[Whisper] Fix word-level timestamps for audio < 30 seconds (#25607)
* Fix word-level timestamps for audio < 30 seconds * Fix code quality * fix unit tests * Fix unit tests * Fix unit test * temp: print out result * temp: set max diff to None * fix unit tests * fix typo * Fix typo Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Use generation config for `num_frames` * fix docs * Move `num_frames` to kwargs * compute stride/attn_mask once * mark test as slow --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: sanchit-gandhi <sanchit@huggingface.co>
This commit is contained in:
parent
44a0490d3c
commit
95fe0f5d80
|
@ -1754,6 +1754,9 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||||
"See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
|
"See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if kwargs.get("num_frames") is not None:
|
||||||
|
generation_config.num_frames = kwargs.pop("num_frames")
|
||||||
|
|
||||||
outputs = super().generate(
|
outputs = super().generate(
|
||||||
inputs,
|
inputs,
|
||||||
generation_config,
|
generation_config,
|
||||||
|
@ -1765,7 +1768,10 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
|
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
|
||||||
outputs["token_timestamps"] = self._extract_token_timestamps(outputs, generation_config.alignment_heads)
|
num_frames = getattr(generation_config, "num_frames", None)
|
||||||
|
outputs["token_timestamps"] = self._extract_token_timestamps(
|
||||||
|
outputs, generation_config.alignment_heads, num_frames=num_frames
|
||||||
|
)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@ -1799,10 +1805,11 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||||
)
|
)
|
||||||
return reordered_past
|
return reordered_past
|
||||||
|
|
||||||
def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02):
|
def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None):
|
||||||
"""
|
"""
|
||||||
Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to
|
Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to
|
||||||
map each output token to a position in the input audio.
|
map each output token to a position in the input audio. If `num_frames` is specified, the encoder-decoder
|
||||||
|
cross-attentions will be cropped before applying DTW.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tensor containing the timestamps in seconds for each predicted token
|
tensor containing the timestamps in seconds for each predicted token
|
||||||
|
@ -1817,6 +1824,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||||
# of shape (batch size, num selected, output length, input length).
|
# of shape (batch size, num selected, output length, input length).
|
||||||
weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
|
weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
|
||||||
weights = weights.permute([1, 0, 2, 3])
|
weights = weights.permute([1, 0, 2, 3])
|
||||||
|
if num_frames is not None:
|
||||||
|
weights = weights[..., : num_frames // 2]
|
||||||
|
|
||||||
# Normalize and smoothen the weights.
|
# Normalize and smoothen the weights.
|
||||||
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
||||||
|
|
|
@ -523,10 +523,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||||
if generate_kwargs is None:
|
if generate_kwargs is None:
|
||||||
generate_kwargs = {}
|
generate_kwargs = {}
|
||||||
|
|
||||||
if return_timestamps and self.type == "seq2seq_whisper":
|
attention_mask = model_inputs.pop("attention_mask", None)
|
||||||
generate_kwargs["return_timestamps"] = return_timestamps
|
stride = model_inputs.pop("stride", None)
|
||||||
if return_timestamps == "word":
|
|
||||||
generate_kwargs["return_token_timestamps"] = True
|
|
||||||
is_last = model_inputs.pop("is_last")
|
is_last = model_inputs.pop("is_last")
|
||||||
|
|
||||||
if self.type in {"seq2seq", "seq2seq_whisper"}:
|
if self.type in {"seq2seq", "seq2seq_whisper"}:
|
||||||
|
@ -543,11 +541,15 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||||
f"`input_features` or `input_values` key, but only has {model_inputs.keys()}"
|
f"`input_features` or `input_values` key, but only has {model_inputs.keys()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# we need to pass `processed.get("attention_mask")` here since audio encoder
|
# custom processing for Whisper timestamps and word-level timestamps
|
||||||
# attention mask length is different from expected text decoder `encoder_attention_mask` length
|
if return_timestamps and self.type == "seq2seq_whisper":
|
||||||
# `generate` magic to create the mask automatically won't work, we basically need to help
|
generate_kwargs["return_timestamps"] = return_timestamps
|
||||||
# it here.
|
if return_timestamps == "word":
|
||||||
attention_mask = model_inputs.pop("attention_mask", None)
|
generate_kwargs["return_token_timestamps"] = True
|
||||||
|
|
||||||
|
if stride is not None:
|
||||||
|
generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length
|
||||||
|
|
||||||
tokens = self.model.generate(
|
tokens = self.model.generate(
|
||||||
encoder_outputs=encoder(inputs, attention_mask=attention_mask),
|
encoder_outputs=encoder(inputs, attention_mask=attention_mask),
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
@ -558,14 +560,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||||
else:
|
else:
|
||||||
out = {"tokens": tokens}
|
out = {"tokens": tokens}
|
||||||
if self.type == "seq2seq_whisper":
|
if self.type == "seq2seq_whisper":
|
||||||
stride = model_inputs.pop("stride", None)
|
|
||||||
if stride is not None:
|
if stride is not None:
|
||||||
out["stride"] = stride
|
out["stride"] = stride
|
||||||
|
|
||||||
else:
|
else:
|
||||||
stride = model_inputs.pop("stride", None)
|
|
||||||
input_values = model_inputs.pop("input_values")
|
input_values = model_inputs.pop("input_values")
|
||||||
attention_mask = model_inputs.pop("attention_mask", None)
|
|
||||||
outputs = self.model(input_values=input_values, attention_mask=attention_mask)
|
outputs = self.model(input_values=input_values, attention_mask=attention_mask)
|
||||||
logits = outputs.logits
|
logits = outputs.logits
|
||||||
|
|
||||||
|
|
|
@ -299,6 +299,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||||
output = speech_recognizer(filename)
|
output = speech_recognizer(filename)
|
||||||
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
|
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
|
||||||
|
|
||||||
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
def test_return_timestamps_in_preprocess(self):
|
def test_return_timestamps_in_preprocess(self):
|
||||||
|
@ -319,28 +320,28 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||||
res,
|
res,
|
||||||
{
|
{
|
||||||
"text": " Conquered returned to its place amidst the tents.",
|
"text": " Conquered returned to its place amidst the tents.",
|
||||||
"chunks": [{"text": " Conquered returned to its place amidst the tents.", "timestamp": (0.0, 3.36)}],
|
"chunks": [{"timestamp": (0.0, 3.36), "text": " Conquered returned to its place amidst the tents."}],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
|
pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
|
||||||
res = pipe(sample["audio"]["array"], return_timestamps="word")
|
res = pipe(sample["audio"]["array"], return_timestamps="word")
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
# Note that the word-level timestamps predicted here are pretty bad.
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
res,
|
res,
|
||||||
{
|
{
|
||||||
"text": " Conquered returned to its place amidst the tents.",
|
"text": " Conquered returned to its place amidst the tents.",
|
||||||
"chunks": [
|
"chunks": [
|
||||||
{'text': ' Conquered', 'timestamp': (29.78, 29.9)},
|
{"text": " Conquered", "timestamp": (0.5, 1.2)},
|
||||||
{'text': ' returned', 'timestamp': (29.9, 29.9)},
|
{"text": " returned", "timestamp": (1.2, 1.64)},
|
||||||
{'text': ' to', 'timestamp': (29.9, 29.9)},
|
{"text": " to", "timestamp": (1.64, 1.84)},
|
||||||
{'text': ' its', 'timestamp': (29.9, 29.9)},
|
{"text": " its", "timestamp": (1.84, 2.02)},
|
||||||
{'text': ' place', 'timestamp': (29.9, 29.9)},
|
{"text": " place", "timestamp": (2.02, 2.28)},
|
||||||
{'text': ' amidst', 'timestamp': (29.9, 29.9)},
|
{"text": " amidst", "timestamp": (2.28, 2.78)},
|
||||||
{'text': ' the', 'timestamp': (29.9, 29.9)},
|
{"text": " the", "timestamp": (2.78, 2.96)},
|
||||||
{'text': ' tents.', 'timestamp': (29.9, 29.9)}
|
{"text": " tents.", "timestamp": (2.96, 3.48)},
|
||||||
]
|
],
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue