From e6708709cb7ee2cc04df641403ed0671ee7806c6 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Mon, 20 May 2024 13:40:42 +0200 Subject: [PATCH] Add AutoFeatureExtractor support to Wav2Vec2ProcessorWithLM (#28706) * Add AutoFeatureExtractor support to Wav2Vec2ProcessorWithLM * update with a type filter * add raises error test * fix added test --- .../processing_wav2vec2_with_lm.py | 21 ++++++++----- .../test_processor_wav2vec2_with_lm.py | 31 ++++++++++++++++++- 2 files changed, 43 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py index b388be245f..9fceb1e61a 100644 --- a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +++ b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py @@ -70,15 +70,15 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): with language model support into a single processor for language model boosted speech recognition decoding. Args: - feature_extractor ([`Wav2Vec2FeatureExtractor`]): - An instance of [`Wav2Vec2FeatureExtractor`]. The feature extractor is a required input. + feature_extractor ([`Wav2Vec2FeatureExtractor`] or [`SeamlessM4TFeatureExtractor`]): + An instance of [`Wav2Vec2FeatureExtractor`] or [`SeamlessM4TFeatureExtractor`]. The feature extractor is a required input. tokenizer ([`Wav2Vec2CTCTokenizer`]): An instance of [`Wav2Vec2CTCTokenizer`]. The tokenizer is a required input. decoder (`pyctcdecode.BeamSearchDecoderCTC`): An instance of [`pyctcdecode.BeamSearchDecoderCTC`]. The decoder is a required input. """ - feature_extractor_class = "Wav2Vec2FeatureExtractor" + feature_extractor_class = "AutoFeatureExtractor" tokenizer_class = "Wav2Vec2CTCTokenizer" def __init__( @@ -93,6 +93,11 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): if not isinstance(decoder, BeamSearchDecoderCTC): raise ValueError(f"`decoder` has to be of type {BeamSearchDecoderCTC.__class__}, but is {type(decoder)}") + if feature_extractor.__class__.__name__ not in ["Wav2Vec2FeatureExtractor", "SeamlessM4TFeatureExtractor"]: + raise ValueError( + f"`feature_extractor` has to be of type `Wav2Vec2FeatureExtractor` or `SeamlessM4TFeatureExtractor`, but is {type(feature_extractor)}" + ) + # make sure that decoder's alphabet and tokenizer's vocab match in content missing_decoder_tokens = self.get_missing_alphabet_tokens(decoder, tokenizer) if len(missing_decoder_tokens) > 0: @@ -117,7 +122,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): - This class method is simply calling Wav2Vec2FeatureExtractor's + This class method is simply calling the feature extractor's [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`], Wav2Vec2CTCTokenizer's [`~tokenization_utils_base.PreTrainedTokenizerBase.from_pretrained`], and [`pyctcdecode.BeamSearchDecoderCTC.load_from_hf_hub`]. @@ -213,8 +218,8 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): def __call__(self, *args, **kwargs): """ - When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's - [`~Wav2Vec2FeatureExtractor.__call__`] and returns its output. If used in the context + When used in normal mode, this method forwards all its arguments to the feature extractor's + [`~FeatureExtractionMixin.__call__`] and returns its output. If used in the context [`~Wav2Vec2ProcessorWithLM.as_target_processor`] this method forwards all its arguments to Wav2Vec2CTCTokenizer's [`~Wav2Vec2CTCTokenizer.__call__`]. Please refer to the docstring of the above two methods for more information. @@ -252,8 +257,8 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): def pad(self, *args, **kwargs): """ - When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's - [`~Wav2Vec2FeatureExtractor.pad`] and returns its output. If used in the context + When used in normal mode, this method forwards all its arguments to the feature extractor's + [`~FeatureExtractionMixin.pad`] and returns its output. If used in the context [`~Wav2Vec2ProcessorWithLM.as_target_processor`] this method forwards all its arguments to Wav2Vec2CTCTokenizer's [`~Wav2Vec2CTCTokenizer.pad`]. Please refer to the docstring of the above two methods for more information. diff --git a/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py b/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py index 2c52a92165..61dee30091 100644 --- a/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py +++ b/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py @@ -25,7 +25,7 @@ import numpy as np from datasets import load_dataset from parameterized import parameterized -from transformers import AutoProcessor +from transformers import AutoFeatureExtractor, AutoProcessor from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES from transformers.testing_utils import require_pyctcdecode, require_torch, require_torchaudio, slow @@ -157,6 +157,35 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): for key in input_feat_extract.keys(): self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) + def test_another_feature_extractor(self): + feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0") + tokenizer = self.get_tokenizer() + decoder = self.get_decoder() + + processor = Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder) + + raw_speech = floats_list((3, 1000)) + + input_feat_extract = feature_extractor(raw_speech, return_tensors="np") + input_processor = processor(raw_speech, return_tensors="np") + + for key in input_feat_extract.keys(): + self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) + + self.assertListEqual( + processor.model_input_names, + feature_extractor.model_input_names, + msg="`processor` and `feature_extractor` model input names do not match", + ) + + def test_wrong_feature_extractor_raises_error(self): + feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-large-v3") + tokenizer = self.get_tokenizer() + decoder = self.get_decoder() + + with self.assertRaises(ValueError): + Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder) + def test_tokenizer(self): feature_extractor = self.get_feature_extractor() tokenizer = self.get_tokenizer()