2023-02-16 06:11:49 +08:00
2023-06-22 06:43:43 +08:00
import json
2023-02-16 06:11:49 +08:00
import os
2023-03-10 21:07:18 +08:00
import shutil
2023-02-16 06:11:49 +08:00
from dataclasses import dataclass , field
2023-06-30 01:32:17 +08:00
from typing import Optional , Set
2023-04-22 05:20:54 +08:00
from tqdm import tqdm
2023-02-16 06:11:49 +08:00
2023-04-21 00:56:25 +08:00
from transformers import (
AutoConfig ,
AutoTokenizer ,
HfArgumentParser
)
2023-02-16 06:11:49 +08:00
2023-04-21 00:56:25 +08:00
import onnx
2023-08-01 20:01:04 +08:00
from optimum . exporters . onnx import main_export , export_models
2023-02-16 06:11:49 +08:00
from optimum . exporters . tasks import TasksManager
2023-03-29 21:48:14 +08:00
from onnxruntime . quantization import (
quantize_dynamic ,
QuantType
)
2023-02-16 06:11:49 +08:00
2023-07-10 05:21:43 +08:00
DEFAULT_QUANTIZE_PARAMS = {
' per_channel ' : True ,
' reduce_range ' : True ,
}
MODEL_SPECIFIC_QUANTIZE_PARAMS = {
' whisper ' : {
' per_channel ' : False ,
' reduce_range ' : False ,
}
}
2023-02-16 06:11:49 +08:00
@dataclass
class ConversionArguments :
"""
Arguments used for converting HuggingFace models to onnx .
"""
model_id : str = field (
metadata = {
" help " : " Model identifier "
}
)
quantize : bool = field (
default = False ,
metadata = {
" help " : " Whether to quantize the model. "
}
)
output_parent_dir : str = field (
2023-04-21 00:56:25 +08:00
default = ' ./models/ ' ,
2023-02-16 06:11:49 +08:00
metadata = {
" help " : " Path where the converted model will be saved to. "
}
)
task : Optional [ str ] = field (
2023-04-21 00:56:25 +08:00
default = ' auto ' ,
2023-02-16 06:11:49 +08:00
metadata = {
" help " : (
" The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on the model, but are among: "
f " { str ( list ( TasksManager . _TASKS_TO_AUTOMODELS . keys ( ) ) ) } . For decoder models, use `xxx-with-past` to export the model using past key values in the decoder. "
)
}
)
opset : int = field (
default = None ,
metadata = {
" help " : (
" If specified, ONNX opset version to export the model with. Otherwise, the default opset will be used. "
)
}
)
device : str = field (
default = ' cpu ' ,
metadata = {
2023-04-02 19:18:30 +08:00
" help " : ' The device to use to do the export. '
2023-02-16 06:11:49 +08:00
}
)
2023-05-31 17:59:20 +08:00
skip_validation : bool = field (
default = False ,
metadata = {
" help " : " Whether to skip validation of the converted model "
}
)
2023-02-16 06:11:49 +08:00
2023-06-22 06:43:43 +08:00
per_channel : bool = field (
2023-07-10 05:21:43 +08:00
default = None ,
2023-06-22 06:43:43 +08:00
metadata = {
" help " : " Whether to quantize weights per channel "
}
)
reduce_range : bool = field (
2023-07-10 05:21:43 +08:00
default = None ,
2023-06-22 06:43:43 +08:00
metadata = {
" help " : " Whether to quantize weights with 7-bits. It may improve the accuracy for some models running on non-VNNI machine, especially for per-channel mode "
}
)
2023-07-10 05:21:43 +08:00
output_attentions : bool = field (
default = False ,
metadata = {
" help " : " Whether to output attentions from the model. NOTE: This is only supported for whisper models right now. "
}
)
2023-08-01 20:01:04 +08:00
split_modalities : bool = field (
default = False ,
metadata = {
" help " : " Whether to split multimodal models. NOTE: This is only supported for CLIP models right now. "
}
)
2023-06-22 06:43:43 +08:00
2023-06-30 01:32:17 +08:00
def get_operators ( model : onnx . ModelProto ) - > Set [ str ] :
2023-06-22 06:43:43 +08:00
operators = set ( )
2023-02-16 06:11:49 +08:00
2023-06-22 06:43:43 +08:00
def traverse_graph ( graph ) :
for node in graph . node :
operators . add ( node . op_type )
for attr in node . attribute :
if attr . type == onnx . AttributeProto . GRAPH :
subgraph = attr . g
traverse_graph ( subgraph )
traverse_graph ( model . graph )
return operators
2023-07-10 05:21:43 +08:00
def quantize ( model_names_or_paths , * * quantize_kwargs ) :
2023-02-21 22:53:32 +08:00
"""
Quantize the weights of the model from float32 to int8 to allow very efficient inference on modern CPU
Uses unsigned ints for activation values , signed ints for weights , per
https : / / onnxruntime . ai / docs / performance / quantization . html #data-type-selection
it is faster on most CPU architectures
Args :
onnx_model_path : Path to location the exported ONNX model is stored
Returns : The Path generated for the quantized
"""
2023-07-10 05:21:43 +08:00
quantize_config = dict (
* * quantize_kwargs ,
2023-06-22 06:43:43 +08:00
per_model_config = { }
)
2023-04-21 00:56:25 +08:00
for model in tqdm ( model_names_or_paths , desc = ' Quantizing ' ) :
directory_path = os . path . dirname ( model )
file_name_without_extension = os . path . splitext (
os . path . basename ( model ) ) [ 0 ]
2023-03-08 05:53:50 +08:00
2023-04-21 00:56:25 +08:00
# NOTE:
# As of 2023/04/20, the current latest version of onnxruntime-web is 1.14.0, and does not support INT8 weights for Conv layers.
# For this reason, we choose model weight types to ensure compatibility with onnxruntime-web.
#
# As per docs, signed weight type (QInt8) is faster on most CPUs, so, we use that unless the model contains a Conv layer.
# For more information, see:
# - https://github.com/microsoft/onnxruntime/issues/3130#issuecomment-1105200621
# - https://github.com/microsoft/onnxruntime/issues/2339
2023-06-22 06:43:43 +08:00
loaded_model = onnx . load_model ( model )
op_types = get_operators ( loaded_model )
2023-04-21 00:56:25 +08:00
weight_type = QuantType . QUInt8 if ' Conv ' in op_types else QuantType . QInt8
2023-03-08 05:53:50 +08:00
2023-02-21 22:53:32 +08:00
quantize_dynamic (
model_input = model ,
2023-04-21 00:56:25 +08:00
model_output = os . path . join (
directory_path , f ' { file_name_without_extension } _quantized.onnx ' ) ,
2023-05-30 20:57:44 +08:00
2023-03-08 05:53:50 +08:00
weight_type = weight_type ,
2023-02-21 22:53:32 +08:00
optimize_model = False ,
2023-04-21 00:56:25 +08:00
# TODO allow user to specify these
# op_types_to_quantize=['MatMul', 'Add', 'Conv'],
extra_options = dict (
EnableSubgraph = True
2023-07-10 05:21:43 +08:00
) ,
* * quantize_kwargs
2023-04-21 00:56:25 +08:00
)
2023-02-21 22:53:32 +08:00
2023-07-10 05:21:43 +08:00
quantize_config [ ' per_model_config ' ] [ file_name_without_extension ] = dict (
2023-06-22 06:43:43 +08:00
op_types = list ( op_types ) ,
weight_type = str ( weight_type ) ,
)
# Save quantization config
2023-07-10 05:21:43 +08:00
with open ( os . path . join ( directory_path , ' quantize_config.json ' ) , ' w ' ) as fp :
json . dump ( quantize_config , fp , indent = 4 )
2023-03-10 21:07:18 +08:00
2023-02-16 06:11:49 +08:00
def main ( ) :
parser = HfArgumentParser (
( ConversionArguments , )
)
conv_args , = parser . parse_args_into_dataclasses ( )
2023-04-21 00:56:25 +08:00
model_id = conv_args . model_id
2023-03-02 01:12:32 +08:00
2023-04-21 00:56:25 +08:00
output_model_folder = os . path . join ( conv_args . output_parent_dir , model_id )
2023-02-16 06:11:49 +08:00
2023-03-10 21:07:18 +08:00
# Create output folder
os . makedirs ( output_model_folder , exist_ok = True )
2023-07-10 05:21:43 +08:00
# Saving the model config
2023-04-21 00:56:25 +08:00
config = AutoConfig . from_pretrained ( model_id )
2023-02-16 06:11:49 +08:00
2023-07-10 05:21:43 +08:00
tokenizer = None
2023-03-14 07:28:22 +08:00
try :
# Save tokenizer
2023-04-21 00:56:25 +08:00
tokenizer = AutoTokenizer . from_pretrained ( model_id )
2023-03-14 07:28:22 +08:00
except KeyError :
pass # No Tokenizer
2023-02-16 06:11:49 +08:00
2023-07-10 05:21:43 +08:00
export_kwargs = dict (
2023-04-21 00:56:25 +08:00
model_name_or_path = model_id ,
output = output_model_folder ,
task = conv_args . task ,
2023-05-31 21:13:19 +08:00
opset = conv_args . opset ,
device = conv_args . device ,
2023-05-31 17:59:20 +08:00
do_validation = not conv_args . skip_validation ,
2023-04-21 00:56:25 +08:00
)
2023-02-21 22:53:32 +08:00
2023-07-10 05:21:43 +08:00
# Handle special cases
if config . model_type == ' marian ' :
from . extra . marian import generate_tokenizer_json
tokenizer_json = generate_tokenizer_json ( model_id , tokenizer )
with open ( os . path . join ( output_model_folder , ' tokenizer.json ' ) , ' w ' , encoding = ' utf-8 ' ) as fp :
json . dump ( tokenizer_json , fp )
elif config . model_type == ' whisper ' :
if conv_args . output_attentions :
from . extra . whisper import get_main_export_kwargs
export_kwargs . update (
* * get_main_export_kwargs ( config , " automatic-speech-recognition " )
)
else :
pass # TODO
# Step 1. convert huggingface model to onnx
2023-08-01 20:01:04 +08:00
if config . model_type == ' clip ' and conv_args . split_modalities :
# Handle special case for exporting text and vision models separately
from . extra . clip import CLIPTextModelWithProjectionOnnxConfig , CLIPVisionModelWithProjectionOnnxConfig
from transformers . models . clip import CLIPTextModelWithProjection , CLIPVisionModelWithProjection
text_model = CLIPTextModelWithProjection . from_pretrained ( model_id )
vision_model = CLIPVisionModelWithProjection . from_pretrained ( model_id )
export_models (
models_and_onnx_configs = {
" text_model " : ( text_model , CLIPTextModelWithProjectionOnnxConfig ( text_model . config ) ) ,
" vision_model " : ( vision_model , CLIPVisionModelWithProjectionOnnxConfig ( vision_model . config ) ) ,
} ,
output_dir = output_model_folder ,
opset = conv_args . opset ,
device = conv_args . device ,
)
else :
main_export ( * * export_kwargs )
2023-07-10 05:21:43 +08:00
2023-02-16 06:11:49 +08:00
# Step 2. (optional, recommended) quantize the converted model for fast inference and to reduce model size.
if conv_args . quantize :
2023-07-10 05:21:43 +08:00
# Update quantize config with model specific defaults
quantize_config = MODEL_SPECIFIC_QUANTIZE_PARAMS . get (
config . model_type , DEFAULT_QUANTIZE_PARAMS )
2023-04-21 00:56:25 +08:00
quantize ( [
os . path . join ( output_model_folder , x )
for x in os . listdir ( output_model_folder )
if x . endswith ( ' .onnx ' ) and not x . endswith ( ' _quantized.onnx ' )
2023-07-10 05:21:43 +08:00
] , * * quantize_config )
2023-04-21 00:56:25 +08:00
# Step 3. Move .onnx files to the 'onnx' subfolder
os . makedirs ( os . path . join ( output_model_folder , ' onnx ' ) , exist_ok = True )
for file in os . listdir ( output_model_folder ) :
2023-07-10 05:21:43 +08:00
if file . endswith ( ( ' .onnx ' , ' .onnx_data ' ) ) :
2023-04-21 00:56:25 +08:00
shutil . move ( os . path . join ( output_model_folder , file ) ,
os . path . join ( output_model_folder , ' onnx ' , file ) )
2023-02-16 06:11:49 +08:00
2023-07-10 05:21:43 +08:00
# Step 4. Update the generation config if necessary
if config . model_type == ' whisper ' :
from transformers import GenerationConfig
from . extra . whisper import get_alignment_heads
generation_config = GenerationConfig . from_pretrained ( model_id )
generation_config . alignment_heads = get_alignment_heads ( config )
generation_config . save_pretrained ( output_model_folder )
2023-02-16 06:11:49 +08:00
2023-08-01 20:01:04 +08:00
2023-02-16 06:11:49 +08:00
if __name__ == ' __main__ ' :
main ( )