fixing #1133
This commit is contained in:
parent
b6cd856b08
commit
fede4ef45d
|
@ -104,7 +104,7 @@ class CommonTestCases:
|
||||||
self.assertNotEqual(vocab_size, 0)
|
self.assertNotEqual(vocab_size, 0)
|
||||||
self.assertEqual(vocab_size, all_size)
|
self.assertEqual(vocab_size, all_size)
|
||||||
|
|
||||||
new_toks = ["aaaaabbbbbb", "cccccccccdddddddd"]
|
new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"]
|
||||||
added_toks = tokenizer.add_tokens(new_toks)
|
added_toks = tokenizer.add_tokens(new_toks)
|
||||||
vocab_size_2 = tokenizer.vocab_size
|
vocab_size_2 = tokenizer.vocab_size
|
||||||
all_size_2 = len(tokenizer)
|
all_size_2 = len(tokenizer)
|
||||||
|
@ -114,7 +114,9 @@ class CommonTestCases:
|
||||||
self.assertEqual(added_toks, len(new_toks))
|
self.assertEqual(added_toks, len(new_toks))
|
||||||
self.assertEqual(all_size_2, all_size + len(new_toks))
|
self.assertEqual(all_size_2, all_size + len(new_toks))
|
||||||
|
|
||||||
tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l")
|
tokens = tokenizer.encode("aaaaa bbbbbb low cccccccccdddddddd l")
|
||||||
|
out_string = tokenizer.decode(tokens)
|
||||||
|
|
||||||
self.assertGreaterEqual(len(tokens), 4)
|
self.assertGreaterEqual(len(tokens), 4)
|
||||||
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
||||||
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
||||||
|
@ -131,6 +133,7 @@ class CommonTestCases:
|
||||||
self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
|
self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
|
||||||
|
|
||||||
tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l")
|
tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l")
|
||||||
|
out_string = tokenizer.decode(tokens)
|
||||||
|
|
||||||
self.assertGreaterEqual(len(tokens), 6)
|
self.assertGreaterEqual(len(tokens), 6)
|
||||||
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
||||||
|
|
|
@ -722,7 +722,7 @@ class PreTrainedTokenizer(object):
|
||||||
return self._convert_id_to_token(ids)
|
return self._convert_id_to_token(ids)
|
||||||
tokens = []
|
tokens = []
|
||||||
for index in ids:
|
for index in ids:
|
||||||
if index in self.all_special_ids and skip_special_tokens:
|
if skip_special_tokens and index in self.all_special_ids:
|
||||||
continue
|
continue
|
||||||
if index in self.added_tokens_decoder:
|
if index in self.added_tokens_decoder:
|
||||||
tokens.append(self.added_tokens_decoder[index])
|
tokens.append(self.added_tokens_decoder[index])
|
||||||
|
@ -747,7 +747,25 @@ class PreTrainedTokenizer(object):
|
||||||
Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
|
Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
|
||||||
"""
|
"""
|
||||||
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
||||||
text = self.convert_tokens_to_string(filtered_tokens)
|
|
||||||
|
# To avoid mixing byte-level and unicode for byte-level BPT
|
||||||
|
# we need to build string separatly for added tokens and byte-level tokens
|
||||||
|
# cf. https://github.com/huggingface/pytorch-transformers/issues/1133
|
||||||
|
sub_texts = []
|
||||||
|
current_sub_text = []
|
||||||
|
for token in filtered_tokens:
|
||||||
|
if skip_special_tokens and token in self.all_special_ids:
|
||||||
|
continue
|
||||||
|
if token in self.added_tokens_encoder:
|
||||||
|
if current_sub_text:
|
||||||
|
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
|
||||||
|
current_sub_text = []
|
||||||
|
sub_texts.append(token)
|
||||||
|
else:
|
||||||
|
current_sub_text.append(token)
|
||||||
|
if current_sub_text:
|
||||||
|
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
|
||||||
|
text = ''.join(sub_texts)
|
||||||
|
|
||||||
if self._sep_token is not None and self._sep_token in text:
|
if self._sep_token is not None and self._sep_token in text:
|
||||||
text = text.replace(self._cls_token, self._sep_token)
|
text = text.replace(self._cls_token, self._sep_token)
|
||||||
|
|
Loading…
Reference in New Issue