Tokenizers v3.0.0 (#3185)

* Renamed num_added_tokens to num_special_tokens_to_add

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Cherry-Pick: Partially fix space only input without special tokens added to the output #3091

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Added property is_fast on PretrainedTokenizer and PretrainedTokenizerFast

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Make fast tokenizers unittests work on Windows.

* Entirely refactored unittest for tokenizers fast.

* Remove ABC class for CommonFastTokenizerTest

* Added embeded_special_tokens tests from allenai @dirkgr

* Make embeded_special_tokens tests from allenai more generic

* Uniformize vocab_size as a property for both Fast and normal tokenizers

* Move special tokens handling out of PretrainedTokenizer (SpecialTokensMixin)

* Ensure providing None input raise the same ValueError than Python tokenizer + tests.

* Fix invalid input for assert_padding when testing batch_encode_plus

* Move add_special_tokens from constructor to tokenize/encode/[batch_]encode_plus methods parameter.

* Ensure tokenize() correctly forward add_special_tokens to rust.

* Adding None checking on top on encode / encode_batch for TransfoXLTokenizerFast.
Avoid stripping on None values.

* unittests ensure tokenize() also throws a ValueError if provided None

* Added add_special_tokens unittest for all supported models.

* Style

* Make sure TransfoXL test run only if PyTorch is provided.

* Split up tokenizers tests for each model type.

* Fix invalid unittest with new tokenizers API.

* Filter out Roberta openai detector models from unittests.

* Introduce BatchEncoding on fast tokenizers path.

This new structure exposes all the mappings retrieved from Rust.
It also keeps the current behavior with model forward.

* Introduce BatchEncoding on slow tokenizers path.

Backward compatibility.

* Improve error message on BatchEncoding for slow path

* Make add_prefix_space True by default on Roberta fast to match Python in majority of cases.

* Style and format.

* Added typing on all methods for PretrainedTokenizerFast

* Style and format

* Added path for feeding pretokenized (List[str]) input to PretrainedTokenizerFast.

* Style and format

* encode_plus now supports pretokenized inputs.

* Remove user warning about add_special_tokens when working on pretokenized inputs.

* Always go through the post processor.

* Added support for pretokenized input pairs on encode_plus

* Added is_pretokenized flag on encode_plus for clarity and improved error message on input TypeError.

* Added pretokenized inputs support on batch_encode_plus

* Update BatchEncoding methods name to match Encoding.

* Bump setup.py tokenizers dependency to 0.7.0rc1

* Remove unused parameters in BertTokenizerFast

* Make sure Roberta returns token_type_ids for unittests.

* Added missing typings

* Update add_tokens prototype to match tokenizers side and allow AddedToken

* Bumping tokenizers to 0.7.0rc2

* Added documentation for BatchEncoding

* Added (unused) is_pretokenized parameter on PreTrainedTokenizer encode_plus/batch_encode_plus methods.

* Added higher-level typing for tokenize / encode_plus / batch_encode_plus.

* Fix unittests failing because add_special_tokens was defined as a constructor parameter on Rust Tokenizers.

* Fix text-classification pipeline using the wrong tokenizer

* Make pipelines works with BatchEncoding

* Turn off add_special_tokens on tokenize by default.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Remove add_prefix_space from tokenize call in unittest.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Style and quality

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Correct message for batch_encode_plus none input exception.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Fix invalid list comprehension for offset_mapping overriding content every iteration.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* TransfoXL uses Strip normalizer.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Bump tokenizers dependency to 0.7.0rc3

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Support AddedTokens for special_tokens and use left stripping on mask for Roberta.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* SpecilaTokenMixin can use slots to faster access to underlying attributes.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Remove update_special_tokens from fast tokenizers.

* Ensure TransfoXL unittests are run only when torch is available.

* Style.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Style

* Style 🙏🙏

* Remove slots on SpecialTokensMixin, need deep dive into pickle protocol.

* Remove Roberta warning on __init__.

* Move documentation to Google style.

Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
Funtowicz Morgan 2020-04-06 22:29:15 +00:00 committed by GitHub
parent e52d1258e0
commit 96ab75b8dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 852 additions and 579 deletions

View File

@ -96,7 +96,7 @@ setup(
packages=find_packages("src"),
install_requires=[
"numpy",
"tokenizers == 0.5.2",
"tokenizers == 0.7.0rc3",
# dataclasses for Python versions that don't have it
"dataclasses;python_version<'3.7'",
# accessing files from S3 directly

View File

@ -459,7 +459,7 @@ class Pipeline(_ScikitCompat):
)
# Filter out features not available on specific models
inputs = self.inputs_for_model(inputs)
# inputs = self.inputs_for_model(inputs)
return inputs
@ -480,7 +480,7 @@ class Pipeline(_ScikitCompat):
with self.device_placement():
if self.framework == "tf":
# TODO trace model
predictions = self.model(inputs, training=False)[0]
predictions = self.model(inputs.data, training=False)[0]
else:
with torch.no_grad():
inputs = self.ensure_tensor_on_device(**inputs)
@ -778,7 +778,7 @@ class NerPipeline(Pipeline):
# Forward
if self.framework == "tf":
entities = self.model(tokens)[0][0].numpy()
entities = self.model(tokens.data)[0][0].numpy()
input_ids = tokens["input_ids"].numpy()[0]
else:
with torch.no_grad():
@ -1399,7 +1399,7 @@ SUPPORTED_TASKS = {
"tf": "distilbert-base-uncased-finetuned-sst-2-english",
},
"config": "distilbert-base-uncased-finetuned-sst-2-english",
"tokenizer": "distilbert-base-uncased",
"tokenizer": "distilbert-base-cased",
},
},
"ner": {

View File

@ -592,8 +592,6 @@ class BertTokenizerFast(PreTrainedTokenizerFast):
self,
vocab_file,
do_lower_case=True,
do_basic_tokenize=True,
never_split=None,
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
@ -601,7 +599,6 @@ class BertTokenizerFast(PreTrainedTokenizerFast):
mask_token="[MASK]",
clean_text=True,
tokenize_chinese_chars=True,
add_special_tokens=True,
strip_accents=True,
wordpieces_prefix="##",
**kwargs
@ -609,7 +606,6 @@ class BertTokenizerFast(PreTrainedTokenizerFast):
super().__init__(
BertWordPieceTokenizer(
vocab_file=vocab_file,
add_special_tokens=add_special_tokens,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,

View File

@ -18,9 +18,11 @@
import logging
from typing import List, Optional
from tokenizers import AddedToken
from tokenizers.processors import RobertaProcessing
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__)
@ -259,7 +261,7 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
add_prefix_space=False,
add_prefix_space=True,
**kwargs
):
kwargs.setdefault("pad_token", pad_token)
@ -281,16 +283,24 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
(sep_token, self.sep_token_id), (cls_token, self.cls_token_id)
)
self.tokenizer.add_special_tokens([kwargs["mask_token"]])
# As we override the post_processor post super.__init__ the computed num_added_tokens is wrong in super().
# We need to recompute max_len according to the newly register post_processor to get real values.
self.max_len_single_sentence = self.max_len - self.num_added_tokens(False) # take into account special tokens
self.max_len_sentences_pair = self.max_len - self.num_added_tokens(True) # take into account special tokens
self.max_len_single_sentence = self.max_len - self.num_special_tokens_to_add(
False
) # take into account special tokens
self.max_len_sentences_pair = self.max_len - self.num_special_tokens_to_add(
True
) # take into account special tokens
logger.warning(
"RobertaTokenizerFast has an issue when working on mask language modeling "
"where it introduces an extra encoded space before the mask token."
"See https://github.com/huggingface/transformers/pull/2778 for more information."
)
@PreTrainedTokenizer.mask_token.setter
def mask_token(self, value):
if not isinstance(value, AddedToken):
value = AddedToken(value, lstrip=True)
self._mask_token = str(value)
self.tokenizer.add_special_tokens([value])
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]

View File

@ -24,13 +24,13 @@ import os
import pickle
import re
from collections import Counter, OrderedDict
from typing import List, Optional, Tuple, Union
from typing import Optional
import numpy as np
from tokenizers import Encoding, Tokenizer
from tokenizers import Tokenizer
from tokenizers.implementations import BaseTokenizer
from tokenizers.models import WordLevel
from tokenizers.normalizers import Lowercase, Sequence, unicode_normalizer_from_str
from tokenizers.normalizers import Lowercase, Sequence, Strip, unicode_normalizer_from_str
from tokenizers.pre_tokenizers import CharDelimiterSplit, WhitespaceSplit
from tokenizers.processors import BertProcessing
@ -381,6 +381,9 @@ class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer):
if lowercase:
normalizer += [Lowercase()]
# Strip normalizer at the end
normalizer += [Strip(left=True, right=True)]
if len(normalizer) > 0:
tokenizer.normalizer = Sequence(normalizer) if len(normalizer) > 1 else normalizer[0]
@ -404,14 +407,6 @@ class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer):
super().__init__(tokenizer, parameters)
def encode_batch(self, sequences: List[Union[str, Tuple[str, str]]]) -> List[Encoding]:
return super().encode_batch(
[seq.strip() if isinstance(seq, str) else (seq[0].strip(), seq[1].strip()) for seq in sequences]
)
def encode(self, sequence: str, pair: Optional[str] = None) -> Encoding:
return super().encode(sequence.strip(), pair.strip() if pair else pair)
class TransfoXLTokenizerFast(PreTrainedTokenizerFast):

View File

@ -15,15 +15,19 @@
"""Tokenization classes for OpenAI GPT."""
import copy
import functools
import itertools
import json
import logging
import operator
import os
import re
from collections import defaultdict
from collections import UserDict, defaultdict
from contextlib import contextmanager
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Sequence, Tuple, Union
from tokenizers import AddedToken, Encoding
from tokenizers.decoders import Decoder
from tokenizers.implementations import BaseTokenizer
from .file_utils import cached_path, hf_bucket_url, is_remote_url, is_tf_available, is_torch_available
@ -41,6 +45,27 @@ ADDED_TOKENS_FILE = "added_tokens.json"
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
# Define type aliases
TextInput = str
TextPairInput = Tuple[str, str]
PreTokenizedInput = List[str]
PreTokenizedInputPair = Tuple[List[str], List[str]]
def flatten(x: Sequence):
"""
Flatten the provided (potentially nested) sequence
Args:
x (Sequence): Potentially nested sequence to flatten
Returns:
list: Flattened sequence
"""
return functools.reduce(operator.iconcat, x, [])
@contextmanager
def truncate_and_pad(
tokenizer: BaseTokenizer,
@ -61,16 +86,19 @@ def truncate_and_pad(
before the managed section. If your tokenizer set a padding / truncation strategy before,
then it will be reset to no padding/truncation when exiting the managed section.
:param tokenizer:
:param max_length:
:param stride:
:param strategy:
:param pad_to_max_length:
:param padding_side:
:param pad_token_id:
:param pad_token_type_id:
:param pad_token:
:return:
Args:
tokenizer (BaseTokenizer): The tokenizer which will be used
max_length (int): The maximum size of the sequence
stride (int): The stride to use when handling overflow
strategy (str): Overflowing logic to use
pad_to_max_length (bool): Boolean indicating if the output needs to be padded up to max_length
padding_side (str): "left" or "right" indicating the direction the output sequence will be padded
pad_token_id (int): The integer representation of the padding token to use
pad_token_type_id (int): The integer representation of the padding token type to use
pad_token (str): The string representation of the padding token to use
Returns:
"""
# Handle all the truncation and padding stuff
@ -103,44 +131,118 @@ def truncate_and_pad(
tokenizer.no_padding()
class PreTrainedTokenizer(object):
""" Base class for all tokenizers.
Handle all the shared methods for tokenization and special tokens as well as methods downloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary.
class BatchEncoding(UserDict):
"""
Data structure derived from Dictionary holding all the required information to forward through
a model.
This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...).
Class attributes (overridden by derived classes):
- ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string).
- ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the associated pretrained vocabulary file.
- ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or None if the model has no maximum input size.
- ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, a dictionnary of specific arguments to pass to the ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the ``from_pretrained()`` method.
Parameters:
- ``bos_token``: (`Optional`) string: a beginning of sentence token. Will be associated to ``self.bos_token`` and ``self.bos_token_id``
- ``eos_token``: (`Optional`) string: an end of sentence token. Will be associated to ``self.eos_token`` and ``self.eos_token_id``
- ``unk_token``: (`Optional`) string: an unknown token. Will be associated to ``self.unk_token`` and ``self.unk_token_id``
- ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence). Will be associated to ``self.sep_token`` and ``self.sep_token_id``
- ``pad_token``: (`Optional`) string: a padding token. Will be associated to ``self.pad_token`` and ``self.pad_token_id``
- ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model). Will be associated to ``self.cls_token`` and ``self.cls_token_id``
- ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language modeling). Will be associated to ``self.mask_token`` and ``self.mask_token_id``
- ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens. Adding all special tokens here ensure they won't be split by the tokenization process. Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids``
In addition, this structure expose utility methods to map from word/char space to token space.
"""
vocab_files_names = {}
pretrained_vocab_files_map = {}
pretrained_init_configuration = {}
max_model_input_sizes = {}
model_input_names = ["token_type_ids", "attention_mask"]
def __init__(self, data: dict, encoding: Optional[Union[Encoding, Sequence[Encoding]]] = None):
super().__init__(data)
if isinstance(encoding, Encoding):
encoding = [encoding]
self._encodings = encoding
def __getitem__(self, item: Union[int, str]) -> Encoding:
if isinstance(item, str):
return self.data[item]
elif self._encodings is not None:
return self._encodings[item]
else:
raise KeyError("int index is supported only on {} from a Rust tokenizer".format(type(self).__name__))
def __getattr__(self, item: str):
return self.data[item]
@property
def encodings(self) -> Optional[List[Encoding]]:
"""
Return the list all encoding from the tokenization process
Returns: List[Encoding] or None if input was tokenized through Python tokenizer
"""
return self._encodings
def keys(self):
return self.data.keys()
def values(self):
return self.data.values()
def items(self):
return self.data.items()
def char_to_token_offsets(self, sentence: int, char: int) -> Tuple[int, int]:
"""
Find the Offsets of the token containing the character at the specified position
Args:
sentence: Index of the sentence relative to the batch provided to the tokenizer
char: Char index to get the relative token offsets
Returns:
tuple: (token start, token end)
"""
if not self._encodings:
raise ValueError("char_to_token_offsets() is not available when using Python based tokenizers")
return self[sentence].char_to_token_offsets(char)
def char_to_token(self, sentence: int, char: int) -> int:
"""
Return the index of the token at position of the given char.
Args:
sentence (int): Index of the sentence relative to the batch provided to the tokenizer
char (int): Char index to get the relative token offsets
Returns:
int: Integer referring to the position of the token in the returned set of tokens for the sentence
"""
if not self._encodings:
raise ValueError("char_to_token() is not available when using Python based tokenizers")
return self[sentence].char_to_token(char)
def char_to_word_offsets(self, sentence: int, char: int) -> Tuple[int, int]:
"""
Find the Offsets of the word containing the character at the specified position
Args:
sentence (int): Index of the sentence relative to the batch provided to the tokenizer
char (int): Char index to get the relative token offsets
Returns:
tuple: (word start, word end) representing the first and last characters of the word
"""
if not self._encodings:
raise ValueError("char_to_word_offsets() is not available when using Python based tokenizers")
return self[sentence].char_to_word_offsets(char)
def token_to_word_offsets(self, sentence: int, index: int) -> Optional[Tuple[int, int]]:
"""
Find the Offsets of the word containing the token at the given index
Args:
sentence (int): Index of the sentence relative to the batch provided to the tokenizer
index (int): Index of the token to map to the original word offsets
Returns:
Optional[tuple]: (word start, word end) or None
"""
if not self._encodings:
raise ValueError("token_to_word_offsets() is not available when using Python based tokenizers")
return self[sentence].token_to_word_offsets(index)
class SpecialTokensMixin:
SPECIAL_TOKENS_ATTRIBUTES = [
"bos_token",
"eos_token",
@ -152,18 +254,29 @@ class PreTrainedTokenizer(object):
"additional_special_tokens",
]
padding_side = "right"
def __init__(self, **kwargs):
NO_PAD_TOKEN_FOR_BATCH_MSG = (
"No padding token is set for this model, therefore no batch can be made with uneven "
"sequences. Set a padding token or adjust the lengths of the sequences building the "
"batch so that every sequence is of the same length."
)
self._bos_token = None
self._eos_token = None
self._unk_token = None
self._sep_token = None
self._pad_token = None
self._cls_token = None
self._mask_token = None
self._pad_token_type_id = 0
self._additional_special_tokens = []
UNEVEN_SEQUENCES_FOR_BATCH_MSG = (
"The sequences building the batch are not of the same size, no tensor "
"can be built. Set `pad_to_max_length=True` to pad the smaller sequences"
"up to the larger sequence's length."
for key, value in kwargs.items():
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
elif isinstance(value, AddedToken):
setattr(self, key, str(value))
elif isinstance(value, str):
setattr(self, key, value)
else:
raise TypeError(
"special token {} has to be either str or AddedToken but got: {}".format(key, type(value))
)
@property
@ -250,10 +363,6 @@ class PreTrainedTokenizer(object):
def mask_token(self, value):
self._mask_token = value
@additional_special_tokens.setter
def additional_special_tokens(self, value):
self._additional_special_tokens = value
@property
def bos_token_id(self):
""" Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """
@ -299,20 +408,112 @@ class PreTrainedTokenizer(object):
""" Ids of all the additional special tokens in the vocabulary (list of integers). Log an error if used while not having been set. """
return self.convert_tokens_to_ids(self.additional_special_tokens)
@property
def special_tokens_map(self):
""" A dictionary mapping special token class attribute (cls_token, unk_token...) to their
values ('<unk>', '<cls>'...)
"""
set_attr = {}
for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
attr_value = getattr(self, "_" + attr)
if attr_value:
set_attr[attr] = attr_value
return set_attr
@property
def all_special_tokens(self):
""" List all the special tokens ('<unk>', '<cls>'...) mapped to class attributes
(cls_token, unk_token...).
"""
all_toks = []
set_attr = self.special_tokens_map
for attr_value in set_attr.values():
all_toks = all_toks + (list(attr_value) if isinstance(attr_value, (list, tuple)) else [attr_value])
all_toks = list(set(all_toks))
return all_toks
@property
def all_special_ids(self):
""" List the vocabulary indices of the special tokens ('<unk>', '<cls>'...) mapped to
class attributes (cls_token, unk_token...).
"""
all_toks = self.all_special_tokens
all_ids = self.convert_tokens_to_ids(all_toks)
return all_ids
@additional_special_tokens.setter
def additional_special_tokens(self, value):
self._additional_special_tokens = value
class PreTrainedTokenizer(SpecialTokensMixin):
""" Base class for all tokenizers.
Handle all the shared methods for tokenization and special tokens as well as methods downloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary.
This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...).
Class attributes (overridden by derived classes):
- ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string).
- ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the associated pretrained vocabulary file.
- ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or None if the model has no maximum input size.
- ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, a dictionnary of specific arguments to pass to the ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the ``from_pretrained()`` method.
Parameters:
- ``bos_token``: (`Optional`) string: a beginning of sentence token. Will be associated to ``self.bos_token`` and ``self.bos_token_id``
- ``eos_token``: (`Optional`) string: an end of sentence token. Will be associated to ``self.eos_token`` and ``self.eos_token_id``
- ``unk_token``: (`Optional`) string: an unknown token. Will be associated to ``self.unk_token`` and ``self.unk_token_id``
- ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence). Will be associated to ``self.sep_token`` and ``self.sep_token_id``
- ``pad_token``: (`Optional`) string: a padding token. Will be associated to ``self.pad_token`` and ``self.pad_token_id``
- ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model). Will be associated to ``self.cls_token`` and ``self.cls_token_id``
- ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language modeling). Will be associated to ``self.mask_token`` and ``self.mask_token_id``
- ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens. Adding all special tokens here ensure they won't be split by the tokenization process. Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids``
"""
vocab_files_names = {}
pretrained_vocab_files_map = {}
pretrained_init_configuration = {}
max_model_input_sizes = {}
model_input_names = ["token_type_ids", "attention_mask"]
padding_side = "right"
NO_PAD_TOKEN_FOR_BATCH_MSG = (
"No padding token is set for this model, therefore no batch can be made with uneven "
"sequences. Set a padding token or adjust the lengths of the sequences building the "
"batch so that every sequence is of the same length."
)
UNEVEN_SEQUENCES_FOR_BATCH_MSG = (
"The sequences building the batch are not of the same size, no tensor "
"can be built. Set `pad_to_max_length=True` to pad the smaller sequences"
"up to the larger sequence's length."
)
@property
def vocab_size(self) -> int:
""" Size of the base vocabulary (without the added tokens) """
raise NotImplementedError
@property
def is_fast(self):
return False
def get_vocab(self):
""" Returns the vocabulary as a dict of {token: index} pairs. `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the vocab. """
raise NotImplementedError()
def __init__(self, max_len=None, **kwargs):
self._bos_token = None
self._eos_token = None
self._unk_token = None
self._sep_token = None
self._pad_token = None
self._cls_token = None
self._mask_token = None
self._pad_token_type_id = 0
self._additional_special_tokens = []
super().__init__(**kwargs)
self.max_len = max_len if max_len is not None else int(1e12)
@ -329,13 +530,9 @@ class PreTrainedTokenizer(object):
self.init_inputs = ()
self.init_kwargs = {}
for key, value in kwargs.items():
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
else:
assert isinstance(value, str)
setattr(self, key, value)
def __len__(self):
""" Size of the full vocabulary with the added tokens """
return self.vocab_size + len(self.added_tokens_encoder)
@classmethod
def from_pretrained(cls, *inputs, **kwargs):
@ -614,14 +811,6 @@ class PreTrainedTokenizer(object):
"""
raise NotImplementedError
def vocab_size(self):
""" Size of the base vocabulary (without the added tokens) """
raise NotImplementedError
def __len__(self):
""" Size of the full vocabulary with the added tokens """
return self.vocab_size + len(self.added_tokens_encoder)
def add_tokens(self, new_tokens):
"""
Add a list of new tokens to the tokenizer class. If the new tokens are not in the
@ -670,7 +859,7 @@ class PreTrainedTokenizer(object):
return len(to_add_tokens)
def num_added_tokens(self, pair=False):
def num_special_tokens_to_add(self, pair=False):
"""
Returns the number of added tokens when encoding a sequence with special tokens.
@ -743,7 +932,7 @@ class PreTrainedTokenizer(object):
return added_tokens
def tokenize(self, text, **kwargs):
def tokenize(self, text: TextInput, **kwargs):
""" Converts a string in a sequence of tokens (string), using the tokenizer.
Split in words for word-based vocabulary or sub-words for sub-word-based
vocabularies (BPE/SentencePieces/WordPieces).
@ -852,8 +1041,8 @@ class PreTrainedTokenizer(object):
def encode(
self,
text: str,
text_pair: Optional[str] = None,
text: TextInput,
text_pair: Optional[TextInput] = None,
add_special_tokens: bool = True,
max_length: Optional[int] = None,
stride: int = 0,
@ -923,13 +1112,14 @@ class PreTrainedTokenizer(object):
def encode_plus(
self,
text: str,
text_pair: Optional[str] = None,
text: TextInput,
text_pair: Optional[TextInput] = None,
add_special_tokens: bool = True,
max_length: Optional[int] = None,
stride: int = 0,
truncation_strategy: str = "longest_first",
pad_to_max_length: bool = False,
is_pretokenized: bool = False,
return_tensors: Optional[str] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
@ -937,7 +1127,7 @@ class PreTrainedTokenizer(object):
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
**kwargs
):
) -> BatchEncoding:
"""
Returns a dictionary containing the encoded sequence or sequence pair and additional information:
the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
@ -977,6 +1167,8 @@ class PreTrainedTokenizer(object):
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
Defaults to False: no padding.
is_pretokenized (:obj:`bool`, defaults to :obj:`False`):
Set to True to indicate the input is already tokenized
return_tensors (:obj:`str`, `optional`, defaults to :obj:`None`):
Can be set to 'tf' or 'pt' to return respectively TensorFlow :obj:`tf.constant`
or PyTorch :obj:`torch.Tensor` instead of a list of python integers.
@ -1071,12 +1263,15 @@ class PreTrainedTokenizer(object):
def batch_encode_plus(
self,
batch_text_or_text_pairs: Union[str, List[str]],
batch_text_or_text_pairs: Union[
List[TextInput], List[TextPairInput], List[PreTokenizedInput], List[PreTokenizedInputPair]
],
add_special_tokens: bool = True,
max_length: Optional[int] = None,
stride: int = 0,
truncation_strategy: str = "longest_first",
pad_to_max_length: bool = False,
is_pretokenized: bool = False,
return_tensors: Optional[str] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_masks: Optional[bool] = None,
@ -1085,7 +1280,7 @@ class PreTrainedTokenizer(object):
return_offsets_mapping: bool = False,
return_input_lengths: bool = False,
**kwargs
):
) -> BatchEncoding:
"""
Returns a dictionary containing the encoded sequence or sequence pair and additional information:
the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
@ -1121,6 +1316,8 @@ class PreTrainedTokenizer(object):
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
Defaults to False: no padding.
is_pretokenized (:obj:`bool`, defaults to :obj:`False`):
Set to True to indicate the input is already tokenized
return_tensors (:obj:`str`, `optional`, defaults to :obj:`None`):
Can be set to 'tf' or 'pt' to return respectively TensorFlow :obj:`tf.constant`
or PyTorch :obj:`torch.Tensor` instead of a list of python integers.
@ -1213,9 +1410,9 @@ class PreTrainedTokenizer(object):
def total_sequence_length(input_pairs):
first_ids, second_ids = input_pairs
return len(first_ids) + (
self.num_added_tokens()
self.num_special_tokens_to_add()
if second_ids is None
else (len(second_ids) + self.num_added_tokens(pair=True))
else (len(second_ids) + self.num_special_tokens_to_add(pair=True))
)
max_length = max([total_sequence_length(ids) for ids in input_ids])
@ -1277,7 +1474,7 @@ class PreTrainedTokenizer(object):
)
)
return batch_outputs
return BatchEncoding(batch_outputs)
def prepare_for_model(
self,
@ -1361,7 +1558,7 @@ class PreTrainedTokenizer(object):
encoded_inputs = {}
# Handle max sequence length
total_len = len_ids + len_pair_ids + (self.num_added_tokens(pair=pair) if add_special_tokens else 0)
total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
if max_length and total_len > max_length:
ids, pair_ids, overflowing_tokens = self.truncate_sequences(
ids,
@ -1474,7 +1671,7 @@ class PreTrainedTokenizer(object):
)
)
return encoded_inputs
return BatchEncoding(encoded_inputs)
def prepare_for_tokenization(self, text, **kwargs):
""" Performs any necessary transformations before tokenization """
@ -1629,39 +1826,6 @@ class PreTrainedTokenizer(object):
else:
return text
@property
def special_tokens_map(self):
""" A dictionary mapping special token class attribute (cls_token, unk_token...) to their
values ('<unk>', '<cls>'...)
"""
set_attr = {}
for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
attr_value = getattr(self, "_" + attr)
if attr_value:
set_attr[attr] = attr_value
return set_attr
@property
def all_special_tokens(self):
""" List all the special tokens ('<unk>', '<cls>'...) mapped to class attributes
(cls_token, unk_token...).
"""
all_toks = []
set_attr = self.special_tokens_map
for attr_value in set_attr.values():
all_toks = all_toks + (list(attr_value) if isinstance(attr_value, (list, tuple)) else [attr_value])
all_toks = list(set(all_toks))
return all_toks
@property
def all_special_ids(self):
""" List the vocabulary indices of the special tokens ('<unk>', '<cls>'...) mapped to
class attributes (cls_token, unk_token...).
"""
all_toks = self.all_special_tokens
all_ids = self.convert_tokens_to_ids(all_toks)
return all_ids
@staticmethod
def clean_up_tokenization(out_string):
""" Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms.
@ -1692,66 +1856,70 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
self._tokenizer = tokenizer
super().__init__(**kwargs)
self.max_len_single_sentence = self.max_len - self.num_added_tokens(False) # take into account special tokens
self.max_len_sentences_pair = self.max_len - self.num_added_tokens(True) # take into account special tokens
self.max_len_single_sentence = self.max_len - self.num_special_tokens_to_add(
False
) # take into account special tokens
self.max_len_sentences_pair = self.max_len - self.num_special_tokens_to_add(
True
) # take into account special tokens
@property
def tokenizer(self):
def tokenizer(self) -> BaseTokenizer:
return self._tokenizer
@property
def decoder(self):
def decoder(self) -> Decoder:
return self._tokenizer._tokenizer.decoder
@property
def vocab_size(self):
def is_fast(self) -> bool:
return True
@property
def vocab_size(self) -> int:
return self._tokenizer.get_vocab_size(with_added_tokens=False)
def __len__(self):
def __len__(self) -> int:
return self._tokenizer.get_vocab_size(with_added_tokens=True)
@PreTrainedTokenizer.bos_token.setter
def bos_token(self, value):
self._bos_token = value
self._update_special_tokens()
self._tokenizer.add_special_tokens([self._bos_token])
@PreTrainedTokenizer.eos_token.setter
def eos_token(self, value):
self._eos_token = value
self._update_special_tokens()
self._tokenizer.add_special_tokens([self._eos_token])
@PreTrainedTokenizer.unk_token.setter
def unk_token(self, value):
self._unk_token = value
self._update_special_tokens()
self._tokenizer.add_special_tokens([self._unk_token])
@PreTrainedTokenizer.sep_token.setter
def sep_token(self, value):
self._sep_token = value
self._update_special_tokens()
self._tokenizer.add_special_tokens([self._sep_token])
@PreTrainedTokenizer.pad_token.setter
def pad_token(self, value):
self._pad_token = value
self._update_special_tokens()
self._tokenizer.add_special_tokens([self._pad_token])
@PreTrainedTokenizer.cls_token.setter
def cls_token(self, value):
self._cls_token = value
self._update_special_tokens()
self._tokenizer.add_special_tokens([self._cls_token])
@PreTrainedTokenizer.mask_token.setter
def mask_token(self, value):
self._mask_token = value
self._update_special_tokens()
self._tokenizer.add_special_tokens([self._mask_token])
@PreTrainedTokenizer.additional_special_tokens.setter
def additional_special_tokens(self, value):
self._additional_special_tokens = value
self._update_special_tokens()
def _update_special_tokens(self):
if self._tokenizer is not None:
self._tokenizer.add_special_tokens(self.all_special_tokens)
def _convert_encoding(
@ -1785,7 +1953,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
if return_special_tokens_mask:
encoding_dict["special_tokens_mask"].append(e.special_tokens_mask)
if return_offsets_mapping:
encoding_dict["offset_mapping"].append([e.original_str.offsets(o) for o in e.offsets])
encoding_dict["offset_mapping"].append(e.offsets)
# Prepare inputs as tensors if asked
if return_tensors == "tf" and is_tf_available():
@ -1818,42 +1986,50 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
return self.unk_token_id
return id
def _convert_id_to_token(self, index):
def _convert_id_to_token(self, index: int) -> str:
return self._tokenizer.id_to_token(int(index))
def convert_tokens_to_string(self, tokens):
return self._tokenizer.decode(tokens)
def convert_tokens_to_string(self, tokens: List[int], skip_special_tokens: bool = False) -> str:
return self._tokenizer.decode(tokens, skip_special_tokens)
def add_tokens(self, new_tokens):
def add_tokens(self, new_tokens: List[Union[str, AddedToken]]) -> int:
if isinstance(new_tokens, str):
new_tokens = [new_tokens]
return self._tokenizer.add_tokens(new_tokens)
def add_special_tokens(self, special_tokens_dict):
def add_special_tokens(self, special_tokens_dict: dict) -> int:
added = super().add_special_tokens(special_tokens_dict)
self._update_special_tokens()
tokens = flatten(special_tokens_dict.values())
self._tokenizer.add_special_tokens(tokens)
return added
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
if token_ids_1 is None:
return token_ids_0
else:
return token_ids_0 + token_ids_1
def num_added_tokens(self, pair=False):
def num_special_tokens_to_add(self, pair: bool = False) -> int:
return self.tokenizer.num_special_tokens_to_add(pair)
def tokenize(self, text, **kwargs):
return self.tokenizer.encode(text).tokens
def tokenize(
self, text: TextInput, pair: Optional[TextInput] = None, add_special_tokens: bool = False
) -> List[str]:
return self.tokenizer.encode(text, pair, add_special_tokens).tokens
def batch_encode_plus(
self,
batch_text_or_text_pairs: Optional[Union[List[str], List[Tuple[str]]]] = None,
batch_text_or_text_pairs: Union[
List[TextInput], List[TextPairInput], List[PreTokenizedInput], List[PreTokenizedInputPair]
] = None,
add_special_tokens: bool = True,
max_length: Optional[int] = None,
stride: int = 0,
truncation_strategy: str = "longest_first",
pad_to_max_length: bool = False,
is_pretokenized: bool = False,
return_tensors: Optional[str] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
@ -1861,12 +2037,14 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
**kwargs
):
if not add_special_tokens:
logger.warning(
"Fast tokenizers add special tokens by default. To remove special tokens, please specify"
"`add_special_tokens=False` during the initialisation rather than when calling `encode`,"
"`encode_plus` or `batch_encode_plus`."
) -> BatchEncoding:
if batch_text_or_text_pairs is None:
raise ValueError(
"None is not a valid input. "
"Should be a list/tuple of strings, "
"a list/tuple of integers, "
"A list of list of strings or tuple of strings."
)
# Needed if we have to return a tensor
@ -1894,15 +2072,67 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
"batch_text_or_text_pairs has to be a list (got {})".format(type(batch_text_or_text_pairs))
)
# Check for the pretokenized path
if is_pretokenized:
encodings = []
# Iterate over each sample (we don't know yet if they are pairs or simple input
for i, sample in enumerate(batch_text_or_text_pairs):
if not isinstance(sample, (list, tuple)):
raise TypeError(
"batch_encode_plus(..., is_pretokenized=True) requires batch_text_or_text_pairs "
"to be either List[List[str]] or List[Tuple[List[str], List[str]]] but sample at "
"index {} is of type {}".format(i, type(sample))
)
# Convert to tuple for convenience
if isinstance(sample, list):
sample = (sample,)
encodings_text = Encoding.merge(self._tokenizer.encode_batch(sample[0], False), True)
# Check if we have pairs
if len(sample) == 2:
encodings_pair = Encoding.merge(
self._tokenizer.encode_batch([("", s) for s in sample[1]], False), True
)
# No pair, default to None
elif len(sample) == 1:
encodings_pair = None
# Something else is invalid
else:
raise ValueError(
"batch_encode_plus(..., is_pretokenized=True) requires batch_text_or_text_pairs "
"to be either List[List[str]] or List[Tuple[List[str], List[str]]] but sample at "
"index {} has too much dimensions (required 1 or 2, got: {}, type {})".format(
i, len(sample), type(sample)
)
)
# Post-process
encoding = self._tokenizer.post_process(encodings_text, encodings_pair, add_special_tokens)
encodings += [encoding]
# Classical path with strings input
else:
# Avoid thread overhead if only one example.
if len(batch_text_or_text_pairs) == 1:
if isinstance(batch_text_or_text_pairs[0], (tuple, list)):
tokens = self._tokenizer.encode(*batch_text_or_text_pairs[0])
encodings = self._tokenizer.encode(
*batch_text_or_text_pairs[0], add_special_tokens=add_special_tokens
)
else:
tokens = self._tokenizer.encode(batch_text_or_text_pairs[0])
tokens = [tokens]
encodings = self._tokenizer.encode(
batch_text_or_text_pairs[0], add_special_tokens=add_special_tokens
)
encodings = [encodings]
else:
tokens = self._tokenizer.encode_batch(batch_text_or_text_pairs)
encodings = self._tokenizer.encode_batch(
batch_text_or_text_pairs, add_special_tokens=add_special_tokens
)
# Convert encoding to dict
tokens = [
@ -1915,7 +2145,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
)
for encoding in tokens
for encoding in encodings
]
# Sanitize the output to have dict[list] from list[dict]
@ -1926,8 +2156,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
stack = tf.stack(stack, axis=0)
elif return_tensors == "pt":
stack = torch.stack(stack, dim=0)
elif not return_tensors and len(stack) == 1:
stack = stack[0]
# elif not return_tensors and len(stack) == 1:
# stack = stack[0]
sanitized[key] = stack
@ -1938,17 +2168,19 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
i if len(item["input_ids"]) == 1 else [i] * len(item["input_ids"]) for i, item in enumerate(tokens)
]
sanitized["overflow_to_sample_mapping"] = overflow_to_sample_mapping
return sanitized
return BatchEncoding(sanitized, encodings)
def encode_plus(
self,
text: str,
text_pair: Optional[str] = None,
add_special_tokens: bool = False,
text: Union[TextInput, PreTokenizedInput],
text_pair: Optional[Union[TextInput, PreTokenizedInput]] = None,
add_special_tokens: bool = True,
max_length: Optional[int] = None,
pad_to_max_length: bool = False,
stride: int = 0,
truncation_strategy: str = "longest_first",
is_pretokenized: bool = False,
return_tensors: Optional[bool] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
@ -1956,7 +2188,52 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
**kwargs
):
) -> BatchEncoding:
# Check for pretokenized path (ie [token1, token2, ..., tokenN] -> [id1, id2, ..., idN]
if is_pretokenized:
if isinstance(text, list) and len(text) > 0:
# Encode through encode_batch with sequence of only one word which will be merged after hand
encoding = self._tokenizer.encode_batch(text, add_special_tokens=False)
encoding = Encoding.merge(encoding, True)
# Let's do the same for pairs if provided
if isinstance(text_pair, list):
# We prepend empty string before each word so that encoding is aware content is a pair
encoding_pair = self._tokenizer.encode_batch(
[("", p) for p in text_pair], add_special_tokens=False
)
encoding_pair = Encoding.merge(encoding_pair, True)
elif text_pair is None:
encoding_pair = None
else:
raise TypeError(
"encode_plus(..., is_pretokenized=True) requires text and text_pair to be List[str] "
"but got (text={}, text_pair={})".format(type(text), type(text_pair))
)
# Post process and if asked to do so, insert special tokens where needed
encoding = self._tokenizer.post_process(encoding, encoding_pair, add_special_tokens)
batched_output = BatchEncoding(
self._convert_encoding(
encoding,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
),
encoding,
)
else:
raise TypeError(
"encode_plus(..., is_pretokenized=True) requires text to be List[str] "
"but got (text={}, text_pair={})".format(type(text), type(text_pair))
)
else:
batched_input = [(text, text_pair)] if text_pair else [text]
batched_output = self.batch_encode_plus(
batched_input,
@ -1976,11 +2253,19 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
# Return tensor is None, then we can remove the leading batch axis
if not return_tensors:
return {key: value[0] if isinstance(value[0], list) else value for key, value in batched_output.items()}
else:
batched_output = BatchEncoding(
{
key: value[0] if len(value) > 0 and isinstance(value[0], list) else value
for key, value in batched_output.items()
},
batched_output.encodings,
)
return batched_output
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
def decode(
self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
):
text = self.tokenizer.decode(token_ids, skip_special_tokens)
if clean_up_tokenization_spaces:
@ -1989,7 +2274,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
else:
return text
def save_vocabulary(self, save_directory):
def save_vocabulary(self, save_directory: str) -> Tuple[str]:
if os.path.isdir(save_directory):
files = self._tokenizer.save(save_directory)
else:

View File

@ -64,7 +64,7 @@ TF_TEXT_CLASSIF_FINETUNED_MODELS = {
TEXT_CLASSIF_FINETUNED_MODELS = {
(
"bert-base-uncased",
"distilbert-base-cased",
"distilbert-base-uncased-finetuned-sst-2-english",
"distilbert-base-uncased-finetuned-sst-2-english",
)

View File

@ -82,7 +82,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
return
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer(add_special_tokens=False)
rust_tokenizer = self.get_rust_tokenizer()
sequence = "UNwant\u00E9d,running"
@ -91,7 +91,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertListEqual(tokens, rust_tokens)
ids = tokenizer.encode(sequence, add_special_tokens=False)
rust_ids = rust_tokenizer.encode(sequence)
rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
self.assertListEqual(ids, rust_ids)
rust_tokenizer = self.get_rust_tokenizer()

View File

@ -282,7 +282,7 @@ class TokenizerTesterMixin:
# Method is implemented (e.g. not GPT-2)
if len(attached_sequences) != 2:
self.assertEqual(tokenizer.num_added_tokens(pair=True), len(attached_sequences) - len(sequences))
self.assertEqual(tokenizer.num_special_tokens_to_add(pair=True), len(attached_sequences) - len(sequences))
def test_maximum_encoding_length_single_input(self):
tokenizer = self.get_tokenizer()
@ -291,7 +291,7 @@ class TokenizerTesterMixin:
stride = 2
sequence = tokenizer.encode(seq_0, add_special_tokens=False)
num_added_tokens = tokenizer.num_added_tokens()
num_added_tokens = tokenizer.num_special_tokens_to_add()
total_length = len(sequence) + num_added_tokens
information = tokenizer.encode_plus(
seq_0,

View File

@ -1,6 +1,6 @@
import unittest
import numpy as np
from collections import namedtuple
from itertools import takewhile
from tests.utils import require_torch
from transformers import (
@ -21,117 +21,112 @@ from transformers.tokenization_roberta import RobertaTokenizerFast
from transformers.tokenization_transfo_xl import TransfoXLTokenizerFast
class FastTokenizerMatchingTest(unittest.TestCase):
NON_ENGLISH_TAGS = ["chinese", "dutch", "french", "finnish", "german", "multilingual"]
Tokenizer = namedtuple("Tokenizer", ["name", "rust_cls", "python_cls", "vocab_key", "filter"])
def filter_non_english(_: Tokenizer, pretrained_name: str):
""" Filter all the model for non-english language """
return not any([lang in pretrained_name for lang in NON_ENGLISH_TAGS])
def filter_roberta_detectors(_: Tokenizer, pretrained_name: str):
return "detector" not in pretrained_name
class CommonFastTokenizerTest(unittest.TestCase):
TOKENIZERS_CLASSES = frozenset([])
def setUp(self) -> None:
with open("tests/fixtures/sample_text.txt") as f_data:
with open("tests/fixtures/sample_text.txt", encoding="utf-8") as f_data:
self._data = f_data.read().replace("\n\n", "\n").strip()
def assert_sequence_almost_equals(self, a, b, threshold):
def test_all_tokenizers(self):
for tok_case in self.TOKENIZERS_CLASSES:
for pretrained_name in tok_case.python_cls.pretrained_vocab_files_map[tok_case.vocab_key].keys():
# Handle padding
if len(a) != len(b):
max_len = max(len(a), len(b))
# Tokenizer.filter makes it possible to filter which Tokenizer to case based on all the
# information available in Tokenizer (name, rust class, python class, vocab key name)
if tok_case.filter is None or (
tok_case.filter is not None and tok_case.filter(tok_case, pretrained_name)
):
with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name)
tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name)
# Pad with a negative number as vocab doesnt allow idx < 0
# if will be tracked as differences
if len(a) < max_len:
a += [-1] * (max_len - len(a))
self.fast_align_python(tokenizer_r, tokenizer_p)
self.fast_only(tokenizer_r)
if len(b) < max_len:
b += [-1] * (max_len - len(b))
def fast_align_python(self, tokenizer_r, tokenizer_p):
# Check is_fast is set correctly
self.assertFalse(tokenizer_p.is_fast)
self.assertTrue(tokenizer_r.is_fast)
# Convert to numpy for convenience
a_, b_ = np.array(a), np.array(b)
# Check that Rust and Python align
self.assert_tokenization_python_rust_equals(tokenizer_r, tokenizer_p)
self.assert_num_special_tokens_to_add_equal(tokenizer_r, tokenizer_p)
self.assert_max_length_equal(tokenizer_r, tokenizer_p)
self.assert_special_tokens_map_equal(tokenizer_r, tokenizer_p)
self.assert_embeded_special_tokens(tokenizer_r, tokenizer_p)
self.assert_padding(tokenizer_r, tokenizer_p)
# TODO: enable for v3.0.0
# self.assert_empty_output_no_special_tokens(tokenizer_r, tokenizer_p)
# Compute elementwise difference
inputs_diffs = a_ - b_
inputs_diff = np.count_nonzero(inputs_diffs)
self.assertLessEqual(inputs_diff / a_.shape[0], threshold)
def fast_only(self, tokenizer_r):
# Ensure None raise an error
self.assertRaises(ValueError, tokenizer_r.tokenize, None)
self.assertRaises(ValueError, tokenizer_r.encode, None)
self.assertRaises(ValueError, tokenizer_r.encode_plus, None)
self.assertRaises(ValueError, tokenizer_r.batch_encode_plus, None)
def assert_tokenization_python_rust_almost_equals(self, tokenizer_p, tokenizer_r, threshold: float):
self.assert_add_tokens(tokenizer_r)
self.assert_offsets_mapping(tokenizer_r)
self.assert_add_special_tokens(tokenizer_r)
def assert_tokenization_python_rust_equals(self, tokenizer_p, tokenizer_r):
# Ensure basic input match
input_p = tokenizer_p.encode_plus(self._data)
input_r = tokenizer_r.encode_plus(self._data)
for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
self.assert_sequence_almost_equals(input_p[key], input_r[key], threshold)
self.assertSequenceEqual(input_p[key], input_r[key])
input_pairs_p = tokenizer_p.encode_plus(self._data, self._data)
input_pairs_r = tokenizer_r.encode_plus(self._data, self._data)
for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
self.assert_sequence_almost_equals(input_pairs_p[key], input_pairs_r[key], threshold)
self.assertSequenceEqual(input_pairs_p[key], input_pairs_r[key])
# Ensure truncation match
input_p = tokenizer_p.encode_plus(self._data, max_length=512)
input_r = tokenizer_r.encode_plus(self._data, max_length=512)
for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
self.assert_sequence_almost_equals(input_p[key], input_r[key], threshold)
self.assertSequenceEqual(input_p[key], input_r[key])
# Ensure truncation with stride match
input_p = tokenizer_p.encode_plus(self._data, max_length=512, stride=3, return_overflowing_tokens=True)
input_r = tokenizer_r.encode_plus(self._data, max_length=512, stride=3, return_overflowing_tokens=True)
for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
self.assert_sequence_almost_equals(input_p[key], input_r[key], threshold)
self.assertSequenceEqual(input_p[key], input_r[key])
def assert_padding(self, tokenizer_r, tokenizer_p):
# Simple input
input_r = tokenizer_r.encode("This is a simple input", max_length=15, pad_to_max_length=True)
input_p = tokenizer_p.encode("This is a simple input", max_length=15, pad_to_max_length=True)
def assert_num_special_tokens_to_add_equal(self, tokenizer_r, tokenizer_p):
# Check we have the same number of added_tokens for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.num_special_tokens_to_add(False), tokenizer_p.num_special_tokens_to_add(False))
self.assertEqual(tokenizer_r.num_special_tokens_to_add(True), tokenizer_p.num_special_tokens_to_add(True))
self.assertSequenceEqual(input_r, input_p)
def assert_max_length_equal(self, tokenizer_r, tokenizer_p):
# Check we have the correct max_length for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.max_len_single_sentence, tokenizer_p.max_len_single_sentence)
self.assertEqual(tokenizer_r.max_len_sentences_pair, tokenizer_p.max_len_sentences_pair)
# Simple input
input_r = tokenizer_r.encode_plus("This is a simple input", max_length=15, pad_to_max_length=True)
input_p = tokenizer_p.encode_plus("This is a simple input", max_length=15, pad_to_max_length=True)
self.assertSequenceEqual(input_r, input_p)
# Simple input
# TODO: Re-enable this test when batch_encode_plus with padding correctly handles padding
# input_r = tokenizer_r.batch_encode_plus(
# ["This is a simple input 1", "This is a simple input 2"], max_length=15, pad_to_max_length=True
# )
# input_p = tokenizer_p.batch_encode_plus(
# ["This is a simple input 1", "This is a simple input 2"], max_length=15, pad_to_max_length=True
# )
# self.assertSequenceEqual(input_r, input_p)
# Pair input
input_r = tokenizer_r.encode("This is a simple input", "This is a pair", max_length=15, pad_to_max_length=True)
input_p = tokenizer_p.encode("This is a simple input", "This is a pair", max_length=15, pad_to_max_length=True)
self.assertSequenceEqual(input_r, input_p)
# Pair input
input_r = tokenizer_r.encode_plus(
"This is a simple input", "This is a pair", max_length=15, pad_to_max_length=True
def assert_special_tokens_map_equal(self, tokenizer_r, tokenizer_p):
# Assert the set of special tokens match.
self.assertSequenceEqual(
tokenizer_p.special_tokens_map.items(), tokenizer_r.special_tokens_map.items(),
)
input_p = tokenizer_p.encode_plus(
"This is a simple input", "This is a pair", max_length=15, pad_to_max_length=True
)
self.assertSequenceEqual(input_r, input_p)
# Pair input
# TODO: Re-enable this test when batch_encode_plus with padding correctly handles padding
# input_r = tokenizer_r.batch_encode_plus(
# ["This is a simple input 1", "This is a simple input 2"],
# ["This is a simple pair 1", "This is a simple pair 2"],
# max_length=15,
# pad_to_max_length=True,
# )
# input_p = tokenizer_p.batch_encode_plus(
# ["This is a simple input 1", "This is a simple input 2"],
# ["This is a simple pair 1", "This is a simple pair 2"],
# max_length=15,
# pad_to_max_length=True,
# )
# self.assertSequenceEqual(input_r, input_p)
def assert_add_tokens(self, tokenizer_r):
vocab_size = tokenizer_r.vocab_size
@ -150,34 +145,34 @@ class FastTokenizerMatchingTest(unittest.TestCase):
)
self.assertEqual(len(tokenizer_r), vocab_size + 6)
def assert_offsets_mapping(self, tokenizer):
def assert_offsets_mapping(self, tokenizer_r):
text = "Wonderful no inspiration example with subtoken"
pair = "Along with an awesome pair"
# No pair
tokens_with_offsets = tokenizer.encode_plus(text, return_special_tokens_mask=True, return_offsets_mapping=True)
added_tokens = tokenizer.num_added_tokens(False)
tokens_with_offsets = tokenizer_r.encode_plus(
text, return_special_tokens_mask=True, return_offsets_mapping=True, add_special_tokens=True
)
added_tokens = tokenizer_r.num_special_tokens_to_add(False)
offsets = tokens_with_offsets["offset_mapping"]
# Assert there is the same number of tokens and offsets
self.assertEqual(len(offsets), len(tokens_with_offsets["input_ids"]))
# Assert there is online added_tokens special_tokens
self.assertEqual(sum([0 if x else 1 for x in offsets]), added_tokens)
self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens)
# Pairs
tokens_with_offsets = tokenizer.encode_plus(
text, pair, return_special_tokens_mask=True, return_offsets_mapping=True
tokens_with_offsets = tokenizer_r.encode_plus(
text, pair, return_special_tokens_mask=True, return_offsets_mapping=True, add_special_tokens=True
)
added_tokens = tokenizer.num_added_tokens(True)
added_tokens = tokenizer_r.num_special_tokens_to_add(True)
offsets = tokens_with_offsets["offset_mapping"]
# Assert there is the same number of tokens and offsets
self.assertEqual(len(offsets), len(tokens_with_offsets["input_ids"]))
# Assert there is online added_tokens special_tokens
self.assertEqual(sum([0 if x else 1 for x in offsets]), added_tokens)
self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens)
def assert_batch_encode_dynamic_overflowing(self, tokenizer: PreTrainedTokenizer):
@ -258,8 +253,89 @@ class FastTokenizerMatchingTest(unittest.TestCase):
output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple, input_pair)
self.assertEqual(output_p, output_r)
def assert_save_pretrained(self, tokenizer_r, tokenizer_p):
def assert_padding(self, tokenizer_r, tokenizer_p, max_length=15):
def assert_padded_input_match(input_r: list, input_p: list, max_length: int):
# Ensure we match max_length
self.assertEqual(len(input_r), max_length), self.assertEqual(len(input_p), max_length)
# Ensure the number of padded tokens is the same
padded_tokens_r = list(takewhile(lambda i: i == tokenizer_r.pad_token_id, reversed(input_r)))
padded_tokens_p = list(takewhile(lambda i: i == tokenizer_p.pad_token_id, reversed(input_p)))
self.assertSequenceEqual(padded_tokens_r, padded_tokens_p)
def assert_batch_padded_input_match(input_r: dict, input_p: dict):
for i_r in input_r.values():
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), 15), self.assertEqual(len(i_r[1]), 15)
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), 15), self.assertEqual(len(i_r[1]), 15)
for i_r, i_p in zip(input_r["input_ids"], input_p["input_ids"]):
assert_padded_input_match(i_r, i_p, max_length)
for i_r, i_p in zip(input_r["attention_mask"], input_p["attention_mask"]):
self.assertSequenceEqual(i_r, i_p)
# Simple input
input_r = tokenizer_r.encode("This is a simple input", max_length=max_length, pad_to_max_length=True)
input_p = tokenizer_p.encode("This is a simple input", max_length=max_length, pad_to_max_length=True)
assert_padded_input_match(input_r, input_p, max_length)
# Pair input
input_r = tokenizer_r.encode(
"This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
)
input_p = tokenizer_p.encode(
"This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
)
assert_padded_input_match(input_r, input_p, max_length)
# Simple input
input_r = tokenizer_r.encode_plus("This is a simple input", max_length=max_length, pad_to_max_length=True)
input_p = tokenizer_p.encode_plus("This is a simple input", max_length=max_length, pad_to_max_length=True)
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
# Pair input
input_r = tokenizer_r.encode_plus(
"This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
)
input_p = tokenizer_p.encode_plus(
"This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
)
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
# Simple input
# TODO: Re-enable this test when batch_encode_plus with padding correctly handles padding
input_r = tokenizer_r.batch_encode_plus(
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, pad_to_max_length=True
)
input_p = tokenizer_p.batch_encode_plus(
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, pad_to_max_length=True
)
assert_batch_padded_input_match(input_r, input_p)
# Pair input
# TODO: Re-enable this test when batch_encode_plus with padding correctly handles padding
input_r = tokenizer_r.batch_encode_plus(
[
("This is a simple input 1", "This is a simple input 2"),
("This is a simple pair 1", "This is a simple pair 2"),
],
max_length=15,
pad_to_max_length=True,
)
input_p = tokenizer_p.batch_encode_plus(
[
("This is a simple input 1", "This is a simple input 2"),
("This is a simple pair 1", "This is a simple pair 2"),
],
max_length=15,
pad_to_max_length=True,
)
assert_batch_padded_input_match(input_r, input_p)
def assert_save_pretrained(self, tokenizer_r, tokenizer_p):
# Checks it save with the same files
self.assertSequenceEqual(tokenizer_r.save_vocabulary("."), tokenizer_p.save_vocabulary("."))
@ -272,267 +348,178 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# self.assertEqual(getattr(tokenizer_rp, key), getattr(tokenizer_pp, key))
# self.assertEqual(getattr(tokenizer_rp, key + "_id"), getattr(tokenizer_pp, key + "_id"))
def test_bert(self):
for tokenizer_name in BertTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = BertTokenizer.from_pretrained(tokenizer_name)
tokenizer_r = BertTokenizerFast.from_pretrained(tokenizer_name)
# Check we have the same number of added_tokens for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.num_added_tokens(False), tokenizer_p.num_added_tokens(False))
self.assertEqual(tokenizer_r.num_added_tokens(True), tokenizer_p.num_added_tokens(True))
# Check we have the correct max_length for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.max_len_single_sentence, tokenizer_p.max_len_single_sentence)
self.assertEqual(tokenizer_r.max_len_sentences_pair, tokenizer_p.max_len_sentences_pair)
# Assert the set of special tokens match.
self.assertSequenceEqual(
tokenizer_p.special_tokens_map.items(),
tokenizer_r.special_tokens_map.items(),
"Bert tokenizers doesn't have the same set of special_tokens",
def assert_embeded_special_tokens(self, tokenizer_r, tokenizer_p):
sentence = "A, <mask> AllenNLP sentence."
tokens_r = tokenizer_r.encode_plus(
sentence, add_special_tokens=True, return_attention_mask=False, return_token_type_ids=True
)
tokens_p = tokenizer_p.encode_plus(
sentence, add_special_tokens=True, return_attention_mask=False, return_token_type_ids=True
)
# Assure tokenization overlap between python and rust impl.
self.assert_tokenization_python_rust_almost_equals(tokenizer_p, tokenizer_r, 0.0)
for key in tokens_p.keys():
self.assertEqual(tokens_r[key], tokens_p[key])
# Ensure add_tokens and add_special_tokens return the correct vocab size
self.assert_add_tokens(tokenizer_r)
self.assertEqual(sum(tokens_r["token_type_ids"]), 0)
self.assertEqual(sum(tokens_p["token_type_ids"]), 0)
# Check for offsets mapping
self.assert_offsets_mapping(tokenizer_r)
tokens_r = tokenizer_r.convert_ids_to_tokens(tokens_r["input_ids"])
tokens_p = tokenizer_p.convert_ids_to_tokens(tokens_p["input_ids"])
self.assertSequenceEqual(tokens_r, tokens_p)
# Check for dynamic encoding sequence handling in batch_encode_plus
self.assert_batch_encode_dynamic_overflowing(tokenizer_r)
def assert_add_special_tokens(self, tokenizer_r):
simple_num_special_tokens_to_add = tokenizer_r.num_special_tokens_to_add(pair=False)
# pair_num_special_tokens_to_add = tokenizer_r.num_special_tokens_to_add(pair=True)
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
for text in ["", " "]:
# tokenize()
no_special_tokens = tokenizer_r.tokenize(text, add_special_tokens=False)
with_special_tokens = tokenizer_r.tokenize(text, add_special_tokens=True)
self.assertEqual(len(no_special_tokens), len(with_special_tokens) - simple_num_special_tokens_to_add)
# Check the number of returned files for save_vocabulary
self.assert_save_pretrained(tokenizer_r, tokenizer_p)
# encode()
no_special_tokens = tokenizer_r.encode(text, add_special_tokens=False)
with_special_tokens = tokenizer_r.encode(text, add_special_tokens=True)
self.assertEqual(len(no_special_tokens), len(with_special_tokens) - simple_num_special_tokens_to_add)
# Check for padding
self.assert_padding(tokenizer_r, tokenizer_p)
# encode_plus()
no_special_tokens = tokenizer_r.encode_plus(text, add_special_tokens=False)
with_special_tokens = tokenizer_r.encode_plus(text, add_special_tokens=True)
for key in no_special_tokens.keys():
self.assertEqual(
len(no_special_tokens[key]), len(with_special_tokens[key]) - simple_num_special_tokens_to_add
)
# # batch_encode_plus
no_special_tokens = tokenizer_r.batch_encode_plus([text, text], add_special_tokens=False)
with_special_tokens = tokenizer_r.batch_encode_plus([text, text], add_special_tokens=True)
for key in no_special_tokens.keys():
for i_no, i_with in zip(no_special_tokens[key], with_special_tokens[key]):
self.assertEqual(len(i_no), len(i_with) - simple_num_special_tokens_to_add)
class WordPieceFastTokenizerTest(CommonFastTokenizerTest):
"""
Override all the specific methods to test WordPiece behavior
"""
TOKENIZERS_CLASSES = frozenset(
[
Tokenizer("Bert", BertTokenizerFast, BertTokenizer, "vocab_file", filter_non_english),
Tokenizer("DistilBert", DistilBertTokenizerFast, DistilBertTokenizer, "vocab_file", filter_non_english),
]
)
def fast_only(self, tokenizer_r):
super().fast_only(tokenizer_r)
self.assert_offsets_with_special_characters(tokenizer_r)
def assert_add_special_tokens(self, tokenizer_r):
super().assert_add_special_tokens(tokenizer_r)
def assert_offsets_with_special_characters(self, tokenizer_r):
sentence = "A, naïve [MASK] AllenNLP sentence."
tokens = tokenizer_r.encode_plus(
sentence,
return_attention_mask=False,
return_token_type_ids=False,
return_offsets_mapping=True,
add_special_tokens=True,
)
expected_results = [
((0, 1), "A"),
((1, 2), ","),
((3, 8), "naive"), # BERT normalizes this away
# Append MASK here after lower-casing
((16, 21), "Allen"),
((22, 24), "##NL"),
((24, 25), "##P"),
((26, 34), "sentence"),
((35, 36), "."),
]
# Check if the tokenizer is uncased
if tokenizer_r.init_kwargs.get("do_lower_case"):
expected_results = [(offset, token.lower()) for (offset, token) in expected_results]
# Append the special tokens
expected_results.insert(3, ((9, 15), "[MASK]"))
expected_results.insert(0, (None, "[CLS]"))
expected_results.append((None, "[SEP]"))
self.assertEqual([e[1] for e in expected_results], tokenizer_r.convert_ids_to_tokens(tokens["input_ids"]))
# self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"])
class RobertaFastTokenizerTest(CommonFastTokenizerTest):
TOKENIZERS_CLASSES = frozenset(
[Tokenizer("Roberta", RobertaTokenizerFast, RobertaTokenizer, "vocab_file", filter_roberta_detectors)]
)
def assert_embeded_special_tokens(self, tokenizer_r, tokenizer_p):
sentence = "A, <mask> AllenNLP sentence."
tokens_r = tokenizer_r.encode_plus(sentence, add_special_tokens=True, return_token_type_ids=True)
tokens_p = tokenizer_p.encode_plus(sentence, add_special_tokens=True, return_token_type_ids=True)
# Rust correctly handles the space before the mask while python doesnt
self.assertSequenceEqual(tokens_r["input_ids"], [0, 83, 6, 50264, 3823, 487, 21992, 3645, 4, 2])
self.assertSequenceEqual(tokens_p["input_ids"], [0, 83, 6, 50264, 3823, 487, 21992, 3645, 4, 2])
# token_type_ids should put 0 everywhere
self.assertEquals(sum(tokens_r["token_type_ids"]), sum(tokens_p["token_type_ids"]))
# attention_mask should put 1 everywhere, so sum over length should be 1
self.assertEquals(
sum(tokens_r["attention_mask"]) / len(tokens_r["attention_mask"]),
sum(tokens_p["attention_mask"]) / len(tokens_p["attention_mask"]),
)
# Rust should have 'Ġ' before <mask> which should be left as an entire token
tokens_r = tokenizer_r.convert_ids_to_tokens(tokens_r["input_ids"])
self.assertSequenceEqual(tokens_r, ["<s>", "ĠA", ",", "<mask>", "ĠAllen", "N", "LP", "Ġsentence", ".", "</s>"])
class NoPaddingTokenFastTokenizerMatchingTest(CommonFastTokenizerTest):
TOKENIZERS_CLASSES = [
Tokenizer("OpenAI GPT", OpenAIGPTTokenizerFast, OpenAIGPTTokenizer, "vocab_file", None),
Tokenizer("GPT2", GPT2TokenizerFast, GPT2Tokenizer, "vocab_file", None),
]
def assert_padding(self, tokenizer_r, tokenizer_p, max_length=15):
# Simple input
s = "This is a simple input"
s2 = ["This is a simple input 1", "This is a simple input 2"]
p = ("This is a simple input", "This is a pair")
p2 = [
("This is a simple input 1", "This is a simple input 2"),
("This is a simple pair 1", "This is a simple pair 2"),
]
# Simple input tests
self.assertRaises(ValueError, tokenizer_r.encode, s, max_length=max_length, pad_to_max_length=True)
# Simple input
self.assertRaises(ValueError, tokenizer_r.encode_plus, s, max_length=max_length, pad_to_max_length=True)
# Simple input
self.assertRaises(ValueError, tokenizer_r.batch_encode_plus, s2, max_length=max_length, pad_to_max_length=True)
# Pair input
self.assertRaises(ValueError, tokenizer_r.encode, p, max_length=max_length, pad_to_max_length=True)
# Pair input
self.assertRaises(ValueError, tokenizer_r.encode_plus, p, max_length=max_length, pad_to_max_length=True)
# Pair input
self.assertRaises(ValueError, tokenizer_r.batch_encode_plus, p2, max_length=max_length, pad_to_max_length=True)
class TransfoXLFastTokenizerTest(NoPaddingTokenFastTokenizerMatchingTest):
TOKENIZERS_CLASSES = frozenset(
[Tokenizer("TransfoXL", TransfoXLTokenizerFast, TransfoXLTokenizer, "pretrained_vocab_file", None)]
)
@require_torch
def test_transfoxl(self):
for tokenizer_name in TransfoXLTokenizer.pretrained_vocab_files_map["pretrained_vocab_file"].keys():
tokenizer_p = TransfoXLTokenizer.from_pretrained(tokenizer_name)
tokenizer_r = TransfoXLTokenizerFast.from_pretrained(tokenizer_name)
# Check we have the same number of added_tokens for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.num_added_tokens(False), tokenizer_p.num_added_tokens(False))
self.assertEqual(tokenizer_r.num_added_tokens(True), tokenizer_p.num_added_tokens(True))
# Check we have the correct max_length for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.max_len_single_sentence, tokenizer_p.max_len_single_sentence)
self.assertEqual(tokenizer_r.max_len_sentences_pair, tokenizer_p.max_len_sentences_pair)
# Assert the set of special tokens match.
self.assertSequenceEqual(
tokenizer_p.special_tokens_map.items(),
tokenizer_r.special_tokens_map.items(),
"TransfoXL tokenizers doesn't have the same set of special_tokens",
)
# Assure tokenization overlap between python and rust impl.
self.assert_tokenization_python_rust_almost_equals(tokenizer_p, tokenizer_r, 0.0)
# Ensure add_tokens and add_special_tokens return the correct vocab size
self.assert_add_tokens(tokenizer_r)
# Check for offsets mapping
self.assert_offsets_mapping(tokenizer_r)
# Check for dynamic encoding sequence handling in batch_encode_plus
self.assertRaises(ValueError, self.assert_batch_encode_dynamic_overflowing, tokenizer_r)
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check for padding
self.assertRaises(ValueError, self.assert_padding, tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
# TransfoXL tokenizers comes in a special format which is not compatible at all
# with rust tokenizers. We ensure the errors detection at correctly raised
tokenizer_r_files = tokenizer_r.save_pretrained(".")
self.assertSequenceEqual(
tokenizer_r_files, ["./vocab.json", "./special_tokens_map.json", "./added_tokens.json"]
)
# Check loading Python-tokenizer save through Rust doesnt work (and the opposite)
self.assertRaises(ValueError, tokenizer_p.from_pretrained, *tokenizer_r_files)
self.assertRaises(ValueError, tokenizer_r.from_pretrained, *tokenizer_p.save_pretrained("."))
# Check loading works for Python to Python and Rust to Rust
# Issue: https://github.com/huggingface/transformers/issues/3000
# self.assertIsNotNone(tokenizer_p.__class__.from_pretrained('./'))
self.assertIsNotNone(tokenizer_r.__class__.from_pretrained("./"))
def test_distilbert(self):
for tokenizer_name in DistilBertTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = DistilBertTokenizer.from_pretrained(tokenizer_name)
tokenizer_r = DistilBertTokenizerFast.from_pretrained(tokenizer_name)
# Check we have the same number of added_tokens for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.num_added_tokens(False), tokenizer_p.num_added_tokens(False))
self.assertEqual(tokenizer_r.num_added_tokens(True), tokenizer_p.num_added_tokens(True))
# Check we have the correct max_length for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.max_len_single_sentence, tokenizer_p.max_len_single_sentence)
self.assertEqual(tokenizer_r.max_len_sentences_pair, tokenizer_p.max_len_sentences_pair)
# DistilBert should match 100%
# Assert the set of special tokens match.
self.assertSequenceEqual(
tokenizer_p.special_tokens_map.items(),
tokenizer_r.special_tokens_map.items(),
"DistilBert tokenizers doesn't have the same set of special_tokens",
)
# Assure tokenization overlap between python and rust impl.
self.assert_tokenization_python_rust_almost_equals(tokenizer_p, tokenizer_r, 0.0)
# Ensure add_tokens and add_special_tokens return the correct vocab size
self.assert_add_tokens(tokenizer_r)
# Check for offsets mapping
self.assert_offsets_mapping(tokenizer_r)
# Check for dynamic encoding sequence handling in batch_encode_plus
self.assert_batch_encode_dynamic_overflowing(tokenizer_r)
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assert_save_pretrained(tokenizer_r, tokenizer_p)
# Check for padding
self.assert_padding(tokenizer_r, tokenizer_p)
def test_gpt2(self):
for tokenizer_name in GPT2Tokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = GPT2Tokenizer.from_pretrained(tokenizer_name)
tokenizer_r = GPT2TokenizerFast.from_pretrained(tokenizer_name)
# Check we have the same number of added_tokens for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.num_added_tokens(False), tokenizer_p.num_added_tokens(False))
self.assertEqual(tokenizer_r.num_added_tokens(True), tokenizer_p.num_added_tokens(True))
# Check we have the correct max_length for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.max_len_single_sentence, tokenizer_p.max_len_single_sentence)
self.assertEqual(tokenizer_r.max_len_sentences_pair, tokenizer_p.max_len_sentences_pair)
# Assert the set of special tokens match.
self.assertSequenceEqual(
tokenizer_p.special_tokens_map.items(),
tokenizer_r.special_tokens_map.items(),
"GPT2 tokenizers doesn't have the same set of special_tokens",
)
# Assure tokenization overlap between python and rust impl.
self.assert_tokenization_python_rust_almost_equals(tokenizer_p, tokenizer_r, 0.0)
# Ensure add_tokens and add_special_tokens return the correct vocab size
self.assert_add_tokens(tokenizer_r)
# Check for offsets mapping
self.assert_offsets_mapping(tokenizer_r)
# Check for dynamic encoding sequence handling in batch_encode_plus
self.assertRaises(ValueError, self.assert_batch_encode_dynamic_overflowing, tokenizer_r)
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assert_save_pretrained(tokenizer_r, tokenizer_p)
# Check for padding
self.assertRaises(ValueError, self.assert_padding, tokenizer_r, tokenizer_p)
def test_roberta(self):
for tokenizer_name in RobertaTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = RobertaTokenizer.from_pretrained(tokenizer_name)
tokenizer_r = RobertaTokenizerFast.from_pretrained(tokenizer_name)
# Check we have the same number of added_tokens for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.num_added_tokens(False), tokenizer_p.num_added_tokens(False))
self.assertEqual(tokenizer_r.num_added_tokens(True), tokenizer_p.num_added_tokens(True))
# Check we have the correct max_length for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.max_len_single_sentence, tokenizer_p.max_len_single_sentence)
self.assertEqual(tokenizer_r.max_len_sentences_pair, tokenizer_p.max_len_sentences_pair)
# Assert the set of special tokens match.
self.assertSequenceEqual(
tokenizer_p.special_tokens_map.items(),
tokenizer_r.special_tokens_map.items(),
"Roberta tokenizers doesn't have the same set of special_tokens",
)
# Assure tokenization overlap between python and rust impl.
self.assert_tokenization_python_rust_almost_equals(tokenizer_p, tokenizer_r, 0.01)
# Ensure add_tokens and add_special_tokens return the correct vocab size
self.assert_add_tokens(tokenizer_r)
# Check for offsets mapping
self.assert_offsets_mapping(tokenizer_r)
# Check for dynamic encoding sequence handling in batch_encode_plus
self.assert_batch_encode_dynamic_overflowing(tokenizer_r)
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assert_save_pretrained(tokenizer_r, tokenizer_p)
# Check for padding
# TODO: Re-enable this test as soon as Roberta align with the python tokenizer.
# self.assert_padding(tokenizer_r, tokenizer_p)
def test_openai(self):
for tokenizer_name in OpenAIGPTTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = OpenAIGPTTokenizer.from_pretrained(tokenizer_name)
tokenizer_r = OpenAIGPTTokenizerFast.from_pretrained(tokenizer_name)
# Check we have the same number of added_tokens for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.num_added_tokens(False), tokenizer_p.num_added_tokens(False))
self.assertEqual(tokenizer_r.num_added_tokens(True), tokenizer_p.num_added_tokens(True))
# Check we have the correct max_length for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.max_len_single_sentence, tokenizer_p.max_len_single_sentence)
self.assertEqual(tokenizer_r.max_len_sentences_pair, tokenizer_p.max_len_sentences_pair)
# Assert the set of special tokens match.
self.assertSequenceEqual(
tokenizer_p.special_tokens_map.items(),
tokenizer_r.special_tokens_map.items(),
"GPT tokenizers doesn't have the same set of special_tokens",
)
# Assure tokenization overlap between python and rust impl.
self.assert_tokenization_python_rust_almost_equals(tokenizer_p, tokenizer_r, 0.0)
# Ensure add_tokens and add_special_tokens return the correct vocab size
self.assert_add_tokens(tokenizer_r)
# Check for offsets mapping
self.assert_offsets_mapping(tokenizer_r)
# Check for dynamic encoding sequence handling in batch_encode_plus
self.assertRaises(ValueError, self.assert_batch_encode_dynamic_overflowing, tokenizer_r)
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
# Check for padding
self.assertRaises(ValueError, self.assert_padding, tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assert_save_pretrained(tokenizer_r, tokenizer_p)
def test_all_tokenizers(self):
super().test_all_tokenizers()

View File

@ -94,7 +94,7 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
return
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer(add_special_tokens=False, add_prefix_space=True)
rust_tokenizer = self.get_rust_tokenizer(add_prefix_space=True)
sequence = "lower newer"
@ -105,7 +105,7 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# Testing conversion to ids without special tokens
ids = tokenizer.encode(sequence, add_special_tokens=False, add_prefix_space=True)
rust_ids = rust_tokenizer.encode(sequence)
rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
self.assertListEqual(ids, rust_ids)
# Testing conversion to ids with special tokens