Fix channel format when padding non-square images for certain models (#655)
* Add tensor permute unit tests * Rename transpose -> permute * Fix padding for non-square images * Add vitmatte padding unit test * Create `RawImage.toTensor` helper function * Add bankers rounding test case * `.toBe()` -> `.toBeCloseTo()` for floating point numbers
This commit is contained in:
parent
8c731fa54b
commit
40cdd36597
|
@ -3779,11 +3779,7 @@ export class VitMattePreTrainedModel extends PreTrainedModel { }
|
||||||
* import { Tensor, cat } from '@xenova/transformers';
|
* import { Tensor, cat } from '@xenova/transformers';
|
||||||
*
|
*
|
||||||
* // Visualize predicted alpha matte
|
* // Visualize predicted alpha matte
|
||||||
* const imageTensor = new Tensor(
|
* const imageTensor = image.toTensor();
|
||||||
* 'uint8',
|
|
||||||
* new Uint8Array(image.data),
|
|
||||||
* [image.height, image.width, image.channels]
|
|
||||||
* ).transpose(2, 0, 1);
|
|
||||||
*
|
*
|
||||||
* // Convert float (0-1) alpha matte to uint8 (0-255)
|
* // Convert float (0-1) alpha matte to uint8 (0-255)
|
||||||
* const alphaChannel = alphas
|
* const alphaChannel = alphas
|
||||||
|
|
|
@ -33,10 +33,11 @@ import {
|
||||||
min,
|
min,
|
||||||
max,
|
max,
|
||||||
softmax,
|
softmax,
|
||||||
|
bankers_round,
|
||||||
} from './utils/maths.js';
|
} from './utils/maths.js';
|
||||||
|
|
||||||
|
|
||||||
import { Tensor, transpose, cat, interpolate, stack } from './utils/tensor.js';
|
import { Tensor, permute, cat, interpolate, stack } from './utils/tensor.js';
|
||||||
|
|
||||||
import { RawImage } from './utils/image.js';
|
import { RawImage } from './utils/image.js';
|
||||||
import {
|
import {
|
||||||
|
@ -174,14 +175,15 @@ function validate_audio_inputs(audio, feature_extractor) {
|
||||||
* @private
|
* @private
|
||||||
*/
|
*/
|
||||||
function constraint_to_multiple_of(val, multiple, minVal = 0, maxVal = null) {
|
function constraint_to_multiple_of(val, multiple, minVal = 0, maxVal = null) {
|
||||||
let x = Math.round(val / multiple) * multiple;
|
const a = val / multiple;
|
||||||
|
let x = bankers_round(a) * multiple;
|
||||||
|
|
||||||
if (maxVal !== null && x > maxVal) {
|
if (maxVal !== null && x > maxVal) {
|
||||||
x = Math.floor(val / multiple) * multiple;
|
x = Math.floor(a) * multiple;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (x < minVal) {
|
if (x < minVal) {
|
||||||
x = Math.ceil(val / multiple) * multiple;
|
x = Math.ceil(a) * multiple;
|
||||||
}
|
}
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
|
@ -195,8 +197,8 @@ function constraint_to_multiple_of(val, multiple, minVal = 0, maxVal = null) {
|
||||||
*/
|
*/
|
||||||
function enforce_size_divisibility([width, height], divisor) {
|
function enforce_size_divisibility([width, height], divisor) {
|
||||||
return [
|
return [
|
||||||
Math.floor(width / divisor) * divisor,
|
Math.max(Math.floor(width / divisor), 1) * divisor,
|
||||||
Math.floor(height / divisor) * divisor
|
Math.max(Math.floor(height / divisor), 1) * divisor
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -348,7 +350,7 @@ export class ImageFeatureExtractor extends FeatureExtractor {
|
||||||
/**
|
/**
|
||||||
* Pad the image by a certain amount.
|
* Pad the image by a certain amount.
|
||||||
* @param {Float32Array} pixelData The pixel data to pad.
|
* @param {Float32Array} pixelData The pixel data to pad.
|
||||||
* @param {number[]} imgDims The dimensions of the image.
|
* @param {number[]} imgDims The dimensions of the image (height, width, channels).
|
||||||
* @param {{width:number; height:number}|number} padSize The dimensions of the padded image.
|
* @param {{width:number; height:number}|number} padSize The dimensions of the padded image.
|
||||||
* @param {Object} options The options for padding.
|
* @param {Object} options The options for padding.
|
||||||
* @param {'constant'|'symmetric'} [options.mode='constant'] The type of padding to add.
|
* @param {'constant'|'symmetric'} [options.mode='constant'] The type of padding to add.
|
||||||
|
@ -361,7 +363,7 @@ export class ImageFeatureExtractor extends FeatureExtractor {
|
||||||
center = false,
|
center = false,
|
||||||
constant_values = 0,
|
constant_values = 0,
|
||||||
} = {}) {
|
} = {}) {
|
||||||
const [imageWidth, imageHeight, imageChannels] = imgDims;
|
const [imageHeight, imageWidth, imageChannels] = imgDims;
|
||||||
|
|
||||||
let paddedImageWidth, paddedImageHeight;
|
let paddedImageWidth, paddedImageHeight;
|
||||||
if (typeof padSize === 'number') {
|
if (typeof padSize === 'number') {
|
||||||
|
@ -513,8 +515,8 @@ export class ImageFeatureExtractor extends FeatureExtractor {
|
||||||
if (this.config.keep_aspect_ratio && this.config.ensure_multiple_of) {
|
if (this.config.keep_aspect_ratio && this.config.ensure_multiple_of) {
|
||||||
|
|
||||||
// determine new height and width
|
// determine new height and width
|
||||||
let scale_height = size.height / srcHeight;
|
let scale_height = newHeight / srcHeight;
|
||||||
let scale_width = size.width / srcWidth;
|
let scale_width = newWidth / srcWidth;
|
||||||
|
|
||||||
// scale as little as possible
|
// scale as little as possible
|
||||||
if (Math.abs(1 - scale_width) < Math.abs(1 - scale_height)) {
|
if (Math.abs(1 - scale_width) < Math.abs(1 - scale_height)) {
|
||||||
|
@ -616,6 +618,9 @@ export class ImageFeatureExtractor extends FeatureExtractor {
|
||||||
/** @type {HeightWidth} */
|
/** @type {HeightWidth} */
|
||||||
const reshaped_input_size = [image.height, image.width];
|
const reshaped_input_size = [image.height, image.width];
|
||||||
|
|
||||||
|
// NOTE: All pixel-level manipulation (i.e., modifying `pixelData`)
|
||||||
|
// occurs with data in the hwc format (height, width, channels),
|
||||||
|
// to emulate the behavior of the original Python code (w/ numpy).
|
||||||
let pixelData = Float32Array.from(image.data);
|
let pixelData = Float32Array.from(image.data);
|
||||||
let imgDims = [image.height, image.width, image.channels];
|
let imgDims = [image.height, image.width, image.channels];
|
||||||
|
|
||||||
|
@ -646,21 +651,23 @@ export class ImageFeatureExtractor extends FeatureExtractor {
|
||||||
}
|
}
|
||||||
|
|
||||||
// do padding after rescaling/normalizing
|
// do padding after rescaling/normalizing
|
||||||
if (do_pad ?? (this.do_pad && this.pad_size)) {
|
if (do_pad ?? this.do_pad) {
|
||||||
const padded = this.pad_image(pixelData, [image.width, image.height, image.channels], this.pad_size);
|
if (this.pad_size) {
|
||||||
|
const padded = this.pad_image(pixelData, [image.height, image.width, image.channels], this.pad_size);
|
||||||
[pixelData, imgDims] = padded; // Update pixel data and image dimensions
|
[pixelData, imgDims] = padded; // Update pixel data and image dimensions
|
||||||
|
} else if (this.size_divisibility) {
|
||||||
|
const [paddedWidth, paddedHeight] = enforce_size_divisibility([imgDims[1], imgDims[0]], this.size_divisibility);
|
||||||
|
[pixelData, imgDims] = this.pad_image(pixelData, imgDims, { width: paddedWidth, height: paddedHeight });
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create HWC tensor
|
const pixel_values = new Tensor('float32', pixelData, imgDims)
|
||||||
const img = new Tensor('float32', pixelData, imgDims);
|
.permute(2, 0, 1); // convert to channel dimension format (hwc -> chw)
|
||||||
|
|
||||||
// convert to channel dimension format:
|
|
||||||
const transposed = transpose(img, [2, 0, 1]); // hwc -> chw
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
original_size: [srcHeight, srcWidth],
|
original_size: [srcHeight, srcWidth],
|
||||||
reshaped_input_size: reshaped_input_size,
|
reshaped_input_size: reshaped_input_size,
|
||||||
pixel_values: transposed,
|
pixel_values: pixel_values,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -760,9 +767,9 @@ export class SegformerFeatureExtractor extends ImageFeatureExtractor {
|
||||||
return toReturn;
|
return toReturn;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
export class DPTImageProcessor extends ImageFeatureExtractor { }
|
|
||||||
export class BitImageProcessor extends ImageFeatureExtractor { }
|
|
||||||
export class DPTFeatureExtractor extends ImageFeatureExtractor { }
|
export class DPTFeatureExtractor extends ImageFeatureExtractor { }
|
||||||
|
export class DPTImageProcessor extends DPTFeatureExtractor { } // NOTE: extends DPTFeatureExtractor
|
||||||
|
export class BitImageProcessor extends ImageFeatureExtractor { }
|
||||||
export class GLPNFeatureExtractor extends ImageFeatureExtractor { }
|
export class GLPNFeatureExtractor extends ImageFeatureExtractor { }
|
||||||
export class CLIPFeatureExtractor extends ImageFeatureExtractor { }
|
export class CLIPFeatureExtractor extends ImageFeatureExtractor { }
|
||||||
export class ChineseCLIPFeatureExtractor extends ImageFeatureExtractor { }
|
export class ChineseCLIPFeatureExtractor extends ImageFeatureExtractor { }
|
||||||
|
@ -835,7 +842,7 @@ export class DeiTFeatureExtractor extends ImageFeatureExtractor { }
|
||||||
export class BeitFeatureExtractor extends ImageFeatureExtractor { }
|
export class BeitFeatureExtractor extends ImageFeatureExtractor { }
|
||||||
export class DonutFeatureExtractor extends ImageFeatureExtractor {
|
export class DonutFeatureExtractor extends ImageFeatureExtractor {
|
||||||
pad_image(pixelData, imgDims, padSize, options = {}) {
|
pad_image(pixelData, imgDims, padSize, options = {}) {
|
||||||
const [imageWidth, imageHeight, imageChannels] = imgDims;
|
const [imageHeight, imageWidth, imageChannels] = imgDims;
|
||||||
|
|
||||||
let image_mean = this.image_mean;
|
let image_mean = this.image_mean;
|
||||||
if (!Array.isArray(this.image_mean)) {
|
if (!Array.isArray(this.image_mean)) {
|
||||||
|
@ -1382,7 +1389,7 @@ export class Swin2SRImageProcessor extends ImageFeatureExtractor {
|
||||||
pad_image(pixelData, imgDims, padSize, options = {}) {
|
pad_image(pixelData, imgDims, padSize, options = {}) {
|
||||||
// NOTE: In this case, `padSize` represents the size of the sliding window for the local attention.
|
// NOTE: In this case, `padSize` represents the size of the sliding window for the local attention.
|
||||||
// In other words, the image is padded so that its width and height are multiples of `padSize`.
|
// In other words, the image is padded so that its width and height are multiples of `padSize`.
|
||||||
const [imageWidth, imageHeight, imageChannels] = imgDims;
|
const [imageHeight, imageWidth, imageChannels] = imgDims;
|
||||||
|
|
||||||
return super.pad_image(pixelData, imgDims, {
|
return super.pad_image(pixelData, imgDims, {
|
||||||
// NOTE: For Swin2SR models, the original python implementation adds padding even when the image's width/height is already
|
// NOTE: For Swin2SR models, the original python implementation adds padding even when the image's width/height is already
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
|
|
||||||
import { getFile } from './hub.js';
|
import { getFile } from './hub.js';
|
||||||
import { env } from '../env.js';
|
import { env } from '../env.js';
|
||||||
|
import { Tensor } from './tensor.js';
|
||||||
|
|
||||||
// Will be empty (or not used) if running in browser or web-worker
|
// Will be empty (or not used) if running in browser or web-worker
|
||||||
import sharp from 'sharp';
|
import sharp from 'sharp';
|
||||||
|
@ -166,7 +167,7 @@ export class RawImage {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Helper method to create a new Image from a tensor
|
* Helper method to create a new Image from a tensor
|
||||||
* @param {import('./tensor.js').Tensor} tensor
|
* @param {Tensor} tensor
|
||||||
*/
|
*/
|
||||||
static fromTensor(tensor, channel_format = 'CHW') {
|
static fromTensor(tensor, channel_format = 'CHW') {
|
||||||
if (tensor.dims.length !== 3) {
|
if (tensor.dims.length !== 3) {
|
||||||
|
@ -586,6 +587,23 @@ export class RawImage {
|
||||||
return await canvas.convertToBlob({ type, quality });
|
return await canvas.convertToBlob({ type, quality });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
toTensor(channel_format = 'CHW') {
|
||||||
|
let tensor = new Tensor(
|
||||||
|
'uint8',
|
||||||
|
new Uint8Array(this.data),
|
||||||
|
[this.height, this.width, this.channels]
|
||||||
|
);
|
||||||
|
|
||||||
|
if (channel_format === 'HWC') {
|
||||||
|
// Do nothing
|
||||||
|
} else if (channel_format === 'CHW') { // hwc -> chw
|
||||||
|
tensor = tensor.permute(2, 0, 1);
|
||||||
|
} else {
|
||||||
|
throw new Error(`Unsupported channel format: ${channel_format}`);
|
||||||
|
}
|
||||||
|
return tensor;
|
||||||
|
}
|
||||||
|
|
||||||
toCanvas() {
|
toCanvas() {
|
||||||
if (!BROWSER_ENV) {
|
if (!BROWSER_ENV) {
|
||||||
throw new Error('toCanvas() is only supported in browser environments.')
|
throw new Error('toCanvas() is only supported in browser environments.')
|
||||||
|
|
|
@ -88,15 +88,15 @@ export function interpolate_data(input, [in_channels, in_height, in_width], [out
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Helper method to transpose a `AnyTypedArray` directly
|
* Helper method to permute a `AnyTypedArray` directly
|
||||||
* @template {AnyTypedArray} T
|
* @template {AnyTypedArray} T
|
||||||
* @param {T} array
|
* @param {T} array
|
||||||
* @param {number[]} dims
|
* @param {number[]} dims
|
||||||
* @param {number[]} axes
|
* @param {number[]} axes
|
||||||
* @returns {[T, number[]]} The transposed array and the new shape.
|
* @returns {[T, number[]]} The permuted array and the new shape.
|
||||||
*/
|
*/
|
||||||
export function transpose_data(array, dims, axes) {
|
export function permute_data(array, dims, axes) {
|
||||||
// Calculate the new shape of the transposed array
|
// Calculate the new shape of the permuted array
|
||||||
// and the stride of the original array
|
// and the stride of the original array
|
||||||
const shape = new Array(axes.length);
|
const shape = new Array(axes.length);
|
||||||
const stride = new Array(axes.length);
|
const stride = new Array(axes.length);
|
||||||
|
@ -110,21 +110,21 @@ export function transpose_data(array, dims, axes) {
|
||||||
// Precompute inverse mapping of stride
|
// Precompute inverse mapping of stride
|
||||||
const invStride = axes.map((_, i) => stride[axes.indexOf(i)]);
|
const invStride = axes.map((_, i) => stride[axes.indexOf(i)]);
|
||||||
|
|
||||||
// Create the transposed array with the new shape
|
// Create the permuted array with the new shape
|
||||||
// @ts-ignore
|
// @ts-ignore
|
||||||
const transposedData = new array.constructor(array.length);
|
const permutedData = new array.constructor(array.length);
|
||||||
|
|
||||||
// Transpose the original array to the new array
|
// Permute the original array to the new array
|
||||||
for (let i = 0; i < array.length; ++i) {
|
for (let i = 0; i < array.length; ++i) {
|
||||||
let newIndex = 0;
|
let newIndex = 0;
|
||||||
for (let j = dims.length - 1, k = i; j >= 0; --j) {
|
for (let j = dims.length - 1, k = i; j >= 0; --j) {
|
||||||
newIndex += (k % dims[j]) * invStride[j];
|
newIndex += (k % dims[j]) * invStride[j];
|
||||||
k = Math.floor(k / dims[j]);
|
k = Math.floor(k / dims[j]);
|
||||||
}
|
}
|
||||||
transposedData[newIndex] = array[i];
|
permutedData[newIndex] = array[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
return [transposedData, shape];
|
return [permutedData, shape];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -952,3 +952,17 @@ export function round(num, decimals) {
|
||||||
const pow = Math.pow(10, decimals);
|
const pow = Math.pow(10, decimals);
|
||||||
return Math.round(num * pow) / pow;
|
return Math.round(num * pow) / pow;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Helper function to round a number to the nearest integer, with ties rounded to the nearest even number.
|
||||||
|
* Also known as "bankers' rounding". This is the default rounding mode in python. For example:
|
||||||
|
* 1.5 rounds to 2 and 2.5 rounds to 2.
|
||||||
|
*
|
||||||
|
* @param {number} x The number to round
|
||||||
|
* @returns {number} The rounded number
|
||||||
|
*/
|
||||||
|
export function bankers_round(x) {
|
||||||
|
const r = Math.round(x);
|
||||||
|
const br = Math.abs(x) % 1 === 0.5 ? (r % 2 === 0 ? r : r - 1) : r;
|
||||||
|
return br;
|
||||||
|
}
|
||||||
|
|
|
@ -11,7 +11,7 @@ import { ONNX } from '../backends/onnx.js';
|
||||||
|
|
||||||
import {
|
import {
|
||||||
interpolate_data,
|
interpolate_data,
|
||||||
transpose_data
|
permute_data
|
||||||
} from './maths.js';
|
} from './maths.js';
|
||||||
|
|
||||||
|
|
||||||
|
@ -309,16 +309,18 @@ export class Tensor {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Return a transposed version of this Tensor, according to the provided dimensions.
|
* Return a permuted version of this Tensor, according to the provided dimensions.
|
||||||
* @param {...number} dims Dimensions to transpose.
|
* @param {...number} dims Dimensions to permute.
|
||||||
* @returns {Tensor} The transposed tensor.
|
* @returns {Tensor} The permuted tensor.
|
||||||
*/
|
*/
|
||||||
transpose(...dims) {
|
permute(...dims) {
|
||||||
return transpose(this, dims);
|
return permute(this, dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: rename transpose to permute
|
// TODO: implement transpose. For now (backwards compatibility), it's just an alias for permute()
|
||||||
// TODO: implement transpose
|
transpose(...dims) {
|
||||||
|
return this.permute(...dims);
|
||||||
|
}
|
||||||
|
|
||||||
// TODO add .max() and .min() methods
|
// TODO add .max() and .min() methods
|
||||||
|
|
||||||
|
@ -680,14 +682,14 @@ function reshape(data, dimensions) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Transposes a tensor according to the provided axes.
|
* Permutes a tensor according to the provided axes.
|
||||||
* @param {any} tensor The input tensor to transpose.
|
* @param {any} tensor The input tensor to permute.
|
||||||
* @param {Array} axes The axes to transpose the tensor along.
|
* @param {Array} axes The axes to permute the tensor along.
|
||||||
* @returns {Tensor} The transposed tensor.
|
* @returns {Tensor} The permuted tensor.
|
||||||
*/
|
*/
|
||||||
export function transpose(tensor, axes) {
|
export function permute(tensor, axes) {
|
||||||
const [transposedData, shape] = transpose_data(tensor.data, tensor.dims, axes);
|
const [permutedData, shape] = permute_data(tensor.data, tensor.dims, axes);
|
||||||
return new Tensor(tensor.type, transposedData, shape);
|
return new Tensor(tensor.type, permutedData, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
import { compare } from './test_utils.js';
|
import { compare } from './test_utils.js';
|
||||||
|
|
||||||
import { getFile } from '../src/utils/hub.js';
|
import { getFile } from '../src/utils/hub.js';
|
||||||
import { FFT, medianFilter } from '../src/utils/maths.js';
|
import { FFT, medianFilter, bankers_round } from '../src/utils/maths.js';
|
||||||
|
|
||||||
|
|
||||||
const fft = (arr, complex = false) => {
|
const fft = (arr, complex = false) => {
|
||||||
|
@ -27,6 +27,19 @@ const fftTestsData = await (await getFile('./tests/data/fft_tests.json')).json()
|
||||||
|
|
||||||
describe('Mathematical operations', () => {
|
describe('Mathematical operations', () => {
|
||||||
|
|
||||||
|
describe('bankers rounding', () => {
|
||||||
|
it('should round up to nearest even', () => {
|
||||||
|
expect(bankers_round(-0.5)).toBeCloseTo(0);
|
||||||
|
expect(bankers_round(1.5)).toBeCloseTo(2);
|
||||||
|
expect(bankers_round(19.5)).toBeCloseTo(20);
|
||||||
|
});
|
||||||
|
it('should round down to nearest even', () => {
|
||||||
|
expect(bankers_round(-1.5)).toBeCloseTo(-2);
|
||||||
|
expect(bankers_round(2.5)).toBeCloseTo(2);
|
||||||
|
expect(bankers_round(18.5)).toBeCloseTo(18);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('median filtering', () => {
|
describe('median filtering', () => {
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -50,7 +50,9 @@ describe('Processors', () => {
|
||||||
|
|
||||||
const TEST_IMAGES = {
|
const TEST_IMAGES = {
|
||||||
pattern_3x3: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/pattern_3x3.png',
|
pattern_3x3: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/pattern_3x3.png',
|
||||||
|
pattern_3x5: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/pattern_3x5.png',
|
||||||
checkerboard_8x8: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/checkerboard_8x8.png',
|
checkerboard_8x8: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/checkerboard_8x8.png',
|
||||||
|
checkerboard_64x32: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/checkerboard_64x32.png',
|
||||||
receipt: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/receipt.png',
|
receipt: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/receipt.png',
|
||||||
tiger: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/tiger.jpg',
|
tiger: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/tiger.jpg',
|
||||||
paper: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/nougat_paper.png',
|
paper: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/nougat_paper.png',
|
||||||
|
@ -369,6 +371,7 @@ describe('Processors', () => {
|
||||||
// - tests custom overrides
|
// - tests custom overrides
|
||||||
// - tests multiple inputs
|
// - tests multiple inputs
|
||||||
// - tests `size_divisibility` and no size (size_divisibility=32)
|
// - tests `size_divisibility` and no size (size_divisibility=32)
|
||||||
|
// - tests do_pad and `size_divisibility`
|
||||||
it(MODELS.vitmatte, async () => {
|
it(MODELS.vitmatte, async () => {
|
||||||
const processor = await AutoProcessor.from_pretrained(m(MODELS.vitmatte))
|
const processor = await AutoProcessor.from_pretrained(m(MODELS.vitmatte))
|
||||||
|
|
||||||
|
@ -391,6 +394,25 @@ describe('Processors', () => {
|
||||||
compare(original_sizes, [[640, 960]]);
|
compare(original_sizes, [[640, 960]]);
|
||||||
compare(reshaped_input_sizes, [[640, 960]]);
|
compare(reshaped_input_sizes, [[640, 960]]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
{
|
||||||
|
const image = await load_image(TEST_IMAGES.pattern_3x5);
|
||||||
|
const image2 = await load_image(TEST_IMAGES.pattern_3x5);
|
||||||
|
const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image, image2);
|
||||||
|
|
||||||
|
compare(pixel_values.dims, [1, 4, 32, 32]);
|
||||||
|
expect(avg(pixel_values.data)).toBeCloseTo(-0.00867417361587286);
|
||||||
|
expect(pixel_values.data[0]).toBeCloseTo(-0.9921568632125854);
|
||||||
|
expect(pixel_values.data[1]).toBeCloseTo(-0.9686274528503418);
|
||||||
|
expect(pixel_values.data[5]).toBeCloseTo(0.0);
|
||||||
|
expect(pixel_values.data[32]).toBeCloseTo(-0.9215686321258545);
|
||||||
|
expect(pixel_values.data[33]).toBeCloseTo(-0.8980392217636108);
|
||||||
|
expect(pixel_values.data.at(-1)).toBeCloseTo(0.0);
|
||||||
|
|
||||||
|
compare(original_sizes, [[5, 3]]);
|
||||||
|
compare(reshaped_input_sizes, [[5, 3]]);
|
||||||
|
}
|
||||||
}, MAX_TEST_EXECUTION_TIME);
|
}, MAX_TEST_EXECUTION_TIME);
|
||||||
|
|
||||||
// BitImageProcessor
|
// BitImageProcessor
|
||||||
|
@ -412,6 +434,7 @@ describe('Processors', () => {
|
||||||
// DPTImageProcessor
|
// DPTImageProcessor
|
||||||
// - tests ensure_multiple_of
|
// - tests ensure_multiple_of
|
||||||
// - tests keep_aspect_ratio
|
// - tests keep_aspect_ratio
|
||||||
|
// - tests bankers rounding
|
||||||
it(MODELS.dpt_2, async () => {
|
it(MODELS.dpt_2, async () => {
|
||||||
const processor = await AutoProcessor.from_pretrained(m(MODELS.dpt_2))
|
const processor = await AutoProcessor.from_pretrained(m(MODELS.dpt_2))
|
||||||
|
|
||||||
|
@ -425,6 +448,18 @@ describe('Processors', () => {
|
||||||
compare(original_sizes, [[480, 640]]);
|
compare(original_sizes, [[480, 640]]);
|
||||||
compare(reshaped_input_sizes, [[518, 686]]);
|
compare(reshaped_input_sizes, [[518, 686]]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const image = await load_image(TEST_IMAGES.checkerboard_64x32);
|
||||||
|
const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image);
|
||||||
|
|
||||||
|
// NOTE: without bankers rounding, this would be [1, 3, 266, 518]
|
||||||
|
compare(pixel_values.dims, [1, 3, 252, 518]);
|
||||||
|
compare(avg(pixel_values.data), 0.2267402559518814);
|
||||||
|
|
||||||
|
compare(original_sizes, [[32, 64]]);
|
||||||
|
compare(reshaped_input_sizes, [[252, 518]]);
|
||||||
|
}
|
||||||
}, MAX_TEST_EXECUTION_TIME);
|
}, MAX_TEST_EXECUTION_TIME);
|
||||||
|
|
||||||
// EfficientNetImageProcessor
|
// EfficientNetImageProcessor
|
||||||
|
|
|
@ -103,6 +103,65 @@ describe('Tensor operations', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('permute', () => {
|
||||||
|
it('should permute', async () => {
|
||||||
|
const x = new Tensor(
|
||||||
|
'float32',
|
||||||
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23],
|
||||||
|
[2, 3, 4],
|
||||||
|
);
|
||||||
|
// Permute axes to (0, 1, 2) - No change
|
||||||
|
const permuted_1 = x.permute(0, 1, 2);
|
||||||
|
const target_1 = x;
|
||||||
|
compare(permuted_1, target_1, 1e-3);
|
||||||
|
|
||||||
|
// Permute axes to (0, 2, 1)
|
||||||
|
const permuted_2 = x.permute(0, 2, 1);
|
||||||
|
const target_2 = new Tensor(
|
||||||
|
'float32',
|
||||||
|
[0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11, 12, 16, 20, 13, 17, 21, 14, 18, 22, 15, 19, 23],
|
||||||
|
[2, 4, 3],
|
||||||
|
);
|
||||||
|
compare(permuted_2, target_2, 1e-3);
|
||||||
|
|
||||||
|
// Permute axes to (1, 0, 2)
|
||||||
|
const permuted_3 = x.permute(1, 0, 2);
|
||||||
|
const target_3 = new Tensor(
|
||||||
|
'float32',
|
||||||
|
[0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23],
|
||||||
|
[3, 2, 4],
|
||||||
|
);
|
||||||
|
compare(permuted_3, target_3, 1e-3);
|
||||||
|
|
||||||
|
// Permute axes to (1, 2, 0)
|
||||||
|
const permuted_4 = x.permute(1, 2, 0);
|
||||||
|
const target_4 = new Tensor(
|
||||||
|
'float32',
|
||||||
|
[0, 12, 1, 13, 2, 14, 3, 15, 4, 16, 5, 17, 6, 18, 7, 19, 8, 20, 9, 21, 10, 22, 11, 23],
|
||||||
|
[3, 4, 2],
|
||||||
|
);
|
||||||
|
compare(permuted_4, target_4, 1e-3);
|
||||||
|
|
||||||
|
// Permute axes to (2, 0, 1)
|
||||||
|
const permuted_5 = x.permute(2, 0, 1);
|
||||||
|
const target_5 = new Tensor(
|
||||||
|
'float32',
|
||||||
|
[0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23],
|
||||||
|
[4, 2, 3],
|
||||||
|
);
|
||||||
|
compare(permuted_5, target_5, 1e-3);
|
||||||
|
|
||||||
|
// Permute axes to (2, 1, 0)
|
||||||
|
const permuted_6 = x.permute(2, 1, 0);
|
||||||
|
const target_6 = new Tensor(
|
||||||
|
'float32',
|
||||||
|
[0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21, 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23],
|
||||||
|
[4, 3, 2],
|
||||||
|
);
|
||||||
|
compare(permuted_6, target_6, 1e-3);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('mean', () => {
|
describe('mean', () => {
|
||||||
it('should calculate mean', async () => {
|
it('should calculate mean', async () => {
|
||||||
const t1 = new Tensor('float32', [1, 2, 3, 4, 5, 6], [2, 3, 1]);
|
const t1 = new Tensor('float32', [1, 2, 3, 4, 5, 6], [2, 3, 1]);
|
||||||
|
|
Loading…
Reference in New Issue