xref: /llvm-project/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (revision afcbcae668f1d8061974247f2828190173aef742)
1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
17 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/BuiltinAttributes.h"
20 #include "mlir/IR/DialectImplementation.h"
21 #include "mlir/IR/OpImplementation.h"
22 #include "mlir/IR/OperationSupport.h"
23 #include "mlir/Interfaces/FoldInterfaces.h"
24 
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/BitVector.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/STLForwardCompat.h"
29 #include "llvm/ADT/SmallString.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Frontend/OpenMP/OMPConstants.h"
34 #include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
35 #include <cstddef>
36 #include <iterator>
37 #include <optional>
38 #include <variant>
39 
40 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
41 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
42 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
43 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
44 
45 using namespace mlir;
46 using namespace mlir::omp;
47 
48 static ArrayAttr makeArrayAttr(MLIRContext *context,
49                                llvm::ArrayRef<Attribute> attrs) {
50   return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
51 }
52 
53 static DenseBoolArrayAttr
54 makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef<bool> boolArray) {
55   return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
56 }
57 
58 namespace {
59 struct MemRefPointerLikeModel
60     : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
61                                             MemRefType> {
62   Type getElementType(Type pointer) const {
63     return llvm::cast<MemRefType>(pointer).getElementType();
64   }
65 };
66 
67 struct LLVMPointerPointerLikeModel
68     : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
69                                             LLVM::LLVMPointerType> {
70   Type getElementType(Type pointer) const { return Type(); }
71 };
72 } // namespace
73 
74 void OpenMPDialect::initialize() {
75   addOperations<
76 #define GET_OP_LIST
77 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
78       >();
79   addAttributes<
80 #define GET_ATTRDEF_LIST
81 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
82       >();
83   addTypes<
84 #define GET_TYPEDEF_LIST
85 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
86       >();
87 
88   declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
89 
90   MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
91   LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
92       *getContext());
93 
94   // Attach default offload module interface to module op to access
95   // offload functionality through
96   mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
97       *getContext());
98 
99   // Attach default declare target interfaces to operations which can be marked
100   // as declare target (Global Operations and Functions/Subroutines in dialects
101   // that Fortran (or other languages that lower to MLIR) translates too
102   mlir::LLVM::GlobalOp::attachInterface<
103       mlir::omp::DeclareTargetDefaultModel<mlir::LLVM::GlobalOp>>(
104       *getContext());
105   mlir::LLVM::LLVMFuncOp::attachInterface<
106       mlir::omp::DeclareTargetDefaultModel<mlir::LLVM::LLVMFuncOp>>(
107       *getContext());
108   mlir::func::FuncOp::attachInterface<
109       mlir::omp::DeclareTargetDefaultModel<mlir::func::FuncOp>>(*getContext());
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // Parser and printer for Allocate Clause
114 //===----------------------------------------------------------------------===//
115 
116 /// Parse an allocate clause with allocators and a list of operands with types.
117 ///
118 /// allocate-operand-list :: = allocate-operand |
119 ///                            allocator-operand `,` allocate-operand-list
120 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
121 /// ssa-id-and-type ::= ssa-id `:` type
122 static ParseResult parseAllocateAndAllocator(
123     OpAsmParser &parser,
124     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &allocateVars,
125     SmallVectorImpl<Type> &allocateTypes,
126     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &allocatorVars,
127     SmallVectorImpl<Type> &allocatorTypes) {
128 
129   return parser.parseCommaSeparatedList([&]() {
130     OpAsmParser::UnresolvedOperand operand;
131     Type type;
132     if (parser.parseOperand(operand) || parser.parseColonType(type))
133       return failure();
134     allocatorVars.push_back(operand);
135     allocatorTypes.push_back(type);
136     if (parser.parseArrow())
137       return failure();
138     if (parser.parseOperand(operand) || parser.parseColonType(type))
139       return failure();
140 
141     allocateVars.push_back(operand);
142     allocateTypes.push_back(type);
143     return success();
144   });
145 }
146 
147 /// Print allocate clause
148 static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op,
149                                       OperandRange allocateVars,
150                                       TypeRange allocateTypes,
151                                       OperandRange allocatorVars,
152                                       TypeRange allocatorTypes) {
153   for (unsigned i = 0; i < allocateVars.size(); ++i) {
154     std::string separator = i == allocateVars.size() - 1 ? "" : ", ";
155     p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> ";
156     p << allocateVars[i] << " : " << allocateTypes[i] << separator;
157   }
158 }
159 
160 //===----------------------------------------------------------------------===//
161 // Parser and printer for a clause attribute (StringEnumAttr)
162 //===----------------------------------------------------------------------===//
163 
164 template <typename ClauseAttr>
165 static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
166   using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
167   StringRef enumStr;
168   SMLoc loc = parser.getCurrentLocation();
169   if (parser.parseKeyword(&enumStr))
170     return failure();
171   if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
172     attr = ClauseAttr::get(parser.getContext(), *enumValue);
173     return success();
174   }
175   return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
176 }
177 
178 template <typename ClauseAttr>
179 void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
180   p << stringifyEnum(attr.getValue());
181 }
182 
183 //===----------------------------------------------------------------------===//
184 // Parser and printer for Linear Clause
185 //===----------------------------------------------------------------------===//
186 
187 /// linear ::= `linear` `(` linear-list `)`
188 /// linear-list := linear-val | linear-val linear-list
189 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
190 static ParseResult parseLinearClause(
191     OpAsmParser &parser,
192     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &linearVars,
193     SmallVectorImpl<Type> &linearTypes,
194     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &linearStepVars) {
195   return parser.parseCommaSeparatedList([&]() {
196     OpAsmParser::UnresolvedOperand var;
197     Type type;
198     OpAsmParser::UnresolvedOperand stepVar;
199     if (parser.parseOperand(var) || parser.parseEqual() ||
200         parser.parseOperand(stepVar) || parser.parseColonType(type))
201       return failure();
202 
203     linearVars.push_back(var);
204     linearTypes.push_back(type);
205     linearStepVars.push_back(stepVar);
206     return success();
207   });
208 }
209 
210 /// Print Linear Clause
211 static void printLinearClause(OpAsmPrinter &p, Operation *op,
212                               ValueRange linearVars, TypeRange linearTypes,
213                               ValueRange linearStepVars) {
214   size_t linearVarsSize = linearVars.size();
215   for (unsigned i = 0; i < linearVarsSize; ++i) {
216     std::string separator = i == linearVarsSize - 1 ? "" : ", ";
217     p << linearVars[i];
218     if (linearStepVars.size() > i)
219       p << " = " << linearStepVars[i];
220     p << " : " << linearVars[i].getType() << separator;
221   }
222 }
223 
224 //===----------------------------------------------------------------------===//
225 // Verifier for Nontemporal Clause
226 //===----------------------------------------------------------------------===//
227 
228 static LogicalResult verifyNontemporalClause(Operation *op,
229                                              OperandRange nontemporalVars) {
230 
231   // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
232   DenseSet<Value> nontemporalItems;
233   for (const auto &it : nontemporalVars)
234     if (!nontemporalItems.insert(it).second)
235       return op->emitOpError() << "nontemporal variable used more than once";
236 
237   return success();
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // Parser, verifier and printer for Aligned Clause
242 //===----------------------------------------------------------------------===//
243 static LogicalResult verifyAlignedClause(Operation *op,
244                                          std::optional<ArrayAttr> alignments,
245                                          OperandRange alignedVars) {
246   // Check if number of alignment values equals to number of aligned variables
247   if (!alignedVars.empty()) {
248     if (!alignments || alignments->size() != alignedVars.size())
249       return op->emitOpError()
250              << "expected as many alignment values as aligned variables";
251   } else {
252     if (alignments)
253       return op->emitOpError() << "unexpected alignment values attribute";
254     return success();
255   }
256 
257   // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
258   DenseSet<Value> alignedItems;
259   for (auto it : alignedVars)
260     if (!alignedItems.insert(it).second)
261       return op->emitOpError() << "aligned variable used more than once";
262 
263   if (!alignments)
264     return success();
265 
266   // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
267   for (unsigned i = 0; i < (*alignments).size(); ++i) {
268     if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
269       if (intAttr.getValue().sle(0))
270         return op->emitOpError() << "alignment should be greater than 0";
271     } else {
272       return op->emitOpError() << "expected integer alignment";
273     }
274   }
275 
276   return success();
277 }
278 
279 /// aligned ::= `aligned` `(` aligned-list `)`
280 /// aligned-list := aligned-val | aligned-val aligned-list
281 /// aligned-val := ssa-id-and-type `->` alignment
282 static ParseResult
283 parseAlignedClause(OpAsmParser &parser,
284                    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &alignedVars,
285                    SmallVectorImpl<Type> &alignedTypes,
286                    ArrayAttr &alignmentsAttr) {
287   SmallVector<Attribute> alignmentVec;
288   if (failed(parser.parseCommaSeparatedList([&]() {
289         if (parser.parseOperand(alignedVars.emplace_back()) ||
290             parser.parseColonType(alignedTypes.emplace_back()) ||
291             parser.parseArrow() ||
292             parser.parseAttribute(alignmentVec.emplace_back())) {
293           return failure();
294         }
295         return success();
296       })))
297     return failure();
298   SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
299   alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
300   return success();
301 }
302 
303 /// Print Aligned Clause
304 static void printAlignedClause(OpAsmPrinter &p, Operation *op,
305                                ValueRange alignedVars, TypeRange alignedTypes,
306                                std::optional<ArrayAttr> alignments) {
307   for (unsigned i = 0; i < alignedVars.size(); ++i) {
308     if (i != 0)
309       p << ", ";
310     p << alignedVars[i] << " : " << alignedVars[i].getType();
311     p << " -> " << (*alignments)[i];
312   }
313 }
314 
315 //===----------------------------------------------------------------------===//
316 // Parser, printer and verifier for Schedule Clause
317 //===----------------------------------------------------------------------===//
318 
319 static ParseResult
320 verifyScheduleModifiers(OpAsmParser &parser,
321                         SmallVectorImpl<SmallString<12>> &modifiers) {
322   if (modifiers.size() > 2)
323     return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
324   for (const auto &mod : modifiers) {
325     // Translate the string. If it has no value, then it was not a valid
326     // modifier!
327     auto symbol = symbolizeScheduleModifier(mod);
328     if (!symbol)
329       return parser.emitError(parser.getNameLoc())
330              << " unknown modifier type: " << mod;
331   }
332 
333   // If we have one modifier that is "simd", then stick a "none" modiifer in
334   // index 0.
335   if (modifiers.size() == 1) {
336     if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
337       modifiers.push_back(modifiers[0]);
338       modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
339     }
340   } else if (modifiers.size() == 2) {
341     // If there are two modifier:
342     // First modifier should not be simd, second one should be simd
343     if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
344         symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
345       return parser.emitError(parser.getNameLoc())
346              << " incorrect modifier order";
347   }
348   return success();
349 }
350 
351 /// schedule ::= `schedule` `(` sched-list `)`
352 /// sched-list ::= sched-val | sched-val sched-list |
353 ///                sched-val `,` sched-modifier
354 /// sched-val ::= sched-with-chunk | sched-wo-chunk
355 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
356 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
357 /// sched-wo-chunk ::=  `auto` | `runtime`
358 /// sched-modifier ::=  sched-mod-val | sched-mod-val `,` sched-mod-val
359 /// sched-mod-val ::=  `monotonic` | `nonmonotonic` | `simd` | `none`
360 static ParseResult
361 parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
362                     ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
363                     std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
364                     Type &chunkType) {
365   StringRef keyword;
366   if (parser.parseKeyword(&keyword))
367     return failure();
368   std::optional<mlir::omp::ClauseScheduleKind> schedule =
369       symbolizeClauseScheduleKind(keyword);
370   if (!schedule)
371     return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
372 
373   scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
374   switch (*schedule) {
375   case ClauseScheduleKind::Static:
376   case ClauseScheduleKind::Dynamic:
377   case ClauseScheduleKind::Guided:
378     if (succeeded(parser.parseOptionalEqual())) {
379       chunkSize = OpAsmParser::UnresolvedOperand{};
380       if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
381         return failure();
382     } else {
383       chunkSize = std::nullopt;
384     }
385     break;
386   case ClauseScheduleKind::Auto:
387   case ClauseScheduleKind::Runtime:
388     chunkSize = std::nullopt;
389   }
390 
391   // If there is a comma, we have one or more modifiers..
392   SmallVector<SmallString<12>> modifiers;
393   while (succeeded(parser.parseOptionalComma())) {
394     StringRef mod;
395     if (parser.parseKeyword(&mod))
396       return failure();
397     modifiers.push_back(mod);
398   }
399 
400   if (verifyScheduleModifiers(parser, modifiers))
401     return failure();
402 
403   if (!modifiers.empty()) {
404     SMLoc loc = parser.getCurrentLocation();
405     if (std::optional<ScheduleModifier> mod =
406             symbolizeScheduleModifier(modifiers[0])) {
407       scheduleMod = ScheduleModifierAttr::get(parser.getContext(), *mod);
408     } else {
409       return parser.emitError(loc, "invalid schedule modifier");
410     }
411     // Only SIMD attribute is allowed here!
412     if (modifiers.size() > 1) {
413       assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
414       scheduleSimd = UnitAttr::get(parser.getBuilder().getContext());
415     }
416   }
417 
418   return success();
419 }
420 
421 /// Print schedule clause
422 static void printScheduleClause(OpAsmPrinter &p, Operation *op,
423                                 ClauseScheduleKindAttr scheduleKind,
424                                 ScheduleModifierAttr scheduleMod,
425                                 UnitAttr scheduleSimd, Value scheduleChunk,
426                                 Type scheduleChunkType) {
427   p << stringifyClauseScheduleKind(scheduleKind.getValue());
428   if (scheduleChunk)
429     p << " = " << scheduleChunk << " : " << scheduleChunk.getType();
430   if (scheduleMod)
431     p << ", " << stringifyScheduleModifier(scheduleMod.getValue());
432   if (scheduleSimd)
433     p << ", simd";
434 }
435 
436 //===----------------------------------------------------------------------===//
437 // Parser and printer for Order Clause
438 //===----------------------------------------------------------------------===//
439 
440 // order ::= `order` `(` [order-modifier ':'] concurrent `)`
441 // order-modifier ::= reproducible | unconstrained
442 static ParseResult parseOrderClause(OpAsmParser &parser,
443                                     ClauseOrderKindAttr &order,
444                                     OrderModifierAttr &orderMod) {
445   StringRef enumStr;
446   SMLoc loc = parser.getCurrentLocation();
447   if (parser.parseKeyword(&enumStr))
448     return failure();
449   if (std::optional<OrderModifier> enumValue =
450           symbolizeOrderModifier(enumStr)) {
451     orderMod = OrderModifierAttr::get(parser.getContext(), *enumValue);
452     if (parser.parseOptionalColon())
453       return failure();
454     loc = parser.getCurrentLocation();
455     if (parser.parseKeyword(&enumStr))
456       return failure();
457   }
458   if (std::optional<ClauseOrderKind> enumValue =
459           symbolizeClauseOrderKind(enumStr)) {
460     order = ClauseOrderKindAttr::get(parser.getContext(), *enumValue);
461     return success();
462   }
463   return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
464 }
465 
466 static void printOrderClause(OpAsmPrinter &p, Operation *op,
467                              ClauseOrderKindAttr order,
468                              OrderModifierAttr orderMod) {
469   if (orderMod)
470     p << stringifyOrderModifier(orderMod.getValue()) << ":";
471   if (order)
472     p << stringifyClauseOrderKind(order.getValue());
473 }
474 
475 //===----------------------------------------------------------------------===//
476 // Parsers for operations including clauses that define entry block arguments.
477 //===----------------------------------------------------------------------===//
478 
479 namespace {
480 struct MapParseArgs {
481   SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
482   SmallVectorImpl<Type> &types;
483   MapParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
484                SmallVectorImpl<Type> &types)
485       : vars(vars), types(types) {}
486 };
487 struct PrivateParseArgs {
488   llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
489   llvm::SmallVectorImpl<Type> &types;
490   ArrayAttr &syms;
491   DenseI64ArrayAttr *mapIndices;
492   PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
493                    SmallVectorImpl<Type> &types, ArrayAttr &syms,
494                    DenseI64ArrayAttr *mapIndices = nullptr)
495       : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
496 };
497 
498 struct ReductionParseArgs {
499   SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
500   SmallVectorImpl<Type> &types;
501   DenseBoolArrayAttr &byref;
502   ArrayAttr &syms;
503   ReductionModifierAttr *modifier;
504   ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
505                      SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref,
506                      ArrayAttr &syms, ReductionModifierAttr *mod = nullptr)
507       : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
508 };
509 
510 struct AllRegionParseArgs {
511   std::optional<MapParseArgs> hostEvalArgs;
512   std::optional<ReductionParseArgs> inReductionArgs;
513   std::optional<MapParseArgs> mapArgs;
514   std::optional<PrivateParseArgs> privateArgs;
515   std::optional<ReductionParseArgs> reductionArgs;
516   std::optional<ReductionParseArgs> taskReductionArgs;
517   std::optional<MapParseArgs> useDeviceAddrArgs;
518   std::optional<MapParseArgs> useDevicePtrArgs;
519 };
520 } // namespace
521 
522 static ParseResult parseClauseWithRegionArgs(
523     OpAsmParser &parser,
524     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
525     SmallVectorImpl<Type> &types,
526     SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
527     ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
528     DenseBoolArrayAttr *byref = nullptr,
529     ReductionModifierAttr *modifier = nullptr) {
530   SmallVector<SymbolRefAttr> symbolVec;
531   SmallVector<int64_t> mapIndicesVec;
532   SmallVector<bool> isByRefVec;
533   unsigned regionArgOffset = regionPrivateArgs.size();
534 
535   if (parser.parseLParen())
536     return failure();
537 
538   if (modifier && succeeded(parser.parseOptionalKeyword("mod"))) {
539     StringRef enumStr;
540     if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
541         parser.parseComma())
542       return failure();
543     std::optional<ReductionModifier> enumValue =
544         symbolizeReductionModifier(enumStr);
545     if (!enumValue.has_value())
546       return failure();
547     *modifier = ReductionModifierAttr::get(parser.getContext(), *enumValue);
548     if (!*modifier)
549       return failure();
550   }
551 
552   if (parser.parseCommaSeparatedList([&]() {
553         if (byref)
554           isByRefVec.push_back(
555               parser.parseOptionalKeyword("byref").succeeded());
556 
557         if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
558           return failure();
559 
560         if (parser.parseOperand(operands.emplace_back()) ||
561             parser.parseArrow() ||
562             parser.parseArgument(regionPrivateArgs.emplace_back()))
563           return failure();
564 
565         if (mapIndices) {
566           if (parser.parseOptionalLSquare().succeeded()) {
567             if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
568                 parser.parseInteger(mapIndicesVec.emplace_back()) ||
569                 parser.parseRSquare())
570               return failure();
571           } else
572             mapIndicesVec.push_back(-1);
573         }
574 
575         return success();
576       }))
577     return failure();
578 
579   if (parser.parseColon())
580     return failure();
581 
582   if (parser.parseCommaSeparatedList([&]() {
583         if (parser.parseType(types.emplace_back()))
584           return failure();
585 
586         return success();
587       }))
588     return failure();
589 
590   if (operands.size() != types.size())
591     return failure();
592 
593   if (parser.parseRParen())
594     return failure();
595 
596   auto *argsBegin = regionPrivateArgs.begin();
597   MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
598                                argsBegin + regionArgOffset + types.size());
599   for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
600     prv.type = type;
601   }
602 
603   if (symbols) {
604     SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end());
605     *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
606   }
607 
608   if (!mapIndicesVec.empty())
609     *mapIndices =
610         mlir::DenseI64ArrayAttr::get(parser.getContext(), mapIndicesVec);
611 
612   if (byref)
613     *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
614 
615   return success();
616 }
617 
618 static ParseResult parseBlockArgClause(
619     OpAsmParser &parser,
620     llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
621     StringRef keyword, std::optional<MapParseArgs> mapArgs) {
622   if (succeeded(parser.parseOptionalKeyword(keyword))) {
623     if (!mapArgs)
624       return failure();
625 
626     if (failed(parseClauseWithRegionArgs(parser, mapArgs->vars, mapArgs->types,
627                                          entryBlockArgs)))
628       return failure();
629   }
630   return success();
631 }
632 
633 static ParseResult parseBlockArgClause(
634     OpAsmParser &parser,
635     llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
636     StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
637   if (succeeded(parser.parseOptionalKeyword(keyword))) {
638     if (!privateArgs)
639       return failure();
640 
641     if (failed(parseClauseWithRegionArgs(
642             parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
643             &privateArgs->syms, privateArgs->mapIndices)))
644       return failure();
645   }
646   return success();
647 }
648 
649 static ParseResult parseBlockArgClause(
650     OpAsmParser &parser,
651     llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
652     StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
653   if (succeeded(parser.parseOptionalKeyword(keyword))) {
654     if (!reductionArgs)
655       return failure();
656     if (failed(parseClauseWithRegionArgs(
657             parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
658             &reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref,
659             reductionArgs->modifier)))
660       return failure();
661   }
662   return success();
663 }
664 
665 static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
666                                        AllRegionParseArgs args) {
667   llvm::SmallVector<OpAsmParser::Argument> entryBlockArgs;
668 
669   if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
670                                  args.hostEvalArgs)))
671     return parser.emitError(parser.getCurrentLocation())
672            << "invalid `host_eval` format";
673 
674   if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
675                                  args.inReductionArgs)))
676     return parser.emitError(parser.getCurrentLocation())
677            << "invalid `in_reduction` format";
678 
679   if (failed(parseBlockArgClause(parser, entryBlockArgs, "map_entries",
680                                  args.mapArgs)))
681     return parser.emitError(parser.getCurrentLocation())
682            << "invalid `map_entries` format";
683 
684   if (failed(parseBlockArgClause(parser, entryBlockArgs, "private",
685                                  args.privateArgs)))
686     return parser.emitError(parser.getCurrentLocation())
687            << "invalid `private` format";
688 
689   if (failed(parseBlockArgClause(parser, entryBlockArgs, "reduction",
690                                  args.reductionArgs)))
691     return parser.emitError(parser.getCurrentLocation())
692            << "invalid `reduction` format";
693 
694   if (failed(parseBlockArgClause(parser, entryBlockArgs, "task_reduction",
695                                  args.taskReductionArgs)))
696     return parser.emitError(parser.getCurrentLocation())
697            << "invalid `task_reduction` format";
698 
699   if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",
700                                  args.useDeviceAddrArgs)))
701     return parser.emitError(parser.getCurrentLocation())
702            << "invalid `use_device_addr` format";
703 
704   if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr",
705                                  args.useDevicePtrArgs)))
706     return parser.emitError(parser.getCurrentLocation())
707            << "invalid `use_device_addr` format";
708 
709   return parser.parseRegion(region, entryBlockArgs);
710 }
711 
712 static ParseResult parseHostEvalInReductionMapPrivateRegion(
713     OpAsmParser &parser, Region &region,
714     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hostEvalVars,
715     SmallVectorImpl<Type> &hostEvalTypes,
716     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
717     SmallVectorImpl<Type> &inReductionTypes,
718     DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
719     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapVars,
720     SmallVectorImpl<Type> &mapTypes,
721     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
722     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
723     DenseI64ArrayAttr &privateMaps) {
724   AllRegionParseArgs args;
725   args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
726   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
727                                inReductionByref, inReductionSyms);
728   args.mapArgs.emplace(mapVars, mapTypes);
729   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
730                            &privateMaps);
731   return parseBlockArgRegion(parser, region, args);
732 }
733 
734 static ParseResult parseInReductionPrivateRegion(
735     OpAsmParser &parser, Region &region,
736     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
737     SmallVectorImpl<Type> &inReductionTypes,
738     DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
739     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
740     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
741   AllRegionParseArgs args;
742   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
743                                inReductionByref, inReductionSyms);
744   args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
745   return parseBlockArgRegion(parser, region, args);
746 }
747 
748 static ParseResult parseInReductionPrivateReductionRegion(
749     OpAsmParser &parser, Region &region,
750     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
751     SmallVectorImpl<Type> &inReductionTypes,
752     DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
753     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
754     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
755     ReductionModifierAttr &reductionMod,
756     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
757     SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
758     ArrayAttr &reductionSyms) {
759   AllRegionParseArgs args;
760   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
761                                inReductionByref, inReductionSyms);
762   args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
763   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
764                              reductionSyms, &reductionMod);
765   return parseBlockArgRegion(parser, region, args);
766 }
767 
768 static ParseResult parsePrivateRegion(
769     OpAsmParser &parser, Region &region,
770     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
771     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
772   AllRegionParseArgs args;
773   args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
774   return parseBlockArgRegion(parser, region, args);
775 }
776 
777 static ParseResult parsePrivateReductionRegion(
778     OpAsmParser &parser, Region &region,
779     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
780     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
781     ReductionModifierAttr &reductionMod,
782     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
783     SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
784     ArrayAttr &reductionSyms) {
785   AllRegionParseArgs args;
786   args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
787   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
788                              reductionSyms, &reductionMod);
789   return parseBlockArgRegion(parser, region, args);
790 }
791 
792 static ParseResult parseTaskReductionRegion(
793     OpAsmParser &parser, Region &region,
794     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &taskReductionVars,
795     SmallVectorImpl<Type> &taskReductionTypes,
796     DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
797   AllRegionParseArgs args;
798   args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
799                                  taskReductionByref, taskReductionSyms);
800   return parseBlockArgRegion(parser, region, args);
801 }
802 
803 static ParseResult parseUseDeviceAddrUseDevicePtrRegion(
804     OpAsmParser &parser, Region &region,
805     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &useDeviceAddrVars,
806     SmallVectorImpl<Type> &useDeviceAddrTypes,
807     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &useDevicePtrVars,
808     SmallVectorImpl<Type> &useDevicePtrTypes) {
809   AllRegionParseArgs args;
810   args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
811   args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
812   return parseBlockArgRegion(parser, region, args);
813 }
814 
815 //===----------------------------------------------------------------------===//
816 // Printers for operations including clauses that define entry block arguments.
817 //===----------------------------------------------------------------------===//
818 
819 namespace {
820 struct MapPrintArgs {
821   ValueRange vars;
822   TypeRange types;
823   MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {}
824 };
825 struct PrivatePrintArgs {
826   ValueRange vars;
827   TypeRange types;
828   ArrayAttr syms;
829   DenseI64ArrayAttr mapIndices;
830   PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
831                    DenseI64ArrayAttr mapIndices)
832       : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
833 };
834 struct ReductionPrintArgs {
835   ValueRange vars;
836   TypeRange types;
837   DenseBoolArrayAttr byref;
838   ArrayAttr syms;
839   ReductionModifierAttr modifier;
840   ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
841                      ArrayAttr syms, ReductionModifierAttr mod = nullptr)
842       : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
843 };
844 struct AllRegionPrintArgs {
845   std::optional<MapPrintArgs> hostEvalArgs;
846   std::optional<ReductionPrintArgs> inReductionArgs;
847   std::optional<MapPrintArgs> mapArgs;
848   std::optional<PrivatePrintArgs> privateArgs;
849   std::optional<ReductionPrintArgs> reductionArgs;
850   std::optional<ReductionPrintArgs> taskReductionArgs;
851   std::optional<MapPrintArgs> useDeviceAddrArgs;
852   std::optional<MapPrintArgs> useDevicePtrArgs;
853 };
854 } // namespace
855 
856 static void printClauseWithRegionArgs(
857     OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
858     ValueRange argsSubrange, ValueRange operands, TypeRange types,
859     ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
860     DenseBoolArrayAttr byref = nullptr,
861     ReductionModifierAttr modifier = nullptr) {
862   if (argsSubrange.empty())
863     return;
864 
865   p << clauseName << "(";
866 
867   if (modifier)
868     p << "mod: " << stringifyReductionModifier(modifier.getValue()) << ", ";
869 
870   if (!symbols) {
871     llvm::SmallVector<Attribute> values(operands.size(), nullptr);
872     symbols = ArrayAttr::get(ctx, values);
873   }
874 
875   if (!mapIndices) {
876     llvm::SmallVector<int64_t> values(operands.size(), -1);
877     mapIndices = DenseI64ArrayAttr::get(ctx, values);
878   }
879 
880   if (!byref) {
881     mlir::SmallVector<bool> values(operands.size(), false);
882     byref = DenseBoolArrayAttr::get(ctx, values);
883   }
884 
885   llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
886                                         mapIndices.asArrayRef(),
887                                         byref.asArrayRef()),
888                         p, [&p](auto t) {
889                           auto [op, arg, sym, map, isByRef] = t;
890                           if (isByRef)
891                             p << "byref ";
892                           if (sym)
893                             p << sym << " ";
894 
895                           p << op << " -> " << arg;
896 
897                           if (map != -1)
898                             p << " [map_idx=" << map << "]";
899                         });
900   p << " : ";
901   llvm::interleaveComma(types, p);
902   p << ") ";
903 }
904 
905 static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
906                                 StringRef clauseName, ValueRange argsSubrange,
907                                 std::optional<MapPrintArgs> mapArgs) {
908   if (mapArgs)
909     printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, mapArgs->vars,
910                               mapArgs->types);
911 }
912 
913 static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
914                                 StringRef clauseName, ValueRange argsSubrange,
915                                 std::optional<PrivatePrintArgs> privateArgs) {
916   if (privateArgs)
917     printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
918                               privateArgs->vars, privateArgs->types,
919                               privateArgs->syms, privateArgs->mapIndices);
920 }
921 
922 static void
923 printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
924                     ValueRange argsSubrange,
925                     std::optional<ReductionPrintArgs> reductionArgs) {
926   if (reductionArgs)
927     printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
928                               reductionArgs->vars, reductionArgs->types,
929                               reductionArgs->syms, /*mapIndices=*/nullptr,
930                               reductionArgs->byref, reductionArgs->modifier);
931 }
932 
933 static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
934                                 const AllRegionPrintArgs &args) {
935   auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
936   MLIRContext *ctx = op->getContext();
937 
938   printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
939                       args.hostEvalArgs);
940   printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
941                       args.inReductionArgs);
942   printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
943                       args.mapArgs);
944   printBlockArgClause(p, ctx, "private", iface.getPrivateBlockArgs(),
945                       args.privateArgs);
946   printBlockArgClause(p, ctx, "reduction", iface.getReductionBlockArgs(),
947                       args.reductionArgs);
948   printBlockArgClause(p, ctx, "task_reduction",
949                       iface.getTaskReductionBlockArgs(),
950                       args.taskReductionArgs);
951   printBlockArgClause(p, ctx, "use_device_addr",
952                       iface.getUseDeviceAddrBlockArgs(),
953                       args.useDeviceAddrArgs);
954   printBlockArgClause(p, ctx, "use_device_ptr",
955                       iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
956 
957   p.printRegion(region, /*printEntryBlockArgs=*/false);
958 }
959 
960 static void printHostEvalInReductionMapPrivateRegion(
961     OpAsmPrinter &p, Operation *op, Region &region, ValueRange hostEvalVars,
962     TypeRange hostEvalTypes, ValueRange inReductionVars,
963     TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
964     ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
965     ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
966     DenseI64ArrayAttr privateMaps) {
967   AllRegionPrintArgs args;
968   args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
969   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
970                                inReductionByref, inReductionSyms);
971   args.mapArgs.emplace(mapVars, mapTypes);
972   args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps);
973   printBlockArgRegion(p, op, region, args);
974 }
975 
976 static void printInReductionPrivateRegion(
977     OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
978     TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
979     ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
980     ArrayAttr privateSyms) {
981   AllRegionPrintArgs args;
982   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
983                                inReductionByref, inReductionSyms);
984   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
985                            /*mapIndices=*/nullptr);
986   printBlockArgRegion(p, op, region, args);
987 }
988 
989 static void printInReductionPrivateReductionRegion(
990     OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
991     TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
992     ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
993     ArrayAttr privateSyms, ReductionModifierAttr reductionMod,
994     ValueRange reductionVars, TypeRange reductionTypes,
995     DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms) {
996   AllRegionPrintArgs args;
997   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
998                                inReductionByref, inReductionSyms);
999   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1000                            /*mapIndices=*/nullptr);
1001   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1002                              reductionSyms, reductionMod);
1003   printBlockArgRegion(p, op, region, args);
1004 }
1005 
1006 static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
1007                                ValueRange privateVars, TypeRange privateTypes,
1008                                ArrayAttr privateSyms) {
1009   AllRegionPrintArgs args;
1010   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1011                            /*mapIndices=*/nullptr);
1012   printBlockArgRegion(p, op, region, args);
1013 }
1014 
1015 static void printPrivateReductionRegion(
1016     OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
1017     TypeRange privateTypes, ArrayAttr privateSyms,
1018     ReductionModifierAttr reductionMod, ValueRange reductionVars,
1019     TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1020     ArrayAttr reductionSyms) {
1021   AllRegionPrintArgs args;
1022   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1023                            /*mapIndices=*/nullptr);
1024   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1025                              reductionSyms, reductionMod);
1026   printBlockArgRegion(p, op, region, args);
1027 }
1028 
1029 static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op,
1030                                      Region &region,
1031                                      ValueRange taskReductionVars,
1032                                      TypeRange taskReductionTypes,
1033                                      DenseBoolArrayAttr taskReductionByref,
1034                                      ArrayAttr taskReductionSyms) {
1035   AllRegionPrintArgs args;
1036   args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1037                                  taskReductionByref, taskReductionSyms);
1038   printBlockArgRegion(p, op, region, args);
1039 }
1040 
1041 static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op,
1042                                                  Region &region,
1043                                                  ValueRange useDeviceAddrVars,
1044                                                  TypeRange useDeviceAddrTypes,
1045                                                  ValueRange useDevicePtrVars,
1046                                                  TypeRange useDevicePtrTypes) {
1047   AllRegionPrintArgs args;
1048   args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1049   args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1050   printBlockArgRegion(p, op, region, args);
1051 }
1052 
1053 /// Verifies Reduction Clause
1054 static LogicalResult
1055 verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
1056                        OperandRange reductionVars,
1057                        std::optional<ArrayRef<bool>> reductionByref) {
1058   if (!reductionVars.empty()) {
1059     if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1060       return op->emitOpError()
1061              << "expected as many reduction symbol references "
1062                 "as reduction variables";
1063     if (reductionByref && reductionByref->size() != reductionVars.size())
1064       return op->emitError() << "expected as many reduction variable by "
1065                                 "reference attributes as reduction variables";
1066   } else {
1067     if (reductionSyms)
1068       return op->emitOpError() << "unexpected reduction symbol references";
1069     return success();
1070   }
1071 
1072   // TODO: The followings should be done in
1073   // SymbolUserOpInterface::verifySymbolUses.
1074   DenseSet<Value> accumulators;
1075   for (auto args : llvm::zip(reductionVars, *reductionSyms)) {
1076     Value accum = std::get<0>(args);
1077 
1078     if (!accumulators.insert(accum).second)
1079       return op->emitOpError() << "accumulator variable used more than once";
1080 
1081     Type varType = accum.getType();
1082     auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1083     auto decl =
1084         SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
1085     if (!decl)
1086       return op->emitOpError() << "expected symbol reference " << symbolRef
1087                                << " to point to a reduction declaration";
1088 
1089     if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1090       return op->emitOpError()
1091              << "expected accumulator (" << varType
1092              << ") to be the same type as reduction declaration ("
1093              << decl.getAccumulatorType() << ")";
1094   }
1095 
1096   return success();
1097 }
1098 
1099 //===----------------------------------------------------------------------===//
1100 // Parser, printer and verifier for Copyprivate
1101 //===----------------------------------------------------------------------===//
1102 
1103 /// copyprivate-entry-list ::= copyprivate-entry
1104 ///                          | copyprivate-entry-list `,` copyprivate-entry
1105 /// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
1106 static ParseResult parseCopyprivate(
1107     OpAsmParser &parser,
1108     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &copyprivateVars,
1109     SmallVectorImpl<Type> &copyprivateTypes, ArrayAttr &copyprivateSyms) {
1110   SmallVector<SymbolRefAttr> symsVec;
1111   if (failed(parser.parseCommaSeparatedList([&]() {
1112         if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1113             parser.parseArrow() ||
1114             parser.parseAttribute(symsVec.emplace_back()) ||
1115             parser.parseColonType(copyprivateTypes.emplace_back()))
1116           return failure();
1117         return success();
1118       })))
1119     return failure();
1120   SmallVector<Attribute> syms(symsVec.begin(), symsVec.end());
1121   copyprivateSyms = ArrayAttr::get(parser.getContext(), syms);
1122   return success();
1123 }
1124 
1125 /// Print Copyprivate clause
1126 static void printCopyprivate(OpAsmPrinter &p, Operation *op,
1127                              OperandRange copyprivateVars,
1128                              TypeRange copyprivateTypes,
1129                              std::optional<ArrayAttr> copyprivateSyms) {
1130   if (!copyprivateSyms.has_value())
1131     return;
1132   llvm::interleaveComma(
1133       llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1134       [&](const auto &args) {
1135         p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
1136           << std::get<2>(args);
1137       });
1138 }
1139 
1140 /// Verifies CopyPrivate Clause
1141 static LogicalResult
1142 verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars,
1143                          std::optional<ArrayAttr> copyprivateSyms) {
1144   size_t copyprivateSymsSize =
1145       copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1146   if (copyprivateSymsSize != copyprivateVars.size())
1147     return op->emitOpError() << "inconsistent number of copyprivate vars (= "
1148                              << copyprivateVars.size()
1149                              << ") and functions (= " << copyprivateSymsSize
1150                              << "), both must be equal";
1151   if (!copyprivateSyms.has_value())
1152     return success();
1153 
1154   for (auto copyprivateVarAndSym :
1155        llvm::zip(copyprivateVars, *copyprivateSyms)) {
1156     auto symbolRef =
1157         llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1158     std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1159         funcOp;
1160     if (mlir::func::FuncOp mlirFuncOp =
1161             SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
1162                                                                      symbolRef))
1163       funcOp = mlirFuncOp;
1164     else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1165                  SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
1166                      op, symbolRef))
1167       funcOp = llvmFuncOp;
1168 
1169     auto getNumArguments = [&] {
1170       return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
1171     };
1172 
1173     auto getArgumentType = [&](unsigned i) {
1174       return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
1175                         *funcOp);
1176     };
1177 
1178     if (!funcOp)
1179       return op->emitOpError() << "expected symbol reference " << symbolRef
1180                                << " to point to a copy function";
1181 
1182     if (getNumArguments() != 2)
1183       return op->emitOpError()
1184              << "expected copy function " << symbolRef << " to have 2 operands";
1185 
1186     Type argTy = getArgumentType(0);
1187     if (argTy != getArgumentType(1))
1188       return op->emitOpError() << "expected copy function " << symbolRef
1189                                << " arguments to have the same type";
1190 
1191     Type varType = std::get<0>(copyprivateVarAndSym).getType();
1192     if (argTy != varType)
1193       return op->emitOpError()
1194              << "expected copy function arguments' type (" << argTy
1195              << ") to be the same as copyprivate variable's type (" << varType
1196              << ")";
1197   }
1198 
1199   return success();
1200 }
1201 
1202 //===----------------------------------------------------------------------===//
1203 // Parser, printer and verifier for DependVarList
1204 //===----------------------------------------------------------------------===//
1205 
1206 /// depend-entry-list ::= depend-entry
1207 ///                     | depend-entry-list `,` depend-entry
1208 /// depend-entry ::= depend-kind `->` ssa-id `:` type
1209 static ParseResult
1210 parseDependVarList(OpAsmParser &parser,
1211                    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dependVars,
1212                    SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds) {
1213   SmallVector<ClauseTaskDependAttr> kindsVec;
1214   if (failed(parser.parseCommaSeparatedList([&]() {
1215         StringRef keyword;
1216         if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1217             parser.parseOperand(dependVars.emplace_back()) ||
1218             parser.parseColonType(dependTypes.emplace_back()))
1219           return failure();
1220         if (std::optional<ClauseTaskDepend> keywordDepend =
1221                 (symbolizeClauseTaskDepend(keyword)))
1222           kindsVec.emplace_back(
1223               ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
1224         else
1225           return failure();
1226         return success();
1227       })))
1228     return failure();
1229   SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end());
1230   dependKinds = ArrayAttr::get(parser.getContext(), kinds);
1231   return success();
1232 }
1233 
1234 /// Print Depend clause
1235 static void printDependVarList(OpAsmPrinter &p, Operation *op,
1236                                OperandRange dependVars, TypeRange dependTypes,
1237                                std::optional<ArrayAttr> dependKinds) {
1238 
1239   for (unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
1240     if (i != 0)
1241       p << ", ";
1242     p << stringifyClauseTaskDepend(
1243              llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
1244                  .getValue())
1245       << " -> " << dependVars[i] << " : " << dependTypes[i];
1246   }
1247 }
1248 
1249 /// Verifies Depend clause
1250 static LogicalResult verifyDependVarList(Operation *op,
1251                                          std::optional<ArrayAttr> dependKinds,
1252                                          OperandRange dependVars) {
1253   if (!dependVars.empty()) {
1254     if (!dependKinds || dependKinds->size() != dependVars.size())
1255       return op->emitOpError() << "expected as many depend values"
1256                                   " as depend variables";
1257   } else {
1258     if (dependKinds && !dependKinds->empty())
1259       return op->emitOpError() << "unexpected depend values";
1260     return success();
1261   }
1262 
1263   return success();
1264 }
1265 
1266 //===----------------------------------------------------------------------===//
1267 // Parser, printer and verifier for Synchronization Hint (2.17.12)
1268 //===----------------------------------------------------------------------===//
1269 
1270 /// Parses a Synchronization Hint clause. The value of hint is an integer
1271 /// which is a combination of different hints from `omp_sync_hint_t`.
1272 ///
1273 /// hint-clause = `hint` `(` hint-value `)`
1274 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
1275                                             IntegerAttr &hintAttr) {
1276   StringRef hintKeyword;
1277   int64_t hint = 0;
1278   if (succeeded(parser.parseOptionalKeyword("none"))) {
1279     hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
1280     return success();
1281   }
1282   auto parseKeyword = [&]() -> ParseResult {
1283     if (failed(parser.parseKeyword(&hintKeyword)))
1284       return failure();
1285     if (hintKeyword == "uncontended")
1286       hint |= 1;
1287     else if (hintKeyword == "contended")
1288       hint |= 2;
1289     else if (hintKeyword == "nonspeculative")
1290       hint |= 4;
1291     else if (hintKeyword == "speculative")
1292       hint |= 8;
1293     else
1294       return parser.emitError(parser.getCurrentLocation())
1295              << hintKeyword << " is not a valid hint";
1296     return success();
1297   };
1298   if (parser.parseCommaSeparatedList(parseKeyword))
1299     return failure();
1300   hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
1301   return success();
1302 }
1303 
1304 /// Prints a Synchronization Hint clause
1305 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
1306                                      IntegerAttr hintAttr) {
1307   int64_t hint = hintAttr.getInt();
1308 
1309   if (hint == 0) {
1310     p << "none";
1311     return;
1312   }
1313 
1314   // Helper function to get n-th bit from the right end of `value`
1315   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1316 
1317   bool uncontended = bitn(hint, 0);
1318   bool contended = bitn(hint, 1);
1319   bool nonspeculative = bitn(hint, 2);
1320   bool speculative = bitn(hint, 3);
1321 
1322   SmallVector<StringRef> hints;
1323   if (uncontended)
1324     hints.push_back("uncontended");
1325   if (contended)
1326     hints.push_back("contended");
1327   if (nonspeculative)
1328     hints.push_back("nonspeculative");
1329   if (speculative)
1330     hints.push_back("speculative");
1331 
1332   llvm::interleaveComma(hints, p);
1333 }
1334 
1335 /// Verifies a synchronization hint clause
1336 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
1337 
1338   // Helper function to get n-th bit from the right end of `value`
1339   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1340 
1341   bool uncontended = bitn(hint, 0);
1342   bool contended = bitn(hint, 1);
1343   bool nonspeculative = bitn(hint, 2);
1344   bool speculative = bitn(hint, 3);
1345 
1346   if (uncontended && contended)
1347     return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
1348                                 "omp_sync_hint_contended cannot be combined";
1349   if (nonspeculative && speculative)
1350     return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
1351                                 "omp_sync_hint_speculative cannot be combined.";
1352   return success();
1353 }
1354 
1355 //===----------------------------------------------------------------------===//
1356 // Parser, printer and verifier for Target
1357 //===----------------------------------------------------------------------===//
1358 
1359 // Helper function to get bitwise AND of `value` and 'flag'
1360 uint64_t mapTypeToBitFlag(uint64_t value,
1361                           llvm::omp::OpenMPOffloadMappingFlags flag) {
1362   return value & llvm::to_underlying(flag);
1363 }
1364 
1365 /// Parses a map_entries map type from a string format back into its numeric
1366 /// value.
1367 ///
1368 /// map-clause = `map_clauses (  ( `(` `always, `? `close, `? `present, `? (
1369 /// `to` | `from` | `delete` `)` )+ `)` )
1370 static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
1371   llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1372       llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
1373 
1374   // This simply verifies the correct keyword is read in, the
1375   // keyword itself is stored inside of the operation
1376   auto parseTypeAndMod = [&]() -> ParseResult {
1377     StringRef mapTypeMod;
1378     if (parser.parseKeyword(&mapTypeMod))
1379       return failure();
1380 
1381     if (mapTypeMod == "always")
1382       mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
1383 
1384     if (mapTypeMod == "implicit")
1385       mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
1386 
1387     if (mapTypeMod == "close")
1388       mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
1389 
1390     if (mapTypeMod == "present")
1391       mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
1392 
1393     if (mapTypeMod == "to")
1394       mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
1395 
1396     if (mapTypeMod == "from")
1397       mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1398 
1399     if (mapTypeMod == "tofrom")
1400       mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1401                      llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1402 
1403     if (mapTypeMod == "delete")
1404       mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
1405 
1406     return success();
1407   };
1408 
1409   if (parser.parseCommaSeparatedList(parseTypeAndMod))
1410     return failure();
1411 
1412   mapType = parser.getBuilder().getIntegerAttr(
1413       parser.getBuilder().getIntegerType(64, /*isSigned=*/false),
1414       llvm::to_underlying(mapTypeBits));
1415 
1416   return success();
1417 }
1418 
1419 /// Prints a map_entries map type from its numeric value out into its string
1420 /// format.
1421 static void printMapClause(OpAsmPrinter &p, Operation *op,
1422                            IntegerAttr mapType) {
1423   uint64_t mapTypeBits = mapType.getUInt();
1424 
1425   bool emitAllocRelease = true;
1426   llvm::SmallVector<std::string, 4> mapTypeStrs;
1427 
1428   // handling of always, close, present placed at the beginning of the string
1429   // to aid readability
1430   if (mapTypeToBitFlag(mapTypeBits,
1431                        llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
1432     mapTypeStrs.push_back("always");
1433   if (mapTypeToBitFlag(mapTypeBits,
1434                        llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
1435     mapTypeStrs.push_back("implicit");
1436   if (mapTypeToBitFlag(mapTypeBits,
1437                        llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
1438     mapTypeStrs.push_back("close");
1439   if (mapTypeToBitFlag(mapTypeBits,
1440                        llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
1441     mapTypeStrs.push_back("present");
1442 
1443   // special handling of to/from/tofrom/delete and release/alloc, release +
1444   // alloc are the abscense of one of the other flags, whereas tofrom requires
1445   // both the to and from flag to be set.
1446   bool to = mapTypeToBitFlag(mapTypeBits,
1447                              llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1448   bool from = mapTypeToBitFlag(
1449       mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1450   if (to && from) {
1451     emitAllocRelease = false;
1452     mapTypeStrs.push_back("tofrom");
1453   } else if (from) {
1454     emitAllocRelease = false;
1455     mapTypeStrs.push_back("from");
1456   } else if (to) {
1457     emitAllocRelease = false;
1458     mapTypeStrs.push_back("to");
1459   }
1460   if (mapTypeToBitFlag(mapTypeBits,
1461                        llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
1462     emitAllocRelease = false;
1463     mapTypeStrs.push_back("delete");
1464   }
1465   if (emitAllocRelease)
1466     mapTypeStrs.push_back("exit_release_or_enter_alloc");
1467 
1468   for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1469     p << mapTypeStrs[i];
1470     if (i + 1 < mapTypeStrs.size()) {
1471       p << ", ";
1472     }
1473   }
1474 }
1475 
1476 static ParseResult parseMembersIndex(OpAsmParser &parser,
1477                                      ArrayAttr &membersIdx) {
1478   SmallVector<Attribute> values, memberIdxs;
1479 
1480   auto parseIndices = [&]() -> ParseResult {
1481     int64_t value;
1482     if (parser.parseInteger(value))
1483       return failure();
1484     values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64),
1485                                       APInt(64, value, /*isSigned=*/false)));
1486     return success();
1487   };
1488 
1489   do {
1490     if (failed(parser.parseLSquare()))
1491       return failure();
1492 
1493     if (parser.parseCommaSeparatedList(parseIndices))
1494       return failure();
1495 
1496     if (failed(parser.parseRSquare()))
1497       return failure();
1498 
1499     memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values));
1500     values.clear();
1501   } while (succeeded(parser.parseOptionalComma()));
1502 
1503   if (!memberIdxs.empty())
1504     membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs);
1505 
1506   return success();
1507 }
1508 
1509 static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
1510                               ArrayAttr membersIdx) {
1511   if (!membersIdx)
1512     return;
1513 
1514   llvm::interleaveComma(membersIdx, p, [&p](Attribute v) {
1515     p << "[";
1516     auto memberIdx = cast<ArrayAttr>(v);
1517     llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) {
1518       p << cast<IntegerAttr>(v2).getInt();
1519     });
1520     p << "]";
1521   });
1522 }
1523 
1524 static void printCaptureType(OpAsmPrinter &p, Operation *op,
1525                              VariableCaptureKindAttr mapCaptureType) {
1526   std::string typeCapStr;
1527   llvm::raw_string_ostream typeCap(typeCapStr);
1528   if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1529     typeCap << "ByRef";
1530   if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1531     typeCap << "ByCopy";
1532   if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1533     typeCap << "VLAType";
1534   if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1535     typeCap << "This";
1536   p << typeCapStr;
1537 }
1538 
1539 static ParseResult parseCaptureType(OpAsmParser &parser,
1540                                     VariableCaptureKindAttr &mapCaptureType) {
1541   StringRef mapCaptureKey;
1542   if (parser.parseKeyword(&mapCaptureKey))
1543     return failure();
1544 
1545   if (mapCaptureKey == "This")
1546     mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1547         parser.getContext(), mlir::omp::VariableCaptureKind::This);
1548   if (mapCaptureKey == "ByRef")
1549     mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1550         parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
1551   if (mapCaptureKey == "ByCopy")
1552     mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1553         parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1554   if (mapCaptureKey == "VLAType")
1555     mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1556         parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
1557 
1558   return success();
1559 }
1560 
1561 static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
1562   llvm::DenseSet<mlir::TypedValue<mlir::omp::PointerLikeType>> updateToVars;
1563   llvm::DenseSet<mlir::TypedValue<mlir::omp::PointerLikeType>> updateFromVars;
1564 
1565   for (auto mapOp : mapVars) {
1566     if (!mapOp.getDefiningOp())
1567       emitError(op->getLoc(), "missing map operation");
1568 
1569     if (auto mapInfoOp =
1570             mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {
1571       if (!mapInfoOp.getMapType().has_value())
1572         emitError(op->getLoc(), "missing map type for map operand");
1573 
1574       if (!mapInfoOp.getMapCaptureType().has_value())
1575         emitError(op->getLoc(), "missing map capture type for map operand");
1576 
1577       uint64_t mapTypeBits = mapInfoOp.getMapType().value();
1578 
1579       bool to = mapTypeToBitFlag(
1580           mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1581       bool from = mapTypeToBitFlag(
1582           mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1583       bool del = mapTypeToBitFlag(
1584           mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
1585 
1586       bool always = mapTypeToBitFlag(
1587           mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
1588       bool close = mapTypeToBitFlag(
1589           mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
1590       bool implicit = mapTypeToBitFlag(
1591           mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
1592 
1593       if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
1594         return emitError(op->getLoc(),
1595                          "to, from, tofrom and alloc map types are permitted");
1596 
1597       if (isa<TargetEnterDataOp>(op) && (from || del))
1598         return emitError(op->getLoc(), "to and alloc map types are permitted");
1599 
1600       if (isa<TargetExitDataOp>(op) && to)
1601         return emitError(op->getLoc(),
1602                          "from, release and delete map types are permitted");
1603 
1604       if (isa<TargetUpdateOp>(op)) {
1605         if (del) {
1606           return emitError(op->getLoc(),
1607                            "at least one of to or from map types must be "
1608                            "specified, other map types are not permitted");
1609         }
1610 
1611         if (!to && !from) {
1612           return emitError(op->getLoc(),
1613                            "at least one of to or from map types must be "
1614                            "specified, other map types are not permitted");
1615         }
1616 
1617         auto updateVar = mapInfoOp.getVarPtr();
1618 
1619         if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
1620             (from && updateToVars.contains(updateVar))) {
1621           return emitError(
1622               op->getLoc(),
1623               "either to or from map types can be specified, not both");
1624         }
1625 
1626         if (always || close || implicit) {
1627           return emitError(
1628               op->getLoc(),
1629               "present, mapper and iterator map type modifiers are permitted");
1630         }
1631 
1632         to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
1633       }
1634     } else {
1635       emitError(op->getLoc(), "map argument is not a map entry operation");
1636     }
1637   }
1638 
1639   return success();
1640 }
1641 
1642 static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
1643   std::optional<DenseI64ArrayAttr> privateMapIndices =
1644       targetOp.getPrivateMapsAttr();
1645 
1646   // None of the private operands are mapped.
1647   if (!privateMapIndices.has_value() || !privateMapIndices.value())
1648     return success();
1649 
1650   OperandRange privateVars = targetOp.getPrivateVars();
1651 
1652   if (privateMapIndices.value().size() !=
1653       static_cast<int64_t>(privateVars.size()))
1654     return emitError(targetOp.getLoc(), "sizes of `private` operand range and "
1655                                         "`private_maps` attribute mismatch");
1656 
1657   return success();
1658 }
1659 
1660 //===----------------------------------------------------------------------===//
1661 // TargetDataOp
1662 //===----------------------------------------------------------------------===//
1663 
1664 void TargetDataOp::build(OpBuilder &builder, OperationState &state,
1665                          const TargetDataOperands &clauses) {
1666   TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
1667                       clauses.mapVars, clauses.useDeviceAddrVars,
1668                       clauses.useDevicePtrVars);
1669 }
1670 
1671 LogicalResult TargetDataOp::verify() {
1672   if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
1673       getUseDeviceAddrVars().empty()) {
1674     return ::emitError(this->getLoc(),
1675                        "At least one of map, use_device_ptr_vars, or "
1676                        "use_device_addr_vars operand must be present");
1677   }
1678   return verifyMapClause(*this, getMapVars());
1679 }
1680 
1681 //===----------------------------------------------------------------------===//
1682 // TargetEnterDataOp
1683 //===----------------------------------------------------------------------===//
1684 
1685 void TargetEnterDataOp::build(
1686     OpBuilder &builder, OperationState &state,
1687     const TargetEnterExitUpdateDataOperands &clauses) {
1688   MLIRContext *ctx = builder.getContext();
1689   TargetEnterDataOp::build(builder, state,
1690                            makeArrayAttr(ctx, clauses.dependKinds),
1691                            clauses.dependVars, clauses.device, clauses.ifExpr,
1692                            clauses.mapVars, clauses.nowait);
1693 }
1694 
1695 LogicalResult TargetEnterDataOp::verify() {
1696   LogicalResult verifyDependVars =
1697       verifyDependVarList(*this, getDependKinds(), getDependVars());
1698   return failed(verifyDependVars) ? verifyDependVars
1699                                   : verifyMapClause(*this, getMapVars());
1700 }
1701 
1702 //===----------------------------------------------------------------------===//
1703 // TargetExitDataOp
1704 //===----------------------------------------------------------------------===//
1705 
1706 void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
1707                              const TargetEnterExitUpdateDataOperands &clauses) {
1708   MLIRContext *ctx = builder.getContext();
1709   TargetExitDataOp::build(builder, state,
1710                           makeArrayAttr(ctx, clauses.dependKinds),
1711                           clauses.dependVars, clauses.device, clauses.ifExpr,
1712                           clauses.mapVars, clauses.nowait);
1713 }
1714 
1715 LogicalResult TargetExitDataOp::verify() {
1716   LogicalResult verifyDependVars =
1717       verifyDependVarList(*this, getDependKinds(), getDependVars());
1718   return failed(verifyDependVars) ? verifyDependVars
1719                                   : verifyMapClause(*this, getMapVars());
1720 }
1721 
1722 //===----------------------------------------------------------------------===//
1723 // TargetUpdateOp
1724 //===----------------------------------------------------------------------===//
1725 
1726 void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
1727                            const TargetEnterExitUpdateDataOperands &clauses) {
1728   MLIRContext *ctx = builder.getContext();
1729   TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
1730                         clauses.dependVars, clauses.device, clauses.ifExpr,
1731                         clauses.mapVars, clauses.nowait);
1732 }
1733 
1734 LogicalResult TargetUpdateOp::verify() {
1735   LogicalResult verifyDependVars =
1736       verifyDependVarList(*this, getDependKinds(), getDependVars());
1737   return failed(verifyDependVars) ? verifyDependVars
1738                                   : verifyMapClause(*this, getMapVars());
1739 }
1740 
1741 //===----------------------------------------------------------------------===//
1742 // TargetOp
1743 //===----------------------------------------------------------------------===//
1744 
1745 void TargetOp::build(OpBuilder &builder, OperationState &state,
1746                      const TargetOperands &clauses) {
1747   MLIRContext *ctx = builder.getContext();
1748   // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
1749   // inReductionByref, inReductionSyms.
1750   TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
1751                   clauses.bare, makeArrayAttr(ctx, clauses.dependKinds),
1752                   clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
1753                   clauses.hostEvalVars, clauses.ifExpr,
1754                   /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
1755                   /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
1756                   clauses.mapVars, clauses.nowait, clauses.privateVars,
1757                   makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit,
1758                   /*private_maps=*/nullptr);
1759 }
1760 
1761 LogicalResult TargetOp::verify() {
1762   LogicalResult verifyDependVars =
1763       verifyDependVarList(*this, getDependKinds(), getDependVars());
1764 
1765   if (failed(verifyDependVars))
1766     return verifyDependVars;
1767 
1768   LogicalResult verifyMapVars = verifyMapClause(*this, getMapVars());
1769 
1770   if (failed(verifyMapVars))
1771     return verifyMapVars;
1772 
1773   return verifyPrivateVarsMapping(*this);
1774 }
1775 
1776 LogicalResult TargetOp::verifyRegions() {
1777   auto teamsOps = getOps<TeamsOp>();
1778   if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
1779     return emitError("target containing multiple 'omp.teams' nested ops");
1780 
1781   // Check that host_eval values are only used in legal ways.
1782   llvm::omp::OMPTgtExecModeFlags execFlags = getKernelExecFlags();
1783   for (Value hostEvalArg :
1784        cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
1785     for (Operation *user : hostEvalArg.getUsers()) {
1786       if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
1787         if (llvm::is_contained({teamsOp.getNumTeamsLower(),
1788                                 teamsOp.getNumTeamsUpper(),
1789                                 teamsOp.getThreadLimit()},
1790                                hostEvalArg))
1791           continue;
1792 
1793         return emitOpError() << "host_eval argument only legal as 'num_teams' "
1794                                 "and 'thread_limit' in 'omp.teams'";
1795       }
1796       if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
1797         if (execFlags == llvm::omp::OMP_TGT_EXEC_MODE_SPMD &&
1798             hostEvalArg == parallelOp.getNumThreads())
1799           continue;
1800 
1801         return emitOpError()
1802                << "host_eval argument only legal as 'num_threads' in "
1803                   "'omp.parallel' when representing target SPMD";
1804       }
1805       if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
1806         if (execFlags != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC &&
1807             (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
1808              llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
1809              llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
1810           continue;
1811 
1812         return emitOpError() << "host_eval argument only legal as loop bounds "
1813                                 "and steps in 'omp.loop_nest' when "
1814                                 "representing target SPMD or Generic-SPMD";
1815       }
1816 
1817       return emitOpError() << "host_eval argument illegal use in '"
1818                            << user->getName() << "' operation";
1819     }
1820   }
1821   return success();
1822 }
1823 
1824 /// Only allow OpenMP terminators and non-OpenMP ops that have known memory
1825 /// effects, but don't include a memory write effect.
1826 static bool siblingAllowedInCapture(Operation *op) {
1827   if (!op)
1828     return false;
1829 
1830   bool isOmpDialect =
1831       op->getContext()->getLoadedDialect<omp::OpenMPDialect>() ==
1832       op->getDialect();
1833 
1834   if (isOmpDialect)
1835     return op->hasTrait<OpTrait::IsTerminator>();
1836 
1837   if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
1838     SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4> effects;
1839     memOp.getEffects(effects);
1840     return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
1841       return isa<MemoryEffects::Write>(effect.getEffect()) &&
1842              isa<SideEffects::AutomaticAllocationScopeResource>(
1843                  effect.getResource());
1844     });
1845   }
1846   return true;
1847 }
1848 
1849 Operation *TargetOp::getInnermostCapturedOmpOp() {
1850   Dialect *ompDialect = (*this)->getDialect();
1851   Operation *capturedOp = nullptr;
1852   DominanceInfo domInfo;
1853 
1854   // Process in pre-order to check operations from outermost to innermost,
1855   // ensuring we only enter the region of an operation if it meets the criteria
1856   // for being captured. We stop the exploration of nested operations as soon as
1857   // we process a region holding no operations to be captured.
1858   walk<WalkOrder::PreOrder>([&](Operation *op) {
1859     if (op == *this)
1860       return WalkResult::advance();
1861 
1862     // Ignore operations of other dialects or omp operations with no regions,
1863     // because these will only be checked if they are siblings of an omp
1864     // operation that can potentially be captured.
1865     bool isOmpDialect = op->getDialect() == ompDialect;
1866     bool hasRegions = op->getNumRegions() > 0;
1867     if (!isOmpDialect || !hasRegions)
1868       return WalkResult::skip();
1869 
1870     // This operation cannot be captured if it can be executed more than once
1871     // (i.e. its block's successors can reach it) or if it's not guaranteed to
1872     // be executed before all exits of the region (i.e. it doesn't dominate all
1873     // blocks with no successors reachable from the entry block).
1874     Region *parentRegion = op->getParentRegion();
1875     Block *parentBlock = op->getBlock();
1876 
1877     for (Block *successor : parentBlock->getSuccessors())
1878       if (successor->isReachable(parentBlock))
1879         return WalkResult::interrupt();
1880 
1881     for (Block &block : *parentRegion)
1882       if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
1883           !domInfo.dominates(parentBlock, &block))
1884         return WalkResult::interrupt();
1885 
1886     // Don't capture this op if it has a not-allowed sibling, and stop recursing
1887     // into nested operations.
1888     for (Operation &sibling : op->getParentRegion()->getOps())
1889       if (&sibling != op && !siblingAllowedInCapture(&sibling))
1890         return WalkResult::interrupt();
1891 
1892     // Don't continue capturing nested operations if we reach an omp.loop_nest.
1893     // Otherwise, process the contents of this operation.
1894     capturedOp = op;
1895     return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
1896                                      : WalkResult::advance();
1897   });
1898 
1899   return capturedOp;
1900 }
1901 
1902 llvm::omp::OMPTgtExecModeFlags TargetOp::getKernelExecFlags() {
1903   using namespace llvm::omp;
1904 
1905   // Make sure this region is capturing a loop. Otherwise, it's a generic
1906   // kernel.
1907   Operation *capturedOp = getInnermostCapturedOmpOp();
1908   if (!isa_and_present<LoopNestOp>(capturedOp))
1909     return OMP_TGT_EXEC_MODE_GENERIC;
1910 
1911   SmallVector<LoopWrapperInterface> wrappers;
1912   cast<LoopNestOp>(capturedOp).gatherWrappers(wrappers);
1913   assert(!wrappers.empty());
1914 
1915   // Ignore optional SIMD leaf construct.
1916   auto *innermostWrapper = wrappers.begin();
1917   if (isa<SimdOp>(innermostWrapper))
1918     innermostWrapper = std::next(innermostWrapper);
1919 
1920   long numWrappers = std::distance(innermostWrapper, wrappers.end());
1921 
1922   // Detect Generic-SPMD: target-teams-distribute[-simd].
1923   if (numWrappers == 1) {
1924     if (!isa<DistributeOp>(innermostWrapper))
1925       return OMP_TGT_EXEC_MODE_GENERIC;
1926 
1927     Operation *teamsOp = (*innermostWrapper)->getParentOp();
1928     if (!isa_and_present<TeamsOp>(teamsOp))
1929       return OMP_TGT_EXEC_MODE_GENERIC;
1930 
1931     if (teamsOp->getParentOp() == *this)
1932       return OMP_TGT_EXEC_MODE_GENERIC_SPMD;
1933   }
1934 
1935   // Detect SPMD: target-teams-distribute-parallel-wsloop[-simd].
1936   if (numWrappers == 2) {
1937     if (!isa<WsloopOp>(innermostWrapper))
1938       return OMP_TGT_EXEC_MODE_GENERIC;
1939 
1940     innermostWrapper = std::next(innermostWrapper);
1941     if (!isa<DistributeOp>(innermostWrapper))
1942       return OMP_TGT_EXEC_MODE_GENERIC;
1943 
1944     Operation *parallelOp = (*innermostWrapper)->getParentOp();
1945     if (!isa_and_present<ParallelOp>(parallelOp))
1946       return OMP_TGT_EXEC_MODE_GENERIC;
1947 
1948     Operation *teamsOp = parallelOp->getParentOp();
1949     if (!isa_and_present<TeamsOp>(teamsOp))
1950       return OMP_TGT_EXEC_MODE_GENERIC;
1951 
1952     if (teamsOp->getParentOp() == *this)
1953       return OMP_TGT_EXEC_MODE_SPMD;
1954   }
1955 
1956   return OMP_TGT_EXEC_MODE_GENERIC;
1957 }
1958 
1959 //===----------------------------------------------------------------------===//
1960 // ParallelOp
1961 //===----------------------------------------------------------------------===//
1962 
1963 void ParallelOp::build(OpBuilder &builder, OperationState &state,
1964                        ArrayRef<NamedAttribute> attributes) {
1965   ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
1966                     /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
1967                     /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
1968                     /*private_syms=*/nullptr, /*proc_bind_kind=*/nullptr,
1969                     /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
1970                     /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
1971   state.addAttributes(attributes);
1972 }
1973 
1974 void ParallelOp::build(OpBuilder &builder, OperationState &state,
1975                        const ParallelOperands &clauses) {
1976   MLIRContext *ctx = builder.getContext();
1977   ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
1978                     clauses.ifExpr, clauses.numThreads, clauses.privateVars,
1979                     makeArrayAttr(ctx, clauses.privateSyms),
1980                     clauses.procBindKind, clauses.reductionMod,
1981                     clauses.reductionVars,
1982                     makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
1983                     makeArrayAttr(ctx, clauses.reductionSyms));
1984 }
1985 
1986 template <typename OpType>
1987 static LogicalResult verifyPrivateVarList(OpType &op) {
1988   auto privateVars = op.getPrivateVars();
1989   auto privateSyms = op.getPrivateSymsAttr();
1990 
1991   if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty()))
1992     return success();
1993 
1994   auto numPrivateVars = privateVars.size();
1995   auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size();
1996 
1997   if (numPrivateVars != numPrivateSyms)
1998     return op.emitError() << "inconsistent number of private variables and "
1999                              "privatizer op symbols, private vars: "
2000                           << numPrivateVars
2001                           << " vs. privatizer op symbols: " << numPrivateSyms;
2002 
2003   for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2004     Type varType = std::get<0>(privateVarInfo).getType();
2005     SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2006     PrivateClauseOp privatizerOp =
2007         SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op, privateSym);
2008 
2009     if (privatizerOp == nullptr)
2010       return op.emitError() << "failed to lookup privatizer op with symbol: '"
2011                             << privateSym << "'";
2012 
2013     Type privatizerType = privatizerOp.getType();
2014 
2015     if (varType != privatizerType)
2016       return op.emitError()
2017              << "type mismatch between a "
2018              << (privatizerOp.getDataSharingType() ==
2019                          DataSharingClauseType::Private
2020                      ? "private"
2021                      : "firstprivate")
2022              << " variable and its privatizer op, var type: " << varType
2023              << " vs. privatizer op type: " << privatizerType;
2024   }
2025 
2026   return success();
2027 }
2028 
2029 LogicalResult ParallelOp::verify() {
2030   if (getAllocateVars().size() != getAllocatorVars().size())
2031     return emitError(
2032         "expected equal sizes for allocate and allocator variables");
2033 
2034   if (failed(verifyPrivateVarList(*this)))
2035     return failure();
2036 
2037   return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2038                                 getReductionByref());
2039 }
2040 
2041 LogicalResult ParallelOp::verifyRegions() {
2042   auto distributeChildOps = getOps<DistributeOp>();
2043   if (!distributeChildOps.empty()) {
2044     if (!isComposite())
2045       return emitError()
2046              << "'omp.composite' attribute missing from composite operation";
2047 
2048     auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
2049     Operation &distributeOp = **distributeChildOps.begin();
2050     for (Operation &childOp : getOps()) {
2051       if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2052         continue;
2053 
2054       if (!childOp.hasTrait<OpTrait::IsTerminator>())
2055         return emitError() << "unexpected OpenMP operation inside of composite "
2056                               "'omp.parallel'";
2057     }
2058   } else if (isComposite()) {
2059     return emitError()
2060            << "'omp.composite' attribute present in non-composite operation";
2061   }
2062   return success();
2063 }
2064 
2065 //===----------------------------------------------------------------------===//
2066 // TeamsOp
2067 //===----------------------------------------------------------------------===//
2068 
2069 static bool opInGlobalImplicitParallelRegion(Operation *op) {
2070   while ((op = op->getParentOp()))
2071     if (isa<OpenMPDialect>(op->getDialect()))
2072       return false;
2073   return true;
2074 }
2075 
2076 void TeamsOp::build(OpBuilder &builder, OperationState &state,
2077                     const TeamsOperands &clauses) {
2078   MLIRContext *ctx = builder.getContext();
2079   // TODO Store clauses in op: privateVars, privateSyms.
2080   TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2081                  clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
2082                  /*private_vars=*/{}, /*private_syms=*/nullptr,
2083                  clauses.reductionMod, clauses.reductionVars,
2084                  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2085                  makeArrayAttr(ctx, clauses.reductionSyms),
2086                  clauses.threadLimit);
2087 }
2088 
2089 LogicalResult TeamsOp::verify() {
2090   // Check parent region
2091   // TODO If nested inside of a target region, also check that it does not
2092   // contain any statements, declarations or directives other than this
2093   // omp.teams construct. The issue is how to support the initialization of
2094   // this operation's own arguments (allow SSA values across omp.target?).
2095   Operation *op = getOperation();
2096   if (!isa<TargetOp>(op->getParentOp()) &&
2097       !opInGlobalImplicitParallelRegion(op))
2098     return emitError("expected to be nested inside of omp.target or not nested "
2099                      "in any OpenMP dialect operations");
2100 
2101   // Check for num_teams clause restrictions
2102   if (auto numTeamsLowerBound = getNumTeamsLower()) {
2103     auto numTeamsUpperBound = getNumTeamsUpper();
2104     if (!numTeamsUpperBound)
2105       return emitError("expected num_teams upper bound to be defined if the "
2106                        "lower bound is defined");
2107     if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
2108       return emitError(
2109           "expected num_teams upper bound and lower bound to be the same type");
2110   }
2111 
2112   // Check for allocate clause restrictions
2113   if (getAllocateVars().size() != getAllocatorVars().size())
2114     return emitError(
2115         "expected equal sizes for allocate and allocator variables");
2116 
2117   return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2118                                 getReductionByref());
2119 }
2120 
2121 //===----------------------------------------------------------------------===//
2122 // SectionOp
2123 //===----------------------------------------------------------------------===//
2124 
2125 unsigned SectionOp::numPrivateBlockArgs() {
2126   return getParentOp().numPrivateBlockArgs();
2127 }
2128 
2129 unsigned SectionOp::numReductionBlockArgs() {
2130   return getParentOp().numReductionBlockArgs();
2131 }
2132 
2133 //===----------------------------------------------------------------------===//
2134 // SectionsOp
2135 //===----------------------------------------------------------------------===//
2136 
2137 void SectionsOp::build(OpBuilder &builder, OperationState &state,
2138                        const SectionsOperands &clauses) {
2139   MLIRContext *ctx = builder.getContext();
2140   // TODO Store clauses in op: privateVars, privateSyms.
2141   SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2142                     clauses.nowait, /*private_vars=*/{},
2143                     /*private_syms=*/nullptr, clauses.reductionMod,
2144                     clauses.reductionVars,
2145                     makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2146                     makeArrayAttr(ctx, clauses.reductionSyms));
2147 }
2148 
2149 LogicalResult SectionsOp::verify() {
2150   if (getAllocateVars().size() != getAllocatorVars().size())
2151     return emitError(
2152         "expected equal sizes for allocate and allocator variables");
2153 
2154   return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2155                                 getReductionByref());
2156 }
2157 
2158 LogicalResult SectionsOp::verifyRegions() {
2159   for (auto &inst : *getRegion().begin()) {
2160     if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2161       return emitOpError()
2162              << "expected omp.section op or terminator op inside region";
2163     }
2164   }
2165 
2166   return success();
2167 }
2168 
2169 //===----------------------------------------------------------------------===//
2170 // SingleOp
2171 //===----------------------------------------------------------------------===//
2172 
2173 void SingleOp::build(OpBuilder &builder, OperationState &state,
2174                      const SingleOperands &clauses) {
2175   MLIRContext *ctx = builder.getContext();
2176   // TODO Store clauses in op: privateVars, privateSyms.
2177   SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2178                   clauses.copyprivateVars,
2179                   makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2180                   /*private_vars=*/{}, /*private_syms=*/nullptr);
2181 }
2182 
2183 LogicalResult SingleOp::verify() {
2184   // Check for allocate clause restrictions
2185   if (getAllocateVars().size() != getAllocatorVars().size())
2186     return emitError(
2187         "expected equal sizes for allocate and allocator variables");
2188 
2189   return verifyCopyprivateVarList(*this, getCopyprivateVars(),
2190                                   getCopyprivateSyms());
2191 }
2192 
2193 //===----------------------------------------------------------------------===//
2194 // WorkshareOp
2195 //===----------------------------------------------------------------------===//
2196 
2197 void WorkshareOp::build(OpBuilder &builder, OperationState &state,
2198                         const WorkshareOperands &clauses) {
2199   WorkshareOp::build(builder, state, clauses.nowait);
2200 }
2201 
2202 //===----------------------------------------------------------------------===//
2203 // WorkshareLoopWrapperOp
2204 //===----------------------------------------------------------------------===//
2205 
2206 LogicalResult WorkshareLoopWrapperOp::verify() {
2207   if (!(*this)->getParentOfType<WorkshareOp>())
2208     return emitError() << "must be nested in an omp.workshare";
2209   if (getNestedWrapper())
2210     return emitError() << "cannot be composite";
2211   return success();
2212 }
2213 
2214 //===----------------------------------------------------------------------===//
2215 // LoopWrapperInterface
2216 //===----------------------------------------------------------------------===//
2217 
2218 LogicalResult LoopWrapperInterface::verifyImpl() {
2219   Operation *op = this->getOperation();
2220   if (!op->hasTrait<OpTrait::NoTerminator>() ||
2221       !op->hasTrait<OpTrait::SingleBlock>())
2222     return emitOpError() << "loop wrapper must also have the `NoTerminator` "
2223                             "and `SingleBlock` traits";
2224 
2225   if (op->getNumRegions() != 1)
2226     return emitOpError() << "loop wrapper does not contain exactly one region";
2227 
2228   Region &region = op->getRegion(0);
2229   if (range_size(region.getOps()) != 1)
2230     return emitOpError()
2231            << "loop wrapper does not contain exactly one nested op";
2232 
2233   Operation &firstOp = *region.op_begin();
2234   if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2235     return emitOpError() << "op nested in loop wrapper is not another loop "
2236                             "wrapper or `omp.loop_nest`";
2237 
2238   return success();
2239 }
2240 
2241 //===----------------------------------------------------------------------===//
2242 // LoopOp
2243 //===----------------------------------------------------------------------===//
2244 
2245 void LoopOp::build(OpBuilder &builder, OperationState &state,
2246                    const LoopOperands &clauses) {
2247   MLIRContext *ctx = builder.getContext();
2248 
2249   LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2250                 makeArrayAttr(ctx, clauses.privateSyms), clauses.order,
2251                 clauses.orderMod, clauses.reductionMod, clauses.reductionVars,
2252                 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2253                 makeArrayAttr(ctx, clauses.reductionSyms));
2254 }
2255 
2256 LogicalResult LoopOp::verify() {
2257   return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2258                                 getReductionByref());
2259 }
2260 
2261 LogicalResult LoopOp::verifyRegions() {
2262   if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2263       getNestedWrapper())
2264     return emitError() << "`omp.loop` expected to be a standalone loop wrapper";
2265 
2266   return success();
2267 }
2268 
2269 //===----------------------------------------------------------------------===//
2270 // WsloopOp
2271 //===----------------------------------------------------------------------===//
2272 
2273 void WsloopOp::build(OpBuilder &builder, OperationState &state,
2274                      ArrayRef<NamedAttribute> attributes) {
2275   build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
2276         /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
2277         /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
2278         /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
2279         /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
2280         /*reduction_byref=*/nullptr,
2281         /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
2282         /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
2283         /*schedule_simd=*/false);
2284   state.addAttributes(attributes);
2285 }
2286 
2287 void WsloopOp::build(OpBuilder &builder, OperationState &state,
2288                      const WsloopOperands &clauses) {
2289   MLIRContext *ctx = builder.getContext();
2290   // TODO: Store clauses in op: allocateVars, allocatorVars, privateVars,
2291   // privateSyms.
2292   WsloopOp::build(builder, state,
2293                   /*allocate_vars=*/{}, /*allocator_vars=*/{},
2294                   clauses.linearVars, clauses.linearStepVars, clauses.nowait,
2295                   clauses.order, clauses.orderMod, clauses.ordered,
2296                   clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
2297                   clauses.reductionMod, clauses.reductionVars,
2298                   makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2299                   makeArrayAttr(ctx, clauses.reductionSyms),
2300                   clauses.scheduleKind, clauses.scheduleChunk,
2301                   clauses.scheduleMod, clauses.scheduleSimd);
2302 }
2303 
2304 LogicalResult WsloopOp::verify() {
2305   return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2306                                 getReductionByref());
2307 }
2308 
2309 LogicalResult WsloopOp::verifyRegions() {
2310   bool isCompositeChildLeaf =
2311       llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2312 
2313   if (LoopWrapperInterface nested = getNestedWrapper()) {
2314     if (!isComposite())
2315       return emitError()
2316              << "'omp.composite' attribute missing from composite wrapper";
2317 
2318     // Check for the allowed leaf constructs that may appear in a composite
2319     // construct directly after DO/FOR.
2320     if (!isa<SimdOp>(nested))
2321       return emitError() << "only supported nested wrapper is 'omp.simd'";
2322 
2323   } else if (isComposite() && !isCompositeChildLeaf) {
2324     return emitError()
2325            << "'omp.composite' attribute present in non-composite wrapper";
2326   } else if (!isComposite() && isCompositeChildLeaf) {
2327     return emitError()
2328            << "'omp.composite' attribute missing from composite wrapper";
2329   }
2330 
2331   return success();
2332 }
2333 
2334 //===----------------------------------------------------------------------===//
2335 // Simd construct [2.9.3.1]
2336 //===----------------------------------------------------------------------===//
2337 
2338 void SimdOp::build(OpBuilder &builder, OperationState &state,
2339                    const SimdOperands &clauses) {
2340   MLIRContext *ctx = builder.getContext();
2341   // TODO Store clauses in op: linearVars, linearStepVars, privateVars,
2342   // privateSyms.
2343   SimdOp::build(builder, state, clauses.alignedVars,
2344                 makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
2345                 /*linear_vars=*/{}, /*linear_step_vars=*/{},
2346                 clauses.nontemporalVars, clauses.order, clauses.orderMod,
2347                 clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
2348                 clauses.reductionMod, clauses.reductionVars,
2349                 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2350                 makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
2351                 clauses.simdlen);
2352 }
2353 
2354 LogicalResult SimdOp::verify() {
2355   if (getSimdlen().has_value() && getSafelen().has_value() &&
2356       getSimdlen().value() > getSafelen().value())
2357     return emitOpError()
2358            << "simdlen clause and safelen clause are both present, but the "
2359               "simdlen value is not less than or equal to safelen value";
2360 
2361   if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed())
2362     return failure();
2363 
2364   if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
2365     return failure();
2366 
2367   bool isCompositeChildLeaf =
2368       llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2369 
2370   if (!isComposite() && isCompositeChildLeaf)
2371     return emitError()
2372            << "'omp.composite' attribute missing from composite wrapper";
2373 
2374   if (isComposite() && !isCompositeChildLeaf)
2375     return emitError()
2376            << "'omp.composite' attribute present in non-composite wrapper";
2377 
2378   return success();
2379 }
2380 
2381 LogicalResult SimdOp::verifyRegions() {
2382   if (getNestedWrapper())
2383     return emitOpError() << "must wrap an 'omp.loop_nest' directly";
2384 
2385   return success();
2386 }
2387 
2388 //===----------------------------------------------------------------------===//
2389 // Distribute construct [2.9.4.1]
2390 //===----------------------------------------------------------------------===//
2391 
2392 void DistributeOp::build(OpBuilder &builder, OperationState &state,
2393                          const DistributeOperands &clauses) {
2394   DistributeOp::build(builder, state, clauses.allocateVars,
2395                       clauses.allocatorVars, clauses.distScheduleStatic,
2396                       clauses.distScheduleChunkSize, clauses.order,
2397                       clauses.orderMod, clauses.privateVars,
2398                       makeArrayAttr(builder.getContext(), clauses.privateSyms));
2399 }
2400 
2401 LogicalResult DistributeOp::verify() {
2402   if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
2403     return emitOpError() << "chunk size set without "
2404                             "dist_schedule_static being present";
2405 
2406   if (getAllocateVars().size() != getAllocatorVars().size())
2407     return emitError(
2408         "expected equal sizes for allocate and allocator variables");
2409 
2410   return success();
2411 }
2412 
2413 LogicalResult DistributeOp::verifyRegions() {
2414   if (LoopWrapperInterface nested = getNestedWrapper()) {
2415     if (!isComposite())
2416       return emitError()
2417              << "'omp.composite' attribute missing from composite wrapper";
2418     // Check for the allowed leaf constructs that may appear in a composite
2419     // construct directly after DISTRIBUTE.
2420     if (isa<WsloopOp>(nested)) {
2421       if (!llvm::dyn_cast_if_present<ParallelOp>((*this)->getParentOp()))
2422         return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
2423                               "when 'omp.parallel' is the direct parent";
2424     } else if (!isa<SimdOp>(nested))
2425       return emitError() << "only supported nested wrappers are 'omp.simd' and "
2426                             "'omp.wsloop'";
2427   } else if (isComposite()) {
2428     return emitError()
2429            << "'omp.composite' attribute present in non-composite wrapper";
2430   }
2431 
2432   return success();
2433 }
2434 
2435 //===----------------------------------------------------------------------===//
2436 // DeclareReductionOp
2437 //===----------------------------------------------------------------------===//
2438 
2439 LogicalResult DeclareReductionOp::verifyRegions() {
2440   if (!getAllocRegion().empty()) {
2441     for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
2442       if (yieldOp.getResults().size() != 1 ||
2443           yieldOp.getResults().getTypes()[0] != getType())
2444         return emitOpError() << "expects alloc region to yield a value "
2445                                 "of the reduction type";
2446     }
2447   }
2448 
2449   if (getInitializerRegion().empty())
2450     return emitOpError() << "expects non-empty initializer region";
2451   Block &initializerEntryBlock = getInitializerRegion().front();
2452 
2453   if (initializerEntryBlock.getNumArguments() == 1) {
2454     if (!getAllocRegion().empty())
2455       return emitOpError() << "expects two arguments to the initializer region "
2456                               "when an allocation region is used";
2457   } else if (initializerEntryBlock.getNumArguments() == 2) {
2458     if (getAllocRegion().empty())
2459       return emitOpError() << "expects one argument to the initializer region "
2460                               "when no allocation region is used";
2461   } else {
2462     return emitOpError()
2463            << "expects one or two arguments to the initializer region";
2464   }
2465 
2466   for (mlir::Value arg : initializerEntryBlock.getArguments())
2467     if (arg.getType() != getType())
2468       return emitOpError() << "expects initializer region argument to match "
2469                               "the reduction type";
2470 
2471   for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
2472     if (yieldOp.getResults().size() != 1 ||
2473         yieldOp.getResults().getTypes()[0] != getType())
2474       return emitOpError() << "expects initializer region to yield a value "
2475                               "of the reduction type";
2476   }
2477 
2478   if (getReductionRegion().empty())
2479     return emitOpError() << "expects non-empty reduction region";
2480   Block &reductionEntryBlock = getReductionRegion().front();
2481   if (reductionEntryBlock.getNumArguments() != 2 ||
2482       reductionEntryBlock.getArgumentTypes()[0] !=
2483           reductionEntryBlock.getArgumentTypes()[1] ||
2484       reductionEntryBlock.getArgumentTypes()[0] != getType())
2485     return emitOpError() << "expects reduction region with two arguments of "
2486                             "the reduction type";
2487   for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
2488     if (yieldOp.getResults().size() != 1 ||
2489         yieldOp.getResults().getTypes()[0] != getType())
2490       return emitOpError() << "expects reduction region to yield a value "
2491                               "of the reduction type";
2492   }
2493 
2494   if (!getAtomicReductionRegion().empty()) {
2495     Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
2496     if (atomicReductionEntryBlock.getNumArguments() != 2 ||
2497         atomicReductionEntryBlock.getArgumentTypes()[0] !=
2498             atomicReductionEntryBlock.getArgumentTypes()[1])
2499       return emitOpError() << "expects atomic reduction region with two "
2500                               "arguments of the same type";
2501     auto ptrType = llvm::dyn_cast<PointerLikeType>(
2502         atomicReductionEntryBlock.getArgumentTypes()[0]);
2503     if (!ptrType ||
2504         (ptrType.getElementType() && ptrType.getElementType() != getType()))
2505       return emitOpError() << "expects atomic reduction region arguments to "
2506                               "be accumulators containing the reduction type";
2507   }
2508 
2509   if (getCleanupRegion().empty())
2510     return success();
2511   Block &cleanupEntryBlock = getCleanupRegion().front();
2512   if (cleanupEntryBlock.getNumArguments() != 1 ||
2513       cleanupEntryBlock.getArgument(0).getType() != getType())
2514     return emitOpError() << "expects cleanup region with one argument "
2515                             "of the reduction type";
2516 
2517   return success();
2518 }
2519 
2520 //===----------------------------------------------------------------------===//
2521 // TaskOp
2522 //===----------------------------------------------------------------------===//
2523 
2524 void TaskOp::build(OpBuilder &builder, OperationState &state,
2525                    const TaskOperands &clauses) {
2526   MLIRContext *ctx = builder.getContext();
2527   TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2528                 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2529                 clauses.final, clauses.ifExpr, clauses.inReductionVars,
2530                 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2531                 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2532                 clauses.priority, /*private_vars=*/clauses.privateVars,
2533                 /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
2534                 clauses.untied, clauses.eventHandle);
2535 }
2536 
2537 LogicalResult TaskOp::verify() {
2538   LogicalResult verifyDependVars =
2539       verifyDependVarList(*this, getDependKinds(), getDependVars());
2540   return failed(verifyDependVars)
2541              ? verifyDependVars
2542              : verifyReductionVarList(*this, getInReductionSyms(),
2543                                       getInReductionVars(),
2544                                       getInReductionByref());
2545 }
2546 
2547 //===----------------------------------------------------------------------===//
2548 // TaskgroupOp
2549 //===----------------------------------------------------------------------===//
2550 
2551 void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
2552                         const TaskgroupOperands &clauses) {
2553   MLIRContext *ctx = builder.getContext();
2554   TaskgroupOp::build(builder, state, clauses.allocateVars,
2555                      clauses.allocatorVars, clauses.taskReductionVars,
2556                      makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
2557                      makeArrayAttr(ctx, clauses.taskReductionSyms));
2558 }
2559 
2560 LogicalResult TaskgroupOp::verify() {
2561   return verifyReductionVarList(*this, getTaskReductionSyms(),
2562                                 getTaskReductionVars(),
2563                                 getTaskReductionByref());
2564 }
2565 
2566 //===----------------------------------------------------------------------===//
2567 // TaskloopOp
2568 //===----------------------------------------------------------------------===//
2569 
2570 void TaskloopOp::build(OpBuilder &builder, OperationState &state,
2571                        const TaskloopOperands &clauses) {
2572   MLIRContext *ctx = builder.getContext();
2573   // TODO Store clauses in op: privateVars, privateSyms.
2574   TaskloopOp::build(
2575       builder, state, clauses.allocateVars, clauses.allocatorVars,
2576       clauses.final, clauses.grainsize, clauses.ifExpr, clauses.inReductionVars,
2577       makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2578       makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2579       clauses.nogroup, clauses.numTasks, clauses.priority, /*private_vars=*/{},
2580       /*private_syms=*/nullptr, clauses.reductionMod, clauses.reductionVars,
2581       makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2582       makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
2583 }
2584 
2585 SmallVector<Value> TaskloopOp::getAllReductionVars() {
2586   SmallVector<Value> allReductionNvars(getInReductionVars().begin(),
2587                                        getInReductionVars().end());
2588   allReductionNvars.insert(allReductionNvars.end(), getReductionVars().begin(),
2589                            getReductionVars().end());
2590   return allReductionNvars;
2591 }
2592 
2593 LogicalResult TaskloopOp::verify() {
2594   if (getAllocateVars().size() != getAllocatorVars().size())
2595     return emitError(
2596         "expected equal sizes for allocate and allocator variables");
2597   if (failed(verifyReductionVarList(*this, getReductionSyms(),
2598                                     getReductionVars(), getReductionByref())) ||
2599       failed(verifyReductionVarList(*this, getInReductionSyms(),
2600                                     getInReductionVars(),
2601                                     getInReductionByref())))
2602     return failure();
2603 
2604   if (!getReductionVars().empty() && getNogroup())
2605     return emitError("if a reduction clause is present on the taskloop "
2606                      "directive, the nogroup clause must not be specified");
2607   for (auto var : getReductionVars()) {
2608     if (llvm::is_contained(getInReductionVars(), var))
2609       return emitError("the same list item cannot appear in both a reduction "
2610                        "and an in_reduction clause");
2611   }
2612 
2613   if (getGrainsize() && getNumTasks()) {
2614     return emitError(
2615         "the grainsize clause and num_tasks clause are mutually exclusive and "
2616         "may not appear on the same taskloop directive");
2617   }
2618 
2619   return success();
2620 }
2621 
2622 LogicalResult TaskloopOp::verifyRegions() {
2623   if (LoopWrapperInterface nested = getNestedWrapper()) {
2624     if (!isComposite())
2625       return emitError()
2626              << "'omp.composite' attribute missing from composite wrapper";
2627 
2628     // Check for the allowed leaf constructs that may appear in a composite
2629     // construct directly after TASKLOOP.
2630     if (!isa<SimdOp>(nested))
2631       return emitError() << "only supported nested wrapper is 'omp.simd'";
2632   } else if (isComposite()) {
2633     return emitError()
2634            << "'omp.composite' attribute present in non-composite wrapper";
2635   }
2636 
2637   return success();
2638 }
2639 
2640 //===----------------------------------------------------------------------===//
2641 // LoopNestOp
2642 //===----------------------------------------------------------------------===//
2643 
2644 ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
2645   // Parse an opening `(` followed by induction variables followed by `)`
2646   SmallVector<OpAsmParser::Argument> ivs;
2647   SmallVector<OpAsmParser::UnresolvedOperand> lbs, ubs;
2648   Type loopVarType;
2649   if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
2650       parser.parseColonType(loopVarType) ||
2651       // Parse loop bounds.
2652       parser.parseEqual() ||
2653       parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
2654       parser.parseKeyword("to") ||
2655       parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
2656     return failure();
2657 
2658   for (auto &iv : ivs)
2659     iv.type = loopVarType;
2660 
2661   // Parse "inclusive" flag.
2662   if (succeeded(parser.parseOptionalKeyword("inclusive")))
2663     result.addAttribute("loop_inclusive",
2664                         UnitAttr::get(parser.getBuilder().getContext()));
2665 
2666   // Parse step values.
2667   SmallVector<OpAsmParser::UnresolvedOperand> steps;
2668   if (parser.parseKeyword("step") ||
2669       parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
2670     return failure();
2671 
2672   // Parse the body.
2673   Region *region = result.addRegion();
2674   if (parser.parseRegion(*region, ivs))
2675     return failure();
2676 
2677   // Resolve operands.
2678   if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
2679       parser.resolveOperands(ubs, loopVarType, result.operands) ||
2680       parser.resolveOperands(steps, loopVarType, result.operands))
2681     return failure();
2682 
2683   // Parse the optional attribute list.
2684   return parser.parseOptionalAttrDict(result.attributes);
2685 }
2686 
2687 void LoopNestOp::print(OpAsmPrinter &p) {
2688   Region &region = getRegion();
2689   auto args = region.getArguments();
2690   p << " (" << args << ") : " << args[0].getType() << " = ("
2691     << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
2692   if (getLoopInclusive())
2693     p << "inclusive ";
2694   p << "step (" << getLoopSteps() << ") ";
2695   p.printRegion(region, /*printEntryBlockArgs=*/false);
2696 }
2697 
2698 void LoopNestOp::build(OpBuilder &builder, OperationState &state,
2699                        const LoopNestOperands &clauses) {
2700   LoopNestOp::build(builder, state, clauses.loopLowerBounds,
2701                     clauses.loopUpperBounds, clauses.loopSteps,
2702                     clauses.loopInclusive);
2703 }
2704 
2705 LogicalResult LoopNestOp::verify() {
2706   if (getLoopLowerBounds().empty())
2707     return emitOpError() << "must represent at least one loop";
2708 
2709   if (getLoopLowerBounds().size() != getIVs().size())
2710     return emitOpError() << "number of range arguments and IVs do not match";
2711 
2712   for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
2713     if (lb.getType() != iv.getType())
2714       return emitOpError()
2715              << "range argument type does not match corresponding IV type";
2716   }
2717 
2718   if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
2719     return emitOpError() << "expects parent op to be a loop wrapper";
2720 
2721   return success();
2722 }
2723 
2724 void LoopNestOp::gatherWrappers(
2725     SmallVectorImpl<LoopWrapperInterface> &wrappers) {
2726   Operation *parent = (*this)->getParentOp();
2727   while (auto wrapper =
2728              llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
2729     wrappers.push_back(wrapper);
2730     parent = parent->getParentOp();
2731   }
2732 }
2733 
2734 //===----------------------------------------------------------------------===//
2735 // Critical construct (2.17.1)
2736 //===----------------------------------------------------------------------===//
2737 
2738 void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
2739                               const CriticalDeclareOperands &clauses) {
2740   CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
2741 }
2742 
2743 LogicalResult CriticalDeclareOp::verify() {
2744   return verifySynchronizationHint(*this, getHint());
2745 }
2746 
2747 LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2748   if (getNameAttr()) {
2749     SymbolRefAttr symbolRef = getNameAttr();
2750     auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
2751         *this, symbolRef);
2752     if (!decl) {
2753       return emitOpError() << "expected symbol reference " << symbolRef
2754                            << " to point to a critical declaration";
2755     }
2756   }
2757 
2758   return success();
2759 }
2760 
2761 //===----------------------------------------------------------------------===//
2762 // Ordered construct
2763 //===----------------------------------------------------------------------===//
2764 
2765 static LogicalResult verifyOrderedParent(Operation &op) {
2766   bool hasRegion = op.getNumRegions() > 0;
2767   auto loopOp = op.getParentOfType<LoopNestOp>();
2768   if (!loopOp) {
2769     if (hasRegion)
2770       return success();
2771 
2772     // TODO: Consider if this needs to be the case only for the standalone
2773     // variant of the ordered construct.
2774     return op.emitOpError() << "must be nested inside of a loop";
2775   }
2776 
2777   Operation *wrapper = loopOp->getParentOp();
2778   if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
2779     IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
2780     if (!orderedAttr)
2781       return op.emitOpError() << "the enclosing worksharing-loop region must "
2782                                  "have an ordered clause";
2783 
2784     if (hasRegion && orderedAttr.getInt() != 0)
2785       return op.emitOpError() << "the enclosing loop's ordered clause must not "
2786                                  "have a parameter present";
2787 
2788     if (!hasRegion && orderedAttr.getInt() == 0)
2789       return op.emitOpError() << "the enclosing loop's ordered clause must "
2790                                  "have a parameter present";
2791   } else if (!isa<SimdOp>(wrapper)) {
2792     return op.emitOpError() << "must be nested inside of a worksharing, simd "
2793                                "or worksharing simd loop";
2794   }
2795   return success();
2796 }
2797 
2798 void OrderedOp::build(OpBuilder &builder, OperationState &state,
2799                       const OrderedOperands &clauses) {
2800   OrderedOp::build(builder, state, clauses.doacrossDependType,
2801                    clauses.doacrossNumLoops, clauses.doacrossDependVars);
2802 }
2803 
2804 LogicalResult OrderedOp::verify() {
2805   if (failed(verifyOrderedParent(**this)))
2806     return failure();
2807 
2808   auto wrapper = (*this)->getParentOfType<WsloopOp>();
2809   if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
2810     return emitOpError() << "number of variables in depend clause does not "
2811                          << "match number of iteration variables in the "
2812                          << "doacross loop";
2813 
2814   return success();
2815 }
2816 
2817 void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
2818                             const OrderedRegionOperands &clauses) {
2819   OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
2820 }
2821 
2822 LogicalResult OrderedRegionOp::verify() { return verifyOrderedParent(**this); }
2823 
2824 //===----------------------------------------------------------------------===//
2825 // TaskwaitOp
2826 //===----------------------------------------------------------------------===//
2827 
2828 void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
2829                        const TaskwaitOperands &clauses) {
2830   // TODO Store clauses in op: dependKinds, dependVars, nowait.
2831   TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr,
2832                     /*depend_vars=*/{}, /*nowait=*/nullptr);
2833 }
2834 
2835 //===----------------------------------------------------------------------===//
2836 // Verifier for AtomicReadOp
2837 //===----------------------------------------------------------------------===//
2838 
2839 LogicalResult AtomicReadOp::verify() {
2840   if (verifyCommon().failed())
2841     return mlir::failure();
2842 
2843   if (auto mo = getMemoryOrder()) {
2844     if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2845         *mo == ClauseMemoryOrderKind::Release) {
2846       return emitError(
2847           "memory-order must not be acq_rel or release for atomic reads");
2848     }
2849   }
2850   return verifySynchronizationHint(*this, getHint());
2851 }
2852 
2853 //===----------------------------------------------------------------------===//
2854 // Verifier for AtomicWriteOp
2855 //===----------------------------------------------------------------------===//
2856 
2857 LogicalResult AtomicWriteOp::verify() {
2858   if (verifyCommon().failed())
2859     return mlir::failure();
2860 
2861   if (auto mo = getMemoryOrder()) {
2862     if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2863         *mo == ClauseMemoryOrderKind::Acquire) {
2864       return emitError(
2865           "memory-order must not be acq_rel or acquire for atomic writes");
2866     }
2867   }
2868   return verifySynchronizationHint(*this, getHint());
2869 }
2870 
2871 //===----------------------------------------------------------------------===//
2872 // Verifier for AtomicUpdateOp
2873 //===----------------------------------------------------------------------===//
2874 
2875 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
2876                                            PatternRewriter &rewriter) {
2877   if (op.isNoOp()) {
2878     rewriter.eraseOp(op);
2879     return success();
2880   }
2881   if (Value writeVal = op.getWriteOpVal()) {
2882     rewriter.replaceOpWithNewOp<AtomicWriteOp>(
2883         op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
2884     return success();
2885   }
2886   return failure();
2887 }
2888 
2889 LogicalResult AtomicUpdateOp::verify() {
2890   if (verifyCommon().failed())
2891     return mlir::failure();
2892 
2893   if (auto mo = getMemoryOrder()) {
2894     if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2895         *mo == ClauseMemoryOrderKind::Acquire) {
2896       return emitError(
2897           "memory-order must not be acq_rel or acquire for atomic updates");
2898     }
2899   }
2900 
2901   return verifySynchronizationHint(*this, getHint());
2902 }
2903 
2904 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
2905 
2906 //===----------------------------------------------------------------------===//
2907 // Verifier for AtomicCaptureOp
2908 //===----------------------------------------------------------------------===//
2909 
2910 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
2911   if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
2912     return op;
2913   return dyn_cast<AtomicReadOp>(getSecondOp());
2914 }
2915 
2916 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
2917   if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
2918     return op;
2919   return dyn_cast<AtomicWriteOp>(getSecondOp());
2920 }
2921 
2922 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
2923   if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
2924     return op;
2925   return dyn_cast<AtomicUpdateOp>(getSecondOp());
2926 }
2927 
2928 LogicalResult AtomicCaptureOp::verify() {
2929   return verifySynchronizationHint(*this, getHint());
2930 }
2931 
2932 LogicalResult AtomicCaptureOp::verifyRegions() {
2933   if (verifyRegionsCommon().failed())
2934     return mlir::failure();
2935 
2936   if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr("hint"))
2937     return emitOpError(
2938         "operations inside capture region must not have hint clause");
2939 
2940   if (getFirstOp()->getAttr("memory_order") ||
2941       getSecondOp()->getAttr("memory_order"))
2942     return emitOpError(
2943         "operations inside capture region must not have memory_order clause");
2944   return success();
2945 }
2946 
2947 //===----------------------------------------------------------------------===//
2948 // CancelOp
2949 //===----------------------------------------------------------------------===//
2950 
2951 void CancelOp::build(OpBuilder &builder, OperationState &state,
2952                      const CancelOperands &clauses) {
2953   CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
2954 }
2955 
2956 LogicalResult CancelOp::verify() {
2957   ClauseCancellationConstructType cct = getCancelDirective();
2958   Operation *parentOp = (*this)->getParentOp();
2959 
2960   if (!parentOp) {
2961     return emitOpError() << "must be used within a region supporting "
2962                             "cancel directive";
2963   }
2964 
2965   if ((cct == ClauseCancellationConstructType::Parallel) &&
2966       !isa<ParallelOp>(parentOp)) {
2967     return emitOpError() << "cancel parallel must appear "
2968                          << "inside a parallel region";
2969   }
2970   if (cct == ClauseCancellationConstructType::Loop) {
2971     auto loopOp = dyn_cast<LoopNestOp>(parentOp);
2972     auto wsloopOp = llvm::dyn_cast_if_present<WsloopOp>(
2973         loopOp ? loopOp->getParentOp() : nullptr);
2974 
2975     if (!wsloopOp) {
2976       return emitOpError()
2977              << "cancel loop must appear inside a worksharing-loop region";
2978     }
2979     if (wsloopOp.getNowaitAttr()) {
2980       return emitError() << "A worksharing construct that is canceled "
2981                          << "must not have a nowait clause";
2982     }
2983     if (wsloopOp.getOrderedAttr()) {
2984       return emitError() << "A worksharing construct that is canceled "
2985                          << "must not have an ordered clause";
2986     }
2987 
2988   } else if (cct == ClauseCancellationConstructType::Sections) {
2989     if (!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
2990       return emitOpError() << "cancel sections must appear "
2991                            << "inside a sections region";
2992     }
2993     if (isa_and_nonnull<SectionsOp>(parentOp->getParentOp()) &&
2994         cast<SectionsOp>(parentOp->getParentOp()).getNowaitAttr()) {
2995       return emitError() << "A sections construct that is canceled "
2996                          << "must not have a nowait clause";
2997     }
2998   }
2999   // TODO : Add more when we support taskgroup.
3000   return success();
3001 }
3002 
3003 //===----------------------------------------------------------------------===//
3004 // CancellationPointOp
3005 //===----------------------------------------------------------------------===//
3006 
3007 void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
3008                                 const CancellationPointOperands &clauses) {
3009   CancellationPointOp::build(builder, state, clauses.cancelDirective);
3010 }
3011 
3012 LogicalResult CancellationPointOp::verify() {
3013   ClauseCancellationConstructType cct = getCancelDirective();
3014   Operation *parentOp = (*this)->getParentOp();
3015 
3016   if (!parentOp) {
3017     return emitOpError() << "must be used within a region supporting "
3018                             "cancellation point directive";
3019   }
3020 
3021   if ((cct == ClauseCancellationConstructType::Parallel) &&
3022       !(isa<ParallelOp>(parentOp))) {
3023     return emitOpError() << "cancellation point parallel must appear "
3024                          << "inside a parallel region";
3025   }
3026   if ((cct == ClauseCancellationConstructType::Loop) &&
3027       (!isa<LoopNestOp>(parentOp) || !isa<WsloopOp>(parentOp->getParentOp()))) {
3028     return emitOpError() << "cancellation point loop must appear "
3029                          << "inside a worksharing-loop region";
3030   }
3031   if ((cct == ClauseCancellationConstructType::Sections) &&
3032       !(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
3033     return emitOpError() << "cancellation point sections must appear "
3034                          << "inside a sections region";
3035   }
3036   // TODO : Add more when we support taskgroup.
3037   return success();
3038 }
3039 
3040 //===----------------------------------------------------------------------===//
3041 // MapBoundsOp
3042 //===----------------------------------------------------------------------===//
3043 
3044 LogicalResult MapBoundsOp::verify() {
3045   auto extent = getExtent();
3046   auto upperbound = getUpperBound();
3047   if (!extent && !upperbound)
3048     return emitError("expected extent or upperbound.");
3049   return success();
3050 }
3051 
3052 void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3053                             TypeRange /*result_types*/, StringAttr symName,
3054                             TypeAttr type) {
3055   PrivateClauseOp::build(
3056       odsBuilder, odsState, symName, type,
3057       DataSharingClauseTypeAttr::get(odsBuilder.getContext(),
3058                                      DataSharingClauseType::Private));
3059 }
3060 
3061 LogicalResult PrivateClauseOp::verifyRegions() {
3062   Type symType = getType();
3063 
3064   auto verifyTerminator = [&](Operation *terminator,
3065                               bool yieldsValue) -> LogicalResult {
3066     if (!terminator->getBlock()->getSuccessors().empty())
3067       return success();
3068 
3069     if (!llvm::isa<YieldOp>(terminator))
3070       return mlir::emitError(terminator->getLoc())
3071              << "expected exit block terminator to be an `omp.yield` op.";
3072 
3073     YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
3074     TypeRange yieldedTypes = yieldOp.getResults().getTypes();
3075 
3076     if (!yieldsValue) {
3077       if (yieldedTypes.empty())
3078         return success();
3079 
3080       return mlir::emitError(terminator->getLoc())
3081              << "Did not expect any values to be yielded.";
3082     }
3083 
3084     if (yieldedTypes.size() == 1 && yieldedTypes.front() == symType)
3085       return success();
3086 
3087     auto error = mlir::emitError(yieldOp.getLoc())
3088                  << "Invalid yielded value. Expected type: " << symType
3089                  << ", got: ";
3090 
3091     if (yieldedTypes.empty())
3092       error << "None";
3093     else
3094       error << yieldedTypes;
3095 
3096     return error;
3097   };
3098 
3099   auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
3100                           StringRef regionName,
3101                           bool yieldsValue) -> LogicalResult {
3102     assert(!region.empty());
3103 
3104     if (region.getNumArguments() != expectedNumArgs)
3105       return mlir::emitError(region.getLoc())
3106              << "`" << regionName << "`: "
3107              << "expected " << expectedNumArgs
3108              << " region arguments, got: " << region.getNumArguments();
3109 
3110     for (Block &block : region) {
3111       // MLIR will verify the absence of the terminator for us.
3112       if (!block.mightHaveTerminator())
3113         continue;
3114 
3115       if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
3116         return failure();
3117     }
3118 
3119     return success();
3120   };
3121 
3122   if (failed(verifyRegion(getAllocRegion(), /*expectedNumArgs=*/1, "alloc",
3123                           /*yieldsValue=*/true)))
3124     return failure();
3125 
3126   DataSharingClauseType dsType = getDataSharingType();
3127 
3128   if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
3129     return emitError("`private` clauses require only an `alloc` region.");
3130 
3131   if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
3132     return emitError(
3133         "`firstprivate` clauses require both `alloc` and `copy` regions.");
3134 
3135   if (dsType == DataSharingClauseType::FirstPrivate &&
3136       failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy",
3137                           /*yieldsValue=*/true)))
3138     return failure();
3139 
3140   if (!getDeallocRegion().empty() &&
3141       failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc",
3142                           /*yieldsValue=*/false)))
3143     return failure();
3144 
3145   return success();
3146 }
3147 
3148 //===----------------------------------------------------------------------===//
3149 // Spec 5.2: Masked construct (10.5)
3150 //===----------------------------------------------------------------------===//
3151 
3152 void MaskedOp::build(OpBuilder &builder, OperationState &state,
3153                      const MaskedOperands &clauses) {
3154   MaskedOp::build(builder, state, clauses.filteredThreadId);
3155 }
3156 
3157 //===----------------------------------------------------------------------===//
3158 // Spec 5.2: Scan construct (5.6)
3159 //===----------------------------------------------------------------------===//
3160 
3161 void ScanOp::build(OpBuilder &builder, OperationState &state,
3162                    const ScanOperands &clauses) {
3163   ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
3164 }
3165 
3166 LogicalResult ScanOp::verify() {
3167   if (hasExclusiveVars() == hasInclusiveVars())
3168     return emitError(
3169         "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
3170   if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
3171     if (parentWsLoopOp.getReductionModAttr() &&
3172         parentWsLoopOp.getReductionModAttr().getValue() ==
3173             ReductionModifier::inscan)
3174       return success();
3175   }
3176   if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
3177     if (parentSimdOp.getReductionModAttr() &&
3178         parentSimdOp.getReductionModAttr().getValue() ==
3179             ReductionModifier::inscan)
3180       return success();
3181   }
3182   return emitError("SCAN directive needs to be enclosed within a parent "
3183                    "worksharing loop construct or SIMD construct with INSCAN "
3184                    "reduction modifier");
3185 }
3186 
3187 #define GET_ATTRDEF_CLASSES
3188 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
3189 
3190 #define GET_OP_CLASSES
3191 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
3192 
3193 #define GET_TYPEDEF_CLASSES
3194 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
3195