diff --git a/tests/models/bros/test_modeling_bros.py b/tests/models/bros/test_modeling_bros.py index 755deefcb4..4b1290ed49 100644 --- a/tests/models/bros/test_modeling_bros.py +++ b/tests/models/bros/test_modeling_bros.py @@ -17,7 +17,7 @@ import copy import unittest -from transformers.testing_utils import require_torch, slow, torch_device +from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device from transformers.utils import is_torch_available from ...test_configuration_common import ConfigTester @@ -344,6 +344,7 @@ class BrosModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + @require_torch_multi_gpu def test_multi_gpu_data_parallel_forward(self): super().test_multi_gpu_data_parallel_forward()