diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index bc0b39cc29d5..54649fd570ef 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -136,6 +136,18 @@ def TerminatorOp : OpenMP_Op<"terminator", [Terminator]> { let assemblyFormat = "attr-dict"; } +def OMP_SCHEDULE_MOD_None : StrEnumAttrCase<"none", 0>; +def OMP_SCHEDULE_MOD_Monotonic : StrEnumAttrCase<"monotonic", 1>; +def OMP_SCHEDULE_MOD_Nonmonotonic : StrEnumAttrCase<"nonmonotonic", 2>; + +def ScheduleModifier : StrEnumAttr<"ScheduleModifier", "OpenMP Schedule Modifier", + [OMP_SCHEDULE_MOD_None, + OMP_SCHEDULE_MOD_Monotonic, + OMP_SCHEDULE_MOD_Nonmonotonic]> +{ + let cppNamespace = "::mlir::omp"; +} + //===----------------------------------------------------------------------===// // 2.9.2 Workshare Loop Construct //===----------------------------------------------------------------------===// @@ -214,6 +226,7 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments, OptionalAttr:$reductions, OptionalAttr:$schedule_val, Optional:$schedule_chunk_var, + OptionalAttr:$schedule_modifier, Confined, [IntMinValue<0>]>:$collapse_val, UnitAttr:$nowait, Confined, [IntMinValue<0>]>:$ordered_val, diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 462335a03dbf..64292b80e4dd 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -253,6 +253,7 @@ static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars, /// sched-wo-chunk ::= `auto` | `runtime` static ParseResult parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule, + SmallVectorImpl> &modifiers, Optional &chunkSize) { if (parser.parseLParen()) return failure(); @@ -276,6 +277,14 @@ parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule, return parser.emitError(parser.getNameLoc()) << " expected schedule kind"; } + // If there is a comma, we have one or more modifiers.. + if (succeeded(parser.parseOptionalComma())) { + StringRef mod; + if (parser.parseKeyword(&mod)) + return failure(); + modifiers.push_back(mod); + } + if (parser.parseRParen()) return failure(); @@ -284,11 +293,14 @@ parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule, /// Print schedule clause static void printScheduleClause(OpAsmPrinter &p, StringRef &sched, + llvm::Optional modifier, Value scheduleChunkVar) { std::string schedLower = sched.lower(); p << "(" << schedLower; if (scheduleChunkVar) p << " = " << scheduleChunkVar; + if (modifier && modifier.getValue() != "none") + p << ", " << modifier; p << ") "; } @@ -551,6 +563,7 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result, SmallVector linearSteps; SmallString<8> schedule; + SmallVector> modifiers; Optional scheduleChunkSize; // Compute the position of clauses in operand segments @@ -670,7 +683,7 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result, clauseSegments[pos[linearClause] + 1] = linearSteps.size(); } else if (clauseKeyword == "schedule") { if (checkAllowed(scheduleClause) || - parseScheduleClause(parser, schedule, scheduleChunkSize)) + parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize)) return failure(); if (scheduleChunkSize) { clauseSegments[pos[scheduleClause]] = 1; @@ -797,6 +810,10 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result, schedule[0] = llvm::toUpper(schedule[0]); auto attr = parser.getBuilder().getStringAttr(schedule); result.addAttribute("schedule_val", attr); + if (modifiers.size() > 0) { + auto mod = parser.getBuilder().getStringAttr(modifiers[0]); + result.addAttribute("schedule_modifier", mod); + } if (scheduleChunkSize) { auto chunkSizeType = parser.getBuilder().getI32Type(); parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands); @@ -916,7 +933,8 @@ static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) { if (auto sched = op.schedule_val()) { p << "schedule"; - printScheduleClause(p, sched.getValue(), op.schedule_chunk_var()); + printScheduleClause(p, sched.getValue(), op.schedule_modifier(), + op.schedule_chunk_var()); } if (auto collapse = op.collapse_val()) diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 9c2654317a43..1bbd654ae800 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -707,8 +707,23 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder, break; } - ompBuilder->applyDynamicWorkshareLoop(ompLoc.DL, loopInfo, allocaIP, - schedType, !loop.nowait(), chunk); + if (loop.schedule_modifier().hasValue()) { + omp::ScheduleModifier modifier = + *omp::symbolizeScheduleModifier(loop.schedule_modifier().getValue()); + switch (modifier) { + case omp::ScheduleModifier::monotonic: + schedType |= llvm::omp::OMPScheduleType::ModifierMonotonic; + break; + case omp::ScheduleModifier::nonmonotonic: + schedType |= llvm::omp::OMPScheduleType::ModifierNonmonotonic; + break; + default: + // Nothing to do here. + break; + } + } + afterIP = ompBuilder->applyDynamicWorkshareLoop( + ompLoc.DL, loopInfo, allocaIP, schedType, !loop.nowait(), chunk); } // Continue building IR after the loop. Note that the LoopInfo returned by diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index b80d7a185b1b..5defcc5d9122 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -177,7 +177,7 @@ func @omp_wsloop_pretty(%lb : index, %ub : index, %step : index, } // CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) linear(%{{.*}} = %{{.*}} : memref) schedule(static) - omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) schedule(static) lastprivate(%data_var : memref) linear(%data_var = %linear_var : memref) { + omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) schedule(static, none) lastprivate(%data_var : memref) linear(%data_var = %linear_var : memref) { omp.yield } @@ -188,6 +188,20 @@ func @omp_wsloop_pretty(%lb : index, %ub : index, %step : index, omp.yield } + // CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) private(%{{.*}} : memref) firstprivate(%{{.*}} : memref) lastprivate(%{{.*}} : memref) linear(%{{.*}} = %{{.*}} : memref) schedule(dynamic = %{{.*}}, nonmonotonic) collapse(3) ordered(2) + omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) ordered(2) private(%data_var : memref) + firstprivate(%data_var : memref) lastprivate(%data_var : memref) linear(%data_var = %linear_var : memref) + schedule(dynamic = %chunk_var, nonmonotonic) collapse(3) { + omp.yield + } + + // CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) private(%{{.*}} : memref) firstprivate(%{{.*}} : memref) lastprivate(%{{.*}} : memref) linear(%{{.*}} = %{{.*}} : memref) schedule(dynamic = %{{.*}}, monotonic) collapse(3) ordered(2) + omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) ordered(2) private(%data_var : memref) + firstprivate(%data_var : memref) lastprivate(%data_var : memref) linear(%data_var = %linear_var : memref) + schedule(dynamic = %chunk_var, monotonic) collapse(3) { + omp.yield + } + // CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) private({{.*}} : memref) omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) private(%data_var : memref) { omp.yield diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir index 273a432c2104..f1025857de1b 100644 --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -467,6 +467,30 @@ llvm.func @test_omp_wsloop_guided(%lb : i64, %ub : i64, %step : i64) -> () { llvm.return } +llvm.func @test_omp_wsloop_dynamic_nonmonotonic(%lb : i64, %ub : i64, %step : i64) -> () { + omp.wsloop (%iv) : i64 = (%lb) to (%ub) step (%step) schedule(dynamic, nonmonotonic) { + // CHECK: call void @__kmpc_dispatch_init_8u(%struct.ident_t* @{{.*}}, i32 %{{.*}}, i32 1073741859 + // CHECK: %[[continue:.*]] = call i32 @__kmpc_dispatch_next_8u + // CHECK: %[[cond:.*]] = icmp ne i32 %[[continue]], 0 + // CHECK br i1 %[[cond]], label %omp_loop.header{{.*}}, label %omp_loop.exit{{.*}} + llvm.call @body(%iv) : (i64) -> () + omp.yield + } + llvm.return +} + +llvm.func @test_omp_wsloop_dynamic_monotonic(%lb : i64, %ub : i64, %step : i64) -> () { + omp.wsloop (%iv) : i64 = (%lb) to (%ub) step (%step) schedule(dynamic, monotonic) { + // CHECK: call void @__kmpc_dispatch_init_8u(%struct.ident_t* @{{.*}}, i32 %{{.*}}, i32 536870947 + // CHECK: %[[continue:.*]] = call i32 @__kmpc_dispatch_next_8u + // CHECK: %[[cond:.*]] = icmp ne i32 %[[continue]], 0 + // CHECK br i1 %[[cond]], label %omp_loop.header{{.*}}, label %omp_loop.exit{{.*}} + llvm.call @body(%iv) : (i64) -> () + omp.yield + } + llvm.return +} + // ----- omp.critical.declare @mutex hint(contended)