[Reformer] remove reformer pad_token_id (#7991)
* remove reformer pad_token_id * fix pegasus
This commit is contained in:
parent
3a40cdf58d
commit
4acfd1a8dc
|
@ -47,8 +47,8 @@ class PegasusTokenizer(ReformerTokenizer):
|
|||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(self, *args, pad_token="<pad>", **kwargs):
|
||||
super().__init__(*args, **kwargs, pad_token="<pad>")
|
||||
# Don't use reserved words added_token_encoder, added_tokens_decoder because of
|
||||
# AssertionError: Non-consecutive added token '1' found. in from_pretrained
|
||||
assert len(self.added_tokens_decoder) == 0
|
||||
|
|
|
@ -86,19 +86,10 @@ class ReformerTokenizer(PreTrainedTokenizer):
|
|||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
eos_token="</s>",
|
||||
unk_token="<unk>",
|
||||
pad_token="<pad>",
|
||||
additional_special_tokens=[],
|
||||
**kwargs
|
||||
):
|
||||
def __init__(self, vocab_file, eos_token="</s>", unk_token="<unk>", additional_special_tokens=[], **kwargs):
|
||||
super().__init__(
|
||||
eos_token=eos_token,
|
||||
unk_token=unk_token,
|
||||
pad_token=pad_token,
|
||||
additional_special_tokens=additional_special_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
@ -102,7 +102,6 @@ class ReformerTokenizerFast(PreTrainedTokenizerFast):
|
|||
tokenizer_file=None,
|
||||
eos_token="</s>",
|
||||
unk_token="<unk>",
|
||||
pad_token="<pad>",
|
||||
additional_special_tokens=[],
|
||||
**kwargs
|
||||
):
|
||||
|
@ -111,7 +110,6 @@ class ReformerTokenizerFast(PreTrainedTokenizerFast):
|
|||
tokenizer_file=tokenizer_file,
|
||||
eos_token=eos_token,
|
||||
unk_token=unk_token,
|
||||
pad_token=pad_token,
|
||||
additional_special_tokens=additional_special_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
@ -63,6 +63,50 @@ class ReformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
rust_ids = rust_tokenizer.encode(sequence)
|
||||
self.assertListEqual(ids, rust_ids)
|
||||
|
||||
def test_padding(self, max_length=15):
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
with self.subTest("{} ({})".format(tokenizer.__class__.__name__, pretrained_name)):
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
|
||||
# Simple input
|
||||
s = "This is a simple input"
|
||||
s2 = ["This is a simple input 1", "This is a simple input 2"]
|
||||
p = ("This is a simple input", "This is a pair")
|
||||
p2 = [
|
||||
("This is a simple input 1", "This is a simple input 2"),
|
||||
("This is a simple pair 1", "This is a simple pair 2"),
|
||||
]
|
||||
|
||||
# Simple input tests
|
||||
self.assertRaises(ValueError, tokenizer_r.encode, s, max_length=max_length, padding="max_length")
|
||||
|
||||
# Simple input
|
||||
self.assertRaises(ValueError, tokenizer_r.encode_plus, s, max_length=max_length, padding="max_length")
|
||||
|
||||
# Simple input
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
tokenizer_r.batch_encode_plus,
|
||||
s2,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
)
|
||||
|
||||
# Pair input
|
||||
self.assertRaises(ValueError, tokenizer_r.encode, p, max_length=max_length, padding="max_length")
|
||||
|
||||
# Pair input
|
||||
self.assertRaises(ValueError, tokenizer_r.encode_plus, p, max_length=max_length, padding="max_length")
|
||||
|
||||
# Pair input
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
tokenizer_r.batch_encode_plus,
|
||||
p2,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
)
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = ReformerTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
|
||||
|
|
Loading…
Reference in New Issue