Use the Keras set_random_seed in tests (#30504)
Use the Keras set_random_seed to ensure reproducible weight initialization
This commit is contained in:
parent
20081c743e
commit
2de5cb12be
|
@ -541,11 +541,10 @@ class PipelineUtilsTest(unittest.TestCase):
|
|||
@slow
|
||||
@require_tf
|
||||
def test_load_default_pipelines_tf(self):
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers.modeling_tf_utils import keras
|
||||
from transformers.pipelines import SUPPORTED_TASKS
|
||||
|
||||
set_seed_fn = lambda: tf.random.set_seed(0) # noqa: E731
|
||||
set_seed_fn = lambda: keras.utils.set_random_seed(0) # noqa: E731
|
||||
for task in SUPPORTED_TASKS.keys():
|
||||
if task == "table-question-answering":
|
||||
# test table in seperate test due to more dependencies
|
||||
|
@ -553,7 +552,7 @@ class PipelineUtilsTest(unittest.TestCase):
|
|||
|
||||
self.check_default_pipeline(task, "tf", set_seed_fn, self.check_models_equal_tf)
|
||||
|
||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||
# clean-up as much as possible GPU memory occupied by TF
|
||||
gc.collect()
|
||||
|
||||
@slow
|
||||
|
|
Loading…
Reference in New Issue