diff --git a/lib/polygeist/Passes/LoopRestructure.cpp b/lib/polygeist/Passes/LoopRestructure.cpp index e040d52..43a6945 100644 --- a/lib/polygeist/Passes/LoopRestructure.cpp +++ b/lib/polygeist/Passes/LoopRestructure.cpp @@ -644,6 +644,27 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region ®ion) { attemptToFoldIntoPredecessor(wrapper); attemptToFoldIntoPredecessor(target); + if (loop.before().getBlocks().size() != 1) { + Block* blk = new Block(); + OpBuilder B(loop.getContext()); + B.setInsertionPointToEnd(blk); + auto cop = cast(loop.before().getBlocks().back().back()); + auto er = B.create(loop.getLoc(), cop.getOperandTypes()); + er.region().getBlocks().splice(er.region().getBlocks().begin(), loop.before().getBlocks()); + loop.before().push_back(blk); + SmallVector yields; + for(auto a : er.getResults()) yields.push_back(a); + yields.erase(yields.begin()); + B.create(cop.getLoc(), er.getResult(0), yields); + B.setInsertionPoint(&*cop); + for(auto arg : er.region().front().getArguments()) { + auto na = blk->addArgument(arg.getType()); + arg.replaceAllUsesWith(na); + } + er.region().front().eraseArguments([](BlockArgument){ return true; }); + B.create(cop.getLoc(), cop.getOperands()); + cop.erase(); + } assert(loop.before().getBlocks().size() == 1); runOnRegion(domInfo, loop.after()); assert(loop.after().getBlocks().size() == 1); diff --git a/tools/mlir-clang/Lib/clang-mlir.cc b/tools/mlir-clang/Lib/clang-mlir.cc index a4eb757..7da35c8 100644 --- a/tools/mlir-clang/Lib/clang-mlir.cc +++ b/tools/mlir-clang/Lib/clang-mlir.cc @@ -408,6 +408,14 @@ mlir::Value MLIRScanner::createAllocOp(mlir::Type t, VarDecl *name, ValueCategory MLIRScanner::VisitConstantExpr(clang::ConstantExpr *expr) { auto sv = Visit(expr->getSubExpr()); + if (auto ty = getMLIRType(expr->getType()).dyn_cast()) { + if (expr->hasAPValueResult()) { + return ValueCategory( + builder.create(getMLIRLocation(expr->getExprLoc()), + expr->getResultAsAPSInt().getExtValue(), ty), + /*isReference*/ false); + } + } assert(sv.val); return sv; } @@ -3597,20 +3605,25 @@ ValueCategory MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) { } case clang::BinaryOperator::Opcode::BO_Sub: { auto lhs_v = lhs.getValue(builder); + auto rhs_v = rhs.getValue(builder); if (lhs_v.getType().isa()) { - auto right = rhs.getValue(builder); - assert(right.getType() == lhs_v.getType()); - return ValueCategory(builder.create(loc, lhs_v, right), + assert(rhs_v.getType() == lhs_v.getType()); + return ValueCategory(builder.create(loc, lhs_v, rhs_v), /*isReference*/ false); } else if (auto pt = lhs_v.getType().dyn_cast()) { + if (auto IT = rhs_v.getType().dyn_cast()) { + mlir::Value vals[1] = {builder.create(loc, builder.create(loc, 0, IT.getWidth()), rhs_v)}; + return ValueCategory(builder.create(loc, lhs_v.getType(), + lhs_v, ArrayRef(vals)), false); + } return ValueCategory( builder.create( loc, builder.create(loc, getMLIRType(BO->getType()), lhs_v), builder.create(loc, getMLIRType(BO->getType()), - rhs.getValue(builder))), + rhs_v)), /*isReference*/ false); } else if (auto mt = lhs_v.getType().dyn_cast()) { llvm::errs() << " memref ptrtoint: " << mt << "\n"; @@ -3620,11 +3633,11 @@ ValueCategory MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) { builder.create(loc, getMLIRType(BO->getType()), lhs_v), builder.create(loc, getMLIRType(BO->getType()), - rhs.getValue(builder))), + rhs_v)), /*isReference*/ false); } else { return ValueCategory( - builder.create(loc, lhs_v, rhs.getValue(builder)), + builder.create(loc, lhs_v, rhs_v), /*isReference*/ false); } }