* Allow user to set `per_channel` and `reduce_range` quantization parameters (#156) Also save quantization options * Get operators of graph and subgraphs
This commit is contained in:
parent
d90f58110a
commit
f628b841a8
|
@ -1,4 +1,5 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from dataclasses import dataclass, field
|
||||
|
@ -77,8 +78,36 @@ class ConversionArguments:
|
|||
}
|
||||
)
|
||||
|
||||
per_channel: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to quantize weights per channel"
|
||||
}
|
||||
)
|
||||
reduce_range: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to quantize weights with 7-bits. It may improve the accuracy for some models running on non-VNNI machine, especially for per-channel mode"
|
||||
}
|
||||
)
|
||||
|
||||
def quantize(model_names_or_paths):
|
||||
|
||||
def get_operators(model):
|
||||
operators = set()
|
||||
|
||||
def traverse_graph(graph):
|
||||
for node in graph.node:
|
||||
operators.add(node.op_type)
|
||||
for attr in node.attribute:
|
||||
if attr.type == onnx.AttributeProto.GRAPH:
|
||||
subgraph = attr.g
|
||||
traverse_graph(subgraph)
|
||||
|
||||
traverse_graph(model.graph)
|
||||
return operators
|
||||
|
||||
|
||||
def quantize(model_names_or_paths, conv_args: ConversionArguments):
|
||||
"""
|
||||
Quantize the weights of the model from float32 to int8 to allow very efficient inference on modern CPU
|
||||
|
||||
|
@ -90,6 +119,12 @@ def quantize(model_names_or_paths):
|
|||
Returns: The Path generated for the quantized
|
||||
"""
|
||||
|
||||
quant_config = dict(
|
||||
per_channel=conv_args.per_channel,
|
||||
reduce_range=conv_args.reduce_range,
|
||||
per_model_config={}
|
||||
)
|
||||
|
||||
for model in tqdm(model_names_or_paths, desc='Quantizing'):
|
||||
directory_path = os.path.dirname(model)
|
||||
file_name_without_extension = os.path.splitext(
|
||||
|
@ -104,8 +139,8 @@ def quantize(model_names_or_paths):
|
|||
# - https://github.com/microsoft/onnxruntime/issues/3130#issuecomment-1105200621
|
||||
# - https://github.com/microsoft/onnxruntime/issues/2339
|
||||
|
||||
model_nodes = onnx.load_model(model).graph.node
|
||||
op_types = set([node.op_type for node in model_nodes])
|
||||
loaded_model = onnx.load_model(model)
|
||||
op_types = get_operators(loaded_model)
|
||||
weight_type = QuantType.QUInt8 if 'Conv' in op_types else QuantType.QInt8
|
||||
|
||||
quantize_dynamic(
|
||||
|
@ -113,9 +148,8 @@ def quantize(model_names_or_paths):
|
|||
model_output=os.path.join(
|
||||
directory_path, f'{file_name_without_extension}_quantized.onnx'),
|
||||
|
||||
# TODO allow user to specify these or choose based on hardware
|
||||
per_channel=True,
|
||||
reduce_range=True,
|
||||
per_channel=conv_args.per_channel,
|
||||
reduce_range=conv_args.reduce_range,
|
||||
|
||||
weight_type=weight_type,
|
||||
optimize_model=False,
|
||||
|
@ -127,6 +161,15 @@ def quantize(model_names_or_paths):
|
|||
)
|
||||
)
|
||||
|
||||
quant_config['per_model_config'][file_name_without_extension] = dict(
|
||||
op_types=list(op_types),
|
||||
weight_type=str(weight_type),
|
||||
)
|
||||
|
||||
# Save quantization config
|
||||
with open(os.path.join(directory_path, 'quant_config.json'), 'w') as fp:
|
||||
json.dump(quant_config, fp, indent=4)
|
||||
|
||||
|
||||
def copy_if_exists(model_path, file_name, destination):
|
||||
file = cached_file(model_path, file_name,
|
||||
|
@ -192,7 +235,7 @@ def main():
|
|||
os.path.join(output_model_folder, x)
|
||||
for x in os.listdir(output_model_folder)
|
||||
if x.endswith('.onnx') and not x.endswith('_quantized.onnx')
|
||||
])
|
||||
], conv_args)
|
||||
|
||||
# Step 3. Move .onnx files to the 'onnx' subfolder
|
||||
os.makedirs(os.path.join(output_model_folder, 'onnx'), exist_ok=True)
|
||||
|
|
Loading…
Reference in New Issue