diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 7b7301d6d8..c680b4c634 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -541,11 +541,10 @@ class PipelineUtilsTest(unittest.TestCase): @slow @require_tf def test_load_default_pipelines_tf(self): - import tensorflow as tf - + from transformers.modeling_tf_utils import keras from transformers.pipelines import SUPPORTED_TASKS - set_seed_fn = lambda: tf.random.set_seed(0) # noqa: E731 + set_seed_fn = lambda: keras.utils.set_random_seed(0) # noqa: E731 for task in SUPPORTED_TASKS.keys(): if task == "table-question-answering": # test table in seperate test due to more dependencies @@ -553,7 +552,7 @@ class PipelineUtilsTest(unittest.TestCase): self.check_default_pipeline(task, "tf", set_seed_fn, self.check_models_equal_tf) - # clean-up as much as possible GPU memory occupied by PyTorch + # clean-up as much as possible GPU memory occupied by TF gc.collect() @slow