Fix word-level timestamps for non-English languages w/ Whisper (#253)
* Fix language detection * Remove debug statement * Fix punctuation regex for whisper decoding (Closes #223) * Fix word-level timestamps for audio < 30 seconds Issue in python library: https://github.com/huggingface/transformers/issues/25605 PR for above: https://github.com/huggingface/transformers/pull/25607 * Add multilingual transcription w/ word-level timestamps unit test * Fix unit tests
This commit is contained in:
parent
276bdd06b8
commit
c3af596443
|
@ -2171,15 +2171,20 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
|
|||
}
|
||||
|
||||
/**
|
||||
* Generates outputs based on input and generation configuration.
|
||||
* @param {Object} inputs Input data for the model.
|
||||
* @param {Object} generation_config Configuration object for the generation process.
|
||||
* @param {Object} logits_processor Optional logits processor object.
|
||||
* @param {Object} options options
|
||||
* @param {Object} [options.return_timestamps=null] Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
|
||||
* @param {Object} [options.return_token_timestamps=null] Whether to return token-level timestamps
|
||||
* @typedef {Object} WhisperGenerationConfig
|
||||
* @extends GenerationConfig
|
||||
* @property {boolean} [return_timestamps=null] Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
|
||||
* @property {boolean} [return_token_timestamps=null] Whether to return token-level timestamps
|
||||
* with the text. This can be used with or without the `return_timestamps` option. To get word-level
|
||||
* timestamps, use the tokenizer to group the tokens into words.
|
||||
* @property {number} [num_frames=null] The number of audio frames available in this chunk. This is only used generating word-level timestamps.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Generates outputs based on input and generation configuration.
|
||||
* @param {Object} inputs Input data for the model.
|
||||
* @param {WhisperGenerationConfig} generation_config Configuration object for the generation process.
|
||||
* @param {Object} logits_processor Optional logits processor object.
|
||||
* @returns {Promise<Object>} Promise object represents the generated outputs.
|
||||
*/
|
||||
// @ts-ignore
|
||||
|
@ -2226,7 +2231,11 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
|
|||
const outputs = await super.generate(inputs, generation_config, logits_processor);
|
||||
|
||||
if (generation_config.return_token_timestamps && generation_config.alignment_heads) {
|
||||
outputs["token_timestamps"] = this._extract_token_timestamps(outputs, generation_config.alignment_heads)
|
||||
outputs["token_timestamps"] = this._extract_token_timestamps(
|
||||
outputs,
|
||||
generation_config.alignment_heads,
|
||||
generation_config.num_frames,
|
||||
)
|
||||
}
|
||||
|
||||
return outputs
|
||||
|
@ -2280,10 +2289,11 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
|
|||
* @param {Tensor[][][]} generate_outputs.decoder_attentions The decoder attentions output by the model
|
||||
* @param {number[][]} generate_outputs.sequences The sequences output by the model
|
||||
* @param {number[][]} alignment_heads Alignment heads of the model
|
||||
* @param {number} time_precision Precision of the timestamps in seconds
|
||||
* @param {number} [num_frames=null] Number of frames in the input audio.
|
||||
* @param {number} [time_precision=0.02] Precision of the timestamps in seconds
|
||||
* @returns {Tensor} tensor containing the timestamps in seconds for each predicted token
|
||||
*/
|
||||
_extract_token_timestamps(generate_outputs, alignment_heads, time_precision = 0.02) {
|
||||
_extract_token_timestamps(generate_outputs, alignment_heads, num_frames = null, time_precision = 0.02) {
|
||||
if (!generate_outputs.cross_attentions) {
|
||||
throw new Error(
|
||||
"Model outputs must contain cross attentions to extract timestamps. " +
|
||||
|
@ -2304,7 +2314,11 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
|
|||
(_, i) => cat(batch.map(x => x[i]), 2)
|
||||
);
|
||||
|
||||
let weights = stack(alignment_heads.map(([l, h]) => cross_attentions[l].slice(null, h)));
|
||||
let weights = stack(alignment_heads.map(([l, h]) => {
|
||||
return num_frames
|
||||
? cross_attentions[l].slice(null, h, null, [0, num_frames])
|
||||
: cross_attentions[l].slice(null, h);
|
||||
}));
|
||||
weights = weights.transpose(1, 0, 2, 3)
|
||||
|
||||
let [std, calculatedMean] = std_mean(weights, -2, 0, true);
|
||||
|
|
|
@ -1207,6 +1207,7 @@ export class AutomaticSpeechRecognitionPipeline extends Pipeline {
|
|||
|
||||
const sampling_rate = this.processor.feature_extractor.config.sampling_rate;
|
||||
const time_precision = this.processor.feature_extractor.config.chunk_length / this.model.config.max_source_positions;
|
||||
const hop_length = this.processor.feature_extractor.config.hop_length;
|
||||
|
||||
let toReturn = [];
|
||||
for (let aud of audio) {
|
||||
|
@ -1258,6 +1259,8 @@ export class AutomaticSpeechRecognitionPipeline extends Pipeline {
|
|||
|
||||
// Generate for each set of input features
|
||||
for (let chunk of chunks) {
|
||||
kwargs.num_frames = Math.floor(chunk.stride[0] / hop_length);
|
||||
|
||||
// NOTE: doing sequentially for now
|
||||
let data = await this.model.generate(chunk.input_features, kwargs);
|
||||
|
||||
|
|
|
@ -2941,30 +2941,28 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
|
|||
|
||||
if (all_special_ids.has(token)) {
|
||||
const text = this.decode([token]);
|
||||
if (text[0] === "[" && text[text.length - 1] === "]") {
|
||||
const language = WHISPER_LANGUAGE_MAPPING.get(text.slice(1, -1));
|
||||
const language = WHISPER_LANGUAGE_MAPPING.get(text.slice(2, -2));
|
||||
|
||||
if (language !== undefined) {
|
||||
// 1/ Indeed some language
|
||||
// TODO Handle when language is different from the previous
|
||||
// one, and we cannot use timestamped tokens to create chunks
|
||||
if (last_language !== null && language !== last_language && !return_timestamps) {
|
||||
previous_tokens.push(current_tokens);
|
||||
const resolved_tokens = this.findLongestCommonSequence(previous_tokens)[0];
|
||||
const resolved_text = this.decode(resolved_tokens);
|
||||
chunk.text = resolved_text;
|
||||
chunks.push(chunk);
|
||||
if (language !== undefined) {
|
||||
// 1/ Indeed some language
|
||||
// TODO Handle when language is different from the previous
|
||||
// one, and we cannot use timestamped tokens to create chunks
|
||||
if (last_language !== null && language !== last_language && !return_timestamps) {
|
||||
previous_tokens.push(current_tokens);
|
||||
const resolved_tokens = this.findLongestCommonSequence(previous_tokens)[0];
|
||||
const resolved_text = this.decode(resolved_tokens);
|
||||
chunk.text = resolved_text;
|
||||
chunks.push(chunk);
|
||||
|
||||
// Flush all our temporary context
|
||||
previous_tokens = [];
|
||||
current_tokens = [];
|
||||
chunk = new_chunk();
|
||||
}
|
||||
|
||||
last_language = chunk.language = language;
|
||||
} else {
|
||||
// 2/ This is a regular special token, ignoring it
|
||||
// Flush all our temporary context
|
||||
previous_tokens = [];
|
||||
current_tokens = [];
|
||||
chunk = new_chunk();
|
||||
}
|
||||
|
||||
last_language = chunk.language = language;
|
||||
} else {
|
||||
// 2/ This is a regular special token, ignoring it
|
||||
}
|
||||
} else if (token >= timestamp_begin) {
|
||||
// 3/ Timestamp token
|
||||
|
@ -3253,7 +3251,6 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
|
|||
|
||||
if (["chinese", "japanese", "thai", "lao", "myanmar"].includes(language)) {
|
||||
// These languages don't typically use spaces.
|
||||
|
||||
[words, word_tokens, token_indices] = this.splitTokensOnUnicode(tokens)
|
||||
} else {
|
||||
[words, word_tokens, token_indices] = this.splitTokensOnSpaces(tokens)
|
||||
|
@ -3373,7 +3370,7 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
|
|||
let word_tokens = []
|
||||
let token_indices = []
|
||||
|
||||
const punctuationRegex = new RegExp(`[${PUNCTUATION_REGEX}]`)
|
||||
const punctuationRegex = new RegExp(`^[${PUNCTUATION_REGEX}]$`, 'gu');
|
||||
|
||||
for (let i = 0; i < subwords.length; ++i) {
|
||||
|
||||
|
|
|
@ -733,6 +733,7 @@ describe('Pipelines', () => {
|
|||
'openai/whisper-tiny.en', // English-only
|
||||
'openai/whisper-small', // Multilingual
|
||||
['openai/whisper-tiny.en', 'output_attentions'], // English-only + `output_attentions`
|
||||
['openai/whisper-base', 'output_attentions'], // Multilingual + `output_attentions`
|
||||
|
||||
// wav2vec2
|
||||
'jonatasgrosman/wav2vec2-large-xlsr-53-english',
|
||||
|
@ -834,9 +835,60 @@ describe('Pipelines', () => {
|
|||
|
||||
}, MAX_TEST_EXECUTION_TIME);
|
||||
|
||||
it(models[3].join(' + '), async () => {
|
||||
let transcriber = await pipeline('automatic-speech-recognition', m(models[3][0]), {
|
||||
revision: models[3][1],
|
||||
});
|
||||
|
||||
it(models[3], async () => {
|
||||
let transcriber = await pipeline('automatic-speech-recognition', m(models[3]));
|
||||
let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/japanese-audio.wav';
|
||||
let audioData = await loadAudio(url);
|
||||
|
||||
{ // Transcribe Japanese w/ word-level timestamps.
|
||||
let output = await transcriber(audioData, { return_timestamps: 'word', language: 'japanese', task: 'transcribe' });
|
||||
const target = {
|
||||
"text": "森長の美味しい牛乳は濃い青い牛乳ビーンを足らった階のパック牛乳である",
|
||||
"chunks": [
|
||||
{ "text": "森", "timestamp": [0.14, 0.64] },
|
||||
{ "text": "長", "timestamp": [0.64, 0.82] },
|
||||
{ "text": "の", "timestamp": [0.82, 1.04] },
|
||||
{ "text": "美味", "timestamp": [1.04, 1.2] },
|
||||
{ "text": "しい", "timestamp": [1.2, 1.5] },
|
||||
{ "text": "牛", "timestamp": [1.5, 1.68] },
|
||||
{ "text": "乳", "timestamp": [1.68, 1.92] },
|
||||
{ "text": "は", "timestamp": [1.92, 2.14] },
|
||||
{ "text": "濃", "timestamp": [2.14, 2.32] },
|
||||
{ "text": "い", "timestamp": [2.32, 2.44] },
|
||||
{ "text": "青", "timestamp": [2.44, 2.66] },
|
||||
{ "text": "い", "timestamp": [2.66, 2.76] },
|
||||
{ "text": "牛", "timestamp": [2.76, 3.06] },
|
||||
{ "text": "乳", "timestamp": [3.06, 3.36] },
|
||||
{ "text": "ビ", "timestamp": [3.36, 3.58] },
|
||||
{ "text": "ーン", "timestamp": [3.58, 3.66] },
|
||||
{ "text": "を", "timestamp": [3.66, 3.82] },
|
||||
{ "text": "足", "timestamp": [3.82, 4] },
|
||||
{ "text": "ら", "timestamp": [4, 4.12] },
|
||||
{ "text": "った", "timestamp": [4.12, 4.3] },
|
||||
{ "text": "階", "timestamp": [4.3, 4.56] },
|
||||
{ "text": "の", "timestamp": [4.56, 4.92] },
|
||||
{ "text": "パ", "timestamp": [4.92, 5.1] },
|
||||
{ "text": "ック", "timestamp": [5.1, 5.2] },
|
||||
{ "text": "牛", "timestamp": [5.2, 5.44] },
|
||||
{ "text": "乳", "timestamp": [5.44, 5.64] },
|
||||
{ "text": "で", "timestamp": [5.64, 5.84] },
|
||||
{ "text": "ある", "timestamp": [5.84, 6.06] }
|
||||
]
|
||||
}
|
||||
|
||||
compare(output, target);
|
||||
}
|
||||
|
||||
await transcriber.dispose();
|
||||
|
||||
}, MAX_TEST_EXECUTION_TIME);
|
||||
|
||||
|
||||
it(models[4], async () => {
|
||||
let transcriber = await pipeline('automatic-speech-recognition', m(models[4]));
|
||||
|
||||
let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav';
|
||||
let audioData = await loadAudio(url);
|
||||
|
|
Loading…
Reference in New Issue