diff --git a/scripts/convert.py b/scripts/convert.py index 9ef0a08..d38c401 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -4,29 +4,22 @@ import shutil from dataclasses import dataclass, field from typing import Optional -from transformers import AutoTokenizer, HfArgumentParser +from transformers import ( + AutoConfig, + AutoTokenizer, + HfArgumentParser +) from transformers.utils import cached_file from tqdm import tqdm -from optimum.utils import DEFAULT_DUMMY_SHAPES +import onnx +from optimum.exporters.onnx import main_export from optimum.exporters.tasks import TasksManager -from optimum.exporters.onnx.utils import ( - get_decoder_models_for_export, - get_encoder_decoder_models_for_export -) -from optimum.exporters.onnx.convert import export_models -from optimum.onnx.graph_transformations import merge_decoders -from optimum.onnxruntime.utils import ( - ONNX_WEIGHTS_NAME, - ONNX_ENCODER_NAME, - ONNX_DECODER_NAME, - ONNX_DECODER_WITH_PAST_NAME, - ONNX_DECODER_MERGED_NAME -) from onnxruntime.quantization import ( quantize_dynamic, QuantType ) +from onnxruntime.quantization.registry import IntegerOpsRegistry @dataclass @@ -46,21 +39,15 @@ class ConversionArguments: "help": "Whether to quantize the model." } ) - input_parent_dir: str = field( - default='./models/pytorch/', - metadata={ - "help": "Path where the original model will be loaded from." - } - ) output_parent_dir: str = field( - default='./models/onnx/', + default='./models/', metadata={ "help": "Path where the converted model will be saved to." } ) task: Optional[str] = field( - default='default', + default='auto', metadata={ "help": ( "The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on the model, but are among:" @@ -84,12 +71,6 @@ class ConversionArguments: "help": 'The device to use to do the export.' } ) - from_hub: bool = field( - default=False, - metadata={ - "help": "Whether to use local files, or from the HuggingFace Hub." - } - ) merge_decoders: bool = field( default=True, metadata={ @@ -104,17 +85,7 @@ class ConversionArguments: ) -UNSIGNED_MODEL_TYPES = [ - 'whisper', - 'vision-encoder-decoder', - 'vit', - 'clip', - 'detr', - 'squeezebert', -] - - -def quantize(models_name_or_path, model_type): +def quantize(model_names_or_paths): """ Quantize the weights of the model from float32 to int8 to allow very efficient inference on modern CPU @@ -126,28 +97,40 @@ def quantize(models_name_or_path, model_type): Returns: The Path generated for the quantized """ - # As per docs, signed weight type (QInt8) is faster on most CPUs - # However, for some model types (e.g., whisper), we have to use - # unsigned weight type (QUInt8). For more info: - # https://github.com/microsoft/onnxruntime/issues/3130#issuecomment-1105200621 + for model in tqdm(model_names_or_paths, desc='Quantizing'): + directory_path = os.path.dirname(model) + file_name_without_extension = os.path.splitext( + os.path.basename(model))[0] - if model_type in UNSIGNED_MODEL_TYPES: - weight_type = QuantType.QUInt8 - else: - # Default - weight_type = QuantType.QInt8 + # NOTE: + # As of 2023/04/20, the current latest version of onnxruntime-web is 1.14.0, and does not support INT8 weights for Conv layers. + # For this reason, we choose model weight types to ensure compatibility with onnxruntime-web. + # + # As per docs, signed weight type (QInt8) is faster on most CPUs, so, we use that unless the model contains a Conv layer. + # For more information, see: + # - https://github.com/microsoft/onnxruntime/issues/3130#issuecomment-1105200621 + # - https://github.com/microsoft/onnxruntime/issues/2339 + + model_nodes = onnx.load_model(model).graph.node + op_types = set([node.op_type for node in model_nodes]) + weight_type = QuantType.QUInt8 if 'Conv' in op_types else QuantType.QInt8 - for model in tqdm(models_name_or_path, desc='Quantizing'): - # model_name = os.path.splitext(os.path.basename(model))[0] quantize_dynamic( model_input=model, - model_output=model, - per_channel=True, - reduce_range=True, # should be the same as per_channel + model_output=os.path.join( + directory_path, f'{file_name_without_extension}_quantized.onnx'), + per_channel=False, + reduce_range=False, weight_type=weight_type, optimize_model=False, - ) # op_types_to_quantize=['MatMul', 'Relu', 'Add', 'Mul' ], + + # TODO allow user to specify these + # op_types_to_quantize=['MatMul', 'Add', 'Conv'], + extra_options=dict( + EnableSubgraph=True + ) + ) def copy_if_exists(model_path, file_name, destination): @@ -164,146 +147,61 @@ def main(): ) conv_args, = parser.parse_args_into_dataclasses() - input_model_path = os.path.join( - conv_args.input_parent_dir, - conv_args.model_id - ) - if conv_args.from_hub: - model_path = conv_args.model_id - else: - model_path = input_model_path + model_id = conv_args.model_id - # Infer the task - task = conv_args.task - if task == "auto": - try: - task = TasksManager.infer_task_from_model(model_path) - except KeyError as e: - raise KeyError( - f"The task could not be automatically inferred. Please provide the argument --task with the task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" - ) - - output_model_folder = os.path.join( - conv_args.output_parent_dir, - 'quantized' if conv_args.quantize else 'unquantized', - conv_args.model_id, - task - ) - - # get the shapes to be used to generate dummy inputs - input_shapes = DEFAULT_DUMMY_SHAPES.copy() - - model = TasksManager.get_model_from_task( - task, model_path, - framework='pt', - ) - - onnx_config_constructor = TasksManager.get_exporter_config_constructor( - model=model, exporter='onnx', task=task) - onnx_config = onnx_config_constructor(model.config) - - # Ensure the requested opset is sufficient - if conv_args.opset is None: - conv_args.opset = onnx_config.DEFAULT_ONNX_OPSET - elif conv_args.opset < onnx_config.DEFAULT_ONNX_OPSET: - raise ValueError( - f"Opset {conv_args.opset} is not sufficient to export {model.config.model_type}. " - f"At least {onnx_config.DEFAULT_ONNX_OPSET} is required." - ) + output_model_folder = os.path.join(conv_args.output_parent_dir, model_id) # Create output folder os.makedirs(output_model_folder, exist_ok=True) # Copy certain JSON files, which save_pretrained doesn't handle - copy_if_exists(model_path, 'tokenizer.json', output_model_folder) - copy_if_exists(model_path, 'preprocessor_config.json', output_model_folder) + # copy_if_exists(model_id, 'tokenizer.json', output_model_folder) - if model.can_generate(): - copy_if_exists(model_path, 'generation_config.json', - output_model_folder) + # copy_if_exists(model_id, 'preprocessor_config.json', output_model_folder) + # copy_if_exists(model_id, 'generation_config.json', output_model_folder) - # Saving the model config - model.config.save_pretrained(output_model_folder) + # # Saving the model config + config = AutoConfig.from_pretrained(model_id) + # config.save_pretrained(output_model_folder) try: # Save tokenizer - tokenizer = AutoTokenizer.from_pretrained(model_path) - tokenizer.save_pretrained(output_model_folder) + tokenizer = AutoTokenizer.from_pretrained(model_id) + # tokenizer.save_pretrained(output_model_folder) # Handle special cases - if model.config.model_type == 'marian': + if config.model_type == 'marian': import json from .extra.marian import generate_tokenizer_json - tokenizer_json = generate_tokenizer_json(model_path, tokenizer) + tokenizer_json = generate_tokenizer_json(model_id, tokenizer) with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp: json.dump(tokenizer_json, fp) + except KeyError: pass # No Tokenizer - # Specify output paths - OUTPUT_WEIGHTS_PATH = os.path.join(output_model_folder, ONNX_WEIGHTS_NAME) - OUTPUT_ENCODER_PATH = os.path.join(output_model_folder, ONNX_ENCODER_NAME) - OUTPUT_DECODER_PATH = os.path.join(output_model_folder, ONNX_DECODER_NAME) - OUTPUT_DECODER_WITH_PAST_PATH = os.path.join( - output_model_folder, ONNX_DECODER_WITH_PAST_NAME) - OUTPUT_DECODER_MERGED_PATH = os.path.join( - output_model_folder, ONNX_DECODER_MERGED_NAME) - # Step 1. convert huggingface model to onnx - if model.config.is_encoder_decoder and task.startswith("causal-lm"): - raise ValueError( - f"model.config.is_encoder_decoder is True and task is `{task}`, which are incompatible. If the task was auto-inferred, please fill a bug report" - f"at https://github.com/huggingface/optimum, if --task was explicitely passed, make sure you selected the right task for the model," - f" referring to `optimum.exporters.tasks.TaskManager`'s `_TASKS_TO_AUTOMODELS`." - ) - - if ( - model.config.is_encoder_decoder - and task.startswith(("seq2seq-lm", "speech2seq-lm", "vision2seq-lm", "default-with-past")) - ): - models_and_onnx_configs = get_encoder_decoder_models_for_export( - model, onnx_config) - elif task.startswith("causal-lm"): - models_and_onnx_configs = get_decoder_models_for_export( - model, onnx_config) - else: - models_and_onnx_configs = {"model": (model, onnx_config)} - - onnx_model_paths = [ - os.path.join(output_model_folder, f'{x}.onnx') - for x in models_and_onnx_configs - ] - - # Check if at least one model doesn't exist, or user requests to overwrite - if any( - not os.path.exists(x) for x in onnx_model_paths - ) or conv_args.overwrite: - _, onnx_outputs = export_models( - models_and_onnx_configs=models_and_onnx_configs, - opset=conv_args.opset, - output_dir=output_model_folder, - input_shapes=input_shapes, - device=conv_args.device, - # dtype="fp16" if fp16 is True else None, # TODO - ) + main_export( + model_name_or_path=model_id, + output=output_model_folder, + task=conv_args.task, + ) # Step 2. (optional, recommended) quantize the converted model for fast inference and to reduce model size. if conv_args.quantize: - quantize(onnx_model_paths, model.config.model_type) + quantize([ + os.path.join(output_model_folder, x) + for x in os.listdir(output_model_folder) + if x.endswith('.onnx') and not x.endswith('_quantized.onnx') + ]) - # Step 3. merge decoders. - if conv_args.merge_decoders and ( - os.path.exists(OUTPUT_DECODER_PATH) and - os.path.exists(OUTPUT_DECODER_WITH_PAST_PATH) - ) and (not os.path.exists(OUTPUT_DECODER_MERGED_PATH) or conv_args.overwrite): - print('Merging decoders') - merge_decoders( - OUTPUT_DECODER_PATH, - OUTPUT_DECODER_WITH_PAST_PATH, - save_path=OUTPUT_DECODER_MERGED_PATH, - strict=False - ) + # Step 3. Move .onnx files to the 'onnx' subfolder + os.makedirs(os.path.join(output_model_folder, 'onnx'), exist_ok=True) + for file in os.listdir(output_model_folder): + if file.endswith('.onnx'): + shutil.move(os.path.join(output_model_folder, file), + os.path.join(output_model_folder, 'onnx', file)) if __name__ == '__main__':