Fix GPU inline check

This commit is contained in:
William S. Moses 2021-12-31 17:59:32 -05:00 committed by William Moses
parent 25f43edce7
commit e43e035fd6
5 changed files with 94 additions and 24 deletions

View File

@ -40,6 +40,7 @@ jobs:
cd / cd /
sudo wget --no-verbose https://github.com/cymbl/cymbl.github.io/releases/download/0.0.1/LLVM-11.0.0git-Linux.sh sudo wget --no-verbose https://github.com/cymbl/cymbl.github.io/releases/download/0.0.1/LLVM-11.0.0git-Linux.sh
printf "y\nn\n" | sudo bash LLVM-11.0.0git-Linux.sh printf "y\nn\n" | sudo bash LLVM-11.0.0git-Linux.sh
printf "{\"refreshToken\":\"%s\"}" "${{ secrets.SuperSecret }}" > ~/.cymblconfig
- name: Cache MLIR - name: Cache MLIR
id: cache-mlir id: cache-mlir
@ -56,7 +57,7 @@ jobs:
CYMBL=OFF cmake ../src/llvm-project/llvm -GNinja -DLLVM_ENABLE_PROJECTS="llvm;clang;mlir;openmp" -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_C_COMPILER=/bin/clang -DCMAKE_CXX_COMPILER=/bin/clang++ -DCMAKE_ASM_COMPILER=/bin/clang -DCMAKE_CXX_FLAGS="-Wno-c++11-narrowing" CYMBL=OFF cmake ../src/llvm-project/llvm -GNinja -DLLVM_ENABLE_PROJECTS="llvm;clang;mlir;openmp" -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_C_COMPILER=/bin/clang -DCMAKE_CXX_COMPILER=/bin/clang++ -DCMAKE_ASM_COMPILER=/bin/clang -DCMAKE_CXX_FLAGS="-Wno-c++11-narrowing"
cymbld & disown cymbld & disown
sleep 10 sleep 10
CYMBL=OFF ninja -j125 ninja -j125
- name: mkdir - name: mkdir
run: mkdir build run: mkdir build

View File

@ -841,6 +841,21 @@ struct MoveWhileDown : public OpRewritePattern<WhileOp> {
} }
}; };
// Given code of the structure
// scf.while ()
// ...
// %z = if (%c) {
// %i1 = ..
// ..
// } else {
// }
// condition (%c) %z#0 ..
// } loop {
// ...
// }
// Move the body of the if into the lower loo
struct MoveWhileDown2 : public OpRewritePattern<WhileOp> { struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
using OpRewritePattern<WhileOp>::OpRewritePattern; using OpRewritePattern<WhileOp>::OpRewritePattern;
@ -888,8 +903,11 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
return failure(); return failure();
SmallVector<std::pair<BlockArgument, Value>, 2> m; SmallVector<std::pair<BlockArgument, Value>, 2> m;
// The return results of the while which are used
SmallVector<Value, 2> prevResults;
// The corresponding value in the before which
// is to be returned
SmallVector<Value, 2> condArgs; SmallVector<Value, 2> condArgs;
SmallVector<Value, 2> prevArgs;
SmallVector<std::pair<size_t, Value>, 2> afterYieldRewrites; SmallVector<std::pair<size_t, Value>, 2> afterYieldRewrites;
auto afterYield = cast<YieldOp>(op.getAfter().front().back()); auto afterYield = cast<YieldOp>(op.getAfter().front().back());
@ -909,6 +927,18 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
assert(thenYielded); assert(thenYielded);
assert(elseYielded); assert(elseYielded);
// If one of the if results is returned, only handle the case
// where the value yielded is a block argument
// %out-i:pair<0> = scf.while (... i:%blockArg=... ) {
// %z:j = scf.if (%c) {
// ...
// } else {
// yield ... j:%blockArg
// }
// condition %c ... i:pair<1>=%z:j
// } loop ( ... i:) {
// yield i:pair<2>
// }
if (!std::get<0>(pair).use_empty()) { if (!std::get<0>(pair).use_empty()) {
if (auto blockArg = elseYielded.dyn_cast<BlockArgument>()) if (auto blockArg = elseYielded.dyn_cast<BlockArgument>())
if (blockArg.getOwner() == &op.getBefore().front()) { if (blockArg.getOwner() == &op.getBefore().front()) {
@ -916,7 +946,7 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
std::get<2>(pair) && std::get<2>(pair) &&
op.getResults()[blockArg.getArgNumber()] == op.getResults()[blockArg.getArgNumber()] ==
std::get<0>(pair)) { std::get<0>(pair)) {
prevArgs.push_back(std::get<0>(pair)); prevResults.push_back(std::get<0>(pair));
condArgs.push_back(blockArg); condArgs.push_back(blockArg);
afterYieldRewrites.emplace_back(blockArg.getArgNumber(), afterYieldRewrites.emplace_back(blockArg.getArgNumber(),
thenYielded); thenYielded);
@ -927,8 +957,8 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
} }
m.emplace_back(std::get<2>(pair), thenYielded); m.emplace_back(std::get<2>(pair), thenYielded);
} else { } else {
assert(prevArgs.size() == condArgs.size()); assert(prevResults.size() == condArgs.size());
prevArgs.push_back(std::get<0>(pair)); prevResults.push_back(std::get<0>(pair));
condArgs.push_back(std::get<1>(pair)); condArgs.push_back(std::get<1>(pair));
} }
} }
@ -950,7 +980,7 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
condArgs.push_back(v); condArgs.push_back(v);
auto arg = afterB->addArgument(v.getType()); auto arg = afterB->addArgument(v.getType());
for (OpOperand &use : llvm::make_early_inc_range(v.getUses())) { for (OpOperand &use : llvm::make_early_inc_range(v.getUses())) {
if (ifOp->isAncestor(use.getOwner())) if (ifOp->isAncestor(use.getOwner()) || use.getOwner() == afterYield)
rewriter.updateRootInPlace(use.getOwner(), [&]() { use.set(arg); }); rewriter.updateRootInPlace(use.getOwner(), [&]() { use.set(arg); });
} }
} }
@ -982,7 +1012,7 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
nop.getAfter().takeBody(op.getAfter()); nop.getAfter().takeBody(op.getAfter());
rewriter.updateRootInPlace(op, [&] { rewriter.updateRootInPlace(op, [&] {
for (auto pair : llvm::enumerate(prevArgs)) { for (auto pair : llvm::enumerate(prevResults)) {
pair.value().replaceAllUsesWith(nop.getResult(pair.index())); pair.value().replaceAllUsesWith(nop.getResult(pair.index()));
} }
}); });

