Fix model equivalence tests (#15670)

* Fix model equivalence tests

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Lysandre Debut 2022-02-15 18:55:22 -05:00 committed by GitHub
parent 1690319217
commit 943e2aa036
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 27 additions and 31 deletions

View File

@ -625,15 +625,15 @@ class CLIPModelTest(ModelTesterMixin, unittest.TestCase):
if type(tensor) == bool:
tf_inputs_dict[key] = tensor
elif key == "input_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
elif key == "pixel_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
else:
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32)
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.int32)
# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model).to(torch_device)
# need to rename encoder-decoder "inputs" for PyTorch
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
@ -650,7 +650,7 @@ class CLIPModelTest(ModelTesterMixin, unittest.TestCase):
continue
tf_out = tf_output.numpy()
pt_out = pt_output.numpy()
pt_out = pt_output.cpu().numpy()
self.assertEqual(tf_out.shape, pt_out.shape, "Output component shapes differ between TF and PyTorch")
@ -676,6 +676,7 @@ class CLIPModelTest(ModelTesterMixin, unittest.TestCase):
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
pt_model = pt_model.to(torch_device)
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
@ -686,11 +687,11 @@ class CLIPModelTest(ModelTesterMixin, unittest.TestCase):
tensor = np.array(tensor, dtype=bool)
tf_inputs_dict[key] = tf.convert_to_tensor(tensor, dtype=tf.int32)
elif key == "input_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
elif key == "pixel_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
else:
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32)
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.int32)
# need to rename encoder-decoder "inputs" for PyTorch
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
@ -708,7 +709,7 @@ class CLIPModelTest(ModelTesterMixin, unittest.TestCase):
continue
tf_out = tf_output.numpy()
pt_out = pt_output.numpy()
pt_out = pt_output.cpu().numpy()
self.assertEqual(tf_out.shape, pt_out.shape, "Output component shapes differ between TF and PyTorch")

View File

@ -1475,17 +1475,17 @@ class ModelTesterMixin:
if type(tensor) == bool:
tf_inputs_dict[key] = tensor
elif key == "input_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
elif key == "pixel_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
elif key == "input_features":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
else:
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32)
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.int32)
# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model).to(torch_device)
# need to rename encoder-decoder "inputs" for PyTorch
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
@ -1496,7 +1496,7 @@ class ModelTesterMixin:
tfo = tf_model(tf_inputs_dict, training=False)
tf_hidden_states = tfo[0].numpy()
pt_hidden_states = pto[0].numpy()
pt_hidden_states = pto[0].cpu().numpy()
tf_nans = np.copy(np.isnan(tf_hidden_states))
pt_nans = np.copy(np.isnan(pt_hidden_states))
@ -1518,6 +1518,7 @@ class ModelTesterMixin:
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
pt_model = pt_model.to(torch_device)
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
@ -1528,13 +1529,13 @@ class ModelTesterMixin:
tensor = np.array(tensor, dtype=bool)
tf_inputs_dict[key] = tf.convert_to_tensor(tensor, dtype=tf.int32)
elif key == "input_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
elif key == "pixel_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
elif key == "input_features":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
else:
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32)
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.int32)
# need to rename encoder-decoder "inputs" for PyTorch
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
@ -1545,7 +1546,7 @@ class ModelTesterMixin:
tfo = tf_model(tf_inputs_dict)
tfo = tfo[0].numpy()
pto = pto[0].numpy()
pto = pto[0].cpu().numpy()
tf_nans = np.copy(np.isnan(tfo))
pt_nans = np.copy(np.isnan(pto))

View File

@ -776,16 +776,16 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase):
else:
if isinstance(value, (list, tuple)):
return_dict[key] = (
tf.convert_to_tensor(iter_value.numpy(), dtype=tf.int32) for iter_value in value
tf.convert_to_tensor(iter_value.cpu().numpy(), dtype=tf.int32) for iter_value in value
)
else:
return_dict[key] = tf.convert_to_tensor(value.numpy(), dtype=tf.int32)
return_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.int32)
return return_dict
tf_inputs_dict = recursive_numpy_convert(pt_inputs)
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model).to(torch_device)
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
@ -795,12 +795,6 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase):
if "obj_labels" in inputs_dict:
del inputs_dict["obj_labels"]
def torch_type(key):
if key in ("visual_feats", "visual_pos"):
return torch.float32
else:
return torch.long
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
tf_inputs_dict = recursive_numpy_convert(pt_inputs)
@ -808,7 +802,7 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase):
pto = pt_model(**pt_inputs)
tfo = tf_model(tf_inputs_dict, training=False)
tf_hidden_states = tfo[0].numpy()
pt_hidden_states = pto[0].numpy()
pt_hidden_states = pto[0].cpu().numpy()
tf_nans = np.copy(np.isnan(tf_hidden_states))
pt_nans = np.copy(np.isnan(pt_hidden_states))
@ -852,7 +846,7 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase):
tfo = tf_model(tf_inputs_dict)
tfo = tfo[0].numpy()
pto = pto[0].numpy()
pto = pto[0].cpu().numpy()
tf_nans = np.copy(np.isnan(tfo))
pt_nans = np.copy(np.isnan(pto))