From 8eef154b1eca8694778beddec41222733b3d38fb Mon Sep 17 00:00:00 2001 From: Dave <69651599+D4ve-R@users.noreply.github.com> Date: Thu, 7 Mar 2024 14:27:49 +0100 Subject: [PATCH] Add WavLM- & Wav2Vec2ForAudioFrameClassification support (#611) * Add WavLMForXVector support * fix model docs * Add WavLMForAudioFrameClassification * Add missing wWav2Vec2ForAudioFrameCl. * Add doc comment * Add doc string wav2vec2 * update comment * make example like python * Update src/models.js --------- Co-authored-by: Joshua Lochner --- src/models.js | 68 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/src/models.js b/src/models.js index ef3d280..15b656f 100644 --- a/src/models.js +++ b/src/models.js @@ -4571,6 +4571,20 @@ export class Wav2Vec2ForSequenceClassification extends Wav2Vec2PreTrainedModel { return new SequenceClassifierOutput(await super._call(model_inputs)); } } + +/** + * Wav2Vec2 Model with a frame classification head on top for tasks like Speaker Diarization. + */ +export class Wav2Vec2ForAudioFrameClassification extends Wav2Vec2PreTrainedModel { + /** + * Calls the model on new inputs. + * @param {Object} model_inputs The inputs to the model. + * @returns {Promise} An object containing the model's output logits for sequence classification. + */ + async _call(model_inputs) { + return new TokenClassifierOutput(await super._call(model_inputs)); + } +} ////////////////////////////////////////////////// ////////////////////////////////////////////////// @@ -4868,6 +4882,54 @@ export class WavLMForXVector extends WavLMPreTrainedModel { } } +/** + * WavLM Model with a frame classification head on top for tasks like Speaker Diarization. + * + * **Example:** Perform speaker diarization with `WavLMForAudioFrameClassification`. + * ```javascript + * import { AutoProcessor, AutoModelForAudioFrameClassification, read_audio } from '@xenova/transformers'; + * + * // Read and preprocess audio + * const processor = await AutoProcessor.from_pretrained('Xenova/wavlm-base-plus-sd'); + * const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav'; + * const audio = await read_audio(url, 16000); + * const inputs = await processor(audio); + * + * // Run model with inputs + * const model = await AutoModelForAudioFrameClassification.from_pretrained('Xenova/wavlm-base-plus-sd'); + * const { logits } = await model(inputs); + * // { + * // logits: Tensor { + * // dims: [ 1, 549, 2 ], // [batch_size, num_frames, num_speakers] + * // type: 'float32', + * // data: Float32Array(1098) [-3.5301010608673096, ...], + * // size: 1098 + * // } + * // } + * + * const labels = logits[0].sigmoid().tolist().map( + * frames => frames.map(speaker => speaker > 0.5 ? 1 : 0) + * ); + * console.log(labels); // labels is a one-hot array of shape (num_frames, num_speakers) + * // [ + * // [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], + * // [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], + * // [0, 0], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], + * // ... + * // ] + * ``` + */ +export class WavLMForAudioFrameClassification extends WavLMPreTrainedModel { + /** + * Calls the model on new inputs. + * @param {Object} model_inputs The inputs to the model. + * @returns {Promise} An object containing the model's output logits for sequence classification. + */ + async _call(model_inputs) { + return new TokenClassifierOutput(await super._call(model_inputs)); + } +} + ////////////////////////////////////////////////// // SpeechT5 models /** @@ -5695,6 +5757,8 @@ const MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = new Map([ const MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = new Map([ ['unispeech-sat', ['UniSpeechSatForAudioFrameClassification', UniSpeechSatForAudioFrameClassification]], + ['wavlm', ['WavLMForAudioFrameClassification', WavLMForAudioFrameClassification]], + ['wav2vec2', ['Wav2Vec2ForAudioFrameClassification', Wav2Vec2ForAudioFrameClassification]], ]); const MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES = new Map([ @@ -5961,6 +6025,10 @@ export class AutoModelForXVector extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES]; } +export class AutoModelForAudioFrameClassification extends PretrainedMixin { + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES]; +} + export class AutoModelForDocumentQuestionAnswering extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES]; }