[HW] [MSFT] Remove unnecessary assertions in `getReferencedModule` and add SymbolOpUserInterface to MSFT InstanceOp. (#1787)

Remove a lot of assertions and checks that are validated by the ODS and `verifyInstanceOp` function. Secondly,
use SymbolOpUserInterface in MSFT dialect to avoid O(n) lookup when verifying the InstanceOp.
This commit is contained in:
Chris Gyurgyik 2021-09-14 17:30:02 -07:00 committed by GitHub
parent d3164681db
commit ddf3b84fba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 25 additions and 13 deletions

View File

@ -16,6 +16,7 @@
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Pass/PassBase.td"
def MSFTDialect : Dialect {

View File

@ -11,7 +11,9 @@ class MSFTOp<string mnemonic, list<OpTrait> traits = []> :
Op<MSFTDialect, mnemonic, traits>;
def InstanceOp : MSFTOp<"instance", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]> ]> {
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>
]> {
let summary = "Create an instance of a module";
let arguments = (ins StrAttr:$instanceName,

View File

@ -754,16 +754,12 @@ Operation *InstanceOp::getReferencedModule(const SymbolCache *cache) {
return result;
auto topLevelModuleOp = (*this)->getParentOfType<ModuleOp>();
if (!topLevelModuleOp)
return nullptr;
return topLevelModuleOp.lookupSymbol(moduleName());
}
// Helper function to verify instance op types
static LogicalResult verifyInstanceOpTypes(InstanceOp op, Operation *module) {
assert(module && "referenced module must not be null");
// Make sure our port and result names match.
ArrayAttr argNames = op.argNamesAttr();
ArrayAttr modArgNames = module->getAttrOfType<ArrayAttr>("argNames");

View File

@ -24,18 +24,12 @@ using namespace msft;
/// invalid IR.
Operation *InstanceOp::getReferencedModule() {
auto topLevelModuleOp = (*this)->getParentOfType<ModuleOp>();
if (!topLevelModuleOp)
return nullptr;
assert(topLevelModuleOp && "Required to have a ModuleOp parent.");
return topLevelModuleOp.lookupSymbol(moduleName());
}
StringAttr InstanceOp::getResultName(size_t idx) {
auto *module = getReferencedModule();
if (!module)
return {};
return hw::getModuleResultNameAttr(module, idx);
return hw::getModuleResultNameAttr(getReferencedModule(), idx);
}
/// Suggest a name for each result value based on the saved result names
@ -56,5 +50,17 @@ void InstanceOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
}
}
LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto *module = symbolTable.lookupNearestSymbolFrom(*this, moduleNameAttr());
if (module == nullptr)
return emitError("Cannot find module definition '") << moduleName() << "'";
// It must be some sort of module.
if (!hw::isAnyModule(module))
return emitError("symbol reference '")
<< moduleName() << "' isn't a module";
return success();
}
#define GET_OP_CLASSES
#include "circt/Dialect/MSFT/MSFT.cpp.inc"

View File

@ -13,3 +13,10 @@ module {
// expected-error @+1 {{Unexpected msft attribute 'foo'}}
hw.instance "foo1" @Foo() -> () {"loc:" = #msft.foo<""> }
}
// -----
hw.module @M() {
// expected-error @+1 {{Cannot find module definition 'Bar'}}
msft.instance "instance" @Bar () : () -> ()
}