Lowering of OpenMP Parallel operation to LLVM IR 1/n

This patch introduces lowering of the OpenMP parallel operation to LLVM
IR using the OpenMPIRBuilder.

Functions topologicalSort and connectPhiNodes are generalised so that
they work with operations also. connectPhiNodes is also made static.

Lowering works for a parallel region with multiple blocks. Clauses and
arguments of the OpenMP operation are not handled.

Reviewed By: rriddle, anchu-rajendran

Differential Revision: https://reviews.llvm.org/D81660
This commit is contained in:
Kiran Chandramohan 2020-07-13 23:13:04 +01:00
parent 004bf35ba0
commit d9067dca7b
4 changed files with 210 additions and 75 deletions

View File

@ -24,7 +24,6 @@ def OpenMP_Dialect : Dialect {
class OpenMP_Op<string mnemonic, list<OpTrait> traits = []> :
Op<OpenMP_Dialect, mnemonic, traits>;
//===----------------------------------------------------------------------===//
// 2.6 parallel Construct
//===----------------------------------------------------------------------===//
@ -81,8 +80,8 @@ def ParallelOp : OpenMP_Op<"parallel", [AttrSizedOperandSegments]> {
of the parallel region.
}];
let arguments = (ins Optional<I1>:$if_expr_var,
Optional<AnyInteger>:$num_threads_var,
let arguments = (ins Optional<AnyType>:$if_expr_var,
Optional<AnyType>:$num_threads_var,
OptionalAttr<ClauseDefault>:$default_val,
Variadic<AnyType>:$private_vars,
Variadic<AnyType>:$firstprivate_vars,

View File

@ -87,6 +87,8 @@ protected:
llvm::IRBuilder<> &builder);
virtual LogicalResult convertOmpOperation(Operation &op,
llvm::IRBuilder<> &builder);
virtual LogicalResult convertOmpParallel(Operation &op,
llvm::IRBuilder<> &builder);
static std::unique_ptr<llvm::Module> prepareLLVMModule(Operation *m);
/// A helper to look up remapped operands in the value remapping table.
@ -100,7 +102,6 @@ private:
LogicalResult convertFunctions();
LogicalResult convertGlobals();
LogicalResult convertOneFunction(LLVMFuncOp func);
void connectPHINodes(LLVMFuncOp func);
LogicalResult convertBlock(Block &bb, bool ignoreArguments);
llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,

View File

@ -25,11 +25,13 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
using namespace mlir;
@ -304,7 +306,160 @@ ModuleTranslation::ModuleTranslation(Operation *module,
assert(satisfiesLLVMModule(mlirModule) &&
"mlirModule should honor LLVM's module semantics.");
}
ModuleTranslation::~ModuleTranslation() {}
ModuleTranslation::~ModuleTranslation() {
if (ompBuilder)
ompBuilder->finalize();
}
/// Get the SSA value passed to the current block from the terminator operation
/// of its predecessor.
static Value getPHISourceValue(Block *current, Block *pred,
unsigned numArguments, unsigned index) {
Operation &terminator = *pred->getTerminator();
if (isa<LLVM::BrOp>(terminator))
return terminator.getOperand(index);
// For conditional branches, we need to check if the current block is reached
// through the "true" or the "false" branch and take the relevant operands.
auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator);
assert(condBranchOp &&
"only branch operations can be terminators of a block that "
"has successors");
assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) &&
"successors with arguments in LLVM conditional branches must be "
"different blocks");
return condBranchOp.getSuccessor(0) == current
? condBranchOp.trueDestOperands()[index]
: condBranchOp.falseDestOperands()[index];
}
/// Connect the PHI nodes to the results of preceding blocks.
template <typename T>
static void
connectPHINodes(T &func, const DenseMap<Value, llvm::Value *> &valueMapping,
const DenseMap<Block *, llvm::BasicBlock *> &blockMapping) {
// Skip the first block, it cannot be branched to and its arguments correspond
// to the arguments of the LLVM function.
for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
Block *bb = &*it;
llvm::BasicBlock *llvmBB = blockMapping.lookup(bb);
auto phis = llvmBB->phis();
auto numArguments = bb->getNumArguments();
assert(numArguments == std::distance(phis.begin(), phis.end()));
for (auto &numberedPhiNode : llvm::enumerate(phis)) {
auto &phiNode = numberedPhiNode.value();
unsigned index = numberedPhiNode.index();
for (auto *pred : bb->getPredecessors()) {
phiNode.addIncoming(valueMapping.lookup(getPHISourceValue(
bb, pred, numArguments, index)),
blockMapping.lookup(pred));
}
}
}
}
// TODO: implement an iterative version
static void topologicalSortImpl(llvm::SetVector<Block *> &blocks, Block *b) {
blocks.insert(b);
for (Block *bb : b->getSuccessors()) {
if (blocks.count(bb) == 0)
topologicalSortImpl(blocks, bb);
}
}
/// Sort function blocks topologically.
template <typename T>
static llvm::SetVector<Block *> topologicalSort(T &f) {
// For each blocks that has not been visited yet (i.e. that has no
// predecessors), add it to the list and traverse its successors in DFS
// preorder.
llvm::SetVector<Block *> blocks;
for (Block &b : f) {
if (blocks.count(&b) == 0)
topologicalSortImpl(blocks, &b);
}
assert(blocks.size() == f.getBlocks().size() && "some blocks are not sorted");
return blocks;
}
/// Convert the OpenMP parallel Operation to LLVM IR.
LogicalResult
ModuleTranslation::convertOmpParallel(Operation &opInst,
llvm::IRBuilder<> &builder) {
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
llvm::BasicBlock &continuationIP) {
llvm::LLVMContext &llvmContext = llvmModule->getContext();
llvm::BasicBlock *codeGenIPBB = codeGenIP.getBlock();
llvm::Instruction *codeGenIPBBTI = codeGenIPBB->getTerminator();
builder.SetInsertPoint(codeGenIPBB);
for (auto &region : opInst.getRegions()) {
for (auto &bb : region) {
auto *llvmBB = llvm::BasicBlock::Create(
llvmContext, "omp.par.region", codeGenIP.getBlock()->getParent());
blockMapping[&bb] = llvmBB;
}
// Then, convert blocks one by one in topological order to ensure
// defs are converted before uses.
llvm::SetVector<Block *> blocks = topologicalSort(region);
for (auto indexedBB : llvm::enumerate(blocks)) {
Block *bb = indexedBB.value();
llvm::BasicBlock *curLLVMBB = blockMapping[bb];
if (bb->isEntryBlock())
codeGenIPBBTI->setSuccessor(0, curLLVMBB);
// TODO: Error not returned up the hierarchy
if (failed(
convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0)))
return;
// If this block has the terminator then add a jump to
// continuation bb
for (auto &op : *bb) {
if (isa<omp::TerminatorOp>(op)) {
builder.SetInsertPoint(curLLVMBB);
builder.CreateBr(&continuationIP);
}
}
}
// Finally, after all blocks have been traversed and values mapped,
// connect the PHI nodes to the results of preceding blocks.
connectPHINodes(region, valueMapping, blockMapping);
}
};
// TODO: Perform appropriate actions according to the data-sharing
// attribute (shared, private, firstprivate, ...) of variables.
// Currently defaults to shared.
auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
llvm::Value &vPtr,
llvm::Value *&replacementValue) -> InsertPointTy {
replacementValue = &vPtr;
return codeGenIP;
};
// TODO: Perform finalization actions for variables. This has to be
// called for variables which have destructors/finalizers.
auto finiCB = [&](InsertPointTy codeGenIP) {};
// TODO: The various operands of parallel operation are not handled.
// Parallel operation is created with some default options for now.
llvm::Value *ifCond = nullptr;
llvm::Value *numThreads = nullptr;
bool isCancellable = false;
builder.restoreIP(ompBuilder->CreateParallel(
builder, bodyGenCB, privCB, finiCB, ifCond, numThreads,
llvm::omp::OMP_PROC_BIND_default, isCancellable));
return success();
}
/// Given an OpenMP MLIR operation, create the corresponding LLVM IR
/// (including OpenMP runtime calls).
@ -340,6 +495,9 @@ ModuleTranslation::convertOmpOperation(Operation &opInst,
ompBuilder->CreateFlush(builder.saveIP());
return success();
})
.Case([&](omp::TerminatorOp) { return success(); })
.Case(
[&](omp::ParallelOp) { return convertOmpParallel(opInst, builder); })
.Default([&](Operation *inst) {
return inst->emitError("unsupported OpenMP operation: ")
<< inst->getName();
@ -556,75 +714,6 @@ LogicalResult ModuleTranslation::convertGlobals() {
return success();
}
/// Get the SSA value passed to the current block from the terminator operation
/// of its predecessor.
static Value getPHISourceValue(Block *current, Block *pred,
unsigned numArguments, unsigned index) {
auto &terminator = *pred->getTerminator();
if (isa<LLVM::BrOp>(terminator)) {
return terminator.getOperand(index);
}
// For conditional branches, we need to check if the current block is reached
// through the "true" or the "false" branch and take the relevant operands.
auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator);
assert(condBranchOp &&
"only branch operations can be terminators of a block that "
"has successors");
assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) &&
"successors with arguments in LLVM conditional branches must be "
"different blocks");
return condBranchOp.getSuccessor(0) == current
? condBranchOp.trueDestOperands()[index]
: condBranchOp.falseDestOperands()[index];
}
void ModuleTranslation::connectPHINodes(LLVMFuncOp func) {
// Skip the first block, it cannot be branched to and its arguments correspond
// to the arguments of the LLVM function.
for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
Block *bb = &*it;
llvm::BasicBlock *llvmBB = blockMapping.lookup(bb);
auto phis = llvmBB->phis();
auto numArguments = bb->getNumArguments();
assert(numArguments == std::distance(phis.begin(), phis.end()));
for (auto &numberedPhiNode : llvm::enumerate(phis)) {
auto &phiNode = numberedPhiNode.value();
unsigned index = numberedPhiNode.index();
for (auto *pred : bb->getPredecessors()) {
phiNode.addIncoming(valueMapping.lookup(getPHISourceValue(
bb, pred, numArguments, index)),
blockMapping.lookup(pred));
}
}
}
}
// TODO: implement an iterative version
static void topologicalSortImpl(llvm::SetVector<Block *> &blocks, Block *b) {
blocks.insert(b);
for (Block *bb : b->getSuccessors()) {
if (blocks.count(bb) == 0)
topologicalSortImpl(blocks, bb);
}
}
/// Sort function blocks topologically.
static llvm::SetVector<Block *> topologicalSort(LLVMFuncOp f) {
// For each blocks that has not been visited yet (i.e. that has no
// predecessors), add it to the list and traverse its successors in DFS
// preorder.
llvm::SetVector<Block *> blocks;
for (Block &b : f) {
if (blocks.count(&b) == 0)
topologicalSortImpl(blocks, &b);
}
assert(blocks.size() == f.getBlocks().size() && "some blocks are not sorted");
return blocks;
}
/// Attempts to add an attribute identified by `key`, optionally with the given
/// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the
/// attribute has a kind known to LLVM IR, create the attribute of this kind,
@ -772,7 +861,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
// Finally, after all blocks have been traversed and values mapped, connect
// the PHI nodes to the results of preceding blocks.
connectPHINodes(func);
connectPHINodes(func, valueMapping, blockMapping);
return success();
}

