[mlir][MemRef] Fix SubViewOp canonicalization when a subset of unit-dims are dropped.
The canonical type of the result of the `memref.subview` needs to make sure that the previously dropped unit-dimensions are the ones dropped for the canonicalized type as well. This means the generic `inferRankReducedResultType` cannot be used. Instead the current dropped dimensions need to be querried and the same need to be dropped. Reviewed By: nicolasvasilache, ThomasRaoux Differential Revision: https://reviews.llvm.org/D114751
This commit is contained in:
parent
18308e171b
commit
311dd55c9e
|
@ -63,6 +63,8 @@ public:
|
||||||
ResultTypeFunc resultTypeFunc;
|
ResultTypeFunc resultTypeFunc;
|
||||||
auto resultType =
|
auto resultType =
|
||||||
resultTypeFunc(op, mixedOffsets, mixedSizes, mixedStrides);
|
resultTypeFunc(op, mixedOffsets, mixedSizes, mixedStrides);
|
||||||
|
if (!resultType)
|
||||||
|
return failure();
|
||||||
auto newOp =
|
auto newOp =
|
||||||
rewriter.create<OpType>(op.getLoc(), resultType, op.source(),
|
rewriter.create<OpType>(op.getLoc(), resultType, op.source(),
|
||||||
mixedOffsets, mixedSizes, mixedStrides);
|
mixedOffsets, mixedSizes, mixedStrides);
|
||||||
|
|
|
@ -511,14 +511,16 @@ static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
|
||||||
/// dimension is dropped the stride must be dropped too.
|
/// dimension is dropped the stride must be dropped too.
|
||||||
static llvm::Optional<llvm::SmallDenseSet<unsigned>>
|
static llvm::Optional<llvm::SmallDenseSet<unsigned>>
|
||||||
computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
|
computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
|
||||||
ArrayAttr staticSizes) {
|
ArrayRef<OpFoldResult> sizes) {
|
||||||
llvm::SmallDenseSet<unsigned> unusedDims;
|
llvm::SmallDenseSet<unsigned> unusedDims;
|
||||||
if (originalType.getRank() == reducedType.getRank())
|
if (originalType.getRank() == reducedType.getRank())
|
||||||
return unusedDims;
|
return unusedDims;
|
||||||
|
|
||||||
for (auto dim : llvm::enumerate(staticSizes))
|
for (auto dim : llvm::enumerate(sizes))
|
||||||
if (dim.value().cast<IntegerAttr>().getInt() == 1)
|
if (auto attr = dim.value().dyn_cast<Attribute>())
|
||||||
|
if (attr.cast<IntegerAttr>().getInt() == 1)
|
||||||
unusedDims.insert(dim.index());
|
unusedDims.insert(dim.index());
|
||||||
|
|
||||||
SmallVector<int64_t> originalStrides, candidateStrides;
|
SmallVector<int64_t> originalStrides, candidateStrides;
|
||||||
int64_t originalOffset, candidateOffset;
|
int64_t originalOffset, candidateOffset;
|
||||||
if (failed(
|
if (failed(
|
||||||
|
@ -574,7 +576,7 @@ llvm::SmallDenseSet<unsigned> SubViewOp::getDroppedDims() {
|
||||||
MemRefType sourceType = getSourceType();
|
MemRefType sourceType = getSourceType();
|
||||||
MemRefType resultType = getType();
|
MemRefType resultType = getType();
|
||||||
llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims =
|
llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims =
|
||||||
computeMemRefRankReductionMask(sourceType, resultType, static_sizes());
|
computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
|
||||||
assert(unusedDims && "unable to find unused dims of subview");
|
assert(unusedDims && "unable to find unused dims of subview");
|
||||||
return *unusedDims;
|
return *unusedDims;
|
||||||
}
|
}
|
||||||
|
@ -1718,7 +1720,7 @@ enum SubViewVerificationResult {
|
||||||
/// not matching dimension must be 1.
|
/// not matching dimension must be 1.
|
||||||
static SubViewVerificationResult
|
static SubViewVerificationResult
|
||||||
isRankReducedType(Type originalType, Type candidateReducedType,
|
isRankReducedType(Type originalType, Type candidateReducedType,
|
||||||
ArrayAttr staticSizes, std::string *errMsg = nullptr) {
|
ArrayRef<OpFoldResult> sizes, std::string *errMsg = nullptr) {
|
||||||
if (originalType == candidateReducedType)
|
if (originalType == candidateReducedType)
|
||||||
return SubViewVerificationResult::Success;
|
return SubViewVerificationResult::Success;
|
||||||
if (!originalType.isa<MemRefType>())
|
if (!originalType.isa<MemRefType>())
|
||||||
|
@ -1743,7 +1745,7 @@ isRankReducedType(Type originalType, Type candidateReducedType,
|
||||||
MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
|
MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
|
||||||
|
|
||||||
auto optionalUnusedDimsMask =
|
auto optionalUnusedDimsMask =
|
||||||
computeMemRefRankReductionMask(original, candidateReduced, staticSizes);
|
computeMemRefRankReductionMask(original, candidateReduced, sizes);
|
||||||
|
|
||||||
// Sizes cannot be matched in case empty vector is returned.
|
// Sizes cannot be matched in case empty vector is returned.
|
||||||
if (!optionalUnusedDimsMask.hasValue())
|
if (!optionalUnusedDimsMask.hasValue())
|
||||||
|
@ -1813,7 +1815,7 @@ static LogicalResult verify(SubViewOp op) {
|
||||||
|
|
||||||
std::string errMsg;
|
std::string errMsg;
|
||||||
auto result =
|
auto result =
|
||||||
isRankReducedType(expectedType, subViewType, op.static_sizes(), &errMsg);
|
isRankReducedType(expectedType, subViewType, op.getMixedSizes(), &errMsg);
|
||||||
return produceSubViewErrorMsg(result, op, expectedType, errMsg);
|
return produceSubViewErrorMsg(result, op, expectedType, errMsg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1854,21 +1856,29 @@ SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
|
||||||
/// Infer the canonical type of the result of a subview operation. Returns a
|
/// Infer the canonical type of the result of a subview operation. Returns a
|
||||||
/// type with rank `resultRank` that is either the rank of the rank-reduced
|
/// type with rank `resultRank` that is either the rank of the rank-reduced
|
||||||
/// type, or the non-rank-reduced type.
|
/// type, or the non-rank-reduced type.
|
||||||
static MemRefType
|
static MemRefType getCanonicalSubViewResultType(
|
||||||
getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType,
|
MemRefType currentResultType, MemRefType sourceType,
|
||||||
ArrayRef<OpFoldResult> mixedOffsets,
|
ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
|
||||||
ArrayRef<OpFoldResult> mixedSizes,
|
|
||||||
ArrayRef<OpFoldResult> mixedStrides) {
|
ArrayRef<OpFoldResult> mixedStrides) {
|
||||||
auto resultType =
|
auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets,
|
||||||
SubViewOp::inferRankReducedResultType(
|
|
||||||
resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
|
|
||||||
.cast<MemRefType>();
|
|
||||||
if (resultType.getRank() != resultRank) {
|
|
||||||
resultType = SubViewOp::inferResultType(sourceType, mixedOffsets,
|
|
||||||
mixedSizes, mixedStrides)
|
mixedSizes, mixedStrides)
|
||||||
.cast<MemRefType>();
|
.cast<MemRefType>();
|
||||||
|
llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims =
|
||||||
|
computeMemRefRankReductionMask(sourceType, currentResultType, mixedSizes);
|
||||||
|
// Return nullptr as failure mode.
|
||||||
|
if (!unusedDims)
|
||||||
|
return nullptr;
|
||||||
|
SmallVector<int64_t> shape;
|
||||||
|
for (auto sizes : llvm::enumerate(nonRankReducedType.getShape())) {
|
||||||
|
if (unusedDims->count(sizes.index()))
|
||||||
|
continue;
|
||||||
|
shape.push_back(sizes.value());
|
||||||
}
|
}
|
||||||
return resultType;
|
AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap();
|
||||||
|
if (!layoutMap.isIdentity())
|
||||||
|
layoutMap = getProjectedMap(layoutMap, unusedDims.getValue());
|
||||||
|
return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap,
|
||||||
|
nonRankReducedType.getMemorySpace());
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -1911,8 +1921,7 @@ public:
|
||||||
/// the cast source operand type and the SubViewOp static information. This
|
/// the cast source operand type and the SubViewOp static information. This
|
||||||
/// is the resulting type if the MemRefCastOp were folded.
|
/// is the resulting type if the MemRefCastOp were folded.
|
||||||
auto resultType = getCanonicalSubViewResultType(
|
auto resultType = getCanonicalSubViewResultType(
|
||||||
subViewOp.getType().getRank(),
|
subViewOp.getType(), castOp.source().getType().cast<MemRefType>(),
|
||||||
castOp.source().getType().cast<MemRefType>(),
|
|
||||||
subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
|
subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
|
||||||
subViewOp.getMixedStrides());
|
subViewOp.getMixedStrides());
|
||||||
Value newSubView = rewriter.create<SubViewOp>(
|
Value newSubView = rewriter.create<SubViewOp>(
|
||||||
|
@ -1931,9 +1940,9 @@ struct SubViewReturnTypeCanonicalizer {
|
||||||
MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
|
MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
|
||||||
ArrayRef<OpFoldResult> mixedSizes,
|
ArrayRef<OpFoldResult> mixedSizes,
|
||||||
ArrayRef<OpFoldResult> mixedStrides) {
|
ArrayRef<OpFoldResult> mixedStrides) {
|
||||||
return getCanonicalSubViewResultType(op.getType().getRank(),
|
return getCanonicalSubViewResultType(op.getType(), op.getSourceType(),
|
||||||
op.getSourceType(), mixedOffsets,
|
mixedOffsets, mixedSizes,
|
||||||
mixedSizes, mixedStrides);
|
mixedStrides);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -47,7 +47,7 @@ func @subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
#map0 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
|
#map0 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
|
||||||
func @rank_reducing_subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
|
func @rank_reducing_subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
|
||||||
%arg2 : index) -> memref<?x?xf32, #map0>
|
%arg2 : index) -> memref<?x?xf32, #map0>
|
||||||
{
|
{
|
||||||
|
@ -395,3 +395,25 @@ func @collapse_after_memref_cast(%arg0 : memref<?x512x1x?xf32>) -> memref<?x?xf3
|
||||||
%collapsed = memref.collapse_shape %dynamic [[0], [1, 2, 3]] : memref<?x?x?x?xf32> into memref<?x?xf32>
|
%collapsed = memref.collapse_shape %dynamic [[0], [1, 2, 3]] : memref<?x?x?x?xf32> into memref<?x?xf32>
|
||||||
return %collapsed : memref<?x?xf32>
|
return %collapsed : memref<?x?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @reduced_memref(%arg0: memref<2x5x7x1xf32>, %arg1 :index)
|
||||||
|
-> memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>> {
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%c5 = arith.constant 5 : index
|
||||||
|
%c4 = arith.constant 4 : index
|
||||||
|
%c2 = arith.constant 2 : index
|
||||||
|
%c1 = arith.constant 1 : index
|
||||||
|
%0 = memref.subview %arg0[%arg1, %arg1, %arg1, 0] [%c1, %c4, %c1, 1] [1, 1, 1, 1]
|
||||||
|
: memref<2x5x7x1xf32> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>>
|
||||||
|
%1 = memref.cast %0
|
||||||
|
: memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>> to
|
||||||
|
memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>>
|
||||||
|
return %1 : memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @reduced_memref
|
||||||
|
// CHECK: %[[RESULT:.+]] = memref.subview
|
||||||
|
// CHECK-SAME: memref<2x5x7x1xf32> to memref<1x4x1xf32, #{{.+}}>
|
||||||
|
// CHECK: return %[[RESULT]]
|
||||||
|
|
Loading…
Reference in New Issue