Rebase LLVM

This commit is contained in:
William S. Moses 2021-12-30 15:45:42 -05:00 committed by William Moses
parent 85e117eacb
commit 239a2c4d82
13 changed files with 315 additions and 314 deletions

View File

@ -29,7 +29,7 @@ static llvm::SmallVector<mlir::Value>
emitIterationCounts(mlir::OpBuilder &rewriter, mlir::scf::ParallelOp op) {
using namespace mlir;
SmallVector<Value> iterationCounts;
for (auto bounds : llvm::zip(op.lowerBound(), op.upperBound(), op.step())) {
for (auto bounds : llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep())) {
Value lowerBound = std::get<0>(bounds);
Value upperBound = std::get<1>(bounds);
Value step = std::get<2>(bounds);

View File

@ -693,6 +693,9 @@ public:
if (src.source().getType().cast<MemRefType>().getElementType() !=
op.getType().cast<MemRefType>().getElementType())
return failure();
if (src.source().getType().cast<MemRefType>().getMemorySpace() !=
op.getType().cast<MemRefType>().getMemorySpace())
return failure();
rewriter.replaceOpWithNewOp<memref::CastOp>(op, op.getType(), src.source());
return success();
@ -732,6 +735,11 @@ OpFoldResult Memref2PointerOp::fold(ArrayRef<Attribute> operands) {
sourceMutable().assign(mc.source());
return result();
}
if (auto mc = source().getDefiningOp<polygeist::Pointer2MemrefOp>()) {
if (mc.source().getType() == getType()) {
return mc.source();
}
}
return nullptr;
}
@ -868,6 +876,11 @@ OpFoldResult Pointer2MemrefOp::fold(ArrayRef<Attribute> operands) {
sourceMutable().assign(mc.getArg());
return result();
}
if (auto mc = source().getDefiningOp<polygeist::Memref2PointerOp>()) {
if (mc.source().getType() == getType()) {
return mc.source();
}
}
return nullptr;
}

View File

