Custom feature extractor (#15630)
* Rework AutoFeatureExtractor.from_pretrained internal * Custom feature extractor * Add more tests * Add support for custom feature extractor code * Clean up
This commit is contained in:
parent
fcb0f74397
commit
7a32e4722f
|
@ -26,6 +26,7 @@ import numpy as np
|
|||
|
||||
from requests import HTTPError
|
||||
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .file_utils import (
|
||||
FEATURE_EXTRACTOR_NAME,
|
||||
EntryNotFoundError,
|
||||
|
@ -205,6 +206,8 @@ class FeatureExtractionMixin:
|
|||
extractors.
|
||||
"""
|
||||
|
||||
_auto_class = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Set elements of `kwargs` as attributes."""
|
||||
# Pop "processor_class" as it should be saved as private attribute
|
||||
|
@ -316,6 +319,12 @@ class FeatureExtractionMixin:
|
|||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
|
||||
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
|
||||
# loaded from the Hub.
|
||||
if self._auto_class is not None:
|
||||
custom_object_save(self, save_directory, config=self)
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME)
|
||||
|
@ -539,3 +548,29 @@ class FeatureExtractionMixin:
|
|||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__} {self.to_json_string()}"
|
||||
|
||||
@classmethod
|
||||
def register_for_auto_class(cls, auto_class="AutoFeatureExtractor"):
|
||||
"""
|
||||
Register this class with a given auto class. This should only be used for custom feature extractors as the ones
|
||||
in the library are already mapped with `AutoFeatureExtractor`.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is experimental and may have some slight breaking changes in the next releases.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
auto_class (`str` or `type`, *optional*, defaults to `"AutoFeatureExtractor"`):
|
||||
The auto class to register this new feature extractor with.
|
||||
"""
|
||||
if not isinstance(auto_class, str):
|
||||
auto_class = auto_class.__name__
|
||||
|
||||
import transformers.models.auto as auto_module
|
||||
|
||||
if not hasattr(auto_module, auto_class):
|
||||
raise ValueError(f"{auto_class} is not a valid auto class.")
|
||||
|
||||
cls._auto_class = auto_class
|
||||
|
|
|
@ -14,23 +14,28 @@
|
|||
# limitations under the License.
|
||||
""" AutoFeatureExtractor class."""
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
# Build the list of all feature extractors
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...dynamic_module_utils import get_class_from_dynamic_module
|
||||
from ...feature_extraction_utils import FeatureExtractionMixin
|
||||
from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME
|
||||
from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_file_from_repo
|
||||
from ...utils import logging
|
||||
from .auto_factory import _LazyAutoMapping
|
||||
from .configuration_auto import (
|
||||
CONFIG_MAPPING_NAMES,
|
||||
AutoConfig,
|
||||
config_class_to_model_type,
|
||||
model_type_to_module_name,
|
||||
replace_list_option_in_docstrings,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("beit", "BeitFeatureExtractor"),
|
||||
|
@ -66,6 +71,96 @@ def feature_extractor_class_from_name(class_name: str):
|
|||
return None
|
||||
|
||||
|
||||
def get_feature_extractor_config(
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool = False,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Loads the tokenizer configuration from a pretrained model tokenizer configuration.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
||||
This can be either:
|
||||
|
||||
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
||||
huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
|
||||
under a user or organization name, like `dbmdz/bert-base-german-cased`.
|
||||
- a path to a *directory* containing a configuration file saved using the
|
||||
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
||||
|
||||
cache_dir (`str` or `os.PathLike`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
||||
cache should not be used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
||||
exist.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `transformers-cli login` (stored in `~/.huggingface`).
|
||||
revision(`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, will only try to load the tokenizer configuration from local files.
|
||||
|
||||
<Tip>
|
||||
|
||||
Passing `use_auth_token=True` is required when you want to use a private model.
|
||||
|
||||
</Tip>
|
||||
|
||||
Returns:
|
||||
`Dict`: The configuration of the tokenizer.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
# Download configuration from huggingface.co and cache.
|
||||
tokenizer_config = get_tokenizer_config("bert-base-uncased")
|
||||
# This model does not have a tokenizer config so the result will be an empty dict.
|
||||
tokenizer_config = get_tokenizer_config("xlm-roberta-base")
|
||||
|
||||
# Save a pretrained tokenizer locally and you can reload its config
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
|
||||
tokenizer.save_pretrained("tokenizer-test")
|
||||
tokenizer_config = get_tokenizer_config("tokenizer-test")
|
||||
```"""
|
||||
resolved_config_file = get_file_from_repo(
|
||||
pretrained_model_name_or_path,
|
||||
FEATURE_EXTRACTOR_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
if resolved_config_file is None:
|
||||
logger.info(
|
||||
"Could not locate the feature extractor configuration file, will try to use the model config instead."
|
||||
)
|
||||
return {}
|
||||
|
||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||
return json.load(reader)
|
||||
|
||||
|
||||
class AutoFeatureExtractor:
|
||||
r"""
|
||||
This is a generic feature extractor class that will be instantiated as one of the feature extractor classes of the
|
||||
|
@ -128,6 +223,10 @@ class AutoFeatureExtractor:
|
|||
functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary
|
||||
consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of
|
||||
`kwargs` which has not been used to update `feature_extractor` and is otherwise ignored.
|
||||
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
||||
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
||||
execute code present on the Hub on your local machine.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
The values in kwargs of any keys which are feature extractor attributes will be used to override the
|
||||
loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
|
||||
|
@ -151,35 +250,54 @@ class AutoFeatureExtractor:
|
|||
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("./test/saved_model/")
|
||||
```"""
|
||||
config = kwargs.pop("config", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", False)
|
||||
kwargs["_from_auto"] = True
|
||||
|
||||
is_feature_extraction_file = os.path.isfile(pretrained_model_name_or_path)
|
||||
is_directory = os.path.isdir(pretrained_model_name_or_path) and os.path.exists(
|
||||
os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
|
||||
)
|
||||
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
|
||||
feature_extractor_class = config_dict.get("feature_extractor_type", None)
|
||||
feature_extractor_auto_map = None
|
||||
if "AutoFeatureExtractor" in config_dict.get("auto_map", {}):
|
||||
feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"]
|
||||
|
||||
has_local_config = (
|
||||
os.path.exists(os.path.join(pretrained_model_name_or_path, CONFIG_NAME)) if is_directory else False
|
||||
)
|
||||
|
||||
# load config, if it can be loaded
|
||||
if not is_feature_extraction_file and (has_local_config or not is_directory):
|
||||
# If we don't find the feature extractor class in the feature extractor config, let's try the model config.
|
||||
if feature_extractor_class is None and feature_extractor_auto_map is None:
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
# It could be in `config.feature_extractor_type``
|
||||
feature_extractor_class = getattr(config, "feature_extractor_type", None)
|
||||
if hasattr(config, "auto_map") and "AutoFeatureExtractor" in config.auto_map:
|
||||
feature_extractor_auto_map = config.auto_map["AutoFeatureExtractor"]
|
||||
|
||||
kwargs["_from_auto"] = True
|
||||
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
|
||||
if feature_extractor_class is not None:
|
||||
# If we have custom code for a feature extractor, we get the proper class.
|
||||
if feature_extractor_auto_map is not None:
|
||||
if not trust_remote_code:
|
||||
raise ValueError(
|
||||
f"Loading {pretrained_model_name_or_path} requires you to execute the feature extractor file "
|
||||
"in that repo on your local machine. Make sure you have read the code there to avoid "
|
||||
"malicious use, then set the option `trust_remote_code=True` to remove this error."
|
||||
)
|
||||
if kwargs.get("revision", None) is None:
|
||||
logger.warning(
|
||||
"Explicitly passing a `revision` is encouraged when loading a feature extractor with custom "
|
||||
"code to ensure no malicious code has been contributed in a newer revision."
|
||||
)
|
||||
|
||||
model_type = config_class_to_model_type(type(config).__name__)
|
||||
module_file, class_name = feature_extractor_auto_map.split(".")
|
||||
feature_extractor_class = get_class_from_dynamic_module(
|
||||
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
|
||||
)
|
||||
else:
|
||||
feature_extractor_class = feature_extractor_class_from_name(feature_extractor_class)
|
||||
|
||||
if "feature_extractor_type" in config_dict:
|
||||
feature_extractor_class = feature_extractor_class_from_name(config_dict["feature_extractor_type"])
|
||||
return feature_extractor_class.from_dict(config_dict, **kwargs)
|
||||
elif model_type is not None:
|
||||
return FEATURE_EXTRACTOR_MAPPING[type(config)].from_dict(config_dict, **kwargs)
|
||||
# Last try: we use the FEATURE_EXTRACTOR_MAPPING.
|
||||
elif type(config) in FEATURE_EXTRACTOR_MAPPING:
|
||||
feature_extractor_class = FEATURE_EXTRACTOR_MAPPING[type(config)]
|
||||
return feature_extractor_class.from_dict(config_dict, **kwargs)
|
||||
|
||||
raise ValueError(
|
||||
f"Unrecognized feature extractor in {pretrained_model_name_or_path}. Should have a `feature_extractor_type` key in "
|
||||
f"its {FEATURE_EXTRACTOR_NAME}, or one of the following `model_type` keys in its {CONFIG_NAME}: "
|
||||
f"{', '.join(c for c in FEATURE_EXTRACTOR_MAPPING_NAMES.keys())}"
|
||||
f"Unrecognized feature extractor in {pretrained_model_name_or_path}. Should have a "
|
||||
f"`feature_extractor_type` key in its {FEATURE_EXTRACTOR_NAME} of {CONFIG_NAME}, or one of the following "
|
||||
"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in FEATURE_EXTRACTOR_MAPPING_NAMES.keys())}"
|
||||
)
|
||||
|
|
|
@ -82,3 +82,9 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
|||
"hf-internal-testing/config-no-model does not appear to have a file named preprocessor_config.json.",
|
||||
):
|
||||
_ = AutoFeatureExtractor.from_pretrained("hf-internal-testing/config-no-model")
|
||||
|
||||
def test_from_pretrained_dynamic_feature_extractor(self):
|
||||
model = AutoFeatureExtractor.from_pretrained(
|
||||
"hf-internal-testing/test_dynamic_feature_extractor", trust_remote_code=True
|
||||
)
|
||||
self.assertEqual(model.__class__.__name__, "NewFeatureExtractor")
|
||||
|
|
|
@ -16,9 +16,21 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import Repository, delete_repo, login
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import AutoFeatureExtractor
|
||||
from transformers.file_utils import is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import PASS, USER, is_staging_test
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||
|
||||
from test_module.custom_feature_extraction import CustomFeatureExtractor # noqa E402
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
|
@ -29,6 +41,9 @@ if is_vision_available():
|
|||
from PIL import Image
|
||||
|
||||
|
||||
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
|
||||
|
||||
|
||||
def prepare_image_inputs(feature_extract_tester, equal_resolution=False, numpify=False, torchify=False):
|
||||
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
|
||||
or a list of PyTorch tensors if one specifies torchify=True.
|
||||
|
@ -99,3 +114,41 @@ class FeatureExtractionSavingTestMixin:
|
|||
def test_init_without_params(self):
|
||||
feat_extract = self.feature_extraction_class()
|
||||
self.assertIsNotNone(feat_extract)
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ConfigPushToHubTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._token = login(username=USER, password=PASS)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
delete_repo(token=cls._token, name="test-dynamic-feature-extractor")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
def test_push_to_hub_dynamic_feature_extractor(self):
|
||||
CustomFeatureExtractor.register_for_auto_class()
|
||||
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-feature-extractor", use_auth_token=self._token)
|
||||
feature_extractor.save_pretrained(tmp_dir)
|
||||
|
||||
# This has added the proper auto_map field to the config
|
||||
self.assertDictEqual(
|
||||
feature_extractor.auto_map,
|
||||
{"AutoFeatureExtractor": "custom_feature_extraction.CustomFeatureExtractor"},
|
||||
)
|
||||
# The code has been copied from fixtures
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "custom_feature_extraction.py")))
|
||||
|
||||
repo.push_to_hub()
|
||||
|
||||
new_feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
f"{USER}/test-dynamic-feature-extractor", trust_remote_code=True
|
||||
)
|
||||
# Can't make an isinstance check because the new_feature_extractor is from the CustomFeatureExtractor class of a dynamic module
|
||||
self.assertEqual(new_feature_extractor.__class__.__name__, "CustomFeatureExtractor")
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
from transformers import Wav2Vec2FeatureExtractor
|
||||
|
||||
|
||||
class CustomFeatureExtractor(Wav2Vec2FeatureExtractor):
|
||||
pass
|
Loading…
Reference in New Issue