[mlir][linalg] Remove the StructuredOp capture mechanism.

After https://reviews.llvm.org/D104109, structured ops support scalar inputs. As a result, the capture mechanism meant to pass non-shaped parameters got redundant. The patch removes the capture semantics after the FillOp migrated to use scalar operands https://reviews.llvm.org/D104121.

Differential Revision: https://reviews.llvm.org/D104785
This commit is contained in:
Tobias Gysi 2021-06-28 07:30:02 +00:00
parent a1c0f09a89
commit bbf4436a82
13 changed files with 51 additions and 86 deletions

View File

@ -18,11 +18,9 @@ extern "C" {
#endif
/// Apply the special region builder for the builtin named Linalg op.
/// The list of `capture` MlirValue is passed as-is to the region builder.
/// Assert that `op` is a builtin named Linalg op.
MLIR_CAPI_EXPORTED void
mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op,
intptr_t n, MlirValue const *mlirCaptures);
mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op);
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);

View File

@ -49,7 +49,7 @@ def Linalg_Dialect : Dialect {
kInplaceableAttrName = "linalg.inplaceable";
using RegionBuilderFunType =
llvm::function_ref<void(ImplicitLocOpBuilder &b, Block &, ValueRange)>;
llvm::function_ref<void(ImplicitLocOpBuilder &b, Block &)>;
RegionBuilderFunType getRegionBuilder(StringRef name) {
return namedStructuredOpRegionBuilders.lookup(name);
}

View File

@ -901,7 +901,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
Returns a null function if this named op does not define a region
builder.
}],
/*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &, ValueRange)>",
/*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &)>",
/*methodName=*/"getRegionBuilder",
(ins),
[{ return ConcreteOp::getRegionBuilder(); }]

View File

@ -153,10 +153,8 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
Value getSource() { return input();}
Value getTarget() { return output(); }
static void regionBuilder(
ImplicitLocOpBuilder &b, Block &block, ValueRange captures);
static std::function<
void(ImplicitLocOpBuilder &b, Block &block, ValueRange captures)>
static void regionBuilder(ImplicitLocOpBuilder &b, Block &block);
static std::function<void(ImplicitLocOpBuilder &b, Block &block)>
getRegionBuilder() {
return &regionBuilder;
}
@ -200,10 +198,8 @@ def FillOp : LinalgStructured_Op<"fill", []> {
extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)});
}
static void regionBuilder(
ImplicitLocOpBuilder &b, Block &block, ValueRange captures);
static std::function<
void(ImplicitLocOpBuilder &b, Block &block, ValueRange captures)>
static void regionBuilder(ImplicitLocOpBuilder &b, Block &block);
static std::function<void(ImplicitLocOpBuilder &b, Block &block)>
getRegionBuilder() {
return &regionBuilder;
}
@ -291,8 +287,7 @@ class PoolingBase_Op<string mnemonic, list<OpTrait> props>
return padding().getValue().getValue<int64_t>({i, 1});
}
static std::function<
void(ImplicitLocOpBuilder &b, Block &block, ValueRange captures)>
static std::function<void(ImplicitLocOpBuilder &b, Block &block)>
getRegionBuilder() {
return nullptr;
}
@ -533,8 +528,7 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, [
library_call()->str() : "op_has_no_registered_library_name";
}
static std::function<
void(ImplicitLocOpBuilder &b, Block &block, ValueRange captures)>
static std::function<void(ImplicitLocOpBuilder &b, Block &block)>
getRegionBuilder() {
return nullptr;
}

View File

@ -21,15 +21,10 @@ using namespace mlir::python;
void mlir::python::populateDialectLinalgSubmodule(py::module m) {
m.def(
"fill_builtin_region",
[](PyDialectDescriptor &dialect, PyOperation &op, py::list captures) {
llvm::SmallVector<MlirValue, 4> mlirOperands;
mlirOperands.reserve(captures.size());
for (auto v : captures)
mlirOperands.push_back(py::cast<PyValue *>(v)->get());
mlirLinalgFillBuiltinNamedOpRegion(
dialect.get(), op.get(), mlirOperands.size(), mlirOperands.data());
[](PyDialectDescriptor &dialect, PyOperation &op) {
mlirLinalgFillBuiltinNamedOpRegion(dialect.get(), op.get());
},
py::arg("dialect"), py::arg("op"), py::arg("captures") = py::list(),
py::arg("dialect"), py::arg("op"),
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
"op.");
}

View File

@ -16,13 +16,8 @@ using namespace mlir::linalg;
/// Apply the special region builder for the builtin named Linalg op.
/// Assert that `op` is a builtin named Linalg op.
void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect,
MlirOperation mlirOp, intptr_t n,
MlirValue const *mlirCaptures) {
MlirOperation mlirOp) {
Operation *op = unwrap(mlirOp);
SmallVector<Value> captures;
captures.reserve(n);
for (unsigned idx = 0; idx < n; ++idx)
captures.push_back(unwrap(mlirCaptures[idx]));
LinalgDialect::RegionBuilderFunType fun =
static_cast<LinalgDialect *>(unwrap(linalgDialect))
@ -41,7 +36,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect,
Region &region = op->getRegion(0);
Block *body = b.createBlock(&region, /*insertPt=*/{}, argTypes);
b.setInsertionPointToStart(body);
fun(b, *body, captures);
fun(b, *body);
}
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)

