Generate: Load generation config when `device_map` is passed (#25413)
This commit is contained in:
parent
d0839f1a74
commit
3e41cf13fc
|
@ -2849,9 +2849,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
"'sequential'."
|
||||
)
|
||||
|
||||
kwargs = {"no_split_module_classes": no_split_modules}
|
||||
device_map_kwargs = {"no_split_module_classes": no_split_modules}
|
||||
if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
|
||||
kwargs["special_dtypes"] = special_dtypes
|
||||
device_map_kwargs["special_dtypes"] = special_dtypes
|
||||
elif len(special_dtypes) > 0:
|
||||
logger.warning(
|
||||
"This model has some weights that should be kept in higher precision, you need to upgrade "
|
||||
|
@ -2863,12 +2863,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
dtype=target_dtype,
|
||||
low_zero=(device_map == "balanced_low_0"),
|
||||
max_memory=max_memory,
|
||||
**kwargs,
|
||||
**device_map_kwargs,
|
||||
)
|
||||
kwargs["max_memory"] = max_memory
|
||||
device_map_kwargs["max_memory"] = max_memory
|
||||
# Make sure tied weights are tied before creating the device map.
|
||||
model.tie_weights()
|
||||
device_map = infer_auto_device_map(model, dtype=target_dtype, **kwargs)
|
||||
device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
|
||||
|
||||
if load_in_8bit or load_in_4bit:
|
||||
# The LM head / tied weights or any last module can stay on disk / CPU
|
||||
|
@ -2966,7 +2966,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
model.eval()
|
||||
|
||||
# If it is a model with generation capabilities, attempt to load the generation config
|
||||
if model.can_generate():
|
||||
if model.can_generate() and pretrained_model_name_or_path is not None:
|
||||
try:
|
||||
model.generation_config = GenerationConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
|
@ -2982,7 +2982,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
_from_pipeline=from_pipeline,
|
||||
**kwargs,
|
||||
)
|
||||
except (OSError, TypeError):
|
||||
except OSError:
|
||||
logger.info(
|
||||
"Generation config file not found, using a generation config created from the model config."
|
||||
)
|
||||
|
@ -2990,10 +2990,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
|
||||
# Dispatch model with hooks on all devices if necessary
|
||||
if device_map is not None:
|
||||
kwargs = {"device_map": device_map, "offload_dir": offload_folder, "offload_index": offload_index}
|
||||
device_map_kwargs = {
|
||||
"device_map": device_map,
|
||||
"offload_dir": offload_folder,
|
||||
"offload_index": offload_index,
|
||||
}
|
||||
if "skip_keys" in inspect.signature(dispatch_model).parameters:
|
||||
kwargs["skip_keys"] = model._skip_keys_device_placement
|
||||
dispatch_model(model, **kwargs)
|
||||
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
|
||||
dispatch_model(model, **device_map_kwargs)
|
||||
|
||||
if output_loading_info:
|
||||
if loading_info is None:
|
||||
|
|
|
@ -1036,6 +1036,20 @@ class ModelUtilsTest(TestCasePlus):
|
|||
|
||||
self.assertEqual(model.__class__.__name__, model_ref.__class__.__name__)
|
||||
|
||||
def test_generation_config_is_loaded_with_model(self):
|
||||
# Note: `joaogante/tiny-random-gpt2-with-generation-config` has a `generation_config.json` containing a dummy
|
||||
# `transformers_version` field set to `foo`. If loading the file fails, this test also fails.
|
||||
|
||||
# 1. Load without further parameters
|
||||
model = AutoModelForCausalLM.from_pretrained("joaogante/tiny-random-gpt2-with-generation-config")
|
||||
self.assertEqual(model.generation_config.transformers_version, "foo")
|
||||
|
||||
# 2. Load with `device_map`
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"joaogante/tiny-random-gpt2-with-generation-config", device_map="auto"
|
||||
)
|
||||
self.assertEqual(model.generation_config.transformers_version, "foo")
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
|
Loading…
Reference in New Issue