[mlir][NFC] Remove deprecated/old build/fold/parser utilities from OpDefinition

These have generally been replaced by better ODS functionality, and do not
need to be explicitly provided anymore.

Differential Revision: https://reviews.llvm.org/D119065
This commit is contained in:
River Riddle 2022-02-06 12:33:08 -08:00
parent 3c69bc4d6e
commit 60cac0c081
16 changed files with 105 additions and 200 deletions

View File

@ -2534,19 +2534,15 @@ def fir_StringLitOp : fir_Op<"string_lit", [NoSideEffect]> {
class fir_ArithmeticOp<string mnemonic, list<Trait> traits = []> :
fir_Op<mnemonic,
!listconcat(traits, [NoSideEffect, SameOperandsAndResultType])>,
Results<(outs AnyType)> {
let parser = "return impl::parseOneResultSameOperandTypeOp(parser, result);";
let printer = "return printBinaryOp(this->getOperation(), p);";
Results<(outs AnyType:$result)> {
let assemblyFormat = "operands attr-dict `:` type($result)";
}
class fir_UnaryArithmeticOp<string mnemonic, list<Trait> traits = []> :
fir_Op<mnemonic,
!listconcat(traits, [NoSideEffect, SameOperandsAndResultType])>,
Results<(outs AnyType)> {
let parser = "return impl::parseOneResultSameOperandTypeOp(parser, result);";
let printer = "return printUnaryOp(this->getOperation(), p);";
Results<(outs AnyType:$result)> {
let assemblyFormat = "operands attr-dict `:` type($result)";
}
def fir_ConstcOp : fir_Op<"constc", [NoSideEffect]> {

View File

@ -3211,26 +3211,6 @@ mlir::ParseResult fir::parseSelector(mlir::OpAsmParser &parser,
return mlir::success();
}
/// Generic pretty-printer of a binary operation
static void printBinaryOp(Operation *op, OpAsmPrinter &p) {
assert(op->getNumOperands() == 2 && "binary op must have two operands");
assert(op->getNumResults() == 1 && "binary op must have one result");
p << ' ' << op->getOperand(0) << ", " << op->getOperand(1);
p.printOptionalAttrDict(op->getAttrs());
p << " : " << op->getResult(0).getType();
}
/// Generic pretty-printer of an unary operation
static void printUnaryOp(Operation *op, OpAsmPrinter &p) {
assert(op->getNumOperands() == 1 && "unary op must have one operand");
assert(op->getNumResults() == 1 && "unary op must have one result");
p << ' ' << op->getOperand(0);
p.printOptionalAttrDict(op->getAttrs());
p << " : " << op->getResult(0).getType();
}
bool fir::isReferenceLike(mlir::Type type) {
return type.isa<fir::ReferenceType>() || type.isa<fir::HeapType>() ||
type.isa<fir::PointerType>();

View File

@ -419,8 +419,7 @@ class LLVM_CastOp<string mnemonic, string builderFunc, Type type,
let arguments = (ins type:$arg);
let results = (outs resultType:$res);
let builders = [LLVM_OneResultOpBuilder];
let parser = [{ return mlir::impl::parseCastOp(parser, result); }];
let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }];
let assemblyFormat = "$arg attr-dict `:` type($arg) `to` type($res)";
}
def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "CreateBitCast",
LLVM_AnyNonAggregate, LLVM_AnyNonAggregate> {

View File

@ -383,7 +383,6 @@ def MemRef_CastOp : MemRef_Op<"cast", [
}];
let hasFolder = 1;
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//

View File

@ -34,6 +34,7 @@ class SPV_ArithmeticBinaryOp<string mnemonic, Type type,
let results = (outs
SPV_ScalarOrVectorOrCoopMatrixOf<type>:$result
);
let assemblyFormat = "operands attr-dict `:` type($result)";
}
class SPV_ArithmeticUnaryOp<string mnemonic, Type type,

View File

@ -4338,8 +4338,6 @@ class SPV_BinaryOp<string mnemonic, Type resultType, Type operandsType,
SPV_ScalarOrVectorOf<resultType>:$result
);
let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];
let printer = [{ return impl::printOneResultOp(getOperation(), p); }];
// No additional verification needed in addition to the ODS-generated ones.
let hasVerifier = 0;
}

View File

