[mlir][MemRef] Fix SubViewOp canonicalization when a subset of unit-dims are dropped.
The canonical type of the result of the `memref.subview` needs to make sure that the previously dropped unit-dimensions are the ones dropped for the canonicalized type as well. This means the generic `inferRankReducedResultType` cannot be used. Instead the current dropped dimensions need to be querried and the same need to be dropped. Reviewed By: nicolasvasilache, ThomasRaoux Differential Revision: https://reviews.llvm.org/D114751
This commit is contained in:
parent
18308e171b
commit
311dd55c9e
|
@ -63,6 +63,8 @@ public:
|
|||
ResultTypeFunc resultTypeFunc;
|
||||
auto resultType =
|
||||
resultTypeFunc(op, mixedOffsets, mixedSizes, mixedStrides);
|
||||
if (!resultType)
|
||||
return failure();
|
||||
auto newOp =
|
||||
rewriter.create<OpType>(op.getLoc(), resultType, op.source(),
|
||||
mixedOffsets, mixedSizes, mixedStrides);
|
||||
|
|
|
@ -511,14 +511,16 @@ static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
|
|||
/// dimension is dropped the stride must be dropped too.
|
||||
static llvm::Optional<llvm::SmallDenseSet<unsigned>>
|
||||
computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
|
||||
ArrayAttr staticSizes) {
|
||||
ArrayRef<OpFoldResult> sizes) {
|
||||
llvm::SmallDenseSet<unsigned> unusedDims;
|
||||
if (originalType.getRank() == reducedType.getRank())
|
||||
return unusedDims;
|
||||
|
||||
for (auto dim : llvm::enumerate(staticSizes))
|
||||
if (dim.value().cast<IntegerAttr>().getInt() == 1)
|
||||
for (auto dim : llvm::enumerate(sizes))
|
||||
if (auto attr = dim.value().dyn_cast<Attribute>())
|
||||
if (attr.cast<IntegerAttr>().getInt() == 1)
|
||||
unusedDims.insert(dim.index());
|
||||
|
||||
SmallVector<int64_t> originalStrides, candidateStrides;
|
||||
int64_t originalOffset, candidateOffset;
|
||||
if (failed(
|
||||
|
@ -574,7 +576,7 @@ llvm::SmallDenseSet<unsigned> SubViewOp::getDroppedDims() {
|
|||
MemRefType sourceType = getSourceType();
|
||||
MemRefType resultType = getType();
|
||||
llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims =
|
||||
computeMemRefRankReductionMask(sourceType, resultType, static_sizes());
|
||||
computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
|
||||
assert(unusedDims && "unable to find unused dims of subview");
|
||||
return *unusedDims;
|
||||
}
|
||||
|
@ -1718,7 +1720,7 @@ enum SubViewVerificationResult {
|
|||
/// not matching dimension must be 1.
|
||||
static SubViewVerificationResult
|
||||
isRankReducedType(Type originalType, Type candidateReducedType,
|
||||
ArrayAttr staticSizes, std::string *errMsg = nullptr) {
|
||||
ArrayRef<OpFoldResult> sizes, std::string *errMsg = nullptr) {
|
||||
if (originalType == candidateReducedType)
|
||||
return SubViewVerificationResult::Success;
|
||||
if (!originalType.isa<MemRefType>())
|
||||
|
@ -1743,7 +1745,7 @@ isRankReducedType(Type originalType, Type candidateReducedType,
|
|||
MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
|
||||
|
||||
auto optionalUnusedDimsMask =
|
||||
computeMemRefRankReductionMask(original, candidateReduced, staticSizes);
|
||||
computeMemRefRankReductionMask(original, candidateReduced, sizes);
|
||||
|
||||
// Sizes cannot be matched in case empty vector is returned.
|
||||
if (!optionalUnusedDimsMask.hasValue())
|
||||
|
@ -1813,7 +1815,7 @@ static LogicalResult verify(SubViewOp op) {
|
|||
|
||||
std::string errMsg;
|
||||
auto result =
|
||||
isRankReducedType(expectedType, subViewType, op.static_sizes(), &errMsg);
|
||||
isRankReducedType(expectedType, subViewType, op.getMixedSizes(), &errMsg);
|
||||
return produceSubViewErrorMsg(result, op, expectedType, errMsg);
|
||||
}
|
||||
|
||||
|
@ -1854,21 +1856,29 @@ SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
|
|||
/// Infer the canonical type of the result of a subview operation. Returns a
|
||||
/// type with rank `resultRank` that is either the rank of the rank-reduced
|
||||
/// type, or the non-rank-reduced type.
|
||||
static MemRefType
|
||||
getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType,
|
||||
ArrayRef<OpFoldResult> mixedOffsets,
|
||||
ArrayRef<OpFoldResult> mixedSizes,
|
||||
static MemRefType getCanonicalSubViewResultType(
|
||||
MemRefType currentResultType, MemRefType sourceType,
|
||||
ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
|
||||
ArrayRef<OpFoldResult> mixedStrides) {
|
||||
auto resultType =
|
||||
SubViewOp::inferRankReducedResultType(
|
||||
resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
|
||||
.cast<MemRefType>();
|
||||
if (resultType.getRank() != resultRank) {
|
||||
resultType = SubViewOp::inferResultType(sourceType, mixedOffsets,
|
||||
auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets,
|
||||
mixedSizes, mixedStrides)
|
||||
.cast<MemRefType>();
|
||||
llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims =
|
||||
computeMemRefRankReductionMask(sourceType, currentResultType, mixedSizes);
|
||||
// Return nullptr as failure mode.
|
||||
if (!unusedDims)
|
||||
return nullptr;
|
||||
SmallVector<int64_t> shape;
|
||||
for (auto sizes : llvm::enumerate(nonRankReducedType.getShape())) {
|
||||
if (unusedDims->count(sizes.index()))
|
||||
continue;
|
||||
shape.push_back(sizes.value());
|
||||
}
|
||||
return resultType;
|
||||
AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap();
|
||||
if (!layoutMap.isIdentity())
|
||||
layoutMap = getProjectedMap(layoutMap, unusedDims.getValue());
|
||||
return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap,
|
||||
nonRankReducedType.getMemorySpace());
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -1911,8 +1921,7 @@ public:
|
|||
/// the cast source operand type and the SubViewOp static information. This
|
||||
/// is the resulting type if the MemRefCastOp were folded.
|
||||
auto resultType = getCanonicalSubViewResultType(
|
||||
subViewOp.getType().getRank(),
|
||||
castOp.source().getType().cast<MemRefType>(),
|
||||
subViewOp.getType(), castOp.source().getType().cast<MemRefType>(),
|
||||
subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
|
||||
subViewOp.getMixedStrides());
|
||||
Value newSubView = rewriter.create<SubViewOp>(
|
||||
|
@ -1931,9 +1940,9 @@ struct SubViewReturnTypeCanonicalizer {
|
|||
MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
|
||||
ArrayRef<OpFoldResult> mixedSizes,
|
||||
ArrayRef<OpFoldResult> mixedStrides) {
|
||||
return getCanonicalSubViewResultType(op.getType().getRank(),
|
||||
op.getSourceType(), mixedOffsets,
|
||||
mixedSizes, mixedStrides);
|
||||
return getCanonicalSubViewResultType(op.getType(), op.getSourceType(),
|
||||
mixedOffsets, mixedSizes,
|
||||
mixedStrides);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ func @subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
|
|||
|
||||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
|
||||
#map0 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
|
||||
func @rank_reducing_subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
|
||||
%arg2 : index) -> memref<?x?xf32, #map0>
|
||||
{
|
||||
|
@ -395,3 +395,25 @@ func @collapse_after_memref_cast(%arg0 : memref<?x512x1x?xf32>) -> memref<?x?xf3
|
|||
%collapsed = memref.collapse_shape %dynamic [[0], [1, 2, 3]] : memref<?x?x?x?xf32> into memref<?x?xf32>
|
||||
return %collapsed : memref<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @reduced_memref(%arg0: memref<2x5x7x1xf32>, %arg1 :index)
|
||||
-> memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c5 = arith.constant 5 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%0 = memref.subview %arg0[%arg1, %arg1, %arg1, 0] [%c1, %c4, %c1, 1] [1, 1, 1, 1]
|
||||
: memref<2x5x7x1xf32> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>>
|
||||
%1 = memref.cast %0
|
||||
: memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>> to
|
||||
memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>>
|
||||
return %1 : memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @reduced_memref
|
||||
// CHECK: %[[RESULT:.+]] = memref.subview
|
||||
// CHECK-SAME: memref<2x5x7x1xf32> to memref<1x4x1xf32, #{{.+}}>
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
|
Loading…
Reference in New Issue