transformers.js/examples/webgpu-embedding-benchmark/main.js

306 lines
8.7 KiB
JavaScript

import './style.css';
import { env, AutoModel, ones } from '@xenova/transformers';
import Chart from 'chart.js/auto';
// Throw an error if WebGPU is not supported
if (!navigator.gpu) {
const err = 'WebGPU is not supported by this browser.';
alert(err)
throw Error(err);
}
env.backends.onnx.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.1/dist/';
env.backends.onnx.wasm.numThreads = 1;
// Reference the elements that we will need
const ctx = document.getElementById('chart');
const batchSizes = document.getElementById('batch-sizes');
const xscale = document.getElementById('x-scale');
const yscale = document.getElementById('y-scale');
const sequenceLength = document.getElementById('sequence-length');
const modelID = document.getElementById('model-id');
const status = document.getElementById('status');
const start = document.getElementById('start');
const stop = document.getElementById('stop');
const tests = document.getElementsByClassName('tests');
// Benchmark settings
const NUM_WARMUP_STEPS = 3;
const MODEL_CACHE = new Map();
// Chart configuration
const initChart = () => {
const config = {
type: 'line',
data: {
labels: [],
datasets: [],
},
options: {
responsive: true,
maintainAspectRatio: false,
plugins: {
legend: {
position: 'top',
},
},
scales: {
x: {
title: {
display: true,
text: 'Batch size',
},
min: 1,
},
y: {
title: {
display: true,
text: 'Time (ms)',
},
}
}
},
};
const chart = new Chart(ctx, config);
return chart;
}
let chart = initChart();
const toggleScale = (axis, enabled) => {
chart.options.scales[axis].type = enabled ? 'logarithmic' : 'linear';
chart.update();
}
const getSelectedTests = () => {
return [...tests].filter(x => x.checked);
}
const updateDatasets = () => {
chart.data.datasets = getSelectedTests().map(test => {
const color = test.getAttribute('data-color');
return {
label: test.value,
data: [],
borderColor: `rgba(${color}, 1)`,
backgroundColor: `rgba(${color}, 0.5)`,
}
})
chart.update();
}
updateDatasets();
[...tests].forEach(test => test.addEventListener('change', updateDatasets));
xscale.addEventListener('change', () => toggleScale('x', xscale.checked));
yscale.addEventListener('change', () => toggleScale('y', yscale.checked));
const generateDummyInputs = (batch_size, seqLength) => {
const inputs = ones([batch_size, seqLength]);
const model_inputs = {
input_ids: inputs,
attention_mask: inputs,
}
return model_inputs;
}
let adapterInfo;
let gpuHasFp16 = false;
try {
// Shouldn't fail since the WebGPU model has loaded successfully
const adapter = await navigator.gpu.requestAdapter();
adapterInfo = await adapter.requestAdapterInfo();
gpuHasFp16 = adapter.features.has('shader-f16')
} catch (err) {
adapterInfo = {};
}
if (!gpuHasFp16) {
const element = document.querySelector('.tests[data-device="webgpu"][data-dtype="fp16"]');
element.setAttribute('unsupported', true);
element.disabled = true;
element.title = 'This device does not support fp16 on WebGPU';
}
status.textContent = 'Ready';
let interrupted = false;
start.addEventListener('click', async () => {
const validTests = [...tests].filter(test => !test.getAttribute('unsupported'))
// Update UI
start.disabled = true;
stop.disabled = false;
batchSizes.disabled = true;
sequenceLength.disabled = true;
modelID.disabled = true;
validTests.forEach(test => test.disabled = true);
interrupted = false;
// Get parameters
const model_id = modelID.value;
const batch_sizes = batchSizes.value.split(',').map(x => parseInt(x)).filter(x => x);
const seqLength = parseInt(sequenceLength.value);
const selectedTests = getSelectedTests().map(x => ({
label: x.value,
dtype: x.getAttribute('data-dtype'),
device: x.getAttribute('data-device'),
}));
// Reset
chart.destroy();
chart = initChart();
updateDatasets();
// NOTE: Models must be loaded sequentially (otherwise it will fail due to multiple calls to initWasm())
const testsToRun = new Map();
for (const test of selectedTests) {
const { label, dtype, device, quantized } = test;
const key = `${model_id}///${label}`;
const cached = MODEL_CACHE.get(key);
if (cached) {
testsToRun.set(label, cached);
continue;
}
status.textContent = 'Loading model(s)...';
try {
const model = await AutoModel.from_pretrained(model_id, {
quantized,
device,
dtype,
});
MODEL_CACHE.set(key, model);
testsToRun.set(label, model);
} catch (err) {
status.textContent = err.message;
alert(err.message)
throw err;
}
}
status.textContent = 'Warming up...';
// Warm up: This is important for the WebGPU execution provider, which compiles the shaders on first load
for (let i = 0; i < NUM_WARMUP_STEPS; ++i) {
const model_inputs = generateDummyInputs(1, seqLength);
for (const [label, model] of testsToRun) {
await model(model_inputs);
}
}
status.textContent = 'Running benchmark...';
for (const batch_size of batch_sizes) {
if (interrupted) break;
const model_inputs = generateDummyInputs(batch_size, seqLength);
const times = []
for (const [label, model] of testsToRun) {
const start = performance.now();
await model(model_inputs);
const end = performance.now();
times.push(end - start);
}
chart.data.labels.push(batch_size);
for (let i = 0; i < times.length; ++i) {
chart.data.datasets[i].data.push(times[i]);
}
chart.update();
}
// Calculate max speedup:
if (chart.data.labels.length === 0) return;
const testNames = [...testsToRun.keys()];
const table = generateResultsTable(model_id, testNames, chart.data, seqLength);
// Calculate slowest and fastest times
let minMaxTimes = [Infinity, 0];
let minMaxIndices = [0, 0];
for (let i = 0; i < chart.data.datasets.length; i++) {
const lastTime = chart.data.datasets[i].data.at(-1);
if (lastTime < minMaxTimes[0]) {
minMaxTimes[0] = lastTime;
minMaxIndices[0] = i;
}
if (lastTime > minMaxTimes[1]) {
minMaxTimes[1] = lastTime;
minMaxIndices[1] = i;
}
}
const speedup = minMaxTimes[1] / minMaxTimes[0];
const roundedSpeedup = speedup.toFixed(2);
const params = new URLSearchParams({
title: `⚡ WebGPU Benchmark Results (${roundedSpeedup}x speedup)`,
description: table.outerHTML,
});
const paramsStr = params.toString();
status.innerHTML = `⚡ Done! ${testNames.at(minMaxIndices[0])} is <strong>${roundedSpeedup}x</strong> faster than ${testNames.at(minMaxIndices[1])}! ⚡<br><a href="https://huggingface.co/spaces/Xenova/webgpu-embedding-benchmark/discussions/new?${paramsStr}" target="_blank">Share results</a>`;
start.disabled = false;
stop.disabled = true;
batchSizes.disabled = false;
sequenceLength.disabled = false;
modelID.disabled = false;
validTests.forEach(test => test.disabled = false);
});
start.disabled = false;
stop.addEventListener('click', () => {
status.textContent = 'Stopping...';
interrupted = true;
stop.disabled = true;
});
function generateResultsTable(model_id, testNames, data, sequence_length) {
const datasets = data.datasets.map(d => d.data);
const batch_sizes = data.labels;
const container = document.createElement('div');
const table = document.createElement('table');
const thead = table.createTHead();
const tbody = table.createTBody();
// Add header row
const headerRow = thead.insertRow();
headerRow.insertCell().textContent = 'Batch Size';
testNames.forEach(model => {
headerRow.insertCell().textContent = model;
});
// Add data rows
batch_sizes.forEach((batchSize, rowIndex) => {
const row = tbody.insertRow();
row.insertCell().textContent = batchSize;
datasets.forEach(dataset => {
row.insertCell().textContent = dataset[rowIndex].toFixed(2);
});
});
container.appendChild(table);
const createBulletPoint = (text) => {
const li = document.createElement('li');
li.textContent = text;
return li;
}
// Add other information
const info = document.createElement('ul');
info.appendChild(createBulletPoint(`Model: ${model_id}`));
info.appendChild(createBulletPoint(`Tests run: ${testNames.join(', ')}`));
info.appendChild(createBulletPoint(`Sequence length: ${sequence_length}`));
info.appendChild(createBulletPoint(`Browser: ${navigator.userAgent}`));
info.appendChild(createBulletPoint(`GPU: vendor=${adapterInfo.vendor}, architecture=${adapterInfo.architecture}, device=${adapterInfo.device}, description=${adapterInfo.description}`));
container.appendChild(info);
return container;
}