From 827d6d6ef071029cfe82838a18dab046b5813976 Mon Sep 17 00:00:00 2001 From: Thomas Wolf Date: Sat, 18 Apr 2020 13:43:57 +0200 Subject: [PATCH] Cleanup fast tokenizers integration (#3706) * First pass on utility classes and python tokenizers * finishing cleanup pass * style and quality * Fix tests * Updating following @mfuntowicz comment * style and quality * Fix Roberta * fix batch_size/seq_length inBatchEncoding * add alignement methods + tests * Fix OpenAI and Transfo-XL tokenizers * adding trim_offsets=True default for GPT2 et RoBERTa * style and quality * fix tests * add_prefix_space in roberta * bump up tokenizers to rc7 * style * unfortunately tensorfow does like these - removing shape/seq_len for now * Update src/transformers/tokenization_utils.py Co-Authored-By: Stefan Schweter * Adding doc and docstrings * making flake8 happy Co-authored-by: Stefan Schweter --- docs/source/main_classes/tokenizer.rst | 30 +- docs/source/model_doc/bert.rst | 7 + docs/source/model_doc/distilbert.rst | 7 + docs/source/model_doc/electra.rst | 7 + docs/source/model_doc/gpt.rst | 7 + docs/source/model_doc/gpt2.rst | 7 + docs/source/model_doc/roberta.rst | 7 + docs/source/model_doc/transformerxl.rst | 7 + examples/run_language_modeling.py | 2 +- setup.py | 2 +- src/transformers/tokenization_albert.py | 3 - src/transformers/tokenization_bert.py | 44 +- .../tokenization_bert_japanese.py | 2 - src/transformers/tokenization_camembert.py | 2 - src/transformers/tokenization_ctrl.py | 6 - src/transformers/tokenization_distilbert.py | 15 +- src/transformers/tokenization_electra.py | 3 +- src/transformers/tokenization_gpt2.py | 55 +- src/transformers/tokenization_openai.py | 116 +- src/transformers/tokenization_roberta.py | 65 +- src/transformers/tokenization_t5.py | 6 - src/transformers/tokenization_transfo_xl.py | 17 +- src/transformers/tokenization_utils.py | 1022 +++++++++++------ src/transformers/tokenization_xlm.py | 3 - src/transformers/tokenization_xlm_roberta.py | 2 - src/transformers/tokenization_xlnet.py | 2 - .../adding_a_new_model/tokenization_xxx.py | 2 - tests/test_tokenization_fast.py | 86 +- 28 files changed, 1031 insertions(+), 503 deletions(-) diff --git a/docs/source/main_classes/tokenizer.rst b/docs/source/main_classes/tokenizer.rst index c33eb45829..b826114fd5 100644 --- a/docs/source/main_classes/tokenizer.rst +++ b/docs/source/main_classes/tokenizer.rst @@ -1,16 +1,38 @@ Tokenizer ---------------------------------------------------- -The base class ``PreTrainedTokenizer`` implements the common methods for loading/saving a tokenizer either from a local file or directory, or from a pretrained tokenizer provided by the library (downloaded from HuggingFace's AWS S3 repository). +A tokenizer is in charge of preparing the inputs for a model. The library comprise tokenizers for all the models. Most of the tokenizers are available in two flavors: a full python implementation and a "Fast" implementation based on the Rust library `tokenizers`. The "Fast" implementations allows (1) a significant speed-up in particular when doing batched tokenization and (2) additional methods to map between the original string (character and words) and the token space (e.g. getting the index of the token comprising a given character or the span of characters corresponding to a given token). Currently no "Fast" implementation is available for the SentencePiece-based tokenizers (for T5, ALBERT, CamemBERT, XLMRoBERTa and XLNet models). -``PreTrainedTokenizer`` is the main entry point into tokenizers as it also implements the main methods for using all the tokenizers: +The base classes ``PreTrainedTokenizer`` and ``PreTrainedTokenizerFast`` implements the common methods for encoding string inputs in model inputs (see below) and instantiating/saving python and "Fast" tokenizers either from a local file or directory or from a pretrained tokenizer provided by the library (downloaded from HuggingFace's AWS S3 repository). -- tokenizing, converting tokens to ids and back and encoding/decoding, +``PreTrainedTokenizer`` and ``PreTrainedTokenizerFast`` thus implements the main methods for using all the tokenizers: + +- tokenizing (spliting strings in sub-word token strings), converting tokens strings to ids and back, and encoding/decoding (i.e. tokenizing + convert to integers), - adding new tokens to the vocabulary in a way that is independant of the underlying structure (BPE, SentencePiece...), -- managing special tokens (adding them, assigning them to roles, making sure they are not split during tokenization) +- managing special tokens like mask, beginning-of-sentence, etc tokens (adding them, assigning them to attributes in the tokenizer for easy access and making sure they are not split during tokenization) + +``BatchEncoding`` holds the output of the tokenizer's encoding methods (``encode_plus`` and ``batch_encode_plus``) and is derived from a Python dictionary. When the tokenizer is a pure python tokenizer, this class behave just like a standard python dictionary and hold the various model inputs computed by these methodes (``input_ids``, ``attention_mask``...). When the tokenizer is a "Fast" tokenizer (i.e. backed by HuggingFace tokenizers library), this class provides in addition several advanced alignement methods which can be used to map between the original string (character and words) and the token space (e.g. getting the index of the token comprising a given character or the span of characters corresponding to a given token). ``PreTrainedTokenizer`` ~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.PreTrainedTokenizer :members: + +``PreTrainedTokenizerFast`` +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.PreTrainedTokenizerFast + :members: + +``BatchEncoding`` +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BatchEncoding + :members: + +``SpecialTokensMixin`` +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.SpecialTokensMixin + :members: diff --git a/docs/source/model_doc/bert.rst b/docs/source/model_doc/bert.rst index 30d821d410..b77a241a8c 100644 --- a/docs/source/model_doc/bert.rst +++ b/docs/source/model_doc/bert.rst @@ -52,6 +52,13 @@ BertTokenizer create_token_type_ids_from_sequences, save_vocabulary +BertTokenizerFast +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BertTokenizerFast + :members: + + BertModel ~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/distilbert.rst b/docs/source/model_doc/distilbert.rst index 55eaa9a088..9eb9fa151d 100644 --- a/docs/source/model_doc/distilbert.rst +++ b/docs/source/model_doc/distilbert.rst @@ -44,6 +44,13 @@ DistilBertTokenizer :members: +DistilBertTokenizerFast +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DistilBertTokenizerFast + :members: + + DistilBertModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/electra.rst b/docs/source/model_doc/electra.rst index a48eb7562b..3dbac2cee2 100644 --- a/docs/source/model_doc/electra.rst +++ b/docs/source/model_doc/electra.rst @@ -61,6 +61,13 @@ ElectraTokenizer :members: +ElectraTokenizerFast +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.ElectraTokenizerFast + :members: + + ElectraModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/gpt.rst b/docs/source/model_doc/gpt.rst index d856c93554..449a85c3fe 100644 --- a/docs/source/model_doc/gpt.rst +++ b/docs/source/model_doc/gpt.rst @@ -53,6 +53,13 @@ OpenAIGPTTokenizer :members: save_vocabulary +OpenAIGPTTokenizerFast +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.OpenAIGPTTokenizerFast + :members: + + OpenAIGPTModel ~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/gpt2.rst b/docs/source/model_doc/gpt2.rst index d3f9953106..45ac90ec27 100644 --- a/docs/source/model_doc/gpt2.rst +++ b/docs/source/model_doc/gpt2.rst @@ -51,6 +51,13 @@ GPT2Tokenizer :members: save_vocabulary +GPT2TokenizerFast +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.GPT2TokenizerFast + :members: + + GPT2Model ~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/roberta.rst b/docs/source/model_doc/roberta.rst index 727ff1f07e..07e511228a 100644 --- a/docs/source/model_doc/roberta.rst +++ b/docs/source/model_doc/roberta.rst @@ -46,6 +46,13 @@ RobertaTokenizer create_token_type_ids_from_sequences, save_vocabulary +RobertaTokenizerFast +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.RobertaTokenizerFast + :members: build_inputs_with_special_tokens + + RobertaModel ~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/transformerxl.rst b/docs/source/model_doc/transformerxl.rst index 42a8d109e9..336bfdcd69 100644 --- a/docs/source/model_doc/transformerxl.rst +++ b/docs/source/model_doc/transformerxl.rst @@ -47,6 +47,13 @@ TransfoXLTokenizer :members: save_vocabulary +TransfoXLTokenizerFast +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TransfoXLTokenizerFast + :members: + + TransfoXLModel ~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/examples/run_language_modeling.py b/examples/run_language_modeling.py index 5d451e7612..92807dd79c 100644 --- a/examples/run_language_modeling.py +++ b/examples/run_language_modeling.py @@ -67,7 +67,7 @@ class TextDataset(Dataset): def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path: str, block_size=512): assert os.path.isfile(file_path) - block_size = block_size - (tokenizer.max_len - tokenizer.max_len_single_sentence) + block_size = block_size - tokenizer.num_special_tokens_to_add(pair=False) directory, filename = os.path.split(file_path) cached_features_file = os.path.join( diff --git a/setup.py b/setup.py index 10026a0ad8..1e6fa26456 100644 --- a/setup.py +++ b/setup.py @@ -96,7 +96,7 @@ setup( packages=find_packages("src"), install_requires=[ "numpy", - "tokenizers == 0.7.0rc5", + "tokenizers == 0.7.0rc7", # dataclasses for Python versions that don't have it "dataclasses;python_version<'3.7'", # accessing files from S3 directly diff --git a/src/transformers/tokenization_albert.py b/src/transformers/tokenization_albert.py index c8a458a489..953c250b3a 100644 --- a/src/transformers/tokenization_albert.py +++ b/src/transformers/tokenization_albert.py @@ -137,9 +137,6 @@ class AlbertTokenizer(PreTrainedTokenizer): **kwargs, ) - self.max_len_single_sentence = self.max_len - 2 # take into account special tokens - self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens - try: import sentencepiece as spm except ImportError: diff --git a/src/transformers/tokenization_bert.py b/src/transformers/tokenization_bert.py index 21e3973234..376196c314 100644 --- a/src/transformers/tokenization_bert.py +++ b/src/transformers/tokenization_bert.py @@ -182,8 +182,6 @@ class BertTokenizer(PreTrainedTokenizer): mask_token=mask_token, **kwargs, ) - self.max_len_single_sentence = self.max_len - 2 # take into account special tokens - self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens if not os.path.isfile(vocab_file): raise ValueError( @@ -583,6 +581,48 @@ def _is_punctuation(char): class BertTokenizerFast(PreTrainedTokenizerFast): + r""" + Constructs a "Fast" BERT tokenizer (backed by HuggingFace's `tokenizers` library). + + Bert tokenization is Based on WordPiece. + + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the methods. Users + should refer to the superclass for more information regarding methods. + + Args: + vocab_file (:obj:`string`): + File containing the vocabulary. + do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to lowercase the input when tokenizing. + unk_token (:obj:`string`, `optional`, defaults to "[UNK]"): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (:obj:`string`, `optional`, defaults to "[SEP]"): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences + for sequence classification or for a text and a question for question answering. + It is also used as the last token of a sequence built with special tokens. + pad_token (:obj:`string`, `optional`, defaults to "[PAD]"): + The token used for padding, for example when batching sequences of different lengths. + cls_token (:obj:`string`, `optional`, defaults to "[CLS]"): + The classifier token which is used when doing sequence classification (classification of the whole + sequence instead of per-token classification). It is the first token of the sequence when built with + special tokens. + mask_token (:obj:`string`, `optional`, defaults to "[MASK]"): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to tokenize Chinese characters. + This should likely be deactivated for Japanese: + see: https://github.com/huggingface/transformers/issues/328 + clean_text (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to clean the text before tokenization by removing any control characters and + replacing all whitespaces by the classic one. + tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to tokenize Chinese characters. + This should likely be deactivated for Japanese: + see: https://github.com/huggingface/transformers/issues/328 + """ + vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION diff --git a/src/transformers/tokenization_bert_japanese.py b/src/transformers/tokenization_bert_japanese.py index 6c91240435..a88c17b413 100644 --- a/src/transformers/tokenization_bert_japanese.py +++ b/src/transformers/tokenization_bert_japanese.py @@ -119,8 +119,6 @@ class BertJapaneseTokenizer(BertTokenizer): **kwargs, ) # ^^ We call the grandparent's init, not the parent's. - self.max_len_single_sentence = self.max_len - 2 # take into account special tokens - self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens if not os.path.isfile(vocab_file): raise ValueError( diff --git a/src/transformers/tokenization_camembert.py b/src/transformers/tokenization_camembert.py index cc4fdc650b..0179d020bf 100644 --- a/src/transformers/tokenization_camembert.py +++ b/src/transformers/tokenization_camembert.py @@ -129,8 +129,6 @@ class CamembertTokenizer(PreTrainedTokenizer): additional_special_tokens=additional_special_tokens, **kwargs, ) - self.max_len_single_sentence = self.max_len - 2 # take into account special tokens - self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens self.sp_model = spm.SentencePieceProcessor() self.sp_model.Load(str(vocab_file)) self.vocab_file = vocab_file diff --git a/src/transformers/tokenization_ctrl.py b/src/transformers/tokenization_ctrl.py index 5c487952c4..9757b05803 100644 --- a/src/transformers/tokenization_ctrl.py +++ b/src/transformers/tokenization_ctrl.py @@ -140,12 +140,6 @@ class CTRLTokenizer(PreTrainedTokenizer): def __init__(self, vocab_file, merges_file, unk_token="", **kwargs): super().__init__(unk_token=unk_token, **kwargs) - self.max_len_single_sentence = ( - self.max_len - ) # no default special tokens - you can update this value if you add special tokens - self.max_len_sentences_pair = ( - self.max_len - ) # no default special tokens - you can update this value if you add special tokens with open(vocab_file, encoding="utf-8") as vocab_handle: self.encoder = json.load(vocab_handle) diff --git a/src/transformers/tokenization_distilbert.py b/src/transformers/tokenization_distilbert.py index 626e65486b..1c34f3f80e 100644 --- a/src/transformers/tokenization_distilbert.py +++ b/src/transformers/tokenization_distilbert.py @@ -57,8 +57,9 @@ PRETRAINED_INIT_CONFIGURATION = { class DistilBertTokenizer(BertTokenizer): r""" - Constructs a DistilBertTokenizer. - :class:`~transformers.DistilBertTokenizer` is identical to :class:`~transformers.BertTokenizer` and runs end-to-end + Constructs a DistilBertTokenizer. + + :class:`~transformers.DistilBertTokenizer is identical to :class:`~transformers.BertTokenizer` and runs end-to-end tokenization: punctuation splitting + wordpiece. Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning @@ -73,6 +74,16 @@ class DistilBertTokenizer(BertTokenizer): class DistilBertTokenizerFast(BertTokenizerFast): + r""" + Constructs a "Fast" DistilBertTokenizer (backed by HuggingFace's `tokenizers` library). + + :class:`~transformers.DistilBertTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs end-to-end + tokenization: punctuation splitting + wordpiece. + + Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning + parameters. + """ + vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES diff --git a/src/transformers/tokenization_electra.py b/src/transformers/tokenization_electra.py index ceb7eedccb..80fb6a53b7 100644 --- a/src/transformers/tokenization_electra.py +++ b/src/transformers/tokenization_electra.py @@ -67,7 +67,8 @@ class ElectraTokenizer(BertTokenizer): class ElectraTokenizerFast(BertTokenizerFast): r""" - Constructs an Electra Fast tokenizer. + Constructs a "Fast" Electra Fast tokenizer (backed by HuggingFace's `tokenizers` library). + :class:`~transformers.ElectraTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs end-to-end tokenization: punctuation splitting + wordpiece. diff --git a/src/transformers/tokenization_gpt2.py b/src/transformers/tokenization_gpt2.py index 45331445b1..e587968d6b 100644 --- a/src/transformers/tokenization_gpt2.py +++ b/src/transformers/tokenization_gpt2.py @@ -147,12 +147,6 @@ class GPT2Tokenizer(PreTrainedTokenizer): **kwargs ): super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs) - self.max_len_single_sentence = ( - self.max_len - ) # no default special tokens - you can update this value if you add special tokens - self.max_len_sentences_pair = ( - self.max_len - ) # no default special tokens - you can update this value if you add special tokens with open(vocab_file, encoding="utf-8") as vocab_handle: self.encoder = json.load(vocab_handle) @@ -284,6 +278,47 @@ class GPT2Tokenizer(PreTrainedTokenizer): class GPT2TokenizerFast(PreTrainedTokenizerFast): + """ + Constructs a "Fast" GPT-2 BPE tokenizer (backed by HuggingFace's `tokenizers` library). + + Peculiarities: + + - Byte-level Byte-Pair-Encoding + - Requires a space to start the input string => the encoding methods should be called with the + ``add_prefix_space`` flag set to ``True``. + Otherwise, this tokenizer ``encode`` and ``decode`` method will not conserve + the absence of a space at the beginning of a string: + + :: + + tokenizer.decode(tokenizer.encode("Hello")) = " Hello" + + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the methods. Users + should refer to the superclass for more information regarding methods. + + Args: + vocab_file (:obj:`str`): + Path to the vocabulary file. + merges_file (:obj:`str`): + Path to the merges file. + errors (:obj:`str`, `optional`, defaults to "replace"): + Paradigm to follow when decoding bytes to UTF-8. See `bytes.decode + `__ for more information. + unk_token (:obj:`string`, `optional`, defaults to `<|endoftext|>`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (:obj:`string`, `optional`, defaults to `<|endoftext|>`): + The beginning of sequence token. + eos_token (:obj:`string`, `optional`, defaults to `<|endoftext|>`): + The end of sequence token. + add_prefix_space (:obj:`bool`, `optional`, defaults to `False`): + Whether to add a leading space to the first word. + This allows to treat the leading word just as any other word. + (GPT2 tokenizer detect beginning of words by the preceeding space) + trim_offsets (:obj:`bool`, `optional`, defaults to `True`): + Whether the post processing step should trim offsets to avoid including whitespaces. + """ + vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES @@ -296,10 +331,16 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast): bos_token="<|endoftext|>", eos_token="<|endoftext|>", add_prefix_space=False, + trim_offsets=True, **kwargs ): super().__init__( - ByteLevelBPETokenizer(vocab_file=vocab_file, merges_file=merges_file, add_prefix_space=add_prefix_space), + ByteLevelBPETokenizer( + vocab_file=vocab_file, + merges_file=merges_file, + add_prefix_space=add_prefix_space, + trim_offsets=trim_offsets, + ), bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, diff --git a/src/transformers/tokenization_openai.py b/src/transformers/tokenization_openai.py index 9b150d4772..4e71c0a964 100644 --- a/src/transformers/tokenization_openai.py +++ b/src/transformers/tokenization_openai.py @@ -19,15 +19,8 @@ import json import logging import os import re -from typing import List, Optional, Union -from tokenizers import Tokenizer -from tokenizers.decoders import BPEDecoder -from tokenizers.implementations import BaseTokenizer -from tokenizers.models import BPE -from tokenizers.normalizers import BertNormalizer, Sequence, unicode_normalizer_from_str -from tokenizers.pre_tokenizers import BertPreTokenizer -from tokenizers.trainers import BpeTrainer +from tokenizers import CharBPETokenizer from .tokenization_bert import BasicTokenizer from .tokenization_utils import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -106,13 +99,6 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): def __init__(self, vocab_file, merges_file, unk_token="", **kwargs): super().__init__(unk_token=unk_token, **kwargs) - self.max_len_single_sentence = ( - self.max_len - ) # no default special tokens - you can update this value if you add special tokens - self.max_len_sentences_pair = ( - self.max_len - ) # no default special tokens - you can update this value if you add special tokens - try: import ftfy from spacy.lang.en import English @@ -249,83 +235,28 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): return vocab_file, merge_file -class _OpenAIGPTCharBPETokenizer(BaseTokenizer): - """ - OpenAI character-level BPE Tokenizer - """ - - def __init__( - self, - vocab_file: Optional[str] = None, - merges_file: Optional[str] = None, - unk_token: Optional[str] = "", - suffix: Optional[str] = "", - dropout: Optional[float] = None, - unicode_normalizer: Optional[str] = None, - ): - if vocab_file is not None and merges_file is not None: - tokenizer = Tokenizer( - BPE(vocab_file, merges_file, dropout=dropout, unk_token=unk_token, end_of_word_suffix=suffix) - ) - else: - tokenizer = Tokenizer(BPE()) - - # Check for Unicode normalization first (before everything else) - normalizers = [] - - if unicode_normalizer: - normalizers += [unicode_normalizer_from_str(unicode_normalizer)] - - # OpenAI normalization is the same as Bert - normalizers += [BertNormalizer()] - - # Create the normalizer structure - if len(normalizers) > 0: - if len(normalizers) > 1: - tokenizer.normalizer = Sequence(normalizers) - else: - tokenizer.normalizer = normalizers[0] - - tokenizer.pre_tokenizer = BertPreTokenizer() - tokenizer.decoder = BPEDecoder(suffix=suffix) - - parameters = { - "model": "BPE", - "unk_token": unk_token, - "suffix": suffix, - "dropout": dropout, - } - - super().__init__(tokenizer, parameters) - - def train( - self, - files: Union[str, List[str]], - vocab_size: int = 30000, - min_frequency: int = 2, - special_tokens: List[str] = [""], - limit_alphabet: int = 1000, - initial_alphabet: List[str] = [], - suffix: Optional[str] = "", - show_progress: bool = True, - ): - """ Train the model using the given files """ - - trainer = BpeTrainer( - vocab_size=vocab_size, - min_frequency=min_frequency, - special_tokens=special_tokens, - limit_alphabet=limit_alphabet, - initial_alphabet=initial_alphabet, - end_of_word_suffix=suffix, - show_progress=show_progress, - ) - if isinstance(files, str): - files = [files] - self._tokenizer.train(trainer, files) - - class OpenAIGPTTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "Fast" BPE tokenizer for OpenAI GPT (backed by HuggingFace's `tokenizers` library). + + Peculiarities: + + - lower case all inputs + - uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not. + + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users + should refer to the superclass for more information regarding methods. + + Args: + vocab_file (:obj:`str`): + Path to the vocabulary file. + merges_file (:obj:`str`): + Path to the merges file. + unk_token (:obj:`string`, `optional`, defaults to ""): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + """ + vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES @@ -333,5 +264,6 @@ class OpenAIGPTTokenizerFast(PreTrainedTokenizerFast): def __init__(self, vocab_file, merges_file, unk_token="", **kwargs): kwargs.setdefault("unk_token", unk_token) super().__init__( - _OpenAIGPTCharBPETokenizer(vocab_file=vocab_file, merges_file=merges_file, unk_token=unk_token), **kwargs + CharBPETokenizer(vocab_file=vocab_file, merges_file=merges_file, unk_token=unk_token, lowercase=True), + **kwargs, ) diff --git a/src/transformers/tokenization_roberta.py b/src/transformers/tokenization_roberta.py index 02f352b696..6cc96dec4f 100644 --- a/src/transformers/tokenization_roberta.py +++ b/src/transformers/tokenization_roberta.py @@ -150,8 +150,6 @@ class RobertaTokenizer(GPT2Tokenizer): mask_token=mask_token, **kwargs, ) - self.max_len_single_sentence = self.max_len - 2 # take into account special tokens - self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None @@ -244,6 +242,47 @@ class RobertaTokenizer(GPT2Tokenizer): class RobertaTokenizerFast(GPT2TokenizerFast): + """ + Constructs a "Fast" RoBERTa BPE tokenizer (backed by HuggingFace's `tokenizers` library). + + Peculiarities: + + - Byte-level Byte-Pair-Encoding + - Requires a space to start the input string => the encoding methods should be called with the + ``add_prefix_space`` flag set to ``True``. + Otherwise, this tokenizer ``encode`` and ``decode`` method will not conserve + the absence of a space at the beginning of a string: + + :: + + tokenizer.decode(tokenizer.encode("Hello")) = " Hello" + + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the methods. Users + should refer to the superclass for more information regarding methods. + + Args: + vocab_file (:obj:`str`): + Path to the vocabulary file. + merges_file (:obj:`str`): + Path to the merges file. + errors (:obj:`str`, `optional`, defaults to "replace"): + Paradigm to follow when decoding bytes to UTF-8. See `bytes.decode + `__ for more information. + unk_token (:obj:`string`, `optional`, defaults to `<|endoftext|>`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (:obj:`string`, `optional`, defaults to `<|endoftext|>`): + The beginning of sequence token. + eos_token (:obj:`string`, `optional`, defaults to `<|endoftext|>`): + The end of sequence token. + add_prefix_space (:obj:`bool`, `optional`, defaults to `False`): + Whether to add a leading space to the first word. + This allows to treat the leading word just as any other word. + (GPT2 tokenizer detect beginning of words by the preceeding space) + trim_offsets (:obj:`bool`, `optional`, defaults to `True`): + Whether the post processing step should trim offsets to avoid including whitespaces. + """ + vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES @@ -262,6 +301,7 @@ class RobertaTokenizerFast(GPT2TokenizerFast): pad_token="", mask_token="", add_prefix_space=True, + trim_offsets=True, **kwargs ): kwargs.setdefault("pad_token", pad_token) @@ -276,23 +316,18 @@ class RobertaTokenizerFast(GPT2TokenizerFast): bos_token=bos_token, eos_token=eos_token, add_prefix_space=add_prefix_space, + trim_offsets=trim_offsets, **kwargs, ) - self.tokenizer._tokenizer.post_processor = RobertaProcessing( - (sep_token, self.sep_token_id), (cls_token, self.cls_token_id) + self.backend_tokenizer._tokenizer.post_processor = RobertaProcessing( + sep=(sep_token, self.sep_token_id), + cls=(cls_token, self.cls_token_id), + add_prefix_space=add_prefix_space, + trim_offsets=trim_offsets, ) - self.tokenizer.add_special_tokens([kwargs["mask_token"]]) - - # As we override the post_processor post super.__init__ the computed num_added_tokens is wrong in super(). - # We need to recompute max_len according to the newly register post_processor to get real values. - self.max_len_single_sentence = self.max_len - self.num_special_tokens_to_add( - False - ) # take into account special tokens - self.max_len_sentences_pair = self.max_len - self.num_special_tokens_to_add( - True - ) # take into account special tokens + self.backend_tokenizer.add_special_tokens([kwargs["mask_token"]]) @PreTrainedTokenizer.mask_token.setter def mask_token(self, value): @@ -300,7 +335,7 @@ class RobertaTokenizerFast(GPT2TokenizerFast): value = AddedToken(value, lstrip=True) self._mask_token = str(value) - self.tokenizer.add_special_tokens([value]) + self._maybe_update_backend([value]) def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] diff --git a/src/transformers/tokenization_t5.py b/src/transformers/tokenization_t5.py index bc1cf9fd03..df25eab1dd 100644 --- a/src/transformers/tokenization_t5.py +++ b/src/transformers/tokenization_t5.py @@ -118,12 +118,6 @@ class T5Tokenizer(PreTrainedTokenizer): additional_special_tokens=additional_special_tokens, **kwargs, ) - self.max_len_single_sentence = ( - self.max_len - ) # no default special tokens - you can update this value if you add special tokens - self.max_len_sentences_pair = ( - self.max_len - ) # no default special tokens - you can update this value if you add special tokens try: import sentencepiece as spm diff --git a/src/transformers/tokenization_transfo_xl.py b/src/transformers/tokenization_transfo_xl.py index 2392394bbe..ea6c7deee1 100644 --- a/src/transformers/tokenization_transfo_xl.py +++ b/src/transformers/tokenization_transfo_xl.py @@ -101,13 +101,6 @@ class TransfoXLTokenizer(PreTrainedTokenizer): unk_token=unk_token, eos_token=eos_token, additional_special_tokens=additional_special_tokens, **kwargs ) - self.max_len_single_sentence = ( - self.max_len - ) # no default special tokens - you can update this value if you add special tokens - self.max_len_sentences_pair = ( - self.max_len - ) # no default special tokens - you can update this value if you add special tokens - if never_split is None: never_split = self.all_special_tokens if special is None: @@ -410,6 +403,16 @@ class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer): class TransfoXLTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "Fast" Transformer-XL tokenizer (backed by HuggingFace's `tokenizers` library). + + The Transformer-XL tokenizer is a word-level tokenizer (no sub-word tokenization). + + Adapted from Vocab class in https://github.com/kimiyoung/transformer-xl + + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the methods. Users + should refer to the superclass for more information regarding methods. + """ vocab_files_names = VOCAB_FILES_NAMES_FAST pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP_FAST diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index bdfda4f835..8655167e10 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. +# Copyright 2020 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tokenization classes for OpenAI GPT.""" +"""Tokenization classes for python and fast tokenizers. Fast tokenizers are provided by HuggingFace's tokenizers library.""" import copy import functools @@ -24,11 +24,12 @@ import os import re from collections import UserDict, defaultdict from contextlib import contextmanager -from typing import List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union -from tokenizers import AddedToken, Encoding -from tokenizers.decoders import Decoder -from tokenizers.implementations import BaseTokenizer +from tokenizers import AddedToken as AddedTokenFast +from tokenizers import Encoding as EncodingFast +from tokenizers.decoders import Decoder as DecoderFast +from tokenizers.implementations import BaseTokenizer as BaseTokenizerFast from .file_utils import cached_path, hf_bucket_url, is_remote_url, is_tf_available, is_torch_available @@ -44,12 +45,40 @@ SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json" ADDED_TOKENS_FILE = "added_tokens.json" TOKENIZER_CONFIG_FILE = "tokenizer_config.json" +VERY_LARGE_INTEGER = int(1e30) # This is used to set the max input length for a model with infinite size input +LARGE_INTEGER = int(1e20) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER -# Define type aliases +# Define type aliases and NamedTuples TextInput = str -TextPairInput = Tuple[str, str] PreTokenizedInput = List[str] +EncodedInput = List[int] +TextInputPair = Tuple[str, str] PreTokenizedInputPair = Tuple[List[str], List[str]] +EncodedInputPair = Tuple[List[int], List[int]] + + +class CharSpan(NamedTuple): + """ Character span in the original string + + Args: + start: index of the first character in the original string + end: index of the character following the last character in the original string + """ + + start: int + end: int + + +class TokenSpan(NamedTuple): + """ Token span in an encoded string (list of tokens) + + Args: + start: index of the first token in the span + end: index of the token following the last token in the span + """ + + start: int + end: int def flatten(x: Sequence): @@ -68,7 +97,7 @@ def flatten(x: Sequence): @contextmanager def truncate_and_pad( - tokenizer: BaseTokenizer, + tokenizer: BaseTokenizerFast, max_length: int, stride: int, strategy: str, @@ -78,26 +107,23 @@ def truncate_and_pad( pad_token_type_id: int, pad_token: str, ): - """ - This contextmanager is in charge of defining the truncation and the padding strategies and then - restore the tokenizer settings afterwards. + """ This contextmanager is in charge of defining the truncation and the padding strategies for fast tokenizers + (provided by HuggingFace tokenizers library) and restore the tokenizer settings afterwards. - This contextmanager assumes the provider tokenizer has no padding / truncation strategy - before the managed section. If your tokenizer set a padding / truncation strategy before, - then it will be reset to no padding/truncation when exiting the managed section. + This contextmanager assumes the provider tokenizer has no padding / truncation strategy + before the managed section. If your tokenizer set a padding / truncation strategy before, + then it will be reset to no padding/truncation when exiting the managed section. - Args: - tokenizer (BaseTokenizer): The tokenizer which will be used - max_length (int): The maximum size of the sequence - stride (int): The stride to use when handling overflow - strategy (str): Overflowing logic to use - pad_to_max_length (bool): Boolean indicating if the output needs to be padded up to max_length - padding_side (str): "left" or "right" indicating the direction the output sequence will be padded - pad_token_id (int): The integer representation of the padding token to use - pad_token_type_id (int): The integer representation of the padding token type to use - pad_token (str): The string representation of the padding token to use - - Returns: + Args: + tokenizer (BaseTokenizerFast): The tokenizer which will be used + max_length (int): The maximum size of the sequence + stride (int): The stride to use when handling overflow + strategy (str): Overflowing logic to use + pad_to_max_length (bool): Boolean indicating if the output needs to be padded up to max_length + padding_side (str): "left" or "right" indicating the direction the output sequence will be padded + pad_token_id (int): The integer representation of the padding token to use + pad_token_type_id (int): The integer representation of the padding token type to use + pad_token (str): The string representation of the padding token to use """ @@ -124,6 +150,9 @@ def truncate_and_pad( yield + # TODO(morgan, anthony): once we have a simple way to serialize tokenizers maybe store and restore the state afterward + # to avoid destructing the padding / truncation strategy as we do now. + if max_length is not None: tokenizer.no_truncation() @@ -132,41 +161,43 @@ def truncate_and_pad( class BatchEncoding(UserDict): - """ - Data structure derived from Dictionary holding all the required information to forward through - a model. + """ BatchEncoding hold the output of the encode and batch_encode methods (tokens, attention_masks, etc). + This class is derived from a python Dictionary and can be used as a dictionnary. + In addition, this class expose utility methods to map from word/char space to token space. + + Args: + data (:obj:`dict`): Dictionary of lists/arrays returned by the encode/batch_encode methods ('input_ids', 'attention_mask'...) + encoding (:obj:`EncodingFast`, :obj:`list(EncodingFast)`, `optional`, defaults to :obj:`None`): + If the tokenizer is a fast tokenizer which outputs additional informations like mapping from word/char space to token space + the `EncodingFast` instance or list of instance (for batches) hold these informations. - In addition, this structure expose utility methods to map from word/char space to token space. """ - def __init__(self, data: dict, encoding: Optional[Union[Encoding, Sequence[Encoding]]] = None): + def __init__(self, data: Dict[str, Any], encoding: Optional[Union[EncodingFast, Sequence[EncodingFast]]] = None): super().__init__(data) - if isinstance(encoding, Encoding): + if isinstance(encoding, EncodingFast): encoding = [encoding] self._encodings = encoding - def __getitem__(self, item: Union[int, str]) -> Encoding: + def __getitem__(self, item: Union[int, str]) -> EncodingFast: + """ If the key is a string, get the value of the dict associated to `key` ('input_ids', 'attention_mask'...) + If the key is an integer, get the EncodingFast for batch item with index `key` + """ if isinstance(item, str): return self.data[item] elif self._encodings is not None: return self._encodings[item] else: - raise KeyError("int index is supported only on {} from a Rust tokenizer".format(type(self).__name__)) + raise KeyError( + "Indexing with integers (to access backend Encoding for a given batch index) " + "is not available when using Python based tokenizers" + ) def __getattr__(self, item: str): return self.data[item] - @property - def encodings(self) -> Optional[List[Encoding]]: - """ - Return the list all encoding from the tokenization process - - Returns: List[Encoding] or None if input was tokenized through Python tokenizer - """ - return self._encodings - def keys(self): return self.data.keys() @@ -176,73 +207,265 @@ class BatchEncoding(UserDict): def items(self): return self.data.items() - def char_to_token_offsets(self, sentence: int, char: int) -> Tuple[int, int]: + # After this point: + # Extended properties and methods only available for fast (Rust-based) tokenizers + # provided by HuggingFace tokenizers library. + + @property + def encodings(self) -> Optional[List[EncodingFast]]: """ - Find the Offsets of the token containing the character at the specified position + Return the list all encoding from the tokenization process + + Returns: List[EncodingFast] or None if input was tokenized through Python (i.e. not fast) tokenizer + """ + return self._encodings + + def tokens(self, batch_index: int = 0) -> List[int]: + if not self._encodings: + raise ValueError("tokens() is not available when using Python based tokenizers") + return self._encodings[batch_index].tokens + + def words(self, batch_index: int = 0) -> List[Optional[int]]: + if not self._encodings: + raise ValueError("words() is not available when using Python based tokenizers") + return self._encodings[batch_index].words + + def token_to_word(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int: + """ Get the index of the word corresponding (i.e. comprising) to an encoded token + in a sequence of the batch. + + Can be called as: + - self.token_to_word(token_index) if batch size is 1 + - self.token_to_word(batch_index, token_index) if batch size is greater than 1 + + This method is particularly suited when the input sequences are provided as + pre-tokenized sequences (i.e. words are defined by the user). In this case it allows + to easily associate encoded tokens with provided tokenized words. Args: - sentence: Index of the sentence relative to the batch provided to the tokenizer - char: Char index to get the relative token offsets + batch_or_token_index (:obj:`int`): + Index of the sequence in the batch. If the batch only comprise one sequence, + this can be the index of the token in the sequence + token_index (:obj:`int`, `optional`): + If a batch index is provided in `batch_or_token_index`, this can be the index + of the token in the sequence. Returns: - tuple: (token start, token end) + word_index (:obj:`int`): + index of the word in the input sequence. """ if not self._encodings: - raise ValueError("char_to_token_offsets() is not available when using Python based tokenizers") - return self[sentence].char_to_token_offsets(char) + raise ValueError("token_to_word() is not available when using Python based tokenizers") + if token_index is not None: + batch_index = batch_or_token_index + else: + batch_index = 0 + token_index = batch_or_token_index + if batch_index < 0: + batch_index = self._batch_size + batch_index + if token_index < 0: + token_index = self._seq_len + token_index + return self._encodings[batch_index].token_to_word(token_index) - def char_to_token(self, sentence: int, char: int) -> int: - """ - Return the index of the token at position of the given char. + def word_to_tokens(self, batch_or_word_index: int, word_index: Optional[int] = None) -> TokenSpan: + """ Get the encoded token span corresponding to a word in the sequence of the batch. + + Token spans are returned as a TokenSpan NamedTuple with: + start: index of the first token + end: index of the token following the last token + + Can be called as: + - self.word_to_tokens(word_index) if batch size is 1 + - self.word_to_tokens(batch_index, word_index) if batch size is greater or equal to 1 + + This method is particularly suited when the input sequences are provided as + pre-tokenized sequences (i.e. words are defined by the user). In this case it allows + to easily associate encoded tokens with provided tokenized words. Args: - sentence (int): Index of the sentence relative to the batch provided to the tokenizer - char (int): Char index to get the relative token offsets + batch_or_word_index (:obj:`int`): + Index of the sequence in the batch. If the batch only comprises one sequence, + this can be the index of the word in the sequence + word_index (:obj:`int`, `optional`): + If a batch index is provided in `batch_or_token_index`, this can be the index + of the word in the sequence. Returns: - int: Integer referring to the position of the token in the returned set of tokens for the sentence + token_span (:obj:`TokenSpan`): + Span of tokens in the encoded sequence. + + TokenSpan are NamedTuple with: + start: index of the first token + end: index of the token following the last token + """ + + if not self._encodings: + raise ValueError("word_to_tokens() is not available when using Python based tokenizers") + if word_index is not None: + batch_index = batch_or_word_index + else: + batch_index = 0 + word_index = batch_or_word_index + if batch_index < 0: + batch_index = self._batch_size + batch_index + if word_index < 0: + word_index = self._seq_len + word_index + return TokenSpan(*(self._encodings[batch_index].word_to_tokens(word_index))) + + def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> CharSpan: + """ Get the character span corresponding to an encoded token in a sequence of the batch. + + Character spans are returned as a CharSpan NamedTuple with: + start: index of the first character in the original string associated to the token + end: index of the character following the last character in the original string associated to the token + + Can be called as: + - self.token_to_chars(token_index) if batch size is 1 + - self.token_to_chars(batch_index, token_index) if batch size is greater or equal to 1 + + Args: + batch_or_token_index (:obj:`int`): + Index of the sequence in the batch. If the batch only comprise one sequence, + this can be the index of the token in the sequence + token_index (:obj:`int`, `optional`): + If a batch index is provided in `batch_or_token_index`, this can be the index + of the token or tokens in the sequence. + + Returns: + char_span (:obj:`CharSpan`): + Span of characters in the original string. + + CharSpan are NamedTuple with: + start: index of the first character in the original string + end: index of the character following the last character in the original string + """ + + if not self._encodings: + raise ValueError("token_to_chars() is not available when using Python based tokenizers") + if token_index is not None: + batch_index = batch_or_token_index + else: + batch_index = 0 + token_index = batch_or_token_index + return CharSpan(*(self._encodings[batch_index].token_to_chars(token_index))) + + def char_to_token(self, batch_or_char_index: int, char_index: Optional[int] = None) -> int: + """ Get the index of the token in the encoded output comprising a character + in the original string for a sequence of the batch. + + Can be called as: + - self.char_to_token(char_index) if batch size is 1 + - self.char_to_token(batch_index, char_index) if batch size is greater or equal to 1 + + This method is particularly suited when the input sequences are provided as + pre-tokenized sequences (i.e. words are defined by the user). In this case it allows + to easily associate encoded tokens with provided tokenized words. + + Args: + batch_or_char_index (:obj:`int`): + Index of the sequence in the batch. If the batch only comprise one sequence, + this can be the index of the word in the sequence + char_index (:obj:`int`, `optional`): + If a batch index is provided in `batch_or_token_index`, this can be the index + of the word in the sequence. + + + Returns: + token_index (:obj:`int`): + Index of the token. """ if not self._encodings: raise ValueError("char_to_token() is not available when using Python based tokenizers") - return self[sentence].char_to_token(char) + if char_index is not None: + batch_index = batch_or_char_index + else: + batch_index = 0 + char_index = batch_or_char_index + return self._encodings[batch_index].char_to_token(char_index) - def char_to_word_offsets(self, sentence: int, char: int) -> Tuple[int, int]: - """ - Find the Offsets of the word containing the character at the specified position + def word_to_chars(self, batch_or_word_index: int, word_index: Optional[int] = None) -> CharSpan: + """ Get the character span in the original string corresponding to given word in a sequence + of the batch. + + Character spans are returned as a CharSpan NamedTuple with: + start: index of the first character in the original string + end: index of the character following the last character in the original string + + Can be called as: + - self.word_to_chars(word_index) if batch size is 1 + - self.word_to_chars(batch_index, word_index) if batch size is greater or equal to 1 Args: - sentence (int): Index of the sentence relative to the batch provided to the tokenizer - char (int): Char index to get the relative token offsets + batch_or_word_index (:obj:`int`): + Index of the sequence in the batch. If the batch only comprise one sequence, + this can be the index of the word in the sequence + word_index (:obj:`int`, `optional`): + If a batch index is provided in `batch_or_token_index`, this can be the index + of the word in the sequence. Returns: - tuple: (word start, word end) representing the first and last characters of the word + char_span (:obj:`CharSpan` or :obj:`List[CharSpan]`): + Span(s) of the associated character or characters in the string. + CharSpan are NamedTuple with: + start: index of the first character associated to the token in the original string + end: index of the character following the last character associated to the token in the original string """ if not self._encodings: - raise ValueError("char_to_word_offsets() is not available when using Python based tokenizers") - return self[sentence].char_to_word_offsets(char) + raise ValueError("word_to_chars() is not available when using Python based tokenizers") + if word_index is not None: + batch_index = batch_or_word_index + else: + batch_index = 0 + word_index = batch_or_word_index + return CharSpan(*(self._encodings[batch_index].word_to_chars(word_index))) - def token_to_word_offsets(self, sentence: int, index: int) -> Optional[Tuple[int, int]]: - """ - Find the Offsets of the word containing the token at the given index + def char_to_word(self, batch_or_char_index: int, char_index: Optional[int] = None) -> int: + """ Get the word in the original string corresponding to a character in the original string of + a sequence of the batch. + + Can be called as: + - self.char_to_word(char_index) if batch size is 1 + - self.char_to_word(batch_index, char_index) if batch size is greater than 1 + + This method is particularly suited when the input sequences are provided as + pre-tokenized sequences (i.e. words are defined by the user). In this case it allows + to easily associate encoded tokens with provided tokenized words. Args: - sentence (int): Index of the sentence relative to the batch provided to the tokenizer - index (int): Index of the token to map to the original word offsets + batch_or_char_index (:obj:`int`): + Index of the sequence in the batch. If the batch only comprise one sequence, + this can be the index of the character in the orginal string. + char_index (:obj:`int`, `optional`): + If a batch index is provided in `batch_or_token_index`, this can be the index + of the character in the orginal string. + Returns: - Optional[tuple]: (word start, word end) or None + token_index (:obj:`int` or :obj:`List[int]`): + Index or indices of the associated encoded token(s). """ if not self._encodings: - raise ValueError("token_to_word_offsets() is not available when using Python based tokenizers") - return self[sentence].token_to_word_offsets(index) + raise ValueError("char_to_word() is not available when using Python based tokenizers") + if char_index is not None: + batch_index = batch_or_char_index + else: + batch_index = 0 + char_index = batch_or_char_index + return self._encodings[batch_index].char_to_word(char_index) class SpecialTokensMixin: + """ SpecialTokensMixin is derived by ``PreTrainedTokenizer`` and ``PreTrainedTokenizerFast`` and + handles specific behaviors related to special tokens. In particular, this class hold the + attributes which can be used to directly access to these special tokens in a + model-independant manner and allow to set and update the special tokens. + """ + SPECIAL_TOKENS_ATTRIBUTES = [ "bos_token", "eos_token", @@ -255,7 +478,6 @@ class SpecialTokensMixin: ] def __init__(self, **kwargs): - self._bos_token = None self._eos_token = None self._unk_token = None @@ -270,13 +492,13 @@ class SpecialTokensMixin: if key in self.SPECIAL_TOKENS_ATTRIBUTES: if key == "additional_special_tokens": assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value) - elif isinstance(value, AddedToken): + elif isinstance(value, AddedTokenFast): setattr(self, key, str(value)) elif isinstance(value, str): setattr(self, key, value) else: raise TypeError( - "special token {} has to be either str or AddedToken but got: {}".format(key, type(value)) + "special token {} has to be either str or AddedTokenFast but got: {}".format(key, type(value)) ) @property @@ -335,33 +557,49 @@ class SpecialTokensMixin: logger.error("Using additional_special_tokens, but it is not set yet.") return self._additional_special_tokens + def _maybe_update_backend(self, value): + """ To be overriden by derived class if a backend tokenizer has to be updated. """ + pass + @bos_token.setter def bos_token(self, value): self._bos_token = value + self._maybe_update_backend([value]) @eos_token.setter def eos_token(self, value): self._eos_token = value + self._maybe_update_backend([value]) @unk_token.setter def unk_token(self, value): self._unk_token = value + self._maybe_update_backend([value]) @sep_token.setter def sep_token(self, value): self._sep_token = value + self._maybe_update_backend([value]) @pad_token.setter def pad_token(self, value): self._pad_token = value + self._maybe_update_backend([value]) @cls_token.setter def cls_token(self, value): self._cls_token = value + self._maybe_update_backend([value]) @mask_token.setter def mask_token(self, value): self._mask_token = value + self._maybe_update_backend([value]) + + @additional_special_tokens.setter + def additional_special_tokens(self, value): + self._additional_special_tokens = value + self._maybe_update_backend(value) @property def bos_token_id(self): @@ -441,50 +679,69 @@ class SpecialTokensMixin: all_ids = self.convert_tokens_to_ids(all_toks) return all_ids - @additional_special_tokens.setter - def additional_special_tokens(self, value): - self._additional_special_tokens = value - class PreTrainedTokenizer(SpecialTokensMixin): """ Base class for all tokenizers. - Handle all the shared methods for tokenization and special tokens as well as methods downloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary. - This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...). + Handle all the shared methods for tokenization and special tokens as well as methods + downloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary. + + This class also contain the added tokens in a unified way on top of all tokenizers so we don't + have to handle the specific vocabulary augmentation methods of the various underlying + dictionary structures (BPE, sentencepiece...). Class attributes (overridden by derived classes): - - ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string). - - ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the associated pretrained vocabulary file. - - ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or None if the model has no maximum input size. - - ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, a dictionnary of specific arguments to pass to the ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the ``from_pretrained()`` method. + - ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file + required by the model, and as associated values, the filename for saving the associated file (string). + - ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys + being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the + `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the + associated pretrained vocabulary file. + - ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained + models, and as associated values, the maximum length of the sequence inputs of this model, or None if the + model has no maximum input size. + - ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the + pretrained models, and as associated values, a dictionnary of specific arguments to pass to the + ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the + ``from_pretrained()`` method. - Parameters: - - - ``bos_token``: (`Optional`) string: a beginning of sentence token. Will be associated to ``self.bos_token`` and ``self.bos_token_id`` - - - ``eos_token``: (`Optional`) string: an end of sentence token. Will be associated to ``self.eos_token`` and ``self.eos_token_id`` - - - ``unk_token``: (`Optional`) string: an unknown token. Will be associated to ``self.unk_token`` and ``self.unk_token_id`` - - - ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence). Will be associated to ``self.sep_token`` and ``self.sep_token_id`` - - - ``pad_token``: (`Optional`) string: a padding token. Will be associated to ``self.pad_token`` and ``self.pad_token_id`` - - - ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model). Will be associated to ``self.cls_token`` and ``self.cls_token_id`` - - - ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language modeling). Will be associated to ``self.mask_token`` and ``self.mask_token_id`` - - - ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens. Adding all special tokens here ensure they won't be split by the tokenization process. Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids`` + Args: + - ``model_max_length``: (`Optional`) int: the maximum length in number of tokens for the inputs to the transformer model. + When the tokenizer is loaded with `from_pretrained`, this will be set to the value stored for the associated + model in ``max_model_input_sizes`` (see above). If no value is provided, will default to VERY_LARGE_INTEGER (`int(1e30)`). + no associated max_length can be found in ``max_model_input_sizes``. + - ``padding_side``: (`Optional`) string: the side on which the model should have padding applied. + Should be selected between ['right', 'left'] + - ``model_input_names``: (`Optional`) List[string]: the list of the forward pass inputs accepted by the + model ("token_type_ids", "attention_mask"...). + - ``bos_token``: (`Optional`) string: a beginning of sentence token. + Will be associated to ``self.bos_token`` and ``self.bos_token_id`` + - ``eos_token``: (`Optional`) string: an end of sentence token. + Will be associated to ``self.eos_token`` and ``self.eos_token_id`` + - ``unk_token``: (`Optional`) string: an unknown token. + Will be associated to ``self.unk_token`` and ``self.unk_token_id`` + - ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence). + Will be associated to ``self.sep_token`` and ``self.sep_token_id`` + - ``pad_token``: (`Optional`) string: a padding token. + Will be associated to ``self.pad_token`` and ``self.pad_token_id`` + - ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence + leveraging self-attention along the full depth of the model). + Will be associated to ``self.cls_token`` and ``self.cls_token_id`` + - ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language + modeling). Will be associated to ``self.mask_token`` and ``self.mask_token_id`` + - ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens. + Adding all special tokens here ensure they won't be split by the tokenization process. + Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids`` """ - vocab_files_names = {} - pretrained_vocab_files_map = {} - pretrained_init_configuration = {} - max_model_input_sizes = {} - model_input_names = ["token_type_ids", "attention_mask"] + vocab_files_names: Dict[str, str] = {} + pretrained_vocab_files_map: Dict[str, Dict[str, str]] = {} + pretrained_init_configuration: Dict[str, Dict[str, Any]] = {} + max_model_input_sizes: Dict[str, int] = {} + model_input_names: List[str] = ["token_type_ids", "attention_mask"] - padding_side = "right" + padding_side: str = "right" NO_PAD_TOKEN_FOR_BATCH_MSG = ( "No padding token is set for this model, therefore no batch can be made with uneven " @@ -507,18 +764,39 @@ class PreTrainedTokenizer(SpecialTokensMixin): def is_fast(self): return False + @property + def max_len(self): + """ Kept here for backward compatibility. + Now renamed to `model_max_length` to avoid ambiguity. + """ + return self.model_max_length + + @property + def max_len_single_sentence(self): + return self.model_max_length - self.num_special_tokens_to_add(pair=False) + + @property + def max_len_sentences_pair(self): + return self.model_max_length - self.num_special_tokens_to_add(pair=True) + def get_vocab(self): """ Returns the vocabulary as a dict of {token: index} pairs. `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the vocab. """ raise NotImplementedError() - def __init__(self, max_len=None, **kwargs): + def __init__(self, model_max_length=None, **kwargs): super().__init__(**kwargs) - self.max_len = max_len if max_len is not None else int(1e12) + # For backward compatibility we fallback to set model_max_length from max_len if provided + model_max_length = model_max_length if model_max_length is not None else kwargs.pop("max_len", None) + self.model_max_length = model_max_length if model_max_length is not None else VERY_LARGE_INTEGER # Padding side is right by default and overridden in subclasses. If specified in the kwargs, it is changed. self.padding_side = kwargs.pop("padding_side", self.padding_side) + assert self.padding_side in [ + "right", + "left", + ], f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}" self.model_input_names = kwargs.pop("model_input_names", self.model_input_names) # Added tokens @@ -719,9 +997,9 @@ class PreTrainedTokenizer(SpecialTokensMixin): if pretrained_model_name_or_path in cls.max_model_input_sizes: # if we're using a pretrained model, ensure the tokenizer # wont index sequences longer than the number of positional embeddings - max_len = cls.max_model_input_sizes[pretrained_model_name_or_path] - if max_len is not None and isinstance(max_len, (int, float)): - init_kwargs["max_len"] = min(init_kwargs.get("max_len", int(1e12)), max_len) + model_max_length = cls.max_model_input_sizes[pretrained_model_name_or_path] + if model_max_length is not None and isinstance(model_max_length, (int, float)): + init_kwargs["model_max_length"] = min(init_kwargs.get("model_max_length", int(1e30)), model_max_length) # Merge resolved_vocab_files arguments in init_kwargs. added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None) @@ -769,10 +1047,11 @@ class PreTrainedTokenizer(SpecialTokensMixin): - special-tokens-to-class-attributes-mapping, - tokenizer instantiation positional and keywords inputs (e.g. do_lower_case for Bert). - This won't save modifications other than (added tokens and special token mapping) you may have - applied to the tokenizer after the instantiation (e.g. modifying tokenizer.do_lower_case after creation). + Warning: This won't save modifications you may have applied to the tokenizer after the instantiation + (e.g. modifying tokenizer.do_lower_case after creation). - This method make sure the full tokenizer can then be re-loaded using the :func:`~transformers.PreTrainedTokenizer.from_pretrained` class method. + This method make sure the full tokenizer can then be re-loaded using the + :func:`~transformers.PreTrainedTokenizer.from_pretrained` class method. """ if not os.path.isdir(save_directory): logger.error("Saving directory ({}) should be a directory".format(save_directory)) @@ -807,7 +1086,9 @@ class PreTrainedTokenizer(SpecialTokensMixin): """ Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens and special token mappings. - Please use :func:`~transformers.PreTrainedTokenizer.save_pretrained` `()` to save the full Tokenizer state if you want to reload it using the :func:`~transformers.PreTrainedTokenizer.from_pretrained` class method. + Please use :func:`~transformers.PreTrainedTokenizer.save_pretrained` `()` to save the full + Tokenizer state if you want to reload it using the :func:`~transformers.PreTrainedTokenizer.from_pretrained` + class method. """ raise NotImplementedError @@ -817,7 +1098,8 @@ class PreTrainedTokenizer(SpecialTokensMixin): vocabulary, they are added to it with indices starting from length of the current vocabulary. Args: - new_tokens: string or list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them). + new_tokens: string or list of string. Each string is a token to add. Tokens are only added if they are not + already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them). Returns: Number of tokens added to the vocabulary. @@ -939,14 +1221,14 @@ class PreTrainedTokenizer(SpecialTokensMixin): Take care of added tokens. - text: The sequence to be encoded. - add_prefix_space: Only applies to GPT-2 and RoBERTa tokenizers. When `True`, this ensures that the sequence - begins with an empty space. False by default except for when using RoBERTa with `add_special_tokens=True`. - **kwargs: passed to the `prepare_for_tokenization` preprocessing method. + Args: + text (:obj:`string`): The sequence to be encoded. + **kwargs (:obj: `dict`): Arguments passed to the model-specific `prepare_for_tokenization` preprocessing method. """ all_special_tokens = self.all_special_tokens text = self.prepare_for_tokenization(text, **kwargs) + # TODO: should this be in the base class? def lowercase_text(t): # convert non-special tokens to lowercase escaped_special_toks = [re.escape(s_tok) for s_tok in all_special_tokens] @@ -1014,8 +1296,8 @@ class PreTrainedTokenizer(SpecialTokensMixin): raise NotImplementedError def convert_tokens_to_ids(self, tokens): - """ Converts a single token, or a sequence of tokens, (str) in a single integer id - (resp. a sequence of ids), using the vocabulary. + """ Converts a token string (or a sequence of tokens) in a single integer id + (or a sequence of ids), using the vocabulary. """ if tokens is None: return None @@ -1041,8 +1323,8 @@ class PreTrainedTokenizer(SpecialTokensMixin): def encode( self, - text: TextInput, - text_pair: Optional[TextInput] = None, + text: Union[TextInput, PreTokenizedInput, EncodedInput], + text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, add_special_tokens: bool = True, max_length: Optional[int] = None, stride: int = 0, @@ -1057,11 +1339,11 @@ class PreTrainedTokenizer(SpecialTokensMixin): Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``. Args: - text (:obj:`str` or :obj:`List[str]`): + text (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`): The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` method) - text_pair (:obj:`str` or :obj:`List[str]`, `optional`, defaults to :obj:`None`): + text_pair (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`, `optional`, defaults to :obj:`None`): Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` method) @@ -1070,7 +1352,8 @@ class PreTrainedTokenizer(SpecialTokensMixin): to their model. max_length (:obj:`int`, `optional`, defaults to :obj:`None`): If set to a number, will limit the total sequence returned so that it has a maximum length. - If there are overflowing tokens, those will be added to the returned dictionary + If there are overflowing tokens, those will be added to the returned dictionary. + You can set it to the maximal input size of the model with `max_length = tokenizer.model_max_length`. stride (:obj:`int`, `optional`, defaults to ``0``): If set to a number along with max_length, the overflowing tokens returned will contain some tokens from the main sequence returned. The value of this argument defines the number of additional tokens. @@ -1112,8 +1395,8 @@ class PreTrainedTokenizer(SpecialTokensMixin): def encode_plus( self, - text: TextInput, - text_pair: Optional[TextInput] = None, + text: Union[TextInput, PreTokenizedInput, EncodedInput], + text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, add_special_tokens: bool = True, max_length: Optional[int] = None, stride: int = 0, @@ -1133,11 +1416,11 @@ class PreTrainedTokenizer(SpecialTokensMixin): the mask for sequence classification and the overflowing elements if a ``max_length`` is specified. Args: - text (:obj:`str` or :obj:`List[str]`): + text (:obj:`str`, :obj:`List[str]` or :obj:`List[int]` (the later only for not-fast tokenizers)): The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` method) - text_pair (:obj:`str` or :obj:`List[str]`, `optional`, defaults to :obj:`None`): + text_pair (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`, `optional`, defaults to :obj:`None`): Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` method) @@ -1147,6 +1430,7 @@ class PreTrainedTokenizer(SpecialTokensMixin): max_length (:obj:`int`, `optional`, defaults to :obj:`None`): If set to a number, will limit the total sequence returned so that it has a maximum length. If there are overflowing tokens, those will be added to the returned dictionary + You can set it to the maximal input size of the model with `max_length = tokenizer.model_max_length`. stride (:obj:`int`, `optional`, defaults to ``0``): If set to a number along with max_length, the overflowing tokens returned will contain some tokens from the main sequence returned. The value of this argument defines the number of additional tokens. @@ -1188,8 +1472,8 @@ class PreTrainedTokenizer(SpecialTokensMixin): Set to True to return special tokens mask information (default False). return_offsets_mapping (:obj:`bool`, `optional`, defaults to :obj:`False`): Set to True to return (char_start, char_end) for each token (default False). - If using Python's tokenizer, this method will raise NotImplementedError. This one is only available on - Rust-based tokenizers inheriting from PreTrainedTokenizerFast. + If using Python's tokenizer, this method will raise NotImplementedError. + This one is only available on fast tokenizers inheriting from PreTrainedTokenizerFast. **kwargs: passed to the `self.tokenize()` method Return: @@ -1201,7 +1485,8 @@ class PreTrainedTokenizer(SpecialTokensMixin): attention_mask: list[int] if return_attention_mask is True (default) overflowing_tokens: list[int] if a ``max_length`` is specified and return_overflowing_tokens is True num_truncated_tokens: int if a ``max_length`` is specified and return_overflowing_tokens is True - special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True`` and return_special_tokens_mask is True + special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True`` + and return_special_tokens_mask is True } With the fields: @@ -1240,7 +1525,9 @@ class PreTrainedTokenizer(SpecialTokensMixin): # Throw an error if we can pad because there is no padding token if pad_to_max_length and self.pad_token_id is None: raise ValueError( - "Unable to set proper padding strategy as the tokenizer does not have a padding token. In this case please set the `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` or add a new pad token via the function add_special_tokens if you want to use a padding strategy" + "Unable to set proper padding strategy as the tokenizer does not have a padding token. " + "In this case please set the `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` " + "or add a new pad token via the function add_special_tokens if you want to use a padding strategy" ) first_ids = get_input_ids(text) @@ -1264,7 +1551,12 @@ class PreTrainedTokenizer(SpecialTokensMixin): def batch_encode_plus( self, batch_text_or_text_pairs: Union[ - List[TextInput], List[TextPairInput], List[PreTokenizedInput], List[PreTokenizedInputPair] + List[TextInput], + List[TextInputPair], + List[PreTokenizedInput], + List[PreTokenizedInputPair], + List[EncodedInput], + List[EncodedInputPair], ], add_special_tokens: bool = True, max_length: Optional[int] = None, @@ -1278,7 +1570,7 @@ class PreTrainedTokenizer(SpecialTokensMixin): return_overflowing_tokens: bool = False, return_special_tokens_masks: bool = False, return_offsets_mapping: bool = False, - return_input_lengths: bool = False, + return_lengths: bool = False, **kwargs ) -> BatchEncoding: """ @@ -1286,7 +1578,10 @@ class PreTrainedTokenizer(SpecialTokensMixin): the mask for sequence classification and the overflowing elements if a ``max_length`` is specified. Args: - batch_text_or_text_pairs (:obj:`List[str]` or :obj:`List[List[str]]`): + batch_text_or_text_pairs (:obj:`List[str]`, :obj:`List[Tuple[str, str]]`, + :obj:`List[List[str]]`, :obj:`List[Tuple[List[str], List[str]]]`, + and for not-fast tokenizers, also: + :obj:`List[List[int]]`, :obj:`List[Tuple[List[int], List[int]]]`): Batch of sequences or pair of sequences to be encoded. This can be a list of string/string-sequences/int-sequences or a list of pair of string/string-sequences/int-sequence (see details in encode_plus) @@ -1339,8 +1634,8 @@ class PreTrainedTokenizer(SpecialTokensMixin): Set to True to return (char_start, char_end) for each token (default False). If using Python's tokenizer, this method will raise NotImplementedError. This one is only available on Rust-based tokenizers inheriting from PreTrainedTokenizerFast. - return_input_lengths (:obj:`bool`, `optional`, defaults to :obj:`False`): - If set the resulting dictionary will include the length of each sample + return_lengths (:obj:`bool`, `optional`, defaults to :obj:`False`): + If set the resulting dictionary will include the length of each encoded inputs **kwargs: passed to the `self.tokenize()` method Return: @@ -1434,12 +1729,10 @@ class PreTrainedTokenizer(SpecialTokensMixin): return_token_type_ids=return_token_type_ids, return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_masks, + return_lengths=return_lengths, + return_tensors=None, # We will convert the whole batch to tensors at the end ) - # Append the non-padded length to the output - if return_input_lengths: - outputs["input_len"] = len(outputs["input_ids"]) - for key, value in outputs.items(): if key not in batch_outputs: batch_outputs[key] = [] @@ -1493,12 +1786,11 @@ class PreTrainedTokenizer(SpecialTokensMixin): return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, - ): - """ - Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. - It adds special tokens, truncates - sequences if overflowing while taking into account the special tokens and manages a window stride for - overflowing tokens + return_lengths: bool = False, + ) -> BatchEncoding: + """ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. + It adds special tokens, truncates sequences if overflowing while taking into account the special tokens and + manages a moving window (with user defined stride) for overflowing tokens Args: ids: list of tokenized input ids. Can be obtained from a string by chaining the @@ -1508,8 +1800,8 @@ class PreTrainedTokenizer(SpecialTokensMixin): max_length: maximum length of the returned list. Will truncate by taking into account the special tokens. add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative to their model. - stride: window stride for overflowing tokens. Can be useful for edge effect removal when using sequential - list of inputs. + stride: window stride for overflowing tokens. Can be useful to remove edge effect when using sequential + list of inputs. The overflowing token will contains a part of the previous window of tokens. truncation_strategy: string selected in the following options: - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length starting from the longest one at each token (when there is a pair of input sequences) @@ -1524,10 +1816,12 @@ class PreTrainedTokenizer(SpecialTokensMixin): Defaults to False: no padding. return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant or PyTorch torch.Tensor instead of a list of python integers. - return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True). - return_attention_mask: (optional) Set to False to avoid returning attention mask (default True) + return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default: set to model specifics). + return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics) return_overflowing_tokens: (optional) Set to True to return overflowing token information (default False). return_special_tokens_mask: (optional) Set to True to return special tokens mask information (default False). + return_lengths (:obj:`bool`, `optional`, defaults to :obj:`False`): + If set the resulting dictionary will include the length of each encoded inputs Return: A Dictionary of shape:: @@ -1538,21 +1832,24 @@ class PreTrainedTokenizer(SpecialTokensMixin): overflowing_tokens: list[int] if a ``max_length`` is specified and return_overflowing_tokens is True num_truncated_tokens: int if a ``max_length`` is specified and return_overflowing_tokens is True special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True`` and return_special_tokens_mask is True + length: int if return_lengths is True } With the fields: - ``input_ids``: list of token ids to be fed to a model - ``token_type_ids``: list of token type ids to be fed to a model + - ``input_ids``: list of token ids to be fed to a model + - ``token_type_ids``: list of token type ids to be fed to a model - ``overflowing_tokens``: list of overflowing tokens if a max length is specified. - ``num_truncated_tokens``: number of overflowing tokens a ``max_length`` is specified - ``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added - tokens and 1 specifying sequence tokens. + - ``overflowing_tokens``: list of overflowing tokens if a max length is specified. + - ``num_truncated_tokens``: number of overflowing tokens a ``max_length`` is specified + - ``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added + tokens and 1 specifying sequence tokens. + - ``length``: this is the length of ``input_ids`` """ pair = bool(pair_ids is not None) len_ids = len(ids) len_pair_ids = len(pair_ids) if pair else 0 + # Load from model defaults if return_token_type_ids is None: return_token_type_ids = "token_type_ids" in self.model_input_names if return_attention_mask is None: @@ -1560,7 +1857,7 @@ class PreTrainedTokenizer(SpecialTokensMixin): encoded_inputs = {} - # Handle max sequence length + # Truncation: Handle max sequence length total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) if max_length and total_len > max_length: ids, pair_ids, overflowing_tokens = self.truncate_sequences( @@ -1574,7 +1871,7 @@ class PreTrainedTokenizer(SpecialTokensMixin): encoded_inputs["overflowing_tokens"] = overflowing_tokens encoded_inputs["num_truncated_tokens"] = total_len - max_length - # Handle special_tokens + # Add special tokens if add_special_tokens: sequence = self.build_inputs_with_special_tokens(ids, pair_ids) token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) @@ -1582,46 +1879,43 @@ class PreTrainedTokenizer(SpecialTokensMixin): sequence = ids + pair_ids if pair else ids token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else []) + # Build output dictionnary + encoded_inputs["input_ids"] = sequence + if return_token_type_ids: + encoded_inputs["token_type_ids"] = token_type_ids if return_special_tokens_mask: if add_special_tokens: encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) else: encoded_inputs["special_tokens_mask"] = [0] * len(sequence) - encoded_inputs["input_ids"] = sequence - if return_token_type_ids: - encoded_inputs["token_type_ids"] = token_type_ids - - if max_length and len(encoded_inputs["input_ids"]) > max_length: - encoded_inputs["input_ids"] = encoded_inputs["input_ids"][:max_length] - if return_token_type_ids: - encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"][:max_length] - if return_special_tokens_mask: - encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"][:max_length] - - if max_length is None and len(encoded_inputs["input_ids"]) > self.max_len: + # Check lengths + assert max_length is None or len(encoded_inputs["input_ids"]) <= max_length + if max_length is None and len(encoded_inputs["input_ids"]) > self.model_max_length: logger.warning( "Token indices sequence length is longer than the specified maximum sequence length " "for this model ({} > {}). Running this sequence through the model will result in " - "indexing errors".format(len(ids), self.max_len) + "indexing errors".format(len(ids), self.model_max_length) ) + # Padding needs_to_be_padded = pad_to_max_length and ( max_length and len(encoded_inputs["input_ids"]) < max_length or max_length is None - and len(encoded_inputs["input_ids"]) < self.max_len - and self.max_len <= 10000 + and len(encoded_inputs["input_ids"]) < self.model_max_length + and self.model_max_length <= LARGE_INTEGER ) - if pad_to_max_length and max_length is None and self.max_len > 10000: + if pad_to_max_length and max_length is None and self.model_max_length > LARGE_INTEGER: logger.warning( "Sequence can't be padded as no maximum length is specified and the model maximum length is too high." ) if needs_to_be_padded: - difference = (max_length if max_length is not None else self.max_len) - len(encoded_inputs["input_ids"]) - + difference = (max_length if max_length is not None else self.model_max_length) - len( + encoded_inputs["input_ids"] + ) if self.padding_side == "right": if return_attention_mask: encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + [0] * difference @@ -1642,14 +1936,16 @@ class PreTrainedTokenizer(SpecialTokensMixin): if return_special_tokens_mask: encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"] - else: raise ValueError("Invalid padding strategy:" + str(self.padding_side)) + else: + if return_attention_mask: + encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) - elif return_attention_mask: - encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + if return_lengths: + encoded_inputs["length"] = len(encoded_inputs["input_ids"]) - # Prepare inputs as tensors if asked + # Prepare model inputs as tensors if asked if return_tensors == "tf" and is_tf_available(): encoded_inputs["input_ids"] = tf.constant([encoded_inputs["input_ids"]]) @@ -1676,14 +1972,27 @@ class PreTrainedTokenizer(SpecialTokensMixin): return BatchEncoding(encoded_inputs) - def prepare_for_tokenization(self, text, **kwargs): + def prepare_for_tokenization(self, text: str, **kwargs) -> str: """ Performs any necessary transformations before tokenization """ return text def truncate_sequences( - self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy="longest_first", stride=0 - ): - """Truncates a sequence pair in place to the maximum length. + self, + ids: List[int], + pair_ids: Optional[List[int]] = None, + num_tokens_to_remove: int = 0, + truncation_strategy: str = "longest_first", + stride: int = 0, + ) -> Tuple[List[int], List[int], List[int]]: + """ Truncates a sequence pair in place to the maximum length. + + Args: + ids: list of tokenized input ids. Can be obtained from a string by chaining the + `tokenize` and `convert_tokens_to_ids` methods. + pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the + `tokenize` and `convert_tokens_to_ids` methods. + num_tokens_to_remove (:obj:`int`, `optional`, defaults to ``0``): + number of tokens to remove using the truncation strategy truncation_strategy: string selected in the following options: - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length starting from the longest one at each token (when there is a pair of input sequences). @@ -1691,6 +2000,9 @@ class PreTrainedTokenizer(SpecialTokensMixin): - 'only_first': Only truncate the first sequence. raise an error if the first sequence is shorter or equal to than num_tokens_to_remove. - 'only_second': Only truncate the second sequence - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length) + stride (:obj:`int`, `optional`, defaults to ``0``): + If set to a number along with max_length, the overflowing tokens returned will contain some tokens + from the main sequence returned. The value of this argument defines the number of additional tokens. """ if num_tokens_to_remove <= 0: return ids, pair_ids, [] @@ -1724,12 +2036,12 @@ class PreTrainedTokenizer(SpecialTokensMixin): ) return (ids, pair_ids, overflowing_tokens) - def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): + def create_token_type_ids_from_sequences(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List[int]: if token_ids_1 is None: return len(token_ids_0) * [0] return [0] * len(token_ids_0) + [1] * len(token_ids_1) - def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + def build_inputs_with_special_tokens(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List: """ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. @@ -1741,7 +2053,9 @@ class PreTrainedTokenizer(SpecialTokensMixin): return token_ids_0 return token_ids_0 + token_ids_1 - def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): + def get_special_tokens_mask( + self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False + ) -> List[int]: """ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. @@ -1758,7 +2072,9 @@ class PreTrainedTokenizer(SpecialTokensMixin): """ return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0)) - def convert_ids_to_tokens(self, ids, skip_special_tokens=False): + def convert_ids_to_tokens( + self, ids: Union[int, List[int]], skip_special_tokens: bool = False + ) -> Union[int, List[int]]: """ Converts a single index or a sequence of indices (integers) in a token " (resp.) a sequence of tokens (str), using the vocabulary and added tokens. @@ -1781,17 +2097,19 @@ class PreTrainedTokenizer(SpecialTokensMixin): tokens.append(self._convert_id_to_token(index)) return tokens - def _convert_id_to_token(self, index): + def _convert_id_to_token(self, index: int) -> str: raise NotImplementedError - def convert_tokens_to_string(self, tokens): + def convert_tokens_to_string(self, tokens: List[str]) -> str: """ Converts a sequence of tokens (string) in a single string. The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids)) but we often want to remove sub-word tokenization artifacts at the same time. """ return " ".join(self.convert_ids_to_tokens(tokens)) - def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): + def decode( + self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True + ) -> str: """ Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary with options to remove special tokens and clean up tokenization spaces. @@ -1830,7 +2148,7 @@ class PreTrainedTokenizer(SpecialTokensMixin): return text @staticmethod - def clean_up_tokenization(out_string): + def clean_up_tokenization(out_string: str) -> str: """ Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms. """ out_string = ( @@ -1850,28 +2168,79 @@ class PreTrainedTokenizer(SpecialTokensMixin): class PreTrainedTokenizerFast(PreTrainedTokenizer): + """ Base class for all fast tokenizers (wrapping HuggingFace tokenizers library). - model_input_names = ["token_type_ids", "attention_mask"] + Inherit from PreTrainedTokenizer. - def __init__(self, tokenizer: BaseTokenizer, **kwargs): - if tokenizer is None: - raise ValueError("Provided tokenizer cannot be None") - self._tokenizer = tokenizer + Handle all the shared methods for tokenization and special tokens as well as methods + downloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary. + This class also contain the added tokens in a unified way on top of all tokenizers so we don't + have to handle the specific vocabulary augmentation methods of the various underlying + dictionary structures (BPE, sentencepiece...). + + Class attributes (overridden by derived classes): + + - ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file + required by the model, and as associated values, the filename for saving the associated file (string). + - ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys + being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the + `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the + associated pretrained vocabulary file. + - ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained + models, and as associated values, the maximum length of the sequence inputs of this model, or None if the + model has no maximum input size. + - ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the + pretrained models, and as associated values, a dictionnary of specific arguments to pass to the + ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the + ``from_pretrained()`` method. + + Args: + - ``tokenizer`` (`BaseTokenizerFast`): A Fast tokenizer from the HuggingFace tokenizer library (in low level Rust language) + - ``model_max_length``: (`Optional`) int: the maximum length in number of tokens for the inputs to the transformer model. + When the tokenizer is loaded with `from_pretrained`, this will be set to the value stored for the associated + model in ``max_model_input_sizes`` (see above). If no value is provided, will default to VERY_LARGE_INTEGER (`int(1e30)`). + no associated max_length can be found in ``max_model_input_sizes``. + - ``padding_side``: (`Optional`) string: the side on which the model should have padding applied. + Should be selected between ['right', 'left'] + - ``model_input_names``: (`Optional`) List[string]: the list of the forward pass inputs accepted by the + model ("token_type_ids", "attention_mask"...). + - ``bos_token``: (`Optional`) string: a beginning of sentence token. + Will be associated to ``self.bos_token`` and ``self.bos_token_id`` + - ``eos_token``: (`Optional`) string: an end of sentence token. + Will be associated to ``self.eos_token`` and ``self.eos_token_id`` + - ``unk_token``: (`Optional`) string: an unknown token. + Will be associated to ``self.unk_token`` and ``self.unk_token_id`` + - ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence). + Will be associated to ``self.sep_token`` and ``self.sep_token_id`` + - ``pad_token``: (`Optional`) string: a padding token. + Will be associated to ``self.pad_token`` and ``self.pad_token_id`` + - ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence + leveraging self-attention along the full depth of the model). + Will be associated to ``self.cls_token`` and ``self.cls_token_id`` + - ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language + modeling). Will be associated to ``self.mask_token`` and ``self.mask_token_id`` + - ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens. + Adding all special tokens here ensure they won't be split by the tokenization process. + Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids`` + """ + + def __init__(self, tokenizer: BaseTokenizerFast, **kwargs): + if not isinstance(tokenizer, BaseTokenizerFast): + raise ValueError( + "Tokenizer should be an instance of a Tokenizer " "provided by HuggingFace tokenizers library." + ) + self._tokenizer: BaseTokenizerFast = tokenizer + + # Initialize all the rest of the kwargs super().__init__(**kwargs) - self.max_len_single_sentence = self.max_len - self.num_special_tokens_to_add( - False - ) # take into account special tokens - self.max_len_sentences_pair = self.max_len - self.num_special_tokens_to_add( - True - ) # take into account special tokens @property - def tokenizer(self) -> BaseTokenizer: + def backend_tokenizer(self) -> BaseTokenizerFast: return self._tokenizer @property - def decoder(self) -> Decoder: + def decoder(self) -> DecoderFast: return self._tokenizer._tokenizer.decoder @property @@ -1885,56 +2254,30 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): def __len__(self) -> int: return self._tokenizer.get_vocab_size(with_added_tokens=True) - @PreTrainedTokenizer.bos_token.setter - def bos_token(self, value): - self._bos_token = value - self._tokenizer.add_special_tokens([self._bos_token]) - - @PreTrainedTokenizer.eos_token.setter - def eos_token(self, value): - self._eos_token = value - self._tokenizer.add_special_tokens([self._eos_token]) - - @PreTrainedTokenizer.unk_token.setter - def unk_token(self, value): - self._unk_token = value - self._tokenizer.add_special_tokens([self._unk_token]) - - @PreTrainedTokenizer.sep_token.setter - def sep_token(self, value): - self._sep_token = value - self._tokenizer.add_special_tokens([self._sep_token]) - - @PreTrainedTokenizer.pad_token.setter - def pad_token(self, value): - self._pad_token = value - self._tokenizer.add_special_tokens([self._pad_token]) - - @PreTrainedTokenizer.cls_token.setter - def cls_token(self, value): - self._cls_token = value - self._tokenizer.add_special_tokens([self._cls_token]) - - @PreTrainedTokenizer.mask_token.setter - def mask_token(self, value): - self._mask_token = value - self._tokenizer.add_special_tokens([self._mask_token]) - - @PreTrainedTokenizer.additional_special_tokens.setter - def additional_special_tokens(self, value): - self._additional_special_tokens = value - self._tokenizer.add_special_tokens(self.all_special_tokens) + def _maybe_update_backend(self, value): + """ Update the backend fast tokenizer. + Override method from base class SpecialTokensMixin """ + self._tokenizer.add_special_tokens(value) def _convert_encoding( self, - encoding, - return_tensors=None, - return_token_type_ids=None, - return_attention_mask=None, - return_overflowing_tokens=False, - return_special_tokens_mask=False, - return_offsets_mapping=False, - ): + encoding: EncodingFast, + return_tensors: Optional[bool] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + ) -> Dict[str, Any]: + """ Convert the encoding representation (from low-level HuggingFace tokenizer output) to a python Dict. + + Overflowing tokens are converted to additional examples (like batches) so the output values of + the dict are lists (overflows) of lists (tokens). + + If return_tensors is not None, these lists of lists are converted to 2-D tensors + for input_ids, token_type_ids and attention_mask. + Output shape: (overflows, sequence length) + """ if return_token_type_ids is None: return_token_type_ids = "token_type_ids" in self.model_input_names if return_attention_mask is None: @@ -1958,75 +2301,86 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): if return_offsets_mapping: encoding_dict["offset_mapping"].append(e.offsets) - # Prepare inputs as tensors if asked - if return_tensors == "tf" and is_tf_available(): - encoding_dict["input_ids"] = tf.constant(encoding_dict["input_ids"]) - if "token_type_ids" in encoding_dict: - encoding_dict["token_type_ids"] = tf.constant(encoding_dict["token_type_ids"]) - - if "attention_mask" in encoding_dict: - encoding_dict["attention_mask"] = tf.constant(encoding_dict["attention_mask"]) - - elif return_tensors == "pt" and is_torch_available(): - encoding_dict["input_ids"] = torch.tensor(encoding_dict["input_ids"]) - if "token_type_ids" in encoding_dict: - encoding_dict["token_type_ids"] = torch.tensor(encoding_dict["token_type_ids"]) - - if "attention_mask" in encoding_dict: - encoding_dict["attention_mask"] = torch.tensor(encoding_dict["attention_mask"]) - elif return_tensors is not None: - logger.warning( - "Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format( - return_tensors - ) - ) + if return_tensors is not None: + for key, value in encoding_dict.items(): + if return_tensors == "tf" and is_tf_available(): + encoding_dict[key] = tf.constant(value) + elif return_tensors == "pt" and is_torch_available(): + encoding_dict[key] = torch.tensor(value) + elif return_tensors is not None: + logger.warning( + "Unable to convert output to tensors format {}, " + "PyTorch or TensorFlow is not available.".format(return_tensors) + ) return encoding_dict - def _convert_token_to_id_with_added_voc(self, token): - id = self._tokenizer.token_to_id(token) - if id is None: + def _convert_token_to_id_with_added_voc(self, token: int) -> str: + index = self._tokenizer.token_to_id(token) + if index is None: return self.unk_token_id - return id + return index - def _convert_id_to_token(self, index: int) -> str: + def _convert_id_to_token(self, index: int) -> Optional[str]: return self._tokenizer.id_to_token(int(index)) def convert_tokens_to_string(self, tokens: List[int], skip_special_tokens: bool = False) -> str: return self._tokenizer.decode(tokens, skip_special_tokens) - def add_tokens(self, new_tokens: List[Union[str, AddedToken]]) -> int: + def add_tokens(self, new_tokens: List[Union[str, AddedTokenFast]]) -> int: + """ + Add a list of new tokens to the tokenizer class. If the new tokens are not in the + vocabulary, they are added to it with indices starting from length of the current vocabulary. + + Args: + new_tokens: string or list of string or AddedTokenFast. Each string is a token to add. + Tokens are only added if they are not already in the vocabulary. AddedTokenFast wrap a string token to let you personnalize it's behavior (Whether this token should only match against single word, whether this token should strip all potential whitespaces on the left side, Whether this token should strip all potential whitespaces on the right side...). + See details for AddedToken in HuggingFace tokenizers library. + + Returns: + Number of tokens added to the vocabulary. + + Examples:: + + # Let's see how to increase the vocabulary of Bert model and tokenizer + tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') + model = BertModel.from_pretrained('bert-base-uncased') + + num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2']) + print('We have added', num_added_toks, 'tokens') + model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. + """ if isinstance(new_tokens, str): new_tokens = [new_tokens] return self._tokenizer.add_tokens(new_tokens) def add_special_tokens(self, special_tokens_dict: dict) -> int: - added = super().add_special_tokens(special_tokens_dict) + # Map special tokens to class attributes (self.pad_token...) + num_added_tokens = super().add_special_tokens(special_tokens_dict) + + # If the backend tokenizer the only specificities of special tokens are that + # - they will never be processed by the model, and + # - they will be removed while decoding. + # But they are not mapped to special attributes in the backend so we can just + # send a list. tokens = flatten(special_tokens_dict.values()) self._tokenizer.add_special_tokens(tokens) - return added - def build_inputs_with_special_tokens( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: - if token_ids_1 is None: - return token_ids_0 - else: - return token_ids_0 + token_ids_1 + return num_added_tokens def num_special_tokens_to_add(self, pair: bool = False) -> int: - return self.tokenizer.num_special_tokens_to_add(pair) + return self._tokenizer.num_special_tokens_to_add(pair) def tokenize( self, text: TextInput, pair: Optional[TextInput] = None, add_special_tokens: bool = False ) -> List[str]: - return self.tokenizer.encode(text, pair, add_special_tokens).tokens + return self._tokenizer.encode(text, pair, add_special_tokens).tokens def batch_encode_plus( self, batch_text_or_text_pairs: Union[ - List[TextInput], List[TextPairInput], List[PreTokenizedInput], List[PreTokenizedInputPair] - ] = None, + List[TextInput], List[TextInputPair], List[PreTokenizedInput], List[PreTokenizedInputPair] + ], add_special_tokens: bool = True, max_length: Optional[int] = None, stride: int = 0, @@ -2039,15 +2393,13 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, + return_lengths: bool = False, **kwargs ) -> BatchEncoding: - if batch_text_or_text_pairs is None: + if not isinstance(batch_text_or_text_pairs, list): raise ValueError( - "None is not a valid input. " - "Should be a list/tuple of strings, " - "a list/tuple of integers, " - "A list of list of strings or tuple of strings." + "batch_text_or_text_pairs has to be a list (got {})".format(type(batch_text_or_text_pairs)) ) # Needed if we have to return a tensor @@ -2070,11 +2422,6 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): pad_token=self._pad_token, ): - if not isinstance(batch_text_or_text_pairs, list): - raise TypeError( - "batch_text_or_text_pairs has to be a list (got {})".format(type(batch_text_or_text_pairs)) - ) - # Check for the pretokenized path if is_pretokenized: encodings = [] @@ -2089,35 +2436,27 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): "index {} is of type {}".format(i, type(sample)) ) - # Convert to tuple for convenience - if isinstance(sample, list): - sample = (sample,) + # Test if we have a pair of sentences by checking the depth of nesting + is_pair = bool(len(sample) > 0 and isinstance(sample[0], (list, tuple))) - encodings_text = Encoding.merge(self._tokenizer.encode_batch(sample[0], False), True) + # Take care of the first sequence - we multi-thread over the words + encodings_text = EncodingFast.merge( + self._tokenizer.encode_batch(sample[0] if is_pair else sample, add_special_tokens=False), + growing_offsets=True, + ) - # Check if we have pairs - if len(sample) == 2: - encodings_pair = Encoding.merge( - self._tokenizer.encode_batch([("", s) for s in sample[1]], False), True + # Take care of the second sequence if we have a pair + if is_pair: + encodings_pair = EncodingFast.merge( + self._tokenizer.encode_batch([("", s) for s in sample[1]], add_special_tokens=False), + growing_offsets=True, ) - - # No pair, default to None - elif len(sample) == 1: + else: encodings_pair = None - # Something else is invalid - else: - raise ValueError( - "batch_encode_plus(..., is_pretokenized=True) requires batch_text_or_text_pairs " - "to be either List[List[str]] or List[Tuple[List[str], List[str]]] but sample at " - "index {} has too much dimensions (required 1 or 2, got: {}, type {})".format( - i, len(sample), type(sample) - ) - ) - - # Post-process + # Post-process - truncate/pad and add special tokens encoding = self._tokenizer.post_process(encodings_text, encodings_pair, add_special_tokens) - encodings += [encoding] + encodings.append(encoding) # Classical path with strings input else: @@ -2138,6 +2477,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): ) # Convert encoding to dict + # `Tokens` has type: List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]] + # with nested dimensions corresponding to batch, overflows, sequence length tokens = [ self._convert_encoding( encoding=encoding, @@ -2154,6 +2495,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): # Sanitize the output to have dict[list] from list[dict] sanitized = {} for key in tokens[0].keys(): + # To List[List[List[int]]] of shape (batch, overflows, sequence length) stack = [e for item in tokens for e in item[key]] if return_tensors == "tf": stack = tf.stack(stack, axis=0) @@ -2167,9 +2509,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): # If returning overflowing tokens, we need to return a mapping # from the batch idx to the original sample if return_overflowing_tokens: - overflow_to_sample_mapping = [ - i if len(item["input_ids"]) == 1 else [i] * len(item["input_ids"]) for i, item in enumerate(tokens) - ] + overflow_to_sample_mapping = flatten([[i] * len(enc["input_ids"]) for i, enc in enumerate(tokens)]) sanitized["overflow_to_sample_mapping"] = overflow_to_sample_mapping return BatchEncoding(sanitized, encodings) @@ -2199,7 +2539,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): # Encode through encode_batch with sequence of only one word which will be merged after hand encoding = self._tokenizer.encode_batch(text, add_special_tokens=False) - encoding = Encoding.merge(encoding, True) + encoding = EncodingFast.merge(encoding, growing_offsets=True) # Let's do the same for pairs if provided if isinstance(text_pair, list): @@ -2207,7 +2547,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): encoding_pair = self._tokenizer.encode_batch( [("", p) for p in text_pair], add_special_tokens=False ) - encoding_pair = Encoding.merge(encoding_pair, True) + encoding_pair = EncodingFast.merge(encoding_pair, growing_offsets=True) elif text_pair is None: encoding_pair = None else: @@ -2268,8 +2608,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): def decode( self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True - ): - text = self.tokenizer.decode(token_ids, skip_special_tokens) + ) -> str: + text = self._tokenizer.decode(token_ids, skip_special_tokens) if clean_up_tokenization_spaces: clean_text = self.clean_up_tokenization(text) diff --git a/src/transformers/tokenization_xlm.py b/src/transformers/tokenization_xlm.py index a9b79cec82..b0a5204bfb 100644 --- a/src/transformers/tokenization_xlm.py +++ b/src/transformers/tokenization_xlm.py @@ -629,9 +629,6 @@ class XLMTokenizer(PreTrainedTokenizer): **kwargs, ) - self.max_len_single_sentence = self.max_len - 2 # take into account special tokens - self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens - # cache of sm.MosesPunctNormalizer instance self.cache_moses_punct_normalizer = dict() # cache of sm.MosesTokenizer instance diff --git a/src/transformers/tokenization_xlm_roberta.py b/src/transformers/tokenization_xlm_roberta.py index 2e70ed60a3..f26634410e 100644 --- a/src/transformers/tokenization_xlm_roberta.py +++ b/src/transformers/tokenization_xlm_roberta.py @@ -128,8 +128,6 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): mask_token=mask_token, **kwargs, ) - self.max_len_single_sentence = self.max_len - 2 # take into account special tokens - self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens try: import sentencepiece as spm diff --git a/src/transformers/tokenization_xlnet.py b/src/transformers/tokenization_xlnet.py index 800ef09c99..995b35ed8b 100644 --- a/src/transformers/tokenization_xlnet.py +++ b/src/transformers/tokenization_xlnet.py @@ -138,8 +138,6 @@ class XLNetTokenizer(PreTrainedTokenizer): **kwargs, ) - self.max_len_single_sentence = self.max_len - 2 # take into account special tokens - self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens self._pad_token_type_id = 3 try: diff --git a/templates/adding_a_new_model/tokenization_xxx.py b/templates/adding_a_new_model/tokenization_xxx.py index 667a130a9b..6a96b0ff9d 100644 --- a/templates/adding_a_new_model/tokenization_xxx.py +++ b/templates/adding_a_new_model/tokenization_xxx.py @@ -117,8 +117,6 @@ class XxxTokenizer(PreTrainedTokenizer): mask_token=mask_token, **kwargs, ) - self.max_len_single_sentence = self.max_len - 2 # take into account special tokens - self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens if not os.path.isfile(vocab_file): raise ValueError( diff --git a/tests/test_tokenization_fast.py b/tests/test_tokenization_fast.py index b34552cd03..2439f58e54 100644 --- a/tests/test_tokenization_fast.py +++ b/tests/test_tokenization_fast.py @@ -1,3 +1,4 @@ +import logging import unittest from collections import namedtuple from itertools import takewhile @@ -21,6 +22,10 @@ from transformers.tokenization_roberta import RobertaTokenizerFast from transformers.tokenization_transfo_xl import TransfoXLTokenizerFast +logging.basicConfig(level=logging.INFO) + +logger = logging.getLogger(__name__) + NON_ENGLISH_TAGS = ["chinese", "dutch", "french", "finnish", "german", "multilingual"] Tokenizer = namedtuple("Tokenizer", ["name", "rust_cls", "python_cls", "vocab_key", "filter"]) @@ -83,6 +88,85 @@ class CommonFastTokenizerTest(unittest.TestCase): self.assert_add_tokens(tokenizer_r) self.assert_offsets_mapping(tokenizer_r) self.assert_add_special_tokens(tokenizer_r) + self.assert_alignement_methods(tokenizer_r) + + def assert_alignement_methods(self, tokenizer_r): + words = ["Wonderful", "no", "inspiration", "example", "with", "subtoken"] + text = " ".join(words) + batch_size = 3 + + encoding = tokenizer_r.encode_plus(text, add_special_tokens=False) + + batch_encoding = tokenizer_r.batch_encode_plus([text] * batch_size, add_special_tokens=False) + num_tokens = len(encoding["input_ids"]) + + last_word_index = len(words) - 1 + last_token_index = num_tokens - 1 + last_batch_index = batch_size - 1 + last_char_index = len(text) - 1 + + # words, tokens + self.assertEqual(len(encoding.words(0)), num_tokens) + self.assertEqual(max(encoding.words(0)), last_word_index) + self.assertEqual(min(encoding.words(0)), 0) + self.assertEqual(len(batch_encoding.words(last_batch_index)), num_tokens) + self.assertEqual(max(batch_encoding.words(last_batch_index)), last_word_index) + self.assertEqual(min(batch_encoding.words(last_batch_index)), 0) + self.assertEqual(len(encoding.tokens(0)), num_tokens) + + # Assert token_to_word + self.assertEqual(encoding.token_to_word(0), 0) + self.assertEqual(encoding.token_to_word(0, 0), 0) + self.assertEqual(encoding.token_to_word(last_token_index), last_word_index) + self.assertEqual(encoding.token_to_word(0, last_token_index), last_word_index) + self.assertEqual(batch_encoding.token_to_word(1, 0), 0) + self.assertEqual(batch_encoding.token_to_word(0, last_token_index), last_word_index) + self.assertEqual(batch_encoding.token_to_word(last_batch_index, last_token_index), last_word_index) + + # Assert word_to_tokens + self.assertEqual(encoding.word_to_tokens(0).start, 0) + self.assertEqual(encoding.word_to_tokens(0, 0).start, 0) + self.assertEqual(encoding.word_to_tokens(last_word_index).end, last_token_index + 1) + self.assertEqual(encoding.word_to_tokens(0, last_word_index).end, last_token_index + 1) + self.assertEqual(batch_encoding.word_to_tokens(1, 0).start, 0) + self.assertEqual(batch_encoding.word_to_tokens(0, last_word_index).end, last_token_index + 1) + self.assertEqual(batch_encoding.word_to_tokens(last_batch_index, last_word_index).end, last_token_index + 1) + + # Assert token_to_chars + self.assertEqual(encoding.token_to_chars(0).start, 0) + self.assertEqual(encoding.token_to_chars(0, 0).start, 0) + self.assertEqual(encoding.token_to_chars(last_token_index).end, last_char_index + 1) + self.assertEqual(encoding.token_to_chars(0, last_token_index).end, last_char_index + 1) + self.assertEqual(batch_encoding.token_to_chars(1, 0).start, 0) + self.assertEqual(batch_encoding.token_to_chars(0, last_token_index).end, last_char_index + 1) + self.assertEqual(batch_encoding.token_to_chars(last_batch_index, last_token_index).end, last_char_index + 1) + + # Assert char_to_token + self.assertEqual(encoding.char_to_token(0), 0) + self.assertEqual(encoding.char_to_token(0, 0), 0) + self.assertEqual(encoding.char_to_token(last_char_index), last_token_index) + self.assertEqual(encoding.char_to_token(0, last_char_index), last_token_index) + self.assertEqual(batch_encoding.char_to_token(1, 0), 0) + self.assertEqual(batch_encoding.char_to_token(0, last_char_index), last_token_index) + self.assertEqual(batch_encoding.char_to_token(last_batch_index, last_char_index), last_token_index) + + # Assert char_to_word + self.assertEqual(encoding.char_to_word(0), 0) + self.assertEqual(encoding.char_to_word(0, 0), 0) + self.assertEqual(encoding.char_to_word(last_char_index), last_word_index) + self.assertEqual(encoding.char_to_word(0, last_char_index), last_word_index) + self.assertEqual(batch_encoding.char_to_word(1, 0), 0) + self.assertEqual(batch_encoding.char_to_word(0, last_char_index), last_word_index) + self.assertEqual(batch_encoding.char_to_word(last_batch_index, last_char_index), last_word_index) + + # Assert word_to_chars + self.assertEqual(encoding.word_to_chars(0).start, 0) + self.assertEqual(encoding.word_to_chars(0, 0).start, 0) + self.assertEqual(encoding.word_to_chars(last_word_index).end, last_char_index + 1) + self.assertEqual(encoding.word_to_chars(0, last_word_index).end, last_char_index + 1) + self.assertEqual(batch_encoding.word_to_chars(1, 0).start, 0) + self.assertEqual(batch_encoding.word_to_chars(0, last_word_index).end, last_char_index + 1) + self.assertEqual(batch_encoding.word_to_chars(last_batch_index, last_word_index).end, last_char_index + 1) def assert_tokenization_python_rust_equals(self, tokenizer_p, tokenizer_r): # Ensure basic input match @@ -306,7 +390,6 @@ class CommonFastTokenizerTest(unittest.TestCase): self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"]) # Simple input - # TODO: Re-enable this test when batch_encode_plus with padding correctly handles padding input_r = tokenizer_r.batch_encode_plus( ["This is a simple input 1", "This is a simple input 2"], max_length=max_length, pad_to_max_length=True ) @@ -316,7 +399,6 @@ class CommonFastTokenizerTest(unittest.TestCase): assert_batch_padded_input_match(input_r, input_p) # Pair input - # TODO: Re-enable this test when batch_encode_plus with padding correctly handles padding input_r = tokenizer_r.batch_encode_plus( [ ("This is a simple input 1", "This is a simple input 2"),