From 667939a2d327cc4db80fe4879f86c4f999dcb490 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Mon, 15 Apr 2024 23:30:52 +0800 Subject: [PATCH] [tests] add the missing `require_torch_multi_gpu` flag (#30250) add gpu flag --- tests/models/bros/test_modeling_bros.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/bros/test_modeling_bros.py b/tests/models/bros/test_modeling_bros.py index 755deefcb4..4b1290ed49 100644 --- a/tests/models/bros/test_modeling_bros.py +++ b/tests/models/bros/test_modeling_bros.py @@ -17,7 +17,7 @@ import copy import unittest -from transformers.testing_utils import require_torch, slow, torch_device +from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device from transformers.utils import is_torch_available from ...test_configuration_common import ConfigTester @@ -344,6 +344,7 @@ class BrosModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + @require_torch_multi_gpu def test_multi_gpu_data_parallel_forward(self): super().test_multi_gpu_data_parallel_forward()