From 74d0bcb6ff692dbaa52da1fdc2b80ece06f5fbe5 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 12 Nov 2019 15:27:57 -0500 Subject: [PATCH] Fix special tokens addition in decoder --- .../tests/tokenization_tests_commons.py | 20 +++++++++++++++++++ transformers/tokenization_utils.py | 2 +- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/transformers/tests/tokenization_tests_commons.py b/transformers/tests/tokenization_tests_commons.py index a921696b77..fdaf8cc137 100644 --- a/transformers/tests/tokenization_tests_commons.py +++ b/transformers/tests/tokenization_tests_commons.py @@ -160,6 +160,26 @@ class CommonTestCases: self.assertEqual(tokens[0], tokenizer.eos_token_id) self.assertEqual(tokens[-2], tokenizer.pad_token_id) + def test_add_special_tokens(self): + tokenizer = self.get_tokenizer() + input_text, output_text = self.get_input_output_texts() + + special_token = "[SPECIAL TOKEN]" + + tokenizer.add_special_tokens({"cls_token": special_token}) + encoded_special_token = tokenizer.encode(special_token, add_special_tokens=False) + assert len(encoded_special_token) == 1 + + text = " ".join([input_text, special_token, output_text]) + encoded = tokenizer.encode(text, add_special_tokens=False) + + input_encoded = tokenizer.encode(input_text, add_special_tokens=False) + output_encoded = tokenizer.encode(output_text, add_special_tokens=False) + special_token_id = tokenizer.encode(special_token, add_special_tokens=False) + assert encoded == input_encoded + special_token_id + output_encoded + + decoded = tokenizer.decode(encoded, skip_special_tokens=True) + assert special_token not in decoded def test_required_methods_tokenizer(self): tokenizer = self.get_tokenizer() diff --git a/transformers/tokenization_utils.py b/transformers/tokenization_utils.py index cd14cc4582..f37f6f3206 100644 --- a/transformers/tokenization_utils.py +++ b/transformers/tokenization_utils.py @@ -1055,7 +1055,7 @@ class PreTrainedTokenizer(object): class attributes (cls_token, unk_token...). """ all_toks = self.all_special_tokens - all_ids = list(self._convert_token_to_id(t) for t in all_toks) + all_ids = self.convert_tokens_to_ids(all_toks) return all_ids @staticmethod