[mlir][Vector] Add folding for vector.transfer ops

This revision folds vector.transfer operations by updating the `masked` bool array attribute when more unmasked dimensions can be discovered.

Differential revision: https://reviews.llvm.org/D83586
This commit is contained in:
Nicolas Vasilache 2020-07-10 16:47:51 -04:00
parent cb6c110614
commit ec2f2cec76
8 changed files with 122 additions and 26 deletions

View File

@ -919,6 +919,15 @@ def Vector_TransferOpUtils {
VectorType getVectorType() {
return vector().getType().cast<VectorType>();
}
// Number of dimensions that participate in the permutation map.
unsigned getTransferRank() {
return permutation_map().getNumResults();
}
// Number of leading dimensions that do not participate in the permutation
// map.
unsigned getLeadingMemRefRank() {
return getMemRefType().getRank() - permutation_map().getNumResults();
}
}];
}

View File

@ -99,9 +99,9 @@ public:
/// dimensional identifiers.
bool isIdentity() const;
/// Returns true if the map is a minor identity map, i.e. an identity affine
/// map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
static bool isMinorIdentity(AffineMap map);
/// Returns true if this affine map is a minor identity, i.e. an identity
/// affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
bool isMinorIdentity() const;
/// Returns true if this affine map is an empty map, i.e., () -> ().
bool isEmpty() const;

View File

@ -72,7 +72,7 @@ public:
llvm::size(xferOp.indices()) == 0)
return failure();
if (!AffineMap::isMinorIdentity(xferOp.permutation_map()))
if (!xferOp.permutation_map().isMinorIdentity())
return failure();
// Have it handled in vector->llvm conversion pass.

View File

