[ConvertCopyToAffineLoops] Implement this pass; Enable more passes in the scalehls pipeline
This commit is contained in:
parent
4c1e8f177f
commit
543356dac5
|
@ -50,6 +50,7 @@ std::unique_ptr<Pass> createSimplifyTosaGraphPass();
|
|||
std::unique_ptr<Pass> createLegalizeDataflowPass();
|
||||
std::unique_ptr<Pass> createLegalizeDataflowPass(unsigned dataflowGran);
|
||||
std::unique_ptr<Pass> createSplitFunctionPass();
|
||||
std::unique_ptr<Pass> createConvertCopyToAffineLoopsPass();
|
||||
|
||||
/// Loop optimization passes.
|
||||
std::unique_ptr<Pass> createMaterializeReductionPass();
|
||||
|
|
|
@ -108,6 +108,12 @@ def SplitFunction : Pass<"split-function", "ModuleOp"> {
|
|||
let constructor = "mlir::scalehls::createSplitFunctionPass()";
|
||||
}
|
||||
|
||||
def ConvertCopyToAffineLoops : Pass<"convert-copy-to-affine-loops", "FuncOp"> {
|
||||
let summary = "Convert copy and assign to affine loops";
|
||||
|
||||
let constructor = "mlir::scalehls::createConvertCopyToAffineLoopsPass()";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Loop Optimization Passes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -3,6 +3,7 @@ add_mlir_library(MLIRScaleHLSTransforms
|
|||
Directive/CreateHLSCppPrimitive.cpp
|
||||
Directive/FuncPipelining.cpp
|
||||
Directive/LoopPipelining.cpp
|
||||
Graph/ConvertCopyToAffineLoops.cpp
|
||||
Graph/LegalizeDataflow.cpp
|
||||
Graph/SimplifyTosaGraph.cpp
|
||||
Graph/SplitFunction.cpp
|
||||
|
|
|
@ -0,0 +1,150 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Copyright 2020-2021 The ScaleHLS Authors.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/Dominance.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "scalehls/Transforms/Passes.h"
|
||||
#include "scalehls/Transforms/Utils.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace scalehls;
|
||||
|
||||
namespace {
|
||||
struct AllocOpRewritePattern : public OpRewritePattern<memref::AllocOp> {
|
||||
AllocOpRewritePattern(MLIRContext *context, DominanceInfo &DT)
|
||||
: OpRewritePattern(context), DT(DT) {}
|
||||
using OpRewritePattern<memref::AllocOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(memref::AllocOp alloc,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto getCopyUser = [&]() {
|
||||
for (auto user : alloc->getUsers())
|
||||
if (auto copyUser = dyn_cast<memref::CopyOp>(user))
|
||||
return copyUser;
|
||||
return memref::CopyOp();
|
||||
};
|
||||
|
||||
// If the current alloc is not used by any copy, return failure.
|
||||
auto copy = getCopyUser();
|
||||
if (!copy)
|
||||
return failure();
|
||||
|
||||
// If the current alloc dominates another alloc, return failure.
|
||||
auto anotherMemref = alloc.memref() == copy.getSource() ? copy.getTarget()
|
||||
: copy.getSource();
|
||||
if (auto anotherAlloc = anotherMemref.getDefiningOp())
|
||||
if (DT.dominates(alloc.getOperation(), anotherAlloc))
|
||||
return failure();
|
||||
|
||||
// If the source memory is used after the copy op, we cannot eliminate the
|
||||
// target memory. This is conservative?
|
||||
if (llvm::any_of(copy.getSource().getUsers(), [&](Operation *user) {
|
||||
return DT.properlyDominates(copy, user);
|
||||
}))
|
||||
return failure();
|
||||
|
||||
// If the target memory is used before the copy op, we cannot eliminate
|
||||
// the target memory. This is conservative?
|
||||
if (llvm::any_of(copy.getTarget().getUsers(), [&](Operation *user) {
|
||||
return DT.properlyDominates(user, copy);
|
||||
}))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(alloc, anotherMemref);
|
||||
rewriter.eraseOp(copy);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
DominanceInfo &DT;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
struct AssignOpRewritePattern : public OpRewritePattern<AssignOp> {
|
||||
using OpRewritePattern<AssignOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(AssignOp assign,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!assign->hasOneUse())
|
||||
return failure();
|
||||
|
||||
auto toTensorOp = assign.input().getDefiningOp<bufferization::ToTensorOp>();
|
||||
auto toMemrefOp =
|
||||
dyn_cast<bufferization::ToMemrefOp>(*assign.output().user_begin());
|
||||
if (!toTensorOp || !toMemrefOp)
|
||||
return failure();
|
||||
|
||||
rewriter.setInsertionPointAfter(toMemrefOp);
|
||||
rewriter.create<memref::CopyOp>(assign.getLoc(), toTensorOp.memref(),
|
||||
toMemrefOp.memref());
|
||||
rewriter.replaceOpWithNewOp<memref::AllocOp>(
|
||||
toMemrefOp, toMemrefOp.getType().cast<MemRefType>());
|
||||
rewriter.eraseOp(assign);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
struct CopyOpRewritePattern : public OpRewritePattern<memref::CopyOp> {
|
||||
using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(memref::CopyOp copy,
|
||||
PatternRewriter &rewriter) const override {
|
||||
rewriter.setInsertionPoint(copy);
|
||||
auto loc = copy.getLoc();
|
||||
auto memrefType = copy.source().getType().cast<MemRefType>();
|
||||
|
||||
// Create explicit memory copy using an affine loop nest.
|
||||
SmallVector<Value, 4> ivs;
|
||||
for (auto dimSize : memrefType.getShape()) {
|
||||
auto loop = rewriter.create<mlir::AffineForOp>(loc, 0, dimSize);
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
ivs.push_back(loop.getInductionVar());
|
||||
}
|
||||
|
||||
// Create affine load/store operations.
|
||||
auto value = rewriter.create<mlir::AffineLoadOp>(loc, copy.source(), ivs);
|
||||
rewriter.create<mlir::AffineStoreOp>(loc, value, copy.target(), ivs);
|
||||
|
||||
rewriter.eraseOp(copy);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
struct ConvertCopyToAffineLoops
|
||||
: public ConvertCopyToAffineLoopsBase<ConvertCopyToAffineLoops> {
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
auto context = module.getContext();
|
||||
auto DT = DominanceInfo(module);
|
||||
|
||||
// Simplify alloc and copy ops.
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
patterns.add<AllocOpRewritePattern>(context, DT);
|
||||
patterns.add<AssignOpRewritePattern>(context);
|
||||
(void)applyPatternsAndFoldGreedily(module, std::move(patterns));
|
||||
|
||||
// Lower copy and assign operation.
|
||||
patterns.clear();
|
||||
patterns.add<CopyOpRewritePattern>(context);
|
||||
(void)applyPatternsAndFoldGreedily(module, std::move(patterns));
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> scalehls::createConvertCopyToAffineLoopsPass() {
|
||||
return std::make_unique<ConvertCopyToAffineLoops>();
|
||||
}
|
|
@ -7,6 +7,7 @@
|
|||
#include "scalehls/Transforms/Passes.h"
|
||||
#include "mlir/Conversion/Passes.h"
|
||||
#include "mlir/Dialect/Affine/Passes.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
|
@ -51,30 +52,33 @@ void scalehls::registerScaleHLSPassPipeline() {
|
|||
// Lower graph to affine.
|
||||
pm.addPass(tosa::createTosaToLinalgNamed());
|
||||
pm.addPass(tosa::createTosaToLinalg());
|
||||
pm.addPass(tosa::createTosaToStandard());
|
||||
pm.addPass(mlir::createLinalgGeneralizationPass());
|
||||
pm.addPass(mlir::createLinalgBufferizePass());
|
||||
pm.addPass(mlir::createFuncBufferizePass());
|
||||
pm.addPass(bufferization::createBufferResultsToOutParamsPass());
|
||||
pm.addPass(mlir::createConvertLinalgToAffineLoopsPass());
|
||||
pm.addPass(scalehls::createConvertCopyToAffineLoopsPass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
|
||||
// // Loop-level optimizations. Loop pipelining is included.
|
||||
// if (vectorSize)
|
||||
// pm.addPass(mlir::createSuperVectorizePass({vectorSize}));
|
||||
// pm.addPass(scalehls::createLegalizeToHLSCppPass(opts));
|
||||
// pm.addPass(scalehls::createMaterializeReductionPass());
|
||||
// if (loopTileSize) {
|
||||
// pm.addPass(scalehls::createAffineLoopPerfectionPass());
|
||||
// pm.addPass(scalehls::createRemoveVariableBoundPass());
|
||||
// pm.addPass(scalehls::createPartialAffineLoopTilePass(loopTileSize));
|
||||
// pm.addPass(mlir::createCanonicalizerPass());
|
||||
// }
|
||||
// Loop-level optimizations. Loop pipelining is included.
|
||||
if (vectorSize)
|
||||
pm.addPass(mlir::createSuperVectorizePass({vectorSize}));
|
||||
pm.addPass(scalehls::createLegalizeToHLSCppPass(opts));
|
||||
pm.addPass(scalehls::createMaterializeReductionPass());
|
||||
if (loopTileSize) {
|
||||
pm.addPass(scalehls::createAffineLoopPerfectionPass());
|
||||
pm.addPass(scalehls::createRemoveVariableBoundPass());
|
||||
pm.addPass(scalehls::createPartialAffineLoopTilePass(loopTileSize));
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
}
|
||||
|
||||
// // Simplifications.
|
||||
// pm.addPass(scalehls::createSimplifyAffineIfPass());
|
||||
// pm.addPass(scalehls::createAffineStoreForwardPass());
|
||||
// pm.addPass(scalehls::createSimplifyMemrefAccessPass());
|
||||
// pm.addPass(scalehls::createReduceInitialIntervalPass());
|
||||
// pm.addPass(mlir::createCanonicalizerPass());
|
||||
// Simplifications.
|
||||
pm.addPass(scalehls::createSimplifyAffineIfPass());
|
||||
pm.addPass(scalehls::createAffineStoreForwardPass());
|
||||
pm.addPass(scalehls::createSimplifyMemrefAccessPass());
|
||||
pm.addPass(scalehls::createReduceInitialIntervalPass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
|
||||
// // Directive-level optimizations.
|
||||
// pm.addPass(scalehls::createArrayPartitionPass());
|
||||
|
|
Loading…
Reference in New Issue