[mlir][ods] Allow filtering of ops
Add option to filter which op the OpDefinitionsGen run on. This enables having multiple ops together in the same TD file but generating different CC files for them (useful if one wants to use multiclasses or split out 1 dialect into multiple different libraries). There is probably more general query here (e.g., split out all ops that don't have a verify method, or that are commutative) but filtering based on op name (e.g., test.a_op) seemed a reasonable start and didn't require inventing a query specification mechanism here. Differential Revision: https://reviews.llvm.org/D82319
This commit is contained in:
parent
23654d9e7a
commit
ada0d41dbc
|
@ -1,4 +1,5 @@
|
|||
// RUN: mlir-tblgen -gen-op-decls -I %S/../../include %s | FileCheck %s
|
||||
// RUN: mlir-tblgen -gen-op-decls -op-regex="test.a_op" -I %S/../../include %s | FileCheck %s --check-prefix=REDUCE
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
|
@ -195,3 +196,5 @@ def _BOp : NS_Op<"_op_with_leading_underscore_and_no_namespace", []>;
|
|||
// CHECK-LABEL: _BOp declarations
|
||||
// CHECK: class _BOp : public Op<_BOp
|
||||
|
||||
// REDUCE-LABEL: NS::AOp declarations
|
||||
// REDUCE-NOT: NS::BOp declarations
|
||||
|
|
|
@ -21,6 +21,8 @@
|
|||
#include "mlir/TableGen/SideEffects.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/Regex.h"
|
||||
#include "llvm/Support/Signals.h"
|
||||
#include "llvm/TableGen/Error.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
|
@ -32,6 +34,13 @@ using namespace llvm;
|
|||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
|
||||
cl::OptionCategory opDefGenCat("Options for -gen-op-defs and -gen-op-decls");
|
||||
|
||||
static cl::opt<std::string>
|
||||
opFilter("op-regex",
|
||||
cl::desc("Regex of name of op's to filter (no filter if empty)"),
|
||||
cl::cat(opDefGenCat));
|
||||
|
||||
static const char *const tblgenNamePrefix = "tblgen_";
|
||||
static const char *const generatedArgName = "odsArg";
|
||||
static const char *const builderOpState = "odsState";
|
||||
|
@ -2081,10 +2090,37 @@ static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
|
|||
[&os]() { os << ",\n"; });
|
||||
}
|
||||
|
||||
static std::string getOperationName(const Record &def) {
|
||||
auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name");
|
||||
auto opName = def.getValueAsString("opName");
|
||||
if (prefix.empty())
|
||||
return std::string(opName);
|
||||
return std::string(llvm::formatv("{0}.{1}", prefix, opName));
|
||||
}
|
||||
|
||||
static std::vector<Record *>
|
||||
getAllDerivedDefinitions(const RecordKeeper &recordKeeper,
|
||||
StringRef className) {
|
||||
Record *classDef = recordKeeper.getClass(className);
|
||||
if (!classDef)
|
||||
PrintFatalError("ERROR: Couldn't find the `" + className + "' class!\n");
|
||||
|
||||
llvm::Regex includeRegex(opFilter);
|
||||
std::vector<Record *> defs;
|
||||
for (const auto &def : recordKeeper.getDefs()) {
|
||||
if (def.second->isSubClassOf(classDef)) {
|
||||
if (opFilter.empty() || includeRegex.match(getOperationName(*def.second)))
|
||||
defs.push_back(def.second.get());
|
||||
}
|
||||
}
|
||||
|
||||
return defs;
|
||||
}
|
||||
|
||||
static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||
emitSourceFileHeader("Op Declarations", os);
|
||||
|
||||
const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
|
||||
const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
|
||||
emitOpClasses(defs, os, /*emitDecl=*/true);
|
||||
|
||||
return false;
|
||||
|
@ -2093,7 +2129,7 @@ static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
|||
static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||
emitSourceFileHeader("Op Definitions", os);
|
||||
|
||||
const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
|
||||
const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
|
||||
emitOpList(defs, os);
|
||||
emitOpClasses(defs, os, /*emitDecl=*/false);
|
||||
|
||||
|
|
Loading…
Reference in New Issue