Cleanup fast tokenizers integration (#3706)

* First pass on utility classes and python tokenizers

* finishing cleanup pass

* style and quality

* Fix tests

* Updating following @mfuntowicz comment

* style and quality

* Fix Roberta

* fix batch_size/seq_length inBatchEncoding

* add alignement methods + tests

* Fix OpenAI and Transfo-XL tokenizers

* adding trim_offsets=True default for GPT2 et RoBERTa

* style and quality

* fix tests

* add_prefix_space in roberta

* bump up tokenizers to rc7

* style

* unfortunately tensorfow does like these - removing shape/seq_len for now

* Update src/transformers/tokenization_utils.py

Co-Authored-By: Stefan Schweter <stefan@schweter.it>

* Adding doc and docstrings

* making flake8 happy

Co-authored-by: Stefan Schweter <stefan@schweter.it>
This commit is contained in:
Thomas Wolf 2020-04-18 13:43:57 +02:00 committed by GitHub
parent 60a42ef1c0
commit 827d6d6ef0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 1031 additions and 503 deletions

View File

@ -1,16 +1,38 @@
Tokenizer
----------------------------------------------------
The base class ``PreTrainedTokenizer`` implements the common methods for loading/saving a tokenizer either from a local file or directory, or from a pretrained tokenizer provided by the library (downloaded from HuggingFace's AWS S3 repository).
A tokenizer is in charge of preparing the inputs for a model. The library comprise tokenizers for all the models. Most of the tokenizers are available in two flavors: a full python implementation and a "Fast" implementation based on the Rust library `tokenizers`. The "Fast" implementations allows (1) a significant speed-up in particular when doing batched tokenization and (2) additional methods to map between the original string (character and words) and the token space (e.g. getting the index of the token comprising a given character or the span of characters corresponding to a given token). Currently no "Fast" implementation is available for the SentencePiece-based tokenizers (for T5, ALBERT, CamemBERT, XLMRoBERTa and XLNet models).
``PreTrainedTokenizer`` is the main entry point into tokenizers as it also implements the main methods for using all the tokenizers:
The base classes ``PreTrainedTokenizer`` and ``PreTrainedTokenizerFast`` implements the common methods for encoding string inputs in model inputs (see below) and instantiating/saving python and "Fast" tokenizers either from a local file or directory or from a pretrained tokenizer provided by the library (downloaded from HuggingFace's AWS S3 repository).
- tokenizing, converting tokens to ids and back and encoding/decoding,
``PreTrainedTokenizer`` and ``PreTrainedTokenizerFast`` thus implements the main methods for using all the tokenizers:
- tokenizing (spliting strings in sub-word token strings), converting tokens strings to ids and back, and encoding/decoding (i.e. tokenizing + convert to integers),
- adding new tokens to the vocabulary in a way that is independant of the underlying structure (BPE, SentencePiece...),
- managing special tokens (adding them, assigning them to roles, making sure they are not split during tokenization)
- managing special tokens like mask, beginning-of-sentence, etc tokens (adding them, assigning them to attributes in the tokenizer for easy access and making sure they are not split during tokenization)
``BatchEncoding`` holds the output of the tokenizer's encoding methods (``encode_plus`` and ``batch_encode_plus``) and is derived from a Python dictionary. When the tokenizer is a pure python tokenizer, this class behave just like a standard python dictionary and hold the various model inputs computed by these methodes (``input_ids``, ``attention_mask``...). When the tokenizer is a "Fast" tokenizer (i.e. backed by HuggingFace tokenizers library), this class provides in addition several advanced alignement methods which can be used to map between the original string (character and words) and the token space (e.g. getting the index of the token comprising a given character or the span of characters corresponding to a given token).
``PreTrainedTokenizer``
~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.PreTrainedTokenizer
:members:
``PreTrainedTokenizerFast``
~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.PreTrainedTokenizerFast
:members:
``BatchEncoding``
~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BatchEncoding
:members:
``SpecialTokensMixin``
~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.SpecialTokensMixin
:members:

View File

@ -52,6 +52,13 @@ BertTokenizer
create_token_type_ids_from_sequences, save_vocabulary
BertTokenizerFast
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BertTokenizerFast
:members:
BertModel
~~~~~~~~~~~~~~~~~~~~

View File

@ -44,6 +44,13 @@ DistilBertTokenizer
:members:
DistilBertTokenizerFast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.DistilBertTokenizerFast
:members:
DistilBertModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -61,6 +61,13 @@ ElectraTokenizer
:members:
ElectraTokenizerFast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.ElectraTokenizerFast
:members:
ElectraModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -53,6 +53,13 @@ OpenAIGPTTokenizer
:members: save_vocabulary
OpenAIGPTTokenizerFast
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.OpenAIGPTTokenizerFast
:members:
OpenAIGPTModel
~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -51,6 +51,13 @@ GPT2Tokenizer
:members: save_vocabulary
GPT2TokenizerFast
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.GPT2TokenizerFast
:members:
GPT2Model
~~~~~~~~~~~~~~~~~~~~~

View File

@ -46,6 +46,13 @@ RobertaTokenizer
create_token_type_ids_from_sequences, save_vocabulary
RobertaTokenizerFast
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.RobertaTokenizerFast
:members: build_inputs_with_special_tokens
RobertaModel
~~~~~~~~~~~~~~~~~~~~

View File

@ -47,6 +47,13 @@ TransfoXLTokenizer
:members: save_vocabulary
TransfoXLTokenizerFast
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TransfoXLTokenizerFast
:members:
TransfoXLModel
~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -67,7 +67,7 @@ class TextDataset(Dataset):
def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path: str, block_size=512):
assert os.path.isfile(file_path)
block_size = block_size - (tokenizer.max_len - tokenizer.max_len_single_sentence)
block_size = block_size - tokenizer.num_special_tokens_to_add(pair=False)
directory, filename = os.path.split(file_path)
cached_features_file = os.path.join(

View File

@ -96,7 +96,7 @@ setup(
packages=find_packages("src"),
install_requires=[
"numpy",
"tokenizers == 0.7.0rc5",
"tokenizers == 0.7.0rc7",
# dataclasses for Python versions that don't have it
"dataclasses;python_version<'3.7'",
# accessing files from S3 directly

View File

@ -137,9 +137,6 @@ class AlbertTokenizer(PreTrainedTokenizer):
**kwargs,
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
try:
import sentencepiece as spm
except ImportError:

View File

@ -182,8 +182,6 @@ class BertTokenizer(PreTrainedTokenizer):
mask_token=mask_token,
**kwargs,
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
if not os.path.isfile(vocab_file):
raise ValueError(
@ -583,6 +581,48 @@ def _is_punctuation(char):
class BertTokenizerFast(PreTrainedTokenizerFast):
r"""
Constructs a "Fast" BERT tokenizer (backed by HuggingFace's `tokenizers` library).
Bert tokenization is Based on WordPiece.
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the methods. Users
should refer to the superclass for more information regarding methods.
Args:
vocab_file (:obj:`string`):
File containing the vocabulary.
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to lowercase the input when tokenizing.
unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
for sequence classification or for a text and a question for question answering.
It is also used as the last token of a sequence built with special tokens.
pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
The token used for padding, for example when batching sequences of different lengths.
cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
The classifier token which is used when doing sequence classification (classification of the whole
sequence instead of per-token classification). It is the first token of the sequence when built with
special tokens.
mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
The token used for masking values. This is the token used when training this model with masked language
modeling. This is the token which the model will try to predict.
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to tokenize Chinese characters.
This should likely be deactivated for Japanese:
see: https://github.com/huggingface/transformers/issues/328
clean_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to clean the text before tokenization by removing any control characters and
replacing all whitespaces by the classic one.
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to tokenize Chinese characters.
This should likely be deactivated for Japanese:
see: https://github.com/huggingface/transformers/issues/328
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION

View File

@ -119,8 +119,6 @@ class BertJapaneseTokenizer(BertTokenizer):
**kwargs,
)
# ^^ We call the grandparent's init, not the parent's.
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
if not os.path.isfile(vocab_file):
raise ValueError(

View File

@ -129,8 +129,6 @@ class CamembertTokenizer(PreTrainedTokenizer):
additional_special_tokens=additional_special_tokens,
**kwargs,
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(str(vocab_file))
self.vocab_file = vocab_file

View File

@ -140,12 +140,6 @@ class CTRLTokenizer(PreTrainedTokenizer):
def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
super().__init__(unk_token=unk_token, **kwargs)
self.max_len_single_sentence = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)

View File

@ -57,8 +57,9 @@ PRETRAINED_INIT_CONFIGURATION = {
class DistilBertTokenizer(BertTokenizer):
r"""
Constructs a DistilBertTokenizer.
:class:`~transformers.DistilBertTokenizer` is identical to :class:`~transformers.BertTokenizer` and runs end-to-end
Constructs a DistilBertTokenizer.
:class:`~transformers.DistilBertTokenizer is identical to :class:`~transformers.BertTokenizer` and runs end-to-end
tokenization: punctuation splitting + wordpiece.
Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning
@ -73,6 +74,16 @@ class DistilBertTokenizer(BertTokenizer):
class DistilBertTokenizerFast(BertTokenizerFast):
r"""
Constructs a "Fast" DistilBertTokenizer (backed by HuggingFace's `tokenizers` library).
:class:`~transformers.DistilBertTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs end-to-end
tokenization: punctuation splitting + wordpiece.
Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning
parameters.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES

View File

@ -67,7 +67,8 @@ class ElectraTokenizer(BertTokenizer):
class ElectraTokenizerFast(BertTokenizerFast):
r"""
Constructs an Electra Fast tokenizer.
Constructs a "Fast" Electra Fast tokenizer (backed by HuggingFace's `tokenizers` library).
:class:`~transformers.ElectraTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs end-to-end
tokenization: punctuation splitting + wordpiece.

View File

@ -147,12 +147,6 @@ class GPT2Tokenizer(PreTrainedTokenizer):
**kwargs
):
super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
self.max_len_single_sentence = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
@ -284,6 +278,47 @@ class GPT2Tokenizer(PreTrainedTokenizer):
class GPT2TokenizerFast(PreTrainedTokenizerFast):
"""
Constructs a "Fast" GPT-2 BPE tokenizer (backed by HuggingFace's `tokenizers` library).
Peculiarities:
- Byte-level Byte-Pair-Encoding
- Requires a space to start the input string => the encoding methods should be called with the
``add_prefix_space`` flag set to ``True``.
Otherwise, this tokenizer ``encode`` and ``decode`` method will not conserve
the absence of a space at the beginning of a string:
::
tokenizer.decode(tokenizer.encode("Hello")) = " Hello"
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the methods. Users
should refer to the superclass for more information regarding methods.
Args:
vocab_file (:obj:`str`):
Path to the vocabulary file.
merges_file (:obj:`str`):
Path to the merges file.
errors (:obj:`str`, `optional`, defaults to "replace"):
Paradigm to follow when decoding bytes to UTF-8. See `bytes.decode
<https://docs.python.org/3/library/stdtypes.html#bytes.decode>`__ for more information.
unk_token (:obj:`string`, `optional`, defaults to `<|endoftext|>`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
bos_token (:obj:`string`, `optional`, defaults to `<|endoftext|>`):
The beginning of sequence token.
eos_token (:obj:`string`, `optional`, defaults to `<|endoftext|>`):
The end of sequence token.
add_prefix_space (:obj:`bool`, `optional`, defaults to `False`):
Whether to add a leading space to the first word.
This allows to treat the leading word just as any other word.
(GPT2 tokenizer detect beginning of words by the preceeding space)
trim_offsets (:obj:`bool`, `optional`, defaults to `True`):
Whether the post processing step should trim offsets to avoid including whitespaces.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
@ -296,10 +331,16 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
bos_token="<|endoftext|>",
eos_token="<|endoftext|>",
add_prefix_space=False,
trim_offsets=True,
**kwargs
):
super().__init__(
ByteLevelBPETokenizer(vocab_file=vocab_file, merges_file=merges_file, add_prefix_space=add_prefix_space),
ByteLevelBPETokenizer(
vocab_file=vocab_file,
merges_file=merges_file,
add_prefix_space=add_prefix_space,
trim_offsets=trim_offsets,
),
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,

View File

@ -19,15 +19,8 @@ import json
import logging
import os
import re
from typing import List, Optional, Union
from tokenizers import Tokenizer
from tokenizers.decoders import BPEDecoder
from tokenizers.implementations import BaseTokenizer
from tokenizers.models import BPE
from tokenizers.normalizers import BertNormalizer, Sequence, unicode_normalizer_from_str
from tokenizers.pre_tokenizers import BertPreTokenizer
from tokenizers.trainers import BpeTrainer
from tokenizers import CharBPETokenizer
from .tokenization_bert import BasicTokenizer
from .tokenization_utils import PreTrainedTokenizer, PreTrainedTokenizerFast
@ -106,13 +99,6 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
super().__init__(unk_token=unk_token, **kwargs)
self.max_len_single_sentence = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
try:
import ftfy
from spacy.lang.en import English
@ -249,83 +235,28 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
return vocab_file, merge_file
class _OpenAIGPTCharBPETokenizer(BaseTokenizer):
"""
OpenAI character-level BPE Tokenizer
"""
def __init__(
self,
vocab_file: Optional[str] = None,
merges_file: Optional[str] = None,
unk_token: Optional[str] = "<unk>",
suffix: Optional[str] = "</w>",
dropout: Optional[float] = None,
unicode_normalizer: Optional[str] = None,
):
if vocab_file is not None and merges_file is not None:
tokenizer = Tokenizer(
BPE(vocab_file, merges_file, dropout=dropout, unk_token=unk_token, end_of_word_suffix=suffix)
)
else:
tokenizer = Tokenizer(BPE())
# Check for Unicode normalization first (before everything else)
normalizers = []
if unicode_normalizer:
normalizers += [unicode_normalizer_from_str(unicode_normalizer)]
# OpenAI normalization is the same as Bert
normalizers += [BertNormalizer()]
# Create the normalizer structure
if len(normalizers) > 0:
if len(normalizers) > 1:
tokenizer.normalizer = Sequence(normalizers)
else:
tokenizer.normalizer = normalizers[0]
tokenizer.pre_tokenizer = BertPreTokenizer()
tokenizer.decoder = BPEDecoder(suffix=suffix)
parameters = {
"model": "BPE",
"unk_token": unk_token,
"suffix": suffix,
"dropout": dropout,
}
super().__init__(tokenizer, parameters)
def train(
self,
files: Union[str, List[str]],
vocab_size: int = 30000,
min_frequency: int = 2,
special_tokens: List[str] = ["<unk>"],
limit_alphabet: int = 1000,
initial_alphabet: List[str] = [],
suffix: Optional[str] = "</w>",
show_progress: bool = True,
):
""" Train the model using the given files """
trainer = BpeTrainer(
vocab_size=vocab_size,
min_frequency=min_frequency,
special_tokens=special_tokens,
limit_alphabet=limit_alphabet,
initial_alphabet=initial_alphabet,
end_of_word_suffix=suffix,
show_progress=show_progress,
)
if isinstance(files, str):
files = [files]
self._tokenizer.train(trainer, files)
class OpenAIGPTTokenizerFast(PreTrainedTokenizerFast):
"""
Construct a "Fast" BPE tokenizer for OpenAI GPT (backed by HuggingFace's `tokenizers` library).
Peculiarities:
- lower case all inputs
- uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
should refer to the superclass for more information regarding methods.
Args:
vocab_file (:obj:`str`):
Path to the vocabulary file.
merges_file (:obj:`str`):
Path to the merges file.
unk_token (:obj:`string`, `optional`, defaults to "<unk>"):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
@ -333,5 +264,6 @@ class OpenAIGPTTokenizerFast(PreTrainedTokenizerFast):
def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
kwargs.setdefault("unk_token", unk_token)
super().__init__(
_OpenAIGPTCharBPETokenizer(vocab_file=vocab_file, merges_file=merges_file, unk_token=unk_token), **kwargs
CharBPETokenizer(vocab_file=vocab_file, merges_file=merges_file, unk_token=unk_token, lowercase=True),
**kwargs,
)

View File

@ -150,8 +150,6 @@ class RobertaTokenizer(GPT2Tokenizer):
mask_token=mask_token,
**kwargs,
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
@ -244,6 +242,47 @@ class RobertaTokenizer(GPT2Tokenizer):
class RobertaTokenizerFast(GPT2TokenizerFast):
"""
Constructs a "Fast" RoBERTa BPE tokenizer (backed by HuggingFace's `tokenizers` library).
Peculiarities:
- Byte-level Byte-Pair-Encoding
- Requires a space to start the input string => the encoding methods should be called with the
``add_prefix_space`` flag set to ``True``.
Otherwise, this tokenizer ``encode`` and ``decode`` method will not conserve
the absence of a space at the beginning of a string:
::
tokenizer.decode(tokenizer.encode("Hello")) = " Hello"
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the methods. Users
should refer to the superclass for more information regarding methods.
Args:
vocab_file (:obj:`str`):
Path to the vocabulary file.
merges_file (:obj:`str`):
Path to the merges file.
errors (:obj:`str`, `optional`, defaults to "replace"):
Paradigm to follow when decoding bytes to UTF-8. See `bytes.decode
<https://docs.python.org/3/library/stdtypes.html#bytes.decode>`__ for more information.
unk_token (:obj:`string`, `optional`, defaults to `<|endoftext|>`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
bos_token (:obj:`string`, `optional`, defaults to `<|endoftext|>`):
The beginning of sequence token.
eos_token (:obj:`string`, `optional`, defaults to `<|endoftext|>`):
The end of sequence token.
add_prefix_space (:obj:`bool`, `optional`, defaults to `False`):
Whether to add a leading space to the first word.
This allows to treat the leading word just as any other word.
(GPT2 tokenizer detect beginning of words by the preceeding space)
trim_offsets (:obj:`bool`, `optional`, defaults to `True`):
Whether the post processing step should trim offsets to avoid including whitespaces.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
@ -262,6 +301,7 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
pad_token="<pad>",
mask_token="<mask>",
add_prefix_space=True,
trim_offsets=True,
**kwargs
):
kwargs.setdefault("pad_token", pad_token)
@ -276,23 +316,18 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
bos_token=bos_token,
eos_token=eos_token,
add_prefix_space=add_prefix_space,
trim_offsets=trim_offsets,
**kwargs,
)
self.tokenizer._tokenizer.post_processor = RobertaProcessing(
(sep_token, self.sep_token_id), (cls_token, self.cls_token_id)
self.backend_tokenizer._tokenizer.post_processor = RobertaProcessing(
sep=(sep_token, self.sep_token_id),
cls=(cls_token, self.cls_token_id),
add_prefix_space=add_prefix_space,
trim_offsets=trim_offsets,
)
self.tokenizer.add_special_tokens([kwargs["mask_token"]])
# As we override the post_processor post super.__init__ the computed num_added_tokens is wrong in super().
# We need to recompute max_len according to the newly register post_processor to get real values.
self.max_len_single_sentence = self.max_len - self.num_special_tokens_to_add(
False
) # take into account special tokens
self.max_len_sentences_pair = self.max_len - self.num_special_tokens_to_add(
True
) # take into account special tokens
self.backend_tokenizer.add_special_tokens([kwargs["mask_token"]])
@PreTrainedTokenizer.mask_token.setter
def mask_token(self, value):
@ -300,7 +335,7 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
value = AddedToken(value, lstrip=True)
self._mask_token = str(value)
self.tokenizer.add_special_tokens([value])
self._maybe_update_backend([value])
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]

View File

@ -118,12 +118,6 @@ class T5Tokenizer(PreTrainedTokenizer):
additional_special_tokens=additional_special_tokens,
**kwargs,
)
self.max_len_single_sentence = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
try:
import sentencepiece as spm

View File

@ -101,13 +101,6 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
unk_token=unk_token, eos_token=eos_token, additional_special_tokens=additional_special_tokens, **kwargs
)
self.max_len_single_sentence = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
if never_split is None:
never_split = self.all_special_tokens
if special is None:
@ -410,6 +403,16 @@ class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer):
class TransfoXLTokenizerFast(PreTrainedTokenizerFast):
"""
Construct a "Fast" Transformer-XL tokenizer (backed by HuggingFace's `tokenizers` library).
The Transformer-XL tokenizer is a word-level tokenizer (no sub-word tokenization).
Adapted from Vocab class in https://github.com/kimiyoung/transformer-xl
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the methods. Users
should refer to the superclass for more information regarding methods.
"""
vocab_files_names = VOCAB_FILES_NAMES_FAST
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP_FAST

File diff suppressed because it is too large Load Diff

View File

@ -629,9 +629,6 @@ class XLMTokenizer(PreTrainedTokenizer):
**kwargs,
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
# cache of sm.MosesPunctNormalizer instance
self.cache_moses_punct_normalizer = dict()
# cache of sm.MosesTokenizer instance

View File

@ -128,8 +128,6 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
mask_token=mask_token,
**kwargs,
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
try:
import sentencepiece as spm

View File

@ -138,8 +138,6 @@ class XLNetTokenizer(PreTrainedTokenizer):
**kwargs,
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
self._pad_token_type_id = 3
try:

View File

@ -117,8 +117,6 @@ class XxxTokenizer(PreTrainedTokenizer):
mask_token=mask_token,
**kwargs,
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
if not os.path.isfile(vocab_file):
raise ValueError(

View File

@ -1,3 +1,4 @@
import logging
import unittest
from collections import namedtuple
from itertools import takewhile
@ -21,6 +22,10 @@ from transformers.tokenization_roberta import RobertaTokenizerFast
from transformers.tokenization_transfo_xl import TransfoXLTokenizerFast
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
NON_ENGLISH_TAGS = ["chinese", "dutch", "french", "finnish", "german", "multilingual"]
Tokenizer = namedtuple("Tokenizer", ["name", "rust_cls", "python_cls", "vocab_key", "filter"])
@ -83,6 +88,85 @@ class CommonFastTokenizerTest(unittest.TestCase):
self.assert_add_tokens(tokenizer_r)
self.assert_offsets_mapping(tokenizer_r)
self.assert_add_special_tokens(tokenizer_r)
self.assert_alignement_methods(tokenizer_r)
def assert_alignement_methods(self, tokenizer_r):
words = ["Wonderful", "no", "inspiration", "example", "with", "subtoken"]
text = " ".join(words)
batch_size = 3
encoding = tokenizer_r.encode_plus(text, add_special_tokens=False)
batch_encoding = tokenizer_r.batch_encode_plus([text] * batch_size, add_special_tokens=False)
num_tokens = len(encoding["input_ids"])
last_word_index = len(words) - 1
last_token_index = num_tokens - 1
last_batch_index = batch_size - 1
last_char_index = len(text) - 1
# words, tokens
self.assertEqual(len(encoding.words(0)), num_tokens)
self.assertEqual(max(encoding.words(0)), last_word_index)
self.assertEqual(min(encoding.words(0)), 0)
self.assertEqual(len(batch_encoding.words(last_batch_index)), num_tokens)
self.assertEqual(max(batch_encoding.words(last_batch_index)), last_word_index)
self.assertEqual(min(batch_encoding.words(last_batch_index)), 0)
self.assertEqual(len(encoding.tokens(0)), num_tokens)
# Assert token_to_word
self.assertEqual(encoding.token_to_word(0), 0)
self.assertEqual(encoding.token_to_word(0, 0), 0)
self.assertEqual(encoding.token_to_word(last_token_index), last_word_index)
self.assertEqual(encoding.token_to_word(0, last_token_index), last_word_index)
self.assertEqual(batch_encoding.token_to_word(1, 0), 0)
self.assertEqual(batch_encoding.token_to_word(0, last_token_index), last_word_index)
self.assertEqual(batch_encoding.token_to_word(last_batch_index, last_token_index), last_word_index)
# Assert word_to_tokens
self.assertEqual(encoding.word_to_tokens(0).start, 0)
self.assertEqual(encoding.word_to_tokens(0, 0).start, 0)
self.assertEqual(encoding.word_to_tokens(last_word_index).end, last_token_index + 1)
self.assertEqual(encoding.word_to_tokens(0, last_word_index).end, last_token_index + 1)
self.assertEqual(batch_encoding.word_to_tokens(1, 0).start, 0)
self.assertEqual(batch_encoding.word_to_tokens(0, last_word_index).end, last_token_index + 1)
self.assertEqual(batch_encoding.word_to_tokens(last_batch_index, last_word_index).end, last_token_index + 1)
# Assert token_to_chars
self.assertEqual(encoding.token_to_chars(0).start, 0)
self.assertEqual(encoding.token_to_chars(0, 0).start, 0)
self.assertEqual(encoding.token_to_chars(last_token_index).end, last_char_index + 1)
self.assertEqual(encoding.token_to_chars(0, last_token_index).end, last_char_index + 1)
self.assertEqual(batch_encoding.token_to_chars(1, 0).start, 0)
self.assertEqual(batch_encoding.token_to_chars(0, last_token_index).end, last_char_index + 1)
self.assertEqual(batch_encoding.token_to_chars(last_batch_index, last_token_index).end, last_char_index + 1)
# Assert char_to_token
self.assertEqual(encoding.char_to_token(0), 0)
self.assertEqual(encoding.char_to_token(0, 0), 0)
self.assertEqual(encoding.char_to_token(last_char_index), last_token_index)
self.assertEqual(encoding.char_to_token(0, last_char_index), last_token_index)
self.assertEqual(batch_encoding.char_to_token(1, 0), 0)
self.assertEqual(batch_encoding.char_to_token(0, last_char_index), last_token_index)
self.assertEqual(batch_encoding.char_to_token(last_batch_index, last_char_index), last_token_index)
# Assert char_to_word
self.assertEqual(encoding.char_to_word(0), 0)
self.assertEqual(encoding.char_to_word(0, 0), 0)
self.assertEqual(encoding.char_to_word(last_char_index), last_word_index)
self.assertEqual(encoding.char_to_word(0, last_char_index), last_word_index)
self.assertEqual(batch_encoding.char_to_word(1, 0), 0)
self.assertEqual(batch_encoding.char_to_word(0, last_char_index), last_word_index)
self.assertEqual(batch_encoding.char_to_word(last_batch_index, last_char_index), last_word_index)
# Assert word_to_chars
self.assertEqual(encoding.word_to_chars(0).start, 0)
self.assertEqual(encoding.word_to_chars(0, 0).start, 0)
self.assertEqual(encoding.word_to_chars(last_word_index).end, last_char_index + 1)
self.assertEqual(encoding.word_to_chars(0, last_word_index).end, last_char_index + 1)
self.assertEqual(batch_encoding.word_to_chars(1, 0).start, 0)
self.assertEqual(batch_encoding.word_to_chars(0, last_word_index).end, last_char_index + 1)
self.assertEqual(batch_encoding.word_to_chars(last_batch_index, last_word_index).end, last_char_index + 1)
def assert_tokenization_python_rust_equals(self, tokenizer_p, tokenizer_r):
# Ensure basic input match
@ -306,7 +390,6 @@ class CommonFastTokenizerTest(unittest.TestCase):
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
# Simple input
# TODO: Re-enable this test when batch_encode_plus with padding correctly handles padding
input_r = tokenizer_r.batch_encode_plus(
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, pad_to_max_length=True
)
@ -316,7 +399,6 @@ class CommonFastTokenizerTest(unittest.TestCase):
assert_batch_padded_input_match(input_r, input_p)
# Pair input
# TODO: Re-enable this test when batch_encode_plus with padding correctly handles padding
input_r = tokenizer_r.batch_encode_plus(
[
("This is a simple input 1", "This is a simple input 2"),