[mlir][NFC] Remove deprecated/old build/fold/parser utilities from OpDefinition
These have generally been replaced by better ODS functionality, and do not need to be explicitly provided anymore. Differential Revision: https://reviews.llvm.org/D119065
This commit is contained in:
parent
3c69bc4d6e
commit
60cac0c081
|
@ -2534,19 +2534,15 @@ def fir_StringLitOp : fir_Op<"string_lit", [NoSideEffect]> {
|
||||||
class fir_ArithmeticOp<string mnemonic, list<Trait> traits = []> :
|
class fir_ArithmeticOp<string mnemonic, list<Trait> traits = []> :
|
||||||
fir_Op<mnemonic,
|
fir_Op<mnemonic,
|
||||||
!listconcat(traits, [NoSideEffect, SameOperandsAndResultType])>,
|
!listconcat(traits, [NoSideEffect, SameOperandsAndResultType])>,
|
||||||
Results<(outs AnyType)> {
|
Results<(outs AnyType:$result)> {
|
||||||
let parser = "return impl::parseOneResultSameOperandTypeOp(parser, result);";
|
let assemblyFormat = "operands attr-dict `:` type($result)";
|
||||||
|
|
||||||
let printer = "return printBinaryOp(this->getOperation(), p);";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class fir_UnaryArithmeticOp<string mnemonic, list<Trait> traits = []> :
|
class fir_UnaryArithmeticOp<string mnemonic, list<Trait> traits = []> :
|
||||||
fir_Op<mnemonic,
|
fir_Op<mnemonic,
|
||||||
!listconcat(traits, [NoSideEffect, SameOperandsAndResultType])>,
|
!listconcat(traits, [NoSideEffect, SameOperandsAndResultType])>,
|
||||||
Results<(outs AnyType)> {
|
Results<(outs AnyType:$result)> {
|
||||||
let parser = "return impl::parseOneResultSameOperandTypeOp(parser, result);";
|
let assemblyFormat = "operands attr-dict `:` type($result)";
|
||||||
|
|
||||||
let printer = "return printUnaryOp(this->getOperation(), p);";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def fir_ConstcOp : fir_Op<"constc", [NoSideEffect]> {
|
def fir_ConstcOp : fir_Op<"constc", [NoSideEffect]> {
|
||||||
|
|
|
@ -3211,26 +3211,6 @@ mlir::ParseResult fir::parseSelector(mlir::OpAsmParser &parser,
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generic pretty-printer of a binary operation
|
|
||||||
static void printBinaryOp(Operation *op, OpAsmPrinter &p) {
|
|
||||||
assert(op->getNumOperands() == 2 && "binary op must have two operands");
|
|
||||||
assert(op->getNumResults() == 1 && "binary op must have one result");
|
|
||||||
|
|
||||||
p << ' ' << op->getOperand(0) << ", " << op->getOperand(1);
|
|
||||||
p.printOptionalAttrDict(op->getAttrs());
|
|
||||||
p << " : " << op->getResult(0).getType();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generic pretty-printer of an unary operation
|
|
||||||
static void printUnaryOp(Operation *op, OpAsmPrinter &p) {
|
|
||||||
assert(op->getNumOperands() == 1 && "unary op must have one operand");
|
|
||||||
assert(op->getNumResults() == 1 && "unary op must have one result");
|
|
||||||
|
|
||||||
p << ' ' << op->getOperand(0);
|
|
||||||
p.printOptionalAttrDict(op->getAttrs());
|
|
||||||
p << " : " << op->getResult(0).getType();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool fir::isReferenceLike(mlir::Type type) {
|
bool fir::isReferenceLike(mlir::Type type) {
|
||||||
return type.isa<fir::ReferenceType>() || type.isa<fir::HeapType>() ||
|
return type.isa<fir::ReferenceType>() || type.isa<fir::HeapType>() ||
|
||||||
type.isa<fir::PointerType>();
|
type.isa<fir::PointerType>();
|
||||||
|
|
|
@ -419,8 +419,7 @@ class LLVM_CastOp<string mnemonic, string builderFunc, Type type,
|
||||||
let arguments = (ins type:$arg);
|
let arguments = (ins type:$arg);
|
||||||
let results = (outs resultType:$res);
|
let results = (outs resultType:$res);
|
||||||
let builders = [LLVM_OneResultOpBuilder];
|
let builders = [LLVM_OneResultOpBuilder];
|
||||||
let parser = [{ return mlir::impl::parseCastOp(parser, result); }];
|
let assemblyFormat = "$arg attr-dict `:` type($arg) `to` type($res)";
|
||||||
let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }];
|
|
||||||
}
|
}
|
||||||
def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "CreateBitCast",
|
def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "CreateBitCast",
|
||||||
LLVM_AnyNonAggregate, LLVM_AnyNonAggregate> {
|
LLVM_AnyNonAggregate, LLVM_AnyNonAggregate> {
|
||||||
|
|
|
@ -383,7 +383,6 @@ def MemRef_CastOp : MemRef_Op<"cast", [
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
let hasVerifier = 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -34,6 +34,7 @@ class SPV_ArithmeticBinaryOp<string mnemonic, Type type,
|
||||||
let results = (outs
|
let results = (outs
|
||||||
SPV_ScalarOrVectorOrCoopMatrixOf<type>:$result
|
SPV_ScalarOrVectorOrCoopMatrixOf<type>:$result
|
||||||
);
|
);
|
||||||
|
let assemblyFormat = "operands attr-dict `:` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
class SPV_ArithmeticUnaryOp<string mnemonic, Type type,
|
class SPV_ArithmeticUnaryOp<string mnemonic, Type type,
|
||||||
|
|
|
@ -4338,8 +4338,6 @@ class SPV_BinaryOp<string mnemonic, Type resultType, Type operandsType,
|
||||||
SPV_ScalarOrVectorOf<resultType>:$result
|
SPV_ScalarOrVectorOf<resultType>:$result
|
||||||
);
|
);
|
||||||
|
|
||||||
let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];
|
|
||||||
let printer = [{ return impl::printOneResultOp(getOperation(), p); }];
|
|
||||||
// No additional verification needed in addition to the ODS-generated ones.
|
// No additional verification needed in addition to the ODS-generated ones.
|
||||||
let hasVerifier = 0;
|
let hasVerifier = 0;
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,9 @@ class SPV_BitBinaryOp<string mnemonic, list<Trait> traits = []> :
|
||||||
// All the operands type used in bit instructions are SPV_Integer.
|
// All the operands type used in bit instructions are SPV_Integer.
|
||||||
SPV_BinaryOp<mnemonic, SPV_Integer, SPV_Integer,
|
SPV_BinaryOp<mnemonic, SPV_Integer, SPV_Integer,
|
||||||
!listconcat(traits,
|
!listconcat(traits,
|
||||||
[NoSideEffect, SameOperandsAndResultType])>;
|
[NoSideEffect, SameOperandsAndResultType])> {
|
||||||
|
let assemblyFormat = "operands attr-dict `:` type($result)";
|
||||||
|
}
|
||||||
|
|
||||||
class SPV_BitFieldExtractOp<string mnemonic, list<Trait> traits = []> :
|
class SPV_BitFieldExtractOp<string mnemonic, list<Trait> traits = []> :
|
||||||
SPV_Op<mnemonic, !listconcat(traits,
|
SPV_Op<mnemonic, !listconcat(traits,
|
||||||
|
|
|
@ -29,9 +29,9 @@ class SPV_CastOp<string mnemonic, Type resultType, Type operandType,
|
||||||
let results = (outs
|
let results = (outs
|
||||||
SPV_ScalarOrVectorOrCoopMatrixOf<resultType>:$result
|
SPV_ScalarOrVectorOrCoopMatrixOf<resultType>:$result
|
||||||
);
|
);
|
||||||
|
let assemblyFormat = [{
|
||||||
let parser = [{ return mlir::impl::parseCastOp(parser, result); }];
|
$operand attr-dict `:` type($operand) `to` type($result)
|
||||||
let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -85,9 +85,9 @@ def SPV_BitcastOp : SPV_Op<"Bitcast", [NoSideEffect]> {
|
||||||
SPV_ScalarOrVectorOrPtr:$result
|
SPV_ScalarOrVectorOrPtr:$result
|
||||||
);
|
);
|
||||||
|
|
||||||
let parser = [{ return mlir::impl::parseCastOp(parser, result); }];
|
let assemblyFormat = [{
|
||||||
let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }];
|
$operand attr-dict `:` type($operand) `to` type($result)
|
||||||
|
}];
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -72,10 +72,6 @@ class SPV_GLSLBinaryOp<string mnemonic, int opcode, Type resultType,
|
||||||
SPV_ScalarOrVectorOf<resultType>:$result
|
SPV_ScalarOrVectorOf<resultType>:$result
|
||||||
);
|
);
|
||||||
|
|
||||||
let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];
|
|
||||||
|
|
||||||
let printer = [{ return impl::printOneResultOp(getOperation(), p); }];
|
|
||||||
|
|
||||||
let hasVerifier = 0;
|
let hasVerifier = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,7 +79,10 @@ class SPV_GLSLBinaryOp<string mnemonic, int opcode, Type resultType,
|
||||||
// return type matches.
|
// return type matches.
|
||||||
class SPV_GLSLBinaryArithmeticOp<string mnemonic, int opcode, Type type,
|
class SPV_GLSLBinaryArithmeticOp<string mnemonic, int opcode, Type type,
|
||||||
list<Trait> traits = []> :
|
list<Trait> traits = []> :
|
||||||
SPV_GLSLBinaryOp<mnemonic, opcode, type, type, traits>;
|
SPV_GLSLBinaryOp<mnemonic, opcode, type, type,
|
||||||
|
traits # [SameOperandsAndResultType]> {
|
||||||
|
let assemblyFormat = "operands attr-dict `:` type($result)";
|
||||||
|
}
|
||||||
|
|
||||||
// Base class for GLSL ternary ops.
|
// Base class for GLSL ternary ops.
|
||||||
class SPV_GLSLTernaryArithmeticOp<string mnemonic, int opcode, Type type,
|
class SPV_GLSLTernaryArithmeticOp<string mnemonic, int opcode, Type type,
|
||||||
|
@ -100,9 +99,8 @@ class SPV_GLSLTernaryArithmeticOp<string mnemonic, int opcode, Type type,
|
||||||
SPV_ScalarOrVectorOf<type>:$result
|
SPV_ScalarOrVectorOf<type>:$result
|
||||||
);
|
);
|
||||||
|
|
||||||
let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];
|
let parser = [{ return parseOneResultSameOperandTypeOp(parser, result); }];
|
||||||
|
let printer = [{ return printOneResultOp(getOperation(), p); }];
|
||||||
let printer = [{ return impl::printOneResultOp(getOperation(), p); }];
|
|
||||||
|
|
||||||
let hasVerifier = 0;
|
let hasVerifier = 0;
|
||||||
}
|
}
|
||||||
|
|
|
@ -71,10 +71,6 @@ class SPV_OCLBinaryOp<string mnemonic, int opcode, Type resultType,
|
||||||
SPV_ScalarOrVectorOf<resultType>:$result
|
SPV_ScalarOrVectorOf<resultType>:$result
|
||||||
);
|
);
|
||||||
|
|
||||||
let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];
|
|
||||||
|
|
||||||
let printer = [{ return impl::printOneResultOp(getOperation(), p); }];
|
|
||||||
|
|
||||||
let hasVerifier = 0;
|
let hasVerifier = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,7 +78,10 @@ class SPV_OCLBinaryOp<string mnemonic, int opcode, Type resultType,
|
||||||
// return type matches.
|
// return type matches.
|
||||||
class SPV_OCLBinaryArithmeticOp<string mnemonic, int opcode, Type type,
|
class SPV_OCLBinaryArithmeticOp<string mnemonic, int opcode, Type type,
|
||||||
list<Trait> traits = []> :
|
list<Trait> traits = []> :
|
||||||
SPV_OCLBinaryOp<mnemonic, opcode, type, type, traits>;
|
SPV_OCLBinaryOp<mnemonic, opcode, type, type,
|
||||||
|
traits # [SameOperandsAndResultType]> {
|
||||||
|
let assemblyFormat = "operands attr-dict `:` type($result)";
|
||||||
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
#define SHAPE_OPS
|
#define SHAPE_OPS
|
||||||
|
|
||||||
include "mlir/Dialect/Shape/IR/ShapeBase.td"
|
include "mlir/Dialect/Shape/IR/ShapeBase.td"
|
||||||
|
include "mlir/Interfaces/CastInterfaces.td"
|
||||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
|
@ -331,7 +332,9 @@ def Shape_RankOp : Shape_Op<"rank",
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
|
def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [
|
||||||
|
DeclareOpInterfaceMethods<CastOpInterface>, NoSideEffect
|
||||||
|
]> {
|
||||||
let summary = "Creates a dimension tensor from a shape";
|
let summary = "Creates a dimension tensor from a shape";
|
||||||
let description = [{
|
let description = [{
|
||||||
Converts a shape to a 1D integral tensor of extents. The number of elements
|
Converts a shape to a 1D integral tensor of extents. The number of elements
|
||||||
|
@ -624,7 +627,9 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of",
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> {
|
def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [
|
||||||
|
DeclareOpInterfaceMethods<CastOpInterface>, NoSideEffect
|
||||||
|
]> {
|
||||||
let summary = "Casts between index types of the shape and standard dialect";
|
let summary = "Casts between index types of the shape and standard dialect";
|
||||||
let description = [{
|
let description = [{
|
||||||
Converts a `shape.size` to a standard index. This operation and its
|
Converts a `shape.size` to a standard index. This operation and its
|
||||||
|
|
|
@ -1897,26 +1897,9 @@ protected:
|
||||||
};
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Common Operation Folders/Parsers/Printers
|
// CastOpInterface utilities
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// These functions are out-of-line implementations of the methods in UnaryOp and
|
|
||||||
// BinaryOp, which avoids them being template instantiated/duplicated.
|
|
||||||
namespace impl {
|
|
||||||
ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser,
|
|
||||||
OperationState &result);
|
|
||||||
|
|
||||||
void buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs,
|
|
||||||
Value rhs);
|
|
||||||
ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
|
|
||||||
OperationState &result);
|
|
||||||
|
|
||||||
// Prints the given binary `op` in custom assembly form if both the two operands
|
|
||||||
// and the result have the same time. Otherwise, prints the generic assembly
|
|
||||||
// form.
|
|
||||||
void printOneResultOp(Operation *op, OpAsmPrinter &p);
|
|
||||||
} // namespace impl
|
|
||||||
|
|
||||||
// These functions are out-of-line implementations of the methods in
|
// These functions are out-of-line implementations of the methods in
|
||||||
// CastOpInterface, which avoids them being template instantiated/duplicated.
|
// CastOpInterface, which avoids them being template instantiated/duplicated.
|
||||||
namespace impl {
|
namespace impl {
|
||||||
|
@ -1927,20 +1910,6 @@ LogicalResult foldCastInterfaceOp(Operation *op,
|
||||||
/// Attempt to verify the given cast operation.
|
/// Attempt to verify the given cast operation.
|
||||||
LogicalResult verifyCastInterfaceOp(
|
LogicalResult verifyCastInterfaceOp(
|
||||||
Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible);
|
Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible);
|
||||||
|
|
||||||
// TODO: Remove the parse/print/build here (new ODS functionality obsoletes the
|
|
||||||
// need for them, but some older ODS code in `std` still depends on them).
|
|
||||||
void buildCastOp(OpBuilder &builder, OperationState &result, Value source,
|
|
||||||
Type destType);
|
|
||||||
ParseResult parseCastOp(OpAsmParser &parser, OperationState &result);
|
|
||||||
void printCastOp(Operation *op, OpAsmPrinter &p);
|
|
||||||
// TODO: These methods are deprecated in favor of CastOpInterface. Remove them
|
|
||||||
// when all uses have been updated. Also, consider adding functionality to
|
|
||||||
// CastOpInterface to be able to perform the ChainedTensorCast canonicalization
|
|
||||||
// generically.
|
|
||||||
Value foldCastOp(Operation *op);
|
|
||||||
LogicalResult verifyCastOp(Operation *op,
|
|
||||||
function_ref<bool(Type, Type)> areCastCompatible);
|
|
||||||
} // namespace impl
|
} // namespace impl
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -65,10 +65,6 @@ Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
|
||||||
return NoneType::get(type.getContext());
|
return NoneType::get(type.getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult memref::CastOp::verify() {
|
|
||||||
return impl::verifyCastOp(*this, areCastCompatible);
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AllocOp / AllocaOp
|
// AllocOp / AllocaOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -64,6 +64,54 @@ static constexpr const char kCompositeSpecConstituentsName[] = "constituents";
|
||||||
// Common utility functions
|
// Common utility functions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
|
||||||
|
OperationState &result) {
|
||||||
|
SmallVector<OpAsmParser::OperandType, 2> ops;
|
||||||
|
Type type;
|
||||||
|
// If the operand list is in-between parentheses, then we have a generic form.
|
||||||
|
// (see the fallback in `printOneResultOp`).
|
||||||
|
SMLoc loc = parser.getCurrentLocation();
|
||||||
|
if (!parser.parseOptionalLParen()) {
|
||||||
|
if (parser.parseOperandList(ops) || parser.parseRParen() ||
|
||||||
|
parser.parseOptionalAttrDict(result.attributes) ||
|
||||||
|
parser.parseColon() || parser.parseType(type))
|
||||||
|
return failure();
|
||||||
|
auto fnType = type.dyn_cast<FunctionType>();
|
||||||
|
if (!fnType) {
|
||||||
|
parser.emitError(loc, "expected function type");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
|
||||||
|
return failure();
|
||||||
|
result.addTypes(fnType.getResults());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
return failure(parser.parseOperandList(ops) ||
|
||||||
|
parser.parseOptionalAttrDict(result.attributes) ||
|
||||||
|
parser.parseColonType(type) ||
|
||||||
|
parser.resolveOperands(ops, type, result.operands) ||
|
||||||
|
parser.addTypeToList(type, result.types));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void printOneResultOp(Operation *op, OpAsmPrinter &p) {
|
||||||
|
assert(op->getNumResults() == 1 && "op should have one result");
|
||||||
|
|
||||||
|
// If not all the operand and result types are the same, just use the
|
||||||
|
// generic assembly form to avoid omitting information in printing.
|
||||||
|
auto resultType = op->getResult(0).getType();
|
||||||
|
if (llvm::any_of(op->getOperandTypes(),
|
||||||
|
[&](Type type) { return type != resultType; })) {
|
||||||
|
p.printGenericOp(op, /*printOpName=*/false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
p << ' ';
|
||||||
|
p.printOperands(op->getOperands());
|
||||||
|
p.printOptionalAttrDict(op->getAttrs());
|
||||||
|
// Now we can output only one type for all operands and the result.
|
||||||
|
p << " : " << resultType;
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns true if the given op is a function-like op or nested in a
|
/// Returns true if the given op is a function-like op or nested in a
|
||||||
/// function-like op without a module-like op in the middle.
|
/// function-like op without a module-like op in the middle.
|
||||||
static bool isNestedInFunctionOpInterface(Operation *op) {
|
static bool isNestedInFunctionOpInterface(Operation *op) {
|
||||||
|
|
|
@ -1692,7 +1692,7 @@ OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
|
||||||
// `IntegerAttr`s which makes constant folding simple.
|
// `IntegerAttr`s which makes constant folding simple.
|
||||||
if (Attribute arg = operands[0])
|
if (Attribute arg = operands[0])
|
||||||
return arg;
|
return arg;
|
||||||
return impl::foldCastOp(*this);
|
return OpFoldResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
|
@ -1700,6 +1700,12 @@ void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
patterns.add<IndexToSizeToIndexCanonicalization>(context);
|
patterns.add<IndexToSizeToIndexCanonicalization>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||||
|
if (inputs.size() != 1 || outputs.size() != 1)
|
||||||
|
return false;
|
||||||
|
return inputs[0].isa<IndexType, SizeType>() && outputs[0].isa<IndexType>();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// YieldOp
|
// YieldOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1750,7 +1756,7 @@ LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
|
||||||
|
|
||||||
OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
|
||||||
if (!operands[0])
|
if (!operands[0])
|
||||||
return impl::foldCastOp(*this);
|
return OpFoldResult();
|
||||||
Builder builder(getContext());
|
Builder builder(getContext());
|
||||||
auto shape = llvm::to_vector<6>(
|
auto shape = llvm::to_vector<6>(
|
||||||
operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
|
operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
|
||||||
|
@ -1759,6 +1765,21 @@ OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return DenseIntElementsAttr::get(type, shape);
|
return DenseIntElementsAttr::get(type, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||||
|
if (inputs.size() != 1 || outputs.size() != 1)
|
||||||
|
return false;
|
||||||
|
if (auto inputTensor = inputs[0].dyn_cast<RankedTensorType>()) {
|
||||||
|
if (!inputTensor.getElementType().isa<IndexType>() ||
|
||||||
|
inputTensor.getRank() != 1 || !inputTensor.isDynamicDim(0))
|
||||||
|
return false;
|
||||||
|
} else if (!inputs[0].isa<ShapeType>()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorType outputTensor = outputs[0].dyn_cast<TensorType>();
|
||||||
|
return outputTensor && outputTensor.getElementType().isa<IndexType>();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ReduceOp
|
// ReduceOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -1125,69 +1125,7 @@ bool OpTrait::hasElementwiseMappableTraits(Operation *op) {
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// BinaryOp implementation
|
// CastOpInterface
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
// These functions are out-of-line implementations of the methods in BinaryOp,
|
|
||||||
// which avoids them being template instantiated/duplicated.
|
|
||||||
|
|
||||||
void impl::buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs,
|
|
||||||
Value rhs) {
|
|
||||||
assert(lhs.getType() == rhs.getType());
|
|
||||||
result.addOperands({lhs, rhs});
|
|
||||||
result.types.push_back(lhs.getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
ParseResult impl::parseOneResultSameOperandTypeOp(OpAsmParser &parser,
|
|
||||||
OperationState &result) {
|
|
||||||
SmallVector<OpAsmParser::OperandType, 2> ops;
|
|
||||||
Type type;
|
|
||||||
// If the operand list is in-between parentheses, then we have a generic form.
|
|
||||||
// (see the fallback in `printOneResultOp`).
|
|
||||||
SMLoc loc = parser.getCurrentLocation();
|
|
||||||
if (!parser.parseOptionalLParen()) {
|
|
||||||
if (parser.parseOperandList(ops) || parser.parseRParen() ||
|
|
||||||
parser.parseOptionalAttrDict(result.attributes) ||
|
|
||||||
parser.parseColon() || parser.parseType(type))
|
|
||||||
return failure();
|
|
||||||
auto fnType = type.dyn_cast<FunctionType>();
|
|
||||||
if (!fnType) {
|
|
||||||
parser.emitError(loc, "expected function type");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
|
|
||||||
return failure();
|
|
||||||
result.addTypes(fnType.getResults());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
return failure(parser.parseOperandList(ops) ||
|
|
||||||
parser.parseOptionalAttrDict(result.attributes) ||
|
|
||||||
parser.parseColonType(type) ||
|
|
||||||
parser.resolveOperands(ops, type, result.operands) ||
|
|
||||||
parser.addTypeToList(type, result.types));
|
|
||||||
}
|
|
||||||
|
|
||||||
void impl::printOneResultOp(Operation *op, OpAsmPrinter &p) {
|
|
||||||
assert(op->getNumResults() == 1 && "op should have one result");
|
|
||||||
|
|
||||||
// If not all the operand and result types are the same, just use the
|
|
||||||
// generic assembly form to avoid omitting information in printing.
|
|
||||||
auto resultType = op->getResult(0).getType();
|
|
||||||
if (llvm::any_of(op->getOperandTypes(),
|
|
||||||
[&](Type type) { return type != resultType; })) {
|
|
||||||
p.printGenericOp(op, /*printOpName=*/false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
p << ' ';
|
|
||||||
p.printOperands(op->getOperands());
|
|
||||||
p.printOptionalAttrDict(op->getAttrs());
|
|
||||||
// Now we can output only one type for all operands and the result.
|
|
||||||
p << " : " << resultType;
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// CastOp implementation
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Attempt to fold the given cast operation.
|
/// Attempt to fold the given cast operation.
|
||||||
|
@ -1232,50 +1170,6 @@ LogicalResult impl::verifyCastInterfaceOp(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void impl::buildCastOp(OpBuilder &builder, OperationState &result, Value source,
|
|
||||||
Type destType) {
|
|
||||||
result.addOperands(source);
|
|
||||||
result.addTypes(destType);
|
|
||||||
}
|
|
||||||
|
|
||||||
ParseResult impl::parseCastOp(OpAsmParser &parser, OperationState &result) {
|
|
||||||
OpAsmParser::OperandType srcInfo;
|
|
||||||
Type srcType, dstType;
|
|
||||||
return failure(parser.parseOperand(srcInfo) ||
|
|
||||||
parser.parseOptionalAttrDict(result.attributes) ||
|
|
||||||
parser.parseColonType(srcType) ||
|
|
||||||
parser.resolveOperand(srcInfo, srcType, result.operands) ||
|
|
||||||
parser.parseKeywordType("to", dstType) ||
|
|
||||||
parser.addTypeToList(dstType, result.types));
|
|
||||||
}
|
|
||||||
|
|
||||||
void impl::printCastOp(Operation *op, OpAsmPrinter &p) {
|
|
||||||
p << ' ' << op->getOperand(0);
|
|
||||||
p.printOptionalAttrDict(op->getAttrs());
|
|
||||||
p << " : " << op->getOperand(0).getType() << " to "
|
|
||||||
<< op->getResult(0).getType();
|
|
||||||
}
|
|
||||||
|
|
||||||
Value impl::foldCastOp(Operation *op) {
|
|
||||||
// Identity cast
|
|
||||||
if (op->getOperand(0).getType() == op->getResult(0).getType())
|
|
||||||
return op->getOperand(0);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult
|
|
||||||
impl::verifyCastOp(Operation *op,
|
|
||||||
function_ref<bool(Type, Type)> areCastCompatible) {
|
|
||||||
auto opType = op->getOperand(0).getType();
|
|
||||||
auto resType = op->getResult(0).getType();
|
|
||||||
if (!areCastCompatible(opType, resType))
|
|
||||||
return op->emitError("operand type ")
|
|
||||||
<< opType << " and result type " << resType
|
|
||||||
<< " are cast incompatible";
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Misc. utils
|
// Misc. utils
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
Loading…
Reference in New Issue