Use _BaseAutoModelClass's register method (#24810)

Switching _BaseAutoModelClass from_pretrained and from_config to use the register classmethod that it defines rather than using the _LazyAutoMapping register method directly. This makes use of the additional consistency check within the base model's register.
This commit is contained in:
Fady Nakhla 2023-07-13 12:24:51 -07:00 committed by GitHub
parent 0866705022
commit 9d7a0871e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -418,7 +418,7 @@ class _BaseAutoModelClass:
else:
repo_id = config.name_or_path
model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
cls._model_mapping.register(config.__class__, model_class, exist_ok=True)
cls.register(config.__class__, model_class, exist_ok=True)
_ = kwargs.pop("code_revision", None)
return model_class._from_config(config, **kwargs)
elif type(config) in cls._model_mapping.keys():
@ -477,7 +477,7 @@ class _BaseAutoModelClass:
class_ref, pretrained_model_name_or_path, **hub_kwargs, **kwargs
)
_ = hub_kwargs.pop("code_revision", None)
cls._model_mapping.register(config.__class__, model_class, exist_ok=True)
cls.register(config.__class__, model_class, exist_ok=True)
return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
)