transformers.js/scripts/extra/esm.py

55 lines
2.1 KiB
Python

from transformers.convert_slow_tokenizer import Converter
from tokenizers import Tokenizer, pre_tokenizers, processors
from tokenizers.models import WordPiece
class EsmConverter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.vocab
tokenizer = Tokenizer(WordPiece(vocab, continuing_subword_prefix='', max_input_chars_per_word=int(
1e10), unk_token=str(self.original_tokenizer.unk_token)))
tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit()
cls = str(self.original_tokenizer.cls_token)
cls_token_id = self.original_tokenizer.cls_token_id
# No sep token in ESM vocabulary
sep = str(self.original_tokenizer.eos_token)
sep_token_id = self.original_tokenizer.eos_token_id
if sep_token_id is None:
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{cls}:0 $A:0",
special_tokens=[
(cls, cls_token_id),
],
)
else:
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{cls}:0 $A:0 {sep}:0",
pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
special_tokens=[
(cls, cls_token_id),
(sep, sep_token_id),
],
)
# For some reason, all tokens are added: none of them are special, but they all need special splitting.
# See https://github.com/huggingface/transformers/blob/df5c5c62ae253055336f5bb0828ca8e3e15ab6bd/src/transformers/models/esm/tokenization_esm.py#L79-L80
special_tokens = []
other_tokens = []
for token, token_id in vocab.items():
if token[0] == '<' and token[-1] == '>' and token_id <= 3:
special_tokens.append(token)
else:
other_tokens.append(token)
tokenizer.add_special_tokens(special_tokens)
tokenizer.add_tokens(other_tokens)
return tokenizer
def generate_fast_tokenizer(tokenizer):
tokenizer.vocab = tokenizer._token_to_id
return EsmConverter(tokenizer).converted()