Fix word-level timestamps for non-English languages w/ Whisper (#253)

* Fix language detection

* Remove debug statement

* Fix punctuation regex for whisper decoding (Closes #223)

* Fix word-level timestamps for audio < 30 seconds

Issue in python library: https://github.com/huggingface/transformers/issues/25605
PR for above: https://github.com/huggingface/transformers/pull/25607

* Add multilingual transcription w/ word-level timestamps unit test

* Fix unit tests
This commit is contained in:
Joshua Lochner 2023-08-22 15:50:30 +02:00 committed by GitHub
parent 276bdd06b8
commit c3af596443
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 102 additions and 36 deletions

View File

@ -2171,15 +2171,20 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
}
/**
* Generates outputs based on input and generation configuration.
* @param {Object} inputs Input data for the model.
* @param {Object} generation_config Configuration object for the generation process.
* @param {Object} logits_processor Optional logits processor object.
* @param {Object} options options
* @param {Object} [options.return_timestamps=null] Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
* @param {Object} [options.return_token_timestamps=null] Whether to return token-level timestamps
* @typedef {Object} WhisperGenerationConfig
* @extends GenerationConfig
* @property {boolean} [return_timestamps=null] Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
* @property {boolean} [return_token_timestamps=null] Whether to return token-level timestamps
* with the text. This can be used with or without the `return_timestamps` option. To get word-level
* timestamps, use the tokenizer to group the tokens into words.
* @property {number} [num_frames=null] The number of audio frames available in this chunk. This is only used generating word-level timestamps.
*/
/**
* Generates outputs based on input and generation configuration.
* @param {Object} inputs Input data for the model.
* @param {WhisperGenerationConfig} generation_config Configuration object for the generation process.
* @param {Object} logits_processor Optional logits processor object.
* @returns {Promise<Object>} Promise object represents the generated outputs.
*/
// @ts-ignore
@ -2226,7 +2231,11 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
const outputs = await super.generate(inputs, generation_config, logits_processor);
if (generation_config.return_token_timestamps && generation_config.alignment_heads) {
outputs["token_timestamps"] = this._extract_token_timestamps(outputs, generation_config.alignment_heads)
outputs["token_timestamps"] = this._extract_token_timestamps(
outputs,
generation_config.alignment_heads,
generation_config.num_frames,
)
}
return outputs
@ -2280,10 +2289,11 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
* @param {Tensor[][][]} generate_outputs.decoder_attentions The decoder attentions output by the model
* @param {number[][]} generate_outputs.sequences The sequences output by the model
* @param {number[][]} alignment_heads Alignment heads of the model
* @param {number} time_precision Precision of the timestamps in seconds
* @param {number} [num_frames=null] Number of frames in the input audio.
* @param {number} [time_precision=0.02] Precision of the timestamps in seconds
* @returns {Tensor} tensor containing the timestamps in seconds for each predicted token
*/
_extract_token_timestamps(generate_outputs, alignment_heads, time_precision = 0.02) {
_extract_token_timestamps(generate_outputs, alignment_heads, num_frames = null, time_precision = 0.02) {
if (!generate_outputs.cross_attentions) {
throw new Error(
"Model outputs must contain cross attentions to extract timestamps. " +
@ -2304,7 +2314,11 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
(_, i) => cat(batch.map(x => x[i]), 2)
);
let weights = stack(alignment_heads.map(([l, h]) => cross_attentions[l].slice(null, h)));
let weights = stack(alignment_heads.map(([l, h]) => {
return num_frames
? cross_attentions[l].slice(null, h, null, [0, num_frames])
: cross_attentions[l].slice(null, h);
}));
weights = weights.transpose(1, 0, 2, 3)
let [std, calculatedMean] = std_mean(weights, -2, 0, true);

View File

@ -1207,6 +1207,7 @@ export class AutomaticSpeechRecognitionPipeline extends Pipeline {
const sampling_rate = this.processor.feature_extractor.config.sampling_rate;
const time_precision = this.processor.feature_extractor.config.chunk_length / this.model.config.max_source_positions;
const hop_length = this.processor.feature_extractor.config.hop_length;
let toReturn = [];
for (let aud of audio) {
@ -1258,6 +1259,8 @@ export class AutomaticSpeechRecognitionPipeline extends Pipeline {
// Generate for each set of input features
for (let chunk of chunks) {
kwargs.num_frames = Math.floor(chunk.stride[0] / hop_length);
// NOTE: doing sequentially for now
let data = await this.model.generate(chunk.input_features, kwargs);

View File

@ -2941,30 +2941,28 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
if (all_special_ids.has(token)) {
const text = this.decode([token]);
if (text[0] === "[" && text[text.length - 1] === "]") {
const language = WHISPER_LANGUAGE_MAPPING.get(text.slice(1, -1));
const language = WHISPER_LANGUAGE_MAPPING.get(text.slice(2, -2));
if (language !== undefined) {
// 1/ Indeed some language
// TODO Handle when language is different from the previous
// one, and we cannot use timestamped tokens to create chunks
if (last_language !== null && language !== last_language && !return_timestamps) {
previous_tokens.push(current_tokens);
const resolved_tokens = this.findLongestCommonSequence(previous_tokens)[0];
const resolved_text = this.decode(resolved_tokens);
chunk.text = resolved_text;
chunks.push(chunk);
if (language !== undefined) {
// 1/ Indeed some language
// TODO Handle when language is different from the previous
// one, and we cannot use timestamped tokens to create chunks
if (last_language !== null && language !== last_language && !return_timestamps) {
previous_tokens.push(current_tokens);
const resolved_tokens = this.findLongestCommonSequence(previous_tokens)[0];
const resolved_text = this.decode(resolved_tokens);
chunk.text = resolved_text;
chunks.push(chunk);
// Flush all our temporary context
previous_tokens = [];
current_tokens = [];
chunk = new_chunk();
}
last_language = chunk.language = language;
} else {
// 2/ This is a regular special token, ignoring it
// Flush all our temporary context
previous_tokens = [];
current_tokens = [];
chunk = new_chunk();
}
last_language = chunk.language = language;
} else {
// 2/ This is a regular special token, ignoring it
}
} else if (token >= timestamp_begin) {
// 3/ Timestamp token
@ -3253,7 +3251,6 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
if (["chinese", "japanese", "thai", "lao", "myanmar"].includes(language)) {
// These languages don't typically use spaces.
[words, word_tokens, token_indices] = this.splitTokensOnUnicode(tokens)
} else {
[words, word_tokens, token_indices] = this.splitTokensOnSpaces(tokens)
@ -3373,7 +3370,7 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
let word_tokens = []
let token_indices = []
const punctuationRegex = new RegExp(`[${PUNCTUATION_REGEX}]`)
const punctuationRegex = new RegExp(`^[${PUNCTUATION_REGEX}]$`, 'gu');
for (let i = 0; i < subwords.length; ++i) {

View File

@ -733,6 +733,7 @@ describe('Pipelines', () => {
'openai/whisper-tiny.en', // English-only
'openai/whisper-small', // Multilingual
['openai/whisper-tiny.en', 'output_attentions'], // English-only + `output_attentions`
['openai/whisper-base', 'output_attentions'], // Multilingual + `output_attentions`
// wav2vec2
'jonatasgrosman/wav2vec2-large-xlsr-53-english',
@ -834,9 +835,60 @@ describe('Pipelines', () => {
}, MAX_TEST_EXECUTION_TIME);
it(models[3].join(' + '), async () => {
let transcriber = await pipeline('automatic-speech-recognition', m(models[3][0]), {
revision: models[3][1],
});
it(models[3], async () => {
let transcriber = await pipeline('automatic-speech-recognition', m(models[3]));
let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/japanese-audio.wav';
let audioData = await loadAudio(url);
{ // Transcribe Japanese w/ word-level timestamps.
let output = await transcriber(audioData, { return_timestamps: 'word', language: 'japanese', task: 'transcribe' });
const target = {
"text": "森長の美味しい牛乳は濃い青い牛乳ビーンを足らった階のパック牛乳である",
"chunks": [
{ "text": "森", "timestamp": [0.14, 0.64] },
{ "text": "長", "timestamp": [0.64, 0.82] },
{ "text": "の", "timestamp": [0.82, 1.04] },
{ "text": "美味", "timestamp": [1.04, 1.2] },
{ "text": "しい", "timestamp": [1.2, 1.5] },
{ "text": "牛", "timestamp": [1.5, 1.68] },
{ "text": "乳", "timestamp": [1.68, 1.92] },
{ "text": "は", "timestamp": [1.92, 2.14] },
{ "text": "濃", "timestamp": [2.14, 2.32] },
{ "text": "い", "timestamp": [2.32, 2.44] },
{ "text": "青", "timestamp": [2.44, 2.66] },
{ "text": "い", "timestamp": [2.66, 2.76] },
{ "text": "牛", "timestamp": [2.76, 3.06] },
{ "text": "乳", "timestamp": [3.06, 3.36] },
{ "text": "ビ", "timestamp": [3.36, 3.58] },
{ "text": "ーン", "timestamp": [3.58, 3.66] },
{ "text": "を", "timestamp": [3.66, 3.82] },
{ "text": "足", "timestamp": [3.82, 4] },
{ "text": "ら", "timestamp": [4, 4.12] },
{ "text": "った", "timestamp": [4.12, 4.3] },
{ "text": "階", "timestamp": [4.3, 4.56] },
{ "text": "の", "timestamp": [4.56, 4.92] },
{ "text": "パ", "timestamp": [4.92, 5.1] },
{ "text": "ック", "timestamp": [5.1, 5.2] },
{ "text": "牛", "timestamp": [5.2, 5.44] },
{ "text": "乳", "timestamp": [5.44, 5.64] },
{ "text": "で", "timestamp": [5.64, 5.84] },
{ "text": "ある", "timestamp": [5.84, 6.06] }
]
}
compare(output, target);
}
await transcriber.dispose();
}, MAX_TEST_EXECUTION_TIME);
it(models[4], async () => {
let transcriber = await pipeline('automatic-speech-recognition', m(models[4]));
let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav';
let audioData = await loadAudio(url);