[Flax] improve large model init and loading (#16148)

* begin do_init

* add params_shape_tree

* raise error if params are accessed when do_init is False

* don't allow do_init=False when keys are missing

* make shape tree a property

* assign self._params at the end

* add test for do_init

* add do_init arg to all flax models

* fix param setting

* disbale do_init for composite models

* update test

* add do_init in FlaxBigBirdForMultipleChoice

* better names and errors

* improve test

* style

* add a warning when do_init=False

* remove extra if

* set params after _required_params

* add test for from_pretrained

* do_init => _do_init

* chage warning to info

* fix typo

* add params in init_weights

* add params to gpt neo init

* add params to init_weights

* update do_init test

* Trigger CI

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* update template

* trigger CI

* style

* style

* fix template

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Suraj Patil 2022-04-19 14:19:55 +02:00 committed by GitHub
parent 6de4ee61a0
commit d3bd9ac728
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 702 additions and 148 deletions

View File

@ -140,7 +140,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensor
input_ids = jnp.zeros(input_shape[0], dtype="i4")
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])

View File

@ -90,6 +90,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
base_model_prefix = ""
main_input_name = "input_ids"
_auto_class = None
_missing_keys = set()
def __init__(
self,
@ -98,6 +99,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
):
if config is None:
raise ValueError("config cannot be None")
@ -112,15 +114,35 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# Those are public as their type is generic to every derived classes.
self.key = PRNGKey(seed)
self.dtype = dtype
self.input_shape = input_shape
# randomly initialized parameters
random_params = self.init_weights(self.key, input_shape)
# To check if the model was intialized automatically.
self._is_initialized = _do_init
if _do_init:
# randomly initialized parameters
random_params = self.init_weights(self.key, input_shape)
params_shape_tree = jax.eval_shape(lambda params: params, random_params)
else:
init_fn = partial(self.init_weights, input_shape=input_shape)
params_shape_tree = jax.eval_shape(init_fn, self.key)
logger.info(
"Model weights are not initialized as `_do_init` is set to `False`. "
f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights."
)
# get the shape of the parameters
self._params_shape_tree = params_shape_tree
# save required_params as set
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
self.params = random_params
self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict:
# initialize the parameters
if _do_init:
self.params = random_params
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict:
raise NotImplementedError(f"init method has to be implemented for {self}")
@classmethod
@ -147,14 +169,31 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
@property
def params(self) -> Union[Dict, FrozenDict]:
if not self._is_initialized:
raise ValueError(
"`params` cannot be accessed from model when the model is created with `_do_init=False`. "
"You must call `init_weights` manually and store the params outside of the model and "
"pass it explicitly where needed."
)
return self._params
@property
def required_params(self) -> Set:
return self._required_params
@property
def params_shape_tree(self) -> Dict:
return self._params_shape_tree
@params.setter
def params(self, params: Union[Dict, FrozenDict]):
# don't set params if the model is not initialized
if not self._is_initialized:
raise ValueError(
"`params` cannot be set from model when the model is created with `_do_init=False`. "
"You store the params outside of the model."
)
if isinstance(params, FrozenDict):
params = unfreeze(params)
param_keys = set(flatten_dict(params).keys())
@ -417,6 +456,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
revision = kwargs.pop("revision", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
_do_init = kwargs.pop("_do_init", True)
user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class}
if from_pipeline is not None:
@ -553,7 +593,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
resolved_archive_file = None
# init random models
model = cls(config, *model_args, **model_kwargs)
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
if from_pt:
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file)
@ -577,25 +617,36 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# make sure all arrays are stored as jnp.arrays
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
# https://github.com/google/flax/issues/1261
state = jax.tree_util.tree_map(jnp.array, state)
if _do_init:
state = jax.tree_util.tree_map(jnp.array, state)
else:
# keep the params on CPU if we don't want to initialize
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
# if model is base model only use model_prefix key
if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state:
if cls.base_model_prefix not in dict(model.params_shape_tree) and cls.base_model_prefix in state:
state = state[cls.base_model_prefix]
# if model is head model and we are loading weights from base model
# we initialize new params dict with base_model_prefix
if cls.base_model_prefix in dict(model.params) and cls.base_model_prefix not in state:
if cls.base_model_prefix in dict(model.params_shape_tree) and cls.base_model_prefix not in state:
state = {cls.base_model_prefix: state}
# flatten dicts
state = flatten_dict(state)
random_state = flatten_dict(unfreeze(model.params))
random_state = flatten_dict(unfreeze(model.params if _do_init else model.params_shape_tree))
missing_keys = model.required_params - set(state.keys())
unexpected_keys = set(state.keys()) - model.required_params
if missing_keys and not _do_init:
logger.warn(
f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
f"Make sure to call model.init_weights to initialize the missing weights."
)
cls._missing_keys = missing_keys
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys = []
@ -612,9 +663,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
"model."
)
# add missing keys as random parameters
for missing_key in missing_keys:
state[missing_key] = random_state[missing_key]
# add missing keys as random parameters if we are initializing
if missing_keys and _do_init:
for missing_key in missing_keys:
state[missing_key] = random_state[missing_key]
# remove unexpected keys to not be saved again
for unexpected_key in unexpected_keys:
@ -680,10 +732,12 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
)
# set correct parameters
model.params = unflatten_dict(state)
return model
if _do_init:
# set correct parameters
model.params = unflatten_dict(state)
return model
else:
return model, unflatten_dict(state)
def save_pretrained(self, save_directory: Union[str, os.PathLike], params=None, push_to_hub=False, **kwargs):
"""

View File

@ -21,8 +21,9 @@ import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from ...modeling_flax_outputs import (
@ -522,12 +523,13 @@ class FlaxAlbertPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.zeros_like(input_ids)
@ -537,9 +539,19 @@ class FlaxAlbertPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False)[
"params"
]
random_params = self.module.init(
rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(

View File

@ -24,9 +24,10 @@ import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from jax.random import PRNGKey
@ -912,12 +913,13 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple[int] = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
# make sure initialization pass will work for FlaxBartForSequenceClassificationModule
@ -933,7 +935,7 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(
random_params = self.module.init(
rngs,
input_ids,
attention_mask,
@ -943,6 +945,16 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
decoder_position_ids,
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length, encoder_outputs):
r"""
Args:
@ -1737,14 +1749,15 @@ class FlaxBartDecoderPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple[int] = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
config.is_decoder = True
config.is_encoder_decoder = False
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids)

