[MLIR] Change ODS collective params build method to provide an empty default value for named attributes

- Provide default value for `ArrayRef<NamedAttribute> attributes` parameter of
  the collective params build method.
- Change the `genSeparateArgParamBuilder` function to not generate build methods
  that may be ambiguous with the new collective params build method.
- This change should help eliminate passing empty NamedAttribue ArrayRef when the
  collective params build method is used
- Extend op-decl.td unit test to make sure the ambiguous build methods are not
  generated.

Differential Revision: https://reviews.llvm.org/D83517
This commit is contained in:
Rahul Joshi 2020-07-13 11:20:27 -07:00
parent f630b8590f
commit 0d988da6d1
4 changed files with 108 additions and 20 deletions

View File

@ -331,8 +331,7 @@ public:
return operation.emitError( return operation.emitError(
"bitwidth emulation is not implemented yet on unsigned op"); "bitwidth emulation is not implemented yet on unsigned op");
} }
rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType, operands, rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType, operands);
ArrayRef<NamedAttribute>());
return success(); return success();
} }
}; };
@ -368,11 +367,11 @@ public:
if (!dstType) if (!dstType)
return failure(); return failure();
if (isBoolScalarOrVector(operands.front().getType())) { if (isBoolScalarOrVector(operands.front().getType())) {
rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>( rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(operation, dstType,
operation, dstType, operands, ArrayRef<NamedAttribute>()); operands);
} else { } else {
rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>( rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(operation, dstType,
operation, dstType, operands, ArrayRef<NamedAttribute>()); operands);
} }
return success(); return success();
} }
@ -529,8 +528,8 @@ public:
// Then we can just erase this operation by forwarding its operand. // Then we can just erase this operation by forwarding its operand.
rewriter.replaceOp(operation, operands.front()); rewriter.replaceOp(operation, operands.front());
} else { } else {
rewriter.template replaceOpWithNewOp<SPIRVOp>( rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType,
operation, dstType, operands, ArrayRef<NamedAttribute>()); operands);
} }
return success(); return success();
} }
@ -1046,8 +1045,7 @@ XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
auto dstType = typeConverter.convertType(xorOp.getType()); auto dstType = typeConverter.convertType(xorOp.getType());
if (!dstType) if (!dstType)
return failure(); return failure();
rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(xorOp, dstType, operands, rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(xorOp, dstType, operands);
ArrayRef<NamedAttribute>());
return success(); return success();
} }

View File

@ -418,8 +418,7 @@ Value Importer::processConstant(llvm::Constant *c) {
} }
if (auto *GV = dyn_cast<llvm::GlobalVariable>(c)) if (auto *GV = dyn_cast<llvm::GlobalVariable>(c))
return bEntry.create<AddressOfOp>(UnknownLoc::get(context), return bEntry.create<AddressOfOp>(UnknownLoc::get(context),
processGlobal(GV), processGlobal(GV));
ArrayRef<NamedAttribute>());
if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) { if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) {
llvm::Instruction *i = ce->getAsInstruction(); llvm::Instruction *i = ce->getAsInstruction();
@ -727,7 +726,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
if (!calledValue) if (!calledValue)
return failure(); return failure();
ops.insert(ops.begin(), calledValue); ops.insert(ops.begin(), calledValue);
op = b.create<CallOp>(loc, tys, ops, ArrayRef<NamedAttribute>()); op = b.create<CallOp>(loc, tys, ops);
} }
if (!ci->getType()->isVoidTy()) if (!ci->getType()->isVoidTy())
v = op->getResult(0); v = op->getResult(0);
@ -809,7 +808,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
Type type = processType(inst->getType()); Type type = processType(inst->getType());
if (!type) if (!type)
return failure(); return failure();
v = b.create<GEPOp>(loc, type, ops, ArrayRef<NamedAttribute>()); v = b.create<GEPOp>(loc, type, ops);
return success(); return success();
} }
} }

View File

