[mlir][MemRef] Fix SubViewOp canonicalization when a subset of unit-dims are dropped.

The canonical type of the result of the `memref.subview` needs to make
sure that the previously dropped unit-dimensions are the ones dropped
for the canonicalized type as well. This means the generic
`inferRankReducedResultType` cannot be used. Instead the current
dropped dimensions need to be querried and the same need to be dropped.

Reviewed By: nicolasvasilache, ThomasRaoux

Differential Revision: https://reviews.llvm.org/D114751
This commit is contained in:
MaheshRavishankar 2021-11-30 15:46:21 +00:00 committed by Nicolas Vasilache
parent 18308e171b
commit 311dd55c9e
3 changed files with 61 additions and 28 deletions

View File

@ -63,6 +63,8 @@ public:
ResultTypeFunc resultTypeFunc;
auto resultType =
resultTypeFunc(op, mixedOffsets, mixedSizes, mixedStrides);
if (!resultType)
return failure();
auto newOp =
rewriter.create<OpType>(op.getLoc(), resultType, op.source(),
mixedOffsets, mixedSizes, mixedStrides);

View File

@ -511,14 +511,16 @@ static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
/// dimension is dropped the stride must be dropped too.
static llvm::Optional<llvm::SmallDenseSet<unsigned>>
computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
ArrayAttr staticSizes) {
ArrayRef<OpFoldResult> sizes) {
llvm::SmallDenseSet<unsigned> unusedDims;
if (originalType.getRank() == reducedType.getRank())
return unusedDims;
for (auto dim : llvm::enumerate(staticSizes))
if (dim.value().cast<IntegerAttr>().getInt() == 1)
unusedDims.insert(dim.index());
for (auto dim : llvm::enumerate(sizes))
if (auto attr = dim.value().dyn_cast<Attribute>())
if (attr.cast<IntegerAttr>().getInt() == 1)
unusedDims.insert(dim.index());
SmallVector<int64_t> originalStrides, candidateStrides;
int64_t originalOffset, candidateOffset;
if (failed(
@ -574,7 +576,7 @@ llvm::SmallDenseSet<unsigned> SubViewOp::getDroppedDims() {
MemRefType sourceType = getSourceType();
MemRefType resultType = getType();
llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims =
computeMemRefRankReductionMask(sourceType, resultType, static_sizes());
computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
assert(unusedDims && "unable to find unused dims of subview");
return *unusedDims;
}
@ -1718,7 +1720,7 @@ enum SubViewVerificationResult {
/// not matching dimension must be 1.
static SubViewVerificationResult
isRankReducedType(Type originalType, Type candidateReducedType,
ArrayAttr staticSizes, std::string *errMsg = nullptr) {
ArrayRef<OpFoldResult> sizes, std::string *errMsg = nullptr) {
if (originalType == candidateReducedType)
return SubViewVerificationResult::Success;
if (!originalType.isa<MemRefType>())
@ -1743,7 +1745,7 @@ isRankReducedType(Type originalType, Type candidateReducedType,
MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
auto optionalUnusedDimsMask =
computeMemRefRankReductionMask(original, candidateReduced, staticSizes);
computeMemRefRankReductionMask(original, candidateReduced, sizes);
// Sizes cannot be matched in case empty vector is returned.
if (!optionalUnusedDimsMask.hasValue())
@ -1813,7 +1815,7 @@ static LogicalResult verify(SubViewOp op) {
std::string errMsg;
auto result =
isRankReducedType(expectedType, subViewType, op.static_sizes(), &errMsg);
isRankReducedType(expectedType, subViewType, op.getMixedSizes(), &errMsg);
return produceSubViewErrorMsg(result, op, expectedType, errMsg);
}
@ -1854,21 +1856,29 @@ SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
/// Infer the canonical type of the result of a subview operation. Returns a
/// type with rank `resultRank` that is either the rank of the rank-reduced
/// type, or the non-rank-reduced type.
static MemRefType
getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType,
ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
auto resultType =
SubViewOp::inferRankReducedResultType(
resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
.cast<MemRefType>();
if (resultType.getRank() != resultRank) {
resultType = SubViewOp::inferResultType(sourceType, mixedOffsets,
mixedSizes, mixedStrides)
.cast<MemRefType>();
static MemRefType getCanonicalSubViewResultType(
MemRefType currentResultType, MemRefType sourceType,
ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets,
mixedSizes, mixedStrides)
.cast<MemRefType>();
llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims =
computeMemRefRankReductionMask(sourceType, currentResultType, mixedSizes);
// Return nullptr as failure mode.
if (!unusedDims)
return nullptr;
SmallVector<int64_t> shape;
for (auto sizes : llvm::enumerate(nonRankReducedType.getShape())) {
if (unusedDims->count(sizes.index()))
continue;
shape.push_back(sizes.value());
}
return resultType;
AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap();
if (!layoutMap.isIdentity())
layoutMap = getProjectedMap(layoutMap, unusedDims.getValue());
return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap,
nonRankReducedType.getMemorySpace());
}
namespace {
@ -1911,8 +1921,7 @@ public:
/// the cast source operand type and the SubViewOp static information. This
/// is the resulting type if the MemRefCastOp were folded.
auto resultType = getCanonicalSubViewResultType(
subViewOp.getType().getRank(),
castOp.source().getType().cast<MemRefType>(),
subViewOp.getType(), castOp.source().getType().cast<MemRefType>(),
subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
subViewOp.getMixedStrides());
Value newSubView = rewriter.create<SubViewOp>(
@ -1931,9 +1940,9 @@ struct SubViewReturnTypeCanonicalizer {
MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
return getCanonicalSubViewResultType(op.getType().getRank(),
op.getSourceType(), mixedOffsets,
mixedSizes, mixedStrides);
return getCanonicalSubViewResultType(op.getType(), op.getSourceType(),
mixedOffsets, mixedSizes,
mixedStrides);
}
};

View File

@ -47,7 +47,7 @@ func @subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
// -----
#map0 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
#map0 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
func @rank_reducing_subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
%arg2 : index) -> memref<?x?xf32, #map0>
{
@ -395,3 +395,25 @@ func @collapse_after_memref_cast(%arg0 : memref<?x512x1x?xf32>) -> memref<?x?xf3
%collapsed = memref.collapse_shape %dynamic [[0], [1, 2, 3]] : memref<?x?x?x?xf32> into memref<?x?xf32>
return %collapsed : memref<?x?xf32>
}
// -----
func @reduced_memref(%arg0: memref<2x5x7x1xf32>, %arg1 :index)
-> memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>> {
%c0 = arith.constant 0 : index
%c5 = arith.constant 5 : index
%c4 = arith.constant 4 : index
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%0 = memref.subview %arg0[%arg1, %arg1, %arg1, 0] [%c1, %c4, %c1, 1] [1, 1, 1, 1]
: memref<2x5x7x1xf32> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>>
%1 = memref.cast %0
: memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>> to
memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>>
return %1 : memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>>
}
// CHECK-LABEL: func @reduced_memref
// CHECK: %[[RESULT:.+]] = memref.subview
// CHECK-SAME: memref<2x5x7x1xf32> to memref<1x4x1xf32, #{{.+}}>
// CHECK: return %[[RESULT]]