[LegalizeOnnx] support to legalize the output of onnx-mlir flow to pure standard/affine IR

This commit is contained in:
Hanchen Ye 2020-12-23 16:45:38 -06:00
parent 9067cb6383
commit 21347b2665
6 changed files with 440 additions and 5 deletions

View File

@ -15,9 +15,13 @@ class Pass;
namespace mlir {
namespace scalehls {
// HLSKernel and HLSCpp conversion passes.
std::unique_ptr<Pass> createConvertToHLSCppPass();
std::unique_ptr<Pass> createHLSKernelToAffinePass();
/// Onnx kernel legalization pass.
std::unique_ptr<Pass> createLegalizeOnnxPass();
void registerConversionPasses();
#define GEN_PASS_CLASSES

View File

@ -7,11 +7,15 @@
include "mlir/Pass/PassBase.td"
//===----------------------------------------------------------------------===//
// HLSKernel and HLSCpp Conversion passes
//===----------------------------------------------------------------------===//
def ConvertToHLSCpp : Pass<"convert-to-hlscpp", "FuncOp"> {
let summary = "Convert to emittable MLIR code";
let description = [{
This convert-to-hlscpp converts MLIR code in Affine/Standard/SCF level to
emittable MLIR code.
emittable and estimatable MLIR code.
}];
let constructor = "mlir::scalehls::createConvertToHLSCppPass()";
@ -32,4 +36,18 @@ def HLSKernelToAffine : Pass<"hlskernel-to-affine", "FuncOp"> {
let constructor = "mlir::scalehls::createHLSKernelToAffinePass()";
}
//===----------------------------------------------------------------------===//
// Onnx Kernel Legalization Pass
//===----------------------------------------------------------------------===//
def LegalizeOnnx : Pass<"legalize-onnx", "ModuleOp"> {
let summary = "Legalize model lowered from onnx-mlir flow";
let description = [{
This legalize-onnx pass will legalize all operations lowered from onnx-mlir
flow, e.g. krnl.packed_const, krnl.global, and krnl.memcpy.
}];
let constructor = "mlir::scalehls::createLegalizeOnnxPass()";
}
#endif // SCALEHLS_CONVERSION_PASSES_TD

View File

@ -0,0 +1,139 @@
//===------------------------------------------------------------*- C++ -*-===//
//
//===----------------------------------------------------------------------===//
#include "Conversion/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
using namespace std;
using namespace mlir;
using namespace scalehls;
namespace {
struct LegalizeOnnx : public LegalizeOnnxBase<LegalizeOnnx> {
void runOnOperation() override;
};
} // namespace
void LegalizeOnnx::runOnOperation() {
auto module = getOperation();
auto builder = OpBuilder(module);
StringRef weightFileName = "";
int64_t weightSizeInBytes = 0;
int64_t numInputs = 0;
int64_t numOutputs = 0;
StringRef topFunction = "";
SmallVector<FuncOp, 2> funcs;
SmallVector<Operation *, 16> opsToErase;
for (auto &op : module) {
if (op.getName().getStringRef() == "krnl.packed_const") {
// Fetch weight information and erase packed_const operation.
weightFileName = op.getAttrOfType<StringAttr>("file_name").getValue();
weightSizeInBytes =
op.getAttrOfType<IntegerAttr>("size_in_bytes").getInt();
opsToErase.push_back(&op);
} else if (op.getName().getStringRef() == "krnl.entry_point") {
// Fetch top function information and erase entry_point operation.
topFunction = op.getAttrOfType<FlatSymbolRefAttr>("func").getValue();
numInputs = op.getAttrOfType<IntegerAttr>("numInputs").getInt();
numOutputs = op.getAttrOfType<IntegerAttr>("numOutputs").getInt();
opsToErase.push_back(&op);
} else if (auto func = dyn_cast<FuncOp>(op))
funcs.push_back(func);
}
for (auto func : funcs) {
SmallVector<Type, 16> weightTypes;
SmallVector<int64_t, 16> weightOffsets;
SmallVector<Value, 16> weightValues;
for (auto &op : func.front()) {
if (op.getName().getStringRef() == "krnl.global") {
if (auto value = op.getAttrOfType<DenseFPElementsAttr>("value")) {
// If the kernel global operation gets a value, create a standard
// constant operation to substitute it.
builder.setInsertionPoint(&op);
auto tensor = builder.create<mlir::ConstantOp>(op.getLoc(), value);
auto memref = builder.create<mlir::TensorToMemrefOp>(
op.getLoc(), op.getResult(0).getType(), tensor);
op.getResult(0).replaceAllUsesWith(memref);
} else {
// If value attribute doesn't exist, record the type and offset.
weightTypes.push_back(op.getResult(0).getType());
weightOffsets.push_back(
op.getAttrOfType<IntegerAttr>("offset").getInt());
weightValues.push_back(op.getResult(0));
}
// Erase the kernel global operation.
opsToErase.push_back(&op);
} else if (op.getName().getStringRef() == "krnl.memcpy") {
// Replace kernel memcpy with standard memref reshape operation.
auto result = op.getOperand(1);
builder.setInsertionPointAfterValue(result);
auto resultRank = result.getType().cast<MemRefType>().getRank();
auto shape = builder.create<mlir::AllocOp>(
op.getLoc(), MemRefType::get(resultRank, builder.getIndexType()));
auto newResult = builder.create<mlir::MemRefReshapeOp>(
op.getLoc(), result.getType(), op.getOperand(0), shape);
result.replaceAllUsesWith(newResult);
opsToErase.push_back(&op);
} else if (isa<mlir::DeallocOp>(op))
opsToErase.push_back(&op);
}
// Construct new function type.
SmallVector<Type, 16> inputTypes(func.getArgumentTypes());
inputTypes.append(weightTypes.begin(), weightTypes.end());
auto newType = FunctionType::get(inputTypes, func.getType().getResults(),
func.getContext());
// Record the argument number of the old function.
auto oldArgNum = func.getNumArguments();
// Set function type to newType.
func.setType(newType);
// Add new arguments to the entry block.
func.front().addArguments(weightTypes);
// Replace all uses of the kernel global operation with corresponding entry
// block argument.
SmallVector<int64_t, 16> weightIndex;
for (unsigned i = 0, e = weightOffsets.size(); i < e; ++i) {
weightValues[i].replaceAllUsesWith(
func.front().getArgument(i + oldArgNum));
weightIndex.push_back(i + oldArgNum);
}
// Set weight offset and index attribute.
func.setAttr("weight_offsets", builder.getI64ArrayAttr(weightOffsets));
func.setAttr("weight_index", builder.getI64ArrayAttr(weightIndex));
// Set other function attributes if the current function is top function.
if (func.getName() == topFunction) {
func.setAttr("top_function", builder.getBoolAttr(true));
func.setAttr("weight_file_name", builder.getStringAttr(weightFileName));
func.setAttr("weight_size_in_bytes",
builder.getI64IntegerAttr(weightSizeInBytes));
func.setAttr("inputs_num", builder.getI64IntegerAttr(numInputs));
func.setAttr("outputs_num", builder.getI64IntegerAttr(numOutputs));
}
}
// Erase all operations marked as erase.
for (auto op : opsToErase)
op->erase();
}
std::unique_ptr<mlir::Pass> scalehls::createLegalizeOnnxPass() {
return std::make_unique<LegalizeOnnx>();
}

View File

@ -1,8 +1,8 @@
// RUN: scalehls-opt -convert-to-hlscpp="top-function=test_conversion" %s | FileCheck %s
// RUN: scalehls-opt -convert-to-hlscpp="top-function=test_array_assign" %s | FileCheck %s
// CHECK-LABEL: func @test_conversion(
// CHECK-LABEL: func @test_array_assign(
// CHECK-SAME: %arg0: f32, %arg1: memref<16xf32>) -> (f32, memref<16xf32>, i32, tensor<2x2xi32>) attributes {dataflow = false, top_function = true} {
func @test_conversion(%arg0: f32, %arg1: memref<16xf32>) -> (f32, memref<16xf32>, i32, tensor<2x2xi32>) {
func @test_array_assign(%arg0: f32, %arg1: memref<16xf32>) -> (f32, memref<16xf32>, i32, tensor<2x2xi32>) {
// CHECK: %[[VAL_0:.*]] = "hlscpp.array"(%[[ARG_1:.*]]) {interface = true, partition = false, storage = true, storage_type = "ram_1p_bram"} : (memref<16xf32>) -> memref<16xf32>
%c11_i32 = constant 11 : i32
%cst = constant dense<[[11, 0], [0, -42]]> : tensor<2x2xi32>

274
test/onnx-mlir/mnist.mlir Normal file
View File

@ -0,0 +1,274 @@
// RUN: scalehls-opt -legalize-onnx %s | FileCheck %s
// CHECK: module {
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<() -> (0)>
#map2 = affine_map<() -> (32)>
#map3 = affine_map<() -> (1)>
#map4 = affine_map<(d0) -> (d0 + 2)>
#map5 = affine_map<() -> (28)>
#map6 = affine_map<(d0, d1) -> (d0 + d1)>
#map7 = affine_map<() -> (5)>
#map8 = affine_map<() -> (8)>
#map9 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map10 = affine_map<(d0)[s0, s1, s2, s3, s4] -> (0, d0 * s3 - s2)>
#map11 = affine_map<(d0) -> (28, d0 * -2 + 28, d0 * 2 + 2, 2)>
#map12 = affine_map<() -> (14)>
#map13 = affine_map<() -> (18)>
#map14 = affine_map<() -> (16)>
#map15 = affine_map<(d0) -> (14, d0 * -3 + 14, d0 * 3 + 3, 3)>
#map16 = affine_map<() -> (4)>
#map17 = affine_map<(d0, d1) -> (d0, d1)>
#map18 = affine_map<() -> (256)>
#map19 = affine_map<() -> (10)>
module {
%0 = "krnl.packed_const"() {file_name = "/tmp/packed_const-b8d8d9.tmp", is_le = true, size_in_bytes = 23840 : i64} : () -> i64
func @main_graph(%arg0: memref<1x1x28x28xf32>) -> memref<1x10xf32> attributes {input_names = ["Input3"], output_names = ["Plus214_Output_0"]} {
%c10240_i64 = constant 10240 : i64
%c28 = constant 28 : index
%c2 = constant 2 : index
%cst = constant 0xFF800000 : f32
%c14 = constant 14 : index
%c3 = constant 3 : index
%c1 = constant 1 : index
%c1024_i64 = constant 1024 : i64
%cst_0 = constant 1.000000e+00 : f32
%cst_1 = constant 0.000000e+00 : f32
%c0 = constant 0 : index
%1 = alloc() : memref<1x10xf32>
%2 = alloc() : memref<1x256xf32>
%3 = alloc() : memref<1x16x4x4xf32>
%4 = alloc() : memref<1x16x14x14xf32>
%5 = alloc() : memref<1x16x14x14xf32>
%6 = alloc() : memref<1x16x14x14xf32>
%7 = alloc() : memref<1x8x18x18xf32>
%8 = alloc() : memref<1x8x14x14xf32>
%9 = alloc() : memref<1x8x28x28xf32>
%10 = alloc() : memref<1x8x28x28xf32>
%11 = alloc() : memref<1x8x28x28xf32>
%12 = alloc() : memref<1x1x32x32xf32>
%13 = alloc() : memref<256x10xf32>
%14 = "krnl.global"() {name = "constant_0", offset = 0 : i64, shape = [16, 4, 4, 10]} : () -> memref<16x4x4x10xf32>
"krnl.memcpy"(%13, %14, %c10240_i64) : (memref<256x10xf32>, memref<16x4x4x10xf32>, i64) -> ()
%15 = "krnl.global"() {name = "constant_1", offset = 10240 : i64, shape = [8, 1, 5, 5]} : () -> memref<8x1x5x5xf32>
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 1 {
affine.for %arg3 = 0 to 32 {
affine.for %arg4 = 0 to 32 {
affine.store %cst_1, %12[%arg1, %arg2, %arg3, %arg4] : memref<1x1x32x32xf32>
}
}
}
}
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 1 {
affine.for %arg3 = 0 to 28 {
affine.for %arg4 = 0 to 28 {
%20 = affine.apply #map4(%arg3)
%21 = affine.apply #map4(%arg4)
%22 = affine.load %arg0[%arg1, %arg2, %arg3, %arg4] : memref<1x1x28x28xf32>
affine.store %22, %12[%arg1, %arg2, %20, %21] : memref<1x1x32x32xf32>
}
}
}
}
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 8 {
affine.for %arg3 = 0 to 28 {
affine.for %arg4 = 0 to 28 {
affine.store %cst_1, %11[%arg1, %arg2, %arg3, %arg4] : memref<1x8x28x28xf32>
affine.for %arg5 = 0 to 1 {
affine.for %arg6 = 0 to 5 {
affine.for %arg7 = 0 to 5 {
%20 = affine.apply #map6(%arg3, %arg6)
%21 = affine.apply #map6(%arg4, %arg7)
%22 = affine.load %12[%arg1, %arg5, %20, %21] : memref<1x1x32x32xf32>
%23 = affine.load %15[%arg2, %arg5, %arg6, %arg7] : memref<8x1x5x5xf32>
%24 = affine.load %11[%arg1, %arg2, %arg3, %arg4] : memref<1x8x28x28xf32>
%25 = mulf %22, %23 : f32
%26 = addf %24, %25 : f32
affine.store %26, %11[%arg1, %arg2, %arg3, %arg4] : memref<1x8x28x28xf32>
}
}
}
}
}
}
}
%16 = "krnl.global"() {name = "constant_2", offset = 11040 : i64, shape = [8, 1, 1], value = dense<[[[-0.161539719]], [[-0.433835655]], [[0.091641359]], [[-0.0168522168]], [[-0.0650264397]], [[-0.131737873]], [[0.0204175506]], [[-0.121110231]]]> : tensor<8x1x1xf32>} : () -> memref<8x1x1xf32>
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 8 {
affine.for %arg3 = 0 to 28 {
affine.for %arg4 = 0 to 28 {
%20 = affine.load %11[%arg1, %arg2, %arg3, %arg4] : memref<1x8x28x28xf32>
%21 = affine.load %16[%arg2, %c0, %c0] : memref<8x1x1xf32>
%22 = addf %20, %21 : f32
affine.store %22, %10[%arg1, %arg2, %arg3, %arg4] : memref<1x8x28x28xf32>
}
}
}
}
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 8 {
affine.for %arg3 = 0 to 28 {
affine.for %arg4 = 0 to 28 {
%20 = affine.load %10[%arg1, %arg2, %arg3, %arg4] : memref<1x8x28x28xf32>
%21 = cmpf "olt", %20, %cst_1 : f32
%22 = select %21, %cst_1, %20 : f32
affine.store %22, %9[%arg1, %arg2, %arg3, %arg4] : memref<1x8x28x28xf32>
}
}
}
}
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 8 {
affine.for %arg3 = 0 to 14 {
affine.for %arg4 = 0 to 14 {
affine.store %cst, %8[%arg1, %arg2, %arg3, %arg4] : memref<1x8x14x14xf32>
%20 = affine.max #map10(%arg3)[%c28, %c2, %c0, %c2, %c1]
%21 = affine.max #map10(%arg4)[%c28, %c2, %c0, %c2, %c1]
affine.for %arg5 = 0 to min #map11(%arg3) {
affine.for %arg6 = 0 to min #map11(%arg4) {
%22 = addi %arg5, %20 : index
%23 = addi %arg6, %21 : index
%24 = load %9[%arg1, %arg2, %22, %23] : memref<1x8x28x28xf32>
%25 = affine.load %8[%arg1, %arg2, %arg3, %arg4] : memref<1x8x14x14xf32>
%26 = cmpf "ogt", %25, %24 : f32
%27 = select %26, %25, %24 : f32
affine.store %27, %8[%arg1, %arg2, %arg3, %arg4] : memref<1x8x14x14xf32>
}
}
}
}
}
}
%17 = "krnl.global"() {name = "constant_3", offset = 11040 : i64, shape = [16, 8, 5, 5]} : () -> memref<16x8x5x5xf32>
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 8 {
affine.for %arg3 = 0 to 18 {
affine.for %arg4 = 0 to 18 {
affine.store %cst_1, %7[%arg1, %arg2, %arg3, %arg4] : memref<1x8x18x18xf32>
}
}
}
}
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 8 {
affine.for %arg3 = 0 to 14 {
affine.for %arg4 = 0 to 14 {
%20 = affine.apply #map4(%arg3)
%21 = affine.apply #map4(%arg4)
%22 = affine.load %8[%arg1, %arg2, %arg3, %arg4] : memref<1x8x14x14xf32>
affine.store %22, %7[%arg1, %arg2, %20, %21] : memref<1x8x18x18xf32>
}
}
}
}
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 16 {
affine.for %arg3 = 0 to 14 {
affine.for %arg4 = 0 to 14 {
affine.store %cst_1, %6[%arg1, %arg2, %arg3, %arg4] : memref<1x16x14x14xf32>
affine.for %arg5 = 0 to 8 {
affine.for %arg6 = 0 to 5 {
affine.for %arg7 = 0 to 5 {
%20 = affine.apply #map6(%arg3, %arg6)
%21 = affine.apply #map6(%arg4, %arg7)
%22 = affine.load %7[%arg1, %arg5, %20, %21] : memref<1x8x18x18xf32>
%23 = affine.load %17[%arg2, %arg5, %arg6, %arg7] : memref<16x8x5x5xf32>
%24 = affine.load %6[%arg1, %arg2, %arg3, %arg4] : memref<1x16x14x14xf32>
%25 = mulf %22, %23 : f32
%26 = addf %24, %25 : f32
affine.store %26, %6[%arg1, %arg2, %arg3, %arg4] : memref<1x16x14x14xf32>
}
}
}
}
}
}
}
%18 = "krnl.global"() {name = "constant_4", offset = 23840 : i64, shape = [16, 1, 1], value = dense<[[[-0.0822488219]], [[-0.108868778]], [[-0.141039595]], [[-0.204869166]], [[-0.17913565]], [[-0.215438381]], [[-0.133805066]], [[-0.195724562]], [[-0.268250644]], [[-0.258212209]], [[-0.0761560649]], [[0.0132841459]], [[-0.00444464432]], [[-0.414740831]], [[-0.17879115]], [[-0.0386558883]]]> : tensor<16x1x1xf32>} : () -> memref<16x1x1xf32>
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 16 {
affine.for %arg3 = 0 to 14 {
affine.for %arg4 = 0 to 14 {
%20 = affine.load %6[%arg1, %arg2, %arg3, %arg4] : memref<1x16x14x14xf32>
%21 = affine.load %18[%arg2, %c0, %c0] : memref<16x1x1xf32>
%22 = addf %20, %21 : f32
affine.store %22, %5[%arg1, %arg2, %arg3, %arg4] : memref<1x16x14x14xf32>
}
}
}
}
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 16 {
affine.for %arg3 = 0 to 14 {
affine.for %arg4 = 0 to 14 {
%20 = affine.load %5[%arg1, %arg2, %arg3, %arg4] : memref<1x16x14x14xf32>
%21 = cmpf "olt", %20, %cst_1 : f32
%22 = select %21, %cst_1, %20 : f32
affine.store %22, %4[%arg1, %arg2, %arg3, %arg4] : memref<1x16x14x14xf32>
}
}
}
}
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 16 {
affine.for %arg3 = 0 to 4 {
affine.for %arg4 = 0 to 4 {
affine.store %cst, %3[%arg1, %arg2, %arg3, %arg4] : memref<1x16x4x4xf32>
%20 = affine.max #map10(%arg3)[%c14, %c3, %c0, %c3, %c1]
%21 = affine.max #map10(%arg4)[%c14, %c3, %c0, %c3, %c1]
affine.for %arg5 = 0 to min #map15(%arg3) {
affine.for %arg6 = 0 to min #map15(%arg4) {
%22 = addi %arg5, %20 : index
%23 = addi %arg6, %21 : index
%24 = load %4[%arg1, %arg2, %22, %23] : memref<1x16x14x14xf32>
%25 = affine.load %3[%arg1, %arg2, %arg3, %arg4] : memref<1x16x4x4xf32>
%26 = cmpf "ogt", %25, %24 : f32
%27 = select %26, %25, %24 : f32
affine.store %27, %3[%arg1, %arg2, %arg3, %arg4] : memref<1x16x4x4xf32>
}
}
}
}
}
}
"krnl.memcpy"(%2, %3, %c1024_i64) : (memref<1x256xf32>, memref<1x16x4x4xf32>, i64) -> ()
%19 = "krnl.global"() {name = "constant_5", offset = 23840 : i64, shape = [1, 10], value = dense<[[-0.0448560268, 0.00779166119, 0.0681008175, 0.0299937408, -0.126409635, 0.14021875, -0.0552849025, -0.0493838154, 0.0843220502, -0.0545404144]]> : tensor<1x10xf32>} : () -> memref<1x10xf32>
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 10 {
affine.store %cst_1, %1[%arg1, %arg2] : memref<1x10xf32>
affine.for %arg3 = 0 to 256 {
%25 = affine.load %2[%arg1, %arg3] : memref<1x256xf32>
%26 = affine.load %13[%arg3, %arg2] : memref<256x10xf32>
%27 = affine.load %1[%arg1, %arg2] : memref<1x10xf32>
%28 = mulf %25, %26 : f32
%29 = addf %27, %28 : f32
affine.store %29, %1[%arg1, %arg2] : memref<1x10xf32>
}
%20 = affine.load %1[%arg1, %arg2] : memref<1x10xf32>
%21 = mulf %cst_0, %20 : f32
%22 = affine.load %19[%c0, %arg2] : memref<1x10xf32>
%23 = mulf %cst_0, %22 : f32
%24 = addf %21, %23 : f32
affine.store %24, %1[%arg1, %arg2] : memref<1x10xf32>
}
}
dealloc %13 : memref<256x10xf32>
dealloc %12 : memref<1x1x32x32xf32>
dealloc %11 : memref<1x8x28x28xf32>
dealloc %10 : memref<1x8x28x28xf32>
dealloc %9 : memref<1x8x28x28xf32>
dealloc %8 : memref<1x8x14x14xf32>
dealloc %7 : memref<1x8x18x18xf32>
dealloc %6 : memref<1x16x14x14xf32>
dealloc %5 : memref<1x16x14x14xf32>
dealloc %4 : memref<1x16x14x14xf32>
dealloc %3 : memref<1x16x4x4xf32>
dealloc %2 : memref<1x256xf32>
return %1 : memref<1x10xf32>
}
"krnl.entry_point"() {func = @main_graph, numInputs = 1 : i32, numOutputs = 1 : i32} : () -> ()
}

View File

@ -48,7 +48,7 @@ static llvm::cl::opt<bool> verifyPasses(
static llvm::cl::opt<bool> allowUnregisteredDialects(
"allow-unregistered-dialect",
llvm::cl::desc("Allow operation with no registered dialects"),
llvm::cl::init(false));
llvm::cl::init(true));
static llvm::cl::opt<bool>
showDialects("show-dialects",