Using assistant in AutomaticSpeechRecognitionPipeline with different encoder size (#30637)

* fiw input to generate in pipeline

* fixup

* pass input_features to generate with assistant

* error if model and assistant with different enc size

* fix

* apply review suggestions

* use self.config.is_encoder_decoder

* pass inputs to generate directly

* add slow tests

* Update src/transformers/generation/utils.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update tests/pipelines/test_pipelines_automatic_speech_recognition.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update tests/pipelines/test_pipelines_automatic_speech_recognition.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update tests/pipelines/test_pipelines_automatic_speech_recognition.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update tests/pipelines/test_pipelines_automatic_speech_recognition.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update tests/pipelines/test_pipelines_automatic_speech_recognition.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* apply review

* Update src/transformers/generation/utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/pipelines/test_pipelines_automatic_speech_recognition.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* apply code review

* update attributes encoder_xyz to check

* Update src/transformers/generation/utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* add slow test

* solve conflicts

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Kamil Akesbi 2024-05-23 10:59:38 +02:00 committed by GitHub
parent 15585b81a5
commit eb1a77bbb0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 174 additions and 7 deletions

View File

@ -1097,6 +1097,25 @@ class GenerationMixin:
exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}" exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}"
raise TypeError(exception_message) raise TypeError(exception_message)
def _validate_assistant(self, assistant_model):
if assistant_model is None:
return
if self.config.is_encoder_decoder and not assistant_model.config.is_encoder_decoder:
attributes_to_check = ["encoder_attention_heads", "encoder_ffn_dim", "encoder_layers"]
attributes_to_check = [attr for attr in dir(assistant_model.config) if attr in attributes_to_check]
are_equal = all(
getattr(self.config, attr) == getattr(assistant_model.config, attr) for attr in attributes_to_check
)
if not are_equal:
raise ValueError(
"The main model and the assistant don't have compatible encoder-dependent input shapes. "
"Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper."
)
if not self.config.vocab_size == assistant_model.config.vocab_size:
raise ValueError("Make sure the main and assistant model use the same tokenizer")
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here.""" """Validates model kwargs for generation. Generate argument typos will also be caught here."""
# If a `Cache` instance is passed, checks whether the model is compatible with it # If a `Cache` instance is passed, checks whether the model is compatible with it
@ -1547,6 +1566,7 @@ class GenerationMixin:
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
self._validate_model_kwargs(model_kwargs.copy()) self._validate_model_kwargs(model_kwargs.copy())
self._validate_assistant(assistant_model)
# 2. Set generation parameters if not already defined # 2. Set generation parameters if not already defined
if synced_gpus is None: if synced_gpus is None:

View File

@ -474,7 +474,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
raise ValueError("num_frames must be used only when stride is None") raise ValueError("num_frames must be used only when stride is None")
if self.type in {"seq2seq", "seq2seq_whisper"}: if self.type in {"seq2seq", "seq2seq_whisper"}:
encoder = self.model.get_encoder()
# Consume values so we can let extra information flow freely through # Consume values so we can let extra information flow freely through
# the pipeline (important for `partial` in microphone) # the pipeline (important for `partial` in microphone)
if "input_features" in model_inputs: if "input_features" in model_inputs:
@ -499,16 +498,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length
else: else:
generate_kwargs["num_frames"] = [s[0] // self.feature_extractor.hop_length for s in stride] generate_kwargs["num_frames"] = [s[0] // self.feature_extractor.hop_length for s in stride]
else: else:
generate_kwargs["num_frames"] = num_frames generate_kwargs["num_frames"] = num_frames
if self.type == "seq2seq_whisper" and inputs.shape[-1] > self.feature_extractor.nb_max_frames:
generate_kwargs["input_features"] = inputs
else:
generate_kwargs["encoder_outputs"] = encoder(inputs, attention_mask=attention_mask)
tokens = self.model.generate( tokens = self.model.generate(
inputs=inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
**generate_kwargs, **generate_kwargs,
) )

View File

@ -45,6 +45,7 @@ if is_torch_available():
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoModelForSpeechSeq2Seq, AutoModelForSpeechSeq2Seq,
AutoModelForVision2Seq, AutoModelForVision2Seq,
AutoProcessor,
AutoTokenizer, AutoTokenizer,
BartForCausalLM, BartForCausalLM,
BartForConditionalGeneration, BartForConditionalGeneration,
@ -2919,6 +2920,67 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
# update_candidate_strategy is called once but assistant_model.generation_config.num_assistant_tokens should stay 5 # update_candidate_strategy is called once but assistant_model.generation_config.num_assistant_tokens should stay 5
self.assertEqual(assistant_model.generation_config.num_assistant_tokens, 5) self.assertEqual(assistant_model.generation_config.num_assistant_tokens, 5)
@slow
def test_validate_assistant(self):
# Generate a random sample:
inputs = np.random.rand(160000)
# Load a main encoder-decoder model:
model_id = "openai/whisper-large-v2"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
low_cpu_mem_usage=True,
use_safetensors=True,
)
model.to(torch_device)
# process the input:
features = processor(inputs, return_tensors="pt").to(torch_device)
# Load an encoder-decoder assistant with same encoder as the main model:
assistant_distil_model_id = "distil-whisper/distil-large-v2"
assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained(
assistant_distil_model_id,
use_safetensors=True,
).to(torch_device)
self.assertTrue(model.generate(**features, assistant_model=assistant_seq_to_seq).sum())
# Load its decoder only version:
assistant_causal_lm = AutoModelForCausalLM.from_pretrained(
assistant_distil_model_id,
low_cpu_mem_usage=True,
use_safetensors=True,
).to(torch_device)
self.assertTrue(model.generate(**features, assistant_model=assistant_causal_lm).sum())
# Load an encoder-decoder assistant with a different encoder than the main model:
assistant_distil_model_id = "openai/whisper-tiny"
assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained(
assistant_distil_model_id,
use_safetensors=True,
).to(torch_device)
self.assertTrue(model.generate(**features, assistant_model=assistant_seq_to_seq).sum())
# Load its decoder only version:
assistant_causal_lm = AutoModelForCausalLM.from_pretrained(
assistant_distil_model_id,
low_cpu_mem_usage=True,
use_safetensors=True,
).to(torch_device)
# It will raise an error as the encoder of the main and assistant model are not compatible:
with self.assertRaises(ValueError):
model.generate(**features, assistant_model=assistant_causal_lm)
# Load an encoder-decoder model with a different tokenizer than the main model:
assistant_distil_model_id = "hf-internal-testing/tiny-random-SeamlessM4Tv2ForSpeechToText"
assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained(
assistant_distil_model_id,
).to(torch_device)
# This should raise an error as the main and assistant model don't use the same tokenizer:
with self.assertRaises(ValueError):
model.generate(**features, assistant_model=assistant_seq_to_seq)
def test_compare_unprocessed_logit_scores(self): def test_compare_unprocessed_logit_scores(self):
# Get unprocessed logit scores back from model generate function. # Get unprocessed logit scores back from model generate function.
# Assert that unprocessed logits from generate() are same as those from modal eval() # Assert that unprocessed logits from generate() are same as those from modal eval()

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import time
import unittest import unittest
import numpy as np import numpy as np
@ -23,6 +24,8 @@ from transformers import (
MODEL_FOR_CTC_MAPPING, MODEL_FOR_CTC_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
AutoFeatureExtractor, AutoFeatureExtractor,
AutoModelForCausalLM,
AutoModelForSpeechSeq2Seq,
AutoProcessor, AutoProcessor,
AutoTokenizer, AutoTokenizer,
Speech2TextForConditionalGeneration, Speech2TextForConditionalGeneration,
@ -1138,6 +1141,94 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
{"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."}, {"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."},
) )
@slow
def test_speculative_decoding_whisper_non_distil(self):
# Load data:
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]")
sample = dataset[0]["audio"]
# Load model:
model_id = "openai/whisper-large-v2"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
use_safetensors=True,
)
# Load assistant:
assistant_model_id = "openai/whisper-tiny"
assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained(
assistant_model_id,
use_safetensors=True,
)
# Load pipeline:
pipe = AutomaticSpeechRecognitionPipeline(
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
generate_kwargs={"language": "en"},
)
start_time = time.time()
transcription_non_ass = pipe(sample.copy(), generate_kwargs={"assistant_model": assistant_model})["text"]
total_time_assist = time.time() - start_time
start_time = time.time()
transcription_ass = pipe(sample)["text"]
total_time_non_assist = time.time() - start_time
self.assertEqual(transcription_ass, transcription_non_ass)
self.assertEqual(
transcription_ass,
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
)
self.assertTrue(total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster")
@slow
def test_speculative_decoding_whisper_distil(self):
# Load data:
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]")
sample = dataset[0]["audio"]
# Load model:
model_id = "openai/whisper-large-v2"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
use_safetensors=True,
)
# Load assistant:
assistant_model_id = "distil-whisper/distil-large-v2"
assistant_model = AutoModelForCausalLM.from_pretrained(
assistant_model_id,
use_safetensors=True,
)
# Load pipeline:
pipe = AutomaticSpeechRecognitionPipeline(
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
generate_kwargs={"language": "en"},
)
start_time = time.time()
transcription_non_ass = pipe(sample.copy(), generate_kwargs={"assistant_model": assistant_model})["text"]
total_time_assist = time.time() - start_time
start_time = time.time()
transcription_ass = pipe(sample)["text"]
total_time_non_assist = time.time() - start_time
self.assertEqual(transcription_ass, transcription_non_ass)
self.assertEqual(
transcription_ass,
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
)
self.assertEqual(total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster")
@slow @slow
@require_torch @require_torch
@require_torchaudio @require_torchaudio