diff --git a/scripts/supported_models.py b/scripts/supported_models.py index 8690f5a..aeb9cc6 100644 --- a/scripts/supported_models.py +++ b/scripts/supported_models.py @@ -140,6 +140,7 @@ SUPPORTED_MODELS = { ], 'm2m_100': [ 'facebook/nllb-200-distilled-600M', + 'facebook/m2m100_418M', ], # TODO: # 'marian': [ diff --git a/src/pipelines.js b/src/pipelines.js index f8546c5..7e22d54 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -452,8 +452,36 @@ export class SummarizationPipeline extends Text2TextGenerationPipeline { } /** - * TranslationPipeline class to translate text from one language to another using the provided model and tokenizer. - * @extends Text2TextGenerationPipeline + * Translates text from one language to another. + * + * **Example:** Multilingual translation w/ `Xenova/nllb-200-distilled-600M`. + * + * See [here](https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200) + * for the full list of languages and their corresponding codes. + * + * ```javascript + * let translator = await pipeline('translation', 'Xenova/nllb-200-distilled-600M'); + * let output = await translator('जीवन एक चॉकलेट बॉक्स की तरह है।', { + * src_lang: 'hin_Deva', // Hindi + * tgt_lang: 'fra_Latn', // French + * }); + * // [ { translation_text: 'La vie est comme une boîte à chocolat.' } ] + * ``` + * + * **Example:** Multilingual translation w/ `Xenova/m2m100_418M`. + * + * See [here](https://huggingface.co/facebook/m2m100_418M#languages-covered) + * for the full list of languages and their corresponding codes. + * + * ```javascript + * let translator = await pipeline('translation', 'Xenova/m2m100_418M'); + * let output = await translator('生活就像一盒巧克力。', { + * src_lang: 'zh', // Chinese + * tgt_lang: 'en', // English + * }); + * // [ { translation_text: 'Life is like a box of chocolate.' } ] + * ``` + * */ export class TranslationPipeline extends Text2TextGenerationPipeline { _key = 'translation_text'; diff --git a/src/tokenizers.js b/src/tokenizers.js index 8f71ad3..f9322c4 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -1995,12 +1995,16 @@ export class PreTrainedTokenizer extends Callable { } } + // Update additional_special_tokens + this.special_tokens.push(...(tokenizerConfig.additional_special_tokens ?? [])); + this.special_tokens = [...new Set(this.special_tokens)]; // Remove duplicates + // Slight hack, but it prevents code duplication: this.decoder.added_tokens = this.added_tokens; - this.added_tokens_regex = new RegExp( + this.added_tokens_regex = this.added_tokens.length > 0 ? new RegExp( '(' + this.added_tokens.map(escapeRegExp).join('|') + ')' - ); + ) : null; // Set mask token if present (otherwise will be undefined, which is fine) this.mask_token = this.getToken(tokenizerConfig, 'mask_token'); @@ -2265,8 +2269,7 @@ export class PreTrainedTokenizer extends Callable { // Actual function which does encoding, for a single text // First, we take care of special tokens. Needed to avoid issues arising from // normalization and/or pretokenization (which may not preserve special tokens) - const sections = text.split(this.added_tokens_regex).filter(x => x); - + const sections = this.added_tokens_regex ? text.split(this.added_tokens_regex).filter(x => x) : [text]; let tokens = sections.map(x => { if (this.added_tokens.includes(x)) { // Ignore added tokens @@ -2482,6 +2485,58 @@ export class FalconTokenizer extends PreTrainedTokenizer { export class GPTNeoXTokenizer extends PreTrainedTokenizer { } + +/** + * Helper function to build translation inputs for an `NllbTokenizer` or `M2M100Tokenizer`. + * @param {PreTrainedTokenizer} self The tokenizer instance. + * @param {string|string[]} raw_inputs The text to tokenize. + * @param {Object} tokenizer_options Options to be sent to the tokenizer + * @param {Object} generate_kwargs Generation options. + * @returns {Object} Object to be passed to the model. + * @private + */ +function _build_translation_inputs(self, raw_inputs, tokenizer_options, generate_kwargs) { + if (!('language_codes' in self) || !Array.isArray(self.language_codes)) { + throw new Error('Tokenizer must have `language_codes` attribute set and it should be an array of language ids.') + } + if (!('languageRegex' in self) || !(self.languageRegex instanceof RegExp)) { + throw new Error('Tokenizer must have `languageRegex` attribute set and it should be a regular expression.') + } + if (!('lang_to_token' in self) || typeof self.lang_to_token !== 'function') { + throw new Error('Tokenizer must have `lang_to_token` attribute set and it should be a function.') + } + const src_lang_token = generate_kwargs.src_lang; + const tgt_lang_token = generate_kwargs.tgt_lang; + + // Check that the target language is valid: + if (!self.language_codes.includes(tgt_lang_token)) { + throw new Error(`Target language code "${tgt_lang_token}" is not valid. Must be one of: {${self.language_codes.join(', ')}}`); + } + + // Allow `src_lang` to be optional. If not set, we'll use the tokenizer's default. + if (src_lang_token !== undefined) { + // Check that the source language is valid: + if (!self.language_codes.includes(src_lang_token)) { + throw new Error(`Source language code "${src_lang_token}" is not valid. Must be one of: {${self.language_codes.join(', ')}}`); + } + + // In the same way as the Python library, we override the post-processor + // to force the source language to be first: + for (let item of self.post_processor.config.single) { + if ('SpecialToken' in item && self.languageRegex.test(item.SpecialToken.id)) { + item.SpecialToken.id = self.lang_to_token(src_lang_token); + break; + } + } + // TODO: Do the same for pair? + } + + // Override the `forced_bos_token_id` to force the correct language + generate_kwargs.forced_bos_token_id = self.model.convert_tokens_to_ids([self.lang_to_token(tgt_lang_token)])[0]; + + return self._call(raw_inputs, tokenizer_options); +} + /** * The NllbTokenizer class is used to tokenize text for NLLB ("No Language Left Behind") models. * @@ -2502,6 +2557,7 @@ export class NllbTokenizer extends PreTrainedTokenizer { this.languageRegex = /^[a-z]{3}_[A-Z][a-z]{3}$/; this.language_codes = this.special_tokens.filter(x => this.languageRegex.test(x)); + this.lang_to_token = x => x; // Identity function } /** @@ -2512,34 +2568,40 @@ export class NllbTokenizer extends PreTrainedTokenizer { * @returns {Object} Object to be passed to the model. */ _build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs) { + return _build_translation_inputs(this, raw_inputs, tokenizer_options, generate_kwargs); + } +} +/** + * The M2M100Tokenizer class is used to tokenize text for M2M100 ("Many-to-Many") models. + * + * M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many + * multilingual translation. It was introduced in this [paper](https://arxiv.org/abs/2010.11125) + * and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository. + * + * For a list of supported languages (along with their language codes), + * @see {@link https://huggingface.co/facebook/m2m100_418M#languages-covered} + */ +export class M2M100Tokenizer extends PreTrainedTokenizer { + constructor(tokenizerJSON, tokenizerConfig) { + super(tokenizerJSON, tokenizerConfig); - // Check that the target language is valid: - if (!this.language_codes.includes(generate_kwargs.tgt_lang)) { - throw new Error(`Target language code "${generate_kwargs.tgt_lang}" is not valid. Must be one of: {${this.language_codes.join(', ')}}`); - } + this.languageRegex = /^__[a-z]{2,3}__$/; + this.language_codes = this.special_tokens + .filter(x => this.languageRegex.test(x)) + .map(x => x.slice(2, -2)); + this.lang_to_token = x => `__${x}__`; + } - // Allow `src_lang` to be optional. If not set, we'll use the tokenizer's default. - if (generate_kwargs.src_lang !== undefined) { - // Check that the source language is valid: - if (!this.language_codes.includes(generate_kwargs.src_lang)) { - throw new Error(`Source language code "${generate_kwargs.src_lang}" is not valid. Must be one of: {${this.language_codes.join(', ')}}`); - } - - // In the same way as the Python library, we override the post-processor - // to force the source language to be first: - for (let item of this.post_processor.config.single) { - if ('SpecialToken' in item && this.languageRegex.test(item.SpecialToken.id)) { - item.SpecialToken.id = generate_kwargs.src_lang; - break; - } - } - } - - // Override the `forced_bos_token_id` to force the correct language - generate_kwargs.forced_bos_token_id = this.model.convert_tokens_to_ids([generate_kwargs.tgt_lang])[0]; - - return this._call(raw_inputs, tokenizer_options); + /** + * Helper function to build translation inputs for an `M2M100Tokenizer`. + * @param {string|string[]} raw_inputs The text to tokenize. + * @param {Object} tokenizer_options Options to be sent to the tokenizer + * @param {Object} generate_kwargs Generation options. + * @returns {Object} Object to be passed to the model. + */ + _build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs) { + return _build_translation_inputs(this, raw_inputs, tokenizer_options, generate_kwargs); } } @@ -3485,6 +3547,7 @@ export class AutoTokenizer { 'MarianTokenizer': MarianTokenizer, 'BloomTokenizer': BloomTokenizer, 'NllbTokenizer': NllbTokenizer, + 'M2M100Tokenizer': M2M100Tokenizer, 'LlamaTokenizer': LlamaTokenizer, 'XLMRobertaTokenizer': XLMRobertaTokenizer, 'MPNetTokenizer': MPNetTokenizer, diff --git a/tests/generate_tests.py b/tests/generate_tests.py index 819681f..268b81a 100644 --- a/tests/generate_tests.py +++ b/tests/generate_tests.py @@ -21,6 +21,11 @@ ADDITIONAL_TOKENIZERS_TO_TEST = { ], } +TOKENIZERS_TO_IGNORE = [ + # TODO: remove when https://github.com/huggingface/transformers/pull/25478 is merged + 'facebook/m2m100_418M', +] + TOKENIZER_TEST_DATA = { "shared": [ "hello world", @@ -92,6 +97,9 @@ def generate_tokenizer_tests(): for model_type, tokenizer_names in tokenizers_to_test: print(f'Generating tests for {model_type}') for tokenizer_name in tokenizer_names: + if tokenizer_name in TOKENIZERS_TO_IGNORE: + continue + print(' -', tokenizer_name) try: