Add standardized get_vocab method to tokenizers

This commit is contained in:
Joe Davison 2020-02-22 12:09:01 -05:00 committed by GitHub
commit c36416e53c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 62 additions and 0 deletions

View File

@ -114,6 +114,11 @@ class AlbertTokenizer(PreTrainedTokenizer):
def vocab_size(self):
return len(self.sp_model)
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None

View File

@ -195,6 +195,9 @@ class BertTokenizer(PreTrainedTokenizer):
def vocab_size(self):
return len(self.vocab)
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)
def _tokenize(self, text):
split_tokens = []
if self.do_basic_tokenize:

View File

@ -147,6 +147,9 @@ class CTRLTokenizer(PreTrainedTokenizer):
def vocab_size(self):
return len(self.encoder)
def get_vocab(self):
return dict(self.encoder, **self.added_tokens_encoder)
def bpe(self, token):
if token in self.cache:
return self.cache[token]

View File

@ -149,6 +149,9 @@ class GPT2Tokenizer(PreTrainedTokenizer):
def vocab_size(self):
return len(self.encoder)
def get_vocab(self):
return dict(self.encoder, **self.added_tokens_encoder)
def bpe(self, token):
if token in self.cache:
return self.cache[token]

View File

@ -125,6 +125,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
def vocab_size(self):
return len(self.encoder)
def get_vocab(self):
return dict(self.encoder, **self.added_tokens_encoder)
def bpe(self, token):
word = tuple(token[:-1]) + (token[-1] + "</w>",)
if token in self.cache:

View File

@ -119,6 +119,11 @@ class T5Tokenizer(PreTrainedTokenizer):
def vocab_size(self):
return self.sp_model.get_piece_size() + self._extra_ids
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None

View File

@ -273,6 +273,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
def vocab_size(self):
return len(self.idx2sym)
def get_vocab(self):
return dict(self.sym2idx, **self.added_tokens_encoder)
def _tokenize(self, line, add_eos=False, add_double_eos=False):
line = line.strip()
# convert to lower case

View File

@ -286,6 +286,10 @@ class PreTrainedTokenizer(object):
""" Ids of all the additional special tokens in the vocabulary (list of integers). Log an error if used while not having been set. """
return self.convert_tokens_to_ids(self.additional_special_tokens)
def get_vocab(self):
""" Returns the vocabulary as a dict of {token: index} pairs. `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the vocab. """
raise NotImplementedError()
def __init__(self, max_len=None, **kwargs):
self._bos_token = None
self._eos_token = None

View File

@ -662,6 +662,9 @@ class XLMTokenizer(PreTrainedTokenizer):
def vocab_size(self):
return len(self.encoder)
def get_vocab(self):
return dict(self.encoder, **self.added_tokens_encoder)
def bpe(self, token):
word = tuple(token[:-1]) + (token[-1] + "</w>",)
if token in self.cache:

View File

@ -190,6 +190,11 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
def vocab_size(self):
return len(self.sp_model) + len(self.fairseq_tokens_to_ids)
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text):
return self.sp_model.EncodeAsPieces(text)

View File

@ -114,6 +114,11 @@ class XLNetTokenizer(PreTrainedTokenizer):
def vocab_size(self):
return len(self.sp_model)
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None

View File

@ -542,3 +542,23 @@ class TokenizerTesterMixin:
print(new_tokenizer.init_kwargs)
assert tokenizer.init_kwargs["random_argument"] is True
assert new_tokenizer.init_kwargs["random_argument"] is False
def test_get_vocab(self):
tokenizer = self.get_tokenizer()
vocab = tokenizer.get_vocab()
self.assertIsInstance(vocab, dict)
self.assertEqual(len(vocab), len(tokenizer))
for word, ind in vocab.items():
self.assertEqual(tokenizer.convert_tokens_to_ids(word), ind)
self.assertEqual(tokenizer.convert_ids_to_tokens(ind), word)
tokenizer.add_tokens(["asdfasdfasdfasdf"])
vocab = tokenizer.get_vocab()
self.assertIsInstance(vocab, dict)
self.assertEqual(len(vocab), len(tokenizer))
for word, ind in vocab.items():
self.assertEqual(tokenizer.convert_tokens_to_ids(word), ind)
self.assertEqual(tokenizer.convert_ids_to_tokens(ind), word)