942 lines
35 KiB
C++
942 lines
35 KiB
C++
//===- PolygeistOps.cpp - BFV dialect ops ---------------*- C++ -*-===//
|
|
//
|
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "polygeist/Ops.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "polygeist/Dialect.h"
|
|
|
|
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "polygeist/PolygeistOps.cpp.inc"
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
|
|
|
|
using namespace mlir;
|
|
using namespace polygeist;
|
|
using namespace mlir::arith;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// BarrierOp
|
|
//===----------------------------------------------------------------------===//
|
|
void print(OpAsmPrinter &out, BarrierOp) {
|
|
out << BarrierOp::getOperationName();
|
|
}
|
|
|
|
LogicalResult verify(BarrierOp) { return success(); }
|
|
|
|
ParseResult parseBarrierOp(OpAsmParser &, OperationState &) {
|
|
return success();
|
|
}
|
|
|
|
/// Collect the memory effects of the given op in 'effects'. Returns 'true' it
|
|
/// could extract the effect information from the op, otherwise returns 'false'
|
|
/// and conservatively populates the list with all possible effects.
|
|
static bool
|
|
collectEffects(Operation *op,
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
// Skip over barriers to avoid infinite recursion (those barriers would ask
|
|
// this barrier again).
|
|
if (isa<BarrierOp>(op))
|
|
return true;
|
|
|
|
// Collect effect instances the operation. Note that the implementation of
|
|
// getEffects erases all effect instances that have the type other than the
|
|
// template parameter so we collect them first in a local buffer and then
|
|
// copy.
|
|
SmallVector<MemoryEffects::EffectInstance> localEffects;
|
|
if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
|
|
iface.getEffects<MemoryEffects::Read>(localEffects);
|
|
llvm::append_range(effects, localEffects);
|
|
iface.getEffects<MemoryEffects::Write>(localEffects);
|
|
llvm::append_range(effects, localEffects);
|
|
iface.getEffects<MemoryEffects::Allocate>(localEffects);
|
|
llvm::append_range(effects, localEffects);
|
|
iface.getEffects<MemoryEffects::Free>(localEffects);
|
|
llvm::append_range(effects, localEffects);
|
|
return true;
|
|
}
|
|
|
|
// We need to be conservative here in case the op doesn't have the interface
|
|
// and assume it can have any possible effect.
|
|
effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Read>());
|
|
effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Write>());
|
|
effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Allocate>());
|
|
effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Free>());
|
|
return false;
|
|
}
|
|
|
|
void BarrierOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
Operation *op = getOperation();
|
|
for (Operation *it = op->getPrevNode(); it != nullptr; it = it->getPrevNode())
|
|
if (!collectEffects(it, effects))
|
|
return;
|
|
for (Operation *it = op->getNextNode(); it != nullptr; it = it->getNextNode())
|
|
if (!collectEffects(it, effects))
|
|
return;
|
|
|
|
// TODO: we need to handle regions in case the parent op isn't an SCF parallel
|
|
}
|
|
|
|
/// Replace cast(subindex(x, InterimType), FinalType) with subindex(x,
|
|
/// FinalType)
|
|
class CastOfSubIndex final : public OpRewritePattern<memref::CastOp> {
|
|
public:
|
|
using OpRewritePattern<memref::CastOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(memref::CastOp castOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto subindexOp = castOp.source().getDefiningOp<SubIndexOp>();
|
|
if (!subindexOp)
|
|
return failure();
|
|
|
|
if (castOp.getType().cast<MemRefType>().getShape().size() !=
|
|
subindexOp.getType().cast<MemRefType>().getShape().size())
|
|
return failure();
|
|
if (castOp.getType().cast<MemRefType>().getElementType() != subindexOp.result().getType().cast<MemRefType>().getElementType())
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<SubIndexOp>(
|
|
castOp, castOp.getType(), subindexOp.source(), subindexOp.index());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// Replace subindex(subindex(x)) with subindex(x) with appropriate
|
|
// indexing.
|
|
class SubIndex2 final : public OpRewritePattern<SubIndexOp> {
|
|
public:
|
|
using OpRewritePattern<SubIndexOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(SubIndexOp subViewOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto prevOp = subViewOp.source().getDefiningOp<SubIndexOp>();
|
|
if (!prevOp)
|
|
return failure();
|
|
|
|
auto mt0 = prevOp.source().getType().cast<MemRefType>();
|
|
auto mt1 = prevOp.getType().cast<MemRefType>();
|
|
auto mt2 = subViewOp.getType().cast<MemRefType>();
|
|
if (mt0.getShape().size() == mt2.getShape().size() &&
|
|
mt1.getShape().size() == mt0.getShape().size() + 1) {
|
|
rewriter.replaceOpWithNewOp<SubIndexOp>(subViewOp, mt2, prevOp.source(),
|
|
subViewOp.index());
|
|
return success();
|
|
}
|
|
if (mt0.getShape().size() == mt2.getShape().size() &&
|
|
mt1.getShape().size() == mt0.getShape().size()) {
|
|
rewriter.replaceOpWithNewOp<SubIndexOp>(
|
|
subViewOp, mt2, prevOp.source(),
|
|
rewriter.create<AddIOp>(prevOp.getLoc(), subViewOp.index(),
|
|
prevOp.index()));
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
// When possible, simplify subindex(x) to cast(x)
|
|
class SubToCast final : public OpRewritePattern<SubIndexOp> {
|
|
public:
|
|
using OpRewritePattern<SubIndexOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(SubIndexOp subViewOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto prev = subViewOp.source().getType().cast<MemRefType>();
|
|
auto post = subViewOp.getType().cast<MemRefType>();
|
|
bool legal = prev.getShape().size() == post.getShape().size();
|
|
if (legal) {
|
|
|
|
auto cidx = subViewOp.index().getDefiningOp<ConstantIndexOp>();
|
|
if (!cidx)
|
|
return failure();
|
|
|
|
if (cidx.value() != 0)
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<memref::CastOp>(subViewOp, subViewOp.source(),
|
|
post);
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
// Simplify polygeist.subindex to memref.subview.
|
|
class SubToSubView final : public OpRewritePattern<SubIndexOp> {
|
|
public:
|
|
using OpRewritePattern<SubIndexOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(SubIndexOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
auto srcMemRefType = op.source().getType().cast<MemRefType>();
|
|
auto resMemRefType = op.result().getType().cast<MemRefType>();
|
|
auto dims = srcMemRefType.getShape().size();
|
|
|
|
// For now, restrict subview lowering to statically defined memref's
|
|
if (!srcMemRefType.hasStaticShape() | !resMemRefType.hasStaticShape())
|
|
return failure();
|
|
|
|
// For now, restrict to simple rank-reducing indexing
|
|
if (srcMemRefType.getShape().size() <= resMemRefType.getShape().size())
|
|
return failure();
|
|
|
|
// Build offset, sizes and strides
|
|
SmallVector<OpFoldResult> sizes(dims, rewriter.getIndexAttr(0));
|
|
sizes[0] = op.index();
|
|
SmallVector<OpFoldResult> offsets(dims);
|
|
for (auto dim : llvm::enumerate(srcMemRefType.getShape())) {
|
|
if (dim.index() == 0)
|
|
offsets[0] = rewriter.getIndexAttr(1);
|
|
else
|
|
offsets[dim.index()] = rewriter.getIndexAttr(dim.value());
|
|
}
|
|
SmallVector<OpFoldResult> strides(dims, rewriter.getIndexAttr(1));
|
|
|
|
// Generate the appropriate return type:
|
|
auto subMemRefType = MemRefType::get(srcMemRefType.getShape().drop_front(),
|
|
srcMemRefType.getElementType());
|
|
|
|
rewriter.replaceOpWithNewOp<memref::SubViewOp>(
|
|
op, subMemRefType, op.source(), sizes, offsets, strides);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// Simplify redundant dynamic subindex patterns which tries to represent
|
|
// rank-reducing indexing:
|
|
// %3 = "polygeist.subindex"(%1, %arg0) : (memref<2x1000xi32>, index) ->
|
|
// memref<?x1000xi32> %4 = "polygeist.subindex"(%3, %c0) :
|
|
// (memref<?x1000xi32>, index) -> memref<1000xi32>
|
|
// simplifies to:
|
|
// %4 = "polygeist.subindex"(%1, %arg0) : (memref<2x1000xi32>, index) ->
|
|
// memref<1000xi32>
|
|
|
|
class RedundantDynSubIndex final : public OpRewritePattern<SubIndexOp> {
|
|
public:
|
|
using OpRewritePattern<SubIndexOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(SubIndexOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
auto srcOp = op.source().getDefiningOp<SubIndexOp>();
|
|
if (!srcOp)
|
|
return failure();
|
|
|
|
auto preMemRefType = srcOp.source().getType().cast<MemRefType>();
|
|
auto srcMemRefType = op.source().getType().cast<MemRefType>();
|
|
auto resMemRefType = op.result().getType().cast<MemRefType>();
|
|
|
|
// Check that this is indeed a rank reducing operation
|
|
if (srcMemRefType.getShape().size() !=
|
|
(resMemRefType.getShape().size() + 1))
|
|
return failure();
|
|
|
|
// Check that the previous op is the same rank.
|
|
if (srcMemRefType.getShape().size() != preMemRefType.getShape().size())
|
|
return failure();
|
|
|
|
// Valid optimization target; perform the substitution.
|
|
rewriter.replaceOpWithNewOp<SubIndexOp>(
|
|
op, op.result().getType(), srcOp.source(),
|
|
rewriter.create<arith::AddIOp>(op.getLoc(), op.index(), srcOp.index()));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Simplify all uses of subindex, specifically
|
|
// store subindex(x) = ...
|
|
// affine.store subindex(x) = ...
|
|
// load subindex(x)
|
|
// affine.load subindex(x)
|
|
// dealloc subindex(x)
|
|
struct SimplifySubIndexUsers : public OpRewritePattern<SubIndexOp> {
|
|
using OpRewritePattern<SubIndexOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(SubIndexOp subindex,
|
|
PatternRewriter &rewriter) const override {
|
|
bool changed = false;
|
|
|
|
for (OpOperand &use : llvm::make_early_inc_range(subindex->getUses())) {
|
|
rewriter.setInsertionPoint(use.getOwner());
|
|
if (auto dealloc = dyn_cast<memref::DeallocOp>(use.getOwner())) {
|
|
changed = true;
|
|
rewriter.replaceOpWithNewOp<memref::DeallocOp>(dealloc,
|
|
subindex.source());
|
|
} else if (auto loadOp = dyn_cast<memref::LoadOp>(use.getOwner())) {
|
|
if (loadOp.memref() == subindex) {
|
|
SmallVector<Value, 4> indices = loadOp.indices();
|
|
if (subindex.getType().cast<MemRefType>().getShape().size() ==
|
|
subindex.source()
|
|
.getType()
|
|
.cast<MemRefType>()
|
|
.getShape()
|
|
.size()) {
|
|
assert(indices.size() > 0);
|
|
indices[0] = rewriter.create<AddIOp>(subindex.getLoc(), indices[0],
|
|
subindex.index());
|
|
} else {
|
|
assert (subindex.getType().cast<MemRefType>().getShape().size() + 1 ==
|
|
subindex.source()
|
|
.getType()
|
|
.cast<MemRefType>()
|
|
.getShape()
|
|
.size());
|
|
indices.insert(indices.begin(), subindex.index());
|
|
}
|
|
|
|
assert(subindex.source()
|
|
.getType()
|
|
.cast<MemRefType>()
|
|
.getShape()
|
|
.size() == indices.size());
|
|
rewriter.replaceOpWithNewOp<memref::LoadOp>(loadOp, subindex.source(),
|
|
indices);
|
|
changed = true;
|
|
}
|
|
} else if (auto storeOp = dyn_cast<memref::StoreOp>(use.getOwner())) {
|
|
if (storeOp.memref() == subindex) {
|
|
SmallVector<Value, 4> indices = storeOp.indices();
|
|
if (subindex.getType().cast<MemRefType>().getShape().size() ==
|
|
subindex.source()
|
|
.getType()
|
|
.cast<MemRefType>()
|
|
.getShape()
|
|
.size()) {
|
|
assert(indices.size() > 0);
|
|
indices[0] = rewriter.create<AddIOp>(subindex.getLoc(), indices[0],
|
|
subindex.index());
|
|
} else {
|
|
assert (subindex.getType().cast<MemRefType>().getShape().size() + 1 ==
|
|
subindex.source()
|
|
.getType()
|
|
.cast<MemRefType>()
|
|
.getShape()
|
|
.size());
|
|
indices.insert(indices.begin(), subindex.index());
|
|
}
|
|
assert(subindex.source()
|
|
.getType()
|
|
.cast<MemRefType>()
|
|
.getShape()
|
|
.size() == indices.size());
|
|
rewriter.replaceOpWithNewOp<memref::StoreOp>(
|
|
storeOp, storeOp.value(), subindex.source(), indices);
|
|
changed = true;
|
|
}
|
|
} else if (auto storeOp = dyn_cast<memref::AtomicRMWOp>(use.getOwner())) {
|
|
if (storeOp.memref() == subindex) {
|
|
SmallVector<Value, 4> indices = storeOp.indices();
|
|
if (subindex.getType().cast<MemRefType>().getShape().size() ==
|
|
subindex.source()
|
|
.getType()
|
|
.cast<MemRefType>()
|
|
.getShape()
|
|
.size()) {
|
|
assert(indices.size() > 0);
|
|
indices[0] = rewriter.create<AddIOp>(subindex.getLoc(), indices[0],
|
|
subindex.index());
|
|
} else {
|
|
assert (subindex.getType().cast<MemRefType>().getShape().size() + 1 ==
|
|
subindex.source()
|
|
.getType()
|
|
.cast<MemRefType>()
|
|
.getShape()
|
|
.size());
|
|
indices.insert(indices.begin(), subindex.index());
|
|
}
|
|
assert(subindex.source()
|
|
.getType()
|
|
.cast<MemRefType>()
|
|
.getShape()
|
|
.size() == indices.size());
|
|
rewriter.replaceOpWithNewOp<memref::AtomicRMWOp>(
|
|
storeOp, storeOp.getType(), storeOp.kind(), storeOp.value(), subindex.source(), indices);
|
|
changed = true;
|
|
}
|
|
} else if (auto storeOp = dyn_cast<AffineStoreOp>(use.getOwner())) {
|
|
if (storeOp.memref() == subindex) {
|
|
if (subindex.getType().cast<MemRefType>().getShape().size() + 1 ==
|
|
subindex.source()
|
|
.getType()
|
|
.cast<MemRefType>()
|
|
.getShape()
|
|
.size()) {
|
|
|
|
std::vector<Value> indices;
|
|
auto map = storeOp.getAffineMap();
|
|
indices.push_back(subindex.index());
|
|
for (size_t i = 0; i < map.getNumResults(); i++) {
|
|
auto apply = rewriter.create<AffineApplyOp>(
|
|
storeOp.getLoc(), map.getSliceMap(i, 1),
|
|
storeOp.getMapOperands());
|
|
indices.push_back(apply->getResult(0));
|
|
}
|
|
|
|
assert(subindex.source()
|
|
.getType()
|
|
.cast<MemRefType>()
|
|
.getShape()
|
|
.size() == indices.size());
|
|
rewriter.replaceOpWithNewOp<memref::StoreOp>(
|
|
storeOp, storeOp.value(), subindex.source(), indices);
|
|
changed = true;
|
|
}
|
|
}
|
|
} else if (auto storeOp = dyn_cast<AffineLoadOp>(use.getOwner())) {
|
|
if (storeOp.memref() == subindex) {
|
|
if (subindex.getType().cast<MemRefType>().getShape().size() + 1 ==
|
|
subindex.source()
|
|
.getType()
|
|
.cast<MemRefType>()
|
|
.getShape()
|
|
.size()) {
|
|
|
|
std::vector<Value> indices;
|
|
auto map = storeOp.getAffineMap();
|
|
indices.push_back(subindex.index());
|
|
for (size_t i = 0; i < map.getNumResults(); i++) {
|
|
auto apply = rewriter.create<AffineApplyOp>(
|
|
storeOp.getLoc(), map.getSliceMap(i, 1),
|
|
storeOp.getMapOperands());
|
|
indices.push_back(apply->getResult(0));
|
|
}
|
|
assert(subindex.source()
|
|
.getType()
|
|
.cast<MemRefType>()
|
|
.getShape()
|
|
.size() == indices.size());
|
|
rewriter.replaceOpWithNewOp<memref::LoadOp>(
|
|
storeOp, subindex.source(), indices);
|
|
changed = true;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return success(changed);
|
|
}
|
|
};
|
|
|
|
/// Simplify all uses of subindex, specifically
|
|
// store subindex(x) = ...
|
|
// affine.store subindex(x) = ...
|
|
// load subindex(x)
|
|
// affine.load subindex(x)
|
|
// dealloc subindex(x)
|
|
struct SimplifySubViewUsers : public OpRewritePattern<memref::SubViewOp> {
|
|
using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(memref::SubViewOp subindex,
|
|
PatternRewriter &rewriter) const override {
|
|
bool changed = false;
|
|
int64_t offs = -1;
|
|
for (auto tup :
|
|
llvm::zip(subindex.static_offsets(), subindex.static_sizes(),
|
|
subindex.static_strides())) {
|
|
auto sz = std::get<1>(tup).dyn_cast<IntegerAttr>().getValue();
|
|
|
|
auto stride = std::get<2>(tup).dyn_cast<IntegerAttr>().getValue();
|
|
if (stride != 1)
|
|
return failure();
|
|
|
|
if (offs == -1) {
|
|
offs = std::get<0>(tup)
|
|
.dyn_cast<IntegerAttr>()
|
|
.getValue()
|
|
.getLimitedValue();
|
|
if (sz != 1)
|
|
return failure();
|
|
}
|
|
}
|
|
Value off = rewriter.create<ConstantIndexOp>(subindex.getLoc(), offs);
|
|
assert(off);
|
|
|
|
for (OpOperand &use : llvm::make_early_inc_range(subindex->getUses())) {
|
|
rewriter.setInsertionPoint(use.getOwner());
|
|
if (auto dealloc = dyn_cast<memref::DeallocOp>(use.getOwner())) {
|
|
changed = true;
|
|
rewriter.replaceOpWithNewOp<memref::DeallocOp>(dealloc,
|
|
subindex.source());
|
|
} else if (auto loadOp = dyn_cast<memref::LoadOp>(use.getOwner())) {
|
|
if (loadOp.memref() == subindex) {
|
|
SmallVector<Value, 4> indices = loadOp.indices();
|
|
if (subindex.getType().cast<MemRefType>().getShape().size() ==
|
|
subindex.source()
|
|
.getType()
|
|
.cast<MemRefType>()
|
|
.getShape()
|
|
.size()) {
|
|
assert(indices.size() > 0);
|
|
indices[0] =
|
|
rewriter.create<AddIOp>(subindex.getLoc(), indices[0], off);
|
|
} else {
|
|
if (subindex.getType().cast<MemRefType>().getShape().size() + 1 ==
|
|
subindex.source()
|
|
.getType()
|
|
.cast<MemRefType>()
|
|
.getShape()
|
|
.size())
|
|
indices.insert(indices.begin(), off);
|
|
else {
|
|
assert(indices.size() > 0);
|
|
indices.erase(indices.begin());
|
|
}
|
|
}
|
|
|
|
assert(subindex.source()
|
|
.getType()
|
|
.cast<MemRefType>()
|
|
.getShape()
|
|
.size() == indices.size());
|
|
rewriter.replaceOpWithNewOp<memref::LoadOp>(loadOp, subindex.source(),
|
|
indices);
|
|
changed = true;
|
|
}
|
|
} else if (auto storeOp = dyn_cast<memref::StoreOp>(use.getOwner())) {
|
|
if (storeOp.memref() == subindex) {
|
|
SmallVector<Value, 4> indices = storeOp.indices();
|
|
if (subindex.getType().cast<MemRefType>().getShape().size() ==
|
|
subindex.source()
|
|
.getType()
|
|
.cast<MemRefType>()
|
|
.getShape()
|
|
.size()) {
|
|
assert(indices.size() > 0);
|
|
indices[0] =
|
|
rewriter.create<AddIOp>(subindex.getLoc(), indices[0], off);
|
|
} else {
|
|
if (subindex.getType().cast<MemRefType>().getShape().size() + 1 ==
|
|
subindex.source()
|
|
.getType()
|
|
.cast<MemRefType>()
|
|
.getShape()
|
|
.size())
|
|
indices.insert(indices.begin(), off);
|
|
else {
|
|
if (indices.size() == 0) {
|
|
llvm::errs() << " storeOp: " << storeOp
|
|
<< " - subidx: " << subindex << "\n";
|
|
}
|
|
assert(indices.size() > 0);
|
|
indices.erase(indices.begin());
|
|
}
|
|
}
|
|
|
|
if (subindex.source()
|
|
.getType()
|
|
.cast<MemRefType>()
|
|
.getShape()
|
|
.size() != indices.size()) {
|
|
llvm::errs() << " storeOp: " << storeOp << " - subidx: " << subindex
|
|
<< "\n";
|
|
}
|
|
assert(subindex.source()
|
|
.getType()
|
|
.cast<MemRefType>()
|
|
.getShape()
|
|
.size() == indices.size());
|
|
rewriter.replaceOpWithNewOp<memref::StoreOp>(
|
|
storeOp, storeOp.value(), subindex.source(), indices);
|
|
changed = true;
|
|
}
|
|
} else if (auto storeOp = dyn_cast<AffineStoreOp>(use.getOwner())) {
|
|
if (storeOp.memref() == subindex) {
|
|
if (subindex.getType().cast<MemRefType>().getShape().size() + 1 ==
|
|
subindex.source()
|
|
.getType()
|
|
.cast<MemRefType>()
|
|
.getShape()
|
|
.size()) {
|
|
|
|
std::vector<Value> indices;
|
|
auto map = storeOp.getAffineMap();
|
|
indices.push_back(off);
|
|
for (size_t i = 0; i < map.getNumResults(); i++) {
|
|
auto apply = rewriter.create<AffineApplyOp>(
|
|
storeOp.getLoc(), map.getSliceMap(i, 1),
|
|
storeOp.getMapOperands());
|
|
indices.push_back(apply->getResult(0));
|
|
}
|
|
|
|
assert(subindex.source()
|
|
.getType()
|
|
.cast<MemRefType>()
|
|
.getShape()
|
|
.size() == indices.size());
|
|
rewriter.replaceOpWithNewOp<memref::StoreOp>(
|
|
storeOp, storeOp.value(), subindex.source(), indices);
|
|
changed = true;
|
|
}
|
|
}
|
|
} else if (auto storeOp = dyn_cast<AffineLoadOp>(use.getOwner())) {
|
|
if (storeOp.memref() == subindex) {
|
|
if (subindex.getType().cast<MemRefType>().getShape().size() + 1 ==
|
|
subindex.source()
|
|
.getType()
|
|
.cast<MemRefType>()
|
|
.getShape()
|
|
.size()) {
|
|
|
|
std::vector<Value> indices;
|
|
auto map = storeOp.getAffineMap();
|
|
indices.push_back(off);
|
|
for (size_t i = 0; i < map.getNumResults(); i++) {
|
|
auto apply = rewriter.create<AffineApplyOp>(
|
|
storeOp.getLoc(), map.getSliceMap(i, 1),
|
|
storeOp.getMapOperands());
|
|
indices.push_back(apply->getResult(0));
|
|
}
|
|
assert(subindex.source()
|
|
.getType()
|
|
.cast<MemRefType>()
|
|
.getShape()
|
|
.size() == indices.size());
|
|
rewriter.replaceOpWithNewOp<memref::LoadOp>(
|
|
storeOp, subindex.source(), indices);
|
|
changed = true;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return success(changed);
|
|
}
|
|
};
|
|
|
|
/// Simplify select cast(x), cast(y) to cast(select x, y)
|
|
struct SelectOfCast : public OpRewritePattern<SelectOp> {
|
|
using OpRewritePattern<SelectOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(SelectOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
auto cst1 = op.getTrueValue().getDefiningOp<memref::CastOp>();
|
|
if (!cst1)
|
|
return failure();
|
|
|
|
auto cst2 = op.getFalseValue().getDefiningOp<memref::CastOp>();
|
|
if (!cst2)
|
|
return failure();
|
|
|
|
if (cst1.source().getType() != cst2.source().getType())
|
|
return failure();
|
|
|
|
auto newSel = rewriter.create<SelectOp>(op.getLoc(), op.getCondition(),
|
|
cst1.source(), cst2.source());
|
|
|
|
rewriter.replaceOpWithNewOp<memref::CastOp>(op, op.getType(), newSel);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Simplify select subindex(x), subindex(y) to subindex(select x, y)
|
|
struct SelectOfSubIndex : public OpRewritePattern<SelectOp> {
|
|
using OpRewritePattern<SelectOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(SelectOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
auto cst1 = op.getTrueValue().getDefiningOp<SubIndexOp>();
|
|
if (!cst1)
|
|
return failure();
|
|
|
|
auto cst2 = op.getFalseValue().getDefiningOp<SubIndexOp>();
|
|
if (!cst2)
|
|
return failure();
|
|
|
|
if (cst1.source().getType() != cst2.source().getType())
|
|
return failure();
|
|
|
|
auto newSel = rewriter.create<SelectOp>(op.getLoc(), op.getCondition(),
|
|
cst1.source(), cst2.source());
|
|
auto newIdx = rewriter.create<SelectOp>(op.getLoc(), op.getCondition(),
|
|
cst1.index(), cst2.index());
|
|
rewriter.replaceOpWithNewOp<SubIndexOp>(op, op.getType(), newSel, newIdx);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void SubIndexOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|
MLIRContext *context) {
|
|
results.insert<CastOfSubIndex, SubIndex2,
|
|
SubToCast, SimplifySubViewUsers, SimplifySubIndexUsers,
|
|
SelectOfCast, SelectOfSubIndex, RedundantDynSubIndex>(context);
|
|
// Disabled: SubToSubView
|
|
}
|
|
|
|
/// Simplify pointer2memref(memref2pointer(x)) to cast(x)
|
|
class Memref2Pointer2MemrefCast final
|
|
: public OpRewritePattern<Pointer2MemrefOp> {
|
|
public:
|
|
using OpRewritePattern<Pointer2MemrefOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(Pointer2MemrefOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
auto src = op.source().getDefiningOp<Memref2PointerOp>();
|
|
if (!src)
|
|
return failure();
|
|
if (src.source().getType().cast<MemRefType>().getShape().size() !=
|
|
op.getType().cast<MemRefType>().getShape().size())
|
|
return failure();
|
|
if (src.source().getType().cast<MemRefType>().getElementType() !=
|
|
op.getType().cast<MemRefType>().getElementType())
|
|
return failure();
|
|
if (src.source().getType().cast<MemRefType>().getMemorySpace() !=
|
|
op.getType().cast<MemRefType>().getMemorySpace())
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<memref::CastOp>(op, op.getType(), src.source());
|
|
return success();
|
|
}
|
|
};
|
|
/// Simplify pointer2memref(memref2pointer(x)) to cast(x)
|
|
class Memref2PointerIndex final
|
|
: public OpRewritePattern<Memref2PointerOp> {
|
|
public:
|
|
using OpRewritePattern<Memref2PointerOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(Memref2PointerOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
auto src = op.source().getDefiningOp<SubIndexOp>();
|
|
if (!src)
|
|
return failure();
|
|
|
|
if (src.source().getType().cast<MemRefType>().getShape().size() != 1) return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, op.getType(), rewriter.create<Memref2PointerOp>(op.getLoc(), op.getType(), src.source()),
|
|
std::vector<Value>({rewriter.create<arith::IndexCastOp>(op.getLoc(), rewriter.getI64Type(), src.index())}));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
OpFoldResult Memref2PointerOp::fold(ArrayRef<Attribute> operands) {
|
|
if (auto subindex = source().getDefiningOp<SubIndexOp>()) {
|
|
if (auto cop = subindex.index().getDefiningOp<ConstantIndexOp>()) {
|
|
if (cop.value() == 0) {
|
|
sourceMutable().assign(subindex.source());
|
|
return result();
|
|
}
|
|
}
|
|
}
|
|
/// Simplify memref2pointer(cast(x)) to memref2pointer(x)
|
|
if (auto mc = source().getDefiningOp<memref::CastOp>()) {
|
|
sourceMutable().assign(mc.source());
|
|
return result();
|
|
}
|
|
if (auto mc = source().getDefiningOp<polygeist::Pointer2MemrefOp>()) {
|
|
if (mc.source().getType() == getType()) {
|
|
return mc.source();
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
void Memref2PointerOp::getCanonicalizationPatterns(
|
|
OwningRewritePatternList &results, MLIRContext *context) {
|
|
results.insert<Memref2Pointer2MemrefCast, Memref2PointerIndex>(context);
|
|
}
|
|
|
|
/// Simplify cast(pointer2memref(x)) to pointer2memref(x)
|
|
class Pointer2MemrefCast final : public OpRewritePattern<memref::CastOp> {
|
|
public:
|
|
using OpRewritePattern<memref::CastOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(memref::CastOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
auto src = op.source().getDefiningOp<Pointer2MemrefOp>();
|
|
if (!src)
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<polygeist::Pointer2MemrefOp>(op, op.getType(),
|
|
src.source());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Simplify memref2pointer(pointer2memref(x)) to cast(x)
|
|
class Pointer2Memref2PointerCast final
|
|
: public OpRewritePattern<Memref2PointerOp> {
|
|
public:
|
|
using OpRewritePattern<Memref2PointerOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(Memref2PointerOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
auto src = op.source().getDefiningOp<Pointer2MemrefOp>();
|
|
if (!src)
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getType(),
|
|
src.source());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Simplify load (pointer2memref(x)) to llvm.load x
|
|
template<typename Op>
|
|
class MetaPointer2Memref final : public OpRewritePattern<Op> {
|
|
public:
|
|
using OpRewritePattern<Op>::OpRewritePattern;
|
|
|
|
Value computeIndex(Op op, size_t idx, PatternRewriter &rewriter) const;
|
|
|
|
void rewrite(Op op, Value ptr, PatternRewriter &rewriter) const;
|
|
|
|
LogicalResult matchAndRewrite(Op op,
|
|
PatternRewriter &rewriter) const override {
|
|
Value opPtr = op.memref();
|
|
Pointer2MemrefOp src = opPtr.getDefiningOp<polygeist::Pointer2MemrefOp>();
|
|
if (!src)
|
|
return failure();
|
|
|
|
auto mt = src.getType().cast<MemRefType>();
|
|
for (size_t i=1; i<mt.getShape().size(); i++)
|
|
if (mt.getShape()[i] == ShapedType::kDynamicSize)
|
|
return failure();
|
|
|
|
Value val = src.source();
|
|
if (val.getType().cast<LLVM::LLVMPointerType>().getElementType() !=
|
|
mt.getElementType())
|
|
val = rewriter.create<LLVM::BitcastOp>(op.getLoc(), LLVM::LLVMPointerType::get(mt.getElementType(),
|
|
val.getType().cast<LLVM::LLVMPointerType>().getAddressSpace()), val);
|
|
|
|
Value idx = nullptr;
|
|
auto shape = mt.getShape();
|
|
for (size_t i = 0; i < shape.size(); i++) {
|
|
auto off = computeIndex(op, i, rewriter);
|
|
auto cur = rewriter.create<IndexCastOp>(
|
|
op.getLoc(), rewriter.getI32Type(), off);
|
|
if (idx == nullptr) {
|
|
idx = cur;
|
|
} else {
|
|
idx = rewriter.create<AddIOp>(
|
|
op.getLoc(),
|
|
rewriter.create<MulIOp>(
|
|
op.getLoc(), idx,
|
|
rewriter.create<ConstantIntOp>(
|
|
op.getLoc(),
|
|
shape[i],
|
|
32)),
|
|
cur);
|
|
}
|
|
}
|
|
|
|
if (idx) {
|
|
Value idxs[] = {idx};
|
|
val = rewriter.create<LLVM::GEPOp>(op.getLoc(), val.getType(), val, idxs);
|
|
}
|
|
rewrite(op, val, rewriter);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
template<>
|
|
Value MetaPointer2Memref<memref::LoadOp>::computeIndex(memref::LoadOp op, size_t i, PatternRewriter &rewriter) const {
|
|
return op.indices()[i];
|
|
}
|
|
|
|
template<>
|
|
void MetaPointer2Memref<memref::LoadOp>::rewrite(memref::LoadOp op, Value ptr, PatternRewriter &rewriter) const {
|
|
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, op.getType(), ptr);
|
|
}
|
|
|
|
template<>
|
|
Value MetaPointer2Memref<memref::StoreOp>::computeIndex(memref::StoreOp op, size_t i, PatternRewriter &rewriter) const {
|
|
return op.indices()[i];
|
|
}
|
|
|
|
template<>
|
|
void MetaPointer2Memref<memref::StoreOp>::rewrite(memref::StoreOp op, Value ptr, PatternRewriter &rewriter) const {
|
|
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, op.value(), ptr);
|
|
}
|
|
|
|
template<>
|
|
Value MetaPointer2Memref<AffineLoadOp>::computeIndex(AffineLoadOp op, size_t i, PatternRewriter &rewriter) const {
|
|
auto map = op.getAffineMap();
|
|
auto apply = rewriter.create<AffineApplyOp>(
|
|
op.getLoc(), map.getSliceMap(i, 1),
|
|
op.getMapOperands());
|
|
return apply->getResult(0);
|
|
}
|
|
|
|
template<>
|
|
void MetaPointer2Memref<AffineLoadOp>::rewrite(AffineLoadOp op, Value ptr, PatternRewriter &rewriter) const {
|
|
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, op.getType(), ptr);
|
|
}
|
|
|
|
template<>
|
|
Value MetaPointer2Memref<AffineStoreOp>::computeIndex(AffineStoreOp op, size_t i, PatternRewriter &rewriter) const {
|
|
auto map = op.getAffineMap();
|
|
auto apply = rewriter.create<AffineApplyOp>(
|
|
op.getLoc(), map.getSliceMap(i, 1),
|
|
op.getMapOperands());
|
|
return apply->getResult(0);
|
|
}
|
|
|
|
template<>
|
|
void MetaPointer2Memref<AffineStoreOp>::rewrite(AffineStoreOp op, Value ptr, PatternRewriter &rewriter) const {
|
|
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, op.value(), ptr);
|
|
}
|
|
|
|
void Pointer2MemrefOp::getCanonicalizationPatterns(
|
|
OwningRewritePatternList &results, MLIRContext *context) {
|
|
results.insert<Pointer2MemrefCast, Pointer2Memref2PointerCast,
|
|
MetaPointer2Memref<memref::LoadOp>,
|
|
MetaPointer2Memref<memref::StoreOp>,
|
|
MetaPointer2Memref<AffineLoadOp>,
|
|
MetaPointer2Memref<AffineStoreOp>
|
|
>(context);
|
|
}
|
|
|
|
OpFoldResult Pointer2MemrefOp::fold(ArrayRef<Attribute> operands) {
|
|
/// Simplify pointer2memref(cast(x)) to pointer2memref(x)
|
|
if (auto mc = source().getDefiningOp<LLVM::BitcastOp>()) {
|
|
sourceMutable().assign(mc.getArg());
|
|
return result();
|
|
}
|
|
if (auto mc = source().getDefiningOp<LLVM::AddrSpaceCastOp>()) {
|
|
sourceMutable().assign(mc.getArg());
|
|
return result();
|
|
}
|
|
if (auto mc = source().getDefiningOp<LLVM::GEPOp>()) {
|
|
for (auto idx : mc.getIndices()) {
|
|
if (!matchPattern(idx, m_Zero()))
|
|
return nullptr;
|
|
}
|
|
sourceMutable().assign(mc.getBase());
|
|
return result();
|
|
}
|
|
if (auto mc = source().getDefiningOp<polygeist::Memref2PointerOp>()) {
|
|
if (mc.source().getType() == getType()) {
|
|
return mc.source();
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
OpFoldResult SubIndexOp::fold(ArrayRef<Attribute> operands) {
|
|
if (result().getType() == source().getType()) {
|
|
if (matchPattern(index(), m_Zero()))
|
|
return source();
|
|
}
|
|
/// Replace subindex(cast(x)) with subindex(x)
|
|
if (auto castOp = source().getDefiningOp<memref::CastOp>()) {
|
|
if (castOp.getType().cast<MemRefType>().getElementType() == result().getType().cast<MemRefType>().getElementType()) {
|
|
sourceMutable().assign(castOp.source());
|
|
return result();
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|