Exposing prepare_for_model for both slow & fast tokenizers (#5479)
* Exposing prepare_for_model for both slow & fast tokenizers * Update method signature * The traditional style commit * Hide the warnings behind the verbose flag * update default truncation strategy and prepare_for_model * fix tests and prepare_for_models methods Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
parent
814ed7ee76
commit
17ade127b9
|
@ -454,12 +454,12 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||
first_ids = get_input_ids(text)
|
||||
second_ids = get_input_ids(text_pair) if text_pair is not None else None
|
||||
|
||||
return self._prepare_for_model(
|
||||
return self.prepare_for_model(
|
||||
first_ids,
|
||||
pair_ids=second_ids,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding_strategy=padding_strategy,
|
||||
truncation_strategy=truncation_strategy,
|
||||
padding=padding_strategy.value,
|
||||
truncation=truncation_strategy.value,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
|
@ -584,7 +584,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||
|
||||
batch_outputs = {}
|
||||
for first_ids, second_ids in batch_ids_pairs:
|
||||
outputs = self._prepare_for_model(
|
||||
outputs = self.prepare_for_model(
|
||||
first_ids,
|
||||
second_ids,
|
||||
add_special_tokens=add_special_tokens,
|
||||
|
@ -620,109 +620,6 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||
|
||||
return batch_outputs
|
||||
|
||||
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
|
||||
def _prepare_for_model(
|
||||
self,
|
||||
ids: List[int],
|
||||
pair_ids: Optional[List[int]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
||||
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
|
||||
max_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_tensors: Optional[str] = None,
|
||||
prepend_batch_axis: bool = False,
|
||||
return_token_type_ids: Optional[bool] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
) -> BatchEncoding:
|
||||
""" Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
|
||||
It adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
|
||||
manages a moving window (with user defined stride) for overflowing tokens
|
||||
|
||||
Args:
|
||||
ids: list of tokenized input ids. Can be obtained from a string by chaining the
|
||||
`tokenize` and `convert_tokens_to_ids` methods.
|
||||
pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the
|
||||
`tokenize` and `convert_tokens_to_ids` methods.
|
||||
"""
|
||||
pair = bool(pair_ids is not None)
|
||||
len_ids = len(ids)
|
||||
len_pair_ids = len(pair_ids) if pair else 0
|
||||
|
||||
# Load from model defaults
|
||||
if return_token_type_ids is None:
|
||||
return_token_type_ids = "token_type_ids" in self.model_input_names
|
||||
if return_attention_mask is None:
|
||||
return_attention_mask = "attention_mask" in self.model_input_names
|
||||
|
||||
encoded_inputs = {}
|
||||
|
||||
# Compute the total size of the returned encodings
|
||||
total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
|
||||
|
||||
# Truncation: Handle max sequence length
|
||||
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
|
||||
ids, pair_ids, overflowing_tokens = self.truncate_sequences(
|
||||
ids,
|
||||
pair_ids=pair_ids,
|
||||
num_tokens_to_remove=total_len - max_length,
|
||||
truncation_strategy=truncation_strategy,
|
||||
stride=stride,
|
||||
)
|
||||
if return_overflowing_tokens:
|
||||
encoded_inputs["overflowing_tokens"] = overflowing_tokens
|
||||
encoded_inputs["num_truncated_tokens"] = total_len - max_length
|
||||
|
||||
# Add special tokens
|
||||
if add_special_tokens:
|
||||
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
|
||||
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
|
||||
else:
|
||||
sequence = ids + pair_ids if pair else ids
|
||||
token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else [])
|
||||
|
||||
# Build output dictionnary
|
||||
encoded_inputs["input_ids"] = sequence
|
||||
if return_token_type_ids:
|
||||
encoded_inputs["token_type_ids"] = token_type_ids
|
||||
if return_special_tokens_mask:
|
||||
if add_special_tokens:
|
||||
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
|
||||
else:
|
||||
encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
|
||||
|
||||
# Check lengths
|
||||
if max_length is None and len(encoded_inputs["input_ids"]) > self.model_max_length and verbose:
|
||||
logger.warning(
|
||||
"Token indices sequence length is longer than the specified maximum sequence length "
|
||||
"for this model ({} > {}). Running this sequence through the model will result in "
|
||||
"indexing errors".format(len(ids), self.model_max_length)
|
||||
)
|
||||
|
||||
# Padding
|
||||
if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
|
||||
encoded_inputs = self.pad(
|
||||
encoded_inputs,
|
||||
max_length=max_length,
|
||||
padding=padding_strategy.value,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_attention_mask=return_attention_mask,
|
||||
)
|
||||
|
||||
if return_length:
|
||||
encoded_inputs["length"] = len(encoded_inputs["input_ids"])
|
||||
|
||||
batch_outputs = BatchEncoding(
|
||||
encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
|
||||
)
|
||||
|
||||
return batch_outputs
|
||||
|
||||
def prepare_for_tokenization(self, text: str, is_pretokenized=False, **kwargs) -> (str, dict):
|
||||
""" Performs any necessary transformations before tokenization.
|
||||
|
||||
|
@ -731,90 +628,6 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||
"""
|
||||
return (text, kwargs)
|
||||
|
||||
def truncate_sequences(
|
||||
self,
|
||||
ids: List[int],
|
||||
pair_ids: Optional[List[int]] = None,
|
||||
num_tokens_to_remove: int = 0,
|
||||
truncation_strategy: Union[str, TruncationStrategy] = "only_first",
|
||||
stride: int = 0,
|
||||
) -> Tuple[List[int], List[int], List[int]]:
|
||||
""" Truncates a sequence pair in place to the maximum length.
|
||||
|
||||
Args:
|
||||
ids: list of tokenized input ids. Can be obtained from a string by chaining the
|
||||
`tokenize` and `convert_tokens_to_ids` methods.
|
||||
pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the
|
||||
`tokenize` and `convert_tokens_to_ids` methods.
|
||||
num_tokens_to_remove (:obj:`int`, `optional`, defaults to ``0``):
|
||||
number of tokens to remove using the truncation strategy
|
||||
truncation_strategy (:obj:`string`, `optional`, defaults to "only_first"):
|
||||
String selected in the following options:
|
||||
|
||||
- 'only_first' (default): Only truncate the first sequence. raise an error if the first sequence is shorter or equal to than num_tokens_to_remove.
|
||||
- 'only_second': Only truncate the second sequence
|
||||
- 'longest_first': Iteratively reduce the inputs sequence until the input is under max_length
|
||||
starting from the longest one at each token (when there is a pair of input sequences).
|
||||
Overflowing tokens only contains overflow from the first sequence.
|
||||
- 'do_not_truncate'
|
||||
stride (:obj:`int`, `optional`, defaults to ``0``):
|
||||
If set to a number along with max_length, the overflowing tokens returned will contain some tokens
|
||||
from the main sequence returned. The value of this argument defines the number of additional tokens.
|
||||
"""
|
||||
if num_tokens_to_remove <= 0:
|
||||
return ids, pair_ids, []
|
||||
|
||||
if not isinstance(truncation_strategy, TruncationStrategy):
|
||||
truncation_strategy = TruncationStrategy(truncation_strategy)
|
||||
|
||||
overflowing_tokens = []
|
||||
if truncation_strategy == TruncationStrategy.LONGEST_FIRST:
|
||||
for _ in range(num_tokens_to_remove):
|
||||
if pair_ids is None or len(ids) > len(pair_ids):
|
||||
ids = ids[:-1]
|
||||
else:
|
||||
pair_ids = pair_ids[:-1]
|
||||
elif truncation_strategy == TruncationStrategy.ONLY_FIRST:
|
||||
if len(ids) > num_tokens_to_remove:
|
||||
window_len = min(len(ids), stride + num_tokens_to_remove)
|
||||
overflowing_tokens = ids[-window_len:]
|
||||
ids = ids[:-num_tokens_to_remove]
|
||||
else:
|
||||
logger.error(
|
||||
f"We need to remove {num_tokens_to_remove} to truncate the input"
|
||||
f"but the first sequence has a length {len(ids)}. "
|
||||
f"Please select another truncation strategy than {truncation_strategy}, "
|
||||
f"for instance 'longest_first' or 'only_second'."
|
||||
)
|
||||
elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
|
||||
if len(pair_ids) > num_tokens_to_remove:
|
||||
window_len = min(len(pair_ids), stride + num_tokens_to_remove)
|
||||
overflowing_tokens = pair_ids[-window_len:]
|
||||
pair_ids = pair_ids[:-num_tokens_to_remove]
|
||||
else:
|
||||
logger.error(
|
||||
f"We need to remove {num_tokens_to_remove} to truncate the input"
|
||||
f"but the second sequence has a length {len(pair_ids)}. "
|
||||
f"Please select another truncation strategy than {truncation_strategy}, "
|
||||
f"for instance 'longest_first' or 'only_first'."
|
||||
)
|
||||
|
||||
return (ids, pair_ids, overflowing_tokens)
|
||||
|
||||
def create_token_type_ids_from_sequences(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List[int]:
|
||||
if token_ids_1 is None:
|
||||
return len(token_ids_0) * [0]
|
||||
return [0] * len(token_ids_0) + [1] * len(token_ids_1)
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
||||
by concatenating and adding special tokens. This implementation does not add special tokens.
|
||||
"""
|
||||
if token_ids_1 is None:
|
||||
return token_ids_0
|
||||
return token_ids_0 + token_ids_1
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
|
||||
) -> List[int]:
|
||||
|
|
|
@ -945,9 +945,9 @@ ENCODE_KWARGS_DOCSTRING = r"""
|
|||
`truncation` (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`False`):
|
||||
Activate and control truncation. Accepts the following values:
|
||||
|
||||
* `True` or `'only_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will only truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided,
|
||||
* `True` or `'longest_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will truncate token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a batch of pairs) is provided,
|
||||
* `'only_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will only truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided,
|
||||
* `'only_second'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will only truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided,
|
||||
* `'longest_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will truncate token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a batch of pairs) is provided,
|
||||
* `False` or `'do_not_truncate'` (default): No truncation (i.e. can output batch with sequences length greater than the model max admissible input size)
|
||||
`max_length` (:obj:`Union[int, None]`, `optional`, defaults to :obj:`None`):
|
||||
Control the length for padding/truncation. Accepts the following values
|
||||
|
@ -1446,10 +1446,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
|||
logger.warning(
|
||||
"Truncation was not explicitely activated but `max_length` is provided a specific value, "
|
||||
"please use `truncation=True` to explicitely truncate examples to max length. "
|
||||
"Defaulting to 'only_first' truncation strategy. "
|
||||
"If you encode pairs of sequences (GLUE-style) with the tokenizer you may want to check this is the right behavior."
|
||||
"Defaulting to 'longest_first' truncation strategy. "
|
||||
"If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
|
||||
"more precisely by providing a specific strategy to `truncation`."
|
||||
)
|
||||
truncation = "only_first"
|
||||
truncation = "longest_first"
|
||||
|
||||
# Get padding strategy
|
||||
if padding is False and old_pad_to_max_length:
|
||||
|
@ -1469,7 +1470,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
|||
elif padding is not False:
|
||||
if padding is True:
|
||||
padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
|
||||
else:
|
||||
elif not isinstance(padding, PaddingStrategy):
|
||||
padding_strategy = PaddingStrategy(padding)
|
||||
else:
|
||||
padding_strategy = PaddingStrategy.DO_NOT_PAD
|
||||
|
@ -1492,9 +1493,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
|||
elif truncation is not False:
|
||||
if truncation is True:
|
||||
truncation_strategy = (
|
||||
TruncationStrategy.ONLY_FIRST
|
||||
) # Default to truncate the first sequences in pairs of inputs
|
||||
else:
|
||||
TruncationStrategy.LONGEST_FIRST
|
||||
) # Default to truncate the longest sequences in pairs of inputs
|
||||
elif not isinstance(truncation, TruncationStrategy):
|
||||
truncation_strategy = TruncationStrategy(truncation)
|
||||
else:
|
||||
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
|
||||
|
@ -1960,6 +1961,225 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
|||
|
||||
return BatchEncoding(batch_outputs, tensor_type=return_tensors)
|
||||
|
||||
def create_token_type_ids_from_sequences(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List[int]:
|
||||
if token_ids_1 is None:
|
||||
return len(token_ids_0) * [0]
|
||||
return [0] * len(token_ids_0) + [1] * len(token_ids_1)
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
||||
by concatenating and adding special tokens. This implementation does not add special tokens.
|
||||
"""
|
||||
if token_ids_1 is None:
|
||||
return token_ids_0
|
||||
return token_ids_0 + token_ids_1
|
||||
|
||||
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
|
||||
def prepare_for_model(
|
||||
self,
|
||||
ids: List[int],
|
||||
pair_ids: Optional[List[int]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding: Union[bool, str] = False,
|
||||
truncation: Union[bool, str] = False,
|
||||
max_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
return_token_type_ids: Optional[bool] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
prepend_batch_axis: bool = False,
|
||||
**kwargs
|
||||
) -> BatchEncoding:
|
||||
""" Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
|
||||
It adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
|
||||
manages a moving window (with user defined stride) for overflowing tokens
|
||||
|
||||
Args:
|
||||
ids: list of tokenized input ids. Can be obtained from a string by chaining the
|
||||
`tokenize` and `convert_tokens_to_ids` methods.
|
||||
pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the
|
||||
`tokenize` and `convert_tokens_to_ids` methods.
|
||||
"""
|
||||
|
||||
if "return_lengths" in kwargs:
|
||||
if verbose:
|
||||
warnings.warn(
|
||||
"The PreTrainedTokenizerBase.prepare_for_model `return_lengths` parameter is deprecated. "
|
||||
"Please use `return_length` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
return_length = kwargs["return_lengths"]
|
||||
|
||||
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
|
||||
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
pair = bool(pair_ids is not None)
|
||||
len_ids = len(ids)
|
||||
len_pair_ids = len(pair_ids) if pair else 0
|
||||
|
||||
# Load from model defaults
|
||||
if return_token_type_ids is None:
|
||||
return_token_type_ids = "token_type_ids" in self.model_input_names
|
||||
if return_attention_mask is None:
|
||||
return_attention_mask = "attention_mask" in self.model_input_names
|
||||
|
||||
encoded_inputs = {}
|
||||
|
||||
# Compute the total size of the returned encodings
|
||||
total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
|
||||
|
||||
# Truncation: Handle max sequence length
|
||||
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
|
||||
ids, pair_ids, overflowing_tokens = self.truncate_sequences(
|
||||
ids,
|
||||
pair_ids=pair_ids,
|
||||
num_tokens_to_remove=total_len - max_length,
|
||||
truncation_strategy=truncation_strategy,
|
||||
stride=stride,
|
||||
)
|
||||
if return_overflowing_tokens:
|
||||
encoded_inputs["overflowing_tokens"] = overflowing_tokens
|
||||
encoded_inputs["num_truncated_tokens"] = total_len - max_length
|
||||
|
||||
# Add special tokens
|
||||
if add_special_tokens:
|
||||
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
|
||||
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
|
||||
else:
|
||||
sequence = ids + pair_ids if pair else ids
|
||||
token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else [])
|
||||
|
||||
# Build output dictionnary
|
||||
encoded_inputs["input_ids"] = sequence
|
||||
if return_token_type_ids:
|
||||
encoded_inputs["token_type_ids"] = token_type_ids
|
||||
if return_special_tokens_mask:
|
||||
if add_special_tokens:
|
||||
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
|
||||
else:
|
||||
encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
|
||||
|
||||
# Check lengths
|
||||
if max_length is None and len(encoded_inputs["input_ids"]) > self.model_max_length and verbose:
|
||||
logger.warning(
|
||||
"Token indices sequence length is longer than the specified maximum sequence length "
|
||||
"for this model ({} > {}). Running this sequence through the model will result in "
|
||||
"indexing errors".format(len(ids), self.model_max_length)
|
||||
)
|
||||
|
||||
# Padding
|
||||
if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
|
||||
encoded_inputs = self.pad(
|
||||
encoded_inputs,
|
||||
max_length=max_length,
|
||||
padding=padding_strategy.value,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_attention_mask=return_attention_mask,
|
||||
)
|
||||
|
||||
if return_length:
|
||||
encoded_inputs["length"] = len(encoded_inputs["input_ids"])
|
||||
|
||||
batch_outputs = BatchEncoding(
|
||||
encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
|
||||
)
|
||||
|
||||
return batch_outputs
|
||||
|
||||
def truncate_sequences(
|
||||
self,
|
||||
ids: List[int],
|
||||
pair_ids: Optional[List[int]] = None,
|
||||
num_tokens_to_remove: int = 0,
|
||||
truncation_strategy: Union[str, TruncationStrategy] = "longest_first",
|
||||
stride: int = 0,
|
||||
) -> Tuple[List[int], List[int], List[int]]:
|
||||
""" Truncates a sequence pair in place to the maximum length.
|
||||
|
||||
Args:
|
||||
ids: list of tokenized input ids. Can be obtained from a string by chaining the
|
||||
`tokenize` and `convert_tokens_to_ids` methods.
|
||||
pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the
|
||||
`tokenize` and `convert_tokens_to_ids` methods.
|
||||
num_tokens_to_remove (:obj:`int`, `optional`, defaults to ``0``):
|
||||
number of tokens to remove using the truncation strategy
|
||||
truncation_strategy (:obj:`string`, `optional`, defaults to "longest_first"):
|
||||
String selected in the following options:
|
||||
|
||||
- 'longest_first' (default): Iteratively reduce the inputs sequence until the input is under max_length
|
||||
starting from the longest one at each token (when there is a pair of input sequences).
|
||||
Overflowing tokens only contains overflow from the first sequence.
|
||||
- 'only_first': Only truncate the first sequence. raise an error if the first sequence is shorter or equal to than num_tokens_to_remove.
|
||||
- 'only_second': Only truncate the second sequence
|
||||
- 'do_not_truncate'
|
||||
stride (:obj:`int`, `optional`, defaults to ``0``):
|
||||
If set to a number along with max_length, the overflowing tokens returned will contain some tokens
|
||||
from the main sequence returned. The value of this argument defines the number of additional tokens.
|
||||
"""
|
||||
if num_tokens_to_remove <= 0:
|
||||
return ids, pair_ids, []
|
||||
|
||||
if not isinstance(truncation_strategy, TruncationStrategy):
|
||||
truncation_strategy = TruncationStrategy(truncation_strategy)
|
||||
|
||||
overflowing_tokens = []
|
||||
if truncation_strategy == TruncationStrategy.LONGEST_FIRST:
|
||||
for _ in range(num_tokens_to_remove):
|
||||
if pair_ids is None or len(ids) > len(pair_ids):
|
||||
if not overflowing_tokens:
|
||||
window_len = min(len(ids), stride + 1)
|
||||
else:
|
||||
window_len = 1
|
||||
overflowing_tokens.extend(ids[-window_len:])
|
||||
ids = ids[:-1]
|
||||
else:
|
||||
if not overflowing_tokens:
|
||||
window_len = min(len(pair_ids), stride + 1)
|
||||
else:
|
||||
window_len = 1
|
||||
overflowing_tokens.extend(pair_ids[-window_len:])
|
||||
pair_ids = pair_ids[:-1]
|
||||
elif truncation_strategy == TruncationStrategy.ONLY_FIRST:
|
||||
if len(ids) > num_tokens_to_remove:
|
||||
window_len = min(len(ids), stride + num_tokens_to_remove)
|
||||
overflowing_tokens = ids[-window_len:]
|
||||
ids = ids[:-num_tokens_to_remove]
|
||||
else:
|
||||
logger.error(
|
||||
f"We need to remove {num_tokens_to_remove} to truncate the input"
|
||||
f"but the first sequence has a length {len(ids)}. "
|
||||
f"Please select another truncation strategy than {truncation_strategy}, "
|
||||
f"for instance 'longest_first' or 'only_second'."
|
||||
)
|
||||
elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
|
||||
if len(pair_ids) > num_tokens_to_remove:
|
||||
window_len = min(len(pair_ids), stride + num_tokens_to_remove)
|
||||
overflowing_tokens = pair_ids[-window_len:]
|
||||
pair_ids = pair_ids[:-num_tokens_to_remove]
|
||||
else:
|
||||
logger.error(
|
||||
f"We need to remove {num_tokens_to_remove} to truncate the input"
|
||||
f"but the second sequence has a length {len(pair_ids)}. "
|
||||
f"Please select another truncation strategy than {truncation_strategy}, "
|
||||
f"for instance 'longest_first' or 'only_first'."
|
||||
)
|
||||
|
||||
return (ids, pair_ids, overflowing_tokens)
|
||||
|
||||
def _pad(
|
||||
self,
|
||||
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
||||
|
|
|
@ -508,9 +508,7 @@ class TokenizerTesterMixin:
|
|||
self.assertEqual(len(truncated_sequence), total_length - 2)
|
||||
self.assertEqual(truncated_sequence, sequence[:-2])
|
||||
|
||||
self.assertEqual(
|
||||
len(overflowing_tokens), 0
|
||||
) # No overflowing tokens when using 'longest' in python tokenizers
|
||||
self.assertEqual(len(overflowing_tokens), 2 + stride)
|
||||
|
||||
def test_maximum_encoding_length_pair_input(self):
|
||||
tokenizers = self.get_tokenizers(do_lower_case=False, model_max_length=100)
|
||||
|
@ -634,7 +632,39 @@ class TokenizerTesterMixin:
|
|||
self.assertEqual(truncated_sequence, truncated_longest_sequence)
|
||||
|
||||
self.assertEqual(
|
||||
len(overflowing_tokens), 0
|
||||
len(overflowing_tokens), 2 + stride
|
||||
) # No overflowing tokens when using 'longest' in python tokenizers
|
||||
|
||||
information = tokenizer.encode_plus(
|
||||
seq_0,
|
||||
seq_1,
|
||||
max_length=len(sequence) - 2,
|
||||
add_special_tokens=False,
|
||||
stride=stride,
|
||||
truncation=True,
|
||||
return_overflowing_tokens=True,
|
||||
# add_prefix_space=False,
|
||||
)
|
||||
# Overflowing tokens are handled quite differently in slow and fast tokenizers
|
||||
if isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
truncated_sequence = information["input_ids"][0]
|
||||
overflowing_tokens = information["input_ids"][1]
|
||||
self.assertEqual(len(information["input_ids"]), 2)
|
||||
|
||||
self.assertEqual(len(truncated_sequence), len(sequence) - 2)
|
||||
self.assertEqual(truncated_sequence, truncated_longest_sequence)
|
||||
|
||||
self.assertEqual(len(overflowing_tokens), 2 + stride + len(smallest))
|
||||
self.assertEqual(overflowing_tokens, overflow_longest_sequence)
|
||||
else:
|
||||
truncated_sequence = information["input_ids"]
|
||||
overflowing_tokens = information["overflowing_tokens"]
|
||||
|
||||
self.assertEqual(len(truncated_sequence), len(sequence) - 2)
|
||||
self.assertEqual(truncated_sequence, truncated_longest_sequence)
|
||||
|
||||
self.assertEqual(
|
||||
len(overflowing_tokens), 2 + stride
|
||||
) # No overflowing tokens when using 'longest' in python tokenizers
|
||||
|
||||
information_first_truncated = tokenizer.encode_plus(
|
||||
|
@ -643,7 +673,7 @@ class TokenizerTesterMixin:
|
|||
max_length=len(sequence) - 2,
|
||||
add_special_tokens=False,
|
||||
stride=stride,
|
||||
truncation=True,
|
||||
truncation="only_first",
|
||||
return_overflowing_tokens=True,
|
||||
# add_prefix_space=False,
|
||||
)
|
||||
|
@ -1293,6 +1323,16 @@ class TokenizerTesterMixin:
|
|||
for key in output.keys():
|
||||
self.assertEqual(output[key], output_sequence[key])
|
||||
|
||||
def test_prepare_for_model(self):
|
||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||
for tokenizer in tokenizers:
|
||||
string_sequence = "Testing the prepare_for_model method."
|
||||
ids = tokenizer.encode(string_sequence, add_special_tokens=False)
|
||||
input_dict = tokenizer.encode_plus(string_sequence)
|
||||
prepared_input_dict = tokenizer.prepare_for_model(ids)
|
||||
|
||||
self.assertEqual(input_dict, prepared_input_dict)
|
||||
|
||||
@require_torch
|
||||
@require_tf
|
||||
def test_batch_encode_plus_tensors(self):
|
||||
|
|
|
@ -90,6 +90,7 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
|||
self.assert_embeded_special_tokens(tokenizer_r, tokenizer_p)
|
||||
self.assert_padding(tokenizer_r, tokenizer_p)
|
||||
self.assert_create_token_type_ids(tokenizer_r, tokenizer_p)
|
||||
self.assert_prepare_for_model(tokenizer_r, tokenizer_p)
|
||||
# TODO: enable for v3.0.0
|
||||
# self.assert_empty_output_no_special_tokens(tokenizer_r, tokenizer_p)
|
||||
|
||||
|
@ -709,6 +710,12 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
|||
for i_no, i_with in zip(no_special_tokens[key], with_special_tokens[key]):
|
||||
self.assertEqual(len(i_no), len(i_with) - simple_num_special_tokens_to_add)
|
||||
|
||||
def assert_prepare_for_model(self, tokenizer_r, tokenizer_p):
|
||||
string_sequence = "Asserting that both tokenizers are equal"
|
||||
python_output = tokenizer_p.prepare_for_model(tokenizer_p.encode(string_sequence))
|
||||
rust_output = tokenizer_r.prepare_for_model(tokenizer_r.encode(string_sequence))
|
||||
self.assertEqual(python_output, rust_output)
|
||||
|
||||
|
||||
class WordPieceFastTokenizerTest(CommonFastTokenizerTest):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue