Fix typing issues

This commit is contained in:
Joshua Lochner 2023-12-11 21:20:42 +02:00
parent b768cb8588
commit 70d3f7c34d
2 changed files with 38 additions and 23 deletions

View File

@ -252,8 +252,8 @@ export function min(arr) {
/**
* Returns the value and index of the maximum element in an array.
* @param {number[]|TypedArray} arr array of numbers.
* @returns {number[]} the value and index of the maximum element, of the form: [valueOfMax, indexOfMax]
* @param {number[]|AnyTypedArray} arr array of numbers.
* @returns {[number, number]} the value and index of the maximum element, of the form: [valueOfMax, indexOfMax]
* @throws {Error} If array is empty.
*/
export function max(arr) {
@ -266,7 +266,7 @@ export function max(arr) {
indexOfMax = i;
}
}
return [max, indexOfMax];
return [Number(max), indexOfMax];
}
function isPowerOfTwo(number) {

View File

@ -15,20 +15,18 @@ import {
} from './maths.js';
// @ts-ignore
const DataTypeMap = new Map([
['bool', Uint8Array],
['float32', Float32Array],
['float64', Float64Array],
['string', Array], // string[]
['int8', Int8Array],
['uint8', Uint8Array],
['int16', Int16Array],
['uint16', Uint16Array],
['int32', Int32Array],
['uint32', Uint32Array],
['int64', BigInt64Array],
])
const DataTypeMap = Object.freeze({
float32: Float32Array,
float64: Float64Array,
string: Array, // string[]
int8: Int8Array,
uint8: Uint8Array,
int16: Int16Array,
uint16: Uint16Array,
int32: Int32Array,
uint32: Uint32Array,
int64: BigInt64Array,
});
/**
* @typedef {import('./maths.js').AnyTypedArray | any[]} DataArray
@ -132,14 +130,21 @@ export class Tensor {
* @returns {Tensor}
*/
_subarray(index, iterSize, iterDims) {
let data = this.data.subarray(index * iterSize, (index + 1) * iterSize);
const o1 = index * iterSize;
const o2 = (index + 1) * iterSize;
// We use subarray if available (typed array), otherwise we use slice (normal array)
const data =
('subarray' in this.data)
? this.data.subarray(o1, o2)
: this.data.slice(o1, o2);
return new Tensor(this.type, data, iterDims);
}
/**
* Returns the value of this tensor as a standard JavaScript Number. This only works
* for tensors with one element. For other cases, see `Tensor.tolist()`.
* @returns {number} The value of this tensor as a standard JavaScript Number.
* @returns {number|bigint} The value of this tensor as a standard JavaScript Number.
* @throws {Error} If the tensor has more than one element.
*/
item() {
@ -267,6 +272,7 @@ export class Tensor {
let newBufferSize = newDims.reduce((a, b) => a * b);
// Allocate memory
// @ts-ignore
let data = new this.data.constructor(newBufferSize);
// Precompute strides
@ -340,6 +346,7 @@ export class Tensor {
resultDims[dim] = 1; // Remove the specified axis
// Create a new array to store the accumulated values
// @ts-ignore
const result = new this.data.constructor(this.data.length / this.dims[dim]);
// Iterate over the data array
@ -589,7 +596,7 @@ export class Tensor {
if (this.type === type) return this;
// Otherwise, the returned tensor is a copy of self with the desired dtype.
const ArrayConstructor = DataTypeMap.get(type);
const ArrayConstructor = DataTypeMap[type];
if (!ArrayConstructor) {
throw new Error(`Unsupported type: ${type}`);
}
@ -620,10 +627,10 @@ export class Tensor {
* reshape([1, 2, 3, 4 ], [2, 2 ]); // Type: number[][] Value: [[1, 2], [3, 4]]
* reshape([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]); // Type: number[][][] Value: [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
* reshape([1, 2, 3, 4, 5, 6, 7, 8], [4, 2 ]); // Type: number[][] Value: [[1, 2], [3, 4], [5, 6], [7, 8]]
* @param {T[]} data The input array to reshape.
* @param {T[]|DataArray} data The input array to reshape.
* @param {DIM} dimensions The target shape/dimensions.
* @template T
* @template {[number]|[number, number]|[number, number, number]|[number, number, number, number]} DIM
* @template {[number]|number[]} DIM
* @returns {NestArray<T, DIM["length"]>} The reshaped array.
*/
function reshape(data, dimensions) {
@ -683,7 +690,7 @@ export function interpolate(input, [out_height, out_width], mode = 'bilinear', a
const in_width = input.dims.at(-1);
let output = interpolate_data(
input.data,
/** @type {import('./maths.js').TypedArray}*/(input.data),
[in_channels, in_height, in_width],
[out_height, out_width],
mode,
@ -703,6 +710,7 @@ export function mean_pooling(last_hidden_state, attention_mask) {
// attention_mask: [batchSize, seqLength]
let shape = [last_hidden_state.dims[0], last_hidden_state.dims[2]];
// @ts-ignore
let returnedData = new last_hidden_state.data.constructor(shape[0] * shape[1]);
let [batchSize, seqLength, embedDim] = last_hidden_state.dims;
@ -815,6 +823,7 @@ export function cat(tensors, dim = 0) {
// Create a new array to store the accumulated values
const resultSize = resultDims.reduce((a, b) => a * b, 1);
// @ts-ignore
const result = new tensors[0].data.constructor(resultSize);
// Create output tensor of same type as first
@ -886,8 +895,10 @@ export function std_mean(input, dim = null, correction = 1, keepdim = false) {
if (dim === null) {
// None to reduce over all dimensions.
// @ts-ignore
const sum = input.data.reduce((a, b) => a + b, 0);
const mean = sum / input.data.length;
// @ts-ignore
const std = Math.sqrt(input.data.reduce((a, b) => a + (b - mean) ** 2, 0) / (input.data.length - correction));
const meanTensor = new Tensor(input.type, [mean], [/* scalar */]);
@ -906,6 +917,7 @@ export function std_mean(input, dim = null, correction = 1, keepdim = false) {
resultDims[dim] = 1; // Remove the specified axis
// Create a new array to store the accumulated values
// @ts-ignore
const result = new input.data.constructor(input.data.length / input.dims[dim]);
// Iterate over the data array
@ -953,6 +965,7 @@ export function mean(input, dim = null, keepdim = false) {
if (dim === null) {
// None to reduce over all dimensions.
// @ts-ignore
let val = input.data.reduce((a, b) => a + b, 0);
return new Tensor(input.type, [val / input.data.length], [/* scalar */]);
}
@ -965,6 +978,7 @@ export function mean(input, dim = null, keepdim = false) {
resultDims[dim] = 1; // Remove the specified axis
// Create a new array to store the accumulated values
// @ts-ignore
const result = new input.data.constructor(input.data.length / input.dims[dim]);
// Iterate over the data array
@ -1056,6 +1070,7 @@ export function dynamicTimeWarping(matrix) {
let i = output_length;
let j = input_length;
// @ts-ignore
trace.data.fill(2, 0, outputShape[1]) // trace[0, :] = 2
for (let i = 0; i < outputShape[0]; ++i) { // trace[:, 0] = 1
trace[i].data[0] = 1;