60 lines
2.2 KiB
JavaScript
60 lines
2.2 KiB
JavaScript
|
|
|
|
import { AutoTokenizer } from '../src/transformers.js';
|
|
import { getFile } from '../src/utils/hub.js';
|
|
import { m, MAX_TEST_EXECUTION_TIME } from './init.js';
|
|
import { compare } from './test_utils.js';
|
|
|
|
// Load test data generated by the python tests
|
|
// TODO do this dynamically?
|
|
let testsData = await (await getFile('./tests/data/tokenizer_tests.json')).json()
|
|
|
|
describe('Tokenizers', () => {
|
|
|
|
for (let [tokenizerName, tests] of Object.entries(testsData)) {
|
|
|
|
it(tokenizerName, async () => {
|
|
let tokenizer = await AutoTokenizer.from_pretrained(m(tokenizerName));
|
|
|
|
for (let test of tests) {
|
|
|
|
// Test encoding
|
|
let encoded = tokenizer(test.input, {
|
|
return_tensor: false
|
|
});
|
|
|
|
// Add the input text to the encoded object for easier debugging
|
|
encoded.input = test.input;
|
|
test.encoded.input = test.input;
|
|
|
|
expect(encoded).toEqual(test.encoded);
|
|
|
|
// Test decoding
|
|
let decoded_with_special = tokenizer.decode(encoded.input_ids, { skip_special_tokens: false });
|
|
expect(decoded_with_special).toEqual(test.decoded_with_special);
|
|
|
|
let decoded_without_special = tokenizer.decode(encoded.input_ids, { skip_special_tokens: true });
|
|
expect(decoded_without_special).toEqual(test.decoded_without_special);
|
|
}
|
|
}, MAX_TEST_EXECUTION_TIME);
|
|
}
|
|
});
|
|
|
|
describe('Edge cases', () => {
|
|
it('should not crash when encoding a very long string', async () => {
|
|
let tokenizer = await AutoTokenizer.from_pretrained('Xenova/t5-small');
|
|
|
|
let text = String.prototype.repeat.call('Hello world! ', 50000);
|
|
let encoded = tokenizer(text);
|
|
expect(encoded.input_ids.data.length).toBeGreaterThan(100000);
|
|
}, MAX_TEST_EXECUTION_TIME);
|
|
|
|
it('should not take too long', async () => {
|
|
let tokenizer = await AutoTokenizer.from_pretrained('Xenova/all-MiniLM-L6-v2');
|
|
|
|
let text = String.prototype.repeat.call('a', 50000);
|
|
let token_ids = tokenizer.encode(text);
|
|
compare(token_ids, [101, 100, 102])
|
|
}, 5000); // NOTE: 5 seconds
|
|
});
|