[mlir][NFC] Update Affine operations to use `hasVerifier` instead of `verifier`

The verifier field is deprecated, and slated for removal.

Differential Revision: https://reviews.llvm.org/D118826
This commit is contained in:
River Riddle 2022-02-02 10:23:28 -08:00
parent ef72cf4413
commit 4809da8eaf
2 changed files with 122 additions and 119 deletions

View File

@ -31,12 +31,10 @@ class Affine_Op<string mnemonic, list<Trait> traits = []> :
Op<Affine_Dialect, mnemonic, traits> {
// For every affine op, there needs to be a:
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
// * LogicalResult verify(${C++ class of Op} op)
// * ParseResult parse${C++ class of Op}(OpAsmParser &parser,
// OperationState &result)
// functions.
let printer = [{ return ::print(p, *this); }];
let verifier = [{ return ::verify(*this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
@ -112,6 +110,7 @@ def AffineApplyOp : Affine_Op<"apply", [NoSideEffect]> {
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
def AffineForOp : Affine_Op<"for",
@ -350,6 +349,7 @@ def AffineForOp : Affine_Op<"for",
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
def AffineIfOp : Affine_Op<"if",
@ -473,6 +473,7 @@ def AffineIfOp : Affine_Op<"if",
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
class AffineLoadOpBase<string mnemonic, list<Trait> traits = []> :
@ -538,6 +539,7 @@ def AffineLoadOp : AffineLoadOpBase<"load"> {
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
class AffineMinMaxOpBase<string mnemonic, list<Trait> traits = []> :
@ -565,11 +567,11 @@ class AffineMinMaxOpBase<string mnemonic, list<Trait> traits = []> :
operands().end()};
}
}];
let verifier = [{ return ::verifyAffineMinMaxOp(*this); }];
let printer = [{ return ::printAffineMinMaxOp(p, *this); }];
let parser = [{ return ::parseAffineMinMaxOp<$cppClass>(parser, result); }];
let hasFolder = 1;
let hasCanonicalizer = 1;
let hasVerifier = 1;
}
def AffineMinOp : AffineMinMaxOpBase<"min", [NoSideEffect]> {
@ -753,6 +755,7 @@ def AffineParallelOp : Affine_Op<"parallel",
}];
let hasFolder = 1;
let hasVerifier = 1;
}
def AffinePrefetchOp : Affine_Op<"prefetch",
@ -832,6 +835,7 @@ def AffinePrefetchOp : Affine_Op<"prefetch",
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
class AffineStoreOpBase<string mnemonic, list<Trait> traits = []> :
@ -896,6 +900,7 @@ def AffineStoreOp : AffineStoreOpBase<"store"> {
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
def AffineYieldOp : Affine_Op<"yield", [NoSideEffect, Terminator, ReturnLike,
@ -921,6 +926,7 @@ def AffineYieldOp : Affine_Op<"yield", [NoSideEffect, Terminator, ReturnLike,
];
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
let hasVerifier = 1;
}
def AffineVectorLoadOp : AffineLoadOpBase<"vector_load"> {
@ -984,6 +990,7 @@ def AffineVectorLoadOp : AffineLoadOpBase<"vector_load"> {
}];
let hasCanonicalizer = 1;
let hasVerifier = 1;
}
def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> {
@ -1048,6 +1055,7 @@ def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> {
}];
let hasCanonicalizer = 1;
let hasVerifier = 1;
}
#endif // AFFINE_OPS

View File

