From bd3a6499f716254a4bb8dad10873bf9336995037 Mon Sep 17 00:00:00 2001 From: John Demme Date: Tue, 4 Jan 2022 15:12:43 -0800 Subject: [PATCH] [MSFT] Wire cleanup pass (#2410) Another step towards #2365. This pass 'bubbles up' wires which are merely pass throughs in a given module. It then 'sinks down' wires which are looped back in the instantiation. Together with the entity movement piece, this effectively moves the wires as well. Does not handle wire manipulation operations, which also need to be moved/copied. --- include/circt/Dialect/MSFT/MSFTOps.td | 9 ++ include/circt/Dialect/MSFT/MSFTPasses.td | 6 + lib/Dialect/MSFT/MSFTOps.cpp | 31 +++++ lib/Dialect/MSFT/MSFTPasses.cpp | 161 +++++++++++++++++++++++ test/Dialect/MSFT/partition.mlir | 70 ++++++---- 5 files changed, 251 insertions(+), 26 deletions(-) diff --git a/include/circt/Dialect/MSFT/MSFTOps.td b/include/circt/Dialect/MSFT/MSFTOps.td index b342adad9e..d44301e610 100644 --- a/include/circt/Dialect/MSFT/MSFTOps.td +++ b/include/circt/Dialect/MSFT/MSFTOps.td @@ -42,6 +42,10 @@ def InstanceOp : MSFTOp<"instance", [ StringAttr instanceNameAttr() { return sym_nameAttr(); } + + // Update the results. + InstanceOp getWithNewResults(MSFTModuleOp mod, + ArrayRef newToOldMap); }]; /// sym keyword for optional symbol simplifies parsing @@ -97,6 +101,11 @@ def MSFTModuleOp : MSFTOp<"module", ArrayRef> inputs, ArrayRef> outputs); + // Remove the ports at the specified indexes. Returns the new to old result + // mapping. + SmallVector + removePorts(ArrayRef inputs, ArrayRef outputs); + // Get the module's symbolic name as StringAttr. StringAttr getNameAttr() { return (*this)->getAttrOfType( diff --git a/include/circt/Dialect/MSFT/MSFTPasses.td b/include/circt/Dialect/MSFT/MSFTPasses.td index 0e08dc68b5..853de293be 100644 --- a/include/circt/Dialect/MSFT/MSFTPasses.td +++ b/include/circt/Dialect/MSFT/MSFTPasses.td @@ -26,3 +26,9 @@ def Partition: Pass<"msft-partition", "mlir::ModuleOp"> { let constructor = "circt::msft::createPartitionPass()"; let dependentDialects = ["circt::hw::HWDialect"]; } + +def WireCleanup: Pass<"msft-wire-cleanup", "mlir::ModuleOp"> { + let summary = "Cleanup unnecessary ports and wires"; + let constructor = "circt::msft::createWireCleanupPass()"; + let dependentDialects = []; +} diff --git a/lib/Dialect/MSFT/MSFTOps.cpp b/lib/Dialect/MSFT/MSFTOps.cpp index b808bce1df..05e875a697 100644 --- a/lib/Dialect/MSFT/MSFTOps.cpp +++ b/lib/Dialect/MSFT/MSFTOps.cpp @@ -132,6 +132,37 @@ MSFTModuleOp::addPorts(ArrayRef> inputs, return newBlockArgs; } +// Remove the ports at the specified indexes. +SmallVector MSFTModuleOp::removePorts(ArrayRef inputs, + ArrayRef outputs) { + FunctionType ftype = getType(); + Block *body = getBodyBlock(); + Operation *terminator = body->getTerminator(); + + setType(ftype.getWithoutArgsAndResults(inputs, outputs)); + + // Build new operand list for output op. Construct an output mapping to return + // as a side-effect. + unsigned numResults = ftype.getNumResults(); + llvm::BitVector skipOutputs(numResults); + SmallVector newOutputValues; + SmallVector newToOldResultMap; + for (unsigned i : outputs) + skipOutputs.set(i); + for (unsigned i = 0; i < numResults; ++i) + if (!skipOutputs.test(i)) { + newOutputValues.push_back(terminator->getOperand(i)); + newToOldResultMap.push_back(i); + } + terminator->setOperands(newOutputValues); + + // Erase the arguments after setting the new output op operands since the + // arguments might be used by output op. + body->eraseArguments(inputs); + + return newToOldResultMap; +} + // Copied nearly exactly from hwops.cpp. // TODO: Unify code once a `ModuleLike` op interface exists. static void buildModule(OpBuilder &builder, OperationState &result, diff --git a/lib/Dialect/MSFT/MSFTPasses.cpp b/lib/Dialect/MSFT/MSFTPasses.cpp index 78d29dd194..b2d6fb80a6 100644 --- a/lib/Dialect/MSFT/MSFTPasses.cpp +++ b/lib/Dialect/MSFT/MSFTPasses.cpp @@ -722,6 +722,167 @@ std::unique_ptr createPartitionPass() { } // namespace msft } // namespace circt +namespace { +struct WireCleanupPass : public WireCleanupBase, PassCommon { + void runOnOperation() override; + +private: + void bubbleWiresUp(MSFTModuleOp mod); + void sinkWiresDown(MSFTModuleOp mod); +}; +} // anonymous namespace + +void WireCleanupPass::runOnOperation() { + ModuleOp topMod = getOperation(); + populateSymbolCache(topMod); + SmallVector sortedMods; + getAndSortModules(topMod, sortedMods); + + for (auto mod : sortedMods) + bubbleWiresUp(mod); + + for (auto mod : llvm::reverse(sortedMods)) + sinkWiresDown(mod); +} + +/// Push up any wires which are simply passed-through. +void WireCleanupPass::bubbleWiresUp(MSFTModuleOp mod) { + Block *body = mod.getBodyBlock(); + Operation *terminator = body->getTerminator(); + + // Find all "passthough" internal wires, filling 'inputPortsToRemove' as a + // side-effect. + DenseMap passThroughs; + SmallVector inputPortsToRemove; + for (hw::PortInfo inputPort : mod.getPorts().inputs) { + BlockArgument portArg = body->getArgument(inputPort.argNum); + bool removePort = true; + for (OpOperand user : portArg.getUsers()) { + if (user.getOwner() == terminator) + passThroughs[portArg] = inputPort; + else + removePort = false; + } + if (removePort) + inputPortsToRemove.push_back(inputPort.argNum); + } + + // Find all output ports which we can remove. Fill in 'outputToInputIdx' to + // help rewire instantiations later on. + DenseMap outputToInputIdx; + SmallVector outputPortsToRemove; + for (hw::PortInfo outputPort : mod.getPorts().outputs) { + assert(outputPort.argNum < terminator->getNumOperands() && "Invalid IR"); + Value outputValue = terminator->getOperand(outputPort.argNum); + auto inputNumF = passThroughs.find(outputValue); + if (inputNumF == passThroughs.end()) + continue; + hw::PortInfo inputPort = inputNumF->second; + outputToInputIdx[outputPort.argNum] = inputPort.argNum; + outputPortsToRemove.push_back(outputPort.argNum); + } + + // Use MSFTModuleOp's `removePorts` method to remove the ports. It returns a + // mapping of the new output port to old output port indices to assist in + // updating the instantiations later on. + auto newToOldResult = + mod.removePorts(inputPortsToRemove, outputPortsToRemove); + + // Update the instantiations. + llvm::sort(inputPortsToRemove); + auto setPassthroughsGetOperands = [&](InstanceOp newInst, InstanceOp oldInst, + SmallVectorImpl &newOperands) { + // Re-map the passthrough values around the instance. + for (auto idxPair : outputToInputIdx) { + size_t outputPortNum = idxPair.first; + assert(outputPortNum <= oldInst.getNumResults()); + size_t inputPortNum = idxPair.second; + assert(inputPortNum <= oldInst.getNumOperands()); + oldInst.getResult(outputPortNum) + .replaceAllUsesWith(oldInst.getOperand(inputPortNum)); + } + // Use a sort-merge-join approach to figure out the operand mapping on the + // fly. + size_t mergeCtr = 0; + for (size_t operNum = 0, e = oldInst.getNumOperands(); operNum < e; + ++operNum) { + if (mergeCtr < inputPortsToRemove.size() && + operNum == inputPortsToRemove[mergeCtr]) + ++mergeCtr; + else + newOperands.push_back(oldInst.getOperand(operNum)); + } + }; + updateInstances(mod, newToOldResult, setPassthroughsGetOperands); +} + +/// Sink all the instance connections which are loops. +void WireCleanupPass::sinkWiresDown(MSFTModuleOp mod) { + auto instantiations = moduleInstantiations[mod]; + // TODO: remove this limitation. This would involve looking at the common + // loopbacks for all the instances. + if (instantiations.size() != 1) + return; + InstanceOp inst = instantiations[0]; + + // Find all the "loopback" connections in the instantiation. Populate + // 'inputToOutputLoopback' with a mapping of input port to output port which + // drives it. Populate 'resultsToErase' with output ports which only drive + // input ports. + DenseMap inputToOutputLoopback; + SmallVector resultsToErase; // This is sorted. + for (unsigned resNum = 0, e = inst.getNumResults(); resNum < e; ++resNum) { + bool allLoops = true; + for (auto &use : inst.getResult(resNum).getUses()) { + if (use.getOwner() != inst.getOperation()) + allLoops = false; + else + inputToOutputLoopback[use.getOperandNumber()] = resNum; + } + if (allLoops) + resultsToErase.push_back(resNum); + } + + // Add internal connections to replace the instantiation's loop back + // connections. + Block *body = mod.getBodyBlock(); + Operation *terminator = body->getTerminator(); + SmallVector argsToErase; + for (auto resOper : inputToOutputLoopback) { + body->getArgument(resOper.first) + .replaceAllUsesWith(terminator->getOperand(resOper.second)); + argsToErase.push_back(resOper.first); + } + + // Remove the ports. + SmallVector newToOldResultMap = + mod.removePorts(argsToErase, resultsToErase); + // and update the instantiations. + llvm::sort(argsToErase); + auto getOperands = [&](InstanceOp newInst, InstanceOp oldInst, + SmallVectorImpl &newOperands) { + // Use sort-merge-join to compute the new operands; + unsigned mergeJoinCtr = 0; + for (unsigned argNum = 0, e = oldInst.getNumOperands(); argNum < e; + ++argNum) { + if (mergeJoinCtr < argsToErase.size() && + argNum == argsToErase[mergeJoinCtr]) + ++mergeJoinCtr; + else + newOperands.push_back(oldInst.getOperand(argNum)); + } + }; + updateInstances(mod, newToOldResultMap, getOperands); +} + +namespace circt { +namespace msft { +std::unique_ptr createWireCleanupPass() { + return std::make_unique(); +} +} // namespace msft +} // namespace circt + namespace { #define GEN_PASS_REGISTRATION #include "circt/Dialect/MSFT/MSFTPasses.h.inc" diff --git a/test/Dialect/MSFT/partition.mlir b/test/Dialect/MSFT/partition.mlir index cc10360ae8..0403a10a57 100644 --- a/test/Dialect/MSFT/partition.mlir +++ b/test/Dialect/MSFT/partition.mlir @@ -1,39 +1,57 @@ // RUN: circt-opt %s --msft-partition -verify-diagnostics -split-input-file | FileCheck %s +// RUN: circt-opt %s --msft-partition --msft-wire-cleanup -verify-diagnostics -split-input-file | FileCheck --check-prefix=CLEANUP %s -msft.module @top {} (%clk : i1) -> () { +msft.module @top {} (%clk : i1) -> (out1: i2, out2: i2) { msft.partition @part1, "dp" - msft.instance @b @B(%clk) : (i1) -> (i32) + %res1 = msft.instance @b @B(%clk) : (i1) -> (i2) - %c0 = hw.constant 0 : i1 - msft.instance @unit1 @Extern(%c0) { targetDesignPartition = @top::@part1 }: (i1) -> (i1) + %c0 = hw.constant 0 : i2 + %res2 = msft.instance @unit1 @Extern(%c0) { targetDesignPartition = @top::@part1 }: (i2) -> (i2) - msft.output + msft.output %res1, %res2 : i2, i2 } -msft.module.extern @Extern (%foo_a: i1) -> (foo_x: i1) +msft.module.extern @Extern (%foo_a: i2) -> (foo_x: i2) -msft.module @B {} (%clk : i1) -> (x: i1) { - %c1 = hw.constant 1 : i1 - %0 = msft.instance @unit1 @Extern(%c1) { targetDesignPartition = @top::@part1 }: (i1) -> (i1) - %1 = seq.compreg %0, %clk { targetDesignPartition = @top::@part1 } : i1 +msft.module @B {} (%clk : i1) -> (x: i2) { + %c1 = hw.constant 1 : i2 + %0 = msft.instance @unit1 @Extern(%c1) { targetDesignPartition = @top::@part1 }: (i2) -> (i2) + %1 = seq.compreg %0, %clk { targetDesignPartition = @top::@part1 } : i2 - %2 = msft.instance @unit2 @Extern(%1) { targetDesignPartition = @top::@part1 }: (i1) -> (i1) + %2 = msft.instance @unit2 @Extern(%1) { targetDesignPartition = @top::@part1 }: (i2) -> (i2) - msft.output %2: i1 + msft.output %2: i2 } -// CHECK-LABEL: msft.module @top {} (%clk: i1) { -// CHECK: %part1.b.unit1.foo_x, %part1.b.seq.compreg.b.seq.compreg, %part1.b.unit2.foo_x, %part1.unit1.foo_x = msft.instance @part1 @dp(%b.unit1.foo_a, %b.seq.compreg.in0, %b.seq.compreg.in1, %b.unit2.foo_a, %false) : (i1, i1, i1, i1, i1) -> (i1, i1, i1, i1) -// CHECK: %b.x, %b.unit1.foo_a, %b.seq.compreg.in0, %b.seq.compreg.in1, %b.unit2.foo_a = msft.instance @b @B(%clk, %part1.b.unit1.foo_x, %part1.b.seq.compreg.b.seq.compreg, %part1.b.unit2.foo_x) : (i1, i1, i1, i1) -> (i1, i1, i1, i1, i1) -// CHECK: %false = hw.constant false -// CHECK: msft.output -// CHECK-LABEL: msft.module @B {} (%clk: i1, %unit1.foo_x: i1, %seq.compreg.out0: i1, %unit2.foo_x: i1) -> (x: i1, unit1.foo_a: i1, seq.compreg.in0: i1, seq.compreg.in1: i1, unit2.foo_a: i1) { -// CHECK: %true = hw.constant true -// CHECK: msft.output %unit2.foo_x, %true, %unit1.foo_x, %clk, %seq.compreg.out0 : i1, i1, i1, i1, i1 -// CHECK-LABEL: msft.module @dp {} (%b.unit1.foo_a: i1, %b.seq.compreg.in0: i1, %b.seq.compreg.in1: i1, %b.unit2.foo_a: i1, %unit1.foo_a: i1) -> (b.unit1.foo_x: i1, b.seq.compreg.b.seq.compreg: i1, b.unit2.foo_x: i1, unit1.foo_x: i1) { -// CHECK: %b.unit1.foo_x = msft.instance @b.unit1 @Extern(%b.unit1.foo_a) : (i1) -> i1 -// CHECK: %b.seq.compreg = seq.compreg %b.seq.compreg.in0, %b.seq.compreg.in1 : i1 -// CHECK: %b.unit2.foo_x = msft.instance @b.unit2 @Extern(%b.unit2.foo_a) : (i1) -> i1 -// CHECK: %unit1.foo_x = msft.instance @unit1 @Extern(%unit1.foo_a) : (i1) -> i1 -// CHECK: msft.output %b.unit1.foo_x, %b.seq.compreg, %b.unit2.foo_x, %unit1.foo_x : i1, i1, i1, i1 +// CHECK-LABEL: msft.module @top {} (%clk: i1) -> (out1: i2, out2: i2) { +// CHECK: %part1.b.unit1.foo_x, %part1.b.seq.compreg.b.seq.compreg, %part1.b.unit2.foo_x, %part1.unit1.foo_x = msft.instance @part1 @dp(%b.unit1.foo_a, %b.seq.compreg.in0, %b.seq.compreg.in1, %b.unit2.foo_a, %c0_i2) : (i2, i2, i1, i2, i2) -> (i2, i2, i2, i2) +// CHECK: %b.x, %b.unit1.foo_a, %b.seq.compreg.in0, %b.seq.compreg.in1, %b.unit2.foo_a = msft.instance @b @B(%clk, %part1.b.unit1.foo_x, %part1.b.seq.compreg.b.seq.compreg, %part1.b.unit2.foo_x) : (i1, i2, i2, i2) -> (i2, i2, i2, i1, i2) +// CHECK: %c0_i2 = hw.constant 0 : i2 +// CHECK: msft.output %b.x, %part1.unit1.foo_x : i2, i2 +// CHECK-LABEL: msft.module.extern @Extern(%foo_a: i2) -> (foo_x: i2) +// CHECK-LABEL: msft.module @B {} (%clk: i1, %unit1.foo_x: i2, %seq.compreg.out0: i2, %unit2.foo_x: i2) -> (x: i2, unit1.foo_a: i2, seq.compreg.in0: i2, seq.compreg.in1: i1, unit2.foo_a: i2) { +// CHECK: %c1_i2 = hw.constant 1 : i2 +// CHECK: msft.output %unit2.foo_x, %c1_i2, %unit1.foo_x, %clk, %seq.compreg.out0 : i2, i2, i2, i1, i2 +// CHECK-LABEL: msft.module @dp {} (%b.unit1.foo_a: i2, %b.seq.compreg.in0: i2, %b.seq.compreg.in1: i1, %b.unit2.foo_a: i2, %unit1.foo_a: i2) -> (b.unit1.foo_x: i2, b.seq.compreg.b.seq.compreg: i2, b.unit2.foo_x: i2, unit1.foo_x: i2) { +// CHECK: %b.unit1.foo_x = msft.instance @b.unit1 @Extern(%b.unit1.foo_a) : (i2) -> i2 +// CHECK: %b.seq.compreg = seq.compreg %b.seq.compreg.in0, %b.seq.compreg.in1 : i2 +// CHECK: %b.unit2.foo_x = msft.instance @b.unit2 @Extern(%b.unit2.foo_a) : (i2) -> i2 +// CHECK: %unit1.foo_x = msft.instance @unit1 @Extern(%unit1.foo_a) : (i2) -> i2 +// CHECK: msft.output %b.unit1.foo_x, %b.seq.compreg, %b.unit2.foo_x, %unit1.foo_x : i2, i2, i2, i2 + +// CLEANUP-LABEL: msft.module @top {} (%clk: i1) -> (out1: i2, out2: i2) { +// CLEANUP: %part1.b.unit1.foo_x, %part1.b.seq.compreg.b.seq.compreg = msft.instance @part1 @dp(%b.x, %clk, %c0_i2) : (i2, i1, i2) -> (i2, i2) +// CLEANUP: %b.x = msft.instance @b @B() : () -> i2 +// CLEANUP: %c0_i2 = hw.constant 0 : i2 +// CLEANUP: msft.output %part1.b.unit1.foo_x, %part1.b.seq.compreg.b.seq.compreg : i2, i2 +// CLEANUP-LABEL: msft.module.extern @Extern(%foo_a: i2) -> (foo_x: i2) +// CLEANUP-LABEL: msft.module @B {} () -> (x: i2) { +// CLEANUP: %c1_i2 = hw.constant 1 : i2 +// CLEANUP: msft.output %c1_i2 : i2 +// CLEANUP-LABEL: msft.module @dp {} (%b.unit1.foo_a: i2, %b.seq.compreg.in0: i1, %b.seq.compreg.in1: i2) -> (b.unit1.foo_x: i2, b.seq.compreg.b.seq.compreg: i2) { +// CLEANUP: %b.unit1.foo_x = msft.instance @b.unit1 @Extern(%b.unit1.foo_a) : (i2) -> i2 +// CLEANUP: %b.seq.compreg = seq.compreg %b.unit1.foo_x, %b.seq.compreg.in0 : i2 +// CLEANUP: %b.unit2.foo_x = msft.instance @b.unit2 @Extern(%b.seq.compreg) : (i2) -> i2 +// CLEANUP: %unit1.foo_x = msft.instance @unit1 @Extern(%b.seq.compreg.in1) : (i2) -> i2 +// CLEANUP: msft.output %b.unit2.foo_x, %unit1.foo_x : i2, i2