Fix mem2reg bug

This commit is contained in:
William S. Moses 2021-12-29 22:52:23 -05:00 committed by William Moses
parent 3bfbd54ee5
commit 92684ed844
5 changed files with 135 additions and 31 deletions

View File

@ -448,6 +448,8 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
for (auto &a : block) { for (auto &a : block) {
ops.push_back(&a); ops.push_back(&a);
} }
LLVM_DEBUG( llvm::dbgs() << " starting block: "; block.print(llvm::dbgs()); llvm::dbgs() << " with ";
if (lastVal) llvm::dbgs() << lastVal << "\n"; else llvm::dbgs() << " null\n"; );
for (auto a : ops) { for (auto a : ops) {
if (StoringOperations.count(a)) { if (StoringOperations.count(a)) {
if (auto exOp = dyn_cast<mlir::scf::ExecuteRegionOp>(a)) { if (auto exOp = dyn_cast<mlir::scf::ExecuteRegionOp>(a)) {
@ -455,6 +457,7 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
Value thenVal; // = handleBlock(exOp.region().front(), lastVal); Value thenVal; // = handleBlock(exOp.region().front(), lastVal);
lastVal = nullptr; lastVal = nullptr;
seenSubStore = true; seenSubStore = true;
LLVM_DEBUG( llvm::dbgs() << " zeroing val due to " << exOp << "\n"; );
continue; continue;
bool needsAfter = false; bool needsAfter = false;
@ -470,8 +473,9 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
}); });
} }
} }
if (!needsAfter) if (!needsAfter) {
continue; continue;
}
Block &then = exOp.region().back(); Block &then = exOp.region().back();
OpBuilder B(exOp.getContext()); OpBuilder B(exOp.getContext());
@ -550,7 +554,8 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
else else
newLoad = B.create<LLVM::LoadOp>(ifOp.getLoc(), AI); newLoad = B.create<LLVM::LoadOp>(ifOp.getLoc(), AI);
loadOps.insert(newLoad); if (!seenSubStore)
loadOps.insert(newLoad);
lastVal = newLoad->getResult(0); lastVal = newLoad->getResult(0);
} }
@ -813,6 +818,8 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
}); });
} }
} }
LLVM_DEBUG( llvm::dbgs() << " ending block: "; block.print(llvm::dbgs()); llvm::dbgs() << " with ";
if (lastVal) llvm::dbgs() << lastVal << "\n"; else llvm::dbgs() << " null\n"; );
return lastStoreInBlock[&block] = lastVal; return lastStoreInBlock[&block] = lastVal;
}; };

View File

