From 21d8b6a33ebf96680b6a0aabd27fa7eaa068da93 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Mon, 24 Feb 2020 12:09:46 -0500 Subject: [PATCH] Testing that batch_encode_plus is the same as encode_plus (#2973) * Testing that encode_plus and batch_encode_plus behave the same way Spoiler alert: they don't * Testing rest of arguments in batch_encode_plus * Test tensor return in batch_encode_plus * Addressing Sam's comments * flake8 * Simplified with `num_added_tokens` --- src/transformers/tokenization_t5.py | 6 ++ src/transformers/tokenization_utils.py | 143 ++++++++++++++++++------- tests/test_tokenization_common.py | 112 ++++++++++++++++++- 3 files changed, 222 insertions(+), 39 deletions(-) diff --git a/src/transformers/tokenization_t5.py b/src/transformers/tokenization_t5.py index 1aa5df38ad..74532a1f15 100644 --- a/src/transformers/tokenization_t5.py +++ b/src/transformers/tokenization_t5.py @@ -98,6 +98,12 @@ class T5Tokenizer(PreTrainedTokenizer): additional_special_tokens=additional_special_tokens, **kwargs, ) + self.max_len_single_sentence = ( + self.max_len + ) # no default special tokens - you can update this value if you add special tokens + self.max_len_sentences_pair = ( + self.max_len + ) # no default special tokens - you can update this value if you add special tokens try: import sentencepiece as spm diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index 754a4a0932..73fcb79c97 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -14,7 +14,6 @@ # limitations under the License. """Tokenization classes for OpenAI GPT.""" - import copy import itertools import json @@ -153,6 +152,18 @@ class PreTrainedTokenizer(object): padding_side = "right" + NO_PAD_TOKEN_FOR_BATCH_MSG = ( + "No padding token is set for this model, therefore no batch can be made with uneven " + "sequences. Set a padding token or adjust the lengths of the sequences building the " + "batch so that every sequence is of the same length." + ) + + UNEVEN_SEQUENCES_FOR_BATCH_MSG = ( + "The sequences building the batch are not of the same size, no tensor " + "can be built. Set `pad_to_max_length=True` to pad the smaller sequences" + "up to the larger sequence's length." + ) + @property def bos_token(self): """ Beginning of sentence token (string). Log an error if used while not having been set. """ @@ -1020,14 +1031,18 @@ class PreTrainedTokenizer(object): def batch_encode_plus( self, batch_text_or_text_pairs=None, - add_special_tokens=False, + add_special_tokens=True, max_length=None, stride=0, truncation_strategy="longest_first", + pad_to_max_length=False, return_tensors=None, - return_input_lengths=False, - return_attention_masks=False, + return_token_type_ids=True, + return_attention_masks=True, + return_overflowing_tokens=False, + return_special_tokens_masks=False, return_offsets_mapping=False, + return_input_lengths=False, **kwargs ): """ @@ -1050,14 +1065,54 @@ class PreTrainedTokenizer(object): - 'only_first': Only truncate the first sequence - 'only_second': Only truncate the second sequence - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length) + pad_to_max_length: if set to True, the returned sequences will be padded according to the model's padding side and + padding index, up to their max length. If no max length is specified, the padding is done up to the model's max length. + The tokenizer padding sides are handled by the class attribute `padding_side` which can be set to the following strings: + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + Defaults to False: no padding. return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant or PyTorch torch.Tensor instead of a list of python integers. return_input_lengths: (optional) If set the resulting dictionary will include the length of each sample return_attention_masks: (optional) Set to True to return the attention mask (default False) return_offsets_mapping: (optional) Not available, should be set to False or it will throw NotImplementError **kwargs: passed to the `self.tokenize()` method + + Return: + A Dictionary of shape:: + + { + input_ids: list[List[int]], + token_type_ids: list[List[int]] if return_token_type_ids is True (default) + attention_mask: list[List[int]] if return_attention_mask is True (default) + overflowing_tokens: list[List[int]] if a ``max_length`` is specified and return_overflowing_tokens is True + num_truncated_tokens: List[int] if a ``max_length`` is specified and return_overflowing_tokens is True + special_tokens_mask: list[List[int]] if ``add_special_tokens`` if set to ``True`` and return_special_tokens_mask is True + } + + With the fields: + ``input_ids``: list of token ids to be fed to a model + ``token_type_ids``: list of token type ids to be fed to a model + ``attention_mask``: list of indices specifying which tokens should be attended to by the model + ``overflowing_tokens``: list of overflowing tokens if a max length is specified. + ``num_truncated_tokens``: number of overflowing tokens a ``max_length`` is specified + ``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added + tokens and 1 specifying sequence tokens. """ + def get_input_ids(text): + if isinstance(text, str): + tokens = self.tokenize(text, add_special_tokens=add_special_tokens, **kwargs) + return self.convert_tokens_to_ids(tokens) + elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str): + return self.convert_tokens_to_ids(text) + elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int): + return text + else: + raise ValueError( + "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers." + ) + if return_offsets_mapping: raise NotImplementedError( "return_offset_mapping is not available when using Python tokenizers." @@ -1067,21 +1122,47 @@ class PreTrainedTokenizer(object): "https://github.com/huggingface/transformers/pull/2674" ) - batch_outputs = {} + input_ids = [] for ids_or_pair_ids in batch_text_or_text_pairs: if isinstance(ids_or_pair_ids, (list, tuple)): assert len(ids_or_pair_ids) == 2 ids, pair_ids = ids_or_pair_ids else: ids, pair_ids = ids_or_pair_ids, None - outputs = self.encode_plus( - ids, - pair_ids, - add_special_tokens=add_special_tokens, + + first_ids = get_input_ids(ids) + second_ids = get_input_ids(pair_ids) if pair_ids is not None else None + input_ids.append((first_ids, second_ids)) + + if max_length is None and pad_to_max_length: + + def total_sequence_length(input_pairs): + first_ids, second_ids = input_pairs + return len(first_ids) + ( + self.num_added_tokens() + if second_ids is None + else (len(second_ids) + self.num_added_tokens(pair=True)) + ) + + max_length = max([total_sequence_length(ids) for ids in input_ids]) + + batch_outputs = {} + for first_ids, second_ids in input_ids: + # Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by + # the model. It adds special tokens, truncates sequences if overflowing while taking into account + # the special tokens and manages a window stride for overflowing tokens + outputs = self.prepare_for_model( + first_ids, + pair_ids=second_ids, max_length=max_length, + pad_to_max_length=pad_to_max_length, + add_special_tokens=add_special_tokens, stride=stride, truncation_strategy=truncation_strategy, - return_tensors=None, + return_attention_mask=return_attention_masks, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_masks, ) # Append the non-padded length to the output @@ -1093,31 +1174,28 @@ class PreTrainedTokenizer(object): batch_outputs[key] = [] batch_outputs[key].append(value) - # Compute longest sequence size - max_seq_len = max(map(len, batch_outputs["input_ids"])) - - if return_attention_masks: - # Allow the model to not give any special attention to padded input - batch_outputs["attention_mask"] = [[0] * len(v) for v in batch_outputs["input_ids"]] - if return_tensors is not None: # Do the tensor conversion in batch for key, value in batch_outputs.items(): - - padded_value = value - # verify that the tokenizer has a pad_token_id - if key != "input_len" and self._pad_token is not None: - # Padding handle - padded_value = [ - v + [self.pad_token_id if key == "input_ids" else 1] * (max_seq_len - len(v)) - for v in padded_value - ] - if return_tensors == "tf" and is_tf_available(): - batch_outputs[key] = tf.constant(padded_value) + try: + batch_outputs[key] = tf.constant(value) + except ValueError: + if None in [item for sequence in value for item in sequence]: + raise ValueError(self.NO_PAD_TOKEN_FOR_BATCH_MSG) + else: + raise ValueError(self.UNEVEN_SEQUENCES_FOR_BATCH_MSG) elif return_tensors == "pt" and is_torch_available(): - batch_outputs[key] = torch.tensor(padded_value) + try: + batch_outputs[key] = torch.tensor(value) + except ValueError: + raise ValueError(self.UNEVEN_SEQUENCES_FOR_BATCH_MSG) + except RuntimeError: + if None in [item for sequence in value for item in sequence]: + raise ValueError(self.NO_PAD_TOKEN_FOR_BATCH_MSG) + else: + raise elif return_tensors is not None: logger.warning( "Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format( @@ -1125,13 +1203,6 @@ class PreTrainedTokenizer(object): ) ) - # encoder_attention_mask requires 1 for real token, 0 for padding, just invert value - if return_attention_masks: - if is_tf_available(): - batch_outputs["attention_mask"] = tf.abs(batch_outputs["attention_mask"] - 1) - else: - batch_outputs["attention_mask"] = torch.abs(batch_outputs["attention_mask"] - 1) - return batch_outputs def prepare_for_model( diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index a597d90f04..1ca830004b 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -19,6 +19,8 @@ import pickle import shutil import tempfile +from tests.utils import require_tf, require_torch + class TokenizerTesterMixin: @@ -40,6 +42,15 @@ class TokenizerTesterMixin: def get_input_output_texts(self): raise NotImplementedError + @staticmethod + def convert_batch_encode_plus_format_to_encode_plus(batch_encode_plus_sequences): + # Switch from batch_encode_plus format: {'input_ids': [[...], [...]], ...} + # to the concatenated encode_plus format: [{'input_ids': [...], ...}, {'input_ids': [...], ...}] + return [ + {value: batch_encode_plus_sequences[value][i] for value in batch_encode_plus_sequences.keys()} + for i in range(len(batch_encode_plus_sequences)) + ] + def test_tokenizers_common_properties(self): tokenizer = self.get_tokenizer() attributes_list = [ @@ -535,11 +546,8 @@ class TokenizerTesterMixin: # we're loading an S3 configuration from a pre-trained identifier, and we have no way of testing those today. tokenizer = self.get_tokenizer(random_argument=True) - print(tokenizer.init_kwargs) assert tokenizer.init_kwargs["random_argument"] is True new_tokenizer = self.get_tokenizer(random_argument=False) - print(tokenizer.init_kwargs) - print(new_tokenizer.init_kwargs) assert tokenizer.init_kwargs["random_argument"] is True assert new_tokenizer.init_kwargs["random_argument"] is False @@ -562,3 +570,101 @@ class TokenizerTesterMixin: for word, ind in vocab.items(): self.assertEqual(tokenizer.convert_tokens_to_ids(word), ind) self.assertEqual(tokenizer.convert_ids_to_tokens(ind), word) + + def test_batch_encode_plus_batch_sequence_length(self): + # Tests that all encoded values have the correct size + tokenizer = self.get_tokenizer() + sequences = [ + "Testing batch encode plus", + "Testing batch encode plus with different sequence lengths", + "Testing batch encode plus with different sequence lengths correctly pads", + ] + + encoded_sequences = [tokenizer.encode_plus(sequence, pad_to_max_length=False) for sequence in sequences] + encoded_sequences_batch = tokenizer.batch_encode_plus(sequences) + self.assertListEqual( + encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch) + ) + + maximum_length = len(max([encoded_sequence["input_ids"] for encoded_sequence in encoded_sequences], key=len)) + + encoded_sequences_padded = [ + tokenizer.encode_plus(sequence, pad_to_max_length=True, max_length=maximum_length) + for sequence in sequences + ] + encoded_sequences_batch_padded = tokenizer.batch_encode_plus(sequences, pad_to_max_length=True) + self.assertListEqual( + encoded_sequences_padded, + self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch_padded), + ) + + def test_batch_encode_plus_padding(self): + # Test that padded sequences are equivalent between batch_encode_plus and encode_plus + + # Right padding tests + tokenizer = self.get_tokenizer() + sequences = [ + "Testing batch encode plus", + "Testing batch encode plus with different sequence lengths", + "Testing batch encode plus with different sequence lengths correctly pads", + ] + + max_length = 100 + encoded_sequences = [ + tokenizer.encode_plus(sequence, pad_to_max_length=True, max_length=max_length) for sequence in sequences + ] + encoded_sequences_batch = tokenizer.batch_encode_plus(sequences, pad_to_max_length=True, max_length=max_length) + self.assertListEqual( + encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch) + ) + + # Left padding tests + tokenizer = self.get_tokenizer() + tokenizer.padding_side = "left" + sequences = [ + "Testing batch encode plus", + "Testing batch encode plus with different sequence lengths", + "Testing batch encode plus with different sequence lengths correctly pads", + ] + + max_length = 100 + encoded_sequences = [ + tokenizer.encode_plus(sequence, pad_to_max_length=True, max_length=max_length) for sequence in sequences + ] + encoded_sequences_batch = tokenizer.batch_encode_plus(sequences, pad_to_max_length=True, max_length=max_length) + self.assertListEqual( + encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch) + ) + + @require_torch + @require_tf + def test_batch_encode_plus_tensors(self): + tokenizer = self.get_tokenizer() + sequences = [ + "Testing batch encode plus", + "Testing batch encode plus with different sequence lengths", + "Testing batch encode plus with different sequence lengths correctly pads", + ] + + # A Tensor cannot be build by sequences which are not the same size + self.assertRaises(ValueError, tokenizer.batch_encode_plus, sequences, return_tensors="pt") + self.assertRaises(ValueError, tokenizer.batch_encode_plus, sequences, return_tensors="tf") + + if tokenizer.pad_token_id is None: + self.assertRaises( + ValueError, tokenizer.batch_encode_plus, sequences, pad_to_max_length=True, return_tensors="pt" + ) + self.assertRaises( + ValueError, tokenizer.batch_encode_plus, sequences, pad_to_max_length=True, return_tensors="tf" + ) + else: + pytorch_tensor = tokenizer.batch_encode_plus(sequences, pad_to_max_length=True, return_tensors="pt") + tensorflow_tensor = tokenizer.batch_encode_plus(sequences, pad_to_max_length=True, return_tensors="tf") + encoded_sequences = tokenizer.batch_encode_plus(sequences, pad_to_max_length=True) + + for key in encoded_sequences.keys(): + pytorch_value = pytorch_tensor[key].tolist() + tensorflow_value = tensorflow_tensor[key].numpy().tolist() + encoded_value = encoded_sequences[key] + + self.assertEqual(pytorch_value, tensorflow_value, encoded_value)