@ -171,6 +171,56 @@ def NS_GOp : NS_Op<"op_with_fixed_return_type", []> {
// CHECK-LABEL: class GOp : // CHECK-LABEL: class GOp :
// CHECK: static ::mlir::LogicalResult inferReturnTypes // CHECK: static ::mlir::LogicalResult inferReturnTypes
// Check default value for collective params builder. Check that other builders
// are generated as well.
def NS_HCollectiveParamsOp : NS_Op<"op_collective_params", []> {
let arguments = (ins AnyType:$a);
let results = (outs AnyType:$b);
}
// CHECK_LABEL: class NS_HCollectiveParamsOp :
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type b, ::mlir::Value a);
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::Value a);
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {})
// Check suppression of "separate arg, separate result" build method for an op
// with single variadic arg and single variadic result (since it will be
// ambiguous with the collective params build method).
def NS_HCollectiveParamsSuppress0Op : NS_Op<"op_collective_suppress0", []> {
let arguments = (ins Variadic<I32>:$a);
let results = (outs Variadic<I32>:$b);
}
// CHECK_LABEL: class NS_HCollectiveParamsSuppress0Op :
// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a);
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
// Check suppression of "separate arg, collective result" build method for an op
// with single variadic arg and non variadic result (since it will be
// ambiguous with the collective params build method).
def NS_HCollectiveParamsSuppress1Op : NS_Op<"op_collective_suppress1", []> {
let arguments = (ins Variadic<I32>:$a);
let results = (outs I32:$b);
}
// CHECK_LABEL: class NS_HCollectiveParamsSuppress1Op :
// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a);
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
// Check suppression of "separate arg, collective result" build method for an op
// with single variadic arg and > 1 variadic result (since it will be
// ambiguous with the collective params build method). Note that "separate arg,
// separate result" build method should be generated in this case as its not
// ambiguous with the collective params build method.
def NS_HCollectiveParamsSuppress2Op : NS_Op<"op_collective_suppress2", [SameVariadicResultSize]> {
let arguments = (ins Variadic<I32>:$a);
let results = (outs Variadic<I32>:$b, Variadic<F32>:$c);
}
// CHECK_LABEL: class NS_HCollectiveParamsSuppress2Op :
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::llvm::ArrayRef<::mlir::Type> c, ::mlir::ValueRange a);
// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a);
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
// Check that default builders can be suppressed. // Check that default builders can be suppressed.
// --- // ---

View File

@ -955,14 +955,51 @@ void OpEmitter::genSeparateArgParamBuilder() {
llvm_unreachable("unhandled TypeParamKind"); llvm_unreachable("unhandled TypeParamKind");
}; };
// A separate arg param builder method will have a signature which is
// ambiguous with the collective params build method (generated in
// `genCollectiveParamBuilder` function below) if it has a single
// `ArrayReg<Type>` parameter for result types and a single `ArrayRef<Value>`
// parameter for the operands, no parameters after that, and the collective
// params build method has `attributes` as its last parameter (with
// a default value). This will happen when all of the following are true:
// 1. [`attributes` as last parameter in collective params build method]:
// getNumVariadicRegions must be 0 (otherwise the collective params build
// method ends with a `numRegions` param, and we don't specify default
// value for attributes).
// 2. [single `ArrayRef<Value>` parameter for operands, and no parameters
// after that]: numArgs() must be 1 (if not, each arg gets a separate param
// in the build methods generated here) and the single arg must be a
// non-attribute variadic argument.
// 3. [single `ArrayReg<Type>` parameter for result types]:
// 3a. paramKind should be Collective, or
// 3b. paramKind should be Separate and there should be a single variadic
// result
//
// In that case, skip generating such ambiguous build methods here.
bool hasSingleVariadicResult =
op.getNumResults() == 1 && op.getResult(0).isVariadic();
bool hasSingleVariadicArg =
op.getNumArgs() == 1 &&
op.getArg(0).is<tblgen::NamedTypeConstraint *>() &&
op.getOperand(0).isVariadic();
bool hasNoVariadicRegions = op.getNumVariadicRegions() == 0;
for (auto attrType : attrBuilderType) { for (auto attrType : attrBuilderType) {
emit(attrType, TypeParamKind::Separate, /*inferType=*/false); // Case 3b above.
if (!(hasNoVariadicRegions && hasSingleVariadicArg &&
hasSingleVariadicResult))
emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
if (canInferType(op)) if (canInferType(op))
emit(attrType, TypeParamKind::None, /*inferType=*/true); emit(attrType, TypeParamKind::None, /*inferType=*/true);
// Emit separate arg build with collective type, unless there is only one // The separate arg + collective param kind method will be:
// variadic result, in which case the above would have already generated // (a) Same as the separate arg + separate param kind method if there is
// the same build method. // only one variadic result.
if (!(op.getNumResults() == 1 && op.getResult(0).isVariableLength())) // (b) Ambiguous with the collective params method under conditions in (3a)
// above.
// In either case, skip generating such build method.
if (!hasSingleVariadicResult &&
!(hasNoVariadicRegions && hasSingleVariadicArg))
emit(attrType, TypeParamKind::Collective, /*inferType=*/false); emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
} }
} }
@ -1184,8 +1221,12 @@ void OpEmitter::genCollectiveParamBuilder() {
", ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange " ", ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange "
"operands, " "operands, "
"::llvm::ArrayRef<::mlir::NamedAttribute> attributes"; "::llvm::ArrayRef<::mlir::NamedAttribute> attributes";
if (op.getNumVariadicRegions()) if (op.getNumVariadicRegions()) {
params += ", unsigned numRegions"; params += ", unsigned numRegions";
} else {
// Provide default value for `attributes` since its the last parameter
params += " = {}";
}
auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static); auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static);
auto &body = m.body(); auto &body = m.body();