Using assistant in AutomaticSpeechRecognitionPipeline with different encoder size (#30637)
* fiw input to generate in pipeline * fixup * pass input_features to generate with assistant * error if model and assistant with different enc size * fix * apply review suggestions * use self.config.is_encoder_decoder * pass inputs to generate directly * add slow tests * Update src/transformers/generation/utils.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update tests/pipelines/test_pipelines_automatic_speech_recognition.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update tests/pipelines/test_pipelines_automatic_speech_recognition.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update tests/pipelines/test_pipelines_automatic_speech_recognition.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update tests/pipelines/test_pipelines_automatic_speech_recognition.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update tests/pipelines/test_pipelines_automatic_speech_recognition.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * apply review * Update src/transformers/generation/utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/pipelines/test_pipelines_automatic_speech_recognition.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * apply code review * update attributes encoder_xyz to check * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * add slow test * solve conflicts --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
parent
15585b81a5
commit
eb1a77bbb0
|
@ -1097,6 +1097,25 @@ class GenerationMixin:
|
|||
exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}"
|
||||
raise TypeError(exception_message)
|
||||
|
||||
def _validate_assistant(self, assistant_model):
|
||||
if assistant_model is None:
|
||||
return
|
||||
|
||||
if self.config.is_encoder_decoder and not assistant_model.config.is_encoder_decoder:
|
||||
attributes_to_check = ["encoder_attention_heads", "encoder_ffn_dim", "encoder_layers"]
|
||||
attributes_to_check = [attr for attr in dir(assistant_model.config) if attr in attributes_to_check]
|
||||
are_equal = all(
|
||||
getattr(self.config, attr) == getattr(assistant_model.config, attr) for attr in attributes_to_check
|
||||
)
|
||||
if not are_equal:
|
||||
raise ValueError(
|
||||
"The main model and the assistant don't have compatible encoder-dependent input shapes. "
|
||||
"Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper."
|
||||
)
|
||||
|
||||
if not self.config.vocab_size == assistant_model.config.vocab_size:
|
||||
raise ValueError("Make sure the main and assistant model use the same tokenizer")
|
||||
|
||||
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
|
||||
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
|
||||
# If a `Cache` instance is passed, checks whether the model is compatible with it
|
||||
|
@ -1547,6 +1566,7 @@ class GenerationMixin:
|
|||
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
|
||||
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
|
||||
self._validate_model_kwargs(model_kwargs.copy())
|
||||
self._validate_assistant(assistant_model)
|
||||
|
||||
# 2. Set generation parameters if not already defined
|
||||
if synced_gpus is None:
|
||||
|
|
|
@ -474,7 +474,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||
raise ValueError("num_frames must be used only when stride is None")
|
||||
|
||||
if self.type in {"seq2seq", "seq2seq_whisper"}:
|
||||
encoder = self.model.get_encoder()
|
||||
# Consume values so we can let extra information flow freely through
|
||||
# the pipeline (important for `partial` in microphone)
|
||||
if "input_features" in model_inputs:
|
||||
|
@ -499,16 +498,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||
generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length
|
||||
else:
|
||||
generate_kwargs["num_frames"] = [s[0] // self.feature_extractor.hop_length for s in stride]
|
||||
|
||||
else:
|
||||
generate_kwargs["num_frames"] = num_frames
|
||||
|
||||
if self.type == "seq2seq_whisper" and inputs.shape[-1] > self.feature_extractor.nb_max_frames:
|
||||
generate_kwargs["input_features"] = inputs
|
||||
else:
|
||||
generate_kwargs["encoder_outputs"] = encoder(inputs, attention_mask=attention_mask)
|
||||
|
||||
tokens = self.model.generate(
|
||||
inputs=inputs,
|
||||
attention_mask=attention_mask,
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
|
|
@ -45,6 +45,7 @@ if is_torch_available():
|
|||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSpeechSeq2Seq,
|
||||
AutoModelForVision2Seq,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
BartForCausalLM,
|
||||
BartForConditionalGeneration,
|
||||
|
@ -2919,6 +2920,67 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||
# update_candidate_strategy is called once but assistant_model.generation_config.num_assistant_tokens should stay 5
|
||||
self.assertEqual(assistant_model.generation_config.num_assistant_tokens, 5)
|
||||
|
||||
@slow
|
||||
def test_validate_assistant(self):
|
||||
# Generate a random sample:
|
||||
inputs = np.random.rand(160000)
|
||||
|
||||
# Load a main encoder-decoder model:
|
||||
model_id = "openai/whisper-large-v2"
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
model_id,
|
||||
low_cpu_mem_usage=True,
|
||||
use_safetensors=True,
|
||||
)
|
||||
model.to(torch_device)
|
||||
|
||||
# process the input:
|
||||
features = processor(inputs, return_tensors="pt").to(torch_device)
|
||||
|
||||
# Load an encoder-decoder assistant with same encoder as the main model:
|
||||
assistant_distil_model_id = "distil-whisper/distil-large-v2"
|
||||
assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
assistant_distil_model_id,
|
||||
use_safetensors=True,
|
||||
).to(torch_device)
|
||||
self.assertTrue(model.generate(**features, assistant_model=assistant_seq_to_seq).sum())
|
||||
|
||||
# Load its decoder only version:
|
||||
assistant_causal_lm = AutoModelForCausalLM.from_pretrained(
|
||||
assistant_distil_model_id,
|
||||
low_cpu_mem_usage=True,
|
||||
use_safetensors=True,
|
||||
).to(torch_device)
|
||||
self.assertTrue(model.generate(**features, assistant_model=assistant_causal_lm).sum())
|
||||
|
||||
# Load an encoder-decoder assistant with a different encoder than the main model:
|
||||
assistant_distil_model_id = "openai/whisper-tiny"
|
||||
assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
assistant_distil_model_id,
|
||||
use_safetensors=True,
|
||||
).to(torch_device)
|
||||
self.assertTrue(model.generate(**features, assistant_model=assistant_seq_to_seq).sum())
|
||||
|
||||
# Load its decoder only version:
|
||||
assistant_causal_lm = AutoModelForCausalLM.from_pretrained(
|
||||
assistant_distil_model_id,
|
||||
low_cpu_mem_usage=True,
|
||||
use_safetensors=True,
|
||||
).to(torch_device)
|
||||
# It will raise an error as the encoder of the main and assistant model are not compatible:
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(**features, assistant_model=assistant_causal_lm)
|
||||
|
||||
# Load an encoder-decoder model with a different tokenizer than the main model:
|
||||
assistant_distil_model_id = "hf-internal-testing/tiny-random-SeamlessM4Tv2ForSpeechToText"
|
||||
assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
assistant_distil_model_id,
|
||||
).to(torch_device)
|
||||
# This should raise an error as the main and assistant model don't use the same tokenizer:
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(**features, assistant_model=assistant_seq_to_seq)
|
||||
|
||||
def test_compare_unprocessed_logit_scores(self):
|
||||
# Get unprocessed logit scores back from model generate function.
|
||||
# Assert that unprocessed logits from generate() are same as those from modal eval()
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
@ -23,6 +24,8 @@ from transformers import (
|
|||
MODEL_FOR_CTC_MAPPING,
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
AutoFeatureExtractor,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSpeechSeq2Seq,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
Speech2TextForConditionalGeneration,
|
||||
|
@ -1138,6 +1141,94 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||
{"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."},
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_speculative_decoding_whisper_non_distil(self):
|
||||
# Load data:
|
||||
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]")
|
||||
sample = dataset[0]["audio"]
|
||||
|
||||
# Load model:
|
||||
model_id = "openai/whisper-large-v2"
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
model_id,
|
||||
use_safetensors=True,
|
||||
)
|
||||
|
||||
# Load assistant:
|
||||
assistant_model_id = "openai/whisper-tiny"
|
||||
assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
assistant_model_id,
|
||||
use_safetensors=True,
|
||||
)
|
||||
|
||||
# Load pipeline:
|
||||
pipe = AutomaticSpeechRecognitionPipeline(
|
||||
model=model,
|
||||
tokenizer=processor.tokenizer,
|
||||
feature_extractor=processor.feature_extractor,
|
||||
generate_kwargs={"language": "en"},
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
transcription_non_ass = pipe(sample.copy(), generate_kwargs={"assistant_model": assistant_model})["text"]
|
||||
total_time_assist = time.time() - start_time
|
||||
|
||||
start_time = time.time()
|
||||
transcription_ass = pipe(sample)["text"]
|
||||
total_time_non_assist = time.time() - start_time
|
||||
|
||||
self.assertEqual(transcription_ass, transcription_non_ass)
|
||||
self.assertEqual(
|
||||
transcription_ass,
|
||||
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
|
||||
)
|
||||
self.assertTrue(total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster")
|
||||
|
||||
@slow
|
||||
def test_speculative_decoding_whisper_distil(self):
|
||||
# Load data:
|
||||
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]")
|
||||
sample = dataset[0]["audio"]
|
||||
|
||||
# Load model:
|
||||
model_id = "openai/whisper-large-v2"
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
model_id,
|
||||
use_safetensors=True,
|
||||
)
|
||||
|
||||
# Load assistant:
|
||||
assistant_model_id = "distil-whisper/distil-large-v2"
|
||||
assistant_model = AutoModelForCausalLM.from_pretrained(
|
||||
assistant_model_id,
|
||||
use_safetensors=True,
|
||||
)
|
||||
|
||||
# Load pipeline:
|
||||
pipe = AutomaticSpeechRecognitionPipeline(
|
||||
model=model,
|
||||
tokenizer=processor.tokenizer,
|
||||
feature_extractor=processor.feature_extractor,
|
||||
generate_kwargs={"language": "en"},
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
transcription_non_ass = pipe(sample.copy(), generate_kwargs={"assistant_model": assistant_model})["text"]
|
||||
total_time_assist = time.time() - start_time
|
||||
|
||||
start_time = time.time()
|
||||
transcription_ass = pipe(sample)["text"]
|
||||
total_time_non_assist = time.time() - start_time
|
||||
|
||||
self.assertEqual(transcription_ass, transcription_non_ass)
|
||||
self.assertEqual(
|
||||
transcription_ass,
|
||||
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
|
||||
)
|
||||
self.assertEqual(total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster")
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
|
|
Loading…
Reference in New Issue