Add more tests on tokenizers serialization - fix bugs (#5056)
* update tests for fast tokenizers + fix small bug in saving/loading * better tests on serialization * fixing serialization * comment cleanup
This commit is contained in:
parent
0148c262e7
commit
7ac9110711
|
@ -20,7 +20,7 @@ import itertools
|
|||
import logging
|
||||
import re
|
||||
import unicodedata
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from .file_utils import add_end_docstrings
|
||||
from .tokenization_utils_base import (
|
||||
|
@ -155,10 +155,12 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
# Added tokens
|
||||
self.added_tokens_encoder = {}
|
||||
self.unique_added_tokens_encoder = []
|
||||
self.added_tokens_decoder = {}
|
||||
|
||||
# Added tokens - We store this for both slow and fast tokenizers
|
||||
# until the serialization of Fast tokenizers is updated
|
||||
self.added_tokens_encoder: Dict[str, int] = {}
|
||||
self.added_tokens_decoder: Dict[int, str] = {}
|
||||
self.unique_no_split_tokens: List[str] = []
|
||||
|
||||
@property
|
||||
def is_fast(self) -> bool:
|
||||
|
@ -173,11 +175,14 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||
""" Returns the vocabulary as a dict of {token: index} pairs. `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the vocab. """
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_added_vocab(self) -> Dict[str, int]:
|
||||
return self.added_tokens_encoder
|
||||
|
||||
def __len__(self):
|
||||
""" Size of the full vocabulary with the added tokens """
|
||||
return self.vocab_size + len(self.added_tokens_encoder)
|
||||
|
||||
def add_tokens(self, new_tokens: Union[str, List[str]], special_token=False) -> int:
|
||||
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens=False) -> int:
|
||||
"""
|
||||
Add a list of new tokens to the tokenizer class. If the new tokens are not in the
|
||||
vocabulary, they are added to it with indices starting from length of the current vocabulary.
|
||||
|
@ -199,16 +204,12 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||
print('We have added', num_added_toks, 'tokens')
|
||||
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
||||
"""
|
||||
if not new_tokens:
|
||||
return 0
|
||||
|
||||
if not isinstance(new_tokens, list):
|
||||
new_tokens = [new_tokens]
|
||||
new_tokens = [str(tok) for tok in new_tokens]
|
||||
|
||||
tokens_to_add = []
|
||||
for token in new_tokens:
|
||||
assert isinstance(token, (str, AddedToken))
|
||||
if self.init_kwargs.get("do_lower_case", False) and token not in self.all_special_tokens:
|
||||
assert isinstance(token, str)
|
||||
if not special_tokens and self.init_kwargs.get("do_lower_case", False):
|
||||
token = token.lower()
|
||||
if (
|
||||
token != self.unk_token
|
||||
|
@ -222,11 +223,15 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))
|
||||
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
|
||||
self.added_tokens_encoder.update(added_tok_encoder)
|
||||
self.unique_added_tokens_encoder = list(
|
||||
set(self.added_tokens_encoder.keys()).union(set(self.all_special_tokens))
|
||||
)
|
||||
self.added_tokens_decoder.update(added_tok_decoder)
|
||||
|
||||
# Make sure we don't split on any special tokens (even they were already in the vocab before e.g. for Albert)
|
||||
if special_tokens:
|
||||
self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(new_tokens)))
|
||||
else:
|
||||
# Or on the newly added tokens
|
||||
self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(tokens_to_add)))
|
||||
|
||||
return len(tokens_to_add)
|
||||
|
||||
def num_special_tokens_to_add(self, pair=False):
|
||||
|
@ -340,7 +345,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||
for tok in tok_list:
|
||||
tokenized_text = []
|
||||
for sub_text in text_list:
|
||||
if sub_text not in self.unique_added_tokens_encoder:
|
||||
if sub_text not in self.unique_no_split_tokens:
|
||||
tokenized_text += split_on_token(tok, sub_text)
|
||||
else:
|
||||
tokenized_text += [sub_text]
|
||||
|
@ -349,14 +354,14 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||
return list(
|
||||
itertools.chain.from_iterable(
|
||||
(
|
||||
self._tokenize(token) if token not in self.unique_added_tokens_encoder else [token]
|
||||
self._tokenize(token) if token not in self.unique_no_split_tokens else [token]
|
||||
for token in tokenized_text
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
added_tokens = self.unique_added_tokens_encoder
|
||||
tokenized_text = split_on_tokens(added_tokens, text)
|
||||
no_split_token = self.unique_no_split_tokens
|
||||
tokenized_text = split_on_tokens(no_split_token, text)
|
||||
return tokenized_text
|
||||
|
||||
def _tokenize(self, text, **kwargs):
|
||||
|
|
|
@ -62,9 +62,12 @@ PreTokenizedInputPair = Tuple[List[str], List[str]]
|
|||
EncodedInputPair = Tuple[List[int], List[int]]
|
||||
|
||||
|
||||
# Slow tokenizers used to be saved in three separated files
|
||||
SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
|
||||
ADDED_TOKENS_FILE = "added_tokens.json"
|
||||
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
|
||||
|
||||
# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
|
||||
FULL_TOKENIZER_FILE = "tokenizer.json"
|
||||
|
||||
|
||||
|
@ -589,10 +592,14 @@ class SpecialTokensMixin:
|
|||
self._additional_special_tokens = []
|
||||
self.verbose = verbose
|
||||
|
||||
# We directly set the hidden value to allow initialization with special tokens
|
||||
# which are not yet in the vocabulary. Necesssary for serialization/de-serialization
|
||||
# TODO clean this up at some point (probably by sitching to fast tokenizers)
|
||||
for key, value in kwargs.items():
|
||||
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
|
||||
if key == "additional_special_tokens":
|
||||
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
|
||||
setattr(self, key, value)
|
||||
elif isinstance(value, (str, AddedToken)):
|
||||
setattr(self, key, value)
|
||||
else:
|
||||
|
@ -607,7 +614,7 @@ class SpecialTokensMixin:
|
|||
Return:
|
||||
Number of tokens added in the vocaulary during the operation.
|
||||
"""
|
||||
return self.add_tokens(self.all_special_tokens_extended, special_token=True)
|
||||
return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)
|
||||
|
||||
def add_special_tokens(self, special_tokens_dict):
|
||||
"""
|
||||
|
@ -652,22 +659,56 @@ class SpecialTokensMixin:
|
|||
added_tokens = 0
|
||||
for key, value in special_tokens_dict.items():
|
||||
assert key in self.SPECIAL_TOKENS_ATTRIBUTES
|
||||
|
||||
if self.verbose:
|
||||
logger.info("Assigning %s to the %s key of the tokenizer", value, key)
|
||||
setattr(self, key, value)
|
||||
|
||||
if key == "additional_special_tokens":
|
||||
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
|
||||
added_tokens += self.add_tokens(value, special_token=True)
|
||||
added_tokens += self.add_tokens(value, special_tokens=True)
|
||||
else:
|
||||
assert isinstance(value, str)
|
||||
added_tokens += self.add_tokens([value], special_token=True)
|
||||
added_tokens += self.add_tokens([value], special_tokens=True)
|
||||
|
||||
return added_tokens
|
||||
|
||||
def add_tokens(self, value, special_token=False):
|
||||
""" To be overriden by derived class to add a token in the vocabulary. """
|
||||
pass
|
||||
def add_tokens(self, new_tokens: Union[str, AddedToken, List[str], List[AddedToken]], special_tokens=False) -> int:
|
||||
"""
|
||||
Add a list of new tokens to the tokenizer class. If the new tokens are not in the
|
||||
vocabulary, they are added to it with indices starting from length of the current vocabulary.
|
||||
|
||||
Args:
|
||||
new_tokens: string or list of string or :class:`~transformers.AddedToken`. Each string is a token to add.
|
||||
Tokens are only added if they are not already in the vocabulary. AddedToken wrap a string token to
|
||||
let you personnalize it's behavior (Whether this token should only match against single word, whether
|
||||
this token should strip all potential whitespaces on the left side, Whether this token should strip
|
||||
all potential whitespaces on the right side...).
|
||||
special_token: can be used to specify if the token is a special token. This mostly change the normalization
|
||||
behavior (special tokens like CLS or [MASK] are usually not lower-cased for instance)
|
||||
|
||||
See details for :class:`~transformers.AddedToken` in HuggingFace tokenizers library.
|
||||
|
||||
Returns:
|
||||
Number of tokens added to the vocabulary.
|
||||
|
||||
Examples::
|
||||
|
||||
# Let's see how to increase the vocabulary of Bert model and tokenizer
|
||||
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
||||
model = BertModel.from_pretrained('bert-base-uncased')
|
||||
|
||||
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
|
||||
print('We have added', num_added_toks, 'tokens')
|
||||
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
||||
"""
|
||||
if not new_tokens:
|
||||
return 0
|
||||
|
||||
if not isinstance(new_tokens, (list, tuple)):
|
||||
new_tokens = [new_tokens]
|
||||
|
||||
return self._add_tokens(new_tokens, special_tokens=special_tokens)
|
||||
|
||||
@property
|
||||
def bos_token(self):
|
||||
|
@ -964,11 +1005,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
|||
|
||||
padding_side: str = "right"
|
||||
|
||||
def __init__(self, model_max_length=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
def __init__(self, **kwargs):
|
||||
# inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
|
||||
self.init_inputs = ()
|
||||
self.init_kwargs = kwargs
|
||||
|
||||
# For backward compatibility we fallback to set model_max_length from max_len if provided
|
||||
model_max_length = model_max_length if model_max_length is not None else kwargs.pop("max_len", None)
|
||||
model_max_length = kwargs.pop("model_max_length", kwargs.pop("max_len", None))
|
||||
self.model_max_length = model_max_length if model_max_length is not None else VERY_LARGE_INTEGER
|
||||
|
||||
# Padding side is right by default and overridden in subclasses. If specified in the kwargs, it is changed.
|
||||
|
@ -979,9 +1022,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
|||
], f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}"
|
||||
self.model_input_names = kwargs.pop("model_input_names", self.model_input_names)
|
||||
|
||||
# inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
|
||||
self.init_inputs = ()
|
||||
self.init_kwargs = {}
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def max_len(self) -> int:
|
||||
|
@ -1125,8 +1166,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
|||
"added_tokens_file": ADDED_TOKENS_FILE,
|
||||
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
|
||||
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
|
||||
"full_tokenizer_file": FULL_TOKENIZER_FILE,
|
||||
}
|
||||
# Look for the tokenizer main vocabulary files + the additional tokens files
|
||||
# Look for the tokenizer files
|
||||
for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items():
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
|
||||
|
@ -1215,18 +1257,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
|||
|
||||
# Merge resolved_vocab_files arguments in init_kwargs.
|
||||
added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None)
|
||||
special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None)
|
||||
for args_name, file_path in resolved_vocab_files.items():
|
||||
if args_name not in init_kwargs:
|
||||
init_kwargs[args_name] = file_path
|
||||
if special_tokens_map_file is not None:
|
||||
with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
|
||||
special_tokens_map = json.load(special_tokens_map_handle)
|
||||
for key, value in special_tokens_map.items():
|
||||
if isinstance(value, dict):
|
||||
value = AddedToken(**value)
|
||||
if key not in init_kwargs:
|
||||
init_kwargs[key] = value
|
||||
|
||||
# Instantiate tokenizer.
|
||||
try:
|
||||
|
@ -1241,20 +1274,39 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
|||
tokenizer.init_inputs = init_inputs
|
||||
tokenizer.init_kwargs = init_kwargs
|
||||
|
||||
# update unique_added_tokens_encoder with special tokens for correct tokenization
|
||||
if hasattr(tokenizer, "unique_added_tokens_encoder"):
|
||||
union = set(tokenizer.unique_added_tokens_encoder).union(tokenizer.all_special_tokens)
|
||||
tokenizer.unique_added_tokens_encoder = list(union)
|
||||
# If there is a complementary special token map, load it
|
||||
special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None)
|
||||
if special_tokens_map_file is not None:
|
||||
with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
|
||||
special_tokens_map = json.load(special_tokens_map_handle)
|
||||
|
||||
for key, value in special_tokens_map.items():
|
||||
if isinstance(value, dict):
|
||||
value = AddedToken(**value)
|
||||
setattr(tokenizer, key, value)
|
||||
|
||||
# Add supplementary tokens.
|
||||
special_tokens = tokenizer.all_special_tokens
|
||||
if added_tokens_file is not None:
|
||||
with open(added_tokens_file, encoding="utf-8") as added_tokens_handle:
|
||||
added_tok_encoder = json.load(added_tokens_handle)
|
||||
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
|
||||
tokenizer.added_tokens_encoder.update(added_tok_encoder)
|
||||
tokenizer.added_tokens_decoder.update(added_tok_decoder)
|
||||
union = set(tokenizer.unique_added_tokens_encoder).union(tokenizer.added_tokens_encoder.keys())
|
||||
tokenizer.unique_added_tokens_encoder = list(union)
|
||||
|
||||
# Sort added tokens by index
|
||||
added_tok_encoder_sorted = list(sorted(added_tok_encoder.items(), key=lambda x: x[1]))
|
||||
|
||||
for token, index in added_tok_encoder_sorted:
|
||||
assert index == len(tokenizer), (
|
||||
f"Non-consecutive added token '{token}' found. "
|
||||
f"Should have index {len(tokenizer)} but has index {index} in saved vocabulary."
|
||||
)
|
||||
tokenizer.add_tokens(token, special_tokens=bool(token in special_tokens))
|
||||
|
||||
# Check all our special tokens are registrered as "no split" token (we don't cut them) and are in the vocab
|
||||
added_tokens = tokenizer.sanitize_special_tokens()
|
||||
if added_tokens:
|
||||
logger.warning(
|
||||
"Special tokens have been added in the vocabulary, make sure the associated word emebedding are fine-tuned or trained."
|
||||
)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
@ -1296,9 +1348,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
|||
write_dict[key] = value
|
||||
f.write(json.dumps(write_dict, ensure_ascii=False))
|
||||
|
||||
if hasattr(self, "added_tokens_encoder") and len(self.added_tokens_encoder) > 0:
|
||||
added_vocab = self.get_added_vocab()
|
||||
if added_vocab:
|
||||
with open(added_tokens_file, "w", encoding="utf-8") as f:
|
||||
out_str = json.dumps(self.added_tokens_encoder, ensure_ascii=False)
|
||||
out_str = json.dumps(added_vocab, ensure_ascii=False)
|
||||
f.write(out_str)
|
||||
|
||||
vocab_files = self.save_vocabulary(save_directory)
|
||||
|
|
|
@ -123,6 +123,12 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
|||
def get_vocab(self) -> Dict[str, int]:
|
||||
return self._tokenizer.get_vocab(with_added_tokens=True)
|
||||
|
||||
def get_added_vocab(self) -> Dict[str, int]:
|
||||
base_vocab = self._tokenizer.get_vocab(with_added_tokens=False)
|
||||
full_vocab = self._tokenizer.get_vocab(with_added_tokens=True)
|
||||
added_vocab = dict((tok, index) for tok, index in full_vocab.items() if tok not in base_vocab)
|
||||
return added_vocab
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._tokenizer.get_vocab_size(with_added_tokens=True)
|
||||
|
||||
|
@ -206,37 +212,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
|||
def convert_tokens_to_string(self, tokens: List[int], skip_special_tokens: bool = False) -> str:
|
||||
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
def add_tokens(self, new_tokens: List[Union[str, AddedToken]], special_token=False) -> int:
|
||||
"""
|
||||
Add a list of new tokens to the tokenizer class. If the new tokens are not in the
|
||||
vocabulary, they are added to it with indices starting from length of the current vocabulary.
|
||||
|
||||
Args:
|
||||
new_tokens: string or list of string or :class:`~transformers.AddedToken`. Each string is a token to add.
|
||||
Tokens are only added if they are not already in the vocabulary. AddedToken wrap a string token to
|
||||
let you personnalize it's behavior (Whether this token should only match against single word, whether
|
||||
this token should strip all potential whitespaces on the left side, Whether this token should strip
|
||||
all potential whitespaces on the right side...).
|
||||
|
||||
See details for :class:`~transformers.AddedToken` in HuggingFace tokenizers library.
|
||||
|
||||
Returns:
|
||||
Number of tokens added to the vocabulary.
|
||||
|
||||
Examples::
|
||||
|
||||
# Let's see how to increase the vocabulary of Bert model and tokenizer
|
||||
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
||||
model = BertModel.from_pretrained('bert-base-uncased')
|
||||
|
||||
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
|
||||
print('We have added', num_added_toks, 'tokens')
|
||||
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
||||
"""
|
||||
if not isinstance(new_tokens, (list, tuple)):
|
||||
new_tokens = [new_tokens]
|
||||
|
||||
if special_token:
|
||||
def _add_tokens(self, new_tokens: List[Union[str, AddedToken]], special_tokens=False) -> int:
|
||||
if special_tokens:
|
||||
return self._tokenizer.add_special_tokens(new_tokens)
|
||||
|
||||
return self._tokenizer.add_tokens(new_tokens)
|
||||
|
|
|
@ -20,10 +20,10 @@ import re
|
|||
import shutil
|
||||
import tempfile
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, Dict, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
||||
|
||||
from tests.utils import require_tf, require_torch
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -93,7 +93,7 @@ class TokenizerTesterMixin:
|
|||
output_ids = tokenizer.encode(output_txt, add_special_tokens=False)
|
||||
return output_txt, output_ids
|
||||
|
||||
def get_tokenizers(self, fast=True, **kwargs) -> PreTrainedTokenizer:
|
||||
def get_tokenizers(self, fast=True, **kwargs) -> List[PreTrainedTokenizerBase]:
|
||||
if fast and self.test_rust_tokenizer:
|
||||
return [self.get_tokenizer(**kwargs), self.get_rust_tokenizer(**kwargs)]
|
||||
return [self.get_tokenizer(**kwargs)]
|
||||
|
@ -101,7 +101,7 @@ class TokenizerTesterMixin:
|
|||
def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer:
|
||||
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_rust_tokenizer(self, **kwargs):
|
||||
def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast:
|
||||
raise NotImplementedError
|
||||
|
||||
# def get_input_output_texts(self) -> Tuple[str, str]:
|
||||
|
@ -156,28 +156,62 @@ class TokenizerTesterMixin:
|
|||
|
||||
def test_save_and_load_tokenizer(self):
|
||||
# safety check on max_len default value so we are sure the test works
|
||||
tokenizers = self.get_tokenizers(fast=False)
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
self.assertNotEqual(tokenizer.max_len, 42)
|
||||
|
||||
# Now let's start the test
|
||||
tokenizers = self.get_tokenizers(fast=False, model_max_length=42)
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
sample_text = "He is very happy, UNwant\u00E9d,running"
|
||||
# Isolate this from the other tests because we save additional tokens/etc
|
||||
tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
sample_text = " He is very happy, UNwant\u00E9d,running"
|
||||
before_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
|
||||
before_vocab = tokenizer.get_vocab()
|
||||
tokenizer.save_pretrained(tmpdirname)
|
||||
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
tokenizer = self.tokenizer_class.from_pretrained(self.tmpdirname)
|
||||
|
||||
after_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
|
||||
after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname)
|
||||
after_tokens = after_tokenizer.encode(sample_text, add_special_tokens=False)
|
||||
after_vocab = after_tokenizer.get_vocab()
|
||||
self.assertListEqual(before_tokens, after_tokens)
|
||||
self.assertDictEqual(before_vocab, after_vocab)
|
||||
|
||||
self.assertEqual(tokenizer.model_max_length, 42)
|
||||
tokenizer = self.tokenizer_class.from_pretrained(self.tmpdirname, model_max_length=43)
|
||||
shutil.rmtree(tmpdirname)
|
||||
|
||||
# Now let's start the test
|
||||
tokenizers = self.get_tokenizers(model_max_length=42)
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
# Isolate this from the other tests because we save additional tokens/etc
|
||||
tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
sample_text = " He is very happy, UNwant\u00E9d,running"
|
||||
tokenizer.add_tokens(["bim", "bambam"])
|
||||
additional_special_tokens = tokenizer.additional_special_tokens
|
||||
additional_special_tokens.append("new_additional_special_token")
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
||||
before_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
|
||||
before_vocab = tokenizer.get_vocab()
|
||||
tokenizer.save_pretrained(tmpdirname)
|
||||
|
||||
after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname)
|
||||
after_tokens = after_tokenizer.encode(sample_text, add_special_tokens=False)
|
||||
after_vocab = after_tokenizer.get_vocab()
|
||||
self.assertListEqual(before_tokens, after_tokens)
|
||||
self.assertDictEqual(before_vocab, after_vocab)
|
||||
self.assertIn("bim", after_vocab)
|
||||
self.assertIn("bambam", after_vocab)
|
||||
self.assertIn("new_additional_special_token", after_tokenizer.additional_special_tokens)
|
||||
self.assertEqual(after_tokenizer.model_max_length, 42)
|
||||
|
||||
tokenizer = tokenizer.__class__.from_pretrained(tmpdirname, model_max_length=43)
|
||||
self.assertEqual(tokenizer.model_max_length, 43)
|
||||
|
||||
shutil.rmtree(tmpdirname)
|
||||
|
||||
def test_pickle_tokenizer(self):
|
||||
"""Google pickle __getstate__ __setstate__ if you are struggling with this."""
|
||||
tokenizers = self.get_tokenizers()
|
||||
|
@ -265,7 +299,10 @@ class TokenizerTesterMixin:
|
|||
all_size = len(tokenizer)
|
||||
|
||||
self.assertNotEqual(vocab_size, 0)
|
||||
self.assertEqual(vocab_size, all_size)
|
||||
|
||||
# We usually have added tokens from the start in tests because our vocab fixtures are
|
||||
# smaller than the original vocabs - let's not assert this
|
||||
# self.assertEqual(vocab_size, all_size)
|
||||
|
||||
new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"]
|
||||
added_toks = tokenizer.add_tokens(new_toks)
|
||||
|
|
Loading…
Reference in New Issue