Add `ClapAudioModelWithProjection` and `ClapTextModelWithProjection`

This commit is contained in:
Joshua Lochner 2023-12-01 22:20:36 +02:00
parent 3a4c71fee1
commit 9fa43597b0
1 changed files with 74 additions and 0 deletions

View File

@ -4043,6 +4043,77 @@ export class FalconForCausalLM extends FalconPreTrainedModel { }
export class ClapPreTrainedModel extends PreTrainedModel { }
export class ClapModel extends ClapPreTrainedModel { }
/**
* CLAP Text Model with a projection layer on top (a linear layer on top of the pooled output).
*
* **Example:** Compute text embeddings with `ClapTextModelWithProjection`.
*
* ```javascript
* import { AutoTokenizer, ClapTextModelWithProjection } from '@xenova/transformers';
*
* // Load tokenizer and text model
* const tokenizer = await AutoTokenizer.from_pretrained('Xenova/clap-htsat-unfused');
* const text_model = await ClapTextModelWithProjection.from_pretrained('Xenova/clap-htsat-unfused');
*
* // Run tokenization
* const texts = ['a sound of a cat', 'a sound of a dog'];
* const text_inputs = tokenizer(texts, { padding: true, truncation: true });
*
* // Compute embeddings
* const { text_embeds } = await text_model(text_inputs);
* // Tensor {
* // dims: [ 2, 512 ],
* // type: 'float32',
* // data: Float32Array(1024) [ ... ],
* // size: 1024
* // }
* ```
*/
export class ClapTextModelWithProjection extends ClapPreTrainedModel {
/** @type {PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'text_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
}
}
/**
* CLAP Audio Model with a projection layer on top (a linear layer on top of the pooled output).
*
* **Example:** Compute audio embeddings with `ClapAudioModelWithProjection`.
*
* ```javascript
* import { AutoProcessor, ClapAudioModelWithProjection, read_audio } from '@xenova/transformers';
*
* // Load processor and audio model
* const processor = await AutoProcessor.from_pretrained('Xenova/clap-htsat-unfused');
* const audio_model = await ClapAudioModelWithProjection.from_pretrained('Xenova/clap-htsat-unfused');
*
* // Read audio and run processor
* const audio = await read_audio('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cat_meow.wav');
* const audio_inputs = await processor(audio);
*
* // Compute embeddings
* const { audio_embeds } = await audio_model(audio_inputs);
* // Tensor {
* // dims: [ 1, 512 ],
* // type: 'float32',
* // data: Float32Array(512) [ ... ],
* // size: 512
* // }
* ```
*/
export class ClapAudioModelWithProjection extends ClapPreTrainedModel {
/** @type {PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'audio_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
}
}
//////////////////////////////////////////////////
@ -4372,6 +4443,9 @@ for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) {
const CUSTOM_MAPPING = [
['CLIPTextModelWithProjection', CLIPTextModelWithProjection, MODEL_TYPES.EncoderOnly],
['CLIPVisionModelWithProjection', CLIPVisionModelWithProjection, MODEL_TYPES.EncoderOnly],
['ClapTextModelWithProjection', ClapTextModelWithProjection, MODEL_TYPES.EncoderOnly],
['ClapAudioModelWithProjection', ClapAudioModelWithProjection, MODEL_TYPES.EncoderOnly],
]
for (const [name, model, type] of CUSTOM_MAPPING) {
MODEL_TYPE_MAPPING.set(name, type);