[FeatureExtractorSavingUtils] Refactor PretrainedFeatureExtractor (#10594)

* save first version

* finish refactor

* finish refactor

* correct naming

* correct naming

* shorter names

* Update src/transformers/feature_extraction_common_utils.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* change name

* finish

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Patrick von Platen 2021-03-09 12:16:59 +03:00 committed by GitHub
parent b6a28e9ac9
commit 9a06b6b11b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 638 additions and 572 deletions

View File

@ -14,16 +14,24 @@
Feature Extractor
-----------------------------------------------------------------------------------------------------------------------
A feature extractor is in charge of preparing read-in audio files for a speech model. This includes feature extraction,
such as processing audio files to, *e.g.*, Log-Mel Spectrogram features, but also padding, normalization, and
conversion to Numpy, PyTorch, and TensorFlow tensors.
A feature extractor is in charge of preparing input features for a multi-modal model. This includes feature extraction
from sequences, *e.g.*, pre-processing audio files to Log-Mel Spectrogram features, feature extraction from images
*e.g.* cropping image image files, but also padding, normalization, and conversion to Numpy, PyTorch, and TensorFlow
tensors.
PreTrainedFeatureExtractor
FeatureExtractionMixin
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.PreTrainedFeatureExtractor
:members: from_pretrained, save_pretrained, pad
.. autoclass:: transformers.feature_extraction_utils.FeatureExtractionMixin
:members: from_pretrained, save_pretrained
SequenceFeatureExtractor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.SequenceFeatureExtractor
:members: pad
BatchFeature

View File

@ -246,7 +246,7 @@ _import_structure = {
"SpecialTokensMixin",
"TokenSpan",
],
"feature_extraction_utils": ["PreTrainedFeatureExtractor", "BatchFeature"],
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor", "BatchFeature"],
"trainer_callback": [
"DefaultFlowCallback",
"EarlyStoppingCallback",
@ -1257,7 +1257,7 @@ if TYPE_CHECKING:
)
# Feature Extractor
from .feature_extraction_utils import BatchFeature, PreTrainedFeatureExtractor
from .feature_extraction_utils import BatchFeature, SequenceFeatureExtractor
# Files and general utilities
from .file_utils import (

View File

@ -0,0 +1,317 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Sequence feature extraction class for common feature extrcactors to preprocess sequences.
"""
from typing import Dict, List, Optional, Union
import numpy as np
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from .file_utils import (
PaddingStrategy,
TensorType,
_is_tensorflow,
_is_torch,
is_tf_available,
is_torch_available,
to_py_obj,
)
from .utils import logging
logger = logging.get_logger(__name__)
class SequenceFeatureExtractor(FeatureExtractionMixin):
"""
This is a general feature extraction class for speech recognition.
Args:
feature_size (:obj:`int`):
The feature dimension of the extracted features.
sampling_rate (:obj:`int`):
The sampling rate at which the audio files should be digitalized expressed in Hertz per second (Hz).
padding_value (:obj:`float`):
The value that is used to fill the padding values / vectors.
"""
def __init__(self, feature_size: int, sampling_rate: int, padding_value: float, **kwargs):
self.feature_size = feature_size
self.sampling_rate = sampling_rate
self.padding_value = padding_value
self.padding_side = kwargs.pop("padding_side", "right")
self.return_attention_mask = kwargs.pop("return_attention_mask", True)
# Additional attributes without default values
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error(f"Can't set {key} with value {value} for {self}")
raise err
def pad(
self,
processed_features: Union[
BatchFeature,
List[BatchFeature],
Dict[str, BatchFeature],
Dict[str, List[BatchFeature]],
List[Dict[str, BatchFeature]],
],
padding: Union[bool, str, PaddingStrategy] = True,
max_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> BatchFeature:
"""
Pad input values / input vectors or a batch of input values / input vectors up to predefined length or to the
max sequence length in the batch.
Padding side (left/right) padding values are defined at the feature extractor level (with
``self.padding_side``, ``self.padding_value``)
.. note::
If the ``processed_features`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors,
the result will use the same type unless you provide a different tensor type with ``return_tensors``. In
the case of PyTorch tensors, you will lose the specific device of your tensors however.
Args:
processed_features (:class:`~transformers.BatchFeature`, list of :class:`~transformers.BatchFeature`, :obj:`Dict[str, List[float]]`, :obj:`Dict[str, List[List[float]]` or :obj:`List[Dict[str, List[float]]]`):
Processed inputs. Can represent one input (:class:`~transformers.BatchFeature` or :obj:`Dict[str,
List[float]]`) or a batch of input values / vectors (list of :class:`~transformers.BatchFeature`,
`Dict[str, List[List[float]]]` or `List[Dict[str, List[float]]]`) so you can use this method during
preprocessing as well as in a PyTorch Dataloader collate function.
Instead of :obj:`List[float]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow
tensors), see the note above for the return type.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
>= 7.5 (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
return_attention_mask (:obj:`bool`, `optional`):
Whether to return the attention mask. If left to the default, will return the attention mask according
to the specific feature_extractor's default.
`What are attention masks? <../glossary.html#attention-mask>`__
return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
"""
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if isinstance(processed_features, (list, tuple)) and isinstance(processed_features[0], (dict, BatchFeature)):
processed_features = {
key: [example[key] for example in processed_features] for key in processed_features[0].keys()
}
# The model's main input name, usually `input_values`, has be passed for padding
if self.model_input_names[0] not in processed_features:
raise ValueError(
"You should supply an instance of :class:`~transformers.BatchFeature` or list of :class:`~transformers.BatchFeature` to this method"
f"that includes {self.model_input_names[0]}, but you provided {list(processed_features.keys())}"
)
required_input = processed_features[self.model_input_names[0]]
return_attention_mask = (
return_attention_mask if return_attention_mask is not None else self.return_attention_mask
)
if not required_input:
if return_attention_mask:
processed_features["attention_mask"] = []
return processed_features
# If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
# and rebuild them afterwards if no return_tensors is specified
# Note that we lose the specific device the tensor may be on for PyTorch
first_element = required_input[0]
if isinstance(first_element, (list, tuple)):
# first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
index = 0
while len(required_input[index]) == 0:
index += 1
if index < len(required_input):
first_element = required_input[index][0]
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
if not isinstance(first_element, (float, int, list, tuple)):
if is_tf_available() and _is_tensorflow(first_element):
return_tensors = "tf" if return_tensors is None else return_tensors
elif is_torch_available() and _is_torch(first_element):
return_tensors = "pt" if return_tensors is None else return_tensors
elif isinstance(first_element, np.ndarray):
return_tensors = "np" if return_tensors is None else return_tensors
else:
raise ValueError(
f"type of {first_element} unknown: {type(first_element)}. "
f"Should be one of a python, numpy, pytorch or tensorflow object."
)
for key, value in processed_features.items():
processed_features[key] = to_py_obj(value)
# Convert padding_strategy in PaddingStrategy
padding_strategy, max_length, _ = self._get_padding_strategies(padding=padding, max_length=max_length)
required_input = processed_features[self.model_input_names[0]]
if required_input and not isinstance(required_input[0], (list, tuple)):
processed_features = self._pad(
processed_features,
max_length=max_length,
padding_strategy=padding_strategy,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
return BatchFeature(processed_features, tensor_type=return_tensors)
batch_size = len(required_input)
assert all(
len(v) == batch_size for v in processed_features.values()
), "Some items in the output dictionary have a different batch size than others."
if padding_strategy == PaddingStrategy.LONGEST:
max_length = max(len(inputs) for inputs in required_input)
padding_strategy = PaddingStrategy.MAX_LENGTH
batch_outputs = {}
for i in range(batch_size):
inputs = dict((k, v[i]) for k, v in processed_features.items())
outputs = self._pad(
inputs,
max_length=max_length,
padding_strategy=padding_strategy,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
for key, value in outputs.items():
if key not in batch_outputs:
batch_outputs[key] = []
batch_outputs[key].append(value)
return BatchFeature(batch_outputs, tensor_type=return_tensors)
def _pad(
self,
processed_features: Union[Dict[str, List[float]], BatchFeature],
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
"""
Pad inputs (on left/right and up to predefined length or max length in the batch)
Args:
processed_features: Dictionary of input values (`List[float]`) / input vectors (`List[List[float]]`) or batch of inputs values (`List[List[int]]`) / input vectors (`List[List[List[int]]]`)
max_length: maximum length of the returned list and optionally padding length (see below)
padding_strategy: PaddingStrategy to use for padding.
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
- PaddingStrategy.DO_NOT_PAD: Do not pad
The feature_extractor padding sides are defined in self.padding_side:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
>= 7.5 (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
required_input = processed_features[self.model_input_names[0]]
if padding_strategy == PaddingStrategy.LONGEST:
max_length = len(required_input)
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
if needs_to_be_padded:
difference = max_length - len(required_input)
padding_vector = self.feature_size * [self.padding_value] if self.feature_size > 1 else self.padding_value
if self.padding_side == "right":
if return_attention_mask:
processed_features["attention_mask"] = [1] * len(required_input) + [0] * difference
processed_features[self.model_input_names[0]] = required_input + [
padding_vector for _ in range(difference)
]
elif self.padding_side == "left":
if return_attention_mask:
processed_features["attention_mask"] = [0] * difference + [1] * len(required_input)
processed_features[self.model_input_names[0]] = [
padding_vector for _ in range(difference)
] + required_input
else:
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
elif return_attention_mask and "attention_mask" not in processed_features:
processed_features["attention_mask"] = [1] * len(required_input)
return processed_features
def _get_padding_strategies(self, padding=False, max_length=None, pad_to_multiple_of=None, **kwargs):
"""
Find the correct padding strategy
"""
# Get padding strategy
if padding is not False:
if padding is True:
padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
elif not isinstance(padding, PaddingStrategy):
padding_strategy = PaddingStrategy(padding)
elif isinstance(padding, PaddingStrategy):
padding_strategy = padding
else:
padding_strategy = PaddingStrategy.DO_NOT_PAD
# Set max length if needed
if max_length is None:
if padding_strategy == PaddingStrategy.MAX_LENGTH:
raise ValueError(
f"When setting ``padding={PaddingStrategy.MAX_LENGTH}``, make sure that" f" max_length is defined"
)
# Test if we have a padding value
if padding_strategy != PaddingStrategy.DO_NOT_PAD and (self.padding_value is None):
raise ValueError(
"Asking to pad but the feature_extractor does not have a padding value. "
"Please select a value to use as `padding_value`. For example: `feature_extractor.padding_value = 0.0`."
)
return padding_strategy, max_length, kwargs

View File

@ -13,24 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Feature extraction common class for python feature extractors.
Feature extraction saving/loading class for common feature extractors.
"""
import copy
import json
import os
from collections import UserDict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import numpy as np
from .file_utils import (
FEATURE_EXTRACTOR_NAME,
PaddingStrategy,
TensorType,
_is_jax,
_is_numpy,
_is_tensorflow,
_is_torch,
_is_torch_device,
cached_path,
hf_bucket_url,
@ -39,23 +37,24 @@ from .file_utils import (
is_remote_url,
is_tf_available,
is_torch_available,
to_py_obj,
torch_required,
)
from .utils import logging
logger = logging.get_logger(__name__)
if TYPE_CHECKING:
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
PreTrainedFeatureExtractor = Union["SequenceFeatureExtractor"] # noqa: F821
class BatchFeature(UserDict):
r"""
Holds the output of the :meth:`~transformers.PreTrainedFeatureExtractor.pad` and feature extractor specific
Holds the output of the :meth:`~transformers.SequenceFeatureExtractor.pad` and feature extractor specific
``__call__`` methods.
This class is derived from a python dictionary and can be used as a dictionary.
@ -179,8 +178,7 @@ class BatchFeature(UserDict):
device (:obj:`str` or :obj:`torch.device`): The device to put the tensors on.
Returns:
:class:`~transformers.BatchFeature`: The same instance of :class:`~transformers.BatchFeature` after
modification.
:class:`~transformers.BatchFeature`: The same instance after modification.
"""
# This check catches things like APEX blindly calling "to" on all inputs to a module
@ -193,42 +191,19 @@ class BatchFeature(UserDict):
return self
class PreTrainedFeatureExtractor:
class FeatureExtractionMixin:
"""
This is a general feature extraction class for speech recognition.
Args:
feature_size (:obj:`int`):
The feature dimension of the extracted features.
sampling_rate (:obj:`int`):
The sampling rate at which the audio files should be digitalized expressed in Hertz per second (Hz).
padding_value (:obj:`float`):
The value that is used to fill the padding values / vectors.
This is a feature extraction mixin used to provide saving/loading functionality for sequential and image feature
extractors.
"""
def __init__(self, feature_size: int, sampling_rate: int, padding_value: float, **kwargs):
self.feature_size = feature_size
self.sampling_rate = sampling_rate
self.padding_value = padding_value
self.padding_side = kwargs.pop("padding_side", "right")
self.return_attention_mask = kwargs.pop("return_attention_mask", True)
# Additional attributes without default values
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error(f"Can't set {key} with value {value} for {self}")
raise err
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PreTrainedFeatureExtractor":
) -> PreTrainedFeatureExtractor:
r"""
Instantiate a :class:`~transformers.PreTrainedFeatureExtractor` (or a derived class) from a pretrained feature
extractor.
Instantiate a type of :class:`~transformers.feature_extraction_utils.FeatureExtractionMixin` from a feature
extractor, *e.g.* a derived class of :class:`~transformers.SequenceFeatureExtractor`.
Args:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
@ -238,7 +213,7 @@ class PreTrainedFeatureExtractor:
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing a feature extractor file saved using the
:func:`~transformers.PreTrainedFeatureExtractor.save_pretrained` method, e.g.,
:func:`~transformers.feature_extraction_utils.FeatureExtractionMixin.save_pretrained` method, e.g.,
``./my_model_directory/``.
- a path or url to a saved feature extractor JSON `file`, e.g.,
``./my_model_directory/feature_extraction_config.json``.
@ -262,12 +237,10 @@ class PreTrainedFeatureExtractor:
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`False`, then this function returns just the final feature extractor object.
If :obj:`True`, then this functions returns a :obj:`Tuple(feature_extractor, unused_kwargs)` where
`unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not feature extractor
attributes: i.e., the part of ``kwargs`` which has not been used to update ``feature_extractor`` and is
otherwise ignored.
If :obj:`False`, then this function returns just the final feature extractor object. If :obj:`True`,
then this functions returns a :obj:`Tuple(feature_extractor, unused_kwargs)` where `unused_kwargs` is a
dictionary consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the
part of ``kwargs`` which has not been used to update ``feature_extractor`` and is otherwise ignored.
kwargs (:obj:`Dict[str, Any]`, `optional`):
The values in kwargs of any keys which are feature extractor attributes will be used to override the
loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
@ -279,13 +252,12 @@ class PreTrainedFeatureExtractor:
Returns:
:class:`~transformers.PreTrainedFeatureExtractor`: The feature extractor object instantiated from this
pretrained model.
A feature extractor of type :class:`~transformers.feature_extraction_utils.FeatureExtractionMixin`.
Examples::
# We can't instantiate directly the base class `PreTrainedFeatureExtractor` so let's show the examples on a
# derived class: Wav2Vec2FeatureExtractor
# We can't instantiate directly the base class `FeatureExtractionMixin` nor `SequenceFeatureExtractor` so let's show the examples on a
# derived class: `Wav2Vec2FeatureExtractor`
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('facebook/wav2vec2-base-960h') # Download feature_extraction_config from huggingface.co and cache.
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('./test/saved_model/') # E.g. feature_extractor (or model) was saved using `save_pretrained('./test/saved_model/')`
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('./test/saved_model/preprocessor_config.json')
@ -295,7 +267,6 @@ class PreTrainedFeatureExtractor:
foo=False, return_unused_kwargs=True)
assert feature_extractor.return_attention_mask is False
assert unused_kwargs == {'foo': False}
"""
feature_extractor_dict, kwargs = cls.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
@ -304,7 +275,7 @@ class PreTrainedFeatureExtractor:
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
"""
Save a feature_extractor object to the directory ``save_directory``, so that it can be re-loaded using the
:func:`~transformers.PreTrainedFeatureExtractor.from_pretrained` class method.
:func:`~transformers.feature_extraction_utils.FeatureExtractionMixin.from_pretrained` class method.
Args:
save_directory (:obj:`str` or :obj:`os.PathLike`):
@ -325,7 +296,8 @@ class PreTrainedFeatureExtractor:
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used for instantiating a
:class:`~transformers.PreTrainedFeatureExtractor` using ``from_dict``.
feature extractor of type :class:`~transformers.feature_extraction_utils.FeatureExtractionMixin` using
``from_dict``.
Parameters:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
@ -400,21 +372,22 @@ class PreTrainedFeatureExtractor:
return feature_extractor_dict, kwargs
@classmethod
def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> "PreTrainedFeatureExtractor":
def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> PreTrainedFeatureExtractor:
"""
Instantiates a :class:`~transformers.PreTrainedFeatureExtractor` from a Python dictionary of parameters.
Instantiates a type of :class:`~transformers.feature_extraction_utils.FeatureExtractionMixin` from a Python
dictionary of parameters.
Args:
feature_extractor_dict (:obj:`Dict[str, Any]`):
Dictionary that will be used to instantiate the feature extractor object. Such a dictionary can be
retrieved from a pretrained checkpoint by leveraging the
:func:`~transformers.PreTrainedFeatureExtractor.to_dict` method.
:func:`~transformers.feature_extraction_utils.FeatureExtractionMixin.to_dict` method.
kwargs (:obj:`Dict[str, Any]`):
Additional parameters from which to initialize the feature extractor object.
Returns:
:class:`~transformers.PreTrainedFeatureExtractor`: The feature extractor object instantiated from those
parameters.
:class:`~transformers.feature_extraction_utils.FeatureExtractionMixin`: The feature extractor object
instantiated from those parameters.
"""
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
@ -447,18 +420,18 @@ class PreTrainedFeatureExtractor:
return output
@classmethod
def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PreTrainedFeatureExtractor":
def from_json_file(cls, json_file: Union[str, os.PathLike]) -> PreTrainedFeatureExtractor:
"""
Instantiates a :class:`~transformers.PreTrainedFeatureExtractor` from the path to a JSON file of parameters.
Instantiates a feature extractor of type :class:`~transformers.feature_extraction_utils.FeatureExtractionMixin`
from the path to a JSON file of parameters.
Args:
json_file (:obj:`str` or :obj:`os.PathLike`):
Path to the JSON file containing the parameters.
Returns:
:class:`~transformers.PreTrainedFeatureExtractor`: The feature_extractor object instantiated from that JSON
file.
A feature extractor of type :class:`~transformers.feature_extraction_utils.FeatureExtractionMixin`: The
feature_extractor object instantiated from that JSON file.
"""
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
@ -488,255 +461,3 @@ class PreTrainedFeatureExtractor:
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
def pad(
self,
processed_features: Union[
BatchFeature,
List[BatchFeature],
Dict[str, BatchFeature],
Dict[str, List[BatchFeature]],
List[Dict[str, BatchFeature]],
],
padding: Union[bool, str, PaddingStrategy] = True,
max_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> BatchFeature:
"""
Pad input values / input vectors or a batch of input values / input vectors up to predefined length or to the
max sequence length in the batch.
Padding side (left/right) padding values are defined at the feature extractor level (with
``self.padding_side``, ``self.padding_value``)
.. note::
If the ``processed_features`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors,
the result will use the same type unless you provide a different tensor type with ``return_tensors``. In
the case of PyTorch tensors, you will lose the specific device of your tensors however.
Args:
processed_features (:class:`~transformers.BatchFeature`, list of :class:`~transformers.BatchFeature`, :obj:`Dict[str, List[float]]`, :obj:`Dict[str, List[List[float]]` or :obj:`List[Dict[str, List[float]]]`):
Processed inputs. Can represent one input (:class:`~transformers.BatchFeature` or :obj:`Dict[str,
List[float]]`) or a batch of input values / vectors (list of :class:`~transformers.BatchFeature`,
`Dict[str, List[List[float]]]` or `List[Dict[str, List[float]]]`) so you can use this method during
preprocessing as well as in a PyTorch Dataloader collate function.
Instead of :obj:`List[float]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow
tensors), see the note above for the return type.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
>= 7.5 (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
return_attention_mask (:obj:`bool`, `optional`):
Whether to return the attention mask. If left to the default, will return the attention mask according
to the specific feature_extractor's default.
`What are attention masks? <../glossary.html#attention-mask>`__
return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
"""
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if isinstance(processed_features, (list, tuple)) and isinstance(processed_features[0], (dict, BatchFeature)):
processed_features = {
key: [example[key] for example in processed_features] for key in processed_features[0].keys()
}
# The model's main input name, usually `input_values`, has be passed for padding
if self.model_input_names[0] not in processed_features:
raise ValueError(
"You should supply an instance of :class:`~transformers.BatchFeature` or list of :class:`~transformers.BatchFeature` to this method"
f"that includes {self.model_input_names[0]}, but you provided {list(processed_features.keys())}"
)
required_input = processed_features[self.model_input_names[0]]
return_attention_mask = (
return_attention_mask if return_attention_mask is not None else self.return_attention_mask
)
if not required_input:
if return_attention_mask:
processed_features["attention_mask"] = []
return processed_features
# If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
# and rebuild them afterwards if no return_tensors is specified
# Note that we lose the specific device the tensor may be on for PyTorch
first_element = required_input[0]
if isinstance(first_element, (list, tuple)):
# first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
index = 0
while len(required_input[index]) == 0:
index += 1
if index < len(required_input):
first_element = required_input[index][0]
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
if not isinstance(first_element, (float, int, list, tuple)):
if is_tf_available() and _is_tensorflow(first_element):
return_tensors = "tf" if return_tensors is None else return_tensors
elif is_torch_available() and _is_torch(first_element):
return_tensors = "pt" if return_tensors is None else return_tensors
elif isinstance(first_element, np.ndarray):
return_tensors = "np" if return_tensors is None else return_tensors
else:
raise ValueError(
f"type of {first_element} unknown: {type(first_element)}. "
f"Should be one of a python, numpy, pytorch or tensorflow object."
)
for key, value in processed_features.items():
processed_features[key] = to_py_obj(value)
# Convert padding_strategy in PaddingStrategy
padding_strategy, max_length, _ = self._get_padding_strategies(padding=padding, max_length=max_length)
required_input = processed_features[self.model_input_names[0]]
if required_input and not isinstance(required_input[0], (list, tuple)):
processed_features = self._pad(
processed_features,
max_length=max_length,
padding_strategy=padding_strategy,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
return BatchFeature(processed_features, tensor_type=return_tensors)
batch_size = len(required_input)
assert all(
len(v) == batch_size for v in processed_features.values()
), "Some items in the output dictionary have a different batch size than others."
if padding_strategy == PaddingStrategy.LONGEST:
max_length = max(len(inputs) for inputs in required_input)
padding_strategy = PaddingStrategy.MAX_LENGTH
batch_outputs = {}
for i in range(batch_size):
inputs = dict((k, v[i]) for k, v in processed_features.items())
outputs = self._pad(
inputs,
max_length=max_length,
padding_strategy=padding_strategy,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
for key, value in outputs.items():
if key not in batch_outputs:
batch_outputs[key] = []
batch_outputs[key].append(value)
return BatchFeature(batch_outputs, tensor_type=return_tensors)
def _pad(
self,
processed_features: Union[Dict[str, List[float]], BatchFeature],
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
"""
Pad inputs (on left/right and up to predefined length or max length in the batch)
Args:
processed_features: Dictionary of input values (`List[float]`) / input vectors (`List[List[float]]`) or batch of inputs values (`List[List[int]]`) / input vectors (`List[List[List[int]]]`)
max_length: maximum length of the returned list and optionally padding length (see below)
padding_strategy: PaddingStrategy to use for padding.
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
- PaddingStrategy.DO_NOT_PAD: Do not pad
The feature_extractor padding sides are defined in self.padding_side:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
>= 7.5 (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
required_input = processed_features[self.model_input_names[0]]
if padding_strategy == PaddingStrategy.LONGEST:
max_length = len(required_input)
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
if needs_to_be_padded:
difference = max_length - len(required_input)
padding_vector = self.feature_size * [self.padding_value] if self.feature_size > 1 else self.padding_value
if self.padding_side == "right":
if return_attention_mask:
processed_features["attention_mask"] = [1] * len(required_input) + [0] * difference
processed_features[self.model_input_names[0]] = required_input + [
padding_vector for _ in range(difference)
]
elif self.padding_side == "left":
if return_attention_mask:
processed_features["attention_mask"] = [0] * difference + [1] * len(required_input)
processed_features[self.model_input_names[0]] = [
padding_vector for _ in range(difference)
] + required_input
else:
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
elif return_attention_mask and "attention_mask" not in processed_features:
processed_features["attention_mask"] = [1] * len(required_input)
return processed_features
def _get_padding_strategies(self, padding=False, max_length=None, pad_to_multiple_of=None, **kwargs):
"""
Find the correct padding strategy
"""
# Get padding strategy
if padding is not False:
if padding is True:
padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
elif not isinstance(padding, PaddingStrategy):
padding_strategy = PaddingStrategy(padding)
elif isinstance(padding, PaddingStrategy):
padding_strategy = padding
else:
padding_strategy = PaddingStrategy.DO_NOT_PAD
# Set max length if needed
if max_length is None:
if padding_strategy == PaddingStrategy.MAX_LENGTH:
raise ValueError(
f"When setting ``padding={PaddingStrategy.MAX_LENGTH}``, make sure that" f" max_length is defined"
)
# Test if we have a padding value
if padding_strategy != PaddingStrategy.DO_NOT_PAD and (self.padding_value is None):
raise ValueError(
"Asking to pad but the feature_extractor does not have a padding value. "
"Please select a value to use as `padding_value`. For example: `feature_extractor.padding_value = 0.0`."
)
return padding_strategy, max_length, kwargs

View File

@ -20,7 +20,8 @@ from typing import List, Optional, Union
import numpy as np
from ...feature_extraction_utils import BatchFeature, PreTrainedFeatureExtractor
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...file_utils import PaddingStrategy, TensorType
from ...utils import logging
@ -28,7 +29,7 @@ from ...utils import logging
logger = logging.get_logger(__name__)
class Wav2Vec2FeatureExtractor(PreTrainedFeatureExtractor):
class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
r"""
Constructs a Wav2Vec2 feature extractor.

View File

@ -59,7 +59,8 @@ class Wav2Vec2Processor:
.. note::
This class method is simply calling :meth:`~transformers.PreTrainedFeatureExtractor.save_pretrained` and
This class method is simply calling
:meth:`~transformers.feature_extraction_utils.FeatureExtractionMixin.save_pretrained` and
:meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.save_pretrained`. Please refer to the
docstrings of the methods above for more information.
@ -80,9 +81,9 @@ class Wav2Vec2Processor:
.. note::
This class method is simply calling Wav2Vec2FeatureExtractor's
:meth:`~transformers.PreTrainedFeatureExtractor.from_pretrained` and Wav2Vec2CTCTokenizer's
:meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.from_pretrained`. Please refer to the
docstrings of the methods above for more information.
:meth:`~transformers.feature_extraction_utils.FeatureExtractionMixin.from_pretrained` and
Wav2Vec2CTCTokenizer's :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.from_pretrained`.
Please refer to the docstrings of the methods above for more information.
Args:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
@ -92,12 +93,12 @@ class Wav2Vec2Processor:
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing a feature extractor file saved using the
:meth:`~transformers.PreTrainedFeatureExtractor.save_pretrained` method, e.g.,
:meth:`~transformers.SequenceFeatureExtractor.save_pretrained` method, e.g.,
``./my_model_directory/``.
- a path or url to a saved feature extractor JSON `file`, e.g.,
``./my_model_directory/feature_extraction_config.json``.
**kwargs
Additional keyword arguments passed along to both :class:`~transformers.PreTrainedFeatureExtractor` and
Additional keyword arguments passed along to both :class:`~transformers.SequenceFeatureExtractor` and
:class:`~transformers.PreTrainedTokenizer`
"""
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)

View File

@ -727,8 +727,7 @@ class BatchEncoding(UserDict):
device (:obj:`str` or :obj:`torch.device`): The device to put the tensors on.
Returns:
:class:`~transformers.BatchEncoding`: The same instance of :class:`~transformers.BatchEncoding` after
modification.
:class:`~transformers.BatchEncoding`: The same instance after modification.
"""
# This check catches things like APEX blindly calling "to" on all inputs to a module

View File

@ -18,28 +18,8 @@ import json
import os
import tempfile
import numpy as np
from transformers import BatchFeature
from transformers.testing_utils import require_tf, require_torch
class FeatureExtractionMixin:
# to overwrite at feature extractactor specific tests
feat_extract_tester = None
feature_extraction_class = None
@property
def feat_extract_dict(self):
return self.feat_extract_tester.prepare_feat_extract_dict()
def test_feat_extract_common_properties(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
self.assertTrue(hasattr(feat_extract, "feature_size"))
self.assertTrue(hasattr(feat_extract, "sampling_rate"))
self.assertTrue(hasattr(feat_extract, "padding_value"))
class FeatureExtractionSavingTestMixin:
def test_feat_extract_to_json_string(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
obj = json.loads(feat_extract.to_json_string())
@ -68,217 +48,3 @@ class FeatureExtractionMixin:
def test_init_without_params(self):
feat_extract = self.feature_extraction_class()
self.assertIsNotNone(feat_extract)
def test_batch_feature(self):
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs})
self.assertTrue(all(len(x) == len(y) for x, y in zip(speech_inputs, processed_features[input_name])))
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(equal_length=True)
processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="np")
batch_features_input = processed_features[input_name]
if len(batch_features_input.shape) < 3:
batch_features_input = batch_features_input[:, :, None]
self.assertTrue(
batch_features_input.shape
== (self.feat_extract_tester.batch_size, len(speech_inputs[0]), self.feat_extract_tester.feature_size)
)
@require_torch
def test_batch_feature_pt(self):
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(equal_length=True)
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="pt")
batch_features_input = processed_features[input_name]
if len(batch_features_input.shape) < 3:
batch_features_input = batch_features_input[:, :, None]
self.assertTrue(
batch_features_input.shape
== (self.feat_extract_tester.batch_size, len(speech_inputs[0]), self.feat_extract_tester.feature_size)
)
@require_tf
def test_batch_feature_tf(self):
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(equal_length=True)
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="tf")
batch_features_input = processed_features[input_name]
if len(batch_features_input.shape) < 3:
batch_features_input = batch_features_input[:, :, None]
self.assertTrue(
batch_features_input.shape
== (self.feat_extract_tester.batch_size, len(speech_inputs[0]), self.feat_extract_tester.feature_size)
)
def _check_padding(self, numpify=False):
def _inputs_have_equal_length(input):
length = len(input[0])
for input_slice in input[1:]:
if len(input_slice) != length:
return False
return True
def _inputs_are_equal(input_1, input_2):
if len(input_1) != len(input_2):
return False
for input_slice_1, input_slice_2 in zip(input_1, input_2):
if not np.allclose(np.asarray(input_slice_1), np.asarray(input_slice_2), atol=1e-3):
return False
return True
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(numpify=numpify)
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs})
pad_diff = self.feat_extract_tester.seq_length_diff
pad_max_length = self.feat_extract_tester.max_seq_length + pad_diff
pad_min_length = self.feat_extract_tester.min_seq_length
batch_size = self.feat_extract_tester.batch_size
feature_size = self.feat_extract_tester.feature_size
# test padding for List[int] + numpy
input_1 = feat_extract.pad(processed_features, padding=False)[input_name]
input_2 = feat_extract.pad(processed_features, padding="longest")[input_name]
input_3 = feat_extract.pad(processed_features, padding="max_length", max_length=len(speech_inputs[-1]))[
input_name
]
input_4 = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]
# max_length parameter has to be provided when setting `padding="max_length"`
with self.assertRaises(ValueError):
feat_extract.pad(processed_features, padding="max_length")[input_name]
input_5 = feat_extract.pad(
processed_features, padding="max_length", max_length=pad_max_length, return_tensors="np"
)[input_name]
self.assertFalse(_inputs_have_equal_length(input_1))
self.assertTrue(_inputs_have_equal_length(input_2))
self.assertTrue(_inputs_have_equal_length(input_3))
self.assertTrue(_inputs_are_equal(input_2, input_3))
self.assertTrue(len(input_1[0]) == pad_min_length)
self.assertTrue(len(input_1[1]) == pad_min_length + pad_diff)
self.assertTrue(input_4.shape[:2] == (batch_size, len(input_3[0])))
self.assertTrue(input_5.shape[:2] == (batch_size, pad_max_length))
if feature_size > 1:
self.assertTrue(input_4.shape[2] == input_5.shape[2] == feature_size)
# test padding for `pad_to_multiple_of` for List[int] + numpy
input_6 = feat_extract.pad(processed_features, pad_to_multiple_of=10)[input_name]
input_7 = feat_extract.pad(processed_features, padding="longest", pad_to_multiple_of=10)[input_name]
input_8 = feat_extract.pad(
processed_features, padding="max_length", pad_to_multiple_of=10, max_length=pad_max_length
)[input_name]
input_9 = feat_extract.pad(
processed_features,
padding="max_length",
pad_to_multiple_of=10,
max_length=pad_max_length,
return_tensors="np",
)[input_name]
self.assertTrue(all(len(x) % 10 == 0 for x in input_6))
self.assertTrue(_inputs_are_equal(input_6, input_7))
expected_mult_pad_length = pad_max_length if pad_max_length % 10 == 0 else (pad_max_length // 10 + 1) * 10
self.assertTrue(all(len(x) == expected_mult_pad_length for x in input_8))
self.assertTrue(input_9.shape[:2], (batch_size, expected_mult_pad_length))
if feature_size > 1:
self.assertTrue(input_9.shape[2] == feature_size)
# Check padding value is correct
padding_vector_sum = (np.ones(self.feat_extract_tester.feature_size) * feat_extract.padding_value).sum()
self.assertTrue(
abs(np.asarray(input_2[0])[pad_min_length:].sum() - padding_vector_sum * (pad_max_length - pad_min_length))
< 1e-3
)
self.assertTrue(
abs(
np.asarray(input_2[1])[pad_min_length + pad_diff :].sum()
- padding_vector_sum * (pad_max_length - pad_min_length - pad_diff)
)
< 1e-3
)
self.assertTrue(
abs(
np.asarray(input_2[2])[pad_min_length + 2 * pad_diff :].sum()
- padding_vector_sum * (pad_max_length - pad_min_length - 2 * pad_diff)
)
< 1e-3
)
self.assertTrue(
abs(input_5[0, pad_min_length:].sum() - padding_vector_sum * (pad_max_length - pad_min_length)) < 1e-3
)
self.assertTrue(
abs(input_9[0, pad_min_length:].sum() - padding_vector_sum * (expected_mult_pad_length - pad_min_length))
< 1e-3
)
def test_padding_from_list(self):
self._check_padding(numpify=False)
def test_padding_from_array(self):
self._check_padding(numpify=True)
@require_torch
def test_padding_accepts_tensors_pt(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs})
input_np = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]
input_pt = feat_extract.pad(processed_features, padding="longest", return_tensors="pt")[input_name]
self.assertTrue(abs(input_np.sum() - input_pt.numpy().sum()) < 1e-2)
@require_tf
def test_padding_accepts_tensors_tf(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs})
input_np = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]
input_tf = feat_extract.pad(processed_features, padding="longest", return_tensors="tf")[input_name]
self.assertTrue(abs(input_np.sum() - input_tf.numpy().sum()) < 1e-2)
def test_attention_mask(self):
feat_dict = self.feat_extract_dict
feat_dict["return_attention_mask"] = True
feat_extract = self.feature_extraction_class(**feat_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
input_lenghts = [len(x) for x in speech_inputs]
input_name = feat_extract.model_input_names[0]
processed = BatchFeature({input_name: speech_inputs})
processed = feat_extract.pad(processed, padding="longest", return_tensors="np")
self.assertIn("attention_mask", processed)
self.assertListEqual(list(processed.attention_mask.shape), list(processed[input_name].shape[:2]))
self.assertListEqual(processed.attention_mask.sum(-1).tolist(), input_lenghts)

View File

@ -23,7 +23,7 @@ import numpy as np
from transformers import WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, Wav2Vec2Config, Wav2Vec2FeatureExtractor
from transformers.testing_utils import slow
from .test_feature_extraction_common import FeatureExtractionMixin
from .test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
global_rng = random.Random()
@ -94,7 +94,7 @@ class Wav2Vec2FeatureExtractionTester(unittest.TestCase):
return speech_inputs
class Wav2Vec2FeatureExtractionTest(FeatureExtractionMixin, unittest.TestCase):
class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = Wav2Vec2FeatureExtractor

View File

@ -0,0 +1,253 @@
# coding=utf-8
# Copyright 2021 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from transformers import BatchFeature
from transformers.testing_utils import require_tf, require_torch
from .test_feature_extraction_common import FeatureExtractionSavingTestMixin
class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin):
# to overwrite at feature extractactor specific tests
feat_extract_tester = None
feature_extraction_class = None
@property
def feat_extract_dict(self):
return self.feat_extract_tester.prepare_feat_extract_dict()
def test_feat_extract_common_properties(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
self.assertTrue(hasattr(feat_extract, "feature_size"))
self.assertTrue(hasattr(feat_extract, "sampling_rate"))
self.assertTrue(hasattr(feat_extract, "padding_value"))
def test_batch_feature(self):
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs})
self.assertTrue(all(len(x) == len(y) for x, y in zip(speech_inputs, processed_features[input_name])))
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(equal_length=True)
processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="np")
batch_features_input = processed_features[input_name]
if len(batch_features_input.shape) < 3:
batch_features_input = batch_features_input[:, :, None]
self.assertTrue(
batch_features_input.shape
== (self.feat_extract_tester.batch_size, len(speech_inputs[0]), self.feat_extract_tester.feature_size)
)
@require_torch
def test_batch_feature_pt(self):
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(equal_length=True)
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="pt")
batch_features_input = processed_features[input_name]
if len(batch_features_input.shape) < 3:
batch_features_input = batch_features_input[:, :, None]
self.assertTrue(
batch_features_input.shape
== (self.feat_extract_tester.batch_size, len(speech_inputs[0]), self.feat_extract_tester.feature_size)
)
@require_tf
def test_batch_feature_tf(self):
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(equal_length=True)
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="tf")
batch_features_input = processed_features[input_name]
if len(batch_features_input.shape) < 3:
batch_features_input = batch_features_input[:, :, None]
self.assertTrue(
batch_features_input.shape
== (self.feat_extract_tester.batch_size, len(speech_inputs[0]), self.feat_extract_tester.feature_size)
)
def _check_padding(self, numpify=False):
def _inputs_have_equal_length(input):
length = len(input[0])
for input_slice in input[1:]:
if len(input_slice) != length:
return False
return True
def _inputs_are_equal(input_1, input_2):
if len(input_1) != len(input_2):
return False
for input_slice_1, input_slice_2 in zip(input_1, input_2):
if not np.allclose(np.asarray(input_slice_1), np.asarray(input_slice_2), atol=1e-3):
return False
return True
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(numpify=numpify)
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs})
pad_diff = self.feat_extract_tester.seq_length_diff
pad_max_length = self.feat_extract_tester.max_seq_length + pad_diff
pad_min_length = self.feat_extract_tester.min_seq_length
batch_size = self.feat_extract_tester.batch_size
feature_size = self.feat_extract_tester.feature_size
# test padding for List[int] + numpy
input_1 = feat_extract.pad(processed_features, padding=False)[input_name]
input_2 = feat_extract.pad(processed_features, padding="longest")[input_name]
input_3 = feat_extract.pad(processed_features, padding="max_length", max_length=len(speech_inputs[-1]))[
input_name
]
input_4 = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]
# max_length parameter has to be provided when setting `padding="max_length"`
with self.assertRaises(ValueError):
feat_extract.pad(processed_features, padding="max_length")[input_name]
input_5 = feat_extract.pad(
processed_features, padding="max_length", max_length=pad_max_length, return_tensors="np"
)[input_name]
self.assertFalse(_inputs_have_equal_length(input_1))
self.assertTrue(_inputs_have_equal_length(input_2))
self.assertTrue(_inputs_have_equal_length(input_3))
self.assertTrue(_inputs_are_equal(input_2, input_3))
self.assertTrue(len(input_1[0]) == pad_min_length)
self.assertTrue(len(input_1[1]) == pad_min_length + pad_diff)
self.assertTrue(input_4.shape[:2] == (batch_size, len(input_3[0])))
self.assertTrue(input_5.shape[:2] == (batch_size, pad_max_length))
if feature_size > 1:
self.assertTrue(input_4.shape[2] == input_5.shape[2] == feature_size)
# test padding for `pad_to_multiple_of` for List[int] + numpy
input_6 = feat_extract.pad(processed_features, pad_to_multiple_of=10)[input_name]
input_7 = feat_extract.pad(processed_features, padding="longest", pad_to_multiple_of=10)[input_name]
input_8 = feat_extract.pad(
processed_features, padding="max_length", pad_to_multiple_of=10, max_length=pad_max_length
)[input_name]
input_9 = feat_extract.pad(
processed_features,
padding="max_length",
pad_to_multiple_of=10,
max_length=pad_max_length,
return_tensors="np",
)[input_name]
self.assertTrue(all(len(x) % 10 == 0 for x in input_6))
self.assertTrue(_inputs_are_equal(input_6, input_7))
expected_mult_pad_length = pad_max_length if pad_max_length % 10 == 0 else (pad_max_length // 10 + 1) * 10
self.assertTrue(all(len(x) == expected_mult_pad_length for x in input_8))
self.assertTrue(input_9.shape[:2], (batch_size, expected_mult_pad_length))
if feature_size > 1:
self.assertTrue(input_9.shape[2] == feature_size)
# Check padding value is correct
padding_vector_sum = (np.ones(self.feat_extract_tester.feature_size) * feat_extract.padding_value).sum()
self.assertTrue(
abs(np.asarray(input_2[0])[pad_min_length:].sum() - padding_vector_sum * (pad_max_length - pad_min_length))
< 1e-3
)
self.assertTrue(
abs(
np.asarray(input_2[1])[pad_min_length + pad_diff :].sum()
- padding_vector_sum * (pad_max_length - pad_min_length - pad_diff)
)
< 1e-3
)
self.assertTrue(
abs(
np.asarray(input_2[2])[pad_min_length + 2 * pad_diff :].sum()
- padding_vector_sum * (pad_max_length - pad_min_length - 2 * pad_diff)
)
< 1e-3
)
self.assertTrue(
abs(input_5[0, pad_min_length:].sum() - padding_vector_sum * (pad_max_length - pad_min_length)) < 1e-3
)
self.assertTrue(
abs(input_9[0, pad_min_length:].sum() - padding_vector_sum * (expected_mult_pad_length - pad_min_length))
< 1e-3
)
def test_padding_from_list(self):
self._check_padding(numpify=False)
def test_padding_from_array(self):
self._check_padding(numpify=True)
@require_torch
def test_padding_accepts_tensors_pt(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs})
input_np = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]
input_pt = feat_extract.pad(processed_features, padding="longest", return_tensors="pt")[input_name]
self.assertTrue(abs(input_np.sum() - input_pt.numpy().sum()) < 1e-2)
@require_tf
def test_padding_accepts_tensors_tf(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs})
input_np = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]
input_tf = feat_extract.pad(processed_features, padding="longest", return_tensors="tf")[input_name]
self.assertTrue(abs(input_np.sum() - input_tf.numpy().sum()) < 1e-2)
def test_attention_mask(self):
feat_dict = self.feat_extract_dict
feat_dict["return_attention_mask"] = True
feat_extract = self.feature_extraction_class(**feat_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
input_lenghts = [len(x) for x in speech_inputs]
input_name = feat_extract.model_input_names[0]
processed = BatchFeature({input_name: speech_inputs})
processed = feat_extract.pad(processed, padding="longest", return_tensors="np")
self.assertIn("attention_mask", processed)
self.assertListEqual(list(processed.attention_mask.shape), list(processed[input_name].shape[:2]))
self.assertListEqual(processed.attention_mask.sum(-1).tolist(), input_lenghts)