From 239a2c4d82bba4bf999db2df51bfef47ffd96e33 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 30 Dec 2021 15:45:42 -0500 Subject: [PATCH] Rebase LLVM --- include/polygeist/BarrierUtils.h | 2 +- lib/polygeist/Ops.cpp | 13 + lib/polygeist/Passes/AffineCFG.cpp | 12 +- .../Passes/BarrierRemovalContinuation.cpp | 34 +-- lib/polygeist/Passes/CanonicalizeFor.cpp | 250 +++++++++--------- lib/polygeist/Passes/LoopRestructure.cpp | 58 ++-- lib/polygeist/Passes/Mem2Reg.cpp | 60 ++--- .../Passes/ParallelLoopDistribute.cpp | 108 ++++---- lib/polygeist/Passes/RaiseToAffine.cpp | 18 +- llvm-project | 2 +- tools/mlir-clang/Lib/CGStmt.cc | 32 +-- tools/mlir-clang/Lib/IfScope.cc | 8 +- tools/mlir-clang/Lib/clang-mlir.cc | 32 +-- 13 files changed, 315 insertions(+), 314 deletions(-) diff --git a/include/polygeist/BarrierUtils.h b/include/polygeist/BarrierUtils.h index 08fc22f..85591b0 100644 --- a/include/polygeist/BarrierUtils.h +++ b/include/polygeist/BarrierUtils.h @@ -29,7 +29,7 @@ static llvm::SmallVector emitIterationCounts(mlir::OpBuilder &rewriter, mlir::scf::ParallelOp op) { using namespace mlir; SmallVector iterationCounts; - for (auto bounds : llvm::zip(op.lowerBound(), op.upperBound(), op.step())) { + for (auto bounds : llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep())) { Value lowerBound = std::get<0>(bounds); Value upperBound = std::get<1>(bounds); Value step = std::get<2>(bounds); diff --git a/lib/polygeist/Ops.cpp b/lib/polygeist/Ops.cpp index 8ca40c6..9848189 100644 --- a/lib/polygeist/Ops.cpp +++ b/lib/polygeist/Ops.cpp @@ -693,6 +693,9 @@ public: if (src.source().getType().cast().getElementType() != op.getType().cast().getElementType()) return failure(); + if (src.source().getType().cast().getMemorySpace() != + op.getType().cast().getMemorySpace()) + return failure(); rewriter.replaceOpWithNewOp(op, op.getType(), src.source()); return success(); @@ -732,6 +735,11 @@ OpFoldResult Memref2PointerOp::fold(ArrayRef operands) { sourceMutable().assign(mc.source()); return result(); } + if (auto mc = source().getDefiningOp()) { + if (mc.source().getType() == getType()) { + return mc.source(); + } + } return nullptr; } @@ -868,6 +876,11 @@ OpFoldResult Pointer2MemrefOp::fold(ArrayRef operands) { sourceMutable().assign(mc.getArg()); return result(); } + if (auto mc = source().getDefiningOp()) { + if (mc.source().getType() == getType()) { + return mc.source(); + } + } return nullptr; } diff --git a/lib/polygeist/Passes/AffineCFG.cpp b/lib/polygeist/Passes/AffineCFG.cpp index b7d8fbd..31ec0d9 100644 --- a/lib/polygeist/Passes/AffineCFG.cpp +++ b/lib/polygeist/Passes/AffineCFG.cpp @@ -1047,7 +1047,7 @@ void AffineCFGPass::runOnFunction() { OpBuilder b(ifOp); AffineIfOp affineIfOp; std::vector types; - for (auto v : ifOp.results()) { + for (auto v : ifOp.getResults()) { types.push_back(v.getType()); } @@ -1055,7 +1055,7 @@ void AffineCFGPass::runOnFunction() { SmallVector eqflags; SmallVector applies; - std::deque todo = {ifOp.condition()}; + std::deque todo = {ifOp.getCondition()}; while (todo.size()) { auto cur = todo.front(); todo.pop_front(); @@ -1077,20 +1077,20 @@ void AffineCFGPass::runOnFunction() { eqflags); affineIfOp = b.create(ifOp.getLoc(), types, iset, applies, /*elseBlock=*/true); - affineIfOp.thenRegion().takeBody(ifOp.thenRegion()); - affineIfOp.elseRegion().takeBody(ifOp.elseRegion()); + affineIfOp.thenRegion().takeBody(ifOp.getThenRegion()); + affineIfOp.elseRegion().takeBody(ifOp.getElseRegion()); for (auto &blk : affineIfOp.thenRegion()) { if (auto yop = dyn_cast(blk.getTerminator())) { OpBuilder b(yop); - b.create(yop.getLoc(), yop.results()); + b.create(yop.getLoc(), yop.getResults()); yop.erase(); } } for (auto &blk : affineIfOp.elseRegion()) { if (auto yop = dyn_cast(blk.getTerminator())) { OpBuilder b(yop); - b.create(yop.getLoc(), yop.results()); + b.create(yop.getLoc(), yop.getResults()); yop.erase(); } } diff --git a/lib/polygeist/Passes/BarrierRemovalContinuation.cpp b/lib/polygeist/Passes/BarrierRemovalContinuation.cpp index 2d2e3da..68852f5 100644 --- a/lib/polygeist/Passes/BarrierRemovalContinuation.cpp +++ b/lib/polygeist/Passes/BarrierRemovalContinuation.cpp @@ -57,14 +57,14 @@ static void wrapPersistingLoopBodies(FuncOp function) { for (scf::ParallelOp op : loops) { OpBuilder builder = OpBuilder::atBlockBegin(op.getBody()); auto wrapper = builder.create( - op.getLoc(), op.results().getTypes()); - builder.createBlock(&wrapper.region(), wrapper.region().begin()); - wrapper.region().front().getOperations().splice( - wrapper.region().front().begin(), op.getBody()->getOperations(), + op.getLoc(), op.getResults().getTypes()); + builder.createBlock(&wrapper.getRegion(), wrapper.getRegion().begin()); + wrapper.getRegion().front().getOperations().splice( + wrapper.getRegion().front().begin(), op.getBody()->getOperations(), std::next(op.getBody()->begin()), op.getBody()->end()); builder.setInsertionPointToEnd(op.getBody()); builder.create( - wrapper.region().front().getTerminator()->getLoc(), + wrapper.getRegion().front().getTerminator()->getLoc(), wrapper.getResults()); } } @@ -124,7 +124,7 @@ static LogicalResult splitBlocksWithBarrier(FuncOp function) { } splitBlocksWithBarrier( - cast(&op.getBody()->front()).region()); + cast(&op.getBody()->front()).getRegion()); return success(); }); return success(!result.wasInterrupted()); @@ -288,15 +288,15 @@ emitContinuationCase(Value condition, Value storage, scf::ParallelOp parallel, bn.create(TypeRange(), ValueRange()); BlockAndValueMapping mapping; mapping.map(parallel.getInductionVars(), ivs); - replicateIntoRegion(executeRegion.region(), storage, ivs, - parallel.lowerBound(), blocks, subgraphEntryPoints, + replicateIntoRegion(executeRegion.getRegion(), storage, ivs, + parallel.getLowerBound(), blocks, subgraphEntryPoints, mapping, builder); }; auto thenBuilder = [&](OpBuilder &nested, Location loc) { ImplicitLocOpBuilder bn(loc, nested); - bn.create(parallel.lowerBound(), parallel.upperBound(), - parallel.step(), parallelBuilder); + bn.create(parallel.getLowerBound(), parallel.getUpperBound(), + parallel.getStep(), parallelBuilder); bn.create(); }; @@ -362,9 +362,9 @@ findInsertionPointAfterLoopOperands(scf::ParallelOp op) { // Find the earliest insertion point where loop bounds are fully defined. PostDominanceInfo postDominanceInfo(op->getParentOfType()); SmallVector operands; - llvm::append_range(operands, op.lowerBound()); - llvm::append_range(operands, op.upperBound()); - llvm::append_range(operands, op.step()); + llvm::append_range(operands, op.getLowerBound()); + llvm::append_range(operands, op.getUpperBound()); + llvm::append_range(operands, op.getStep()); return findNesrestPostDominatingInsertionPoint(operands, postDominanceInfo); } @@ -536,8 +536,8 @@ static void createContinuations(scf::ParallelOp parallel, Value storage) { llvm::SetVector startBlocks; auto outerExecuteRegion = cast(¶llel.getBody()->front()); - startBlocks.insert(&outerExecuteRegion.region().front()); - for (Block &block : outerExecuteRegion.region()) { + startBlocks.insert(&outerExecuteRegion.getRegion().front()); + for (Block &block : outerExecuteRegion.getRegion()) { if (!isa_and_nonnull( block.getTerminator()->getPrevNode())) continue; @@ -560,12 +560,12 @@ static void createContinuations(scf::ParallelOp parallel, Value storage) { OpBuilder allocBuilder(loop); reg2mem(subgraphs, parallel, allocBuilder, builder); - builder.createBlock(&loop.before(), loop.before().end()); + builder.createBlock(&loop.getBefore(), loop.getBefore().end()); Value next = builder.create(storage); Value condition = builder.create(CmpIPredicate::ne, next, negOne); builder.create(TypeRange(), condition, ValueRange()); - builder.createBlock(&loop.after(), loop.after().end()); + builder.createBlock(&loop.getAfter(), loop.getAfter().end()); next = builder.create(storage); SmallVector caseConditions; caseConditions.resize(startBlocks.size()); diff --git a/lib/polygeist/Passes/CanonicalizeFor.cpp b/lib/polygeist/Passes/CanonicalizeFor.cpp index 66f368f..c1427aa 100644 --- a/lib/polygeist/Passes/CanonicalizeFor.cpp +++ b/lib/polygeist/Passes/CanonicalizeFor.cpp @@ -52,7 +52,7 @@ static bool hasSameInitValue(Value iter, scf::ForOp forOp) { if (!cst) return false; if (auto cstOp = dyn_cast(cst)) { - Operation *lbDefOp = forOp.lowerBound().getDefiningOp(); + Operation *lbDefOp = forOp.getLowerBound().getDefiningOp(); if (!lbDefOp) return false; ConstantIndexOp lb = dyn_cast_or_null(lbDefOp); @@ -69,7 +69,7 @@ static bool hasSameStepValue(Value regIter, Value yieldOp, scf::ForOp forOp) { if (!defOpStep) return false; if (auto cstStep = dyn_cast(defOpStep)) { - Operation *stepForDefOp = forOp.step().getDefiningOp(); + Operation *stepForDefOp = forOp.getStep().getDefiningOp(); if (!stepForDefOp) return false; ConstantIndexOp stepFor = dyn_cast_or_null(stepForDefOp); @@ -123,7 +123,7 @@ struct DetectTrivialIndVarInArgs : public OpRewritePattern { if (!forOp.getNumIterOperands()) return failure(); - Block &block = forOp.region().front(); + Block &block = forOp.getRegion().front(); auto yieldOp = cast(block.getTerminator()); bool matched = false; @@ -150,7 +150,7 @@ struct ForOpInductionReplacement : public OpRewritePattern { LogicalResult matchAndRewrite(scf::ForOp forOp, PatternRewriter &rewriter) const final { bool canonicalize = false; - Block &block = forOp.region().front(); + Block &block = forOp.getRegion().front(); auto yieldOp = cast(block.getTerminator()); for (auto it : llvm::zip(forOp.getIterOperands(), // iter from outside @@ -171,10 +171,10 @@ struct ForOpInductionReplacement : public OpRewritePattern { bool isOne = false; - if (addOp.getOperand(1) == forOp.step()) { + if (addOp.getOperand(1) == forOp.getStep()) { legalStep = true; } else if (auto iter_step = - forOp.step().getDefiningOp()) { + forOp.getStep().getDefiningOp()) { isOne |= iter_step.value() == 1; if (auto op = addOp.getOperand(1).getDefiningOp()) { if (op.value() == iter_step.value()) { @@ -200,9 +200,9 @@ struct ForOpInductionReplacement : public OpRewritePattern { Value init = std::get<0>(it); if (!std::get<1>(it).use_empty()) { - rewriter.setInsertionPointToStart(&forOp.region().front()); + rewriter.setInsertionPointToStart(&forOp.getRegion().front()); Value replacement = rewriter.create( - forOp.getLoc(), forOp.getInductionVar(), forOp.lowerBound()); + forOp.getLoc(), forOp.getInductionVar(), forOp.getLowerBound()); if (!std::get<1>(it).getType().isa()) { replacement = rewriter.create( forOp.getLoc(), replacement, std::get<1>(it).getType()); @@ -223,7 +223,7 @@ struct ForOpInductionReplacement : public OpRewritePattern { if (isOne && !std::get<2>(it).use_empty()) { rewriter.setInsertionPoint(forOp); Value replacement = rewriter.create( - forOp.getLoc(), forOp.upperBound(), forOp.lowerBound()); + forOp.getLoc(), forOp.getUpperBound(), forOp.getLowerBound()); if (!std::get<1>(it).getType().isa()) { replacement = rewriter.create( forOp.getLoc(), replacement, std::get<1>(it).getType()); @@ -275,7 +275,7 @@ struct RemoveUnusedArgs : public OpRewritePattern { return failure(); auto newForOp = rewriter.create( - op.getLoc(), op.lowerBound(), op.upperBound(), op.step(), usedOperands); + op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep(), usedOperands); if (!newForOp.getBody()->empty()) rewriter.eraseOp(newForOp.getBody()->getTerminator()); @@ -499,7 +499,7 @@ yop2.results()[idx]); bool isWhile(WhileOp wop) { bool hasCondOp = false; - wop.before().walk([&](Operation *op) { + wop.getBefore().walk([&](Operation *op) { if (isa(op)) hasCondOp = true; }); @@ -547,12 +547,12 @@ struct MoveWhileToFor : public OpRewritePattern { } loopInfo; auto condOp = loop.getConditionOp(); - SmallVector results = {condOp.args()}; - auto cmpIOp = condOp.condition().getDefiningOp(); + SmallVector results = {condOp.getArgs()}; + auto cmpIOp = condOp.getCondition().getDefiningOp(); if (!cmpIOp) { return failure(); } - size_t size = loop.before().front().getOperations().size(); + size_t size = loop.getBefore().front().getOperations().size(); if (size != 2) { return failure(); } @@ -560,24 +560,24 @@ struct MoveWhileToFor : public OpRewritePattern { BlockArgument indVar = cmpIOp.getLhs().dyn_cast(); if (!indVar) return failure(); - if (indVar.getOwner() != &loop.before().front()) + if (indVar.getOwner() != &loop.getBefore().front()) return failure(); SmallVector afterArgs; - for (auto pair : llvm::enumerate(condOp.args())) { + for (auto pair : llvm::enumerate(condOp.getArgs())) { if (pair.value() == indVar) afterArgs.push_back(pair.index()); } - auto endYield = cast(loop.after().back().getTerminator()); + auto endYield = cast(loop.getAfter().back().getTerminator()); auto addIOp = - endYield.results()[indVar.getArgNumber()].getDefiningOp(); + endYield.getResults()[indVar.getArgNumber()].getDefiningOp(); if (!addIOp) return failure(); for (auto afterArg : afterArgs) { - auto arg = loop.after().getArgument(afterArg); + auto arg = loop.getAfter().getArgument(afterArg); if (addIOp.getOperand(0) == arg) { step = addIOp.getOperand(1); break; @@ -679,19 +679,19 @@ struct MoveWhileToFor : public OpRewritePattern { // input of the for goes the input of the scf::while plus the output taken // from the conditionOp. SmallVector forArgs; - forArgs.append(loop.inits().begin(), loop.inits().end()); + forArgs.append(loop.getInits().begin(), loop.getInits().end()); - for (Value arg : condOp.args()) { + for (Value arg : condOp.getArgs()) { Type cst = nullptr; if (auto idx = arg.getDefiningOp()) { cst = idx.getType(); arg = idx.getIn(); } Value res; - if (isTopLevelArgValue(arg, &loop.before())) { + if (isTopLevelArgValue(arg, &loop.getBefore())) { auto blockArg = arg.cast(); auto pos = blockArg.getArgNumber(); - res = loop.inits()[pos]; + res = loop.getInits()[pos]; } else res = arg; if (cst) { @@ -706,31 +706,31 @@ struct MoveWhileToFor : public OpRewritePattern { if (!forloop.getBody()->empty()) rewriter.eraseOp(forloop.getBody()->getTerminator()); - auto oldYield = cast(loop.after().front().getTerminator()); + auto oldYield = cast(loop.getAfter().front().getTerminator()); rewriter.updateRootInPlace(loop, [&] { - for (auto pair : llvm::zip(loop.after().getArguments(), condOp.args())) { + for (auto pair : llvm::zip(loop.getAfter().getArguments(), condOp.getArgs())) { std::get<0>(pair).replaceAllUsesWith(std::get<1>(pair)); } }); - loop.after().front().eraseArguments([](BlockArgument) { return true; }); + loop.getAfter().front().eraseArguments([](BlockArgument) { return true; }); SmallVector yieldOperands; - for (auto oldYieldArg : oldYield.results()) + for (auto oldYieldArg : oldYield.getResults()) yieldOperands.push_back(oldYieldArg); BlockAndValueMapping outmap; - outmap.map(loop.before().getArguments(), yieldOperands); - for (auto arg : condOp.args()) + outmap.map(loop.getBefore().getArguments(), yieldOperands); + for (auto arg : condOp.getArgs()) yieldOperands.push_back(outmap.lookupOrDefault(arg)); rewriter.setInsertionPoint(oldYield); rewriter.replaceOpWithNewOp(oldYield, yieldOperands); - size_t pos = loop.inits().size(); + size_t pos = loop.getInits().size(); rewriter.updateRootInPlace(loop, [&] { - for (auto pair : llvm::zip(loop.before().getArguments(), + for (auto pair : llvm::zip(loop.getBefore().getArguments(), forloop.getRegionIterArgs().drop_back(pos))) { std::get<0>(pair).replaceAllUsesWith(std::get<1>(pair)); } @@ -738,7 +738,7 @@ struct MoveWhileToFor : public OpRewritePattern { forloop.getBody()->getOperations().splice( forloop.getBody()->getOperations().begin(), - loop.after().front().getOperations()); + loop.getAfter().front().getOperations()); SmallVector replacements; replacements.append(forloop.getResults().begin() + pos, @@ -754,20 +754,20 @@ struct MoveWhileDown : public OpRewritePattern { LogicalResult matchAndRewrite(WhileOp op, PatternRewriter &rewriter) const override { - auto term = cast(op.before().front().getTerminator()); - if (auto ifOp = term.condition().getDefiningOp()) { - if (ifOp.getNumResults() != term.args().size() + 1) + auto term = cast(op.getBefore().front().getTerminator()); + if (auto ifOp = term.getCondition().getDefiningOp()) { + if (ifOp.getNumResults() != term.getArgs().size() + 1) return failure(); - if (ifOp.getResult(0) != term.condition()) + if (ifOp.getResult(0) != term.getCondition()) return failure(); for (size_t i = 1; i < ifOp.getNumResults(); ++i) { - if (ifOp.getResult(i) != term.args()[i - 1]) + if (ifOp.getResult(i) != term.getArgs()[i - 1]) return failure(); } auto yield1 = - cast(ifOp.thenRegion().front().getTerminator()); + cast(ifOp.getThenRegion().front().getTerminator()); auto yield2 = - cast(ifOp.elseRegion().front().getTerminator()); + cast(ifOp.getElseRegion().front().getTerminator()); if (auto cop = yield1.getOperand(0).getDefiningOp()) { if (cop.value() == 0) return failure(); @@ -778,31 +778,31 @@ struct MoveWhileDown : public OpRewritePattern { return failure(); } else return failure(); - if (ifOp.elseRegion().front().getOperations().size() != 1) + if (ifOp.getElseRegion().front().getOperations().size() != 1) return failure(); - op.after().front().getOperations().splice( - op.after().front().begin(), - ifOp.thenRegion().front().getOperations()); + op.getAfter().front().getOperations().splice( + op.getAfter().front().begin(), + ifOp.getThenRegion().front().getOperations()); rewriter.updateRootInPlace( - term, [&] { term.conditionMutable().assign(ifOp.condition()); }); + term, [&] { term.getConditionMutable().assign(ifOp.getCondition()); }); SmallVector args; for (size_t i = 1; i < yield2.getNumOperands(); ++i) { args.push_back(yield2.getOperand(i)); } rewriter.updateRootInPlace(term, - [&] { term.argsMutable().assign(args); }); + [&] { term.getArgsMutable().assign(args); }); rewriter.eraseOp(yield2); rewriter.eraseOp(ifOp); - for (size_t i = 0; i < op.after().front().getNumArguments(); ++i) { - op.after().front().getArgument(i).replaceAllUsesWith( + for (size_t i = 0; i < op.getAfter().front().getNumArguments(); ++i) { + op.getAfter().front().getArgument(i).replaceAllUsesWith( yield1.getOperand(i + 1)); } rewriter.eraseOp(yield1); // TODO move operands from begin to after - SmallVector todo(op.before().front().getArguments().begin(), - op.before().front().getArguments().end()); - for (auto &op : op.before().front()) { + SmallVector todo(op.getBefore().front().getArguments().begin(), + op.getBefore().front().getArguments().end()); + for (auto &op : op.getBefore().front()) { for (auto res : op.getResults()) { todo.push_back(res); } @@ -810,24 +810,24 @@ struct MoveWhileDown : public OpRewritePattern { rewriter.updateRootInPlace(op, [&] { for (auto val : todo) { - auto na = op.after().front().addArgument(val.getType()); + auto na = op.getAfter().front().addArgument(val.getType()); val.replaceUsesWithIf(na, [&](OpOperand &u) -> bool { - return op.after().isAncestor(u.getOwner()->getParentRegion()); + return op.getAfter().isAncestor(u.getOwner()->getParentRegion()); }); args.push_back(val); } }); rewriter.updateRootInPlace(term, - [&] { term.argsMutable().assign(args); }); + [&] { term.getArgsMutable().assign(args); }); SmallVector tys; for (auto a : args) tys.push_back(a.getType()); - auto op2 = rewriter.create(op.getLoc(), tys, op.inits()); - op2.before().takeBody(op.before()); - op2.after().takeBody(op.after()); + auto op2 = rewriter.create(op.getLoc(), tys, op.getInits()); + op2.getBefore().takeBody(op.getBefore()); + op2.getAfter().takeBody(op.getAfter()); SmallVector replacements; for (auto a : op2.getResults()) { if (replacements.size() == op.getResults().size()) @@ -882,9 +882,9 @@ struct MoveWhileDown2 : public OpRewritePattern { LogicalResult matchAndRewrite(WhileOp op, PatternRewriter &rewriter) const override { - auto term = cast(op.before().front().getTerminator()); + auto term = cast(op.getBefore().front().getTerminator()); if (auto ifOp = dyn_cast_or_null(term->getPrevNode())) { - if (ifOp.condition() != term.condition()) + if (ifOp.getCondition() != term.getCondition()) return failure(); SmallVector, 2> m; @@ -892,14 +892,14 @@ struct MoveWhileDown2 : public OpRewritePattern { SmallVector prevArgs; SmallVector, 2> afterYieldRewrites; - auto afterYield = cast(op.after().front().back()); + auto afterYield = cast(op.getAfter().front().back()); for (auto pair : - llvm::zip(op.getResults(), term.args(), op.getAfterArguments())) { + llvm::zip(op.getResults(), term.getArgs(), op.getAfterArguments())) { if (std::get<1>(pair).getDefiningOp() == ifOp) { Value thenYielded, elseYielded; - for (auto p : llvm::zip(ifOp.thenYield().results(), ifOp.results(), - ifOp.elseYield().results())) { + for (auto p : llvm::zip(ifOp.thenYield().getResults(), ifOp.getResults(), + ifOp.elseYield().getResults())) { if (std::get<1>(pair) == std::get<1>(p)) { thenYielded = std::get<0>(p); elseYielded = std::get<2>(p); @@ -911,8 +911,8 @@ struct MoveWhileDown2 : public OpRewritePattern { if (!std::get<0>(pair).use_empty()) { if (auto blockArg = elseYielded.dyn_cast()) - if (blockArg.getOwner() == &op.before().front()) { - if (afterYield.results()[blockArg.getArgNumber()] == + if (blockArg.getOwner() == &op.getBefore().front()) { + if (afterYield.getResults()[blockArg.getArgNumber()] == std::get<2>(pair) && op.getResults()[blockArg.getArgNumber()] == std::get<0>(pair)) { @@ -933,18 +933,18 @@ struct MoveWhileDown2 : public OpRewritePattern { } } - SmallVector yieldArgs = afterYield.results(); + SmallVector yieldArgs = afterYield.getResults(); for (auto pair : afterYieldRewrites) { yieldArgs[pair.first] = pair.second; } rewriter.updateRootInPlace( - afterYield, [&] { afterYield.resultsMutable().assign(yieldArgs); }); + afterYield, [&] { afterYield.getResultsMutable().assign(yieldArgs); }); llvm::SetVector sv; findValuesUsedBelow(ifOp, sv); - Block *afterB = &op.after().front(); + Block *afterB = &op.getAfter().front(); for (auto v : sv) { condArgs.push_back(v); @@ -956,7 +956,7 @@ struct MoveWhileDown2 : public OpRewritePattern { } rewriter.setInsertionPoint(term); - rewriter.replaceOpWithNewOp(term, term.condition(), + rewriter.replaceOpWithNewOp(term, term.getCondition(), condArgs); for (int i = m.size() - 1; i >= 0; i--) { @@ -977,9 +977,9 @@ struct MoveWhileDown2 : public OpRewritePattern { } rewriter.setInsertionPoint(op); - auto nop = rewriter.create(op.getLoc(), resultTypes, op.inits()); - nop.before().takeBody(op.before()); - nop.after().takeBody(op.after()); + auto nop = rewriter.create(op.getLoc(), resultTypes, op.getInits()); + nop.getBefore().takeBody(op.getBefore()); + nop.getAfter().takeBody(op.getAfter()); rewriter.updateRootInPlace(op, [&] { for (auto pair : llvm::enumerate(prevArgs)) { @@ -1002,24 +1002,24 @@ struct MoveWhileInvariantIfResult : public OpRewritePattern { SmallVector origAfterArgs(op.getAfterArguments().begin(), op.getAfterArguments().end()); bool changed = false; - auto term = cast(op.before().front().getTerminator()); + scf::ConditionOp term = cast(op.getBefore().front().getTerminator()); assert(origAfterArgs.size() == op.getResults().size()); - assert(origAfterArgs.size() == term.args().size()); + assert(origAfterArgs.size() == term.getArgs().size()); - for (auto pair : llvm::zip(op.getResults(), term.args(), origAfterArgs)) { + for (auto pair : llvm::zip(op.getResults(), term.getArgs(), origAfterArgs)) { if (!std::get<0>(pair).use_empty()) { if (auto ifOp = std::get<1>(pair).getDefiningOp()) { - if (ifOp.condition() == term.condition()) { + if (ifOp.getCondition() == term.getCondition()) { ssize_t idx = -1; - for (auto tup : llvm::enumerate(ifOp.results())) { + for (auto tup : llvm::enumerate(ifOp.getResults())) { if (tup.value() == std::get<1>(pair)) { idx = tup.index(); break; } } assert(idx != -1); - Value returnWith = ifOp.elseYield().results()[idx]; - if (!op.before().isAncestor(returnWith.getParentRegion())) { + Value returnWith = ifOp.elseYield().getResults()[idx]; + if (!op.getBefore().isAncestor(returnWith.getParentRegion())) { rewriter.updateRootInPlace(op, [&] { std::get<0>(pair).replaceAllUsesWith(returnWith); }); @@ -1042,12 +1042,12 @@ struct WhileLogicalNegation : public OpRewritePattern { SmallVector origAfterArgs(op.getAfterArguments().begin(), op.getAfterArguments().end()); bool changed = false; - auto term = cast(op.before().front().getTerminator()); + scf::ConditionOp term = cast(op.getBefore().front().getTerminator()); assert(origAfterArgs.size() == op.getResults().size()); - assert(origAfterArgs.size() == term.args().size()); + assert(origAfterArgs.size() == term.getArgs().size()); - if (auto condCmp = term.condition().getDefiningOp()) { - for (auto pair : llvm::zip(op.getResults(), term.args(), origAfterArgs)) { + if (auto condCmp = term.getCondition().getDefiningOp()) { + for (auto pair : llvm::zip(op.getResults(), term.getArgs(), origAfterArgs)) { if (!std::get<0>(pair).use_empty()) { if (auto termCmp = std::get<1>(pair).getDefiningOp()) { if (termCmp.getLhs() == condCmp.getLhs() && @@ -1082,41 +1082,41 @@ struct WhileCmpOffset : public OpRewritePattern { PatternRewriter &rewriter) const override { SmallVector origAfterArgs(op.getAfterArguments().begin(), op.getAfterArguments().end()); - auto term = cast(op.before().front().getTerminator()); + scf::ConditionOp term = cast(op.getBefore().front().getTerminator()); assert(origAfterArgs.size() == op.getResults().size()); - assert(origAfterArgs.size() == term.args().size()); + assert(origAfterArgs.size() == term.getArgs().size()); - if (auto condCmp = term.condition().getDefiningOp()) { + if (auto condCmp = term.getCondition().getDefiningOp()) { if (auto addI = condCmp.getLhs().getDefiningOp()) { if (addI.getOperand(1).getDefiningOp() && - !op.before().isAncestor( + !op.getBefore().isAncestor( addI.getOperand(1).getDefiningOp()->getParentRegion())) if (auto blockArg = addI.getOperand(0).dyn_cast()) { - if (blockArg.getOwner() == &op.before().front()) { + if (blockArg.getOwner() == &op.getBefore().front()) { auto rng = llvm::make_early_inc_range(blockArg.getUses()); { rewriter.setInsertionPoint(op); - SmallVector oldInits = op.inits(); + SmallVector oldInits = op.getInits(); oldInits[blockArg.getArgNumber()] = rewriter.create( addI.getLoc(), oldInits[blockArg.getArgNumber()], addI.getOperand(1)); - op.initsMutable().assign(oldInits); + op.getInitsMutable().assign(oldInits); rewriter.updateRootInPlace( addI, [&] { addI.replaceAllUsesWith(blockArg); }); } - YieldOp afterYield = cast(op.after().front().back()); + YieldOp afterYield = cast(op.getAfter().front().back()); rewriter.setInsertionPoint(afterYield); - SmallVector oldYields = afterYield.results(); + SmallVector oldYields = afterYield.getResults(); oldYields[blockArg.getArgNumber()] = rewriter.create( addI.getLoc(), oldYields[blockArg.getArgNumber()], addI.getOperand(1)); rewriter.updateRootInPlace(afterYield, [&] { - afterYield.resultsMutable().assign(oldYields); + afterYield.getResultsMutable().assign(oldYields); }); - rewriter.setInsertionPointToStart(&op.before().front()); + rewriter.setInsertionPointToStart(&op.getBefore().front()); auto sub = rewriter.create(addI.getLoc(), blockArg, addI.getOperand(1)); for (OpOperand &use : rng) { @@ -1139,7 +1139,7 @@ struct MoveWhileDown3 : public OpRewritePattern { LogicalResult matchAndRewrite(WhileOp op, PatternRewriter &rewriter) const override { - auto term = cast(op.before().front().getTerminator()); + scf::ConditionOp term = cast(op.getBefore().front().getTerminator()); SmallVector toErase; SmallVector newOps; SmallVector condOps; @@ -1147,8 +1147,8 @@ struct MoveWhileDown3 : public OpRewritePattern { op.getAfterArguments().end()); SmallVector returns; assert(origAfterArgs.size() == op.getResults().size()); - assert(origAfterArgs.size() == term.args().size()); - for (auto pair : llvm::zip(op.getResults(), term.args(), origAfterArgs)) { + assert(origAfterArgs.size() == term.getArgs().size()); + for (auto pair : llvm::zip(op.getResults(), term.getArgs(), origAfterArgs)) { if (std::get<0>(pair).use_empty()) { if (std::get<2>(pair).use_empty()) { toErase.push_back(std::get<2>(pair).getArgNumber()); @@ -1162,9 +1162,9 @@ struct MoveWhileDown3 : public OpRewritePattern { Operation *cloned = std::get<1>(pair).getDefiningOp(); if (!std::get<1>(pair).hasOneUse()) { cloned = std::get<1>(pair).getDefiningOp()->clone(); - op.after().front().push_front(cloned); + op.getAfter().front().push_front(cloned); } else { - cloned->moveBefore(&op.after().front().front()); + cloned->moveBefore(&op.getAfter().front().front()); } rewriter.updateRootInPlace(std::get<1>(pair).getDefiningOp(), [&] { std::get<2>(pair).replaceAllUsesWith(cloned->getResult(0)); @@ -1174,7 +1174,7 @@ struct MoveWhileDown3 : public OpRewritePattern { llvm::make_early_inc_range(cloned->getOpOperands())) { { newOps.push_back(o.get()); - o.set(op.after().front().addArgument(o.get().getType())); + o.set(op.getAfter().front().addArgument(o.get().getType())); } } continue; @@ -1190,19 +1190,19 @@ struct MoveWhileDown3 : public OpRewritePattern { condOps.append(newOps.begin(), newOps.end()); rewriter.updateRootInPlace( - term, [&] { op.after().front().eraseArguments(toErase); }); + term, [&] { op.getAfter().front().eraseArguments(toErase); }); rewriter.setInsertionPoint(term); - rewriter.replaceOpWithNewOp(term, term.condition(), condOps); + rewriter.replaceOpWithNewOp(term, term.getCondition(), condOps); rewriter.setInsertionPoint(op); SmallVector resultTypes; for (auto v : condOps) { resultTypes.push_back(v.getType()); } - auto nop = rewriter.create(op.getLoc(), resultTypes, op.inits()); + auto nop = rewriter.create(op.getLoc(), resultTypes, op.getInits()); - nop.before().takeBody(op.before()); - nop.after().takeBody(op.after()); + nop.getBefore().takeBody(op.getBefore()); + nop.getAfter().takeBody(op.getAfter()); rewriter.updateRootInPlace(op, [&] { for (auto pair : llvm::enumerate(returns)) { @@ -1210,7 +1210,7 @@ struct MoveWhileDown3 : public OpRewritePattern { } }); - assert(resultTypes.size() == nop.after().front().getNumArguments()); + assert(resultTypes.size() == nop.getAfter().front().getNumArguments()); assert(resultTypes.size() == condOps.size()); rewriter.eraseOp(op); @@ -1269,7 +1269,7 @@ struct WhileLICM : public OpRewritePattern { auto definingOp = value.getDefiningOp(); bool definedOutside = (definingOp && !!willBeMovedSet.count(definingOp)) || - !op.before().isAncestor(value.getParentRegion()); + !op.getBefore().isAncestor(value.getParentRegion()); return definedOutside; }; @@ -1277,7 +1277,7 @@ struct WhileLICM : public OpRewritePattern { // hoist operations from there. These regions might have semantics unknown // to this rewriting. If the nested regions are loops, they will have been // processed. - for (auto &block : op.before()) { + for (auto &block : op.getBefore()) { for (auto &op : block.without_terminator()) { bool legal = canBeHoisted(&op, isDefinedOutsideOfBody); if (legal) { @@ -1299,7 +1299,7 @@ struct RemoveUnusedCondVar : public OpRewritePattern { LogicalResult matchAndRewrite(WhileOp op, PatternRewriter &rewriter) const override { - auto term = cast(op.before().front().getTerminator()); + auto term = cast(op.getBefore().front().getTerminator()); SmallVector conds; SmallVector eraseArgs; SmallVector keepArgs; @@ -1308,7 +1308,7 @@ struct RemoveUnusedCondVar : public OpRewritePattern { std::map valueOffsets; std::map resultOffsets; SmallVector resultArgs; - for (auto pair : llvm::zip(term.args(), op.after().front().getArguments(), + for (auto pair : llvm::zip(term.getArgs(), op.getAfter().front().getArguments(), op.getResults())) { auto arg = std::get<0>(pair); auto afarg = std::get<1>(pair); @@ -1331,24 +1331,24 @@ struct RemoveUnusedCondVar : public OpRewritePattern { } i++; } - assert(i == op.after().front().getArguments().size()); + assert(i == op.getAfter().front().getArguments().size()); if (eraseArgs.size() != 0) { rewriter.setInsertionPoint(term); - rewriter.replaceOpWithNewOp(term, term.condition(), + rewriter.replaceOpWithNewOp(term, term.getCondition(), conds); rewriter.setInsertionPoint(op); - auto op2 = rewriter.create(op.getLoc(), tys, op.inits()); + auto op2 = rewriter.create(op.getLoc(), tys, op.getInits()); - op2.before().takeBody(op.before()); - op2.after().takeBody(op.after()); + op2.getBefore().takeBody(op.getBefore()); + op2.getAfter().takeBody(op.getAfter()); for (auto pair : resultOffsets) { op.getResult(pair.first).replaceAllUsesWith(op2.getResult(pair.second)); } rewriter.eraseOp(op); - op2.after().front().eraseArguments(eraseArgs); + op2.getAfter().front().eraseArguments(eraseArgs); return success(); } return failure(); @@ -1360,19 +1360,19 @@ struct MoveSideEffectFreeWhile : public OpRewritePattern { LogicalResult matchAndRewrite(WhileOp op, PatternRewriter &rewriter) const override { - auto term = cast(op.before().front().getTerminator()); - SmallVector conds(term.args().begin(), term.args().end()); + scf::ConditionOp term = cast(op.getBefore().front().getTerminator()); + SmallVector conds(term.getArgs().begin(), term.getArgs().end()); bool changed = false; unsigned i = 0; - for (auto arg : term.args()) { + for (auto arg : term.getArgs()) { if (auto IC = arg.getDefiningOp()) { if (arg.hasOneUse() && op.getResult(i).use_empty()) { auto rep = - op.after().front().addArgument(IC->getOperand(0).getType()); - IC->moveBefore(&op.after().front(), op.after().front().begin()); + op.getAfter().front().addArgument(IC->getOperand(0).getType()); + IC->moveBefore(&op.getAfter().front(), op.getAfter().front().begin()); conds.push_back(IC.getIn()); IC.getInMutable().assign(rep); - op.after().front().getArgument(i).replaceAllUsesWith( + op.getAfter().front().getArgument(i).replaceAllUsesWith( IC->getResult(0)); changed = true; } @@ -1384,9 +1384,9 @@ struct MoveSideEffectFreeWhile : public OpRewritePattern { for (auto arg : conds) { tys.push_back(arg.getType()); } - auto op2 = rewriter.create(op.getLoc(), tys, op.inits()); - op2.before().takeBody(op.before()); - op2.after().takeBody(op.after()); + auto op2 = rewriter.create(op.getLoc(), tys, op.getInits()); + op2.getBefore().takeBody(op.getBefore()); + op2.getAfter().takeBody(op.getAfter()); unsigned j = 0; for (auto a : op.getResults()) { a.replaceAllUsesWith(op2.getResult(j)); @@ -1394,7 +1394,7 @@ struct MoveSideEffectFreeWhile : public OpRewritePattern { } rewriter.eraseOp(op); rewriter.setInsertionPoint(term); - rewriter.replaceOpWithNewOp(term, term.condition(), + rewriter.replaceOpWithNewOp(term, term.getCondition(), conds); return success(); } diff --git a/lib/polygeist/Passes/LoopRestructure.cpp b/lib/polygeist/Passes/LoopRestructure.cpp index 4545dd4..b03a1cb 100644 --- a/lib/polygeist/Passes/LoopRestructure.cpp +++ b/lib/polygeist/Passes/LoopRestructure.cpp @@ -295,8 +295,8 @@ bool LoopRestructure::removeIfFromRegion(DominanceInfo &domInfo, Region ®ion, /*hasElse*/ true); Succs[j] = new Block(); if (j == 0) { - ifOp.elseRegion().getBlocks().splice( - ifOp.elseRegion().getBlocks().end(), region.getBlocks(), + ifOp.getElseRegion().getBlocks().splice( + ifOp.getElseRegion().getBlocks().end(), region.getBlocks(), Succs[1 - j]); SmallVector idx; for (size_t i = 0; i < Succs[1 - j]->getNumArguments(); ++i) { @@ -305,18 +305,18 @@ bool LoopRestructure::removeIfFromRegion(DominanceInfo &domInfo, Region ®ion, idx.push_back(i); } Succs[1 - j]->eraseArguments(idx); - assert(!ifOp.elseRegion().getBlocks().empty()); + assert(!ifOp.getElseRegion().getBlocks().empty()); assert(condTys.size() == condBr.getTrueOperands().size()); - OpBuilder tbuilder(&ifOp.thenRegion().front(), - ifOp.thenRegion().front().begin()); + OpBuilder tbuilder(&ifOp.getThenRegion().front(), + ifOp.getThenRegion().front().begin()); tbuilder.create(tbuilder.getUnknownLoc(), emptyTys, condBr.getTrueOperands()); } else { - if (!ifOp.thenRegion().getBlocks().empty()) { - ifOp.thenRegion().front().erase(); + if (!ifOp.getThenRegion().getBlocks().empty()) { + ifOp.getThenRegion().front().erase(); } - ifOp.thenRegion().getBlocks().splice( - ifOp.thenRegion().getBlocks().end(), region.getBlocks(), + ifOp.getThenRegion().getBlocks().splice( + ifOp.getThenRegion().getBlocks().end(), region.getBlocks(), Succs[1 - j]); SmallVector idx; for (size_t i = 0; i < Succs[1 - j]->getNumArguments(); ++i) { @@ -325,9 +325,9 @@ bool LoopRestructure::removeIfFromRegion(DominanceInfo &domInfo, Region ®ion, idx.push_back(i); } Succs[1 - j]->eraseArguments(idx); - assert(!ifOp.elseRegion().getBlocks().empty()); - OpBuilder tbuilder(&ifOp.elseRegion().front(), - ifOp.elseRegion().front().begin()); + assert(!ifOp.getElseRegion().getBlocks().empty()); + OpBuilder tbuilder(&ifOp.getElseRegion().front(), + ifOp.getElseRegion().front().begin()); assert(condTys.size() == condBr.getFalseOperands().size()); tbuilder.create(tbuilder.getUnknownLoc(), emptyTys, condBr.getFalseOperands()); @@ -449,12 +449,12 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region ®ion) { Preds.push_back(block); } - loop.before().getBlocks().splice(loop.before().getBlocks().begin(), + loop.getBefore().getBlocks().splice(loop.getBefore().getBlocks().begin(), region.getBlocks(), header); for (auto *w : L->getBlocks()) { Block *b = &**w; if (b != header) { - loop.before().getBlocks().splice(loop.before().getBlocks().end(), + loop.getBefore().getBlocks().splice(loop.getBefore().getBlocks().end(), region.getBlocks(), b); } } @@ -462,7 +462,7 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region ®ion) { Block *pseudoExit = new Block(); auto i1Ty = builder.getI1Type(); { - loop.before().push_back(pseudoExit); + loop.getBefore().push_back(pseudoExit); SmallVector tys = {i1Ty}; for (auto t : combinedTypes) tys.push_back(t); @@ -592,7 +592,7 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region ®ion) { Block *after = new Block(); after->addArguments(combinedTypes); - loop.after().push_back(after); + loop.getAfter().push_back(after); OpBuilder builder2(after, after->begin()); SmallVector yieldargs; for (auto a : after->getArguments()) { @@ -619,42 +619,42 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region ®ion) { } builder2.create(builder.getUnknownLoc(), yieldargs); - domInfo.invalidate(&loop.before()); - runOnRegion(domInfo, loop.before()); - if (!removeIfFromRegion(domInfo, loop.before(), pseudoExit)) { + domInfo.invalidate(&loop.getBefore()); + runOnRegion(domInfo, loop.getBefore()); + if (!removeIfFromRegion(domInfo, loop.getBefore(), pseudoExit)) { attemptToFoldIntoPredecessor(pseudoExit); } attemptToFoldIntoPredecessor(wrapper); attemptToFoldIntoPredecessor(target); - if (loop.before().getBlocks().size() != 1) { + if (loop.getBefore().getBlocks().size() != 1) { Block *blk = new Block(); OpBuilder B(loop.getContext()); B.setInsertionPointToEnd(blk); auto cop = - cast(loop.before().getBlocks().back().back()); + cast(loop.getBefore().getBlocks().back().back()); auto er = B.create(loop.getLoc(), cop.getOperandTypes()); - er.region().getBlocks().splice(er.region().getBlocks().begin(), - loop.before().getBlocks()); - loop.before().push_back(blk); + er.getRegion().getBlocks().splice(er.getRegion().getBlocks().begin(), + loop.getBefore().getBlocks()); + loop.getBefore().push_back(blk); SmallVector yields; for (auto a : er.getResults()) yields.push_back(a); yields.erase(yields.begin()); B.create(cop.getLoc(), er.getResult(0), yields); B.setInsertionPoint(&*cop); - for (auto arg : er.region().front().getArguments()) { + for (auto arg : er.getRegion().front().getArguments()) { auto na = blk->addArgument(arg.getType()); arg.replaceAllUsesWith(na); } - er.region().front().eraseArguments([](BlockArgument) { return true; }); + er.getRegion().front().eraseArguments([](BlockArgument) { return true; }); B.create(cop.getLoc(), cop.getOperands()); cop.erase(); } - assert(loop.before().getBlocks().size() == 1); - runOnRegion(domInfo, loop.after()); - assert(loop.after().getBlocks().size() == 1); + assert(loop.getBefore().getBlocks().size() == 1); + runOnRegion(domInfo, loop.getAfter()); + assert(loop.getAfter().getBlocks().size() == 1); } } diff --git a/lib/polygeist/Passes/Mem2Reg.cpp b/lib/polygeist/Passes/Mem2Reg.cpp index fce52d0..339ab93 100644 --- a/lib/polygeist/Passes/Mem2Reg.cpp +++ b/lib/polygeist/Passes/Mem2Reg.cpp @@ -453,7 +453,7 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector idx, for (auto a : ops) { if (StoringOperations.count(a)) { if (auto exOp = dyn_cast(a)) { - valueAtStartOfBlock[&*exOp.region().begin()] = lastVal; + valueAtStartOfBlock[&*exOp.getRegion().begin()] = lastVal; Value thenVal; // = handleBlock(exOp.region().front(), lastVal); lastVal = nullptr; seenSubStore = true; @@ -477,7 +477,7 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector idx, continue; } - Block &then = exOp.region().back(); + Block &then = exOp.getRegion().back(); OpBuilder B(exOp.getContext()); auto yieldOp = cast(then.back()); B.setInsertionPoint(yieldOp); @@ -502,13 +502,13 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector idx, auto nextIf = B.create(exOp.getLoc(), tys); - SmallVector thenVals = yieldOp.results(); + SmallVector thenVals = yieldOp.getResults(); thenVals.push_back(thenVal); yieldOp->setOperands(thenVals); - nextIf.region().getBlocks().clear(); - nextIf.region().getBlocks().splice( - nextIf.region().getBlocks().begin(), - exOp.region().getBlocks()); + nextIf.getRegion().getBlocks().clear(); + nextIf.getRegion().getBlocks().splice( + nextIf.getRegion().getBlocks().begin(), + exOp.getRegion().getBlocks()); SmallVector resvals = (nextIf.getResults()); lastVal = resvals.back(); @@ -559,15 +559,15 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector idx, lastVal = newLoad->getResult(0); } - valueAtStartOfBlock[&*ifOp.thenRegion().begin()] = lastVal; + valueAtStartOfBlock[&*ifOp.getThenRegion().begin()] = lastVal; mlir::Value thenVal = - handleBlock(*ifOp.thenRegion().begin(), lastVal); + handleBlock(*ifOp.getThenRegion().begin(), lastVal); - if (lastVal && ifOp.elseRegion().getBlocks().size()) - valueAtStartOfBlock[&*ifOp.elseRegion().begin()] = lastVal; + if (lastVal && ifOp.getElseRegion().getBlocks().size()) + valueAtStartOfBlock[&*ifOp.getElseRegion().begin()] = lastVal; mlir::Value elseVal = - (ifOp.elseRegion().getBlocks().size()) - ? handleBlock(*ifOp.elseRegion().begin(), lastVal) + (ifOp.getElseRegion().getBlocks().size()) + ? handleBlock(*ifOp.getElseRegion().begin(), lastVal) : lastVal; if (thenVal == elseVal && thenVal != nullptr) { @@ -582,41 +582,37 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector idx, ifOp.getResultTypes().end()); tys.push_back(thenVal.getType()); auto nextIf = B.create( - ifOp.getLoc(), tys, ifOp.condition(), /*hasElse*/ true); + ifOp.getLoc(), tys, ifOp.getCondition(), /*hasElse*/ true); - Block &then = ifOp.thenRegion().back(); + Block &then = ifOp.getThenRegion().back(); SmallVector thenVals = - cast(then.back()).results(); + cast(then.back()).getResults(); thenVals.push_back(thenVal); - nextIf.thenRegion().getBlocks().clear(); - nextIf.thenRegion().getBlocks().splice( - nextIf.thenRegion().getBlocks().begin(), - ifOp.thenRegion().getBlocks()); + nextIf.getThenRegion().getBlocks().clear(); + nextIf.getThenRegion().takeBody(ifOp.getThenRegion()); cast( - nextIf.thenRegion().back().getTerminator()) + nextIf.getThenRegion().back().getTerminator()) ->setOperands(thenVals); - if (ifOp.elseRegion().getBlocks().size()) { - nextIf.elseRegion().getBlocks().clear(); + if (ifOp.getElseRegion().getBlocks().size()) { + nextIf.getElseRegion().getBlocks().clear(); SmallVector elseVals = - cast(ifOp.elseRegion().back().back()) - .results(); + cast(ifOp.getElseRegion().back().back()) + .getResults(); elseVals.push_back(elseVal); - nextIf.elseRegion().getBlocks().splice( - nextIf.elseRegion().getBlocks().begin(), - ifOp.elseRegion().getBlocks()); + nextIf.getElseRegion().takeBody(ifOp.getElseRegion()); cast( - nextIf.elseRegion().back().getTerminator()) + nextIf.getElseRegion().back().getTerminator()) ->setOperands(elseVals); } else { - B.setInsertionPoint(&nextIf.elseRegion().back(), - nextIf.elseRegion().back().begin()); + B.setInsertionPoint(&nextIf.getElseRegion().back(), + nextIf.getElseRegion().back().begin()); SmallVector elseVals; elseVals.push_back(elseVal); B.create(ifOp.getLoc(), elseVals); } - SmallVector resvals = (nextIf.results()); + SmallVector resvals = nextIf.getResults(); lastVal = resvals.back(); resvals.pop_back(); ifOp.replaceAllUsesWith(resvals); diff --git a/lib/polygeist/Passes/ParallelLoopDistribute.cpp b/lib/polygeist/Passes/ParallelLoopDistribute.cpp index bb0708c..7653298 100644 --- a/lib/polygeist/Passes/ParallelLoopDistribute.cpp +++ b/lib/polygeist/Passes/ParallelLoopDistribute.cpp @@ -101,7 +101,7 @@ struct ReplaceIfWithFors : public OpRewritePattern { LogicalResult matchAndRewrite(scf::IfOp op, PatternRewriter &rewriter) const override { - assert(op.condition().getType().isInteger(1)); + assert(op.getCondition().getType().isInteger(1)); if (!hasNestedBarrier(op)) { LLVM_DEBUG(DBGS() << "[if-to-for] no nested barrier\n"); @@ -119,7 +119,7 @@ struct ReplaceIfWithFors : public OpRewritePattern { auto cond = rewriter.create( loc, rewriter.getIndexType(), - rewriter.create(loc, op.condition(), + rewriter.create(loc, op.getCondition(), mlir::IntegerType::get(one.getContext(), 64))); auto thenLoop = rewriter.create(loc, zero, cond, one, forArgs); if (forArgs.size() == 0) @@ -128,7 +128,7 @@ struct ReplaceIfWithFors : public OpRewritePattern { SmallVector vals; - if (!op.elseRegion().empty()) { + if (!op.getElseRegion().empty()) { auto negCondition = rewriter.create(loc, one, cond); scf::ForOp elseLoop = rewriter.create(loc, zero, negCondition, one, forArgs); if (forArgs.size() == 0) @@ -136,7 +136,7 @@ struct ReplaceIfWithFors : public OpRewritePattern { rewriter.mergeBlocks(op.getBody(1), elseLoop.getBody(0)); for (auto tup : llvm::zip(thenLoop.getResults(), elseLoop.getResults())) { - vals.push_back(rewriter.create(op.getLoc(), op.condition(), std::get<0>(tup), std::get<1>(tup))); + vals.push_back(rewriter.create(op.getLoc(), op.getCondition(), std::get<0>(tup), std::get<1>(tup))); } } @@ -153,7 +153,7 @@ static bool isDefinedAbove(Value value, Operation *user) { /// Returns `true` if the loop has a form expected by interchange patterns. static bool isNormalized(scf::ForOp op) { - return isDefinedAbove(op.lowerBound(), op) && isDefinedAbove(op.step(), op); + return isDefinedAbove(op.getLowerBound(), op) && isDefinedAbove(op.getStep(), op); } /// Transforms a loop to the normal form expected by interchange patterns, i.e. @@ -179,15 +179,15 @@ struct NormalizeLoop : public OpRewritePattern { rewriter.restoreInsertionPoint(point); Value difference = - rewriter.create(op.getLoc(), op.upperBound(), op.lowerBound()); + rewriter.create(op.getLoc(), op.getUpperBound(), op.getLowerBound()); Value tripCount = - rewriter.create(op.getLoc(), difference, op.step()); + rewriter.create(op.getLoc(), difference, op.getStep()); auto newForOp = rewriter.create(op.getLoc(), zero, tripCount, one); rewriter.setInsertionPointToStart(newForOp.getBody()); Value scaled = rewriter.create( - op.getLoc(), newForOp.getInductionVar(), op.step()); - Value iv = rewriter.create(op.getLoc(), op.lowerBound(), scaled); + op.getLoc(), newForOp.getInductionVar(), op.getStep()); + Value iv = rewriter.create(op.getLoc(), op.getLowerBound(), scaled); rewriter.mergeBlockBefore(op.getBody(), &newForOp.getBody()->back(), {iv}); rewriter.eraseOp(&newForOp.getBody()->back()); rewriter.eraseOp(op); @@ -205,8 +205,8 @@ static bool isNormalized(scf::ParallelOp op) { APInt value; return matchPattern(v, m_ConstantInt(&value)) && value.isOneValue(); }; - return llvm::all_of(op.lowerBound(), isZero) && - llvm::all_of(op.step(), isOne); + return llvm::all_of(op.getLowerBound(), isZero) && + llvm::all_of(op.getStep(), isOne); } /// Transforms a loop to the normal form expected by interchange patterns, i.e. @@ -242,9 +242,9 @@ struct NormalizeParallel : public OpRewritePattern { rewriter.setInsertionPointToStart(newOp.getBody()); for (unsigned i = 0, e = iterationCounts.size(); i < e; ++i) { Value scaled = rewriter.create( - op.getLoc(), newOp.getInductionVars()[i], op.step()[i]); + op.getLoc(), newOp.getInductionVars()[i], op.getStep()[i]); Value shifted = - rewriter.create(op.getLoc(), op.lowerBound()[i], scaled); + rewriter.create(op.getLoc(), op.getLowerBound()[i], scaled); inductionVars.push_back(shifted); } @@ -333,7 +333,7 @@ struct WrapForWithBarrier : public OpRewritePattern { return wrapWithBarriers(op, rewriter, [&](Operation *prevOp) { if (auto loadOp = dyn_cast_or_null(prevOp)) { - if (loadOp.result() == op.upperBound() && + if (loadOp.result() == op.getUpperBound() && loadOp.indices() == cast(op->getParentOp()).getInductionVars()) { prevOp = prevOp->getPrevNode(); @@ -353,7 +353,7 @@ struct WrapWhileWithBarrier : public OpRewritePattern { LogicalResult matchAndRewrite(scf::WhileOp op, PatternRewriter &rewriter) const override { if (op.getNumOperands() != 0 || - !llvm::hasSingleElement(op.after().front())) { + !llvm::hasSingleElement(op.getAfter().front())) { LLVM_DEBUG(DBGS() << "[wrap-while] ignoring non-rotated loop\n";); return failure(); } @@ -372,7 +372,7 @@ static void moveBodies(PatternRewriter &rewriter, scf::ParallelOp op, OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(newForLoop.getBody()); auto newParallel = rewriter.create( - op.getLoc(), op.lowerBound(), op.upperBound(), op.step()); + op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep()); // Merge in two stages so we can properly replace uses of two induction // varibales defined in different blocks. @@ -419,8 +419,8 @@ struct InterchangeForPFor : public OpRewritePattern { } auto newForLoop = - rewriter.create(forLoop.getLoc(), forLoop.lowerBound(), - forLoop.upperBound(), forLoop.step()); + rewriter.create(forLoop.getLoc(), forLoop.getLowerBound(), + forLoop.getUpperBound(), forLoop.getStep()); moveBodies(rewriter, op, forLoop, newForLoop); return success(); } @@ -446,7 +446,7 @@ struct InterchangeForPForLoad : public OpRewritePattern { } auto loadOp = dyn_cast(op.getBody()->front()); auto forOp = dyn_cast(op.getBody()->front().getNextNode()); - if (!loadOp || !forOp || loadOp.result() != forOp.upperBound() || + if (!loadOp || !forOp || loadOp.result() != forOp.getUpperBound() || loadOp.indices() != op.getInductionVars()) { LLVM_DEBUG(DBGS() << "[interchange-load] expected pfor(load, for)"); return failure(); @@ -470,7 +470,7 @@ struct InterchangeForPForLoad : public OpRewritePattern { loadOp.getLoc(), loadOp.getMemRef(), SmallVector(loadOp.getMemRefType().getRank(), zero)); auto newForLoop = rewriter.create( - forOp.getLoc(), forOp.lowerBound(), tripCount, forOp.step()); + forOp.getLoc(), forOp.getLowerBound(), tripCount, forOp.getStep()); moveBodies(rewriter, op, forOp, newForLoop); return success(); @@ -535,9 +535,9 @@ findInsertionPointAfterLoopOperands(scf::ParallelOp op) { // Find the earliest insertion point where loop bounds are fully defined. PostDominanceInfo postDominanceInfo(op->getParentOfType()); SmallVector operands; - llvm::append_range(operands, op.lowerBound()); - llvm::append_range(operands, op.upperBound()); - llvm::append_range(operands, op.step()); + llvm::append_range(operands, op.getLowerBound()); + llvm::append_range(operands, op.getUpperBound()); + llvm::append_range(operands, op.getStep()); return findNearestPostDominatingInsertionPoint(operands, postDominanceInfo); } @@ -562,7 +562,7 @@ struct InterchangeWhilePFor : public OpRewritePattern { LLVM_DEBUG(DBGS() << "[interchange-while] loop-carried values\n"); return failure(); } - if (!llvm::hasSingleElement(whileOp.after().front()) || !isNormalized(op)) { + if (!llvm::hasSingleElement(whileOp.getAfter().front()) || !isNormalized(op)) { LLVM_DEBUG(DBGS() << "[interchange-while] non-normalized loop\n"); return failure(); } @@ -573,37 +573,37 @@ struct InterchangeWhilePFor : public OpRewritePattern { auto newWhileOp = rewriter.create(whileOp.getLoc(), TypeRange(), ValueRange()); - rewriter.createBlock(&newWhileOp.after()); - rewriter.clone(whileOp.after().front().back()); + rewriter.createBlock(&newWhileOp.getAfter()); + rewriter.clone(whileOp.getAfter().front().back()); - rewriter.createBlock(&newWhileOp.before()); + rewriter.createBlock(&newWhileOp.getBefore()); auto newParallelOp = rewriter.create( - op.getLoc(), op.lowerBound(), op.upperBound(), op.step()); + op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep()); - auto conditionOp = cast(whileOp.before().front().back()); + auto conditionOp = cast(whileOp.getBefore().front().back()); rewriter.mergeBlockBefore(op.getBody(), &newParallelOp.getBody()->back(), newParallelOp.getInductionVars()); rewriter.eraseOp(newParallelOp.getBody()->back().getPrevNode()); - rewriter.mergeBlockBefore(&whileOp.before().front(), + rewriter.mergeBlockBefore(&whileOp.getBefore().front(), &newParallelOp.getBody()->back()); - Operation *conditionDefiningOp = conditionOp.condition().getDefiningOp(); + Operation *conditionDefiningOp = conditionOp.getCondition().getDefiningOp(); if (conditionDefiningOp && - !isDefinedAbove(conditionOp.condition(), conditionOp)) { + !isDefinedAbove(conditionOp.getCondition(), conditionOp)) { std::pair insertionPoint = findInsertionPointAfterLoopOperands(op); rewriter.setInsertionPoint(insertionPoint.first, insertionPoint.second); SmallVector iterationCounts = emitIterationCounts(rewriter, op); Value allocated = allocateTemporaryBuffer( - rewriter, conditionOp.condition(), iterationCounts); + rewriter, conditionOp.getCondition(), iterationCounts); Value zero = rewriter.create(op.getLoc(), 0); rewriter.setInsertionPointAfter(conditionDefiningOp); rewriter.create(conditionDefiningOp->getLoc(), - conditionOp.condition(), allocated, + conditionOp.getCondition(), allocated, newParallelOp.getInductionVars()); - rewriter.setInsertionPointToEnd(&newWhileOp.before().front()); + rewriter.setInsertionPointToEnd(&newWhileOp.getBefore().front()); SmallVector zeros(iterationCounts.size(), zero); Value reloaded = rewriter.create( conditionDefiningOp->getLoc(), allocated, zeros); @@ -646,7 +646,7 @@ struct RotateWhile : public OpRewritePattern { LogicalResult matchAndRewrite(scf::WhileOp op, PatternRewriter &rewriter) const override { - if (llvm::hasSingleElement(op.after().front())) { + if (llvm::hasSingleElement(op.getAfter().front())) { LLVM_DEBUG(DBGS() << "[rotate-while] the after region is empty"); return failure(); } @@ -659,15 +659,15 @@ struct RotateWhile : public OpRewritePattern { return failure(); } - auto condition = cast(op.before().front().back()); + auto condition = cast(op.getBefore().front().back()); rewriter.setInsertionPoint(condition); auto conditional = - rewriter.create(op.getLoc(), condition.condition()); - rewriter.mergeBlockBefore(&op.after().front(), + rewriter.create(op.getLoc(), condition.getCondition()); + rewriter.mergeBlockBefore(&op.getAfter().front(), &conditional.getBody()->back()); rewriter.eraseOp(&conditional.getBody()->back()); - rewriter.createBlock(&op.after()); + rewriter.createBlock(&op.getAfter()); rewriter.clone(conditional.getBody()->back()); LLVM_DEBUG(DBGS() << "[rotate-while] done\n"); @@ -758,7 +758,7 @@ struct DistributeAroundBarrier : public OpRewritePattern { // Create the second loop. rewriter.setInsertionPointAfter(op); auto newLoop = rewriter.create( - op.getLoc(), op.lowerBound(), op.upperBound(), op.step()); + op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep()); rewriter.eraseOp(&newLoop.getBody()->back()); for (auto alloc : allocations) @@ -837,8 +837,8 @@ struct Reg2MemFor : public OpRewritePattern { rewriter.create(op.getLoc(), operand, alloc, zero); } - auto newOp = rewriter.create(op.getLoc(), op.lowerBound(), - op.upperBound(), op.step()); + auto newOp = rewriter.create(op.getLoc(), op.getLowerBound(), + op.getUpperBound(), op.getStep()); rewriter.setInsertionPointToStart(newOp.getBody()); SmallVector newRegionArguments; newRegionArguments.push_back(newOp.getInductionVar()); @@ -849,7 +849,7 @@ struct Reg2MemFor : public OpRewritePattern { newRegionArguments); rewriter.setInsertionPoint(newOp.getBody()->getTerminator()); - for (auto en : llvm::enumerate(oldTerminator.results())) { + for (auto en : llvm::enumerate(oldTerminator.getResults())) { rewriter.create(op.getLoc(), en.value(), allocated[en.index()], zero); } @@ -908,35 +908,35 @@ struct Reg2MemWhile : public OpRewritePattern { auto newOp = rewriter.create(op.getLoc(), TypeRange(), ValueRange()); Block *newBefore = - rewriter.createBlock(&newOp.before(), newOp.before().begin()); + rewriter.createBlock(&newOp.getBefore(), newOp.getBefore().begin()); SmallVector newBeforeArguments; loadValues(op.getLoc(), beforeAllocated, zero, rewriter, newBeforeArguments); - rewriter.mergeBlocks(&op.before().front(), newBefore, newBeforeArguments); + rewriter.mergeBlocks(&op.getBefore().front(), newBefore, newBeforeArguments); auto beforeTerminator = - cast(newOp.before().front().getTerminator()); + cast(newOp.getBefore().front().getTerminator()); rewriter.setInsertionPoint(beforeTerminator); - storeValues(op.getLoc(), beforeTerminator.args(), afterAllocated, zero, + storeValues(op.getLoc(), beforeTerminator.getArgs(), afterAllocated, zero, rewriter); rewriter.updateRootInPlace(beforeTerminator, - [&] { beforeTerminator.argsMutable().clear(); }); + [&] { beforeTerminator.getArgsMutable().clear(); }); Block *newAfter = - rewriter.createBlock(&newOp.after(), newOp.after().begin()); + rewriter.createBlock(&newOp.getAfter(), newOp.getAfter().begin()); SmallVector newAfterArguments; loadValues(op.getLoc(), afterAllocated, zero, rewriter, newAfterArguments); - rewriter.mergeBlocks(&op.after().front(), newAfter, newAfterArguments); + rewriter.mergeBlocks(&op.getAfter().front(), newAfter, newAfterArguments); auto afterTerminator = - cast(newOp.after().front().getTerminator()); + cast(newOp.getAfter().front().getTerminator()); rewriter.setInsertionPoint(afterTerminator); - storeValues(op.getLoc(), afterTerminator.results(), beforeAllocated, zero, + storeValues(op.getLoc(), afterTerminator.getResults(), beforeAllocated, zero, rewriter); rewriter.updateRootInPlace( - afterTerminator, [&] { afterTerminator.resultsMutable().clear(); }); + afterTerminator, [&] { afterTerminator.getResultsMutable().clear(); }); rewriter.setInsertionPointAfter(op); SmallVector results; diff --git a/lib/polygeist/Passes/RaiseToAffine.cpp b/lib/polygeist/Passes/RaiseToAffine.cpp index 5a414bb..3c83704 100644 --- a/lib/polygeist/Passes/RaiseToAffine.cpp +++ b/lib/polygeist/Passes/RaiseToAffine.cpp @@ -30,7 +30,7 @@ struct ForOpRaising : public OpRewritePattern { bool isAffine(scf::ForOp loop) const { // return true; // enforce step to be a ConstantIndexOp (maybe too restrictive). - return isa_and_nonnull(loop.step().getDefiningOp()); + return isa_and_nonnull(loop.getStep().getDefiningOp()); } void canonicalizeLoopBounds(AffineForOp forOp) const { @@ -67,23 +67,23 @@ struct ForOpRaising : public OpRewritePattern { if (isAffine(loop)) { OpBuilder builder(loop); - if (!isValidIndex(loop.lowerBound())) { + if (!isValidIndex(loop.getLowerBound())) { return failure(); } - if (!isValidIndex(loop.upperBound())) { + if (!isValidIndex(loop.getUpperBound())) { return failure(); } AffineForOp affineLoop = rewriter.create( - loop.getLoc(), loop.lowerBound(), builder.getSymbolIdentityMap(), - loop.upperBound(), builder.getSymbolIdentityMap(), - getStep(loop.step()), loop.getIterOperands()); + loop.getLoc(), loop.getLowerBound(), builder.getSymbolIdentityMap(), + loop.getUpperBound(), builder.getSymbolIdentityMap(), + getStep(loop.getStep()), loop.getIterOperands()); canonicalizeLoopBounds(affineLoop); auto mergedYieldOp = - cast(loop.region().front().getTerminator()); + cast(loop.getRegion().front().getTerminator()); Block &newBlock = affineLoop.region().front(); @@ -97,10 +97,10 @@ struct ForOpRaising : public OpRewritePattern { rewriter.updateRootInPlace(loop, [&] { affineLoop.region().front().getOperations().splice( affineLoop.region().front().getOperations().begin(), - loop.region().front().getOperations()); + loop.getRegion().front().getOperations()); for (auto pair : llvm::zip(affineLoop.region().front().getArguments(), - loop.region().front().getArguments())) { + loop.getRegion().front().getArguments())) { std::get<1>(pair).replaceAllUsesWith(std::get<0>(pair)); } }); diff --git a/llvm-project b/llvm-project index fe5137a..a6a583d 160000 --- a/llvm-project +++ b/llvm-project @@ -1 +1 @@ -Subproject commit fe5137a0fe421ee153f115ae3cd7fb51bba65795 +Subproject commit a6a583dae40485cacfac56811e6d9131bac6ca74 diff --git a/tools/mlir-clang/Lib/CGStmt.cc b/tools/mlir-clang/Lib/CGStmt.cc index 3cf1391..74d4a2a 100644 --- a/tools/mlir-clang/Lib/CGStmt.cc +++ b/tools/mlir-clang/Lib/CGStmt.cc @@ -172,10 +172,10 @@ void MLIRScanner::buildAffineLoopImpl( builder.setInsertionPointToEnd(®.front()); auto er = builder.create(loc, ArrayRef()); - er.region().push_back(new Block()); - builder.setInsertionPointToStart(&er.region().back()); + er.getRegion().push_back(new Block()); + builder.setInsertionPointToStart(&er.getRegion().back()); builder.create(loc); - builder.setInsertionPointToStart(&er.region().back()); + builder.setInsertionPointToStart(&er.getRegion().back()); if (!descr.getForwardMode()) { val = builder.create(loc, val, lb); @@ -355,15 +355,15 @@ ValueCategory MLIRScanner::VisitOMPParallelForDirective( auto oldpoint = builder.getInsertionPoint(); auto oldblock = builder.getInsertionBlock(); - builder.setInsertionPointToStart(&affineOp.region().front()); + builder.setInsertionPointToStart(&affineOp.getRegion().front()); auto executeRegion = builder.create(loc, ArrayRef()); - executeRegion.region().push_back(new Block()); - builder.setInsertionPointToStart(&executeRegion.region().back()); + executeRegion.getRegion().push_back(new Block()); + builder.setInsertionPointToStart(&executeRegion.getRegion().back()); auto oldScope = allocationScope; - allocationScope = &executeRegion.region().back(); + allocationScope = &executeRegion.getRegion().back(); for (auto zp : zip(inds, fors->counters())) { auto idx = builder.create( @@ -552,13 +552,13 @@ ValueCategory MLIRScanner::VisitIfStmt(clang::IfStmt *stmt) { bool hasElseRegion = stmt->getElse(); auto ifOp = builder.create(loc, cond, hasElseRegion); - ifOp.thenRegion().back().clear(); - builder.setInsertionPointToStart(&ifOp.thenRegion().back()); + ifOp.getThenRegion().back().clear(); + builder.setInsertionPointToStart(&ifOp.getThenRegion().back()); Visit(stmt->getThen()); builder.create(loc); if (hasElseRegion) { - ifOp.elseRegion().back().clear(); - builder.setInsertionPointToStart(&ifOp.elseRegion().back()); + ifOp.getElseRegion().back().clear(); + builder.setInsertionPointToStart(&ifOp.getElseRegion().back()); Visit(stmt->getElse()); builder.create(loc); } @@ -574,7 +574,7 @@ ValueCategory MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) { SmallVector caseVals; auto er = builder.create(loc, ArrayRef()); - er.region().push_back(new Block()); + er.getRegion().push_back(new Block()); auto oldpoint2 = builder.getInsertionPoint(); auto oldblock2 = builder.getInsertionBlock(); @@ -613,7 +613,7 @@ ValueCategory MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) { } inCase = true; - er.region().getBlocks().push_back(&condB); + er.getRegion().getBlocks().push_back(&condB); blocks.push_back(&condB); builder.setInsertionPointToStart(&condB); @@ -638,7 +638,7 @@ ValueCategory MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) { } inCase = true; - er.region().getBlocks().push_back(&condB); + er.getRegion().getBlocks().push_back(&condB); builder.setInsertionPointToStart(&condB); auto i1Ty = builder.getIntegerType(1); @@ -668,7 +668,7 @@ ValueCategory MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) { loops.pop_back(); builder.create(loc, &exitB); - er.region().getBlocks().push_back(&exitB); + er.getRegion().getBlocks().push_back(&exitB); DenseIntElementsAttr caseValuesAttr; ShapedType caseValueType = mlir::VectorType::get( @@ -694,7 +694,7 @@ ValueCategory MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) { caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseVals8); } - builder.setInsertionPointToStart(&er.region().front()); + builder.setInsertionPointToStart(&er.getRegion().front()); builder.create( loc, cond, defaultB, ArrayRef(), caseValuesAttr, blocks, SmallVector(caseVals.size(), ArrayRef())); diff --git a/tools/mlir-clang/Lib/IfScope.cc b/tools/mlir-clang/Lib/IfScope.cc index 873f6a8..5d1394e 100644 --- a/tools/mlir-clang/Lib/IfScope.cc +++ b/tools/mlir-clang/Lib/IfScope.cc @@ -21,13 +21,13 @@ IfScope::IfScope(MLIRScanner &scanner) : scanner(scanner), prevBlock(nullptr) { /*hasElse*/ false); prevBlock = scanner.builder.getInsertionBlock(); prevIterator = scanner.builder.getInsertionPoint(); - ifOp.thenRegion().back().clear(); - scanner.builder.setInsertionPointToStart(&ifOp.thenRegion().back()); + ifOp.getThenRegion().back().clear(); + scanner.builder.setInsertionPointToStart(&ifOp.getThenRegion().back()); auto er = scanner.builder.create( scanner.loc, ArrayRef()); scanner.builder.create(scanner.loc); - er.region().push_back(new Block()); - scanner.builder.setInsertionPointToStart(&er.region().back()); + er.getRegion().push_back(new Block()); + scanner.builder.setInsertionPointToStart(&er.getRegion().back()); } } diff --git a/tools/mlir-clang/Lib/clang-mlir.cc b/tools/mlir-clang/Lib/clang-mlir.cc index 2a01c0a..ccf5637 100644 --- a/tools/mlir-clang/Lib/clang-mlir.cc +++ b/tools/mlir-clang/Lib/clang-mlir.cc @@ -68,7 +68,8 @@ static mlir::Value castCallerMemRefArg(mlir::Value callerArg, if (MemRefType dstTy = calleeArgType.dyn_cast_or_null()) { MemRefType srcTy = callerArgType.dyn_cast(); - if (srcTy && dstTy.getElementType() == srcTy.getElementType()) { + if (srcTy && dstTy.getElementType() == srcTy.getElementType() + && dstTy.getMemorySpace() == srcTy.getMemorySpace()) { auto srcShape = srcTy.getShape(); auto dstShape = dstTy.getShape(); @@ -748,7 +749,7 @@ ValueCategory MLIRScanner::VisitVarDecl(clang::VarDecl *decl) { auto ifOp = builder.create(varLoc, cond, /*hasElse*/false); block = builder.getInsertionBlock(); iter = builder.getInsertionPoint(); - builder.setInsertionPointToStart(&ifOp.thenRegion().back()); + builder.setInsertionPointToStart(&ifOp.getThenRegion().back()); builder.create(varLoc, builder.create(varLoc, false, 1), boolop, std::vector({getConstantIndex(0)})); } } else @@ -2919,15 +2920,6 @@ ValueCategory MLIRScanner::CallHelper( } } assert(val); - /* - if (val.getType() != fnType.getInput(i)) { - if (auto MR1 = val.getType().dyn_cast()) { - if (auto MR2 = fnType.getInput(i).dyn_cast()) { - val = builder.create(loc, val, MR2); - } - } - } - */ args.push_back(val); i++; } @@ -3460,7 +3452,7 @@ ValueCategory MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) { auto oldpoint = builder.getInsertionPoint(); auto oldblock = builder.getInsertionBlock(); - builder.setInsertionPointToStart(&ifOp.thenRegion().back()); + builder.setInsertionPointToStart(&ifOp.getThenRegion().back()); auto rhs = Visit(BO->getRHS()).getValue(builder); assert(rhs != nullptr); @@ -3476,7 +3468,7 @@ ValueCategory MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) { mlir::Value truearray[] = {rhs}; builder.create(loc, truearray); - builder.setInsertionPointToStart(&ifOp.elseRegion().back()); + builder.setInsertionPointToStart(&ifOp.getElseRegion().back()); mlir::Value falsearray[] = { builder.create(loc, 0, types[0])}; builder.create(loc, falsearray); @@ -3497,12 +3489,12 @@ ValueCategory MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) { auto oldpoint = builder.getInsertionPoint(); auto oldblock = builder.getInsertionBlock(); - builder.setInsertionPointToStart(&ifOp.thenRegion().back()); + builder.setInsertionPointToStart(&ifOp.getThenRegion().back()); mlir::Value truearray[] = {builder.create(loc, 1, types[0])}; builder.create(loc, truearray); - builder.setInsertionPointToStart(&ifOp.elseRegion().back()); + builder.setInsertionPointToStart(&ifOp.getElseRegion().back()); auto rhs = Visit(BO->getRHS()).getValue(builder); if (!rhs.getType().cast().isInteger(1)) { auto postTy = builder.getI1Type(); @@ -4747,7 +4739,7 @@ MLIRScanner::VisitConditionalOperator(clang::ConditionalOperator *E) { auto oldpoint = builder.getInsertionPoint(); auto oldblock = builder.getInsertionBlock(); - builder.setInsertionPointToStart(&ifOp.thenRegion().back()); + builder.setInsertionPointToStart(&ifOp.getThenRegion().back()); auto trueExpr = Visit(E->getTrueExpr()); @@ -4778,7 +4770,7 @@ MLIRScanner::VisitConditionalOperator(clang::ConditionalOperator *E) { builder.create(loc, truearray); } - builder.setInsertionPointToStart(&ifOp.elseRegion().back()); + builder.setInsertionPointToStart(&ifOp.getElseRegion().back()); auto falseExpr = Visit(E->getFalseExpr()); std::vector falsearray; @@ -4800,10 +4792,10 @@ MLIRScanner::VisitConditionalOperator(clang::ConditionalOperator *E) { types[i] = truearray[i].getType(); auto newIfOp = builder.create(loc, types, cond, /*hasElseRegion*/ true); - newIfOp.thenRegion().takeBody(ifOp.thenRegion()); - newIfOp.elseRegion().takeBody(ifOp.elseRegion()); + newIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); + newIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); + ifOp.erase(); return ValueCategory(newIfOp.getResult(0), /*isReference*/ isReference); - // return ifOp; } ValueCategory MLIRScanner::VisitStmtExpr(clang::StmtExpr *stmt) {