[Whisper] Fix word-level timestamps for audio < 30 seconds (#25607)
* Fix word-level timestamps for audio < 30 seconds * Fix code quality * fix unit tests * Fix unit tests * Fix unit test * temp: print out result * temp: set max diff to None * fix unit tests * fix typo * Fix typo Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Use generation config for `num_frames` * fix docs * Move `num_frames` to kwargs * compute stride/attn_mask once * mark test as slow --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: sanchit-gandhi <sanchit@huggingface.co>
This commit is contained in:
parent
44a0490d3c
commit
95fe0f5d80
|
@ -1754,6 +1754,9 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||
"See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
|
||||
)
|
||||
|
||||
if kwargs.get("num_frames") is not None:
|
||||
generation_config.num_frames = kwargs.pop("num_frames")
|
||||
|
||||
outputs = super().generate(
|
||||
inputs,
|
||||
generation_config,
|
||||
|
@ -1765,7 +1768,10 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||
)
|
||||
|
||||
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
|
||||
outputs["token_timestamps"] = self._extract_token_timestamps(outputs, generation_config.alignment_heads)
|
||||
num_frames = getattr(generation_config, "num_frames", None)
|
||||
outputs["token_timestamps"] = self._extract_token_timestamps(
|
||||
outputs, generation_config.alignment_heads, num_frames=num_frames
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
@ -1799,10 +1805,11 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||
)
|
||||
return reordered_past
|
||||
|
||||
def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02):
|
||||
def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None):
|
||||
"""
|
||||
Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to
|
||||
map each output token to a position in the input audio.
|
||||
map each output token to a position in the input audio. If `num_frames` is specified, the encoder-decoder
|
||||
cross-attentions will be cropped before applying DTW.
|
||||
|
||||
Returns:
|
||||
tensor containing the timestamps in seconds for each predicted token
|
||||
|
@ -1817,6 +1824,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||
# of shape (batch size, num selected, output length, input length).
|
||||
weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
|
||||
weights = weights.permute([1, 0, 2, 3])
|
||||
if num_frames is not None:
|
||||
weights = weights[..., : num_frames // 2]
|
||||
|
||||
# Normalize and smoothen the weights.
|
||||
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
||||
|
|
|
@ -523,10 +523,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||
if generate_kwargs is None:
|
||||
generate_kwargs = {}
|
||||
|
||||
if return_timestamps and self.type == "seq2seq_whisper":
|
||||
generate_kwargs["return_timestamps"] = return_timestamps
|
||||
if return_timestamps == "word":
|
||||
generate_kwargs["return_token_timestamps"] = True
|
||||
attention_mask = model_inputs.pop("attention_mask", None)
|
||||
stride = model_inputs.pop("stride", None)
|
||||
is_last = model_inputs.pop("is_last")
|
||||
|
||||
if self.type in {"seq2seq", "seq2seq_whisper"}:
|
||||
|
@ -543,11 +541,15 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||
f"`input_features` or `input_values` key, but only has {model_inputs.keys()}"
|
||||
)
|
||||
|
||||
# we need to pass `processed.get("attention_mask")` here since audio encoder
|
||||
# attention mask length is different from expected text decoder `encoder_attention_mask` length
|
||||
# `generate` magic to create the mask automatically won't work, we basically need to help
|
||||
# it here.
|
||||
attention_mask = model_inputs.pop("attention_mask", None)
|
||||
# custom processing for Whisper timestamps and word-level timestamps
|
||||
if return_timestamps and self.type == "seq2seq_whisper":
|
||||
generate_kwargs["return_timestamps"] = return_timestamps
|
||||
if return_timestamps == "word":
|
||||
generate_kwargs["return_token_timestamps"] = True
|
||||
|
||||
if stride is not None:
|
||||
generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length
|
||||
|
||||
tokens = self.model.generate(
|
||||
encoder_outputs=encoder(inputs, attention_mask=attention_mask),
|
||||
attention_mask=attention_mask,
|
||||
|
@ -558,14 +560,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||
else:
|
||||
out = {"tokens": tokens}
|
||||
if self.type == "seq2seq_whisper":
|
||||
stride = model_inputs.pop("stride", None)
|
||||
if stride is not None:
|
||||
out["stride"] = stride
|
||||
|
||||
else:
|
||||
stride = model_inputs.pop("stride", None)
|
||||
input_values = model_inputs.pop("input_values")
|
||||
attention_mask = model_inputs.pop("attention_mask", None)
|
||||
outputs = self.model(input_values=input_values, attention_mask=attention_mask)
|
||||
logits = outputs.logits
|
||||
|
||||
|
|
|
@ -299,6 +299,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||
output = speech_recognizer(filename)
|
||||
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@slow
|
||||
def test_return_timestamps_in_preprocess(self):
|
||||
|
@ -319,28 +320,28 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||
res,
|
||||
{
|
||||
"text": " Conquered returned to its place amidst the tents.",
|
||||
"chunks": [{"text": " Conquered returned to its place amidst the tents.", "timestamp": (0.0, 3.36)}],
|
||||
"chunks": [{"timestamp": (0.0, 3.36), "text": " Conquered returned to its place amidst the tents."}],
|
||||
},
|
||||
)
|
||||
pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
|
||||
res = pipe(sample["audio"]["array"], return_timestamps="word")
|
||||
|
||||
# fmt: off
|
||||
# Note that the word-level timestamps predicted here are pretty bad.
|
||||
self.assertEqual(
|
||||
res,
|
||||
{
|
||||
"text": " Conquered returned to its place amidst the tents.",
|
||||
"chunks": [
|
||||
{'text': ' Conquered', 'timestamp': (29.78, 29.9)},
|
||||
{'text': ' returned', 'timestamp': (29.9, 29.9)},
|
||||
{'text': ' to', 'timestamp': (29.9, 29.9)},
|
||||
{'text': ' its', 'timestamp': (29.9, 29.9)},
|
||||
{'text': ' place', 'timestamp': (29.9, 29.9)},
|
||||
{'text': ' amidst', 'timestamp': (29.9, 29.9)},
|
||||
{'text': ' the', 'timestamp': (29.9, 29.9)},
|
||||
{'text': ' tents.', 'timestamp': (29.9, 29.9)}
|
||||
]
|
||||
}
|
||||
{"text": " Conquered", "timestamp": (0.5, 1.2)},
|
||||
{"text": " returned", "timestamp": (1.2, 1.64)},
|
||||
{"text": " to", "timestamp": (1.64, 1.84)},
|
||||
{"text": " its", "timestamp": (1.84, 2.02)},
|
||||
{"text": " place", "timestamp": (2.02, 2.28)},
|
||||
{"text": " amidst", "timestamp": (2.28, 2.78)},
|
||||
{"text": " the", "timestamp": (2.78, 2.96)},
|
||||
{"text": " tents.", "timestamp": (2.96, 3.48)},
|
||||
],
|
||||
},
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
|
|
Loading…
Reference in New Issue