Add `ViTFeatureExtractor`
This commit is contained in:
parent
051a193bed
commit
6f335fc9a0
|
@ -28,6 +28,7 @@
|
|||
},
|
||||
"homepage": "https://github.com/xenova/transformers.js#readme",
|
||||
"dependencies": {
|
||||
"canvas": "^2.11.0",
|
||||
"onnxruntime-web": "^1.14.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
|
|
@ -5,9 +5,12 @@ const {
|
|||
getFile
|
||||
} = require("./utils.js");
|
||||
|
||||
const { Tensor } = require('onnxruntime-web')
|
||||
const { Tensor } = require('onnxruntime-web');
|
||||
|
||||
const FFT = require('./fft.js')
|
||||
const FFT = require('./fft.js');
|
||||
const { transpose } = require("./tensor_utils.js");
|
||||
|
||||
const { createCanvas, loadImage } = require('canvas');
|
||||
|
||||
|
||||
class AutoProcessor {
|
||||
|
@ -24,6 +27,9 @@ class AutoProcessor {
|
|||
case 'WhisperFeatureExtractor':
|
||||
feature_extractor = new WhisperFeatureExtractor(preprocessorConfig)
|
||||
break;
|
||||
case 'ViTFeatureExtractor':
|
||||
feature_extractor = new ViTFeatureExtractor(preprocessorConfig)
|
||||
break;
|
||||
|
||||
default:
|
||||
throw new Error(`Unknown Feature Extractor type: ${preprocessorConfig.feature_extractor_type}`);
|
||||
|
@ -33,22 +39,122 @@ class AutoProcessor {
|
|||
case 'WhisperProcessor':
|
||||
processor_class = WhisperProcessor;
|
||||
break;
|
||||
|
||||
default:
|
||||
throw new Error(`Unknown Processor type: ${preprocessorConfig.processor_class}`);
|
||||
|
||||
// No associated processor class, use default
|
||||
processor_class = Processor;
|
||||
}
|
||||
|
||||
return new processor_class(feature_extractor);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class WhisperFeatureExtractor extends Callable {
|
||||
class FeatureExtractor extends Callable {
|
||||
constructor(config) {
|
||||
super();
|
||||
this.config = config
|
||||
}
|
||||
}
|
||||
class ViTFeatureExtractor extends FeatureExtractor {
|
||||
|
||||
constructor(config) {
|
||||
super(config);
|
||||
|
||||
this.image_mean = this.config.image_mean;
|
||||
if (!Array.isArray(this.image_mean)) {
|
||||
this.image_mean = new Array(3).fill(this.image_mean);
|
||||
}
|
||||
|
||||
this.image_std = this.config.image_std;
|
||||
if (!Array.isArray(this.image_std)) {
|
||||
this.image_std = new Array(3).fill(this.image_std);
|
||||
}
|
||||
|
||||
this.do_rescale = this.config.do_rescale ?? true;
|
||||
this.do_normalize = this.config.do_normalize;
|
||||
|
||||
this.do_resize = this.config.do_resize;
|
||||
this.size = this.config.size;
|
||||
}
|
||||
|
||||
async resize(image, width, height) {
|
||||
|
||||
// Create a canvas element to hold the resized image
|
||||
const canvas = createCanvas(width, height);
|
||||
|
||||
// Draw the resized image onto the canvas
|
||||
const ctx = canvas.getContext('2d')
|
||||
ctx.drawImage(image, 0, 0, width, height)
|
||||
|
||||
// Get the pixel data for the entire canvas as a typed array
|
||||
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
|
||||
const data = imageData.data;
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
async preprocess(image) {
|
||||
if (isString(image)) {
|
||||
// loading from path
|
||||
image = await loadImage(image);
|
||||
}
|
||||
|
||||
|
||||
// resize all images
|
||||
let width = image.width;
|
||||
let height = image.height;
|
||||
if (this.do_resize) {
|
||||
width = this.size;
|
||||
height = this.size;
|
||||
}
|
||||
|
||||
let data = await this.resize(image, width, height)
|
||||
|
||||
// Do not include alpha channel
|
||||
let convData = new Float32Array(data.length * 3 / 4);
|
||||
|
||||
let outIndex = 0;
|
||||
for (let i = 0; i < data.length; i += 4) {
|
||||
for (let j = 0; j < 3; ++j) {
|
||||
convData[outIndex++] = data[i + j];
|
||||
}
|
||||
}
|
||||
|
||||
if (this.do_rescale) {
|
||||
for (let i = 0; i < convData.length; ++i) {
|
||||
convData[i] = convData[i] / 255;
|
||||
}
|
||||
}
|
||||
|
||||
if (this.do_normalize) {
|
||||
for (let i = 0; i < convData.length; i += 3) {
|
||||
for (let j = 0; j < 3; ++j) {
|
||||
convData[i + j] = (convData[i + j] - this.image_mean[j]) / this.image_std[j]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let img = new Tensor('float32', convData, [width, height, 3]);
|
||||
let transposed = transpose(img, [2, 0, 1]);
|
||||
|
||||
return transposed;
|
||||
}
|
||||
|
||||
async _call(images) {
|
||||
if (!Array.isArray(images)) {
|
||||
images = [images];
|
||||
}
|
||||
|
||||
// Convert any non-images to images
|
||||
images = await Promise.all(images.map(x => this.preprocess(x)));
|
||||
|
||||
images.forEach(x => x.dims = [1, ...x.dims]) // add batch dimension
|
||||
return {
|
||||
pixel_values: images
|
||||
};
|
||||
}
|
||||
|
||||
}
|
||||
class WhisperFeatureExtractor extends FeatureExtractor {
|
||||
|
||||
calcOffset(i, w) {
|
||||
return Math.abs((i + w) % (2 * w) - w);
|
||||
|
@ -315,7 +421,7 @@ class WhisperFeatureExtractor extends Callable {
|
|||
};
|
||||
}
|
||||
|
||||
_call(audio) {
|
||||
async _call(audio) {
|
||||
|
||||
if (audio.length > this.config.n_samples) {
|
||||
// TODO: https://github.com/openai/whisper/discussions/726
|
||||
|
@ -341,15 +447,20 @@ class Processor extends Callable {
|
|||
this.feature_extractor = feature_extractor;
|
||||
// TODO use tokenizer here?
|
||||
}
|
||||
async _call(input) {
|
||||
return await this.feature_extractor(input);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
function isString(text) {
|
||||
return typeof text === 'string' || text instanceof String
|
||||
}
|
||||
class WhisperProcessor extends Processor {
|
||||
|
||||
async _call(audio) {
|
||||
|
||||
if (typeof audio === 'string' || audio instanceof String) {
|
||||
if (isString(audio)) {
|
||||
// Attempting to load from path
|
||||
|
||||
if (typeof AudioContext === 'undefined') {
|
||||
|
@ -368,7 +479,7 @@ class WhisperProcessor extends Processor {
|
|||
}
|
||||
|
||||
// TODO use sampling rate?
|
||||
return this.feature_extractor(audio)
|
||||
return await this.feature_extractor(audio)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue