mirror of https://github.com/llvm/circt.git
[CFToHandshake] Refactor towards genericness (#6156)
A couple of minor refactors - more will come - to allow for using `CFToHandshake` on things other than just `func.func` as the source operation and `handshake.func` as the target operation.
This commit is contained in:
parent
75c00d78d8
commit
60db6832b4
|
@ -69,23 +69,22 @@ public:
|
|||
LogicalResult addBranchOps(ConversionPatternRewriter &rewriter);
|
||||
LogicalResult replaceCallOps(ConversionPatternRewriter &rewriter);
|
||||
|
||||
template <typename TTerm>
|
||||
LogicalResult setControlOnlyPath(ConversionPatternRewriter &rewriter) {
|
||||
template <typename TSrcTerm, typename TDstTerm>
|
||||
LogicalResult setControlOnlyPath(ConversionPatternRewriter &rewriter,
|
||||
Value entryCtrl) {
|
||||
assert(entryCtrl.getType().isa<NoneType>() &&
|
||||
"Expected NoneType for entry control value");
|
||||
// Creates start and end points of the control-only path
|
||||
|
||||
// Add start point of the control-only path to the entry block's arguments
|
||||
Block *entryBlock = &r.front();
|
||||
startCtrl = entryBlock->addArgument(rewriter.getNoneType(),
|
||||
rewriter.getUnknownLoc());
|
||||
setBlockEntryControl(entryBlock, startCtrl);
|
||||
setBlockEntryControl(entryBlock, entryCtrl);
|
||||
|
||||
// Replace original return ops with new returns with additional control
|
||||
// input
|
||||
for (auto retOp : llvm::make_early_inc_range(r.getOps<TTerm>())) {
|
||||
for (auto retOp : llvm::make_early_inc_range(r.getOps<TSrcTerm>())) {
|
||||
rewriter.setInsertionPoint(retOp);
|
||||
SmallVector<Value, 8> operands(retOp->getOperands());
|
||||
operands.push_back(startCtrl);
|
||||
rewriter.replaceOpWithNewOp<handshake::ReturnOp>(retOp, operands);
|
||||
operands.push_back(entryCtrl);
|
||||
rewriter.replaceOpWithNewOp<TDstTerm>(retOp, operands);
|
||||
}
|
||||
|
||||
// Store the number of block arguments in each block
|
||||
|
@ -96,7 +95,7 @@ public:
|
|||
// Apply SSA maximization on the newly added entry block argument to
|
||||
// propagate it explicitly between the start-point of the control-only
|
||||
// network and the function's terminators
|
||||
if (failed(maximizeSSA(startCtrl, rewriter)))
|
||||
if (failed(maximizeSSA(entryCtrl, rewriter)))
|
||||
return failure();
|
||||
|
||||
// Identify all block arguments belonging to the control-only network
|
||||
|
@ -170,13 +169,19 @@ LogicalResult runPartialLowering(
|
|||
instance.getContext(), instance.getRegion());
|
||||
}
|
||||
|
||||
/// Remove basic blocks inside the given region. This allows the result to be
|
||||
/// a valid graph region, since multi-basic block regions are not allowed to
|
||||
/// be graph regions currently.
|
||||
void removeBasicBlocks(Region &r);
|
||||
|
||||
// Helper to check the validity of the dataflow conversion
|
||||
// Driver that applies the partial lowerings expressed in HandshakeLowering to
|
||||
// the region encapsulated in it. The region is assumed to have a terminator of
|
||||
// type TTerm. See HandshakeLowering for the different lowering steps.
|
||||
template <typename TTerm>
|
||||
// type TSrcTerm, and will replace it with TDstTerm. See HandshakeLowering for
|
||||
// the different lowering steps.
|
||||
template <typename TSrcTerm, typename TDstTerm>
|
||||
LogicalResult lowerRegion(HandshakeLowering &hl, bool sourceConstants,
|
||||
bool disableTaskPipelining) {
|
||||
bool disableTaskPipelining, Value entryCtrl) {
|
||||
// Perform initial dataflow conversion. This process allows for the use of
|
||||
// non-deterministic merge-like operations.
|
||||
HandshakeLowering::MemRefToMemoryAccessOp memOps;
|
||||
|
@ -184,8 +189,9 @@ LogicalResult lowerRegion(HandshakeLowering &hl, bool sourceConstants,
|
|||
if (failed(
|
||||
runPartialLowering(hl, &HandshakeLowering::replaceMemoryOps, memOps)))
|
||||
return failure();
|
||||
if (failed(runPartialLowering(hl,
|
||||
&HandshakeLowering::setControlOnlyPath<TTerm>)))
|
||||
if (failed(runPartialLowering(
|
||||
hl, &HandshakeLowering::setControlOnlyPath<TSrcTerm, TDstTerm>,
|
||||
entryCtrl)))
|
||||
return failure();
|
||||
if (failed(runPartialLowering(hl, &HandshakeLowering::addMergeOps)))
|
||||
return failure();
|
||||
|
@ -215,14 +221,13 @@ LogicalResult lowerRegion(HandshakeLowering &hl, bool sourceConstants,
|
|||
lsq)))
|
||||
return failure();
|
||||
|
||||
// Legalize the resulting regions, removing basic blocks and performing
|
||||
// any simple conversions.
|
||||
removeBasicBlocks(hl.getRegion());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Remove basic blocks inside the given region. This allows the result to be
|
||||
/// a valid graph region, since multi-basic block regions are not allowed to
|
||||
/// be graph regions currently.
|
||||
void removeBasicBlocks(Region &r);
|
||||
|
||||
/// Lowers the mlir operations into handshake that are not part of the dataflow
|
||||
/// conversion.
|
||||
LogicalResult postDataflowConvert(Operation *op);
|
||||
|
|
|
@ -120,6 +120,11 @@ def FuncOp : Op<Handshake_Dialect, "func", [
|
|||
return success();
|
||||
}
|
||||
|
||||
/// Returns the body block of the function.
|
||||
Block* getBodyBlock() {
|
||||
return &getBody().front();
|
||||
}
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// OpAsmOpInterface Methods
|
||||
//===------------------------------------------------------------------===//
|
||||
|
|
|
@ -219,23 +219,13 @@ void HandshakeLowering::setBlockEntryControl(Block *block, Value v) {
|
|||
}
|
||||
|
||||
void handshake::removeBasicBlocks(Region &r) {
|
||||
auto &entryBlock = r.front().getOperations();
|
||||
|
||||
// Now that basic blocks are going to be removed, we can erase all cf-dialect
|
||||
// branches, and move ReturnOp to the entry block's end
|
||||
for (auto &block : r) {
|
||||
Operation &termOp = block.back();
|
||||
if (isa<mlir::cf::CondBranchOp, mlir::cf::BranchOp>(termOp))
|
||||
termOp.erase();
|
||||
else if (isa<handshake::ReturnOp>(termOp))
|
||||
entryBlock.splice(entryBlock.end(), block.getOperations(), termOp);
|
||||
}
|
||||
Block *entryBlock = &r.front();
|
||||
auto &entryBlockOps = entryBlock->getOperations();
|
||||
|
||||
// Move all operations to entry block and erase other blocks.
|
||||
for (auto &block : llvm::make_early_inc_range(llvm::drop_begin(r, 1))) {
|
||||
entryBlock.splice(--entryBlock.end(), block.getOperations());
|
||||
}
|
||||
for (auto &block : llvm::make_early_inc_range(llvm::drop_begin(r, 1))) {
|
||||
for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(r, 1))) {
|
||||
entryBlockOps.splice(entryBlockOps.end(), block.getOperations());
|
||||
|
||||
block.clear();
|
||||
block.dropAllDefinedValueUses();
|
||||
for (size_t i = 0; i < block.getNumArguments(); i++) {
|
||||
|
@ -243,6 +233,21 @@ void handshake::removeBasicBlocks(Region &r) {
|
|||
}
|
||||
block.erase();
|
||||
}
|
||||
|
||||
// Remove any control flow operations, and move the non-control flow
|
||||
// terminator op to the end of the entry block.
|
||||
for (Operation &terminatorLike : llvm::make_early_inc_range(*entryBlock)) {
|
||||
if (!terminatorLike.hasTrait<OpTrait::IsTerminator>())
|
||||
continue;
|
||||
|
||||
if (isa<mlir::cf::CondBranchOp, mlir::cf::BranchOp>(terminatorLike)) {
|
||||
terminatorLike.erase();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Else, assume that this is a return-like terminator op.
|
||||
terminatorLike.moveBefore(entryBlock, entryBlock->end());
|
||||
}
|
||||
}
|
||||
|
||||
void removeBasicBlocks(handshake::FuncOp funcOp) {
|
||||
|
@ -1694,8 +1699,11 @@ static LogicalResult lowerFuncOp(func::FuncOp funcOp, MLIRContext *ctx,
|
|||
funcOp.getLoc(), funcOp.getName(), func_type, attributes);
|
||||
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
|
||||
newFuncOp.end());
|
||||
if (!newFuncOp.isExternal())
|
||||
if (!newFuncOp.isExternal()) {
|
||||
newFuncOp.getBodyBlock()->addArgument(rewriter.getNoneType(),
|
||||
funcOp.getLoc());
|
||||
newFuncOp.resolveArgAndResNames();
|
||||
}
|
||||
rewriter.eraseOp(funcOp);
|
||||
return success();
|
||||
},
|
||||
|
@ -1706,9 +1714,12 @@ static LogicalResult lowerFuncOp(func::FuncOp funcOp, MLIRContext *ctx,
|
|||
partiallyLowerRegion(maximizeSSANoMem, ctx, newFuncOp.getBody()));
|
||||
|
||||
if (!newFuncOp.isExternal()) {
|
||||
Block *bodyBlock = newFuncOp.getBodyBlock();
|
||||
Value entryCtrl = bodyBlock->getArguments().back();
|
||||
HandshakeLowering fol(newFuncOp.getBody());
|
||||
returnOnError(lowerRegion<func::ReturnOp>(fol, sourceConstants,
|
||||
disableTaskPipelining));
|
||||
if (failed(lowerRegion<func::ReturnOp, handshake::ReturnOp>(
|
||||
fol, sourceConstants, disableTaskPipelining, entryCtrl)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
|
@ -1736,11 +1747,6 @@ struct CFToHandshakePass : public CFToHandshakeBase<CFToHandshakePass> {
|
|||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Legalize the resulting regions, removing basic blocks and performing
|
||||
// any simple conversions.
|
||||
for (auto func : m.getOps<handshake::FuncOp>())
|
||||
removeBasicBlocks(func);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue