[mlir:Standard] Remove support for creating a `unit` ConstantOp

This is completely unused upstream, and does not really have well defined semantics
on what this is supposed to do/how this fits into the ecosystem. Given that, as part of
splitting up the standard dialect it's best to just remove this behavior, instead of try
to awkwardly fit it somewhere upstream. Downstream users are encouraged to
define their own operations that clearly can define the semantics of this.

This also uncovered several lingering uses of ConstantOp that weren't
updated to use arith::ConstantOp, and worked during conversions because
the constant was removed/converted into something else before
verification.

See https://llvm.discourse.group/t/standard-dialect-the-final-chapter/ for more discussion.

Differential Revision: https://reviews.llvm.org/D118654
This commit is contained in:
River Riddle 2022-01-31 13:53:22 -08:00
parent ead1107257
commit 8e123ca65f
12 changed files with 40 additions and 127 deletions

View File

@ -72,7 +72,7 @@ LLVM_ATTRIBUTE_UNUSED static bool needToMaterialize(mlir::Value str) {
/// Unwrap integer constant from mlir::Value.
static llvm::Optional<std::int64_t> getIntIfConstant(mlir::Value value) {
if (auto *definingOp = value.getDefiningOp())
if (auto cst = mlir::dyn_cast<mlir::ConstantOp>(definingOp))
if (auto cst = mlir::dyn_cast<mlir::arith::ConstantOp>(definingOp))
if (auto intAttr = cst.getValue().dyn_cast<mlir::IntegerAttr>())
return intAttr.getInt();
return {};

View File

@ -376,23 +376,16 @@ def ConstantOp : Std_Op<"constant",
operation ::= ssa-id `=` `std.constant` attribute-value `:` type
```
The `constant` operation produces an SSA value equal to some constant
specified by an attribute. This is the way that MLIR uses to form simple
integer and floating point constants, as well as more exotic things like
references to functions and tensor/vector constants.
The `constant` operation produces an SSA value from a symbol reference to a
`builtin.func` operation
Example:
```mlir
// Complex constant
%1 = constant [1.0 : f32, 1.0 : f32] : complex<f32>
// Reference to function @myfn.
%2 = constant @myfn : (tensor<16xf32>, f32) -> tensor<16xf32>
// Equivalent generic forms
%1 = "std.constant"() {value = [1.0 : f32, 1.0 : f32] : complex<f32>}
: () -> complex<f32>
%2 = "std.constant"() {value = @myfn}
: () -> ((tensor<16xf32>, f32) -> tensor<16xf32>)
```
@ -403,15 +396,9 @@ def ConstantOp : Std_Op<"constant",
([rationale](../Rationale/Rationale.md#multithreading-the-compiler)).
}];
let arguments = (ins AnyAttr:$value);
let arguments = (ins FlatSymbolRefAttr:$value);
let results = (outs AnyType);
let builders = [
OpBuilder<(ins "Attribute":$value),
[{ build($_builder, $_state, value.getType(), value); }]>,
OpBuilder<(ins "Attribute":$value, "Type":$type),
[{ build($_builder, $_state, type, value); }]>,
];
let assemblyFormat = "attr-dict $value `:` type(results)";
let extraClassDeclaration = [{
/// Returns true if a constant operation can be built with the given value

View File

@ -435,31 +435,19 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
LogicalResult
matchAndRewrite(ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// If constant refers to a function, convert it to "addressof".
if (auto symbolRef = op.getValue().dyn_cast<FlatSymbolRefAttr>()) {
auto type = typeConverter->convertType(op.getResult().getType());
if (!type || !LLVM::isCompatibleType(type))
return rewriter.notifyMatchFailure(op, "failed to convert result type");
auto type = typeConverter->convertType(op.getResult().getType());
if (!type || !LLVM::isCompatibleType(type))
return rewriter.notifyMatchFailure(op, "failed to convert result type");
auto newOp = rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type,
symbolRef.getValue());
for (const NamedAttribute &attr : op->getAttrs()) {
if (attr.getName().strref() == "value")
continue;
newOp->setAttr(attr.getName(), attr.getValue());
}
rewriter.replaceOp(op, newOp->getResults());
return success();
auto newOp =
rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue());
for (const NamedAttribute &attr : op->getAttrs()) {
if (attr.getName().strref() == "value")
continue;
newOp->setAttr(attr.getName(), attr.getValue());
}
// Calling into other scopes (non-flat reference) is not supported in LLVM.
if (op.getValue().isa<SymbolRefAttr>())
return rewriter.notifyMatchFailure(
op, "referring to a symbol outside of the current module");
return LLVM::detail::oneToOneRewrite(
op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
*getTypeConverter(), rewriter);
rewriter.replaceOp(op, newOp->getResults());
return success();
}
};

View File

@ -291,7 +291,7 @@ static ParallelComputeFunction createParallelComputeFunction(
return llvm::to_vector(
llvm::map_range(llvm::zip(args, attrs), [&](auto tuple) -> Value {
if (IntegerAttr attr = std::get<1>(tuple))
return b.create<ConstantOp>(attr);
return b.create<arith::ConstantOp>(attr);
return std::get<0>(tuple);
}));
};

View File

@ -1576,7 +1576,7 @@ public:
isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
: DenseElementsAttr::get(outputType, intOutputValues);
rewriter.replaceOpWithNewOp<ConstantOp>(genericOp, outputAttr);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(genericOp, outputAttr);
return success();
}

View File

@ -145,7 +145,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
// Stitch results together into one large vector.
Type resultEltType = results[0].getType().cast<VectorType>().getElementType();
Type resultExpandedType = VectorType::get(expandedShape, resultEltType);
Value result = builder.create<ConstantOp>(
Value result = builder.create<arith::ConstantOp>(
resultExpandedType, builder.getZeroAttr(resultExpandedType));
for (int64_t i = 0; i < maxLinearIndex; ++i)

View File

@ -115,7 +115,10 @@ Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder,
Location loc) {
if (arith::ConstantOp::isBuildableWith(value, type))
return builder.create<arith::ConstantOp>(loc, type, value);
return builder.create<ConstantOp>(loc, type, value);
if (ConstantOp::isBuildableWith(value, type))
return builder.create<ConstantOp>(loc, type,
value.cast<FlatSymbolRefAttr>());
return nullptr;
}
//===----------------------------------------------------------------------===//
@ -562,97 +565,35 @@ Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
// ConstantOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, ConstantOp &op) {
p << " ";
p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"});
if (op->getAttrs().size() > 1)
p << ' ';
p << op.getValue();
// If the value is a symbol reference, print a trailing type.
if (op.getValue().isa<SymbolRefAttr>())
p << " : " << op.getType();
}
static ParseResult parseConstantOp(OpAsmParser &parser,
OperationState &result) {
Attribute valueAttr;
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.parseAttribute(valueAttr, "value", result.attributes))
return failure();
// If the attribute is a symbol reference, then we expect a trailing type.
Type type;
if (!valueAttr.isa<SymbolRefAttr>())
type = valueAttr.getType();
else if (parser.parseColonType(type))
return failure();
// Add the attribute type to the list.
return parser.addTypeToList(type, result.types);
}
/// The constant op requires an attribute, and furthermore requires that it
/// matches the return type.
LogicalResult ConstantOp::verify() {
auto value = getValue();
if (!value)
return emitOpError("requires a 'value' attribute");
StringRef fnName = getValue();
Type type = getType();
if (!value.getType().isa<NoneType>() && type != value.getType())
return emitOpError() << "requires attribute's type (" << value.getType()
<< ") to match op's return type (" << type << ")";
if (type.isa<FunctionType>()) {
auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>();
if (!fnAttr)
return emitOpError("requires 'value' to be a function reference");
// Try to find the referenced function.
auto fn = (*this)->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnName);
if (!fn)
return emitOpError() << "reference to undefined function '" << fnName
<< "'";
// Try to find the referenced function.
auto fn = (*this)->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
fnAttr.getValue());
if (!fn)
return emitOpError() << "reference to undefined function '"
<< fnAttr.getValue() << "'";
// Check that the referenced function has the correct type.
if (fn.getType() != type)
return emitOpError("reference to function with mismatched type");
// Check that the referenced function has the correct type.
if (fn.getType() != type)
return emitOpError("reference to function with mismatched type");
return success();
}
if (type.isa<NoneType>() && value.isa<UnitAttr>())
return success();
return emitOpError("unsupported 'value' attribute: ") << value;
return success();
}
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "constant has no operands");
return getValue();
return getValueAttr();
}
void ConstantOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
Type type = getType();
if (type.isa<FunctionType>()) {
setNameFn(getResult(), "f");
} else {
setNameFn(getResult(), "cst");
}
setNameFn(getResult(), "f");
}
/// Returns true if a constant operation can be built with the given value and
/// result type.
bool ConstantOp::isBuildableWith(Attribute value, Type type) {
// SymbolRefAttr can only be used with a function type.
if (value.isa<SymbolRefAttr>())
return type.isa<FunctionType>();
// Otherwise, this must be a UnitAttr.
return value.isa<UnitAttr>() && type.isa<NoneType>();
return value.isa<FlatSymbolRefAttr>() && type.isa<FunctionType>();
}
//===----------------------------------------------------------------------===//

View File

@ -307,7 +307,7 @@ struct TwoDimMultiReductionToReduction
return failure();
auto loc = multiReductionOp.getLoc();
Value result = rewriter.create<ConstantOp>(
Value result = rewriter.create<arith::ConstantOp>(
loc, multiReductionOp.getDestType(),
rewriter.getZeroAttr(multiReductionOp.getDestType()));
int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];

View File

@ -232,7 +232,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
static LogicalResult printOperation(CppEmitter &emitter,
mlir::ConstantOp constantOp) {
Operation *operation = constantOp.getOperation();
Attribute value = constantOp.getValue();
Attribute value = constantOp.getValueAttr();
return printConstantOp(emitter, operation, value);
}

View File

@ -1,7 +1,7 @@
// RUN: mlir-opt -split-input-file %s -verify-diagnostics
func @unsupported_attribute() {
// expected-error @+1 {{unsupported 'value' attribute: "" : index}}
// expected-error @+1 {{invalid kind of attribute specified}}
%0 = constant "" : index
return
}

View File

@ -99,9 +99,6 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
// CHECK: %{{.*}} = arith.cmpf oeq, %{{.*}}, %{{.*}}: vector<4xf32>
%70 = arith.cmpf oeq, %vcf32, %vcf32 : vector<4 x f32>
// CHECK: = constant unit
%73 = constant unit
// CHECK: arith.constant true
%74 = arith.constant true

View File

@ -578,7 +578,7 @@ struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
LogicalResult matchAndRewrite(ILLegalOpG op,
PatternRewriter &rewriter) const final {
IntegerAttr attr = rewriter.getI32IntegerAttr(0);
Value val = rewriter.create<ConstantOp>(op->getLoc(), attr);
Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr);
rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
return success();
};