From 60db6832b4c2cc2c2d83f279cb93dd47daa11e10 Mon Sep 17 00:00:00 2001 From: Morten Borup Petersen Date: Thu, 21 Sep 2023 10:24:50 +0200 Subject: [PATCH] [CFToHandshake] Refactor towards genericness (#6156) A couple of minor refactors - more will come - to allow for using `CFToHandshake` on things other than just `func.func` as the source operation and `handshake.func` as the target operation. --- include/circt/Conversion/CFToHandshake.h | 47 +++++++++-------- .../circt/Dialect/Handshake/HandshakeOps.td | 5 ++ .../CFToHandshake/CFToHandshake.cpp | 52 +++++++++++-------- 3 files changed, 60 insertions(+), 44 deletions(-) diff --git a/include/circt/Conversion/CFToHandshake.h b/include/circt/Conversion/CFToHandshake.h index d6bd53d68f..85fcb198b3 100644 --- a/include/circt/Conversion/CFToHandshake.h +++ b/include/circt/Conversion/CFToHandshake.h @@ -69,23 +69,22 @@ public: LogicalResult addBranchOps(ConversionPatternRewriter &rewriter); LogicalResult replaceCallOps(ConversionPatternRewriter &rewriter); - template - LogicalResult setControlOnlyPath(ConversionPatternRewriter &rewriter) { + template + LogicalResult setControlOnlyPath(ConversionPatternRewriter &rewriter, + Value entryCtrl) { + assert(entryCtrl.getType().isa() && + "Expected NoneType for entry control value"); // Creates start and end points of the control-only path - - // Add start point of the control-only path to the entry block's arguments Block *entryBlock = &r.front(); - startCtrl = entryBlock->addArgument(rewriter.getNoneType(), - rewriter.getUnknownLoc()); - setBlockEntryControl(entryBlock, startCtrl); + setBlockEntryControl(entryBlock, entryCtrl); // Replace original return ops with new returns with additional control // input - for (auto retOp : llvm::make_early_inc_range(r.getOps())) { + for (auto retOp : llvm::make_early_inc_range(r.getOps())) { rewriter.setInsertionPoint(retOp); SmallVector operands(retOp->getOperands()); - operands.push_back(startCtrl); - rewriter.replaceOpWithNewOp(retOp, operands); + operands.push_back(entryCtrl); + rewriter.replaceOpWithNewOp(retOp, operands); } // Store the number of block arguments in each block @@ -96,7 +95,7 @@ public: // Apply SSA maximization on the newly added entry block argument to // propagate it explicitly between the start-point of the control-only // network and the function's terminators - if (failed(maximizeSSA(startCtrl, rewriter))) + if (failed(maximizeSSA(entryCtrl, rewriter))) return failure(); // Identify all block arguments belonging to the control-only network @@ -170,13 +169,19 @@ LogicalResult runPartialLowering( instance.getContext(), instance.getRegion()); } +/// Remove basic blocks inside the given region. This allows the result to be +/// a valid graph region, since multi-basic block regions are not allowed to +/// be graph regions currently. +void removeBasicBlocks(Region &r); + // Helper to check the validity of the dataflow conversion // Driver that applies the partial lowerings expressed in HandshakeLowering to // the region encapsulated in it. The region is assumed to have a terminator of -// type TTerm. See HandshakeLowering for the different lowering steps. -template +// type TSrcTerm, and will replace it with TDstTerm. See HandshakeLowering for +// the different lowering steps. +template LogicalResult lowerRegion(HandshakeLowering &hl, bool sourceConstants, - bool disableTaskPipelining) { + bool disableTaskPipelining, Value entryCtrl) { // Perform initial dataflow conversion. This process allows for the use of // non-deterministic merge-like operations. HandshakeLowering::MemRefToMemoryAccessOp memOps; @@ -184,8 +189,9 @@ LogicalResult lowerRegion(HandshakeLowering &hl, bool sourceConstants, if (failed( runPartialLowering(hl, &HandshakeLowering::replaceMemoryOps, memOps))) return failure(); - if (failed(runPartialLowering(hl, - &HandshakeLowering::setControlOnlyPath))) + if (failed(runPartialLowering( + hl, &HandshakeLowering::setControlOnlyPath, + entryCtrl))) return failure(); if (failed(runPartialLowering(hl, &HandshakeLowering::addMergeOps))) return failure(); @@ -215,14 +221,13 @@ LogicalResult lowerRegion(HandshakeLowering &hl, bool sourceConstants, lsq))) return failure(); + // Legalize the resulting regions, removing basic blocks and performing + // any simple conversions. + removeBasicBlocks(hl.getRegion()); + return success(); } -/// Remove basic blocks inside the given region. This allows the result to be -/// a valid graph region, since multi-basic block regions are not allowed to -/// be graph regions currently. -void removeBasicBlocks(Region &r); - /// Lowers the mlir operations into handshake that are not part of the dataflow /// conversion. LogicalResult postDataflowConvert(Operation *op); diff --git a/include/circt/Dialect/Handshake/HandshakeOps.td b/include/circt/Dialect/Handshake/HandshakeOps.td index 397f0324b1..5e280e8dbf 100644 --- a/include/circt/Dialect/Handshake/HandshakeOps.td +++ b/include/circt/Dialect/Handshake/HandshakeOps.td @@ -120,6 +120,11 @@ def FuncOp : Op(termOp)) - termOp.erase(); - else if (isa(termOp)) - entryBlock.splice(entryBlock.end(), block.getOperations(), termOp); - } + Block *entryBlock = &r.front(); + auto &entryBlockOps = entryBlock->getOperations(); // Move all operations to entry block and erase other blocks. - for (auto &block : llvm::make_early_inc_range(llvm::drop_begin(r, 1))) { - entryBlock.splice(--entryBlock.end(), block.getOperations()); - } - for (auto &block : llvm::make_early_inc_range(llvm::drop_begin(r, 1))) { + for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(r, 1))) { + entryBlockOps.splice(entryBlockOps.end(), block.getOperations()); + block.clear(); block.dropAllDefinedValueUses(); for (size_t i = 0; i < block.getNumArguments(); i++) { @@ -243,6 +233,21 @@ void handshake::removeBasicBlocks(Region &r) { } block.erase(); } + + // Remove any control flow operations, and move the non-control flow + // terminator op to the end of the entry block. + for (Operation &terminatorLike : llvm::make_early_inc_range(*entryBlock)) { + if (!terminatorLike.hasTrait()) + continue; + + if (isa(terminatorLike)) { + terminatorLike.erase(); + continue; + } + + // Else, assume that this is a return-like terminator op. + terminatorLike.moveBefore(entryBlock, entryBlock->end()); + } } void removeBasicBlocks(handshake::FuncOp funcOp) { @@ -1694,8 +1699,11 @@ static LogicalResult lowerFuncOp(func::FuncOp funcOp, MLIRContext *ctx, funcOp.getLoc(), funcOp.getName(), func_type, attributes); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); - if (!newFuncOp.isExternal()) + if (!newFuncOp.isExternal()) { + newFuncOp.getBodyBlock()->addArgument(rewriter.getNoneType(), + funcOp.getLoc()); newFuncOp.resolveArgAndResNames(); + } rewriter.eraseOp(funcOp); return success(); }, @@ -1706,9 +1714,12 @@ static LogicalResult lowerFuncOp(func::FuncOp funcOp, MLIRContext *ctx, partiallyLowerRegion(maximizeSSANoMem, ctx, newFuncOp.getBody())); if (!newFuncOp.isExternal()) { + Block *bodyBlock = newFuncOp.getBodyBlock(); + Value entryCtrl = bodyBlock->getArguments().back(); HandshakeLowering fol(newFuncOp.getBody()); - returnOnError(lowerRegion(fol, sourceConstants, - disableTaskPipelining)); + if (failed(lowerRegion( + fol, sourceConstants, disableTaskPipelining, entryCtrl))) + return failure(); } return success(); @@ -1736,11 +1747,6 @@ struct CFToHandshakePass : public CFToHandshakeBase { return; } } - - // Legalize the resulting regions, removing basic blocks and performing - // any simple conversions. - for (auto func : m.getOps()) - removeBasicBlocks(func); } };