xref: /llvm-project/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (revision cbcb7ad32e6faca1f4c0f2f436e6076774104e17)
1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // =============================================================================
8 
9 #include "mlir/Dialect/OpenACC/OpenACC.h"
10 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
11 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinAttributes.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/OpImplementation.h"
19 #include "mlir/Support/LLVM.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/ADT/SmallSet.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/LogicalResult.h"
24 
25 using namespace mlir;
26 using namespace acc;
27 
28 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
29 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
30 #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
31 #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
32 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
33 
34 namespace {
35 struct MemRefPointerLikeModel
36     : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
37                                             MemRefType> {
38   Type getElementType(Type pointer) const {
39     return llvm::cast<MemRefType>(pointer).getElementType();
40   }
41 };
42 
43 struct LLVMPointerPointerLikeModel
44     : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
45                                             LLVM::LLVMPointerType> {
46   Type getElementType(Type pointer) const { return Type(); }
47 };
48 } // namespace
49 
50 //===----------------------------------------------------------------------===//
51 // OpenACC operations
52 //===----------------------------------------------------------------------===//
53 
54 void OpenACCDialect::initialize() {
55   addOperations<
56 #define GET_OP_LIST
57 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
58       >();
59   addAttributes<
60 #define GET_ATTRDEF_LIST
61 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
62       >();
63   addTypes<
64 #define GET_TYPEDEF_LIST
65 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
66       >();
67 
68   // By attaching interfaces here, we make the OpenACC dialect dependent on
69   // the other dialects. This is probably better than having dialects like LLVM
70   // and memref be dependent on OpenACC.
71   MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
72   LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
73       *getContext());
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // device_type support helpers
78 //===----------------------------------------------------------------------===//
79 
80 static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
81   if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
82     return true;
83   return false;
84 }
85 
86 static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
87                           mlir::acc::DeviceType deviceType) {
88   if (!hasDeviceTypeValues(arrayAttr))
89     return false;
90 
91   for (auto attr : *arrayAttr) {
92     auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
93     if (deviceTypeAttr.getValue() == deviceType)
94       return true;
95   }
96 
97   return false;
98 }
99 
100 static void printDeviceTypes(mlir::OpAsmPrinter &p,
101                              std::optional<mlir::ArrayAttr> deviceTypes) {
102   if (!hasDeviceTypeValues(deviceTypes))
103     return;
104 
105   p << "[";
106   llvm::interleaveComma(*deviceTypes, p,
107                         [&](mlir::Attribute attr) { p << attr; });
108   p << "]";
109 }
110 
111 static std::optional<unsigned> findSegment(ArrayAttr segments,
112                                            mlir::acc::DeviceType deviceType) {
113   unsigned segmentIdx = 0;
114   for (auto attr : segments) {
115     auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
116     if (deviceTypeAttr.getValue() == deviceType)
117       return std::make_optional(segmentIdx);
118     ++segmentIdx;
119   }
120   return std::nullopt;
121 }
122 
123 static mlir::Operation::operand_range
124 getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
125                       mlir::Operation::operand_range range,
126                       std::optional<llvm::ArrayRef<int32_t>> segments,
127                       mlir::acc::DeviceType deviceType) {
128   if (!arrayAttr)
129     return range.take_front(0);
130   if (auto pos = findSegment(*arrayAttr, deviceType)) {
131     int32_t nbOperandsBefore = 0;
132     for (unsigned i = 0; i < *pos; ++i)
133       nbOperandsBefore += (*segments)[i];
134     return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
135   }
136   return range.take_front(0);
137 }
138 
139 static mlir::Value
140 getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
141                    mlir::Operation::operand_range operands,
142                    std::optional<llvm::ArrayRef<int32_t>> segments,
143                    std::optional<mlir::ArrayAttr> hasWaitDevnum,
144                    mlir::acc::DeviceType deviceType) {
145   if (!hasDeviceTypeValues(deviceTypeAttr))
146     return {};
147   if (auto pos = findSegment(*deviceTypeAttr, deviceType))
148     if (hasWaitDevnum->getValue()[*pos])
149       return getValuesFromSegments(deviceTypeAttr, operands, segments,
150                                    deviceType)
151           .front();
152   return {};
153 }
154 
155 static mlir::Operation::operand_range
156 getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
157                            mlir::Operation::operand_range operands,
158                            std::optional<llvm::ArrayRef<int32_t>> segments,
159                            std::optional<mlir::ArrayAttr> hasWaitDevnum,
160                            mlir::acc::DeviceType deviceType) {
161   auto range =
162       getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
163   if (range.empty())
164     return range;
165   if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
166     if (hasWaitDevnum && *hasWaitDevnum) {
167       auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
168       if (boolAttr.getValue())
169         return range.drop_front(1); // first value is devnum
170     }
171   }
172   return range;
173 }
174 
175 template <typename Op>
176 static LogicalResult checkWaitAndAsyncConflict(Op op) {
177   for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
178        ++dtypeInt) {
179     auto dtype = static_cast<acc::DeviceType>(dtypeInt);
180 
181     // The async attribute represent the async clause without value. Therefore
182     // the attribute and operand cannot appear at the same time.
183     if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
184         op.hasAsyncOnly(dtype))
185       return op.emitError("async attribute cannot appear with asyncOperand");
186 
187     // The wait attribute represent the wait clause without values. Therefore
188     // the attribute and operands cannot appear at the same time.
189     if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
190         op.hasWaitOnly(dtype))
191       return op.emitError("wait attribute cannot appear with waitOperands");
192   }
193   return success();
194 }
195 
196 template <typename Op>
197 static LogicalResult checkVarAndVarType(Op op) {
198   if (!op.getVar())
199     return op.emitError("must have var operand");
200 
201   if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
202       mlir::isa<mlir::acc::MappableType>(op.getVar().getType())) {
203     // TODO: If a type implements both interfaces (mappable and pointer-like),
204     // it is unclear which semantics to apply without additional info which
205     // would need captured in the data operation. For now restrict this case
206     // unless a compelling reason to support disambiguating between the two.
207     return op.emitError("var must be mappable or pointer-like (not both)");
208   }
209 
210   if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
211       !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
212     return op.emitError("var must be mappable or pointer-like");
213 
214   if (mlir::isa<mlir::acc::MappableType>(op.getVar().getType()) &&
215       op.getVarType() != op.getVar().getType())
216     return op.emitError("varType must match when var is mappable");
217 
218   return success();
219 }
220 
221 template <typename Op>
222 static LogicalResult checkVarAndAccVar(Op op) {
223   if (op.getVar().getType() != op.getAccVar().getType())
224     return op.emitError("input and output types must match");
225 
226   return success();
227 }
228 
229 static ParseResult parseVar(mlir::OpAsmParser &parser,
230                             OpAsmParser::UnresolvedOperand &var) {
231   // Either `var` or `varPtr` keyword is required.
232   if (failed(parser.parseOptionalKeyword("varPtr"))) {
233     if (failed(parser.parseKeyword("var")))
234       return failure();
235   }
236   if (failed(parser.parseLParen()))
237     return failure();
238   if (failed(parser.parseOperand(var)))
239     return failure();
240 
241   return success();
242 }
243 
244 static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op,
245                      mlir::Value var) {
246   if (mlir::isa<mlir::acc::PointerLikeType>(var.getType()))
247     p << "varPtr(";
248   else
249     p << "var(";
250   p.printOperand(var);
251 }
252 
253 static ParseResult parseAccVar(mlir::OpAsmParser &parser,
254                                OpAsmParser::UnresolvedOperand &var,
255                                mlir::Type &accVarType) {
256   // Either `accVar` or `accPtr` keyword is required.
257   if (failed(parser.parseOptionalKeyword("accPtr"))) {
258     if (failed(parser.parseKeyword("accVar")))
259       return failure();
260   }
261   if (failed(parser.parseLParen()))
262     return failure();
263   if (failed(parser.parseOperand(var)))
264     return failure();
265   if (failed(parser.parseColon()))
266     return failure();
267   if (failed(parser.parseType(accVarType)))
268     return failure();
269   if (failed(parser.parseRParen()))
270     return failure();
271 
272   return success();
273 }
274 
275 static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op,
276                         mlir::Value accVar, mlir::Type accVarType) {
277   if (mlir::isa<mlir::acc::PointerLikeType>(accVar.getType()))
278     p << "accPtr(";
279   else
280     p << "accVar(";
281   p.printOperand(accVar);
282   p << " : ";
283   p.printType(accVarType);
284   p << ")";
285 }
286 
287 static ParseResult parseVarPtrType(mlir::OpAsmParser &parser,
288                                    mlir::Type &varPtrType,
289                                    mlir::TypeAttr &varTypeAttr) {
290   if (failed(parser.parseType(varPtrType)))
291     return failure();
292   if (failed(parser.parseRParen()))
293     return failure();
294 
295   if (succeeded(parser.parseOptionalKeyword("varType"))) {
296     if (failed(parser.parseLParen()))
297       return failure();
298     mlir::Type varType;
299     if (failed(parser.parseType(varType)))
300       return failure();
301     varTypeAttr = mlir::TypeAttr::get(varType);
302     if (failed(parser.parseRParen()))
303       return failure();
304   } else {
305     // Set `varType` from the element type of the type of `varPtr`.
306     if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
307       varTypeAttr = mlir::TypeAttr::get(
308           mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType());
309     else
310       varTypeAttr = mlir::TypeAttr::get(varPtrType);
311   }
312 
313   return success();
314 }
315 
316 static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op,
317                             mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
318   p.printType(varPtrType);
319   p << ")";
320 
321   // Print the `varType` only if it differs from the element type of
322   // `varPtr`'s type.
323   mlir::Type varType = varTypeAttr.getValue();
324   mlir::Type typeToCheckAgainst =
325       mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
326           ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
327           : varPtrType;
328   if (typeToCheckAgainst != varType) {
329     p << " varType(";
330     p.printType(varType);
331     p << ")";
332   }
333 }
334 
335 //===----------------------------------------------------------------------===//
336 // DataBoundsOp
337 //===----------------------------------------------------------------------===//
338 LogicalResult acc::DataBoundsOp::verify() {
339   auto extent = getExtent();
340   auto upperbound = getUpperbound();
341   if (!extent && !upperbound)
342     return emitError("expected extent or upperbound.");
343   return success();
344 }
345 
346 //===----------------------------------------------------------------------===//
347 // PrivateOp
348 //===----------------------------------------------------------------------===//
349 LogicalResult acc::PrivateOp::verify() {
350   if (getDataClause() != acc::DataClause::acc_private)
351     return emitError(
352         "data clause associated with private operation must match its intent");
353   if (failed(checkVarAndVarType(*this)))
354     return failure();
355   return success();
356 }
357 
358 //===----------------------------------------------------------------------===//
359 // FirstprivateOp
360 //===----------------------------------------------------------------------===//
361 LogicalResult acc::FirstprivateOp::verify() {
362   if (getDataClause() != acc::DataClause::acc_firstprivate)
363     return emitError("data clause associated with firstprivate operation must "
364                      "match its intent");
365   if (failed(checkVarAndVarType(*this)))
366     return failure();
367   return success();
368 }
369 
370 //===----------------------------------------------------------------------===//
371 // ReductionOp
372 //===----------------------------------------------------------------------===//
373 LogicalResult acc::ReductionOp::verify() {
374   if (getDataClause() != acc::DataClause::acc_reduction)
375     return emitError("data clause associated with reduction operation must "
376                      "match its intent");
377   if (failed(checkVarAndVarType(*this)))
378     return failure();
379   return success();
380 }
381 
382 //===----------------------------------------------------------------------===//
383 // DevicePtrOp
384 //===----------------------------------------------------------------------===//
385 LogicalResult acc::DevicePtrOp::verify() {
386   if (getDataClause() != acc::DataClause::acc_deviceptr)
387     return emitError("data clause associated with deviceptr operation must "
388                      "match its intent");
389   if (failed(checkVarAndVarType(*this)))
390     return failure();
391   if (failed(checkVarAndAccVar(*this)))
392     return failure();
393   return success();
394 }
395 
396 //===----------------------------------------------------------------------===//
397 // PresentOp
398 //===----------------------------------------------------------------------===//
399 LogicalResult acc::PresentOp::verify() {
400   if (getDataClause() != acc::DataClause::acc_present)
401     return emitError(
402         "data clause associated with present operation must match its intent");
403   if (failed(checkVarAndVarType(*this)))
404     return failure();
405   if (failed(checkVarAndAccVar(*this)))
406     return failure();
407   return success();
408 }
409 
410 //===----------------------------------------------------------------------===//
411 // CopyinOp
412 //===----------------------------------------------------------------------===//
413 LogicalResult acc::CopyinOp::verify() {
414   // Test for all clauses this operation can be decomposed from:
415   if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
416       getDataClause() != acc::DataClause::acc_copyin_readonly &&
417       getDataClause() != acc::DataClause::acc_copy &&
418       getDataClause() != acc::DataClause::acc_reduction)
419     return emitError(
420         "data clause associated with copyin operation must match its intent"
421         " or specify original clause this operation was decomposed from");
422   if (failed(checkVarAndVarType(*this)))
423     return failure();
424   if (failed(checkVarAndAccVar(*this)))
425     return failure();
426   return success();
427 }
428 
429 bool acc::CopyinOp::isCopyinReadonly() {
430   return getDataClause() == acc::DataClause::acc_copyin_readonly;
431 }
432 
433 //===----------------------------------------------------------------------===//
434 // CreateOp
435 //===----------------------------------------------------------------------===//
436 LogicalResult acc::CreateOp::verify() {
437   // Test for all clauses this operation can be decomposed from:
438   if (getDataClause() != acc::DataClause::acc_create &&
439       getDataClause() != acc::DataClause::acc_create_zero &&
440       getDataClause() != acc::DataClause::acc_copyout &&
441       getDataClause() != acc::DataClause::acc_copyout_zero)
442     return emitError(
443         "data clause associated with create operation must match its intent"
444         " or specify original clause this operation was decomposed from");
445   if (failed(checkVarAndVarType(*this)))
446     return failure();
447   if (failed(checkVarAndAccVar(*this)))
448     return failure();
449   return success();
450 }
451 
452 bool acc::CreateOp::isCreateZero() {
453   // The zero modifier is encoded in the data clause.
454   return getDataClause() == acc::DataClause::acc_create_zero ||
455          getDataClause() == acc::DataClause::acc_copyout_zero;
456 }
457 
458 //===----------------------------------------------------------------------===//
459 // NoCreateOp
460 //===----------------------------------------------------------------------===//
461 LogicalResult acc::NoCreateOp::verify() {
462   if (getDataClause() != acc::DataClause::acc_no_create)
463     return emitError("data clause associated with no_create operation must "
464                      "match its intent");
465   if (failed(checkVarAndVarType(*this)))
466     return failure();
467   if (failed(checkVarAndAccVar(*this)))
468     return failure();
469   return success();
470 }
471 
472 //===----------------------------------------------------------------------===//
473 // AttachOp
474 //===----------------------------------------------------------------------===//
475 LogicalResult acc::AttachOp::verify() {
476   if (getDataClause() != acc::DataClause::acc_attach)
477     return emitError(
478         "data clause associated with attach operation must match its intent");
479   if (failed(checkVarAndVarType(*this)))
480     return failure();
481   if (failed(checkVarAndAccVar(*this)))
482     return failure();
483   return success();
484 }
485 
486 //===----------------------------------------------------------------------===//
487 // DeclareDeviceResidentOp
488 //===----------------------------------------------------------------------===//
489 
490 LogicalResult acc::DeclareDeviceResidentOp::verify() {
491   if (getDataClause() != acc::DataClause::acc_declare_device_resident)
492     return emitError("data clause associated with device_resident operation "
493                      "must match its intent");
494   if (failed(checkVarAndVarType(*this)))
495     return failure();
496   if (failed(checkVarAndAccVar(*this)))
497     return failure();
498   return success();
499 }
500 
501 //===----------------------------------------------------------------------===//
502 // DeclareLinkOp
503 //===----------------------------------------------------------------------===//
504 
505 LogicalResult acc::DeclareLinkOp::verify() {
506   if (getDataClause() != acc::DataClause::acc_declare_link)
507     return emitError(
508         "data clause associated with link operation must match its intent");
509   if (failed(checkVarAndVarType(*this)))
510     return failure();
511   if (failed(checkVarAndAccVar(*this)))
512     return failure();
513   return success();
514 }
515 
516 //===----------------------------------------------------------------------===//
517 // CopyoutOp
518 //===----------------------------------------------------------------------===//
519 LogicalResult acc::CopyoutOp::verify() {
520   // Test for all clauses this operation can be decomposed from:
521   if (getDataClause() != acc::DataClause::acc_copyout &&
522       getDataClause() != acc::DataClause::acc_copyout_zero &&
523       getDataClause() != acc::DataClause::acc_copy &&
524       getDataClause() != acc::DataClause::acc_reduction)
525     return emitError(
526         "data clause associated with copyout operation must match its intent"
527         " or specify original clause this operation was decomposed from");
528   if (!getVar() || !getAccVar())
529     return emitError("must have both host and device pointers");
530   if (failed(checkVarAndVarType(*this)))
531     return failure();
532   if (failed(checkVarAndAccVar(*this)))
533     return failure();
534   return success();
535 }
536 
537 bool acc::CopyoutOp::isCopyoutZero() {
538   return getDataClause() == acc::DataClause::acc_copyout_zero;
539 }
540 
541 //===----------------------------------------------------------------------===//
542 // DeleteOp
543 //===----------------------------------------------------------------------===//
544 LogicalResult acc::DeleteOp::verify() {
545   // Test for all clauses this operation can be decomposed from:
546   if (getDataClause() != acc::DataClause::acc_delete &&
547       getDataClause() != acc::DataClause::acc_create &&
548       getDataClause() != acc::DataClause::acc_create_zero &&
549       getDataClause() != acc::DataClause::acc_copyin &&
550       getDataClause() != acc::DataClause::acc_copyin_readonly &&
551       getDataClause() != acc::DataClause::acc_present &&
552       getDataClause() != acc::DataClause::acc_declare_device_resident &&
553       getDataClause() != acc::DataClause::acc_declare_link)
554     return emitError(
555         "data clause associated with delete operation must match its intent"
556         " or specify original clause this operation was decomposed from");
557   if (!getAccVar())
558     return emitError("must have device pointer");
559   return success();
560 }
561 
562 //===----------------------------------------------------------------------===//
563 // DetachOp
564 //===----------------------------------------------------------------------===//
565 LogicalResult acc::DetachOp::verify() {
566   // Test for all clauses this operation can be decomposed from:
567   if (getDataClause() != acc::DataClause::acc_detach &&
568       getDataClause() != acc::DataClause::acc_attach)
569     return emitError(
570         "data clause associated with detach operation must match its intent"
571         " or specify original clause this operation was decomposed from");
572   if (!getAccVar())
573     return emitError("must have device pointer");
574   return success();
575 }
576 
577 //===----------------------------------------------------------------------===//
578 // HostOp
579 //===----------------------------------------------------------------------===//
580 LogicalResult acc::UpdateHostOp::verify() {
581   // Test for all clauses this operation can be decomposed from:
582   if (getDataClause() != acc::DataClause::acc_update_host &&
583       getDataClause() != acc::DataClause::acc_update_self)
584     return emitError(
585         "data clause associated with host operation must match its intent"
586         " or specify original clause this operation was decomposed from");
587   if (!getVar() || !getAccVar())
588     return emitError("must have both host and device pointers");
589   if (failed(checkVarAndVarType(*this)))
590     return failure();
591   if (failed(checkVarAndAccVar(*this)))
592     return failure();
593   return success();
594 }
595 
596 //===----------------------------------------------------------------------===//
597 // DeviceOp
598 //===----------------------------------------------------------------------===//
599 LogicalResult acc::UpdateDeviceOp::verify() {
600   // Test for all clauses this operation can be decomposed from:
601   if (getDataClause() != acc::DataClause::acc_update_device)
602     return emitError(
603         "data clause associated with device operation must match its intent"
604         " or specify original clause this operation was decomposed from");
605   if (failed(checkVarAndVarType(*this)))
606     return failure();
607   if (failed(checkVarAndAccVar(*this)))
608     return failure();
609   return success();
610 }
611 
612 //===----------------------------------------------------------------------===//
613 // UseDeviceOp
614 //===----------------------------------------------------------------------===//
615 LogicalResult acc::UseDeviceOp::verify() {
616   // Test for all clauses this operation can be decomposed from:
617   if (getDataClause() != acc::DataClause::acc_use_device)
618     return emitError(
619         "data clause associated with use_device operation must match its intent"
620         " or specify original clause this operation was decomposed from");
621   if (failed(checkVarAndVarType(*this)))
622     return failure();
623   if (failed(checkVarAndAccVar(*this)))
624     return failure();
625   return success();
626 }
627 
628 //===----------------------------------------------------------------------===//
629 // CacheOp
630 //===----------------------------------------------------------------------===//
631 LogicalResult acc::CacheOp::verify() {
632   // Test for all clauses this operation can be decomposed from:
633   if (getDataClause() != acc::DataClause::acc_cache &&
634       getDataClause() != acc::DataClause::acc_cache_readonly)
635     return emitError(
636         "data clause associated with cache operation must match its intent"
637         " or specify original clause this operation was decomposed from");
638   if (failed(checkVarAndVarType(*this)))
639     return failure();
640   if (failed(checkVarAndAccVar(*this)))
641     return failure();
642   return success();
643 }
644 
645 template <typename StructureOp>
646 static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
647                                 unsigned nRegions = 1) {
648 
649   SmallVector<Region *, 2> regions;
650   for (unsigned i = 0; i < nRegions; ++i)
651     regions.push_back(state.addRegion());
652 
653   for (Region *region : regions)
654     if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
655       return failure();
656 
657   return success();
658 }
659 
660 static bool isComputeOperation(Operation *op) {
661   return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
662 }
663 
664 namespace {
665 /// Pattern to remove operation without region that have constant false `ifCond`
666 /// and remove the condition from the operation if the `ifCond` is a true
667 /// constant.
668 template <typename OpTy>
669 struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
670   using OpRewritePattern<OpTy>::OpRewritePattern;
671 
672   LogicalResult matchAndRewrite(OpTy op,
673                                 PatternRewriter &rewriter) const override {
674     // Early return if there is no condition.
675     Value ifCond = op.getIfCond();
676     if (!ifCond)
677       return failure();
678 
679     IntegerAttr constAttr;
680     if (!matchPattern(ifCond, m_Constant(&constAttr)))
681       return failure();
682     if (constAttr.getInt())
683       rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
684     else
685       rewriter.eraseOp(op);
686 
687     return success();
688   }
689 };
690 
691 /// Replaces the given op with the contents of the given single-block region,
692 /// using the operands of the block terminator to replace operation results.
693 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
694                                 Region &region, ValueRange blockArgs = {}) {
695   assert(llvm::hasSingleElement(region) && "expected single-region block");
696   Block *block = &region.front();
697   Operation *terminator = block->getTerminator();
698   ValueRange results = terminator->getOperands();
699   rewriter.inlineBlockBefore(block, op, blockArgs);
700   rewriter.replaceOp(op, results);
701   rewriter.eraseOp(terminator);
702 }
703 
704 /// Pattern to remove operation with region that have constant false `ifCond`
705 /// and remove the condition from the operation if the `ifCond` is constant
706 /// true.
707 template <typename OpTy>
708 struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
709   using OpRewritePattern<OpTy>::OpRewritePattern;
710 
711   LogicalResult matchAndRewrite(OpTy op,
712                                 PatternRewriter &rewriter) const override {
713     // Early return if there is no condition.
714     Value ifCond = op.getIfCond();
715     if (!ifCond)
716       return failure();
717 
718     IntegerAttr constAttr;
719     if (!matchPattern(ifCond, m_Constant(&constAttr)))
720       return failure();
721     if (constAttr.getInt())
722       rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
723     else
724       replaceOpWithRegion(rewriter, op, op.getRegion());
725 
726     return success();
727   }
728 };
729 
730 } // namespace
731 
732 //===----------------------------------------------------------------------===//
733 // PrivateRecipeOp
734 //===----------------------------------------------------------------------===//
735 
736 static LogicalResult verifyInitLikeSingleArgRegion(
737     Operation *op, Region &region, StringRef regionType, StringRef regionName,
738     Type type, bool verifyYield, bool optional = false) {
739   if (optional && region.empty())
740     return success();
741 
742   if (region.empty())
743     return op->emitOpError() << "expects non-empty " << regionName << " region";
744   Block &firstBlock = region.front();
745   if (firstBlock.getNumArguments() < 1 ||
746       firstBlock.getArgument(0).getType() != type)
747     return op->emitOpError() << "expects " << regionName
748                              << " region first "
749                                 "argument of the "
750                              << regionType << " type";
751 
752   if (verifyYield) {
753     for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
754       if (yieldOp.getOperands().size() != 1 ||
755           yieldOp.getOperands().getTypes()[0] != type)
756         return op->emitOpError() << "expects " << regionName
757                                  << " region to "
758                                     "yield a value of the "
759                                  << regionType << " type";
760     }
761   }
762   return success();
763 }
764 
765 LogicalResult acc::PrivateRecipeOp::verifyRegions() {
766   if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
767                                            "privatization", "init", getType(),
768                                            /*verifyYield=*/false)))
769     return failure();
770   if (failed(verifyInitLikeSingleArgRegion(
771           *this, getDestroyRegion(), "privatization", "destroy", getType(),
772           /*verifyYield=*/false, /*optional=*/true)))
773     return failure();
774   return success();
775 }
776 
777 //===----------------------------------------------------------------------===//
778 // FirstprivateRecipeOp
779 //===----------------------------------------------------------------------===//
780 
781 LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
782   if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
783                                            "privatization", "init", getType(),
784                                            /*verifyYield=*/false)))
785     return failure();
786 
787   if (getCopyRegion().empty())
788     return emitOpError() << "expects non-empty copy region";
789 
790   Block &firstBlock = getCopyRegion().front();
791   if (firstBlock.getNumArguments() < 2 ||
792       firstBlock.getArgument(0).getType() != getType())
793     return emitOpError() << "expects copy region with two arguments of the "
794                             "privatization type";
795 
796   if (getDestroyRegion().empty())
797     return success();
798 
799   if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
800                                            "privatization", "destroy",
801                                            getType(), /*verifyYield=*/false)))
802     return failure();
803 
804   return success();
805 }
806 
807 //===----------------------------------------------------------------------===//
808 // ReductionRecipeOp
809 //===----------------------------------------------------------------------===//
810 
811 LogicalResult acc::ReductionRecipeOp::verifyRegions() {
812   if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
813                                            "init", getType(),
814                                            /*verifyYield=*/false)))
815     return failure();
816 
817   if (getCombinerRegion().empty())
818     return emitOpError() << "expects non-empty combiner region";
819 
820   Block &reductionBlock = getCombinerRegion().front();
821   if (reductionBlock.getNumArguments() < 2 ||
822       reductionBlock.getArgument(0).getType() != getType() ||
823       reductionBlock.getArgument(1).getType() != getType())
824     return emitOpError() << "expects combiner region with the first two "
825                          << "arguments of the reduction type";
826 
827   for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
828     if (yieldOp.getOperands().size() != 1 ||
829         yieldOp.getOperands().getTypes()[0] != getType())
830       return emitOpError() << "expects combiner region to yield a value "
831                               "of the reduction type";
832   }
833 
834   return success();
835 }
836 
837 //===----------------------------------------------------------------------===//
838 // Custom parser and printer verifier for private clause
839 //===----------------------------------------------------------------------===//
840 
841 static ParseResult parseSymOperandList(
842     mlir::OpAsmParser &parser,
843     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
844     llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) {
845   llvm::SmallVector<SymbolRefAttr> attributes;
846   if (failed(parser.parseCommaSeparatedList([&]() {
847         if (parser.parseAttribute(attributes.emplace_back()) ||
848             parser.parseArrow() ||
849             parser.parseOperand(operands.emplace_back()) ||
850             parser.parseColonType(types.emplace_back()))
851           return failure();
852         return success();
853       })))
854     return failure();
855   llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
856                                                attributes.end());
857   symbols = ArrayAttr::get(parser.getContext(), arrayAttr);
858   return success();
859 }
860 
861 static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op,
862                                 mlir::OperandRange operands,
863                                 mlir::TypeRange types,
864                                 std::optional<mlir::ArrayAttr> attributes) {
865   llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) {
866     p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
867       << std::get<1>(it).getType();
868   });
869 }
870 
871 //===----------------------------------------------------------------------===//
872 // ParallelOp
873 //===----------------------------------------------------------------------===//
874 
875 /// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
876 template <typename Op>
877 static LogicalResult checkDataOperands(Op op,
878                                        const mlir::ValueRange &operands) {
879   for (mlir::Value operand : operands)
880     if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
881                    acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
882                    acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
883             operand.getDefiningOp()))
884       return op.emitError(
885           "expect data entry/exit operation or acc.getdeviceptr "
886           "as defining op");
887   return success();
888 }
889 
890 template <typename Op>
891 static LogicalResult
892 checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
893                     mlir::OperandRange operands, llvm::StringRef operandName,
894                     llvm::StringRef symbolName, bool checkOperandType = true) {
895   if (!operands.empty()) {
896     if (!attributes || attributes->size() != operands.size())
897       return op->emitOpError()
898              << "expected as many " << symbolName << " symbol reference as "
899              << operandName << " operands";
900   } else {
901     if (attributes)
902       return op->emitOpError()
903              << "unexpected " << symbolName << " symbol reference";
904     return success();
905   }
906 
907   llvm::DenseSet<Value> set;
908   for (auto args : llvm::zip(operands, *attributes)) {
909     mlir::Value operand = std::get<0>(args);
910 
911     if (!set.insert(operand).second)
912       return op->emitOpError()
913              << operandName << " operand appears more than once";
914 
915     mlir::Type varType = operand.getType();
916     auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
917     auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
918     if (!decl)
919       return op->emitOpError()
920              << "expected symbol reference " << symbolRef << " to point to a "
921              << operandName << " declaration";
922 
923     if (checkOperandType && decl.getType() && decl.getType() != varType)
924       return op->emitOpError() << "expected " << operandName << " (" << varType
925                                << ") to be the same type as " << operandName
926                                << " declaration (" << decl.getType() << ")";
927   }
928 
929   return success();
930 }
931 
932 unsigned ParallelOp::getNumDataOperands() {
933   return getReductionOperands().size() + getPrivateOperands().size() +
934          getFirstprivateOperands().size() + getDataClauseOperands().size();
935 }
936 
937 Value ParallelOp::getDataOperand(unsigned i) {
938   unsigned numOptional = getAsyncOperands().size();
939   numOptional += getNumGangs().size();
940   numOptional += getNumWorkers().size();
941   numOptional += getVectorLength().size();
942   numOptional += getIfCond() ? 1 : 0;
943   numOptional += getSelfCond() ? 1 : 0;
944   return getOperand(getWaitOperands().size() + numOptional + i);
945 }
946 
947 template <typename Op>
948 static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
949                                                 ArrayAttr deviceTypes,
950                                                 llvm::StringRef keyword) {
951   if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
952     return op.emitOpError() << keyword << " operands count must match "
953                             << keyword << " device_type count";
954   return success();
955 }
956 
957 template <typename Op>
958 static LogicalResult verifyDeviceTypeAndSegmentCountMatch(
959     Op op, OperandRange operands, DenseI32ArrayAttr segments,
960     ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
961   std::size_t numOperandsInSegments = 0;
962   std::size_t nbOfSegments = 0;
963 
964   if (segments) {
965     for (auto segCount : segments.asArrayRef()) {
966       if (maxInSegment != 0 && segCount > maxInSegment)
967         return op.emitOpError() << keyword << " expects a maximum of "
968                                 << maxInSegment << " values per segment";
969       numOperandsInSegments += segCount;
970       ++nbOfSegments;
971     }
972   }
973 
974   if ((numOperandsInSegments != operands.size()) ||
975       (!deviceTypes && !operands.empty()))
976     return op.emitOpError()
977            << keyword << " operand count does not match count in segments";
978   if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
979     return op.emitOpError()
980            << keyword << " segment count does not match device_type count";
981   return success();
982 }
983 
984 LogicalResult acc::ParallelOp::verify() {
985   if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
986           *this, getPrivatizations(), getPrivateOperands(), "private",
987           "privatizations", /*checkOperandType=*/false)))
988     return failure();
989   if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
990           *this, getFirstprivatizations(), getFirstprivateOperands(),
991           "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
992     return failure();
993   if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
994           *this, getReductionRecipes(), getReductionOperands(), "reduction",
995           "reductions", false)))
996     return failure();
997 
998   if (failed(verifyDeviceTypeAndSegmentCountMatch(
999           *this, getNumGangs(), getNumGangsSegmentsAttr(),
1000           getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
1001     return failure();
1002 
1003   if (failed(verifyDeviceTypeAndSegmentCountMatch(
1004           *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1005           getWaitOperandsDeviceTypeAttr(), "wait")))
1006     return failure();
1007 
1008   if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
1009                                         getNumWorkersDeviceTypeAttr(),
1010                                         "num_workers")))
1011     return failure();
1012 
1013   if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
1014                                         getVectorLengthDeviceTypeAttr(),
1015                                         "vector_length")))
1016     return failure();
1017 
1018   if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1019                                         getAsyncOperandsDeviceTypeAttr(),
1020                                         "async")))
1021     return failure();
1022 
1023   if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*this)))
1024     return failure();
1025 
1026   return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
1027 }
1028 
1029 static mlir::Value
1030 getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
1031                             mlir::Operation::operand_range range,
1032                             mlir::acc::DeviceType deviceType) {
1033   if (!arrayAttr)
1034     return {};
1035   if (auto pos = findSegment(*arrayAttr, deviceType))
1036     return range[*pos];
1037   return {};
1038 }
1039 
1040 bool acc::ParallelOp::hasAsyncOnly() {
1041   return hasAsyncOnly(mlir::acc::DeviceType::None);
1042 }
1043 
1044 bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1045   return hasDeviceType(getAsyncOnly(), deviceType);
1046 }
1047 
1048 mlir::Value acc::ParallelOp::getAsyncValue() {
1049   return getAsyncValue(mlir::acc::DeviceType::None);
1050 }
1051 
1052 mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1053   return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
1054                                      getAsyncOperands(), deviceType);
1055 }
1056 
1057 mlir::Value acc::ParallelOp::getNumWorkersValue() {
1058   return getNumWorkersValue(mlir::acc::DeviceType::None);
1059 }
1060 
1061 mlir::Value
1062 acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1063   return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1064                                      deviceType);
1065 }
1066 
1067 mlir::Value acc::ParallelOp::getVectorLengthValue() {
1068   return getVectorLengthValue(mlir::acc::DeviceType::None);
1069 }
1070 
1071 mlir::Value
1072 acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1073   return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1074                                      getVectorLength(), deviceType);
1075 }
1076 
1077 mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
1078   return getNumGangsValues(mlir::acc::DeviceType::None);
1079 }
1080 
1081 mlir::Operation::operand_range
1082 ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1083   return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
1084                                getNumGangsSegments(), deviceType);
1085 }
1086 
1087 bool acc::ParallelOp::hasWaitOnly() {
1088   return hasWaitOnly(mlir::acc::DeviceType::None);
1089 }
1090 
1091 bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1092   return hasDeviceType(getWaitOnly(), deviceType);
1093 }
1094 
1095 mlir::Operation::operand_range ParallelOp::getWaitValues() {
1096   return getWaitValues(mlir::acc::DeviceType::None);
1097 }
1098 
1099 mlir::Operation::operand_range
1100 ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1101   return getWaitValuesWithoutDevnum(
1102       getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1103       getHasWaitDevnum(), deviceType);
1104 }
1105 
1106 mlir::Value ParallelOp::getWaitDevnum() {
1107   return getWaitDevnum(mlir::acc::DeviceType::None);
1108 }
1109 
1110 mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1111   return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1112                             getWaitOperandsSegments(), getHasWaitDevnum(),
1113                             deviceType);
1114 }
1115 
1116 void ParallelOp::build(mlir::OpBuilder &odsBuilder,
1117                        mlir::OperationState &odsState,
1118                        mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
1119                        mlir::ValueRange vectorLength,
1120                        mlir::ValueRange asyncOperands,
1121                        mlir::ValueRange waitOperands, mlir::Value ifCond,
1122                        mlir::Value selfCond, mlir::ValueRange reductionOperands,
1123                        mlir::ValueRange gangPrivateOperands,
1124                        mlir::ValueRange gangFirstPrivateOperands,
1125                        mlir::ValueRange dataClauseOperands) {
1126 
1127   ParallelOp::build(
1128       odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
1129       /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
1130       /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
1131       /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
1132       /*numGangsDeviceType=*/nullptr, numWorkers,
1133       /*numWorkersDeviceType=*/nullptr, vectorLength,
1134       /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
1135       /*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr,
1136       gangPrivateOperands, /*privatizations=*/nullptr, gangFirstPrivateOperands,
1137       /*firstprivatizations=*/nullptr, dataClauseOperands,
1138       /*defaultAttr=*/nullptr, /*combined=*/nullptr);
1139 }
1140 
1141 static ParseResult parseNumGangs(
1142     mlir::OpAsmParser &parser,
1143     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
1144     llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1145     mlir::DenseI32ArrayAttr &segments) {
1146   llvm::SmallVector<DeviceTypeAttr> attributes;
1147   llvm::SmallVector<int32_t> seg;
1148 
1149   do {
1150     if (failed(parser.parseLBrace()))
1151       return failure();
1152 
1153     int32_t crtOperandsSize = operands.size();
1154     if (failed(parser.parseCommaSeparatedList(
1155             mlir::AsmParser::Delimiter::None, [&]() {
1156               if (parser.parseOperand(operands.emplace_back()) ||
1157                   parser.parseColonType(types.emplace_back()))
1158                 return failure();
1159               return success();
1160             })))
1161       return failure();
1162     seg.push_back(operands.size() - crtOperandsSize);
1163 
1164     if (failed(parser.parseRBrace()))
1165       return failure();
1166 
1167     if (succeeded(parser.parseOptionalLSquare())) {
1168       if (parser.parseAttribute(attributes.emplace_back()) ||
1169           parser.parseRSquare())
1170         return failure();
1171     } else {
1172       attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1173           parser.getContext(), mlir::acc::DeviceType::None));
1174     }
1175   } while (succeeded(parser.parseOptionalComma()));
1176 
1177   llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1178                                                attributes.end());
1179   deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1180   segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1181 
1182   return success();
1183 }
1184 
1185 static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr) {
1186   auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1187   if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
1188     p << " [" << attr << "]";
1189 }
1190 
1191 static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op,
1192                           mlir::OperandRange operands, mlir::TypeRange types,
1193                           std::optional<mlir::ArrayAttr> deviceTypes,
1194                           std::optional<mlir::DenseI32ArrayAttr> segments) {
1195   unsigned opIdx = 0;
1196   llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1197     p << "{";
1198     llvm::interleaveComma(
1199         llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1200           p << operands[opIdx] << " : " << operands[opIdx].getType();
1201           ++opIdx;
1202         });
1203     p << "}";
1204     printSingleDeviceType(p, it.value());
1205   });
1206 }
1207 
1208 static ParseResult parseDeviceTypeOperandsWithSegment(
1209     mlir::OpAsmParser &parser,
1210     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
1211     llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1212     mlir::DenseI32ArrayAttr &segments) {
1213   llvm::SmallVector<DeviceTypeAttr> attributes;
1214   llvm::SmallVector<int32_t> seg;
1215 
1216   do {
1217     if (failed(parser.parseLBrace()))
1218       return failure();
1219 
1220     int32_t crtOperandsSize = operands.size();
1221 
1222     if (failed(parser.parseCommaSeparatedList(
1223             mlir::AsmParser::Delimiter::None, [&]() {
1224               if (parser.parseOperand(operands.emplace_back()) ||
1225                   parser.parseColonType(types.emplace_back()))
1226                 return failure();
1227               return success();
1228             })))
1229       return failure();
1230 
1231     seg.push_back(operands.size() - crtOperandsSize);
1232 
1233     if (failed(parser.parseRBrace()))
1234       return failure();
1235 
1236     if (succeeded(parser.parseOptionalLSquare())) {
1237       if (parser.parseAttribute(attributes.emplace_back()) ||
1238           parser.parseRSquare())
1239         return failure();
1240     } else {
1241       attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1242           parser.getContext(), mlir::acc::DeviceType::None));
1243     }
1244   } while (succeeded(parser.parseOptionalComma()));
1245 
1246   llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1247                                                attributes.end());
1248   deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1249   segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1250 
1251   return success();
1252 }
1253 
1254 static void printDeviceTypeOperandsWithSegment(
1255     mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
1256     mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1257     std::optional<mlir::DenseI32ArrayAttr> segments) {
1258   unsigned opIdx = 0;
1259   llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1260     p << "{";
1261     llvm::interleaveComma(
1262         llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1263           p << operands[opIdx] << " : " << operands[opIdx].getType();
1264           ++opIdx;
1265         });
1266     p << "}";
1267     printSingleDeviceType(p, it.value());
1268   });
1269 }
1270 
1271 static ParseResult parseWaitClause(
1272     mlir::OpAsmParser &parser,
1273     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
1274     llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1275     mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
1276     mlir::ArrayAttr &keywordOnly) {
1277   llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
1278   llvm::SmallVector<int32_t> seg;
1279 
1280   bool needCommaBeforeOperands = false;
1281 
1282   // Keyword only
1283   if (failed(parser.parseOptionalLParen())) {
1284     keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1285         parser.getContext(), mlir::acc::DeviceType::None));
1286     keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1287     return success();
1288   }
1289 
1290   // Parse keyword only attributes
1291   if (succeeded(parser.parseOptionalLSquare())) {
1292     if (failed(parser.parseCommaSeparatedList([&]() {
1293           if (parser.parseAttribute(keywordAttrs.emplace_back()))
1294             return failure();
1295           return success();
1296         })))
1297       return failure();
1298     if (parser.parseRSquare())
1299       return failure();
1300     needCommaBeforeOperands = true;
1301   }
1302 
1303   if (needCommaBeforeOperands && failed(parser.parseComma()))
1304     return failure();
1305 
1306   do {
1307     if (failed(parser.parseLBrace()))
1308       return failure();
1309 
1310     int32_t crtOperandsSize = operands.size();
1311 
1312     if (succeeded(parser.parseOptionalKeyword("devnum"))) {
1313       if (failed(parser.parseColon()))
1314         return failure();
1315       devnum.push_back(BoolAttr::get(parser.getContext(), true));
1316     } else {
1317       devnum.push_back(BoolAttr::get(parser.getContext(), false));
1318     }
1319 
1320     if (failed(parser.parseCommaSeparatedList(
1321             mlir::AsmParser::Delimiter::None, [&]() {
1322               if (parser.parseOperand(operands.emplace_back()) ||
1323                   parser.parseColonType(types.emplace_back()))
1324                 return failure();
1325               return success();
1326             })))
1327       return failure();
1328 
1329     seg.push_back(operands.size() - crtOperandsSize);
1330 
1331     if (failed(parser.parseRBrace()))
1332       return failure();
1333 
1334     if (succeeded(parser.parseOptionalLSquare())) {
1335       if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
1336           parser.parseRSquare())
1337         return failure();
1338     } else {
1339       deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1340           parser.getContext(), mlir::acc::DeviceType::None));
1341     }
1342   } while (succeeded(parser.parseOptionalComma()));
1343 
1344   if (failed(parser.parseRParen()))
1345     return failure();
1346 
1347   deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
1348   keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1349   segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1350   hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
1351 
1352   return success();
1353 }
1354 
1355 static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
1356   if (!hasDeviceTypeValues(attrs))
1357     return false;
1358   if (attrs->size() != 1)
1359     return false;
1360   if (auto deviceTypeAttr =
1361           mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
1362     return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
1363   return false;
1364 }
1365 
1366 static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op,
1367                             mlir::OperandRange operands, mlir::TypeRange types,
1368                             std::optional<mlir::ArrayAttr> deviceTypes,
1369                             std::optional<mlir::DenseI32ArrayAttr> segments,
1370                             std::optional<mlir::ArrayAttr> hasDevNum,
1371                             std::optional<mlir::ArrayAttr> keywordOnly) {
1372 
1373   if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
1374     return;
1375 
1376   p << "(";
1377 
1378   printDeviceTypes(p, keywordOnly);
1379   if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
1380     p << ", ";
1381 
1382   unsigned opIdx = 0;
1383   llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1384     p << "{";
1385     auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
1386     if (boolAttr && boolAttr.getValue())
1387       p << "devnum: ";
1388     llvm::interleaveComma(
1389         llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1390           p << operands[opIdx] << " : " << operands[opIdx].getType();
1391           ++opIdx;
1392         });
1393     p << "}";
1394     printSingleDeviceType(p, it.value());
1395   });
1396 
1397   p << ")";
1398 }
1399 
1400 static ParseResult parseDeviceTypeOperands(
1401     mlir::OpAsmParser &parser,
1402     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
1403     llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
1404   llvm::SmallVector<DeviceTypeAttr> attributes;
1405   if (failed(parser.parseCommaSeparatedList([&]() {
1406         if (parser.parseOperand(operands.emplace_back()) ||
1407             parser.parseColonType(types.emplace_back()))
1408           return failure();
1409         if (succeeded(parser.parseOptionalLSquare())) {
1410           if (parser.parseAttribute(attributes.emplace_back()) ||
1411               parser.parseRSquare())
1412             return failure();
1413         } else {
1414           attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1415               parser.getContext(), mlir::acc::DeviceType::None));
1416         }
1417         return success();
1418       })))
1419     return failure();
1420   llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1421                                                attributes.end());
1422   deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1423   return success();
1424 }
1425 
1426 static void
1427 printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op,
1428                         mlir::OperandRange operands, mlir::TypeRange types,
1429                         std::optional<mlir::ArrayAttr> deviceTypes) {
1430   if (!hasDeviceTypeValues(deviceTypes))
1431     return;
1432   llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
1433     p << std::get<1>(it) << " : " << std::get<1>(it).getType();
1434     printSingleDeviceType(p, std::get<0>(it));
1435   });
1436 }
1437 
1438 static ParseResult parseDeviceTypeOperandsWithKeywordOnly(
1439     mlir::OpAsmParser &parser,
1440     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
1441     llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1442     mlir::ArrayAttr &keywordOnlyDeviceType) {
1443 
1444   llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
1445   bool needCommaBeforeOperands = false;
1446 
1447   if (failed(parser.parseOptionalLParen())) {
1448     // Keyword only
1449     keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1450         parser.getContext(), mlir::acc::DeviceType::None));
1451     keywordOnlyDeviceType =
1452         ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
1453     return success();
1454   }
1455 
1456   // Parse keyword only attributes
1457   if (succeeded(parser.parseOptionalLSquare())) {
1458     // Parse keyword only attributes
1459     if (failed(parser.parseCommaSeparatedList([&]() {
1460           if (parser.parseAttribute(
1461                   keywordOnlyDeviceTypeAttributes.emplace_back()))
1462             return failure();
1463           return success();
1464         })))
1465       return failure();
1466     if (parser.parseRSquare())
1467       return failure();
1468     needCommaBeforeOperands = true;
1469   }
1470 
1471   if (needCommaBeforeOperands && failed(parser.parseComma()))
1472     return failure();
1473 
1474   llvm::SmallVector<DeviceTypeAttr> attributes;
1475   if (failed(parser.parseCommaSeparatedList([&]() {
1476         if (parser.parseOperand(operands.emplace_back()) ||
1477             parser.parseColonType(types.emplace_back()))
1478           return failure();
1479         if (succeeded(parser.parseOptionalLSquare())) {
1480           if (parser.parseAttribute(attributes.emplace_back()) ||
1481               parser.parseRSquare())
1482             return failure();
1483         } else {
1484           attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1485               parser.getContext(), mlir::acc::DeviceType::None));
1486         }
1487         return success();
1488       })))
1489     return failure();
1490 
1491   if (failed(parser.parseRParen()))
1492     return failure();
1493 
1494   llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1495                                                attributes.end());
1496   deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1497   return success();
1498 }
1499 
1500 static void printDeviceTypeOperandsWithKeywordOnly(
1501     mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
1502     mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1503     std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
1504 
1505   if (operands.begin() == operands.end() &&
1506       hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
1507     return;
1508   }
1509 
1510   p << "(";
1511   printDeviceTypes(p, keywordOnlyDeviceTypes);
1512   if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
1513       hasDeviceTypeValues(deviceTypes))
1514     p << ", ";
1515   printDeviceTypeOperands(p, op, operands, types, deviceTypes);
1516   p << ")";
1517 }
1518 
1519 static ParseResult
1520 parseCombinedConstructsLoop(mlir::OpAsmParser &parser,
1521                             mlir::acc::CombinedConstructsTypeAttr &attr) {
1522   if (succeeded(parser.parseOptionalKeyword("combined"))) {
1523     if (parser.parseLParen())
1524       return failure();
1525     if (succeeded(parser.parseOptionalKeyword("kernels"))) {
1526       attr = mlir::acc::CombinedConstructsTypeAttr::get(
1527           parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
1528     } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
1529       attr = mlir::acc::CombinedConstructsTypeAttr::get(
1530           parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
1531     } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
1532       attr = mlir::acc::CombinedConstructsTypeAttr::get(
1533           parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
1534     } else {
1535       parser.emitError(parser.getCurrentLocation(),
1536                        "expected compute construct name");
1537       return failure();
1538     }
1539     if (parser.parseRParen())
1540       return failure();
1541   }
1542   return success();
1543 }
1544 
1545 static void
1546 printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op,
1547                             mlir::acc::CombinedConstructsTypeAttr attr) {
1548   if (attr) {
1549     switch (attr.getValue()) {
1550     case mlir::acc::CombinedConstructsType::KernelsLoop:
1551       p << "combined(kernels)";
1552       break;
1553     case mlir::acc::CombinedConstructsType::ParallelLoop:
1554       p << "combined(parallel)";
1555       break;
1556     case mlir::acc::CombinedConstructsType::SerialLoop:
1557       p << "combined(serial)";
1558       break;
1559     };
1560   }
1561 }
1562 
1563 //===----------------------------------------------------------------------===//
1564 // SerialOp
1565 //===----------------------------------------------------------------------===//
1566 
1567 unsigned SerialOp::getNumDataOperands() {
1568   return getReductionOperands().size() + getPrivateOperands().size() +
1569          getFirstprivateOperands().size() + getDataClauseOperands().size();
1570 }
1571 
1572 Value SerialOp::getDataOperand(unsigned i) {
1573   unsigned numOptional = getAsyncOperands().size();
1574   numOptional += getIfCond() ? 1 : 0;
1575   numOptional += getSelfCond() ? 1 : 0;
1576   return getOperand(getWaitOperands().size() + numOptional + i);
1577 }
1578 
1579 bool acc::SerialOp::hasAsyncOnly() {
1580   return hasAsyncOnly(mlir::acc::DeviceType::None);
1581 }
1582 
1583 bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1584   return hasDeviceType(getAsyncOnly(), deviceType);
1585 }
1586 
1587 mlir::Value acc::SerialOp::getAsyncValue() {
1588   return getAsyncValue(mlir::acc::DeviceType::None);
1589 }
1590 
1591 mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1592   return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
1593                                      getAsyncOperands(), deviceType);
1594 }
1595 
1596 bool acc::SerialOp::hasWaitOnly() {
1597   return hasWaitOnly(mlir::acc::DeviceType::None);
1598 }
1599 
1600 bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1601   return hasDeviceType(getWaitOnly(), deviceType);
1602 }
1603 
1604 mlir::Operation::operand_range SerialOp::getWaitValues() {
1605   return getWaitValues(mlir::acc::DeviceType::None);
1606 }
1607 
1608 mlir::Operation::operand_range
1609 SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1610   return getWaitValuesWithoutDevnum(
1611       getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1612       getHasWaitDevnum(), deviceType);
1613 }
1614 
1615 mlir::Value SerialOp::getWaitDevnum() {
1616   return getWaitDevnum(mlir::acc::DeviceType::None);
1617 }
1618 
1619 mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1620   return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1621                             getWaitOperandsSegments(), getHasWaitDevnum(),
1622                             deviceType);
1623 }
1624 
1625 LogicalResult acc::SerialOp::verify() {
1626   if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1627           *this, getPrivatizations(), getPrivateOperands(), "private",
1628           "privatizations", /*checkOperandType=*/false)))
1629     return failure();
1630   if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1631           *this, getFirstprivatizations(), getFirstprivateOperands(),
1632           "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
1633     return failure();
1634   if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1635           *this, getReductionRecipes(), getReductionOperands(), "reduction",
1636           "reductions", false)))
1637     return failure();
1638 
1639   if (failed(verifyDeviceTypeAndSegmentCountMatch(
1640           *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1641           getWaitOperandsDeviceTypeAttr(), "wait")))
1642     return failure();
1643 
1644   if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1645                                         getAsyncOperandsDeviceTypeAttr(),
1646                                         "async")))
1647     return failure();
1648 
1649   if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*this)))
1650     return failure();
1651 
1652   return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
1653 }
1654 
1655 //===----------------------------------------------------------------------===//
1656 // KernelsOp
1657 //===----------------------------------------------------------------------===//
1658 
1659 unsigned KernelsOp::getNumDataOperands() {
1660   return getDataClauseOperands().size();
1661 }
1662 
1663 Value KernelsOp::getDataOperand(unsigned i) {
1664   unsigned numOptional = getAsyncOperands().size();
1665   numOptional += getWaitOperands().size();
1666   numOptional += getNumGangs().size();
1667   numOptional += getNumWorkers().size();
1668   numOptional += getVectorLength().size();
1669   numOptional += getIfCond() ? 1 : 0;
1670   numOptional += getSelfCond() ? 1 : 0;
1671   return getOperand(numOptional + i);
1672 }
1673 
1674 bool acc::KernelsOp::hasAsyncOnly() {
1675   return hasAsyncOnly(mlir::acc::DeviceType::None);
1676 }
1677 
1678 bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1679   return hasDeviceType(getAsyncOnly(), deviceType);
1680 }
1681 
1682 mlir::Value acc::KernelsOp::getAsyncValue() {
1683   return getAsyncValue(mlir::acc::DeviceType::None);
1684 }
1685 
1686 mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1687   return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
1688                                      getAsyncOperands(), deviceType);
1689 }
1690 
1691 mlir::Value acc::KernelsOp::getNumWorkersValue() {
1692   return getNumWorkersValue(mlir::acc::DeviceType::None);
1693 }
1694 
1695 mlir::Value
1696 acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1697   return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1698                                      deviceType);
1699 }
1700 
1701 mlir::Value acc::KernelsOp::getVectorLengthValue() {
1702   return getVectorLengthValue(mlir::acc::DeviceType::None);
1703 }
1704 
1705 mlir::Value
1706 acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1707   return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1708                                      getVectorLength(), deviceType);
1709 }
1710 
1711 mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
1712   return getNumGangsValues(mlir::acc::DeviceType::None);
1713 }
1714 
1715 mlir::Operation::operand_range
1716 KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1717   return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
1718                                getNumGangsSegments(), deviceType);
1719 }
1720 
1721 bool acc::KernelsOp::hasWaitOnly() {
1722   return hasWaitOnly(mlir::acc::DeviceType::None);
1723 }
1724 
1725 bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1726   return hasDeviceType(getWaitOnly(), deviceType);
1727 }
1728 
1729 mlir::Operation::operand_range KernelsOp::getWaitValues() {
1730   return getWaitValues(mlir::acc::DeviceType::None);
1731 }
1732 
1733 mlir::Operation::operand_range
1734 KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1735   return getWaitValuesWithoutDevnum(
1736       getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1737       getHasWaitDevnum(), deviceType);
1738 }
1739 
1740 mlir::Value KernelsOp::getWaitDevnum() {
1741   return getWaitDevnum(mlir::acc::DeviceType::None);
1742 }
1743 
1744 mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1745   return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1746                             getWaitOperandsSegments(), getHasWaitDevnum(),
1747                             deviceType);
1748 }
1749 
1750 LogicalResult acc::KernelsOp::verify() {
1751   if (failed(verifyDeviceTypeAndSegmentCountMatch(
1752           *this, getNumGangs(), getNumGangsSegmentsAttr(),
1753           getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
1754     return failure();
1755 
1756   if (failed(verifyDeviceTypeAndSegmentCountMatch(
1757           *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1758           getWaitOperandsDeviceTypeAttr(), "wait")))
1759     return failure();
1760 
1761   if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
1762                                         getNumWorkersDeviceTypeAttr(),
1763                                         "num_workers")))
1764     return failure();
1765 
1766   if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
1767                                         getVectorLengthDeviceTypeAttr(),
1768                                         "vector_length")))
1769     return failure();
1770 
1771   if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1772                                         getAsyncOperandsDeviceTypeAttr(),
1773                                         "async")))
1774     return failure();
1775 
1776   if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*this)))
1777     return failure();
1778 
1779   return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
1780 }
1781 
1782 //===----------------------------------------------------------------------===//
1783 // HostDataOp
1784 //===----------------------------------------------------------------------===//
1785 
1786 LogicalResult acc::HostDataOp::verify() {
1787   if (getDataClauseOperands().empty())
1788     return emitError("at least one operand must appear on the host_data "
1789                      "operation");
1790 
1791   for (mlir::Value operand : getDataClauseOperands())
1792     if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
1793       return emitError("expect data entry operation as defining op");
1794   return success();
1795 }
1796 
1797 void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
1798                                                   MLIRContext *context) {
1799   results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
1800 }
1801 
1802 //===----------------------------------------------------------------------===//
1803 // LoopOp
1804 //===----------------------------------------------------------------------===//
1805 
1806 static ParseResult parseGangValue(
1807     OpAsmParser &parser, llvm::StringRef keyword,
1808     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
1809     llvm::SmallVectorImpl<Type> &types,
1810     llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
1811     bool &needCommaBetweenValues, bool &newValue) {
1812   if (succeeded(parser.parseOptionalKeyword(keyword))) {
1813     if (parser.parseEqual())
1814       return failure();
1815     if (parser.parseOperand(operands.emplace_back()) ||
1816         parser.parseColonType(types.emplace_back()))
1817       return failure();
1818     attributes.push_back(gangArgType);
1819     needCommaBetweenValues = true;
1820     newValue = true;
1821   }
1822   return success();
1823 }
1824 
1825 static ParseResult parseGangClause(
1826     OpAsmParser &parser,
1827     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &gangOperands,
1828     llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
1829     mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
1830     mlir::ArrayAttr &gangOnlyDeviceType) {
1831   llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
1832   llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
1833   llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
1834   llvm::SmallVector<int32_t> seg;
1835   bool needCommaBetweenValues = false;
1836   bool needCommaBeforeOperands = false;
1837 
1838   if (failed(parser.parseOptionalLParen())) {
1839     // Gang only keyword
1840     gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1841         parser.getContext(), mlir::acc::DeviceType::None));
1842     gangOnlyDeviceType =
1843         ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
1844     return success();
1845   }
1846 
1847   // Parse gang only attributes
1848   if (succeeded(parser.parseOptionalLSquare())) {
1849     // Parse gang only attributes
1850     if (failed(parser.parseCommaSeparatedList([&]() {
1851           if (parser.parseAttribute(
1852                   gangOnlyDeviceTypeAttributes.emplace_back()))
1853             return failure();
1854           return success();
1855         })))
1856       return failure();
1857     if (parser.parseRSquare())
1858       return failure();
1859     needCommaBeforeOperands = true;
1860   }
1861 
1862   auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
1863                                                 mlir::acc::GangArgType::Num);
1864   auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
1865                                                 mlir::acc::GangArgType::Dim);
1866   auto argStatic = mlir::acc::GangArgTypeAttr::get(
1867       parser.getContext(), mlir::acc::GangArgType::Static);
1868 
1869   do {
1870     if (needCommaBeforeOperands) {
1871       needCommaBeforeOperands = false;
1872       continue;
1873     }
1874 
1875     if (failed(parser.parseLBrace()))
1876       return failure();
1877 
1878     int32_t crtOperandsSize = gangOperands.size();
1879     while (true) {
1880       bool newValue = false;
1881       bool needValue = false;
1882       if (needCommaBetweenValues) {
1883         if (succeeded(parser.parseOptionalComma()))
1884           needValue = true; // expect a new value after comma.
1885         else
1886           break;
1887       }
1888 
1889       if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
1890                                 gangOperands, gangOperandsType,
1891                                 gangArgTypeAttributes, argNum,
1892                                 needCommaBetweenValues, newValue)))
1893         return failure();
1894       if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
1895                                 gangOperands, gangOperandsType,
1896                                 gangArgTypeAttributes, argDim,
1897                                 needCommaBetweenValues, newValue)))
1898         return failure();
1899       if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
1900                                 gangOperands, gangOperandsType,
1901                                 gangArgTypeAttributes, argStatic,
1902                                 needCommaBetweenValues, newValue)))
1903         return failure();
1904 
1905       if (!newValue && needValue) {
1906         parser.emitError(parser.getCurrentLocation(),
1907                          "new value expected after comma");
1908         return failure();
1909       }
1910 
1911       if (!newValue)
1912         break;
1913     }
1914 
1915     if (gangOperands.empty())
1916       return parser.emitError(
1917           parser.getCurrentLocation(),
1918           "expect at least one of num, dim or static values");
1919 
1920     if (failed(parser.parseRBrace()))
1921       return failure();
1922 
1923     if (succeeded(parser.parseOptionalLSquare())) {
1924       if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) ||
1925           parser.parseRSquare())
1926         return failure();
1927     } else {
1928       deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1929           parser.getContext(), mlir::acc::DeviceType::None));
1930     }
1931 
1932     seg.push_back(gangOperands.size() - crtOperandsSize);
1933 
1934   } while (succeeded(parser.parseOptionalComma()));
1935 
1936   if (failed(parser.parseRParen()))
1937     return failure();
1938 
1939   llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
1940                                                gangArgTypeAttributes.end());
1941   gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
1942   deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
1943 
1944   llvm::SmallVector<mlir::Attribute> gangOnlyAttr(
1945       gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
1946   gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
1947 
1948   segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1949   return success();
1950 }
1951 
1952 void printGangClause(OpAsmPrinter &p, Operation *op,
1953                      mlir::OperandRange operands, mlir::TypeRange types,
1954                      std::optional<mlir::ArrayAttr> gangArgTypes,
1955                      std::optional<mlir::ArrayAttr> deviceTypes,
1956                      std::optional<mlir::DenseI32ArrayAttr> segments,
1957                      std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
1958 
1959   if (operands.begin() == operands.end() &&
1960       hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
1961     return;
1962   }
1963 
1964   p << "(";
1965 
1966   printDeviceTypes(p, gangOnlyDeviceTypes);
1967 
1968   if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
1969       hasDeviceTypeValues(deviceTypes))
1970     p << ", ";
1971 
1972   if (hasDeviceTypeValues(deviceTypes)) {
1973     unsigned opIdx = 0;
1974     llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1975       p << "{";
1976       llvm::interleaveComma(
1977           llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1978             auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
1979                 (*gangArgTypes)[opIdx]);
1980             if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
1981               p << LoopOp::getGangNumKeyword();
1982             else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
1983               p << LoopOp::getGangDimKeyword();
1984             else if (gangArgTypeAttr.getValue() ==
1985                      mlir::acc::GangArgType::Static)
1986               p << LoopOp::getGangStaticKeyword();
1987             p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
1988             ++opIdx;
1989           });
1990       p << "}";
1991       printSingleDeviceType(p, it.value());
1992     });
1993   }
1994   p << ")";
1995 }
1996 
1997 bool hasDuplicateDeviceTypes(
1998     std::optional<mlir::ArrayAttr> segments,
1999     llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
2000   if (!segments)
2001     return false;
2002   for (auto attr : *segments) {
2003     auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2004     if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
2005       return true;
2006   }
2007   return false;
2008 }
2009 
2010 /// Check for duplicates in the DeviceType array attribute.
2011 LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
2012   llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
2013   if (!deviceTypes)
2014     return success();
2015   for (auto attr : deviceTypes) {
2016     auto deviceTypeAttr =
2017         mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
2018     if (!deviceTypeAttr)
2019       return failure();
2020     if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
2021       return failure();
2022   }
2023   return success();
2024 }
2025 
2026 LogicalResult acc::LoopOp::verify() {
2027   if (!getUpperbound().empty() && getInclusiveUpperbound() &&
2028       (getUpperbound().size() != getInclusiveUpperbound()->size()))
2029     return emitError() << "inclusiveUpperbound size is expected to be the same"
2030                        << " as upperbound size";
2031 
2032   // Check collapse
2033   if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
2034     return emitOpError() << "collapse device_type attr must be define when"
2035                          << " collapse attr is present";
2036 
2037   if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
2038       getCollapseAttr().getValue().size() !=
2039           getCollapseDeviceTypeAttr().getValue().size())
2040     return emitOpError() << "collapse attribute count must match collapse"
2041                          << " device_type count";
2042   if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr())))
2043     return emitOpError()
2044            << "duplicate device_type found in collapseDeviceType attribute";
2045 
2046   // Check gang
2047   if (!getGangOperands().empty()) {
2048     if (!getGangOperandsArgType())
2049       return emitOpError() << "gangOperandsArgType attribute must be defined"
2050                            << " when gang operands are present";
2051 
2052     if (getGangOperands().size() !=
2053         getGangOperandsArgTypeAttr().getValue().size())
2054       return emitOpError() << "gangOperandsArgType attribute count must match"
2055                            << " gangOperands count";
2056   }
2057   if (getGangAttr() && failed(checkDeviceTypes(getGangAttr())))
2058     return emitOpError() << "duplicate device_type found in gang attribute";
2059 
2060   if (failed(verifyDeviceTypeAndSegmentCountMatch(
2061           *this, getGangOperands(), getGangOperandsSegmentsAttr(),
2062           getGangOperandsDeviceTypeAttr(), "gang")))
2063     return failure();
2064 
2065   // Check worker
2066   if (failed(checkDeviceTypes(getWorkerAttr())))
2067     return emitOpError() << "duplicate device_type found in worker attribute";
2068   if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())))
2069     return emitOpError() << "duplicate device_type found in "
2070                             "workerNumOperandsDeviceType attribute";
2071   if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
2072                                         getWorkerNumOperandsDeviceTypeAttr(),
2073                                         "worker")))
2074     return failure();
2075 
2076   // Check vector
2077   if (failed(checkDeviceTypes(getVectorAttr())))
2078     return emitOpError() << "duplicate device_type found in vector attribute";
2079   if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr())))
2080     return emitOpError() << "duplicate device_type found in "
2081                             "vectorOperandsDeviceType attribute";
2082   if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
2083                                         getVectorOperandsDeviceTypeAttr(),
2084                                         "vector")))
2085     return failure();
2086 
2087   if (failed(verifyDeviceTypeAndSegmentCountMatch(
2088           *this, getTileOperands(), getTileOperandsSegmentsAttr(),
2089           getTileOperandsDeviceTypeAttr(), "tile")))
2090     return failure();
2091 
2092   // auto, independent and seq attribute are mutually exclusive.
2093   llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
2094   if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
2095       hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
2096       hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
2097     return emitError() << "only one of \"" << acc::LoopOp::getAutoAttrStrName()
2098                        << "\", " << getIndependentAttrName() << ", "
2099                        << getSeqAttrName()
2100                        << " can be present at the same time";
2101   }
2102 
2103   // Gang, worker and vector are incompatible with seq.
2104   if (getSeqAttr()) {
2105     for (auto attr : getSeqAttr()) {
2106       auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2107       if (hasVector(deviceTypeAttr.getValue()) ||
2108           getVectorValue(deviceTypeAttr.getValue()) ||
2109           hasWorker(deviceTypeAttr.getValue()) ||
2110           getWorkerValue(deviceTypeAttr.getValue()) ||
2111           hasGang(deviceTypeAttr.getValue()) ||
2112           getGangValue(mlir::acc::GangArgType::Num,
2113                        deviceTypeAttr.getValue()) ||
2114           getGangValue(mlir::acc::GangArgType::Dim,
2115                        deviceTypeAttr.getValue()) ||
2116           getGangValue(mlir::acc::GangArgType::Static,
2117                        deviceTypeAttr.getValue()))
2118         return emitError()
2119                << "gang, worker or vector cannot appear with the seq attr";
2120     }
2121   }
2122 
2123   if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
2124           *this, getPrivatizations(), getPrivateOperands(), "private",
2125           "privatizations", false)))
2126     return failure();
2127 
2128   if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
2129           *this, getReductionRecipes(), getReductionOperands(), "reduction",
2130           "reductions", false)))
2131     return failure();
2132 
2133   if (getCombined().has_value() &&
2134       (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
2135        getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
2136        getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
2137     return emitError("unexpected combined constructs attribute");
2138   }
2139 
2140   // Check non-empty body().
2141   if (getRegion().empty())
2142     return emitError("expected non-empty body.");
2143 
2144   return success();
2145 }
2146 
2147 unsigned LoopOp::getNumDataOperands() {
2148   return getReductionOperands().size() + getPrivateOperands().size();
2149 }
2150 
2151 Value LoopOp::getDataOperand(unsigned i) {
2152   unsigned numOptional =
2153       getLowerbound().size() + getUpperbound().size() + getStep().size();
2154   numOptional += getGangOperands().size();
2155   numOptional += getVectorOperands().size();
2156   numOptional += getWorkerNumOperands().size();
2157   numOptional += getTileOperands().size();
2158   numOptional += getCacheOperands().size();
2159   return getOperand(numOptional + i);
2160 }
2161 
2162 bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
2163 
2164 bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
2165   return hasDeviceType(getAuto_(), deviceType);
2166 }
2167 
2168 bool LoopOp::hasIndependent() {
2169   return hasIndependent(mlir::acc::DeviceType::None);
2170 }
2171 
2172 bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
2173   return hasDeviceType(getIndependent(), deviceType);
2174 }
2175 
2176 bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
2177 
2178 bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
2179   return hasDeviceType(getSeq(), deviceType);
2180 }
2181 
2182 mlir::Value LoopOp::getVectorValue() {
2183   return getVectorValue(mlir::acc::DeviceType::None);
2184 }
2185 
2186 mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
2187   return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
2188                                      getVectorOperands(), deviceType);
2189 }
2190 
2191 bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
2192 
2193 bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
2194   return hasDeviceType(getVector(), deviceType);
2195 }
2196 
2197 mlir::Value LoopOp::getWorkerValue() {
2198   return getWorkerValue(mlir::acc::DeviceType::None);
2199 }
2200 
2201 mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
2202   return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
2203                                      getWorkerNumOperands(), deviceType);
2204 }
2205 
2206 bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
2207 
2208 bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
2209   return hasDeviceType(getWorker(), deviceType);
2210 }
2211 
2212 mlir::Operation::operand_range LoopOp::getTileValues() {
2213   return getTileValues(mlir::acc::DeviceType::None);
2214 }
2215 
2216 mlir::Operation::operand_range
2217 LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
2218   return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
2219                                getTileOperandsSegments(), deviceType);
2220 }
2221 
2222 std::optional<int64_t> LoopOp::getCollapseValue() {
2223   return getCollapseValue(mlir::acc::DeviceType::None);
2224 }
2225 
2226 std::optional<int64_t>
2227 LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
2228   if (!getCollapseAttr())
2229     return std::nullopt;
2230   if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
2231     auto intAttr =
2232         mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
2233     return intAttr.getValue().getZExtValue();
2234   }
2235   return std::nullopt;
2236 }
2237 
2238 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
2239   return getGangValue(gangArgType, mlir::acc::DeviceType::None);
2240 }
2241 
2242 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
2243                                  mlir::acc::DeviceType deviceType) {
2244   if (getGangOperands().empty())
2245     return {};
2246   if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
2247     int32_t nbOperandsBefore = 0;
2248     for (unsigned i = 0; i < *pos; ++i)
2249       nbOperandsBefore += (*getGangOperandsSegments())[i];
2250     mlir::Operation::operand_range values =
2251         getGangOperands()
2252             .drop_front(nbOperandsBefore)
2253             .take_front((*getGangOperandsSegments())[*pos]);
2254 
2255     int32_t argTypeIdx = nbOperandsBefore;
2256     for (auto value : values) {
2257       auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2258           (*getGangOperandsArgType())[argTypeIdx]);
2259       if (gangArgTypeAttr.getValue() == gangArgType)
2260         return value;
2261       ++argTypeIdx;
2262     }
2263   }
2264   return {};
2265 }
2266 
2267 bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
2268 
2269 bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
2270   return hasDeviceType(getGang(), deviceType);
2271 }
2272 
2273 llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
2274   return {&getRegion()};
2275 }
2276 
2277 /// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
2278 /// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
2279 /// `(` ssa-id-and-type-list `)`
2280 /// region
2281 ParseResult
2282 parseLoopControl(OpAsmParser &parser, Region &region,
2283                  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &lowerbound,
2284                  SmallVectorImpl<Type> &lowerboundType,
2285                  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &upperbound,
2286                  SmallVectorImpl<Type> &upperboundType,
2287                  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &step,
2288                  SmallVectorImpl<Type> &stepType) {
2289 
2290   SmallVector<OpAsmParser::Argument> inductionVars;
2291   if (succeeded(
2292           parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
2293     if (parser.parseLParen() ||
2294         parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
2295                                  /*allowType=*/true) ||
2296         parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
2297         parser.parseOperandList(lowerbound, inductionVars.size(),
2298                                 OpAsmParser::Delimiter::None) ||
2299         parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
2300         parser.parseKeyword("to") || parser.parseLParen() ||
2301         parser.parseOperandList(upperbound, inductionVars.size(),
2302                                 OpAsmParser::Delimiter::None) ||
2303         parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
2304         parser.parseKeyword("step") || parser.parseLParen() ||
2305         parser.parseOperandList(step, inductionVars.size(),
2306                                 OpAsmParser::Delimiter::None) ||
2307         parser.parseColonTypeList(stepType) || parser.parseRParen())
2308       return failure();
2309   }
2310   return parser.parseRegion(region, inductionVars);
2311 }
2312 
2313 void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region,
2314                       ValueRange lowerbound, TypeRange lowerboundType,
2315                       ValueRange upperbound, TypeRange upperboundType,
2316                       ValueRange steps, TypeRange stepType) {
2317   ValueRange regionArgs = region.front().getArguments();
2318   if (!regionArgs.empty()) {
2319     p << acc::LoopOp::getControlKeyword() << "(";
2320     llvm::interleaveComma(regionArgs, p,
2321                           [&p](Value v) { p << v << " : " << v.getType(); });
2322     p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
2323       << upperbound << " : " << upperboundType << ") " << " step (" << steps
2324       << " : " << stepType << ") ";
2325   }
2326   p.printRegion(region, /*printEntryBlockArgs=*/false);
2327 }
2328 
2329 //===----------------------------------------------------------------------===//
2330 // DataOp
2331 //===----------------------------------------------------------------------===//
2332 
2333 LogicalResult acc::DataOp::verify() {
2334   // 2.6.5. Data Construct restriction
2335   // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
2336   // attach, or default clause must appear on a data construct.
2337   if (getOperands().empty() && !getDefaultAttr())
2338     return emitError("at least one operand or the default attribute "
2339                      "must appear on the data operation");
2340 
2341   for (mlir::Value operand : getDataClauseOperands())
2342     if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2343                    acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
2344                    acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
2345             operand.getDefiningOp()))
2346       return emitError("expect data entry/exit operation or acc.getdeviceptr "
2347                        "as defining op");
2348 
2349   if (failed(checkWaitAndAsyncConflict<acc::DataOp>(*this)))
2350     return failure();
2351 
2352   return success();
2353 }
2354 
2355 unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
2356 
2357 Value DataOp::getDataOperand(unsigned i) {
2358   unsigned numOptional = getIfCond() ? 1 : 0;
2359   numOptional += getAsyncOperands().size() ? 1 : 0;
2360   numOptional += getWaitOperands().size();
2361   return getOperand(numOptional + i);
2362 }
2363 
2364 bool acc::DataOp::hasAsyncOnly() {
2365   return hasAsyncOnly(mlir::acc::DeviceType::None);
2366 }
2367 
2368 bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2369   return hasDeviceType(getAsyncOnly(), deviceType);
2370 }
2371 
2372 mlir::Value DataOp::getAsyncValue() {
2373   return getAsyncValue(mlir::acc::DeviceType::None);
2374 }
2375 
2376 mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2377   return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
2378                                      getAsyncOperands(), deviceType);
2379 }
2380 
2381 bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
2382 
2383 bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2384   return hasDeviceType(getWaitOnly(), deviceType);
2385 }
2386 
2387 mlir::Operation::operand_range DataOp::getWaitValues() {
2388   return getWaitValues(mlir::acc::DeviceType::None);
2389 }
2390 
2391 mlir::Operation::operand_range
2392 DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2393   return getWaitValuesWithoutDevnum(
2394       getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2395       getHasWaitDevnum(), deviceType);
2396 }
2397 
2398 mlir::Value DataOp::getWaitDevnum() {
2399   return getWaitDevnum(mlir::acc::DeviceType::None);
2400 }
2401 
2402 mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2403   return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2404                             getWaitOperandsSegments(), getHasWaitDevnum(),
2405                             deviceType);
2406 }
2407 
2408 //===----------------------------------------------------------------------===//
2409 // ExitDataOp
2410 //===----------------------------------------------------------------------===//
2411 
2412 LogicalResult acc::ExitDataOp::verify() {
2413   // 2.6.6. Data Exit Directive restriction
2414   // At least one copyout, delete, or detach clause must appear on an exit data
2415   // directive.
2416   if (getDataClauseOperands().empty())
2417     return emitError("at least one operand must be present in dataOperands on "
2418                      "the exit data operation");
2419 
2420   // The async attribute represent the async clause without value. Therefore the
2421   // attribute and operand cannot appear at the same time.
2422   if (getAsyncOperand() && getAsync())
2423     return emitError("async attribute cannot appear with asyncOperand");
2424 
2425   // The wait attribute represent the wait clause without values. Therefore the
2426   // attribute and operands cannot appear at the same time.
2427   if (!getWaitOperands().empty() && getWait())
2428     return emitError("wait attribute cannot appear with waitOperands");
2429 
2430   if (getWaitDevnum() && getWaitOperands().empty())
2431     return emitError("wait_devnum cannot appear without waitOperands");
2432 
2433   return success();
2434 }
2435 
2436 unsigned ExitDataOp::getNumDataOperands() {
2437   return getDataClauseOperands().size();
2438 }
2439 
2440 Value ExitDataOp::getDataOperand(unsigned i) {
2441   unsigned numOptional = getIfCond() ? 1 : 0;
2442   numOptional += getAsyncOperand() ? 1 : 0;
2443   numOptional += getWaitDevnum() ? 1 : 0;
2444   return getOperand(getWaitOperands().size() + numOptional + i);
2445 }
2446 
2447 void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2448                                              MLIRContext *context) {
2449   results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
2450 }
2451 
2452 //===----------------------------------------------------------------------===//
2453 // EnterDataOp
2454 //===----------------------------------------------------------------------===//
2455 
2456 LogicalResult acc::EnterDataOp::verify() {
2457   // 2.6.6. Data Enter Directive restriction
2458   // At least one copyin, create, or attach clause must appear on an enter data
2459   // directive.
2460   if (getDataClauseOperands().empty())
2461     return emitError("at least one operand must be present in dataOperands on "
2462                      "the enter data operation");
2463 
2464   // The async attribute represent the async clause without value. Therefore the
2465   // attribute and operand cannot appear at the same time.
2466   if (getAsyncOperand() && getAsync())
2467     return emitError("async attribute cannot appear with asyncOperand");
2468 
2469   // The wait attribute represent the wait clause without values. Therefore the
2470   // attribute and operands cannot appear at the same time.
2471   if (!getWaitOperands().empty() && getWait())
2472     return emitError("wait attribute cannot appear with waitOperands");
2473 
2474   if (getWaitDevnum() && getWaitOperands().empty())
2475     return emitError("wait_devnum cannot appear without waitOperands");
2476 
2477   for (mlir::Value operand : getDataClauseOperands())
2478     if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
2479             operand.getDefiningOp()))
2480       return emitError("expect data entry operation as defining op");
2481 
2482   return success();
2483 }
2484 
2485 unsigned EnterDataOp::getNumDataOperands() {
2486   return getDataClauseOperands().size();
2487 }
2488 
2489 Value EnterDataOp::getDataOperand(unsigned i) {
2490   unsigned numOptional = getIfCond() ? 1 : 0;
2491   numOptional += getAsyncOperand() ? 1 : 0;
2492   numOptional += getWaitDevnum() ? 1 : 0;
2493   return getOperand(getWaitOperands().size() + numOptional + i);
2494 }
2495 
2496 void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2497                                               MLIRContext *context) {
2498   results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
2499 }
2500 
2501 //===----------------------------------------------------------------------===//
2502 // AtomicReadOp
2503 //===----------------------------------------------------------------------===//
2504 
2505 LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
2506 
2507 //===----------------------------------------------------------------------===//
2508 // AtomicWriteOp
2509 //===----------------------------------------------------------------------===//
2510 
2511 LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
2512 
2513 //===----------------------------------------------------------------------===//
2514 // AtomicUpdateOp
2515 //===----------------------------------------------------------------------===//
2516 
2517 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
2518                                            PatternRewriter &rewriter) {
2519   if (op.isNoOp()) {
2520     rewriter.eraseOp(op);
2521     return success();
2522   }
2523 
2524   if (Value writeVal = op.getWriteOpVal()) {
2525     rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal);
2526     return success();
2527   }
2528 
2529   return failure();
2530 }
2531 
2532 LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
2533 
2534 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
2535 
2536 //===----------------------------------------------------------------------===//
2537 // AtomicCaptureOp
2538 //===----------------------------------------------------------------------===//
2539 
2540 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
2541   if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
2542     return op;
2543   return dyn_cast<AtomicReadOp>(getSecondOp());
2544 }
2545 
2546 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
2547   if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
2548     return op;
2549   return dyn_cast<AtomicWriteOp>(getSecondOp());
2550 }
2551 
2552 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
2553   if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
2554     return op;
2555   return dyn_cast<AtomicUpdateOp>(getSecondOp());
2556 }
2557 
2558 LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
2559 
2560 //===----------------------------------------------------------------------===//
2561 // DeclareEnterOp
2562 //===----------------------------------------------------------------------===//
2563 
2564 template <typename Op>
2565 static LogicalResult
2566 checkDeclareOperands(Op &op, const mlir::ValueRange &operands,
2567                      bool requireAtLeastOneOperand = true) {
2568   if (operands.empty() && requireAtLeastOneOperand)
2569     return emitError(
2570         op->getLoc(),
2571         "at least one operand must appear on the declare operation");
2572 
2573   for (mlir::Value operand : operands) {
2574     if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2575                    acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
2576                    acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
2577             operand.getDefiningOp()))
2578       return op.emitError(
2579           "expect valid declare data entry operation or acc.getdeviceptr "
2580           "as defining op");
2581 
2582     mlir::Value varPtr{getVarPtr(operand.getDefiningOp())};
2583     assert(varPtr && "declare operands can only be data entry operations which "
2584                      "must have varPtr");
2585     std::optional<mlir::acc::DataClause> dataClauseOptional{
2586         getDataClause(operand.getDefiningOp())};
2587     assert(dataClauseOptional.has_value() &&
2588            "declare operands can only be data entry operations which must have "
2589            "dataClause");
2590 
2591     // If varPtr has no defining op - there is nothing to check further.
2592     if (!varPtr.getDefiningOp())
2593       continue;
2594 
2595     // Check that the varPtr has a declare attribute.
2596     auto declareAttribute{
2597         varPtr.getDefiningOp()->getAttr(mlir::acc::getDeclareAttrName())};
2598     if (!declareAttribute)
2599       return op.emitError(
2600           "expect declare attribute on variable in declare operation");
2601 
2602     auto declAttr = mlir::cast<mlir::acc::DeclareAttr>(declareAttribute);
2603     if (declAttr.getDataClause().getValue() != dataClauseOptional.value())
2604       return op.emitError(
2605           "expect matching declare attribute on variable in declare operation");
2606 
2607     // If the variable is marked with implicit attribute, the matching declare
2608     // data action must also be marked implicit. The reverse is not checked
2609     // since implicit data action may be inserted to do actions like updating
2610     // device copy, in which case the variable is not necessarily implicitly
2611     // declare'd.
2612     if (declAttr.getImplicit() &&
2613         declAttr.getImplicit() != acc::getImplicitFlag(operand.getDefiningOp()))
2614       return op.emitError(
2615           "implicitness must match between declare op and flag on variable");
2616   }
2617 
2618   return success();
2619 }
2620 
2621 LogicalResult acc::DeclareEnterOp::verify() {
2622   return checkDeclareOperands(*this, this->getDataClauseOperands());
2623 }
2624 
2625 //===----------------------------------------------------------------------===//
2626 // DeclareExitOp
2627 //===----------------------------------------------------------------------===//
2628 
2629 LogicalResult acc::DeclareExitOp::verify() {
2630   if (getToken())
2631     return checkDeclareOperands(*this, this->getDataClauseOperands(),
2632                                 /*requireAtLeastOneOperand=*/false);
2633   return checkDeclareOperands(*this, this->getDataClauseOperands());
2634 }
2635 
2636 //===----------------------------------------------------------------------===//
2637 // DeclareOp
2638 //===----------------------------------------------------------------------===//
2639 
2640 LogicalResult acc::DeclareOp::verify() {
2641   return checkDeclareOperands(*this, this->getDataClauseOperands());
2642 }
2643 
2644 //===----------------------------------------------------------------------===//
2645 // RoutineOp
2646 //===----------------------------------------------------------------------===//
2647 
2648 static unsigned getParallelismForDeviceType(acc::RoutineOp op,
2649                                             acc::DeviceType dtype) {
2650   unsigned parallelism = 0;
2651   parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
2652   parallelism += op.hasWorker(dtype) ? 1 : 0;
2653   parallelism += op.hasVector(dtype) ? 1 : 0;
2654   parallelism += op.hasSeq(dtype) ? 1 : 0;
2655   return parallelism;
2656 }
2657 
2658 LogicalResult acc::RoutineOp::verify() {
2659   unsigned baseParallelism =
2660       getParallelismForDeviceType(*this, acc::DeviceType::None);
2661 
2662   if (baseParallelism > 1)
2663     return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
2664                           "be present at the same time";
2665 
2666   for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
2667        ++dtypeInt) {
2668     auto dtype = static_cast<acc::DeviceType>(dtypeInt);
2669     if (dtype == acc::DeviceType::None)
2670       continue;
2671     unsigned parallelism = getParallelismForDeviceType(*this, dtype);
2672 
2673     if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
2674       return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
2675                             "be present at the same time";
2676   }
2677 
2678   return success();
2679 }
2680 
2681 static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName,
2682                                  mlir::ArrayAttr &deviceTypes) {
2683   llvm::SmallVector<mlir::Attribute> bindNameAttrs;
2684   llvm::SmallVector<mlir::Attribute> deviceTypeAttrs;
2685 
2686   if (failed(parser.parseCommaSeparatedList([&]() {
2687         if (parser.parseAttribute(bindNameAttrs.emplace_back()))
2688           return failure();
2689         if (failed(parser.parseOptionalLSquare())) {
2690           deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2691               parser.getContext(), mlir::acc::DeviceType::None));
2692         } else {
2693           if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
2694               parser.parseRSquare())
2695             return failure();
2696         }
2697         return success();
2698       })))
2699     return failure();
2700 
2701   bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs);
2702   deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2703 
2704   return success();
2705 }
2706 
2707 static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op,
2708                           std::optional<mlir::ArrayAttr> bindName,
2709                           std::optional<mlir::ArrayAttr> deviceTypes) {
2710   llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
2711                         [&](const auto &pair) {
2712                           p << std::get<0>(pair);
2713                           printSingleDeviceType(p, std::get<1>(pair));
2714                         });
2715 }
2716 
2717 static ParseResult parseRoutineGangClause(OpAsmParser &parser,
2718                                           mlir::ArrayAttr &gang,
2719                                           mlir::ArrayAttr &gangDim,
2720                                           mlir::ArrayAttr &gangDimDeviceTypes) {
2721 
2722   llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
2723       gangDimDeviceTypeAttrs;
2724   bool needCommaBeforeOperands = false;
2725 
2726   // Gang keyword only
2727   if (failed(parser.parseOptionalLParen())) {
2728     gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2729         parser.getContext(), mlir::acc::DeviceType::None));
2730     gang = ArrayAttr::get(parser.getContext(), gangAttrs);
2731     return success();
2732   }
2733 
2734   // Parse keyword only attributes
2735   if (succeeded(parser.parseOptionalLSquare())) {
2736     if (failed(parser.parseCommaSeparatedList([&]() {
2737           if (parser.parseAttribute(gangAttrs.emplace_back()))
2738             return failure();
2739           return success();
2740         })))
2741       return failure();
2742     if (parser.parseRSquare())
2743       return failure();
2744     needCommaBeforeOperands = true;
2745   }
2746 
2747   if (needCommaBeforeOperands && failed(parser.parseComma()))
2748     return failure();
2749 
2750   if (failed(parser.parseCommaSeparatedList([&]() {
2751         if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
2752             parser.parseColon() ||
2753             parser.parseAttribute(gangDimAttrs.emplace_back()))
2754           return failure();
2755         if (succeeded(parser.parseOptionalLSquare())) {
2756           if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
2757               parser.parseRSquare())
2758             return failure();
2759         } else {
2760           gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2761               parser.getContext(), mlir::acc::DeviceType::None));
2762         }
2763         return success();
2764       })))
2765     return failure();
2766 
2767   if (failed(parser.parseRParen()))
2768     return failure();
2769 
2770   gang = ArrayAttr::get(parser.getContext(), gangAttrs);
2771   gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
2772   gangDimDeviceTypes =
2773       ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
2774 
2775   return success();
2776 }
2777 
2778 void printRoutineGangClause(OpAsmPrinter &p, Operation *op,
2779                             std::optional<mlir::ArrayAttr> gang,
2780                             std::optional<mlir::ArrayAttr> gangDim,
2781                             std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
2782 
2783   if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
2784       gang->size() == 1) {
2785     auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
2786     if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
2787       return;
2788   }
2789 
2790   p << "(";
2791 
2792   printDeviceTypes(p, gang);
2793 
2794   if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
2795     p << ", ";
2796 
2797   if (hasDeviceTypeValues(gangDimDeviceTypes))
2798     llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
2799                           [&](const auto &pair) {
2800                             p << acc::RoutineOp::getGangDimKeyword() << ": ";
2801                             p << std::get<0>(pair);
2802                             printSingleDeviceType(p, std::get<1>(pair));
2803                           });
2804 
2805   p << ")";
2806 }
2807 
2808 static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
2809                                             mlir::ArrayAttr &deviceTypes) {
2810   llvm::SmallVector<mlir::Attribute> attributes;
2811   // Keyword only
2812   if (failed(parser.parseOptionalLParen())) {
2813     attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2814         parser.getContext(), mlir::acc::DeviceType::None));
2815     deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
2816     return success();
2817   }
2818 
2819   // Parse device type attributes
2820   if (succeeded(parser.parseOptionalLSquare())) {
2821     if (failed(parser.parseCommaSeparatedList([&]() {
2822           if (parser.parseAttribute(attributes.emplace_back()))
2823             return failure();
2824           return success();
2825         })))
2826       return failure();
2827     if (parser.parseRSquare() || parser.parseRParen())
2828       return failure();
2829   }
2830   deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
2831   return success();
2832 }
2833 
2834 static void
2835 printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op,
2836                          std::optional<mlir::ArrayAttr> deviceTypes) {
2837 
2838   if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
2839     auto deviceTypeAttr =
2840         mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
2841     if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
2842       return;
2843   }
2844 
2845   if (!hasDeviceTypeValues(deviceTypes))
2846     return;
2847 
2848   p << "([";
2849   llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
2850     auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2851     p << dTypeAttr;
2852   });
2853   p << "])";
2854 }
2855 
2856 bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
2857 
2858 bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
2859   return hasDeviceType(getWorker(), deviceType);
2860 }
2861 
2862 bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
2863 
2864 bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
2865   return hasDeviceType(getVector(), deviceType);
2866 }
2867 
2868 bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
2869 
2870 bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
2871   return hasDeviceType(getSeq(), deviceType);
2872 }
2873 
2874 std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
2875   return getBindNameValue(mlir::acc::DeviceType::None);
2876 }
2877 
2878 std::optional<llvm::StringRef>
2879 RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
2880   if (!hasDeviceTypeValues(getBindNameDeviceType()))
2881     return std::nullopt;
2882   if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) {
2883     auto attr = (*getBindName())[*pos];
2884     auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
2885     return stringAttr.getValue();
2886   }
2887   return std::nullopt;
2888 }
2889 
2890 bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
2891 
2892 bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
2893   return hasDeviceType(getGang(), deviceType);
2894 }
2895 
2896 std::optional<int64_t> RoutineOp::getGangDimValue() {
2897   return getGangDimValue(mlir::acc::DeviceType::None);
2898 }
2899 
2900 std::optional<int64_t>
2901 RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
2902   if (!hasDeviceTypeValues(getGangDimDeviceType()))
2903     return std::nullopt;
2904   if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
2905     auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
2906     return intAttr.getInt();
2907   }
2908   return std::nullopt;
2909 }
2910 
2911 //===----------------------------------------------------------------------===//
2912 // InitOp
2913 //===----------------------------------------------------------------------===//
2914 
2915 LogicalResult acc::InitOp::verify() {
2916   Operation *currOp = *this;
2917   while ((currOp = currOp->getParentOp()))
2918     if (isComputeOperation(currOp))
2919       return emitOpError("cannot be nested in a compute operation");
2920   return success();
2921 }
2922 
2923 //===----------------------------------------------------------------------===//
2924 // ShutdownOp
2925 //===----------------------------------------------------------------------===//
2926 
2927 LogicalResult acc::ShutdownOp::verify() {
2928   Operation *currOp = *this;
2929   while ((currOp = currOp->getParentOp()))
2930     if (isComputeOperation(currOp))
2931       return emitOpError("cannot be nested in a compute operation");
2932   return success();
2933 }
2934 
2935 //===----------------------------------------------------------------------===//
2936 // SetOp
2937 //===----------------------------------------------------------------------===//
2938 
2939 LogicalResult acc::SetOp::verify() {
2940   Operation *currOp = *this;
2941   while ((currOp = currOp->getParentOp()))
2942     if (isComputeOperation(currOp))
2943       return emitOpError("cannot be nested in a compute operation");
2944   if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
2945     return emitOpError("at least one default_async, device_num, or device_type "
2946                        "operand must appear");
2947   return success();
2948 }
2949 
2950 //===----------------------------------------------------------------------===//
2951 // UpdateOp
2952 //===----------------------------------------------------------------------===//
2953 
2954 LogicalResult acc::UpdateOp::verify() {
2955   // At least one of host or device should have a value.
2956   if (getDataClauseOperands().empty())
2957     return emitError("at least one value must be present in dataOperands");
2958 
2959   if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
2960                                         getAsyncOperandsDeviceTypeAttr(),
2961                                         "async")))
2962     return failure();
2963 
2964   if (failed(verifyDeviceTypeAndSegmentCountMatch(
2965           *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2966           getWaitOperandsDeviceTypeAttr(), "wait")))
2967     return failure();
2968 
2969   if (failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*this)))
2970     return failure();
2971 
2972   for (mlir::Value operand : getDataClauseOperands())
2973     if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
2974             operand.getDefiningOp()))
2975       return emitError("expect data entry/exit operation or acc.getdeviceptr "
2976                        "as defining op");
2977 
2978   return success();
2979 }
2980 
2981 unsigned UpdateOp::getNumDataOperands() {
2982   return getDataClauseOperands().size();
2983 }
2984 
2985 Value UpdateOp::getDataOperand(unsigned i) {
2986   unsigned numOptional = getAsyncOperands().size();
2987   numOptional += getIfCond() ? 1 : 0;
2988   return getOperand(getWaitOperands().size() + numOptional + i);
2989 }
2990 
2991 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
2992                                            MLIRContext *context) {
2993   results.add<RemoveConstantIfCondition<UpdateOp>>(context);
2994 }
2995 
2996 bool UpdateOp::hasAsyncOnly() {
2997   return hasAsyncOnly(mlir::acc::DeviceType::None);
2998 }
2999 
3000 bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3001   return hasDeviceType(getAsync(), deviceType);
3002 }
3003 
3004 mlir::Value UpdateOp::getAsyncValue() {
3005   return getAsyncValue(mlir::acc::DeviceType::None);
3006 }
3007 
3008 mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3009   if (!hasDeviceTypeValues(getAsyncOperandsDeviceType()))
3010     return {};
3011 
3012   if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
3013     return getAsyncOperands()[*pos];
3014 
3015   return {};
3016 }
3017 
3018 bool UpdateOp::hasWaitOnly() {
3019   return hasWaitOnly(mlir::acc::DeviceType::None);
3020 }
3021 
3022 bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3023   return hasDeviceType(getWaitOnly(), deviceType);
3024 }
3025 
3026 mlir::Operation::operand_range UpdateOp::getWaitValues() {
3027   return getWaitValues(mlir::acc::DeviceType::None);
3028 }
3029 
3030 mlir::Operation::operand_range
3031 UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3032   return getWaitValuesWithoutDevnum(
3033       getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3034       getHasWaitDevnum(), deviceType);
3035 }
3036 
3037 mlir::Value UpdateOp::getWaitDevnum() {
3038   return getWaitDevnum(mlir::acc::DeviceType::None);
3039 }
3040 
3041 mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3042   return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
3043                             getWaitOperandsSegments(), getHasWaitDevnum(),
3044                             deviceType);
3045 }
3046 
3047 //===----------------------------------------------------------------------===//
3048 // WaitOp
3049 //===----------------------------------------------------------------------===//
3050 
3051 LogicalResult acc::WaitOp::verify() {
3052   // The async attribute represent the async clause without value. Therefore the
3053   // attribute and operand cannot appear at the same time.
3054   if (getAsyncOperand() && getAsync())
3055     return emitError("async attribute cannot appear with asyncOperand");
3056 
3057   if (getWaitDevnum() && getWaitOperands().empty())
3058     return emitError("wait_devnum cannot appear without waitOperands");
3059 
3060   return success();
3061 }
3062 
3063 #define GET_OP_CLASSES
3064 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
3065 
3066 #define GET_ATTRDEF_CLASSES
3067 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
3068 
3069 #define GET_TYPEDEF_CLASSES
3070 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
3071 
3072 //===----------------------------------------------------------------------===//
3073 // acc dialect utilities
3074 //===----------------------------------------------------------------------===//
3075 
3076 mlir::TypedValue<mlir::acc::PointerLikeType>
3077 mlir::acc::getVarPtr(mlir::Operation *accDataClauseOp) {
3078   auto varPtr{llvm::TypeSwitch<mlir::Operation *,
3079                                mlir::TypedValue<mlir::acc::PointerLikeType>>(
3080                   accDataClauseOp)
3081                   .Case<ACC_DATA_ENTRY_OPS>(
3082                       [&](auto entry) { return entry.getVarPtr(); })
3083                   .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
3084                       [&](auto exit) { return exit.getVarPtr(); })
3085                   .Default([&](mlir::Operation *) {
3086                     return mlir::TypedValue<mlir::acc::PointerLikeType>();
3087                   })};
3088   return varPtr;
3089 }
3090 
3091 mlir::Value mlir::acc::getVar(mlir::Operation *accDataClauseOp) {
3092   auto varPtr{
3093       llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
3094           .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getVar(); })
3095           .Default([&](mlir::Operation *) { return mlir::Value(); })};
3096   return varPtr;
3097 }
3098 
3099 mlir::Type mlir::acc::getVarType(mlir::Operation *accDataClauseOp) {
3100   auto varType{llvm::TypeSwitch<mlir::Operation *, mlir::Type>(accDataClauseOp)
3101                    .Case<ACC_DATA_ENTRY_OPS>(
3102                        [&](auto entry) { return entry.getVarType(); })
3103                    .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
3104                        [&](auto exit) { return exit.getVarType(); })
3105                    .Default([&](mlir::Operation *) { return mlir::Type(); })};
3106   return varType;
3107 }
3108 
3109 mlir::TypedValue<mlir::acc::PointerLikeType>
3110 mlir::acc::getAccPtr(mlir::Operation *accDataClauseOp) {
3111   auto accPtr{llvm::TypeSwitch<mlir::Operation *,
3112                                mlir::TypedValue<mlir::acc::PointerLikeType>>(
3113                   accDataClauseOp)
3114                   .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
3115                       [&](auto dataClause) { return dataClause.getAccPtr(); })
3116                   .Default([&](mlir::Operation *) {
3117                     return mlir::TypedValue<mlir::acc::PointerLikeType>();
3118                   })};
3119   return accPtr;
3120 }
3121 
3122 mlir::Value mlir::acc::getAccVar(mlir::Operation *accDataClauseOp) {
3123   auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
3124                   .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
3125                       [&](auto dataClause) { return dataClause.getAccVar(); })
3126                   .Default([&](mlir::Operation *) { return mlir::Value(); })};
3127   return accPtr;
3128 }
3129 
3130 mlir::Value mlir::acc::getVarPtrPtr(mlir::Operation *accDataClauseOp) {
3131   auto varPtrPtr{
3132       llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
3133           .Case<ACC_DATA_ENTRY_OPS>(
3134               [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
3135           .Default([&](mlir::Operation *) { return mlir::Value(); })};
3136   return varPtrPtr;
3137 }
3138 
3139 mlir::SmallVector<mlir::Value>
3140 mlir::acc::getBounds(mlir::Operation *accDataClauseOp) {
3141   mlir::SmallVector<mlir::Value> bounds{
3142       llvm::TypeSwitch<mlir::Operation *, mlir::SmallVector<mlir::Value>>(
3143           accDataClauseOp)
3144           .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
3145             return mlir::SmallVector<mlir::Value>(
3146                 dataClause.getBounds().begin(), dataClause.getBounds().end());
3147           })
3148           .Default([&](mlir::Operation *) {
3149             return mlir::SmallVector<mlir::Value, 0>();
3150           })};
3151   return bounds;
3152 }
3153 
3154 mlir::SmallVector<mlir::Value>
3155 mlir::acc::getAsyncOperands(mlir::Operation *accDataClauseOp) {
3156   return llvm::TypeSwitch<mlir::Operation *, mlir::SmallVector<mlir::Value>>(
3157              accDataClauseOp)
3158       .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
3159         return mlir::SmallVector<mlir::Value>(
3160             dataClause.getAsyncOperands().begin(),
3161             dataClause.getAsyncOperands().end());
3162       })
3163       .Default([&](mlir::Operation *) {
3164         return mlir::SmallVector<mlir::Value, 0>();
3165       });
3166 }
3167 
3168 mlir::ArrayAttr
3169 mlir::acc::getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp) {
3170   return llvm::TypeSwitch<mlir::Operation *, mlir::ArrayAttr>(accDataClauseOp)
3171       .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
3172         return dataClause.getAsyncOperandsDeviceTypeAttr();
3173       })
3174       .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
3175 }
3176 
3177 mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
3178   return llvm::TypeSwitch<mlir::Operation *, mlir::ArrayAttr>(accDataClauseOp)
3179       .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
3180           [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
3181       .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
3182 }
3183 
3184 std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
3185   auto name{
3186       llvm::TypeSwitch<mlir::Operation *, std::optional<llvm::StringRef>>(accOp)
3187           .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
3188           .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
3189             return {};
3190           })};
3191   return name;
3192 }
3193 
3194 std::optional<mlir::acc::DataClause>
3195 mlir::acc::getDataClause(mlir::Operation *accDataEntryOp) {
3196   auto dataClause{
3197       llvm::TypeSwitch<mlir::Operation *, std::optional<mlir::acc::DataClause>>(
3198           accDataEntryOp)
3199           .Case<ACC_DATA_ENTRY_OPS>(
3200               [&](auto entry) { return entry.getDataClause(); })
3201           .Default([&](mlir::Operation *) { return std::nullopt; })};
3202   return dataClause;
3203 }
3204 
3205 bool mlir::acc::getImplicitFlag(mlir::Operation *accDataEntryOp) {
3206   auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
3207                     .Case<ACC_DATA_ENTRY_OPS>(
3208                         [&](auto entry) { return entry.getImplicit(); })
3209                     .Default([&](mlir::Operation *) { return false; })};
3210   return implicit;
3211 }
3212 
3213 mlir::ValueRange mlir::acc::getDataOperands(mlir::Operation *accOp) {
3214   auto dataOperands{
3215       llvm::TypeSwitch<mlir::Operation *, mlir::ValueRange>(accOp)
3216           .Case<ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS>(
3217               [&](auto entry) { return entry.getDataClauseOperands(); })
3218           .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
3219   return dataOperands;
3220 }
3221 
3222 mlir::MutableOperandRange
3223 mlir::acc::getMutableDataOperands(mlir::Operation *accOp) {
3224   auto dataOperands{
3225       llvm::TypeSwitch<mlir::Operation *, mlir::MutableOperandRange>(accOp)
3226           .Case<ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS>(
3227               [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
3228           .Default([&](mlir::Operation *) { return nullptr; })};
3229   return dataOperands;
3230 }
3231