Update TF(Vision)EncoderDecoderModel PT/TF equivalence tests (#18073)

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2022-07-18 15:29:14 +02:00 committed by GitHub
parent cb19c2afdc
commit 6561fbcc6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 394 additions and 128 deletions

View File

@ -23,6 +23,7 @@ import numpy as np
from transformers import is_tf_available, is_torch_available
from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_torch, slow, torch_device
from transformers.utils.generic import ModelOutput
from ...test_modeling_tf_common import ids_tensor
from ..bert.test_modeling_tf_bert import TFBertModelTester
@ -326,31 +327,145 @@ class TFEncoderDecoderMixin:
)
self.assertEqual(tuple(generated_output.shape.as_list()), (input_ids.shape[0],) + (decoder_config.max_length,))
def check_pt_tf_equivalence(self, pt_model, tf_model, inputs_dict):
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
"""Check the outputs from PyTorch and TensorFlow models are close enough. Checks are done in a recursive way.
pt_model.to(torch_device)
pt_model.eval()
Args:
model_class: The class of the model that is currently testing. For example, `TFBertModel`,
TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Mainly used for providing more informative
error messages.
name (`str`): The name of the output. For example, `output.hidden_states`, `output.attentions`, etc.
attributes (`Tuple[str]`): The names of the output's element if the output is a tuple/list with each element
being a named field in the output.
"""
# prepare inputs
tf_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()}
if "labels" in pt_inputs:
pt_inputs["labels"] = pt_inputs["labels"].type(torch.LongTensor)
self.assertEqual(type(name), str)
if attributes is not None:
self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`")
# Allow `ModelOutput` (e.g. `CLIPOutput` has `text_model_output` and `vision_model_output`).
if isinstance(tf_outputs, ModelOutput):
self.assertTrue(
isinstance(pt_outputs, ModelOutput),
f"{name}: `pt_outputs` should an instance of `ModelOutput` when `tf_outputs` is",
)
tf_keys = [k for k, v in tf_outputs.items() if v is not None]
pt_keys = [k for k, v in pt_outputs.items() if v is not None]
self.assertEqual(tf_keys, pt_keys, f"{name}: Output keys differ between TF and PyTorch")
# convert to the case of `tuple`
# appending each key to the current (string) `names`
attributes = tuple([f"{name}.{k}" for k in tf_keys])
self.check_pt_tf_outputs(
tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes
)
# Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.)
elif type(tf_outputs) in [tuple, list]:
self.assertEqual(type(tf_outputs), type(pt_outputs), f"{name}: Output types differ between TF and PyTorch")
self.assertEqual(len(tf_outputs), len(pt_outputs), f"{name}: Output lengths differ between TF and PyTorch")
if attributes is not None:
# case 1: each output has assigned name (e.g. a tuple form of a `ModelOutput`)
self.assertEqual(
len(attributes),
len(tf_outputs),
f"{name}: The tuple `names` should have the same length as `tf_outputs`",
)
else:
# case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `names`
attributes = tuple([f"{name}_{idx}" for idx in range(len(tf_outputs))])
for tf_output, pt_output, attr in zip(tf_outputs, pt_outputs, attributes):
self.check_pt_tf_outputs(tf_output, pt_output, model_class, tol=tol, name=attr)
elif isinstance(tf_outputs, tf.Tensor):
self.assertTrue(
isinstance(pt_outputs, torch.Tensor), f"{name}: `pt_outputs` should a tensor when `tf_outputs` is"
)
tf_outputs = tf_outputs.numpy()
pt_outputs = pt_outputs.detach().to("cpu").numpy()
self.assertEqual(
tf_outputs.shape, pt_outputs.shape, f"{name}: Output shapes differ between TF and PyTorch"
)
# deal with NumPy's scalars to make replacing nan values by 0 work.
if np.isscalar(tf_outputs):
tf_outputs = np.array([tf_outputs])
pt_outputs = np.array([pt_outputs])
tf_nans = np.isnan(tf_outputs)
pt_nans = np.isnan(pt_outputs)
pt_outputs[tf_nans] = 0
tf_outputs[tf_nans] = 0
pt_outputs[pt_nans] = 0
tf_outputs[pt_nans] = 0
max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
self.assertLessEqual(max_diff, tol, f"{name}: Difference between torch and tf is {max_diff} (>= {tol}).")
else:
raise ValueError(
"`tf_outputs` should be an instance of `tf.Tensor`, a `tuple`, or an instance of `tf.Tensor`. Got"
f" {type(tf_outputs)} instead."
)
def prepare_pt_inputs_from_tf_inputs(self, tf_inputs_dict):
pt_inputs_dict = {}
for name, key in tf_inputs_dict.items():
if type(key) == bool:
pt_inputs_dict[name] = key
elif name == "input_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
elif name == "pixel_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
elif name == "input_features":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
# other general float inputs
elif tf_inputs_dict[name].dtype.is_floating:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
return pt_inputs_dict
def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict):
pt_inputs_dict = self.prepare_pt_inputs_from_tf_inputs(tf_inputs_dict)
# send pytorch inputs to the correct device
pt_inputs = {k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()}
pt_inputs_dict = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs_dict.items()
}
# send pytorch model to the correct device
pt_model.to(torch_device)
# Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences
pt_model.eval()
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
pt_outputs = pt_model(**pt_inputs_dict)
tf_outputs = tf_model(tf_inputs_dict)
tf_outputs = tf_model(**inputs_dict)
if "loss" in tf_outputs:
tf_outputs.loss = tf.math.reduce_mean(tf_outputs.loss)
tf_outputs = tf_outputs.to_tuple()
self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch")
# tf models returned loss is usually a tensor rather than a scalar.
# (see `hf_compute_loss`: it uses `tf.keras.losses.Reduction.NONE`)
# Change it here to a scalar to match PyTorch models' loss
tf_loss = getattr(tf_outputs, "loss", None)
if tf_loss is not None:
tf_outputs.loss = tf.math.reduce_mean(tf_loss)
for tf_output, pt_output in zip(tf_outputs, pt_outputs):
self.assert_almost_equals(tf_output.numpy(), pt_output.detach().to("cpu").numpy(), 1e-3)
self.check_pt_tf_outputs(tf_outputs, pt_outputs, type(tf_model))
def check_pt_tf_equivalence(self, tf_model, pt_model, tf_inputs_dict):
"""Wrap `check_pt_tf_models` to further check PT -> TF again"""
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
# PT -> TF
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
@ -363,18 +478,16 @@ class TFEncoderDecoderMixin:
# This is only for copying some specific attributes of this particular model.
tf_model_loaded.config = pt_model.config
tf_outputs_loaded = tf_model_loaded(**inputs_dict)
if "loss" in tf_outputs_loaded:
tf_outputs_loaded.loss = tf.math.reduce_mean(tf_outputs_loaded.loss)
tf_outputs_loaded = tf_outputs_loaded.to_tuple()
self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch")
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs):
self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.detach().to("cpu").numpy(), 1e-3)
def check_equivalence_pt_to_tf(self, config, decoder_config, inputs_dict):
def check_pt_to_tf_equivalence(self, config, decoder_config, tf_inputs_dict):
"""EncoderDecoderModel requires special way to cross load (PT -> TF)"""
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
# Output all for aggressive testing
encoder_decoder_config.output_hidden_states = True
# All models tested in this file have attentions
encoder_decoder_config.output_attentions = True
pt_model = EncoderDecoderModel(encoder_decoder_config)
@ -388,11 +501,16 @@ class TFEncoderDecoderMixin:
# This is only for copying some specific attributes of this particular model.
tf_model.config = pt_model.config
self.check_pt_tf_equivalence(pt_model, tf_model, inputs_dict)
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
def check_equivalence_tf_to_pt(self, config, decoder_config, inputs_dict):
def check_tf_to_pt_equivalence(self, config, decoder_config, tf_inputs_dict):
"""EncoderDecoderModel requires special way to cross load (TF -> PT)"""
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
# Output all for aggressive testing
encoder_decoder_config.output_hidden_states = True
# TODO: A generalizable way to determine this attribute
encoder_decoder_config.output_attentions = True
# Using `_tf_model`, the test will fail, because the weights of `_tf_model` get extended before saving
# the encoder/decoder models.
@ -401,7 +519,7 @@ class TFEncoderDecoderMixin:
# (the change in `src/transformers/modeling_tf_utils.py`)
_tf_model = TFEncoderDecoderModel(encoder_decoder_config)
# Make sure model is built
_tf_model(**inputs_dict)
_tf_model(**tf_inputs_dict)
# Using `tf_model` to pass the test.
encoder = _tf_model.encoder.__class__(encoder_decoder_config.encoder)
@ -410,6 +528,7 @@ class TFEncoderDecoderMixin:
encoder(encoder.dummy_inputs)
decoder(decoder.dummy_inputs)
tf_model = TFEncoderDecoderModel(encoder=encoder, decoder=decoder)
tf_model.config = encoder_decoder_config
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
@ -421,7 +540,7 @@ class TFEncoderDecoderMixin:
# This is only for copying some specific attributes of this particular model.
pt_model.config = tf_model.config
self.check_pt_tf_equivalence(pt_model, tf_model, inputs_dict)
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
def test_encoder_decoder_model(self):
input_ids_dict = self.prepare_config_and_inputs()
@ -460,7 +579,7 @@ class TFEncoderDecoderMixin:
self.assertLessEqual(diff, tol, f"Difference between torch and tf is {diff} (>= {tol}).")
@is_pt_tf_cross_test
def test_pt_tf_equivalence(self):
def test_pt_tf_model_equivalence(self):
config_inputs_dict = self.prepare_config_and_inputs()
labels = config_inputs_dict.pop("decoder_token_labels")
@ -480,48 +599,58 @@ class TFEncoderDecoderMixin:
config = config_inputs_dict.pop("config")
decoder_config = config_inputs_dict.pop("decoder_config")
inputs_dict = config_inputs_dict
# Output all for aggressive testing
config.output_hidden_states = True
decoder_config.output_hidden_states = True
# All models tested in this file have attentions
config.output_attentions = True
decoder_config.output_attentions = True
tf_inputs_dict = config_inputs_dict
# `encoder_hidden_states` is not used in model call/forward
del inputs_dict["encoder_hidden_states"]
del tf_inputs_dict["encoder_hidden_states"]
inputs_dict_with_labels = copy.copy(inputs_dict)
inputs_dict_with_labels["labels"] = labels
# Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
# of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
for k in ["attention_mask", "decoder_attention_mask"]:
attention_mask = tf_inputs_dict[k]
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
inputs_dict["decoder_attention_mask"] = tf.constant(
np.concatenate([np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1)
)
# Make sure no all 0s attention masks - to avoid failure at this moment.
# Put `1` at the beginning of sequences to make it still work when combining causal attention masks.
# TODO: remove this line once a fix regarding large negative values for attention mask is done.
attention_mask = tf.concat(
[tf.ones_like(attention_mask[:, :1], dtype=attention_mask.dtype), attention_mask[:, 1:]], axis=-1
)
tf_inputs_dict[k] = attention_mask
# TF models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
decoder_config.use_cache = False
tf_inputs_dict_with_labels = copy.copy(tf_inputs_dict)
tf_inputs_dict_with_labels["labels"] = labels
self.assertTrue(decoder_config.cross_attention_hidden_size is None)
# check without `enc_to_dec_proj` projection
# Original test: check without `labels` and without `enc_to_dec_proj` projection
self.assertTrue(config.hidden_size == decoder_config.hidden_size)
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
self.check_pt_to_tf_equivalence(config, decoder_config, tf_inputs_dict)
self.check_tf_to_pt_equivalence(config, decoder_config, tf_inputs_dict)
# check equivalence with labels
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict_with_labels)
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict_with_labels)
# check with `labels`
self.check_pt_to_tf_equivalence(config, decoder_config, tf_inputs_dict_with_labels)
self.check_tf_to_pt_equivalence(config, decoder_config, tf_inputs_dict_with_labels)
# This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`,
# which randomly initialize `enc_to_dec_proj`.
# # check `enc_to_dec_proj` work as expected
# check `enc_to_dec_proj` work as expected
# decoder_config.hidden_size = decoder_config.hidden_size * 2
# self.assertTrue(config.hidden_size != decoder_config.hidden_size)
# self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
# self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
# self.check_pt_to_tf_equivalence(config, decoder_config, tf_inputs_dict)
# self.check_tf_to_pt_equivalence(config, decoder_config, tf_inputs_dict)
# Let's just check `enc_to_dec_proj` can run for now
decoder_config.hidden_size = decoder_config.hidden_size * 2
self.assertTrue(config.hidden_size != decoder_config.hidden_size)
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
model = TFEncoderDecoderModel(encoder_decoder_config)
model(**inputs_dict)
model(tf_inputs_dict)
def test_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model()
@ -554,6 +683,10 @@ class TFEncoderDecoderMixin:
@require_tf
class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
def setUp(self):
self.encoder_model_tester = TFBertModelTester(self, batch_size=13)
self.decoder_model_tester = TFBertModelTester(self, batch_size=13)
def get_pretrained_model(self):
return TFEncoderDecoderModel.from_encoder_decoder_pretrained(
"hf-internal-testing/tiny-random-bert",
@ -566,10 +699,8 @@ class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester_encoder = TFBertModelTester(self, batch_size=13)
model_tester_decoder = TFBertModelTester(self, batch_size=13)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
encoder_config_and_inputs = self.encoder_model_tester.prepare_config_and_inputs()
decoder_config_and_inputs = self.decoder_model_tester.prepare_config_and_inputs_for_decoder()
(
config,
input_ids,
@ -652,6 +783,10 @@ class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
@require_tf
class TFGPT2EncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
def setUp(self):
self.encoder_model_tester = TFBertModelTester(self, batch_size=13)
self.decoder_model_tester = TFGPT2ModelTester(self)
def get_pretrained_model(self):
return TFEncoderDecoderModel.from_encoder_decoder_pretrained(
"hf-internal-testing/tiny-random-bert",
@ -664,10 +799,8 @@ class TFGPT2EncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester_encoder = TFBertModelTester(self, batch_size=13)
model_tester_decoder = TFGPT2ModelTester(self)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
encoder_config_and_inputs = self.encoder_model_tester.prepare_config_and_inputs()
decoder_config_and_inputs = self.decoder_model_tester.prepare_config_and_inputs_for_decoder()
(
config,
input_ids,
@ -744,6 +877,10 @@ class TFGPT2EncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
@require_tf
class TFRoBertaEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
def setUp(self):
self.encoder_model_tester = TFRobertaModelTester(self)
self.decoder_model_tester = TFRobertaModelTester(self)
def get_pretrained_model(self):
return TFEncoderDecoderModel.from_encoder_decoder_pretrained(
"hf-internal-testing/tiny-random-roberta",
@ -756,10 +893,8 @@ class TFRoBertaEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase)
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester_encoder = TFRobertaModelTester(self)
model_tester_decoder = TFRobertaModelTester(self)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
encoder_config_and_inputs = self.encoder_model_tester.prepare_config_and_inputs()
decoder_config_and_inputs = self.decoder_model_tester.prepare_config_and_inputs_for_decoder()
(
config,
input_ids,
@ -803,6 +938,10 @@ class TFRoBertaEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase)
@require_tf
class TFRembertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
def setUp(self):
self.encoder_model_tester = TFRemBertModelTester(self)
self.decoder_model_tester = TFRemBertModelTester(self)
def get_pretrained_model(self):
return TFEncoderDecoderModel.from_encoder_decoder_pretrained(
"hf-internal-testing/tiny-random-rembert",
@ -815,10 +954,8 @@ class TFRembertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase)
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester_encoder = TFRemBertModelTester(self)
model_tester_decoder = TFRemBertModelTester(self)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
encoder_config_and_inputs = self.encoder_model_tester.prepare_config_and_inputs()
decoder_config_and_inputs = self.decoder_model_tester.prepare_config_and_inputs_for_decoder()
(
config,
input_ids,

View File

@ -31,6 +31,7 @@ from transformers.testing_utils import (
slow,
torch_device,
)
from transformers.utils.generic import ModelOutput
from ...test_modeling_tf_common import floats_tensor, ids_tensor
from ..gpt2.test_modeling_tf_gpt2 import TFGPT2ModelTester
@ -314,31 +315,145 @@ class TFVisionEncoderDecoderMixin:
tuple(generated_output.shape.as_list()), (pixel_values.shape[0],) + (decoder_config.max_length,)
)
def check_pt_tf_equivalence(self, pt_model, tf_model, inputs_dict):
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
"""Check the outputs from PyTorch and TensorFlow models are close enough. Checks are done in a recursive way.
pt_model.to(torch_device)
pt_model.eval()
Args:
model_class: The class of the model that is currently testing. For example, `TFBertModel`,
TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Mainly used for providing more informative
error messages.
name (`str`): The name of the output. For example, `output.hidden_states`, `output.attentions`, etc.
attributes (`Tuple[str]`): The names of the output's element if the output is a tuple/list with each element
being a named field in the output.
"""
# prepare inputs
tf_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()}
if "labels" in pt_inputs:
pt_inputs["labels"] = pt_inputs["labels"].type(torch.LongTensor)
self.assertEqual(type(name), str)
if attributes is not None:
self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`")
# Allow `ModelOutput` (e.g. `CLIPOutput` has `text_model_output` and `vision_model_output`).
if isinstance(tf_outputs, ModelOutput):
self.assertTrue(
isinstance(pt_outputs, ModelOutput),
f"{name}: `pt_outputs` should an instance of `ModelOutput` when `tf_outputs` is",
)
tf_keys = [k for k, v in tf_outputs.items() if v is not None]
pt_keys = [k for k, v in pt_outputs.items() if v is not None]
self.assertEqual(tf_keys, pt_keys, f"{name}: Output keys differ between TF and PyTorch")
# convert to the case of `tuple`
# appending each key to the current (string) `names`
attributes = tuple([f"{name}.{k}" for k in tf_keys])
self.check_pt_tf_outputs(
tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes
)
# Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.)
elif type(tf_outputs) in [tuple, list]:
self.assertEqual(type(tf_outputs), type(pt_outputs), f"{name}: Output types differ between TF and PyTorch")
self.assertEqual(len(tf_outputs), len(pt_outputs), f"{name}: Output lengths differ between TF and PyTorch")
if attributes is not None:
# case 1: each output has assigned name (e.g. a tuple form of a `ModelOutput`)
self.assertEqual(
len(attributes),
len(tf_outputs),
f"{name}: The tuple `names` should have the same length as `tf_outputs`",
)
else:
# case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `names`
attributes = tuple([f"{name}_{idx}" for idx in range(len(tf_outputs))])
for tf_output, pt_output, attr in zip(tf_outputs, pt_outputs, attributes):
self.check_pt_tf_outputs(tf_output, pt_output, model_class, tol=tol, name=attr)
elif isinstance(tf_outputs, tf.Tensor):
self.assertTrue(
isinstance(pt_outputs, torch.Tensor), f"{name}: `pt_outputs` should a tensor when `tf_outputs` is"
)
tf_outputs = tf_outputs.numpy()
pt_outputs = pt_outputs.detach().to("cpu").numpy()
self.assertEqual(
tf_outputs.shape, pt_outputs.shape, f"{name}: Output shapes differ between TF and PyTorch"
)
# deal with NumPy's scalars to make replacing nan values by 0 work.
if np.isscalar(tf_outputs):
tf_outputs = np.array([tf_outputs])
pt_outputs = np.array([pt_outputs])
tf_nans = np.isnan(tf_outputs)
pt_nans = np.isnan(pt_outputs)
pt_outputs[tf_nans] = 0
tf_outputs[tf_nans] = 0
pt_outputs[pt_nans] = 0
tf_outputs[pt_nans] = 0
max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
self.assertLessEqual(max_diff, tol, f"{name}: Difference between torch and tf is {max_diff} (>= {tol}).")
else:
raise ValueError(
"`tf_outputs` should be an instance of `tf.Tensor`, a `tuple`, or an instance of `tf.Tensor`. Got"
f" {type(tf_outputs)} instead."
)
def prepare_pt_inputs_from_tf_inputs(self, tf_inputs_dict):
pt_inputs_dict = {}
for name, key in tf_inputs_dict.items():
if type(key) == bool:
pt_inputs_dict[name] = key
elif name == "input_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
elif name == "pixel_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
elif name == "input_features":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
# other general float inputs
elif tf_inputs_dict[name].dtype.is_floating:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
return pt_inputs_dict
def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict):
pt_inputs_dict = self.prepare_pt_inputs_from_tf_inputs(tf_inputs_dict)
# send pytorch inputs to the correct device
pt_inputs = {k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()}
pt_inputs_dict = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs_dict.items()
}
# send pytorch model to the correct device
pt_model.to(torch_device)
# Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences
pt_model.eval()
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
pt_outputs = pt_model(**pt_inputs_dict)
tf_outputs = tf_model(tf_inputs_dict)
tf_outputs = tf_model(**inputs_dict)
if "loss" in tf_outputs:
tf_outputs.loss = tf.math.reduce_mean(tf_outputs.loss)
tf_outputs = tf_outputs.to_tuple()
self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch")
# tf models returned loss is usually a tensor rather than a scalar.
# (see `hf_compute_loss`: it uses `tf.keras.losses.Reduction.NONE`)
# Change it here to a scalar to match PyTorch models' loss
tf_loss = getattr(tf_outputs, "loss", None)
if tf_loss is not None:
tf_outputs.loss = tf.math.reduce_mean(tf_loss)
for tf_output, pt_output in zip(tf_outputs, pt_outputs):
self.assert_almost_equals(tf_output.numpy(), pt_output.detach().to("cpu").numpy(), 1e-3)
self.check_pt_tf_outputs(tf_outputs, pt_outputs, type(tf_model))
def check_pt_tf_equivalence(self, tf_model, pt_model, tf_inputs_dict):
"""Wrap `check_pt_tf_models` to further check PT -> TF again"""
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
# PT -> TF
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
@ -351,18 +466,16 @@ class TFVisionEncoderDecoderMixin:
# This is only for copying some specific attributes of this particular model.
tf_model_loaded.config = pt_model.config
tf_outputs_loaded = tf_model_loaded(**inputs_dict)
if "loss" in tf_outputs_loaded:
tf_outputs_loaded.loss = tf.math.reduce_mean(tf_outputs_loaded.loss)
tf_outputs_loaded = tf_outputs_loaded.to_tuple()
self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch")
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs):
self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.detach().to("cpu").numpy(), 1e-3)
def check_equivalence_pt_to_tf(self, config, decoder_config, inputs_dict):
def check_pt_to_tf_equivalence(self, config, decoder_config, tf_inputs_dict):
"""EncoderDecoderModel requires special way to cross load (PT -> TF)"""
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
# Output all for aggressive testing
encoder_decoder_config.output_hidden_states = True
# All models tested in this file have attentions
encoder_decoder_config.output_attentions = True
pt_model = VisionEncoderDecoderModel(encoder_decoder_config)
@ -376,11 +489,16 @@ class TFVisionEncoderDecoderMixin:
# This is only for copying some specific attributes of this particular model.
tf_model.config = pt_model.config
self.check_pt_tf_equivalence(pt_model, tf_model, inputs_dict)
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
def check_equivalence_tf_to_pt(self, config, decoder_config, inputs_dict):
def check_tf_to_pt_equivalence(self, config, decoder_config, tf_inputs_dict):
"""EncoderDecoderModel requires special way to cross load (TF -> PT)"""
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
# Output all for aggressive testing
encoder_decoder_config.output_hidden_states = True
# TODO: A generalizable way to determine this attribute
encoder_decoder_config.output_attentions = True
# Using `_tf_model`, the test will fail, because the weights of `_tf_model` get extended before saving
# the encoder/decoder models.
@ -389,7 +507,7 @@ class TFVisionEncoderDecoderMixin:
# (the change in `src/transformers/modeling_tf_utils.py`)
_tf_model = TFVisionEncoderDecoderModel(encoder_decoder_config)
# Make sure model is built
_tf_model(**inputs_dict)
_tf_model(**tf_inputs_dict)
# Using `tf_model` to pass the test.
encoder = _tf_model.encoder.__class__(encoder_decoder_config.encoder)
@ -398,6 +516,7 @@ class TFVisionEncoderDecoderMixin:
encoder(encoder.dummy_inputs)
decoder(decoder.dummy_inputs)
tf_model = TFVisionEncoderDecoderModel(encoder=encoder, decoder=decoder)
tf_model.config = encoder_decoder_config
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
@ -409,7 +528,7 @@ class TFVisionEncoderDecoderMixin:
# This is only for copying some specific attributes of this particular model.
pt_model.config = tf_model.config
self.check_pt_tf_equivalence(pt_model, tf_model, inputs_dict)
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
def test_encoder_decoder_model(self):
config_inputs_dict = self.prepare_config_and_inputs()
@ -448,7 +567,7 @@ class TFVisionEncoderDecoderMixin:
self.assertLessEqual(diff, tol, f"Difference between torch and tf is {diff} (>= {tol}).")
@is_pt_tf_cross_test
def test_pt_tf_equivalence(self):
def test_pt_tf_model_equivalence(self):
config_inputs_dict = self.prepare_config_and_inputs()
labels = config_inputs_dict.pop("decoder_token_labels")
@ -467,48 +586,58 @@ class TFVisionEncoderDecoderMixin:
config = config_inputs_dict.pop("config")
decoder_config = config_inputs_dict.pop("decoder_config")
inputs_dict = config_inputs_dict
# Output all for aggressive testing
config.output_hidden_states = True
decoder_config.output_hidden_states = True
# All models tested in this file have attentions
config.output_attentions = True
decoder_config.output_attentions = True
tf_inputs_dict = config_inputs_dict
# `encoder_hidden_states` is not used in model call/forward
del inputs_dict["encoder_hidden_states"]
del tf_inputs_dict["encoder_hidden_states"]
inputs_dict_with_labels = copy.copy(inputs_dict)
inputs_dict_with_labels["labels"] = labels
# Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
# of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
for k in ["decoder_attention_mask"]:
attention_mask = tf_inputs_dict[k]
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
inputs_dict["decoder_attention_mask"] = tf.constant(
np.concatenate([np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1)
)
# Make sure no all 0s attention masks - to avoid failure at this moment.
# Put `1` at the beginning of sequences to make it still work when combining causal attention masks.
# TODO: remove this line once a fix regarding large negative values for attention mask is done.
attention_mask = tf.concat(
[tf.ones_like(attention_mask[:, :1], dtype=attention_mask.dtype), attention_mask[:, 1:]], axis=-1
)
tf_inputs_dict[k] = attention_mask
# TF models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
decoder_config.use_cache = False
tf_inputs_dict_with_labels = copy.copy(tf_inputs_dict)
tf_inputs_dict_with_labels["labels"] = labels
self.assertTrue(decoder_config.cross_attention_hidden_size is None)
# check without `enc_to_dec_proj` projection
# Original test: check without `labels` and without `enc_to_dec_proj` projection
self.assertTrue(config.hidden_size == decoder_config.hidden_size)
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
self.check_pt_to_tf_equivalence(config, decoder_config, tf_inputs_dict)
self.check_tf_to_pt_equivalence(config, decoder_config, tf_inputs_dict)
# check equivalence with labels
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict_with_labels)
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict_with_labels)
# check with `labels`
self.check_pt_to_tf_equivalence(config, decoder_config, tf_inputs_dict_with_labels)
self.check_tf_to_pt_equivalence(config, decoder_config, tf_inputs_dict_with_labels)
# This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`,
# which randomly initialize `enc_to_dec_proj`.
# # check `enc_to_dec_proj` work as expected
# check `enc_to_dec_proj` work as expected
# decoder_config.hidden_size = decoder_config.hidden_size * 2
# self.assertTrue(config.hidden_size != decoder_config.hidden_size)
# self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
# self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
# self.check_pt_to_tf_equivalence(config, decoder_config, tf_inputs_dict)
# self.check_tf_to_pt_equivalence(config, decoder_config, tf_inputs_dict)
# Let's just check `enc_to_dec_proj` can run for now
decoder_config.hidden_size = decoder_config.hidden_size * 2
self.assertTrue(config.hidden_size != decoder_config.hidden_size)
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
model = TFVisionEncoderDecoderModel(encoder_decoder_config)
model(**inputs_dict)
model(tf_inputs_dict)
@slow
def test_real_model_save_load_from_pretrained(self):

View File

@ -1616,7 +1616,7 @@ class ModelTesterMixin:
# Copied from tests.test_modeling_tf_common.TFModelTesterMixin.check_pt_tf_outputs
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
"""Check the outputs from PyTorch and TensorFlow models are closed enough. Checks are done in a recursive way.
"""Check the outputs from PyTorch and TensorFlow models are close enough. Checks are done in a recursive way.
Args:
model_class: The class of the model that is currently testing. For example, `TFBertModel`,
@ -1642,8 +1642,8 @@ class ModelTesterMixin:
# TODO: remove this method and this line after issues are fixed
tf_outputs, pt_outputs = self._postprocessing_to_ignore_test_cases(tf_outputs, pt_outputs, model_class)
tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
tf_keys = [k for k, v in tf_outputs.items() if v is not None]
pt_keys = [k for k, v in pt_outputs.items() if v is not None]
self.assertEqual(tf_keys, pt_keys, f"{name}: Output keys differ between TF and PyTorch")

View File

@ -428,7 +428,7 @@ class TFModelTesterMixin:
return new_tf_outputs, new_pt_outputs
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
"""Check the outputs from PyTorch and TensorFlow models are closed enough. Checks are done in a recursive way.
"""Check the outputs from PyTorch and TensorFlow models are close enough. Checks are done in a recursive way.
Args:
model_class: The class of the model that is currently testing. For example, `TFBertModel`,
@ -454,8 +454,8 @@ class TFModelTesterMixin:
# TODO: remove this method and this line after issues are fixed
tf_outputs, pt_outputs = self._postprocessing_to_ignore_test_cases(tf_outputs, pt_outputs, model_class)
tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
tf_keys = [k for k, v in tf_outputs.items() if v is not None]
pt_keys = [k for k, v in pt_outputs.items() if v is not None]
self.assertEqual(tf_keys, pt_keys, f"{name}: Output keys differ between TF and PyTorch")