Improve conversion script

Determine method to quantize based on supported operations
This commit is contained in:
Joshua Lochner 2023-04-20 18:56:25 +02:00
parent 9989d80a33
commit 1900a42154
1 changed files with 67 additions and 169 deletions

View File

@ -4,29 +4,22 @@ import shutil
from dataclasses import dataclass, field
from typing import Optional
from transformers import AutoTokenizer, HfArgumentParser
from transformers import (
AutoConfig,
AutoTokenizer,
HfArgumentParser
)
from transformers.utils import cached_file
from tqdm import tqdm
from optimum.utils import DEFAULT_DUMMY_SHAPES
import onnx
from optimum.exporters.onnx import main_export
from optimum.exporters.tasks import TasksManager
from optimum.exporters.onnx.utils import (
get_decoder_models_for_export,
get_encoder_decoder_models_for_export
)
from optimum.exporters.onnx.convert import export_models
from optimum.onnx.graph_transformations import merge_decoders
from optimum.onnxruntime.utils import (
ONNX_WEIGHTS_NAME,
ONNX_ENCODER_NAME,
ONNX_DECODER_NAME,
ONNX_DECODER_WITH_PAST_NAME,
ONNX_DECODER_MERGED_NAME
)
from onnxruntime.quantization import (
quantize_dynamic,
QuantType
)
from onnxruntime.quantization.registry import IntegerOpsRegistry
@dataclass
@ -46,21 +39,15 @@ class ConversionArguments:
"help": "Whether to quantize the model."
}
)
input_parent_dir: str = field(
default='./models/pytorch/',
metadata={
"help": "Path where the original model will be loaded from."
}
)
output_parent_dir: str = field(
default='./models/onnx/',
default='./models/',
metadata={
"help": "Path where the converted model will be saved to."
}
)
task: Optional[str] = field(
default='default',
default='auto',
metadata={
"help": (
"The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on the model, but are among:"
@ -84,12 +71,6 @@ class ConversionArguments:
"help": 'The device to use to do the export.'
}
)
from_hub: bool = field(
default=False,
metadata={
"help": "Whether to use local files, or from the HuggingFace Hub."
}
)
merge_decoders: bool = field(
default=True,
metadata={
@ -104,17 +85,7 @@ class ConversionArguments:
)
UNSIGNED_MODEL_TYPES = [
'whisper',
'vision-encoder-decoder',
'vit',
'clip',
'detr',
'squeezebert',
]
def quantize(models_name_or_path, model_type):
def quantize(model_names_or_paths):
"""
Quantize the weights of the model from float32 to int8 to allow very efficient inference on modern CPU
@ -126,28 +97,40 @@ def quantize(models_name_or_path, model_type):
Returns: The Path generated for the quantized
"""
# As per docs, signed weight type (QInt8) is faster on most CPUs
# However, for some model types (e.g., whisper), we have to use
# unsigned weight type (QUInt8). For more info:
# https://github.com/microsoft/onnxruntime/issues/3130#issuecomment-1105200621
for model in tqdm(model_names_or_paths, desc='Quantizing'):
directory_path = os.path.dirname(model)
file_name_without_extension = os.path.splitext(
os.path.basename(model))[0]
if model_type in UNSIGNED_MODEL_TYPES:
weight_type = QuantType.QUInt8
else:
# Default
weight_type = QuantType.QInt8
# NOTE:
# As of 2023/04/20, the current latest version of onnxruntime-web is 1.14.0, and does not support INT8 weights for Conv layers.
# For this reason, we choose model weight types to ensure compatibility with onnxruntime-web.
#
# As per docs, signed weight type (QInt8) is faster on most CPUs, so, we use that unless the model contains a Conv layer.
# For more information, see:
# - https://github.com/microsoft/onnxruntime/issues/3130#issuecomment-1105200621
# - https://github.com/microsoft/onnxruntime/issues/2339
model_nodes = onnx.load_model(model).graph.node
op_types = set([node.op_type for node in model_nodes])
weight_type = QuantType.QUInt8 if 'Conv' in op_types else QuantType.QInt8
for model in tqdm(models_name_or_path, desc='Quantizing'):
# model_name = os.path.splitext(os.path.basename(model))[0]
quantize_dynamic(
model_input=model,
model_output=model,
per_channel=True,
reduce_range=True, # should be the same as per_channel
model_output=os.path.join(
directory_path, f'{file_name_without_extension}_quantized.onnx'),
per_channel=False,
reduce_range=False,
weight_type=weight_type,
optimize_model=False,
) # op_types_to_quantize=['MatMul', 'Relu', 'Add', 'Mul' ],
# TODO allow user to specify these
# op_types_to_quantize=['MatMul', 'Add', 'Conv'],
extra_options=dict(
EnableSubgraph=True
)
)
def copy_if_exists(model_path, file_name, destination):
@ -164,146 +147,61 @@ def main():
)
conv_args, = parser.parse_args_into_dataclasses()
input_model_path = os.path.join(
conv_args.input_parent_dir,
conv_args.model_id
)
if conv_args.from_hub:
model_path = conv_args.model_id
else:
model_path = input_model_path
model_id = conv_args.model_id
# Infer the task
task = conv_args.task
if task == "auto":
try:
task = TasksManager.infer_task_from_model(model_path)
except KeyError as e:
raise KeyError(
f"The task could not be automatically inferred. Please provide the argument --task with the task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
)
output_model_folder = os.path.join(
conv_args.output_parent_dir,
'quantized' if conv_args.quantize else 'unquantized',
conv_args.model_id,
task
)
# get the shapes to be used to generate dummy inputs
input_shapes = DEFAULT_DUMMY_SHAPES.copy()
model = TasksManager.get_model_from_task(
task, model_path,
framework='pt',
)
onnx_config_constructor = TasksManager.get_exporter_config_constructor(
model=model, exporter='onnx', task=task)
onnx_config = onnx_config_constructor(model.config)
# Ensure the requested opset is sufficient
if conv_args.opset is None:
conv_args.opset = onnx_config.DEFAULT_ONNX_OPSET
elif conv_args.opset < onnx_config.DEFAULT_ONNX_OPSET:
raise ValueError(
f"Opset {conv_args.opset} is not sufficient to export {model.config.model_type}. "
f"At least {onnx_config.DEFAULT_ONNX_OPSET} is required."
)
output_model_folder = os.path.join(conv_args.output_parent_dir, model_id)
# Create output folder
os.makedirs(output_model_folder, exist_ok=True)
# Copy certain JSON files, which save_pretrained doesn't handle
copy_if_exists(model_path, 'tokenizer.json', output_model_folder)
copy_if_exists(model_path, 'preprocessor_config.json', output_model_folder)
# copy_if_exists(model_id, 'tokenizer.json', output_model_folder)
if model.can_generate():
copy_if_exists(model_path, 'generation_config.json',
output_model_folder)
# copy_if_exists(model_id, 'preprocessor_config.json', output_model_folder)
# copy_if_exists(model_id, 'generation_config.json', output_model_folder)
# Saving the model config
model.config.save_pretrained(output_model_folder)
# # Saving the model config
config = AutoConfig.from_pretrained(model_id)
# config.save_pretrained(output_model_folder)
try:
# Save tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.save_pretrained(output_model_folder)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# tokenizer.save_pretrained(output_model_folder)
# Handle special cases
if model.config.model_type == 'marian':
if config.model_type == 'marian':
import json
from .extra.marian import generate_tokenizer_json
tokenizer_json = generate_tokenizer_json(model_path, tokenizer)
tokenizer_json = generate_tokenizer_json(model_id, tokenizer)
with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp:
json.dump(tokenizer_json, fp)
except KeyError:
pass # No Tokenizer
# Specify output paths
OUTPUT_WEIGHTS_PATH = os.path.join(output_model_folder, ONNX_WEIGHTS_NAME)
OUTPUT_ENCODER_PATH = os.path.join(output_model_folder, ONNX_ENCODER_NAME)
OUTPUT_DECODER_PATH = os.path.join(output_model_folder, ONNX_DECODER_NAME)
OUTPUT_DECODER_WITH_PAST_PATH = os.path.join(
output_model_folder, ONNX_DECODER_WITH_PAST_NAME)
OUTPUT_DECODER_MERGED_PATH = os.path.join(
output_model_folder, ONNX_DECODER_MERGED_NAME)
# Step 1. convert huggingface model to onnx
if model.config.is_encoder_decoder and task.startswith("causal-lm"):
raise ValueError(
f"model.config.is_encoder_decoder is True and task is `{task}`, which are incompatible. If the task was auto-inferred, please fill a bug report"
f"at https://github.com/huggingface/optimum, if --task was explicitely passed, make sure you selected the right task for the model,"
f" referring to `optimum.exporters.tasks.TaskManager`'s `_TASKS_TO_AUTOMODELS`."
)
if (
model.config.is_encoder_decoder
and task.startswith(("seq2seq-lm", "speech2seq-lm", "vision2seq-lm", "default-with-past"))
):
models_and_onnx_configs = get_encoder_decoder_models_for_export(
model, onnx_config)
elif task.startswith("causal-lm"):
models_and_onnx_configs = get_decoder_models_for_export(
model, onnx_config)
else:
models_and_onnx_configs = {"model": (model, onnx_config)}
onnx_model_paths = [
os.path.join(output_model_folder, f'{x}.onnx')
for x in models_and_onnx_configs
]
# Check if at least one model doesn't exist, or user requests to overwrite
if any(
not os.path.exists(x) for x in onnx_model_paths
) or conv_args.overwrite:
_, onnx_outputs = export_models(
models_and_onnx_configs=models_and_onnx_configs,
opset=conv_args.opset,
output_dir=output_model_folder,
input_shapes=input_shapes,
device=conv_args.device,
# dtype="fp16" if fp16 is True else None, # TODO
)
main_export(
model_name_or_path=model_id,
output=output_model_folder,
task=conv_args.task,
)
# Step 2. (optional, recommended) quantize the converted model for fast inference and to reduce model size.
if conv_args.quantize:
quantize(onnx_model_paths, model.config.model_type)
quantize([
os.path.join(output_model_folder, x)
for x in os.listdir(output_model_folder)
if x.endswith('.onnx') and not x.endswith('_quantized.onnx')
])
# Step 3. merge decoders.
if conv_args.merge_decoders and (
os.path.exists(OUTPUT_DECODER_PATH) and
os.path.exists(OUTPUT_DECODER_WITH_PAST_PATH)
) and (not os.path.exists(OUTPUT_DECODER_MERGED_PATH) or conv_args.overwrite):
print('Merging decoders')
merge_decoders(
OUTPUT_DECODER_PATH,
OUTPUT_DECODER_WITH_PAST_PATH,
save_path=OUTPUT_DECODER_MERGED_PATH,
strict=False
)
# Step 3. Move .onnx files to the 'onnx' subfolder
os.makedirs(os.path.join(output_model_folder, 'onnx'), exist_ok=True)
for file in os.listdir(output_model_folder):
if file.endswith('.onnx'):
shutil.move(os.path.join(output_model_folder, file),
os.path.join(output_model_folder, 'onnx', file))
if __name__ == '__main__':