hanchenye-scalehls/lib/Transforms/Dataflow/LegalizeDataflow.cpp

229 lines
7.8 KiB
C++

//===----------------------------------------------------------------------===//
//
// Copyright 2020-2021 The ScaleHLS Authors.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "scalehls/Transforms/Passes.h"
#include "scalehls/Transforms/Utils.h"
using namespace mlir;
using namespace scalehls;
using namespace hls;
namespace {
struct MultiProducerRemovePattern : public OpRewritePattern<DataflowNodeOp> {
using OpRewritePattern<DataflowNodeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DataflowNodeOp node,
PatternRewriter &rewriter) const override {
auto loc = rewriter.getUnknownLoc();
auto output = node.getOutputOp();
bool hasChanged = false;
for (auto operand : output.getOperands())
if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
if (!operand.getDefiningOp<DataflowNodeOp>())
continue;
rewriter.setInsertionPointToStart(node.getBody());
auto buffer = rewriter.create<memref::AllocOp>(loc, memrefType);
operand.replaceUsesWithIf(buffer, [&](OpOperand &use) {
return node->isProperAncestor(use.getOwner());
});
rewriter.create<memref::CopyOp>(loc, operand, buffer);
hasChanged = true;
}
return success(hasChanged);
}
};
} // namespace
/// Get the dataflow level of an operation.
static Optional<unsigned> getDataflowLevel(Operation *op) {
if (op == op->getBlock()->getTerminator())
return (unsigned)0;
if (auto node = dyn_cast<DataflowNodeOp>(op))
return node.level();
if (auto buffer = dyn_cast<DataflowBufferOp>(op))
return buffer.level();
if (auto node = op->getParentOfType<DataflowNodeOp>())
return node.level();
return llvm::None;
}
/// Schedule the dataflow level of the given operation. Supports DataflowNodeOp
/// and DataflowBufferOp.
template <typename OpType>
static LogicalResult scheduleDataflowOp(OpType op, PatternRewriter &rewriter) {
unsigned level = 0;
for (auto user : op->getUsers()) {
auto userLevel = getDataflowLevel(user);
if (!userLevel.hasValue())
return failure();
level = std::max(level, userLevel.getValue() + 1);
}
op->setAttr(op.levelAttrName(), rewriter.getI32IntegerAttr(level));
return success();
}
namespace {
struct NodeSchedulePattern : public OpRewritePattern<DataflowNodeOp> {
using OpRewritePattern<DataflowNodeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DataflowNodeOp node,
PatternRewriter &rewriter) const override {
if (node.level().hasValue())
return failure();
return scheduleDataflowOp(node, rewriter);
}
};
} // namespace
namespace {
struct BufferInsertPattern : public OpRewritePattern<DataflowNodeOp> {
using OpRewritePattern<DataflowNodeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DataflowNodeOp node,
PatternRewriter &rewriter) const override {
if (!node.level().hasValue())
return failure();
auto loc = rewriter.getUnknownLoc();
bool hasChanged = false;
for (auto use : node.getDataflowUses()) {
auto userLevel = getDataflowLevel(use.second);
if (!userLevel.hasValue())
continue;
auto levelDiff = node.level().getValue() - userLevel.getValue();
if (levelDiff > 1) {
rewriter.setInsertionPointAfter(node);
auto buffer = rewriter.create<DataflowBufferOp>(
loc, use.first.getType(), use.first, levelDiff - 1);
use.first.replaceUsesWithIf(buffer.output(), [&](OpOperand &operand) {
return use.second->isAncestor(operand.getOwner());
});
hasChanged = true;
}
}
return success(hasChanged);
}
};
} // namespace
namespace {
struct BufferSplitPattern : public OpRewritePattern<DataflowBufferOp> {
using OpRewritePattern<DataflowBufferOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DataflowBufferOp buffer,
PatternRewriter &rewriter) const override {
// Single-element and external buffer don't need to split.
if (buffer.depth() == 1 || buffer.isExternal())
return failure();
Value currentValue = buffer.input();
DataflowBufferOp currentBuffer;
for (unsigned i = 0; i < buffer.depth(); ++i) {
rewriter.setInsertionPoint(buffer);
currentBuffer = rewriter.create<DataflowBufferOp>(
buffer.getLoc(), buffer.getType(), currentValue, /*depth=*/1);
currentValue = currentBuffer.output();
}
rewriter.replaceOp(buffer, currentValue);
return success();
}
};
} // namespace
namespace {
struct BufferSchedulePattern : public OpRewritePattern<DataflowBufferOp> {
using OpRewritePattern<DataflowBufferOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DataflowBufferOp buffer,
PatternRewriter &rewriter) const override {
// Multi-elements and external buffer should not be scheduled.
if (buffer.level().hasValue() || buffer.depth() != 1 || buffer.isExternal())
return failure();
return scheduleDataflowOp(buffer, rewriter);
}
};
} // namespace
namespace {
template <typename OpType>
struct DataflowMergePattern : public OpRewritePattern<OpType> {
using OpRewritePattern<OpType>::OpRewritePattern;
LogicalResult matchAndRewrite(OpType target,
PatternRewriter &rewriter) const override {
llvm::SmallDenseMap<unsigned, SmallVector<Operation *>> dataflowOpsList;
for (auto &op : target.getOps())
if (isa<DataflowNodeOp, DataflowBufferOp>(op)) {
// Multi-elements and external buffer should not be merged.
if (auto buffer = dyn_cast<DataflowBufferOp>(op))
if (buffer.depth() != 1 || buffer.isExternal())
continue;
if (auto level = getDataflowLevel(&op))
dataflowOpsList[level.getValue()].push_back(&op);
else
return op.emitOpError("is not scheduled");
}
bool hasChanged = false;
for (const auto &p : dataflowOpsList) {
auto node = fuseOpsIntoNewNode(p.second, rewriter);
node->setAttr(node.levelAttrName(), rewriter.getI32IntegerAttr(p.first));
hasChanged = true;
}
return success(hasChanged);
}
};
} // namespace
namespace {
struct LegalizeDataflow : public LegalizeDataflowBase<LegalizeDataflow> {
void runOnOperation() override {
auto func = getOperation();
auto context = func.getContext();
mlir::RewritePatternSet patterns(context);
patterns.add<MultiProducerRemovePattern>(context);
patterns.add<NodeSchedulePattern>(context);
patterns.add<BufferInsertPattern>(context);
patterns.add<BufferSplitPattern>(context);
patterns.add<BufferSchedulePattern>(context);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
// Legalize function dataflow.
patterns.clear();
patterns.add<DataflowMergePattern<func::FuncOp>>(context);
(void)applyOpPatternsAndFold(func, std::move(patterns));
if (!func.getOps<DataflowNodeOp>().empty())
setFuncDirective(func, false, 1, true);
// Collect all target loop bands.
AffineLoopBands targetBands;
getLoopBands(func.front(), targetBands, /*allowHavingChilds=*/true);
// Legalize loop dataflow to each innermost loop.
patterns.clear();
patterns.add<DataflowMergePattern<mlir::AffineForOp>>(context);
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
for (auto &band : targetBands) {
(void)applyOpPatternsAndFold(band.back(), frozenPatterns);
if (!band.back().getOps<DataflowNodeOp>().empty())
setLoopDirective(band.back(), false, 1, true, false);
}
}
};
} // namespace
std::unique_ptr<Pass> scalehls::createLegalizeDataflowPass() {
return std::make_unique<LegalizeDataflow>();
}