[mlir][ODS] Add support for specifying the namespace of an interface.

The namespace can be specified using the `cppNamespace` field. This matches the functionality already present on dialects, enums, etc. This fixes problems with using interfaces on operations in a different namespace than the interface was defined in.

Differential Revision: https://reviews.llvm.org/D83604
This commit is contained in:
River Riddle 2020-07-12 14:11:39 -07:00
parent 90c577a113
commit 572c2905ae
32 changed files with 97 additions and 51 deletions

View File

@ -22,6 +22,7 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
This interface provides hooks to interact with the AsmPrinter and AsmParser
classes.
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<[{

View File

@ -1803,6 +1803,12 @@ class Interface<string name> {
// The name given to the c++ interface class.
string cppClassName = name;
// The C++ namespace that this interface should be placed into.
//
// To specify nested namespaces, use "::" as the delimiter, e.g., given
// "A::B", ops will be placed in `namespace A { namespace B { <def> } }`.
string cppNamespace = "";
// The list of methods defined by this interface.
list<InterfaceMethod> methods = [];
@ -1838,6 +1844,7 @@ class DeclareOpInterfaceMethods<OpInterface interface,
: OpInterface<interface.cppClassName> {
let description = interface.description;
let cppClassName = interface.cppClassName;
let cppNamespace = interface.cppNamespace;
let methods = interface.methods;
// This field contains a set of method names that should always have their

View File

@ -764,6 +764,7 @@ public:
virtual void getAsmBlockArgumentNames(Block *block,
OpAsmSetValueNameFn setNameFn) const {}
};
} // end namespace mlir
//===--------------------------------------------------------------------===//
// Operation OpAsm interface.
@ -772,6 +773,4 @@ public:
/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
#include "mlir/IR/OpAsmInterface.h.inc"
} // end namespace mlir
#endif

View File

@ -27,6 +27,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
See [Symbols and SymbolTables](SymbolsAndSymbolTables.md) for more details
and constraints on `Symbol` operations.
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<"Returns the name of this symbol.",

View File

@ -252,10 +252,9 @@ public:
};
} // end namespace OpTrait
} // end namespace mlir
/// Include the generated symbol interfaces.
#include "mlir/IR/SymbolInterfaces.h.inc"
} // end namespace mlir
#endif // MLIR_IR_SYMBOLTABLE_H

View File

@ -23,8 +23,9 @@ namespace mlir {
struct CallInterfaceCallable : public PointerUnion<SymbolRefAttr, Value> {
using PointerUnion<SymbolRefAttr, Value>::PointerUnion;
};
#include "mlir/Interfaces/CallInterfaces.h.inc"
} // end namespace mlir
/// Include the generated interface declarations.
#include "mlir/Interfaces/CallInterfaces.h.inc"
#endif // MLIR_INTERFACES_CALLINTERFACES_H

View File

