Adding optimizations block from ONNXRuntime. (#4431)

* Adding optimizations block from ONNXRuntime.

* Turn off external data format by default for PyTorch export.

* Correct the way use_external_format is passed through the cmdline args.
This commit is contained in:
Funtowicz Morgan 2020-05-18 18:32:33 +00:00 committed by GitHub
parent 24538df919
commit ca4a3f4da9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 44 additions and 6 deletions

View File

@ -125,9 +125,39 @@
"- **Deadcode Elimination**: Remove nodes never accessed in the graph\n",
"- **Operator Fusing**: Merge multiple instruction into one (Linear -> ReLU can be fused to be LinearReLU)\n",
"\n",
"All of this is done on **onnxruntime** by settings specific `SessionOptions`:"
"ONNX Runtime automatically applies most optimizations by setting specific `SessionOptions`.\n",
"\n",
"Note:Some of the latest optimizations that are not yet integrated into ONNX Runtime are available in [optimization script](https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers) that tunes models for the best performance."
]
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# # An optional step unless\n",
"# # you want to get a model with mixed precision for perf accelartion on newer GPU\n",
"# # or you are working with Tensorflow(tf.keras) models or pytorch models other than bert\n",
"\n",
"# !pip install onnxruntime-tools\n",
"# from onnxruntime_tools import optimizer\n",
"\n",
"# # Mixed precision conversion for bert-base-cased model converted from Pytorch\n",
"# optimized_model = optimizer.optimize_model(\"bert-base-cased.onnx\", model_type='bert', num_heads=12, hidden_size=768)\n",
"# optimized_model.convert_model_float32_to_float16()\n",
"# optimized_model.save_model_to_file(\"bert-base-cased.onnx\")\n",
"\n",
"# # optimizations for bert-base-cased model converted from Tensorflow(tf.keras)\n",
"# optimized_model = optimizer.optimize_model(\"bert-base-cased.onnx\", model_type='bert_keras', num_heads=12, hidden_size=768)\n",
"# optimized_model.save_model_to_file(\"bert-base-cased.onnx\")\n"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 2,

View File

@ -22,6 +22,7 @@ class OnnxConverterArgumentParser(ArgumentParser):
self.add_argument("--framework", type=str, choices=["pt", "tf"], help="Framework for loading the model")
self.add_argument("--opset", type=int, default=11, help="ONNX opset to use")
self.add_argument("--check-loading", action="store_true", help="Check ONNX is able to load the model")
self.add_argument("--use-external-format", action="store_true", help="Allow exporting model >= than 2Gb")
self.add_argument("output")
@ -105,7 +106,7 @@ def load_graph_from_args(framework: str, model: str, tokenizer: Optional[str] =
return pipeline("feature-extraction", model=model, framework=framework)
def convert_pytorch(nlp: Pipeline, opset: int, output: str):
def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format: bool):
if not is_torch_available():
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
@ -126,7 +127,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: str):
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
use_external_data_format=True,
use_external_data_format=use_external_format,
enable_onnx_checker=True,
opset_version=opset,
)
@ -160,7 +161,14 @@ def convert_tensorflow(nlp: Pipeline, opset: int, output: str):
)
def convert(framework: str, model: str, output: str, opset: int, tokenizer: Optional[str] = None):
def convert(
framework: str,
model: str,
output: str,
opset: int,
tokenizer: Optional[str] = None,
use_external_format: bool = False,
):
print("ONNX opset version set to: {}".format(opset))
# Load the pipeline
@ -175,7 +183,7 @@ def convert(framework: str, model: str, output: str, opset: int, tokenizer: Opti
# Export the graph
if framework == "pt":
convert_pytorch(nlp, opset, output)
convert_pytorch(nlp, opset, output, use_external_format)
else:
convert_tensorflow(nlp, opset, output)
@ -202,7 +210,7 @@ if __name__ == "__main__":
try:
# Convert
convert(args.framework, args.model, args.output, args.opset, args.tokenizer)
convert(args.framework, args.model, args.output, args.opset, args.tokenizer, args.use_external_format)
# And verify
if args.check_loading: