From 92684ed8443e1b3592575d97826ecbc8394d29e6 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 29 Dec 2021 22:52:23 -0500 Subject: [PATCH] Fix mem2reg bug --- lib/polygeist/Passes/Mem2Reg.cpp | 11 ++++- lib/polygeist/Passes/ParallelLower.cpp | 62 ++++++++++++++++++-------- test/polygeist-opt/execmem2reg.mlir | 57 +++++++++++++++++++++++ tools/mlir-clang/Lib/clang-mlir.cc | 34 +++++++++----- tools/mlir-clang/mlir-clang.cc | 2 +- 5 files changed, 135 insertions(+), 31 deletions(-) create mode 100644 test/polygeist-opt/execmem2reg.mlir diff --git a/lib/polygeist/Passes/Mem2Reg.cpp b/lib/polygeist/Passes/Mem2Reg.cpp index 42a925c..fce52d0 100644 --- a/lib/polygeist/Passes/Mem2Reg.cpp +++ b/lib/polygeist/Passes/Mem2Reg.cpp @@ -448,6 +448,8 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector idx, for (auto &a : block) { ops.push_back(&a); } + LLVM_DEBUG( llvm::dbgs() << " starting block: "; block.print(llvm::dbgs()); llvm::dbgs() << " with "; + if (lastVal) llvm::dbgs() << lastVal << "\n"; else llvm::dbgs() << " null\n"; ); for (auto a : ops) { if (StoringOperations.count(a)) { if (auto exOp = dyn_cast(a)) { @@ -455,6 +457,7 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector idx, Value thenVal; // = handleBlock(exOp.region().front(), lastVal); lastVal = nullptr; seenSubStore = true; + LLVM_DEBUG( llvm::dbgs() << " zeroing val due to " << exOp << "\n"; ); continue; bool needsAfter = false; @@ -470,8 +473,9 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector idx, }); } } - if (!needsAfter) + if (!needsAfter) { continue; + } Block &then = exOp.region().back(); OpBuilder B(exOp.getContext()); @@ -550,7 +554,8 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector idx, else newLoad = B.create(ifOp.getLoc(), AI); - loadOps.insert(newLoad); + if (!seenSubStore) + loadOps.insert(newLoad); lastVal = newLoad->getResult(0); } @@ -813,6 +818,8 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector idx, }); } } + LLVM_DEBUG( llvm::dbgs() << " ending block: "; block.print(llvm::dbgs()); llvm::dbgs() << " with "; + if (lastVal) llvm::dbgs() << lastVal << "\n"; else llvm::dbgs() << " null\n"; ); return lastStoreInBlock[&block] = lastVal; }; diff --git a/lib/polygeist/Passes/ParallelLower.cpp b/lib/polygeist/Passes/ParallelLower.cpp index d2f3cee..5b07034 100644 --- a/lib/polygeist/Passes/ParallelLower.cpp +++ b/lib/polygeist/Passes/ParallelLower.cpp @@ -144,24 +144,33 @@ struct AlwaysInlinerInterface : public InlinerInterface { }; // TODO -#if 0 -mlir::LLVM::LLVMFuncOp MLIRASTConsumer::GetOrCreateMallocFunction(ModuleOp* module) { - std::string name = "malloc"; - if (llvmFunctions.find(name) != llvmFunctions.end()) { - return llvmFunctions[name]; - } +mlir::LLVM::LLVMFuncOp GetOrCreateMallocFunction(ModuleOp module) { + mlir::OpBuilder builder(module.getContext()); + SymbolTableCollection symbolTable; + if (auto fn = dyn_cast_or_null(symbolTable.lookupSymbolIn(module, builder.getIdentifier("malloc")))) + return fn; auto ctx = module->getContext(); mlir::Type types[] = {mlir::IntegerType::get(ctx, 64)}; auto llvmFnType = LLVM::LLVMFunctionType::get( LLVM::LLVMPointerType::get(mlir::IntegerType::get(ctx, 8)), types, false); LLVM::Linkage lnk = LLVM::Linkage::External; - mlir::OpBuilder builder(module->getContext()); - builder.setInsertionPointToStart(module->getBody()); - return llvmFunctions[name] = builder.create( - module->getLoc(), name, llvmFnType, lnk); + builder.setInsertionPointToStart(module.getBody()); + return builder.create(module.getLoc(), "malloc", llvmFnType, lnk); +} +mlir::LLVM::LLVMFuncOp GetOrCreateFreeFunction(ModuleOp module) { + mlir::OpBuilder builder(module.getContext()); + SymbolTableCollection symbolTable; + if (auto fn = dyn_cast_or_null(symbolTable.lookupSymbolIn(module, builder.getIdentifier("free")))) + return fn; + auto ctx = module->getContext(); + auto llvmFnType = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(ctx), ArrayRef(LLVM::LLVMPointerType::get(builder.getI8Type())), false); + + LLVM::Linkage lnk = LLVM::Linkage::External; + builder.setInsertionPointToStart(module.getBody()); + return builder.create(module.getLoc(), "free", llvmFnType, lnk); } -#endif void ParallelLower::runOnOperation() { // The inliner should only be run on operations that define a symbol table, @@ -419,17 +428,34 @@ void ParallelLower::runOnOperation() { OpBuilder bz(call); auto falsev = bz.create(call.getLoc(), false, 1); bz.create(call.getLoc(), call.getOperand(0), - call.getOperand(1), call.getOperand(2), + bz.create(call.getLoc(), bz.getI8Type(), call.getOperand(1)), call.getOperand(2), /*isVolatile*/ falsev); - Value vals[] = {call.getOperand(2)}; - call.replaceAllUsesWith(ArrayRef(vals)); - call.erase(); - /* - } else if (call.getCallee().getValue() == "cudaMalloc") { Value vals[] = {call.getOperand(0)}; call.replaceAllUsesWith(ArrayRef(vals)); call.erase(); - */ + } else if (call.getCallee().getValue() == "cudaMalloc") { + auto mf = GetOrCreateMallocFunction(getOperation()); + OpBuilder bz(call); + Value args[] = {bz.create(call.getLoc(), bz.getI64Type(), call.getOperand(1))}; + mlir::Value alloc = bz.create(call.getLoc(), mf, args).getResult(0); + bz.create(call.getLoc(), alloc, call.getOperand(0)); + { + auto retv = bz.create(call.getLoc(), 0, call.getResult(0).getType().cast().getWidth()); + Value vals[] = {retv}; + call.replaceAllUsesWith(ArrayRef(vals)); + call.erase(); + } + } else if (call.getCallee().getValue() == "cudaFree") { + auto mf = GetOrCreateFreeFunction(getOperation()); + OpBuilder bz(call); + Value args[] = {call.getOperand(0)}; + bz.create(call.getLoc(), mf, args); + { + auto retv = bz.create(call.getLoc(), 0, call.getResult(0).getType().cast().getWidth()); + Value vals[] = {retv}; + call.replaceAllUsesWith(ArrayRef(vals)); + call.erase(); + } } else if (call.getCallee().getValue() == "cudaDeviceSynchronize") { OpBuilder bz(call); auto retv = bz.create(call.getLoc(), 0, call.getResult(0).getType().cast().getWidth()); diff --git a/test/polygeist-opt/execmem2reg.mlir b/test/polygeist-opt/execmem2reg.mlir new file mode 100644 index 0000000..ce293a5 --- /dev/null +++ b/test/polygeist-opt/execmem2reg.mlir @@ -0,0 +1,57 @@ +// RUN: polygeist-opt --mem2reg --split-input-file %s | FileCheck %s + +module { + llvm.func @print(i32) + func @h(%arg7: i1, %arg8: i1, %arg9 : i1) { + %c1_i32 = arith.constant 1 : i32 + %c5_i32 = arith.constant 5 : i32 + %c0_i32 = arith.constant 0 : i32 + %c-1_i32 = arith.constant -1 : i32 + %2 = memref.alloca() : memref + %3 = llvm.mlir.undef : i32 + memref.store %3, %2[] : memref + scf.if %arg8 { + memref.store %c-1_i32, %2[] : memref + scf.execute_region { + memref.store %c0_i32, %2[] : memref + scf.yield + } + scf.if %arg9 { + memref.store %c5_i32, %2[] : memref + } + %23 = memref.load %2[] : memref + llvm.call @print(%23) : (i32) -> () + } + return + } +} + +// CHECK: func @h(%arg0: i1, %arg1: i1, %arg2: i1) +// CHECK-NEXT: %c1_i32 = arith.constant 1 : i32 +// CHECK-NEXT: %c5_i32 = arith.constant 5 : i32 +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: %c-1_i32 = arith.constant -1 : i32 +// CHECK-NEXT: %0 = memref.alloca() : memref +// CHECK-NEXT: %1 = llvm.mlir.undef : i32 +// CHECK-NEXT: memref.store %1, %0[] : memref +// CHECK-NEXT: %2:2 = scf.if %arg1 -> (i32, i32) { +// CHECK-NEXT: memref.store %c-1_i32, %0[] : memref +// CHECK-NEXT: scf.execute_region { +// CHECK-NEXT: memref.store %c0_i32, %0[] : memref +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: %3 = memref.load %0[] : memref +// CHECK-NEXT: %4:2 = scf.if %arg2 -> (i32, i32) { +// CHECK-NEXT: memref.store %c5_i32, %0[] : memref +// CHECK-NEXT: scf.yield %c5_i32, %c5_i32 : i32, i32 +// CHECK-NEXT: } else { +// CHECK-NEXT: scf.yield %3, %3 : i32, i32 +// CHECK-NEXT: } +// CHECK-NEXT: llvm.call @print(%4#0) : (i32) -> () +// CHECK-NEXT: scf.yield %4#0, %4#1 : i32, i32 +// CHECK-NEXT: } else { +// CHECK-NEXT: scf.yield %1, %1 : i32, i32 +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } + diff --git a/tools/mlir-clang/Lib/clang-mlir.cc b/tools/mlir-clang/Lib/clang-mlir.cc index c11563b..3d2b9a8 100644 --- a/tools/mlir-clang/Lib/clang-mlir.cc +++ b/tools/mlir-clang/Lib/clang-mlir.cc @@ -2258,6 +2258,30 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) { args[3]); return ValueCategory(args[0], /*isReference*/ false); } + if (sr->getDecl()->getIdentifier() && + (sr->getDecl()->getName() == "memset" || + sr->getDecl()->getName() == "__builtin_memset")) { + std::vector args = { + getLLVM(expr->getArg(0)), getLLVM(expr->getArg(1)), + getLLVM(expr->getArg(2)), /*isVolatile*/ + builder.create(loc, false, 1)}; + + args[1] = builder.create(loc, builder.getI8Type(), args[1]); + builder.create(loc, args[0], args[1], args[2], + args[3]); + return ValueCategory(args[0], /*isReference*/ false); + } + if (sr->getDecl()->getIdentifier() && + (sr->getDecl()->getName() == "memcpy" || + sr->getDecl()->getName() == "__builtin_memcpy")) { + std::vector args = { + getLLVM(expr->getArg(0)), getLLVM(expr->getArg(1)), + getLLVM(expr->getArg(2)), /*isVolatile*/ + builder.create(loc, false, 1)}; + builder.create(loc, args[0], args[1], args[2], + args[3]); + return ValueCategory(args[0], /*isReference*/ false); + } if (sr->getDecl()->getIdentifier() && (sr->getDecl()->getName() == "cudaMemcpy" || sr->getDecl()->getName() == "cudaMemcpyAsync" || @@ -2271,16 +2295,6 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) { while (auto BC = dyn_cast(srcSub)) srcSub = BC->getSubExpr(); - if (sr->getDecl()->getName() == "memcpy" || - sr->getDecl()->getName() == "__builtin_memcpy") { - std::vector args = { - getLLVM(expr->getArg(0)), getLLVM(expr->getArg(1)), - getLLVM(expr->getArg(2)), /*isVolatile*/ - builder.create(loc, false, 1)}; - builder.create(loc, args[0], args[1], args[2], - args[3]); - return ValueCategory(args[0], /*isReference*/ false); - } #if 0 auto dstst = dstSub->getType()->getUnqualifiedDesugaredType(); if (isa(dstst) || isa(dstst)) { diff --git a/tools/mlir-clang/mlir-clang.cc b/tools/mlir-clang/mlir-clang.cc index f255e52..46e0a58 100644 --- a/tools/mlir-clang/mlir-clang.cc +++ b/tools/mlir-clang/mlir-clang.cc @@ -417,7 +417,7 @@ int main(int argc, char **argv) { llvm::errs() << "\n"; } - bool LinkOMP = false; + bool LinkOMP = FOpenMP; pm.enableVerifier(EarlyVerifier); mlir::OpPassManager &optPM = pm.nest(); if (true) {