[FuncDuplication] Implement this pass, which supports the subview sink after duplication; [Transforms] Move FuncPreprocess out of the Loop directory

This commit is contained in:
Hanchen Ye 2022-04-21 14:57:51 -05:00
parent 27b616ef34
commit 1b342cf6f1
7 changed files with 227 additions and 108 deletions

View File

@ -24,30 +24,32 @@ void registerScaleHLSDSEPipeline();
void registerScaleHLSPyTorchPipelineV2();
void registerTransformsPasses();
/// Design space exploration passes.
std::unique_ptr<Pass>
createDesignSpaceExplorePass(std::string dseTargetSpec = "");
std::unique_ptr<Pass>
createFuncPreprocessPass(std::string hlsTopFunc = "forward");
/// Graph optimization passes.
std::unique_ptr<Pass> createTosaFakeQuantizePass();
std::unique_ptr<Pass> createTosaSimplifyGraphPass();
std::unique_ptr<Pass> createTosaNodeFusionPass();
std::unique_ptr<Pass> createCreateTokenDependsPass();
/// Dataflow-related passes.
std::unique_ptr<Pass> createFuncDuplicationPass();
std::unique_ptr<Pass> createCreateFuncDataflowPass();
std::unique_ptr<Pass> createCreateLoopDataflowPass();
std::unique_ptr<Pass> createLegalizeDataflowPass();
std::unique_ptr<Pass> createCreateTokenDependsPass();
std::unique_ptr<Pass> createBufferizeDataflowPass();
std::unique_ptr<Pass> createConvertDataflowToFuncPass();
/// Graph-related passes.
std::unique_ptr<Pass> createTosaFakeQuantizePass();
std::unique_ptr<Pass> createTosaSimplifyGraphPass();
std::unique_ptr<Pass> createTosaNodeFusionPass();
std::unique_ptr<Pass> createTosaToLinalgCleanupPass();
/// Runtime-related passes.
std::unique_ptr<Pass> createCreateAxiInterfacePass();
std::unique_ptr<Pass>
createCreateRuntimeMainPass(std::string hlsTopFunc = "forward");
std::unique_ptr<Pass> createCreateAxiInterfacePass();
/// Loop optimization passes.
std::unique_ptr<Pass>
createFuncPreprocessPass(std::string hlsTopFunc = "forward");
/// Loop-related passes.
std::unique_ptr<Pass>
createConvertCopyToAffineLoopsPass(bool convertInternCopyOnly = true);
std::unique_ptr<Pass> createMaterializeReductionPass();
@ -60,7 +62,7 @@ createAffineLoopUnrollJamPass(unsigned loopUnrollFactor = 1,
bool unrollPointLoopOnly = false);
std::unique_ptr<Pass> createSimplifyAffineIfPass();
/// Memory optimization passes.
/// Memory-related passes.
std::unique_ptr<Pass> createCreateMemrefSubviewPass();
std::unique_ptr<Pass> createPromoteBufferPass();
std::unique_ptr<Pass> createAffineStoreForwardPass();
@ -68,7 +70,7 @@ std::unique_ptr<Pass> createSimplifyMemrefAccessPass();
std::unique_ptr<Pass> createRaiseImplicitCopyPass();
std::unique_ptr<Pass> createReduceInitialIntervalPass();
/// Directive optimization passes.
/// Directive-related passes.
std::unique_ptr<Pass> createFuncPipeliningPass();
std::unique_ptr<Pass> createLoopPipeliningPass();
std::unique_ptr<Pass> createArrayPartitionPass();

View File