@ -524,18 +524,18 @@ static void print(OpAsmPrinter &p, AffineApplyOp op) {
p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"map"});
}
static LogicalResult verify(AffineApplyOp op) {
LogicalResult AffineApplyOp::verify() {
// Check input and output dimensions match.
auto map = op.map();
AffineMap affineMap = map();
// Verify that operand count matches affine map dimension and symbol count.
if (op.getNumOperands() != map.getNumDims() + map.getNumSymbols())
return op.emitOpError(
if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols())
return emitOpError(
"operand count and affine map dimension and symbol count must match");
// Verify that the map only produces one result.
if (map.getNumResults() != 1)
return op.emitOpError("mapping must produce one value");
if (affineMap.getNumResults() != 1)
return emitOpError("mapping must produce one value");
return success();
}
@ -1306,41 +1306,38 @@ void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
bodyBuilder);
}
static LogicalResult verify(AffineForOp op) {
LogicalResult AffineForOp::verify() {
// Check that the body defines as single block argument for the induction
// variable.
auto *body = op.getBody();
auto *body = getBody();
if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
return op.emitOpError(
"expected body to have a single index argument for the "
"induction variable");
return emitOpError("expected body to have a single index argument for the "
"induction variable");
// Verify that the bound operands are valid dimension/symbols.
/// Lower bound.
if (op.getLowerBoundMap().getNumInputs() > 0)
if (failed(
verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(),
op.getLowerBoundMap().getNumDims())))
if (getLowerBoundMap().getNumInputs() > 0)
if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundOperands(),
getLowerBoundMap().getNumDims())))
return failure();
/// Upper bound.
if (op.getUpperBoundMap().getNumInputs() > 0)
if (failed(
verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(),
op.getUpperBoundMap().getNumDims())))
if (getUpperBoundMap().getNumInputs() > 0)
if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundOperands(),
getUpperBoundMap().getNumDims())))
return failure();
unsigned opNumResults = op.getNumResults();
unsigned opNumResults = getNumResults();
if (opNumResults == 0)
return success();
// If ForOp defines values, check that the number and types of the defined
// values match ForOp initial iter operands and backedge basic block
// arguments.
if (op.getNumIterOperands() != opNumResults)
return op.emitOpError(
if (getNumIterOperands() != opNumResults)
return emitOpError(
"mismatch between the number of loop-carried values and results");
if (op.getNumRegionIterArgs() != opNumResults)
return op.emitOpError(
if (getNumRegionIterArgs() != opNumResults)
return emitOpError(
"mismatch between the number of basic block args and results");
return success();
@ -2063,23 +2060,22 @@ struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
};
} // namespace
static LogicalResult verify(AffineIfOp op) {
LogicalResult AffineIfOp::verify() {
// Verify that we have a condition attribute.
// FIXME: This should be specified in the arguments list in ODS.
auto conditionAttr =
op->getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
(*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrName());
if (!conditionAttr)
return op.emitOpError(
"requires an integer set attribute named 'condition'");
return emitOpError("requires an integer set attribute named 'condition'");
// Verify that there are enough operands for the condition.
IntegerSet condition = conditionAttr.getValue();
if (op.getNumOperands() != condition.getNumInputs())
return op.emitOpError(
"operand count and condition integer set dimension and "
"symbol count must match");
if (getNumOperands() != condition.getNumInputs())
return emitOpError("operand count and condition integer set dimension and "
"symbol count must match");
// Verify that the operands are valid dimension/symbols.
if (failed(verifyDimAndSymbolIdentifiers(op, op.getOperands(),
if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(),
condition.getNumDims())))
return failure();
@ -2325,16 +2321,16 @@ verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
return success();
}
LogicalResult verify(AffineLoadOp op) {
auto memrefType = op.getMemRefType();
if (op.getType() != memrefType.getElementType())
return op.emitOpError("result type must match element type of memref");
LogicalResult AffineLoadOp::verify() {
auto memrefType = getMemRefType();
if (getType() != memrefType.getElementType())
return emitOpError("result type must match element type of memref");
if (failed(verifyMemoryOpIndexing(
op.getOperation(),
op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
op.getMapOperands(), memrefType,
/*numIndexOperands=*/op.getNumOperands() - 1)))
getOperation(),
(*this)->getAttrOfType<AffineMapAttr>(getMapAttrName()),
getMapOperands(), memrefType,
/*numIndexOperands=*/getNumOperands() - 1)))
return failure();
return success();
@ -2413,18 +2409,18 @@ static void print(OpAsmPrinter &p, AffineStoreOp op) {
p << " : " << op.getMemRefType();
}
LogicalResult verify(AffineStoreOp op) {
LogicalResult AffineStoreOp::verify() {
// The value to store must have the same type as memref element type.
auto memrefType = op.getMemRefType();
if (op.getValueToStore().getType() != memrefType.getElementType())
return op.emitOpError(
auto memrefType = getMemRefType();
if (getValueToStore().getType() != memrefType.getElementType())
return emitOpError(
"value to store must have the same type as memref element type");
if (failed(verifyMemoryOpIndexing(
op.getOperation(),
op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
op.getMapOperands(), memrefType,
/*numIndexOperands=*/op.getNumOperands() - 2)))
getOperation(),
(*this)->getAttrOfType<AffineMapAttr>(getMapAttrName()),
getMapOperands(), memrefType,
/*numIndexOperands=*/getNumOperands() - 2)))
return failure();
return success();
@ -2672,6 +2668,8 @@ void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
context);
}
LogicalResult AffineMinOp::verify() { return verifyAffineMinMaxOp(*this); }
//===----------------------------------------------------------------------===//
// AffineMaxOp
//===----------------------------------------------------------------------===//
@ -2691,6 +2689,8 @@ void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
context);
}
LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(*this); }
//===----------------------------------------------------------------------===//
// AffinePrefetchOp
//===----------------------------------------------------------------------===//
@ -2764,24 +2764,24 @@ static void print(OpAsmPrinter &p, AffinePrefetchOp op) {
p << " : " << op.getMemRefType();
}
static LogicalResult verify(AffinePrefetchOp op) {
auto mapAttr = op->getAttrOfType<AffineMapAttr>(op.getMapAttrName());
LogicalResult AffinePrefetchOp::verify() {
auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrName());
if (mapAttr) {
AffineMap map = mapAttr.getValue();
if (map.getNumResults() != op.getMemRefType().getRank())
return op.emitOpError("affine.prefetch affine map num results must equal"
" memref rank");
if (map.getNumInputs() + 1 != op.getNumOperands())
return op.emitOpError("too few operands");
if (map.getNumResults() != getMemRefType().getRank())
return emitOpError("affine.prefetch affine map num results must equal"
" memref rank");
if (map.getNumInputs() + 1 != getNumOperands())
return emitOpError("too few operands");
} else {
if (op.getNumOperands() != 1)
return op.emitOpError("too few operands");
if (getNumOperands() != 1)
return emitOpError("too few operands");
}
Region *scope = getAffineScope(op);
for (auto idx : op.getMapOperands()) {
Region *scope = getAffineScope(*this);
for (auto idx : getMapOperands()) {
if (!isValidAffineIndexOperand(idx, scope))
return op.emitOpError("index must be a dimension or symbol identifier");
return emitOpError("index must be a dimension or symbol identifier");
}
return success();
}
@ -3018,53 +3018,52 @@ void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
stepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
}
static LogicalResult verify(AffineParallelOp op) {
auto numDims = op.getNumDims();
if (op.lowerBoundsGroups().getNumElements() != numDims ||
op.upperBoundsGroups().getNumElements() != numDims ||
op.steps().size() != numDims ||
op.getBody()->getNumArguments() != numDims) {
return op.emitOpError()
<< "the number of region arguments ("
<< op.getBody()->getNumArguments()
<< ") and the number of map groups for lower ("
<< op.lowerBoundsGroups().getNumElements() << ") and upper bound ("
<< op.upperBoundsGroups().getNumElements()
<< "), and the number of steps (" << op.steps().size()
<< ") must all match";
LogicalResult AffineParallelOp::verify() {
auto numDims = getNumDims();
if (lowerBoundsGroups().getNumElements() != numDims ||
upperBoundsGroups().getNumElements() != numDims ||
steps().size() != numDims || getBody()->getNumArguments() != numDims) {
return emitOpError() << "the number of region arguments ("
<< getBody()->getNumArguments()
<< ") and the number of map groups for lower ("
<< lowerBoundsGroups().getNumElements()
<< ") and upper bound ("
<< upperBoundsGroups().getNumElements()
<< "), and the number of steps (" << steps().size()
<< ") must all match";
}
unsigned expectedNumLBResults = 0;
for (APInt v : op.lowerBoundsGroups())
for (APInt v : lowerBoundsGroups())
expectedNumLBResults += v.getZExtValue();
if (expectedNumLBResults != op.lowerBoundsMap().getNumResults())
return op.emitOpError() << "expected lower bounds map to have "
<< expectedNumLBResults << " results";
if (expectedNumLBResults != lowerBoundsMap().getNumResults())
return emitOpError() << "expected lower bounds map to have "
<< expectedNumLBResults << " results";
unsigned expectedNumUBResults = 0;
for (APInt v : op.upperBoundsGroups())
for (APInt v : upperBoundsGroups())
expectedNumUBResults += v.getZExtValue();
if (expectedNumUBResults != op.upperBoundsMap().getNumResults())
return op.emitOpError() << "expected upper bounds map to have "
<< expectedNumUBResults << " results";
if (expectedNumUBResults != upperBoundsMap().getNumResults())
return emitOpError() << "expected upper bounds map to have "
<< expectedNumUBResults << " results";
if (op.reductions().size() != op.getNumResults())
return op.emitOpError("a reduction must be specified for each output");
if (reductions().size() != getNumResults())
return emitOpError("a reduction must be specified for each output");
// Verify reduction ops are all valid
for (Attribute attr : op.reductions()) {
for (Attribute attr : reductions()) {
auto intAttr = attr.dyn_cast<IntegerAttr>();
if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
return op.emitOpError("invalid reduction attribute");
return emitOpError("invalid reduction attribute");
}
// Verify that the bound operands are valid dimension/symbols.
/// Lower bounds.
if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundsOperands(),
op.lowerBoundsMap().getNumDims())))
if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(),
lowerBoundsMap().getNumDims())))
return failure();
/// Upper bounds.
if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundsOperands(),
op.upperBoundsMap().getNumDims())))
if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(),
upperBoundsMap().getNumDims())))
return failure();
return success();
}
@ -3412,20 +3411,19 @@ static ParseResult parseAffineParallelOp(OpAsmParser &parser,
// AffineYieldOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(AffineYieldOp op) {
auto *parentOp = op->getParentOp();
LogicalResult AffineYieldOp::verify() {
auto *parentOp = (*this)->getParentOp();
auto results = parentOp->getResults();
auto operands = op.getOperands();
auto operands = getOperands();
if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
return op.emitOpError() << "only terminates affine.if/for/parallel regions";
if (parentOp->getNumResults() != op.getNumOperands())
return op.emitOpError() << "parent of yield must have same number of "
"results as the yield operands";
return emitOpError() << "only terminates affine.if/for/parallel regions";
if (parentOp->getNumResults() != getNumOperands())
return emitOpError() << "parent of yield must have same number of "
"results as the yield operands";
for (auto it : llvm::zip(results, operands)) {
if (std::get<0>(it).getType() != std::get<1>(it).getType())
return op.emitOpError()
<< "types mismatch between yield op and its parent";
return emitOpError() << "types mismatch between yield op and its parent";
}
return success();
@ -3516,17 +3514,16 @@ static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
return success();
}
static LogicalResult verify(AffineVectorLoadOp op) {
MemRefType memrefType = op.getMemRefType();
LogicalResult AffineVectorLoadOp::verify() {
MemRefType memrefType = getMemRefType();
if (failed(verifyMemoryOpIndexing(
op.getOperation(),
op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
op.getMapOperands(), memrefType,
/*numIndexOperands=*/op.getNumOperands() - 1)))
getOperation(),
(*this)->getAttrOfType<AffineMapAttr>(getMapAttrName()),
getMapOperands(), memrefType,
/*numIndexOperands=*/getNumOperands() - 1)))
return failure();
if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType,
op.getVectorType())))
if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType())))
return failure();
return success();
@ -3599,17 +3596,15 @@ static void print(OpAsmPrinter &p, AffineVectorStoreOp op) {
p << " : " << op.getMemRefType() << ", " << op.getValueToStore().getType();
}
static LogicalResult verify(AffineVectorStoreOp op) {
MemRefType memrefType = op.getMemRefType();
LogicalResult AffineVectorStoreOp::verify() {
MemRefType memrefType = getMemRefType();
if (failed(verifyMemoryOpIndexing(
op.getOperation(),
op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
op.getMapOperands(), memrefType,
/*numIndexOperands=*/op.getNumOperands() - 2)))
*this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrName()),
getMapOperands(), memrefType,
/*numIndexOperands=*/getNumOperands() - 2)))
return failure();
if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType,
op.getVectorType())))
if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType())))
return failure();
return success();