Add more tests on tokenizers serialization - fix bugs (#5056)

* update tests for fast tokenizers + fix small bug in saving/loading

* better tests on serialization

* fixing serialization

* comment cleanup
This commit is contained in:
Thomas Wolf 2020-06-24 21:53:08 +02:00 committed by GitHub
parent 0148c262e7
commit 7ac9110711
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 170 additions and 98 deletions

View File

@ -20,7 +20,7 @@ import itertools
import logging
import re
import unicodedata
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
from .file_utils import add_end_docstrings
from .tokenization_utils_base import (
@ -155,10 +155,12 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Added tokens
self.added_tokens_encoder = {}
self.unique_added_tokens_encoder = []
self.added_tokens_decoder = {}
# Added tokens - We store this for both slow and fast tokenizers
# until the serialization of Fast tokenizers is updated
self.added_tokens_encoder: Dict[str, int] = {}
self.added_tokens_decoder: Dict[int, str] = {}
self.unique_no_split_tokens: List[str] = []
@property
def is_fast(self) -> bool:
@ -173,11 +175,14 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
""" Returns the vocabulary as a dict of {token: index} pairs. `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the vocab. """
raise NotImplementedError()
def get_added_vocab(self) -> Dict[str, int]:
return self.added_tokens_encoder
def __len__(self):
""" Size of the full vocabulary with the added tokens """
return self.vocab_size + len(self.added_tokens_encoder)
def add_tokens(self, new_tokens: Union[str, List[str]], special_token=False) -> int:
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens=False) -> int:
"""
Add a list of new tokens to the tokenizer class. If the new tokens are not in the
vocabulary, they are added to it with indices starting from length of the current vocabulary.
@ -199,16 +204,12 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
print('We have added', num_added_toks, 'tokens')
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
"""
if not new_tokens:
return 0
if not isinstance(new_tokens, list):
new_tokens = [new_tokens]
new_tokens = [str(tok) for tok in new_tokens]
tokens_to_add = []
for token in new_tokens:
assert isinstance(token, (str, AddedToken))
if self.init_kwargs.get("do_lower_case", False) and token not in self.all_special_tokens:
assert isinstance(token, str)
if not special_tokens and self.init_kwargs.get("do_lower_case", False):
token = token.lower()
if (
token != self.unk_token
@ -222,11 +223,15 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
self.added_tokens_encoder.update(added_tok_encoder)
self.unique_added_tokens_encoder = list(
set(self.added_tokens_encoder.keys()).union(set(self.all_special_tokens))
)
self.added_tokens_decoder.update(added_tok_decoder)
# Make sure we don't split on any special tokens (even they were already in the vocab before e.g. for Albert)
if special_tokens:
self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(new_tokens)))
else:
# Or on the newly added tokens
self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(tokens_to_add)))
return len(tokens_to_add)
def num_special_tokens_to_add(self, pair=False):
@ -340,7 +345,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
for tok in tok_list:
tokenized_text = []
for sub_text in text_list:
if sub_text not in self.unique_added_tokens_encoder:
if sub_text not in self.unique_no_split_tokens:
tokenized_text += split_on_token(tok, sub_text)
else:
tokenized_text += [sub_text]
@ -349,14 +354,14 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
return list(
itertools.chain.from_iterable(
(
self._tokenize(token) if token not in self.unique_added_tokens_encoder else [token]
self._tokenize(token) if token not in self.unique_no_split_tokens else [token]
for token in tokenized_text
)
)
)
added_tokens = self.unique_added_tokens_encoder
tokenized_text = split_on_tokens(added_tokens, text)
no_split_token = self.unique_no_split_tokens
tokenized_text = split_on_tokens(no_split_token, text)
return tokenized_text
def _tokenize(self, text, **kwargs):

View File

@ -62,9 +62,12 @@ PreTokenizedInputPair = Tuple[List[str], List[str]]
EncodedInputPair = Tuple[List[int], List[int]]
# Slow tokenizers used to be saved in three separated files
SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
ADDED_TOKENS_FILE = "added_tokens.json"
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
FULL_TOKENIZER_FILE = "tokenizer.json"
@ -589,10 +592,14 @@ class SpecialTokensMixin:
self._additional_special_tokens = []
self.verbose = verbose
# We directly set the hidden value to allow initialization with special tokens
# which are not yet in the vocabulary. Necesssary for serialization/de-serialization
# TODO clean this up at some point (probably by sitching to fast tokenizers)
for key, value in kwargs.items():
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
setattr(self, key, value)
elif isinstance(value, (str, AddedToken)):
setattr(self, key, value)
else:
@ -607,7 +614,7 @@ class SpecialTokensMixin:
Return:
Number of tokens added in the vocaulary during the operation.
"""
return self.add_tokens(self.all_special_tokens_extended, special_token=True)
return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)
def add_special_tokens(self, special_tokens_dict):
"""
@ -652,22 +659,56 @@ class SpecialTokensMixin:
added_tokens = 0
for key, value in special_tokens_dict.items():
assert key in self.SPECIAL_TOKENS_ATTRIBUTES
if self.verbose:
logger.info("Assigning %s to the %s key of the tokenizer", value, key)
setattr(self, key, value)
if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
added_tokens += self.add_tokens(value, special_token=True)
added_tokens += self.add_tokens(value, special_tokens=True)
else:
assert isinstance(value, str)
added_tokens += self.add_tokens([value], special_token=True)
added_tokens += self.add_tokens([value], special_tokens=True)
return added_tokens
def add_tokens(self, value, special_token=False):
""" To be overriden by derived class to add a token in the vocabulary. """
pass
def add_tokens(self, new_tokens: Union[str, AddedToken, List[str], List[AddedToken]], special_tokens=False) -> int:
"""
Add a list of new tokens to the tokenizer class. If the new tokens are not in the
vocabulary, they are added to it with indices starting from length of the current vocabulary.
Args:
new_tokens: string or list of string or :class:`~transformers.AddedToken`. Each string is a token to add.
Tokens are only added if they are not already in the vocabulary. AddedToken wrap a string token to
let you personnalize it's behavior (Whether this token should only match against single word, whether
this token should strip all potential whitespaces on the left side, Whether this token should strip
all potential whitespaces on the right side...).
special_token: can be used to specify if the token is a special token. This mostly change the normalization
behavior (special tokens like CLS or [MASK] are usually not lower-cased for instance)
See details for :class:`~transformers.AddedToken` in HuggingFace tokenizers library.
Returns:
Number of tokens added to the vocabulary.
Examples::
# Let's see how to increase the vocabulary of Bert model and tokenizer
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
print('We have added', num_added_toks, 'tokens')
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
"""
if not new_tokens:
return 0
if not isinstance(new_tokens, (list, tuple)):
new_tokens = [new_tokens]
return self._add_tokens(new_tokens, special_tokens=special_tokens)
@property
def bos_token(self):
@ -964,11 +1005,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
padding_side: str = "right"
def __init__(self, model_max_length=None, **kwargs):
super().__init__(**kwargs)
def __init__(self, **kwargs):
# inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
self.init_inputs = ()
self.init_kwargs = kwargs
# For backward compatibility we fallback to set model_max_length from max_len if provided
model_max_length = model_max_length if model_max_length is not None else kwargs.pop("max_len", None)
model_max_length = kwargs.pop("model_max_length", kwargs.pop("max_len", None))
self.model_max_length = model_max_length if model_max_length is not None else VERY_LARGE_INTEGER
# Padding side is right by default and overridden in subclasses. If specified in the kwargs, it is changed.
@ -979,9 +1022,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
], f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}"
self.model_input_names = kwargs.pop("model_input_names", self.model_input_names)
# inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
self.init_inputs = ()
self.init_kwargs = {}
super().__init__(**kwargs)
@property
def max_len(self) -> int:
@ -1125,8 +1166,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
"added_tokens_file": ADDED_TOKENS_FILE,
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
"full_tokenizer_file": FULL_TOKENIZER_FILE,
}
# Look for the tokenizer main vocabulary files + the additional tokens files
# Look for the tokenizer files
for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items():
if os.path.isdir(pretrained_model_name_or_path):
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
@ -1215,18 +1257,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
# Merge resolved_vocab_files arguments in init_kwargs.
added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None)
special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None)
for args_name, file_path in resolved_vocab_files.items():
if args_name not in init_kwargs:
init_kwargs[args_name] = file_path
if special_tokens_map_file is not None:
with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
special_tokens_map = json.load(special_tokens_map_handle)
for key, value in special_tokens_map.items():
if isinstance(value, dict):
value = AddedToken(**value)
if key not in init_kwargs:
init_kwargs[key] = value
# Instantiate tokenizer.
try:
@ -1241,20 +1274,39 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
tokenizer.init_inputs = init_inputs
tokenizer.init_kwargs = init_kwargs
# update unique_added_tokens_encoder with special tokens for correct tokenization
if hasattr(tokenizer, "unique_added_tokens_encoder"):
union = set(tokenizer.unique_added_tokens_encoder).union(tokenizer.all_special_tokens)
tokenizer.unique_added_tokens_encoder = list(union)
# If there is a complementary special token map, load it
special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None)
if special_tokens_map_file is not None:
with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
special_tokens_map = json.load(special_tokens_map_handle)
for key, value in special_tokens_map.items():
if isinstance(value, dict):
value = AddedToken(**value)
setattr(tokenizer, key, value)
# Add supplementary tokens.
special_tokens = tokenizer.all_special_tokens
if added_tokens_file is not None:
with open(added_tokens_file, encoding="utf-8") as added_tokens_handle:
added_tok_encoder = json.load(added_tokens_handle)
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
tokenizer.added_tokens_encoder.update(added_tok_encoder)
tokenizer.added_tokens_decoder.update(added_tok_decoder)
union = set(tokenizer.unique_added_tokens_encoder).union(tokenizer.added_tokens_encoder.keys())
tokenizer.unique_added_tokens_encoder = list(union)
# Sort added tokens by index
added_tok_encoder_sorted = list(sorted(added_tok_encoder.items(), key=lambda x: x[1]))
for token, index in added_tok_encoder_sorted:
assert index == len(tokenizer), (
f"Non-consecutive added token '{token}' found. "
f"Should have index {len(tokenizer)} but has index {index} in saved vocabulary."
)
tokenizer.add_tokens(token, special_tokens=bool(token in special_tokens))
# Check all our special tokens are registrered as "no split" token (we don't cut them) and are in the vocab
added_tokens = tokenizer.sanitize_special_tokens()
if added_tokens:
logger.warning(
"Special tokens have been added in the vocabulary, make sure the associated word emebedding are fine-tuned or trained."
)
return tokenizer
@ -1296,9 +1348,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
write_dict[key] = value
f.write(json.dumps(write_dict, ensure_ascii=False))
if hasattr(self, "added_tokens_encoder") and len(self.added_tokens_encoder) > 0:
added_vocab = self.get_added_vocab()
if added_vocab:
with open(added_tokens_file, "w", encoding="utf-8") as f:
out_str = json.dumps(self.added_tokens_encoder, ensure_ascii=False)
out_str = json.dumps(added_vocab, ensure_ascii=False)
f.write(out_str)
vocab_files = self.save_vocabulary(save_directory)

