Nit-added-tokens (#26538)

* fix stripping

* nits

* fix another test

* styling

* fix?

* update

* revert bad merge

* found the bug

* YES SIR

* is that change really required?

* make fast even faster

* re order functions
This commit is contained in:
Arthur 2023-10-03 12:23:46 +02:00 committed by GitHub
parent 245da7ed38
commit 1a2e966cfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 27 deletions

View File

@ -367,6 +367,25 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
self._decode_use_source_tokenizer = False
@property
def is_fast(self) -> bool:
return False
@property
def vocab_size(self) -> int:
"""
`int`: Size of the base vocabulary (without the added tokens).
"""
raise NotImplementedError
@property
def added_tokens_encoder(self) -> Dict[str, int]:
"""
Returns the sorted mapping from string to index. The added tokens encoder is cached for performance
optimisation in `self._added_tokens_encoder` for the slow tokenizers.
"""
return {k.content: v for v, k in sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])}
@property
def added_tokens_decoder(self) -> Dict[int, AddedToken]:
"""
@ -389,17 +408,6 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
self._added_tokens_decoder[index] = AddedToken(token) if isinstance(token, str) else token
self._added_tokens_encoder[str(token)] = index
@property
def is_fast(self) -> bool:
return False
@property
def vocab_size(self) -> int:
"""
`int`: Size of the base vocabulary (without the added tokens).
"""
raise NotImplementedError
def get_added_vocab(self) -> Dict[str, int]:
"""
Returns the added tokens in the vocabulary as a dictionary of token to index. Results might be different from

View File

@ -846,15 +846,26 @@ class SpecialTokensMixin:
# We directly set the hidden value to allow initialization with special tokens
# which are not yet in the vocabulary. Necessary for serialization/de-serialization
# TODO clean this up at some point (probably by switching to fast tokenizers)
for key, value in kwargs.items():
if value is None:
continue
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
if key == "additional_special_tokens":
# TODO THIS IS NASTY! Will always reset tokens to default rstrip and lstrip because self.set_attr on strings
# will not check the addedtokens decoder. WILL FIX TOMORROW
assert isinstance(value, (list, tuple)), f"Value {value} is not a list or tuple"
assert all(
isinstance(t, (str, AddedToken)) for t in value
), "One of the tokens is not a string or an AddedToken"
if hasattr(self, "added_tokens_encoder"):
extended_token = []
for token in value:
if isinstance(token, str) and str(token) in self.added_tokens_encoder:
extended_token.append(self.added_tokens_decoder[self.added_tokens_encoder[str(token)]])
else:
extended_token.append(token)
value = extended_token
setattr(self, key, value)
elif isinstance(value, (str)):
value = AddedToken(value, normalized=False, special=True)
@ -1674,14 +1685,6 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
"""Sets processor class as an attribute."""
self._processor_class = processor_class
@property
def added_tokens_encoder(self) -> Dict[str, int]:
"""
Returns the sorted mapping from string to index. The added tokens encoder is cached for performance
optimisation in `self._added_tokens_encoder` for the slow tokenizers.
"""
return {k.content: v for v, k in sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])}
@property
def added_tokens_decoder(self) -> Dict[int, AddedToken]:
raise NotImplementedError()
@ -2196,9 +2199,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
for idx, token in init_kwargs["added_tokens_decoder"].items():
if isinstance(token, dict):
token = AddedToken(**token)
if isinstance(token, AddedToken):
added_tokens_decoder[int(idx)] = token
if str(token) in additional_special_tokens:
# at this point the token is in `additional_special_tokens` as an str, let's add the AddedToken info
additional_special_tokens.remove(str(token))
if token.special and token not in additional_special_tokens:
additional_special_tokens.append(token)
else:
raise ValueError(
f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary."
@ -2381,9 +2388,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
tokenizer_config = copy.deepcopy(self.init_kwargs)
# TODO: Ensure the modified attributes (those are also in the __init__ kwargs) will give identical tokenizers
# target_keys = self.init_kwargs.keys()
target_keys = ["model_max_length", "clean_up_tokenization_spaces", "additional_special_tokens"]
target_keys = list(self.init_kwargs.keys())
target_keys += ["model_max_length", "clean_up_tokenization_spaces", "additional_special_tokens"]
for k in target_keys:
if hasattr(self, k):
tokenizer_config[k] = getattr(self, k)

View File

@ -185,6 +185,14 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
def vocab(self) -> Dict[str, int]:
return self.get_vocab()
@property
def added_tokens_encoder(self) -> Dict[str, int]:
"""
Returns the sorted mapping from string to index. The added tokens encoder is cached for performance
optimisation in `self._added_tokens_encoder` for the slow tokenizers.
"""
return {k.content: v for v, k in sorted(self.added_tokens_decoder.items(), key=lambda item: item[0])}
@property
def added_tokens_decoder(self) -> Dict[int, AddedToken]:
"""
@ -202,10 +210,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
Returns:
`Dict[str, int]`: The added tokens.
"""
base_vocab = self._tokenizer.get_vocab(with_added_tokens=False)
full_vocab = self._tokenizer.get_vocab(with_added_tokens=True)
added_vocab = {tok: index for tok, index in full_vocab.items() if tok not in base_vocab}
return added_vocab
return {k.content: v for v, k in sorted(self.added_tokens_decoder.items(), key=lambda item: item[0])}
def __len__(self) -> int:
"""