Force pad_token_id to be set before padding for standard tokenizer (#3035)

* force pad_token_id to be set before padding

* fix tests and forbid padding without having a padding_token_id set
This commit is contained in:
Patrick von Platen 2020-03-02 16:53:55 +01:00 committed by GitHub
parent b54ef78d0c
commit c0135194eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 1 deletions

View File

@ -1012,6 +1012,12 @@ class PreTrainedTokenizer(object):
"https://github.com/huggingface/transformers/pull/2674"
)
# Throw an error if we can pad because there is no padding token
if pad_to_max_length and self.pad_token_id is None:
raise ValueError(
"Unable to set proper padding strategy as the tokenizer does not have a padding token. In this case please set the `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` or add a new pad token via the function add_special_tokens if you want to use a padding strategy"
)
first_ids = get_input_ids(text)
second_ids = get_input_ids(text_pair) if text_pair is not None else None
@ -1115,6 +1121,12 @@ class PreTrainedTokenizer(object):
"Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
)
# Throw an error if we can pad because there is no padding token
if pad_to_max_length and self.pad_token_id is None:
raise ValueError(
"Unable to set proper padding strategy as the tokenizer does not have a padding token. In this case please set the `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` or add a new pad token via the function add_special_tokens if you want to use a padding strategy"
)
if return_offsets_mapping:
raise NotImplementedError(
"return_offset_mapping is not available when using Python tokenizers."
@ -1788,7 +1800,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
# Throw an error if we can pad because there is no padding token
if pad_to_max_length and self.pad_token_id is None:
raise ValueError("Unable to set proper padding strategy as the tokenizer does have padding token")
raise ValueError("Unable to set proper padding strategy as the tokenizer does not have a padding token")
# Set the truncation and padding strategy and restore the initial configuration
with truncate_and_pad(

View File

@ -449,6 +449,10 @@ class TokenizerTesterMixin:
sequence = "Sequence"
padding_size = 10
# check correct behaviour if no pad_token_id exists and add it eventually
self._check_no_pad_token_padding(tokenizer, sequence)
padding_idx = tokenizer.pad_token_id
# RIGHT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
@ -490,6 +494,10 @@ class TokenizerTesterMixin:
tokenizer = self.get_tokenizer()
sequence = "Sequence"
# check correct behaviour if no pad_token_id exists and add it eventually
self._check_no_pad_token_padding(tokenizer, sequence)
padding_size = 10
padding_idx = tokenizer.pad_token_id
token_type_padding_idx = tokenizer.pad_token_type_id
@ -503,6 +511,7 @@ class TokenizerTesterMixin:
# Test right padding
tokenizer.padding_side = "right"
padded_sequence = tokenizer.encode_plus(
sequence,
max_length=sequence_length + padding_size,
@ -588,10 +597,14 @@ class TokenizerTesterMixin:
maximum_length = len(max([encoded_sequence["input_ids"] for encoded_sequence in encoded_sequences], key=len))
# check correct behaviour if no pad_token_id exists and add it eventually
self._check_no_pad_token_padding(tokenizer, sequences)
encoded_sequences_padded = [
tokenizer.encode_plus(sequence, pad_to_max_length=True, max_length=maximum_length)
for sequence in sequences
]
encoded_sequences_batch_padded = tokenizer.batch_encode_plus(sequences, pad_to_max_length=True)
self.assertListEqual(
encoded_sequences_padded,
@ -610,6 +623,10 @@ class TokenizerTesterMixin:
]
max_length = 100
# check correct behaviour if no pad_token_id exists and add it eventually
self._check_no_pad_token_padding(tokenizer, sequences)
encoded_sequences = [
tokenizer.encode_plus(sequence, pad_to_max_length=True, max_length=max_length) for sequence in sequences
]
@ -620,6 +637,7 @@ class TokenizerTesterMixin:
# Left padding tests
tokenizer = self.get_tokenizer()
tokenizer.padding_side = "left"
sequences = [
"Testing batch encode plus",
@ -628,6 +646,10 @@ class TokenizerTesterMixin:
]
max_length = 100
# check correct behaviour if no pad_token_id exists and add it eventually
self._check_no_pad_token_padding(tokenizer, sequences)
encoded_sequences = [
tokenizer.encode_plus(sequence, pad_to_max_length=True, max_length=max_length) for sequence in sequences
]
@ -668,3 +690,15 @@ class TokenizerTesterMixin:
encoded_value = encoded_sequences[key]
self.assertEqual(pytorch_value, tensorflow_value, encoded_value)
def _check_no_pad_token_padding(self, tokenizer, sequences):
# if tokenizer does not have pad_token_id, an error should be thrown
if tokenizer.pad_token_id is None:
with self.assertRaises(ValueError):
if isinstance(sequences, list):
tokenizer.batch_encode_plus(sequences, pad_to_max_length=True)
else:
tokenizer.encode_plus(sequences, pad_to_max_length=True)
# add pad_token_id to pass subsequent tests
tokenizer.add_special_tokens({"pad_token": "<PAD>"})