From c3248cf122014dce10c0c8d0e663a95c948493e3 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Wed, 11 Dec 2019 12:36:37 -0500 Subject: [PATCH] Tests for all tokenizers --- transformers/tests/tokenization_bert_test.py | 13 ------------- transformers/tests/tokenization_gpt2_test.py | 15 --------------- transformers/tests/tokenization_tests_commons.py | 9 +++++++++ 3 files changed, 9 insertions(+), 28 deletions(-) diff --git a/transformers/tests/tokenization_bert_test.py b/transformers/tests/tokenization_bert_test.py index 77b124cdf2..c503ea5e1e 100644 --- a/transformers/tests/tokenization_bert_test.py +++ b/transformers/tests/tokenization_bert_test.py @@ -99,19 +99,6 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): self.assertListEqual( tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) - def test_encode_decode_with_spaces(self): - tokenizer = self.get_tokenizer() - - new_toks = ['[ABC]', '[DEF]', 'GHI IHG'] - tokenizer.add_tokens(new_toks) - input = "unwanted running [ABC] [DEF] running unwanted [ABC] GHI IHG unwanted [DEF]" - encoded = tokenizer.encode(input) - decoded = tokenizer.decode(encoded) - self.assertEqual( - decoded.lower(), - ("[CLS] " + input + " [SEP]").lower() - ) - def test_is_whitespace(self): self.assertTrue(_is_whitespace(u" ")) self.assertTrue(_is_whitespace(u"\t")) diff --git a/transformers/tests/tokenization_gpt2_test.py b/transformers/tests/tokenization_gpt2_test.py index 1b4fe42874..5eae767bdf 100644 --- a/transformers/tests/tokenization_gpt2_test.py +++ b/transformers/tests/tokenization_gpt2_test.py @@ -67,20 +67,5 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): self.assertListEqual( tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) - def test_encode_decode_with_spaces(self): - tokenizer = self.get_tokenizer() - - new_toks = ['[ABC]', '[DEF]', 'GHI IHG'] - tokenizer.add_tokens(new_toks) - input = "lower newer [ABC] [DEF] newer lower [ABC] GHI IHG newer lower [DEF]" - encoded = tokenizer.encode(input) - decoded = tokenizer.decode(encoded) - self.assertEqual( - decoded.lower(), - input.lower() - ) - - - if __name__ == '__main__': unittest.main() diff --git a/transformers/tests/tokenization_tests_commons.py b/transformers/tests/tokenization_tests_commons.py index c009958135..13e7ae746a 100644 --- a/transformers/tests/tokenization_tests_commons.py +++ b/transformers/tests/tokenization_tests_commons.py @@ -232,6 +232,15 @@ class CommonTestCases: self.assertNotEqual(len(tokens_2), 0) self.assertIsInstance(text_2, (str, unicode)) + def test_encode_decode_with_spaces(self): + tokenizer = self.get_tokenizer() + + new_toks = ['[ABC]', '[DEF]', 'GHI IHG'] + tokenizer.add_tokens(new_toks) + input = "[ABC] [DEF] [ABC] GHI IHG [DEF]" + encoded = tokenizer.encode(input, add_special_tokens=False) + decoded = tokenizer.decode(encoded) + self.assertEqual(decoded, input) def test_pretrained_model_lists(self): weights_list = list(self.tokenizer_class.max_model_input_sizes.keys())