[Generation] Fix bug for manual decoder_input_ids + warning message (#9472)
* up * improve style
This commit is contained in:
parent
9e1ea846bc
commit
79bbcc5260
|
@ -379,12 +379,8 @@ class GenerationMixin:
|
||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
def _prepare_decoder_input_ids_for_generation(
|
def _prepare_decoder_input_ids_for_generation(
|
||||||
self, input_ids: torch.LongTensor, decoder_start_token_id: int = None, bos_token_id: int = None, **model_kwargs
|
self, input_ids: torch.LongTensor, decoder_start_token_id: int = None, bos_token_id: int = None
|
||||||
) -> torch.LongTensor:
|
) -> torch.LongTensor:
|
||||||
|
|
||||||
if "decoder_input_ids" in model_kwargs:
|
|
||||||
return model_kwargs["decoder_input_ids"]
|
|
||||||
|
|
||||||
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
|
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
|
||||||
decoder_input_ids = (
|
decoder_input_ids = (
|
||||||
torch.ones((input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device)
|
torch.ones((input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device)
|
||||||
|
@ -837,13 +833,23 @@ class GenerationMixin:
|
||||||
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
|
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
|
||||||
|
|
||||||
# set input_ids as decoder_input_ids
|
# set input_ids as decoder_input_ids
|
||||||
|
if "decoder_input_ids" in model_kwargs:
|
||||||
|
input_ids = model_kwargs.pop("decoder_input_ids")
|
||||||
|
else:
|
||||||
input_ids = self._prepare_decoder_input_ids_for_generation(
|
input_ids = self._prepare_decoder_input_ids_for_generation(
|
||||||
input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id, **model_kwargs
|
input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if "encoder_outputs" not in model_kwargs or not isinstance(model_kwargs["encoder_outputs"], ModelOutput):
|
if "encoder_outputs" not in model_kwargs or not isinstance(model_kwargs["encoder_outputs"], ModelOutput):
|
||||||
raise ValueError("Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`.")
|
raise ValueError("Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`.")
|
||||||
|
|
||||||
|
if input_ids.shape[-1] >= max_length:
|
||||||
|
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
||||||
|
logger.warning(
|
||||||
|
f"Input length of {input_ids_string} is {input_ids.shape[-1]}, but ``max_length`` is set to {max_length}."
|
||||||
|
"This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``."
|
||||||
|
)
|
||||||
|
|
||||||
# determine generation mode
|
# determine generation mode
|
||||||
is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False
|
is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False
|
||||||
is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True
|
is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True
|
||||||
|
|
Loading…
Reference in New Issue