Fix memory leak on ModuleOp

This commit is contained in:
Lorenzo Chelini 2021-08-24 12:34:00 +02:00
parent 376fdc0b48
commit 1b94b95e55
3 changed files with 62 additions and 58 deletions

View File

@ -2298,7 +2298,7 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
tostore = builder.create<polygeist::SubIndexOp>(
loc, mt0, tostore, getConstantIndex(0));
i++;
auto indexType = mlir::IntegerType::get(module.getContext(), 64);
auto indexType = mlir::IntegerType::get(module->getContext(), 64);
auto one = builder.create<mlir::ConstantOp>(
loc, indexType,
builder.getIntegerAttr(builder.getIntegerType(64), 1));
@ -4372,10 +4372,10 @@ MLIRASTConsumer::GetOrCreateLLVMFunction(const FunctionDecl *FD) {
/*isVarArg=*/FD->isVariadic());
// Insert the function into the body of the parent module.
mlir::OpBuilder builder(module.getContext());
builder.setInsertionPointToStart(module.getBody());
mlir::OpBuilder builder(module->getContext());
builder.setInsertionPointToStart(module->getBody());
return llvmFunctions[name] = builder.create<LLVM::LLVMFuncOp>(
module.getLoc(), name, llvmFnType);
module->getLoc(), name, llvmFnType);
}
mlir::LLVM::GlobalOp
@ -4431,11 +4431,11 @@ MLIRASTConsumer::GetOrCreateLLVMGlobal(const ValueDecl *FD) {
auto rt = typeTranslator.translateType(anonymize(getLLVMType(FD->getType())));
mlir::OpBuilder builder(module.getContext());
builder.setInsertionPointToStart(module.getBody());
mlir::OpBuilder builder(module->getContext());
builder.setInsertionPointToStart(module->getBody());
auto glob = builder.create<LLVM::GlobalOp>(
module.getLoc(), rt, /*constant*/ false, lnk, name, mlir::Attribute());
module->getLoc(), rt, /*constant*/ false, lnk, name, mlir::Attribute());
if (cast<VarDecl>(FD)->isThisDeclarationADefinition() ==
VarDecl::Definition ||
@ -4445,9 +4445,9 @@ MLIRASTConsumer::GetOrCreateLLVMGlobal(const ValueDecl *FD) {
glob.getInitializerRegion().push_back(blk);
builder.setInsertionPointToStart(blk);
builder.create<LLVM::ReturnOp>(
module.getLoc(),
module->getLoc(),
std::vector<mlir::Value>(
{builder.create<LLVM::UndefOp>(module.getLoc(), rt)}));
{builder.create<LLVM::UndefOp>(module->getLoc(), rt)}));
}
return llvmGlobals[name] = glob;
}
@ -4476,8 +4476,8 @@ MLIRASTConsumer::GetOrCreateGlobal(const ValueDecl *FD) {
mlir::SymbolTable::Visibility lnk;
mlir::Attribute initial_value;
mlir::OpBuilder builder(module.getContext());
builder.setInsertionPointToStart(module.getBody());
mlir::OpBuilder builder(module->getContext());
builder.setInsertionPointToStart(module->getBody());
if (cast<VarDecl>(FD)->isThisDeclarationADefinition() ==
VarDecl::Definition) {
@ -4525,7 +4525,7 @@ MLIRASTConsumer::GetOrCreateGlobal(const ValueDecl *FD) {
}
auto globalOp = builder.create<mlir::memref::GlobalOp>(
module.getLoc(), builder.getStringAttr(FD->getName()),
module->getLoc(), builder.getStringAttr(FD->getName()),
/*sym_visibility*/ mlir::StringAttr(), mlir::TypeAttr::get(mr),
initial_value, mlir::UnitAttr());
SymbolTable::setSymbolVisibility(globalOp, lnk);
@ -4539,7 +4539,7 @@ mlir::Value MLIRASTConsumer::GetOrCreateGlobalLLVMString(
// Create the global at the entry of the module.
if (llvmStringGlobals.find(value.str()) == llvmStringGlobals.end()) {
OpBuilder::InsertionGuard insertGuard(builder);
builder.setInsertionPointToStart(module.getBody());
builder.setInsertionPointToStart(module->getBody());
auto type = LLVM::LLVMArrayType::get(
mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1);
llvmStringGlobals[value.str()] = builder.create<LLVM::GlobalOp>(
@ -4660,7 +4660,7 @@ mlir::FuncOp MLIRASTConsumer::GetOrCreateMLIRFunction(const FunctionDecl *FD) {
rettypes.push_back(rt);
}
}
mlir::OpBuilder builder(module.getContext());
mlir::OpBuilder builder(module->getContext());
auto funcType = builder.getFunctionType(types, rettypes);
mlir::FuncOp function = mlir::FuncOp(
mlir::FuncOp::create(builder.getUnknownLoc(), name, funcType));
@ -4673,7 +4673,7 @@ mlir::FuncOp MLIRASTConsumer::GetOrCreateMLIRFunction(const FunctionDecl *FD) {
}
functions[name] = function;
module.push_back(function);
module->push_back(function);
const FunctionDecl *Def = nullptr;
if (FD->isDefined(Def, /*checkforfriend*/ true) && Def->getBody()) {
assert(Def->getTemplatedKind() !=
@ -4815,7 +4815,7 @@ mlir::Location MLIRASTConsumer::getMLIRLocation(clang::SourceLocation loc) {
auto colNumber = SM.getSpellingColumnNumber(spellingLoc);
auto fileId = SM.getFilename(spellingLoc);
auto ctx = module.getContext();
auto ctx = module->getContext();
auto mlirIdentifier = Identifier::get(fileId, ctx);
mlir::OpBuilder builder(ctx);
return builder.getUnknownLoc();
@ -4914,7 +4914,8 @@ mlir::Type MLIRASTConsumer::getMLIRType(clang::QualType qt, bool *implicitRef,
}
if (notAllSame || !allowMerge || innerLLVM) {
return mlir::LLVM::LLVMStructType::getLiteral(module.getContext(), types);
return mlir::LLVM::LLVMStructType::getLiteral(module->getContext(),
types);
}
if (!types.size()) {
@ -4932,7 +4933,7 @@ mlir::Type MLIRASTConsumer::getMLIRType(clang::QualType qt, bool *implicitRef,
auto t = qt->getUnqualifiedDesugaredType();
if (t->isVoidType()) {
mlir::OpBuilder builder(module.getContext());
mlir::OpBuilder builder(module->getContext());
return builder.getNoneType();
}
@ -5009,7 +5010,7 @@ mlir::Type MLIRASTConsumer::getMLIRType(clang::QualType qt, bool *implicitRef,
return LLVM::LLVMPointerType::get(subType);
if (PTT->isBooleanType()) {
OpBuilder builder(module);
OpBuilder builder(module->getContext());
return MemRefType::get(outer, builder.getIntegerType(1), {});
}
@ -5063,7 +5064,7 @@ llvm::Type *MLIRASTConsumer::getLLVMType(clang::QualType t) {
}
mlir::Type MLIRASTConsumer::getMLIRType(llvm::Type *t) {
mlir::OpBuilder builder(module.getContext());
mlir::OpBuilder builder(module->getContext());
if (t->isVoidTy()) {
return builder.getNoneType();
}
@ -5175,7 +5176,8 @@ mlir::Type MLIRASTConsumer::getMLIRType(llvm::Type *t) {
} else
types.push_back(st);
}
return mlir::LLVM::LLVMStructType::getLiteral(module.getContext(), types);
return mlir::LLVM::LLVMStructType::getLiteral(module->getContext(),
types);
}
return mlir::MemRefType::get(ST->getNumElements(),
@ -5218,13 +5220,14 @@ class MLIRAction : public clang::ASTFrontendAction {
public:
std::set<std::string> emitIfFound;
std::set<std::string> done;
mlir::ModuleOp &module;
mlir::OwningOpRef<mlir::ModuleOp> &module;
std::map<std::string, mlir::LLVM::GlobalOp> llvmStringGlobals;
std::map<std::string, std::pair<mlir::memref::GlobalOp, bool>> globals;
std::map<std::string, mlir::FuncOp> functions;
std::map<std::string, mlir::LLVM::GlobalOp> llvmGlobals;
std::map<std::string, mlir::LLVM::LLVMFuncOp> llvmFunctions;
MLIRAction(std::string fn, mlir::ModuleOp &module) : module(module) {
MLIRAction(std::string fn, mlir::OwningOpRef<mlir::ModuleOp> &module)
: module(module) {
emitIfFound.insert(fn);
}
std::unique_ptr<clang::ASTConsumer>
@ -5278,7 +5281,8 @@ std::string GetExecutablePath(const char *Argv0, bool CanonicalPrefixes) {
#include "clang/Frontend/TextDiagnosticBuffer.h"
static bool parseMLIR(const char *Argv0, std::vector<std::string> filenames,
std::string fn, std::vector<std::string> includeDirs,
std::vector<std::string> defines, mlir::ModuleOp &module,
std::vector<std::string> defines,
mlir::OwningOpRef<mlir::ModuleOp> &module,
llvm::Triple &triple, llvm::DataLayout &DL) {
IntrusiveRefCntPtr<DiagnosticIDs> DiagID(new DiagnosticIDs());
@ -5422,12 +5426,13 @@ static bool parseMLIR(const char *Argv0, std::vector<std::string> filenames,
Clang->getTarget().adjustTargetOptions(Clang->getCodeGenOpts(),
Clang->getTargetOpts());
module->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
StringAttr::get(module.getContext(),
module.get()->setAttr(
LLVM::LLVMDialect::getDataLayoutAttrName(),
StringAttr::get(module->getContext(),
Clang->getTarget().getDataLayoutString()));
module->setAttr(
module.get()->setAttr(
LLVM::LLVMDialect::getTargetTripleAttrName(),
StringAttr::get(module.getContext(),
StringAttr::get(module->getContext(),
Clang->getTarget().getTriple().getTriple()));
for (const auto &FIF : Clang->getFrontendOpts().Inputs) {

View File

@ -308,7 +308,7 @@ struct MLIRASTConsumer : public ASTConsumer {
std::map<std::string, mlir::LLVM::LLVMFuncOp> &llvmFunctions;
Preprocessor &PP;
ASTContext &astContext;
mlir::ModuleOp &module;
mlir::OwningOpRef<mlir::ModuleOp> &module;
clang::SourceManager &SM;
LLVMContext lcontext;
llvm::Module llvmMod;
@ -329,8 +329,8 @@ struct MLIRASTConsumer : public ASTConsumer {
std::map<std::string, mlir::FuncOp> &functions,
std::map<std::string, mlir::LLVM::GlobalOp> &llvmGlobals,
std::map<std::string, mlir::LLVM::LLVMFuncOp> &llvmFunctions,
Preprocessor &PP, ASTContext &astContext, mlir::ModuleOp &module,
clang::SourceManager &SM)
Preprocessor &PP, ASTContext &astContext,
mlir::OwningOpRef<mlir::ModuleOp> &module, clang::SourceManager &SM)
: emitIfFound(emitIfFound), done(done),
llvmStringGlobals(llvmStringGlobals), globals(globals),
functions(functions), llvmGlobals(llvmGlobals),
@ -339,7 +339,7 @@ struct MLIRASTConsumer : public ASTConsumer {
codegenops(),
CGM(astContext, PP.getHeaderSearchInfo().getHeaderSearchOpts(),
PP.getPreprocessorOpts(), codegenops, llvmMod, PP.getDiagnostics()),
error(false), typeTranslator(*module.getContext()),
error(false), typeTranslator(*module->getContext()),
reverseTypeTranslator(lcontext) {
addPragmaScopHandlers(PP, scopLocList);
addPragmaEndScopHandlers(PP, scopLocList);
@ -385,7 +385,7 @@ struct MLIRScanner : public StmtVisitor<MLIRScanner, ValueWithOffsets> {
private:
MLIRASTConsumer &Glob;
mlir::FuncOp function;
mlir::ModuleOp &module;
mlir::OwningOpRef<mlir::ModuleOp> &module;
mlir::OpBuilder builder;
mlir::Location loc;
mlir::Block *entryBlock;
@ -464,10 +464,10 @@ public:
LowerToInfo &LTInfo;
MLIRScanner(MLIRASTConsumer &Glob, mlir::FuncOp function,
const FunctionDecl *fd, mlir::ModuleOp &module,
const FunctionDecl *fd, mlir::OwningOpRef<mlir::ModuleOp> &module,
LowerToInfo &LTInfo)
: Glob(Glob), function(function), module(module),
builder(module.getContext()), loc(builder.getUnknownLoc()),
builder(module->getContext()), loc(builder.getUnknownLoc()),
EmittingFunctionDecl(fd), ThisCapture(nullptr), LTInfo(LTInfo) {
if (ShowAST) {

View File

@ -153,8 +153,8 @@ int main(int argc, char **argv) {
}
return 0;
}
auto module =
mlir::ModuleOp::create(mlir::OpBuilder(&context).getUnknownLoc());
mlir::OwningOpRef<mlir::ModuleOp> module(
mlir::ModuleOp::create(mlir::OpBuilder(&context).getUnknownLoc()));
llvm::Triple triple;
llvm::DataLayout DL("");
@ -164,7 +164,7 @@ int main(int argc, char **argv) {
if (ImmediateMLIR) {
llvm::errs() << "<immediate: mlir>\n";
module.dump();
module->dump();
llvm::errs() << "</immediate: mlir>\n";
}
pm.enableVerifier(false);
@ -199,12 +199,12 @@ int main(int argc, char **argv) {
if (ScalarReplacement)
optPM.addPass(mlir::createAffineScalarReplacementPass());
}
if (mlir::failed(pm.run(module))) {
module.dump();
if (mlir::failed(pm.run(module.get()))) {
module->dump();
return 4;
}
if (mlir::failed(mlir::verify(module))) {
module.dump();
if (mlir::failed(mlir::verify(module.get()))) {
module->dump();
return 5;
}
#define optPM optPM2
@ -246,15 +246,15 @@ int main(int argc, char **argv) {
if (EmitLLVM) {
pm.addPass(mlir::createLowerAffinePass());
pm.nest<mlir::FuncOp>().addPass(mlir::createConvertMathToLLVMPass());
if (mlir::failed(pm.run(module))) {
module.dump();
if (mlir::failed(pm.run(module.get()))) {
module->dump();
return 4;
}
mlir::PassManager pm2(&context);
if (SCFOpenMP)
pm2.nest<mlir::FuncOp>().addPass(createConvertSCFToOpenMPPass());
if (mlir::failed(pm2.run(module))) {
module.dump();
if (mlir::failed(pm2.run(module.get()))) {
module->dump();
return 4;
}
mlir::PassManager pm3(&context);
@ -265,29 +265,28 @@ int main(int argc, char **argv) {
// invalid for gemm.c init array
// options.useBarePtrCallConv = true;
pm3.addPass(mlir::createLowerToLLVMPass(options));
if (mlir::failed(pm3.run(module))) {
module.dump();
if (mlir::failed(pm3.run(module.get()))) {
module->dump();
return 4;
}
} else {
if (mlir::failed(pm.run(module))) {
module.dump();
if (mlir::failed(pm.run(module.get()))) {
module->dump();
return 4;
}
}
// module.dump();
if (mlir::failed(mlir::verify(module))) {
module.dump();
if (mlir::failed(mlir::verify(module.get()))) {
module->dump();
return 5;
}
}
if (EmitLLVM) {
llvm::LLVMContext llvmContext;
auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext);
auto llvmModule = mlir::translateModuleToLLVMIR(module.get(), llvmContext);
if (!llvmModule) {
module.dump();
module->dump();
llvm::errs() << "Failed to emit LLVM IR\n";
return -1;
}
@ -303,11 +302,11 @@ int main(int argc, char **argv) {
} else {
if (Output == "-")
module.print(outs());
module->print(outs());
else {
std::error_code EC;
llvm::raw_fd_ostream out(Output, EC);
module.print(out);
module->print(out);
}
}
return 0;