275 lines
12 KiB
JavaScript
275 lines
12 KiB
JavaScript
|
|
|
|
import { AutoTokenizer } from '../src/transformers.js';
|
|
import { getFile } from '../src/utils/hub.js';
|
|
import { m, MAX_TEST_EXECUTION_TIME } from './init.js';
|
|
import { compare } from './test_utils.js';
|
|
|
|
// Load test data generated by the python tests
|
|
// TODO do this dynamically?
|
|
const { tokenization, templates } = await (await getFile('./tests/data/tokenizer_tests.json')).json()
|
|
|
|
// Dynamic tests to ensure transformers.js (JavaScript) matches transformers (Python)
|
|
describe('Tokenizers (dynamic)', () => {
|
|
|
|
for (let [tokenizerName, tests] of Object.entries(tokenization)) {
|
|
|
|
it(tokenizerName, async () => {
|
|
let tokenizer = await AutoTokenizer.from_pretrained(m(tokenizerName));
|
|
|
|
for (let test of tests) {
|
|
|
|
// Test encoding
|
|
let encoded = tokenizer(test.input, {
|
|
return_tensor: false
|
|
});
|
|
|
|
// Add the input text to the encoded object for easier debugging
|
|
test.encoded.input = encoded.input = test.input;
|
|
|
|
expect(encoded).toEqual(test.encoded);
|
|
|
|
// Skip decoding tests if encoding produces zero tokens
|
|
if (test.encoded.input_ids.length === 0) continue;
|
|
|
|
// Test decoding
|
|
let decoded_with_special = tokenizer.decode(encoded.input_ids, { skip_special_tokens: false });
|
|
expect(decoded_with_special).toEqual(test.decoded_with_special);
|
|
|
|
let decoded_without_special = tokenizer.decode(encoded.input_ids, { skip_special_tokens: true });
|
|
expect(decoded_without_special).toEqual(test.decoded_without_special);
|
|
}
|
|
}, MAX_TEST_EXECUTION_TIME);
|
|
}
|
|
});
|
|
|
|
// Tests to ensure that no matter what, the correct tokenization is returned.
|
|
// This is necessary since there are sometimes bugs in the transformers library.
|
|
describe('Tokenizers (hard-coded)', () => {
|
|
const TESTS = {
|
|
'Xenova/llama-tokenizer': [ // Test legacy compatibility
|
|
{
|
|
// legacy unset => legacy=true
|
|
// NOTE: While incorrect, it is necessary to match legacy behaviour
|
|
data: {
|
|
"<s>\n": [1, 29871, 13],
|
|
},
|
|
legacy: null,
|
|
},
|
|
{
|
|
// override legacy=true (same results as above)
|
|
data: {
|
|
"<s>\n": [1, 29871, 13],
|
|
},
|
|
legacy: true,
|
|
},
|
|
{
|
|
// override legacy=false (fixed results)
|
|
data: {
|
|
"<s>\n": [1, 13],
|
|
},
|
|
legacy: false,
|
|
}
|
|
],
|
|
|
|
'Xenova/llama-tokenizer_new': [ // legacy=false
|
|
{
|
|
data: {
|
|
" </s> 1 2 3 4 ": [259, 2, 29871, 29896, 259, 29906, 1678, 29941, 268, 29946, 1678],
|
|
"<s>\n": [1, 13],
|
|
"</s>test</s>": [2, 1688, 2],
|
|
" </s> test </s> ": [259, 2, 1243, 29871, 2, 29871],
|
|
"A\n'll": [319, 13, 29915, 645],
|
|
"Hey </s>. how are you": [18637, 29871, 2, 29889, 920, 526, 366],
|
|
" Hi Hello ": [259, 6324, 29871, 15043, 259],
|
|
},
|
|
reversible: true,
|
|
legacy: null,
|
|
},
|
|
{ // override legacy=true (incorrect results, but necessary to match legacy behaviour)
|
|
data: {
|
|
"<s>\n": [1, 29871, 13],
|
|
},
|
|
legacy: true,
|
|
},
|
|
],
|
|
|
|
// legacy=false
|
|
'Xenova/t5-tokenizer-new': [
|
|
{
|
|
data: {
|
|
// https://github.com/huggingface/transformers/pull/26678
|
|
// ['▁Hey', '▁', '</s>', '.', '▁how', '▁are', '▁you']
|
|
"Hey </s>. how are you": [9459, 3, 1, 5, 149, 33, 25],
|
|
},
|
|
reversible: true,
|
|
legacy: null,
|
|
},
|
|
{
|
|
data: {
|
|
"</s>\n": [1, 3],
|
|
"A\n'll": [71, 3, 31, 195],
|
|
},
|
|
reversible: false,
|
|
legacy: null,
|
|
}
|
|
],
|
|
}
|
|
|
|
// Re-use the same tests for the llama2 tokenizer
|
|
TESTS['Xenova/llama2-tokenizer'] = TESTS['Xenova/llama-tokenizer_new'];
|
|
|
|
for (const [tokenizerName, test_data] of Object.entries(TESTS)) {
|
|
|
|
it(tokenizerName, async () => {
|
|
for (const { data, reversible, legacy } of test_data) {
|
|
const tokenizer = await AutoTokenizer.from_pretrained(m(tokenizerName), { legacy });
|
|
|
|
for (const [text, expected] of Object.entries(data)) {
|
|
const token_ids = tokenizer.encode(text, null, { add_special_tokens: false });
|
|
expect(token_ids).toEqual(expected);
|
|
|
|
// If reversible, test that decoding produces the original text
|
|
if (reversible) {
|
|
const decoded = tokenizer.decode(token_ids);
|
|
expect(decoded).toEqual(text);
|
|
}
|
|
}
|
|
}
|
|
}, MAX_TEST_EXECUTION_TIME);
|
|
}
|
|
});
|
|
|
|
describe('Edge cases', () => {
|
|
it('should not crash when encoding a very long string', async () => {
|
|
let tokenizer = await AutoTokenizer.from_pretrained('Xenova/t5-small');
|
|
|
|
let text = String.prototype.repeat.call('Hello world! ', 50000);
|
|
let encoded = tokenizer(text);
|
|
expect(encoded.input_ids.data.length).toBeGreaterThan(100000);
|
|
}, MAX_TEST_EXECUTION_TIME);
|
|
|
|
it('should not take too long', async () => {
|
|
let tokenizer = await AutoTokenizer.from_pretrained('Xenova/all-MiniLM-L6-v2');
|
|
|
|
let text = String.prototype.repeat.call('a', 50000);
|
|
let token_ids = tokenizer.encode(text);
|
|
compare(token_ids, [101, 100, 102])
|
|
}, 5000); // NOTE: 5 seconds
|
|
});
|
|
|
|
describe('Extra decoding tests', () => {
|
|
it('should be able to decode the output of encode', async () => {
|
|
let tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased');
|
|
|
|
let text = 'hello world!';
|
|
|
|
// Ensure all the following outputs are the same:
|
|
// 1. Tensor of ids: allow decoding of 1D or 2D tensors.
|
|
let encodedTensor = tokenizer(text);
|
|
let decoded1 = tokenizer.decode(encodedTensor.input_ids, { skip_special_tokens: true });
|
|
let decoded2 = tokenizer.batch_decode(encodedTensor.input_ids, { skip_special_tokens: true })[0];
|
|
expect(decoded1).toEqual(text);
|
|
expect(decoded2).toEqual(text);
|
|
|
|
// 2. List of ids
|
|
let encodedList = tokenizer(text, { return_tensor: false });
|
|
let decoded3 = tokenizer.decode(encodedList.input_ids, { skip_special_tokens: true });
|
|
let decoded4 = tokenizer.batch_decode([encodedList.input_ids], { skip_special_tokens: true })[0];
|
|
expect(decoded3).toEqual(text);
|
|
expect(decoded4).toEqual(text);
|
|
|
|
}, MAX_TEST_EXECUTION_TIME);
|
|
});
|
|
|
|
describe('Chat templates', () => {
|
|
it('should generate a chat template', async () => {
|
|
const tokenizer = await AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1");
|
|
|
|
const chat = [
|
|
{ "role": "user", "content": "Hello, how are you?" },
|
|
{ "role": "assistant", "content": "I'm doing great. How can I help you today?" },
|
|
{ "role": "user", "content": "I'd like to show off how chat templating works!" },
|
|
]
|
|
|
|
const text = tokenizer.apply_chat_template(chat, { tokenize: false });
|
|
|
|
expect(text).toEqual("<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]");
|
|
|
|
const input_ids = tokenizer.apply_chat_template(chat, { tokenize: true, return_tensor: false });
|
|
compare(input_ids, [1, 733, 16289, 28793, 22557, 28725, 910, 460, 368, 28804, 733, 28748, 16289, 28793, 28737, 28742, 28719, 2548, 1598, 28723, 1602, 541, 315, 1316, 368, 3154, 28804, 2, 28705, 733, 16289, 28793, 315, 28742, 28715, 737, 298, 1347, 805, 910, 10706, 5752, 1077, 3791, 28808, 733, 28748, 16289, 28793])
|
|
});
|
|
|
|
it('should support user-defined chat template', async () => {
|
|
const tokenizer = await AutoTokenizer.from_pretrained("Xenova/llama-tokenizer");
|
|
|
|
const chat = [
|
|
{ role: 'user', content: 'Hello, how are you?' },
|
|
{ role: 'assistant', content: "I'm doing great. How can I help you today?" },
|
|
{ role: 'user', content: "I'd like to show off how chat templating works!" },
|
|
]
|
|
|
|
// https://discuss.huggingface.co/t/issue-with-llama-2-chat-template-and-out-of-date-documentation/61645/3
|
|
const chat_template = (
|
|
"{% if messages[0]['role'] == 'system' %}" +
|
|
"{% set loop_messages = messages[1:] %}" + // Extract system message if it's present
|
|
"{% set system_message = messages[0]['content'] %}" +
|
|
"{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}" +
|
|
"{% set loop_messages = messages %}" + // Or use the default system message if the flag is set
|
|
"{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}" +
|
|
"{% else %}" +
|
|
"{% set loop_messages = messages %}" +
|
|
"{% set system_message = false %}" +
|
|
"{% endif %}" +
|
|
"{% if loop_messages|length == 0 and system_message %}" + // Special handling when only sys message present
|
|
"{{ bos_token + '[INST] <<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n [/INST]' }}" +
|
|
"{% endif %}" +
|
|
"{% for message in loop_messages %}" + // Loop over all non-system messages
|
|
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" +
|
|
"{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" +
|
|
"{% endif %}" +
|
|
"{% if loop.index0 == 0 and system_message != false %}" + // Embed system message in first message
|
|
"{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}" +
|
|
"{% else %}" +
|
|
"{% set content = message['content'] %}" +
|
|
"{% endif %}" +
|
|
"{% if message['role'] == 'user' %}" + // After all of that, handle messages/roles in a fairly normal way
|
|
"{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" +
|
|
"{% elif message['role'] == 'system' %}" +
|
|
"{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}" +
|
|
"{% elif message['role'] == 'assistant' %}" +
|
|
"{{ ' ' + content.strip() + ' ' + eos_token }}" +
|
|
"{% endif %}" +
|
|
"{% endfor %}"
|
|
)
|
|
.replaceAll('USE_DEFAULT_PROMPT', true)
|
|
.replaceAll('DEFAULT_SYSTEM_MESSAGE', 'You are a helpful, respectful and honest assistant.');
|
|
|
|
const text = await tokenizer.apply_chat_template(chat, { tokenize: false, return_tensor: false, chat_template });
|
|
|
|
expect(text).toEqual("<s>[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant.\n<</SYS>>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]");
|
|
|
|
// TODO: Add test for token_ids once bug in transformers is fixed.
|
|
});
|
|
|
|
// Dynamically-generated tests
|
|
for (const [tokenizerName, tests] of Object.entries(templates)) {
|
|
|
|
it(tokenizerName, async () => {
|
|
// NOTE: not m(...) here
|
|
// TODO: update this?
|
|
const tokenizer = await AutoTokenizer.from_pretrained(tokenizerName);
|
|
|
|
for (let { messages, add_generation_prompt, tokenize, target } of tests) {
|
|
|
|
const generated = await tokenizer.apply_chat_template(messages, {
|
|
tokenize,
|
|
add_generation_prompt,
|
|
return_tensor: false,
|
|
});
|
|
expect(generated).toEqual(target)
|
|
}
|
|
});
|
|
}
|
|
});
|