[mlir] Use SmallBitVector instead of SmallDenseSet for AffineMap::compressSymbols

This is both more efficient and more ergonomic to use, as inverting a
bit vector is trivial while inverting a set is annoying.

Sadly this leaks into a bunch of APIs downstream, so adapt them as well.

This would be NFC, but there is an ordering dependency in MemRefOps's
computeMemRefRankReductionMask. This is now deterministic, previously it
was dependent on SmallDenseSet's unspecified iteration order.

Differential Revision: https://reviews.llvm.org/D119076
This commit is contained in:
Benjamin Kramer 2022-02-06 14:06:34 +01:00
parent 330838eb90
commit 6635c12ada
13 changed files with 71 additions and 76 deletions

View File

@ -31,8 +31,8 @@ detail::op_matcher<arith::ConstantIndexOp> matchConstantIndex();
void canonicalizeSubViewPart(SmallVectorImpl<OpFoldResult> &values,
function_ref<bool(int64_t)> isDynamic);
void getPositionsOfShapeOne(unsigned rank, ArrayRef<int64_t> shape,
llvm::SmallDenseSet<unsigned> &dimsToProject);
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank,
ArrayRef<int64_t> shape);
/// Pattern to rewrite a subview op with constant arguments.
template <typename OpType, typename ResultTypeFunc, typename CastOpFunc>

View File

@ -1639,7 +1639,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
/// Return the dimensions of the source type that are dropped when
/// the result is rank-reduced.
llvm::SmallDenseSet<unsigned> getDroppedDims();
llvm::SmallBitVector getDroppedDims();
}];
let hasCanonicalizer = 1;

View File

@ -316,8 +316,7 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
/// Return the dimensions of the source that are dropped in the
/// result when the result is rank-reduced.
llvm::SmallDenseSet<unsigned> getDroppedDims();
llvm::SmallBitVector getDroppedDims();
}];
let hasCanonicalizer = 1;

View File

@ -18,7 +18,10 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/DenseSet.h"
namespace llvm {
class SmallBitVector;
} // namespace llvm
namespace mlir {
@ -372,8 +375,7 @@ AffineMap compressUnusedDims(AffineMap map);
SmallVector<AffineMap> compressUnusedDims(ArrayRef<AffineMap> maps);
/// Drop the dims that are not listed in `unusedDims`.
AffineMap compressDims(AffineMap map,
const llvm::SmallDenseSet<unsigned> &unusedDims);
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims);
/// Drop the symbols that are not used.
AffineMap compressUnusedSymbols(AffineMap map);
@ -385,7 +387,7 @@ SmallVector<AffineMap> compressUnusedSymbols(ArrayRef<AffineMap> maps);
/// Drop the symbols that are not listed in `unusedSymbols`.
AffineMap compressSymbols(AffineMap map,
const llvm::SmallDenseSet<unsigned> &unusedSymbols);
const llvm::SmallBitVector &unusedSymbols);
/// Returns a map with the same dimension and symbol count as `map`, but whose
/// results are the unique affine expressions of `map`.
@ -521,9 +523,8 @@ AffineMap concatAffineMaps(ArrayRef<AffineMap> maps);
/// result : affine_map<(d0, d1) -> (d0, 0)>
///
/// This function also compresses unused symbols away.
AffineMap
getProjectedMap(AffineMap map,
const llvm::SmallDenseSet<unsigned> &projectedDimensions);
AffineMap getProjectedMap(AffineMap map,
const llvm::SmallBitVector &projectedDimensions);
/// Apply a permutation from `map` to `source` and return the result.
template <typename T>

View File

@ -18,6 +18,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "llvm/ADT/SmallBitVector.h"
using namespace mlir;
@ -1503,10 +1504,10 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
assert(mixedSizes.size() == mixedStrides.size() &&
"expected sizes and strides of equal length");
llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
i >= 0 && j >= 0; --i) {
if (unusedDims.contains(i))
if (unusedDims.test(i))
continue;
// `i` may overflow subViewOp.getMixedSizes because of trailing semantics.

View File

@ -12,6 +12,7 @@
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "llvm/ADT/SmallBitVector.h"
using namespace mlir;
@ -37,16 +38,16 @@ void mlir::canonicalizeSubViewPart(
}
}
void mlir::getPositionsOfShapeOne(
unsigned rank, ArrayRef<int64_t> shape,
llvm::SmallDenseSet<unsigned> &dimsToProject) {
dimsToProject.reserve(rank);
llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
ArrayRef<int64_t> shape) {
llvm::SmallBitVector dimsToProject(shape.size());
for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) {
if (shape[pos] == 1) {
dimsToProject.insert(pos);
dimsToProject.set(pos);
--rank;
}
}
return dimsToProject;
}
Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,

