add serialization semantics to tokenizers - fix transfo-xl tokenizer

This commit is contained in:
thomwolf 2019-04-15 11:47:25 +02:00
parent 616743330e
commit 3e65f255dc
5 changed files with 67 additions and 110 deletions

View File

@ -28,7 +28,7 @@ import math
import torch
from pytorch_pretrained_bert import TransfoXLLMHeadModel, TransfoXLCorpus
from pytorch_pretrained_bert import TransfoXLLMHeadModel, TransfoXLCorpus, TransfoXLTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
@ -80,6 +80,7 @@ def main():
# The pre-processing involve computing word frequencies to prepare the Adaptive input and SoftMax
# and tokenizing the dataset
# The pre-processed corpus is a convertion (using the conversion script )
tokenizer = TransfoXLTokenizer.from_pretrained(args.model_name)
corpus = TransfoXLCorpus.from_pretrained(args.model_name)
ntokens = len(corpus.vocab)

View File

@ -134,6 +134,19 @@ class BertTokenizer(object):
tokens.append(self.ids_to_tokens[i])
return tokens
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a path."""
index = 0
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!".format(vocab_file))
index = token_index
writer.write(token + u'\n')
index += 1
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""

View File

@ -187,6 +187,22 @@ class GPT2Tokenizer(object):
self.cache[token] = word
return word
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a path."""
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME)
json.dump(self.encoder, vocab_file)
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
index = token_index
writer.write(bpe_tokens + u'\n')
index += 1
def encode(self, text):
bpe_tokens = []
for token in re.findall(self.pat, text):

View File

@ -261,3 +261,19 @@ class OpenAIGPTTokenizer(object):
).replace(" 's", "'s").replace(" t ", "'t ").replace(" s ", "'s ").replace(" m ", "'m "
).replace(" 've", "'ve")
return out_string
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a path."""
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME)
json.dump(self.encoder, vocab_file)
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
index = token_index
writer.write(bpe_tokens + u'\n')
index += 1

View File

@ -63,7 +63,10 @@ class TransfoXLTokenizer(object):
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
if os.path.isdir(pretrained_model_name_or_path):
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
else:
vocab_file = pretrained_model_name_or_path
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
@ -141,6 +144,11 @@ class TransfoXLTokenizer(object):
else:
raise ValueError('No <unkown> token in vocabulary')
def save_vocabulary(self, vocab_path):
index = 0
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
torch.save(self.__dict__, vocab_file)
def build_vocab(self):
if self.vocab_file:
print('building vocab from {}'.format(self.vocab_file))
@ -245,82 +253,24 @@ class TransfoXLTokenizer(object):
def __len__(self):
return len(self.idx2sym)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
if text in self.never_split:
return [text]
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
def whitespace_tokenize(self, text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
if self.delimiter == '':
tokens = text
else:
tokens = text.split(self.delimiter)
return tokens
def tokenize(self, line, add_eos=False, add_double_eos=False):
line = self._clean_text(line)
line = line.strip()
# convert to lower case
if self.lower_case:
line = line.lower()
symbols = self.whitespace_tokenize(line)
split_symbols = []
for symbol in symbols:
if self.lower_case and symbol not in self.never_split:
symbol = symbol.lower()
symbol = self._run_strip_accents(symbol)
split_symbols.extend(self._run_split_on_punc(symbol))
# empty delimiter '' will evaluate False
if self.delimiter == '':
symbols = line
else:
symbols = line.split(self.delimiter)
if add_double_eos: # lm1b
return ['<S>'] + split_symbols + ['<S>']
return ['<S>'] + symbols + ['<S>']
elif add_eos:
return split_symbols + ['<eos>']
return symbols + ['<eos>']
else:
return split_symbols
return symbols
class LMOrderedIterator(object):
@ -631,42 +581,3 @@ def get_lm_corpus(datadir, dataset):
torch.save(corpus, fn)
return corpus
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False