[MSFT] [NFC] Refactor design partition pass to share code (#2416)

Enable code sharing with the future wire cleanup pass. Move utility methods into new super class. Create a "framework" method for updating instances.
This commit is contained in:
John Demme 2022-01-03 16:18:40 -08:00 committed by GitHub
parent 117262aa68
commit 922282ca83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 84 additions and 37 deletions

View File

@ -275,17 +275,12 @@ std::unique_ptr<Pass> createLowerToHWPass() {
} // namespace circt
namespace {
struct PartitionPass : public PartitionBase<PartitionPass> {
void runOnOperation() override;
private:
struct PassCommon {
protected:
hw::SymbolCache topLevelSyms;
DenseMap<MSFTModuleOp, SmallVector<InstanceOp, 1>> moduleInstantiations;
void partition(MSFTModuleOp mod);
void partition(DesignPartitionOp part, SmallVectorImpl<Operation *> &users);
void bubbleUp(MSFTModuleOp mod, ArrayRef<Operation *> ops);
void populateSymbolCache(ModuleOp topMod);
// Find all the modules and use the partial order of the instantiation DAG
// to sort them. If we use this order when "bubbling" up operations, we
@ -293,17 +288,58 @@ private:
// instantiation sites mapping.
//
// Assumption (unchecked): there is not a cycle in the instantiation graph.
void getAndSortModules(SmallVectorImpl<MSFTModuleOp> &mods);
void getAndSortModules(ModuleOp topMod, SmallVectorImpl<MSFTModuleOp> &mods);
void getAndSortModulesVisitor(MSFTModuleOp mod,
SmallVectorImpl<MSFTModuleOp> &mods,
DenseSet<MSFTModuleOp> &modsSeen);
void updateInstances(
MSFTModuleOp mod, ArrayRef<unsigned> newToOldResultMap,
llvm::function_ref<void(InstanceOp, InstanceOp, SmallVectorImpl<Value> &)>
getOperandsFunc);
};
} // anonymous namespace
/// Update all the instantiations of 'mod' to match the port list. For any
/// output ports which survived, automatically map the result according to
/// `newToOldResultMap`. Calls 'getOperandsFunc' with the new instance op, the
/// old instance op, and expects the operand vector to return filled.
/// `getOperandsFunc` can (and often does) modify other operations. The update
/// call deletes the original instance op, so all references are invalidated
/// after this call.
void PassCommon::updateInstances(
MSFTModuleOp mod, ArrayRef<unsigned> newToOldResultMap,
llvm::function_ref<void(InstanceOp, InstanceOp, SmallVectorImpl<Value> &)>
getOperandsFunc) {
SmallVector<InstanceOp, 1> newInstances;
for (InstanceOp inst : moduleInstantiations[mod]) {
assert(inst->getParentOp());
OpBuilder b(inst);
auto newInst =
b.create<InstanceOp>(inst.getLoc(), mod.getType().getResults(),
inst.getOperands(), inst->getAttrs());
for (size_t portNum = 0, e = newToOldResultMap.size(); portNum < e;
++portNum) {
assert(portNum < newInst.getNumResults());
inst.getResult(newToOldResultMap[portNum])
.replaceAllUsesWith(newInst.getResult(portNum));
}
SmallVector<Value> newOperands;
getOperandsFunc(newInst, inst, newOperands);
newInst->setOperands(newOperands);
newInstances.push_back(newInst);
inst->dropAllUses();
inst->erase();
}
moduleInstantiations[mod].swap(newInstances);
}
// Run a post-order DFS.
void PartitionPass::getAndSortModulesVisitor(
MSFTModuleOp mod, SmallVectorImpl<MSFTModuleOp> &mods,
DenseSet<MSFTModuleOp> &modsSeen) {
void PassCommon::getAndSortModulesVisitor(MSFTModuleOp mod,
SmallVectorImpl<MSFTModuleOp> &mods,
DenseSet<MSFTModuleOp> &modsSeen) {
if (modsSeen.contains(mod))
return;
modsSeen.insert(mod);
@ -320,32 +356,47 @@ void PartitionPass::getAndSortModulesVisitor(
mods.push_back(mod);
}
void PartitionPass::getAndSortModules(SmallVectorImpl<MSFTModuleOp> &mods) {
void PassCommon::getAndSortModules(ModuleOp topMod,
SmallVectorImpl<MSFTModuleOp> &mods) {
// Add here _before_ we go deeper to prevent infinite recursion.
DenseSet<MSFTModuleOp> modsSeen;
getOperation().walk(
mods.clear();
moduleInstantiations.clear();
topMod.walk(
[&](MSFTModuleOp mod) { getAndSortModulesVisitor(mod, mods, modsSeen); });
}
/// Fill a symbol cache with all the top level symbols.
static void populateSymbolCache(mlir::ModuleOp mod, hw::SymbolCache &cache) {
void PassCommon::populateSymbolCache(mlir::ModuleOp mod) {
for (Operation &op : mod.getBody()->getOperations()) {
StringAttr symName = SymbolTable::getSymbolName(&op);
if (!symName)
continue;
// Add the symbol to the cache.
cache.addDefinition(symName, &op);
topLevelSyms.addDefinition(symName, &op);
}
cache.freeze();
topLevelSyms.freeze();
}
namespace {
struct PartitionPass : public PartitionBase<PartitionPass>, PassCommon {
void runOnOperation() override;
private:
void partition(MSFTModuleOp mod);
void partition(DesignPartitionOp part, SmallVectorImpl<Operation *> &users);
void bubbleUp(MSFTModuleOp mod, ArrayRef<Operation *> ops);
};
} // anonymous namespace
void PartitionPass::runOnOperation() {
ModuleOp outerMod = getOperation();
::populateSymbolCache(outerMod, topLevelSyms);
populateSymbolCache(outerMod);
// Get a properly sorted list, then partition the mods in order.
SmallVector<MSFTModuleOp, 64> sortedMods;
getAndSortModules(sortedMods);
getAndSortModules(outerMod, sortedMods);
for (auto mod : sortedMods)
partition(mod);
}
@ -456,6 +507,7 @@ void PartitionPass::bubbleUp(MSFTModuleOp mod, ArrayRef<Operation *> ops) {
// assumes that the order in which the ops, operands, and results are the same
// _every_ time it runs through them. Doing this saves on bookkeeping.
auto *ctxt = mod.getContext();
FunctionType origType = mod.getType();
std::string nameBuffer;
//*************
@ -495,20 +547,16 @@ void PartitionPass::bubbleUp(MSFTModuleOp mod, ArrayRef<Operation *> ops) {
// - Clone in 'ops'.
// - Construct the new instance operands from the old ones + the cloned
// ops' results.
for (InstanceOp inst : moduleInstantiations[mod]) {
OpBuilder b(inst);
SmallVector<unsigned> resValues;
for (size_t i = 0, e = origType.getNumInputs(); i < e; ++i)
resValues.push_back(i);
auto cloneOpsGetOperands = [&](InstanceOp newInst, InstanceOp oldInst,
SmallVectorImpl<Value> &newOperands) {
OpBuilder b(newInst);
// Since we only have to add result types, just copy most everything.
SmallVector<Type, 64> resTypes(inst.getResultTypes());
resTypes.append(newResTypes);
auto newInst = cast<InstanceOp>(b.insert(Operation::create(
OperationState(inst->getLoc(), inst->getName().getStringRef(),
inst->getOperands(), resTypes, inst->getAttrs()))));
size_t resultNum = 0;
for (Value origRes : inst.getResults())
origRes.replaceAllUsesWith(newInst->getResult(resultNum++));
SmallVector<Value, 64> newOperands(inst->getOperands());
size_t resultNum = origType.getNumResults();
auto oldOperands = oldInst->getOperands();
newOperands.append(oldOperands.begin(), oldOperands.end());
for (Operation *op : ops) {
BlockAndValueMapping map;
for (Value oper : op->getOperands())
@ -516,11 +564,10 @@ void PartitionPass::bubbleUp(MSFTModuleOp mod, ArrayRef<Operation *> ops) {
Operation *newOp = b.insert(op->clone(map));
for (Value res : newOp->getResults())
newOperands.push_back(res);
setEntityName(newOp, inst.getName() + "." + ::getOpName(op));
setEntityName(newOp, oldInst.getName() + "." + ::getOpName(op));
}
newInst->setOperands(newOperands);
inst.erase();
}
};
updateInstances(mod, resValues, cloneOpsGetOperands);
//*************
// Done.

View File

@ -25,7 +25,7 @@ msft.module @B {} (%clk : i1) -> (x: i1) {
// CHECK-LABEL: msft.module @top {} (%clk: i1) {
// CHECK: %part1.b.unit1.foo_x, %part1.b.seq.compreg.b.seq.compreg, %part1.b.unit2.foo_x, %part1.unit1.foo_x = msft.instance @part1 @dp(%b.unit1.foo_a, %b.seq.compreg.in0, %b.seq.compreg.in1, %b.unit2.foo_a, %false) : (i1, i1, i1, i1, i1) -> (i1, i1, i1, i1)
// CHECK: %b.x, %b.unit1.foo_a, %b.seq.compreg.in0, %b.seq.compreg.in1, %b.unit2.foo_a = msft.instance @b @B(%clk, %part1.b.unit1.foo_x, %part1.b.seq.compreg.b.seq.compreg, %part1.b.unit2.foo_x) : (i1, i1, i1, i1) -> (i32, i1, i1, i1, i1)
// CHECK: %b.x, %b.unit1.foo_a, %b.seq.compreg.in0, %b.seq.compreg.in1, %b.unit2.foo_a = msft.instance @b @B(%clk, %part1.b.unit1.foo_x, %part1.b.seq.compreg.b.seq.compreg, %part1.b.unit2.foo_x) : (i1, i1, i1, i1) -> (i1, i1, i1, i1, i1)
// CHECK: %false = hw.constant false
// CHECK: msft.output
// CHECK-LABEL: msft.module @B {} (%clk: i1, %unit1.foo_x: i1, %seq.compreg.out0: i1, %unit2.foo_x: i1) -> (x: i1, unit1.foo_a: i1, seq.compreg.in0: i1, seq.compreg.in1: i1, unit2.foo_a: i1) {