Consider do_lower_case in PreTrainedTokenizer
As pointed out in #1545, when using an uncased model, and adding a new uncased token, the tokenizer does not correctly identify this in the case that the input text contains the token in a cased format. For instance, if we load bert-base-uncased into BertTokenizer, and then use .add_tokens() to add "cool-token", we get the expected result for .tokenize('this is a cool-token'). However, we get a possibly unexpected result for .tokenize('this is a cOOl-Token'), which in fact mirrors the result for the former from before the new token was added. This commit adds - functionality to PreTrainedTokenizer to handle this situation in case a tokenizer (currently Bert, DistilBert, and XLNet) has the do_lower_case=True kwarg by: 1) lowercasing tokens added with .add_tokens() 2) lowercasing text at the beginning of .tokenize() - new common test case for tokenizers https://github.com/huggingface/transformers/issues/1545
This commit is contained in:
parent
8aba81a0b6
commit
7246d3c2f9
|
@ -110,6 +110,36 @@ class CommonTestCases:
|
||||||
|
|
||||||
self.assertListEqual(subwords, subwords_loaded)
|
self.assertListEqual(subwords, subwords_loaded)
|
||||||
|
|
||||||
|
def test_added_tokens_do_lower_case(self):
|
||||||
|
tokenizer = self.get_tokenizer(do_lower_case=True)
|
||||||
|
|
||||||
|
text = "aaaaa bbbbbb low cccccccccdddddddd l"
|
||||||
|
text2 = "AAAAA BBBBBB low CCCCCCCCCDDDDDDDD l"
|
||||||
|
|
||||||
|
toks0 = tokenizer.tokenize(text) # toks before adding new_toks
|
||||||
|
|
||||||
|
new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd", 'AAAAA BBBBBB', 'CCCCCCCCCDDDDDDDD']
|
||||||
|
added = tokenizer.add_tokens(new_toks)
|
||||||
|
self.assertEqual(added, 2)
|
||||||
|
|
||||||
|
toks = tokenizer.tokenize(text)
|
||||||
|
toks2 = tokenizer.tokenize(text2)
|
||||||
|
|
||||||
|
self.assertEqual(len(toks), len(toks2))
|
||||||
|
self.assertNotEqual(len(toks), len(toks0)) # toks0 should be longer
|
||||||
|
self.assertListEqual(toks, toks2)
|
||||||
|
|
||||||
|
tokenizer = self.get_tokenizer(do_lower_case=False)
|
||||||
|
|
||||||
|
added = tokenizer.add_tokens(new_toks)
|
||||||
|
self.assertEqual(added, 4)
|
||||||
|
|
||||||
|
toks = tokenizer.tokenize(text)
|
||||||
|
toks2 = tokenizer.tokenize(text2)
|
||||||
|
|
||||||
|
self.assertEqual(len(toks), len(toks2)) # Length should still be the same
|
||||||
|
self.assertNotEqual(len(toks), len(toks0))
|
||||||
|
self.assertNotEqual(toks[0], toks2[0]) # But at least the first tokens should differ
|
||||||
|
|
||||||
def test_add_tokens_tokenizer(self):
|
def test_add_tokens_tokenizer(self):
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
|
@ -160,7 +190,6 @@ class CommonTestCases:
|
||||||
self.assertEqual(tokens[0], tokenizer.eos_token_id)
|
self.assertEqual(tokens[0], tokenizer.eos_token_id)
|
||||||
self.assertEqual(tokens[-2], tokenizer.pad_token_id)
|
self.assertEqual(tokens[-2], tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
|
||||||
def test_required_methods_tokenizer(self):
|
def test_required_methods_tokenizer(self):
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
input_text, output_text = self.get_input_output_texts()
|
input_text, output_text = self.get_input_output_texts()
|
||||||
|
|
|
@ -512,6 +512,8 @@ class PreTrainedTokenizer(object):
|
||||||
to_add_tokens = []
|
to_add_tokens = []
|
||||||
for token in new_tokens:
|
for token in new_tokens:
|
||||||
assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode))
|
assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode))
|
||||||
|
if self.init_kwargs.get('do_lower_case', False):
|
||||||
|
token = token.lower()
|
||||||
if token != self.unk_token and \
|
if token != self.unk_token and \
|
||||||
self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) and \
|
self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) and \
|
||||||
token not in to_add_tokens:
|
token not in to_add_tokens:
|
||||||
|
@ -605,6 +607,9 @@ class PreTrainedTokenizer(object):
|
||||||
|
|
||||||
Take care of added tokens.
|
Take care of added tokens.
|
||||||
"""
|
"""
|
||||||
|
if self.init_kwargs.get('do_lower_case', False):
|
||||||
|
text = text.lower()
|
||||||
|
|
||||||
def split_on_token(tok, text):
|
def split_on_token(tok, text):
|
||||||
result = []
|
result = []
|
||||||
split_text = text.split(tok)
|
split_text = text.split(tok)
|
||||||
|
|
Loading…
Reference in New Issue