View File

@ -123,6 +123,12 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
def get_vocab(self) -> Dict[str, int]:
return self._tokenizer.get_vocab(with_added_tokens=True)
def get_added_vocab(self) -> Dict[str, int]:
base_vocab = self._tokenizer.get_vocab(with_added_tokens=False)
full_vocab = self._tokenizer.get_vocab(with_added_tokens=True)
added_vocab = dict((tok, index) for tok, index in full_vocab.items() if tok not in base_vocab)
return added_vocab
def __len__(self) -> int:
return self._tokenizer.get_vocab_size(with_added_tokens=True)
@ -206,37 +212,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
def convert_tokens_to_string(self, tokens: List[int], skip_special_tokens: bool = False) -> str:
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
def add_tokens(self, new_tokens: List[Union[str, AddedToken]], special_token=False) -> int:
"""
Add a list of new tokens to the tokenizer class. If the new tokens are not in the
vocabulary, they are added to it with indices starting from length of the current vocabulary.
Args:
new_tokens: string or list of string or :class:`~transformers.AddedToken`. Each string is a token to add.
Tokens are only added if they are not already in the vocabulary. AddedToken wrap a string token to
let you personnalize it's behavior (Whether this token should only match against single word, whether
this token should strip all potential whitespaces on the left side, Whether this token should strip
all potential whitespaces on the right side...).
See details for :class:`~transformers.AddedToken` in HuggingFace tokenizers library.
Returns:
Number of tokens added to the vocabulary.
Examples::
# Let's see how to increase the vocabulary of Bert model and tokenizer
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
print('We have added', num_added_toks, 'tokens')
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
"""
if not isinstance(new_tokens, (list, tuple)):
new_tokens = [new_tokens]
if special_token:
def _add_tokens(self, new_tokens: List[Union[str, AddedToken]], special_tokens=False) -> int:
if special_tokens:
return self._tokenizer.add_special_tokens(new_tokens)
return self._tokenizer.add_tokens(new_tokens)

