[Generate] Remove attention_mask and integrate model_main_input_name (#14856)
* up * save * correct * up * correct more * up * up * up * up * up * correct * fix tf * fix * remove tokenizer
This commit is contained in:
parent
86b40073e9
commit
fe4197ab11
|
@ -14,6 +14,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
@ -349,9 +350,6 @@ BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOu
|
|||
BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]
|
||||
|
||||
|
||||
ENCODER_MODEL_INPUT_NAMES = ["input_ids", "inputs_embeds", "input_values", "input_features", "pixel_values"]
|
||||
|
||||
|
||||
class GenerationMixin:
|
||||
"""
|
||||
A class containing all of the functions supporting generation, to be used as a mixin in
|
||||
|
@ -363,58 +361,69 @@ class GenerationMixin:
|
|||
inputs: Optional[torch.Tensor] = None,
|
||||
bos_token_id: Optional[int] = None,
|
||||
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[str]]:
|
||||
) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
This function extracts the model-specific `inputs` for generation.
|
||||
"""
|
||||
# filter model input names that are `None`
|
||||
model_kwargs = {k: v for k, v in model_kwargs.items() if k not in ENCODER_MODEL_INPUT_NAMES or v is not None}
|
||||
# extract keyword arguments that are model input specific
|
||||
model_input_kwarg_names = set(ENCODER_MODEL_INPUT_NAMES) & set(model_kwargs.keys())
|
||||
# 1. retrieve all kwargs that are non-None or non-model input related.
|
||||
# some encoder-decoder models have different names for model and encoder
|
||||
if (
|
||||
self.config.is_encoder_decoder
|
||||
and hasattr(self, "encoder")
|
||||
and self.encoder.main_input_name != self.main_input_name
|
||||
):
|
||||
input_name = self.encoder.main_input_name
|
||||
else:
|
||||
input_name = self.main_input_name
|
||||
|
||||
# There are 5 possible scenarios
|
||||
if inputs is not None and len(model_input_kwarg_names) == 0:
|
||||
# 1. `inputs` are passed and no model-specific keyword inputs
|
||||
# -> return input
|
||||
model_input_name = None
|
||||
return inputs, model_input_name, model_kwargs
|
||||
elif inputs is not None and len(model_input_kwarg_names) > 0:
|
||||
# 2. `inputs` are passed as well as model-specific keyword inputs
|
||||
# -> not allowed, raise Error
|
||||
model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}
|
||||
|
||||
# 2. check whether model_input_name is passed as kwarg
|
||||
# if yes and `inputs` is None use kwarg inputs
|
||||
inputs_kwarg = model_kwargs.pop(input_name, None)
|
||||
if inputs_kwarg is not None and inputs is not None:
|
||||
raise ValueError(
|
||||
f"`inputs`: {inputs}` were passed alongside "
|
||||
f"{model_input_kwarg_names} which is not allowed."
|
||||
f"Make sure to not pass any of {model_input_kwarg_names} "
|
||||
"when `inputs` is defined."
|
||||
f"{input_name} which is not allowed."
|
||||
f"Make sure to either pass {inputs} or {input_name}=..."
|
||||
)
|
||||
elif inputs is None and len(model_input_kwarg_names) == 0:
|
||||
# 3. no `inputs` and no model-specific keyword inputs are passed
|
||||
# -> try to create `input_ids` from BOS
|
||||
input_tensor = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))
|
||||
return input_tensor, "input_ids", model_kwargs
|
||||
elif inputs is None and len(model_input_kwarg_names) == 1:
|
||||
# 4. no `inputs` are passed and exactly one model-specific keyword input
|
||||
# -> return that model-specific keyword input tensor
|
||||
model_input_name = model_input_kwarg_names.pop()
|
||||
input_tensor = model_kwargs.pop(model_input_name)
|
||||
elif inputs_kwarg is not None:
|
||||
inputs = inputs_kwarg
|
||||
|
||||
# make sure model is encoder decoder if not `input_ids`
|
||||
if not self.config.is_encoder_decoder and model_input_name != "input_ids":
|
||||
raise ValueError(
|
||||
f"If {model_input_name} is passed as model-specific keyword "
|
||||
"input then model has to be an encoder-decoder and not a "
|
||||
f"{self.__class__.__name__}."
|
||||
)
|
||||
return input_tensor, model_input_name, model_kwargs
|
||||
else:
|
||||
# 5. no `inputs` are passed and multiple model-specific keyword inputs
|
||||
# -> not allowed, raise Error
|
||||
# 3. models with `input_ids` can also make use of `inputs_embeds`
|
||||
if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs):
|
||||
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
|
||||
|
||||
# 4. Only encoder-decoder models can have non `input_ids` input format
|
||||
if not self.config.is_encoder_decoder and input_name != "input_ids":
|
||||
raise ValueError(
|
||||
f"Can only pass one of {ENCODER_MODEL_INPUT_NAMES}, "
|
||||
f"but passed {model_input_kwarg_names}."
|
||||
f"Make sure to only pass one of {model_input_kwarg_names}."
|
||||
f"If {input_name} is passed as model-specific keyword "
|
||||
"input then model has to be an encoder-decoder and not a "
|
||||
f"{self.__class__.__name__}."
|
||||
)
|
||||
|
||||
# 5. if `inputs` is still None, try to create `input_ids` from BOS token
|
||||
if inputs is None:
|
||||
inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))
|
||||
|
||||
return inputs, input_name, model_kwargs
|
||||
|
||||
def _can_retrieve_inputs_from_name(
|
||||
self, inputs: Optional[torch.Tensor], name: str, model_kwargs: Dict[str, torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
If `inputs` is None and `name` is in both forward function and keyword
|
||||
arguments, then inputs can be retrieved from name
|
||||
"""
|
||||
can_retrieve_inputs = model_kwargs.get(name, None) is not None and name in set(
|
||||
inspect.signature(self.forward).parameters.keys()
|
||||
)
|
||||
|
||||
if can_retrieve_inputs and inputs is not None:
|
||||
raise ValueError(f"Cannot only pass one of {name} and {self.main_input_name}")
|
||||
|
||||
return can_retrieve_inputs
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the
|
||||
|
@ -461,29 +470,22 @@ class GenerationMixin:
|
|||
def _prepare_encoder_decoder_kwargs_for_generation(
|
||||
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
if "encoder_outputs" not in model_kwargs:
|
||||
# 1. get encoder
|
||||
encoder = self.get_encoder()
|
||||
# 2. prepare encoder args and encoder kwargs from model kwargs
|
||||
encoder_args = (inputs_tensor,)
|
||||
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
|
||||
encoder_kwargs = {
|
||||
argument: value
|
||||
for argument, value in model_kwargs.items()
|
||||
if not any(argument.startswith(p) for p in irrelevant_prefix)
|
||||
}
|
||||
# 3. make sure that encoder returns `ModelOutput`
|
||||
encoder_kwargs["return_dict"] = True
|
||||
# 1. get encoder
|
||||
encoder = self.get_encoder()
|
||||
|
||||
# 4. if model_input_name is not defined then pass input_tensor as
|
||||
# first input argument and remove from args
|
||||
if model_input_name is not None:
|
||||
# make sure inputs_tensor is None in case model
|
||||
# accepts multiple model input arguments
|
||||
encoder_kwargs[model_input_name] = inputs_tensor
|
||||
encoder_args = ()
|
||||
# 2. prepare encoder args and encoder kwargs from model kwargs
|
||||
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
|
||||
encoder_kwargs = {
|
||||
argument: value
|
||||
for argument, value in model_kwargs.items()
|
||||
if not any(argument.startswith(p) for p in irrelevant_prefix)
|
||||
}
|
||||
|
||||
model_kwargs["encoder_outputs"]: ModelOutput = encoder(*encoder_args, **encoder_kwargs)
|
||||
# 3. make sure that encoder returns `ModelOutput`
|
||||
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
|
||||
encoder_kwargs["return_dict"] = True
|
||||
encoder_kwargs[model_input_name] = inputs_tensor
|
||||
model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
|
||||
|
||||
return model_kwargs
|
||||
|
||||
|
@ -1013,12 +1015,13 @@ class GenerationMixin:
|
|||
model_kwargs["output_hidden_states"] = output_hidden_states
|
||||
model_kwargs["use_cache"] = use_cache
|
||||
|
||||
if model_kwargs.get("attention_mask", None) is None:
|
||||
has_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
||||
if model_kwargs.get("attention_mask", None) is None and has_attention_mask:
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
inputs_tensor, pad_token_id, eos_token_id
|
||||
)
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
|
||||
# if model is encoder decoder encoder_outputs are created
|
||||
# and added to `model_kwargs`
|
||||
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
|
||||
|
|
|
@ -57,8 +57,6 @@ class KerasMetricCallback(Callback):
|
|||
Validation data to be used to generate predictions for the `metric_fn`.
|
||||
metric_fn_kwargs (`dict`, *optional*):
|
||||
Additional keyword arguments to be passed to the metric_fn.
|
||||
tokenizer ([`PretrainedTokenizerBase`], *optional*):
|
||||
Tokenizer used to validate column names to be passed to the generate() function.
|
||||
output_cols (`List[str], *optional*):
|
||||
A list of columns to be retained from the model output as the predictions. Defaults to all.
|
||||
label_cols ('`List[str]`, *optional*'):
|
||||
|
@ -75,7 +73,6 @@ class KerasMetricCallback(Callback):
|
|||
self,
|
||||
metric_fn: Callable,
|
||||
eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict],
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
metric_fn_kwargs: Optional[dict] = None,
|
||||
output_cols: Optional[List[str]] = None,
|
||||
label_cols: Optional[List[str]] = None,
|
||||
|
@ -97,10 +94,11 @@ class KerasMetricCallback(Callback):
|
|||
self.predict_with_generate = predict_with_generate
|
||||
self.output_cols = output_cols
|
||||
self.metric_fn_kwargs = metric_fn_kwargs or dict()
|
||||
if tokenizer is not None:
|
||||
self.model_input_names = tokenizer.model_input_names
|
||||
|
||||
if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
|
||||
self.main_input_name = self.model.encoder.main_input_name
|
||||
else:
|
||||
self.model_input_names = ["input_ids"]
|
||||
self.main_input_name = self.model.main_input_name
|
||||
|
||||
# This next block attempts to parse out which elements of the dataset should be appended to the labels list
|
||||
# that is passed to the metric_fn
|
||||
|
@ -161,9 +159,13 @@ class KerasMetricCallback(Callback):
|
|||
labels = None
|
||||
if self.predict_with_generate:
|
||||
if isinstance(batch, dict):
|
||||
# generate() gets stressed out by any unexpected keys
|
||||
batch = {key: array for key, array in batch.items() if key in self.model_input_names}
|
||||
predictions = self.model.generate(batch)
|
||||
generation_inputs = batch[self.main_input_name]
|
||||
attention_mask = batch.get("attention_mask", None)
|
||||
else:
|
||||
generation_inputs = batch
|
||||
attention_mask = None
|
||||
|
||||
predictions = self.model.generate(generation_inputs, attention_mask=attention_mask)
|
||||
else:
|
||||
predictions = self.model.predict(batch)
|
||||
predictions = dict(predictions)
|
||||
|
|
|
@ -478,7 +478,6 @@ class DeiTModel(DeiTPreTrainedModel):
|
|||
def forward(
|
||||
self,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
|
|
|
@ -69,19 +69,11 @@ SPEECH_ENCODER_DECODER_START_DOCSTRING = r"""
|
|||
|
||||
SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
|
||||
inputs (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
|
||||
Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac* or *.wav* audio file
|
||||
into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
|
||||
soundfile*). To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should
|
||||
be used for padding and conversion into a tensor of type *torch.FloatTensor*. See
|
||||
[`Wav2Vec2Processor.__call__`] for details.
|
||||
input_features (`torch.LongTensor` of shape `(batch_size, sequence_length, feature_size)`, *optional*):
|
||||
Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be obtained
|
||||
by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
|
||||
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array
|
||||
into `input_features`, the [`Speech2TextTokenizer`] should be used for extracting
|
||||
the fbank features, padding and conversion into a tensor of type `torch.FloatTensor`. See
|
||||
[`~Speech2TextTokenizer.__call__`]
|
||||
soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or [`Speech2TextProcessor`] should
|
||||
be used for padding and conversion into a tensor of type *torch.FloatTensor*.
|
||||
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
|
@ -137,6 +129,19 @@ SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
|
|||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
|
||||
into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
|
||||
soundfile*). To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should
|
||||
be used for padding and conversion into a tensor of type *torch.FloatTensor*. See
|
||||
[`Wav2Vec2Processor.__call__`] for details.
|
||||
input_features (`torch.LongTensor` of shape `(batch_size, sequence_length, feature_size)`, *optional*):
|
||||
Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be obtained
|
||||
by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
|
||||
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array
|
||||
into `input_features`, the [`Speech2TextTokenizer`] should be used for extracting
|
||||
the fbank features, padding and conversion into a tensor of type `torch.FloatTensor`. See
|
||||
[`~Speech2TextTokenizer.__call__`]
|
||||
return_dict (`bool`, *optional*):
|
||||
If set to `True`, the model will return a [`~file_utils.Seq2SeqLMOutput`] instead of a
|
||||
plain tuple.
|
||||
|
@ -176,7 +181,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
|||
"""
|
||||
config_class = SpeechEncoderDecoderConfig
|
||||
base_model_prefix = "speech_encoder_decoder"
|
||||
main_input_name = "input_values"
|
||||
main_input_name = "inputs"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -417,8 +422,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
|||
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_values=None,
|
||||
input_features=None,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
|
@ -429,6 +433,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
|||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
input_values=None,
|
||||
input_features=None,
|
||||
return_dict=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -463,7 +469,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
|||
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
||||
}
|
||||
|
||||
if encoder_outputs is None:
|
||||
if encoder_outputs is None and inputs is None:
|
||||
if input_values is not None and input_features is not None:
|
||||
raise ValueError("You cannot specify both input_values and input_features at the same time")
|
||||
elif input_values is not None:
|
||||
|
|
|
@ -507,7 +507,6 @@ class ViTModel(ViTPreTrainedModel):
|
|||
def forward(
|
||||
self,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
|
|
|
@ -161,11 +161,17 @@ class Seq2SeqTrainer(Trainer):
|
|||
"synced_gpus": True if is_deepspeed_zero3_enabled() else False,
|
||||
}
|
||||
|
||||
model_input_names = self.tokenizer.model_input_names if self.tokenizer is not None else ["input_ids"]
|
||||
generation_inputs = {k: v for k, v in inputs.items() if k in model_input_names}
|
||||
# prepare generation inputs
|
||||
# some encoder-decoder models can have varying encder's and thus
|
||||
# varying model input names
|
||||
if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
|
||||
generation_inputs = inputs[self.model.encoder.main_input_name]
|
||||
else:
|
||||
generation_inputs = inputs[self.model.main_input_name]
|
||||
|
||||
generated_tokens = self.model.generate(
|
||||
**generation_inputs,
|
||||
generation_inputs,
|
||||
attention_mask=inputs.get("attention_mask", None),
|
||||
**gen_kwargs,
|
||||
)
|
||||
# in case the batch is shorter than max length, the output should be padded
|
||||
|
|
|
@ -1856,7 +1856,7 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10).to(torch_device)
|
||||
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(input_ids=input_ids, input_values=input_ids)
|
||||
model.generate(input_ids=input_ids, inputs_embeds=input_ids)
|
||||
|
||||
def test_generate_input_values_as_encoder_kwarg(self):
|
||||
input_values = floats_tensor((2, 250))
|
||||
|
|
|
@ -64,14 +64,7 @@ class EncoderDecoderMixin:
|
|||
pass
|
||||
|
||||
def check_encoder_decoder_model_from_pretrained_configs(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
pixel_values=None,
|
||||
**kwargs
|
||||
self, config, decoder_config, decoder_input_ids, decoder_attention_mask, pixel_values=None, **kwargs
|
||||
):
|
||||
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
self.assertTrue(encoder_decoder_config.decoder.is_decoder)
|
||||
|
@ -84,7 +77,6 @@ class EncoderDecoderMixin:
|
|||
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
@ -94,14 +86,7 @@ class EncoderDecoderMixin:
|
|||
)
|
||||
|
||||
def check_encoder_decoder_model(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
pixel_values=None,
|
||||
**kwargs
|
||||
self, config, decoder_config, decoder_input_ids, decoder_attention_mask, pixel_values=None, **kwargs
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
|
@ -111,7 +96,6 @@ class EncoderDecoderMixin:
|
|||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_hidden_states=True,
|
||||
|
@ -122,7 +106,6 @@ class EncoderDecoderMixin:
|
|||
encoder_outputs = BaseModelOutput(last_hidden_state=outputs_encoder_decoder.encoder_hidden_states[-1])
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
encoder_outputs=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
@ -134,7 +117,6 @@ class EncoderDecoderMixin:
|
|||
def check_encoder_decoder_model_from_pretrained(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
|
@ -148,7 +130,6 @@ class EncoderDecoderMixin:
|
|||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_hidden_states=True,
|
||||
|
@ -160,14 +141,7 @@ class EncoderDecoderMixin:
|
|||
)
|
||||
|
||||
def check_save_and_load(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
pixel_values=None,
|
||||
**kwargs
|
||||
self, config, decoder_config, decoder_input_ids, decoder_attention_mask, pixel_values=None, **kwargs
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
|
@ -176,7 +150,6 @@ class EncoderDecoderMixin:
|
|||
with torch.no_grad():
|
||||
outputs = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
@ -190,7 +163,6 @@ class EncoderDecoderMixin:
|
|||
|
||||
after_outputs = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
@ -200,14 +172,7 @@ class EncoderDecoderMixin:
|
|||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
def check_save_and_load_encoder_decoder_model(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
pixel_values=None,
|
||||
**kwargs
|
||||
self, config, decoder_config, decoder_input_ids, decoder_attention_mask, pixel_values=None, **kwargs
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
|
@ -216,7 +181,6 @@ class EncoderDecoderMixin:
|
|||
with torch.no_grad():
|
||||
outputs = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
@ -233,7 +197,6 @@ class EncoderDecoderMixin:
|
|||
|
||||
after_outputs = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
@ -245,7 +208,6 @@ class EncoderDecoderMixin:
|
|||
def check_encoder_decoder_model_output_attentions(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
|
@ -261,7 +223,6 @@ class EncoderDecoderMixin:
|
|||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=True,
|
||||
|
@ -382,13 +343,10 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||
]
|
||||
)
|
||||
# for DEiT, the sequence length is equal to the number of patches + 2 (for the [CLS] and distillation tokens)
|
||||
seq_len = (model.encoder.config.image_size // model.encoder.config.patch_size) ** 2 + 2
|
||||
attention_mask = random_attention_mask([batch_size, seq_len])
|
||||
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
|
||||
decoder_attention_mask = random_attention_mask([batch_size, 4])
|
||||
inputs = {
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
@ -398,7 +356,6 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||
def check_encoder_decoder_model_output_attentions(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
|
@ -414,7 +371,6 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=True,
|
||||
|
@ -463,7 +419,6 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||
encoder_config_and_inputs = deit_model_tester.prepare_config_and_inputs()
|
||||
decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder()
|
||||
config, pixel_values, _ = encoder_config_and_inputs
|
||||
input_mask = None # TODO add once attention_mask is supported for vision models
|
||||
(
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
|
@ -481,7 +436,6 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||
return {
|
||||
"config": config,
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": input_mask,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_token_type_ids": decoder_token_type_ids,
|
||||
|
@ -509,13 +463,10 @@ class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||
]
|
||||
)
|
||||
# for ViT, the sequence length is equal to the number of patches + 1 (for the [CLS] token)
|
||||
seq_len = (model.encoder.config.image_size // model.encoder.config.patch_size) ** 2 + 1
|
||||
attention_mask = random_attention_mask([batch_size, seq_len])
|
||||
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
|
||||
decoder_attention_mask = random_attention_mask([batch_size, 4])
|
||||
inputs = {
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
@ -534,7 +485,6 @@ class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||
decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder()
|
||||
|
||||
config, pixel_values, _ = encoder_config_and_inputs
|
||||
input_mask = None # TODO add once attention_mask is supported for vision models
|
||||
|
||||
(
|
||||
decoder_config,
|
||||
|
@ -553,7 +503,6 @@ class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||
return {
|
||||
"config": config,
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": input_mask,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_token_type_ids": decoder_token_type_ids,
|
||||
|
@ -580,7 +529,6 @@ class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase):
|
|||
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
|
||||
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs()
|
||||
config, pixel_values, _ = encoder_config_and_inputs
|
||||
input_mask = None # TODO add once attention_mask is supported for vision models
|
||||
(decoder_config, decoder_input_ids, decoder_attention_mask, _) = decoder_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
|
@ -590,7 +538,6 @@ class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase):
|
|||
return {
|
||||
"config": config,
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": input_mask,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
|
|
Loading…
Reference in New Issue