Flax t5 Encoder (#17784)

* first draft adding Flax-t5-encoder and Flax-mt5-encoder

* imports

* after make fixup

* flax t5 encoder test

* black on test

* make fix-copies

* clean

* all_model_classes -> tuple

* clean test

* is_encoder_decoder=False in t5-enc tester

* remove file docstring before FlaxT5Encoder

* black

* isort

* commit suggestions on src/transformers/models/t5/modeling_flax_t5.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* commit suggestions on src/transformers/models/t5/modeling_flax_t5.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* remove _get_encoder_module

* self.decoder_seq_length -> self.encoder_seq_length as t5-enc does not have decoder

* bugfix - self.module_class is class itself, not instance;

* docs for mt5 and t5

* call -> __call__ in t5 doc

* FlaxMT5EncoderModel to TYPE_HINT

* run doc-builder to allow change the files

Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Crystina 2022-06-29 15:49:02 -07:00 committed by GitHub
parent eb1493b15d
commit 692e61e91a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 454 additions and 15 deletions

View File

@ -96,3 +96,7 @@ See [`T5TokenizerFast`] for all details.
## FlaxMT5ForConditionalGeneration
[[autodoc]] FlaxMT5ForConditionalGeneration
## FlaxMT5EncoderModel
[[autodoc]] FlaxMT5EncoderModel

View File

@ -371,3 +371,8 @@ T5 is supported by several example scripts, both for pre-training and fine-tunin
- __call__
- encode
- decode
## FlaxT5EncoderModel
[[autodoc]] FlaxT5EncoderModel
- __call__

View File

@ -2704,7 +2704,7 @@ else:
"FlaxMBartPreTrainedModel",
]
)
_import_structure["models.mt5"].extend(["FlaxMT5ForConditionalGeneration", "FlaxMT5Model"])
_import_structure["models.mt5"].extend(["FlaxMT5EncoderModel", "FlaxMT5ForConditionalGeneration", "FlaxMT5Model"])
_import_structure["models.opt"].extend(
[
"FlaxOPTForCausalLM",
@ -2743,7 +2743,9 @@ else:
]
)
_import_structure["models.speech_encoder_decoder"].append("FlaxSpeechEncoderDecoderModel")
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"])
_import_structure["models.t5"].extend(
["FlaxT5EncoderModel", "FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"]
)
_import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel")
_import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"])
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"])
@ -4974,7 +4976,7 @@ if TYPE_CHECKING:
FlaxMBartModel,
FlaxMBartPreTrainedModel,
)
from .models.mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model
from .models.mt5 import FlaxMT5EncoderModel, FlaxMT5ForConditionalGeneration, FlaxMT5Model
from .models.opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel
from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel
from .models.roberta import (
@ -4997,7 +4999,7 @@ if TYPE_CHECKING:
FlaxRoFormerPreTrainedModel,
)
from .models.speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
from .models.t5 import FlaxT5EncoderModel, FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel
from .models.vision_text_dual_encoder import FlaxVisionTextDualEncoderModel
from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel

View File

@ -67,7 +67,7 @@ try:
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_mt5"] = ["FlaxMT5ForConditionalGeneration", "FlaxMT5Model"]
_import_structure["modeling_flax_mt5"] = ["FlaxMT5EncoderModel", "FlaxMT5ForConditionalGeneration", "FlaxMT5Model"]
if TYPE_CHECKING:
@ -95,7 +95,7 @@ if TYPE_CHECKING:
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model
from .modeling_flax_mt5 import FlaxMT5EncoderModel, FlaxMT5ForConditionalGeneration, FlaxMT5Model
else:
import sys

View File