@ -9,10 +9,6 @@
include "mlir/Pass/PassBase.td"
//===----------------------------------------------------------------------===//
// Design Space Exploration Pass
//===----------------------------------------------------------------------===//
def DesignSpaceExplore : Pass<"scalehls-dse", "ModuleOp"> {
let summary = "Optimize HLS design at multiple abstraction level";
let description = [{
@ -38,27 +34,23 @@ def DesignSpaceExplore : Pass<"scalehls-dse", "ModuleOp"> {
];
}
//===----------------------------------------------------------------------===//
// Graph Optimization Passes
//===----------------------------------------------------------------------===//
def FuncPreprocess : Pass<"scalehls-func-preprocess", "func::FuncOp"> {
let summary = "Preprocess the functions subsequent ScaleHLS optimizations";
let constructor = "mlir::scalehls::createFuncPreprocessPass()";
def TosaFakeQuantize : Pass<"scalehls-tosa-fake-quantize", "ModuleOp"> {
let summary = "Convert to 8-bits quantized model (only for testing use)";
let constructor = "mlir::scalehls::createTosaFakeQuantizePass()";
let options = [
Option<"topFunc", "top-func", "std::string", /*default=*/"\"main\"",
"The top function for HLS synthesis">
];
}
def TosaSimplifyGraph : Pass<"scalehls-tosa-simplify-graph", "func::FuncOp"> {
let summary = "Remove redundant TOSA operations";
let description = [{
This simplify-tosa-graph pass will try to remove redundant transpose ops
through pattern matching.
}];
let constructor = "mlir::scalehls::createTosaSimplifyGraphPass()";
}
//===----------------------------------------------------------------------===//
// Dataflow-related Passes
//===----------------------------------------------------------------------===//
def TosaNodeFusion : Pass<"scalehls-tosa-node-fusion", "func::FuncOp"> {
let summary = "Node fusion on TOSA operations";
let constructor = "mlir::scalehls::createTosaNodeFusionPass()";
def FuncDuplication : Pass<"scalehls-func-duplication", "mlir::ModuleOp"> {
let summary = "Duplicate function for each function call";
let constructor = "mlir::scalehls::createFuncDuplicationPass()";
}
def CreateFuncDataflow : Pass<"scalehls-create-func-dataflow", "func::FuncOp"> {
@ -102,6 +94,29 @@ def ConvertDataflowToFunc :
let constructor = "mlir::scalehls::createConvertDataflowToFuncPass()";
}
//===----------------------------------------------------------------------===//
// Graph-related Passes
//===----------------------------------------------------------------------===//
def TosaFakeQuantize : Pass<"scalehls-tosa-fake-quantize", "ModuleOp"> {
let summary = "Convert to 8-bits quantized model (only for testing use)";
let constructor = "mlir::scalehls::createTosaFakeQuantizePass()";
}
def TosaSimplifyGraph : Pass<"scalehls-tosa-simplify-graph", "func::FuncOp"> {
let summary = "Remove redundant TOSA operations";
let description = [{
This simplify-tosa-graph pass will try to remove redundant transpose ops
through pattern matching.
}];
let constructor = "mlir::scalehls::createTosaSimplifyGraphPass()";
}
def TosaNodeFusion : Pass<"scalehls-tosa-node-fusion", "func::FuncOp"> {
let summary = "Node fusion on TOSA operations";
let constructor = "mlir::scalehls::createTosaNodeFusionPass()";
}
def TosaToLinalgCleanup :
Pass<"scalehls-tosa-to-linalg-cleanup", "func::FuncOp"> {
let summary = "Lower tosa::ReshapeOp and tensor::PadOp";
@ -112,6 +127,11 @@ def TosaToLinalgCleanup :
// Runtime-related Passes
//===----------------------------------------------------------------------===//
def CreateAxiInterface : Pass<"scalehls-create-axi-interface", "ModuleOp"> {
let summary = "Create AXI interfaces for the top function";
let constructor = "mlir::scalehls::createCreateAxiInterfacePass()";
}
def CreateRuntimeMain : Pass<"scalehls-create-runtime-main", "ModuleOp"> {
let summary = "Create the main function of runtime";
let description = [{
@ -128,25 +148,10 @@ def CreateRuntimeMain : Pass<"scalehls-create-runtime-main", "ModuleOp"> {
];
}
def CreateAxiInterface : Pass<"scalehls-create-axi-interface", "ModuleOp"> {
let summary = "Create AXI interfaces for the top function";
let constructor = "mlir::scalehls::createCreateAxiInterfacePass()";
}
//===----------------------------------------------------------------------===//
// Loop Optimization Passes
// Loop-related Passes
//===----------------------------------------------------------------------===//
def FuncPreprocess : Pass<"scalehls-func-preprocess", "func::FuncOp"> {
let summary = "Preprocess the functions for loop and directive optimizations";
let constructor = "mlir::scalehls::createFuncPreprocessPass()";
let options = [
Option<"topFunc", "top-func", "std::string", /*default=*/"\"main\"",
"The top function for HLS synthesis">
];
}
def ConvertCopyToAffineLoops :
Pass<"scalehls-convert-copy-to-affine-loops", "func::FuncOp"> {
let summary = "Convert copy and assign to affine loops";
@ -252,62 +257,7 @@ def SimplifyAffineIf : Pass<"scalehls-simplify-affine-if", "func::FuncOp"> {
}
//===----------------------------------------------------------------------===//
// Directive Optimization Passes
//===----------------------------------------------------------------------===//
def FuncPipelining : Pass<"scalehls-func-pipelining", "func::FuncOp"> {
let summary = "Apply function pipelining";
let description = [{
This func-pipelining pass will insert pipeline pragma to the specified
function, all contained loops will be automatically unrolled.
}];
let constructor = "mlir::scalehls::createFuncPipeliningPass()";
let options = [
Option<"targetFunc", "target-func", "std::string",
/*default=*/"\"main\"", "The target function to be pipelined">,
Option<"targetII", "target-ii", "unsigned", /*default=*/"1",
"Positive number: the targeted II to achieve">
];
}
def LoopPipelining : Pass<"scalehls-loop-pipelining", "func::FuncOp"> {
let summary = "Apply loop pipelining";
let description = [{
This loop-pipelining pass will insert pipeline pragma to the target loop
level, and automatically unroll all inner loops.
}];
let constructor = "mlir::scalehls::createLoopPipeliningPass()";
let options = [
Option<"pipelineLevel", "pipeline-level", "unsigned", /*default=*/"0",
"Positive number: loop level to be pipelined (from innermost)">,
Option<"targetII", "target-ii", "unsigned", /*default=*/"1",
"Positive number: the targeted II to achieve">
];
}
def ArrayPartition : Pass<"scalehls-array-partition", "ModuleOp"> {
let summary = "Apply optimized array partition strategy";
let description = [{
This array-partition pass will automatically search for the best array
partition solution for each on-chip memory instance and apply the solution
through changing the layout of the corresponding memref.
}];
let constructor = "mlir::scalehls::createArrayPartitionPass()";
}
def CreateHLSPrimitive : Pass<"scalehls-create-hls-primitive", "func::FuncOp"> {
let summary = "Create HLS C++ multiplification primitives";
let description = [{
This create-hls-primitive pass will convert 8-bits multiplifications to HLS
C++ primitives in order to utilize DSP instances in FPGA.
}];
let constructor = "mlir::scalehls::createCreateHLSPrimitivePass()";
}
//===----------------------------------------------------------------------===//
// Memory Optimization Passes
// Memory-related Passes
//===----------------------------------------------------------------------===//
def CreateMemrefSubview :
@ -365,6 +315,61 @@ def ReduceInitialInterval :
let constructor = "mlir::scalehls::createReduceInitialIntervalPass()";
}
//===----------------------------------------------------------------------===//
// Directive-related Passes
//===----------------------------------------------------------------------===//
def FuncPipelining : Pass<"scalehls-func-pipelining", "func::FuncOp"> {
let summary = "Apply function pipelining";
let description = [{
This func-pipelining pass will insert pipeline pragma to the specified
function, all contained loops will be automatically unrolled.
}];
let constructor = "mlir::scalehls::createFuncPipeliningPass()";
let options = [
Option<"targetFunc", "target-func", "std::string",
/*default=*/"\"main\"", "The target function to be pipelined">,
Option<"targetII", "target-ii", "unsigned", /*default=*/"1",
"Positive number: the targeted II to achieve">
];
}
def LoopPipelining : Pass<"scalehls-loop-pipelining", "func::FuncOp"> {
let summary = "Apply loop pipelining";
let description = [{
This loop-pipelining pass will insert pipeline pragma to the target loop
level, and automatically unroll all inner loops.
}];
let constructor = "mlir::scalehls::createLoopPipeliningPass()";
let options = [
Option<"pipelineLevel", "pipeline-level", "unsigned", /*default=*/"0",
"Positive number: loop level to be pipelined (from innermost)">,
Option<"targetII", "target-ii", "unsigned", /*default=*/"1",
"Positive number: the targeted II to achieve">
];
}
def ArrayPartition : Pass<"scalehls-array-partition", "ModuleOp"> {
let summary = "Apply optimized array partition strategy";
let description = [{
This array-partition pass will automatically search for the best array
partition solution for each on-chip memory instance and apply the solution
through changing the layout of the corresponding memref.
}];
let constructor = "mlir::scalehls::createArrayPartitionPass()";
}
def CreateHLSPrimitive : Pass<"scalehls-create-hls-primitive", "func::FuncOp"> {
let summary = "Create HLS C++ multiplification primitives";
let description = [{
This create-hls-primitive pass will convert 8-bits multiplifications to HLS
C++ primitives in order to utilize DSP instances in FPGA.
}];
let constructor = "mlir::scalehls::createCreateHLSPrimitivePass()";
}
def QoREstimation : Pass<"scalehls-qor-estimation", "ModuleOp"> {
let summary = "Estimate the performance and resource utilization";
let description = [{

View File

@ -3,6 +3,7 @@ add_mlir_library(MLIRScaleHLSTransforms
Dataflow/ConvertDataflowToFunc.cpp
Dataflow/CreateDataflow.cpp
Dataflow/CreateTokenDepends.cpp
Dataflow/FuncDuplication.cpp
Dataflow/LegalizeDataflow.cpp
Directive/ArrayPartition.cpp
Directive/CreateHLSPrimitive.cpp
@ -18,7 +19,6 @@ add_mlir_library(MLIRScaleHLSTransforms
Loop/AffineLoopTile.cpp
Loop/AffineLoopUnrollJam.cpp
Loop/ConvertCopyToAffineLoops.cpp
Loop/FuncPreprocess.cpp
Loop/MaterializeReduction.cpp
Loop/RemoveVariableBound.cpp
Loop/SimplifyAffineIf.cpp
@ -31,6 +31,7 @@ add_mlir_library(MLIRScaleHLSTransforms
Runtime/CreateAxiInterface.cpp
Runtime/CreateRuntimeMain.cpp
DesignSpaceExplore.cpp
FuncPreprocess.cpp
Passes.cpp
Utils.cpp

View File

@ -0,0 +1,105 @@
//===----------------------------------------------------------------------===//
//
// Copyright 2020-2021 The ScaleHLS Authors.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "scalehls/Dialect/HLS/HLS.h"
#include "scalehls/Transforms/Passes.h"
using namespace mlir;
using namespace scalehls;
using namespace hls;
using CallToFuncMap = llvm::SmallDenseMap<func::CallOp, func::FuncOp>;
namespace {
/// Sink memref.subview into its call users recursively.
struct SubViewSinkPattern : public OpRewritePattern<func::CallOp> {
using OpRewritePattern<func::CallOp>::OpRewritePattern;
LogicalResult matchAndRewrite(func::CallOp call,
PatternRewriter &rewriter) const override {
auto func = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
call, call.getCalleeAttr());
assert(func && "function definition not found");
SmallVector<Value, 16> newInputs;
bool hasChanged = false;
for (auto operand : call->getOperands()) {
if (auto subview = operand.getDefiningOp<memref::SubViewOp>()) {
// Create a cloned subview at the start of the function.
rewriter.setInsertionPointToStart(&func.front());
auto cloneSubview = cast<memref::SubViewOp>(rewriter.clone(*subview));
// Get the current argument and replace all its uses.
auto argIdx = newInputs.size();
auto arg = func.getArgument(argIdx);
arg.replaceAllUsesWith(cloneSubview.result());
func.eraseArgument(argIdx);
// Insert new arguments and replace the operand of the cloned subview.
for (auto type : llvm::enumerate(subview.getOperandTypes())) {
auto newArg = func.front().insertArgument(
argIdx + type.index(), type.value(), rewriter.getUnknownLoc());
cloneSubview.setOperand(type.index(), newArg);
}
newInputs.append(subview.operand_begin(), subview.operand_end());
hasChanged = true;
} else
newInputs.push_back(operand);
}
if (hasChanged) {
func.setType(rewriter.getFunctionType(ValueRange(newInputs),
func.getResultTypes()));
rewriter.setInsertionPoint(call);
rewriter.replaceOpWithNewOp<func::CallOp>(call, func, newInputs);
}
return success(hasChanged);
}
};
} // namespace
namespace {
struct FuncDuplication : public FuncDuplicationBase<FuncDuplication> {
void runOnOperation() override {
auto module = getOperation();
auto context = module.getContext();
auto builder = OpBuilder(module);
for (auto func :
llvm::make_early_inc_range(module.getOps<func::FuncOp>())) {
unsigned idx = 0;
if (auto symbolUses = func.getSymbolUses(module)) {
for (auto use : llvm::make_early_inc_range(symbolUses.getValue())) {
auto call = cast<func::CallOp>(use.getUser());
builder.setInsertionPoint(func);
auto cloneFunc = cast<func::FuncOp>(builder.clone(*func));
auto newName = func.getName().str() + "_" + std::to_string(idx++);
cloneFunc.setName(newName);
call->setAttr(call.getCalleeAttrName(),
FlatSymbolRefAttr::get(func.getContext(), newName));
}
if (!symbolUses.getValue().empty())
func.erase();
}
}
// TODO: This should be factored out someday somehow. However, because this
// must be applied after function duplication, the refactoring has to be
// done very carefully.
mlir::RewritePatternSet patterns(context);
patterns.add<SubViewSinkPattern>(context);
(void)applyPatternsAndFoldGreedily(module, std::move(patterns));
}
};
} // namespace
std::unique_ptr<Pass> scalehls::createFuncDuplicationPass() {
return std::make_unique<FuncDuplication>();
}

View File

@ -62,8 +62,8 @@ struct GetGlobalConvertPattern : public OpRewritePattern<memref::GetGlobalOp> {
LogicalResult matchAndRewrite(memref::GetGlobalOp getGlobal,
PatternRewriter &rewriter) const override {
auto global = cast<memref::GlobalOp>(
SymbolTable::lookupNearestSymbolFrom(getGlobal, getGlobal.nameAttr()));
auto global = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(
getGlobal, getGlobal.nameAttr());
rewriter.setInsertionPoint(getGlobal);
rewriter.replaceOpWithNewOp<PrimConstOp>(
getGlobal, global.type(),

View File

@ -0,0 +1,6 @@
// RUN: scalehls-opt -scalehls-func-duplication %s | FileCheck %s
// CHECK-LABEL: func @test
func @test() {
return
}