Add tests for fast tokenizers
This commit is contained in:
parent
31c56f2e0b
commit
2818e50569
|
@ -21,6 +21,7 @@ from transformers.tokenization_bert import (
|
|||
VOCAB_FILES_NAMES,
|
||||
BasicTokenizer,
|
||||
BertTokenizer,
|
||||
BertTokenizerFast,
|
||||
WordpieceTokenizer,
|
||||
_is_control,
|
||||
_is_punctuation,
|
||||
|
@ -34,6 +35,7 @@ from .utils import slow
|
|||
class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
tokenizer_class = BertTokenizer
|
||||
test_rust_tokenizer = True
|
||||
|
||||
def setUp(self):
|
||||
super(BertTokenizationTest, self).setUp()
|
||||
|
@ -60,6 +62,9 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
def get_tokenizer(self, **kwargs):
|
||||
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_rust_tokenizer(self, **kwargs):
|
||||
return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = "UNwant\u00E9d,running"
|
||||
output_text = "unwanted, running"
|
||||
|
@ -72,6 +77,28 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||
|
||||
def test_rust_and_python_full_tokenizers(self):
|
||||
if not self.test_rust_tokenizer:
|
||||
return
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
rust_tokenizer = self.get_rust_tokenizer(add_special_tokens=False)
|
||||
|
||||
sequence = u"UNwant\u00E9d,running"
|
||||
|
||||
tokens = tokenizer.tokenize(sequence)
|
||||
rust_tokens = rust_tokenizer.tokenize(sequence)
|
||||
self.assertListEqual(tokens, rust_tokens)
|
||||
|
||||
ids = tokenizer.encode(sequence, add_special_tokens=False)
|
||||
rust_ids = rust_tokenizer.encode(sequence)
|
||||
self.assertListEqual(ids, rust_ids)
|
||||
|
||||
rust_tokenizer = self.get_rust_tokenizer()
|
||||
ids = tokenizer.encode(sequence)
|
||||
rust_ids = rust_tokenizer.encode(sequence)
|
||||
self.assertListEqual(ids, rust_ids)
|
||||
|
||||
def test_chinese(self):
|
||||
tokenizer = BasicTokenizer()
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ import tempfile
|
|||
class TokenizerTesterMixin:
|
||||
|
||||
tokenizer_class = None
|
||||
test_rust_tokenizer = False
|
||||
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
|
@ -33,6 +34,9 @@ class TokenizerTesterMixin:
|
|||
def get_tokenizer(self, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_rust_tokenizer(self, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_input_output_texts(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ import json
|
|||
import os
|
||||
import unittest
|
||||
|
||||
from transformers.tokenization_gpt2 import VOCAB_FILES_NAMES, GPT2Tokenizer
|
||||
from transformers.tokenization_gpt2 import VOCAB_FILES_NAMES, GPT2Tokenizer, GPT2TokenizerFast
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
@ -26,6 +26,7 @@ from .test_tokenization_common import TokenizerTesterMixin
|
|||
class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
tokenizer_class = GPT2Tokenizer
|
||||
test_rust_tokenizer = True
|
||||
|
||||
def setUp(self):
|
||||
super(GPT2TokenizationTest, self).setUp()
|
||||
|
@ -68,6 +69,10 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
kwargs.update(self.special_tokens_map)
|
||||
return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_rust_tokenizer(self, **kwargs):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return GPT2TokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = "lower newer"
|
||||
output_text = "lower newer"
|
||||
|
@ -83,3 +88,33 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
input_tokens = tokens + [tokenizer.unk_token]
|
||||
input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
def test_rust_and_python_full_tokenizers(self):
|
||||
if not self.test_rust_tokenizer:
|
||||
return
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
rust_tokenizer = self.get_rust_tokenizer(add_special_tokens=False, add_prefix_space=True)
|
||||
|
||||
sequence = u"lower newer"
|
||||
|
||||
# Testing tokenization
|
||||
tokens = tokenizer.tokenize(sequence, add_prefix_space=True)
|
||||
rust_tokens = rust_tokenizer.tokenize(sequence)
|
||||
self.assertListEqual(tokens, rust_tokens)
|
||||
|
||||
# Testing conversion to ids without special tokens
|
||||
ids = tokenizer.encode(sequence, add_special_tokens=False, add_prefix_space=True)
|
||||
rust_ids = rust_tokenizer.encode(sequence)
|
||||
self.assertListEqual(ids, rust_ids)
|
||||
|
||||
# Testing conversion to ids with special tokens
|
||||
rust_tokenizer = self.get_rust_tokenizer(add_prefix_space=True)
|
||||
ids = tokenizer.encode(sequence, add_prefix_space=True)
|
||||
rust_ids = rust_tokenizer.encode(sequence)
|
||||
self.assertListEqual(ids, rust_ids)
|
||||
|
||||
# Testing the unknown token
|
||||
input_tokens = tokens + [rust_tokenizer.unk_token]
|
||||
input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
|
||||
self.assertListEqual(rust_tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
|
Loading…
Reference in New Issue