Fix failing tests for XLA generation in TF (#18298)

* Fix failing test_xla_generate_slow tests

* Fix failing speech-to-text xla_generate tests
This commit is contained in:
Daniel Suess 2022-08-03 15:45:15 +02:00 committed by GitHub
parent a507908cd3
commit 8fb7c908c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 9 deletions

View File

@ -1685,6 +1685,17 @@ class TFModelTesterMixin:
config.do_sample = False
config.num_beams = num_beams
config.num_return_sequences = num_return_sequences
# fix config for models with additional sequence-length limiting settings
for var_name in ["max_position_embeddings", "max_target_positions"]:
if hasattr(config, var_name):
try:
setattr(config, var_name, max_length)
except NotImplementedError:
# xlnet will raise an exception when trying to set
# max_position_embeddings.
pass
model = model_class(config)
if model.supports_xla_generation:
@ -1714,15 +1725,6 @@ class TFModelTesterMixin:
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
"""
# TODO (Joao): find the issues related to the following models. They are passing the fast test, but failing
# the slow one.
if any(
[
model in str(self).lower()
for model in ["tfbart", "tfblenderbot", "tfmarian", "tfmbart", "tfopt", "tfpegasus"]
]
):
return
num_beams = 8
num_return_sequences = 2
max_length = 128