Add WavLM- & Wav2Vec2ForAudioFrameClassification support (#611)
* Add WavLMForXVector support * fix model docs * Add WavLMForAudioFrameClassification * Add missing wWav2Vec2ForAudioFrameCl. * Add doc comment * Add doc string wav2vec2 * update comment * make example like python * Update src/models.js --------- Co-authored-by: Joshua Lochner <admin@xenova.com>
This commit is contained in:
parent
5bb8d25337
commit
8eef154b1e
|
@ -4571,6 +4571,20 @@ export class Wav2Vec2ForSequenceClassification extends Wav2Vec2PreTrainedModel {
|
|||
return new SequenceClassifierOutput(await super._call(model_inputs));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Wav2Vec2 Model with a frame classification head on top for tasks like Speaker Diarization.
|
||||
*/
|
||||
export class Wav2Vec2ForAudioFrameClassification extends Wav2Vec2PreTrainedModel {
|
||||
/**
|
||||
* Calls the model on new inputs.
|
||||
* @param {Object} model_inputs The inputs to the model.
|
||||
* @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for sequence classification.
|
||||
*/
|
||||
async _call(model_inputs) {
|
||||
return new TokenClassifierOutput(await super._call(model_inputs));
|
||||
}
|
||||
}
|
||||
//////////////////////////////////////////////////
|
||||
|
||||
//////////////////////////////////////////////////
|
||||
|
@ -4868,6 +4882,54 @@ export class WavLMForXVector extends WavLMPreTrainedModel {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* WavLM Model with a frame classification head on top for tasks like Speaker Diarization.
|
||||
*
|
||||
* **Example:** Perform speaker diarization with `WavLMForAudioFrameClassification`.
|
||||
* ```javascript
|
||||
* import { AutoProcessor, AutoModelForAudioFrameClassification, read_audio } from '@xenova/transformers';
|
||||
*
|
||||
* // Read and preprocess audio
|
||||
* const processor = await AutoProcessor.from_pretrained('Xenova/wavlm-base-plus-sd');
|
||||
* const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav';
|
||||
* const audio = await read_audio(url, 16000);
|
||||
* const inputs = await processor(audio);
|
||||
*
|
||||
* // Run model with inputs
|
||||
* const model = await AutoModelForAudioFrameClassification.from_pretrained('Xenova/wavlm-base-plus-sd');
|
||||
* const { logits } = await model(inputs);
|
||||
* // {
|
||||
* // logits: Tensor {
|
||||
* // dims: [ 1, 549, 2 ], // [batch_size, num_frames, num_speakers]
|
||||
* // type: 'float32',
|
||||
* // data: Float32Array(1098) [-3.5301010608673096, ...],
|
||||
* // size: 1098
|
||||
* // }
|
||||
* // }
|
||||
*
|
||||
* const labels = logits[0].sigmoid().tolist().map(
|
||||
* frames => frames.map(speaker => speaker > 0.5 ? 1 : 0)
|
||||
* );
|
||||
* console.log(labels); // labels is a one-hot array of shape (num_frames, num_speakers)
|
||||
* // [
|
||||
* // [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0],
|
||||
* // [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0],
|
||||
* // [0, 0], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1],
|
||||
* // ...
|
||||
* // ]
|
||||
* ```
|
||||
*/
|
||||
export class WavLMForAudioFrameClassification extends WavLMPreTrainedModel {
|
||||
/**
|
||||
* Calls the model on new inputs.
|
||||
* @param {Object} model_inputs The inputs to the model.
|
||||
* @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for sequence classification.
|
||||
*/
|
||||
async _call(model_inputs) {
|
||||
return new TokenClassifierOutput(await super._call(model_inputs));
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////
|
||||
// SpeechT5 models
|
||||
/**
|
||||
|
@ -5695,6 +5757,8 @@ const MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = new Map([
|
|||
|
||||
const MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = new Map([
|
||||
['unispeech-sat', ['UniSpeechSatForAudioFrameClassification', UniSpeechSatForAudioFrameClassification]],
|
||||
['wavlm', ['WavLMForAudioFrameClassification', WavLMForAudioFrameClassification]],
|
||||
['wav2vec2', ['Wav2Vec2ForAudioFrameClassification', Wav2Vec2ForAudioFrameClassification]],
|
||||
]);
|
||||
|
||||
const MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES = new Map([
|
||||
|
@ -5961,6 +6025,10 @@ export class AutoModelForXVector extends PretrainedMixin {
|
|||
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES];
|
||||
}
|
||||
|
||||
export class AutoModelForAudioFrameClassification extends PretrainedMixin {
|
||||
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES];
|
||||
}
|
||||
|
||||
export class AutoModelForDocumentQuestionAnswering extends PretrainedMixin {
|
||||
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES];
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue