Token level timestamps for long-form generation in Whisper (#29148)

This commit is contained in:
Raushan Turganbay 2024-02-27 23:15:26 +05:00 committed by GitHub
parent 8a1faf2803
commit ddf7ac4237
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 141 additions and 3 deletions

View File

@ -720,6 +720,7 @@ class WhisperGenerationMixin:
input_stride=input_stride,
prev_idx=prev_i,
idx=i,
return_token_timestamps=return_token_timestamps,
)
current_segments[prev_i] += segments
@ -809,11 +810,15 @@ class WhisperGenerationMixin:
# remove eos token id
if is_not_final and seek_sequence[-1] == generation_config.eos_token_id:
seek_sequence = seek_sequence[:-1]
if return_token_timestamps:
seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-1]
# remove all padding tokens
if seek_sequence[-1] == generation_config.pad_token_id:
num_paddings = (seek_sequence == generation_config.pad_token_id).sum()
seek_sequence = seek_sequence[:-num_paddings]
if return_token_timestamps:
seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-num_paddings]
# check which sequences in batch need fallback & which should be skipped
needs_fallback[i], should_skip[i] = self._need_fallback(
@ -878,15 +883,18 @@ class WhisperGenerationMixin:
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
seek_outputs, generation_config.alignment_heads, num_frames=num_frames
)
seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, decoder_input_ids.shape[-1] :]
seek_outputs["sequences"] = seek_outputs["sequences"][:, decoder_input_ids.shape[-1] :]
def split_by_batch_index(values, key, batch_idx):
if key == "scores":
return [v[batch_idx].cpu() for v in values]
if key == "past_key_values":
elif key == "past_key_values":
# we don't save `past_key_values` as this is too costly
return None
elif isinstance(values[batch_idx], tuple) and torch.is_tensor(values[batch_idx][0]):
return tuple(tuple(w[batch_idx][None].cpu() for w in v) for v in values)
return values[batch_idx].cpu()
sequence_tokens = seek_outputs["sequences"]
@ -1611,6 +1619,7 @@ class WhisperGenerationMixin:
input_stride,
prev_idx,
idx,
return_token_timestamps,
):
# find the predicted "end of segment" predictions of Whisper
# "end of segment" predictions occur whenever Whisper predicts a timestamp token
@ -1618,6 +1627,7 @@ class WhisperGenerationMixin:
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
timestamp_segment_indices.add_(1)
token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
# If whisper predicted a "end of segment" via a timestep token, let's go ever each
# "end of segment" prediction and slice the decoding into segments accordingly
@ -1642,6 +1652,10 @@ class WhisperGenerationMixin:
"result": seek_outputs[idx],
}
)
if return_token_timestamps:
segments[-1]["token_timestamps"] = (
token_timestamps[last_slice:current_slice] + time_offset[prev_idx]
)
last_slice = current_slice
if single_timestamp_ending:
@ -1661,7 +1675,6 @@ class WhisperGenerationMixin:
if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin:
# no consecutive timestamps but it has a timestamp; use the last one.
last_timestamp_pos = timestamps[-1].item() - timestamp_begin
segments = [
{
"start": time_offset[prev_idx],
@ -1670,6 +1683,8 @@ class WhisperGenerationMixin:
"result": seek_outputs[idx],
}
]
if return_token_timestamps:
segments[-1]["token_timestamps"] = token_timestamps + time_offset[prev_idx]
segment_offset = seek_num_frames[prev_idx]
return segments, segment_offset

View File

@ -483,6 +483,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
generate_kwargs["return_timestamps"] = return_timestamps
if return_timestamps == "word":
generate_kwargs["return_token_timestamps"] = True
generate_kwargs["return_segments"] = True
if stride is not None:
if isinstance(stride, tuple):
@ -499,8 +500,16 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
attention_mask=attention_mask,
**generate_kwargs,
)
# whisper longform generation stores timestamps in "segments"
if return_timestamps == "word" and self.type == "seq2seq_whisper":
out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]}
if "segments" not in tokens:
out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]}
else:
token_timestamps = [
torch.cat([segment["token_timestamps"] for segment in segment_list])
for segment_list in tokens["segments"]
]
out = {"tokens": tokens["sequences"], "token_timestamps": token_timestamps}
else:
out = {"tokens": tokens}
if self.type == "seq2seq_whisper":

View File

