add fix for serialization of tokenizer

This commit is contained in:
thomwolf 2019-06-29 23:35:21 +02:00
parent d9184620f9
commit 4f8b5f687c
2 changed files with 37 additions and 2 deletions

View File

@ -182,6 +182,21 @@ class XLNetTokenizer(object):
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
return state
def __setstate__(self, d):
self.__dict__ = d
try:
import sentencepiece as spm
except ImportError:
logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece")
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(self.vocab_file)
def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the

View File

@ -15,11 +15,17 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import sys
import unittest
from io import open
import shutil
import pytest
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
from pytorch_pretrained_bert.tokenization_xlnet import (XLNetTokenizer,
PRETRAINED_VOCAB_ARCHIVE_MAP,
SPIECE_UNDERLINE)
@ -43,8 +49,6 @@ class XLNetTokenizationTest(unittest.TestCase):
vocab_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path)
tokenizer = tokenizer.from_pretrained(vocab_path,
keep_accents=True)
os.remove(vocab_file)
os.remove(special_tokens_file)
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
@ -65,6 +69,22 @@ class XLNetTokenizationTest(unittest.TestCase):
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
u'<unk>', u'.'])
text = "Munich and Berlin are nice cities"
filename = u"/tmp/tokenizer.bin"
subwords = tokenizer.tokenize(text)
pickle.dump(tokenizer, open(filename, "wb"))
tokenizer_new = pickle.load(open(filename, "rb"))
subwords_loaded = tokenizer_new.tokenize(text)
self.assertListEqual(subwords, subwords_loaded)
os.remove(filename)
os.remove(vocab_file)
os.remove(special_tokens_file)
@pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"