[CreateAxiInterface] Factor out this pass from LegalizeToHLSCpp; [Transforms] Add a new Rntime directory to hold runtime-related passes; Move CreateRuntimeMain to the new directory; Remove redundant pass initialization methods; [Support] Add new getTopFunc/getRuntimeFunc APIs

This commit is contained in:
Hanchen Ye 2022-03-19 23:03:47 -05:00
parent 48d90f2e62
commit e0a844c63d
19 changed files with 134 additions and 98 deletions

View File

@ -34,6 +34,7 @@ bool hasPointAttr(AffineForOp loop);
/// Parse function directives.
FuncDirectiveAttr getFuncDirective(Operation *op);
bool hasTopFuncAttr(FuncOp func);
bool hasRuntimeAttr(FuncOp func);
/// Parse array attributes.
SmallVector<int64_t, 8> getIntArrayAttrValue(Operation *op, StringRef name);
@ -105,6 +106,10 @@ bool checkDependence(Operation *A, Operation *B);
/// Localize each tosa/arith constant to right before its each use.
void localizeConstants(Block &block);
FuncOp getTopFunc(ModuleOp module, std::string topFuncName = "");
FuncOp getRuntimeFunc(ModuleOp module, std::string runtimeFuncName = "");
//===----------------------------------------------------------------------===//
// PtrLikeMemRefAccess Struct Declaration
//===----------------------------------------------------------------------===//

View File

@ -25,28 +25,27 @@ void registerTransformsPasses();
/// QoR estimation and DSE passes.
std::unique_ptr<Pass> createQoREstimationPass();
std::unique_ptr<Pass> createQoREstimationPass(std::string qorTargetSpec);
std::unique_ptr<Pass> createMultipleLevelDSEPass();
std::unique_ptr<Pass> createMultipleLevelDSEPass(std::string dseTargetSpec);
std::unique_ptr<Pass>
createMultipleLevelDSEPass(std::string dseTargetSpec = "");
/// Graph optimization passes.
std::unique_ptr<Pass> createFakeQuantizePass();
std::unique_ptr<Pass> createSimplifyTosaGraphPass();
std::unique_ptr<Pass> createHeuristicNodeFusionPass();
std::unique_ptr<Pass> createCreateTokenFlowPass();
std::unique_ptr<Pass> createFuncDataflowPass();
std::unique_ptr<Pass> createFuncDataflowPass(unsigned dataflowGran,
std::unique_ptr<Pass> createFuncDataflowPass(unsigned dataflowGran = 1,
bool dataflowBalance = true);
std::unique_ptr<Pass> createTosaToLinalgCleanupPass();
std::unique_ptr<Pass> createHoistStreamChannelPass();
/// Runtime-related passes.
std::unique_ptr<Pass> createCreateRuntimeMainPass();
std::unique_ptr<Pass> createCreateRuntimeMainPass(std::string hlsTopFunc);
std::unique_ptr<Pass>
createCreateRuntimeMainPass(std::string hlsTopFunc = "forward");
std::unique_ptr<Pass> createCreateAxiInterfacePass();
/// HLSCpp legalization pass.
std::unique_ptr<Pass> createLegalizeToHLSCppPass();
std::unique_ptr<Pass> createLegalizeToHLSCppPass(std::string hlsTopFunc,
bool hlsAxiInterf = false);
std::unique_ptr<Pass>
createLegalizeToHLSCppPass(std::string hlsTopFunc = "forward");
/// Loop optimization passes.
std::unique_ptr<Pass>
@ -55,14 +54,11 @@ std::unique_ptr<Pass> createMaterializeReductionPass();
std::unique_ptr<Pass> createAffineLoopPerfectionPass();
std::unique_ptr<Pass> createRemoveVariableBoundPass();
std::unique_ptr<Pass> createAffineLoopOrderOptPass();
std::unique_ptr<Pass> createAffineLoopTilePass();
std::unique_ptr<Pass> createAffineLoopTilePass(unsigned loopTileSize);
std::unique_ptr<Pass> createAffineLoopUnrollJamPass();
std::unique_ptr<Pass> createAffineLoopTilePass(unsigned loopTileSize = 1);
std::unique_ptr<Pass>
createAffineLoopUnrollJamPass(unsigned loopUnrollSize,
createAffineLoopUnrollJamPass(unsigned loopUnrollSize = 1,
bool unrollPointLoopOnly = false);
std::unique_ptr<Pass> createAffineLoopDataflowPass();
std::unique_ptr<Pass> createAffineLoopDataflowPass(unsigned dataflowGran,
std::unique_ptr<Pass> createAffineLoopDataflowPass(unsigned dataflowGran = 1,
bool dataflowBalance = true);
std::unique_ptr<Pass> createSimplifyAffineIfPass();

View File

@ -129,11 +129,6 @@ def HoistStreamChannel : Pass<"scalehls-hoist-stream-channel", "ModuleOp"> {
}];
let constructor = "mlir::scalehls::createHoistStreamChannelPass()";
let options = [
Option<"topFunc", "top-func", "std::string",
/*default=*/"\"main\"", "The top function to be transformed">
];
}
//===----------------------------------------------------------------------===//
@ -157,6 +152,12 @@ def CreateRuntimeMain : Pass<"scalehls-create-runtime-main", "ModuleOp"> {
];
}
def CreateAxiInterface : Pass<"scalehls-create-axi-interface", "ModuleOp"> {
let summary = "Create AXI interfaces for the top function";
let constructor = "mlir::scalehls::createCreateAxiInterfacePass()";
}
//===----------------------------------------------------------------------===//
// HLSCpp Legalization Passes
//===----------------------------------------------------------------------===//
@ -173,9 +174,7 @@ def LegalizeToHLSCpp : Pass<"scalehls-legalize-to-hlscpp", "FuncOp"> {
let options = [
Option<"topFunc", "top-func", "std::string", /*default=*/"\"main\"",
"The top function for HLS synthesis">,
Option<"axiInterf", "axi-interf", "bool", /*default=*/"false",
"Whether to create AXI interfaces for the top function">
"The top function for HLS synthesis">
];
}

