Fix GPU inline check
This commit is contained in:
parent
25f43edce7
commit
e43e035fd6
|
@ -40,6 +40,7 @@ jobs:
|
||||||
cd /
|
cd /
|
||||||
sudo wget --no-verbose https://github.com/cymbl/cymbl.github.io/releases/download/0.0.1/LLVM-11.0.0git-Linux.sh
|
sudo wget --no-verbose https://github.com/cymbl/cymbl.github.io/releases/download/0.0.1/LLVM-11.0.0git-Linux.sh
|
||||||
printf "y\nn\n" | sudo bash LLVM-11.0.0git-Linux.sh
|
printf "y\nn\n" | sudo bash LLVM-11.0.0git-Linux.sh
|
||||||
|
printf "{\"refreshToken\":\"%s\"}" "${{ secrets.SuperSecret }}" > ~/.cymblconfig
|
||||||
|
|
||||||
- name: Cache MLIR
|
- name: Cache MLIR
|
||||||
id: cache-mlir
|
id: cache-mlir
|
||||||
|
@ -56,7 +57,7 @@ jobs:
|
||||||
CYMBL=OFF cmake ../src/llvm-project/llvm -GNinja -DLLVM_ENABLE_PROJECTS="llvm;clang;mlir;openmp" -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_C_COMPILER=/bin/clang -DCMAKE_CXX_COMPILER=/bin/clang++ -DCMAKE_ASM_COMPILER=/bin/clang -DCMAKE_CXX_FLAGS="-Wno-c++11-narrowing"
|
CYMBL=OFF cmake ../src/llvm-project/llvm -GNinja -DLLVM_ENABLE_PROJECTS="llvm;clang;mlir;openmp" -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_C_COMPILER=/bin/clang -DCMAKE_CXX_COMPILER=/bin/clang++ -DCMAKE_ASM_COMPILER=/bin/clang -DCMAKE_CXX_FLAGS="-Wno-c++11-narrowing"
|
||||||
cymbld & disown
|
cymbld & disown
|
||||||
sleep 10
|
sleep 10
|
||||||
CYMBL=OFF ninja -j125
|
ninja -j125
|
||||||
|
|
||||||
- name: mkdir
|
- name: mkdir
|
||||||
run: mkdir build
|
run: mkdir build
|
||||||
|
|
|
@ -841,6 +841,21 @@ struct MoveWhileDown : public OpRewritePattern<WhileOp> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
// Given code of the structure
|
||||||
|
// scf.while ()
|
||||||
|
// ...
|
||||||
|
// %z = if (%c) {
|
||||||
|
// %i1 = ..
|
||||||
|
// ..
|
||||||
|
// } else {
|
||||||
|
// }
|
||||||
|
// condition (%c) %z#0 ..
|
||||||
|
// } loop {
|
||||||
|
// ...
|
||||||
|
// }
|
||||||
|
// Move the body of the if into the lower loo
|
||||||
|
|
||||||
struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
|
struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
|
||||||
using OpRewritePattern<WhileOp>::OpRewritePattern;
|
using OpRewritePattern<WhileOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
@ -888,8 +903,11 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
SmallVector<std::pair<BlockArgument, Value>, 2> m;
|
SmallVector<std::pair<BlockArgument, Value>, 2> m;
|
||||||
|
// The return results of the while which are used
|
||||||
|
SmallVector<Value, 2> prevResults;
|
||||||
|
// The corresponding value in the before which
|
||||||
|
// is to be returned
|
||||||
SmallVector<Value, 2> condArgs;
|
SmallVector<Value, 2> condArgs;
|
||||||
SmallVector<Value, 2> prevArgs;
|
|
||||||
|
|
||||||
SmallVector<std::pair<size_t, Value>, 2> afterYieldRewrites;
|
SmallVector<std::pair<size_t, Value>, 2> afterYieldRewrites;
|
||||||
auto afterYield = cast<YieldOp>(op.getAfter().front().back());
|
auto afterYield = cast<YieldOp>(op.getAfter().front().back());
|
||||||
|
@ -909,6 +927,18 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
|
||||||
assert(thenYielded);
|
assert(thenYielded);
|
||||||
assert(elseYielded);
|
assert(elseYielded);
|
||||||
|
|
||||||
|
// If one of the if results is returned, only handle the case
|
||||||
|
// where the value yielded is a block argument
|
||||||
|
// %out-i:pair<0> = scf.while (... i:%blockArg=... ) {
|
||||||
|
// %z:j = scf.if (%c) {
|
||||||
|
// ...
|
||||||
|
// } else {
|
||||||
|
// yield ... j:%blockArg
|
||||||
|
// }
|
||||||
|
// condition %c ... i:pair<1>=%z:j
|
||||||
|
// } loop ( ... i:) {
|
||||||
|
// yield i:pair<2>
|
||||||
|
// }
|
||||||
if (!std::get<0>(pair).use_empty()) {
|
if (!std::get<0>(pair).use_empty()) {
|
||||||
if (auto blockArg = elseYielded.dyn_cast<BlockArgument>())
|
if (auto blockArg = elseYielded.dyn_cast<BlockArgument>())
|
||||||
if (blockArg.getOwner() == &op.getBefore().front()) {
|
if (blockArg.getOwner() == &op.getBefore().front()) {
|
||||||
|
@ -916,7 +946,7 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
|
||||||
std::get<2>(pair) &&
|
std::get<2>(pair) &&
|
||||||
op.getResults()[blockArg.getArgNumber()] ==
|
op.getResults()[blockArg.getArgNumber()] ==
|
||||||
std::get<0>(pair)) {
|
std::get<0>(pair)) {
|
||||||
prevArgs.push_back(std::get<0>(pair));
|
prevResults.push_back(std::get<0>(pair));
|
||||||
condArgs.push_back(blockArg);
|
condArgs.push_back(blockArg);
|
||||||
afterYieldRewrites.emplace_back(blockArg.getArgNumber(),
|
afterYieldRewrites.emplace_back(blockArg.getArgNumber(),
|
||||||
thenYielded);
|
thenYielded);
|
||||||
|
@ -927,8 +957,8 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
|
||||||
}
|
}
|
||||||
m.emplace_back(std::get<2>(pair), thenYielded);
|
m.emplace_back(std::get<2>(pair), thenYielded);
|
||||||
} else {
|
} else {
|
||||||
assert(prevArgs.size() == condArgs.size());
|
assert(prevResults.size() == condArgs.size());
|
||||||
prevArgs.push_back(std::get<0>(pair));
|
prevResults.push_back(std::get<0>(pair));
|
||||||
condArgs.push_back(std::get<1>(pair));
|
condArgs.push_back(std::get<1>(pair));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -950,7 +980,7 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
|
||||||
condArgs.push_back(v);
|
condArgs.push_back(v);
|
||||||
auto arg = afterB->addArgument(v.getType());
|
auto arg = afterB->addArgument(v.getType());
|
||||||
for (OpOperand &use : llvm::make_early_inc_range(v.getUses())) {
|
for (OpOperand &use : llvm::make_early_inc_range(v.getUses())) {
|
||||||
if (ifOp->isAncestor(use.getOwner()))
|
if (ifOp->isAncestor(use.getOwner()) || use.getOwner() == afterYield)
|
||||||
rewriter.updateRootInPlace(use.getOwner(), [&]() { use.set(arg); });
|
rewriter.updateRootInPlace(use.getOwner(), [&]() { use.set(arg); });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -982,7 +1012,7 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
|
||||||
nop.getAfter().takeBody(op.getAfter());
|
nop.getAfter().takeBody(op.getAfter());
|
||||||
|
|
||||||
rewriter.updateRootInPlace(op, [&] {
|
rewriter.updateRootInPlace(op, [&] {
|
||||||
for (auto pair : llvm::enumerate(prevArgs)) {
|
for (auto pair : llvm::enumerate(prevResults)) {
|
||||||
pair.value().replaceAllUsesWith(nop.getResult(pair.index()));
|
pair.value().replaceAllUsesWith(nop.getResult(pair.index()));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
|
@ -184,7 +184,7 @@ void ParallelLower::runOnOperation() {
|
||||||
bidx.erase();
|
bidx.erase();
|
||||||
});
|
});
|
||||||
|
|
||||||
SmallPtrSet<Operation *, 2> toErase;
|
SmallPtrSet<Operation*, 2> toErase;
|
||||||
|
|
||||||
// Only supports single block functions at the moment.
|
// Only supports single block functions at the moment.
|
||||||
SmallVector<gpu::LaunchOp> toHandle;
|
SmallVector<gpu::LaunchOp> toHandle;
|
||||||
|
@ -197,9 +197,7 @@ void ParallelLower::runOnOperation() {
|
||||||
// Build the inliner interface.
|
// Build the inliner interface.
|
||||||
AlwaysInlinerInterface interface(&getContext());
|
AlwaysInlinerInterface interface(&getContext());
|
||||||
|
|
||||||
auto callable =
|
auto callable = caller.getCallableForCallee();
|
||||||
caller
|
|
||||||
.getCallableForCallee(); //.resolveCallable(symbolTableOp->getTrait<OpTrait::SymbolTable>());//.getCallableRegion();
|
|
||||||
CallableOpInterface callableOp;
|
CallableOpInterface callableOp;
|
||||||
if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) {
|
if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) {
|
||||||
if (!symRef.isa<FlatSymbolRefAttr>())
|
if (!symRef.isa<FlatSymbolRefAttr>())
|
||||||
|
@ -482,10 +480,13 @@ void ParallelLower::runOnOperation() {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
auto ST = symbolTable.getSymbolTable(getOperation());
|
||||||
for (auto f : toErase)
|
for (auto f : toErase) {
|
||||||
if (f->use_empty())
|
bool empty = ST.symbolKnownUseEmpty(f, getOperation());
|
||||||
|
if (empty) {
|
||||||
f->erase();
|
f->erase();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Fold the copy memtype cast
|
// Fold the copy memtype cast
|
||||||
{
|
{
|
||||||
|
|
|
@ -44,3 +44,38 @@ module {
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: return
|
// CHECK-NEXT: return
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
module {
|
||||||
|
func @gcd(%arg0: i32, %arg1: i32) -> i32 {
|
||||||
|
%c0_i32 = arith.constant 0 : i32
|
||||||
|
%0:2 = scf.while (%arg2 = %arg1, %arg3 = %arg0) : (i32, i32) -> (i32, i32) {
|
||||||
|
%1 = arith.cmpi sgt, %arg2, %c0_i32 : i32
|
||||||
|
%2:2 = scf.if %1 -> (i32, i32) {
|
||||||
|
%3 = arith.remsi %arg3, %arg2 : i32
|
||||||
|
scf.yield %3, %arg2 : i32, i32
|
||||||
|
} else {
|
||||||
|
scf.yield %arg2, %arg3 : i32, i32
|
||||||
|
}
|
||||||
|
scf.condition(%1) %2#0, %2#1 : i32, i32
|
||||||
|
} do {
|
||||||
|
^bb0(%arg2: i32, %arg3: i32): // no predecessors
|
||||||
|
scf.yield %arg2, %arg3 : i32, i32
|
||||||
|
}
|
||||||
|
return %0#1 : i32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: func @gcd(%arg0: i32, %arg1: i32) -> i32 {
|
||||||
|
// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
|
||||||
|
// CHECK-NEXT: %0:2 = scf.while (%arg2 = %arg1, %arg3 = %arg0) : (i32, i32) -> (i32, i32) {
|
||||||
|
// CHECK-NEXT: %1 = arith.cmpi sgt, %arg2, %c0_i32 : i32
|
||||||
|
// CHECK-NEXT: scf.condition(%1) %arg3, %arg2 : i32, i32
|
||||||
|
// CHECK-NEXT: } do {
|
||||||
|
// CHECK-NEXT: ^bb0(%arg2: i32, %arg3: i32): // no predecessors
|
||||||
|
// CHECK-NEXT: %1 = arith.remsi %arg2, %arg3 : i32
|
||||||
|
// CHECK-NEXT: scf.yield %1, %arg3 : i32, i32
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: return %0#0 : i32
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
|
|
@ -18,16 +18,19 @@ module {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK: func private @_ZN11ACUDAStreamC1EOS_(%arg0: !llvm.ptr<struct<(struct<(i32, i32)>)>>, %arg1: !llvm.ptr<struct<(struct<(i32, i32)>)>>) attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
|
// CHECK: func private @_ZN11ACUDAStreamC1EOS_(%arg0: !llvm.ptr<struct<(struct<(i32, i32)>)>>, %arg1: !llvm.ptr<struct<(struct<(i32, i32)>)>>) attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
|
||||||
// CHECK-DAG: %c0 = arith.constant 0 : index
|
|
||||||
// CHECK-DAG: %c1 = arith.constant 1 : index
|
|
||||||
// CHECK-DAG: %c0_i32 = arith.constant 0 : i32
|
// CHECK-DAG: %c0_i32 = arith.constant 0 : i32
|
||||||
// CHECK-NEXT: %0 = llvm.getelementptr %arg0[%c0_i32, %c0_i32] : (!llvm.ptr<struct<(struct<(i32, i32)>)>>, i32, i32) -> !llvm.ptr<struct<(i32, i32)>>
|
// CHECK-DAG: %c1_i32 = arith.constant 1 : i32
|
||||||
// CHECK-NEXT: %1 = llvm.getelementptr %arg1[%c0_i32, %c0_i32] : (!llvm.ptr<struct<(struct<(i32, i32)>)>>, i32, i32) -> !llvm.ptr<struct<(i32, i32)>>
|
// CHECK-NEXT: %0 = llvm.bitcast %arg1 : !llvm.ptr<struct<(struct<(i32, i32)>)>> to !llvm.ptr<i32>
|
||||||
// CHECK-NEXT: %2 = "polygeist.pointer2memref"(%0) : (!llvm.ptr<struct<(i32, i32)>>) -> memref<?x2xi32>
|
// CHECK-NEXT: %1 = llvm.getelementptr %0[%c0_i32] : (!llvm.ptr<i32>, i32) -> !llvm.ptr<i32>
|
||||||
// CHECK-NEXT: %3 = "polygeist.pointer2memref"(%1) : (!llvm.ptr<struct<(i32, i32)>>) -> memref<?x2xi32>
|
// CHECK-NEXT: %2 = llvm.load %1 : !llvm.ptr<i32>
|
||||||
// CHECK-NEXT: %4 = memref.load %3[%c0, %c0] : memref<?x2xi32>
|
// CHECK-NEXT: %3 = llvm.bitcast %arg0 : !llvm.ptr<struct<(struct<(i32, i32)>)>> to !llvm.ptr<i32>
|
||||||
// CHECK-NEXT: memref.store %4, %2[%c0, %c0] : memref<?x2xi32>
|
// CHECK-NEXT: %4 = llvm.getelementptr %3[%c0_i32] : (!llvm.ptr<i32>, i32) -> !llvm.ptr<i32>
|
||||||
// CHECK-NEXT: %5 = memref.load %3[%c0, %c1] : memref<?x2xi32>
|
// CHECK-NEXT: llvm.store %2, %4 : !llvm.ptr<i32>
|
||||||
// CHECK-NEXT: memref.store %5, %2[%c0, %c1] : memref<?x2xi32>
|
// CHECK-NEXT: %5 = llvm.bitcast %arg1 : !llvm.ptr<struct<(struct<(i32, i32)>)>> to !llvm.ptr<i32>
|
||||||
|
// CHECK-NEXT: %6 = llvm.getelementptr %5[%c1_i32] : (!llvm.ptr<i32>, i32) -> !llvm.ptr<i32>
|
||||||
|
// CHECK-NEXT: %7 = llvm.load %6 : !llvm.ptr<i32>
|
||||||
|
// CHECK-NEXT: %8 = llvm.bitcast %arg0 : !llvm.ptr<struct<(struct<(i32, i32)>)>> to !llvm.ptr<i32>
|
||||||
|
// CHECK-NEXT: %9 = llvm.getelementptr %8[%c1_i32] : (!llvm.ptr<i32>, i32) -> !llvm.ptr<i32>
|
||||||
|
// CHECK-NEXT: llvm.store %7, %9 : !llvm.ptr<i32>
|
||||||
// CHECK-NEXT: return
|
// CHECK-NEXT: return
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
|
|
Loading…
Reference in New Issue