[mlir] Refactor InterfaceGen to support generating interfaces for Attributes and Types.

This revision adds support to ODS for generating interfaces for attributes and types, in addition to operations. These interfaces can be specified using `AttrInterface` and `TypeInterface` in place of `OpInterface`. All of the features of `OpInterface` are supported except for the `verify` method, which does not have a matching representation in the Attribute/Type world. Generating these interface can be done using `gen-(attr|type)-interface-(defs|decls|docs)`.

Differential Revision: https://reviews.llvm.org/D81884
This commit is contained in:
River Riddle 2020-06-30 15:42:52 -07:00
parent 9fbb2de8e4
commit 2e2cdd0a52
28 changed files with 730 additions and 383 deletions

View File

@ -21,16 +21,16 @@ transformations/analyses.
### Dialect Interfaces
Dialect interfaces are generally useful for transformation passes or
analyses that want to operate generically on a set of operations,
which might even be defined in different dialects. These
interfaces generally involve wide coverage over the entire dialect and are only
used for a handful of transformations/analyses. In these cases, registering the
interface directly on each operation is overly complex and cumbersome. The
interface is not core to the operation, just to the specific transformation. An
example of where this type of interface would be used is inlining. Inlining
generally queries high-level information about the operations within a dialect,
like legality and cost modeling, that often is not specific to one operation.
Dialect interfaces are generally useful for transformation passes or analyses
that want to operate generically on a set of attributes/operations/types, which
might even be defined in different dialects. These interfaces generally involve
wide coverage over the entire dialect and are only used for a handful of
transformations/analyses. In these cases, registering the interface directly on
each operation is overly complex and cumbersome. The interface is not core to
the operation, just to the specific transformation. An example of where this
type of interface would be used is inlining. Inlining generally queries
high-level information about the operations within a dialect, like legality and
cost modeling, that often is not specific to one operation.
A dialect interface can be defined by inheriting from the CRTP base class
`DialectInterfaceBase::Base`. This class provides the necessary utilities for
@ -106,24 +106,25 @@ if(!interface.isLegalToInline(...))
...
```
### Operation Interfaces
### Attribute/Operation/Type Interfaces
Operation interfaces, as the name suggests, are those registered at the
Operation level. These interfaces provide access to derived operations
by providing a virtual interface that must be implemented. As an example, the
`Linalg` dialect may implement an interface that provides general queries about
some of the dialects library operations. These queries may provide things like:
the number of parallel loops; the number of inputs and outputs; etc.
Attribute/Operation/Type interfaces, as the names suggest, are those registered
at the level of a specific attribute/operation/type. These interfaces provide
access to derived objects by providing a virtual interface that must be
implemented. As an example, the `Linalg` dialect may implement an interface that
provides general queries about some of the dialects library operations. These
queries may provide things like: the number of parallel loops; the number of
inputs and outputs; etc.
Operation interfaces are defined by overriding the CRTP base class
`OpInterface`. This class takes, as a template parameter, a `Traits` class that
defines a `Concept` and a `Model` class. These classes provide an implementation
of concept-based polymorphism, where the Concept defines a set of virtual
methods that are overridden by the Model that is templated on the concrete
operation type. It is important to note that these classes should be pure in
that they contain no non-static data members. Operations that wish to override
this interface should add the provided trait `OpInterface<..>::Trait` upon
registration.
These interfaces are defined by overriding the CRTP base class `AttrInterface`,
`OpInterface`, or `TypeInterface` respectively. These classes take, as a
template parameter, a `Traits` class that defines a `Concept` and a `Model`
class. These classes provide an implementation of concept-based polymorphism,
where the Concept defines a set of virtual methods that are overridden by the
Model that is templated on the concrete object type. It is important to note
that these classes should be pure in that they contain no non-static data
members. Objects that wish to override this interface should add the provided
trait `*Interface<..>::Trait` to the trait list upon registration.
```c++
struct ExampleOpInterfaceTraits {
@ -182,8 +183,7 @@ if (ExampleOpInterface example = dyn_cast<ExampleOpInterface>(op))
Operation interfaces require a bit of boiler plate to connect all of the pieces
together. The ODS(Operation Definition Specification) framework provides
simplified mechanisms for
[defining interfaces](OpDefinitions.md#operation-interfaces).
simplified mechanisms for [defining interfaces](OpDefinitions.md#interfaces).
As an example, using the ODS framework would allow for defining the example
interface above as:

View File

@ -346,20 +346,20 @@ involving multiple operands/attributes/results are provided as the second
template parameter to the `Op` class. They should be deriving from the `OpTrait`
class. See [Constraints](#constraints) for more information.
### Operation interfaces
### Interfaces
[Operation interfaces](Interfaces.md#operation-interfaces) allow
operations to expose method calls without the
caller needing to know the exact operation type. Operation interfaces
defined in C++ can be accessed in the ODS framework via the
`OpInterfaceTrait` class. Aside from using pre-existing interfaces in
the C++ API, the ODS framework also provides a simplified mechanism
for defining such interfaces which removes much of the boilerplate
necessary.
[Interfaces](Interfaces.md#attribute-operation-type-interfaces) allow for
attributes, operations, and types to expose method calls without the caller
needing to know the derived type. Operation interfaces defined in C++ can be
accessed in the ODS framework via the `OpInterfaceTrait` class. Aside from using
pre-existing interfaces in the C++ API, the ODS framework also provides a
simplified mechanism for defining such interfaces which removes much of the
boilerplate necessary.
Providing a definition of the `OpInterface` class will auto-generate the C++
classes for the interface. An `OpInterface` includes a name, for the C++ class,
a description, and a list of interface methods.
Providing a definition of the `AttrInterface`, `OpInterface`, or `TypeInterface`
class will auto-generate the C++ classes for the interface. An interface
includes a name, for the C++ class, a description, and a list of interface
methods.
```tablegen
def MyInterface : OpInterface<"MyInterface"> {
@ -450,10 +450,11 @@ def MyInterface : OpInterface<"MyInterface"> {
];
}
// Interfaces can optionally be wrapped inside DeclareOpInterfaceMethods. This
// would result in autogenerating declarations for members `foo`, `bar` and
// `fooStatic`. Methods with bodies are not declared inside the op
// declaration but instead handled by the op interface trait directly.
// Operation interfaces can optionally be wrapped inside
// DeclareOpInterfaceMethods. This would result in autogenerating declarations
// for members `foo`, `bar` and `fooStatic`. Methods with bodies are not
// declared inside the op declaration but instead handled by the op interface
// trait directly.
def OpWithInferTypeInterfaceOp : Op<...
[DeclareOpInterfaceMethods<MyInterface>]> { ... }
@ -465,9 +466,9 @@ def OpWithOverrideInferTypeInterfaceOp : Op<...
[DeclareOpInterfaceMethods<MyInterface, ["getNumWithDefault"]>]> { ... }
```
A verification method can also be specified on the `OpInterface` by setting
`verify`. Setting `verify` results in the generated trait having a `verifyTrait`
method that is applied to all operations implementing the trait.
Operation interfaces may also provide a verification method on `OpInterface` by
setting `verify`. Setting `verify` results in the generated trait having a
`verifyTrait` method that is applied to all operations implementing the trait.
### Builder methods

View File

@ -1,24 +1,28 @@
# Operation Traits
# Traits
[TOC]
MLIR allows for a truly open operation ecosystem, as any dialect may define
operations that suit a specific level of abstraction. `Traits` are a mechanism
which abstracts implementation details and properties that are common
across many different operations. `Traits` may be used to specify special
properties and constraints of the operation, including whether the operation has
side effects or whether its output has the same type as the input. Some examples
of traits are `Commutative`, `SingleResult`, `Terminator`, etc. See the more
[comprehensive list](#trait-list) below for more examples of what is possible.
MLIR allows for a truly open ecosystem, as any dialect may define attributes,
operations, and types that suit a specific level of abstraction. `Traits` are a
mechanism which abstracts implementation details and properties that are common
across many different attributes/operations/types/etc.. `Traits` may be used to
specify special properties and constraints of the object, including whether an
operation has side effects or that its output has the same type as the input.
Some examples of operation traits are `Commutative`, `SingleResult`,
`Terminator`, etc. See the more comprehensive list of
[operation traits](#operation-traits-list) below for more examples of what is
possible.
## Defining a Trait
Traits may be defined in C++ by inheriting from the
`OpTrait::TraitBase<ConcreteType, TraitType>` class. This base class takes as
template parameters:
Traits may be defined in C++ by inheriting from the `TraitBase<ConcreteType,
TraitType>` class for the specific IR type. For attributes, this is
`AttributeTrait::TraitBase`. For operations, this is `OpTrait::TraitBase`. For
types, this is `TypeTrait::TraitBase`. This base class takes as template
parameters:
* ConcreteType
- The concrete operation type that this trait was attached to.
- The concrete class type that this trait was attached to.
* TraitType
- The type of the trait class that is being defined, for use with the
[`Curiously Recurring Template Pattern`](https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern).
@ -28,11 +32,11 @@ the `ConcreteType`. An example trait definition is shown below:
```c++
template <typename ConcreteType>
class MyTrait : public OpTrait::TraitBase<ConcreteType, MyTrait> {
class MyTrait : public TraitBase<ConcreteType, MyTrait> {
};
```
Derived traits may also provide a `verifyTrait` hook, that is called when
Operation traits may also provide a `verifyTrait` hook, that is called when
verifying the concrete operation. The trait verifiers will currently always be
invoked before the main `Op::verify`.
@ -57,15 +61,15 @@ instantiating the implementation for every concrete operation type.
The above demonstrates the definition of a simple self-contained trait. It is
also often useful to provide some static parameters to the trait to control its
behavior. Given that the definition of the trait class is rigid, i.e. we must
have a single template argument for the concrete operation, the templates for
the parameters will need to be split out. An example is shown below:
have a single template argument for the concrete object, the templates for the
parameters will need to be split out. An example is shown below:
```c++
template <int Parameter>
class MyParametricTrait {
public:
template <typename ConcreteType>
class Impl : public OpTrait::TraitBase<ConcreteType, Impl> {
class Impl : public TraitBase<ConcreteType, Impl> {
// Inside of 'Impl' we have full access to the template parameters
// specified above.
};
@ -74,19 +78,28 @@ public:
## Attaching a Trait
Traits may be used when defining a derived operation type, by simply adding the
name of the trait class to the `Op` class after the concrete operation type:
Traits may be used when defining a derived object type, by simply appending the
name of the trait class to the end of the base object class operation type:
```c++
/// Here we define 'MyAttr' along with the 'MyTrait' and `MyParametric trait
/// classes we defined previously.
class MyAttr : public Attribute::AttrBase<MyAttr, ..., MyTrait, MyParametricTrait<10>::Impl> {};
/// Here we define 'MyOp' along with the 'MyTrait' and `MyParametric trait
/// classes we defined previously.
class MyOp : public Op<MyOp, MyTrait, MyParametricTrait<10>::Impl> {};
/// Here we define 'MyType' along with the 'MyTrait' and `MyParametric trait
/// classes we defined previously.
class MyType : public Type::TypeBase<MyType, ..., MyTrait, MyParametricTrait<10>::Impl> {};
```
To use a trait in the [ODS](OpDefinitions.md) framework, we need to provide a
definition of the trait class. This can be done using the `NativeOpTrait` and
`ParamNativeOpTrait` classes. `ParamNativeOpTrait` provides a mechanism in which
to specify arguments to a parametric trait class with an internal `Impl`.
### Attaching Operation Traits in ODS
To use an operation trait in the [ODS](OpDefinitions.md) framework, we need to
provide a definition of the trait class. This can be done using the
`NativeOpTrait` and `ParamNativeOpTrait` classes. `ParamNativeOpTrait` provides
a mechanism in which to specify arguments to a parametric trait class with an
internal `Impl`.
```tablegen
// The argument is the c++ trait class name.
@ -110,14 +123,14 @@ details.
## Using a Trait
Traits may be used to provide additional methods, static fields, or other
information directly on the concrete operation. `Traits` internally become
`Base` classes of the concrete operation, so all of these are directly
accessible. To expose this information opaquely to transformations and analyses,
information directly on the concrete object. `Traits` internally become `Base`
classes of the concrete operation, so all of these are directly accessible. To
expose this information opaquely to transformations and analyses,
[`interfaces`](Interfaces.md) may be used.
To query if a specific operation contains a specific trait, the `hasTrait<>`
method may be used. This takes as a template parameter the trait class, which is
the same as the one passed when attaching the trait to an operation.
To query if a specific object contains a specific trait, the `hasTrait<>` method
may be used. This takes as a template parameter the trait class, which is the
same as the one passed when attaching the trait to an operation.
```c++
Operation *op = ..;
@ -125,7 +138,7 @@ if (op->hasTrait<MyTrait>() || op->hasTrait<MyParametricTrait<10>::Impl>())
...;
```
## Trait List
## Operation Traits List
MLIR provides a suite of traits that provide various functionalities that are
common across many different operations. Below is a list of some key traits that

View File

@ -199,11 +199,11 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
"Operation *", "clone",
(ins "OpBuilder &":$b, "Location":$loc, "ValueRange":$operands), [{
BlockAndValueMapping map;
unsigned numRegions = op.getOperation()->getNumRegions();
Operation *res = create(b, loc, operands, op.getAttrs());
unsigned numRegions = $_op.getOperation()->getNumRegions();
Operation *res = create(b, loc, operands, $_op.getAttrs());
assert(res->getNumRegions() == numRegions && "inconsistent # regions");
for (unsigned ridx = 0; ridx < numRegions; ++ridx)
op.getOperation()->getRegion(ridx).cloneInto(
$_op.getOperation()->getRegion(ridx).cloneInto(
&res->getRegion(ridx), map);
return res;
}]

View File

@ -1761,8 +1761,8 @@ class OpInterfaceTrait<string name, code verifyBody = [{}]>
}
// This class represents a single, optionally static, interface method.
// Note: non-static interface methods have an implicit 'op' parameter
// corresponding to an instance of the derived operation.
// Note: non-static interface methods have an implicit parameter, either
// $_op/$_attr/$_type corresponding to an instance of the derived value.
class InterfaceMethod<string desc, string retTy, string methodName,
dag args = (ins), code methodBody = [{}],
code defaultImplementation = [{}]> {
@ -1792,8 +1792,8 @@ class StaticInterfaceMethod<string desc, string retTy, string methodName,
: InterfaceMethod<desc, retTy, methodName, args, methodBody,
defaultImplementation>;
// OpInterface represents an interface regarding an op.
class OpInterface<string name> : OpInterfaceTrait<name> {
// Interface represents a base interface.
class Interface<string name> {
// A human-readable description of what this interface does.
string description = "";
@ -1808,6 +1808,23 @@ class OpInterface<string name> : OpInterfaceTrait<name> {
code extraClassDeclaration = "";
}
// AttrInterface represents an interface registered to an attribute.
class AttrInterface<string name> : Interface<name> {
// An optional code block containing extra declarations to place in the
// interface trait declaration.
code extraTraitClassDeclaration = "";
}
// OpInterface represents an interface registered to an operation.
class OpInterface<string name> : Interface<name>, OpInterfaceTrait<name>;
// TypeInterface represents an interface registered to a type.
class TypeInterface<string name> : Interface<name> {
// An optional code block containing extra declarations to place in the
// interface trait declaration.
code extraTraitClassDeclaration = "";
}
// Whether to declare the op interface methods in the op's header. This class
// simply wraps an OpInterface but is used to indicate that the method
// declarations should be generated. This class takes an optional set of methods

View File

@ -33,7 +33,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
"StringRef", "getName", (ins), [{
// Don't rely on the trait implementation as optional symbol operations
// may override this.
return mlir::SymbolTable::getSymbolName(op);
return mlir::SymbolTable::getSymbolName($_op);
}], /*defaultImplementation=*/[{
return mlir::SymbolTable::getSymbolName(this->getOperation());
}]

View File

@ -51,9 +51,9 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
}],
"Operation *", "resolveCallable", (ins), [{
// If the callable isn't a value, lookup the symbol reference.
CallInterfaceCallable callable = op.getCallableForCallee();
CallInterfaceCallable callable = $_op.getCallableForCallee();
if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>())
return SymbolTable::lookupNearestSymbolFrom(op, symbolRef);
return SymbolTable::lookupNearestSymbolFrom($_op, symbolRef);
return callable.get<Value>().getDefiningOp();
}]
>,

View File

@ -55,10 +55,10 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
}],
"Optional<BlockArgument>", "getSuccessorBlockArgument",
(ins "unsigned":$operandIndex), [{
Operation *opaqueOp = op;
Operation *opaqueOp = $_op;
for (unsigned i = 0, e = opaqueOp->getNumSuccessors(); i != e; ++i) {
if (Optional<BlockArgument> arg = detail::getBranchSuccessorArgument(
op.getSuccessorOperands(i), operandIndex,
$_op.getSuccessorOperands(i), operandIndex,
opaqueOp->getSuccessor(i)))
return arg;
}

View File

@ -61,7 +61,7 @@ class EffectOpInterfaceBase<string name, string baseEffect>
(ins "Value":$value,
"SmallVectorImpl<SideEffects::EffectInstance<"
# baseEffect # ">> &":$effects), [{
op.getEffects(effects);
$_op.getEffects(effects);
llvm::erase_if(effects, [&](auto &it) {
return it.getValue() != value;
});
@ -75,7 +75,7 @@ class EffectOpInterfaceBase<string name, string baseEffect>
(ins "SideEffects::Resource *":$resource,
"SmallVectorImpl<SideEffects::EffectInstance<"
# baseEffect # ">> &":$effects), [{
op.getEffects(effects);
$_op.getEffects(effects);
llvm::erase_if(effects, [&](auto &it) {
return it.getResource() != resource;
});

View File

@ -1,17 +1,13 @@
//===- OpInterfaces.h - OpInterfaces wrapper class --------------*- C++ -*-===//
//===- Interfaces.h - Interface wrapper classes -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// OpInterfaces wrapper to simplify using TableGen OpInterfaces.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_OPINTERFACES_H_
#define MLIR_TABLEGEN_OPINTERFACES_H_
#ifndef MLIR_TABLEGEN_INTERFACES_H_
#define MLIR_TABLEGEN_INTERFACES_H_
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallVector.h"
@ -25,9 +21,9 @@ class Record;
namespace mlir {
namespace tblgen {
// Wrapper class with helper methods for accessing OpInterfaceMethod defined
// Wrapper class with helper methods for accessing InterfaceMethod defined
// in TableGen.
class OpInterfaceMethod {
class InterfaceMethod {
public:
// This struct represents a single method argument.
struct Argument {
@ -35,7 +31,7 @@ public:
StringRef name;
};
explicit OpInterfaceMethod(const llvm::Record *def);
explicit InterfaceMethod(const llvm::Record *def);
// Return the return type of this method.
StringRef getReturnType() const;
@ -68,20 +64,20 @@ private:
};
//===----------------------------------------------------------------------===//
// OpInterface
// Interface
//===----------------------------------------------------------------------===//
// Wrapper class with helper methods for accessing OpInterfaces defined in
// Wrapper class with helper methods for accessing Interfaces defined in
// TableGen.
class OpInterface {
class Interface {
public:
explicit OpInterface(const llvm::Record *def);
explicit Interface(const llvm::Record *def);
// Return the name of this interface.
StringRef getName() const;
// Return the methods of this interface.
ArrayRef<OpInterfaceMethod> getMethods() const;
ArrayRef<InterfaceMethod> getMethods() const;
// Return the description of this method if it has one.
llvm::Optional<StringRef> getDescription() const;
@ -95,15 +91,36 @@ public:
// Return the verify method body if it has one.
llvm::Optional<StringRef> getVerify() const;
// Returns the Tablegen definition this interface was constructed from.
const llvm::Record &getDef() const { return *def; }
private:
// The TableGen definition of this interface.
const llvm::Record *def;
// The methods of this interface.
SmallVector<OpInterfaceMethod, 8> methods;
SmallVector<InterfaceMethod, 8> methods;
};
// An interface that is registered to an Attribute.
struct AttrInterface : public Interface {
using Interface::Interface;
static bool classof(const Interface *interface);
};
// An interface that is registered to an Operation.
struct OpInterface : public Interface {
using Interface::Interface;
static bool classof(const Interface *interface);
};
// An interface that is registered to a Type.
struct TypeInterface : public Interface {
using Interface::Interface;
static bool classof(const Interface *interface);
};
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TABLEGEN_OPINTERFACES_H_
#endif // MLIR_TABLEGEN_INTERFACES_H_

View File

@ -25,7 +25,7 @@ class Record;
namespace mlir {
namespace tblgen {
class OpInterface;
struct OpInterface;
// Wrapper class with helper methods for accessing OpTrait constraints defined
// in TableGen.

View File

@ -14,9 +14,9 @@ llvm_add_library(MLIRTableGen STATIC
Constraint.cpp
Dialect.cpp
Format.cpp
Interfaces.cpp
Operator.cpp
OpClass.cpp
OpInterfaces.cpp
OpTrait.cpp
Pass.cpp
Pattern.cpp

View File

@ -1,16 +1,12 @@
//===- OpInterfaces.cpp - OpInterfaces class ------------------------------===//
//===- Interfaces.cpp - Interface classes ---------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// OpInterfaces wrapper to simplify using TableGen OpInterfaces.
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/OpInterfaces.h"
#include "mlir/TableGen/Interfaces.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
@ -19,7 +15,11 @@
using namespace mlir;
using namespace mlir::tblgen;
OpInterfaceMethod::OpInterfaceMethod(const llvm::Record *def) : def(def) {
//===----------------------------------------------------------------------===//
// InterfaceMethod
//===----------------------------------------------------------------------===//
InterfaceMethod::InterfaceMethod(const llvm::Record *def) : def(def) {
llvm::DagInit *args = def->getValueAsDag("arguments");
for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) {
arguments.push_back(
@ -28,78 +28,112 @@ OpInterfaceMethod::OpInterfaceMethod(const llvm::Record *def) : def(def) {
}
}
StringRef OpInterfaceMethod::getReturnType() const {
StringRef InterfaceMethod::getReturnType() const {
return def->getValueAsString("returnType");
}
// Return the name of this method.
StringRef OpInterfaceMethod::getName() const {
StringRef InterfaceMethod::getName() const {
return def->getValueAsString("name");
}
// Return if this method is static.
bool OpInterfaceMethod::isStatic() const {
bool InterfaceMethod::isStatic() const {
return def->isSubClassOf("StaticInterfaceMethod");
}
// Return the body for this method if it has one.
llvm::Optional<StringRef> OpInterfaceMethod::getBody() const {
llvm::Optional<StringRef> InterfaceMethod::getBody() const {
auto value = def->getValueAsString("body");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
// Return the default implementation for this method if it has one.
llvm::Optional<StringRef> OpInterfaceMethod::getDefaultImplementation() const {
llvm::Optional<StringRef> InterfaceMethod::getDefaultImplementation() const {
auto value = def->getValueAsString("defaultBody");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
// Return the description of this method if it has one.
llvm::Optional<StringRef> OpInterfaceMethod::getDescription() const {
llvm::Optional<StringRef> InterfaceMethod::getDescription() const {
auto value = def->getValueAsString("description");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
ArrayRef<OpInterfaceMethod::Argument> OpInterfaceMethod::getArguments() const {
ArrayRef<InterfaceMethod::Argument> InterfaceMethod::getArguments() const {
return arguments;
}
bool OpInterfaceMethod::arg_empty() const { return arguments.empty(); }
bool InterfaceMethod::arg_empty() const { return arguments.empty(); }
//===----------------------------------------------------------------------===//
// Interface
//===----------------------------------------------------------------------===//
Interface::Interface(const llvm::Record *def) : def(def) {
assert(def->isSubClassOf("Interface") &&
"must be subclass of TableGen 'Interface' class");
OpInterface::OpInterface(const llvm::Record *def) : def(def) {
auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("methods"));
for (llvm::Init *init : listInit->getValues())
methods.emplace_back(cast<llvm::DefInit>(init)->getDef());
}
// Return the name of this interface.
StringRef OpInterface::getName() const {
StringRef Interface::getName() const {
return def->getValueAsString("cppClassName");
}
// Return the methods of this interface.
ArrayRef<OpInterfaceMethod> OpInterface::getMethods() const { return methods; }
ArrayRef<InterfaceMethod> Interface::getMethods() const { return methods; }
// Return the description of this method if it has one.
llvm::Optional<StringRef> OpInterface::getDescription() const {
llvm::Optional<StringRef> Interface::getDescription() const {
auto value = def->getValueAsString("description");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
// Return the interfaces extra class declaration code.
llvm::Optional<StringRef> OpInterface::getExtraClassDeclaration() const {
llvm::Optional<StringRef> Interface::getExtraClassDeclaration() const {
auto value = def->getValueAsString("extraClassDeclaration");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
// Return the traits extra class declaration code.
llvm::Optional<StringRef> OpInterface::getExtraTraitClassDeclaration() const {
llvm::Optional<StringRef> Interface::getExtraTraitClassDeclaration() const {
auto value = def->getValueAsString("extraTraitClassDeclaration");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
// Return the body for this method if it has one.
llvm::Optional<StringRef> OpInterface::getVerify() const {
llvm::Optional<StringRef> Interface::getVerify() const {
// Only OpInterface supports the verify method.
if (!isa<OpInterface>(this))
return llvm::None;
auto value = def->getValueAsString("verify");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
//===----------------------------------------------------------------------===//
// AttrInterface
//===----------------------------------------------------------------------===//
bool AttrInterface::classof(const Interface *interface) {
return interface->getDef().isSubClassOf("AttrInterface");
}
//===----------------------------------------------------------------------===//
// OpInterface
//===----------------------------------------------------------------------===//
bool OpInterface::classof(const Interface *interface) {
return interface->getDef().isSubClassOf("OpInterface");
}
//===----------------------------------------------------------------------===//
// TypeInterface
//===----------------------------------------------------------------------===//
bool TypeInterface::classof(const Interface *interface) {
return interface->getDef().isSubClassOf("TypeInterface");
}

View File

@ -11,7 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/OpInterfaces.h"
#include "mlir/TableGen/Interfaces.h"
#include "mlir/TableGen/Predicate.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"

View File

@ -3,6 +3,11 @@ set(LLVM_OPTIONAL_SOURCES
TestPatterns.cpp
)
set(LLVM_TARGET_DEFINITIONS TestInterfaces.td)
mlir_tablegen(TestTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(TestTypeInterfaces.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(MLIRTestInterfaceIncGen)
set(LLVM_TARGET_DEFINITIONS TestOps.td)
mlir_tablegen(TestOps.h.inc -gen-op-decls)
mlir_tablegen(TestOps.cpp.inc -gen-op-defs)
@ -22,6 +27,7 @@ add_mlir_library(MLIRTestDialect
EXCLUDE_FROM_LIBMLIR
DEPENDS
MLIRTestInterfaceIncGen
MLIRTestOpsIncGen
LINK_LIBS PUBLIC

View File

@ -7,7 +7,9 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
#include "TestTypes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/PatternMatch.h"
@ -135,9 +137,21 @@ TestDialect::TestDialect(MLIRContext *context)
>();
addInterfaces<TestOpAsmInterface, TestOpFolderDialectInterface,
TestInlinerInterface>();
addTypes<TestType>();
allowUnknownOperations();
}
Type TestDialect::parseType(DialectAsmParser &parser) const {
if (failed(parser.parseKeyword("test_type")))
return Type();
return TestType::get(getContext());
}
void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
assert(type.isa<TestType>() && "unexpected type");
printer << "test_type";
}
LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
NamedAttribute namedAttr) {
if (namedAttr.first == "test.invalid_attr")
@ -598,6 +612,7 @@ static mlir::DialectRegistration<mlir::TestDialect> testDialect;
#include "TestOpEnums.cpp.inc"
#include "TestOpStructs.cpp.inc"
#include "TestTypeInterfaces.cpp.inc"
#define GET_OP_CLASSES
#include "TestOps.cpp.inc"

View File

@ -0,0 +1,46 @@
//===-- TestInterfaces.td - Test dialect interfaces --------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef TEST_INTERFACES
#define TEST_INTERFACES
include "mlir/IR/OpBase.td"
// A type interface used to test the ODS generation of type interfaces.
def TestTypeInterface : TypeInterface<"TestTypeInterface"> {
let methods = [
InterfaceMethod<"Prints the type name.",
"void", "printTypeA", (ins "Location":$loc), [{
emitRemark(loc) << $_type << " - TestA";
}]
>,
InterfaceMethod<"Prints the type name.",
"void", "printTypeB", (ins "Location":$loc),
[{}], /*defaultImplementation=*/[{
emitRemark(loc) << $_type << " - TestB";
}]
>,
InterfaceMethod<"Prints the type name.",
"void", "printTypeC", (ins "Location":$loc)
>,
];
let extraClassDeclaration = [{
/// Prints the type name.
void printTypeD(Location loc) const {
emitRemark(loc) << *this << " - TestD";
}
}];
let extraTraitClassDeclaration = [{
/// Prints the type name.
void printTypeE(Location loc) const {
emitRemark(loc) << $_type << " - TestE";
}
}];
}
#endif // TEST_INTERFACES

View File

@ -90,6 +90,10 @@ def MultiTensorRankOf : TEST_Op<"multi_tensor_rank_of"> {
);
}
def TEST_TestType : DialectType<Test_Dialect,
CPred<"$_self.isa<::mlir::TestType>()">, "test">,
BuildableType<"$_builder.getType<::mlir::TestType>()">;
//===----------------------------------------------------------------------===//
// Test Symbols
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,44 @@
//===- TestTypes.h - MLIR Test Dialect Types --------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains types defined by the TestDialect for testing various
// features of MLIR.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TESTTYPES_H
#define MLIR_TESTTYPES_H
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Types.h"
namespace mlir {
#include "TestTypeInterfaces.h.inc"
/// This class is a simple test type that uses a generated interface.
struct TestType : public Type::TypeBase<TestType, Type, TypeStorage,
TestTypeInterface::Trait> {
using Base::Base;
static bool kindof(unsigned kind) {
return kind == Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE;
}
static TestType get(MLIRContext *context) {
return Base::get(context, Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE);
}
/// Provide a definition for the necessary interface methods.
void printTypeC(Location loc) const {
emitRemark(loc) << *this << " - TestC";
}
};
} // end namespace mlir
#endif // MLIR_TESTTYPES_H

View File

@ -1,6 +1,7 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestIR
TestFunc.cpp
TestInterfaces.cpp
TestMatchers.cpp
TestSideEffects.cpp
TestSymbolUses.cpp

View File

@ -0,0 +1,41 @@
//===- TestInterfaces.cpp - Test interface generation and application -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "TestTypes.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
namespace {
/// This test checks various aspects of Type interface generation and
/// application.
struct TestTypeInterfaces
: public PassWrapper<TestTypeInterfaces, OperationPass<ModuleOp>> {
void runOnOperation() override {
getOperation().walk([](Operation *op) {
for (Type type : op->getResultTypes()) {
if (auto testInterface = type.dyn_cast<TestTypeInterface>()) {
testInterface.printTypeA(op->getLoc());
testInterface.printTypeB(op->getLoc());
testInterface.printTypeC(op->getLoc());
testInterface.printTypeD(op->getLoc());
}
if (auto testType = type.dyn_cast<TestType>())
testType.printTypeE(op->getLoc());
}
});
}
};
} // end anonymous namespace
namespace mlir {
void registerTestInterfaces() {
PassRegistration<TestTypeInterfaces> pass("test-type-interfaces",
"Test type interface support.");
}
} // namespace mlir

View File

@ -0,0 +1,11 @@
// RUN: mlir-opt -test-type-interfaces -allow-unregistered-dialect -verify-diagnostics %s
// expected-remark@below {{'!test.test_type' - TestA}}
// expected-remark@below {{'!test.test_type' - TestB}}
// expected-remark@below {{'!test.test_type' - TestC}}
// expected-remark@below {{'!test.test_type' - TestD}}
// expected-remark@below {{'!test.test_type' - TestE}}
%foo0 = "foo.test"() : () -> (!test.test_type)
// Type without the test interface.
%foo1 = "foo.test"() : () -> (i32)

View File

@ -51,6 +51,7 @@ void registerTestDominancePass();
void registerTestExpandTanhPass();
void registerTestFunc();
void registerTestGpuMemoryPromotionPass();
void registerTestInterfaces();
void registerTestLinalgHoisting();
void registerTestLinalgTransforms();
void registerTestLivenessPass();
@ -125,6 +126,7 @@ void registerTestPasses() {
registerTestFunc();
registerTestExpandTanhPass();
registerTestGpuMemoryPromotionPass();
registerTestInterfaces();
registerTestLinalgHoisting();
registerTestLinalgTransforms();
registerTestLivenessPass();

View File

@ -12,8 +12,8 @@
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Interfaces.h"
#include "mlir/TableGen/OpClass.h"
#include "mlir/TableGen/OpInterfaces.h"
#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/Sequence.h"

View File

@ -14,8 +14,8 @@
#include "OpFormatGen.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Interfaces.h"
#include "mlir/TableGen/OpClass.h"
#include "mlir/TableGen/OpInterfaces.h"
#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/SideEffects.h"
@ -1469,7 +1469,7 @@ void OpEmitter::genOpInterfaceMethod(const tblgen::InterfaceOpTrait *opTrait) {
alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(),
alwaysDeclaredMethodsVec.end());
for (const OpInterfaceMethod &method : interface.getMethods()) {
for (const InterfaceMethod &method : interface.getMethods()) {
// Don't declare if the method has a body.
if (method.getBody())
continue;
@ -1482,7 +1482,7 @@ void OpEmitter::genOpInterfaceMethod(const tblgen::InterfaceOpTrait *opTrait) {
std::string args;
llvm::raw_string_ostream os(args);
interleaveComma(method.getArguments(), os,
[&](const OpInterfaceMethod::Argument &arg) {
[&](const InterfaceMethod::Argument &arg) {
os << arg.type << " " << arg.name;
});
opClass.newMethod(method.getReturnType(), method.getName(), os.str(),

View File

@ -10,8 +10,8 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Interfaces.h"
#include "mlir/TableGen/OpClass.h"
#include "mlir/TableGen/OpInterfaces.h"
#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/MapVector.h"

View File

@ -13,7 +13,7 @@
#include "DocGenUtilities.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/OpInterfaces.h"
#include "mlir/TableGen/Interfaces.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
@ -22,218 +22,10 @@
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
using namespace llvm;
using namespace mlir;
using mlir::tblgen::Interface;
using mlir::tblgen::InterfaceMethod;
using mlir::tblgen::OpInterface;
using mlir::tblgen::OpInterfaceMethod;
// Emit the method name and argument list for the given method. If
// 'addOperationArg' is true, then an Operation* argument is added to the
// beginning of the argument list.
static void emitMethodNameAndArgs(const OpInterfaceMethod &method,
raw_ostream &os, bool addOperationArg) {
os << method.getName() << '(';
if (addOperationArg)
os << "Operation *tablegen_opaque_op" << (method.arg_empty() ? "" : ", ");
llvm::interleaveComma(method.getArguments(), os,
[&](const OpInterfaceMethod::Argument &arg) {
os << arg.type << " " << arg.name;
});
os << ')';
}
// Get an array of all OpInterface definitions but exclude those subclassing
// "DeclareOpInterfaceMethods".
static std::vector<Record *>
getAllOpInterfaceDefinitions(const RecordKeeper &recordKeeper) {
std::vector<Record *> defs =
recordKeeper.getAllDerivedDefinitions("OpInterface");
llvm::erase_if(defs, [](const Record *def) {
return def->isSubClassOf("DeclareOpInterfaceMethods");
});
return defs;
}
//===----------------------------------------------------------------------===//
// GEN: Interface definitions
//===----------------------------------------------------------------------===//
static void emitInterfaceDef(OpInterface &interface, raw_ostream &os) {
StringRef interfaceName = interface.getName();
// Insert the method definitions.
for (auto &method : interface.getMethods()) {
os << method.getReturnType() << " " << interfaceName << "::";
emitMethodNameAndArgs(method, os, /*addOperationArg=*/false);
// Forward to the method on the concrete operation type.
os << " {\n return getImpl()->" << method.getName() << '(';
if (!method.isStatic())
os << "getOperation()" << (method.arg_empty() ? "" : ", ");
llvm::interleaveComma(
method.getArguments(), os,
[&](const OpInterfaceMethod::Argument &arg) { os << arg.name; });
os << ");\n }\n";
}
}
static bool emitInterfaceDefs(const RecordKeeper &recordKeeper,
raw_ostream &os) {
llvm::emitSourceFileHeader("Operation Interface Definitions", os);
for (const auto *def : getAllOpInterfaceDefinitions(recordKeeper)) {
OpInterface interface(def);
emitInterfaceDef(interface, os);
}
return false;
}
//===----------------------------------------------------------------------===//
// GEN: Interface declarations
//===----------------------------------------------------------------------===//
static void emitConceptDecl(OpInterface &interface, raw_ostream &os) {
os << " class Concept {\n"
<< " public:\n"
<< " virtual ~Concept() = default;\n";
// Insert each of the pure virtual concept methods.
for (auto &method : interface.getMethods()) {
os << " virtual " << method.getReturnType() << " ";
emitMethodNameAndArgs(method, os, /*addOperationArg=*/!method.isStatic());
os << " = 0;\n";
}
os << " };\n";
}
static void emitModelDecl(OpInterface &interface, raw_ostream &os) {
os << " template<typename ConcreteOp>\n";
os << " class Model : public Concept {\npublic:\n";
// Insert each of the virtual method overrides.
for (auto &method : interface.getMethods()) {
os << " " << method.getReturnType() << " ";
emitMethodNameAndArgs(method, os, /*addOperationArg=*/!method.isStatic());
os << " final {\n";
// Provide a definition of the concrete op if this is non static.
if (!method.isStatic()) {
os << " auto op = ::mlir::cast<ConcreteOp>(tablegen_opaque_op);\n"
<< " (void)op;\n";
}
// Check for a provided body to the function.
if (auto body = method.getBody()) {
os << body << "\n }\n";
continue;
}
// Forward to the method on the concrete operation type.
os << " return " << (method.isStatic() ? "ConcreteOp::" : "op.");
// Add the arguments to the call.
os << method.getName() << '(';
llvm::interleaveComma(
method.getArguments(), os,
[&](const OpInterfaceMethod::Argument &arg) { os << arg.name; });
os << ");\n }\n";
}
os << " };\n";
}
static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
StringRef interfaceName,
StringRef interfaceTraitsName) {
os << " template <typename ConcreteOp>\n "
<< llvm::formatv("struct {0}Trait : public ::mlir::OpInterface<{0},"
" detail::{1}>::Trait<ConcreteOp> {{\n",
interfaceName, interfaceTraitsName);
// Insert the default implementation for any methods.
for (auto &method : interface.getMethods()) {
// Flag interface methods named verifyTrait.
if (method.getName() == "verifyTrait")
PrintFatalError(
formatv("'verifyTrait' method cannot be specified as interface "
"method for '{0}'; set 'verify' on OpInterfaceTrait instead",
interfaceName));
auto defaultImpl = method.getDefaultImplementation();
if (!defaultImpl)
continue;
os << " " << (method.isStatic() ? "static " : "") << method.getReturnType()
<< " ";
emitMethodNameAndArgs(method, os, /*addOperationArg=*/false);
os << " {\n" << defaultImpl.getValue() << " }\n";
}
tblgen::FmtContext traitCtx;
traitCtx.withOp("op");
if (auto verify = interface.getVerify()) {
os << " static ::mlir::LogicalResult verifyTrait(Operation* op) {\n"
<< std::string(tblgen::tgfmt(*verify, &traitCtx)) << "\n }\n";
}
if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration())
os << extraTraitDecls << "\n";
os << " };\n";
// Emit a utility wrapper trait class.
os << " template <typename ConcreteOp>\n "
<< llvm::formatv("struct Trait : public {0}Trait<ConcreteOp> {{};\n",
interfaceName);
}
static void emitInterfaceDecl(OpInterface &interface, raw_ostream &os) {
StringRef interfaceName = interface.getName();
auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
// Emit the traits struct containing the concept and model declarations.
os << "namespace detail {\n"
<< "struct " << interfaceTraitsName << " {\n";
emitConceptDecl(interface, os);
emitModelDecl(interface, os);
os << "};\n} // end namespace detail\n";
// Emit the main interface class declaration.
os << llvm::formatv(
"class {0} : public ::mlir::OpInterface<{1}, detail::{2}> {\n"
"public:\n"
" using ::mlir::OpInterface<{1}, detail::{2}>::OpInterface;\n",
interfaceName, interfaceName, interfaceTraitsName);
// Emit the derived trait for the interface.
emitTraitDecl(interface, os, interfaceName, interfaceTraitsName);
// Insert the method declarations.
for (auto &method : interface.getMethods()) {
os << " " << method.getReturnType() << " ";
emitMethodNameAndArgs(method, os, /*addOperationArg=*/false);
os << ";\n";
}
// Emit any extra declarations.
if (Optional<StringRef> extraDecls = interface.getExtraClassDeclaration())
os << *extraDecls << "\n";
os << "};\n";
}
static bool emitInterfaceDecls(const RecordKeeper &recordKeeper,
raw_ostream &os) {
llvm::emitSourceFileHeader("Operation Interface Declarations", os);
for (const auto *def : getAllOpInterfaceDefinitions(recordKeeper)) {
OpInterface interface(def);
emitInterfaceDecl(interface, os);
}
return false;
}
//===----------------------------------------------------------------------===//
// GEN: Interface documentation
//===----------------------------------------------------------------------===//
/// Emit a string corresponding to a C++ type, followed by a space if necessary.
static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
@ -244,8 +36,308 @@ static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
return os;
}
static void emitInterfaceDoc(const Record &interfaceDef, raw_ostream &os) {
OpInterface interface(&interfaceDef);
/// Emit the method name and argument list for the given method. If 'addThisArg'
/// is true, then an argument is added to the beginning of the argument list for
/// the concrete value.
static void emitMethodNameAndArgs(const InterfaceMethod &method,
raw_ostream &os, StringRef valueType,
bool addThisArg, bool addConst) {
os << method.getName() << '(';
if (addThisArg)
emitCPPType(valueType, os)
<< "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
llvm::interleaveComma(method.getArguments(), os,
[&](const InterfaceMethod::Argument &arg) {
os << arg.type << " " << arg.name;
});
os << ')';
if (addConst)
os << " const";
}
/// Get an array of all OpInterface definitions but exclude those subclassing
/// "DeclareOpInterfaceMethods".
static std::vector<llvm::Record *>
getAllOpInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper) {
std::vector<llvm::Record *> defs =
recordKeeper.getAllDerivedDefinitions("OpInterface");
llvm::erase_if(defs, [](const llvm::Record *def) {
return def->isSubClassOf("DeclareOpInterfaceMethods");
});
return defs;
}
namespace {
/// This struct is the base generator used when processing tablegen interfaces.
class InterfaceGenerator {
public:
bool emitInterfaceDefs();
bool emitInterfaceDecls();
bool emitInterfaceDocs();
protected:
InterfaceGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os)
: defs(std::move(defs)), os(os) {}
void emitConceptDecl(Interface &interface);
void emitModelDecl(Interface &interface);
void emitTraitDecl(Interface &interface, StringRef interfaceName,
StringRef interfaceTraitsName);
void emitInterfaceDecl(Interface interface);
/// The set of interface records to emit.
std::vector<llvm::Record *> defs;
// The stream to emit to.
raw_ostream &os;
/// The C++ value type of the interface, e.g. Operation*.
StringRef valueType;
/// The C++ base interface type.
StringRef interfaceBaseType;
/// The name of the typename for the value template.
StringRef valueTemplate;
/// The format context to use for methods.
tblgen::FmtContext nonStaticMethodFmt;
tblgen::FmtContext traitMethodFmt;
};
/// A specialized generator for attribute interfaces.
struct AttrInterfaceGenerator : public InterfaceGenerator {
AttrInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
: InterfaceGenerator(records.getAllDerivedDefinitions("AttrInterface"),
os) {
valueType = "::mlir::Attribute";
interfaceBaseType = "AttrInterface";
valueTemplate = "ConcreteAttr";
StringRef castCode = "(tablegen_opaque_val.cast<ConcreteAttr>())";
nonStaticMethodFmt.addSubst("_attr", castCode).withSelf(castCode);
traitMethodFmt.addSubst("_attr",
"(*static_cast<const ConcreteAttr *>(this))");
}
};
/// A specialized generator for operaton interfaces.
struct OpInterfaceGenerator : public InterfaceGenerator {
OpInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
: InterfaceGenerator(getAllOpInterfaceDefinitions(records), os) {
valueType = "::mlir::Operation *";
interfaceBaseType = "OpInterface";
valueTemplate = "ConcreteOp";
StringRef castCode = "(llvm::cast<ConcreteOp>(tablegen_opaque_val))";
nonStaticMethodFmt.withOp(castCode).withSelf(castCode);
traitMethodFmt.withOp("(*static_cast<ConcreteOp *>(this))");
}
};
/// A specialized generator for type interfaces.
struct TypeInterfaceGenerator : public InterfaceGenerator {
TypeInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
: InterfaceGenerator(records.getAllDerivedDefinitions("TypeInterface"),
os) {
valueType = "::mlir::Type";
interfaceBaseType = "TypeInterface";
valueTemplate = "ConcreteType";
StringRef castCode = "(tablegen_opaque_val.cast<ConcreteType>())";
nonStaticMethodFmt.addSubst("_type", castCode).withSelf(castCode);
traitMethodFmt.addSubst("_type",
"(*static_cast<const ConcreteType *>(this))");
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// GEN: Interface definitions
//===----------------------------------------------------------------------===//
static void emitInterfaceDef(Interface interface, StringRef valueType,
raw_ostream &os) {
StringRef interfaceName = interface.getName();
// Insert the method definitions.
bool isOpInterface = isa<OpInterface>(interface);
for (auto &method : interface.getMethods()) {
emitCPPType(method.getReturnType(), os) << interfaceName << "::";
emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
/*addConst=*/!isOpInterface);
// Forward to the method on the concrete operation type.
os << " {\n return getImpl()->" << method.getName() << '(';
if (!method.isStatic()) {
os << (isOpInterface ? "getOperation()" : "*this");
os << (method.arg_empty() ? "" : ", ");
}
llvm::interleaveComma(
method.getArguments(), os,
[&](const InterfaceMethod::Argument &arg) { os << arg.name; });
os << ");\n }\n";
}
}
bool InterfaceGenerator::emitInterfaceDefs() {
llvm::emitSourceFileHeader("Interface Definitions", os);
for (const auto *def : defs)
emitInterfaceDef(Interface(def), valueType, os);
return false;
}
//===----------------------------------------------------------------------===//
// GEN: Interface declarations
//===----------------------------------------------------------------------===//
void InterfaceGenerator::emitConceptDecl(Interface &interface) {
os << " class Concept {\n"
<< " public:\n"
<< " virtual ~Concept() = default;\n";
// Insert each of the pure virtual concept methods.
for (auto &method : interface.getMethods()) {
os << " virtual ";
emitCPPType(method.getReturnType(), os);
emitMethodNameAndArgs(method, os, valueType,
/*addThisArg=*/!method.isStatic(), /*addConst=*/true);
os << " = 0;\n";
}
os << " };\n";
}
void InterfaceGenerator::emitModelDecl(Interface &interface) {
os << " template<typename " << valueTemplate << ">\n";
os << " class Model : public Concept {\n public:\n";
// Insert each of the virtual method overrides.
for (auto &method : interface.getMethods()) {
emitCPPType(method.getReturnType(), os << " ");
emitMethodNameAndArgs(method, os, valueType,
/*addThisArg=*/!method.isStatic(), /*addConst=*/true);
os << " final {\n ";
// Check for a provided body to the function.
if (Optional<StringRef> body = method.getBody()) {
if (method.isStatic())
os << body->trim();
else
os << tblgen::tgfmt(body->trim(), &nonStaticMethodFmt);
os << "\n }\n";
continue;
}
// Forward to the method on the concrete operation type.
if (method.isStatic())
os << "return " << valueTemplate << "::";
else
os << tblgen::tgfmt("return $_self.", &nonStaticMethodFmt);
// Add the arguments to the call.
os << method.getName() << '(';
llvm::interleaveComma(
method.getArguments(), os,
[&](const InterfaceMethod::Argument &arg) { os << arg.name; });
os << ");\n }\n";
}
os << " };\n";
}
void InterfaceGenerator::emitTraitDecl(Interface &interface,
StringRef interfaceName,
StringRef interfaceTraitsName) {
os << llvm::formatv(" template <typename {3}>\n"
" struct {0}Trait : public ::mlir::{2}<{0},"
" detail::{1}>::Trait<{3}> {{\n",
interfaceName, interfaceTraitsName, interfaceBaseType,
valueTemplate);
// Insert the default implementation for any methods.
bool isOpInterface = isa<OpInterface>(interface);
for (auto &method : interface.getMethods()) {
// Flag interface methods named verifyTrait.
if (method.getName() == "verifyTrait")
PrintFatalError(
formatv("'verifyTrait' method cannot be specified as interface "
"method for '{0}'; use the 'verify' field instead",
interfaceName));
auto defaultImpl = method.getDefaultImplementation();
if (!defaultImpl)
continue;
os << " " << (method.isStatic() ? "static " : "");
emitCPPType(method.getReturnType(), os);
emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
/*addConst=*/!isOpInterface);
os << " {\n " << tblgen::tgfmt(defaultImpl->trim(), &traitMethodFmt)
<< "\n }\n";
}
if (auto verify = interface.getVerify()) {
assert(isa<OpInterface>(interface) && "only OpInterface supports 'verify'");
tblgen::FmtContext verifyCtx;
verifyCtx.withOp("op");
os << " static ::mlir::LogicalResult verifyTrait(::mlir::Operation *op) "
"{\n "
<< tblgen::tgfmt(verify->trim(), &verifyCtx) << "\n }\n";
}
if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration())
os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n";
os << " };\n";
// Emit a utility wrapper trait class.
os << llvm::formatv(" template <typename {1}>\n"
" struct Trait : public {0}Trait<{1}> {{};\n",
interfaceName, valueTemplate);
}
void InterfaceGenerator::emitInterfaceDecl(Interface interface) {
StringRef interfaceName = interface.getName();
auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
// Emit the traits struct containing the concept and model declarations.
os << "namespace detail {\n"
<< "struct " << interfaceTraitsName << " {\n";
emitConceptDecl(interface);
emitModelDecl(interface);
os << "};\n} // end namespace detail\n";
// Emit the main interface class declaration.
os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n"
"public:\n"
" using ::mlir::{3}<{1}, detail::{2}>::{3};\n",
interfaceName, interfaceName, interfaceTraitsName,
interfaceBaseType);
// Emit the derived trait for the interface.
emitTraitDecl(interface, interfaceName, interfaceTraitsName);
// Insert the method declarations.
bool isOpInterface = isa<OpInterface>(interface);
for (auto &method : interface.getMethods()) {
emitCPPType(method.getReturnType(), os << " ");
emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
/*addConst=*/!isOpInterface);
os << ";\n";
}
// Emit any extra declarations.
if (Optional<StringRef> extraDecls = interface.getExtraClassDeclaration())
os << *extraDecls << "\n";
os << "};\n";
}
bool InterfaceGenerator::emitInterfaceDecls() {
llvm::emitSourceFileHeader("Interface Declarations", os);
for (const auto *def : defs)
emitInterfaceDecl(Interface(def));
return false;
}
//===----------------------------------------------------------------------===//
// GEN: Interface documentation
//===----------------------------------------------------------------------===//
static void emitInterfaceDoc(const llvm::Record &interfaceDef,
raw_ostream &os) {
Interface interface(&interfaceDef);
// Emit the interface name followed by the description.
os << "## " << interface.getName() << " (" << interfaceDef.getName() << ")";
@ -263,7 +355,7 @@ static void emitInterfaceDoc(const Record &interfaceDef, raw_ostream &os) {
os << "static ";
emitCPPType(method.getReturnType(), os) << method.getName() << '(';
llvm::interleaveComma(method.getArguments(), os,
[&](const OpInterfaceMethod::Argument &arg) {
[&](const InterfaceMethod::Argument &arg) {
emitCPPType(arg.type, os) << arg.name;
});
os << ");\n```\n";
@ -272,19 +364,17 @@ static void emitInterfaceDoc(const Record &interfaceDef, raw_ostream &os) {
if (auto description = method.getDescription())
mlir::tblgen::emitDescription(*description, os);
// If the body is not provided, this method must be provided by the
// operation.
// If the body is not provided, this method must be provided by the user.
if (!method.getBody())
os << "\nNOTE: This method *must* be implemented by the operation.\n\n";
os << "\nNOTE: This method *must* be implemented by the user.\n\n";
}
}
static bool emitInterfaceDocs(const RecordKeeper &recordKeeper,
raw_ostream &os) {
bool InterfaceGenerator::emitInterfaceDocs() {
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
os << "# Operation Interface definition\n";
os << "# " << interfaceBaseType << " definitions\n";
for (const auto *def : getAllOpInterfaceDefinitions(recordKeeper))
for (const auto *def : defs)
emitInterfaceDoc(*def, os);
return false;
}
@ -293,26 +383,30 @@ static bool emitInterfaceDocs(const RecordKeeper &recordKeeper,
// GEN: Interface registration hooks
//===----------------------------------------------------------------------===//
// Registers the operation interface generator to mlir-tblgen.
static mlir::GenRegistration
genInterfaceDecls("gen-op-interface-decls",
"Generate op interface declarations",
[](const RecordKeeper &records, raw_ostream &os) {
return emitInterfaceDecls(records, os);
});
namespace {
template <typename GeneratorT> struct InterfaceGenRegistration {
InterfaceGenRegistration(StringRef genArg)
: genDeclArg(("gen-" + genArg + "-interface-decls").str()),
genDefArg(("gen-" + genArg + "-interface-defs").str()),
genDocArg(("gen-" + genArg + "-interface-docs").str()),
genDecls(genDeclArg, "Generate interface declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
return GeneratorT(records, os).emitInterfaceDecls();
}),
genDefs(genDefArg, "Generate interface definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
return GeneratorT(records, os).emitInterfaceDefs();
}),
genDocs(genDocArg, "Generate interface documentation",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
return GeneratorT(records, os).emitInterfaceDocs();
}) {}
// Registers the operation interface generator to mlir-tblgen.
static mlir::GenRegistration
genInterfaceDefs("gen-op-interface-defs",
"Generate op interface definitions",
[](const RecordKeeper &records, raw_ostream &os) {
return emitInterfaceDefs(records, os);
});
std::string genDeclArg, genDefArg, genDocArg;
mlir::GenRegistration genDecls, genDefs, genDocs;
};
} // end anonymous namespace
// Registers the operation interface document generator to mlir-tblgen.
static mlir::GenRegistration
genInterfaceDocs("gen-op-interface-doc",
"Generate op interface documentation",
[](const RecordKeeper &records, raw_ostream &os) {
return emitInterfaceDocs(records, os);
});
static InterfaceGenRegistration<AttrInterfaceGenerator> attrGen("attr");
static InterfaceGenRegistration<OpInterfaceGenerator> opGen("op");
static InterfaceGenRegistration<TypeInterfaceGenerator> typeGen("type");

View File

@ -205,7 +205,8 @@ static void emitConceptDecl(const Availability &availability, raw_ostream &os) {
<< " public:\n"
<< " virtual ~Concept() = default;\n"
<< " virtual " << availability.getQueryFnRetType() << " "
<< availability.getQueryFnName() << "(Operation *tblgen_opaque_op) = 0;\n"
<< availability.getQueryFnName()
<< "(Operation *tblgen_opaque_op) const = 0;\n"
<< " };\n";
}
@ -215,7 +216,7 @@ static void emitModelDecl(const Availability &availability, raw_ostream &os) {
<< " public:\n"
<< " " << availability.getQueryFnRetType() << " "
<< availability.getQueryFnName()
<< "(Operation *tblgen_opaque_op) final {\n"
<< "(Operation *tblgen_opaque_op) const final {\n"
<< " auto op = llvm::cast<ConcreteOp>(tblgen_opaque_op);\n"
<< " (void)op;\n"
// Forward to the method on the concrete operation type.