Preserve hub-related kwargs in AutoModel.from_pretrained (#18545)
* Preserve hub-related kwargs in AutoModel.from_pretrained * Fix tests * Remove debug statement
This commit is contained in:
parent
34aad0dac0
commit
d7e2d7b40b
|
@ -419,9 +419,24 @@ class _BaseAutoModelClass:
|
|||
config = kwargs.pop("config", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", False)
|
||||
kwargs["_from_auto"] = True
|
||||
hub_kwargs_names = [
|
||||
"cache_dir",
|
||||
"force_download",
|
||||
"local_files_only",
|
||||
"proxies",
|
||||
"resume_download",
|
||||
"revision",
|
||||
"subfolder",
|
||||
"use_auth_token",
|
||||
]
|
||||
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **kwargs
|
||||
pretrained_model_name_or_path,
|
||||
return_unused_kwargs=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**hub_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
|
||||
if not trust_remote_code:
|
||||
|
@ -430,7 +445,7 @@ class _BaseAutoModelClass:
|
|||
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
|
||||
"the option `trust_remote_code=True` to remove this error."
|
||||
)
|
||||
if kwargs.get("revision", None) is None:
|
||||
if hub_kwargs.get("revision", None) is None:
|
||||
logger.warning(
|
||||
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
|
||||
"no malicious code has been contributed in a newer revision."
|
||||
|
@ -438,12 +453,16 @@ class _BaseAutoModelClass:
|
|||
class_ref = config.auto_map[cls.__name__]
|
||||
module_file, class_name = class_ref.split(".")
|
||||
model_class = get_class_from_dynamic_module(
|
||||
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
|
||||
pretrained_model_name_or_path, module_file + ".py", class_name, **hub_kwargs, **kwargs
|
||||
)
|
||||
return model_class.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
|
||||
)
|
||||
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif type(config) in cls._model_mapping.keys():
|
||||
model_class = _get_model_class(config, cls._model_mapping)
|
||||
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
return model_class.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
||||
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
||||
|
|
|
@ -728,7 +728,7 @@ class AutoConfig:
|
|||
kwargs["_from_auto"] = True
|
||||
kwargs["name_or_path"] = pretrained_model_name_or_path
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", False)
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
if "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]:
|
||||
if not trust_remote_code:
|
||||
raise ValueError(
|
||||
|
@ -749,13 +749,13 @@ class AutoConfig:
|
|||
return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif "model_type" in config_dict:
|
||||
config_class = CONFIG_MAPPING[config_dict["model_type"]]
|
||||
return config_class.from_dict(config_dict, **kwargs)
|
||||
return config_class.from_dict(config_dict, **unused_kwargs)
|
||||
else:
|
||||
# Fallback: use pattern matching on the string.
|
||||
# We go from longer names to shorter names to catch roberta before bert (for instance)
|
||||
for pattern in sorted(CONFIG_MAPPING.keys(), key=len, reverse=True):
|
||||
if pattern in str(pretrained_model_name_or_path):
|
||||
return CONFIG_MAPPING[pattern].from_dict(config_dict, **kwargs)
|
||||
return CONFIG_MAPPING[pattern].from_dict(config_dict, **unused_kwargs)
|
||||
|
||||
raise ValueError(
|
||||
f"Unrecognized model in {pretrained_model_name_or_path}. "
|
||||
|
|
Loading…
Reference in New Issue