Fixing OPT fast tokenizer option. (#18753)

* Fixing OPT fast tokenizer option.

* Remove dependency on `pt`.

* Move it to GPT2 tokenization tests.

* Added a few tests.
This commit is contained in:
Nicolas Patry 2022-09-15 17:12:58 +02:00 committed by GitHub
parent 578e18e002
commit 68bb33d770
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 70 additions and 13 deletions

View File

@ -282,8 +282,20 @@ class GPT2Converter(Converter):
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
tokenizer.decoder = decoders.ByteLevel()
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
if self.original_tokenizer.add_bos_token:
bos = self.original_tokenizer.bos_token
bos_token_id = self.original_tokenizer.bos_token_id
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{bos}:0 $A:0", # token_type_id is 2 for Funnel transformer
pair=f"{bos}:0 $A:0 $B:1",
special_tokens=[
(bos, bos_token_id),
],
)
else:
# XXX trim_offsets=False actually means this post_processor doesn't
# really do anything.
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
return tokenizer

View File

@ -146,16 +146,7 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
**kwargs,
)
if kwargs.pop("add_bos_token", False):
model_id = kwargs.pop("name_or_path", "")
raise ValueError(
"Currenty GPT2's fast tokenizer does NOT support adding a BOS token."
"Instead you should use GPT2's slow tokenizer class `GPT2Tokenizer` as follows: \n"
f"`GPT2Tokenizer.from_pretrained('{model_id}')`\nor\n"
f"`AutoTokenizer.from_pretrained('{model_id}', use_fast=False)`\n"
"This issue will be fixed soon, see: https://github.com/huggingface/tokenizers/pull/1005."
" so that the fast tokenizer works correctly."
)
self.add_bos_token = kwargs.pop("add_bos_token", False)
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:

View File

@ -18,7 +18,7 @@ import json
import os
import unittest
from transformers import GPT2Tokenizer, GPT2TokenizerFast
from transformers import AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import VOCAB_FILES_NAMES
from transformers.testing_utils import require_tokenizers
@ -275,3 +275,57 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
]
filtered_sequence = [x for x in filtered_sequence if x is not None]
self.assertEqual(encoded_sequence, filtered_sequence)
@require_tokenizers
class OPTTokenizationTest(unittest.TestCase):
def test_serialize_deserialize_fast_opt(self):
# More context:
# https://huggingface.co/wjmcat/opt-350m-paddle/discussions/1
# https://huggingface.slack.com/archives/C01N44FJDHT/p1653511495183519
# https://github.com/huggingface/transformers/pull/17088#discussion_r871246439
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m", from_slow=True)
text = "A photo of a cat"
tokens_ids = tokenizer.encode(
text,
)
self.assertEqual(tokens_ids, [2, 250, 1345, 9, 10, 4758])
tokenizer.save_pretrained("test_opt")
tokenizer = AutoTokenizer.from_pretrained("./test_opt")
tokens_ids = tokenizer.encode(
text,
)
self.assertEqual(tokens_ids, [2, 250, 1345, 9, 10, 4758])
def test_fast_slow_equivalence(self):
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m", use_slow=True)
text = "A photo of a cat"
tokens_ids = tokenizer.encode(
text,
)
# Same as above
self.assertEqual(tokens_ids, [2, 250, 1345, 9, 10, 4758])
def test_users_can_modify_bos(self):
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m", from_slow=True)
tokenizer.bos_token = "bos"
tokenizer.bos_token_id = tokenizer.get_vocab()["bos"]
text = "A photo of a cat"
tokens_ids = tokenizer.encode(
text,
)
# We changed the bos token
self.assertEqual(tokens_ids, [31957, 250, 1345, 9, 10, 4758])
tokenizer.save_pretrained("./tok")
tokenizer = AutoTokenizer.from_pretrained("./tok")
self.assertTrue(tokenizer.is_fast)
tokens_ids = tokenizer.encode(
text,
)
self.assertEqual(tokens_ids, [31957, 250, 1345, 9, 10, 4758])