* Add `M2M100Tokenizer` * Allow `added_tokens` list to be empty * Apply hot-fix for issue in HF's `M2M100Tokenizer` * Skip M2M100 tokenizer tests for now TODO: Remove when https://github.com/huggingface/transformers/pull/25478 is merged * Fix `_build_translation_inputs` for `M2M100Tokenizer` * Add example code in JSDoc for `TranslationPipeline` * Update supported_models.py
This commit is contained in:
parent
cc4b857d54
commit
060ac830fc
|
@ -140,6 +140,7 @@ SUPPORTED_MODELS = {
|
|||
],
|
||||
'm2m_100': [
|
||||
'facebook/nllb-200-distilled-600M',
|
||||
'facebook/m2m100_418M',
|
||||
],
|
||||
# TODO:
|
||||
# 'marian': [
|
||||
|
|
|
@ -452,8 +452,36 @@ export class SummarizationPipeline extends Text2TextGenerationPipeline {
|
|||
}
|
||||
|
||||
/**
|
||||
* TranslationPipeline class to translate text from one language to another using the provided model and tokenizer.
|
||||
* @extends Text2TextGenerationPipeline
|
||||
* Translates text from one language to another.
|
||||
*
|
||||
* **Example:** Multilingual translation w/ `Xenova/nllb-200-distilled-600M`.
|
||||
*
|
||||
* See [here](https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200)
|
||||
* for the full list of languages and their corresponding codes.
|
||||
*
|
||||
* ```javascript
|
||||
* let translator = await pipeline('translation', 'Xenova/nllb-200-distilled-600M');
|
||||
* let output = await translator('जीवन एक चॉकलेट बॉक्स की तरह है।', {
|
||||
* src_lang: 'hin_Deva', // Hindi
|
||||
* tgt_lang: 'fra_Latn', // French
|
||||
* });
|
||||
* // [ { translation_text: 'La vie est comme une boîte à chocolat.' } ]
|
||||
* ```
|
||||
*
|
||||
* **Example:** Multilingual translation w/ `Xenova/m2m100_418M`.
|
||||
*
|
||||
* See [here](https://huggingface.co/facebook/m2m100_418M#languages-covered)
|
||||
* for the full list of languages and their corresponding codes.
|
||||
*
|
||||
* ```javascript
|
||||
* let translator = await pipeline('translation', 'Xenova/m2m100_418M');
|
||||
* let output = await translator('生活就像一盒巧克力。', {
|
||||
* src_lang: 'zh', // Chinese
|
||||
* tgt_lang: 'en', // English
|
||||
* });
|
||||
* // [ { translation_text: 'Life is like a box of chocolate.' } ]
|
||||
* ```
|
||||
*
|
||||
*/
|
||||
export class TranslationPipeline extends Text2TextGenerationPipeline {
|
||||
_key = 'translation_text';
|
||||
|
|
|
@ -1995,12 +1995,16 @@ export class PreTrainedTokenizer extends Callable {
|
|||
}
|
||||
}
|
||||
|
||||
// Update additional_special_tokens
|
||||
this.special_tokens.push(...(tokenizerConfig.additional_special_tokens ?? []));
|
||||
this.special_tokens = [...new Set(this.special_tokens)]; // Remove duplicates
|
||||
|
||||
// Slight hack, but it prevents code duplication:
|
||||
this.decoder.added_tokens = this.added_tokens;
|
||||
|
||||
this.added_tokens_regex = new RegExp(
|
||||
this.added_tokens_regex = this.added_tokens.length > 0 ? new RegExp(
|
||||
'(' + this.added_tokens.map(escapeRegExp).join('|') + ')'
|
||||
);
|
||||
) : null;
|
||||
|
||||
// Set mask token if present (otherwise will be undefined, which is fine)
|
||||
this.mask_token = this.getToken(tokenizerConfig, 'mask_token');
|
||||
|
@ -2265,8 +2269,7 @@ export class PreTrainedTokenizer extends Callable {
|
|||
// Actual function which does encoding, for a single text
|
||||
// First, we take care of special tokens. Needed to avoid issues arising from
|
||||
// normalization and/or pretokenization (which may not preserve special tokens)
|
||||
const sections = text.split(this.added_tokens_regex).filter(x => x);
|
||||
|
||||
const sections = this.added_tokens_regex ? text.split(this.added_tokens_regex).filter(x => x) : [text];
|
||||
let tokens = sections.map(x => {
|
||||
if (this.added_tokens.includes(x)) {
|
||||
// Ignore added tokens
|
||||
|
@ -2482,6 +2485,58 @@ export class FalconTokenizer extends PreTrainedTokenizer {
|
|||
|
||||
export class GPTNeoXTokenizer extends PreTrainedTokenizer { }
|
||||
|
||||
|
||||
/**
|
||||
* Helper function to build translation inputs for an `NllbTokenizer` or `M2M100Tokenizer`.
|
||||
* @param {PreTrainedTokenizer} self The tokenizer instance.
|
||||
* @param {string|string[]} raw_inputs The text to tokenize.
|
||||
* @param {Object} tokenizer_options Options to be sent to the tokenizer
|
||||
* @param {Object} generate_kwargs Generation options.
|
||||
* @returns {Object} Object to be passed to the model.
|
||||
* @private
|
||||
*/
|
||||
function _build_translation_inputs(self, raw_inputs, tokenizer_options, generate_kwargs) {
|
||||
if (!('language_codes' in self) || !Array.isArray(self.language_codes)) {
|
||||
throw new Error('Tokenizer must have `language_codes` attribute set and it should be an array of language ids.')
|
||||
}
|
||||
if (!('languageRegex' in self) || !(self.languageRegex instanceof RegExp)) {
|
||||
throw new Error('Tokenizer must have `languageRegex` attribute set and it should be a regular expression.')
|
||||
}
|
||||
if (!('lang_to_token' in self) || typeof self.lang_to_token !== 'function') {
|
||||
throw new Error('Tokenizer must have `lang_to_token` attribute set and it should be a function.')
|
||||
}
|
||||
const src_lang_token = generate_kwargs.src_lang;
|
||||
const tgt_lang_token = generate_kwargs.tgt_lang;
|
||||
|
||||
// Check that the target language is valid:
|
||||
if (!self.language_codes.includes(tgt_lang_token)) {
|
||||
throw new Error(`Target language code "${tgt_lang_token}" is not valid. Must be one of: {${self.language_codes.join(', ')}}`);
|
||||
}
|
||||
|
||||
// Allow `src_lang` to be optional. If not set, we'll use the tokenizer's default.
|
||||
if (src_lang_token !== undefined) {
|
||||
// Check that the source language is valid:
|
||||
if (!self.language_codes.includes(src_lang_token)) {
|
||||
throw new Error(`Source language code "${src_lang_token}" is not valid. Must be one of: {${self.language_codes.join(', ')}}`);
|
||||
}
|
||||
|
||||
// In the same way as the Python library, we override the post-processor
|
||||
// to force the source language to be first:
|
||||
for (let item of self.post_processor.config.single) {
|
||||
if ('SpecialToken' in item && self.languageRegex.test(item.SpecialToken.id)) {
|
||||
item.SpecialToken.id = self.lang_to_token(src_lang_token);
|
||||
break;
|
||||
}
|
||||
}
|
||||
// TODO: Do the same for pair?
|
||||
}
|
||||
|
||||
// Override the `forced_bos_token_id` to force the correct language
|
||||
generate_kwargs.forced_bos_token_id = self.model.convert_tokens_to_ids([self.lang_to_token(tgt_lang_token)])[0];
|
||||
|
||||
return self._call(raw_inputs, tokenizer_options);
|
||||
}
|
||||
|
||||
/**
|
||||
* The NllbTokenizer class is used to tokenize text for NLLB ("No Language Left Behind") models.
|
||||
*
|
||||
|
@ -2502,6 +2557,7 @@ export class NllbTokenizer extends PreTrainedTokenizer {
|
|||
|
||||
this.languageRegex = /^[a-z]{3}_[A-Z][a-z]{3}$/;
|
||||
this.language_codes = this.special_tokens.filter(x => this.languageRegex.test(x));
|
||||
this.lang_to_token = x => x; // Identity function
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -2512,34 +2568,40 @@ export class NllbTokenizer extends PreTrainedTokenizer {
|
|||
* @returns {Object} Object to be passed to the model.
|
||||
*/
|
||||
_build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs) {
|
||||
return _build_translation_inputs(this, raw_inputs, tokenizer_options, generate_kwargs);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The M2M100Tokenizer class is used to tokenize text for M2M100 ("Many-to-Many") models.
|
||||
*
|
||||
* M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many
|
||||
* multilingual translation. It was introduced in this [paper](https://arxiv.org/abs/2010.11125)
|
||||
* and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository.
|
||||
*
|
||||
* For a list of supported languages (along with their language codes),
|
||||
* @see {@link https://huggingface.co/facebook/m2m100_418M#languages-covered}
|
||||
*/
|
||||
export class M2M100Tokenizer extends PreTrainedTokenizer {
|
||||
constructor(tokenizerJSON, tokenizerConfig) {
|
||||
super(tokenizerJSON, tokenizerConfig);
|
||||
|
||||
// Check that the target language is valid:
|
||||
if (!this.language_codes.includes(generate_kwargs.tgt_lang)) {
|
||||
throw new Error(`Target language code "${generate_kwargs.tgt_lang}" is not valid. Must be one of: {${this.language_codes.join(', ')}}`);
|
||||
}
|
||||
this.languageRegex = /^__[a-z]{2,3}__$/;
|
||||
this.language_codes = this.special_tokens
|
||||
.filter(x => this.languageRegex.test(x))
|
||||
.map(x => x.slice(2, -2));
|
||||
this.lang_to_token = x => `__${x}__`;
|
||||
}
|
||||
|
||||
// Allow `src_lang` to be optional. If not set, we'll use the tokenizer's default.
|
||||
if (generate_kwargs.src_lang !== undefined) {
|
||||
// Check that the source language is valid:
|
||||
if (!this.language_codes.includes(generate_kwargs.src_lang)) {
|
||||
throw new Error(`Source language code "${generate_kwargs.src_lang}" is not valid. Must be one of: {${this.language_codes.join(', ')}}`);
|
||||
}
|
||||
|
||||
// In the same way as the Python library, we override the post-processor
|
||||
// to force the source language to be first:
|
||||
for (let item of this.post_processor.config.single) {
|
||||
if ('SpecialToken' in item && this.languageRegex.test(item.SpecialToken.id)) {
|
||||
item.SpecialToken.id = generate_kwargs.src_lang;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Override the `forced_bos_token_id` to force the correct language
|
||||
generate_kwargs.forced_bos_token_id = this.model.convert_tokens_to_ids([generate_kwargs.tgt_lang])[0];
|
||||
|
||||
return this._call(raw_inputs, tokenizer_options);
|
||||
/**
|
||||
* Helper function to build translation inputs for an `M2M100Tokenizer`.
|
||||
* @param {string|string[]} raw_inputs The text to tokenize.
|
||||
* @param {Object} tokenizer_options Options to be sent to the tokenizer
|
||||
* @param {Object} generate_kwargs Generation options.
|
||||
* @returns {Object} Object to be passed to the model.
|
||||
*/
|
||||
_build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs) {
|
||||
return _build_translation_inputs(this, raw_inputs, tokenizer_options, generate_kwargs);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3485,6 +3547,7 @@ export class AutoTokenizer {
|
|||
'MarianTokenizer': MarianTokenizer,
|
||||
'BloomTokenizer': BloomTokenizer,
|
||||
'NllbTokenizer': NllbTokenizer,
|
||||
'M2M100Tokenizer': M2M100Tokenizer,
|
||||
'LlamaTokenizer': LlamaTokenizer,
|
||||
'XLMRobertaTokenizer': XLMRobertaTokenizer,
|
||||
'MPNetTokenizer': MPNetTokenizer,
|
||||
|
|
|
@ -21,6 +21,11 @@ ADDITIONAL_TOKENIZERS_TO_TEST = {
|
|||
],
|
||||
}
|
||||
|
||||
TOKENIZERS_TO_IGNORE = [
|
||||
# TODO: remove when https://github.com/huggingface/transformers/pull/25478 is merged
|
||||
'facebook/m2m100_418M',
|
||||
]
|
||||
|
||||
TOKENIZER_TEST_DATA = {
|
||||
"shared": [
|
||||
"hello world",
|
||||
|
@ -92,6 +97,9 @@ def generate_tokenizer_tests():
|
|||
for model_type, tokenizer_names in tokenizers_to_test:
|
||||
print(f'Generating tests for {model_type}')
|
||||
for tokenizer_name in tokenizer_names:
|
||||
if tokenizer_name in TOKENIZERS_TO_IGNORE:
|
||||
continue
|
||||
|
||||
print(' -', tokenizer_name)
|
||||
|
||||
try:
|
||||
|
|
Loading…
Reference in New Issue