From 15478c1287a4e7b52c01730ffb0718243d153600 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Wed, 9 Sep 2020 12:55:17 +0200 Subject: [PATCH] Batch encore plus and overflowing tokens fails when non existing overflowing tokens for a sequence (#6677) * Patch and test * Fix tests --- src/transformers/tokenization_utils_base.py | 8 +++++--- tests/test_tokenization_common.py | 12 ++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 43b77951fb..57e578990c 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -2440,6 +2440,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) # Truncation: Handle max sequence length + overflowing_tokens = [] if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: ids, pair_ids, overflowing_tokens = self.truncate_sequences( ids, @@ -2448,9 +2449,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): truncation_strategy=truncation_strategy, stride=stride, ) - if return_overflowing_tokens: - encoded_inputs["overflowing_tokens"] = overflowing_tokens - encoded_inputs["num_truncated_tokens"] = total_len - max_length + + if return_overflowing_tokens: + encoded_inputs["overflowing_tokens"] = overflowing_tokens + encoded_inputs["num_truncated_tokens"] = total_len - max_length # Add special tokens if add_special_tokens: diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 22e845d2c4..26c48b475d 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1352,6 +1352,18 @@ class TokenizerTesterMixin: self.assertEqual(input_dict, prepared_input_dict) + def test_batch_encode_plus_overflowing_tokens(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + string_sequences = ["Testing the prepare_for_model method.", "Test"] + + if tokenizer.pad_token is None: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + tokenizer.batch_encode_plus( + string_sequences, return_overflowing_tokens=True, truncation=True, padding=True, max_length=3 + ) + @require_torch @require_tf def test_batch_encode_plus_tensors(self):