Tokenization behave the same as original XLM proprocessing for most languages except zh, ja and th; Change API to allow specifying language in `tokenize`

This commit is contained in:
Shijie Wu 2019-08-23 14:40:17 -04:00
parent df9d6effae
commit 436ce07218
3 changed files with 135 additions and 20 deletions

View File

@ -20,8 +20,11 @@ import json
import logging
import os
import re
import unicodedata
from io import open
import sacremoses as sm
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_bert import BasicTokenizer
@ -95,6 +98,93 @@ def text_standardize(text):
text = re.sub(r'[^\S\n]+', ' ', text)
return text.strip()
def lowercase_and_remove_accent(text):
"""
Lowercase and strips accents from a piece of text based on
https://github.com/facebookresearch/XLM/blob/master/tools/lowercase_and_remove_accent.py
"""
text = text.lower()
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output).lower()
def replace_unicode_punct(text):
'''
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl
'''
text = text.replace('', ',')
text = text.replace('。 *', '. ')
text = text.replace('', ',')
text = text.replace('', '"')
text = text.replace('', '"')
text = text.replace('', ':')
text = text.replace('', ':')
text = text.replace('', '?')
text = text.replace('', '"')
text = text.replace('', '"')
text = text.replace('', ')')
text = text.replace('', '!')
text = text.replace('', '(')
text = text.replace('', ';')
text = text.replace('', '"')
text = text.replace('', '"')
text = text.replace('', '"')
text = text.replace('', '0')
text = text.replace('', '3')
text = text.replace('', '2')
text = text.replace('', '5')
text = text.replace('', '6')
text = text.replace('', '9')
text = text.replace('', '7')
text = text.replace('', '8')
text = text.replace('', '4')
text = re.sub(r'\s*', '. ', text)
text = text.replace('', '~')
text = text.replace('', '\'')
text = text.replace('', '...')
text = text.replace('', '-')
text = text.replace('', '<')
text = text.replace('', '>')
text = text.replace('', '[')
text = text.replace('', ']')
text = text.replace('', '%')
return text
def remove_non_printing_char(text):
'''
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl
'''
output = []
for char in text:
cat = unicodedata.category(char)
if cat.startswith('C'):
continue
output.append(char)
return "".join(output)
def romanian_preprocessing(text):
'''Sennrich's WMT16 scripts for Romanian preprocessing, used by model `xlm-mlm-enro-1024`'''
# https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/normalise-romanian.py
text = text.replace("\u015e", "\u0218").replace("\u015f", "\u0219")
text = text.replace("\u0162", "\u021a").replace("\u0163", "\u021b")
# https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/remove-diacritics.py
text = text.replace("\u0218", "S").replace("\u0219", "s") #s-comma
text = text.replace("\u021a", "T").replace("\u021b", "t") #t-comma
text = text.replace("\u0102", "A").replace("\u0103", "a")
text = text.replace("\u00C2", "A").replace("\u00E2", "a")
text = text.replace("\u00CE", "I").replace("\u00EE", "i")
return text
class XLMTokenizer(PreTrainedTokenizer):
"""
BPE tokenizer for XLM, adapted from OpenAI BPE tokenizer. Peculiarities:
@ -122,16 +212,14 @@ class XLMTokenizer(PreTrainedTokenizer):
cls_token=cls_token, mask_token=mask_token,
additional_special_tokens=additional_special_tokens,
**kwargs)
try:
import ftfy
from spacy.lang.en import English
_nlp = English()
self.nlp = _nlp.Defaults.create_tokenizer(_nlp)
self.fix_text = ftfy.fix_text
except ImportError:
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
self.nlp = BasicTokenizer(do_lower_case=True)
self.fix_text = None
# cache of sm.MosesPunctNormalizer instance
self.cache_moses_punct_normalizer = dict()
# cache of sm.MosesTokenizer instance
self.cache_moses_tokenizer = dict()
self.lang_with_custom_tokenizer = set(['zh', 'th', 'ja'])
# True for current supported model (v1.2.0), False for XLM-17 & 100
self.do_lowercase_and_remove_accent = True
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
self.decoder = {v:k for k,v in self.encoder.items()}
@ -140,6 +228,28 @@ class XLMTokenizer(PreTrainedTokenizer):
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
def moses_punct_norm(self, text, lang):
if lang not in self.cache_moses_punct_normalizer:
punct_normalizer = sm.MosesPunctNormalizer(lang=lang)
self.cache_moses_punct_normalizer[lang] = punct_normalizer
else:
punct_normalizer = self.cache_moses_punct_normalizer[lang]
return punct_normalizer.normalize(text)
def moses_tokenize(self, text, lang):
if lang not in self.cache_moses_tokenizer:
moses_tokenizer = sm.MosesTokenizer(lang=lang)
self.cache_moses_tokenizer[lang] = moses_tokenizer
else:
moses_tokenizer = self.cache_moses_tokenizer[lang]
return moses_tokenizer.tokenize(text, return_str=False, escape=False)
def moses_pipeline(self, text, lang):
text = replace_unicode_punct(text)
text = self.moses_punct_norm(text, lang)
text = remove_non_printing_char(text)
return text
@property
def vocab_size(self):
return len(self.encoder)
@ -187,19 +297,21 @@ class XLMTokenizer(PreTrainedTokenizer):
self.cache[token] = word
return word
def _tokenize(self, text):
def _tokenize(self, text, lang='en'):
""" Tokenize a string. """
split_tokens = []
if self.fix_text is None:
# Using BERT's BasicTokenizer
text = self.nlp.tokenize(text)
if self.do_lowercase_and_remove_accent:
text = lowercase_and_remove_accent(text)
if lang not in self.lang_with_custom_tokenizer:
text = self.moses_pipeline(text, lang=lang)
# TODO: make sure we are using `xlm-mlm-enro-1024`, since XLM-100 doesn't have this step
if lang == 'ro':
text = romanian_preprocessing(text)
text = self.moses_tokenize(text, lang=lang)
for token in text:
split_tokens.extend([t for t in self.bpe(token).split(' ')])
else:
# Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
text = self.nlp(text_standardize(self.fix_text(text)))
for token in text:
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
raise ValueError
return split_tokens
def _convert_token_to_id(self, token):

View File

@ -9,4 +9,6 @@ requests
# For OpenAI GPT
regex
# For XLNet
sentencepiece
sentencepiece
# For XLM
sacremoses

View File

@ -55,7 +55,8 @@ setup(
'requests',
'tqdm',
'regex',
'sentencepiece'],
'sentencepiece',
'sacremoses'],
entry_points={
'console_scripts': [
"pytorch_transformers=pytorch_transformers.__main__:main",