[Whisper] Fix word-level timestamps for audio < 30 seconds (#25607)

* Fix word-level timestamps for audio < 30 seconds

* Fix code quality

* fix unit tests

* Fix unit tests

* Fix unit test

* temp: print out result

* temp: set max diff to None

* fix unit tests

* fix typo

* Fix typo

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Use generation config for `num_frames`

* fix docs

* Move `num_frames` to kwargs

* compute stride/attn_mask once

* mark test as slow

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: sanchit-gandhi <sanchit@huggingface.co>
This commit is contained in:
Joshua Lochner 2023-09-14 18:42:35 +02:00 committed by GitHub
parent 44a0490d3c
commit 95fe0f5d80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 27 deletions

View File

@ -1754,6 +1754,9 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
"See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config." "See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
) )
if kwargs.get("num_frames") is not None:
generation_config.num_frames = kwargs.pop("num_frames")
outputs = super().generate( outputs = super().generate(
inputs, inputs,
generation_config, generation_config,
@ -1765,7 +1768,10 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
) )
if return_token_timestamps and hasattr(generation_config, "alignment_heads"): if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
outputs["token_timestamps"] = self._extract_token_timestamps(outputs, generation_config.alignment_heads) num_frames = getattr(generation_config, "num_frames", None)
outputs["token_timestamps"] = self._extract_token_timestamps(
outputs, generation_config.alignment_heads, num_frames=num_frames
)
return outputs return outputs
@ -1799,10 +1805,11 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
) )
return reordered_past return reordered_past
def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02): def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None):
""" """
Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to
map each output token to a position in the input audio. map each output token to a position in the input audio. If `num_frames` is specified, the encoder-decoder
cross-attentions will be cropped before applying DTW.
Returns: Returns:
tensor containing the timestamps in seconds for each predicted token tensor containing the timestamps in seconds for each predicted token
@ -1817,6 +1824,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
# of shape (batch size, num selected, output length, input length). # of shape (batch size, num selected, output length, input length).
weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads]) weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
weights = weights.permute([1, 0, 2, 3]) weights = weights.permute([1, 0, 2, 3])
if num_frames is not None:
weights = weights[..., : num_frames // 2]
# Normalize and smoothen the weights. # Normalize and smoothen the weights.
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False) std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)

View File

@ -523,10 +523,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if generate_kwargs is None: if generate_kwargs is None:
generate_kwargs = {} generate_kwargs = {}
if return_timestamps and self.type == "seq2seq_whisper": attention_mask = model_inputs.pop("attention_mask", None)
generate_kwargs["return_timestamps"] = return_timestamps stride = model_inputs.pop("stride", None)
if return_timestamps == "word":
generate_kwargs["return_token_timestamps"] = True
is_last = model_inputs.pop("is_last") is_last = model_inputs.pop("is_last")
if self.type in {"seq2seq", "seq2seq_whisper"}: if self.type in {"seq2seq", "seq2seq_whisper"}:
@ -543,11 +541,15 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
f"`input_features` or `input_values` key, but only has {model_inputs.keys()}" f"`input_features` or `input_values` key, but only has {model_inputs.keys()}"
) )
# we need to pass `processed.get("attention_mask")` here since audio encoder # custom processing for Whisper timestamps and word-level timestamps
# attention mask length is different from expected text decoder `encoder_attention_mask` length if return_timestamps and self.type == "seq2seq_whisper":
# `generate` magic to create the mask automatically won't work, we basically need to help generate_kwargs["return_timestamps"] = return_timestamps
# it here. if return_timestamps == "word":
attention_mask = model_inputs.pop("attention_mask", None) generate_kwargs["return_token_timestamps"] = True
if stride is not None:
generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length
tokens = self.model.generate( tokens = self.model.generate(
encoder_outputs=encoder(inputs, attention_mask=attention_mask), encoder_outputs=encoder(inputs, attention_mask=attention_mask),
attention_mask=attention_mask, attention_mask=attention_mask,
@ -558,14 +560,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
else: else:
out = {"tokens": tokens} out = {"tokens": tokens}
if self.type == "seq2seq_whisper": if self.type == "seq2seq_whisper":
stride = model_inputs.pop("stride", None)
if stride is not None: if stride is not None:
out["stride"] = stride out["stride"] = stride
else: else:
stride = model_inputs.pop("stride", None)
input_values = model_inputs.pop("input_values") input_values = model_inputs.pop("input_values")
attention_mask = model_inputs.pop("attention_mask", None)
outputs = self.model(input_values=input_values, attention_mask=attention_mask) outputs = self.model(input_values=input_values, attention_mask=attention_mask)
logits = outputs.logits logits = outputs.logits

View File

@ -299,6 +299,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
output = speech_recognizer(filename) output = speech_recognizer(filename)
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"}) self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
@slow
@require_torch @require_torch
@slow @slow
def test_return_timestamps_in_preprocess(self): def test_return_timestamps_in_preprocess(self):
@ -319,28 +320,28 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
res, res,
{ {
"text": " Conquered returned to its place amidst the tents.", "text": " Conquered returned to its place amidst the tents.",
"chunks": [{"text": " Conquered returned to its place amidst the tents.", "timestamp": (0.0, 3.36)}], "chunks": [{"timestamp": (0.0, 3.36), "text": " Conquered returned to its place amidst the tents."}],
}, },
) )
pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]] pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
res = pipe(sample["audio"]["array"], return_timestamps="word") res = pipe(sample["audio"]["array"], return_timestamps="word")
# fmt: off # fmt: off
# Note that the word-level timestamps predicted here are pretty bad.
self.assertEqual( self.assertEqual(
res, res,
{ {
"text": " Conquered returned to its place amidst the tents.", "text": " Conquered returned to its place amidst the tents.",
"chunks": [ "chunks": [
{'text': ' Conquered', 'timestamp': (29.78, 29.9)}, {"text": " Conquered", "timestamp": (0.5, 1.2)},
{'text': ' returned', 'timestamp': (29.9, 29.9)}, {"text": " returned", "timestamp": (1.2, 1.64)},
{'text': ' to', 'timestamp': (29.9, 29.9)}, {"text": " to", "timestamp": (1.64, 1.84)},
{'text': ' its', 'timestamp': (29.9, 29.9)}, {"text": " its", "timestamp": (1.84, 2.02)},
{'text': ' place', 'timestamp': (29.9, 29.9)}, {"text": " place", "timestamp": (2.02, 2.28)},
{'text': ' amidst', 'timestamp': (29.9, 29.9)}, {"text": " amidst", "timestamp": (2.28, 2.78)},
{'text': ' the', 'timestamp': (29.9, 29.9)}, {"text": " the", "timestamp": (2.78, 2.96)},
{'text': ' tents.', 'timestamp': (29.9, 29.9)} {"text": " tents.", "timestamp": (2.96, 3.48)},
] ],
} },
) )
# fmt: on # fmt: on