Permit debug build

This commit is contained in:
William S. Moses 2021-06-25 01:14:04 -04:00
parent 721c54b271
commit 03354e8285
2 changed files with 151 additions and 56 deletions

View File

@ -255,7 +255,6 @@ struct RemoveUnusedArgs : public OpRewritePattern<ForOp> {
});
rewriter.setInsertionPoint(yieldOp);
rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, usedYieldOperands);
rewriter.eraseOp(yieldOp);
// Replace the operation's results with the new ones.
SmallVector<Value, 4> repResults(op.getNumResults());
@ -489,14 +488,6 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
return true;
}
unsigned countOperations(Region &reg) const {
unsigned count = 0;
for (auto &block : reg)
for (auto &nested : block)
count++;
return count;
}
LogicalResult matchAndRewrite(WhileOp loop,
PatternRewriter &rewriter) const override {
if (!isWhile(loop))
@ -518,6 +509,12 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
llvm::errs() << condOp.condition() << "\n";
return failure();
}
size_t size = 0;
for(auto &m : loop.before().front())
size++;
if (size != 2) {
return failure();
}
size_t beforeArgNum;
Value maybeIndVar = cmpIOp.lhs();
@ -618,48 +615,63 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
// loop.dump();
for (Value arg : condOp.args()) {
Type cst = nullptr;
if (auto idx = arg.getDefiningOp<IndexCastOp>()) {
cst = idx.getType();
arg = idx.in();
}
Value res;
if (isTopLevelArgValue(arg, &loop.before())) {
auto blockArg = arg.dyn_cast<BlockArgument>();
auto pos = blockArg.getArgNumber();
forArgs.push_back(loop.inits()[pos]);
res = loop.inits()[pos];
} else
forArgs.push_back(arg);
res = arg;
if (cst) {
res = rewriter.create<IndexCastOp>(rewriter.getUnknownLoc(), res, cst);
}
forArgs.push_back(res);
}
auto forloop = rewriter.create<scf::ForOp>(
loop.getLoc(), lb, ub, step, forArgs,
[&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
// map for the conditionOp value.
size_t pos = loop.inits().size();
SmallVector<Value, 2> mappedValues;
mappedValues.append(args.begin() + pos, args.end());
loop.getLoc(), lb, ub, step, forArgs);
if (!forloop.getBody()->empty())
rewriter.eraseOp(forloop.getBody()->getTerminator());
BlockAndValueMapping mapping;
mapping.map(loop.after().getArguments(), mappedValues);
for (auto &block : loop.after().getBlocks())
for (auto &nested : block.without_terminator())
b.clone(nested, mapping);
size_t pos = loop.inits().size();
for (auto pair : llvm::zip(loop.after().getArguments(), forloop.getRegionIterArgs().drop_front(pos))) {
std::get<0>(pair).replaceAllUsesWith(std::get<1>(pair));
}
for (auto pair : llvm::zip(loop.before().getArguments(), forloop.getRegionIterArgs().drop_back(pos))) {
std::get<0>(pair).replaceAllUsesWith(std::get<1>(pair));
}
auto oldYield =
cast<scf::YieldOp>(loop.after().front().getTerminator());
SmallVector<Value, 2> yieldOperands;
for (auto oldYieldArg : oldYield.results())
yieldOperands.push_back(mapping.lookupOrDefault(oldYieldArg));
forloop.getBody()->getOperations().splice(forloop.getBody()->getOperations().begin(), loop.after().front().getOperations());
BlockAndValueMapping outmap;
outmap.map(loop.before().getArguments(), yieldOperands);
for (auto arg : condOp.args())
yieldOperands.push_back(outmap.lookupOrDefault(arg));
auto oldYield = cast<scf::YieldOp>(forloop.getBody()->getTerminator());
b.create<scf::YieldOp>(loop.getLoc(), yieldOperands);
});
SmallVector<Value, 2> yieldOperands;
for (auto oldYieldArg : oldYield.results())
yieldOperands.push_back(oldYieldArg);
BlockAndValueMapping outmap;
outmap.map(loop.before().getArguments(), yieldOperands);
for (auto arg : condOp.args())
yieldOperands.push_back(outmap.lookupOrDefault(arg));
rewriter.setInsertionPoint(oldYield);
rewriter.replaceOpWithNewOp<scf::YieldOp>(oldYield, yieldOperands);
SmallVector<Value, 2> replacements;
size_t pos = loop.inits().size();
replacements.append(forloop.getResults().begin() + pos,
forloop.getResults().end());
llvm::errs() << " func: " << *loop->getParentOfType<FuncOp>() << "\n";
llvm::errs() << " op2: " << forloop << "\n";
llvm::errs() << " op: " << loop << "\n";
rewriter.replaceOp(loop, replacements);
auto m = forloop->getParentOfType<ModuleOp>();
return success();
}
};
@ -749,7 +761,6 @@ struct MoveWhileDown : public OpRewritePattern<WhileOp> {
}
};
#if 1
struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
using OpRewritePattern<WhileOp>::OpRewritePattern;
@ -796,7 +807,6 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
if (ifOp.condition() != term.condition())
return failure();
llvm::errs() << " moving while to for2: " << op <<"\n";
SmallVector<std::pair<BlockArgument, Value>, 2> m;
SmallVector<Value, 2> condArgs;
SmallVector<Value, 2> prevArgs;
@ -838,7 +848,7 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
rewriter.setInsertionPoint(term);
auto ncond = rewriter.replaceOpWithNewOp<ConditionOp>(term, term.condition(), condArgs);
rewriter.replaceOpWithNewOp<ConditionOp>(term, term.condition(), condArgs);
for(int i=m.size()-1; i>=0; i--) {
m[i].first.replaceAllUsesWith(m[i].second);
@ -864,27 +874,68 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
for(auto pair : llvm::enumerate(prevArgs)) {
pair.value().replaceAllUsesWith(nop.getResult(pair.index()));
/*
for (OpOperand &use :
llvm::make_early_inc_range(pair.value().getUses())) {
if (nop.after().isAncestor(use.getOwner()->getParentRegion()))
rewriter.updateRootInPlace(use.getOwner(),
[&]() { use.set(nop.getResult(pair.index())); });
}
*/
}
rewriter.eraseOp(op);
llvm::errs() << " nop: " << nop << "\n";
return success();
}
return failure();
}
};
#endif
struct MoveWhileDown3 : public OpRewritePattern<WhileOp> {
using OpRewritePattern<WhileOp>::OpRewritePattern;
LogicalResult matchAndRewrite(WhileOp op,
PatternRewriter &rewriter) const override {
auto term = cast<scf::ConditionOp>(op.before().front().getTerminator());
SmallVector<BlockArgument, 2> toErase;
SmallVector<Value, 2> newOps;
SmallVector<Value, 2> condOps;
for (auto pair : llvm::zip(op.getResults(), term.args(), llvm::make_early_inc_range(op.getAfterArguments()))) {
if (std::get<0>(pair).use_empty() && std::get<1>(pair).hasOneUse()) {
// todo generalize to any non memory
if (auto idx = std::get<1>(pair).getDefiningOp<IndexCastOp>()) {
std::get<2>(pair).replaceAllUsesWith(idx);
idx->moveBefore(&op.after().front().front());
toErase.push_back(std::get<2>(pair));
for(auto& o : llvm::make_early_inc_range(idx->getOpOperands())) {
newOps.push_back(o.get());
o.set(op.after().front().addArgument(o.get().getType()));
}
continue;
}
}
condOps.push_back(std::get<1>(pair));
}
if (toErase.size() == 0) return failure();
SmallVector<Value, 2> returns(condOps.begin(), condOps.end());
condOps.append(newOps.begin(), newOps.end());
for (int i=toErase.size()-1; i>=0; i--) {
op.after().front().eraseArgument(i);
}
rewriter.setInsertionPoint(term);
rewriter.replaceOpWithNewOp<ConditionOp>(term, term.condition(), condOps);
rewriter.setInsertionPoint(op);
SmallVector<Type, 4> resultTypes;
for(auto v : returns) {
resultTypes.push_back(v.getType());
}
auto nop = rewriter.create<WhileOp>(op.getLoc(), resultTypes, op.inits());
nop.before().takeBody(op.before());
nop.after().takeBody(op.after());
for(auto pair : llvm::enumerate(returns)) {
pair.value().replaceAllUsesWith(nop.getResult(pair.index()));
}
rewriter.eraseOp(op);
return success();
}
};
struct RemoveUnusedCondVar : public OpRewritePattern<WhileOp> {
using OpRewritePattern<WhileOp>::OpRewritePattern;
@ -994,7 +1045,7 @@ void CanonicalizeFor::runOnFunction() {
mlir::RewritePatternSet rpl(getFunction().getContext());
rpl.add<ForOpInductionReplacement, RemoveUnusedArgs,
MoveWhileToFor, MoveWhileDown, MoveWhileDown2,
MoveWhileToFor, MoveWhileDown, MoveWhileDown2, MoveWhileDown3,
RemoveUnusedCondVar, MoveSideEffectFreeWhile>(getFunction().getContext());
GreedyRewriteConfig config;
applyPatternsAndFoldGreedily(getFunction().getOperation(), std::move(rpl),

View File

@ -128,13 +128,13 @@ llvm::raw_ostream& operator<<(llvm::raw_ostream& os, Wrapper& w) {
template<typename T>
struct Iter : public std::iterator<
std::input_iterator_tag, // iterator_category
Wrapper*,
std::ptrdiff_t,
Wrapper**,
Wrapper* > {
T it;
Iter(T it) : it(it) {}
Wrapper* operator*() const {
Block* B = *it;
return (Wrapper*)B;
}
Wrapper* operator*() const;
bool operator!=(Iter I) const {
return it != I.it;
}
@ -144,6 +144,9 @@ struct Iter : public std::iterator<
void operator++() {
++it;
}
Iter<T> operator--() {
return --it;
}
Iter<T> operator++(int) {
auto prev = *this;
it++;
@ -151,8 +154,47 @@ struct Iter : public std::iterator<
}
};
template<>
Wrapper* Iter<Region::iterator>::operator*() const {
Block& B = *it;
return (Wrapper*)&B;
}
template<>
Wrapper* Iter<Region::reverse_iterator>::operator*() const {
Block& B = *it;
return (Wrapper*)&B;
}
template<typename T>
Wrapper* Iter<T>::operator*() const {
Block* B = *it;
return (Wrapper*)B;
}
namespace llvm {
template <>
struct GraphTraits<RWrapper *> {
using nodes_iterator = Iter<Region::iterator>;
static Wrapper* getEntryNode(RWrapper* bb) { return (Wrapper*)&((Region*)bb)->front(); }
static nodes_iterator nodes_begin(RWrapper* bb) {
return ((Region*)bb)->begin();
}
static nodes_iterator nodes_end(RWrapper* bb) {
return ((Region*)bb)->end();
}
};
template <>
struct GraphTraits<Inverse<RWrapper *>> {
using nodes_iterator = Iter<Region::reverse_iterator>;
static Wrapper* getEntryNode(RWrapper* bb) { return (Wrapper*)&((Region*)bb)->front(); }
static nodes_iterator nodes_begin(RWrapper* bb) {
return ((Region*)bb)->rbegin();
}
static nodes_iterator nodes_end(RWrapper* bb) {
return ((Region*)bb)->rend();
}
};
template <>
struct GraphTraits<const Wrapper *> {
using ChildIteratorType = Iter<Block::succ_iterator>;
using Node = const Wrapper;
@ -231,9 +273,11 @@ struct LoopRestructure : public mlir::LoopRestructureBase<LoopRestructure> {
// Instantiate a variant of LLVM LoopInfo that works on mlir::Block
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopInfoImpl.h"
#include "llvm/Support/GenericDomTreeConstruction.h"
template class llvm::DominatorTreeBase<Wrapper, false>;
template class llvm::DomTreeNodeBase<Wrapper>;
//template void llvm::DomTreeBuilder::ApplyUpdates<llvm::DominatorTreeBase<Wrapper, false>>;
namespace mlir {
class Loop : public llvm::LoopBase<Wrapper, mlir::Loop> {