add word-level timestamps to Whisper (#23205)

* let's go!

* initial implementation of token-level timestamps

* only return a single timestamp per token

* remove token probabilities

* fix return type

* fix doc comment

* strip special tokens

* rename

* revert to not stripping special tokens

* only support models that have alignment_heads

* add integration test

* consistently name it token-level timestamps

* small DTW tweak

* initial support for ASR pipeline

* fix pipeline doc comments

* resolve token timestamps in pipeline with chunking

* change warning when no final timestamp is found

* return word-level timestamps

* fixup

* fix bug that skipped final word in each chunk

* fix failing unit tests

* merge punctuations into the words

* also return word tokens

* also return token indices

* add (failing) unit test for combine_tokens_into_words

* make combine_tokens_into_words private

* restore OpenAI's punctuation rules

* add pipeline tests

* make requested changes

* PR review changes

* fix failing pipeline test

* small stuff from PR

* only return words and their timestamps, not segments

* move alignment_heads into generation config

* forgot to set alignment_heads in pipeline tests

* tiny comment fix

* grr
This commit is contained in:
Matthijs Hollemans 2023-06-21 17:48:21 +02:00 committed by GitHub
parent 0f968ddaa3
commit cd927a4736
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 456 additions and 25 deletions

View File

@ -171,7 +171,9 @@ class WhisperConfig(PretrainedConfig):
The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
step, irrespectively of `mask_feature_prob`. Only relevant if
`mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks`.
median_filter_width (`int`, *optional*, defaults to 7):
Width of the median filter used to smoothen to cross-attention outputs when computing token timestamps.
Should be an odd number.
Example:
@ -229,6 +231,7 @@ class WhisperConfig(PretrainedConfig):
mask_feature_prob=0.0,
mask_feature_length=10,
mask_feature_min_masks=0,
median_filter_width=7,
**kwargs,
):
self.vocab_size = vocab_size
@ -265,6 +268,9 @@ class WhisperConfig(PretrainedConfig):
self.mask_feature_prob = mask_feature_prob
self.mask_feature_length = mask_feature_length
self.mask_feature_min_masks = mask_feature_min_masks
self.median_filter_width = median_filter_width
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,

View File

