add word-level timestamps to Whisper (#23205)
* let's go! * initial implementation of token-level timestamps * only return a single timestamp per token * remove token probabilities * fix return type * fix doc comment * strip special tokens * rename * revert to not stripping special tokens * only support models that have alignment_heads * add integration test * consistently name it token-level timestamps * small DTW tweak * initial support for ASR pipeline * fix pipeline doc comments * resolve token timestamps in pipeline with chunking * change warning when no final timestamp is found * return word-level timestamps * fixup * fix bug that skipped final word in each chunk * fix failing unit tests * merge punctuations into the words * also return word tokens * also return token indices * add (failing) unit test for combine_tokens_into_words * make combine_tokens_into_words private * restore OpenAI's punctuation rules * add pipeline tests * make requested changes * PR review changes * fix failing pipeline test * small stuff from PR * only return words and their timestamps, not segments * move alignment_heads into generation config * forgot to set alignment_heads in pipeline tests * tiny comment fix * grr
This commit is contained in:
parent
0f968ddaa3
commit
cd927a4736
|
@ -171,7 +171,9 @@ class WhisperConfig(PretrainedConfig):
|
|||
The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
|
||||
step, irrespectively of `mask_feature_prob`. Only relevant if
|
||||
`mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks`.
|
||||
|
||||
median_filter_width (`int`, *optional*, defaults to 7):
|
||||
Width of the median filter used to smoothen to cross-attention outputs when computing token timestamps.
|
||||
Should be an odd number.
|
||||
|
||||
Example:
|
||||
|
||||
|
@ -229,6 +231,7 @@ class WhisperConfig(PretrainedConfig):
|
|||
mask_feature_prob=0.0,
|
||||
mask_feature_length=10,
|
||||
mask_feature_min_masks=0,
|
||||
median_filter_width=7,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
|
@ -265,6 +268,9 @@ class WhisperConfig(PretrainedConfig):
|
|||
self.mask_feature_prob = mask_feature_prob
|
||||
self.mask_feature_length = mask_feature_length
|
||||
self.mask_feature_min_masks = mask_feature_min_masks
|
||||
|
||||
self.median_filter_width = median_filter_width
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
|
|
|
@ -227,6 +227,81 @@ def _compute_mask_indices(
|
|||
return spec_aug_mask
|
||||
|
||||
|
||||
def _median_filter(inputs: torch.Tensor, filter_width: int) -> torch.Tensor:
|
||||
"""
|
||||
Applies a median filter of width `filter_width` along the last dimension of the input.
|
||||
|
||||
The `inputs` tensor is assumed to be 3- or 4-dimensional.
|
||||
"""
|
||||
if filter_width <= 0 or filter_width % 2 != 1:
|
||||
raise ValueError("`filter_width` should be an odd number")
|
||||
|
||||
pad_width = filter_width // 2
|
||||
if inputs.shape[-1] <= pad_width:
|
||||
return inputs
|
||||
|
||||
# Pad the left and right edges.
|
||||
inputs = nn.functional.pad(inputs, (pad_width, pad_width, 0, 0), mode="reflect")
|
||||
|
||||
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
|
||||
result = inputs.unfold(-1, filter_width, 1).sort()[0][..., pad_width]
|
||||
return result
|
||||
|
||||
|
||||
def _dynamic_time_warping(matrix: np.ndarray):
|
||||
"""
|
||||
Measures similarity between two temporal sequences: the input audio and the output tokens. Used to generate
|
||||
token-level timestamps.
|
||||
"""
|
||||
output_length, input_length = matrix.shape
|
||||
cost = np.ones((output_length + 1, input_length + 1), dtype=np.float32) * np.inf
|
||||
trace = -np.ones((output_length + 1, input_length + 1), dtype=np.float32)
|
||||
|
||||
cost[0, 0] = 0
|
||||
for j in range(1, input_length + 1):
|
||||
for i in range(1, output_length + 1):
|
||||
c0 = cost[i - 1, j - 1]
|
||||
c1 = cost[i - 1, j]
|
||||
c2 = cost[i, j - 1]
|
||||
|
||||
if c0 < c1 and c0 < c2:
|
||||
c, t = c0, 0
|
||||
elif c1 < c0 and c1 < c2:
|
||||
c, t = c1, 1
|
||||
else:
|
||||
c, t = c2, 2
|
||||
|
||||
cost[i, j] = matrix[i - 1, j - 1] + c
|
||||
trace[i, j] = t
|
||||
|
||||
# backtrace
|
||||
i = trace.shape[0] - 1
|
||||
j = trace.shape[1] - 1
|
||||
trace[0, :] = 2
|
||||
trace[:, 0] = 1
|
||||
|
||||
text_indices = []
|
||||
time_indices = []
|
||||
while i > 0 or j > 0:
|
||||
text_indices.append(i - 1)
|
||||
time_indices.append(j - 1)
|
||||
if trace[i, j] == 0:
|
||||
i -= 1
|
||||
j -= 1
|
||||
elif trace[i, j] == 1:
|
||||
i -= 1
|
||||
elif trace[i, j] == 2:
|
||||
j -= 1
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Internal error in dynamic time warping. Unexpected trace[{i}, {j}]. Please file a bug report."
|
||||
)
|
||||
|
||||
text_indices = np.array(text_indices)[::-1]
|
||||
time_indices = np.array(time_indices)[::-1]
|
||||
return text_indices, time_indices
|
||||
|
||||
|
||||
class WhisperPositionalEmbedding(nn.Embedding):
|
||||
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
||||
super().__init__(num_positions, embedding_dim)
|
||||
|
@ -1472,6 +1547,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||
language=None,
|
||||
is_multilingual=None,
|
||||
prompt_ids: Optional[torch.Tensor] = None,
|
||||
return_token_timestamps=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
@ -1534,6 +1610,10 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||
provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
|
||||
transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
|
||||
correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
|
||||
return_token_timestamps (`bool`, *optional*):
|
||||
Whether to return token-level timestamps with the text. This can be used with or without the
|
||||
`return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into
|
||||
words.
|
||||
kwargs:
|
||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
||||
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
||||
|
@ -1662,7 +1742,19 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||
if generation_config.return_timestamps:
|
||||
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
|
||||
|
||||
return super().generate(
|
||||
if return_token_timestamps:
|
||||
kwargs["output_attentions"] = True
|
||||
kwargs["return_dict_in_generate"] = True
|
||||
|
||||
if getattr(generation_config, "task", None) == "translate":
|
||||
logger.warning("Token-level timestamps may not be reliable for task 'translate'.")
|
||||
if not hasattr(generation_config, "alignment_heads"):
|
||||
raise ValueError(
|
||||
"Model generation config has no `alignment_heads`, token-level timestamps not available. "
|
||||
"See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
|
||||
)
|
||||
|
||||
outputs = super().generate(
|
||||
inputs,
|
||||
generation_config,
|
||||
logits_processor,
|
||||
|
@ -1672,6 +1764,11 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
|
||||
outputs["token_timestamps"] = self._extract_token_timestamps(outputs, generation_config.alignment_heads)
|
||||
|
||||
return outputs
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
|
@ -1693,7 +1790,6 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||
"decoder_attention_mask": None,
|
||||
}
|
||||
|
||||
#
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
|
@ -1701,6 +1797,44 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
return reordered_past
|
||||
|
||||
def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02):
|
||||
"""
|
||||
Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to
|
||||
map each output token to a position in the input audio.
|
||||
|
||||
Returns:
|
||||
tensor containing the timestamps in seconds for each predicted token
|
||||
"""
|
||||
# Create a list with `decoder_layers` elements, each a tensor of shape
|
||||
# (batch size, attention_heads, output length, input length).
|
||||
cross_attentions = []
|
||||
for i in range(self.config.decoder_layers):
|
||||
cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2))
|
||||
|
||||
# Select specific cross-attention layers and heads. This is a tensor
|
||||
# of shape (batch size, num selected, output length, input length).
|
||||
weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
|
||||
weights = weights.permute([1, 0, 2, 3])
|
||||
|
||||
# Normalize and smoothen the weights.
|
||||
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
||||
weights = (weights - mean) / std
|
||||
weights = _median_filter(weights, self.config.median_filter_width)
|
||||
|
||||
# Average the different cross-attention heads.
|
||||
matrix = weights.mean(dim=1)
|
||||
|
||||
timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32)
|
||||
|
||||
# Perform dynamic time warping on each element of the batch.
|
||||
for batch_idx in range(timestamps.shape[0]):
|
||||
text_indices, time_indices = _dynamic_time_warping(-matrix[batch_idx].double().cpu().numpy())
|
||||
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
||||
jump_times = time_indices[jumps] * time_precision
|
||||
timestamps[batch_idx, 1:] = torch.tensor(jump_times)
|
||||
|
||||
return timestamps
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
|
|
@ -585,7 +585,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
|
||||
timestamps.
|
||||
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
|
||||
WHether or not to decode with timestamps included in the raw text.
|
||||
Whether or not to decode with timestamps included in the raw text.
|
||||
Returns:
|
||||
`str`: The decoded sentence.
|
||||
"""
|
||||
|
@ -779,6 +779,7 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
|
|||
time_offset = 0.0
|
||||
timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1
|
||||
previous_tokens = []
|
||||
previous_token_timestamps = []
|
||||
skip = False
|
||||
right_stride_start = None
|
||||
|
||||
|
@ -788,6 +789,8 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
|
|||
# We can drop everything to Python list, it's going to make
|
||||
# our lives easier
|
||||
token_ids = output["tokens"][0].tolist()
|
||||
if return_timestamps == "word":
|
||||
token_timestamps = output["token_timestamps"][0].tolist()
|
||||
|
||||
# Those keep track of timestamps within strides
|
||||
# Which need to be skipped and resolve all tokens in a single
|
||||
|
@ -820,6 +823,7 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
|
|||
last_timestamp = token
|
||||
|
||||
current_tokens = []
|
||||
current_token_timestamps = []
|
||||
|
||||
# - all tokens within output
|
||||
for i, token in enumerate(token_ids):
|
||||
|
@ -883,20 +887,37 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
|
|||
chunk["timestamp"][1] = time
|
||||
# Handling merges.
|
||||
previous_tokens.append(current_tokens)
|
||||
resolved_tokens = _find_longest_common_sequence(previous_tokens)
|
||||
if return_timestamps == "word":
|
||||
previous_token_timestamps.append(current_token_timestamps)
|
||||
resolved_tokens, resolved_token_timestamps = _find_longest_common_sequence(
|
||||
previous_tokens, previous_token_timestamps
|
||||
)
|
||||
resolved_text = tokenizer.decode(resolved_tokens)
|
||||
chunk["text"] = resolved_text
|
||||
if return_timestamps == "word":
|
||||
chunk["words"] = _collate_word_timestamps(
|
||||
tokenizer, resolved_tokens, resolved_token_timestamps, last_language
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
# Flush all our temporary context
|
||||
previous_tokens = []
|
||||
current_tokens = []
|
||||
previous_token_timestamps = []
|
||||
current_token_timestamps = []
|
||||
chunk = new_chunk()
|
||||
else:
|
||||
# 4/ Regular token
|
||||
# We just append to the list of all tokens so we can handle
|
||||
# merges later and decode into text.
|
||||
current_tokens.append(token)
|
||||
if return_timestamps == "word":
|
||||
start_time = round(token_timestamps[i] + time_offset, 2)
|
||||
if i + 1 < len(token_timestamps):
|
||||
end_time = round(token_timestamps[i + 1] + time_offset, 2)
|
||||
else:
|
||||
end_time = None # should never happen
|
||||
current_token_timestamps.append((start_time, end_time))
|
||||
|
||||
if "stride" in output:
|
||||
time_offset += chunk_len - stride_right
|
||||
|
@ -904,21 +925,31 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
|
|||
# Leftover tokens
|
||||
if current_tokens:
|
||||
previous_tokens.append(current_tokens)
|
||||
if return_timestamps == "word":
|
||||
previous_token_timestamps.append(current_token_timestamps)
|
||||
elif not (any(p for p in previous_tokens)):
|
||||
chunk = new_chunk()
|
||||
previous_tokens = []
|
||||
current_tokens = []
|
||||
previous_token_timestamps = []
|
||||
current_token_timestamps = []
|
||||
|
||||
if previous_tokens:
|
||||
if return_timestamps:
|
||||
logger.warning(
|
||||
"There was an error while processing timestamps, we haven't found a timestamp as last token. Was"
|
||||
" WhisperTimeStampLogitsProcessor used?"
|
||||
"Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. "
|
||||
"Also make sure WhisperTimeStampLogitsProcessor was used during generation."
|
||||
)
|
||||
# Happens when we don't use timestamps
|
||||
resolved_tokens = _find_longest_common_sequence(previous_tokens)
|
||||
resolved_tokens, resolved_token_timestamps = _find_longest_common_sequence(
|
||||
previous_tokens, previous_token_timestamps
|
||||
)
|
||||
resolved_text = tokenizer.decode(resolved_tokens)
|
||||
chunk["text"] = resolved_text
|
||||
if return_timestamps == "word":
|
||||
chunk["words"] = _collate_word_timestamps(
|
||||
tokenizer, resolved_tokens, resolved_token_timestamps, last_language
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
# Preparing and cleaning up the pipeline output
|
||||
|
@ -931,20 +962,35 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
|
|||
chunk["timestamp"] = tuple(chunk["timestamp"])
|
||||
if not return_language:
|
||||
chunk.pop("language")
|
||||
|
||||
if return_timestamps == "word":
|
||||
new_chunks = []
|
||||
for chunk in chunks:
|
||||
new_chunks.extend(chunk["words"])
|
||||
optional = {"chunks": new_chunks}
|
||||
else:
|
||||
optional = {"chunks": chunks}
|
||||
else:
|
||||
optional = {}
|
||||
return full_text, optional
|
||||
|
||||
|
||||
def _find_longest_common_sequence(sequences):
|
||||
def _find_longest_common_sequence(sequences, token_timestamp_sequences=None):
|
||||
# It would be much harder to do O(n) because of fault tolerance.
|
||||
# We actually have a really good property which is that the total sequence
|
||||
# MUST be those subsequences in order.
|
||||
# If token_timestamp_sequences is provided, will split those sequences in
|
||||
# exactly the same way.
|
||||
|
||||
left_sequence = sequences[0]
|
||||
left_length = len(left_sequence)
|
||||
total_sequence = []
|
||||
for right_sequence in sequences[1:]:
|
||||
|
||||
if token_timestamp_sequences:
|
||||
left_token_timestamp_sequence = token_timestamp_sequences[0]
|
||||
total_token_timestamp_sequence = []
|
||||
|
||||
for seq_idx, right_sequence in enumerate(sequences[1:]):
|
||||
# index = 0
|
||||
max_ = 0.0
|
||||
max_indices = (left_length, left_length, 0, 0)
|
||||
|
@ -1018,6 +1064,148 @@ def _find_longest_common_sequence(sequences):
|
|||
left_sequence = right_sequence[right_mid:]
|
||||
left_length = len(left_sequence)
|
||||
|
||||
if token_timestamp_sequences:
|
||||
total_token_timestamp_sequence.extend(left_token_timestamp_sequence[:left_mid])
|
||||
left_token_timestamp_sequence = token_timestamp_sequences[seq_idx + 1][right_mid:]
|
||||
|
||||
total_sequence.extend(left_sequence)
|
||||
|
||||
if token_timestamp_sequences is None:
|
||||
return total_sequence
|
||||
|
||||
if len(token_timestamp_sequences) > 0:
|
||||
total_token_timestamp_sequence.extend(left_token_timestamp_sequence)
|
||||
return total_sequence, total_token_timestamp_sequence
|
||||
else:
|
||||
return total_sequence, []
|
||||
|
||||
|
||||
def _collate_word_timestamps(tokenizer, tokens, token_timestamps, language):
|
||||
words, _, token_indices = _combine_tokens_into_words(tokenizer, tokens, language)
|
||||
timings = [
|
||||
{
|
||||
"text": word,
|
||||
"timestamp": (token_timestamps[indices[0]][0], token_timestamps[indices[-1]][1]),
|
||||
}
|
||||
for word, indices in zip(words, token_indices)
|
||||
]
|
||||
return timings
|
||||
|
||||
|
||||
def _combine_tokens_into_words(
|
||||
tokenizer,
|
||||
tokens: List[int],
|
||||
language: str = None,
|
||||
prepend_punctuations: str = "\"'“¡¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
):
|
||||
"""
|
||||
Groups tokens by word. Returns a tuple containing a list of strings with the words, and a list of `token_id`
|
||||
sequences with the tokens making up each word.
|
||||
"""
|
||||
if language is None:
|
||||
language = tokenizer.language
|
||||
if language is None:
|
||||
language = "english"
|
||||
|
||||
if language in {"chinese", "japanese", "thai", "lao", "myanmar"}:
|
||||
# These languages don't typically use spaces.
|
||||
words, word_tokens, token_indices = _split_tokens_on_unicode(tokenizer, tokens)
|
||||
else:
|
||||
words, word_tokens, token_indices = _split_tokens_on_spaces(tokenizer, tokens)
|
||||
|
||||
_merge_punctuations(words, word_tokens, token_indices, prepend_punctuations, append_punctuations)
|
||||
return words, word_tokens, token_indices
|
||||
|
||||
|
||||
def _split_tokens_on_unicode(tokenizer, tokens: List[int]):
|
||||
"""Combine tokens into words by splitting at any position where the tokens are decoded as valid unicode points."""
|
||||
decoded_full = tokenizer.decode(tokens, decode_with_timestamps=True)
|
||||
replacement_char = "\ufffd"
|
||||
|
||||
words = []
|
||||
word_tokens = []
|
||||
token_indices = []
|
||||
current_tokens = []
|
||||
current_indices = []
|
||||
unicode_offset = 0
|
||||
|
||||
for token_idx, token in enumerate(tokens):
|
||||
current_tokens.append(token)
|
||||
current_indices.append(token_idx)
|
||||
decoded = tokenizer.decode(current_tokens, decode_with_timestamps=True)
|
||||
|
||||
if (
|
||||
replacement_char not in decoded
|
||||
or decoded_full[unicode_offset + decoded.index(replacement_char)] == replacement_char
|
||||
):
|
||||
words.append(decoded)
|
||||
word_tokens.append(current_tokens)
|
||||
token_indices.append(current_indices)
|
||||
current_tokens = []
|
||||
current_indices = []
|
||||
unicode_offset += len(decoded)
|
||||
|
||||
return words, word_tokens, token_indices
|
||||
|
||||
|
||||
def _split_tokens_on_spaces(tokenizer, tokens: List[int]):
|
||||
"""Combine tokens into words by splitting at whitespace and punctuation tokens."""
|
||||
subwords, subword_tokens_list, subword_indices_list = _split_tokens_on_unicode(tokenizer, tokens)
|
||||
words = []
|
||||
word_tokens = []
|
||||
token_indices = []
|
||||
|
||||
for subword, subword_tokens, subword_indices in zip(subwords, subword_tokens_list, subword_indices_list):
|
||||
special = subword_tokens[0] >= tokenizer.eos_token_id
|
||||
with_space = subword.startswith(" ")
|
||||
punctuation = subword.strip() in "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
|
||||
|
||||
if special or with_space or punctuation or len(words) == 0:
|
||||
words.append(subword)
|
||||
word_tokens.append(subword_tokens)
|
||||
token_indices.append(subword_indices)
|
||||
else:
|
||||
words[-1] = words[-1] + subword
|
||||
word_tokens[-1].extend(subword_tokens)
|
||||
token_indices[-1].extend(subword_indices)
|
||||
|
||||
return words, word_tokens, token_indices
|
||||
|
||||
|
||||
def _merge_punctuations(words, tokens, indices, prepended, appended):
|
||||
"""Merges punctuation tokens with neighboring words."""
|
||||
# prepend punctuations
|
||||
i = len(words) - 2
|
||||
j = len(words) - 1
|
||||
while i >= 0:
|
||||
if words[i].startswith(" ") and words[i].strip() in prepended:
|
||||
words[j] = words[i] + words[j]
|
||||
tokens[j] = tokens[i] + tokens[j]
|
||||
indices[j] = indices[i] + indices[j]
|
||||
words[i] = ""
|
||||
tokens[i] = []
|
||||
indices[i] = []
|
||||
else:
|
||||
j = i
|
||||
i -= 1
|
||||
|
||||
# append punctuations
|
||||
i = 0
|
||||
j = 1
|
||||
while j < len(words):
|
||||
if not words[i].endswith(" ") and words[j] in appended:
|
||||
words[i] += words[j]
|
||||
tokens[i] += tokens[j]
|
||||
indices[i] += indices[j]
|
||||
words[j] = ""
|
||||
tokens[j] = []
|
||||
indices[j] = []
|
||||
else:
|
||||
i = j
|
||||
j += 1
|
||||
|
||||
# remove elements that are now empty
|
||||
words[:] = [word for word in words if word]
|
||||
tokens[:] = [token for token in tokens if token]
|
||||
indices[:] = [idx for idx in indices if idx]
|
||||
|
|
|
@ -295,7 +295,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
|
||||
timestamps.
|
||||
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
|
||||
WHether or not to decode with timestamps included in the raw text.
|
||||
Whether or not to decode with timestamps included in the raw text.
|
||||
Returns:
|
||||
`str`: The decoded sentence.
|
||||
"""
|
||||
|
|
|
@ -246,12 +246,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||
treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
|
||||
inference to provide more context to the model). Only use `stride` with CTC models.
|
||||
return_timestamps (*optional*, `str`):
|
||||
Only available for pure CTC models. If set to `"char"`, the pipeline will return `timestamps` along the
|
||||
text for every character in the text. For instance if you get `[{"text": "h", "timestamps": (0.5,0.6),
|
||||
{"text": "i", "timestamps": (0.7, .9)}]`, then it means the model predicts that the letter "h" was
|
||||
Only available for pure CTC models. If set to `"char"`, the pipeline will return timestamps along the
|
||||
text for every character in the text. For instance if you get `[{"text": "h", "timestamp": (0.5, 0.6)},
|
||||
{"text": "i", "timestamp": (0.7, 0.9)}]`, then it means the model predicts that the letter "h" was
|
||||
pronounced after `0.5` and before `0.6` seconds. If set to `"word"`, the pipeline will return
|
||||
`timestamps` along the text for every word in the text. For instance if you get `[{"text": "hi ",
|
||||
"timestamps": (0.5,0.9), {"text": "there", "timestamps": (1.0, .1.5)}]`, then it means the model
|
||||
timestamps along the text for every word in the text. For instance if you get `[{"text": "hi ",
|
||||
"timestamp": (0.5, 0.9)}, {"text": "there", "timestamp": (1.0, 1.5)}]`, then it means the model
|
||||
predicts that the word "hi" was pronounced after `0.5` and before `0.9` seconds.
|
||||
generate_kwargs (`dict`, *optional*):
|
||||
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
|
||||
|
@ -265,8 +265,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||
- **text** (`str` ) -- The recognized text.
|
||||
- **chunks** (*optional(, `List[Dict]`)
|
||||
When using `return_timestamps`, the `chunks` will become a list containing all the various text
|
||||
chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamps": (0.5,0.9), {"text":
|
||||
"there", "timestamps": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
|
||||
chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text":
|
||||
"there", "timestamp": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
|
||||
`"".join(chunk["text"] for chunk in output["chunks"])`.
|
||||
"""
|
||||
return super().__call__(inputs, **kwargs)
|
||||
|
@ -421,6 +421,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||
generate_kwargs = {}
|
||||
if return_timestamps and self.type == "seq2seq_whisper":
|
||||
generate_kwargs["return_timestamps"] = return_timestamps
|
||||
if return_timestamps == "word":
|
||||
generate_kwargs["return_token_timestamps"] = True
|
||||
is_last = model_inputs.pop("is_last")
|
||||
|
||||
if self.type in {"seq2seq", "seq2seq_whisper"}:
|
||||
|
@ -447,6 +449,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||
attention_mask=attention_mask,
|
||||
**generate_kwargs,
|
||||
)
|
||||
if return_timestamps == "word" and self.type == "seq2seq_whisper":
|
||||
out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]}
|
||||
else:
|
||||
out = {"tokens": tokens}
|
||||
if self.type == "seq2seq_whisper":
|
||||
stride = model_inputs.pop("stride", None)
|
||||
|
@ -486,9 +491,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||
if return_timestamps and self.type == "seq2seq":
|
||||
raise ValueError("We cannot return_timestamps yet on non-ctc models apart from Whisper !")
|
||||
if return_timestamps == "char" and self.type == "ctc_with_lm":
|
||||
raise ValueError("CTC with LM cannot return `char` timestamps, only `words`")
|
||||
if return_timestamps in {"char", "words"} and self.type == "seq2seq_whisper":
|
||||
raise ValueError("Whisper cannot return `char` nor `words` timestamps, use `True` instead.")
|
||||
raise ValueError("CTC with LM cannot return `char` timestamps, only `word`")
|
||||
if return_timestamps == "char" and self.type == "seq2seq_whisper":
|
||||
raise ValueError("Whisper cannot return `char` timestamps, use `True` or `word` instead.")
|
||||
|
||||
if return_language is not None and self.type != "seq2seq_whisper":
|
||||
raise ValueError("Only whisper can return language for now.")
|
||||
|
@ -574,6 +579,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||
output.pop("logits", None)
|
||||
output.pop("is_last", None)
|
||||
output.pop("stride", None)
|
||||
output.pop("token_timestamps", None)
|
||||
for k, v in output.items():
|
||||
extra[k].append(v)
|
||||
return {"text": text, **optional, **extra}
|
||||
|
|
|
@ -1436,6 +1436,35 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
def test_tiny_token_timestamp_generation(self):
|
||||
set_seed(0)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||
model.to(torch_device)
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
generate_outputs = model.generate(
|
||||
input_features, max_length=448, return_timestamps=True, return_token_timestamps=True
|
||||
)
|
||||
|
||||
self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_OUTPUT = torch.tensor([
|
||||
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.4800, 0.8200, 0.9600, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0000, 2.3400, 2.5000, 2.6600, 3.1800, 3.5600, 3.6800, 3.8000, 4.1000, 4.3000, 4.5800, 4.9400, 5.3800, 12.4200, 12.8400, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9400, 26.9400, 26.9400, 26.9400, 29.8400 ],
|
||||
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.5200, 0.9000, 1.1400, 1.4200, 1.5200, 1.6800, 1.6800, 1.8800, 2.1000, 2.2200, 2.6200, 3.1400, 3.5800, 3.9600, 4.4000, 17.3000, 17.3000, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400, 28.0000 ],
|
||||
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7600, 1.0000, 1.4200, 1.8000, 1.9400, 2.1800, 2.5200, 3.0200, 3.3200, 3.5400, 3.9400, 4.5600, 4.9200, 5.2800, 5.5600, 5.9000, 6.1600, 6.3000, 6.4800, 6.4800, 6.6400, 7.8200, 7.9600, 8.2200, 8.6000, 8.9200, 9.2200, 9.5200, 9.7200, 10.0600, 10.5400, 10.8800, 11.2600, 11.5400, 11.7400, 12.0800, 15.6800, 15.6800],
|
||||
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7400, 1.0400, 1.3200, 1.6800, 2.1400, 2.4800, 2.7800, 3.0800, 3.1600, 3.4000, 3.6000, 4.0200, 4.2200, 4.8600, 5.2400, 5.7400, 6.3400, 6.6200, 6.7600, 6.7600, 6.8600, 7.2400, 7.4200, 7.6800, 7.9200, 8.4800, 8.7600, 9.2000, 9.2000, 9.4200, 15.8200, 15.8200, 29.6400, 29.6600, 29.6600, 29.6600, 29.6600, 29.7600]
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))
|
||||
|
||||
@slow
|
||||
def test_tiny_specaugment_librispeech(self):
|
||||
torch_device = "cpu"
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
import unittest
|
||||
|
||||
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
|
||||
from transformers.models.whisper.tokenization_whisper import _find_longest_common_sequence
|
||||
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
|
||||
from transformers.testing_utils import slow
|
||||
|
||||
from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
@ -255,6 +255,24 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
|
||||
self.assertListEqual(tokenizer_prompt_ids.tolist(), fast_tokenizer_prompt_ids.tolist())
|
||||
|
||||
def test_combine_tokens_into_words(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
rust_tokenizer = self.get_rust_tokenizer()
|
||||
|
||||
# 'whatever "whatever" said someone, clever!?'
|
||||
encoded_input = [1363, 7969, 503, 1363, 7969, 1, 848, 1580, 11, 13494, 7323]
|
||||
expected_words = ["whatever", ' "whatever"', " said", " someone,", " clever!?"]
|
||||
expected_tokens = [[1363, 7969], [503, 1363, 7969, 1], [848], [1580, 11], [13494, 7323]]
|
||||
expected_indices = [[0, 1], [2, 3, 4, 5], [6], [7, 8], [9, 10]]
|
||||
output = _combine_tokens_into_words(tokenizer, encoded_input)
|
||||
self.assertEqual(expected_words, output[0])
|
||||
self.assertEqual(expected_tokens, output[1])
|
||||
self.assertEqual(expected_indices, output[2])
|
||||
output_rust = _combine_tokens_into_words(rust_tokenizer, encoded_input)
|
||||
self.assertEqual(expected_words, output_rust[0])
|
||||
self.assertEqual(expected_tokens, output_rust[1])
|
||||
self.assertEqual(expected_indices, output_rust[2])
|
||||
|
||||
|
||||
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||
checkpoint_name = "openai/whisper-small.en"
|
||||
|
|
|
@ -316,6 +316,27 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||
"chunks": [{"text": " Conquered returned to its place amidst the tents.", "timestamp": (0.0, 3.36)}],
|
||||
},
|
||||
)
|
||||
pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
|
||||
res = pipe(sample["audio"]["array"], return_timestamps="word")
|
||||
# fmt: off
|
||||
# Note that the word-level timestamps predicted here are pretty bad.
|
||||
self.assertEqual(
|
||||
res,
|
||||
{
|
||||
"text": " Conquered returned to its place amidst the tents.",
|
||||
"chunks": [
|
||||
{'text': ' Conquered', 'timestamp': (29.78, 29.9)},
|
||||
{'text': ' returned', 'timestamp': (29.9, 29.9)},
|
||||
{'text': ' to', 'timestamp': (29.9, 29.9)},
|
||||
{'text': ' its', 'timestamp': (29.9, 29.9)},
|
||||
{'text': ' place', 'timestamp': (29.9, 29.9)},
|
||||
{'text': ' amidst', 'timestamp': (29.9, 29.9)},
|
||||
{'text': ' the', 'timestamp': (29.9, 29.9)},
|
||||
{'text': ' tents.', 'timestamp': (29.9, 29.9)}
|
||||
]
|
||||
}
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
|
@ -699,6 +720,35 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||
],
|
||||
},
|
||||
)
|
||||
speech_recognizer.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
|
||||
output = speech_recognizer(filename, return_timestamps="word")
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
output,
|
||||
{
|
||||
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
|
||||
"chunks": [
|
||||
{'text': ' Mr.', 'timestamp': (0.0, 1.02)},
|
||||
{'text': ' Quilter', 'timestamp': (1.02, 1.18)},
|
||||
{'text': ' is', 'timestamp': (1.18, 1.44)},
|
||||
{'text': ' the', 'timestamp': (1.44, 1.58)},
|
||||
{'text': ' apostle', 'timestamp': (1.58, 1.98)},
|
||||
{'text': ' of', 'timestamp': (1.98, 2.3)},
|
||||
{'text': ' the', 'timestamp': (2.3, 2.46)},
|
||||
{'text': ' middle', 'timestamp': (2.46, 2.56)},
|
||||
{'text': ' classes,', 'timestamp': (2.56, 3.38)},
|
||||
{'text': ' and', 'timestamp': (3.38, 3.52)},
|
||||
{'text': ' we', 'timestamp': (3.52, 3.6)},
|
||||
{'text': ' are', 'timestamp': (3.6, 3.72)},
|
||||
{'text': ' glad', 'timestamp': (3.72, 4.0)},
|
||||
{'text': ' to', 'timestamp': (4.0, 4.26)},
|
||||
{'text': ' welcome', 'timestamp': (4.26, 4.54)},
|
||||
{'text': ' his', 'timestamp': (4.54, 4.92)},
|
||||
{'text': ' gospel.', 'timestamp': (4.92, 6.66)},
|
||||
],
|
||||
},
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
|
|
Loading…
Reference in New Issue