1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the OpenMP dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 14 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 17 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.h" 18 #include "mlir/IR/Attributes.h" 19 #include "mlir/IR/BuiltinAttributes.h" 20 #include "mlir/IR/DialectImplementation.h" 21 #include "mlir/IR/OpImplementation.h" 22 #include "mlir/IR/OperationSupport.h" 23 #include "mlir/Interfaces/FoldInterfaces.h" 24 25 #include "llvm/ADT/ArrayRef.h" 26 #include "llvm/ADT/BitVector.h" 27 #include "llvm/ADT/STLExtras.h" 28 #include "llvm/ADT/STLForwardCompat.h" 29 #include "llvm/ADT/SmallString.h" 30 #include "llvm/ADT/StringExtras.h" 31 #include "llvm/ADT/StringRef.h" 32 #include "llvm/ADT/TypeSwitch.h" 33 #include "llvm/Frontend/OpenMP/OMPConstants.h" 34 #include "llvm/Frontend/OpenMP/OMPDeviceConstants.h" 35 #include <cstddef> 36 #include <iterator> 37 #include <optional> 38 #include <variant> 39 40 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc" 41 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" 42 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc" 43 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc" 44 45 using namespace mlir; 46 using namespace mlir::omp; 47 48 static ArrayAttr makeArrayAttr(MLIRContext *context, 49 llvm::ArrayRef<Attribute> attrs) { 50 return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs); 51 } 52 53 static DenseBoolArrayAttr 54 makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef<bool> boolArray) { 55 return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray); 56 } 57 58 namespace { 59 struct MemRefPointerLikeModel 60 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel, 61 MemRefType> { 62 Type getElementType(Type pointer) const { 63 return llvm::cast<MemRefType>(pointer).getElementType(); 64 } 65 }; 66 67 struct LLVMPointerPointerLikeModel 68 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel, 69 LLVM::LLVMPointerType> { 70 Type getElementType(Type pointer) const { return Type(); } 71 }; 72 } // namespace 73 74 void OpenMPDialect::initialize() { 75 addOperations< 76 #define GET_OP_LIST 77 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 78 >(); 79 addAttributes< 80 #define GET_ATTRDEF_LIST 81 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" 82 >(); 83 addTypes< 84 #define GET_TYPEDEF_LIST 85 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc" 86 >(); 87 88 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>(); 89 90 MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext()); 91 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>( 92 *getContext()); 93 94 // Attach default offload module interface to module op to access 95 // offload functionality through 96 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>( 97 *getContext()); 98 99 // Attach default declare target interfaces to operations which can be marked 100 // as declare target (Global Operations and Functions/Subroutines in dialects 101 // that Fortran (or other languages that lower to MLIR) translates too 102 mlir::LLVM::GlobalOp::attachInterface< 103 mlir::omp::DeclareTargetDefaultModel<mlir::LLVM::GlobalOp>>( 104 *getContext()); 105 mlir::LLVM::LLVMFuncOp::attachInterface< 106 mlir::omp::DeclareTargetDefaultModel<mlir::LLVM::LLVMFuncOp>>( 107 *getContext()); 108 mlir::func::FuncOp::attachInterface< 109 mlir::omp::DeclareTargetDefaultModel<mlir::func::FuncOp>>(*getContext()); 110 } 111 112 //===----------------------------------------------------------------------===// 113 // Parser and printer for Allocate Clause 114 //===----------------------------------------------------------------------===// 115 116 /// Parse an allocate clause with allocators and a list of operands with types. 117 /// 118 /// allocate-operand-list :: = allocate-operand | 119 /// allocator-operand `,` allocate-operand-list 120 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type 121 /// ssa-id-and-type ::= ssa-id `:` type 122 static ParseResult parseAllocateAndAllocator( 123 OpAsmParser &parser, 124 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &allocateVars, 125 SmallVectorImpl<Type> &allocateTypes, 126 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &allocatorVars, 127 SmallVectorImpl<Type> &allocatorTypes) { 128 129 return parser.parseCommaSeparatedList([&]() { 130 OpAsmParser::UnresolvedOperand operand; 131 Type type; 132 if (parser.parseOperand(operand) || parser.parseColonType(type)) 133 return failure(); 134 allocatorVars.push_back(operand); 135 allocatorTypes.push_back(type); 136 if (parser.parseArrow()) 137 return failure(); 138 if (parser.parseOperand(operand) || parser.parseColonType(type)) 139 return failure(); 140 141 allocateVars.push_back(operand); 142 allocateTypes.push_back(type); 143 return success(); 144 }); 145 } 146 147 /// Print allocate clause 148 static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, 149 OperandRange allocateVars, 150 TypeRange allocateTypes, 151 OperandRange allocatorVars, 152 TypeRange allocatorTypes) { 153 for (unsigned i = 0; i < allocateVars.size(); ++i) { 154 std::string separator = i == allocateVars.size() - 1 ? "" : ", "; 155 p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> "; 156 p << allocateVars[i] << " : " << allocateTypes[i] << separator; 157 } 158 } 159 160 //===----------------------------------------------------------------------===// 161 // Parser and printer for a clause attribute (StringEnumAttr) 162 //===----------------------------------------------------------------------===// 163 164 template <typename ClauseAttr> 165 static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) { 166 using ClauseT = decltype(std::declval<ClauseAttr>().getValue()); 167 StringRef enumStr; 168 SMLoc loc = parser.getCurrentLocation(); 169 if (parser.parseKeyword(&enumStr)) 170 return failure(); 171 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) { 172 attr = ClauseAttr::get(parser.getContext(), *enumValue); 173 return success(); 174 } 175 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'"; 176 } 177 178 template <typename ClauseAttr> 179 void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) { 180 p << stringifyEnum(attr.getValue()); 181 } 182 183 //===----------------------------------------------------------------------===// 184 // Parser and printer for Linear Clause 185 //===----------------------------------------------------------------------===// 186 187 /// linear ::= `linear` `(` linear-list `)` 188 /// linear-list := linear-val | linear-val linear-list 189 /// linear-val := ssa-id-and-type `=` ssa-id-and-type 190 static ParseResult parseLinearClause( 191 OpAsmParser &parser, 192 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &linearVars, 193 SmallVectorImpl<Type> &linearTypes, 194 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &linearStepVars) { 195 return parser.parseCommaSeparatedList([&]() { 196 OpAsmParser::UnresolvedOperand var; 197 Type type; 198 OpAsmParser::UnresolvedOperand stepVar; 199 if (parser.parseOperand(var) || parser.parseEqual() || 200 parser.parseOperand(stepVar) || parser.parseColonType(type)) 201 return failure(); 202 203 linearVars.push_back(var); 204 linearTypes.push_back(type); 205 linearStepVars.push_back(stepVar); 206 return success(); 207 }); 208 } 209 210 /// Print Linear Clause 211 static void printLinearClause(OpAsmPrinter &p, Operation *op, 212 ValueRange linearVars, TypeRange linearTypes, 213 ValueRange linearStepVars) { 214 size_t linearVarsSize = linearVars.size(); 215 for (unsigned i = 0; i < linearVarsSize; ++i) { 216 std::string separator = i == linearVarsSize - 1 ? "" : ", "; 217 p << linearVars[i]; 218 if (linearStepVars.size() > i) 219 p << " = " << linearStepVars[i]; 220 p << " : " << linearVars[i].getType() << separator; 221 } 222 } 223 224 //===----------------------------------------------------------------------===// 225 // Verifier for Nontemporal Clause 226 //===----------------------------------------------------------------------===// 227 228 static LogicalResult verifyNontemporalClause(Operation *op, 229 OperandRange nontemporalVars) { 230 231 // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section 232 DenseSet<Value> nontemporalItems; 233 for (const auto &it : nontemporalVars) 234 if (!nontemporalItems.insert(it).second) 235 return op->emitOpError() << "nontemporal variable used more than once"; 236 237 return success(); 238 } 239 240 //===----------------------------------------------------------------------===// 241 // Parser, verifier and printer for Aligned Clause 242 //===----------------------------------------------------------------------===// 243 static LogicalResult verifyAlignedClause(Operation *op, 244 std::optional<ArrayAttr> alignments, 245 OperandRange alignedVars) { 246 // Check if number of alignment values equals to number of aligned variables 247 if (!alignedVars.empty()) { 248 if (!alignments || alignments->size() != alignedVars.size()) 249 return op->emitOpError() 250 << "expected as many alignment values as aligned variables"; 251 } else { 252 if (alignments) 253 return op->emitOpError() << "unexpected alignment values attribute"; 254 return success(); 255 } 256 257 // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section 258 DenseSet<Value> alignedItems; 259 for (auto it : alignedVars) 260 if (!alignedItems.insert(it).second) 261 return op->emitOpError() << "aligned variable used more than once"; 262 263 if (!alignments) 264 return success(); 265 266 // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section 267 for (unsigned i = 0; i < (*alignments).size(); ++i) { 268 if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) { 269 if (intAttr.getValue().sle(0)) 270 return op->emitOpError() << "alignment should be greater than 0"; 271 } else { 272 return op->emitOpError() << "expected integer alignment"; 273 } 274 } 275 276 return success(); 277 } 278 279 /// aligned ::= `aligned` `(` aligned-list `)` 280 /// aligned-list := aligned-val | aligned-val aligned-list 281 /// aligned-val := ssa-id-and-type `->` alignment 282 static ParseResult 283 parseAlignedClause(OpAsmParser &parser, 284 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &alignedVars, 285 SmallVectorImpl<Type> &alignedTypes, 286 ArrayAttr &alignmentsAttr) { 287 SmallVector<Attribute> alignmentVec; 288 if (failed(parser.parseCommaSeparatedList([&]() { 289 if (parser.parseOperand(alignedVars.emplace_back()) || 290 parser.parseColonType(alignedTypes.emplace_back()) || 291 parser.parseArrow() || 292 parser.parseAttribute(alignmentVec.emplace_back())) { 293 return failure(); 294 } 295 return success(); 296 }))) 297 return failure(); 298 SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end()); 299 alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments); 300 return success(); 301 } 302 303 /// Print Aligned Clause 304 static void printAlignedClause(OpAsmPrinter &p, Operation *op, 305 ValueRange alignedVars, TypeRange alignedTypes, 306 std::optional<ArrayAttr> alignments) { 307 for (unsigned i = 0; i < alignedVars.size(); ++i) { 308 if (i != 0) 309 p << ", "; 310 p << alignedVars[i] << " : " << alignedVars[i].getType(); 311 p << " -> " << (*alignments)[i]; 312 } 313 } 314 315 //===----------------------------------------------------------------------===// 316 // Parser, printer and verifier for Schedule Clause 317 //===----------------------------------------------------------------------===// 318 319 static ParseResult 320 verifyScheduleModifiers(OpAsmParser &parser, 321 SmallVectorImpl<SmallString<12>> &modifiers) { 322 if (modifiers.size() > 2) 323 return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)"; 324 for (const auto &mod : modifiers) { 325 // Translate the string. If it has no value, then it was not a valid 326 // modifier! 327 auto symbol = symbolizeScheduleModifier(mod); 328 if (!symbol) 329 return parser.emitError(parser.getNameLoc()) 330 << " unknown modifier type: " << mod; 331 } 332 333 // If we have one modifier that is "simd", then stick a "none" modiifer in 334 // index 0. 335 if (modifiers.size() == 1) { 336 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) { 337 modifiers.push_back(modifiers[0]); 338 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none); 339 } 340 } else if (modifiers.size() == 2) { 341 // If there are two modifier: 342 // First modifier should not be simd, second one should be simd 343 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd || 344 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd) 345 return parser.emitError(parser.getNameLoc()) 346 << " incorrect modifier order"; 347 } 348 return success(); 349 } 350 351 /// schedule ::= `schedule` `(` sched-list `)` 352 /// sched-list ::= sched-val | sched-val sched-list | 353 /// sched-val `,` sched-modifier 354 /// sched-val ::= sched-with-chunk | sched-wo-chunk 355 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)? 356 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided` 357 /// sched-wo-chunk ::= `auto` | `runtime` 358 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val 359 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none` 360 static ParseResult 361 parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, 362 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, 363 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize, 364 Type &chunkType) { 365 StringRef keyword; 366 if (parser.parseKeyword(&keyword)) 367 return failure(); 368 std::optional<mlir::omp::ClauseScheduleKind> schedule = 369 symbolizeClauseScheduleKind(keyword); 370 if (!schedule) 371 return parser.emitError(parser.getNameLoc()) << " expected schedule kind"; 372 373 scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule); 374 switch (*schedule) { 375 case ClauseScheduleKind::Static: 376 case ClauseScheduleKind::Dynamic: 377 case ClauseScheduleKind::Guided: 378 if (succeeded(parser.parseOptionalEqual())) { 379 chunkSize = OpAsmParser::UnresolvedOperand{}; 380 if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType)) 381 return failure(); 382 } else { 383 chunkSize = std::nullopt; 384 } 385 break; 386 case ClauseScheduleKind::Auto: 387 case ClauseScheduleKind::Runtime: 388 chunkSize = std::nullopt; 389 } 390 391 // If there is a comma, we have one or more modifiers.. 392 SmallVector<SmallString<12>> modifiers; 393 while (succeeded(parser.parseOptionalComma())) { 394 StringRef mod; 395 if (parser.parseKeyword(&mod)) 396 return failure(); 397 modifiers.push_back(mod); 398 } 399 400 if (verifyScheduleModifiers(parser, modifiers)) 401 return failure(); 402 403 if (!modifiers.empty()) { 404 SMLoc loc = parser.getCurrentLocation(); 405 if (std::optional<ScheduleModifier> mod = 406 symbolizeScheduleModifier(modifiers[0])) { 407 scheduleMod = ScheduleModifierAttr::get(parser.getContext(), *mod); 408 } else { 409 return parser.emitError(loc, "invalid schedule modifier"); 410 } 411 // Only SIMD attribute is allowed here! 412 if (modifiers.size() > 1) { 413 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd); 414 scheduleSimd = UnitAttr::get(parser.getBuilder().getContext()); 415 } 416 } 417 418 return success(); 419 } 420 421 /// Print schedule clause 422 static void printScheduleClause(OpAsmPrinter &p, Operation *op, 423 ClauseScheduleKindAttr scheduleKind, 424 ScheduleModifierAttr scheduleMod, 425 UnitAttr scheduleSimd, Value scheduleChunk, 426 Type scheduleChunkType) { 427 p << stringifyClauseScheduleKind(scheduleKind.getValue()); 428 if (scheduleChunk) 429 p << " = " << scheduleChunk << " : " << scheduleChunk.getType(); 430 if (scheduleMod) 431 p << ", " << stringifyScheduleModifier(scheduleMod.getValue()); 432 if (scheduleSimd) 433 p << ", simd"; 434 } 435 436 //===----------------------------------------------------------------------===// 437 // Parser and printer for Order Clause 438 //===----------------------------------------------------------------------===// 439 440 // order ::= `order` `(` [order-modifier ':'] concurrent `)` 441 // order-modifier ::= reproducible | unconstrained 442 static ParseResult parseOrderClause(OpAsmParser &parser, 443 ClauseOrderKindAttr &order, 444 OrderModifierAttr &orderMod) { 445 StringRef enumStr; 446 SMLoc loc = parser.getCurrentLocation(); 447 if (parser.parseKeyword(&enumStr)) 448 return failure(); 449 if (std::optional<OrderModifier> enumValue = 450 symbolizeOrderModifier(enumStr)) { 451 orderMod = OrderModifierAttr::get(parser.getContext(), *enumValue); 452 if (parser.parseOptionalColon()) 453 return failure(); 454 loc = parser.getCurrentLocation(); 455 if (parser.parseKeyword(&enumStr)) 456 return failure(); 457 } 458 if (std::optional<ClauseOrderKind> enumValue = 459 symbolizeClauseOrderKind(enumStr)) { 460 order = ClauseOrderKindAttr::get(parser.getContext(), *enumValue); 461 return success(); 462 } 463 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'"; 464 } 465 466 static void printOrderClause(OpAsmPrinter &p, Operation *op, 467 ClauseOrderKindAttr order, 468 OrderModifierAttr orderMod) { 469 if (orderMod) 470 p << stringifyOrderModifier(orderMod.getValue()) << ":"; 471 if (order) 472 p << stringifyClauseOrderKind(order.getValue()); 473 } 474 475 //===----------------------------------------------------------------------===// 476 // Parsers for operations including clauses that define entry block arguments. 477 //===----------------------------------------------------------------------===// 478 479 namespace { 480 struct MapParseArgs { 481 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars; 482 SmallVectorImpl<Type> &types; 483 MapParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars, 484 SmallVectorImpl<Type> &types) 485 : vars(vars), types(types) {} 486 }; 487 struct PrivateParseArgs { 488 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars; 489 llvm::SmallVectorImpl<Type> &types; 490 ArrayAttr &syms; 491 DenseI64ArrayAttr *mapIndices; 492 PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars, 493 SmallVectorImpl<Type> &types, ArrayAttr &syms, 494 DenseI64ArrayAttr *mapIndices = nullptr) 495 : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {} 496 }; 497 498 struct ReductionParseArgs { 499 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars; 500 SmallVectorImpl<Type> &types; 501 DenseBoolArrayAttr &byref; 502 ArrayAttr &syms; 503 ReductionModifierAttr *modifier; 504 ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars, 505 SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref, 506 ArrayAttr &syms, ReductionModifierAttr *mod = nullptr) 507 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {} 508 }; 509 510 struct AllRegionParseArgs { 511 std::optional<MapParseArgs> hostEvalArgs; 512 std::optional<ReductionParseArgs> inReductionArgs; 513 std::optional<MapParseArgs> mapArgs; 514 std::optional<PrivateParseArgs> privateArgs; 515 std::optional<ReductionParseArgs> reductionArgs; 516 std::optional<ReductionParseArgs> taskReductionArgs; 517 std::optional<MapParseArgs> useDeviceAddrArgs; 518 std::optional<MapParseArgs> useDevicePtrArgs; 519 }; 520 } // namespace 521 522 static ParseResult parseClauseWithRegionArgs( 523 OpAsmParser &parser, 524 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands, 525 SmallVectorImpl<Type> &types, 526 SmallVectorImpl<OpAsmParser::Argument> ®ionPrivateArgs, 527 ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr, 528 DenseBoolArrayAttr *byref = nullptr, 529 ReductionModifierAttr *modifier = nullptr) { 530 SmallVector<SymbolRefAttr> symbolVec; 531 SmallVector<int64_t> mapIndicesVec; 532 SmallVector<bool> isByRefVec; 533 unsigned regionArgOffset = regionPrivateArgs.size(); 534 535 if (parser.parseLParen()) 536 return failure(); 537 538 if (modifier && succeeded(parser.parseOptionalKeyword("mod"))) { 539 StringRef enumStr; 540 if (parser.parseColon() || parser.parseKeyword(&enumStr) || 541 parser.parseComma()) 542 return failure(); 543 std::optional<ReductionModifier> enumValue = 544 symbolizeReductionModifier(enumStr); 545 if (!enumValue.has_value()) 546 return failure(); 547 *modifier = ReductionModifierAttr::get(parser.getContext(), *enumValue); 548 if (!*modifier) 549 return failure(); 550 } 551 552 if (parser.parseCommaSeparatedList([&]() { 553 if (byref) 554 isByRefVec.push_back( 555 parser.parseOptionalKeyword("byref").succeeded()); 556 557 if (symbols && parser.parseAttribute(symbolVec.emplace_back())) 558 return failure(); 559 560 if (parser.parseOperand(operands.emplace_back()) || 561 parser.parseArrow() || 562 parser.parseArgument(regionPrivateArgs.emplace_back())) 563 return failure(); 564 565 if (mapIndices) { 566 if (parser.parseOptionalLSquare().succeeded()) { 567 if (parser.parseKeyword("map_idx") || parser.parseEqual() || 568 parser.parseInteger(mapIndicesVec.emplace_back()) || 569 parser.parseRSquare()) 570 return failure(); 571 } else 572 mapIndicesVec.push_back(-1); 573 } 574 575 return success(); 576 })) 577 return failure(); 578 579 if (parser.parseColon()) 580 return failure(); 581 582 if (parser.parseCommaSeparatedList([&]() { 583 if (parser.parseType(types.emplace_back())) 584 return failure(); 585 586 return success(); 587 })) 588 return failure(); 589 590 if (operands.size() != types.size()) 591 return failure(); 592 593 if (parser.parseRParen()) 594 return failure(); 595 596 auto *argsBegin = regionPrivateArgs.begin(); 597 MutableArrayRef argsSubrange(argsBegin + regionArgOffset, 598 argsBegin + regionArgOffset + types.size()); 599 for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) { 600 prv.type = type; 601 } 602 603 if (symbols) { 604 SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end()); 605 *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs); 606 } 607 608 if (!mapIndicesVec.empty()) 609 *mapIndices = 610 mlir::DenseI64ArrayAttr::get(parser.getContext(), mapIndicesVec); 611 612 if (byref) 613 *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec); 614 615 return success(); 616 } 617 618 static ParseResult parseBlockArgClause( 619 OpAsmParser &parser, 620 llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs, 621 StringRef keyword, std::optional<MapParseArgs> mapArgs) { 622 if (succeeded(parser.parseOptionalKeyword(keyword))) { 623 if (!mapArgs) 624 return failure(); 625 626 if (failed(parseClauseWithRegionArgs(parser, mapArgs->vars, mapArgs->types, 627 entryBlockArgs))) 628 return failure(); 629 } 630 return success(); 631 } 632 633 static ParseResult parseBlockArgClause( 634 OpAsmParser &parser, 635 llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs, 636 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) { 637 if (succeeded(parser.parseOptionalKeyword(keyword))) { 638 if (!privateArgs) 639 return failure(); 640 641 if (failed(parseClauseWithRegionArgs( 642 parser, privateArgs->vars, privateArgs->types, entryBlockArgs, 643 &privateArgs->syms, privateArgs->mapIndices))) 644 return failure(); 645 } 646 return success(); 647 } 648 649 static ParseResult parseBlockArgClause( 650 OpAsmParser &parser, 651 llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs, 652 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) { 653 if (succeeded(parser.parseOptionalKeyword(keyword))) { 654 if (!reductionArgs) 655 return failure(); 656 if (failed(parseClauseWithRegionArgs( 657 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs, 658 &reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref, 659 reductionArgs->modifier))) 660 return failure(); 661 } 662 return success(); 663 } 664 665 static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion, 666 AllRegionParseArgs args) { 667 llvm::SmallVector<OpAsmParser::Argument> entryBlockArgs; 668 669 if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval", 670 args.hostEvalArgs))) 671 return parser.emitError(parser.getCurrentLocation()) 672 << "invalid `host_eval` format"; 673 674 if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction", 675 args.inReductionArgs))) 676 return parser.emitError(parser.getCurrentLocation()) 677 << "invalid `in_reduction` format"; 678 679 if (failed(parseBlockArgClause(parser, entryBlockArgs, "map_entries", 680 args.mapArgs))) 681 return parser.emitError(parser.getCurrentLocation()) 682 << "invalid `map_entries` format"; 683 684 if (failed(parseBlockArgClause(parser, entryBlockArgs, "private", 685 args.privateArgs))) 686 return parser.emitError(parser.getCurrentLocation()) 687 << "invalid `private` format"; 688 689 if (failed(parseBlockArgClause(parser, entryBlockArgs, "reduction", 690 args.reductionArgs))) 691 return parser.emitError(parser.getCurrentLocation()) 692 << "invalid `reduction` format"; 693 694 if (failed(parseBlockArgClause(parser, entryBlockArgs, "task_reduction", 695 args.taskReductionArgs))) 696 return parser.emitError(parser.getCurrentLocation()) 697 << "invalid `task_reduction` format"; 698 699 if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr", 700 args.useDeviceAddrArgs))) 701 return parser.emitError(parser.getCurrentLocation()) 702 << "invalid `use_device_addr` format"; 703 704 if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr", 705 args.useDevicePtrArgs))) 706 return parser.emitError(parser.getCurrentLocation()) 707 << "invalid `use_device_addr` format"; 708 709 return parser.parseRegion(region, entryBlockArgs); 710 } 711 712 static ParseResult parseHostEvalInReductionMapPrivateRegion( 713 OpAsmParser &parser, Region ®ion, 714 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hostEvalVars, 715 SmallVectorImpl<Type> &hostEvalTypes, 716 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars, 717 SmallVectorImpl<Type> &inReductionTypes, 718 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, 719 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapVars, 720 SmallVectorImpl<Type> &mapTypes, 721 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars, 722 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms, 723 DenseI64ArrayAttr &privateMaps) { 724 AllRegionParseArgs args; 725 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes); 726 args.inReductionArgs.emplace(inReductionVars, inReductionTypes, 727 inReductionByref, inReductionSyms); 728 args.mapArgs.emplace(mapVars, mapTypes); 729 args.privateArgs.emplace(privateVars, privateTypes, privateSyms, 730 &privateMaps); 731 return parseBlockArgRegion(parser, region, args); 732 } 733 734 static ParseResult parseInReductionPrivateRegion( 735 OpAsmParser &parser, Region ®ion, 736 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars, 737 SmallVectorImpl<Type> &inReductionTypes, 738 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, 739 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars, 740 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) { 741 AllRegionParseArgs args; 742 args.inReductionArgs.emplace(inReductionVars, inReductionTypes, 743 inReductionByref, inReductionSyms); 744 args.privateArgs.emplace(privateVars, privateTypes, privateSyms); 745 return parseBlockArgRegion(parser, region, args); 746 } 747 748 static ParseResult parseInReductionPrivateReductionRegion( 749 OpAsmParser &parser, Region ®ion, 750 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars, 751 SmallVectorImpl<Type> &inReductionTypes, 752 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, 753 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars, 754 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms, 755 ReductionModifierAttr &reductionMod, 756 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars, 757 SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref, 758 ArrayAttr &reductionSyms) { 759 AllRegionParseArgs args; 760 args.inReductionArgs.emplace(inReductionVars, inReductionTypes, 761 inReductionByref, inReductionSyms); 762 args.privateArgs.emplace(privateVars, privateTypes, privateSyms); 763 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref, 764 reductionSyms, &reductionMod); 765 return parseBlockArgRegion(parser, region, args); 766 } 767 768 static ParseResult parsePrivateRegion( 769 OpAsmParser &parser, Region ®ion, 770 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars, 771 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) { 772 AllRegionParseArgs args; 773 args.privateArgs.emplace(privateVars, privateTypes, privateSyms); 774 return parseBlockArgRegion(parser, region, args); 775 } 776 777 static ParseResult parsePrivateReductionRegion( 778 OpAsmParser &parser, Region ®ion, 779 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars, 780 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms, 781 ReductionModifierAttr &reductionMod, 782 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars, 783 SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref, 784 ArrayAttr &reductionSyms) { 785 AllRegionParseArgs args; 786 args.privateArgs.emplace(privateVars, privateTypes, privateSyms); 787 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref, 788 reductionSyms, &reductionMod); 789 return parseBlockArgRegion(parser, region, args); 790 } 791 792 static ParseResult parseTaskReductionRegion( 793 OpAsmParser &parser, Region ®ion, 794 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &taskReductionVars, 795 SmallVectorImpl<Type> &taskReductionTypes, 796 DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) { 797 AllRegionParseArgs args; 798 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes, 799 taskReductionByref, taskReductionSyms); 800 return parseBlockArgRegion(parser, region, args); 801 } 802 803 static ParseResult parseUseDeviceAddrUseDevicePtrRegion( 804 OpAsmParser &parser, Region ®ion, 805 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &useDeviceAddrVars, 806 SmallVectorImpl<Type> &useDeviceAddrTypes, 807 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &useDevicePtrVars, 808 SmallVectorImpl<Type> &useDevicePtrTypes) { 809 AllRegionParseArgs args; 810 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes); 811 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes); 812 return parseBlockArgRegion(parser, region, args); 813 } 814 815 //===----------------------------------------------------------------------===// 816 // Printers for operations including clauses that define entry block arguments. 817 //===----------------------------------------------------------------------===// 818 819 namespace { 820 struct MapPrintArgs { 821 ValueRange vars; 822 TypeRange types; 823 MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {} 824 }; 825 struct PrivatePrintArgs { 826 ValueRange vars; 827 TypeRange types; 828 ArrayAttr syms; 829 DenseI64ArrayAttr mapIndices; 830 PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms, 831 DenseI64ArrayAttr mapIndices) 832 : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {} 833 }; 834 struct ReductionPrintArgs { 835 ValueRange vars; 836 TypeRange types; 837 DenseBoolArrayAttr byref; 838 ArrayAttr syms; 839 ReductionModifierAttr modifier; 840 ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref, 841 ArrayAttr syms, ReductionModifierAttr mod = nullptr) 842 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {} 843 }; 844 struct AllRegionPrintArgs { 845 std::optional<MapPrintArgs> hostEvalArgs; 846 std::optional<ReductionPrintArgs> inReductionArgs; 847 std::optional<MapPrintArgs> mapArgs; 848 std::optional<PrivatePrintArgs> privateArgs; 849 std::optional<ReductionPrintArgs> reductionArgs; 850 std::optional<ReductionPrintArgs> taskReductionArgs; 851 std::optional<MapPrintArgs> useDeviceAddrArgs; 852 std::optional<MapPrintArgs> useDevicePtrArgs; 853 }; 854 } // namespace 855 856 static void printClauseWithRegionArgs( 857 OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, 858 ValueRange argsSubrange, ValueRange operands, TypeRange types, 859 ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr, 860 DenseBoolArrayAttr byref = nullptr, 861 ReductionModifierAttr modifier = nullptr) { 862 if (argsSubrange.empty()) 863 return; 864 865 p << clauseName << "("; 866 867 if (modifier) 868 p << "mod: " << stringifyReductionModifier(modifier.getValue()) << ", "; 869 870 if (!symbols) { 871 llvm::SmallVector<Attribute> values(operands.size(), nullptr); 872 symbols = ArrayAttr::get(ctx, values); 873 } 874 875 if (!mapIndices) { 876 llvm::SmallVector<int64_t> values(operands.size(), -1); 877 mapIndices = DenseI64ArrayAttr::get(ctx, values); 878 } 879 880 if (!byref) { 881 mlir::SmallVector<bool> values(operands.size(), false); 882 byref = DenseBoolArrayAttr::get(ctx, values); 883 } 884 885 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols, 886 mapIndices.asArrayRef(), 887 byref.asArrayRef()), 888 p, [&p](auto t) { 889 auto [op, arg, sym, map, isByRef] = t; 890 if (isByRef) 891 p << "byref "; 892 if (sym) 893 p << sym << " "; 894 895 p << op << " -> " << arg; 896 897 if (map != -1) 898 p << " [map_idx=" << map << "]"; 899 }); 900 p << " : "; 901 llvm::interleaveComma(types, p); 902 p << ") "; 903 } 904 905 static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, 906 StringRef clauseName, ValueRange argsSubrange, 907 std::optional<MapPrintArgs> mapArgs) { 908 if (mapArgs) 909 printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, mapArgs->vars, 910 mapArgs->types); 911 } 912 913 static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, 914 StringRef clauseName, ValueRange argsSubrange, 915 std::optional<PrivatePrintArgs> privateArgs) { 916 if (privateArgs) 917 printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, 918 privateArgs->vars, privateArgs->types, 919 privateArgs->syms, privateArgs->mapIndices); 920 } 921 922 static void 923 printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, 924 ValueRange argsSubrange, 925 std::optional<ReductionPrintArgs> reductionArgs) { 926 if (reductionArgs) 927 printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, 928 reductionArgs->vars, reductionArgs->types, 929 reductionArgs->syms, /*mapIndices=*/nullptr, 930 reductionArgs->byref, reductionArgs->modifier); 931 } 932 933 static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion, 934 const AllRegionPrintArgs &args) { 935 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op); 936 MLIRContext *ctx = op->getContext(); 937 938 printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(), 939 args.hostEvalArgs); 940 printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(), 941 args.inReductionArgs); 942 printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(), 943 args.mapArgs); 944 printBlockArgClause(p, ctx, "private", iface.getPrivateBlockArgs(), 945 args.privateArgs); 946 printBlockArgClause(p, ctx, "reduction", iface.getReductionBlockArgs(), 947 args.reductionArgs); 948 printBlockArgClause(p, ctx, "task_reduction", 949 iface.getTaskReductionBlockArgs(), 950 args.taskReductionArgs); 951 printBlockArgClause(p, ctx, "use_device_addr", 952 iface.getUseDeviceAddrBlockArgs(), 953 args.useDeviceAddrArgs); 954 printBlockArgClause(p, ctx, "use_device_ptr", 955 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs); 956 957 p.printRegion(region, /*printEntryBlockArgs=*/false); 958 } 959 960 static void printHostEvalInReductionMapPrivateRegion( 961 OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange hostEvalVars, 962 TypeRange hostEvalTypes, ValueRange inReductionVars, 963 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, 964 ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, 965 ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, 966 DenseI64ArrayAttr privateMaps) { 967 AllRegionPrintArgs args; 968 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes); 969 args.inReductionArgs.emplace(inReductionVars, inReductionTypes, 970 inReductionByref, inReductionSyms); 971 args.mapArgs.emplace(mapVars, mapTypes); 972 args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps); 973 printBlockArgRegion(p, op, region, args); 974 } 975 976 static void printInReductionPrivateRegion( 977 OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, 978 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, 979 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, 980 ArrayAttr privateSyms) { 981 AllRegionPrintArgs args; 982 args.inReductionArgs.emplace(inReductionVars, inReductionTypes, 983 inReductionByref, inReductionSyms); 984 args.privateArgs.emplace(privateVars, privateTypes, privateSyms, 985 /*mapIndices=*/nullptr); 986 printBlockArgRegion(p, op, region, args); 987 } 988 989 static void printInReductionPrivateReductionRegion( 990 OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, 991 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, 992 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, 993 ArrayAttr privateSyms, ReductionModifierAttr reductionMod, 994 ValueRange reductionVars, TypeRange reductionTypes, 995 DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms) { 996 AllRegionPrintArgs args; 997 args.inReductionArgs.emplace(inReductionVars, inReductionTypes, 998 inReductionByref, inReductionSyms); 999 args.privateArgs.emplace(privateVars, privateTypes, privateSyms, 1000 /*mapIndices=*/nullptr); 1001 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref, 1002 reductionSyms, reductionMod); 1003 printBlockArgRegion(p, op, region, args); 1004 } 1005 1006 static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, 1007 ValueRange privateVars, TypeRange privateTypes, 1008 ArrayAttr privateSyms) { 1009 AllRegionPrintArgs args; 1010 args.privateArgs.emplace(privateVars, privateTypes, privateSyms, 1011 /*mapIndices=*/nullptr); 1012 printBlockArgRegion(p, op, region, args); 1013 } 1014 1015 static void printPrivateReductionRegion( 1016 OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, 1017 TypeRange privateTypes, ArrayAttr privateSyms, 1018 ReductionModifierAttr reductionMod, ValueRange reductionVars, 1019 TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, 1020 ArrayAttr reductionSyms) { 1021 AllRegionPrintArgs args; 1022 args.privateArgs.emplace(privateVars, privateTypes, privateSyms, 1023 /*mapIndices=*/nullptr); 1024 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref, 1025 reductionSyms, reductionMod); 1026 printBlockArgRegion(p, op, region, args); 1027 } 1028 1029 static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, 1030 Region ®ion, 1031 ValueRange taskReductionVars, 1032 TypeRange taskReductionTypes, 1033 DenseBoolArrayAttr taskReductionByref, 1034 ArrayAttr taskReductionSyms) { 1035 AllRegionPrintArgs args; 1036 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes, 1037 taskReductionByref, taskReductionSyms); 1038 printBlockArgRegion(p, op, region, args); 1039 } 1040 1041 static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, 1042 Region ®ion, 1043 ValueRange useDeviceAddrVars, 1044 TypeRange useDeviceAddrTypes, 1045 ValueRange useDevicePtrVars, 1046 TypeRange useDevicePtrTypes) { 1047 AllRegionPrintArgs args; 1048 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes); 1049 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes); 1050 printBlockArgRegion(p, op, region, args); 1051 } 1052 1053 /// Verifies Reduction Clause 1054 static LogicalResult 1055 verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms, 1056 OperandRange reductionVars, 1057 std::optional<ArrayRef<bool>> reductionByref) { 1058 if (!reductionVars.empty()) { 1059 if (!reductionSyms || reductionSyms->size() != reductionVars.size()) 1060 return op->emitOpError() 1061 << "expected as many reduction symbol references " 1062 "as reduction variables"; 1063 if (reductionByref && reductionByref->size() != reductionVars.size()) 1064 return op->emitError() << "expected as many reduction variable by " 1065 "reference attributes as reduction variables"; 1066 } else { 1067 if (reductionSyms) 1068 return op->emitOpError() << "unexpected reduction symbol references"; 1069 return success(); 1070 } 1071 1072 // TODO: The followings should be done in 1073 // SymbolUserOpInterface::verifySymbolUses. 1074 DenseSet<Value> accumulators; 1075 for (auto args : llvm::zip(reductionVars, *reductionSyms)) { 1076 Value accum = std::get<0>(args); 1077 1078 if (!accumulators.insert(accum).second) 1079 return op->emitOpError() << "accumulator variable used more than once"; 1080 1081 Type varType = accum.getType(); 1082 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args)); 1083 auto decl = 1084 SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef); 1085 if (!decl) 1086 return op->emitOpError() << "expected symbol reference " << symbolRef 1087 << " to point to a reduction declaration"; 1088 1089 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) 1090 return op->emitOpError() 1091 << "expected accumulator (" << varType 1092 << ") to be the same type as reduction declaration (" 1093 << decl.getAccumulatorType() << ")"; 1094 } 1095 1096 return success(); 1097 } 1098 1099 //===----------------------------------------------------------------------===// 1100 // Parser, printer and verifier for Copyprivate 1101 //===----------------------------------------------------------------------===// 1102 1103 /// copyprivate-entry-list ::= copyprivate-entry 1104 /// | copyprivate-entry-list `,` copyprivate-entry 1105 /// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type 1106 static ParseResult parseCopyprivate( 1107 OpAsmParser &parser, 1108 SmallVectorImpl<OpAsmParser::UnresolvedOperand> ©privateVars, 1109 SmallVectorImpl<Type> ©privateTypes, ArrayAttr ©privateSyms) { 1110 SmallVector<SymbolRefAttr> symsVec; 1111 if (failed(parser.parseCommaSeparatedList([&]() { 1112 if (parser.parseOperand(copyprivateVars.emplace_back()) || 1113 parser.parseArrow() || 1114 parser.parseAttribute(symsVec.emplace_back()) || 1115 parser.parseColonType(copyprivateTypes.emplace_back())) 1116 return failure(); 1117 return success(); 1118 }))) 1119 return failure(); 1120 SmallVector<Attribute> syms(symsVec.begin(), symsVec.end()); 1121 copyprivateSyms = ArrayAttr::get(parser.getContext(), syms); 1122 return success(); 1123 } 1124 1125 /// Print Copyprivate clause 1126 static void printCopyprivate(OpAsmPrinter &p, Operation *op, 1127 OperandRange copyprivateVars, 1128 TypeRange copyprivateTypes, 1129 std::optional<ArrayAttr> copyprivateSyms) { 1130 if (!copyprivateSyms.has_value()) 1131 return; 1132 llvm::interleaveComma( 1133 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p, 1134 [&](const auto &args) { 1135 p << std::get<0>(args) << " -> " << std::get<1>(args) << " : " 1136 << std::get<2>(args); 1137 }); 1138 } 1139 1140 /// Verifies CopyPrivate Clause 1141 static LogicalResult 1142 verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, 1143 std::optional<ArrayAttr> copyprivateSyms) { 1144 size_t copyprivateSymsSize = 1145 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0; 1146 if (copyprivateSymsSize != copyprivateVars.size()) 1147 return op->emitOpError() << "inconsistent number of copyprivate vars (= " 1148 << copyprivateVars.size() 1149 << ") and functions (= " << copyprivateSymsSize 1150 << "), both must be equal"; 1151 if (!copyprivateSyms.has_value()) 1152 return success(); 1153 1154 for (auto copyprivateVarAndSym : 1155 llvm::zip(copyprivateVars, *copyprivateSyms)) { 1156 auto symbolRef = 1157 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym)); 1158 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>> 1159 funcOp; 1160 if (mlir::func::FuncOp mlirFuncOp = 1161 SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op, 1162 symbolRef)) 1163 funcOp = mlirFuncOp; 1164 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp = 1165 SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>( 1166 op, symbolRef)) 1167 funcOp = llvmFuncOp; 1168 1169 auto getNumArguments = [&] { 1170 return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp); 1171 }; 1172 1173 auto getArgumentType = [&](unsigned i) { 1174 return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; }, 1175 *funcOp); 1176 }; 1177 1178 if (!funcOp) 1179 return op->emitOpError() << "expected symbol reference " << symbolRef 1180 << " to point to a copy function"; 1181 1182 if (getNumArguments() != 2) 1183 return op->emitOpError() 1184 << "expected copy function " << symbolRef << " to have 2 operands"; 1185 1186 Type argTy = getArgumentType(0); 1187 if (argTy != getArgumentType(1)) 1188 return op->emitOpError() << "expected copy function " << symbolRef 1189 << " arguments to have the same type"; 1190 1191 Type varType = std::get<0>(copyprivateVarAndSym).getType(); 1192 if (argTy != varType) 1193 return op->emitOpError() 1194 << "expected copy function arguments' type (" << argTy 1195 << ") to be the same as copyprivate variable's type (" << varType 1196 << ")"; 1197 } 1198 1199 return success(); 1200 } 1201 1202 //===----------------------------------------------------------------------===// 1203 // Parser, printer and verifier for DependVarList 1204 //===----------------------------------------------------------------------===// 1205 1206 /// depend-entry-list ::= depend-entry 1207 /// | depend-entry-list `,` depend-entry 1208 /// depend-entry ::= depend-kind `->` ssa-id `:` type 1209 static ParseResult 1210 parseDependVarList(OpAsmParser &parser, 1211 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dependVars, 1212 SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds) { 1213 SmallVector<ClauseTaskDependAttr> kindsVec; 1214 if (failed(parser.parseCommaSeparatedList([&]() { 1215 StringRef keyword; 1216 if (parser.parseKeyword(&keyword) || parser.parseArrow() || 1217 parser.parseOperand(dependVars.emplace_back()) || 1218 parser.parseColonType(dependTypes.emplace_back())) 1219 return failure(); 1220 if (std::optional<ClauseTaskDepend> keywordDepend = 1221 (symbolizeClauseTaskDepend(keyword))) 1222 kindsVec.emplace_back( 1223 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend)); 1224 else 1225 return failure(); 1226 return success(); 1227 }))) 1228 return failure(); 1229 SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end()); 1230 dependKinds = ArrayAttr::get(parser.getContext(), kinds); 1231 return success(); 1232 } 1233 1234 /// Print Depend clause 1235 static void printDependVarList(OpAsmPrinter &p, Operation *op, 1236 OperandRange dependVars, TypeRange dependTypes, 1237 std::optional<ArrayAttr> dependKinds) { 1238 1239 for (unsigned i = 0, e = dependKinds->size(); i < e; ++i) { 1240 if (i != 0) 1241 p << ", "; 1242 p << stringifyClauseTaskDepend( 1243 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i]) 1244 .getValue()) 1245 << " -> " << dependVars[i] << " : " << dependTypes[i]; 1246 } 1247 } 1248 1249 /// Verifies Depend clause 1250 static LogicalResult verifyDependVarList(Operation *op, 1251 std::optional<ArrayAttr> dependKinds, 1252 OperandRange dependVars) { 1253 if (!dependVars.empty()) { 1254 if (!dependKinds || dependKinds->size() != dependVars.size()) 1255 return op->emitOpError() << "expected as many depend values" 1256 " as depend variables"; 1257 } else { 1258 if (dependKinds && !dependKinds->empty()) 1259 return op->emitOpError() << "unexpected depend values"; 1260 return success(); 1261 } 1262 1263 return success(); 1264 } 1265 1266 //===----------------------------------------------------------------------===// 1267 // Parser, printer and verifier for Synchronization Hint (2.17.12) 1268 //===----------------------------------------------------------------------===// 1269 1270 /// Parses a Synchronization Hint clause. The value of hint is an integer 1271 /// which is a combination of different hints from `omp_sync_hint_t`. 1272 /// 1273 /// hint-clause = `hint` `(` hint-value `)` 1274 static ParseResult parseSynchronizationHint(OpAsmParser &parser, 1275 IntegerAttr &hintAttr) { 1276 StringRef hintKeyword; 1277 int64_t hint = 0; 1278 if (succeeded(parser.parseOptionalKeyword("none"))) { 1279 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0); 1280 return success(); 1281 } 1282 auto parseKeyword = [&]() -> ParseResult { 1283 if (failed(parser.parseKeyword(&hintKeyword))) 1284 return failure(); 1285 if (hintKeyword == "uncontended") 1286 hint |= 1; 1287 else if (hintKeyword == "contended") 1288 hint |= 2; 1289 else if (hintKeyword == "nonspeculative") 1290 hint |= 4; 1291 else if (hintKeyword == "speculative") 1292 hint |= 8; 1293 else 1294 return parser.emitError(parser.getCurrentLocation()) 1295 << hintKeyword << " is not a valid hint"; 1296 return success(); 1297 }; 1298 if (parser.parseCommaSeparatedList(parseKeyword)) 1299 return failure(); 1300 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint); 1301 return success(); 1302 } 1303 1304 /// Prints a Synchronization Hint clause 1305 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, 1306 IntegerAttr hintAttr) { 1307 int64_t hint = hintAttr.getInt(); 1308 1309 if (hint == 0) { 1310 p << "none"; 1311 return; 1312 } 1313 1314 // Helper function to get n-th bit from the right end of `value` 1315 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 1316 1317 bool uncontended = bitn(hint, 0); 1318 bool contended = bitn(hint, 1); 1319 bool nonspeculative = bitn(hint, 2); 1320 bool speculative = bitn(hint, 3); 1321 1322 SmallVector<StringRef> hints; 1323 if (uncontended) 1324 hints.push_back("uncontended"); 1325 if (contended) 1326 hints.push_back("contended"); 1327 if (nonspeculative) 1328 hints.push_back("nonspeculative"); 1329 if (speculative) 1330 hints.push_back("speculative"); 1331 1332 llvm::interleaveComma(hints, p); 1333 } 1334 1335 /// Verifies a synchronization hint clause 1336 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) { 1337 1338 // Helper function to get n-th bit from the right end of `value` 1339 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 1340 1341 bool uncontended = bitn(hint, 0); 1342 bool contended = bitn(hint, 1); 1343 bool nonspeculative = bitn(hint, 2); 1344 bool speculative = bitn(hint, 3); 1345 1346 if (uncontended && contended) 1347 return op->emitOpError() << "the hints omp_sync_hint_uncontended and " 1348 "omp_sync_hint_contended cannot be combined"; 1349 if (nonspeculative && speculative) 1350 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and " 1351 "omp_sync_hint_speculative cannot be combined."; 1352 return success(); 1353 } 1354 1355 //===----------------------------------------------------------------------===// 1356 // Parser, printer and verifier for Target 1357 //===----------------------------------------------------------------------===// 1358 1359 // Helper function to get bitwise AND of `value` and 'flag' 1360 uint64_t mapTypeToBitFlag(uint64_t value, 1361 llvm::omp::OpenMPOffloadMappingFlags flag) { 1362 return value & llvm::to_underlying(flag); 1363 } 1364 1365 /// Parses a map_entries map type from a string format back into its numeric 1366 /// value. 1367 /// 1368 /// map-clause = `map_clauses ( ( `(` `always, `? `close, `? `present, `? ( 1369 /// `to` | `from` | `delete` `)` )+ `)` ) 1370 static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) { 1371 llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = 1372 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; 1373 1374 // This simply verifies the correct keyword is read in, the 1375 // keyword itself is stored inside of the operation 1376 auto parseTypeAndMod = [&]() -> ParseResult { 1377 StringRef mapTypeMod; 1378 if (parser.parseKeyword(&mapTypeMod)) 1379 return failure(); 1380 1381 if (mapTypeMod == "always") 1382 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS; 1383 1384 if (mapTypeMod == "implicit") 1385 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT; 1386 1387 if (mapTypeMod == "close") 1388 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE; 1389 1390 if (mapTypeMod == "present") 1391 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT; 1392 1393 if (mapTypeMod == "to") 1394 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; 1395 1396 if (mapTypeMod == "from") 1397 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; 1398 1399 if (mapTypeMod == "tofrom") 1400 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | 1401 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; 1402 1403 if (mapTypeMod == "delete") 1404 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE; 1405 1406 return success(); 1407 }; 1408 1409 if (parser.parseCommaSeparatedList(parseTypeAndMod)) 1410 return failure(); 1411 1412 mapType = parser.getBuilder().getIntegerAttr( 1413 parser.getBuilder().getIntegerType(64, /*isSigned=*/false), 1414 llvm::to_underlying(mapTypeBits)); 1415 1416 return success(); 1417 } 1418 1419 /// Prints a map_entries map type from its numeric value out into its string 1420 /// format. 1421 static void printMapClause(OpAsmPrinter &p, Operation *op, 1422 IntegerAttr mapType) { 1423 uint64_t mapTypeBits = mapType.getUInt(); 1424 1425 bool emitAllocRelease = true; 1426 llvm::SmallVector<std::string, 4> mapTypeStrs; 1427 1428 // handling of always, close, present placed at the beginning of the string 1429 // to aid readability 1430 if (mapTypeToBitFlag(mapTypeBits, 1431 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS)) 1432 mapTypeStrs.push_back("always"); 1433 if (mapTypeToBitFlag(mapTypeBits, 1434 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT)) 1435 mapTypeStrs.push_back("implicit"); 1436 if (mapTypeToBitFlag(mapTypeBits, 1437 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE)) 1438 mapTypeStrs.push_back("close"); 1439 if (mapTypeToBitFlag(mapTypeBits, 1440 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT)) 1441 mapTypeStrs.push_back("present"); 1442 1443 // special handling of to/from/tofrom/delete and release/alloc, release + 1444 // alloc are the abscense of one of the other flags, whereas tofrom requires 1445 // both the to and from flag to be set. 1446 bool to = mapTypeToBitFlag(mapTypeBits, 1447 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); 1448 bool from = mapTypeToBitFlag( 1449 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); 1450 if (to && from) { 1451 emitAllocRelease = false; 1452 mapTypeStrs.push_back("tofrom"); 1453 } else if (from) { 1454 emitAllocRelease = false; 1455 mapTypeStrs.push_back("from"); 1456 } else if (to) { 1457 emitAllocRelease = false; 1458 mapTypeStrs.push_back("to"); 1459 } 1460 if (mapTypeToBitFlag(mapTypeBits, 1461 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) { 1462 emitAllocRelease = false; 1463 mapTypeStrs.push_back("delete"); 1464 } 1465 if (emitAllocRelease) 1466 mapTypeStrs.push_back("exit_release_or_enter_alloc"); 1467 1468 for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) { 1469 p << mapTypeStrs[i]; 1470 if (i + 1 < mapTypeStrs.size()) { 1471 p << ", "; 1472 } 1473 } 1474 } 1475 1476 static ParseResult parseMembersIndex(OpAsmParser &parser, 1477 ArrayAttr &membersIdx) { 1478 SmallVector<Attribute> values, memberIdxs; 1479 1480 auto parseIndices = [&]() -> ParseResult { 1481 int64_t value; 1482 if (parser.parseInteger(value)) 1483 return failure(); 1484 values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64), 1485 APInt(64, value, /*isSigned=*/false))); 1486 return success(); 1487 }; 1488 1489 do { 1490 if (failed(parser.parseLSquare())) 1491 return failure(); 1492 1493 if (parser.parseCommaSeparatedList(parseIndices)) 1494 return failure(); 1495 1496 if (failed(parser.parseRSquare())) 1497 return failure(); 1498 1499 memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values)); 1500 values.clear(); 1501 } while (succeeded(parser.parseOptionalComma())); 1502 1503 if (!memberIdxs.empty()) 1504 membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs); 1505 1506 return success(); 1507 } 1508 1509 static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, 1510 ArrayAttr membersIdx) { 1511 if (!membersIdx) 1512 return; 1513 1514 llvm::interleaveComma(membersIdx, p, [&p](Attribute v) { 1515 p << "["; 1516 auto memberIdx = cast<ArrayAttr>(v); 1517 llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) { 1518 p << cast<IntegerAttr>(v2).getInt(); 1519 }); 1520 p << "]"; 1521 }); 1522 } 1523 1524 static void printCaptureType(OpAsmPrinter &p, Operation *op, 1525 VariableCaptureKindAttr mapCaptureType) { 1526 std::string typeCapStr; 1527 llvm::raw_string_ostream typeCap(typeCapStr); 1528 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef) 1529 typeCap << "ByRef"; 1530 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy) 1531 typeCap << "ByCopy"; 1532 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType) 1533 typeCap << "VLAType"; 1534 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This) 1535 typeCap << "This"; 1536 p << typeCapStr; 1537 } 1538 1539 static ParseResult parseCaptureType(OpAsmParser &parser, 1540 VariableCaptureKindAttr &mapCaptureType) { 1541 StringRef mapCaptureKey; 1542 if (parser.parseKeyword(&mapCaptureKey)) 1543 return failure(); 1544 1545 if (mapCaptureKey == "This") 1546 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get( 1547 parser.getContext(), mlir::omp::VariableCaptureKind::This); 1548 if (mapCaptureKey == "ByRef") 1549 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get( 1550 parser.getContext(), mlir::omp::VariableCaptureKind::ByRef); 1551 if (mapCaptureKey == "ByCopy") 1552 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get( 1553 parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy); 1554 if (mapCaptureKey == "VLAType") 1555 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get( 1556 parser.getContext(), mlir::omp::VariableCaptureKind::VLAType); 1557 1558 return success(); 1559 } 1560 1561 static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) { 1562 llvm::DenseSet<mlir::TypedValue<mlir::omp::PointerLikeType>> updateToVars; 1563 llvm::DenseSet<mlir::TypedValue<mlir::omp::PointerLikeType>> updateFromVars; 1564 1565 for (auto mapOp : mapVars) { 1566 if (!mapOp.getDefiningOp()) 1567 emitError(op->getLoc(), "missing map operation"); 1568 1569 if (auto mapInfoOp = 1570 mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) { 1571 if (!mapInfoOp.getMapType().has_value()) 1572 emitError(op->getLoc(), "missing map type for map operand"); 1573 1574 if (!mapInfoOp.getMapCaptureType().has_value()) 1575 emitError(op->getLoc(), "missing map capture type for map operand"); 1576 1577 uint64_t mapTypeBits = mapInfoOp.getMapType().value(); 1578 1579 bool to = mapTypeToBitFlag( 1580 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); 1581 bool from = mapTypeToBitFlag( 1582 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); 1583 bool del = mapTypeToBitFlag( 1584 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE); 1585 1586 bool always = mapTypeToBitFlag( 1587 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS); 1588 bool close = mapTypeToBitFlag( 1589 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE); 1590 bool implicit = mapTypeToBitFlag( 1591 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT); 1592 1593 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del) 1594 return emitError(op->getLoc(), 1595 "to, from, tofrom and alloc map types are permitted"); 1596 1597 if (isa<TargetEnterDataOp>(op) && (from || del)) 1598 return emitError(op->getLoc(), "to and alloc map types are permitted"); 1599 1600 if (isa<TargetExitDataOp>(op) && to) 1601 return emitError(op->getLoc(), 1602 "from, release and delete map types are permitted"); 1603 1604 if (isa<TargetUpdateOp>(op)) { 1605 if (del) { 1606 return emitError(op->getLoc(), 1607 "at least one of to or from map types must be " 1608 "specified, other map types are not permitted"); 1609 } 1610 1611 if (!to && !from) { 1612 return emitError(op->getLoc(), 1613 "at least one of to or from map types must be " 1614 "specified, other map types are not permitted"); 1615 } 1616 1617 auto updateVar = mapInfoOp.getVarPtr(); 1618 1619 if ((to && from) || (to && updateFromVars.contains(updateVar)) || 1620 (from && updateToVars.contains(updateVar))) { 1621 return emitError( 1622 op->getLoc(), 1623 "either to or from map types can be specified, not both"); 1624 } 1625 1626 if (always || close || implicit) { 1627 return emitError( 1628 op->getLoc(), 1629 "present, mapper and iterator map type modifiers are permitted"); 1630 } 1631 1632 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar); 1633 } 1634 } else { 1635 emitError(op->getLoc(), "map argument is not a map entry operation"); 1636 } 1637 } 1638 1639 return success(); 1640 } 1641 1642 static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) { 1643 std::optional<DenseI64ArrayAttr> privateMapIndices = 1644 targetOp.getPrivateMapsAttr(); 1645 1646 // None of the private operands are mapped. 1647 if (!privateMapIndices.has_value() || !privateMapIndices.value()) 1648 return success(); 1649 1650 OperandRange privateVars = targetOp.getPrivateVars(); 1651 1652 if (privateMapIndices.value().size() != 1653 static_cast<int64_t>(privateVars.size())) 1654 return emitError(targetOp.getLoc(), "sizes of `private` operand range and " 1655 "`private_maps` attribute mismatch"); 1656 1657 return success(); 1658 } 1659 1660 //===----------------------------------------------------------------------===// 1661 // TargetDataOp 1662 //===----------------------------------------------------------------------===// 1663 1664 void TargetDataOp::build(OpBuilder &builder, OperationState &state, 1665 const TargetDataOperands &clauses) { 1666 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr, 1667 clauses.mapVars, clauses.useDeviceAddrVars, 1668 clauses.useDevicePtrVars); 1669 } 1670 1671 LogicalResult TargetDataOp::verify() { 1672 if (getMapVars().empty() && getUseDevicePtrVars().empty() && 1673 getUseDeviceAddrVars().empty()) { 1674 return ::emitError(this->getLoc(), 1675 "At least one of map, use_device_ptr_vars, or " 1676 "use_device_addr_vars operand must be present"); 1677 } 1678 return verifyMapClause(*this, getMapVars()); 1679 } 1680 1681 //===----------------------------------------------------------------------===// 1682 // TargetEnterDataOp 1683 //===----------------------------------------------------------------------===// 1684 1685 void TargetEnterDataOp::build( 1686 OpBuilder &builder, OperationState &state, 1687 const TargetEnterExitUpdateDataOperands &clauses) { 1688 MLIRContext *ctx = builder.getContext(); 1689 TargetEnterDataOp::build(builder, state, 1690 makeArrayAttr(ctx, clauses.dependKinds), 1691 clauses.dependVars, clauses.device, clauses.ifExpr, 1692 clauses.mapVars, clauses.nowait); 1693 } 1694 1695 LogicalResult TargetEnterDataOp::verify() { 1696 LogicalResult verifyDependVars = 1697 verifyDependVarList(*this, getDependKinds(), getDependVars()); 1698 return failed(verifyDependVars) ? verifyDependVars 1699 : verifyMapClause(*this, getMapVars()); 1700 } 1701 1702 //===----------------------------------------------------------------------===// 1703 // TargetExitDataOp 1704 //===----------------------------------------------------------------------===// 1705 1706 void TargetExitDataOp::build(OpBuilder &builder, OperationState &state, 1707 const TargetEnterExitUpdateDataOperands &clauses) { 1708 MLIRContext *ctx = builder.getContext(); 1709 TargetExitDataOp::build(builder, state, 1710 makeArrayAttr(ctx, clauses.dependKinds), 1711 clauses.dependVars, clauses.device, clauses.ifExpr, 1712 clauses.mapVars, clauses.nowait); 1713 } 1714 1715 LogicalResult TargetExitDataOp::verify() { 1716 LogicalResult verifyDependVars = 1717 verifyDependVarList(*this, getDependKinds(), getDependVars()); 1718 return failed(verifyDependVars) ? verifyDependVars 1719 : verifyMapClause(*this, getMapVars()); 1720 } 1721 1722 //===----------------------------------------------------------------------===// 1723 // TargetUpdateOp 1724 //===----------------------------------------------------------------------===// 1725 1726 void TargetUpdateOp::build(OpBuilder &builder, OperationState &state, 1727 const TargetEnterExitUpdateDataOperands &clauses) { 1728 MLIRContext *ctx = builder.getContext(); 1729 TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds), 1730 clauses.dependVars, clauses.device, clauses.ifExpr, 1731 clauses.mapVars, clauses.nowait); 1732 } 1733 1734 LogicalResult TargetUpdateOp::verify() { 1735 LogicalResult verifyDependVars = 1736 verifyDependVarList(*this, getDependKinds(), getDependVars()); 1737 return failed(verifyDependVars) ? verifyDependVars 1738 : verifyMapClause(*this, getMapVars()); 1739 } 1740 1741 //===----------------------------------------------------------------------===// 1742 // TargetOp 1743 //===----------------------------------------------------------------------===// 1744 1745 void TargetOp::build(OpBuilder &builder, OperationState &state, 1746 const TargetOperands &clauses) { 1747 MLIRContext *ctx = builder.getContext(); 1748 // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars, 1749 // inReductionByref, inReductionSyms. 1750 TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, 1751 clauses.bare, makeArrayAttr(ctx, clauses.dependKinds), 1752 clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars, 1753 clauses.hostEvalVars, clauses.ifExpr, 1754 /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr, 1755 /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars, 1756 clauses.mapVars, clauses.nowait, clauses.privateVars, 1757 makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit, 1758 /*private_maps=*/nullptr); 1759 } 1760 1761 LogicalResult TargetOp::verify() { 1762 LogicalResult verifyDependVars = 1763 verifyDependVarList(*this, getDependKinds(), getDependVars()); 1764 1765 if (failed(verifyDependVars)) 1766 return verifyDependVars; 1767 1768 LogicalResult verifyMapVars = verifyMapClause(*this, getMapVars()); 1769 1770 if (failed(verifyMapVars)) 1771 return verifyMapVars; 1772 1773 return verifyPrivateVarsMapping(*this); 1774 } 1775 1776 LogicalResult TargetOp::verifyRegions() { 1777 auto teamsOps = getOps<TeamsOp>(); 1778 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1) 1779 return emitError("target containing multiple 'omp.teams' nested ops"); 1780 1781 // Check that host_eval values are only used in legal ways. 1782 llvm::omp::OMPTgtExecModeFlags execFlags = getKernelExecFlags(); 1783 for (Value hostEvalArg : 1784 cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) { 1785 for (Operation *user : hostEvalArg.getUsers()) { 1786 if (auto teamsOp = dyn_cast<TeamsOp>(user)) { 1787 if (llvm::is_contained({teamsOp.getNumTeamsLower(), 1788 teamsOp.getNumTeamsUpper(), 1789 teamsOp.getThreadLimit()}, 1790 hostEvalArg)) 1791 continue; 1792 1793 return emitOpError() << "host_eval argument only legal as 'num_teams' " 1794 "and 'thread_limit' in 'omp.teams'"; 1795 } 1796 if (auto parallelOp = dyn_cast<ParallelOp>(user)) { 1797 if (execFlags == llvm::omp::OMP_TGT_EXEC_MODE_SPMD && 1798 hostEvalArg == parallelOp.getNumThreads()) 1799 continue; 1800 1801 return emitOpError() 1802 << "host_eval argument only legal as 'num_threads' in " 1803 "'omp.parallel' when representing target SPMD"; 1804 } 1805 if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) { 1806 if (execFlags != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC && 1807 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) || 1808 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) || 1809 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg))) 1810 continue; 1811 1812 return emitOpError() << "host_eval argument only legal as loop bounds " 1813 "and steps in 'omp.loop_nest' when " 1814 "representing target SPMD or Generic-SPMD"; 1815 } 1816 1817 return emitOpError() << "host_eval argument illegal use in '" 1818 << user->getName() << "' operation"; 1819 } 1820 } 1821 return success(); 1822 } 1823 1824 /// Only allow OpenMP terminators and non-OpenMP ops that have known memory 1825 /// effects, but don't include a memory write effect. 1826 static bool siblingAllowedInCapture(Operation *op) { 1827 if (!op) 1828 return false; 1829 1830 bool isOmpDialect = 1831 op->getContext()->getLoadedDialect<omp::OpenMPDialect>() == 1832 op->getDialect(); 1833 1834 if (isOmpDialect) 1835 return op->hasTrait<OpTrait::IsTerminator>(); 1836 1837 if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) { 1838 SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4> effects; 1839 memOp.getEffects(effects); 1840 return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) { 1841 return isa<MemoryEffects::Write>(effect.getEffect()) && 1842 isa<SideEffects::AutomaticAllocationScopeResource>( 1843 effect.getResource()); 1844 }); 1845 } 1846 return true; 1847 } 1848 1849 Operation *TargetOp::getInnermostCapturedOmpOp() { 1850 Dialect *ompDialect = (*this)->getDialect(); 1851 Operation *capturedOp = nullptr; 1852 DominanceInfo domInfo; 1853 1854 // Process in pre-order to check operations from outermost to innermost, 1855 // ensuring we only enter the region of an operation if it meets the criteria 1856 // for being captured. We stop the exploration of nested operations as soon as 1857 // we process a region holding no operations to be captured. 1858 walk<WalkOrder::PreOrder>([&](Operation *op) { 1859 if (op == *this) 1860 return WalkResult::advance(); 1861 1862 // Ignore operations of other dialects or omp operations with no regions, 1863 // because these will only be checked if they are siblings of an omp 1864 // operation that can potentially be captured. 1865 bool isOmpDialect = op->getDialect() == ompDialect; 1866 bool hasRegions = op->getNumRegions() > 0; 1867 if (!isOmpDialect || !hasRegions) 1868 return WalkResult::skip(); 1869 1870 // This operation cannot be captured if it can be executed more than once 1871 // (i.e. its block's successors can reach it) or if it's not guaranteed to 1872 // be executed before all exits of the region (i.e. it doesn't dominate all 1873 // blocks with no successors reachable from the entry block). 1874 Region *parentRegion = op->getParentRegion(); 1875 Block *parentBlock = op->getBlock(); 1876 1877 for (Block *successor : parentBlock->getSuccessors()) 1878 if (successor->isReachable(parentBlock)) 1879 return WalkResult::interrupt(); 1880 1881 for (Block &block : *parentRegion) 1882 if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() && 1883 !domInfo.dominates(parentBlock, &block)) 1884 return WalkResult::interrupt(); 1885 1886 // Don't capture this op if it has a not-allowed sibling, and stop recursing 1887 // into nested operations. 1888 for (Operation &sibling : op->getParentRegion()->getOps()) 1889 if (&sibling != op && !siblingAllowedInCapture(&sibling)) 1890 return WalkResult::interrupt(); 1891 1892 // Don't continue capturing nested operations if we reach an omp.loop_nest. 1893 // Otherwise, process the contents of this operation. 1894 capturedOp = op; 1895 return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt() 1896 : WalkResult::advance(); 1897 }); 1898 1899 return capturedOp; 1900 } 1901 1902 llvm::omp::OMPTgtExecModeFlags TargetOp::getKernelExecFlags() { 1903 using namespace llvm::omp; 1904 1905 // Make sure this region is capturing a loop. Otherwise, it's a generic 1906 // kernel. 1907 Operation *capturedOp = getInnermostCapturedOmpOp(); 1908 if (!isa_and_present<LoopNestOp>(capturedOp)) 1909 return OMP_TGT_EXEC_MODE_GENERIC; 1910 1911 SmallVector<LoopWrapperInterface> wrappers; 1912 cast<LoopNestOp>(capturedOp).gatherWrappers(wrappers); 1913 assert(!wrappers.empty()); 1914 1915 // Ignore optional SIMD leaf construct. 1916 auto *innermostWrapper = wrappers.begin(); 1917 if (isa<SimdOp>(innermostWrapper)) 1918 innermostWrapper = std::next(innermostWrapper); 1919 1920 long numWrappers = std::distance(innermostWrapper, wrappers.end()); 1921 1922 // Detect Generic-SPMD: target-teams-distribute[-simd]. 1923 if (numWrappers == 1) { 1924 if (!isa<DistributeOp>(innermostWrapper)) 1925 return OMP_TGT_EXEC_MODE_GENERIC; 1926 1927 Operation *teamsOp = (*innermostWrapper)->getParentOp(); 1928 if (!isa_and_present<TeamsOp>(teamsOp)) 1929 return OMP_TGT_EXEC_MODE_GENERIC; 1930 1931 if (teamsOp->getParentOp() == *this) 1932 return OMP_TGT_EXEC_MODE_GENERIC_SPMD; 1933 } 1934 1935 // Detect SPMD: target-teams-distribute-parallel-wsloop[-simd]. 1936 if (numWrappers == 2) { 1937 if (!isa<WsloopOp>(innermostWrapper)) 1938 return OMP_TGT_EXEC_MODE_GENERIC; 1939 1940 innermostWrapper = std::next(innermostWrapper); 1941 if (!isa<DistributeOp>(innermostWrapper)) 1942 return OMP_TGT_EXEC_MODE_GENERIC; 1943 1944 Operation *parallelOp = (*innermostWrapper)->getParentOp(); 1945 if (!isa_and_present<ParallelOp>(parallelOp)) 1946 return OMP_TGT_EXEC_MODE_GENERIC; 1947 1948 Operation *teamsOp = parallelOp->getParentOp(); 1949 if (!isa_and_present<TeamsOp>(teamsOp)) 1950 return OMP_TGT_EXEC_MODE_GENERIC; 1951 1952 if (teamsOp->getParentOp() == *this) 1953 return OMP_TGT_EXEC_MODE_SPMD; 1954 } 1955 1956 return OMP_TGT_EXEC_MODE_GENERIC; 1957 } 1958 1959 //===----------------------------------------------------------------------===// 1960 // ParallelOp 1961 //===----------------------------------------------------------------------===// 1962 1963 void ParallelOp::build(OpBuilder &builder, OperationState &state, 1964 ArrayRef<NamedAttribute> attributes) { 1965 ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(), 1966 /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr, 1967 /*num_threads=*/nullptr, /*private_vars=*/ValueRange(), 1968 /*private_syms=*/nullptr, /*proc_bind_kind=*/nullptr, 1969 /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(), 1970 /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr); 1971 state.addAttributes(attributes); 1972 } 1973 1974 void ParallelOp::build(OpBuilder &builder, OperationState &state, 1975 const ParallelOperands &clauses) { 1976 MLIRContext *ctx = builder.getContext(); 1977 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, 1978 clauses.ifExpr, clauses.numThreads, clauses.privateVars, 1979 makeArrayAttr(ctx, clauses.privateSyms), 1980 clauses.procBindKind, clauses.reductionMod, 1981 clauses.reductionVars, 1982 makeDenseBoolArrayAttr(ctx, clauses.reductionByref), 1983 makeArrayAttr(ctx, clauses.reductionSyms)); 1984 } 1985 1986 template <typename OpType> 1987 static LogicalResult verifyPrivateVarList(OpType &op) { 1988 auto privateVars = op.getPrivateVars(); 1989 auto privateSyms = op.getPrivateSymsAttr(); 1990 1991 if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty())) 1992 return success(); 1993 1994 auto numPrivateVars = privateVars.size(); 1995 auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size(); 1996 1997 if (numPrivateVars != numPrivateSyms) 1998 return op.emitError() << "inconsistent number of private variables and " 1999 "privatizer op symbols, private vars: " 2000 << numPrivateVars 2001 << " vs. privatizer op symbols: " << numPrivateSyms; 2002 2003 for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) { 2004 Type varType = std::get<0>(privateVarInfo).getType(); 2005 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo)); 2006 PrivateClauseOp privatizerOp = 2007 SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op, privateSym); 2008 2009 if (privatizerOp == nullptr) 2010 return op.emitError() << "failed to lookup privatizer op with symbol: '" 2011 << privateSym << "'"; 2012 2013 Type privatizerType = privatizerOp.getType(); 2014 2015 if (varType != privatizerType) 2016 return op.emitError() 2017 << "type mismatch between a " 2018 << (privatizerOp.getDataSharingType() == 2019 DataSharingClauseType::Private 2020 ? "private" 2021 : "firstprivate") 2022 << " variable and its privatizer op, var type: " << varType 2023 << " vs. privatizer op type: " << privatizerType; 2024 } 2025 2026 return success(); 2027 } 2028 2029 LogicalResult ParallelOp::verify() { 2030 if (getAllocateVars().size() != getAllocatorVars().size()) 2031 return emitError( 2032 "expected equal sizes for allocate and allocator variables"); 2033 2034 if (failed(verifyPrivateVarList(*this))) 2035 return failure(); 2036 2037 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(), 2038 getReductionByref()); 2039 } 2040 2041 LogicalResult ParallelOp::verifyRegions() { 2042 auto distributeChildOps = getOps<DistributeOp>(); 2043 if (!distributeChildOps.empty()) { 2044 if (!isComposite()) 2045 return emitError() 2046 << "'omp.composite' attribute missing from composite operation"; 2047 2048 auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>(); 2049 Operation &distributeOp = **distributeChildOps.begin(); 2050 for (Operation &childOp : getOps()) { 2051 if (&childOp == &distributeOp || ompDialect != childOp.getDialect()) 2052 continue; 2053 2054 if (!childOp.hasTrait<OpTrait::IsTerminator>()) 2055 return emitError() << "unexpected OpenMP operation inside of composite " 2056 "'omp.parallel'"; 2057 } 2058 } else if (isComposite()) { 2059 return emitError() 2060 << "'omp.composite' attribute present in non-composite operation"; 2061 } 2062 return success(); 2063 } 2064 2065 //===----------------------------------------------------------------------===// 2066 // TeamsOp 2067 //===----------------------------------------------------------------------===// 2068 2069 static bool opInGlobalImplicitParallelRegion(Operation *op) { 2070 while ((op = op->getParentOp())) 2071 if (isa<OpenMPDialect>(op->getDialect())) 2072 return false; 2073 return true; 2074 } 2075 2076 void TeamsOp::build(OpBuilder &builder, OperationState &state, 2077 const TeamsOperands &clauses) { 2078 MLIRContext *ctx = builder.getContext(); 2079 // TODO Store clauses in op: privateVars, privateSyms. 2080 TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, 2081 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper, 2082 /*private_vars=*/{}, /*private_syms=*/nullptr, 2083 clauses.reductionMod, clauses.reductionVars, 2084 makeDenseBoolArrayAttr(ctx, clauses.reductionByref), 2085 makeArrayAttr(ctx, clauses.reductionSyms), 2086 clauses.threadLimit); 2087 } 2088 2089 LogicalResult TeamsOp::verify() { 2090 // Check parent region 2091 // TODO If nested inside of a target region, also check that it does not 2092 // contain any statements, declarations or directives other than this 2093 // omp.teams construct. The issue is how to support the initialization of 2094 // this operation's own arguments (allow SSA values across omp.target?). 2095 Operation *op = getOperation(); 2096 if (!isa<TargetOp>(op->getParentOp()) && 2097 !opInGlobalImplicitParallelRegion(op)) 2098 return emitError("expected to be nested inside of omp.target or not nested " 2099 "in any OpenMP dialect operations"); 2100 2101 // Check for num_teams clause restrictions 2102 if (auto numTeamsLowerBound = getNumTeamsLower()) { 2103 auto numTeamsUpperBound = getNumTeamsUpper(); 2104 if (!numTeamsUpperBound) 2105 return emitError("expected num_teams upper bound to be defined if the " 2106 "lower bound is defined"); 2107 if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType()) 2108 return emitError( 2109 "expected num_teams upper bound and lower bound to be the same type"); 2110 } 2111 2112 // Check for allocate clause restrictions 2113 if (getAllocateVars().size() != getAllocatorVars().size()) 2114 return emitError( 2115 "expected equal sizes for allocate and allocator variables"); 2116 2117 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(), 2118 getReductionByref()); 2119 } 2120 2121 //===----------------------------------------------------------------------===// 2122 // SectionOp 2123 //===----------------------------------------------------------------------===// 2124 2125 unsigned SectionOp::numPrivateBlockArgs() { 2126 return getParentOp().numPrivateBlockArgs(); 2127 } 2128 2129 unsigned SectionOp::numReductionBlockArgs() { 2130 return getParentOp().numReductionBlockArgs(); 2131 } 2132 2133 //===----------------------------------------------------------------------===// 2134 // SectionsOp 2135 //===----------------------------------------------------------------------===// 2136 2137 void SectionsOp::build(OpBuilder &builder, OperationState &state, 2138 const SectionsOperands &clauses) { 2139 MLIRContext *ctx = builder.getContext(); 2140 // TODO Store clauses in op: privateVars, privateSyms. 2141 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, 2142 clauses.nowait, /*private_vars=*/{}, 2143 /*private_syms=*/nullptr, clauses.reductionMod, 2144 clauses.reductionVars, 2145 makeDenseBoolArrayAttr(ctx, clauses.reductionByref), 2146 makeArrayAttr(ctx, clauses.reductionSyms)); 2147 } 2148 2149 LogicalResult SectionsOp::verify() { 2150 if (getAllocateVars().size() != getAllocatorVars().size()) 2151 return emitError( 2152 "expected equal sizes for allocate and allocator variables"); 2153 2154 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(), 2155 getReductionByref()); 2156 } 2157 2158 LogicalResult SectionsOp::verifyRegions() { 2159 for (auto &inst : *getRegion().begin()) { 2160 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) { 2161 return emitOpError() 2162 << "expected omp.section op or terminator op inside region"; 2163 } 2164 } 2165 2166 return success(); 2167 } 2168 2169 //===----------------------------------------------------------------------===// 2170 // SingleOp 2171 //===----------------------------------------------------------------------===// 2172 2173 void SingleOp::build(OpBuilder &builder, OperationState &state, 2174 const SingleOperands &clauses) { 2175 MLIRContext *ctx = builder.getContext(); 2176 // TODO Store clauses in op: privateVars, privateSyms. 2177 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, 2178 clauses.copyprivateVars, 2179 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait, 2180 /*private_vars=*/{}, /*private_syms=*/nullptr); 2181 } 2182 2183 LogicalResult SingleOp::verify() { 2184 // Check for allocate clause restrictions 2185 if (getAllocateVars().size() != getAllocatorVars().size()) 2186 return emitError( 2187 "expected equal sizes for allocate and allocator variables"); 2188 2189 return verifyCopyprivateVarList(*this, getCopyprivateVars(), 2190 getCopyprivateSyms()); 2191 } 2192 2193 //===----------------------------------------------------------------------===// 2194 // WorkshareOp 2195 //===----------------------------------------------------------------------===// 2196 2197 void WorkshareOp::build(OpBuilder &builder, OperationState &state, 2198 const WorkshareOperands &clauses) { 2199 WorkshareOp::build(builder, state, clauses.nowait); 2200 } 2201 2202 //===----------------------------------------------------------------------===// 2203 // WorkshareLoopWrapperOp 2204 //===----------------------------------------------------------------------===// 2205 2206 LogicalResult WorkshareLoopWrapperOp::verify() { 2207 if (!(*this)->getParentOfType<WorkshareOp>()) 2208 return emitError() << "must be nested in an omp.workshare"; 2209 if (getNestedWrapper()) 2210 return emitError() << "cannot be composite"; 2211 return success(); 2212 } 2213 2214 //===----------------------------------------------------------------------===// 2215 // LoopWrapperInterface 2216 //===----------------------------------------------------------------------===// 2217 2218 LogicalResult LoopWrapperInterface::verifyImpl() { 2219 Operation *op = this->getOperation(); 2220 if (!op->hasTrait<OpTrait::NoTerminator>() || 2221 !op->hasTrait<OpTrait::SingleBlock>()) 2222 return emitOpError() << "loop wrapper must also have the `NoTerminator` " 2223 "and `SingleBlock` traits"; 2224 2225 if (op->getNumRegions() != 1) 2226 return emitOpError() << "loop wrapper does not contain exactly one region"; 2227 2228 Region ®ion = op->getRegion(0); 2229 if (range_size(region.getOps()) != 1) 2230 return emitOpError() 2231 << "loop wrapper does not contain exactly one nested op"; 2232 2233 Operation &firstOp = *region.op_begin(); 2234 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp)) 2235 return emitOpError() << "op nested in loop wrapper is not another loop " 2236 "wrapper or `omp.loop_nest`"; 2237 2238 return success(); 2239 } 2240 2241 //===----------------------------------------------------------------------===// 2242 // LoopOp 2243 //===----------------------------------------------------------------------===// 2244 2245 void LoopOp::build(OpBuilder &builder, OperationState &state, 2246 const LoopOperands &clauses) { 2247 MLIRContext *ctx = builder.getContext(); 2248 2249 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars, 2250 makeArrayAttr(ctx, clauses.privateSyms), clauses.order, 2251 clauses.orderMod, clauses.reductionMod, clauses.reductionVars, 2252 makeDenseBoolArrayAttr(ctx, clauses.reductionByref), 2253 makeArrayAttr(ctx, clauses.reductionSyms)); 2254 } 2255 2256 LogicalResult LoopOp::verify() { 2257 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(), 2258 getReductionByref()); 2259 } 2260 2261 LogicalResult LoopOp::verifyRegions() { 2262 if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) || 2263 getNestedWrapper()) 2264 return emitError() << "`omp.loop` expected to be a standalone loop wrapper"; 2265 2266 return success(); 2267 } 2268 2269 //===----------------------------------------------------------------------===// 2270 // WsloopOp 2271 //===----------------------------------------------------------------------===// 2272 2273 void WsloopOp::build(OpBuilder &builder, OperationState &state, 2274 ArrayRef<NamedAttribute> attributes) { 2275 build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, 2276 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), 2277 /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr, 2278 /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr, 2279 /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(), 2280 /*reduction_byref=*/nullptr, 2281 /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr, 2282 /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr, 2283 /*schedule_simd=*/false); 2284 state.addAttributes(attributes); 2285 } 2286 2287 void WsloopOp::build(OpBuilder &builder, OperationState &state, 2288 const WsloopOperands &clauses) { 2289 MLIRContext *ctx = builder.getContext(); 2290 // TODO: Store clauses in op: allocateVars, allocatorVars, privateVars, 2291 // privateSyms. 2292 WsloopOp::build(builder, state, 2293 /*allocate_vars=*/{}, /*allocator_vars=*/{}, 2294 clauses.linearVars, clauses.linearStepVars, clauses.nowait, 2295 clauses.order, clauses.orderMod, clauses.ordered, 2296 clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), 2297 clauses.reductionMod, clauses.reductionVars, 2298 makeDenseBoolArrayAttr(ctx, clauses.reductionByref), 2299 makeArrayAttr(ctx, clauses.reductionSyms), 2300 clauses.scheduleKind, clauses.scheduleChunk, 2301 clauses.scheduleMod, clauses.scheduleSimd); 2302 } 2303 2304 LogicalResult WsloopOp::verify() { 2305 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(), 2306 getReductionByref()); 2307 } 2308 2309 LogicalResult WsloopOp::verifyRegions() { 2310 bool isCompositeChildLeaf = 2311 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()); 2312 2313 if (LoopWrapperInterface nested = getNestedWrapper()) { 2314 if (!isComposite()) 2315 return emitError() 2316 << "'omp.composite' attribute missing from composite wrapper"; 2317 2318 // Check for the allowed leaf constructs that may appear in a composite 2319 // construct directly after DO/FOR. 2320 if (!isa<SimdOp>(nested)) 2321 return emitError() << "only supported nested wrapper is 'omp.simd'"; 2322 2323 } else if (isComposite() && !isCompositeChildLeaf) { 2324 return emitError() 2325 << "'omp.composite' attribute present in non-composite wrapper"; 2326 } else if (!isComposite() && isCompositeChildLeaf) { 2327 return emitError() 2328 << "'omp.composite' attribute missing from composite wrapper"; 2329 } 2330 2331 return success(); 2332 } 2333 2334 //===----------------------------------------------------------------------===// 2335 // Simd construct [2.9.3.1] 2336 //===----------------------------------------------------------------------===// 2337 2338 void SimdOp::build(OpBuilder &builder, OperationState &state, 2339 const SimdOperands &clauses) { 2340 MLIRContext *ctx = builder.getContext(); 2341 // TODO Store clauses in op: linearVars, linearStepVars, privateVars, 2342 // privateSyms. 2343 SimdOp::build(builder, state, clauses.alignedVars, 2344 makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr, 2345 /*linear_vars=*/{}, /*linear_step_vars=*/{}, 2346 clauses.nontemporalVars, clauses.order, clauses.orderMod, 2347 clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), 2348 clauses.reductionMod, clauses.reductionVars, 2349 makeDenseBoolArrayAttr(ctx, clauses.reductionByref), 2350 makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen, 2351 clauses.simdlen); 2352 } 2353 2354 LogicalResult SimdOp::verify() { 2355 if (getSimdlen().has_value() && getSafelen().has_value() && 2356 getSimdlen().value() > getSafelen().value()) 2357 return emitOpError() 2358 << "simdlen clause and safelen clause are both present, but the " 2359 "simdlen value is not less than or equal to safelen value"; 2360 2361 if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed()) 2362 return failure(); 2363 2364 if (verifyNontemporalClause(*this, getNontemporalVars()).failed()) 2365 return failure(); 2366 2367 bool isCompositeChildLeaf = 2368 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()); 2369 2370 if (!isComposite() && isCompositeChildLeaf) 2371 return emitError() 2372 << "'omp.composite' attribute missing from composite wrapper"; 2373 2374 if (isComposite() && !isCompositeChildLeaf) 2375 return emitError() 2376 << "'omp.composite' attribute present in non-composite wrapper"; 2377 2378 return success(); 2379 } 2380 2381 LogicalResult SimdOp::verifyRegions() { 2382 if (getNestedWrapper()) 2383 return emitOpError() << "must wrap an 'omp.loop_nest' directly"; 2384 2385 return success(); 2386 } 2387 2388 //===----------------------------------------------------------------------===// 2389 // Distribute construct [2.9.4.1] 2390 //===----------------------------------------------------------------------===// 2391 2392 void DistributeOp::build(OpBuilder &builder, OperationState &state, 2393 const DistributeOperands &clauses) { 2394 DistributeOp::build(builder, state, clauses.allocateVars, 2395 clauses.allocatorVars, clauses.distScheduleStatic, 2396 clauses.distScheduleChunkSize, clauses.order, 2397 clauses.orderMod, clauses.privateVars, 2398 makeArrayAttr(builder.getContext(), clauses.privateSyms)); 2399 } 2400 2401 LogicalResult DistributeOp::verify() { 2402 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic()) 2403 return emitOpError() << "chunk size set without " 2404 "dist_schedule_static being present"; 2405 2406 if (getAllocateVars().size() != getAllocatorVars().size()) 2407 return emitError( 2408 "expected equal sizes for allocate and allocator variables"); 2409 2410 return success(); 2411 } 2412 2413 LogicalResult DistributeOp::verifyRegions() { 2414 if (LoopWrapperInterface nested = getNestedWrapper()) { 2415 if (!isComposite()) 2416 return emitError() 2417 << "'omp.composite' attribute missing from composite wrapper"; 2418 // Check for the allowed leaf constructs that may appear in a composite 2419 // construct directly after DISTRIBUTE. 2420 if (isa<WsloopOp>(nested)) { 2421 if (!llvm::dyn_cast_if_present<ParallelOp>((*this)->getParentOp())) 2422 return emitError() << "an 'omp.wsloop' nested wrapper is only allowed " 2423 "when 'omp.parallel' is the direct parent"; 2424 } else if (!isa<SimdOp>(nested)) 2425 return emitError() << "only supported nested wrappers are 'omp.simd' and " 2426 "'omp.wsloop'"; 2427 } else if (isComposite()) { 2428 return emitError() 2429 << "'omp.composite' attribute present in non-composite wrapper"; 2430 } 2431 2432 return success(); 2433 } 2434 2435 //===----------------------------------------------------------------------===// 2436 // DeclareReductionOp 2437 //===----------------------------------------------------------------------===// 2438 2439 LogicalResult DeclareReductionOp::verifyRegions() { 2440 if (!getAllocRegion().empty()) { 2441 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) { 2442 if (yieldOp.getResults().size() != 1 || 2443 yieldOp.getResults().getTypes()[0] != getType()) 2444 return emitOpError() << "expects alloc region to yield a value " 2445 "of the reduction type"; 2446 } 2447 } 2448 2449 if (getInitializerRegion().empty()) 2450 return emitOpError() << "expects non-empty initializer region"; 2451 Block &initializerEntryBlock = getInitializerRegion().front(); 2452 2453 if (initializerEntryBlock.getNumArguments() == 1) { 2454 if (!getAllocRegion().empty()) 2455 return emitOpError() << "expects two arguments to the initializer region " 2456 "when an allocation region is used"; 2457 } else if (initializerEntryBlock.getNumArguments() == 2) { 2458 if (getAllocRegion().empty()) 2459 return emitOpError() << "expects one argument to the initializer region " 2460 "when no allocation region is used"; 2461 } else { 2462 return emitOpError() 2463 << "expects one or two arguments to the initializer region"; 2464 } 2465 2466 for (mlir::Value arg : initializerEntryBlock.getArguments()) 2467 if (arg.getType() != getType()) 2468 return emitOpError() << "expects initializer region argument to match " 2469 "the reduction type"; 2470 2471 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) { 2472 if (yieldOp.getResults().size() != 1 || 2473 yieldOp.getResults().getTypes()[0] != getType()) 2474 return emitOpError() << "expects initializer region to yield a value " 2475 "of the reduction type"; 2476 } 2477 2478 if (getReductionRegion().empty()) 2479 return emitOpError() << "expects non-empty reduction region"; 2480 Block &reductionEntryBlock = getReductionRegion().front(); 2481 if (reductionEntryBlock.getNumArguments() != 2 || 2482 reductionEntryBlock.getArgumentTypes()[0] != 2483 reductionEntryBlock.getArgumentTypes()[1] || 2484 reductionEntryBlock.getArgumentTypes()[0] != getType()) 2485 return emitOpError() << "expects reduction region with two arguments of " 2486 "the reduction type"; 2487 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) { 2488 if (yieldOp.getResults().size() != 1 || 2489 yieldOp.getResults().getTypes()[0] != getType()) 2490 return emitOpError() << "expects reduction region to yield a value " 2491 "of the reduction type"; 2492 } 2493 2494 if (!getAtomicReductionRegion().empty()) { 2495 Block &atomicReductionEntryBlock = getAtomicReductionRegion().front(); 2496 if (atomicReductionEntryBlock.getNumArguments() != 2 || 2497 atomicReductionEntryBlock.getArgumentTypes()[0] != 2498 atomicReductionEntryBlock.getArgumentTypes()[1]) 2499 return emitOpError() << "expects atomic reduction region with two " 2500 "arguments of the same type"; 2501 auto ptrType = llvm::dyn_cast<PointerLikeType>( 2502 atomicReductionEntryBlock.getArgumentTypes()[0]); 2503 if (!ptrType || 2504 (ptrType.getElementType() && ptrType.getElementType() != getType())) 2505 return emitOpError() << "expects atomic reduction region arguments to " 2506 "be accumulators containing the reduction type"; 2507 } 2508 2509 if (getCleanupRegion().empty()) 2510 return success(); 2511 Block &cleanupEntryBlock = getCleanupRegion().front(); 2512 if (cleanupEntryBlock.getNumArguments() != 1 || 2513 cleanupEntryBlock.getArgument(0).getType() != getType()) 2514 return emitOpError() << "expects cleanup region with one argument " 2515 "of the reduction type"; 2516 2517 return success(); 2518 } 2519 2520 //===----------------------------------------------------------------------===// 2521 // TaskOp 2522 //===----------------------------------------------------------------------===// 2523 2524 void TaskOp::build(OpBuilder &builder, OperationState &state, 2525 const TaskOperands &clauses) { 2526 MLIRContext *ctx = builder.getContext(); 2527 TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, 2528 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars, 2529 clauses.final, clauses.ifExpr, clauses.inReductionVars, 2530 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref), 2531 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable, 2532 clauses.priority, /*private_vars=*/clauses.privateVars, 2533 /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms), 2534 clauses.untied, clauses.eventHandle); 2535 } 2536 2537 LogicalResult TaskOp::verify() { 2538 LogicalResult verifyDependVars = 2539 verifyDependVarList(*this, getDependKinds(), getDependVars()); 2540 return failed(verifyDependVars) 2541 ? verifyDependVars 2542 : verifyReductionVarList(*this, getInReductionSyms(), 2543 getInReductionVars(), 2544 getInReductionByref()); 2545 } 2546 2547 //===----------------------------------------------------------------------===// 2548 // TaskgroupOp 2549 //===----------------------------------------------------------------------===// 2550 2551 void TaskgroupOp::build(OpBuilder &builder, OperationState &state, 2552 const TaskgroupOperands &clauses) { 2553 MLIRContext *ctx = builder.getContext(); 2554 TaskgroupOp::build(builder, state, clauses.allocateVars, 2555 clauses.allocatorVars, clauses.taskReductionVars, 2556 makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref), 2557 makeArrayAttr(ctx, clauses.taskReductionSyms)); 2558 } 2559 2560 LogicalResult TaskgroupOp::verify() { 2561 return verifyReductionVarList(*this, getTaskReductionSyms(), 2562 getTaskReductionVars(), 2563 getTaskReductionByref()); 2564 } 2565 2566 //===----------------------------------------------------------------------===// 2567 // TaskloopOp 2568 //===----------------------------------------------------------------------===// 2569 2570 void TaskloopOp::build(OpBuilder &builder, OperationState &state, 2571 const TaskloopOperands &clauses) { 2572 MLIRContext *ctx = builder.getContext(); 2573 // TODO Store clauses in op: privateVars, privateSyms. 2574 TaskloopOp::build( 2575 builder, state, clauses.allocateVars, clauses.allocatorVars, 2576 clauses.final, clauses.grainsize, clauses.ifExpr, clauses.inReductionVars, 2577 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref), 2578 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable, 2579 clauses.nogroup, clauses.numTasks, clauses.priority, /*private_vars=*/{}, 2580 /*private_syms=*/nullptr, clauses.reductionMod, clauses.reductionVars, 2581 makeDenseBoolArrayAttr(ctx, clauses.reductionByref), 2582 makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied); 2583 } 2584 2585 SmallVector<Value> TaskloopOp::getAllReductionVars() { 2586 SmallVector<Value> allReductionNvars(getInReductionVars().begin(), 2587 getInReductionVars().end()); 2588 allReductionNvars.insert(allReductionNvars.end(), getReductionVars().begin(), 2589 getReductionVars().end()); 2590 return allReductionNvars; 2591 } 2592 2593 LogicalResult TaskloopOp::verify() { 2594 if (getAllocateVars().size() != getAllocatorVars().size()) 2595 return emitError( 2596 "expected equal sizes for allocate and allocator variables"); 2597 if (failed(verifyReductionVarList(*this, getReductionSyms(), 2598 getReductionVars(), getReductionByref())) || 2599 failed(verifyReductionVarList(*this, getInReductionSyms(), 2600 getInReductionVars(), 2601 getInReductionByref()))) 2602 return failure(); 2603 2604 if (!getReductionVars().empty() && getNogroup()) 2605 return emitError("if a reduction clause is present on the taskloop " 2606 "directive, the nogroup clause must not be specified"); 2607 for (auto var : getReductionVars()) { 2608 if (llvm::is_contained(getInReductionVars(), var)) 2609 return emitError("the same list item cannot appear in both a reduction " 2610 "and an in_reduction clause"); 2611 } 2612 2613 if (getGrainsize() && getNumTasks()) { 2614 return emitError( 2615 "the grainsize clause and num_tasks clause are mutually exclusive and " 2616 "may not appear on the same taskloop directive"); 2617 } 2618 2619 return success(); 2620 } 2621 2622 LogicalResult TaskloopOp::verifyRegions() { 2623 if (LoopWrapperInterface nested = getNestedWrapper()) { 2624 if (!isComposite()) 2625 return emitError() 2626 << "'omp.composite' attribute missing from composite wrapper"; 2627 2628 // Check for the allowed leaf constructs that may appear in a composite 2629 // construct directly after TASKLOOP. 2630 if (!isa<SimdOp>(nested)) 2631 return emitError() << "only supported nested wrapper is 'omp.simd'"; 2632 } else if (isComposite()) { 2633 return emitError() 2634 << "'omp.composite' attribute present in non-composite wrapper"; 2635 } 2636 2637 return success(); 2638 } 2639 2640 //===----------------------------------------------------------------------===// 2641 // LoopNestOp 2642 //===----------------------------------------------------------------------===// 2643 2644 ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) { 2645 // Parse an opening `(` followed by induction variables followed by `)` 2646 SmallVector<OpAsmParser::Argument> ivs; 2647 SmallVector<OpAsmParser::UnresolvedOperand> lbs, ubs; 2648 Type loopVarType; 2649 if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) || 2650 parser.parseColonType(loopVarType) || 2651 // Parse loop bounds. 2652 parser.parseEqual() || 2653 parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) || 2654 parser.parseKeyword("to") || 2655 parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren)) 2656 return failure(); 2657 2658 for (auto &iv : ivs) 2659 iv.type = loopVarType; 2660 2661 // Parse "inclusive" flag. 2662 if (succeeded(parser.parseOptionalKeyword("inclusive"))) 2663 result.addAttribute("loop_inclusive", 2664 UnitAttr::get(parser.getBuilder().getContext())); 2665 2666 // Parse step values. 2667 SmallVector<OpAsmParser::UnresolvedOperand> steps; 2668 if (parser.parseKeyword("step") || 2669 parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren)) 2670 return failure(); 2671 2672 // Parse the body. 2673 Region *region = result.addRegion(); 2674 if (parser.parseRegion(*region, ivs)) 2675 return failure(); 2676 2677 // Resolve operands. 2678 if (parser.resolveOperands(lbs, loopVarType, result.operands) || 2679 parser.resolveOperands(ubs, loopVarType, result.operands) || 2680 parser.resolveOperands(steps, loopVarType, result.operands)) 2681 return failure(); 2682 2683 // Parse the optional attribute list. 2684 return parser.parseOptionalAttrDict(result.attributes); 2685 } 2686 2687 void LoopNestOp::print(OpAsmPrinter &p) { 2688 Region ®ion = getRegion(); 2689 auto args = region.getArguments(); 2690 p << " (" << args << ") : " << args[0].getType() << " = (" 2691 << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") "; 2692 if (getLoopInclusive()) 2693 p << "inclusive "; 2694 p << "step (" << getLoopSteps() << ") "; 2695 p.printRegion(region, /*printEntryBlockArgs=*/false); 2696 } 2697 2698 void LoopNestOp::build(OpBuilder &builder, OperationState &state, 2699 const LoopNestOperands &clauses) { 2700 LoopNestOp::build(builder, state, clauses.loopLowerBounds, 2701 clauses.loopUpperBounds, clauses.loopSteps, 2702 clauses.loopInclusive); 2703 } 2704 2705 LogicalResult LoopNestOp::verify() { 2706 if (getLoopLowerBounds().empty()) 2707 return emitOpError() << "must represent at least one loop"; 2708 2709 if (getLoopLowerBounds().size() != getIVs().size()) 2710 return emitOpError() << "number of range arguments and IVs do not match"; 2711 2712 for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) { 2713 if (lb.getType() != iv.getType()) 2714 return emitOpError() 2715 << "range argument type does not match corresponding IV type"; 2716 } 2717 2718 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp())) 2719 return emitOpError() << "expects parent op to be a loop wrapper"; 2720 2721 return success(); 2722 } 2723 2724 void LoopNestOp::gatherWrappers( 2725 SmallVectorImpl<LoopWrapperInterface> &wrappers) { 2726 Operation *parent = (*this)->getParentOp(); 2727 while (auto wrapper = 2728 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) { 2729 wrappers.push_back(wrapper); 2730 parent = parent->getParentOp(); 2731 } 2732 } 2733 2734 //===----------------------------------------------------------------------===// 2735 // Critical construct (2.17.1) 2736 //===----------------------------------------------------------------------===// 2737 2738 void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state, 2739 const CriticalDeclareOperands &clauses) { 2740 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint); 2741 } 2742 2743 LogicalResult CriticalDeclareOp::verify() { 2744 return verifySynchronizationHint(*this, getHint()); 2745 } 2746 2747 LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 2748 if (getNameAttr()) { 2749 SymbolRefAttr symbolRef = getNameAttr(); 2750 auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>( 2751 *this, symbolRef); 2752 if (!decl) { 2753 return emitOpError() << "expected symbol reference " << symbolRef 2754 << " to point to a critical declaration"; 2755 } 2756 } 2757 2758 return success(); 2759 } 2760 2761 //===----------------------------------------------------------------------===// 2762 // Ordered construct 2763 //===----------------------------------------------------------------------===// 2764 2765 static LogicalResult verifyOrderedParent(Operation &op) { 2766 bool hasRegion = op.getNumRegions() > 0; 2767 auto loopOp = op.getParentOfType<LoopNestOp>(); 2768 if (!loopOp) { 2769 if (hasRegion) 2770 return success(); 2771 2772 // TODO: Consider if this needs to be the case only for the standalone 2773 // variant of the ordered construct. 2774 return op.emitOpError() << "must be nested inside of a loop"; 2775 } 2776 2777 Operation *wrapper = loopOp->getParentOp(); 2778 if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) { 2779 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr(); 2780 if (!orderedAttr) 2781 return op.emitOpError() << "the enclosing worksharing-loop region must " 2782 "have an ordered clause"; 2783 2784 if (hasRegion && orderedAttr.getInt() != 0) 2785 return op.emitOpError() << "the enclosing loop's ordered clause must not " 2786 "have a parameter present"; 2787 2788 if (!hasRegion && orderedAttr.getInt() == 0) 2789 return op.emitOpError() << "the enclosing loop's ordered clause must " 2790 "have a parameter present"; 2791 } else if (!isa<SimdOp>(wrapper)) { 2792 return op.emitOpError() << "must be nested inside of a worksharing, simd " 2793 "or worksharing simd loop"; 2794 } 2795 return success(); 2796 } 2797 2798 void OrderedOp::build(OpBuilder &builder, OperationState &state, 2799 const OrderedOperands &clauses) { 2800 OrderedOp::build(builder, state, clauses.doacrossDependType, 2801 clauses.doacrossNumLoops, clauses.doacrossDependVars); 2802 } 2803 2804 LogicalResult OrderedOp::verify() { 2805 if (failed(verifyOrderedParent(**this))) 2806 return failure(); 2807 2808 auto wrapper = (*this)->getParentOfType<WsloopOp>(); 2809 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops()) 2810 return emitOpError() << "number of variables in depend clause does not " 2811 << "match number of iteration variables in the " 2812 << "doacross loop"; 2813 2814 return success(); 2815 } 2816 2817 void OrderedRegionOp::build(OpBuilder &builder, OperationState &state, 2818 const OrderedRegionOperands &clauses) { 2819 OrderedRegionOp::build(builder, state, clauses.parLevelSimd); 2820 } 2821 2822 LogicalResult OrderedRegionOp::verify() { return verifyOrderedParent(**this); } 2823 2824 //===----------------------------------------------------------------------===// 2825 // TaskwaitOp 2826 //===----------------------------------------------------------------------===// 2827 2828 void TaskwaitOp::build(OpBuilder &builder, OperationState &state, 2829 const TaskwaitOperands &clauses) { 2830 // TODO Store clauses in op: dependKinds, dependVars, nowait. 2831 TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr, 2832 /*depend_vars=*/{}, /*nowait=*/nullptr); 2833 } 2834 2835 //===----------------------------------------------------------------------===// 2836 // Verifier for AtomicReadOp 2837 //===----------------------------------------------------------------------===// 2838 2839 LogicalResult AtomicReadOp::verify() { 2840 if (verifyCommon().failed()) 2841 return mlir::failure(); 2842 2843 if (auto mo = getMemoryOrder()) { 2844 if (*mo == ClauseMemoryOrderKind::Acq_rel || 2845 *mo == ClauseMemoryOrderKind::Release) { 2846 return emitError( 2847 "memory-order must not be acq_rel or release for atomic reads"); 2848 } 2849 } 2850 return verifySynchronizationHint(*this, getHint()); 2851 } 2852 2853 //===----------------------------------------------------------------------===// 2854 // Verifier for AtomicWriteOp 2855 //===----------------------------------------------------------------------===// 2856 2857 LogicalResult AtomicWriteOp::verify() { 2858 if (verifyCommon().failed()) 2859 return mlir::failure(); 2860 2861 if (auto mo = getMemoryOrder()) { 2862 if (*mo == ClauseMemoryOrderKind::Acq_rel || 2863 *mo == ClauseMemoryOrderKind::Acquire) { 2864 return emitError( 2865 "memory-order must not be acq_rel or acquire for atomic writes"); 2866 } 2867 } 2868 return verifySynchronizationHint(*this, getHint()); 2869 } 2870 2871 //===----------------------------------------------------------------------===// 2872 // Verifier for AtomicUpdateOp 2873 //===----------------------------------------------------------------------===// 2874 2875 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op, 2876 PatternRewriter &rewriter) { 2877 if (op.isNoOp()) { 2878 rewriter.eraseOp(op); 2879 return success(); 2880 } 2881 if (Value writeVal = op.getWriteOpVal()) { 2882 rewriter.replaceOpWithNewOp<AtomicWriteOp>( 2883 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr()); 2884 return success(); 2885 } 2886 return failure(); 2887 } 2888 2889 LogicalResult AtomicUpdateOp::verify() { 2890 if (verifyCommon().failed()) 2891 return mlir::failure(); 2892 2893 if (auto mo = getMemoryOrder()) { 2894 if (*mo == ClauseMemoryOrderKind::Acq_rel || 2895 *mo == ClauseMemoryOrderKind::Acquire) { 2896 return emitError( 2897 "memory-order must not be acq_rel or acquire for atomic updates"); 2898 } 2899 } 2900 2901 return verifySynchronizationHint(*this, getHint()); 2902 } 2903 2904 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); } 2905 2906 //===----------------------------------------------------------------------===// 2907 // Verifier for AtomicCaptureOp 2908 //===----------------------------------------------------------------------===// 2909 2910 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() { 2911 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp())) 2912 return op; 2913 return dyn_cast<AtomicReadOp>(getSecondOp()); 2914 } 2915 2916 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() { 2917 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp())) 2918 return op; 2919 return dyn_cast<AtomicWriteOp>(getSecondOp()); 2920 } 2921 2922 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() { 2923 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp())) 2924 return op; 2925 return dyn_cast<AtomicUpdateOp>(getSecondOp()); 2926 } 2927 2928 LogicalResult AtomicCaptureOp::verify() { 2929 return verifySynchronizationHint(*this, getHint()); 2930 } 2931 2932 LogicalResult AtomicCaptureOp::verifyRegions() { 2933 if (verifyRegionsCommon().failed()) 2934 return mlir::failure(); 2935 2936 if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr("hint")) 2937 return emitOpError( 2938 "operations inside capture region must not have hint clause"); 2939 2940 if (getFirstOp()->getAttr("memory_order") || 2941 getSecondOp()->getAttr("memory_order")) 2942 return emitOpError( 2943 "operations inside capture region must not have memory_order clause"); 2944 return success(); 2945 } 2946 2947 //===----------------------------------------------------------------------===// 2948 // CancelOp 2949 //===----------------------------------------------------------------------===// 2950 2951 void CancelOp::build(OpBuilder &builder, OperationState &state, 2952 const CancelOperands &clauses) { 2953 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr); 2954 } 2955 2956 LogicalResult CancelOp::verify() { 2957 ClauseCancellationConstructType cct = getCancelDirective(); 2958 Operation *parentOp = (*this)->getParentOp(); 2959 2960 if (!parentOp) { 2961 return emitOpError() << "must be used within a region supporting " 2962 "cancel directive"; 2963 } 2964 2965 if ((cct == ClauseCancellationConstructType::Parallel) && 2966 !isa<ParallelOp>(parentOp)) { 2967 return emitOpError() << "cancel parallel must appear " 2968 << "inside a parallel region"; 2969 } 2970 if (cct == ClauseCancellationConstructType::Loop) { 2971 auto loopOp = dyn_cast<LoopNestOp>(parentOp); 2972 auto wsloopOp = llvm::dyn_cast_if_present<WsloopOp>( 2973 loopOp ? loopOp->getParentOp() : nullptr); 2974 2975 if (!wsloopOp) { 2976 return emitOpError() 2977 << "cancel loop must appear inside a worksharing-loop region"; 2978 } 2979 if (wsloopOp.getNowaitAttr()) { 2980 return emitError() << "A worksharing construct that is canceled " 2981 << "must not have a nowait clause"; 2982 } 2983 if (wsloopOp.getOrderedAttr()) { 2984 return emitError() << "A worksharing construct that is canceled " 2985 << "must not have an ordered clause"; 2986 } 2987 2988 } else if (cct == ClauseCancellationConstructType::Sections) { 2989 if (!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) { 2990 return emitOpError() << "cancel sections must appear " 2991 << "inside a sections region"; 2992 } 2993 if (isa_and_nonnull<SectionsOp>(parentOp->getParentOp()) && 2994 cast<SectionsOp>(parentOp->getParentOp()).getNowaitAttr()) { 2995 return emitError() << "A sections construct that is canceled " 2996 << "must not have a nowait clause"; 2997 } 2998 } 2999 // TODO : Add more when we support taskgroup. 3000 return success(); 3001 } 3002 3003 //===----------------------------------------------------------------------===// 3004 // CancellationPointOp 3005 //===----------------------------------------------------------------------===// 3006 3007 void CancellationPointOp::build(OpBuilder &builder, OperationState &state, 3008 const CancellationPointOperands &clauses) { 3009 CancellationPointOp::build(builder, state, clauses.cancelDirective); 3010 } 3011 3012 LogicalResult CancellationPointOp::verify() { 3013 ClauseCancellationConstructType cct = getCancelDirective(); 3014 Operation *parentOp = (*this)->getParentOp(); 3015 3016 if (!parentOp) { 3017 return emitOpError() << "must be used within a region supporting " 3018 "cancellation point directive"; 3019 } 3020 3021 if ((cct == ClauseCancellationConstructType::Parallel) && 3022 !(isa<ParallelOp>(parentOp))) { 3023 return emitOpError() << "cancellation point parallel must appear " 3024 << "inside a parallel region"; 3025 } 3026 if ((cct == ClauseCancellationConstructType::Loop) && 3027 (!isa<LoopNestOp>(parentOp) || !isa<WsloopOp>(parentOp->getParentOp()))) { 3028 return emitOpError() << "cancellation point loop must appear " 3029 << "inside a worksharing-loop region"; 3030 } 3031 if ((cct == ClauseCancellationConstructType::Sections) && 3032 !(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) { 3033 return emitOpError() << "cancellation point sections must appear " 3034 << "inside a sections region"; 3035 } 3036 // TODO : Add more when we support taskgroup. 3037 return success(); 3038 } 3039 3040 //===----------------------------------------------------------------------===// 3041 // MapBoundsOp 3042 //===----------------------------------------------------------------------===// 3043 3044 LogicalResult MapBoundsOp::verify() { 3045 auto extent = getExtent(); 3046 auto upperbound = getUpperBound(); 3047 if (!extent && !upperbound) 3048 return emitError("expected extent or upperbound."); 3049 return success(); 3050 } 3051 3052 void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState, 3053 TypeRange /*result_types*/, StringAttr symName, 3054 TypeAttr type) { 3055 PrivateClauseOp::build( 3056 odsBuilder, odsState, symName, type, 3057 DataSharingClauseTypeAttr::get(odsBuilder.getContext(), 3058 DataSharingClauseType::Private)); 3059 } 3060 3061 LogicalResult PrivateClauseOp::verifyRegions() { 3062 Type symType = getType(); 3063 3064 auto verifyTerminator = [&](Operation *terminator, 3065 bool yieldsValue) -> LogicalResult { 3066 if (!terminator->getBlock()->getSuccessors().empty()) 3067 return success(); 3068 3069 if (!llvm::isa<YieldOp>(terminator)) 3070 return mlir::emitError(terminator->getLoc()) 3071 << "expected exit block terminator to be an `omp.yield` op."; 3072 3073 YieldOp yieldOp = llvm::cast<YieldOp>(terminator); 3074 TypeRange yieldedTypes = yieldOp.getResults().getTypes(); 3075 3076 if (!yieldsValue) { 3077 if (yieldedTypes.empty()) 3078 return success(); 3079 3080 return mlir::emitError(terminator->getLoc()) 3081 << "Did not expect any values to be yielded."; 3082 } 3083 3084 if (yieldedTypes.size() == 1 && yieldedTypes.front() == symType) 3085 return success(); 3086 3087 auto error = mlir::emitError(yieldOp.getLoc()) 3088 << "Invalid yielded value. Expected type: " << symType 3089 << ", got: "; 3090 3091 if (yieldedTypes.empty()) 3092 error << "None"; 3093 else 3094 error << yieldedTypes; 3095 3096 return error; 3097 }; 3098 3099 auto verifyRegion = [&](Region ®ion, unsigned expectedNumArgs, 3100 StringRef regionName, 3101 bool yieldsValue) -> LogicalResult { 3102 assert(!region.empty()); 3103 3104 if (region.getNumArguments() != expectedNumArgs) 3105 return mlir::emitError(region.getLoc()) 3106 << "`" << regionName << "`: " 3107 << "expected " << expectedNumArgs 3108 << " region arguments, got: " << region.getNumArguments(); 3109 3110 for (Block &block : region) { 3111 // MLIR will verify the absence of the terminator for us. 3112 if (!block.mightHaveTerminator()) 3113 continue; 3114 3115 if (failed(verifyTerminator(block.getTerminator(), yieldsValue))) 3116 return failure(); 3117 } 3118 3119 return success(); 3120 }; 3121 3122 if (failed(verifyRegion(getAllocRegion(), /*expectedNumArgs=*/1, "alloc", 3123 /*yieldsValue=*/true))) 3124 return failure(); 3125 3126 DataSharingClauseType dsType = getDataSharingType(); 3127 3128 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty()) 3129 return emitError("`private` clauses require only an `alloc` region."); 3130 3131 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty()) 3132 return emitError( 3133 "`firstprivate` clauses require both `alloc` and `copy` regions."); 3134 3135 if (dsType == DataSharingClauseType::FirstPrivate && 3136 failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy", 3137 /*yieldsValue=*/true))) 3138 return failure(); 3139 3140 if (!getDeallocRegion().empty() && 3141 failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc", 3142 /*yieldsValue=*/false))) 3143 return failure(); 3144 3145 return success(); 3146 } 3147 3148 //===----------------------------------------------------------------------===// 3149 // Spec 5.2: Masked construct (10.5) 3150 //===----------------------------------------------------------------------===// 3151 3152 void MaskedOp::build(OpBuilder &builder, OperationState &state, 3153 const MaskedOperands &clauses) { 3154 MaskedOp::build(builder, state, clauses.filteredThreadId); 3155 } 3156 3157 //===----------------------------------------------------------------------===// 3158 // Spec 5.2: Scan construct (5.6) 3159 //===----------------------------------------------------------------------===// 3160 3161 void ScanOp::build(OpBuilder &builder, OperationState &state, 3162 const ScanOperands &clauses) { 3163 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars); 3164 } 3165 3166 LogicalResult ScanOp::verify() { 3167 if (hasExclusiveVars() == hasInclusiveVars()) 3168 return emitError( 3169 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected"); 3170 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) { 3171 if (parentWsLoopOp.getReductionModAttr() && 3172 parentWsLoopOp.getReductionModAttr().getValue() == 3173 ReductionModifier::inscan) 3174 return success(); 3175 } 3176 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) { 3177 if (parentSimdOp.getReductionModAttr() && 3178 parentSimdOp.getReductionModAttr().getValue() == 3179 ReductionModifier::inscan) 3180 return success(); 3181 } 3182 return emitError("SCAN directive needs to be enclosed within a parent " 3183 "worksharing loop construct or SIMD construct with INSCAN " 3184 "reduction modifier"); 3185 } 3186 3187 #define GET_ATTRDEF_CLASSES 3188 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" 3189 3190 #define GET_OP_CLASSES 3191 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 3192 3193 #define GET_TYPEDEF_CLASSES 3194 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc" 3195