From 4f1788b34d0d99a10805fb8e32e65097bd7719e4 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 13 Dec 2022 12:51:07 +0000 Subject: [PATCH] Fix AdamWeightDecay for TF 2.11 (#20735) * Fix AdamWeightDecay for TF * Fix AdamWeightDecay for TF * make fixup --- src/transformers/optimization_tf.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/optimization_tf.py b/src/transformers/optimization_tf.py index e2b2a961ca..58ff287d8b 100644 --- a/src/transformers/optimization_tf.py +++ b/src/transformers/optimization_tf.py @@ -21,6 +21,12 @@ from typing import Callable, List, Optional, Union import tensorflow as tf +if hasattr(tf.keras, "optimizer") and hasattr(tf.keras.optimizer, "legacy"): + Adam = tf.keras.optimizer.legacy.Adam +else: + Adam = tf.keras.optimizers.Adam + + class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): """ Applies a warmup schedule on a given learning rate decay schedule. @@ -163,7 +169,7 @@ def create_optimizer( return optimizer, lr_schedule -class AdamWeightDecay(tf.keras.optimizers.Adam): +class AdamWeightDecay(Adam): """ Adam enables L2 weight decay and clip_by_global_norm on gradients. Just adding the square of the weights to the loss function is *not* the correct way of using L2 regularization/weight decay with Adam, since that will interact