Compare commits

...

5 Commits

Author SHA1 Message Date
Joshua Lochner 3fb575d64f Update tensor import type 2023-12-12 16:38:03 +02:00
Joshua Lochner 3af5b1c8bc Apply suggestions 2023-12-12 16:13:43 +02:00
Joshua Lochner ba76b7931a Typing improvements 2023-12-11 22:47:03 +02:00
Joshua Lochner 70d3f7c34d Fix typing issues 2023-12-11 21:20:42 +02:00
Joshua Lochner b768cb8588 Do not extend from ONNX tensor (fix #437) 2023-12-11 20:37:22 +02:00
5 changed files with 103 additions and 52 deletions

View File

@ -206,6 +206,7 @@ function validateInputs(session, inputs) {
async function sessionRun(session, inputs) {
const checkedInputs = validateInputs(session, inputs);
try {
// @ts-ignore
let output = await session.run(checkedInputs);
output = replaceTensors(output);
return output;
@ -292,6 +293,7 @@ function prepareAttentionMask(self, tokens) {
if (is_pad_token_in_inputs && is_pad_token_not_equal_to_eos_token_id) {
let data = BigInt64Array.from(
// Note: != so that int matches bigint
// @ts-ignore
tokens.data.map(x => x != pad_token_id)
)
return new Tensor('int64', data, tokens.dims)
@ -704,9 +706,10 @@ export class PreTrainedModel extends Callable {
* @todo Use https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/FinalizationRegistry
*/
async dispose() {
let promises = [];
const promises = [];
for (let key of Object.keys(this)) {
let item = this[key];
const item = this[key];
// @ts-ignore
if (item instanceof InferenceSession) {
promises.push(item.handler.dispose())
}

View File

@ -261,6 +261,8 @@ export class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
return logits;
}
const logitsData = /** @type {Float32Array} */(logits.data);
// timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
const seq = input_ids.slice(this.begin_index);
const last_was_timestamp = seq.length >= 1 && seq[seq.length - 1] >= this.timestamp_begin;
@ -268,25 +270,25 @@ export class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
if (last_was_timestamp) {
if (penultimate_was_timestamp) { // has to be non-timestamp
logits.data.subarray(this.timestamp_begin).fill(-Infinity);
logitsData.subarray(this.timestamp_begin).fill(-Infinity);
} else { // cannot be normal text tokens
logits.data.subarray(0, this.eos_token_id).fill(-Infinity);
logitsData.subarray(0, this.eos_token_id).fill(-Infinity);
}
}
// apply the `max_initial_timestamp` option
if (input_ids.length === this.begin_index && this.max_initial_timestamp_index !== null) {
const last_allowed = this.timestamp_begin + this.max_initial_timestamp_index;
logits.data.subarray(last_allowed + 1).fill(-Infinity);
logitsData.subarray(last_allowed + 1).fill(-Infinity);
}
// if sum of probability over timestamps is above any other token, sample timestamp
const logprobs = log_softmax(logits.data);
const logprobs = log_softmax(logitsData);
const timestamp_logprob = Math.log(logprobs.subarray(this.timestamp_begin).map(Math.exp).reduce((a, b) => a + b));
const max_text_token_logprob = max(logprobs.subarray(0, this.timestamp_begin))[0];
if (timestamp_logprob > max_text_token_logprob) {
logits.data.subarray(0, this.timestamp_begin).fill(-Infinity);
logitsData.subarray(0, this.timestamp_begin).fill(-Infinity);
}
return logits;
@ -697,12 +699,12 @@ export class Sampler extends Callable {
* Returns the specified logits as an array, with temperature applied.
* @param {Tensor} logits
* @param {number} index
* @returns {Array}
* @returns {Float32Array}
*/
getLogits(logits, index) {
let vocabSize = logits.dims.at(-1);
let logs = logits.data;
let logs = /** @type {Float32Array} */(logits.data);
if (index === -1) {
logs = logs.slice(-vocabSize);

View File

@ -79,7 +79,7 @@ export class RawImage {
/**
* Create a new `RawImage` object.
* @param {Uint8ClampedArray} data The pixel data.
* @param {Uint8ClampedArray|Uint8Array} data The pixel data.
* @param {number} width The width of the image.
* @param {number} height The height of the image.
* @param {1|2|3|4} channels The number of channels.
@ -173,7 +173,18 @@ export class RawImage {
} else {
throw new Error(`Unsupported channel format: ${channel_format}`);
}
return new RawImage(tensor.data, tensor.dims[1], tensor.dims[0], tensor.dims[2]);
if (!(tensor.data instanceof Uint8ClampedArray || tensor.data instanceof Uint8Array)) {
throw new Error(`Unsupported tensor type: ${tensor.type}`);
}
switch (tensor.dims[2]) {
case 1:
case 2:
case 3:
case 4:
return new RawImage(tensor.data, tensor.dims[1], tensor.dims[0], tensor.dims[2]);
default:
throw new Error(`Unsupported number of channels: ${tensor.dims[2]}`);
}
}
/**

View File

@ -130,9 +130,9 @@ export function transpose_data(array, dims, axes) {
/**
* Compute the softmax of an array of numbers.
*
* @param {number[]} arr The array of numbers to compute the softmax of.
* @returns {number[]} The softmax array.
* @template {TypedArray|number[]} T
* @param {T} arr The array of numbers to compute the softmax of.
* @returns {T} The softmax array.
*/
export function softmax(arr) {
// Compute the maximum value in the array
@ -142,18 +142,20 @@ export function softmax(arr) {
const exps = arr.map(x => Math.exp(x - maxVal));
// Compute the sum of the exponentials
// @ts-ignore
const sumExps = exps.reduce((acc, val) => acc + val, 0);
// Compute the softmax values
const softmaxArr = exps.map(x => x / sumExps);
return softmaxArr;
return /** @type {T} */(softmaxArr);
}
/**
* Calculates the logarithm of the softmax function for the input array.
* @param {number[]} arr The input array to calculate the log_softmax function for.
* @returns {any} The resulting log_softmax array.
* @template {TypedArray|number[]} T
* @param {T} arr The input array to calculate the log_softmax function for.
* @returns {T} The resulting log_softmax array.
*/
export function log_softmax(arr) {
// Compute the softmax values
@ -162,7 +164,7 @@ export function log_softmax(arr) {
// Apply log formula to each element
const logSoftmaxArr = softmaxArr.map(x => Math.log(x));
return logSoftmaxArr;
return /** @type {T} */(logSoftmaxArr);
}
/**
@ -178,8 +180,7 @@ export function dot(arr1, arr2) {
/**
* Get the top k items from an iterable, sorted by descending order
*
* @param {Array} items The items to be sorted
* @param {any[]|TypedArray} items The items to be sorted
* @param {number} [top_k=0] The number of top items to return (default: 0 = return all)
* @returns {Array} The top k items, sorted by descending order
*/
@ -252,8 +253,8 @@ export function min(arr) {
/**
* Returns the value and index of the maximum element in an array.
* @param {number[]|TypedArray} arr array of numbers.
* @returns {number[]} the value and index of the maximum element, of the form: [valueOfMax, indexOfMax]
* @param {number[]|AnyTypedArray} arr array of numbers.
* @returns {[number, number]} the value and index of the maximum element, of the form: [valueOfMax, indexOfMax]
* @throws {Error} If array is empty.
*/
export function max(arr) {
@ -266,7 +267,7 @@ export function max(arr) {
indexOfMax = i;
}
}
return [max, indexOfMax];
return [Number(max), indexOfMax];
}
function isPowerOfTwo(number) {

View File

@ -15,40 +15,57 @@ import {
} from './maths.js';
// @ts-ignore
const DataTypeMap = new Map([
['bool', Uint8Array],
['float32', Float32Array],
['float64', Float64Array],
['string', Array], // string[]
['int8', Int8Array],
['uint8', Uint8Array],
['int16', Int16Array],
['uint16', Uint16Array],
['int32', Int32Array],
['uint32', Uint32Array],
['int64', BigInt64Array],
])
const DataTypeMap = Object.freeze({
float32: Float32Array,
float64: Float64Array,
string: Array, // string[]
int8: Int8Array,
uint8: Uint8Array,
int16: Int16Array,
uint16: Uint16Array,
int32: Int32Array,
uint32: Uint32Array,
int64: BigInt64Array,
uint64: BigUint64Array,
bool: Uint8Array,
});
/**
* @typedef {keyof typeof DataTypeMap} DataType
* @typedef {import('./maths.js').AnyTypedArray | any[]} DataArray
*/
const ONNXTensor = ONNX.Tensor;
export class Tensor extends ONNXTensor {
export class Tensor {
/** @type {number[]} Dimensions of the tensor. */
dims;
/** @type {DataType} Type of the tensor. */
type;
/** @type {DataArray} The data stored in the tensor. */
data;
/** @type {number} The number of elements in the tensor. */
size;
/**
* Create a new Tensor or copy an existing Tensor.
* @param {[string, DataArray, number[]]|[ONNXTensor]} args
* @param {[DataType, DataArray, number[]]|[import('onnxruntime-common').Tensor]} args
*/
constructor(...args) {
if (args[0] instanceof ONNX.Tensor) {
if (args[0] instanceof ONNXTensor) {
// Create shallow copy
super(args[0].type, args[0].data, args[0].dims);
Object.assign(this, args[0]);
} else {
// Create new
super(...args);
// Create new tensor
Object.assign(this, new ONNXTensor(
/** @type {DataType} */(args[0]),
/** @type {Exclude<import('./maths.js').AnyTypedArray, Uint8ClampedArray>} */(args[1]),
args[2]
));
}
return new Proxy(this, {
@ -130,14 +147,21 @@ export class Tensor extends ONNXTensor {
* @returns {Tensor}
*/
_subarray(index, iterSize, iterDims) {
let data = this.data.subarray(index * iterSize, (index + 1) * iterSize);
const o1 = index * iterSize;
const o2 = (index + 1) * iterSize;
// We use subarray if available (typed array), otherwise we use slice (normal array)
const data =
('subarray' in this.data)
? this.data.subarray(o1, o2)
: this.data.slice(o1, o2);
return new Tensor(this.type, data, iterDims);
}
/**
* Returns the value of this tensor as a standard JavaScript Number. This only works
* for tensors with one element. For other cases, see `Tensor.tolist()`.
* @returns {number} The value of this tensor as a standard JavaScript Number.
* @returns {number|bigint} The value of this tensor as a standard JavaScript Number.
* @throws {Error} If the tensor has more than one element.
*/
item() {
@ -265,6 +289,7 @@ export class Tensor extends ONNXTensor {
let newBufferSize = newDims.reduce((a, b) => a * b);
// Allocate memory
// @ts-ignore
let data = new this.data.constructor(newBufferSize);
// Precompute strides
@ -338,6 +363,7 @@ export class Tensor extends ONNXTensor {
resultDims[dim] = 1; // Remove the specified axis
// Create a new array to store the accumulated values
// @ts-ignore
const result = new this.data.constructor(this.data.length / this.dims[dim]);
// Iterate over the data array
@ -579,7 +605,7 @@ export class Tensor extends ONNXTensor {
/**
* Performs Tensor dtype conversion.
* @param {'bool'|'float32'|'float64'|'string'|'int8'|'uint8'|'int16'|'uint16'|'int32'|'uint32'|'int64'} type
* @param {DataType} type The desired data type.
* @returns {Tensor} The converted tensor.
*/
to(type) {
@ -587,11 +613,11 @@ export class Tensor extends ONNXTensor {
if (this.type === type) return this;
// Otherwise, the returned tensor is a copy of self with the desired dtype.
const ArrayConstructor = DataTypeMap.get(type);
if (!ArrayConstructor) {
if (!DataTypeMap.hasOwnProperty(type)) {
throw new Error(`Unsupported type: ${type}`);
}
return new Tensor(type, ArrayConstructor.from(this.data), this.dims);
// @ts-ignore
return new Tensor(type, DataTypeMap[type].from(this.data), this.dims);
}
}
@ -618,10 +644,10 @@ export class Tensor extends ONNXTensor {
* reshape([1, 2, 3, 4 ], [2, 2 ]); // Type: number[][] Value: [[1, 2], [3, 4]]
* reshape([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]); // Type: number[][][] Value: [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
* reshape([1, 2, 3, 4, 5, 6, 7, 8], [4, 2 ]); // Type: number[][] Value: [[1, 2], [3, 4], [5, 6], [7, 8]]
* @param {T[]} data The input array to reshape.
* @param {T[]|DataArray} data The input array to reshape.
* @param {DIM} dimensions The target shape/dimensions.
* @template T
* @template {[number]|[number, number]|[number, number, number]|[number, number, number, number]} DIM
* @template {[number]|number[]} DIM
* @returns {NestArray<T, DIM["length"]>} The reshaped array.
*/
function reshape(data, dimensions) {
@ -681,7 +707,7 @@ export function interpolate(input, [out_height, out_width], mode = 'bilinear', a
const in_width = input.dims.at(-1);
let output = interpolate_data(
input.data,
/** @type {import('./maths.js').TypedArray}*/(input.data),
[in_channels, in_height, in_width],
[out_height, out_width],
mode,
@ -701,6 +727,7 @@ export function mean_pooling(last_hidden_state, attention_mask) {
// attention_mask: [batchSize, seqLength]
let shape = [last_hidden_state.dims[0], last_hidden_state.dims[2]];
// @ts-ignore
let returnedData = new last_hidden_state.data.constructor(shape[0] * shape[1]);
let [batchSize, seqLength, embedDim] = last_hidden_state.dims;
@ -813,6 +840,7 @@ export function cat(tensors, dim = 0) {
// Create a new array to store the accumulated values
const resultSize = resultDims.reduce((a, b) => a * b, 1);
// @ts-ignore
const result = new tensors[0].data.constructor(resultSize);
// Create output tensor of same type as first
@ -884,8 +912,10 @@ export function std_mean(input, dim = null, correction = 1, keepdim = false) {
if (dim === null) {
// None to reduce over all dimensions.
// @ts-ignore
const sum = input.data.reduce((a, b) => a + b, 0);
const mean = sum / input.data.length;
// @ts-ignore
const std = Math.sqrt(input.data.reduce((a, b) => a + (b - mean) ** 2, 0) / (input.data.length - correction));
const meanTensor = new Tensor(input.type, [mean], [/* scalar */]);
@ -904,6 +934,7 @@ export function std_mean(input, dim = null, correction = 1, keepdim = false) {
resultDims[dim] = 1; // Remove the specified axis
// Create a new array to store the accumulated values
// @ts-ignore
const result = new input.data.constructor(input.data.length / input.dims[dim]);
// Iterate over the data array
@ -951,6 +982,7 @@ export function mean(input, dim = null, keepdim = false) {
if (dim === null) {
// None to reduce over all dimensions.
// @ts-ignore
let val = input.data.reduce((a, b) => a + b, 0);
return new Tensor(input.type, [val / input.data.length], [/* scalar */]);
}
@ -963,6 +995,7 @@ export function mean(input, dim = null, keepdim = false) {
resultDims[dim] = 1; // Remove the specified axis
// Create a new array to store the accumulated values
// @ts-ignore
const result = new input.data.constructor(input.data.length / input.dims[dim]);
// Iterate over the data array
@ -1054,6 +1087,7 @@ export function dynamicTimeWarping(matrix) {
let i = output_length;
let j = input_length;
// @ts-ignore
trace.data.fill(2, 0, outputShape[1]) // trace[0, :] = 2
for (let i = 0; i < outputShape[0]; ++i) { // trace[:, 0] = 1
trace[i].data[0] = 1;