Rebase LLVM
This commit is contained in:
parent
85e117eacb
commit
239a2c4d82
|
@ -29,7 +29,7 @@ static llvm::SmallVector<mlir::Value>
|
||||||
emitIterationCounts(mlir::OpBuilder &rewriter, mlir::scf::ParallelOp op) {
|
emitIterationCounts(mlir::OpBuilder &rewriter, mlir::scf::ParallelOp op) {
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
SmallVector<Value> iterationCounts;
|
SmallVector<Value> iterationCounts;
|
||||||
for (auto bounds : llvm::zip(op.lowerBound(), op.upperBound(), op.step())) {
|
for (auto bounds : llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep())) {
|
||||||
Value lowerBound = std::get<0>(bounds);
|
Value lowerBound = std::get<0>(bounds);
|
||||||
Value upperBound = std::get<1>(bounds);
|
Value upperBound = std::get<1>(bounds);
|
||||||
Value step = std::get<2>(bounds);
|
Value step = std::get<2>(bounds);
|
||||||
|
|
|
@ -693,6 +693,9 @@ public:
|
||||||
if (src.source().getType().cast<MemRefType>().getElementType() !=
|
if (src.source().getType().cast<MemRefType>().getElementType() !=
|
||||||
op.getType().cast<MemRefType>().getElementType())
|
op.getType().cast<MemRefType>().getElementType())
|
||||||
return failure();
|
return failure();
|
||||||
|
if (src.source().getType().cast<MemRefType>().getMemorySpace() !=
|
||||||
|
op.getType().cast<MemRefType>().getMemorySpace())
|
||||||
|
return failure();
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<memref::CastOp>(op, op.getType(), src.source());
|
rewriter.replaceOpWithNewOp<memref::CastOp>(op, op.getType(), src.source());
|
||||||
return success();
|
return success();
|
||||||
|
@ -732,6 +735,11 @@ OpFoldResult Memref2PointerOp::fold(ArrayRef<Attribute> operands) {
|
||||||
sourceMutable().assign(mc.source());
|
sourceMutable().assign(mc.source());
|
||||||
return result();
|
return result();
|
||||||
}
|
}
|
||||||
|
if (auto mc = source().getDefiningOp<polygeist::Pointer2MemrefOp>()) {
|
||||||
|
if (mc.source().getType() == getType()) {
|
||||||
|
return mc.source();
|
||||||
|
}
|
||||||
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -868,6 +876,11 @@ OpFoldResult Pointer2MemrefOp::fold(ArrayRef<Attribute> operands) {
|
||||||
sourceMutable().assign(mc.getArg());
|
sourceMutable().assign(mc.getArg());
|
||||||
return result();
|
return result();
|
||||||
}
|
}
|
||||||
|
if (auto mc = source().getDefiningOp<polygeist::Memref2PointerOp>()) {
|
||||||
|
if (mc.source().getType() == getType()) {
|
||||||
|
return mc.source();
|
||||||
|
}
|
||||||
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1047,7 +1047,7 @@ void AffineCFGPass::runOnFunction() {
|
||||||
OpBuilder b(ifOp);
|
OpBuilder b(ifOp);
|
||||||
AffineIfOp affineIfOp;
|
AffineIfOp affineIfOp;
|
||||||
std::vector<mlir::Type> types;
|
std::vector<mlir::Type> types;
|
||||||
for (auto v : ifOp.results()) {
|
for (auto v : ifOp.getResults()) {
|
||||||
types.push_back(v.getType());
|
types.push_back(v.getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1055,7 +1055,7 @@ void AffineCFGPass::runOnFunction() {
|
||||||
SmallVector<bool, 2> eqflags;
|
SmallVector<bool, 2> eqflags;
|
||||||
SmallVector<Value, 4> applies;
|
SmallVector<Value, 4> applies;
|
||||||
|
|
||||||
std::deque<Value> todo = {ifOp.condition()};
|
std::deque<Value> todo = {ifOp.getCondition()};
|
||||||
while (todo.size()) {
|
while (todo.size()) {
|
||||||
auto cur = todo.front();
|
auto cur = todo.front();
|
||||||
todo.pop_front();
|
todo.pop_front();
|
||||||
|
@ -1077,20 +1077,20 @@ void AffineCFGPass::runOnFunction() {
|
||||||
eqflags);
|
eqflags);
|
||||||
affineIfOp = b.create<AffineIfOp>(ifOp.getLoc(), types, iset, applies,
|
affineIfOp = b.create<AffineIfOp>(ifOp.getLoc(), types, iset, applies,
|
||||||
/*elseBlock=*/true);
|
/*elseBlock=*/true);
|
||||||
affineIfOp.thenRegion().takeBody(ifOp.thenRegion());
|
affineIfOp.thenRegion().takeBody(ifOp.getThenRegion());
|
||||||
affineIfOp.elseRegion().takeBody(ifOp.elseRegion());
|
affineIfOp.elseRegion().takeBody(ifOp.getElseRegion());
|
||||||
|
|
||||||
for (auto &blk : affineIfOp.thenRegion()) {
|
for (auto &blk : affineIfOp.thenRegion()) {
|
||||||
if (auto yop = dyn_cast<scf::YieldOp>(blk.getTerminator())) {
|
if (auto yop = dyn_cast<scf::YieldOp>(blk.getTerminator())) {
|
||||||
OpBuilder b(yop);
|
OpBuilder b(yop);
|
||||||
b.create<AffineYieldOp>(yop.getLoc(), yop.results());
|
b.create<AffineYieldOp>(yop.getLoc(), yop.getResults());
|
||||||
yop.erase();
|
yop.erase();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (auto &blk : affineIfOp.elseRegion()) {
|
for (auto &blk : affineIfOp.elseRegion()) {
|
||||||
if (auto yop = dyn_cast<scf::YieldOp>(blk.getTerminator())) {
|
if (auto yop = dyn_cast<scf::YieldOp>(blk.getTerminator())) {
|
||||||
OpBuilder b(yop);
|
OpBuilder b(yop);
|
||||||
b.create<AffineYieldOp>(yop.getLoc(), yop.results());
|
b.create<AffineYieldOp>(yop.getLoc(), yop.getResults());
|
||||||
yop.erase();
|
yop.erase();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,14 +57,14 @@ static void wrapPersistingLoopBodies(FuncOp function) {
|
||||||
for (scf::ParallelOp op : loops) {
|
for (scf::ParallelOp op : loops) {
|
||||||
OpBuilder builder = OpBuilder::atBlockBegin(op.getBody());
|
OpBuilder builder = OpBuilder::atBlockBegin(op.getBody());
|
||||||
auto wrapper = builder.create<scf::ExecuteRegionOp>(
|
auto wrapper = builder.create<scf::ExecuteRegionOp>(
|
||||||
op.getLoc(), op.results().getTypes());
|
op.getLoc(), op.getResults().getTypes());
|
||||||
builder.createBlock(&wrapper.region(), wrapper.region().begin());
|
builder.createBlock(&wrapper.getRegion(), wrapper.getRegion().begin());
|
||||||
wrapper.region().front().getOperations().splice(
|
wrapper.getRegion().front().getOperations().splice(
|
||||||
wrapper.region().front().begin(), op.getBody()->getOperations(),
|
wrapper.getRegion().front().begin(), op.getBody()->getOperations(),
|
||||||
std::next(op.getBody()->begin()), op.getBody()->end());
|
std::next(op.getBody()->begin()), op.getBody()->end());
|
||||||
builder.setInsertionPointToEnd(op.getBody());
|
builder.setInsertionPointToEnd(op.getBody());
|
||||||
builder.create<scf::YieldOp>(
|
builder.create<scf::YieldOp>(
|
||||||
wrapper.region().front().getTerminator()->getLoc(),
|
wrapper.getRegion().front().getTerminator()->getLoc(),
|
||||||
wrapper.getResults());
|
wrapper.getResults());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -124,7 +124,7 @@ static LogicalResult splitBlocksWithBarrier(FuncOp function) {
|
||||||
}
|
}
|
||||||
|
|
||||||
splitBlocksWithBarrier(
|
splitBlocksWithBarrier(
|
||||||
cast<scf::ExecuteRegionOp>(&op.getBody()->front()).region());
|
cast<scf::ExecuteRegionOp>(&op.getBody()->front()).getRegion());
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
return success(!result.wasInterrupted());
|
return success(!result.wasInterrupted());
|
||||||
|
@ -288,15 +288,15 @@ emitContinuationCase(Value condition, Value storage, scf::ParallelOp parallel,
|
||||||
bn.create<scf::ExecuteRegionOp>(TypeRange(), ValueRange());
|
bn.create<scf::ExecuteRegionOp>(TypeRange(), ValueRange());
|
||||||
BlockAndValueMapping mapping;
|
BlockAndValueMapping mapping;
|
||||||
mapping.map(parallel.getInductionVars(), ivs);
|
mapping.map(parallel.getInductionVars(), ivs);
|
||||||
replicateIntoRegion(executeRegion.region(), storage, ivs,
|
replicateIntoRegion(executeRegion.getRegion(), storage, ivs,
|
||||||
parallel.lowerBound(), blocks, subgraphEntryPoints,
|
parallel.getLowerBound(), blocks, subgraphEntryPoints,
|
||||||
mapping, builder);
|
mapping, builder);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto thenBuilder = [&](OpBuilder &nested, Location loc) {
|
auto thenBuilder = [&](OpBuilder &nested, Location loc) {
|
||||||
ImplicitLocOpBuilder bn(loc, nested);
|
ImplicitLocOpBuilder bn(loc, nested);
|
||||||
bn.create<scf::ParallelOp>(parallel.lowerBound(), parallel.upperBound(),
|
bn.create<scf::ParallelOp>(parallel.getLowerBound(), parallel.getUpperBound(),
|
||||||
parallel.step(), parallelBuilder);
|
parallel.getStep(), parallelBuilder);
|
||||||
bn.create<scf::YieldOp>();
|
bn.create<scf::YieldOp>();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -362,9 +362,9 @@ findInsertionPointAfterLoopOperands(scf::ParallelOp op) {
|
||||||
// Find the earliest insertion point where loop bounds are fully defined.
|
// Find the earliest insertion point where loop bounds are fully defined.
|
||||||
PostDominanceInfo postDominanceInfo(op->getParentOfType<FuncOp>());
|
PostDominanceInfo postDominanceInfo(op->getParentOfType<FuncOp>());
|
||||||
SmallVector<Value> operands;
|
SmallVector<Value> operands;
|
||||||
llvm::append_range(operands, op.lowerBound());
|
llvm::append_range(operands, op.getLowerBound());
|
||||||
llvm::append_range(operands, op.upperBound());
|
llvm::append_range(operands, op.getUpperBound());
|
||||||
llvm::append_range(operands, op.step());
|
llvm::append_range(operands, op.getStep());
|
||||||
return findNesrestPostDominatingInsertionPoint(operands, postDominanceInfo);
|
return findNesrestPostDominatingInsertionPoint(operands, postDominanceInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -536,8 +536,8 @@ static void createContinuations(scf::ParallelOp parallel, Value storage) {
|
||||||
llvm::SetVector<Block *> startBlocks;
|
llvm::SetVector<Block *> startBlocks;
|
||||||
auto outerExecuteRegion =
|
auto outerExecuteRegion =
|
||||||
cast<scf::ExecuteRegionOp>(¶llel.getBody()->front());
|
cast<scf::ExecuteRegionOp>(¶llel.getBody()->front());
|
||||||
startBlocks.insert(&outerExecuteRegion.region().front());
|
startBlocks.insert(&outerExecuteRegion.getRegion().front());
|
||||||
for (Block &block : outerExecuteRegion.region()) {
|
for (Block &block : outerExecuteRegion.getRegion()) {
|
||||||
if (!isa_and_nonnull<polygeist::BarrierOp>(
|
if (!isa_and_nonnull<polygeist::BarrierOp>(
|
||||||
block.getTerminator()->getPrevNode()))
|
block.getTerminator()->getPrevNode()))
|
||||||
continue;
|
continue;
|
||||||
|
@ -560,12 +560,12 @@ static void createContinuations(scf::ParallelOp parallel, Value storage) {
|
||||||
OpBuilder allocBuilder(loop);
|
OpBuilder allocBuilder(loop);
|
||||||
reg2mem(subgraphs, parallel, allocBuilder, builder);
|
reg2mem(subgraphs, parallel, allocBuilder, builder);
|
||||||
|
|
||||||
builder.createBlock(&loop.before(), loop.before().end());
|
builder.createBlock(&loop.getBefore(), loop.getBefore().end());
|
||||||
Value next = builder.create<memref::LoadOp>(storage);
|
Value next = builder.create<memref::LoadOp>(storage);
|
||||||
Value condition = builder.create<CmpIOp>(CmpIPredicate::ne, next, negOne);
|
Value condition = builder.create<CmpIOp>(CmpIPredicate::ne, next, negOne);
|
||||||
builder.create<scf::ConditionOp>(TypeRange(), condition, ValueRange());
|
builder.create<scf::ConditionOp>(TypeRange(), condition, ValueRange());
|
||||||
|
|
||||||
builder.createBlock(&loop.after(), loop.after().end());
|
builder.createBlock(&loop.getAfter(), loop.getAfter().end());
|
||||||
next = builder.create<memref::LoadOp>(storage);
|
next = builder.create<memref::LoadOp>(storage);
|
||||||
SmallVector<Value> caseConditions;
|
SmallVector<Value> caseConditions;
|
||||||
caseConditions.resize(startBlocks.size());
|
caseConditions.resize(startBlocks.size());
|
||||||
|
|
|
@ -52,7 +52,7 @@ static bool hasSameInitValue(Value iter, scf::ForOp forOp) {
|
||||||
if (!cst)
|
if (!cst)
|
||||||
return false;
|
return false;
|
||||||
if (auto cstOp = dyn_cast<ConstantIntOp>(cst)) {
|
if (auto cstOp = dyn_cast<ConstantIntOp>(cst)) {
|
||||||
Operation *lbDefOp = forOp.lowerBound().getDefiningOp();
|
Operation *lbDefOp = forOp.getLowerBound().getDefiningOp();
|
||||||
if (!lbDefOp)
|
if (!lbDefOp)
|
||||||
return false;
|
return false;
|
||||||
ConstantIndexOp lb = dyn_cast_or_null<ConstantIndexOp>(lbDefOp);
|
ConstantIndexOp lb = dyn_cast_or_null<ConstantIndexOp>(lbDefOp);
|
||||||
|
@ -69,7 +69,7 @@ static bool hasSameStepValue(Value regIter, Value yieldOp, scf::ForOp forOp) {
|
||||||
if (!defOpStep)
|
if (!defOpStep)
|
||||||
return false;
|
return false;
|
||||||
if (auto cstStep = dyn_cast<ConstantIntOp>(defOpStep)) {
|
if (auto cstStep = dyn_cast<ConstantIntOp>(defOpStep)) {
|
||||||
Operation *stepForDefOp = forOp.step().getDefiningOp();
|
Operation *stepForDefOp = forOp.getStep().getDefiningOp();
|
||||||
if (!stepForDefOp)
|
if (!stepForDefOp)
|
||||||
return false;
|
return false;
|
||||||
ConstantIndexOp stepFor = dyn_cast_or_null<ConstantIndexOp>(stepForDefOp);
|
ConstantIndexOp stepFor = dyn_cast_or_null<ConstantIndexOp>(stepForDefOp);
|
||||||
|
@ -123,7 +123,7 @@ struct DetectTrivialIndVarInArgs : public OpRewritePattern<scf::ForOp> {
|
||||||
if (!forOp.getNumIterOperands())
|
if (!forOp.getNumIterOperands())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Block &block = forOp.region().front();
|
Block &block = forOp.getRegion().front();
|
||||||
auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
|
auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
|
||||||
|
|
||||||
bool matched = false;
|
bool matched = false;
|
||||||
|
@ -150,7 +150,7 @@ struct ForOpInductionReplacement : public OpRewritePattern<scf::ForOp> {
|
||||||
LogicalResult matchAndRewrite(scf::ForOp forOp,
|
LogicalResult matchAndRewrite(scf::ForOp forOp,
|
||||||
PatternRewriter &rewriter) const final {
|
PatternRewriter &rewriter) const final {
|
||||||
bool canonicalize = false;
|
bool canonicalize = false;
|
||||||
Block &block = forOp.region().front();
|
Block &block = forOp.getRegion().front();
|
||||||
auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
|
auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
|
||||||
|
|
||||||
for (auto it : llvm::zip(forOp.getIterOperands(), // iter from outside
|
for (auto it : llvm::zip(forOp.getIterOperands(), // iter from outside
|
||||||
|
@ -171,10 +171,10 @@ struct ForOpInductionReplacement : public OpRewritePattern<scf::ForOp> {
|
||||||
|
|
||||||
bool isOne = false;
|
bool isOne = false;
|
||||||
|
|
||||||
if (addOp.getOperand(1) == forOp.step()) {
|
if (addOp.getOperand(1) == forOp.getStep()) {
|
||||||
legalStep = true;
|
legalStep = true;
|
||||||
} else if (auto iter_step =
|
} else if (auto iter_step =
|
||||||
forOp.step().getDefiningOp<ConstantIndexOp>()) {
|
forOp.getStep().getDefiningOp<ConstantIndexOp>()) {
|
||||||
isOne |= iter_step.value() == 1;
|
isOne |= iter_step.value() == 1;
|
||||||
if (auto op = addOp.getOperand(1).getDefiningOp<ConstantIntOp>()) {
|
if (auto op = addOp.getOperand(1).getDefiningOp<ConstantIntOp>()) {
|
||||||
if (op.value() == iter_step.value()) {
|
if (op.value() == iter_step.value()) {
|
||||||
|
@ -200,9 +200,9 @@ struct ForOpInductionReplacement : public OpRewritePattern<scf::ForOp> {
|
||||||
Value init = std::get<0>(it);
|
Value init = std::get<0>(it);
|
||||||
|
|
||||||
if (!std::get<1>(it).use_empty()) {
|
if (!std::get<1>(it).use_empty()) {
|
||||||
rewriter.setInsertionPointToStart(&forOp.region().front());
|
rewriter.setInsertionPointToStart(&forOp.getRegion().front());
|
||||||
Value replacement = rewriter.create<SubIOp>(
|
Value replacement = rewriter.create<SubIOp>(
|
||||||
forOp.getLoc(), forOp.getInductionVar(), forOp.lowerBound());
|
forOp.getLoc(), forOp.getInductionVar(), forOp.getLowerBound());
|
||||||
if (!std::get<1>(it).getType().isa<IndexType>()) {
|
if (!std::get<1>(it).getType().isa<IndexType>()) {
|
||||||
replacement = rewriter.create<IndexCastOp>(
|
replacement = rewriter.create<IndexCastOp>(
|
||||||
forOp.getLoc(), replacement, std::get<1>(it).getType());
|
forOp.getLoc(), replacement, std::get<1>(it).getType());
|
||||||
|
@ -223,7 +223,7 @@ struct ForOpInductionReplacement : public OpRewritePattern<scf::ForOp> {
|
||||||
if (isOne && !std::get<2>(it).use_empty()) {
|
if (isOne && !std::get<2>(it).use_empty()) {
|
||||||
rewriter.setInsertionPoint(forOp);
|
rewriter.setInsertionPoint(forOp);
|
||||||
Value replacement = rewriter.create<SubIOp>(
|
Value replacement = rewriter.create<SubIOp>(
|
||||||
forOp.getLoc(), forOp.upperBound(), forOp.lowerBound());
|
forOp.getLoc(), forOp.getUpperBound(), forOp.getLowerBound());
|
||||||
if (!std::get<1>(it).getType().isa<IndexType>()) {
|
if (!std::get<1>(it).getType().isa<IndexType>()) {
|
||||||
replacement = rewriter.create<IndexCastOp>(
|
replacement = rewriter.create<IndexCastOp>(
|
||||||
forOp.getLoc(), replacement, std::get<1>(it).getType());
|
forOp.getLoc(), replacement, std::get<1>(it).getType());
|
||||||
|
@ -275,7 +275,7 @@ struct RemoveUnusedArgs : public OpRewritePattern<ForOp> {
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto newForOp = rewriter.create<ForOp>(
|
auto newForOp = rewriter.create<ForOp>(
|
||||||
op.getLoc(), op.lowerBound(), op.upperBound(), op.step(), usedOperands);
|
op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep(), usedOperands);
|
||||||
|
|
||||||
if (!newForOp.getBody()->empty())
|
if (!newForOp.getBody()->empty())
|
||||||
rewriter.eraseOp(newForOp.getBody()->getTerminator());
|
rewriter.eraseOp(newForOp.getBody()->getTerminator());
|
||||||
|
@ -499,7 +499,7 @@ yop2.results()[idx]);
|
||||||
|
|
||||||
bool isWhile(WhileOp wop) {
|
bool isWhile(WhileOp wop) {
|
||||||
bool hasCondOp = false;
|
bool hasCondOp = false;
|
||||||
wop.before().walk([&](Operation *op) {
|
wop.getBefore().walk([&](Operation *op) {
|
||||||
if (isa<scf::ConditionOp>(op))
|
if (isa<scf::ConditionOp>(op))
|
||||||
hasCondOp = true;
|
hasCondOp = true;
|
||||||
});
|
});
|
||||||
|
@ -547,12 +547,12 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
|
||||||
} loopInfo;
|
} loopInfo;
|
||||||
|
|
||||||
auto condOp = loop.getConditionOp();
|
auto condOp = loop.getConditionOp();
|
||||||
SmallVector<Value, 2> results = {condOp.args()};
|
SmallVector<Value, 2> results = {condOp.getArgs()};
|
||||||
auto cmpIOp = condOp.condition().getDefiningOp<CmpIOp>();
|
auto cmpIOp = condOp.getCondition().getDefiningOp<CmpIOp>();
|
||||||
if (!cmpIOp) {
|
if (!cmpIOp) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
size_t size = loop.before().front().getOperations().size();
|
size_t size = loop.getBefore().front().getOperations().size();
|
||||||
if (size != 2) {
|
if (size != 2) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
@ -560,24 +560,24 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
|
||||||
BlockArgument indVar = cmpIOp.getLhs().dyn_cast<BlockArgument>();
|
BlockArgument indVar = cmpIOp.getLhs().dyn_cast<BlockArgument>();
|
||||||
if (!indVar)
|
if (!indVar)
|
||||||
return failure();
|
return failure();
|
||||||
if (indVar.getOwner() != &loop.before().front())
|
if (indVar.getOwner() != &loop.getBefore().front())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
SmallVector<size_t, 2> afterArgs;
|
SmallVector<size_t, 2> afterArgs;
|
||||||
for (auto pair : llvm::enumerate(condOp.args())) {
|
for (auto pair : llvm::enumerate(condOp.getArgs())) {
|
||||||
if (pair.value() == indVar)
|
if (pair.value() == indVar)
|
||||||
afterArgs.push_back(pair.index());
|
afterArgs.push_back(pair.index());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto endYield = cast<YieldOp>(loop.after().back().getTerminator());
|
auto endYield = cast<YieldOp>(loop.getAfter().back().getTerminator());
|
||||||
|
|
||||||
auto addIOp =
|
auto addIOp =
|
||||||
endYield.results()[indVar.getArgNumber()].getDefiningOp<AddIOp>();
|
endYield.getResults()[indVar.getArgNumber()].getDefiningOp<AddIOp>();
|
||||||
if (!addIOp)
|
if (!addIOp)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
for (auto afterArg : afterArgs) {
|
for (auto afterArg : afterArgs) {
|
||||||
auto arg = loop.after().getArgument(afterArg);
|
auto arg = loop.getAfter().getArgument(afterArg);
|
||||||
if (addIOp.getOperand(0) == arg) {
|
if (addIOp.getOperand(0) == arg) {
|
||||||
step = addIOp.getOperand(1);
|
step = addIOp.getOperand(1);
|
||||||
break;
|
break;
|
||||||
|
@ -679,19 +679,19 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
|
||||||
// input of the for goes the input of the scf::while plus the output taken
|
// input of the for goes the input of the scf::while plus the output taken
|
||||||
// from the conditionOp.
|
// from the conditionOp.
|
||||||
SmallVector<Value, 8> forArgs;
|
SmallVector<Value, 8> forArgs;
|
||||||
forArgs.append(loop.inits().begin(), loop.inits().end());
|
forArgs.append(loop.getInits().begin(), loop.getInits().end());
|
||||||
|
|
||||||
for (Value arg : condOp.args()) {
|
for (Value arg : condOp.getArgs()) {
|
||||||
Type cst = nullptr;
|
Type cst = nullptr;
|
||||||
if (auto idx = arg.getDefiningOp<IndexCastOp>()) {
|
if (auto idx = arg.getDefiningOp<IndexCastOp>()) {
|
||||||
cst = idx.getType();
|
cst = idx.getType();
|
||||||
arg = idx.getIn();
|
arg = idx.getIn();
|
||||||
}
|
}
|
||||||
Value res;
|
Value res;
|
||||||
if (isTopLevelArgValue(arg, &loop.before())) {
|
if (isTopLevelArgValue(arg, &loop.getBefore())) {
|
||||||
auto blockArg = arg.cast<BlockArgument>();
|
auto blockArg = arg.cast<BlockArgument>();
|
||||||
auto pos = blockArg.getArgNumber();
|
auto pos = blockArg.getArgNumber();
|
||||||
res = loop.inits()[pos];
|
res = loop.getInits()[pos];
|
||||||
} else
|
} else
|
||||||
res = arg;
|
res = arg;
|
||||||
if (cst) {
|
if (cst) {
|
||||||
|
@ -706,31 +706,31 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
|
||||||
if (!forloop.getBody()->empty())
|
if (!forloop.getBody()->empty())
|
||||||
rewriter.eraseOp(forloop.getBody()->getTerminator());
|
rewriter.eraseOp(forloop.getBody()->getTerminator());
|
||||||
|
|
||||||
auto oldYield = cast<scf::YieldOp>(loop.after().front().getTerminator());
|
auto oldYield = cast<scf::YieldOp>(loop.getAfter().front().getTerminator());
|
||||||
|
|
||||||
rewriter.updateRootInPlace(loop, [&] {
|
rewriter.updateRootInPlace(loop, [&] {
|
||||||
for (auto pair : llvm::zip(loop.after().getArguments(), condOp.args())) {
|
for (auto pair : llvm::zip(loop.getAfter().getArguments(), condOp.getArgs())) {
|
||||||
std::get<0>(pair).replaceAllUsesWith(std::get<1>(pair));
|
std::get<0>(pair).replaceAllUsesWith(std::get<1>(pair));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
loop.after().front().eraseArguments([](BlockArgument) { return true; });
|
loop.getAfter().front().eraseArguments([](BlockArgument) { return true; });
|
||||||
|
|
||||||
SmallVector<Value, 2> yieldOperands;
|
SmallVector<Value, 2> yieldOperands;
|
||||||
for (auto oldYieldArg : oldYield.results())
|
for (auto oldYieldArg : oldYield.getResults())
|
||||||
yieldOperands.push_back(oldYieldArg);
|
yieldOperands.push_back(oldYieldArg);
|
||||||
|
|
||||||
BlockAndValueMapping outmap;
|
BlockAndValueMapping outmap;
|
||||||
outmap.map(loop.before().getArguments(), yieldOperands);
|
outmap.map(loop.getBefore().getArguments(), yieldOperands);
|
||||||
for (auto arg : condOp.args())
|
for (auto arg : condOp.getArgs())
|
||||||
yieldOperands.push_back(outmap.lookupOrDefault(arg));
|
yieldOperands.push_back(outmap.lookupOrDefault(arg));
|
||||||
|
|
||||||
rewriter.setInsertionPoint(oldYield);
|
rewriter.setInsertionPoint(oldYield);
|
||||||
rewriter.replaceOpWithNewOp<scf::YieldOp>(oldYield, yieldOperands);
|
rewriter.replaceOpWithNewOp<scf::YieldOp>(oldYield, yieldOperands);
|
||||||
|
|
||||||
size_t pos = loop.inits().size();
|
size_t pos = loop.getInits().size();
|
||||||
|
|
||||||
rewriter.updateRootInPlace(loop, [&] {
|
rewriter.updateRootInPlace(loop, [&] {
|
||||||
for (auto pair : llvm::zip(loop.before().getArguments(),
|
for (auto pair : llvm::zip(loop.getBefore().getArguments(),
|
||||||
forloop.getRegionIterArgs().drop_back(pos))) {
|
forloop.getRegionIterArgs().drop_back(pos))) {
|
||||||
std::get<0>(pair).replaceAllUsesWith(std::get<1>(pair));
|
std::get<0>(pair).replaceAllUsesWith(std::get<1>(pair));
|
||||||
}
|
}
|
||||||
|
@ -738,7 +738,7 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
|
||||||
|
|
||||||
forloop.getBody()->getOperations().splice(
|
forloop.getBody()->getOperations().splice(
|
||||||
forloop.getBody()->getOperations().begin(),
|
forloop.getBody()->getOperations().begin(),
|
||||||
loop.after().front().getOperations());
|
loop.getAfter().front().getOperations());
|
||||||
|
|
||||||
SmallVector<Value, 2> replacements;
|
SmallVector<Value, 2> replacements;
|
||||||
replacements.append(forloop.getResults().begin() + pos,
|
replacements.append(forloop.getResults().begin() + pos,
|
||||||
|
@ -754,20 +754,20 @@ struct MoveWhileDown : public OpRewritePattern<WhileOp> {
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(WhileOp op,
|
LogicalResult matchAndRewrite(WhileOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator());
|
auto term = cast<scf::ConditionOp>(op.getBefore().front().getTerminator());
|
||||||
if (auto ifOp = term.condition().getDefiningOp<scf::IfOp>()) {
|
if (auto ifOp = term.getCondition().getDefiningOp<scf::IfOp>()) {
|
||||||
if (ifOp.getNumResults() != term.args().size() + 1)
|
if (ifOp.getNumResults() != term.getArgs().size() + 1)
|
||||||
return failure();
|
return failure();
|
||||||
if (ifOp.getResult(0) != term.condition())
|
if (ifOp.getResult(0) != term.getCondition())
|
||||||
return failure();
|
return failure();
|
||||||
for (size_t i = 1; i < ifOp.getNumResults(); ++i) {
|
for (size_t i = 1; i < ifOp.getNumResults(); ++i) {
|
||||||
if (ifOp.getResult(i) != term.args()[i - 1])
|
if (ifOp.getResult(i) != term.getArgs()[i - 1])
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
auto yield1 =
|
auto yield1 =
|
||||||
cast<scf::YieldOp>(ifOp.thenRegion().front().getTerminator());
|
cast<scf::YieldOp>(ifOp.getThenRegion().front().getTerminator());
|
||||||
auto yield2 =
|
auto yield2 =
|
||||||
cast<scf::YieldOp>(ifOp.elseRegion().front().getTerminator());
|
cast<scf::YieldOp>(ifOp.getElseRegion().front().getTerminator());
|
||||||
if (auto cop = yield1.getOperand(0).getDefiningOp<ConstantIntOp>()) {
|
if (auto cop = yield1.getOperand(0).getDefiningOp<ConstantIntOp>()) {
|
||||||
if (cop.value() == 0)
|
if (cop.value() == 0)
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -778,31 +778,31 @@ struct MoveWhileDown : public OpRewritePattern<WhileOp> {
|
||||||
return failure();
|
return failure();
|
||||||
} else
|
} else
|
||||||
return failure();
|
return failure();
|
||||||
if (ifOp.elseRegion().front().getOperations().size() != 1)
|
if (ifOp.getElseRegion().front().getOperations().size() != 1)
|
||||||
return failure();
|
return failure();
|
||||||
op.after().front().getOperations().splice(
|
op.getAfter().front().getOperations().splice(
|
||||||
op.after().front().begin(),
|
op.getAfter().front().begin(),
|
||||||
ifOp.thenRegion().front().getOperations());
|
ifOp.getThenRegion().front().getOperations());
|
||||||
rewriter.updateRootInPlace(
|
rewriter.updateRootInPlace(
|
||||||
term, [&] { term.conditionMutable().assign(ifOp.condition()); });
|
term, [&] { term.getConditionMutable().assign(ifOp.getCondition()); });
|
||||||
SmallVector<Value, 2> args;
|
SmallVector<Value, 2> args;
|
||||||
for (size_t i = 1; i < yield2.getNumOperands(); ++i) {
|
for (size_t i = 1; i < yield2.getNumOperands(); ++i) {
|
||||||
args.push_back(yield2.getOperand(i));
|
args.push_back(yield2.getOperand(i));
|
||||||
}
|
}
|
||||||
rewriter.updateRootInPlace(term,
|
rewriter.updateRootInPlace(term,
|
||||||
[&] { term.argsMutable().assign(args); });
|
[&] { term.getArgsMutable().assign(args); });
|
||||||
rewriter.eraseOp(yield2);
|
rewriter.eraseOp(yield2);
|
||||||
rewriter.eraseOp(ifOp);
|
rewriter.eraseOp(ifOp);
|
||||||
|
|
||||||
for (size_t i = 0; i < op.after().front().getNumArguments(); ++i) {
|
for (size_t i = 0; i < op.getAfter().front().getNumArguments(); ++i) {
|
||||||
op.after().front().getArgument(i).replaceAllUsesWith(
|
op.getAfter().front().getArgument(i).replaceAllUsesWith(
|
||||||
yield1.getOperand(i + 1));
|
yield1.getOperand(i + 1));
|
||||||
}
|
}
|
||||||
rewriter.eraseOp(yield1);
|
rewriter.eraseOp(yield1);
|
||||||
// TODO move operands from begin to after
|
// TODO move operands from begin to after
|
||||||
SmallVector<Value> todo(op.before().front().getArguments().begin(),
|
SmallVector<Value> todo(op.getBefore().front().getArguments().begin(),
|
||||||
op.before().front().getArguments().end());
|
op.getBefore().front().getArguments().end());
|
||||||
for (auto &op : op.before().front()) {
|
for (auto &op : op.getBefore().front()) {
|
||||||
for (auto res : op.getResults()) {
|
for (auto res : op.getResults()) {
|
||||||
todo.push_back(res);
|
todo.push_back(res);
|
||||||
}
|
}
|
||||||
|
@ -810,24 +810,24 @@ struct MoveWhileDown : public OpRewritePattern<WhileOp> {
|
||||||
|
|
||||||
rewriter.updateRootInPlace(op, [&] {
|
rewriter.updateRootInPlace(op, [&] {
|
||||||
for (auto val : todo) {
|
for (auto val : todo) {
|
||||||
auto na = op.after().front().addArgument(val.getType());
|
auto na = op.getAfter().front().addArgument(val.getType());
|
||||||
val.replaceUsesWithIf(na, [&](OpOperand &u) -> bool {
|
val.replaceUsesWithIf(na, [&](OpOperand &u) -> bool {
|
||||||
return op.after().isAncestor(u.getOwner()->getParentRegion());
|
return op.getAfter().isAncestor(u.getOwner()->getParentRegion());
|
||||||
});
|
});
|
||||||
args.push_back(val);
|
args.push_back(val);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
rewriter.updateRootInPlace(term,
|
rewriter.updateRootInPlace(term,
|
||||||
[&] { term.argsMutable().assign(args); });
|
[&] { term.getArgsMutable().assign(args); });
|
||||||
|
|
||||||
SmallVector<Type, 4> tys;
|
SmallVector<Type, 4> tys;
|
||||||
for (auto a : args)
|
for (auto a : args)
|
||||||
tys.push_back(a.getType());
|
tys.push_back(a.getType());
|
||||||
|
|
||||||
auto op2 = rewriter.create<WhileOp>(op.getLoc(), tys, op.inits());
|
auto op2 = rewriter.create<WhileOp>(op.getLoc(), tys, op.getInits());
|
||||||
op2.before().takeBody(op.before());
|
op2.getBefore().takeBody(op.getBefore());
|
||||||
op2.after().takeBody(op.after());
|
op2.getAfter().takeBody(op.getAfter());
|
||||||
SmallVector<Value, 4> replacements;
|
SmallVector<Value, 4> replacements;
|
||||||
for (auto a : op2.getResults()) {
|
for (auto a : op2.getResults()) {
|
||||||
if (replacements.size() == op.getResults().size())
|
if (replacements.size() == op.getResults().size())
|
||||||
|
@ -882,9 +882,9 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(WhileOp op,
|
LogicalResult matchAndRewrite(WhileOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator());
|
auto term = cast<scf::ConditionOp>(op.getBefore().front().getTerminator());
|
||||||
if (auto ifOp = dyn_cast_or_null<scf::IfOp>(term->getPrevNode())) {
|
if (auto ifOp = dyn_cast_or_null<scf::IfOp>(term->getPrevNode())) {
|
||||||
if (ifOp.condition() != term.condition())
|
if (ifOp.getCondition() != term.getCondition())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
SmallVector<std::pair<BlockArgument, Value>, 2> m;
|
SmallVector<std::pair<BlockArgument, Value>, 2> m;
|
||||||
|
@ -892,14 +892,14 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
|
||||||
SmallVector<Value, 2> prevArgs;
|
SmallVector<Value, 2> prevArgs;
|
||||||
|
|
||||||
SmallVector<std::pair<size_t, Value>, 2> afterYieldRewrites;
|
SmallVector<std::pair<size_t, Value>, 2> afterYieldRewrites;
|
||||||
auto afterYield = cast<YieldOp>(op.after().front().back());
|
auto afterYield = cast<YieldOp>(op.getAfter().front().back());
|
||||||
for (auto pair :
|
for (auto pair :
|
||||||
llvm::zip(op.getResults(), term.args(), op.getAfterArguments())) {
|
llvm::zip(op.getResults(), term.getArgs(), op.getAfterArguments())) {
|
||||||
if (std::get<1>(pair).getDefiningOp() == ifOp) {
|
if (std::get<1>(pair).getDefiningOp() == ifOp) {
|
||||||
|
|
||||||
Value thenYielded, elseYielded;
|
Value thenYielded, elseYielded;
|
||||||
for (auto p : llvm::zip(ifOp.thenYield().results(), ifOp.results(),
|
for (auto p : llvm::zip(ifOp.thenYield().getResults(), ifOp.getResults(),
|
||||||
ifOp.elseYield().results())) {
|
ifOp.elseYield().getResults())) {
|
||||||
if (std::get<1>(pair) == std::get<1>(p)) {
|
if (std::get<1>(pair) == std::get<1>(p)) {
|
||||||
thenYielded = std::get<0>(p);
|
thenYielded = std::get<0>(p);
|
||||||
elseYielded = std::get<2>(p);
|
elseYielded = std::get<2>(p);
|
||||||
|
@ -911,8 +911,8 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
|
||||||
|
|
||||||
if (!std::get<0>(pair).use_empty()) {
|
if (!std::get<0>(pair).use_empty()) {
|
||||||
if (auto blockArg = elseYielded.dyn_cast<BlockArgument>())
|
if (auto blockArg = elseYielded.dyn_cast<BlockArgument>())
|
||||||
if (blockArg.getOwner() == &op.before().front()) {
|
if (blockArg.getOwner() == &op.getBefore().front()) {
|
||||||
if (afterYield.results()[blockArg.getArgNumber()] ==
|
if (afterYield.getResults()[blockArg.getArgNumber()] ==
|
||||||
std::get<2>(pair) &&
|
std::get<2>(pair) &&
|
||||||
op.getResults()[blockArg.getArgNumber()] ==
|
op.getResults()[blockArg.getArgNumber()] ==
|
||||||
std::get<0>(pair)) {
|
std::get<0>(pair)) {
|
||||||
|
@ -933,18 +933,18 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value> yieldArgs = afterYield.results();
|
SmallVector<Value> yieldArgs = afterYield.getResults();
|
||||||
for (auto pair : afterYieldRewrites) {
|
for (auto pair : afterYieldRewrites) {
|
||||||
yieldArgs[pair.first] = pair.second;
|
yieldArgs[pair.first] = pair.second;
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.updateRootInPlace(
|
rewriter.updateRootInPlace(
|
||||||
afterYield, [&] { afterYield.resultsMutable().assign(yieldArgs); });
|
afterYield, [&] { afterYield.getResultsMutable().assign(yieldArgs); });
|
||||||
|
|
||||||
llvm::SetVector<Value> sv;
|
llvm::SetVector<Value> sv;
|
||||||
findValuesUsedBelow(ifOp, sv);
|
findValuesUsedBelow(ifOp, sv);
|
||||||
|
|
||||||
Block *afterB = &op.after().front();
|
Block *afterB = &op.getAfter().front();
|
||||||
|
|
||||||
for (auto v : sv) {
|
for (auto v : sv) {
|
||||||
condArgs.push_back(v);
|
condArgs.push_back(v);
|
||||||
|
@ -956,7 +956,7 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.setInsertionPoint(term);
|
rewriter.setInsertionPoint(term);
|
||||||
rewriter.replaceOpWithNewOp<ConditionOp>(term, term.condition(),
|
rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
|
||||||
condArgs);
|
condArgs);
|
||||||
|
|
||||||
for (int i = m.size() - 1; i >= 0; i--) {
|
for (int i = m.size() - 1; i >= 0; i--) {
|
||||||
|
@ -977,9 +977,9 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.setInsertionPoint(op);
|
rewriter.setInsertionPoint(op);
|
||||||
auto nop = rewriter.create<WhileOp>(op.getLoc(), resultTypes, op.inits());
|
auto nop = rewriter.create<WhileOp>(op.getLoc(), resultTypes, op.getInits());
|
||||||
nop.before().takeBody(op.before());
|
nop.getBefore().takeBody(op.getBefore());
|
||||||
nop.after().takeBody(op.after());
|
nop.getAfter().takeBody(op.getAfter());
|
||||||
|
|
||||||
rewriter.updateRootInPlace(op, [&] {
|
rewriter.updateRootInPlace(op, [&] {
|
||||||
for (auto pair : llvm::enumerate(prevArgs)) {
|
for (auto pair : llvm::enumerate(prevArgs)) {
|
||||||
|
@ -1002,24 +1002,24 @@ struct MoveWhileInvariantIfResult : public OpRewritePattern<WhileOp> {
|
||||||
SmallVector<BlockArgument, 2> origAfterArgs(op.getAfterArguments().begin(),
|
SmallVector<BlockArgument, 2> origAfterArgs(op.getAfterArguments().begin(),
|
||||||
op.getAfterArguments().end());
|
op.getAfterArguments().end());
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator());
|
scf::ConditionOp term = cast<scf::ConditionOp>(op.getBefore().front().getTerminator());
|
||||||
assert(origAfterArgs.size() == op.getResults().size());
|
assert(origAfterArgs.size() == op.getResults().size());
|
||||||
assert(origAfterArgs.size() == term.args().size());
|
assert(origAfterArgs.size() == term.getArgs().size());
|
||||||
|
|
||||||
for (auto pair : llvm::zip(op.getResults(), term.args(), origAfterArgs)) {
|
for (auto pair : llvm::zip(op.getResults(), term.getArgs(), origAfterArgs)) {
|
||||||
if (!std::get<0>(pair).use_empty()) {
|
if (!std::get<0>(pair).use_empty()) {
|
||||||
if (auto ifOp = std::get<1>(pair).getDefiningOp<scf::IfOp>()) {
|
if (auto ifOp = std::get<1>(pair).getDefiningOp<scf::IfOp>()) {
|
||||||
if (ifOp.condition() == term.condition()) {
|
if (ifOp.getCondition() == term.getCondition()) {
|
||||||
ssize_t idx = -1;
|
ssize_t idx = -1;
|
||||||
for (auto tup : llvm::enumerate(ifOp.results())) {
|
for (auto tup : llvm::enumerate(ifOp.getResults())) {
|
||||||
if (tup.value() == std::get<1>(pair)) {
|
if (tup.value() == std::get<1>(pair)) {
|
||||||
idx = tup.index();
|
idx = tup.index();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert(idx != -1);
|
assert(idx != -1);
|
||||||
Value returnWith = ifOp.elseYield().results()[idx];
|
Value returnWith = ifOp.elseYield().getResults()[idx];
|
||||||
if (!op.before().isAncestor(returnWith.getParentRegion())) {
|
if (!op.getBefore().isAncestor(returnWith.getParentRegion())) {
|
||||||
rewriter.updateRootInPlace(op, [&] {
|
rewriter.updateRootInPlace(op, [&] {
|
||||||
std::get<0>(pair).replaceAllUsesWith(returnWith);
|
std::get<0>(pair).replaceAllUsesWith(returnWith);
|
||||||
});
|
});
|
||||||
|
@ -1042,12 +1042,12 @@ struct WhileLogicalNegation : public OpRewritePattern<WhileOp> {
|
||||||
SmallVector<BlockArgument, 2> origAfterArgs(op.getAfterArguments().begin(),
|
SmallVector<BlockArgument, 2> origAfterArgs(op.getAfterArguments().begin(),
|
||||||
op.getAfterArguments().end());
|
op.getAfterArguments().end());
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator());
|
scf::ConditionOp term = cast<scf::ConditionOp>(op.getBefore().front().getTerminator());
|
||||||
assert(origAfterArgs.size() == op.getResults().size());
|
assert(origAfterArgs.size() == op.getResults().size());
|
||||||
assert(origAfterArgs.size() == term.args().size());
|
assert(origAfterArgs.size() == term.getArgs().size());
|
||||||
|
|
||||||
if (auto condCmp = term.condition().getDefiningOp<CmpIOp>()) {
|
if (auto condCmp = term.getCondition().getDefiningOp<CmpIOp>()) {
|
||||||
for (auto pair : llvm::zip(op.getResults(), term.args(), origAfterArgs)) {
|
for (auto pair : llvm::zip(op.getResults(), term.getArgs(), origAfterArgs)) {
|
||||||
if (!std::get<0>(pair).use_empty()) {
|
if (!std::get<0>(pair).use_empty()) {
|
||||||
if (auto termCmp = std::get<1>(pair).getDefiningOp<CmpIOp>()) {
|
if (auto termCmp = std::get<1>(pair).getDefiningOp<CmpIOp>()) {
|
||||||
if (termCmp.getLhs() == condCmp.getLhs() &&
|
if (termCmp.getLhs() == condCmp.getLhs() &&
|
||||||
|
@ -1082,41 +1082,41 @@ struct WhileCmpOffset : public OpRewritePattern<WhileOp> {
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
SmallVector<BlockArgument, 2> origAfterArgs(op.getAfterArguments().begin(),
|
SmallVector<BlockArgument, 2> origAfterArgs(op.getAfterArguments().begin(),
|
||||||
op.getAfterArguments().end());
|
op.getAfterArguments().end());
|
||||||
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator());
|
scf::ConditionOp term = cast<scf::ConditionOp>(op.getBefore().front().getTerminator());
|
||||||
assert(origAfterArgs.size() == op.getResults().size());
|
assert(origAfterArgs.size() == op.getResults().size());
|
||||||
assert(origAfterArgs.size() == term.args().size());
|
assert(origAfterArgs.size() == term.getArgs().size());
|
||||||
|
|
||||||
if (auto condCmp = term.condition().getDefiningOp<CmpIOp>()) {
|
if (auto condCmp = term.getCondition().getDefiningOp<CmpIOp>()) {
|
||||||
if (auto addI = condCmp.getLhs().getDefiningOp<AddIOp>()) {
|
if (auto addI = condCmp.getLhs().getDefiningOp<AddIOp>()) {
|
||||||
if (addI.getOperand(1).getDefiningOp() &&
|
if (addI.getOperand(1).getDefiningOp() &&
|
||||||
!op.before().isAncestor(
|
!op.getBefore().isAncestor(
|
||||||
addI.getOperand(1).getDefiningOp()->getParentRegion()))
|
addI.getOperand(1).getDefiningOp()->getParentRegion()))
|
||||||
if (auto blockArg = addI.getOperand(0).dyn_cast<BlockArgument>()) {
|
if (auto blockArg = addI.getOperand(0).dyn_cast<BlockArgument>()) {
|
||||||
if (blockArg.getOwner() == &op.before().front()) {
|
if (blockArg.getOwner() == &op.getBefore().front()) {
|
||||||
auto rng = llvm::make_early_inc_range(blockArg.getUses());
|
auto rng = llvm::make_early_inc_range(blockArg.getUses());
|
||||||
|
|
||||||
{
|
{
|
||||||
rewriter.setInsertionPoint(op);
|
rewriter.setInsertionPoint(op);
|
||||||
SmallVector<Value> oldInits = op.inits();
|
SmallVector<Value> oldInits = op.getInits();
|
||||||
oldInits[blockArg.getArgNumber()] = rewriter.create<AddIOp>(
|
oldInits[blockArg.getArgNumber()] = rewriter.create<AddIOp>(
|
||||||
addI.getLoc(), oldInits[blockArg.getArgNumber()],
|
addI.getLoc(), oldInits[blockArg.getArgNumber()],
|
||||||
addI.getOperand(1));
|
addI.getOperand(1));
|
||||||
op.initsMutable().assign(oldInits);
|
op.getInitsMutable().assign(oldInits);
|
||||||
rewriter.updateRootInPlace(
|
rewriter.updateRootInPlace(
|
||||||
addI, [&] { addI.replaceAllUsesWith(blockArg); });
|
addI, [&] { addI.replaceAllUsesWith(blockArg); });
|
||||||
}
|
}
|
||||||
|
|
||||||
YieldOp afterYield = cast<YieldOp>(op.after().front().back());
|
YieldOp afterYield = cast<YieldOp>(op.getAfter().front().back());
|
||||||
rewriter.setInsertionPoint(afterYield);
|
rewriter.setInsertionPoint(afterYield);
|
||||||
SmallVector<Value> oldYields = afterYield.results();
|
SmallVector<Value> oldYields = afterYield.getResults();
|
||||||
oldYields[blockArg.getArgNumber()] = rewriter.create<AddIOp>(
|
oldYields[blockArg.getArgNumber()] = rewriter.create<AddIOp>(
|
||||||
addI.getLoc(), oldYields[blockArg.getArgNumber()],
|
addI.getLoc(), oldYields[blockArg.getArgNumber()],
|
||||||
addI.getOperand(1));
|
addI.getOperand(1));
|
||||||
rewriter.updateRootInPlace(afterYield, [&] {
|
rewriter.updateRootInPlace(afterYield, [&] {
|
||||||
afterYield.resultsMutable().assign(oldYields);
|
afterYield.getResultsMutable().assign(oldYields);
|
||||||
});
|
});
|
||||||
|
|
||||||
rewriter.setInsertionPointToStart(&op.before().front());
|
rewriter.setInsertionPointToStart(&op.getBefore().front());
|
||||||
auto sub = rewriter.create<SubIOp>(addI.getLoc(), blockArg,
|
auto sub = rewriter.create<SubIOp>(addI.getLoc(), blockArg,
|
||||||
addI.getOperand(1));
|
addI.getOperand(1));
|
||||||
for (OpOperand &use : rng) {
|
for (OpOperand &use : rng) {
|
||||||
|
@ -1139,7 +1139,7 @@ struct MoveWhileDown3 : public OpRewritePattern<WhileOp> {
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(WhileOp op,
|
LogicalResult matchAndRewrite(WhileOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator());
|
scf::ConditionOp term = cast<scf::ConditionOp>(op.getBefore().front().getTerminator());
|
||||||
SmallVector<unsigned, 2> toErase;
|
SmallVector<unsigned, 2> toErase;
|
||||||
SmallVector<Value, 2> newOps;
|
SmallVector<Value, 2> newOps;
|
||||||
SmallVector<Value, 2> condOps;
|
SmallVector<Value, 2> condOps;
|
||||||
|
@ -1147,8 +1147,8 @@ struct MoveWhileDown3 : public OpRewritePattern<WhileOp> {
|
||||||
op.getAfterArguments().end());
|
op.getAfterArguments().end());
|
||||||
SmallVector<Value, 2> returns;
|
SmallVector<Value, 2> returns;
|
||||||
assert(origAfterArgs.size() == op.getResults().size());
|
assert(origAfterArgs.size() == op.getResults().size());
|
||||||
assert(origAfterArgs.size() == term.args().size());
|
assert(origAfterArgs.size() == term.getArgs().size());
|
||||||
for (auto pair : llvm::zip(op.getResults(), term.args(), origAfterArgs)) {
|
for (auto pair : llvm::zip(op.getResults(), term.getArgs(), origAfterArgs)) {
|
||||||
if (std::get<0>(pair).use_empty()) {
|
if (std::get<0>(pair).use_empty()) {
|
||||||
if (std::get<2>(pair).use_empty()) {
|
if (std::get<2>(pair).use_empty()) {
|
||||||
toErase.push_back(std::get<2>(pair).getArgNumber());
|
toErase.push_back(std::get<2>(pair).getArgNumber());
|
||||||
|
@ -1162,9 +1162,9 @@ struct MoveWhileDown3 : public OpRewritePattern<WhileOp> {
|
||||||
Operation *cloned = std::get<1>(pair).getDefiningOp();
|
Operation *cloned = std::get<1>(pair).getDefiningOp();
|
||||||
if (!std::get<1>(pair).hasOneUse()) {
|
if (!std::get<1>(pair).hasOneUse()) {
|
||||||
cloned = std::get<1>(pair).getDefiningOp()->clone();
|
cloned = std::get<1>(pair).getDefiningOp()->clone();
|
||||||
op.after().front().push_front(cloned);
|
op.getAfter().front().push_front(cloned);
|
||||||
} else {
|
} else {
|
||||||
cloned->moveBefore(&op.after().front().front());
|
cloned->moveBefore(&op.getAfter().front().front());
|
||||||
}
|
}
|
||||||
rewriter.updateRootInPlace(std::get<1>(pair).getDefiningOp(), [&] {
|
rewriter.updateRootInPlace(std::get<1>(pair).getDefiningOp(), [&] {
|
||||||
std::get<2>(pair).replaceAllUsesWith(cloned->getResult(0));
|
std::get<2>(pair).replaceAllUsesWith(cloned->getResult(0));
|
||||||
|
@ -1174,7 +1174,7 @@ struct MoveWhileDown3 : public OpRewritePattern<WhileOp> {
|
||||||
llvm::make_early_inc_range(cloned->getOpOperands())) {
|
llvm::make_early_inc_range(cloned->getOpOperands())) {
|
||||||
{
|
{
|
||||||
newOps.push_back(o.get());
|
newOps.push_back(o.get());
|
||||||
o.set(op.after().front().addArgument(o.get().getType()));
|
o.set(op.getAfter().front().addArgument(o.get().getType()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
|
@ -1190,19 +1190,19 @@ struct MoveWhileDown3 : public OpRewritePattern<WhileOp> {
|
||||||
condOps.append(newOps.begin(), newOps.end());
|
condOps.append(newOps.begin(), newOps.end());
|
||||||
|
|
||||||
rewriter.updateRootInPlace(
|
rewriter.updateRootInPlace(
|
||||||
term, [&] { op.after().front().eraseArguments(toErase); });
|
term, [&] { op.getAfter().front().eraseArguments(toErase); });
|
||||||
rewriter.setInsertionPoint(term);
|
rewriter.setInsertionPoint(term);
|
||||||
rewriter.replaceOpWithNewOp<ConditionOp>(term, term.condition(), condOps);
|
rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(), condOps);
|
||||||
|
|
||||||
rewriter.setInsertionPoint(op);
|
rewriter.setInsertionPoint(op);
|
||||||
SmallVector<Type, 4> resultTypes;
|
SmallVector<Type, 4> resultTypes;
|
||||||
for (auto v : condOps) {
|
for (auto v : condOps) {
|
||||||
resultTypes.push_back(v.getType());
|
resultTypes.push_back(v.getType());
|
||||||
}
|
}
|
||||||
auto nop = rewriter.create<WhileOp>(op.getLoc(), resultTypes, op.inits());
|
auto nop = rewriter.create<WhileOp>(op.getLoc(), resultTypes, op.getInits());
|
||||||
|
|
||||||
nop.before().takeBody(op.before());
|
nop.getBefore().takeBody(op.getBefore());
|
||||||
nop.after().takeBody(op.after());
|
nop.getAfter().takeBody(op.getAfter());
|
||||||
|
|
||||||
rewriter.updateRootInPlace(op, [&] {
|
rewriter.updateRootInPlace(op, [&] {
|
||||||
for (auto pair : llvm::enumerate(returns)) {
|
for (auto pair : llvm::enumerate(returns)) {
|
||||||
|
@ -1210,7 +1210,7 @@ struct MoveWhileDown3 : public OpRewritePattern<WhileOp> {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
assert(resultTypes.size() == nop.after().front().getNumArguments());
|
assert(resultTypes.size() == nop.getAfter().front().getNumArguments());
|
||||||
assert(resultTypes.size() == condOps.size());
|
assert(resultTypes.size() == condOps.size());
|
||||||
|
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
|
@ -1269,7 +1269,7 @@ struct WhileLICM : public OpRewritePattern<WhileOp> {
|
||||||
auto definingOp = value.getDefiningOp();
|
auto definingOp = value.getDefiningOp();
|
||||||
bool definedOutside =
|
bool definedOutside =
|
||||||
(definingOp && !!willBeMovedSet.count(definingOp)) ||
|
(definingOp && !!willBeMovedSet.count(definingOp)) ||
|
||||||
!op.before().isAncestor(value.getParentRegion());
|
!op.getBefore().isAncestor(value.getParentRegion());
|
||||||
return definedOutside;
|
return definedOutside;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1277,7 +1277,7 @@ struct WhileLICM : public OpRewritePattern<WhileOp> {
|
||||||
// hoist operations from there. These regions might have semantics unknown
|
// hoist operations from there. These regions might have semantics unknown
|
||||||
// to this rewriting. If the nested regions are loops, they will have been
|
// to this rewriting. If the nested regions are loops, they will have been
|
||||||
// processed.
|
// processed.
|
||||||
for (auto &block : op.before()) {
|
for (auto &block : op.getBefore()) {
|
||||||
for (auto &op : block.without_terminator()) {
|
for (auto &op : block.without_terminator()) {
|
||||||
bool legal = canBeHoisted(&op, isDefinedOutsideOfBody);
|
bool legal = canBeHoisted(&op, isDefinedOutsideOfBody);
|
||||||
if (legal) {
|
if (legal) {
|
||||||
|
@ -1299,7 +1299,7 @@ struct RemoveUnusedCondVar : public OpRewritePattern<WhileOp> {
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(WhileOp op,
|
LogicalResult matchAndRewrite(WhileOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator());
|
auto term = cast<scf::ConditionOp>(op.getBefore().front().getTerminator());
|
||||||
SmallVector<Value, 4> conds;
|
SmallVector<Value, 4> conds;
|
||||||
SmallVector<unsigned, 4> eraseArgs;
|
SmallVector<unsigned, 4> eraseArgs;
|
||||||
SmallVector<unsigned, 4> keepArgs;
|
SmallVector<unsigned, 4> keepArgs;
|
||||||
|
@ -1308,7 +1308,7 @@ struct RemoveUnusedCondVar : public OpRewritePattern<WhileOp> {
|
||||||
std::map<void *, unsigned> valueOffsets;
|
std::map<void *, unsigned> valueOffsets;
|
||||||
std::map<unsigned, unsigned> resultOffsets;
|
std::map<unsigned, unsigned> resultOffsets;
|
||||||
SmallVector<Value, 4> resultArgs;
|
SmallVector<Value, 4> resultArgs;
|
||||||
for (auto pair : llvm::zip(term.args(), op.after().front().getArguments(),
|
for (auto pair : llvm::zip(term.getArgs(), op.getAfter().front().getArguments(),
|
||||||
op.getResults())) {
|
op.getResults())) {
|
||||||
auto arg = std::get<0>(pair);
|
auto arg = std::get<0>(pair);
|
||||||
auto afarg = std::get<1>(pair);
|
auto afarg = std::get<1>(pair);
|
||||||
|
@ -1331,24 +1331,24 @@ struct RemoveUnusedCondVar : public OpRewritePattern<WhileOp> {
|
||||||
}
|
}
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
assert(i == op.after().front().getArguments().size());
|
assert(i == op.getAfter().front().getArguments().size());
|
||||||
|
|
||||||
if (eraseArgs.size() != 0) {
|
if (eraseArgs.size() != 0) {
|
||||||
|
|
||||||
rewriter.setInsertionPoint(term);
|
rewriter.setInsertionPoint(term);
|
||||||
rewriter.replaceOpWithNewOp<scf::ConditionOp>(term, term.condition(),
|
rewriter.replaceOpWithNewOp<scf::ConditionOp>(term, term.getCondition(),
|
||||||
conds);
|
conds);
|
||||||
|
|
||||||
rewriter.setInsertionPoint(op);
|
rewriter.setInsertionPoint(op);
|
||||||
auto op2 = rewriter.create<WhileOp>(op.getLoc(), tys, op.inits());
|
auto op2 = rewriter.create<WhileOp>(op.getLoc(), tys, op.getInits());
|
||||||
|
|
||||||
op2.before().takeBody(op.before());
|
op2.getBefore().takeBody(op.getBefore());
|
||||||
op2.after().takeBody(op.after());
|
op2.getAfter().takeBody(op.getAfter());
|
||||||
for (auto pair : resultOffsets) {
|
for (auto pair : resultOffsets) {
|
||||||
op.getResult(pair.first).replaceAllUsesWith(op2.getResult(pair.second));
|
op.getResult(pair.first).replaceAllUsesWith(op2.getResult(pair.second));
|
||||||
}
|
}
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
op2.after().front().eraseArguments(eraseArgs);
|
op2.getAfter().front().eraseArguments(eraseArgs);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -1360,19 +1360,19 @@ struct MoveSideEffectFreeWhile : public OpRewritePattern<WhileOp> {
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(WhileOp op,
|
LogicalResult matchAndRewrite(WhileOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator());
|
scf::ConditionOp term = cast<scf::ConditionOp>(op.getBefore().front().getTerminator());
|
||||||
SmallVector<Value, 4> conds(term.args().begin(), term.args().end());
|
SmallVector<Value, 4> conds(term.getArgs().begin(), term.getArgs().end());
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
unsigned i = 0;
|
unsigned i = 0;
|
||||||
for (auto arg : term.args()) {
|
for (auto arg : term.getArgs()) {
|
||||||
if (auto IC = arg.getDefiningOp<IndexCastOp>()) {
|
if (auto IC = arg.getDefiningOp<IndexCastOp>()) {
|
||||||
if (arg.hasOneUse() && op.getResult(i).use_empty()) {
|
if (arg.hasOneUse() && op.getResult(i).use_empty()) {
|
||||||
auto rep =
|
auto rep =
|
||||||
op.after().front().addArgument(IC->getOperand(0).getType());
|
op.getAfter().front().addArgument(IC->getOperand(0).getType());
|
||||||
IC->moveBefore(&op.after().front(), op.after().front().begin());
|
IC->moveBefore(&op.getAfter().front(), op.getAfter().front().begin());
|
||||||
conds.push_back(IC.getIn());
|
conds.push_back(IC.getIn());
|
||||||
IC.getInMutable().assign(rep);
|
IC.getInMutable().assign(rep);
|
||||||
op.after().front().getArgument(i).replaceAllUsesWith(
|
op.getAfter().front().getArgument(i).replaceAllUsesWith(
|
||||||
IC->getResult(0));
|
IC->getResult(0));
|
||||||
changed = true;
|
changed = true;
|
||||||
}
|
}
|
||||||
|
@ -1384,9 +1384,9 @@ struct MoveSideEffectFreeWhile : public OpRewritePattern<WhileOp> {
|
||||||
for (auto arg : conds) {
|
for (auto arg : conds) {
|
||||||
tys.push_back(arg.getType());
|
tys.push_back(arg.getType());
|
||||||
}
|
}
|
||||||
auto op2 = rewriter.create<WhileOp>(op.getLoc(), tys, op.inits());
|
auto op2 = rewriter.create<WhileOp>(op.getLoc(), tys, op.getInits());
|
||||||
op2.before().takeBody(op.before());
|
op2.getBefore().takeBody(op.getBefore());
|
||||||
op2.after().takeBody(op.after());
|
op2.getAfter().takeBody(op.getAfter());
|
||||||
unsigned j = 0;
|
unsigned j = 0;
|
||||||
for (auto a : op.getResults()) {
|
for (auto a : op.getResults()) {
|
||||||
a.replaceAllUsesWith(op2.getResult(j));
|
a.replaceAllUsesWith(op2.getResult(j));
|
||||||
|
@ -1394,7 +1394,7 @@ struct MoveSideEffectFreeWhile : public OpRewritePattern<WhileOp> {
|
||||||
}
|
}
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
rewriter.setInsertionPoint(term);
|
rewriter.setInsertionPoint(term);
|
||||||
rewriter.replaceOpWithNewOp<scf::ConditionOp>(term, term.condition(),
|
rewriter.replaceOpWithNewOp<scf::ConditionOp>(term, term.getCondition(),
|
||||||
conds);
|
conds);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -295,8 +295,8 @@ bool LoopRestructure::removeIfFromRegion(DominanceInfo &domInfo, Region ®ion,
|
||||||
/*hasElse*/ true);
|
/*hasElse*/ true);
|
||||||
Succs[j] = new Block();
|
Succs[j] = new Block();
|
||||||
if (j == 0) {
|
if (j == 0) {
|
||||||
ifOp.elseRegion().getBlocks().splice(
|
ifOp.getElseRegion().getBlocks().splice(
|
||||||
ifOp.elseRegion().getBlocks().end(), region.getBlocks(),
|
ifOp.getElseRegion().getBlocks().end(), region.getBlocks(),
|
||||||
Succs[1 - j]);
|
Succs[1 - j]);
|
||||||
SmallVector<unsigned, 4> idx;
|
SmallVector<unsigned, 4> idx;
|
||||||
for (size_t i = 0; i < Succs[1 - j]->getNumArguments(); ++i) {
|
for (size_t i = 0; i < Succs[1 - j]->getNumArguments(); ++i) {
|
||||||
|
@ -305,18 +305,18 @@ bool LoopRestructure::removeIfFromRegion(DominanceInfo &domInfo, Region ®ion,
|
||||||
idx.push_back(i);
|
idx.push_back(i);
|
||||||
}
|
}
|
||||||
Succs[1 - j]->eraseArguments(idx);
|
Succs[1 - j]->eraseArguments(idx);
|
||||||
assert(!ifOp.elseRegion().getBlocks().empty());
|
assert(!ifOp.getElseRegion().getBlocks().empty());
|
||||||
assert(condTys.size() == condBr.getTrueOperands().size());
|
assert(condTys.size() == condBr.getTrueOperands().size());
|
||||||
OpBuilder tbuilder(&ifOp.thenRegion().front(),
|
OpBuilder tbuilder(&ifOp.getThenRegion().front(),
|
||||||
ifOp.thenRegion().front().begin());
|
ifOp.getThenRegion().front().begin());
|
||||||
tbuilder.create<scf::YieldOp>(tbuilder.getUnknownLoc(), emptyTys,
|
tbuilder.create<scf::YieldOp>(tbuilder.getUnknownLoc(), emptyTys,
|
||||||
condBr.getTrueOperands());
|
condBr.getTrueOperands());
|
||||||
} else {
|
} else {
|
||||||
if (!ifOp.thenRegion().getBlocks().empty()) {
|
if (!ifOp.getThenRegion().getBlocks().empty()) {
|
||||||
ifOp.thenRegion().front().erase();
|
ifOp.getThenRegion().front().erase();
|
||||||
}
|
}
|
||||||
ifOp.thenRegion().getBlocks().splice(
|
ifOp.getThenRegion().getBlocks().splice(
|
||||||
ifOp.thenRegion().getBlocks().end(), region.getBlocks(),
|
ifOp.getThenRegion().getBlocks().end(), region.getBlocks(),
|
||||||
Succs[1 - j]);
|
Succs[1 - j]);
|
||||||
SmallVector<unsigned, 4> idx;
|
SmallVector<unsigned, 4> idx;
|
||||||
for (size_t i = 0; i < Succs[1 - j]->getNumArguments(); ++i) {
|
for (size_t i = 0; i < Succs[1 - j]->getNumArguments(); ++i) {
|
||||||
|
@ -325,9 +325,9 @@ bool LoopRestructure::removeIfFromRegion(DominanceInfo &domInfo, Region ®ion,
|
||||||
idx.push_back(i);
|
idx.push_back(i);
|
||||||
}
|
}
|
||||||
Succs[1 - j]->eraseArguments(idx);
|
Succs[1 - j]->eraseArguments(idx);
|
||||||
assert(!ifOp.elseRegion().getBlocks().empty());
|
assert(!ifOp.getElseRegion().getBlocks().empty());
|
||||||
OpBuilder tbuilder(&ifOp.elseRegion().front(),
|
OpBuilder tbuilder(&ifOp.getElseRegion().front(),
|
||||||
ifOp.elseRegion().front().begin());
|
ifOp.getElseRegion().front().begin());
|
||||||
assert(condTys.size() == condBr.getFalseOperands().size());
|
assert(condTys.size() == condBr.getFalseOperands().size());
|
||||||
tbuilder.create<scf::YieldOp>(tbuilder.getUnknownLoc(), emptyTys,
|
tbuilder.create<scf::YieldOp>(tbuilder.getUnknownLoc(), emptyTys,
|
||||||
condBr.getFalseOperands());
|
condBr.getFalseOperands());
|
||||||
|
@ -449,12 +449,12 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region ®ion) {
|
||||||
Preds.push_back(block);
|
Preds.push_back(block);
|
||||||
}
|
}
|
||||||
|
|
||||||
loop.before().getBlocks().splice(loop.before().getBlocks().begin(),
|
loop.getBefore().getBlocks().splice(loop.getBefore().getBlocks().begin(),
|
||||||
region.getBlocks(), header);
|
region.getBlocks(), header);
|
||||||
for (auto *w : L->getBlocks()) {
|
for (auto *w : L->getBlocks()) {
|
||||||
Block *b = &**w;
|
Block *b = &**w;
|
||||||
if (b != header) {
|
if (b != header) {
|
||||||
loop.before().getBlocks().splice(loop.before().getBlocks().end(),
|
loop.getBefore().getBlocks().splice(loop.getBefore().getBlocks().end(),
|
||||||
region.getBlocks(), b);
|
region.getBlocks(), b);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -462,7 +462,7 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region ®ion) {
|
||||||
Block *pseudoExit = new Block();
|
Block *pseudoExit = new Block();
|
||||||
auto i1Ty = builder.getI1Type();
|
auto i1Ty = builder.getI1Type();
|
||||||
{
|
{
|
||||||
loop.before().push_back(pseudoExit);
|
loop.getBefore().push_back(pseudoExit);
|
||||||
SmallVector<Type, 4> tys = {i1Ty};
|
SmallVector<Type, 4> tys = {i1Ty};
|
||||||
for (auto t : combinedTypes)
|
for (auto t : combinedTypes)
|
||||||
tys.push_back(t);
|
tys.push_back(t);
|
||||||
|
@ -592,7 +592,7 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region ®ion) {
|
||||||
|
|
||||||
Block *after = new Block();
|
Block *after = new Block();
|
||||||
after->addArguments(combinedTypes);
|
after->addArguments(combinedTypes);
|
||||||
loop.after().push_back(after);
|
loop.getAfter().push_back(after);
|
||||||
OpBuilder builder2(after, after->begin());
|
OpBuilder builder2(after, after->begin());
|
||||||
SmallVector<Value, 4> yieldargs;
|
SmallVector<Value, 4> yieldargs;
|
||||||
for (auto a : after->getArguments()) {
|
for (auto a : after->getArguments()) {
|
||||||
|
@ -619,42 +619,42 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region ®ion) {
|
||||||
}
|
}
|
||||||
|
|
||||||
builder2.create<scf::YieldOp>(builder.getUnknownLoc(), yieldargs);
|
builder2.create<scf::YieldOp>(builder.getUnknownLoc(), yieldargs);
|
||||||
domInfo.invalidate(&loop.before());
|
domInfo.invalidate(&loop.getBefore());
|
||||||
runOnRegion(domInfo, loop.before());
|
runOnRegion(domInfo, loop.getBefore());
|
||||||
if (!removeIfFromRegion(domInfo, loop.before(), pseudoExit)) {
|
if (!removeIfFromRegion(domInfo, loop.getBefore(), pseudoExit)) {
|
||||||
attemptToFoldIntoPredecessor(pseudoExit);
|
attemptToFoldIntoPredecessor(pseudoExit);
|
||||||
}
|
}
|
||||||
|
|
||||||
attemptToFoldIntoPredecessor(wrapper);
|
attemptToFoldIntoPredecessor(wrapper);
|
||||||
attemptToFoldIntoPredecessor(target);
|
attemptToFoldIntoPredecessor(target);
|
||||||
if (loop.before().getBlocks().size() != 1) {
|
if (loop.getBefore().getBlocks().size() != 1) {
|
||||||
Block *blk = new Block();
|
Block *blk = new Block();
|
||||||
OpBuilder B(loop.getContext());
|
OpBuilder B(loop.getContext());
|
||||||
B.setInsertionPointToEnd(blk);
|
B.setInsertionPointToEnd(blk);
|
||||||
auto cop =
|
auto cop =
|
||||||
cast<scf::ConditionOp>(loop.before().getBlocks().back().back());
|
cast<scf::ConditionOp>(loop.getBefore().getBlocks().back().back());
|
||||||
auto er = B.create<scf::ExecuteRegionOp>(loop.getLoc(),
|
auto er = B.create<scf::ExecuteRegionOp>(loop.getLoc(),
|
||||||
cop.getOperandTypes());
|
cop.getOperandTypes());
|
||||||
er.region().getBlocks().splice(er.region().getBlocks().begin(),
|
er.getRegion().getBlocks().splice(er.getRegion().getBlocks().begin(),
|
||||||
loop.before().getBlocks());
|
loop.getBefore().getBlocks());
|
||||||
loop.before().push_back(blk);
|
loop.getBefore().push_back(blk);
|
||||||
SmallVector<Value> yields;
|
SmallVector<Value> yields;
|
||||||
for (auto a : er.getResults())
|
for (auto a : er.getResults())
|
||||||
yields.push_back(a);
|
yields.push_back(a);
|
||||||
yields.erase(yields.begin());
|
yields.erase(yields.begin());
|
||||||
B.create<scf::ConditionOp>(cop.getLoc(), er.getResult(0), yields);
|
B.create<scf::ConditionOp>(cop.getLoc(), er.getResult(0), yields);
|
||||||
B.setInsertionPoint(&*cop);
|
B.setInsertionPoint(&*cop);
|
||||||
for (auto arg : er.region().front().getArguments()) {
|
for (auto arg : er.getRegion().front().getArguments()) {
|
||||||
auto na = blk->addArgument(arg.getType());
|
auto na = blk->addArgument(arg.getType());
|
||||||
arg.replaceAllUsesWith(na);
|
arg.replaceAllUsesWith(na);
|
||||||
}
|
}
|
||||||
er.region().front().eraseArguments([](BlockArgument) { return true; });
|
er.getRegion().front().eraseArguments([](BlockArgument) { return true; });
|
||||||
B.create<scf::YieldOp>(cop.getLoc(), cop.getOperands());
|
B.create<scf::YieldOp>(cop.getLoc(), cop.getOperands());
|
||||||
cop.erase();
|
cop.erase();
|
||||||
}
|
}
|
||||||
assert(loop.before().getBlocks().size() == 1);
|
assert(loop.getBefore().getBlocks().size() == 1);
|
||||||
runOnRegion(domInfo, loop.after());
|
runOnRegion(domInfo, loop.getAfter());
|
||||||
assert(loop.after().getBlocks().size() == 1);
|
assert(loop.getAfter().getBlocks().size() == 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -453,7 +453,7 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
|
||||||
for (auto a : ops) {
|
for (auto a : ops) {
|
||||||
if (StoringOperations.count(a)) {
|
if (StoringOperations.count(a)) {
|
||||||
if (auto exOp = dyn_cast<mlir::scf::ExecuteRegionOp>(a)) {
|
if (auto exOp = dyn_cast<mlir::scf::ExecuteRegionOp>(a)) {
|
||||||
valueAtStartOfBlock[&*exOp.region().begin()] = lastVal;
|
valueAtStartOfBlock[&*exOp.getRegion().begin()] = lastVal;
|
||||||
Value thenVal; // = handleBlock(exOp.region().front(), lastVal);
|
Value thenVal; // = handleBlock(exOp.region().front(), lastVal);
|
||||||
lastVal = nullptr;
|
lastVal = nullptr;
|
||||||
seenSubStore = true;
|
seenSubStore = true;
|
||||||
|
@ -477,7 +477,7 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
Block &then = exOp.region().back();
|
Block &then = exOp.getRegion().back();
|
||||||
OpBuilder B(exOp.getContext());
|
OpBuilder B(exOp.getContext());
|
||||||
auto yieldOp = cast<mlir::scf::YieldOp>(then.back());
|
auto yieldOp = cast<mlir::scf::YieldOp>(then.back());
|
||||||
B.setInsertionPoint(yieldOp);
|
B.setInsertionPoint(yieldOp);
|
||||||
|
@ -502,13 +502,13 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
|
||||||
auto nextIf =
|
auto nextIf =
|
||||||
B.create<mlir::scf::ExecuteRegionOp>(exOp.getLoc(), tys);
|
B.create<mlir::scf::ExecuteRegionOp>(exOp.getLoc(), tys);
|
||||||
|
|
||||||
SmallVector<mlir::Value, 4> thenVals = yieldOp.results();
|
SmallVector<mlir::Value, 4> thenVals = yieldOp.getResults();
|
||||||
thenVals.push_back(thenVal);
|
thenVals.push_back(thenVal);
|
||||||
yieldOp->setOperands(thenVals);
|
yieldOp->setOperands(thenVals);
|
||||||
nextIf.region().getBlocks().clear();
|
nextIf.getRegion().getBlocks().clear();
|
||||||
nextIf.region().getBlocks().splice(
|
nextIf.getRegion().getBlocks().splice(
|
||||||
nextIf.region().getBlocks().begin(),
|
nextIf.getRegion().getBlocks().begin(),
|
||||||
exOp.region().getBlocks());
|
exOp.getRegion().getBlocks());
|
||||||
|
|
||||||
SmallVector<mlir::Value, 3> resvals = (nextIf.getResults());
|
SmallVector<mlir::Value, 3> resvals = (nextIf.getResults());
|
||||||
lastVal = resvals.back();
|
lastVal = resvals.back();
|
||||||
|
@ -559,15 +559,15 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
|
||||||
lastVal = newLoad->getResult(0);
|
lastVal = newLoad->getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
valueAtStartOfBlock[&*ifOp.thenRegion().begin()] = lastVal;
|
valueAtStartOfBlock[&*ifOp.getThenRegion().begin()] = lastVal;
|
||||||
mlir::Value thenVal =
|
mlir::Value thenVal =
|
||||||
handleBlock(*ifOp.thenRegion().begin(), lastVal);
|
handleBlock(*ifOp.getThenRegion().begin(), lastVal);
|
||||||
|
|
||||||
if (lastVal && ifOp.elseRegion().getBlocks().size())
|
if (lastVal && ifOp.getElseRegion().getBlocks().size())
|
||||||
valueAtStartOfBlock[&*ifOp.elseRegion().begin()] = lastVal;
|
valueAtStartOfBlock[&*ifOp.getElseRegion().begin()] = lastVal;
|
||||||
mlir::Value elseVal =
|
mlir::Value elseVal =
|
||||||
(ifOp.elseRegion().getBlocks().size())
|
(ifOp.getElseRegion().getBlocks().size())
|
||||||
? handleBlock(*ifOp.elseRegion().begin(), lastVal)
|
? handleBlock(*ifOp.getElseRegion().begin(), lastVal)
|
||||||
: lastVal;
|
: lastVal;
|
||||||
|
|
||||||
if (thenVal == elseVal && thenVal != nullptr) {
|
if (thenVal == elseVal && thenVal != nullptr) {
|
||||||
|
@ -582,41 +582,37 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
|
||||||
ifOp.getResultTypes().end());
|
ifOp.getResultTypes().end());
|
||||||
tys.push_back(thenVal.getType());
|
tys.push_back(thenVal.getType());
|
||||||
auto nextIf = B.create<mlir::scf::IfOp>(
|
auto nextIf = B.create<mlir::scf::IfOp>(
|
||||||
ifOp.getLoc(), tys, ifOp.condition(), /*hasElse*/ true);
|
ifOp.getLoc(), tys, ifOp.getCondition(), /*hasElse*/ true);
|
||||||
|
|
||||||
Block &then = ifOp.thenRegion().back();
|
Block &then = ifOp.getThenRegion().back();
|
||||||
SmallVector<mlir::Value, 4> thenVals =
|
SmallVector<mlir::Value, 4> thenVals =
|
||||||
cast<mlir::scf::YieldOp>(then.back()).results();
|
cast<mlir::scf::YieldOp>(then.back()).getResults();
|
||||||
thenVals.push_back(thenVal);
|
thenVals.push_back(thenVal);
|
||||||
nextIf.thenRegion().getBlocks().clear();
|
nextIf.getThenRegion().getBlocks().clear();
|
||||||
nextIf.thenRegion().getBlocks().splice(
|
nextIf.getThenRegion().takeBody(ifOp.getThenRegion());
|
||||||
nextIf.thenRegion().getBlocks().begin(),
|
|
||||||
ifOp.thenRegion().getBlocks());
|
|
||||||
cast<mlir::scf::YieldOp>(
|
cast<mlir::scf::YieldOp>(
|
||||||
nextIf.thenRegion().back().getTerminator())
|
nextIf.getThenRegion().back().getTerminator())
|
||||||
->setOperands(thenVals);
|
->setOperands(thenVals);
|
||||||
|
|
||||||
if (ifOp.elseRegion().getBlocks().size()) {
|
if (ifOp.getElseRegion().getBlocks().size()) {
|
||||||
nextIf.elseRegion().getBlocks().clear();
|
nextIf.getElseRegion().getBlocks().clear();
|
||||||
SmallVector<mlir::Value, 4> elseVals =
|
SmallVector<mlir::Value, 4> elseVals =
|
||||||
cast<mlir::scf::YieldOp>(ifOp.elseRegion().back().back())
|
cast<mlir::scf::YieldOp>(ifOp.getElseRegion().back().back())
|
||||||
.results();
|
.getResults();
|
||||||
elseVals.push_back(elseVal);
|
elseVals.push_back(elseVal);
|
||||||
nextIf.elseRegion().getBlocks().splice(
|
nextIf.getElseRegion().takeBody(ifOp.getElseRegion());
|
||||||
nextIf.elseRegion().getBlocks().begin(),
|
|
||||||
ifOp.elseRegion().getBlocks());
|
|
||||||
cast<mlir::scf::YieldOp>(
|
cast<mlir::scf::YieldOp>(
|
||||||
nextIf.elseRegion().back().getTerminator())
|
nextIf.getElseRegion().back().getTerminator())
|
||||||
->setOperands(elseVals);
|
->setOperands(elseVals);
|
||||||
} else {
|
} else {
|
||||||
B.setInsertionPoint(&nextIf.elseRegion().back(),
|
B.setInsertionPoint(&nextIf.getElseRegion().back(),
|
||||||
nextIf.elseRegion().back().begin());
|
nextIf.getElseRegion().back().begin());
|
||||||
SmallVector<mlir::Value, 4> elseVals;
|
SmallVector<mlir::Value, 4> elseVals;
|
||||||
elseVals.push_back(elseVal);
|
elseVals.push_back(elseVal);
|
||||||
B.create<mlir::scf::YieldOp>(ifOp.getLoc(), elseVals);
|
B.create<mlir::scf::YieldOp>(ifOp.getLoc(), elseVals);
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<mlir::Value, 3> resvals = (nextIf.results());
|
SmallVector<mlir::Value, 3> resvals = nextIf.getResults();
|
||||||
lastVal = resvals.back();
|
lastVal = resvals.back();
|
||||||
resvals.pop_back();
|
resvals.pop_back();
|
||||||
ifOp.replaceAllUsesWith(resvals);
|
ifOp.replaceAllUsesWith(resvals);
|
||||||
|
|
|
@ -101,7 +101,7 @@ struct ReplaceIfWithFors : public OpRewritePattern<scf::IfOp> {
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(scf::IfOp op,
|
LogicalResult matchAndRewrite(scf::IfOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
assert(op.condition().getType().isInteger(1));
|
assert(op.getCondition().getType().isInteger(1));
|
||||||
|
|
||||||
if (!hasNestedBarrier(op)) {
|
if (!hasNestedBarrier(op)) {
|
||||||
LLVM_DEBUG(DBGS() << "[if-to-for] no nested barrier\n");
|
LLVM_DEBUG(DBGS() << "[if-to-for] no nested barrier\n");
|
||||||
|
@ -119,7 +119,7 @@ struct ReplaceIfWithFors : public OpRewritePattern<scf::IfOp> {
|
||||||
|
|
||||||
auto cond = rewriter.create<IndexCastOp>(
|
auto cond = rewriter.create<IndexCastOp>(
|
||||||
loc, rewriter.getIndexType(),
|
loc, rewriter.getIndexType(),
|
||||||
rewriter.create<ExtUIOp>(loc, op.condition(),
|
rewriter.create<ExtUIOp>(loc, op.getCondition(),
|
||||||
mlir::IntegerType::get(one.getContext(), 64)));
|
mlir::IntegerType::get(one.getContext(), 64)));
|
||||||
auto thenLoop = rewriter.create<scf::ForOp>(loc, zero, cond, one, forArgs);
|
auto thenLoop = rewriter.create<scf::ForOp>(loc, zero, cond, one, forArgs);
|
||||||
if (forArgs.size() == 0)
|
if (forArgs.size() == 0)
|
||||||
|
@ -128,7 +128,7 @@ struct ReplaceIfWithFors : public OpRewritePattern<scf::IfOp> {
|
||||||
|
|
||||||
SmallVector<Value> vals;
|
SmallVector<Value> vals;
|
||||||
|
|
||||||
if (!op.elseRegion().empty()) {
|
if (!op.getElseRegion().empty()) {
|
||||||
auto negCondition = rewriter.create<SubIOp>(loc, one, cond);
|
auto negCondition = rewriter.create<SubIOp>(loc, one, cond);
|
||||||
scf::ForOp elseLoop = rewriter.create<scf::ForOp>(loc, zero, negCondition, one, forArgs);
|
scf::ForOp elseLoop = rewriter.create<scf::ForOp>(loc, zero, negCondition, one, forArgs);
|
||||||
if (forArgs.size() == 0)
|
if (forArgs.size() == 0)
|
||||||
|
@ -136,7 +136,7 @@ struct ReplaceIfWithFors : public OpRewritePattern<scf::IfOp> {
|
||||||
rewriter.mergeBlocks(op.getBody(1), elseLoop.getBody(0));
|
rewriter.mergeBlocks(op.getBody(1), elseLoop.getBody(0));
|
||||||
|
|
||||||
for (auto tup : llvm::zip(thenLoop.getResults(), elseLoop.getResults())) {
|
for (auto tup : llvm::zip(thenLoop.getResults(), elseLoop.getResults())) {
|
||||||
vals.push_back(rewriter.create<SelectOp>(op.getLoc(), op.condition(), std::get<0>(tup), std::get<1>(tup)));
|
vals.push_back(rewriter.create<SelectOp>(op.getLoc(), op.getCondition(), std::get<0>(tup), std::get<1>(tup)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -153,7 +153,7 @@ static bool isDefinedAbove(Value value, Operation *user) {
|
||||||
|
|
||||||
/// Returns `true` if the loop has a form expected by interchange patterns.
|
/// Returns `true` if the loop has a form expected by interchange patterns.
|
||||||
static bool isNormalized(scf::ForOp op) {
|
static bool isNormalized(scf::ForOp op) {
|
||||||
return isDefinedAbove(op.lowerBound(), op) && isDefinedAbove(op.step(), op);
|
return isDefinedAbove(op.getLowerBound(), op) && isDefinedAbove(op.getStep(), op);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Transforms a loop to the normal form expected by interchange patterns, i.e.
|
/// Transforms a loop to the normal form expected by interchange patterns, i.e.
|
||||||
|
@ -179,15 +179,15 @@ struct NormalizeLoop : public OpRewritePattern<scf::ForOp> {
|
||||||
rewriter.restoreInsertionPoint(point);
|
rewriter.restoreInsertionPoint(point);
|
||||||
|
|
||||||
Value difference =
|
Value difference =
|
||||||
rewriter.create<SubIOp>(op.getLoc(), op.upperBound(), op.lowerBound());
|
rewriter.create<SubIOp>(op.getLoc(), op.getUpperBound(), op.getLowerBound());
|
||||||
Value tripCount =
|
Value tripCount =
|
||||||
rewriter.create<CeilDivSIOp>(op.getLoc(), difference, op.step());
|
rewriter.create<CeilDivSIOp>(op.getLoc(), difference, op.getStep());
|
||||||
auto newForOp =
|
auto newForOp =
|
||||||
rewriter.create<scf::ForOp>(op.getLoc(), zero, tripCount, one);
|
rewriter.create<scf::ForOp>(op.getLoc(), zero, tripCount, one);
|
||||||
rewriter.setInsertionPointToStart(newForOp.getBody());
|
rewriter.setInsertionPointToStart(newForOp.getBody());
|
||||||
Value scaled = rewriter.create<MulIOp>(
|
Value scaled = rewriter.create<MulIOp>(
|
||||||
op.getLoc(), newForOp.getInductionVar(), op.step());
|
op.getLoc(), newForOp.getInductionVar(), op.getStep());
|
||||||
Value iv = rewriter.create<AddIOp>(op.getLoc(), op.lowerBound(), scaled);
|
Value iv = rewriter.create<AddIOp>(op.getLoc(), op.getLowerBound(), scaled);
|
||||||
rewriter.mergeBlockBefore(op.getBody(), &newForOp.getBody()->back(), {iv});
|
rewriter.mergeBlockBefore(op.getBody(), &newForOp.getBody()->back(), {iv});
|
||||||
rewriter.eraseOp(&newForOp.getBody()->back());
|
rewriter.eraseOp(&newForOp.getBody()->back());
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
|
@ -205,8 +205,8 @@ static bool isNormalized(scf::ParallelOp op) {
|
||||||
APInt value;
|
APInt value;
|
||||||
return matchPattern(v, m_ConstantInt(&value)) && value.isOneValue();
|
return matchPattern(v, m_ConstantInt(&value)) && value.isOneValue();
|
||||||
};
|
};
|
||||||
return llvm::all_of(op.lowerBound(), isZero) &&
|
return llvm::all_of(op.getLowerBound(), isZero) &&
|
||||||
llvm::all_of(op.step(), isOne);
|
llvm::all_of(op.getStep(), isOne);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Transforms a loop to the normal form expected by interchange patterns, i.e.
|
/// Transforms a loop to the normal form expected by interchange patterns, i.e.
|
||||||
|
@ -242,9 +242,9 @@ struct NormalizeParallel : public OpRewritePattern<scf::ParallelOp> {
|
||||||
rewriter.setInsertionPointToStart(newOp.getBody());
|
rewriter.setInsertionPointToStart(newOp.getBody());
|
||||||
for (unsigned i = 0, e = iterationCounts.size(); i < e; ++i) {
|
for (unsigned i = 0, e = iterationCounts.size(); i < e; ++i) {
|
||||||
Value scaled = rewriter.create<MulIOp>(
|
Value scaled = rewriter.create<MulIOp>(
|
||||||
op.getLoc(), newOp.getInductionVars()[i], op.step()[i]);
|
op.getLoc(), newOp.getInductionVars()[i], op.getStep()[i]);
|
||||||
Value shifted =
|
Value shifted =
|
||||||
rewriter.create<AddIOp>(op.getLoc(), op.lowerBound()[i], scaled);
|
rewriter.create<AddIOp>(op.getLoc(), op.getLowerBound()[i], scaled);
|
||||||
inductionVars.push_back(shifted);
|
inductionVars.push_back(shifted);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -333,7 +333,7 @@ struct WrapForWithBarrier : public OpRewritePattern<scf::ForOp> {
|
||||||
|
|
||||||
return wrapWithBarriers(op, rewriter, [&](Operation *prevOp) {
|
return wrapWithBarriers(op, rewriter, [&](Operation *prevOp) {
|
||||||
if (auto loadOp = dyn_cast_or_null<memref::LoadOp>(prevOp)) {
|
if (auto loadOp = dyn_cast_or_null<memref::LoadOp>(prevOp)) {
|
||||||
if (loadOp.result() == op.upperBound() &&
|
if (loadOp.result() == op.getUpperBound() &&
|
||||||
loadOp.indices() ==
|
loadOp.indices() ==
|
||||||
cast<scf::ParallelOp>(op->getParentOp()).getInductionVars()) {
|
cast<scf::ParallelOp>(op->getParentOp()).getInductionVars()) {
|
||||||
prevOp = prevOp->getPrevNode();
|
prevOp = prevOp->getPrevNode();
|
||||||
|
@ -353,7 +353,7 @@ struct WrapWhileWithBarrier : public OpRewritePattern<scf::WhileOp> {
|
||||||
LogicalResult matchAndRewrite(scf::WhileOp op,
|
LogicalResult matchAndRewrite(scf::WhileOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
if (op.getNumOperands() != 0 ||
|
if (op.getNumOperands() != 0 ||
|
||||||
!llvm::hasSingleElement(op.after().front())) {
|
!llvm::hasSingleElement(op.getAfter().front())) {
|
||||||
LLVM_DEBUG(DBGS() << "[wrap-while] ignoring non-rotated loop\n";);
|
LLVM_DEBUG(DBGS() << "[wrap-while] ignoring non-rotated loop\n";);
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
@ -372,7 +372,7 @@ static void moveBodies(PatternRewriter &rewriter, scf::ParallelOp op,
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
rewriter.setInsertionPointToStart(newForLoop.getBody());
|
rewriter.setInsertionPointToStart(newForLoop.getBody());
|
||||||
auto newParallel = rewriter.create<scf::ParallelOp>(
|
auto newParallel = rewriter.create<scf::ParallelOp>(
|
||||||
op.getLoc(), op.lowerBound(), op.upperBound(), op.step());
|
op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep());
|
||||||
|
|
||||||
// Merge in two stages so we can properly replace uses of two induction
|
// Merge in two stages so we can properly replace uses of two induction
|
||||||
// varibales defined in different blocks.
|
// varibales defined in different blocks.
|
||||||
|
@ -419,8 +419,8 @@ struct InterchangeForPFor : public OpRewritePattern<scf::ParallelOp> {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto newForLoop =
|
auto newForLoop =
|
||||||
rewriter.create<scf::ForOp>(forLoop.getLoc(), forLoop.lowerBound(),
|
rewriter.create<scf::ForOp>(forLoop.getLoc(), forLoop.getLowerBound(),
|
||||||
forLoop.upperBound(), forLoop.step());
|
forLoop.getUpperBound(), forLoop.getStep());
|
||||||
moveBodies(rewriter, op, forLoop, newForLoop);
|
moveBodies(rewriter, op, forLoop, newForLoop);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -446,7 +446,7 @@ struct InterchangeForPForLoad : public OpRewritePattern<scf::ParallelOp> {
|
||||||
}
|
}
|
||||||
auto loadOp = dyn_cast<memref::LoadOp>(op.getBody()->front());
|
auto loadOp = dyn_cast<memref::LoadOp>(op.getBody()->front());
|
||||||
auto forOp = dyn_cast<scf::ForOp>(op.getBody()->front().getNextNode());
|
auto forOp = dyn_cast<scf::ForOp>(op.getBody()->front().getNextNode());
|
||||||
if (!loadOp || !forOp || loadOp.result() != forOp.upperBound() ||
|
if (!loadOp || !forOp || loadOp.result() != forOp.getUpperBound() ||
|
||||||
loadOp.indices() != op.getInductionVars()) {
|
loadOp.indices() != op.getInductionVars()) {
|
||||||
LLVM_DEBUG(DBGS() << "[interchange-load] expected pfor(load, for)");
|
LLVM_DEBUG(DBGS() << "[interchange-load] expected pfor(load, for)");
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -470,7 +470,7 @@ struct InterchangeForPForLoad : public OpRewritePattern<scf::ParallelOp> {
|
||||||
loadOp.getLoc(), loadOp.getMemRef(),
|
loadOp.getLoc(), loadOp.getMemRef(),
|
||||||
SmallVector<Value>(loadOp.getMemRefType().getRank(), zero));
|
SmallVector<Value>(loadOp.getMemRefType().getRank(), zero));
|
||||||
auto newForLoop = rewriter.create<scf::ForOp>(
|
auto newForLoop = rewriter.create<scf::ForOp>(
|
||||||
forOp.getLoc(), forOp.lowerBound(), tripCount, forOp.step());
|
forOp.getLoc(), forOp.getLowerBound(), tripCount, forOp.getStep());
|
||||||
|
|
||||||
moveBodies(rewriter, op, forOp, newForLoop);
|
moveBodies(rewriter, op, forOp, newForLoop);
|
||||||
return success();
|
return success();
|
||||||
|
@ -535,9 +535,9 @@ findInsertionPointAfterLoopOperands(scf::ParallelOp op) {
|
||||||
// Find the earliest insertion point where loop bounds are fully defined.
|
// Find the earliest insertion point where loop bounds are fully defined.
|
||||||
PostDominanceInfo postDominanceInfo(op->getParentOfType<FuncOp>());
|
PostDominanceInfo postDominanceInfo(op->getParentOfType<FuncOp>());
|
||||||
SmallVector<Value> operands;
|
SmallVector<Value> operands;
|
||||||
llvm::append_range(operands, op.lowerBound());
|
llvm::append_range(operands, op.getLowerBound());
|
||||||
llvm::append_range(operands, op.upperBound());
|
llvm::append_range(operands, op.getUpperBound());
|
||||||
llvm::append_range(operands, op.step());
|
llvm::append_range(operands, op.getStep());
|
||||||
return findNearestPostDominatingInsertionPoint(operands, postDominanceInfo);
|
return findNearestPostDominatingInsertionPoint(operands, postDominanceInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -562,7 +562,7 @@ struct InterchangeWhilePFor : public OpRewritePattern<scf::ParallelOp> {
|
||||||
LLVM_DEBUG(DBGS() << "[interchange-while] loop-carried values\n");
|
LLVM_DEBUG(DBGS() << "[interchange-while] loop-carried values\n");
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
if (!llvm::hasSingleElement(whileOp.after().front()) || !isNormalized(op)) {
|
if (!llvm::hasSingleElement(whileOp.getAfter().front()) || !isNormalized(op)) {
|
||||||
LLVM_DEBUG(DBGS() << "[interchange-while] non-normalized loop\n");
|
LLVM_DEBUG(DBGS() << "[interchange-while] non-normalized loop\n");
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
@ -573,37 +573,37 @@ struct InterchangeWhilePFor : public OpRewritePattern<scf::ParallelOp> {
|
||||||
|
|
||||||
auto newWhileOp = rewriter.create<scf::WhileOp>(whileOp.getLoc(),
|
auto newWhileOp = rewriter.create<scf::WhileOp>(whileOp.getLoc(),
|
||||||
TypeRange(), ValueRange());
|
TypeRange(), ValueRange());
|
||||||
rewriter.createBlock(&newWhileOp.after());
|
rewriter.createBlock(&newWhileOp.getAfter());
|
||||||
rewriter.clone(whileOp.after().front().back());
|
rewriter.clone(whileOp.getAfter().front().back());
|
||||||
|
|
||||||
rewriter.createBlock(&newWhileOp.before());
|
rewriter.createBlock(&newWhileOp.getBefore());
|
||||||
auto newParallelOp = rewriter.create<scf::ParallelOp>(
|
auto newParallelOp = rewriter.create<scf::ParallelOp>(
|
||||||
op.getLoc(), op.lowerBound(), op.upperBound(), op.step());
|
op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep());
|
||||||
|
|
||||||
auto conditionOp = cast<scf::ConditionOp>(whileOp.before().front().back());
|
auto conditionOp = cast<scf::ConditionOp>(whileOp.getBefore().front().back());
|
||||||
rewriter.mergeBlockBefore(op.getBody(), &newParallelOp.getBody()->back(),
|
rewriter.mergeBlockBefore(op.getBody(), &newParallelOp.getBody()->back(),
|
||||||
newParallelOp.getInductionVars());
|
newParallelOp.getInductionVars());
|
||||||
rewriter.eraseOp(newParallelOp.getBody()->back().getPrevNode());
|
rewriter.eraseOp(newParallelOp.getBody()->back().getPrevNode());
|
||||||
rewriter.mergeBlockBefore(&whileOp.before().front(),
|
rewriter.mergeBlockBefore(&whileOp.getBefore().front(),
|
||||||
&newParallelOp.getBody()->back());
|
&newParallelOp.getBody()->back());
|
||||||
|
|
||||||
Operation *conditionDefiningOp = conditionOp.condition().getDefiningOp();
|
Operation *conditionDefiningOp = conditionOp.getCondition().getDefiningOp();
|
||||||
if (conditionDefiningOp &&
|
if (conditionDefiningOp &&
|
||||||
!isDefinedAbove(conditionOp.condition(), conditionOp)) {
|
!isDefinedAbove(conditionOp.getCondition(), conditionOp)) {
|
||||||
std::pair<Block *, Block::iterator> insertionPoint =
|
std::pair<Block *, Block::iterator> insertionPoint =
|
||||||
findInsertionPointAfterLoopOperands(op);
|
findInsertionPointAfterLoopOperands(op);
|
||||||
rewriter.setInsertionPoint(insertionPoint.first, insertionPoint.second);
|
rewriter.setInsertionPoint(insertionPoint.first, insertionPoint.second);
|
||||||
SmallVector<Value> iterationCounts = emitIterationCounts(rewriter, op);
|
SmallVector<Value> iterationCounts = emitIterationCounts(rewriter, op);
|
||||||
Value allocated = allocateTemporaryBuffer<memref::AllocaOp>(
|
Value allocated = allocateTemporaryBuffer<memref::AllocaOp>(
|
||||||
rewriter, conditionOp.condition(), iterationCounts);
|
rewriter, conditionOp.getCondition(), iterationCounts);
|
||||||
Value zero = rewriter.create<ConstantIndexOp>(op.getLoc(), 0);
|
Value zero = rewriter.create<ConstantIndexOp>(op.getLoc(), 0);
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(conditionDefiningOp);
|
rewriter.setInsertionPointAfter(conditionDefiningOp);
|
||||||
rewriter.create<memref::StoreOp>(conditionDefiningOp->getLoc(),
|
rewriter.create<memref::StoreOp>(conditionDefiningOp->getLoc(),
|
||||||
conditionOp.condition(), allocated,
|
conditionOp.getCondition(), allocated,
|
||||||
newParallelOp.getInductionVars());
|
newParallelOp.getInductionVars());
|
||||||
|
|
||||||
rewriter.setInsertionPointToEnd(&newWhileOp.before().front());
|
rewriter.setInsertionPointToEnd(&newWhileOp.getBefore().front());
|
||||||
SmallVector<Value> zeros(iterationCounts.size(), zero);
|
SmallVector<Value> zeros(iterationCounts.size(), zero);
|
||||||
Value reloaded = rewriter.create<memref::LoadOp>(
|
Value reloaded = rewriter.create<memref::LoadOp>(
|
||||||
conditionDefiningOp->getLoc(), allocated, zeros);
|
conditionDefiningOp->getLoc(), allocated, zeros);
|
||||||
|
@ -646,7 +646,7 @@ struct RotateWhile : public OpRewritePattern<scf::WhileOp> {
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(scf::WhileOp op,
|
LogicalResult matchAndRewrite(scf::WhileOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
if (llvm::hasSingleElement(op.after().front())) {
|
if (llvm::hasSingleElement(op.getAfter().front())) {
|
||||||
LLVM_DEBUG(DBGS() << "[rotate-while] the after region is empty");
|
LLVM_DEBUG(DBGS() << "[rotate-while] the after region is empty");
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
@ -659,15 +659,15 @@ struct RotateWhile : public OpRewritePattern<scf::WhileOp> {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto condition = cast<scf::ConditionOp>(op.before().front().back());
|
auto condition = cast<scf::ConditionOp>(op.getBefore().front().back());
|
||||||
rewriter.setInsertionPoint(condition);
|
rewriter.setInsertionPoint(condition);
|
||||||
auto conditional =
|
auto conditional =
|
||||||
rewriter.create<scf::IfOp>(op.getLoc(), condition.condition());
|
rewriter.create<scf::IfOp>(op.getLoc(), condition.getCondition());
|
||||||
rewriter.mergeBlockBefore(&op.after().front(),
|
rewriter.mergeBlockBefore(&op.getAfter().front(),
|
||||||
&conditional.getBody()->back());
|
&conditional.getBody()->back());
|
||||||
rewriter.eraseOp(&conditional.getBody()->back());
|
rewriter.eraseOp(&conditional.getBody()->back());
|
||||||
|
|
||||||
rewriter.createBlock(&op.after());
|
rewriter.createBlock(&op.getAfter());
|
||||||
rewriter.clone(conditional.getBody()->back());
|
rewriter.clone(conditional.getBody()->back());
|
||||||
|
|
||||||
LLVM_DEBUG(DBGS() << "[rotate-while] done\n");
|
LLVM_DEBUG(DBGS() << "[rotate-while] done\n");
|
||||||
|
@ -758,7 +758,7 @@ struct DistributeAroundBarrier : public OpRewritePattern<scf::ParallelOp> {
|
||||||
// Create the second loop.
|
// Create the second loop.
|
||||||
rewriter.setInsertionPointAfter(op);
|
rewriter.setInsertionPointAfter(op);
|
||||||
auto newLoop = rewriter.create<scf::ParallelOp>(
|
auto newLoop = rewriter.create<scf::ParallelOp>(
|
||||||
op.getLoc(), op.lowerBound(), op.upperBound(), op.step());
|
op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep());
|
||||||
rewriter.eraseOp(&newLoop.getBody()->back());
|
rewriter.eraseOp(&newLoop.getBody()->back());
|
||||||
|
|
||||||
for (auto alloc : allocations)
|
for (auto alloc : allocations)
|
||||||
|
@ -837,8 +837,8 @@ struct Reg2MemFor : public OpRewritePattern<scf::ForOp> {
|
||||||
rewriter.create<memref::StoreOp>(op.getLoc(), operand, alloc, zero);
|
rewriter.create<memref::StoreOp>(op.getLoc(), operand, alloc, zero);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto newOp = rewriter.create<scf::ForOp>(op.getLoc(), op.lowerBound(),
|
auto newOp = rewriter.create<scf::ForOp>(op.getLoc(), op.getLowerBound(),
|
||||||
op.upperBound(), op.step());
|
op.getUpperBound(), op.getStep());
|
||||||
rewriter.setInsertionPointToStart(newOp.getBody());
|
rewriter.setInsertionPointToStart(newOp.getBody());
|
||||||
SmallVector<Value> newRegionArguments;
|
SmallVector<Value> newRegionArguments;
|
||||||
newRegionArguments.push_back(newOp.getInductionVar());
|
newRegionArguments.push_back(newOp.getInductionVar());
|
||||||
|
@ -849,7 +849,7 @@ struct Reg2MemFor : public OpRewritePattern<scf::ForOp> {
|
||||||
newRegionArguments);
|
newRegionArguments);
|
||||||
|
|
||||||
rewriter.setInsertionPoint(newOp.getBody()->getTerminator());
|
rewriter.setInsertionPoint(newOp.getBody()->getTerminator());
|
||||||
for (auto en : llvm::enumerate(oldTerminator.results())) {
|
for (auto en : llvm::enumerate(oldTerminator.getResults())) {
|
||||||
rewriter.create<memref::StoreOp>(op.getLoc(), en.value(),
|
rewriter.create<memref::StoreOp>(op.getLoc(), en.value(),
|
||||||
allocated[en.index()], zero);
|
allocated[en.index()], zero);
|
||||||
}
|
}
|
||||||
|
@ -908,35 +908,35 @@ struct Reg2MemWhile : public OpRewritePattern<scf::WhileOp> {
|
||||||
auto newOp =
|
auto newOp =
|
||||||
rewriter.create<scf::WhileOp>(op.getLoc(), TypeRange(), ValueRange());
|
rewriter.create<scf::WhileOp>(op.getLoc(), TypeRange(), ValueRange());
|
||||||
Block *newBefore =
|
Block *newBefore =
|
||||||
rewriter.createBlock(&newOp.before(), newOp.before().begin());
|
rewriter.createBlock(&newOp.getBefore(), newOp.getBefore().begin());
|
||||||
SmallVector<Value> newBeforeArguments;
|
SmallVector<Value> newBeforeArguments;
|
||||||
loadValues(op.getLoc(), beforeAllocated, zero, rewriter,
|
loadValues(op.getLoc(), beforeAllocated, zero, rewriter,
|
||||||
newBeforeArguments);
|
newBeforeArguments);
|
||||||
rewriter.mergeBlocks(&op.before().front(), newBefore, newBeforeArguments);
|
rewriter.mergeBlocks(&op.getBefore().front(), newBefore, newBeforeArguments);
|
||||||
|
|
||||||
auto beforeTerminator =
|
auto beforeTerminator =
|
||||||
cast<scf::ConditionOp>(newOp.before().front().getTerminator());
|
cast<scf::ConditionOp>(newOp.getBefore().front().getTerminator());
|
||||||
rewriter.setInsertionPoint(beforeTerminator);
|
rewriter.setInsertionPoint(beforeTerminator);
|
||||||
storeValues(op.getLoc(), beforeTerminator.args(), afterAllocated, zero,
|
storeValues(op.getLoc(), beforeTerminator.getArgs(), afterAllocated, zero,
|
||||||
rewriter);
|
rewriter);
|
||||||
|
|
||||||
rewriter.updateRootInPlace(beforeTerminator,
|
rewriter.updateRootInPlace(beforeTerminator,
|
||||||
[&] { beforeTerminator.argsMutable().clear(); });
|
[&] { beforeTerminator.getArgsMutable().clear(); });
|
||||||
|
|
||||||
Block *newAfter =
|
Block *newAfter =
|
||||||
rewriter.createBlock(&newOp.after(), newOp.after().begin());
|
rewriter.createBlock(&newOp.getAfter(), newOp.getAfter().begin());
|
||||||
SmallVector<Value> newAfterArguments;
|
SmallVector<Value> newAfterArguments;
|
||||||
loadValues(op.getLoc(), afterAllocated, zero, rewriter, newAfterArguments);
|
loadValues(op.getLoc(), afterAllocated, zero, rewriter, newAfterArguments);
|
||||||
rewriter.mergeBlocks(&op.after().front(), newAfter, newAfterArguments);
|
rewriter.mergeBlocks(&op.getAfter().front(), newAfter, newAfterArguments);
|
||||||
|
|
||||||
auto afterTerminator =
|
auto afterTerminator =
|
||||||
cast<scf::YieldOp>(newOp.after().front().getTerminator());
|
cast<scf::YieldOp>(newOp.getAfter().front().getTerminator());
|
||||||
rewriter.setInsertionPoint(afterTerminator);
|
rewriter.setInsertionPoint(afterTerminator);
|
||||||
storeValues(op.getLoc(), afterTerminator.results(), beforeAllocated, zero,
|
storeValues(op.getLoc(), afterTerminator.getResults(), beforeAllocated, zero,
|
||||||
rewriter);
|
rewriter);
|
||||||
|
|
||||||
rewriter.updateRootInPlace(
|
rewriter.updateRootInPlace(
|
||||||
afterTerminator, [&] { afterTerminator.resultsMutable().clear(); });
|
afterTerminator, [&] { afterTerminator.getResultsMutable().clear(); });
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(op);
|
rewriter.setInsertionPointAfter(op);
|
||||||
SmallVector<Value> results;
|
SmallVector<Value> results;
|
||||||
|
|
|
@ -30,7 +30,7 @@ struct ForOpRaising : public OpRewritePattern<scf::ForOp> {
|
||||||
bool isAffine(scf::ForOp loop) const {
|
bool isAffine(scf::ForOp loop) const {
|
||||||
// return true;
|
// return true;
|
||||||
// enforce step to be a ConstantIndexOp (maybe too restrictive).
|
// enforce step to be a ConstantIndexOp (maybe too restrictive).
|
||||||
return isa_and_nonnull<ConstantIndexOp>(loop.step().getDefiningOp());
|
return isa_and_nonnull<ConstantIndexOp>(loop.getStep().getDefiningOp());
|
||||||
}
|
}
|
||||||
|
|
||||||
void canonicalizeLoopBounds(AffineForOp forOp) const {
|
void canonicalizeLoopBounds(AffineForOp forOp) const {
|
||||||
|
@ -67,23 +67,23 @@ struct ForOpRaising : public OpRewritePattern<scf::ForOp> {
|
||||||
if (isAffine(loop)) {
|
if (isAffine(loop)) {
|
||||||
OpBuilder builder(loop);
|
OpBuilder builder(loop);
|
||||||
|
|
||||||
if (!isValidIndex(loop.lowerBound())) {
|
if (!isValidIndex(loop.getLowerBound())) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!isValidIndex(loop.upperBound())) {
|
if (!isValidIndex(loop.getUpperBound())) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
AffineForOp affineLoop = rewriter.create<AffineForOp>(
|
AffineForOp affineLoop = rewriter.create<AffineForOp>(
|
||||||
loop.getLoc(), loop.lowerBound(), builder.getSymbolIdentityMap(),
|
loop.getLoc(), loop.getLowerBound(), builder.getSymbolIdentityMap(),
|
||||||
loop.upperBound(), builder.getSymbolIdentityMap(),
|
loop.getUpperBound(), builder.getSymbolIdentityMap(),
|
||||||
getStep(loop.step()), loop.getIterOperands());
|
getStep(loop.getStep()), loop.getIterOperands());
|
||||||
|
|
||||||
canonicalizeLoopBounds(affineLoop);
|
canonicalizeLoopBounds(affineLoop);
|
||||||
|
|
||||||
auto mergedYieldOp =
|
auto mergedYieldOp =
|
||||||
cast<scf::YieldOp>(loop.region().front().getTerminator());
|
cast<scf::YieldOp>(loop.getRegion().front().getTerminator());
|
||||||
|
|
||||||
Block &newBlock = affineLoop.region().front();
|
Block &newBlock = affineLoop.region().front();
|
||||||
|
|
||||||
|
@ -97,10 +97,10 @@ struct ForOpRaising : public OpRewritePattern<scf::ForOp> {
|
||||||
rewriter.updateRootInPlace(loop, [&] {
|
rewriter.updateRootInPlace(loop, [&] {
|
||||||
affineLoop.region().front().getOperations().splice(
|
affineLoop.region().front().getOperations().splice(
|
||||||
affineLoop.region().front().getOperations().begin(),
|
affineLoop.region().front().getOperations().begin(),
|
||||||
loop.region().front().getOperations());
|
loop.getRegion().front().getOperations());
|
||||||
|
|
||||||
for (auto pair : llvm::zip(affineLoop.region().front().getArguments(),
|
for (auto pair : llvm::zip(affineLoop.region().front().getArguments(),
|
||||||
loop.region().front().getArguments())) {
|
loop.getRegion().front().getArguments())) {
|
||||||
std::get<1>(pair).replaceAllUsesWith(std::get<0>(pair));
|
std::get<1>(pair).replaceAllUsesWith(std::get<0>(pair));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
Subproject commit fe5137a0fe421ee153f115ae3cd7fb51bba65795
|
Subproject commit a6a583dae40485cacfac56811e6d9131bac6ca74
|
|
@ -172,10 +172,10 @@ void MLIRScanner::buildAffineLoopImpl(
|
||||||
builder.setInsertionPointToEnd(®.front());
|
builder.setInsertionPointToEnd(®.front());
|
||||||
|
|
||||||
auto er = builder.create<scf::ExecuteRegionOp>(loc, ArrayRef<mlir::Type>());
|
auto er = builder.create<scf::ExecuteRegionOp>(loc, ArrayRef<mlir::Type>());
|
||||||
er.region().push_back(new Block());
|
er.getRegion().push_back(new Block());
|
||||||
builder.setInsertionPointToStart(&er.region().back());
|
builder.setInsertionPointToStart(&er.getRegion().back());
|
||||||
builder.create<scf::YieldOp>(loc);
|
builder.create<scf::YieldOp>(loc);
|
||||||
builder.setInsertionPointToStart(&er.region().back());
|
builder.setInsertionPointToStart(&er.getRegion().back());
|
||||||
|
|
||||||
if (!descr.getForwardMode()) {
|
if (!descr.getForwardMode()) {
|
||||||
val = builder.create<SubIOp>(loc, val, lb);
|
val = builder.create<SubIOp>(loc, val, lb);
|
||||||
|
@ -355,15 +355,15 @@ ValueCategory MLIRScanner::VisitOMPParallelForDirective(
|
||||||
auto oldpoint = builder.getInsertionPoint();
|
auto oldpoint = builder.getInsertionPoint();
|
||||||
auto oldblock = builder.getInsertionBlock();
|
auto oldblock = builder.getInsertionBlock();
|
||||||
|
|
||||||
builder.setInsertionPointToStart(&affineOp.region().front());
|
builder.setInsertionPointToStart(&affineOp.getRegion().front());
|
||||||
|
|
||||||
auto executeRegion =
|
auto executeRegion =
|
||||||
builder.create<scf::ExecuteRegionOp>(loc, ArrayRef<mlir::Type>());
|
builder.create<scf::ExecuteRegionOp>(loc, ArrayRef<mlir::Type>());
|
||||||
executeRegion.region().push_back(new Block());
|
executeRegion.getRegion().push_back(new Block());
|
||||||
builder.setInsertionPointToStart(&executeRegion.region().back());
|
builder.setInsertionPointToStart(&executeRegion.getRegion().back());
|
||||||
|
|
||||||
auto oldScope = allocationScope;
|
auto oldScope = allocationScope;
|
||||||
allocationScope = &executeRegion.region().back();
|
allocationScope = &executeRegion.getRegion().back();
|
||||||
|
|
||||||
for (auto zp : zip(inds, fors->counters())) {
|
for (auto zp : zip(inds, fors->counters())) {
|
||||||
auto idx = builder.create<IndexCastOp>(
|
auto idx = builder.create<IndexCastOp>(
|
||||||
|
@ -552,13 +552,13 @@ ValueCategory MLIRScanner::VisitIfStmt(clang::IfStmt *stmt) {
|
||||||
bool hasElseRegion = stmt->getElse();
|
bool hasElseRegion = stmt->getElse();
|
||||||
auto ifOp = builder.create<mlir::scf::IfOp>(loc, cond, hasElseRegion);
|
auto ifOp = builder.create<mlir::scf::IfOp>(loc, cond, hasElseRegion);
|
||||||
|
|
||||||
ifOp.thenRegion().back().clear();
|
ifOp.getThenRegion().back().clear();
|
||||||
builder.setInsertionPointToStart(&ifOp.thenRegion().back());
|
builder.setInsertionPointToStart(&ifOp.getThenRegion().back());
|
||||||
Visit(stmt->getThen());
|
Visit(stmt->getThen());
|
||||||
builder.create<scf::YieldOp>(loc);
|
builder.create<scf::YieldOp>(loc);
|
||||||
if (hasElseRegion) {
|
if (hasElseRegion) {
|
||||||
ifOp.elseRegion().back().clear();
|
ifOp.getElseRegion().back().clear();
|
||||||
builder.setInsertionPointToStart(&ifOp.elseRegion().back());
|
builder.setInsertionPointToStart(&ifOp.getElseRegion().back());
|
||||||
Visit(stmt->getElse());
|
Visit(stmt->getElse());
|
||||||
builder.create<scf::YieldOp>(loc);
|
builder.create<scf::YieldOp>(loc);
|
||||||
}
|
}
|
||||||
|
@ -574,7 +574,7 @@ ValueCategory MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) {
|
||||||
SmallVector<int64_t> caseVals;
|
SmallVector<int64_t> caseVals;
|
||||||
|
|
||||||
auto er = builder.create<scf::ExecuteRegionOp>(loc, ArrayRef<mlir::Type>());
|
auto er = builder.create<scf::ExecuteRegionOp>(loc, ArrayRef<mlir::Type>());
|
||||||
er.region().push_back(new Block());
|
er.getRegion().push_back(new Block());
|
||||||
auto oldpoint2 = builder.getInsertionPoint();
|
auto oldpoint2 = builder.getInsertionPoint();
|
||||||
auto oldblock2 = builder.getInsertionBlock();
|
auto oldblock2 = builder.getInsertionBlock();
|
||||||
|
|
||||||
|
@ -613,7 +613,7 @@ ValueCategory MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) {
|
||||||
}
|
}
|
||||||
|
|
||||||
inCase = true;
|
inCase = true;
|
||||||
er.region().getBlocks().push_back(&condB);
|
er.getRegion().getBlocks().push_back(&condB);
|
||||||
blocks.push_back(&condB);
|
blocks.push_back(&condB);
|
||||||
builder.setInsertionPointToStart(&condB);
|
builder.setInsertionPointToStart(&condB);
|
||||||
|
|
||||||
|
@ -638,7 +638,7 @@ ValueCategory MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) {
|
||||||
}
|
}
|
||||||
|
|
||||||
inCase = true;
|
inCase = true;
|
||||||
er.region().getBlocks().push_back(&condB);
|
er.getRegion().getBlocks().push_back(&condB);
|
||||||
builder.setInsertionPointToStart(&condB);
|
builder.setInsertionPointToStart(&condB);
|
||||||
|
|
||||||
auto i1Ty = builder.getIntegerType(1);
|
auto i1Ty = builder.getIntegerType(1);
|
||||||
|
@ -668,7 +668,7 @@ ValueCategory MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) {
|
||||||
loops.pop_back();
|
loops.pop_back();
|
||||||
builder.create<mlir::BranchOp>(loc, &exitB);
|
builder.create<mlir::BranchOp>(loc, &exitB);
|
||||||
|
|
||||||
er.region().getBlocks().push_back(&exitB);
|
er.getRegion().getBlocks().push_back(&exitB);
|
||||||
|
|
||||||
DenseIntElementsAttr caseValuesAttr;
|
DenseIntElementsAttr caseValuesAttr;
|
||||||
ShapedType caseValueType = mlir::VectorType::get(
|
ShapedType caseValueType = mlir::VectorType::get(
|
||||||
|
@ -694,7 +694,7 @@ ValueCategory MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) {
|
||||||
caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseVals8);
|
caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseVals8);
|
||||||
}
|
}
|
||||||
|
|
||||||
builder.setInsertionPointToStart(&er.region().front());
|
builder.setInsertionPointToStart(&er.getRegion().front());
|
||||||
builder.create<mlir::SwitchOp>(
|
builder.create<mlir::SwitchOp>(
|
||||||
loc, cond, defaultB, ArrayRef<mlir::Value>(), caseValuesAttr, blocks,
|
loc, cond, defaultB, ArrayRef<mlir::Value>(), caseValuesAttr, blocks,
|
||||||
SmallVector<mlir::ValueRange>(caseVals.size(), ArrayRef<mlir::Value>()));
|
SmallVector<mlir::ValueRange>(caseVals.size(), ArrayRef<mlir::Value>()));
|
||||||
|
|
|
@ -21,13 +21,13 @@ IfScope::IfScope(MLIRScanner &scanner) : scanner(scanner), prevBlock(nullptr) {
|
||||||
/*hasElse*/ false);
|
/*hasElse*/ false);
|
||||||
prevBlock = scanner.builder.getInsertionBlock();
|
prevBlock = scanner.builder.getInsertionBlock();
|
||||||
prevIterator = scanner.builder.getInsertionPoint();
|
prevIterator = scanner.builder.getInsertionPoint();
|
||||||
ifOp.thenRegion().back().clear();
|
ifOp.getThenRegion().back().clear();
|
||||||
scanner.builder.setInsertionPointToStart(&ifOp.thenRegion().back());
|
scanner.builder.setInsertionPointToStart(&ifOp.getThenRegion().back());
|
||||||
auto er = scanner.builder.create<scf::ExecuteRegionOp>(
|
auto er = scanner.builder.create<scf::ExecuteRegionOp>(
|
||||||
scanner.loc, ArrayRef<mlir::Type>());
|
scanner.loc, ArrayRef<mlir::Type>());
|
||||||
scanner.builder.create<scf::YieldOp>(scanner.loc);
|
scanner.builder.create<scf::YieldOp>(scanner.loc);
|
||||||
er.region().push_back(new Block());
|
er.getRegion().push_back(new Block());
|
||||||
scanner.builder.setInsertionPointToStart(&er.region().back());
|
scanner.builder.setInsertionPointToStart(&er.getRegion().back());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -68,7 +68,8 @@ static mlir::Value castCallerMemRefArg(mlir::Value callerArg,
|
||||||
|
|
||||||
if (MemRefType dstTy = calleeArgType.dyn_cast_or_null<MemRefType>()) {
|
if (MemRefType dstTy = calleeArgType.dyn_cast_or_null<MemRefType>()) {
|
||||||
MemRefType srcTy = callerArgType.dyn_cast<MemRefType>();
|
MemRefType srcTy = callerArgType.dyn_cast<MemRefType>();
|
||||||
if (srcTy && dstTy.getElementType() == srcTy.getElementType()) {
|
if (srcTy && dstTy.getElementType() == srcTy.getElementType()
|
||||||
|
&& dstTy.getMemorySpace() == srcTy.getMemorySpace()) {
|
||||||
auto srcShape = srcTy.getShape();
|
auto srcShape = srcTy.getShape();
|
||||||
auto dstShape = dstTy.getShape();
|
auto dstShape = dstTy.getShape();
|
||||||
|
|
||||||
|
@ -748,7 +749,7 @@ ValueCategory MLIRScanner::VisitVarDecl(clang::VarDecl *decl) {
|
||||||
auto ifOp = builder.create<scf::IfOp>(varLoc, cond, /*hasElse*/false);
|
auto ifOp = builder.create<scf::IfOp>(varLoc, cond, /*hasElse*/false);
|
||||||
block = builder.getInsertionBlock();
|
block = builder.getInsertionBlock();
|
||||||
iter = builder.getInsertionPoint();
|
iter = builder.getInsertionPoint();
|
||||||
builder.setInsertionPointToStart(&ifOp.thenRegion().back());
|
builder.setInsertionPointToStart(&ifOp.getThenRegion().back());
|
||||||
builder.create<memref::StoreOp>(varLoc, builder.create<ConstantIntOp>(varLoc, false, 1), boolop, std::vector<mlir::Value>({getConstantIndex(0)}));
|
builder.create<memref::StoreOp>(varLoc, builder.create<ConstantIntOp>(varLoc, false, 1), boolop, std::vector<mlir::Value>({getConstantIndex(0)}));
|
||||||
}
|
}
|
||||||
} else
|
} else
|
||||||
|
@ -2919,15 +2920,6 @@ ValueCategory MLIRScanner::CallHelper(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert(val);
|
assert(val);
|
||||||
/*
|
|
||||||
if (val.getType() != fnType.getInput(i)) {
|
|
||||||
if (auto MR1 = val.getType().dyn_cast<MemRefType>()) {
|
|
||||||
if (auto MR2 = fnType.getInput(i).dyn_cast<MemRefType>()) {
|
|
||||||
val = builder.create<mlir::memref::CastOp>(loc, val, MR2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
args.push_back(val);
|
args.push_back(val);
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
@ -3460,7 +3452,7 @@ ValueCategory MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) {
|
||||||
|
|
||||||
auto oldpoint = builder.getInsertionPoint();
|
auto oldpoint = builder.getInsertionPoint();
|
||||||
auto oldblock = builder.getInsertionBlock();
|
auto oldblock = builder.getInsertionBlock();
|
||||||
builder.setInsertionPointToStart(&ifOp.thenRegion().back());
|
builder.setInsertionPointToStart(&ifOp.getThenRegion().back());
|
||||||
|
|
||||||
auto rhs = Visit(BO->getRHS()).getValue(builder);
|
auto rhs = Visit(BO->getRHS()).getValue(builder);
|
||||||
assert(rhs != nullptr);
|
assert(rhs != nullptr);
|
||||||
|
@ -3476,7 +3468,7 @@ ValueCategory MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) {
|
||||||
mlir::Value truearray[] = {rhs};
|
mlir::Value truearray[] = {rhs};
|
||||||
builder.create<mlir::scf::YieldOp>(loc, truearray);
|
builder.create<mlir::scf::YieldOp>(loc, truearray);
|
||||||
|
|
||||||
builder.setInsertionPointToStart(&ifOp.elseRegion().back());
|
builder.setInsertionPointToStart(&ifOp.getElseRegion().back());
|
||||||
mlir::Value falsearray[] = {
|
mlir::Value falsearray[] = {
|
||||||
builder.create<ConstantIntOp>(loc, 0, types[0])};
|
builder.create<ConstantIntOp>(loc, 0, types[0])};
|
||||||
builder.create<mlir::scf::YieldOp>(loc, falsearray);
|
builder.create<mlir::scf::YieldOp>(loc, falsearray);
|
||||||
|
@ -3497,12 +3489,12 @@ ValueCategory MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) {
|
||||||
|
|
||||||
auto oldpoint = builder.getInsertionPoint();
|
auto oldpoint = builder.getInsertionPoint();
|
||||||
auto oldblock = builder.getInsertionBlock();
|
auto oldblock = builder.getInsertionBlock();
|
||||||
builder.setInsertionPointToStart(&ifOp.thenRegion().back());
|
builder.setInsertionPointToStart(&ifOp.getThenRegion().back());
|
||||||
|
|
||||||
mlir::Value truearray[] = {builder.create<ConstantIntOp>(loc, 1, types[0])};
|
mlir::Value truearray[] = {builder.create<ConstantIntOp>(loc, 1, types[0])};
|
||||||
builder.create<mlir::scf::YieldOp>(loc, truearray);
|
builder.create<mlir::scf::YieldOp>(loc, truearray);
|
||||||
|
|
||||||
builder.setInsertionPointToStart(&ifOp.elseRegion().back());
|
builder.setInsertionPointToStart(&ifOp.getElseRegion().back());
|
||||||
auto rhs = Visit(BO->getRHS()).getValue(builder);
|
auto rhs = Visit(BO->getRHS()).getValue(builder);
|
||||||
if (!rhs.getType().cast<mlir::IntegerType>().isInteger(1)) {
|
if (!rhs.getType().cast<mlir::IntegerType>().isInteger(1)) {
|
||||||
auto postTy = builder.getI1Type();
|
auto postTy = builder.getI1Type();
|
||||||
|
@ -4747,7 +4739,7 @@ MLIRScanner::VisitConditionalOperator(clang::ConditionalOperator *E) {
|
||||||
|
|
||||||
auto oldpoint = builder.getInsertionPoint();
|
auto oldpoint = builder.getInsertionPoint();
|
||||||
auto oldblock = builder.getInsertionBlock();
|
auto oldblock = builder.getInsertionBlock();
|
||||||
builder.setInsertionPointToStart(&ifOp.thenRegion().back());
|
builder.setInsertionPointToStart(&ifOp.getThenRegion().back());
|
||||||
|
|
||||||
auto trueExpr = Visit(E->getTrueExpr());
|
auto trueExpr = Visit(E->getTrueExpr());
|
||||||
|
|
||||||
|
@ -4778,7 +4770,7 @@ MLIRScanner::VisitConditionalOperator(clang::ConditionalOperator *E) {
|
||||||
builder.create<mlir::scf::YieldOp>(loc, truearray);
|
builder.create<mlir::scf::YieldOp>(loc, truearray);
|
||||||
}
|
}
|
||||||
|
|
||||||
builder.setInsertionPointToStart(&ifOp.elseRegion().back());
|
builder.setInsertionPointToStart(&ifOp.getElseRegion().back());
|
||||||
|
|
||||||
auto falseExpr = Visit(E->getFalseExpr());
|
auto falseExpr = Visit(E->getFalseExpr());
|
||||||
std::vector<mlir::Value> falsearray;
|
std::vector<mlir::Value> falsearray;
|
||||||
|
@ -4800,10 +4792,10 @@ MLIRScanner::VisitConditionalOperator(clang::ConditionalOperator *E) {
|
||||||
types[i] = truearray[i].getType();
|
types[i] = truearray[i].getType();
|
||||||
auto newIfOp = builder.create<mlir::scf::IfOp>(loc, types, cond,
|
auto newIfOp = builder.create<mlir::scf::IfOp>(loc, types, cond,
|
||||||
/*hasElseRegion*/ true);
|
/*hasElseRegion*/ true);
|
||||||
newIfOp.thenRegion().takeBody(ifOp.thenRegion());
|
newIfOp.getThenRegion().takeBody(ifOp.getThenRegion());
|
||||||
newIfOp.elseRegion().takeBody(ifOp.elseRegion());
|
newIfOp.getElseRegion().takeBody(ifOp.getElseRegion());
|
||||||
|
ifOp.erase();
|
||||||
return ValueCategory(newIfOp.getResult(0), /*isReference*/ isReference);
|
return ValueCategory(newIfOp.getResult(0), /*isReference*/ isReference);
|
||||||
// return ifOp;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ValueCategory MLIRScanner::VisitStmtExpr(clang::StmtExpr *stmt) {
|
ValueCategory MLIRScanner::VisitStmtExpr(clang::StmtExpr *stmt) {
|
||||||
|
|
Loading…
Reference in New Issue