Pass encoder outputs into GenerationMixin (#10599)
* Pass encoder_outputs into generate() * Remove an if-statement * Reformat * Minimize changes to generate() * Comment on input_ids
This commit is contained in:
parent
00cad2e5c1
commit
fa35cda91e
|
@ -376,7 +376,14 @@ class GenerationMixin:
|
|||
"""
|
||||
return logits
|
||||
|
||||
def _prepare_input_ids_for_generation(self, bos_token_id: int) -> torch.LongTensor:
|
||||
def _prepare_input_ids_for_generation(
|
||||
self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput]
|
||||
) -> torch.LongTensor:
|
||||
if self.config.is_encoder_decoder and encoder_outputs is not None:
|
||||
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
|
||||
shape = encoder_outputs.last_hidden_state.size()[:-1]
|
||||
return torch.ones(shape, dtype=torch.long, device=self.device) * -100
|
||||
|
||||
if bos_token_id is None:
|
||||
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
|
||||
return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id
|
||||
|
@ -395,12 +402,13 @@ class GenerationMixin:
|
|||
def _prepare_encoder_decoder_kwargs_for_generation(
|
||||
self, input_ids: torch.LongTensor, model_kwargs
|
||||
) -> Dict[str, Any]:
|
||||
# retrieve encoder hidden states
|
||||
encoder = self.get_encoder()
|
||||
encoder_kwargs = {
|
||||
argument: value for argument, value in model_kwargs.items() if not argument.startswith("decoder_")
|
||||
}
|
||||
model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs)
|
||||
if "encoder_outputs" not in model_kwargs:
|
||||
# retrieve encoder hidden states
|
||||
encoder = self.get_encoder()
|
||||
encoder_kwargs = {
|
||||
argument: value for argument, value in model_kwargs.items() if not argument.startswith("decoder_")
|
||||
}
|
||||
model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs)
|
||||
return model_kwargs
|
||||
|
||||
def _prepare_decoder_input_ids_for_generation(
|
||||
|
@ -887,7 +895,7 @@ class GenerationMixin:
|
|||
|
||||
if input_ids is None:
|
||||
# init `input_ids` with bos_token_id
|
||||
input_ids = self._prepare_input_ids_for_generation(bos_token_id)
|
||||
input_ids = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))
|
||||
|
||||
if model_kwargs.get("attention_mask", None) is None:
|
||||
# init `attention_mask` depending on `pad_token_id`
|
||||
|
|
Loading…
Reference in New Issue