Rebase LLVM

This commit is contained in:
William S. Moses 2021-12-30 15:45:42 -05:00 committed by William Moses
parent 85e117eacb
commit 239a2c4d82
13 changed files with 315 additions and 314 deletions

View File

@ -29,7 +29,7 @@ static llvm::SmallVector<mlir::Value>
emitIterationCounts(mlir::OpBuilder &rewriter, mlir::scf::ParallelOp op) { emitIterationCounts(mlir::OpBuilder &rewriter, mlir::scf::ParallelOp op) {
using namespace mlir; using namespace mlir;
SmallVector<Value> iterationCounts; SmallVector<Value> iterationCounts;
for (auto bounds : llvm::zip(op.lowerBound(), op.upperBound(), op.step())) { for (auto bounds : llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep())) {
Value lowerBound = std::get<0>(bounds); Value lowerBound = std::get<0>(bounds);
Value upperBound = std::get<1>(bounds); Value upperBound = std::get<1>(bounds);
Value step = std::get<2>(bounds); Value step = std::get<2>(bounds);

View File

@ -693,6 +693,9 @@ public:
if (src.source().getType().cast<MemRefType>().getElementType() != if (src.source().getType().cast<MemRefType>().getElementType() !=
op.getType().cast<MemRefType>().getElementType()) op.getType().cast<MemRefType>().getElementType())
return failure(); return failure();
if (src.source().getType().cast<MemRefType>().getMemorySpace() !=
op.getType().cast<MemRefType>().getMemorySpace())
return failure();
rewriter.replaceOpWithNewOp<memref::CastOp>(op, op.getType(), src.source()); rewriter.replaceOpWithNewOp<memref::CastOp>(op, op.getType(), src.source());
return success(); return success();
@ -732,6 +735,11 @@ OpFoldResult Memref2PointerOp::fold(ArrayRef<Attribute> operands) {
sourceMutable().assign(mc.source()); sourceMutable().assign(mc.source());
return result(); return result();
} }
if (auto mc = source().getDefiningOp<polygeist::Pointer2MemrefOp>()) {
if (mc.source().getType() == getType()) {
return mc.source();
}
}
return nullptr; return nullptr;
} }
@ -868,6 +876,11 @@ OpFoldResult Pointer2MemrefOp::fold(ArrayRef<Attribute> operands) {
sourceMutable().assign(mc.getArg()); sourceMutable().assign(mc.getArg());
return result(); return result();
} }
if (auto mc = source().getDefiningOp<polygeist::Memref2PointerOp>()) {
if (mc.source().getType() == getType()) {
return mc.source();
}
}
return nullptr; return nullptr;
} }

View File

