[Generate] Remove attention_mask and integrate model_main_input_name (#14856)

* up

* save

* correct

* up

* correct more

* up

* up

* up

* up

* up

* correct

* fix tf

* fix

* remove tokenizer
This commit is contained in:
Patrick von Platen 2021-12-23 19:43:37 +01:00 committed by GitHub
parent 86b40073e9
commit fe4197ab11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 117 additions and 155 deletions

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
@ -349,9 +350,6 @@ BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOu
BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]
ENCODER_MODEL_INPUT_NAMES = ["input_ids", "inputs_embeds", "input_values", "input_features", "pixel_values"]
class GenerationMixin:
"""
A class containing all of the functions supporting generation, to be used as a mixin in
@ -363,58 +361,69 @@ class GenerationMixin:
inputs: Optional[torch.Tensor] = None,
bos_token_id: Optional[int] = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[str]]:
) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
"""
This function extracts the model-specific `inputs` for generation.
"""
# filter model input names that are `None`
model_kwargs = {k: v for k, v in model_kwargs.items() if k not in ENCODER_MODEL_INPUT_NAMES or v is not None}
# extract keyword arguments that are model input specific
model_input_kwarg_names = set(ENCODER_MODEL_INPUT_NAMES) & set(model_kwargs.keys())
# 1. retrieve all kwargs that are non-None or non-model input related.
# some encoder-decoder models have different names for model and encoder
if (
self.config.is_encoder_decoder
and hasattr(self, "encoder")
and self.encoder.main_input_name != self.main_input_name
):
input_name = self.encoder.main_input_name
else:
input_name = self.main_input_name
# There are 5 possible scenarios
if inputs is not None and len(model_input_kwarg_names) == 0:
# 1. `inputs` are passed and no model-specific keyword inputs
# -> return input
model_input_name = None
return inputs, model_input_name, model_kwargs
elif inputs is not None and len(model_input_kwarg_names) > 0:
# 2. `inputs` are passed as well as model-specific keyword inputs
# -> not allowed, raise Error
model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}
# 2. check whether model_input_name is passed as kwarg
# if yes and `inputs` is None use kwarg inputs
inputs_kwarg = model_kwargs.pop(input_name, None)
if inputs_kwarg is not None and inputs is not None:
raise ValueError(
f"`inputs`: {inputs}` were passed alongside "
f"{model_input_kwarg_names} which is not allowed."
f"Make sure to not pass any of {model_input_kwarg_names} "
"when `inputs` is defined."
f"{input_name} which is not allowed."
f"Make sure to either pass {inputs} or {input_name}=..."
)
elif inputs is None and len(model_input_kwarg_names) == 0:
# 3. no `inputs` and no model-specific keyword inputs are passed
# -> try to create `input_ids` from BOS
input_tensor = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))
return input_tensor, "input_ids", model_kwargs
elif inputs is None and len(model_input_kwarg_names) == 1:
# 4. no `inputs` are passed and exactly one model-specific keyword input
# -> return that model-specific keyword input tensor
model_input_name = model_input_kwarg_names.pop()
input_tensor = model_kwargs.pop(model_input_name)
elif inputs_kwarg is not None:
inputs = inputs_kwarg
# make sure model is encoder decoder if not `input_ids`
if not self.config.is_encoder_decoder and model_input_name != "input_ids":
raise ValueError(
f"If {model_input_name} is passed as model-specific keyword "
"input then model has to be an encoder-decoder and not a "
f"{self.__class__.__name__}."
)
return input_tensor, model_input_name, model_kwargs
else:
# 5. no `inputs` are passed and multiple model-specific keyword inputs
# -> not allowed, raise Error
# 3. models with `input_ids` can also make use of `inputs_embeds`
if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs):
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
# 4. Only encoder-decoder models can have non `input_ids` input format
if not self.config.is_encoder_decoder and input_name != "input_ids":
raise ValueError(
f"Can only pass one of {ENCODER_MODEL_INPUT_NAMES}, "
f"but passed {model_input_kwarg_names}."
f"Make sure to only pass one of {model_input_kwarg_names}."
f"If {input_name} is passed as model-specific keyword "
"input then model has to be an encoder-decoder and not a "
f"{self.__class__.__name__}."
)
# 5. if `inputs` is still None, try to create `input_ids` from BOS token
if inputs is None:
inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))
return inputs, input_name, model_kwargs
def _can_retrieve_inputs_from_name(
self, inputs: Optional[torch.Tensor], name: str, model_kwargs: Dict[str, torch.Tensor]
) -> torch.Tensor:
"""
If `inputs` is None and `name` is in both forward function and keyword
arguments, then inputs can be retrieved from name
"""
can_retrieve_inputs = model_kwargs.get(name, None) is not None and name in set(
inspect.signature(self.forward).parameters.keys()
)
if can_retrieve_inputs and inputs is not None:
raise ValueError(f"Cannot only pass one of {name} and {self.main_input_name}")
return can_retrieve_inputs
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
"""
Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the
@ -461,29 +470,22 @@ class GenerationMixin:
def _prepare_encoder_decoder_kwargs_for_generation(
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
) -> Dict[str, Any]:
if "encoder_outputs" not in model_kwargs:
# 1. get encoder
encoder = self.get_encoder()
# 2. prepare encoder args and encoder kwargs from model kwargs
encoder_args = (inputs_tensor,)
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
encoder_kwargs = {
argument: value
for argument, value in model_kwargs.items()
if not any(argument.startswith(p) for p in irrelevant_prefix)
}
# 3. make sure that encoder returns `ModelOutput`
encoder_kwargs["return_dict"] = True
# 1. get encoder
encoder = self.get_encoder()
# 4. if model_input_name is not defined then pass input_tensor as
# first input argument and remove from args
if model_input_name is not None:
# make sure inputs_tensor is None in case model
# accepts multiple model input arguments
encoder_kwargs[model_input_name] = inputs_tensor
encoder_args = ()
# 2. prepare encoder args and encoder kwargs from model kwargs
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
encoder_kwargs = {
argument: value
for argument, value in model_kwargs.items()
if not any(argument.startswith(p) for p in irrelevant_prefix)
}
model_kwargs["encoder_outputs"]: ModelOutput = encoder(*encoder_args, **encoder_kwargs)
# 3. make sure that encoder returns `ModelOutput`
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
encoder_kwargs["return_dict"] = True
encoder_kwargs[model_input_name] = inputs_tensor
model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
return model_kwargs
@ -1013,12 +1015,13 @@ class GenerationMixin:
model_kwargs["output_hidden_states"] = output_hidden_states
model_kwargs["use_cache"] = use_cache
if model_kwargs.get("attention_mask", None) is None:
has_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
if model_kwargs.get("attention_mask", None) is None and has_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, pad_token_id, eos_token_id
)
if self.config.is_encoder_decoder:
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
# if model is encoder decoder encoder_outputs are created
# and added to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(

