From f7ff0832968539f6f1a5c4e3e351a8326cef8674 Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Sun, 12 Sep 2021 15:24:49 -0700 Subject: [PATCH] [HW] Generalize methods like getModulePortInfo to work on instances. This is NFC other than changing a verification error due to an earlier check. This is plumbing to make way for port names being stored on instances. --- include/circt/Dialect/HW/HWOps.h | 3 +- include/circt/Dialect/HW/HWStructure.td | 8 ++- lib/Dialect/HW/HWOps.cpp | 87 ++++++++++++++++++------- test/Dialect/HW/errors.mlir | 2 +- 4 files changed, 72 insertions(+), 28 deletions(-) diff --git a/include/circt/Dialect/HW/HWOps.h b/include/circt/Dialect/HW/HWOps.h index cba1a4bc4a..07caf71964 100644 --- a/include/circt/Dialect/HW/HWOps.h +++ b/include/circt/Dialect/HW/HWOps.h @@ -75,7 +75,8 @@ static inline StringRef getModuleResultName(Operation *module, void setModuleArgumentNames(Operation *module, ArrayRef names); void setModuleResultNames(Operation *module, ArrayRef names); -/// Return an encapsulated set of information about input and output ports. +/// Return an encapsulated set of information about input and output ports of +/// the specified module or instance. SmallVector getModulePortInfo(Operation *op); /// Return true if the specified operation is a combinatorial logic op. diff --git a/include/circt/Dialect/HW/HWStructure.td b/include/circt/Dialect/HW/HWStructure.td index b066a5e3f8..12e1bf5cbc 100644 --- a/include/circt/Dialect/HW/HWStructure.td +++ b/include/circt/Dialect/HW/HWStructure.td @@ -311,8 +311,12 @@ def InstanceOp : HWOp<"instance", [HasParent<"HWModuleOp">, Symbol, ]; let extraClassDeclaration = [{ - // Return the name of the specified result or empty string if it cannot be - // determined. + /// Return the name of the specified input port or null if it cannot be + /// determined. + StringAttr getArgumentName(size_t i, const SymbolCache *cache = nullptr); + + /// Return the name of the specified result or null if it cannot be + /// determined. StringAttr getResultName(size_t i, const SymbolCache *cache = nullptr); /// Lookup the module or extmodule for the symbol. This returns null on diff --git a/lib/Dialect/HW/HWOps.cpp b/lib/Dialect/HW/HWOps.cpp index 5dfb3548cf..df0ec9fed0 100644 --- a/lib/Dialect/HW/HWOps.cpp +++ b/lib/Dialect/HW/HWOps.cpp @@ -119,10 +119,19 @@ bool hw::isAnyModule(Operation *module) { isa(module); } -/// Return the signature for the specified module as a function type. -FunctionType hw::getModuleType(Operation *module) { +/// Return the signature for a module as a function type from the module itself +/// or from an hw::InstanceOp. +FunctionType hw::getModuleType(Operation *moduleOrInstance) { + if (auto instance = dyn_cast(moduleOrInstance)) { + SmallVector inputs(instance->getOperandTypes()); + SmallVector results(instance->getResultTypes()); + return FunctionType::get(instance->getContext(), inputs, results); + } + + assert(isAnyModule(moduleOrInstance) && + "must be called on instance or module"); auto typeAttr = - module->getAttrOfType(HWModuleOp::getTypeAttrName()); + moduleOrInstance->getAttrOfType(HWModuleOp::getTypeAttrName()); return typeAttr.getValue().cast(); } @@ -155,12 +164,14 @@ StringAttr hw::getModuleResultNameAttr(Operation *module, size_t resultNo) { } void hw::setModuleArgumentNames(Operation *module, ArrayRef names) { + assert(isAnyModule(module) && "Must be called on a module"); assert(getModuleType(module).getNumInputs() == names.size() && "incorrect number of arguments names specified"); module->setAttr("argNames", ArrayAttr::get(module->getContext(), names)); } void hw::setModuleResultNames(Operation *module, ArrayRef names) { + assert(isAnyModule(module) && "Must be called on a module"); assert(getModuleType(module).getNumResults() == names.size() && "incorrect number of arguments names specified"); module->setAttr("resultNames", ArrayAttr::get(module->getContext(), names)); @@ -268,8 +279,16 @@ void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result, result.addAttribute("verilogName", builder.getStringAttr(verilogName)); } +/// Return an encapsulated set of information about input and output ports of +/// the specified module or instance. SmallVector hw::getModulePortInfo(Operation *op) { - assert(isAnyModule(op) && "Can only get module ports from a module"); + assert((isa(op) || isAnyModule(op)) && + "Can only get module ports from an instance or module"); + + // TODO: Remove when argNames/resultNames are stored on instances. + if (auto instance = dyn_cast(op)) + op = instance.getReferencedModule(); + SmallVector results; auto argTypes = getModuleType(op).getInputs(); @@ -737,17 +756,21 @@ static LogicalResult verifyInstanceOpTypes(InstanceOp op, } LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto *referencedModule = - symbolTable.lookupNearestSymbolFrom(*this, moduleNameAttr()); - if (referencedModule == nullptr) + auto *module = symbolTable.lookupNearestSymbolFrom(*this, moduleNameAttr()); + if (module == nullptr) return emitError("Cannot find module definition '") << moduleName() << "'"; - if (!isa(referencedModule)) - return success(); + // It must be some sort of module. + if (!isAnyModule(module)) + return emitError("symbol reference '") + << moduleName() << "' isn't a module"; // If the referenced module is internal, check that input and result types are // consistent with the referenced module. - return verifyInstanceOpTypes(*this, referencedModule); + if (isa(module)) + return verifyInstanceOpTypes(*this, module); + + return success(); } static ParseResult parseInstanceOp(OpAsmParser &parser, @@ -810,9 +833,9 @@ static ParseResult parseInstanceOp(OpAsmParser &parser, static void printInstanceOp(OpAsmPrinter &p, InstanceOp op) { p << ' '; p.printAttributeWithoutType(op.instanceNameAttr()); - if (op->getAttr("sym_name")) { - p << ' ' << "sym" << ' '; - p.printSymbolName(op.sym_nameAttr().getValue()); + if (auto attr = op.sym_nameAttr()) { + p << " sym "; + p.printSymbolName(attr.getValue()); } p << ' '; p.printAttributeWithoutType(op.moduleNameAttr()); @@ -831,31 +854,47 @@ static void printInstanceOp(OpAsmPrinter &p, InstanceOp op) { p << ')'; } -StringAttr InstanceOp::getResultName(size_t idx, - const SymbolCache *symbolCache) { - auto *module = getReferencedModule(symbolCache); +/// Return the name of the specified input port or null if it cannot be +/// determined. +StringAttr InstanceOp::getArgumentName(size_t idx, const SymbolCache *cache) { + // TODO: Remove when argNames/resultNames are stored on instances. + auto *module = getReferencedModule(cache); + if (!module) + return {}; + auto argNames = module->getAttrOfType("argNames"); + // Tolerate malformed IR here to enable debug printing etc. + if (argNames && idx < argNames.size()) + return argNames[idx].cast(); + return StringAttr(); +} + +/// Return the name of the specified result or null if it cannot be +/// determined. +StringAttr InstanceOp::getResultName(size_t idx, const SymbolCache *cache) { + // TODO: Remove when argNames/resultNames are stored on instances. + auto *module = getReferencedModule(cache); if (!module) return {}; - return getModuleResultNameAttr(module, idx); + auto resultNames = module->getAttrOfType("resultNames"); + // Tolerate malformed IR here to enable debug printing etc. + if (resultNames && idx < resultNames.size()) + return resultNames[idx].cast(); + return StringAttr(); } /// Suggest a name for each result value based on the saved result names /// attribute. void InstanceOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { - auto *module = getReferencedModule(); - if (!module) - return; - // Provide default names for instance results. std::string name = instanceName().str() + "."; size_t baseNameLen = name.size(); for (size_t i = 0, e = getNumResults(); i != e; ++i) { - auto resName = getModuleResultName(module, i); + auto resName = getResultName(i); name.resize(baseNameLen); - if (!resName.empty()) - name += resName.str(); + if (resName && !resName.getValue().empty()) + name += resName.getValue().str(); else name += std::to_string(i); setNameFn(getResult(i), name); diff --git a/test/Dialect/HW/errors.mlir b/test/Dialect/HW/errors.mlir index 408f5fd02b..9ec222267f 100644 --- a/test/Dialect/HW/errors.mlir +++ b/test/Dialect/HW/errors.mlir @@ -34,7 +34,7 @@ func private @notModule () { } hw.module @A(%arg0: i1) { - // expected-error @+1 {{'hw.instance' op attribute 'moduleName' failed to satisfy constraint: flat symbol reference attribute is module like}} + // expected-error @+1 {{symbol reference 'notModule' isn't a module}} hw.instance "foo" @notModule(%arg0) : (i1) -> () }