@ -17,7 +17,7 @@
import numpy as np
from ...utils import logging
from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
from ..t5.modeling_flax_t5 import FlaxT5EncoderModel, FlaxT5ForConditionalGeneration, FlaxT5Model
from .configuration_mt5 import MT5Config
@ -67,6 +67,33 @@ class FlaxMT5Model(FlaxT5Model):
config_class = MT5Config
class FlaxMT5EncoderModel(FlaxT5EncoderModel):
r"""
This class overrides [`FlaxT5EncoderModel`]. Please check the superclass for the appropriate documentation
alongside usage examples.
Examples:
```python
>>> from transformers import FlaxT5EncoderModel, T5Tokenizer
>>> model = FlaxT5EncoderModel.from_pretrained("google/mt5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien."
>>> inputs = tokenizer(article, return_tensors="np")
>>> with tokenizer.as_target_tokenizer():
... decoder_input_ids = tokenizer(summary, return_tensors="np").input_ids
>>> outputs = model(input_ids=inputs["input_ids"])
>>> hidden_states = outputs.last_hidden_state
```"""
model_type = "mt5"
config_class = MT5Config
class FlaxMT5ForConditionalGeneration(FlaxT5ForConditionalGeneration):
r"""
This class overrides [`FlaxT5ForConditionalGeneration`]. Please check the superclass for the appropriate

View File

@ -83,6 +83,7 @@ except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_t5"] = [
"FlaxT5EncoderModel",
"FlaxT5ForConditionalGeneration",
"FlaxT5Model",
"FlaxT5PreTrainedModel",
@ -143,7 +144,12 @@ if TYPE_CHECKING:
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
from .modeling_flax_t5 import (
FlaxT5EncoderModel,
FlaxT5ForConditionalGeneration,
FlaxT5Model,
FlaxT5PreTrainedModel,
)
else:

View File

@ -929,18 +929,18 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids)
decoder_input_ids = jnp.ones_like(input_ids)
decoder_attention_mask = jnp.ones_like(input_ids)
args = [input_ids, attention_mask]
if self.module_class not in [FlaxT5EncoderModule]:
decoder_input_ids = jnp.ones_like(input_ids)
decoder_attention_mask = jnp.ones_like(input_ids)
args.extend([decoder_input_ids, decoder_attention_mask])
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init(
rngs,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
*args,
)["params"]
if params is not None:
@ -1357,6 +1357,90 @@ overwrite_call_docstring(FlaxT5Model, T5_INPUTS_DOCSTRING + FLAX_T5_MODEL_DOCSTR
append_replace_return_docstrings(FlaxT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@add_start_docstrings(
"The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
T5_START_DOCSTRING,
)
class FlaxT5EncoderModule(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
)
encoder_config = copy.deepcopy(self.config)
encoder_config.is_decoder = False
encoder_config.is_encoder_decoder = False
encoder_config.causal = False
self.encoder = FlaxT5Stack(encoder_config, embed_tokens=self.shared, dtype=self.dtype)
def __call__(
self,
input_ids=None,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
deterministic: bool = True,
):
# Encode if needed (training, first prediction pass)
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
return encoder_outputs
class FlaxT5EncoderModel(FlaxT5PreTrainedModel):
module_class = FlaxT5EncoderModule
@add_start_docstrings_to_model_forward(T5_ENCODE_INPUTS_DOCSTRING)
def __call__(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# prepare encoder inputs
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# Handle any PRNG if needed
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
return self.module.apply(
{"params": params or self.params},
input_ids=jnp.array(input_ids, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
)
@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
class FlaxT5ForConditionalGenerationModule(nn.Module):
config: T5Config

View File

@ -802,6 +802,13 @@ class FlaxMBartPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"])
class FlaxMT5EncoderModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxMT5ForConditionalGeneration(metaclass=DummyObject):
_backends = ["flax"]
@ -970,6 +977,13 @@ class FlaxSpeechEncoderDecoderModel(metaclass=DummyObject):
requires_backends(self, ["flax"])
class FlaxT5EncoderModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxT5ForConditionalGeneration(metaclass=DummyObject):
_backends = ["flax"]

View File

@ -48,7 +48,12 @@ if is_flax_available():
from flax.traverse_util import flatten_dict
from transformers import FLAX_MODEL_MAPPING, ByT5Tokenizer, T5Config, T5Tokenizer
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
from transformers.models.t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, shift_tokens_right
from transformers.models.t5.modeling_flax_t5 import (
FlaxT5EncoderModel,
FlaxT5ForConditionalGeneration,
FlaxT5Model,
shift_tokens_right,
)
class FlaxT5ModelTester:
@ -461,6 +466,298 @@ class FlaxT5ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
class FlaxT5EncoderOnlyModelTester:
def __init__(
self,
parent,
vocab_size=99,
batch_size=13,
encoder_seq_length=7,
# For common tests
is_training=True,
use_attention_mask=True,
use_labels=True,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
d_ff=37,
relative_attention_num_buckets=8,
dropout_rate=0.1,
initializer_factor=0.002,
eos_token_id=1,
pad_token_id=0,
decoder_start_token_id=0,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.encoder_seq_length = encoder_seq_length
# For common tests
self.seq_length = self.encoder_seq_length
self.is_training = is_training
self.use_attention_mask = use_attention_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.d_ff = d_ff
self.relative_attention_num_buckets = relative_attention_num_buckets
self.dropout_rate = dropout_rate
self.initializer_factor = initializer_factor
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.decoder_start_token_id = decoder_start_token_id
self.scope = None
self.decoder_layers = 0
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
attention_mask = None
if self.use_attention_mask:
attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
config = T5Config(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
d_ff=self.d_ff,
d_kv=self.hidden_size // self.num_attention_heads,
num_layers=self.num_hidden_layers,
num_decoder_layers=self.decoder_layers,
num_heads=self.num_attention_heads,
relative_attention_num_buckets=self.relative_attention_num_buckets,
dropout_rate=self.dropout_rate,
initializer_factor=self.initializer_factor,
eos_token_id=self.eos_token_id,
bos_token_id=self.pad_token_id,
pad_token_id=self.pad_token_id,
decoder_start_token_id=self.decoder_start_token_id,
is_encoder_decoder=False,
)
return (
config,
input_ids,
attention_mask,
)
def create_and_check_model(
self,
config,
input_ids,
attention_mask,
):
model = FlaxT5EncoderModel(config=config)
result = model(
input_ids=input_ids,
attention_mask=attention_mask,
)
result = model(input_ids=input_ids)
encoder_output = result.last_hidden_state
self.parent.assertEqual(encoder_output.shape, (self.batch_size, self.encoder_seq_length, self.hidden_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
attention_mask,
) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
return config, inputs_dict
@require_flax
class FlaxT5EncoderOnlyModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxT5EncoderModel,) if is_flax_available() else ()
is_encoder_decoder = False
def setUp(self):
self.model_tester = FlaxT5EncoderOnlyModelTester(self)
self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_v1_1(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
# check that gated gelu feed forward and different word embeddings work
config = config_and_inputs[0]
config.tie_word_embeddings = False
config.feed_forward_proj = "gated-gelu"
self.model_tester.create_and_check_model(config, *config_and_inputs[1:])
def test_encode(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
@jax.jit
def encode_jitted(input_ids, attention_mask=None, **kwargs):
return model(input_ids=input_ids, attention_mask=attention_mask)
with self.subTest("JIT Enabled"):
jitted_outputs = encode_jitted(**prepared_inputs_dict).to_tuple()
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = encode_jitted(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
# overwrite since special base model prefix is used
def test_save_load_from_base(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = base_class(config)
base_params = flatten_dict(unfreeze(model.params))
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
head_model = model_class.from_pretrained(tmpdirname)
base_param_from_head = flatten_dict(unfreeze(head_model.params))
for key in base_param_from_head.keys():
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite since special base model prefix is used
def test_save_load_to_base(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
base_params_from_head = flatten_dict(unfreeze(model.params))
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite since special base model prefix is used
@is_pt_flax_cross_test
def test_save_load_from_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = base_class(config)
base_params = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
# save pt model
pt_model.save_pretrained(tmpdirname)
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
base_param_from_head = flatten_dict(unfreeze(head_model.params))
for key in base_param_from_head.keys():
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite since special base model prefix is used
@is_pt_flax_cross_test
def test_save_load_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
base_params_from_head = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite since special base model prefix is used
@is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
model.params = model.to_bf16(model.params)
base_params_from_head = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
@require_sentencepiece
@require_tokenizers
@require_flax