Simplify subview
This commit is contained in:
parent
a4ab8d5c86
commit
52d56133ed
|
@ -310,7 +310,7 @@ public:
|
|||
// load subindex(x)
|
||||
// affine.load subindex(x)
|
||||
// dealloc subindex(x)
|
||||
struct SimplifySubViewUsers : public OpRewritePattern<SubIndexOp> {
|
||||
struct SimplifySubIndexUsers : public OpRewritePattern<SubIndexOp> {
|
||||
using OpRewritePattern<SubIndexOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(SubIndexOp subindex,
|
||||
|
@ -469,6 +469,192 @@ struct SimplifySubViewUsers : public OpRewritePattern<SubIndexOp> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Simplify all uses of subindex, specifically
|
||||
// store subindex(x) = ...
|
||||
// affine.store subindex(x) = ...
|
||||
// load subindex(x)
|
||||
// affine.load subindex(x)
|
||||
// dealloc subindex(x)
|
||||
struct SimplifySubViewUsers : public OpRewritePattern<memref::SubViewOp> {
|
||||
using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(memref::SubViewOp subindex,
|
||||
PatternRewriter &rewriter) const override {
|
||||
bool changed = false;
|
||||
int64_t offs = -1;
|
||||
for (auto tup :
|
||||
llvm::zip(subindex.static_offsets(), subindex.static_sizes(),
|
||||
subindex.static_strides())) {
|
||||
auto sz = std::get<1>(tup).dyn_cast<IntegerAttr>().getValue();
|
||||
|
||||
auto stride = std::get<2>(tup).dyn_cast<IntegerAttr>().getValue();
|
||||
if (stride != 1)
|
||||
return failure();
|
||||
|
||||
if (offs == -1) {
|
||||
offs = std::get<0>(tup)
|
||||
.dyn_cast<IntegerAttr>()
|
||||
.getValue()
|
||||
.getLimitedValue();
|
||||
if (sz != 1)
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
Value off = rewriter.create<ConstantIndexOp>(subindex.getLoc(), offs);
|
||||
assert(off);
|
||||
|
||||
for (OpOperand &use : llvm::make_early_inc_range(subindex->getUses())) {
|
||||
rewriter.setInsertionPoint(use.getOwner());
|
||||
if (auto dealloc = dyn_cast<memref::DeallocOp>(use.getOwner())) {
|
||||
changed = true;
|
||||
rewriter.replaceOpWithNewOp<memref::DeallocOp>(dealloc,
|
||||
subindex.source());
|
||||
} else if (auto loadOp = dyn_cast<memref::LoadOp>(use.getOwner())) {
|
||||
if (loadOp.memref() == subindex) {
|
||||
SmallVector<Value, 4> indices = loadOp.indices();
|
||||
if (subindex.getType().cast<MemRefType>().getShape().size() ==
|
||||
subindex.source()
|
||||
.getType()
|
||||
.cast<MemRefType>()
|
||||
.getShape()
|
||||
.size()) {
|
||||
assert(indices.size() > 0);
|
||||
indices[0] =
|
||||
rewriter.create<AddIOp>(subindex.getLoc(), indices[0], off);
|
||||
} else {
|
||||
if (subindex.getType().cast<MemRefType>().getShape().size() + 1 ==
|
||||
subindex.source()
|
||||
.getType()
|
||||
.cast<MemRefType>()
|
||||
.getShape()
|
||||
.size())
|
||||
indices.insert(indices.begin(), off);
|
||||
else {
|
||||
assert(indices.size() > 0);
|
||||
indices.erase(indices.begin());
|
||||
}
|
||||
}
|
||||
|
||||
assert(subindex.source()
|
||||
.getType()
|
||||
.cast<MemRefType>()
|
||||
.getShape()
|
||||
.size() == indices.size());
|
||||
rewriter.replaceOpWithNewOp<memref::LoadOp>(loadOp, subindex.source(),
|
||||
indices);
|
||||
changed = true;
|
||||
}
|
||||
} else if (auto storeOp = dyn_cast<memref::StoreOp>(use.getOwner())) {
|
||||
if (storeOp.memref() == subindex) {
|
||||
SmallVector<Value, 4> indices = storeOp.indices();
|
||||
if (subindex.getType().cast<MemRefType>().getShape().size() ==
|
||||
subindex.source()
|
||||
.getType()
|
||||
.cast<MemRefType>()
|
||||
.getShape()
|
||||
.size()) {
|
||||
assert(indices.size() > 0);
|
||||
indices[0] =
|
||||
rewriter.create<AddIOp>(subindex.getLoc(), indices[0], off);
|
||||
} else {
|
||||
if (subindex.getType().cast<MemRefType>().getShape().size() + 1 ==
|
||||
subindex.source()
|
||||
.getType()
|
||||
.cast<MemRefType>()
|
||||
.getShape()
|
||||
.size())
|
||||
indices.insert(indices.begin(), off);
|
||||
else {
|
||||
if (indices.size() == 0) {
|
||||
llvm::errs() << " storeOp: " << storeOp
|
||||
<< " - subidx: " << subindex << "\n";
|
||||
}
|
||||
assert(indices.size() > 0);
|
||||
indices.erase(indices.begin());
|
||||
}
|
||||
}
|
||||
|
||||
if (subindex.source()
|
||||
.getType()
|
||||
.cast<MemRefType>()
|
||||
.getShape()
|
||||
.size() != indices.size()) {
|
||||
llvm::errs() << " storeOp: " << storeOp << " - subidx: " << subindex
|
||||
<< "\n";
|
||||
}
|
||||
assert(subindex.source()
|
||||
.getType()
|
||||
.cast<MemRefType>()
|
||||
.getShape()
|
||||
.size() == indices.size());
|
||||
rewriter.replaceOpWithNewOp<memref::StoreOp>(
|
||||
storeOp, storeOp.value(), subindex.source(), indices);
|
||||
changed = true;
|
||||
}
|
||||
} else if (auto storeOp = dyn_cast<AffineStoreOp>(use.getOwner())) {
|
||||
if (storeOp.memref() == subindex) {
|
||||
if (subindex.getType().cast<MemRefType>().getShape().size() + 1 ==
|
||||
subindex.source()
|
||||
.getType()
|
||||
.cast<MemRefType>()
|
||||
.getShape()
|
||||
.size()) {
|
||||
|
||||
std::vector<Value> indices;
|
||||
auto map = storeOp.getAffineMap();
|
||||
indices.push_back(off);
|
||||
for (size_t i = 0; i < map.getNumResults(); i++) {
|
||||
auto apply = rewriter.create<AffineApplyOp>(
|
||||
storeOp.getLoc(), map.getSliceMap(i, 1),
|
||||
storeOp.getMapOperands());
|
||||
indices.push_back(apply->getResult(0));
|
||||
}
|
||||
|
||||
assert(subindex.source()
|
||||
.getType()
|
||||
.cast<MemRefType>()
|
||||
.getShape()
|
||||
.size() == indices.size());
|
||||
rewriter.replaceOpWithNewOp<memref::StoreOp>(
|
||||
storeOp, storeOp.value(), subindex.source(), indices);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
} else if (auto storeOp = dyn_cast<AffineLoadOp>(use.getOwner())) {
|
||||
if (storeOp.memref() == subindex) {
|
||||
if (subindex.getType().cast<MemRefType>().getShape().size() + 1 ==
|
||||
subindex.source()
|
||||
.getType()
|
||||
.cast<MemRefType>()
|
||||
.getShape()
|
||||
.size()) {
|
||||
|
||||
std::vector<Value> indices;
|
||||
auto map = storeOp.getAffineMap();
|
||||
indices.push_back(off);
|
||||
for (size_t i = 0; i < map.getNumResults(); i++) {
|
||||
auto apply = rewriter.create<AffineApplyOp>(
|
||||
storeOp.getLoc(), map.getSliceMap(i, 1),
|
||||
storeOp.getMapOperands());
|
||||
indices.push_back(apply->getResult(0));
|
||||
}
|
||||
assert(subindex.source()
|
||||
.getType()
|
||||
.cast<MemRefType>()
|
||||
.getShape()
|
||||
.size() == indices.size());
|
||||
rewriter.replaceOpWithNewOp<memref::LoadOp>(
|
||||
storeOp, subindex.source(), indices);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return success(changed);
|
||||
}
|
||||
};
|
||||
|
||||
/// Simplify select cast(x), cast(y) to cast(select x, y)
|
||||
struct SelectOfCast : public OpRewritePattern<SelectOp> {
|
||||
using OpRewritePattern<SelectOp>::OpRewritePattern;
|
||||
|
@ -522,8 +708,9 @@ struct SelectOfSubIndex : public OpRewritePattern<SelectOp> {
|
|||
|
||||
void SubIndexOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<CastOfSubIndex, SubIndexOpMemRefCastFolder, SubIndex2,
|
||||
SubToCast, SimplifySubViewUsers, SelectOfCast,
|
||||
results
|
||||
.insert<CastOfSubIndex, SubIndexOpMemRefCastFolder, SubIndex2, SubToCast,
|
||||
SimplifySubViewUsers, SimplifySubIndexUsers, SelectOfCast,
|
||||
SelectOfSubIndex, SubToSubView, RedundantDynSubIndex>(context);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue