Tokenizers API developments (#5103)

* Add return lengths

* make pad a bit more flexible so it can be used as collate_fn

* check all kwargs sent to encoding method are known

* fixing kwargs in encodings

* New AddedToken class in python

This class let you specify specifique tokenization behaviors for some special tokens. Used in particular for GPT2 and Roberta, to control how white spaces are stripped around special tokens.

* style and quality

* switched to hugginface tokenizers library for AddedTokens

* up to tokenizer 0.8.0-rc3 - update API to use AddedToken state

* style and quality

* do not raise an error on additional or unused kwargs for tokenize() but only a warning

* transfo-xl pretrained model requires torch

* Update src/transformers/tokenization_utils.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Thomas Wolf 2020-06-23 13:36:57 +02:00 committed by GitHub
parent 1ae132a07d
commit 11fdde0271
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 414 additions and 230 deletions

View File

@ -109,7 +109,7 @@ setup(
packages=find_packages("src"),
install_requires=[
"numpy",
"tokenizers == 0.8.0-rc1",
"tokenizers == 0.8.0-rc3",
# dataclasses for Python versions that don't have it
"dataclasses;python_version<'3.7'",
# utilities from PyPA to e.g. compare versions

View File

@ -23,7 +23,7 @@ from typing import List, Optional
from tokenizers import BertWordPieceTokenizer
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
from .tokenization_utils_fast import PreTrainedTokenizerFast
@ -547,45 +547,6 @@ class WordpieceTokenizer(object):
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
class BertTokenizerFast(PreTrainedTokenizerFast):
r"""
Constructs a "Fast" BERT tokenizer (backed by HuggingFace's `tokenizers` library).

View File

@ -146,6 +146,7 @@ class GPT2Tokenizer(PreTrainedTokenizer):
unk_token="<|endoftext|>",
bos_token="<|endoftext|>",
eos_token="<|endoftext|>",
add_prefix_space=False,
**kwargs
):
super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
@ -161,6 +162,7 @@ class GPT2Tokenizer(PreTrainedTokenizer):
bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
self.add_prefix_space = add_prefix_space
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
@ -273,10 +275,11 @@ class GPT2Tokenizer(PreTrainedTokenizer):
return vocab_file, merge_file
def prepare_for_tokenization(self, text, **kwargs):
if "add_prefix_space" in kwargs and kwargs["add_prefix_space"]:
return " " + text
return text
def prepare_for_tokenization(self, text, is_pretokenized=False, **kwargs):
add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
if is_pretokenized or add_prefix_space:
text = " " + text
return (text, kwargs)
class GPT2TokenizerFast(PreTrainedTokenizerFast):
@ -354,7 +357,7 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
is_pretokenized = kwargs.get("is_pretokenized", False)
assert self.add_prefix_space or not is_pretokenized, (
"You need to instantiate GPT2TokenizerFast with add_prefix_space=False "
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
"to use it with pretokenized inputs."
)
@ -364,7 +367,7 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
is_pretokenized = kwargs.get("is_pretokenized", False)
assert self.add_prefix_space or not is_pretokenized, (
"You need to instantiate GPT2TokenizerFast with add_prefix_space=False "
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
"to use it with pretokenized inputs."
)

View File

@ -18,11 +18,10 @@
import logging
from typing import List, Optional
from tokenizers import AddedToken
from tokenizers.processors import RobertaProcessing
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils import AddedToken, PreTrainedTokenizer
logger = logging.getLogger(__name__)
@ -135,6 +134,7 @@ class RobertaTokenizer(GPT2Tokenizer):
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
add_prefix_space=False,
**kwargs
):
super().__init__(
@ -148,9 +148,17 @@ class RobertaTokenizer(GPT2Tokenizer):
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
add_prefix_space=add_prefix_space,
**kwargs,
)
@PreTrainedTokenizer.mask_token.setter
def mask_token(self, value):
if not isinstance(value, AddedToken):
value = AddedToken(value, lstrip=True)
self._mask_token = value
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
@ -231,14 +239,11 @@ class RobertaTokenizer(GPT2Tokenizer):
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
def prepare_for_tokenization(self, text, add_special_tokens=False, **kwargs):
if "add_prefix_space" in kwargs:
add_prefix_space = kwargs["add_prefix_space"]
else:
add_prefix_space = add_special_tokens
if add_prefix_space and len(text) > 0 and not text[0].isspace():
def prepare_for_tokenization(self, text, is_pretokenized=False, **kwargs):
add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
if (is_pretokenized or add_prefix_space) and text:
text = " " + text
return text
return (text, kwargs)
class RobertaTokenizerFast(GPT2TokenizerFast):
@ -300,7 +305,7 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
add_prefix_space=True,
add_prefix_space=False,
trim_offsets=True,
**kwargs
):
@ -327,15 +332,14 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
trim_offsets=trim_offsets,
)
self.backend_tokenizer.add_special_tokens([kwargs["mask_token"]])
self.sanitize_special_tokens() # This will add the necessary special tokens to the vocabulary if needed.
@PreTrainedTokenizer.mask_token.setter
def mask_token(self, value):
if not isinstance(value, AddedToken):
value = AddedToken(value, lstrip=True)
self._mask_token = str(value)
self._maybe_update_backend([value])
self._mask_token = value
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]

View File

@ -355,10 +355,10 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
else:
return symbols
def prepare_for_tokenization(self, text, **kwargs):
def prepare_for_tokenization(self, text, is_pretokenized=False, **kwargs):
# add spaces before punctuation symbols as should be done in transfo-xl
if "add_space_before_punct_symbol" in kwargs and kwargs["add_space_before_punct_symbol"]:
add_space_before_punct_symbol = kwargs.pop("add_space_before_punct_symbol", False)
if add_space_before_punct_symbol:
text = self.punctuation_with_space_around_pattern.sub(r" ", text)
elif self.punction_without_space_before_pattern.search(text):
# searches until the first occurence of a punctuation symbol without surrounding spaces
@ -366,7 +366,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
"You might want to consider setting `add_space_before_punct_symbol=True` as an argument to the `tokenizer.encode()` to avoid tokenizing words with punctuation symbols to the `<unk>` token"
)
return text
return (text, kwargs)
class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer):

View File

@ -19,12 +19,14 @@
import itertools
import logging
import re
import unicodedata
from typing import List, Optional, Tuple, Union
from .file_utils import add_end_docstrings
from .tokenization_utils_base import (
ENCODE_KWARGS_DOCSTRING,
ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING,
AddedToken,
BatchEncoding,
EncodedInput,
EncodedInputPair,
@ -42,6 +44,57 @@ from .tokenization_utils_base import (
logger = logging.getLogger(__name__)
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
def _is_end_of_word(text):
"""Checks whether the last character in text is one of a punctuation, control or whitespace character."""
last_char = text[-1]
return bool(_is_control(last_char) | _is_punctuation(last_char) | _is_whitespace(last_char))
def _is_start_of_word(text):
"""Checks whether the first character in text is one of a punctuation, control or whitespace character."""
first_char = text[0]
return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char))
class PreTrainedTokenizer(PreTrainedTokenizerBase):
""" Base class for all slow tokenizers.
@ -104,7 +157,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
super().__init__(**kwargs)
# Added tokens
self.added_tokens_encoder = {}
self.unique_added_tokens_encoder = set()
self.unique_added_tokens_encoder = []
self.added_tokens_decoder = {}
@property
@ -124,7 +177,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
""" Size of the full vocabulary with the added tokens """
return self.vocab_size + len(self.added_tokens_encoder)
def add_tokens(self, new_tokens: Union[str, List[str]]) -> int:
def add_tokens(self, new_tokens: Union[str, List[str]], special_token=False) -> int:
"""
Add a list of new tokens to the tokenizer class. If the new tokens are not in the
vocabulary, they are added to it with indices starting from length of the current vocabulary.
@ -154,7 +207,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
tokens_to_add = []
for token in new_tokens:
assert isinstance(token, str)
assert isinstance(token, (str, AddedToken))
if self.init_kwargs.get("do_lower_case", False) and token not in self.all_special_tokens:
token = token.lower()
if (
@ -169,7 +222,9 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
self.added_tokens_encoder.update(added_tok_encoder)
self.unique_added_tokens_encoder = set(self.added_tokens_encoder.keys()).union(set(self.all_special_tokens))
self.unique_added_tokens_encoder = list(
set(self.added_tokens_encoder.keys()).union(set(self.all_special_tokens))
)
self.added_tokens_decoder.update(added_tok_decoder)
return len(tokens_to_add)
@ -204,24 +259,63 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
text (:obj:`string`): The sequence to be encoded.
**kwargs (:obj: `dict`): Arguments passed to the model-specific `prepare_for_tokenization` preprocessing method.
"""
all_special_tokens = self.all_special_tokens
text = self.prepare_for_tokenization(text, **kwargs)
# Simple mapping string => AddedToken for special tokens with specific tokenization behaviors
all_special_tokens_extended = dict(
(str(t), t) for t in self.all_special_tokens_extended if isinstance(t, AddedToken)
)
text, kwargs = self.prepare_for_tokenization(text, **kwargs)
if kwargs:
logger.warning(f"Keyword arguments {kwargs} not recognized.")
# TODO: should this be in the base class?
def lowercase_text(t):
# convert non-special tokens to lowercase
escaped_special_toks = [re.escape(s_tok) for s_tok in all_special_tokens]
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
return re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), t)
if self.init_kwargs.get("do_lower_case", False):
text = lowercase_text(text)
# convert non-special tokens to lowercase
escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens]
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
def split_on_token(tok, text):
result = []
tok_extended = all_special_tokens_extended.get(tok, None)
split_text = text.split(tok)
full_word = ""
for i, sub_text in enumerate(split_text):
sub_text = sub_text.rstrip()
# AddedToken can control whitespace stripping around them.
# We use them for GPT2 and Roberta to have different behavior depending on the special token
# Cf. https://github.com/huggingface/transformers/pull/2778
# and https://github.com/huggingface/transformers/issues/3788
if isinstance(tok_extended, AddedToken):
if tok_extended.single_word:
# Try to avoid splitting on token
if (
i < len(split_text) - 1
and not _is_end_of_word(sub_text)
and not _is_start_of_word(split_text[i + 1])
):
# Don't extract the special token
full_word += sub_text + tok
elif full_word:
full_word += sub_text
result += [full_word]
full_word = ""
continue
# Strip white spaces on the right
if tok_extended.rstrip and i > 0:
# A bit counter-intuitive but we strip the left of the string
# since tok_extended.rstrip means the special token is eating all white spaces on its right
sub_text = sub_text.lstrip()
# Strip white spaces on the left
if tok_extended.lstrip and i < len(split_text) - 1:
sub_text = sub_text.rstrip() # Opposite here
else:
# We strip left and right by default
if i < len(split_text) - 1:
sub_text = sub_text.rstrip()
if i > 0:
sub_text = sub_text.lstrip()
if i == 0 and not sub_text:
result += [tok]
elif i == len(split_text) - 1:
@ -316,23 +410,17 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs
) -> BatchEncoding:
def get_input_ids(text):
if isinstance(text, str):
tokens = self.tokenize(text, add_special_tokens=add_special_tokens, **kwargs)
tokens = self.tokenize(text, **kwargs)
return self.convert_tokens_to_ids(tokens)
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
if is_pretokenized:
tokens = list(
itertools.chain(
*(
self.tokenize(t, add_special_tokens=False, add_prefix_space=True, **kwargs)
for t in text
)
)
)
tokens = list(itertools.chain(*(self.tokenize(t, is_pretokenized=True, **kwargs) for t in text)))
return self.convert_tokens_to_ids(tokens)
else:
return self.convert_tokens_to_ids(text)
@ -369,6 +457,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
return_token_type_ids=return_token_type_ids,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_length=return_length,
verbose=verbose,
)
@ -390,28 +479,21 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
is_pretokenized: bool = False,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_masks: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_masks: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_lengths: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs
) -> BatchEncoding:
def get_input_ids(text):
if isinstance(text, str):
tokens = self.tokenize(text, add_special_tokens=add_special_tokens, **kwargs)
tokens = self.tokenize(text, **kwargs)
return self.convert_tokens_to_ids(tokens)
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
if is_pretokenized:
tokens = list(
itertools.chain(
*(
self.tokenize(t, add_special_tokens=False, add_prefix_space=True, **kwargs)
for t in text
)
)
)
tokens = list(itertools.chain(*(self.tokenize(t, is_pretokenized=True, **kwargs) for t in text)))
return self.convert_tokens_to_ids(tokens)
else:
return self.convert_tokens_to_ids(text)
@ -449,11 +531,11 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
truncation_strategy=truncation_strategy,
max_length=max_length,
stride=stride,
return_attention_masks=return_attention_masks,
return_attention_mask=return_attention_mask,
return_token_type_ids=return_token_type_ids,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_masks=return_special_tokens_masks,
return_lengths=return_lengths,
return_special_tokens_mask=return_special_tokens_mask,
return_length=return_length,
return_tensors=return_tensors,
verbose=verbose,
)
@ -471,10 +553,10 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
stride: int = 0,
return_tensors: Optional[str] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_masks: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_masks: bool = False,
return_lengths: bool = False,
return_special_tokens_mask: bool = False,
return_length: bool = False,
verbose: bool = True,
) -> BatchEncoding:
""" Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
@ -507,11 +589,11 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
truncation_strategy=truncation_strategy,
max_length=max_length,
stride=stride,
return_attention_mask=return_attention_masks,
return_attention_mask=return_attention_mask,
return_token_type_ids=return_token_type_ids,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_masks,
return_lengths=return_lengths,
return_special_tokens_mask=return_special_tokens_mask,
return_length=return_length,
return_tensors=None, # We will convert the whole batch to tensors at the end
prepend_batch_axis=False,
verbose=verbose,
@ -542,7 +624,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_lengths: bool = False,
return_length: bool = False,
verbose: bool = True,
) -> BatchEncoding:
""" Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
@ -615,7 +697,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
return_attention_mask=return_attention_mask,
)
if return_lengths:
if return_length:
encoded_inputs["length"] = len(encoded_inputs["input_ids"])
batch_outputs = BatchEncoding(
@ -624,9 +706,13 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
return batch_outputs
def prepare_for_tokenization(self, text: str, **kwargs) -> str:
""" Performs any necessary transformations before tokenization """
return text
def prepare_for_tokenization(self, text: str, is_pretokenized=False, **kwargs) -> (str, dict):
""" Performs any necessary transformations before tokenization.
This method should pop the arguments from kwargs and return kwargs as well.
We test kwargs at the end of the encoding process to be sure all the arguments have been used.
"""
return (text, kwargs)
def truncate_sequences(
self,

View File

@ -28,7 +28,7 @@ from enum import Enum
from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
import numpy as np
from tokenizers import AddedToken as AddedTokenFast
from tokenizers import AddedToken
from tokenizers import Encoding as EncodingFast
from .file_utils import (
@ -522,6 +522,43 @@ class BatchEncoding(UserDict):
return self
# class AddedToken(UserString):
# """ AddedToken represents a token to be added to a Tokenizer
# An AddedToken can have special options defining the way it should behave.
# Args:
# content: str:
# The content of the token
# single_word: bool
# Whether this token should only match against single word. If True,
# this token will never match inside of a word.
# lstrip: bool
# Whether this token should strip all potential whitespaces on the left side.
# If True, this token will greedily match any whitespace on the left and then strip
# them out.
# rstrip: bool
# Whether this token should strip all potential whitespaces on the right side.
# If True, this token will greedily match any whitespace on the right and then strip
# them out.
# """
# def __init__(
# self, data: str, single_word: bool = False, lstrip: bool = False, rstrip: bool = False,
# ):
# super().__init__(data)
# self._single_word = single_word
# self._lstrip = lstrip
# self._rstrip = rstrip
# def lower(self):
# return AddedToken(self.data.lower(), self._single_word, self._lstrip, self._rstrip)
class SpecialTokensMixin:
""" SpecialTokensMixin is derived by ``PreTrainedTokenizer`` and ``PreTrainedTokenizerFast`` and
handles specific behaviors related to special tokens. In particular, this class hold the
@ -556,15 +593,22 @@ class SpecialTokensMixin:
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
elif isinstance(value, AddedTokenFast):
setattr(self, key, str(value))
elif isinstance(value, str):
elif isinstance(value, (str, AddedToken)):
setattr(self, key, value)
else:
raise TypeError(
"special token {} has to be either str or AddedTokenFast but got: {}".format(key, type(value))
"special token {} has to be either str or AddedToken but got: {}".format(key, type(value))
)
def sanitize_special_tokens(self):
""" Make sure that all the special tokens attributes of the tokenizer (tokenizer.mask_token, tokenizer.cls_token, ...)
are in the vocabulary. Add the missing ones to the vocabulary if needed.
Return:
Number of tokens added in the vocaulary during the operation.
"""
return self.add_tokens(self.all_special_tokens_extended, special_token=True)
def add_special_tokens(self, special_tokens_dict):
"""
Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them
@ -608,121 +652,118 @@ class SpecialTokensMixin:
added_tokens = 0
for key, value in special_tokens_dict.items():
assert key in self.SPECIAL_TOKENS_ATTRIBUTES
if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
added_tokens += self.add_tokens(value)
else:
assert isinstance(value, str)
added_tokens += self.add_tokens([value])
if self.verbose:
logger.info("Assigning %s to the %s key of the tokenizer", value, key)
setattr(self, key, value)
if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
added_tokens += self.add_tokens(value, special_token=True)
else:
assert isinstance(value, str)
added_tokens += self.add_tokens([value], special_token=True)
return added_tokens
def add_tokens(self, value):
def add_tokens(self, value, special_token=False):
""" To be overriden by derived class to add a token in the vocabulary. """
pass
def _maybe_update_backend(self, value):
""" To be overriden by derived class if a backend tokenizer has to be updated. """
pass
@property
def bos_token(self):
""" Beginning of sentence token (string). Log an error if used while not having been set. """
if self._bos_token is None and self.verbose:
logger.error("Using bos_token, but it is not set yet.")
return self._bos_token
return None
return str(self._bos_token)
@property
def eos_token(self):
""" End of sentence token (string). Log an error if used while not having been set. """
if self._eos_token is None and self.verbose:
logger.error("Using eos_token, but it is not set yet.")
return self._eos_token
return None
return str(self._eos_token)
@property
def unk_token(self):
""" Unknown token (string). Log an error if used while not having been set. """
if self._unk_token is None and self.verbose:
logger.error("Using unk_token, but it is not set yet.")
return self._unk_token
return None
return str(self._unk_token)
@property
def sep_token(self):
""" Separation token (string). E.g. separate context and query in an input sequence. Log an error if used while not having been set. """
if self._sep_token is None and self.verbose:
logger.error("Using sep_token, but it is not set yet.")
return self._sep_token
return None
return str(self._sep_token)
@property
def pad_token(self):
""" Padding token (string). Log an error if used while not having been set. """
if self._pad_token is None and self.verbose:
logger.error("Using pad_token, but it is not set yet.")
return self._pad_token
return None
return str(self._pad_token)
@property
def cls_token(self):
""" Classification token (string). E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
if self._cls_token is None and self.verbose:
logger.error("Using cls_token, but it is not set yet.")
return self._cls_token
return None
return str(self._cls_token)
@property
def mask_token(self):
""" Mask token (string). E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
if self._mask_token is None and self.verbose:
logger.error("Using mask_token, but it is not set yet.")
return self._mask_token
return None
return str(self._mask_token)
@property
def additional_special_tokens(self):
""" All the additional special tokens you may want to use (list of strings). Log an error if used while not having been set. """
if self._additional_special_tokens is None and self.verbose:
logger.error("Using additional_special_tokens, but it is not set yet.")
return self._additional_special_tokens
return None
return [str(tok) for tok in self._additional_special_tokens]
@bos_token.setter
def bos_token(self, value):
self._bos_token = value
self._maybe_update_backend([value])
@eos_token.setter
def eos_token(self, value):
self._eos_token = value
self._maybe_update_backend([value])
@unk_token.setter
def unk_token(self, value):
self._unk_token = value
self._maybe_update_backend([value])
@sep_token.setter
def sep_token(self, value):
self._sep_token = value
self._maybe_update_backend([value])
@pad_token.setter
def pad_token(self, value):
self._pad_token = value
self._maybe_update_backend([value])
@cls_token.setter
def cls_token(self, value):
self._cls_token = value
self._maybe_update_backend([value])
@mask_token.setter
def mask_token(self, value):
self._mask_token = value
self._maybe_update_backend([value])
@additional_special_tokens.setter
def additional_special_tokens(self, value):
self._additional_special_tokens = value
self._maybe_update_backend(value)
@property
def bos_token_id(self):
@ -773,6 +814,23 @@ class SpecialTokensMixin:
def special_tokens_map(self):
""" A dictionary mapping special token class attribute (cls_token, unk_token...) to their
values ('<unk>', '<cls>'...)
Convert tokens of AddedToken type in string.
All returned tokens are strings
"""
set_attr = {}
for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
attr_value = getattr(self, "_" + attr)
if attr_value:
set_attr[attr] = str(attr_value)
return set_attr
@property
def special_tokens_map_extended(self):
""" A dictionary mapping special token class attribute (cls_token, unk_token...) to their
values ('<unk>', '<cls>'...)
Keep the tokens as AddedToken if they are of this type.
AddedToken can be used to control more finely how special tokens are tokenized.
"""
set_attr = {}
for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
@ -784,10 +842,22 @@ class SpecialTokensMixin:
@property
def all_special_tokens(self):
""" List all the special tokens ('<unk>', '<cls>'...) mapped to class attributes
Convert tokens of AddedToken type in string.
All returned tokens are strings
(cls_token, unk_token...).
"""
all_toks = [str(s) for s in self.all_special_tokens_extended]
return all_toks
@property
def all_special_tokens_extended(self):
""" List all the special tokens ('<unk>', '<cls>'...) mapped to class attributes
Keep the tokens as AddedToken if they are of this type.
AddedToken can be used to control more finely how special tokens are tokenized.
"""
all_toks = []
set_attr = self.special_tokens_map
set_attr = self.special_tokens_map_extended
for attr_value in set_attr.values():
all_toks = all_toks + (list(attr_value) if isinstance(attr_value, (list, tuple)) else [attr_value])
all_toks = list(set(all_toks))
@ -1153,6 +1223,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
special_tokens_map = json.load(special_tokens_map_handle)
for key, value in special_tokens_map.items():
if isinstance(value, dict):
value = AddedToken(**value)
if key not in init_kwargs:
init_kwargs[key] = value
@ -1171,7 +1243,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
# update unique_added_tokens_encoder with special tokens for correct tokenization
if hasattr(tokenizer, "unique_added_tokens_encoder"):
tokenizer.unique_added_tokens_encoder.update(set(tokenizer.all_special_tokens))
union = set(tokenizer.unique_added_tokens_encoder).union(tokenizer.all_special_tokens)
tokenizer.unique_added_tokens_encoder = list(union)
# Add supplementary tokens.
if added_tokens_file is not None:
@ -1180,7 +1253,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
tokenizer.added_tokens_encoder.update(added_tok_encoder)
tokenizer.added_tokens_decoder.update(added_tok_decoder)
tokenizer.unique_added_tokens_encoder.update(set(tokenizer.added_tokens_encoder.keys()))
union = set(tokenizer.unique_added_tokens_encoder).union(tokenizer.added_tokens_encoder.keys())
tokenizer.unique_added_tokens_encoder = list(union)
return tokenizer
@ -1214,7 +1288,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
with open(special_tokens_map_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.special_tokens_map, ensure_ascii=False))
write_dict = {}
for key, value in self.special_tokens_map_extended.items():
if isinstance(value, AddedToken):
write_dict[key] = value.__getstate__()
else:
write_dict[key] = value
f.write(json.dumps(write_dict, ensure_ascii=False))
if hasattr(self, "added_tokens_encoder") and len(self.added_tokens_encoder) > 0:
with open(added_tokens_file, "w", encoding="utf-8") as f:
@ -1395,7 +1475,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_lengths: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs
) -> BatchEncoding:
@ -1432,11 +1512,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
is_pretokenized=is_pretokenized,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_masks=return_attention_mask,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_masks=return_special_tokens_mask,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_lengths=return_lengths,
return_length=return_length,
verbose=verbose,
**kwargs,
)
@ -1456,6 +1536,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
**kwargs,
)
@ -1477,7 +1558,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_lengths: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs
) -> BatchEncoding:
@ -1516,7 +1597,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_lengths=return_lengths,
return_length=return_length,
verbose=verbose,
**kwargs,
)
@ -1537,6 +1618,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs
) -> BatchEncoding:
@ -1561,11 +1643,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
is_pretokenized: bool = False,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_masks: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_masks: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_lengths: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs
) -> BatchEncoding:
@ -1598,11 +1680,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
is_pretokenized=is_pretokenized,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_masks=return_attention_masks,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_masks=return_special_tokens_masks,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_lengths=return_lengths,
return_length=return_length,
verbose=verbose,
**kwargs,
)
@ -1625,11 +1707,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
is_pretokenized: bool = False,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_masks: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_masks: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_lengths: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs
) -> BatchEncoding:
@ -1637,63 +1719,84 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
def pad(
self,
encoding_or_batch: Dict[str, Union[List[EncodedInput], EncodedInput]],
encoded_inputs: Union[
BatchEncoding,
List[BatchEncoding],
Dict[str, EncodedInput],
Dict[str, List[EncodedInput]],
List[Dict[str, EncodedInput]],
],
padding: Union[bool, str] = True,
max_length: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
verbose: bool = True,
) -> dict:
""" Pad encoded inputs (on left/right and up to predefined legnth or max length in the batch)
) -> BatchEncoding:
""" Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length in the batch.
Padding side (left/right) padding token ids are defined at the tokenizer level
(with ``self.padding_side``, ``self.pad_token_id`` and ``self.pad_token_type_id``)
Args:
batch_ids: Dictionary of batch of tokenized inputs (`List[List[int]]`).
max_length: maximum length of the returned list and optionally padding length (see below).
Will truncate by taking into account the special tokens.
encoded_inputs: Dictionary of tokenized inputs (`Dict[str, List[int]]`) or batch of tokenized inputs.
Batch of tokenized inputs can be given as dicts of lists or lists of dicts, both work so you can
use ``tokenizer.pad()`` during pre-processing as well as in a PyTorch Dataloader collate function.
(`Dict[str, List[List[int]]]` or `List[Dict[str, List[int]]]`).
padding: Boolean or specific strategy to use for padding.
Select a strategy to pad the returned sequences (according to the model's padding side and padding index) among:
- 'longest' (or `True`) Pad to the longest sequence in the batch
- 'max_length': Pad to the max length (default)
- 'do_not_pad' (or `False`): Do not pad
The tokenizer padding sides are defined in self.padding_side:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
max_length: maximum length of the returned list and optionally padding length (see below).
Will truncate by taking into account the special tokens.
return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
return_tensors (:obj:`str`, `optional`, defaults to :obj:`None`):
Can be set to 'tf', 'pt' or 'np' to return respectively TensorFlow :obj:`tf.constant`,
PyTorch :obj:`torch.Tensor` or Numpy :oj: `np.ndarray` instead of a list of python integers.
verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):
Set to ``False`` to avoid printing infos and warnings.
"""
assert "input_ids" in encoding_or_batch, (
"You should supply an encoding to this method (a dict of lists/batch of int). "
"This is the output of encode/encode_plus/batch_encode_plus/__call__. "
# If we have a list of dicts, let's convert it in a dict of lists
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
assert "input_ids" in encoded_inputs, (
"You should supply an encoding or a list of encodings to this method. "
"An encoding is the output of one the encoding methods of the tokenizer, i.e. "
"__call__/encode_plus/batch_encode_plus. "
)
if not encoding_or_batch["input_ids"]:
if not encoded_inputs["input_ids"]:
if return_attention_mask:
encoding_or_batch["attention_mask"] = []
return encoding_or_batch
encoded_inputs["attention_mask"] = []
return encoded_inputs
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
padding=padding, max_length=max_length, verbose=verbose
)
if encoding_or_batch["input_ids"] and not isinstance(encoding_or_batch["input_ids"][0], (list, tuple)):
return self._pad(
encoding_or_batch,
if encoded_inputs["input_ids"] and not isinstance(encoded_inputs["input_ids"][0], (list, tuple)):
encoded_inputs = self._pad(
encoded_inputs,
max_length=max_length,
padding_strategy=padding_strategy,
return_attention_mask=return_attention_mask,
)
return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
batch_size = len(encoding_or_batch["input_ids"])
batch_size = len(encoded_inputs["input_ids"])
assert all(
len(v) == batch_size for v in encoding_or_batch.values()
len(v) == batch_size for v in encoded_inputs.values()
), "Some items in the output dictionnary have a different batch size than others."
if padding_strategy == PaddingStrategy.LONGEST:
max_length = max(len(inputs) for inputs in encoding_or_batch["input_ids"])
max_length = max(len(inputs) for inputs in encoded_inputs["input_ids"])
padding_strategy = PaddingStrategy.MAX_LENGTH
batch_outputs = {}
for i in range(batch_size):
inputs = dict((k, v[i]) for k, v in encoding_or_batch.items())
inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
outputs = self._pad(
inputs,
max_length=max_length,
@ -1706,11 +1809,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
batch_outputs[key] = []
batch_outputs[key].append(value)
return batch_outputs
return BatchEncoding(batch_outputs, tensor_type=return_tensors)
def _pad(
self,
encoded_inputs: Dict[str, EncodedInput],
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
return_attention_mask: Optional[bool] = None,

View File

@ -21,12 +21,12 @@ import os
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Union
from tokenizers import AddedToken as AddedTokenFast
from tokenizers import Encoding as EncodingFast
from tokenizers.decoders import Decoder as DecoderFast
from tokenizers.implementations import BaseTokenizer as BaseTokenizerFast
from .tokenization_utils_base import (
AddedToken,
BatchEncoding,
PaddingStrategy,
PreTokenizedInput,
@ -134,11 +134,6 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
def decoder(self) -> DecoderFast:
return self._tokenizer._tokenizer.decoder
def _maybe_update_backend(self, value):
""" Update the backend fast tokenizer.
Override method from base class SpecialTokensMixin """
self._tokenizer.add_special_tokens(value)
def _convert_encoding(
self,
encoding: EncodingFast,
@ -147,6 +142,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
) -> Dict[str, Any]:
""" Convert the encoding representation (from low-level HuggingFace tokenizer output) to a python Dict.
@ -178,6 +174,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
encoding_dict["special_tokens_mask"].append(e.special_tokens_mask)
if return_offsets_mapping:
encoding_dict["offset_mapping"].append(e.offsets)
if return_length:
encoding_dict["length"].append(len(e.ids))
return encoding_dict
@ -208,14 +206,14 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
def convert_tokens_to_string(self, tokens: List[int], skip_special_tokens: bool = False) -> str:
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
def add_tokens(self, new_tokens: List[Union[str, AddedTokenFast]]) -> int:
def add_tokens(self, new_tokens: List[Union[str, AddedToken]], special_token=False) -> int:
"""
Add a list of new tokens to the tokenizer class. If the new tokens are not in the
vocabulary, they are added to it with indices starting from length of the current vocabulary.
Args:
new_tokens: string or list of string or :class:`~transformers.AddedTokenFast`. Each string is a token to add.
Tokens are only added if they are not already in the vocabulary. AddedTokenFast wrap a string token to
new_tokens: string or list of string or :class:`~transformers.AddedToken`. Each string is a token to add.
Tokens are only added if they are not already in the vocabulary. AddedToken wrap a string token to
let you personnalize it's behavior (Whether this token should only match against single word, whether
this token should strip all potential whitespaces on the left side, Whether this token should strip
all potential whitespaces on the right side...).
@ -235,16 +233,12 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
print('We have added', num_added_toks, 'tokens')
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
"""
if isinstance(new_tokens, str):
if not isinstance(new_tokens, (list, tuple)):
new_tokens = [new_tokens]
# TODO This should be done in tokenizers to be really clean.
# Removing for now
# tokens = []
# for token in new_tokens:
# if self.init_kwargs.get("do_lower_case", False) and token not in self.all_special_tokens:
# token = token.lower()
# if token not in tokens:
# tokens.append(token)
if special_token:
return self._tokenizer.add_special_tokens(new_tokens)
return self._tokenizer.add_tokens(new_tokens)
def num_special_tokens_to_add(self, pair: bool = False) -> int:
@ -330,7 +324,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_lengths: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs
) -> BatchEncoding:
@ -340,6 +334,9 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
"batch_text_or_text_pairs has to be a list (got {})".format(type(batch_text_or_text_pairs))
)
if kwargs:
raise ValueError(f"Keyword arguments {kwargs} not recognized.")
# Set the truncation and padding strategy and restore the initial configuration
self.set_truncation_and_padding(
padding_strategy=padding_strategy,
@ -381,6 +378,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
)
for encoding in encodings
@ -419,6 +417,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs
) -> BatchEncoding:
@ -438,6 +437,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
**kwargs,
)

View File

@ -390,7 +390,7 @@ class TokenizerTesterMixin:
seq_1 = "With these inputs."
sequences = tokenizer.encode(seq_0, seq_1, add_special_tokens=False)
attached_sequences = tokenizer.encode(seq_0, seq_1, add_special_tokens=True, add_prefix_space=False)
attached_sequences = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
# Method is implemented (e.g. not GPT-2)
if len(attached_sequences) != 2:
@ -416,7 +416,7 @@ class TokenizerTesterMixin:
stride=stride,
truncation="longest_first",
return_overflowing_tokens=True,
add_prefix_space=False,
# add_prefix_space=False,
)
# Overflowing tokens are handled quite differently in slow and fast tokenizers
@ -468,7 +468,7 @@ class TokenizerTesterMixin:
# We are not using the special tokens - a bit too hard to test all the tokenizers with this
# TODO try this again later
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=False, add_prefix_space=False)
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=False) # , add_prefix_space=False)
truncated_first_sequence = tokenizer.encode(seq_0, add_special_tokens=False)[:-2] + tokenizer.encode(
seq_1, add_special_tokens=False
)
@ -499,7 +499,7 @@ class TokenizerTesterMixin:
stride=stride,
truncation="longest_first",
return_overflowing_tokens=True,
add_prefix_space=False,
# add_prefix_space=False,
)
# Overflowing tokens are handled quite differently in slow and fast tokenizers
if isinstance(tokenizer, PreTrainedTokenizerFast):
@ -531,7 +531,7 @@ class TokenizerTesterMixin:
stride=stride,
truncation=True,
return_overflowing_tokens=True,
add_prefix_space=False,
# add_prefix_space=False,
)
# Overflowing tokens are handled quite differently in slow and fast tokenizers
if isinstance(tokenizer, PreTrainedTokenizerFast):
@ -562,7 +562,7 @@ class TokenizerTesterMixin:
stride=stride,
truncation="only_second",
return_overflowing_tokens=True,
add_prefix_space=False,
# add_prefix_space=False,
)
# Overflowing tokens are handled quite differently in slow and fast tokenizers
if isinstance(tokenizer, PreTrainedTokenizerFast):
@ -638,7 +638,7 @@ class TokenizerTesterMixin:
# Testing single inputs
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
encoded_sequence_dict = tokenizer.encode_plus(
sequence_0, add_special_tokens=True, return_special_tokens_mask=True, add_prefix_space=False
sequence_0, add_special_tokens=True, return_special_tokens_mask=True # , add_prefix_space=False
)
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
@ -660,7 +660,7 @@ class TokenizerTesterMixin:
sequence_1,
add_special_tokens=True,
return_special_tokens_mask=True,
add_prefix_space=False,
# add_prefix_space=False,
)
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
@ -1042,7 +1042,7 @@ class TokenizerTesterMixin:
def test_pretokenized_inputs(self):
# Test when inputs are pretokenized
tokenizers = self.get_tokenizers(do_lower_case=False, add_prefix_space=True)
tokenizers = self.get_tokenizers(do_lower_case=False) # , add_prefix_space=True)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):

View File

@ -63,6 +63,21 @@ class CommonFastTokenizerTest(unittest.TestCase):
self.fast_align_python(tokenizer_r, tokenizer_p, tok_case, pretrained_name)
self.fast_only(tokenizer_r)
def test_pretokenized_tokenizers(self):
for tok_case in self.TOKENIZERS_CLASSES:
for pretrained_name in tok_case.python_cls.pretrained_vocab_files_map[tok_case.vocab_key].keys():
# Tokenizer.filter makes it possible to filter which Tokenizer to case based on all the
# information available in Tokenizer (name, rust class, python class, vocab key name)
if tok_case.filter is None or (
tok_case.filter is not None and tok_case.filter(tok_case, pretrained_name)
):
with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, add_prefix_space=True)
tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name, add_prefix_space=True)
self.assert_pretokenized_inputs(tokenizer_r, tokenizer_p)
def fast_align_python(self, tokenizer_r, tokenizer_p, tok_case, pretrained_name):
# Check is_fast is set correctly
self.assertFalse(tokenizer_p.is_fast)
@ -75,7 +90,6 @@ class CommonFastTokenizerTest(unittest.TestCase):
self.assert_special_tokens_map_equal(tokenizer_r, tokenizer_p)
self.assert_embeded_special_tokens(tokenizer_r, tokenizer_p)
self.assert_padding(tokenizer_r, tokenizer_p)
self.assert_pretokenized_inputs(tokenizer_r, tokenizer_p)
self.assert_create_token_type_ids(tokenizer_r, tokenizer_p)
# TODO: enable for v3.0.0
# self.assert_empty_output_no_special_tokens(tokenizer_r, tokenizer_p)
@ -341,6 +355,14 @@ class CommonFastTokenizerTest(unittest.TestCase):
"return_special_tokens_mask": True,
"return_offsets_mapping": False, # Not implemented in python tokenizers
}
batch_kwargs = {
"is_pretokenized": True,
"return_token_type_ids": True,
"return_attention_mask": True, # we have an 's' here
"return_overflowing_tokens": False,
"return_special_tokens_mask": True, # we have an 's' here
"return_offsets_mapping": False, # Not implemented in python tokenizers
}
# Test encode_plus for pretokenized inputs
output_r = tokenizer_r.encode_plus(pretokenized_input_simple, **kwargs)
output_p = tokenizer_p.encode_plus(pretokenized_input_simple, **kwargs)
@ -349,8 +371,8 @@ class CommonFastTokenizerTest(unittest.TestCase):
# Test batch_encode_plus for pretokenized inputs
input_batch = ([pretokenized_input_simple] * 2) + [pretokenized_input_simple + pretokenized_input_pair]
output_r = tokenizer_r.batch_encode_plus(input_batch, **kwargs)
output_p = tokenizer_p.batch_encode_plus(input_batch, **kwargs)
output_r = tokenizer_r.batch_encode_plus(input_batch, **batch_kwargs)
output_p = tokenizer_p.batch_encode_plus(input_batch, **batch_kwargs)
for key in output_p.keys():
self.assertEqual(output_p[key], output_r[key])
@ -370,8 +392,8 @@ class CommonFastTokenizerTest(unittest.TestCase):
pretokenized_input_simple + pretokenized_input_pair,
pretokenized_input_pair,
]
output_r = tokenizer_r.batch_encode_plus(input_batch_pair, **kwargs)
output_p = tokenizer_p.batch_encode_plus(input_batch_pair, **kwargs)
output_r = tokenizer_r.batch_encode_plus(input_batch_pair, **batch_kwargs)
output_p = tokenizer_p.batch_encode_plus(input_batch_pair, **batch_kwargs)
for key in output_p.keys():
self.assertEqual(output_p[key], output_r[key])
@ -756,8 +778,8 @@ class RobertaFastTokenizerTest(CommonFastTokenizerTest):
tokens_p = tokenizer_p.encode_plus(sentence, add_special_tokens=True, return_token_type_ids=True)
# Rust correctly handles the space before the mask while python doesnt
self.assertSequenceEqual(tokens_r["input_ids"], [0, 83, 6, 50264, 3823, 487, 21992, 3645, 4, 2])
self.assertSequenceEqual(tokens_p["input_ids"], [0, 83, 6, 50264, 3823, 487, 21992, 3645, 4, 2])
self.assertSequenceEqual(tokens_r["input_ids"], [0, 250, 6, 50264, 3823, 487, 21992, 3645, 4, 2])
self.assertSequenceEqual(tokens_p["input_ids"], [0, 250, 6, 50264, 3823, 487, 21992, 3645, 4, 2])
# token_type_ids should put 0 everywhere
self.assertEquals(sum(tokens_r["token_type_ids"]), sum(tokens_p["token_type_ids"]))
@ -768,9 +790,10 @@ class RobertaFastTokenizerTest(CommonFastTokenizerTest):
sum(tokens_p["attention_mask"]) / len(tokens_p["attention_mask"]),
)
# Rust should have 'Ġ' before <mask> which should be left as an entire token
tokens_r = tokenizer_r.convert_ids_to_tokens(tokens_r["input_ids"])
self.assertSequenceEqual(tokens_r, ["<s>", "ĠA", ",", "<mask>", "ĠAllen", "N", "LP", "Ġsentence", ".", "</s>"])
tokens_p = tokenizer_p.convert_ids_to_tokens(tokens_p["input_ids"])
self.assertSequenceEqual(tokens_r, ["<s>", "A", ",", "<mask>", "ĠAllen", "N", "LP", "Ġsentence", ".", "</s>"])
self.assertSequenceEqual(tokens_p, ["<s>", "A", ",", "<mask>", "ĠAllen", "N", "LP", "Ġsentence", ".", "</s>"])
class NoPaddingTokenFastTokenizerMatchingTest(CommonFastTokenizerTest):
@ -840,3 +863,7 @@ class TransfoXLFastTokenizerTest(NoPaddingTokenFastTokenizerMatchingTest):
@require_torch
def test_all_tokenizers(self):
super().test_all_tokenizers()
@require_torch
def test_pretokenized_tokenizers(self):
super().test_pretokenized_tokenizers()

