Also simplify trivial dynamic subindexing patterns
This commit is contained in:
parent
caa0302dbb
commit
e216fc1163
|
@ -24,7 +24,7 @@
|
|||
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::polygeist;
|
||||
using namespace polygeist;
|
||||
using namespace mlir::arith;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -236,6 +236,64 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// Simplify redundant dynamic subindex patterns which tries to represent
|
||||
// rank-reducing indexing:
|
||||
// %3 = "polygeist.subindex"(%1, %arg0) : (memref<2x1000xi32>, index) ->
|
||||
// memref<?x1000xi32> %4 = "polygeist.subindex"(%3, %c0) :
|
||||
// (memref<?x1000xi32>, index) -> memref<1000xi32>
|
||||
// simplifies to:
|
||||
// %4 = "polygeist.subindex"(%1, %arg0) : (memref<2x1000xi32>, index) ->
|
||||
// memref<1000xi32>
|
||||
|
||||
class RedundantDynSubIndex final : public OpRewritePattern<SubIndexOp> {
|
||||
public:
|
||||
using OpRewritePattern<SubIndexOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(SubIndexOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto srcOp = dyn_cast<SubIndexOp>(op.source().getDefiningOp());
|
||||
if (!srcOp)
|
||||
return failure();
|
||||
|
||||
auto srcMemRefType = op.source().getType().cast<MemRefType>();
|
||||
auto resMemRefType = op.result().getType().cast<MemRefType>();
|
||||
|
||||
// Check if there are multiple users of the dynamically sized memory
|
||||
if (!op.source().hasOneUse())
|
||||
return failure();
|
||||
|
||||
// Check that the source op indeed is a dynamically indexed memory in the
|
||||
// 0'th index.
|
||||
if (srcMemRefType.getShape()[0] != -1)
|
||||
return failure();
|
||||
|
||||
// Check that this is indeed a rank reducing operation
|
||||
if (srcMemRefType.getShape().size() !=
|
||||
(resMemRefType.getShape().size() + 1))
|
||||
return failure();
|
||||
for (auto it : llvm::zip(srcMemRefType.getShape().drop_front(),
|
||||
resMemRefType.getShape())) {
|
||||
if (std::get<0>(it) != std::get<1>(it))
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Check that we're indexing into the 0'th index in the 2nd subindex op
|
||||
auto constIdx = dyn_cast<arith::ConstantOp>(op.index().getDefiningOp());
|
||||
if (!constIdx)
|
||||
return failure();
|
||||
auto constValue = constIdx.value().dyn_cast<IntegerAttr>();
|
||||
if (!constValue || !constValue.getType().isa<IndexType>() ||
|
||||
constValue.getValue().getZExtValue() != 0)
|
||||
return failure();
|
||||
|
||||
// Valid optimization target; perform the substitution.
|
||||
rewriter.replaceOpWithNewOp<SubIndexOp>(op, op.result().getType(),
|
||||
srcOp.source(), srcOp.index());
|
||||
rewriter.eraseOp(srcOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Simplify all uses of subindex, specifically
|
||||
// store subindex(x) = ...
|
||||
// affine.store subindex(x) = ...
|
||||
|
@ -456,7 +514,7 @@ void SubIndexOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||
MLIRContext *context) {
|
||||
results.insert<CastOfSubIndex, SubIndexOpMemRefCastFolder, SubIndex2,
|
||||
SubToCast, SimplifySubViewUsers, SelectOfCast,
|
||||
SelectOfSubIndex, SubToSubView>(context);
|
||||
SelectOfSubIndex, SubToSubView, RedundantDynSubIndex>(context);
|
||||
}
|
||||
|
||||
/// Simplify memref2pointer(cast(x)) to memref2pointer(x)
|
||||
|
|
|
@ -12,3 +12,18 @@ module {
|
|||
return %1 : memref<30xi32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @main(%arg0: index) -> memref<1000xi32> {
|
||||
// CHECK: %0 = memref.alloca() : memref<2x1000xi32>
|
||||
// CHECK: %1 = memref.subview %0[%arg0, 0] [1, 1000] [1, 1] : memref<2x1000xi32> to memref<1000xi32>
|
||||
// CHECK: return %1 : memref<1000xi32>
|
||||
// CHECK: }
|
||||
func @main(%arg0 : index) -> memref<1000xi32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%1 = memref.alloca() : memref<2x1000xi32>
|
||||
%3 = "polygeist.subindex"(%1, %arg0) : (memref<2x1000xi32>, index) -> memref<?x1000xi32>
|
||||
%4 = "polygeist.subindex"(%3, %c0) : (memref<?x1000xi32>, index) -> memref<1000xi32>
|
||||
return %4 : memref<1000xi32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue