Fix mem2reg bug

This commit is contained in:
William S. Moses 2021-12-29 22:52:23 -05:00 committed by William Moses
parent 3bfbd54ee5
commit 92684ed844
5 changed files with 135 additions and 31 deletions

View File

@ -448,6 +448,8 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
for (auto &a : block) {
ops.push_back(&a);
}
LLVM_DEBUG( llvm::dbgs() << " starting block: "; block.print(llvm::dbgs()); llvm::dbgs() << " with ";
if (lastVal) llvm::dbgs() << lastVal << "\n"; else llvm::dbgs() << " null\n"; );
for (auto a : ops) {
if (StoringOperations.count(a)) {
if (auto exOp = dyn_cast<mlir::scf::ExecuteRegionOp>(a)) {
@ -455,6 +457,7 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
Value thenVal; // = handleBlock(exOp.region().front(), lastVal);
lastVal = nullptr;
seenSubStore = true;
LLVM_DEBUG( llvm::dbgs() << " zeroing val due to " << exOp << "\n"; );
continue;
bool needsAfter = false;
@ -470,8 +473,9 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
});
}
}
if (!needsAfter)
if (!needsAfter) {
continue;
}
Block &then = exOp.region().back();
OpBuilder B(exOp.getContext());
@ -550,7 +554,8 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
else
newLoad = B.create<LLVM::LoadOp>(ifOp.getLoc(), AI);
loadOps.insert(newLoad);
if (!seenSubStore)
loadOps.insert(newLoad);
lastVal = newLoad->getResult(0);
}
@ -813,6 +818,8 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
});
}
}
LLVM_DEBUG( llvm::dbgs() << " ending block: "; block.print(llvm::dbgs()); llvm::dbgs() << " with ";
if (lastVal) llvm::dbgs() << lastVal << "\n"; else llvm::dbgs() << " null\n"; );
return lastStoreInBlock[&block] = lastVal;
};

View File

