[mlir][ods] Allow filtering of ops

Add option to filter which op the OpDefinitionsGen run on. This enables having multiple ops together in the same TD file but generating different CC files for them (useful if one wants to use multiclasses or split out 1 dialect into multiple different libraries). There is probably more general query here (e.g., split out all ops that don't have a verify method, or that are commutative) but filtering based on op name (e.g., test.a_op) seemed a reasonable start and didn't require inventing a query specification mechanism here.

Differential Revision: https://reviews.llvm.org/D82319
This commit is contained in:
Jacques Pienaar 2020-06-22 14:56:54 -07:00
parent 23654d9e7a
commit ada0d41dbc
2 changed files with 41 additions and 2 deletions

View File

@ -1,4 +1,5 @@
// RUN: mlir-tblgen -gen-op-decls -I %S/../../include %s | FileCheck %s
// RUN: mlir-tblgen -gen-op-decls -op-regex="test.a_op" -I %S/../../include %s | FileCheck %s --check-prefix=REDUCE
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
@ -195,3 +196,5 @@ def _BOp : NS_Op<"_op_with_leading_underscore_and_no_namespace", []>;
// CHECK-LABEL: _BOp declarations
// CHECK: class _BOp : public Op<_BOp
// REDUCE-LABEL: NS::AOp declarations
// REDUCE-NOT: NS::BOp declarations

View File

@ -21,6 +21,8 @@
#include "mlir/TableGen/SideEffects.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
@ -32,6 +34,13 @@ using namespace llvm;
using namespace mlir;
using namespace mlir::tblgen;
cl::OptionCategory opDefGenCat("Options for -gen-op-defs and -gen-op-decls");
static cl::opt<std::string>
opFilter("op-regex",
cl::desc("Regex of name of op's to filter (no filter if empty)"),
cl::cat(opDefGenCat));
static const char *const tblgenNamePrefix = "tblgen_";
static const char *const generatedArgName = "odsArg";
static const char *const builderOpState = "odsState";
@ -2081,10 +2090,37 @@ static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
[&os]() { os << ",\n"; });
}
static std::string getOperationName(const Record &def) {
auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name");
auto opName = def.getValueAsString("opName");
if (prefix.empty())
return std::string(opName);
return std::string(llvm::formatv("{0}.{1}", prefix, opName));
}
static std::vector<Record *>
getAllDerivedDefinitions(const RecordKeeper &recordKeeper,
StringRef className) {
Record *classDef = recordKeeper.getClass(className);
if (!classDef)
PrintFatalError("ERROR: Couldn't find the `" + className + "' class!\n");
llvm::Regex includeRegex(opFilter);
std::vector<Record *> defs;
for (const auto &def : recordKeeper.getDefs()) {
if (def.second->isSubClassOf(classDef)) {
if (opFilter.empty() || includeRegex.match(getOperationName(*def.second)))
defs.push_back(def.second.get());
}
}
return defs;
}
static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Op Declarations", os);
const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
emitOpClasses(defs, os, /*emitDecl=*/true);
return false;
@ -2093,7 +2129,7 @@ static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Op Definitions", os);
const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
emitOpList(defs, os);
emitOpClasses(defs, os, /*emitDecl=*/false);