pass decay_mask fn to optimizer (#12087)
This commit is contained in:
parent
d472bd7b18
commit
d1500d9151
|
@ -38,7 +38,7 @@ import flax
|
|||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
from flax import jax_utils
|
||||
from flax import jax_utils, traverse_util
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard
|
||||
from transformers import (
|
||||
|
@ -504,6 +504,15 @@ if __name__ == "__main__":
|
|||
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
|
||||
)
|
||||
|
||||
# We use Optax's "masking" functionality to not apply weight decay
|
||||
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
||||
# mask boolean with the same structure as the parameters.
|
||||
# The mask is True for parameters that should be decayed.
|
||||
def decay_mask_fn(params):
|
||||
flat_params = traverse_util.flatten_dict(params)
|
||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
|
||||
return traverse_util.unflatten_dict(flat_mask)
|
||||
|
||||
# create adam optimizer
|
||||
adamw = optax.adamw(
|
||||
learning_rate=linear_decay_lr_schedule_fn,
|
||||
|
@ -511,6 +520,7 @@ if __name__ == "__main__":
|
|||
b2=training_args.adam_beta2,
|
||||
eps=1e-8,
|
||||
weight_decay=training_args.weight_decay,
|
||||
mask=decay_mask_fn,
|
||||
)
|
||||
|
||||
# Setup train state
|
||||
|
|
Loading…
Reference in New Issue