2023-05-02 11:19:23 +08:00
2023-05-03 08:37:08 +08:00
import { AutoTokenizer } from '../src/transformers.js' ;
2023-05-02 11:19:23 +08:00
import { getFile } from '../src/utils/hub.js' ;
2023-06-28 21:14:44 +08:00
import { m , MAX _TEST _EXECUTION _TIME } from './init.js' ;
2023-11-19 14:06:49 +08:00
import { compare } from './test_utils.js' ;
2023-05-02 11:19:23 +08:00
// Load test data generated by the python tests
// TODO do this dynamically?
2023-12-18 23:00:50 +08:00
const { tokenization , templates } = await ( await getFile ( './tests/data/tokenizer_tests.json' ) ) . json ( )
2023-05-02 11:19:23 +08:00
2023-12-18 23:00:50 +08:00
// Dynamic tests to ensure transformers.js (JavaScript) matches transformers (Python)
describe ( 'Tokenizers (dynamic)' , ( ) => {
2023-05-02 11:19:23 +08:00
2024-01-05 00:37:26 +08:00
for ( const [ tokenizerName , tests ] of Object . entries ( tokenization ) ) {
2023-05-02 11:19:23 +08:00
it ( tokenizerName , async ( ) => {
2024-01-05 00:37:26 +08:00
const tokenizer = await AutoTokenizer . from _pretrained ( m ( tokenizerName ) ) ;
2023-05-02 11:19:23 +08:00
2024-01-05 00:37:26 +08:00
for ( const test of tests ) {
// Two kinds of tests:
// 1. text w/o text_pair
// 2. text w text_pair
2023-05-15 12:32:10 +08:00
2024-01-05 00:37:26 +08:00
if ( typeof test . input === 'string' ) {
2023-06-28 21:14:44 +08:00
2024-01-05 00:37:26 +08:00
// Test encoding
const encoded = tokenizer ( test . input , {
return _tensor : false
} ) ;
2023-06-28 21:14:44 +08:00
2024-01-05 00:37:26 +08:00
// Add the input text to the encoded object for easier debugging
test . encoded . input = encoded . input = test . input ;
2023-05-15 12:32:10 +08:00
2024-01-05 00:37:26 +08:00
expect ( encoded ) . toEqual ( test . encoded ) ;
2023-12-26 22:34:52 +08:00
2024-01-05 00:37:26 +08:00
// Skip decoding tests if encoding produces zero tokens
if ( test . encoded . input _ids . length === 0 ) continue ;
2023-05-15 12:32:10 +08:00
2024-01-05 00:37:26 +08:00
// Test decoding
const decoded _with _special = tokenizer . decode ( encoded . input _ids , { skip _special _tokens : false } ) ;
expect ( decoded _with _special ) . toEqual ( test . decoded _with _special ) ;
const decoded _without _special = tokenizer . decode ( encoded . input _ids , { skip _special _tokens : true } ) ;
expect ( decoded _without _special ) . toEqual ( test . decoded _without _special ) ;
} else {
const { text , text _pair } = test . input ;
const encoded = tokenizer ( text , {
text _pair ,
return _tensor : false ,
} ) ;
compare ( encoded , test . output ) ;
}
2023-05-02 11:19:23 +08:00
}
2023-06-28 21:14:44 +08:00
} , MAX _TEST _EXECUTION _TIME ) ;
2023-05-02 11:19:23 +08:00
}
} ) ;
2023-06-21 07:21:14 +08:00
2023-12-18 23:00:50 +08:00
// Tests to ensure that no matter what, the correct tokenization is returned.
// This is necessary since there are sometimes bugs in the transformers library.
describe ( 'Tokenizers (hard-coded)' , ( ) => {
const TESTS = {
'Xenova/llama-tokenizer' : [ // Test legacy compatibility
{
// legacy unset => legacy=true
// NOTE: While incorrect, it is necessary to match legacy behaviour
data : {
"<s>\n" : [ 1 , 29871 , 13 ] ,
} ,
legacy : null ,
} ,
{
// override legacy=true (same results as above)
data : {
"<s>\n" : [ 1 , 29871 , 13 ] ,
} ,
legacy : true ,
} ,
{
// override legacy=false (fixed results)
data : {
"<s>\n" : [ 1 , 13 ] ,
} ,
legacy : false ,
}
] ,
'Xenova/llama-tokenizer_new' : [ // legacy=false
{
data : {
" </s> 1 2 3 4 " : [ 259 , 2 , 29871 , 29896 , 259 , 29906 , 1678 , 29941 , 268 , 29946 , 1678 ] ,
"<s>\n" : [ 1 , 13 ] ,
"</s>test</s>" : [ 2 , 1688 , 2 ] ,
" </s> test </s> " : [ 259 , 2 , 1243 , 29871 , 2 , 29871 ] ,
"A\n'll" : [ 319 , 13 , 29915 , 645 ] ,
"Hey </s>. how are you" : [ 18637 , 29871 , 2 , 29889 , 920 , 526 , 366 ] ,
" Hi Hello " : [ 259 , 6324 , 29871 , 15043 , 259 ] ,
} ,
reversible : true ,
legacy : null ,
} ,
{ // override legacy=true (incorrect results, but necessary to match legacy behaviour)
data : {
"<s>\n" : [ 1 , 29871 , 13 ] ,
} ,
legacy : true ,
} ,
] ,
// legacy=false
'Xenova/t5-tokenizer-new' : [
{
data : {
// https://github.com/huggingface/transformers/pull/26678
// ['▁Hey', '▁', '</s>', '.', '▁how', '▁are', '▁you']
"Hey </s>. how are you" : [ 9459 , 3 , 1 , 5 , 149 , 33 , 25 ] ,
} ,
reversible : true ,
legacy : null ,
} ,
{
data : {
"</s>\n" : [ 1 , 3 ] ,
"A\n'll" : [ 71 , 3 , 31 , 195 ] ,
} ,
reversible : false ,
legacy : null ,
}
] ,
}
// Re-use the same tests for the llama2 tokenizer
TESTS [ 'Xenova/llama2-tokenizer' ] = TESTS [ 'Xenova/llama-tokenizer_new' ] ;
for ( const [ tokenizerName , test _data ] of Object . entries ( TESTS ) ) {
it ( tokenizerName , async ( ) => {
for ( const { data , reversible , legacy } of test _data ) {
const tokenizer = await AutoTokenizer . from _pretrained ( m ( tokenizerName ) , { legacy } ) ;
for ( const [ text , expected ] of Object . entries ( data ) ) {
const token _ids = tokenizer . encode ( text , null , { add _special _tokens : false } ) ;
expect ( token _ids ) . toEqual ( expected ) ;
// If reversible, test that decoding produces the original text
if ( reversible ) {
const decoded = tokenizer . decode ( token _ids ) ;
expect ( decoded ) . toEqual ( text ) ;
}
}
}
} , MAX _TEST _EXECUTION _TIME ) ;
}
} ) ;
2024-01-05 00:37:26 +08:00
describe ( 'Tokenizer padding/truncation' , ( ) => {
const inputs = [ 'a' , 'b c' ] ;
const text _pair = [ 'd e' , 'f g h' ] ;
it ( 'should create a jagged array' , async ( ) => {
const tokenizer = await AutoTokenizer . from _pretrained ( 'Xenova/bert-base-uncased' ) ;
{ // support jagged array if `return_tensor=false`
const output = tokenizer ( inputs , {
return _tensor : false ,
} )
const expected = {
input _ids : [ [ 101 , 1037 , 102 ] , [ 101 , 1038 , 1039 , 102 ] ] ,
attention _mask : [ [ 1 , 1 , 1 ] , [ 1 , 1 , 1 , 1 ] ] ,
token _type _ids : [ [ 0 , 0 , 0 ] , [ 0 , 0 , 0 , 0 ] ]
}
compare ( output , expected ) ;
}
{
const output = tokenizer ( inputs , {
return _tensor : false ,
truncation : true ,
add _special _tokens : false ,
} )
const expected = {
input _ids : [ [ 1037 ] , [ 1038 , 1039 ] ] ,
attention _mask : [ [ 1 ] , [ 1 , 1 ] ] ,
token _type _ids : [ [ 0 ] , [ 0 , 0 ] ]
}
compare ( output , expected ) ;
}
} )
it ( 'should create a tensor' , async ( ) => {
const tokenizer = await AutoTokenizer . from _pretrained ( 'Xenova/bert-base-uncased' ) ;
{ // Expected to throw error if jagged array
expect ( ( ) => tokenizer ( inputs ) ) . toThrowError ( 'Unable to create tensor' ) ;
}
{ // Truncation
const { input _ids , attention _mask , token _type _ids } = tokenizer ( inputs , {
truncation : true ,
max _length : 1 ,
add _special _tokens : false ,
} )
expect ( input _ids . tolist ( ) ) . toEqual ( [ [ 1037 n ] , [ 1038 n ] ] )
expect ( attention _mask . tolist ( ) ) . toEqual ( [ [ 1 n ] , [ 1 n ] ] )
expect ( token _type _ids . tolist ( ) ) . toEqual ( [ [ 0 n ] , [ 0 n ] ] )
}
{ // Truncation w/ text pair
// TODO
}
{ // Padding
const { input _ids , attention _mask , token _type _ids } = tokenizer ( inputs , {
padding : true ,
add _special _tokens : false ,
} )
expect ( input _ids . tolist ( ) ) . toEqual ( [ [ 1037 n , 0 n ] , [ 1038 n , 1039 n ] ] )
expect ( attention _mask . tolist ( ) ) . toEqual ( [ [ 1 n , 0 n ] , [ 1 n , 1 n ] ] )
expect ( token _type _ids . tolist ( ) ) . toEqual ( [ [ 0 n , 0 n ] , [ 0 n , 0 n ] ] )
}
{ // Padding w/ text pair
const { input _ids , attention _mask , token _type _ids } = tokenizer ( inputs , {
text _pair ,
padding : true ,
add _special _tokens : false ,
} )
expect ( input _ids . tolist ( ) ) . toEqual ( [
[ 1037 n , 1040 n , 1041 n , 0 n , 0 n ] ,
[ 1038 n , 1039 n , 1042 n , 1043 n , 1044 n ] ,
] ) ;
expect ( attention _mask . tolist ( ) ) . toEqual ( [
[ 1 n , 1 n , 1 n , 0 n , 0 n ] ,
[ 1 n , 1 n , 1 n , 1 n , 1 n ] ,
] ) ;
expect ( token _type _ids . tolist ( ) ) . toEqual ( [
[ 0 n , 1 n , 1 n , 0 n , 0 n ] ,
[ 0 n , 0 n , 1 n , 1 n , 1 n ] ,
] ) ;
}
{ // Truncation + padding
const { input _ids , attention _mask , token _type _ids } = tokenizer ( [ 'a' , 'b c' , 'd e f' ] , {
padding : true ,
truncation : true ,
add _special _tokens : false ,
max _length : 2 ,
} )
expect ( input _ids . tolist ( ) ) . toEqual ( [ [ 1037 n , 0 n ] , [ 1038 n , 1039 n ] , [ 1040 n , 1041 n ] ] )
expect ( attention _mask . tolist ( ) ) . toEqual ( [ [ 1 n , 0 n ] , [ 1 n , 1 n ] , [ 1 n , 1 n ] ] )
expect ( token _type _ids . tolist ( ) ) . toEqual ( [ [ 0 n , 0 n ] , [ 0 n , 0 n ] , [ 0 n , 0 n ] ] )
}
} , MAX _TEST _EXECUTION _TIME ) ;
} ) ;
describe ( 'Token type ids' , ( ) => {
it ( 'should correctly add token type ids' , async ( ) => {
const tokenizer = await AutoTokenizer . from _pretrained ( 'Xenova/bert-base-uncased' ) ;
const model _inputs = tokenizer (
[ 'a b c' , 'd' ] ,
{
text _pair : [ 'e f' , 'g h' ] ,
padding : true ,
truncation : true ,
return _tensor : false ,
}
) ;
const expected = {
input _ids : [
[ 101 , 1037 , 1038 , 1039 , 102 , 1041 , 1042 , 102 ] ,
[ 101 , 1040 , 102 , 1043 , 1044 , 102 , 0 , 0 ] ,
] ,
token _type _ids : [
[ 0 , 0 , 0 , 0 , 0 , 1 , 1 , 1 ] ,
[ 0 , 0 , 0 , 1 , 1 , 1 , 0 , 0 ] ,
] ,
attention _mask : [
[ 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ] ,
[ 1 , 1 , 1 , 1 , 1 , 1 , 0 , 0 ] ,
] ,
}
compare ( model _inputs , expected ) ;
} , MAX _TEST _EXECUTION _TIME ) ;
} ) ;
2023-06-21 07:21:14 +08:00
describe ( 'Edge cases' , ( ) => {
it ( 'should not crash when encoding a very long string' , async ( ) => {
2023-11-19 14:06:49 +08:00
let tokenizer = await AutoTokenizer . from _pretrained ( 'Xenova/t5-small' ) ;
2023-06-21 07:21:14 +08:00
let text = String . prototype . repeat . call ( 'Hello world! ' , 50000 ) ;
2023-11-19 14:06:49 +08:00
let encoded = tokenizer ( text ) ;
2023-06-21 07:21:14 +08:00
expect ( encoded . input _ids . data . length ) . toBeGreaterThan ( 100000 ) ;
2023-06-28 21:14:44 +08:00
} , MAX _TEST _EXECUTION _TIME ) ;
2023-11-19 14:06:49 +08:00
it ( 'should not take too long' , async ( ) => {
let tokenizer = await AutoTokenizer . from _pretrained ( 'Xenova/all-MiniLM-L6-v2' ) ;
let text = String . prototype . repeat . call ( 'a' , 50000 ) ;
let token _ids = tokenizer . encode ( text ) ;
compare ( token _ids , [ 101 , 100 , 102 ] )
} , 5000 ) ; // NOTE: 5 seconds
2023-06-21 07:21:14 +08:00
} ) ;
2023-12-02 22:17:57 +08:00
describe ( 'Extra decoding tests' , ( ) => {
it ( 'should be able to decode the output of encode' , async ( ) => {
let tokenizer = await AutoTokenizer . from _pretrained ( 'Xenova/bert-base-uncased' ) ;
let text = 'hello world!' ;
// Ensure all the following outputs are the same:
// 1. Tensor of ids: allow decoding of 1D or 2D tensors.
let encodedTensor = tokenizer ( text ) ;
let decoded1 = tokenizer . decode ( encodedTensor . input _ids , { skip _special _tokens : true } ) ;
let decoded2 = tokenizer . batch _decode ( encodedTensor . input _ids , { skip _special _tokens : true } ) [ 0 ] ;
expect ( decoded1 ) . toEqual ( text ) ;
expect ( decoded2 ) . toEqual ( text ) ;
// 2. List of ids
let encodedList = tokenizer ( text , { return _tensor : false } ) ;
let decoded3 = tokenizer . decode ( encodedList . input _ids , { skip _special _tokens : true } ) ;
let decoded4 = tokenizer . batch _decode ( [ encodedList . input _ids ] , { skip _special _tokens : true } ) [ 0 ] ;
expect ( decoded3 ) . toEqual ( text ) ;
expect ( decoded4 ) . toEqual ( text ) ;
} , MAX _TEST _EXECUTION _TIME ) ;
} ) ;
2023-12-18 23:00:50 +08:00
describe ( 'Chat templates' , ( ) => {
it ( 'should generate a chat template' , async ( ) => {
const tokenizer = await AutoTokenizer . from _pretrained ( "mistralai/Mistral-7B-Instruct-v0.1" ) ;
const chat = [
{ "role" : "user" , "content" : "Hello, how are you?" } ,
{ "role" : "assistant" , "content" : "I'm doing great. How can I help you today?" } ,
{ "role" : "user" , "content" : "I'd like to show off how chat templating works!" } ,
]
2023-12-19 05:22:24 +08:00
const text = tokenizer . apply _chat _template ( chat , { tokenize : false } ) ;
2023-12-18 23:00:50 +08:00
expect ( text ) . toEqual ( "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]" ) ;
2023-12-19 05:22:24 +08:00
const input _ids = tokenizer . apply _chat _template ( chat , { tokenize : true , return _tensor : false } ) ;
2023-12-18 23:00:50 +08:00
compare ( input _ids , [ 1 , 733 , 16289 , 28793 , 22557 , 28725 , 910 , 460 , 368 , 28804 , 733 , 28748 , 16289 , 28793 , 28737 , 28742 , 28719 , 2548 , 1598 , 28723 , 1602 , 541 , 315 , 1316 , 368 , 3154 , 28804 , 2 , 28705 , 733 , 16289 , 28793 , 315 , 28742 , 28715 , 737 , 298 , 1347 , 805 , 910 , 10706 , 5752 , 1077 , 3791 , 28808 , 733 , 28748 , 16289 , 28793 ] )
} ) ;
it ( 'should support user-defined chat template' , async ( ) => {
const tokenizer = await AutoTokenizer . from _pretrained ( "Xenova/llama-tokenizer" ) ;
const chat = [
{ role : 'user' , content : 'Hello, how are you?' } ,
{ role : 'assistant' , content : "I'm doing great. How can I help you today?" } ,
{ role : 'user' , content : "I'd like to show off how chat templating works!" } ,
]
// https://discuss.huggingface.co/t/issue-with-llama-2-chat-template-and-out-of-date-documentation/61645/3
const chat _template = (
"{% if messages[0]['role'] == 'system' %}" +
"{% set loop_messages = messages[1:] %}" + // Extract system message if it's present
"{% set system_message = messages[0]['content'] %}" +
"{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}" +
"{% set loop_messages = messages %}" + // Or use the default system message if the flag is set
"{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}" +
"{% else %}" +
"{% set loop_messages = messages %}" +
"{% set system_message = false %}" +
"{% endif %}" +
"{% if loop_messages|length == 0 and system_message %}" + // Special handling when only sys message present
"{{ bos_token + '[INST] <<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n [/INST]' }}" +
"{% endif %}" +
"{% for message in loop_messages %}" + // Loop over all non-system messages
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" +
"{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" +
"{% endif %}" +
"{% if loop.index0 == 0 and system_message != false %}" + // Embed system message in first message
"{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}" +
"{% else %}" +
"{% set content = message['content'] %}" +
"{% endif %}" +
"{% if message['role'] == 'user' %}" + // After all of that, handle messages/roles in a fairly normal way
"{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" +
"{% elif message['role'] == 'system' %}" +
"{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}" +
"{% elif message['role'] == 'assistant' %}" +
"{{ ' ' + content.strip() + ' ' + eos_token }}" +
"{% endif %}" +
"{% endfor %}"
)
. replaceAll ( 'USE_DEFAULT_PROMPT' , true )
. replaceAll ( 'DEFAULT_SYSTEM_MESSAGE' , 'You are a helpful, respectful and honest assistant.' ) ;
const text = await tokenizer . apply _chat _template ( chat , { tokenize : false , return _tensor : false , chat _template } ) ;
expect ( text ) . toEqual ( "<s>[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant.\n<</SYS>>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]" ) ;
// TODO: Add test for token_ids once bug in transformers is fixed.
} ) ;
// Dynamically-generated tests
for ( const [ tokenizerName , tests ] of Object . entries ( templates ) ) {
it ( tokenizerName , async ( ) => {
// NOTE: not m(...) here
// TODO: update this?
const tokenizer = await AutoTokenizer . from _pretrained ( tokenizerName ) ;
for ( let { messages , add _generation _prompt , tokenize , target } of tests ) {
const generated = await tokenizer . apply _chat _template ( messages , {
tokenize ,
add _generation _prompt ,
return _tensor : false ,
} ) ;
expect ( generated ) . toEqual ( target )
}
} ) ;
}
} ) ;