From f628b841a872b82b09454bc54e2bf7e9abfdd955 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 22 Jun 2023 00:43:43 +0200 Subject: [PATCH] Allow user to set `per_channel` and `reduce_range` quantization params (#156) (#157) * Allow user to set `per_channel` and `reduce_range` quantization parameters (#156) Also save quantization options * Get operators of graph and subgraphs --- scripts/convert.py | 57 ++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 50 insertions(+), 7 deletions(-) diff --git a/scripts/convert.py b/scripts/convert.py index 6f8938f..8daa43e 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -1,4 +1,5 @@ +import json import os import shutil from dataclasses import dataclass, field @@ -77,8 +78,36 @@ class ConversionArguments: } ) + per_channel: bool = field( + default=True, + metadata={ + "help": "Whether to quantize weights per channel" + } + ) + reduce_range: bool = field( + default=True, + metadata={ + "help": "Whether to quantize weights with 7-bits. It may improve the accuracy for some models running on non-VNNI machine, especially for per-channel mode" + } + ) -def quantize(model_names_or_paths): + +def get_operators(model): + operators = set() + + def traverse_graph(graph): + for node in graph.node: + operators.add(node.op_type) + for attr in node.attribute: + if attr.type == onnx.AttributeProto.GRAPH: + subgraph = attr.g + traverse_graph(subgraph) + + traverse_graph(model.graph) + return operators + + +def quantize(model_names_or_paths, conv_args: ConversionArguments): """ Quantize the weights of the model from float32 to int8 to allow very efficient inference on modern CPU @@ -90,6 +119,12 @@ def quantize(model_names_or_paths): Returns: The Path generated for the quantized """ + quant_config = dict( + per_channel=conv_args.per_channel, + reduce_range=conv_args.reduce_range, + per_model_config={} + ) + for model in tqdm(model_names_or_paths, desc='Quantizing'): directory_path = os.path.dirname(model) file_name_without_extension = os.path.splitext( @@ -104,8 +139,8 @@ def quantize(model_names_or_paths): # - https://github.com/microsoft/onnxruntime/issues/3130#issuecomment-1105200621 # - https://github.com/microsoft/onnxruntime/issues/2339 - model_nodes = onnx.load_model(model).graph.node - op_types = set([node.op_type for node in model_nodes]) + loaded_model = onnx.load_model(model) + op_types = get_operators(loaded_model) weight_type = QuantType.QUInt8 if 'Conv' in op_types else QuantType.QInt8 quantize_dynamic( @@ -113,9 +148,8 @@ def quantize(model_names_or_paths): model_output=os.path.join( directory_path, f'{file_name_without_extension}_quantized.onnx'), - # TODO allow user to specify these or choose based on hardware - per_channel=True, - reduce_range=True, + per_channel=conv_args.per_channel, + reduce_range=conv_args.reduce_range, weight_type=weight_type, optimize_model=False, @@ -127,6 +161,15 @@ def quantize(model_names_or_paths): ) ) + quant_config['per_model_config'][file_name_without_extension] = dict( + op_types=list(op_types), + weight_type=str(weight_type), + ) + + # Save quantization config + with open(os.path.join(directory_path, 'quant_config.json'), 'w') as fp: + json.dump(quant_config, fp, indent=4) + def copy_if_exists(model_path, file_name, destination): file = cached_file(model_path, file_name, @@ -192,7 +235,7 @@ def main(): os.path.join(output_model_folder, x) for x in os.listdir(output_model_folder) if x.endswith('.onnx') and not x.endswith('_quantized.onnx') - ]) + ], conv_args) # Step 3. Move .onnx files to the 'onnx' subfolder os.makedirs(os.path.join(output_model_folder, 'onnx'), exist_ok=True)