Use the Keras set_random_seed in tests (#30504)

Use the Keras set_random_seed to ensure reproducible weight initialization
This commit is contained in:
Matt 2024-04-26 16:14:53 +01:00 committed by GitHub
parent 20081c743e
commit 2de5cb12be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 4 deletions

View File

@ -541,11 +541,10 @@ class PipelineUtilsTest(unittest.TestCase):
@slow
@require_tf
def test_load_default_pipelines_tf(self):
import tensorflow as tf
from transformers.modeling_tf_utils import keras
from transformers.pipelines import SUPPORTED_TASKS
set_seed_fn = lambda: tf.random.set_seed(0) # noqa: E731
set_seed_fn = lambda: keras.utils.set_random_seed(0) # noqa: E731
for task in SUPPORTED_TASKS.keys():
if task == "table-question-answering":
# test table in seperate test due to more dependencies
@ -553,7 +552,7 @@ class PipelineUtilsTest(unittest.TestCase):
self.check_default_pipeline(task, "tf", set_seed_fn, self.check_models_equal_tf)
# clean-up as much as possible GPU memory occupied by PyTorch
# clean-up as much as possible GPU memory occupied by TF
gc.collect()
@slow