Fix special tokens addition in decoder
This commit is contained in:
parent
d409aca326
commit
74d0bcb6ff
|
@ -160,6 +160,26 @@ class CommonTestCases:
|
|||
self.assertEqual(tokens[0], tokenizer.eos_token_id)
|
||||
self.assertEqual(tokens[-2], tokenizer.pad_token_id)
|
||||
|
||||
def test_add_special_tokens(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
input_text, output_text = self.get_input_output_texts()
|
||||
|
||||
special_token = "[SPECIAL TOKEN]"
|
||||
|
||||
tokenizer.add_special_tokens({"cls_token": special_token})
|
||||
encoded_special_token = tokenizer.encode(special_token, add_special_tokens=False)
|
||||
assert len(encoded_special_token) == 1
|
||||
|
||||
text = " ".join([input_text, special_token, output_text])
|
||||
encoded = tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
input_encoded = tokenizer.encode(input_text, add_special_tokens=False)
|
||||
output_encoded = tokenizer.encode(output_text, add_special_tokens=False)
|
||||
special_token_id = tokenizer.encode(special_token, add_special_tokens=False)
|
||||
assert encoded == input_encoded + special_token_id + output_encoded
|
||||
|
||||
decoded = tokenizer.decode(encoded, skip_special_tokens=True)
|
||||
assert special_token not in decoded
|
||||
|
||||
def test_required_methods_tokenizer(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
|
|
@ -1055,7 +1055,7 @@ class PreTrainedTokenizer(object):
|
|||
class attributes (cls_token, unk_token...).
|
||||
"""
|
||||
all_toks = self.all_special_tokens
|
||||
all_ids = list(self._convert_token_to_id(t) for t in all_toks)
|
||||
all_ids = self.convert_tokens_to_ids(all_toks)
|
||||
return all_ids
|
||||
|
||||
@staticmethod
|
||||
|
|
Loading…
Reference in New Issue