diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py
index b65b7cfcd9..e535c3dbde 100644
--- a/src/transformers/feature_extraction_utils.py
+++ b/src/transformers/feature_extraction_utils.py
@@ -26,6 +26,7 @@ import numpy as np
from requests import HTTPError
+from .dynamic_module_utils import custom_object_save
from .file_utils import (
FEATURE_EXTRACTOR_NAME,
EntryNotFoundError,
@@ -205,6 +206,8 @@ class FeatureExtractionMixin:
extractors.
"""
+ _auto_class = None
+
def __init__(self, **kwargs):
"""Set elements of `kwargs` as attributes."""
# Pop "processor_class" as it should be saved as private attribute
@@ -316,6 +319,12 @@ class FeatureExtractionMixin:
"""
if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+ # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
+ # loaded from the Hub.
+ if self._auto_class is not None:
+ custom_object_save(self, save_directory, config=self)
+
os.makedirs(save_directory, exist_ok=True)
# If we save using the predefined names, we can load using `from_pretrained`
output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME)
@@ -539,3 +548,29 @@ class FeatureExtractionMixin:
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
+
+ @classmethod
+ def register_for_auto_class(cls, auto_class="AutoFeatureExtractor"):
+ """
+ Register this class with a given auto class. This should only be used for custom feature extractors as the ones
+ in the library are already mapped with `AutoFeatureExtractor`.
+
+
+
+ This API is experimental and may have some slight breaking changes in the next releases.
+
+
+
+ Args:
+ auto_class (`str` or `type`, *optional*, defaults to `"AutoFeatureExtractor"`):
+ The auto class to register this new feature extractor with.
+ """
+ if not isinstance(auto_class, str):
+ auto_class = auto_class.__name__
+
+ import transformers.models.auto as auto_module
+
+ if not hasattr(auto_module, auto_class):
+ raise ValueError(f"{auto_class} is not a valid auto class.")
+
+ cls._auto_class = auto_class
diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py
index a146c611fb..93e0fc1ba9 100644
--- a/src/transformers/models/auto/feature_extraction_auto.py
+++ b/src/transformers/models/auto/feature_extraction_auto.py
@@ -14,23 +14,28 @@
# limitations under the License.
""" AutoFeatureExtractor class."""
import importlib
+import json
import os
from collections import OrderedDict
+from typing import Dict, Optional, Union
# Build the list of all feature extractors
from ...configuration_utils import PretrainedConfig
+from ...dynamic_module_utils import get_class_from_dynamic_module
from ...feature_extraction_utils import FeatureExtractionMixin
-from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME
+from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_file_from_repo
+from ...utils import logging
from .auto_factory import _LazyAutoMapping
from .configuration_auto import (
CONFIG_MAPPING_NAMES,
AutoConfig,
- config_class_to_model_type,
model_type_to_module_name,
replace_list_option_in_docstrings,
)
+logger = logging.get_logger(__name__)
+
FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
[
("beit", "BeitFeatureExtractor"),
@@ -66,6 +71,96 @@ def feature_extractor_class_from_name(class_name: str):
return None
+def get_feature_extractor_config(
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ resume_download: bool = False,
+ proxies: Optional[Dict[str, str]] = None,
+ use_auth_token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ local_files_only: bool = False,
+ **kwargs,
+):
+ """
+ Loads the tokenizer configuration from a pretrained model tokenizer configuration.
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
+ under a user or organization name, like `dbmdz/bert-base-german-cased`.
+ - a path to a *directory* containing a configuration file saved using the
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
+
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
+ cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
+ exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision(`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ If `True`, will only try to load the tokenizer configuration from local files.
+
+
+
+ Passing `use_auth_token=True` is required when you want to use a private model.
+
+
+
+ Returns:
+ `Dict`: The configuration of the tokenizer.
+
+ Examples:
+
+ ```python
+ # Download configuration from huggingface.co and cache.
+ tokenizer_config = get_tokenizer_config("bert-base-uncased")
+ # This model does not have a tokenizer config so the result will be an empty dict.
+ tokenizer_config = get_tokenizer_config("xlm-roberta-base")
+
+ # Save a pretrained tokenizer locally and you can reload its config
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
+ tokenizer.save_pretrained("tokenizer-test")
+ tokenizer_config = get_tokenizer_config("tokenizer-test")
+ ```"""
+ resolved_config_file = get_file_from_repo(
+ pretrained_model_name_or_path,
+ FEATURE_EXTRACTOR_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ local_files_only=local_files_only,
+ )
+ if resolved_config_file is None:
+ logger.info(
+ "Could not locate the feature extractor configuration file, will try to use the model config instead."
+ )
+ return {}
+
+ with open(resolved_config_file, encoding="utf-8") as reader:
+ return json.load(reader)
+
+
class AutoFeatureExtractor:
r"""
This is a generic feature extractor class that will be instantiated as one of the feature extractor classes of the
@@ -128,6 +223,10 @@ class AutoFeatureExtractor:
functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary
consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of
`kwargs` which has not been used to update `feature_extractor` and is otherwise ignored.
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
+ execute code present on the Hub on your local machine.
kwargs (`Dict[str, Any]`, *optional*):
The values in kwargs of any keys which are feature extractor attributes will be used to override the
loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
@@ -151,35 +250,54 @@ class AutoFeatureExtractor:
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("./test/saved_model/")
```"""
config = kwargs.pop("config", None)
+ trust_remote_code = kwargs.pop("trust_remote_code", False)
kwargs["_from_auto"] = True
- is_feature_extraction_file = os.path.isfile(pretrained_model_name_or_path)
- is_directory = os.path.isdir(pretrained_model_name_or_path) and os.path.exists(
- os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
- )
+ config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
+ feature_extractor_class = config_dict.get("feature_extractor_type", None)
+ feature_extractor_auto_map = None
+ if "AutoFeatureExtractor" in config_dict.get("auto_map", {}):
+ feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"]
- has_local_config = (
- os.path.exists(os.path.join(pretrained_model_name_or_path, CONFIG_NAME)) if is_directory else False
- )
-
- # load config, if it can be loaded
- if not is_feature_extraction_file and (has_local_config or not is_directory):
+ # If we don't find the feature extractor class in the feature extractor config, let's try the model config.
+ if feature_extractor_class is None and feature_extractor_auto_map is None:
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
+ # It could be in `config.feature_extractor_type``
+ feature_extractor_class = getattr(config, "feature_extractor_type", None)
+ if hasattr(config, "auto_map") and "AutoFeatureExtractor" in config.auto_map:
+ feature_extractor_auto_map = config.auto_map["AutoFeatureExtractor"]
- kwargs["_from_auto"] = True
- config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
+ if feature_extractor_class is not None:
+ # If we have custom code for a feature extractor, we get the proper class.
+ if feature_extractor_auto_map is not None:
+ if not trust_remote_code:
+ raise ValueError(
+ f"Loading {pretrained_model_name_or_path} requires you to execute the feature extractor file "
+ "in that repo on your local machine. Make sure you have read the code there to avoid "
+ "malicious use, then set the option `trust_remote_code=True` to remove this error."
+ )
+ if kwargs.get("revision", None) is None:
+ logger.warning(
+ "Explicitly passing a `revision` is encouraged when loading a feature extractor with custom "
+ "code to ensure no malicious code has been contributed in a newer revision."
+ )
- model_type = config_class_to_model_type(type(config).__name__)
+ module_file, class_name = feature_extractor_auto_map.split(".")
+ feature_extractor_class = get_class_from_dynamic_module(
+ pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
+ )
+ else:
+ feature_extractor_class = feature_extractor_class_from_name(feature_extractor_class)
- if "feature_extractor_type" in config_dict:
- feature_extractor_class = feature_extractor_class_from_name(config_dict["feature_extractor_type"])
return feature_extractor_class.from_dict(config_dict, **kwargs)
- elif model_type is not None:
- return FEATURE_EXTRACTOR_MAPPING[type(config)].from_dict(config_dict, **kwargs)
+ # Last try: we use the FEATURE_EXTRACTOR_MAPPING.
+ elif type(config) in FEATURE_EXTRACTOR_MAPPING:
+ feature_extractor_class = FEATURE_EXTRACTOR_MAPPING[type(config)]
+ return feature_extractor_class.from_dict(config_dict, **kwargs)
raise ValueError(
- f"Unrecognized feature extractor in {pretrained_model_name_or_path}. Should have a `feature_extractor_type` key in "
- f"its {FEATURE_EXTRACTOR_NAME}, or one of the following `model_type` keys in its {CONFIG_NAME}: "
- f"{', '.join(c for c in FEATURE_EXTRACTOR_MAPPING_NAMES.keys())}"
+ f"Unrecognized feature extractor in {pretrained_model_name_or_path}. Should have a "
+ f"`feature_extractor_type` key in its {FEATURE_EXTRACTOR_NAME} of {CONFIG_NAME}, or one of the following "
+ "`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in FEATURE_EXTRACTOR_MAPPING_NAMES.keys())}"
)
diff --git a/tests/test_feature_extraction_auto.py b/tests/test_feature_extraction_auto.py
index c827b0a656..da5386bd50 100644
--- a/tests/test_feature_extraction_auto.py
+++ b/tests/test_feature_extraction_auto.py
@@ -82,3 +82,9 @@ class AutoFeatureExtractorTest(unittest.TestCase):
"hf-internal-testing/config-no-model does not appear to have a file named preprocessor_config.json.",
):
_ = AutoFeatureExtractor.from_pretrained("hf-internal-testing/config-no-model")
+
+ def test_from_pretrained_dynamic_feature_extractor(self):
+ model = AutoFeatureExtractor.from_pretrained(
+ "hf-internal-testing/test_dynamic_feature_extractor", trust_remote_code=True
+ )
+ self.assertEqual(model.__class__.__name__, "NewFeatureExtractor")
diff --git a/tests/test_feature_extraction_common.py b/tests/test_feature_extraction_common.py
index 217da135ca..931ee2444e 100644
--- a/tests/test_feature_extraction_common.py
+++ b/tests/test_feature_extraction_common.py
@@ -16,9 +16,21 @@
import json
import os
+import sys
import tempfile
+import unittest
+from pathlib import Path
+from huggingface_hub import Repository, delete_repo, login
+from requests.exceptions import HTTPError
+from transformers import AutoFeatureExtractor
from transformers.file_utils import is_torch_available, is_vision_available
+from transformers.testing_utils import PASS, USER, is_staging_test
+
+
+sys.path.append(str(Path(__file__).parent.parent / "utils"))
+
+from test_module.custom_feature_extraction import CustomFeatureExtractor # noqa E402
if is_torch_available():
@@ -29,6 +41,9 @@ if is_vision_available():
from PIL import Image
+SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
+
+
def prepare_image_inputs(feature_extract_tester, equal_resolution=False, numpify=False, torchify=False):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
@@ -99,3 +114,41 @@ class FeatureExtractionSavingTestMixin:
def test_init_without_params(self):
feat_extract = self.feature_extraction_class()
self.assertIsNotNone(feat_extract)
+
+
+@is_staging_test
+class ConfigPushToHubTester(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls._token = login(username=USER, password=PASS)
+
+ @classmethod
+ def tearDownClass(cls):
+ try:
+ delete_repo(token=cls._token, name="test-dynamic-feature-extractor")
+ except HTTPError:
+ pass
+
+ def test_push_to_hub_dynamic_feature_extractor(self):
+ CustomFeatureExtractor.register_for_auto_class()
+ feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-feature-extractor", use_auth_token=self._token)
+ feature_extractor.save_pretrained(tmp_dir)
+
+ # This has added the proper auto_map field to the config
+ self.assertDictEqual(
+ feature_extractor.auto_map,
+ {"AutoFeatureExtractor": "custom_feature_extraction.CustomFeatureExtractor"},
+ )
+ # The code has been copied from fixtures
+ self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "custom_feature_extraction.py")))
+
+ repo.push_to_hub()
+
+ new_feature_extractor = AutoFeatureExtractor.from_pretrained(
+ f"{USER}/test-dynamic-feature-extractor", trust_remote_code=True
+ )
+ # Can't make an isinstance check because the new_feature_extractor is from the CustomFeatureExtractor class of a dynamic module
+ self.assertEqual(new_feature_extractor.__class__.__name__, "CustomFeatureExtractor")
diff --git a/utils/test_module/custom_feature_extraction.py b/utils/test_module/custom_feature_extraction.py
new file mode 100644
index 0000000000..de367032d8
--- /dev/null
+++ b/utils/test_module/custom_feature_extraction.py
@@ -0,0 +1,5 @@
+from transformers import Wav2Vec2FeatureExtractor
+
+
+class CustomFeatureExtractor(Wav2Vec2FeatureExtractor):
+ pass