improving GPT2 tokenization and adding tests

This commit is contained in:
thomwolf 2019-04-16 17:00:55 +02:00
parent 3d78e226e6
commit 18a8a15f78
5 changed files with 169 additions and 34 deletions

View File

@ -929,10 +929,11 @@ This class has four arguments:
and five methods:
- `tokenize(text)`: convert a `str` in a list of `str` tokens by (1) performing basic tokenization and (2) WordPiece tokenization.
- `tokenize(text)`: convert a `str` in a list of `str` tokens by performing BPE tokenization.
- `convert_tokens_to_ids(tokens)`: convert a list of `str` tokens in a list of `int` indices in the vocabulary.
- `convert_ids_to_tokens(tokens)`: convert a list of `int` indices in a list of `str` tokens in the vocabulary.
- `set_special_tokens(self, special_tokens)`: update the list of special tokens (see above arguments)
- `encode(text)`: convert a `str` in a list of `int` tokens by performing BPE encoding.
- `decode(ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)`: decode a list of `int` indices in a string and do some post-processing if needed: (i) remove special tokens from the output and (ii) clean up tokenization spaces.
- `save_vocabulary(directory_path)`: save the vocabulary, merge and special tokens files to `directory_path`. Return the path to the three files: `vocab_file_path`, `merge_file_path`, `special_tokens_file_path`. The vocabulary can be reloaded with `OpenAIGPTTokenizer.from_pretrained('directory_path')`.
@ -958,6 +959,10 @@ This class has three arguments:
and two methods:
- `tokenize(text)`: convert a `str` in a list of `str` tokens by performing byte-level BPE.
- `convert_tokens_to_ids(tokens)`: convert a list of `str` tokens in a list of `int` indices in the vocabulary.
- `convert_ids_to_tokens(tokens)`: convert a list of `int` indices in a list of `str` tokens in the vocabulary.
- `set_special_tokens(self, special_tokens)`: update the list of special tokens (see above arguments)
- `encode(text)`: convert a `str` in a list of `int` tokens by performing byte-level BPE.
- `decode(tokens)`: convert back a list of `int` tokens in a `str`.
- `save_vocabulary(directory_path)`: save the vocabulary, merge and special tokens files to `directory_path`. Return the path to the three files: `vocab_file_path`, `merge_file_path`, `special_tokens_file_path`. The vocabulary can be reloaded with `OpenAIGPTTokenizer.from_pretrained('directory_path')`.

View File

@ -16,6 +16,7 @@
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import sys
import json
import logging
import os
@ -138,7 +139,7 @@ class GPT2Tokenizer(object):
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace', max_len=None):
def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None):
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()}
@ -153,8 +154,25 @@ class GPT2Tokenizer(object):
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens)
def __len__(self):
return len(self.encoder)
return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token):
if token in self.cache:
@ -197,6 +215,54 @@ class GPT2Tokenizer(object):
self.cache[token] = word
return word
def tokenize(self, text):
""" Tokenize a string. """
bpe_tokens = []
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def convert_tokens_to_ids(self, tokens):
""" Converts a sequence of tokens into ids using the vocab. """
ids = []
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.encoder.get(tokens, 0)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.encoder.get(token, 0))
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
)
return ids
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens = []
for i in ids:
if i in self.special_tokens_decoder:
if not skip_special_tokens:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.decoder[i])
return tokens
def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(vocab_path):
@ -220,26 +286,14 @@ class GPT2Tokenizer(object):
writer.write(' '.join(bpe_tokens) + u'\n')
index += 1
index = len(self.encoder)
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token in sorted(self.special_tokens.keys(), key=lambda kv: kv[1]):
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
index = token_index
writer.write(token + u'\n')
index += 1
return vocab_file, merge_file, special_tokens_file
def encode(self, text):
bpe_tokens = []
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
if len(bpe_tokens) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT-2 model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(bpe_tokens), self.max_len)
)
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text

View File

@ -150,6 +150,8 @@ class OpenAIGPTTokenizer(object):
merges = [tuple(merge.split()) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens)
def __len__(self):
@ -261,7 +263,10 @@ class OpenAIGPTTokenizer(object):
tokens.append(self.decoder[i])
return tokens
def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=False):
def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
"""Converts a sequence of ids in a string."""
tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
out_string = ''.join(tokens).replace('</w>', ' ').strip()
@ -296,8 +301,14 @@ class OpenAIGPTTokenizer(object):
writer.write(' '.join(bpe_tokens) + u'\n')
index += 1
index = len(self.encoder)
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token in sorted(self.special_tokens.keys(), key=lambda kv: kv[1]):
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
index = token_index
writer.write(token + u'\n')
index += 1
return vocab_file, merge_file, special_tokens_file

View File

@ -0,0 +1,68 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import unittest
import json
from pytorch_pretrained_bert.tokenization_gpt2 import GPT2Tokenizer
class GPT2TokenizationTest(unittest.TestCase):
def test_full_tokenizer(self):
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"lo", "low", "er",
"low", "lowest", "newer", "wider"]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
json.dump(vocab_tokens, fp)
vocab_file = fp.name
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
fp.write("\n".join(merges))
merges_file = fp.name
tokenizer = GPT2Tokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
os.remove(vocab_file)
os.remove(merges_file)
text = "lower"
bpe_tokens = ["low", "er"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + ["<unk>"]
input_bpe_tokens = [13, 12, 16]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
tokenizer_2 = GPT2Tokenizer.from_pretrained("/tmp/")
os.remove(vocab_file)
os.remove(merges_file)
os.remove(special_tokens_file)
self.assertListEqual(
[tokenizer.encoder, tokenizer.decoder, tokenizer.bpe_ranks,
tokenizer.special_tokens, tokenizer.special_tokens_decoder],
[tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks,
tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder])
if __name__ == '__main__':
unittest.main()

View File

@ -38,7 +38,7 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
fp.write("\n".join(merges))
merges_file = fp.name
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=["<unk>"])
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
os.remove(vocab_file)
os.remove(merges_file)
@ -53,19 +53,16 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
tokenizer.from_pretrained("/tmp/")
tokenizer_2 = OpenAIGPTTokenizer.from_pretrained("/tmp/")
os.remove(vocab_file)
os.remove(merges_file)
os.remove(special_tokens_file)
text = "lower"
bpe_tokens = ["low", "er</w>"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + ["<unk>"]
input_bpe_tokens = [14, 15, 20]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
[tokenizer.encoder, tokenizer.decoder, tokenizer.bpe_ranks,
tokenizer.special_tokens, tokenizer.special_tokens_decoder],
[tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks,
tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder])
if __name__ == '__main__':