[StandardToHandshake] Initial commit for stateful std-to-hs lowering

This has been a long-standing to-do for me; the intention of this commit is to start removing lowering state from within the IR itself, and make it explicitly managed. In the long term, this intends to remove all of the "helper values" and assumptions on temporary values being placed at fixed indexes for certain operations - something which i don't see as maintainable nor extensible.

One of the first issues that is tackled through this is the managing of the "block entry control value". Previously, this was assumed to be any `control_merge` operation that existed within a basic block (an assumption which is no longer valid in some stuff i'm working on locally).
This commit is contained in:
Morten Borup Petersen 2022-01-09 10:11:19 +01:00
parent f40cf9418a
commit 38cab756ab
1 changed files with 148 additions and 75 deletions

View File

@ -47,6 +47,67 @@ using namespace std;
typedef DenseMap<Block *, std::vector<Value>> BlockValues;
typedef DenseMap<Block *, std::vector<Operation *>> BlockOps;
typedef DenseMap<Value, Operation *> blockArgPairs;
typedef llvm::MapVector<Value, std::vector<Operation *>> MemRefToMemoryAccessOp;
// A class for maintaining state between all the partial handshake lowering
// passes. The std-to-handshake pass was initially written in such a way that
// all temporary state was added to the IR itself. Consequences of this are that
// the IR is in an inconsistent (non-valid) state during lowering, as well as a
// lot of assumptions being put on the IR itself for tracking things; i.e.
// control values for blocks were always assumed to be any control_merge within
// that block (which may not be the case).
// Through this class, state is made explicit, which should hopefully aid in the
// debugability and extensibility of this pass.
class FuncOpLowering {
public:
explicit FuncOpLowering(handshake::FuncOp f) : f(f) {}
LogicalResult addMergeOps(ConversionPatternRewriter &rewriter);
LogicalResult addBranchOps(ConversionPatternRewriter &rewriter);
LogicalResult replaceCallOps(ConversionPatternRewriter &rewriter);
LogicalResult setControlOnlyPath(ConversionPatternRewriter &rewriter);
LogicalResult connectConstantsToControl(ConversionPatternRewriter &rewriter,
bool sourceConstants);
BlockOps insertMergeOps(BlockValues blockLiveIns, blockArgPairs &mergePairs,
ConversionPatternRewriter &rewriter);
Operation *insertMerge(Block *block, Value val,
ConversionPatternRewriter &rewriter);
// Replaces standard memory ops with their handshake version (i.e.,
// ops which connect to memory/LSQ). Returns a map with an ordered
// list of new ops corresponding to each memref. Later, we instantiate
// a memory node for each memref and connect it to its load/store ops
LogicalResult replaceMemoryOps(ConversionPatternRewriter &rewriter,
MemRefToMemoryAccessOp &memRefOps);
LogicalResult connectToMemory(MemRefToMemoryAccessOp memRefOps, bool lsq,
ConversionPatternRewriter &rewriter);
void setMemOpControlInputs(ConversionPatternRewriter &rewriter,
ArrayRef<Operation *> memOps, Operation *memOp,
int offset, ArrayRef<int> cntrlInd);
LogicalResult finalize(ConversionPatternRewriter &rewriter,
TypeRange argTypes, TypeRange resTypes,
mlir::FuncOp origFunc);
private:
// Returns the entry control value for operations contained within this block.
Value getBlockEntryControl(Block *block);
void setBlockEntryControl(Block *block, Value v);
DenseMap<Block *, Value> blockEntryControlMap;
handshake::FuncOp f;
};
Value FuncOpLowering::getBlockEntryControl(Block *block) {
auto it = blockEntryControlMap.find(block);
assert(it != blockEntryControlMap.end() &&
"No block entry control value registerred for this block!");
return it->second;
}
void FuncOpLowering::setBlockEntryControl(Block *block, Value v) {
blockEntryControlMap[block] = v;
}
/// Remove basic blocks inside the given FuncOp. This allows the result to be
/// a valid graph region, since multi-basic block regions are not allowed to
@ -77,14 +138,15 @@ void removeBasicBlocks(handshake::FuncOp funcOp) {
}
}
LogicalResult setControlOnlyPath(handshake::FuncOp f,
ConversionPatternRewriter &rewriter) {
LogicalResult
FuncOpLowering::setControlOnlyPath(ConversionPatternRewriter &rewriter) {
// Creates start and end points of the control-only path
// Temporary start node (removed in later steps) in entry block
Block *entryBlock = &f.front();
rewriter.setInsertionPointToStart(entryBlock);
Operation *startOp = rewriter.create<StartOp>(entryBlock->front().getLoc());
setBlockEntryControl(entryBlock, startOp->getResult(0));
// Replace original return ops with new returns with additional control input
for (auto retOp : llvm::make_early_inc_range(f.getOps<mlir::ReturnOp>())) {
@ -222,15 +284,17 @@ unsigned getBlockPredecessorCount(Block *block) {
// Insert appropriate type of Merge CMerge for control-only path,
// Merge for single-successor blocks, Mux otherwise
Operation *insertMerge(Block *block, Value val,
Operation *FuncOpLowering::insertMerge(Block *block, Value val,
ConversionPatternRewriter &rewriter) {
unsigned numPredecessors = getBlockPredecessorCount(block);
// Control-only path originates from StartOp
if (!val.isa<BlockArgument>()) {
if (isa<StartOp>(val.getDefiningOp())) {
return rewriter.create<handshake::ControlMergeOp>(block->front().getLoc(),
val, numPredecessors);
auto cmerge = rewriter.create<handshake::ControlMergeOp>(
block->front().getLoc(), val, numPredecessors);
setBlockEntryControl(block, cmerge.result());
return cmerge;
}
}
@ -249,7 +313,7 @@ Operation *insertMerge(Block *block, Value val,
// Adds Merge for every live-in and block argument
// Returns DenseMap of all inserted operations
BlockOps insertMergeOps(handshake::FuncOp f, BlockValues blockLiveIns,
BlockOps FuncOpLowering::insertMergeOps(BlockValues blockLiveIns,
blockArgPairs &mergePairs,
ConversionPatternRewriter &rewriter) {
BlockOps blockMerges;
@ -351,8 +415,6 @@ Operation *getControlMerge(Block *block) {
return getFirstOp<ControlMergeOp>(block);
}
Operation *getStartOp(Block *block) { return getFirstOp<StartOp>(block); }
void reconnectMergeOps(handshake::FuncOp f, BlockOps blockMerges,
blockArgPairs &mergePairs) {
// All merge operands are initially set to original (defining) value
@ -407,8 +469,7 @@ void reconnectMergeOps(handshake::FuncOp f, BlockOps blockMerges,
removeBlockOperands(f);
}
LogicalResult addMergeOps(handshake::FuncOp f,
ConversionPatternRewriter &rewriter) {
LogicalResult FuncOpLowering::addMergeOps(ConversionPatternRewriter &rewriter) {
blockArgPairs mergePairs;
@ -416,7 +477,7 @@ LogicalResult addMergeOps(handshake::FuncOp f,
BlockValues liveIns = livenessAnalysis(f);
// Insert merge operations
BlockOps mergeOps = insertMergeOps(f, liveIns, mergePairs, rewriter);
BlockOps mergeOps = insertMergeOps(liveIns, mergePairs, rewriter);
// Set merge operands and uses
reconnectMergeOps(f, mergeOps, mergePairs);
@ -465,8 +526,8 @@ Value getSuccResult(Operation *termOp, Operation *newOp, Block *succBlock) {
return newOp->getResult(0);
}
LogicalResult addBranchOps(handshake::FuncOp f,
ConversionPatternRewriter &rewriter) {
LogicalResult
FuncOpLowering::addBranchOps(ConversionPatternRewriter &rewriter) {
BlockValues liveOuts;
@ -535,8 +596,8 @@ LogicalResult addBranchOps(handshake::FuncOp f,
return success();
}
LogicalResult connectConstantsToControl(handshake::FuncOp f,
ConversionPatternRewriter &rewriter,
LogicalResult
FuncOpLowering::connectConstantsToControl(ConversionPatternRewriter &rewriter,
bool sourceConstants) {
// Create new constants which have a control-only input to trigger them. These
// are conneted to the control network or optionally to a Source operation
@ -553,14 +614,12 @@ LogicalResult connectConstantsToControl(handshake::FuncOp f,
}
} else {
for (Block &block : f) {
Operation *cntrlMg =
block.isEntryBlock() ? getStartOp(&block) : getControlMerge(&block);
assert(cntrlMg != nullptr && "No control operation found in block");
Value blockEntryCtrl = getBlockEntryControl(&block);
for (auto constantOp :
llvm::make_early_inc_range(block.getOps<arith::ConstantOp>())) {
rewriter.setInsertionPointAfter(constantOp);
rewriter.replaceOpWithNewOp<handshake::ConstantOp>(
constantOp, constantOp.value(), cntrlMg->getResult(0));
constantOp, constantOp.value(), blockEntryCtrl);
}
}
}
@ -704,14 +763,8 @@ bool isMemoryOp(Operation *op) {
mlir::AffineWriteOpInterface>(op);
}
typedef llvm::MapVector<Value, std::vector<Operation *>> MemRefToMemoryAccessOp;
// Replaces standard memory ops with their handshake version (i.e.,
// ops which connect to memory/LSQ). Returns a map with an ordered
// list of new ops corresponding to each memref. Later, we instantiate
// a memory node for each memref and connect it to its load/store ops
LogicalResult replaceMemoryOps(handshake::FuncOp f,
ConversionPatternRewriter &rewriter,
LogicalResult
FuncOpLowering::replaceMemoryOps(ConversionPatternRewriter &rewriter,
MemRefToMemoryAccessOp &memRefOps) {
std::vector<Operation *> opsToErase;
@ -965,18 +1018,18 @@ LogicalResult setJoinControlInputs(ArrayRef<Operation *> memOps,
return success();
}
void setMemOpControlInputs(ConversionPatternRewriter &rewriter,
ArrayRef<Operation *> memOps, Operation *memOp,
int offset, ArrayRef<int> cntrlInd) {
void FuncOpLowering::setMemOpControlInputs(ConversionPatternRewriter &rewriter,
ArrayRef<Operation *> memOps,
Operation *memOp, int offset,
ArrayRef<int> cntrlInd) {
for (int i = 0, e = memOps.size(); i < e; ++i) {
std::vector<Value> controlOperands;
Operation *currOp = memOps[i];
Block *currBlock = currOp->getBlock();
// Set load/store control inputs from control merge
Operation *cntrlMg = currBlock->isEntryBlock() ? getStartOp(currBlock)
: getControlMerge(currBlock);
controlOperands.push_back(cntrlMg->getResult(0));
// Set load/store control inputs from the block input control value
Value blockEntryCtrl = getBlockEntryControl(currBlock);
controlOperands.push_back(blockEntryCtrl);
// Set load/store control inputs from predecessors in block
for (int j = 0, f = i; j < f; ++j) {
@ -1003,8 +1056,8 @@ void setMemOpControlInputs(ConversionPatternRewriter &rewriter,
}
}
LogicalResult connectToMemory(handshake::FuncOp f,
MemRefToMemoryAccessOp memRefOps, bool lsq,
LogicalResult
FuncOpLowering::connectToMemory(MemRefToMemoryAccessOp memRefOps, bool lsq,
ConversionPatternRewriter &rewriter) {
// Add MemoryOps which represent the memory interface
// Connect memory operations and control appropriately
@ -1266,19 +1319,17 @@ struct HandshakeCanonicalizePattern : public ConversionPattern {
}
};
LogicalResult replaceCallOps(handshake::FuncOp f,
ConversionPatternRewriter &rewriter) {
LogicalResult
FuncOpLowering::replaceCallOps(ConversionPatternRewriter &rewriter) {
for (Block &block : f) {
/// An instance is activated whenever control arrives at the basic block of
/// the source callOp.
Operation *cntrlMg =
block.isEntryBlock() ? getStartOp(&block) : getControlMerge(&block);
assert(cntrlMg);
Value blockEntryControl = getBlockEntryControl(&block);
for (Operation &op : block) {
if (auto callOp = dyn_cast<CallOp>(op)) {
llvm::SmallVector<Value> operands;
llvm::copy(callOp.getOperands(), std::back_inserter(operands));
operands.push_back(cntrlMg->getResult(0));
operands.push_back(blockEntryControl);
rewriter.setInsertionPoint(callOp);
auto instanceOp = rewriter.create<handshake::InstanceOp>(
callOp.getLoc(), callOp.getCallee(), callOp.getResultTypes(),
@ -1293,6 +1344,28 @@ LogicalResult replaceCallOps(handshake::FuncOp f,
return success();
}
LogicalResult FuncOpLowering::finalize(ConversionPatternRewriter &rewriter,
TypeRange argTypes, TypeRange resTypes,
mlir::FuncOp origFunc) {
SmallVector<Type> newArgTypes(argTypes);
newArgTypes.push_back(rewriter.getNoneType());
auto funcType = rewriter.getFunctionType(newArgTypes, resTypes);
f.setType(funcType);
auto ctrlArg = f.front().addArgument(rewriter.getNoneType());
// We've now added all types to the handshake.funcOp; resolve arg- and
// res names to ensure they are up to date with the final type
// signature.
f.resolveArgAndResNames();
Operation *startOp = findStartOp(&f.getRegion());
startOp->getResult(0).replaceAllUsesWith(ctrlArg);
rewriter.eraseOp(startOp);
rewriter.eraseOp(origFunc);
return success();
}
#define returnOnError(logicalResult) \
if (failed(logicalResult)) \
return logicalResult;
@ -1337,28 +1410,42 @@ LogicalResult lowerFuncOp(mlir::FuncOp funcOp, MLIRContext *ctx,
},
ctx, funcOp));
FuncOpLowering fol(newFuncOp);
// Perform dataflow conversion
MemRefToMemoryAccessOp memOps;
returnOnError(partiallyLowerFuncOp<handshake::FuncOp>(
[&](handshake::FuncOp nfo, ConversionPatternRewriter &rewriter) {
// Map from original memref to new load/store operations.
return replaceMemoryOps(nfo, rewriter, memOps);
return fol.replaceMemoryOps(rewriter, memOps);
},
ctx, newFuncOp));
returnOnError(partiallyLowerFuncOp<handshake::FuncOp>(setControlOnlyPath, ctx,
newFuncOp));
returnOnError(
partiallyLowerFuncOp<handshake::FuncOp>(addMergeOps, ctx, newFuncOp));
returnOnError(
partiallyLowerFuncOp<handshake::FuncOp>(replaceCallOps, ctx, newFuncOp));
returnOnError(
partiallyLowerFuncOp<handshake::FuncOp>(addBranchOps, ctx, newFuncOp));
returnOnError(partiallyLowerFuncOp<handshake::FuncOp>(
[&](handshake::FuncOp nfo, ConversionPatternRewriter &rewriter) {
return fol.setControlOnlyPath(rewriter);
},
ctx, newFuncOp));
returnOnError(partiallyLowerFuncOp<handshake::FuncOp>(
[&](handshake::FuncOp nfo, ConversionPatternRewriter &rewriter) {
return fol.addMergeOps(rewriter);
},
ctx, newFuncOp));
returnOnError(partiallyLowerFuncOp<handshake::FuncOp>(
[&](handshake::FuncOp nfo, ConversionPatternRewriter &rewriter) {
return fol.replaceCallOps(rewriter);
},
ctx, newFuncOp));
returnOnError(partiallyLowerFuncOp<handshake::FuncOp>(
[&](handshake::FuncOp nfo, ConversionPatternRewriter &rewriter) {
return fol.addBranchOps(rewriter);
},
ctx, newFuncOp));
returnOnError(
partiallyLowerFuncOp<handshake::FuncOp>(addSinkOps, ctx, newFuncOp));
returnOnError(partiallyLowerFuncOp<handshake::FuncOp>(
[&](handshake::FuncOp f, ConversionPatternRewriter &rewriter) {
return connectConstantsToControl(f, rewriter, sourceConstants);
return fol.connectConstantsToControl(rewriter, sourceConstants);
},
ctx, newFuncOp));
returnOnError(
@ -1368,7 +1455,7 @@ LogicalResult lowerFuncOp(mlir::FuncOp funcOp, MLIRContext *ctx,
bool lsq = false;
returnOnError(partiallyLowerFuncOp<handshake::FuncOp>(
[&](handshake::FuncOp nfo, ConversionPatternRewriter &rewriter) {
return connectToMemory(nfo, memOps, lsq, rewriter);
return fol.connectToMemory(memOps, lsq, rewriter);
},
ctx, newFuncOp));
@ -1376,22 +1463,8 @@ LogicalResult lowerFuncOp(mlir::FuncOp funcOp, MLIRContext *ctx,
// temporary handshake::StartOp operation, and finally remove the start
// op.
returnOnError(partiallyLowerFuncOp<handshake::FuncOp>(
[&](handshake::FuncOp nfo, PatternRewriter &rewriter) {
argTypes.push_back(rewriter.getNoneType());
auto funcType = rewriter.getFunctionType(argTypes, resTypes);
nfo.setType(funcType);
auto ctrlArg = nfo.front().addArgument(rewriter.getNoneType());
// We've now added all types to the handshake.funcOp; resolve arg- and
// res names to ensure they are up to date with the final type
// signature.
nfo.resolveArgAndResNames();
Operation *startOp = findStartOp(&nfo.getRegion());
startOp->getResult(0).replaceAllUsesWith(ctrlArg);
rewriter.eraseOp(startOp);
rewriter.eraseOp(funcOp);
return success();
[&](handshake::FuncOp nfo, ConversionPatternRewriter &rewriter) {
return fol.finalize(rewriter, argTypes, resTypes, funcOp);
},
ctx, newFuncOp));