@ -89,7 +89,7 @@ public:
// TODO: when we go to k > 1-D vectors adapt minorRank.
minorRank = 1;
majorRank = vectorType.getRank() - minorRank;
leadingRank = xferOp.getMemRefType().getRank() - (majorRank + minorRank);
leadingRank = xferOp.getLeadingMemRefRank();
majorVectorType =
VectorType::get(vectorType.getShape().take_front(majorRank),
vectorType.getElementType());
@ -538,7 +538,7 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
using namespace mlir::edsc::op;
TransferReadOp transfer = cast<TransferReadOp>(op);
if (AffineMap::isMinorIdentity(transfer.permutation_map())) {
if (transfer.permutation_map().isMinorIdentity()) {
// If > 1D, emit a bunch of loops around 1-D vector transfers.
if (transfer.getVectorType().getRank() > 1)
return NDTransferOpHelper<TransferReadOp>(rewriter, transfer, options)
@ -611,7 +611,7 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
using namespace edsc::op;
TransferWriteOp transfer = cast<TransferWriteOp>(op);
if (AffineMap::isMinorIdentity(transfer.permutation_map())) {
if (transfer.permutation_map().isMinorIdentity()) {
// If > 1D, emit a bunch of loops around 1-D vector transfers.
if (transfer.getVectorType().getRank() > 1)
return NDTransferOpHelper<TransferWriteOp>(rewriter, transfer, options)

View File

@ -620,7 +620,7 @@ static LogicalResult foldExtractOpFromTranspose(ExtractOp extractOp) {
MLIRContext *ctx = extractOp.getContext();
AffineMap permutationMap = AffineMap::getPermutationMap(permutation, ctx);
AffineMap minorMap = permutationMap.getMinorSubMap(minorRank);
if (minorMap && !AffineMap::isMinorIdentity(minorMap))
if (minorMap && !minorMap.isMinorIdentity())
return failure();
// %1 = transpose %0[x, y, z] : vector<axbxcxf32>
@ -730,7 +730,7 @@ static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
unsigned minorRank =
permutationMap.getNumResults() - insertedPos.size();
AffineMap minorMap = permutationMap.getMinorSubMap(minorRank);
if (!minorMap || AffineMap::isMinorIdentity(minorMap))
if (!minorMap || minorMap.isMinorIdentity())
return insertOp.source();
}
}
@ -1720,8 +1720,68 @@ static LogicalResult foldMemRefCast(Operation *op) {
return success(folded);
}
template <typename TransferOp>
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
// TODO: support more aggressive createOrFold on:
// `op.indices()[indicesIdx] + vectorType < dim(op.memref(), indicesIdx)`
if (op.getMemRefType().isDynamicDim(indicesIdx))
return false;
Value index = op.indices()[indicesIdx];
auto cstOp = index.getDefiningOp<ConstantIndexOp>();
if (!cstOp)
return false;
int64_t memrefSize = op.getMemRefType().getDimSize(indicesIdx);
int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
return cstOp.getValue() + vectorSize <= memrefSize;
}
template <typename TransferOp>
static LogicalResult foldTransferMaskAttribute(TransferOp op) {
AffineMap permutationMap = op.permutation_map();
if (!permutationMap.isMinorIdentity())
return failure();
bool changed = false;
SmallVector<bool, 4> isMasked;
isMasked.reserve(op.getTransferRank());
// `permutationMap` results and `op.indices` sizes may not match and may not
// be aligned. The first `indicesIdx` may just be indexed and not transferred
// from/into the vector.
// For example:
// vector.transfer %0[%i, %j, %k, %c0] : memref<?x?x?x?xf32>, vector<2x4xf32>
// with `permutation_map = (d0, d1, d2, d3) -> (d2, d3)`.
// The `permutationMap` results and `op.indices` are however aligned when
// iterating in reverse until we exhaust `permutationMap` results.
// As a consequence we iterate with 2 running indices: `resultIdx` and
// `indicesIdx`, until `resultIdx` reaches 0.
for (int64_t resultIdx = permutationMap.getNumResults() - 1,
indicesIdx = op.indices().size() - 1;
resultIdx >= 0; --resultIdx, --indicesIdx) {
// Already marked unmasked, nothing to see here.
if (!op.isMaskedDim(resultIdx)) {
isMasked.push_back(false);
continue;
}
// Currently masked, check whether we can statically determine it is
// inBounds.
auto inBounds = isInBounds(op, resultIdx, indicesIdx);
isMasked.push_back(!inBounds);
// We commit the pattern if it is "more inbounds".
changed |= inBounds;
}
if (!changed)
return failure();
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(op.getContext());
std::reverse(isMasked.begin(), isMasked.end());
op.setAttr(TransferOp::getMaskedAttrName(), b.getBoolArrayAttr(isMasked));
return success();
}
OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
/// transfer_read(memrefcast) -> transfer_read
if (succeeded(foldTransferMaskAttribute(*this)))
return getResult();
if (succeeded(foldMemRefCast(*this)))
return getResult();
return OpFoldResult();
@ -1819,6 +1879,8 @@ static LogicalResult verify(TransferWriteOp op) {
LogicalResult TransferWriteOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
if (succeeded(foldTransferMaskAttribute(*this)))
return success();
return foldMemRefCast(*this);
}

View File

@ -104,11 +104,9 @@ AffineMap AffineMap::getMinorIdentityMap(unsigned dims, unsigned results,
return AffineMap::get(dims, 0, id.getResults().take_back(results), context);
}
bool AffineMap::isMinorIdentity(AffineMap map) {
if (!map)
return false;
return map == getMinorIdentityMap(map.getNumDims(), map.getNumResults(),
map.getContext());
bool AffineMap::isMinorIdentity() const {
return *this ==
getMinorIdentityMap(getNumDims(), getNumResults(), getContext());
}
/// Returns an AffineMap representing a permutation.

View File

@ -168,10 +168,10 @@ func @cast_transfers(%A: memref<4x8xf32>) -> (vector<4x8xf32>) {
%f0 = constant 0.0 : f32
%0 = memref_cast %A : memref<4x8xf32> to memref<?x?xf32>
// CHECK: vector.transfer_read %{{.*}} : memref<4x8xf32>, vector<4x8xf32>
// CHECK: vector.transfer_read %{{.*}} {masked = [false, false]} : memref<4x8xf32>, vector<4x8xf32>
%1 = vector.transfer_read %0[%c0, %c0], %f0 : memref<?x?xf32>, vector<4x8xf32>
// CHECK: vector.transfer_write %{{.*}} : vector<4x8xf32>, memref<4x8xf32>
// CHECK: vector.transfer_write %{{.*}} {masked = [false, false]} : vector<4x8xf32>, memref<4x8xf32>
vector.transfer_write %1, %0[%c0, %c0] : vector<4x8xf32>, memref<?x?xf32>
return %1 : vector<4x8xf32>
}
@ -345,3 +345,30 @@ func @fold_extract_transpose(
return %1, %3, %5 : vector<6xf32>, vector<6xf32>, vector<6xf32>
}
// -----
// CHECK-LABEL: fold_vector_transfers
func @fold_vector_transfers(%A: memref<?x8xf32>) -> (vector<4x8xf32>, vector<4x9xf32>) {
%c0 = constant 0 : index
%f0 = constant 0.0 : f32
// CHECK: vector.transfer_read %{{.*}} {masked = [true, false]}
%1 = vector.transfer_read %A[%c0, %c0], %f0 : memref<?x8xf32>, vector<4x8xf32>
// CHECK: vector.transfer_write %{{.*}} {masked = [true, false]}
vector.transfer_write %1, %A[%c0, %c0] : vector<4x8xf32>, memref<?x8xf32>
// Both dims masked, attribute is elided.
// CHECK: vector.transfer_read %{{.*}}
// CHECK-NOT: masked
%2 = vector.transfer_read %A[%c0, %c0], %f0 : memref<?x8xf32>, vector<4x9xf32>
// Both dims masked, attribute is elided.
// CHECK: vector.transfer_write %{{.*}}
// CHECK-NOT: masked
vector.transfer_write %2, %A[%c0, %c0] : vector<4x9xf32>, memref<?x8xf32>
// CHECK: return
return %1, %2 : vector<4x8xf32>, vector<4x9xf32>
}

View File

@ -248,10 +248,10 @@ func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>,
// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] : vector<2x2xf32>, memref<4x4xf32>
// CHECK-NEXT: vector.transfer_write %[[R1]], %{{.*}}[%[[C0]], %[[C2]]] : vector<2x2xf32>, memref<4x4xf32>
// CHECK-NEXT: vector.transfer_write %[[R2]], %{{.*}}[%[[C2]], %[[C0]]] : vector<2x2xf32>, memref<4x4xf32>
// CHECK-NEXT: vector.transfer_write %[[R3]], %{{.*}}[%[[C2]], %[[C2]]] : vector<2x2xf32>, memref<4x4xf32>
// CHECK-NEXT: vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32>
// CHECK-NEXT: vector.transfer_write %[[R1]], %{{.*}}[%[[C0]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32>
// CHECK-NEXT: vector.transfer_write %[[R2]], %{{.*}}[%[[C2]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32>
// CHECK-NEXT: vector.transfer_write %[[R3]], %{{.*}}[%[[C2]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32>
// CHECK-NEXT: return
func @contraction4x4_ikj_xfer_read(%arg0 : memref<4x2xf32>,
@ -296,8 +296,8 @@ func @vector_transfers(%arg0: index, %arg1: index) {
%cst_1 = constant 2.000000e+00 : f32
affine.for %arg2 = 0 to %arg0 step 4 {
affine.for %arg3 = 0 to %arg1 step 4 {
%4 = vector.transfer_read %0[%arg2, %arg3], %cst {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref<?x?xf32>, vector<4x4xf32>
%5 = vector.transfer_read %1[%arg2, %arg3], %cst {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref<?x?xf32>, vector<4x4xf32>
%4 = vector.transfer_read %0[%arg2, %arg3], %cst {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref<?x?xf32>, vector<4x4xf32>
%5 = vector.transfer_read %1[%arg2, %arg3], %cst {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref<?x?xf32>, vector<4x4xf32>
%6 = addf %4, %5 : vector<4x4xf32>
vector.transfer_write %6, %2[%arg2, %arg3] {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : vector<4x4xf32>, memref<?x?xf32>
}
@ -426,10 +426,10 @@ func @cancelling_shape_cast_ops(%arg0 : vector<2x4xf32>) -> vector<2x4xf32> {
// CHECK-LABEL: func @vector_transfers_vector_element_type
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32>
// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C1]], %[[C0]]], %{{.*}} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32>
// CHECK-NEXT: vector.transfer_write %[[VTR0]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>>
// CHECK-NEXT: vector.transfer_write %[[VTR1]], %{{.*}}[%[[C0]], %[[C1]], %[[C0]]] : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>>
// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} {masked = [false, false]} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32>
// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C1]], %[[C0]]], %{{.*}} {masked = [false, false]} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32>
// CHECK-NEXT: vector.transfer_write %[[VTR0]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {masked = [false, false]} : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>>
// CHECK-NEXT: vector.transfer_write %[[VTR1]], %{{.*}}[%[[C0]], %[[C1]], %[[C0]]] {masked = [false, false]} : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>>
func @vector_transfers_vector_element_type() {
%c0 = constant 0 : index