This commit is contained in:
thomwolf 2019-09-02 02:27:39 +02:00
parent b6cd856b08
commit fede4ef45d
2 changed files with 25 additions and 4 deletions

View File

@ -104,7 +104,7 @@ class CommonTestCases:
self.assertNotEqual(vocab_size, 0) self.assertNotEqual(vocab_size, 0)
self.assertEqual(vocab_size, all_size) self.assertEqual(vocab_size, all_size)
new_toks = ["aaaaabbbbbb", "cccccccccdddddddd"] new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"]
added_toks = tokenizer.add_tokens(new_toks) added_toks = tokenizer.add_tokens(new_toks)
vocab_size_2 = tokenizer.vocab_size vocab_size_2 = tokenizer.vocab_size
all_size_2 = len(tokenizer) all_size_2 = len(tokenizer)
@ -114,7 +114,9 @@ class CommonTestCases:
self.assertEqual(added_toks, len(new_toks)) self.assertEqual(added_toks, len(new_toks))
self.assertEqual(all_size_2, all_size + len(new_toks)) self.assertEqual(all_size_2, all_size + len(new_toks))
tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l") tokens = tokenizer.encode("aaaaa bbbbbb low cccccccccdddddddd l")
out_string = tokenizer.decode(tokens)
self.assertGreaterEqual(len(tokens), 4) self.assertGreaterEqual(len(tokens), 4)
self.assertGreater(tokens[0], tokenizer.vocab_size - 1) self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1) self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
@ -131,6 +133,7 @@ class CommonTestCases:
self.assertEqual(all_size_3, all_size_2 + len(new_toks_2)) self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l") tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l")
out_string = tokenizer.decode(tokens)
self.assertGreaterEqual(len(tokens), 6) self.assertGreaterEqual(len(tokens), 6)
self.assertGreater(tokens[0], tokenizer.vocab_size - 1) self.assertGreater(tokens[0], tokenizer.vocab_size - 1)

View File

@ -722,7 +722,7 @@ class PreTrainedTokenizer(object):
return self._convert_id_to_token(ids) return self._convert_id_to_token(ids)
tokens = [] tokens = []
for index in ids: for index in ids:
if index in self.all_special_ids and skip_special_tokens: if skip_special_tokens and index in self.all_special_ids:
continue continue
if index in self.added_tokens_decoder: if index in self.added_tokens_decoder:
tokens.append(self.added_tokens_decoder[index]) tokens.append(self.added_tokens_decoder[index])
@ -747,7 +747,25 @@ class PreTrainedTokenizer(object):
Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``. Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
""" """
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
text = self.convert_tokens_to_string(filtered_tokens)
# To avoid mixing byte-level and unicode for byte-level BPT
# we need to build string separatly for added tokens and byte-level tokens
# cf. https://github.com/huggingface/pytorch-transformers/issues/1133
sub_texts = []
current_sub_text = []
for token in filtered_tokens:
if skip_special_tokens and token in self.all_special_ids:
continue
if token in self.added_tokens_encoder:
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
text = ''.join(sub_texts)
if self._sep_token is not None and self._sep_token in text: if self._sep_token is not None and self._sep_token in text:
text = text.replace(self._cls_token, self._sep_token) text = text.replace(self._cls_token, self._sep_token)