View File

@ -20,10 +20,10 @@ import re
import shutil
import tempfile
from collections import OrderedDict
from typing import TYPE_CHECKING, Dict, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from tests.utils import require_tf, require_torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
if TYPE_CHECKING:
@ -93,7 +93,7 @@ class TokenizerTesterMixin:
output_ids = tokenizer.encode(output_txt, add_special_tokens=False)
return output_txt, output_ids
def get_tokenizers(self, fast=True, **kwargs) -> PreTrainedTokenizer:
def get_tokenizers(self, fast=True, **kwargs) -> List[PreTrainedTokenizerBase]:
if fast and self.test_rust_tokenizer:
return [self.get_tokenizer(**kwargs), self.get_rust_tokenizer(**kwargs)]
return [self.get_tokenizer(**kwargs)]
@ -101,7 +101,7 @@ class TokenizerTesterMixin:
def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer:
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs):
def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast:
raise NotImplementedError
# def get_input_output_texts(self) -> Tuple[str, str]:
@ -156,28 +156,62 @@ class TokenizerTesterMixin:
def test_save_and_load_tokenizer(self):
# safety check on max_len default value so we are sure the test works
tokenizers = self.get_tokenizers(fast=False)
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
self.assertNotEqual(tokenizer.max_len, 42)
# Now let's start the test
tokenizers = self.get_tokenizers(fast=False, model_max_length=42)
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
# Isolate this from the other tests because we save additional tokens/etc
tmpdirname = tempfile.mkdtemp()
sample_text = " He is very happy, UNwant\u00E9d,running"
before_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
before_vocab = tokenizer.get_vocab()
tokenizer.save_pretrained(tmpdirname)
tokenizer.save_pretrained(self.tmpdirname)
tokenizer = self.tokenizer_class.from_pretrained(self.tmpdirname)
after_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname)
after_tokens = after_tokenizer.encode(sample_text, add_special_tokens=False)
after_vocab = after_tokenizer.get_vocab()
self.assertListEqual(before_tokens, after_tokens)
self.assertDictEqual(before_vocab, after_vocab)
self.assertEqual(tokenizer.model_max_length, 42)
tokenizer = self.tokenizer_class.from_pretrained(self.tmpdirname, model_max_length=43)
shutil.rmtree(tmpdirname)
# Now let's start the test
tokenizers = self.get_tokenizers(model_max_length=42)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
# Isolate this from the other tests because we save additional tokens/etc
tmpdirname = tempfile.mkdtemp()
sample_text = " He is very happy, UNwant\u00E9d,running"
tokenizer.add_tokens(["bim", "bambam"])
additional_special_tokens = tokenizer.additional_special_tokens
additional_special_tokens.append("new_additional_special_token")
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
before_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
before_vocab = tokenizer.get_vocab()
tokenizer.save_pretrained(tmpdirname)
after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname)
after_tokens = after_tokenizer.encode(sample_text, add_special_tokens=False)
after_vocab = after_tokenizer.get_vocab()
self.assertListEqual(before_tokens, after_tokens)
self.assertDictEqual(before_vocab, after_vocab)
self.assertIn("bim", after_vocab)
self.assertIn("bambam", after_vocab)
self.assertIn("new_additional_special_token", after_tokenizer.additional_special_tokens)
self.assertEqual(after_tokenizer.model_max_length, 42)
tokenizer = tokenizer.__class__.from_pretrained(tmpdirname, model_max_length=43)
self.assertEqual(tokenizer.model_max_length, 43)
shutil.rmtree(tmpdirname)
def test_pickle_tokenizer(self):
"""Google pickle __getstate__ __setstate__ if you are struggling with this."""
tokenizers = self.get_tokenizers()
@ -265,7 +299,10 @@ class TokenizerTesterMixin:
all_size = len(tokenizer)
self.assertNotEqual(vocab_size, 0)
self.assertEqual(vocab_size, all_size)
# We usually have added tokens from the start in tests because our vocab fixtures are
# smaller than the original vocabs - let's not assert this
# self.assertEqual(vocab_size, all_size)
new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"]
added_toks = tokenizer.add_tokens(new_toks)