diff --git a/flang/lib/Optimizer/Builder/Character.cpp b/flang/lib/Optimizer/Builder/Character.cpp index 87faa3b42c44..e4719133f3fa 100644 --- a/flang/lib/Optimizer/Builder/Character.cpp +++ b/flang/lib/Optimizer/Builder/Character.cpp @@ -72,7 +72,7 @@ LLVM_ATTRIBUTE_UNUSED static bool needToMaterialize(mlir::Value str) { /// Unwrap integer constant from mlir::Value. static llvm::Optional getIntIfConstant(mlir::Value value) { if (auto *definingOp = value.getDefiningOp()) - if (auto cst = mlir::dyn_cast(definingOp)) + if (auto cst = mlir::dyn_cast(definingOp)) if (auto intAttr = cst.getValue().dyn_cast()) return intAttr.getInt(); return {}; diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td index 9efe4ceb2147..2aca33eda3c4 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -376,23 +376,16 @@ def ConstantOp : Std_Op<"constant", operation ::= ssa-id `=` `std.constant` attribute-value `:` type ``` - The `constant` operation produces an SSA value equal to some constant - specified by an attribute. This is the way that MLIR uses to form simple - integer and floating point constants, as well as more exotic things like - references to functions and tensor/vector constants. + The `constant` operation produces an SSA value from a symbol reference to a + `builtin.func` operation Example: ```mlir - // Complex constant - %1 = constant [1.0 : f32, 1.0 : f32] : complex - // Reference to function @myfn. %2 = constant @myfn : (tensor<16xf32>, f32) -> tensor<16xf32> // Equivalent generic forms - %1 = "std.constant"() {value = [1.0 : f32, 1.0 : f32] : complex} - : () -> complex %2 = "std.constant"() {value = @myfn} : () -> ((tensor<16xf32>, f32) -> tensor<16xf32>) ``` @@ -403,15 +396,9 @@ def ConstantOp : Std_Op<"constant", ([rationale](../Rationale/Rationale.md#multithreading-the-compiler)). }]; - let arguments = (ins AnyAttr:$value); + let arguments = (ins FlatSymbolRefAttr:$value); let results = (outs AnyType); - - let builders = [ - OpBuilder<(ins "Attribute":$value), - [{ build($_builder, $_state, value.getType(), value); }]>, - OpBuilder<(ins "Attribute":$value, "Type":$type), - [{ build($_builder, $_state, type, value); }]>, - ]; + let assemblyFormat = "attr-dict $value `:` type(results)"; let extraClassDeclaration = [{ /// Returns true if a constant operation can be built with the given value diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index b2e18ab8196f..04c51422ed11 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -435,31 +435,19 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern { LogicalResult matchAndRewrite(ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // If constant refers to a function, convert it to "addressof". - if (auto symbolRef = op.getValue().dyn_cast()) { - auto type = typeConverter->convertType(op.getResult().getType()); - if (!type || !LLVM::isCompatibleType(type)) - return rewriter.notifyMatchFailure(op, "failed to convert result type"); + auto type = typeConverter->convertType(op.getResult().getType()); + if (!type || !LLVM::isCompatibleType(type)) + return rewriter.notifyMatchFailure(op, "failed to convert result type"); - auto newOp = rewriter.create(op.getLoc(), type, - symbolRef.getValue()); - for (const NamedAttribute &attr : op->getAttrs()) { - if (attr.getName().strref() == "value") - continue; - newOp->setAttr(attr.getName(), attr.getValue()); - } - rewriter.replaceOp(op, newOp->getResults()); - return success(); + auto newOp = + rewriter.create(op.getLoc(), type, op.getValue()); + for (const NamedAttribute &attr : op->getAttrs()) { + if (attr.getName().strref() == "value") + continue; + newOp->setAttr(attr.getName(), attr.getValue()); } - - // Calling into other scopes (non-flat reference) is not supported in LLVM. - if (op.getValue().isa()) - return rewriter.notifyMatchFailure( - op, "referring to a symbol outside of the current module"); - - return LLVM::detail::oneToOneRewrite( - op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), - *getTypeConverter(), rewriter); + rewriter.replaceOp(op, newOp->getResults()); + return success(); } }; diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp index e1c91fbbc1d9..74d6d42e2b9b 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -291,7 +291,7 @@ static ParallelComputeFunction createParallelComputeFunction( return llvm::to_vector( llvm::map_range(llvm::zip(args, attrs), [&](auto tuple) -> Value { if (IntegerAttr attr = std::get<1>(tuple)) - return b.create(attr); + return b.create(attr); return std::get<0>(tuple); })); }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 19e3c0318f57..32fd370012c4 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1576,7 +1576,7 @@ public: isFloat ? DenseElementsAttr::get(outputType, fpOutputValues) : DenseElementsAttr::get(outputType, intOutputValues); - rewriter.replaceOpWithNewOp(genericOp, outputAttr); + rewriter.replaceOpWithNewOp(genericOp, outputAttr); return success(); } diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp index 1237f0b47cf7..d47d6ead0273 100644 --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -145,7 +145,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, // Stitch results together into one large vector. Type resultEltType = results[0].getType().cast().getElementType(); Type resultExpandedType = VectorType::get(expandedShape, resultEltType); - Value result = builder.create( + Value result = builder.create( resultExpandedType, builder.getZeroAttr(resultExpandedType)); for (int64_t i = 0; i < maxLinearIndex; ++i) diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index 65e72293ed3f..bf35625adb62 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -115,7 +115,10 @@ Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder, Location loc) { if (arith::ConstantOp::isBuildableWith(value, type)) return builder.create(loc, type, value); - return builder.create(loc, type, value); + if (ConstantOp::isBuildableWith(value, type)) + return builder.create(loc, type, + value.cast()); + return nullptr; } //===----------------------------------------------------------------------===// @@ -562,97 +565,35 @@ Block *CondBranchOp::getSuccessorForOperands(ArrayRef operands) { // ConstantOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, ConstantOp &op) { - p << " "; - p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"}); - - if (op->getAttrs().size() > 1) - p << ' '; - p << op.getValue(); - - // If the value is a symbol reference, print a trailing type. - if (op.getValue().isa()) - p << " : " << op.getType(); -} - -static ParseResult parseConstantOp(OpAsmParser &parser, - OperationState &result) { - Attribute valueAttr; - if (parser.parseOptionalAttrDict(result.attributes) || - parser.parseAttribute(valueAttr, "value", result.attributes)) - return failure(); - - // If the attribute is a symbol reference, then we expect a trailing type. - Type type; - if (!valueAttr.isa()) - type = valueAttr.getType(); - else if (parser.parseColonType(type)) - return failure(); - - // Add the attribute type to the list. - return parser.addTypeToList(type, result.types); -} - -/// The constant op requires an attribute, and furthermore requires that it -/// matches the return type. LogicalResult ConstantOp::verify() { - auto value = getValue(); - if (!value) - return emitOpError("requires a 'value' attribute"); - + StringRef fnName = getValue(); Type type = getType(); - if (!value.getType().isa() && type != value.getType()) - return emitOpError() << "requires attribute's type (" << value.getType() - << ") to match op's return type (" << type << ")"; - if (type.isa()) { - auto fnAttr = value.dyn_cast(); - if (!fnAttr) - return emitOpError("requires 'value' to be a function reference"); + // Try to find the referenced function. + auto fn = (*this)->getParentOfType().lookupSymbol(fnName); + if (!fn) + return emitOpError() << "reference to undefined function '" << fnName + << "'"; - // Try to find the referenced function. - auto fn = (*this)->getParentOfType().lookupSymbol( - fnAttr.getValue()); - if (!fn) - return emitOpError() << "reference to undefined function '" - << fnAttr.getValue() << "'"; + // Check that the referenced function has the correct type. + if (fn.getType() != type) + return emitOpError("reference to function with mismatched type"); - // Check that the referenced function has the correct type. - if (fn.getType() != type) - return emitOpError("reference to function with mismatched type"); - - return success(); - } - - if (type.isa() && value.isa()) - return success(); - - return emitOpError("unsupported 'value' attribute: ") << value; + return success(); } OpFoldResult ConstantOp::fold(ArrayRef operands) { assert(operands.empty() && "constant has no operands"); - return getValue(); + return getValueAttr(); } void ConstantOp::getAsmResultNames( function_ref setNameFn) { - Type type = getType(); - if (type.isa()) { - setNameFn(getResult(), "f"); - } else { - setNameFn(getResult(), "cst"); - } + setNameFn(getResult(), "f"); } -/// Returns true if a constant operation can be built with the given value and -/// result type. bool ConstantOp::isBuildableWith(Attribute value, Type type) { - // SymbolRefAttr can only be used with a function type. - if (value.isa()) - return type.isa(); - // Otherwise, this must be a UnitAttr. - return value.isa() && type.isa(); + return value.isa() && type.isa(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp index 495de25662db..52b52763b0dc 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp @@ -307,7 +307,7 @@ struct TwoDimMultiReductionToReduction return failure(); auto loc = multiReductionOp.getLoc(); - Value result = rewriter.create( + Value result = rewriter.create( loc, multiReductionOp.getDestType(), rewriter.getZeroAttr(multiReductionOp.getDestType())); int outerDim = multiReductionOp.getSourceVectorType().getShape()[0]; diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index a332029a4cf8..5d7ef65fcad2 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -232,7 +232,7 @@ static LogicalResult printOperation(CppEmitter &emitter, static LogicalResult printOperation(CppEmitter &emitter, mlir::ConstantOp constantOp) { Operation *operation = constantOp.getOperation(); - Attribute value = constantOp.getValue(); + Attribute value = constantOp.getValueAttr(); return printConstantOp(emitter, operation, value); } diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir index 836158dd2160..e9359936a196 100644 --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt -split-input-file %s -verify-diagnostics func @unsupported_attribute() { - // expected-error @+1 {{unsupported 'value' attribute: "" : index}} + // expected-error @+1 {{invalid kind of attribute specified}} %0 = constant "" : index return } diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir index fefe7387f284..55280b2ac8b8 100644 --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -99,9 +99,6 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) { // CHECK: %{{.*}} = arith.cmpf oeq, %{{.*}}, %{{.*}}: vector<4xf32> %70 = arith.cmpf oeq, %vcf32, %vcf32 : vector<4 x f32> - // CHECK: = constant unit - %73 = constant unit - // CHECK: arith.constant true %74 = arith.constant true diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 34dd14176b45..53661511ee32 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -578,7 +578,7 @@ struct TestCreateUnregisteredOp : public OpRewritePattern { LogicalResult matchAndRewrite(ILLegalOpG op, PatternRewriter &rewriter) const final { IntegerAttr attr = rewriter.getI32IntegerAttr(0); - Value val = rewriter.create(op->getLoc(), attr); + Value val = rewriter.create(op->getLoc(), attr); rewriter.replaceOpWithNewOp(op, val); return success(); };