View File

@ -45,6 +45,7 @@ void setFuncDirective(Operation *op, FuncDirectiveAttr FuncDirective);
void setFuncDirective(Operation *op, bool pipeline, int64_t targetInterval,
bool dataflow);
void setTopFuncAttr(FuncOp func);
void setRuntimeAttr(FuncOp func);
//===----------------------------------------------------------------------===//
// Optimization utils
@ -90,7 +91,7 @@ bool applyArrayPartition(Value array, ArrayRef<unsigned> factors,
/// targeted function.
bool applyAutoArrayPartition(FuncOp func);
bool applyLegalizeToHLSCpp(FuncOp func, bool topFunc, bool axiInterf = false);
bool applyLegalizeToHLSCpp(FuncOp func, bool topFunc);
/// Apply memory optimizations.
bool applyMemoryOpts(FuncOp func);

View File

@ -56,6 +56,10 @@ bool scalehls::hasTopFuncAttr(FuncOp func) {
return func->hasAttrOfType<UnitAttr>("top_func");
}
bool scalehls::hasRuntimeAttr(FuncOp func) {
return func->hasAttrOfType<UnitAttr>("runtime");
}
/// Parse array attributes.
SmallVector<int64_t, 8> scalehls::getIntArrayAttrValue(Operation *op,
StringRef name) {
@ -445,6 +449,30 @@ void scalehls::localizeConstants(Block &block) {
}
}
FuncOp scalehls::getTopFunc(ModuleOp module, std::string topFuncName) {
FuncOp topFunc;
for (auto func : module.getOps<FuncOp>())
if (hasTopFuncAttr(func) || func.getName() == topFuncName) {
if (!topFunc)
topFunc = func;
else
return FuncOp();
}
return topFunc;
}
FuncOp scalehls::getRuntimeFunc(ModuleOp module, std::string runtimeFuncName) {
FuncOp runtimeFunc;
for (auto func : module.getOps<FuncOp>())
if (hasRuntimeAttr(func) || func.getName() == runtimeFuncName) {
if (!runtimeFunc)
runtimeFunc = func;
else
return FuncOp();
}
return runtimeFunc;
}
//===----------------------------------------------------------------------===//
// PtrLikeMemRefAccess Struct Definition
//===----------------------------------------------------------------------===//

View File

@ -25,7 +25,8 @@ add_mlir_library(MLIRScaleHLSTransforms
Memory/RaiseImplicitCopy.cpp
Memory/ReduceInitialInterval.cpp
Memory/SimplifyMemrefAccess.cpp
CreateRuntimeMain.cpp
Runtime/CreateAxiInterface.cpp
Runtime/CreateRuntimeMain.cpp
LegalizeToHLSCpp.cpp
MultipleLevelDSE.cpp
Passes.cpp

View File

@ -430,7 +430,7 @@ struct ArrayPartition : public ArrayPartitionBase<ArrayPartition> {
// FIXME: A better solution to handle the runtime main function.
FuncOp topFunc;
for (auto func : module.getOps<FuncOp>()) {
if (func.getName() == "main") {
if (hasRuntimeAttr(func)) {
topFunc = func;
break;
} else if (hasTopFuncAttr(func))
@ -438,7 +438,7 @@ struct ArrayPartition : public ArrayPartitionBase<ArrayPartition> {
}
if (!topFunc) {
emitError(module.getLoc(), "top function is not found");
emitError(module.getLoc(), "fail to find the top function");
return signalPassFailure();
}
applyAutoArrayPartition(topFunc);

View File

@ -426,9 +426,6 @@ struct FuncDataflow : public FuncDataflowBase<FuncDataflow> {
};
} // namespace
std::unique_ptr<Pass> scalehls::createFuncDataflowPass() {
return std::make_unique<FuncDataflow>();
}
std::unique_ptr<Pass> scalehls::createFuncDataflowPass(unsigned dataflowGran,
bool dataflowBalance) {
return std::make_unique<FuncDataflow>(dataflowGran, dataflowBalance);

View File

@ -79,8 +79,7 @@ static void updateReturnOps(FuncOp func,
// Updates all CallOps in the scope of the given ModuleOp by allocating
// temporary buffers for newly introduced out params.
static LogicalResult updateCalls(ModuleOp module) {
bool didFail = false;
static void updateCalls(ModuleOp module) {
module.walk([&](func::CallOp op) {
SmallVector<Value, 6> replaceWithNewCallResults;
SmallVector<Value, 6> replaceWithOutParams;
@ -109,8 +108,6 @@ static LogicalResult updateCalls(ModuleOp module) {
std::get<0>(t).replaceAllUsesWith(std::get<1>(t));
op.erase();
});
return failure(didFail);
}
namespace {
@ -149,15 +146,18 @@ struct HoistStreamChannel : HoistStreamChannelBase<HoistStreamChannel> {
patterns.add<LowerStreamBufferOpRewritePattern>(context);
(void)applyPatternsAndFoldGreedily(module, std::move(patterns));
for (auto func : module.getOps<FuncOp>()) {
// Get the top function of the module.
auto func = getTopFunc(module);
if (!func) {
emitError(module.getLoc(), "fail to find the top function");
return signalPassFailure();
}
// Hoist stream channels to the top-function.
SmallVector<BlockArgument, 6> appendedEntryArgs;
updateFuncOp(func, appendedEntryArgs);
if (func.isExternal())
continue;
updateReturnOps(func, appendedEntryArgs);
}
if (failed(updateCalls(module)))
return signalPassFailure();
updateCalls(module);
}
};
} // namespace

View File

@ -54,8 +54,7 @@ struct MemrefStoreRewritePattern : public OpRewritePattern<memref::StoreOp> {
};
} // namespace
bool scalehls::applyLegalizeToHLSCpp(FuncOp func, bool isTopFunc,
bool axiInterf) {
bool scalehls::applyLegalizeToHLSCpp(FuncOp func, bool isTopFunc) {
auto builder = OpBuilder(func);
// We constrain functions to only contain one block.
@ -73,21 +72,6 @@ bool scalehls::applyLegalizeToHLSCpp(FuncOp func, bool isTopFunc,
setParallelAttr(loop);
});
if (axiInterf) {
// Convert each argument memory kind to DRAM and buffer each of them.
for (auto arg : func.getArguments()) {
if (auto type = arg.getType().dyn_cast<MemRefType>()) {
arg.setType(MemRefType::get(type.getShape(), type.getElementType(),
type.getLayout().getAffineMap(),
(unsigned)MemoryKind::DRAM));
}
}
// Finally, update the type of the function.
func.setType(builder.getFunctionType(func.front().getArgumentTypes(),
func.getResultTypes()));
}
// Insert BufferOp when an arguments or result of ConstantOp are directly
// connected to ReturnOp.
auto returnOp = func.front().getTerminator();
@ -118,24 +102,17 @@ bool scalehls::applyLegalizeToHLSCpp(FuncOp func, bool isTopFunc,
namespace {
struct LegalizeToHLSCpp : public LegalizeToHLSCppBase<LegalizeToHLSCpp> {
LegalizeToHLSCpp() = default;
LegalizeToHLSCpp(std::string hlsTopFunc, bool hlsAxiInterf) {
topFunc = hlsTopFunc;
axiInterf = hlsAxiInterf;
}
LegalizeToHLSCpp(std::string hlsTopFunc) { topFunc = hlsTopFunc; }
void runOnOperation() override {
auto func = getOperation();
auto isTop = func.getName() == topFunc;
applyLegalizeToHLSCpp(func, isTop, isTop && axiInterf);
applyLegalizeToHLSCpp(func, isTop);
}
};
} // namespace
std::unique_ptr<Pass> scalehls::createLegalizeToHLSCppPass() {
return std::make_unique<LegalizeToHLSCpp>();
}
std::unique_ptr<Pass>
scalehls::createLegalizeToHLSCppPass(std::string hlsTopFunc,
bool hlsAxiInterf) {
return std::make_unique<LegalizeToHLSCpp>(hlsTopFunc, hlsAxiInterf);
scalehls::createLegalizeToHLSCppPass(std::string hlsTopFunc) {
return std::make_unique<LegalizeToHLSCpp>(hlsTopFunc);
}

View File

@ -38,9 +38,6 @@ struct AffineLoopDataflow : public AffineLoopDataflowBase<AffineLoopDataflow> {
};
} // namespace
std::unique_ptr<Pass> scalehls::createAffineLoopDataflowPass() {
return std::make_unique<AffineLoopDataflow>();
}
std::unique_ptr<Pass>
scalehls::createAffineLoopDataflowPass(unsigned dataflowGran,
bool dataflowBalance) {

View File

@ -112,9 +112,6 @@ struct AffineLoopTile : public AffineLoopTileBase<AffineLoopTile> {
/// Creates a pass to perform loop tiling on all suitable loop nests of a
/// Function.
std::unique_ptr<Pass> scalehls::createAffineLoopTilePass() {
return std::make_unique<AffineLoopTile>();
}
std::unique_ptr<Pass>
scalehls::createAffineLoopTilePass(unsigned loopTileSize) {
return std::make_unique<AffineLoopTile>(loopTileSize);

View File

@ -89,9 +89,6 @@ struct AffineLoopUnrollJam
};
} // namespace
std::unique_ptr<Pass> scalehls::createAffineLoopUnrollJamPass() {
return std::make_unique<AffineLoopUnrollJam>();
}
std::unique_ptr<Pass>
scalehls::createAffineLoopUnrollJamPass(unsigned loopUnrollSize,
bool unrollPointLoopOnly) {

View File

@ -900,9 +900,6 @@ struct MultipleLevelDSE : public MultipleLevelDSEBase<MultipleLevelDSE> {
};
} // namespace
std::unique_ptr<Pass> scalehls::createMultipleLevelDSEPass() {
return std::make_unique<MultipleLevelDSE>();
}
std::unique_ptr<Pass>
scalehls::createMultipleLevelDSEPass(std::string dseTargetSpec) {
return std::make_unique<MultipleLevelDSE>(dseTargetSpec);

View File

@ -46,8 +46,9 @@ void scalehls::registerScaleHLSDSEPipeline() {
"Launch design space exploration for C/C++ kernel",
[](OpPassManager &pm, const ScaleHLSDSEPipelineOptions &opts) {
// Legalize the input program.
pm.addPass(scalehls::createLegalizeToHLSCppPass(opts.hlsTopFunc,
opts.hlsAxiInterf));
pm.addPass(scalehls::createLegalizeToHLSCppPass(opts.hlsTopFunc));
if (opts.hlsAxiInterf)
pm.addPass(scalehls::createCreateAxiInterfacePass());
// We first run several passes to simplify the input program.
pm.addPass(scalehls::createPromoteBufferPass());
@ -208,8 +209,9 @@ void scalehls::registerScaleHLSTestPipeline() {
"scalehls-test-pipeline",
"Launch design space exploration for C/C++ kernel",
[](OpPassManager &pm, const ScaleHLSTestPipelineOptions &opts) {
pm.addPass(scalehls::createLegalizeToHLSCppPass(opts.hlsTopFunc,
opts.hlsAxiInterf));
pm.addPass(scalehls::createLegalizeToHLSCppPass(opts.hlsTopFunc));
if (opts.hlsAxiInterf)
pm.addPass(scalehls::createCreateAxiInterfacePass());
pm.addPass(scalehls::createMaterializeReductionPass());
pm.addPass(scalehls::createAffineLoopPerfectionPass());
pm.addPass(scalehls::createRemoveVariableBoundPass());

View File

@ -0,0 +1,45 @@
//===----------------------------------------------------------------------===//
//
// Copyright 2020-2021 The ScaleHLS Authors.
//
//===----------------------------------------------------------------------===//
#include "scalehls/Transforms/Passes.h"
#include "scalehls/Transforms/Utils.h"
using namespace mlir;
using namespace scalehls;
namespace {
struct CreateAxiInterface
: public scalehls::CreateAxiInterfaceBase<CreateAxiInterface> {
void runOnOperation() override {
auto module = getOperation();
OpBuilder builder(module);
// Get the top function of the module.
auto func = getTopFunc(module);
if (!func) {
emitError(module.getLoc(), "fail to find the top function");
return signalPassFailure();
}
// Convert each argument memory kind to DRAM and buffer each of them.
for (auto arg : func.getArguments()) {
if (auto type = arg.getType().dyn_cast<MemRefType>()) {
arg.setType(MemRefType::get(type.getShape(), type.getElementType(),
type.getLayout().getAffineMap(),
(unsigned)MemoryKind::DRAM));
}
}
// Finally, update the type of the function.
func.setType(builder.getFunctionType(func.front().getArgumentTypes(),
func.getResultTypes()));
}
};
} // namespace
std::unique_ptr<Pass> scalehls::createCreateAxiInterfacePass() {
return std::make_unique<CreateAxiInterface>();
}

View File

@ -48,17 +48,12 @@ struct CreateRuntimeMain : public CreateRuntimeMainBase<CreateRuntimeMain> {
OpBuilder builder(module);
// Get the top function of the module.
auto getTopFunc = [&]() {
for (auto func : module.getOps<FuncOp>())
if (func.getName() == topFunc)
return func;
return FuncOp();
};
auto func = getTopFunc();
auto func = getTopFunc(module, topFunc);
if (!func) {
emitError(module.getLoc(), "fail to find the top function");
return signalPassFailure();
}
setTopFuncAttr(func);
// Hoist local constants to the top function.
for (auto call : llvm::make_early_inc_range(func.getOps<func::CallOp>())) {
@ -84,6 +79,7 @@ struct CreateRuntimeMain : public CreateRuntimeMainBase<CreateRuntimeMain> {
builder.setInsertionPointAfter(func);
auto mainFunc =
builder.create<FuncOp>(builder.getUnknownLoc(), "main", func.getType());
setRuntimeAttr(mainFunc);
auto entry = mainFunc.addEntryBlock();
auto constants = collectConstantsAndUpdateFuncionType(func);
@ -104,9 +100,6 @@ struct CreateRuntimeMain : public CreateRuntimeMainBase<CreateRuntimeMain> {
};
} // namespace
std::unique_ptr<Pass> scalehls::createCreateRuntimeMainPass() {
return std::make_unique<CreateRuntimeMain>();
}
std::unique_ptr<Pass>
scalehls::createCreateRuntimeMainPass(std::string hlsTopFunc) {
return std::make_unique<CreateRuntimeMain>(hlsTopFunc);

View File

@ -92,6 +92,10 @@ void scalehls::setTopFuncAttr(FuncOp func) {
func->setAttr("top_func", UnitAttr::get(func.getContext()));
}
void scalehls::setRuntimeAttr(FuncOp func) {
func->setAttr("runtime", UnitAttr::get(func.getContext()));
}
//===----------------------------------------------------------------------===//
// Loop transform utils
//===----------------------------------------------------------------------===//

View File

@ -57,7 +57,7 @@ module {
return %1 : tensor<1x32x32x64xi8>
}
// CHECK: func @forward(%arg0: tensor<1x3x32x32xi8>, %arg1: tensor<64x3x3x3xi8>, %arg2: tensor<64x3x3x64xi8>, %arg3: tensor<64x3x3x64xi8>, %arg4: tensor<1x64x10xi8>) -> tensor<1x10xi8> attributes {func_directive = #hlscpp.fd<pipeline=false, targetInterval=1, dataflow=true>} {
// CHECK: func @forward(%arg0: tensor<1x3x32x32xi8>, %arg1: tensor<64x3x3x3xi8>, %arg2: tensor<64x3x3x64xi8>, %arg3: tensor<64x3x3x64xi8>, %arg4: tensor<1x64x10xi8>) -> tensor<1x10xi8> attributes {func_directive = #hlscpp.fd<pipeline=false, targetInterval=1, dataflow=true>, top_func} {
// CHECK: %0 = call @dataflow5(%arg0, %arg1) : (tensor<1x3x32x32xi8>, tensor<64x3x3x3xi8>) -> tensor<1x32x32x64xi8>
// CHECK: %1:2 = call @dataflow4(%0, %arg2) : (tensor<1x32x32x64xi8>, tensor<64x3x3x64xi8>) -> (tensor<1x32x32x64xi8>, tensor<1x32x32x64xi8>)
// CHECK: %2 = call @dataflow3(%1#0, %1#1, %arg3) : (tensor<1x32x32x64xi8>, tensor<1x32x32x64xi8>, tensor<64x3x3x64xi8>) -> tensor<1x32x32x64xi8>
@ -74,7 +74,7 @@ module {
return %4 : tensor<1x10xi8>
}
// CHECK: func @main(%arg0: tensor<1x3x32x32xi8>) -> tensor<1x10xi8> {
// CHECK: func @main(%arg0: tensor<1x3x32x32xi8>) -> tensor<1x10xi8> attributes {runtime} {
// CHECK: %cst = arith.constant dense<4> : tensor<64x3x3x3xi8>
// CHECK: %cst_0 = arith.constant dense<3> : tensor<64x3x3x64xi8>
// CHECK: %cst_1 = arith.constant dense<2> : tensor<64x3x3x64xi8>