Exposing prepare_for_model for both slow & fast tokenizers (#5479)

* Exposing prepare_for_model for both slow & fast tokenizers

* Update method signature

* The traditional style commit

* Hide the warnings behind the verbose flag

* update default truncation strategy and prepare_for_model

* fix tests and prepare_for_models methods

Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
Lysandre Debut 2020-07-03 10:51:21 -04:00 committed by GitHub
parent 814ed7ee76
commit 17ade127b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 285 additions and 205 deletions

View File

@ -454,12 +454,12 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
first_ids = get_input_ids(text)
second_ids = get_input_ids(text_pair) if text_pair is not None else None
return self._prepare_for_model(
return self.prepare_for_model(
first_ids,
pair_ids=second_ids,
add_special_tokens=add_special_tokens,
padding_strategy=padding_strategy,
truncation_strategy=truncation_strategy,
padding=padding_strategy.value,
truncation=truncation_strategy.value,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
@ -584,7 +584,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
batch_outputs = {}
for first_ids, second_ids in batch_ids_pairs:
outputs = self._prepare_for_model(
outputs = self.prepare_for_model(
first_ids,
second_ids,
add_special_tokens=add_special_tokens,
@ -620,109 +620,6 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
return batch_outputs
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
def _prepare_for_model(
self,
ids: List[int],
pair_ids: Optional[List[int]] = None,
add_special_tokens: bool = True,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[str] = None,
prepend_batch_axis: bool = False,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_length: bool = False,
verbose: bool = True,
) -> BatchEncoding:
""" Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
It adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
manages a moving window (with user defined stride) for overflowing tokens
Args:
ids: list of tokenized input ids. Can be obtained from a string by chaining the
`tokenize` and `convert_tokens_to_ids` methods.
pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the
`tokenize` and `convert_tokens_to_ids` methods.
"""
pair = bool(pair_ids is not None)
len_ids = len(ids)
len_pair_ids = len(pair_ids) if pair else 0
# Load from model defaults
if return_token_type_ids is None:
return_token_type_ids = "token_type_ids" in self.model_input_names
if return_attention_mask is None:
return_attention_mask = "attention_mask" in self.model_input_names
encoded_inputs = {}
# Compute the total size of the returned encodings
total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
# Truncation: Handle max sequence length
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
ids, pair_ids, overflowing_tokens = self.truncate_sequences(
ids,
pair_ids=pair_ids,
num_tokens_to_remove=total_len - max_length,
truncation_strategy=truncation_strategy,
stride=stride,
)
if return_overflowing_tokens:
encoded_inputs["overflowing_tokens"] = overflowing_tokens
encoded_inputs["num_truncated_tokens"] = total_len - max_length
# Add special tokens
if add_special_tokens:
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
else:
sequence = ids + pair_ids if pair else ids
token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else [])
# Build output dictionnary
encoded_inputs["input_ids"] = sequence
if return_token_type_ids:
encoded_inputs["token_type_ids"] = token_type_ids
if return_special_tokens_mask:
if add_special_tokens:
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
else:
encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
# Check lengths
if max_length is None and len(encoded_inputs["input_ids"]) > self.model_max_length and verbose:
logger.warning(
"Token indices sequence length is longer than the specified maximum sequence length "
"for this model ({} > {}). Running this sequence through the model will result in "
"indexing errors".format(len(ids), self.model_max_length)
)
# Padding
if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
encoded_inputs = self.pad(
encoded_inputs,
max_length=max_length,
padding=padding_strategy.value,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
if return_length:
encoded_inputs["length"] = len(encoded_inputs["input_ids"])
batch_outputs = BatchEncoding(
encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
)
return batch_outputs
def prepare_for_tokenization(self, text: str, is_pretokenized=False, **kwargs) -> (str, dict):
""" Performs any necessary transformations before tokenization.
@ -731,90 +628,6 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
"""
return (text, kwargs)
def truncate_sequences(
self,
ids: List[int],
pair_ids: Optional[List[int]] = None,
num_tokens_to_remove: int = 0,
truncation_strategy: Union[str, TruncationStrategy] = "only_first",
stride: int = 0,
) -> Tuple[List[int], List[int], List[int]]:
""" Truncates a sequence pair in place to the maximum length.
Args:
ids: list of tokenized input ids. Can be obtained from a string by chaining the
`tokenize` and `convert_tokens_to_ids` methods.
pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the
`tokenize` and `convert_tokens_to_ids` methods.
num_tokens_to_remove (:obj:`int`, `optional`, defaults to ``0``):
number of tokens to remove using the truncation strategy
truncation_strategy (:obj:`string`, `optional`, defaults to "only_first"):
String selected in the following options:
- 'only_first' (default): Only truncate the first sequence. raise an error if the first sequence is shorter or equal to than num_tokens_to_remove.
- 'only_second': Only truncate the second sequence
- 'longest_first': Iteratively reduce the inputs sequence until the input is under max_length
starting from the longest one at each token (when there is a pair of input sequences).
Overflowing tokens only contains overflow from the first sequence.
- 'do_not_truncate'
stride (:obj:`int`, `optional`, defaults to ``0``):
If set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defines the number of additional tokens.
"""
if num_tokens_to_remove <= 0:
return ids, pair_ids, []
if not isinstance(truncation_strategy, TruncationStrategy):
truncation_strategy = TruncationStrategy(truncation_strategy)
overflowing_tokens = []
if truncation_strategy == TruncationStrategy.LONGEST_FIRST:
for _ in range(num_tokens_to_remove):
if pair_ids is None or len(ids) > len(pair_ids):
ids = ids[:-1]
else:
pair_ids = pair_ids[:-1]
elif truncation_strategy == TruncationStrategy.ONLY_FIRST:
if len(ids) > num_tokens_to_remove:
window_len = min(len(ids), stride + num_tokens_to_remove)
overflowing_tokens = ids[-window_len:]
ids = ids[:-num_tokens_to_remove]
else:
logger.error(
f"We need to remove {num_tokens_to_remove} to truncate the input"
f"but the first sequence has a length {len(ids)}. "
f"Please select another truncation strategy than {truncation_strategy}, "
f"for instance 'longest_first' or 'only_second'."
)
elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
if len(pair_ids) > num_tokens_to_remove:
window_len = min(len(pair_ids), stride + num_tokens_to_remove)
overflowing_tokens = pair_ids[-window_len:]
pair_ids = pair_ids[:-num_tokens_to_remove]
else:
logger.error(
f"We need to remove {num_tokens_to_remove} to truncate the input"
f"but the second sequence has a length {len(pair_ids)}. "
f"Please select another truncation strategy than {truncation_strategy}, "
f"for instance 'longest_first' or 'only_first'."
)
return (ids, pair_ids, overflowing_tokens)
def create_token_type_ids_from_sequences(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List[int]:
if token_ids_1 is None:
return len(token_ids_0) * [0]
return [0] * len(token_ids_0) + [1] * len(token_ids_1)
def build_inputs_with_special_tokens(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens. This implementation does not add special tokens.
"""
if token_ids_1 is None:
return token_ids_0
return token_ids_0 + token_ids_1
def get_special_tokens_mask(
self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
) -> List[int]:

View File

@ -945,9 +945,9 @@ ENCODE_KWARGS_DOCSTRING = r"""
`truncation` (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`False`):
Activate and control truncation. Accepts the following values:
* `True` or `'only_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will only truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided,
* `True` or `'longest_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will truncate token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a batch of pairs) is provided,
* `'only_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will only truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided,
* `'only_second'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will only truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided,
* `'longest_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will truncate token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a batch of pairs) is provided,
* `False` or `'do_not_truncate'` (default): No truncation (i.e. can output batch with sequences length greater than the model max admissible input size)
`max_length` (:obj:`Union[int, None]`, `optional`, defaults to :obj:`None`):
Control the length for padding/truncation. Accepts the following values
@ -1446,10 +1446,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
logger.warning(
"Truncation was not explicitely activated but `max_length` is provided a specific value, "
"please use `truncation=True` to explicitely truncate examples to max length. "
"Defaulting to 'only_first' truncation strategy. "
"If you encode pairs of sequences (GLUE-style) with the tokenizer you may want to check this is the right behavior."
"Defaulting to 'longest_first' truncation strategy. "
"If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
"more precisely by providing a specific strategy to `truncation`."
)
truncation = "only_first"
truncation = "longest_first"
# Get padding strategy
if padding is False and old_pad_to_max_length:
@ -1469,7 +1470,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
elif padding is not False:
if padding is True:
padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
else:
elif not isinstance(padding, PaddingStrategy):
padding_strategy = PaddingStrategy(padding)
else:
padding_strategy = PaddingStrategy.DO_NOT_PAD
@ -1492,9 +1493,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
elif truncation is not False:
if truncation is True:
truncation_strategy = (
TruncationStrategy.ONLY_FIRST
) # Default to truncate the first sequences in pairs of inputs
else:
TruncationStrategy.LONGEST_FIRST
) # Default to truncate the longest sequences in pairs of inputs
elif not isinstance(truncation, TruncationStrategy):
truncation_strategy = TruncationStrategy(truncation)
else:
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
@ -1960,6 +1961,225 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
return BatchEncoding(batch_outputs, tensor_type=return_tensors)
def create_token_type_ids_from_sequences(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List[int]:
if token_ids_1 is None:
return len(token_ids_0) * [0]
return [0] * len(token_ids_0) + [1] * len(token_ids_1)
def build_inputs_with_special_tokens(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens. This implementation does not add special tokens.
"""
if token_ids_1 is None:
return token_ids_0
return token_ids_0 + token_ids_1
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
def prepare_for_model(
self,
ids: List[int],
pair_ids: Optional[List[int]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str] = False,
truncation: Union[bool, str] = False,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
prepend_batch_axis: bool = False,
**kwargs
) -> BatchEncoding:
""" Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
It adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
manages a moving window (with user defined stride) for overflowing tokens
Args:
ids: list of tokenized input ids. Can be obtained from a string by chaining the
`tokenize` and `convert_tokens_to_ids` methods.
pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the
`tokenize` and `convert_tokens_to_ids` methods.
"""
if "return_lengths" in kwargs:
if verbose:
warnings.warn(
"The PreTrainedTokenizerBase.prepare_for_model `return_lengths` parameter is deprecated. "
"Please use `return_length` instead.",
FutureWarning,
)
return_length = kwargs["return_lengths"]
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
padding=padding,
truncation=truncation,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
verbose=verbose,
**kwargs,
)
pair = bool(pair_ids is not None)
len_ids = len(ids)
len_pair_ids = len(pair_ids) if pair else 0
# Load from model defaults
if return_token_type_ids is None:
return_token_type_ids = "token_type_ids" in self.model_input_names
if return_attention_mask is None:
return_attention_mask = "attention_mask" in self.model_input_names
encoded_inputs = {}
# Compute the total size of the returned encodings
total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
# Truncation: Handle max sequence length
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
ids, pair_ids, overflowing_tokens = self.truncate_sequences(
ids,
pair_ids=pair_ids,
num_tokens_to_remove=total_len - max_length,
truncation_strategy=truncation_strategy,
stride=stride,
)
if return_overflowing_tokens:
encoded_inputs["overflowing_tokens"] = overflowing_tokens
encoded_inputs["num_truncated_tokens"] = total_len - max_length
# Add special tokens
if add_special_tokens:
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
else:
sequence = ids + pair_ids if pair else ids
token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else [])
# Build output dictionnary
encoded_inputs["input_ids"] = sequence
if return_token_type_ids:
encoded_inputs["token_type_ids"] = token_type_ids
if return_special_tokens_mask:
if add_special_tokens:
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
else:
encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
# Check lengths
if max_length is None and len(encoded_inputs["input_ids"]) > self.model_max_length and verbose:
logger.warning(
"Token indices sequence length is longer than the specified maximum sequence length "
"for this model ({} > {}). Running this sequence through the model will result in "
"indexing errors".format(len(ids), self.model_max_length)
)
# Padding
if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
encoded_inputs = self.pad(
encoded_inputs,
max_length=max_length,
padding=padding_strategy.value,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
if return_length:
encoded_inputs["length"] = len(encoded_inputs["input_ids"])
batch_outputs = BatchEncoding(
encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
)
return batch_outputs
def truncate_sequences(
self,
ids: List[int],
pair_ids: Optional[List[int]] = None,
num_tokens_to_remove: int = 0,
truncation_strategy: Union[str, TruncationStrategy] = "longest_first",
stride: int = 0,
) -> Tuple[List[int], List[int], List[int]]:
""" Truncates a sequence pair in place to the maximum length.
Args:
ids: list of tokenized input ids. Can be obtained from a string by chaining the
`tokenize` and `convert_tokens_to_ids` methods.
pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the
`tokenize` and `convert_tokens_to_ids` methods.
num_tokens_to_remove (:obj:`int`, `optional`, defaults to ``0``):
number of tokens to remove using the truncation strategy
truncation_strategy (:obj:`string`, `optional`, defaults to "longest_first"):
String selected in the following options:
- 'longest_first' (default): Iteratively reduce the inputs sequence until the input is under max_length
starting from the longest one at each token (when there is a pair of input sequences).
Overflowing tokens only contains overflow from the first sequence.
- 'only_first': Only truncate the first sequence. raise an error if the first sequence is shorter or equal to than num_tokens_to_remove.
- 'only_second': Only truncate the second sequence
- 'do_not_truncate'
stride (:obj:`int`, `optional`, defaults to ``0``):
If set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defines the number of additional tokens.
"""
if num_tokens_to_remove <= 0:
return ids, pair_ids, []
if not isinstance(truncation_strategy, TruncationStrategy):
truncation_strategy = TruncationStrategy(truncation_strategy)
overflowing_tokens = []
if truncation_strategy == TruncationStrategy.LONGEST_FIRST:
for _ in range(num_tokens_to_remove):
if pair_ids is None or len(ids) > len(pair_ids):
if not overflowing_tokens:
window_len = min(len(ids), stride + 1)
else:
window_len = 1
overflowing_tokens.extend(ids[-window_len:])
ids = ids[:-1]
else:
if not overflowing_tokens:
window_len = min(len(pair_ids), stride + 1)
else:
window_len = 1
overflowing_tokens.extend(pair_ids[-window_len:])
pair_ids = pair_ids[:-1]
elif truncation_strategy == TruncationStrategy.ONLY_FIRST:
if len(ids) > num_tokens_to_remove:
window_len = min(len(ids), stride + num_tokens_to_remove)
overflowing_tokens = ids[-window_len:]
ids = ids[:-num_tokens_to_remove]
else:
logger.error(
f"We need to remove {num_tokens_to_remove} to truncate the input"
f"but the first sequence has a length {len(ids)}. "
f"Please select another truncation strategy than {truncation_strategy}, "
f"for instance 'longest_first' or 'only_second'."
)
elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
if len(pair_ids) > num_tokens_to_remove:
window_len = min(len(pair_ids), stride + num_tokens_to_remove)
overflowing_tokens = pair_ids[-window_len:]
pair_ids = pair_ids[:-num_tokens_to_remove]
else:
logger.error(
f"We need to remove {num_tokens_to_remove} to truncate the input"
f"but the second sequence has a length {len(pair_ids)}. "
f"Please select another truncation strategy than {truncation_strategy}, "
f"for instance 'longest_first' or 'only_first'."
)
return (ids, pair_ids, overflowing_tokens)
def _pad(
self,
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],

View File

@ -508,9 +508,7 @@ class TokenizerTesterMixin:
self.assertEqual(len(truncated_sequence), total_length - 2)
self.assertEqual(truncated_sequence, sequence[:-2])
self.assertEqual(
len(overflowing_tokens), 0
) # No overflowing tokens when using 'longest' in python tokenizers
self.assertEqual(len(overflowing_tokens), 2 + stride)
def test_maximum_encoding_length_pair_input(self):
tokenizers = self.get_tokenizers(do_lower_case=False, model_max_length=100)
@ -634,7 +632,39 @@ class TokenizerTesterMixin:
self.assertEqual(truncated_sequence, truncated_longest_sequence)
self.assertEqual(
len(overflowing_tokens), 0
len(overflowing_tokens), 2 + stride
) # No overflowing tokens when using 'longest' in python tokenizers
information = tokenizer.encode_plus(
seq_0,
seq_1,
max_length=len(sequence) - 2,
add_special_tokens=False,
stride=stride,
truncation=True,
return_overflowing_tokens=True,
# add_prefix_space=False,
)
# Overflowing tokens are handled quite differently in slow and fast tokenizers
if isinstance(tokenizer, PreTrainedTokenizerFast):
truncated_sequence = information["input_ids"][0]
overflowing_tokens = information["input_ids"][1]
self.assertEqual(len(information["input_ids"]), 2)
self.assertEqual(len(truncated_sequence), len(sequence) - 2)
self.assertEqual(truncated_sequence, truncated_longest_sequence)
self.assertEqual(len(overflowing_tokens), 2 + stride + len(smallest))
self.assertEqual(overflowing_tokens, overflow_longest_sequence)
else:
truncated_sequence = information["input_ids"]
overflowing_tokens = information["overflowing_tokens"]
self.assertEqual(len(truncated_sequence), len(sequence) - 2)
self.assertEqual(truncated_sequence, truncated_longest_sequence)
self.assertEqual(
len(overflowing_tokens), 2 + stride
) # No overflowing tokens when using 'longest' in python tokenizers
information_first_truncated = tokenizer.encode_plus(
@ -643,7 +673,7 @@ class TokenizerTesterMixin:
max_length=len(sequence) - 2,
add_special_tokens=False,
stride=stride,
truncation=True,
truncation="only_first",
return_overflowing_tokens=True,
# add_prefix_space=False,
)
@ -1293,6 +1323,16 @@ class TokenizerTesterMixin:
for key in output.keys():
self.assertEqual(output[key], output_sequence[key])
def test_prepare_for_model(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
string_sequence = "Testing the prepare_for_model method."
ids = tokenizer.encode(string_sequence, add_special_tokens=False)
input_dict = tokenizer.encode_plus(string_sequence)
prepared_input_dict = tokenizer.prepare_for_model(ids)
self.assertEqual(input_dict, prepared_input_dict)
@require_torch
@require_tf
def test_batch_encode_plus_tensors(self):

View File

@ -90,6 +90,7 @@ class CommonFastTokenizerTest(unittest.TestCase):
self.assert_embeded_special_tokens(tokenizer_r, tokenizer_p)
self.assert_padding(tokenizer_r, tokenizer_p)
self.assert_create_token_type_ids(tokenizer_r, tokenizer_p)
self.assert_prepare_for_model(tokenizer_r, tokenizer_p)
# TODO: enable for v3.0.0
# self.assert_empty_output_no_special_tokens(tokenizer_r, tokenizer_p)
@ -709,6 +710,12 @@ class CommonFastTokenizerTest(unittest.TestCase):
for i_no, i_with in zip(no_special_tokens[key], with_special_tokens[key]):
self.assertEqual(len(i_no), len(i_with) - simple_num_special_tokens_to_add)
def assert_prepare_for_model(self, tokenizer_r, tokenizer_p):
string_sequence = "Asserting that both tokenizers are equal"
python_output = tokenizer_p.prepare_for_model(tokenizer_p.encode(string_sequence))
rust_output = tokenizer_r.prepare_for_model(tokenizer_r.encode(string_sequence))
self.assertEqual(python_output, rust_output)
class WordPieceFastTokenizerTest(CommonFastTokenizerTest):
"""