Fix whisper quantization

This commit is contained in:
Joshua Lochner 2023-03-07 23:53:50 +02:00
parent e5c2782db8
commit ef37343897
1 changed files with 22 additions and 3 deletions

View File

@ -83,7 +83,12 @@ class ConversionArguments:
)
def quantize(models_name_or_path):
UNSIGNED_MODEL_TYPES = [
'whisper'
]
def quantize(models_name_or_path, model_type):
"""
Quantize the weights of the model from float32 to int8 to allow very efficient inference on modern CPU
@ -95,6 +100,17 @@ def quantize(models_name_or_path):
Returns: The Path generated for the quantized
"""
# As per docs, signed weight type (QInt8) is faster on most CPUs
# However, for some model types (e.g., whisper), we have to use
# unsigned weight type (QUInt8). For more info:
# https://github.com/microsoft/onnxruntime/issues/3130#issuecomment-1105200621
if model_type in UNSIGNED_MODEL_TYPES:
weight_type = QuantType.QUInt8
else:
# Default
weight_type = QuantType.QInt8
for model in tqdm(models_name_or_path, desc='Quantizing'):
# model_name = os.path.splitext(os.path.basename(model))[0]
quantize_dynamic(
@ -103,7 +119,8 @@ def quantize(models_name_or_path):
per_channel=True,
reduce_range=True, # should be the same as per_channel
activation_type=QuantType.QUInt8,
weight_type=QuantType.QInt8, # per docs, signed is faster on most CPUs
weight_type=weight_type,
optimize_model=False,
) # op_types_to_quantize=['MatMul', 'Relu', 'Add', 'Mul' ],
@ -219,7 +236,9 @@ def main():
# Step 2. (optional, recommended) quantize the converted model for fast inference and to reduce model size.
if conv_args.quantize:
quantize(onnx_model_paths)
quantize(onnx_model_paths, model.config.model_type)
# TODO copy all other .json files
if __name__ == '__main__':