diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index 2d91bdb3eb..cbaba81b57 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -2437,6 +2437,8 @@ class TFGenerationMixin: else self.generation_config.return_dict_in_generate ) use_cache = True # In contrastive search, we always use cache + model_kwargs.pop("use_cache", None) + use_xla = not tf.executing_eagerly() # TODO (Joao): fix cache format or find programatic way to detect cache index # GPT2 and other models has a slightly different cache structure, with a different batch axis