[flang[OpenACC] Lower wait directive

This patch adds lowering for the `!$acc wait` directive
from the PFT to OpenACC dialect.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: PeteSteinfeld

Differential Revision: https://reviews.llvm.org/D122399
This commit is contained in:
Valentin Clement 2022-03-24 17:15:00 +01:00
parent 67eb2f144e
commit 44b0ea44f2
No known key found for this signature in database
GPG Key ID: 086D54783C928776
2 changed files with 52 additions and 20 deletions

View File

@ -898,16 +898,16 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
const auto &accClauseList =
std::get<Fortran::parser::AccClauseList>(waitConstruct.t);
mlir::Value ifCond, waitDevnum, async;
SmallVector<mlir::Value, 2> waitOperands;
mlir::Value ifCond, asyncOperand, waitDevnum, async;
SmallVector<mlir::Value> waitOperands;
// Async clause have optional values but can be present with
// no value as well. When there is no value, the op has an attribute to
// represent the clause.
bool addAsyncAttr = false;
auto &firOpBuilder = converter.getFirOpBuilder();
auto currentLocation = converter.getCurrentLocation();
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::Location currentLocation = converter.getCurrentLocation();
Fortran::lower::StatementContext stmtCtx;
if (waitArgument) { // wait has a value.
@ -930,35 +930,26 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
// Lower clauses values mapped to operands.
// Keep track of each group of operands separatly as clauses can appear
// more than once.
for (const auto &clause : accClauseList.v) {
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
if (const auto *ifClause =
std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
mlir::Value cond = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(ifClause->v), stmtCtx));
ifCond = firOpBuilder.createConvert(currentLocation,
firOpBuilder.getI1Type(), cond);
genIfClause(converter, ifClause, ifCond, stmtCtx);
} else if (const auto *asyncClause =
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
const auto &asyncClauseValue = asyncClause->v;
if (asyncClauseValue) { // async has a value.
async = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx));
} else {
addAsyncAttr = true;
}
genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
}
}
// Prepare the operand segement size attribute and the operands value range.
SmallVector<mlir::Value, 8> operands;
SmallVector<int32_t, 4> operandSegments;
SmallVector<mlir::Value> operands;
SmallVector<int32_t> operandSegments;
addOperands(operands, operandSegments, waitOperands);
addOperand(operands, operandSegments, async);
addOperand(operands, operandSegments, waitDevnum);
addOperand(operands, operandSegments, ifCond);
auto waitOp = createSimpleOp<mlir::acc::WaitOp>(firOpBuilder, currentLocation,
operands, operandSegments);
mlir::acc::WaitOp waitOp = createSimpleOp<mlir::acc::WaitOp>(
firOpBuilder, currentLocation, operands, operandSegments);
if (addAsyncAttr)
waitOp.asyncAttr(firOpBuilder.getUnitAttr());

View File

@ -0,0 +1,41 @@
! This test checks lowering of OpenACC wait directive.
! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s
subroutine acc_update
integer :: async = 1
logical :: ifCondition = .TRUE.
!$acc wait
!CHECK: acc.wait{{$}}
!$acc wait if(.true.)
!CHECK: [[IF1:%.*]] = arith.constant true
!CHECK: acc.wait if([[IF1]]){{$}}
!$acc wait if(ifCondition)
!CHECK: [[IFCOND:%.*]] = fir.load %{{.*}} : !fir.ref<!fir.logical<4>>
!CHECK: [[IF2:%.*]] = fir.convert [[IFCOND]] : (!fir.logical<4>) -> i1
!CHECK: acc.wait if([[IF2]]){{$}}
!$acc wait(1, 2)
!CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
!CHECK: [[WAIT2:%.*]] = arith.constant 2 : i32
!CHECK: acc.wait([[WAIT1]], [[WAIT2]] : i32, i32){{$}}
!$acc wait(1) async
!CHECK: [[WAIT3:%.*]] = arith.constant 1 : i32
!CHECK: acc.wait([[WAIT3]] : i32) attributes {async}
!$acc wait(1) async(async)
!CHECK: [[WAIT3:%.*]] = arith.constant 1 : i32
!CHECK: [[ASYNC1:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
!CHECK: acc.wait([[WAIT3]] : i32) async([[ASYNC1]] : i32){{$}}
!$acc wait(devnum: 3: queues: 1, 2)
!CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
!CHECK: [[WAIT2:%.*]] = arith.constant 2 : i32
!CHECK: [[DEVNUM:%.*]] = arith.constant 3 : i32
!CHECK: acc.wait([[WAIT1]], [[WAIT2]] : i32, i32) wait_devnum([[DEVNUM]] : i32){{$}}
end subroutine acc_update