From 692e61e91a0b83f5b847902ed619b7c74c0a5dda Mon Sep 17 00:00:00 2001 From: Crystina Date: Wed, 29 Jun 2022 15:49:02 -0700 Subject: [PATCH] Flax t5 Encoder (#17784) * first draft adding Flax-t5-encoder and Flax-mt5-encoder * imports * after make fixup * flax t5 encoder test * black on test * make fix-copies * clean * all_model_classes -> tuple * clean test * is_encoder_decoder=False in t5-enc tester * remove file docstring before FlaxT5Encoder * black * isort * commit suggestions on src/transformers/models/t5/modeling_flax_t5.py Co-authored-by: Suraj Patil * commit suggestions on src/transformers/models/t5/modeling_flax_t5.py Co-authored-by: Suraj Patil * Apply suggestions from code review Co-authored-by: Suraj Patil * remove _get_encoder_module * self.decoder_seq_length -> self.encoder_seq_length as t5-enc does not have decoder * bugfix - self.module_class is class itself, not instance; * docs for mt5 and t5 * call -> __call__ in t5 doc * FlaxMT5EncoderModel to TYPE_HINT * run doc-builder to allow change the files Co-authored-by: Suraj Patil --- docs/source/en/model_doc/mt5.mdx | 4 + docs/source/en/model_doc/t5.mdx | 5 + src/transformers/__init__.py | 10 +- src/transformers/models/mt5/__init__.py | 4 +- .../models/mt5/modeling_flax_mt5.py | 29 +- src/transformers/models/t5/__init__.py | 8 +- .../models/t5/modeling_flax_t5.py | 96 +++++- src/transformers/utils/dummy_flax_objects.py | 14 + tests/models/t5/test_modeling_flax_t5.py | 299 +++++++++++++++++- 9 files changed, 454 insertions(+), 15 deletions(-) diff --git a/docs/source/en/model_doc/mt5.mdx b/docs/source/en/model_doc/mt5.mdx index fa39e4ab2c..a00003da46 100644 --- a/docs/source/en/model_doc/mt5.mdx +++ b/docs/source/en/model_doc/mt5.mdx @@ -96,3 +96,7 @@ See [`T5TokenizerFast`] for all details. ## FlaxMT5ForConditionalGeneration [[autodoc]] FlaxMT5ForConditionalGeneration + +## FlaxMT5EncoderModel + +[[autodoc]] FlaxMT5EncoderModel diff --git a/docs/source/en/model_doc/t5.mdx b/docs/source/en/model_doc/t5.mdx index 8034fc010a..5a19289234 100644 --- a/docs/source/en/model_doc/t5.mdx +++ b/docs/source/en/model_doc/t5.mdx @@ -371,3 +371,8 @@ T5 is supported by several example scripts, both for pre-training and fine-tunin - __call__ - encode - decode + +## FlaxT5EncoderModel + +[[autodoc]] FlaxT5EncoderModel + - __call__ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 999fc54be8..75837649c9 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2704,7 +2704,7 @@ else: "FlaxMBartPreTrainedModel", ] ) - _import_structure["models.mt5"].extend(["FlaxMT5ForConditionalGeneration", "FlaxMT5Model"]) + _import_structure["models.mt5"].extend(["FlaxMT5EncoderModel", "FlaxMT5ForConditionalGeneration", "FlaxMT5Model"]) _import_structure["models.opt"].extend( [ "FlaxOPTForCausalLM", @@ -2743,7 +2743,9 @@ else: ] ) _import_structure["models.speech_encoder_decoder"].append("FlaxSpeechEncoderDecoderModel") - _import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"]) + _import_structure["models.t5"].extend( + ["FlaxT5EncoderModel", "FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"] + ) _import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel") _import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"]) _import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"]) @@ -4974,7 +4976,7 @@ if TYPE_CHECKING: FlaxMBartModel, FlaxMBartPreTrainedModel, ) - from .models.mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model + from .models.mt5 import FlaxMT5EncoderModel, FlaxMT5ForConditionalGeneration, FlaxMT5Model from .models.opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel from .models.roberta import ( @@ -4997,7 +4999,7 @@ if TYPE_CHECKING: FlaxRoFormerPreTrainedModel, ) from .models.speech_encoder_decoder import FlaxSpeechEncoderDecoderModel - from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel + from .models.t5 import FlaxT5EncoderModel, FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel from .models.vision_text_dual_encoder import FlaxVisionTextDualEncoderModel from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel diff --git a/src/transformers/models/mt5/__init__.py b/src/transformers/models/mt5/__init__.py index 66daf82b38..3f04a25691 100644 --- a/src/transformers/models/mt5/__init__.py +++ b/src/transformers/models/mt5/__init__.py @@ -67,7 +67,7 @@ try: except OptionalDependencyNotAvailable: pass else: - _import_structure["modeling_flax_mt5"] = ["FlaxMT5ForConditionalGeneration", "FlaxMT5Model"] + _import_structure["modeling_flax_mt5"] = ["FlaxMT5EncoderModel", "FlaxMT5ForConditionalGeneration", "FlaxMT5Model"] if TYPE_CHECKING: @@ -95,7 +95,7 @@ if TYPE_CHECKING: except OptionalDependencyNotAvailable: pass else: - from .modeling_flax_mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model + from .modeling_flax_mt5 import FlaxMT5EncoderModel, FlaxMT5ForConditionalGeneration, FlaxMT5Model else: import sys diff --git a/src/transformers/models/mt5/modeling_flax_mt5.py b/src/transformers/models/mt5/modeling_flax_mt5.py index 0f35e7f9e4..841d2069e6 100644 --- a/src/transformers/models/mt5/modeling_flax_mt5.py +++ b/src/transformers/models/mt5/modeling_flax_mt5.py @@ -17,7 +17,7 @@ import numpy as np from ...utils import logging -from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model +from ..t5.modeling_flax_t5 import FlaxT5EncoderModel, FlaxT5ForConditionalGeneration, FlaxT5Model from .configuration_mt5 import MT5Config @@ -67,6 +67,33 @@ class FlaxMT5Model(FlaxT5Model): config_class = MT5Config +class FlaxMT5EncoderModel(FlaxT5EncoderModel): + r""" + This class overrides [`FlaxT5EncoderModel`]. Please check the superclass for the appropriate documentation + alongside usage examples. + + Examples: + + ```python + >>> from transformers import FlaxT5EncoderModel, T5Tokenizer + + >>> model = FlaxT5EncoderModel.from_pretrained("google/mt5-small") + >>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small") + + >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." + >>> summary = "Weiter Verhandlung in Syrien." + >>> inputs = tokenizer(article, return_tensors="np") + + >>> with tokenizer.as_target_tokenizer(): + ... decoder_input_ids = tokenizer(summary, return_tensors="np").input_ids + + >>> outputs = model(input_ids=inputs["input_ids"]) + >>> hidden_states = outputs.last_hidden_state + ```""" + model_type = "mt5" + config_class = MT5Config + + class FlaxMT5ForConditionalGeneration(FlaxT5ForConditionalGeneration): r""" This class overrides [`FlaxT5ForConditionalGeneration`]. Please check the superclass for the appropriate diff --git a/src/transformers/models/t5/__init__.py b/src/transformers/models/t5/__init__.py index 178fb38678..2f0bd9521a 100644 --- a/src/transformers/models/t5/__init__.py +++ b/src/transformers/models/t5/__init__.py @@ -83,6 +83,7 @@ except OptionalDependencyNotAvailable: pass else: _import_structure["modeling_flax_t5"] = [ + "FlaxT5EncoderModel", "FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel", @@ -143,7 +144,12 @@ if TYPE_CHECKING: except OptionalDependencyNotAvailable: pass else: - from .modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel + from .modeling_flax_t5 import ( + FlaxT5EncoderModel, + FlaxT5ForConditionalGeneration, + FlaxT5Model, + FlaxT5PreTrainedModel, + ) else: diff --git a/src/transformers/models/t5/modeling_flax_t5.py b/src/transformers/models/t5/modeling_flax_t5.py index 4094599592..06ad510542 100644 --- a/src/transformers/models/t5/modeling_flax_t5.py +++ b/src/transformers/models/t5/modeling_flax_t5.py @@ -929,18 +929,18 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel): input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) - decoder_input_ids = jnp.ones_like(input_ids) - decoder_attention_mask = jnp.ones_like(input_ids) + args = [input_ids, attention_mask] + if self.module_class not in [FlaxT5EncoderModule]: + decoder_input_ids = jnp.ones_like(input_ids) + decoder_attention_mask = jnp.ones_like(input_ids) + args.extend([decoder_input_ids, decoder_attention_mask]) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} random_params = self.module.init( rngs, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, + *args, )["params"] if params is not None: @@ -1357,6 +1357,90 @@ overwrite_call_docstring(FlaxT5Model, T5_INPUTS_DOCSTRING + FLAX_T5_MODEL_DOCSTR append_replace_return_docstrings(FlaxT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) +@add_start_docstrings( + "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.", + T5_START_DOCSTRING, +) +class FlaxT5EncoderModule(nn.Module): + config: T5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), + ) + + encoder_config = copy.deepcopy(self.config) + encoder_config.is_decoder = False + encoder_config.is_encoder_decoder = False + encoder_config.causal = False + self.encoder = FlaxT5Stack(encoder_config, embed_tokens=self.shared, dtype=self.dtype) + + def __call__( + self, + input_ids=None, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + deterministic: bool = True, + ): + + # Encode if needed (training, first prediction pass) + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + return encoder_outputs + + +class FlaxT5EncoderModel(FlaxT5PreTrainedModel): + module_class = FlaxT5EncoderModule + + @add_start_docstrings_to_model_forward(T5_ENCODE_INPUTS_DOCSTRING) + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # prepare encoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + @add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING) class FlaxT5ForConditionalGenerationModule(nn.Module): config: T5Config diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 44c3b1cf3e..953808dab8 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -802,6 +802,13 @@ class FlaxMBartPreTrainedModel(metaclass=DummyObject): requires_backends(self, ["flax"]) +class FlaxMT5EncoderModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxMT5ForConditionalGeneration(metaclass=DummyObject): _backends = ["flax"] @@ -970,6 +977,13 @@ class FlaxSpeechEncoderDecoderModel(metaclass=DummyObject): requires_backends(self, ["flax"]) +class FlaxT5EncoderModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxT5ForConditionalGeneration(metaclass=DummyObject): _backends = ["flax"] diff --git a/tests/models/t5/test_modeling_flax_t5.py b/tests/models/t5/test_modeling_flax_t5.py index f3b2c166ed..3186567709 100644 --- a/tests/models/t5/test_modeling_flax_t5.py +++ b/tests/models/t5/test_modeling_flax_t5.py @@ -48,7 +48,12 @@ if is_flax_available(): from flax.traverse_util import flatten_dict from transformers import FLAX_MODEL_MAPPING, ByT5Tokenizer, T5Config, T5Tokenizer from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model - from transformers.models.t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, shift_tokens_right + from transformers.models.t5.modeling_flax_t5 import ( + FlaxT5EncoderModel, + FlaxT5ForConditionalGeneration, + FlaxT5Model, + shift_tokens_right, + ) class FlaxT5ModelTester: @@ -461,6 +466,298 @@ class FlaxT5ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest. self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") +class FlaxT5EncoderOnlyModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=13, + encoder_seq_length=7, + # For common tests + is_training=True, + use_attention_mask=True, + use_labels=True, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + d_ff=37, + relative_attention_num_buckets=8, + dropout_rate=0.1, + initializer_factor=0.002, + eos_token_id=1, + pad_token_id=0, + decoder_start_token_id=0, + scope=None, + ): + + self.parent = parent + self.batch_size = batch_size + self.encoder_seq_length = encoder_seq_length + # For common tests + self.seq_length = self.encoder_seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.d_ff = d_ff + self.relative_attention_num_buckets = relative_attention_num_buckets + self.dropout_rate = dropout_rate + self.initializer_factor = initializer_factor + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.decoder_start_token_id = decoder_start_token_id + self.scope = None + self.decoder_layers = 0 + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) + + config = T5Config( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + d_ff=self.d_ff, + d_kv=self.hidden_size // self.num_attention_heads, + num_layers=self.num_hidden_layers, + num_decoder_layers=self.decoder_layers, + num_heads=self.num_attention_heads, + relative_attention_num_buckets=self.relative_attention_num_buckets, + dropout_rate=self.dropout_rate, + initializer_factor=self.initializer_factor, + eos_token_id=self.eos_token_id, + bos_token_id=self.pad_token_id, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + is_encoder_decoder=False, + ) + + return ( + config, + input_ids, + attention_mask, + ) + + def create_and_check_model( + self, + config, + input_ids, + attention_mask, + ): + model = FlaxT5EncoderModel(config=config) + result = model( + input_ids=input_ids, + attention_mask=attention_mask, + ) + result = model(input_ids=input_ids) + encoder_output = result.last_hidden_state + + self.parent.assertEqual(encoder_output.shape, (self.batch_size, self.encoder_seq_length, self.hidden_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + attention_mask, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_flax +class FlaxT5EncoderOnlyModelTest(FlaxModelTesterMixin, unittest.TestCase): + + all_model_classes = (FlaxT5EncoderModel,) if is_flax_available() else () + is_encoder_decoder = False + + def setUp(self): + self.model_tester = FlaxT5EncoderOnlyModelTester(self) + self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_v1_1(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + # check that gated gelu feed forward and different word embeddings work + config = config_and_inputs[0] + config.tie_word_embeddings = False + config.feed_forward_proj = "gated-gelu" + self.model_tester.create_and_check_model(config, *config_and_inputs[1:]) + + def test_encode(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + + @jax.jit + def encode_jitted(input_ids, attention_mask=None, **kwargs): + return model(input_ids=input_ids, attention_mask=attention_mask) + + with self.subTest("JIT Enabled"): + jitted_outputs = encode_jitted(**prepared_inputs_dict).to_tuple() + + with self.subTest("JIT Disabled"): + with jax.disable_jit(): + outputs = encode_jitted(**prepared_inputs_dict).to_tuple() + + self.assertEqual(len(outputs), len(jitted_outputs)) + for jitted_output, output in zip(jitted_outputs, outputs): + self.assertEqual(jitted_output.shape, output.shape) + + # overwrite since special base model prefix is used + def test_save_load_from_base(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = base_class(config) + base_params = flatten_dict(unfreeze(model.params)) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + head_model = model_class.from_pretrained(tmpdirname) + + base_param_from_head = flatten_dict(unfreeze(head_model.params)) + + for key in base_param_from_head.keys(): + max_diff = (base_params[key] - base_param_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + # overwrite since special base model prefix is used + def test_save_load_to_base(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = model_class(config) + base_params_from_head = flatten_dict(unfreeze(model.params)) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + base_model = base_class.from_pretrained(tmpdirname) + + base_params = flatten_dict(unfreeze(base_model.params)) + + for key in base_params_from_head.keys(): + max_diff = (base_params[key] - base_params_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + # overwrite since special base model prefix is used + @is_pt_flax_cross_test + def test_save_load_from_base_pt(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = base_class(config) + base_params = flatten_dict(unfreeze(model.params)) + + # convert Flax model to PyTorch model + pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning + pt_model = pt_model_class(config).eval() + pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + # save pt model + pt_model.save_pretrained(tmpdirname) + head_model = model_class.from_pretrained(tmpdirname, from_pt=True) + + base_param_from_head = flatten_dict(unfreeze(head_model.params)) + + for key in base_param_from_head.keys(): + max_diff = (base_params[key] - base_param_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + # overwrite since special base model prefix is used + @is_pt_flax_cross_test + def test_save_load_to_base_pt(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = model_class(config) + base_params_from_head = flatten_dict(unfreeze(model.params)) + + # convert Flax model to PyTorch model + pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning + pt_model = pt_model_class(config).eval() + pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + base_model = base_class.from_pretrained(tmpdirname, from_pt=True) + + base_params = flatten_dict(unfreeze(base_model.params)) + + for key in base_params_from_head.keys(): + max_diff = (base_params[key] - base_params_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + # overwrite since special base model prefix is used + @is_pt_flax_cross_test + def test_save_load_bf16_to_base_pt(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = model_class(config) + model.params = model.to_bf16(model.params) + base_params_from_head = flatten_dict(unfreeze(model.params)) + + # convert Flax model to PyTorch model + pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning + pt_model = pt_model_class(config).eval() + pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + base_model = base_class.from_pretrained(tmpdirname, from_pt=True) + + base_params = flatten_dict(unfreeze(base_model.params)) + + for key in base_params_from_head.keys(): + max_diff = (base_params[key] - base_params_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + @require_sentencepiece @require_tokenizers @require_flax