fix rag retriever save pretrained (#7399)
This commit is contained in:
parent
1a14687e6f
commit
2c8ecdf8a8
|
@ -312,8 +312,8 @@ class RagRetriever:
|
|||
def save_pretrained(self, save_directory):
|
||||
self.config.save_pretrained(save_directory)
|
||||
rag_tokenizer = RagTokenizer(
|
||||
question_encoder_tokenizer=self.question_encoder_tokenizer,
|
||||
generator_tokenizer=self.generator_tokenizer,
|
||||
question_encoder=self.question_encoder_tokenizer,
|
||||
generator=self.generator_tokenizer,
|
||||
)
|
||||
rag_tokenizer.save_pretrained(save_directory)
|
||||
|
||||
|
|
|
@ -168,6 +168,11 @@ class RagRetrieverTest(TestCase):
|
|||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
def test_save_and_from_pretrained(self):
|
||||
retriever = self.get_dummy_hf_index_retriever()
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
retriever.save_pretrained(tmp_dirname)
|
||||
|
||||
def test_legacy_index_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_legacy_index_retriever()
|
||||
|
|
Loading…
Reference in New Issue