[mlir][linalg] Improve aliasing approximation for hoisting transfer read/write

Improve the logic deciding if it is safe to hoist vector transfer read/write
out of the loop. Change the logic to prevent hoisting operations if there are
any unknown access to the memref in the loop no matter where the operation is.
For other transfer read/write in the loop check if we can prove that they
access disjoint memory and ignore them in this case.

Differential Revision: https://reviews.llvm.org/D83538
This commit is contained in:
Thomas Raoux 2020-07-10 14:55:04 -07:00
parent c0bc995429
commit 6d5aeb0dce
2 changed files with 143 additions and 8 deletions

View File

@ -80,6 +80,42 @@ void mlir::linalg::hoistViewAllocOps(FuncOp func) {
}
}
/// Return true if we can prove that the transfer operations access dijoint
/// memory.
template <typename TransferTypeA, typename TransferTypeB>
static bool isDisjoint(TransferTypeA transferA, TransferTypeB transferB) {
if (transferA.memref() != transferB.memref())
return false;
// For simplicity only look at transfer of same type.
if (transferA.getVectorType() != transferB.getVectorType())
return false;
unsigned rankOffset = transferA.getLeadingMemRefRank();
for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) {
auto indexA = transferA.indices()[i].template getDefiningOp<ConstantOp>();
auto indexB = transferB.indices()[i].template getDefiningOp<ConstantOp>();
// If any of the indices are dynamic we cannot prove anything.
if (!indexA || !indexB)
continue;
if (i < rankOffset) {
// For dimension used as index if we can prove that index are different we
// know we are accessing disjoint slices.
if (indexA.getValue().template cast<IntegerAttr>().getInt() !=
indexB.getValue().template cast<IntegerAttr>().getInt())
return true;
} else {
// For this dimension, we slice a part of the memref we need to make sure
// the intervals accessed don't overlap.
int64_t distance =
std::abs(indexA.getValue().template cast<IntegerAttr>().getInt() -
indexB.getValue().template cast<IntegerAttr>().getInt());
if (distance >= transferA.getVectorType().getDimSize(i - rankOffset))
return true;
}
}
return false;
}
void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
bool changed = true;
while (changed) {
@ -129,10 +165,11 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
// Approximate aliasing by checking that:
// 1. indices are the same,
// 2. no other use either dominates the transfer_read or is dominated
// by the transfer_write (i.e. aliasing between the write and the read
// across the loop).
if (transferRead.indices() != transferWrite.indices())
// 2. no other operations in the loop access the same memref except
// for transfer_read/transfer_write accessing statically disjoint
// slices.
if (transferRead.indices() != transferWrite.indices() &&
transferRead.getVectorType() == transferWrite.getVectorType())
return WalkResult::advance();
// TODO: may want to memoize this information for performance but it
@ -140,11 +177,26 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
DominanceInfo dom(loop);
if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
return WalkResult::advance();
for (auto &use : transferRead.memref().getUses())
if (dom.properlyDominates(use.getOwner(),
transferRead.getOperation()) ||
dom.properlyDominates(transferWrite, use.getOwner()))
for (auto &use : transferRead.memref().getUses()) {
if (!dom.properlyDominates(loop, use.getOwner()))
continue;
if (use.getOwner() == transferRead.getOperation() ||
use.getOwner() == transferWrite.getOperation())
continue;
if (auto transferWriteUse =
dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
if (!isDisjoint(transferWrite, transferWriteUse))
return WalkResult::advance();
} else if (auto transferReadUse =
dyn_cast<vector::TransferReadOp>(use.getOwner())) {
if (!isDisjoint(transferWrite, transferReadUse))
return WalkResult::advance();
} else {
// Unknown use, we cannot prove that it doesn't alias with the
// transferRead/transferWrite operations.
return WalkResult::advance();
}
}
// Hoist read before.
if (failed(loop.moveOutOfLoop({transferRead})))

View File

@ -132,18 +132,101 @@ func @hoist_vector_transfer_pairs(
%r3 = vector.transfer_read %memref3[%c0, %c0], %cst: memref<?x?xf32>, vector<4xf32>
"some_crippling_use"(%memref4) : (memref<?x?xf32>) -> ()
%r4 = vector.transfer_read %memref4[%c0, %c0], %cst: memref<?x?xf32>, vector<5xf32>
%r5 = vector.transfer_read %memref5[%c0, %c0], %cst: memref<?x?xf32>, vector<6xf32>
"some_crippling_use"(%memref5) : (memref<?x?xf32>) -> ()
%u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
%u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
%u2 = "some_use"(%memref2) : (memref<?x?xf32>) -> vector<3xf32>
%u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32>
%u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32>
%u5 = "some_use"(%r5) : (vector<6xf32>) -> vector<6xf32>
vector.transfer_write %u0, %memref1[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
vector.transfer_write %u1, %memref0[%i, %i] : vector<2xf32>, memref<?x?xf32>
vector.transfer_write %u2, %memref2[%c0, %c0] : vector<3xf32>, memref<?x?xf32>
vector.transfer_write %u3, %memref3[%c0, %c0] : vector<4xf32>, memref<?x?xf32>
vector.transfer_write %u4, %memref4[%c0, %c0] : vector<5xf32>, memref<?x?xf32>
vector.transfer_write %u5, %memref5[%c0, %c0] : vector<6xf32>, memref<?x?xf32>
"some_crippling_use"(%memref3) : (memref<?x?xf32>) -> ()
}
}
return
}
// VECTOR_TRANSFERS-LABEL: func @hoist_vector_transfer_pairs_disjoint(
// VECTOR_TRANSFERS-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
// VECTOR_TRANSFERS-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,
// VECTOR_TRANSFERS-SAME: %[[MEMREF2:[a-zA-Z0-9]*]]: memref<?x?xf32>,
// VECTOR_TRANSFERS-SAME: %[[MEMREF3:[a-zA-Z0-9]*]]: memref<?x?xf32>,
// VECTOR_TRANSFERS-SAME: %[[VAL:[a-zA-Z0-9]*]]: index,
// VECTOR_TRANSFERS-SAME: %[[LB:[a-zA-Z0-9]*]]: index,
// VECTOR_TRANSFERS-SAME: %[[UB:[a-zA-Z0-9]*]]: index,
// VECTOR_TRANSFERS-SAME: %[[STEP:[a-zA-Z0-9]*]]: index,
// VECTOR_TRANSFERS-SAME: %[[RANDOM:[a-zA-Z0-9]*]]: index,
// VECTOR_TRANSFERS-SAME: %[[CMP:[a-zA-Z0-9]*]]: i1
func @hoist_vector_transfer_pairs_disjoint(
%memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>,
%memref2: memref<?x?xf32>, %memref3: memref<?x?xf32>, %val: index, %lb : index, %ub : index,
%step: index, %random_index : index, %cmp: i1) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c3 = constant 3 : index
%cst = constant 0.0 : f32
// VECTOR_TRANSFERS: vector.transfer_read %[[MEMREF2]]{{.*}} : memref<?x?xf32>, vector<3xf32>
// VECTOR_TRANSFERS: vector.transfer_read %[[MEMREF2]]{{.*}} : memref<?x?xf32>, vector<3xf32>
// VECTOR_TRANSFERS: vector.transfer_read %[[MEMREF3]]{{.*}} : memref<?x?xf32>, vector<4xf32>
// VECTOR_TRANSFERS: vector.transfer_read %[[MEMREF3]]{{.*}} : memref<?x?xf32>, vector<4xf32>
// VECTOR_TRANSFERS: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) ->
// VECTOR_TRANSFERS-SAME: (vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
// VECTOR_TRANSFERS: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) ->
// VECTOR_TRANSFERS-SAME: (vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
// VECTOR_TRANSFERS: vector.transfer_read %[[MEMREF1]]{{.*}} : memref<?x?xf32>, vector<2xf32>
// VECTOR_TRANSFERS: vector.transfer_read %[[MEMREF1]]{{.*}} : memref<?x?xf32>, vector<2xf32>
// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32>
// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32>
// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[MEMREF1]]{{.*}} : vector<2xf32>, memref<?x?xf32>
// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[MEMREF1]]{{.*}} : vector<2xf32>, memref<?x?xf32>
// VECTOR_TRANSFERS: scf.yield {{.*}} : vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>
// VECTOR_TRANSFERS: }
// VECTOR_TRANSFERS: scf.yield {{.*}} : vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>
// VECTOR_TRANSFERS: }
// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[MEMREF3]]{{.*}} : vector<4xf32>, memref<?x?xf32>
// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[MEMREF3]]{{.*}} : vector<4xf32>, memref<?x?xf32>
// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[MEMREF2]]{{.*}} : vector<3xf32>, memref<?x?xf32>
// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[MEMREF2]]{{.*}} : vector<3xf32>, memref<?x?xf32>
scf.for %i = %lb to %ub step %step {
scf.for %j = %lb to %ub step %step {
%r00 = vector.transfer_read %memref1[%c0, %c0], %cst: memref<?x?xf32>, vector<2xf32>
%r01 = vector.transfer_read %memref1[%c0, %c1], %cst: memref<?x?xf32>, vector<2xf32>
%r20 = vector.transfer_read %memref2[%c0, %c0], %cst: memref<?x?xf32>, vector<3xf32>
%r21 = vector.transfer_read %memref2[%c0, %c3], %cst: memref<?x?xf32>, vector<3xf32>
%r30 = vector.transfer_read %memref3[%c0, %random_index], %cst: memref<?x?xf32>, vector<4xf32>
%r31 = vector.transfer_read %memref3[%c1, %random_index], %cst: memref<?x?xf32>, vector<4xf32>
%r10 = vector.transfer_read %memref0[%i, %i], %cst: memref<?x?xf32>, vector<2xf32>
%r11 = vector.transfer_read %memref0[%random_index, %random_index], %cst: memref<?x?xf32>, vector<2xf32>
%u00 = "some_use"(%r00) : (vector<2xf32>) -> vector<2xf32>
%u01 = "some_use"(%r01) : (vector<2xf32>) -> vector<2xf32>
%u20 = "some_use"(%r20) : (vector<3xf32>) -> vector<3xf32>
%u21 = "some_use"(%r21) : (vector<3xf32>) -> vector<3xf32>
%u30 = "some_use"(%r30) : (vector<4xf32>) -> vector<4xf32>
%u31 = "some_use"(%r31) : (vector<4xf32>) -> vector<4xf32>
%u10 = "some_use"(%r10) : (vector<2xf32>) -> vector<2xf32>
%u11 = "some_use"(%r11) : (vector<2xf32>) -> vector<2xf32>
vector.transfer_write %u00, %memref1[%c0, %c0] : vector<2xf32>, memref<?x?xf32>
vector.transfer_write %u01, %memref1[%c0, %c1] : vector<2xf32>, memref<?x?xf32>
vector.transfer_write %u20, %memref2[%c0, %c0] : vector<3xf32>, memref<?x?xf32>
vector.transfer_write %u21, %memref2[%c0, %c3] : vector<3xf32>, memref<?x?xf32>
vector.transfer_write %u30, %memref3[%c0, %random_index] : vector<4xf32>, memref<?x?xf32>
vector.transfer_write %u31, %memref3[%c1, %random_index] : vector<4xf32>, memref<?x?xf32>
vector.transfer_write %u10, %memref0[%i, %i] : vector<2xf32>, memref<?x?xf32>
vector.transfer_write %u11, %memref0[%random_index, %random_index] : vector<2xf32>, memref<?x?xf32>
}
}
return
}