[T5, TF 2.2] change tf t5 argument naming (#3547)

* change tf t5 argument naming for TF 2.2

* correct bug in testing
This commit is contained in:
Patrick von Platen 2020-04-01 22:04:20 +02:00 committed by GitHub
parent 06dd597552
commit a4ee4da18a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 25 deletions

View File

@ -592,8 +592,8 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
input_ids = tf.constant(DUMMY_INPUTS)
input_mask = tf.constant(DUMMY_MASK)
dummy_inputs = {
"inputs": input_ids,
"decoder_input_ids": input_ids,
"input_ids": input_ids,
"decoder_attention_mask": input_mask,
}
return dummy_inputs
@ -637,11 +637,9 @@ T5_START_DOCSTRING = r""" The T5 model was proposed in
T5_INPUTS_DOCSTRING = r"""
Args:
decoder_input_ids are usually used as a `dict` (see T5 description above for more information) containing all the following.
decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
inputs are usually used as a `dict` (see T5 description above for more information) containing all the following.
input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
inputs (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
T5 is a model with relative position embeddings so you should be able to pad the inputs on
the right or the left.
@ -650,6 +648,8 @@ T5_INPUTS_DOCSTRING = r"""
`T5 Training <./t5.html#training>`_ .
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
@ -706,7 +706,7 @@ class TFT5Model(TFT5PreTrainedModel):
return self.shared
@add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING)
def call(self, decoder_input_ids, **kwargs):
def call(self, inputs, **kwargs):
r"""
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs.
@ -736,13 +736,13 @@ class TFT5Model(TFT5PreTrainedModel):
"""
if isinstance(decoder_input_ids, dict):
kwargs.update(decoder_input_ids)
if isinstance(inputs, dict):
kwargs.update(inputs)
else:
kwargs["decoder_input_ids"] = decoder_input_ids
kwargs["inputs"] = inputs
# retrieve arguments
input_ids = kwargs.get("input_ids", None)
input_ids = kwargs.get("inputs", None)
decoder_input_ids = kwargs.get("decoder_input_ids", None)
attention_mask = kwargs.get("attention_mask", None)
encoder_outputs = kwargs.get("encoder_outputs", None)
@ -803,7 +803,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
return self.encoder
@add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING)
def call(self, decoder_input_ids, **kwargs):
def call(self, inputs, **kwargs):
r"""
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs.
@ -839,13 +839,13 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
"""
if isinstance(decoder_input_ids, dict):
kwargs.update(decoder_input_ids)
if isinstance(inputs, dict):
kwargs.update(inputs)
else:
kwargs["decoder_input_ids"] = decoder_input_ids
kwargs["inputs"] = inputs
# retrieve arguments
input_ids = kwargs.get("input_ids", None)
input_ids = kwargs.get("inputs", None)
decoder_input_ids = kwargs.get("decoder_input_ids", None)
attention_mask = kwargs.get("attention_mask", None)
encoder_outputs = kwargs.get("encoder_outputs", None)
@ -890,7 +890,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
encoder_outputs = (past,)
return {
"inputs": input_ids,
"inputs": None, # inputs don't have to be defined, but still need to be passed to make Keras.layer.__call__ happy
"decoder_input_ids": input_ids, # input_ids are the decoder_input_ids
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
}

View File

@ -162,6 +162,10 @@ class TFModelTesterMixin:
pt_inputs_dict = dict(
(name, torch.from_numpy(key.numpy()).to(torch.long)) for name, key in inputs_dict.items()
)
# need to rename encoder-decoder "inputs" for PyTorch
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
with torch.no_grad():
pto = pt_model(**pt_inputs_dict)
tfo = tf_model(inputs_dict, training=False)
@ -201,6 +205,10 @@ class TFModelTesterMixin:
pt_inputs_dict = dict(
(name, torch.from_numpy(key.numpy()).to(torch.long)) for name, key in inputs_dict.items()
)
# need to rename encoder-decoder "inputs" for PyTorch
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
with torch.no_grad():
pto = pt_model(**pt_inputs_dict)
tfo = tf_model(inputs_dict)
@ -223,7 +231,7 @@ class TFModelTesterMixin:
if self.is_encoder_decoder:
input_ids = {
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
"inputs": tf.keras.Input(batch_shape=(2, 2000), name="inputs", dtype="int32"),
}
else:
input_ids = tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32")
@ -259,7 +267,7 @@ class TFModelTesterMixin:
outputs_dict = model(inputs_dict)
inputs_keywords = copy.deepcopy(inputs_dict)
input_ids = inputs_keywords.pop("input_ids" if not self.is_encoder_decoder else "decoder_input_ids", None,)
input_ids = inputs_keywords.pop("input_ids" if not self.is_encoder_decoder else "inputs", None,)
outputs_keywords = model(input_ids, **inputs_keywords)
output_dict = outputs_dict[0].numpy()
@ -395,9 +403,9 @@ class TFModelTesterMixin:
input_ids = inputs_dict["input_ids"]
del inputs_dict["input_ids"]
else:
encoder_input_ids = inputs_dict["input_ids"]
encoder_input_ids = inputs_dict["inputs"]
decoder_input_ids = inputs_dict["decoder_input_ids"]
del inputs_dict["input_ids"]
del inputs_dict["inputs"]
del inputs_dict["decoder_input_ids"]
for model_class in self.all_model_classes:
@ -415,7 +423,7 @@ class TFModelTesterMixin:
def test_lm_head_model_random_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"]
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
if self.is_encoder_decoder:
config.output_past = True # needed for Bart TODO: might have to update for other encoder-decoder models

View File

@ -107,13 +107,15 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
def create_and_check_t5_model(self, config, input_ids, input_mask, token_labels):
model = TFT5Model(config=config)
inputs = {
"input_ids": input_ids,
"inputs": input_ids,
"decoder_input_ids": input_ids,
"decoder_attention_mask": input_mask,
}
encoder_output, decoder_output = model(inputs)
encoder_output, decoder_output = model(input_ids, decoder_attention_mask=input_mask, input_ids=input_ids)
encoder_output, decoder_output = model(
input_ids, decoder_attention_mask=input_mask, decoder_input_ids=input_ids
)
result = {
"encoder_output": encoder_output.numpy(),
@ -129,7 +131,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels):
model = TFT5ForConditionalGeneration(config=config)
inputs_dict = {
"input_ids": input_ids,
"inputs": input_ids,
"decoder_input_ids": input_ids,
"decoder_attention_mask": input_mask,
}
@ -147,7 +149,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, token_labels) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"inputs": input_ids,
"decoder_input_ids": input_ids,
"decoder_attention_mask": input_mask,
}