Replace `past` with `past_key_values` (#20944)

* start cleanup

* more updates

* more models are affected

* more updates

* update generation utils

* style

* revert change that removed reorder cachce

* update generation utils

* style

* style

* remove reorder cache
This commit is contained in:
Arthur 2023-01-08 10:21:40 +01:00 committed by GitHub
parent 7cb596fa22
commit f0577df6de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
84 changed files with 479 additions and 424 deletions

View File

@ -1074,20 +1074,20 @@ class TFGenerationMixin:
@staticmethod
def _extract_past_from_model_output(outputs: ModelOutput):
past = None
past_key_values = None
if "past_key_values" in outputs:
past = outputs.past_key_values
past_key_values = outputs.past_key_values
elif "mems" in outputs:
past = outputs.mems
past_key_values = outputs.mems
elif "past_buckets_states" in outputs:
past = outputs.past_buckets_states
return past
past_key_values = outputs.past_buckets_states
return past_key_values
def _update_model_kwargs_for_generation(
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
) -> Dict[str, Any]:
# update past
model_kwargs["past"] = self._extract_past_from_model_output(outputs)
# update past_key_values
model_kwargs["past_key_values"] = self._extract_past_from_model_output(outputs)
# update attention mask
if not is_encoder_decoder:
@ -1112,7 +1112,7 @@ class TFGenerationMixin:
def _initialize_attention(model_kwargs, num_padding_values, is_encoder_decoder):
"""initializes the appropriate attention mask -- encoder-decoder models use `decoder_attention_mask`"""
if is_encoder_decoder:
# One 1 for decoder_start_token_id, 0s for the currently-unfilled locations in the past tensor,
# One 1 for decoder_start_token_id, 0s for the currently-unfilled locations in the past_key_values tensor,
# 1s for the actual input_ids
decoder_attention_mask = tf.concat(
[
@ -1125,7 +1125,7 @@ class TFGenerationMixin:
mask = {"decoder_attention_mask": decoder_attention_mask}
else:
attention_mask = model_kwargs.pop("attention_mask")
# 0s for the currently-unfilled locations in the past tensor, 1s for the actual input_ids
# 0s for the currently-unfilled locations in the past_key_values tensor, 1s for the actual input_ids
attention_mask = tf.concat(
[
attention_mask,
@ -1154,32 +1154,32 @@ class TFGenerationMixin:
mask = {"attention_mask": attention_mask}
return mask
def _initialize_past(past, num_padding_values, batch_axis):
"""initialize past with zeros -- the structure depends on `batch_axis`"""
def _initialize_past(past_key_values, num_padding_values, batch_axis):
"""initialize past_key_values with zeros -- the structure depends on `batch_axis`"""
if batch_axis == 0:
padding_values = tf.constant([[0, 0], [0, 0], [0, num_padding_values], [0, 0]], dtype=tf.int32)
new_past = ()
for past_layer in past:
for past_layer in past_key_values:
new_past_layer = list(past_layer)
for i in range(len(new_past_layer[:2])):
new_past_layer[i] = tf.pad(past_layer[i], padding_values)
new_past += (tuple(new_past_layer),)
else:
padding_values = tf.scatter_nd(indices=[[3, 1]], updates=[num_padding_values], shape=(5, 2))
new_past = list(past)
for i in range(len(past)):
new_past[i] = tf.pad(past[i], padding_values)
new_past = list(past_key_values)
for i in range(len(past_key_values)):
new_past[i] = tf.pad(past_key_values[i], padding_values)
return new_past
def _update_past(past, new_past_index, batch_axis):
def _update_past(past_key_values, new_past_index, batch_axis):
if batch_axis == 0:
slice_start_base = tf.constant([0, 0, 1, 0])
new_past = ()
for past_layer in past:
for past_layer in past_key_values:
new_past_layer = list(past_layer)
for i in range(len(new_past_layer[:2])):
update_slice = past_layer[i][:, :, -1:]
# Write the last slice to the first open location in the padded past array
# Write the last slice to the first open location in the padded past_key_values array
# and then truncate the last slice off the array
new_past_layer[i] = dynamic_update_slice(
past_layer[i][:, :, :-1], update_slice, slice_start_base * new_past_index
@ -1187,41 +1187,42 @@ class TFGenerationMixin:
new_past += (tuple(new_past_layer),)
else:
slice_start_base = tf.constant([0, 0, 0, 1, 0])
new_past = [None for _ in range(len(past))]
for i in range(len(past)):
update_slice = past[i][:, :, :, -1:]
# Write the last slice to the first open location in the padded past array
new_past = [None for _ in range(len(past_key_values))]
for i in range(len(past_key_values)):
update_slice = past_key_values[i][:, :, :, -1:]
# Write the last slice to the first open location in the padded past_key_values array
# and then truncate the last slice off the array
new_past[i] = dynamic_update_slice(
past[i][:, :, :, :-1], update_slice, slice_start_base * new_past_index
past_key_values[i][:, :, :, :-1], update_slice, slice_start_base * new_past_index
)
return new_past
past = self._extract_past_from_model_output(model_outputs)
if past is None:
past_key_values = self._extract_past_from_model_output(model_outputs)
if past_key_values is None:
raise ValueError(
f"No known past variable found in model outputs (model outputs keys: {list(model_outputs.keys())})"
"No known `past_key_values variable` found in model outputs (model outputs keys:"
f" {list(model_outputs.keys())})"
)
is_past_initialized = model_kwargs.pop("past", None) is not None
is_past_initialized = model_kwargs.pop("past_key_values", None) is not None
if not is_past_initialized:
# The padded version of `past` has a length of `max_length - 1`, as `past` holds information relative to
# previous autoregressive generation steps (step 0 has no past, step 1 has 1 past value, ..., the last step
# has `max_length - 1` past values).
# The padded version of `past_key_values` has a length of `max_length - 1`, as `past_key_values` holds information relative to
# previous autoregressive generation steps (step 0 has no past_key_values, step 1 has 1 past_key_values value, ..., the last step
# has `max_length - 1` past_key_values values).
num_padding_values = max_length - cur_len - 1
mask = _initialize_attention(model_kwargs, num_padding_values, is_encoder_decoder)
new_past = _initialize_past(past, num_padding_values, batch_axis)
new_past = _initialize_past(past_key_values, num_padding_values, batch_axis)
else:
# The new index of past to be filled corresponds to the current length of the sequence, with two
# subtractions: -1 because past holds information regarding previous generation steps (read comment above)
# The new index of past_key_values to be filled corresponds to the current length of the sequence, with two
# subtractions: -1 because past_key_values holds information regarding previous generation steps (read comment above)
# and -1 again because in an array the index is the length of the array minus 1.
new_past_index = cur_len - 2
mask = _update_attention(model_kwargs, new_past_index, is_encoder_decoder)
new_past = _update_past(past, new_past_index, batch_axis)
new_past = _update_past(past_key_values, new_past_index, batch_axis)
# sets the updated variables (mask and past)
# sets the updated variables (mask and past_key_values)
model_kwargs.update(mask)
model_kwargs["past"] = tuple(new_past)
model_kwargs["past_key_values"] = tuple(new_past)
return model_kwargs
@ -1403,7 +1404,7 @@ class TFGenerationMixin:
# GPT2 and other models has a slightly different cache structure, with a different batch axis
model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self)
cache_batch_axis = 1 if any([model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")]) else 0
# some models, like XLNet, need more than the last token in the presence of past
# some models, like XLNet, need more than the last token in the presence of past_key_values
needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys())
# 2. init `attentions`, `hidden_states`, and `scores` tuples
@ -1429,7 +1430,7 @@ class TFGenerationMixin:
# define condition fn
def greedy_search_body_fn(generated, finished_sequences, cur_len, model_kwargs):
"""state update fn."""
if model_kwargs.get("past") is None or needs_full_input:
if model_kwargs.get("past_key_values") is None or needs_full_input:
input_ids = generated[:, :cur_len]
else:
input_ids = tf.expand_dims(generated[:, cur_len - 1], -1)
@ -1492,15 +1493,15 @@ class TFGenerationMixin:
model_kwargs = self._update_model_kwargs_for_generation(
model_outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
# if we don't cache past key values we need the whole input
if model_kwargs.get("past", None) is None:
# let's throw out `past` since we don't want `None` tensors
model_kwargs.pop("past", None)
# if we don't cache past_key_values key values we need the whole input
if model_kwargs.get("past_key_values", None) is None:
# let's throw out `past_key_values` since we don't want `None` tensors
model_kwargs.pop("past_key_values", None)
return generated, finished_sequences, cur_len, model_kwargs
# 5. run generation
# 1st generation step has to be run before to initialize `past`
# 1st generation step has to be run before to initialize `past_key_values`
generated, finished_sequences, cur_len, model_kwargs = greedy_search_body_fn(
generated, finished_sequences, cur_len, model_kwargs
)
@ -1680,7 +1681,7 @@ class TFGenerationMixin:
# GPT2 and other models has a slightly different cache structure, with a different batch axis
model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self)
cache_batch_axis = 1 if any([model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")]) else 0
# some models, like XLNet, need more than the last token in the presence of past
# some models, like XLNet, need more than the last token in the presence of past_key_values
needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys())
# 2. init `attentions`, `hidden_states`, and `scores` tuples
@ -1702,7 +1703,7 @@ class TFGenerationMixin:
return ~tf.reduce_all(finished_sequences)
def sample_body_fn(generated, finished_sequences, cur_len, model_kwargs):
if model_kwargs.get("past") is None or needs_full_input:
if model_kwargs.get("past_key_values") is None or needs_full_input:
input_ids = generated[:, :cur_len]
else:
input_ids = tf.expand_dims(generated[:, cur_len - 1], -1)
@ -1775,15 +1776,15 @@ class TFGenerationMixin:
model_kwargs = self._update_model_kwargs_for_generation(
model_outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
# if we don't cache past key values we need the whole input
if model_kwargs.get("past", None) is None:
# let's throw out `past` since we don't want `None` tensors
model_kwargs.pop("past", None)
# if we don't cache past_key_values key values we need the whole input
if model_kwargs.get("past_key_values", None) is None:
# let's throw out `past_key_values` since we don't want `None` tensors
model_kwargs.pop("past_key_values", None)
return generated, finished_sequences, cur_len, model_kwargs
# 5. run generation
# 1st generation step has to be run before to initialize `past`
# 1st generation step has to be run before to initialize `past_key_values`
generated, finished_sequences, cur_len, model_kwargs = sample_body_fn(
generated, finished_sequences, cur_len, model_kwargs
)
@ -2012,7 +2013,7 @@ class TFGenerationMixin:
# GPT2 and other models has a slightly different cache structure, with a different batch axis
model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self)
cache_batch_axis = 1 if any([model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")]) else 0
# some models, like XLNet, need more than the last token in the presence of past
# some models, like XLNet, need more than the last token in the presence of past_key_values
needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys())
# 2. init `attentions`, `hidden_states`, and `scores` tuples
@ -2092,7 +2093,7 @@ class TFGenerationMixin:
seen so far
"""
# 1. Forward current tokens
if model_kwargs.get("past") is None or needs_full_input:
if model_kwargs.get("past_key_values") is None or needs_full_input:
input_ids = running_sequences[:, :, :cur_len]
else:
input_ids = tf.expand_dims(running_sequences[:, :, cur_len - 1], -1)
@ -2248,10 +2249,10 @@ class TFGenerationMixin:
model_outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
# if we don't cache past key values we need the whole input
if model_kwargs.get("past", None) is None:
# let's throw out `past` since we don't want `None` tensors
model_kwargs.pop("past", None)
# if we don't cache past_key_values key values we need the whole input
if model_kwargs.get("past_key_values", None) is None:
# let's throw out `past_key_values` since we don't want `None` tensors
model_kwargs.pop("past_key_values", None)
return (
cur_len,
@ -2264,7 +2265,7 @@ class TFGenerationMixin:
)
# 5. run generation
# 1st generation step has to be run before to initialize `past` (if active)
# 1st generation step has to be run before to initialize `past_key_values` (if active)
(
cur_len,
running_sequences,
@ -2472,7 +2473,7 @@ class TFGenerationMixin:
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
if model_kwargs.get("past") is None:
if model_kwargs.get("past_key_values") is None:
# prepare inputs
model_inputs = self.prepare_inputs_for_generation(
@ -2520,13 +2521,16 @@ class TFGenerationMixin:
expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
)
past = model_kwargs.get("past")
if past is None:
past_key_values = model_kwargs.get("past_key_values")
if past_key_values is None:
raise ValueError(
f"{self.__class__.__name__} does not support caching and therefore **can't** be used "
"for contrastive search."
)
elif not isinstance(past[0], (tuple, tf.Tensor)) or past[0][0].shape[0] != batch_size:
elif (
not isinstance(past_key_values[0], (tuple, tf.Tensor))
or past_key_values[0][0].shape[0] != batch_size
):
raise ValueError(
f"{self.__class__.__name__} does not have a standard cache format and therefore **can't** be "
"used for contrastive search without further modifications."
@ -2562,8 +2566,8 @@ class TFGenerationMixin:
decoder_hidden_states.append(outputs.hidden_states)
# Replicates the new past_key_values to match the `top_k` candidates
model_kwargs["past"] = tf.nest.map_structure(
lambda tensor: tf.repeat(tensor, top_k, axis=cache_batch_axis), model_kwargs["past"]
model_kwargs["past_key_values"] = tf.nest.map_structure(
lambda tensor: tf.repeat(tensor, top_k, axis=cache_batch_axis), model_kwargs["past_key_values"]
)
# compute the candidate tokens by the language model and collects their hidden_states
@ -2676,7 +2680,7 @@ class TFGenerationMixin:
return generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables
# 5. run generation
# 1st generation step has to be run before to initialize `past`
# 1st generation step has to be run before to initialize `past_key_values`
generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables = contrastive_search_body_fn(
generated, finished_sequences, cur_len, model_kwargs, None
)

View File

@ -675,19 +675,19 @@ class GenerationMixin:
return input_ids, model_kwargs
def _extract_past_from_model_output(self, outputs: ModelOutput, standardize_cache_format: bool = False):
past = None
past_key_values = None
if "past_key_values" in outputs:
past = outputs.past_key_values
past_key_values = outputs.past_key_values
elif "mems" in outputs:
past = outputs.mems
past_key_values = outputs.mems
elif "past_buckets_states" in outputs:
past = outputs.past_buckets_states
past_key_values = outputs.past_buckets_states
# Bloom fix: standardizes the cache format when requested
if standardize_cache_format and hasattr(self, "_convert_to_standard_cache"):
batch_size = outputs.logits.shape[0]
past = self._convert_to_standard_cache(past, batch_size=batch_size)
return past
past_key_values = self._convert_to_standard_cache(past_key_values, batch_size=batch_size)
return past_key_values
def _update_model_kwargs_for_generation(
self,
@ -696,8 +696,8 @@ class GenerationMixin:
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
) -> Dict[str, Any]:
# update past
model_kwargs["past"] = self._extract_past_from_model_output(
# update past_key_values
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
@ -1758,7 +1758,7 @@ class GenerationMixin:
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
if model_kwargs.get("past") is None:
if model_kwargs.get("past_key_values") is None:
# prepare inputs
model_kwargs["use_cache"] = True
@ -1791,13 +1791,16 @@ class GenerationMixin:
expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
)
past = model_kwargs.get("past")
if past is None:
past_key_values = model_kwargs.get("past_key_values")
if past_key_values is None:
raise ValueError(
f"{self.__class__.__name__} does not support caching and therefore **can't** be used "
"for contrastive search."
)
elif not isinstance(past[0], (tuple, torch.Tensor)) or past[0][0].shape[0] != batch_size:
elif (
not isinstance(past_key_values[0], (tuple, torch.Tensor))
or past_key_values[0][0].shape[0] != batch_size
):
raise ValueError(
f"{self.__class__.__name__} does not have a standard cache format and therefore **can't** be "
"used for contrastive search without further modifications."
@ -1832,13 +1835,13 @@ class GenerationMixin:
# Replicates the new past_key_values to match the `top_k` candidates
new_key_values = []
for layer in model_kwargs["past"]:
for layer in model_kwargs["past_key_values"]:
items = []
# item is either the key or the value matrix
for item in layer:
items.append(item.repeat_interleave(top_k, dim=0))
new_key_values.append(items)
model_kwargs["past"] = new_key_values
model_kwargs["past_key_values"] = new_key_values
# compute the candidate tokens by the language model and collects their hidden_states
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs)
@ -2718,8 +2721,8 @@ class GenerationMixin:
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if model_kwargs["past"] is not None:
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
if return_dict_in_generate and output_scores:
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
@ -3040,8 +3043,8 @@ class GenerationMixin:
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if model_kwargs["past"] is not None:
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
if return_dict_in_generate and output_scores:
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
@ -3410,8 +3413,10 @@ class GenerationMixin:
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if model_kwargs["past"] is not None:
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._reorder_cache(
model_kwargs["past_key_values"], reordering_indices
)
# increase cur_len
cur_len = cur_len + 1
@ -3732,8 +3737,8 @@ class GenerationMixin:
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if model_kwargs["past"] is not None:
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
# increase cur_len
cur_len = cur_len + 1

View File

@ -1418,7 +1418,7 @@ class BartForConditionalGeneration(BartPretrainedModel):
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
@ -1428,14 +1428,14 @@ class BartForConditionalGeneration(BartPretrainedModel):
encoder_outputs=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
@ -1910,18 +1910,20 @@ class BartForCausalLM(BartPretrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
if past_key_values:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
}

View File

@ -1436,7 +1436,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
@ -1447,21 +1447,21 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_attention_mask is not None: # xla
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
elif past is not None: # no xla + past
decoder_position_ids = past[0][0].shape[2]
else: # no xla + no past
elif past_key_values is not None: # no xla + past_key_values
decoder_position_ids = past_key_values[0][0].shape[2]
else: # no xla + no past_key_values
decoder_position_ids = tf.range(decoder_input_ids.shape[1])
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,

View File

@ -1274,20 +1274,22 @@ class BertLMHeadModel(BertPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=True, **model_kwargs):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=True, **model_kwargs
):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
}

View File

@ -1395,17 +1395,17 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = tf.ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
@unpack_inputs
@add_code_sample_docstrings(

View File

@ -987,20 +987,20 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past

View File

@ -2626,7 +2626,7 @@ class BigBirdForCausalLM(BigBirdPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
@ -2634,14 +2634,14 @@ class BigBirdForCausalLM(BigBirdPreTrainedModel):
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past:
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
)

View File

@ -2617,7 +2617,7 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
@ -2627,14 +2627,14 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
encoder_outputs=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
@ -3105,18 +3105,20 @@ class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
if past_key_values:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
}

View File

@ -706,16 +706,16 @@ class BioGptForCausalLM(BioGptPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, attention_mask, past=None, **kwargs):
def prepare_inputs_for_generation(self, input_ids, attention_mask, past_key_values=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if past:
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
}

View File

@ -1378,7 +1378,7 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
@ -1388,13 +1388,13 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
@ -1609,18 +1609,20 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
if past_key_values:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
}

View File

@ -1438,7 +1438,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
@ -1449,21 +1449,21 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_attention_mask is not None: # xla
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
elif past is not None: # no xla + past
decoder_position_ids = past[0][0].shape[2]
else: # no xla + no past
elif past_key_values is not None: # no xla + past_key_values
decoder_position_ids = past_key_values[0][0].shape[2]
else: # no xla + no past_key_values
decoder_position_ids = tf.range(decoder_input_ids.shape[1])
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,

View File

@ -1345,7 +1345,7 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
@ -1355,13 +1355,13 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
@ -1576,18 +1576,20 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
if past_key_values:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
}

View File

@ -1418,7 +1418,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
@ -1429,21 +1429,21 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_attention_mask is not None: # xla
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
elif past is not None: # no xla + past
decoder_position_ids = past[0][0].shape[2]
else: # no xla + no past
elif past_key_values is not None: # no xla + past_key_values
decoder_position_ids = past_key_values[0][0].shape[2]
else: # no xla + no past_key_values
decoder_position_ids = tf.range(decoder_input_ids.shape[1])
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,

View File

@ -917,20 +917,20 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
"is_decoder": True,

View File

@ -842,21 +842,21 @@ class BloomForCausalLM(BloomPreTrainedModel):
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past: Optional[torch.Tensor] = None,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs
) -> dict:
# only last token for input_ids if past is not None
if past:
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
if past[0][0].shape[0] == input_ids.shape[0]:
past = self._convert_to_bloom_cache(past)
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = self._convert_to_bloom_cache(past_key_values)
return {
"input_ids": input_ids,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}

View File

@ -1559,17 +1559,17 @@ class CamembertForCausalLM(CamembertPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
reordered_past = ()

View File

@ -1612,17 +1612,17 @@ class TFCamembertForCausalLM(TFCamembertPreTrainedModel, TFCausalLanguageModelin
return self.name + "/" + self.lm_head.name
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = tf.ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
@unpack_inputs
@add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))

View File

@ -632,10 +632,10 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
@ -647,13 +647,13 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,

View File

@ -525,12 +525,12 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cache=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if past:
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
return {"input_ids": input_ids, "past_key_values": past, "use_cache": use_cache}
return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache}
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)

View File

@ -641,12 +641,12 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name
def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cache=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if past:
if past_key_values:
input_ids = tf.expand_dims(input_ids[:, -1], -1)
return {"input_ids": input_ids, "past_key_values": past, "use_cache": use_cache}
return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache}
@unpack_inputs
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)

View File

@ -1015,17 +1015,17 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
reordered_past = ()

View File

@ -1666,17 +1666,17 @@ class ElectraForCausalLM(ElectraPreTrainedModel):
)
# Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
# Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM._reorder_cache
def _reorder_cache(self, past, beam_idx):

View File

@ -1214,20 +1214,22 @@ class ErnieForCausalLM(ErniePreTrainedModel):
)
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=True, **model_kwargs):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=True, **model_kwargs
):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
}

View File

@ -1268,7 +1268,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
@ -1280,7 +1280,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,

View File

@ -983,10 +983,10 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
@ -998,13 +998,13 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
@ -1156,10 +1156,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
@ -1171,14 +1171,14 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,

View File

@ -828,10 +828,10 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
def set_output_embeddings(self, value):
self.set_input_embeddings(value)
def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs):
def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
if past_key_values:
inputs = tf.expand_dims(inputs[:, -1], -1)
if token_type_ids is not None:
token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1)
@ -841,14 +841,14 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
if attention_mask is not None and position_ids is None:
position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)
if past:
if past_key_values:
position_ids = tf.expand_dims(position_ids[:, -1], -1)
return {
"input_ids": inputs,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
"token_type_ids": token_type_ids,
}

View File

@ -683,10 +683,10 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
@ -698,13 +698,13 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,

View File

@ -686,7 +686,7 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
@ -694,13 +694,13 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past and past[0] is not None:
if past_key_values and past_key_values[0] is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past or model_kwargs.get("past_key_values"),
"past_key_values": past_key_values,
}
def _reorder_cache(self, past, beam_idx):

View File

@ -702,7 +702,7 @@ class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel):
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
@ -710,10 +710,10 @@ class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel):
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past and past[0] is not None:
if past_key_values and past_key_values[0] is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
reordered_past = ()

View File

@ -760,10 +760,10 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
@ -775,13 +775,13 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,

View File

@ -741,10 +741,10 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs):
def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
if past_key_values:
inputs = tf.expand_dims(inputs[:, -1], -1)
if token_type_ids is not None:
token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1)
@ -754,14 +754,14 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
if attention_mask is not None and position_ids is None:
position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)
if past:
if past_key_values:
position_ids = tf.expand_dims(position_ids[:, -1], -1)
return {
"input_ids": inputs,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
"token_type_ids": token_type_ids,
}

View File

@ -914,10 +914,10 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past: Optional[bool] = None, **kwargs):
def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past_key_values: Optional[bool] = None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
@ -929,13 +929,13 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,

View File

@ -2480,7 +2480,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
global_attention_mask=None,
head_mask=None,
@ -2491,13 +2491,13 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
**kwargs,
):
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"global_attention_mask": global_attention_mask,

View File

@ -2512,7 +2512,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
@ -2521,13 +2521,13 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,

View File

@ -2093,7 +2093,7 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel):
def prepare_inputs_for_generation(
self,
input_ids,
past=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
@ -2104,12 +2104,12 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel):
):
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {
"decoder_input_ids": input_ids,
"past_key_values": past,
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"head_mask": head_mask,

View File

@ -1373,7 +1373,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
@ -1383,13 +1383,13 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
**kwargs,
):
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,

View File

@ -1491,7 +1491,7 @@ class MarianMTModel(MarianPreTrainedModel):
def prepare_inputs_for_generation(
self,
decoder_input_ids: torch.LongTensor,
past: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
@ -1501,13 +1501,13 @@ class MarianMTModel(MarianPreTrainedModel):
**kwargs,
) -> Dict:
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
@ -1727,18 +1727,20 @@ class MarianForCausalLM(MarianPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
if past_key_values:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
}

View File

@ -1455,7 +1455,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
@ -1466,21 +1466,21 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_attention_mask is not None: # xla
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
elif past is not None: # no xla + past
decoder_position_ids = past[0][0].shape[2]
else: # no xla + no past
elif past_key_values is not None: # no xla + past_key_values
decoder_position_ids = past_key_values[0][0].shape[2]
else: # no xla + no past_key_values
decoder_position_ids = tf.range(decoder_input_ids.shape[1])
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,

View File

@ -937,20 +937,22 @@ class MarkupLMModel(MarkupLMPreTrainedModel):
)
# Copied from transformers.models.bert.modeling_bert.BertModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=True, **model_kwargs):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=True, **model_kwargs
):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
}

View File

@ -1404,7 +1404,7 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
@ -1414,13 +1414,13 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
@ -1892,18 +1892,20 @@ class MBartForCausalLM(MBartPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
if past_key_values:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
}

View File

@ -1452,7 +1452,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
@ -1463,21 +1463,21 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_attention_mask is not None: # xla
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
elif past is not None: # no xla + past
decoder_position_ids = past[0][0].shape[2]
else: # no xla + no past
elif past_key_values is not None: # no xla + past_key_values
decoder_position_ids = past_key_values[0][0].shape[2]
else: # no xla + no past_key_values
decoder_position_ids = tf.range(decoder_input_ids.shape[1])
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,

View File

@ -1247,17 +1247,17 @@ class MegatronBertForCausalLM(MegatronBertPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
reordered_past = ()

View File

@ -1746,7 +1746,7 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
def prepare_inputs_for_generation(
self,
input_ids,
past=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
@ -1757,12 +1757,12 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
):
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {
"decoder_input_ids": input_ids,
"past_key_values": past,
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"head_mask": head_mask,

View File

@ -1555,7 +1555,7 @@ class MvpForConditionalGeneration(MvpPreTrainedModel):
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
@ -1565,13 +1565,13 @@ class MvpForConditionalGeneration(MvpPreTrainedModel):
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
@ -2040,18 +2040,20 @@ class MvpForCausalLM(MvpPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
if past_key_values:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
}

View File

@ -966,18 +966,20 @@ class OPTForCausalLM(OPTPreTrainedModel):
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
if past_key_values:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
}

View File

@ -881,17 +881,17 @@ class TFOPTForCausalLM(TFOPTPreTrainedModel, TFCausalLanguageModelingLoss):
def get_output_embeddings(self):
return self.model.get_input_embeddings()
def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs):
def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs):
attention_mask = kwargs.get("attention_mask", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
if past_key_values:
inputs = tf.expand_dims(inputs[:, -1], -1)
return {
"input_ids": inputs,
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
}

View File

@ -1452,7 +1452,7 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
@ -1462,13 +1462,13 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
@ -1706,18 +1706,20 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
if past_key_values:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
}

View File

@ -1465,7 +1465,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
@ -1476,21 +1476,21 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_attention_mask is not None: # xla
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
elif past is not None: # no xla + past
decoder_position_ids = past[0][0].shape[2]
else: # no xla + no past
elif past_key_values is not None: # no xla + past_key_values
decoder_position_ids = past_key_values[0][0].shape[2]
else: # no xla + no past_key_values
decoder_position_ids = tf.range(decoder_input_ids.shape[1])
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,

View File

@ -1659,16 +1659,22 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel):
)
def prepare_inputs_for_generation(
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
self,
decoder_input_ids,
past_key_values=None,
attention_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)

View File

@ -1375,7 +1375,7 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
def prepare_inputs_for_generation(
self,
decoder_input_ids: torch.LongTensor,
past: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
@ -1385,13 +1385,13 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
**kwargs # TODO: Check if this is needed. It is unused?
) -> Dict[str, Any]:
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
@ -1737,18 +1737,20 @@ class PLBartForCausalLM(PLBartPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
if past_key_values:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
}

View File

@ -2062,7 +2062,7 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
@ -2073,13 +2073,13 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
):
assert encoder_outputs is not None, "`encoder_outputs` have to be passed for generation."
if past:
if past_key_values:
decoder_input_ids = decoder_input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
@ -2316,7 +2316,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
def prepare_inputs_for_generation(
self,
input_ids,
past=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
use_cache=None,
@ -2326,14 +2326,14 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
if past_key_values:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"head_mask": head_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
}

View File

@ -1145,7 +1145,7 @@ class QDQBertLMHeadModel(QDQBertPreTrainedModel):
def prepare_inputs_for_generation(
self,
input_ids: Optional[torch.LongTensor],
past=None,
past_key_values=None,
attention_mask: Optional[torch.Tensor] = None,
**model_kwargs
):
@ -1155,10 +1155,10 @@ class QDQBertLMHeadModel(QDQBertPreTrainedModel):
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
reordered_past = ()

View File

@ -1170,7 +1170,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
use_cache=None,
encoder_outputs=None,
@ -1178,7 +1178,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
n_docs=None,
**kwargs
):
if past is not None:
if past_key_values is not None:
# if past is defined use only last decoder_input_ids
decoder_input_ids = decoder_input_ids[:, -1:]
@ -1188,7 +1188,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
"doc_scores": doc_scores,
"context_attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
"do_marginalize": True,
"n_docs": n_docs,

View File

@ -764,7 +764,7 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
use_cache=None,
encoder_outputs=None,
@ -772,7 +772,7 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
n_docs=None,
**kwargs
):
if past is not None:
if past_key_values is not None:
# if past is defined use only last decoder_input_ids
decoder_input_ids = decoder_input_ids[:, -1:]
@ -782,7 +782,7 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
"doc_scores": doc_scores,
"context_attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
"do_marginalize": True,
"n_docs": n_docs,

View File

@ -2286,14 +2286,16 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
attentions=reformer_outputs.attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, num_hashes=None, **kwargs):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, use_cache=None, num_hashes=None, **kwargs
):
# only last token for inputs_ids if past is defined in kwargs
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
inputs_dict = {
"input_ids": input_ids,
"past_buckets_states": past,
"past_buckets_states": past_key_values,
"use_cache": use_cache,
"num_hashes": num_hashes,
}

View File

@ -1141,7 +1141,7 @@ class RemBertForCausalLM(RemBertPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
@ -1149,10 +1149,10 @@ class RemBertForCausalLM(RemBertPreTrainedModel):
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
reordered_past = ()

View File

@ -1131,17 +1131,17 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
return self.mlm.predictions
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = tf.ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
@unpack_inputs
@add_code_sample_docstrings(

View File

@ -1013,17 +1013,17 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
reordered_past = ()

View File

@ -1171,17 +1171,17 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos
return self.name + "/" + self.lm_head.name
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = tf.ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
@unpack_inputs
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))

View File

@ -1020,17 +1020,17 @@ class RobertaPreLayerNormForCausalLM(RobertaPreLayerNormPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
reordered_past = ()

View File

@ -1186,17 +1186,17 @@ class TFRobertaPreLayerNormForCausalLM(TFRobertaPreLayerNormPreTrainedModel, TFC
return self.name + "/" + self.lm_head.name
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = tf.ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
@unpack_inputs
@add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))

View File

@ -1551,7 +1551,7 @@ class RoCBertForCausalLM(RoCBertPreTrainedModel):
input_ids,
input_shape_ids=None,
input_pronunciation_ids=None,
past=None,
past_key_values=None,
attention_mask=None,
**model_kwargs
):
@ -1562,7 +1562,7 @@ class RoCBertForCausalLM(RoCBertPreTrainedModel):
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
if input_shape_ids is not None:
input_shape_ids = input_shape_ids[:, -1:]
@ -1574,7 +1574,7 @@ class RoCBertForCausalLM(RoCBertPreTrainedModel):
"input_shape_ids": input_shape_ids,
"input_pronunciation_ids": input_pronunciation_ids,
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
}
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel._reorder_cache

View File

@ -1178,7 +1178,7 @@ class RoFormerForCausalLM(RoFormerPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
@ -1186,10 +1186,10 @@ class RoFormerForCausalLM(RoFormerPreTrainedModel):
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
reordered_past = ()

View File

@ -1395,7 +1395,7 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
@ -1405,12 +1405,12 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,

View File

@ -1477,7 +1477,7 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
@ -1487,13 +1487,13 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_features": None, # needs to be passed to make Keras.layer.__call__ happy
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,

View File

@ -951,18 +951,20 @@ class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
if past_key_values:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
}

View File

@ -1750,7 +1750,7 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
def prepare_inputs_for_generation(
self,
input_ids,
past=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
@ -1761,12 +1761,12 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
):
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {
"decoder_input_ids": input_ids,
"past_key_values": past,
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"head_mask": head_mask,

View File

@ -1713,7 +1713,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
def prepare_inputs_for_generation(
self,
input_ids,
past=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
@ -1724,12 +1724,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
):
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {
"decoder_input_ids": input_ids,
"past_key_values": past,
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"head_mask": head_mask,

View File

@ -1243,7 +1243,7 @@ class TFT5Model(TFT5PreTrainedModel):
past = decoder_outputs[1] if use_cache else None
if not return_dict:
if past is not None:
if past_key_values is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
return decoder_outputs + encoder_outputs
@ -1441,7 +1441,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
past = decoder_outputs[1] if use_cache else None
if not return_dict:
if past is not None:
if past_key_values is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
output = (logits,) + decoder_outputs[1:] + encoder_outputs
return ((loss,) + output) if loss is not None else output
@ -1499,7 +1499,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
def prepare_inputs_for_generation(
self,
input_ids,
past=None,
past_key_values=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
@ -1510,13 +1510,13 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
):
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": None, # needs to be passed to make Keras.layer.__call__ happy
"decoder_input_ids": input_ids,
"past_key_values": past,
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,

View File

@ -1028,11 +1028,11 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
attentions=attns,
)
def prepare_inputs_for_generation(self, input_ids, past=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **model_kwargs):
inputs = {}
# if past is defined in model kwargs then use it for faster decoding
if past:
if past_key_values:
input_ids = tf.expand_dims(input_ids[:, -1], axis=-1)
else:
input_ids = input_ids

View File

@ -1150,12 +1150,12 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
else:
return self.crit.out_layers[-1]
def prepare_inputs_for_generation(self, input_ids, past=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **model_kwargs):
inputs = {}
# if past is defined in model kwargs then use it for faster decoding
if past:
inputs["mems"] = past
if past_key_values:
inputs["mems"] = past_key_values
inputs["input_ids"] = input_ids[:, -1].unsqueeze(-1)
else:
inputs["input_ids"] = input_ids

View File

@ -992,18 +992,20 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
if past_key_values:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
}

View File

@ -1358,7 +1358,7 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
use_cache=None,
encoder_outputs=None,
attention_mask=None,
@ -1366,13 +1366,13 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_attention_mask is not None: # xla
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
elif past is not None: # no xla + past
decoder_position_ids = past[0][0].shape[2]
elif past_key_values is not None: # no xla + past
decoder_position_ids = past_key_values[0][0].shape[2]
else: # no xla + no past
decoder_position_ids = tf.range(decoder_input_ids.shape[1])
decoder_position_ids = tf.broadcast_to(decoder_position_ids, decoder_input_ids.shape)
@ -1380,7 +1380,7 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua
return {
"input_features": None, # Needs to be passed to make Keras.layer.__call__ happy
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"use_cache": use_cache,
"decoder_attention_mask": decoder_attention_mask,

View File

@ -1233,15 +1233,21 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
)
def prepare_inputs_for_generation(
self, decoder_input_ids, past=None, use_cache=None, encoder_outputs=None, attention_mask=None, **kwargs
self,
decoder_input_ids,
past_key_values=None,
use_cache=None,
encoder_outputs=None,
attention_mask=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"use_cache": use_cache,
"decoder_attention_mask": None,

View File

@ -888,9 +888,9 @@ class TFXGLMForCausalLM(TFXGLMPreTrainedModel, TFCausalLanguageModelingLoss):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs):
def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if past:
if past_key_values:
inputs = tf.expand_dims(inputs[:, -1], -1)
attention_mask = kwargs.get("attention_mask", None)
@ -898,7 +898,7 @@ class TFXGLMForCausalLM(TFXGLMPreTrainedModel, TFCausalLanguageModelingLoss):
return {
"input_ids": inputs,
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
}

View File

@ -929,18 +929,20 @@ class XGLMForCausalLM(XGLMPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
if past_key_values:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
}

View File

@ -2089,7 +2089,7 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel):
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
@ -2100,13 +2100,13 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel):
):
assert encoder_outputs is not None, "`encoder_outputs` have to be passed for generation."
if past:
if past_key_values:
decoder_input_ids = decoder_input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
@ -2346,7 +2346,7 @@ class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel):
def prepare_inputs_for_generation(
self,
input_ids,
past=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
use_cache=None,
@ -2356,14 +2356,14 @@ class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel):
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
if past_key_values:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"head_mask": head_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
}

View File

@ -1017,17 +1017,17 @@ class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
reordered_past = ()

View File

@ -979,17 +979,17 @@ class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
reordered_past = ()

View File

@ -1217,7 +1217,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_loss.name
def prepare_inputs_for_generation(self, inputs, past=None, use_mems=None, **kwargs):
def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_mems=None, **kwargs):
# Add dummy token at the end (no attention on this one)
effective_batch_size = inputs.shape[0]
dummy_token = tf.zeros((effective_batch_size, 1), dtype=inputs.dtype)
@ -1227,7 +1227,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
# offset = 1; offset = 2 seems to have slightly better computation.
offset = 2
if past:
if past_key_values:
input_ids = tf.concat([inputs[:, -offset:], dummy_token], axis=1)
else:
input_ids = tf.concat([inputs, dummy_token], axis=1)
@ -1251,8 +1251,8 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
}
# if past is defined in model kwargs then use it for faster decoding
if past:
inputs["mems"] = tuple(layer_past[:-offset, :, :] for layer_past in past)
if past_key_values:
inputs["mems"] = tuple(layer_past[:-offset, :, :] for layer_past in past_key_values)
return inputs

View File

@ -1315,7 +1315,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.lm_loss = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past=None, use_mems=None, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_mems=None, **kwargs):
# Add dummy token at the end (no attention on this one)
effective_batch_size = input_ids.shape[0]
@ -1326,7 +1326,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
# offset = 1; offset = 2 seems to have slightly better computation.
offset = 2
if past:
if past_key_values:
input_ids = torch.cat([input_ids[:, -offset:], dummy_token], dim=1)
else:
input_ids = torch.cat([input_ids, dummy_token], dim=1)
@ -1352,8 +1352,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
}
# if past is defined in model kwargs then use it for faster decoding
if past:
inputs["mems"] = tuple(layer_past[:-offset, :, :] for layer_past in past)
if past_key_values:
inputs["mems"] = tuple(layer_past[:-offset, :, :] for layer_past in past_key_values)
return inputs

View File

@ -1121,15 +1121,15 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
def get_lm_head(self) -> tf.keras.layers.Layer:
return self.mlm.predictions
def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, inputs, past_key_values=None, attention_mask=None, **model_kwargs):
# cut decoder_input_ids if past is used
if past:
if past_key_values:
inputs = tf.expand_dims(inputs[:, -1], -1)
return {
"input_ids": inputs,
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": model_kwargs["use_cache"],
}
@ -3003,7 +3003,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
@ -3013,13 +3013,13 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None, # needs to be passed to make Keras.layer.__call__ happy
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,

View File

@ -1167,7 +1167,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
@ -1175,10 +1175,10 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
reordered_past = ()
@ -2879,7 +2879,7 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
@ -2889,13 +2889,13 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
@ -3328,7 +3328,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
@ -3339,7 +3339,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
}