[RAG] Fix retrieval offset in RAG's HfIndex and better integration tests (#7372)

* Fix retrieval offset in RAG's HfIndex

* update slow tests

* style

* fix new test

* style

* add better tests

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Quentin Lhoest 2020-09-25 16:12:46 +02:00 committed by GitHub
parent 571c7a11c1
commit cf1c88e092
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 138 additions and 124 deletions

View File

@ -153,4 +153,4 @@ class RagRetrieverTest(TestCase):
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
self.assertListEqual(list(doc_ids), [1, 0])
self.assertListEqual(doc_ids.tolist(), [[1], [0]])

View File

@ -203,6 +203,7 @@ class HFIndex:
dataset_name: str,
dataset_split: str,
index_name: str,
vector_size: int,
index_path: Optional[str] = None,
use_dummy_dataset=False,
):
@ -210,6 +211,7 @@ class HFIndex:
self.dataset_name = dataset_name
self.dataset_split = dataset_split
self.index_name = index_name
self.vector_size = vector_size
self.index_path = index_path
self.use_dummy_dataset = use_dummy_dataset
self._index_initialize = False
@ -218,6 +220,7 @@ class HFIndex:
self.dataset = load_dataset(
self.dataset_name, with_index=False, split=self.dataset_split, dummy=self.use_dummy_dataset
)
self.dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True)
def is_initialized(self):
return self._index_initialize
@ -236,15 +239,19 @@ class HFIndex:
index_name=self.index_name,
dummy=self.use_dummy_dataset,
)
self.dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True)
self._index_initialize = True
def get_doc_dicts(self, doc_ids: np.ndarray) -> List[dict]:
return [self.dataset[doc_ids[i].tolist()] for i in range(doc_ids.shape[0])]
def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]:
_, docs = self.dataset.get_nearest_examples_batch("embeddings", question_hidden_states, n_docs)
ids = [[int(i) for i in doc["id"]] for doc in docs]
_, ids = self.dataset.search_batch("embeddings", question_hidden_states, n_docs)
docs = [self.dataset[[i for i in indices if i >= 0]] for indices in ids]
vectors = [doc["embeddings"] for doc in docs]
for i in range(len(vectors)):
if len(vectors[i]) < n_docs:
vectors[i] = np.vstack([vectors[i], np.zeros((n_docs - len(vectors[i]), self.vector_size))])
return np.array(ids), np.array(vectors) # shapes (batch_size, n_docs) and (batch_size, n_docs, d)
@ -274,7 +281,12 @@ class RagRetriever:
)
if config.index_name == "legacy"
else HFIndex(
config.dataset, config.dataset_split, config.index_name, config.index_path, config.use_dummy_dataset
config.dataset,
config.dataset_split,
config.index_name,
config.retrieval_vector_size,
config.index_path,
config.use_dummy_dataset,
)
)
self.generator_tokenizer = generator_tokenizer
@ -384,8 +396,9 @@ class RagRetriever:
)
ids_batched.extend(ids)
vectors_batched.extend(vectors)
return np.array(ids_batched), np.array(
vectors_batched
return (
np.array(ids_batched),
np.array(vectors_batched),
) # shapes (batch_size, n_docs) and (batch_size, n_docs, d)
def retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, List[dict]]:

View File

@ -54,6 +54,7 @@ if is_torch_available() and is_datasets_available() and is_faiss_available():
RagRetriever,
RagSequenceForGeneration,
RagTokenForGeneration,
RagTokenizer,
)
from transformers.modeling_outputs import BaseModelOutput
@ -519,7 +520,7 @@ class RagModelIntegrationTests(unittest.TestCase):
expected_doc_scores = torch.tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]).to(torch_device)
_assert_tensors_equal(expected_doc_scores, output.doc_scores, atol=TOLERANCE)
expected_loss = torch.tensor([38.7446]).to(torch_device)
expected_loss = torch.tensor([36.7368]).to(torch_device)
_assert_tensors_equal(expected_loss, output.loss, atol=TOLERANCE)
@slow
@ -558,7 +559,7 @@ class RagModelIntegrationTests(unittest.TestCase):
expected_doc_scores = torch.tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]).to(torch_device)
_assert_tensors_equal(expected_doc_scores, output.doc_scores, atol=TOLERANCE)
expected_loss = torch.tensor([38.7045]).to(torch_device)
expected_loss = torch.tensor([36.3557]).to(torch_device)
_assert_tensors_equal(expected_loss, output.loss, atol=TOLERANCE)
@slow
@ -594,122 +595,12 @@ class RagModelIntegrationTests(unittest.TestCase):
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
# Expected outputs as given by model at integration time.
EXPECTED_OUTPUT_TEXT_1 = "The songwriting credits are credited to ABBA"
EXPECTED_OUTPUT_TEXT_2 = 'The songwriting credits are credited to "B'
EXPECTED_OUTPUT_TEXT_1 = "\"She's My Kind of Girl"
EXPECTED_OUTPUT_TEXT_2 = "\"She's My Kind of Love"
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
@slow
def test_rag_token_generate_batch(self):
rag_config = self.get_rag_config()
rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
)
rag_retriever = RagRetriever(
rag_config,
question_encoder_tokenizer=rag_question_encoder_tokenizer,
generator_tokenizer=rag_decoder_tokenizer,
)
rag_token = self.token_model
rag_token.set_retriever(rag_retriever)
questions = [
"who sings does he love me with reba",
"how many pages is invisible man by ralph ellison",
"what",
]
input_dict = rag_question_encoder_tokenizer.batch_encode_plus(
questions,
return_tensors="pt",
padding=True,
truncation=True,
)
input_ids = input_dict.input_ids.to(torch_device)
attention_mask = input_dict.attention_mask.to(torch_device)
output_ids = rag_token.generate(
input_ids,
attention_mask=attention_mask,
decoder_start_token_id=rag_token.generator.config.decoder_start_token_id,
num_beams=4,
num_return_sequences=1,
max_length=10,
)
# sequence generate test
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
output_text_3 = rag_decoder_tokenizer.decode(output_ids[2], skip_special_tokens=True)
# Expected outputs as given by model at integration time.
EXPECTED_OUTPUT_TEXT_1 = '"People Need Love" is the'
EXPECTED_OUTPUT_TEXT_2 = '"How many pages is invisible man'
EXPECTED_OUTPUT_TEXT_3 = "Otis the Aardvark"
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
self.assertEqual(output_text_3, EXPECTED_OUTPUT_TEXT_3)
@slow
def test_rag_sequence_generate_batch(self):
# IMPORTAN: This test fails on GPU, but is fine on CPU -> beam search is very sensible
rag_config = self.get_rag_config()
rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
)
rag_retriever = RagRetriever(
rag_config,
question_encoder_tokenizer=rag_question_encoder_tokenizer,
generator_tokenizer=rag_decoder_tokenizer,
)
rag_sequence = self.sequence_model
rag_sequence.set_retriever(rag_retriever)
questions = [
"who sings does he love me with reba",
"how many pages is invisible man by ralph ellison",
"what",
]
input_dict = rag_question_encoder_tokenizer.batch_encode_plus(
questions,
return_tensors="pt",
padding=True,
truncation=True,
)
input_ids = input_dict.input_ids.to(torch_device)
attention_mask = input_dict.attention_mask.to(torch_device)
output_ids = rag_sequence.generate(
input_ids,
attention_mask=attention_mask,
decoder_start_token_id=rag_sequence.generator.config.decoder_start_token_id,
num_beams=4,
num_return_sequences=1,
max_length=10,
)
# sequence generate test
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
output_text_3 = rag_decoder_tokenizer.decode(output_ids[2], skip_special_tokens=True)
# Expected outputs as given by model at integration time.
EXPECTED_OUTPUT_TEXT_1 = '"I Know Him So Well"'
EXPECTED_OUTPUT_TEXT_2 = '"Howl" chronicles the'
EXPECTED_OUTPUT_TEXT_3 = "Otis the Aardvark"
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
self.assertEqual(output_text_3, EXPECTED_OUTPUT_TEXT_3)
@slow
def test_rag_sequence_generate_beam(self):
rag_config = self.get_rag_config()
@ -743,12 +634,122 @@ class RagModelIntegrationTests(unittest.TestCase):
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
# Expected outputs as given by model at integration time.
EXPECTED_OUTPUT_TEXT_1 = """ ABBA / small label like Playboy Records did not have the distribution resources to meet the demand for the single from retailers and radio programmers. The foursome decided to record their first album together in late 1972, and sessions began on 26 September 1972. The women shared lead vocals on "Nina, Pretty Ballerina" that day."""
EXPECTED_OUTPUT_TEXT_2 = """ ABBA / small label like Playboy Records did not have the distribution resources to meet the demand for the single from retailers and radio programmers. The foursome decided to record their first album together in late 1972, and sessions began on 26 September 1972. The women shared lead vocals on "Nina, Pretty Ballerina" (a top ten hit in Austria)"""
EXPECTED_OUTPUT_TEXT_1 = """\"She's My Kind of Girl\" was released through Epic Records in Japan in March 1972, giving the duo a Top 10 hit. Two more singles were released in Japan, \"En Carousel\" and \"Love Has Its Ways\" Ulvaeus and Andersson persevered with their songwriting and experimented with new sounds and vocal arrangements."""
EXPECTED_OUTPUT_TEXT_2 = """In September 2018, Björn Ulvaeus revealed that the two new songs, \"I Still Have Faith In You\" and \"Don't Shut Me Down\", would be released no earlier than March 2019. The two new tracks will feature in a TV special set to air later in the year."""
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
@property
def test_data_questions(self):
return [
"who got the first nobel prize in physics",
"when is the next deadpool movie being released",
"which mode is used for short wave broadcast service",
"who is the owner of reading football club",
"when is the next scandal episode coming out",
"when is the last time the philadelphia won the superbowl",
"what is the most current adobe flash player version",
"how many episodes are there in dragon ball z",
"what is the first step in the evolution of the eye",
"where is gall bladder situated in human body",
"what is the main mineral in lithium batteries",
"who is the president of usa right now",
"where do the greasers live in the outsiders",
"panda is a national animal of which country",
"what is the name of manchester united stadium",
]
@slow
def test_rag_sequence_generate_batch(self):
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
retriever = RagRetriever.from_pretrained(
"facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
)
rag_sequence = RagTokenForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
torch_device
)
input_dict = tokenizer(
self.test_data_questions,
return_tensors="pt",
padding=True,
truncation=True,
)
input_ids = input_dict.input_ids.to(torch_device)
attention_mask = input_dict.attention_mask.to(torch_device)
output_ids = rag_sequence.generate(
input_ids,
attention_mask=attention_mask,
)
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
EXPECTED_OUTPUTS = [
" albert einstein",
" june 22, 2018",
" amplitude modulation",
" tim besley ( chairman )",
" june 20, 2018",
" 1980",
" 7.0",
" 8",
" reticular formation",
" walls of the abdomen",
" spodumene",
" obama",
" grainger's compound",
" japan",
" old trafford stadium",
]
self.assertListEqual(outputs, EXPECTED_OUTPUTS)
@slow
def test_rag_token_generate_batch(self):
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
rag_token = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever).to(
torch_device
)
input_dict = tokenizer(
self.test_data_questions,
return_tensors="pt",
padding=True,
truncation=True,
)
input_ids = input_dict.input_ids.to(torch_device)
attention_mask = input_dict.attention_mask.to(torch_device)
output_ids = rag_token.generate(
input_ids,
attention_mask=attention_mask,
)
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
EXPECTED_OUTPUTS = [
" albert einstein",
" september 22, 2017",
" amplitude modulation",
" stefan persson",
" april 20, 2018",
" the 1970s",
" 7.1. 2",
" 13",
" step by step",
" stomach",
" spodumene",
" obama",
" northern new jersey",
" india",
" united stadium",
]
self.assertListEqual(outputs, EXPECTED_OUTPUTS)
@require_torch
@require_retrieval

View File

@ -166,7 +166,7 @@ class RagRetrieverTest(TestCase):
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
self.assertListEqual(list(doc_ids), [1, 0])
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
def test_legacy_index_retriever_retrieve(self):
n_docs = 1
@ -181,7 +181,7 @@ class RagRetrieverTest(TestCase):
self.assertEqual(len(doc_dicts[0]["text"]), n_docs)
self.assertEqual(doc_dicts[0]["text"][0], "bar") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["text"][0], "foo") # max inner product is reached with first doc
self.assertListEqual(list(doc_ids), [1, 0])
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
@require_torch
def test_hf_index_retriever_call(self):