diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1953cbd..f70bc2d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -53,8 +53,8 @@ jobs: run: | cd build ls ../mlir-build/lib/cmake/clang - CYMBL=OFF cmake ../src/ -GNinja -DMLIR_DIR=`pwd`/../mlir-build/lib/cmake/mlir -DLLVM_EXTERNAL_LIT=`pwd`/../mlir-build/bin/llvm-lit -DClang_DIR=`pwd`/../mlir-build/lib/cmake/clang -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_C_COMPILER=/bin/clang -DCMAKE_CXX_COMPILER=/bin/clang++ -DCMAKE_ASM_COMPILER=/bin/clang -DCMAKE_CXX_FLAGS="-Wno-c++11-narrowing" + cmake ../src/ -GNinja -DMLIR_DIR=`pwd`/../mlir-build/lib/cmake/mlir -DLLVM_EXTERNAL_LIT=`pwd`/../mlir-build/bin/llvm-lit -DClang_DIR=`pwd`/../mlir-build/lib/cmake/clang -DCMAKE_BUILD_TYPE=${{ matrix.build }} - name: test mlir-clang run: | cd build - ninja -j125 check-mlir-clang \ No newline at end of file + ninja check-mlir-clang diff --git a/lib/polygeist/Ops.cpp b/lib/polygeist/Ops.cpp index 250ba7a..f640b2d 100644 --- a/lib/polygeist/Ops.cpp +++ b/lib/polygeist/Ops.cpp @@ -7,11 +7,11 @@ //===----------------------------------------------------------------------===// #include "polygeist/Ops.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" #include "polygeist/Dialect.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #define GET_OP_CLASSES #include "polygeist/PolygeistOps.cpp.inc" @@ -275,15 +275,16 @@ public: LogicalResult matchAndRewrite(Memref2PointerOp op, PatternRewriter &rewriter) const override { - auto src = op.source().getDefiningOp(); - if (!src) - return failure(); + auto src = op.source().getDefiningOp(); + if (!src) + return failure(); - rewriter.replaceOpWithNewOp(op, op.getType(), src.source()); - return success(); + rewriter.replaceOpWithNewOp(op, op.getType(), + src.source()); + return success(); } }; -void Memref2PointerOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { +void Memref2PointerOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } diff --git a/lib/polygeist/Passes/CanonicalizeFor.cpp b/lib/polygeist/Passes/CanonicalizeFor.cpp index 6c61611..2666551 100644 --- a/lib/polygeist/Passes/CanonicalizeFor.cpp +++ b/lib/polygeist/Passes/CanonicalizeFor.cpp @@ -1428,34 +1428,29 @@ struct ReturnSq : public OpRewritePattern { LogicalResult matchAndRewrite(ReturnOp op, PatternRewriter &rewriter) const override { bool changed = false; - SmallVector toErase; - for (auto iter = op->getBlock()->rbegin(); iter != op->getBlock()->rend() && &*iter != op; iter++) { - changed = true; - toErase.push_back(&*iter); + SmallVector toErase; + for (auto iter = op->getBlock()->rbegin(); + iter != op->getBlock()->rend() && &*iter != op; iter++) { + changed = true; + toErase.push_back(&*iter); } - for(auto op : toErase) { - rewriter.eraseOp(op); + for (auto op : toErase) { + rewriter.eraseOp(op); } return success(changed); } }; void CanonicalizeFor::runOnFunction() { mlir::RewritePatternSet rpl(getFunction().getContext()); - rpl.add< - PropagateInLoopBody, DetectTrivialIndVarInArgs, - ForOpInductionReplacement, RemoveUnusedArgs, MoveWhileToFor, - - MoveWhileDown, - MoveWhileDown2 - - , MoveWhileDown3 + rpl.add(getFunction().getContext()); + MoveWhileDown3, MoveWhileInvariantIfResult, WhileLogicalNegation, + SubToAdd, WhileCmpOffset, WhileLICM, RemoveUnusedCondVar, ReturnSq, + MoveSideEffectFreeWhile>(getFunction().getContext()); GreedyRewriteConfig config; config.maxIterations = 47; applyPatternsAndFoldGreedily(getFunction().getOperation(), std::move(rpl), diff --git a/lib/polygeist/Passes/ParallelLower.cpp b/lib/polygeist/Passes/ParallelLower.cpp index 4bf9737..0681f9c 100644 --- a/lib/polygeist/Passes/ParallelLower.cpp +++ b/lib/polygeist/Passes/ParallelLower.cpp @@ -150,9 +150,9 @@ void ParallelLower::runOnFunction() { symbolTable.getSymbolTable(symbolTableOp); getFunction().walk([&](mlir::CallOp bidx) { - if (bidx.callee() == "cudaThreadSynchronize") - bidx.erase(); - }); + if (bidx.callee() == "cudaThreadSynchronize") + bidx.erase(); + }); // Only supports single block functions at the moment. getFunction().walk([&](gpu::LaunchOp launchOp) { @@ -250,19 +250,20 @@ void ParallelLower::runOnFunction() { }); container.walk([&](mlir::memref::AllocaOp alop) { - if (auto ia = alop.getType().getMemorySpace().dyn_cast_or_null()) - if (ia.getValue() == 5) { - mlir::OpBuilder bz(launchOp.getContext()); - bz.setInsertionPointToStart(blockB); - auto newAlloca = bz.create( - alop.getLoc(), - MemRefType::get(alop.getType().getShape(), - alop.getType().getElementType(), - alop.getType().getAffineMaps(), (uint64_t)0)); - alop.replaceAllUsesWith((mlir::Value)bz.create( - alop.getLoc(), newAlloca, alop.getType())); - alop.erase(); - } + if (auto ia = + alop.getType().getMemorySpace().dyn_cast_or_null()) + if (ia.getValue() == 5) { + mlir::OpBuilder bz(launchOp.getContext()); + bz.setInsertionPointToStart(blockB); + auto newAlloca = bz.create( + alop.getLoc(), + MemRefType::get(alop.getType().getShape(), + alop.getType().getElementType(), + alop.getType().getAffineMaps(), (uint64_t)0)); + alop.replaceAllUsesWith((mlir::Value)bz.create( + alop.getLoc(), newAlloca, alop.getType())); + alop.erase(); + } }); container.walk([&](mlir::gpu::ThreadIdOp bidx) { diff --git a/lib/polygeist/Passes/TrivialUse.cpp b/lib/polygeist/Passes/TrivialUse.cpp index 2c0627d..e09126c 100644 --- a/lib/polygeist/Passes/TrivialUse.cpp +++ b/lib/polygeist/Passes/TrivialUse.cpp @@ -32,9 +32,6 @@ std::unique_ptr> createRemoveTrivialUsePass() { } // namespace polygeist } // namespace mlir - void RemoveTrivialUse::runOnFunction() { - getFunction().walk([&](polygeist::TrivialUseOp bidx) { - bidx.erase(); - }); + getFunction().walk([&](polygeist::TrivialUseOp bidx) { bidx.erase(); }); } diff --git a/mlir-clang/Lib/clang-mlir.cc b/mlir-clang/Lib/clang-mlir.cc index 02e7571..90b1e44 100644 --- a/mlir-clang/Lib/clang-mlir.cc +++ b/mlir-clang/Lib/clang-mlir.cc @@ -64,23 +64,24 @@ mlir::Value MLIRScanner::createAllocOp(mlir::Type t, VarDecl *name, auto pshape = shape[0]; if (name) - if (auto var = dyn_cast(name->getType()->getUnqualifiedDesugaredType())) { + if (auto var = dyn_cast( + name->getType()->getUnqualifiedDesugaredType())) { llvm::errs() << t << "\n"; assert(shape[0] == -1); - mr = mlir::MemRefType::get(shape, mt.getElementType(), mt.getAffineMaps(), - memspace); + mr = mlir::MemRefType::get(shape, mt.getElementType(), + mt.getAffineMaps(), memspace); auto len = Visit(var->getSizeExpr()).getValue(builder); len = builder.create(loc, len, builder.getIndexType()); alloc = builder.create(loc, mr, len); builder.create(loc, alloc); - } + } if (!alloc) { - if (pshape == -1) - shape[0] = 1; - mr = mlir::MemRefType::get(shape, mt.getElementType(), mt.getAffineMaps(), - memspace); - alloc = abuilder.create(loc, mr); + if (pshape == -1) + shape[0] = 1; + mr = mlir::MemRefType::get(shape, mt.getElementType(), mt.getAffineMaps(), + memspace); + alloc = abuilder.create(loc, mr); } shape[0] = pshape; @@ -108,7 +109,8 @@ ValueWithOffsets MLIRScanner::VisitDeclStmt(clang::DeclStmt *decl) { for (auto sub : decl->decls()) { if (auto vd = dyn_cast(sub)) { VisitVarDecl(vd); - } else if (isa(sub)) { + } else if (isa( + sub)) { } else { llvm::errs() << " + visiting unknonwn sub decl stmt\n"; sub->dump(); @@ -466,7 +468,7 @@ MLIRScanner::VisitCXXFunctionalCastExpr(clang::CXXFunctionalCastExpr *expr) { builder.create(loc, scalar, postTy), /*isReference*/ false); } -} + } expr->dump(); assert(0 && "unhandled functional cast type"); } @@ -904,33 +906,44 @@ ValueWithOffsets MLIRScanner::VisitDoStmt(clang::DoStmt *fors) { return nullptr; } -ValueWithOffsets MLIRScanner::VisitOMPParallelForDirective(clang::OMPParallelForDirective* fors) { +ValueWithOffsets MLIRScanner::VisitOMPParallelForDirective( + clang::OMPParallelForDirective *fors) { IfScope scope(*this); fors->dump(); - + Visit(fors->getPreInits()); SmallVector inits; for (auto f : fors->inits()) { - llvm::errs() << " init: "; f->dump(); llvm::errs() << "\n"; - inits.push_back(builder.create(loc, Visit(f).getValue(builder), builder.getIndexType())); + llvm::errs() << " init: "; + f->dump(); + llvm::errs() << "\n"; + inits.push_back(builder.create(loc, Visit(f).getValue(builder), + builder.getIndexType())); } inits.clear(); inits.push_back(getConstantIndex(0)); SmallVector finals; for (auto f : fors->finals()) { - llvm::errs() << " final: "; f->dump(); llvm::errs() << "\n"; - finals.push_back(builder.create(loc, Visit(f).getValue(builder), builder.getIndexType())); + llvm::errs() << " final: "; + f->dump(); + llvm::errs() << "\n"; + finals.push_back(builder.create( + loc, Visit(f).getValue(builder), builder.getIndexType())); } finals.clear(); - finals.push_back(builder.create(loc, Visit(fors->getNumIterations()).getValue(builder), builder.getIndexType())); - + finals.push_back(builder.create( + loc, Visit(fors->getNumIterations()).getValue(builder), + builder.getIndexType())); + SmallVector counters; mlir::Value incs[] = {getConstantIndex(1)}; for (auto m : fors->getInnermostCapturedStmt()->captures()) { - llvm::errs() << " cap: "; m.getCapturedVar()->dump(); llvm::errs() << "\n"; + llvm::errs() << " cap: "; + m.getCapturedVar()->dump(); + llvm::errs() << "\n"; /* if (m->getCaptureKind() == LambdaCaptureKind::LCK_ByCopy) CommonFieldLookup(expr->getCallOperator()->getThisObjectType(), @@ -947,29 +960,57 @@ ValueWithOffsets MLIRScanner::VisitOMPParallelForDirective(clang::OMPParallelFor } for (auto f : fors->counters()) { - llvm::errs() << " counter: "; f->dump(); llvm::errs() << "\n"; + llvm::errs() << " counter: "; + f->dump(); + llvm::errs() << "\n"; cast(f)->getDecl()->dump(); - //counters.push_back(builder.create(loc, Visit(f).getValue(builder), builder.getIndexType())); + // counters.push_back(builder.create(loc, + // Visit(f).getValue(builder), builder.getIndexType())); } - llvm::errs() << " assoc:"; fors->getAssociatedStmt()->dump(); llvm::errs() << "\n"; - llvm::errs() << " sblock:"; fors->getStructuredBlock()->dump(); llvm::errs() << "\n"; + llvm::errs() << " assoc:"; + fors->getAssociatedStmt()->dump(); + llvm::errs() << "\n"; + llvm::errs() << " sblock:"; + fors->getStructuredBlock()->dump(); + llvm::errs() << "\n"; - llvm::errs() << " preinit: "; fors->getPreInits()->dump(); llvm::errs() << "\n"; - llvm::errs() << " init: "; fors->getInit()->dump(); llvm::errs() << "\n"; - llvm::errs() << " lb: "; fors->getLowerBoundVariable()->dump(); llvm::errs() << "\n"; - llvm::errs() << " ub: "; fors->getUpperBoundVariable()->dump(); llvm::errs() << "\n"; - llvm::errs() << " precond: "; fors->getPreCond()->dump(); llvm::errs() << "\n"; - llvm::errs() << " cond: "; fors->getCond()->dump(); llvm::errs() << "\n"; - llvm::errs() << " inc: "; fors->getInc()->dump(); llvm::errs() << "\n"; - llvm::errs() << " last: "; fors->getLastIteration()->dump(); llvm::errs() << "\n"; - llvm::errs() << " stride: "; fors->getStrideVariable()->dump(); llvm::errs() << "\n"; - llvm::errs() << " iters: "; fors->getNumIterations()->dump(); llvm::errs() << "\n"; - llvm::errs() << " body: "; fors->getBody()->dump(); llvm::errs() << "\n"; + llvm::errs() << " preinit: "; + fors->getPreInits()->dump(); + llvm::errs() << "\n"; + llvm::errs() << " init: "; + fors->getInit()->dump(); + llvm::errs() << "\n"; + llvm::errs() << " lb: "; + fors->getLowerBoundVariable()->dump(); + llvm::errs() << "\n"; + llvm::errs() << " ub: "; + fors->getUpperBoundVariable()->dump(); + llvm::errs() << "\n"; + llvm::errs() << " precond: "; + fors->getPreCond()->dump(); + llvm::errs() << "\n"; + llvm::errs() << " cond: "; + fors->getCond()->dump(); + llvm::errs() << "\n"; + llvm::errs() << " inc: "; + fors->getInc()->dump(); + llvm::errs() << "\n"; + llvm::errs() << " last: "; + fors->getLastIteration()->dump(); + llvm::errs() << "\n"; + llvm::errs() << " stride: "; + fors->getStrideVariable()->dump(); + llvm::errs() << "\n"; + llvm::errs() << " iters: "; + fors->getNumIterations()->dump(); + llvm::errs() << "\n"; + llvm::errs() << " body: "; + fors->getBody()->dump(); + llvm::errs() << "\n"; + + auto affineOp = builder.create(loc, inits, finals, incs); - auto affineOp = builder.create( - loc, inits, finals, incs); - fors->getIterationVariable()->dump(); auto inds = affineOp.getInductionVars(); @@ -985,8 +1026,10 @@ ValueWithOffsets MLIRScanner::VisitOMPParallelForDirective(clang::OMPParallelFor builder.create(loc); builder.setInsertionPointToStart(&er.region().back()); - auto idx = builder.create(loc, inds[0], getMLIRType(fors->getIterationVariable()->getType())); - VarDecl* name = cast(cast(fors->counters()[0])->getDecl()); + auto idx = builder.create( + loc, inds[0], getMLIRType(fors->getIterationVariable()->getType())); + VarDecl *name = + cast(cast(fors->counters()[0])->getDecl()); assert(params.find(name) == params.end()); params[name] = ValueWithOffsets(idx, false); @@ -1146,7 +1189,7 @@ MLIRScanner::VisitConstructCommon(clang::CXXConstructExpr *cons, VarDecl *name, isArray = true; } if (op == nullptr) - op = createAllocOp(subType, name, memtype, isArray, LLVMABI); + op = createAllocOp(subType, name, memtype, isArray, LLVMABI); SmallVector args; args.push_back(op); @@ -1393,15 +1436,16 @@ MLIRScanner::VisitArraySubscriptExpr(clang::ArraySubscriptExpr *expr) { // Check the RHS has been successfully emitted assert(rhs); auto idx = castToIndex(getMLIRLocation(expr->getRBracketLoc()), rhs); - if (isa(expr->getLHS()->getType()->getUnqualifiedDesugaredType())) { + if (isa( + expr->getLHS()->getType()->getUnqualifiedDesugaredType())) { assert(moo.isReference); moo.isReference = false; auto mt = moo.val.getType().cast(); - + auto shape = std::vector(mt.getShape()); shape.erase(shape.begin()); auto mt0 = mlir::MemRefType::get(shape, mt.getElementType(), - mt.getAffineMaps(), mt.getMemorySpace()); + mt.getAffineMaps(), mt.getMemorySpace()); moo.val = builder.create(loc, mt0, moo.val, getConstantIndex(0)); } @@ -1657,15 +1701,13 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) { } if (auto ic = dyn_cast(expr->getCallee())) if (auto sr = dyn_cast(ic->getSubExpr())) { - if (sr->getDecl()->getIdentifier() && - sr->getDecl()->getName() == "log") { + if (sr->getDecl()->getIdentifier() && sr->getDecl()->getName() == "log") { std::vector args; for (auto a : expr->arguments()) { args.push_back(Visit(a).getValue(builder)); } - return ValueWithOffsets( - builder.create(loc, args[0]), - /*isReference*/ false); + return ValueWithOffsets(builder.create(loc, args[0]), + /*isReference*/ false); } } if (auto ic = dyn_cast(expr->getCallee())) @@ -1783,14 +1825,14 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) { } } if (auto IC1 = dyn_cast(E)) { - if (auto U0 = dyn_cast(IC1->getSubExpr())) - if (auto IC2 = dyn_cast(U0->getSubExpr())) { - if (auto slit = - dyn_cast(IC2->getFunctionName())) { - return Glob.GetOrCreateGlobalLLVMString(loc, builder, - slit->getString()); + if (auto U0 = dyn_cast(IC1->getSubExpr())) + if (auto IC2 = dyn_cast(U0->getSubExpr())) { + if (auto slit = + dyn_cast(IC2->getFunctionName())) { + return Glob.GetOrCreateGlobalLLVMString(loc, builder, + slit->getString()); + } } - } } if (auto slit = dyn_cast(IC1->getSubExpr())) { return Glob.GetOrCreateGlobalLLVMString(loc, builder, @@ -1811,7 +1853,8 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) { } auto val = sub.getValue(builder); if (auto mt = val.getType().dyn_cast()) { - val = builder.create(loc, LLVM::LLVMPointerType::get(mt.getElementType()), val); + val = builder.create( + loc, LLVM::LLVMPointerType::get(mt.getElementType()), val); } return val; }; @@ -1869,12 +1912,11 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) { builder.create(loc, args[0]); if (sr->getDecl()->getName() == "cudaFree" || - sr->getDecl()->getName() == "cudaFreeHost") { - auto ty = getMLIRType(expr->getType()); - auto op = builder.create( - loc, ty, - builder.getIntegerAttr(ty, /*cudaSuccess*/0)); - return ValueWithOffsets(op, /*isReference*/false); + sr->getDecl()->getName() == "cudaFreeHost") { + auto ty = getMLIRType(expr->getType()); + auto op = builder.create( + loc, ty, builder.getIntegerAttr(ty, /*cudaSuccess*/ 0)); + return ValueWithOffsets(op, /*isReference*/ false); } return nullptr; } @@ -1923,47 +1965,49 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) { if (auto BC = dyn_cast(expr->getArg(0))) { auto dst = Visit(BC->getSubExpr()).getValue(builder); if (auto omt = dst.getType().dyn_cast()) { - auto mt = omt.getElementType().dyn_cast(); - auto shape = std::vector(mt.getShape()); + auto mt = omt.getElementType().dyn_cast(); + auto shape = std::vector(mt.getShape()); - auto elemSize = getTypeSize( - cast( - cast(BC->getSubExpr() - ->getType() - ->getUnqualifiedDesugaredType()) - ->getPointeeType()) - ->getPointeeType()); - mlir::Value allocSize = builder.create( - loc, Visit(expr->getArg(1)).getValue(builder), - mlir::IndexType::get(builder.getContext())); - mlir::Value args[1] = {builder.create( - loc, allocSize, - builder.create( - loc, allocSize.getType(), - builder.getIntegerAttr(allocSize.getType(), elemSize)))}; - auto alloc = builder.create( - loc, - (sr->getDecl()->getName() == "cudaMalloc" && !CudaLower) - ? mlir::MemRefType::get(shape, mt.getElementType(), - mt.getAffineMaps(), 1) - : mt, - args); - ValueWithOffsets(dst, /*isReference*/ true) - .store(builder, - builder.create(loc, alloc, mt)); - return ValueWithOffsets(getConstantIndex(0), /*isReference*/ false); - } } + auto elemSize = + getTypeSize(cast( + cast( + BC->getSubExpr() + ->getType() + ->getUnqualifiedDesugaredType()) + ->getPointeeType()) + ->getPointeeType()); + mlir::Value allocSize = builder.create( + loc, Visit(expr->getArg(1)).getValue(builder), + mlir::IndexType::get(builder.getContext())); + mlir::Value args[1] = {builder.create( + loc, allocSize, + builder.create( + loc, allocSize.getType(), + builder.getIntegerAttr(allocSize.getType(), elemSize)))}; + auto alloc = builder.create( + loc, + (sr->getDecl()->getName() == "cudaMalloc" && !CudaLower) + ? mlir::MemRefType::get(shape, mt.getElementType(), + mt.getAffineMaps(), 1) + : mt, + args); + ValueWithOffsets(dst, /*isReference*/ true) + .store(builder, + builder.create(loc, alloc, mt)); + return ValueWithOffsets(getConstantIndex(0), /*isReference*/ false); + } + } if (auto ic = dyn_cast(expr->getCallee())) if (auto sr = dyn_cast(ic->getSubExpr())) if (sr->getDecl()->getIdentifier() && - (sr->getDecl()->getName() == "cudaMemcpy" || sr->getDecl()->getName() == "memcpy")) { + (sr->getDecl()->getName() == "cudaMemcpy" || + sr->getDecl()->getName() == "memcpy")) { if (auto BCdst = dyn_cast(expr->getArg(0))) { auto elem = - cast(BCdst->getSubExpr() - ->getType() - ->getUnqualifiedDesugaredType()) - ->getPointeeType(); + cast( + BCdst->getSubExpr()->getType()->getUnqualifiedDesugaredType()) + ->getPointeeType(); if (auto BCsrc = dyn_cast(expr->getArg(1))) { auto selem = cast(BCsrc->getSubExpr() @@ -1971,65 +2015,80 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) { ->getUnqualifiedDesugaredType()) ->getPointeeType(); if (elem == selem) { - auto dst = Visit(BCdst->getSubExpr()).getValue(builder); - if (dst.getType().isa()) { - auto src = Visit(BCsrc->getSubExpr()).getValue(builder); + auto dst = Visit(BCdst->getSubExpr()).getValue(builder); + if (dst.getType().isa()) { + auto src = Visit(BCsrc->getSubExpr()).getValue(builder); - bool dstArray = false; - Glob.getMLIRType(elem, &dstArray); - auto elemSize = getTypeSize(elem); - mlir::Value size = builder.create( - loc, Visit(expr->getArg( sr->getDecl()->getName() == "cudaMemcpy" ? 2 : 1)).getValue(builder), - mlir::IndexType::get(builder.getContext())); - size = builder.create( - loc, size, - builder.create(loc, elemSize)); + bool dstArray = false; + Glob.getMLIRType(elem, &dstArray); + auto elemSize = getTypeSize(elem); + mlir::Value size = builder.create( + loc, + Visit(expr->getArg( + sr->getDecl()->getName() == "cudaMemcpy" ? 2 : 1)) + .getValue(builder), + mlir::IndexType::get(builder.getContext())); + size = builder.create( + loc, size, + builder.create(loc, elemSize)); - std::vector start = {getConstantIndex(0)}; - std::vector sizes = {size}; - AffineMap map = builder.getSymbolIdentityMap(); - auto affineOp = - builder.create(loc, start, map, sizes, map); + std::vector start = {getConstantIndex(0)}; + std::vector sizes = {size}; + AffineMap map = builder.getSymbolIdentityMap(); + auto affineOp = + builder.create(loc, start, map, sizes, map); - auto oldpoint = builder.getInsertionPoint(); - auto oldblock = builder.getInsertionBlock(); + auto oldpoint = builder.getInsertionPoint(); + auto oldblock = builder.getInsertionBlock(); - std::vector args = {affineOp.getInductionVar()}; + std::vector args = {affineOp.getInductionVar()}; - builder.setInsertionPointToStart(&affineOp.getLoopBody().front()); + builder.setInsertionPointToStart( + &affineOp.getLoopBody().front()); - if (dstArray) { - std::vector start = {getConstantIndex(0)}; - auto mt = - Glob.getMLIRType(Glob.CGM.getContext().getPointerType(elem)) - .cast(); - auto shape = std::vector(mt.getShape()); - std::vector sizes = {getConstantIndex(shape[1])}; - AffineMap map = builder.getSymbolIdentityMap(); - auto affineOp = - builder.create(loc, start, map, sizes, map); - args.push_back(affineOp.getInductionVar()); - builder.setInsertionPointToStart(&affineOp.getLoopBody().front()); + if (dstArray) { + std::vector start = {getConstantIndex(0)}; + auto mt = Glob.getMLIRType( + Glob.CGM.getContext().getPointerType(elem)) + .cast(); + auto shape = std::vector(mt.getShape()); + std::vector sizes = {getConstantIndex(shape[1])}; + AffineMap map = builder.getSymbolIdentityMap(); + auto affineOp = + builder.create(loc, start, map, sizes, map); + args.push_back(affineOp.getInductionVar()); + builder.setInsertionPointToStart( + &affineOp.getLoopBody().front()); + } + + builder.create( + loc, builder.create(loc, src, args), dst, + args); + + // TODO: set the value of the iteration value to the final bound + // at the end of the loop. + builder.setInsertionPoint(oldblock, oldpoint); + + return ValueWithOffsets(getConstantIndex(0), + /*isReference*/ false); + } } - - builder.create( - loc, builder.create(loc, src, args), dst, args); - - // TODO: set the value of the iteration value to the final bound at - // the end of the loop. - builder.setInsertionPoint(oldblock, oldpoint); - - return ValueWithOffsets(getConstantIndex(0), /*isReference*/ false); - } } } } + } + } } auto callee = EmitCallee(expr->getCallee()); - std::set funcs = {"strcmp", "sprintf", "fputs", "puts", - "memcpy", "cudaMalloc", + std::set funcs = {"strcmp", + "sprintf", + "fputs", + "puts", + "memcpy", + "cudaMalloc", "open", "fopen", - "memset", "cudaMemset", + "memset", + "cudaMemset", "strcpy", "close", "fclose", @@ -2090,7 +2149,6 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) { expr->isXValue()); } - if (auto ic = dyn_cast(expr->getCallee())) if (auto sr = dyn_cast(ic->getSubExpr())) { if (sr->getDecl()->getIdentifier() && @@ -2466,8 +2524,9 @@ ValueWithOffsets MLIRScanner::VisitUnaryOperator(clang::UnaryOperator *U) { val.dump(); } auto ty = val.getType().cast(); - auto c1 = builder.create(loc, ty, - builder.getIntegerAttr(ty, APInt::getAllOnesValue(ty.getWidth()))); + auto c1 = builder.create( + loc, ty, + builder.getIntegerAttr(ty, APInt::getAllOnesValue(ty.getWidth()))); return ValueWithOffsets(builder.create(loc, val, c1), /*isReference*/ false); } @@ -2490,14 +2549,14 @@ ValueWithOffsets MLIRScanner::VisitUnaryOperator(clang::UnaryOperator *U) { shape[0] = -1; } else { shape.insert(shape.begin(), -1); - } auto mt0 = mlir::MemRefType::get(shape, mt.getElementType(), mt.getAffineMaps(), mt.getMemorySpace()); if (!isArray) - res = builder.create(loc, sub.val, mt0); + res = builder.create(loc, sub.val, mt0); else - res = builder.create(loc, mt0, sub.val, getConstantIndex(-1)); + res = builder.create(loc, mt0, sub.val, + getConstantIndex(-1)); return ValueWithOffsets(res, /*isReference*/ false); } @@ -2527,11 +2586,11 @@ ValueWithOffsets MLIRScanner::VisitUnaryOperator(clang::UnaryOperator *U) { mlir::Value next; if (auto ft = ty.dyn_cast()) { - if (prev.getType() != ty) { - U->dump(); - llvm::errs() << " ty: " << ty << "prev: " << prev << "\n"; - } - assert(prev.getType() == ty); + if (prev.getType() != ty) { + U->dump(); + llvm::errs() << " ty: " << ty << "prev: " << prev << "\n"; + } + assert(prev.getType() == ty); next = builder.create( loc, prev, builder.create( @@ -2555,11 +2614,11 @@ ValueWithOffsets MLIRScanner::VisitUnaryOperator(clang::UnaryOperator *U) { llvm::errs() << ty << " - " << prev << "\n"; U->dump(); } - if (prev.getType() != ty) { - U->dump(); - llvm::errs() << " ty: " << ty << "prev: " << prev << "\n"; - } - assert(prev.getType() == ty); + if (prev.getType() != ty) { + U->dump(); + llvm::errs() << " ty: " << ty << "prev: " << prev << "\n"; + } + assert(prev.getType() == ty); next = builder.create( loc, prev, builder.create(loc, 1, @@ -3136,14 +3195,14 @@ ValueWithOffsets MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) { mlir::Value result; if (auto postTy = prev.getType().dyn_cast()) { - mlir::Value rhsV = rhs.getValue(builder); - auto prevTy = rhsV.getType().cast(); - if (prevTy == postTy) {} - else if (prevTy.getWidth() < postTy.getWidth()) { - rhsV = builder.create(loc, rhsV, postTy); - } else { - rhsV = builder.create(loc, rhsV, postTy); - } + mlir::Value rhsV = rhs.getValue(builder); + auto prevTy = rhsV.getType().cast(); + if (prevTy == postTy) { + } else if (prevTy.getWidth() < postTy.getWidth()) { + rhsV = builder.create(loc, rhsV, postTy); + } else { + rhsV = builder.create(loc, rhsV, postTy); + } assert(rhsV.getType() == prev.getType()); result = builder.create(loc, prev, rhsV); } else if (auto pt = @@ -3299,7 +3358,9 @@ ValueWithOffsets MLIRScanner::CommonFieldLookup(clang::QualType CT, size_t fnum = 0; auto CXRD = dyn_cast(rd); - if ((CXRD && (!CXRD->hasDefinition() || CXRD->isPolymorphic() || CXRD->getNumBases() > 0)) || rd->isUnion()) { + if ((CXRD && (!CXRD->hasDefinition() || CXRD->isPolymorphic() || + CXRD->getNumBases() > 0)) || + rd->isUnion()) { auto &layout = Glob.CGM.getTypes().getCGRecordLayout(rd); fnum = layout.getLLVMFieldNo(FD); } else { @@ -3503,7 +3564,8 @@ ValueWithOffsets MLIRScanner::VisitMemberExpr(MemberExpr *ME) { ME->dump(); } base = base.dereference(builder); - OT = cast(OT->getUnqualifiedDesugaredType())->getPointeeType(); + OT = cast(OT->getUnqualifiedDesugaredType()) + ->getPointeeType(); } if (!base.isReference) { EmittingFunctionDecl->dump(); @@ -3746,8 +3808,10 @@ ValueWithOffsets MLIRScanner::VisitCastExpr(CastExpr *E) { builder.create(loc, postTy, scalar), /*isReference*/ false); } - if (scalar.getType().isa() || postTy.isa()) { - return ValueWithOffsets(builder.create(loc, scalar, postTy), false); + if (scalar.getType().isa() || + postTy.isa()) { + return ValueWithOffsets(builder.create(loc, scalar, postTy), + false); } if (!scalar.getType().isa()) { E->dump(); @@ -3845,7 +3909,8 @@ ValueWithOffsets MLIRScanner::VisitCastExpr(CastExpr *E) { case clang::CastKind::CK_PointerToBoolean: { auto scalar = Visit(E->getSubExpr()).getValue(builder); if (auto mt = scalar.getType().dyn_cast()) { - scalar = builder.create(loc, LLVM::LLVMPointerType::get(mt.getElementType()), scalar); + scalar = builder.create( + loc, LLVM::LLVMPointerType::get(mt.getElementType()), scalar); } if (auto LT = scalar.getType().dyn_cast()) { auto nullptr_llvm = builder.create(loc, LT); @@ -3864,7 +3929,8 @@ ValueWithOffsets MLIRScanner::VisitCastExpr(CastExpr *E) { case clang::CastKind::CK_PointerToIntegral: { auto scalar = Visit(E->getSubExpr()).getValue(builder); if (auto mt = scalar.getType().dyn_cast()) { - scalar = builder.create(loc, LLVM::LLVMPointerType::get(mt.getElementType()), scalar); + scalar = builder.create( + loc, LLVM::LLVMPointerType::get(mt.getElementType()), scalar); } if (auto LT = scalar.getType().dyn_cast()) { auto mlirType = getMLIRType(E->getType()); @@ -3946,63 +4012,70 @@ ValueWithOffsets MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) { assert(cond != nullptr); stmt->dump(); SmallVector caseVals; - + auto er = builder.create(loc, ArrayRef()); er.region().push_back(new Block()); auto oldpoint2 = builder.getInsertionPoint(); auto oldblock2 = builder.getInsertionBlock(); - + auto &exitB = *(new Block()); builder.setInsertionPointToStart(&exitB); builder.create(loc); - SmallVector blocks; + SmallVector blocks; bool inCase = false; - for(auto cse : stmt->getBody()->children()) { + for (auto cse : stmt->getBody()->children()) { if (auto cses = dyn_cast(cse)) { - auto &condB = *(new Block()); - + auto &condB = *(new Block()); - caseVals.push_back((int32_t)Visit(cses->getLHS()).getValue(builder).getDefiningOp().getValue()); - - if (inCase) { - auto noBreak = builder.create(loc, loops.back().noBreak); + caseVals.push_back((int32_t)Visit(cses->getLHS()) + .getValue(builder) + .getDefiningOp() + .getValue()); + + if (inCase) { + auto noBreak = + builder.create(loc, loops.back().noBreak); builder.create(loc, noBreak, &condB, &exitB); loops.pop_back(); - } - - inCase = true; - er.region().getBlocks().push_back(&condB); - blocks.push_back(&condB); - builder.setInsertionPointToStart(&condB); + } - auto i1Ty = builder.getIntegerType(1); - auto type = mlir::MemRefType::get({}, i1Ty, {}, 0); - auto truev = builder.create(loc, true, 1); - loops.push_back( - (LoopContext){builder.create(loc, type), - builder.create(loc, type)}); - builder.create(loc, truev, loops.back().noBreak); - builder.create(loc, truev, loops.back().keepRunning); - Visit(cses->getSubStmt()); + inCase = true; + er.region().getBlocks().push_back(&condB); + blocks.push_back(&condB); + builder.setInsertionPointToStart(&condB); + + auto i1Ty = builder.getIntegerType(1); + auto type = mlir::MemRefType::get({}, i1Ty, {}, 0); + auto truev = builder.create(loc, true, 1); + loops.push_back( + (LoopContext){builder.create(loc, type), + builder.create(loc, type)}); + builder.create(loc, truev, loops.back().noBreak); + builder.create(loc, truev, + loops.back().keepRunning); + Visit(cses->getSubStmt()); } else { - Visit(cse); + Visit(cse); } } - if (inCase) loops.pop_back(); + if (inCase) + loops.pop_back(); builder.create(loc, &exitB); - + er.region().getBlocks().push_back(&exitB); - + DenseIntElementsAttr caseValuesAttr; - ShapedType caseValueType = mlir::VectorType::get( - static_cast(caseVals.size()), cond.getType()); - caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseVals); + ShapedType caseValueType = mlir::VectorType::get( + static_cast(caseVals.size()), cond.getType()); + caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseVals); builder.setInsertionPointToStart(&er.region().front()); - builder.create(loc, cond, &exitB, ArrayRef(), caseValuesAttr, blocks, SmallVector(caseVals.size(), ArrayRef())); + builder.create( + loc, cond, &exitB, ArrayRef(), caseValuesAttr, blocks, + SmallVector(caseVals.size(), ArrayRef())); builder.setInsertionPoint(oldblock2, oldpoint2); return nullptr; } @@ -4038,9 +4111,10 @@ MLIRScanner::VisitConditionalOperator(clang::ConditionalOperator *E) { assert(trueExpr.val); mlir::Value truev; if (isReference) { - assert(trueExpr.isReference); - truev = trueExpr.val; - } else truev = trueExpr.getValue(builder); + assert(trueExpr.isReference); + truev = trueExpr.val; + } else + truev = trueExpr.getValue(builder); assert(truev != nullptr); truearray.push_back(truev); } @@ -4053,9 +4127,10 @@ MLIRScanner::VisitConditionalOperator(clang::ConditionalOperator *E) { if (!E->getType()->isVoidType()) { mlir::Value falsev; if (isReference) { - assert(falseExpr.isReference); - falsev = falseExpr.val; - } else falsev = falseExpr.getValue(builder); + assert(falseExpr.isReference); + falsev = falseExpr.val; + } else + falsev = falseExpr.getValue(builder); assert(falsev != nullptr); falsearray.push_back(falsev); } @@ -4141,16 +4216,17 @@ ValueWithOffsets MLIRScanner::VisitReturnStmt(clang::ReturnStmt *stmt) { } else if (stmt->getRetValue()) { auto rv = Visit(stmt->getRetValue()); if (!stmt->getRetValue()->getType()->isVoidType()) { - if (!rv.val) { + if (!rv.val) { stmt->dump(); - } - assert(rv.val); - if (stmt->getRetValue()->isLValue() || stmt->getRetValue()->isXValue()) { - assert(rv.isReference); - builder.create(loc, rv.val, returnVal); - } else { - builder.create(loc, rv.getValue(builder), returnVal); - } + } + assert(rv.val); + if (stmt->getRetValue()->isLValue() || stmt->getRetValue()->isXValue()) { + assert(rv.isReference); + builder.create(loc, rv.val, returnVal); + } else { + builder.create(loc, rv.getValue(builder), + returnVal); + } } } @@ -4162,7 +4238,7 @@ ValueWithOffsets MLIRScanner::VisitReturnStmt(clang::ReturnStmt *stmt) { builder.create(loc, vfalse, l.keepRunning); builder.create(loc, vfalse, l.noBreak); } - + return nullptr; } @@ -4204,14 +4280,14 @@ MLIRASTConsumer::GetOrCreateLLVMFunction(const FunctionDecl *FD) { mlir::LLVM::GlobalOp MLIRASTConsumer::GetOrCreateLLVMGlobal(const ValueDecl *FD) { std::string name = CGM.getMangledName(FD).str(); - + if (llvmGlobals.find(name) != llvmGlobals.end()) { return llvmGlobals[name]; } LLVM::Linkage lnk; if (!isa(FD)) - FD->dump(); + FD->dump(); switch (CGM.getLLVMLinkageVarDefinition(cast(FD), /*isConstant*/ false)) { case llvm::GlobalValue::LinkageTypes::InternalLinkage: @@ -4443,7 +4519,8 @@ mlir::FuncOp MLIRASTConsumer::GetOrCreateMLIRFunction(const FunctionDecl *FD) { functions[name] = function; module.push_back(function); const FunctionDecl *Def = nullptr; - if (FD->isDefined(Def, /*checkforfriend*/ true) && Def->getTemplatedKind() != FunctionDecl::TK_FunctionTemplate) + if (FD->isDefined(Def, /*checkforfriend*/ true) && + Def->getTemplatedKind() != FunctionDecl::TK_FunctionTemplate) functionsToEmit.push_back(Def); else if (FD->getIdentifier()) emitIfFound.insert(FD->getName().str()); @@ -4461,7 +4538,7 @@ void MLIRASTConsumer::run() { while (functionsToEmit.size()) { const FunctionDecl *FD = functionsToEmit.front(); functionsToEmit.pop_front(); - assert (FD->getTemplatedKind() != FunctionDecl::TK_FunctionTemplate); + assert(FD->getTemplatedKind() != FunctionDecl::TK_FunctionTemplate); std::string name; if (auto CC = dyn_cast(FD)) @@ -4491,7 +4568,9 @@ void MLIRASTConsumer::HandleDeclContext(DeclContext *DC) { continue; if (fd->getIdentifier() == nullptr) continue; - if ((emitIfFound.count("*") && fd->getName() != "fpclassify" && !fd->isStatic())|| emitIfFound.count(fd->getName().str())) { + if ((emitIfFound.count("*") && fd->getName() != "fpclassify" && + !fd->isStatic()) || + emitIfFound.count(fd->getName().str())) { if (fd->getTemplatedKind() != FunctionDecl::TK_FunctionTemplate) functionsToEmit.push_back(fd); } else { @@ -4517,7 +4596,9 @@ bool MLIRASTConsumer::HandleTopLevelDecl(DeclGroupRef dg) { continue; if (fd->getIdentifier() == nullptr) continue; - if ((emitIfFound.count("*") && fd->getName() != "fpclassify" && !fd->isStatic())|| emitIfFound.count(fd->getName().str())) { + if ((emitIfFound.count("*") && fd->getName() != "fpclassify" && + !fd->isStatic()) || + emitIfFound.count(fd->getName().str())) { if (fd->getTemplatedKind() != FunctionDecl::TK_FunctionTemplate) functionsToEmit.push_back(fd); } else { @@ -4603,7 +4684,9 @@ mlir::Type MLIRASTConsumer::getMLIRType(clang::QualType qt, bool *implicitRef, } auto CXRD = dyn_cast(RT->getDecl()); - if (RT->getDecl()->isUnion() || (CXRD && (!CXRD->hasDefinition() || CXRD->isPolymorphic() || CXRD->getDefinition()->getNumBases() > 0)) || + if (RT->getDecl()->isUnion() || + (CXRD && (!CXRD->hasDefinition() || CXRD->isPolymorphic() || + CXRD->getDefinition()->getNumBases() > 0)) || ST->getNumElements() == 0 || recursive || (!ST->isLiteral() && (ST->getName().contains("SmallVector") || ST->getName() == "struct._IO_FILE" || @@ -4657,9 +4740,9 @@ mlir::Type MLIRASTConsumer::getMLIRType(clang::QualType qt, bool *implicitRef, return builder.getNoneType(); } - //if (auto AT = dyn_cast(t)) { - // return getMLIRType(AT->getElementType(), implicitRef, allowMerge); - //} + // if (auto AT = dyn_cast(t)) { + // return getMLIRType(AT->getElementType(), implicitRef, allowMerge); + // } if (auto AT = dyn_cast(t)) { bool subRef = false; @@ -4939,10 +5022,10 @@ public: } std::unique_ptr CreateASTConsumer(CompilerInstance &CI, StringRef InFile) override { - return std::unique_ptr( - new MLIRASTConsumer(emitIfFound, done, llvmStringGlobals, globals, - functions, llvmGlobals, llvmFunctions, CI.getPreprocessor(), - CI.getASTContext(), module, CI.getSourceManager())); + return std::unique_ptr(new MLIRASTConsumer( + emitIfFound, done, llvmStringGlobals, globals, functions, llvmGlobals, + llvmFunctions, CI.getPreprocessor(), CI.getASTContext(), module, + CI.getSourceManager())); } }; @@ -5162,4 +5245,3 @@ static bool parseMLIR(const char *Argv0, std::vector filenames, } return true; } - diff --git a/mlir-clang/Lib/clang-mlir.h b/mlir-clang/Lib/clang-mlir.h index 70451e9..b0cf7e7 100644 --- a/mlir-clang/Lib/clang-mlir.h +++ b/mlir-clang/Lib/clang-mlir.h @@ -108,7 +108,7 @@ struct ValueWithOffsets { // return ValueWithOffsets(builder.create(loc, mt0, val, // c0), /*isReference*/true); if (val.getType().cast().getShape().size() != 1) { - llvm::errs() << " val: " << val << " ty: " << val.getType() << "\n"; + llvm::errs() << " val: " << val << " ty: " << val.getType() << "\n"; } assert(val.getType().cast().getShape().size() == 1); return builder.create(loc, val, @@ -223,15 +223,17 @@ struct ValueWithOffsets { if (auto PT = val.getType().dyn_cast()) { if (toStore.getType() != PT.getElementType()) { if (auto mt = toStore.getType().dyn_cast()) { - if (auto spt = PT.getElementType().dyn_cast()) { + if (auto spt = + PT.getElementType().dyn_cast()) { if (mt.getElementType() == spt.getElementType()) { - toStore = builder.create(loc, spt, toStore); + toStore = builder.create(loc, spt, + toStore); } - } + } } } if (toStore.getType() != PT.getElementType()) { - llvm::errs() << " toStore: " << toStore << " PT: " << PT + llvm::errs() << " toStore: " << toStore << " PT: " << PT << " val: " << val << "\n"; } assert(toStore.getType() == PT.getElementType()); @@ -424,9 +426,10 @@ struct MLIRASTConsumer : public ASTConsumer { clang::SourceManager &SM) : emitIfFound(emitIfFound), done(done), llvmStringGlobals(llvmStringGlobals), globals(globals), - functions(functions), llvmGlobals(llvmGlobals), llvmFunctions(llvmFunctions), PP(PP), - astContext(astContext), module(module), SM(SM), lcontext(), - llvmMod("tmp", lcontext), codegenops(), + functions(functions), llvmGlobals(llvmGlobals), + llvmFunctions(llvmFunctions), PP(PP), astContext(astContext), + module(module), SM(SM), lcontext(), llvmMod("tmp", lcontext), + codegenops(), CGM(astContext, PP.getHeaderSearchInfo().getHeaderSearchOpts(), PP.getPreprocessorOpts(), codegenops, llvmMod, PP.getDiagnostics()), error(false), typeTranslator(*module.getContext()), @@ -666,17 +669,18 @@ public: builder.create(loc, truev, loops.back().noBreak); builder.create(loc, truev, loops.back().keepRunning); if (function.getType().getResults().size()) { - auto type = mlir::MemRefType::get({}, function.getType().getResult(0), {}, 0); - returnVal = builder.create(loc, type); + auto type = + mlir::MemRefType::get({}, function.getType().getResult(0), {}, 0); + returnVal = builder.create(loc, type); } Visit(stmt); - if (function.getType().getResults().size()) { - mlir::Value vals[1] = {builder.create(loc, returnVal)}; - builder.create(loc,vals); - } - else - builder.create(loc); + if (function.getType().getResults().size()) { + mlir::Value vals[1] = { + builder.create(loc, returnVal)}; + builder.create(loc, vals); + } else + builder.create(loc); // function.dump(); } @@ -701,7 +705,8 @@ public: ValueWithOffsets VisitVarDecl(clang::VarDecl *decl); ValueWithOffsets VisitForStmt(clang::ForStmt *fors); - ValueWithOffsets VisitOMPParallelForDirective(clang::OMPParallelForDirective *fors); + ValueWithOffsets + VisitOMPParallelForDirective(clang::OMPParallelForDirective *fors); ValueWithOffsets VisitWhileStmt(clang::WhileStmt *fors); ValueWithOffsets VisitDoStmt(clang::DoStmt *fors); @@ -713,7 +718,7 @@ public: ValueWithOffsets VisitCXXConstructExpr(clang::CXXConstructExpr *expr); ValueWithOffsets VisitConstructCommon(clang::CXXConstructExpr *expr, VarDecl *name, unsigned space, - mlir::Value mem=nullptr); + mlir::Value mem = nullptr); ValueWithOffsets VisitMSPropertyRefExpr(MSPropertyRefExpr *expr); diff --git a/mlir-clang/mlir-clang.cc b/mlir-clang/mlir-clang.cc index 19d608a..9a54433 100644 --- a/mlir-clang/mlir-clang.cc +++ b/mlir-clang/mlir-clang.cc @@ -1,17 +1,17 @@ #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" -#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" -#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" +#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" #include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/Passes.h" #include "mlir/Dialect/SCF/SCF.h" -#include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Verifier.h" -#include "mlir/Target/LLVMIR/Export.h" #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Export.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/InitLLVM.h" #include @@ -30,7 +30,7 @@ static cl::opt EmitLLVM("emit-llvm", cl::init(false), cl::desc("Emit llvm")); static cl::opt SCFOpenMP("scf-openmp", cl::init(false), - cl::desc("Emit llvm")); + cl::desc("Emit llvm")); static cl::opt ShowAST("show-ast", cl::init(false), cl::desc("Show AST")); @@ -41,7 +41,7 @@ static cl::opt RaiseToAffine("raise-scf-to-affine", cl::init(false), cl::desc("Raise SCF to Affine")); static cl::opt ScalarReplacement("scal-rep", cl::init(true), - cl::desc("Raise SCF to Affine")); + cl::desc("Raise SCF to Affine")); static cl::opt DetectReduction("detect-reduction", cl::init(false), @@ -67,7 +67,7 @@ static cl::opt FOpenMP("fopenmp", cl::init(false), cl::desc("Enable OpenMP")); static cl::opt ToCPU("cpuify", cl::init(false), - cl::desc("Convert to cpu")); + cl::desc("Convert to cpu")); static cl::opt MArch("march", cl::init(""), cl::desc("Architecture")); @@ -171,10 +171,10 @@ int main(int argc, char **argv) { optPM.addPass(mlir::createCanonicalizerPass()); optPM.addPass(polygeist::createLoopRestructurePass()); if (!CudaLower) - optPM.addPass(polygeist::replaceAffineCFGPass()); + optPM.addPass(polygeist::replaceAffineCFGPass()); optPM.addPass(mlir::createCanonicalizerPass()); if (ScalarReplacement) - optPM.addPass(mlir::createAffineScalarReplacementPass()); + optPM.addPass(mlir::createAffineScalarReplacementPass()); optPM.addPass(mlir::createLoopInvariantCodeMotionPass()); optPM.addPass(mlir::createCanonicalizerPass()); optPM.addPass(polygeist::createCanonicalizeForPass()); @@ -186,7 +186,7 @@ int main(int argc, char **argv) { optPM.addPass(polygeist::createRaiseSCFToAffinePass()); optPM.addPass(polygeist::replaceAffineCFGPass()); if (ScalarReplacement) - optPM.addPass(mlir::createAffineScalarReplacementPass()); + optPM.addPass(mlir::createAffineScalarReplacementPass()); } if (mlir::failed(pm.run(module))) { module.dump(); @@ -209,7 +209,6 @@ int main(int argc, char **argv) { optPM.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createSymbolDCEPass()); - if (CudaLower) { optPM.addPass(polygeist::createParallelLowerPass()); optPM.addPass(polygeist::replaceAffineCFGPass()); @@ -219,21 +218,19 @@ int main(int argc, char **argv) { optPM.addPass(mlir::createCanonicalizerPass()); optPM.addPass(polygeist::createCanonicalizeForPass()); optPM.addPass(mlir::createCanonicalizerPass()); - - if (RaiseToAffine) { - optPM.addPass(polygeist::createCanonicalizeForPass()); - optPM.addPass(mlir::createCanonicalizerPass()); - optPM.addPass(mlir::createLoopInvariantCodeMotionPass()); - optPM.addPass(polygeist::createRaiseSCFToAffinePass()); - optPM.addPass(polygeist::replaceAffineCFGPass()); - if (ScalarReplacement) - optPM.addPass(mlir::createAffineScalarReplacementPass()); - } + + if (RaiseToAffine) { + optPM.addPass(polygeist::createCanonicalizeForPass()); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::createLoopInvariantCodeMotionPass()); + optPM.addPass(polygeist::createRaiseSCFToAffinePass()); + optPM.addPass(polygeist::replaceAffineCFGPass()); + if (ScalarReplacement) + optPM.addPass(mlir::createAffineScalarReplacementPass()); + } if (ToCPU) optPM.addPass(polygeist::createCPUifyPass()); - } - if (EmitLLVM) { pm.addPass(mlir::createLowerAffinePass()); @@ -263,11 +260,10 @@ int main(int argc, char **argv) { } } else { - if (mlir::failed(pm.run(module))) { - module.dump(); - return 4; - } - + if (mlir::failed(pm.run(module))) { + module.dump(); + return 4; + } } // module.dump(); if (mlir::failed(mlir::verify(module))) {