Consider do_lower_case in PreTrainedTokenizer

As pointed out in #1545, when using an uncased model, and adding
a new uncased token, the tokenizer does not correctly identify this
in the case that the input text contains the token in a cased format.

For instance, if we load bert-base-uncased into BertTokenizer, and
then use .add_tokens() to add "cool-token", we get the expected
result for .tokenize('this is a cool-token'). However, we get a
possibly unexpected result for .tokenize('this is a cOOl-Token'),
which in fact mirrors the result for the former from before the new
token was added.

This commit adds
- functionality to PreTrainedTokenizer to handle this
situation in case a tokenizer (currently Bert, DistilBert,
and XLNet) has the do_lower_case=True kwarg by:
    1) lowercasing tokens added with .add_tokens()
    2) lowercasing text at the beginning of .tokenize()
- new common test case for tokenizers

https://github.com/huggingface/transformers/issues/1545
This commit is contained in:
Michael Watkins 2019-11-06 13:18:16 +02:00
parent 8aba81a0b6
commit 7246d3c2f9
2 changed files with 35 additions and 1 deletions

View File

@ -110,6 +110,36 @@ class CommonTestCases:
self.assertListEqual(subwords, subwords_loaded) self.assertListEqual(subwords, subwords_loaded)
def test_added_tokens_do_lower_case(self):
tokenizer = self.get_tokenizer(do_lower_case=True)
text = "aaaaa bbbbbb low cccccccccdddddddd l"
text2 = "AAAAA BBBBBB low CCCCCCCCCDDDDDDDD l"
toks0 = tokenizer.tokenize(text) # toks before adding new_toks
new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd", 'AAAAA BBBBBB', 'CCCCCCCCCDDDDDDDD']
added = tokenizer.add_tokens(new_toks)
self.assertEqual(added, 2)
toks = tokenizer.tokenize(text)
toks2 = tokenizer.tokenize(text2)
self.assertEqual(len(toks), len(toks2))
self.assertNotEqual(len(toks), len(toks0)) # toks0 should be longer
self.assertListEqual(toks, toks2)
tokenizer = self.get_tokenizer(do_lower_case=False)
added = tokenizer.add_tokens(new_toks)
self.assertEqual(added, 4)
toks = tokenizer.tokenize(text)
toks2 = tokenizer.tokenize(text2)
self.assertEqual(len(toks), len(toks2)) # Length should still be the same
self.assertNotEqual(len(toks), len(toks0))
self.assertNotEqual(toks[0], toks2[0]) # But at least the first tokens should differ
def test_add_tokens_tokenizer(self): def test_add_tokens_tokenizer(self):
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
@ -160,7 +190,6 @@ class CommonTestCases:
self.assertEqual(tokens[0], tokenizer.eos_token_id) self.assertEqual(tokens[0], tokenizer.eos_token_id)
self.assertEqual(tokens[-2], tokenizer.pad_token_id) self.assertEqual(tokens[-2], tokenizer.pad_token_id)
def test_required_methods_tokenizer(self): def test_required_methods_tokenizer(self):
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
input_text, output_text = self.get_input_output_texts() input_text, output_text = self.get_input_output_texts()

View File

@ -512,6 +512,8 @@ class PreTrainedTokenizer(object):
to_add_tokens = [] to_add_tokens = []
for token in new_tokens: for token in new_tokens:
assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode)) assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode))
if self.init_kwargs.get('do_lower_case', False):
token = token.lower()
if token != self.unk_token and \ if token != self.unk_token and \
self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) and \ self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) and \
token not in to_add_tokens: token not in to_add_tokens:
@ -605,6 +607,9 @@ class PreTrainedTokenizer(object):
Take care of added tokens. Take care of added tokens.
""" """
if self.init_kwargs.get('do_lower_case', False):
text = text.lower()
def split_on_token(tok, text): def split_on_token(tok, text):
result = [] result = []
split_text = text.split(tok) split_text = text.split(tok)