transformers.js/examples/semantic-audio-search/worker.js

95 lines
3.1 KiB
JavaScript

import { env, AutoTokenizer, ClapTextModelWithProjection } from '@xenova/transformers';
import { getCachedFile } from './utils';
// Skip local model check
env.allowLocalModels = false;
class ApplicationSingleton {
static model_id = 'Xenova/larger_clap_music_and_speech';
static BASE_URL = 'https://huggingface.co/datasets/Xenova/MusicBenchEmbedded/resolve/main/';
static tokenizer = null;
static text_model = null;
static embeddings = null;
static async getInstance(progress_callback = null) {
this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, { progress_callback });
this.text_model ??= ClapTextModelWithProjection.from_pretrained(this.model_id, {
progress_callback,
quantized: true, // TODO allow user to select quantized or not
});
this.embeddings ??= new Promise(
(resolve, reject) => {
getCachedFile(this.BASE_URL + 'audio-embeddings_52768-512_32bit.bin')
.then((buffer) => {
resolve(new Float32Array(buffer));
})
.catch(reject);
}
);
return Promise.all([this.tokenizer, this.text_model, this.embeddings]);
}
}
function cosineSimilarity(query_embeds, database_embeds) {
const EMBED_DIM = 512;
const numDB = database_embeds.length / EMBED_DIM;
const similarityScores = new Array(numDB);
for (let i = 0; i < numDB; ++i) {
const startOffset = i * EMBED_DIM;
const dbVector = database_embeds.slice(startOffset, startOffset + EMBED_DIM);
let dotProduct = 0;
let normEmbeds = 0;
let normDB = 0;
for (let j = 0; j < EMBED_DIM; ++j) {
const embedValue = query_embeds[j];
const dbValue = dbVector[j];
dotProduct += embedValue * dbValue;
normEmbeds += embedValue * embedValue;
normDB += dbValue * dbValue;
}
similarityScores[i] = dotProduct / (Math.sqrt(normEmbeds) * Math.sqrt(normDB));
}
return similarityScores;
}
// Listen for messages from the main thread
self.addEventListener('message', async (event) => {
// Get the tokenizer, model, and embeddings. When called for the first time,
// this will load the files and cache them for future use.
const [tokenizer, text_model, embeddings] = await ApplicationSingleton.getInstance(self.postMessage);
// Send the output back to the main thread
self.postMessage({ status: 'ready' });
// Run tokenization
const text_inputs = tokenizer(event.data.query, { padding: true, truncation: true });
// Compute embeddings
const { text_embeds } = await text_model(text_inputs);
// Compute similarity scores
const scores = cosineSimilarity(text_embeds.data, embeddings);
const output = scores
.map((score, i) => [score, i]) // Save index
.sort((a, b) => b[0] - a[0]) // Sort by scores
.slice(0, 100); // Get top 100 results
// Send the output back to the main thread
self.postMessage({
status: 'complete',
output: output,
});
});