[RAG] Fix retrieval offset in RAG's HfIndex and better integration tests (#7372)
* Fix retrieval offset in RAG's HfIndex * update slow tests * style * fix new test * style * add better tests Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
571c7a11c1
commit
cf1c88e092
|
@ -153,4 +153,4 @@ class RagRetrieverTest(TestCase):
|
|||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(list(doc_ids), [1, 0])
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
|
|
@ -203,6 +203,7 @@ class HFIndex:
|
|||
dataset_name: str,
|
||||
dataset_split: str,
|
||||
index_name: str,
|
||||
vector_size: int,
|
||||
index_path: Optional[str] = None,
|
||||
use_dummy_dataset=False,
|
||||
):
|
||||
|
@ -210,6 +211,7 @@ class HFIndex:
|
|||
self.dataset_name = dataset_name
|
||||
self.dataset_split = dataset_split
|
||||
self.index_name = index_name
|
||||
self.vector_size = vector_size
|
||||
self.index_path = index_path
|
||||
self.use_dummy_dataset = use_dummy_dataset
|
||||
self._index_initialize = False
|
||||
|
@ -218,6 +220,7 @@ class HFIndex:
|
|||
self.dataset = load_dataset(
|
||||
self.dataset_name, with_index=False, split=self.dataset_split, dummy=self.use_dummy_dataset
|
||||
)
|
||||
self.dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True)
|
||||
|
||||
def is_initialized(self):
|
||||
return self._index_initialize
|
||||
|
@ -236,15 +239,19 @@ class HFIndex:
|
|||
index_name=self.index_name,
|
||||
dummy=self.use_dummy_dataset,
|
||||
)
|
||||
self.dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True)
|
||||
self._index_initialize = True
|
||||
|
||||
def get_doc_dicts(self, doc_ids: np.ndarray) -> List[dict]:
|
||||
return [self.dataset[doc_ids[i].tolist()] for i in range(doc_ids.shape[0])]
|
||||
|
||||
def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]:
|
||||
_, docs = self.dataset.get_nearest_examples_batch("embeddings", question_hidden_states, n_docs)
|
||||
ids = [[int(i) for i in doc["id"]] for doc in docs]
|
||||
_, ids = self.dataset.search_batch("embeddings", question_hidden_states, n_docs)
|
||||
docs = [self.dataset[[i for i in indices if i >= 0]] for indices in ids]
|
||||
vectors = [doc["embeddings"] for doc in docs]
|
||||
for i in range(len(vectors)):
|
||||
if len(vectors[i]) < n_docs:
|
||||
vectors[i] = np.vstack([vectors[i], np.zeros((n_docs - len(vectors[i]), self.vector_size))])
|
||||
return np.array(ids), np.array(vectors) # shapes (batch_size, n_docs) and (batch_size, n_docs, d)
|
||||
|
||||
|
||||
|
@ -274,7 +281,12 @@ class RagRetriever:
|
|||
)
|
||||
if config.index_name == "legacy"
|
||||
else HFIndex(
|
||||
config.dataset, config.dataset_split, config.index_name, config.index_path, config.use_dummy_dataset
|
||||
config.dataset,
|
||||
config.dataset_split,
|
||||
config.index_name,
|
||||
config.retrieval_vector_size,
|
||||
config.index_path,
|
||||
config.use_dummy_dataset,
|
||||
)
|
||||
)
|
||||
self.generator_tokenizer = generator_tokenizer
|
||||
|
@ -384,8 +396,9 @@ class RagRetriever:
|
|||
)
|
||||
ids_batched.extend(ids)
|
||||
vectors_batched.extend(vectors)
|
||||
return np.array(ids_batched), np.array(
|
||||
vectors_batched
|
||||
return (
|
||||
np.array(ids_batched),
|
||||
np.array(vectors_batched),
|
||||
) # shapes (batch_size, n_docs) and (batch_size, n_docs, d)
|
||||
|
||||
def retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, List[dict]]:
|
||||
|
|
|
@ -54,6 +54,7 @@ if is_torch_available() and is_datasets_available() and is_faiss_available():
|
|||
RagRetriever,
|
||||
RagSequenceForGeneration,
|
||||
RagTokenForGeneration,
|
||||
RagTokenizer,
|
||||
)
|
||||
from transformers.modeling_outputs import BaseModelOutput
|
||||
|
||||
|
@ -519,7 +520,7 @@ class RagModelIntegrationTests(unittest.TestCase):
|
|||
expected_doc_scores = torch.tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]).to(torch_device)
|
||||
_assert_tensors_equal(expected_doc_scores, output.doc_scores, atol=TOLERANCE)
|
||||
|
||||
expected_loss = torch.tensor([38.7446]).to(torch_device)
|
||||
expected_loss = torch.tensor([36.7368]).to(torch_device)
|
||||
_assert_tensors_equal(expected_loss, output.loss, atol=TOLERANCE)
|
||||
|
||||
@slow
|
||||
|
@ -558,7 +559,7 @@ class RagModelIntegrationTests(unittest.TestCase):
|
|||
expected_doc_scores = torch.tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]).to(torch_device)
|
||||
_assert_tensors_equal(expected_doc_scores, output.doc_scores, atol=TOLERANCE)
|
||||
|
||||
expected_loss = torch.tensor([38.7045]).to(torch_device)
|
||||
expected_loss = torch.tensor([36.3557]).to(torch_device)
|
||||
_assert_tensors_equal(expected_loss, output.loss, atol=TOLERANCE)
|
||||
|
||||
@slow
|
||||
|
@ -594,122 +595,12 @@ class RagModelIntegrationTests(unittest.TestCase):
|
|||
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
|
||||
|
||||
# Expected outputs as given by model at integration time.
|
||||
EXPECTED_OUTPUT_TEXT_1 = "The songwriting credits are credited to ABBA"
|
||||
EXPECTED_OUTPUT_TEXT_2 = 'The songwriting credits are credited to "B'
|
||||
EXPECTED_OUTPUT_TEXT_1 = "\"She's My Kind of Girl"
|
||||
EXPECTED_OUTPUT_TEXT_2 = "\"She's My Kind of Love"
|
||||
|
||||
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
|
||||
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
|
||||
|
||||
@slow
|
||||
def test_rag_token_generate_batch(self):
|
||||
rag_config = self.get_rag_config()
|
||||
rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
|
||||
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
|
||||
"facebook/dpr-question_encoder-single-nq-base"
|
||||
)
|
||||
rag_retriever = RagRetriever(
|
||||
rag_config,
|
||||
question_encoder_tokenizer=rag_question_encoder_tokenizer,
|
||||
generator_tokenizer=rag_decoder_tokenizer,
|
||||
)
|
||||
|
||||
rag_token = self.token_model
|
||||
rag_token.set_retriever(rag_retriever)
|
||||
|
||||
questions = [
|
||||
"who sings does he love me with reba",
|
||||
"how many pages is invisible man by ralph ellison",
|
||||
"what",
|
||||
]
|
||||
input_dict = rag_question_encoder_tokenizer.batch_encode_plus(
|
||||
questions,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
input_ids = input_dict.input_ids.to(torch_device)
|
||||
attention_mask = input_dict.attention_mask.to(torch_device)
|
||||
|
||||
output_ids = rag_token.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_start_token_id=rag_token.generator.config.decoder_start_token_id,
|
||||
num_beams=4,
|
||||
num_return_sequences=1,
|
||||
max_length=10,
|
||||
)
|
||||
|
||||
# sequence generate test
|
||||
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
|
||||
output_text_3 = rag_decoder_tokenizer.decode(output_ids[2], skip_special_tokens=True)
|
||||
|
||||
# Expected outputs as given by model at integration time.
|
||||
EXPECTED_OUTPUT_TEXT_1 = '"People Need Love" is the'
|
||||
EXPECTED_OUTPUT_TEXT_2 = '"How many pages is invisible man'
|
||||
EXPECTED_OUTPUT_TEXT_3 = "Otis the Aardvark"
|
||||
|
||||
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
|
||||
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
|
||||
self.assertEqual(output_text_3, EXPECTED_OUTPUT_TEXT_3)
|
||||
|
||||
@slow
|
||||
def test_rag_sequence_generate_batch(self):
|
||||
# IMPORTAN: This test fails on GPU, but is fine on CPU -> beam search is very sensible
|
||||
rag_config = self.get_rag_config()
|
||||
rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
|
||||
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
|
||||
"facebook/dpr-question_encoder-single-nq-base"
|
||||
)
|
||||
rag_retriever = RagRetriever(
|
||||
rag_config,
|
||||
question_encoder_tokenizer=rag_question_encoder_tokenizer,
|
||||
generator_tokenizer=rag_decoder_tokenizer,
|
||||
)
|
||||
|
||||
rag_sequence = self.sequence_model
|
||||
rag_sequence.set_retriever(rag_retriever)
|
||||
|
||||
questions = [
|
||||
"who sings does he love me with reba",
|
||||
"how many pages is invisible man by ralph ellison",
|
||||
"what",
|
||||
]
|
||||
|
||||
input_dict = rag_question_encoder_tokenizer.batch_encode_plus(
|
||||
questions,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
input_ids = input_dict.input_ids.to(torch_device)
|
||||
attention_mask = input_dict.attention_mask.to(torch_device)
|
||||
|
||||
output_ids = rag_sequence.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_start_token_id=rag_sequence.generator.config.decoder_start_token_id,
|
||||
num_beams=4,
|
||||
num_return_sequences=1,
|
||||
max_length=10,
|
||||
)
|
||||
|
||||
# sequence generate test
|
||||
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
|
||||
output_text_3 = rag_decoder_tokenizer.decode(output_ids[2], skip_special_tokens=True)
|
||||
|
||||
# Expected outputs as given by model at integration time.
|
||||
EXPECTED_OUTPUT_TEXT_1 = '"I Know Him So Well"'
|
||||
EXPECTED_OUTPUT_TEXT_2 = '"Howl" chronicles the'
|
||||
EXPECTED_OUTPUT_TEXT_3 = "Otis the Aardvark"
|
||||
|
||||
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
|
||||
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
|
||||
self.assertEqual(output_text_3, EXPECTED_OUTPUT_TEXT_3)
|
||||
|
||||
@slow
|
||||
def test_rag_sequence_generate_beam(self):
|
||||
rag_config = self.get_rag_config()
|
||||
|
@ -743,12 +634,122 @@ class RagModelIntegrationTests(unittest.TestCase):
|
|||
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
|
||||
|
||||
# Expected outputs as given by model at integration time.
|
||||
EXPECTED_OUTPUT_TEXT_1 = """ ABBA / small label like Playboy Records did not have the distribution resources to meet the demand for the single from retailers and radio programmers. The foursome decided to record their first album together in late 1972, and sessions began on 26 September 1972. The women shared lead vocals on "Nina, Pretty Ballerina" that day."""
|
||||
EXPECTED_OUTPUT_TEXT_2 = """ ABBA / small label like Playboy Records did not have the distribution resources to meet the demand for the single from retailers and radio programmers. The foursome decided to record their first album together in late 1972, and sessions began on 26 September 1972. The women shared lead vocals on "Nina, Pretty Ballerina" (a top ten hit in Austria)"""
|
||||
EXPECTED_OUTPUT_TEXT_1 = """\"She's My Kind of Girl\" was released through Epic Records in Japan in March 1972, giving the duo a Top 10 hit. Two more singles were released in Japan, \"En Carousel\" and \"Love Has Its Ways\" Ulvaeus and Andersson persevered with their songwriting and experimented with new sounds and vocal arrangements."""
|
||||
EXPECTED_OUTPUT_TEXT_2 = """In September 2018, Björn Ulvaeus revealed that the two new songs, \"I Still Have Faith In You\" and \"Don't Shut Me Down\", would be released no earlier than March 2019. The two new tracks will feature in a TV special set to air later in the year."""
|
||||
|
||||
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
|
||||
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
|
||||
|
||||
@property
|
||||
def test_data_questions(self):
|
||||
return [
|
||||
"who got the first nobel prize in physics",
|
||||
"when is the next deadpool movie being released",
|
||||
"which mode is used for short wave broadcast service",
|
||||
"who is the owner of reading football club",
|
||||
"when is the next scandal episode coming out",
|
||||
"when is the last time the philadelphia won the superbowl",
|
||||
"what is the most current adobe flash player version",
|
||||
"how many episodes are there in dragon ball z",
|
||||
"what is the first step in the evolution of the eye",
|
||||
"where is gall bladder situated in human body",
|
||||
"what is the main mineral in lithium batteries",
|
||||
"who is the president of usa right now",
|
||||
"where do the greasers live in the outsiders",
|
||||
"panda is a national animal of which country",
|
||||
"what is the name of manchester united stadium",
|
||||
]
|
||||
|
||||
@slow
|
||||
def test_rag_sequence_generate_batch(self):
|
||||
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
|
||||
retriever = RagRetriever.from_pretrained(
|
||||
"facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
|
||||
)
|
||||
rag_sequence = RagTokenForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
input_dict = tokenizer(
|
||||
self.test_data_questions,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
input_ids = input_dict.input_ids.to(torch_device)
|
||||
attention_mask = input_dict.attention_mask.to(torch_device)
|
||||
|
||||
output_ids = rag_sequence.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
|
||||
EXPECTED_OUTPUTS = [
|
||||
" albert einstein",
|
||||
" june 22, 2018",
|
||||
" amplitude modulation",
|
||||
" tim besley ( chairman )",
|
||||
" june 20, 2018",
|
||||
" 1980",
|
||||
" 7.0",
|
||||
" 8",
|
||||
" reticular formation",
|
||||
" walls of the abdomen",
|
||||
" spodumene",
|
||||
" obama",
|
||||
" grainger's compound",
|
||||
" japan",
|
||||
" old trafford stadium",
|
||||
]
|
||||
self.assertListEqual(outputs, EXPECTED_OUTPUTS)
|
||||
|
||||
@slow
|
||||
def test_rag_token_generate_batch(self):
|
||||
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
|
||||
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
|
||||
rag_token = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
input_dict = tokenizer(
|
||||
self.test_data_questions,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
input_ids = input_dict.input_ids.to(torch_device)
|
||||
attention_mask = input_dict.attention_mask.to(torch_device)
|
||||
|
||||
output_ids = rag_token.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
|
||||
EXPECTED_OUTPUTS = [
|
||||
" albert einstein",
|
||||
" september 22, 2017",
|
||||
" amplitude modulation",
|
||||
" stefan persson",
|
||||
" april 20, 2018",
|
||||
" the 1970s",
|
||||
" 7.1. 2",
|
||||
" 13",
|
||||
" step by step",
|
||||
" stomach",
|
||||
" spodumene",
|
||||
" obama",
|
||||
" northern new jersey",
|
||||
" india",
|
||||
" united stadium",
|
||||
]
|
||||
self.assertListEqual(outputs, EXPECTED_OUTPUTS)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_retrieval
|
||||
|
|
|
@ -166,7 +166,7 @@ class RagRetrieverTest(TestCase):
|
|||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(list(doc_ids), [1, 0])
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
def test_legacy_index_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
|
@ -181,7 +181,7 @@ class RagRetrieverTest(TestCase):
|
|||
self.assertEqual(len(doc_dicts[0]["text"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["text"][0], "bar") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["text"][0], "foo") # max inner product is reached with first doc
|
||||
self.assertListEqual(list(doc_ids), [1, 0])
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
@require_torch
|
||||
def test_hf_index_retriever_call(self):
|
||||
|
|
Loading…
Reference in New Issue