From 436ce072183e3e134d2fbc286f6c72f012f31e74 Mon Sep 17 00:00:00 2001 From: Shijie Wu Date: Fri, 23 Aug 2019 14:40:17 -0400 Subject: [PATCH] Tokenization behave the same as original XLM proprocessing for most languages except zh, ja and th; Change API to allow specifying language in `tokenize` --- pytorch_transformers/tokenization_xlm.py | 148 ++++++++++++++++++++--- requirements.txt | 4 +- setup.py | 3 +- 3 files changed, 135 insertions(+), 20 deletions(-) diff --git a/pytorch_transformers/tokenization_xlm.py b/pytorch_transformers/tokenization_xlm.py index 2d2f3a8cd4..8418a5d6f3 100644 --- a/pytorch_transformers/tokenization_xlm.py +++ b/pytorch_transformers/tokenization_xlm.py @@ -20,8 +20,11 @@ import json import logging import os import re +import unicodedata from io import open +import sacremoses as sm + from .tokenization_utils import PreTrainedTokenizer from .tokenization_bert import BasicTokenizer @@ -95,6 +98,93 @@ def text_standardize(text): text = re.sub(r'[^\S\n]+', ' ', text) return text.strip() + +def lowercase_and_remove_accent(text): + """ + Lowercase and strips accents from a piece of text based on + https://github.com/facebookresearch/XLM/blob/master/tools/lowercase_and_remove_accent.py + """ + text = text.lower() + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output).lower() + + +def replace_unicode_punct(text): + ''' + Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl + ''' + text = text.replace(',', ',') + text = text.replace('。 *', '. ') + text = text.replace('、', ',') + text = text.replace('”', '"') + text = text.replace('“', '"') + text = text.replace('∶', ':') + text = text.replace(':', ':') + text = text.replace('?', '?') + text = text.replace('《', '"') + text = text.replace('》', '"') + text = text.replace(')', ')') + text = text.replace('!', '!') + text = text.replace('(', '(') + text = text.replace(';', ';') + text = text.replace('1', '"') + text = text.replace('」', '"') + text = text.replace('「', '"') + text = text.replace('0', '0') + text = text.replace('3', '3') + text = text.replace('2', '2') + text = text.replace('5', '5') + text = text.replace('6', '6') + text = text.replace('9', '9') + text = text.replace('7', '7') + text = text.replace('8', '8') + text = text.replace('4', '4') + text = re.sub(r'.\s*', '. ', text) + text = text.replace('~', '~') + text = text.replace('’', '\'') + text = text.replace('…', '...') + text = text.replace('━', '-') + text = text.replace('〈', '<') + text = text.replace('〉', '>') + text = text.replace('【', '[') + text = text.replace('】', ']') + text = text.replace('%', '%') + return text + + +def remove_non_printing_char(text): + ''' + Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl + ''' + output = [] + for char in text: + cat = unicodedata.category(char) + if cat.startswith('C'): + continue + output.append(char) + return "".join(output) + + +def romanian_preprocessing(text): + '''Sennrich's WMT16 scripts for Romanian preprocessing, used by model `xlm-mlm-enro-1024`''' + # https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/normalise-romanian.py + text = text.replace("\u015e", "\u0218").replace("\u015f", "\u0219") + text = text.replace("\u0162", "\u021a").replace("\u0163", "\u021b") + # https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/remove-diacritics.py + text = text.replace("\u0218", "S").replace("\u0219", "s") #s-comma + text = text.replace("\u021a", "T").replace("\u021b", "t") #t-comma + text = text.replace("\u0102", "A").replace("\u0103", "a") + text = text.replace("\u00C2", "A").replace("\u00E2", "a") + text = text.replace("\u00CE", "I").replace("\u00EE", "i") + return text + + class XLMTokenizer(PreTrainedTokenizer): """ BPE tokenizer for XLM, adapted from OpenAI BPE tokenizer. Peculiarities: @@ -122,16 +212,14 @@ class XLMTokenizer(PreTrainedTokenizer): cls_token=cls_token, mask_token=mask_token, additional_special_tokens=additional_special_tokens, **kwargs) - try: - import ftfy - from spacy.lang.en import English - _nlp = English() - self.nlp = _nlp.Defaults.create_tokenizer(_nlp) - self.fix_text = ftfy.fix_text - except ImportError: - logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") - self.nlp = BasicTokenizer(do_lower_case=True) - self.fix_text = None + + # cache of sm.MosesPunctNormalizer instance + self.cache_moses_punct_normalizer = dict() + # cache of sm.MosesTokenizer instance + self.cache_moses_tokenizer = dict() + self.lang_with_custom_tokenizer = set(['zh', 'th', 'ja']) + # True for current supported model (v1.2.0), False for XLM-17 & 100 + self.do_lowercase_and_remove_accent = True self.encoder = json.load(open(vocab_file, encoding="utf-8")) self.decoder = {v:k for k,v in self.encoder.items()} @@ -140,6 +228,28 @@ class XLMTokenizer(PreTrainedTokenizer): self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = {} + def moses_punct_norm(self, text, lang): + if lang not in self.cache_moses_punct_normalizer: + punct_normalizer = sm.MosesPunctNormalizer(lang=lang) + self.cache_moses_punct_normalizer[lang] = punct_normalizer + else: + punct_normalizer = self.cache_moses_punct_normalizer[lang] + return punct_normalizer.normalize(text) + + def moses_tokenize(self, text, lang): + if lang not in self.cache_moses_tokenizer: + moses_tokenizer = sm.MosesTokenizer(lang=lang) + self.cache_moses_tokenizer[lang] = moses_tokenizer + else: + moses_tokenizer = self.cache_moses_tokenizer[lang] + return moses_tokenizer.tokenize(text, return_str=False, escape=False) + + def moses_pipeline(self, text, lang): + text = replace_unicode_punct(text) + text = self.moses_punct_norm(text, lang) + text = remove_non_printing_char(text) + return text + @property def vocab_size(self): return len(self.encoder) @@ -187,19 +297,21 @@ class XLMTokenizer(PreTrainedTokenizer): self.cache[token] = word return word - def _tokenize(self, text): + def _tokenize(self, text, lang='en'): """ Tokenize a string. """ split_tokens = [] - if self.fix_text is None: - # Using BERT's BasicTokenizer - text = self.nlp.tokenize(text) + if self.do_lowercase_and_remove_accent: + text = lowercase_and_remove_accent(text) + if lang not in self.lang_with_custom_tokenizer: + text = self.moses_pipeline(text, lang=lang) + # TODO: make sure we are using `xlm-mlm-enro-1024`, since XLM-100 doesn't have this step + if lang == 'ro': + text = romanian_preprocessing(text) + text = self.moses_tokenize(text, lang=lang) for token in text: split_tokens.extend([t for t in self.bpe(token).split(' ')]) else: - # Using SpaCy & ftfy (original tokenization process of OpenAI GPT) - text = self.nlp(text_standardize(self.fix_text(text))) - for token in text: - split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) + raise ValueError return split_tokens def _convert_token_to_id(self, token): diff --git a/requirements.txt b/requirements.txt index 76532d18a5..01dca79d23 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ requests # For OpenAI GPT regex # For XLNet -sentencepiece \ No newline at end of file +sentencepiece +# For XLM +sacremoses \ No newline at end of file diff --git a/setup.py b/setup.py index c9f80fc224..2979722268 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,8 @@ setup( 'requests', 'tqdm', 'regex', - 'sentencepiece'], + 'sentencepiece', + 'sacremoses'], entry_points={ 'console_scripts': [ "pytorch_transformers=pytorch_transformers.__main__:main",