Tokenization behave the same as original XLM proprocessing for most languages except zh, ja and th; Change API to allow specifying language in `tokenize`
This commit is contained in:
parent
df9d6effae
commit
436ce07218
|
@ -20,8 +20,11 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
import unicodedata
|
||||
from io import open
|
||||
|
||||
import sacremoses as sm
|
||||
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_bert import BasicTokenizer
|
||||
|
||||
|
@ -95,6 +98,93 @@ def text_standardize(text):
|
|||
text = re.sub(r'[^\S\n]+', ' ', text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def lowercase_and_remove_accent(text):
|
||||
"""
|
||||
Lowercase and strips accents from a piece of text based on
|
||||
https://github.com/facebookresearch/XLM/blob/master/tools/lowercase_and_remove_accent.py
|
||||
"""
|
||||
text = text.lower()
|
||||
text = unicodedata.normalize("NFD", text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Mn":
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output).lower()
|
||||
|
||||
|
||||
def replace_unicode_punct(text):
|
||||
'''
|
||||
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl
|
||||
'''
|
||||
text = text.replace(',', ',')
|
||||
text = text.replace('。 *', '. ')
|
||||
text = text.replace('、', ',')
|
||||
text = text.replace('”', '"')
|
||||
text = text.replace('“', '"')
|
||||
text = text.replace('∶', ':')
|
||||
text = text.replace(':', ':')
|
||||
text = text.replace('?', '?')
|
||||
text = text.replace('《', '"')
|
||||
text = text.replace('》', '"')
|
||||
text = text.replace(')', ')')
|
||||
text = text.replace('!', '!')
|
||||
text = text.replace('(', '(')
|
||||
text = text.replace(';', ';')
|
||||
text = text.replace('1', '"')
|
||||
text = text.replace('」', '"')
|
||||
text = text.replace('「', '"')
|
||||
text = text.replace('0', '0')
|
||||
text = text.replace('3', '3')
|
||||
text = text.replace('2', '2')
|
||||
text = text.replace('5', '5')
|
||||
text = text.replace('6', '6')
|
||||
text = text.replace('9', '9')
|
||||
text = text.replace('7', '7')
|
||||
text = text.replace('8', '8')
|
||||
text = text.replace('4', '4')
|
||||
text = re.sub(r'.\s*', '. ', text)
|
||||
text = text.replace('~', '~')
|
||||
text = text.replace('’', '\'')
|
||||
text = text.replace('…', '...')
|
||||
text = text.replace('━', '-')
|
||||
text = text.replace('〈', '<')
|
||||
text = text.replace('〉', '>')
|
||||
text = text.replace('【', '[')
|
||||
text = text.replace('】', ']')
|
||||
text = text.replace('%', '%')
|
||||
return text
|
||||
|
||||
|
||||
def remove_non_printing_char(text):
|
||||
'''
|
||||
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl
|
||||
'''
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith('C'):
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
|
||||
def romanian_preprocessing(text):
|
||||
'''Sennrich's WMT16 scripts for Romanian preprocessing, used by model `xlm-mlm-enro-1024`'''
|
||||
# https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/normalise-romanian.py
|
||||
text = text.replace("\u015e", "\u0218").replace("\u015f", "\u0219")
|
||||
text = text.replace("\u0162", "\u021a").replace("\u0163", "\u021b")
|
||||
# https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/remove-diacritics.py
|
||||
text = text.replace("\u0218", "S").replace("\u0219", "s") #s-comma
|
||||
text = text.replace("\u021a", "T").replace("\u021b", "t") #t-comma
|
||||
text = text.replace("\u0102", "A").replace("\u0103", "a")
|
||||
text = text.replace("\u00C2", "A").replace("\u00E2", "a")
|
||||
text = text.replace("\u00CE", "I").replace("\u00EE", "i")
|
||||
return text
|
||||
|
||||
|
||||
class XLMTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
BPE tokenizer for XLM, adapted from OpenAI BPE tokenizer. Peculiarities:
|
||||
|
@ -122,16 +212,14 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||
cls_token=cls_token, mask_token=mask_token,
|
||||
additional_special_tokens=additional_special_tokens,
|
||||
**kwargs)
|
||||
try:
|
||||
import ftfy
|
||||
from spacy.lang.en import English
|
||||
_nlp = English()
|
||||
self.nlp = _nlp.Defaults.create_tokenizer(_nlp)
|
||||
self.fix_text = ftfy.fix_text
|
||||
except ImportError:
|
||||
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
|
||||
self.nlp = BasicTokenizer(do_lower_case=True)
|
||||
self.fix_text = None
|
||||
|
||||
# cache of sm.MosesPunctNormalizer instance
|
||||
self.cache_moses_punct_normalizer = dict()
|
||||
# cache of sm.MosesTokenizer instance
|
||||
self.cache_moses_tokenizer = dict()
|
||||
self.lang_with_custom_tokenizer = set(['zh', 'th', 'ja'])
|
||||
# True for current supported model (v1.2.0), False for XLM-17 & 100
|
||||
self.do_lowercase_and_remove_accent = True
|
||||
|
||||
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
|
||||
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||
|
@ -140,6 +228,28 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {}
|
||||
|
||||
def moses_punct_norm(self, text, lang):
|
||||
if lang not in self.cache_moses_punct_normalizer:
|
||||
punct_normalizer = sm.MosesPunctNormalizer(lang=lang)
|
||||
self.cache_moses_punct_normalizer[lang] = punct_normalizer
|
||||
else:
|
||||
punct_normalizer = self.cache_moses_punct_normalizer[lang]
|
||||
return punct_normalizer.normalize(text)
|
||||
|
||||
def moses_tokenize(self, text, lang):
|
||||
if lang not in self.cache_moses_tokenizer:
|
||||
moses_tokenizer = sm.MosesTokenizer(lang=lang)
|
||||
self.cache_moses_tokenizer[lang] = moses_tokenizer
|
||||
else:
|
||||
moses_tokenizer = self.cache_moses_tokenizer[lang]
|
||||
return moses_tokenizer.tokenize(text, return_str=False, escape=False)
|
||||
|
||||
def moses_pipeline(self, text, lang):
|
||||
text = replace_unicode_punct(text)
|
||||
text = self.moses_punct_norm(text, lang)
|
||||
text = remove_non_printing_char(text)
|
||||
return text
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.encoder)
|
||||
|
@ -187,19 +297,21 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def _tokenize(self, text):
|
||||
def _tokenize(self, text, lang='en'):
|
||||
""" Tokenize a string. """
|
||||
split_tokens = []
|
||||
if self.fix_text is None:
|
||||
# Using BERT's BasicTokenizer
|
||||
text = self.nlp.tokenize(text)
|
||||
if self.do_lowercase_and_remove_accent:
|
||||
text = lowercase_and_remove_accent(text)
|
||||
if lang not in self.lang_with_custom_tokenizer:
|
||||
text = self.moses_pipeline(text, lang=lang)
|
||||
# TODO: make sure we are using `xlm-mlm-enro-1024`, since XLM-100 doesn't have this step
|
||||
if lang == 'ro':
|
||||
text = romanian_preprocessing(text)
|
||||
text = self.moses_tokenize(text, lang=lang)
|
||||
for token in text:
|
||||
split_tokens.extend([t for t in self.bpe(token).split(' ')])
|
||||
else:
|
||||
# Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
|
||||
text = self.nlp(text_standardize(self.fix_text(text)))
|
||||
for token in text:
|
||||
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
|
||||
raise ValueError
|
||||
return split_tokens
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
|
|
|
@ -9,4 +9,6 @@ requests
|
|||
# For OpenAI GPT
|
||||
regex
|
||||
# For XLNet
|
||||
sentencepiece
|
||||
sentencepiece
|
||||
# For XLM
|
||||
sacremoses
|
Loading…
Reference in New Issue