View File

@ -32,3 +32,49 @@ llvm.func @test_flush_construct(%arg0: !llvm.i32) {
// CHECK-NEXT: ret void
llvm.return
}
// CHECK-LABEL: define void @test_omp_parallel_1()
llvm.func @test_omp_parallel_1() -> () {
// CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_1:.*]] to {{.*}}
omp.parallel {
omp.barrier
omp.terminator
}
llvm.return
}
// CHECK: define internal void @[[OMP_OUTLINED_FN_1]]
// CHECK: call void @__kmpc_barrier
llvm.func @body(!llvm.i64)
// CHECK-LABEL: define void @test_omp_parallel_2()
llvm.func @test_omp_parallel_2() -> () {
// CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_2:.*]] to {{.*}}
omp.parallel {
^bb0:
%0 = llvm.mlir.constant(1 : index) : !llvm.i64
%1 = llvm.mlir.constant(42 : index) : !llvm.i64
llvm.call @body(%0) : (!llvm.i64) -> ()
llvm.call @body(%1) : (!llvm.i64) -> ()
llvm.br ^bb1
^bb1:
%2 = llvm.add %0, %1 : !llvm.i64
llvm.call @body(%2) : (!llvm.i64) -> ()
omp.terminator
}
llvm.return
}
// CHECK: define internal void @[[OMP_OUTLINED_FN_2]]
// CHECK-LABEL: omp.par.region:
// CHECK: br label %omp.par.region1
// CHECK-LABEL: omp.par.region1:
// CHECK: call void @body(i64 1)
// CHECK: call void @body(i64 42)
// CHECK: br label %omp.par.region2
// CHECK-LABEL: omp.par.region2:
// CHECK: call void @body(i64 43)
// CHECK: br label %omp.par.pre_finalize