Fix failing tests for XLA generation in TF (#18298)
* Fix failing test_xla_generate_slow tests * Fix failing speech-to-text xla_generate tests
This commit is contained in:
parent
a507908cd3
commit
8fb7c908c8
|
@ -1685,6 +1685,17 @@ class TFModelTesterMixin:
|
|||
config.do_sample = False
|
||||
config.num_beams = num_beams
|
||||
config.num_return_sequences = num_return_sequences
|
||||
|
||||
# fix config for models with additional sequence-length limiting settings
|
||||
for var_name in ["max_position_embeddings", "max_target_positions"]:
|
||||
if hasattr(config, var_name):
|
||||
try:
|
||||
setattr(config, var_name, max_length)
|
||||
except NotImplementedError:
|
||||
# xlnet will raise an exception when trying to set
|
||||
# max_position_embeddings.
|
||||
pass
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
if model.supports_xla_generation:
|
||||
|
@ -1714,15 +1725,6 @@ class TFModelTesterMixin:
|
|||
|
||||
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
|
||||
"""
|
||||
# TODO (Joao): find the issues related to the following models. They are passing the fast test, but failing
|
||||
# the slow one.
|
||||
if any(
|
||||
[
|
||||
model in str(self).lower()
|
||||
for model in ["tfbart", "tfblenderbot", "tfmarian", "tfmbart", "tfopt", "tfpegasus"]
|
||||
]
|
||||
):
|
||||
return
|
||||
num_beams = 8
|
||||
num_return_sequences = 2
|
||||
max_length = 128
|
||||
|
|
Loading…
Reference in New Issue