transformers.js/tests/maths.test.js

127 lines
4.9 KiB
JavaScript
Raw Normal View History

Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
import { compare } from './test_utils.js';
import { getFile } from '../src/utils/hub.js';
import { FFT, medianFilter } from '../src/utils/maths.js';
const fft = (arr, complex = false) => {
let output;
let fft;
if (complex) {
fft = new FFT(arr.length / 2);
output = new Float64Array(fft.outputBufferSize);
fft.transform(output, arr);
} else {
fft = new FFT(arr.length);
output = new Float64Array(fft.outputBufferSize);
fft.realTransform(output, arr);
}
if (!fft.isPowerOfTwo) {
output = output.slice(0, complex ? arr.length : 2 * arr.length);
}
return output;
}
const fftTestsData = await (await getFile('./tests/data/fft_tests.json')).json()
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
describe('Mathematical operations', () => {
describe('median filtering', () => {
it('should compute median filter', async () => {
const t1 = new Float32Array([5, 12, 2, 6, 3, 10, 9, 1, 4, 8, 11, 7]);
const window = 3;
const target = new Float32Array([12, 5, 6, 3, 6, 9, 9, 4, 4, 8, 8, 11]);
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
const output = medianFilter(t1, window);
compare(output, target, 1e-3);
});
// TODO add tests for errors
});
describe('FFT', () => {
// Should match output of numpy fft
it('should compute real FFT for power of two', () => {
{ // size = 4
// np.fft.fft([1,2,3,4]) == array([10.+0.j, -2.+2.j, -2.+0.j, -2.-2.j])
const input = new Float32Array([1, 2, 3, 4]);
const target = new Float32Array([10, 0, -2, 2, -2, 0, -2, -2]);
const output = fft(input);
compare(output, target, 1e-3);
}
{ // size = 16
// np.fft.fft([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16])
// == array([136. +0.j , -8.+40.21871594j, -8.+19.3137085j ,
// -8.+11.9728461j , -8. +8.j , -8. +5.3454291j ,
// -8. +3.3137085j , -8. +1.59129894j, -8. +0.j ,
// -8. -1.59129894j, -8. -3.3137085j , -8. -5.3454291j ,
// -8. -8.j , -8.-11.9728461j , -8.-19.3137085j ,
// -8.-40.21871594j])
const input = new Float32Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
const target = new Float32Array([136.0, 0.0, -8.0, 40.218715937006785, -8.0, 19.31370849898476, -8.0, 11.972846101323912, -8.0, 8.0, -8.0, 5.345429103354389, -8.0, 3.313708498984761, -8.0, 1.5912989390372658, -8.0, 0.0, -8.0, -1.5912989390372658, -8.0, -3.313708498984761, -8.0, -5.345429103354389, -8.0, -8.0, -8.0, -11.972846101323912, -8.0, -19.31370849898476, -8.0, -40.218715937006785]);
const output = fft(input);
compare(output, target, 1e-3);
}
});
it('should compute real FFT for non-power of two', () => {
{ // size = 3
// np.fft.fft([1,2,3]) == array([ 6. +0.j, -1.5+0.8660254j, -1.5-0.8660254j])
const input = new Float32Array([1, 2, 3]);
const target = new Float32Array([6, 0, -1.5, 0.8660254, -1.5, -0.8660254]);
const output = fft(input);
compare(output, target, 1e-3);
}
});
it('should compute complex FFT for non-power of two', () => {
{ // size = 3
// np.fft.fft([1+3j,2-2j,3+1j]) == array([ 6. +2.j, -4.09807621+4.3660254j, 1.09807621+2.6339746j])
const input = new Float32Array([1, 3, 2, -2, 3, 1]);
const target = new Float32Array([6, 2, -4.09807621, 4.3660254, 1.09807621, 2.6339746]);
const output = fft(input, true);
compare(output, target, 1e-3);
}
});
it('should compute complex FFT for power of two', () => {
{ // size = 4
// np.fft.fft([1+4j, 2-3j,3+2j, 4-1j]) == array([10. +2.j, -4. +4.j, -2.+10.j, 0. +0.j])
const input = new Float32Array([1, 4, 2, -3, 3, 2, 4, -1]);
const target = new Float32Array([10, 2, -4, 4, -2, 10, 0, 0]);
const output = fft(input, true);
compare(output, target, 1e-3);
}
});
})
describe('FFT (dynamic)', () => {
// Should match output of numpy fft
for (const [name, test] of Object.entries(fftTestsData)) {
// if (test.input.length > 5) continue;
it(name, () => {
const output = fft(test.input, test.complex);
if (output.map((v, i) => Math.abs(v - test.output[i])).some(v => v > 1e-4)) {
console.log('input', test.input)
console.log('output', output)
console.log('target', test.output)
}
compare(output, test.output, 1e-4);
});
}
});
Whisper word-level timestamps (#184) * Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
2023-07-10 05:21:43 +08:00
});