⚠️⚠️[`T5Tokenize`] Fix T5 family tokenizers⚠️⚠️ (#24565)

* don't add space before single letter chars that don't have a merge

* fix the fix

* fixup

* add a test

* more testing

* fixup

* hack to make sure fast is also fixed

* update switch transformers test

* revert convert slow

* Update src/transformers/models/t5/tokenization_t5.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* add typechecking

* quality

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Arthur 2023-06-30 14:00:43 +09:00 committed by GitHub
parent 9e28750287
commit b52a03cd3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 53 additions and 7 deletions

View File

@ -19,11 +19,15 @@ import os
import re
import warnings
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import sentencepiece as spm
from ...tokenization_utils import PreTrainedTokenizer
if TYPE_CHECKING:
from ...tokenization_utils_base import TextInput
from ...utils import logging
@ -51,6 +55,8 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"t5-11b": 512,
}
SPIECE_UNDERLINE = ""
class T5Tokenizer(PreTrainedTokenizer):
"""
@ -294,9 +300,17 @@ class T5Tokenizer(PreTrainedTokenizer):
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file)
def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
if not text.startswith(" "):
text = " " + text
return super().tokenize(text, **kwargs)
def _tokenize(self, text: str) -> List[str]:
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
return self.sp_model.encode(text, out_type=str)
tokens = self.sp_model.encode(text, out_type=str)
if not text.startswith(" ") and tokens[0] == SPIECE_UNDERLINE:
tokens = tokens[1:]
return tokens
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""

View File

@ -1149,7 +1149,7 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase):
model = SwitchTransformersForConditionalGeneration.from_pretrained(
"google/switch-base-8", torch_dtype=torch.bfloat16
).eval()
tokenizer = AutoTokenizer.from_pretrained("t5-small")
tokenizer = AutoTokenizer.from_pretrained("t5-small", use_fast=False)
model = model.to(torch_device)
input_ids = tokenizer(
@ -1160,13 +1160,13 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase):
self.assertEqual(output_str, "drink.")
input_ids = tokenizer(
"A <extra_id_0> walks into a bar a orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>.",
"A <extra_id_0> walks into a bar and orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>.",
return_tensors="pt",
).input_ids.to(torch_device)
sequences = model.generate(input_ids)
output_str = tokenizer.batch_decode(sequences, skip_special_tokens=False)[0]
EXPECTED_OUTPUT = "<pad><extra_id_0> man<extra_id_1> beer<extra_id_2> a<extra_id_3> salt<extra_id_4>.</s>"
EXPECTED_OUTPUT = "<pad><extra_id_0> man<extra_id_1> beer<extra_id_2> a<extra_id_3> whiskey<extra_id_4>.</s>"
self.assertEqual(output_str, EXPECTED_OUTPUT)
def test_small_batch_generate(self):
@ -1174,10 +1174,10 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase):
model = SwitchTransformersForConditionalGeneration.from_pretrained(
"google/switch-base-8", torch_dtype=torch.bfloat16
).eval()
tokenizer = AutoTokenizer.from_pretrained("t5-small")
tokenizer = AutoTokenizer.from_pretrained("t5-small", use_fast=False)
inputs = [
"A <extra_id_0> walks into a bar a orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>."
"A <extra_id_0> walks into a bar and orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>."
] * BATCH_SIZE
encoded_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt")

View File

@ -399,3 +399,35 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_get_sentinel_token_ids_for_fasttokenizer(self):
tokenizer = T5TokenizerFast(SAMPLE_VOCAB, extra_ids=10)
self.assertListEqual(sorted(tokenizer.get_sentinel_token_ids()), sorted(range(1000, 1010)))
def test_encode_extra_ids(self):
tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=0)
tokenizer.add_special_tokens({"additional_special_tokens": ["<extra_id_0>"]})
tokenizer._create_trie(tokenizer.all_special_tokens)
# TODO ArthurZ the above is necessary as addedTokens / intialization sucks. Trie is not correctly created
# So the extra ids are split....
input_ids = tokenizer.encode(". Hello")
self.assertEquals(input_ids, [7, 4, 156, 86, 20, 2])
tokens = tokenizer.tokenize(". Hello")
self.assertEquals(tokens, ["", ".", "▁He", "ll", "o"])
input_ids = tokenizer.encode(" . Hello")
self.assertEquals(input_ids, [7, 4, 156, 86, 20, 2])
tokens = tokenizer.tokenize(" . Hello")
self.assertEquals(tokens, ["", ".", "▁He", "ll", "o"])
input_ids = tokenizer.encode("Hello, <extra_id_0>I")
self.assertEquals(input_ids, [156, 86, 20, 3, 999, 8, 2])
tokens = tokenizer.tokenize("Hello, <extra_id_0>I")
self.assertEquals(tokens, ["▁He", "ll", "o", ",", "<extra_id_0>", "▁I"])
input_ids = tokenizer.encode("Hello, <extra_id_0>,")
self.assertEquals(input_ids, [156, 86, 20, 3, 999, 3, 2])
tokens = tokenizer.tokenize("Hello, <extra_id_0>,")
self.assertEquals(tokens, ["▁He", "ll", "o", ",", "<extra_id_0>", ","])
input_ids = tokenizer.encode(" <extra_id_0> ,")
self.assertEquals(input_ids, [999, 3, 2])
tokens = tokenizer.tokenize(" <extra_id_0> ,")
self.assertEquals(tokens, ["<extra_id_0>", ","]) # spaces are eaten by rstrip / lstrip