Fix AdamWeightDecay for TF 2.11 (#20735)

* Fix AdamWeightDecay for TF

* Fix AdamWeightDecay for TF

* make fixup
This commit is contained in:
Matt 2022-12-13 12:51:07 +00:00 committed by GitHub
parent a12c5cbcd8
commit 4f1788b34d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 1 deletions

View File

@ -21,6 +21,12 @@ from typing import Callable, List, Optional, Union
import tensorflow as tf
if hasattr(tf.keras, "optimizer") and hasattr(tf.keras.optimizer, "legacy"):
Adam = tf.keras.optimizer.legacy.Adam
else:
Adam = tf.keras.optimizers.Adam
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
"""
Applies a warmup schedule on a given learning rate decay schedule.
@ -163,7 +169,7 @@ def create_optimizer(
return optimizer, lr_schedule
class AdamWeightDecay(tf.keras.optimizers.Adam):
class AdamWeightDecay(Adam):
"""
Adam enables L2 weight decay and clip_by_global_norm on gradients. Just adding the square of the weights to the
loss function is *not* the correct way of using L2 regularization/weight decay with Adam, since that will interact