Fixup lower
This commit is contained in:
parent
8b95da86a9
commit
1a1a1955b6
|
@ -55,9 +55,9 @@ static void findValuesUsedBelow(Operation *op,
|
|||
|
||||
/// Returns `true` if the given operation has a BarrierOp transitively nested in
|
||||
/// one of its regions.
|
||||
static bool hasNestedBarrier(Operation *op, Operation *direct = nullptr) {
|
||||
auto result = op->walk([=](polygeist::BarrierOp op) {
|
||||
if (!direct || op->getParentOp() == direct)
|
||||
static bool hasNestedBarrier(Operation *direct) {
|
||||
auto result = direct->walk([=](polygeist::BarrierOp op) {
|
||||
if (op->getParentOp() == direct)
|
||||
return WalkResult::interrupt();
|
||||
else
|
||||
return WalkResult::skip();
|
||||
|
@ -96,7 +96,7 @@ struct ReplaceIfWithFors : public OpRewritePattern<scf::IfOp> {
|
|||
return failure();
|
||||
}
|
||||
|
||||
if (!hasNestedBarrier(op, op)) {
|
||||
if (!hasNestedBarrier(op)) {
|
||||
LLVM_DEBUG(DBGS() << "[if-to-for] no nested barrier\n");
|
||||
return failure();
|
||||
}
|
||||
|
@ -255,7 +255,7 @@ struct NormalizeParallel : public OpRewritePattern<scf::ParallelOp> {
|
|||
/// Checks if `op` may need to be wrapped in a pair of barriers. This is a
|
||||
/// necessary but insufficient condition.
|
||||
static LogicalResult canWrapWithBarriers(Operation *op) {
|
||||
if (!isa<scf::ParallelOp>(op->getParentOp())) {
|
||||
if (!op->getParentOfType<scf::ParallelOp>()) {
|
||||
LLVM_DEBUG(DBGS() << "[wrap] not nested in a pfor\n");
|
||||
return failure();
|
||||
}
|
||||
|
@ -537,7 +537,8 @@ findInsertionPointAfterLoopOperands(scf::ParallelOp op) {
|
|||
return findNearestPostDominatingInsertionPoint(operands, postDominanceInfo);
|
||||
}
|
||||
|
||||
static Value allocateTemporaryBuffer(PatternRewriter &rewriter, Value value,
|
||||
template<typename T>
|
||||
static T allocateTemporaryBuffer(PatternRewriter &rewriter, Value value,
|
||||
ValueRange iterationCounts) {
|
||||
SmallVector<int64_t> bufferSize(iterationCounts.size(),
|
||||
ShapedType::kDynamicSize);
|
||||
|
@ -559,9 +560,7 @@ static Value allocateTemporaryBuffer(PatternRewriter &rewriter, Value value,
|
|||
}
|
||||
}
|
||||
auto type = MemRefType::get(bufferSize, ty);
|
||||
Value alloc =
|
||||
rewriter.create<memref::AllocaOp>(value.getLoc(), type, iterationCounts);
|
||||
return alloc;
|
||||
return rewriter.create<T>(value.getLoc(), type, iterationCounts);
|
||||
}
|
||||
|
||||
/// Interchanges a parallel for loop with a while loop it contains. The while
|
||||
|
@ -617,7 +616,7 @@ struct InterchangeWhilePFor : public OpRewritePattern<scf::ParallelOp> {
|
|||
findInsertionPointAfterLoopOperands(op);
|
||||
rewriter.setInsertionPoint(insertionPoint.first, insertionPoint.second);
|
||||
SmallVector<Value> iterationCounts = emitIterationCounts(rewriter, op);
|
||||
Value allocated = allocateTemporaryBuffer(
|
||||
Value allocated = allocateTemporaryBuffer<memref::AllocaOp>(
|
||||
rewriter, conditionOp.condition(), iterationCounts);
|
||||
Value zero = rewriter.create<ConstantIndexOp>(op.getLoc(), 0);
|
||||
|
||||
|
@ -729,20 +728,20 @@ struct DistributeAroundBarrier : public OpRewritePattern<scf::ParallelOp> {
|
|||
}
|
||||
barrier = &*it;
|
||||
}
|
||||
|
||||
|
||||
llvm::SetVector<Value> crossing;
|
||||
findValuesUsedBelow(barrier, crossing);
|
||||
std::pair<Block *, Block::iterator> insertPoint =
|
||||
findInsertionPointAfterLoopOperands(op);
|
||||
|
||||
rewriter.setInsertionPoint(insertPoint.first, insertPoint.second);
|
||||
//std::pair<Block *, Block::iterator> insertPoint =
|
||||
// findInsertionPointAfterLoopOperands(op);
|
||||
//rewriter.setInsertionPoint(insertPoint.first, insertPoint.second);
|
||||
rewriter.setInsertionPoint(op);
|
||||
SmallVector<Value> iterationCounts = emitIterationCounts(rewriter, op);
|
||||
|
||||
// Allocate space for values crossing the barrier.
|
||||
SmallVector<Value> allocations;
|
||||
SmallVector<memref::AllocOp> allocations;
|
||||
allocations.reserve(crossing.size());
|
||||
for (Value v : crossing) {
|
||||
allocations.push_back(allocateTemporaryBuffer(rewriter, v, iterationCounts));
|
||||
allocations.push_back(allocateTemporaryBuffer<memref::AllocOp>(rewriter, v, iterationCounts));
|
||||
}
|
||||
|
||||
// Store values crossing the barrier in caches immediately when ready.
|
||||
|
@ -781,6 +780,9 @@ struct DistributeAroundBarrier : public OpRewritePattern<scf::ParallelOp> {
|
|||
op.getLoc(), op.lowerBound(), op.upperBound(), op.step());
|
||||
rewriter.eraseOp(&newLoop.getBody()->back());
|
||||
|
||||
for (auto alloc : allocations)
|
||||
rewriter.create<memref::DeallocOp>(alloc.getLoc(), alloc);
|
||||
|
||||
// Recreate the operations in the new loop with new values.
|
||||
rewriter.setInsertionPointToStart(newLoop.getBody());
|
||||
BlockAndValueMapping mapping;
|
||||
|
@ -973,7 +975,7 @@ struct CPUifyPass : public SCFCPUifyBase<CPUifyPass> {
|
|||
NormalizeParallel, RotateWhile, DistributeAroundBarrier>(
|
||||
&getContext());
|
||||
GreedyRewriteConfig config;
|
||||
config.maxIterations = 42;
|
||||
config.maxIterations = 142;
|
||||
if (failed(applyPatternsAndFoldGreedily(getFunction(), std::move(patterns),
|
||||
config)))
|
||||
signalPassFailure();
|
||||
|
|
|
@ -2652,6 +2652,7 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
|
|||
(funcs.count(sr->getDecl()->getName().str()) ||
|
||||
sr->getDecl()->getName().startswith("mkl_") ||
|
||||
sr->getDecl()->getName().startswith("MKL_") ||
|
||||
sr->getDecl()->getName().startswith("cublas") ||
|
||||
sr->getDecl()->getName().startswith("cblas_"))) {
|
||||
|
||||
std::vector<mlir::Value> args;
|
||||
|
@ -4648,10 +4649,13 @@ ValueWithOffsets MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) {
|
|||
auto &exitB = *(new Block());
|
||||
builder.setInsertionPointToStart(&exitB);
|
||||
builder.create<scf::YieldOp>(loc);
|
||||
builder.setInsertionPointToStart(&exitB);
|
||||
|
||||
SmallVector<Block *> blocks;
|
||||
bool inCase = false;
|
||||
|
||||
Block* defaultB = &exitB;
|
||||
|
||||
for (auto cse : stmt->getBody()->children()) {
|
||||
if (auto cses = dyn_cast<CaseStmt>(cse)) {
|
||||
auto &condB = *(new Block());
|
||||
|
@ -4683,6 +4687,31 @@ ValueWithOffsets MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) {
|
|||
builder.create<mlir::memref::StoreOp>(loc, truev,
|
||||
loops.back().keepRunning);
|
||||
Visit(cses->getSubStmt());
|
||||
} else if (auto cses = dyn_cast<DefaultStmt>(cse)) {
|
||||
auto &condB = *(new Block());
|
||||
|
||||
if (inCase) {
|
||||
auto noBreak =
|
||||
builder.create<mlir::memref::LoadOp>(loc, loops.back().noBreak);
|
||||
builder.create<mlir::CondBranchOp>(loc, noBreak, &condB, &exitB);
|
||||
loops.pop_back();
|
||||
}
|
||||
|
||||
inCase = true;
|
||||
er.region().getBlocks().push_back(&condB);
|
||||
builder.setInsertionPointToStart(&condB);
|
||||
|
||||
auto i1Ty = builder.getIntegerType(1);
|
||||
auto type = mlir::MemRefType::get({}, i1Ty, {}, 0);
|
||||
auto truev = builder.create<mlir::ConstantIntOp>(loc, true, 1);
|
||||
loops.push_back(
|
||||
(LoopContext){builder.create<mlir::memref::AllocaOp>(loc, type),
|
||||
builder.create<mlir::memref::AllocaOp>(loc, type)});
|
||||
builder.create<mlir::memref::StoreOp>(loc, truev, loops.back().noBreak);
|
||||
builder.create<mlir::memref::StoreOp>(loc, truev,
|
||||
loops.back().keepRunning);
|
||||
defaultB = &condB;
|
||||
Visit(cses->getSubStmt());
|
||||
} else {
|
||||
Visit(cse);
|
||||
}
|
||||
|
@ -4701,7 +4730,7 @@ ValueWithOffsets MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) {
|
|||
|
||||
builder.setInsertionPointToStart(&er.region().front());
|
||||
builder.create<mlir::SwitchOp>(
|
||||
loc, cond, &exitB, ArrayRef<mlir::Value>(), caseValuesAttr, blocks,
|
||||
loc, cond, defaultB, ArrayRef<mlir::Value>(), caseValuesAttr, blocks,
|
||||
SmallVector<mlir::ValueRange>(caseVals.size(), ArrayRef<mlir::Value>()));
|
||||
builder.setInsertionPoint(oldblock2, oldpoint2);
|
||||
return nullptr;
|
||||
|
|
|
@ -389,6 +389,7 @@ int main(int argc, char **argv) {
|
|||
module->dump();
|
||||
llvm::errs() << "</immediate: mlir>\n";
|
||||
}
|
||||
|
||||
pm.enableVerifier(false);
|
||||
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
|
||||
if (true) {
|
||||
|
|
Loading…
Reference in New Issue