constrain dynamic indexing based on downstream casts

This commit is contained in:
Morten Borup Petersen 2021-11-30 13:56:18 +00:00 committed by William Moses
parent e216fc1163
commit 1f6f1e97e7
2 changed files with 49 additions and 0 deletions

View File

@ -271,6 +271,16 @@ public:
if (srcMemRefType.getShape().size() !=
(resMemRefType.getShape().size() + 1))
return failure();
// Check that there is not a downstream cast of subindex result. This is a
// bit dubious, but allowing cast canonicalizations - when possible - to
// convert subindexes will ultimately result in fewer memref.subview
// operations to be inferred.
for (auto user : op.getResult().getUsers()) {
if (isa<memref::CastOp>(user))
return failure();
}
for (auto it : llvm::zip(srcMemRefType.getShape().drop_front(),
resMemRefType.getShape())) {
if (std::get<0>(it) != std::get<1>(it))

View File

@ -0,0 +1,39 @@
// RUN: mlir-clang --S --function=* --memref-fullrank %s | FileCheck %s
// The following should be able to fully lower to memref ops without memref
// subviews.
// CHECK-LABEL: func @matrix_power(
// CHECK: %[[VAL_0:.*]]: memref<20x20xi32>, %[[VAL_1:.*]]: memref<20xi32>, %[[VAL_2:.*]]: memref<20xi32>, %[[VAL_3:.*]]: memref<20xi32>)
// CHECK: %[[VAL_4:.*]] = arith.constant 20 : index
// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_7:.*]] = arith.constant 1 : i32
// CHECK: scf.for %[[VAL_8:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_6]] {
// CHECK: %[[VAL_9:.*]] = arith.subi %[[VAL_8]], %[[VAL_6]] : index
// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_9]] : index to i32
// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_7]] : i32
// CHECK: %[[VAL_12:.*]] = arith.index_cast %[[VAL_11]] : i32 to index
// CHECK: scf.for %[[VAL_13:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_13]]] : memref<20xi32>
// CHECK: %[[VAL_15:.*]] = arith.index_cast %[[VAL_14]] : i32 to index
// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_13]]] : memref<20xi32>
// CHECK: %[[VAL_17:.*]] = arith.subi %[[VAL_12]], %[[VAL_6]] : index
// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_13]]] : memref<20xi32>
// CHECK: %[[VAL_19:.*]] = arith.index_cast %[[VAL_18]] : i32 to index
// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_17]], %[[VAL_19]]] : memref<20x20xi32>
// CHECK: %[[VAL_21:.*]] = arith.muli %[[VAL_16]], %[[VAL_20]] : i32
// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_12]], %[[VAL_15]]] : memref<20x20xi32>
// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_22]], %[[VAL_21]] : i32
// CHECK: memref.store %[[VAL_23]], %[[VAL_0]]{{\[}}%[[VAL_12]], %[[VAL_15]]] : memref<20x20xi32>
// CHECK: }
// CHECK: }
// CHECK: return
// CHECK: }
void matrix_power(int x[20][20], int row[20], int col[20], int a[20]) {
for (int k = 1; k < 20; k++) {
for (int p = 0; p < 20; p++) {
x[k][row[p]] += a[p] * x[k - 1][col[p]];
}
}
}