Handle dominating values in loop detection

This commit is contained in:
William S. Moses 2021-12-29 00:41:32 -05:00 committed by William Moses
parent f2de97724c
commit 75574626d6
3 changed files with 150 additions and 60 deletions

View File

@ -7,61 +7,6 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TODO fix uses of induction or inner variables outside of loop
/*
see %2 in
func @kernel_gemm(%arg0: i32, %arg1: memref<?xf64>) {
%c0 = constant 0 : index
%c0_i32 = constant 0 : i32
%c0_i64 = constant 0 : i64
%c1_i32 = constant 1 : i32
%c1_i64 = constant 1 : i64
%c32_i32 = constant 32 : i32
%cst = constant 1.000000e+00 : f64
br ^bb1(%c0_i64 : i64)
^bb1(%0: i64): // 2 preds: ^bb0, ^bb2
%1 = subi %arg0, %c1_i32 : i32
%2 = cmpi "slt", %1, %c0_i32 : i32
%3 = scf.if %2 -> (i32) {
%14 = subi %c0_i32, %1 : i32
%15 = addi %14, %c32_i32 : i32
%16 = subi %15, %c1_i32 : i32
%17 = divi_signed %16, %c32_i32 : i32
%18 = subi %c0_i32, %17 : i32
scf.yield %18 : i32
} else {
%14 = divi_signed %1, %c32_i32 : i32
scf.yield %14 : i32
}
%4 = sexti %3 : i32 to i64
%5 = cmpi "sle", %0, %4 : i64
cond_br %5, ^bb2, ^bb3
^bb2: // pred: ^bb1
%6 = load %arg1[%c0] : memref<?xf64>
%7 = mulf %6, %cst : f64
store %7, %arg1[%c0] : memref<?xf64>
%8 = addi %0, %c1_i64 : i64
br ^bb1(%8 : i64)
^bb3: // pred: ^bb1
%9 = scf.if %2 -> (i32) {
%14 = subi %c0_i32, %1 : i32
%15 = addi %14, %c32_i32 : i32
%16 = subi %15, %c1_i32 : i32
%17 = divi_signed %16, %c32_i32 : i32
%18 = subi %c0_i32, %17 : i32
scf.yield %18 : i32
} else {
%14 = divi_signed %1, %c32_i32 : i32
scf.yield %14 : i32
}
%10 = sexti %9 : i32 to i64
%11 = index_cast %10 : i64 to index
%12 = load %arg1[%11] : memref<?xf64>
%13 = addf %12, %cst : f64
store %13, %arg1[%11] : memref<?xf64>
return
}
*/
#include "PassDetails.h" #include "PassDetails.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@ -445,9 +390,31 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region &region) {
for (auto arg : header->getArguments()) { for (auto arg : header->getArguments()) {
headerArgumentTypes.push_back(arg.getType()); headerArgumentTypes.push_back(arg.getType());
} }
// TODO values used outside loop should be wrapped.
wrapper->addArguments(headerArgumentTypes); wrapper->addArguments(headerArgumentTypes);
SmallVector<Value> valsCallingLoop(wrapper->getArguments().begin(), wrapper->getArguments().end());
SmallVector<std::pair<Value, size_t>> preservedVals;
for (auto B : L->getBlocks()) {
for (auto &O : *(Block*)B) {
for (auto V : O.getResults()) {
if (llvm::any_of(V.getUsers(), [&](Operation *user) {
Block* blk = user->getBlock();
while (blk->getParent() != &region)
blk = blk->getParentOp()->getBlock();
return !L->contains((Wrapper*)blk);
})) {
preservedVals.emplace_back(V, headerArgumentTypes.size());
headerArgumentTypes.push_back(V.getType());
valsCallingLoop.push_back(builder.create<mlir::LLVM::UndefOp>(builder.getUnknownLoc(), V.getType()));
header->addArgument(V.getType());
}
}
}
}
// TODO values used outside loop should be wrapped.
SmallVector<Type, 4> combinedTypes(headerArgumentTypes.begin(), SmallVector<Type, 4> combinedTypes(headerArgumentTypes.begin(),
headerArgumentTypes.end()); headerArgumentTypes.end());
SmallVector<Type, 4> returns; SmallVector<Type, 4> returns;
@ -457,7 +424,7 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region &region) {
} }
auto loop = builder.create<mlir::scf::WhileOp>( auto loop = builder.create<mlir::scf::WhileOp>(
builder.getUnknownLoc(), combinedTypes, wrapper->getArguments()); builder.getUnknownLoc(), combinedTypes, valsCallingLoop);
{ {
SmallVector<Value, 4> RetVals; SmallVector<Value, 4> RetVals;
for (size_t i = 0; i < returns.size(); ++i) { for (size_t i = 0; i < returns.size(); ++i) {
@ -465,6 +432,15 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region &region) {
} }
builder.create<BranchOp>(builder.getUnknownLoc(), target, RetVals); builder.create<BranchOp>(builder.getUnknownLoc(), target, RetVals);
} }
for (auto& pair : preservedVals) {
pair.first.replaceUsesWithIf(loop.getResult(pair.second),
[&](OpOperand &op) -> bool {
Block* blk = op.getOwner()->getBlock();
while (blk->getParent() != &region)
blk = blk->getParentOp()->getBlock();
return !L->contains((Wrapper*)blk);
});
}
SmallVector<Block *, 4> Preds; SmallVector<Block *, 4> Preds;
@ -493,8 +469,7 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region &region) {
pseudoExit->addArguments(tys); pseudoExit->addArguments(tys);
OpBuilder builder(pseudoExit, pseudoExit->begin()); OpBuilder builder(pseudoExit, pseudoExit->begin());
tys.clear(); tys.clear();
builder.create<scf::ConditionOp>(builder.getUnknownLoc(), tys, builder.create<scf::ConditionOp>(builder.getUnknownLoc(), tys, pseudoExit->getArguments());
pseudoExit->getArguments());
} }
for (auto *w : exitingBlocks) { for (auto *w : exitingBlocks) {
@ -511,6 +486,8 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region &region) {
std::vector<Value> args = {vfalse}; std::vector<Value> args = {vfalse};
for (auto arg : header->getArguments()) for (auto arg : header->getArguments())
args.push_back(arg); args.push_back(arg);
for(auto v : preservedVals)
args[v.second + 1] = v.first;
if (auto op = dyn_cast<BranchOp>(terminator)) { if (auto op = dyn_cast<BranchOp>(terminator)) {
args.insert(args.end(), op.getOperands().begin(), args.insert(args.end(), op.getOperands().begin(),
@ -564,6 +541,8 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region &region) {
std::vector<Value> args(op.getOperands().begin(), std::vector<Value> args(op.getOperands().begin(),
op.getOperands().end()); op.getOperands().end());
args.insert(args.begin(), vtrue); args.insert(args.begin(), vtrue);
for (auto pair : preservedVals)
args.push_back(pair.first);
for (auto ty : returns) { for (auto ty : returns) {
// args.push_back(builder.create<mlir::LLVM::UndefOp>(builder.getUnknownLoc(), // args.push_back(builder.create<mlir::LLVM::UndefOp>(builder.getUnknownLoc(),
// ty)); // ty));
@ -581,6 +560,8 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region &region) {
op.getFalseOperands().end()); op.getFalseOperands().end());
if (op.getTrueDest() == header) { if (op.getTrueDest() == header) {
trueargs.insert(trueargs.begin(), vtrue); trueargs.insert(trueargs.begin(), vtrue);
for (auto pair : preservedVals)
trueargs.push_back(pair.first);
for (auto ty : returns) { for (auto ty : returns) {
trueargs.push_back(builder.create<mlir::LLVM::UndefOp>( trueargs.push_back(builder.create<mlir::LLVM::UndefOp>(
builder.getUnknownLoc(), ty)); builder.getUnknownLoc(), ty));
@ -588,6 +569,8 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region &region) {
} }
if (op.getFalseDest() == header) { if (op.getFalseDest() == header) {
falseargs.insert(falseargs.begin(), vtrue); falseargs.insert(falseargs.begin(), vtrue);
for (auto pair : preservedVals)
falseargs.push_back(pair.first);
for (auto ty : returns) { for (auto ty : returns) {
falseargs.push_back(builder.create<mlir::LLVM::UndefOp>( falseargs.push_back(builder.create<mlir::LLVM::UndefOp>(
builder.getUnknownLoc(), ty)); builder.getUnknownLoc(), ty));

@ -1 +1 @@
Subproject commit 982a69616c0000ca5b04abddeddd4b95bb1c590f Subproject commit ca8997eb7f6858768c58f538de3a5c85c8fad7ea

View File

@ -0,0 +1,107 @@
// RUN: polygeist-opt --loop-restructure --split-input-file %s | FileCheck %s
module {
func @kernel_gemm(%arg0: i64) -> i1 {
%c0_i64 = arith.constant 0 : i64
%c1_i64 = arith.constant 1 : i64
br ^bb1(%c0_i64 : i64)
^bb1(%0: i64): // 2 preds: ^bb0, ^bb2
%2 = arith.cmpi "slt", %0, %c0_i64 : i64
%5 = arith.cmpi "sle", %0, %arg0 : i64
cond_br %5, ^bb2, ^bb3
^bb2: // pred: ^bb1
%8 = arith.addi %0, %c1_i64 : i64
br ^bb1(%8 : i64)
^bb3: // pred: ^bb1
return %2 : i1
}
// CHECK: func @kernel_gemm(%arg0: i64) -> i1 {
// CHECK-NEXT: %c0_i64 = arith.constant 0 : i64
// CHECK-NEXT: %c1_i64 = arith.constant 1 : i64
// CHECK-NEXT: %0 = llvm.mlir.undef : i1
// CHECK-NEXT: %1:2 = scf.while (%arg1 = %c0_i64, %arg2 = %0) : (i64, i1) -> (i64, i1) {
// CHECK-NEXT: %2 = arith.cmpi slt, %arg1, %c0_i64 : i64
// CHECK-NEXT: %3 = arith.cmpi sle, %arg1, %arg0 : i64
// CHECK-NEXT: %false = arith.constant false
// CHECK-NEXT: %4:3 = scf.if %3 -> (i1, i64, i1) {
// CHECK-NEXT: %5 = arith.addi %arg1, %c1_i64 : i64
// CHECK-NEXT: %true = arith.constant true
// CHECK-NEXT: scf.yield %true, %5, %2 : i1, i64, i1
// CHECK-NEXT: } else {
// CHECK-NEXT: scf.yield %false, %arg1, %2 : i1, i64, i1
// CHECK-NEXT: }
// CHECK-NEXT: scf.condition(%4#0) %4#1, %4#2 : i64, i1
// CHECK-NEXT: } do {
// CHECK-NEXT: ^bb0(%arg1: i64, %arg2: i1): // no predecessors
// CHECK-NEXT: scf.yield %arg1, %arg2 : i64, i1
// CHECK-NEXT: }
// CHECK-NEXT: return %1#1 : i1
// CHECK-NEXT: }
func @gcd(%arg0: i32, %arg1: i32) -> i32 {
%c0_i32 = arith.constant 0 : i32
%true = arith.constant true
%0 = memref.alloca() : memref<i32>
%1 = memref.alloca() : memref<i32>
%2 = memref.alloca() : memref<i32>
memref.store %arg0, %2[] : memref<i32>
memref.store %arg1, %1[] : memref<i32>
br ^bb1
^bb1: // 2 preds: ^bb0, ^bb2
%3 = memref.load %1[] : memref<i32>
%4 = arith.cmpi sgt, %3, %c0_i32 : i32
cond_br %4, ^bb2, ^bb3
^bb2: // pred: ^bb1
%5 = memref.load %0[] : memref<i32>
%8 = memref.load %2[] : memref<i32>
%9 = arith.remsi %8, %3 : i32
scf.if %true {
memref.store %9, %0[] : memref<i32>
}
memref.store %3, %2[] : memref<i32>
memref.store %9, %1[] : memref<i32>
br ^bb1
^bb3: // pred: ^bb1
%7 = memref.load %2[] : memref<i32>
return %7 : i32
}
// CHECK: func @gcd(%arg0: i32, %arg1: i32) -> i32 {
// CHECK-DAG: %c0_i32 = arith.constant 0 : i32
// CHECK-DAG: %true = arith.constant true
// CHECK-NEXT: %0 = memref.alloca() : memref<i32>
// CHECK-NEXT: %1 = memref.alloca() : memref<i32>
// CHECK-NEXT: %2 = memref.alloca() : memref<i32>
// CHECK-NEXT: memref.store %arg0, %2[] : memref<i32>
// CHECK-NEXT: memref.store %arg1, %1[] : memref<i32>
// CHECK-NEXT: scf.while : () -> () {
// CHECK-NEXT: %4 = memref.load %1[] : memref<i32>
// CHECK-NEXT: %5 = arith.cmpi sgt, %4, %c0_i32 : i32
// CHECK-NEXT: %false = arith.constant false
// CHECK-NEXT: %6 = scf.if %5 -> (i1) {
// CHECK-NEXT: %7 = memref.load %0[] : memref<i32>
// CHECK-NEXT: %8 = memref.load %2[] : memref<i32>
// CHECK-NEXT: %9 = arith.remsi %8, %4 : i32
// CHECK-NEXT: scf.if %true {
// CHECK-NEXT: memref.store %9, %0[] : memref<i32>
// CHECK-NEXT: }
// CHECK-NEXT: memref.store %4, %2[] : memref<i32>
// CHECK-NEXT: memref.store %9, %1[] : memref<i32>
// CHECK-NEXT: %true_0 = arith.constant true
// CHECK-NEXT: scf.yield %true_0 : i1
// CHECK-NEXT: } else {
// CHECK-NEXT: scf.yield %false : i1
// CHECK-NEXT: }
// CHECK-NEXT: scf.condition(%6)
// CHECK-NEXT: } do {
// CHECK-NEXT: scf.yield
// CHECK-NEXT: }
// CHECK-NEXT: %3 = memref.load %2[] : memref<i32>
// CHECK-NEXT: return %3 : i32
// CHECK-NEXT: }
}