Add M2M100 tokenizer (Closes #235) (#250)

* Add `M2M100Tokenizer`

* Allow `added_tokens` list to be empty

* Apply hot-fix for issue in HF's `M2M100Tokenizer`

* Skip M2M100 tokenizer tests for now

TODO: Remove when https://github.com/huggingface/transformers/pull/25478 is merged

* Fix `_build_translation_inputs` for `M2M100Tokenizer`

* Add example code in JSDoc for `TranslationPipeline`

* Update supported_models.py
This commit is contained in:
Joshua Lochner 2023-08-14 17:22:20 +02:00 committed by GitHub
parent cc4b857d54
commit 060ac830fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 131 additions and 31 deletions

View File

@ -140,6 +140,7 @@ SUPPORTED_MODELS = {
],
'm2m_100': [
'facebook/nllb-200-distilled-600M',
'facebook/m2m100_418M',
],
# TODO:
# 'marian': [

View File

@ -452,8 +452,36 @@ export class SummarizationPipeline extends Text2TextGenerationPipeline {
}
/**
* TranslationPipeline class to translate text from one language to another using the provided model and tokenizer.
* @extends Text2TextGenerationPipeline
* Translates text from one language to another.
*
* **Example:** Multilingual translation w/ `Xenova/nllb-200-distilled-600M`.
*
* See [here](https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200)
* for the full list of languages and their corresponding codes.
*
* ```javascript
* let translator = await pipeline('translation', 'Xenova/nllb-200-distilled-600M');
* let output = await translator('जीवन एक चॉकलेट बॉक्स की तरह है।', {
* src_lang: 'hin_Deva', // Hindi
* tgt_lang: 'fra_Latn', // French
* });
* // [ { translation_text: 'La vie est comme une boîte à chocolat.' } ]
* ```
*
* **Example:** Multilingual translation w/ `Xenova/m2m100_418M`.
*
* See [here](https://huggingface.co/facebook/m2m100_418M#languages-covered)
* for the full list of languages and their corresponding codes.
*
* ```javascript
* let translator = await pipeline('translation', 'Xenova/m2m100_418M');
* let output = await translator('生活就像一盒巧克力。', {
* src_lang: 'zh', // Chinese
* tgt_lang: 'en', // English
* });
* // [ { translation_text: 'Life is like a box of chocolate.' } ]
* ```
*
*/
export class TranslationPipeline extends Text2TextGenerationPipeline {
_key = 'translation_text';

View File

@ -1995,12 +1995,16 @@ export class PreTrainedTokenizer extends Callable {
}
}
// Update additional_special_tokens
this.special_tokens.push(...(tokenizerConfig.additional_special_tokens ?? []));
this.special_tokens = [...new Set(this.special_tokens)]; // Remove duplicates
// Slight hack, but it prevents code duplication:
this.decoder.added_tokens = this.added_tokens;
this.added_tokens_regex = new RegExp(
this.added_tokens_regex = this.added_tokens.length > 0 ? new RegExp(
'(' + this.added_tokens.map(escapeRegExp).join('|') + ')'
);
) : null;
// Set mask token if present (otherwise will be undefined, which is fine)
this.mask_token = this.getToken(tokenizerConfig, 'mask_token');
@ -2265,8 +2269,7 @@ export class PreTrainedTokenizer extends Callable {
// Actual function which does encoding, for a single text
// First, we take care of special tokens. Needed to avoid issues arising from
// normalization and/or pretokenization (which may not preserve special tokens)
const sections = text.split(this.added_tokens_regex).filter(x => x);
const sections = this.added_tokens_regex ? text.split(this.added_tokens_regex).filter(x => x) : [text];
let tokens = sections.map(x => {
if (this.added_tokens.includes(x)) {
// Ignore added tokens
@ -2482,6 +2485,58 @@ export class FalconTokenizer extends PreTrainedTokenizer {
export class GPTNeoXTokenizer extends PreTrainedTokenizer { }
/**
* Helper function to build translation inputs for an `NllbTokenizer` or `M2M100Tokenizer`.
* @param {PreTrainedTokenizer} self The tokenizer instance.
* @param {string|string[]} raw_inputs The text to tokenize.
* @param {Object} tokenizer_options Options to be sent to the tokenizer
* @param {Object} generate_kwargs Generation options.
* @returns {Object} Object to be passed to the model.
* @private
*/
function _build_translation_inputs(self, raw_inputs, tokenizer_options, generate_kwargs) {
if (!('language_codes' in self) || !Array.isArray(self.language_codes)) {
throw new Error('Tokenizer must have `language_codes` attribute set and it should be an array of language ids.')
}
if (!('languageRegex' in self) || !(self.languageRegex instanceof RegExp)) {
throw new Error('Tokenizer must have `languageRegex` attribute set and it should be a regular expression.')
}
if (!('lang_to_token' in self) || typeof self.lang_to_token !== 'function') {
throw new Error('Tokenizer must have `lang_to_token` attribute set and it should be a function.')
}
const src_lang_token = generate_kwargs.src_lang;
const tgt_lang_token = generate_kwargs.tgt_lang;
// Check that the target language is valid:
if (!self.language_codes.includes(tgt_lang_token)) {
throw new Error(`Target language code "${tgt_lang_token}" is not valid. Must be one of: {${self.language_codes.join(', ')}}`);
}
// Allow `src_lang` to be optional. If not set, we'll use the tokenizer's default.
if (src_lang_token !== undefined) {
// Check that the source language is valid:
if (!self.language_codes.includes(src_lang_token)) {
throw new Error(`Source language code "${src_lang_token}" is not valid. Must be one of: {${self.language_codes.join(', ')}}`);
}
// In the same way as the Python library, we override the post-processor
// to force the source language to be first:
for (let item of self.post_processor.config.single) {
if ('SpecialToken' in item && self.languageRegex.test(item.SpecialToken.id)) {
item.SpecialToken.id = self.lang_to_token(src_lang_token);
break;
}
}
// TODO: Do the same for pair?
}
// Override the `forced_bos_token_id` to force the correct language
generate_kwargs.forced_bos_token_id = self.model.convert_tokens_to_ids([self.lang_to_token(tgt_lang_token)])[0];
return self._call(raw_inputs, tokenizer_options);
}
/**
* The NllbTokenizer class is used to tokenize text for NLLB ("No Language Left Behind") models.
*
@ -2502,6 +2557,7 @@ export class NllbTokenizer extends PreTrainedTokenizer {
this.languageRegex = /^[a-z]{3}_[A-Z][a-z]{3}$/;
this.language_codes = this.special_tokens.filter(x => this.languageRegex.test(x));
this.lang_to_token = x => x; // Identity function
}
/**
@ -2512,34 +2568,40 @@ export class NllbTokenizer extends PreTrainedTokenizer {
* @returns {Object} Object to be passed to the model.
*/
_build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs) {
return _build_translation_inputs(this, raw_inputs, tokenizer_options, generate_kwargs);
}
}
/**
* The M2M100Tokenizer class is used to tokenize text for M2M100 ("Many-to-Many") models.
*
* M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many
* multilingual translation. It was introduced in this [paper](https://arxiv.org/abs/2010.11125)
* and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository.
*
* For a list of supported languages (along with their language codes),
* @see {@link https://huggingface.co/facebook/m2m100_418M#languages-covered}
*/
export class M2M100Tokenizer extends PreTrainedTokenizer {
constructor(tokenizerJSON, tokenizerConfig) {
super(tokenizerJSON, tokenizerConfig);
// Check that the target language is valid:
if (!this.language_codes.includes(generate_kwargs.tgt_lang)) {
throw new Error(`Target language code "${generate_kwargs.tgt_lang}" is not valid. Must be one of: {${this.language_codes.join(', ')}}`);
}
this.languageRegex = /^__[a-z]{2,3}__$/;
this.language_codes = this.special_tokens
.filter(x => this.languageRegex.test(x))
.map(x => x.slice(2, -2));
this.lang_to_token = x => `__${x}__`;
}
// Allow `src_lang` to be optional. If not set, we'll use the tokenizer's default.
if (generate_kwargs.src_lang !== undefined) {
// Check that the source language is valid:
if (!this.language_codes.includes(generate_kwargs.src_lang)) {
throw new Error(`Source language code "${generate_kwargs.src_lang}" is not valid. Must be one of: {${this.language_codes.join(', ')}}`);
}
// In the same way as the Python library, we override the post-processor
// to force the source language to be first:
for (let item of this.post_processor.config.single) {
if ('SpecialToken' in item && this.languageRegex.test(item.SpecialToken.id)) {
item.SpecialToken.id = generate_kwargs.src_lang;
break;
}
}
}
// Override the `forced_bos_token_id` to force the correct language
generate_kwargs.forced_bos_token_id = this.model.convert_tokens_to_ids([generate_kwargs.tgt_lang])[0];
return this._call(raw_inputs, tokenizer_options);
/**
* Helper function to build translation inputs for an `M2M100Tokenizer`.
* @param {string|string[]} raw_inputs The text to tokenize.
* @param {Object} tokenizer_options Options to be sent to the tokenizer
* @param {Object} generate_kwargs Generation options.
* @returns {Object} Object to be passed to the model.
*/
_build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs) {
return _build_translation_inputs(this, raw_inputs, tokenizer_options, generate_kwargs);
}
}
@ -3485,6 +3547,7 @@ export class AutoTokenizer {
'MarianTokenizer': MarianTokenizer,
'BloomTokenizer': BloomTokenizer,
'NllbTokenizer': NllbTokenizer,
'M2M100Tokenizer': M2M100Tokenizer,
'LlamaTokenizer': LlamaTokenizer,
'XLMRobertaTokenizer': XLMRobertaTokenizer,
'MPNetTokenizer': MPNetTokenizer,

View File

@ -21,6 +21,11 @@ ADDITIONAL_TOKENIZERS_TO_TEST = {
],
}
TOKENIZERS_TO_IGNORE = [
# TODO: remove when https://github.com/huggingface/transformers/pull/25478 is merged
'facebook/m2m100_418M',
]
TOKENIZER_TEST_DATA = {
"shared": [
"hello world",
@ -92,6 +97,9 @@ def generate_tokenizer_tests():
for model_type, tokenizer_names in tokenizers_to_test:
print(f'Generating tests for {model_type}')
for tokenizer_name in tokenizer_names:
if tokenizer_name in TOKENIZERS_TO_IGNORE:
continue
print(' -', tokenizer_name)
try: