[Rag] Fix wrong usage of `num_beams` and `bos_token_id` in Rag Sequence generation (#7386)
* fix_rag_sequence * add second bug fix
This commit is contained in:
parent
415071b4c2
commit
571c7a11c1
|
@ -882,7 +882,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||
|
||||
hypos = []
|
||||
kwargs["num_beams"] = num_beams
|
||||
kwargs["num_return_sequences"] = num_return_sequences
|
||||
kwargs["num_return_sequences"] = num_beams
|
||||
kwargs["attention_mask"] = None
|
||||
|
||||
for index in range(len(input_ids)):
|
||||
|
@ -916,7 +916,8 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||
)
|
||||
|
||||
# bos_token_id is None for T5
|
||||
use_bos = self.config.bos_token_id is not None and target[:, 0].eq(self.config.bos_token_id).all()
|
||||
bos_token_id = self.config.bos_token_id or self.config.generator.bos_token_id
|
||||
use_bos = bos_token_id is not None and target[:, 0].eq(bos_token_id).all()
|
||||
|
||||
def _mask_pads(ll, smooth_obj):
|
||||
pad_mask = target.eq(self.config.generator.pad_token_id)
|
||||
|
|
Loading…
Reference in New Issue