Pass encoder outputs into GenerationMixin (#10599)

* Pass encoder_outputs into generate()

* Remove an if-statement

* Reformat

* Minimize changes to generate()

* Comment on input_ids
This commit is contained in:
ymfa 2021-03-12 16:13:11 +00:00 committed by GitHub
parent 00cad2e5c1
commit fa35cda91e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 16 additions and 8 deletions

View File

@ -376,7 +376,14 @@ class GenerationMixin:
"""
return logits
def _prepare_input_ids_for_generation(self, bos_token_id: int) -> torch.LongTensor:
def _prepare_input_ids_for_generation(
self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput]
) -> torch.LongTensor:
if self.config.is_encoder_decoder and encoder_outputs is not None:
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
shape = encoder_outputs.last_hidden_state.size()[:-1]
return torch.ones(shape, dtype=torch.long, device=self.device) * -100
if bos_token_id is None:
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id
@ -395,12 +402,13 @@ class GenerationMixin:
def _prepare_encoder_decoder_kwargs_for_generation(
self, input_ids: torch.LongTensor, model_kwargs
) -> Dict[str, Any]:
# retrieve encoder hidden states
encoder = self.get_encoder()
encoder_kwargs = {
argument: value for argument, value in model_kwargs.items() if not argument.startswith("decoder_")
}
model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs)
if "encoder_outputs" not in model_kwargs:
# retrieve encoder hidden states
encoder = self.get_encoder()
encoder_kwargs = {
argument: value for argument, value in model_kwargs.items() if not argument.startswith("decoder_")
}
model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs)
return model_kwargs
def _prepare_decoder_input_ids_for_generation(
@ -887,7 +895,7 @@ class GenerationMixin:
if input_ids is None:
# init `input_ids` with bos_token_id
input_ids = self._prepare_input_ids_for_generation(bos_token_id)
input_ids = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))
if model_kwargs.get("attention_mask", None) is None:
# init `attention_mask` depending on `pad_token_id`