Fix failing tokenizer tests (#31083)

* Fix failing tokenizer tests

* Use small tokenizer

* Fix remaining reference
This commit is contained in:
Lysandre Debut 2024-05-28 13:34:23 +02:00 committed by GitHub
parent 90da0b1c9f
commit a3c7b59e31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 15 additions and 29 deletions

View File

@ -29,7 +29,7 @@ class CohereTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
test_rust_tokenizer = True
test_slow_tokenizer = False
from_pretrained_vocab_key = "tokenizer_file"
from_pretrained_id = "CohereForAI/c4ai-command-r-v01"
from_pretrained_id = "hf-internal-testing/tiny-random-CohereForCausalLM"
special_tokens_map = {
"bos_token": "<BOS_TOKEN>",
"eos_token": "<|END_OF_TURN_TOKEN|>",
@ -39,7 +39,7 @@ class CohereTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def setUp(self):
super().setUp()
tokenizer = CohereTokenizerFast.from_pretrained("CohereForAI/c4ai-command-r-v01")
tokenizer = CohereTokenizerFast.from_pretrained("hf-internal-testing/tiny-random-CohereForCausalLM")
tokenizer.save_pretrained(self.tmpdirname)
def get_rust_tokenizer(self, **kwargs):
@ -57,7 +57,10 @@ class CohereTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer = self.get_rust_tokenizer()
INPUT_SENTENCES = ["The quick brown fox<|END_OF_TURN_TOKEN|>", "jumps over the lazy dog<|END_OF_TURN_TOKEN|>"]
TARGET_TOKENS = [[5, 2162, 6629, 19883, 73388, 255001], [5, 81, 25092, 2515, 1690, 46189, 9507, 255001]]
TARGET_TOKENS = [
[5, 60, 203, 746, 666, 980, 571, 222, 87, 96, 8],
[5, 82, 332, 88, 91, 544, 206, 257, 930, 97, 239, 435, 8],
]
computed_tokens = tokenizer.batch_encode_plus(INPUT_SENTENCES)["input_ids"]
self.assertListEqual(TARGET_TOKENS, computed_tokens)
@ -141,34 +144,17 @@ class CohereTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
],
]
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
# fmt: off
expected_tokens = [
[5, 255000, 255008, 5659, 1955, 1671, 19264, 171597, 21, 255001, 255000, 255006, 28339, 8, 255001],
[
5,
255000,
255008,
5659,
1955,
1671,
19264,
171597,
21,
255001,
255000,
255006,
28339,
8,
255001,
255000,
255007,
97190,
1726,
5694,
1933,
21,
255001,
],
[5, 36, 99, 59, 60, 41, 58, 60, 71, 55, 46, 71, 60, 61, 58, 54, 71, 60, 55, 51, 45, 54, 99, 38, 36, 99, 59, 65, 59, 60, 45, 53, 71, 60, 55, 51, 45, 54, 99, 38, 65, 243, 394, 204, 336, 84, 88, 887, 374, 216, 74, 286, 22, 8, 36, 99, 59, 60, 41, 58, 60, 71, 55, 46, 71, 60, 61, 58, 54, 71, 60, 55, 51, 45, 54, 99, 38, 36, 99, 61, 59, 45, 58, 71, 60, 55, 51, 45, 54, 99, 38, 48, 420, 87, 9, 8],
[5, 36, 99, 59, 60, 41, 58, 60, 71, 55, 46, 71, 60, 61, 58, 54, 71, 60, 55, 51, 45, 54, 99, 38, 36, 99, 59, 65,
59, 60, 45, 53, 71, 60, 55, 51, 45, 54, 99, 38, 65, 243, 394, 204, 336, 84, 88, 887, 374, 216, 74, 286, 22, 8,
36, 99, 59, 60, 41, 58, 60, 71, 55, 46, 71, 60, 61, 58, 54, 71, 60, 55, 51, 45, 54, 99, 38, 36, 99, 61, 59,
45, 58, 71, 60, 55, 51, 45, 54, 99, 38, 48, 420, 87, 9, 8, 36, 99, 59, 60, 41, 58, 60, 71, 55, 46, 71, 60, 61,
58, 54, 71, 60, 55, 51, 45, 54, 99, 38, 36, 99, 43, 48, 41, 60, 42, 55, 60, 71, 60, 55, 51, 45, 54, 99, 38,
54, 567, 235, 693, 276, 411, 243, 22, 8]
]
# fmt: on
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
self.assertListEqual(tokenized_chat, expected_tokens)