fix rag retriever save pretrained (#7399)

This commit is contained in:
Patrick von Platen 2020-09-25 19:47:12 +02:00 committed by GitHub
parent 1a14687e6f
commit 2c8ecdf8a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 2 deletions

View File

@ -312,8 +312,8 @@ class RagRetriever:
def save_pretrained(self, save_directory):
self.config.save_pretrained(save_directory)
rag_tokenizer = RagTokenizer(
question_encoder_tokenizer=self.question_encoder_tokenizer,
generator_tokenizer=self.generator_tokenizer,
question_encoder=self.question_encoder_tokenizer,
generator=self.generator_tokenizer,
)
rag_tokenizer.save_pretrained(save_directory)

View File

@ -168,6 +168,11 @@ class RagRetrieverTest(TestCase):
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
def test_save_and_from_pretrained(self):
retriever = self.get_dummy_hf_index_retriever()
with tempfile.TemporaryDirectory() as tmp_dirname:
retriever.save_pretrained(tmp_dirname)
def test_legacy_index_retriever_retrieve(self):
n_docs = 1
retriever = self.get_dummy_legacy_index_retriever()