Support multiple choice in tf common model tests (#4920)
* Support multiple choice in tf common model tests * Add the input_embeds test
This commit is contained in:
parent
699541c4b3
commit
20451195f0
|
@ -1100,6 +1100,11 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
|||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||
flat_inputs_embeds = (
|
||||
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
|
||||
if inputs_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
flat_inputs = [
|
||||
flat_input_ids,
|
||||
|
@ -1107,7 +1112,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
|||
flat_token_type_ids,
|
||||
flat_position_ids,
|
||||
head_mask,
|
||||
inputs_embeds,
|
||||
flat_inputs_embeds,
|
||||
output_attentions,
|
||||
]
|
||||
|
||||
|
|
|
@ -49,6 +49,7 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
TFBertForQuestionAnswering,
|
||||
TFBertForSequenceClassification,
|
||||
TFBertForTokenClassification,
|
||||
TFBertForMultipleChoice,
|
||||
)
|
||||
if is_tf_available()
|
||||
else ()
|
||||
|
|
|
@ -30,7 +30,12 @@ if is_tf_available():
|
|||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from transformers import tf_top_k_top_p_filtering, TFAdaptiveEmbedding, TFSharedEmbeddings
|
||||
from transformers import (
|
||||
tf_top_k_top_p_filtering,
|
||||
TFAdaptiveEmbedding,
|
||||
TFSharedEmbeddings,
|
||||
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
)
|
||||
|
||||
if _tf_gpu_memory_limit is not None:
|
||||
gpus = tf.config.list_physical_devices("GPU")
|
||||
|
@ -66,6 +71,16 @@ class TFModelTesterMixin:
|
|||
test_resize_embeddings = True
|
||||
is_encoder_decoder = False
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class):
|
||||
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||
return {
|
||||
k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices, 1))
|
||||
if isinstance(v, tf.Tensor) and v.ndim != 0
|
||||
else v
|
||||
for k, v in inputs_dict.items()
|
||||
}
|
||||
return inputs_dict
|
||||
|
||||
def test_initialization(self):
|
||||
pass
|
||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
@ -83,12 +98,12 @@ class TFModelTesterMixin:
|
|||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
outputs = model(inputs_dict)
|
||||
outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
after_outputs = model(inputs_dict)
|
||||
after_outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
self.assert_outputs_same(after_outputs, outputs)
|
||||
|
||||
|
@ -173,13 +188,16 @@ class TFModelTesterMixin:
|
|||
|
||||
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict)
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(
|
||||
tf_model, pt_model, tf_inputs=self._prepare_for_class(inputs_dict, model_class)
|
||||
)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
|
||||
|
||||
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
|
||||
pt_model.eval()
|
||||
pt_inputs_dict = dict(
|
||||
(name, torch.from_numpy(key.numpy()).to(torch.long)) for name, key in inputs_dict.items()
|
||||
(name, torch.from_numpy(key.numpy()).to(torch.long))
|
||||
for name, key in self._prepare_for_class(inputs_dict, model_class).items()
|
||||
)
|
||||
# need to rename encoder-decoder "inputs" for PyTorch
|
||||
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
|
||||
|
@ -187,7 +205,7 @@ class TFModelTesterMixin:
|
|||
|
||||
with torch.no_grad():
|
||||
pto = pt_model(**pt_inputs_dict)
|
||||
tfo = tf_model(inputs_dict, training=False)
|
||||
tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False)
|
||||
tf_hidden_states = tfo[0].numpy()
|
||||
pt_hidden_states = pto[0].numpy()
|
||||
|
||||
|
@ -222,7 +240,8 @@ class TFModelTesterMixin:
|
|||
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
|
||||
pt_model.eval()
|
||||
pt_inputs_dict = dict(
|
||||
(name, torch.from_numpy(key.numpy()).to(torch.long)) for name, key in inputs_dict.items()
|
||||
(name, torch.from_numpy(key.numpy()).to(torch.long))
|
||||
for name, key in self._prepare_for_class(inputs_dict, model_class).items()
|
||||
)
|
||||
# need to rename encoder-decoder "inputs" for PyTorch
|
||||
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
|
||||
|
@ -230,7 +249,7 @@ class TFModelTesterMixin:
|
|||
|
||||
with torch.no_grad():
|
||||
pto = pt_model(**pt_inputs_dict)
|
||||
tfo = tf_model(inputs_dict)
|
||||
tfo = tf_model(self._prepare_for_class(inputs_dict, model_class))
|
||||
tfo = tfo[0].numpy()
|
||||
pto = pto[0].numpy()
|
||||
tf_nans = np.copy(np.isnan(tfo))
|
||||
|
@ -247,24 +266,29 @@ class TFModelTesterMixin:
|
|||
def test_compile_tf_model(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
input_ids = {
|
||||
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
|
||||
"inputs": tf.keras.Input(batch_shape=(2, 2000), name="inputs", dtype="int32"),
|
||||
}
|
||||
else:
|
||||
input_ids = tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32")
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
|
||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if self.is_encoder_decoder:
|
||||
input_ids = {
|
||||
"decoder_input_ids": tf.keras.Input(
|
||||
batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"
|
||||
),
|
||||
"inputs": tf.keras.Input(batch_shape=(2, 2000), name="inputs", dtype="int32"),
|
||||
}
|
||||
elif model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||
input_ids = tf.keras.Input(batch_shape=(4, 2, 2000), name="input_ids", dtype="int32")
|
||||
else:
|
||||
input_ids = tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32")
|
||||
|
||||
# Prepare our model
|
||||
model = model_class(config)
|
||||
|
||||
# Let's load it from the disk to be sure we can use pretrained weights
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
outputs = model(inputs_dict) # build the model
|
||||
outputs = model(self._prepare_for_class(inputs_dict, model_class)) # build the model
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
|
@ -283,9 +307,9 @@ class TFModelTesterMixin:
|
|||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
outputs_dict = model(inputs_dict)
|
||||
outputs_dict = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
inputs_keywords = copy.deepcopy(inputs_dict)
|
||||
inputs_keywords = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
input_ids = inputs_keywords.pop("input_ids" if not self.is_encoder_decoder else "inputs", None,)
|
||||
outputs_keywords = model(input_ids, **inputs_keywords)
|
||||
output_dict = outputs_dict[0].numpy()
|
||||
|
@ -317,7 +341,7 @@ class TFModelTesterMixin:
|
|||
inputs_dict["output_attentions"] = True
|
||||
config.output_hidden_states = False
|
||||
model = model_class(config)
|
||||
outputs = model(inputs_dict)
|
||||
outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = [t.numpy() for t in outputs[-1]]
|
||||
self.assertEqual(model.config.output_hidden_states, False)
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
@ -341,7 +365,7 @@ class TFModelTesterMixin:
|
|||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
outputs = model(inputs_dict)
|
||||
outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = [t.numpy() for t in outputs[-1]]
|
||||
self.assertEqual(model.config.output_hidden_states, False)
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
@ -354,7 +378,7 @@ class TFModelTesterMixin:
|
|||
inputs_dict["output_attentions"] = True
|
||||
config.output_hidden_states = True
|
||||
model = model_class(config)
|
||||
outputs = model(inputs_dict)
|
||||
outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))
|
||||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
|
||||
|
@ -371,7 +395,7 @@ class TFModelTesterMixin:
|
|||
for model_class in self.all_model_classes:
|
||||
config.output_hidden_states = True
|
||||
model = model_class(config)
|
||||
outputs = model(inputs_dict)
|
||||
outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
hidden_states = [t.numpy() for t in outputs[-1]]
|
||||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
|
||||
|
@ -394,8 +418,8 @@ class TFModelTesterMixin:
|
|||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
first, second = (
|
||||
model(inputs_dict, training=False)[0],
|
||||
model(inputs_dict, training=False)[0],
|
||||
model(self._prepare_for_class(inputs_dict, model_class), training=False)[0],
|
||||
model(self._prepare_for_class(inputs_dict, model_class), training=False)[0],
|
||||
)
|
||||
out_1 = first.numpy()
|
||||
out_2 = second.numpy()
|
||||
|
@ -425,26 +449,28 @@ class TFModelTesterMixin:
|
|||
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if not self.is_encoder_decoder:
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
del inputs_dict["input_ids"]
|
||||
else:
|
||||
encoder_input_ids = inputs_dict["inputs"]
|
||||
decoder_input_ids = inputs_dict["decoder_input_ids"]
|
||||
del inputs_dict["inputs"]
|
||||
del inputs_dict["decoder_input_ids"]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
if not self.is_encoder_decoder:
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
else:
|
||||
encoder_input_ids = inputs["inputs"]
|
||||
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
|
||||
del inputs["inputs"]
|
||||
inputs.pop("decoder_input_ids", None)
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
if not self.is_encoder_decoder:
|
||||
inputs_dict["inputs_embeds"] = self._get_embeds(wte, input_ids)
|
||||
inputs["inputs_embeds"] = self._get_embeds(wte, input_ids)
|
||||
else:
|
||||
inputs_dict["inputs_embeds"] = self._get_embeds(wte, encoder_input_ids)
|
||||
inputs_dict["decoder_inputs_embeds"] = self._get_embeds(wte, decoder_input_ids)
|
||||
inputs["inputs_embeds"] = self._get_embeds(wte, encoder_input_ids)
|
||||
inputs["decoder_inputs_embeds"] = self._get_embeds(wte, decoder_input_ids)
|
||||
|
||||
model(inputs_dict)
|
||||
model(inputs)
|
||||
|
||||
def test_lm_head_model_random_no_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
|
Loading…
Reference in New Issue