Custom feature extractor (#15630)

* Rework AutoFeatureExtractor.from_pretrained internal

* Custom feature extractor

* Add more tests

* Add support for custom feature extractor code

* Clean up
This commit is contained in:
Sylvain Gugger 2022-02-11 16:43:54 -05:00 committed by GitHub
parent fcb0f74397
commit 7a32e4722f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 239 additions and 22 deletions

View File

@ -26,6 +26,7 @@ import numpy as np
from requests import HTTPError
from .dynamic_module_utils import custom_object_save
from .file_utils import (
FEATURE_EXTRACTOR_NAME,
EntryNotFoundError,
@ -205,6 +206,8 @@ class FeatureExtractionMixin:
extractors.
"""
_auto_class = None
def __init__(self, **kwargs):
"""Set elements of `kwargs` as attributes."""
# Pop "processor_class" as it should be saved as private attribute
@ -316,6 +319,12 @@ class FeatureExtractionMixin:
"""
if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub.
if self._auto_class is not None:
custom_object_save(self, save_directory, config=self)
os.makedirs(save_directory, exist_ok=True)
# If we save using the predefined names, we can load using `from_pretrained`
output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME)
@ -539,3 +548,29 @@ class FeatureExtractionMixin:
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
@classmethod
def register_for_auto_class(cls, auto_class="AutoFeatureExtractor"):
"""
Register this class with a given auto class. This should only be used for custom feature extractors as the ones
in the library are already mapped with `AutoFeatureExtractor`.
<Tip warning={true}>
This API is experimental and may have some slight breaking changes in the next releases.
</Tip>
Args:
auto_class (`str` or `type`, *optional*, defaults to `"AutoFeatureExtractor"`):
The auto class to register this new feature extractor with.
"""
if not isinstance(auto_class, str):
auto_class = auto_class.__name__
import transformers.models.auto as auto_module
if not hasattr(auto_module, auto_class):
raise ValueError(f"{auto_class} is not a valid auto class.")
cls._auto_class = auto_class

View File

@ -14,23 +14,28 @@
# limitations under the License.
""" AutoFeatureExtractor class."""
import importlib
import json
import os
from collections import OrderedDict
from typing import Dict, Optional, Union
# Build the list of all feature extractors
from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module
from ...feature_extraction_utils import FeatureExtractionMixin
from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME
from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_file_from_repo
from ...utils import logging
from .auto_factory import _LazyAutoMapping
from .configuration_auto import (
CONFIG_MAPPING_NAMES,
AutoConfig,
config_class_to_model_type,
model_type_to_module_name,
replace_list_option_in_docstrings,
)
logger = logging.get_logger(__name__)
FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
[
("beit", "BeitFeatureExtractor"),
@ -66,6 +71,96 @@ def feature_extractor_class_from_name(class_name: str):
return None
def get_feature_extractor_config(
pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
**kwargs,
):
"""
Loads the tokenizer configuration from a pretrained model tokenizer configuration.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
under a user or organization name, like `dbmdz/bert-base-german-cased`.
- a path to a *directory* containing a configuration file saved using the
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
cache_dir (`str` or `os.PathLike`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force to (re-)download the configuration files and override the cached versions if they
exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
revision(`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
<Tip>
Passing `use_auth_token=True` is required when you want to use a private model.
</Tip>
Returns:
`Dict`: The configuration of the tokenizer.
Examples:
```python
# Download configuration from huggingface.co and cache.
tokenizer_config = get_tokenizer_config("bert-base-uncased")
# This model does not have a tokenizer config so the result will be an empty dict.
tokenizer_config = get_tokenizer_config("xlm-roberta-base")
# Save a pretrained tokenizer locally and you can reload its config
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
tokenizer.save_pretrained("tokenizer-test")
tokenizer_config = get_tokenizer_config("tokenizer-test")
```"""
resolved_config_file = get_file_from_repo(
pretrained_model_name_or_path,
FEATURE_EXTRACTOR_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
use_auth_token=use_auth_token,
revision=revision,
local_files_only=local_files_only,
)
if resolved_config_file is None:
logger.info(
"Could not locate the feature extractor configuration file, will try to use the model config instead."
)
return {}
with open(resolved_config_file, encoding="utf-8") as reader:
return json.load(reader)
class AutoFeatureExtractor:
r"""
This is a generic feature extractor class that will be instantiated as one of the feature extractor classes of the
@ -128,6 +223,10 @@ class AutoFeatureExtractor:
functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary
consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of
`kwargs` which has not been used to update `feature_extractor` and is otherwise ignored.
trust_remote_code (`bool`, *optional*, defaults to `False`):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to `True` for repositories you trust and in which you have read the code, as it will
execute code present on the Hub on your local machine.
kwargs (`Dict[str, Any]`, *optional*):
The values in kwargs of any keys which are feature extractor attributes will be used to override the
loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
@ -151,35 +250,54 @@ class AutoFeatureExtractor:
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("./test/saved_model/")
```"""
config = kwargs.pop("config", None)
trust_remote_code = kwargs.pop("trust_remote_code", False)
kwargs["_from_auto"] = True
is_feature_extraction_file = os.path.isfile(pretrained_model_name_or_path)
is_directory = os.path.isdir(pretrained_model_name_or_path) and os.path.exists(
os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
)
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
feature_extractor_class = config_dict.get("feature_extractor_type", None)
feature_extractor_auto_map = None
if "AutoFeatureExtractor" in config_dict.get("auto_map", {}):
feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"]
has_local_config = (
os.path.exists(os.path.join(pretrained_model_name_or_path, CONFIG_NAME)) if is_directory else False
)
# load config, if it can be loaded
if not is_feature_extraction_file and (has_local_config or not is_directory):
# If we don't find the feature extractor class in the feature extractor config, let's try the model config.
if feature_extractor_class is None and feature_extractor_auto_map is None:
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
# It could be in `config.feature_extractor_type``
feature_extractor_class = getattr(config, "feature_extractor_type", None)
if hasattr(config, "auto_map") and "AutoFeatureExtractor" in config.auto_map:
feature_extractor_auto_map = config.auto_map["AutoFeatureExtractor"]
kwargs["_from_auto"] = True
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
if feature_extractor_class is not None:
# If we have custom code for a feature extractor, we get the proper class.
if feature_extractor_auto_map is not None:
if not trust_remote_code:
raise ValueError(
f"Loading {pretrained_model_name_or_path} requires you to execute the feature extractor file "
"in that repo on your local machine. Make sure you have read the code there to avoid "
"malicious use, then set the option `trust_remote_code=True` to remove this error."
)
if kwargs.get("revision", None) is None:
logger.warning(
"Explicitly passing a `revision` is encouraged when loading a feature extractor with custom "
"code to ensure no malicious code has been contributed in a newer revision."
)
model_type = config_class_to_model_type(type(config).__name__)
module_file, class_name = feature_extractor_auto_map.split(".")
feature_extractor_class = get_class_from_dynamic_module(
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
)
else:
feature_extractor_class = feature_extractor_class_from_name(feature_extractor_class)
if "feature_extractor_type" in config_dict:
feature_extractor_class = feature_extractor_class_from_name(config_dict["feature_extractor_type"])
return feature_extractor_class.from_dict(config_dict, **kwargs)
elif model_type is not None:
return FEATURE_EXTRACTOR_MAPPING[type(config)].from_dict(config_dict, **kwargs)
# Last try: we use the FEATURE_EXTRACTOR_MAPPING.
elif type(config) in FEATURE_EXTRACTOR_MAPPING:
feature_extractor_class = FEATURE_EXTRACTOR_MAPPING[type(config)]
return feature_extractor_class.from_dict(config_dict, **kwargs)
raise ValueError(
f"Unrecognized feature extractor in {pretrained_model_name_or_path}. Should have a `feature_extractor_type` key in "
f"its {FEATURE_EXTRACTOR_NAME}, or one of the following `model_type` keys in its {CONFIG_NAME}: "
f"{', '.join(c for c in FEATURE_EXTRACTOR_MAPPING_NAMES.keys())}"
f"Unrecognized feature extractor in {pretrained_model_name_or_path}. Should have a "
f"`feature_extractor_type` key in its {FEATURE_EXTRACTOR_NAME} of {CONFIG_NAME}, or one of the following "
"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in FEATURE_EXTRACTOR_MAPPING_NAMES.keys())}"
)

View File

@ -82,3 +82,9 @@ class AutoFeatureExtractorTest(unittest.TestCase):
"hf-internal-testing/config-no-model does not appear to have a file named preprocessor_config.json.",
):
_ = AutoFeatureExtractor.from_pretrained("hf-internal-testing/config-no-model")
def test_from_pretrained_dynamic_feature_extractor(self):
model = AutoFeatureExtractor.from_pretrained(
"hf-internal-testing/test_dynamic_feature_extractor", trust_remote_code=True
)
self.assertEqual(model.__class__.__name__, "NewFeatureExtractor")

View File

@ -16,9 +16,21 @@
import json
import os
import sys
import tempfile
import unittest
from pathlib import Path
from huggingface_hub import Repository, delete_repo, login
from requests.exceptions import HTTPError
from transformers import AutoFeatureExtractor
from transformers.file_utils import is_torch_available, is_vision_available
from transformers.testing_utils import PASS, USER, is_staging_test
sys.path.append(str(Path(__file__).parent.parent / "utils"))
from test_module.custom_feature_extraction import CustomFeatureExtractor # noqa E402
if is_torch_available():
@ -29,6 +41,9 @@ if is_vision_available():
from PIL import Image
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
def prepare_image_inputs(feature_extract_tester, equal_resolution=False, numpify=False, torchify=False):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
@ -99,3 +114,41 @@ class FeatureExtractionSavingTestMixin:
def test_init_without_params(self):
feat_extract = self.feature_extraction_class()
self.assertIsNotNone(feat_extract)
@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls._token = login(username=USER, password=PASS)
@classmethod
def tearDownClass(cls):
try:
delete_repo(token=cls._token, name="test-dynamic-feature-extractor")
except HTTPError:
pass
def test_push_to_hub_dynamic_feature_extractor(self):
CustomFeatureExtractor.register_for_auto_class()
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-feature-extractor", use_auth_token=self._token)
feature_extractor.save_pretrained(tmp_dir)
# This has added the proper auto_map field to the config
self.assertDictEqual(
feature_extractor.auto_map,
{"AutoFeatureExtractor": "custom_feature_extraction.CustomFeatureExtractor"},
)
# The code has been copied from fixtures
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "custom_feature_extraction.py")))
repo.push_to_hub()
new_feature_extractor = AutoFeatureExtractor.from_pretrained(
f"{USER}/test-dynamic-feature-extractor", trust_remote_code=True
)
# Can't make an isinstance check because the new_feature_extractor is from the CustomFeatureExtractor class of a dynamic module
self.assertEqual(new_feature_extractor.__class__.__name__, "CustomFeatureExtractor")

View File

@ -0,0 +1,5 @@
from transformers import Wav2Vec2FeatureExtractor
class CustomFeatureExtractor(Wav2Vec2FeatureExtractor):
pass