Nit-added-tokens (#26538)
* fix stripping * nits * fix another test * styling * fix? * update * revert bad merge * found the bug * YES SIR * is that change really required? * make fast even faster * re order functions
This commit is contained in:
parent
245da7ed38
commit
1a2e966cfe
|
@ -367,6 +367,25 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||
|
||||
self._decode_use_source_tokenizer = False
|
||||
|
||||
@property
|
||||
def is_fast(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
"""
|
||||
`int`: Size of the base vocabulary (without the added tokens).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def added_tokens_encoder(self) -> Dict[str, int]:
|
||||
"""
|
||||
Returns the sorted mapping from string to index. The added tokens encoder is cached for performance
|
||||
optimisation in `self._added_tokens_encoder` for the slow tokenizers.
|
||||
"""
|
||||
return {k.content: v for v, k in sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])}
|
||||
|
||||
@property
|
||||
def added_tokens_decoder(self) -> Dict[int, AddedToken]:
|
||||
"""
|
||||
|
@ -389,17 +408,6 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||
self._added_tokens_decoder[index] = AddedToken(token) if isinstance(token, str) else token
|
||||
self._added_tokens_encoder[str(token)] = index
|
||||
|
||||
@property
|
||||
def is_fast(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
"""
|
||||
`int`: Size of the base vocabulary (without the added tokens).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_added_vocab(self) -> Dict[str, int]:
|
||||
"""
|
||||
Returns the added tokens in the vocabulary as a dictionary of token to index. Results might be different from
|
||||
|
|
|
@ -846,15 +846,26 @@ class SpecialTokensMixin:
|
|||
# We directly set the hidden value to allow initialization with special tokens
|
||||
# which are not yet in the vocabulary. Necessary for serialization/de-serialization
|
||||
# TODO clean this up at some point (probably by switching to fast tokenizers)
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if value is None:
|
||||
continue
|
||||
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
|
||||
if key == "additional_special_tokens":
|
||||
# TODO THIS IS NASTY! Will always reset tokens to default rstrip and lstrip because self.set_attr on strings
|
||||
# will not check the addedtokens decoder. WILL FIX TOMORROW
|
||||
assert isinstance(value, (list, tuple)), f"Value {value} is not a list or tuple"
|
||||
assert all(
|
||||
isinstance(t, (str, AddedToken)) for t in value
|
||||
), "One of the tokens is not a string or an AddedToken"
|
||||
if hasattr(self, "added_tokens_encoder"):
|
||||
extended_token = []
|
||||
for token in value:
|
||||
if isinstance(token, str) and str(token) in self.added_tokens_encoder:
|
||||
extended_token.append(self.added_tokens_decoder[self.added_tokens_encoder[str(token)]])
|
||||
else:
|
||||
extended_token.append(token)
|
||||
value = extended_token
|
||||
setattr(self, key, value)
|
||||
elif isinstance(value, (str)):
|
||||
value = AddedToken(value, normalized=False, special=True)
|
||||
|
@ -1674,14 +1685,6 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||
"""Sets processor class as an attribute."""
|
||||
self._processor_class = processor_class
|
||||
|
||||
@property
|
||||
def added_tokens_encoder(self) -> Dict[str, int]:
|
||||
"""
|
||||
Returns the sorted mapping from string to index. The added tokens encoder is cached for performance
|
||||
optimisation in `self._added_tokens_encoder` for the slow tokenizers.
|
||||
"""
|
||||
return {k.content: v for v, k in sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])}
|
||||
|
||||
@property
|
||||
def added_tokens_decoder(self) -> Dict[int, AddedToken]:
|
||||
raise NotImplementedError()
|
||||
|
@ -2196,9 +2199,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||
for idx, token in init_kwargs["added_tokens_decoder"].items():
|
||||
if isinstance(token, dict):
|
||||
token = AddedToken(**token)
|
||||
|
||||
if isinstance(token, AddedToken):
|
||||
added_tokens_decoder[int(idx)] = token
|
||||
if str(token) in additional_special_tokens:
|
||||
# at this point the token is in `additional_special_tokens` as an str, let's add the AddedToken info
|
||||
additional_special_tokens.remove(str(token))
|
||||
if token.special and token not in additional_special_tokens:
|
||||
additional_special_tokens.append(token)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary."
|
||||
|
@ -2381,9 +2388,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||
|
||||
tokenizer_config = copy.deepcopy(self.init_kwargs)
|
||||
|
||||
# TODO: Ensure the modified attributes (those are also in the __init__ kwargs) will give identical tokenizers
|
||||
# target_keys = self.init_kwargs.keys()
|
||||
target_keys = ["model_max_length", "clean_up_tokenization_spaces", "additional_special_tokens"]
|
||||
target_keys = list(self.init_kwargs.keys())
|
||||
target_keys += ["model_max_length", "clean_up_tokenization_spaces", "additional_special_tokens"]
|
||||
for k in target_keys:
|
||||
if hasattr(self, k):
|
||||
tokenizer_config[k] = getattr(self, k)
|
||||
|
|
|
@ -185,6 +185,14 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
|||
def vocab(self) -> Dict[str, int]:
|
||||
return self.get_vocab()
|
||||
|
||||
@property
|
||||
def added_tokens_encoder(self) -> Dict[str, int]:
|
||||
"""
|
||||
Returns the sorted mapping from string to index. The added tokens encoder is cached for performance
|
||||
optimisation in `self._added_tokens_encoder` for the slow tokenizers.
|
||||
"""
|
||||
return {k.content: v for v, k in sorted(self.added_tokens_decoder.items(), key=lambda item: item[0])}
|
||||
|
||||
@property
|
||||
def added_tokens_decoder(self) -> Dict[int, AddedToken]:
|
||||
"""
|
||||
|
@ -202,10 +210,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
|||
Returns:
|
||||
`Dict[str, int]`: The added tokens.
|
||||
"""
|
||||
base_vocab = self._tokenizer.get_vocab(with_added_tokens=False)
|
||||
full_vocab = self._tokenizer.get_vocab(with_added_tokens=True)
|
||||
added_vocab = {tok: index for tok, index in full_vocab.items() if tok not in base_vocab}
|
||||
return added_vocab
|
||||
return {k.content: v for v, k in sorted(self.added_tokens_decoder.items(), key=lambda item: item[0])}
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue