From e0a844c63d3c30efc7a8051f93ec5769e2a234d3 Mon Sep 17 00:00:00 2001 From: Hanchen Ye Date: Sat, 19 Mar 2022 23:03:47 -0500 Subject: [PATCH] [CreateAxiInterface] Factor out this pass from LegalizeToHLSCpp; [Transforms] Add a new Rntime directory to hold runtime-related passes; Move CreateRuntimeMain to the new directory; Remove redundant pass initialization methods; [Support] Add new getTopFunc/getRuntimeFunc APIs --- include/scalehls/Support/Utils.h | 5 +++ include/scalehls/Transforms/Passes.h | 26 +++++------ include/scalehls/Transforms/Passes.td | 15 +++---- include/scalehls/Transforms/Utils.h | 3 +- lib/Support/Utils.cpp | 28 ++++++++++++ lib/Transforms/CMakeLists.txt | 3 +- lib/Transforms/Directive/ArrayPartition.cpp | 4 +- lib/Transforms/Graph/FuncDataflow.cpp | 3 -- lib/Transforms/Graph/HoistStreamChannel.cpp | 24 +++++----- lib/Transforms/LegalizeToHLSCpp.cpp | 33 +++----------- lib/Transforms/Loop/AffineLoopDataflow.cpp | 3 -- lib/Transforms/Loop/AffineLoopTile.cpp | 3 -- lib/Transforms/Loop/AffineLoopUnrollJam.cpp | 3 -- lib/Transforms/MultipleLevelDSE.cpp | 3 -- lib/Transforms/Passes.cpp | 10 +++-- lib/Transforms/Runtime/CreateAxiInterface.cpp | 45 +++++++++++++++++++ .../{ => Runtime}/CreateRuntimeMain.cpp | 13 ++---- lib/Transforms/Utils.cpp | 4 ++ test/Transforms/create_runtime_main.mlir | 4 +- 19 files changed, 134 insertions(+), 98 deletions(-) create mode 100644 lib/Transforms/Runtime/CreateAxiInterface.cpp rename lib/Transforms/{ => Runtime}/CreateRuntimeMain.cpp (92%) diff --git a/include/scalehls/Support/Utils.h b/include/scalehls/Support/Utils.h index fc2b7ee..6b77ba8 100644 --- a/include/scalehls/Support/Utils.h +++ b/include/scalehls/Support/Utils.h @@ -34,6 +34,7 @@ bool hasPointAttr(AffineForOp loop); /// Parse function directives. FuncDirectiveAttr getFuncDirective(Operation *op); bool hasTopFuncAttr(FuncOp func); +bool hasRuntimeAttr(FuncOp func); /// Parse array attributes. SmallVector getIntArrayAttrValue(Operation *op, StringRef name); @@ -105,6 +106,10 @@ bool checkDependence(Operation *A, Operation *B); /// Localize each tosa/arith constant to right before its each use. void localizeConstants(Block &block); +FuncOp getTopFunc(ModuleOp module, std::string topFuncName = ""); + +FuncOp getRuntimeFunc(ModuleOp module, std::string runtimeFuncName = ""); + //===----------------------------------------------------------------------===// // PtrLikeMemRefAccess Struct Declaration //===----------------------------------------------------------------------===// diff --git a/include/scalehls/Transforms/Passes.h b/include/scalehls/Transforms/Passes.h index 1ccafa0..895337d 100644 --- a/include/scalehls/Transforms/Passes.h +++ b/include/scalehls/Transforms/Passes.h @@ -25,28 +25,27 @@ void registerTransformsPasses(); /// QoR estimation and DSE passes. std::unique_ptr createQoREstimationPass(); std::unique_ptr createQoREstimationPass(std::string qorTargetSpec); -std::unique_ptr createMultipleLevelDSEPass(); -std::unique_ptr createMultipleLevelDSEPass(std::string dseTargetSpec); +std::unique_ptr +createMultipleLevelDSEPass(std::string dseTargetSpec = ""); /// Graph optimization passes. std::unique_ptr createFakeQuantizePass(); std::unique_ptr createSimplifyTosaGraphPass(); std::unique_ptr createHeuristicNodeFusionPass(); std::unique_ptr createCreateTokenFlowPass(); -std::unique_ptr createFuncDataflowPass(); -std::unique_ptr createFuncDataflowPass(unsigned dataflowGran, +std::unique_ptr createFuncDataflowPass(unsigned dataflowGran = 1, bool dataflowBalance = true); std::unique_ptr createTosaToLinalgCleanupPass(); std::unique_ptr createHoistStreamChannelPass(); /// Runtime-related passes. -std::unique_ptr createCreateRuntimeMainPass(); -std::unique_ptr createCreateRuntimeMainPass(std::string hlsTopFunc); +std::unique_ptr +createCreateRuntimeMainPass(std::string hlsTopFunc = "forward"); +std::unique_ptr createCreateAxiInterfacePass(); /// HLSCpp legalization pass. -std::unique_ptr createLegalizeToHLSCppPass(); -std::unique_ptr createLegalizeToHLSCppPass(std::string hlsTopFunc, - bool hlsAxiInterf = false); +std::unique_ptr +createLegalizeToHLSCppPass(std::string hlsTopFunc = "forward"); /// Loop optimization passes. std::unique_ptr @@ -55,14 +54,11 @@ std::unique_ptr createMaterializeReductionPass(); std::unique_ptr createAffineLoopPerfectionPass(); std::unique_ptr createRemoveVariableBoundPass(); std::unique_ptr createAffineLoopOrderOptPass(); -std::unique_ptr createAffineLoopTilePass(); -std::unique_ptr createAffineLoopTilePass(unsigned loopTileSize); -std::unique_ptr createAffineLoopUnrollJamPass(); +std::unique_ptr createAffineLoopTilePass(unsigned loopTileSize = 1); std::unique_ptr -createAffineLoopUnrollJamPass(unsigned loopUnrollSize, +createAffineLoopUnrollJamPass(unsigned loopUnrollSize = 1, bool unrollPointLoopOnly = false); -std::unique_ptr createAffineLoopDataflowPass(); -std::unique_ptr createAffineLoopDataflowPass(unsigned dataflowGran, +std::unique_ptr createAffineLoopDataflowPass(unsigned dataflowGran = 1, bool dataflowBalance = true); std::unique_ptr createSimplifyAffineIfPass(); diff --git a/include/scalehls/Transforms/Passes.td b/include/scalehls/Transforms/Passes.td index 8bfd679..f2de439 100644 --- a/include/scalehls/Transforms/Passes.td +++ b/include/scalehls/Transforms/Passes.td @@ -129,11 +129,6 @@ def HoistStreamChannel : Pass<"scalehls-hoist-stream-channel", "ModuleOp"> { }]; let constructor = "mlir::scalehls::createHoistStreamChannelPass()"; - - let options = [ - Option<"topFunc", "top-func", "std::string", - /*default=*/"\"main\"", "The top function to be transformed"> - ]; } //===----------------------------------------------------------------------===// @@ -157,6 +152,12 @@ def CreateRuntimeMain : Pass<"scalehls-create-runtime-main", "ModuleOp"> { ]; } +def CreateAxiInterface : Pass<"scalehls-create-axi-interface", "ModuleOp"> { + let summary = "Create AXI interfaces for the top function"; + + let constructor = "mlir::scalehls::createCreateAxiInterfacePass()"; +} + //===----------------------------------------------------------------------===// // HLSCpp Legalization Passes //===----------------------------------------------------------------------===// @@ -173,9 +174,7 @@ def LegalizeToHLSCpp : Pass<"scalehls-legalize-to-hlscpp", "FuncOp"> { let options = [ Option<"topFunc", "top-func", "std::string", /*default=*/"\"main\"", - "The top function for HLS synthesis">, - Option<"axiInterf", "axi-interf", "bool", /*default=*/"false", - "Whether to create AXI interfaces for the top function"> + "The top function for HLS synthesis"> ]; } diff --git a/include/scalehls/Transforms/Utils.h b/include/scalehls/Transforms/Utils.h index 48af9c7..ad13924 100644 --- a/include/scalehls/Transforms/Utils.h +++ b/include/scalehls/Transforms/Utils.h @@ -45,6 +45,7 @@ void setFuncDirective(Operation *op, FuncDirectiveAttr FuncDirective); void setFuncDirective(Operation *op, bool pipeline, int64_t targetInterval, bool dataflow); void setTopFuncAttr(FuncOp func); +void setRuntimeAttr(FuncOp func); //===----------------------------------------------------------------------===// // Optimization utils @@ -90,7 +91,7 @@ bool applyArrayPartition(Value array, ArrayRef factors, /// targeted function. bool applyAutoArrayPartition(FuncOp func); -bool applyLegalizeToHLSCpp(FuncOp func, bool topFunc, bool axiInterf = false); +bool applyLegalizeToHLSCpp(FuncOp func, bool topFunc); /// Apply memory optimizations. bool applyMemoryOpts(FuncOp func); diff --git a/lib/Support/Utils.cpp b/lib/Support/Utils.cpp index f5c5d01..01cfaa8 100644 --- a/lib/Support/Utils.cpp +++ b/lib/Support/Utils.cpp @@ -56,6 +56,10 @@ bool scalehls::hasTopFuncAttr(FuncOp func) { return func->hasAttrOfType("top_func"); } +bool scalehls::hasRuntimeAttr(FuncOp func) { + return func->hasAttrOfType("runtime"); +} + /// Parse array attributes. SmallVector scalehls::getIntArrayAttrValue(Operation *op, StringRef name) { @@ -445,6 +449,30 @@ void scalehls::localizeConstants(Block &block) { } } +FuncOp scalehls::getTopFunc(ModuleOp module, std::string topFuncName) { + FuncOp topFunc; + for (auto func : module.getOps()) + if (hasTopFuncAttr(func) || func.getName() == topFuncName) { + if (!topFunc) + topFunc = func; + else + return FuncOp(); + } + return topFunc; +} + +FuncOp scalehls::getRuntimeFunc(ModuleOp module, std::string runtimeFuncName) { + FuncOp runtimeFunc; + for (auto func : module.getOps()) + if (hasRuntimeAttr(func) || func.getName() == runtimeFuncName) { + if (!runtimeFunc) + runtimeFunc = func; + else + return FuncOp(); + } + return runtimeFunc; +} + //===----------------------------------------------------------------------===// // PtrLikeMemRefAccess Struct Definition //===----------------------------------------------------------------------===// diff --git a/lib/Transforms/CMakeLists.txt b/lib/Transforms/CMakeLists.txt index e81ac28..5bf0b82 100644 --- a/lib/Transforms/CMakeLists.txt +++ b/lib/Transforms/CMakeLists.txt @@ -25,7 +25,8 @@ add_mlir_library(MLIRScaleHLSTransforms Memory/RaiseImplicitCopy.cpp Memory/ReduceInitialInterval.cpp Memory/SimplifyMemrefAccess.cpp - CreateRuntimeMain.cpp + Runtime/CreateAxiInterface.cpp + Runtime/CreateRuntimeMain.cpp LegalizeToHLSCpp.cpp MultipleLevelDSE.cpp Passes.cpp diff --git a/lib/Transforms/Directive/ArrayPartition.cpp b/lib/Transforms/Directive/ArrayPartition.cpp index f4356e2..79f8918 100644 --- a/lib/Transforms/Directive/ArrayPartition.cpp +++ b/lib/Transforms/Directive/ArrayPartition.cpp @@ -430,7 +430,7 @@ struct ArrayPartition : public ArrayPartitionBase { // FIXME: A better solution to handle the runtime main function. FuncOp topFunc; for (auto func : module.getOps()) { - if (func.getName() == "main") { + if (hasRuntimeAttr(func)) { topFunc = func; break; } else if (hasTopFuncAttr(func)) @@ -438,7 +438,7 @@ struct ArrayPartition : public ArrayPartitionBase { } if (!topFunc) { - emitError(module.getLoc(), "top function is not found"); + emitError(module.getLoc(), "fail to find the top function"); return signalPassFailure(); } applyAutoArrayPartition(topFunc); diff --git a/lib/Transforms/Graph/FuncDataflow.cpp b/lib/Transforms/Graph/FuncDataflow.cpp index 8518bcd..07382cd 100644 --- a/lib/Transforms/Graph/FuncDataflow.cpp +++ b/lib/Transforms/Graph/FuncDataflow.cpp @@ -426,9 +426,6 @@ struct FuncDataflow : public FuncDataflowBase { }; } // namespace -std::unique_ptr scalehls::createFuncDataflowPass() { - return std::make_unique(); -} std::unique_ptr scalehls::createFuncDataflowPass(unsigned dataflowGran, bool dataflowBalance) { return std::make_unique(dataflowGran, dataflowBalance); diff --git a/lib/Transforms/Graph/HoistStreamChannel.cpp b/lib/Transforms/Graph/HoistStreamChannel.cpp index 6395d1c..1613c36 100644 --- a/lib/Transforms/Graph/HoistStreamChannel.cpp +++ b/lib/Transforms/Graph/HoistStreamChannel.cpp @@ -79,8 +79,7 @@ static void updateReturnOps(FuncOp func, // Updates all CallOps in the scope of the given ModuleOp by allocating // temporary buffers for newly introduced out params. -static LogicalResult updateCalls(ModuleOp module) { - bool didFail = false; +static void updateCalls(ModuleOp module) { module.walk([&](func::CallOp op) { SmallVector replaceWithNewCallResults; SmallVector replaceWithOutParams; @@ -109,8 +108,6 @@ static LogicalResult updateCalls(ModuleOp module) { std::get<0>(t).replaceAllUsesWith(std::get<1>(t)); op.erase(); }); - - return failure(didFail); } namespace { @@ -149,15 +146,18 @@ struct HoistStreamChannel : HoistStreamChannelBase { patterns.add(context); (void)applyPatternsAndFoldGreedily(module, std::move(patterns)); - for (auto func : module.getOps()) { - SmallVector appendedEntryArgs; - updateFuncOp(func, appendedEntryArgs); - if (func.isExternal()) - continue; - updateReturnOps(func, appendedEntryArgs); - } - if (failed(updateCalls(module))) + // Get the top function of the module. + auto func = getTopFunc(module); + if (!func) { + emitError(module.getLoc(), "fail to find the top function"); return signalPassFailure(); + } + + // Hoist stream channels to the top-function. + SmallVector appendedEntryArgs; + updateFuncOp(func, appendedEntryArgs); + updateReturnOps(func, appendedEntryArgs); + updateCalls(module); } }; } // namespace diff --git a/lib/Transforms/LegalizeToHLSCpp.cpp b/lib/Transforms/LegalizeToHLSCpp.cpp index 71c1c77..292fd5f 100644 --- a/lib/Transforms/LegalizeToHLSCpp.cpp +++ b/lib/Transforms/LegalizeToHLSCpp.cpp @@ -54,8 +54,7 @@ struct MemrefStoreRewritePattern : public OpRewritePattern { }; } // namespace -bool scalehls::applyLegalizeToHLSCpp(FuncOp func, bool isTopFunc, - bool axiInterf) { +bool scalehls::applyLegalizeToHLSCpp(FuncOp func, bool isTopFunc) { auto builder = OpBuilder(func); // We constrain functions to only contain one block. @@ -73,21 +72,6 @@ bool scalehls::applyLegalizeToHLSCpp(FuncOp func, bool isTopFunc, setParallelAttr(loop); }); - if (axiInterf) { - // Convert each argument memory kind to DRAM and buffer each of them. - for (auto arg : func.getArguments()) { - if (auto type = arg.getType().dyn_cast()) { - arg.setType(MemRefType::get(type.getShape(), type.getElementType(), - type.getLayout().getAffineMap(), - (unsigned)MemoryKind::DRAM)); - } - } - - // Finally, update the type of the function. - func.setType(builder.getFunctionType(func.front().getArgumentTypes(), - func.getResultTypes())); - } - // Insert BufferOp when an arguments or result of ConstantOp are directly // connected to ReturnOp. auto returnOp = func.front().getTerminator(); @@ -118,24 +102,17 @@ bool scalehls::applyLegalizeToHLSCpp(FuncOp func, bool isTopFunc, namespace { struct LegalizeToHLSCpp : public LegalizeToHLSCppBase { LegalizeToHLSCpp() = default; - LegalizeToHLSCpp(std::string hlsTopFunc, bool hlsAxiInterf) { - topFunc = hlsTopFunc; - axiInterf = hlsAxiInterf; - } + LegalizeToHLSCpp(std::string hlsTopFunc) { topFunc = hlsTopFunc; } void runOnOperation() override { auto func = getOperation(); auto isTop = func.getName() == topFunc; - applyLegalizeToHLSCpp(func, isTop, isTop && axiInterf); + applyLegalizeToHLSCpp(func, isTop); } }; } // namespace -std::unique_ptr scalehls::createLegalizeToHLSCppPass() { - return std::make_unique(); -} std::unique_ptr -scalehls::createLegalizeToHLSCppPass(std::string hlsTopFunc, - bool hlsAxiInterf) { - return std::make_unique(hlsTopFunc, hlsAxiInterf); +scalehls::createLegalizeToHLSCppPass(std::string hlsTopFunc) { + return std::make_unique(hlsTopFunc); } diff --git a/lib/Transforms/Loop/AffineLoopDataflow.cpp b/lib/Transforms/Loop/AffineLoopDataflow.cpp index 0854eec..7c253a9 100644 --- a/lib/Transforms/Loop/AffineLoopDataflow.cpp +++ b/lib/Transforms/Loop/AffineLoopDataflow.cpp @@ -38,9 +38,6 @@ struct AffineLoopDataflow : public AffineLoopDataflowBase { }; } // namespace -std::unique_ptr scalehls::createAffineLoopDataflowPass() { - return std::make_unique(); -} std::unique_ptr scalehls::createAffineLoopDataflowPass(unsigned dataflowGran, bool dataflowBalance) { diff --git a/lib/Transforms/Loop/AffineLoopTile.cpp b/lib/Transforms/Loop/AffineLoopTile.cpp index b032ba0..dc34c73 100644 --- a/lib/Transforms/Loop/AffineLoopTile.cpp +++ b/lib/Transforms/Loop/AffineLoopTile.cpp @@ -112,9 +112,6 @@ struct AffineLoopTile : public AffineLoopTileBase { /// Creates a pass to perform loop tiling on all suitable loop nests of a /// Function. -std::unique_ptr scalehls::createAffineLoopTilePass() { - return std::make_unique(); -} std::unique_ptr scalehls::createAffineLoopTilePass(unsigned loopTileSize) { return std::make_unique(loopTileSize); diff --git a/lib/Transforms/Loop/AffineLoopUnrollJam.cpp b/lib/Transforms/Loop/AffineLoopUnrollJam.cpp index 48c2ec6..197dd99 100644 --- a/lib/Transforms/Loop/AffineLoopUnrollJam.cpp +++ b/lib/Transforms/Loop/AffineLoopUnrollJam.cpp @@ -89,9 +89,6 @@ struct AffineLoopUnrollJam }; } // namespace -std::unique_ptr scalehls::createAffineLoopUnrollJamPass() { - return std::make_unique(); -} std::unique_ptr scalehls::createAffineLoopUnrollJamPass(unsigned loopUnrollSize, bool unrollPointLoopOnly) { diff --git a/lib/Transforms/MultipleLevelDSE.cpp b/lib/Transforms/MultipleLevelDSE.cpp index b2048ed..7581dd0 100644 --- a/lib/Transforms/MultipleLevelDSE.cpp +++ b/lib/Transforms/MultipleLevelDSE.cpp @@ -900,9 +900,6 @@ struct MultipleLevelDSE : public MultipleLevelDSEBase { }; } // namespace -std::unique_ptr scalehls::createMultipleLevelDSEPass() { - return std::make_unique(); -} std::unique_ptr scalehls::createMultipleLevelDSEPass(std::string dseTargetSpec) { return std::make_unique(dseTargetSpec); diff --git a/lib/Transforms/Passes.cpp b/lib/Transforms/Passes.cpp index 9c6ba5c..744d6e1 100644 --- a/lib/Transforms/Passes.cpp +++ b/lib/Transforms/Passes.cpp @@ -46,8 +46,9 @@ void scalehls::registerScaleHLSDSEPipeline() { "Launch design space exploration for C/C++ kernel", [](OpPassManager &pm, const ScaleHLSDSEPipelineOptions &opts) { // Legalize the input program. - pm.addPass(scalehls::createLegalizeToHLSCppPass(opts.hlsTopFunc, - opts.hlsAxiInterf)); + pm.addPass(scalehls::createLegalizeToHLSCppPass(opts.hlsTopFunc)); + if (opts.hlsAxiInterf) + pm.addPass(scalehls::createCreateAxiInterfacePass()); // We first run several passes to simplify the input program. pm.addPass(scalehls::createPromoteBufferPass()); @@ -208,8 +209,9 @@ void scalehls::registerScaleHLSTestPipeline() { "scalehls-test-pipeline", "Launch design space exploration for C/C++ kernel", [](OpPassManager &pm, const ScaleHLSTestPipelineOptions &opts) { - pm.addPass(scalehls::createLegalizeToHLSCppPass(opts.hlsTopFunc, - opts.hlsAxiInterf)); + pm.addPass(scalehls::createLegalizeToHLSCppPass(opts.hlsTopFunc)); + if (opts.hlsAxiInterf) + pm.addPass(scalehls::createCreateAxiInterfacePass()); pm.addPass(scalehls::createMaterializeReductionPass()); pm.addPass(scalehls::createAffineLoopPerfectionPass()); pm.addPass(scalehls::createRemoveVariableBoundPass()); diff --git a/lib/Transforms/Runtime/CreateAxiInterface.cpp b/lib/Transforms/Runtime/CreateAxiInterface.cpp new file mode 100644 index 0000000..cdf5782 --- /dev/null +++ b/lib/Transforms/Runtime/CreateAxiInterface.cpp @@ -0,0 +1,45 @@ +//===----------------------------------------------------------------------===// +// +// Copyright 2020-2021 The ScaleHLS Authors. +// +//===----------------------------------------------------------------------===// + +#include "scalehls/Transforms/Passes.h" +#include "scalehls/Transforms/Utils.h" + +using namespace mlir; +using namespace scalehls; + +namespace { +struct CreateAxiInterface + : public scalehls::CreateAxiInterfaceBase { + void runOnOperation() override { + auto module = getOperation(); + OpBuilder builder(module); + + // Get the top function of the module. + auto func = getTopFunc(module); + if (!func) { + emitError(module.getLoc(), "fail to find the top function"); + return signalPassFailure(); + } + + // Convert each argument memory kind to DRAM and buffer each of them. + for (auto arg : func.getArguments()) { + if (auto type = arg.getType().dyn_cast()) { + arg.setType(MemRefType::get(type.getShape(), type.getElementType(), + type.getLayout().getAffineMap(), + (unsigned)MemoryKind::DRAM)); + } + } + + // Finally, update the type of the function. + func.setType(builder.getFunctionType(func.front().getArgumentTypes(), + func.getResultTypes())); + } +}; +} // namespace + +std::unique_ptr scalehls::createCreateAxiInterfacePass() { + return std::make_unique(); +} diff --git a/lib/Transforms/CreateRuntimeMain.cpp b/lib/Transforms/Runtime/CreateRuntimeMain.cpp similarity index 92% rename from lib/Transforms/CreateRuntimeMain.cpp rename to lib/Transforms/Runtime/CreateRuntimeMain.cpp index 19ac062..ef45400 100644 --- a/lib/Transforms/CreateRuntimeMain.cpp +++ b/lib/Transforms/Runtime/CreateRuntimeMain.cpp @@ -48,17 +48,12 @@ struct CreateRuntimeMain : public CreateRuntimeMainBase { OpBuilder builder(module); // Get the top function of the module. - auto getTopFunc = [&]() { - for (auto func : module.getOps()) - if (func.getName() == topFunc) - return func; - return FuncOp(); - }; - auto func = getTopFunc(); + auto func = getTopFunc(module, topFunc); if (!func) { emitError(module.getLoc(), "fail to find the top function"); return signalPassFailure(); } + setTopFuncAttr(func); // Hoist local constants to the top function. for (auto call : llvm::make_early_inc_range(func.getOps())) { @@ -84,6 +79,7 @@ struct CreateRuntimeMain : public CreateRuntimeMainBase { builder.setInsertionPointAfter(func); auto mainFunc = builder.create(builder.getUnknownLoc(), "main", func.getType()); + setRuntimeAttr(mainFunc); auto entry = mainFunc.addEntryBlock(); auto constants = collectConstantsAndUpdateFuncionType(func); @@ -104,9 +100,6 @@ struct CreateRuntimeMain : public CreateRuntimeMainBase { }; } // namespace -std::unique_ptr scalehls::createCreateRuntimeMainPass() { - return std::make_unique(); -} std::unique_ptr scalehls::createCreateRuntimeMainPass(std::string hlsTopFunc) { return std::make_unique(hlsTopFunc); diff --git a/lib/Transforms/Utils.cpp b/lib/Transforms/Utils.cpp index fe77b6a..520b2b3 100644 --- a/lib/Transforms/Utils.cpp +++ b/lib/Transforms/Utils.cpp @@ -92,6 +92,10 @@ void scalehls::setTopFuncAttr(FuncOp func) { func->setAttr("top_func", UnitAttr::get(func.getContext())); } +void scalehls::setRuntimeAttr(FuncOp func) { + func->setAttr("runtime", UnitAttr::get(func.getContext())); +} + //===----------------------------------------------------------------------===// // Loop transform utils //===----------------------------------------------------------------------===// diff --git a/test/Transforms/create_runtime_main.mlir b/test/Transforms/create_runtime_main.mlir index e8ac0a4..0b383f5 100644 --- a/test/Transforms/create_runtime_main.mlir +++ b/test/Transforms/create_runtime_main.mlir @@ -57,7 +57,7 @@ module { return %1 : tensor<1x32x32x64xi8> } - // CHECK: func @forward(%arg0: tensor<1x3x32x32xi8>, %arg1: tensor<64x3x3x3xi8>, %arg2: tensor<64x3x3x64xi8>, %arg3: tensor<64x3x3x64xi8>, %arg4: tensor<1x64x10xi8>) -> tensor<1x10xi8> attributes {func_directive = #hlscpp.fd} { + // CHECK: func @forward(%arg0: tensor<1x3x32x32xi8>, %arg1: tensor<64x3x3x3xi8>, %arg2: tensor<64x3x3x64xi8>, %arg3: tensor<64x3x3x64xi8>, %arg4: tensor<1x64x10xi8>) -> tensor<1x10xi8> attributes {func_directive = #hlscpp.fd, top_func} { // CHECK: %0 = call @dataflow5(%arg0, %arg1) : (tensor<1x3x32x32xi8>, tensor<64x3x3x3xi8>) -> tensor<1x32x32x64xi8> // CHECK: %1:2 = call @dataflow4(%0, %arg2) : (tensor<1x32x32x64xi8>, tensor<64x3x3x64xi8>) -> (tensor<1x32x32x64xi8>, tensor<1x32x32x64xi8>) // CHECK: %2 = call @dataflow3(%1#0, %1#1, %arg3) : (tensor<1x32x32x64xi8>, tensor<1x32x32x64xi8>, tensor<64x3x3x64xi8>) -> tensor<1x32x32x64xi8> @@ -74,7 +74,7 @@ module { return %4 : tensor<1x10xi8> } - // CHECK: func @main(%arg0: tensor<1x3x32x32xi8>) -> tensor<1x10xi8> { + // CHECK: func @main(%arg0: tensor<1x3x32x32xi8>) -> tensor<1x10xi8> attributes {runtime} { // CHECK: %cst = arith.constant dense<4> : tensor<64x3x3x3xi8> // CHECK: %cst_0 = arith.constant dense<3> : tensor<64x3x3x64xi8> // CHECK: %cst_1 = arith.constant dense<2> : tensor<64x3x3x64xi8>