Fix model parallelism test (#17439)

This commit is contained in:
Sylvain Gugger 2022-05-26 09:57:12 -04:00 committed by GitHub
parent 7535d92e71
commit 98f6e1ee87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -2203,7 +2203,7 @@ class ModelTesterMixin:
@require_torch_gpu
def test_cpu_offload(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.num_hidden_layers < 5:
if isinstance(getattr(config, "num_hidden_layers", None), int) and config.num_hidden_layers < 5:
config.num_hidden_layers = 5
for model_class in self.all_model_classes:
@ -2236,7 +2236,7 @@ class ModelTesterMixin:
@require_torch_multi_gpu
def test_model_parallelism(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.num_hidden_layers < 5:
if isinstance(getattr(config, "num_hidden_layers", None), int) and config.num_hidden_layers < 5:
config.num_hidden_layers = 5
for model_class in self.all_model_classes: