[examples/flax] add adafactor optimizer (#12544)

* add adafactor

* Update examples/flax/language-modeling/run_mlm_flax.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Suraj Patil 2021-07-07 11:50:30 +05:30 committed by GitHub
parent 208df208bf
commit 2d42915abe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 48 additions and 27 deletions

View File

@ -2,4 +2,4 @@ datasets >= 1.1.3
jax>=0.2.8
jaxlib>=0.1.59
flax>=0.3.4
optax>=0.0.8
optax>=0.0.9

View File

@ -489,17 +489,24 @@ def main():
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
if training_args.adafactor:
# We use the default parameters here to initialize adafactor,
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
optimizer = optax.adafactor(
learning_rate=linear_decay_lr_schedule_fn,
)
else:
optimizer = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
# Setup train state
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
def loss_fn(logits, labels):
shift_logits = logits[..., :-1, :]

View File

@ -513,17 +513,24 @@ if __name__ == "__main__":
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
eps=1e-8,
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
if training_args.adafactor:
# We use the default parameters here to initialize adafactor,
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
optimizer = optax.adafactor(
learning_rate=linear_decay_lr_schedule_fn,
)
else:
optimizer = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
# Setup train state
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
# Define gradient update step fn
def train_step(state, batch, dropout_rng):

View File

@ -635,16 +635,23 @@ if __name__ == "__main__":
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
if training_args.adafactor:
# We use the default parameters here to initialize adafactor,
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
optimizer = optax.adafactor(
learning_rate=linear_decay_lr_schedule_fn,
)
else:
optimizer = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
# Setup train state
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
# Define gradient update step fn
def train_step(state, batch, dropout_rng):