View File

@ -80,12 +80,12 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_full_tokenizer(self):
tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower newer"
bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"]
tokens = tokenizer.tokenize(text, add_prefix_space=True)
bpe_tokens = ["l", "o", "w", "er", "\u0120", "n", "e", "w", "er"]
tokens = tokenizer.tokenize(text) # , add_prefix_space=True)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + [tokenizer.unk_token]
input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
input_bpe_tokens = [0, 1, 2, 15, 10, 9, 3, 2, 15, 19]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
def roberta_dict_integration_testing(self):
@ -124,7 +124,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
space_encoding = tokenizer.byte_encoder[" ".encode("utf-8")[0]]
# Testing encoder arguments
encoded = tokenizer.encode(sequence, add_special_tokens=False)
encoded = tokenizer.encode(sequence, add_special_tokens=False, add_prefix_space=False)
first_char = tokenizer.convert_ids_to_tokens(encoded[0])[0]
self.assertNotEqual(first_char, space_encoding)
@ -135,7 +135,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer.add_special_tokens({"bos_token": "<s>"})
encoded = tokenizer.encode(sequence, add_special_tokens=True)
first_char = tokenizer.convert_ids_to_tokens(encoded[1])[0]
self.assertEqual(first_char, space_encoding)
self.assertNotEqual(first_char, space_encoding)
# Testing spaces after special tokenss
mask = "<mask>"