Update conversion script to support `marian` models
This commit is contained in:
parent
84b931f2f1
commit
645e5ceba3
|
@ -197,6 +197,7 @@ def main():
|
|||
|
||||
model = TasksManager.get_model_from_task(
|
||||
task, model_path,
|
||||
framework='pt',
|
||||
)
|
||||
|
||||
onnx_config_constructor = TasksManager.get_exporter_config_constructor(
|
||||
|
@ -230,6 +231,15 @@ def main():
|
|||
# Save tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
tokenizer.save_pretrained(output_model_folder)
|
||||
|
||||
# Handle special cases
|
||||
if model.config.model_type == 'marian':
|
||||
import json
|
||||
from .tokenizers.marian import generate_tokenizer_json
|
||||
tokenizer_json = generate_tokenizer_json(model_path, tokenizer)
|
||||
|
||||
with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp:
|
||||
json.dump(tokenizer_json, fp)
|
||||
except KeyError:
|
||||
pass # No Tokenizer
|
||||
|
||||
|
|
102
scripts/tasks.py
102
scripts/tasks.py
|
@ -149,6 +149,106 @@ SUPPORTED_MODELS = {
|
|||
'token-classification',
|
||||
]
|
||||
},
|
||||
'marian': {
|
||||
'Helsinki-NLP/opus-mt-en-es': [
|
||||
'default',
|
||||
'seq2seq-lm-with-past',
|
||||
],
|
||||
'Helsinki-NLP/opus-mt-es-en': [
|
||||
'default',
|
||||
'seq2seq-lm-with-past',
|
||||
],
|
||||
'Helsinki-NLP/opus-mt-en-fr': [
|
||||
'default',
|
||||
'seq2seq-lm-with-past',
|
||||
],
|
||||
'Helsinki-NLP/opus-mt-fr-en': [
|
||||
'default',
|
||||
'seq2seq-lm-with-past',
|
||||
],
|
||||
'Helsinki-NLP/opus-mt-en-hi': [
|
||||
'default',
|
||||
'seq2seq-lm-with-past',
|
||||
],
|
||||
'Helsinki-NLP/opus-mt-hi-en': [
|
||||
'default',
|
||||
'seq2seq-lm-with-past',
|
||||
],
|
||||
'Helsinki-NLP/opus-mt-en-de': [
|
||||
'default',
|
||||
'seq2seq-lm-with-past',
|
||||
],
|
||||
'Helsinki-NLP/opus-mt-de-en': [
|
||||
'default',
|
||||
'seq2seq-lm-with-past',
|
||||
],
|
||||
'Helsinki-NLP/opus-mt-en-ru': [
|
||||
'default',
|
||||
'seq2seq-lm-with-past',
|
||||
],
|
||||
'Helsinki-NLP/opus-mt-ru-en': [
|
||||
'default',
|
||||
'seq2seq-lm-with-past',
|
||||
],
|
||||
'Helsinki-NLP/opus-mt-en-it': [
|
||||
'default',
|
||||
'seq2seq-lm-with-past',
|
||||
],
|
||||
'Helsinki-NLP/opus-mt-it-en': [
|
||||
'default',
|
||||
'seq2seq-lm-with-past',
|
||||
],
|
||||
'Helsinki-NLP/opus-mt-en-ar': [
|
||||
'default',
|
||||
'seq2seq-lm-with-past',
|
||||
],
|
||||
'Helsinki-NLP/opus-mt-ar-en': [
|
||||
'default',
|
||||
'seq2seq-lm-with-past',
|
||||
],
|
||||
'Helsinki-NLP/opus-mt-en-zh': [
|
||||
'default',
|
||||
'seq2seq-lm-with-past',
|
||||
],
|
||||
'Helsinki-NLP/opus-mt-zh-en': [
|
||||
'default',
|
||||
'seq2seq-lm-with-past',
|
||||
],
|
||||
'Helsinki-NLP/opus-mt-en-sv': [
|
||||
'default',
|
||||
'seq2seq-lm-with-past',
|
||||
],
|
||||
'Helsinki-NLP/opus-mt-sv-en': [
|
||||
'default',
|
||||
'seq2seq-lm-with-past',
|
||||
],
|
||||
'Helsinki-NLP/opus-mt-en-mul': [
|
||||
'default',
|
||||
'seq2seq-lm-with-past',
|
||||
],
|
||||
'Helsinki-NLP/opus-mt-mul-en': [
|
||||
'default',
|
||||
'seq2seq-lm-with-past',
|
||||
],
|
||||
'Helsinki-NLP/opus-mt-en-nl': [
|
||||
'default',
|
||||
'seq2seq-lm-with-past',
|
||||
],
|
||||
'Helsinki-NLP/opus-mt-nl-en': [
|
||||
'default',
|
||||
'seq2seq-lm-with-past',
|
||||
],
|
||||
'Helsinki-NLP/opus-mt-en-fi': [
|
||||
'default',
|
||||
'seq2seq-lm-with-past',
|
||||
],
|
||||
'Helsinki-NLP/opus-mt-fi-en': [
|
||||
'default',
|
||||
'seq2seq-lm-with-past',
|
||||
],
|
||||
|
||||
# TODO add more models, or dynamically generate this list
|
||||
},
|
||||
'mobilebert': {
|
||||
'google/mobilebert-uncased': [
|
||||
'default',
|
||||
|
@ -271,7 +371,7 @@ def main():
|
|||
for model_id, tasks in model_ids.items():
|
||||
for task in tasks:
|
||||
print(
|
||||
f'python ./scripts/convert.py --model_id {model_id} --from_hub --quantize --task {task}')
|
||||
f'python -m scripts.convert --model_id {model_id} --from_hub --quantize --task {task}')
|
||||
print()
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
import json
|
||||
from transformers.utils import cached_file
|
||||
|
||||
|
||||
def generate_tokenizer_json(model_path, tokenizer):
|
||||
# Marian models use two separate tokenizers for source and target languages.
|
||||
# So, we merge them into a single tokenizer.
|
||||
|
||||
vocab_file = cached_file(model_path, 'vocab.json')
|
||||
with open(vocab_file) as fp:
|
||||
vocab = json.load(fp)
|
||||
|
||||
added_tokens = [
|
||||
dict(
|
||||
id=vocab.get(x),
|
||||
special=True,
|
||||
content=x,
|
||||
single_word=False,
|
||||
lstrip=False,
|
||||
rstrip=False,
|
||||
normalized=False
|
||||
)
|
||||
for x in tokenizer.all_special_tokens
|
||||
]
|
||||
|
||||
tokenizer_json = {
|
||||
'version': '1.0',
|
||||
'truncation': None,
|
||||
'padding': None,
|
||||
'added_tokens': added_tokens,
|
||||
'normalizer': {
|
||||
'type': 'Precompiled',
|
||||
'precompiled_charsmap': None # TODO add this
|
||||
},
|
||||
'pre_tokenizer': {
|
||||
'type': 'Sequence',
|
||||
'pretokenizers': [
|
||||
{
|
||||
'type': 'WhitespaceSplit'
|
||||
},
|
||||
{
|
||||
'type': 'Metaspace',
|
||||
'replacement': '\u2581',
|
||||
'add_prefix_space': True
|
||||
}
|
||||
]
|
||||
},
|
||||
'post_processor': {
|
||||
'type': 'TemplateProcessing', 'single': [
|
||||
{'Sequence': {'id': 'A', 'type_id': 0}},
|
||||
{'SpecialToken': {'id': tokenizer.eos_token, 'type_id': 0}}
|
||||
],
|
||||
'pair': [
|
||||
{'Sequence': {'id': 'A', 'type_id': 0}},
|
||||
{'SpecialToken': {'id': tokenizer.eos_token, 'type_id': 0}},
|
||||
{'Sequence': {'id': 'B', 'type_id': 0}},
|
||||
{'SpecialToken': {'id': tokenizer.eos_token, 'type_id': 0}}
|
||||
],
|
||||
'special_tokens': {
|
||||
tokenizer.eos_token: {
|
||||
'id': tokenizer.eos_token,
|
||||
'ids': [tokenizer.eos_token_id],
|
||||
'tokens': [tokenizer.eos_token]
|
||||
}
|
||||
}
|
||||
},
|
||||
'decoder': {
|
||||
'type': 'Metaspace',
|
||||
'replacement': '\u2581',
|
||||
'add_prefix_space': True
|
||||
},
|
||||
'model': {
|
||||
'type': 'Unigram',
|
||||
'unk_id': 2,
|
||||
}
|
||||
}
|
||||
|
||||
# NOTE: Must have sentencepiece installed
|
||||
spm_source = tokenizer.spm_source
|
||||
spm_target = tokenizer.spm_target
|
||||
|
||||
src_vocab_dict = {
|
||||
spm_source.IdToPiece(i): spm_source.GetScore(i)
|
||||
for i in range(spm_source.GetPieceSize())
|
||||
}
|
||||
tgt_vocab_dict = {
|
||||
spm_target.IdToPiece(i): spm_target.GetScore(i)
|
||||
for i in range(spm_target.GetPieceSize())
|
||||
}
|
||||
|
||||
tokenizer_json['model']['vocab'] = [
|
||||
[
|
||||
k,
|
||||
0.0 if k in tokenizer.all_special_tokens else max(
|
||||
src_vocab_dict.get(k, -100), tgt_vocab_dict.get(k, -100))
|
||||
]
|
||||
for k in vocab
|
||||
]
|
||||
|
||||
return tokenizer_json
|
Loading…
Reference in New Issue