is_pretokenized -> is_split_into_words (#7236)

* is_pretokenized -> is_split_into_words

* Fix tests
This commit is contained in:
Sylvain Gugger 2020-09-22 09:34:35 -04:00 committed by GitHub
parent 324f361e91
commit 21ca148090
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 142 additions and 72 deletions

View File

@ -324,7 +324,7 @@ which we'll use in a moment:
id2tag = {id: tag for tag, id in tag2id.items()}
To encode the tokens, we'll use a pre-trained DistilBert tokenizer. We can tell the tokenizer that we're dealing
with ready-split tokens rather than full sentence strings by passing ``is_pretokenized=True``. We'll also pass
with ready-split tokens rather than full sentence strings by passing ``is_split_into_words=True``. We'll also pass
``padding=True`` and ``truncation=True`` to pad the sequences to be the same length. Lastly, we can tell the model
to return information about the tokens which are split by the wordpiece tokenization process, which we will need in
a moment.
@ -333,8 +333,8 @@ a moment.
from transformers import DistilBertTokenizerFast
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-cased')
train_encodings = tokenizer(train_texts, is_pretokenized=True, return_offsets_mapping=True, padding=True, truncation=True)
val_encodings = tokenizer(val_texts, is_pretokenized=True, return_offsets_mapping=True, padding=True, truncation=True)
train_encodings = tokenizer(train_texts, is_split_into_words=True, return_offsets_mapping=True, padding=True, truncation=True)
val_encodings = tokenizer(val_texts, is_split_into_words=True, return_offsets_mapping=True, padding=True, truncation=True)
Great, so now our tokens are nicely encoded in the format that they need to be in to feed them into our DistilBert
model below.

View File

@ -290,12 +290,12 @@ predictions in `named entity recognition (NER) <https://en.wikipedia.org/wiki/Na
if that was the case) but just split into words (which is often the first step in subword tokenization algorithms
like BPE).
If you want to use pre-tokenized inputs, just set :obj:`is_pretokenized=True` when passing your inputs to the
If you want to use pre-tokenized inputs, just set :obj:`is_split_into_words=True` when passing your inputs to the
tokenizer. For instance, we have:
.. code-block::
>>> encoded_input = tokenizer(["Hello", "I'm", "a", "single", "sentence"], is_pretokenized=True)
>>> encoded_input = tokenizer(["Hello", "I'm", "a", "single", "sentence"], is_split_into_words=True)
>>> print(encoded_input)
{'input_ids': [101, 8667, 146, 112, 182, 170, 1423, 5650, 102],
'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0],
@ -312,7 +312,7 @@ like this:
batch_sentences = [["Hello", "I'm", "a", "single", "sentence"],
["And", "another", "sentence"],
["And", "the", "very", "very", "last", "one"]]
encoded_inputs = tokenizer(batch_sentences, is_pretokenized=True)
encoded_inputs = tokenizer(batch_sentences, is_split_into_words=True)
or a batch of pair sentences like this:
@ -321,7 +321,7 @@ or a batch of pair sentences like this:
batch_of_second_sentences = [["I'm", "a", "sentence", "that", "goes", "with", "the", "first", "sentence"],
["And", "I", "should", "be", "encoded", "with", "the", "second", "sentence"],
["And", "I", "go", "with", "the", "very", "last", "one"]]
encoded_inputs = tokenizer(batch_sentences, batch_of_second_sentences, is_pretokenized=True)
encoded_inputs = tokenizer(batch_sentences, batch_of_second_sentences, is_split_into_words=True)
And you can add padding, truncation as well as directly return tensors like before:
@ -330,14 +330,14 @@ And you can add padding, truncation as well as directly return tensors like befo
## PYTORCH CODE
batch = tokenizer(batch_sentences,
batch_of_second_sentences,
is_pretokenized=True,
is_split_into_words=True,
padding=True,
truncation=True,
return_tensors="pt")
## TENSORFLOW CODE
batch = tokenizer(batch_sentences,
batch_of_second_sentences,
is_pretokenized=True,
is_split_into_words=True,
padding=True,
truncation=True,
return_tensors="tf")

View File

@ -17,6 +17,7 @@
import json
import os
import warnings
from functools import lru_cache
import regex as re
@ -121,7 +122,7 @@ class GPT2Tokenizer(PreTrainedTokenizer):
.. note::
When used with ``is_pretokenized=True``, this tokenizer will add a space before each word (even the first one).
When used with ``is_split_into_words=True``, this tokenizer will add a space before each word (even the first one).
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
should refer to the superclass for more information regarding methods.
@ -288,9 +289,16 @@ class GPT2Tokenizer(PreTrainedTokenizer):
return vocab_file, merge_file
def prepare_for_tokenization(self, text, is_pretokenized=False, **kwargs):
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
if "is_pretokenized" in kwargs:
warnings.warn(
"`is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead.",
FutureWarning,
)
is_split_into_words = kwargs.pop("is_pretokenized")
add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
if is_pretokenized or add_prefix_space:
if is_split_into_words or add_prefix_space:
text = " " + text
return (text, kwargs)
@ -317,7 +325,7 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
.. note::
When used with ``is_pretokenized=True``, this tokenizer needs to be instantiated with
When used with ``is_split_into_words=True``, this tokenizer needs to be instantiated with
``add_prefix_space=True``.
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
@ -377,9 +385,15 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
self.add_prefix_space = add_prefix_space
def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
if "is_pretokenized" in kwargs:
warnings.warn(
"`is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead.",
FutureWarning,
)
is_split_into_words = kwargs.pop("is_pretokenized")
is_pretokenized = kwargs.get("is_pretokenized", False)
assert self.add_prefix_space or not is_pretokenized, (
is_split_into_words = kwargs.get("is_split_into_words", False)
assert self.add_prefix_space or not is_split_into_words, (
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
"to use it with pretokenized inputs."
)
@ -387,9 +401,15 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
return super()._batch_encode_plus(*args, **kwargs)
def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
if "is_pretokenized" in kwargs:
warnings.warn(
"`is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead.",
FutureWarning,
)
is_split_into_words = kwargs.pop("is_pretokenized")
is_pretokenized = kwargs.get("is_pretokenized", False)
assert self.add_prefix_space or not is_pretokenized, (
is_split_into_words = kwargs.get("is_split_into_words", False)
assert self.add_prefix_space or not is_split_into_words, (
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
"to use it with pretokenized inputs."
)

View File

@ -14,7 +14,7 @@
# limitations under the License.
"""Tokenization classes for RoBERTa."""
import warnings
from typing import List, Optional
from tokenizers.processors import RobertaProcessing
@ -81,7 +81,7 @@ class RobertaTokenizer(GPT2Tokenizer):
.. note::
When used with ``is_pretokenized=True``, this tokenizer will add a space before each word (even the first one).
When used with ``is_split_into_words=True``, this tokenizer will add a space before each word (even the first one).
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
should refer to the superclass for more information regarding methods.
@ -251,9 +251,16 @@ class RobertaTokenizer(GPT2Tokenizer):
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
def prepare_for_tokenization(self, text, is_pretokenized=False, **kwargs):
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
if "is_pretokenized" in kwargs:
warnings.warn(
"`is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead.",
FutureWarning,
)
is_split_into_words = kwargs.pop("is_pretokenized")
add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
if (is_pretokenized or add_prefix_space) and (len(text) > 0 and not text[0].isspace()):
if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()):
text = " " + text
return (text, kwargs)
@ -280,7 +287,7 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
.. note::
When used with ``is_pretokenized=True``, this tokenizer needs to be instantiated with
When used with ``is_split_into_words=True``, this tokenizer needs to be instantiated with
``add_prefix_space=True``.
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the methods. Users

View File

@ -19,6 +19,7 @@
import itertools
import re
import unicodedata
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union, overload
from .file_utils import add_end_docstrings
@ -250,6 +251,12 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
Returns:
:obj:`List[str]`: The list of tokens.
"""
if "is_pretokenized" in kwargs:
warnings.warn(
"`is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead.",
FutureWarning,
)
kwargs["is_split_into_words"] = kwargs.pop("is_pretokenized")
# Simple mapping string => AddedToken for special tokens with specific tokenization behaviors
all_special_tokens_extended = dict(
(str(t), t) for t in self.all_special_tokens_extended if isinstance(t, AddedToken)
@ -402,7 +409,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
max_length: Optional[int] = None,
stride: int = 0,
is_pretokenized: bool = False,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
@ -419,17 +426,19 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
tokens = self.tokenize(text, **kwargs)
return self.convert_tokens_to_ids(tokens)
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
if is_pretokenized:
tokens = list(itertools.chain(*(self.tokenize(t, is_pretokenized=True, **kwargs) for t in text)))
if is_split_into_words:
tokens = list(
itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text))
)
return self.convert_tokens_to_ids(tokens)
else:
return self.convert_tokens_to_ids(text)
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
return text
else:
if is_pretokenized:
if is_split_into_words:
raise ValueError(
f"Input {text} is not valid. Should be a string or a list/tuple of strings when `is_pretokenized=True`."
f"Input {text} is not valid. Should be a string or a list/tuple of strings when `is_split_into_words=True`."
)
else:
raise ValueError(
@ -445,6 +454,13 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
"https://github.com/huggingface/transformers/pull/2674"
)
if "is_pretokenized" in kwargs:
warnings.warn(
"`is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead.",
FutureWarning,
)
is_split_into_words = kwargs.pop("is_pretokenized")
first_ids = get_input_ids(text)
second_ids = get_input_ids(text_pair) if text_pair is not None else None
@ -482,7 +498,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
max_length: Optional[int] = None,
stride: int = 0,
is_pretokenized: bool = False,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
@ -499,8 +515,10 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
tokens = self.tokenize(text, **kwargs)
return self.convert_tokens_to_ids(tokens)
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
if is_pretokenized:
tokens = list(itertools.chain(*(self.tokenize(t, is_pretokenized=True, **kwargs) for t in text)))
if is_split_into_words:
tokens = list(
itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text))
)
return self.convert_tokens_to_ids(tokens)
else:
return self.convert_tokens_to_ids(text)
@ -518,11 +536,18 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
"transformers.PreTrainedTokenizerFast."
)
if "is_pretokenized" in kwargs:
warnings.warn(
"`is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead.",
FutureWarning,
)
is_split_into_words = kwargs.pop("is_pretokenized")
input_ids = []
for ids_or_pair_ids in batch_text_or_text_pairs:
if not isinstance(ids_or_pair_ids, (list, tuple)):
ids, pair_ids = ids_or_pair_ids, None
elif is_pretokenized and not isinstance(ids_or_pair_ids[0], (list, tuple)):
elif is_split_into_words and not isinstance(ids_or_pair_ids[0], (list, tuple)):
ids, pair_ids = ids_or_pair_ids, None
else:
ids, pair_ids = ids_or_pair_ids
@ -616,7 +641,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
return batch_outputs
def prepare_for_tokenization(
self, text: str, is_pretokenized: bool = False, **kwargs
self, text: str, is_split_into_words: bool = False, **kwargs
) -> Tuple[str, Dict[str, Any]]:
"""
Performs any necessary transformations before tokenization.
@ -627,7 +652,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
Args:
test (:obj:`str`):
The text to prepare.
is_pretokenized (:obj:`bool`, `optional`, defaults to :obj:`False`):
is_split_into_words (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the text has been pretokenized.
kwargs:
Keyword arguments to use for the tokenization.

View File

@ -1088,7 +1088,7 @@ ENCODE_KWARGS_DOCSTRING = r"""
:obj:`return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence
returned to provide some overlap between truncated and overflowing sequences. The value of this
argument defines the number of overlapping tokens.
is_pretokenized (:obj:`bool`, `optional`, defaults to :obj:`False`):
is_split_into_words (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the input is already pre-tokenized (e.g., split into words), in which case the tokenizer
will skip the pre-tokenization step. This is useful for NER or token classification.
pad_to_multiple_of (:obj:`int`, `optional`):
@ -1863,7 +1863,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
stride: int = 0,
is_pretokenized: bool = False,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
@ -1884,12 +1884,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
The sequence or batch of sequences to be encoded.
Each sequence can be a string or a list of strings (pretokenized string).
If the sequences are provided as list of strings (pretokenized), you must set
:obj:`is_pretokenized=True` (to lift the ambiguity with a batch of sequences).
:obj:`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
text_pair (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
The sequence or batch of sequences to be encoded.
Each sequence can be a string or a list of strings (pretokenized string).
If the sequences are provided as list of strings (pretokenized), you must set
:obj:`is_pretokenized=True` (to lift the ambiguity with a batch of sequences).
:obj:`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
"""
# Input type checking for clearer error
assert isinstance(text, str) or (
@ -1928,8 +1928,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
)
is_batched = bool(
(not is_pretokenized and isinstance(text, (list, tuple)))
or (is_pretokenized and isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple)))
(not is_split_into_words and isinstance(text, (list, tuple)))
or (
is_split_into_words and isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
)
)
if is_batched:
@ -1941,7 +1943,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
truncation=truncation,
max_length=max_length,
stride=stride,
is_pretokenized=is_pretokenized,
is_split_into_words=is_split_into_words,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
@ -1962,7 +1964,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
truncation=truncation,
max_length=max_length,
stride=stride,
is_pretokenized=is_pretokenized,
is_split_into_words=is_split_into_words,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
@ -1985,7 +1987,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
stride: int = 0,
is_pretokenized: bool = False,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
@ -2032,7 +2034,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
truncation_strategy=truncation_strategy,
max_length=max_length,
stride=stride,
is_pretokenized=is_pretokenized,
is_split_into_words=is_split_into_words,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
@ -2054,7 +2056,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
max_length: Optional[int] = None,
stride: int = 0,
is_pretokenized: bool = False,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
@ -2084,7 +2086,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
stride: int = 0,
is_pretokenized: bool = False,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
@ -2126,7 +2128,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
truncation_strategy=truncation_strategy,
max_length=max_length,
stride=stride,
is_pretokenized=is_pretokenized,
is_split_into_words=is_split_into_words,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
@ -2154,7 +2156,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
max_length: Optional[int] = None,
stride: int = 0,
is_pretokenized: bool = False,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,

View File

@ -17,6 +17,7 @@
"""
import os
import warnings
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Union
@ -328,7 +329,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
max_length: Optional[int] = None,
stride: int = 0,
is_pretokenized: bool = False,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[str] = None,
return_token_type_ids: Optional[bool] = None,
@ -346,6 +347,13 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
"batch_text_or_text_pairs has to be a list (got {})".format(type(batch_text_or_text_pairs))
)
if "is_pretokenized" in kwargs:
warnings.warn(
"`is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead.",
FutureWarning,
)
is_split_into_words = kwargs.pop("is_pretokenized")
if kwargs:
raise ValueError(f"Keyword arguments {kwargs} not recognized.")
@ -365,19 +373,21 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
encodings = self._tokenizer.encode(
*batch_text_or_text_pairs[0],
add_special_tokens=add_special_tokens,
is_pretokenized=is_pretokenized,
is_pretokenized=is_split_into_words,
)
else:
# We got a single sequence
encodings = self._tokenizer.encode(
batch_text_or_text_pairs[0],
add_special_tokens=add_special_tokens,
is_pretokenized=is_pretokenized,
is_pretokenized=is_split_into_words,
)
encodings = [encodings]
else:
encodings = self._tokenizer.encode_batch(
batch_text_or_text_pairs, add_special_tokens=add_special_tokens, is_pretokenized=is_pretokenized
batch_text_or_text_pairs,
add_special_tokens=add_special_tokens,
is_pretokenized=is_split_into_words,
)
# Convert encoding to dict
@ -423,7 +433,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
max_length: Optional[int] = None,
stride: int = 0,
is_pretokenized: bool = False,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[bool] = None,
return_token_type_ids: Optional[bool] = None,
@ -435,11 +445,17 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
verbose: bool = True,
**kwargs
) -> BatchEncoding:
if "is_pretokenized" in kwargs:
warnings.warn(
"`is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead.",
FutureWarning,
)
is_split_into_words = kwargs.pop("is_pretokenized")
batched_input = [(text, text_pair)] if text_pair else [text]
batched_output = self._batch_encode_plus(
batched_input,
is_pretokenized=is_pretokenized,
is_split_into_words=is_split_into_words,
add_special_tokens=add_special_tokens,
padding_strategy=padding_strategy,
truncation_strategy=truncation_strategy,

View File

@ -743,7 +743,7 @@ class TokenizerTesterMixin:
# formatted_input = tokenizer.encode(sequence, add_special_tokens=True, add_prefix_space=False)
# self.assertEqual(
# tokenizer.encode(tokens, is_pretokenized=True, add_special_tokens=True), formatted_input
# tokenizer.encode(tokens, is_split_into_words=True, add_special_tokens=True), formatted_input
# )
# # This is not supported with the Rust tokenizers
# # self.assertEqual(tokenizer.encode(input_ids, add_special_tokens=True), formatted_input)
@ -1250,20 +1250,20 @@ class TokenizerTesterMixin:
# sequence_no_prefix_space = sequence.strip()
# Test encode for pretokenized inputs
output = tokenizer.encode(token_sequence, is_pretokenized=True, add_special_tokens=False)
output = tokenizer.encode(token_sequence, is_split_into_words=True, add_special_tokens=False)
output_sequence = tokenizer.encode(sequence, add_special_tokens=False)
self.assertEqual(output, output_sequence)
output = tokenizer.encode(token_sequence, is_pretokenized=True, add_special_tokens=True)
output = tokenizer.encode(token_sequence, is_split_into_words=True, add_special_tokens=True)
output_sequence = tokenizer.encode(sequence, add_special_tokens=True)
self.assertEqual(output, output_sequence)
# Test encode_plus for pretokenized inputs
output = tokenizer.encode_plus(token_sequence, is_pretokenized=True, add_special_tokens=False)
output = tokenizer.encode_plus(token_sequence, is_split_into_words=True, add_special_tokens=False)
output_sequence = tokenizer.encode_plus(sequence, add_special_tokens=False)
for key in output.keys():
self.assertEqual(output[key], output_sequence[key])
output = tokenizer.encode_plus(token_sequence, is_pretokenized=True, add_special_tokens=True)
output = tokenizer.encode_plus(token_sequence, is_split_into_words=True, add_special_tokens=True)
output_sequence = tokenizer.encode_plus(sequence, add_special_tokens=True)
for key in output.keys():
self.assertEqual(output[key], output_sequence[key])
@ -1274,7 +1274,7 @@ class TokenizerTesterMixin:
sequence_batch_cleaned_up_spaces = [" " + " ".join(s) for s in token_sequence_batch]
output = tokenizer.batch_encode_plus(
token_sequence_batch, is_pretokenized=True, add_special_tokens=False
token_sequence_batch, is_split_into_words=True, add_special_tokens=False
)
output_sequence = tokenizer.batch_encode_plus(
sequence_batch_cleaned_up_spaces, add_special_tokens=False
@ -1282,7 +1282,7 @@ class TokenizerTesterMixin:
for key in output.keys():
self.assertEqual(output[key], output_sequence[key])
output = tokenizer.batch_encode_plus(
token_sequence_batch, is_pretokenized=True, add_special_tokens=True
token_sequence_batch, is_split_into_words=True, add_special_tokens=True
)
output_sequence = tokenizer.batch_encode_plus(
sequence_batch_cleaned_up_spaces, add_special_tokens=True
@ -1292,25 +1292,25 @@ class TokenizerTesterMixin:
# Test encode for pretokenized inputs pairs
output = tokenizer.encode(
token_sequence, token_sequence, is_pretokenized=True, add_special_tokens=False
token_sequence, token_sequence, is_split_into_words=True, add_special_tokens=False
)
output_sequence = tokenizer.encode(sequence, sequence, add_special_tokens=False)
self.assertEqual(output, output_sequence)
output = tokenizer.encode(
token_sequence, token_sequence, is_pretokenized=True, add_special_tokens=True
token_sequence, token_sequence, is_split_into_words=True, add_special_tokens=True
)
output_sequence = tokenizer.encode(sequence, sequence, add_special_tokens=True)
self.assertEqual(output, output_sequence)
# Test encode_plus for pretokenized inputs pairs
output = tokenizer.encode_plus(
token_sequence, token_sequence, is_pretokenized=True, add_special_tokens=False
token_sequence, token_sequence, is_split_into_words=True, add_special_tokens=False
)
output_sequence = tokenizer.encode_plus(sequence, sequence, add_special_tokens=False)
for key in output.keys():
self.assertEqual(output[key], output_sequence[key])
output = tokenizer.encode_plus(
token_sequence, token_sequence, is_pretokenized=True, add_special_tokens=True
token_sequence, token_sequence, is_split_into_words=True, add_special_tokens=True
)
output_sequence = tokenizer.encode_plus(sequence, sequence, add_special_tokens=True)
for key in output.keys():
@ -1326,7 +1326,7 @@ class TokenizerTesterMixin:
]
output = tokenizer.batch_encode_plus(
token_sequence_pair_batch, is_pretokenized=True, add_special_tokens=False
token_sequence_pair_batch, is_split_into_words=True, add_special_tokens=False
)
output_sequence = tokenizer.batch_encode_plus(
sequence_pair_batch_cleaned_up_spaces, add_special_tokens=False
@ -1334,7 +1334,7 @@ class TokenizerTesterMixin:
for key in output.keys():
self.assertEqual(output[key], output_sequence[key])
output = tokenizer.batch_encode_plus(
token_sequence_pair_batch, is_pretokenized=True, add_special_tokens=True
token_sequence_pair_batch, is_split_into_words=True, add_special_tokens=True
)
output_sequence = tokenizer.batch_encode_plus(
sequence_pair_batch_cleaned_up_spaces, add_special_tokens=True

View File

@ -340,12 +340,12 @@ class CommonFastTokenizerTest(unittest.TestCase):
pretokenized_input_pair = "This is a sample pair".split()
# Test encode for pretokenized inputs
output_r = tokenizer_r.encode(pretokenized_input_simple, is_pretokenized=True)
output_p = tokenizer_p.encode(pretokenized_input_simple, is_pretokenized=True)
output_r = tokenizer_r.encode(pretokenized_input_simple, is_split_into_words=True)
output_p = tokenizer_p.encode(pretokenized_input_simple, is_split_into_words=True)
self.assertEqual(output_p, output_r)
kwargs = {
"is_pretokenized": True,
"is_split_into_words": True,
"return_token_type_ids": True,
"return_attention_mask": True,
"return_overflowing_tokens": False,
@ -353,7 +353,7 @@ class CommonFastTokenizerTest(unittest.TestCase):
"return_offsets_mapping": False, # Not implemented in python tokenizers
}
batch_kwargs = {
"is_pretokenized": True,
"is_split_into_words": True,
"return_token_type_ids": True,
"return_attention_mask": True, # we have an 's' here
"return_overflowing_tokens": False,
@ -374,8 +374,8 @@ class CommonFastTokenizerTest(unittest.TestCase):
self.assertEqual(output_p[key], output_r[key])
# Test encode for pretokenized inputs pairs
output_r = tokenizer_r.encode(pretokenized_input_simple, pretokenized_input_pair, is_pretokenized=True)
output_p = tokenizer_p.encode(pretokenized_input_simple, pretokenized_input_pair, is_pretokenized=True)
output_r = tokenizer_r.encode(pretokenized_input_simple, pretokenized_input_pair, is_split_into_words=True)
output_p = tokenizer_p.encode(pretokenized_input_simple, pretokenized_input_pair, is_split_into_words=True)
self.assertEqual(output_p, output_r)
# Test encode_plus for pretokenized inputs