From cf4e36e626463a18f7ce432cdebc6b8a83014820 Mon Sep 17 00:00:00 2001 From: Hanchen Ye Date: Thu, 5 Nov 2020 15:48:47 -0600 Subject: [PATCH] [HLSKernelToAffine] impl lowering of maxpool, relu, gemm ops --- include/Dialect/HLSKernel/Ops.td | 3 +- .../HLSKernelToAffine/HLSKernelToAffine.cpp | 195 +++++++++++++++++- 2 files changed, 192 insertions(+), 6 deletions(-) diff --git a/include/Dialect/HLSKernel/Ops.td b/include/Dialect/HLSKernel/Ops.td index 9c15dab..870cb0d 100644 --- a/include/Dialect/HLSKernel/Ops.td +++ b/include/Dialect/HLSKernel/Ops.td @@ -11,8 +11,7 @@ def ConvOp : HLSKernelOp<"conv", [HLSKernelOpInterface]> { Convolution operation. For now, only static shaped 4-dims X (batch, channel, row, and col), 4-dims W (input channel, output channel, kernel row, and kernel col), 1-dim B (output channel), and 4-dims Y (batch, channel, row, - and col) are supported. Meanwhile, in the current lowering, padding is not - allowed. Verifiers will ensure the legalness of the operation. + and col) is supported. Verifiers will ensure the legalness of the operation. }]; let arguments = (ins diff --git a/lib/Conversion/HLSKernelToAffine/HLSKernelToAffine.cpp b/lib/Conversion/HLSKernelToAffine/HLSKernelToAffine.cpp index 300480d..410f023 100644 --- a/lib/Conversion/HLSKernelToAffine/HLSKernelToAffine.cpp +++ b/lib/Conversion/HLSKernelToAffine/HLSKernelToAffine.cpp @@ -33,6 +33,7 @@ public: }; } // namespace +/// Padding is not suppored. bool HLSKernelVisitor::visitOp(ConvOp op) { OpBuilder builder(op); @@ -115,11 +116,197 @@ bool HLSKernelVisitor::visitOp(ConvOp op) { return true; } -bool HLSKernelVisitor::visitOp(MaxPoolOp op) { return true; } +// Only support when kernel size is equal to stride size. +bool HLSKernelVisitor::visitOp(MaxPoolOp op) { + OpBuilder builder(op); -bool HLSKernelVisitor::visitOp(ReluOp op) { return true; } + SmallVector kernelShape; + for (auto shape : op.getAttrOfType("kernel_shape")) + kernelShape.push_back(shape.cast().getInt()); -bool HLSKernelVisitor::visitOp(GemmOp op) { return true; } + auto X = op.getOperand(); + auto Y = op.getResult(); + + auto YShape = Y.getType().cast().getShape(); + auto dataType = Y.getType().cast().getElementType(); + + auto newY = builder.create(op.getLoc(), + Y.getType().cast()); + Y.replaceAllUsesWith(newY); + + // Create batch loop. + auto gLoop = builder.create(op.getLoc(), 0, YShape[0]); + builder.setInsertionPointToStart(&gLoop.getLoopBody().front()); + auto g = gLoop.getInductionVar(); + + // Create channel loop. + auto cLoop = builder.create(op.getLoc(), 0, YShape[1]); + builder.setInsertionPointToStart(&cLoop.getLoopBody().front()); + auto c = cLoop.getInductionVar(); + + // Create height loop. + auto hLoop = builder.create(op.getLoc(), 0, YShape[2]); + builder.setInsertionPointToStart(&hLoop.getLoopBody().front()); + auto h = hLoop.getInductionVar(); + + // Create width loop. + auto wLoop = builder.create(op.getLoc(), 0, YShape[3]); + builder.setInsertionPointToStart(&wLoop.getLoopBody().front()); + auto w = wLoop.getInductionVar(); + + // Set largest value as zero. + auto zeroConstant = builder.create( + op.getLoc(), builder.getZeroAttr(dataType)); + builder.create(op.getLoc(), zeroConstant, newY, + ArrayRef({g, c, h, w})); + + // Create kernel height loop. + auto rLoop = + builder.create(op.getLoc(), 0, kernelShape[0]); + builder.setInsertionPointToStart(&rLoop.getLoopBody().front()); + auto r = rLoop.getInductionVar(); + + // Create kernel width loop. + auto sLoop = + builder.create(op.getLoc(), 0, kernelShape[1]); + builder.setInsertionPointToStart(&sLoop.getLoopBody().front()); + auto s = sLoop.getInductionVar(); + + // Fetch feature map. + SmallVector idxExprs; + idxExprs.push_back(builder.getAffineDimExpr(0)); + idxExprs.push_back(builder.getAffineDimExpr(1)); + idxExprs.push_back(builder.getAffineDimExpr(2) * + builder.getAffineConstantExpr(kernelShape[0]) + + builder.getAffineDimExpr(4)); + idxExprs.push_back(builder.getAffineDimExpr(3) * + builder.getAffineConstantExpr(kernelShape[1]) + + builder.getAffineDimExpr(5)); + auto fmap = builder.create( + op.getLoc(), X, AffineMap::get(6, 0, idxExprs, op.getContext()), + ArrayRef({g, c, h, w, r, s})); + + // Fetch current greatest value. + auto tmpGreatest = builder.create( + op.getLoc(), newY, ArrayRef({g, c, h, w})); + auto greaterThanTmp = builder.create( + op.getLoc(), CmpFPredicate::OGT, fmap, tmpGreatest); + + auto newGreatest = builder.create(op.getLoc(), greaterThanTmp, + fmap, tmpGreatest); + + // Store back the greater value. + builder.create(op.getLoc(), newGreatest, newY, + ArrayRef({g, c, h, w})); + + return true; +} + +bool HLSKernelVisitor::visitOp(ReluOp op) { + OpBuilder builder(op); + + auto X = op.getOperand(); + auto Y = op.getResult(); + + auto YShape = Y.getType().cast().getShape(); + + auto newY = builder.create(op.getLoc(), + Y.getType().cast()); + Y.replaceAllUsesWith(newY); + + // Create batch loop. + auto gLoop = builder.create(op.getLoc(), 0, YShape[0]); + builder.setInsertionPointToStart(&gLoop.getLoopBody().front()); + auto g = gLoop.getInductionVar(); + + // Create channel loop. + auto cLoop = builder.create(op.getLoc(), 0, YShape[1]); + builder.setInsertionPointToStart(&cLoop.getLoopBody().front()); + auto c = cLoop.getInductionVar(); + + // Create height loop. + auto hLoop = builder.create(op.getLoc(), 0, YShape[2]); + builder.setInsertionPointToStart(&hLoop.getLoopBody().front()); + auto h = hLoop.getInductionVar(); + + // Create width loop. + auto wLoop = builder.create(op.getLoc(), 0, YShape[3]); + builder.setInsertionPointToStart(&wLoop.getLoopBody().front()); + auto w = wLoop.getInductionVar(); + + // Load original value from input array. + auto fmap = builder.create(op.getLoc(), X, + ArrayRef({g, c, h, w})); + + // Carry out activation. + auto zeroConstant = builder.create( + op.getLoc(), builder.getZeroAttr(fmap.getType())); + auto greaterThanZero = builder.create( + op.getLoc(), CmpFPredicate::OGT, fmap, zeroConstant); + + auto activ = builder.create(op.getLoc(), greaterThanZero, + fmap, zeroConstant); + + // Store back the activations. + builder.create(op.getLoc(), activ, newY, + ArrayRef({g, c, h, w})); + + return true; +} + +bool HLSKernelVisitor::visitOp(GemmOp op) { + OpBuilder builder(op); + + auto X = op.getOperand(0); + auto W = op.getOperand(1); + auto B = op.getOperand(2); + auto Y = op.getResult(); + + auto WShape = W.getType().cast().getShape(); + auto YShape = Y.getType().cast().getShape(); + + auto newY = builder.create(op.getLoc(), + Y.getType().cast()); + Y.replaceAllUsesWith(newY); + + // Create batch loop. + auto gLoop = builder.create(op.getLoc(), 0, YShape[0]); + builder.setInsertionPointToStart(&gLoop.getLoopBody().front()); + auto g = gLoop.getInductionVar(); + + // Create output channel loop. + auto kLoop = builder.create(op.getLoc(), 0, WShape[0]); + builder.setInsertionPointToStart(&kLoop.getLoopBody().front()); + auto k = kLoop.getInductionVar(); + + // Load bias into newY array. + auto bias = builder.create(op.getLoc(), B, k); + builder.create(op.getLoc(), bias, newY, + ArrayRef({g, k})); + + // Create input channel loop. + auto cLoop = builder.create(op.getLoc(), 0, WShape[1]); + builder.setInsertionPointToStart(&cLoop.getLoopBody().front()); + auto c = cLoop.getInductionVar(); + + // Fetch feature map, weight and carry out multiplication. + auto fmap = builder.create(op.getLoc(), X, + ArrayRef({g, c})); + auto weight = builder.create(op.getLoc(), W, + ArrayRef({k, c})); + auto multi = + builder.create(op.getLoc(), fmap.getType(), fmap, weight); + + // Fetch partial result and carry out accumulation. + auto partial = builder.create(op.getLoc(), newY, + ArrayRef({g, k})); + auto accum = + builder.create(op.getLoc(), fmap.getType(), partial, multi); + builder.create(op.getLoc(), accum, newY, + ArrayRef({g, k})); + + return true; +} //===----------------------------------------------------------------------===// // HLSkernel to Affine Lowering Pass @@ -140,7 +327,7 @@ void HLSKernelToAffinePass::runOnOperation() { if (auto func = dyn_cast(op)) { func.walk([&](HLSKernelOpInterface kernelOp) { if (visitor.dispatchVisitor(kernelOp)) { - // kernelOp.erase(); + kernelOp.erase(); } else kernelOp.emitError("can't be correctly lowered."); });