@ -144,24 +144,33 @@ struct AlwaysInlinerInterface : public InlinerInterface {
};
// TODO
#if 0
mlir::LLVM::LLVMFuncOp MLIRASTConsumer::GetOrCreateMallocFunction(ModuleOp* module) {
std::string name = "malloc";
if (llvmFunctions.find(name) != llvmFunctions.end()) {
return llvmFunctions[name];
}
mlir::LLVM::LLVMFuncOp GetOrCreateMallocFunction(ModuleOp module) {
mlir::OpBuilder builder(module.getContext());
SymbolTableCollection symbolTable;
if (auto fn = dyn_cast_or_null<LLVM::LLVMFuncOp>(symbolTable.lookupSymbolIn(module, builder.getIdentifier("malloc"))))
return fn;
auto ctx = module->getContext();
mlir::Type types[] = {mlir::IntegerType::get(ctx, 64)};
auto llvmFnType = LLVM::LLVMFunctionType::get(
LLVM::LLVMPointerType::get(mlir::IntegerType::get(ctx, 8)), types, false);
LLVM::Linkage lnk = LLVM::Linkage::External;
mlir::OpBuilder builder(module->getContext());
builder.setInsertionPointToStart(module->getBody());
return llvmFunctions[name] = builder.create<LLVM::LLVMFuncOp>(
module->getLoc(), name, llvmFnType, lnk);
builder.setInsertionPointToStart(module.getBody());
return builder.create<LLVM::LLVMFuncOp>(module.getLoc(), "malloc", llvmFnType, lnk);
}
mlir::LLVM::LLVMFuncOp GetOrCreateFreeFunction(ModuleOp module) {
mlir::OpBuilder builder(module.getContext());
SymbolTableCollection symbolTable;
if (auto fn = dyn_cast_or_null<LLVM::LLVMFuncOp>(symbolTable.lookupSymbolIn(module, builder.getIdentifier("free"))))
return fn;
auto ctx = module->getContext();
auto llvmFnType = LLVM::LLVMFunctionType::get(
LLVM::LLVMVoidType::get(ctx), ArrayRef<mlir::Type>(LLVM::LLVMPointerType::get(builder.getI8Type())), false);
LLVM::Linkage lnk = LLVM::Linkage::External;
builder.setInsertionPointToStart(module.getBody());
return builder.create<LLVM::LLVMFuncOp>(module.getLoc(), "free", llvmFnType, lnk);
}
#endif
void ParallelLower::runOnOperation() {
// The inliner should only be run on operations that define a symbol table,
@ -419,17 +428,34 @@ void ParallelLower::runOnOperation() {
OpBuilder bz(call);
auto falsev = bz.create<ConstantIntOp>(call.getLoc(), false, 1);
bz.create<LLVM::MemsetOp>(call.getLoc(), call.getOperand(0),
call.getOperand(1), call.getOperand(2),
bz.create<TruncIOp>(call.getLoc(), bz.getI8Type(), call.getOperand(1)), call.getOperand(2),
/*isVolatile*/ falsev);
Value vals[] = {call.getOperand(2)};
call.replaceAllUsesWith(ArrayRef<Value>(vals));
call.erase();
/*
} else if (call.getCallee().getValue() == "cudaMalloc") {
Value vals[] = {call.getOperand(0)};
call.replaceAllUsesWith(ArrayRef<Value>(vals));
call.erase();
*/
} else if (call.getCallee().getValue() == "cudaMalloc") {
auto mf = GetOrCreateMallocFunction(getOperation());
OpBuilder bz(call);
Value args[] = {bz.create<arith::ExtUIOp>(call.getLoc(), bz.getI64Type(), call.getOperand(1))};
mlir::Value alloc = bz.create<mlir::LLVM::CallOp>(call.getLoc(), mf, args).getResult(0);
bz.create<LLVM::StoreOp>(call.getLoc(), alloc, call.getOperand(0));
{
auto retv = bz.create<ConstantIntOp>(call.getLoc(), 0, call.getResult(0).getType().cast<IntegerType>().getWidth());
Value vals[] = {retv};
call.replaceAllUsesWith(ArrayRef<Value>(vals));
call.erase();
}
} else if (call.getCallee().getValue() == "cudaFree") {
auto mf = GetOrCreateFreeFunction(getOperation());
OpBuilder bz(call);
Value args[] = {call.getOperand(0)};
bz.create<mlir::LLVM::CallOp>(call.getLoc(), mf, args);
{
auto retv = bz.create<ConstantIntOp>(call.getLoc(), 0, call.getResult(0).getType().cast<IntegerType>().getWidth());
Value vals[] = {retv};
call.replaceAllUsesWith(ArrayRef<Value>(vals));
call.erase();
}
} else if (call.getCallee().getValue() == "cudaDeviceSynchronize") {
OpBuilder bz(call);
auto retv = bz.create<ConstantIntOp>(call.getLoc(), 0, call.getResult(0).getType().cast<IntegerType>().getWidth());

View File

@ -0,0 +1,57 @@
// RUN: polygeist-opt --mem2reg --split-input-file %s | FileCheck %s
module {
llvm.func @print(i32)
func @h(%arg7: i1, %arg8: i1, %arg9 : i1) {
%c1_i32 = arith.constant 1 : i32
%c5_i32 = arith.constant 5 : i32
%c0_i32 = arith.constant 0 : i32
%c-1_i32 = arith.constant -1 : i32
%2 = memref.alloca() : memref<i32>
%3 = llvm.mlir.undef : i32
memref.store %3, %2[] : memref<i32>
scf.if %arg8 {
memref.store %c-1_i32, %2[] : memref<i32>
scf.execute_region {
memref.store %c0_i32, %2[] : memref<i32>
scf.yield
}
scf.if %arg9 {
memref.store %c5_i32, %2[] : memref<i32>
}
%23 = memref.load %2[] : memref<i32>
llvm.call @print(%23) : (i32) -> ()
}
return
}
}
// CHECK: func @h(%arg0: i1, %arg1: i1, %arg2: i1)
// CHECK-NEXT: %c1_i32 = arith.constant 1 : i32
// CHECK-NEXT: %c5_i32 = arith.constant 5 : i32
// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
// CHECK-NEXT: %c-1_i32 = arith.constant -1 : i32
// CHECK-NEXT: %0 = memref.alloca() : memref<i32>
// CHECK-NEXT: %1 = llvm.mlir.undef : i32
// CHECK-NEXT: memref.store %1, %0[] : memref<i32>
// CHECK-NEXT: %2:2 = scf.if %arg1 -> (i32, i32) {
// CHECK-NEXT: memref.store %c-1_i32, %0[] : memref<i32>
// CHECK-NEXT: scf.execute_region {
// CHECK-NEXT: memref.store %c0_i32, %0[] : memref<i32>
// CHECK-NEXT: scf.yield
// CHECK-NEXT: }
// CHECK-NEXT: %3 = memref.load %0[] : memref<i32>
// CHECK-NEXT: %4:2 = scf.if %arg2 -> (i32, i32) {
// CHECK-NEXT: memref.store %c5_i32, %0[] : memref<i32>
// CHECK-NEXT: scf.yield %c5_i32, %c5_i32 : i32, i32
// CHECK-NEXT: } else {
// CHECK-NEXT: scf.yield %3, %3 : i32, i32
// CHECK-NEXT: }
// CHECK-NEXT: llvm.call @print(%4#0) : (i32) -> ()
// CHECK-NEXT: scf.yield %4#0, %4#1 : i32, i32
// CHECK-NEXT: } else {
// CHECK-NEXT: scf.yield %1, %1 : i32, i32
// CHECK-NEXT: }
// CHECK-NEXT: return
// CHECK-NEXT: }

View File

@ -2258,6 +2258,30 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
args[3]);
return ValueCategory(args[0], /*isReference*/ false);
}
if (sr->getDecl()->getIdentifier() &&
(sr->getDecl()->getName() == "memset" ||
sr->getDecl()->getName() == "__builtin_memset")) {
std::vector<mlir::Value> args = {
getLLVM(expr->getArg(0)), getLLVM(expr->getArg(1)),
getLLVM(expr->getArg(2)), /*isVolatile*/
builder.create<ConstantIntOp>(loc, false, 1)};
args[1] = builder.create<TruncIOp>(loc, builder.getI8Type(), args[1]);
builder.create<LLVM::MemsetOp>(loc, args[0], args[1], args[2],
args[3]);
return ValueCategory(args[0], /*isReference*/ false);
}
if (sr->getDecl()->getIdentifier() &&
(sr->getDecl()->getName() == "memcpy" ||
sr->getDecl()->getName() == "__builtin_memcpy")) {
std::vector<mlir::Value> args = {
getLLVM(expr->getArg(0)), getLLVM(expr->getArg(1)),
getLLVM(expr->getArg(2)), /*isVolatile*/
builder.create<ConstantIntOp>(loc, false, 1)};
builder.create<LLVM::MemcpyOp>(loc, args[0], args[1], args[2],
args[3]);
return ValueCategory(args[0], /*isReference*/ false);
}
if (sr->getDecl()->getIdentifier() &&
(sr->getDecl()->getName() == "cudaMemcpy" ||
sr->getDecl()->getName() == "cudaMemcpyAsync" ||
@ -2271,16 +2295,6 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
while (auto BC = dyn_cast<clang::CastExpr>(srcSub))
srcSub = BC->getSubExpr();
if (sr->getDecl()->getName() == "memcpy" ||
sr->getDecl()->getName() == "__builtin_memcpy") {
std::vector<mlir::Value> args = {
getLLVM(expr->getArg(0)), getLLVM(expr->getArg(1)),
getLLVM(expr->getArg(2)), /*isVolatile*/
builder.create<ConstantIntOp>(loc, false, 1)};
builder.create<LLVM::MemcpyOp>(loc, args[0], args[1], args[2],
args[3]);
return ValueCategory(args[0], /*isReference*/ false);
}
#if 0
auto dstst = dstSub->getType()->getUnqualifiedDesugaredType();
if (isa<clang::PointerType>(dstst) || isa<clang::ArrayType>(dstst)) {

View File

@ -417,7 +417,7 @@ int main(int argc, char **argv) {
llvm::errs() << "</immediate: mlir>\n";
}
bool LinkOMP = false;
bool LinkOMP = FOpenMP;
pm.enableVerifier(EarlyVerifier);
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
if (true) {