Replace `past` with `past_key_values` (#20944)
* start cleanup * more updates * more models are affected * more updates * update generation utils * style * revert change that removed reorder cachce * update generation utils * style * style * remove reorder cache
This commit is contained in:
parent
7cb596fa22
commit
f0577df6de
|
@ -1074,20 +1074,20 @@ class TFGenerationMixin:
|
|||
|
||||
@staticmethod
|
||||
def _extract_past_from_model_output(outputs: ModelOutput):
|
||||
past = None
|
||||
past_key_values = None
|
||||
if "past_key_values" in outputs:
|
||||
past = outputs.past_key_values
|
||||
past_key_values = outputs.past_key_values
|
||||
elif "mems" in outputs:
|
||||
past = outputs.mems
|
||||
past_key_values = outputs.mems
|
||||
elif "past_buckets_states" in outputs:
|
||||
past = outputs.past_buckets_states
|
||||
return past
|
||||
past_key_values = outputs.past_buckets_states
|
||||
return past_key_values
|
||||
|
||||
def _update_model_kwargs_for_generation(
|
||||
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
# update past
|
||||
model_kwargs["past"] = self._extract_past_from_model_output(outputs)
|
||||
# update past_key_values
|
||||
model_kwargs["past_key_values"] = self._extract_past_from_model_output(outputs)
|
||||
|
||||
# update attention mask
|
||||
if not is_encoder_decoder:
|
||||
|
@ -1112,7 +1112,7 @@ class TFGenerationMixin:
|
|||
def _initialize_attention(model_kwargs, num_padding_values, is_encoder_decoder):
|
||||
"""initializes the appropriate attention mask -- encoder-decoder models use `decoder_attention_mask`"""
|
||||
if is_encoder_decoder:
|
||||
# One 1 for decoder_start_token_id, 0s for the currently-unfilled locations in the past tensor,
|
||||
# One 1 for decoder_start_token_id, 0s for the currently-unfilled locations in the past_key_values tensor,
|
||||
# 1s for the actual input_ids
|
||||
decoder_attention_mask = tf.concat(
|
||||
[
|
||||
|
@ -1125,7 +1125,7 @@ class TFGenerationMixin:
|
|||
mask = {"decoder_attention_mask": decoder_attention_mask}
|
||||
else:
|
||||
attention_mask = model_kwargs.pop("attention_mask")
|
||||
# 0s for the currently-unfilled locations in the past tensor, 1s for the actual input_ids
|
||||
# 0s for the currently-unfilled locations in the past_key_values tensor, 1s for the actual input_ids
|
||||
attention_mask = tf.concat(
|
||||
[
|
||||
attention_mask,
|
||||
|
@ -1154,32 +1154,32 @@ class TFGenerationMixin:
|
|||
mask = {"attention_mask": attention_mask}
|
||||
return mask
|
||||
|
||||
def _initialize_past(past, num_padding_values, batch_axis):
|
||||
"""initialize past with zeros -- the structure depends on `batch_axis`"""
|
||||
def _initialize_past(past_key_values, num_padding_values, batch_axis):
|
||||
"""initialize past_key_values with zeros -- the structure depends on `batch_axis`"""
|
||||
if batch_axis == 0:
|
||||
padding_values = tf.constant([[0, 0], [0, 0], [0, num_padding_values], [0, 0]], dtype=tf.int32)
|
||||
new_past = ()
|
||||
for past_layer in past:
|
||||
for past_layer in past_key_values:
|
||||
new_past_layer = list(past_layer)
|
||||
for i in range(len(new_past_layer[:2])):
|
||||
new_past_layer[i] = tf.pad(past_layer[i], padding_values)
|
||||
new_past += (tuple(new_past_layer),)
|
||||
else:
|
||||
padding_values = tf.scatter_nd(indices=[[3, 1]], updates=[num_padding_values], shape=(5, 2))
|
||||
new_past = list(past)
|
||||
for i in range(len(past)):
|
||||
new_past[i] = tf.pad(past[i], padding_values)
|
||||
new_past = list(past_key_values)
|
||||
for i in range(len(past_key_values)):
|
||||
new_past[i] = tf.pad(past_key_values[i], padding_values)
|
||||
return new_past
|
||||
|
||||
def _update_past(past, new_past_index, batch_axis):
|
||||
def _update_past(past_key_values, new_past_index, batch_axis):
|
||||
if batch_axis == 0:
|
||||
slice_start_base = tf.constant([0, 0, 1, 0])
|
||||
new_past = ()
|
||||
for past_layer in past:
|
||||
for past_layer in past_key_values:
|
||||
new_past_layer = list(past_layer)
|
||||
for i in range(len(new_past_layer[:2])):
|
||||
update_slice = past_layer[i][:, :, -1:]
|
||||
# Write the last slice to the first open location in the padded past array
|
||||
# Write the last slice to the first open location in the padded past_key_values array
|
||||
# and then truncate the last slice off the array
|
||||
new_past_layer[i] = dynamic_update_slice(
|
||||
past_layer[i][:, :, :-1], update_slice, slice_start_base * new_past_index
|
||||
|
@ -1187,41 +1187,42 @@ class TFGenerationMixin:
|
|||
new_past += (tuple(new_past_layer),)
|
||||
else:
|
||||
slice_start_base = tf.constant([0, 0, 0, 1, 0])
|
||||
new_past = [None for _ in range(len(past))]
|
||||
for i in range(len(past)):
|
||||
update_slice = past[i][:, :, :, -1:]
|
||||
# Write the last slice to the first open location in the padded past array
|
||||
new_past = [None for _ in range(len(past_key_values))]
|
||||
for i in range(len(past_key_values)):
|
||||
update_slice = past_key_values[i][:, :, :, -1:]
|
||||
# Write the last slice to the first open location in the padded past_key_values array
|
||||
# and then truncate the last slice off the array
|
||||
new_past[i] = dynamic_update_slice(
|
||||
past[i][:, :, :, :-1], update_slice, slice_start_base * new_past_index
|
||||
past_key_values[i][:, :, :, :-1], update_slice, slice_start_base * new_past_index
|
||||
)
|
||||
return new_past
|
||||
|
||||
past = self._extract_past_from_model_output(model_outputs)
|
||||
if past is None:
|
||||
past_key_values = self._extract_past_from_model_output(model_outputs)
|
||||
if past_key_values is None:
|
||||
raise ValueError(
|
||||
f"No known past variable found in model outputs (model outputs keys: {list(model_outputs.keys())})"
|
||||
"No known `past_key_values variable` found in model outputs (model outputs keys:"
|
||||
f" {list(model_outputs.keys())})"
|
||||
)
|
||||
is_past_initialized = model_kwargs.pop("past", None) is not None
|
||||
is_past_initialized = model_kwargs.pop("past_key_values", None) is not None
|
||||
|
||||
if not is_past_initialized:
|
||||
# The padded version of `past` has a length of `max_length - 1`, as `past` holds information relative to
|
||||
# previous autoregressive generation steps (step 0 has no past, step 1 has 1 past value, ..., the last step
|
||||
# has `max_length - 1` past values).
|
||||
# The padded version of `past_key_values` has a length of `max_length - 1`, as `past_key_values` holds information relative to
|
||||
# previous autoregressive generation steps (step 0 has no past_key_values, step 1 has 1 past_key_values value, ..., the last step
|
||||
# has `max_length - 1` past_key_values values).
|
||||
num_padding_values = max_length - cur_len - 1
|
||||
mask = _initialize_attention(model_kwargs, num_padding_values, is_encoder_decoder)
|
||||
new_past = _initialize_past(past, num_padding_values, batch_axis)
|
||||
new_past = _initialize_past(past_key_values, num_padding_values, batch_axis)
|
||||
else:
|
||||
# The new index of past to be filled corresponds to the current length of the sequence, with two
|
||||
# subtractions: -1 because past holds information regarding previous generation steps (read comment above)
|
||||
# The new index of past_key_values to be filled corresponds to the current length of the sequence, with two
|
||||
# subtractions: -1 because past_key_values holds information regarding previous generation steps (read comment above)
|
||||
# and -1 again because in an array the index is the length of the array minus 1.
|
||||
new_past_index = cur_len - 2
|
||||
mask = _update_attention(model_kwargs, new_past_index, is_encoder_decoder)
|
||||
new_past = _update_past(past, new_past_index, batch_axis)
|
||||
new_past = _update_past(past_key_values, new_past_index, batch_axis)
|
||||
|
||||
# sets the updated variables (mask and past)
|
||||
# sets the updated variables (mask and past_key_values)
|
||||
model_kwargs.update(mask)
|
||||
model_kwargs["past"] = tuple(new_past)
|
||||
model_kwargs["past_key_values"] = tuple(new_past)
|
||||
|
||||
return model_kwargs
|
||||
|
||||
|
@ -1403,7 +1404,7 @@ class TFGenerationMixin:
|
|||
# GPT2 and other models has a slightly different cache structure, with a different batch axis
|
||||
model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self)
|
||||
cache_batch_axis = 1 if any([model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")]) else 0
|
||||
# some models, like XLNet, need more than the last token in the presence of past
|
||||
# some models, like XLNet, need more than the last token in the presence of past_key_values
|
||||
needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys())
|
||||
|
||||
# 2. init `attentions`, `hidden_states`, and `scores` tuples
|
||||
|
@ -1429,7 +1430,7 @@ class TFGenerationMixin:
|
|||
# define condition fn
|
||||
def greedy_search_body_fn(generated, finished_sequences, cur_len, model_kwargs):
|
||||
"""state update fn."""
|
||||
if model_kwargs.get("past") is None or needs_full_input:
|
||||
if model_kwargs.get("past_key_values") is None or needs_full_input:
|
||||
input_ids = generated[:, :cur_len]
|
||||
else:
|
||||
input_ids = tf.expand_dims(generated[:, cur_len - 1], -1)
|
||||
|
@ -1492,15 +1493,15 @@ class TFGenerationMixin:
|
|||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
model_outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
# if we don't cache past key values we need the whole input
|
||||
if model_kwargs.get("past", None) is None:
|
||||
# let's throw out `past` since we don't want `None` tensors
|
||||
model_kwargs.pop("past", None)
|
||||
# if we don't cache past_key_values key values we need the whole input
|
||||
if model_kwargs.get("past_key_values", None) is None:
|
||||
# let's throw out `past_key_values` since we don't want `None` tensors
|
||||
model_kwargs.pop("past_key_values", None)
|
||||
|
||||
return generated, finished_sequences, cur_len, model_kwargs
|
||||
|
||||
# 5. run generation
|
||||
# 1st generation step has to be run before to initialize `past`
|
||||
# 1st generation step has to be run before to initialize `past_key_values`
|
||||
generated, finished_sequences, cur_len, model_kwargs = greedy_search_body_fn(
|
||||
generated, finished_sequences, cur_len, model_kwargs
|
||||
)
|
||||
|
@ -1680,7 +1681,7 @@ class TFGenerationMixin:
|
|||
# GPT2 and other models has a slightly different cache structure, with a different batch axis
|
||||
model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self)
|
||||
cache_batch_axis = 1 if any([model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")]) else 0
|
||||
# some models, like XLNet, need more than the last token in the presence of past
|
||||
# some models, like XLNet, need more than the last token in the presence of past_key_values
|
||||
needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys())
|
||||
|
||||
# 2. init `attentions`, `hidden_states`, and `scores` tuples
|
||||
|
@ -1702,7 +1703,7 @@ class TFGenerationMixin:
|
|||
return ~tf.reduce_all(finished_sequences)
|
||||
|
||||
def sample_body_fn(generated, finished_sequences, cur_len, model_kwargs):
|
||||
if model_kwargs.get("past") is None or needs_full_input:
|
||||
if model_kwargs.get("past_key_values") is None or needs_full_input:
|
||||
input_ids = generated[:, :cur_len]
|
||||
else:
|
||||
input_ids = tf.expand_dims(generated[:, cur_len - 1], -1)
|
||||
|
@ -1775,15 +1776,15 @@ class TFGenerationMixin:
|
|||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
model_outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
# if we don't cache past key values we need the whole input
|
||||
if model_kwargs.get("past", None) is None:
|
||||
# let's throw out `past` since we don't want `None` tensors
|
||||
model_kwargs.pop("past", None)
|
||||
# if we don't cache past_key_values key values we need the whole input
|
||||
if model_kwargs.get("past_key_values", None) is None:
|
||||
# let's throw out `past_key_values` since we don't want `None` tensors
|
||||
model_kwargs.pop("past_key_values", None)
|
||||
|
||||
return generated, finished_sequences, cur_len, model_kwargs
|
||||
|
||||
# 5. run generation
|
||||
# 1st generation step has to be run before to initialize `past`
|
||||
# 1st generation step has to be run before to initialize `past_key_values`
|
||||
generated, finished_sequences, cur_len, model_kwargs = sample_body_fn(
|
||||
generated, finished_sequences, cur_len, model_kwargs
|
||||
)
|
||||
|
@ -2012,7 +2013,7 @@ class TFGenerationMixin:
|
|||
# GPT2 and other models has a slightly different cache structure, with a different batch axis
|
||||
model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self)
|
||||
cache_batch_axis = 1 if any([model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")]) else 0
|
||||
# some models, like XLNet, need more than the last token in the presence of past
|
||||
# some models, like XLNet, need more than the last token in the presence of past_key_values
|
||||
needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys())
|
||||
|
||||
# 2. init `attentions`, `hidden_states`, and `scores` tuples
|
||||
|
@ -2092,7 +2093,7 @@ class TFGenerationMixin:
|
|||
seen so far
|
||||
"""
|
||||
# 1. Forward current tokens
|
||||
if model_kwargs.get("past") is None or needs_full_input:
|
||||
if model_kwargs.get("past_key_values") is None or needs_full_input:
|
||||
input_ids = running_sequences[:, :, :cur_len]
|
||||
else:
|
||||
input_ids = tf.expand_dims(running_sequences[:, :, cur_len - 1], -1)
|
||||
|
@ -2248,10 +2249,10 @@ class TFGenerationMixin:
|
|||
model_outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
|
||||
# if we don't cache past key values we need the whole input
|
||||
if model_kwargs.get("past", None) is None:
|
||||
# let's throw out `past` since we don't want `None` tensors
|
||||
model_kwargs.pop("past", None)
|
||||
# if we don't cache past_key_values key values we need the whole input
|
||||
if model_kwargs.get("past_key_values", None) is None:
|
||||
# let's throw out `past_key_values` since we don't want `None` tensors
|
||||
model_kwargs.pop("past_key_values", None)
|
||||
|
||||
return (
|
||||
cur_len,
|
||||
|
@ -2264,7 +2265,7 @@ class TFGenerationMixin:
|
|||
)
|
||||
|
||||
# 5. run generation
|
||||
# 1st generation step has to be run before to initialize `past` (if active)
|
||||
# 1st generation step has to be run before to initialize `past_key_values` (if active)
|
||||
(
|
||||
cur_len,
|
||||
running_sequences,
|
||||
|
@ -2472,7 +2473,7 @@ class TFGenerationMixin:
|
|||
|
||||
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
|
||||
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
|
||||
if model_kwargs.get("past") is None:
|
||||
if model_kwargs.get("past_key_values") is None:
|
||||
|
||||
# prepare inputs
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
|
@ -2520,13 +2521,16 @@ class TFGenerationMixin:
|
|||
expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
|
||||
)
|
||||
|
||||
past = model_kwargs.get("past")
|
||||
if past is None:
|
||||
past_key_values = model_kwargs.get("past_key_values")
|
||||
if past_key_values is None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support caching and therefore **can't** be used "
|
||||
"for contrastive search."
|
||||
)
|
||||
elif not isinstance(past[0], (tuple, tf.Tensor)) or past[0][0].shape[0] != batch_size:
|
||||
elif (
|
||||
not isinstance(past_key_values[0], (tuple, tf.Tensor))
|
||||
or past_key_values[0][0].shape[0] != batch_size
|
||||
):
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not have a standard cache format and therefore **can't** be "
|
||||
"used for contrastive search without further modifications."
|
||||
|
@ -2562,8 +2566,8 @@ class TFGenerationMixin:
|
|||
decoder_hidden_states.append(outputs.hidden_states)
|
||||
|
||||
# Replicates the new past_key_values to match the `top_k` candidates
|
||||
model_kwargs["past"] = tf.nest.map_structure(
|
||||
lambda tensor: tf.repeat(tensor, top_k, axis=cache_batch_axis), model_kwargs["past"]
|
||||
model_kwargs["past_key_values"] = tf.nest.map_structure(
|
||||
lambda tensor: tf.repeat(tensor, top_k, axis=cache_batch_axis), model_kwargs["past_key_values"]
|
||||
)
|
||||
|
||||
# compute the candidate tokens by the language model and collects their hidden_states
|
||||
|
@ -2676,7 +2680,7 @@ class TFGenerationMixin:
|
|||
return generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables
|
||||
|
||||
# 5. run generation
|
||||
# 1st generation step has to be run before to initialize `past`
|
||||
# 1st generation step has to be run before to initialize `past_key_values`
|
||||
generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables = contrastive_search_body_fn(
|
||||
generated, finished_sequences, cur_len, model_kwargs, None
|
||||
)
|
||||
|
|
|
@ -675,19 +675,19 @@ class GenerationMixin:
|
|||
return input_ids, model_kwargs
|
||||
|
||||
def _extract_past_from_model_output(self, outputs: ModelOutput, standardize_cache_format: bool = False):
|
||||
past = None
|
||||
past_key_values = None
|
||||
if "past_key_values" in outputs:
|
||||
past = outputs.past_key_values
|
||||
past_key_values = outputs.past_key_values
|
||||
elif "mems" in outputs:
|
||||
past = outputs.mems
|
||||
past_key_values = outputs.mems
|
||||
elif "past_buckets_states" in outputs:
|
||||
past = outputs.past_buckets_states
|
||||
past_key_values = outputs.past_buckets_states
|
||||
|
||||
# Bloom fix: standardizes the cache format when requested
|
||||
if standardize_cache_format and hasattr(self, "_convert_to_standard_cache"):
|
||||
batch_size = outputs.logits.shape[0]
|
||||
past = self._convert_to_standard_cache(past, batch_size=batch_size)
|
||||
return past
|
||||
past_key_values = self._convert_to_standard_cache(past_key_values, batch_size=batch_size)
|
||||
return past_key_values
|
||||
|
||||
def _update_model_kwargs_for_generation(
|
||||
self,
|
||||
|
@ -696,8 +696,8 @@ class GenerationMixin:
|
|||
is_encoder_decoder: bool = False,
|
||||
standardize_cache_format: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
# update past
|
||||
model_kwargs["past"] = self._extract_past_from_model_output(
|
||||
# update past_key_values
|
||||
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
||||
outputs, standardize_cache_format=standardize_cache_format
|
||||
)
|
||||
|
||||
|
@ -1758,7 +1758,7 @@ class GenerationMixin:
|
|||
|
||||
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
|
||||
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
|
||||
if model_kwargs.get("past") is None:
|
||||
if model_kwargs.get("past_key_values") is None:
|
||||
|
||||
# prepare inputs
|
||||
model_kwargs["use_cache"] = True
|
||||
|
@ -1791,13 +1791,16 @@ class GenerationMixin:
|
|||
expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
|
||||
)
|
||||
|
||||
past = model_kwargs.get("past")
|
||||
if past is None:
|
||||
past_key_values = model_kwargs.get("past_key_values")
|
||||
if past_key_values is None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support caching and therefore **can't** be used "
|
||||
"for contrastive search."
|
||||
)
|
||||
elif not isinstance(past[0], (tuple, torch.Tensor)) or past[0][0].shape[0] != batch_size:
|
||||
elif (
|
||||
not isinstance(past_key_values[0], (tuple, torch.Tensor))
|
||||
or past_key_values[0][0].shape[0] != batch_size
|
||||
):
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not have a standard cache format and therefore **can't** be "
|
||||
"used for contrastive search without further modifications."
|
||||
|
@ -1832,13 +1835,13 @@ class GenerationMixin:
|
|||
|
||||
# Replicates the new past_key_values to match the `top_k` candidates
|
||||
new_key_values = []
|
||||
for layer in model_kwargs["past"]:
|
||||
for layer in model_kwargs["past_key_values"]:
|
||||
items = []
|
||||
# item is either the key or the value matrix
|
||||
for item in layer:
|
||||
items.append(item.repeat_interleave(top_k, dim=0))
|
||||
new_key_values.append(items)
|
||||
model_kwargs["past"] = new_key_values
|
||||
model_kwargs["past_key_values"] = new_key_values
|
||||
|
||||
# compute the candidate tokens by the language model and collects their hidden_states
|
||||
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs)
|
||||
|
@ -2718,8 +2721,8 @@ class GenerationMixin:
|
|||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
if model_kwargs["past"] is not None:
|
||||
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
|
||||
if model_kwargs["past_key_values"] is not None:
|
||||
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
|
||||
|
||||
if return_dict_in_generate and output_scores:
|
||||
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
|
||||
|
@ -3040,8 +3043,8 @@ class GenerationMixin:
|
|||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
if model_kwargs["past"] is not None:
|
||||
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
|
||||
if model_kwargs["past_key_values"] is not None:
|
||||
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
|
||||
|
||||
if return_dict_in_generate and output_scores:
|
||||
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
|
||||
|
@ -3410,8 +3413,10 @@ class GenerationMixin:
|
|||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
if model_kwargs["past"] is not None:
|
||||
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices)
|
||||
if model_kwargs["past_key_values"] is not None:
|
||||
model_kwargs["past_key_values"] = self._reorder_cache(
|
||||
model_kwargs["past_key_values"], reordering_indices
|
||||
)
|
||||
|
||||
# increase cur_len
|
||||
cur_len = cur_len + 1
|
||||
|
@ -3732,8 +3737,8 @@ class GenerationMixin:
|
|||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
if model_kwargs["past"] is not None:
|
||||
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
|
||||
if model_kwargs["past_key_values"] is not None:
|
||||
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
|
||||
|
||||
# increase cur_len
|
||||
cur_len = cur_len + 1
|
||||
|
|
|
@ -1418,7 +1418,7 @@ class BartForConditionalGeneration(BartPretrainedModel):
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
|
@ -1428,14 +1428,14 @@ class BartForConditionalGeneration(BartPretrainedModel):
|
|||
encoder_outputs=None,
|
||||
**kwargs
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
# cut decoder_input_ids if past_key_values is used
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
|
@ -1910,18 +1910,20 @@ class BartForCausalLM(BartPretrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
|
||||
):
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
|
|
|
@ -1436,7 +1436,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
|
@ -1447,21 +1447,21 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
|
|||
**kwargs
|
||||
):
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
# cut decoder_input_ids if past_key_values is used
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
if decoder_attention_mask is not None: # xla
|
||||
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
|
||||
elif past is not None: # no xla + past
|
||||
decoder_position_ids = past[0][0].shape[2]
|
||||
else: # no xla + no past
|
||||
elif past_key_values is not None: # no xla + past_key_values
|
||||
decoder_position_ids = past_key_values[0][0].shape[2]
|
||||
else: # no xla + no past_key_values
|
||||
decoder_position_ids = tf.range(decoder_input_ids.shape[1])
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
|
|
|
@ -1274,20 +1274,22 @@ class BertLMHeadModel(BertPreTrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=True, **model_kwargs):
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, use_cache=True, **model_kwargs
|
||||
):
|
||||
input_shape = input_ids.shape
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
# cut decoder_input_ids if past_key_values is used
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
|
|
|
@ -1395,17 +1395,17 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
|
||||
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
|
||||
|
||||
@unpack_inputs
|
||||
@add_code_sample_docstrings(
|
||||
|
|
|
@ -987,20 +987,20 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past:
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
return reordered_past
|
||||
|
|
|
@ -2626,7 +2626,7 @@ class BigBirdForCausalLM(BigBirdPreTrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
|
@ -2634,14 +2634,14 @@ class BigBirdForCausalLM(BigBirdPreTrainedModel):
|
|||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past:
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
)
|
||||
|
|
|
@ -2617,7 +2617,7 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
|
@ -2627,14 +2627,14 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
|
|||
encoder_outputs=None,
|
||||
**kwargs
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
# cut decoder_input_ids if past_key_values is used
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
|
@ -3105,18 +3105,20 @@ class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
|
||||
):
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
|
|
|
@ -706,16 +706,16 @@ class BioGptForCausalLM(BioGptPreTrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, attention_mask, past=None, **kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, attention_mask, past_key_values=None, **kwargs):
|
||||
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
}
|
||||
|
||||
|
|
|
@ -1378,7 +1378,7 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
|
@ -1388,13 +1388,13 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
|
|||
**kwargs
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
|
@ -1609,18 +1609,20 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
|
||||
):
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
|
|
|
@ -1438,7 +1438,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
|
@ -1449,21 +1449,21 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
|
|||
**kwargs
|
||||
):
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
# cut decoder_input_ids if past_key_values is used
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
if decoder_attention_mask is not None: # xla
|
||||
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
|
||||
elif past is not None: # no xla + past
|
||||
decoder_position_ids = past[0][0].shape[2]
|
||||
else: # no xla + no past
|
||||
elif past_key_values is not None: # no xla + past_key_values
|
||||
decoder_position_ids = past_key_values[0][0].shape[2]
|
||||
else: # no xla + no past_key_values
|
||||
decoder_position_ids = tf.range(decoder_input_ids.shape[1])
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
|
|
|
@ -1345,7 +1345,7 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
|
@ -1355,13 +1355,13 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
|
|||
**kwargs
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
|
@ -1576,18 +1576,20 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
|
||||
):
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
|
|
|
@ -1418,7 +1418,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
|
@ -1429,21 +1429,21 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
|||
**kwargs
|
||||
):
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
# cut decoder_input_ids if past_key_values is used
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
if decoder_attention_mask is not None: # xla
|
||||
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
|
||||
elif past is not None: # no xla + past
|
||||
decoder_position_ids = past[0][0].shape[2]
|
||||
else: # no xla + no past
|
||||
elif past_key_values is not None: # no xla + past_key_values
|
||||
decoder_position_ids = past_key_values[0][0].shape[2]
|
||||
else: # no xla + no past_key_values
|
||||
decoder_position_ids = tf.range(decoder_input_ids.shape[1])
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
|
|
|
@ -917,20 +917,20 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
# cut decoder_input_ids if past_key_values is used
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
|
||||
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
|
||||
"is_decoder": True,
|
||||
|
|
|
@ -842,21 +842,21 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
past: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
) -> dict:
|
||||
# only last token for input_ids if past is not None
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
|
||||
if past[0][0].shape[0] == input_ids.shape[0]:
|
||||
past = self._convert_to_bloom_cache(past)
|
||||
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
|
||||
past_key_values = self._convert_to_bloom_cache(past_key_values)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
|
|
@ -1559,17 +1559,17 @@ class CamembertForCausalLM(CamembertPreTrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
reordered_past = ()
|
||||
|
|
|
@ -1612,17 +1612,17 @@ class TFCamembertForCausalLM(TFCamembertPreTrainedModel, TFCausalLanguageModelin
|
|||
return self.name + "/" + self.lm_head.name
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
|
|
|
@ -632,10 +632,10 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
|
|||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
||||
|
@ -647,13 +647,13 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
|
|||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past:
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
else:
|
||||
position_ids = None
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
|
|
|
@ -525,12 +525,12 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
|||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cache=None, **kwargs):
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
return {"input_ids": input_ids, "past_key_values": past, "use_cache": use_cache}
|
||||
return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache}
|
||||
|
||||
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
|
|
|
@ -641,12 +641,12 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
|
||||
return self.name + "/" + self.lm_head.name
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cache=None, **kwargs):
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = tf.expand_dims(input_ids[:, -1], -1)
|
||||
|
||||
return {"input_ids": input_ids, "past_key_values": past, "use_cache": use_cache}
|
||||
return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache}
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
|
||||
|
|
|
@ -1015,17 +1015,17 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
reordered_past = ()
|
||||
|
|
|
@ -1666,17 +1666,17 @@ class ElectraForCausalLM(ElectraPreTrainedModel):
|
|||
)
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM._reorder_cache
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
|
|
|
@ -1214,20 +1214,22 @@ class ErnieForCausalLM(ErniePreTrainedModel):
|
|||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=True, **model_kwargs):
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, use_cache=True, **model_kwargs
|
||||
):
|
||||
input_shape = input_ids.shape
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
# cut decoder_input_ids if past_key_values is used
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
|
|
|
@ -1268,7 +1268,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
|
@ -1280,7 +1280,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
|
|||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
|
|
|
@ -983,10 +983,10 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
||||
|
@ -998,13 +998,13 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past:
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
else:
|
||||
position_ids = None
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
|
@ -1156,10 +1156,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
||||
|
@ -1171,14 +1171,14 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past:
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
else:
|
||||
position_ids = None
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
|
|
|
@ -828,10 +828,10 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
|
|||
def set_output_embeddings(self, value):
|
||||
self.set_input_embeddings(value)
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs):
|
||||
def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs):
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
if past_key_values:
|
||||
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1)
|
||||
|
@ -841,14 +841,14 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
|
|||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)
|
||||
if past:
|
||||
if past_key_values:
|
||||
position_ids = tf.expand_dims(position_ids[:, -1], -1)
|
||||
|
||||
return {
|
||||
"input_ids": inputs,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
|
|
|
@ -683,10 +683,10 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
|
|||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
||||
|
@ -698,13 +698,13 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
|
|||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past:
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
else:
|
||||
position_ids = None
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
|
|
|
@ -686,7 +686,7 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
|
|||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
|
@ -694,13 +694,13 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
|
|||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past and past[0] is not None:
|
||||
if past_key_values and past_key_values[0] is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past or model_kwargs.get("past_key_values"),
|
||||
"past_key_values": past_key_values,
|
||||
}
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
|
|
|
@ -702,7 +702,7 @@ class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel):
|
|||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
|
@ -710,10 +710,10 @@ class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel):
|
|||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past and past[0] is not None:
|
||||
if past_key_values and past_key_values[0] is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
reordered_past = ()
|
||||
|
|
|
@ -760,10 +760,10 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
|
|||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
||||
|
@ -775,13 +775,13 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
|
|||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past:
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
else:
|
||||
position_ids = None
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
|
|
|
@ -741,10 +741,10 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs):
|
||||
def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs):
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
if past_key_values:
|
||||
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1)
|
||||
|
@ -754,14 +754,14 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)
|
||||
if past:
|
||||
if past_key_values:
|
||||
position_ids = tf.expand_dims(position_ids[:, -1], -1)
|
||||
|
||||
return {
|
||||
"input_ids": inputs,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
|
|
|
@ -914,10 +914,10 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
|
|||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past: Optional[bool] = None, **kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past_key_values: Optional[bool] = None, **kwargs):
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
||||
|
@ -929,13 +929,13 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
|
|||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past:
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
else:
|
||||
position_ids = None
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
|
|
|
@ -2480,7 +2480,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
global_attention_mask=None,
|
||||
head_mask=None,
|
||||
|
@ -2491,13 +2491,13 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
|
|||
**kwargs,
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"global_attention_mask": global_attention_mask,
|
||||
|
|
|
@ -2512,7 +2512,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
|
@ -2521,13 +2521,13 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
|||
**kwargs
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
|
|
|
@ -2093,7 +2093,7 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel):
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
|
@ -2104,12 +2104,12 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel):
|
|||
):
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"decoder_input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
|
|
|
@ -1373,7 +1373,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
|
@ -1383,13 +1383,13 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
|
|||
**kwargs,
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
|
|
|
@ -1491,7 +1491,7 @@ class MarianMTModel(MarianPreTrainedModel):
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids: torch.LongTensor,
|
||||
past: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||
|
@ -1501,13 +1501,13 @@ class MarianMTModel(MarianPreTrainedModel):
|
|||
**kwargs,
|
||||
) -> Dict:
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
|
@ -1727,18 +1727,20 @@ class MarianForCausalLM(MarianPreTrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
|
||||
):
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
|
|
|
@ -1455,7 +1455,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
|
@ -1466,21 +1466,21 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||
**kwargs
|
||||
):
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
# cut decoder_input_ids if past_key_values is used
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
if decoder_attention_mask is not None: # xla
|
||||
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
|
||||
elif past is not None: # no xla + past
|
||||
decoder_position_ids = past[0][0].shape[2]
|
||||
else: # no xla + no past
|
||||
elif past_key_values is not None: # no xla + past_key_values
|
||||
decoder_position_ids = past_key_values[0][0].shape[2]
|
||||
else: # no xla + no past_key_values
|
||||
decoder_position_ids = tf.range(decoder_input_ids.shape[1])
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
|
|
|
@ -937,20 +937,22 @@ class MarkupLMModel(MarkupLMPreTrainedModel):
|
|||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=True, **model_kwargs):
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, use_cache=True, **model_kwargs
|
||||
):
|
||||
input_shape = input_ids.shape
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
# cut decoder_input_ids if past_key_values is used
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
|
|
|
@ -1404,7 +1404,7 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
|
@ -1414,13 +1414,13 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
|
|||
**kwargs
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
|
@ -1892,18 +1892,20 @@ class MBartForCausalLM(MBartPreTrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
|
||||
):
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
|
|
|
@ -1452,7 +1452,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
|
@ -1463,21 +1463,21 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
|
|||
**kwargs
|
||||
):
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
# cut decoder_input_ids if past_key_values is used
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
if decoder_attention_mask is not None: # xla
|
||||
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
|
||||
elif past is not None: # no xla + past
|
||||
decoder_position_ids = past[0][0].shape[2]
|
||||
else: # no xla + no past
|
||||
elif past_key_values is not None: # no xla + past_key_values
|
||||
decoder_position_ids = past_key_values[0][0].shape[2]
|
||||
else: # no xla + no past_key_values
|
||||
decoder_position_ids = tf.range(decoder_input_ids.shape[1])
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
|
|
|
@ -1247,17 +1247,17 @@ class MegatronBertForCausalLM(MegatronBertPreTrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
reordered_past = ()
|
||||
|
|
|
@ -1746,7 +1746,7 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
|
@ -1757,12 +1757,12 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
|
|||
):
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"decoder_input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
|
|
|
@ -1555,7 +1555,7 @@ class MvpForConditionalGeneration(MvpPreTrainedModel):
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
|
@ -1565,13 +1565,13 @@ class MvpForConditionalGeneration(MvpPreTrainedModel):
|
|||
**kwargs
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
|
@ -2040,18 +2040,20 @@ class MvpForCausalLM(MvpPreTrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
|
||||
):
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
|
|
|
@ -966,18 +966,20 @@ class OPTForCausalLM(OPTPreTrainedModel):
|
|||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
|
||||
):
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
|
|
|
@ -881,17 +881,17 @@ class TFOPTForCausalLM(TFOPTPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||
def get_output_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs):
|
||||
def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs):
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
if past_key_values:
|
||||
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||
|
||||
return {
|
||||
"input_ids": inputs,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
|
|
|
@ -1452,7 +1452,7 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
|
@ -1462,13 +1462,13 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
|
|||
**kwargs
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
|
@ -1706,18 +1706,20 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
|
||||
):
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
|
|
|
@ -1465,7 +1465,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
|
@ -1476,21 +1476,21 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
|
|||
**kwargs
|
||||
):
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
# cut decoder_input_ids if past_key_values is used
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
if decoder_attention_mask is not None: # xla
|
||||
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
|
||||
elif past is not None: # no xla + past
|
||||
decoder_position_ids = past[0][0].shape[2]
|
||||
else: # no xla + no past
|
||||
elif past_key_values is not None: # no xla + past_key_values
|
||||
decoder_position_ids = past_key_values[0][0].shape[2]
|
||||
else: # no xla + no past_key_values
|
||||
decoder_position_ids = tf.range(decoder_input_ids.shape[1])
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
|
|
|
@ -1659,16 +1659,22 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel):
|
|||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
|
|
|
@ -1375,7 +1375,7 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids: torch.LongTensor,
|
||||
past: Optional[List[torch.FloatTensor]] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||
|
@ -1385,13 +1385,13 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
|
|||
**kwargs # TODO: Check if this is needed. It is unused?
|
||||
) -> Dict[str, Any]:
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
|
@ -1737,18 +1737,20 @@ class PLBartForCausalLM(PLBartPreTrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
|
||||
):
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
|
|
|
@ -2062,7 +2062,7 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
|
@ -2073,13 +2073,13 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
|
|||
):
|
||||
assert encoder_outputs is not None, "`encoder_outputs` have to be passed for generation."
|
||||
|
||||
if past:
|
||||
if past_key_values:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
|
@ -2316,7 +2316,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
use_cache=None,
|
||||
|
@ -2326,14 +2326,14 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
|
|||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
|
|
|
@ -1145,7 +1145,7 @@ class QDQBertLMHeadModel(QDQBertPreTrainedModel):
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor],
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**model_kwargs
|
||||
):
|
||||
|
@ -1155,10 +1155,10 @@ class QDQBertLMHeadModel(QDQBertPreTrainedModel):
|
|||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
reordered_past = ()
|
||||
|
|
|
@ -1170,7 +1170,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
|
@ -1178,7 +1178,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||
n_docs=None,
|
||||
**kwargs
|
||||
):
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
# if past is defined use only last decoder_input_ids
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
|
@ -1188,7 +1188,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||
"doc_scores": doc_scores,
|
||||
"context_attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"do_marginalize": True,
|
||||
"n_docs": n_docs,
|
||||
|
|
|
@ -764,7 +764,7 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
|
@ -772,7 +772,7 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
|||
n_docs=None,
|
||||
**kwargs
|
||||
):
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
# if past is defined use only last decoder_input_ids
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
|
@ -782,7 +782,7 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
|||
"doc_scores": doc_scores,
|
||||
"context_attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"do_marginalize": True,
|
||||
"n_docs": n_docs,
|
||||
|
|
|
@ -2286,14 +2286,16 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
|
|||
attentions=reformer_outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, num_hashes=None, **kwargs):
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, use_cache=None, num_hashes=None, **kwargs
|
||||
):
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"past_buckets_states": past,
|
||||
"past_buckets_states": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"num_hashes": num_hashes,
|
||||
}
|
||||
|
|
|
@ -1141,7 +1141,7 @@ class RemBertForCausalLM(RemBertPreTrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
|
@ -1149,10 +1149,10 @@ class RemBertForCausalLM(RemBertPreTrainedModel):
|
|||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
reordered_past = ()
|
||||
|
|
|
@ -1131,17 +1131,17 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
|
|||
return self.mlm.predictions
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
|
||||
|
||||
@unpack_inputs
|
||||
@add_code_sample_docstrings(
|
||||
|
|
|
@ -1013,17 +1013,17 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
reordered_past = ()
|
||||
|
|
|
@ -1171,17 +1171,17 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos
|
|||
return self.name + "/" + self.lm_head.name
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
|
|
|
@ -1020,17 +1020,17 @@ class RobertaPreLayerNormForCausalLM(RobertaPreLayerNormPreTrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
reordered_past = ()
|
||||
|
|
|
@ -1186,17 +1186,17 @@ class TFRobertaPreLayerNormForCausalLM(TFRobertaPreLayerNormPreTrainedModel, TFC
|
|||
return self.name + "/" + self.lm_head.name
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
|
|
|
@ -1551,7 +1551,7 @@ class RoCBertForCausalLM(RoCBertPreTrainedModel):
|
|||
input_ids,
|
||||
input_shape_ids=None,
|
||||
input_pronunciation_ids=None,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
**model_kwargs
|
||||
):
|
||||
|
@ -1562,7 +1562,7 @@ class RoCBertForCausalLM(RoCBertPreTrainedModel):
|
|||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
if input_shape_ids is not None:
|
||||
input_shape_ids = input_shape_ids[:, -1:]
|
||||
|
@ -1574,7 +1574,7 @@ class RoCBertForCausalLM(RoCBertPreTrainedModel):
|
|||
"input_shape_ids": input_shape_ids,
|
||||
"input_pronunciation_ids": input_pronunciation_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
}
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel._reorder_cache
|
||||
|
|
|
@ -1178,7 +1178,7 @@ class RoFormerForCausalLM(RoFormerPreTrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
|
@ -1186,10 +1186,10 @@ class RoFormerForCausalLM(RoFormerPreTrainedModel):
|
|||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
reordered_past = ()
|
||||
|
|
|
@ -1395,7 +1395,7 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
|
@ -1405,12 +1405,12 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
|
|||
**kwargs
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
|
|
|
@ -1477,7 +1477,7 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
|
@ -1487,13 +1487,13 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
|
|||
**kwargs
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_features": None, # needs to be passed to make Keras.layer.__call__ happy
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
|
|
|
@ -951,18 +951,20 @@ class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
|
||||
):
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
|
|
|
@ -1750,7 +1750,7 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
|
@ -1761,12 +1761,12 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
|
|||
):
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"decoder_input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
|
|
|
@ -1713,7 +1713,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
|
@ -1724,12 +1724,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||
):
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"decoder_input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
|
|
|
@ -1243,7 +1243,7 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||
past = decoder_outputs[1] if use_cache else None
|
||||
|
||||
if not return_dict:
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
|
@ -1441,7 +1441,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||
|
||||
past = decoder_outputs[1] if use_cache else None
|
||||
if not return_dict:
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
|
||||
output = (logits,) + decoder_outputs[1:] + encoder_outputs
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
@ -1499,7 +1499,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
|
@ -1510,13 +1510,13 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||
):
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # needs to be passed to make Keras.layer.__call__ happy
|
||||
"decoder_input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
|
|
|
@ -1028,11 +1028,11 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
|
|||
attentions=attns,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **model_kwargs):
|
||||
inputs = {}
|
||||
|
||||
# if past is defined in model kwargs then use it for faster decoding
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = tf.expand_dims(input_ids[:, -1], axis=-1)
|
||||
else:
|
||||
input_ids = input_ids
|
||||
|
|
|
@ -1150,12 +1150,12 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
|||
else:
|
||||
return self.crit.out_layers[-1]
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **model_kwargs):
|
||||
inputs = {}
|
||||
|
||||
# if past is defined in model kwargs then use it for faster decoding
|
||||
if past:
|
||||
inputs["mems"] = past
|
||||
if past_key_values:
|
||||
inputs["mems"] = past_key_values
|
||||
inputs["input_ids"] = input_ids[:, -1].unsqueeze(-1)
|
||||
else:
|
||||
inputs["input_ids"] = input_ids
|
||||
|
|
|
@ -992,18 +992,20 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
|
||||
):
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
|
|
|
@ -1358,7 +1358,7 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
attention_mask=None,
|
||||
|
@ -1366,13 +1366,13 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua
|
|||
**kwargs
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
if decoder_attention_mask is not None: # xla
|
||||
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
|
||||
elif past is not None: # no xla + past
|
||||
decoder_position_ids = past[0][0].shape[2]
|
||||
elif past_key_values is not None: # no xla + past
|
||||
decoder_position_ids = past_key_values[0][0].shape[2]
|
||||
else: # no xla + no past
|
||||
decoder_position_ids = tf.range(decoder_input_ids.shape[1])
|
||||
decoder_position_ids = tf.broadcast_to(decoder_position_ids, decoder_input_ids.shape)
|
||||
|
@ -1380,7 +1380,7 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua
|
|||
return {
|
||||
"input_features": None, # Needs to be passed to make Keras.layer.__call__ happy
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"use_cache": use_cache,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
|
|
|
@ -1233,15 +1233,21 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, decoder_input_ids, past=None, use_cache=None, encoder_outputs=None, attention_mask=None, **kwargs
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
attention_mask=None,
|
||||
**kwargs
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"use_cache": use_cache,
|
||||
"decoder_attention_mask": None,
|
||||
|
|
|
@ -888,9 +888,9 @@ class TFXGLMForCausalLM(TFXGLMPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs):
|
||||
def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs):
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
if past_key_values:
|
||||
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
|
@ -898,7 +898,7 @@ class TFXGLMForCausalLM(TFXGLMPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||
return {
|
||||
"input_ids": inputs,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
|
|
|
@ -929,18 +929,20 @@ class XGLMForCausalLM(XGLMPreTrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
|
||||
):
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
|
|
|
@ -2089,7 +2089,7 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel):
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
|
@ -2100,13 +2100,13 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel):
|
|||
):
|
||||
assert encoder_outputs is not None, "`encoder_outputs` have to be passed for generation."
|
||||
|
||||
if past:
|
||||
if past_key_values:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
|
@ -2346,7 +2346,7 @@ class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel):
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
use_cache=None,
|
||||
|
@ -2356,14 +2356,14 @@ class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel):
|
|||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
|
|
|
@ -1017,17 +1017,17 @@ class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
reordered_past = ()
|
||||
|
|
|
@ -979,17 +979,17 @@ class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel):
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
reordered_past = ()
|
||||
|
|
|
@ -1217,7 +1217,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
|
||||
return self.name + "/" + self.lm_loss.name
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, past=None, use_mems=None, **kwargs):
|
||||
def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_mems=None, **kwargs):
|
||||
# Add dummy token at the end (no attention on this one)
|
||||
effective_batch_size = inputs.shape[0]
|
||||
dummy_token = tf.zeros((effective_batch_size, 1), dtype=inputs.dtype)
|
||||
|
@ -1227,7 +1227,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||
# offset = 1; offset = 2 seems to have slightly better computation.
|
||||
offset = 2
|
||||
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = tf.concat([inputs[:, -offset:], dummy_token], axis=1)
|
||||
else:
|
||||
input_ids = tf.concat([inputs, dummy_token], axis=1)
|
||||
|
@ -1251,8 +1251,8 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||
}
|
||||
|
||||
# if past is defined in model kwargs then use it for faster decoding
|
||||
if past:
|
||||
inputs["mems"] = tuple(layer_past[:-offset, :, :] for layer_past in past)
|
||||
if past_key_values:
|
||||
inputs["mems"] = tuple(layer_past[:-offset, :, :] for layer_past in past_key_values)
|
||||
|
||||
return inputs
|
||||
|
||||
|
|
|
@ -1315,7 +1315,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_loss = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, use_mems=None, **kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_mems=None, **kwargs):
|
||||
# Add dummy token at the end (no attention on this one)
|
||||
|
||||
effective_batch_size = input_ids.shape[0]
|
||||
|
@ -1326,7 +1326,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||
# offset = 1; offset = 2 seems to have slightly better computation.
|
||||
offset = 2
|
||||
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = torch.cat([input_ids[:, -offset:], dummy_token], dim=1)
|
||||
else:
|
||||
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
||||
|
@ -1352,8 +1352,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||
}
|
||||
|
||||
# if past is defined in model kwargs then use it for faster decoding
|
||||
if past:
|
||||
inputs["mems"] = tuple(layer_past[:-offset, :, :] for layer_past in past)
|
||||
if past_key_values:
|
||||
inputs["mems"] = tuple(layer_past[:-offset, :, :] for layer_past in past_key_values)
|
||||
|
||||
return inputs
|
||||
|
||||
|
|
|
@ -1121,15 +1121,15 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
|
|||
def get_lm_head(self) -> tf.keras.layers.Layer:
|
||||
return self.mlm.predictions
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, inputs, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past:
|
||||
if past_key_values:
|
||||
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||
|
||||
return {
|
||||
"input_ids": inputs,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": model_kwargs["use_cache"],
|
||||
}
|
||||
|
||||
|
@ -3003,7 +3003,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
|
@ -3013,13 +3013,13 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
|||
**kwargs
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # needs to be passed to make Keras.layer.__call__ happy
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
|
|
|
@ -1167,7 +1167,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
|
@ -1175,10 +1175,10 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
|||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
reordered_past = ()
|
||||
|
@ -2879,7 +2879,7 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
|
|||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
|
@ -2889,13 +2889,13 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
|
|||
**kwargs
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
|
@ -3328,7 +3328,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
|||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs):
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
@ -3339,7 +3339,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
|||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue