From 8a8de8b4c2cee96134f1062497bfa9619d2b7db6 Mon Sep 17 00:00:00 2001 From: Hanchen Ye Date: Mon, 7 Feb 2022 00:38:25 -0600 Subject: [PATCH] [MaterializeReduction] Implement this pass --- include/scalehls/Transforms/Passes.h | 3 +- include/scalehls/Transforms/Passes.td | 12 ++- lib/Transforms/Loop/MaterializeReduction.cpp | 90 ++++++++++++++++++++ 3 files changed, 103 insertions(+), 2 deletions(-) create mode 100644 lib/Transforms/Loop/MaterializeReduction.cpp diff --git a/include/scalehls/Transforms/Passes.h b/include/scalehls/Transforms/Passes.h index b73ac82..236495b 100644 --- a/include/scalehls/Transforms/Passes.h +++ b/include/scalehls/Transforms/Passes.h @@ -28,6 +28,7 @@ std::unique_ptr createLegalizeDataflowPass(); std::unique_ptr createSplitFunctionPass(); /// Loop optimization passes. +std::unique_ptr createMaterializeReductionPass(); std::unique_ptr createAffineLoopPerfectionPass(); std::unique_ptr createRemoveVariableBoundPass(); std::unique_ptr createAffineLoopOrderOptPass(); @@ -38,7 +39,7 @@ std::unique_ptr createFuncPipeliningPass(); std::unique_ptr createLoopPipeliningPass(); std::unique_ptr createArrayPartitionPass(); -/// Standard operation optimization passes. +/// Simplification passes. std::unique_ptr createSimplifyAffineIfPass(); std::unique_ptr createAffineStoreForwardPass(); std::unique_ptr createSimplifyMemrefAccessPass(); diff --git a/include/scalehls/Transforms/Passes.td b/include/scalehls/Transforms/Passes.td index 6d4191c..1b95fde 100644 --- a/include/scalehls/Transforms/Passes.td +++ b/include/scalehls/Transforms/Passes.td @@ -112,6 +112,16 @@ def SplitFunction : Pass<"split-function", "ModuleOp"> { // Loop Optimization Passes //===----------------------------------------------------------------------===// +def MaterializeReduction : Pass<"materialize-reduction", "FuncOp"> { + let summary = "Materialize loop reductions"; + let description = [{ + This materialize-reduction pass will materialize loop reductions with local + buffer read/writes in order to expose more optimizations targeting HLS. + }]; + + let constructor = "mlir::scalehls::createMaterializeReductionPass()"; +} + def AffineLoopPerfection : Pass<"affine-loop-perfection", "FuncOp"> { let summary = "Try to perfect a nested loop"; let description = [{ @@ -217,7 +227,7 @@ def ArrayPartition : Pass<"array-partition", "ModuleOp"> { } //===----------------------------------------------------------------------===// -// Standard Operation Optimization Passes +// Simplification Passes //===----------------------------------------------------------------------===// def SimplifyAffineIf : Pass<"simplify-affine-if", "FuncOp"> { diff --git a/lib/Transforms/Loop/MaterializeReduction.cpp b/lib/Transforms/Loop/MaterializeReduction.cpp new file mode 100644 index 0000000..8cb28ca --- /dev/null +++ b/lib/Transforms/Loop/MaterializeReduction.cpp @@ -0,0 +1,90 @@ +//===----------------------------------------------------------------------===// +// +// Copyright 2020-2021 The ScaleHLS Authors. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "scalehls/Transforms/Passes.h" + +using namespace mlir; +using namespace scalehls; + +namespace { +struct MaterializeReductionPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AffineForOp loop, + PatternRewriter &rewriter) const override { + if (!loop.getNumIterOperands()) + return success(); + auto loc = rewriter.getUnknownLoc(); + auto yield = cast(loop.getBody()->getTerminator()); + + // Traverse all iteration values. + for (auto zip : llvm::zip(loop.getIterOperands(), loop.getRegionIterArgs(), + yield.getOperands(), loop.getResults())) { + auto iterOperand = std::get<0>(zip); + auto iterArg = std::get<1>(zip); + auto yieldOperand = std::get<2>(zip); + auto yieldResult = std::get<3>(zip); + + // Create a buffer for the iteration value before the loop and set the + // initial state. + auto memrefType = MemRefType::get({1}, iterOperand.getType()); + auto map = rewriter.getConstantAffineMap(0); + rewriter.setInsertionPoint(loop); + auto buf = rewriter.create(loc, memrefType); + rewriter.create(loc, iterOperand, buf, map, ValueRange()); + + // Load the iteration value from the buffer at the begining of loop and + // replace all uses. + rewriter.setInsertionPointToStart(loop.getBody()); + auto partial = rewriter.create(loc, buf, map, ValueRange()); + iterArg.replaceAllUsesWith(partial); + + // Update the state of the buffer at the end of loop. + rewriter.setInsertionPoint(yield); + rewriter.create(loc, yieldOperand, buf, map, ValueRange()); + + // Load from the buffer after the loop and replace all uses. + rewriter.setInsertionPointAfter(loop); + auto result = rewriter.create(loc, buf, map, ValueRange()); + yieldResult.replaceAllUsesWith(result); + } + + // Create a new loop without iteration operands. + rewriter.setInsertionPoint(loop); + auto newLoop = rewriter.create( + loop.getLoc(), loop.getLowerBoundOperands(), loop.getLowerBoundMap(), + loop.getUpperBoundOperands(), loop.getUpperBoundMap(), loop.getStep()); + auto &loopOps = loop.getBody()->getOperations(); + auto &newLoopOps = newLoop.getBody()->getOperations(); + newLoopOps.splice(newLoopOps.begin(), loopOps, loopOps.begin(), + std::prev(loopOps.end())); + loop.getInductionVar().replaceAllUsesWith(newLoop.getInductionVar()); + + // Remove the original loop now. + rewriter.eraseOp(loop); + return success(); + } +}; +} // namespace + +namespace { +struct MaterializeReduction + : public MaterializeReductionBase { + void runOnOperation() override { + auto func = getOperation(); + mlir::RewritePatternSet patterns(func.getContext()); + patterns.add(func.getContext()); + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + } +}; +} // namespace + +std::unique_ptr scalehls::createMaterializeReductionPass() { + return std::make_unique(); +}