[MLIR][PDL] Generalize result type verification

Presently the result type verification checks if the type is used by a `pdl::OperationOp` inside the matcher. This is unnecessarily restrictive; the type could come from a `pdl::OperandOp or `pdl::OperandsOp` and still be inferrable.

Reviewed By: rriddle, Mogball

Differential Revision: https://reviews.llvm.org/D116083
This commit is contained in:
Stanislav Funiak 2022-01-04 08:11:35 +05:30 committed by Uday Bondhugula
parent b4130e9ead
commit de6c82d6fd
2 changed files with 37 additions and 8 deletions

View File

@ -207,16 +207,17 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
if (isa<ApplyNativeRewriteOp>(resultTypeOp))
continue;
// If the type operation was defined in the matcher and constrains the
// result of an input operation, it can be used.
auto constrainsInputOp = [rewriterBlock](Operation *user) {
return user->getBlock() != rewriterBlock && isa<OperationOp>(user);
// If the type operation was defined in the matcher and constrains an
// operand or the result of an input operation, it can be used.
auto constrainsInput = [rewriterBlock](Operation *user) {
return user->getBlock() != rewriterBlock &&
isa<OperandOp, OperandsOp, OperationOp>(user);
};
if (TypeOp typeOp = dyn_cast<TypeOp>(resultTypeOp)) {
if (typeOp.type() || llvm::any_of(typeOp->getUsers(), constrainsInputOp))
if (typeOp.type() || llvm::any_of(typeOp->getUsers(), constrainsInput))
continue;
} else if (TypesOp typeOp = dyn_cast<TypesOp>(resultTypeOp)) {
if (typeOp.types() || llvm::any_of(typeOp->getUsers(), constrainsInputOp))
if (typeOp.types() || llvm::any_of(typeOp->getUsers(), constrainsInput))
continue;
}

View File

@ -88,7 +88,7 @@ pdl.pattern @infer_type_from_operation_replace : benefit(1) {
// -----
// Check that the result type of an operation within a rewrite can be inferred
// from types used within the match block.
// from the result types of an operation within the match block.
pdl.pattern @infer_type_from_type_used_in_match : benefit(1) {
%type1 = pdl.type : i32
%type2 = pdl.type
@ -101,7 +101,7 @@ pdl.pattern @infer_type_from_type_used_in_match : benefit(1) {
// -----
// Check that the result type of an operation within a rewrite can be inferred
// from types used within the match block.
// from the result types of an operation within the match block.
pdl.pattern @infer_type_from_type_used_in_match : benefit(1) {
%types = pdl.types
%root = pdl.operation -> (%types : !pdl.range<type>)
@ -113,6 +113,34 @@ pdl.pattern @infer_type_from_type_used_in_match : benefit(1) {
// -----
// Check that the result type of an operation within a rewrite can be inferred
// from the type of an operand within the match block.
pdl.pattern @infer_type_from_type_used_in_match : benefit(1) {
%type1 = pdl.type
%type2 = pdl.type
%operand1 = pdl.operand : %type1
%operand2 = pdl.operand : %type2
%root = pdl.operation (%operand1, %operand2 : !pdl.value, !pdl.value)
pdl.rewrite %root {
%newOp = pdl.operation "foo.op" -> (%type1, %type2 : !pdl.type, !pdl.type)
}
}
// -----
// Check that the result type of an operation within a rewrite can be inferred
// from the types of operands within the match block.
pdl.pattern @infer_type_from_type_used_in_match : benefit(1) {
%types = pdl.types
%operands = pdl.operands : %types
%root = pdl.operation (%operands : !pdl.range<value>)
pdl.rewrite %root {
%newOp = pdl.operation "foo.op" -> (%types : !pdl.range<type>)
}
}
// -----
pdl.pattern @apply_rewrite_with_no_results : benefit(1) {
%root = pdl.operation
pdl.rewrite %root {