@ -227,6 +227,81 @@ def _compute_mask_indices(
return spec_aug_mask
def _median_filter(inputs: torch.Tensor, filter_width: int) -> torch.Tensor:
"""
Applies a median filter of width `filter_width` along the last dimension of the input.
The `inputs` tensor is assumed to be 3- or 4-dimensional.
"""
if filter_width <= 0 or filter_width % 2 != 1:
raise ValueError("`filter_width` should be an odd number")
pad_width = filter_width // 2
if inputs.shape[-1] <= pad_width:
return inputs
# Pad the left and right edges.
inputs = nn.functional.pad(inputs, (pad_width, pad_width, 0, 0), mode="reflect")
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
result = inputs.unfold(-1, filter_width, 1).sort()[0][..., pad_width]
return result
def _dynamic_time_warping(matrix: np.ndarray):
"""
Measures similarity between two temporal sequences: the input audio and the output tokens. Used to generate
token-level timestamps.
"""
output_length, input_length = matrix.shape
cost = np.ones((output_length + 1, input_length + 1), dtype=np.float32) * np.inf
trace = -np.ones((output_length + 1, input_length + 1), dtype=np.float32)
cost[0, 0] = 0
for j in range(1, input_length + 1):
for i in range(1, output_length + 1):
c0 = cost[i - 1, j - 1]
c1 = cost[i - 1, j]
c2 = cost[i, j - 1]
if c0 < c1 and c0 < c2:
c, t = c0, 0
elif c1 < c0 and c1 < c2:
c, t = c1, 1
else:
c, t = c2, 2
cost[i, j] = matrix[i - 1, j - 1] + c
trace[i, j] = t
# backtrace
i = trace.shape[0] - 1
j = trace.shape[1] - 1
trace[0, :] = 2
trace[:, 0] = 1
text_indices = []
time_indices = []
while i > 0 or j > 0:
text_indices.append(i - 1)
time_indices.append(j - 1)
if trace[i, j] == 0:
i -= 1
j -= 1
elif trace[i, j] == 1:
i -= 1
elif trace[i, j] == 2:
j -= 1
else:
raise RuntimeError(
f"Internal error in dynamic time warping. Unexpected trace[{i}, {j}]. Please file a bug report."
)
text_indices = np.array(text_indices)[::-1]
time_indices = np.array(time_indices)[::-1]
return text_indices, time_indices
class WhisperPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
super().__init__(num_positions, embedding_dim)
@ -1472,6 +1547,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
language=None,
is_multilingual=None,
prompt_ids: Optional[torch.Tensor] = None,
return_token_timestamps=None,
**kwargs,
):
"""
@ -1534,6 +1610,10 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
return_token_timestamps (`bool`, *optional*):
Whether to return token-level timestamps with the text. This can be used with or without the
`return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into
words.
kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
@ -1662,7 +1742,19 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
if generation_config.return_timestamps:
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
return super().generate(
if return_token_timestamps:
kwargs["output_attentions"] = True
kwargs["return_dict_in_generate"] = True
if getattr(generation_config, "task", None) == "translate":
logger.warning("Token-level timestamps may not be reliable for task 'translate'.")
if not hasattr(generation_config, "alignment_heads"):
raise ValueError(
"Model generation config has no `alignment_heads`, token-level timestamps not available. "
"See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
)
outputs = super().generate(
inputs,
generation_config,
logits_processor,
@ -1672,6 +1764,11 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
**kwargs,
)
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
outputs["token_timestamps"] = self._extract_token_timestamps(outputs, generation_config.alignment_heads)
return outputs
def prepare_inputs_for_generation(
self,
decoder_input_ids,
@ -1693,7 +1790,6 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
"decoder_attention_mask": None,
}
#
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
@ -1701,6 +1797,44 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02):
"""
Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to
map each output token to a position in the input audio.
Returns:
tensor containing the timestamps in seconds for each predicted token
"""
# Create a list with `decoder_layers` elements, each a tensor of shape
# (batch size, attention_heads, output length, input length).
cross_attentions = []
for i in range(self.config.decoder_layers):
cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2))
# Select specific cross-attention layers and heads. This is a tensor
# of shape (batch size, num selected, output length, input length).
weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
weights = weights.permute([1, 0, 2, 3])
# Normalize and smoothen the weights.
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
weights = (weights - mean) / std
weights = _median_filter(weights, self.config.median_filter_width)
# Average the different cross-attention heads.
matrix = weights.mean(dim=1)
timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32)
# Perform dynamic time warping on each element of the batch.
for batch_idx in range(timestamps.shape[0]):
text_indices, time_indices = _dynamic_time_warping(-matrix[batch_idx].double().cpu().numpy())
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
jump_times = time_indices[jumps] * time_precision
timestamps[batch_idx, 1:] = torch.tensor(jump_times)
return timestamps
@add_start_docstrings(
"""

View File

@ -585,7 +585,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
WHether or not to decode with timestamps included in the raw text.
Whether or not to decode with timestamps included in the raw text.
Returns:
`str`: The decoded sentence.
"""
@ -779,6 +779,7 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
time_offset = 0.0
timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1
previous_tokens = []
previous_token_timestamps = []
skip = False
right_stride_start = None
@ -788,6 +789,8 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
# We can drop everything to Python list, it's going to make
# our lives easier
token_ids = output["tokens"][0].tolist()
if return_timestamps == "word":
token_timestamps = output["token_timestamps"][0].tolist()
# Those keep track of timestamps within strides
# Which need to be skipped and resolve all tokens in a single
@ -820,6 +823,7 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
last_timestamp = token
current_tokens = []
current_token_timestamps = []
# - all tokens within output
for i, token in enumerate(token_ids):
@ -883,20 +887,37 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
chunk["timestamp"][1] = time
# Handling merges.
previous_tokens.append(current_tokens)
resolved_tokens = _find_longest_common_sequence(previous_tokens)
if return_timestamps == "word":
previous_token_timestamps.append(current_token_timestamps)
resolved_tokens, resolved_token_timestamps = _find_longest_common_sequence(
previous_tokens, previous_token_timestamps
)
resolved_text = tokenizer.decode(resolved_tokens)
chunk["text"] = resolved_text
if return_timestamps == "word":
chunk["words"] = _collate_word_timestamps(
tokenizer, resolved_tokens, resolved_token_timestamps, last_language
)
chunks.append(chunk)
# Flush all our temporary context
previous_tokens = []
current_tokens = []
previous_token_timestamps = []
current_token_timestamps = []
chunk = new_chunk()
else:
# 4/ Regular token
# We just append to the list of all tokens so we can handle
# merges later and decode into text.
current_tokens.append(token)
if return_timestamps == "word":
start_time = round(token_timestamps[i] + time_offset, 2)
if i + 1 < len(token_timestamps):
end_time = round(token_timestamps[i + 1] + time_offset, 2)
else:
end_time = None # should never happen
current_token_timestamps.append((start_time, end_time))
if "stride" in output:
time_offset += chunk_len - stride_right
@ -904,21 +925,31 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
# Leftover tokens
if current_tokens:
previous_tokens.append(current_tokens)
if return_timestamps == "word":
previous_token_timestamps.append(current_token_timestamps)
elif not (any(p for p in previous_tokens)):
chunk = new_chunk()
previous_tokens = []
current_tokens = []
previous_token_timestamps = []
current_token_timestamps = []
if previous_tokens:
if return_timestamps:
logger.warning(
"There was an error while processing timestamps, we haven't found a timestamp as last token. Was"
" WhisperTimeStampLogitsProcessor used?"
"Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. "
"Also make sure WhisperTimeStampLogitsProcessor was used during generation."
)
# Happens when we don't use timestamps
resolved_tokens = _find_longest_common_sequence(previous_tokens)
resolved_tokens, resolved_token_timestamps = _find_longest_common_sequence(
previous_tokens, previous_token_timestamps
)
resolved_text = tokenizer.decode(resolved_tokens)
chunk["text"] = resolved_text
if return_timestamps == "word":
chunk["words"] = _collate_word_timestamps(
tokenizer, resolved_tokens, resolved_token_timestamps, last_language
)
chunks.append(chunk)
# Preparing and cleaning up the pipeline output
@ -931,20 +962,35 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
chunk["timestamp"] = tuple(chunk["timestamp"])
if not return_language:
chunk.pop("language")
optional = {"chunks": chunks}
if return_timestamps == "word":
new_chunks = []
for chunk in chunks:
new_chunks.extend(chunk["words"])
optional = {"chunks": new_chunks}
else:
optional = {"chunks": chunks}
else:
optional = {}
return full_text, optional
def _find_longest_common_sequence(sequences):
def _find_longest_common_sequence(sequences, token_timestamp_sequences=None):
# It would be much harder to do O(n) because of fault tolerance.
# We actually have a really good property which is that the total sequence
# MUST be those subsequences in order.
# If token_timestamp_sequences is provided, will split those sequences in
# exactly the same way.
left_sequence = sequences[0]
left_length = len(left_sequence)
total_sequence = []
for right_sequence in sequences[1:]:
if token_timestamp_sequences:
left_token_timestamp_sequence = token_timestamp_sequences[0]
total_token_timestamp_sequence = []
for seq_idx, right_sequence in enumerate(sequences[1:]):
# index = 0
max_ = 0.0
max_indices = (left_length, left_length, 0, 0)
@ -1018,6 +1064,148 @@ def _find_longest_common_sequence(sequences):
left_sequence = right_sequence[right_mid:]
left_length = len(left_sequence)
if token_timestamp_sequences:
total_token_timestamp_sequence.extend(left_token_timestamp_sequence[:left_mid])
left_token_timestamp_sequence = token_timestamp_sequences[seq_idx + 1][right_mid:]
total_sequence.extend(left_sequence)
return total_sequence
if token_timestamp_sequences is None:
return total_sequence
if len(token_timestamp_sequences) > 0:
total_token_timestamp_sequence.extend(left_token_timestamp_sequence)
return total_sequence, total_token_timestamp_sequence
else:
return total_sequence, []
def _collate_word_timestamps(tokenizer, tokens, token_timestamps, language):
words, _, token_indices = _combine_tokens_into_words(tokenizer, tokens, language)
timings = [
{
"text": word,
"timestamp": (token_timestamps[indices[0]][0], token_timestamps[indices[-1]][1]),
}
for word, indices in zip(words, token_indices)
]
return timings
def _combine_tokens_into_words(
tokenizer,
tokens: List[int],
language: str = None,
prepend_punctuations: str = "\"'“¡¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
):
"""
Groups tokens by word. Returns a tuple containing a list of strings with the words, and a list of `token_id`
sequences with the tokens making up each word.
"""
if language is None:
language = tokenizer.language
if language is None:
language = "english"
if language in {"chinese", "japanese", "thai", "lao", "myanmar"}:
# These languages don't typically use spaces.
words, word_tokens, token_indices = _split_tokens_on_unicode(tokenizer, tokens)
else:
words, word_tokens, token_indices = _split_tokens_on_spaces(tokenizer, tokens)
_merge_punctuations(words, word_tokens, token_indices, prepend_punctuations, append_punctuations)
return words, word_tokens, token_indices
def _split_tokens_on_unicode(tokenizer, tokens: List[int]):
"""Combine tokens into words by splitting at any position where the tokens are decoded as valid unicode points."""
decoded_full = tokenizer.decode(tokens, decode_with_timestamps=True)
replacement_char = "\ufffd"
words = []
word_tokens = []
token_indices = []
current_tokens = []
current_indices = []
unicode_offset = 0
for token_idx, token in enumerate(tokens):
current_tokens.append(token)
current_indices.append(token_idx)
decoded = tokenizer.decode(current_tokens, decode_with_timestamps=True)
if (
replacement_char not in decoded
or decoded_full[unicode_offset + decoded.index(replacement_char)] == replacement_char
):
words.append(decoded)
word_tokens.append(current_tokens)
token_indices.append(current_indices)
current_tokens = []
current_indices = []
unicode_offset += len(decoded)
return words, word_tokens, token_indices
def _split_tokens_on_spaces(tokenizer, tokens: List[int]):
"""Combine tokens into words by splitting at whitespace and punctuation tokens."""
subwords, subword_tokens_list, subword_indices_list = _split_tokens_on_unicode(tokenizer, tokens)
words = []
word_tokens = []
token_indices = []
for subword, subword_tokens, subword_indices in zip(subwords, subword_tokens_list, subword_indices_list):
special = subword_tokens[0] >= tokenizer.eos_token_id
with_space = subword.startswith(" ")
punctuation = subword.strip() in "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
if special or with_space or punctuation or len(words) == 0:
words.append(subword)
word_tokens.append(subword_tokens)
token_indices.append(subword_indices)
else:
words[-1] = words[-1] + subword
word_tokens[-1].extend(subword_tokens)
token_indices[-1].extend(subword_indices)
return words, word_tokens, token_indices
def _merge_punctuations(words, tokens, indices, prepended, appended):
"""Merges punctuation tokens with neighboring words."""
# prepend punctuations
i = len(words) - 2
j = len(words) - 1
while i >= 0:
if words[i].startswith(" ") and words[i].strip() in prepended:
words[j] = words[i] + words[j]
tokens[j] = tokens[i] + tokens[j]
indices[j] = indices[i] + indices[j]
words[i] = ""
tokens[i] = []
indices[i] = []
else:
j = i
i -= 1
# append punctuations
i = 0
j = 1
while j < len(words):
if not words[i].endswith(" ") and words[j] in appended:
words[i] += words[j]
tokens[i] += tokens[j]
indices[i] += indices[j]
words[j] = ""
tokens[j] = []
indices[j] = []
else:
i = j
j += 1
# remove elements that are now empty
words[:] = [word for word in words if word]
tokens[:] = [token for token in tokens if token]
indices[:] = [idx for idx in indices if idx]

View File

@ -295,7 +295,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
WHether or not to decode with timestamps included in the raw text.
Whether or not to decode with timestamps included in the raw text.
Returns:
`str`: The decoded sentence.
"""