@ -21,7 +21,9 @@ class SPV_BitBinaryOp<string mnemonic, list<Trait> traits = []> :
// All the operands type used in bit instructions are SPV_Integer.
SPV_BinaryOp<mnemonic, SPV_Integer, SPV_Integer,
!listconcat(traits,
[NoSideEffect, SameOperandsAndResultType])>;
[NoSideEffect, SameOperandsAndResultType])> {
let assemblyFormat = "operands attr-dict `:` type($result)";
}
class SPV_BitFieldExtractOp<string mnemonic, list<Trait> traits = []> :
SPV_Op<mnemonic, !listconcat(traits,

View File

@ -29,9 +29,9 @@ class SPV_CastOp<string mnemonic, Type resultType, Type operandType,
let results = (outs
SPV_ScalarOrVectorOrCoopMatrixOf<resultType>:$result
);
let parser = [{ return mlir::impl::parseCastOp(parser, result); }];
let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }];
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` type($result)
}];
}
// -----
@ -85,9 +85,9 @@ def SPV_BitcastOp : SPV_Op<"Bitcast", [NoSideEffect]> {
SPV_ScalarOrVectorOrPtr:$result
);
let parser = [{ return mlir::impl::parseCastOp(parser, result); }];
let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }];
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` type($result)
}];
let hasCanonicalizer = 1;
}

View File

@ -72,10 +72,6 @@ class SPV_GLSLBinaryOp<string mnemonic, int opcode, Type resultType,
SPV_ScalarOrVectorOf<resultType>:$result
);
let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];
let printer = [{ return impl::printOneResultOp(getOperation(), p); }];
let hasVerifier = 0;
}
@ -83,7 +79,10 @@ class SPV_GLSLBinaryOp<string mnemonic, int opcode, Type resultType,
// return type matches.
class SPV_GLSLBinaryArithmeticOp<string mnemonic, int opcode, Type type,
list<Trait> traits = []> :
SPV_GLSLBinaryOp<mnemonic, opcode, type, type, traits>;
SPV_GLSLBinaryOp<mnemonic, opcode, type, type,
traits # [SameOperandsAndResultType]> {
let assemblyFormat = "operands attr-dict `:` type($result)";
}
// Base class for GLSL ternary ops.
class SPV_GLSLTernaryArithmeticOp<string mnemonic, int opcode, Type type,
@ -100,9 +99,8 @@ class SPV_GLSLTernaryArithmeticOp<string mnemonic, int opcode, Type type,
SPV_ScalarOrVectorOf<type>:$result
);
let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];
let printer = [{ return impl::printOneResultOp(getOperation(), p); }];
let parser = [{ return parseOneResultSameOperandTypeOp(parser, result); }];
let printer = [{ return printOneResultOp(getOperation(), p); }];
let hasVerifier = 0;
}

View File

@ -71,10 +71,6 @@ class SPV_OCLBinaryOp<string mnemonic, int opcode, Type resultType,
SPV_ScalarOrVectorOf<resultType>:$result
);
let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];
let printer = [{ return impl::printOneResultOp(getOperation(), p); }];
let hasVerifier = 0;
}
@ -82,7 +78,10 @@ class SPV_OCLBinaryOp<string mnemonic, int opcode, Type resultType,
// return type matches.
class SPV_OCLBinaryArithmeticOp<string mnemonic, int opcode, Type type,
list<Trait> traits = []> :
SPV_OCLBinaryOp<mnemonic, opcode, type, type, traits>;
SPV_OCLBinaryOp<mnemonic, opcode, type, type,
traits # [SameOperandsAndResultType]> {
let assemblyFormat = "operands attr-dict `:` type($result)";
}
// -----

View File

@ -14,6 +14,7 @@
#define SHAPE_OPS
include "mlir/Dialect/Shape/IR/ShapeBase.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@ -331,7 +332,9 @@ def Shape_RankOp : Shape_Op<"rank",
}];
}
def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [
DeclareOpInterfaceMethods<CastOpInterface>, NoSideEffect
]> {
let summary = "Creates a dimension tensor from a shape";
let description = [{
Converts a shape to a 1D integral tensor of extents. The number of elements
@ -624,7 +627,9 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of",
}];
}
def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> {
def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [
DeclareOpInterfaceMethods<CastOpInterface>, NoSideEffect
]> {
let summary = "Casts between index types of the shape and standard dialect";
let description = [{
Converts a `shape.size` to a standard index. This operation and its

View File

@ -1897,26 +1897,9 @@ protected:
};
//===----------------------------------------------------------------------===//
// Common Operation Folders/Parsers/Printers
// CastOpInterface utilities
//===----------------------------------------------------------------------===//
// These functions are out-of-line implementations of the methods in UnaryOp and
// BinaryOp, which avoids them being template instantiated/duplicated.
namespace impl {
ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser,
OperationState &result);
void buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs,
Value rhs);
ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
OperationState &result);
// Prints the given binary `op` in custom assembly form if both the two operands
// and the result have the same time. Otherwise, prints the generic assembly
// form.
void printOneResultOp(Operation *op, OpAsmPrinter &p);
} // namespace impl
// These functions are out-of-line implementations of the methods in
// CastOpInterface, which avoids them being template instantiated/duplicated.
namespace impl {
@ -1927,20 +1910,6 @@ LogicalResult foldCastInterfaceOp(Operation *op,
/// Attempt to verify the given cast operation.
LogicalResult verifyCastInterfaceOp(
Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible);
// TODO: Remove the parse/print/build here (new ODS functionality obsoletes the
// need for them, but some older ODS code in `std` still depends on them).
void buildCastOp(OpBuilder &builder, OperationState &result, Value source,
Type destType);
ParseResult parseCastOp(OpAsmParser &parser, OperationState &result);
void printCastOp(Operation *op, OpAsmPrinter &p);
// TODO: These methods are deprecated in favor of CastOpInterface. Remove them
// when all uses have been updated. Also, consider adding functionality to
// CastOpInterface to be able to perform the ChainedTensorCast canonicalization
// generically.
Value foldCastOp(Operation *op);
LogicalResult verifyCastOp(Operation *op,
function_ref<bool(Type, Type)> areCastCompatible);
} // namespace impl
} // namespace mlir

View File

@ -65,10 +65,6 @@ Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
return NoneType::get(type.getContext());
}
LogicalResult memref::CastOp::verify() {
return impl::verifyCastOp(*this, areCastCompatible);
}
//===----------------------------------------------------------------------===//
// AllocOp / AllocaOp
//===----------------------------------------------------------------------===//

View File

@ -64,6 +64,54 @@ static constexpr const char kCompositeSpecConstituentsName[] = "constituents";
// Common utility functions
//===----------------------------------------------------------------------===//
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::OperandType, 2> ops;
Type type;
// If the operand list is in-between parentheses, then we have a generic form.
// (see the fallback in `printOneResultOp`).
SMLoc loc = parser.getCurrentLocation();
if (!parser.parseOptionalLParen()) {
if (parser.parseOperandList(ops) || parser.parseRParen() ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColon() || parser.parseType(type))
return failure();
auto fnType = type.dyn_cast<FunctionType>();
if (!fnType) {
parser.emitError(loc, "expected function type");
return failure();
}
if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
return failure();
result.addTypes(fnType.getResults());
return success();
}
return failure(parser.parseOperandList(ops) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) ||
parser.resolveOperands(ops, type, result.operands) ||
parser.addTypeToList(type, result.types));
}
static void printOneResultOp(Operation *op, OpAsmPrinter &p) {
assert(op->getNumResults() == 1 && "op should have one result");
// If not all the operand and result types are the same, just use the
// generic assembly form to avoid omitting information in printing.
auto resultType = op->getResult(0).getType();
if (llvm::any_of(op->getOperandTypes(),
[&](Type type) { return type != resultType; })) {
p.printGenericOp(op, /*printOpName=*/false);
return;
}
p << ' ';
p.printOperands(op->getOperands());
p.printOptionalAttrDict(op->getAttrs());
// Now we can output only one type for all operands and the result.
p << " : " << resultType;
}
/// Returns true if the given op is a function-like op or nested in a
/// function-like op without a module-like op in the middle.
static bool isNestedInFunctionOpInterface(Operation *op) {

View File

@ -1692,7 +1692,7 @@ OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
// `IntegerAttr`s which makes constant folding simple.
if (Attribute arg = operands[0])
return arg;
return impl::foldCastOp(*this);
return OpFoldResult();
}
void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@ -1700,6 +1700,12 @@ void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<IndexToSizeToIndexCanonicalization>(context);
}
bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
return inputs[0].isa<IndexType, SizeType>() && outputs[0].isa<IndexType>();
}
//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//
@ -1750,7 +1756,7 @@ LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
if (!operands[0])
return impl::foldCastOp(*this);
return OpFoldResult();
Builder builder(getContext());
auto shape = llvm::to_vector<6>(
operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
@ -1759,6 +1765,21 @@ OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
return DenseIntElementsAttr::get(type, shape);
}
bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
if (auto inputTensor = inputs[0].dyn_cast<RankedTensorType>()) {
if (!inputTensor.getElementType().isa<IndexType>() ||
inputTensor.getRank() != 1 || !inputTensor.isDynamicDim(0))
return false;
} else if (!inputs[0].isa<ShapeType>()) {
return false;
}
TensorType outputTensor = outputs[0].dyn_cast<TensorType>();
return outputTensor && outputTensor.getElementType().isa<IndexType>();
}
//===----------------------------------------------------------------------===//
// ReduceOp
//===----------------------------------------------------------------------===//

