106 lines
3.9 KiB
C++
106 lines
3.9 KiB
C++
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Copyright 2020-2021 The ScaleHLS Authors.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "scalehls/Dialect/HLS/HLS.h"
|
|
#include "scalehls/Transforms/Passes.h"
|
|
|
|
using namespace mlir;
|
|
using namespace scalehls;
|
|
using namespace hls;
|
|
|
|
using CallToFuncMap = llvm::SmallDenseMap<func::CallOp, func::FuncOp>;
|
|
|
|
namespace {
|
|
/// Sink memref.subview into its call users recursively.
|
|
struct SubViewSinkPattern : public OpRewritePattern<func::CallOp> {
|
|
using OpRewritePattern<func::CallOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(func::CallOp call,
|
|
PatternRewriter &rewriter) const override {
|
|
auto func = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
|
|
call, call.getCalleeAttr());
|
|
assert(func && "function definition not found");
|
|
|
|
SmallVector<Value, 16> newInputs;
|
|
bool hasChanged = false;
|
|
for (auto operand : call->getOperands()) {
|
|
if (auto subview = operand.getDefiningOp<memref::SubViewOp>()) {
|
|
// Create a cloned subview at the start of the function.
|
|
rewriter.setInsertionPointToStart(&func.front());
|
|
auto cloneSubview = cast<memref::SubViewOp>(rewriter.clone(*subview));
|
|
|
|
// Get the current argument and replace all its uses.
|
|
auto argIdx = newInputs.size();
|
|
auto arg = func.getArgument(argIdx);
|
|
arg.replaceAllUsesWith(cloneSubview.result());
|
|
func.eraseArgument(argIdx);
|
|
|
|
// Insert new arguments and replace the operand of the cloned subview.
|
|
for (auto type : llvm::enumerate(subview.getOperandTypes())) {
|
|
auto newArg = func.front().insertArgument(
|
|
argIdx + type.index(), type.value(), rewriter.getUnknownLoc());
|
|
cloneSubview.setOperand(type.index(), newArg);
|
|
}
|
|
|
|
newInputs.append(subview.operand_begin(), subview.operand_end());
|
|
hasChanged = true;
|
|
} else
|
|
newInputs.push_back(operand);
|
|
}
|
|
|
|
if (hasChanged) {
|
|
func.setType(rewriter.getFunctionType(ValueRange(newInputs),
|
|
func.getResultTypes()));
|
|
rewriter.setInsertionPoint(call);
|
|
rewriter.replaceOpWithNewOp<func::CallOp>(call, func, newInputs);
|
|
}
|
|
return success(hasChanged);
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace {
|
|
struct FuncDuplication : public FuncDuplicationBase<FuncDuplication> {
|
|
void runOnOperation() override {
|
|
auto module = getOperation();
|
|
auto context = module.getContext();
|
|
auto builder = OpBuilder(module);
|
|
|
|
for (auto func :
|
|
llvm::make_early_inc_range(module.getOps<func::FuncOp>())) {
|
|
unsigned idx = 0;
|
|
if (auto symbolUses = func.getSymbolUses(module)) {
|
|
for (auto use : llvm::make_early_inc_range(symbolUses.getValue())) {
|
|
auto call = cast<func::CallOp>(use.getUser());
|
|
builder.setInsertionPoint(func);
|
|
auto cloneFunc = cast<func::FuncOp>(builder.clone(*func));
|
|
auto newName = func.getName().str() + "_" + std::to_string(idx++);
|
|
cloneFunc.setName(newName);
|
|
call->setAttr(call.getCalleeAttrName(),
|
|
FlatSymbolRefAttr::get(func.getContext(), newName));
|
|
}
|
|
if (!symbolUses.getValue().empty())
|
|
func.erase();
|
|
}
|
|
}
|
|
|
|
// TODO: This should be factored out someday somehow. However, because this
|
|
// must be applied after function duplication, the refactoring has to be
|
|
// done very carefully.
|
|
mlir::RewritePatternSet patterns(context);
|
|
patterns.add<SubViewSinkPattern>(context);
|
|
(void)applyPatternsAndFoldGreedily(module, std::move(patterns));
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
std::unique_ptr<Pass> scalehls::createFuncDuplicationPass() {
|
|
return std::make_unique<FuncDuplication>();
|
|
}
|