[mlir] Add canonicalizer to merge shape.assuming_all ops

Depends On D119021

Reviewed By: frgossen

Differential Revision: https://reviews.llvm.org/D119025
This commit is contained in:
Eugene Zhulenev 2022-02-04 11:31:59 -08:00
parent 598833c987
commit 981f0a14f1
2 changed files with 55 additions and 1 deletions

View File

@ -460,6 +460,39 @@ LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
//===----------------------------------------------------------------------===//
namespace {
// Merge multiple `shape.assuming_all` operations together.
//
// %0 = shape.assuming_all %w0, %w1
// %1 = shape.assuming_all %w2, %0
//
// to:
//
// %0 = shape.assuming_all %w0, %w2, %w2
struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AssumingAllOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value> operands;
for (Value operand : op.getInputs()) {
if (auto assume_all = operand.getDefiningOp<AssumingAllOp>())
operands.append(assume_all.operand_begin(), assume_all->operand_end());
else
operands.push_back(operand);
}
// We didn't find any other `assuming_all` ops to merge with.
if (operands.size() == op.getNumOperands())
return failure();
// Replace with a new `assuming_all` operation with merged constraints.
rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
return success();
}
};
struct AssumingAllToCstrEqCanonicalization
: public OpRewritePattern<AssumingAllOp> {
using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
@ -506,7 +539,8 @@ struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization,
patterns.add<MergeAssumingAllOps, AssumingAllOneOp,
AssumingAllToCstrEqCanonicalization,
RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
}

View File

@ -463,6 +463,26 @@ func @cstr_require_no_fold(%arg0: i1) {
return
}
// -----
// merge assuming_all operations
// CHECK-LABEL: func @f
func @f() {
// CHECK-NEXT: %[[W0:.*]] = "test.source"
// CHECK-NEXT: %[[W1:.*]] = "test.source"
// CHECK-NEXT: %[[W2:.*]] = "test.source"
// CHECK-NEXT: shape.assuming_all %[[W0]], %[[W1]], %[[W2]]
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%0 = "test.source"() : () -> !shape.witness
%1 = "test.source"() : () -> !shape.witness
%2 = "test.source"() : () -> !shape.witness
%3 = shape.assuming_all %0, %1
%4 = shape.assuming_all %3, %2
"consume.witness"(%4) : (!shape.witness) -> ()
return
}
// -----
// `assuming_all` with all `cstr_eq` and shared operands can be collapsed.
// CHECK-LABEL: func @assuming_all_to_cstr_eq