[mlir][Vector] Add folding for vector.transfer ops
This revision folds vector.transfer operations by updating the `masked` bool array attribute when more unmasked dimensions can be discovered. Differential revision: https://reviews.llvm.org/D83586
This commit is contained in:
parent
cb6c110614
commit
ec2f2cec76
|
@ -919,6 +919,15 @@ def Vector_TransferOpUtils {
|
|||
VectorType getVectorType() {
|
||||
return vector().getType().cast<VectorType>();
|
||||
}
|
||||
// Number of dimensions that participate in the permutation map.
|
||||
unsigned getTransferRank() {
|
||||
return permutation_map().getNumResults();
|
||||
}
|
||||
// Number of leading dimensions that do not participate in the permutation
|
||||
// map.
|
||||
unsigned getLeadingMemRefRank() {
|
||||
return getMemRefType().getRank() - permutation_map().getNumResults();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -99,9 +99,9 @@ public:
|
|||
/// dimensional identifiers.
|
||||
bool isIdentity() const;
|
||||
|
||||
/// Returns true if the map is a minor identity map, i.e. an identity affine
|
||||
/// map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
|
||||
static bool isMinorIdentity(AffineMap map);
|
||||
/// Returns true if this affine map is a minor identity, i.e. an identity
|
||||
/// affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
|
||||
bool isMinorIdentity() const;
|
||||
|
||||
/// Returns true if this affine map is an empty map, i.e., () -> ().
|
||||
bool isEmpty() const;
|
||||
|
|
|
@ -72,7 +72,7 @@ public:
|
|||
llvm::size(xferOp.indices()) == 0)
|
||||
return failure();
|
||||
|
||||
if (!AffineMap::isMinorIdentity(xferOp.permutation_map()))
|
||||
if (!xferOp.permutation_map().isMinorIdentity())
|
||||
return failure();
|
||||
|
||||
// Have it handled in vector->llvm conversion pass.
|
||||
|
|
|
@ -89,7 +89,7 @@ public:
|
|||
// TODO: when we go to k > 1-D vectors adapt minorRank.
|
||||
minorRank = 1;
|
||||
majorRank = vectorType.getRank() - minorRank;
|
||||
leadingRank = xferOp.getMemRefType().getRank() - (majorRank + minorRank);
|
||||
leadingRank = xferOp.getLeadingMemRefRank();
|
||||
majorVectorType =
|
||||
VectorType::get(vectorType.getShape().take_front(majorRank),
|
||||
vectorType.getElementType());
|
||||
|
@ -538,7 +538,7 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
|
|||
using namespace mlir::edsc::op;
|
||||
|
||||
TransferReadOp transfer = cast<TransferReadOp>(op);
|
||||
if (AffineMap::isMinorIdentity(transfer.permutation_map())) {
|
||||
if (transfer.permutation_map().isMinorIdentity()) {
|
||||
// If > 1D, emit a bunch of loops around 1-D vector transfers.
|
||||
if (transfer.getVectorType().getRank() > 1)
|
||||
return NDTransferOpHelper<TransferReadOp>(rewriter, transfer, options)
|
||||
|
@ -611,7 +611,7 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
|
|||
using namespace edsc::op;
|
||||
|
||||
TransferWriteOp transfer = cast<TransferWriteOp>(op);
|
||||
if (AffineMap::isMinorIdentity(transfer.permutation_map())) {
|
||||
if (transfer.permutation_map().isMinorIdentity()) {
|
||||
// If > 1D, emit a bunch of loops around 1-D vector transfers.
|
||||
if (transfer.getVectorType().getRank() > 1)
|
||||
return NDTransferOpHelper<TransferWriteOp>(rewriter, transfer, options)
|
||||
|
|
|
@ -620,7 +620,7 @@ static LogicalResult foldExtractOpFromTranspose(ExtractOp extractOp) {
|
|||
MLIRContext *ctx = extractOp.getContext();
|
||||
AffineMap permutationMap = AffineMap::getPermutationMap(permutation, ctx);
|
||||
AffineMap minorMap = permutationMap.getMinorSubMap(minorRank);
|
||||
if (minorMap && !AffineMap::isMinorIdentity(minorMap))
|
||||
if (minorMap && !minorMap.isMinorIdentity())
|
||||
return failure();
|
||||
|
||||
// %1 = transpose %0[x, y, z] : vector<axbxcxf32>
|
||||
|
@ -730,7 +730,7 @@ static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
|
|||
unsigned minorRank =
|
||||
permutationMap.getNumResults() - insertedPos.size();
|
||||
AffineMap minorMap = permutationMap.getMinorSubMap(minorRank);
|
||||
if (!minorMap || AffineMap::isMinorIdentity(minorMap))
|
||||
if (!minorMap || minorMap.isMinorIdentity())
|
||||
return insertOp.source();
|
||||
}
|
||||
}
|
||||
|
@ -1720,8 +1720,68 @@ static LogicalResult foldMemRefCast(Operation *op) {
|
|||
return success(folded);
|
||||
}
|
||||
|
||||
template <typename TransferOp>
|
||||
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
|
||||
// TODO: support more aggressive createOrFold on:
|
||||
// `op.indices()[indicesIdx] + vectorType < dim(op.memref(), indicesIdx)`
|
||||
if (op.getMemRefType().isDynamicDim(indicesIdx))
|
||||
return false;
|
||||
Value index = op.indices()[indicesIdx];
|
||||
auto cstOp = index.getDefiningOp<ConstantIndexOp>();
|
||||
if (!cstOp)
|
||||
return false;
|
||||
|
||||
int64_t memrefSize = op.getMemRefType().getDimSize(indicesIdx);
|
||||
int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
|
||||
return cstOp.getValue() + vectorSize <= memrefSize;
|
||||
}
|
||||
|
||||
template <typename TransferOp>
|
||||
static LogicalResult foldTransferMaskAttribute(TransferOp op) {
|
||||
AffineMap permutationMap = op.permutation_map();
|
||||
if (!permutationMap.isMinorIdentity())
|
||||
return failure();
|
||||
bool changed = false;
|
||||
SmallVector<bool, 4> isMasked;
|
||||
isMasked.reserve(op.getTransferRank());
|
||||
// `permutationMap` results and `op.indices` sizes may not match and may not
|
||||
// be aligned. The first `indicesIdx` may just be indexed and not transferred
|
||||
// from/into the vector.
|
||||
// For example:
|
||||
// vector.transfer %0[%i, %j, %k, %c0] : memref<?x?x?x?xf32>, vector<2x4xf32>
|
||||
// with `permutation_map = (d0, d1, d2, d3) -> (d2, d3)`.
|
||||
// The `permutationMap` results and `op.indices` are however aligned when
|
||||
// iterating in reverse until we exhaust `permutationMap` results.
|
||||
// As a consequence we iterate with 2 running indices: `resultIdx` and
|
||||
// `indicesIdx`, until `resultIdx` reaches 0.
|
||||
for (int64_t resultIdx = permutationMap.getNumResults() - 1,
|
||||
indicesIdx = op.indices().size() - 1;
|
||||
resultIdx >= 0; --resultIdx, --indicesIdx) {
|
||||
// Already marked unmasked, nothing to see here.
|
||||
if (!op.isMaskedDim(resultIdx)) {
|
||||
isMasked.push_back(false);
|
||||
continue;
|
||||
}
|
||||
// Currently masked, check whether we can statically determine it is
|
||||
// inBounds.
|
||||
auto inBounds = isInBounds(op, resultIdx, indicesIdx);
|
||||
isMasked.push_back(!inBounds);
|
||||
// We commit the pattern if it is "more inbounds".
|
||||
changed |= inBounds;
|
||||
}
|
||||
if (!changed)
|
||||
return failure();
|
||||
// OpBuilder is only used as a helper to build an I64ArrayAttr.
|
||||
OpBuilder b(op.getContext());
|
||||
std::reverse(isMasked.begin(), isMasked.end());
|
||||
op.setAttr(TransferOp::getMaskedAttrName(), b.getBoolArrayAttr(isMasked));
|
||||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
|
||||
/// transfer_read(memrefcast) -> transfer_read
|
||||
if (succeeded(foldTransferMaskAttribute(*this)))
|
||||
return getResult();
|
||||
if (succeeded(foldMemRefCast(*this)))
|
||||
return getResult();
|
||||
return OpFoldResult();
|
||||
|
@ -1819,6 +1879,8 @@ static LogicalResult verify(TransferWriteOp op) {
|
|||
|
||||
LogicalResult TransferWriteOp::fold(ArrayRef<Attribute>,
|
||||
SmallVectorImpl<OpFoldResult> &) {
|
||||
if (succeeded(foldTransferMaskAttribute(*this)))
|
||||
return success();
|
||||
return foldMemRefCast(*this);
|
||||
}
|
||||
|
||||
|
|
|
@ -104,11 +104,9 @@ AffineMap AffineMap::getMinorIdentityMap(unsigned dims, unsigned results,
|
|||
return AffineMap::get(dims, 0, id.getResults().take_back(results), context);
|
||||
}
|
||||
|
||||
bool AffineMap::isMinorIdentity(AffineMap map) {
|
||||
if (!map)
|
||||
return false;
|
||||
return map == getMinorIdentityMap(map.getNumDims(), map.getNumResults(),
|
||||
map.getContext());
|
||||
bool AffineMap::isMinorIdentity() const {
|
||||
return *this ==
|
||||
getMinorIdentityMap(getNumDims(), getNumResults(), getContext());
|
||||
}
|
||||
|
||||
/// Returns an AffineMap representing a permutation.
|
||||
|
|
|
@ -168,10 +168,10 @@ func @cast_transfers(%A: memref<4x8xf32>) -> (vector<4x8xf32>) {
|
|||
%f0 = constant 0.0 : f32
|
||||
%0 = memref_cast %A : memref<4x8xf32> to memref<?x?xf32>
|
||||
|
||||
// CHECK: vector.transfer_read %{{.*}} : memref<4x8xf32>, vector<4x8xf32>
|
||||
// CHECK: vector.transfer_read %{{.*}} {masked = [false, false]} : memref<4x8xf32>, vector<4x8xf32>
|
||||
%1 = vector.transfer_read %0[%c0, %c0], %f0 : memref<?x?xf32>, vector<4x8xf32>
|
||||
|
||||
// CHECK: vector.transfer_write %{{.*}} : vector<4x8xf32>, memref<4x8xf32>
|
||||
// CHECK: vector.transfer_write %{{.*}} {masked = [false, false]} : vector<4x8xf32>, memref<4x8xf32>
|
||||
vector.transfer_write %1, %0[%c0, %c0] : vector<4x8xf32>, memref<?x?xf32>
|
||||
return %1 : vector<4x8xf32>
|
||||
}
|
||||
|
@ -345,3 +345,30 @@ func @fold_extract_transpose(
|
|||
|
||||
return %1, %3, %5 : vector<6xf32>, vector<6xf32>, vector<6xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: fold_vector_transfers
|
||||
func @fold_vector_transfers(%A: memref<?x8xf32>) -> (vector<4x8xf32>, vector<4x9xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
%f0 = constant 0.0 : f32
|
||||
|
||||
// CHECK: vector.transfer_read %{{.*}} {masked = [true, false]}
|
||||
%1 = vector.transfer_read %A[%c0, %c0], %f0 : memref<?x8xf32>, vector<4x8xf32>
|
||||
|
||||
// CHECK: vector.transfer_write %{{.*}} {masked = [true, false]}
|
||||
vector.transfer_write %1, %A[%c0, %c0] : vector<4x8xf32>, memref<?x8xf32>
|
||||
|
||||
// Both dims masked, attribute is elided.
|
||||
// CHECK: vector.transfer_read %{{.*}}
|
||||
// CHECK-NOT: masked
|
||||
%2 = vector.transfer_read %A[%c0, %c0], %f0 : memref<?x8xf32>, vector<4x9xf32>
|
||||
|
||||
// Both dims masked, attribute is elided.
|
||||
// CHECK: vector.transfer_write %{{.*}}
|
||||
// CHECK-NOT: masked
|
||||
vector.transfer_write %2, %A[%c0, %c0] : vector<4x9xf32>, memref<?x8xf32>
|
||||
|
||||
// CHECK: return
|
||||
return %1, %2 : vector<4x8xf32>, vector<4x9xf32>
|
||||
}
|
||||
|
|
|
@ -248,10 +248,10 @@ func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>,
|
|||
// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
|
||||
// CHECK-NEXT: vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] : vector<2x2xf32>, memref<4x4xf32>
|
||||
// CHECK-NEXT: vector.transfer_write %[[R1]], %{{.*}}[%[[C0]], %[[C2]]] : vector<2x2xf32>, memref<4x4xf32>
|
||||
// CHECK-NEXT: vector.transfer_write %[[R2]], %{{.*}}[%[[C2]], %[[C0]]] : vector<2x2xf32>, memref<4x4xf32>
|
||||
// CHECK-NEXT: vector.transfer_write %[[R3]], %{{.*}}[%[[C2]], %[[C2]]] : vector<2x2xf32>, memref<4x4xf32>
|
||||
// CHECK-NEXT: vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32>
|
||||
// CHECK-NEXT: vector.transfer_write %[[R1]], %{{.*}}[%[[C0]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32>
|
||||
// CHECK-NEXT: vector.transfer_write %[[R2]], %{{.*}}[%[[C2]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32>
|
||||
// CHECK-NEXT: vector.transfer_write %[[R3]], %{{.*}}[%[[C2]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32>
|
||||
// CHECK-NEXT: return
|
||||
|
||||
func @contraction4x4_ikj_xfer_read(%arg0 : memref<4x2xf32>,
|
||||
|
@ -296,8 +296,8 @@ func @vector_transfers(%arg0: index, %arg1: index) {
|
|||
%cst_1 = constant 2.000000e+00 : f32
|
||||
affine.for %arg2 = 0 to %arg0 step 4 {
|
||||
affine.for %arg3 = 0 to %arg1 step 4 {
|
||||
%4 = vector.transfer_read %0[%arg2, %arg3], %cst {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref<?x?xf32>, vector<4x4xf32>
|
||||
%5 = vector.transfer_read %1[%arg2, %arg3], %cst {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref<?x?xf32>, vector<4x4xf32>
|
||||
%4 = vector.transfer_read %0[%arg2, %arg3], %cst {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref<?x?xf32>, vector<4x4xf32>
|
||||
%5 = vector.transfer_read %1[%arg2, %arg3], %cst {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref<?x?xf32>, vector<4x4xf32>
|
||||
%6 = addf %4, %5 : vector<4x4xf32>
|
||||
vector.transfer_write %6, %2[%arg2, %arg3] {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : vector<4x4xf32>, memref<?x?xf32>
|
||||
}
|
||||
|
@ -426,10 +426,10 @@ func @cancelling_shape_cast_ops(%arg0 : vector<2x4xf32>) -> vector<2x4xf32> {
|
|||
// CHECK-LABEL: func @vector_transfers_vector_element_type
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32>
|
||||
// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C1]], %[[C0]]], %{{.*}} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32>
|
||||
// CHECK-NEXT: vector.transfer_write %[[VTR0]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>>
|
||||
// CHECK-NEXT: vector.transfer_write %[[VTR1]], %{{.*}}[%[[C0]], %[[C1]], %[[C0]]] : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>>
|
||||
// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} {masked = [false, false]} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32>
|
||||
// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C1]], %[[C0]]], %{{.*}} {masked = [false, false]} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32>
|
||||
// CHECK-NEXT: vector.transfer_write %[[VTR0]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {masked = [false, false]} : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>>
|
||||
// CHECK-NEXT: vector.transfer_write %[[VTR1]], %{{.*}}[%[[C0]], %[[C1]], %[[C0]]] {masked = [false, false]} : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>>
|
||||
|
||||
func @vector_transfers_vector_element_type() {
|
||||
%c0 = constant 0 : index
|
||||
|
|
Loading…
Reference in New Issue