Fix mem2reg bug
This commit is contained in:
parent
3bfbd54ee5
commit
92684ed844
|
@ -448,6 +448,8 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
|
|||
for (auto &a : block) {
|
||||
ops.push_back(&a);
|
||||
}
|
||||
LLVM_DEBUG( llvm::dbgs() << " starting block: "; block.print(llvm::dbgs()); llvm::dbgs() << " with ";
|
||||
if (lastVal) llvm::dbgs() << lastVal << "\n"; else llvm::dbgs() << " null\n"; );
|
||||
for (auto a : ops) {
|
||||
if (StoringOperations.count(a)) {
|
||||
if (auto exOp = dyn_cast<mlir::scf::ExecuteRegionOp>(a)) {
|
||||
|
@ -455,6 +457,7 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
|
|||
Value thenVal; // = handleBlock(exOp.region().front(), lastVal);
|
||||
lastVal = nullptr;
|
||||
seenSubStore = true;
|
||||
LLVM_DEBUG( llvm::dbgs() << " zeroing val due to " << exOp << "\n"; );
|
||||
continue;
|
||||
|
||||
bool needsAfter = false;
|
||||
|
@ -470,8 +473,9 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
|
|||
});
|
||||
}
|
||||
}
|
||||
if (!needsAfter)
|
||||
if (!needsAfter) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Block &then = exOp.region().back();
|
||||
OpBuilder B(exOp.getContext());
|
||||
|
@ -550,7 +554,8 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
|
|||
else
|
||||
newLoad = B.create<LLVM::LoadOp>(ifOp.getLoc(), AI);
|
||||
|
||||
loadOps.insert(newLoad);
|
||||
if (!seenSubStore)
|
||||
loadOps.insert(newLoad);
|
||||
lastVal = newLoad->getResult(0);
|
||||
}
|
||||
|
||||
|
@ -813,6 +818,8 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
|
|||
});
|
||||
}
|
||||
}
|
||||
LLVM_DEBUG( llvm::dbgs() << " ending block: "; block.print(llvm::dbgs()); llvm::dbgs() << " with ";
|
||||
if (lastVal) llvm::dbgs() << lastVal << "\n"; else llvm::dbgs() << " null\n"; );
|
||||
return lastStoreInBlock[&block] = lastVal;
|
||||
};
|
||||
|
||||
|
|
|
@ -144,24 +144,33 @@ struct AlwaysInlinerInterface : public InlinerInterface {
|
|||
};
|
||||
|
||||
// TODO
|
||||
#if 0
|
||||
mlir::LLVM::LLVMFuncOp MLIRASTConsumer::GetOrCreateMallocFunction(ModuleOp* module) {
|
||||
std::string name = "malloc";
|
||||
if (llvmFunctions.find(name) != llvmFunctions.end()) {
|
||||
return llvmFunctions[name];
|
||||
}
|
||||
mlir::LLVM::LLVMFuncOp GetOrCreateMallocFunction(ModuleOp module) {
|
||||
mlir::OpBuilder builder(module.getContext());
|
||||
SymbolTableCollection symbolTable;
|
||||
if (auto fn = dyn_cast_or_null<LLVM::LLVMFuncOp>(symbolTable.lookupSymbolIn(module, builder.getIdentifier("malloc"))))
|
||||
return fn;
|
||||
auto ctx = module->getContext();
|
||||
mlir::Type types[] = {mlir::IntegerType::get(ctx, 64)};
|
||||
auto llvmFnType = LLVM::LLVMFunctionType::get(
|
||||
LLVM::LLVMPointerType::get(mlir::IntegerType::get(ctx, 8)), types, false);
|
||||
|
||||
LLVM::Linkage lnk = LLVM::Linkage::External;
|
||||
mlir::OpBuilder builder(module->getContext());
|
||||
builder.setInsertionPointToStart(module->getBody());
|
||||
return llvmFunctions[name] = builder.create<LLVM::LLVMFuncOp>(
|
||||
module->getLoc(), name, llvmFnType, lnk);
|
||||
builder.setInsertionPointToStart(module.getBody());
|
||||
return builder.create<LLVM::LLVMFuncOp>(module.getLoc(), "malloc", llvmFnType, lnk);
|
||||
}
|
||||
mlir::LLVM::LLVMFuncOp GetOrCreateFreeFunction(ModuleOp module) {
|
||||
mlir::OpBuilder builder(module.getContext());
|
||||
SymbolTableCollection symbolTable;
|
||||
if (auto fn = dyn_cast_or_null<LLVM::LLVMFuncOp>(symbolTable.lookupSymbolIn(module, builder.getIdentifier("free"))))
|
||||
return fn;
|
||||
auto ctx = module->getContext();
|
||||
auto llvmFnType = LLVM::LLVMFunctionType::get(
|
||||
LLVM::LLVMVoidType::get(ctx), ArrayRef<mlir::Type>(LLVM::LLVMPointerType::get(builder.getI8Type())), false);
|
||||
|
||||
LLVM::Linkage lnk = LLVM::Linkage::External;
|
||||
builder.setInsertionPointToStart(module.getBody());
|
||||
return builder.create<LLVM::LLVMFuncOp>(module.getLoc(), "free", llvmFnType, lnk);
|
||||
}
|
||||
#endif
|
||||
|
||||
void ParallelLower::runOnOperation() {
|
||||
// The inliner should only be run on operations that define a symbol table,
|
||||
|
@ -419,17 +428,34 @@ void ParallelLower::runOnOperation() {
|
|||
OpBuilder bz(call);
|
||||
auto falsev = bz.create<ConstantIntOp>(call.getLoc(), false, 1);
|
||||
bz.create<LLVM::MemsetOp>(call.getLoc(), call.getOperand(0),
|
||||
call.getOperand(1), call.getOperand(2),
|
||||
bz.create<TruncIOp>(call.getLoc(), bz.getI8Type(), call.getOperand(1)), call.getOperand(2),
|
||||
/*isVolatile*/ falsev);
|
||||
Value vals[] = {call.getOperand(2)};
|
||||
call.replaceAllUsesWith(ArrayRef<Value>(vals));
|
||||
call.erase();
|
||||
/*
|
||||
} else if (call.getCallee().getValue() == "cudaMalloc") {
|
||||
Value vals[] = {call.getOperand(0)};
|
||||
call.replaceAllUsesWith(ArrayRef<Value>(vals));
|
||||
call.erase();
|
||||
*/
|
||||
} else if (call.getCallee().getValue() == "cudaMalloc") {
|
||||
auto mf = GetOrCreateMallocFunction(getOperation());
|
||||
OpBuilder bz(call);
|
||||
Value args[] = {bz.create<arith::ExtUIOp>(call.getLoc(), bz.getI64Type(), call.getOperand(1))};
|
||||
mlir::Value alloc = bz.create<mlir::LLVM::CallOp>(call.getLoc(), mf, args).getResult(0);
|
||||
bz.create<LLVM::StoreOp>(call.getLoc(), alloc, call.getOperand(0));
|
||||
{
|
||||
auto retv = bz.create<ConstantIntOp>(call.getLoc(), 0, call.getResult(0).getType().cast<IntegerType>().getWidth());
|
||||
Value vals[] = {retv};
|
||||
call.replaceAllUsesWith(ArrayRef<Value>(vals));
|
||||
call.erase();
|
||||
}
|
||||
} else if (call.getCallee().getValue() == "cudaFree") {
|
||||
auto mf = GetOrCreateFreeFunction(getOperation());
|
||||
OpBuilder bz(call);
|
||||
Value args[] = {call.getOperand(0)};
|
||||
bz.create<mlir::LLVM::CallOp>(call.getLoc(), mf, args);
|
||||
{
|
||||
auto retv = bz.create<ConstantIntOp>(call.getLoc(), 0, call.getResult(0).getType().cast<IntegerType>().getWidth());
|
||||
Value vals[] = {retv};
|
||||
call.replaceAllUsesWith(ArrayRef<Value>(vals));
|
||||
call.erase();
|
||||
}
|
||||
} else if (call.getCallee().getValue() == "cudaDeviceSynchronize") {
|
||||
OpBuilder bz(call);
|
||||
auto retv = bz.create<ConstantIntOp>(call.getLoc(), 0, call.getResult(0).getType().cast<IntegerType>().getWidth());
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
// RUN: polygeist-opt --mem2reg --split-input-file %s | FileCheck %s
|
||||
|
||||
module {
|
||||
llvm.func @print(i32)
|
||||
func @h(%arg7: i1, %arg8: i1, %arg9 : i1) {
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
%c5_i32 = arith.constant 5 : i32
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%c-1_i32 = arith.constant -1 : i32
|
||||
%2 = memref.alloca() : memref<i32>
|
||||
%3 = llvm.mlir.undef : i32
|
||||
memref.store %3, %2[] : memref<i32>
|
||||
scf.if %arg8 {
|
||||
memref.store %c-1_i32, %2[] : memref<i32>
|
||||
scf.execute_region {
|
||||
memref.store %c0_i32, %2[] : memref<i32>
|
||||
scf.yield
|
||||
}
|
||||
scf.if %arg9 {
|
||||
memref.store %c5_i32, %2[] : memref<i32>
|
||||
}
|
||||
%23 = memref.load %2[] : memref<i32>
|
||||
llvm.call @print(%23) : (i32) -> ()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK: func @h(%arg0: i1, %arg1: i1, %arg2: i1)
|
||||
// CHECK-NEXT: %c1_i32 = arith.constant 1 : i32
|
||||
// CHECK-NEXT: %c5_i32 = arith.constant 5 : i32
|
||||
// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
|
||||
// CHECK-NEXT: %c-1_i32 = arith.constant -1 : i32
|
||||
// CHECK-NEXT: %0 = memref.alloca() : memref<i32>
|
||||
// CHECK-NEXT: %1 = llvm.mlir.undef : i32
|
||||
// CHECK-NEXT: memref.store %1, %0[] : memref<i32>
|
||||
// CHECK-NEXT: %2:2 = scf.if %arg1 -> (i32, i32) {
|
||||
// CHECK-NEXT: memref.store %c-1_i32, %0[] : memref<i32>
|
||||
// CHECK-NEXT: scf.execute_region {
|
||||
// CHECK-NEXT: memref.store %c0_i32, %0[] : memref<i32>
|
||||
// CHECK-NEXT: scf.yield
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %3 = memref.load %0[] : memref<i32>
|
||||
// CHECK-NEXT: %4:2 = scf.if %arg2 -> (i32, i32) {
|
||||
// CHECK-NEXT: memref.store %c5_i32, %0[] : memref<i32>
|
||||
// CHECK-NEXT: scf.yield %c5_i32, %c5_i32 : i32, i32
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: scf.yield %3, %3 : i32, i32
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: llvm.call @print(%4#0) : (i32) -> ()
|
||||
// CHECK-NEXT: scf.yield %4#0, %4#1 : i32, i32
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: scf.yield %1, %1 : i32, i32
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
|
|
@ -2258,6 +2258,30 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
|
|||
args[3]);
|
||||
return ValueCategory(args[0], /*isReference*/ false);
|
||||
}
|
||||
if (sr->getDecl()->getIdentifier() &&
|
||||
(sr->getDecl()->getName() == "memset" ||
|
||||
sr->getDecl()->getName() == "__builtin_memset")) {
|
||||
std::vector<mlir::Value> args = {
|
||||
getLLVM(expr->getArg(0)), getLLVM(expr->getArg(1)),
|
||||
getLLVM(expr->getArg(2)), /*isVolatile*/
|
||||
builder.create<ConstantIntOp>(loc, false, 1)};
|
||||
|
||||
args[1] = builder.create<TruncIOp>(loc, builder.getI8Type(), args[1]);
|
||||
builder.create<LLVM::MemsetOp>(loc, args[0], args[1], args[2],
|
||||
args[3]);
|
||||
return ValueCategory(args[0], /*isReference*/ false);
|
||||
}
|
||||
if (sr->getDecl()->getIdentifier() &&
|
||||
(sr->getDecl()->getName() == "memcpy" ||
|
||||
sr->getDecl()->getName() == "__builtin_memcpy")) {
|
||||
std::vector<mlir::Value> args = {
|
||||
getLLVM(expr->getArg(0)), getLLVM(expr->getArg(1)),
|
||||
getLLVM(expr->getArg(2)), /*isVolatile*/
|
||||
builder.create<ConstantIntOp>(loc, false, 1)};
|
||||
builder.create<LLVM::MemcpyOp>(loc, args[0], args[1], args[2],
|
||||
args[3]);
|
||||
return ValueCategory(args[0], /*isReference*/ false);
|
||||
}
|
||||
if (sr->getDecl()->getIdentifier() &&
|
||||
(sr->getDecl()->getName() == "cudaMemcpy" ||
|
||||
sr->getDecl()->getName() == "cudaMemcpyAsync" ||
|
||||
|
@ -2271,16 +2295,6 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
|
|||
while (auto BC = dyn_cast<clang::CastExpr>(srcSub))
|
||||
srcSub = BC->getSubExpr();
|
||||
|
||||
if (sr->getDecl()->getName() == "memcpy" ||
|
||||
sr->getDecl()->getName() == "__builtin_memcpy") {
|
||||
std::vector<mlir::Value> args = {
|
||||
getLLVM(expr->getArg(0)), getLLVM(expr->getArg(1)),
|
||||
getLLVM(expr->getArg(2)), /*isVolatile*/
|
||||
builder.create<ConstantIntOp>(loc, false, 1)};
|
||||
builder.create<LLVM::MemcpyOp>(loc, args[0], args[1], args[2],
|
||||
args[3]);
|
||||
return ValueCategory(args[0], /*isReference*/ false);
|
||||
}
|
||||
#if 0
|
||||
auto dstst = dstSub->getType()->getUnqualifiedDesugaredType();
|
||||
if (isa<clang::PointerType>(dstst) || isa<clang::ArrayType>(dstst)) {
|
||||
|
|
|
@ -417,7 +417,7 @@ int main(int argc, char **argv) {
|
|||
llvm::errs() << "</immediate: mlir>\n";
|
||||
}
|
||||
|
||||
bool LinkOMP = false;
|
||||
bool LinkOMP = FOpenMP;
|
||||
pm.enableVerifier(EarlyVerifier);
|
||||
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
|
||||
if (true) {
|
||||
|
|
Loading…
Reference in New Issue