View File

@ -246,12 +246,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
inference to provide more context to the model). Only use `stride` with CTC models.
return_timestamps (*optional*, `str`):
Only available for pure CTC models. If set to `"char"`, the pipeline will return `timestamps` along the
text for every character in the text. For instance if you get `[{"text": "h", "timestamps": (0.5,0.6),
{"text": "i", "timestamps": (0.7, .9)}]`, then it means the model predicts that the letter "h" was
Only available for pure CTC models. If set to `"char"`, the pipeline will return timestamps along the
text for every character in the text. For instance if you get `[{"text": "h", "timestamp": (0.5, 0.6)},
{"text": "i", "timestamp": (0.7, 0.9)}]`, then it means the model predicts that the letter "h" was
pronounced after `0.5` and before `0.6` seconds. If set to `"word"`, the pipeline will return
`timestamps` along the text for every word in the text. For instance if you get `[{"text": "hi ",
"timestamps": (0.5,0.9), {"text": "there", "timestamps": (1.0, .1.5)}]`, then it means the model
timestamps along the text for every word in the text. For instance if you get `[{"text": "hi ",
"timestamp": (0.5, 0.9)}, {"text": "there", "timestamp": (1.0, 1.5)}]`, then it means the model
predicts that the word "hi" was pronounced after `0.5` and before `0.9` seconds.
generate_kwargs (`dict`, *optional*):
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
@ -265,8 +265,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
- **text** (`str` ) -- The recognized text.
- **chunks** (*optional(, `List[Dict]`)
When using `return_timestamps`, the `chunks` will become a list containing all the various text
chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamps": (0.5,0.9), {"text":
"there", "timestamps": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text":
"there", "timestamp": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
`"".join(chunk["text"] for chunk in output["chunks"])`.
"""
return super().__call__(inputs, **kwargs)
@ -421,6 +421,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
generate_kwargs = {}
if return_timestamps and self.type == "seq2seq_whisper":
generate_kwargs["return_timestamps"] = return_timestamps
if return_timestamps == "word":
generate_kwargs["return_token_timestamps"] = True
is_last = model_inputs.pop("is_last")
if self.type in {"seq2seq", "seq2seq_whisper"}:
@ -447,7 +449,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
attention_mask=attention_mask,
**generate_kwargs,
)
out = {"tokens": tokens}
if return_timestamps == "word" and self.type == "seq2seq_whisper":
out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]}
else:
out = {"tokens": tokens}
if self.type == "seq2seq_whisper":
stride = model_inputs.pop("stride", None)
if stride is not None:
@ -486,9 +491,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if return_timestamps and self.type == "seq2seq":
raise ValueError("We cannot return_timestamps yet on non-ctc models apart from Whisper !")
if return_timestamps == "char" and self.type == "ctc_with_lm":
raise ValueError("CTC with LM cannot return `char` timestamps, only `words`")
if return_timestamps in {"char", "words"} and self.type == "seq2seq_whisper":
raise ValueError("Whisper cannot return `char` nor `words` timestamps, use `True` instead.")
raise ValueError("CTC with LM cannot return `char` timestamps, only `word`")
if return_timestamps == "char" and self.type == "seq2seq_whisper":
raise ValueError("Whisper cannot return `char` timestamps, use `True` or `word` instead.")
if return_language is not None and self.type != "seq2seq_whisper":
raise ValueError("Only whisper can return language for now.")
@ -574,6 +579,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
output.pop("logits", None)
output.pop("is_last", None)
output.pop("stride", None)
output.pop("token_timestamps", None)
for k, v in output.items():
extra[k].append(v)
return {"text": text, **optional, **extra}

View File

@ -1436,6 +1436,35 @@ class WhisperModelIntegrationTests(unittest.TestCase):
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
@slow
def test_tiny_token_timestamp_generation(self):
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device)
input_speech = self._load_datasamples(4)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
torch_device
)
generate_outputs = model.generate(
input_features, max_length=448, return_timestamps=True, return_token_timestamps=True
)
self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape)
# fmt: off
EXPECTED_OUTPUT = torch.tensor([
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.4800, 0.8200, 0.9600, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0000, 2.3400, 2.5000, 2.6600, 3.1800, 3.5600, 3.6800, 3.8000, 4.1000, 4.3000, 4.5800, 4.9400, 5.3800, 12.4200, 12.8400, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9400, 26.9400, 26.9400, 26.9400, 29.8400 ],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.5200, 0.9000, 1.1400, 1.4200, 1.5200, 1.6800, 1.6800, 1.8800, 2.1000, 2.2200, 2.6200, 3.1400, 3.5800, 3.9600, 4.4000, 17.3000, 17.3000, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400, 28.0000 ],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7600, 1.0000, 1.4200, 1.8000, 1.9400, 2.1800, 2.5200, 3.0200, 3.3200, 3.5400, 3.9400, 4.5600, 4.9200, 5.2800, 5.5600, 5.9000, 6.1600, 6.3000, 6.4800, 6.4800, 6.6400, 7.8200, 7.9600, 8.2200, 8.6000, 8.9200, 9.2200, 9.5200, 9.7200, 10.0600, 10.5400, 10.8800, 11.2600, 11.5400, 11.7400, 12.0800, 15.6800, 15.6800],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7400, 1.0400, 1.3200, 1.6800, 2.1400, 2.4800, 2.7800, 3.0800, 3.1600, 3.4000, 3.6000, 4.0200, 4.2200, 4.8600, 5.2400, 5.7400, 6.3400, 6.6200, 6.7600, 6.7600, 6.8600, 7.2400, 7.4200, 7.6800, 7.9200, 8.4800, 8.7600, 9.2000, 9.2000, 9.4200, 15.8200, 15.8200, 29.6400, 29.6600, 29.6600, 29.6600, 29.6600, 29.7600]
])
# fmt: on
self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))
@slow
def test_tiny_specaugment_librispeech(self):
torch_device = "cpu"

View File

@ -15,7 +15,7 @@
import unittest
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
from transformers.models.whisper.tokenization_whisper import _find_longest_common_sequence
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
from transformers.testing_utils import slow
from ...test_tokenization_common import TokenizerTesterMixin
@ -255,6 +255,24 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
self.assertListEqual(tokenizer_prompt_ids.tolist(), fast_tokenizer_prompt_ids.tolist())
def test_combine_tokens_into_words(self):
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()
# 'whatever "whatever" said someone, clever!?'
encoded_input = [1363, 7969, 503, 1363, 7969, 1, 848, 1580, 11, 13494, 7323]
expected_words = ["whatever", ' "whatever"', " said", " someone,", " clever!?"]
expected_tokens = [[1363, 7969], [503, 1363, 7969, 1], [848], [1580, 11], [13494, 7323]]
expected_indices = [[0, 1], [2, 3, 4, 5], [6], [7, 8], [9, 10]]
output = _combine_tokens_into_words(tokenizer, encoded_input)
self.assertEqual(expected_words, output[0])
self.assertEqual(expected_tokens, output[1])
self.assertEqual(expected_indices, output[2])
output_rust = _combine_tokens_into_words(rust_tokenizer, encoded_input)
self.assertEqual(expected_words, output_rust[0])
self.assertEqual(expected_tokens, output_rust[1])
self.assertEqual(expected_indices, output_rust[2])
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
checkpoint_name = "openai/whisper-small.en"

View File

@ -316,6 +316,27 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
"chunks": [{"text": " Conquered returned to its place amidst the tents.", "timestamp": (0.0, 3.36)}],
},
)
pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
res = pipe(sample["audio"]["array"], return_timestamps="word")
# fmt: off
# Note that the word-level timestamps predicted here are pretty bad.
self.assertEqual(
res,
{
"text": " Conquered returned to its place amidst the tents.",
"chunks": [
{'text': ' Conquered', 'timestamp': (29.78, 29.9)},
{'text': ' returned', 'timestamp': (29.9, 29.9)},
{'text': ' to', 'timestamp': (29.9, 29.9)},
{'text': ' its', 'timestamp': (29.9, 29.9)},
{'text': ' place', 'timestamp': (29.9, 29.9)},
{'text': ' amidst', 'timestamp': (29.9, 29.9)},
{'text': ' the', 'timestamp': (29.9, 29.9)},
{'text': ' tents.', 'timestamp': (29.9, 29.9)}
]
}
)
# fmt: on
@require_torch
@slow
@ -699,6 +720,35 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
],
},
)
speech_recognizer.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
output = speech_recognizer(filename, return_timestamps="word")
# fmt: off
self.assertEqual(
output,
{
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
"chunks": [
{'text': ' Mr.', 'timestamp': (0.0, 1.02)},
{'text': ' Quilter', 'timestamp': (1.02, 1.18)},
{'text': ' is', 'timestamp': (1.18, 1.44)},
{'text': ' the', 'timestamp': (1.44, 1.58)},
{'text': ' apostle', 'timestamp': (1.58, 1.98)},
{'text': ' of', 'timestamp': (1.98, 2.3)},
{'text': ' the', 'timestamp': (2.3, 2.46)},
{'text': ' middle', 'timestamp': (2.46, 2.56)},
{'text': ' classes,', 'timestamp': (2.56, 3.38)},
{'text': ' and', 'timestamp': (3.38, 3.52)},
{'text': ' we', 'timestamp': (3.52, 3.6)},
{'text': ' are', 'timestamp': (3.6, 3.72)},
{'text': ' glad', 'timestamp': (3.72, 4.0)},
{'text': ' to', 'timestamp': (4.0, 4.26)},
{'text': ' welcome', 'timestamp': (4.26, 4.54)},
{'text': ' his', 'timestamp': (4.54, 4.92)},
{'text': ' gospel.', 'timestamp': (4.92, 6.66)},
],
},
)
# fmt: on
@slow
@require_torch