Permit debug build
This commit is contained in:
parent
721c54b271
commit
03354e8285
|
@ -255,7 +255,6 @@ struct RemoveUnusedArgs : public OpRewritePattern<ForOp> {
|
|||
});
|
||||
rewriter.setInsertionPoint(yieldOp);
|
||||
rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, usedYieldOperands);
|
||||
rewriter.eraseOp(yieldOp);
|
||||
|
||||
// Replace the operation's results with the new ones.
|
||||
SmallVector<Value, 4> repResults(op.getNumResults());
|
||||
|
@ -489,14 +488,6 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
|
|||
return true;
|
||||
}
|
||||
|
||||
unsigned countOperations(Region ®) const {
|
||||
unsigned count = 0;
|
||||
for (auto &block : reg)
|
||||
for (auto &nested : block)
|
||||
count++;
|
||||
return count;
|
||||
}
|
||||
|
||||
LogicalResult matchAndRewrite(WhileOp loop,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!isWhile(loop))
|
||||
|
@ -518,6 +509,12 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
|
|||
llvm::errs() << condOp.condition() << "\n";
|
||||
return failure();
|
||||
}
|
||||
size_t size = 0;
|
||||
for(auto &m : loop.before().front())
|
||||
size++;
|
||||
if (size != 2) {
|
||||
return failure();
|
||||
}
|
||||
size_t beforeArgNum;
|
||||
|
||||
Value maybeIndVar = cmpIOp.lhs();
|
||||
|
@ -618,48 +615,63 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
|
|||
// loop.dump();
|
||||
|
||||
for (Value arg : condOp.args()) {
|
||||
Type cst = nullptr;
|
||||
if (auto idx = arg.getDefiningOp<IndexCastOp>()) {
|
||||
cst = idx.getType();
|
||||
arg = idx.in();
|
||||
}
|
||||
Value res;
|
||||
if (isTopLevelArgValue(arg, &loop.before())) {
|
||||
auto blockArg = arg.dyn_cast<BlockArgument>();
|
||||
auto pos = blockArg.getArgNumber();
|
||||
forArgs.push_back(loop.inits()[pos]);
|
||||
res = loop.inits()[pos];
|
||||
} else
|
||||
forArgs.push_back(arg);
|
||||
res = arg;
|
||||
if (cst) {
|
||||
res = rewriter.create<IndexCastOp>(rewriter.getUnknownLoc(), res, cst);
|
||||
}
|
||||
forArgs.push_back(res);
|
||||
}
|
||||
|
||||
auto forloop = rewriter.create<scf::ForOp>(
|
||||
loop.getLoc(), lb, ub, step, forArgs,
|
||||
[&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
|
||||
// map for the conditionOp value.
|
||||
size_t pos = loop.inits().size();
|
||||
SmallVector<Value, 2> mappedValues;
|
||||
mappedValues.append(args.begin() + pos, args.end());
|
||||
loop.getLoc(), lb, ub, step, forArgs);
|
||||
|
||||
BlockAndValueMapping mapping;
|
||||
mapping.map(loop.after().getArguments(), mappedValues);
|
||||
for (auto &block : loop.after().getBlocks())
|
||||
for (auto &nested : block.without_terminator())
|
||||
b.clone(nested, mapping);
|
||||
if (!forloop.getBody()->empty())
|
||||
rewriter.eraseOp(forloop.getBody()->getTerminator());
|
||||
|
||||
auto oldYield =
|
||||
cast<scf::YieldOp>(loop.after().front().getTerminator());
|
||||
SmallVector<Value, 2> yieldOperands;
|
||||
for (auto oldYieldArg : oldYield.results())
|
||||
yieldOperands.push_back(mapping.lookupOrDefault(oldYieldArg));
|
||||
size_t pos = loop.inits().size();
|
||||
for (auto pair : llvm::zip(loop.after().getArguments(), forloop.getRegionIterArgs().drop_front(pos))) {
|
||||
std::get<0>(pair).replaceAllUsesWith(std::get<1>(pair));
|
||||
}
|
||||
for (auto pair : llvm::zip(loop.before().getArguments(), forloop.getRegionIterArgs().drop_back(pos))) {
|
||||
std::get<0>(pair).replaceAllUsesWith(std::get<1>(pair));
|
||||
}
|
||||
|
||||
BlockAndValueMapping outmap;
|
||||
outmap.map(loop.before().getArguments(), yieldOperands);
|
||||
for (auto arg : condOp.args())
|
||||
yieldOperands.push_back(outmap.lookupOrDefault(arg));
|
||||
forloop.getBody()->getOperations().splice(forloop.getBody()->getOperations().begin(), loop.after().front().getOperations());
|
||||
|
||||
b.create<scf::YieldOp>(loop.getLoc(), yieldOperands);
|
||||
});
|
||||
auto oldYield = cast<scf::YieldOp>(forloop.getBody()->getTerminator());
|
||||
|
||||
SmallVector<Value, 2> yieldOperands;
|
||||
for (auto oldYieldArg : oldYield.results())
|
||||
yieldOperands.push_back(oldYieldArg);
|
||||
|
||||
BlockAndValueMapping outmap;
|
||||
outmap.map(loop.before().getArguments(), yieldOperands);
|
||||
for (auto arg : condOp.args())
|
||||
yieldOperands.push_back(outmap.lookupOrDefault(arg));
|
||||
|
||||
rewriter.setInsertionPoint(oldYield);
|
||||
rewriter.replaceOpWithNewOp<scf::YieldOp>(oldYield, yieldOperands);
|
||||
|
||||
SmallVector<Value, 2> replacements;
|
||||
size_t pos = loop.inits().size();
|
||||
replacements.append(forloop.getResults().begin() + pos,
|
||||
forloop.getResults().end());
|
||||
|
||||
llvm::errs() << " func: " << *loop->getParentOfType<FuncOp>() << "\n";
|
||||
llvm::errs() << " op2: " << forloop << "\n";
|
||||
llvm::errs() << " op: " << loop << "\n";
|
||||
|
||||
rewriter.replaceOp(loop, replacements);
|
||||
auto m = forloop->getParentOfType<ModuleOp>();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -749,7 +761,6 @@ struct MoveWhileDown : public OpRewritePattern<WhileOp> {
|
|||
}
|
||||
};
|
||||
|
||||
#if 1
|
||||
struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
|
||||
using OpRewritePattern<WhileOp>::OpRewritePattern;
|
||||
|
||||
|
@ -796,7 +807,6 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
|
|||
if (ifOp.condition() != term.condition())
|
||||
return failure();
|
||||
|
||||
llvm::errs() << " moving while to for2: " << op <<"\n";
|
||||
SmallVector<std::pair<BlockArgument, Value>, 2> m;
|
||||
SmallVector<Value, 2> condArgs;
|
||||
SmallVector<Value, 2> prevArgs;
|
||||
|
@ -838,7 +848,7 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
|
|||
|
||||
|
||||
rewriter.setInsertionPoint(term);
|
||||
auto ncond = rewriter.replaceOpWithNewOp<ConditionOp>(term, term.condition(), condArgs);
|
||||
rewriter.replaceOpWithNewOp<ConditionOp>(term, term.condition(), condArgs);
|
||||
|
||||
for(int i=m.size()-1; i>=0; i--) {
|
||||
m[i].first.replaceAllUsesWith(m[i].second);
|
||||
|
@ -864,27 +874,68 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
|
|||
|
||||
for(auto pair : llvm::enumerate(prevArgs)) {
|
||||
pair.value().replaceAllUsesWith(nop.getResult(pair.index()));
|
||||
/*
|
||||
for (OpOperand &use :
|
||||
llvm::make_early_inc_range(pair.value().getUses())) {
|
||||
if (nop.after().isAncestor(use.getOwner()->getParentRegion()))
|
||||
rewriter.updateRootInPlace(use.getOwner(),
|
||||
[&]() { use.set(nop.getResult(pair.index())); });
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
|
||||
llvm::errs() << " nop: " << nop << "\n";
|
||||
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
struct MoveWhileDown3 : public OpRewritePattern<WhileOp> {
|
||||
using OpRewritePattern<WhileOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(WhileOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator());
|
||||
SmallVector<BlockArgument, 2> toErase;
|
||||
SmallVector<Value, 2> newOps;
|
||||
SmallVector<Value, 2> condOps;
|
||||
for (auto pair : llvm::zip(op.getResults(), term.args(), llvm::make_early_inc_range(op.getAfterArguments()))) {
|
||||
if (std::get<0>(pair).use_empty() && std::get<1>(pair).hasOneUse()) {
|
||||
// todo generalize to any non memory
|
||||
if (auto idx = std::get<1>(pair).getDefiningOp<IndexCastOp>()) {
|
||||
std::get<2>(pair).replaceAllUsesWith(idx);
|
||||
idx->moveBefore(&op.after().front().front());
|
||||
toErase.push_back(std::get<2>(pair));
|
||||
for(auto& o : llvm::make_early_inc_range(idx->getOpOperands())) {
|
||||
newOps.push_back(o.get());
|
||||
o.set(op.after().front().addArgument(o.get().getType()));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
condOps.push_back(std::get<1>(pair));
|
||||
}
|
||||
if (toErase.size() == 0) return failure();
|
||||
|
||||
SmallVector<Value, 2> returns(condOps.begin(), condOps.end());
|
||||
|
||||
condOps.append(newOps.begin(), newOps.end());
|
||||
for (int i=toErase.size()-1; i>=0; i--) {
|
||||
op.after().front().eraseArgument(i);
|
||||
}
|
||||
rewriter.setInsertionPoint(term);
|
||||
rewriter.replaceOpWithNewOp<ConditionOp>(term, term.condition(), condOps);
|
||||
|
||||
rewriter.setInsertionPoint(op);
|
||||
SmallVector<Type, 4> resultTypes;
|
||||
for(auto v : returns) {
|
||||
resultTypes.push_back(v.getType());
|
||||
}
|
||||
auto nop = rewriter.create<WhileOp>(op.getLoc(), resultTypes, op.inits());
|
||||
|
||||
nop.before().takeBody(op.before());
|
||||
nop.after().takeBody(op.after());
|
||||
|
||||
for(auto pair : llvm::enumerate(returns)) {
|
||||
pair.value().replaceAllUsesWith(nop.getResult(pair.index()));
|
||||
}
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct RemoveUnusedCondVar : public OpRewritePattern<WhileOp> {
|
||||
using OpRewritePattern<WhileOp>::OpRewritePattern;
|
||||
|
@ -994,7 +1045,7 @@ void CanonicalizeFor::runOnFunction() {
|
|||
|
||||
mlir::RewritePatternSet rpl(getFunction().getContext());
|
||||
rpl.add<ForOpInductionReplacement, RemoveUnusedArgs,
|
||||
MoveWhileToFor, MoveWhileDown, MoveWhileDown2,
|
||||
MoveWhileToFor, MoveWhileDown, MoveWhileDown2, MoveWhileDown3,
|
||||
RemoveUnusedCondVar, MoveSideEffectFreeWhile>(getFunction().getContext());
|
||||
GreedyRewriteConfig config;
|
||||
applyPatternsAndFoldGreedily(getFunction().getOperation(), std::move(rpl),
|
||||
|
|
|
@ -128,13 +128,13 @@ llvm::raw_ostream& operator<<(llvm::raw_ostream& os, Wrapper& w) {
|
|||
template<typename T>
|
||||
struct Iter : public std::iterator<
|
||||
std::input_iterator_tag, // iterator_category
|
||||
Wrapper*,
|
||||
std::ptrdiff_t,
|
||||
Wrapper**,
|
||||
Wrapper* > {
|
||||
T it;
|
||||
Iter(T it) : it(it) {}
|
||||
Wrapper* operator*() const {
|
||||
Block* B = *it;
|
||||
return (Wrapper*)B;
|
||||
}
|
||||
Wrapper* operator*() const;
|
||||
bool operator!=(Iter I) const {
|
||||
return it != I.it;
|
||||
}
|
||||
|
@ -144,6 +144,9 @@ struct Iter : public std::iterator<
|
|||
void operator++() {
|
||||
++it;
|
||||
}
|
||||
Iter<T> operator--() {
|
||||
return --it;
|
||||
}
|
||||
Iter<T> operator++(int) {
|
||||
auto prev = *this;
|
||||
it++;
|
||||
|
@ -151,8 +154,47 @@ struct Iter : public std::iterator<
|
|||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
Wrapper* Iter<Region::iterator>::operator*() const {
|
||||
Block& B = *it;
|
||||
return (Wrapper*)&B;
|
||||
}
|
||||
template<>
|
||||
Wrapper* Iter<Region::reverse_iterator>::operator*() const {
|
||||
Block& B = *it;
|
||||
return (Wrapper*)&B;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
Wrapper* Iter<T>::operator*() const {
|
||||
Block* B = *it;
|
||||
return (Wrapper*)B;
|
||||
}
|
||||
|
||||
namespace llvm {
|
||||
template <>
|
||||
struct GraphTraits<RWrapper *> {
|
||||
using nodes_iterator = Iter<Region::iterator>;
|
||||
static Wrapper* getEntryNode(RWrapper* bb) { return (Wrapper*)&((Region*)bb)->front(); }
|
||||
static nodes_iterator nodes_begin(RWrapper* bb) {
|
||||
return ((Region*)bb)->begin();
|
||||
}
|
||||
static nodes_iterator nodes_end(RWrapper* bb) {
|
||||
return ((Region*)bb)->end();
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct GraphTraits<Inverse<RWrapper *>> {
|
||||
using nodes_iterator = Iter<Region::reverse_iterator>;
|
||||
static Wrapper* getEntryNode(RWrapper* bb) { return (Wrapper*)&((Region*)bb)->front(); }
|
||||
static nodes_iterator nodes_begin(RWrapper* bb) {
|
||||
return ((Region*)bb)->rbegin();
|
||||
}
|
||||
static nodes_iterator nodes_end(RWrapper* bb) {
|
||||
return ((Region*)bb)->rend();
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct GraphTraits<const Wrapper *> {
|
||||
using ChildIteratorType = Iter<Block::succ_iterator>;
|
||||
using Node = const Wrapper;
|
||||
|
@ -231,9 +273,11 @@ struct LoopRestructure : public mlir::LoopRestructureBase<LoopRestructure> {
|
|||
// Instantiate a variant of LLVM LoopInfo that works on mlir::Block
|
||||
#include "llvm/Analysis/LoopInfo.h"
|
||||
#include "llvm/Analysis/LoopInfoImpl.h"
|
||||
#include "llvm/Support/GenericDomTreeConstruction.h"
|
||||
|
||||
template class llvm::DominatorTreeBase<Wrapper, false>;
|
||||
template class llvm::DomTreeNodeBase<Wrapper>;
|
||||
//template void llvm::DomTreeBuilder::ApplyUpdates<llvm::DominatorTreeBase<Wrapper, false>>;
|
||||
|
||||
namespace mlir {
|
||||
class Loop : public llvm::LoopBase<Wrapper, mlir::Loop> {
|
||||
|
|
Loading…
Reference in New Issue