Always truncate argument in the encode method

This commit is contained in:
LysandreJik 2019-09-30 10:20:14 -04:00
parent 7af0777910
commit 7c789c337d
2 changed files with 48 additions and 12 deletions

View File

@ -232,6 +232,23 @@ class CommonTestCases:
assert len(truncated_sequence) == total_length - 2
assert truncated_sequence == tokenizer.add_special_tokens_single_sequence(sequence[:-2])
def test_always_truncate(self):
tokenizer = self.get_tokenizer()
seq_0 = "This is a sentence to be encoded."
length_single_sequence = len(tokenizer.encode(seq_0))
length = len(tokenizer.encode(seq_0, seq_0, add_special_tokens=True))
not_truncated = tokenizer.encode(seq_0, seq_0, add_special_tokens=True, max_length=length_single_sequence)
truncated = tokenizer.encode(
seq_0, seq_0,
max_length=length_single_sequence,
add_special_tokens=True,
always_truncate=True
)
assert truncated == not_truncated[:length_single_sequence - length]
def test_maximum_encoding_length_pair_input(self):
tokenizer = self.get_tokenizer()

View File

@ -693,14 +693,15 @@ class PreTrainedTokenizer(object):
raise NotImplementedError
def encode(self,
text,
text_pair=None,
add_special_tokens=False,
max_length=None,
stride=0,
truncate_first_sequence=True,
return_tensors=None,
**kwargs):
text,
text_pair=None,
add_special_tokens=False,
max_length=None,
stride=0,
truncate_first_sequence=True,
return_tensors=None,
always_truncate=False,
**kwargs):
"""
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
@ -721,6 +722,8 @@ class PreTrainedTokenizer(object):
from the main sequence returned. The value of this argument defined the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated.
always_truncate: if set to True, will always truncate the sequences when overflowing, even if one of the
sequences may be lost in the process.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method
@ -732,6 +735,7 @@ class PreTrainedTokenizer(object):
stride=stride,
truncate_first_sequence=truncate_first_sequence,
return_tensors=return_tensors,
always_truncate=always_truncate,
**kwargs)
return encoded_inputs["input_ids"]
@ -744,6 +748,7 @@ class PreTrainedTokenizer(object):
stride=0,
truncate_first_sequence=True,
return_tensors=None,
always_truncate=False,
**kwargs):
"""
Returns a dictionary containing the encoded sequence or sequence pair and additional informations:
@ -764,6 +769,8 @@ class PreTrainedTokenizer(object):
from the main sequence returned. The value of this argument defined the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated.
always_truncate: if set to True, will always truncate the sequences when overflowing, even if one of the
sequences may be lost in the process.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method
@ -788,11 +795,12 @@ class PreTrainedTokenizer(object):
add_special_tokens=add_special_tokens,
stride=stride,
truncate_first_sequence=truncate_first_sequence,
always_truncate=always_truncate,
return_tensors=return_tensors)
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0,
truncate_first_sequence=True, return_tensors=None):
truncate_first_sequence=True, always_truncate=False, return_tensors=None):
"""
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
It adds special tokens, truncates
@ -812,6 +820,8 @@ class PreTrainedTokenizer(object):
truncate_first_sequence: if set to `True` and an optional second list of input ids is provided,
alongside a specified `max_length`, will truncate the first sequence if the total size is superior
than the specified `max_length`. If set to `False`, will truncate the second sequence instead.
always_truncate: if set to True, will always truncate the sequences when overflowing, even if one of the
sequences may be lost in the process.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
@ -826,9 +836,14 @@ class PreTrainedTokenizer(object):
if max_length:
n_added_tokens = self.num_added_tokens(pair=pair) if add_special_tokens else 0
if pair and n_added_tokens + (len_pair_ids if truncate_first_sequence else len_ids) >= max_length:
logger.warning(
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length."
"This pair of sequences will not be truncated.")
if always_truncate:
logger.warning(
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length. "
"This pair of sequences will be truncated but one of the sequences may not be present in the resulting list of ids.")
else:
logger.warning(
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length. "
"This pair of sequences will not be truncated.")
else:
if n_added_tokens + len_ids + len_pair_ids > max_length:
if truncate_first_sequence or not pair:
@ -860,6 +875,10 @@ class PreTrainedTokenizer(object):
encoded_inputs["input_ids"] = sequence
encoded_inputs["token_type_ids"] = token_type_ids
if always_truncate and len(encoded_inputs["input_ids"]) > max_length:
encoded_inputs["input_ids"] = encoded_inputs["input_ids"][:max_length]
encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"][:max_length]
return encoded_inputs
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1):