diff --git a/include/circt/Dialect/MSFT/MSFTOps.td b/include/circt/Dialect/MSFT/MSFTOps.td index b342adad9e..d44301e610 100644 --- a/include/circt/Dialect/MSFT/MSFTOps.td +++ b/include/circt/Dialect/MSFT/MSFTOps.td @@ -42,6 +42,10 @@ def InstanceOp : MSFTOp<"instance", [ StringAttr instanceNameAttr() { return sym_nameAttr(); } + + // Update the results. + InstanceOp getWithNewResults(MSFTModuleOp mod, + ArrayRef newToOldMap); }]; /// sym keyword for optional symbol simplifies parsing @@ -97,6 +101,11 @@ def MSFTModuleOp : MSFTOp<"module", ArrayRef> inputs, ArrayRef> outputs); + // Remove the ports at the specified indexes. Returns the new to old result + // mapping. + SmallVector + removePorts(ArrayRef inputs, ArrayRef outputs); + // Get the module's symbolic name as StringAttr. StringAttr getNameAttr() { return (*this)->getAttrOfType( diff --git a/include/circt/Dialect/MSFT/MSFTPasses.td b/include/circt/Dialect/MSFT/MSFTPasses.td index 0e08dc68b5..853de293be 100644 --- a/include/circt/Dialect/MSFT/MSFTPasses.td +++ b/include/circt/Dialect/MSFT/MSFTPasses.td @@ -26,3 +26,9 @@ def Partition: Pass<"msft-partition", "mlir::ModuleOp"> { let constructor = "circt::msft::createPartitionPass()"; let dependentDialects = ["circt::hw::HWDialect"]; } + +def WireCleanup: Pass<"msft-wire-cleanup", "mlir::ModuleOp"> { + let summary = "Cleanup unnecessary ports and wires"; + let constructor = "circt::msft::createWireCleanupPass()"; + let dependentDialects = []; +} diff --git a/lib/Dialect/MSFT/MSFTOps.cpp b/lib/Dialect/MSFT/MSFTOps.cpp index b808bce1df..05e875a697 100644 --- a/lib/Dialect/MSFT/MSFTOps.cpp +++ b/lib/Dialect/MSFT/MSFTOps.cpp @@ -132,6 +132,37 @@ MSFTModuleOp::addPorts(ArrayRef> inputs, return newBlockArgs; } +// Remove the ports at the specified indexes. +SmallVector MSFTModuleOp::removePorts(ArrayRef inputs, + ArrayRef outputs) { + FunctionType ftype = getType(); + Block *body = getBodyBlock(); + Operation *terminator = body->getTerminator(); + + setType(ftype.getWithoutArgsAndResults(inputs, outputs)); + + // Build new operand list for output op. Construct an output mapping to return + // as a side-effect. + unsigned numResults = ftype.getNumResults(); + llvm::BitVector skipOutputs(numResults); + SmallVector newOutputValues; + SmallVector newToOldResultMap; + for (unsigned i : outputs) + skipOutputs.set(i); + for (unsigned i = 0; i < numResults; ++i) + if (!skipOutputs.test(i)) { + newOutputValues.push_back(terminator->getOperand(i)); + newToOldResultMap.push_back(i); + } + terminator->setOperands(newOutputValues); + + // Erase the arguments after setting the new output op operands since the + // arguments might be used by output op. + body->eraseArguments(inputs); + + return newToOldResultMap; +} + // Copied nearly exactly from hwops.cpp. // TODO: Unify code once a `ModuleLike` op interface exists. static void buildModule(OpBuilder &builder, OperationState &result, diff --git a/lib/Dialect/MSFT/MSFTPasses.cpp b/lib/Dialect/MSFT/MSFTPasses.cpp index 78d29dd194..b2d6fb80a6 100644 --- a/lib/Dialect/MSFT/MSFTPasses.cpp +++ b/lib/Dialect/MSFT/MSFTPasses.cpp @@ -722,6 +722,167 @@ std::unique_ptr createPartitionPass() { } // namespace msft } // namespace circt +namespace { +struct WireCleanupPass : public WireCleanupBase, PassCommon { + void runOnOperation() override; + +private: + void bubbleWiresUp(MSFTModuleOp mod); + void sinkWiresDown(MSFTModuleOp mod); +}; +} // anonymous namespace + +void WireCleanupPass::runOnOperation() { + ModuleOp topMod = getOperation(); + populateSymbolCache(topMod); + SmallVector sortedMods; + getAndSortModules(topMod, sortedMods); + + for (auto mod : sortedMods) + bubbleWiresUp(mod); + + for (auto mod : llvm::reverse(sortedMods)) + sinkWiresDown(mod); +} + +/// Push up any wires which are simply passed-through. +void WireCleanupPass::bubbleWiresUp(MSFTModuleOp mod) { + Block *body = mod.getBodyBlock(); + Operation *terminator = body->getTerminator(); + + // Find all "passthough" internal wires, filling 'inputPortsToRemove' as a + // side-effect. + DenseMap passThroughs; + SmallVector inputPortsToRemove; + for (hw::PortInfo inputPort : mod.getPorts().inputs) { + BlockArgument portArg = body->getArgument(inputPort.argNum); + bool removePort = true; + for (OpOperand user : portArg.getUsers()) { + if (user.getOwner() == terminator) + passThroughs[portArg] = inputPort; + else + removePort = false; + } + if (removePort) + inputPortsToRemove.push_back(inputPort.argNum); + } + + // Find all output ports which we can remove. Fill in 'outputToInputIdx' to + // help rewire instantiations later on. + DenseMap outputToInputIdx; + SmallVector outputPortsToRemove; + for (hw::PortInfo outputPort : mod.getPorts().outputs) { + assert(outputPort.argNum < terminator->getNumOperands() && "Invalid IR"); + Value outputValue = terminator->getOperand(outputPort.argNum); + auto inputNumF = passThroughs.find(outputValue); + if (inputNumF == passThroughs.end()) + continue; + hw::PortInfo inputPort = inputNumF->second; + outputToInputIdx[outputPort.argNum] = inputPort.argNum; + outputPortsToRemove.push_back(outputPort.argNum); + } + + // Use MSFTModuleOp's `removePorts` method to remove the ports. It returns a + // mapping of the new output port to old output port indices to assist in + // updating the instantiations later on. + auto newToOldResult = + mod.removePorts(inputPortsToRemove, outputPortsToRemove); + + // Update the instantiations. + llvm::sort(inputPortsToRemove); + auto setPassthroughsGetOperands = [&](InstanceOp newInst, InstanceOp oldInst, + SmallVectorImpl &newOperands) { + // Re-map the passthrough values around the instance. + for (auto idxPair : outputToInputIdx) { + size_t outputPortNum = idxPair.first; + assert(outputPortNum <= oldInst.getNumResults()); + size_t inputPortNum = idxPair.second; + assert(inputPortNum <= oldInst.getNumOperands()); + oldInst.getResult(outputPortNum) + .replaceAllUsesWith(oldInst.getOperand(inputPortNum)); + } + // Use a sort-merge-join approach to figure out the operand mapping on the + // fly. + size_t mergeCtr = 0; + for (size_t operNum = 0, e = oldInst.getNumOperands(); operNum < e; + ++operNum) { + if (mergeCtr < inputPortsToRemove.size() && + operNum == inputPortsToRemove[mergeCtr]) + ++mergeCtr; + else + newOperands.push_back(oldInst.getOperand(operNum)); + } + }; + updateInstances(mod, newToOldResult, setPassthroughsGetOperands); +} + +/// Sink all the instance connections which are loops. +void WireCleanupPass::sinkWiresDown(MSFTModuleOp mod) { + auto instantiations = moduleInstantiations[mod]; + // TODO: remove this limitation. This would involve looking at the common + // loopbacks for all the instances. + if (instantiations.size() != 1) + return; + InstanceOp inst = instantiations[0]; + + // Find all the "loopback" connections in the instantiation. Populate + // 'inputToOutputLoopback' with a mapping of input port to output port which + // drives it. Populate 'resultsToErase' with output ports which only drive + // input ports. + DenseMap inputToOutputLoopback; + SmallVector resultsToErase; // This is sorted. + for (unsigned resNum = 0, e = inst.getNumResults(); resNum < e; ++resNum) { + bool allLoops = true; + for (auto &use : inst.getResult(resNum).getUses()) { + if (use.getOwner() != inst.getOperation()) + allLoops = false; + else + inputToOutputLoopback[use.getOperandNumber()] = resNum; + } + if (allLoops) + resultsToErase.push_back(resNum); + } + + // Add internal connections to replace the instantiation's loop back + // connections. + Block *body = mod.getBodyBlock(); + Operation *terminator = body->getTerminator(); + SmallVector argsToErase; + for (auto resOper : inputToOutputLoopback) { + body->getArgument(resOper.first) + .replaceAllUsesWith(terminator->getOperand(resOper.second)); + argsToErase.push_back(resOper.first); + } + + // Remove the ports. + SmallVector newToOldResultMap = + mod.removePorts(argsToErase, resultsToErase); + // and update the instantiations. + llvm::sort(argsToErase); + auto getOperands = [&](InstanceOp newInst, InstanceOp oldInst, + SmallVectorImpl &newOperands) { + // Use sort-merge-join to compute the new operands; + unsigned mergeJoinCtr = 0; + for (unsigned argNum = 0, e = oldInst.getNumOperands(); argNum < e; + ++argNum) { + if (mergeJoinCtr < argsToErase.size() && + argNum == argsToErase[mergeJoinCtr]) + ++mergeJoinCtr; + else + newOperands.push_back(oldInst.getOperand(argNum)); + } + }; + updateInstances(mod, newToOldResultMap, getOperands); +} + +namespace circt { +namespace msft { +std::unique_ptr createWireCleanupPass() { + return std::make_unique(); +} +} // namespace msft +} // namespace circt + namespace { #define GEN_PASS_REGISTRATION #include "circt/Dialect/MSFT/MSFTPasses.h.inc" diff --git a/test/Dialect/MSFT/partition.mlir b/test/Dialect/MSFT/partition.mlir index cc10360ae8..0403a10a57 100644 --- a/test/Dialect/MSFT/partition.mlir +++ b/test/Dialect/MSFT/partition.mlir @@ -1,39 +1,57 @@ // RUN: circt-opt %s --msft-partition -verify-diagnostics -split-input-file | FileCheck %s +// RUN: circt-opt %s --msft-partition --msft-wire-cleanup -verify-diagnostics -split-input-file | FileCheck --check-prefix=CLEANUP %s -msft.module @top {} (%clk : i1) -> () { +msft.module @top {} (%clk : i1) -> (out1: i2, out2: i2) { msft.partition @part1, "dp" - msft.instance @b @B(%clk) : (i1) -> (i32) + %res1 = msft.instance @b @B(%clk) : (i1) -> (i2) - %c0 = hw.constant 0 : i1 - msft.instance @unit1 @Extern(%c0) { targetDesignPartition = @top::@part1 }: (i1) -> (i1) + %c0 = hw.constant 0 : i2 + %res2 = msft.instance @unit1 @Extern(%c0) { targetDesignPartition = @top::@part1 }: (i2) -> (i2) - msft.output + msft.output %res1, %res2 : i2, i2 } -msft.module.extern @Extern (%foo_a: i1) -> (foo_x: i1) +msft.module.extern @Extern (%foo_a: i2) -> (foo_x: i2) -msft.module @B {} (%clk : i1) -> (x: i1) { - %c1 = hw.constant 1 : i1 - %0 = msft.instance @unit1 @Extern(%c1) { targetDesignPartition = @top::@part1 }: (i1) -> (i1) - %1 = seq.compreg %0, %clk { targetDesignPartition = @top::@part1 } : i1 +msft.module @B {} (%clk : i1) -> (x: i2) { + %c1 = hw.constant 1 : i2 + %0 = msft.instance @unit1 @Extern(%c1) { targetDesignPartition = @top::@part1 }: (i2) -> (i2) + %1 = seq.compreg %0, %clk { targetDesignPartition = @top::@part1 } : i2 - %2 = msft.instance @unit2 @Extern(%1) { targetDesignPartition = @top::@part1 }: (i1) -> (i1) + %2 = msft.instance @unit2 @Extern(%1) { targetDesignPartition = @top::@part1 }: (i2) -> (i2) - msft.output %2: i1 + msft.output %2: i2 } -// CHECK-LABEL: msft.module @top {} (%clk: i1) { -// CHECK: %part1.b.unit1.foo_x, %part1.b.seq.compreg.b.seq.compreg, %part1.b.unit2.foo_x, %part1.unit1.foo_x = msft.instance @part1 @dp(%b.unit1.foo_a, %b.seq.compreg.in0, %b.seq.compreg.in1, %b.unit2.foo_a, %false) : (i1, i1, i1, i1, i1) -> (i1, i1, i1, i1) -// CHECK: %b.x, %b.unit1.foo_a, %b.seq.compreg.in0, %b.seq.compreg.in1, %b.unit2.foo_a = msft.instance @b @B(%clk, %part1.b.unit1.foo_x, %part1.b.seq.compreg.b.seq.compreg, %part1.b.unit2.foo_x) : (i1, i1, i1, i1) -> (i1, i1, i1, i1, i1) -// CHECK: %false = hw.constant false -// CHECK: msft.output -// CHECK-LABEL: msft.module @B {} (%clk: i1, %unit1.foo_x: i1, %seq.compreg.out0: i1, %unit2.foo_x: i1) -> (x: i1, unit1.foo_a: i1, seq.compreg.in0: i1, seq.compreg.in1: i1, unit2.foo_a: i1) { -// CHECK: %true = hw.constant true -// CHECK: msft.output %unit2.foo_x, %true, %unit1.foo_x, %clk, %seq.compreg.out0 : i1, i1, i1, i1, i1 -// CHECK-LABEL: msft.module @dp {} (%b.unit1.foo_a: i1, %b.seq.compreg.in0: i1, %b.seq.compreg.in1: i1, %b.unit2.foo_a: i1, %unit1.foo_a: i1) -> (b.unit1.foo_x: i1, b.seq.compreg.b.seq.compreg: i1, b.unit2.foo_x: i1, unit1.foo_x: i1) { -// CHECK: %b.unit1.foo_x = msft.instance @b.unit1 @Extern(%b.unit1.foo_a) : (i1) -> i1 -// CHECK: %b.seq.compreg = seq.compreg %b.seq.compreg.in0, %b.seq.compreg.in1 : i1 -// CHECK: %b.unit2.foo_x = msft.instance @b.unit2 @Extern(%b.unit2.foo_a) : (i1) -> i1 -// CHECK: %unit1.foo_x = msft.instance @unit1 @Extern(%unit1.foo_a) : (i1) -> i1 -// CHECK: msft.output %b.unit1.foo_x, %b.seq.compreg, %b.unit2.foo_x, %unit1.foo_x : i1, i1, i1, i1 +// CHECK-LABEL: msft.module @top {} (%clk: i1) -> (out1: i2, out2: i2) { +// CHECK: %part1.b.unit1.foo_x, %part1.b.seq.compreg.b.seq.compreg, %part1.b.unit2.foo_x, %part1.unit1.foo_x = msft.instance @part1 @dp(%b.unit1.foo_a, %b.seq.compreg.in0, %b.seq.compreg.in1, %b.unit2.foo_a, %c0_i2) : (i2, i2, i1, i2, i2) -> (i2, i2, i2, i2) +// CHECK: %b.x, %b.unit1.foo_a, %b.seq.compreg.in0, %b.seq.compreg.in1, %b.unit2.foo_a = msft.instance @b @B(%clk, %part1.b.unit1.foo_x, %part1.b.seq.compreg.b.seq.compreg, %part1.b.unit2.foo_x) : (i1, i2, i2, i2) -> (i2, i2, i2, i1, i2) +// CHECK: %c0_i2 = hw.constant 0 : i2 +// CHECK: msft.output %b.x, %part1.unit1.foo_x : i2, i2 +// CHECK-LABEL: msft.module.extern @Extern(%foo_a: i2) -> (foo_x: i2) +// CHECK-LABEL: msft.module @B {} (%clk: i1, %unit1.foo_x: i2, %seq.compreg.out0: i2, %unit2.foo_x: i2) -> (x: i2, unit1.foo_a: i2, seq.compreg.in0: i2, seq.compreg.in1: i1, unit2.foo_a: i2) { +// CHECK: %c1_i2 = hw.constant 1 : i2 +// CHECK: msft.output %unit2.foo_x, %c1_i2, %unit1.foo_x, %clk, %seq.compreg.out0 : i2, i2, i2, i1, i2 +// CHECK-LABEL: msft.module @dp {} (%b.unit1.foo_a: i2, %b.seq.compreg.in0: i2, %b.seq.compreg.in1: i1, %b.unit2.foo_a: i2, %unit1.foo_a: i2) -> (b.unit1.foo_x: i2, b.seq.compreg.b.seq.compreg: i2, b.unit2.foo_x: i2, unit1.foo_x: i2) { +// CHECK: %b.unit1.foo_x = msft.instance @b.unit1 @Extern(%b.unit1.foo_a) : (i2) -> i2 +// CHECK: %b.seq.compreg = seq.compreg %b.seq.compreg.in0, %b.seq.compreg.in1 : i2 +// CHECK: %b.unit2.foo_x = msft.instance @b.unit2 @Extern(%b.unit2.foo_a) : (i2) -> i2 +// CHECK: %unit1.foo_x = msft.instance @unit1 @Extern(%unit1.foo_a) : (i2) -> i2 +// CHECK: msft.output %b.unit1.foo_x, %b.seq.compreg, %b.unit2.foo_x, %unit1.foo_x : i2, i2, i2, i2 + +// CLEANUP-LABEL: msft.module @top {} (%clk: i1) -> (out1: i2, out2: i2) { +// CLEANUP: %part1.b.unit1.foo_x, %part1.b.seq.compreg.b.seq.compreg = msft.instance @part1 @dp(%b.x, %clk, %c0_i2) : (i2, i1, i2) -> (i2, i2) +// CLEANUP: %b.x = msft.instance @b @B() : () -> i2 +// CLEANUP: %c0_i2 = hw.constant 0 : i2 +// CLEANUP: msft.output %part1.b.unit1.foo_x, %part1.b.seq.compreg.b.seq.compreg : i2, i2 +// CLEANUP-LABEL: msft.module.extern @Extern(%foo_a: i2) -> (foo_x: i2) +// CLEANUP-LABEL: msft.module @B {} () -> (x: i2) { +// CLEANUP: %c1_i2 = hw.constant 1 : i2 +// CLEANUP: msft.output %c1_i2 : i2 +// CLEANUP-LABEL: msft.module @dp {} (%b.unit1.foo_a: i2, %b.seq.compreg.in0: i1, %b.seq.compreg.in1: i2) -> (b.unit1.foo_x: i2, b.seq.compreg.b.seq.compreg: i2) { +// CLEANUP: %b.unit1.foo_x = msft.instance @b.unit1 @Extern(%b.unit1.foo_a) : (i2) -> i2 +// CLEANUP: %b.seq.compreg = seq.compreg %b.unit1.foo_x, %b.seq.compreg.in0 : i2 +// CLEANUP: %b.unit2.foo_x = msft.instance @b.unit2 @Extern(%b.seq.compreg) : (i2) -> i2 +// CLEANUP: %unit1.foo_x = msft.instance @unit1 @Extern(%b.seq.compreg.in1) : (i2) -> i2 +// CLEANUP: msft.output %b.unit2.foo_x, %unit1.foo_x : i2, i2