View File

@ -520,10 +520,10 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
/// Prune all dimensions that are of reduction iterator type from `map`.
static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
AffineMap map) {
llvm::SmallDenseSet<unsigned> projectedDims;
llvm::SmallBitVector projectedDims(iteratorTypes.size());
for (const auto &attr : llvm::enumerate(iteratorTypes)) {
if (!isParallelIterator(attr.value()))
projectedDims.insert(attr.index());
projectedDims.set(attr.index());
}
return getProjectedMap(map, projectedDims);
}

View File

@ -187,7 +187,7 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
return failure(hasDynamicShape);
// Compute the dropped dimensions if `sliceOp` is ranke-reducing.
llvm::SmallDenseSet<unsigned> droppedDims = sliceOp.getDroppedDims();
llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
// Upper bound the `sliceOp` sizes to obtain a static bounding box.
SmallVector<int64_t> staticSizes;
@ -195,7 +195,7 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation());
for (const auto &en : enumerate(shapedOp.getMixedSizes())) {
// Skip dropped dimensions.
if (droppedDims.contains(en.index()))
if (droppedDims.test(en.index()))
continue;
// If the size is an attribute add it directly to `staticSizes`.
if (en.value().is<Attribute>()) {

View File

@ -20,6 +20,7 @@
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
using namespace mlir;
using namespace mlir::memref;
@ -590,17 +591,17 @@ static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
/// This accounts for cases where there are multiple unit-dims, but only a
/// subset of those are dropped. For MemRefTypes these can be disambiguated
/// using the strides. If a dimension is dropped the stride must be dropped too.
static llvm::Optional<llvm::SmallDenseSet<unsigned>>
static llvm::Optional<llvm::SmallBitVector>
computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
ArrayRef<OpFoldResult> sizes) {
llvm::SmallDenseSet<unsigned> unusedDims;
llvm::SmallBitVector unusedDims(originalType.getRank());
if (originalType.getRank() == reducedType.getRank())
return unusedDims;
for (const auto &dim : llvm::enumerate(sizes))
if (auto attr = dim.value().dyn_cast<Attribute>())
if (attr.cast<IntegerAttr>().getInt() == 1)
unusedDims.insert(dim.index());
unusedDims.set(dim.index());
SmallVector<int64_t> originalStrides, candidateStrides;
int64_t originalOffset, candidateOffset;
@ -623,8 +624,9 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
getNumOccurences(originalStrides);
std::map<int64_t, unsigned> candidateStridesNumOccurences =
getNumOccurences(candidateStrides);
llvm::SmallDenseSet<unsigned> prunedUnusedDims;
for (unsigned dim : unusedDims) {
for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
if (!unusedDims.test(dim))
continue;
int64_t originalStride = originalStrides[dim];
if (currUnaccountedStrides[originalStride] >
candidateStridesNumOccurences[originalStride]) {
@ -635,7 +637,7 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
if (currUnaccountedStrides[originalStride] ==
candidateStridesNumOccurences[originalStride]) {
// The stride for this is not dropped. Keep as is.
prunedUnusedDims.insert(dim);
unusedDims.reset(dim);
continue;
}
if (currUnaccountedStrides[originalStride] <
@ -646,17 +648,16 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
}
}
for (auto prunedDim : prunedUnusedDims)
unusedDims.erase(prunedDim);
if (unusedDims.size() + reducedType.getRank() != originalType.getRank())
if ((int64_t)unusedDims.count() + reducedType.getRank() !=
originalType.getRank())
return llvm::None;
return unusedDims;
}
llvm::SmallDenseSet<unsigned> SubViewOp::getDroppedDims() {
llvm::SmallBitVector SubViewOp::getDroppedDims() {
MemRefType sourceType = getSourceType();
MemRefType resultType = getType();
llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims =
llvm::Optional<llvm::SmallBitVector> unusedDims =
computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
assert(unusedDims && "unable to find unused dims of subview");
return *unusedDims;
@ -698,12 +699,12 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
memrefType.getDynamicDimIndex(unsignedIndex));
if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
llvm::SmallDenseSet<unsigned> unusedDims = subview.getDroppedDims();
llvm::SmallBitVector unusedDims = subview.getDroppedDims();
unsigned resultIndex = 0;
unsigned sourceRank = subview.getSourceType().getRank();
unsigned sourceIndex = 0;
for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
if (unusedDims.count(i))
if (unusedDims.test(i))
continue;
if (resultIndex == unsignedIndex) {
sourceIndex = i;
@ -1734,11 +1735,11 @@ Type SubViewOp::inferRankReducedResultType(unsigned resultRank,
int rankDiff = inferredType.getRank() - resultRank;
if (rankDiff > 0) {
auto shape = inferredType.getShape();
llvm::SmallDenseSet<unsigned> dimsToProject;
mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
llvm::SmallBitVector dimsToProject =
getPositionsOfShapeOne(rankDiff, shape);
SmallVector<int64_t> projectedShape;
for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
if (!dimsToProject.contains(pos))
if (!dimsToProject.test(pos))
projectedShape.push_back(shape[pos]);
AffineMap map = inferredType.getLayout().getAffineMap();
@ -2015,7 +2016,7 @@ static MemRefType getCanonicalSubViewResultType(
auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets,
mixedSizes, mixedStrides)
.cast<MemRefType>();
llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims =
llvm::Optional<llvm::SmallBitVector> unusedDims =
computeMemRefRankReductionMask(currentSourceType, currentResultType,
mixedSizes);
// Return nullptr as failure mode.
@ -2023,7 +2024,7 @@ static MemRefType getCanonicalSubViewResultType(
return nullptr;
SmallVector<int64_t> shape;
for (const auto &sizes : llvm::enumerate(nonRankReducedType.getShape())) {
if (unusedDims->count(sizes.index()))
if (unusedDims->test(sizes.index()))
continue;
shape.push_back(sizes.value());
}

View File

@ -18,6 +18,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallBitVector.h"
using namespace mlir;
@ -50,9 +51,9 @@ resolveSourceIndices(Location loc, PatternRewriter &rewriter,
// Check if this is rank-reducing case. Then for every unit-dim size add a
// zero to the indices.
unsigned resultDim = 0;
llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) {
if (unusedDims.count(dim))
if (unusedDims.test(dim))
useIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
else
useIndices.push_back(indices[resultDim++]);
@ -106,11 +107,11 @@ static Value getMemRefOperand(vector::TransferWriteOp op) {
static AffineMapAttr getPermutationMapAttr(MLIRContext *context,
memref::SubViewOp subViewOp,
AffineMap currPermutationMap) {
llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
SmallVector<AffineExpr> exprs;
int64_t sourceRank = subViewOp.getSourceType().getRank();
for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
if (unusedDims.count(dim))
if (unusedDims.test(dim))
continue;
exprs.push_back(getAffineDimExpr(dim, context));
}

View File

@ -19,6 +19,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
using namespace mlir;
using namespace mlir::tensor;
@ -935,11 +936,11 @@ RankedTensorType ExtractSliceOp::inferRankReducedResultType(
int rankDiff = inferredType.getRank() - resultRank;
if (rankDiff > 0) {
auto shape = inferredType.getShape();
llvm::SmallDenseSet<unsigned> dimsToProject;
mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
llvm::SmallBitVector dimsToProject =
getPositionsOfShapeOne(rankDiff, shape);
SmallVector<int64_t> projectedShape;
for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
if (!dimsToProject.contains(pos))
if (!dimsToProject.test(pos))
projectedShape.push_back(shape[pos]);
inferredType =
RankedTensorType::get(projectedShape, inferredType.getElementType());
@ -1076,10 +1077,10 @@ getCanonicalSliceResultType(unsigned resultRank, RankedTensorType sourceType,
return resultType;
}
llvm::SmallDenseSet<unsigned> ExtractSliceOp::getDroppedDims() {
llvm::SmallDenseSet<unsigned> droppedDims;
llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
ArrayRef<int64_t> resultShape = getType().getShape();
SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
llvm::SmallBitVector droppedDims(mixedSizes.size());
unsigned shapePos = 0;
for (const auto &size : enumerate(mixedSizes)) {
Optional<int64_t> sizeVal = getConstantIntValue(size.value());
@ -1091,7 +1092,7 @@ llvm::SmallDenseSet<unsigned> ExtractSliceOp::getDroppedDims() {
shapePos++;
continue;
}
droppedDims.insert(size.index());
droppedDims.set(size.index());
}
return droppedDims;
}
@ -1101,10 +1102,10 @@ LogicalResult ExtractSliceOp::reifyResultShapes(
reifiedReturnShapes.resize(1);
reifiedReturnShapes[0].reserve(getType().getRank());
SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
llvm::SmallDenseSet<unsigned> droppedDims = getDroppedDims();
llvm::SmallBitVector droppedDims = getDroppedDims();
Location loc = getLoc();
for (const auto &size : enumerate(mixedSizes)) {
if (droppedDims.count(size.index()))
if (droppedDims.test(size.index()))
continue;
if (auto attr = size.value().dyn_cast<Attribute>()) {
reifiedReturnShapes[0].push_back(builder.create<arith::ConstantIndexOp>(

View File

@ -547,13 +547,13 @@ AffineMap AffineMap::getMinorSubMap(unsigned numResults) const {
}
AffineMap mlir::compressDims(AffineMap map,
const llvm::SmallDenseSet<unsigned> &unusedDims) {
const llvm::SmallBitVector &unusedDims) {
unsigned numDims = 0;
SmallVector<AffineExpr> dimReplacements;
dimReplacements.reserve(map.getNumDims());
MLIRContext *context = map.getContext();
for (unsigned dim = 0, e = map.getNumDims(); dim < e; ++dim) {
if (unusedDims.contains(dim))
if (unusedDims.test(dim))
dimReplacements.push_back(getAffineConstantExpr(0, context));
else
dimReplacements.push_back(getAffineDimExpr(numDims++, context));
@ -566,15 +566,11 @@ AffineMap mlir::compressDims(AffineMap map,
}
AffineMap mlir::compressUnusedDims(AffineMap map) {
llvm::SmallDenseSet<unsigned> usedDims;
llvm::SmallBitVector unusedDims(map.getNumDims(), true);
map.walkExprs([&](AffineExpr expr) {
if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
usedDims.insert(dimExpr.getPosition());
unusedDims.reset(dimExpr.getPosition());
});
llvm::SmallDenseSet<unsigned> unusedDims;
for (unsigned d = 0, e = map.getNumDims(); d != e; ++d)
if (!usedDims.contains(d))
unusedDims.insert(d);
return compressDims(map, unusedDims);
}
@ -613,15 +609,14 @@ SmallVector<AffineMap> mlir::compressUnusedDims(ArrayRef<AffineMap> maps) {
[](AffineMap m) { return compressUnusedDims(m); });
}
AffineMap
mlir::compressSymbols(AffineMap map,
const llvm::SmallDenseSet<unsigned> &unusedSymbols) {
AffineMap mlir::compressSymbols(AffineMap map,
const llvm::SmallBitVector &unusedSymbols) {
unsigned numSymbols = 0;
SmallVector<AffineExpr> symReplacements;
symReplacements.reserve(map.getNumSymbols());
MLIRContext *context = map.getContext();
for (unsigned sym = 0, e = map.getNumSymbols(); sym < e; ++sym) {
if (unusedSymbols.contains(sym))
if (unusedSymbols.test(sym))
symReplacements.push_back(getAffineConstantExpr(0, context));
else
symReplacements.push_back(getAffineSymbolExpr(numSymbols++, context));
@ -634,15 +629,11 @@ mlir::compressSymbols(AffineMap map,
}
AffineMap mlir::compressUnusedSymbols(AffineMap map) {
llvm::SmallDenseSet<unsigned> usedSymbols;
llvm::SmallBitVector unusedSymbols(map.getNumSymbols(), true);
map.walkExprs([&](AffineExpr expr) {
if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
usedSymbols.insert(symExpr.getPosition());
unusedSymbols.reset(symExpr.getPosition());
});
llvm::SmallDenseSet<unsigned> unusedSymbols;
for (unsigned d = 0, e = map.getNumSymbols(); d != e; ++d)
if (!usedSymbols.contains(d))
unusedSymbols.insert(d);
return compressSymbols(map, unusedSymbols);
}
@ -732,9 +723,8 @@ AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
maps.front().getContext());
}
AffineMap
mlir::getProjectedMap(AffineMap map,
const llvm::SmallDenseSet<unsigned> &unusedDims) {
AffineMap mlir::getProjectedMap(AffineMap map,
const llvm::SmallBitVector &unusedDims) {
return compressUnusedSymbols(compressDims(map, unusedDims));
}

View File

@ -154,8 +154,8 @@ func @fold_rank_reducing_subview_with_load
// CHECK-SAME: %[[ARG16:[a-zA-Z0-9_]+]]: index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP]](%[[ARG13]])[%[[ARG7]], %[[ARG1]]]
// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]](%[[ARG14]])[%[[ARG8]], %[[ARG2]]]
// CHECK-DAG: %[[I3:.+]] = affine.apply #[[MAP]](%[[C0]])[%[[ARG9]], %[[ARG3]]]
// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]](%[[C0]])[%[[ARG8]], %[[ARG2]]]
// CHECK-DAG: %[[I3:.+]] = affine.apply #[[MAP]](%[[ARG14]])[%[[ARG9]], %[[ARG3]]]
// CHECK-DAG: %[[I4:.+]] = affine.apply #[[MAP]](%[[ARG15]])[%[[ARG10]], %[[ARG4]]]
// CHECK-DAG: %[[I5:.+]] = affine.apply #[[MAP]](%[[ARG16]])[%[[ARG11]], %[[ARG5]]]
// CHECK-DAG: %[[I6:.+]] = affine.apply #[[MAP]](%[[C0]])[%[[ARG12]], %[[ARG6]]]