fixed GPT-2 tokenization on python 2

This commit is contained in:
thomwolf 2019-04-17 10:56:15 +02:00
parent bdaba1897c
commit bc70779bf0
4 changed files with 7 additions and 5 deletions

View File

@ -227,7 +227,7 @@ def get_from_cache(url, cache_dir=None):
meta = {'url': url, 'etag': etag}
meta_path = cache_path + '.json'
with open(meta_path, 'w', encoding="utf-8") as meta_file:
json.dump(meta, meta_file)
meta_file.write(json.dumps(meta))
logger.info("removing temp file %s", temp_file.name)

View File

@ -59,6 +59,7 @@ def bytes_to_unicode():
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
_chr = unichr if sys.version_info[0] == 2 else chr
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
@ -67,7 +68,7 @@ def bytes_to_unicode():
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
cs = [_chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
@ -219,7 +220,7 @@ class GPT2Tokenizer(object):
""" Tokenize a string. """
bpe_tokens = []
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
token = ''.join(self.byte_encoder[ord(b)] for b in token.encode('utf-8'))
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
return bpe_tokens

View File

@ -31,13 +31,14 @@ class GPT2TokenizationTest(unittest.TestCase):
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
json.dump(vocab_tokens, fp)
fp.write(json.dumps(vocab_tokens))
vocab_file = fp.name
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
fp.write("\n".join(merges))
merges_file = fp.name
tokenizer = GPT2Tokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
print("encoder", tokenizer.byte_encoder)
os.remove(vocab_file)
os.remove(merges_file)

View File

@ -32,7 +32,7 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""]
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
json.dump(vocab_tokens, fp)
fp.write(json.dumps(vocab_tokens))
vocab_file = fp.name
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
fp.write("\n".join(merges))