[HLSKernelToAffine] lowering impl of MergeOp
This commit is contained in:
parent
8f1eadd913
commit
10050bc6a3
|
@ -32,7 +32,7 @@ public:
|
|||
bool visitOp(ConvOp op);
|
||||
bool visitOp(MaxPoolOp op);
|
||||
bool visitOp(ReluOp op);
|
||||
// bool visitOp(MergeOp op);
|
||||
bool visitOp(MergeOp op);
|
||||
|
||||
private:
|
||||
OpBuilder &builder;
|
||||
|
@ -65,6 +65,7 @@ private:
|
|||
}
|
||||
}
|
||||
|
||||
// Helpers for getting dimension or constant affine expression.
|
||||
AffineExpr getDim(unsigned pos) { return builder.getAffineDimExpr(pos); }
|
||||
AffineExpr getConst(int64_t val) {
|
||||
return builder.getAffineConstantExpr(val);
|
||||
|
@ -270,6 +271,39 @@ bool HLSKernelVisitor::visitOp(ReluOp op) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool HLSKernelVisitor::visitOp(MergeOp op) {
|
||||
auto I0 = op.getOperand(0);
|
||||
auto I1 = op.getOperand(1);
|
||||
auto O = op.getOperand(2);
|
||||
|
||||
auto OShape = O.getType().cast<MemRefType>().getShape();
|
||||
|
||||
// Set insertion point of builder.
|
||||
builder.setInsertionPoint(op);
|
||||
|
||||
// Create batch loop.
|
||||
auto n = createLoop(OShape[0]);
|
||||
|
||||
// Create height loop.
|
||||
auto h = createLoop(OShape[2]);
|
||||
|
||||
// Create width loop.
|
||||
auto w = createLoop(OShape[3]);
|
||||
|
||||
// Create channel loop.
|
||||
auto c = createLoop(OShape[1]);
|
||||
|
||||
// Load original value from input array.
|
||||
auto fmap0 = createLoad(I0, {n, c, h, w});
|
||||
auto fmap1 = createLoad(I1, {n, c, h, w});
|
||||
|
||||
// Carry out add and store the result.
|
||||
auto result = createBinaryOp<mlir::AddFOp>(fmap0, fmap1);
|
||||
createStore(result, O, {n, c, h, w});
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// HLSkernel to Affine Lowering Pass
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
// RUN: scalehls-opt -hlskernel-to-affine %s | FileCheck %s
|
||||
|
||||
// CHECK: module {
|
||||
func @test_merge(%I0: memref<10x16x28x28xf32>, %I1: memref<10x16x28x28xf32>, %O: memref<10x16x28x28xf32>) -> () {
|
||||
"hlskernel.merge" (%I0, %I1, %O) {} : (memref<10x16x28x28xf32>, memref<10x16x28x28xf32>, memref<10x16x28x28xf32>) -> ()
|
||||
return
|
||||
}
|
Loading…
Reference in New Issue