From 1a1a1955b6d117f728366dd71acdc92ffb8c38b3 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 25 Aug 2021 11:10:00 -0400 Subject: [PATCH] Fixup lower --- .../Passes/ParallelLoopDistribute.cpp | 38 ++++++++++--------- mlir-clang/Lib/clang-mlir.cc | 31 ++++++++++++++- mlir-clang/mlir-clang.cc | 1 + 3 files changed, 51 insertions(+), 19 deletions(-) diff --git a/lib/polygeist/Passes/ParallelLoopDistribute.cpp b/lib/polygeist/Passes/ParallelLoopDistribute.cpp index f903dc7..052eccc 100644 --- a/lib/polygeist/Passes/ParallelLoopDistribute.cpp +++ b/lib/polygeist/Passes/ParallelLoopDistribute.cpp @@ -55,9 +55,9 @@ static void findValuesUsedBelow(Operation *op, /// Returns `true` if the given operation has a BarrierOp transitively nested in /// one of its regions. -static bool hasNestedBarrier(Operation *op, Operation *direct = nullptr) { - auto result = op->walk([=](polygeist::BarrierOp op) { - if (!direct || op->getParentOp() == direct) +static bool hasNestedBarrier(Operation *direct) { + auto result = direct->walk([=](polygeist::BarrierOp op) { + if (op->getParentOp() == direct) return WalkResult::interrupt(); else return WalkResult::skip(); @@ -96,7 +96,7 @@ struct ReplaceIfWithFors : public OpRewritePattern { return failure(); } - if (!hasNestedBarrier(op, op)) { + if (!hasNestedBarrier(op)) { LLVM_DEBUG(DBGS() << "[if-to-for] no nested barrier\n"); return failure(); } @@ -255,7 +255,7 @@ struct NormalizeParallel : public OpRewritePattern { /// Checks if `op` may need to be wrapped in a pair of barriers. This is a /// necessary but insufficient condition. static LogicalResult canWrapWithBarriers(Operation *op) { - if (!isa(op->getParentOp())) { + if (!op->getParentOfType()) { LLVM_DEBUG(DBGS() << "[wrap] not nested in a pfor\n"); return failure(); } @@ -537,7 +537,8 @@ findInsertionPointAfterLoopOperands(scf::ParallelOp op) { return findNearestPostDominatingInsertionPoint(operands, postDominanceInfo); } -static Value allocateTemporaryBuffer(PatternRewriter &rewriter, Value value, +template +static T allocateTemporaryBuffer(PatternRewriter &rewriter, Value value, ValueRange iterationCounts) { SmallVector bufferSize(iterationCounts.size(), ShapedType::kDynamicSize); @@ -559,9 +560,7 @@ static Value allocateTemporaryBuffer(PatternRewriter &rewriter, Value value, } } auto type = MemRefType::get(bufferSize, ty); - Value alloc = - rewriter.create(value.getLoc(), type, iterationCounts); - return alloc; + return rewriter.create(value.getLoc(), type, iterationCounts); } /// Interchanges a parallel for loop with a while loop it contains. The while @@ -617,7 +616,7 @@ struct InterchangeWhilePFor : public OpRewritePattern { findInsertionPointAfterLoopOperands(op); rewriter.setInsertionPoint(insertionPoint.first, insertionPoint.second); SmallVector iterationCounts = emitIterationCounts(rewriter, op); - Value allocated = allocateTemporaryBuffer( + Value allocated = allocateTemporaryBuffer( rewriter, conditionOp.condition(), iterationCounts); Value zero = rewriter.create(op.getLoc(), 0); @@ -729,20 +728,20 @@ struct DistributeAroundBarrier : public OpRewritePattern { } barrier = &*it; } - + llvm::SetVector crossing; findValuesUsedBelow(barrier, crossing); - std::pair insertPoint = - findInsertionPointAfterLoopOperands(op); - - rewriter.setInsertionPoint(insertPoint.first, insertPoint.second); + //std::pair insertPoint = + // findInsertionPointAfterLoopOperands(op); + //rewriter.setInsertionPoint(insertPoint.first, insertPoint.second); + rewriter.setInsertionPoint(op); SmallVector iterationCounts = emitIterationCounts(rewriter, op); // Allocate space for values crossing the barrier. - SmallVector allocations; + SmallVector allocations; allocations.reserve(crossing.size()); for (Value v : crossing) { - allocations.push_back(allocateTemporaryBuffer(rewriter, v, iterationCounts)); + allocations.push_back(allocateTemporaryBuffer(rewriter, v, iterationCounts)); } // Store values crossing the barrier in caches immediately when ready. @@ -781,6 +780,9 @@ struct DistributeAroundBarrier : public OpRewritePattern { op.getLoc(), op.lowerBound(), op.upperBound(), op.step()); rewriter.eraseOp(&newLoop.getBody()->back()); + for (auto alloc : allocations) + rewriter.create(alloc.getLoc(), alloc); + // Recreate the operations in the new loop with new values. rewriter.setInsertionPointToStart(newLoop.getBody()); BlockAndValueMapping mapping; @@ -973,7 +975,7 @@ struct CPUifyPass : public SCFCPUifyBase { NormalizeParallel, RotateWhile, DistributeAroundBarrier>( &getContext()); GreedyRewriteConfig config; - config.maxIterations = 42; + config.maxIterations = 142; if (failed(applyPatternsAndFoldGreedily(getFunction(), std::move(patterns), config))) signalPassFailure(); diff --git a/mlir-clang/Lib/clang-mlir.cc b/mlir-clang/Lib/clang-mlir.cc index df05a16..7389c9a 100644 --- a/mlir-clang/Lib/clang-mlir.cc +++ b/mlir-clang/Lib/clang-mlir.cc @@ -2652,6 +2652,7 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) { (funcs.count(sr->getDecl()->getName().str()) || sr->getDecl()->getName().startswith("mkl_") || sr->getDecl()->getName().startswith("MKL_") || + sr->getDecl()->getName().startswith("cublas") || sr->getDecl()->getName().startswith("cblas_"))) { std::vector args; @@ -4648,10 +4649,13 @@ ValueWithOffsets MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) { auto &exitB = *(new Block()); builder.setInsertionPointToStart(&exitB); builder.create(loc); + builder.setInsertionPointToStart(&exitB); SmallVector blocks; bool inCase = false; + Block* defaultB = &exitB; + for (auto cse : stmt->getBody()->children()) { if (auto cses = dyn_cast(cse)) { auto &condB = *(new Block()); @@ -4683,6 +4687,31 @@ ValueWithOffsets MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) { builder.create(loc, truev, loops.back().keepRunning); Visit(cses->getSubStmt()); + } else if (auto cses = dyn_cast(cse)) { + auto &condB = *(new Block()); + + if (inCase) { + auto noBreak = + builder.create(loc, loops.back().noBreak); + builder.create(loc, noBreak, &condB, &exitB); + loops.pop_back(); + } + + inCase = true; + er.region().getBlocks().push_back(&condB); + builder.setInsertionPointToStart(&condB); + + auto i1Ty = builder.getIntegerType(1); + auto type = mlir::MemRefType::get({}, i1Ty, {}, 0); + auto truev = builder.create(loc, true, 1); + loops.push_back( + (LoopContext){builder.create(loc, type), + builder.create(loc, type)}); + builder.create(loc, truev, loops.back().noBreak); + builder.create(loc, truev, + loops.back().keepRunning); + defaultB = &condB; + Visit(cses->getSubStmt()); } else { Visit(cse); } @@ -4701,7 +4730,7 @@ ValueWithOffsets MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) { builder.setInsertionPointToStart(&er.region().front()); builder.create( - loc, cond, &exitB, ArrayRef(), caseValuesAttr, blocks, + loc, cond, defaultB, ArrayRef(), caseValuesAttr, blocks, SmallVector(caseVals.size(), ArrayRef())); builder.setInsertionPoint(oldblock2, oldpoint2); return nullptr; diff --git a/mlir-clang/mlir-clang.cc b/mlir-clang/mlir-clang.cc index dcdb2ba..2e938ff 100644 --- a/mlir-clang/mlir-clang.cc +++ b/mlir-clang/mlir-clang.cc @@ -389,6 +389,7 @@ int main(int argc, char **argv) { module->dump(); llvm::errs() << "\n"; } + pm.enableVerifier(false); mlir::OpPassManager &optPM = pm.nest(); if (true) {