[Flax] improve large model init and loading (#16148)
* begin do_init * add params_shape_tree * raise error if params are accessed when do_init is False * don't allow do_init=False when keys are missing * make shape tree a property * assign self._params at the end * add test for do_init * add do_init arg to all flax models * fix param setting * disbale do_init for composite models * update test * add do_init in FlaxBigBirdForMultipleChoice * better names and errors * improve test * style * add a warning when do_init=False * remove extra if * set params after _required_params * add test for from_pretrained * do_init => _do_init * chage warning to info * fix typo * add params in init_weights * add params to gpt neo init * add params to init_weights * update do_init test * Trigger CI * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * update template * trigger CI * style * style * fix template Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
6de4ee61a0
commit
d3bd9ac728
|
@ -140,7 +140,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
|||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensor
|
||||
input_ids = jnp.zeros(input_shape[0], dtype="i4")
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
|
||||
|
|
|
@ -90,6 +90,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||
base_model_prefix = ""
|
||||
main_input_name = "input_ids"
|
||||
_auto_class = None
|
||||
_missing_keys = set()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -98,6 +99,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||
input_shape: Tuple = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
):
|
||||
if config is None:
|
||||
raise ValueError("config cannot be None")
|
||||
|
@ -112,15 +114,35 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||
# Those are public as their type is generic to every derived classes.
|
||||
self.key = PRNGKey(seed)
|
||||
self.dtype = dtype
|
||||
self.input_shape = input_shape
|
||||
|
||||
# randomly initialized parameters
|
||||
random_params = self.init_weights(self.key, input_shape)
|
||||
# To check if the model was intialized automatically.
|
||||
self._is_initialized = _do_init
|
||||
|
||||
if _do_init:
|
||||
# randomly initialized parameters
|
||||
random_params = self.init_weights(self.key, input_shape)
|
||||
params_shape_tree = jax.eval_shape(lambda params: params, random_params)
|
||||
else:
|
||||
init_fn = partial(self.init_weights, input_shape=input_shape)
|
||||
params_shape_tree = jax.eval_shape(init_fn, self.key)
|
||||
|
||||
logger.info(
|
||||
"Model weights are not initialized as `_do_init` is set to `False`. "
|
||||
f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights."
|
||||
)
|
||||
|
||||
# get the shape of the parameters
|
||||
self._params_shape_tree = params_shape_tree
|
||||
|
||||
# save required_params as set
|
||||
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
|
||||
self.params = random_params
|
||||
self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict:
|
||||
# initialize the parameters
|
||||
if _do_init:
|
||||
self.params = random_params
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict:
|
||||
raise NotImplementedError(f"init method has to be implemented for {self}")
|
||||
|
||||
@classmethod
|
||||
|
@ -147,14 +169,31 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||
|
||||
@property
|
||||
def params(self) -> Union[Dict, FrozenDict]:
|
||||
if not self._is_initialized:
|
||||
raise ValueError(
|
||||
"`params` cannot be accessed from model when the model is created with `_do_init=False`. "
|
||||
"You must call `init_weights` manually and store the params outside of the model and "
|
||||
"pass it explicitly where needed."
|
||||
)
|
||||
return self._params
|
||||
|
||||
@property
|
||||
def required_params(self) -> Set:
|
||||
return self._required_params
|
||||
|
||||
@property
|
||||
def params_shape_tree(self) -> Dict:
|
||||
return self._params_shape_tree
|
||||
|
||||
@params.setter
|
||||
def params(self, params: Union[Dict, FrozenDict]):
|
||||
# don't set params if the model is not initialized
|
||||
if not self._is_initialized:
|
||||
raise ValueError(
|
||||
"`params` cannot be set from model when the model is created with `_do_init=False`. "
|
||||
"You store the params outside of the model."
|
||||
)
|
||||
|
||||
if isinstance(params, FrozenDict):
|
||||
params = unfreeze(params)
|
||||
param_keys = set(flatten_dict(params).keys())
|
||||
|
@ -417,6 +456,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||
revision = kwargs.pop("revision", None)
|
||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
_do_init = kwargs.pop("_do_init", True)
|
||||
|
||||
user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class}
|
||||
if from_pipeline is not None:
|
||||
|
@ -553,7 +593,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||
resolved_archive_file = None
|
||||
|
||||
# init random models
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
|
||||
|
||||
if from_pt:
|
||||
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file)
|
||||
|
@ -577,25 +617,36 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||
# make sure all arrays are stored as jnp.arrays
|
||||
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
|
||||
# https://github.com/google/flax/issues/1261
|
||||
state = jax.tree_util.tree_map(jnp.array, state)
|
||||
if _do_init:
|
||||
state = jax.tree_util.tree_map(jnp.array, state)
|
||||
else:
|
||||
# keep the params on CPU if we don't want to initialize
|
||||
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
|
||||
|
||||
# if model is base model only use model_prefix key
|
||||
if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state:
|
||||
if cls.base_model_prefix not in dict(model.params_shape_tree) and cls.base_model_prefix in state:
|
||||
state = state[cls.base_model_prefix]
|
||||
|
||||
# if model is head model and we are loading weights from base model
|
||||
# we initialize new params dict with base_model_prefix
|
||||
if cls.base_model_prefix in dict(model.params) and cls.base_model_prefix not in state:
|
||||
if cls.base_model_prefix in dict(model.params_shape_tree) and cls.base_model_prefix not in state:
|
||||
state = {cls.base_model_prefix: state}
|
||||
|
||||
# flatten dicts
|
||||
state = flatten_dict(state)
|
||||
|
||||
random_state = flatten_dict(unfreeze(model.params))
|
||||
random_state = flatten_dict(unfreeze(model.params if _do_init else model.params_shape_tree))
|
||||
|
||||
missing_keys = model.required_params - set(state.keys())
|
||||
unexpected_keys = set(state.keys()) - model.required_params
|
||||
|
||||
if missing_keys and not _do_init:
|
||||
logger.warn(
|
||||
f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
|
||||
f"Make sure to call model.init_weights to initialize the missing weights."
|
||||
)
|
||||
cls._missing_keys = missing_keys
|
||||
|
||||
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||
# matching the weights in the model.
|
||||
mismatched_keys = []
|
||||
|
@ -612,9 +663,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||
"model."
|
||||
)
|
||||
|
||||
# add missing keys as random parameters
|
||||
for missing_key in missing_keys:
|
||||
state[missing_key] = random_state[missing_key]
|
||||
# add missing keys as random parameters if we are initializing
|
||||
if missing_keys and _do_init:
|
||||
for missing_key in missing_keys:
|
||||
state[missing_key] = random_state[missing_key]
|
||||
|
||||
# remove unexpected keys to not be saved again
|
||||
for unexpected_key in unexpected_keys:
|
||||
|
@ -680,10 +732,12 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
|
||||
)
|
||||
|
||||
# set correct parameters
|
||||
model.params = unflatten_dict(state)
|
||||
|
||||
return model
|
||||
if _do_init:
|
||||
# set correct parameters
|
||||
model.params = unflatten_dict(state)
|
||||
return model
|
||||
else:
|
||||
return model, unflatten_dict(state)
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], params=None, push_to_hub=False, **kwargs):
|
||||
"""
|
||||
|
|
|
@ -21,8 +21,9 @@ import flax
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
|
||||
from ...modeling_flax_outputs import (
|
||||
|
@ -522,12 +523,13 @@ class FlaxAlbertPreTrainedModel(FlaxPreTrainedModel):
|
|||
input_shape: Tuple = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
token_type_ids = jnp.zeros_like(input_ids)
|
||||
|
@ -537,9 +539,19 @@ class FlaxAlbertPreTrainedModel(FlaxPreTrainedModel):
|
|||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False)[
|
||||
"params"
|
||||
]
|
||||
random_params = self.module.init(
|
||||
rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False
|
||||
)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
|
|
|
@ -24,9 +24,10 @@ import numpy as np
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
|
@ -912,12 +913,13 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
|
|||
input_shape: Tuple[int] = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
# make sure initialization pass will work for FlaxBartForSequenceClassificationModule
|
||||
|
@ -933,7 +935,7 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
|
|||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(
|
||||
random_params = self.module.init(
|
||||
rngs,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
|
@ -943,6 +945,16 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
|
|||
decoder_position_ids,
|
||||
)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
def init_cache(self, batch_size, max_length, encoder_outputs):
|
||||
r"""
|
||||
Args:
|
||||
|
@ -1737,14 +1749,15 @@ class FlaxBartDecoderPreTrainedModel(FlaxPreTrainedModel):
|
|||
input_shape: Tuple[int] = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
config.is_decoder = True
|
||||
config.is_encoder_decoder = False
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
|
|
@ -22,8 +22,9 @@ import flax
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
|
||||
from ...modeling_flax_outputs import (
|
||||
FlaxBaseModelOutput,
|
||||
|
@ -591,13 +592,21 @@ class FlaxBeitPreTrainedModel(FlaxPreTrainedModel):
|
|||
main_input_name = "pixel_values"
|
||||
module_class: nn.Module = None
|
||||
|
||||
def __init__(self, config: BeitConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
config: BeitConfig,
|
||||
input_shape=None,
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
if input_shape is None:
|
||||
input_shape = (1, config.image_size, config.image_size, 3)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
pixel_values = jnp.zeros(input_shape, dtype=self.dtype)
|
||||
|
||||
|
@ -605,7 +614,17 @@ class FlaxBeitPreTrainedModel(FlaxPreTrainedModel):
|
|||
dropout_rng, droppath_rng = jax.random.split(dropout_rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng, "droppath": droppath_rng}
|
||||
|
||||
return self.module.init(rngs, pixel_values, return_dict=False)["params"]
|
||||
random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
|
|
|
@ -21,8 +21,9 @@ import flax
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
|
||||
from ...modeling_flax_outputs import (
|
||||
|
@ -616,12 +617,18 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
|||
module_class: nn.Module = None
|
||||
|
||||
def __init__(
|
||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
self,
|
||||
config: BertConfig,
|
||||
input_shape: Tuple = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
token_type_ids = jnp.zeros_like(input_ids)
|
||||
|
@ -632,10 +639,20 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
|||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(
|
||||
random_params = self.module.init(
|
||||
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||
)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
|
|
|
@ -21,8 +21,9 @@ import flax
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
|
||||
from ...modeling_flax_outputs import (
|
||||
|
@ -1420,6 +1421,7 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
|
|||
input_shape: Optional[tuple] = None,
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
|
@ -1428,9 +1430,9 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
|
|||
elif input_shape is None:
|
||||
input_shape = (1, 1)
|
||||
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
token_type_ids = jnp.zeros_like(input_ids)
|
||||
|
@ -1441,10 +1443,20 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
|
|||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(
|
||||
random_params = self.module.init(
|
||||
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||
)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -1897,13 +1909,14 @@ class FlaxBigBirdForMultipleChoice(FlaxBigBirdPreTrainedModel):
|
|||
input_shape: Optional[tuple] = None,
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
if config.attention_type == "block_sparse" and input_shape is None:
|
||||
input_shape = (1, 1, 12 * config.block_size)
|
||||
elif input_shape is None:
|
||||
input_shape = (1, 1)
|
||||
super().__init__(config, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
|
||||
overwrite_call_docstring(
|
||||
|
|
|
@ -24,9 +24,10 @@ import numpy as np
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
|
@ -887,12 +888,13 @@ class FlaxBlenderbotPreTrainedModel(FlaxPreTrainedModel):
|
|||
input_shape: Tuple[int] = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
# make sure initialization pass will work for FlaxBlenderbotForSequenceClassificationModule
|
||||
|
@ -908,7 +910,7 @@ class FlaxBlenderbotPreTrainedModel(FlaxPreTrainedModel):
|
|||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(
|
||||
random_params = self.module.init(
|
||||
rngs,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
|
@ -918,6 +920,16 @@ class FlaxBlenderbotPreTrainedModel(FlaxPreTrainedModel):
|
|||
decoder_position_ids,
|
||||
)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
def init_cache(self, batch_size, max_length, encoder_outputs):
|
||||
r"""
|
||||
Args:
|
||||
|
|
|
@ -25,9 +25,10 @@ import numpy as np
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
|
@ -885,12 +886,13 @@ class FlaxBlenderbotSmallPreTrainedModel(FlaxPreTrainedModel):
|
|||
input_shape: Tuple[int] = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
# make sure initialization pass will work for FlaxBlenderbotSmallForSequenceClassificationModule
|
||||
|
@ -906,7 +908,7 @@ class FlaxBlenderbotSmallPreTrainedModel(FlaxPreTrainedModel):
|
|||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(
|
||||
random_params = self.module.init(
|
||||
rngs,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
|
@ -916,6 +918,16 @@ class FlaxBlenderbotSmallPreTrainedModel(FlaxPreTrainedModel):
|
|||
decoder_position_ids,
|
||||
)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
def init_cache(self, batch_size, max_length, encoder_outputs):
|
||||
r"""
|
||||
Args:
|
||||
|
|
|
@ -19,9 +19,10 @@ import flax
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
|
||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling
|
||||
|
@ -585,12 +586,18 @@ class FlaxCLIPTextPreTrainedModel(FlaxPreTrainedModel):
|
|||
module_class: nn.Module = None
|
||||
|
||||
def __init__(
|
||||
self, config: CLIPTextConfig, input_shape=(1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
self,
|
||||
config: CLIPTextConfig,
|
||||
input_shape=(1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensor
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
|
||||
|
@ -599,7 +606,17 @@ class FlaxCLIPTextPreTrainedModel(FlaxPreTrainedModel):
|
|||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(rngs, input_ids, attention_mask, position_ids)["params"]
|
||||
random_params = self.module.init(rngs, input_ids, attention_mask, position_ids)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -654,21 +671,32 @@ class FlaxCLIPVisionPreTrainedModel(FlaxPreTrainedModel):
|
|||
input_shape: Optional[Tuple] = None,
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
if input_shape is None:
|
||||
input_shape = (1, config.image_size, config.image_size, 3)
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensor
|
||||
pixel_values = jax.random.normal(rng, input_shape)
|
||||
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(rngs, pixel_values)["params"]
|
||||
random_params = self.module.init(rngs, pixel_values)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -714,14 +742,15 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
|
|||
input_shape: Optional[Tuple] = None,
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
if input_shape is None:
|
||||
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensor
|
||||
input_ids = jnp.zeros(input_shape[0], dtype="i4")
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
|
||||
|
@ -732,7 +761,17 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
|
|||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids)["params"]
|
||||
random_params = self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
|
|
@ -21,7 +21,8 @@ import numpy as np
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
|
||||
from ...modeling_flax_outputs import (
|
||||
|
@ -428,12 +429,13 @@ class FlaxDistilBertPreTrainedModel(FlaxPreTrainedModel):
|
|||
input_shape: Tuple = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
@ -441,7 +443,17 @@ class FlaxDistilBertPreTrainedModel(FlaxPreTrainedModel):
|
|||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(rngs, input_ids, attention_mask, return_dict=False)["params"]
|
||||
random_params = self.module.init(rngs, input_ids, attention_mask, return_dict=False)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
|
|
|
@ -21,8 +21,9 @@ import flax
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
|
@ -541,12 +542,13 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
|
|||
input_shape: Tuple = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
token_type_ids = jnp.zeros_like(input_ids)
|
||||
|
@ -557,10 +559,20 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
|
|||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(
|
||||
random_params = self.module.init(
|
||||
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||
)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
|
|
|
@ -21,7 +21,8 @@ from typing import Optional, Tuple, Union
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
|
@ -315,11 +316,17 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
|
|||
input_shape: Optional[Tuple] = None,
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
if input_shape is None:
|
||||
input_shape = ((1, 1), (1, 1))
|
||||
|
||||
if not _do_init:
|
||||
raise ValueError(
|
||||
"`FlaxEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
|
||||
)
|
||||
|
||||
if config.decoder.cross_attention_hidden_size is not None:
|
||||
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
|
||||
raise ValueError(
|
||||
|
@ -330,9 +337,9 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
|
|||
)
|
||||
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
encoder_input_shape, decoder_input_shape = input_shape
|
||||
|
||||
# init input tensors
|
||||
|
@ -356,7 +363,7 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
|
|||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(
|
||||
random_params = self.module.init(
|
||||
rngs,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
|
@ -366,6 +373,16 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
|
|||
decoder_position_ids,
|
||||
)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
def init_cache(self, batch_size, max_length, encoder_outputs):
|
||||
r"""
|
||||
Args:
|
||||
|
|
|
@ -18,9 +18,10 @@ from typing import Any, Optional, Tuple
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
|
||||
from ...modeling_flax_outputs import (
|
||||
|
@ -394,12 +395,13 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
|
|||
input_shape: Tuple = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
@ -422,7 +424,17 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
|
|||
else:
|
||||
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
|
||||
|
||||
return module_init_outputs["params"]
|
||||
random_params = module_init_outputs["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
def init_cache(self, batch_size, max_length):
|
||||
r"""
|
||||
|
|
|
@ -19,9 +19,10 @@ from typing import Optional, Tuple
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
|
||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
|
||||
|
@ -353,12 +354,13 @@ class FlaxGPTNeoPreTrainedModel(FlaxPreTrainedModel):
|
|||
input_shape: Tuple = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
@ -366,7 +368,17 @@ class FlaxGPTNeoPreTrainedModel(FlaxPreTrainedModel):
|
|||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"]
|
||||
random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
def init_cache(self, batch_size, max_length):
|
||||
r"""
|
||||
|
|
|
@ -21,9 +21,10 @@ import numpy as np
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
|
||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
|
||||
|
@ -373,12 +374,13 @@ class FlaxGPTJPreTrainedModel(FlaxPreTrainedModel):
|
|||
input_shape: Tuple = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
@ -401,7 +403,17 @@ class FlaxGPTJPreTrainedModel(FlaxPreTrainedModel):
|
|||
else:
|
||||
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
|
||||
|
||||
return module_init_outputs["params"]
|
||||
random_params = module_init_outputs["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
def init_cache(self, batch_size, max_length):
|
||||
r"""
|
||||
|
|
|
@ -24,9 +24,10 @@ import numpy as np
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
|
@ -882,12 +883,13 @@ class FlaxMarianPreTrainedModel(FlaxPreTrainedModel):
|
|||
input_shape: Tuple[int] = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
# make sure initialization pass will work for FlaxMarianForSequenceClassificationModule
|
||||
|
@ -903,7 +905,7 @@ class FlaxMarianPreTrainedModel(FlaxPreTrainedModel):
|
|||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(
|
||||
random_params = self.module.init(
|
||||
rngs,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
|
@ -913,6 +915,16 @@ class FlaxMarianPreTrainedModel(FlaxPreTrainedModel):
|
|||
decoder_position_ids,
|
||||
)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
def init_cache(self, batch_size, max_length, encoder_outputs):
|
||||
r"""
|
||||
Args:
|
||||
|
|
|
@ -24,9 +24,10 @@ import numpy as np
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
|
@ -951,12 +952,13 @@ class FlaxMBartPreTrainedModel(FlaxPreTrainedModel):
|
|||
input_shape: Tuple[int] = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
# make sure initialization pass will work for FlaxMBartForSequenceClassificationModule
|
||||
|
@ -972,7 +974,7 @@ class FlaxMBartPreTrainedModel(FlaxPreTrainedModel):
|
|||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(
|
||||
random_params = self.module.init(
|
||||
rngs,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
|
@ -982,6 +984,16 @@ class FlaxMBartPreTrainedModel(FlaxPreTrainedModel):
|
|||
decoder_position_ids,
|
||||
)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartPreTrainedModel.init_cache with Bart->MBart
|
||||
def init_cache(self, batch_size, max_length, encoder_outputs):
|
||||
r"""
|
||||
|
|
|
@ -25,9 +25,10 @@ import numpy as np
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
|
@ -901,12 +902,13 @@ class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel):
|
|||
input_shape: Tuple[int] = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
@ -920,7 +922,7 @@ class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel):
|
|||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(
|
||||
random_params = self.module.init(
|
||||
rngs,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
|
@ -930,6 +932,16 @@ class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel):
|
|||
decoder_position_ids,
|
||||
)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
def init_cache(self, batch_size, max_length, encoder_outputs):
|
||||
r"""
|
||||
Args:
|
||||
|
|
|
@ -19,8 +19,9 @@ import numpy as np
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
|
@ -585,12 +586,13 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
|||
input_shape: Tuple = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
|
@ -601,10 +603,20 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
|||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(
|
||||
random_params = self.module.init(
|
||||
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||
)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
|
|
|
@ -21,8 +21,9 @@ import numpy as np
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
|
||||
from ...modeling_flax_outputs import (
|
||||
|
@ -621,12 +622,13 @@ class FlaxRoFormerPreTrainedModel(FlaxPreTrainedModel):
|
|||
input_shape: Tuple = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
token_type_ids = jnp.zeros_like(input_ids)
|
||||
|
@ -636,9 +638,19 @@ class FlaxRoFormerPreTrainedModel(FlaxPreTrainedModel):
|
|||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, head_mask, return_dict=False)[
|
||||
"params"
|
||||
]
|
||||
random_params = self.module.init(
|
||||
rngs, input_ids, attention_mask, token_type_ids, head_mask, return_dict=False
|
||||
)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
@add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
|
|
|
@ -20,7 +20,8 @@ from typing import Optional, Tuple, Union
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
|
@ -343,8 +344,15 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
|
|||
input_shape: Optional[Tuple] = None,
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
if not _do_init:
|
||||
raise ValueError(
|
||||
"`FlaxSpeechEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
|
||||
)
|
||||
|
||||
if config.decoder.cross_attention_hidden_size is not None:
|
||||
# Raise ValueError or option to project enc to dec hidden_size (eg EncAdapterLayer)
|
||||
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
|
||||
|
@ -365,9 +373,9 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
|
|||
decoder_input_length = module._get_feat_extract_output_lengths(encoder_input_length)
|
||||
input_shape = ((1, encoder_input_length), (1, decoder_input_length))
|
||||
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
encoder_input_shape, decoder_input_shape = input_shape
|
||||
|
||||
# init input DeviceArrays
|
||||
|
@ -390,7 +398,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
|
|||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(
|
||||
random_params = self.module.init(
|
||||
rngs,
|
||||
inputs,
|
||||
attention_mask,
|
||||
|
@ -399,6 +407,16 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
|
|||
decoder_position_ids,
|
||||
)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
def init_cache(self, batch_size, max_length, encoder_outputs):
|
||||
r"""
|
||||
Args:
|
||||
|
|
|
@ -23,9 +23,10 @@ import numpy as np
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from ...modeling_flax_outputs import (
|
||||
|
@ -919,12 +920,13 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
|
|||
input_shape: Tuple[int] = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
|
||||
|
@ -935,7 +937,7 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
|
|||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(
|
||||
random_params = self.module.init(
|
||||
rngs,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
|
@ -943,6 +945,16 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
|
|||
decoder_attention_mask,
|
||||
)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
|
||||
def __call__(
|
||||
self,
|
||||
|
|
|
@ -21,7 +21,8 @@ from typing import Optional, Tuple, Union
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
|
@ -282,8 +283,14 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
|
|||
input_shape: Optional[Tuple] = None,
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
if not _do_init:
|
||||
raise ValueError(
|
||||
"`FlaxVisionEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
|
||||
)
|
||||
|
||||
if input_shape is None:
|
||||
num_channels = getattr(config.encoder, "num_channels", 3)
|
||||
input_shape = (
|
||||
|
@ -301,9 +308,9 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
|
|||
)
|
||||
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
encoder_input_shape, decoder_input_shape = input_shape
|
||||
|
||||
# init input tensors
|
||||
|
@ -325,7 +332,7 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
|
|||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(
|
||||
random_params = self.module.init(
|
||||
rngs,
|
||||
pixel_values,
|
||||
decoder_input_ids,
|
||||
|
@ -333,6 +340,16 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
|
|||
decoder_position_ids,
|
||||
)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
def init_cache(self, batch_size, max_length, encoder_outputs):
|
||||
r"""
|
||||
Args:
|
||||
|
|
|
@ -20,7 +20,8 @@ from typing import Optional, Tuple
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
|
||||
from ...modeling_flax_utils import FlaxPreTrainedModel, append_replace_return_docstrings, overwrite_call_docstring
|
||||
from ...utils import add_start_docstrings, logging
|
||||
|
@ -225,15 +226,22 @@ class FlaxVisionTextDualEncoderModel(FlaxPreTrainedModel):
|
|||
input_shape: Optional[Tuple] = None,
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
if not _do_init:
|
||||
raise ValueError(
|
||||
"`FlaxVisionTextDualEncoderModel` cannot be created without initializing, `_do_init` must be `True`."
|
||||
)
|
||||
|
||||
if input_shape is None:
|
||||
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
|
||||
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensor
|
||||
input_ids = jnp.zeros(input_shape[0], dtype="i4")
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
|
||||
|
@ -245,7 +253,19 @@ class FlaxVisionTextDualEncoderModel(FlaxPreTrainedModel):
|
|||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids)["params"]
|
||||
random_params = self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids)[
|
||||
"params"
|
||||
]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
|
|
@ -18,8 +18,9 @@ from typing import Optional, Tuple
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
|
||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling, FlaxSequenceClassifierOutput
|
||||
from ...modeling_flax_utils import (
|
||||
|
@ -407,20 +408,38 @@ class FlaxViTPreTrainedModel(FlaxPreTrainedModel):
|
|||
main_input_name = "pixel_values"
|
||||
module_class: nn.Module = None
|
||||
|
||||
def __init__(self, config: ViTConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
config: ViTConfig,
|
||||
input_shape=None,
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
if input_shape is None:
|
||||
input_shape = (1, config.image_size, config.image_size, 3)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
pixel_values = jnp.zeros(input_shape, dtype=self.dtype)
|
||||
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(rngs, pixel_values, return_dict=False)["params"]
|
||||
random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
|
|
|
@ -23,8 +23,9 @@ import flax
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
|
||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
|
||||
|
@ -858,19 +859,30 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
|
|||
input_shape: Tuple = (1, 1024),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_values = jnp.zeros(input_shape, dtype="i4")
|
||||
attention_mask = jnp.ones_like(input_values)
|
||||
params_rng, dropout_rng = jax.random.split(rng, 2)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"]
|
||||
random_params = self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||
def __call__(
|
||||
|
|
|
@ -25,9 +25,10 @@ import numpy as np
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
|
@ -561,12 +562,13 @@ class FlaxXGLMPreTrainedModel(FlaxPreTrainedModel):
|
|||
input_shape: Tuple[int] = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
@ -589,7 +591,17 @@ class FlaxXGLMPreTrainedModel(FlaxPreTrainedModel):
|
|||
else:
|
||||
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
|
||||
|
||||
return module_init_outputs["params"]
|
||||
random_params = module_init_outputs["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
def init_cache(self, batch_size, max_length):
|
||||
r"""
|
||||
|
|
|
@ -23,7 +23,8 @@ import numpy as np
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze, freeze
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from jax import lax
|
||||
|
||||
|
@ -586,12 +587,18 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
|||
module_class: nn.Module = None
|
||||
|
||||
def __init__(
|
||||
self, config: {{cookiecutter.camelcase_modelname}}Config, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
self,
|
||||
config: {{cookiecutter.camelcase_modelname}}Config,
|
||||
input_shape: Tuple = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
token_type_ids = jnp.zeros_like(input_ids)
|
||||
|
@ -602,10 +609,20 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
|||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(
|
||||
random_params = self.module.init(
|
||||
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||
)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -1130,9 +1147,10 @@ from typing import Callable, Optional, Tuple
|
|||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze, freeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
|
@ -2031,12 +2049,13 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
|||
input_shape: Tuple[int] = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
# make sure initialization pass will work for Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule
|
||||
|
@ -2052,7 +2071,7 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
|||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(
|
||||
random_params = self.module.init(
|
||||
rngs,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
|
@ -2062,6 +2081,16 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
|||
decoder_position_ids,
|
||||
)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
def init_cache(self, batch_size, max_length, encoder_outputs):
|
||||
r"""
|
||||
Args:
|
||||
|
|
|
@ -43,7 +43,7 @@ if is_flax_available():
|
|||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import unfreeze
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from transformers import (
|
||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
|
@ -904,6 +904,93 @@ class FlaxModelTesterMixin:
|
|||
else:
|
||||
_check_attentions_validity(outputs.attentions)
|
||||
|
||||
def test_no_automatic_init(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config, _do_init=False)
|
||||
|
||||
# Check that accesing parmas raises an ValueError when _do_init is False
|
||||
with self.assertRaises(ValueError):
|
||||
params = model.params
|
||||
|
||||
# Check if we params can be properly initialized when calling init_weights
|
||||
params = model.init_weights(model.key, model.input_shape)
|
||||
self.assertIsInstance(params, FrozenDict)
|
||||
# Check if all required parmas are initialized
|
||||
keys = set(flatten_dict(unfreeze(params)).keys())
|
||||
self.assertTrue(all(k in keys for k in model.required_params))
|
||||
# Check if the shapes match
|
||||
flat_params = flatten_dict(unfreeze(params))
|
||||
for k, v in flatten_dict(unfreeze(model.params_shape_tree)).items():
|
||||
self.assertEqual(
|
||||
v.shape,
|
||||
flat_params[k].shape,
|
||||
"Shapes of {} do not match. Expecting {}, got {}.".format(k, v.shape, flat_params[k].shape),
|
||||
)
|
||||
|
||||
# Check that setting params raises an ValueError when _do_init is False
|
||||
with self.assertRaises(ValueError):
|
||||
model.params = params
|
||||
|
||||
# Check if we can do a forward pass
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class).copy()
|
||||
model(**inputs, params=params)
|
||||
|
||||
def test_from_pretrained_with_no_automatic_init(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
def _assert_all_params_initialised(model, params):
|
||||
# Check if all required parmas are loaded
|
||||
keys = set(flatten_dict(unfreeze(params)).keys())
|
||||
self.assertTrue(all(k in keys for k in model.required_params))
|
||||
# Check if the shapes match
|
||||
flat_params = flatten_dict(unfreeze(params))
|
||||
for k, v in flatten_dict(unfreeze(model.params_shape_tree)).items():
|
||||
self.assertEqual(
|
||||
v.shape,
|
||||
flat_params[k].shape,
|
||||
"Shapes of {} do not match. Expecting {}, got {}.".format(k, v.shape, flat_params[k].shape),
|
||||
)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
# init the model
|
||||
model = model_class(config)
|
||||
|
||||
# save the model in the temporary directory
|
||||
# load the saved model with _do_init=False
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model, params = model_class.from_pretrained(tmpdirname, _do_init=False)
|
||||
|
||||
# Check that accesing parmas raises an ValueError when _do_init is False
|
||||
with self.assertRaises(ValueError):
|
||||
params = model.params
|
||||
|
||||
# Check if all required parmas are loaded
|
||||
_assert_all_params_initialised(model, params)
|
||||
|
||||
# Check that setting params raises an ValueError when _do_init is False
|
||||
with self.assertRaises(ValueError):
|
||||
model.params = params
|
||||
|
||||
# Check if init_weights initializes missing keys from from_pretrained
|
||||
flat_params = flatten_dict(unfreeze(params))
|
||||
random_key = random.choice(list(flat_params.keys()))
|
||||
flat_params.pop(random_key)
|
||||
params = freeze(unflatten_dict(flat_params))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, params=params)
|
||||
model, params = model_class.from_pretrained(tmpdirname, _do_init=False)
|
||||
|
||||
params = model.init_weights(model.key, model.input_shape, params=params)
|
||||
# Check if all required parmas are loaded
|
||||
_assert_all_params_initialised(model, params)
|
||||
|
||||
|
||||
@require_flax
|
||||
@is_staging_test
|
||||
|
|
Loading…
Reference in New Issue