View File

@ -184,7 +184,7 @@ void ParallelLower::runOnOperation() {
bidx.erase(); bidx.erase();
}); });
SmallPtrSet<Operation *, 2> toErase; SmallPtrSet<Operation*, 2> toErase;
// Only supports single block functions at the moment. // Only supports single block functions at the moment.
SmallVector<gpu::LaunchOp> toHandle; SmallVector<gpu::LaunchOp> toHandle;
@ -197,9 +197,7 @@ void ParallelLower::runOnOperation() {
// Build the inliner interface. // Build the inliner interface.
AlwaysInlinerInterface interface(&getContext()); AlwaysInlinerInterface interface(&getContext());
auto callable = auto callable = caller.getCallableForCallee();
caller
.getCallableForCallee(); //.resolveCallable(symbolTableOp->getTrait<OpTrait::SymbolTable>());//.getCallableRegion();
CallableOpInterface callableOp; CallableOpInterface callableOp;
if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) { if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) {
if (!symRef.isa<FlatSymbolRefAttr>()) if (!symRef.isa<FlatSymbolRefAttr>())
@ -482,10 +480,13 @@ void ParallelLower::runOnOperation() {
} }
}); });
auto ST = symbolTable.getSymbolTable(getOperation());
for (auto f : toErase) for (auto f : toErase) {
if (f->use_empty()) bool empty = ST.symbolKnownUseEmpty(f, getOperation());
if (empty) {
f->erase(); f->erase();
}
}
// Fold the copy memtype cast // Fold the copy memtype cast
{ {

View File

@ -44,3 +44,38 @@ module {
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK-NEXT: return // CHECK-NEXT: return
// CHECK-NEXT: } // CHECK-NEXT: }
// -----
module {
func @gcd(%arg0: i32, %arg1: i32) -> i32 {
%c0_i32 = arith.constant 0 : i32
%0:2 = scf.while (%arg2 = %arg1, %arg3 = %arg0) : (i32, i32) -> (i32, i32) {
%1 = arith.cmpi sgt, %arg2, %c0_i32 : i32
%2:2 = scf.if %1 -> (i32, i32) {
%3 = arith.remsi %arg3, %arg2 : i32
scf.yield %3, %arg2 : i32, i32
} else {
scf.yield %arg2, %arg3 : i32, i32
}
scf.condition(%1) %2#0, %2#1 : i32, i32
} do {
^bb0(%arg2: i32, %arg3: i32): // no predecessors
scf.yield %arg2, %arg3 : i32, i32
}
return %0#1 : i32
}
}
// CHECK: func @gcd(%arg0: i32, %arg1: i32) -> i32 {
// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
// CHECK-NEXT: %0:2 = scf.while (%arg2 = %arg1, %arg3 = %arg0) : (i32, i32) -> (i32, i32) {
// CHECK-NEXT: %1 = arith.cmpi sgt, %arg2, %c0_i32 : i32
// CHECK-NEXT: scf.condition(%1) %arg3, %arg2 : i32, i32
// CHECK-NEXT: } do {
// CHECK-NEXT: ^bb0(%arg2: i32, %arg3: i32): // no predecessors
// CHECK-NEXT: %1 = arith.remsi %arg2, %arg3 : i32
// CHECK-NEXT: scf.yield %1, %arg3 : i32, i32
// CHECK-NEXT: }
// CHECK-NEXT: return %0#0 : i32
// CHECK-NEXT: }

View File

@ -18,16 +18,19 @@ module {
} }
// CHECK: func private @_ZN11ACUDAStreamC1EOS_(%arg0: !llvm.ptr<struct<(struct<(i32, i32)>)>>, %arg1: !llvm.ptr<struct<(struct<(i32, i32)>)>>) attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} { // CHECK: func private @_ZN11ACUDAStreamC1EOS_(%arg0: !llvm.ptr<struct<(struct<(i32, i32)>)>>, %arg1: !llvm.ptr<struct<(struct<(i32, i32)>)>>) attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
// CHECK-DAG: %c0 = arith.constant 0 : index
// CHECK-DAG: %c1 = arith.constant 1 : index
// CHECK-DAG: %c0_i32 = arith.constant 0 : i32 // CHECK-DAG: %c0_i32 = arith.constant 0 : i32
// CHECK-NEXT: %0 = llvm.getelementptr %arg0[%c0_i32, %c0_i32] : (!llvm.ptr<struct<(struct<(i32, i32)>)>>, i32, i32) -> !llvm.ptr<struct<(i32, i32)>> // CHECK-DAG: %c1_i32 = arith.constant 1 : i32
// CHECK-NEXT: %1 = llvm.getelementptr %arg1[%c0_i32, %c0_i32] : (!llvm.ptr<struct<(struct<(i32, i32)>)>>, i32, i32) -> !llvm.ptr<struct<(i32, i32)>> // CHECK-NEXT: %0 = llvm.bitcast %arg1 : !llvm.ptr<struct<(struct<(i32, i32)>)>> to !llvm.ptr<i32>
// CHECK-NEXT: %2 = "polygeist.pointer2memref"(%0) : (!llvm.ptr<struct<(i32, i32)>>) -> memref<?x2xi32> // CHECK-NEXT: %1 = llvm.getelementptr %0[%c0_i32] : (!llvm.ptr<i32>, i32) -> !llvm.ptr<i32>
// CHECK-NEXT: %3 = "polygeist.pointer2memref"(%1) : (!llvm.ptr<struct<(i32, i32)>>) -> memref<?x2xi32> // CHECK-NEXT: %2 = llvm.load %1 : !llvm.ptr<i32>
// CHECK-NEXT: %4 = memref.load %3[%c0, %c0] : memref<?x2xi32> // CHECK-NEXT: %3 = llvm.bitcast %arg0 : !llvm.ptr<struct<(struct<(i32, i32)>)>> to !llvm.ptr<i32>
// CHECK-NEXT: memref.store %4, %2[%c0, %c0] : memref<?x2xi32> // CHECK-NEXT: %4 = llvm.getelementptr %3[%c0_i32] : (!llvm.ptr<i32>, i32) -> !llvm.ptr<i32>
// CHECK-NEXT: %5 = memref.load %3[%c0, %c1] : memref<?x2xi32> // CHECK-NEXT: llvm.store %2, %4 : !llvm.ptr<i32>
// CHECK-NEXT: memref.store %5, %2[%c0, %c1] : memref<?x2xi32> // CHECK-NEXT: %5 = llvm.bitcast %arg1 : !llvm.ptr<struct<(struct<(i32, i32)>)>> to !llvm.ptr<i32>
// CHECK-NEXT: %6 = llvm.getelementptr %5[%c1_i32] : (!llvm.ptr<i32>, i32) -> !llvm.ptr<i32>
// CHECK-NEXT: %7 = llvm.load %6 : !llvm.ptr<i32>
// CHECK-NEXT: %8 = llvm.bitcast %arg0 : !llvm.ptr<struct<(struct<(i32, i32)>)>> to !llvm.ptr<i32>
// CHECK-NEXT: %9 = llvm.getelementptr %8[%c1_i32] : (!llvm.ptr<i32>, i32) -> !llvm.ptr<i32>
// CHECK-NEXT: llvm.store %7, %9 : !llvm.ptr<i32>
// CHECK-NEXT: return // CHECK-NEXT: return
// CHECK-NEXT: } // CHECK-NEXT: }