diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index b7d8f66c33..b412f14157 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -419,9 +419,24 @@ class _BaseAutoModelClass: config = kwargs.pop("config", None) trust_remote_code = kwargs.pop("trust_remote_code", False) kwargs["_from_auto"] = True + hub_kwargs_names = [ + "cache_dir", + "force_download", + "local_files_only", + "proxies", + "resume_download", + "revision", + "subfolder", + "use_auth_token", + ] + hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} if not isinstance(config, PretrainedConfig): config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **kwargs + pretrained_model_name_or_path, + return_unused_kwargs=True, + trust_remote_code=trust_remote_code, + **hub_kwargs, + **kwargs, ) if hasattr(config, "auto_map") and cls.__name__ in config.auto_map: if not trust_remote_code: @@ -430,7 +445,7 @@ class _BaseAutoModelClass: "on your local machine. Make sure you have read the code there to avoid malicious use, then set " "the option `trust_remote_code=True` to remove this error." ) - if kwargs.get("revision", None) is None: + if hub_kwargs.get("revision", None) is None: logger.warning( "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure " "no malicious code has been contributed in a newer revision." @@ -438,12 +453,16 @@ class _BaseAutoModelClass: class_ref = config.auto_map[cls.__name__] module_file, class_name = class_ref.split(".") model_class = get_class_from_dynamic_module( - pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs + pretrained_model_name_or_path, module_file + ".py", class_name, **hub_kwargs, **kwargs + ) + return model_class.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs ) - return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) elif type(config) in cls._model_mapping.keys(): model_class = _get_model_class(config, cls._model_mapping) - return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + return model_class.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs + ) raise ValueError( f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index d8ecbb49e6..c65a2762a0 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -728,7 +728,7 @@ class AutoConfig: kwargs["_from_auto"] = True kwargs["name_or_path"] = pretrained_model_name_or_path trust_remote_code = kwargs.pop("trust_remote_code", False) - config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs) + config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs) if "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]: if not trust_remote_code: raise ValueError( @@ -749,13 +749,13 @@ class AutoConfig: return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs) elif "model_type" in config_dict: config_class = CONFIG_MAPPING[config_dict["model_type"]] - return config_class.from_dict(config_dict, **kwargs) + return config_class.from_dict(config_dict, **unused_kwargs) else: # Fallback: use pattern matching on the string. # We go from longer names to shorter names to catch roberta before bert (for instance) for pattern in sorted(CONFIG_MAPPING.keys(), key=len, reverse=True): if pattern in str(pretrained_model_name_or_path): - return CONFIG_MAPPING[pattern].from_dict(config_dict, **kwargs) + return CONFIG_MAPPING[pattern].from_dict(config_dict, **unused_kwargs) raise ValueError( f"Unrecognized model in {pretrained_model_name_or_path}. "