@ -1047,7 +1047,7 @@ void AffineCFGPass::runOnFunction() {
OpBuilder b(ifOp); OpBuilder b(ifOp);
AffineIfOp affineIfOp; AffineIfOp affineIfOp;
std::vector<mlir::Type> types; std::vector<mlir::Type> types;
for (auto v : ifOp.results()) { for (auto v : ifOp.getResults()) {
types.push_back(v.getType()); types.push_back(v.getType());
} }
@ -1055,7 +1055,7 @@ void AffineCFGPass::runOnFunction() {
SmallVector<bool, 2> eqflags; SmallVector<bool, 2> eqflags;
SmallVector<Value, 4> applies; SmallVector<Value, 4> applies;
std::deque<Value> todo = {ifOp.condition()}; std::deque<Value> todo = {ifOp.getCondition()};
while (todo.size()) { while (todo.size()) {
auto cur = todo.front(); auto cur = todo.front();
todo.pop_front(); todo.pop_front();
@ -1077,20 +1077,20 @@ void AffineCFGPass::runOnFunction() {
eqflags); eqflags);
affineIfOp = b.create<AffineIfOp>(ifOp.getLoc(), types, iset, applies, affineIfOp = b.create<AffineIfOp>(ifOp.getLoc(), types, iset, applies,
/*elseBlock=*/true); /*elseBlock=*/true);
affineIfOp.thenRegion().takeBody(ifOp.thenRegion()); affineIfOp.thenRegion().takeBody(ifOp.getThenRegion());
affineIfOp.elseRegion().takeBody(ifOp.elseRegion()); affineIfOp.elseRegion().takeBody(ifOp.getElseRegion());
for (auto &blk : affineIfOp.thenRegion()) { for (auto &blk : affineIfOp.thenRegion()) {
if (auto yop = dyn_cast<scf::YieldOp>(blk.getTerminator())) { if (auto yop = dyn_cast<scf::YieldOp>(blk.getTerminator())) {
OpBuilder b(yop); OpBuilder b(yop);
b.create<AffineYieldOp>(yop.getLoc(), yop.results()); b.create<AffineYieldOp>(yop.getLoc(), yop.getResults());
yop.erase(); yop.erase();
} }
} }
for (auto &blk : affineIfOp.elseRegion()) { for (auto &blk : affineIfOp.elseRegion()) {
if (auto yop = dyn_cast<scf::YieldOp>(blk.getTerminator())) { if (auto yop = dyn_cast<scf::YieldOp>(blk.getTerminator())) {
OpBuilder b(yop); OpBuilder b(yop);
b.create<AffineYieldOp>(yop.getLoc(), yop.results()); b.create<AffineYieldOp>(yop.getLoc(), yop.getResults());
yop.erase(); yop.erase();
} }
} }

View File

@ -57,14 +57,14 @@ static void wrapPersistingLoopBodies(FuncOp function) {
for (scf::ParallelOp op : loops) { for (scf::ParallelOp op : loops) {
OpBuilder builder = OpBuilder::atBlockBegin(op.getBody()); OpBuilder builder = OpBuilder::atBlockBegin(op.getBody());
auto wrapper = builder.create<scf::ExecuteRegionOp>( auto wrapper = builder.create<scf::ExecuteRegionOp>(
op.getLoc(), op.results().getTypes()); op.getLoc(), op.getResults().getTypes());
builder.createBlock(&wrapper.region(), wrapper.region().begin()); builder.createBlock(&wrapper.getRegion(), wrapper.getRegion().begin());
wrapper.region().front().getOperations().splice( wrapper.getRegion().front().getOperations().splice(
wrapper.region().front().begin(), op.getBody()->getOperations(), wrapper.getRegion().front().begin(), op.getBody()->getOperations(),
std::next(op.getBody()->begin()), op.getBody()->end()); std::next(op.getBody()->begin()), op.getBody()->end());
builder.setInsertionPointToEnd(op.getBody()); builder.setInsertionPointToEnd(op.getBody());
builder.create<scf::YieldOp>( builder.create<scf::YieldOp>(
wrapper.region().front().getTerminator()->getLoc(), wrapper.getRegion().front().getTerminator()->getLoc(),
wrapper.getResults()); wrapper.getResults());
} }
} }
@ -124,7 +124,7 @@ static LogicalResult splitBlocksWithBarrier(FuncOp function) {
} }
splitBlocksWithBarrier( splitBlocksWithBarrier(
cast<scf::ExecuteRegionOp>(&op.getBody()->front()).region()); cast<scf::ExecuteRegionOp>(&op.getBody()->front()).getRegion());
return success(); return success();
}); });
return success(!result.wasInterrupted()); return success(!result.wasInterrupted());
@ -288,15 +288,15 @@ emitContinuationCase(Value condition, Value storage, scf::ParallelOp parallel,
bn.create<scf::ExecuteRegionOp>(TypeRange(), ValueRange()); bn.create<scf::ExecuteRegionOp>(TypeRange(), ValueRange());
BlockAndValueMapping mapping; BlockAndValueMapping mapping;
mapping.map(parallel.getInductionVars(), ivs); mapping.map(parallel.getInductionVars(), ivs);
replicateIntoRegion(executeRegion.region(), storage, ivs, replicateIntoRegion(executeRegion.getRegion(), storage, ivs,
parallel.lowerBound(), blocks, subgraphEntryPoints, parallel.getLowerBound(), blocks, subgraphEntryPoints,
mapping, builder); mapping, builder);
}; };
auto thenBuilder = [&](OpBuilder &nested, Location loc) { auto thenBuilder = [&](OpBuilder &nested, Location loc) {
ImplicitLocOpBuilder bn(loc, nested); ImplicitLocOpBuilder bn(loc, nested);
bn.create<scf::ParallelOp>(parallel.lowerBound(), parallel.upperBound(), bn.create<scf::ParallelOp>(parallel.getLowerBound(), parallel.getUpperBound(),
parallel.step(), parallelBuilder); parallel.getStep(), parallelBuilder);
bn.create<scf::YieldOp>(); bn.create<scf::YieldOp>();
}; };
@ -362,9 +362,9 @@ findInsertionPointAfterLoopOperands(scf::ParallelOp op) {
// Find the earliest insertion point where loop bounds are fully defined. // Find the earliest insertion point where loop bounds are fully defined.
PostDominanceInfo postDominanceInfo(op->getParentOfType<FuncOp>()); PostDominanceInfo postDominanceInfo(op->getParentOfType<FuncOp>());
SmallVector<Value> operands; SmallVector<Value> operands;
llvm::append_range(operands, op.lowerBound()); llvm::append_range(operands, op.getLowerBound());
llvm::append_range(operands, op.upperBound()); llvm::append_range(operands, op.getUpperBound());
llvm::append_range(operands, op.step()); llvm::append_range(operands, op.getStep());
return findNesrestPostDominatingInsertionPoint(operands, postDominanceInfo); return findNesrestPostDominatingInsertionPoint(operands, postDominanceInfo);
} }
@ -536,8 +536,8 @@ static void createContinuations(scf::ParallelOp parallel, Value storage) {
llvm::SetVector<Block *> startBlocks; llvm::SetVector<Block *> startBlocks;
auto outerExecuteRegion = auto outerExecuteRegion =
cast<scf::ExecuteRegionOp>(&parallel.getBody()->front()); cast<scf::ExecuteRegionOp>(&parallel.getBody()->front());
startBlocks.insert(&outerExecuteRegion.region().front()); startBlocks.insert(&outerExecuteRegion.getRegion().front());
for (Block &block : outerExecuteRegion.region()) { for (Block &block : outerExecuteRegion.getRegion()) {
if (!isa_and_nonnull<polygeist::BarrierOp>( if (!isa_and_nonnull<polygeist::BarrierOp>(
block.getTerminator()->getPrevNode())) block.getTerminator()->getPrevNode()))
continue; continue;
@ -560,12 +560,12 @@ static void createContinuations(scf::ParallelOp parallel, Value storage) {
OpBuilder allocBuilder(loop); OpBuilder allocBuilder(loop);
reg2mem(subgraphs, parallel, allocBuilder, builder); reg2mem(subgraphs, parallel, allocBuilder, builder);
builder.createBlock(&loop.before(), loop.before().end()); builder.createBlock(&loop.getBefore(), loop.getBefore().end());
Value next = builder.create<memref::LoadOp>(storage); Value next = builder.create<memref::LoadOp>(storage);
Value condition = builder.create<CmpIOp>(CmpIPredicate::ne, next, negOne); Value condition = builder.create<CmpIOp>(CmpIPredicate::ne, next, negOne);
builder.create<scf::ConditionOp>(TypeRange(), condition, ValueRange()); builder.create<scf::ConditionOp>(TypeRange(), condition, ValueRange());
builder.createBlock(&loop.after(), loop.after().end()); builder.createBlock(&loop.getAfter(), loop.getAfter().end());
next = builder.create<memref::LoadOp>(storage); next = builder.create<memref::LoadOp>(storage);
SmallVector<Value> caseConditions; SmallVector<Value> caseConditions;
caseConditions.resize(startBlocks.size()); caseConditions.resize(startBlocks.size());

View File

@ -52,7 +52,7 @@ static bool hasSameInitValue(Value iter, scf::ForOp forOp) {
if (!cst) if (!cst)
return false; return false;
if (auto cstOp = dyn_cast<ConstantIntOp>(cst)) { if (auto cstOp = dyn_cast<ConstantIntOp>(cst)) {
Operation *lbDefOp = forOp.lowerBound().getDefiningOp(); Operation *lbDefOp = forOp.getLowerBound().getDefiningOp();
if (!lbDefOp) if (!lbDefOp)
return false; return false;
ConstantIndexOp lb = dyn_cast_or_null<ConstantIndexOp>(lbDefOp); ConstantIndexOp lb = dyn_cast_or_null<ConstantIndexOp>(lbDefOp);
@ -69,7 +69,7 @@ static bool hasSameStepValue(Value regIter, Value yieldOp, scf::ForOp forOp) {
if (!defOpStep) if (!defOpStep)
return false; return false;
if (auto cstStep = dyn_cast<ConstantIntOp>(defOpStep)) { if (auto cstStep = dyn_cast<ConstantIntOp>(defOpStep)) {
Operation *stepForDefOp = forOp.step().getDefiningOp(); Operation *stepForDefOp = forOp.getStep().getDefiningOp();
if (!stepForDefOp) if (!stepForDefOp)
return false; return false;
ConstantIndexOp stepFor = dyn_cast_or_null<ConstantIndexOp>(stepForDefOp); ConstantIndexOp stepFor = dyn_cast_or_null<ConstantIndexOp>(stepForDefOp);
@ -123,7 +123,7 @@ struct DetectTrivialIndVarInArgs : public OpRewritePattern<scf::ForOp> {
if (!forOp.getNumIterOperands()) if (!forOp.getNumIterOperands())
return failure(); return failure();
Block &block = forOp.region().front(); Block &block = forOp.getRegion().front();
auto yieldOp = cast<scf::YieldOp>(block.getTerminator()); auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
bool matched = false; bool matched = false;
@ -150,7 +150,7 @@ struct ForOpInductionReplacement : public OpRewritePattern<scf::ForOp> {
LogicalResult matchAndRewrite(scf::ForOp forOp, LogicalResult matchAndRewrite(scf::ForOp forOp,
PatternRewriter &rewriter) const final { PatternRewriter &rewriter) const final {
bool canonicalize = false; bool canonicalize = false;
Block &block = forOp.region().front(); Block &block = forOp.getRegion().front();
auto yieldOp = cast<scf::YieldOp>(block.getTerminator()); auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
for (auto it : llvm::zip(forOp.getIterOperands(), // iter from outside for (auto it : llvm::zip(forOp.getIterOperands(), // iter from outside
@ -171,10 +171,10 @@ struct ForOpInductionReplacement : public OpRewritePattern<scf::ForOp> {
bool isOne = false; bool isOne = false;
if (addOp.getOperand(1) == forOp.step()) { if (addOp.getOperand(1) == forOp.getStep()) {
legalStep = true; legalStep = true;
} else if (auto iter_step = } else if (auto iter_step =
forOp.step().getDefiningOp<ConstantIndexOp>()) { forOp.getStep().getDefiningOp<ConstantIndexOp>()) {
isOne |= iter_step.value() == 1; isOne |= iter_step.value() == 1;
if (auto op = addOp.getOperand(1).getDefiningOp<ConstantIntOp>()) { if (auto op = addOp.getOperand(1).getDefiningOp<ConstantIntOp>()) {
if (op.value() == iter_step.value()) { if (op.value() == iter_step.value()) {
@ -200,9 +200,9 @@ struct ForOpInductionReplacement : public OpRewritePattern<scf::ForOp> {
Value init = std::get<0>(it); Value init = std::get<0>(it);
if (!std::get<1>(it).use_empty()) { if (!std::get<1>(it).use_empty()) {
rewriter.setInsertionPointToStart(&forOp.region().front()); rewriter.setInsertionPointToStart(&forOp.getRegion().front());
Value replacement = rewriter.create<SubIOp>( Value replacement = rewriter.create<SubIOp>(
forOp.getLoc(), forOp.getInductionVar(), forOp.lowerBound()); forOp.getLoc(), forOp.getInductionVar(), forOp.getLowerBound());
if (!std::get<1>(it).getType().isa<IndexType>()) { if (!std::get<1>(it).getType().isa<IndexType>()) {
replacement = rewriter.create<IndexCastOp>( replacement = rewriter.create<IndexCastOp>(
forOp.getLoc(), replacement, std::get<1>(it).getType()); forOp.getLoc(), replacement, std::get<1>(it).getType());
@ -223,7 +223,7 @@ struct ForOpInductionReplacement : public OpRewritePattern<scf::ForOp> {
if (isOne && !std::get<2>(it).use_empty()) { if (isOne && !std::get<2>(it).use_empty()) {
rewriter.setInsertionPoint(forOp); rewriter.setInsertionPoint(forOp);
Value replacement = rewriter.create<SubIOp>( Value replacement = rewriter.create<SubIOp>(
forOp.getLoc(), forOp.upperBound(), forOp.lowerBound()); forOp.getLoc(), forOp.getUpperBound(), forOp.getLowerBound());
if (!std::get<1>(it).getType().isa<IndexType>()) { if (!std::get<1>(it).getType().isa<IndexType>()) {
replacement = rewriter.create<IndexCastOp>( replacement = rewriter.create<IndexCastOp>(
forOp.getLoc(), replacement, std::get<1>(it).getType()); forOp.getLoc(), replacement, std::get<1>(it).getType());
@ -275,7 +275,7 @@ struct RemoveUnusedArgs : public OpRewritePattern<ForOp> {
return failure(); return failure();
auto newForOp = rewriter.create<ForOp>( auto newForOp = rewriter.create<ForOp>(
op.getLoc(), op.lowerBound(), op.upperBound(), op.step(), usedOperands); op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep(), usedOperands);
if (!newForOp.getBody()->empty()) if (!newForOp.getBody()->empty())
rewriter.eraseOp(newForOp.getBody()->getTerminator()); rewriter.eraseOp(newForOp.getBody()->getTerminator());
@ -499,7 +499,7 @@ yop2.results()[idx]);
bool isWhile(WhileOp wop) { bool isWhile(WhileOp wop) {
bool hasCondOp = false; bool hasCondOp = false;
wop.before().walk([&](Operation *op) { wop.getBefore().walk([&](Operation *op) {
if (isa<scf::ConditionOp>(op)) if (isa<scf::ConditionOp>(op))
hasCondOp = true; hasCondOp = true;
}); });
@ -547,12 +547,12 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
} loopInfo; } loopInfo;
auto condOp = loop.getConditionOp(); auto condOp = loop.getConditionOp();
SmallVector<Value, 2> results = {condOp.args()}; SmallVector<Value, 2> results = {condOp.getArgs()};
auto cmpIOp = condOp.condition().getDefiningOp<CmpIOp>(); auto cmpIOp = condOp.getCondition().getDefiningOp<CmpIOp>();
if (!cmpIOp) { if (!cmpIOp) {
return failure(); return failure();
} }
size_t size = loop.before().front().getOperations().size(); size_t size = loop.getBefore().front().getOperations().size();
if (size != 2) { if (size != 2) {
return failure(); return failure();
} }
@ -560,24 +560,24 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
BlockArgument indVar = cmpIOp.getLhs().dyn_cast<BlockArgument>(); BlockArgument indVar = cmpIOp.getLhs().dyn_cast<BlockArgument>();
if (!indVar) if (!indVar)
return failure(); return failure();
if (indVar.getOwner() != &loop.before().front()) if (indVar.getOwner() != &loop.getBefore().front())
return failure(); return failure();
SmallVector<size_t, 2> afterArgs; SmallVector<size_t, 2> afterArgs;
for (auto pair : llvm::enumerate(condOp.args())) { for (auto pair : llvm::enumerate(condOp.getArgs())) {
if (pair.value() == indVar) if (pair.value() == indVar)
afterArgs.push_back(pair.index()); afterArgs.push_back(pair.index());
} }
auto endYield = cast<YieldOp>(loop.after().back().getTerminator()); auto endYield = cast<YieldOp>(loop.getAfter().back().getTerminator());
auto addIOp = auto addIOp =
endYield.results()[indVar.getArgNumber()].getDefiningOp<AddIOp>(); endYield.getResults()[indVar.getArgNumber()].getDefiningOp<AddIOp>();
if (!addIOp) if (!addIOp)
return failure(); return failure();
for (auto afterArg : afterArgs) { for (auto afterArg : afterArgs) {
auto arg = loop.after().getArgument(afterArg); auto arg = loop.getAfter().getArgument(afterArg);
if (addIOp.getOperand(0) == arg) { if (addIOp.getOperand(0) == arg) {
step = addIOp.getOperand(1); step = addIOp.getOperand(1);
break; break;
@ -679,19 +679,19 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
// input of the for goes the input of the scf::while plus the output taken // input of the for goes the input of the scf::while plus the output taken
// from the conditionOp. // from the conditionOp.
SmallVector<Value, 8> forArgs; SmallVector<Value, 8> forArgs;
forArgs.append(loop.inits().begin(), loop.inits().end()); forArgs.append(loop.getInits().begin(), loop.getInits().end());
for (Value arg : condOp.args()) { for (Value arg : condOp.getArgs()) {
Type cst = nullptr; Type cst = nullptr;
if (auto idx = arg.getDefiningOp<IndexCastOp>()) { if (auto idx = arg.getDefiningOp<IndexCastOp>()) {
cst = idx.getType(); cst = idx.getType();
arg = idx.getIn(); arg = idx.getIn();
} }
Value res; Value res;
if (isTopLevelArgValue(arg, &loop.before())) { if (isTopLevelArgValue(arg, &loop.getBefore())) {
auto blockArg = arg.cast<BlockArgument>(); auto blockArg = arg.cast<BlockArgument>();
auto pos = blockArg.getArgNumber(); auto pos = blockArg.getArgNumber();
res = loop.inits()[pos]; res = loop.getInits()[pos];
} else } else
res = arg; res = arg;
if (cst) { if (cst) {
@ -706,31 +706,31 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
if (!forloop.getBody()->empty()) if (!forloop.getBody()->empty())
rewriter.eraseOp(forloop.getBody()->getTerminator()); rewriter.eraseOp(forloop.getBody()->getTerminator());
auto oldYield = cast<scf::YieldOp>(loop.after().front().getTerminator()); auto oldYield = cast<scf::YieldOp>(loop.getAfter().front().getTerminator());
rewriter.updateRootInPlace(loop, [&] { rewriter.updateRootInPlace(loop, [&] {
for (auto pair : llvm::zip(loop.after().getArguments(), condOp.args())) { for (auto pair : llvm::zip(loop.getAfter().getArguments(), condOp.getArgs())) {
std::get<0>(pair).replaceAllUsesWith(std::get<1>(pair)); std::get<0>(pair).replaceAllUsesWith(std::get<1>(pair));
} }
}); });
loop.after().front().eraseArguments([](BlockArgument) { return true; }); loop.getAfter().front().eraseArguments([](BlockArgument) { return true; });
SmallVector<Value, 2> yieldOperands; SmallVector<Value, 2> yieldOperands;
for (auto oldYieldArg : oldYield.results()) for (auto oldYieldArg : oldYield.getResults())
yieldOperands.push_back(oldYieldArg); yieldOperands.push_back(oldYieldArg);
BlockAndValueMapping outmap; BlockAndValueMapping outmap;
outmap.map(loop.before().getArguments(), yieldOperands); outmap.map(loop.getBefore().getArguments(), yieldOperands);
for (auto arg : condOp.args()) for (auto arg : condOp.getArgs())
yieldOperands.push_back(outmap.lookupOrDefault(arg)); yieldOperands.push_back(outmap.lookupOrDefault(arg));
rewriter.setInsertionPoint(oldYield); rewriter.setInsertionPoint(oldYield);
rewriter.replaceOpWithNewOp<scf::YieldOp>(oldYield, yieldOperands); rewriter.replaceOpWithNewOp<scf::YieldOp>(oldYield, yieldOperands);
size_t pos = loop.inits().size(); size_t pos = loop.getInits().size();
rewriter.updateRootInPlace(loop, [&] { rewriter.updateRootInPlace(loop, [&] {
for (auto pair : llvm::zip(loop.before().getArguments(), for (auto pair : llvm::zip(loop.getBefore().getArguments(),
forloop.getRegionIterArgs().drop_back(pos))) { forloop.getRegionIterArgs().drop_back(pos))) {
std::get<0>(pair).replaceAllUsesWith(std::get<1>(pair)); std::get<0>(pair).replaceAllUsesWith(std::get<1>(pair));
} }
@ -738,7 +738,7 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
forloop.getBody()->getOperations().splice( forloop.getBody()->getOperations().splice(
forloop.getBody()->getOperations().begin(), forloop.getBody()->getOperations().begin(),
loop.after().front().getOperations()); loop.getAfter().front().getOperations());
SmallVector<Value, 2> replacements; SmallVector<Value, 2> replacements;
replacements.append(forloop.getResults().begin() + pos, replacements.append(forloop.getResults().begin() + pos,
@ -754,20 +754,20 @@ struct MoveWhileDown : public OpRewritePattern<WhileOp> {
LogicalResult matchAndRewrite(WhileOp op, LogicalResult matchAndRewrite(WhileOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator()); auto term = cast<scf::ConditionOp>(op.getBefore().front().getTerminator());
if (auto ifOp = term.condition().getDefiningOp<scf::IfOp>()) { if (auto ifOp = term.getCondition().getDefiningOp<scf::IfOp>()) {
if (ifOp.getNumResults() != term.args().size() + 1) if (ifOp.getNumResults() != term.getArgs().size() + 1)
return failure(); return failure();
if (ifOp.getResult(0) != term.condition()) if (ifOp.getResult(0) != term.getCondition())
return failure(); return failure();
for (size_t i = 1; i < ifOp.getNumResults(); ++i) { for (size_t i = 1; i < ifOp.getNumResults(); ++i) {
if (ifOp.getResult(i) != term.args()[i - 1]) if (ifOp.getResult(i) != term.getArgs()[i - 1])
return failure(); return failure();
} }
auto yield1 = auto yield1 =
cast<scf::YieldOp>(ifOp.thenRegion().front().getTerminator()); cast<scf::YieldOp>(ifOp.getThenRegion().front().getTerminator());
auto yield2 = auto yield2 =
cast<scf::YieldOp>(ifOp.elseRegion().front().getTerminator()); cast<scf::YieldOp>(ifOp.getElseRegion().front().getTerminator());
if (auto cop = yield1.getOperand(0).getDefiningOp<ConstantIntOp>()) { if (auto cop = yield1.getOperand(0).getDefiningOp<ConstantIntOp>()) {
if (cop.value() == 0) if (cop.value() == 0)
return failure(); return failure();
@ -778,31 +778,31 @@ struct MoveWhileDown : public OpRewritePattern<WhileOp> {
return failure(); return failure();
} else } else
return failure(); return failure();
if (ifOp.elseRegion().front().getOperations().size() != 1) if (ifOp.getElseRegion().front().getOperations().size() != 1)
return failure(); return failure();
op.after().front().getOperations().splice( op.getAfter().front().getOperations().splice(
op.after().front().begin(), op.getAfter().front().begin(),
ifOp.thenRegion().front().getOperations()); ifOp.getThenRegion().front().getOperations());
rewriter.updateRootInPlace( rewriter.updateRootInPlace(
term, [&] { term.conditionMutable().assign(ifOp.condition()); }); term, [&] { term.getConditionMutable().assign(ifOp.getCondition()); });
SmallVector<Value, 2> args; SmallVector<Value, 2> args;
for (size_t i = 1; i < yield2.getNumOperands(); ++i) { for (size_t i = 1; i < yield2.getNumOperands(); ++i) {
args.push_back(yield2.getOperand(i)); args.push_back(yield2.getOperand(i));
} }
rewriter.updateRootInPlace(term, rewriter.updateRootInPlace(term,
[&] { term.argsMutable().assign(args); }); [&] { term.getArgsMutable().assign(args); });
rewriter.eraseOp(yield2); rewriter.eraseOp(yield2);
rewriter.eraseOp(ifOp); rewriter.eraseOp(ifOp);
for (size_t i = 0; i < op.after().front().getNumArguments(); ++i) { for (size_t i = 0; i < op.getAfter().front().getNumArguments(); ++i) {
op.after().front().getArgument(i).replaceAllUsesWith( op.getAfter().front().getArgument(i).replaceAllUsesWith(
yield1.getOperand(i + 1)); yield1.getOperand(i + 1));
} }
rewriter.eraseOp(yield1); rewriter.eraseOp(yield1);
// TODO move operands from begin to after // TODO move operands from begin to after
SmallVector<Value> todo(op.before().front().getArguments().begin(), SmallVector<Value> todo(op.getBefore().front().getArguments().begin(),
op.before().front().getArguments().end()); op.getBefore().front().getArguments().end());
for (auto &op : op.before().front()) { for (auto &op : op.getBefore().front()) {
for (auto res : op.getResults()) { for (auto res : op.getResults()) {
todo.push_back(res); todo.push_back(res);
} }
@ -810,24 +810,24 @@ struct MoveWhileDown : public OpRewritePattern<WhileOp> {
rewriter.updateRootInPlace(op, [&] { rewriter.updateRootInPlace(op, [&] {
for (auto val : todo) { for (auto val : todo) {
auto na = op.after().front().addArgument(val.getType()); auto na = op.getAfter().front().addArgument(val.getType());
val.replaceUsesWithIf(na, [&](OpOperand &u) -> bool { val.replaceUsesWithIf(na, [&](OpOperand &u) -> bool {
return op.after().isAncestor(u.getOwner()->getParentRegion()); return op.getAfter().isAncestor(u.getOwner()->getParentRegion());
}); });
args.push_back(val); args.push_back(val);
} }
}); });
rewriter.updateRootInPlace(term, rewriter.updateRootInPlace(term,
[&] { term.argsMutable().assign(args); }); [&] { term.getArgsMutable().assign(args); });
SmallVector<Type, 4> tys; SmallVector<Type, 4> tys;
for (auto a : args) for (auto a : args)
tys.push_back(a.getType()); tys.push_back(a.getType());
auto op2 = rewriter.create<WhileOp>(op.getLoc(), tys, op.inits()); auto op2 = rewriter.create<WhileOp>(op.getLoc(), tys, op.getInits());
op2.before().takeBody(op.before()); op2.getBefore().takeBody(op.getBefore());
op2.after().takeBody(op.after()); op2.getAfter().takeBody(op.getAfter());
SmallVector<Value, 4> replacements; SmallVector<Value, 4> replacements;
for (auto a : op2.getResults()) { for (auto a : op2.getResults()) {
if (replacements.size() == op.getResults().size()) if (replacements.size() == op.getResults().size())
@ -882,9 +882,9 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
LogicalResult matchAndRewrite(WhileOp op, LogicalResult matchAndRewrite(WhileOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator()); auto term = cast<scf::ConditionOp>(op.getBefore().front().getTerminator());
if (auto ifOp = dyn_cast_or_null<scf::IfOp>(term->getPrevNode())) { if (auto ifOp = dyn_cast_or_null<scf::IfOp>(term->getPrevNode())) {
if (ifOp.condition() != term.condition()) if (ifOp.getCondition() != term.getCondition())
return failure(); return failure();
SmallVector<std::pair<BlockArgument, Value>, 2> m; SmallVector<std::pair<BlockArgument, Value>, 2> m;
@ -892,14 +892,14 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
SmallVector<Value, 2> prevArgs; SmallVector<Value, 2> prevArgs;
SmallVector<std::pair<size_t, Value>, 2> afterYieldRewrites; SmallVector<std::pair<size_t, Value>, 2> afterYieldRewrites;
auto afterYield = cast<YieldOp>(op.after().front().back()); auto afterYield = cast<YieldOp>(op.getAfter().front().back());
for (auto pair : for (auto pair :
llvm::zip(op.getResults(), term.args(), op.getAfterArguments())) { llvm::zip(op.getResults(), term.getArgs(), op.getAfterArguments())) {
if (std::get<1>(pair).getDefiningOp() == ifOp) { if (std::get<1>(pair).getDefiningOp() == ifOp) {
Value thenYielded, elseYielded; Value thenYielded, elseYielded;
for (auto p : llvm::zip(ifOp.thenYield().results(), ifOp.results(), for (auto p : llvm::zip(ifOp.thenYield().getResults(), ifOp.getResults(),
ifOp.elseYield().results())) { ifOp.elseYield().getResults())) {
if (std::get<1>(pair) == std::get<1>(p)) { if (std::get<1>(pair) == std::get<1>(p)) {
thenYielded = std::get<0>(p); thenYielded = std::get<0>(p);
elseYielded = std::get<2>(p); elseYielded = std::get<2>(p);
@ -911,8 +911,8 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
if (!std::get<0>(pair).use_empty()) { if (!std::get<0>(pair).use_empty()) {
if (auto blockArg = elseYielded.dyn_cast<BlockArgument>()) if (auto blockArg = elseYielded.dyn_cast<BlockArgument>())
if (blockArg.getOwner() == &op.before().front()) { if (blockArg.getOwner() == &op.getBefore().front()) {
if (afterYield.results()[blockArg.getArgNumber()] == if (afterYield.getResults()[blockArg.getArgNumber()] ==
std::get<2>(pair) && std::get<2>(pair) &&
op.getResults()[blockArg.getArgNumber()] == op.getResults()[blockArg.getArgNumber()] ==
std::get<0>(pair)) { std::get<0>(pair)) {
@ -933,18 +933,18 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
} }
} }
SmallVector<Value> yieldArgs = afterYield.results(); SmallVector<Value> yieldArgs = afterYield.getResults();
for (auto pair : afterYieldRewrites) { for (auto pair : afterYieldRewrites) {
yieldArgs[pair.first] = pair.second; yieldArgs[pair.first] = pair.second;
} }
rewriter.updateRootInPlace( rewriter.updateRootInPlace(
afterYield, [&] { afterYield.resultsMutable().assign(yieldArgs); }); afterYield, [&] { afterYield.getResultsMutable().assign(yieldArgs); });
llvm::SetVector<Value> sv; llvm::SetVector<Value> sv;
findValuesUsedBelow(ifOp, sv); findValuesUsedBelow(ifOp, sv);
Block *afterB = &op.after().front(); Block *afterB = &op.getAfter().front();
for (auto v : sv) { for (auto v : sv) {
condArgs.push_back(v); condArgs.push_back(v);
@ -956,7 +956,7 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
} }
rewriter.setInsertionPoint(term); rewriter.setInsertionPoint(term);
rewriter.replaceOpWithNewOp<ConditionOp>(term, term.condition(), rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
condArgs); condArgs);
for (int i = m.size() - 1; i >= 0; i--) { for (int i = m.size() - 1; i >= 0; i--) {
@ -977,9 +977,9 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
} }
rewriter.setInsertionPoint(op); rewriter.setInsertionPoint(op);
auto nop = rewriter.create<WhileOp>(op.getLoc(), resultTypes, op.inits()); auto nop = rewriter.create<WhileOp>(op.getLoc(), resultTypes, op.getInits());
nop.before().takeBody(op.before()); nop.getBefore().takeBody(op.getBefore());
nop.after().takeBody(op.after()); nop.getAfter().takeBody(op.getAfter());
rewriter.updateRootInPlace(op, [&] { rewriter.updateRootInPlace(op, [&] {
for (auto pair : llvm::enumerate(prevArgs)) { for (auto pair : llvm::enumerate(prevArgs)) {
@ -1002,24 +1002,24 @@ struct MoveWhileInvariantIfResult : public OpRewritePattern<WhileOp> {
SmallVector<BlockArgument, 2> origAfterArgs(op.getAfterArguments().begin(), SmallVector<BlockArgument, 2> origAfterArgs(op.getAfterArguments().begin(),
op.getAfterArguments().end()); op.getAfterArguments().end());
bool changed = false; bool changed = false;
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator()); scf::ConditionOp term = cast<scf::ConditionOp>(op.getBefore().front().getTerminator());
assert(origAfterArgs.size() == op.getResults().size()); assert(origAfterArgs.size() == op.getResults().size());
assert(origAfterArgs.size() == term.args().size()); assert(origAfterArgs.size() == term.getArgs().size());
for (auto pair : llvm::zip(op.getResults(), term.args(), origAfterArgs)) { for (auto pair : llvm::zip(op.getResults(), term.getArgs(), origAfterArgs)) {
if (!std::get<0>(pair).use_empty()) { if (!std::get<0>(pair).use_empty()) {
if (auto ifOp = std::get<1>(pair).getDefiningOp<scf::IfOp>()) { if (auto ifOp = std::get<1>(pair).getDefiningOp<scf::IfOp>()) {
if (ifOp.condition() == term.condition()) { if (ifOp.getCondition() == term.getCondition()) {
ssize_t idx = -1; ssize_t idx = -1;
for (auto tup : llvm::enumerate(ifOp.results())) { for (auto tup : llvm::enumerate(ifOp.getResults())) {
if (tup.value() == std::get<1>(pair)) { if (tup.value() == std::get<1>(pair)) {
idx = tup.index(); idx = tup.index();
break; break;
} }
} }
assert(idx != -1); assert(idx != -1);
Value returnWith = ifOp.elseYield().results()[idx]; Value returnWith = ifOp.elseYield().getResults()[idx];
if (!op.before().isAncestor(returnWith.getParentRegion())) { if (!op.getBefore().isAncestor(returnWith.getParentRegion())) {
rewriter.updateRootInPlace(op, [&] { rewriter.updateRootInPlace(op, [&] {
std::get<0>(pair).replaceAllUsesWith(returnWith); std::get<0>(pair).replaceAllUsesWith(returnWith);
}); });
@ -1042,12 +1042,12 @@ struct WhileLogicalNegation : public OpRewritePattern<WhileOp> {
SmallVector<BlockArgument, 2> origAfterArgs(op.getAfterArguments().begin(), SmallVector<BlockArgument, 2> origAfterArgs(op.getAfterArguments().begin(),
op.getAfterArguments().end()); op.getAfterArguments().end());
bool changed = false; bool changed = false;
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator()); scf::ConditionOp term = cast<scf::ConditionOp>(op.getBefore().front().getTerminator());
assert(origAfterArgs.size() == op.getResults().size()); assert(origAfterArgs.size() == op.getResults().size());
assert(origAfterArgs.size() == term.args().size()); assert(origAfterArgs.size() == term.getArgs().size());
if (auto condCmp = term.condition().getDefiningOp<CmpIOp>()) { if (auto condCmp = term.getCondition().getDefiningOp<CmpIOp>()) {
for (auto pair : llvm::zip(op.getResults(), term.args(), origAfterArgs)) { for (auto pair : llvm::zip(op.getResults(), term.getArgs(), origAfterArgs)) {
if (!std::get<0>(pair).use_empty()) { if (!std::get<0>(pair).use_empty()) {
if (auto termCmp = std::get<1>(pair).getDefiningOp<CmpIOp>()) { if (auto termCmp = std::get<1>(pair).getDefiningOp<CmpIOp>()) {
if (termCmp.getLhs() == condCmp.getLhs() && if (termCmp.getLhs() == condCmp.getLhs() &&
@ -1082,41 +1082,41 @@ struct WhileCmpOffset : public OpRewritePattern<WhileOp> {
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
SmallVector<BlockArgument, 2> origAfterArgs(op.getAfterArguments().begin(), SmallVector<BlockArgument, 2> origAfterArgs(op.getAfterArguments().begin(),
op.getAfterArguments().end()); op.getAfterArguments().end());
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator()); scf::ConditionOp term = cast<scf::ConditionOp>(op.getBefore().front().getTerminator());
assert(origAfterArgs.size() == op.getResults().size()); assert(origAfterArgs.size() == op.getResults().size());
assert(origAfterArgs.size() == term.args().size()); assert(origAfterArgs.size() == term.getArgs().size());
if (auto condCmp = term.condition().getDefiningOp<CmpIOp>()) { if (auto condCmp = term.getCondition().getDefiningOp<CmpIOp>()) {
if (auto addI = condCmp.getLhs().getDefiningOp<AddIOp>()) { if (auto addI = condCmp.getLhs().getDefiningOp<AddIOp>()) {
if (addI.getOperand(1).getDefiningOp() && if (addI.getOperand(1).getDefiningOp() &&
!op.before().isAncestor( !op.getBefore().isAncestor(
addI.getOperand(1).getDefiningOp()->getParentRegion())) addI.getOperand(1).getDefiningOp()->getParentRegion()))
if (auto blockArg = addI.getOperand(0).dyn_cast<BlockArgument>()) { if (auto blockArg = addI.getOperand(0).dyn_cast<BlockArgument>()) {
if (blockArg.getOwner() == &op.before().front()) { if (blockArg.getOwner() == &op.getBefore().front()) {
auto rng = llvm::make_early_inc_range(blockArg.getUses()); auto rng = llvm::make_early_inc_range(blockArg.getUses());
{ {
rewriter.setInsertionPoint(op); rewriter.setInsertionPoint(op);
SmallVector<Value> oldInits = op.inits(); SmallVector<Value> oldInits = op.getInits();
oldInits[blockArg.getArgNumber()] = rewriter.create<AddIOp>( oldInits[blockArg.getArgNumber()] = rewriter.create<AddIOp>(
addI.getLoc(), oldInits[blockArg.getArgNumber()], addI.getLoc(), oldInits[blockArg.getArgNumber()],
addI.getOperand(1)); addI.getOperand(1));
op.initsMutable().assign(oldInits); op.getInitsMutable().assign(oldInits);
rewriter.updateRootInPlace( rewriter.updateRootInPlace(
addI, [&] { addI.replaceAllUsesWith(blockArg); }); addI, [&] { addI.replaceAllUsesWith(blockArg); });
} }
YieldOp afterYield = cast<YieldOp>(op.after().front().back()); YieldOp afterYield = cast<YieldOp>(op.getAfter().front().back());
rewriter.setInsertionPoint(afterYield); rewriter.setInsertionPoint(afterYield);
SmallVector<Value> oldYields = afterYield.results(); SmallVector<Value> oldYields = afterYield.getResults();
oldYields[blockArg.getArgNumber()] = rewriter.create<AddIOp>( oldYields[blockArg.getArgNumber()] = rewriter.create<AddIOp>(
addI.getLoc(), oldYields[blockArg.getArgNumber()], addI.getLoc(), oldYields[blockArg.getArgNumber()],
addI.getOperand(1)); addI.getOperand(1));
rewriter.updateRootInPlace(afterYield, [&] { rewriter.updateRootInPlace(afterYield, [&] {
afterYield.resultsMutable().assign(oldYields); afterYield.getResultsMutable().assign(oldYields);
}); });
rewriter.setInsertionPointToStart(&op.before().front()); rewriter.setInsertionPointToStart(&op.getBefore().front());
auto sub = rewriter.create<SubIOp>(addI.getLoc(), blockArg, auto sub = rewriter.create<SubIOp>(addI.getLoc(), blockArg,
addI.getOperand(1)); addI.getOperand(1));
for (OpOperand &use : rng) { for (OpOperand &use : rng) {
@ -1139,7 +1139,7 @@ struct MoveWhileDown3 : public OpRewritePattern<WhileOp> {
LogicalResult matchAndRewrite(WhileOp op, LogicalResult matchAndRewrite(WhileOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator()); scf::ConditionOp term = cast<scf::ConditionOp>(op.getBefore().front().getTerminator());
SmallVector<unsigned, 2> toErase; SmallVector<unsigned, 2> toErase;
SmallVector<Value, 2> newOps; SmallVector<Value, 2> newOps;
SmallVector<Value, 2> condOps; SmallVector<Value, 2> condOps;
@ -1147,8 +1147,8 @@ struct MoveWhileDown3 : public OpRewritePattern<WhileOp> {
op.getAfterArguments().end()); op.getAfterArguments().end());
SmallVector<Value, 2> returns; SmallVector<Value, 2> returns;
assert(origAfterArgs.size() == op.getResults().size()); assert(origAfterArgs.size() == op.getResults().size());
assert(origAfterArgs.size() == term.args().size()); assert(origAfterArgs.size() == term.getArgs().size());
for (auto pair : llvm::zip(op.getResults(), term.args(), origAfterArgs)) { for (auto pair : llvm::zip(op.getResults(), term.getArgs(), origAfterArgs)) {
if (std::get<0>(pair).use_empty()) { if (std::get<0>(pair).use_empty()) {
if (std::get<2>(pair).use_empty()) { if (std::get<2>(pair).use_empty()) {
toErase.push_back(std::get<2>(pair).getArgNumber()); toErase.push_back(std::get<2>(pair).getArgNumber());
@ -1162,9 +1162,9 @@ struct MoveWhileDown3 : public OpRewritePattern<WhileOp> {
Operation *cloned = std::get<1>(pair).getDefiningOp(); Operation *cloned = std::get<1>(pair).getDefiningOp();
if (!std::get<1>(pair).hasOneUse()) { if (!std::get<1>(pair).hasOneUse()) {
cloned = std::get<1>(pair).getDefiningOp()->clone(); cloned = std::get<1>(pair).getDefiningOp()->clone();
op.after().front().push_front(cloned); op.getAfter().front().push_front(cloned);
} else { } else {
cloned->moveBefore(&op.after().front().front()); cloned->moveBefore(&op.getAfter().front().front());
} }
rewriter.updateRootInPlace(std::get<1>(pair).getDefiningOp(), [&] { rewriter.updateRootInPlace(std::get<1>(pair).getDefiningOp(), [&] {
std::get<2>(pair).replaceAllUsesWith(cloned->getResult(0)); std::get<2>(pair).replaceAllUsesWith(cloned->getResult(0));
@ -1174,7 +1174,7 @@ struct MoveWhileDown3 : public OpRewritePattern<WhileOp> {
llvm::make_early_inc_range(cloned->getOpOperands())) { llvm::make_early_inc_range(cloned->getOpOperands())) {
{ {
newOps.push_back(o.get()); newOps.push_back(o.get());
o.set(op.after().front().addArgument(o.get().getType())); o.set(op.getAfter().front().addArgument(o.get().getType()));
} }
} }
continue; continue;
@ -1190,19 +1190,19 @@ struct MoveWhileDown3 : public OpRewritePattern<WhileOp> {
condOps.append(newOps.begin(), newOps.end()); condOps.append(newOps.begin(), newOps.end());
rewriter.updateRootInPlace( rewriter.updateRootInPlace(
term, [&] { op.after().front().eraseArguments(toErase); }); term, [&] { op.getAfter().front().eraseArguments(toErase); });
rewriter.setInsertionPoint(term); rewriter.setInsertionPoint(term);
rewriter.replaceOpWithNewOp<ConditionOp>(term, term.condition(), condOps); rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(), condOps);
rewriter.setInsertionPoint(op); rewriter.setInsertionPoint(op);
SmallVector<Type, 4> resultTypes; SmallVector<Type, 4> resultTypes;
for (auto v : condOps) { for (auto v : condOps) {
resultTypes.push_back(v.getType()); resultTypes.push_back(v.getType());
} }
auto nop = rewriter.create<WhileOp>(op.getLoc(), resultTypes, op.inits()); auto nop = rewriter.create<WhileOp>(op.getLoc(), resultTypes, op.getInits());
nop.before().takeBody(op.before()); nop.getBefore().takeBody(op.getBefore());
nop.after().takeBody(op.after()); nop.getAfter().takeBody(op.getAfter());
rewriter.updateRootInPlace(op, [&] { rewriter.updateRootInPlace(op, [&] {
for (auto pair : llvm::enumerate(returns)) { for (auto pair : llvm::enumerate(returns)) {
@ -1210,7 +1210,7 @@ struct MoveWhileDown3 : public OpRewritePattern<WhileOp> {
} }
}); });
assert(resultTypes.size() == nop.after().front().getNumArguments()); assert(resultTypes.size() == nop.getAfter().front().getNumArguments());
assert(resultTypes.size() == condOps.size()); assert(resultTypes.size() == condOps.size());
rewriter.eraseOp(op); rewriter.eraseOp(op);
@ -1269,7 +1269,7 @@ struct WhileLICM : public OpRewritePattern<WhileOp> {
auto definingOp = value.getDefiningOp(); auto definingOp = value.getDefiningOp();
bool definedOutside = bool definedOutside =
(definingOp && !!willBeMovedSet.count(definingOp)) || (definingOp && !!willBeMovedSet.count(definingOp)) ||
!op.before().isAncestor(value.getParentRegion()); !op.getBefore().isAncestor(value.getParentRegion());
return definedOutside; return definedOutside;
}; };
@ -1277,7 +1277,7 @@ struct WhileLICM : public OpRewritePattern<WhileOp> {
// hoist operations from there. These regions might have semantics unknown // hoist operations from there. These regions might have semantics unknown
// to this rewriting. If the nested regions are loops, they will have been // to this rewriting. If the nested regions are loops, they will have been
// processed. // processed.
for (auto &block : op.before()) { for (auto &block : op.getBefore()) {
for (auto &op : block.without_terminator()) { for (auto &op : block.without_terminator()) {
bool legal = canBeHoisted(&op, isDefinedOutsideOfBody); bool legal = canBeHoisted(&op, isDefinedOutsideOfBody);
if (legal) { if (legal) {
@ -1299,7 +1299,7 @@ struct RemoveUnusedCondVar : public OpRewritePattern<WhileOp> {
LogicalResult matchAndRewrite(WhileOp op, LogicalResult matchAndRewrite(WhileOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator()); auto term = cast<scf::ConditionOp>(op.getBefore().front().getTerminator());
SmallVector<Value, 4> conds; SmallVector<Value, 4> conds;
SmallVector<unsigned, 4> eraseArgs; SmallVector<unsigned, 4> eraseArgs;
SmallVector<unsigned, 4> keepArgs; SmallVector<unsigned, 4> keepArgs;
@ -1308,7 +1308,7 @@ struct RemoveUnusedCondVar : public OpRewritePattern<WhileOp> {
std::map<void *, unsigned> valueOffsets; std::map<void *, unsigned> valueOffsets;
std::map<unsigned, unsigned> resultOffsets; std::map<unsigned, unsigned> resultOffsets;
SmallVector<Value, 4> resultArgs; SmallVector<Value, 4> resultArgs;
for (auto pair : llvm::zip(term.args(), op.after().front().getArguments(), for (auto pair : llvm::zip(term.getArgs(), op.getAfter().front().getArguments(),
op.getResults())) { op.getResults())) {
auto arg = std::get<0>(pair); auto arg = std::get<0>(pair);
auto afarg = std::get<1>(pair); auto afarg = std::get<1>(pair);
@ -1331,24 +1331,24 @@ struct RemoveUnusedCondVar : public OpRewritePattern<WhileOp> {
} }
i++; i++;
} }
assert(i == op.after().front().getArguments().size()); assert(i == op.getAfter().front().getArguments().size());
if (eraseArgs.size() != 0) { if (eraseArgs.size() != 0) {
rewriter.setInsertionPoint(term); rewriter.setInsertionPoint(term);
rewriter.replaceOpWithNewOp<scf::ConditionOp>(term, term.condition(), rewriter.replaceOpWithNewOp<scf::ConditionOp>(term, term.getCondition(),
conds); conds);
rewriter.setInsertionPoint(op); rewriter.setInsertionPoint(op);
auto op2 = rewriter.create<WhileOp>(op.getLoc(), tys, op.inits()); auto op2 = rewriter.create<WhileOp>(op.getLoc(), tys, op.getInits());
op2.before().takeBody(op.before()); op2.getBefore().takeBody(op.getBefore());
op2.after().takeBody(op.after()); op2.getAfter().takeBody(op.getAfter());
for (auto pair : resultOffsets) { for (auto pair : resultOffsets) {
op.getResult(pair.first).replaceAllUsesWith(op2.getResult(pair.second)); op.getResult(pair.first).replaceAllUsesWith(op2.getResult(pair.second));
} }
rewriter.eraseOp(op); rewriter.eraseOp(op);
op2.after().front().eraseArguments(eraseArgs); op2.getAfter().front().eraseArguments(eraseArgs);
return success(); return success();
} }
return failure(); return failure();
@ -1360,19 +1360,19 @@ struct MoveSideEffectFreeWhile : public OpRewritePattern<WhileOp> {
LogicalResult matchAndRewrite(WhileOp op, LogicalResult matchAndRewrite(WhileOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator()); scf::ConditionOp term = cast<scf::ConditionOp>(op.getBefore().front().getTerminator());
SmallVector<Value, 4> conds(term.args().begin(), term.args().end()); SmallVector<Value, 4> conds(term.getArgs().begin(), term.getArgs().end());
bool changed = false; bool changed = false;
unsigned i = 0; unsigned i = 0;
for (auto arg : term.args()) { for (auto arg : term.getArgs()) {
if (auto IC = arg.getDefiningOp<IndexCastOp>()) { if (auto IC = arg.getDefiningOp<IndexCastOp>()) {
if (arg.hasOneUse() && op.getResult(i).use_empty()) { if (arg.hasOneUse() && op.getResult(i).use_empty()) {
auto rep = auto rep =
op.after().front().addArgument(IC->getOperand(0).getType()); op.getAfter().front().addArgument(IC->getOperand(0).getType());
IC->moveBefore(&op.after().front(), op.after().front().begin()); IC->moveBefore(&op.getAfter().front(), op.getAfter().front().begin());
conds.push_back(IC.getIn()); conds.push_back(IC.getIn());
IC.getInMutable().assign(rep); IC.getInMutable().assign(rep);
op.after().front().getArgument(i).replaceAllUsesWith( op.getAfter().front().getArgument(i).replaceAllUsesWith(
IC->getResult(0)); IC->getResult(0));
changed = true; changed = true;
} }
@ -1384,9 +1384,9 @@ struct MoveSideEffectFreeWhile : public OpRewritePattern<WhileOp> {
for (auto arg : conds) { for (auto arg : conds) {
tys.push_back(arg.getType()); tys.push_back(arg.getType());
} }
auto op2 = rewriter.create<WhileOp>(op.getLoc(), tys, op.inits()); auto op2 = rewriter.create<WhileOp>(op.getLoc(), tys, op.getInits());
op2.before().takeBody(op.before()); op2.getBefore().takeBody(op.getBefore());
op2.after().takeBody(op.after()); op2.getAfter().takeBody(op.getAfter());
unsigned j = 0; unsigned j = 0;
for (auto a : op.getResults()) { for (auto a : op.getResults()) {
a.replaceAllUsesWith(op2.getResult(j)); a.replaceAllUsesWith(op2.getResult(j));
@ -1394,7 +1394,7 @@ struct MoveSideEffectFreeWhile : public OpRewritePattern<WhileOp> {
} }
rewriter.eraseOp(op); rewriter.eraseOp(op);
rewriter.setInsertionPoint(term); rewriter.setInsertionPoint(term);
rewriter.replaceOpWithNewOp<scf::ConditionOp>(term, term.condition(), rewriter.replaceOpWithNewOp<scf::ConditionOp>(term, term.getCondition(),
conds); conds);
return success(); return success();
} }

View File

@ -295,8 +295,8 @@ bool LoopRestructure::removeIfFromRegion(DominanceInfo &domInfo, Region &region,
/*hasElse*/ true); /*hasElse*/ true);
Succs[j] = new Block(); Succs[j] = new Block();
if (j == 0) { if (j == 0) {
ifOp.elseRegion().getBlocks().splice( ifOp.getElseRegion().getBlocks().splice(
ifOp.elseRegion().getBlocks().end(), region.getBlocks(), ifOp.getElseRegion().getBlocks().end(), region.getBlocks(),
Succs[1 - j]); Succs[1 - j]);
SmallVector<unsigned, 4> idx; SmallVector<unsigned, 4> idx;
for (size_t i = 0; i < Succs[1 - j]->getNumArguments(); ++i) { for (size_t i = 0; i < Succs[1 - j]->getNumArguments(); ++i) {
@ -305,18 +305,18 @@ bool LoopRestructure::removeIfFromRegion(DominanceInfo &domInfo, Region &region,
idx.push_back(i); idx.push_back(i);
} }
Succs[1 - j]->eraseArguments(idx); Succs[1 - j]->eraseArguments(idx);
assert(!ifOp.elseRegion().getBlocks().empty()); assert(!ifOp.getElseRegion().getBlocks().empty());
assert(condTys.size() == condBr.getTrueOperands().size()); assert(condTys.size() == condBr.getTrueOperands().size());
OpBuilder tbuilder(&ifOp.thenRegion().front(), OpBuilder tbuilder(&ifOp.getThenRegion().front(),
ifOp.thenRegion().front().begin()); ifOp.getThenRegion().front().begin());
tbuilder.create<scf::YieldOp>(tbuilder.getUnknownLoc(), emptyTys, tbuilder.create<scf::YieldOp>(tbuilder.getUnknownLoc(), emptyTys,
condBr.getTrueOperands()); condBr.getTrueOperands());
} else { } else {
if (!ifOp.thenRegion().getBlocks().empty()) { if (!ifOp.getThenRegion().getBlocks().empty()) {
ifOp.thenRegion().front().erase(); ifOp.getThenRegion().front().erase();
} }
ifOp.thenRegion().getBlocks().splice( ifOp.getThenRegion().getBlocks().splice(
ifOp.thenRegion().getBlocks().end(), region.getBlocks(), ifOp.getThenRegion().getBlocks().end(), region.getBlocks(),
Succs[1 - j]); Succs[1 - j]);
SmallVector<unsigned, 4> idx; SmallVector<unsigned, 4> idx;
for (size_t i = 0; i < Succs[1 - j]->getNumArguments(); ++i) { for (size_t i = 0; i < Succs[1 - j]->getNumArguments(); ++i) {
@ -325,9 +325,9 @@ bool LoopRestructure::removeIfFromRegion(DominanceInfo &domInfo, Region &region,
idx.push_back(i); idx.push_back(i);
} }
Succs[1 - j]->eraseArguments(idx); Succs[1 - j]->eraseArguments(idx);
assert(!ifOp.elseRegion().getBlocks().empty()); assert(!ifOp.getElseRegion().getBlocks().empty());
OpBuilder tbuilder(&ifOp.elseRegion().front(), OpBuilder tbuilder(&ifOp.getElseRegion().front(),
ifOp.elseRegion().front().begin()); ifOp.getElseRegion().front().begin());
assert(condTys.size() == condBr.getFalseOperands().size()); assert(condTys.size() == condBr.getFalseOperands().size());
tbuilder.create<scf::YieldOp>(tbuilder.getUnknownLoc(), emptyTys, tbuilder.create<scf::YieldOp>(tbuilder.getUnknownLoc(), emptyTys,
condBr.getFalseOperands()); condBr.getFalseOperands());
@ -449,12 +449,12 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region &region) {
Preds.push_back(block); Preds.push_back(block);
} }
loop.before().getBlocks().splice(loop.before().getBlocks().begin(), loop.getBefore().getBlocks().splice(loop.getBefore().getBlocks().begin(),
region.getBlocks(), header); region.getBlocks(), header);
for (auto *w : L->getBlocks()) { for (auto *w : L->getBlocks()) {
Block *b = &**w; Block *b = &**w;
if (b != header) { if (b != header) {
loop.before().getBlocks().splice(loop.before().getBlocks().end(), loop.getBefore().getBlocks().splice(loop.getBefore().getBlocks().end(),
region.getBlocks(), b); region.getBlocks(), b);
} }
} }
@ -462,7 +462,7 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region &region) {
Block *pseudoExit = new Block(); Block *pseudoExit = new Block();
auto i1Ty = builder.getI1Type(); auto i1Ty = builder.getI1Type();
{ {
loop.before().push_back(pseudoExit); loop.getBefore().push_back(pseudoExit);
SmallVector<Type, 4> tys = {i1Ty}; SmallVector<Type, 4> tys = {i1Ty};
for (auto t : combinedTypes) for (auto t : combinedTypes)
tys.push_back(t); tys.push_back(t);
@ -592,7 +592,7 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region &region) {
Block *after = new Block(); Block *after = new Block();
after->addArguments(combinedTypes); after->addArguments(combinedTypes);
loop.after().push_back(after); loop.getAfter().push_back(after);
OpBuilder builder2(after, after->begin()); OpBuilder builder2(after, after->begin());
SmallVector<Value, 4> yieldargs; SmallVector<Value, 4> yieldargs;
for (auto a : after->getArguments()) { for (auto a : after->getArguments()) {
@ -619,42 +619,42 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region &region) {
} }
builder2.create<scf::YieldOp>(builder.getUnknownLoc(), yieldargs); builder2.create<scf::YieldOp>(builder.getUnknownLoc(), yieldargs);
domInfo.invalidate(&loop.before()); domInfo.invalidate(&loop.getBefore());
runOnRegion(domInfo, loop.before()); runOnRegion(domInfo, loop.getBefore());
if (!removeIfFromRegion(domInfo, loop.before(), pseudoExit)) { if (!removeIfFromRegion(domInfo, loop.getBefore(), pseudoExit)) {
attemptToFoldIntoPredecessor(pseudoExit); attemptToFoldIntoPredecessor(pseudoExit);
} }
attemptToFoldIntoPredecessor(wrapper); attemptToFoldIntoPredecessor(wrapper);
attemptToFoldIntoPredecessor(target); attemptToFoldIntoPredecessor(target);
if (loop.before().getBlocks().size() != 1) { if (loop.getBefore().getBlocks().size() != 1) {
Block *blk = new Block(); Block *blk = new Block();
OpBuilder B(loop.getContext()); OpBuilder B(loop.getContext());
B.setInsertionPointToEnd(blk); B.setInsertionPointToEnd(blk);
auto cop = auto cop =
cast<scf::ConditionOp>(loop.before().getBlocks().back().back()); cast<scf::ConditionOp>(loop.getBefore().getBlocks().back().back());
auto er = B.create<scf::ExecuteRegionOp>(loop.getLoc(), auto er = B.create<scf::ExecuteRegionOp>(loop.getLoc(),
cop.getOperandTypes()); cop.getOperandTypes());
er.region().getBlocks().splice(er.region().getBlocks().begin(), er.getRegion().getBlocks().splice(er.getRegion().getBlocks().begin(),
loop.before().getBlocks()); loop.getBefore().getBlocks());
loop.before().push_back(blk); loop.getBefore().push_back(blk);
SmallVector<Value> yields; SmallVector<Value> yields;
for (auto a : er.getResults()) for (auto a : er.getResults())
yields.push_back(a); yields.push_back(a);
yields.erase(yields.begin()); yields.erase(yields.begin());
B.create<scf::ConditionOp>(cop.getLoc(), er.getResult(0), yields); B.create<scf::ConditionOp>(cop.getLoc(), er.getResult(0), yields);
B.setInsertionPoint(&*cop); B.setInsertionPoint(&*cop);
for (auto arg : er.region().front().getArguments()) { for (auto arg : er.getRegion().front().getArguments()) {
auto na = blk->addArgument(arg.getType()); auto na = blk->addArgument(arg.getType());
arg.replaceAllUsesWith(na); arg.replaceAllUsesWith(na);
} }
er.region().front().eraseArguments([](BlockArgument) { return true; }); er.getRegion().front().eraseArguments([](BlockArgument) { return true; });
B.create<scf::YieldOp>(cop.getLoc(), cop.getOperands()); B.create<scf::YieldOp>(cop.getLoc(), cop.getOperands());
cop.erase(); cop.erase();
} }
assert(loop.before().getBlocks().size() == 1); assert(loop.getBefore().getBlocks().size() == 1);
runOnRegion(domInfo, loop.after()); runOnRegion(domInfo, loop.getAfter());
assert(loop.after().getBlocks().size() == 1); assert(loop.getAfter().getBlocks().size() == 1);
} }
} }

View File

@ -453,7 +453,7 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
for (auto a : ops) { for (auto a : ops) {
if (StoringOperations.count(a)) { if (StoringOperations.count(a)) {
if (auto exOp = dyn_cast<mlir::scf::ExecuteRegionOp>(a)) { if (auto exOp = dyn_cast<mlir::scf::ExecuteRegionOp>(a)) {
valueAtStartOfBlock[&*exOp.region().begin()] = lastVal; valueAtStartOfBlock[&*exOp.getRegion().begin()] = lastVal;
Value thenVal; // = handleBlock(exOp.region().front(), lastVal); Value thenVal; // = handleBlock(exOp.region().front(), lastVal);
lastVal = nullptr; lastVal = nullptr;
seenSubStore = true; seenSubStore = true;
@ -477,7 +477,7 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
continue; continue;
} }
Block &then = exOp.region().back(); Block &then = exOp.getRegion().back();
OpBuilder B(exOp.getContext()); OpBuilder B(exOp.getContext());
auto yieldOp = cast<mlir::scf::YieldOp>(then.back()); auto yieldOp = cast<mlir::scf::YieldOp>(then.back());
B.setInsertionPoint(yieldOp); B.setInsertionPoint(yieldOp);
@ -502,13 +502,13 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
auto nextIf = auto nextIf =
B.create<mlir::scf::ExecuteRegionOp>(exOp.getLoc(), tys); B.create<mlir::scf::ExecuteRegionOp>(exOp.getLoc(), tys);
SmallVector<mlir::Value, 4> thenVals = yieldOp.results(); SmallVector<mlir::Value, 4> thenVals = yieldOp.getResults();
thenVals.push_back(thenVal); thenVals.push_back(thenVal);
yieldOp->setOperands(thenVals); yieldOp->setOperands(thenVals);
nextIf.region().getBlocks().clear(); nextIf.getRegion().getBlocks().clear();
nextIf.region().getBlocks().splice( nextIf.getRegion().getBlocks().splice(
nextIf.region().getBlocks().begin(), nextIf.getRegion().getBlocks().begin(),
exOp.region().getBlocks()); exOp.getRegion().getBlocks());
SmallVector<mlir::Value, 3> resvals = (nextIf.getResults()); SmallVector<mlir::Value, 3> resvals = (nextIf.getResults());
lastVal = resvals.back(); lastVal = resvals.back();
@ -559,15 +559,15 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
lastVal = newLoad->getResult(0); lastVal = newLoad->getResult(0);
} }
valueAtStartOfBlock[&*ifOp.thenRegion().begin()] = lastVal; valueAtStartOfBlock[&*ifOp.getThenRegion().begin()] = lastVal;
mlir::Value thenVal = mlir::Value thenVal =
handleBlock(*ifOp.thenRegion().begin(), lastVal); handleBlock(*ifOp.getThenRegion().begin(), lastVal);
if (lastVal && ifOp.elseRegion().getBlocks().size()) if (lastVal && ifOp.getElseRegion().getBlocks().size())
valueAtStartOfBlock[&*ifOp.elseRegion().begin()] = lastVal; valueAtStartOfBlock[&*ifOp.getElseRegion().begin()] = lastVal;
mlir::Value elseVal = mlir::Value elseVal =
(ifOp.elseRegion().getBlocks().size()) (ifOp.getElseRegion().getBlocks().size())
? handleBlock(*ifOp.elseRegion().begin(), lastVal) ? handleBlock(*ifOp.getElseRegion().begin(), lastVal)
: lastVal; : lastVal;
if (thenVal == elseVal && thenVal != nullptr) { if (thenVal == elseVal && thenVal != nullptr) {
@ -582,41 +582,37 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
ifOp.getResultTypes().end()); ifOp.getResultTypes().end());
tys.push_back(thenVal.getType()); tys.push_back(thenVal.getType());
auto nextIf = B.create<mlir::scf::IfOp>( auto nextIf = B.create<mlir::scf::IfOp>(
ifOp.getLoc(), tys, ifOp.condition(), /*hasElse*/ true); ifOp.getLoc(), tys, ifOp.getCondition(), /*hasElse*/ true);
Block &then = ifOp.thenRegion().back(); Block &then = ifOp.getThenRegion().back();
SmallVector<mlir::Value, 4> thenVals = SmallVector<mlir::Value, 4> thenVals =
cast<mlir::scf::YieldOp>(then.back()).results(); cast<mlir::scf::YieldOp>(then.back()).getResults();
thenVals.push_back(thenVal); thenVals.push_back(thenVal);
nextIf.thenRegion().getBlocks().clear(); nextIf.getThenRegion().getBlocks().clear();
nextIf.thenRegion().getBlocks().splice( nextIf.getThenRegion().takeBody(ifOp.getThenRegion());
nextIf.thenRegion().getBlocks().begin(),
ifOp.thenRegion().getBlocks());
cast<mlir::scf::YieldOp>( cast<mlir::scf::YieldOp>(
nextIf.thenRegion().back().getTerminator()) nextIf.getThenRegion().back().getTerminator())
->setOperands(thenVals); ->setOperands(thenVals);
if (ifOp.elseRegion().getBlocks().size()) { if (ifOp.getElseRegion().getBlocks().size()) {
nextIf.elseRegion().getBlocks().clear(); nextIf.getElseRegion().getBlocks().clear();
SmallVector<mlir::Value, 4> elseVals = SmallVector<mlir::Value, 4> elseVals =
cast<mlir::scf::YieldOp>(ifOp.elseRegion().back().back()) cast<mlir::scf::YieldOp>(ifOp.getElseRegion().back().back())
.results(); .getResults();
elseVals.push_back(elseVal); elseVals.push_back(elseVal);
nextIf.elseRegion().getBlocks().splice( nextIf.getElseRegion().takeBody(ifOp.getElseRegion());
nextIf.elseRegion().getBlocks().begin(),
ifOp.elseRegion().getBlocks());
cast<mlir::scf::YieldOp>( cast<mlir::scf::YieldOp>(
nextIf.elseRegion().back().getTerminator()) nextIf.getElseRegion().back().getTerminator())
->setOperands(elseVals); ->setOperands(elseVals);
} else { } else {
B.setInsertionPoint(&nextIf.elseRegion().back(), B.setInsertionPoint(&nextIf.getElseRegion().back(),
nextIf.elseRegion().back().begin()); nextIf.getElseRegion().back().begin());
SmallVector<mlir::Value, 4> elseVals; SmallVector<mlir::Value, 4> elseVals;
elseVals.push_back(elseVal); elseVals.push_back(elseVal);
B.create<mlir::scf::YieldOp>(ifOp.getLoc(), elseVals); B.create<mlir::scf::YieldOp>(ifOp.getLoc(), elseVals);
} }
SmallVector<mlir::Value, 3> resvals = (nextIf.results()); SmallVector<mlir::Value, 3> resvals = nextIf.getResults();
lastVal = resvals.back(); lastVal = resvals.back();
resvals.pop_back(); resvals.pop_back();
ifOp.replaceAllUsesWith(resvals); ifOp.replaceAllUsesWith(resvals);

View File

@ -101,7 +101,7 @@ struct ReplaceIfWithFors : public OpRewritePattern<scf::IfOp> {
LogicalResult matchAndRewrite(scf::IfOp op, LogicalResult matchAndRewrite(scf::IfOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
assert(op.condition().getType().isInteger(1)); assert(op.getCondition().getType().isInteger(1));
if (!hasNestedBarrier(op)) { if (!hasNestedBarrier(op)) {
LLVM_DEBUG(DBGS() << "[if-to-for] no nested barrier\n"); LLVM_DEBUG(DBGS() << "[if-to-for] no nested barrier\n");
@ -119,7 +119,7 @@ struct ReplaceIfWithFors : public OpRewritePattern<scf::IfOp> {
auto cond = rewriter.create<IndexCastOp>( auto cond = rewriter.create<IndexCastOp>(
loc, rewriter.getIndexType(), loc, rewriter.getIndexType(),
rewriter.create<ExtUIOp>(loc, op.condition(), rewriter.create<ExtUIOp>(loc, op.getCondition(),
mlir::IntegerType::get(one.getContext(), 64))); mlir::IntegerType::get(one.getContext(), 64)));
auto thenLoop = rewriter.create<scf::ForOp>(loc, zero, cond, one, forArgs); auto thenLoop = rewriter.create<scf::ForOp>(loc, zero, cond, one, forArgs);
if (forArgs.size() == 0) if (forArgs.size() == 0)
@ -128,7 +128,7 @@ struct ReplaceIfWithFors : public OpRewritePattern<scf::IfOp> {
SmallVector<Value> vals; SmallVector<Value> vals;
if (!op.elseRegion().empty()) { if (!op.getElseRegion().empty()) {
auto negCondition = rewriter.create<SubIOp>(loc, one, cond); auto negCondition = rewriter.create<SubIOp>(loc, one, cond);
scf::ForOp elseLoop = rewriter.create<scf::ForOp>(loc, zero, negCondition, one, forArgs); scf::ForOp elseLoop = rewriter.create<scf::ForOp>(loc, zero, negCondition, one, forArgs);
if (forArgs.size() == 0) if (forArgs.size() == 0)
@ -136,7 +136,7 @@ struct ReplaceIfWithFors : public OpRewritePattern<scf::IfOp> {
rewriter.mergeBlocks(op.getBody(1), elseLoop.getBody(0)); rewriter.mergeBlocks(op.getBody(1), elseLoop.getBody(0));
for (auto tup : llvm::zip(thenLoop.getResults(), elseLoop.getResults())) { for (auto tup : llvm::zip(thenLoop.getResults(), elseLoop.getResults())) {
vals.push_back(rewriter.create<SelectOp>(op.getLoc(), op.condition(), std::get<0>(tup), std::get<1>(tup))); vals.push_back(rewriter.create<SelectOp>(op.getLoc(), op.getCondition(), std::get<0>(tup), std::get<1>(tup)));
} }
} }
@ -153,7 +153,7 @@ static bool isDefinedAbove(Value value, Operation *user) {
/// Returns `true` if the loop has a form expected by interchange patterns. /// Returns `true` if the loop has a form expected by interchange patterns.
static bool isNormalized(scf::ForOp op) { static bool isNormalized(scf::ForOp op) {
return isDefinedAbove(op.lowerBound(), op) && isDefinedAbove(op.step(), op); return isDefinedAbove(op.getLowerBound(), op) && isDefinedAbove(op.getStep(), op);
} }
/// Transforms a loop to the normal form expected by interchange patterns, i.e. /// Transforms a loop to the normal form expected by interchange patterns, i.e.
@ -179,15 +179,15 @@ struct NormalizeLoop : public OpRewritePattern<scf::ForOp> {
rewriter.restoreInsertionPoint(point); rewriter.restoreInsertionPoint(point);
Value difference = Value difference =
rewriter.create<SubIOp>(op.getLoc(), op.upperBound(), op.lowerBound()); rewriter.create<SubIOp>(op.getLoc(), op.getUpperBound(), op.getLowerBound());
Value tripCount = Value tripCount =
rewriter.create<CeilDivSIOp>(op.getLoc(), difference, op.step()); rewriter.create<CeilDivSIOp>(op.getLoc(), difference, op.getStep());
auto newForOp = auto newForOp =
rewriter.create<scf::ForOp>(op.getLoc(), zero, tripCount, one); rewriter.create<scf::ForOp>(op.getLoc(), zero, tripCount, one);
rewriter.setInsertionPointToStart(newForOp.getBody()); rewriter.setInsertionPointToStart(newForOp.getBody());
Value scaled = rewriter.create<MulIOp>( Value scaled = rewriter.create<MulIOp>(
op.getLoc(), newForOp.getInductionVar(), op.step()); op.getLoc(), newForOp.getInductionVar(), op.getStep());
Value iv = rewriter.create<AddIOp>(op.getLoc(), op.lowerBound(), scaled); Value iv = rewriter.create<AddIOp>(op.getLoc(), op.getLowerBound(), scaled);
rewriter.mergeBlockBefore(op.getBody(), &newForOp.getBody()->back(), {iv}); rewriter.mergeBlockBefore(op.getBody(), &newForOp.getBody()->back(), {iv});
rewriter.eraseOp(&newForOp.getBody()->back()); rewriter.eraseOp(&newForOp.getBody()->back());
rewriter.eraseOp(op); rewriter.eraseOp(op);
@ -205,8 +205,8 @@ static bool isNormalized(scf::ParallelOp op) {
APInt value; APInt value;
return matchPattern(v, m_ConstantInt(&value)) && value.isOneValue(); return matchPattern(v, m_ConstantInt(&value)) && value.isOneValue();
}; };
return llvm::all_of(op.lowerBound(), isZero) && return llvm::all_of(op.getLowerBound(), isZero) &&
llvm::all_of(op.step(), isOne); llvm::all_of(op.getStep(), isOne);
} }
/// Transforms a loop to the normal form expected by interchange patterns, i.e. /// Transforms a loop to the normal form expected by interchange patterns, i.e.
@ -242,9 +242,9 @@ struct NormalizeParallel : public OpRewritePattern<scf::ParallelOp> {
rewriter.setInsertionPointToStart(newOp.getBody()); rewriter.setInsertionPointToStart(newOp.getBody());
for (unsigned i = 0, e = iterationCounts.size(); i < e; ++i) { for (unsigned i = 0, e = iterationCounts.size(); i < e; ++i) {
Value scaled = rewriter.create<MulIOp>( Value scaled = rewriter.create<MulIOp>(
op.getLoc(), newOp.getInductionVars()[i], op.step()[i]); op.getLoc(), newOp.getInductionVars()[i], op.getStep()[i]);
Value shifted = Value shifted =
rewriter.create<AddIOp>(op.getLoc(), op.lowerBound()[i], scaled); rewriter.create<AddIOp>(op.getLoc(), op.getLowerBound()[i], scaled);
inductionVars.push_back(shifted); inductionVars.push_back(shifted);
} }
@ -333,7 +333,7 @@ struct WrapForWithBarrier : public OpRewritePattern<scf::ForOp> {
return wrapWithBarriers(op, rewriter, [&](Operation *prevOp) { return wrapWithBarriers(op, rewriter, [&](Operation *prevOp) {
if (auto loadOp = dyn_cast_or_null<memref::LoadOp>(prevOp)) { if (auto loadOp = dyn_cast_or_null<memref::LoadOp>(prevOp)) {
if (loadOp.result() == op.upperBound() && if (loadOp.result() == op.getUpperBound() &&
loadOp.indices() == loadOp.indices() ==
cast<scf::ParallelOp>(op->getParentOp()).getInductionVars()) { cast<scf::ParallelOp>(op->getParentOp()).getInductionVars()) {
prevOp = prevOp->getPrevNode(); prevOp = prevOp->getPrevNode();
@ -353,7 +353,7 @@ struct WrapWhileWithBarrier : public OpRewritePattern<scf::WhileOp> {
LogicalResult matchAndRewrite(scf::WhileOp op, LogicalResult matchAndRewrite(scf::WhileOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
if (op.getNumOperands() != 0 || if (op.getNumOperands() != 0 ||
!llvm::hasSingleElement(op.after().front())) { !llvm::hasSingleElement(op.getAfter().front())) {
LLVM_DEBUG(DBGS() << "[wrap-while] ignoring non-rotated loop\n";); LLVM_DEBUG(DBGS() << "[wrap-while] ignoring non-rotated loop\n";);
return failure(); return failure();
} }
@ -372,7 +372,7 @@ static void moveBodies(PatternRewriter &rewriter, scf::ParallelOp op,
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(newForLoop.getBody()); rewriter.setInsertionPointToStart(newForLoop.getBody());
auto newParallel = rewriter.create<scf::ParallelOp>( auto newParallel = rewriter.create<scf::ParallelOp>(
op.getLoc(), op.lowerBound(), op.upperBound(), op.step()); op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep());
// Merge in two stages so we can properly replace uses of two induction // Merge in two stages so we can properly replace uses of two induction
// varibales defined in different blocks. // varibales defined in different blocks.
@ -419,8 +419,8 @@ struct InterchangeForPFor : public OpRewritePattern<scf::ParallelOp> {
} }
auto newForLoop = auto newForLoop =
rewriter.create<scf::ForOp>(forLoop.getLoc(), forLoop.lowerBound(), rewriter.create<scf::ForOp>(forLoop.getLoc(), forLoop.getLowerBound(),
forLoop.upperBound(), forLoop.step()); forLoop.getUpperBound(), forLoop.getStep());
moveBodies(rewriter, op, forLoop, newForLoop); moveBodies(rewriter, op, forLoop, newForLoop);
return success(); return success();
} }
@ -446,7 +446,7 @@ struct InterchangeForPForLoad : public OpRewritePattern<scf::ParallelOp> {
} }
auto loadOp = dyn_cast<memref::LoadOp>(op.getBody()->front()); auto loadOp = dyn_cast<memref::LoadOp>(op.getBody()->front());
auto forOp = dyn_cast<scf::ForOp>(op.getBody()->front().getNextNode()); auto forOp = dyn_cast<scf::ForOp>(op.getBody()->front().getNextNode());
if (!loadOp || !forOp || loadOp.result() != forOp.upperBound() || if (!loadOp || !forOp || loadOp.result() != forOp.getUpperBound() ||
loadOp.indices() != op.getInductionVars()) { loadOp.indices() != op.getInductionVars()) {
LLVM_DEBUG(DBGS() << "[interchange-load] expected pfor(load, for)"); LLVM_DEBUG(DBGS() << "[interchange-load] expected pfor(load, for)");
return failure(); return failure();
@ -470,7 +470,7 @@ struct InterchangeForPForLoad : public OpRewritePattern<scf::ParallelOp> {
loadOp.getLoc(), loadOp.getMemRef(), loadOp.getLoc(), loadOp.getMemRef(),
SmallVector<Value>(loadOp.getMemRefType().getRank(), zero)); SmallVector<Value>(loadOp.getMemRefType().getRank(), zero));
auto newForLoop = rewriter.create<scf::ForOp>( auto newForLoop = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.lowerBound(), tripCount, forOp.step()); forOp.getLoc(), forOp.getLowerBound(), tripCount, forOp.getStep());
moveBodies(rewriter, op, forOp, newForLoop); moveBodies(rewriter, op, forOp, newForLoop);
return success(); return success();
@ -535,9 +535,9 @@ findInsertionPointAfterLoopOperands(scf::ParallelOp op) {
// Find the earliest insertion point where loop bounds are fully defined. // Find the earliest insertion point where loop bounds are fully defined.
PostDominanceInfo postDominanceInfo(op->getParentOfType<FuncOp>()); PostDominanceInfo postDominanceInfo(op->getParentOfType<FuncOp>());
SmallVector<Value> operands; SmallVector<Value> operands;
llvm::append_range(operands, op.lowerBound()); llvm::append_range(operands, op.getLowerBound());
llvm::append_range(operands, op.upperBound()); llvm::append_range(operands, op.getUpperBound());
llvm::append_range(operands, op.step()); llvm::append_range(operands, op.getStep());
return findNearestPostDominatingInsertionPoint(operands, postDominanceInfo); return findNearestPostDominatingInsertionPoint(operands, postDominanceInfo);
} }
@ -562,7 +562,7 @@ struct InterchangeWhilePFor : public OpRewritePattern<scf::ParallelOp> {
LLVM_DEBUG(DBGS() << "[interchange-while] loop-carried values\n"); LLVM_DEBUG(DBGS() << "[interchange-while] loop-carried values\n");
return failure(); return failure();
} }
if (!llvm::hasSingleElement(whileOp.after().front()) || !isNormalized(op)) { if (!llvm::hasSingleElement(whileOp.getAfter().front()) || !isNormalized(op)) {
LLVM_DEBUG(DBGS() << "[interchange-while] non-normalized loop\n"); LLVM_DEBUG(DBGS() << "[interchange-while] non-normalized loop\n");
return failure(); return failure();
} }
@ -573,37 +573,37 @@ struct InterchangeWhilePFor : public OpRewritePattern<scf::ParallelOp> {
auto newWhileOp = rewriter.create<scf::WhileOp>(whileOp.getLoc(), auto newWhileOp = rewriter.create<scf::WhileOp>(whileOp.getLoc(),
TypeRange(), ValueRange()); TypeRange(), ValueRange());
rewriter.createBlock(&newWhileOp.after()); rewriter.createBlock(&newWhileOp.getAfter());
rewriter.clone(whileOp.after().front().back()); rewriter.clone(whileOp.getAfter().front().back());
rewriter.createBlock(&newWhileOp.before()); rewriter.createBlock(&newWhileOp.getBefore());
auto newParallelOp = rewriter.create<scf::ParallelOp>( auto newParallelOp = rewriter.create<scf::ParallelOp>(
op.getLoc(), op.lowerBound(), op.upperBound(), op.step()); op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep());
auto conditionOp = cast<scf::ConditionOp>(whileOp.before().front().back()); auto conditionOp = cast<scf::ConditionOp>(whileOp.getBefore().front().back());
rewriter.mergeBlockBefore(op.getBody(), &newParallelOp.getBody()->back(), rewriter.mergeBlockBefore(op.getBody(), &newParallelOp.getBody()->back(),
newParallelOp.getInductionVars()); newParallelOp.getInductionVars());
rewriter.eraseOp(newParallelOp.getBody()->back().getPrevNode()); rewriter.eraseOp(newParallelOp.getBody()->back().getPrevNode());
rewriter.mergeBlockBefore(&whileOp.before().front(), rewriter.mergeBlockBefore(&whileOp.getBefore().front(),
&newParallelOp.getBody()->back()); &newParallelOp.getBody()->back());
Operation *conditionDefiningOp = conditionOp.condition().getDefiningOp(); Operation *conditionDefiningOp = conditionOp.getCondition().getDefiningOp();
if (conditionDefiningOp && if (conditionDefiningOp &&
!isDefinedAbove(conditionOp.condition(), conditionOp)) { !isDefinedAbove(conditionOp.getCondition(), conditionOp)) {
std::pair<Block *, Block::iterator> insertionPoint = std::pair<Block *, Block::iterator> insertionPoint =
findInsertionPointAfterLoopOperands(op); findInsertionPointAfterLoopOperands(op);
rewriter.setInsertionPoint(insertionPoint.first, insertionPoint.second); rewriter.setInsertionPoint(insertionPoint.first, insertionPoint.second);
SmallVector<Value> iterationCounts = emitIterationCounts(rewriter, op); SmallVector<Value> iterationCounts = emitIterationCounts(rewriter, op);
Value allocated = allocateTemporaryBuffer<memref::AllocaOp>( Value allocated = allocateTemporaryBuffer<memref::AllocaOp>(
rewriter, conditionOp.condition(), iterationCounts); rewriter, conditionOp.getCondition(), iterationCounts);
Value zero = rewriter.create<ConstantIndexOp>(op.getLoc(), 0); Value zero = rewriter.create<ConstantIndexOp>(op.getLoc(), 0);
rewriter.setInsertionPointAfter(conditionDefiningOp); rewriter.setInsertionPointAfter(conditionDefiningOp);
rewriter.create<memref::StoreOp>(conditionDefiningOp->getLoc(), rewriter.create<memref::StoreOp>(conditionDefiningOp->getLoc(),
conditionOp.condition(), allocated, conditionOp.getCondition(), allocated,
newParallelOp.getInductionVars()); newParallelOp.getInductionVars());
rewriter.setInsertionPointToEnd(&newWhileOp.before().front()); rewriter.setInsertionPointToEnd(&newWhileOp.getBefore().front());
SmallVector<Value> zeros(iterationCounts.size(), zero); SmallVector<Value> zeros(iterationCounts.size(), zero);
Value reloaded = rewriter.create<memref::LoadOp>( Value reloaded = rewriter.create<memref::LoadOp>(
conditionDefiningOp->getLoc(), allocated, zeros); conditionDefiningOp->getLoc(), allocated, zeros);
@ -646,7 +646,7 @@ struct RotateWhile : public OpRewritePattern<scf::WhileOp> {
LogicalResult matchAndRewrite(scf::WhileOp op, LogicalResult matchAndRewrite(scf::WhileOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
if (llvm::hasSingleElement(op.after().front())) { if (llvm::hasSingleElement(op.getAfter().front())) {
LLVM_DEBUG(DBGS() << "[rotate-while] the after region is empty"); LLVM_DEBUG(DBGS() << "[rotate-while] the after region is empty");
return failure(); return failure();
} }
@ -659,15 +659,15 @@ struct RotateWhile : public OpRewritePattern<scf::WhileOp> {
return failure(); return failure();
} }
auto condition = cast<scf::ConditionOp>(op.before().front().back()); auto condition = cast<scf::ConditionOp>(op.getBefore().front().back());
rewriter.setInsertionPoint(condition); rewriter.setInsertionPoint(condition);
auto conditional = auto conditional =
rewriter.create<scf::IfOp>(op.getLoc(), condition.condition()); rewriter.create<scf::IfOp>(op.getLoc(), condition.getCondition());
rewriter.mergeBlockBefore(&op.after().front(), rewriter.mergeBlockBefore(&op.getAfter().front(),
&conditional.getBody()->back()); &conditional.getBody()->back());
rewriter.eraseOp(&conditional.getBody()->back()); rewriter.eraseOp(&conditional.getBody()->back());
rewriter.createBlock(&op.after()); rewriter.createBlock(&op.getAfter());
rewriter.clone(conditional.getBody()->back()); rewriter.clone(conditional.getBody()->back());
LLVM_DEBUG(DBGS() << "[rotate-while] done\n"); LLVM_DEBUG(DBGS() << "[rotate-while] done\n");
@ -758,7 +758,7 @@ struct DistributeAroundBarrier : public OpRewritePattern<scf::ParallelOp> {
// Create the second loop. // Create the second loop.
rewriter.setInsertionPointAfter(op); rewriter.setInsertionPointAfter(op);
auto newLoop = rewriter.create<scf::ParallelOp>( auto newLoop = rewriter.create<scf::ParallelOp>(
op.getLoc(), op.lowerBound(), op.upperBound(), op.step()); op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep());
rewriter.eraseOp(&newLoop.getBody()->back()); rewriter.eraseOp(&newLoop.getBody()->back());
for (auto alloc : allocations) for (auto alloc : allocations)
@ -837,8 +837,8 @@ struct Reg2MemFor : public OpRewritePattern<scf::ForOp> {
rewriter.create<memref::StoreOp>(op.getLoc(), operand, alloc, zero); rewriter.create<memref::StoreOp>(op.getLoc(), operand, alloc, zero);
} }
auto newOp = rewriter.create<scf::ForOp>(op.getLoc(), op.lowerBound(), auto newOp = rewriter.create<scf::ForOp>(op.getLoc(), op.getLowerBound(),
op.upperBound(), op.step()); op.getUpperBound(), op.getStep());
rewriter.setInsertionPointToStart(newOp.getBody()); rewriter.setInsertionPointToStart(newOp.getBody());
SmallVector<Value> newRegionArguments; SmallVector<Value> newRegionArguments;
newRegionArguments.push_back(newOp.getInductionVar()); newRegionArguments.push_back(newOp.getInductionVar());
@ -849,7 +849,7 @@ struct Reg2MemFor : public OpRewritePattern<scf::ForOp> {
newRegionArguments); newRegionArguments);
rewriter.setInsertionPoint(newOp.getBody()->getTerminator()); rewriter.setInsertionPoint(newOp.getBody()->getTerminator());
for (auto en : llvm::enumerate(oldTerminator.results())) { for (auto en : llvm::enumerate(oldTerminator.getResults())) {
rewriter.create<memref::StoreOp>(op.getLoc(), en.value(), rewriter.create<memref::StoreOp>(op.getLoc(), en.value(),
allocated[en.index()], zero); allocated[en.index()], zero);
} }
@ -908,35 +908,35 @@ struct Reg2MemWhile : public OpRewritePattern<scf::WhileOp> {
auto newOp = auto newOp =
rewriter.create<scf::WhileOp>(op.getLoc(), TypeRange(), ValueRange()); rewriter.create<scf::WhileOp>(op.getLoc(), TypeRange(), ValueRange());
Block *newBefore = Block *newBefore =
rewriter.createBlock(&newOp.before(), newOp.before().begin()); rewriter.createBlock(&newOp.getBefore(), newOp.getBefore().begin());
SmallVector<Value> newBeforeArguments; SmallVector<Value> newBeforeArguments;
loadValues(op.getLoc(), beforeAllocated, zero, rewriter, loadValues(op.getLoc(), beforeAllocated, zero, rewriter,
newBeforeArguments); newBeforeArguments);
rewriter.mergeBlocks(&op.before().front(), newBefore, newBeforeArguments); rewriter.mergeBlocks(&op.getBefore().front(), newBefore, newBeforeArguments);
auto beforeTerminator = auto beforeTerminator =
cast<scf::ConditionOp>(newOp.before().front().getTerminator()); cast<scf::ConditionOp>(newOp.getBefore().front().getTerminator());
rewriter.setInsertionPoint(beforeTerminator); rewriter.setInsertionPoint(beforeTerminator);
storeValues(op.getLoc(), beforeTerminator.args(), afterAllocated, zero, storeValues(op.getLoc(), beforeTerminator.getArgs(), afterAllocated, zero,
rewriter); rewriter);
rewriter.updateRootInPlace(beforeTerminator, rewriter.updateRootInPlace(beforeTerminator,
[&] { beforeTerminator.argsMutable().clear(); }); [&] { beforeTerminator.getArgsMutable().clear(); });
Block *newAfter = Block *newAfter =
rewriter.createBlock(&newOp.after(), newOp.after().begin()); rewriter.createBlock(&newOp.getAfter(), newOp.getAfter().begin());
SmallVector<Value> newAfterArguments; SmallVector<Value> newAfterArguments;
loadValues(op.getLoc(), afterAllocated, zero, rewriter, newAfterArguments); loadValues(op.getLoc(), afterAllocated, zero, rewriter, newAfterArguments);
rewriter.mergeBlocks(&op.after().front(), newAfter, newAfterArguments); rewriter.mergeBlocks(&op.getAfter().front(), newAfter, newAfterArguments);
auto afterTerminator = auto afterTerminator =
cast<scf::YieldOp>(newOp.after().front().getTerminator()); cast<scf::YieldOp>(newOp.getAfter().front().getTerminator());
rewriter.setInsertionPoint(afterTerminator); rewriter.setInsertionPoint(afterTerminator);
storeValues(op.getLoc(), afterTerminator.results(), beforeAllocated, zero, storeValues(op.getLoc(), afterTerminator.getResults(), beforeAllocated, zero,
rewriter); rewriter);
rewriter.updateRootInPlace( rewriter.updateRootInPlace(
afterTerminator, [&] { afterTerminator.resultsMutable().clear(); }); afterTerminator, [&] { afterTerminator.getResultsMutable().clear(); });
rewriter.setInsertionPointAfter(op); rewriter.setInsertionPointAfter(op);
SmallVector<Value> results; SmallVector<Value> results;

View File

@ -30,7 +30,7 @@ struct ForOpRaising : public OpRewritePattern<scf::ForOp> {
bool isAffine(scf::ForOp loop) const { bool isAffine(scf::ForOp loop) const {
// return true; // return true;
// enforce step to be a ConstantIndexOp (maybe too restrictive). // enforce step to be a ConstantIndexOp (maybe too restrictive).
return isa_and_nonnull<ConstantIndexOp>(loop.step().getDefiningOp()); return isa_and_nonnull<ConstantIndexOp>(loop.getStep().getDefiningOp());
} }
void canonicalizeLoopBounds(AffineForOp forOp) const { void canonicalizeLoopBounds(AffineForOp forOp) const {
@ -67,23 +67,23 @@ struct ForOpRaising : public OpRewritePattern<scf::ForOp> {
if (isAffine(loop)) { if (isAffine(loop)) {
OpBuilder builder(loop); OpBuilder builder(loop);
if (!isValidIndex(loop.lowerBound())) { if (!isValidIndex(loop.getLowerBound())) {
return failure(); return failure();
} }
if (!isValidIndex(loop.upperBound())) { if (!isValidIndex(loop.getUpperBound())) {
return failure(); return failure();
} }
AffineForOp affineLoop = rewriter.create<AffineForOp>( AffineForOp affineLoop = rewriter.create<AffineForOp>(
loop.getLoc(), loop.lowerBound(), builder.getSymbolIdentityMap(), loop.getLoc(), loop.getLowerBound(), builder.getSymbolIdentityMap(),
loop.upperBound(), builder.getSymbolIdentityMap(), loop.getUpperBound(), builder.getSymbolIdentityMap(),
getStep(loop.step()), loop.getIterOperands()); getStep(loop.getStep()), loop.getIterOperands());
canonicalizeLoopBounds(affineLoop); canonicalizeLoopBounds(affineLoop);
auto mergedYieldOp = auto mergedYieldOp =
cast<scf::YieldOp>(loop.region().front().getTerminator()); cast<scf::YieldOp>(loop.getRegion().front().getTerminator());
Block &newBlock = affineLoop.region().front(); Block &newBlock = affineLoop.region().front();
@ -97,10 +97,10 @@ struct ForOpRaising : public OpRewritePattern<scf::ForOp> {
rewriter.updateRootInPlace(loop, [&] { rewriter.updateRootInPlace(loop, [&] {
affineLoop.region().front().getOperations().splice( affineLoop.region().front().getOperations().splice(
affineLoop.region().front().getOperations().begin(), affineLoop.region().front().getOperations().begin(),
loop.region().front().getOperations()); loop.getRegion().front().getOperations());
for (auto pair : llvm::zip(affineLoop.region().front().getArguments(), for (auto pair : llvm::zip(affineLoop.region().front().getArguments(),
loop.region().front().getArguments())) { loop.getRegion().front().getArguments())) {
std::get<1>(pair).replaceAllUsesWith(std::get<0>(pair)); std::get<1>(pair).replaceAllUsesWith(std::get<0>(pair));
} }
}); });

@ -1 +1 @@
Subproject commit fe5137a0fe421ee153f115ae3cd7fb51bba65795 Subproject commit a6a583dae40485cacfac56811e6d9131bac6ca74

View File

@ -172,10 +172,10 @@ void MLIRScanner::buildAffineLoopImpl(
builder.setInsertionPointToEnd(&reg.front()); builder.setInsertionPointToEnd(&reg.front());
auto er = builder.create<scf::ExecuteRegionOp>(loc, ArrayRef<mlir::Type>()); auto er = builder.create<scf::ExecuteRegionOp>(loc, ArrayRef<mlir::Type>());
er.region().push_back(new Block()); er.getRegion().push_back(new Block());
builder.setInsertionPointToStart(&er.region().back()); builder.setInsertionPointToStart(&er.getRegion().back());
builder.create<scf::YieldOp>(loc); builder.create<scf::YieldOp>(loc);
builder.setInsertionPointToStart(&er.region().back()); builder.setInsertionPointToStart(&er.getRegion().back());
if (!descr.getForwardMode()) { if (!descr.getForwardMode()) {
val = builder.create<SubIOp>(loc, val, lb); val = builder.create<SubIOp>(loc, val, lb);
@ -355,15 +355,15 @@ ValueCategory MLIRScanner::VisitOMPParallelForDirective(
auto oldpoint = builder.getInsertionPoint(); auto oldpoint = builder.getInsertionPoint();
auto oldblock = builder.getInsertionBlock(); auto oldblock = builder.getInsertionBlock();
builder.setInsertionPointToStart(&affineOp.region().front()); builder.setInsertionPointToStart(&affineOp.getRegion().front());
auto executeRegion = auto executeRegion =
builder.create<scf::ExecuteRegionOp>(loc, ArrayRef<mlir::Type>()); builder.create<scf::ExecuteRegionOp>(loc, ArrayRef<mlir::Type>());
executeRegion.region().push_back(new Block()); executeRegion.getRegion().push_back(new Block());
builder.setInsertionPointToStart(&executeRegion.region().back()); builder.setInsertionPointToStart(&executeRegion.getRegion().back());
auto oldScope = allocationScope; auto oldScope = allocationScope;
allocationScope = &executeRegion.region().back(); allocationScope = &executeRegion.getRegion().back();
for (auto zp : zip(inds, fors->counters())) { for (auto zp : zip(inds, fors->counters())) {
auto idx = builder.create<IndexCastOp>( auto idx = builder.create<IndexCastOp>(
@ -552,13 +552,13 @@ ValueCategory MLIRScanner::VisitIfStmt(clang::IfStmt *stmt) {
bool hasElseRegion = stmt->getElse(); bool hasElseRegion = stmt->getElse();
auto ifOp = builder.create<mlir::scf::IfOp>(loc, cond, hasElseRegion); auto ifOp = builder.create<mlir::scf::IfOp>(loc, cond, hasElseRegion);
ifOp.thenRegion().back().clear(); ifOp.getThenRegion().back().clear();
builder.setInsertionPointToStart(&ifOp.thenRegion().back()); builder.setInsertionPointToStart(&ifOp.getThenRegion().back());
Visit(stmt->getThen()); Visit(stmt->getThen());
builder.create<scf::YieldOp>(loc); builder.create<scf::YieldOp>(loc);
if (hasElseRegion) { if (hasElseRegion) {
ifOp.elseRegion().back().clear(); ifOp.getElseRegion().back().clear();
builder.setInsertionPointToStart(&ifOp.elseRegion().back()); builder.setInsertionPointToStart(&ifOp.getElseRegion().back());
Visit(stmt->getElse()); Visit(stmt->getElse());
builder.create<scf::YieldOp>(loc); builder.create<scf::YieldOp>(loc);
} }
@ -574,7 +574,7 @@ ValueCategory MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) {
SmallVector<int64_t> caseVals; SmallVector<int64_t> caseVals;
auto er = builder.create<scf::ExecuteRegionOp>(loc, ArrayRef<mlir::Type>()); auto er = builder.create<scf::ExecuteRegionOp>(loc, ArrayRef<mlir::Type>());
er.region().push_back(new Block()); er.getRegion().push_back(new Block());
auto oldpoint2 = builder.getInsertionPoint(); auto oldpoint2 = builder.getInsertionPoint();
auto oldblock2 = builder.getInsertionBlock(); auto oldblock2 = builder.getInsertionBlock();
@ -613,7 +613,7 @@ ValueCategory MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) {
} }
inCase = true; inCase = true;
er.region().getBlocks().push_back(&condB); er.getRegion().getBlocks().push_back(&condB);
blocks.push_back(&condB); blocks.push_back(&condB);
builder.setInsertionPointToStart(&condB); builder.setInsertionPointToStart(&condB);
@ -638,7 +638,7 @@ ValueCategory MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) {
} }
inCase = true; inCase = true;
er.region().getBlocks().push_back(&condB); er.getRegion().getBlocks().push_back(&condB);
builder.setInsertionPointToStart(&condB); builder.setInsertionPointToStart(&condB);
auto i1Ty = builder.getIntegerType(1); auto i1Ty = builder.getIntegerType(1);
@ -668,7 +668,7 @@ ValueCategory MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) {
loops.pop_back(); loops.pop_back();
builder.create<mlir::BranchOp>(loc, &exitB); builder.create<mlir::BranchOp>(loc, &exitB);
er.region().getBlocks().push_back(&exitB); er.getRegion().getBlocks().push_back(&exitB);
DenseIntElementsAttr caseValuesAttr; DenseIntElementsAttr caseValuesAttr;
ShapedType caseValueType = mlir::VectorType::get( ShapedType caseValueType = mlir::VectorType::get(
@ -694,7 +694,7 @@ ValueCategory MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) {
caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseVals8); caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseVals8);
} }
builder.setInsertionPointToStart(&er.region().front()); builder.setInsertionPointToStart(&er.getRegion().front());
builder.create<mlir::SwitchOp>( builder.create<mlir::SwitchOp>(
loc, cond, defaultB, ArrayRef<mlir::Value>(), caseValuesAttr, blocks, loc, cond, defaultB, ArrayRef<mlir::Value>(), caseValuesAttr, blocks,
SmallVector<mlir::ValueRange>(caseVals.size(), ArrayRef<mlir::Value>())); SmallVector<mlir::ValueRange>(caseVals.size(), ArrayRef<mlir::Value>()));

View File

@ -21,13 +21,13 @@ IfScope::IfScope(MLIRScanner &scanner) : scanner(scanner), prevBlock(nullptr) {
/*hasElse*/ false); /*hasElse*/ false);
prevBlock = scanner.builder.getInsertionBlock(); prevBlock = scanner.builder.getInsertionBlock();
prevIterator = scanner.builder.getInsertionPoint(); prevIterator = scanner.builder.getInsertionPoint();
ifOp.thenRegion().back().clear(); ifOp.getThenRegion().back().clear();
scanner.builder.setInsertionPointToStart(&ifOp.thenRegion().back()); scanner.builder.setInsertionPointToStart(&ifOp.getThenRegion().back());
auto er = scanner.builder.create<scf::ExecuteRegionOp>( auto er = scanner.builder.create<scf::ExecuteRegionOp>(
scanner.loc, ArrayRef<mlir::Type>()); scanner.loc, ArrayRef<mlir::Type>());
scanner.builder.create<scf::YieldOp>(scanner.loc); scanner.builder.create<scf::YieldOp>(scanner.loc);
er.region().push_back(new Block()); er.getRegion().push_back(new Block());
scanner.builder.setInsertionPointToStart(&er.region().back()); scanner.builder.setInsertionPointToStart(&er.getRegion().back());
} }
} }

View File

@ -68,7 +68,8 @@ static mlir::Value castCallerMemRefArg(mlir::Value callerArg,
if (MemRefType dstTy = calleeArgType.dyn_cast_or_null<MemRefType>()) { if (MemRefType dstTy = calleeArgType.dyn_cast_or_null<MemRefType>()) {
MemRefType srcTy = callerArgType.dyn_cast<MemRefType>(); MemRefType srcTy = callerArgType.dyn_cast<MemRefType>();
if (srcTy && dstTy.getElementType() == srcTy.getElementType()) { if (srcTy && dstTy.getElementType() == srcTy.getElementType()
&& dstTy.getMemorySpace() == srcTy.getMemorySpace()) {
auto srcShape = srcTy.getShape(); auto srcShape = srcTy.getShape();
auto dstShape = dstTy.getShape(); auto dstShape = dstTy.getShape();
@ -748,7 +749,7 @@ ValueCategory MLIRScanner::VisitVarDecl(clang::VarDecl *decl) {
auto ifOp = builder.create<scf::IfOp>(varLoc, cond, /*hasElse*/false); auto ifOp = builder.create<scf::IfOp>(varLoc, cond, /*hasElse*/false);
block = builder.getInsertionBlock(); block = builder.getInsertionBlock();
iter = builder.getInsertionPoint(); iter = builder.getInsertionPoint();
builder.setInsertionPointToStart(&ifOp.thenRegion().back()); builder.setInsertionPointToStart(&ifOp.getThenRegion().back());
builder.create<memref::StoreOp>(varLoc, builder.create<ConstantIntOp>(varLoc, false, 1), boolop, std::vector<mlir::Value>({getConstantIndex(0)})); builder.create<memref::StoreOp>(varLoc, builder.create<ConstantIntOp>(varLoc, false, 1), boolop, std::vector<mlir::Value>({getConstantIndex(0)}));
} }
} else } else
@ -2919,15 +2920,6 @@ ValueCategory MLIRScanner::CallHelper(
} }
} }
assert(val); assert(val);
/*
if (val.getType() != fnType.getInput(i)) {
if (auto MR1 = val.getType().dyn_cast<MemRefType>()) {
if (auto MR2 = fnType.getInput(i).dyn_cast<MemRefType>()) {
val = builder.create<mlir::memref::CastOp>(loc, val, MR2);
}
}
}
*/
args.push_back(val); args.push_back(val);
i++; i++;
} }
@ -3460,7 +3452,7 @@ ValueCategory MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) {
auto oldpoint = builder.getInsertionPoint(); auto oldpoint = builder.getInsertionPoint();
auto oldblock = builder.getInsertionBlock(); auto oldblock = builder.getInsertionBlock();
builder.setInsertionPointToStart(&ifOp.thenRegion().back()); builder.setInsertionPointToStart(&ifOp.getThenRegion().back());
auto rhs = Visit(BO->getRHS()).getValue(builder); auto rhs = Visit(BO->getRHS()).getValue(builder);
assert(rhs != nullptr); assert(rhs != nullptr);
@ -3476,7 +3468,7 @@ ValueCategory MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) {
mlir::Value truearray[] = {rhs}; mlir::Value truearray[] = {rhs};
builder.create<mlir::scf::YieldOp>(loc, truearray); builder.create<mlir::scf::YieldOp>(loc, truearray);
builder.setInsertionPointToStart(&ifOp.elseRegion().back()); builder.setInsertionPointToStart(&ifOp.getElseRegion().back());
mlir::Value falsearray[] = { mlir::Value falsearray[] = {
builder.create<ConstantIntOp>(loc, 0, types[0])}; builder.create<ConstantIntOp>(loc, 0, types[0])};
builder.create<mlir::scf::YieldOp>(loc, falsearray); builder.create<mlir::scf::YieldOp>(loc, falsearray);
@ -3497,12 +3489,12 @@ ValueCategory MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) {
auto oldpoint = builder.getInsertionPoint(); auto oldpoint = builder.getInsertionPoint();
auto oldblock = builder.getInsertionBlock(); auto oldblock = builder.getInsertionBlock();
builder.setInsertionPointToStart(&ifOp.thenRegion().back()); builder.setInsertionPointToStart(&ifOp.getThenRegion().back());
mlir::Value truearray[] = {builder.create<ConstantIntOp>(loc, 1, types[0])}; mlir::Value truearray[] = {builder.create<ConstantIntOp>(loc, 1, types[0])};
builder.create<mlir::scf::YieldOp>(loc, truearray); builder.create<mlir::scf::YieldOp>(loc, truearray);
builder.setInsertionPointToStart(&ifOp.elseRegion().back()); builder.setInsertionPointToStart(&ifOp.getElseRegion().back());
auto rhs = Visit(BO->getRHS()).getValue(builder); auto rhs = Visit(BO->getRHS()).getValue(builder);
if (!rhs.getType().cast<mlir::IntegerType>().isInteger(1)) { if (!rhs.getType().cast<mlir::IntegerType>().isInteger(1)) {
auto postTy = builder.getI1Type(); auto postTy = builder.getI1Type();
@ -4747,7 +4739,7 @@ MLIRScanner::VisitConditionalOperator(clang::ConditionalOperator *E) {
auto oldpoint = builder.getInsertionPoint(); auto oldpoint = builder.getInsertionPoint();
auto oldblock = builder.getInsertionBlock(); auto oldblock = builder.getInsertionBlock();
builder.setInsertionPointToStart(&ifOp.thenRegion().back()); builder.setInsertionPointToStart(&ifOp.getThenRegion().back());
auto trueExpr = Visit(E->getTrueExpr()); auto trueExpr = Visit(E->getTrueExpr());
@ -4778,7 +4770,7 @@ MLIRScanner::VisitConditionalOperator(clang::ConditionalOperator *E) {
builder.create<mlir::scf::YieldOp>(loc, truearray); builder.create<mlir::scf::YieldOp>(loc, truearray);
} }
builder.setInsertionPointToStart(&ifOp.elseRegion().back()); builder.setInsertionPointToStart(&ifOp.getElseRegion().back());
auto falseExpr = Visit(E->getFalseExpr()); auto falseExpr = Visit(E->getFalseExpr());
std::vector<mlir::Value> falsearray; std::vector<mlir::Value> falsearray;
@ -4800,10 +4792,10 @@ MLIRScanner::VisitConditionalOperator(clang::ConditionalOperator *E) {
types[i] = truearray[i].getType(); types[i] = truearray[i].getType();
auto newIfOp = builder.create<mlir::scf::IfOp>(loc, types, cond, auto newIfOp = builder.create<mlir::scf::IfOp>(loc, types, cond,
/*hasElseRegion*/ true); /*hasElseRegion*/ true);
newIfOp.thenRegion().takeBody(ifOp.thenRegion()); newIfOp.getThenRegion().takeBody(ifOp.getThenRegion());
newIfOp.elseRegion().takeBody(ifOp.elseRegion()); newIfOp.getElseRegion().takeBody(ifOp.getElseRegion());
ifOp.erase();
return ValueCategory(newIfOp.getResult(0), /*isReference*/ isReference); return ValueCategory(newIfOp.getResult(0), /*isReference*/ isReference);
// return ifOp;
} }
ValueCategory MLIRScanner::VisitStmtExpr(clang::StmtExpr *stmt) { ValueCategory MLIRScanner::VisitStmtExpr(clang::StmtExpr *stmt) {