View File

@ -1125,69 +1125,7 @@ bool OpTrait::hasElementwiseMappableTraits(Operation *op) {
}
//===----------------------------------------------------------------------===//
// BinaryOp implementation
//===----------------------------------------------------------------------===//
// These functions are out-of-line implementations of the methods in BinaryOp,
// which avoids them being template instantiated/duplicated.
void impl::buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs,
Value rhs) {
assert(lhs.getType() == rhs.getType());
result.addOperands({lhs, rhs});
result.types.push_back(lhs.getType());
}
ParseResult impl::parseOneResultSameOperandTypeOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::OperandType, 2> ops;
Type type;
// If the operand list is in-between parentheses, then we have a generic form.
// (see the fallback in `printOneResultOp`).
SMLoc loc = parser.getCurrentLocation();
if (!parser.parseOptionalLParen()) {
if (parser.parseOperandList(ops) || parser.parseRParen() ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColon() || parser.parseType(type))
return failure();
auto fnType = type.dyn_cast<FunctionType>();
if (!fnType) {
parser.emitError(loc, "expected function type");
return failure();
}
if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
return failure();
result.addTypes(fnType.getResults());
return success();
}
return failure(parser.parseOperandList(ops) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) ||
parser.resolveOperands(ops, type, result.operands) ||
parser.addTypeToList(type, result.types));
}
void impl::printOneResultOp(Operation *op, OpAsmPrinter &p) {
assert(op->getNumResults() == 1 && "op should have one result");
// If not all the operand and result types are the same, just use the
// generic assembly form to avoid omitting information in printing.
auto resultType = op->getResult(0).getType();
if (llvm::any_of(op->getOperandTypes(),
[&](Type type) { return type != resultType; })) {
p.printGenericOp(op, /*printOpName=*/false);
return;
}
p << ' ';
p.printOperands(op->getOperands());
p.printOptionalAttrDict(op->getAttrs());
// Now we can output only one type for all operands and the result.
p << " : " << resultType;
}
//===----------------------------------------------------------------------===//
// CastOp implementation
// CastOpInterface
//===----------------------------------------------------------------------===//
/// Attempt to fold the given cast operation.
@ -1232,50 +1170,6 @@ LogicalResult impl::verifyCastInterfaceOp(
return success();
}
void impl::buildCastOp(OpBuilder &builder, OperationState &result, Value source,
Type destType) {
result.addOperands(source);
result.addTypes(destType);
}
ParseResult impl::parseCastOp(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType srcInfo;
Type srcType, dstType;
return failure(parser.parseOperand(srcInfo) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(srcType) ||
parser.resolveOperand(srcInfo, srcType, result.operands) ||
parser.parseKeywordType("to", dstType) ||
parser.addTypeToList(dstType, result.types));
}
void impl::printCastOp(Operation *op, OpAsmPrinter &p) {
p << ' ' << op->getOperand(0);
p.printOptionalAttrDict(op->getAttrs());
p << " : " << op->getOperand(0).getType() << " to "
<< op->getResult(0).getType();
}
Value impl::foldCastOp(Operation *op) {
// Identity cast
if (op->getOperand(0).getType() == op->getResult(0).getType())
return op->getOperand(0);
return nullptr;
}
LogicalResult
impl::verifyCastOp(Operation *op,
function_ref<bool(Type, Type)> areCastCompatible) {
auto opType = op->getOperand(0).getType();
auto resType = op->getResult(0).getType();
if (!areCastCompatible(opType, resType))
return op->emitError("operand type ")
<< opType << " and result type " << resType
<< " are cast incompatible";
return success();
}
//===----------------------------------------------------------------------===//
// Misc. utils
//===----------------------------------------------------------------------===//