View File

@ -57,8 +57,6 @@ class KerasMetricCallback(Callback):
Validation data to be used to generate predictions for the `metric_fn`.
metric_fn_kwargs (`dict`, *optional*):
Additional keyword arguments to be passed to the metric_fn.
tokenizer ([`PretrainedTokenizerBase`], *optional*):
Tokenizer used to validate column names to be passed to the generate() function.
output_cols (`List[str], *optional*):
A list of columns to be retained from the model output as the predictions. Defaults to all.
label_cols ('`List[str]`, *optional*'):
@ -75,7 +73,6 @@ class KerasMetricCallback(Callback):
self,
metric_fn: Callable,
eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict],
tokenizer: Optional[PreTrainedTokenizerBase] = None,
metric_fn_kwargs: Optional[dict] = None,
output_cols: Optional[List[str]] = None,
label_cols: Optional[List[str]] = None,
@ -97,10 +94,11 @@ class KerasMetricCallback(Callback):
self.predict_with_generate = predict_with_generate
self.output_cols = output_cols
self.metric_fn_kwargs = metric_fn_kwargs or dict()
if tokenizer is not None:
self.model_input_names = tokenizer.model_input_names
if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
self.main_input_name = self.model.encoder.main_input_name
else:
self.model_input_names = ["input_ids"]
self.main_input_name = self.model.main_input_name
# This next block attempts to parse out which elements of the dataset should be appended to the labels list
# that is passed to the metric_fn
@ -161,9 +159,13 @@ class KerasMetricCallback(Callback):
labels = None
if self.predict_with_generate:
if isinstance(batch, dict):
# generate() gets stressed out by any unexpected keys
batch = {key: array for key, array in batch.items() if key in self.model_input_names}
predictions = self.model.generate(batch)
generation_inputs = batch[self.main_input_name]
attention_mask = batch.get("attention_mask", None)
else:
generation_inputs = batch
attention_mask = None
predictions = self.model.generate(generation_inputs, attention_mask=attention_mask)
else:
predictions = self.model.predict(batch)
predictions = dict(predictions)

View File

@ -478,7 +478,6 @@ class DeiTModel(DeiTPreTrainedModel):
def forward(
self,
pixel_values=None,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,

View File

@ -69,19 +69,11 @@ SPEECH_ENCODER_DECODER_START_DOCSTRING = r"""
SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
inputs (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac* or *.wav* audio file
into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
soundfile*). To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should
be used for padding and conversion into a tensor of type *torch.FloatTensor*. See
[`Wav2Vec2Processor.__call__`] for details.
input_features (`torch.LongTensor` of shape `(batch_size, sequence_length, feature_size)`, *optional*):
Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be obtained
by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array
into `input_features`, the [`Speech2TextTokenizer`] should be used for extracting
the fbank features, padding and conversion into a tensor of type `torch.FloatTensor`. See
[`~Speech2TextTokenizer.__call__`]
soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or [`Speech2TextProcessor`] should
be used for padding and conversion into a tensor of type *torch.FloatTensor*.
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
@ -137,6 +129,19 @@ SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
soundfile*). To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should
be used for padding and conversion into a tensor of type *torch.FloatTensor*. See
[`Wav2Vec2Processor.__call__`] for details.
input_features (`torch.LongTensor` of shape `(batch_size, sequence_length, feature_size)`, *optional*):
Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be obtained
by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array
into `input_features`, the [`Speech2TextTokenizer`] should be used for extracting
the fbank features, padding and conversion into a tensor of type `torch.FloatTensor`. See
[`~Speech2TextTokenizer.__call__`]
return_dict (`bool`, *optional*):
If set to `True`, the model will return a [`~file_utils.Seq2SeqLMOutput`] instead of a
plain tuple.
@ -176,7 +181,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
"""
config_class = SpeechEncoderDecoderConfig
base_model_prefix = "speech_encoder_decoder"
main_input_name = "input_values"
main_input_name = "inputs"
def __init__(
self,
@ -417,8 +422,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_values=None,
input_features=None,
inputs=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
@ -429,6 +433,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
use_cache=None,
output_attentions=None,
output_hidden_states=None,
input_values=None,
input_features=None,
return_dict=None,
**kwargs,
):
@ -463,7 +469,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
if encoder_outputs is None:
if encoder_outputs is None and inputs is None:
if input_values is not None and input_features is not None:
raise ValueError("You cannot specify both input_values and input_features at the same time")
elif input_values is not None:

View File

@ -507,7 +507,6 @@ class ViTModel(ViTPreTrainedModel):
def forward(
self,
pixel_values=None,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,

View File

@ -161,11 +161,17 @@ class Seq2SeqTrainer(Trainer):
"synced_gpus": True if is_deepspeed_zero3_enabled() else False,
}
model_input_names = self.tokenizer.model_input_names if self.tokenizer is not None else ["input_ids"]
generation_inputs = {k: v for k, v in inputs.items() if k in model_input_names}
# prepare generation inputs
# some encoder-decoder models can have varying encder's and thus
# varying model input names
if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
generation_inputs = inputs[self.model.encoder.main_input_name]
else:
generation_inputs = inputs[self.model.main_input_name]
generated_tokens = self.model.generate(
**generation_inputs,
generation_inputs,
attention_mask=inputs.get("attention_mask", None),
**gen_kwargs,
)
# in case the batch is shorter than max length, the output should be padded

View File

@ -1856,7 +1856,7 @@ class GenerationIntegrationTests(unittest.TestCase):
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10).to(torch_device)
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
with self.assertRaises(ValueError):
model.generate(input_ids=input_ids, input_values=input_ids)
model.generate(input_ids=input_ids, inputs_embeds=input_ids)
def test_generate_input_values_as_encoder_kwarg(self):
input_values = floats_tensor((2, 250))

View File

@ -64,14 +64,7 @@ class EncoderDecoderMixin:
pass
def check_encoder_decoder_model_from_pretrained_configs(
self,
config,
attention_mask,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
pixel_values=None,
**kwargs
self, config, decoder_config, decoder_input_ids, decoder_attention_mask, pixel_values=None, **kwargs
):
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
self.assertTrue(encoder_decoder_config.decoder.is_decoder)
@ -84,7 +77,6 @@ class EncoderDecoderMixin:
outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
@ -94,14 +86,7 @@ class EncoderDecoderMixin:
)
def check_encoder_decoder_model(
self,
config,
attention_mask,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
pixel_values=None,
**kwargs
self, config, decoder_config, decoder_input_ids, decoder_attention_mask, pixel_values=None, **kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
@ -111,7 +96,6 @@ class EncoderDecoderMixin:
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_hidden_states=True,
@ -122,7 +106,6 @@ class EncoderDecoderMixin:
encoder_outputs = BaseModelOutput(last_hidden_state=outputs_encoder_decoder.encoder_hidden_states[-1])
outputs_encoder_decoder = enc_dec_model(
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
@ -134,7 +117,6 @@ class EncoderDecoderMixin:
def check_encoder_decoder_model_from_pretrained(
self,
config,
attention_mask,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
@ -148,7 +130,6 @@ class EncoderDecoderMixin:
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_hidden_states=True,
@ -160,14 +141,7 @@ class EncoderDecoderMixin:
)
def check_save_and_load(
self,
config,
attention_mask,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
pixel_values=None,
**kwargs
self, config, decoder_config, decoder_input_ids, decoder_attention_mask, pixel_values=None, **kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
@ -176,7 +150,6 @@ class EncoderDecoderMixin:
with torch.no_grad():
outputs = enc_dec_model(
pixel_values=pixel_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
@ -190,7 +163,6 @@ class EncoderDecoderMixin:
after_outputs = enc_dec_model(
pixel_values=pixel_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
@ -200,14 +172,7 @@ class EncoderDecoderMixin:
self.assertLessEqual(max_diff, 1e-5)
def check_save_and_load_encoder_decoder_model(
self,
config,
attention_mask,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
pixel_values=None,
**kwargs
self, config, decoder_config, decoder_input_ids, decoder_attention_mask, pixel_values=None, **kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
@ -216,7 +181,6 @@ class EncoderDecoderMixin:
with torch.no_grad():
outputs = enc_dec_model(
pixel_values=pixel_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
@ -233,7 +197,6 @@ class EncoderDecoderMixin:
after_outputs = enc_dec_model(
pixel_values=pixel_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
@ -245,7 +208,6 @@ class EncoderDecoderMixin:
def check_encoder_decoder_model_output_attentions(
self,
config,
attention_mask,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
@ -261,7 +223,6 @@ class EncoderDecoderMixin:
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
@ -382,13 +343,10 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
]
)
# for DEiT, the sequence length is equal to the number of patches + 2 (for the [CLS] and distillation tokens)
seq_len = (model.encoder.config.image_size // model.encoder.config.patch_size) ** 2 + 2
attention_mask = random_attention_mask([batch_size, seq_len])
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
decoder_attention_mask = random_attention_mask([batch_size, 4])
inputs = {
"pixel_values": pixel_values,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
}
@ -398,7 +356,6 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
def check_encoder_decoder_model_output_attentions(
self,
config,
attention_mask,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
@ -414,7 +371,6 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
@ -463,7 +419,6 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
encoder_config_and_inputs = deit_model_tester.prepare_config_and_inputs()
decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder()
config, pixel_values, _ = encoder_config_and_inputs
input_mask = None # TODO add once attention_mask is supported for vision models
(
decoder_config,
decoder_input_ids,
@ -481,7 +436,6 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
return {
"config": config,
"pixel_values": pixel_values,
"attention_mask": input_mask,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_token_type_ids": decoder_token_type_ids,
@ -509,13 +463,10 @@ class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
]
)
# for ViT, the sequence length is equal to the number of patches + 1 (for the [CLS] token)
seq_len = (model.encoder.config.image_size // model.encoder.config.patch_size) ** 2 + 1
attention_mask = random_attention_mask([batch_size, seq_len])
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
decoder_attention_mask = random_attention_mask([batch_size, 4])
inputs = {
"pixel_values": pixel_values,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
}
@ -534,7 +485,6 @@ class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder()
config, pixel_values, _ = encoder_config_and_inputs
input_mask = None # TODO add once attention_mask is supported for vision models
(
decoder_config,
@ -553,7 +503,6 @@ class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
return {
"config": config,
"pixel_values": pixel_values,
"attention_mask": input_mask,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_token_type_ids": decoder_token_type_ids,
@ -580,7 +529,6 @@ class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase):
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs()
config, pixel_values, _ = encoder_config_and_inputs
input_mask = None # TODO add once attention_mask is supported for vision models
(decoder_config, decoder_input_ids, decoder_attention_mask, _) = decoder_config_and_inputs
# make sure that cross attention layers are added
@ -590,7 +538,6 @@ class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase):
return {
"config": config,
"pixel_values": pixel_values,
"attention_mask": input_mask,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,