Flax Remat for LongT5 (#17994)

* [Flax] Add remat (gradient checkpointing)

* fix variable naming in test

* flip: checkpoint using a method

* fix naming

* fix class naming

* apply PVP's suggestions from code review

* add gradient_checkpointing to examples

* Add gradient_checkpointing to run_mlm_flax

* Add remat to longt5

* Add gradient checkpointing test longt5

* Fix args errors

* Fix remaining tests

* Make fixup & quality fixes

* replace kwargs

* remove unecessary kwargs

* Make fixup changes

* revert long_t5_flax changes

* Remove return_dict and copy to LongT5

* Remove test_gradient_checkpointing

Co-authored-by: sanchit-gandhi <sanchit@huggingface.co>
This commit is contained in:
Karim Foda 2022-08-14 17:27:13 +02:00 committed by GitHub
parent 1ccd2515ed
commit d6eeb87170
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 149 additions and 39 deletions

View File

@ -107,6 +107,12 @@ class TrainingArguments:
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
gradient_checkpointing: bool = field(
default=False,
metadata={
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
def __post_init__(self):
if self.output_dir is not None:
@ -640,6 +646,9 @@ def main():
dtype=getattr(jnp, model_args.dtype),
)
if training_args.gradient_checkpointing:
model.enable_gradient_checkpointing()
# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()

View File

@ -121,6 +121,12 @@ class TrainingArguments:
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
gradient_checkpointing: bool = field(
default=False,
metadata={
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
def __post_init__(self):
if self.output_dir is not None:
@ -535,6 +541,9 @@ def main():
dtype=getattr(jnp, model_args.dtype),
)
if training_args.gradient_checkpointing:
model.enable_gradient_checkpointing()
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

View File

@ -25,6 +25,7 @@ import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey
@ -53,6 +54,8 @@ _CHECKPOINT_FOR_DOC = "google/long-t5-local-base"
_CONFIG_FOR_DOC = "LongT5Config"
_TOKENIZER_FOR_DOC = "T5Tokenizer"
remat = nn_partitioning.remat
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
@ -1356,7 +1359,6 @@ class FlaxLongT5LayerCollection(nn.Module):
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
output_attentions=False,
return_dict=True,
deterministic=True,
init_cache=False,
):
@ -1377,13 +1379,31 @@ class FlaxLongT5LayerCollection(nn.Module):
class FlaxLongT5BlockCollection(nn.Module):
config: LongT5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self):
self.causal = self.config.causal
self.blocks = [
FlaxLongT5LayerCollection(self.config, has_relative_attention_bias=(i == 0), dtype=self.dtype, name=str(i))
for i in range(self.config.num_layers)
]
if self.gradient_checkpointing:
FlaxLongT5CheckpointLayer = remat(FlaxLongT5LayerCollection, static_argnums=(6, 7, 8))
self.blocks = [
FlaxLongT5CheckpointLayer(
self.config,
has_relative_attention_bias=(i == 0),
dtype=self.dtype,
name=str(i),
)
for i in range(self.config.num_layers)
]
else:
self.blocks = [
FlaxLongT5LayerCollection(
self.config,
has_relative_attention_bias=(i == 0),
dtype=self.dtype,
name=str(i),
)
for i in range(self.config.num_layers)
]
def __call__(
self,
@ -1409,14 +1429,14 @@ class FlaxLongT5BlockCollection(nn.Module):
layer_outputs = layer_module(
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
output_attentions=output_attentions,
deterministic=deterministic,
init_cache=init_cache,
attention_mask,
position_bias,
encoder_hidden_states,
encoder_attention_mask,
encoder_decoder_position_bias,
output_attentions,
deterministic,
init_cache,
)
hidden_states = layer_outputs[0]
@ -1447,11 +1467,14 @@ class FlaxLongT5Stack(nn.Module):
config: LongT5Config
embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self):
self.causal = self.config.causal
self.block = FlaxLongT5BlockCollection(self.config, dtype=self.dtype)
self.block = FlaxLongT5BlockCollection(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.final_layer_norm = FlaxLongT5LayerNorm(
self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
)
@ -1989,6 +2012,7 @@ LONGT5_START_DOCSTRING = r"""
class FlaxLongT5Module(nn.Module):
config: LongT5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def _get_encoder_module(self):
return self.encoder
@ -2005,12 +2029,22 @@ class FlaxLongT5Module(nn.Module):
encoder_config = copy.deepcopy(self.config)
encoder_config.causal = False
self.encoder = FlaxLongT5Stack(encoder_config, embed_tokens=self.shared, dtype=self.dtype)
self.encoder = FlaxLongT5Stack(
encoder_config,
embed_tokens=self.shared,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
decoder_config = copy.deepcopy(self.config)
decoder_config.causal = True
decoder_config.num_layers = self.config.num_decoder_layers
self.decoder = FlaxLongT5Stack(decoder_config, embed_tokens=self.shared, dtype=self.dtype)
self.decoder = FlaxLongT5Stack(
decoder_config,
embed_tokens=self.shared,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
def __call__(
self,
@ -2104,6 +2138,7 @@ append_replace_return_docstrings(FlaxLongT5Model, output_type=FlaxSeq2SeqLMOutpu
class FlaxLongT5ForConditionalGenerationModule(nn.Module):
config: LongT5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def _get_encoder_module(self):
return self.encoder
@ -2124,13 +2159,17 @@ class FlaxLongT5ForConditionalGenerationModule(nn.Module):
encoder_config.causal = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = FlaxLongT5Stack(encoder_config, self.shared, dtype=self.dtype)
self.encoder = FlaxLongT5Stack(
encoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
decoder_config = copy.deepcopy(self.config)
decoder_config.causal = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = self.config.num_decoder_layers
self.decoder = FlaxLongT5Stack(decoder_config, self.shared, dtype=self.dtype)
self.decoder = FlaxLongT5Stack(
decoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.lm_head = nn.Dense(
self.config.vocab_size,

View File

@ -25,6 +25,7 @@ import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey
@ -53,6 +54,8 @@ _CHECKPOINT_FOR_DOC = "t5-small"
_CONFIG_FOR_DOC = "T5Config"
_TOKENIZER_FOR_DOC = "T5Tokenizer"
remat = nn_partitioning.remat
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
@ -622,7 +625,6 @@ class FlaxT5LayerCollection(nn.Module):
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
output_attentions=False,
return_dict=True,
deterministic=True,
init_cache=False,
):
@ -642,13 +644,31 @@ class FlaxT5LayerCollection(nn.Module):
class FlaxT5BlockCollection(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self):
self.causal = self.config.causal
self.blocks = [
FlaxT5LayerCollection(self.config, has_relative_attention_bias=(i == 0), dtype=self.dtype, name=str(i))
for i in range(self.config.num_layers)
]
if self.gradient_checkpointing:
FlaxT5CheckpointLayer = remat(FlaxT5LayerCollection, static_argnums=(6, 7, 8))
self.blocks = [
FlaxT5CheckpointLayer(
self.config,
has_relative_attention_bias=(i == 0),
dtype=self.dtype,
name=str(i),
)
for i in range(self.config.num_layers)
]
else:
self.blocks = [
FlaxT5LayerCollection(
self.config,
has_relative_attention_bias=(i == 0),
dtype=self.dtype,
name=str(i),
)
for i in range(self.config.num_layers)
]
def __call__(
self,
@ -674,14 +694,14 @@ class FlaxT5BlockCollection(nn.Module):
layer_outputs = layer_module(
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
output_attentions=output_attentions,
deterministic=deterministic,
init_cache=init_cache,
attention_mask,
position_bias,
encoder_hidden_states,
encoder_attention_mask,
encoder_decoder_position_bias,
output_attentions,
deterministic,
init_cache,
)
hidden_states = layer_outputs[0]
@ -711,11 +731,14 @@ class FlaxT5Stack(nn.Module):
config: T5Config
embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self):
self.causal = self.config.causal
self.block = FlaxT5BlockCollection(self.config, dtype=self.dtype)
self.block = FlaxT5BlockCollection(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.final_layer_norm = FlaxT5LayerNorm(
self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
)
@ -919,11 +942,19 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
gradient_checkpointing: bool = False,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def enable_gradient_checkpointing(self):
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
@ -1248,6 +1279,7 @@ T5_START_DOCSTRING = r"""
class FlaxT5Module(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def _get_encoder_module(self):
return self.encoder
@ -1264,12 +1296,22 @@ class FlaxT5Module(nn.Module):
encoder_config = copy.deepcopy(self.config)
encoder_config.causal = False
self.encoder = FlaxT5Stack(encoder_config, embed_tokens=self.shared, dtype=self.dtype)
self.encoder = FlaxT5Stack(
encoder_config,
embed_tokens=self.shared,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
decoder_config = copy.deepcopy(self.config)
decoder_config.causal = True
decoder_config.num_layers = self.config.num_decoder_layers
self.decoder = FlaxT5Stack(decoder_config, embed_tokens=self.shared, dtype=self.dtype)
self.decoder = FlaxT5Stack(
decoder_config,
embed_tokens=self.shared,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
def __call__(
self,
@ -1364,6 +1406,7 @@ append_replace_return_docstrings(FlaxT5Model, output_type=FlaxSeq2SeqLMOutput, c
class FlaxT5EncoderModule(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self):
self.shared = nn.Embed(
@ -1376,7 +1419,12 @@ class FlaxT5EncoderModule(nn.Module):
encoder_config.is_decoder = False
encoder_config.is_encoder_decoder = False
encoder_config.causal = False
self.encoder = FlaxT5Stack(encoder_config, embed_tokens=self.shared, dtype=self.dtype)
self.encoder = FlaxT5Stack(
encoder_config,
embed_tokens=self.shared,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
def __call__(
self,
@ -1384,7 +1432,7 @@ class FlaxT5EncoderModule(nn.Module):
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
return_dict: bool = True,
deterministic: bool = True,
):
@ -1445,6 +1493,7 @@ class FlaxT5EncoderModel(FlaxT5PreTrainedModel):
class FlaxT5ForConditionalGenerationModule(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def _get_encoder_module(self):
return self.encoder
@ -1465,13 +1514,17 @@ class FlaxT5ForConditionalGenerationModule(nn.Module):
encoder_config.causal = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = FlaxT5Stack(encoder_config, self.shared, dtype=self.dtype)
self.encoder = FlaxT5Stack(
encoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
decoder_config = copy.deepcopy(self.config)
decoder_config.causal = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = self.config.num_decoder_layers
self.decoder = FlaxT5Stack(decoder_config, self.shared, dtype=self.dtype)
self.decoder = FlaxT5Stack(
decoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.lm_head = nn.Dense(
self.config.vocab_size,