Testing that batch_encode_plus is the same as encode_plus (#2973)
* Testing that encode_plus and batch_encode_plus behave the same way Spoiler alert: they don't * Testing rest of arguments in batch_encode_plus * Test tensor return in batch_encode_plus * Addressing Sam's comments * flake8 * Simplified with `num_added_tokens`
This commit is contained in:
parent
17c45c39ed
commit
21d8b6a33e
|
@ -98,6 +98,12 @@ class T5Tokenizer(PreTrainedTokenizer):
|
|||
additional_special_tokens=additional_special_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
self.max_len_single_sentence = (
|
||||
self.max_len
|
||||
) # no default special tokens - you can update this value if you add special tokens
|
||||
self.max_len_sentences_pair = (
|
||||
self.max_len
|
||||
) # no default special tokens - you can update this value if you add special tokens
|
||||
|
||||
try:
|
||||
import sentencepiece as spm
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# limitations under the License.
|
||||
"""Tokenization classes for OpenAI GPT."""
|
||||
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import json
|
||||
|
@ -153,6 +152,18 @@ class PreTrainedTokenizer(object):
|
|||
|
||||
padding_side = "right"
|
||||
|
||||
NO_PAD_TOKEN_FOR_BATCH_MSG = (
|
||||
"No padding token is set for this model, therefore no batch can be made with uneven "
|
||||
"sequences. Set a padding token or adjust the lengths of the sequences building the "
|
||||
"batch so that every sequence is of the same length."
|
||||
)
|
||||
|
||||
UNEVEN_SEQUENCES_FOR_BATCH_MSG = (
|
||||
"The sequences building the batch are not of the same size, no tensor "
|
||||
"can be built. Set `pad_to_max_length=True` to pad the smaller sequences"
|
||||
"up to the larger sequence's length."
|
||||
)
|
||||
|
||||
@property
|
||||
def bos_token(self):
|
||||
""" Beginning of sentence token (string). Log an error if used while not having been set. """
|
||||
|
@ -1020,14 +1031,18 @@ class PreTrainedTokenizer(object):
|
|||
def batch_encode_plus(
|
||||
self,
|
||||
batch_text_or_text_pairs=None,
|
||||
add_special_tokens=False,
|
||||
add_special_tokens=True,
|
||||
max_length=None,
|
||||
stride=0,
|
||||
truncation_strategy="longest_first",
|
||||
pad_to_max_length=False,
|
||||
return_tensors=None,
|
||||
return_input_lengths=False,
|
||||
return_attention_masks=False,
|
||||
return_token_type_ids=True,
|
||||
return_attention_masks=True,
|
||||
return_overflowing_tokens=False,
|
||||
return_special_tokens_masks=False,
|
||||
return_offsets_mapping=False,
|
||||
return_input_lengths=False,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
|
@ -1050,14 +1065,54 @@ class PreTrainedTokenizer(object):
|
|||
- 'only_first': Only truncate the first sequence
|
||||
- 'only_second': Only truncate the second sequence
|
||||
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
|
||||
pad_to_max_length: if set to True, the returned sequences will be padded according to the model's padding side and
|
||||
padding index, up to their max length. If no max length is specified, the padding is done up to the model's max length.
|
||||
The tokenizer padding sides are handled by the class attribute `padding_side` which can be set to the following strings:
|
||||
- 'left': pads on the left of the sequences
|
||||
- 'right': pads on the right of the sequences
|
||||
Defaults to False: no padding.
|
||||
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
|
||||
or PyTorch torch.Tensor instead of a list of python integers.
|
||||
return_input_lengths: (optional) If set the resulting dictionary will include the length of each sample
|
||||
return_attention_masks: (optional) Set to True to return the attention mask (default False)
|
||||
return_offsets_mapping: (optional) Not available, should be set to False or it will throw NotImplementError
|
||||
**kwargs: passed to the `self.tokenize()` method
|
||||
|
||||
Return:
|
||||
A Dictionary of shape::
|
||||
|
||||
{
|
||||
input_ids: list[List[int]],
|
||||
token_type_ids: list[List[int]] if return_token_type_ids is True (default)
|
||||
attention_mask: list[List[int]] if return_attention_mask is True (default)
|
||||
overflowing_tokens: list[List[int]] if a ``max_length`` is specified and return_overflowing_tokens is True
|
||||
num_truncated_tokens: List[int] if a ``max_length`` is specified and return_overflowing_tokens is True
|
||||
special_tokens_mask: list[List[int]] if ``add_special_tokens`` if set to ``True`` and return_special_tokens_mask is True
|
||||
}
|
||||
|
||||
With the fields:
|
||||
``input_ids``: list of token ids to be fed to a model
|
||||
``token_type_ids``: list of token type ids to be fed to a model
|
||||
``attention_mask``: list of indices specifying which tokens should be attended to by the model
|
||||
``overflowing_tokens``: list of overflowing tokens if a max length is specified.
|
||||
``num_truncated_tokens``: number of overflowing tokens a ``max_length`` is specified
|
||||
``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added
|
||||
tokens and 1 specifying sequence tokens.
|
||||
"""
|
||||
|
||||
def get_input_ids(text):
|
||||
if isinstance(text, str):
|
||||
tokens = self.tokenize(text, add_special_tokens=add_special_tokens, **kwargs)
|
||||
return self.convert_tokens_to_ids(tokens)
|
||||
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
|
||||
return self.convert_tokens_to_ids(text)
|
||||
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
|
||||
return text
|
||||
else:
|
||||
raise ValueError(
|
||||
"Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
|
||||
)
|
||||
|
||||
if return_offsets_mapping:
|
||||
raise NotImplementedError(
|
||||
"return_offset_mapping is not available when using Python tokenizers."
|
||||
|
@ -1067,21 +1122,47 @@ class PreTrainedTokenizer(object):
|
|||
"https://github.com/huggingface/transformers/pull/2674"
|
||||
)
|
||||
|
||||
batch_outputs = {}
|
||||
input_ids = []
|
||||
for ids_or_pair_ids in batch_text_or_text_pairs:
|
||||
if isinstance(ids_or_pair_ids, (list, tuple)):
|
||||
assert len(ids_or_pair_ids) == 2
|
||||
ids, pair_ids = ids_or_pair_ids
|
||||
else:
|
||||
ids, pair_ids = ids_or_pair_ids, None
|
||||
outputs = self.encode_plus(
|
||||
ids,
|
||||
pair_ids,
|
||||
add_special_tokens=add_special_tokens,
|
||||
|
||||
first_ids = get_input_ids(ids)
|
||||
second_ids = get_input_ids(pair_ids) if pair_ids is not None else None
|
||||
input_ids.append((first_ids, second_ids))
|
||||
|
||||
if max_length is None and pad_to_max_length:
|
||||
|
||||
def total_sequence_length(input_pairs):
|
||||
first_ids, second_ids = input_pairs
|
||||
return len(first_ids) + (
|
||||
self.num_added_tokens()
|
||||
if second_ids is None
|
||||
else (len(second_ids) + self.num_added_tokens(pair=True))
|
||||
)
|
||||
|
||||
max_length = max([total_sequence_length(ids) for ids in input_ids])
|
||||
|
||||
batch_outputs = {}
|
||||
for first_ids, second_ids in input_ids:
|
||||
# Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by
|
||||
# the model. It adds special tokens, truncates sequences if overflowing while taking into account
|
||||
# the special tokens and manages a window stride for overflowing tokens
|
||||
outputs = self.prepare_for_model(
|
||||
first_ids,
|
||||
pair_ids=second_ids,
|
||||
max_length=max_length,
|
||||
pad_to_max_length=pad_to_max_length,
|
||||
add_special_tokens=add_special_tokens,
|
||||
stride=stride,
|
||||
truncation_strategy=truncation_strategy,
|
||||
return_tensors=None,
|
||||
return_attention_mask=return_attention_masks,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_masks,
|
||||
)
|
||||
|
||||
# Append the non-padded length to the output
|
||||
|
@ -1093,31 +1174,28 @@ class PreTrainedTokenizer(object):
|
|||
batch_outputs[key] = []
|
||||
batch_outputs[key].append(value)
|
||||
|
||||
# Compute longest sequence size
|
||||
max_seq_len = max(map(len, batch_outputs["input_ids"]))
|
||||
|
||||
if return_attention_masks:
|
||||
# Allow the model to not give any special attention to padded input
|
||||
batch_outputs["attention_mask"] = [[0] * len(v) for v in batch_outputs["input_ids"]]
|
||||
|
||||
if return_tensors is not None:
|
||||
|
||||
# Do the tensor conversion in batch
|
||||
for key, value in batch_outputs.items():
|
||||
|
||||
padded_value = value
|
||||
# verify that the tokenizer has a pad_token_id
|
||||
if key != "input_len" and self._pad_token is not None:
|
||||
# Padding handle
|
||||
padded_value = [
|
||||
v + [self.pad_token_id if key == "input_ids" else 1] * (max_seq_len - len(v))
|
||||
for v in padded_value
|
||||
]
|
||||
|
||||
if return_tensors == "tf" and is_tf_available():
|
||||
batch_outputs[key] = tf.constant(padded_value)
|
||||
try:
|
||||
batch_outputs[key] = tf.constant(value)
|
||||
except ValueError:
|
||||
if None in [item for sequence in value for item in sequence]:
|
||||
raise ValueError(self.NO_PAD_TOKEN_FOR_BATCH_MSG)
|
||||
else:
|
||||
raise ValueError(self.UNEVEN_SEQUENCES_FOR_BATCH_MSG)
|
||||
elif return_tensors == "pt" and is_torch_available():
|
||||
batch_outputs[key] = torch.tensor(padded_value)
|
||||
try:
|
||||
batch_outputs[key] = torch.tensor(value)
|
||||
except ValueError:
|
||||
raise ValueError(self.UNEVEN_SEQUENCES_FOR_BATCH_MSG)
|
||||
except RuntimeError:
|
||||
if None in [item for sequence in value for item in sequence]:
|
||||
raise ValueError(self.NO_PAD_TOKEN_FOR_BATCH_MSG)
|
||||
else:
|
||||
raise
|
||||
elif return_tensors is not None:
|
||||
logger.warning(
|
||||
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(
|
||||
|
@ -1125,13 +1203,6 @@ class PreTrainedTokenizer(object):
|
|||
)
|
||||
)
|
||||
|
||||
# encoder_attention_mask requires 1 for real token, 0 for padding, just invert value
|
||||
if return_attention_masks:
|
||||
if is_tf_available():
|
||||
batch_outputs["attention_mask"] = tf.abs(batch_outputs["attention_mask"] - 1)
|
||||
else:
|
||||
batch_outputs["attention_mask"] = torch.abs(batch_outputs["attention_mask"] - 1)
|
||||
|
||||
return batch_outputs
|
||||
|
||||
def prepare_for_model(
|
||||
|
|
|
@ -19,6 +19,8 @@ import pickle
|
|||
import shutil
|
||||
import tempfile
|
||||
|
||||
from tests.utils import require_tf, require_torch
|
||||
|
||||
|
||||
class TokenizerTesterMixin:
|
||||
|
||||
|
@ -40,6 +42,15 @@ class TokenizerTesterMixin:
|
|||
def get_input_output_texts(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def convert_batch_encode_plus_format_to_encode_plus(batch_encode_plus_sequences):
|
||||
# Switch from batch_encode_plus format: {'input_ids': [[...], [...]], ...}
|
||||
# to the concatenated encode_plus format: [{'input_ids': [...], ...}, {'input_ids': [...], ...}]
|
||||
return [
|
||||
{value: batch_encode_plus_sequences[value][i] for value in batch_encode_plus_sequences.keys()}
|
||||
for i in range(len(batch_encode_plus_sequences))
|
||||
]
|
||||
|
||||
def test_tokenizers_common_properties(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
attributes_list = [
|
||||
|
@ -535,11 +546,8 @@ class TokenizerTesterMixin:
|
|||
# we're loading an S3 configuration from a pre-trained identifier, and we have no way of testing those today.
|
||||
|
||||
tokenizer = self.get_tokenizer(random_argument=True)
|
||||
print(tokenizer.init_kwargs)
|
||||
assert tokenizer.init_kwargs["random_argument"] is True
|
||||
new_tokenizer = self.get_tokenizer(random_argument=False)
|
||||
print(tokenizer.init_kwargs)
|
||||
print(new_tokenizer.init_kwargs)
|
||||
assert tokenizer.init_kwargs["random_argument"] is True
|
||||
assert new_tokenizer.init_kwargs["random_argument"] is False
|
||||
|
||||
|
@ -562,3 +570,101 @@ class TokenizerTesterMixin:
|
|||
for word, ind in vocab.items():
|
||||
self.assertEqual(tokenizer.convert_tokens_to_ids(word), ind)
|
||||
self.assertEqual(tokenizer.convert_ids_to_tokens(ind), word)
|
||||
|
||||
def test_batch_encode_plus_batch_sequence_length(self):
|
||||
# Tests that all encoded values have the correct size
|
||||
tokenizer = self.get_tokenizer()
|
||||
sequences = [
|
||||
"Testing batch encode plus",
|
||||
"Testing batch encode plus with different sequence lengths",
|
||||
"Testing batch encode plus with different sequence lengths correctly pads",
|
||||
]
|
||||
|
||||
encoded_sequences = [tokenizer.encode_plus(sequence, pad_to_max_length=False) for sequence in sequences]
|
||||
encoded_sequences_batch = tokenizer.batch_encode_plus(sequences)
|
||||
self.assertListEqual(
|
||||
encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch)
|
||||
)
|
||||
|
||||
maximum_length = len(max([encoded_sequence["input_ids"] for encoded_sequence in encoded_sequences], key=len))
|
||||
|
||||
encoded_sequences_padded = [
|
||||
tokenizer.encode_plus(sequence, pad_to_max_length=True, max_length=maximum_length)
|
||||
for sequence in sequences
|
||||
]
|
||||
encoded_sequences_batch_padded = tokenizer.batch_encode_plus(sequences, pad_to_max_length=True)
|
||||
self.assertListEqual(
|
||||
encoded_sequences_padded,
|
||||
self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch_padded),
|
||||
)
|
||||
|
||||
def test_batch_encode_plus_padding(self):
|
||||
# Test that padded sequences are equivalent between batch_encode_plus and encode_plus
|
||||
|
||||
# Right padding tests
|
||||
tokenizer = self.get_tokenizer()
|
||||
sequences = [
|
||||
"Testing batch encode plus",
|
||||
"Testing batch encode plus with different sequence lengths",
|
||||
"Testing batch encode plus with different sequence lengths correctly pads",
|
||||
]
|
||||
|
||||
max_length = 100
|
||||
encoded_sequences = [
|
||||
tokenizer.encode_plus(sequence, pad_to_max_length=True, max_length=max_length) for sequence in sequences
|
||||
]
|
||||
encoded_sequences_batch = tokenizer.batch_encode_plus(sequences, pad_to_max_length=True, max_length=max_length)
|
||||
self.assertListEqual(
|
||||
encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch)
|
||||
)
|
||||
|
||||
# Left padding tests
|
||||
tokenizer = self.get_tokenizer()
|
||||
tokenizer.padding_side = "left"
|
||||
sequences = [
|
||||
"Testing batch encode plus",
|
||||
"Testing batch encode plus with different sequence lengths",
|
||||
"Testing batch encode plus with different sequence lengths correctly pads",
|
||||
]
|
||||
|
||||
max_length = 100
|
||||
encoded_sequences = [
|
||||
tokenizer.encode_plus(sequence, pad_to_max_length=True, max_length=max_length) for sequence in sequences
|
||||
]
|
||||
encoded_sequences_batch = tokenizer.batch_encode_plus(sequences, pad_to_max_length=True, max_length=max_length)
|
||||
self.assertListEqual(
|
||||
encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch)
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@require_tf
|
||||
def test_batch_encode_plus_tensors(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
sequences = [
|
||||
"Testing batch encode plus",
|
||||
"Testing batch encode plus with different sequence lengths",
|
||||
"Testing batch encode plus with different sequence lengths correctly pads",
|
||||
]
|
||||
|
||||
# A Tensor cannot be build by sequences which are not the same size
|
||||
self.assertRaises(ValueError, tokenizer.batch_encode_plus, sequences, return_tensors="pt")
|
||||
self.assertRaises(ValueError, tokenizer.batch_encode_plus, sequences, return_tensors="tf")
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
self.assertRaises(
|
||||
ValueError, tokenizer.batch_encode_plus, sequences, pad_to_max_length=True, return_tensors="pt"
|
||||
)
|
||||
self.assertRaises(
|
||||
ValueError, tokenizer.batch_encode_plus, sequences, pad_to_max_length=True, return_tensors="tf"
|
||||
)
|
||||
else:
|
||||
pytorch_tensor = tokenizer.batch_encode_plus(sequences, pad_to_max_length=True, return_tensors="pt")
|
||||
tensorflow_tensor = tokenizer.batch_encode_plus(sequences, pad_to_max_length=True, return_tensors="tf")
|
||||
encoded_sequences = tokenizer.batch_encode_plus(sequences, pad_to_max_length=True)
|
||||
|
||||
for key in encoded_sequences.keys():
|
||||
pytorch_value = pytorch_tensor[key].tolist()
|
||||
tensorflow_value = tensorflow_tensor[key].numpy().tolist()
|
||||
encoded_value = encoded_sequences[key]
|
||||
|
||||
self.assertEqual(pytorch_value, tensorflow_value, encoded_value)
|
||||
|
|
Loading…
Reference in New Issue