Fix AdamWeightDecay for TF 2.11 (#20735)
* Fix AdamWeightDecay for TF * Fix AdamWeightDecay for TF * make fixup
This commit is contained in:
parent
a12c5cbcd8
commit
4f1788b34d
|
@ -21,6 +21,12 @@ from typing import Callable, List, Optional, Union
|
|||
import tensorflow as tf
|
||||
|
||||
|
||||
if hasattr(tf.keras, "optimizer") and hasattr(tf.keras.optimizer, "legacy"):
|
||||
Adam = tf.keras.optimizer.legacy.Adam
|
||||
else:
|
||||
Adam = tf.keras.optimizers.Adam
|
||||
|
||||
|
||||
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
||||
"""
|
||||
Applies a warmup schedule on a given learning rate decay schedule.
|
||||
|
@ -163,7 +169,7 @@ def create_optimizer(
|
|||
return optimizer, lr_schedule
|
||||
|
||||
|
||||
class AdamWeightDecay(tf.keras.optimizers.Adam):
|
||||
class AdamWeightDecay(Adam):
|
||||
"""
|
||||
Adam enables L2 weight decay and clip_by_global_norm on gradients. Just adding the square of the weights to the
|
||||
loss function is *not* the correct way of using L2 regularization/weight decay with Adam, since that will interact
|
||||
|
|
Loading…
Reference in New Issue