@ -1969,6 +1969,56 @@ class WhisperModelIntegrationTests(unittest.TestCase):
self.assertEqual(len(generate_outputs.sequences), num_return_sequences * num_samples)
@slow
def test_tiny_token_timestamp_generation_longform(self):
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device)
model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
input_speech = self._load_datasamples(5)
long_input_speech = np.concatenate(input_speech, dtype=np.float32)
inputs = processor.feature_extractor(
raw_speech=long_input_speech,
return_tensors="pt",
truncation=False, # False so the audio isn't truncated and whole audio is sent to the model
return_attention_mask=True,
padding=True,
)
inputs = inputs.to(torch_device)
generate_outputs = model.generate(**inputs, return_segments=True, return_token_timestamps=True)
token_timestamps_shape = [
[segment["token_timestamps"].shape for segment in segment_list]
for segment_list in generate_outputs["segments"]
]
tokens_shape = [
[segment["tokens"].shape for segment in segment_list] for segment_list in generate_outputs["segments"]
]
self.assertListEqual(tokens_shape, token_timestamps_shape)
# fmt: off
EXPECTED_OUTPUT = [
torch.tensor([0.0000, 0.4200, 0.8200, 0.9400, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0400, 2.3400, 2.5200, 2.6600, 3.2000, 3.4400, 3.5600, 3.6800, 3.8200, 4.1000, 4.3000, 4.5800, 4.9400, 5.4000, 6.3600]),
torch.tensor([ 6.5400, 6.5400, 6.7400, 6.9600, 7.2600, 7.3400, 7.5800, 7.5800, 7.6400, 7.8400, 8.1000, 8.5000, 9.0000, 9.4800, 9.7200, 10.2600, 11.1000]),
torch.tensor([11.2200, 11.2200, 11.4200, 11.6600, 12.0800, 12.4400, 12.5800, 12.8400, 13.1800, 13.6800, 14.0000, 14.2200, 14.6200, 14.9800, 15.2200, 15.6000, 15.9400, 16.2000, 16.5600, 16.8400, 16.9800]),
torch.tensor([16.9800, 16.9800, 17.3200, 18.1600, 18.6400, 18.8600, 19.2800, 19.5600, 19.8800, 20.1800, 20.3800, 20.7200, 21.1600, 21.5400, 21.9000, 22.2000, 22.4200, 22.8600, 23.7000]),
torch.tensor([23.7000, 23.7000, 23.9400, 24.1800, 24.3800, 24.8400, 25.2800, 25.6600, 25.9200, 26.2600, 26.4000, 26.5800, 26.7600, 27.1400, 27.3800, 28.0400, 28.3800, 28.8200, 29.3400, 29.5200]),
torch.tensor([29.4400, 29.4400, 29.7000, 30.0800, 30.3800, 30.5400, 30.8200, 31.0600, 31.6600, 31.9200, 32.3000, 32.4800, 32.6200, 33.6800]),
torch.tensor([33.8000, 33.8000, 33.9800, 33.9800, 34.1800, 34.4400, 34.6200, 35.0000, 35.2200, 35.3200, 35.5600, 35.9200, 36.3800, 36.6200, 36.6600, 36.9600, 37.3400, 37.9800, 38.5800, 38.7200, 38.9800, 39.4400, 39.5800, 39.8000, 40.1200, 40.2600]),
torch.tensor([40.5200, 40.5200, 40.6200, 41.1000, 41.5400, 41.9200, 42.1000, 42.3200, 42.3200, 43.0600, 44.6000]),
torch.tensor([44.7000, 44.7000, 44.8600, 44.9400, 45.1400, 45.1400, 45.2800, 45.6200, 45.9000, 46.2600, 47.1600, 47.4800, 47.7400, 48.1000, 48.2800, 48.4000, 48.6200, 48.8400, 49.0400, 49.2800, 49.4800, 49.6600, 49.9400, 50.5400]),
torch.tensor([50.5400, 50.5400, 50.6600, 50.8800, 51.2400, 51.7200, 52.8400]),
torch.tensor([52.9600, 52.9600, 53.0400, 53.2600, 53.4200, 53.5800, 53.9200, 54.1200, 54.7200, 54.9400, 55.2600, 55.6200, 55.9800, 56.5600, 56.8000, 56.9200, 57.3600, 57.9200, 58.1800, 58.5000, 58.6400, 58.8200]),
torch.tensor([58.6800, 58.6800, 59.1400, 59.5400, 59.9200, 60.1600, 60.3800, 60.8200, 61.6200, 62.2600, 75.2000]),
]
# fmt: on
for segment, exp_segment in zip(generate_outputs["segments"][0], EXPECTED_OUTPUT):
self.assertTrue(torch.allclose(segment["token_timestamps"], exp_segment))
@slow
def test_tiny_specaugment_librispeech(self):
torch_device = "cpu"

View File

@ -361,6 +361,70 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
)
# fmt: on
@slow
@require_torch
def test_return_timestamps_in_preprocess_longform(self):
pipe = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-tiny.en",
)
data = load_dataset("librispeech_asr", "clean", split="test", streaming=True)
samples = [next(iter(data)) for _ in range(8)]
audio = np.concatenate([sample["audio"]["array"] for sample in samples])
res = pipe(audio)
expected_output = {
"text": " Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst "
"the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst "
"the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst "
"the tents. Concord returned to its place amidst the tents."
}
self.assertEqual(res, expected_output)
res = pipe(audio, return_timestamps=True)
self.assertEqual(
res,
{
"text": " Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents.",
"chunks": [
{"timestamp": (0.0, 3.22), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (3.22, 6.74), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (6.74, 10.26), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (10.26, 13.78), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (13.78, 17.3), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (17.3, 20.82), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (20.82, 24.34), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (24.34, 27.86), "text": " Concord returned to its place amidst the tents."},
],
},
)
pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
res = pipe(audio, return_timestamps="word")
# fmt: off
self.assertEqual(
res["chunks"][:15],
[
{"text": " Concord", "timestamp": (0.5, 0.94)},
{"text": " returned", "timestamp": (0.94, 1.52)},
{"text": " to", "timestamp": (1.52, 1.78)},
{"text": " its", "timestamp": (1.78, 1.98)},
{"text": " place", "timestamp": (1.98, 2.16)},
{"text": " amidst", "timestamp": (2.16, 2.5)},
{"text": " the", "timestamp": (2.5, 2.9)},
{"text": " tents.", "timestamp": (2.9, 4.2)},
{"text": " Concord", "timestamp": (4.2, 4.5)},
{"text": " returned", "timestamp": (4.5, 5.0)},
{"text": " to", "timestamp": (5.0, 5.28)},
{"text": " its", "timestamp": (5.28, 5.48)},
{"text": " place", "timestamp": (5.48, 5.7)},
{"text": " amidst", "timestamp": (5.7, 6.02)},
{"text": " the", "timestamp": (6.02, 6.4)}
],
)
# fmt: on
@require_torch
def test_return_timestamps_in_init(self):
# segment-level timestamps are accepted