From 3e41cf13fc56335ace852b14decc198557052d4f Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 10 Aug 2023 10:54:26 +0100 Subject: [PATCH] Generate: Load generation config when `device_map` is passed (#25413) --- src/transformers/modeling_utils.py | 24 ++++++++++++++---------- tests/test_modeling_utils.py | 14 ++++++++++++++ 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index adb2b4919d..f5b5fa4dd7 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2849,9 +2849,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix "'sequential'." ) - kwargs = {"no_split_module_classes": no_split_modules} + device_map_kwargs = {"no_split_module_classes": no_split_modules} if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters: - kwargs["special_dtypes"] = special_dtypes + device_map_kwargs["special_dtypes"] = special_dtypes elif len(special_dtypes) > 0: logger.warning( "This model has some weights that should be kept in higher precision, you need to upgrade " @@ -2863,12 +2863,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix dtype=target_dtype, low_zero=(device_map == "balanced_low_0"), max_memory=max_memory, - **kwargs, + **device_map_kwargs, ) - kwargs["max_memory"] = max_memory + device_map_kwargs["max_memory"] = max_memory # Make sure tied weights are tied before creating the device map. model.tie_weights() - device_map = infer_auto_device_map(model, dtype=target_dtype, **kwargs) + device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs) if load_in_8bit or load_in_4bit: # The LM head / tied weights or any last module can stay on disk / CPU @@ -2966,7 +2966,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix model.eval() # If it is a model with generation capabilities, attempt to load the generation config - if model.can_generate(): + if model.can_generate() and pretrained_model_name_or_path is not None: try: model.generation_config = GenerationConfig.from_pretrained( pretrained_model_name_or_path, @@ -2982,7 +2982,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix _from_pipeline=from_pipeline, **kwargs, ) - except (OSError, TypeError): + except OSError: logger.info( "Generation config file not found, using a generation config created from the model config." ) @@ -2990,10 +2990,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Dispatch model with hooks on all devices if necessary if device_map is not None: - kwargs = {"device_map": device_map, "offload_dir": offload_folder, "offload_index": offload_index} + device_map_kwargs = { + "device_map": device_map, + "offload_dir": offload_folder, + "offload_index": offload_index, + } if "skip_keys" in inspect.signature(dispatch_model).parameters: - kwargs["skip_keys"] = model._skip_keys_device_placement - dispatch_model(model, **kwargs) + device_map_kwargs["skip_keys"] = model._skip_keys_device_placement + dispatch_model(model, **device_map_kwargs) if output_loading_info: if loading_info is None: diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 5019d0ccb3..bdadbe0800 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -1036,6 +1036,20 @@ class ModelUtilsTest(TestCasePlus): self.assertEqual(model.__class__.__name__, model_ref.__class__.__name__) + def test_generation_config_is_loaded_with_model(self): + # Note: `joaogante/tiny-random-gpt2-with-generation-config` has a `generation_config.json` containing a dummy + # `transformers_version` field set to `foo`. If loading the file fails, this test also fails. + + # 1. Load without further parameters + model = AutoModelForCausalLM.from_pretrained("joaogante/tiny-random-gpt2-with-generation-config") + self.assertEqual(model.generation_config.transformers_version, "foo") + + # 2. Load with `device_map` + model = AutoModelForCausalLM.from_pretrained( + "joaogante/tiny-random-gpt2-with-generation-config", device_map="auto" + ) + self.assertEqual(model.generation_config.transformers_version, "foo") + @require_torch @is_staging_test