[flang[OpenACC] Lower wait directive
This patch adds lowering for the `!$acc wait` directive from the PFT to OpenACC dialect. This patch is part of the upstreaming effort from fir-dev branch. Reviewed By: PeteSteinfeld Differential Revision: https://reviews.llvm.org/D122399
This commit is contained in:
parent
67eb2f144e
commit
44b0ea44f2
|
@ -898,16 +898,16 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
|
||||||
const auto &accClauseList =
|
const auto &accClauseList =
|
||||||
std::get<Fortran::parser::AccClauseList>(waitConstruct.t);
|
std::get<Fortran::parser::AccClauseList>(waitConstruct.t);
|
||||||
|
|
||||||
mlir::Value ifCond, waitDevnum, async;
|
mlir::Value ifCond, asyncOperand, waitDevnum, async;
|
||||||
SmallVector<mlir::Value, 2> waitOperands;
|
SmallVector<mlir::Value> waitOperands;
|
||||||
|
|
||||||
// Async clause have optional values but can be present with
|
// Async clause have optional values but can be present with
|
||||||
// no value as well. When there is no value, the op has an attribute to
|
// no value as well. When there is no value, the op has an attribute to
|
||||||
// represent the clause.
|
// represent the clause.
|
||||||
bool addAsyncAttr = false;
|
bool addAsyncAttr = false;
|
||||||
|
|
||||||
auto &firOpBuilder = converter.getFirOpBuilder();
|
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
||||||
auto currentLocation = converter.getCurrentLocation();
|
mlir::Location currentLocation = converter.getCurrentLocation();
|
||||||
Fortran::lower::StatementContext stmtCtx;
|
Fortran::lower::StatementContext stmtCtx;
|
||||||
|
|
||||||
if (waitArgument) { // wait has a value.
|
if (waitArgument) { // wait has a value.
|
||||||
|
@ -930,35 +930,26 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
|
||||||
// Lower clauses values mapped to operands.
|
// Lower clauses values mapped to operands.
|
||||||
// Keep track of each group of operands separatly as clauses can appear
|
// Keep track of each group of operands separatly as clauses can appear
|
||||||
// more than once.
|
// more than once.
|
||||||
for (const auto &clause : accClauseList.v) {
|
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
|
||||||
if (const auto *ifClause =
|
if (const auto *ifClause =
|
||||||
std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
|
std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
|
||||||
mlir::Value cond = fir::getBase(converter.genExprValue(
|
genIfClause(converter, ifClause, ifCond, stmtCtx);
|
||||||
*Fortran::semantics::GetExpr(ifClause->v), stmtCtx));
|
|
||||||
ifCond = firOpBuilder.createConvert(currentLocation,
|
|
||||||
firOpBuilder.getI1Type(), cond);
|
|
||||||
} else if (const auto *asyncClause =
|
} else if (const auto *asyncClause =
|
||||||
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
|
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
|
||||||
const auto &asyncClauseValue = asyncClause->v;
|
genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
|
||||||
if (asyncClauseValue) { // async has a value.
|
|
||||||
async = fir::getBase(converter.genExprValue(
|
|
||||||
*Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx));
|
|
||||||
} else {
|
|
||||||
addAsyncAttr = true;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare the operand segement size attribute and the operands value range.
|
// Prepare the operand segement size attribute and the operands value range.
|
||||||
SmallVector<mlir::Value, 8> operands;
|
SmallVector<mlir::Value> operands;
|
||||||
SmallVector<int32_t, 4> operandSegments;
|
SmallVector<int32_t> operandSegments;
|
||||||
addOperands(operands, operandSegments, waitOperands);
|
addOperands(operands, operandSegments, waitOperands);
|
||||||
addOperand(operands, operandSegments, async);
|
addOperand(operands, operandSegments, async);
|
||||||
addOperand(operands, operandSegments, waitDevnum);
|
addOperand(operands, operandSegments, waitDevnum);
|
||||||
addOperand(operands, operandSegments, ifCond);
|
addOperand(operands, operandSegments, ifCond);
|
||||||
|
|
||||||
auto waitOp = createSimpleOp<mlir::acc::WaitOp>(firOpBuilder, currentLocation,
|
mlir::acc::WaitOp waitOp = createSimpleOp<mlir::acc::WaitOp>(
|
||||||
operands, operandSegments);
|
firOpBuilder, currentLocation, operands, operandSegments);
|
||||||
|
|
||||||
if (addAsyncAttr)
|
if (addAsyncAttr)
|
||||||
waitOp.asyncAttr(firOpBuilder.getUnitAttr());
|
waitOp.asyncAttr(firOpBuilder.getUnitAttr());
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
! This test checks lowering of OpenACC wait directive.
|
||||||
|
|
||||||
|
! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s
|
||||||
|
|
||||||
|
subroutine acc_update
|
||||||
|
integer :: async = 1
|
||||||
|
logical :: ifCondition = .TRUE.
|
||||||
|
|
||||||
|
!$acc wait
|
||||||
|
!CHECK: acc.wait{{$}}
|
||||||
|
|
||||||
|
!$acc wait if(.true.)
|
||||||
|
!CHECK: [[IF1:%.*]] = arith.constant true
|
||||||
|
!CHECK: acc.wait if([[IF1]]){{$}}
|
||||||
|
|
||||||
|
!$acc wait if(ifCondition)
|
||||||
|
!CHECK: [[IFCOND:%.*]] = fir.load %{{.*}} : !fir.ref<!fir.logical<4>>
|
||||||
|
!CHECK: [[IF2:%.*]] = fir.convert [[IFCOND]] : (!fir.logical<4>) -> i1
|
||||||
|
!CHECK: acc.wait if([[IF2]]){{$}}
|
||||||
|
|
||||||
|
!$acc wait(1, 2)
|
||||||
|
!CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
|
||||||
|
!CHECK: [[WAIT2:%.*]] = arith.constant 2 : i32
|
||||||
|
!CHECK: acc.wait([[WAIT1]], [[WAIT2]] : i32, i32){{$}}
|
||||||
|
|
||||||
|
!$acc wait(1) async
|
||||||
|
!CHECK: [[WAIT3:%.*]] = arith.constant 1 : i32
|
||||||
|
!CHECK: acc.wait([[WAIT3]] : i32) attributes {async}
|
||||||
|
|
||||||
|
!$acc wait(1) async(async)
|
||||||
|
!CHECK: [[WAIT3:%.*]] = arith.constant 1 : i32
|
||||||
|
!CHECK: [[ASYNC1:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
|
||||||
|
!CHECK: acc.wait([[WAIT3]] : i32) async([[ASYNC1]] : i32){{$}}
|
||||||
|
|
||||||
|
!$acc wait(devnum: 3: queues: 1, 2)
|
||||||
|
!CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
|
||||||
|
!CHECK: [[WAIT2:%.*]] = arith.constant 2 : i32
|
||||||
|
!CHECK: [[DEVNUM:%.*]] = arith.constant 3 : i32
|
||||||
|
!CHECK: acc.wait([[WAIT1]], [[WAIT2]] : i32, i32) wait_devnum([[DEVNUM]] : i32){{$}}
|
||||||
|
|
||||||
|
end subroutine acc_update
|
Loading…
Reference in New Issue