Add `ViTFeatureExtractor`

This commit is contained in:
Joshua Lochner 2023-03-11 19:12:10 +02:00
parent 051a193bed
commit 6f335fc9a0
2 changed files with 123 additions and 11 deletions

View File

@ -28,6 +28,7 @@
},
"homepage": "https://github.com/xenova/transformers.js#readme",
"dependencies": {
"canvas": "^2.11.0",
"onnxruntime-web": "^1.14.0"
},
"devDependencies": {

View File

@ -5,9 +5,12 @@ const {
getFile
} = require("./utils.js");
const { Tensor } = require('onnxruntime-web')
const { Tensor } = require('onnxruntime-web');
const FFT = require('./fft.js')
const FFT = require('./fft.js');
const { transpose } = require("./tensor_utils.js");
const { createCanvas, loadImage } = require('canvas');
class AutoProcessor {
@ -24,6 +27,9 @@ class AutoProcessor {
case 'WhisperFeatureExtractor':
feature_extractor = new WhisperFeatureExtractor(preprocessorConfig)
break;
case 'ViTFeatureExtractor':
feature_extractor = new ViTFeatureExtractor(preprocessorConfig)
break;
default:
throw new Error(`Unknown Feature Extractor type: ${preprocessorConfig.feature_extractor_type}`);
@ -33,22 +39,122 @@ class AutoProcessor {
case 'WhisperProcessor':
processor_class = WhisperProcessor;
break;
default:
throw new Error(`Unknown Processor type: ${preprocessorConfig.processor_class}`);
// No associated processor class, use default
processor_class = Processor;
}
return new processor_class(feature_extractor);
}
}
class WhisperFeatureExtractor extends Callable {
class FeatureExtractor extends Callable {
constructor(config) {
super();
this.config = config
}
}
class ViTFeatureExtractor extends FeatureExtractor {
constructor(config) {
super(config);
this.image_mean = this.config.image_mean;
if (!Array.isArray(this.image_mean)) {
this.image_mean = new Array(3).fill(this.image_mean);
}
this.image_std = this.config.image_std;
if (!Array.isArray(this.image_std)) {
this.image_std = new Array(3).fill(this.image_std);
}
this.do_rescale = this.config.do_rescale ?? true;
this.do_normalize = this.config.do_normalize;
this.do_resize = this.config.do_resize;
this.size = this.config.size;
}
async resize(image, width, height) {
// Create a canvas element to hold the resized image
const canvas = createCanvas(width, height);
// Draw the resized image onto the canvas
const ctx = canvas.getContext('2d')
ctx.drawImage(image, 0, 0, width, height)
// Get the pixel data for the entire canvas as a typed array
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
const data = imageData.data;
return data
}
async preprocess(image) {
if (isString(image)) {
// loading from path
image = await loadImage(image);
}
// resize all images
let width = image.width;
let height = image.height;
if (this.do_resize) {
width = this.size;
height = this.size;
}
let data = await this.resize(image, width, height)
// Do not include alpha channel
let convData = new Float32Array(data.length * 3 / 4);
let outIndex = 0;
for (let i = 0; i < data.length; i += 4) {
for (let j = 0; j < 3; ++j) {
convData[outIndex++] = data[i + j];
}
}
if (this.do_rescale) {
for (let i = 0; i < convData.length; ++i) {
convData[i] = convData[i] / 255;
}
}
if (this.do_normalize) {
for (let i = 0; i < convData.length; i += 3) {
for (let j = 0; j < 3; ++j) {
convData[i + j] = (convData[i + j] - this.image_mean[j]) / this.image_std[j]
}
}
}
let img = new Tensor('float32', convData, [width, height, 3]);
let transposed = transpose(img, [2, 0, 1]);
return transposed;
}
async _call(images) {
if (!Array.isArray(images)) {
images = [images];
}
// Convert any non-images to images
images = await Promise.all(images.map(x => this.preprocess(x)));
images.forEach(x => x.dims = [1, ...x.dims]) // add batch dimension
return {
pixel_values: images
};
}
}
class WhisperFeatureExtractor extends FeatureExtractor {
calcOffset(i, w) {
return Math.abs((i + w) % (2 * w) - w);
@ -315,7 +421,7 @@ class WhisperFeatureExtractor extends Callable {
};
}
_call(audio) {
async _call(audio) {
if (audio.length > this.config.n_samples) {
// TODO: https://github.com/openai/whisper/discussions/726
@ -341,15 +447,20 @@ class Processor extends Callable {
this.feature_extractor = feature_extractor;
// TODO use tokenizer here?
}
async _call(input) {
return await this.feature_extractor(input);
}
}
function isString(text) {
return typeof text === 'string' || text instanceof String
}
class WhisperProcessor extends Processor {
async _call(audio) {
if (typeof audio === 'string' || audio instanceof String) {
if (isString(audio)) {
// Attempting to load from path
if (typeof AudioContext === 'undefined') {
@ -368,7 +479,7 @@ class WhisperProcessor extends Processor {
}
// TODO use sampling rate?
return this.feature_extractor(audio)
return await this.feature_extractor(audio)
}
}