From 79bbcc5260c3acde3e7156966ba836afcbfd8808 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 8 Jan 2021 11:50:39 +0100 Subject: [PATCH] [Generation] Fix bug for manual decoder_input_ids + warning message (#9472) * up * improve style --- src/transformers/generation_utils.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 8147c103dd..6fcf9e5bab 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -379,12 +379,8 @@ class GenerationMixin: return model_kwargs def _prepare_decoder_input_ids_for_generation( - self, input_ids: torch.LongTensor, decoder_start_token_id: int = None, bos_token_id: int = None, **model_kwargs + self, input_ids: torch.LongTensor, decoder_start_token_id: int = None, bos_token_id: int = None ) -> torch.LongTensor: - - if "decoder_input_ids" in model_kwargs: - return model_kwargs["decoder_input_ids"] - decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) decoder_input_ids = ( torch.ones((input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device) @@ -837,13 +833,23 @@ class GenerationMixin: model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs) # set input_ids as decoder_input_ids - input_ids = self._prepare_decoder_input_ids_for_generation( - input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id, **model_kwargs - ) + if "decoder_input_ids" in model_kwargs: + input_ids = model_kwargs.pop("decoder_input_ids") + else: + input_ids = self._prepare_decoder_input_ids_for_generation( + input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id + ) if "encoder_outputs" not in model_kwargs or not isinstance(model_kwargs["encoder_outputs"], ModelOutput): raise ValueError("Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`.") + if input_ids.shape[-1] >= max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids.shape[-1]}, but ``max_length`` is set to {max_length}." + "This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``." + ) + # determine generation mode is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True