[ `PreTrainedTokenizerFast`] Keep properties from fast tokenizer (#25053)
* draft solution * use `setdefault` * nits * add tests and fix truncation issue * fix test * test passes locally * quality * updates * update tsets
This commit is contained in:
parent
0779fc8eb8
commit
f9cc333805
|
@ -132,6 +132,26 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
|||
|
||||
self._decode_use_source_tokenizer = False
|
||||
|
||||
_truncation = self._tokenizer.truncation
|
||||
|
||||
if _truncation is not None:
|
||||
self._tokenizer.enable_truncation(**_truncation)
|
||||
kwargs.setdefault("max_length", _truncation["max_length"])
|
||||
kwargs.setdefault("truncation_side", _truncation["direction"])
|
||||
kwargs.setdefault("stride", _truncation["stride"])
|
||||
kwargs.setdefault("truncation_strategy", _truncation["strategy"])
|
||||
else:
|
||||
self._tokenizer.no_truncation()
|
||||
|
||||
_padding = self._tokenizer.padding
|
||||
if _padding is not None:
|
||||
self._tokenizer.enable_padding(**_padding)
|
||||
kwargs.setdefault("pad_token", _padding["pad_token"])
|
||||
kwargs.setdefault("pad_token_type_id", _padding["pad_type_id"])
|
||||
kwargs.setdefault("padding_side", _padding["direction"])
|
||||
kwargs.setdefault("max_length", _padding["length"])
|
||||
kwargs.setdefault("pad_to_multiple_of", _padding["pad_to_multiple_of"])
|
||||
|
||||
# We call this after having initialized the backend tokenizer because we update it.
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
|
|
@ -109,6 +109,58 @@ class PreTrainedTokenizationFastTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
encoding_ids = new_tokenizer.encode("a🤗")
|
||||
self.assertEqual(encoding_ids, [64, 172, 253, 97, 245])
|
||||
|
||||
def test_init_from_tokenizers_model(self):
|
||||
from tokenizers import Tokenizer
|
||||
|
||||
sentences = ["Hello, y'all!", "How are you 😁 ? There should not be any issue right?"]
|
||||
|
||||
tokenizer = Tokenizer.from_pretrained("t5-base")
|
||||
# Enable padding
|
||||
tokenizer.enable_padding(pad_id=0, pad_token="<pad>", length=512, pad_to_multiple_of=8)
|
||||
self.assertEqual(
|
||||
tokenizer.padding,
|
||||
{
|
||||
"length": 512,
|
||||
"pad_to_multiple_of": 8,
|
||||
"pad_id": 0,
|
||||
"pad_token": "<pad>",
|
||||
"pad_type_id": 0,
|
||||
"direction": "right",
|
||||
},
|
||||
)
|
||||
fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer)
|
||||
tmpdirname = tempfile.mkdtemp()
|
||||
fast_tokenizer.save_pretrained(tmpdirname)
|
||||
fast_from_saved = PreTrainedTokenizerFast.from_pretrained(tmpdirname)
|
||||
for tok in [fast_tokenizer, fast_from_saved]:
|
||||
self.assertEqual(tok.pad_token_id, 0)
|
||||
self.assertEqual(tok.padding_side, "right")
|
||||
self.assertEqual(tok.pad_token, "<pad>")
|
||||
self.assertEqual(tok.init_kwargs["max_length"], 512)
|
||||
self.assertEqual(tok.init_kwargs["pad_to_multiple_of"], 8)
|
||||
# fmt: off
|
||||
self.assertEqual(tok(sentences, padding = True), {'input_ids': [[8774, 6, 3, 63, 31, 1748, 55, 1, 0, 0, 0, 0,0, 0, 0, 0],[ 571, 33, 25, 3, 2, 3, 58, 290, 225, 59, 36, 136, 962, 269, 58, 1]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]})
|
||||
# fmt: on
|
||||
|
||||
tokenizer.enable_truncation(8, stride=0, strategy="longest_first", direction="right")
|
||||
self.assertEqual(
|
||||
tokenizer.truncation, {"max_length": 8, "stride": 0, "strategy": "longest_first", "direction": "right"}
|
||||
)
|
||||
fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer)
|
||||
tmpdirname = tempfile.mkdtemp()
|
||||
fast_tokenizer.save_pretrained(tmpdirname)
|
||||
fast_from_saved = PreTrainedTokenizerFast.from_pretrained(tmpdirname)
|
||||
for tok in [fast_tokenizer, fast_from_saved]:
|
||||
self.assertEqual(tok.truncation_side, "right")
|
||||
self.assertEqual(tok.init_kwargs["truncation_strategy"], "longest_first")
|
||||
self.assertEqual(tok.init_kwargs["max_length"], 8)
|
||||
self.assertEqual(tok.init_kwargs["stride"], 0)
|
||||
# NOTE even if the model has a default max_length, it is not used...
|
||||
# thus tok(sentences, truncation = True) does nothing and does not warn either
|
||||
# fmt: off
|
||||
self.assertEqual(tok(sentences, truncation = True, max_length = 8), {'input_ids': [[8774, 6, 3, 63, 31, 1748, 55, 1],[ 571, 33, 25, 3, 2, 3, 58, 1]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1]]})
|
||||
# fmt: on
|
||||
|
||||
|
||||
@require_tokenizers
|
||||
class TokenizerVersioningTest(unittest.TestCase):
|
||||
|
|
Loading…
Reference in New Issue