Handle dominating values in loop detection
This commit is contained in:
parent
f2de97724c
commit
75574626d6
|
@ -7,61 +7,6 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TODO fix uses of induction or inner variables outside of loop
|
||||
/*
|
||||
see %2 in
|
||||
func @kernel_gemm(%arg0: i32, %arg1: memref<?xf64>) {
|
||||
%c0 = constant 0 : index
|
||||
%c0_i32 = constant 0 : i32
|
||||
%c0_i64 = constant 0 : i64
|
||||
%c1_i32 = constant 1 : i32
|
||||
%c1_i64 = constant 1 : i64
|
||||
%c32_i32 = constant 32 : i32
|
||||
%cst = constant 1.000000e+00 : f64
|
||||
br ^bb1(%c0_i64 : i64)
|
||||
^bb1(%0: i64): // 2 preds: ^bb0, ^bb2
|
||||
%1 = subi %arg0, %c1_i32 : i32
|
||||
%2 = cmpi "slt", %1, %c0_i32 : i32
|
||||
%3 = scf.if %2 -> (i32) {
|
||||
%14 = subi %c0_i32, %1 : i32
|
||||
%15 = addi %14, %c32_i32 : i32
|
||||
%16 = subi %15, %c1_i32 : i32
|
||||
%17 = divi_signed %16, %c32_i32 : i32
|
||||
%18 = subi %c0_i32, %17 : i32
|
||||
scf.yield %18 : i32
|
||||
} else {
|
||||
%14 = divi_signed %1, %c32_i32 : i32
|
||||
scf.yield %14 : i32
|
||||
}
|
||||
%4 = sexti %3 : i32 to i64
|
||||
%5 = cmpi "sle", %0, %4 : i64
|
||||
cond_br %5, ^bb2, ^bb3
|
||||
^bb2: // pred: ^bb1
|
||||
%6 = load %arg1[%c0] : memref<?xf64>
|
||||
%7 = mulf %6, %cst : f64
|
||||
store %7, %arg1[%c0] : memref<?xf64>
|
||||
%8 = addi %0, %c1_i64 : i64
|
||||
br ^bb1(%8 : i64)
|
||||
^bb3: // pred: ^bb1
|
||||
%9 = scf.if %2 -> (i32) {
|
||||
%14 = subi %c0_i32, %1 : i32
|
||||
%15 = addi %14, %c32_i32 : i32
|
||||
%16 = subi %15, %c1_i32 : i32
|
||||
%17 = divi_signed %16, %c32_i32 : i32
|
||||
%18 = subi %c0_i32, %17 : i32
|
||||
scf.yield %18 : i32
|
||||
} else {
|
||||
%14 = divi_signed %1, %c32_i32 : i32
|
||||
scf.yield %14 : i32
|
||||
}
|
||||
%10 = sexti %9 : i32 to i64
|
||||
%11 = index_cast %10 : i64 to index
|
||||
%12 = load %arg1[%11] : memref<?xf64>
|
||||
%13 = addf %12, %cst : f64
|
||||
store %13, %arg1[%11] : memref<?xf64>
|
||||
return
|
||||
}
|
||||
*/
|
||||
#include "PassDetails.h"
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
|
@ -445,9 +390,31 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region ®ion) {
|
|||
for (auto arg : header->getArguments()) {
|
||||
headerArgumentTypes.push_back(arg.getType());
|
||||
}
|
||||
// TODO values used outside loop should be wrapped.
|
||||
wrapper->addArguments(headerArgumentTypes);
|
||||
|
||||
SmallVector<Value> valsCallingLoop(wrapper->getArguments().begin(), wrapper->getArguments().end());
|
||||
|
||||
SmallVector<std::pair<Value, size_t>> preservedVals;
|
||||
for (auto B : L->getBlocks()) {
|
||||
for (auto &O : *(Block*)B) {
|
||||
for (auto V : O.getResults()) {
|
||||
if (llvm::any_of(V.getUsers(), [&](Operation *user) {
|
||||
Block* blk = user->getBlock();
|
||||
while (blk->getParent() != ®ion)
|
||||
blk = blk->getParentOp()->getBlock();
|
||||
return !L->contains((Wrapper*)blk);
|
||||
})) {
|
||||
preservedVals.emplace_back(V, headerArgumentTypes.size());
|
||||
headerArgumentTypes.push_back(V.getType());
|
||||
valsCallingLoop.push_back(builder.create<mlir::LLVM::UndefOp>(builder.getUnknownLoc(), V.getType()));
|
||||
header->addArgument(V.getType());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO values used outside loop should be wrapped.
|
||||
|
||||
SmallVector<Type, 4> combinedTypes(headerArgumentTypes.begin(),
|
||||
headerArgumentTypes.end());
|
||||
SmallVector<Type, 4> returns;
|
||||
|
@ -457,7 +424,7 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region ®ion) {
|
|||
}
|
||||
|
||||
auto loop = builder.create<mlir::scf::WhileOp>(
|
||||
builder.getUnknownLoc(), combinedTypes, wrapper->getArguments());
|
||||
builder.getUnknownLoc(), combinedTypes, valsCallingLoop);
|
||||
{
|
||||
SmallVector<Value, 4> RetVals;
|
||||
for (size_t i = 0; i < returns.size(); ++i) {
|
||||
|
@ -465,6 +432,15 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region ®ion) {
|
|||
}
|
||||
builder.create<BranchOp>(builder.getUnknownLoc(), target, RetVals);
|
||||
}
|
||||
for (auto& pair : preservedVals) {
|
||||
pair.first.replaceUsesWithIf(loop.getResult(pair.second),
|
||||
[&](OpOperand &op) -> bool {
|
||||
Block* blk = op.getOwner()->getBlock();
|
||||
while (blk->getParent() != ®ion)
|
||||
blk = blk->getParentOp()->getBlock();
|
||||
return !L->contains((Wrapper*)blk);
|
||||
});
|
||||
}
|
||||
|
||||
SmallVector<Block *, 4> Preds;
|
||||
|
||||
|
@ -493,8 +469,7 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region ®ion) {
|
|||
pseudoExit->addArguments(tys);
|
||||
OpBuilder builder(pseudoExit, pseudoExit->begin());
|
||||
tys.clear();
|
||||
builder.create<scf::ConditionOp>(builder.getUnknownLoc(), tys,
|
||||
pseudoExit->getArguments());
|
||||
builder.create<scf::ConditionOp>(builder.getUnknownLoc(), tys, pseudoExit->getArguments());
|
||||
}
|
||||
|
||||
for (auto *w : exitingBlocks) {
|
||||
|
@ -511,6 +486,8 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region ®ion) {
|
|||
std::vector<Value> args = {vfalse};
|
||||
for (auto arg : header->getArguments())
|
||||
args.push_back(arg);
|
||||
for(auto v : preservedVals)
|
||||
args[v.second + 1] = v.first;
|
||||
|
||||
if (auto op = dyn_cast<BranchOp>(terminator)) {
|
||||
args.insert(args.end(), op.getOperands().begin(),
|
||||
|
@ -564,6 +541,8 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region ®ion) {
|
|||
std::vector<Value> args(op.getOperands().begin(),
|
||||
op.getOperands().end());
|
||||
args.insert(args.begin(), vtrue);
|
||||
for (auto pair : preservedVals)
|
||||
args.push_back(pair.first);
|
||||
for (auto ty : returns) {
|
||||
// args.push_back(builder.create<mlir::LLVM::UndefOp>(builder.getUnknownLoc(),
|
||||
// ty));
|
||||
|
@ -581,6 +560,8 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region ®ion) {
|
|||
op.getFalseOperands().end());
|
||||
if (op.getTrueDest() == header) {
|
||||
trueargs.insert(trueargs.begin(), vtrue);
|
||||
for (auto pair : preservedVals)
|
||||
trueargs.push_back(pair.first);
|
||||
for (auto ty : returns) {
|
||||
trueargs.push_back(builder.create<mlir::LLVM::UndefOp>(
|
||||
builder.getUnknownLoc(), ty));
|
||||
|
@ -588,6 +569,8 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region ®ion) {
|
|||
}
|
||||
if (op.getFalseDest() == header) {
|
||||
falseargs.insert(falseargs.begin(), vtrue);
|
||||
for (auto pair : preservedVals)
|
||||
falseargs.push_back(pair.first);
|
||||
for (auto ty : returns) {
|
||||
falseargs.push_back(builder.create<mlir::LLVM::UndefOp>(
|
||||
builder.getUnknownLoc(), ty));
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 982a69616c0000ca5b04abddeddd4b95bb1c590f
|
||||
Subproject commit ca8997eb7f6858768c58f538de3a5c85c8fad7ea
|
|
@ -0,0 +1,107 @@
|
|||
// RUN: polygeist-opt --loop-restructure --split-input-file %s | FileCheck %s
|
||||
|
||||
module {
|
||||
func @kernel_gemm(%arg0: i64) -> i1 {
|
||||
%c0_i64 = arith.constant 0 : i64
|
||||
%c1_i64 = arith.constant 1 : i64
|
||||
br ^bb1(%c0_i64 : i64)
|
||||
^bb1(%0: i64): // 2 preds: ^bb0, ^bb2
|
||||
%2 = arith.cmpi "slt", %0, %c0_i64 : i64
|
||||
%5 = arith.cmpi "sle", %0, %arg0 : i64
|
||||
cond_br %5, ^bb2, ^bb3
|
||||
^bb2: // pred: ^bb1
|
||||
%8 = arith.addi %0, %c1_i64 : i64
|
||||
br ^bb1(%8 : i64)
|
||||
^bb3: // pred: ^bb1
|
||||
return %2 : i1
|
||||
}
|
||||
|
||||
|
||||
// CHECK: func @kernel_gemm(%arg0: i64) -> i1 {
|
||||
// CHECK-NEXT: %c0_i64 = arith.constant 0 : i64
|
||||
// CHECK-NEXT: %c1_i64 = arith.constant 1 : i64
|
||||
// CHECK-NEXT: %0 = llvm.mlir.undef : i1
|
||||
// CHECK-NEXT: %1:2 = scf.while (%arg1 = %c0_i64, %arg2 = %0) : (i64, i1) -> (i64, i1) {
|
||||
// CHECK-NEXT: %2 = arith.cmpi slt, %arg1, %c0_i64 : i64
|
||||
// CHECK-NEXT: %3 = arith.cmpi sle, %arg1, %arg0 : i64
|
||||
// CHECK-NEXT: %false = arith.constant false
|
||||
// CHECK-NEXT: %4:3 = scf.if %3 -> (i1, i64, i1) {
|
||||
// CHECK-NEXT: %5 = arith.addi %arg1, %c1_i64 : i64
|
||||
// CHECK-NEXT: %true = arith.constant true
|
||||
// CHECK-NEXT: scf.yield %true, %5, %2 : i1, i64, i1
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: scf.yield %false, %arg1, %2 : i1, i64, i1
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.condition(%4#0) %4#1, %4#2 : i64, i1
|
||||
// CHECK-NEXT: } do {
|
||||
// CHECK-NEXT: ^bb0(%arg1: i64, %arg2: i1): // no predecessors
|
||||
// CHECK-NEXT: scf.yield %arg1, %arg2 : i64, i1
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %1#1 : i1
|
||||
// CHECK-NEXT: }
|
||||
|
||||
|
||||
func @gcd(%arg0: i32, %arg1: i32) -> i32 {
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%true = arith.constant true
|
||||
%0 = memref.alloca() : memref<i32>
|
||||
%1 = memref.alloca() : memref<i32>
|
||||
%2 = memref.alloca() : memref<i32>
|
||||
memref.store %arg0, %2[] : memref<i32>
|
||||
memref.store %arg1, %1[] : memref<i32>
|
||||
br ^bb1
|
||||
^bb1: // 2 preds: ^bb0, ^bb2
|
||||
%3 = memref.load %1[] : memref<i32>
|
||||
%4 = arith.cmpi sgt, %3, %c0_i32 : i32
|
||||
cond_br %4, ^bb2, ^bb3
|
||||
^bb2: // pred: ^bb1
|
||||
%5 = memref.load %0[] : memref<i32>
|
||||
%8 = memref.load %2[] : memref<i32>
|
||||
%9 = arith.remsi %8, %3 : i32
|
||||
scf.if %true {
|
||||
memref.store %9, %0[] : memref<i32>
|
||||
}
|
||||
memref.store %3, %2[] : memref<i32>
|
||||
memref.store %9, %1[] : memref<i32>
|
||||
br ^bb1
|
||||
^bb3: // pred: ^bb1
|
||||
%7 = memref.load %2[] : memref<i32>
|
||||
return %7 : i32
|
||||
}
|
||||
|
||||
|
||||
// CHECK: func @gcd(%arg0: i32, %arg1: i32) -> i32 {
|
||||
// CHECK-DAG: %c0_i32 = arith.constant 0 : i32
|
||||
// CHECK-DAG: %true = arith.constant true
|
||||
// CHECK-NEXT: %0 = memref.alloca() : memref<i32>
|
||||
// CHECK-NEXT: %1 = memref.alloca() : memref<i32>
|
||||
// CHECK-NEXT: %2 = memref.alloca() : memref<i32>
|
||||
// CHECK-NEXT: memref.store %arg0, %2[] : memref<i32>
|
||||
// CHECK-NEXT: memref.store %arg1, %1[] : memref<i32>
|
||||
// CHECK-NEXT: scf.while : () -> () {
|
||||
// CHECK-NEXT: %4 = memref.load %1[] : memref<i32>
|
||||
// CHECK-NEXT: %5 = arith.cmpi sgt, %4, %c0_i32 : i32
|
||||
// CHECK-NEXT: %false = arith.constant false
|
||||
// CHECK-NEXT: %6 = scf.if %5 -> (i1) {
|
||||
// CHECK-NEXT: %7 = memref.load %0[] : memref<i32>
|
||||
// CHECK-NEXT: %8 = memref.load %2[] : memref<i32>
|
||||
// CHECK-NEXT: %9 = arith.remsi %8, %4 : i32
|
||||
// CHECK-NEXT: scf.if %true {
|
||||
// CHECK-NEXT: memref.store %9, %0[] : memref<i32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: memref.store %4, %2[] : memref<i32>
|
||||
// CHECK-NEXT: memref.store %9, %1[] : memref<i32>
|
||||
// CHECK-NEXT: %true_0 = arith.constant true
|
||||
// CHECK-NEXT: scf.yield %true_0 : i1
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: scf.yield %false : i1
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.condition(%6)
|
||||
// CHECK-NEXT: } do {
|
||||
// CHECK-NEXT: scf.yield
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %3 = memref.load %2[] : memref<i32>
|
||||
// CHECK-NEXT: return %3 : i32
|
||||
// CHECK-NEXT: }
|
||||
|
||||
}
|
Loading…
Reference in New Issue