add fix for serialization of tokenizer
This commit is contained in:
parent
d9184620f9
commit
4f8b5f687c
|
@ -182,6 +182,21 @@ class XLNetTokenizer(object):
|
|||
def __len__(self):
|
||||
return len(self.encoder) + len(self.special_tokens)
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state["sp_model"] = None
|
||||
return state
|
||||
|
||||
def __setstate__(self, d):
|
||||
self.__dict__ = d
|
||||
try:
|
||||
import sentencepiece as spm
|
||||
except ImportError:
|
||||
logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
|
||||
"pip install sentencepiece")
|
||||
self.sp_model = spm.SentencePieceProcessor()
|
||||
self.sp_model.Load(self.vocab_file)
|
||||
|
||||
def set_special_tokens(self, special_tokens):
|
||||
""" Add a list of additional tokens to the encoder.
|
||||
The additional tokens are indexed starting from the last index of the
|
||||
|
|
|
@ -15,11 +15,17 @@
|
|||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from io import open
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
else:
|
||||
import pickle
|
||||
|
||||
from pytorch_pretrained_bert.tokenization_xlnet import (XLNetTokenizer,
|
||||
PRETRAINED_VOCAB_ARCHIVE_MAP,
|
||||
SPIECE_UNDERLINE)
|
||||
|
@ -43,8 +49,6 @@ class XLNetTokenizationTest(unittest.TestCase):
|
|||
vocab_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path)
|
||||
tokenizer = tokenizer.from_pretrained(vocab_path,
|
||||
keep_accents=True)
|
||||
os.remove(vocab_file)
|
||||
os.remove(special_tokens_file)
|
||||
|
||||
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
|
||||
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
|
||||
|
@ -65,6 +69,22 @@ class XLNetTokenizationTest(unittest.TestCase):
|
|||
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
|
||||
u'<unk>', u'.'])
|
||||
|
||||
text = "Munich and Berlin are nice cities"
|
||||
filename = u"/tmp/tokenizer.bin"
|
||||
|
||||
subwords = tokenizer.tokenize(text)
|
||||
|
||||
pickle.dump(tokenizer, open(filename, "wb"))
|
||||
|
||||
tokenizer_new = pickle.load(open(filename, "rb"))
|
||||
subwords_loaded = tokenizer_new.tokenize(text)
|
||||
|
||||
self.assertListEqual(subwords, subwords_loaded)
|
||||
|
||||
os.remove(filename)
|
||||
os.remove(vocab_file)
|
||||
os.remove(special_tokens_file)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_tokenizer_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
|
||||
|
|
Loading…
Reference in New Issue