@ -1047,7 +1047,7 @@ void AffineCFGPass::runOnFunction() {
OpBuilder b(ifOp);
AffineIfOp affineIfOp;
std::vector<mlir::Type> types;
for (auto v : ifOp.results()) {
for (auto v : ifOp.getResults()) {
types.push_back(v.getType());
}
@ -1055,7 +1055,7 @@ void AffineCFGPass::runOnFunction() {
SmallVector<bool, 2> eqflags;
SmallVector<Value, 4> applies;
std::deque<Value> todo = {ifOp.condition()};
std::deque<Value> todo = {ifOp.getCondition()};
while (todo.size()) {
auto cur = todo.front();
todo.pop_front();
@ -1077,20 +1077,20 @@ void AffineCFGPass::runOnFunction() {
eqflags);
affineIfOp = b.create<AffineIfOp>(ifOp.getLoc(), types, iset, applies,
/*elseBlock=*/true);
affineIfOp.thenRegion().takeBody(ifOp.thenRegion());
affineIfOp.elseRegion().takeBody(ifOp.elseRegion());
affineIfOp.thenRegion().takeBody(ifOp.getThenRegion());
affineIfOp.elseRegion().takeBody(ifOp.getElseRegion());
for (auto &blk : affineIfOp.thenRegion()) {
if (auto yop = dyn_cast<scf::YieldOp>(blk.getTerminator())) {
OpBuilder b(yop);
b.create<AffineYieldOp>(yop.getLoc(), yop.results());
b.create<AffineYieldOp>(yop.getLoc(), yop.getResults());
yop.erase();
}
}
for (auto &blk : affineIfOp.elseRegion()) {
if (auto yop = dyn_cast<scf::YieldOp>(blk.getTerminator())) {
OpBuilder b(yop);
b.create<AffineYieldOp>(yop.getLoc(), yop.results());
b.create<AffineYieldOp>(yop.getLoc(), yop.getResults());
yop.erase();
}
}

View File

@ -57,14 +57,14 @@ static void wrapPersistingLoopBodies(FuncOp function) {
for (scf::ParallelOp op : loops) {
OpBuilder builder = OpBuilder::atBlockBegin(op.getBody());
auto wrapper = builder.create<scf::ExecuteRegionOp>(
op.getLoc(), op.results().getTypes());
builder.createBlock(&wrapper.region(), wrapper.region().begin());
wrapper.region().front().getOperations().splice(
wrapper.region().front().begin(), op.getBody()->getOperations(),
op.getLoc(), op.getResults().getTypes());
builder.createBlock(&wrapper.getRegion(), wrapper.getRegion().begin());
wrapper.getRegion().front().getOperations().splice(
wrapper.getRegion().front().begin(), op.getBody()->getOperations(),
std::next(op.getBody()->begin()), op.getBody()->end());
builder.setInsertionPointToEnd(op.getBody());
builder.create<scf::YieldOp>(
wrapper.region().front().getTerminator()->getLoc(),
wrapper.getRegion().front().getTerminator()->getLoc(),
wrapper.getResults());
}
}
@ -124,7 +124,7 @@ static LogicalResult splitBlocksWithBarrier(FuncOp function) {
}
splitBlocksWithBarrier(
cast<scf::ExecuteRegionOp>(&op.getBody()->front()).region());
cast<scf::ExecuteRegionOp>(&op.getBody()->front()).getRegion());
return success();
});
return success(!result.wasInterrupted());
@ -288,15 +288,15 @@ emitContinuationCase(Value condition, Value storage, scf::ParallelOp parallel,
bn.create<scf::ExecuteRegionOp>(TypeRange(), ValueRange());
BlockAndValueMapping mapping;
mapping.map(parallel.getInductionVars(), ivs);
replicateIntoRegion(executeRegion.region(), storage, ivs,
parallel.lowerBound(), blocks, subgraphEntryPoints,
replicateIntoRegion(executeRegion.getRegion(), storage, ivs,
parallel.getLowerBound(), blocks, subgraphEntryPoints,
mapping, builder);
};
auto thenBuilder = [&](OpBuilder &nested, Location loc) {
ImplicitLocOpBuilder bn(loc, nested);
bn.create<scf::ParallelOp>(parallel.lowerBound(), parallel.upperBound(),
parallel.step(), parallelBuilder);
bn.create<scf::ParallelOp>(parallel.getLowerBound(), parallel.getUpperBound(),
parallel.getStep(), parallelBuilder);
bn.create<scf::YieldOp>();
};
@ -362,9 +362,9 @@ findInsertionPointAfterLoopOperands(scf::ParallelOp op) {
// Find the earliest insertion point where loop bounds are fully defined.
PostDominanceInfo postDominanceInfo(op->getParentOfType<FuncOp>());
SmallVector<Value> operands;
llvm::append_range(operands, op.lowerBound());
llvm::append_range(operands, op.upperBound());
llvm::append_range(operands, op.step());
llvm::append_range(operands, op.getLowerBound());
llvm::append_range(operands, op.getUpperBound());
llvm::append_range(operands, op.getStep());
return findNesrestPostDominatingInsertionPoint(operands, postDominanceInfo);
}
@ -536,8 +536,8 @@ static void createContinuations(scf::ParallelOp parallel, Value storage) {
llvm::SetVector<Block *> startBlocks;
auto outerExecuteRegion =
cast<scf::ExecuteRegionOp>(&parallel.getBody()->front());
startBlocks.insert(&outerExecuteRegion.region().front());
for (Block &block : outerExecuteRegion.region()) {
startBlocks.insert(&outerExecuteRegion.getRegion().front());
for (Block &block : outerExecuteRegion.getRegion()) {
if (!isa_and_nonnull<polygeist::BarrierOp>(
block.getTerminator()->getPrevNode()))
continue;
@ -560,12 +560,12 @@ static void createContinuations(scf::ParallelOp parallel, Value storage) {
OpBuilder allocBuilder(loop);
reg2mem(subgraphs, parallel, allocBuilder, builder);
builder.createBlock(&loop.before(), loop.before().end());
builder.createBlock(&loop.getBefore(), loop.getBefore().end());
Value next = builder.create<memref::LoadOp>(storage);
Value condition = builder.create<CmpIOp>(CmpIPredicate::ne, next, negOne);
builder.create<scf::ConditionOp>(TypeRange(), condition, ValueRange());
builder.createBlock(&loop.after(), loop.after().end());
builder.createBlock(&loop.getAfter(), loop.getAfter().end());
next = builder.create<memref::LoadOp>(storage);
SmallVector<Value> caseConditions;
caseConditions.resize(startBlocks.size());

View File

@ -52,7 +52,7 @@ static bool hasSameInitValue(Value iter, scf::ForOp forOp) {
if (!cst)
return false;
if (auto cstOp = dyn_cast<ConstantIntOp>(cst)) {
Operation *lbDefOp = forOp.lowerBound().getDefiningOp();
Operation *lbDefOp = forOp.getLowerBound().getDefiningOp();
if (!lbDefOp)
return false;
ConstantIndexOp lb = dyn_cast_or_null<ConstantIndexOp>(lbDefOp);
@ -69,7 +69,7 @@ static bool hasSameStepValue(Value regIter, Value yieldOp, scf::ForOp forOp) {
if (!defOpStep)
return false;
if (auto cstStep = dyn_cast<ConstantIntOp>(defOpStep)) {
Operation *stepForDefOp = forOp.step().getDefiningOp();
Operation *stepForDefOp = forOp.getStep().getDefiningOp();
if (!stepForDefOp)
return false;
ConstantIndexOp stepFor = dyn_cast_or_null<ConstantIndexOp>(stepForDefOp);
@ -123,7 +123,7 @@ struct DetectTrivialIndVarInArgs : public OpRewritePattern<scf::ForOp> {
if (!forOp.getNumIterOperands())
return failure();
Block &block = forOp.region().front();
Block &block = forOp.getRegion().front();
auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
bool matched = false;
@ -150,7 +150,7 @@ struct ForOpInductionReplacement : public OpRewritePattern<scf::ForOp> {
LogicalResult matchAndRewrite(scf::ForOp forOp,
PatternRewriter &rewriter) const final {
bool canonicalize = false;
Block &block = forOp.region().front();
Block &block = forOp.getRegion().front();
auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
for (auto it : llvm::zip(forOp.getIterOperands(), // iter from outside
@ -171,10 +171,10 @@ struct ForOpInductionReplacement : public OpRewritePattern<scf::ForOp> {
bool isOne = false;
if (addOp.getOperand(1) == forOp.step()) {
if (addOp.getOperand(1) == forOp.getStep()) {
legalStep = true;
} else if (auto iter_step =
forOp.step().getDefiningOp<ConstantIndexOp>()) {
forOp.getStep().getDefiningOp<ConstantIndexOp>()) {
isOne |= iter_step.value() == 1;
if (auto op = addOp.getOperand(1).getDefiningOp<ConstantIntOp>()) {
if (op.value() == iter_step.value()) {
@ -200,9 +200,9 @@ struct ForOpInductionReplacement : public OpRewritePattern<scf::ForOp> {
Value init = std::get<0>(it);
if (!std::get<1>(it).use_empty()) {
rewriter.setInsertionPointToStart(&forOp.region().front());
rewriter.setInsertionPointToStart(&forOp.getRegion().front());
Value replacement = rewriter.create<SubIOp>(
forOp.getLoc(), forOp.getInductionVar(), forOp.lowerBound());
forOp.getLoc(), forOp.getInductionVar(), forOp.getLowerBound());
if (!std::get<1>(it).getType().isa<IndexType>()) {
replacement = rewriter.create<IndexCastOp>(
forOp.getLoc(), replacement, std::get<1>(it).getType());
@ -223,7 +223,7 @@ struct ForOpInductionReplacement : public OpRewritePattern<scf::ForOp> {
if (isOne && !std::get<2>(it).use_empty()) {
rewriter.setInsertionPoint(forOp);
Value replacement = rewriter.create<SubIOp>(
forOp.getLoc(), forOp.upperBound(), forOp.lowerBound());
forOp.getLoc(), forOp.getUpperBound(), forOp.getLowerBound());
if (!std::get<1>(it).getType().isa<IndexType>()) {
replacement = rewriter.create<IndexCastOp>(
forOp.getLoc(), replacement, std::get<1>(it).getType());
@ -275,7 +275,7 @@ struct RemoveUnusedArgs : public OpRewritePattern<ForOp> {
return failure();
auto newForOp = rewriter.create<ForOp>(
op.getLoc(), op.lowerBound(), op.upperBound(), op.step(), usedOperands);
op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep(), usedOperands);
if (!newForOp.getBody()->empty())
rewriter.eraseOp(newForOp.getBody()->getTerminator());
@ -499,7 +499,7 @@ yop2.results()[idx]);
bool isWhile(WhileOp wop) {
bool hasCondOp = false;
wop.before().walk([&](Operation *op) {
wop.getBefore().walk([&](Operation *op) {
if (isa<scf::ConditionOp>(op))
hasCondOp = true;
});
@ -547,12 +547,12 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
} loopInfo;
auto condOp = loop.getConditionOp();
SmallVector<Value, 2> results = {condOp.args()};
auto cmpIOp = condOp.condition().getDefiningOp<CmpIOp>();
SmallVector<Value, 2> results = {condOp.getArgs()};
auto cmpIOp = condOp.getCondition().getDefiningOp<CmpIOp>();
if (!cmpIOp) {
return failure();
}
size_t size = loop.before().front().getOperations().size();
size_t size = loop.getBefore().front().getOperations().size();
if (size != 2) {
return failure();
}
@ -560,24 +560,24 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
BlockArgument indVar = cmpIOp.getLhs().dyn_cast<BlockArgument>();
if (!indVar)
return failure();
if (indVar.getOwner() != &loop.before().front())
if (indVar.getOwner() != &loop.getBefore().front())
return failure();
SmallVector<size_t, 2> afterArgs;
for (auto pair : llvm::enumerate(condOp.args())) {
for (auto pair : llvm::enumerate(condOp.getArgs())) {
if (pair.value() == indVar)
afterArgs.push_back(pair.index());
}
auto endYield = cast<YieldOp>(loop.after().back().getTerminator());
auto endYield = cast<YieldOp>(loop.getAfter().back().getTerminator());
auto addIOp =
endYield.results()[indVar.getArgNumber()].getDefiningOp<AddIOp>();
endYield.getResults()[indVar.getArgNumber()].getDefiningOp<AddIOp>();
if (!addIOp)
return failure();
for (auto afterArg : afterArgs) {
auto arg = loop.after().getArgument(afterArg);
auto arg = loop.getAfter().getArgument(afterArg);
if (addIOp.getOperand(0) == arg) {
step = addIOp.getOperand(1);
break;
@ -679,19 +679,19 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
// input of the for goes the input of the scf::while plus the output taken
// from the conditionOp.
SmallVector<Value, 8> forArgs;
forArgs.append(loop.inits().begin(), loop.inits().end());
forArgs.append(loop.getInits().begin(), loop.getInits().end());
for (Value arg : condOp.args()) {
for (Value arg : condOp.getArgs()) {
Type cst = nullptr;
if (auto idx = arg.getDefiningOp<IndexCastOp>()) {
cst = idx.getType();
arg = idx.getIn();
}
Value res;
if (isTopLevelArgValue(arg, &loop.before())) {
if (isTopLevelArgValue(arg, &loop.getBefore())) {
auto blockArg = arg.cast<BlockArgument>();
auto pos = blockArg.getArgNumber();
res = loop.inits()[pos];
res = loop.getInits()[pos];
} else
res = arg;
if (cst) {
@ -706,31 +706,31 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
if (!forloop.getBody()->empty())
rewriter.eraseOp(forloop.getBody()->getTerminator());
auto oldYield = cast<scf::YieldOp>(loop.after().front().getTerminator());
auto oldYield = cast<scf::YieldOp>(loop.getAfter().front().getTerminator());
rewriter.updateRootInPlace(loop, [&] {
for (auto pair : llvm::zip(loop.after().getArguments(), condOp.args())) {
for (auto pair : llvm::zip(loop.getAfter().getArguments(), condOp.getArgs())) {
std::get<0>(pair).replaceAllUsesWith(std::get<1>(pair));
}
});
loop.after().front().eraseArguments([](BlockArgument) { return true; });
loop.getAfter().front().eraseArguments([](BlockArgument) { return true; });
SmallVector<Value, 2> yieldOperands;
for (auto oldYieldArg : oldYield.results())
for (auto oldYieldArg : oldYield.getResults())
yieldOperands.push_back(oldYieldArg);
BlockAndValueMapping outmap;
outmap.map(loop.before().getArguments(), yieldOperands);
for (auto arg : condOp.args())
outmap.map(loop.getBefore().getArguments(), yieldOperands);
for (auto arg : condOp.getArgs())
yieldOperands.push_back(outmap.lookupOrDefault(arg));
rewriter.setInsertionPoint(oldYield);
rewriter.replaceOpWithNewOp<scf::YieldOp>(oldYield, yieldOperands);
size_t pos = loop.inits().size();
size_t pos = loop.getInits().size();
rewriter.updateRootInPlace(loop, [&] {
for (auto pair : llvm::zip(loop.before().getArguments(),
for (auto pair : llvm::zip(loop.getBefore().getArguments(),
forloop.getRegionIterArgs().drop_back(pos))) {
std::get<0>(pair).replaceAllUsesWith(std::get<1>(pair));
}
@ -738,7 +738,7 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
forloop.getBody()->getOperations().splice(
forloop.getBody()->getOperations().begin(),
loop.after().front().getOperations());
loop.getAfter().front().getOperations());
SmallVector<Value, 2> replacements;
replacements.append(forloop.getResults().begin() + pos,
@ -754,20 +754,20 @@ struct MoveWhileDown : public OpRewritePattern<WhileOp> {
LogicalResult matchAndRewrite(WhileOp op,
PatternRewriter &rewriter) const override {
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator());
if (auto ifOp = term.condition().getDefiningOp<scf::IfOp>()) {
if (ifOp.getNumResults() != term.args().size() + 1)
auto term = cast<scf::ConditionOp>(op.getBefore().front().getTerminator());
if (auto ifOp = term.getCondition().getDefiningOp<scf::IfOp>()) {
if (ifOp.getNumResults() != term.getArgs().size() + 1)
return failure();
if (ifOp.getResult(0) != term.condition())
if (ifOp.getResult(0) != term.getCondition())
return failure();
for (size_t i = 1; i < ifOp.getNumResults(); ++i) {
if (ifOp.getResult(i) != term.args()[i - 1])
if (ifOp.getResult(i) != term.getArgs()[i - 1])
return failure();
}
auto yield1 =
cast<scf::YieldOp>(ifOp.thenRegion().front().getTerminator());
cast<scf::YieldOp>(ifOp.getThenRegion().front().getTerminator());
auto yield2 =
cast<scf::YieldOp>(ifOp.elseRegion().front().getTerminator());
cast<scf::YieldOp>(ifOp.getElseRegion().front().getTerminator());
if (auto cop = yield1.getOperand(0).getDefiningOp<ConstantIntOp>()) {
if (cop.value() == 0)
return failure();
@ -778,31 +778,31 @@ struct MoveWhileDown : public OpRewritePattern<WhileOp> {
return failure();
} else
return failure();
if (ifOp.elseRegion().front().getOperations().size() != 1)
if (ifOp.getElseRegion().front().getOperations().size() != 1)
return failure();
op.after().front().getOperations().splice(
op.after().front().begin(),
ifOp.thenRegion().front().getOperations());
op.getAfter().front().getOperations().splice(
op.getAfter().front().begin(),
ifOp.getThenRegion().front().getOperations());
rewriter.updateRootInPlace(
term, [&] { term.conditionMutable().assign(ifOp.condition()); });
term, [&] { term.getConditionMutable().assign(ifOp.getCondition()); });
SmallVector<Value, 2> args;
for (size_t i = 1; i < yield2.getNumOperands(); ++i) {
args.push_back(yield2.getOperand(i));
}
rewriter.updateRootInPlace(term,
[&] { term.argsMutable().assign(args); });
[&] { term.getArgsMutable().assign(args); });
rewriter.eraseOp(yield2);
rewriter.eraseOp(ifOp);
for (size_t i = 0; i < op.after().front().getNumArguments(); ++i) {
op.after().front().getArgument(i).replaceAllUsesWith(
for (size_t i = 0; i < op.getAfter().front().getNumArguments(); ++i) {
op.getAfter().front().getArgument(i).replaceAllUsesWith(
yield1.getOperand(i + 1));
}
rewriter.eraseOp(yield1);
// TODO move operands from begin to after
SmallVector<Value> todo(op.before().front().getArguments().begin(),
op.before().front().getArguments().end());
for (auto &op : op.before().front()) {
SmallVector<Value> todo(op.getBefore().front().getArguments().begin(),
op.getBefore().front().getArguments().end());
for (auto &op : op.getBefore().front()) {
for (auto res : op.getResults()) {
todo.push_back(res);
}
@ -810,24 +810,24 @@ struct MoveWhileDown : public OpRewritePattern<WhileOp> {
rewriter.updateRootInPlace(op, [&] {
for (auto val : todo) {
auto na = op.after().front().addArgument(val.getType());
auto na = op.getAfter().front().addArgument(val.getType());
val.replaceUsesWithIf(na, [&](OpOperand &u) -> bool {
return op.after().isAncestor(u.getOwner()->getParentRegion());
return op.getAfter().isAncestor(u.getOwner()->getParentRegion());
});
args.push_back(val);
}
});
rewriter.updateRootInPlace(term,
[&] { term.argsMutable().assign(args); });
[&] { term.getArgsMutable().assign(args); });
SmallVector<Type, 4> tys;
for (auto a : args)
tys.push_back(a.getType());
auto op2 = rewriter.create<WhileOp>(op.getLoc(), tys, op.inits());
op2.before().takeBody(op.before());
op2.after().takeBody(op.after());
auto op2 = rewriter.create<WhileOp>(op.getLoc(), tys, op.getInits());
op2.getBefore().takeBody(op.getBefore());
op2.getAfter().takeBody(op.getAfter());
SmallVector<Value, 4> replacements;
for (auto a : op2.getResults()) {
if (replacements.size() == op.getResults().size())
@ -882,9 +882,9 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
LogicalResult matchAndRewrite(WhileOp op,
PatternRewriter &rewriter) const override {
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator());
auto term = cast<scf::ConditionOp>(op.getBefore().front().getTerminator());
if (auto ifOp = dyn_cast_or_null<scf::IfOp>(term->getPrevNode())) {
if (ifOp.condition() != term.condition())
if (ifOp.getCondition() != term.getCondition())
return failure();
SmallVector<std::pair<BlockArgument, Value>, 2> m;
@ -892,14 +892,14 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
SmallVector<Value, 2> prevArgs;
SmallVector<std::pair<size_t, Value>, 2> afterYieldRewrites;
auto afterYield = cast<YieldOp>(op.after().front().back());
auto afterYield = cast<YieldOp>(op.getAfter().front().back());
for (auto pair :
llvm::zip(op.getResults(), term.args(), op.getAfterArguments())) {
llvm::zip(op.getResults(), term.getArgs(), op.getAfterArguments())) {
if (std::get<1>(pair).getDefiningOp() == ifOp) {
Value thenYielded, elseYielded;
for (auto p : llvm::zip(ifOp.thenYield().results(), ifOp.results(),
ifOp.elseYield().results())) {
for (auto p : llvm::zip(ifOp.thenYield().getResults(), ifOp.getResults(),
ifOp.elseYield().getResults())) {
if (std::get<1>(pair) == std::get<1>(p)) {
thenYielded = std::get<0>(p);
elseYielded = std::get<2>(p);
@ -911,8 +911,8 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
if (!std::get<0>(pair).use_empty()) {
if (auto blockArg = elseYielded.dyn_cast<BlockArgument>())
if (blockArg.getOwner() == &op.before().front()) {
if (afterYield.results()[blockArg.getArgNumber()] ==
if (blockArg.getOwner() == &op.getBefore().front()) {
if (afterYield.getResults()[blockArg.getArgNumber()] ==
std::get<2>(pair) &&
op.getResults()[blockArg.getArgNumber()] ==
std::get<0>(pair)) {
@ -933,18 +933,18 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
}
}
SmallVector<Value> yieldArgs = afterYield.results();
SmallVector<Value> yieldArgs = afterYield.getResults();
for (auto pair : afterYieldRewrites) {
yieldArgs[pair.first] = pair.second;
}
rewriter.updateRootInPlace(
afterYield, [&] { afterYield.resultsMutable().assign(yieldArgs); });
afterYield, [&] { afterYield.getResultsMutable().assign(yieldArgs); });
llvm::SetVector<Value> sv;
findValuesUsedBelow(ifOp, sv);
Block *afterB = &op.after().front();
Block *afterB = &op.getAfter().front();
for (auto v : sv) {
condArgs.push_back(v);
@ -956,7 +956,7 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
}
rewriter.setInsertionPoint(term);
rewriter.replaceOpWithNewOp<ConditionOp>(term, term.condition(),
rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
condArgs);
for (int i = m.size() - 1; i >= 0; i--) {
@ -977,9 +977,9 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
}
rewriter.setInsertionPoint(op);
auto nop = rewriter.create<WhileOp>(op.getLoc(), resultTypes, op.inits());
nop.before().takeBody(op.before());
nop.after().takeBody(op.after());
auto nop = rewriter.create<WhileOp>(op.getLoc(), resultTypes, op.getInits());
nop.getBefore().takeBody(op.getBefore());
nop.getAfter().takeBody(op.getAfter());
rewriter.updateRootInPlace(op, [&] {
for (auto pair : llvm::enumerate(prevArgs)) {
@ -1002,24 +1002,24 @@ struct MoveWhileInvariantIfResult : public OpRewritePattern<WhileOp> {
SmallVector<BlockArgument, 2> origAfterArgs(op.getAfterArguments().begin(),
op.getAfterArguments().end());
bool changed = false;
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator());
scf::ConditionOp term = cast<scf::ConditionOp>(op.getBefore().front().getTerminator());
assert(origAfterArgs.size() == op.getResults().size());
assert(origAfterArgs.size() == term.args().size());
assert(origAfterArgs.size() == term.getArgs().size());
for (auto pair : llvm::zip(op.getResults(), term.args(), origAfterArgs)) {
for (auto pair : llvm::zip(op.getResults(), term.getArgs(), origAfterArgs)) {
if (!std::get<0>(pair).use_empty()) {
if (auto ifOp = std::get<1>(pair).getDefiningOp<scf::IfOp>()) {
if (ifOp.condition() == term.condition()) {
if (ifOp.getCondition() == term.getCondition()) {
ssize_t idx = -1;
for (auto tup : llvm::enumerate(ifOp.results())) {
for (auto tup : llvm::enumerate(ifOp.getResults())) {
if (tup.value() == std::get<1>(pair)) {
idx = tup.index();
break;
}
}
assert(idx != -1);
Value returnWith = ifOp.elseYield().results()[idx];
if (!op.before().isAncestor(returnWith.getParentRegion())) {
Value returnWith = ifOp.elseYield().getResults()[idx];
if (!op.getBefore().isAncestor(returnWith.getParentRegion())) {
rewriter.updateRootInPlace(op, [&] {
std::get<0>(pair).replaceAllUsesWith(returnWith);
});
@ -1042,12 +1042,12 @@ struct WhileLogicalNegation : public OpRewritePattern<WhileOp> {
SmallVector<BlockArgument, 2> origAfterArgs(op.getAfterArguments().begin(),
op.getAfterArguments().end());
bool changed = false;
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator());
scf::ConditionOp term = cast<scf::ConditionOp>(op.getBefore().front().getTerminator());
assert(origAfterArgs.size() == op.getResults().size());
assert(origAfterArgs.size() == term.args().size());
assert(origAfterArgs.size() == term.getArgs().size());
if (auto condCmp = term.condition().getDefiningOp<CmpIOp>()) {
for (auto pair : llvm::zip(op.getResults(), term.args(), origAfterArgs)) {
if (auto condCmp = term.getCondition().getDefiningOp<CmpIOp>()) {
for (auto pair : llvm::zip(op.getResults(), term.getArgs(), origAfterArgs)) {
if (!std::get<0>(pair).use_empty()) {
if (auto termCmp = std::get<1>(pair).getDefiningOp<CmpIOp>()) {
if (termCmp.getLhs() == condCmp.getLhs() &&
@ -1082,41 +1082,41 @@ struct WhileCmpOffset : public OpRewritePattern<WhileOp> {
PatternRewriter &rewriter) const override {
SmallVector<BlockArgument, 2> origAfterArgs(op.getAfterArguments().begin(),
op.getAfterArguments().end());
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator());
scf::ConditionOp term = cast<scf::ConditionOp>(op.getBefore().front().getTerminator());
assert(origAfterArgs.size() == op.getResults().size());
assert(origAfterArgs.size() == term.args().size());
assert(origAfterArgs.size() == term.getArgs().size());
if (auto condCmp = term.condition().getDefiningOp<CmpIOp>()) {
if (auto condCmp = term.getCondition().getDefiningOp<CmpIOp>()) {
if (auto addI = condCmp.getLhs().getDefiningOp<AddIOp>()) {
if (addI.getOperand(1).getDefiningOp() &&
!op.before().isAncestor(
!op.getBefore().isAncestor(
addI.getOperand(1).getDefiningOp()->getParentRegion()))
if (auto blockArg = addI.getOperand(0).dyn_cast<BlockArgument>()) {
if (blockArg.getOwner() == &op.before().front()) {
if (blockArg.getOwner() == &op.getBefore().front()) {
auto rng = llvm::make_early_inc_range(blockArg.getUses());
{
rewriter.setInsertionPoint(op);
SmallVector<Value> oldInits = op.inits();
SmallVector<Value> oldInits = op.getInits();
oldInits[blockArg.getArgNumber()] = rewriter.create<AddIOp>(
addI.getLoc(), oldInits[blockArg.getArgNumber()],
addI.getOperand(1));
op.initsMutable().assign(oldInits);
op.getInitsMutable().assign(oldInits);
rewriter.updateRootInPlace(
addI, [&] { addI.replaceAllUsesWith(blockArg); });
}
YieldOp afterYield = cast<YieldOp>(op.after().front().back());
YieldOp afterYield = cast<YieldOp>(op.getAfter().front().back());
rewriter.setInsertionPoint(afterYield);
SmallVector<Value> oldYields = afterYield.results();
SmallVector<Value> oldYields = afterYield.getResults();
oldYields[blockArg.getArgNumber()] = rewriter.create<AddIOp>(
addI.getLoc(), oldYields[blockArg.getArgNumber()],
addI.getOperand(1));
rewriter.updateRootInPlace(afterYield, [&] {
afterYield.resultsMutable().assign(oldYields);
afterYield.getResultsMutable().assign(oldYields);
});
rewriter.setInsertionPointToStart(&op.before().front());
rewriter.setInsertionPointToStart(&op.getBefore().front());
auto sub = rewriter.create<SubIOp>(addI.getLoc(), blockArg,
addI.getOperand(1));
for (OpOperand &use : rng) {
@ -1139,7 +1139,7 @@ struct MoveWhileDown3 : public OpRewritePattern<WhileOp> {
LogicalResult matchAndRewrite(WhileOp op,
PatternRewriter &rewriter) const override {
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator());
scf::ConditionOp term = cast<scf::ConditionOp>(op.getBefore().front().getTerminator());
SmallVector<unsigned, 2> toErase;
SmallVector<Value, 2> newOps;
SmallVector<Value, 2> condOps;
@ -1147,8 +1147,8 @@ struct MoveWhileDown3 : public OpRewritePattern<WhileOp> {
op.getAfterArguments().end());
SmallVector<Value, 2> returns;
assert(origAfterArgs.size() == op.getResults().size());
assert(origAfterArgs.size() == term.args().size());
for (auto pair : llvm::zip(op.getResults(), term.args(), origAfterArgs)) {
assert(origAfterArgs.size() == term.getArgs().size());
for (auto pair : llvm::zip(op.getResults(), term.getArgs(), origAfterArgs)) {
if (std::get<0>(pair).use_empty()) {
if (std::get<2>(pair).use_empty()) {
toErase.push_back(std::get<2>(pair).getArgNumber());
@ -1162,9 +1162,9 @@ struct MoveWhileDown3 : public OpRewritePattern<WhileOp> {
Operation *cloned = std::get<1>(pair).getDefiningOp();
if (!std::get<1>(pair).hasOneUse()) {
cloned = std::get<1>(pair).getDefiningOp()->clone();
op.after().front().push_front(cloned);
op.getAfter().front().push_front(cloned);
} else {
cloned->moveBefore(&op.after().front().front());
cloned->moveBefore(&op.getAfter().front().front());
}
rewriter.updateRootInPlace(std::get<1>(pair).getDefiningOp(), [&] {
std::get<2>(pair).replaceAllUsesWith(cloned->getResult(0));
@ -1174,7 +1174,7 @@ struct MoveWhileDown3 : public OpRewritePattern<WhileOp> {
llvm::make_early_inc_range(cloned->getOpOperands())) {
{
newOps.push_back(o.get());
o.set(op.after().front().addArgument(o.get().getType()));
o.set(op.getAfter().front().addArgument(o.get().getType()));
}
}
continue;
@ -1190,19 +1190,19 @@ struct MoveWhileDown3 : public OpRewritePattern<WhileOp> {
condOps.append(newOps.begin(), newOps.end());
rewriter.updateRootInPlace(
term, [&] { op.after().front().eraseArguments(toErase); });
term, [&] { op.getAfter().front().eraseArguments(toErase); });
rewriter.setInsertionPoint(term);
rewriter.replaceOpWithNewOp<ConditionOp>(term, term.condition(), condOps);
rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(), condOps);
rewriter.setInsertionPoint(op);
SmallVector<Type, 4> resultTypes;
for (auto v : condOps) {
resultTypes.push_back(v.getType());
}
auto nop = rewriter.create<WhileOp>(op.getLoc(), resultTypes, op.inits());
auto nop = rewriter.create<WhileOp>(op.getLoc(), resultTypes, op.getInits());
nop.before().takeBody(op.before());
nop.after().takeBody(op.after());
nop.getBefore().takeBody(op.getBefore());
nop.getAfter().takeBody(op.getAfter());
rewriter.updateRootInPlace(op, [&] {
for (auto pair : llvm::enumerate(returns)) {
@ -1210,7 +1210,7 @@ struct MoveWhileDown3 : public OpRewritePattern<WhileOp> {
}
});
assert(resultTypes.size() == nop.after().front().getNumArguments());
assert(resultTypes.size() == nop.getAfter().front().getNumArguments());
assert(resultTypes.size() == condOps.size());
rewriter.eraseOp(op);
@ -1269,7 +1269,7 @@ struct WhileLICM : public OpRewritePattern<WhileOp> {
auto definingOp = value.getDefiningOp();
bool definedOutside =
(definingOp && !!willBeMovedSet.count(definingOp)) ||
!op.before().isAncestor(value.getParentRegion());
!op.getBefore().isAncestor(value.getParentRegion());
return definedOutside;
};
@ -1277,7 +1277,7 @@ struct WhileLICM : public OpRewritePattern<WhileOp> {
// hoist operations from there. These regions might have semantics unknown
// to this rewriting. If the nested regions are loops, they will have been
// processed.
for (auto &block : op.before()) {
for (auto &block : op.getBefore()) {
for (auto &op : block.without_terminator()) {
bool legal = canBeHoisted(&op, isDefinedOutsideOfBody);
if (legal) {
@ -1299,7 +1299,7 @@ struct RemoveUnusedCondVar : public OpRewritePattern<WhileOp> {
LogicalResult matchAndRewrite(WhileOp op,
PatternRewriter &rewriter) const override {
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator());
auto term = cast<scf::ConditionOp>(op.getBefore().front().getTerminator());
SmallVector<Value, 4> conds;
SmallVector<unsigned, 4> eraseArgs;
SmallVector<unsigned, 4> keepArgs;
@ -1308,7 +1308,7 @@ struct RemoveUnusedCondVar : public OpRewritePattern<WhileOp> {
std::map<void *, unsigned> valueOffsets;
std::map<unsigned, unsigned> resultOffsets;
SmallVector<Value, 4> resultArgs;
for (auto pair : llvm::zip(term.args(), op.after().front().getArguments(),
for (auto pair : llvm::zip(term.getArgs(), op.getAfter().front().getArguments(),
op.getResults())) {
auto arg = std::get<0>(pair);
auto afarg = std::get<1>(pair);
@ -1331,24 +1331,24 @@ struct RemoveUnusedCondVar : public OpRewritePattern<WhileOp> {
}
i++;
}
assert(i == op.after().front().getArguments().size());
assert(i == op.getAfter().front().getArguments().size());
if (eraseArgs.size() != 0) {
rewriter.setInsertionPoint(term);
rewriter.replaceOpWithNewOp<scf::ConditionOp>(term, term.condition(),
rewriter.replaceOpWithNewOp<scf::ConditionOp>(term, term.getCondition(),
conds);
rewriter.setInsertionPoint(op);
auto op2 = rewriter.create<WhileOp>(op.getLoc(), tys, op.inits());
auto op2 = rewriter.create<WhileOp>(op.getLoc(), tys, op.getInits());
op2.before().takeBody(op.before());
op2.after().takeBody(op.after());
op2.getBefore().takeBody(op.getBefore());
op2.getAfter().takeBody(op.getAfter());
for (auto pair : resultOffsets) {
op.getResult(pair.first).replaceAllUsesWith(op2.getResult(pair.second));
}
rewriter.eraseOp(op);
op2.after().front().eraseArguments(eraseArgs);
op2.getAfter().front().eraseArguments(eraseArgs);
return success();
}
return failure();
@ -1360,19 +1360,19 @@ struct MoveSideEffectFreeWhile : public OpRewritePattern<WhileOp> {
LogicalResult matchAndRewrite(WhileOp op,
PatternRewriter &rewriter) const override {
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator());
SmallVector<Value, 4> conds(term.args().begin(), term.args().end());
scf::ConditionOp term = cast<scf::ConditionOp>(op.getBefore().front().getTerminator());
SmallVector<Value, 4> conds(term.getArgs().begin(), term.getArgs().end());
bool changed = false;
unsigned i = 0;
for (auto arg : term.args()) {
for (auto arg : term.getArgs()) {
if (auto IC = arg.getDefiningOp<IndexCastOp>()) {
if (arg.hasOneUse() && op.getResult(i).use_empty()) {
auto rep =
op.after().front().addArgument(IC->getOperand(0).getType());
IC->moveBefore(&op.after().front(), op.after().front().begin());
op.getAfter().front().addArgument(IC->getOperand(0).getType());
IC->moveBefore(&op.getAfter().front(), op.getAfter().front().begin());
conds.push_back(IC.getIn());
IC.getInMutable().assign(rep);
op.after().front().getArgument(i).replaceAllUsesWith(
op.getAfter().front().getArgument(i).replaceAllUsesWith(
IC->getResult(0));
changed = true;
}
@ -1384,9 +1384,9 @@ struct MoveSideEffectFreeWhile : public OpRewritePattern<WhileOp> {
for (auto arg : conds) {
tys.push_back(arg.getType());
}
auto op2 = rewriter.create<WhileOp>(op.getLoc(), tys, op.inits());
op2.before().takeBody(op.before());
op2.after().takeBody(op.after());
auto op2 = rewriter.create<WhileOp>(op.getLoc(), tys, op.getInits());
op2.getBefore().takeBody(op.getBefore());
op2.getAfter().takeBody(op.getAfter());
unsigned j = 0;
for (auto a : op.getResults()) {
a.replaceAllUsesWith(op2.getResult(j));
@ -1394,7 +1394,7 @@ struct MoveSideEffectFreeWhile : public OpRewritePattern<WhileOp> {
}
rewriter.eraseOp(op);
rewriter.setInsertionPoint(term);
rewriter.replaceOpWithNewOp<scf::ConditionOp>(term, term.condition(),
rewriter.replaceOpWithNewOp<scf::ConditionOp>(term, term.getCondition(),
conds);
return success();
}

View File

@ -295,8 +295,8 @@ bool LoopRestructure::removeIfFromRegion(DominanceInfo &domInfo, Region &region,
/*hasElse*/ true);
Succs[j] = new Block();
if (j == 0) {
ifOp.elseRegion().getBlocks().splice(
ifOp.elseRegion().getBlocks().end(), region.getBlocks(),
ifOp.getElseRegion().getBlocks().splice(
ifOp.getElseRegion().getBlocks().end(), region.getBlocks(),
Succs[1 - j]);
SmallVector<unsigned, 4> idx;
for (size_t i = 0; i < Succs[1 - j]->getNumArguments(); ++i) {
@ -305,18 +305,18 @@ bool LoopRestructure::removeIfFromRegion(DominanceInfo &domInfo, Region &region,
idx.push_back(i);
}
Succs[1 - j]->eraseArguments(idx);
assert(!ifOp.elseRegion().getBlocks().empty());
assert(!ifOp.getElseRegion().getBlocks().empty());
assert(condTys.size() == condBr.getTrueOperands().size());
OpBuilder tbuilder(&ifOp.thenRegion().front(),
ifOp.thenRegion().front().begin());
OpBuilder tbuilder(&ifOp.getThenRegion().front(),
ifOp.getThenRegion().front().begin());
tbuilder.create<scf::YieldOp>(tbuilder.getUnknownLoc(), emptyTys,
condBr.getTrueOperands());
} else {
if (!ifOp.thenRegion().getBlocks().empty()) {
ifOp.thenRegion().front().erase();
if (!ifOp.getThenRegion().getBlocks().empty()) {
ifOp.getThenRegion().front().erase();
}
ifOp.thenRegion().getBlocks().splice(
ifOp.thenRegion().getBlocks().end(), region.getBlocks(),
ifOp.getThenRegion().getBlocks().splice(
ifOp.getThenRegion().getBlocks().end(), region.getBlocks(),
Succs[1 - j]);
SmallVector<unsigned, 4> idx;
for (size_t i = 0; i < Succs[1 - j]->getNumArguments(); ++i) {
@ -325,9 +325,9 @@ bool LoopRestructure::removeIfFromRegion(DominanceInfo &domInfo, Region &region,
idx.push_back(i);
}
Succs[1 - j]->eraseArguments(idx);
assert(!ifOp.elseRegion().getBlocks().empty());
OpBuilder tbuilder(&ifOp.elseRegion().front(),
ifOp.elseRegion().front().begin());
assert(!ifOp.getElseRegion().getBlocks().empty());
OpBuilder tbuilder(&ifOp.getElseRegion().front(),
ifOp.getElseRegion().front().begin());
assert(condTys.size() == condBr.getFalseOperands().size());
tbuilder.create<scf::YieldOp>(tbuilder.getUnknownLoc(), emptyTys,
condBr.getFalseOperands());
@ -449,12 +449,12 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region &region) {
Preds.push_back(block);
}
loop.before().getBlocks().splice(loop.before().getBlocks().begin(),
loop.getBefore().getBlocks().splice(loop.getBefore().getBlocks().begin(),
region.getBlocks(), header);
for (auto *w : L->getBlocks()) {
Block *b = &**w;
if (b != header) {
loop.before().getBlocks().splice(loop.before().getBlocks().end(),
loop.getBefore().getBlocks().splice(loop.getBefore().getBlocks().end(),
region.getBlocks(), b);
}
}
@ -462,7 +462,7 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region &region) {
Block *pseudoExit = new Block();
auto i1Ty = builder.getI1Type();
{
loop.before().push_back(pseudoExit);
loop.getBefore().push_back(pseudoExit);
SmallVector<Type, 4> tys = {i1Ty};
for (auto t : combinedTypes)
tys.push_back(t);
@ -592,7 +592,7 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region &region) {
Block *after = new Block();
after->addArguments(combinedTypes);
loop.after().push_back(after);
loop.getAfter().push_back(after);
OpBuilder builder2(after, after->begin());
SmallVector<Value, 4> yieldargs;
for (auto a : after->getArguments()) {
@ -619,42 +619,42 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region &region) {
}
builder2.create<scf::YieldOp>(builder.getUnknownLoc(), yieldargs);
domInfo.invalidate(&loop.before());
runOnRegion(domInfo, loop.before());
if (!removeIfFromRegion(domInfo, loop.before(), pseudoExit)) {
domInfo.invalidate(&loop.getBefore());
runOnRegion(domInfo, loop.getBefore());
if (!removeIfFromRegion(domInfo, loop.getBefore(), pseudoExit)) {
attemptToFoldIntoPredecessor(pseudoExit);
}
attemptToFoldIntoPredecessor(wrapper);
attemptToFoldIntoPredecessor(target);
if (loop.before().getBlocks().size() != 1) {
if (loop.getBefore().getBlocks().size() != 1) {
Block *blk = new Block();
OpBuilder B(loop.getContext());
B.setInsertionPointToEnd(blk);
auto cop =
cast<scf::ConditionOp>(loop.before().getBlocks().back().back());
cast<scf::ConditionOp>(loop.getBefore().getBlocks().back().back());
auto er = B.create<scf::ExecuteRegionOp>(loop.getLoc(),
cop.getOperandTypes());
er.region().getBlocks().splice(er.region().getBlocks().begin(),
loop.before().getBlocks());
loop.before().push_back(blk);
er.getRegion().getBlocks().splice(er.getRegion().getBlocks().begin(),
loop.getBefore().getBlocks());
loop.getBefore().push_back(blk);
SmallVector<Value> yields;
for (auto a : er.getResults())
yields.push_back(a);
yields.erase(yields.begin());
B.create<scf::ConditionOp>(cop.getLoc(), er.getResult(0), yields);
B.setInsertionPoint(&*cop);
for (auto arg : er.region().front().getArguments()) {
for (auto arg : er.getRegion().front().getArguments()) {
auto na = blk->addArgument(arg.getType());
arg.replaceAllUsesWith(na);
}
er.region().front().eraseArguments([](BlockArgument) { return true; });
er.getRegion().front().eraseArguments([](BlockArgument) { return true; });
B.create<scf::YieldOp>(cop.getLoc(), cop.getOperands());
cop.erase();
}
assert(loop.before().getBlocks().size() == 1);
runOnRegion(domInfo, loop.after());
assert(loop.after().getBlocks().size() == 1);
assert(loop.getBefore().getBlocks().size() == 1);
runOnRegion(domInfo, loop.getAfter());
assert(loop.getAfter().getBlocks().size() == 1);
}
}

View File

@ -453,7 +453,7 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
for (auto a : ops) {
if (StoringOperations.count(a)) {
if (auto exOp = dyn_cast<mlir::scf::ExecuteRegionOp>(a)) {
valueAtStartOfBlock[&*exOp.region().begin()] = lastVal;
valueAtStartOfBlock[&*exOp.getRegion().begin()] = lastVal;
Value thenVal; // = handleBlock(exOp.region().front(), lastVal);
lastVal = nullptr;
seenSubStore = true;
@ -477,7 +477,7 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
continue;
}
Block &then = exOp.region().back();
Block &then = exOp.getRegion().back();
OpBuilder B(exOp.getContext());
auto yieldOp = cast<mlir::scf::YieldOp>(then.back());
B.setInsertionPoint(yieldOp);
@ -502,13 +502,13 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
auto nextIf =
B.create<mlir::scf::ExecuteRegionOp>(exOp.getLoc(), tys);
SmallVector<mlir::Value, 4> thenVals = yieldOp.results();
SmallVector<mlir::Value, 4> thenVals = yieldOp.getResults();
thenVals.push_back(thenVal);
yieldOp->setOperands(thenVals);
nextIf.region().getBlocks().clear();
nextIf.region().getBlocks().splice(
nextIf.region().getBlocks().begin(),
exOp.region().getBlocks());
nextIf.getRegion().getBlocks().clear();
nextIf.getRegion().getBlocks().splice(
nextIf.getRegion().getBlocks().begin(),
exOp.getRegion().getBlocks());
SmallVector<mlir::Value, 3> resvals = (nextIf.getResults());
lastVal = resvals.back();
@ -559,15 +559,15 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
lastVal = newLoad->getResult(0);
}
valueAtStartOfBlock[&*ifOp.thenRegion().begin()] = lastVal;
valueAtStartOfBlock[&*ifOp.getThenRegion().begin()] = lastVal;
mlir::Value thenVal =
handleBlock(*ifOp.thenRegion().begin(), lastVal);
handleBlock(*ifOp.getThenRegion().begin(), lastVal);
if (lastVal && ifOp.elseRegion().getBlocks().size())
valueAtStartOfBlock[&*ifOp.elseRegion().begin()] = lastVal;
if (lastVal && ifOp.getElseRegion().getBlocks().size())
valueAtStartOfBlock[&*ifOp.getElseRegion().begin()] = lastVal;
mlir::Value elseVal =
(ifOp.elseRegion().getBlocks().size())
? handleBlock(*ifOp.elseRegion().begin(), lastVal)
(ifOp.getElseRegion().getBlocks().size())
? handleBlock(*ifOp.getElseRegion().begin(), lastVal)
: lastVal;
if (thenVal == elseVal && thenVal != nullptr) {
@ -582,41 +582,37 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
ifOp.getResultTypes().end());
tys.push_back(thenVal.getType());
auto nextIf = B.create<mlir::scf::IfOp>(
ifOp.getLoc(), tys, ifOp.condition(), /*hasElse*/ true);
ifOp.getLoc(), tys, ifOp.getCondition(), /*hasElse*/ true);
Block &then = ifOp.thenRegion().back();
Block &then = ifOp.getThenRegion().back();
SmallVector<mlir::Value, 4> thenVals =
cast<mlir::scf::YieldOp>(then.back()).results();
cast<mlir::scf::YieldOp>(then.back()).getResults();
thenVals.push_back(thenVal);
nextIf.thenRegion().getBlocks().clear();
nextIf.thenRegion().getBlocks().splice(
nextIf.thenRegion().getBlocks().begin(),
ifOp.thenRegion().getBlocks());
nextIf.getThenRegion().getBlocks().clear();
nextIf.getThenRegion().takeBody(ifOp.getThenRegion());
cast<mlir::scf::YieldOp>(
nextIf.thenRegion().back().getTerminator())
nextIf.getThenRegion().back().getTerminator())
->setOperands(thenVals);
if (ifOp.elseRegion().getBlocks().size()) {
nextIf.elseRegion().getBlocks().clear();
if (ifOp.getElseRegion().getBlocks().size()) {
nextIf.getElseRegion().getBlocks().clear();
SmallVector<mlir::Value, 4> elseVals =
cast<mlir::scf::YieldOp>(ifOp.elseRegion().back().back())
.results();
cast<mlir::scf::YieldOp>(ifOp.getElseRegion().back().back())
.getResults();
elseVals.push_back(elseVal);
nextIf.elseRegion().getBlocks().splice(
nextIf.elseRegion().getBlocks().begin(),
ifOp.elseRegion().getBlocks());
nextIf.getElseRegion().takeBody(ifOp.getElseRegion());
cast<mlir::scf::YieldOp>(
nextIf.elseRegion().back().getTerminator())
nextIf.getElseRegion().back().getTerminator())
->setOperands(elseVals);
} else {
B.setInsertionPoint(&nextIf.elseRegion().back(),
nextIf.elseRegion().back().begin());
B.setInsertionPoint(&nextIf.getElseRegion().back(),
nextIf.getElseRegion().back().begin());
SmallVector<mlir::Value, 4> elseVals;
elseVals.push_back(elseVal);
B.create<mlir::scf::YieldOp>(ifOp.getLoc(), elseVals);
}
SmallVector<mlir::Value, 3> resvals = (nextIf.results());
SmallVector<mlir::Value, 3> resvals = nextIf.getResults();
lastVal = resvals.back();
resvals.pop_back();
ifOp.replaceAllUsesWith(resvals);

View File

@ -101,7 +101,7 @@ struct ReplaceIfWithFors : public OpRewritePattern<scf::IfOp> {
LogicalResult matchAndRewrite(scf::IfOp op,
PatternRewriter &rewriter) const override {
assert(op.condition().getType().isInteger(1));
assert(op.getCondition().getType().isInteger(1));
if (!hasNestedBarrier(op)) {
LLVM_DEBUG(DBGS() << "[if-to-for] no nested barrier\n");
@ -119,7 +119,7 @@ struct ReplaceIfWithFors : public OpRewritePattern<scf::IfOp> {
auto cond = rewriter.create<IndexCastOp>(
loc, rewriter.getIndexType(),
rewriter.create<ExtUIOp>(loc, op.condition(),
rewriter.create<ExtUIOp>(loc, op.getCondition(),
mlir::IntegerType::get(one.getContext(), 64)));
auto thenLoop = rewriter.create<scf::ForOp>(loc, zero, cond, one, forArgs);
if (forArgs.size() == 0)
@ -128,7 +128,7 @@ struct ReplaceIfWithFors : public OpRewritePattern<scf::IfOp> {
SmallVector<Value> vals;
if (!op.elseRegion().empty()) {
if (!op.getElseRegion().empty()) {
auto negCondition = rewriter.create<SubIOp>(loc, one, cond);
scf::ForOp elseLoop = rewriter.create<scf::ForOp>(loc, zero, negCondition, one, forArgs);
if (forArgs.size() == 0)
@ -136,7 +136,7 @@ struct ReplaceIfWithFors : public OpRewritePattern<scf::IfOp> {
rewriter.mergeBlocks(op.getBody(1), elseLoop.getBody(0));
for (auto tup : llvm::zip(thenLoop.getResults(), elseLoop.getResults())) {
vals.push_back(rewriter.create<SelectOp>(op.getLoc(), op.condition(), std::get<0>(tup), std::get<1>(tup)));
vals.push_back(rewriter.create<SelectOp>(op.getLoc(), op.getCondition(), std::get<0>(tup), std::get<1>(tup)));
}
}
@ -153,7 +153,7 @@ static bool isDefinedAbove(Value value, Operation *user) {
/// Returns `true` if the loop has a form expected by interchange patterns.
static bool isNormalized(scf::ForOp op) {
return isDefinedAbove(op.lowerBound(), op) && isDefinedAbove(op.step(), op);
return isDefinedAbove(op.getLowerBound(), op) && isDefinedAbove(op.getStep(), op);
}
/// Transforms a loop to the normal form expected by interchange patterns, i.e.
@ -179,15 +179,15 @@ struct NormalizeLoop : public OpRewritePattern<scf::ForOp> {
rewriter.restoreInsertionPoint(point);
Value difference =
rewriter.create<SubIOp>(op.getLoc(), op.upperBound(), op.lowerBound());
rewriter.create<SubIOp>(op.getLoc(), op.getUpperBound(), op.getLowerBound());
Value tripCount =
rewriter.create<CeilDivSIOp>(op.getLoc(), difference, op.step());
rewriter.create<CeilDivSIOp>(op.getLoc(), difference, op.getStep());
auto newForOp =
rewriter.create<scf::ForOp>(op.getLoc(), zero, tripCount, one);
rewriter.setInsertionPointToStart(newForOp.getBody());
Value scaled = rewriter.create<MulIOp>(
op.getLoc(), newForOp.getInductionVar(), op.step());
Value iv = rewriter.create<AddIOp>(op.getLoc(), op.lowerBound(), scaled);
op.getLoc(), newForOp.getInductionVar(), op.getStep());
Value iv = rewriter.create<AddIOp>(op.getLoc(), op.getLowerBound(), scaled);
rewriter.mergeBlockBefore(op.getBody(), &newForOp.getBody()->back(), {iv});
rewriter.eraseOp(&newForOp.getBody()->back());
rewriter.eraseOp(op);
@ -205,8 +205,8 @@ static bool isNormalized(scf::ParallelOp op) {
APInt value;
return matchPattern(v, m_ConstantInt(&value)) && value.isOneValue();
};
return llvm::all_of(op.lowerBound(), isZero) &&
llvm::all_of(op.step(), isOne);
return llvm::all_of(op.getLowerBound(), isZero) &&
llvm::all_of(op.getStep(), isOne);
}
/// Transforms a loop to the normal form expected by interchange patterns, i.e.
@ -242,9 +242,9 @@ struct NormalizeParallel : public OpRewritePattern<scf::ParallelOp> {
rewriter.setInsertionPointToStart(newOp.getBody());
for (unsigned i = 0, e = iterationCounts.size(); i < e; ++i) {
Value scaled = rewriter.create<MulIOp>(
op.getLoc(), newOp.getInductionVars()[i], op.step()[i]);
op.getLoc(), newOp.getInductionVars()[i], op.getStep()[i]);
Value shifted =
rewriter.create<AddIOp>(op.getLoc(), op.lowerBound()[i], scaled);
rewriter.create<AddIOp>(op.getLoc(), op.getLowerBound()[i], scaled);
inductionVars.push_back(shifted);
}
@ -333,7 +333,7 @@ struct WrapForWithBarrier : public OpRewritePattern<scf::ForOp> {
return wrapWithBarriers(op, rewriter, [&](Operation *prevOp) {
if (auto loadOp = dyn_cast_or_null<memref::LoadOp>(prevOp)) {
if (loadOp.result() == op.upperBound() &&
if (loadOp.result() == op.getUpperBound() &&
loadOp.indices() ==
cast<scf::ParallelOp>(op->getParentOp()).getInductionVars()) {
prevOp = prevOp->getPrevNode();
@ -353,7 +353,7 @@ struct WrapWhileWithBarrier : public OpRewritePattern<scf::WhileOp> {
LogicalResult matchAndRewrite(scf::WhileOp op,
PatternRewriter &rewriter) const override {
if (op.getNumOperands() != 0 ||
!llvm::hasSingleElement(op.after().front())) {
!llvm::hasSingleElement(op.getAfter().front())) {
LLVM_DEBUG(DBGS() << "[wrap-while] ignoring non-rotated loop\n";);
return failure();
}
@ -372,7 +372,7 @@ static void moveBodies(PatternRewriter &rewriter, scf::ParallelOp op,
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(newForLoop.getBody());
auto newParallel = rewriter.create<scf::ParallelOp>(
op.getLoc(), op.lowerBound(), op.upperBound(), op.step());
op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep());
// Merge in two stages so we can properly replace uses of two induction
// varibales defined in different blocks.
@ -419,8 +419,8 @@ struct InterchangeForPFor : public OpRewritePattern<scf::ParallelOp> {
}
auto newForLoop =
rewriter.create<scf::ForOp>(forLoop.getLoc(), forLoop.lowerBound(),
forLoop.upperBound(), forLoop.step());
rewriter.create<scf::ForOp>(forLoop.getLoc(), forLoop.getLowerBound(),
forLoop.getUpperBound(), forLoop.getStep());
moveBodies(rewriter, op, forLoop, newForLoop);
return success();
}
@ -446,7 +446,7 @@ struct InterchangeForPForLoad : public OpRewritePattern<scf::ParallelOp> {
}
auto loadOp = dyn_cast<memref::LoadOp>(op.getBody()->front());
auto forOp = dyn_cast<scf::ForOp>(op.getBody()->front().getNextNode());
if (!loadOp || !forOp || loadOp.result() != forOp.upperBound() ||
if (!loadOp || !forOp || loadOp.result() != forOp.getUpperBound() ||
loadOp.indices() != op.getInductionVars()) {
LLVM_DEBUG(DBGS() << "[interchange-load] expected pfor(load, for)");
return failure();
@ -470,7 +470,7 @@ struct InterchangeForPForLoad : public OpRewritePattern<scf::ParallelOp> {
loadOp.getLoc(), loadOp.getMemRef(),
SmallVector<Value>(loadOp.getMemRefType().getRank(), zero));
auto newForLoop = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.lowerBound(), tripCount, forOp.step());
forOp.getLoc(), forOp.getLowerBound(), tripCount, forOp.getStep());
moveBodies(rewriter, op, forOp, newForLoop);
return success();
@ -535,9 +535,9 @@ findInsertionPointAfterLoopOperands(scf::ParallelOp op) {
// Find the earliest insertion point where loop bounds are fully defined.
PostDominanceInfo postDominanceInfo(op->getParentOfType<FuncOp>());
SmallVector<Value> operands;
llvm::append_range(operands, op.lowerBound());
llvm::append_range(operands, op.upperBound());
llvm::append_range(operands, op.step());
llvm::append_range(operands, op.getLowerBound());
llvm::append_range(operands, op.getUpperBound());
llvm::append_range(operands, op.getStep());
return findNearestPostDominatingInsertionPoint(operands, postDominanceInfo);
}
@ -562,7 +562,7 @@ struct InterchangeWhilePFor : public OpRewritePattern<scf::ParallelOp> {
LLVM_DEBUG(DBGS() << "[interchange-while] loop-carried values\n");
return failure();
}
if (!llvm::hasSingleElement(whileOp.after().front()) || !isNormalized(op)) {
if (!llvm::hasSingleElement(whileOp.getAfter().front()) || !isNormalized(op)) {
LLVM_DEBUG(DBGS() << "[interchange-while] non-normalized loop\n");
return failure();
}
@ -573,37 +573,37 @@ struct InterchangeWhilePFor : public OpRewritePattern<scf::ParallelOp> {
auto newWhileOp = rewriter.create<scf::WhileOp>(whileOp.getLoc(),
TypeRange(), ValueRange());
rewriter.createBlock(&newWhileOp.after());
rewriter.clone(whileOp.after().front().back());
rewriter.createBlock(&newWhileOp.getAfter());
rewriter.clone(whileOp.getAfter().front().back());
rewriter.createBlock(&newWhileOp.before());
rewriter.createBlock(&newWhileOp.getBefore());
auto newParallelOp = rewriter.create<scf::ParallelOp>(
op.getLoc(), op.lowerBound(), op.upperBound(), op.step());
op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep());
auto conditionOp = cast<scf::ConditionOp>(whileOp.before().front().back());
auto conditionOp = cast<scf::ConditionOp>(whileOp.getBefore().front().back());
rewriter.mergeBlockBefore(op.getBody(), &newParallelOp.getBody()->back(),
newParallelOp.getInductionVars());
rewriter.eraseOp(newParallelOp.getBody()->back().getPrevNode());
rewriter.mergeBlockBefore(&whileOp.before().front(),
rewriter.mergeBlockBefore(&whileOp.getBefore().front(),
&newParallelOp.getBody()->back());
Operation *conditionDefiningOp = conditionOp.condition().getDefiningOp();
Operation *conditionDefiningOp = conditionOp.getCondition().getDefiningOp();
if (conditionDefiningOp &&
!isDefinedAbove(conditionOp.condition(), conditionOp)) {
!isDefinedAbove(conditionOp.getCondition(), conditionOp)) {
std::pair<Block *, Block::iterator> insertionPoint =
findInsertionPointAfterLoopOperands(op);
rewriter.setInsertionPoint(insertionPoint.first, insertionPoint.second);
SmallVector<Value> iterationCounts = emitIterationCounts(rewriter, op);
Value allocated = allocateTemporaryBuffer<memref::AllocaOp>(
rewriter, conditionOp.condition(), iterationCounts);
rewriter, conditionOp.getCondition(), iterationCounts);
Value zero = rewriter.create<ConstantIndexOp>(op.getLoc(), 0);
rewriter.setInsertionPointAfter(conditionDefiningOp);
rewriter.create<memref::StoreOp>(conditionDefiningOp->getLoc(),
conditionOp.condition(), allocated,
conditionOp.getCondition(), allocated,
newParallelOp.getInductionVars());
rewriter.setInsertionPointToEnd(&newWhileOp.before().front());
rewriter.setInsertionPointToEnd(&newWhileOp.getBefore().front());
SmallVector<Value> zeros(iterationCounts.size(), zero);
Value reloaded = rewriter.create<memref::LoadOp>(
conditionDefiningOp->getLoc(), allocated, zeros);
@ -646,7 +646,7 @@ struct RotateWhile : public OpRewritePattern<scf::WhileOp> {
LogicalResult matchAndRewrite(scf::WhileOp op,
PatternRewriter &rewriter) const override {
if (llvm::hasSingleElement(op.after().front())) {
if (llvm::hasSingleElement(op.getAfter().front())) {
LLVM_DEBUG(DBGS() << "[rotate-while] the after region is empty");
return failure();
}
@ -659,15 +659,15 @@ struct RotateWhile : public OpRewritePattern<scf::WhileOp> {
return failure();
}
auto condition = cast<scf::ConditionOp>(op.before().front().back());
auto condition = cast<scf::ConditionOp>(op.getBefore().front().back());
rewriter.setInsertionPoint(condition);
auto conditional =
rewriter.create<scf::IfOp>(op.getLoc(), condition.condition());
rewriter.mergeBlockBefore(&op.after().front(),
rewriter.create<scf::IfOp>(op.getLoc(), condition.getCondition());
rewriter.mergeBlockBefore(&op.getAfter().front(),
&conditional.getBody()->back());
rewriter.eraseOp(&conditional.getBody()->back());
rewriter.createBlock(&op.after());
rewriter.createBlock(&op.getAfter());
rewriter.clone(conditional.getBody()->back());
LLVM_DEBUG(DBGS() << "[rotate-while] done\n");
@ -758,7 +758,7 @@ struct DistributeAroundBarrier : public OpRewritePattern<scf::ParallelOp> {
// Create the second loop.
rewriter.setInsertionPointAfter(op);
auto newLoop = rewriter.create<scf::ParallelOp>(
op.getLoc(), op.lowerBound(), op.upperBound(), op.step());
op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep());
rewriter.eraseOp(&newLoop.getBody()->back());
for (auto alloc : allocations)
@ -837,8 +837,8 @@ struct Reg2MemFor : public OpRewritePattern<scf::ForOp> {
rewriter.create<memref::StoreOp>(op.getLoc(), operand, alloc, zero);
}
auto newOp = rewriter.create<scf::ForOp>(op.getLoc(), op.lowerBound(),
op.upperBound(), op.step());
auto newOp = rewriter.create<scf::ForOp>(op.getLoc(), op.getLowerBound(),
op.getUpperBound(), op.getStep());
rewriter.setInsertionPointToStart(newOp.getBody());
SmallVector<Value> newRegionArguments;
newRegionArguments.push_back(newOp.getInductionVar());
@ -849,7 +849,7 @@ struct Reg2MemFor : public OpRewritePattern<scf::ForOp> {
newRegionArguments);
rewriter.setInsertionPoint(newOp.getBody()->getTerminator());
for (auto en : llvm::enumerate(oldTerminator.results())) {
for (auto en : llvm::enumerate(oldTerminator.getResults())) {
rewriter.create<memref::StoreOp>(op.getLoc(), en.value(),
allocated[en.index()], zero);
}
@ -908,35 +908,35 @@ struct Reg2MemWhile : public OpRewritePattern<scf::WhileOp> {
auto newOp =
rewriter.create<scf::WhileOp>(op.getLoc(), TypeRange(), ValueRange());
Block *newBefore =
rewriter.createBlock(&newOp.before(), newOp.before().begin());
rewriter.createBlock(&newOp.getBefore(), newOp.getBefore().begin());
SmallVector<Value> newBeforeArguments;
loadValues(op.getLoc(), beforeAllocated, zero, rewriter,
newBeforeArguments);
rewriter.mergeBlocks(&op.before().front(), newBefore, newBeforeArguments);
rewriter.mergeBlocks(&op.getBefore().front(), newBefore, newBeforeArguments);
auto beforeTerminator =
cast<scf::ConditionOp>(newOp.before().front().getTerminator());
cast<scf::ConditionOp>(newOp.getBefore().front().getTerminator());
rewriter.setInsertionPoint(beforeTerminator);
storeValues(op.getLoc(), beforeTerminator.args(), afterAllocated, zero,
storeValues(op.getLoc(), beforeTerminator.getArgs(), afterAllocated, zero,
rewriter);
rewriter.updateRootInPlace(beforeTerminator,
[&] { beforeTerminator.argsMutable().clear(); });
[&] { beforeTerminator.getArgsMutable().clear(); });
Block *newAfter =
rewriter.createBlock(&newOp.after(), newOp.after().begin());
rewriter.createBlock(&newOp.getAfter(), newOp.getAfter().begin());
SmallVector<Value> newAfterArguments;
loadValues(op.getLoc(), afterAllocated, zero, rewriter, newAfterArguments);
rewriter.mergeBlocks(&op.after().front(), newAfter, newAfterArguments);
rewriter.mergeBlocks(&op.getAfter().front(), newAfter, newAfterArguments);
auto afterTerminator =
cast<scf::YieldOp>(newOp.after().front().getTerminator());
cast<scf::YieldOp>(newOp.getAfter().front().getTerminator());
rewriter.setInsertionPoint(afterTerminator);
storeValues(op.getLoc(), afterTerminator.results(), beforeAllocated, zero,
storeValues(op.getLoc(), afterTerminator.getResults(), beforeAllocated, zero,
rewriter);
rewriter.updateRootInPlace(
afterTerminator, [&] { afterTerminator.resultsMutable().clear(); });
afterTerminator, [&] { afterTerminator.getResultsMutable().clear(); });
rewriter.setInsertionPointAfter(op);
SmallVector<Value> results;

View File

@ -30,7 +30,7 @@ struct ForOpRaising : public OpRewritePattern<scf::ForOp> {
bool isAffine(scf::ForOp loop) const {
// return true;
// enforce step to be a ConstantIndexOp (maybe too restrictive).
return isa_and_nonnull<ConstantIndexOp>(loop.step().getDefiningOp());
return isa_and_nonnull<ConstantIndexOp>(loop.getStep().getDefiningOp());
}
void canonicalizeLoopBounds(AffineForOp forOp) const {
@ -67,23 +67,23 @@ struct ForOpRaising : public OpRewritePattern<scf::ForOp> {
if (isAffine(loop)) {
OpBuilder builder(loop);
if (!isValidIndex(loop.lowerBound())) {
if (!isValidIndex(loop.getLowerBound())) {
return failure();
}
if (!isValidIndex(loop.upperBound())) {
if (!isValidIndex(loop.getUpperBound())) {
return failure();
}
AffineForOp affineLoop = rewriter.create<AffineForOp>(
loop.getLoc(), loop.lowerBound(), builder.getSymbolIdentityMap(),
loop.upperBound(), builder.getSymbolIdentityMap(),
getStep(loop.step()), loop.getIterOperands());
loop.getLoc(), loop.getLowerBound(), builder.getSymbolIdentityMap(),
loop.getUpperBound(), builder.getSymbolIdentityMap(),
getStep(loop.getStep()), loop.getIterOperands());
canonicalizeLoopBounds(affineLoop);
auto mergedYieldOp =
cast<scf::YieldOp>(loop.region().front().getTerminator());
cast<scf::YieldOp>(loop.getRegion().front().getTerminator());
Block &newBlock = affineLoop.region().front();
@ -97,10 +97,10 @@ struct ForOpRaising : public OpRewritePattern<scf::ForOp> {
rewriter.updateRootInPlace(loop, [&] {
affineLoop.region().front().getOperations().splice(
affineLoop.region().front().getOperations().begin(),
loop.region().front().getOperations());
loop.getRegion().front().getOperations());
for (auto pair : llvm::zip(affineLoop.region().front().getArguments(),
loop.region().front().getArguments())) {
loop.getRegion().front().getArguments())) {
std::get<1>(pair).replaceAllUsesWith(std::get<0>(pair));
}
});

@ -1 +1 @@
Subproject commit fe5137a0fe421ee153f115ae3cd7fb51bba65795
Subproject commit a6a583dae40485cacfac56811e6d9131bac6ca74

View File

@ -172,10 +172,10 @@ void MLIRScanner::buildAffineLoopImpl(
builder.setInsertionPointToEnd(&reg.front());
auto er = builder.create<scf::ExecuteRegionOp>(loc, ArrayRef<mlir::Type>());
er.region().push_back(new Block());
builder.setInsertionPointToStart(&er.region().back());
er.getRegion().push_back(new Block());
builder.setInsertionPointToStart(&er.getRegion().back());
builder.create<scf::YieldOp>(loc);
builder.setInsertionPointToStart(&er.region().back());
builder.setInsertionPointToStart(&er.getRegion().back());
if (!descr.getForwardMode()) {
val = builder.create<SubIOp>(loc, val, lb);
@ -355,15 +355,15 @@ ValueCategory MLIRScanner::VisitOMPParallelForDirective(
auto oldpoint = builder.getInsertionPoint();
auto oldblock = builder.getInsertionBlock();
builder.setInsertionPointToStart(&affineOp.region().front());
builder.setInsertionPointToStart(&affineOp.getRegion().front());
auto executeRegion =
builder.create<scf::ExecuteRegionOp>(loc, ArrayRef<mlir::Type>());
executeRegion.region().push_back(new Block());
builder.setInsertionPointToStart(&executeRegion.region().back());
executeRegion.getRegion().push_back(new Block());
builder.setInsertionPointToStart(&executeRegion.getRegion().back());
auto oldScope = allocationScope;
allocationScope = &executeRegion.region().back();
allocationScope = &executeRegion.getRegion().back();
for (auto zp : zip(inds, fors->counters())) {
auto idx = builder.create<IndexCastOp>(
@ -552,13 +552,13 @@ ValueCategory MLIRScanner::VisitIfStmt(clang::IfStmt *stmt) {
bool hasElseRegion = stmt->getElse();
auto ifOp = builder.create<mlir::scf::IfOp>(loc, cond, hasElseRegion);
ifOp.thenRegion().back().clear();
builder.setInsertionPointToStart(&ifOp.thenRegion().back());
ifOp.getThenRegion().back().clear();
builder.setInsertionPointToStart(&ifOp.getThenRegion().back());
Visit(stmt->getThen());
builder.create<scf::YieldOp>(loc);
if (hasElseRegion) {
ifOp.elseRegion().back().clear();
builder.setInsertionPointToStart(&ifOp.elseRegion().back());
ifOp.getElseRegion().back().clear();
builder.setInsertionPointToStart(&ifOp.getElseRegion().back());
Visit(stmt->getElse());
builder.create<scf::YieldOp>(loc);
}
@ -574,7 +574,7 @@ ValueCategory MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) {
SmallVector<int64_t> caseVals;
auto er = builder.create<scf::ExecuteRegionOp>(loc, ArrayRef<mlir::Type>());
er.region().push_back(new Block());
er.getRegion().push_back(new Block());
auto oldpoint2 = builder.getInsertionPoint();
auto oldblock2 = builder.getInsertionBlock();
@ -613,7 +613,7 @@ ValueCategory MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) {
}
inCase = true;
er.region().getBlocks().push_back(&condB);
er.getRegion().getBlocks().push_back(&condB);
blocks.push_back(&condB);
builder.setInsertionPointToStart(&condB);
@ -638,7 +638,7 @@ ValueCategory MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) {
}
inCase = true;
er.region().getBlocks().push_back(&condB);
er.getRegion().getBlocks().push_back(&condB);
builder.setInsertionPointToStart(&condB);
auto i1Ty = builder.getIntegerType(1);
@ -668,7 +668,7 @@ ValueCategory MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) {
loops.pop_back();
builder.create<mlir::BranchOp>(loc, &exitB);
er.region().getBlocks().push_back(&exitB);
er.getRegion().getBlocks().push_back(&exitB);
DenseIntElementsAttr caseValuesAttr;
ShapedType caseValueType = mlir::VectorType::get(
@ -694,7 +694,7 @@ ValueCategory MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) {
caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseVals8);
}
builder.setInsertionPointToStart(&er.region().front());
builder.setInsertionPointToStart(&er.getRegion().front());
builder.create<mlir::SwitchOp>(
loc, cond, defaultB, ArrayRef<mlir::Value>(), caseValuesAttr, blocks,
SmallVector<mlir::ValueRange>(caseVals.size(), ArrayRef<mlir::Value>()));

View File

@ -21,13 +21,13 @@ IfScope::IfScope(MLIRScanner &scanner) : scanner(scanner), prevBlock(nullptr) {
/*hasElse*/ false);
prevBlock = scanner.builder.getInsertionBlock();
prevIterator = scanner.builder.getInsertionPoint();
ifOp.thenRegion().back().clear();
scanner.builder.setInsertionPointToStart(&ifOp.thenRegion().back());
ifOp.getThenRegion().back().clear();
scanner.builder.setInsertionPointToStart(&ifOp.getThenRegion().back());
auto er = scanner.builder.create<scf::ExecuteRegionOp>(
scanner.loc, ArrayRef<mlir::Type>());
scanner.builder.create<scf::YieldOp>(scanner.loc);
er.region().push_back(new Block());
scanner.builder.setInsertionPointToStart(&er.region().back());
er.getRegion().push_back(new Block());
scanner.builder.setInsertionPointToStart(&er.getRegion().back());
}
}

View File

@ -68,7 +68,8 @@ static mlir::Value castCallerMemRefArg(mlir::Value callerArg,
if (MemRefType dstTy = calleeArgType.dyn_cast_or_null<MemRefType>()) {
MemRefType srcTy = callerArgType.dyn_cast<MemRefType>();
if (srcTy && dstTy.getElementType() == srcTy.getElementType()) {
if (srcTy && dstTy.getElementType() == srcTy.getElementType()
&& dstTy.getMemorySpace() == srcTy.getMemorySpace()) {
auto srcShape = srcTy.getShape();
auto dstShape = dstTy.getShape();
@ -748,7 +749,7 @@ ValueCategory MLIRScanner::VisitVarDecl(clang::VarDecl *decl) {
auto ifOp = builder.create<scf::IfOp>(varLoc, cond, /*hasElse*/false);
block = builder.getInsertionBlock();
iter = builder.getInsertionPoint();
builder.setInsertionPointToStart(&ifOp.thenRegion().back());
builder.setInsertionPointToStart(&ifOp.getThenRegion().back());
builder.create<memref::StoreOp>(varLoc, builder.create<ConstantIntOp>(varLoc, false, 1), boolop, std::vector<mlir::Value>({getConstantIndex(0)}));
}
} else
@ -2919,15 +2920,6 @@ ValueCategory MLIRScanner::CallHelper(
}
}
assert(val);
/*
if (val.getType() != fnType.getInput(i)) {
if (auto MR1 = val.getType().dyn_cast<MemRefType>()) {
if (auto MR2 = fnType.getInput(i).dyn_cast<MemRefType>()) {
val = builder.create<mlir::memref::CastOp>(loc, val, MR2);
}
}
}
*/
args.push_back(val);
i++;
}
@ -3460,7 +3452,7 @@ ValueCategory MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) {
auto oldpoint = builder.getInsertionPoint();
auto oldblock = builder.getInsertionBlock();
builder.setInsertionPointToStart(&ifOp.thenRegion().back());
builder.setInsertionPointToStart(&ifOp.getThenRegion().back());
auto rhs = Visit(BO->getRHS()).getValue(builder);
assert(rhs != nullptr);
@ -3476,7 +3468,7 @@ ValueCategory MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) {
mlir::Value truearray[] = {rhs};
builder.create<mlir::scf::YieldOp>(loc, truearray);
builder.setInsertionPointToStart(&ifOp.elseRegion().back());
builder.setInsertionPointToStart(&ifOp.getElseRegion().back());
mlir::Value falsearray[] = {
builder.create<ConstantIntOp>(loc, 0, types[0])};
builder.create<mlir::scf::YieldOp>(loc, falsearray);
@ -3497,12 +3489,12 @@ ValueCategory MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) {
auto oldpoint = builder.getInsertionPoint();
auto oldblock = builder.getInsertionBlock();
builder.setInsertionPointToStart(&ifOp.thenRegion().back());
builder.setInsertionPointToStart(&ifOp.getThenRegion().back());
mlir::Value truearray[] = {builder.create<ConstantIntOp>(loc, 1, types[0])};
builder.create<mlir::scf::YieldOp>(loc, truearray);
builder.setInsertionPointToStart(&ifOp.elseRegion().back());
builder.setInsertionPointToStart(&ifOp.getElseRegion().back());
auto rhs = Visit(BO->getRHS()).getValue(builder);
if (!rhs.getType().cast<mlir::IntegerType>().isInteger(1)) {
auto postTy = builder.getI1Type();
@ -4747,7 +4739,7 @@ MLIRScanner::VisitConditionalOperator(clang::ConditionalOperator *E) {
auto oldpoint = builder.getInsertionPoint();
auto oldblock = builder.getInsertionBlock();
builder.setInsertionPointToStart(&ifOp.thenRegion().back());
builder.setInsertionPointToStart(&ifOp.getThenRegion().back());
auto trueExpr = Visit(E->getTrueExpr());
@ -4778,7 +4770,7 @@ MLIRScanner::VisitConditionalOperator(clang::ConditionalOperator *E) {
builder.create<mlir::scf::YieldOp>(loc, truearray);
}
builder.setInsertionPointToStart(&ifOp.elseRegion().back());
builder.setInsertionPointToStart(&ifOp.getElseRegion().back());
auto falseExpr = Visit(E->getFalseExpr());
std::vector<mlir::Value> falsearray;
@ -4800,10 +4792,10 @@ MLIRScanner::VisitConditionalOperator(clang::ConditionalOperator *E) {
types[i] = truearray[i].getType();
auto newIfOp = builder.create<mlir::scf::IfOp>(loc, types, cond,
/*hasElseRegion*/ true);
newIfOp.thenRegion().takeBody(ifOp.thenRegion());
newIfOp.elseRegion().takeBody(ifOp.elseRegion());
newIfOp.getThenRegion().takeBody(ifOp.getThenRegion());
newIfOp.getElseRegion().takeBody(ifOp.getElseRegion());
ifOp.erase();
return ValueCategory(newIfOp.getResult(0), /*isReference*/ isReference);
// return ifOp;
}
ValueCategory MLIRScanner::VisitStmtExpr(clang::StmtExpr *stmt) {