Add WavLM- & Wav2Vec2ForAudioFrameClassification support (#611)

* Add WavLMForXVector support

* fix model docs

* Add WavLMForAudioFrameClassification

* Add missing wWav2Vec2ForAudioFrameCl.

* Add doc comment

* Add doc string wav2vec2

* update comment

* make example like python

* Update src/models.js

---------

Co-authored-by: Joshua Lochner <admin@xenova.com>
This commit is contained in:
Dave 2024-03-07 14:27:49 +01:00 committed by GitHub
parent 5bb8d25337
commit 8eef154b1e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 68 additions and 0 deletions

View File

@ -4571,6 +4571,20 @@ export class Wav2Vec2ForSequenceClassification extends Wav2Vec2PreTrainedModel {
return new SequenceClassifierOutput(await super._call(model_inputs));
}
}
/**
* Wav2Vec2 Model with a frame classification head on top for tasks like Speaker Diarization.
*/
export class Wav2Vec2ForAudioFrameClassification extends Wav2Vec2PreTrainedModel {
/**
* Calls the model on new inputs.
* @param {Object} model_inputs The inputs to the model.
* @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for sequence classification.
*/
async _call(model_inputs) {
return new TokenClassifierOutput(await super._call(model_inputs));
}
}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
@ -4868,6 +4882,54 @@ export class WavLMForXVector extends WavLMPreTrainedModel {
}
}
/**
* WavLM Model with a frame classification head on top for tasks like Speaker Diarization.
*
* **Example:** Perform speaker diarization with `WavLMForAudioFrameClassification`.
* ```javascript
* import { AutoProcessor, AutoModelForAudioFrameClassification, read_audio } from '@xenova/transformers';
*
* // Read and preprocess audio
* const processor = await AutoProcessor.from_pretrained('Xenova/wavlm-base-plus-sd');
* const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav';
* const audio = await read_audio(url, 16000);
* const inputs = await processor(audio);
*
* // Run model with inputs
* const model = await AutoModelForAudioFrameClassification.from_pretrained('Xenova/wavlm-base-plus-sd');
* const { logits } = await model(inputs);
* // {
* // logits: Tensor {
* // dims: [ 1, 549, 2 ], // [batch_size, num_frames, num_speakers]
* // type: 'float32',
* // data: Float32Array(1098) [-3.5301010608673096, ...],
* // size: 1098
* // }
* // }
*
* const labels = logits[0].sigmoid().tolist().map(
* frames => frames.map(speaker => speaker > 0.5 ? 1 : 0)
* );
* console.log(labels); // labels is a one-hot array of shape (num_frames, num_speakers)
* // [
* // [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0],
* // [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0],
* // [0, 0], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1],
* // ...
* // ]
* ```
*/
export class WavLMForAudioFrameClassification extends WavLMPreTrainedModel {
/**
* Calls the model on new inputs.
* @param {Object} model_inputs The inputs to the model.
* @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for sequence classification.
*/
async _call(model_inputs) {
return new TokenClassifierOutput(await super._call(model_inputs));
}
}
//////////////////////////////////////////////////
// SpeechT5 models
/**
@ -5695,6 +5757,8 @@ const MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = new Map([
const MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = new Map([
['unispeech-sat', ['UniSpeechSatForAudioFrameClassification', UniSpeechSatForAudioFrameClassification]],
['wavlm', ['WavLMForAudioFrameClassification', WavLMForAudioFrameClassification]],
['wav2vec2', ['Wav2Vec2ForAudioFrameClassification', Wav2Vec2ForAudioFrameClassification]],
]);
const MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES = new Map([
@ -5961,6 +6025,10 @@ export class AutoModelForXVector extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES];
}
export class AutoModelForAudioFrameClassification extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES];
}
export class AutoModelForDocumentQuestionAnswering extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES];
}