[ `PreTrainedTokenizerFast`] Keep properties from fast tokenizer (#25053)

* draft solution

* use `setdefault`

* nits

* add tests and fix truncation issue

* fix test

* test passes locally

* quality

* updates

* update tsets
This commit is contained in:
Arthur 2023-07-25 18:45:01 +02:00 committed by GitHub
parent 0779fc8eb8
commit f9cc333805
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 72 additions and 0 deletions

View File

@ -132,6 +132,26 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
self._decode_use_source_tokenizer = False
_truncation = self._tokenizer.truncation
if _truncation is not None:
self._tokenizer.enable_truncation(**_truncation)
kwargs.setdefault("max_length", _truncation["max_length"])
kwargs.setdefault("truncation_side", _truncation["direction"])
kwargs.setdefault("stride", _truncation["stride"])
kwargs.setdefault("truncation_strategy", _truncation["strategy"])
else:
self._tokenizer.no_truncation()
_padding = self._tokenizer.padding
if _padding is not None:
self._tokenizer.enable_padding(**_padding)
kwargs.setdefault("pad_token", _padding["pad_token"])
kwargs.setdefault("pad_token_type_id", _padding["pad_type_id"])
kwargs.setdefault("padding_side", _padding["direction"])
kwargs.setdefault("max_length", _padding["length"])
kwargs.setdefault("pad_to_multiple_of", _padding["pad_to_multiple_of"])
# We call this after having initialized the backend tokenizer because we update it.
super().__init__(**kwargs)

View File

@ -109,6 +109,58 @@ class PreTrainedTokenizationFastTest(TokenizerTesterMixin, unittest.TestCase):
encoding_ids = new_tokenizer.encode("a🤗")
self.assertEqual(encoding_ids, [64, 172, 253, 97, 245])
def test_init_from_tokenizers_model(self):
from tokenizers import Tokenizer
sentences = ["Hello, y'all!", "How are you 😁 ? There should not be any issue right?"]
tokenizer = Tokenizer.from_pretrained("t5-base")
# Enable padding
tokenizer.enable_padding(pad_id=0, pad_token="<pad>", length=512, pad_to_multiple_of=8)
self.assertEqual(
tokenizer.padding,
{
"length": 512,
"pad_to_multiple_of": 8,
"pad_id": 0,
"pad_token": "<pad>",
"pad_type_id": 0,
"direction": "right",
},
)
fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer)
tmpdirname = tempfile.mkdtemp()
fast_tokenizer.save_pretrained(tmpdirname)
fast_from_saved = PreTrainedTokenizerFast.from_pretrained(tmpdirname)
for tok in [fast_tokenizer, fast_from_saved]:
self.assertEqual(tok.pad_token_id, 0)
self.assertEqual(tok.padding_side, "right")
self.assertEqual(tok.pad_token, "<pad>")
self.assertEqual(tok.init_kwargs["max_length"], 512)
self.assertEqual(tok.init_kwargs["pad_to_multiple_of"], 8)
# fmt: off
self.assertEqual(tok(sentences, padding = True), {'input_ids': [[8774, 6, 3, 63, 31, 1748, 55, 1, 0, 0, 0, 0,0, 0, 0, 0],[ 571, 33, 25, 3, 2, 3, 58, 290, 225, 59, 36, 136, 962, 269, 58, 1]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]})
# fmt: on
tokenizer.enable_truncation(8, stride=0, strategy="longest_first", direction="right")
self.assertEqual(
tokenizer.truncation, {"max_length": 8, "stride": 0, "strategy": "longest_first", "direction": "right"}
)
fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer)
tmpdirname = tempfile.mkdtemp()
fast_tokenizer.save_pretrained(tmpdirname)
fast_from_saved = PreTrainedTokenizerFast.from_pretrained(tmpdirname)
for tok in [fast_tokenizer, fast_from_saved]:
self.assertEqual(tok.truncation_side, "right")
self.assertEqual(tok.init_kwargs["truncation_strategy"], "longest_first")
self.assertEqual(tok.init_kwargs["max_length"], 8)
self.assertEqual(tok.init_kwargs["stride"], 0)
# NOTE even if the model has a default max_length, it is not used...
# thus tok(sentences, truncation = True) does nothing and does not warn either
# fmt: off
self.assertEqual(tok(sentences, truncation = True, max_length = 8), {'input_ids': [[8774, 6, 3, 63, 31, 1748, 55, 1],[ 571, 33, 25, 3, 2, 3, 58, 1]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1]]})
# fmt: on
@require_tokenizers
class TokenizerVersioningTest(unittest.TestCase):