Generate: Load generation config when `device_map` is passed (#25413)

This commit is contained in:
Joao Gante 2023-08-10 10:54:26 +01:00 committed by GitHub
parent d0839f1a74
commit 3e41cf13fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 10 deletions

View File

@ -2849,9 +2849,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"'sequential'."
)
kwargs = {"no_split_module_classes": no_split_modules}
device_map_kwargs = {"no_split_module_classes": no_split_modules}
if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
kwargs["special_dtypes"] = special_dtypes
device_map_kwargs["special_dtypes"] = special_dtypes
elif len(special_dtypes) > 0:
logger.warning(
"This model has some weights that should be kept in higher precision, you need to upgrade "
@ -2863,12 +2863,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
dtype=target_dtype,
low_zero=(device_map == "balanced_low_0"),
max_memory=max_memory,
**kwargs,
**device_map_kwargs,
)
kwargs["max_memory"] = max_memory
device_map_kwargs["max_memory"] = max_memory
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(model, dtype=target_dtype, **kwargs)
device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
if load_in_8bit or load_in_4bit:
# The LM head / tied weights or any last module can stay on disk / CPU
@ -2966,7 +2966,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model.eval()
# If it is a model with generation capabilities, attempt to load the generation config
if model.can_generate():
if model.can_generate() and pretrained_model_name_or_path is not None:
try:
model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
@ -2982,7 +2982,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
_from_pipeline=from_pipeline,
**kwargs,
)
except (OSError, TypeError):
except OSError:
logger.info(
"Generation config file not found, using a generation config created from the model config."
)
@ -2990,10 +2990,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Dispatch model with hooks on all devices if necessary
if device_map is not None:
kwargs = {"device_map": device_map, "offload_dir": offload_folder, "offload_index": offload_index}
device_map_kwargs = {
"device_map": device_map,
"offload_dir": offload_folder,
"offload_index": offload_index,
}
if "skip_keys" in inspect.signature(dispatch_model).parameters:
kwargs["skip_keys"] = model._skip_keys_device_placement
dispatch_model(model, **kwargs)
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
dispatch_model(model, **device_map_kwargs)
if output_loading_info:
if loading_info is None:

View File

@ -1036,6 +1036,20 @@ class ModelUtilsTest(TestCasePlus):
self.assertEqual(model.__class__.__name__, model_ref.__class__.__name__)
def test_generation_config_is_loaded_with_model(self):
# Note: `joaogante/tiny-random-gpt2-with-generation-config` has a `generation_config.json` containing a dummy
# `transformers_version` field set to `foo`. If loading the file fails, this test also fails.
# 1. Load without further parameters
model = AutoModelForCausalLM.from_pretrained("joaogante/tiny-random-gpt2-with-generation-config")
self.assertEqual(model.generation_config.transformers_version, "foo")
# 2. Load with `device_map`
model = AutoModelForCausalLM.from_pretrained(
"joaogante/tiny-random-gpt2-with-generation-config", device_map="auto"
)
self.assertEqual(model.generation_config.transformers_version, "foo")
@require_torch
@is_staging_test