Tests for all tokenizers

This commit is contained in:
LysandreJik 2019-12-11 12:36:37 -05:00 committed by Lysandre Debut
parent f2ac50cb55
commit c3248cf122
3 changed files with 9 additions and 28 deletions

View File

@ -99,19 +99,6 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
self.assertListEqual(
tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
def test_encode_decode_with_spaces(self):
tokenizer = self.get_tokenizer()
new_toks = ['[ABC]', '[DEF]', 'GHI IHG']
tokenizer.add_tokens(new_toks)
input = "unwanted running [ABC] [DEF] running unwanted [ABC] GHI IHG unwanted [DEF]"
encoded = tokenizer.encode(input)
decoded = tokenizer.decode(encoded)
self.assertEqual(
decoded.lower(),
("[CLS] " + input + " [SEP]").lower()
)
def test_is_whitespace(self):
self.assertTrue(_is_whitespace(u" "))
self.assertTrue(_is_whitespace(u"\t"))

View File

@ -67,20 +67,5 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
def test_encode_decode_with_spaces(self):
tokenizer = self.get_tokenizer()
new_toks = ['[ABC]', '[DEF]', 'GHI IHG']
tokenizer.add_tokens(new_toks)
input = "lower newer [ABC] [DEF] newer lower [ABC] GHI IHG newer lower [DEF]"
encoded = tokenizer.encode(input)
decoded = tokenizer.decode(encoded)
self.assertEqual(
decoded.lower(),
input.lower()
)
if __name__ == '__main__':
unittest.main()

View File

@ -232,6 +232,15 @@ class CommonTestCases:
self.assertNotEqual(len(tokens_2), 0)
self.assertIsInstance(text_2, (str, unicode))
def test_encode_decode_with_spaces(self):
tokenizer = self.get_tokenizer()
new_toks = ['[ABC]', '[DEF]', 'GHI IHG']
tokenizer.add_tokens(new_toks)
input = "[ABC] [DEF] [ABC] GHI IHG [DEF]"
encoded = tokenizer.encode(input, add_special_tokens=False)
decoded = tokenizer.decode(encoded)
self.assertEqual(decoded, input)
def test_pretrained_model_lists(self):
weights_list = list(self.tokenizer_class.max_model_input_sizes.keys())