[MSFT] Add input, output verification to instance op

Design partitioning pass was making unverified assumptions about
instances matching their modules. Adds a function to check.
This commit is contained in:
John Demme 2022-01-12 05:03:36 +00:00
parent 0b6b9f3afd
commit 8cc1a1a514
4 changed files with 63 additions and 14 deletions

View File

@ -42,6 +42,8 @@ def InstanceOp : MSFTOp<"instance", [
StringAttr instanceNameAttr() {
return sym_nameAttr();
}
/// Check that the operands and results match the module specified.
LogicalResult verifySignatureMatch(const circt::hw::ModulePortInfo&);
// Update the results.
InstanceOp getWithNewResults(MSFTModuleOp mod,

View File

@ -60,6 +60,28 @@ void InstanceOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
}
}
LogicalResult
InstanceOp::verifySignatureMatch(const hw::ModulePortInfo &ports) {
if (ports.inputs.size() != getNumOperands())
return emitOpError("wrong number of inputs (expected ")
<< ports.inputs.size() << ")";
if (ports.outputs.size() != getNumResults())
return emitOpError("wrong number of outputs (expected ")
<< ports.outputs.size() << ")";
for (auto port : ports.inputs)
if (getOperand(port.argNum).getType() != port.type)
return emitOpError("in input port ")
<< port.name << ", expected type " << port.type << " got "
<< getOperand(port.argNum).getType();
for (auto port : ports.outputs)
if (getResult(port.argNum).getType() != port.type)
return emitOpError("in output port ")
<< port.name << ", expected type " << port.type << " got "
<< getResult(port.argNum).getType();
return success();
}
/// Return an encapsulated set of information about input and output ports of
/// the specified module or instance. The input ports always come before the
/// output ports in the list.

View File

@ -36,6 +36,19 @@ namespace msft {
} // namespace msft
} // namespace circt
/// TODO: Migrate these to some sort of OpInterface shared with hw.
static bool isAnyModule(Operation *module) {
return isa<MSFTModuleOp, MSFTModuleExternOp>(module) ||
hw::isAnyModule(module);
}
hw::ModulePortInfo getModulePortInfo(Operation *op) {
if (auto mod = dyn_cast<MSFTModuleOp>(op))
return mod.getPorts();
if (auto mod = dyn_cast<MSFTModuleExternOp>(op))
return mod.getPorts();
return hw::getModulePortInfo(op);
}
//===----------------------------------------------------------------------===//
// Lower MSFT to HW.
//===----------------------------------------------------------------------===//
@ -281,6 +294,7 @@ protected:
DenseMap<MSFTModuleOp, SmallVector<InstanceOp, 1>> moduleInstantiations;
void populateSymbolCache(ModuleOp topMod);
LogicalResult verifyInstances(ModuleOp topMod);
// Find all the modules and use the partial order of the instantiation DAG
// to sort them. If we use this order when "bubbling" up operations, we
@ -328,6 +342,7 @@ void PassCommon::updateInstances(
for (size_t portNum = 0, e = newToOldResultMap.size(); portNum < e;
++portNum) {
assert(portNum < newInst.getNumResults());
assert(newToOldResultMap[portNum] < inst.getNumResults());
inst.getResult(newToOldResultMap[portNum])
.replaceAllUsesWith(newInst.getResult(portNum));
}
@ -385,6 +400,20 @@ void PassCommon::populateSymbolCache(mlir::ModuleOp mod) {
topLevelSyms.freeze();
}
LogicalResult PassCommon::verifyInstances(mlir::ModuleOp mod) {
WalkResult r = mod.walk([&](InstanceOp inst) {
Operation *modOp = topLevelSyms.getDefinition(inst.moduleNameAttr());
if (!isAnyModule(modOp))
return WalkResult::interrupt();
hw::ModulePortInfo ports = getModulePortInfo(modOp);
return succeeded(inst.verifySignatureMatch(ports))
? WalkResult::advance()
: WalkResult::interrupt();
});
return failure(r.wasInterrupted());
}
namespace {
struct PartitionPass : public PartitionBase<PartitionPass>, PassCommon {
void runOnOperation() override;
@ -406,6 +435,10 @@ private:
void PartitionPass::runOnOperation() {
ModuleOp outerMod = getOperation();
populateSymbolCache(outerMod);
if (failed(verifyInstances(outerMod))) {
signalPassFailure();
return;
}
// Get a properly sorted list, then partition the mods in order.
SmallVector<MSFTModuleOp, 64> sortedMods;
@ -473,19 +506,6 @@ void PartitionPass::partition(MSFTModuleOp mod) {
}
}
/// TODO: Migrate these to some sort of OpInterface shared with hw.
static bool isAnyModule(Operation *module) {
return isa<MSFTModuleOp, MSFTModuleExternOp>(module) ||
hw::isAnyModule(module);
}
hw::ModulePortInfo getModulePortInfo(Operation *op) {
if (auto mod = dyn_cast<MSFTModuleOp>(op))
return mod.getPorts();
if (auto mod = dyn_cast<MSFTModuleExternOp>(op))
return mod.getPorts();
return hw::getModulePortInfo(op);
}
/// Heuristics to get the entity name.
static StringRef getOpName(Operation *op) {
StringAttr name;
@ -899,6 +919,11 @@ private:
void WireCleanupPass::runOnOperation() {
ModuleOp topMod = getOperation();
populateSymbolCache(topMod);
if (failed(verifyInstances(topMod))) {
signalPassFailure();
return;
}
SmallVector<MSFTModuleOp> sortedMods;
getAndSortModules(topMod, sortedMods);

View File

@ -11,7 +11,7 @@ hw.globalRef @ref2 [#hw.innerNameRef<@top::@b>, #hw.innerNameRef<@B::@c>, #hw.in
msft.module @top {} (%clk : i1) -> (out1: i2, out2: i2) {
msft.partition @part1, "dp"
%res1 = msft.instance @b @B(%clk) { circt.globalRef = [#hw.globalNameRef<@ref1>, #hw.globalNameRef<@ref2>], inner_sym = "b" } : (i1) -> (i2)
%res1, %_ = msft.instance @b @B(%clk) { circt.globalRef = [#hw.globalNameRef<@ref1>, #hw.globalNameRef<@ref2>], inner_sym = "b" } : (i1) -> (i2, i2)
%c0 = hw.constant 0 : i2
%res2 = msft.instance @unit1 @Extern(%c0) { targetDesignPartition = @top::@part1 }: (i2) -> (i2)