Simplify subview

This commit is contained in:
William S. Moses 2021-12-23 20:02:01 -05:00 committed by William Moses
parent a4ab8d5c86
commit 52d56133ed
1 changed files with 191 additions and 4 deletions

View File

@ -310,7 +310,7 @@ public:
// load subindex(x)
// affine.load subindex(x)
// dealloc subindex(x)
struct SimplifySubViewUsers : public OpRewritePattern<SubIndexOp> {
struct SimplifySubIndexUsers : public OpRewritePattern<SubIndexOp> {
using OpRewritePattern<SubIndexOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SubIndexOp subindex,
@ -469,6 +469,192 @@ struct SimplifySubViewUsers : public OpRewritePattern<SubIndexOp> {
}
};
/// Simplify all uses of subindex, specifically
// store subindex(x) = ...
// affine.store subindex(x) = ...
// load subindex(x)
// affine.load subindex(x)
// dealloc subindex(x)
struct SimplifySubViewUsers : public OpRewritePattern<memref::SubViewOp> {
using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
LogicalResult matchAndRewrite(memref::SubViewOp subindex,
PatternRewriter &rewriter) const override {
bool changed = false;
int64_t offs = -1;
for (auto tup :
llvm::zip(subindex.static_offsets(), subindex.static_sizes(),
subindex.static_strides())) {
auto sz = std::get<1>(tup).dyn_cast<IntegerAttr>().getValue();
auto stride = std::get<2>(tup).dyn_cast<IntegerAttr>().getValue();
if (stride != 1)
return failure();
if (offs == -1) {
offs = std::get<0>(tup)
.dyn_cast<IntegerAttr>()
.getValue()
.getLimitedValue();
if (sz != 1)
return failure();
}
}
Value off = rewriter.create<ConstantIndexOp>(subindex.getLoc(), offs);
assert(off);
for (OpOperand &use : llvm::make_early_inc_range(subindex->getUses())) {
rewriter.setInsertionPoint(use.getOwner());
if (auto dealloc = dyn_cast<memref::DeallocOp>(use.getOwner())) {
changed = true;
rewriter.replaceOpWithNewOp<memref::DeallocOp>(dealloc,
subindex.source());
} else if (auto loadOp = dyn_cast<memref::LoadOp>(use.getOwner())) {
if (loadOp.memref() == subindex) {
SmallVector<Value, 4> indices = loadOp.indices();
if (subindex.getType().cast<MemRefType>().getShape().size() ==
subindex.source()
.getType()
.cast<MemRefType>()
.getShape()
.size()) {
assert(indices.size() > 0);
indices[0] =
rewriter.create<AddIOp>(subindex.getLoc(), indices[0], off);
} else {
if (subindex.getType().cast<MemRefType>().getShape().size() + 1 ==
subindex.source()
.getType()
.cast<MemRefType>()
.getShape()
.size())
indices.insert(indices.begin(), off);
else {
assert(indices.size() > 0);
indices.erase(indices.begin());
}
}
assert(subindex.source()
.getType()
.cast<MemRefType>()
.getShape()
.size() == indices.size());
rewriter.replaceOpWithNewOp<memref::LoadOp>(loadOp, subindex.source(),
indices);
changed = true;
}
} else if (auto storeOp = dyn_cast<memref::StoreOp>(use.getOwner())) {
if (storeOp.memref() == subindex) {
SmallVector<Value, 4> indices = storeOp.indices();
if (subindex.getType().cast<MemRefType>().getShape().size() ==
subindex.source()
.getType()
.cast<MemRefType>()
.getShape()
.size()) {
assert(indices.size() > 0);
indices[0] =
rewriter.create<AddIOp>(subindex.getLoc(), indices[0], off);
} else {
if (subindex.getType().cast<MemRefType>().getShape().size() + 1 ==
subindex.source()
.getType()
.cast<MemRefType>()
.getShape()
.size())
indices.insert(indices.begin(), off);
else {
if (indices.size() == 0) {
llvm::errs() << " storeOp: " << storeOp
<< " - subidx: " << subindex << "\n";
}
assert(indices.size() > 0);
indices.erase(indices.begin());
}
}
if (subindex.source()
.getType()
.cast<MemRefType>()
.getShape()
.size() != indices.size()) {
llvm::errs() << " storeOp: " << storeOp << " - subidx: " << subindex
<< "\n";
}
assert(subindex.source()
.getType()
.cast<MemRefType>()
.getShape()
.size() == indices.size());
rewriter.replaceOpWithNewOp<memref::StoreOp>(
storeOp, storeOp.value(), subindex.source(), indices);
changed = true;
}
} else if (auto storeOp = dyn_cast<AffineStoreOp>(use.getOwner())) {
if (storeOp.memref() == subindex) {
if (subindex.getType().cast<MemRefType>().getShape().size() + 1 ==
subindex.source()
.getType()
.cast<MemRefType>()
.getShape()
.size()) {
std::vector<Value> indices;
auto map = storeOp.getAffineMap();
indices.push_back(off);
for (size_t i = 0; i < map.getNumResults(); i++) {
auto apply = rewriter.create<AffineApplyOp>(
storeOp.getLoc(), map.getSliceMap(i, 1),
storeOp.getMapOperands());
indices.push_back(apply->getResult(0));
}
assert(subindex.source()
.getType()
.cast<MemRefType>()
.getShape()
.size() == indices.size());
rewriter.replaceOpWithNewOp<memref::StoreOp>(
storeOp, storeOp.value(), subindex.source(), indices);
changed = true;
}
}
} else if (auto storeOp = dyn_cast<AffineLoadOp>(use.getOwner())) {
if (storeOp.memref() == subindex) {
if (subindex.getType().cast<MemRefType>().getShape().size() + 1 ==
subindex.source()
.getType()
.cast<MemRefType>()
.getShape()
.size()) {
std::vector<Value> indices;
auto map = storeOp.getAffineMap();
indices.push_back(off);
for (size_t i = 0; i < map.getNumResults(); i++) {
auto apply = rewriter.create<AffineApplyOp>(
storeOp.getLoc(), map.getSliceMap(i, 1),
storeOp.getMapOperands());
indices.push_back(apply->getResult(0));
}
assert(subindex.source()
.getType()
.cast<MemRefType>()
.getShape()
.size() == indices.size());
rewriter.replaceOpWithNewOp<memref::LoadOp>(
storeOp, subindex.source(), indices);
changed = true;
}
}
}
}
return success(changed);
}
};
/// Simplify select cast(x), cast(y) to cast(select x, y)
struct SelectOfCast : public OpRewritePattern<SelectOp> {
using OpRewritePattern<SelectOp>::OpRewritePattern;
@ -522,9 +708,10 @@ struct SelectOfSubIndex : public OpRewritePattern<SelectOp> {
void SubIndexOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<CastOfSubIndex, SubIndexOpMemRefCastFolder, SubIndex2,
SubToCast, SimplifySubViewUsers, SelectOfCast,
SelectOfSubIndex, SubToSubView, RedundantDynSubIndex>(context);
results
.insert<CastOfSubIndex, SubIndexOpMemRefCastFolder, SubIndex2, SubToCast,
SimplifySubViewUsers, SimplifySubIndexUsers, SelectOfCast,
SelectOfSubIndex, SubToSubView, RedundantDynSubIndex>(context);
}
/// Simplify memref2pointer(cast(x)) to memref2pointer(x)