From dce33f2150769825ca175df3209441122f85a814 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Mon, 11 Apr 2022 22:19:12 +0200 Subject: [PATCH] Improve PT/TF equivalence test (#16557) * add error message * Use names in the error message * allow ModelOutput * rename to check_pt_tf_outputs and move outside * fix style * skip past_key_values in a better way * Add comments * improve code for label/loss * make the logic clear by moving the ignore keys out * fix _postprocessing_to_ignore * fix _postprocessing_to_ignore: create new outputs from the remaining fields * ignore past_key_values in TFGPT2 models for now * make check_pt_tf_outputs better regarding names * move check_pt_tf_models outside * rename methods * remove test_pt_tf_model_equivalence in TFCLIPModelTest * Reduce TFViTMAEModelTest.test_pt_tf_model_equivalence * move prepare_pt_inputs_from_tf_inputs outside check_pt_tf_models * Fix quality * Clean-up TFLxmertModelTester.test_pt_tf_model_equivalence * Fix quality * fix * fix style * Clean-up TFLEDModelTest.test_pt_tf_model_equivalence * Fix quality * add docstring * improve comment Co-authored-by: ydshieh --- tests/clip/test_modeling_tf_clip.py | 127 +------ tests/led/test_modeling_tf_led.py | 125 +------ tests/lxmert/test_modeling_tf_lxmert.py | 148 ++------- tests/test_modeling_tf_common.py | 384 ++++++++++++---------- tests/vit_mae/test_modeling_tf_vit_mae.py | 134 +------- 5 files changed, 234 insertions(+), 684 deletions(-) diff --git a/tests/clip/test_modeling_tf_clip.py b/tests/clip/test_modeling_tf_clip.py index d3c3cb9f50..7ba9352406 100644 --- a/tests/clip/test_modeling_tf_clip.py +++ b/tests/clip/test_modeling_tf_clip.py @@ -23,7 +23,7 @@ from importlib import import_module import requests from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig -from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_vision, slow +from transformers.testing_utils import require_tf, require_vision, slow from transformers.utils import is_tf_available, is_vision_available from ..test_configuration_common import ConfigTester @@ -31,7 +31,6 @@ from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_ten if is_tf_available(): - import numpy as np import tensorflow as tf from transformers import TFCLIPModel, TFCLIPTextModel, TFCLIPVisionModel, TFSharedEmbeddings @@ -497,130 +496,6 @@ class TFCLIPModelTest(TFModelTesterMixin, unittest.TestCase): after_outputs = model(inputs_dict) self.assert_outputs_same(after_outputs, outputs) - # overwrite from common since CLIPModel/TFCLIPModel return CLIPOutput/TFCLIPOutput - @is_pt_tf_cross_test - def test_pt_tf_model_equivalence(self): - import torch - - import transformers - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning - pt_model_class = getattr(transformers, pt_model_class_name) - - config.output_hidden_states = True - - tf_model = model_class(config) - pt_model = pt_model_class(config) - - # Check we can load pt model in tf and vice-versa with model => model functions - - tf_model = transformers.load_pytorch_model_in_tf2_model( - tf_model, pt_model, tf_inputs=self._prepare_for_class(inputs_dict, model_class) - ) - pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model) - - # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences - pt_model.eval() - pt_inputs_dict = {} - for name, key in self._prepare_for_class(inputs_dict, model_class).items(): - if type(key) == bool: - pt_inputs_dict[name] = key - elif name == "input_values": - pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) - elif name == "pixel_values": - pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) - else: - pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long) - - # need to rename encoder-decoder "inputs" for PyTorch - if "inputs" in pt_inputs_dict and self.is_encoder_decoder: - pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs") - - with torch.no_grad(): - pto = pt_model(**pt_inputs_dict) - tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False) - - self.assertEqual(len(tfo), len(pto), "Output lengths differ between TF and PyTorch") - for tf_output, pt_output in zip(tfo.to_tuple(), pto.to_tuple()): - - if not (isinstance(tf_output, tf.Tensor) and isinstance(pt_output, torch.Tensor)): - continue - - tf_out = tf_output.numpy() - pt_out = pt_output.numpy() - - self.assertEqual(tf_out.shape, pt_out.shape, "Output component shapes differ between TF and PyTorch") - - if len(tf_out.shape) > 0: - - tf_nans = np.copy(np.isnan(tf_out)) - pt_nans = np.copy(np.isnan(pt_out)) - - pt_out[tf_nans] = 0 - tf_out[tf_nans] = 0 - pt_out[pt_nans] = 0 - tf_out[pt_nans] = 0 - - max_diff = np.amax(np.abs(tf_out - pt_out)) - self.assertLessEqual(max_diff, 4e-2) - - # Check we can load pt model in tf and vice-versa with checkpoint => model functions - with tempfile.TemporaryDirectory() as tmpdirname: - pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin") - torch.save(pt_model.state_dict(), pt_checkpoint_path) - tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path) - - tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5") - tf_model.save_weights(tf_checkpoint_path) - pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path) - - # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences - pt_model.eval() - pt_inputs_dict = {} - for name, key in self._prepare_for_class(inputs_dict, model_class).items(): - if type(key) == bool: - key = np.array(key, dtype=bool) - pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long) - elif name == "input_values": - pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) - elif name == "pixel_values": - pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) - else: - pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long) - # need to rename encoder-decoder "inputs" for PyTorch - if "inputs" in pt_inputs_dict and self.is_encoder_decoder: - pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs") - - with torch.no_grad(): - pto = pt_model(**pt_inputs_dict) - tfo = tf_model(self._prepare_for_class(inputs_dict, model_class)) - - self.assertEqual(len(tfo), len(pto), "Output lengths differ between TF and PyTorch") - for tf_output, pt_output in zip(tfo.to_tuple(), pto.to_tuple()): - - if not (isinstance(tf_output, tf.Tensor) and isinstance(pt_output, torch.Tensor)): - continue - - tf_out = tf_output.numpy() - pt_out = pt_output.numpy() - - self.assertEqual(tf_out.shape, pt_out.shape, "Output component shapes differ between TF and PyTorch") - - if len(tf_out.shape) > 0: - tf_nans = np.copy(np.isnan(tf_out)) - pt_nans = np.copy(np.isnan(pt_out)) - - pt_out[tf_nans] = 0 - tf_out[tf_nans] = 0 - pt_out[pt_nans] = 0 - tf_out[pt_nans] = 0 - - max_diff = np.amax(np.abs(tf_out - pt_out)) - self.assertLessEqual(max_diff, 4e-2) - @slow def test_model_from_pretrained(self): for model_name in TF_CLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/led/test_modeling_tf_led.py b/tests/led/test_modeling_tf_led.py index cb75ddf8c3..df115010f3 100644 --- a/tests/led/test_modeling_tf_led.py +++ b/tests/led/test_modeling_tf_led.py @@ -17,14 +17,13 @@ import unittest from transformers import LEDConfig, is_tf_available -from transformers.testing_utils import is_pt_tf_cross_test, require_tf, slow +from transformers.testing_utils import require_tf, slow from ..test_configuration_common import ConfigTester from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor if is_tf_available(): - import numpy as np import tensorflow as tf from transformers import TFLEDForConditionalGeneration, TFLEDModel @@ -362,128 +361,6 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase): self.assertEqual(model.config.output_hidden_states, True) check_encoder_attentions_output(outputs) - # TODO: Remove this once a more thorough pt/tf equivalence could be implemented in `test_modeling_tf_common.py`. - # (Currently, such a test will fail some other model tests: it requires some time to fix them.) - @is_pt_tf_cross_test - def test_pt_tf_model_equivalence_extra(self): - import torch - - import transformers - - def prepare_pt_inputs_from_tf_inputs(tf_inputs_dict): - - pt_inputs_dict = {} - for name, key in tf_inputs_dict.items(): - if type(key) == bool: - pt_inputs_dict[name] = key - elif name == "input_values": - pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) - elif name == "pixel_values": - pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) - else: - pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long) - - return pt_inputs_dict - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning - pt_model_class = getattr(transformers, pt_model_class_name) - - config.output_hidden_states = True - - tf_model = model_class(config) - pt_model = pt_model_class(config) - - tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - tf_inputs_dict_maybe_with_labels = self._prepare_for_class(inputs_dict, model_class, return_labels=True) - - # Check we can load pt model in tf and vice-versa with model => model functions - - tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict) - pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model) - - # Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences - pt_model.eval() - - pt_inputs_dict = prepare_pt_inputs_from_tf_inputs(tf_inputs_dict) - pt_inputs_dict_maybe_with_labels = prepare_pt_inputs_from_tf_inputs(tf_inputs_dict_maybe_with_labels) - - # need to rename encoder-decoder "inputs" for PyTorch - if "inputs" in pt_inputs_dict and self.is_encoder_decoder: - pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs") - - with torch.no_grad(): - pto = pt_model(**pt_inputs_dict) - tfo = tf_model(tf_inputs_dict, training=False) - - tf_hidden_states = tfo[0].numpy() - pt_hidden_states = pto[0].numpy() - - tf_nans = np.isnan(tf_hidden_states) - pt_nans = np.isnan(pt_hidden_states) - - pt_hidden_states[tf_nans] = 0 - tf_hidden_states[tf_nans] = 0 - pt_hidden_states[pt_nans] = 0 - tf_hidden_states[pt_nans] = 0 - - max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states)) - self.assertLessEqual(max_diff, 1e-4) - - has_labels = any( - x in tf_inputs_dict_maybe_with_labels for x in ["labels", "next_sentence_label", "start_positions"] - ) - if has_labels: - - with torch.no_grad(): - pto = pt_model(**pt_inputs_dict_maybe_with_labels) - tfo = tf_model(tf_inputs_dict_maybe_with_labels, training=False) - - # Some models' output class don't have `loss` attribute despite `labels` is used. - tf_loss = getattr(tfo, "loss", None) - pt_loss = getattr(pto, "loss", None) - - # Some models require extra condition to return loss. For example, `BertForPreTraining` requires both - # `labels` and `next_sentence_label`. - # Moreover, some PT models return loss while the corresponding TF/Flax models don't. - if tf_loss is not None and pt_loss is not None: - - tf_loss = tf.math.reduce_mean(tf_loss).numpy() - pt_loss = pt_loss.numpy() - - tf_nans = np.isnan(tf_loss) - pt_nans = np.isnan(pt_loss) - # the 2 losses need to be both nan or both not nan - # (`TapasForQuestionAnswering` gives nan loss here) - self.assertEqual(tf_nans, pt_nans) - - if not tf_nans: - max_diff = np.amax(np.abs(tf_loss - pt_loss)) - # `TFFunnelForTokenClassification` (and potentially other TF token classification models) give - # large difference (up to 0.1x). PR #15294 addresses this issue. - # There is also an inconsistency between PT/TF `XLNetLMHeadModel`. - # Before these issues are fixed & merged, set a higher threshold here to pass the test. - self.assertLessEqual(max_diff, 1e-4) - - tf_logits = tfo[1].numpy() - pt_logits = pto[1].numpy() - - # check on the shape - self.assertEqual(tf_logits.shape, pt_logits.shape) - - tf_nans = np.isnan(tf_logits) - pt_nans = np.isnan(pt_logits) - - pt_logits[tf_nans] = 0 - tf_logits[tf_nans] = 0 - pt_logits[pt_nans] = 0 - tf_logits[pt_nans] = 0 - - max_diff = np.amax(np.abs(tf_logits - pt_logits)) - self.assertLessEqual(max_diff, 1e-4) - def test_xla_mode(self): # TODO JP: Make LED XLA compliant pass diff --git a/tests/lxmert/test_modeling_tf_lxmert.py b/tests/lxmert/test_modeling_tf_lxmert.py index 63ec44a1ad..19226545a9 100644 --- a/tests/lxmert/test_modeling_tf_lxmert.py +++ b/tests/lxmert/test_modeling_tf_lxmert.py @@ -272,6 +272,8 @@ class TFLxmertModelTester(object): if return_obj_labels: inputs_dict["obj_labels"] = obj_labels + else: + config.task_obj_predict = False return config, inputs_dict @@ -486,135 +488,31 @@ class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase): config.output_hidden_states = True check_hidden_states_output(config, inputs_dict, model_class) - def test_pt_tf_model_equivalence(self): - from transformers import is_torch_available - - if not is_torch_available(): - return - + def prepare_pt_inputs_from_tf_inputs(self, tf_inputs_dict): import torch - import transformers + pt_inputs_dict = {} + for key, value in tf_inputs_dict.items(): - for model_class in self.all_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( - return_obj_labels="PreTraining" in model_class.__name__ - ) + if isinstance(value, dict): + pt_inputs_dict[key] = self.prepare_pt_inputs_from_tf_inputs(value) + elif isinstance(value, (list, tuple)): + pt_inputs_dict[key] = (self.prepare_pt_inputs_from_tf_inputs(iter_value) for iter_value in value) + elif type(key) == bool: + pt_inputs_dict[key] = value + elif key == "input_values": + pt_inputs_dict[key] = torch.from_numpy(value.numpy()).to(torch.float32) + elif key == "pixel_values": + pt_inputs_dict[key] = torch.from_numpy(value.numpy()).to(torch.float32) + elif key == "input_features": + pt_inputs_dict[key] = torch.from_numpy(value.numpy()).to(torch.float32) + # other general float inputs + elif tf_inputs_dict[key].dtype.is_floating: + pt_inputs_dict[key] = torch.from_numpy(value.numpy()).to(torch.float32) + else: + pt_inputs_dict[key] = torch.from_numpy(value.numpy()).to(torch.long) - pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning - pt_model_class = getattr(transformers, pt_model_class_name) - - config.output_hidden_states = True - config.task_obj_predict = False - - tf_model = model_class(config) - pt_model = pt_model_class(config) - - # Check we can load pt model in tf and vice-versa with model => model functions - - tf_model = transformers.load_pytorch_model_in_tf2_model( - tf_model, pt_model, tf_inputs=self._prepare_for_class(inputs_dict, model_class) - ) - pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model) - - # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences - pt_model.eval() - - # Delete obj labels as we want to compute the hidden states and not the loss - - if "obj_labels" in inputs_dict: - del inputs_dict["obj_labels"] - - def torch_type(key): - if key in ("visual_feats", "visual_pos"): - return torch.float32 - else: - return torch.long - - def recursive_numpy_convert(iterable): - return_dict = {} - for key, value in iterable.items(): - if isinstance(value, dict): - return_dict[key] = recursive_numpy_convert(value) - else: - if isinstance(value, (list, tuple)): - return_dict[key] = ( - torch.from_numpy(iter_value.numpy()).to(torch_type(key)) for iter_value in value - ) - else: - return_dict[key] = torch.from_numpy(value.numpy()).to(torch_type(key)) - return return_dict - - pt_inputs_dict = recursive_numpy_convert(self._prepare_for_class(inputs_dict, model_class)) - - # need to rename encoder-decoder "inputs" for PyTorch - if "inputs" in pt_inputs_dict and self.is_encoder_decoder: - pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs") - - with torch.no_grad(): - pto = pt_model(**pt_inputs_dict) - tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False) - tf_hidden_states = tfo[0].numpy() - pt_hidden_states = pto[0].numpy() - - tf_nans = np.copy(np.isnan(tf_hidden_states)) - pt_nans = np.copy(np.isnan(pt_hidden_states)) - - pt_hidden_states[tf_nans] = 0 - tf_hidden_states[tf_nans] = 0 - pt_hidden_states[pt_nans] = 0 - tf_hidden_states[pt_nans] = 0 - - max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states)) - # Debug info (remove when fixed) - if max_diff >= 2e-2: - print("===") - print(model_class) - print(config) - print(inputs_dict) - print(pt_inputs_dict) - self.assertLessEqual(max_diff, 6e-2) - - # Check we can load pt model in tf and vice-versa with checkpoint => model functions - with tempfile.TemporaryDirectory() as tmpdirname: - import os - - pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin") - torch.save(pt_model.state_dict(), pt_checkpoint_path) - tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path) - - tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5") - tf_model.save_weights(tf_checkpoint_path) - pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path) - - # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences - pt_model.eval() - pt_inputs_dict = dict( - (name, torch.from_numpy(key.numpy()).to(torch.long)) - for name, key in self._prepare_for_class(inputs_dict, model_class).items() - ) - - for key, value in pt_inputs_dict.items(): - if key in ("visual_feats", "visual_pos"): - pt_inputs_dict[key] = value.to(torch.float32) - else: - pt_inputs_dict[key] = value.to(torch.long) - - with torch.no_grad(): - pto = pt_model(**pt_inputs_dict) - tfo = tf_model(self._prepare_for_class(inputs_dict, model_class)) - tfo = tfo[0].numpy() - pto = pto[0].numpy() - tf_nans = np.copy(np.isnan(tfo)) - pt_nans = np.copy(np.isnan(pto)) - - pto[tf_nans] = 0 - tfo[tf_nans] = 0 - pto[pt_nans] = 0 - tfo[pt_nans] = 0 - - max_diff = np.amax(np.abs(tfo - pto)) - self.assertLessEqual(max_diff, 6e-2) + return pt_inputs_dict def test_save_load(self): for model_class in self.all_model_classes: diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index bbb0befeb6..195c7daa84 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -27,7 +27,7 @@ from typing import List, Tuple from huggingface_hub import delete_repo, login from requests.exceptions import HTTPError -from transformers import is_tf_available +from transformers import is_tf_available, is_torch_available from transformers.configuration_utils import PretrainedConfig from transformers.models.auto import get_values from transformers.testing_utils import tooslow # noqa: F401 @@ -44,6 +44,7 @@ from transformers.testing_utils import ( torch_device, ) from transformers.utils import logging +from transformers.utils.generic import ModelOutput logger = logging.get_logger(__name__) @@ -98,6 +99,9 @@ if is_tf_available(): # Virtual devices must be set before GPUs have been initialized print(e) +if is_torch_available(): + import torch + def _config_zero_init(config): configs_no_init = copy.deepcopy(config) @@ -350,192 +354,210 @@ class TFModelTesterMixin: max_diff = np.amax(np.abs(out_1 - out_2)) self.assertLessEqual(max_diff, 1e-5) - @is_pt_tf_cross_test - def test_pt_tf_model_equivalence(self): - import torch + # Don't copy this method to model specific test file! + # TODO: remove this method once the issues are all fixed! + def _make_attention_mask_non_null(self, inputs_dict): + """Make sure no sequence has all zeros as attention mask""" - import transformers + for k in ["attention_mask", "encoder_attention_mask", "decoder_attention_mask"]: + if k in inputs_dict: + attention_mask = inputs_dict[k] - def prepare_pt_inputs_from_tf_inputs(tf_inputs_dict): - - pt_inputs_dict = {} - for name, key in tf_inputs_dict.items(): - if type(key) == bool: - pt_inputs_dict[name] = key - elif name == "input_values": - pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) - elif name == "pixel_values": - pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) - elif name == "input_features": - pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) - else: - pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long) - - return pt_inputs_dict - - def check_outputs(tf_outputs, pt_outputs, model_class, names): - """ - Args: - model_class: The class of the model that is currently testing. For example, `TFBertModel`, - TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Currently unused, but it could make - debugging easier and faster. - - names: A string, or a tuple of strings. These specify what tf_outputs/pt_outputs represent in the model outputs. - Currently unused, but in the future, we could use this information to make the error message clearer - by giving the name(s) of the output tensor(s) with large difference(s) between PT and TF. - """ - - # Some issue (`about past_key_values`) to solve (e.g. `TFPegasusForConditionalGeneration`) in a separate PR. - if names == "past_key_values": - return - - # Allow `list` because `(TF)TransfoXLModelOutput.mems` is a list of tensors. - if type(tf_outputs) in [tuple, list]: - self.assertEqual(type(tf_outputs), type(pt_outputs)) - self.assertEqual(len(tf_outputs), len(pt_outputs)) - if type(names) == tuple: - for tf_output, pt_output, name in zip(tf_outputs, pt_outputs, names): - check_outputs(tf_output, pt_output, model_class, names=name) - elif type(names) == str: - for idx, (tf_output, pt_output) in enumerate(zip(tf_outputs, pt_outputs)): - check_outputs(tf_output, pt_output, model_class, names=f"{names}_{idx}") - else: - raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.") - elif isinstance(tf_outputs, tf.Tensor): - self.assertTrue(isinstance(pt_outputs, torch.Tensor)) - - tf_outputs = tf_outputs.numpy() - pt_outputs = pt_outputs.detach().to("cpu").numpy() - - tf_nans = np.isnan(tf_outputs) - pt_nans = np.isnan(pt_outputs) - - pt_outputs[tf_nans] = 0 - tf_outputs[tf_nans] = 0 - pt_outputs[pt_nans] = 0 - tf_outputs[pt_nans] = 0 - - max_diff = np.amax(np.abs(tf_outputs - pt_outputs)) - self.assertLessEqual(max_diff, 1e-5) - else: - raise ValueError( - f"`tf_outputs` should be a `tuple` or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead." + # Make sure no all 0s attention masks - to avoid failure at this moment. + # Put `1` at the beginning of sequences to make it still work when combining causal attention masks. + # TODO: remove this line once a fix regarding large negative values for attention mask is done. + attention_mask = tf.concat( + [tf.ones_like(attention_mask[:, :1], dtype=attention_mask.dtype), attention_mask[:, 1:]], axis=-1 ) - def check_pt_tf_models(tf_model, pt_model): + # Here we make the first sequence with all 0s as attention mask. + # Currently, this will fail for `TFWav2Vec2Model`. This is caused by the different large negative + # values, like `1e-4`, `1e-9`, `1e-30` and `-inf` for attention mask across models/frameworks. + # TODO: enable this block once the large negative values thing is cleaned up. + # (see https://github.com/huggingface/transformers/issues/14859) + # attention_mask = tf.concat( + # [ + # tf.zeros_like(attention_mask[:1], dtype=tf.int32), + # tf.cast(attention_mask[1:], dtype=tf.int32) + # ], + # axis=0 + # ) - # send pytorch model to the correct device - pt_model.to(torch_device) + inputs_dict[k] = attention_mask - # Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences - pt_model.eval() + # Don't copy this method to model specific test file! + # TODO: remove this method once the issues are all fixed! + def _postprocessing_to_ignore_test_cases(self, tf_outputs, pt_outputs, model_class): + """For temporarily ignoring some failed test cases (issues to be fixed)""" - pt_inputs_dict = prepare_pt_inputs_from_tf_inputs(tf_inputs_dict) - pt_inputs_dict_maybe_with_labels = prepare_pt_inputs_from_tf_inputs(tf_inputs_dict_maybe_with_labels) + tf_keys = set([k for k, v in tf_outputs.items() if v is not None]) + pt_keys = set([k for k, v in pt_outputs.items() if v is not None]) - # send pytorch inputs to the correct device - pt_inputs_dict = { - k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs_dict.items() - } - pt_inputs_dict_maybe_with_labels = { - k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v - for k, v in pt_inputs_dict_maybe_with_labels.items() - } + key_differences = tf_keys.symmetric_difference(pt_keys) - # Original test: check without `labels` - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs_dict) - tf_outputs = tf_model(tf_inputs_dict) + if model_class.__name__ in [ + "TFFlaubertWithLMHeadModel", + "TFFunnelForPreTraining", + "TFElectraForPreTraining", + "TFXLMWithLMHeadModel", + "TFTransfoXLLMHeadModel", + ]: + for k in key_differences: + if k in ["loss", "losses"]: + tf_keys.discard(k) + pt_keys.discard(k) + elif model_class.__name__.startswith("TFGPT2"): + # `TFGPT2` has `past_key_values` as a tensor while `GPT2` has it as a tuple. + tf_keys.discard("past_key_values") + pt_keys.discard("past_key_values") + + # create new outputs from the remaining fields + new_tf_outputs = type(tf_outputs)(**{k: tf_outputs[k] for k in tf_keys}) + new_pt_outputs = type(pt_outputs)(**{k: pt_outputs[k] for k in pt_keys}) + + return new_tf_outputs, new_pt_outputs + + def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None): + """Check the outputs from PyTorch and TensorFlow models are closed enough. Checks are done in a recursive way. + + Args: + model_class: The class of the model that is currently testing. For example, `TFBertModel`, + TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Mainly used for providing more informative + error messages. + name (`str`): The name of the output. For example, `output.hidden_states`, `output.attentions`, etc. + attributes (`Tuple[str]`): The names of the output's element if the output is a tuple/list with each element + being a named field in the output. + """ + + self.assertEqual(type(name), str) + if attributes is not None: + self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`") + + # Allow `ModelOutput` (e.g. `CLIPOutput` has `text_model_output` and `vision_model_output`). + if isinstance(tf_outputs, ModelOutput): + self.assertTrue( + isinstance(pt_outputs, ModelOutput), + f"{name}: `pt_outputs` should an instance of `ModelOutput` when `tf_outputs` is", + ) + + # Don't copy this block to model specific test file! + # TODO: remove this method and this line after issues are fixed + tf_outputs, pt_outputs = self._postprocessing_to_ignore_test_cases(tf_outputs, pt_outputs, model_class) tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None]) pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) - self.assertEqual(tf_keys, pt_keys) - check_outputs(tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=tf_keys) + self.assertEqual(tf_keys, pt_keys, f"{name}: Output keys differ between TF and PyTorch") - # check the case where `labels` is passed - has_labels = any( - x in tf_inputs_dict_maybe_with_labels for x in ["labels", "next_sentence_label", "start_positions"] + # convert to the case of `tuple` + # appending each key to the current (string) `names` + attributes = tuple([f"{name}.{k}" for k in tf_keys]) + self.check_pt_tf_outputs( + tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes ) - if has_labels: - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs_dict_maybe_with_labels) - tf_outputs = tf_model(tf_inputs_dict_maybe_with_labels) + # Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.) + elif type(tf_outputs) in [tuple, list]: + self.assertEqual(type(tf_outputs), type(pt_outputs), f"{name}: Output types differ between TF and PyTorch") + self.assertEqual(len(tf_outputs), len(pt_outputs), f"{name}: Output lengths differ between TF and PyTorch") - # Some models' output class don't have `loss` attribute despite `labels` is used. - # TODO: identify which models - tf_loss = getattr(tf_outputs, "loss", None) - pt_loss = getattr(pt_outputs, "loss", None) + if attributes is not None: + # case 1: each output has assigned name (e.g. a tuple form of a `ModelOutput`) + self.assertEqual( + len(attributes), + len(tf_outputs), + f"{name}: The tuple `names` should have the same length as `tf_outputs`", + ) + else: + # case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `names` + attributes = tuple([f"{name}_{idx}" for idx in range(len(tf_outputs))]) - # Some PT models return loss while the corresponding TF models don't (i.e. `None` for `loss`). - # - TFFlaubertWithLMHeadModel - # - TFFunnelForPreTraining - # - TFElectraForPreTraining - # - TFXLMWithLMHeadModel - # TODO: Fix PT/TF diff -> remove this condition to fail the test if a diff occurs - if not ((tf_loss is None and pt_loss is None) or (tf_loss is not None and pt_loss is not None)): - if model_class.__name__ not in [ - "TFFlaubertWithLMHeadModel", - "TFFunnelForPreTraining", - "TFElectraForPreTraining", - "TFXLMWithLMHeadModel", - "TFTransfoXLLMHeadModel", - ]: - self.assertEqual(tf_loss is None, pt_loss is None) + for tf_output, pt_output, attr in zip(tf_outputs, pt_outputs, attributes): + self.check_pt_tf_outputs(tf_output, pt_output, model_class, tol=tol, name=attr) - tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + elif isinstance(tf_outputs, tf.Tensor): + self.assertTrue( + isinstance(pt_outputs, torch.Tensor), f"{name}: `pt_outputs` should a tensor when `tf_outputs` is" + ) - # TODO: remove these 2 conditions once the above TODOs (above loss) are implemented - # (Also, `TFTransfoXLLMHeadModel` has no `loss` while `TransfoXLLMHeadModel` return `losses`) - if tf_keys != pt_keys: - if model_class.__name__ not in [ - "TFFlaubertWithLMHeadModel", - "TFFunnelForPreTraining", - "TFElectraForPreTraining", - "TFXLMWithLMHeadModel", - "TFTransfoXLLMHeadModel", - ]: - self.assertEqual(tf_keys, pt_keys) + tf_outputs = tf_outputs.numpy() + pt_outputs = pt_outputs.detach().to("cpu").numpy() - # Since we deliberately make some tests pass above (regarding the `loss`), let's still try to test - # some remaining attributes in the outputs. - # TODO: remove this block of `index` computing once the above TODOs (above loss) are implemented - # compute the 1st `index` where `tf_keys` and `pt_keys` is different - index = 0 - for _ in range(min(len(tf_keys), len(pt_keys))): - if tf_keys[index] == pt_keys[index]: - index += 1 - else: - break - if tf_keys[:index] != pt_keys[:index]: - self.assertEqual(tf_keys, pt_keys) + self.assertEqual( + tf_outputs.shape, pt_outputs.shape, f"{name}: Output shapes differ between TF and PyTorch" + ) - # Some models require extra condition to return loss. For example, `(TF)BertForPreTraining` requires - # both`labels` and `next_sentence_label`. - if tf_loss is not None and pt_loss is not None: + # deal with NumPy's scalars to make replacing nan values by 0 work. + if np.isscalar(tf_outputs): + tf_outputs = np.array([tf_outputs]) + pt_outputs = np.array([pt_outputs]) - # check anything else than `loss` - keys = tuple([k for k in tf_keys]) - check_outputs(tf_outputs[1:index], pt_outputs[1:index], model_class, names=keys[1:index]) + tf_nans = np.isnan(tf_outputs) + pt_nans = np.isnan(pt_outputs) - # check `loss` + pt_outputs[tf_nans] = 0 + tf_outputs[tf_nans] = 0 + pt_outputs[pt_nans] = 0 + tf_outputs[pt_nans] = 0 - # tf models returned loss is usually a tensor rather than a scalar. - # (see `hf_compute_loss`: it uses `tf.keras.losses.Reduction.NONE`) - # Change it here to a scalar to match PyTorch models' loss - tf_loss = tf.math.reduce_mean(tf_loss).numpy() - pt_loss = pt_loss.detach().to("cpu").numpy() + max_diff = np.amax(np.abs(tf_outputs - pt_outputs)) + self.assertLessEqual(max_diff, tol, f"{name}: Difference between torch and tf is {max_diff} (>= {tol}).") + else: + raise ValueError( + f"`tf_outputs` should be an instance of `tf.Tensor`, a `tuple`, or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead." + ) - tf_nans = np.isnan(tf_loss) - pt_nans = np.isnan(pt_loss) - # the 2 losses need to be both nan or both not nan - self.assertEqual(tf_nans, pt_nans) + def prepare_pt_inputs_from_tf_inputs(self, tf_inputs_dict): - if not tf_nans: - max_diff = np.amax(np.abs(tf_loss - pt_loss)) - self.assertLessEqual(max_diff, 1e-5) + pt_inputs_dict = {} + for name, key in tf_inputs_dict.items(): + if type(key) == bool: + pt_inputs_dict[name] = key + elif name == "input_values": + pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) + elif name == "pixel_values": + pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) + elif name == "input_features": + pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) + # other general float inputs + elif tf_inputs_dict[name].dtype.is_floating: + pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) + else: + pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long) + + return pt_inputs_dict + + def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict): + + pt_inputs_dict = self.prepare_pt_inputs_from_tf_inputs(tf_inputs_dict) + + # send pytorch inputs to the correct device + pt_inputs_dict = { + k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs_dict.items() + } + + # send pytorch model to the correct device + pt_model.to(torch_device) + + # Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences + pt_model.eval() + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs_dict) + tf_outputs = tf_model(tf_inputs_dict) + + # tf models returned loss is usually a tensor rather than a scalar. + # (see `hf_compute_loss`: it uses `tf.keras.losses.Reduction.NONE`) + # Change it here to a scalar to match PyTorch models' loss + tf_loss = getattr(tf_outputs, "loss", None) + if tf_loss is not None: + tf_outputs.loss = tf.math.reduce_mean(tf_loss) + + self.check_pt_tf_outputs(tf_outputs, pt_outputs, type(tf_model)) + + @is_pt_tf_cross_test + def test_pt_tf_model_equivalence(self): + import transformers for model_class in self.all_model_classes: @@ -546,25 +568,10 @@ class TFModelTesterMixin: if self.has_attentions: config.output_attentions = True - for k in ["attention_mask", "encoder_attention_mask", "decoder_attention_mask"]: - if k in inputs_dict: - attention_mask = inputs_dict[k] - # make sure no all 0s attention masks - to avoid failure at this moment. - # TODO: remove this line once the TODO below is implemented. - attention_mask = tf.ones_like(attention_mask, dtype=tf.int32) - # Here we make the first sequence with all 0s as attention mask. - # Currently, this will fail for `TFWav2Vec2Model`. This is caused by the different large negative - # values, like `1e-4`, `1e-9`, `1e-30` and `-inf` for attention mask across models/frameworks. - # TODO: enable this block once the large negative values thing is cleaned up. - # (see https://github.com/huggingface/transformers/issues/14859) - # attention_mask = tf.concat( - # [ - # tf.zeros_like(attention_mask[:1], dtype=tf.int32), - # tf.cast(attention_mask[1:], dtype=tf.int32) - # ], - # axis=0 - # ) - inputs_dict[k] = attention_mask + # Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency + # of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`. + # TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it. + self._make_attention_mask_non_null(inputs_dict) pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning pt_model_class = getattr(transformers, pt_model_class_name) @@ -573,18 +580,27 @@ class TFModelTesterMixin: pt_model = pt_model_class(config) tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - tf_inputs_dict_maybe_with_labels = self._prepare_for_class( + tf_inputs_dict_with_labels = self._prepare_for_class( inputs_dict, model_class, # Not all models accept "labels" in the forward pass (yet :) ) return_labels=True if "labels" in inspect.signature(model_class.call).parameters.keys() else False, ) + # For some models (e.g. base models), there is no label returned. + # Set the input dict to `None` to avoid check outputs twice for the same input dicts. + if set(tf_inputs_dict_with_labels.keys()).symmetric_difference(tf_inputs_dict.keys()): + tf_inputs_dict_with_labels = None + # Check we can load pt model in tf and vice-versa with model => model functions tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict) pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model) - check_pt_tf_models(tf_model, pt_model) + # Original test: check without `labels` + self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) + # check with `labels` + if tf_inputs_dict_with_labels: + self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict_with_labels) # Check we can load pt model in tf and vice-versa with checkpoint => model functions with tempfile.TemporaryDirectory() as tmpdirname: @@ -596,7 +612,11 @@ class TFModelTesterMixin: tf_model.save_weights(tf_checkpoint_path) pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path) - check_pt_tf_models(tf_model, pt_model) + # Original test: check without `labels` + self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) + # check with `labels` + if tf_inputs_dict_with_labels: + self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict_with_labels) def test_compile_tf_model(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/vit_mae/test_modeling_tf_vit_mae.py b/tests/vit_mae/test_modeling_tf_vit_mae.py index e978fabb33..5a95f46350 100644 --- a/tests/vit_mae/test_modeling_tf_vit_mae.py +++ b/tests/vit_mae/test_modeling_tf_vit_mae.py @@ -28,7 +28,7 @@ import numpy as np from transformers import ViTMAEConfig from transformers.file_utils import cached_property, is_tf_available, is_vision_available -from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_vision, slow, torch_device +from transformers.testing_utils import require_tf, require_vision, slow from ..test_configuration_common import ConfigTester from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor @@ -363,140 +363,20 @@ class TFViTMAEModelTest(TFModelTesterMixin, unittest.TestCase): # overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise # to generate masks during test - @is_pt_tf_cross_test - def test_pt_tf_model_equivalence(self): - import torch - - import transformers + def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict): # make masks reproducible np.random.seed(2) - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - num_patches = int((config.image_size // config.patch_size) ** 2) + num_patches = int((tf_model.config.image_size // tf_model.config.patch_size) ** 2) noise = np.random.uniform(size=(self.model_tester.batch_size, num_patches)) - pt_noise = torch.from_numpy(noise).to(device=torch_device) tf_noise = tf.constant(noise) - def prepare_pt_inputs_from_tf_inputs(tf_inputs_dict): + # Add `noise` argument. + # PT inputs will be prepared in `super().check_pt_tf_models()` with this added `noise` argument + tf_inputs_dict["noise"] = tf_noise - pt_inputs_dict = {} - for name, key in tf_inputs_dict.items(): - pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) - - return pt_inputs_dict - - def check_outputs(tf_outputs, pt_outputs, model_class, names): - """ - Args: - model_class: The class of the model that is currently testing. For example, `TFBertModel`, - TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Currently unused, but it could make - debugging easier and faster. - - names: A string, or a tuple of strings. These specify what tf_outputs/pt_outputs represent in the model outputs. - Currently unused, but in the future, we could use this information to make the error message clearer - by giving the name(s) of the output tensor(s) with large difference(s) between PT and TF. - """ - - # Allow `list` because `(TF)TransfoXLModelOutput.mems` is a list of tensors. - if type(tf_outputs) in [tuple, list]: - self.assertEqual(type(tf_outputs), type(pt_outputs)) - self.assertEqual(len(tf_outputs), len(pt_outputs)) - if type(names) == tuple: - for tf_output, pt_output, name in zip(tf_outputs, pt_outputs, names): - check_outputs(tf_output, pt_output, model_class, names=name) - elif type(names) == str: - for idx, (tf_output, pt_output) in enumerate(zip(tf_outputs, pt_outputs)): - check_outputs(tf_output, pt_output, model_class, names=f"{names}_{idx}") - else: - raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.") - elif isinstance(tf_outputs, tf.Tensor): - self.assertTrue(isinstance(pt_outputs, torch.Tensor)) - - tf_outputs = tf_outputs.numpy() - if isinstance(tf_outputs, np.float32): - tf_outputs = np.array(tf_outputs, dtype=np.float32) - pt_outputs = pt_outputs.detach().to("cpu").numpy() - - tf_nans = np.isnan(tf_outputs) - pt_nans = np.isnan(pt_outputs) - - pt_outputs[tf_nans] = 0 - tf_outputs[tf_nans] = 0 - pt_outputs[pt_nans] = 0 - tf_outputs[pt_nans] = 0 - - max_diff = np.amax(np.abs(tf_outputs - pt_outputs)) - # Set a higher tolerance (2e-5) here than the one in the common test (1e-5). - # TODO: A deeper look to decide the best (common) tolerance for the test to be strict but not too flaky. - self.assertLessEqual(max_diff, 2e-5) - else: - raise ValueError( - f"`tf_outputs` should be a `tuple` or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead." - ) - - def check_pt_tf_models(tf_model, pt_model): - # we are not preparing a model with labels because of the formation - # of the ViT MAE model - - # send pytorch model to the correct device - pt_model.to(torch_device) - - # Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences - pt_model.eval() - - pt_inputs_dict = prepare_pt_inputs_from_tf_inputs(tf_inputs_dict) - - # send pytorch inputs to the correct device - pt_inputs_dict = { - k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs_dict.items() - } - - # Original test: check without `labels` - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs_dict, noise=pt_noise) - tf_outputs = tf_model(tf_inputs_dict, noise=tf_noise) - - tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) - - self.assertEqual(tf_keys, pt_keys) - check_outputs(tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=tf_keys) - - for model_class in self.all_model_classes: - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - # Output all for aggressive testing - config.output_hidden_states = True - if self.has_attentions: - config.output_attentions = True - - pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning - pt_model_class = getattr(transformers, pt_model_class_name) - - tf_model = model_class(config) - pt_model = pt_model_class(config) - - tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - - # Check we can load pt model in tf and vice-versa with model => model functions - tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict) - pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model) - - check_pt_tf_models(tf_model, pt_model) - - # Check we can load pt model in tf and vice-versa with checkpoint => model functions - with tempfile.TemporaryDirectory() as tmpdirname: - pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin") - torch.save(pt_model.state_dict(), pt_checkpoint_path) - tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path) - - tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5") - tf_model.save_weights(tf_checkpoint_path) - pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path) - - check_pt_tf_models(tf_model, pt_model) + super().check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) # overwrite from common since TFViTMAEForPretraining outputs loss along with # logits and mask indices. loss and mask indicies are not suitable for integration