Add basic 2D `layer_norm` operator (#588)

This commit is contained in:
Joshua Lochner 2024-02-19 14:53:06 +02:00 committed by GitHub
parent 351dbed922
commit e9092d2337
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 51 additions and 2 deletions

View File

@ -762,6 +762,42 @@ export function mean_pooling(last_hidden_state, attention_mask) {
)
}
/**
* Apply Layer Normalization for last certain number of dimensions.
* @param {Tensor} input The input tensor
* @param {number[]} normalized_shape input shape from an expected input of size
* @param {Object} options The options for the layer normalization
* @param {number} [options.eps=1e-5] A value added to the denominator for numerical stability.
* @returns {Tensor} The normalized tensor.
*/
export function layer_norm(input, normalized_shape, {
eps = 1e-5,
} = {}) {
if (input.dims.length !== 2) {
throw new Error('`layer_norm` currently only supports 2D input.');
}
const [batchSize, featureDim] = input.dims;
if (normalized_shape.length !== 1 && normalized_shape[0] !== featureDim) {
throw new Error('`normalized_shape` must be a 1D array with shape `[input.dims[1]]`.');
}
const [std, mean] = std_mean(input, 1, 0, true);
// @ts-ignore
const returnedData = new input.data.constructor(input.data.length);
for (let i = 0; i < batchSize; ++i) {
const offset = i * featureDim;
for (let j = 0; j < featureDim; ++j) {
const offset2 = offset + j;
returnedData[offset2] = (input.data[offset2] - mean.data[i]) / (std.data[i] + eps);
}
}
return new Tensor(input.type, returnedData, input.dims);
}
/**
* Helper function to calculate new dimensions when performing a squeeze operation.
* @param {number[]} dims The dimensions of the tensor.

View File

@ -1,7 +1,7 @@
import { Tensor } from '../src/transformers.js';
import { compare } from './test_utils.js';
import { cat, mean, stack } from '../src/utils/tensor.js';
import { cat, mean, stack, layer_norm } from '../src/utils/tensor.js';
describe('Tensor operations', () => {
@ -103,7 +103,6 @@ describe('Tensor operations', () => {
});
});
describe('mean', () => {
it('should calculate mean', async () => {
const t1 = new Tensor('float32', [1, 2, 3, 4, 5, 6], [2, 3, 1]);
@ -128,4 +127,18 @@ describe('Tensor operations', () => {
})
});
describe('layer_norm', () => {
it('should calculate layer norm', async () => {
const t1 = new Tensor('float32', [1, 2, 3, 4, 5, 6], [2, 3]);
const target = new Tensor('float32', [
-1.2247356176376343, 0.0, 1.2247356176376343,
-1.2247357368469238, -1.1920928955078125e-07, 1.2247354984283447,
], [2, 3]);
const norm = layer_norm(t1, [t1.dims.at(-1)]);
compare(norm, target, 1e-3);
});
});
});