From e43e035fd650dc3822732bf4a790ce37c74ff991 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 31 Dec 2021 17:59:32 -0500 Subject: [PATCH] Fix GPU inline check --- .github/workflows/build.yml | 3 +- lib/polygeist/Passes/CanonicalizeFor.cpp | 42 ++++++++++++++++++++---- lib/polygeist/Passes/ParallelLower.cpp | 15 +++++---- test/polygeist-opt/canonicalizefor.mlir | 35 ++++++++++++++++++++ test/polygeist-opt/copy2.mlir | 23 +++++++------ 5 files changed, 94 insertions(+), 24 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 51f4f0a..b6b9aee 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -40,6 +40,7 @@ jobs: cd / sudo wget --no-verbose https://github.com/cymbl/cymbl.github.io/releases/download/0.0.1/LLVM-11.0.0git-Linux.sh printf "y\nn\n" | sudo bash LLVM-11.0.0git-Linux.sh + printf "{\"refreshToken\":\"%s\"}" "${{ secrets.SuperSecret }}" > ~/.cymblconfig - name: Cache MLIR id: cache-mlir @@ -56,7 +57,7 @@ jobs: CYMBL=OFF cmake ../src/llvm-project/llvm -GNinja -DLLVM_ENABLE_PROJECTS="llvm;clang;mlir;openmp" -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_C_COMPILER=/bin/clang -DCMAKE_CXX_COMPILER=/bin/clang++ -DCMAKE_ASM_COMPILER=/bin/clang -DCMAKE_CXX_FLAGS="-Wno-c++11-narrowing" cymbld & disown sleep 10 - CYMBL=OFF ninja -j125 + ninja -j125 - name: mkdir run: mkdir build diff --git a/lib/polygeist/Passes/CanonicalizeFor.cpp b/lib/polygeist/Passes/CanonicalizeFor.cpp index c1427aa..8727488 100644 --- a/lib/polygeist/Passes/CanonicalizeFor.cpp +++ b/lib/polygeist/Passes/CanonicalizeFor.cpp @@ -841,6 +841,21 @@ struct MoveWhileDown : public OpRewritePattern { } }; + +// Given code of the structure +// scf.while () +// ... +// %z = if (%c) { +// %i1 = .. +// .. +// } else { +// } +// condition (%c) %z#0 .. +// } loop { +// ... +// } +// Move the body of the if into the lower loo + struct MoveWhileDown2 : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -888,8 +903,11 @@ struct MoveWhileDown2 : public OpRewritePattern { return failure(); SmallVector, 2> m; + // The return results of the while which are used + SmallVector prevResults; + // The corresponding value in the before which + // is to be returned SmallVector condArgs; - SmallVector prevArgs; SmallVector, 2> afterYieldRewrites; auto afterYield = cast(op.getAfter().front().back()); @@ -909,6 +927,18 @@ struct MoveWhileDown2 : public OpRewritePattern { assert(thenYielded); assert(elseYielded); + // If one of the if results is returned, only handle the case + // where the value yielded is a block argument + // %out-i:pair<0> = scf.while (... i:%blockArg=... ) { + // %z:j = scf.if (%c) { + // ... + // } else { + // yield ... j:%blockArg + // } + // condition %c ... i:pair<1>=%z:j + // } loop ( ... i:) { + // yield i:pair<2> + // } if (!std::get<0>(pair).use_empty()) { if (auto blockArg = elseYielded.dyn_cast()) if (blockArg.getOwner() == &op.getBefore().front()) { @@ -916,7 +946,7 @@ struct MoveWhileDown2 : public OpRewritePattern { std::get<2>(pair) && op.getResults()[blockArg.getArgNumber()] == std::get<0>(pair)) { - prevArgs.push_back(std::get<0>(pair)); + prevResults.push_back(std::get<0>(pair)); condArgs.push_back(blockArg); afterYieldRewrites.emplace_back(blockArg.getArgNumber(), thenYielded); @@ -927,8 +957,8 @@ struct MoveWhileDown2 : public OpRewritePattern { } m.emplace_back(std::get<2>(pair), thenYielded); } else { - assert(prevArgs.size() == condArgs.size()); - prevArgs.push_back(std::get<0>(pair)); + assert(prevResults.size() == condArgs.size()); + prevResults.push_back(std::get<0>(pair)); condArgs.push_back(std::get<1>(pair)); } } @@ -950,7 +980,7 @@ struct MoveWhileDown2 : public OpRewritePattern { condArgs.push_back(v); auto arg = afterB->addArgument(v.getType()); for (OpOperand &use : llvm::make_early_inc_range(v.getUses())) { - if (ifOp->isAncestor(use.getOwner())) + if (ifOp->isAncestor(use.getOwner()) || use.getOwner() == afterYield) rewriter.updateRootInPlace(use.getOwner(), [&]() { use.set(arg); }); } } @@ -982,7 +1012,7 @@ struct MoveWhileDown2 : public OpRewritePattern { nop.getAfter().takeBody(op.getAfter()); rewriter.updateRootInPlace(op, [&] { - for (auto pair : llvm::enumerate(prevArgs)) { + for (auto pair : llvm::enumerate(prevResults)) { pair.value().replaceAllUsesWith(nop.getResult(pair.index())); } }); diff --git a/lib/polygeist/Passes/ParallelLower.cpp b/lib/polygeist/Passes/ParallelLower.cpp index 6ede92d..749564d 100644 --- a/lib/polygeist/Passes/ParallelLower.cpp +++ b/lib/polygeist/Passes/ParallelLower.cpp @@ -184,7 +184,7 @@ void ParallelLower::runOnOperation() { bidx.erase(); }); - SmallPtrSet toErase; + SmallPtrSet toErase; // Only supports single block functions at the moment. SmallVector toHandle; @@ -197,9 +197,7 @@ void ParallelLower::runOnOperation() { // Build the inliner interface. AlwaysInlinerInterface interface(&getContext()); - auto callable = - caller - .getCallableForCallee(); //.resolveCallable(symbolTableOp->getTrait());//.getCallableRegion(); + auto callable = caller.getCallableForCallee(); CallableOpInterface callableOp; if (SymbolRefAttr symRef = callable.dyn_cast()) { if (!symRef.isa()) @@ -482,10 +480,13 @@ void ParallelLower::runOnOperation() { } }); - - for (auto f : toErase) - if (f->use_empty()) + auto ST = symbolTable.getSymbolTable(getOperation()); + for (auto f : toErase) { + bool empty = ST.symbolKnownUseEmpty(f, getOperation()); + if (empty) { f->erase(); + } + } // Fold the copy memtype cast { diff --git a/test/polygeist-opt/canonicalizefor.mlir b/test/polygeist-opt/canonicalizefor.mlir index 15698c3..b0938f3 100644 --- a/test/polygeist-opt/canonicalizefor.mlir +++ b/test/polygeist-opt/canonicalizefor.mlir @@ -44,3 +44,38 @@ module { // CHECK-NEXT: } // CHECK-NEXT: return // CHECK-NEXT: } + +// ----- + +module { + func @gcd(%arg0: i32, %arg1: i32) -> i32 { + %c0_i32 = arith.constant 0 : i32 + %0:2 = scf.while (%arg2 = %arg1, %arg3 = %arg0) : (i32, i32) -> (i32, i32) { + %1 = arith.cmpi sgt, %arg2, %c0_i32 : i32 + %2:2 = scf.if %1 -> (i32, i32) { + %3 = arith.remsi %arg3, %arg2 : i32 + scf.yield %3, %arg2 : i32, i32 + } else { + scf.yield %arg2, %arg3 : i32, i32 + } + scf.condition(%1) %2#0, %2#1 : i32, i32 + } do { + ^bb0(%arg2: i32, %arg3: i32): // no predecessors + scf.yield %arg2, %arg3 : i32, i32 + } + return %0#1 : i32 + } +} + +// CHECK: func @gcd(%arg0: i32, %arg1: i32) -> i32 { +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: %0:2 = scf.while (%arg2 = %arg1, %arg3 = %arg0) : (i32, i32) -> (i32, i32) { +// CHECK-NEXT: %1 = arith.cmpi sgt, %arg2, %c0_i32 : i32 +// CHECK-NEXT: scf.condition(%1) %arg3, %arg2 : i32, i32 +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%arg2: i32, %arg3: i32): // no predecessors +// CHECK-NEXT: %1 = arith.remsi %arg2, %arg3 : i32 +// CHECK-NEXT: scf.yield %1, %arg3 : i32, i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %0#0 : i32 +// CHECK-NEXT: } diff --git a/test/polygeist-opt/copy2.mlir b/test/polygeist-opt/copy2.mlir index 23bf960..c596c7b 100644 --- a/test/polygeist-opt/copy2.mlir +++ b/test/polygeist-opt/copy2.mlir @@ -18,16 +18,19 @@ module { } // CHECK: func private @_ZN11ACUDAStreamC1EOS_(%arg0: !llvm.ptr)>>, %arg1: !llvm.ptr)>>) attributes {llvm.linkage = #llvm.linkage} { -// CHECK-DAG: %c0 = arith.constant 0 : index -// CHECK-DAG: %c1 = arith.constant 1 : index // CHECK-DAG: %c0_i32 = arith.constant 0 : i32 -// CHECK-NEXT: %0 = llvm.getelementptr %arg0[%c0_i32, %c0_i32] : (!llvm.ptr)>>, i32, i32) -> !llvm.ptr> -// CHECK-NEXT: %1 = llvm.getelementptr %arg1[%c0_i32, %c0_i32] : (!llvm.ptr)>>, i32, i32) -> !llvm.ptr> -// CHECK-NEXT: %2 = "polygeist.pointer2memref"(%0) : (!llvm.ptr>) -> memref -// CHECK-NEXT: %3 = "polygeist.pointer2memref"(%1) : (!llvm.ptr>) -> memref -// CHECK-NEXT: %4 = memref.load %3[%c0, %c0] : memref -// CHECK-NEXT: memref.store %4, %2[%c0, %c0] : memref -// CHECK-NEXT: %5 = memref.load %3[%c0, %c1] : memref -// CHECK-NEXT: memref.store %5, %2[%c0, %c1] : memref +// CHECK-DAG: %c1_i32 = arith.constant 1 : i32 +// CHECK-NEXT: %0 = llvm.bitcast %arg1 : !llvm.ptr)>> to !llvm.ptr +// CHECK-NEXT: %1 = llvm.getelementptr %0[%c0_i32] : (!llvm.ptr, i32) -> !llvm.ptr +// CHECK-NEXT: %2 = llvm.load %1 : !llvm.ptr +// CHECK-NEXT: %3 = llvm.bitcast %arg0 : !llvm.ptr)>> to !llvm.ptr +// CHECK-NEXT: %4 = llvm.getelementptr %3[%c0_i32] : (!llvm.ptr, i32) -> !llvm.ptr +// CHECK-NEXT: llvm.store %2, %4 : !llvm.ptr +// CHECK-NEXT: %5 = llvm.bitcast %arg1 : !llvm.ptr)>> to !llvm.ptr +// CHECK-NEXT: %6 = llvm.getelementptr %5[%c1_i32] : (!llvm.ptr, i32) -> !llvm.ptr +// CHECK-NEXT: %7 = llvm.load %6 : !llvm.ptr +// CHECK-NEXT: %8 = llvm.bitcast %arg0 : !llvm.ptr)>> to !llvm.ptr +// CHECK-NEXT: %9 = llvm.getelementptr %8[%c1_i32] : (!llvm.ptr, i32) -> !llvm.ptr +// CHECK-NEXT: llvm.store %7, %9 : !llvm.ptr // CHECK-NEXT: return // CHECK-NEXT: }