Fix model parallelism test (#17439)
This commit is contained in:
parent
7535d92e71
commit
98f6e1ee87
|
@ -2203,7 +2203,7 @@ class ModelTesterMixin:
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
def test_cpu_offload(self):
|
def test_cpu_offload(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
if config.num_hidden_layers < 5:
|
if isinstance(getattr(config, "num_hidden_layers", None), int) and config.num_hidden_layers < 5:
|
||||||
config.num_hidden_layers = 5
|
config.num_hidden_layers = 5
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
|
@ -2236,7 +2236,7 @@ class ModelTesterMixin:
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_gpu
|
||||||
def test_model_parallelism(self):
|
def test_model_parallelism(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
if config.num_hidden_layers < 5:
|
if isinstance(getattr(config, "num_hidden_layers", None), int) and config.num_hidden_layers < 5:
|
||||||
config.num_hidden_layers = 5
|
config.num_hidden_layers = 5
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
|
|
Loading…
Reference in New Issue