[mlir][Vector] Add custom slt / SCF.if folding to VectorToSCF

scf.if currently lacks folding on true / false conditionals.
Such foldings are a bit more involved than can be addressed immediately.
This revision introduces an eager folding  for lowering vector.transfer operations in the presence of unrolling.

Differential revision: https://reviews.llvm.org/D83146
This commit is contained in:
Nicolas Vasilache 2020-07-06 08:16:53 -04:00
parent 05c65dc0fe
commit bd87c6bce1
2 changed files with 43 additions and 3 deletions

View File

@ -174,6 +174,27 @@ void NDTransferOpHelper<ConcreteOp>::emitLoops(Lambda loopBodyBuilder) {
}
}
static Optional<int64_t> extractConstantIndex(Value v) {
if (auto cstOp = v.getDefiningOp<ConstantIndexOp>())
return cstOp.getValue();
if (auto affineApplyOp = v.getDefiningOp<AffineApplyOp>())
if (affineApplyOp.getAffineMap().isSingleConstant())
return affineApplyOp.getAffineMap().getSingleConstantResult();
return None;
}
// Missing foldings of scf.if make it necessary to perform poor man's folding
// eagerly, especially in the case of unrolling. In the future, this should go
// away once scf.if folds properly.
static Value onTheFlyFoldSLT(Value v, Value ub) {
using namespace mlir::edsc::op;
auto maybeCstV = extractConstantIndex(v);
auto maybeCstUb = extractConstantIndex(ub);
if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb)
return Value();
return slt(v, ub);
}
template <typename ConcreteOp>
Value NDTransferOpHelper<ConcreteOp>::emitInBoundsCondition(
ValueRange majorIvs, ValueRange majorOffsets,
@ -187,9 +208,11 @@ Value NDTransferOpHelper<ConcreteOp>::emitInBoundsCondition(
using namespace mlir::edsc::op;
majorIvsPlusOffsets.push_back(iv + off);
if (xferOp.isMaskedDim(leadingRank + idx)) {
Value inBounds = slt(majorIvsPlusOffsets.back(), ub);
inBoundsCondition =
(inBoundsCondition) ? (inBoundsCondition && inBounds) : inBounds;
Value inBoundsCond = onTheFlyFoldSLT(majorIvsPlusOffsets.back(), ub);
if (inBoundsCond)
inBoundsCondition = (inBoundsCondition)
? (inBoundsCondition && inBoundsCond)
: inBoundsCond;
}
++idx;
}

View File

@ -383,3 +383,20 @@ func @transfer_write_progressive_not_masked(%A : memref<?x?xf32>, %base: index,
vector<3x15xf32>, memref<?x?xf32>
return
}
// -----
// FULL-UNROLL-LABEL: transfer_read_simple
func @transfer_read_simple(%A : memref<2x2xf32>) -> vector<2x2xf32> {
%c0 = constant 0 : index
%f0 = constant 0.0 : f32
// FULL-UNROLL-DAG: %[[VC0:.*]] = constant dense<0.000000e+00> : vector<2x2xf32>
// FULL-UNROLL-DAG: %[[C0:.*]] = constant 0 : index
// FULL-UNROLL-DAG: %[[C1:.*]] = constant 1 : index
// FULL-UNROLL: %[[V0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]]
// FULL-UNROLL: %[[RES0:.*]] = vector.insert %[[V0]], %[[VC0]] [0] : vector<2xf32> into vector<2x2xf32>
// FULL-UNROLL: %[[V1:.*]] = vector.transfer_read %{{.*}}[%[[C1]], %[[C0]]]
// FULL-UNROLL: %[[RES1:.*]] = vector.insert %[[V1]], %[[RES0]] [1] : vector<2xf32> into vector<2x2xf32>
%0 = vector.transfer_read %A[%c0, %c0], %f0 : memref<2x2xf32>, vector<2x2xf32>
return %0 : vector<2x2xf32>
}