[Rag] Fix wrong usage of `num_beams` and `bos_token_id` in Rag Sequence generation (#7386)

* fix_rag_sequence

* add second bug fix
This commit is contained in:
Patrick von Platen 2020-09-25 14:35:49 +02:00 committed by GitHub
parent 415071b4c2
commit 571c7a11c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 2 deletions

View File

@ -882,7 +882,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
hypos = []
kwargs["num_beams"] = num_beams
kwargs["num_return_sequences"] = num_return_sequences
kwargs["num_return_sequences"] = num_beams
kwargs["attention_mask"] = None
for index in range(len(input_ids)):
@ -916,7 +916,8 @@ class RagSequenceForGeneration(RagPreTrainedModel):
)
# bos_token_id is None for T5
use_bos = self.config.bos_token_id is not None and target[:, 0].eq(self.config.bos_token_id).all()
bos_token_id = self.config.bos_token_id or self.config.generator.bos_token_id
use_bos = bos_token_id is not None and target[:, 0].eq(bos_token_id).all()
def _mask_pads(ll, smooth_obj):
pad_mask = target.eq(self.config.generator.pad_token_id)