Compare commits
5 Commits
main
...
fix-tensor
Author | SHA1 | Date |
---|---|---|
Joshua Lochner | 3fb575d64f | |
Joshua Lochner | 3af5b1c8bc | |
Joshua Lochner | ba76b7931a | |
Joshua Lochner | 70d3f7c34d | |
Joshua Lochner | b768cb8588 |
|
@ -206,6 +206,7 @@ function validateInputs(session, inputs) {
|
|||
async function sessionRun(session, inputs) {
|
||||
const checkedInputs = validateInputs(session, inputs);
|
||||
try {
|
||||
// @ts-ignore
|
||||
let output = await session.run(checkedInputs);
|
||||
output = replaceTensors(output);
|
||||
return output;
|
||||
|
@ -292,6 +293,7 @@ function prepareAttentionMask(self, tokens) {
|
|||
if (is_pad_token_in_inputs && is_pad_token_not_equal_to_eos_token_id) {
|
||||
let data = BigInt64Array.from(
|
||||
// Note: != so that int matches bigint
|
||||
// @ts-ignore
|
||||
tokens.data.map(x => x != pad_token_id)
|
||||
)
|
||||
return new Tensor('int64', data, tokens.dims)
|
||||
|
@ -704,9 +706,10 @@ export class PreTrainedModel extends Callable {
|
|||
* @todo Use https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/FinalizationRegistry
|
||||
*/
|
||||
async dispose() {
|
||||
let promises = [];
|
||||
const promises = [];
|
||||
for (let key of Object.keys(this)) {
|
||||
let item = this[key];
|
||||
const item = this[key];
|
||||
// @ts-ignore
|
||||
if (item instanceof InferenceSession) {
|
||||
promises.push(item.handler.dispose())
|
||||
}
|
||||
|
|
|
@ -261,6 +261,8 @@ export class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
|
|||
return logits;
|
||||
}
|
||||
|
||||
const logitsData = /** @type {Float32Array} */(logits.data);
|
||||
|
||||
// timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
|
||||
const seq = input_ids.slice(this.begin_index);
|
||||
const last_was_timestamp = seq.length >= 1 && seq[seq.length - 1] >= this.timestamp_begin;
|
||||
|
@ -268,25 +270,25 @@ export class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
|
|||
|
||||
if (last_was_timestamp) {
|
||||
if (penultimate_was_timestamp) { // has to be non-timestamp
|
||||
logits.data.subarray(this.timestamp_begin).fill(-Infinity);
|
||||
logitsData.subarray(this.timestamp_begin).fill(-Infinity);
|
||||
} else { // cannot be normal text tokens
|
||||
logits.data.subarray(0, this.eos_token_id).fill(-Infinity);
|
||||
logitsData.subarray(0, this.eos_token_id).fill(-Infinity);
|
||||
}
|
||||
}
|
||||
|
||||
// apply the `max_initial_timestamp` option
|
||||
if (input_ids.length === this.begin_index && this.max_initial_timestamp_index !== null) {
|
||||
const last_allowed = this.timestamp_begin + this.max_initial_timestamp_index;
|
||||
logits.data.subarray(last_allowed + 1).fill(-Infinity);
|
||||
logitsData.subarray(last_allowed + 1).fill(-Infinity);
|
||||
}
|
||||
|
||||
// if sum of probability over timestamps is above any other token, sample timestamp
|
||||
const logprobs = log_softmax(logits.data);
|
||||
const logprobs = log_softmax(logitsData);
|
||||
const timestamp_logprob = Math.log(logprobs.subarray(this.timestamp_begin).map(Math.exp).reduce((a, b) => a + b));
|
||||
const max_text_token_logprob = max(logprobs.subarray(0, this.timestamp_begin))[0];
|
||||
|
||||
if (timestamp_logprob > max_text_token_logprob) {
|
||||
logits.data.subarray(0, this.timestamp_begin).fill(-Infinity);
|
||||
logitsData.subarray(0, this.timestamp_begin).fill(-Infinity);
|
||||
}
|
||||
|
||||
return logits;
|
||||
|
@ -697,12 +699,12 @@ export class Sampler extends Callable {
|
|||
* Returns the specified logits as an array, with temperature applied.
|
||||
* @param {Tensor} logits
|
||||
* @param {number} index
|
||||
* @returns {Array}
|
||||
* @returns {Float32Array}
|
||||
*/
|
||||
getLogits(logits, index) {
|
||||
let vocabSize = logits.dims.at(-1);
|
||||
|
||||
let logs = logits.data;
|
||||
let logs = /** @type {Float32Array} */(logits.data);
|
||||
|
||||
if (index === -1) {
|
||||
logs = logs.slice(-vocabSize);
|
||||
|
|
|
@ -79,7 +79,7 @@ export class RawImage {
|
|||
|
||||
/**
|
||||
* Create a new `RawImage` object.
|
||||
* @param {Uint8ClampedArray} data The pixel data.
|
||||
* @param {Uint8ClampedArray|Uint8Array} data The pixel data.
|
||||
* @param {number} width The width of the image.
|
||||
* @param {number} height The height of the image.
|
||||
* @param {1|2|3|4} channels The number of channels.
|
||||
|
@ -173,7 +173,18 @@ export class RawImage {
|
|||
} else {
|
||||
throw new Error(`Unsupported channel format: ${channel_format}`);
|
||||
}
|
||||
return new RawImage(tensor.data, tensor.dims[1], tensor.dims[0], tensor.dims[2]);
|
||||
if (!(tensor.data instanceof Uint8ClampedArray || tensor.data instanceof Uint8Array)) {
|
||||
throw new Error(`Unsupported tensor type: ${tensor.type}`);
|
||||
}
|
||||
switch (tensor.dims[2]) {
|
||||
case 1:
|
||||
case 2:
|
||||
case 3:
|
||||
case 4:
|
||||
return new RawImage(tensor.data, tensor.dims[1], tensor.dims[0], tensor.dims[2]);
|
||||
default:
|
||||
throw new Error(`Unsupported number of channels: ${tensor.dims[2]}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -130,9 +130,9 @@ export function transpose_data(array, dims, axes) {
|
|||
|
||||
/**
|
||||
* Compute the softmax of an array of numbers.
|
||||
*
|
||||
* @param {number[]} arr The array of numbers to compute the softmax of.
|
||||
* @returns {number[]} The softmax array.
|
||||
* @template {TypedArray|number[]} T
|
||||
* @param {T} arr The array of numbers to compute the softmax of.
|
||||
* @returns {T} The softmax array.
|
||||
*/
|
||||
export function softmax(arr) {
|
||||
// Compute the maximum value in the array
|
||||
|
@ -142,18 +142,20 @@ export function softmax(arr) {
|
|||
const exps = arr.map(x => Math.exp(x - maxVal));
|
||||
|
||||
// Compute the sum of the exponentials
|
||||
// @ts-ignore
|
||||
const sumExps = exps.reduce((acc, val) => acc + val, 0);
|
||||
|
||||
// Compute the softmax values
|
||||
const softmaxArr = exps.map(x => x / sumExps);
|
||||
|
||||
return softmaxArr;
|
||||
return /** @type {T} */(softmaxArr);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates the logarithm of the softmax function for the input array.
|
||||
* @param {number[]} arr The input array to calculate the log_softmax function for.
|
||||
* @returns {any} The resulting log_softmax array.
|
||||
* @template {TypedArray|number[]} T
|
||||
* @param {T} arr The input array to calculate the log_softmax function for.
|
||||
* @returns {T} The resulting log_softmax array.
|
||||
*/
|
||||
export function log_softmax(arr) {
|
||||
// Compute the softmax values
|
||||
|
@ -162,7 +164,7 @@ export function log_softmax(arr) {
|
|||
// Apply log formula to each element
|
||||
const logSoftmaxArr = softmaxArr.map(x => Math.log(x));
|
||||
|
||||
return logSoftmaxArr;
|
||||
return /** @type {T} */(logSoftmaxArr);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -178,8 +180,7 @@ export function dot(arr1, arr2) {
|
|||
|
||||
/**
|
||||
* Get the top k items from an iterable, sorted by descending order
|
||||
*
|
||||
* @param {Array} items The items to be sorted
|
||||
* @param {any[]|TypedArray} items The items to be sorted
|
||||
* @param {number} [top_k=0] The number of top items to return (default: 0 = return all)
|
||||
* @returns {Array} The top k items, sorted by descending order
|
||||
*/
|
||||
|
@ -252,8 +253,8 @@ export function min(arr) {
|
|||
|
||||
/**
|
||||
* Returns the value and index of the maximum element in an array.
|
||||
* @param {number[]|TypedArray} arr array of numbers.
|
||||
* @returns {number[]} the value and index of the maximum element, of the form: [valueOfMax, indexOfMax]
|
||||
* @param {number[]|AnyTypedArray} arr array of numbers.
|
||||
* @returns {[number, number]} the value and index of the maximum element, of the form: [valueOfMax, indexOfMax]
|
||||
* @throws {Error} If array is empty.
|
||||
*/
|
||||
export function max(arr) {
|
||||
|
@ -266,7 +267,7 @@ export function max(arr) {
|
|||
indexOfMax = i;
|
||||
}
|
||||
}
|
||||
return [max, indexOfMax];
|
||||
return [Number(max), indexOfMax];
|
||||
}
|
||||
|
||||
function isPowerOfTwo(number) {
|
||||
|
|
|
@ -15,40 +15,57 @@ import {
|
|||
} from './maths.js';
|
||||
|
||||
|
||||
// @ts-ignore
|
||||
const DataTypeMap = new Map([
|
||||
['bool', Uint8Array],
|
||||
['float32', Float32Array],
|
||||
['float64', Float64Array],
|
||||
['string', Array], // string[]
|
||||
['int8', Int8Array],
|
||||
['uint8', Uint8Array],
|
||||
['int16', Int16Array],
|
||||
['uint16', Uint16Array],
|
||||
['int32', Int32Array],
|
||||
['uint32', Uint32Array],
|
||||
['int64', BigInt64Array],
|
||||
])
|
||||
const DataTypeMap = Object.freeze({
|
||||
float32: Float32Array,
|
||||
float64: Float64Array,
|
||||
string: Array, // string[]
|
||||
int8: Int8Array,
|
||||
uint8: Uint8Array,
|
||||
int16: Int16Array,
|
||||
uint16: Uint16Array,
|
||||
int32: Int32Array,
|
||||
uint32: Uint32Array,
|
||||
int64: BigInt64Array,
|
||||
uint64: BigUint64Array,
|
||||
bool: Uint8Array,
|
||||
});
|
||||
|
||||
/**
|
||||
* @typedef {keyof typeof DataTypeMap} DataType
|
||||
* @typedef {import('./maths.js').AnyTypedArray | any[]} DataArray
|
||||
*/
|
||||
|
||||
const ONNXTensor = ONNX.Tensor;
|
||||
|
||||
export class Tensor extends ONNXTensor {
|
||||
export class Tensor {
|
||||
/** @type {number[]} Dimensions of the tensor. */
|
||||
dims;
|
||||
|
||||
/** @type {DataType} Type of the tensor. */
|
||||
type;
|
||||
|
||||
/** @type {DataArray} The data stored in the tensor. */
|
||||
data;
|
||||
|
||||
/** @type {number} The number of elements in the tensor. */
|
||||
size;
|
||||
|
||||
/**
|
||||
* Create a new Tensor or copy an existing Tensor.
|
||||
* @param {[string, DataArray, number[]]|[ONNXTensor]} args
|
||||
* @param {[DataType, DataArray, number[]]|[import('onnxruntime-common').Tensor]} args
|
||||
*/
|
||||
constructor(...args) {
|
||||
if (args[0] instanceof ONNX.Tensor) {
|
||||
if (args[0] instanceof ONNXTensor) {
|
||||
// Create shallow copy
|
||||
super(args[0].type, args[0].data, args[0].dims);
|
||||
Object.assign(this, args[0]);
|
||||
|
||||
} else {
|
||||
// Create new
|
||||
super(...args);
|
||||
// Create new tensor
|
||||
Object.assign(this, new ONNXTensor(
|
||||
/** @type {DataType} */(args[0]),
|
||||
/** @type {Exclude<import('./maths.js').AnyTypedArray, Uint8ClampedArray>} */(args[1]),
|
||||
args[2]
|
||||
));
|
||||
}
|
||||
|
||||
return new Proxy(this, {
|
||||
|
@ -130,14 +147,21 @@ export class Tensor extends ONNXTensor {
|
|||
* @returns {Tensor}
|
||||
*/
|
||||
_subarray(index, iterSize, iterDims) {
|
||||
let data = this.data.subarray(index * iterSize, (index + 1) * iterSize);
|
||||
const o1 = index * iterSize;
|
||||
const o2 = (index + 1) * iterSize;
|
||||
|
||||
// We use subarray if available (typed array), otherwise we use slice (normal array)
|
||||
const data =
|
||||
('subarray' in this.data)
|
||||
? this.data.subarray(o1, o2)
|
||||
: this.data.slice(o1, o2);
|
||||
return new Tensor(this.type, data, iterDims);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the value of this tensor as a standard JavaScript Number. This only works
|
||||
* for tensors with one element. For other cases, see `Tensor.tolist()`.
|
||||
* @returns {number} The value of this tensor as a standard JavaScript Number.
|
||||
* @returns {number|bigint} The value of this tensor as a standard JavaScript Number.
|
||||
* @throws {Error} If the tensor has more than one element.
|
||||
*/
|
||||
item() {
|
||||
|
@ -265,6 +289,7 @@ export class Tensor extends ONNXTensor {
|
|||
let newBufferSize = newDims.reduce((a, b) => a * b);
|
||||
|
||||
// Allocate memory
|
||||
// @ts-ignore
|
||||
let data = new this.data.constructor(newBufferSize);
|
||||
|
||||
// Precompute strides
|
||||
|
@ -338,6 +363,7 @@ export class Tensor extends ONNXTensor {
|
|||
resultDims[dim] = 1; // Remove the specified axis
|
||||
|
||||
// Create a new array to store the accumulated values
|
||||
// @ts-ignore
|
||||
const result = new this.data.constructor(this.data.length / this.dims[dim]);
|
||||
|
||||
// Iterate over the data array
|
||||
|
@ -579,7 +605,7 @@ export class Tensor extends ONNXTensor {
|
|||
|
||||
/**
|
||||
* Performs Tensor dtype conversion.
|
||||
* @param {'bool'|'float32'|'float64'|'string'|'int8'|'uint8'|'int16'|'uint16'|'int32'|'uint32'|'int64'} type
|
||||
* @param {DataType} type The desired data type.
|
||||
* @returns {Tensor} The converted tensor.
|
||||
*/
|
||||
to(type) {
|
||||
|
@ -587,11 +613,11 @@ export class Tensor extends ONNXTensor {
|
|||
if (this.type === type) return this;
|
||||
|
||||
// Otherwise, the returned tensor is a copy of self with the desired dtype.
|
||||
const ArrayConstructor = DataTypeMap.get(type);
|
||||
if (!ArrayConstructor) {
|
||||
if (!DataTypeMap.hasOwnProperty(type)) {
|
||||
throw new Error(`Unsupported type: ${type}`);
|
||||
}
|
||||
return new Tensor(type, ArrayConstructor.from(this.data), this.dims);
|
||||
// @ts-ignore
|
||||
return new Tensor(type, DataTypeMap[type].from(this.data), this.dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -618,10 +644,10 @@ export class Tensor extends ONNXTensor {
|
|||
* reshape([1, 2, 3, 4 ], [2, 2 ]); // Type: number[][] Value: [[1, 2], [3, 4]]
|
||||
* reshape([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]); // Type: number[][][] Value: [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
|
||||
* reshape([1, 2, 3, 4, 5, 6, 7, 8], [4, 2 ]); // Type: number[][] Value: [[1, 2], [3, 4], [5, 6], [7, 8]]
|
||||
* @param {T[]} data The input array to reshape.
|
||||
* @param {T[]|DataArray} data The input array to reshape.
|
||||
* @param {DIM} dimensions The target shape/dimensions.
|
||||
* @template T
|
||||
* @template {[number]|[number, number]|[number, number, number]|[number, number, number, number]} DIM
|
||||
* @template {[number]|number[]} DIM
|
||||
* @returns {NestArray<T, DIM["length"]>} The reshaped array.
|
||||
*/
|
||||
function reshape(data, dimensions) {
|
||||
|
@ -681,7 +707,7 @@ export function interpolate(input, [out_height, out_width], mode = 'bilinear', a
|
|||
const in_width = input.dims.at(-1);
|
||||
|
||||
let output = interpolate_data(
|
||||
input.data,
|
||||
/** @type {import('./maths.js').TypedArray}*/(input.data),
|
||||
[in_channels, in_height, in_width],
|
||||
[out_height, out_width],
|
||||
mode,
|
||||
|
@ -701,6 +727,7 @@ export function mean_pooling(last_hidden_state, attention_mask) {
|
|||
// attention_mask: [batchSize, seqLength]
|
||||
|
||||
let shape = [last_hidden_state.dims[0], last_hidden_state.dims[2]];
|
||||
// @ts-ignore
|
||||
let returnedData = new last_hidden_state.data.constructor(shape[0] * shape[1]);
|
||||
let [batchSize, seqLength, embedDim] = last_hidden_state.dims;
|
||||
|
||||
|
@ -813,6 +840,7 @@ export function cat(tensors, dim = 0) {
|
|||
|
||||
// Create a new array to store the accumulated values
|
||||
const resultSize = resultDims.reduce((a, b) => a * b, 1);
|
||||
// @ts-ignore
|
||||
const result = new tensors[0].data.constructor(resultSize);
|
||||
|
||||
// Create output tensor of same type as first
|
||||
|
@ -884,8 +912,10 @@ export function std_mean(input, dim = null, correction = 1, keepdim = false) {
|
|||
|
||||
if (dim === null) {
|
||||
// None to reduce over all dimensions.
|
||||
// @ts-ignore
|
||||
const sum = input.data.reduce((a, b) => a + b, 0);
|
||||
const mean = sum / input.data.length;
|
||||
// @ts-ignore
|
||||
const std = Math.sqrt(input.data.reduce((a, b) => a + (b - mean) ** 2, 0) / (input.data.length - correction));
|
||||
|
||||
const meanTensor = new Tensor(input.type, [mean], [/* scalar */]);
|
||||
|
@ -904,6 +934,7 @@ export function std_mean(input, dim = null, correction = 1, keepdim = false) {
|
|||
resultDims[dim] = 1; // Remove the specified axis
|
||||
|
||||
// Create a new array to store the accumulated values
|
||||
// @ts-ignore
|
||||
const result = new input.data.constructor(input.data.length / input.dims[dim]);
|
||||
|
||||
// Iterate over the data array
|
||||
|
@ -951,6 +982,7 @@ export function mean(input, dim = null, keepdim = false) {
|
|||
|
||||
if (dim === null) {
|
||||
// None to reduce over all dimensions.
|
||||
// @ts-ignore
|
||||
let val = input.data.reduce((a, b) => a + b, 0);
|
||||
return new Tensor(input.type, [val / input.data.length], [/* scalar */]);
|
||||
}
|
||||
|
@ -963,6 +995,7 @@ export function mean(input, dim = null, keepdim = false) {
|
|||
resultDims[dim] = 1; // Remove the specified axis
|
||||
|
||||
// Create a new array to store the accumulated values
|
||||
// @ts-ignore
|
||||
const result = new input.data.constructor(input.data.length / input.dims[dim]);
|
||||
|
||||
// Iterate over the data array
|
||||
|
@ -1054,6 +1087,7 @@ export function dynamicTimeWarping(matrix) {
|
|||
let i = output_length;
|
||||
let j = input_length;
|
||||
|
||||
// @ts-ignore
|
||||
trace.data.fill(2, 0, outputShape[1]) // trace[0, :] = 2
|
||||
for (let i = 0; i < outputShape[0]; ++i) { // trace[:, 0] = 1
|
||||
trace[i].data[0] = 1;
|
||||
|
|
Loading…
Reference in New Issue