[mlir] Add canonicalizer to merge shape.assuming_all ops
Depends On D119021 Reviewed By: frgossen Differential Revision: https://reviews.llvm.org/D119025
This commit is contained in:
parent
598833c987
commit
981f0a14f1
|
@ -460,6 +460,39 @@ LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
||||
// Merge multiple `shape.assuming_all` operations together.
|
||||
//
|
||||
// %0 = shape.assuming_all %w0, %w1
|
||||
// %1 = shape.assuming_all %w2, %0
|
||||
//
|
||||
// to:
|
||||
//
|
||||
// %0 = shape.assuming_all %w0, %w2, %w2
|
||||
struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
|
||||
using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(AssumingAllOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<Value> operands;
|
||||
|
||||
for (Value operand : op.getInputs()) {
|
||||
if (auto assume_all = operand.getDefiningOp<AssumingAllOp>())
|
||||
operands.append(assume_all.operand_begin(), assume_all->operand_end());
|
||||
else
|
||||
operands.push_back(operand);
|
||||
}
|
||||
|
||||
// We didn't find any other `assuming_all` ops to merge with.
|
||||
if (operands.size() == op.getNumOperands())
|
||||
return failure();
|
||||
|
||||
// Replace with a new `assuming_all` operation with merged constraints.
|
||||
rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct AssumingAllToCstrEqCanonicalization
|
||||
: public OpRewritePattern<AssumingAllOp> {
|
||||
using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
|
||||
|
@ -506,7 +539,8 @@ struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
|
|||
|
||||
void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization,
|
||||
patterns.add<MergeAssumingAllOps, AssumingAllOneOp,
|
||||
AssumingAllToCstrEqCanonicalization,
|
||||
RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
|
||||
}
|
||||
|
||||
|
|
|
@ -463,6 +463,26 @@ func @cstr_require_no_fold(%arg0: i1) {
|
|||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// merge assuming_all operations
|
||||
// CHECK-LABEL: func @f
|
||||
func @f() {
|
||||
// CHECK-NEXT: %[[W0:.*]] = "test.source"
|
||||
// CHECK-NEXT: %[[W1:.*]] = "test.source"
|
||||
// CHECK-NEXT: %[[W2:.*]] = "test.source"
|
||||
// CHECK-NEXT: shape.assuming_all %[[W0]], %[[W1]], %[[W2]]
|
||||
// CHECK-NEXT: consume.witness
|
||||
// CHECK-NEXT: return
|
||||
%0 = "test.source"() : () -> !shape.witness
|
||||
%1 = "test.source"() : () -> !shape.witness
|
||||
%2 = "test.source"() : () -> !shape.witness
|
||||
%3 = shape.assuming_all %0, %1
|
||||
%4 = shape.assuming_all %3, %2
|
||||
"consume.witness"(%4) : (!shape.witness) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
// `assuming_all` with all `cstr_eq` and shared operands can be collapsed.
|
||||
// CHECK-LABEL: func @assuming_all_to_cstr_eq
|
||||
|
|
Loading…
Reference in New Issue