[CFToHandshake] Refactor towards genericness (#6156)

A couple of minor refactors - more will come - to allow for using `CFToHandshake` on things other than just `func.func` as the source operation and `handshake.func` as the target operation.
This commit is contained in:
Morten Borup Petersen 2023-09-21 10:24:50 +02:00 committed by GitHub
parent 75c00d78d8
commit 60db6832b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 60 additions and 44 deletions

View File

@ -69,23 +69,22 @@ public:
LogicalResult addBranchOps(ConversionPatternRewriter &rewriter);
LogicalResult replaceCallOps(ConversionPatternRewriter &rewriter);
template <typename TTerm>
LogicalResult setControlOnlyPath(ConversionPatternRewriter &rewriter) {
template <typename TSrcTerm, typename TDstTerm>
LogicalResult setControlOnlyPath(ConversionPatternRewriter &rewriter,
Value entryCtrl) {
assert(entryCtrl.getType().isa<NoneType>() &&
"Expected NoneType for entry control value");
// Creates start and end points of the control-only path
// Add start point of the control-only path to the entry block's arguments
Block *entryBlock = &r.front();
startCtrl = entryBlock->addArgument(rewriter.getNoneType(),
rewriter.getUnknownLoc());
setBlockEntryControl(entryBlock, startCtrl);
setBlockEntryControl(entryBlock, entryCtrl);
// Replace original return ops with new returns with additional control
// input
for (auto retOp : llvm::make_early_inc_range(r.getOps<TTerm>())) {
for (auto retOp : llvm::make_early_inc_range(r.getOps<TSrcTerm>())) {
rewriter.setInsertionPoint(retOp);
SmallVector<Value, 8> operands(retOp->getOperands());
operands.push_back(startCtrl);
rewriter.replaceOpWithNewOp<handshake::ReturnOp>(retOp, operands);
operands.push_back(entryCtrl);
rewriter.replaceOpWithNewOp<TDstTerm>(retOp, operands);
}
// Store the number of block arguments in each block
@ -96,7 +95,7 @@ public:
// Apply SSA maximization on the newly added entry block argument to
// propagate it explicitly between the start-point of the control-only
// network and the function's terminators
if (failed(maximizeSSA(startCtrl, rewriter)))
if (failed(maximizeSSA(entryCtrl, rewriter)))
return failure();
// Identify all block arguments belonging to the control-only network
@ -170,13 +169,19 @@ LogicalResult runPartialLowering(
instance.getContext(), instance.getRegion());
}
/// Remove basic blocks inside the given region. This allows the result to be
/// a valid graph region, since multi-basic block regions are not allowed to
/// be graph regions currently.
void removeBasicBlocks(Region &r);
// Helper to check the validity of the dataflow conversion
// Driver that applies the partial lowerings expressed in HandshakeLowering to
// the region encapsulated in it. The region is assumed to have a terminator of
// type TTerm. See HandshakeLowering for the different lowering steps.
template <typename TTerm>
// type TSrcTerm, and will replace it with TDstTerm. See HandshakeLowering for
// the different lowering steps.
template <typename TSrcTerm, typename TDstTerm>
LogicalResult lowerRegion(HandshakeLowering &hl, bool sourceConstants,
bool disableTaskPipelining) {
bool disableTaskPipelining, Value entryCtrl) {
// Perform initial dataflow conversion. This process allows for the use of
// non-deterministic merge-like operations.
HandshakeLowering::MemRefToMemoryAccessOp memOps;
@ -184,8 +189,9 @@ LogicalResult lowerRegion(HandshakeLowering &hl, bool sourceConstants,
if (failed(
runPartialLowering(hl, &HandshakeLowering::replaceMemoryOps, memOps)))
return failure();
if (failed(runPartialLowering(hl,
&HandshakeLowering::setControlOnlyPath<TTerm>)))
if (failed(runPartialLowering(
hl, &HandshakeLowering::setControlOnlyPath<TSrcTerm, TDstTerm>,
entryCtrl)))
return failure();
if (failed(runPartialLowering(hl, &HandshakeLowering::addMergeOps)))
return failure();
@ -215,14 +221,13 @@ LogicalResult lowerRegion(HandshakeLowering &hl, bool sourceConstants,
lsq)))
return failure();
// Legalize the resulting regions, removing basic blocks and performing
// any simple conversions.
removeBasicBlocks(hl.getRegion());
return success();
}
/// Remove basic blocks inside the given region. This allows the result to be
/// a valid graph region, since multi-basic block regions are not allowed to
/// be graph regions currently.
void removeBasicBlocks(Region &r);
/// Lowers the mlir operations into handshake that are not part of the dataflow
/// conversion.
LogicalResult postDataflowConvert(Operation *op);

View File

@ -120,6 +120,11 @@ def FuncOp : Op<Handshake_Dialect, "func", [
return success();
}
/// Returns the body block of the function.
Block* getBodyBlock() {
return &getBody().front();
}
//===------------------------------------------------------------------===//
// OpAsmOpInterface Methods
//===------------------------------------------------------------------===//

View File

@ -219,23 +219,13 @@ void HandshakeLowering::setBlockEntryControl(Block *block, Value v) {
}
void handshake::removeBasicBlocks(Region &r) {
auto &entryBlock = r.front().getOperations();
// Now that basic blocks are going to be removed, we can erase all cf-dialect
// branches, and move ReturnOp to the entry block's end
for (auto &block : r) {
Operation &termOp = block.back();
if (isa<mlir::cf::CondBranchOp, mlir::cf::BranchOp>(termOp))
termOp.erase();
else if (isa<handshake::ReturnOp>(termOp))
entryBlock.splice(entryBlock.end(), block.getOperations(), termOp);
}
Block *entryBlock = &r.front();
auto &entryBlockOps = entryBlock->getOperations();
// Move all operations to entry block and erase other blocks.
for (auto &block : llvm::make_early_inc_range(llvm::drop_begin(r, 1))) {
entryBlock.splice(--entryBlock.end(), block.getOperations());
}
for (auto &block : llvm::make_early_inc_range(llvm::drop_begin(r, 1))) {
for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(r, 1))) {
entryBlockOps.splice(entryBlockOps.end(), block.getOperations());
block.clear();
block.dropAllDefinedValueUses();
for (size_t i = 0; i < block.getNumArguments(); i++) {
@ -243,6 +233,21 @@ void handshake::removeBasicBlocks(Region &r) {
}
block.erase();
}
// Remove any control flow operations, and move the non-control flow
// terminator op to the end of the entry block.
for (Operation &terminatorLike : llvm::make_early_inc_range(*entryBlock)) {
if (!terminatorLike.hasTrait<OpTrait::IsTerminator>())
continue;
if (isa<mlir::cf::CondBranchOp, mlir::cf::BranchOp>(terminatorLike)) {
terminatorLike.erase();
continue;
}
// Else, assume that this is a return-like terminator op.
terminatorLike.moveBefore(entryBlock, entryBlock->end());
}
}
void removeBasicBlocks(handshake::FuncOp funcOp) {
@ -1694,8 +1699,11 @@ static LogicalResult lowerFuncOp(func::FuncOp funcOp, MLIRContext *ctx,
funcOp.getLoc(), funcOp.getName(), func_type, attributes);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
if (!newFuncOp.isExternal())
if (!newFuncOp.isExternal()) {
newFuncOp.getBodyBlock()->addArgument(rewriter.getNoneType(),
funcOp.getLoc());
newFuncOp.resolveArgAndResNames();
}
rewriter.eraseOp(funcOp);
return success();
},
@ -1706,9 +1714,12 @@ static LogicalResult lowerFuncOp(func::FuncOp funcOp, MLIRContext *ctx,
partiallyLowerRegion(maximizeSSANoMem, ctx, newFuncOp.getBody()));
if (!newFuncOp.isExternal()) {
Block *bodyBlock = newFuncOp.getBodyBlock();
Value entryCtrl = bodyBlock->getArguments().back();
HandshakeLowering fol(newFuncOp.getBody());
returnOnError(lowerRegion<func::ReturnOp>(fol, sourceConstants,
disableTaskPipelining));
if (failed(lowerRegion<func::ReturnOp, handshake::ReturnOp>(
fol, sourceConstants, disableTaskPipelining, entryCtrl)))
return failure();
}
return success();
@ -1736,11 +1747,6 @@ struct CFToHandshakePass : public CFToHandshakeBase<CFToHandshakePass> {
return;
}
}
// Legalize the resulting regions, removing basic blocks and performing
// any simple conversions.
for (auto func : m.getOps<handshake::FuncOp>())
removeBasicBlocks(func);
}
};