pass decay_mask fn to optimizer (#12087)

This commit is contained in:
Suraj Patil 2021-06-09 23:19:27 +05:30 committed by GitHub
parent d472bd7b18
commit d1500d9151
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 1 deletions

View File

@ -38,7 +38,7 @@ import flax
import jax
import jax.numpy as jnp
import optax
from flax import jax_utils
from flax import jax_utils, traverse_util
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from transformers import (
@ -504,6 +504,15 @@ if __name__ == "__main__":
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
@ -511,6 +520,7 @@ if __name__ == "__main__":
b2=training_args.adam_beta2,
eps=1e-8,
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
# Setup train state