@ -144,24 +144,33 @@ struct AlwaysInlinerInterface : public InlinerInterface {
}; };
// TODO // TODO
#if 0 mlir::LLVM::LLVMFuncOp GetOrCreateMallocFunction(ModuleOp module) {
mlir::LLVM::LLVMFuncOp MLIRASTConsumer::GetOrCreateMallocFunction(ModuleOp* module) { mlir::OpBuilder builder(module.getContext());
std::string name = "malloc"; SymbolTableCollection symbolTable;
if (llvmFunctions.find(name) != llvmFunctions.end()) { if (auto fn = dyn_cast_or_null<LLVM::LLVMFuncOp>(symbolTable.lookupSymbolIn(module, builder.getIdentifier("malloc"))))
return llvmFunctions[name]; return fn;
}
auto ctx = module->getContext(); auto ctx = module->getContext();
mlir::Type types[] = {mlir::IntegerType::get(ctx, 64)}; mlir::Type types[] = {mlir::IntegerType::get(ctx, 64)};
auto llvmFnType = LLVM::LLVMFunctionType::get( auto llvmFnType = LLVM::LLVMFunctionType::get(
LLVM::LLVMPointerType::get(mlir::IntegerType::get(ctx, 8)), types, false); LLVM::LLVMPointerType::get(mlir::IntegerType::get(ctx, 8)), types, false);
LLVM::Linkage lnk = LLVM::Linkage::External; LLVM::Linkage lnk = LLVM::Linkage::External;
mlir::OpBuilder builder(module->getContext()); builder.setInsertionPointToStart(module.getBody());
builder.setInsertionPointToStart(module->getBody()); return builder.create<LLVM::LLVMFuncOp>(module.getLoc(), "malloc", llvmFnType, lnk);
return llvmFunctions[name] = builder.create<LLVM::LLVMFuncOp>( }
module->getLoc(), name, llvmFnType, lnk); mlir::LLVM::LLVMFuncOp GetOrCreateFreeFunction(ModuleOp module) {
mlir::OpBuilder builder(module.getContext());
SymbolTableCollection symbolTable;
if (auto fn = dyn_cast_or_null<LLVM::LLVMFuncOp>(symbolTable.lookupSymbolIn(module, builder.getIdentifier("free"))))
return fn;
auto ctx = module->getContext();
auto llvmFnType = LLVM::LLVMFunctionType::get(
LLVM::LLVMVoidType::get(ctx), ArrayRef<mlir::Type>(LLVM::LLVMPointerType::get(builder.getI8Type())), false);
LLVM::Linkage lnk = LLVM::Linkage::External;
builder.setInsertionPointToStart(module.getBody());
return builder.create<LLVM::LLVMFuncOp>(module.getLoc(), "free", llvmFnType, lnk);
} }
#endif
void ParallelLower::runOnOperation() { void ParallelLower::runOnOperation() {
// The inliner should only be run on operations that define a symbol table, // The inliner should only be run on operations that define a symbol table,
@ -419,17 +428,34 @@ void ParallelLower::runOnOperation() {
OpBuilder bz(call); OpBuilder bz(call);
auto falsev = bz.create<ConstantIntOp>(call.getLoc(), false, 1); auto falsev = bz.create<ConstantIntOp>(call.getLoc(), false, 1);
bz.create<LLVM::MemsetOp>(call.getLoc(), call.getOperand(0), bz.create<LLVM::MemsetOp>(call.getLoc(), call.getOperand(0),
call.getOperand(1), call.getOperand(2), bz.create<TruncIOp>(call.getLoc(), bz.getI8Type(), call.getOperand(1)), call.getOperand(2),
/*isVolatile*/ falsev); /*isVolatile*/ falsev);
Value vals[] = {call.getOperand(2)};
call.replaceAllUsesWith(ArrayRef<Value>(vals));
call.erase();
/*
} else if (call.getCallee().getValue() == "cudaMalloc") {
Value vals[] = {call.getOperand(0)}; Value vals[] = {call.getOperand(0)};
call.replaceAllUsesWith(ArrayRef<Value>(vals)); call.replaceAllUsesWith(ArrayRef<Value>(vals));
call.erase(); call.erase();
*/ } else if (call.getCallee().getValue() == "cudaMalloc") {
auto mf = GetOrCreateMallocFunction(getOperation());
OpBuilder bz(call);
Value args[] = {bz.create<arith::ExtUIOp>(call.getLoc(), bz.getI64Type(), call.getOperand(1))};
mlir::Value alloc = bz.create<mlir::LLVM::CallOp>(call.getLoc(), mf, args).getResult(0);
bz.create<LLVM::StoreOp>(call.getLoc(), alloc, call.getOperand(0));
{
auto retv = bz.create<ConstantIntOp>(call.getLoc(), 0, call.getResult(0).getType().cast<IntegerType>().getWidth());
Value vals[] = {retv};
call.replaceAllUsesWith(ArrayRef<Value>(vals));
call.erase();
}
} else if (call.getCallee().getValue() == "cudaFree") {
auto mf = GetOrCreateFreeFunction(getOperation());
OpBuilder bz(call);
Value args[] = {call.getOperand(0)};
bz.create<mlir::LLVM::CallOp>(call.getLoc(), mf, args);
{
auto retv = bz.create<ConstantIntOp>(call.getLoc(), 0, call.getResult(0).getType().cast<IntegerType>().getWidth());
Value vals[] = {retv};
call.replaceAllUsesWith(ArrayRef<Value>(vals));
call.erase();
}
} else if (call.getCallee().getValue() == "cudaDeviceSynchronize") { } else if (call.getCallee().getValue() == "cudaDeviceSynchronize") {
OpBuilder bz(call); OpBuilder bz(call);
auto retv = bz.create<ConstantIntOp>(call.getLoc(), 0, call.getResult(0).getType().cast<IntegerType>().getWidth()); auto retv = bz.create<ConstantIntOp>(call.getLoc(), 0, call.getResult(0).getType().cast<IntegerType>().getWidth());

View File

@ -0,0 +1,57 @@
// RUN: polygeist-opt --mem2reg --split-input-file %s | FileCheck %s
module {
llvm.func @print(i32)
func @h(%arg7: i1, %arg8: i1, %arg9 : i1) {
%c1_i32 = arith.constant 1 : i32
%c5_i32 = arith.constant 5 : i32
%c0_i32 = arith.constant 0 : i32
%c-1_i32 = arith.constant -1 : i32
%2 = memref.alloca() : memref<i32>
%3 = llvm.mlir.undef : i32
memref.store %3, %2[] : memref<i32>
scf.if %arg8 {
memref.store %c-1_i32, %2[] : memref<i32>
scf.execute_region {
memref.store %c0_i32, %2[] : memref<i32>
scf.yield
}
scf.if %arg9 {
memref.store %c5_i32, %2[] : memref<i32>
}
%23 = memref.load %2[] : memref<i32>
llvm.call @print(%23) : (i32) -> ()
}
return
}
}
// CHECK: func @h(%arg0: i1, %arg1: i1, %arg2: i1)
// CHECK-NEXT: %c1_i32 = arith.constant 1 : i32
// CHECK-NEXT: %c5_i32 = arith.constant 5 : i32
// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
// CHECK-NEXT: %c-1_i32 = arith.constant -1 : i32
// CHECK-NEXT: %0 = memref.alloca() : memref<i32>
// CHECK-NEXT: %1 = llvm.mlir.undef : i32
// CHECK-NEXT: memref.store %1, %0[] : memref<i32>
// CHECK-NEXT: %2:2 = scf.if %arg1 -> (i32, i32) {
// CHECK-NEXT: memref.store %c-1_i32, %0[] : memref<i32>
// CHECK-NEXT: scf.execute_region {
// CHECK-NEXT: memref.store %c0_i32, %0[] : memref<i32>
// CHECK-NEXT: scf.yield
// CHECK-NEXT: }
// CHECK-NEXT: %3 = memref.load %0[] : memref<i32>
// CHECK-NEXT: %4:2 = scf.if %arg2 -> (i32, i32) {
// CHECK-NEXT: memref.store %c5_i32, %0[] : memref<i32>
// CHECK-NEXT: scf.yield %c5_i32, %c5_i32 : i32, i32
// CHECK-NEXT: } else {
// CHECK-NEXT: scf.yield %3, %3 : i32, i32
// CHECK-NEXT: }
// CHECK-NEXT: llvm.call @print(%4#0) : (i32) -> ()
// CHECK-NEXT: scf.yield %4#0, %4#1 : i32, i32
// CHECK-NEXT: } else {
// CHECK-NEXT: scf.yield %1, %1 : i32, i32
// CHECK-NEXT: }
// CHECK-NEXT: return
// CHECK-NEXT: }

View File

@ -2258,6 +2258,30 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
args[3]); args[3]);
return ValueCategory(args[0], /*isReference*/ false); return ValueCategory(args[0], /*isReference*/ false);
} }
if (sr->getDecl()->getIdentifier() &&
(sr->getDecl()->getName() == "memset" ||
sr->getDecl()->getName() == "__builtin_memset")) {
std::vector<mlir::Value> args = {
getLLVM(expr->getArg(0)), getLLVM(expr->getArg(1)),
getLLVM(expr->getArg(2)), /*isVolatile*/
builder.create<ConstantIntOp>(loc, false, 1)};
args[1] = builder.create<TruncIOp>(loc, builder.getI8Type(), args[1]);
builder.create<LLVM::MemsetOp>(loc, args[0], args[1], args[2],
args[3]);
return ValueCategory(args[0], /*isReference*/ false);
}
if (sr->getDecl()->getIdentifier() &&
(sr->getDecl()->getName() == "memcpy" ||
sr->getDecl()->getName() == "__builtin_memcpy")) {
std::vector<mlir::Value> args = {
getLLVM(expr->getArg(0)), getLLVM(expr->getArg(1)),
getLLVM(expr->getArg(2)), /*isVolatile*/
builder.create<ConstantIntOp>(loc, false, 1)};
builder.create<LLVM::MemcpyOp>(loc, args[0], args[1], args[2],
args[3]);
return ValueCategory(args[0], /*isReference*/ false);
}
if (sr->getDecl()->getIdentifier() && if (sr->getDecl()->getIdentifier() &&
(sr->getDecl()->getName() == "cudaMemcpy" || (sr->getDecl()->getName() == "cudaMemcpy" ||
sr->getDecl()->getName() == "cudaMemcpyAsync" || sr->getDecl()->getName() == "cudaMemcpyAsync" ||
@ -2271,16 +2295,6 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
while (auto BC = dyn_cast<clang::CastExpr>(srcSub)) while (auto BC = dyn_cast<clang::CastExpr>(srcSub))
srcSub = BC->getSubExpr(); srcSub = BC->getSubExpr();
if (sr->getDecl()->getName() == "memcpy" ||
sr->getDecl()->getName() == "__builtin_memcpy") {
std::vector<mlir::Value> args = {
getLLVM(expr->getArg(0)), getLLVM(expr->getArg(1)),
getLLVM(expr->getArg(2)), /*isVolatile*/
builder.create<ConstantIntOp>(loc, false, 1)};
builder.create<LLVM::MemcpyOp>(loc, args[0], args[1], args[2],
args[3]);
return ValueCategory(args[0], /*isReference*/ false);
}
#if 0 #if 0
auto dstst = dstSub->getType()->getUnqualifiedDesugaredType(); auto dstst = dstSub->getType()->getUnqualifiedDesugaredType();
if (isa<clang::PointerType>(dstst) || isa<clang::ArrayType>(dstst)) { if (isa<clang::PointerType>(dstst) || isa<clang::ArrayType>(dstst)) {

View File

@ -417,7 +417,7 @@ int main(int argc, char **argv) {
llvm::errs() << "</immediate: mlir>\n"; llvm::errs() << "</immediate: mlir>\n";
} }
bool LinkOMP = false; bool LinkOMP = FOpenMP;
pm.enableVerifier(EarlyVerifier); pm.enableVerifier(EarlyVerifier);
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>(); mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
if (true) { if (true) {