View File

@ -43,20 +43,19 @@ using namespace mlir::linalg;
/// defined C++ ops.
/// This is used by both builders and parsers.
/// This function creates the block in the region with arguments corresponding
/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted
/// to be ShapedType.
/// to the elemental types of `inputTypes` and `outputTypes`. The latter are
/// asserted to be of ShapedType.
template <typename NamedStructuredOpType>
static void fillStructuredOpRegion(
OpBuilder &opBuilder, Region &region, TypeRange inputTypes,
TypeRange outputTypes, ValueRange captures = {},
TypeRange outputTypes,
std::function<void(unsigned, unsigned)> errorHandler = nullptr);
/// Generic entry point to create both the region and the block of a LinalgOp.
template <typename NamedStructuredOpType>
static void
createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result,
TypeRange inputTypes, TypeRange outputTypes,
ValueRange captures = {});
TypeRange inputTypes, TypeRange outputTypes);
/// Common parsing and printing used for both named structured ops created by
/// ods-gen and by manually defined C++ ops. Does not handle regions.
@ -72,17 +71,15 @@ static void printCommonStructuredOpParts(OpAsmPrinter &p,
template <typename NamedStructuredOpType>
static ParseResult
parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
TypeRange inputTypes, TypeRange outputTypes,
ArrayRef<OpAsmParser::OperandType> captures = {});
TypeRange inputTypes, TypeRange outputTypes);
static ParseResult
parseNamedStructuredOpResults(OpAsmParser &parser,
SmallVectorImpl<Type> &resultTypes);
template <typename NamedStructuredOpType>
static ParseResult
parseNamedStructuredOp(OpAsmParser &parser, OperationState &result,
ArrayRef<OpAsmParser::OperandType> captures = {});
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
OperationState &result);
static void printNamedStructuredOpResults(OpAsmPrinter &p,
TypeRange resultTypes);
@ -323,8 +320,7 @@ private:
//===----------------------------------------------------------------------===//
// CopyOp
//===----------------------------------------------------------------------===//
void CopyOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
ValueRange captures) {
void CopyOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {
assert(block.getNumArguments() == 2 && "CopyOp regionBuilder expects 2 args");
b.create<linalg::YieldOp>(block.getArgument(0));
}
@ -403,8 +399,7 @@ void CopyOp::getEffects(
//===----------------------------------------------------------------------===//
// FillOp
//===----------------------------------------------------------------------===//
void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
ValueRange captures) {
void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {
assert(block.getNumArguments() == 2 && "FillOp regionBuilder expects 2 args");
b.create<linalg::YieldOp>(block.getArgument(0));
}
@ -2799,7 +2794,6 @@ template <typename NamedStructuredOpType>
static void
fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
TypeRange inputTypes, TypeRange outputTypes,
ValueRange captures,
std::function<void(unsigned, unsigned)> errorHandler) {
assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
@ -2823,7 +2817,7 @@ fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
opBuilder.setInsertionPointToStart(body);
ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
NamedStructuredOpType::regionBuilder(b, *body, captures);
NamedStructuredOpType::regionBuilder(b, *body);
// indexing_maps is an auto-generated method.
@ -2835,11 +2829,10 @@ template <typename NamedStructuredOpType>
void createAndFillStructuredOpRegion(OpBuilder &opBuilder,
OperationState &result,
TypeRange inputTypes,
TypeRange outputTypes,
ValueRange captures) {
TypeRange outputTypes) {
Region &region = *result.addRegion();
fillStructuredOpRegion<NamedStructuredOpType>(
opBuilder, region, inputTypes, outputTypes, captures,
opBuilder, region, inputTypes, outputTypes,
[&](unsigned expected, unsigned actual) {
assert(expected != actual && "incorrect number of arguments");
});
@ -2902,15 +2895,14 @@ static void printCommonStructuredOpParts(OpAsmPrinter &p,
template <typename NamedStructuredOpType>
static ParseResult
parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
TypeRange inputTypes, TypeRange outputTypes,
ArrayRef<OpAsmParser::OperandType> captures) {
TypeRange inputTypes, TypeRange outputTypes) {
ParseResult res = success();
OpBuilder opBuilder(parser.getBuilder().getContext());
// Resolve `captures` into `capturedValues` at parse time so we can build the
// region with captures.
SmallVector<Value> capturedValues;
fillStructuredOpRegion<NamedStructuredOpType>(
opBuilder, region, inputTypes, outputTypes, capturedValues,
opBuilder, region, inputTypes, outputTypes,
[&](unsigned expected, unsigned actual) {
res = parser.emitError(
parser.getCurrentLocation(),
@ -2931,11 +2923,9 @@ parseNamedStructuredOpResults(OpAsmParser &parser,
}
template <typename NamedStructuredOpType>
static ParseResult
parseNamedStructuredOp(OpAsmParser &parser, OperationState &result,
ArrayRef<OpAsmParser::OperandType> captures) {
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
OperationState &result) {
// TODO: Enable when ods-gen supports captures.
assert(captures.empty() && "unexpected captures for named structured ops");
SmallVector<Type, 1> inputTypes, outputTypes;
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
return failure();
@ -2949,7 +2939,7 @@ parseNamedStructuredOp(OpAsmParser &parser, OperationState &result,
std::unique_ptr<Region> region = std::make_unique<Region>();
if (parseNamedStructuredOpRegion<NamedStructuredOpType>(
parser, *region, inputTypes, outputTypes, captures))
parser, *region, inputTypes, outputTypes))
return failure();
result.addRegion(std::move(region));

View File

@ -63,8 +63,7 @@ static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp,
iterators,
[&regionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) {
ImplicitLocOpBuilder b(loc, bodyBuilder);
regionBuilder(b, *bodyBuilder.getBlock(),
/*captures=*/{});
regionBuilder(b, *bodyBuilder.getBlock());
});
}

View File

@ -33,7 +33,7 @@ class FillOp:
ip=ip)
OpView.__init__(self, op)
linalgDialect = Context.current.get_dialect_descriptor("linalg")
fill_builtin_region(linalgDialect, self.operation, [])
fill_builtin_region(linalgDialect, self.operation)
# TODO: self.result is None. When len(results) == 1 we expect it to be
# results[0] as per _linalg_ops_gen.py. This seems like an orthogonal bug
# in the generator of _linalg_ops_gen.py where we have:

View File

@ -24,7 +24,7 @@
// IMPL-NEXT: return {{.+}}.getAffineMapArrayAttr({ map0, map1, map2 });
//
// IMPL: void Test1Op::regionBuilder(ImplicitLocOpBuilder &b,
// IMPL: Block &block, ValueRange captures) {
// IMPL: Block &block) {
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
// IMPL: Value [[d:.*]] = b.create<MulFOp>([[a]], [[b]]);
// IMPL: Value [[e:.*]] = b.create<AddFOp>([[c]], [[d]]);
@ -49,7 +49,7 @@ def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
// IMPL: AffineMap::get(3, 3, {d0, d1}, context)
//
// IMPL: Test2Op::regionBuilder(ImplicitLocOpBuilder &b,
// IMPL: Block &block, ValueRange captures) {
// IMPL: Block &block) {
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
// IMPL: Value [[d:.*]] = b.create<MulFOp>([[a]], [[b]]);
// IMPL: Value [[e:.*]] = b.create<AddFOp>([[c]], [[d]]);
@ -74,7 +74,7 @@ def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
// IMPL: AffineMap::get(4, 4, {d0, d1, d2}, context)
//
// IMPL: Test3Op::regionBuilder(ImplicitLocOpBuilder &b,
// IMPL: Block &block, ValueRange captures) {
// IMPL: Block &block) {
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
// IMPL: Value [[d:.*]] = b.create<MulFOp>([[a]], [[b]]);
// IMPL: Value [[e:.*]] = b.create<AddFOp>([[c]], [[d]]);
@ -182,7 +182,7 @@ def test7(A: f32(M, K), B: f32(K)) -> (C: f32(M))
// Test output arg order.
// IMPL-LABEL: void Test8Op::regionBuilder(ImplicitLocOpBuilder &b,
// IMPL: Block &block, ValueRange captures) {
// IMPL: Block &block) {
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
// IMPL: Value [[d:.*]] = b.create<MulFOp>([[a]], [[b]]);
// IMPL: Value [[e:.*]] = b.create<SubFOp>([[d]], [[c]]);
@ -199,7 +199,7 @@ def test8(A: f32(M, K), B: f32(K)) -> (C: f32(M))
// IMPL: auto map1 = AffineMap::get(2, 2, {d1}, context);
// IMPL: auto map2 = AffineMap::get(2, 2, {d0}, context);
// IMPL-LABEL: void Test9Op::regionBuilder(ImplicitLocOpBuilder &b,
// IMPL: Block &block, ValueRange captures) {
// IMPL: Block &block) {
// IMPL: Value [[a:.*]](args[0]), [[c:.*]](args[2]);
ods_def<Test9Op>:
def test9(A: f32(M, K), B: f32(K)) -> (C: f32(M))

View File

@ -76,7 +76,7 @@ structured_op: !LinalgStructuredOpConfig
# ODS-NEXT: TypeRange(outputs)
# IMPL-LABEL: void Test1Op::regionBuilder(
# IMPL: ImplicitLocOpBuilder &b, Block &block, ValueRange captures)
# IMPL: ImplicitLocOpBuilder &b, Block &block)
# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64");
# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL0]]);
# IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1);
@ -163,8 +163,7 @@ structured_op: !LinalgStructuredOpConfig
# IMPL: auto attr = op->getAttrOfType<DenseElementsAttr>("strides")
# IMPL: "missing indexing map required attribute 'strides'"
# IMPL: void Test2Op::regionBuilder(
# IMPL-NEXT: ImplicitLocOpBuilder &b, Block &block, ValueRange captures)
# IMPL: void Test2Op::regionBuilder(ImplicitLocOpBuilder &b, Block &block)
# IMPL-NEXT: assert(2 > 0 && block.getNumArguments() == 2 &&
# IMPL: yields.push_back(block.getArgument(0));

View File

@ -1923,7 +1923,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
$_builder,
$_state,
TypeRange(inputs),
TypeRange(outputs)/*, TODO: support captures*/);
TypeRange(outputs));
}]>,
OpBuilder<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
@ -1941,7 +1941,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
$_builder,
$_state,
TypeRange(inputs),
TypeRange(outputs)/*, TODO: support captures*/);
TypeRange(outputs));
}]>,
OpBuilder<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
@ -1956,7 +1956,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
];
let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
let parser = [{{
return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/);
return ::parseNamedStructuredOp<{0}>(parser, result);
}];
let hasFolder = 1;
@ -1964,10 +1964,9 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
// Auto-generated.
ArrayAttr iterator_types();
ArrayAttr indexing_maps();
static void regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ValueRange captures);
static std::function<void(ImplicitLocOpBuilder &b,
Block &, ValueRange)> getRegionBuilder() {{
static void regionBuilder(ImplicitLocOpBuilder &b, Block &block);
static std::function<void(ImplicitLocOpBuilder &b, Block &)>
getRegionBuilder() {{
return regionBuilder;
}
@ -2035,7 +2034,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
$_builder,
$_state,
TypeRange(inputs),
TypeRange(outputs)/*, TODO: support captures*/);
TypeRange(outputs));
{2}
}]>
)FMT";
@ -2354,8 +2353,7 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
};
const char *regionBuilderFmt = R"FMT(
void {0}::regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ValueRange captures) {
void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {
auto args = block.getArguments();
Value {1};
{2}

View File

@ -511,10 +511,8 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
// Auto-generated.
ArrayAttr iterator_types();
ArrayAttr indexing_maps();
static void regionBuilder(
ImplicitLocOpBuilder &b, Block &block, ValueRange captures);
static std::function<
void(ImplicitLocOpBuilder &b, Block &, ValueRange)>
static void regionBuilder(ImplicitLocOpBuilder &b, Block &block);
static std::function<void(ImplicitLocOpBuilder &b, Block &)>
getRegionBuilder() {{
return regionBuilder;
}
@ -883,8 +881,7 @@ LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{
// {1}: Number of args
// {2}: Statements
static const char structuredOpRegionBuilderFormat[] = R"FMT(
void {0}::regionBuilder(
ImplicitLocOpBuilder &b, Block &block, ValueRange captures) {{
void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {{
assert({1} > 0 && block.getNumArguments() == {1} &&
"{0} regionBuilder expects {1} (>=0) args");
RegionBuilderHelper helper(block.getArgument(0).getContext(), block);