2020-12-11 04:57:39 +08:00
|
|
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2021-03-31 14:45:58 +08:00
|
|
|
import copy
|
2021-04-29 18:04:51 +08:00
|
|
|
import inspect
|
2020-12-11 04:57:39 +08:00
|
|
|
import random
|
2020-12-16 20:03:32 +08:00
|
|
|
import tempfile
|
2021-04-29 18:04:51 +08:00
|
|
|
from typing import List, Tuple
|
2020-12-11 04:57:39 +08:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
import transformers
|
|
|
|
from transformers import is_flax_available, is_torch_available
|
2021-03-16 13:05:37 +08:00
|
|
|
from transformers.testing_utils import is_pt_flax_cross_test, require_flax
|
2020-12-11 04:57:39 +08:00
|
|
|
|
|
|
|
|
|
|
|
if is_flax_available():
|
|
|
|
import os
|
|
|
|
|
|
|
|
import jax
|
|
|
|
import jax.numpy as jnp
|
2021-04-29 18:04:51 +08:00
|
|
|
import jaxlib.xla_extension as jax_xla
|
2021-04-23 15:53:09 +08:00
|
|
|
from transformers.modeling_flax_pytorch_utils import (
|
|
|
|
convert_pytorch_state_dict_to_flax,
|
|
|
|
load_flax_weights_in_pytorch_model,
|
|
|
|
)
|
2020-12-11 04:57:39 +08:00
|
|
|
|
|
|
|
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
|
|
|
|
|
|
|
if is_torch_available():
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
def ids_tensor(shape, vocab_size, rng=None):
|
|
|
|
"""Creates a random int32 tensor of the shape within the vocab size."""
|
|
|
|
if rng is None:
|
|
|
|
rng = random.Random()
|
|
|
|
|
|
|
|
total_dims = 1
|
|
|
|
for dim in shape:
|
|
|
|
total_dims *= dim
|
|
|
|
|
|
|
|
values = []
|
|
|
|
for _ in range(total_dims):
|
|
|
|
values.append(rng.randint(0, vocab_size - 1))
|
|
|
|
|
|
|
|
output = np.array(values, dtype=jnp.int32).reshape(shape)
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
2021-06-01 12:14:31 +08:00
|
|
|
def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
|
|
|
"""Creates a random float32 tensor"""
|
|
|
|
if rng is None:
|
|
|
|
rng = random.Random()
|
|
|
|
|
|
|
|
total_dims = 1
|
|
|
|
for dim in shape:
|
|
|
|
total_dims *= dim
|
|
|
|
|
|
|
|
values = []
|
|
|
|
for _ in range(total_dims):
|
|
|
|
values.append(rng.random() * scale)
|
|
|
|
|
|
|
|
return np.array(values, dtype=jnp.float32).reshape(shape)
|
|
|
|
|
|
|
|
|
2020-12-11 04:57:39 +08:00
|
|
|
def random_attention_mask(shape, rng=None):
|
|
|
|
attn_mask = ids_tensor(shape, vocab_size=2, rng=rng)
|
|
|
|
# make sure that at least one token is attended to for each batch
|
|
|
|
attn_mask[:, -1] = 1
|
|
|
|
return attn_mask
|
|
|
|
|
|
|
|
|
2021-03-18 14:44:17 +08:00
|
|
|
@require_flax
|
2020-12-11 04:57:39 +08:00
|
|
|
class FlaxModelTesterMixin:
|
|
|
|
model_tester = None
|
|
|
|
all_model_classes = ()
|
|
|
|
|
2021-03-31 14:45:58 +08:00
|
|
|
def _prepare_for_class(self, inputs_dict, model_class):
|
|
|
|
inputs_dict = copy.deepcopy(inputs_dict)
|
|
|
|
|
|
|
|
# hack for now until we have AutoModel classes
|
|
|
|
if "ForMultipleChoice" in model_class.__name__:
|
|
|
|
inputs_dict = {
|
|
|
|
k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
|
2021-04-29 18:04:51 +08:00
|
|
|
if isinstance(v, (jax_xla.DeviceArray, np.ndarray))
|
2021-05-28 18:46:56 +08:00
|
|
|
else v
|
|
|
|
for k, v in inputs_dict.items()
|
2021-03-31 14:45:58 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
return inputs_dict
|
|
|
|
|
2020-12-11 04:57:39 +08:00
|
|
|
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
|
2020-12-16 20:03:32 +08:00
|
|
|
diff = np.abs((a - b)).max()
|
2020-12-11 04:57:39 +08:00
|
|
|
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
|
|
|
|
2021-04-29 18:04:51 +08:00
|
|
|
def test_model_outputs_equivalence(self):
|
|
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
|
|
|
|
def set_nan_tensor_to_zero(t):
|
|
|
|
t[t != t] = 0
|
|
|
|
return t
|
|
|
|
|
|
|
|
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
|
|
|
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
|
|
|
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
|
|
|
|
|
|
|
def recursive_check(tuple_object, dict_object):
|
|
|
|
if isinstance(tuple_object, (List, Tuple)):
|
|
|
|
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
|
|
|
|
recursive_check(tuple_iterable_value, dict_iterable_value)
|
|
|
|
elif tuple_object is None:
|
|
|
|
return
|
|
|
|
else:
|
|
|
|
self.assert_almost_equals(
|
|
|
|
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), 1e-5
|
|
|
|
)
|
|
|
|
|
|
|
|
recursive_check(tuple_output, dict_output)
|
|
|
|
|
|
|
|
for model_class in self.all_model_classes:
|
|
|
|
model = model_class(config)
|
|
|
|
|
|
|
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
|
|
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
|
|
|
check_equivalence(model, tuple_inputs, dict_inputs)
|
|
|
|
|
|
|
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
|
|
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
|
|
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
|
|
|
|
2021-03-16 13:05:37 +08:00
|
|
|
@is_pt_flax_cross_test
|
2021-04-23 15:53:09 +08:00
|
|
|
def test_equivalence_pt_to_flax(self):
|
2020-12-11 04:57:39 +08:00
|
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
|
|
|
|
for model_class in self.all_model_classes:
|
|
|
|
with self.subTest(model_class.__name__):
|
2021-04-23 15:53:09 +08:00
|
|
|
# prepare inputs
|
2021-03-31 14:45:58 +08:00
|
|
|
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
2021-04-23 15:53:09 +08:00
|
|
|
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
|
|
|
|
|
|
|
# load corresponding PyTorch class
|
2020-12-11 04:57:39 +08:00
|
|
|
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
|
|
|
pt_model_class = getattr(transformers, pt_model_class_name)
|
|
|
|
|
2021-04-23 15:53:09 +08:00
|
|
|
pt_model = pt_model_class(config).eval()
|
2020-12-16 20:03:32 +08:00
|
|
|
fx_model = model_class(config, dtype=jnp.float32)
|
2021-04-23 15:53:09 +08:00
|
|
|
|
2021-03-30 17:13:59 +08:00
|
|
|
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
2020-12-16 20:03:32 +08:00
|
|
|
fx_model.params = fx_state
|
2020-12-11 04:57:39 +08:00
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
2020-12-16 20:03:32 +08:00
|
|
|
|
2021-04-29 18:04:51 +08:00
|
|
|
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
2020-12-11 04:57:39 +08:00
|
|
|
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
|
|
|
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
2021-05-05 01:57:59 +08:00
|
|
|
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
2020-12-11 04:57:39 +08:00
|
|
|
|
2020-12-16 20:03:32 +08:00
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
|
|
pt_model.save_pretrained(tmpdirname)
|
|
|
|
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
|
|
|
|
|
2021-04-29 18:04:51 +08:00
|
|
|
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
|
2020-12-16 20:03:32 +08:00
|
|
|
self.assertEqual(
|
|
|
|
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
|
|
|
)
|
|
|
|
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
2021-05-05 01:57:59 +08:00
|
|
|
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
|
2021-04-23 15:53:09 +08:00
|
|
|
|
|
|
|
@is_pt_flax_cross_test
|
|
|
|
def test_equivalence_flax_to_pt(self):
|
|
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
|
|
|
|
for model_class in self.all_model_classes:
|
|
|
|
with self.subTest(model_class.__name__):
|
|
|
|
# prepare inputs
|
|
|
|
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
|
|
|
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
|
|
|
|
|
|
|
# load corresponding PyTorch class
|
|
|
|
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
|
|
|
pt_model_class = getattr(transformers, pt_model_class_name)
|
|
|
|
|
|
|
|
pt_model = pt_model_class(config).eval()
|
|
|
|
fx_model = model_class(config, dtype=jnp.float32)
|
|
|
|
|
|
|
|
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
|
|
|
|
|
|
|
# make sure weights are tied in PyTorch
|
|
|
|
pt_model.tie_weights()
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
|
|
|
|
2021-04-29 18:04:51 +08:00
|
|
|
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
2021-04-23 15:53:09 +08:00
|
|
|
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
|
|
|
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
2021-05-05 01:57:59 +08:00
|
|
|
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
2021-04-23 15:53:09 +08:00
|
|
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
|
|
fx_model.save_pretrained(tmpdirname)
|
|
|
|
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
|
|
|
|
|
|
|
self.assertEqual(
|
|
|
|
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
|
|
|
)
|
|
|
|
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
|
2021-05-05 01:57:59 +08:00
|
|
|
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
2020-12-16 20:03:32 +08:00
|
|
|
|
|
|
|
def test_from_pretrained_save_pretrained(self):
|
2020-12-11 04:57:39 +08:00
|
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
|
|
|
|
for model_class in self.all_model_classes:
|
2021-04-29 18:04:51 +08:00
|
|
|
if model_class.__name__ != "FlaxBertModel":
|
|
|
|
continue
|
|
|
|
|
2020-12-11 04:57:39 +08:00
|
|
|
with self.subTest(model_class.__name__):
|
2020-12-16 20:03:32 +08:00
|
|
|
model = model_class(config)
|
2020-12-11 04:57:39 +08:00
|
|
|
|
2021-03-31 14:45:58 +08:00
|
|
|
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
2021-04-29 18:04:51 +08:00
|
|
|
outputs = model(**prepared_inputs_dict).to_tuple()
|
2020-12-11 04:57:39 +08:00
|
|
|
|
2021-05-05 01:57:59 +08:00
|
|
|
# verify that normal save_pretrained works as expected
|
2020-12-16 20:03:32 +08:00
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
|
|
model.save_pretrained(tmpdirname)
|
|
|
|
model_loaded = model_class.from_pretrained(tmpdirname)
|
|
|
|
|
2021-04-29 18:04:51 +08:00
|
|
|
outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
|
2021-05-05 01:57:59 +08:00
|
|
|
for output_loaded, output in zip(outputs_loaded, outputs):
|
|
|
|
self.assert_almost_equals(output_loaded, output, 1e-3)
|
|
|
|
|
|
|
|
# verify that save_pretrained for distributed training
|
|
|
|
# with `params=params` works as expected
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
|
|
model.save_pretrained(tmpdirname, params=model.params)
|
|
|
|
model_loaded = model_class.from_pretrained(tmpdirname)
|
|
|
|
|
|
|
|
outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
|
2020-12-16 20:03:32 +08:00
|
|
|
for output_loaded, output in zip(outputs_loaded, outputs):
|
2021-04-23 15:53:09 +08:00
|
|
|
self.assert_almost_equals(output_loaded, output, 1e-3)
|
2020-12-16 20:03:32 +08:00
|
|
|
|
|
|
|
def test_jit_compilation(self):
|
|
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
|
|
|
|
for model_class in self.all_model_classes:
|
|
|
|
with self.subTest(model_class.__name__):
|
2021-03-31 14:45:58 +08:00
|
|
|
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
2020-12-16 20:03:32 +08:00
|
|
|
model = model_class(config)
|
2020-12-11 04:57:39 +08:00
|
|
|
|
|
|
|
@jax.jit
|
2021-05-19 05:50:51 +08:00
|
|
|
def model_jitted(input_ids, attention_mask=None, **kwargs):
|
2021-05-26 22:01:13 +08:00
|
|
|
return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
|
2021-04-29 18:04:51 +08:00
|
|
|
|
|
|
|
with self.subTest("JIT Enabled"):
|
2021-05-26 22:01:13 +08:00
|
|
|
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
2020-12-11 04:57:39 +08:00
|
|
|
|
|
|
|
with self.subTest("JIT Disabled"):
|
|
|
|
with jax.disable_jit():
|
2021-05-26 22:01:13 +08:00
|
|
|
outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
2020-12-11 04:57:39 +08:00
|
|
|
|
|
|
|
self.assertEqual(len(outputs), len(jitted_outputs))
|
|
|
|
for jitted_output, output in zip(jitted_outputs, outputs):
|
|
|
|
self.assertEqual(jitted_output.shape, output.shape)
|
2020-12-16 20:03:32 +08:00
|
|
|
|
2021-04-29 18:04:51 +08:00
|
|
|
def test_forward_signature(self):
|
|
|
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
|
|
|
|
for model_class in self.all_model_classes:
|
|
|
|
model = model_class(config)
|
|
|
|
signature = inspect.signature(model.__call__)
|
|
|
|
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
|
|
|
arg_names = [*signature.parameters.keys()]
|
|
|
|
|
|
|
|
expected_arg_names = ["input_ids", "attention_mask"]
|
|
|
|
self.assertListEqual(arg_names[:2], expected_arg_names)
|
|
|
|
|
2020-12-16 20:03:32 +08:00
|
|
|
def test_naming_convention(self):
|
|
|
|
for model_class in self.all_model_classes:
|
|
|
|
model_class_name = model_class.__name__
|
|
|
|
module_class_name = (
|
|
|
|
model_class_name[:-5] + "Module" if model_class_name[-5:] == "Model" else model_class_name + "Module"
|
|
|
|
)
|
|
|
|
bert_modeling_flax_module = __import__(model_class.__module__, fromlist=[module_class_name])
|
|
|
|
module_cls = getattr(bert_modeling_flax_module, module_class_name)
|
|
|
|
|
|
|
|
self.assertIsNotNone(module_cls)
|
2021-04-29 18:04:51 +08:00
|
|
|
|
|
|
|
def test_hidden_states_output(self):
|
|
|
|
def check_hidden_states_output(inputs_dict, config, model_class):
|
|
|
|
model = model_class(config)
|
|
|
|
|
|
|
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
|
|
|
hidden_states = outputs.hidden_states
|
|
|
|
|
|
|
|
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
|
|
|
|
seq_length = self.model_tester.seq_length
|
|
|
|
|
|
|
|
self.assertListEqual(
|
|
|
|
list(hidden_states[0].shape[-2:]),
|
|
|
|
[seq_length, self.model_tester.hidden_size],
|
|
|
|
)
|
|
|
|
|
|
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
|
|
|
|
for model_class in self.all_model_classes:
|
|
|
|
inputs_dict["output_hidden_states"] = True
|
|
|
|
check_hidden_states_output(inputs_dict, config, model_class)
|
|
|
|
|
|
|
|
# check that output_hidden_states also work using config
|
|
|
|
del inputs_dict["output_hidden_states"]
|
|
|
|
config.output_hidden_states = True
|
|
|
|
|
|
|
|
check_hidden_states_output(inputs_dict, config, model_class)
|
2021-05-28 18:46:56 +08:00
|
|
|
|
|
|
|
def test_attention_outputs(self):
|
|
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
config.return_dict = True
|
|
|
|
|
|
|
|
seq_length = getattr(self.model_tester, "seq_length", None)
|
|
|
|
|
|
|
|
for model_class in self.all_model_classes:
|
|
|
|
inputs_dict["output_attentions"] = True
|
|
|
|
inputs_dict["output_hidden_states"] = False
|
|
|
|
model = model_class(config)
|
|
|
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
|
|
|
attentions = outputs.attentions
|
|
|
|
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
|
|
|
|
|
|
|
# check that output_attentions also work using config
|
|
|
|
del inputs_dict["output_attentions"]
|
|
|
|
config.output_attentions = True
|
|
|
|
model = model_class(config)
|
|
|
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
|
|
|
attentions = outputs.attentions
|
|
|
|
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
|
|
|
|
|
|
|
self.assertListEqual(
|
|
|
|
list(attentions[0].shape[-3:]),
|
|
|
|
[self.model_tester.num_attention_heads, seq_length, seq_length],
|
|
|
|
)
|
|
|
|
out_len = len(outputs)
|
|
|
|
|
|
|
|
# Check attention is always last and order is fine
|
|
|
|
inputs_dict["output_attentions"] = True
|
|
|
|
inputs_dict["output_hidden_states"] = True
|
|
|
|
model = model_class(config)
|
|
|
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
|
|
|
|
|
|
|
added_hidden_states = 1
|
|
|
|
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
|
|
|
|
|
|
|
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
|
|
|
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
|
|
|
|
|
|
|
self.assertListEqual(
|
|
|
|
list(self_attentions[0].shape[-3:]),
|
|
|
|
[self.model_tester.num_attention_heads, seq_length, seq_length],
|
|
|
|
)
|