Token level timestamps for long-form generation in Whisper (#29148)
This commit is contained in:
parent
8a1faf2803
commit
ddf7ac4237
|
@ -720,6 +720,7 @@ class WhisperGenerationMixin:
|
|||
input_stride=input_stride,
|
||||
prev_idx=prev_i,
|
||||
idx=i,
|
||||
return_token_timestamps=return_token_timestamps,
|
||||
)
|
||||
|
||||
current_segments[prev_i] += segments
|
||||
|
@ -809,11 +810,15 @@ class WhisperGenerationMixin:
|
|||
# remove eos token id
|
||||
if is_not_final and seek_sequence[-1] == generation_config.eos_token_id:
|
||||
seek_sequence = seek_sequence[:-1]
|
||||
if return_token_timestamps:
|
||||
seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-1]
|
||||
|
||||
# remove all padding tokens
|
||||
if seek_sequence[-1] == generation_config.pad_token_id:
|
||||
num_paddings = (seek_sequence == generation_config.pad_token_id).sum()
|
||||
seek_sequence = seek_sequence[:-num_paddings]
|
||||
if return_token_timestamps:
|
||||
seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-num_paddings]
|
||||
|
||||
# check which sequences in batch need fallback & which should be skipped
|
||||
needs_fallback[i], should_skip[i] = self._need_fallback(
|
||||
|
@ -878,15 +883,18 @@ class WhisperGenerationMixin:
|
|||
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
|
||||
seek_outputs, generation_config.alignment_heads, num_frames=num_frames
|
||||
)
|
||||
seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, decoder_input_ids.shape[-1] :]
|
||||
|
||||
seek_outputs["sequences"] = seek_outputs["sequences"][:, decoder_input_ids.shape[-1] :]
|
||||
|
||||
def split_by_batch_index(values, key, batch_idx):
|
||||
if key == "scores":
|
||||
return [v[batch_idx].cpu() for v in values]
|
||||
if key == "past_key_values":
|
||||
elif key == "past_key_values":
|
||||
# we don't save `past_key_values` as this is too costly
|
||||
return None
|
||||
elif isinstance(values[batch_idx], tuple) and torch.is_tensor(values[batch_idx][0]):
|
||||
return tuple(tuple(w[batch_idx][None].cpu() for w in v) for v in values)
|
||||
return values[batch_idx].cpu()
|
||||
|
||||
sequence_tokens = seek_outputs["sequences"]
|
||||
|
@ -1611,6 +1619,7 @@ class WhisperGenerationMixin:
|
|||
input_stride,
|
||||
prev_idx,
|
||||
idx,
|
||||
return_token_timestamps,
|
||||
):
|
||||
# find the predicted "end of segment" predictions of Whisper
|
||||
# "end of segment" predictions occur whenever Whisper predicts a timestamp token
|
||||
|
@ -1618,6 +1627,7 @@ class WhisperGenerationMixin:
|
|||
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
||||
timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
||||
timestamp_segment_indices.add_(1)
|
||||
token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
|
||||
|
||||
# If whisper predicted a "end of segment" via a timestep token, let's go ever each
|
||||
# "end of segment" prediction and slice the decoding into segments accordingly
|
||||
|
@ -1642,6 +1652,10 @@ class WhisperGenerationMixin:
|
|||
"result": seek_outputs[idx],
|
||||
}
|
||||
)
|
||||
if return_token_timestamps:
|
||||
segments[-1]["token_timestamps"] = (
|
||||
token_timestamps[last_slice:current_slice] + time_offset[prev_idx]
|
||||
)
|
||||
last_slice = current_slice
|
||||
|
||||
if single_timestamp_ending:
|
||||
|
@ -1661,7 +1675,6 @@ class WhisperGenerationMixin:
|
|||
if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin:
|
||||
# no consecutive timestamps but it has a timestamp; use the last one.
|
||||
last_timestamp_pos = timestamps[-1].item() - timestamp_begin
|
||||
|
||||
segments = [
|
||||
{
|
||||
"start": time_offset[prev_idx],
|
||||
|
@ -1670,6 +1683,8 @@ class WhisperGenerationMixin:
|
|||
"result": seek_outputs[idx],
|
||||
}
|
||||
]
|
||||
if return_token_timestamps:
|
||||
segments[-1]["token_timestamps"] = token_timestamps + time_offset[prev_idx]
|
||||
segment_offset = seek_num_frames[prev_idx]
|
||||
|
||||
return segments, segment_offset
|
||||
|
|
|
@ -483,6 +483,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||
generate_kwargs["return_timestamps"] = return_timestamps
|
||||
if return_timestamps == "word":
|
||||
generate_kwargs["return_token_timestamps"] = True
|
||||
generate_kwargs["return_segments"] = True
|
||||
|
||||
if stride is not None:
|
||||
if isinstance(stride, tuple):
|
||||
|
@ -499,8 +500,16 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||
attention_mask=attention_mask,
|
||||
**generate_kwargs,
|
||||
)
|
||||
# whisper longform generation stores timestamps in "segments"
|
||||
if return_timestamps == "word" and self.type == "seq2seq_whisper":
|
||||
out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]}
|
||||
if "segments" not in tokens:
|
||||
out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]}
|
||||
else:
|
||||
token_timestamps = [
|
||||
torch.cat([segment["token_timestamps"] for segment in segment_list])
|
||||
for segment_list in tokens["segments"]
|
||||
]
|
||||
out = {"tokens": tokens["sequences"], "token_timestamps": token_timestamps}
|
||||
else:
|
||||
out = {"tokens": tokens}
|
||||
if self.type == "seq2seq_whisper":
|
||||
|
|
|
@ -1969,6 +1969,56 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||
|
||||
self.assertEqual(len(generate_outputs.sequences), num_return_sequences * num_samples)
|
||||
|
||||
@slow
|
||||
def test_tiny_token_timestamp_generation_longform(self):
|
||||
set_seed(0)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||
model.to(torch_device)
|
||||
model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
|
||||
|
||||
input_speech = self._load_datasamples(5)
|
||||
long_input_speech = np.concatenate(input_speech, dtype=np.float32)
|
||||
inputs = processor.feature_extractor(
|
||||
raw_speech=long_input_speech,
|
||||
return_tensors="pt",
|
||||
truncation=False, # False so the audio isn't truncated and whole audio is sent to the model
|
||||
return_attention_mask=True,
|
||||
padding=True,
|
||||
)
|
||||
|
||||
inputs = inputs.to(torch_device)
|
||||
generate_outputs = model.generate(**inputs, return_segments=True, return_token_timestamps=True)
|
||||
|
||||
token_timestamps_shape = [
|
||||
[segment["token_timestamps"].shape for segment in segment_list]
|
||||
for segment_list in generate_outputs["segments"]
|
||||
]
|
||||
tokens_shape = [
|
||||
[segment["tokens"].shape for segment in segment_list] for segment_list in generate_outputs["segments"]
|
||||
]
|
||||
self.assertListEqual(tokens_shape, token_timestamps_shape)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_OUTPUT = [
|
||||
torch.tensor([0.0000, 0.4200, 0.8200, 0.9400, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0400, 2.3400, 2.5200, 2.6600, 3.2000, 3.4400, 3.5600, 3.6800, 3.8200, 4.1000, 4.3000, 4.5800, 4.9400, 5.4000, 6.3600]),
|
||||
torch.tensor([ 6.5400, 6.5400, 6.7400, 6.9600, 7.2600, 7.3400, 7.5800, 7.5800, 7.6400, 7.8400, 8.1000, 8.5000, 9.0000, 9.4800, 9.7200, 10.2600, 11.1000]),
|
||||
torch.tensor([11.2200, 11.2200, 11.4200, 11.6600, 12.0800, 12.4400, 12.5800, 12.8400, 13.1800, 13.6800, 14.0000, 14.2200, 14.6200, 14.9800, 15.2200, 15.6000, 15.9400, 16.2000, 16.5600, 16.8400, 16.9800]),
|
||||
torch.tensor([16.9800, 16.9800, 17.3200, 18.1600, 18.6400, 18.8600, 19.2800, 19.5600, 19.8800, 20.1800, 20.3800, 20.7200, 21.1600, 21.5400, 21.9000, 22.2000, 22.4200, 22.8600, 23.7000]),
|
||||
torch.tensor([23.7000, 23.7000, 23.9400, 24.1800, 24.3800, 24.8400, 25.2800, 25.6600, 25.9200, 26.2600, 26.4000, 26.5800, 26.7600, 27.1400, 27.3800, 28.0400, 28.3800, 28.8200, 29.3400, 29.5200]),
|
||||
torch.tensor([29.4400, 29.4400, 29.7000, 30.0800, 30.3800, 30.5400, 30.8200, 31.0600, 31.6600, 31.9200, 32.3000, 32.4800, 32.6200, 33.6800]),
|
||||
torch.tensor([33.8000, 33.8000, 33.9800, 33.9800, 34.1800, 34.4400, 34.6200, 35.0000, 35.2200, 35.3200, 35.5600, 35.9200, 36.3800, 36.6200, 36.6600, 36.9600, 37.3400, 37.9800, 38.5800, 38.7200, 38.9800, 39.4400, 39.5800, 39.8000, 40.1200, 40.2600]),
|
||||
torch.tensor([40.5200, 40.5200, 40.6200, 41.1000, 41.5400, 41.9200, 42.1000, 42.3200, 42.3200, 43.0600, 44.6000]),
|
||||
torch.tensor([44.7000, 44.7000, 44.8600, 44.9400, 45.1400, 45.1400, 45.2800, 45.6200, 45.9000, 46.2600, 47.1600, 47.4800, 47.7400, 48.1000, 48.2800, 48.4000, 48.6200, 48.8400, 49.0400, 49.2800, 49.4800, 49.6600, 49.9400, 50.5400]),
|
||||
torch.tensor([50.5400, 50.5400, 50.6600, 50.8800, 51.2400, 51.7200, 52.8400]),
|
||||
torch.tensor([52.9600, 52.9600, 53.0400, 53.2600, 53.4200, 53.5800, 53.9200, 54.1200, 54.7200, 54.9400, 55.2600, 55.6200, 55.9800, 56.5600, 56.8000, 56.9200, 57.3600, 57.9200, 58.1800, 58.5000, 58.6400, 58.8200]),
|
||||
torch.tensor([58.6800, 58.6800, 59.1400, 59.5400, 59.9200, 60.1600, 60.3800, 60.8200, 61.6200, 62.2600, 75.2000]),
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
for segment, exp_segment in zip(generate_outputs["segments"][0], EXPECTED_OUTPUT):
|
||||
self.assertTrue(torch.allclose(segment["token_timestamps"], exp_segment))
|
||||
|
||||
@slow
|
||||
def test_tiny_specaugment_librispeech(self):
|
||||
torch_device = "cpu"
|
||||
|
|
|
@ -361,6 +361,70 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||
)
|
||||
# fmt: on
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_return_timestamps_in_preprocess_longform(self):
|
||||
pipe = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="openai/whisper-tiny.en",
|
||||
)
|
||||
data = load_dataset("librispeech_asr", "clean", split="test", streaming=True)
|
||||
samples = [next(iter(data)) for _ in range(8)]
|
||||
audio = np.concatenate([sample["audio"]["array"] for sample in samples])
|
||||
|
||||
res = pipe(audio)
|
||||
expected_output = {
|
||||
"text": " Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst "
|
||||
"the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst "
|
||||
"the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst "
|
||||
"the tents. Concord returned to its place amidst the tents."
|
||||
}
|
||||
self.assertEqual(res, expected_output)
|
||||
res = pipe(audio, return_timestamps=True)
|
||||
self.assertEqual(
|
||||
res,
|
||||
{
|
||||
"text": " Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents.",
|
||||
"chunks": [
|
||||
{"timestamp": (0.0, 3.22), "text": " Concord returned to its place amidst the tents."},
|
||||
{"timestamp": (3.22, 6.74), "text": " Concord returned to its place amidst the tents."},
|
||||
{"timestamp": (6.74, 10.26), "text": " Concord returned to its place amidst the tents."},
|
||||
{"timestamp": (10.26, 13.78), "text": " Concord returned to its place amidst the tents."},
|
||||
{"timestamp": (13.78, 17.3), "text": " Concord returned to its place amidst the tents."},
|
||||
{"timestamp": (17.3, 20.82), "text": " Concord returned to its place amidst the tents."},
|
||||
{"timestamp": (20.82, 24.34), "text": " Concord returned to its place amidst the tents."},
|
||||
{"timestamp": (24.34, 27.86), "text": " Concord returned to its place amidst the tents."},
|
||||
],
|
||||
},
|
||||
)
|
||||
pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
|
||||
res = pipe(audio, return_timestamps="word")
|
||||
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
res["chunks"][:15],
|
||||
[
|
||||
{"text": " Concord", "timestamp": (0.5, 0.94)},
|
||||
{"text": " returned", "timestamp": (0.94, 1.52)},
|
||||
{"text": " to", "timestamp": (1.52, 1.78)},
|
||||
{"text": " its", "timestamp": (1.78, 1.98)},
|
||||
{"text": " place", "timestamp": (1.98, 2.16)},
|
||||
{"text": " amidst", "timestamp": (2.16, 2.5)},
|
||||
{"text": " the", "timestamp": (2.5, 2.9)},
|
||||
{"text": " tents.", "timestamp": (2.9, 4.2)},
|
||||
{"text": " Concord", "timestamp": (4.2, 4.5)},
|
||||
{"text": " returned", "timestamp": (4.5, 5.0)},
|
||||
{"text": " to", "timestamp": (5.0, 5.28)},
|
||||
{"text": " its", "timestamp": (5.28, 5.48)},
|
||||
{"text": " place", "timestamp": (5.48, 5.7)},
|
||||
{"text": " amidst", "timestamp": (5.7, 6.02)},
|
||||
{"text": " the", "timestamp": (6.02, 6.4)}
|
||||
|
||||
|
||||
],
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
@require_torch
|
||||
def test_return_timestamps_in_init(self):
|
||||
# segment-level timestamps are accepted
|
||||
|
|
Loading…
Reference in New Issue