diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index cc380a9dcc..b60e59d082 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1754,6 +1754,9 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): "See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config." ) + if kwargs.get("num_frames") is not None: + generation_config.num_frames = kwargs.pop("num_frames") + outputs = super().generate( inputs, generation_config, @@ -1765,7 +1768,10 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ) if return_token_timestamps and hasattr(generation_config, "alignment_heads"): - outputs["token_timestamps"] = self._extract_token_timestamps(outputs, generation_config.alignment_heads) + num_frames = getattr(generation_config, "num_frames", None) + outputs["token_timestamps"] = self._extract_token_timestamps( + outputs, generation_config.alignment_heads, num_frames=num_frames + ) return outputs @@ -1799,10 +1805,11 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ) return reordered_past - def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02): + def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None): """ Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to - map each output token to a position in the input audio. + map each output token to a position in the input audio. If `num_frames` is specified, the encoder-decoder + cross-attentions will be cropped before applying DTW. Returns: tensor containing the timestamps in seconds for each predicted token @@ -1817,6 +1824,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): # of shape (batch size, num selected, output length, input length). weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads]) weights = weights.permute([1, 0, 2, 3]) + if num_frames is not None: + weights = weights[..., : num_frames // 2] # Normalize and smoothen the weights. std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 98e43eef85..77470b5b43 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -523,10 +523,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): if generate_kwargs is None: generate_kwargs = {} - if return_timestamps and self.type == "seq2seq_whisper": - generate_kwargs["return_timestamps"] = return_timestamps - if return_timestamps == "word": - generate_kwargs["return_token_timestamps"] = True + attention_mask = model_inputs.pop("attention_mask", None) + stride = model_inputs.pop("stride", None) is_last = model_inputs.pop("is_last") if self.type in {"seq2seq", "seq2seq_whisper"}: @@ -543,11 +541,15 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): f"`input_features` or `input_values` key, but only has {model_inputs.keys()}" ) - # we need to pass `processed.get("attention_mask")` here since audio encoder - # attention mask length is different from expected text decoder `encoder_attention_mask` length - # `generate` magic to create the mask automatically won't work, we basically need to help - # it here. - attention_mask = model_inputs.pop("attention_mask", None) + # custom processing for Whisper timestamps and word-level timestamps + if return_timestamps and self.type == "seq2seq_whisper": + generate_kwargs["return_timestamps"] = return_timestamps + if return_timestamps == "word": + generate_kwargs["return_token_timestamps"] = True + + if stride is not None: + generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length + tokens = self.model.generate( encoder_outputs=encoder(inputs, attention_mask=attention_mask), attention_mask=attention_mask, @@ -558,14 +560,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): else: out = {"tokens": tokens} if self.type == "seq2seq_whisper": - stride = model_inputs.pop("stride", None) if stride is not None: out["stride"] = stride else: - stride = model_inputs.pop("stride", None) input_values = model_inputs.pop("input_values") - attention_mask = model_inputs.pop("attention_mask", None) outputs = self.model(input_values=input_values, attention_mask=attention_mask) logits = outputs.logits diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 0a90f3e5e3..d989db3ef2 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -299,6 +299,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): output = speech_recognizer(filename) self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"}) + @slow @require_torch @slow def test_return_timestamps_in_preprocess(self): @@ -319,28 +320,28 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): res, { "text": " Conquered returned to its place amidst the tents.", - "chunks": [{"text": " Conquered returned to its place amidst the tents.", "timestamp": (0.0, 3.36)}], + "chunks": [{"timestamp": (0.0, 3.36), "text": " Conquered returned to its place amidst the tents."}], }, ) pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]] res = pipe(sample["audio"]["array"], return_timestamps="word") + # fmt: off - # Note that the word-level timestamps predicted here are pretty bad. self.assertEqual( res, { "text": " Conquered returned to its place amidst the tents.", "chunks": [ - {'text': ' Conquered', 'timestamp': (29.78, 29.9)}, - {'text': ' returned', 'timestamp': (29.9, 29.9)}, - {'text': ' to', 'timestamp': (29.9, 29.9)}, - {'text': ' its', 'timestamp': (29.9, 29.9)}, - {'text': ' place', 'timestamp': (29.9, 29.9)}, - {'text': ' amidst', 'timestamp': (29.9, 29.9)}, - {'text': ' the', 'timestamp': (29.9, 29.9)}, - {'text': ' tents.', 'timestamp': (29.9, 29.9)} - ] - } + {"text": " Conquered", "timestamp": (0.5, 1.2)}, + {"text": " returned", "timestamp": (1.2, 1.64)}, + {"text": " to", "timestamp": (1.64, 1.84)}, + {"text": " its", "timestamp": (1.84, 2.02)}, + {"text": " place", "timestamp": (2.02, 2.28)}, + {"text": " amidst", "timestamp": (2.28, 2.78)}, + {"text": " the", "timestamp": (2.78, 2.96)}, + {"text": " tents.", "timestamp": (2.96, 3.48)}, + ], + }, ) # fmt: on