@ -29,6 +29,7 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
indirect calls to other operations `call_indirect %foo`. An operation that
uses this interface, must *not* also provide the `CallableOpInterface`.
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<[{
@ -70,6 +71,7 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> {
`%foo = dialect.create_function(...)`. These operations may only contain a
single region, or subroutine.
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<[{

View File

@ -70,12 +70,6 @@ private:
ValueRange inputs;
};
//===----------------------------------------------------------------------===//
// ControlFlow Interfaces
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/ControlFlowInterfaces.h.inc"
//===----------------------------------------------------------------------===//
// ControlFlow Traits
//===----------------------------------------------------------------------===//
@ -101,4 +95,11 @@ struct ReturnLike : public TraitBase<ConcreteType, ReturnLike> {
} // end namespace mlir
//===----------------------------------------------------------------------===//
// ControlFlow Interfaces
//===----------------------------------------------------------------------===//
/// Include the generated interface declarations.
#include "mlir/Interfaces/ControlFlowInterfaces.h.inc"
#endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES_H

View File

@ -25,6 +25,8 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
This interface provides information for branching terminator operations,
i.e. terminator operations with successors.
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<[{
Returns a mutable range of operands that correspond to the arguments of
@ -96,6 +98,8 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
branching behavior between held regions, i.e. this interface allows for
expressing control flow information for region holding operations.
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<[{
Returns the operands of this operation used as the entry arguments when

View File

@ -15,10 +15,7 @@
#include "mlir/IR/OpDefinition.h"
namespace mlir {
/// Include the generated interface declarations.
#include "mlir/Interfaces/CopyOpInterface.h.inc"
} // namespace mlir
#endif // MLIR_INTERFACES_COPYOPINTERFACE_H_

View File

@ -19,6 +19,7 @@ def CopyOpInterface : OpInterface<"CopyOpInterface"> {
let description = [{
A copy-like operation is one that copies from source value to target value.
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<

View File

@ -15,8 +15,7 @@
#include "mlir/IR/OpDefinition.h"
namespace mlir {
/// Include the generated interface declarations.
#include "mlir/Interfaces/DerivedAttributeOpInterface.h.inc"
} // namespace mlir
#endif // MLIR_INTERFACES_DERIVEDATTRIBUTEOPINTERFACE_H_

View File

@ -23,6 +23,7 @@ def DerivedAttributeOpInterface : OpInterface<"DerivedAttributeOpInterface"> {
from information of the operation. ODS generates convenience accessors for
derived attributes and can be used to simplify translations.
}];
let cppNamespace = "::mlir";
let methods = [
StaticInterfaceMethod<

View File

@ -95,8 +95,6 @@ LogicalResult inferReturnTensorTypes(
LogicalResult verifyInferredResultTypes(Operation *op);
} // namespace detail
#include "mlir/Interfaces/InferTypeOpInterface.h.inc"
namespace OpTrait {
/// Tensor type inference trait that constructs a tensor from the inferred
@ -119,4 +117,7 @@ public:
} // namespace OpTrait
} // namespace mlir
/// Include the generated interface declarations.
#include "mlir/Interfaces/InferTypeOpInterface.h.inc"
#endif // MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_

View File

@ -25,6 +25,7 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
Interface to infer the return types for an operation that could be used
during op construction, verification or type inference.
}];
let cppNamespace = "::mlir";
let methods = [
StaticInterfaceMethod<
@ -73,6 +74,7 @@ def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
The components consists of element type, shape and raw attribute.
}];
let cppNamespace = "::mlir";
let methods = [
StaticInterfaceMethod<

View File

@ -15,10 +15,7 @@
#include "mlir/IR/OpDefinition.h"
namespace mlir {
/// Include the generated interface declarations.
#include "mlir/Interfaces/LoopLikeInterface.h.inc"
} // namespace mlir
#endif // MLIR_INTERFACES_LOOPLIKEINTERFACE_H_

View File

@ -20,6 +20,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
Encodes properties of a loop. Operations that implement this interface will
be considered by loop-invariant code motion.
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<[{

View File

@ -215,13 +215,6 @@ struct Read : public Effect::Base<Read> {};
struct Write : public Effect::Base<Write> {};
} // namespace MemoryEffects
//===----------------------------------------------------------------------===//
// SideEffect Interfaces
//===----------------------------------------------------------------------===//
/// Include the definitions of the side effect interfaces.
#include "mlir/Interfaces/SideEffectInterfaces.h.inc"
//===----------------------------------------------------------------------===//
// SideEffect Utilities
//===----------------------------------------------------------------------===//
@ -237,4 +230,11 @@ bool wouldOpBeTriviallyDead(Operation *op);
} // end namespace mlir
//===----------------------------------------------------------------------===//
// SideEffect Interfaces
//===----------------------------------------------------------------------===//
/// Include the definitions of the side effect interfaces.
#include "mlir/Interfaces/SideEffectInterfaces.h.inc"
#endif // MLIR_INTERFACES_SIDEEFFECTS_H

View File

@ -142,6 +142,9 @@ class SideEffect<EffectOpInterfaceBase interface, string effectName,
/// The parent interface that the effect belongs to.
string interfaceTrait = interface.trait;
/// The cpp namespace of the interface trait.
string cppNamespace = interface.cppNamespace;
/// The derived effect that is being applied.
string effect = effectName;
@ -156,6 +159,9 @@ class SideEffectsTraitBase<EffectOpInterfaceBase parentInterface,
/// The name of the interface trait to use.
let trait = parentInterface.trait;
/// The cpp namespace of the interface trait.
string cppNamespace = parentInterface.cppNamespace;
/// The name of the base effects class.
string baseEffectName = parentInterface.baseEffectName;
@ -177,6 +183,7 @@ def MemoryEffectsOpInterface
An interface used to query information about the memory effects applied by
an operation.
}];
let cppNamespace = "::mlir";
}
// The base class for defining specific memory effects.

View File

@ -17,10 +17,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
namespace mlir {
/// Include the generated interface declarations.
#include "mlir/Interfaces/VectorUnrollInterface.h.inc"
} // namespace mlir
#endif // MLIR_INTERFACES_VECTORUNROLLINTERFACE_H

View File

@ -19,6 +19,7 @@ def VectorUnrollOpInterface : OpInterface<"VectorUnrollOpInterface"> {
let description = [{
Encodes properties of an operation on vectors that can be unrolled.
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<[{

View File

@ -15,10 +15,7 @@
#include "mlir/IR/OpDefinition.h"
namespace mlir {
/// Include the generated interface declarations.
#include "mlir/Interfaces/ViewLikeInterface.h.inc"
} // namespace mlir
#endif // MLIR_INTERFACES_VIEWLIKEINTERFACE_H_

View File

@ -21,6 +21,7 @@ def ViewLikeOpInterface : OpInterface<"ViewLikeOpInterface"> {
takes in a (view of) buffer (and potentially some other operands) and returns
another view of buffer.
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<

View File

@ -76,6 +76,9 @@ public:
// Return the name of this interface.
StringRef getName() const;
// Return the C++ namespace of this interface.
StringRef getCppNamespace() const;
// Return the methods of this interface.
ArrayRef<InterfaceMethod> getMethods() const;

View File

@ -98,7 +98,7 @@ public:
OpInterface getOpInterface() const;
// Returns the trait corresponding to a C++ trait class.
StringRef getTrait() const;
std::string getTrait() const;
static bool classof(const OpTrait *t) {
return t->getKind() == Kind::Interface;

View File

@ -30,7 +30,7 @@ public:
StringRef getBaseEffectName() const;
// Return the name of the Interface that the effect belongs to.
StringRef getInterfaceTrait() const;
std::string getInterfaceTrait() const;
// Return the name of the resource class.
StringRef getResource() const;

View File

@ -84,6 +84,11 @@ StringRef Interface::getName() const {
return def->getValueAsString("cppClassName");
}
// Return the C++ namespace of this interface.
StringRef Interface::getCppNamespace() const {
return def->getValueAsString("cppNamespace");
}
// Return the methods of this interface.
ArrayRef<InterfaceMethod> Interface::getMethods() const { return methods; }

View File

@ -27,7 +27,7 @@ OpTrait OpTrait::create(const llvm::Init *init) {
return OpTrait(Kind::Pred, def);
if (def->isSubClassOf("GenInternalOpTrait"))
return OpTrait(Kind::Internal, def);
if (def->isSubClassOf("OpInterface"))
if (def->isSubClassOf("OpInterfaceTrait"))
return OpTrait(Kind::Interface, def);
assert(def->isSubClassOf("NativeOpTrait"));
return OpTrait(Kind::Native, def);
@ -56,8 +56,11 @@ OpInterface InterfaceOpTrait::getOpInterface() const {
return OpInterface(def);
}
llvm::StringRef InterfaceOpTrait::getTrait() const {
return def->getValueAsString("trait");
std::string InterfaceOpTrait::getTrait() const {
llvm::StringRef trait = def->getValueAsString("trait");
llvm::StringRef cppNamespace = def->getValueAsString("cppNamespace");
return cppNamespace.empty() ? trait.str()
: (cppNamespace + "::" + trait).str();
}
bool InterfaceOpTrait::shouldDeclareMethods() const {

View File

@ -336,7 +336,7 @@ void tblgen::Operator::populateTypeInferenceInfo(
llvm::formatv("{0}::Trait", inferTypeOpInterface).str()))
return;
if (const auto *opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
if (opTrait->getTrait().startswith(inferTypeOpInterface))
if (&opTrait->getDef() == inferTrait)
return;
if (!def.isSubClassOf("AllTypesMatch"))

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/SideEffects.h"
#include "llvm/ADT/Twine.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
@ -24,8 +25,11 @@ StringRef SideEffect::getBaseEffectName() const {
return def->getValueAsString("baseEffectName");
}
StringRef SideEffect::getInterfaceTrait() const {
return def->getValueAsString("interfaceTrait");
std::string SideEffect::getInterfaceTrait() const {
StringRef trait = def->getValueAsString("interfaceTrait");
StringRef cppNamespace = def->getValueAsString("cppNamespace");
return cppNamespace.empty() ? trait.str()
: (cppNamespace + "::" + trait).str();
}
StringRef SideEffect::getResource() const {

View File

@ -887,7 +887,8 @@ static bool canGenerateUnwrappedBuilder(Operator &op) {
}
static bool canInferType(Operator &op) {
return op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0;
return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
op.getNumRegions() == 0;
}
void OpEmitter::genSeparateArgParamBuilder() {
@ -1917,7 +1918,7 @@ void OpEmitter::genOpAsmInterface() {
// TODO: We could also add a flag to allow operations to opt in to this
// generation, even if they only have a single operation.
int numResults = op.getNumResults();
if (numResults <= 1 || op.getTrait("OpAsmOpInterface::Trait"))
if (numResults <= 1 || op.getTrait("::mlir::OpAsmOpInterface::Trait"))
return;
SmallVector<StringRef, 4> resultNames(numResults);
@ -1927,7 +1928,7 @@ void OpEmitter::genOpAsmInterface() {
// Don't add the trait if none of the results have a valid name.
if (llvm::all_of(resultNames, [](StringRef name) { return name.empty(); }))
return;
opClass.addTrait("OpAsmOpInterface::Trait");
opClass.addTrait("::mlir::OpAsmOpInterface::Trait");
// Generate the right accessor for the number of results.
auto &method = opClass.newMethod("void", "getAsmResultNames",

View File

@ -150,11 +150,16 @@ struct TypeInterfaceGenerator : public InterfaceGenerator {
static void emitInterfaceDef(Interface interface, StringRef valueType,
raw_ostream &os) {
StringRef interfaceName = interface.getName();
StringRef cppNamespace = interface.getCppNamespace();
cppNamespace.consume_front("::");
// Insert the method definitions.
bool isOpInterface = isa<OpInterface>(interface);
for (auto &method : interface.getMethods()) {
emitCPPType(method.getReturnType(), os) << interfaceName << "::";
emitCPPType(method.getReturnType(), os);
if (!cppNamespace.empty())
os << cppNamespace << "::";
os << interfaceName << "::";
emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
/*addConst=*/!isOpInterface);
@ -287,6 +292,11 @@ void InterfaceGenerator::emitTraitDecl(Interface &interface,
}
void InterfaceGenerator::emitInterfaceDecl(Interface interface) {
llvm::SmallVector<StringRef, 2> namespaces;
llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
for (StringRef ns : namespaces)
os << "namespace " << ns << " {\n";
StringRef interfaceName = interface.getName();
auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
@ -321,6 +331,9 @@ void InterfaceGenerator::emitInterfaceDecl(Interface interface) {
os << *extraDecls << "\n";
os << "};\n";
for (StringRef ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";
}
bool InterfaceGenerator::emitInterfaceDecls() {