Fixup lower

This commit is contained in:
William S. Moses 2021-08-25 11:10:00 -04:00 committed by William Moses
parent 8b95da86a9
commit 1a1a1955b6
3 changed files with 51 additions and 19 deletions

View File

@ -55,9 +55,9 @@ static void findValuesUsedBelow(Operation *op,
/// Returns `true` if the given operation has a BarrierOp transitively nested in
/// one of its regions.
static bool hasNestedBarrier(Operation *op, Operation *direct = nullptr) {
auto result = op->walk([=](polygeist::BarrierOp op) {
if (!direct || op->getParentOp() == direct)
static bool hasNestedBarrier(Operation *direct) {
auto result = direct->walk([=](polygeist::BarrierOp op) {
if (op->getParentOp() == direct)
return WalkResult::interrupt();
else
return WalkResult::skip();
@ -96,7 +96,7 @@ struct ReplaceIfWithFors : public OpRewritePattern<scf::IfOp> {
return failure();
}
if (!hasNestedBarrier(op, op)) {
if (!hasNestedBarrier(op)) {
LLVM_DEBUG(DBGS() << "[if-to-for] no nested barrier\n");
return failure();
}
@ -255,7 +255,7 @@ struct NormalizeParallel : public OpRewritePattern<scf::ParallelOp> {
/// Checks if `op` may need to be wrapped in a pair of barriers. This is a
/// necessary but insufficient condition.
static LogicalResult canWrapWithBarriers(Operation *op) {
if (!isa<scf::ParallelOp>(op->getParentOp())) {
if (!op->getParentOfType<scf::ParallelOp>()) {
LLVM_DEBUG(DBGS() << "[wrap] not nested in a pfor\n");
return failure();
}
@ -537,7 +537,8 @@ findInsertionPointAfterLoopOperands(scf::ParallelOp op) {
return findNearestPostDominatingInsertionPoint(operands, postDominanceInfo);
}
static Value allocateTemporaryBuffer(PatternRewriter &rewriter, Value value,
template<typename T>
static T allocateTemporaryBuffer(PatternRewriter &rewriter, Value value,
ValueRange iterationCounts) {
SmallVector<int64_t> bufferSize(iterationCounts.size(),
ShapedType::kDynamicSize);
@ -559,9 +560,7 @@ static Value allocateTemporaryBuffer(PatternRewriter &rewriter, Value value,
}
}
auto type = MemRefType::get(bufferSize, ty);
Value alloc =
rewriter.create<memref::AllocaOp>(value.getLoc(), type, iterationCounts);
return alloc;
return rewriter.create<T>(value.getLoc(), type, iterationCounts);
}
/// Interchanges a parallel for loop with a while loop it contains. The while
@ -617,7 +616,7 @@ struct InterchangeWhilePFor : public OpRewritePattern<scf::ParallelOp> {
findInsertionPointAfterLoopOperands(op);
rewriter.setInsertionPoint(insertionPoint.first, insertionPoint.second);
SmallVector<Value> iterationCounts = emitIterationCounts(rewriter, op);
Value allocated = allocateTemporaryBuffer(
Value allocated = allocateTemporaryBuffer<memref::AllocaOp>(
rewriter, conditionOp.condition(), iterationCounts);
Value zero = rewriter.create<ConstantIndexOp>(op.getLoc(), 0);
@ -729,20 +728,20 @@ struct DistributeAroundBarrier : public OpRewritePattern<scf::ParallelOp> {
}
barrier = &*it;
}
llvm::SetVector<Value> crossing;
findValuesUsedBelow(barrier, crossing);
std::pair<Block *, Block::iterator> insertPoint =
findInsertionPointAfterLoopOperands(op);
rewriter.setInsertionPoint(insertPoint.first, insertPoint.second);
//std::pair<Block *, Block::iterator> insertPoint =
// findInsertionPointAfterLoopOperands(op);
//rewriter.setInsertionPoint(insertPoint.first, insertPoint.second);
rewriter.setInsertionPoint(op);
SmallVector<Value> iterationCounts = emitIterationCounts(rewriter, op);
// Allocate space for values crossing the barrier.
SmallVector<Value> allocations;
SmallVector<memref::AllocOp> allocations;
allocations.reserve(crossing.size());
for (Value v : crossing) {
allocations.push_back(allocateTemporaryBuffer(rewriter, v, iterationCounts));
allocations.push_back(allocateTemporaryBuffer<memref::AllocOp>(rewriter, v, iterationCounts));
}
// Store values crossing the barrier in caches immediately when ready.
@ -781,6 +780,9 @@ struct DistributeAroundBarrier : public OpRewritePattern<scf::ParallelOp> {
op.getLoc(), op.lowerBound(), op.upperBound(), op.step());
rewriter.eraseOp(&newLoop.getBody()->back());
for (auto alloc : allocations)
rewriter.create<memref::DeallocOp>(alloc.getLoc(), alloc);
// Recreate the operations in the new loop with new values.
rewriter.setInsertionPointToStart(newLoop.getBody());
BlockAndValueMapping mapping;
@ -973,7 +975,7 @@ struct CPUifyPass : public SCFCPUifyBase<CPUifyPass> {
NormalizeParallel, RotateWhile, DistributeAroundBarrier>(
&getContext());
GreedyRewriteConfig config;
config.maxIterations = 42;
config.maxIterations = 142;
if (failed(applyPatternsAndFoldGreedily(getFunction(), std::move(patterns),
config)))
signalPassFailure();

View File

@ -2652,6 +2652,7 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
(funcs.count(sr->getDecl()->getName().str()) ||
sr->getDecl()->getName().startswith("mkl_") ||
sr->getDecl()->getName().startswith("MKL_") ||
sr->getDecl()->getName().startswith("cublas") ||
sr->getDecl()->getName().startswith("cblas_"))) {
std::vector<mlir::Value> args;
@ -4648,10 +4649,13 @@ ValueWithOffsets MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) {
auto &exitB = *(new Block());
builder.setInsertionPointToStart(&exitB);
builder.create<scf::YieldOp>(loc);
builder.setInsertionPointToStart(&exitB);
SmallVector<Block *> blocks;
bool inCase = false;
Block* defaultB = &exitB;
for (auto cse : stmt->getBody()->children()) {
if (auto cses = dyn_cast<CaseStmt>(cse)) {
auto &condB = *(new Block());
@ -4683,6 +4687,31 @@ ValueWithOffsets MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) {
builder.create<mlir::memref::StoreOp>(loc, truev,
loops.back().keepRunning);
Visit(cses->getSubStmt());
} else if (auto cses = dyn_cast<DefaultStmt>(cse)) {
auto &condB = *(new Block());
if (inCase) {
auto noBreak =
builder.create<mlir::memref::LoadOp>(loc, loops.back().noBreak);
builder.create<mlir::CondBranchOp>(loc, noBreak, &condB, &exitB);
loops.pop_back();
}
inCase = true;
er.region().getBlocks().push_back(&condB);
builder.setInsertionPointToStart(&condB);
auto i1Ty = builder.getIntegerType(1);
auto type = mlir::MemRefType::get({}, i1Ty, {}, 0);
auto truev = builder.create<mlir::ConstantIntOp>(loc, true, 1);
loops.push_back(
(LoopContext){builder.create<mlir::memref::AllocaOp>(loc, type),
builder.create<mlir::memref::AllocaOp>(loc, type)});
builder.create<mlir::memref::StoreOp>(loc, truev, loops.back().noBreak);
builder.create<mlir::memref::StoreOp>(loc, truev,
loops.back().keepRunning);
defaultB = &condB;
Visit(cses->getSubStmt());
} else {
Visit(cse);
}
@ -4701,7 +4730,7 @@ ValueWithOffsets MLIRScanner::VisitSwitchStmt(clang::SwitchStmt *stmt) {
builder.setInsertionPointToStart(&er.region().front());
builder.create<mlir::SwitchOp>(
loc, cond, &exitB, ArrayRef<mlir::Value>(), caseValuesAttr, blocks,
loc, cond, defaultB, ArrayRef<mlir::Value>(), caseValuesAttr, blocks,
SmallVector<mlir::ValueRange>(caseVals.size(), ArrayRef<mlir::Value>()));
builder.setInsertionPoint(oldblock2, oldpoint2);
return nullptr;

View File

@ -389,6 +389,7 @@ int main(int argc, char **argv) {
module->dump();
llvm::errs() << "</immediate: mlir>\n";
}
pm.enableVerifier(false);
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
if (true) {