Generate: TF contrastive search must pop `use_cache` from `model_kwargs` (#21149)

This commit is contained in:
Joao Gante 2023-01-17 13:42:52 +00:00 committed by GitHub
parent 7f3dab39b5
commit 7b5e943cb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 0 deletions

View File

@ -2437,6 +2437,8 @@ class TFGenerationMixin:
else self.generation_config.return_dict_in_generate
)
use_cache = True # In contrastive search, we always use cache
model_kwargs.pop("use_cache", None)
use_xla = not tf.executing_eagerly()
# TODO (Joao): fix cache format or find programatic way to detect cache index
# GPT2 and other models has a slightly different cache structure, with a different batch axis