Flax t5 Encoder (#17784)
* first draft adding Flax-t5-encoder and Flax-mt5-encoder * imports * after make fixup * flax t5 encoder test * black on test * make fix-copies * clean * all_model_classes -> tuple * clean test * is_encoder_decoder=False in t5-enc tester * remove file docstring before FlaxT5Encoder * black * isort * commit suggestions on src/transformers/models/t5/modeling_flax_t5.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * commit suggestions on src/transformers/models/t5/modeling_flax_t5.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> * remove _get_encoder_module * self.decoder_seq_length -> self.encoder_seq_length as t5-enc does not have decoder * bugfix - self.module_class is class itself, not instance; * docs for mt5 and t5 * call -> __call__ in t5 doc * FlaxMT5EncoderModel to TYPE_HINT * run doc-builder to allow change the files Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
parent
eb1493b15d
commit
692e61e91a
|
@ -96,3 +96,7 @@ See [`T5TokenizerFast`] for all details.
|
|||
## FlaxMT5ForConditionalGeneration
|
||||
|
||||
[[autodoc]] FlaxMT5ForConditionalGeneration
|
||||
|
||||
## FlaxMT5EncoderModel
|
||||
|
||||
[[autodoc]] FlaxMT5EncoderModel
|
||||
|
|
|
@ -371,3 +371,8 @@ T5 is supported by several example scripts, both for pre-training and fine-tunin
|
|||
- __call__
|
||||
- encode
|
||||
- decode
|
||||
|
||||
## FlaxT5EncoderModel
|
||||
|
||||
[[autodoc]] FlaxT5EncoderModel
|
||||
- __call__
|
||||
|
|
|
@ -2704,7 +2704,7 @@ else:
|
|||
"FlaxMBartPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.mt5"].extend(["FlaxMT5ForConditionalGeneration", "FlaxMT5Model"])
|
||||
_import_structure["models.mt5"].extend(["FlaxMT5EncoderModel", "FlaxMT5ForConditionalGeneration", "FlaxMT5Model"])
|
||||
_import_structure["models.opt"].extend(
|
||||
[
|
||||
"FlaxOPTForCausalLM",
|
||||
|
@ -2743,7 +2743,9 @@ else:
|
|||
]
|
||||
)
|
||||
_import_structure["models.speech_encoder_decoder"].append("FlaxSpeechEncoderDecoderModel")
|
||||
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"])
|
||||
_import_structure["models.t5"].extend(
|
||||
["FlaxT5EncoderModel", "FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"]
|
||||
)
|
||||
_import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel")
|
||||
_import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"])
|
||||
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"])
|
||||
|
@ -4974,7 +4976,7 @@ if TYPE_CHECKING:
|
|||
FlaxMBartModel,
|
||||
FlaxMBartPreTrainedModel,
|
||||
)
|
||||
from .models.mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model
|
||||
from .models.mt5 import FlaxMT5EncoderModel, FlaxMT5ForConditionalGeneration, FlaxMT5Model
|
||||
from .models.opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel
|
||||
from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel
|
||||
from .models.roberta import (
|
||||
|
@ -4997,7 +4999,7 @@ if TYPE_CHECKING:
|
|||
FlaxRoFormerPreTrainedModel,
|
||||
)
|
||||
from .models.speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
|
||||
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
|
||||
from .models.t5 import FlaxT5EncoderModel, FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
|
||||
from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel
|
||||
from .models.vision_text_dual_encoder import FlaxVisionTextDualEncoderModel
|
||||
from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
|
||||
|
|
|
@ -67,7 +67,7 @@ try:
|
|||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_flax_mt5"] = ["FlaxMT5ForConditionalGeneration", "FlaxMT5Model"]
|
||||
_import_structure["modeling_flax_mt5"] = ["FlaxMT5EncoderModel", "FlaxMT5ForConditionalGeneration", "FlaxMT5Model"]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -95,7 +95,7 @@ if TYPE_CHECKING:
|
|||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_flax_mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model
|
||||
from .modeling_flax_mt5 import FlaxMT5EncoderModel, FlaxMT5ForConditionalGeneration, FlaxMT5Model
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
import numpy as np
|
||||
|
||||
from ...utils import logging
|
||||
from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
|
||||
from ..t5.modeling_flax_t5 import FlaxT5EncoderModel, FlaxT5ForConditionalGeneration, FlaxT5Model
|
||||
from .configuration_mt5 import MT5Config
|
||||
|
||||
|
||||
|
@ -67,6 +67,33 @@ class FlaxMT5Model(FlaxT5Model):
|
|||
config_class = MT5Config
|
||||
|
||||
|
||||
class FlaxMT5EncoderModel(FlaxT5EncoderModel):
|
||||
r"""
|
||||
This class overrides [`FlaxT5EncoderModel`]. Please check the superclass for the appropriate documentation
|
||||
alongside usage examples.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import FlaxT5EncoderModel, T5Tokenizer
|
||||
|
||||
>>> model = FlaxT5EncoderModel.from_pretrained("google/mt5-small")
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
|
||||
|
||||
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
|
||||
>>> summary = "Weiter Verhandlung in Syrien."
|
||||
>>> inputs = tokenizer(article, return_tensors="np")
|
||||
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... decoder_input_ids = tokenizer(summary, return_tensors="np").input_ids
|
||||
|
||||
>>> outputs = model(input_ids=inputs["input_ids"])
|
||||
>>> hidden_states = outputs.last_hidden_state
|
||||
```"""
|
||||
model_type = "mt5"
|
||||
config_class = MT5Config
|
||||
|
||||
|
||||
class FlaxMT5ForConditionalGeneration(FlaxT5ForConditionalGeneration):
|
||||
r"""
|
||||
This class overrides [`FlaxT5ForConditionalGeneration`]. Please check the superclass for the appropriate
|
||||
|
|
|
@ -83,6 +83,7 @@ except OptionalDependencyNotAvailable:
|
|||
pass
|
||||
else:
|
||||
_import_structure["modeling_flax_t5"] = [
|
||||
"FlaxT5EncoderModel",
|
||||
"FlaxT5ForConditionalGeneration",
|
||||
"FlaxT5Model",
|
||||
"FlaxT5PreTrainedModel",
|
||||
|
@ -143,7 +144,12 @@ if TYPE_CHECKING:
|
|||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
|
||||
from .modeling_flax_t5 import (
|
||||
FlaxT5EncoderModel,
|
||||
FlaxT5ForConditionalGeneration,
|
||||
FlaxT5Model,
|
||||
FlaxT5PreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
|
|
|
@ -929,18 +929,18 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
|
|||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
args = [input_ids, attention_mask]
|
||||
if self.module_class not in [FlaxT5EncoderModule]:
|
||||
decoder_input_ids = jnp.ones_like(input_ids)
|
||||
decoder_attention_mask = jnp.ones_like(input_ids)
|
||||
args.extend([decoder_input_ids, decoder_attention_mask])
|
||||
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
random_params = self.module.init(
|
||||
rngs,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
*args,
|
||||
)["params"]
|
||||
|
||||
if params is not None:
|
||||
|
@ -1357,6 +1357,90 @@ overwrite_call_docstring(FlaxT5Model, T5_INPUTS_DOCSTRING + FLAX_T5_MODEL_DOCSTR
|
|||
append_replace_return_docstrings(FlaxT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
|
||||
T5_START_DOCSTRING,
|
||||
)
|
||||
class FlaxT5EncoderModule(nn.Module):
|
||||
config: T5Config
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.shared = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
self.config.d_model,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
|
||||
)
|
||||
|
||||
encoder_config = copy.deepcopy(self.config)
|
||||
encoder_config.is_decoder = False
|
||||
encoder_config.is_encoder_decoder = False
|
||||
encoder_config.causal = False
|
||||
self.encoder = FlaxT5Stack(encoder_config, embed_tokens=self.shared, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
deterministic: bool = True,
|
||||
):
|
||||
|
||||
# Encode if needed (training, first prediction pass)
|
||||
encoder_outputs = self.encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
deterministic=deterministic,
|
||||
)
|
||||
|
||||
return encoder_outputs
|
||||
|
||||
|
||||
class FlaxT5EncoderModel(FlaxT5PreTrainedModel):
|
||||
module_class = FlaxT5EncoderModule
|
||||
|
||||
@add_start_docstrings_to_model_forward(T5_ENCODE_INPUTS_DOCSTRING)
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: jnp.ndarray,
|
||||
attention_mask: Optional[jnp.ndarray] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
train: bool = False,
|
||||
params: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
# prepare encoder inputs
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
input_ids=jnp.array(input_ids, dtype="i4"),
|
||||
attention_mask=jnp.array(attention_mask, dtype="i4"),
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
deterministic=not train,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
|
||||
class FlaxT5ForConditionalGenerationModule(nn.Module):
|
||||
config: T5Config
|
||||
|
|
|
@ -802,6 +802,13 @@ class FlaxMBartPreTrainedModel(metaclass=DummyObject):
|
|||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxMT5EncoderModel(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxMT5ForConditionalGeneration(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
|
@ -970,6 +977,13 @@ class FlaxSpeechEncoderDecoderModel(metaclass=DummyObject):
|
|||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxT5EncoderModel(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxT5ForConditionalGeneration(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
|
|
|
@ -48,7 +48,12 @@ if is_flax_available():
|
|||
from flax.traverse_util import flatten_dict
|
||||
from transformers import FLAX_MODEL_MAPPING, ByT5Tokenizer, T5Config, T5Tokenizer
|
||||
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
|
||||
from transformers.models.t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, shift_tokens_right
|
||||
from transformers.models.t5.modeling_flax_t5 import (
|
||||
FlaxT5EncoderModel,
|
||||
FlaxT5ForConditionalGeneration,
|
||||
FlaxT5Model,
|
||||
shift_tokens_right,
|
||||
)
|
||||
|
||||
|
||||
class FlaxT5ModelTester:
|
||||
|
@ -461,6 +466,298 @@ class FlaxT5ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.
|
|||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
|
||||
class FlaxT5EncoderOnlyModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
vocab_size=99,
|
||||
batch_size=13,
|
||||
encoder_seq_length=7,
|
||||
# For common tests
|
||||
is_training=True,
|
||||
use_attention_mask=True,
|
||||
use_labels=True,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
d_ff=37,
|
||||
relative_attention_num_buckets=8,
|
||||
dropout_rate=0.1,
|
||||
initializer_factor=0.002,
|
||||
eos_token_id=1,
|
||||
pad_token_id=0,
|
||||
decoder_start_token_id=0,
|
||||
scope=None,
|
||||
):
|
||||
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.encoder_seq_length = encoder_seq_length
|
||||
# For common tests
|
||||
self.seq_length = self.encoder_seq_length
|
||||
self.is_training = is_training
|
||||
self.use_attention_mask = use_attention_mask
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.d_ff = d_ff
|
||||
self.relative_attention_num_buckets = relative_attention_num_buckets
|
||||
self.dropout_rate = dropout_rate
|
||||
self.initializer_factor = initializer_factor
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.decoder_start_token_id = decoder_start_token_id
|
||||
self.scope = None
|
||||
self.decoder_layers = 0
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
|
||||
|
||||
attention_mask = None
|
||||
if self.use_attention_mask:
|
||||
attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
|
||||
|
||||
config = T5Config(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=self.hidden_size,
|
||||
d_ff=self.d_ff,
|
||||
d_kv=self.hidden_size // self.num_attention_heads,
|
||||
num_layers=self.num_hidden_layers,
|
||||
num_decoder_layers=self.decoder_layers,
|
||||
num_heads=self.num_attention_heads,
|
||||
relative_attention_num_buckets=self.relative_attention_num_buckets,
|
||||
dropout_rate=self.dropout_rate,
|
||||
initializer_factor=self.initializer_factor,
|
||||
eos_token_id=self.eos_token_id,
|
||||
bos_token_id=self.pad_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
decoder_start_token_id=self.decoder_start_token_id,
|
||||
is_encoder_decoder=False,
|
||||
)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
):
|
||||
model = FlaxT5EncoderModel(config=config)
|
||||
result = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
result = model(input_ids=input_ids)
|
||||
encoder_output = result.last_hidden_state
|
||||
|
||||
self.parent.assertEqual(encoder_output.shape, (self.batch_size, self.encoder_seq_length, self.hidden_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
) = config_and_inputs
|
||||
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxT5EncoderOnlyModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (FlaxT5EncoderModel,) if is_flax_available() else ()
|
||||
is_encoder_decoder = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxT5EncoderOnlyModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_v1_1(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
# check that gated gelu feed forward and different word embeddings work
|
||||
config = config_and_inputs[0]
|
||||
config.tie_word_embeddings = False
|
||||
config.feed_forward_proj = "gated-gelu"
|
||||
self.model_tester.create_and_check_model(config, *config_and_inputs[1:])
|
||||
|
||||
def test_encode(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config)
|
||||
|
||||
@jax.jit
|
||||
def encode_jitted(input_ids, attention_mask=None, **kwargs):
|
||||
return model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
with self.subTest("JIT Enabled"):
|
||||
jitted_outputs = encode_jitted(**prepared_inputs_dict).to_tuple()
|
||||
|
||||
with self.subTest("JIT Disabled"):
|
||||
with jax.disable_jit():
|
||||
outputs = encode_jitted(**prepared_inputs_dict).to_tuple()
|
||||
|
||||
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
# overwrite since special base model prefix is used
|
||||
def test_save_load_from_base(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = base_class(config)
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
head_model = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
base_param_from_head = flatten_dict(unfreeze(head_model.params))
|
||||
|
||||
for key in base_param_from_head.keys():
|
||||
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite since special base model prefix is used
|
||||
def test_save_load_to_base(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite since special base model prefix is used
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_from_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = base_class(config)
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# save pt model
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_param_from_head = flatten_dict(unfreeze(head_model.params))
|
||||
|
||||
for key in base_param_from_head.keys():
|
||||
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite since special base model prefix is used
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# overwrite since special base model prefix is used
|
||||
@is_pt_flax_cross_test
|
||||
def test_save_load_bf16_to_base_pt(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
model.params = model.to_bf16(model.params)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# convert Flax model to PyTorch model
|
||||
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
@require_flax
|
||||
|
|
Loading…
Reference in New Issue