View File

@ -22,8 +22,9 @@ import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from ...modeling_flax_outputs import (
FlaxBaseModelOutput,
@ -591,13 +592,21 @@ class FlaxBeitPreTrainedModel(FlaxPreTrainedModel):
main_input_name = "pixel_values"
module_class: nn.Module = None
def __init__(self, config: BeitConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):
def __init__(
self,
config: BeitConfig,
input_shape=None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
if input_shape is None:
input_shape = (1, config.image_size, config.image_size, 3)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
pixel_values = jnp.zeros(input_shape, dtype=self.dtype)
@ -605,7 +614,17 @@ class FlaxBeitPreTrainedModel(FlaxPreTrainedModel):
dropout_rng, droppath_rng = jax.random.split(dropout_rng)
rngs = {"params": params_rng, "dropout": dropout_rng, "droppath": droppath_rng}
return self.module.init(rngs, pixel_values, return_dict=False)["params"]
random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(

View File

@ -21,8 +21,9 @@ import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from ...modeling_flax_outputs import (
@ -616,12 +617,18 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
module_class: nn.Module = None
def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
self,
config: BertConfig,
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.zeros_like(input_ids)
@ -632,10 +639,20 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(
random_params = self.module.init(
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,

View File

@ -21,8 +21,9 @@ import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from ...modeling_flax_outputs import (
@ -1420,6 +1421,7 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
input_shape: Optional[tuple] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
@ -1428,9 +1430,9 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
elif input_shape is None:
input_shape = (1, 1)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.zeros_like(input_ids)
@ -1441,10 +1443,20 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(
random_params = self.module.init(
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
@ -1897,13 +1909,14 @@ class FlaxBigBirdForMultipleChoice(FlaxBigBirdPreTrainedModel):
input_shape: Optional[tuple] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
if config.attention_type == "block_sparse" and input_shape is None:
input_shape = (1, 1, 12 * config.block_size)
elif input_shape is None:
input_shape = (1, 1)
super().__init__(config, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
overwrite_call_docstring(

View File

@ -24,9 +24,10 @@ import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from jax.random import PRNGKey
@ -887,12 +888,13 @@ class FlaxBlenderbotPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple[int] = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
# make sure initialization pass will work for FlaxBlenderbotForSequenceClassificationModule
@ -908,7 +910,7 @@ class FlaxBlenderbotPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(
random_params = self.module.init(
rngs,
input_ids,
attention_mask,
@ -918,6 +920,16 @@ class FlaxBlenderbotPreTrainedModel(FlaxPreTrainedModel):
decoder_position_ids,
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length, encoder_outputs):
r"""
Args:

View File

@ -25,9 +25,10 @@ import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from jax.random import PRNGKey
@ -885,12 +886,13 @@ class FlaxBlenderbotSmallPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple[int] = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
# make sure initialization pass will work for FlaxBlenderbotSmallForSequenceClassificationModule
@ -906,7 +908,7 @@ class FlaxBlenderbotSmallPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(
random_params = self.module.init(
rngs,
input_ids,
attention_mask,
@ -916,6 +918,16 @@ class FlaxBlenderbotSmallPreTrainedModel(FlaxPreTrainedModel):
decoder_position_ids,
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length, encoder_outputs):
r"""
Args:

View File

@ -19,9 +19,10 @@ import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling
@ -585,12 +586,18 @@ class FlaxCLIPTextPreTrainedModel(FlaxPreTrainedModel):
module_class: nn.Module = None
def __init__(
self, config: CLIPTextConfig, input_shape=(1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
self,
config: CLIPTextConfig,
input_shape=(1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensor
input_ids = jnp.zeros(input_shape, dtype="i4")
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
@ -599,7 +606,17 @@ class FlaxCLIPTextPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, position_ids)["params"]
random_params = self.module.init(rngs, input_ids, attention_mask, position_ids)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def __call__(
self,
@ -654,21 +671,32 @@ class FlaxCLIPVisionPreTrainedModel(FlaxPreTrainedModel):
input_shape: Optional[Tuple] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
if input_shape is None:
input_shape = (1, config.image_size, config.image_size, 3)
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensor
pixel_values = jax.random.normal(rng, input_shape)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, pixel_values)["params"]
random_params = self.module.init(rngs, pixel_values)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def __call__(
self,
@ -714,14 +742,15 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
input_shape: Optional[Tuple] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
if input_shape is None:
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensor
input_ids = jnp.zeros(input_shape[0], dtype="i4")
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
@ -732,7 +761,17 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids)["params"]
random_params = self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def __call__(
self,

View File

@ -21,7 +21,8 @@ import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from ...modeling_flax_outputs import (
@ -428,12 +429,13 @@ class FlaxDistilBertPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids)
@ -441,7 +443,17 @@ class FlaxDistilBertPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, return_dict=False)["params"]
random_params = self.module.init(rngs, input_ids, attention_mask, return_dict=False)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(

View File

@ -21,8 +21,9 @@ import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from jax.random import PRNGKey
@ -541,12 +542,13 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.zeros_like(input_ids)
@ -557,10 +559,20 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(
random_params = self.module.init(
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,

View File

@ -21,7 +21,8 @@ from typing import Optional, Tuple, Union
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from jax.random import PRNGKey
@ -315,11 +316,17 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
input_shape: Optional[Tuple] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
if input_shape is None:
input_shape = ((1, 1), (1, 1))
if not _do_init:
raise ValueError(
"`FlaxEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
)
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
@ -330,9 +337,9 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
)
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
encoder_input_shape, decoder_input_shape = input_shape
# init input tensors
@ -356,7 +363,7 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(
random_params = self.module.init(
rngs,
input_ids,
attention_mask,
@ -366,6 +373,16 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
decoder_position_ids,
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length, encoder_outputs):
r"""
Args:

View File

@ -18,9 +18,10 @@ from typing import Any, Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from ...modeling_flax_outputs import (
@ -394,12 +395,13 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids)
@ -422,7 +424,17 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
else:
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
return module_init_outputs["params"]
random_params = module_init_outputs["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length):
r"""

View File

@ -19,9 +19,10 @@ from typing import Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
@ -353,12 +354,13 @@ class FlaxGPTNeoPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids)
@ -366,7 +368,17 @@ class FlaxGPTNeoPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"]
random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length):
r"""

View File

@ -21,9 +21,10 @@ import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
@ -373,12 +374,13 @@ class FlaxGPTJPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids)
@ -401,7 +403,17 @@ class FlaxGPTJPreTrainedModel(FlaxPreTrainedModel):
else:
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
return module_init_outputs["params"]
random_params = module_init_outputs["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length):
r"""

View File

@ -24,9 +24,10 @@ import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from jax.random import PRNGKey
@ -882,12 +883,13 @@ class FlaxMarianPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple[int] = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
# make sure initialization pass will work for FlaxMarianForSequenceClassificationModule
@ -903,7 +905,7 @@ class FlaxMarianPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(
random_params = self.module.init(
rngs,
input_ids,
attention_mask,
@ -913,6 +915,16 @@ class FlaxMarianPreTrainedModel(FlaxPreTrainedModel):
decoder_position_ids,
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length, encoder_outputs):
r"""
Args:

View File

@ -24,9 +24,10 @@ import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from jax.random import PRNGKey
@ -951,12 +952,13 @@ class FlaxMBartPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple[int] = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
# make sure initialization pass will work for FlaxMBartForSequenceClassificationModule
@ -972,7 +974,7 @@ class FlaxMBartPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(
random_params = self.module.init(
rngs,
input_ids,
attention_mask,
@ -982,6 +984,16 @@ class FlaxMBartPreTrainedModel(FlaxPreTrainedModel):
decoder_position_ids,
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartPreTrainedModel.init_cache with Bart->MBart
def init_cache(self, batch_size, max_length, encoder_outputs):
r"""

View File

@ -25,9 +25,10 @@ import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from jax.random import PRNGKey
@ -901,12 +902,13 @@ class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple[int] = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids)
@ -920,7 +922,7 @@ class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(
random_params = self.module.init(
rngs,
input_ids,
attention_mask,
@ -930,6 +932,16 @@ class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel):
decoder_position_ids,
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length, encoder_outputs):
r"""
Args:

View File

@ -19,8 +19,9 @@ import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from jax.random import PRNGKey
@ -585,12 +586,13 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.ones_like(input_ids)
@ -601,10 +603,20 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(
random_params = self.module.init(
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,

View File

@ -21,8 +21,9 @@ import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from ...modeling_flax_outputs import (
@ -621,12 +622,13 @@ class FlaxRoFormerPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.zeros_like(input_ids)
@ -636,9 +638,19 @@ class FlaxRoFormerPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, head_mask, return_dict=False)[
"params"
]
random_params = self.module.init(
rngs, input_ids, attention_mask, token_type_ids, head_mask, return_dict=False
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(

View File

@ -20,7 +20,8 @@ from typing import Optional, Tuple, Union
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from jax.random import PRNGKey
@ -343,8 +344,15 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
input_shape: Optional[Tuple] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
if not _do_init:
raise ValueError(
"`FlaxSpeechEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
)
if config.decoder.cross_attention_hidden_size is not None:
# Raise ValueError or option to project enc to dec hidden_size (eg EncAdapterLayer)
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
@ -365,9 +373,9 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
decoder_input_length = module._get_feat_extract_output_lengths(encoder_input_length)
input_shape = ((1, encoder_input_length), (1, decoder_input_length))
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
encoder_input_shape, decoder_input_shape = input_shape
# init input DeviceArrays
@ -390,7 +398,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(
random_params = self.module.init(
rngs,
inputs,
attention_mask,
@ -399,6 +407,16 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
decoder_position_ids,
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length, encoder_outputs):
r"""
Args:

View File

@ -23,9 +23,10 @@ import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey
from ...modeling_flax_outputs import (
@ -919,12 +920,13 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple[int] = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
@ -935,7 +937,7 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(
random_params = self.module.init(
rngs,
input_ids,
attention_mask,
@ -943,6 +945,16 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
decoder_attention_mask,
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
def __call__(
self,

View File

@ -21,7 +21,8 @@ from typing import Optional, Tuple, Union
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from jax.random import PRNGKey
@ -282,8 +283,14 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
input_shape: Optional[Tuple] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
if not _do_init:
raise ValueError(
"`FlaxVisionEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
)
if input_shape is None:
num_channels = getattr(config.encoder, "num_channels", 3)
input_shape = (
@ -301,9 +308,9 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
)
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
encoder_input_shape, decoder_input_shape = input_shape
# init input tensors
@ -325,7 +332,7 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(
random_params = self.module.init(
rngs,
pixel_values,
decoder_input_ids,
@ -333,6 +340,16 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
decoder_position_ids,
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length, encoder_outputs):
r"""
Args:

View File

@ -20,7 +20,8 @@ from typing import Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from ...modeling_flax_utils import FlaxPreTrainedModel, append_replace_return_docstrings, overwrite_call_docstring
from ...utils import add_start_docstrings, logging
@ -225,15 +226,22 @@ class FlaxVisionTextDualEncoderModel(FlaxPreTrainedModel):
input_shape: Optional[Tuple] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
if not _do_init:
raise ValueError(
"`FlaxVisionTextDualEncoderModel` cannot be created without initializing, `_do_init` must be `True`."
)
if input_shape is None:
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensor
input_ids = jnp.zeros(input_shape[0], dtype="i4")
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
@ -245,7 +253,19 @@ class FlaxVisionTextDualEncoderModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids)["params"]
random_params = self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids)[
"params"
]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def __call__(
self,

View File

@ -18,8 +18,9 @@ from typing import Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling, FlaxSequenceClassifierOutput
from ...modeling_flax_utils import (
@ -407,20 +408,38 @@ class FlaxViTPreTrainedModel(FlaxPreTrainedModel):
main_input_name = "pixel_values"
module_class: nn.Module = None
def __init__(self, config: ViTConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):
def __init__(
self,
config: ViTConfig,
input_shape=None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
if input_shape is None:
input_shape = (1, config.image_size, config.image_size, 3)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
pixel_values = jnp.zeros(input_shape, dtype=self.dtype)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, pixel_values, return_dict=False)["params"]
random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(

View File

@ -23,8 +23,9 @@ import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
@ -858,19 +859,30 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple = (1, 1024),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_values = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_values)
params_rng, dropout_rng = jax.random.split(rng, 2)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"]
random_params = self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
def __call__(

View File

@ -25,9 +25,10 @@ import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from jax.random import PRNGKey
@ -561,12 +562,13 @@ class FlaxXGLMPreTrainedModel(FlaxPreTrainedModel):
input_shape: Tuple[int] = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids)
@ -589,7 +591,17 @@ class FlaxXGLMPreTrainedModel(FlaxPreTrainedModel):
else:
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
return module_init_outputs["params"]
random_params = module_init_outputs["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length):
r"""

View File

@ -23,7 +23,8 @@ import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.core.frozen_dict import FrozenDict, unfreeze, freeze
from flax.traverse_util import flatten_dict, unflatten_dict
from flax.linen.attention import dot_product_attention_weights
from jax import lax
@ -586,12 +587,18 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
module_class: nn.Module = None
def __init__(
self, config: {{cookiecutter.camelcase_modelname}}Config, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
self,
config: {{cookiecutter.camelcase_modelname}}Config,
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.zeros_like(input_ids)
@ -602,10 +609,20 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(
random_params = self.module.init(
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
@ -1130,9 +1147,10 @@ from typing import Callable, Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.core.frozen_dict import FrozenDict, unfreeze, freeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from jax.random import PRNGKey
@ -2031,12 +2049,13 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
input_shape: Tuple[int] = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
# make sure initialization pass will work for Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule
@ -2052,7 +2071,7 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(
random_params = self.module.init(
rngs,
input_ids,
attention_mask,
@ -2062,6 +2081,16 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
decoder_position_ids,
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length, encoder_outputs):
r"""
Args:

View File

@ -43,7 +43,7 @@ if is_flax_available():
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import unfreeze
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers import (
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
@ -904,6 +904,93 @@ class FlaxModelTesterMixin:
else:
_check_attentions_validity(outputs.attentions)
def test_no_automatic_init(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
for model_class in self.all_model_classes:
model = model_class(config, _do_init=False)
# Check that accesing parmas raises an ValueError when _do_init is False
with self.assertRaises(ValueError):
params = model.params
# Check if we params can be properly initialized when calling init_weights
params = model.init_weights(model.key, model.input_shape)
self.assertIsInstance(params, FrozenDict)
# Check if all required parmas are initialized
keys = set(flatten_dict(unfreeze(params)).keys())
self.assertTrue(all(k in keys for k in model.required_params))
# Check if the shapes match
flat_params = flatten_dict(unfreeze(params))
for k, v in flatten_dict(unfreeze(model.params_shape_tree)).items():
self.assertEqual(
v.shape,
flat_params[k].shape,
"Shapes of {} do not match. Expecting {}, got {}.".format(k, v.shape, flat_params[k].shape),
)
# Check that setting params raises an ValueError when _do_init is False
with self.assertRaises(ValueError):
model.params = params
# Check if we can do a forward pass
inputs_dict["output_hidden_states"] = True
inputs = self._prepare_for_class(inputs_dict, model_class).copy()
model(**inputs, params=params)
def test_from_pretrained_with_no_automatic_init(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
def _assert_all_params_initialised(model, params):
# Check if all required parmas are loaded
keys = set(flatten_dict(unfreeze(params)).keys())
self.assertTrue(all(k in keys for k in model.required_params))
# Check if the shapes match
flat_params = flatten_dict(unfreeze(params))
for k, v in flatten_dict(unfreeze(model.params_shape_tree)).items():
self.assertEqual(
v.shape,
flat_params[k].shape,
"Shapes of {} do not match. Expecting {}, got {}.".format(k, v.shape, flat_params[k].shape),
)
for model_class in self.all_model_classes:
# init the model
model = model_class(config)
# save the model in the temporary directory
# load the saved model with _do_init=False
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model, params = model_class.from_pretrained(tmpdirname, _do_init=False)
# Check that accesing parmas raises an ValueError when _do_init is False
with self.assertRaises(ValueError):
params = model.params
# Check if all required parmas are loaded
_assert_all_params_initialised(model, params)
# Check that setting params raises an ValueError when _do_init is False
with self.assertRaises(ValueError):
model.params = params
# Check if init_weights initializes missing keys from from_pretrained
flat_params = flatten_dict(unfreeze(params))
random_key = random.choice(list(flat_params.keys()))
flat_params.pop(random_key)
params = freeze(unflatten_dict(flat_params))
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, params=params)
model, params = model_class.from_pretrained(tmpdirname, _do_init=False)
params = model.init_weights(model.key, model.input_shape, params=params)
# Check if all required parmas are loaded
_assert_all_params_initialised(model, params)
@require_flax
@is_staging_test