From 571c7a11c17bd00ba3e79f4d853cc51428a14e45 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 25 Sep 2020 14:35:49 +0200 Subject: [PATCH] [Rag] Fix wrong usage of `num_beams` and `bos_token_id` in Rag Sequence generation (#7386) * fix_rag_sequence * add second bug fix --- src/transformers/modeling_rag.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_rag.py b/src/transformers/modeling_rag.py index c8c3e7eefe..4dc9046069 100644 --- a/src/transformers/modeling_rag.py +++ b/src/transformers/modeling_rag.py @@ -882,7 +882,7 @@ class RagSequenceForGeneration(RagPreTrainedModel): hypos = [] kwargs["num_beams"] = num_beams - kwargs["num_return_sequences"] = num_return_sequences + kwargs["num_return_sequences"] = num_beams kwargs["attention_mask"] = None for index in range(len(input_ids)): @@ -916,7 +916,8 @@ class RagSequenceForGeneration(RagPreTrainedModel): ) # bos_token_id is None for T5 - use_bos = self.config.bos_token_id is not None and target[:, 0].eq(self.config.bos_token_id).all() + bos_token_id = self.config.bos_token_id or self.config.generator.bos_token_id + use_bos = bos_token_id is not None and target[:, 0].eq(bos_token_id).all() def _mask_pads(ll, smooth_obj): pad_mask = target.eq(self.config.generator.pad_token_id)