diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index d322f83668..5519d82e7a 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -418,7 +418,7 @@ class _BaseAutoModelClass: else: repo_id = config.name_or_path model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs) - cls._model_mapping.register(config.__class__, model_class, exist_ok=True) + cls.register(config.__class__, model_class, exist_ok=True) _ = kwargs.pop("code_revision", None) return model_class._from_config(config, **kwargs) elif type(config) in cls._model_mapping.keys(): @@ -477,7 +477,7 @@ class _BaseAutoModelClass: class_ref, pretrained_model_name_or_path, **hub_kwargs, **kwargs ) _ = hub_kwargs.pop("code_revision", None) - cls._model_mapping.register(config.__class__, model_class, exist_ok=True) + cls.register(config.__class__, model_class, exist_ok=True) return model_class.from_pretrained( pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs )