diff --git a/src/tokenizers.js b/src/tokenizers.js index 8cf1325..234eef1 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -1592,6 +1592,8 @@ class PostProcessor extends Callable { case 'BertProcessing': return new BertProcessing(config); + case 'Sequence': + return new PostProcessorSequence(config); default: throw new Error(`Unknown PostProcessor type: ${config.type}`); } @@ -1738,6 +1740,50 @@ class ByteLevelPostProcessor extends PostProcessor { } } + +/** + * A post-processor that applies multiple post-processors in sequence. + */ +class PostProcessorSequence extends PostProcessor { + + /** + * Creates a new instance of PostProcessorSequence. + * @param {Object} config The configuration object. + * @param {Object[]} config.processors The list of post-processors to apply. + */ + constructor(config) { + super(config); + + this.processors = config.processors.map(x => PostProcessor.fromConfig(x)); + } + + /** + * Post process the given tokens. + * @param {string[]} tokens The list of tokens for the first sequence. + * @param {string[]} [tokens_pair=null] The list of tokens for the second sequence (optional). + * @returns {PostProcessedOutput} An object containing the post-processed tokens. + */ + post_process(tokens, tokens_pair = null, options = {}) { + let token_type_ids; + for (const processor of this.processors) { + if (processor instanceof ByteLevelPostProcessor) { + // Special case where we need to pass the tokens_pair to the post-processor + const output = processor.post_process(tokens); + tokens = output.tokens; + if (tokens_pair) { + const pair_output = processor.post_process(tokens_pair); + tokens_pair = pair_output.tokens; + } + } else { + const output = processor.post_process(tokens, tokens_pair, options); + tokens = output.tokens; + token_type_ids = output.token_type_ids; + } + } + return { tokens, token_type_ids }; + } +} + /** * The base class for token decoders. * @extends Callable @@ -2100,7 +2146,7 @@ class DecoderSequence extends Decoder { /** * Creates a new instance of DecoderSequence. * @param {Object} config The configuration object. - * @param {Decoder[]} config.decoders The list of decoders to apply. + * @param {Object[]} config.decoders The list of decoders to apply. */ constructor(config) { super(config); @@ -2623,6 +2669,7 @@ export class PreTrainedTokenizer extends Callable { * @param {boolean} [options.truncation=null] Whether to truncate the input sequences. * @param {number} [options.max_length=null] Maximum length of the returned list and optionally padding length. * @param {boolean} [options.return_tensor=true] Whether to return the results as Tensors or arrays. + * @param {boolean} [options.return_token_type_ids=null] Whether to return the token type ids. * @returns {BatchEncoding} Object to be passed to the model. */ _call( @@ -2637,6 +2684,7 @@ export class PreTrainedTokenizer extends Callable { truncation = null, max_length = null, return_tensor = true, // Different to HF + return_token_type_ids = null, } = {}, ) { @@ -2659,11 +2707,11 @@ export class PreTrainedTokenizer extends Callable { } encodedTokens = text.map( - (t, i) => this._encode_plus(t, text_pair[i], { add_special_tokens }) + (t, i) => this._encode_plus(t, text_pair[i], { add_special_tokens, return_token_type_ids }) ) } else { - encodedTokens = text.map(x => this._encode_plus(x, null, { add_special_tokens })); + encodedTokens = text.map(x => this._encode_plus(x, null, { add_special_tokens, return_token_type_ids })); } } else { @@ -2676,7 +2724,7 @@ export class PreTrainedTokenizer extends Callable { } // For single input, we just wrap in an array, and then unwrap later. - encodedTokens = [this._encode_plus(text, text_pair, { add_special_tokens })]; + encodedTokens = [this._encode_plus(text, text_pair, { add_special_tokens, return_token_type_ids })]; } // At this point, tokens is batched: [batch_size, tokens] // However, array may be jagged. So, we pad to max_length @@ -2834,11 +2882,13 @@ export class PreTrainedTokenizer extends Callable { * @param {string|null} text_pair The optional second text to encode. * @param {Object} options An optional object containing the following properties: * @param {boolean} [options.add_special_tokens=true] Whether or not to add the special tokens associated with the corresponding model. + * @param {boolean} [options.return_token_type_ids=null] Whether to return token_type_ids. * @returns {EncodingSingle} An object containing the encoded text. * @private */ _encode_plus(text, text_pair = null, { add_special_tokens = true, + return_token_type_ids = null, } = {}) { // Function called by users to encode possibly multiple texts const tokens = this._encode_text(text); @@ -2854,7 +2904,7 @@ export class PreTrainedTokenizer extends Callable { input_ids, attention_mask: new Array(input_ids.length).fill(1), } - if (this.return_token_type_ids && combinedTokens.token_type_ids) { + if ((return_token_type_ids ?? this.return_token_type_ids) && combinedTokens.token_type_ids) { result.token_type_ids = combinedTokens.token_type_ids; } return result; @@ -2867,13 +2917,16 @@ export class PreTrainedTokenizer extends Callable { * @param {string|null} text_pair The optional second text to encode. * @param {Object} options An optional object containing the following properties: * @param {boolean} [options.add_special_tokens=true] Whether or not to add the special tokens associated with the corresponding model. + * @param {boolean} [options.return_token_type_ids=null] Whether to return token_type_ids. * @returns {number[]} An array of token IDs representing the encoded text(s). */ encode(text, text_pair = null, { add_special_tokens = true, + return_token_type_ids = null, } = {}) { const { input_ids } = this._encode_plus(text, text_pair, { add_special_tokens, + return_token_type_ids, }); return input_ids; } diff --git a/tests/generate_tests.py b/tests/generate_tests.py index b363879..d529160 100644 --- a/tests/generate_tests.py +++ b/tests/generate_tests.py @@ -20,6 +20,9 @@ ADDITIONAL_TOKENIZERS_TO_TEST = { 'Xenova/llama2-tokenizer', # Special tokens: normalized=false 'Xenova/llama2-chat-tokenizer', # Special tokens: normalized=false 'hf-internal-testing/llama-code-tokenizer', + + # TODO: add back when llama tests are fixed + # 'Xenova/llama3-tokenizer-new', # PostProcessor type: Sequence ], 'mpt': [ 'mosaicml/mpt-7b', @@ -289,7 +292,7 @@ def generate_tokenizer_tests(): # Load tokenizer if model_type == 'llama': # As of 17/12/2023, there are a few issues with the Llama tokenizers in transformers. - # (1) Encoding with fast tokenizer adds whitespace after speical tokens: + # (1) Encoding with fast tokenizer adds whitespace after special tokens: # - https://github.com/huggingface/transformers/issues/25881 # - https://github.com/huggingface/transformers/issues/26318 # - https://github.com/huggingface/transformers/issues/26455 diff --git a/tests/tokenizers.test.js b/tests/tokenizers.test.js index cfbe7c6..3b0cfe7 100644 --- a/tests/tokenizers.test.js +++ b/tests/tokenizers.test.js @@ -288,6 +288,44 @@ describe('Token type ids', () => { compare(model_inputs, expected); }, MAX_TEST_EXECUTION_TIME); + + it('should add token type ids if user requests them', async () => { + const tokenizer = await AutoTokenizer.from_pretrained('Xenova/llama3-tokenizer-new'); + + { // Without text pair + const model_inputs = tokenizer( + 'hello', + { + return_tensor: false, + return_token_type_ids: true, + } + ); + const expected = { + input_ids: [128000, 15339], + attention_mask: [1, 1], + token_type_ids: [0, 0] + } + compare(model_inputs, expected); + } + + { // With text pair + const model_inputs = tokenizer( + 'hello', + { + text_pair: 'world', + return_tensor: false, + return_token_type_ids: true, + } + ); + const expected = { + input_ids: [128000, 15339, 128000, 14957], + attention_mask: [1, 1, 1, 1], + token_type_ids: [0, 0, 1, 1] + } + compare(model_inputs, expected); + } + + }, MAX_TEST_EXECUTION_TIME); }); describe('Edge cases', () => {