Also simplify trivial dynamic subindexing patterns

This commit is contained in:
Morten Borup Petersen 2021-11-30 09:55:32 +00:00 committed by William Moses
parent caa0302dbb
commit e216fc1163
2 changed files with 75 additions and 2 deletions

View File

@ -24,7 +24,7 @@
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
using namespace mlir;
using namespace mlir::polygeist;
using namespace polygeist;
using namespace mlir::arith;
//===----------------------------------------------------------------------===//
@ -236,6 +236,64 @@ public:
}
};
// Simplify redundant dynamic subindex patterns which tries to represent
// rank-reducing indexing:
// %3 = "polygeist.subindex"(%1, %arg0) : (memref<2x1000xi32>, index) ->
// memref<?x1000xi32> %4 = "polygeist.subindex"(%3, %c0) :
// (memref<?x1000xi32>, index) -> memref<1000xi32>
// simplifies to:
// %4 = "polygeist.subindex"(%1, %arg0) : (memref<2x1000xi32>, index) ->
// memref<1000xi32>
class RedundantDynSubIndex final : public OpRewritePattern<SubIndexOp> {
public:
using OpRewritePattern<SubIndexOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SubIndexOp op,
PatternRewriter &rewriter) const override {
auto srcOp = dyn_cast<SubIndexOp>(op.source().getDefiningOp());
if (!srcOp)
return failure();
auto srcMemRefType = op.source().getType().cast<MemRefType>();
auto resMemRefType = op.result().getType().cast<MemRefType>();
// Check if there are multiple users of the dynamically sized memory
if (!op.source().hasOneUse())
return failure();
// Check that the source op indeed is a dynamically indexed memory in the
// 0'th index.
if (srcMemRefType.getShape()[0] != -1)
return failure();
// Check that this is indeed a rank reducing operation
if (srcMemRefType.getShape().size() !=
(resMemRefType.getShape().size() + 1))
return failure();
for (auto it : llvm::zip(srcMemRefType.getShape().drop_front(),
resMemRefType.getShape())) {
if (std::get<0>(it) != std::get<1>(it))
return failure();
}
// Check that we're indexing into the 0'th index in the 2nd subindex op
auto constIdx = dyn_cast<arith::ConstantOp>(op.index().getDefiningOp());
if (!constIdx)
return failure();
auto constValue = constIdx.value().dyn_cast<IntegerAttr>();
if (!constValue || !constValue.getType().isa<IndexType>() ||
constValue.getValue().getZExtValue() != 0)
return failure();
// Valid optimization target; perform the substitution.
rewriter.replaceOpWithNewOp<SubIndexOp>(op, op.result().getType(),
srcOp.source(), srcOp.index());
rewriter.eraseOp(srcOp);
return success();
}
};
/// Simplify all uses of subindex, specifically
// store subindex(x) = ...
// affine.store subindex(x) = ...
@ -456,7 +514,7 @@ void SubIndexOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<CastOfSubIndex, SubIndexOpMemRefCastFolder, SubIndex2,
SubToCast, SimplifySubViewUsers, SelectOfCast,
SelectOfSubIndex, SubToSubView>(context);
SelectOfSubIndex, SubToSubView, RedundantDynSubIndex>(context);
}
/// Simplify memref2pointer(cast(x)) to memref2pointer(x)

View File

@ -12,3 +12,18 @@ module {
return %1 : memref<30xi32>
}
}
// -----
// CHECK: func @main(%arg0: index) -> memref<1000xi32> {
// CHECK: %0 = memref.alloca() : memref<2x1000xi32>
// CHECK: %1 = memref.subview %0[%arg0, 0] [1, 1000] [1, 1] : memref<2x1000xi32> to memref<1000xi32>
// CHECK: return %1 : memref<1000xi32>
// CHECK: }
func @main(%arg0 : index) -> memref<1000xi32> {
%c0 = arith.constant 0 : index
%1 = memref.alloca() : memref<2x1000xi32>
%3 = "polygeist.subindex"(%1, %arg0) : (memref<2x1000xi32>, index) -> memref<?x1000xi32>
%4 = "polygeist.subindex"(%3, %c0) : (memref<?x1000xi32>, index) -> memref<1000xi32>
return %4 : memref<1000xi32>
}