1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 // ============================================================================= 8 9 #include "mlir/Dialect/OpenACC/OpenACC.h" 10 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 11 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/BuiltinAttributes.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/DialectImplementation.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/OpImplementation.h" 19 #include "mlir/Support/LLVM.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "llvm/ADT/SmallSet.h" 22 #include "llvm/ADT/TypeSwitch.h" 23 #include "llvm/Support/LogicalResult.h" 24 25 using namespace mlir; 26 using namespace acc; 27 28 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc" 29 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc" 30 #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc" 31 #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc" 32 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc" 33 34 namespace { 35 struct MemRefPointerLikeModel 36 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel, 37 MemRefType> { 38 Type getElementType(Type pointer) const { 39 return llvm::cast<MemRefType>(pointer).getElementType(); 40 } 41 }; 42 43 struct LLVMPointerPointerLikeModel 44 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel, 45 LLVM::LLVMPointerType> { 46 Type getElementType(Type pointer) const { return Type(); } 47 }; 48 } // namespace 49 50 //===----------------------------------------------------------------------===// 51 // OpenACC operations 52 //===----------------------------------------------------------------------===// 53 54 void OpenACCDialect::initialize() { 55 addOperations< 56 #define GET_OP_LIST 57 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" 58 >(); 59 addAttributes< 60 #define GET_ATTRDEF_LIST 61 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc" 62 >(); 63 addTypes< 64 #define GET_TYPEDEF_LIST 65 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc" 66 >(); 67 68 // By attaching interfaces here, we make the OpenACC dialect dependent on 69 // the other dialects. This is probably better than having dialects like LLVM 70 // and memref be dependent on OpenACC. 71 MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext()); 72 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>( 73 *getContext()); 74 } 75 76 //===----------------------------------------------------------------------===// 77 // device_type support helpers 78 //===----------------------------------------------------------------------===// 79 80 static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) { 81 if (arrayAttr && *arrayAttr && arrayAttr->size() > 0) 82 return true; 83 return false; 84 } 85 86 static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr, 87 mlir::acc::DeviceType deviceType) { 88 if (!hasDeviceTypeValues(arrayAttr)) 89 return false; 90 91 for (auto attr : *arrayAttr) { 92 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr); 93 if (deviceTypeAttr.getValue() == deviceType) 94 return true; 95 } 96 97 return false; 98 } 99 100 static void printDeviceTypes(mlir::OpAsmPrinter &p, 101 std::optional<mlir::ArrayAttr> deviceTypes) { 102 if (!hasDeviceTypeValues(deviceTypes)) 103 return; 104 105 p << "["; 106 llvm::interleaveComma(*deviceTypes, p, 107 [&](mlir::Attribute attr) { p << attr; }); 108 p << "]"; 109 } 110 111 static std::optional<unsigned> findSegment(ArrayAttr segments, 112 mlir::acc::DeviceType deviceType) { 113 unsigned segmentIdx = 0; 114 for (auto attr : segments) { 115 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr); 116 if (deviceTypeAttr.getValue() == deviceType) 117 return std::make_optional(segmentIdx); 118 ++segmentIdx; 119 } 120 return std::nullopt; 121 } 122 123 static mlir::Operation::operand_range 124 getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr, 125 mlir::Operation::operand_range range, 126 std::optional<llvm::ArrayRef<int32_t>> segments, 127 mlir::acc::DeviceType deviceType) { 128 if (!arrayAttr) 129 return range.take_front(0); 130 if (auto pos = findSegment(*arrayAttr, deviceType)) { 131 int32_t nbOperandsBefore = 0; 132 for (unsigned i = 0; i < *pos; ++i) 133 nbOperandsBefore += (*segments)[i]; 134 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]); 135 } 136 return range.take_front(0); 137 } 138 139 static mlir::Value 140 getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr, 141 mlir::Operation::operand_range operands, 142 std::optional<llvm::ArrayRef<int32_t>> segments, 143 std::optional<mlir::ArrayAttr> hasWaitDevnum, 144 mlir::acc::DeviceType deviceType) { 145 if (!hasDeviceTypeValues(deviceTypeAttr)) 146 return {}; 147 if (auto pos = findSegment(*deviceTypeAttr, deviceType)) 148 if (hasWaitDevnum->getValue()[*pos]) 149 return getValuesFromSegments(deviceTypeAttr, operands, segments, 150 deviceType) 151 .front(); 152 return {}; 153 } 154 155 static mlir::Operation::operand_range 156 getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr, 157 mlir::Operation::operand_range operands, 158 std::optional<llvm::ArrayRef<int32_t>> segments, 159 std::optional<mlir::ArrayAttr> hasWaitDevnum, 160 mlir::acc::DeviceType deviceType) { 161 auto range = 162 getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType); 163 if (range.empty()) 164 return range; 165 if (auto pos = findSegment(*deviceTypeAttr, deviceType)) { 166 if (hasWaitDevnum && *hasWaitDevnum) { 167 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]); 168 if (boolAttr.getValue()) 169 return range.drop_front(1); // first value is devnum 170 } 171 } 172 return range; 173 } 174 175 template <typename Op> 176 static LogicalResult checkWaitAndAsyncConflict(Op op) { 177 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType(); 178 ++dtypeInt) { 179 auto dtype = static_cast<acc::DeviceType>(dtypeInt); 180 181 // The async attribute represent the async clause without value. Therefore 182 // the attribute and operand cannot appear at the same time. 183 if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) && 184 op.hasAsyncOnly(dtype)) 185 return op.emitError("async attribute cannot appear with asyncOperand"); 186 187 // The wait attribute represent the wait clause without values. Therefore 188 // the attribute and operands cannot appear at the same time. 189 if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) && 190 op.hasWaitOnly(dtype)) 191 return op.emitError("wait attribute cannot appear with waitOperands"); 192 } 193 return success(); 194 } 195 196 template <typename Op> 197 static LogicalResult checkVarAndVarType(Op op) { 198 if (!op.getVar()) 199 return op.emitError("must have var operand"); 200 201 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) && 202 mlir::isa<mlir::acc::MappableType>(op.getVar().getType())) { 203 // TODO: If a type implements both interfaces (mappable and pointer-like), 204 // it is unclear which semantics to apply without additional info which 205 // would need captured in the data operation. For now restrict this case 206 // unless a compelling reason to support disambiguating between the two. 207 return op.emitError("var must be mappable or pointer-like (not both)"); 208 } 209 210 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) && 211 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType())) 212 return op.emitError("var must be mappable or pointer-like"); 213 214 if (mlir::isa<mlir::acc::MappableType>(op.getVar().getType()) && 215 op.getVarType() != op.getVar().getType()) 216 return op.emitError("varType must match when var is mappable"); 217 218 return success(); 219 } 220 221 template <typename Op> 222 static LogicalResult checkVarAndAccVar(Op op) { 223 if (op.getVar().getType() != op.getAccVar().getType()) 224 return op.emitError("input and output types must match"); 225 226 return success(); 227 } 228 229 static ParseResult parseVar(mlir::OpAsmParser &parser, 230 OpAsmParser::UnresolvedOperand &var) { 231 // Either `var` or `varPtr` keyword is required. 232 if (failed(parser.parseOptionalKeyword("varPtr"))) { 233 if (failed(parser.parseKeyword("var"))) 234 return failure(); 235 } 236 if (failed(parser.parseLParen())) 237 return failure(); 238 if (failed(parser.parseOperand(var))) 239 return failure(); 240 241 return success(); 242 } 243 244 static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, 245 mlir::Value var) { 246 if (mlir::isa<mlir::acc::PointerLikeType>(var.getType())) 247 p << "varPtr("; 248 else 249 p << "var("; 250 p.printOperand(var); 251 } 252 253 static ParseResult parseAccVar(mlir::OpAsmParser &parser, 254 OpAsmParser::UnresolvedOperand &var, 255 mlir::Type &accVarType) { 256 // Either `accVar` or `accPtr` keyword is required. 257 if (failed(parser.parseOptionalKeyword("accPtr"))) { 258 if (failed(parser.parseKeyword("accVar"))) 259 return failure(); 260 } 261 if (failed(parser.parseLParen())) 262 return failure(); 263 if (failed(parser.parseOperand(var))) 264 return failure(); 265 if (failed(parser.parseColon())) 266 return failure(); 267 if (failed(parser.parseType(accVarType))) 268 return failure(); 269 if (failed(parser.parseRParen())) 270 return failure(); 271 272 return success(); 273 } 274 275 static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, 276 mlir::Value accVar, mlir::Type accVarType) { 277 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.getType())) 278 p << "accPtr("; 279 else 280 p << "accVar("; 281 p.printOperand(accVar); 282 p << " : "; 283 p.printType(accVarType); 284 p << ")"; 285 } 286 287 static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, 288 mlir::Type &varPtrType, 289 mlir::TypeAttr &varTypeAttr) { 290 if (failed(parser.parseType(varPtrType))) 291 return failure(); 292 if (failed(parser.parseRParen())) 293 return failure(); 294 295 if (succeeded(parser.parseOptionalKeyword("varType"))) { 296 if (failed(parser.parseLParen())) 297 return failure(); 298 mlir::Type varType; 299 if (failed(parser.parseType(varType))) 300 return failure(); 301 varTypeAttr = mlir::TypeAttr::get(varType); 302 if (failed(parser.parseRParen())) 303 return failure(); 304 } else { 305 // Set `varType` from the element type of the type of `varPtr`. 306 if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType)) 307 varTypeAttr = mlir::TypeAttr::get( 308 mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()); 309 else 310 varTypeAttr = mlir::TypeAttr::get(varPtrType); 311 } 312 313 return success(); 314 } 315 316 static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, 317 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) { 318 p.printType(varPtrType); 319 p << ")"; 320 321 // Print the `varType` only if it differs from the element type of 322 // `varPtr`'s type. 323 mlir::Type varType = varTypeAttr.getValue(); 324 mlir::Type typeToCheckAgainst = 325 mlir::isa<mlir::acc::PointerLikeType>(varPtrType) 326 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType() 327 : varPtrType; 328 if (typeToCheckAgainst != varType) { 329 p << " varType("; 330 p.printType(varType); 331 p << ")"; 332 } 333 } 334 335 //===----------------------------------------------------------------------===// 336 // DataBoundsOp 337 //===----------------------------------------------------------------------===// 338 LogicalResult acc::DataBoundsOp::verify() { 339 auto extent = getExtent(); 340 auto upperbound = getUpperbound(); 341 if (!extent && !upperbound) 342 return emitError("expected extent or upperbound."); 343 return success(); 344 } 345 346 //===----------------------------------------------------------------------===// 347 // PrivateOp 348 //===----------------------------------------------------------------------===// 349 LogicalResult acc::PrivateOp::verify() { 350 if (getDataClause() != acc::DataClause::acc_private) 351 return emitError( 352 "data clause associated with private operation must match its intent"); 353 if (failed(checkVarAndVarType(*this))) 354 return failure(); 355 return success(); 356 } 357 358 //===----------------------------------------------------------------------===// 359 // FirstprivateOp 360 //===----------------------------------------------------------------------===// 361 LogicalResult acc::FirstprivateOp::verify() { 362 if (getDataClause() != acc::DataClause::acc_firstprivate) 363 return emitError("data clause associated with firstprivate operation must " 364 "match its intent"); 365 if (failed(checkVarAndVarType(*this))) 366 return failure(); 367 return success(); 368 } 369 370 //===----------------------------------------------------------------------===// 371 // ReductionOp 372 //===----------------------------------------------------------------------===// 373 LogicalResult acc::ReductionOp::verify() { 374 if (getDataClause() != acc::DataClause::acc_reduction) 375 return emitError("data clause associated with reduction operation must " 376 "match its intent"); 377 if (failed(checkVarAndVarType(*this))) 378 return failure(); 379 return success(); 380 } 381 382 //===----------------------------------------------------------------------===// 383 // DevicePtrOp 384 //===----------------------------------------------------------------------===// 385 LogicalResult acc::DevicePtrOp::verify() { 386 if (getDataClause() != acc::DataClause::acc_deviceptr) 387 return emitError("data clause associated with deviceptr operation must " 388 "match its intent"); 389 if (failed(checkVarAndVarType(*this))) 390 return failure(); 391 if (failed(checkVarAndAccVar(*this))) 392 return failure(); 393 return success(); 394 } 395 396 //===----------------------------------------------------------------------===// 397 // PresentOp 398 //===----------------------------------------------------------------------===// 399 LogicalResult acc::PresentOp::verify() { 400 if (getDataClause() != acc::DataClause::acc_present) 401 return emitError( 402 "data clause associated with present operation must match its intent"); 403 if (failed(checkVarAndVarType(*this))) 404 return failure(); 405 if (failed(checkVarAndAccVar(*this))) 406 return failure(); 407 return success(); 408 } 409 410 //===----------------------------------------------------------------------===// 411 // CopyinOp 412 //===----------------------------------------------------------------------===// 413 LogicalResult acc::CopyinOp::verify() { 414 // Test for all clauses this operation can be decomposed from: 415 if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin && 416 getDataClause() != acc::DataClause::acc_copyin_readonly && 417 getDataClause() != acc::DataClause::acc_copy && 418 getDataClause() != acc::DataClause::acc_reduction) 419 return emitError( 420 "data clause associated with copyin operation must match its intent" 421 " or specify original clause this operation was decomposed from"); 422 if (failed(checkVarAndVarType(*this))) 423 return failure(); 424 if (failed(checkVarAndAccVar(*this))) 425 return failure(); 426 return success(); 427 } 428 429 bool acc::CopyinOp::isCopyinReadonly() { 430 return getDataClause() == acc::DataClause::acc_copyin_readonly; 431 } 432 433 //===----------------------------------------------------------------------===// 434 // CreateOp 435 //===----------------------------------------------------------------------===// 436 LogicalResult acc::CreateOp::verify() { 437 // Test for all clauses this operation can be decomposed from: 438 if (getDataClause() != acc::DataClause::acc_create && 439 getDataClause() != acc::DataClause::acc_create_zero && 440 getDataClause() != acc::DataClause::acc_copyout && 441 getDataClause() != acc::DataClause::acc_copyout_zero) 442 return emitError( 443 "data clause associated with create operation must match its intent" 444 " or specify original clause this operation was decomposed from"); 445 if (failed(checkVarAndVarType(*this))) 446 return failure(); 447 if (failed(checkVarAndAccVar(*this))) 448 return failure(); 449 return success(); 450 } 451 452 bool acc::CreateOp::isCreateZero() { 453 // The zero modifier is encoded in the data clause. 454 return getDataClause() == acc::DataClause::acc_create_zero || 455 getDataClause() == acc::DataClause::acc_copyout_zero; 456 } 457 458 //===----------------------------------------------------------------------===// 459 // NoCreateOp 460 //===----------------------------------------------------------------------===// 461 LogicalResult acc::NoCreateOp::verify() { 462 if (getDataClause() != acc::DataClause::acc_no_create) 463 return emitError("data clause associated with no_create operation must " 464 "match its intent"); 465 if (failed(checkVarAndVarType(*this))) 466 return failure(); 467 if (failed(checkVarAndAccVar(*this))) 468 return failure(); 469 return success(); 470 } 471 472 //===----------------------------------------------------------------------===// 473 // AttachOp 474 //===----------------------------------------------------------------------===// 475 LogicalResult acc::AttachOp::verify() { 476 if (getDataClause() != acc::DataClause::acc_attach) 477 return emitError( 478 "data clause associated with attach operation must match its intent"); 479 if (failed(checkVarAndVarType(*this))) 480 return failure(); 481 if (failed(checkVarAndAccVar(*this))) 482 return failure(); 483 return success(); 484 } 485 486 //===----------------------------------------------------------------------===// 487 // DeclareDeviceResidentOp 488 //===----------------------------------------------------------------------===// 489 490 LogicalResult acc::DeclareDeviceResidentOp::verify() { 491 if (getDataClause() != acc::DataClause::acc_declare_device_resident) 492 return emitError("data clause associated with device_resident operation " 493 "must match its intent"); 494 if (failed(checkVarAndVarType(*this))) 495 return failure(); 496 if (failed(checkVarAndAccVar(*this))) 497 return failure(); 498 return success(); 499 } 500 501 //===----------------------------------------------------------------------===// 502 // DeclareLinkOp 503 //===----------------------------------------------------------------------===// 504 505 LogicalResult acc::DeclareLinkOp::verify() { 506 if (getDataClause() != acc::DataClause::acc_declare_link) 507 return emitError( 508 "data clause associated with link operation must match its intent"); 509 if (failed(checkVarAndVarType(*this))) 510 return failure(); 511 if (failed(checkVarAndAccVar(*this))) 512 return failure(); 513 return success(); 514 } 515 516 //===----------------------------------------------------------------------===// 517 // CopyoutOp 518 //===----------------------------------------------------------------------===// 519 LogicalResult acc::CopyoutOp::verify() { 520 // Test for all clauses this operation can be decomposed from: 521 if (getDataClause() != acc::DataClause::acc_copyout && 522 getDataClause() != acc::DataClause::acc_copyout_zero && 523 getDataClause() != acc::DataClause::acc_copy && 524 getDataClause() != acc::DataClause::acc_reduction) 525 return emitError( 526 "data clause associated with copyout operation must match its intent" 527 " or specify original clause this operation was decomposed from"); 528 if (!getVar() || !getAccVar()) 529 return emitError("must have both host and device pointers"); 530 if (failed(checkVarAndVarType(*this))) 531 return failure(); 532 if (failed(checkVarAndAccVar(*this))) 533 return failure(); 534 return success(); 535 } 536 537 bool acc::CopyoutOp::isCopyoutZero() { 538 return getDataClause() == acc::DataClause::acc_copyout_zero; 539 } 540 541 //===----------------------------------------------------------------------===// 542 // DeleteOp 543 //===----------------------------------------------------------------------===// 544 LogicalResult acc::DeleteOp::verify() { 545 // Test for all clauses this operation can be decomposed from: 546 if (getDataClause() != acc::DataClause::acc_delete && 547 getDataClause() != acc::DataClause::acc_create && 548 getDataClause() != acc::DataClause::acc_create_zero && 549 getDataClause() != acc::DataClause::acc_copyin && 550 getDataClause() != acc::DataClause::acc_copyin_readonly && 551 getDataClause() != acc::DataClause::acc_present && 552 getDataClause() != acc::DataClause::acc_declare_device_resident && 553 getDataClause() != acc::DataClause::acc_declare_link) 554 return emitError( 555 "data clause associated with delete operation must match its intent" 556 " or specify original clause this operation was decomposed from"); 557 if (!getAccVar()) 558 return emitError("must have device pointer"); 559 return success(); 560 } 561 562 //===----------------------------------------------------------------------===// 563 // DetachOp 564 //===----------------------------------------------------------------------===// 565 LogicalResult acc::DetachOp::verify() { 566 // Test for all clauses this operation can be decomposed from: 567 if (getDataClause() != acc::DataClause::acc_detach && 568 getDataClause() != acc::DataClause::acc_attach) 569 return emitError( 570 "data clause associated with detach operation must match its intent" 571 " or specify original clause this operation was decomposed from"); 572 if (!getAccVar()) 573 return emitError("must have device pointer"); 574 return success(); 575 } 576 577 //===----------------------------------------------------------------------===// 578 // HostOp 579 //===----------------------------------------------------------------------===// 580 LogicalResult acc::UpdateHostOp::verify() { 581 // Test for all clauses this operation can be decomposed from: 582 if (getDataClause() != acc::DataClause::acc_update_host && 583 getDataClause() != acc::DataClause::acc_update_self) 584 return emitError( 585 "data clause associated with host operation must match its intent" 586 " or specify original clause this operation was decomposed from"); 587 if (!getVar() || !getAccVar()) 588 return emitError("must have both host and device pointers"); 589 if (failed(checkVarAndVarType(*this))) 590 return failure(); 591 if (failed(checkVarAndAccVar(*this))) 592 return failure(); 593 return success(); 594 } 595 596 //===----------------------------------------------------------------------===// 597 // DeviceOp 598 //===----------------------------------------------------------------------===// 599 LogicalResult acc::UpdateDeviceOp::verify() { 600 // Test for all clauses this operation can be decomposed from: 601 if (getDataClause() != acc::DataClause::acc_update_device) 602 return emitError( 603 "data clause associated with device operation must match its intent" 604 " or specify original clause this operation was decomposed from"); 605 if (failed(checkVarAndVarType(*this))) 606 return failure(); 607 if (failed(checkVarAndAccVar(*this))) 608 return failure(); 609 return success(); 610 } 611 612 //===----------------------------------------------------------------------===// 613 // UseDeviceOp 614 //===----------------------------------------------------------------------===// 615 LogicalResult acc::UseDeviceOp::verify() { 616 // Test for all clauses this operation can be decomposed from: 617 if (getDataClause() != acc::DataClause::acc_use_device) 618 return emitError( 619 "data clause associated with use_device operation must match its intent" 620 " or specify original clause this operation was decomposed from"); 621 if (failed(checkVarAndVarType(*this))) 622 return failure(); 623 if (failed(checkVarAndAccVar(*this))) 624 return failure(); 625 return success(); 626 } 627 628 //===----------------------------------------------------------------------===// 629 // CacheOp 630 //===----------------------------------------------------------------------===// 631 LogicalResult acc::CacheOp::verify() { 632 // Test for all clauses this operation can be decomposed from: 633 if (getDataClause() != acc::DataClause::acc_cache && 634 getDataClause() != acc::DataClause::acc_cache_readonly) 635 return emitError( 636 "data clause associated with cache operation must match its intent" 637 " or specify original clause this operation was decomposed from"); 638 if (failed(checkVarAndVarType(*this))) 639 return failure(); 640 if (failed(checkVarAndAccVar(*this))) 641 return failure(); 642 return success(); 643 } 644 645 template <typename StructureOp> 646 static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, 647 unsigned nRegions = 1) { 648 649 SmallVector<Region *, 2> regions; 650 for (unsigned i = 0; i < nRegions; ++i) 651 regions.push_back(state.addRegion()); 652 653 for (Region *region : regions) 654 if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{})) 655 return failure(); 656 657 return success(); 658 } 659 660 static bool isComputeOperation(Operation *op) { 661 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op); 662 } 663 664 namespace { 665 /// Pattern to remove operation without region that have constant false `ifCond` 666 /// and remove the condition from the operation if the `ifCond` is a true 667 /// constant. 668 template <typename OpTy> 669 struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> { 670 using OpRewritePattern<OpTy>::OpRewritePattern; 671 672 LogicalResult matchAndRewrite(OpTy op, 673 PatternRewriter &rewriter) const override { 674 // Early return if there is no condition. 675 Value ifCond = op.getIfCond(); 676 if (!ifCond) 677 return failure(); 678 679 IntegerAttr constAttr; 680 if (!matchPattern(ifCond, m_Constant(&constAttr))) 681 return failure(); 682 if (constAttr.getInt()) 683 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); }); 684 else 685 rewriter.eraseOp(op); 686 687 return success(); 688 } 689 }; 690 691 /// Replaces the given op with the contents of the given single-block region, 692 /// using the operands of the block terminator to replace operation results. 693 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, 694 Region ®ion, ValueRange blockArgs = {}) { 695 assert(llvm::hasSingleElement(region) && "expected single-region block"); 696 Block *block = ®ion.front(); 697 Operation *terminator = block->getTerminator(); 698 ValueRange results = terminator->getOperands(); 699 rewriter.inlineBlockBefore(block, op, blockArgs); 700 rewriter.replaceOp(op, results); 701 rewriter.eraseOp(terminator); 702 } 703 704 /// Pattern to remove operation with region that have constant false `ifCond` 705 /// and remove the condition from the operation if the `ifCond` is constant 706 /// true. 707 template <typename OpTy> 708 struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> { 709 using OpRewritePattern<OpTy>::OpRewritePattern; 710 711 LogicalResult matchAndRewrite(OpTy op, 712 PatternRewriter &rewriter) const override { 713 // Early return if there is no condition. 714 Value ifCond = op.getIfCond(); 715 if (!ifCond) 716 return failure(); 717 718 IntegerAttr constAttr; 719 if (!matchPattern(ifCond, m_Constant(&constAttr))) 720 return failure(); 721 if (constAttr.getInt()) 722 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); }); 723 else 724 replaceOpWithRegion(rewriter, op, op.getRegion()); 725 726 return success(); 727 } 728 }; 729 730 } // namespace 731 732 //===----------------------------------------------------------------------===// 733 // PrivateRecipeOp 734 //===----------------------------------------------------------------------===// 735 736 static LogicalResult verifyInitLikeSingleArgRegion( 737 Operation *op, Region ®ion, StringRef regionType, StringRef regionName, 738 Type type, bool verifyYield, bool optional = false) { 739 if (optional && region.empty()) 740 return success(); 741 742 if (region.empty()) 743 return op->emitOpError() << "expects non-empty " << regionName << " region"; 744 Block &firstBlock = region.front(); 745 if (firstBlock.getNumArguments() < 1 || 746 firstBlock.getArgument(0).getType() != type) 747 return op->emitOpError() << "expects " << regionName 748 << " region first " 749 "argument of the " 750 << regionType << " type"; 751 752 if (verifyYield) { 753 for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) { 754 if (yieldOp.getOperands().size() != 1 || 755 yieldOp.getOperands().getTypes()[0] != type) 756 return op->emitOpError() << "expects " << regionName 757 << " region to " 758 "yield a value of the " 759 << regionType << " type"; 760 } 761 } 762 return success(); 763 } 764 765 LogicalResult acc::PrivateRecipeOp::verifyRegions() { 766 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), 767 "privatization", "init", getType(), 768 /*verifyYield=*/false))) 769 return failure(); 770 if (failed(verifyInitLikeSingleArgRegion( 771 *this, getDestroyRegion(), "privatization", "destroy", getType(), 772 /*verifyYield=*/false, /*optional=*/true))) 773 return failure(); 774 return success(); 775 } 776 777 //===----------------------------------------------------------------------===// 778 // FirstprivateRecipeOp 779 //===----------------------------------------------------------------------===// 780 781 LogicalResult acc::FirstprivateRecipeOp::verifyRegions() { 782 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), 783 "privatization", "init", getType(), 784 /*verifyYield=*/false))) 785 return failure(); 786 787 if (getCopyRegion().empty()) 788 return emitOpError() << "expects non-empty copy region"; 789 790 Block &firstBlock = getCopyRegion().front(); 791 if (firstBlock.getNumArguments() < 2 || 792 firstBlock.getArgument(0).getType() != getType()) 793 return emitOpError() << "expects copy region with two arguments of the " 794 "privatization type"; 795 796 if (getDestroyRegion().empty()) 797 return success(); 798 799 if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(), 800 "privatization", "destroy", 801 getType(), /*verifyYield=*/false))) 802 return failure(); 803 804 return success(); 805 } 806 807 //===----------------------------------------------------------------------===// 808 // ReductionRecipeOp 809 //===----------------------------------------------------------------------===// 810 811 LogicalResult acc::ReductionRecipeOp::verifyRegions() { 812 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction", 813 "init", getType(), 814 /*verifyYield=*/false))) 815 return failure(); 816 817 if (getCombinerRegion().empty()) 818 return emitOpError() << "expects non-empty combiner region"; 819 820 Block &reductionBlock = getCombinerRegion().front(); 821 if (reductionBlock.getNumArguments() < 2 || 822 reductionBlock.getArgument(0).getType() != getType() || 823 reductionBlock.getArgument(1).getType() != getType()) 824 return emitOpError() << "expects combiner region with the first two " 825 << "arguments of the reduction type"; 826 827 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) { 828 if (yieldOp.getOperands().size() != 1 || 829 yieldOp.getOperands().getTypes()[0] != getType()) 830 return emitOpError() << "expects combiner region to yield a value " 831 "of the reduction type"; 832 } 833 834 return success(); 835 } 836 837 //===----------------------------------------------------------------------===// 838 // Custom parser and printer verifier for private clause 839 //===----------------------------------------------------------------------===// 840 841 static ParseResult parseSymOperandList( 842 mlir::OpAsmParser &parser, 843 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands, 844 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) { 845 llvm::SmallVector<SymbolRefAttr> attributes; 846 if (failed(parser.parseCommaSeparatedList([&]() { 847 if (parser.parseAttribute(attributes.emplace_back()) || 848 parser.parseArrow() || 849 parser.parseOperand(operands.emplace_back()) || 850 parser.parseColonType(types.emplace_back())) 851 return failure(); 852 return success(); 853 }))) 854 return failure(); 855 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(), 856 attributes.end()); 857 symbols = ArrayAttr::get(parser.getContext(), arrayAttr); 858 return success(); 859 } 860 861 static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, 862 mlir::OperandRange operands, 863 mlir::TypeRange types, 864 std::optional<mlir::ArrayAttr> attributes) { 865 llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) { 866 p << std::get<0>(it) << " -> " << std::get<1>(it) << " : " 867 << std::get<1>(it).getType(); 868 }); 869 } 870 871 //===----------------------------------------------------------------------===// 872 // ParallelOp 873 //===----------------------------------------------------------------------===// 874 875 /// Check dataOperands for acc.parallel, acc.serial and acc.kernels. 876 template <typename Op> 877 static LogicalResult checkDataOperands(Op op, 878 const mlir::ValueRange &operands) { 879 for (mlir::Value operand : operands) 880 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp, 881 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp, 882 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>( 883 operand.getDefiningOp())) 884 return op.emitError( 885 "expect data entry/exit operation or acc.getdeviceptr " 886 "as defining op"); 887 return success(); 888 } 889 890 template <typename Op> 891 static LogicalResult 892 checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes, 893 mlir::OperandRange operands, llvm::StringRef operandName, 894 llvm::StringRef symbolName, bool checkOperandType = true) { 895 if (!operands.empty()) { 896 if (!attributes || attributes->size() != operands.size()) 897 return op->emitOpError() 898 << "expected as many " << symbolName << " symbol reference as " 899 << operandName << " operands"; 900 } else { 901 if (attributes) 902 return op->emitOpError() 903 << "unexpected " << symbolName << " symbol reference"; 904 return success(); 905 } 906 907 llvm::DenseSet<Value> set; 908 for (auto args : llvm::zip(operands, *attributes)) { 909 mlir::Value operand = std::get<0>(args); 910 911 if (!set.insert(operand).second) 912 return op->emitOpError() 913 << operandName << " operand appears more than once"; 914 915 mlir::Type varType = operand.getType(); 916 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args)); 917 auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef); 918 if (!decl) 919 return op->emitOpError() 920 << "expected symbol reference " << symbolRef << " to point to a " 921 << operandName << " declaration"; 922 923 if (checkOperandType && decl.getType() && decl.getType() != varType) 924 return op->emitOpError() << "expected " << operandName << " (" << varType 925 << ") to be the same type as " << operandName 926 << " declaration (" << decl.getType() << ")"; 927 } 928 929 return success(); 930 } 931 932 unsigned ParallelOp::getNumDataOperands() { 933 return getReductionOperands().size() + getPrivateOperands().size() + 934 getFirstprivateOperands().size() + getDataClauseOperands().size(); 935 } 936 937 Value ParallelOp::getDataOperand(unsigned i) { 938 unsigned numOptional = getAsyncOperands().size(); 939 numOptional += getNumGangs().size(); 940 numOptional += getNumWorkers().size(); 941 numOptional += getVectorLength().size(); 942 numOptional += getIfCond() ? 1 : 0; 943 numOptional += getSelfCond() ? 1 : 0; 944 return getOperand(getWaitOperands().size() + numOptional + i); 945 } 946 947 template <typename Op> 948 static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, 949 ArrayAttr deviceTypes, 950 llvm::StringRef keyword) { 951 if (!operands.empty() && deviceTypes.getValue().size() != operands.size()) 952 return op.emitOpError() << keyword << " operands count must match " 953 << keyword << " device_type count"; 954 return success(); 955 } 956 957 template <typename Op> 958 static LogicalResult verifyDeviceTypeAndSegmentCountMatch( 959 Op op, OperandRange operands, DenseI32ArrayAttr segments, 960 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) { 961 std::size_t numOperandsInSegments = 0; 962 std::size_t nbOfSegments = 0; 963 964 if (segments) { 965 for (auto segCount : segments.asArrayRef()) { 966 if (maxInSegment != 0 && segCount > maxInSegment) 967 return op.emitOpError() << keyword << " expects a maximum of " 968 << maxInSegment << " values per segment"; 969 numOperandsInSegments += segCount; 970 ++nbOfSegments; 971 } 972 } 973 974 if ((numOperandsInSegments != operands.size()) || 975 (!deviceTypes && !operands.empty())) 976 return op.emitOpError() 977 << keyword << " operand count does not match count in segments"; 978 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments) 979 return op.emitOpError() 980 << keyword << " segment count does not match device_type count"; 981 return success(); 982 } 983 984 LogicalResult acc::ParallelOp::verify() { 985 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>( 986 *this, getPrivatizations(), getPrivateOperands(), "private", 987 "privatizations", /*checkOperandType=*/false))) 988 return failure(); 989 if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>( 990 *this, getFirstprivatizations(), getFirstprivateOperands(), 991 "firstprivate", "firstprivatizations", /*checkOperandType=*/false))) 992 return failure(); 993 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>( 994 *this, getReductionRecipes(), getReductionOperands(), "reduction", 995 "reductions", false))) 996 return failure(); 997 998 if (failed(verifyDeviceTypeAndSegmentCountMatch( 999 *this, getNumGangs(), getNumGangsSegmentsAttr(), 1000 getNumGangsDeviceTypeAttr(), "num_gangs", 3))) 1001 return failure(); 1002 1003 if (failed(verifyDeviceTypeAndSegmentCountMatch( 1004 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(), 1005 getWaitOperandsDeviceTypeAttr(), "wait"))) 1006 return failure(); 1007 1008 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(), 1009 getNumWorkersDeviceTypeAttr(), 1010 "num_workers"))) 1011 return failure(); 1012 1013 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(), 1014 getVectorLengthDeviceTypeAttr(), 1015 "vector_length"))) 1016 return failure(); 1017 1018 if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(), 1019 getAsyncOperandsDeviceTypeAttr(), 1020 "async"))) 1021 return failure(); 1022 1023 if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*this))) 1024 return failure(); 1025 1026 return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands()); 1027 } 1028 1029 static mlir::Value 1030 getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr, 1031 mlir::Operation::operand_range range, 1032 mlir::acc::DeviceType deviceType) { 1033 if (!arrayAttr) 1034 return {}; 1035 if (auto pos = findSegment(*arrayAttr, deviceType)) 1036 return range[*pos]; 1037 return {}; 1038 } 1039 1040 bool acc::ParallelOp::hasAsyncOnly() { 1041 return hasAsyncOnly(mlir::acc::DeviceType::None); 1042 } 1043 1044 bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) { 1045 return hasDeviceType(getAsyncOnly(), deviceType); 1046 } 1047 1048 mlir::Value acc::ParallelOp::getAsyncValue() { 1049 return getAsyncValue(mlir::acc::DeviceType::None); 1050 } 1051 1052 mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) { 1053 return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(), 1054 getAsyncOperands(), deviceType); 1055 } 1056 1057 mlir::Value acc::ParallelOp::getNumWorkersValue() { 1058 return getNumWorkersValue(mlir::acc::DeviceType::None); 1059 } 1060 1061 mlir::Value 1062 acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) { 1063 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(), 1064 deviceType); 1065 } 1066 1067 mlir::Value acc::ParallelOp::getVectorLengthValue() { 1068 return getVectorLengthValue(mlir::acc::DeviceType::None); 1069 } 1070 1071 mlir::Value 1072 acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) { 1073 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(), 1074 getVectorLength(), deviceType); 1075 } 1076 1077 mlir::Operation::operand_range ParallelOp::getNumGangsValues() { 1078 return getNumGangsValues(mlir::acc::DeviceType::None); 1079 } 1080 1081 mlir::Operation::operand_range 1082 ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) { 1083 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(), 1084 getNumGangsSegments(), deviceType); 1085 } 1086 1087 bool acc::ParallelOp::hasWaitOnly() { 1088 return hasWaitOnly(mlir::acc::DeviceType::None); 1089 } 1090 1091 bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { 1092 return hasDeviceType(getWaitOnly(), deviceType); 1093 } 1094 1095 mlir::Operation::operand_range ParallelOp::getWaitValues() { 1096 return getWaitValues(mlir::acc::DeviceType::None); 1097 } 1098 1099 mlir::Operation::operand_range 1100 ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) { 1101 return getWaitValuesWithoutDevnum( 1102 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(), 1103 getHasWaitDevnum(), deviceType); 1104 } 1105 1106 mlir::Value ParallelOp::getWaitDevnum() { 1107 return getWaitDevnum(mlir::acc::DeviceType::None); 1108 } 1109 1110 mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) { 1111 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(), 1112 getWaitOperandsSegments(), getHasWaitDevnum(), 1113 deviceType); 1114 } 1115 1116 void ParallelOp::build(mlir::OpBuilder &odsBuilder, 1117 mlir::OperationState &odsState, 1118 mlir::ValueRange numGangs, mlir::ValueRange numWorkers, 1119 mlir::ValueRange vectorLength, 1120 mlir::ValueRange asyncOperands, 1121 mlir::ValueRange waitOperands, mlir::Value ifCond, 1122 mlir::Value selfCond, mlir::ValueRange reductionOperands, 1123 mlir::ValueRange gangPrivateOperands, 1124 mlir::ValueRange gangFirstPrivateOperands, 1125 mlir::ValueRange dataClauseOperands) { 1126 1127 ParallelOp::build( 1128 odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr, 1129 /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr, 1130 /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr, 1131 /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr, 1132 /*numGangsDeviceType=*/nullptr, numWorkers, 1133 /*numWorkersDeviceType=*/nullptr, vectorLength, 1134 /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond, 1135 /*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr, 1136 gangPrivateOperands, /*privatizations=*/nullptr, gangFirstPrivateOperands, 1137 /*firstprivatizations=*/nullptr, dataClauseOperands, 1138 /*defaultAttr=*/nullptr, /*combined=*/nullptr); 1139 } 1140 1141 static ParseResult parseNumGangs( 1142 mlir::OpAsmParser &parser, 1143 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands, 1144 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes, 1145 mlir::DenseI32ArrayAttr &segments) { 1146 llvm::SmallVector<DeviceTypeAttr> attributes; 1147 llvm::SmallVector<int32_t> seg; 1148 1149 do { 1150 if (failed(parser.parseLBrace())) 1151 return failure(); 1152 1153 int32_t crtOperandsSize = operands.size(); 1154 if (failed(parser.parseCommaSeparatedList( 1155 mlir::AsmParser::Delimiter::None, [&]() { 1156 if (parser.parseOperand(operands.emplace_back()) || 1157 parser.parseColonType(types.emplace_back())) 1158 return failure(); 1159 return success(); 1160 }))) 1161 return failure(); 1162 seg.push_back(operands.size() - crtOperandsSize); 1163 1164 if (failed(parser.parseRBrace())) 1165 return failure(); 1166 1167 if (succeeded(parser.parseOptionalLSquare())) { 1168 if (parser.parseAttribute(attributes.emplace_back()) || 1169 parser.parseRSquare()) 1170 return failure(); 1171 } else { 1172 attributes.push_back(mlir::acc::DeviceTypeAttr::get( 1173 parser.getContext(), mlir::acc::DeviceType::None)); 1174 } 1175 } while (succeeded(parser.parseOptionalComma())); 1176 1177 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(), 1178 attributes.end()); 1179 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr); 1180 segments = DenseI32ArrayAttr::get(parser.getContext(), seg); 1181 1182 return success(); 1183 } 1184 1185 static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr) { 1186 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr); 1187 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None) 1188 p << " [" << attr << "]"; 1189 } 1190 1191 static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, 1192 mlir::OperandRange operands, mlir::TypeRange types, 1193 std::optional<mlir::ArrayAttr> deviceTypes, 1194 std::optional<mlir::DenseI32ArrayAttr> segments) { 1195 unsigned opIdx = 0; 1196 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) { 1197 p << "{"; 1198 llvm::interleaveComma( 1199 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) { 1200 p << operands[opIdx] << " : " << operands[opIdx].getType(); 1201 ++opIdx; 1202 }); 1203 p << "}"; 1204 printSingleDeviceType(p, it.value()); 1205 }); 1206 } 1207 1208 static ParseResult parseDeviceTypeOperandsWithSegment( 1209 mlir::OpAsmParser &parser, 1210 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands, 1211 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes, 1212 mlir::DenseI32ArrayAttr &segments) { 1213 llvm::SmallVector<DeviceTypeAttr> attributes; 1214 llvm::SmallVector<int32_t> seg; 1215 1216 do { 1217 if (failed(parser.parseLBrace())) 1218 return failure(); 1219 1220 int32_t crtOperandsSize = operands.size(); 1221 1222 if (failed(parser.parseCommaSeparatedList( 1223 mlir::AsmParser::Delimiter::None, [&]() { 1224 if (parser.parseOperand(operands.emplace_back()) || 1225 parser.parseColonType(types.emplace_back())) 1226 return failure(); 1227 return success(); 1228 }))) 1229 return failure(); 1230 1231 seg.push_back(operands.size() - crtOperandsSize); 1232 1233 if (failed(parser.parseRBrace())) 1234 return failure(); 1235 1236 if (succeeded(parser.parseOptionalLSquare())) { 1237 if (parser.parseAttribute(attributes.emplace_back()) || 1238 parser.parseRSquare()) 1239 return failure(); 1240 } else { 1241 attributes.push_back(mlir::acc::DeviceTypeAttr::get( 1242 parser.getContext(), mlir::acc::DeviceType::None)); 1243 } 1244 } while (succeeded(parser.parseOptionalComma())); 1245 1246 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(), 1247 attributes.end()); 1248 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr); 1249 segments = DenseI32ArrayAttr::get(parser.getContext(), seg); 1250 1251 return success(); 1252 } 1253 1254 static void printDeviceTypeOperandsWithSegment( 1255 mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, 1256 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes, 1257 std::optional<mlir::DenseI32ArrayAttr> segments) { 1258 unsigned opIdx = 0; 1259 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) { 1260 p << "{"; 1261 llvm::interleaveComma( 1262 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) { 1263 p << operands[opIdx] << " : " << operands[opIdx].getType(); 1264 ++opIdx; 1265 }); 1266 p << "}"; 1267 printSingleDeviceType(p, it.value()); 1268 }); 1269 } 1270 1271 static ParseResult parseWaitClause( 1272 mlir::OpAsmParser &parser, 1273 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands, 1274 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes, 1275 mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, 1276 mlir::ArrayAttr &keywordOnly) { 1277 llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum; 1278 llvm::SmallVector<int32_t> seg; 1279 1280 bool needCommaBeforeOperands = false; 1281 1282 // Keyword only 1283 if (failed(parser.parseOptionalLParen())) { 1284 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get( 1285 parser.getContext(), mlir::acc::DeviceType::None)); 1286 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs); 1287 return success(); 1288 } 1289 1290 // Parse keyword only attributes 1291 if (succeeded(parser.parseOptionalLSquare())) { 1292 if (failed(parser.parseCommaSeparatedList([&]() { 1293 if (parser.parseAttribute(keywordAttrs.emplace_back())) 1294 return failure(); 1295 return success(); 1296 }))) 1297 return failure(); 1298 if (parser.parseRSquare()) 1299 return failure(); 1300 needCommaBeforeOperands = true; 1301 } 1302 1303 if (needCommaBeforeOperands && failed(parser.parseComma())) 1304 return failure(); 1305 1306 do { 1307 if (failed(parser.parseLBrace())) 1308 return failure(); 1309 1310 int32_t crtOperandsSize = operands.size(); 1311 1312 if (succeeded(parser.parseOptionalKeyword("devnum"))) { 1313 if (failed(parser.parseColon())) 1314 return failure(); 1315 devnum.push_back(BoolAttr::get(parser.getContext(), true)); 1316 } else { 1317 devnum.push_back(BoolAttr::get(parser.getContext(), false)); 1318 } 1319 1320 if (failed(parser.parseCommaSeparatedList( 1321 mlir::AsmParser::Delimiter::None, [&]() { 1322 if (parser.parseOperand(operands.emplace_back()) || 1323 parser.parseColonType(types.emplace_back())) 1324 return failure(); 1325 return success(); 1326 }))) 1327 return failure(); 1328 1329 seg.push_back(operands.size() - crtOperandsSize); 1330 1331 if (failed(parser.parseRBrace())) 1332 return failure(); 1333 1334 if (succeeded(parser.parseOptionalLSquare())) { 1335 if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) || 1336 parser.parseRSquare()) 1337 return failure(); 1338 } else { 1339 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get( 1340 parser.getContext(), mlir::acc::DeviceType::None)); 1341 } 1342 } while (succeeded(parser.parseOptionalComma())); 1343 1344 if (failed(parser.parseRParen())) 1345 return failure(); 1346 1347 deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs); 1348 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs); 1349 segments = DenseI32ArrayAttr::get(parser.getContext(), seg); 1350 hasDevNum = ArrayAttr::get(parser.getContext(), devnum); 1351 1352 return success(); 1353 } 1354 1355 static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) { 1356 if (!hasDeviceTypeValues(attrs)) 1357 return false; 1358 if (attrs->size() != 1) 1359 return false; 1360 if (auto deviceTypeAttr = 1361 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0])) 1362 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None; 1363 return false; 1364 } 1365 1366 static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, 1367 mlir::OperandRange operands, mlir::TypeRange types, 1368 std::optional<mlir::ArrayAttr> deviceTypes, 1369 std::optional<mlir::DenseI32ArrayAttr> segments, 1370 std::optional<mlir::ArrayAttr> hasDevNum, 1371 std::optional<mlir::ArrayAttr> keywordOnly) { 1372 1373 if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly)) 1374 return; 1375 1376 p << "("; 1377 1378 printDeviceTypes(p, keywordOnly); 1379 if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes)) 1380 p << ", "; 1381 1382 unsigned opIdx = 0; 1383 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) { 1384 p << "{"; 1385 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]); 1386 if (boolAttr && boolAttr.getValue()) 1387 p << "devnum: "; 1388 llvm::interleaveComma( 1389 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) { 1390 p << operands[opIdx] << " : " << operands[opIdx].getType(); 1391 ++opIdx; 1392 }); 1393 p << "}"; 1394 printSingleDeviceType(p, it.value()); 1395 }); 1396 1397 p << ")"; 1398 } 1399 1400 static ParseResult parseDeviceTypeOperands( 1401 mlir::OpAsmParser &parser, 1402 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands, 1403 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) { 1404 llvm::SmallVector<DeviceTypeAttr> attributes; 1405 if (failed(parser.parseCommaSeparatedList([&]() { 1406 if (parser.parseOperand(operands.emplace_back()) || 1407 parser.parseColonType(types.emplace_back())) 1408 return failure(); 1409 if (succeeded(parser.parseOptionalLSquare())) { 1410 if (parser.parseAttribute(attributes.emplace_back()) || 1411 parser.parseRSquare()) 1412 return failure(); 1413 } else { 1414 attributes.push_back(mlir::acc::DeviceTypeAttr::get( 1415 parser.getContext(), mlir::acc::DeviceType::None)); 1416 } 1417 return success(); 1418 }))) 1419 return failure(); 1420 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(), 1421 attributes.end()); 1422 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr); 1423 return success(); 1424 } 1425 1426 static void 1427 printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, 1428 mlir::OperandRange operands, mlir::TypeRange types, 1429 std::optional<mlir::ArrayAttr> deviceTypes) { 1430 if (!hasDeviceTypeValues(deviceTypes)) 1431 return; 1432 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) { 1433 p << std::get<1>(it) << " : " << std::get<1>(it).getType(); 1434 printSingleDeviceType(p, std::get<0>(it)); 1435 }); 1436 } 1437 1438 static ParseResult parseDeviceTypeOperandsWithKeywordOnly( 1439 mlir::OpAsmParser &parser, 1440 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands, 1441 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes, 1442 mlir::ArrayAttr &keywordOnlyDeviceType) { 1443 1444 llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes; 1445 bool needCommaBeforeOperands = false; 1446 1447 if (failed(parser.parseOptionalLParen())) { 1448 // Keyword only 1449 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get( 1450 parser.getContext(), mlir::acc::DeviceType::None)); 1451 keywordOnlyDeviceType = 1452 ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes); 1453 return success(); 1454 } 1455 1456 // Parse keyword only attributes 1457 if (succeeded(parser.parseOptionalLSquare())) { 1458 // Parse keyword only attributes 1459 if (failed(parser.parseCommaSeparatedList([&]() { 1460 if (parser.parseAttribute( 1461 keywordOnlyDeviceTypeAttributes.emplace_back())) 1462 return failure(); 1463 return success(); 1464 }))) 1465 return failure(); 1466 if (parser.parseRSquare()) 1467 return failure(); 1468 needCommaBeforeOperands = true; 1469 } 1470 1471 if (needCommaBeforeOperands && failed(parser.parseComma())) 1472 return failure(); 1473 1474 llvm::SmallVector<DeviceTypeAttr> attributes; 1475 if (failed(parser.parseCommaSeparatedList([&]() { 1476 if (parser.parseOperand(operands.emplace_back()) || 1477 parser.parseColonType(types.emplace_back())) 1478 return failure(); 1479 if (succeeded(parser.parseOptionalLSquare())) { 1480 if (parser.parseAttribute(attributes.emplace_back()) || 1481 parser.parseRSquare()) 1482 return failure(); 1483 } else { 1484 attributes.push_back(mlir::acc::DeviceTypeAttr::get( 1485 parser.getContext(), mlir::acc::DeviceType::None)); 1486 } 1487 return success(); 1488 }))) 1489 return failure(); 1490 1491 if (failed(parser.parseRParen())) 1492 return failure(); 1493 1494 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(), 1495 attributes.end()); 1496 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr); 1497 return success(); 1498 } 1499 1500 static void printDeviceTypeOperandsWithKeywordOnly( 1501 mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, 1502 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes, 1503 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) { 1504 1505 if (operands.begin() == operands.end() && 1506 hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) { 1507 return; 1508 } 1509 1510 p << "("; 1511 printDeviceTypes(p, keywordOnlyDeviceTypes); 1512 if (hasDeviceTypeValues(keywordOnlyDeviceTypes) && 1513 hasDeviceTypeValues(deviceTypes)) 1514 p << ", "; 1515 printDeviceTypeOperands(p, op, operands, types, deviceTypes); 1516 p << ")"; 1517 } 1518 1519 static ParseResult 1520 parseCombinedConstructsLoop(mlir::OpAsmParser &parser, 1521 mlir::acc::CombinedConstructsTypeAttr &attr) { 1522 if (succeeded(parser.parseOptionalKeyword("combined"))) { 1523 if (parser.parseLParen()) 1524 return failure(); 1525 if (succeeded(parser.parseOptionalKeyword("kernels"))) { 1526 attr = mlir::acc::CombinedConstructsTypeAttr::get( 1527 parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop); 1528 } else if (succeeded(parser.parseOptionalKeyword("parallel"))) { 1529 attr = mlir::acc::CombinedConstructsTypeAttr::get( 1530 parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop); 1531 } else if (succeeded(parser.parseOptionalKeyword("serial"))) { 1532 attr = mlir::acc::CombinedConstructsTypeAttr::get( 1533 parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop); 1534 } else { 1535 parser.emitError(parser.getCurrentLocation(), 1536 "expected compute construct name"); 1537 return failure(); 1538 } 1539 if (parser.parseRParen()) 1540 return failure(); 1541 } 1542 return success(); 1543 } 1544 1545 static void 1546 printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, 1547 mlir::acc::CombinedConstructsTypeAttr attr) { 1548 if (attr) { 1549 switch (attr.getValue()) { 1550 case mlir::acc::CombinedConstructsType::KernelsLoop: 1551 p << "combined(kernels)"; 1552 break; 1553 case mlir::acc::CombinedConstructsType::ParallelLoop: 1554 p << "combined(parallel)"; 1555 break; 1556 case mlir::acc::CombinedConstructsType::SerialLoop: 1557 p << "combined(serial)"; 1558 break; 1559 }; 1560 } 1561 } 1562 1563 //===----------------------------------------------------------------------===// 1564 // SerialOp 1565 //===----------------------------------------------------------------------===// 1566 1567 unsigned SerialOp::getNumDataOperands() { 1568 return getReductionOperands().size() + getPrivateOperands().size() + 1569 getFirstprivateOperands().size() + getDataClauseOperands().size(); 1570 } 1571 1572 Value SerialOp::getDataOperand(unsigned i) { 1573 unsigned numOptional = getAsyncOperands().size(); 1574 numOptional += getIfCond() ? 1 : 0; 1575 numOptional += getSelfCond() ? 1 : 0; 1576 return getOperand(getWaitOperands().size() + numOptional + i); 1577 } 1578 1579 bool acc::SerialOp::hasAsyncOnly() { 1580 return hasAsyncOnly(mlir::acc::DeviceType::None); 1581 } 1582 1583 bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) { 1584 return hasDeviceType(getAsyncOnly(), deviceType); 1585 } 1586 1587 mlir::Value acc::SerialOp::getAsyncValue() { 1588 return getAsyncValue(mlir::acc::DeviceType::None); 1589 } 1590 1591 mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) { 1592 return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(), 1593 getAsyncOperands(), deviceType); 1594 } 1595 1596 bool acc::SerialOp::hasWaitOnly() { 1597 return hasWaitOnly(mlir::acc::DeviceType::None); 1598 } 1599 1600 bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { 1601 return hasDeviceType(getWaitOnly(), deviceType); 1602 } 1603 1604 mlir::Operation::operand_range SerialOp::getWaitValues() { 1605 return getWaitValues(mlir::acc::DeviceType::None); 1606 } 1607 1608 mlir::Operation::operand_range 1609 SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) { 1610 return getWaitValuesWithoutDevnum( 1611 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(), 1612 getHasWaitDevnum(), deviceType); 1613 } 1614 1615 mlir::Value SerialOp::getWaitDevnum() { 1616 return getWaitDevnum(mlir::acc::DeviceType::None); 1617 } 1618 1619 mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) { 1620 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(), 1621 getWaitOperandsSegments(), getHasWaitDevnum(), 1622 deviceType); 1623 } 1624 1625 LogicalResult acc::SerialOp::verify() { 1626 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>( 1627 *this, getPrivatizations(), getPrivateOperands(), "private", 1628 "privatizations", /*checkOperandType=*/false))) 1629 return failure(); 1630 if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>( 1631 *this, getFirstprivatizations(), getFirstprivateOperands(), 1632 "firstprivate", "firstprivatizations", /*checkOperandType=*/false))) 1633 return failure(); 1634 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>( 1635 *this, getReductionRecipes(), getReductionOperands(), "reduction", 1636 "reductions", false))) 1637 return failure(); 1638 1639 if (failed(verifyDeviceTypeAndSegmentCountMatch( 1640 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(), 1641 getWaitOperandsDeviceTypeAttr(), "wait"))) 1642 return failure(); 1643 1644 if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(), 1645 getAsyncOperandsDeviceTypeAttr(), 1646 "async"))) 1647 return failure(); 1648 1649 if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*this))) 1650 return failure(); 1651 1652 return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands()); 1653 } 1654 1655 //===----------------------------------------------------------------------===// 1656 // KernelsOp 1657 //===----------------------------------------------------------------------===// 1658 1659 unsigned KernelsOp::getNumDataOperands() { 1660 return getDataClauseOperands().size(); 1661 } 1662 1663 Value KernelsOp::getDataOperand(unsigned i) { 1664 unsigned numOptional = getAsyncOperands().size(); 1665 numOptional += getWaitOperands().size(); 1666 numOptional += getNumGangs().size(); 1667 numOptional += getNumWorkers().size(); 1668 numOptional += getVectorLength().size(); 1669 numOptional += getIfCond() ? 1 : 0; 1670 numOptional += getSelfCond() ? 1 : 0; 1671 return getOperand(numOptional + i); 1672 } 1673 1674 bool acc::KernelsOp::hasAsyncOnly() { 1675 return hasAsyncOnly(mlir::acc::DeviceType::None); 1676 } 1677 1678 bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) { 1679 return hasDeviceType(getAsyncOnly(), deviceType); 1680 } 1681 1682 mlir::Value acc::KernelsOp::getAsyncValue() { 1683 return getAsyncValue(mlir::acc::DeviceType::None); 1684 } 1685 1686 mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) { 1687 return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(), 1688 getAsyncOperands(), deviceType); 1689 } 1690 1691 mlir::Value acc::KernelsOp::getNumWorkersValue() { 1692 return getNumWorkersValue(mlir::acc::DeviceType::None); 1693 } 1694 1695 mlir::Value 1696 acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) { 1697 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(), 1698 deviceType); 1699 } 1700 1701 mlir::Value acc::KernelsOp::getVectorLengthValue() { 1702 return getVectorLengthValue(mlir::acc::DeviceType::None); 1703 } 1704 1705 mlir::Value 1706 acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) { 1707 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(), 1708 getVectorLength(), deviceType); 1709 } 1710 1711 mlir::Operation::operand_range KernelsOp::getNumGangsValues() { 1712 return getNumGangsValues(mlir::acc::DeviceType::None); 1713 } 1714 1715 mlir::Operation::operand_range 1716 KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) { 1717 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(), 1718 getNumGangsSegments(), deviceType); 1719 } 1720 1721 bool acc::KernelsOp::hasWaitOnly() { 1722 return hasWaitOnly(mlir::acc::DeviceType::None); 1723 } 1724 1725 bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { 1726 return hasDeviceType(getWaitOnly(), deviceType); 1727 } 1728 1729 mlir::Operation::operand_range KernelsOp::getWaitValues() { 1730 return getWaitValues(mlir::acc::DeviceType::None); 1731 } 1732 1733 mlir::Operation::operand_range 1734 KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) { 1735 return getWaitValuesWithoutDevnum( 1736 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(), 1737 getHasWaitDevnum(), deviceType); 1738 } 1739 1740 mlir::Value KernelsOp::getWaitDevnum() { 1741 return getWaitDevnum(mlir::acc::DeviceType::None); 1742 } 1743 1744 mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) { 1745 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(), 1746 getWaitOperandsSegments(), getHasWaitDevnum(), 1747 deviceType); 1748 } 1749 1750 LogicalResult acc::KernelsOp::verify() { 1751 if (failed(verifyDeviceTypeAndSegmentCountMatch( 1752 *this, getNumGangs(), getNumGangsSegmentsAttr(), 1753 getNumGangsDeviceTypeAttr(), "num_gangs", 3))) 1754 return failure(); 1755 1756 if (failed(verifyDeviceTypeAndSegmentCountMatch( 1757 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(), 1758 getWaitOperandsDeviceTypeAttr(), "wait"))) 1759 return failure(); 1760 1761 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(), 1762 getNumWorkersDeviceTypeAttr(), 1763 "num_workers"))) 1764 return failure(); 1765 1766 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(), 1767 getVectorLengthDeviceTypeAttr(), 1768 "vector_length"))) 1769 return failure(); 1770 1771 if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(), 1772 getAsyncOperandsDeviceTypeAttr(), 1773 "async"))) 1774 return failure(); 1775 1776 if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*this))) 1777 return failure(); 1778 1779 return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands()); 1780 } 1781 1782 //===----------------------------------------------------------------------===// 1783 // HostDataOp 1784 //===----------------------------------------------------------------------===// 1785 1786 LogicalResult acc::HostDataOp::verify() { 1787 if (getDataClauseOperands().empty()) 1788 return emitError("at least one operand must appear on the host_data " 1789 "operation"); 1790 1791 for (mlir::Value operand : getDataClauseOperands()) 1792 if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp())) 1793 return emitError("expect data entry operation as defining op"); 1794 return success(); 1795 } 1796 1797 void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results, 1798 MLIRContext *context) { 1799 results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context); 1800 } 1801 1802 //===----------------------------------------------------------------------===// 1803 // LoopOp 1804 //===----------------------------------------------------------------------===// 1805 1806 static ParseResult parseGangValue( 1807 OpAsmParser &parser, llvm::StringRef keyword, 1808 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands, 1809 llvm::SmallVectorImpl<Type> &types, 1810 llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType, 1811 bool &needCommaBetweenValues, bool &newValue) { 1812 if (succeeded(parser.parseOptionalKeyword(keyword))) { 1813 if (parser.parseEqual()) 1814 return failure(); 1815 if (parser.parseOperand(operands.emplace_back()) || 1816 parser.parseColonType(types.emplace_back())) 1817 return failure(); 1818 attributes.push_back(gangArgType); 1819 needCommaBetweenValues = true; 1820 newValue = true; 1821 } 1822 return success(); 1823 } 1824 1825 static ParseResult parseGangClause( 1826 OpAsmParser &parser, 1827 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &gangOperands, 1828 llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType, 1829 mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, 1830 mlir::ArrayAttr &gangOnlyDeviceType) { 1831 llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes; 1832 llvm::SmallVector<mlir::Attribute> deviceTypeAttributes; 1833 llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes; 1834 llvm::SmallVector<int32_t> seg; 1835 bool needCommaBetweenValues = false; 1836 bool needCommaBeforeOperands = false; 1837 1838 if (failed(parser.parseOptionalLParen())) { 1839 // Gang only keyword 1840 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get( 1841 parser.getContext(), mlir::acc::DeviceType::None)); 1842 gangOnlyDeviceType = 1843 ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes); 1844 return success(); 1845 } 1846 1847 // Parse gang only attributes 1848 if (succeeded(parser.parseOptionalLSquare())) { 1849 // Parse gang only attributes 1850 if (failed(parser.parseCommaSeparatedList([&]() { 1851 if (parser.parseAttribute( 1852 gangOnlyDeviceTypeAttributes.emplace_back())) 1853 return failure(); 1854 return success(); 1855 }))) 1856 return failure(); 1857 if (parser.parseRSquare()) 1858 return failure(); 1859 needCommaBeforeOperands = true; 1860 } 1861 1862 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(), 1863 mlir::acc::GangArgType::Num); 1864 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(), 1865 mlir::acc::GangArgType::Dim); 1866 auto argStatic = mlir::acc::GangArgTypeAttr::get( 1867 parser.getContext(), mlir::acc::GangArgType::Static); 1868 1869 do { 1870 if (needCommaBeforeOperands) { 1871 needCommaBeforeOperands = false; 1872 continue; 1873 } 1874 1875 if (failed(parser.parseLBrace())) 1876 return failure(); 1877 1878 int32_t crtOperandsSize = gangOperands.size(); 1879 while (true) { 1880 bool newValue = false; 1881 bool needValue = false; 1882 if (needCommaBetweenValues) { 1883 if (succeeded(parser.parseOptionalComma())) 1884 needValue = true; // expect a new value after comma. 1885 else 1886 break; 1887 } 1888 1889 if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(), 1890 gangOperands, gangOperandsType, 1891 gangArgTypeAttributes, argNum, 1892 needCommaBetweenValues, newValue))) 1893 return failure(); 1894 if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(), 1895 gangOperands, gangOperandsType, 1896 gangArgTypeAttributes, argDim, 1897 needCommaBetweenValues, newValue))) 1898 return failure(); 1899 if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(), 1900 gangOperands, gangOperandsType, 1901 gangArgTypeAttributes, argStatic, 1902 needCommaBetweenValues, newValue))) 1903 return failure(); 1904 1905 if (!newValue && needValue) { 1906 parser.emitError(parser.getCurrentLocation(), 1907 "new value expected after comma"); 1908 return failure(); 1909 } 1910 1911 if (!newValue) 1912 break; 1913 } 1914 1915 if (gangOperands.empty()) 1916 return parser.emitError( 1917 parser.getCurrentLocation(), 1918 "expect at least one of num, dim or static values"); 1919 1920 if (failed(parser.parseRBrace())) 1921 return failure(); 1922 1923 if (succeeded(parser.parseOptionalLSquare())) { 1924 if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) || 1925 parser.parseRSquare()) 1926 return failure(); 1927 } else { 1928 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get( 1929 parser.getContext(), mlir::acc::DeviceType::None)); 1930 } 1931 1932 seg.push_back(gangOperands.size() - crtOperandsSize); 1933 1934 } while (succeeded(parser.parseOptionalComma())); 1935 1936 if (failed(parser.parseRParen())) 1937 return failure(); 1938 1939 llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(), 1940 gangArgTypeAttributes.end()); 1941 gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr); 1942 deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes); 1943 1944 llvm::SmallVector<mlir::Attribute> gangOnlyAttr( 1945 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end()); 1946 gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr); 1947 1948 segments = DenseI32ArrayAttr::get(parser.getContext(), seg); 1949 return success(); 1950 } 1951 1952 void printGangClause(OpAsmPrinter &p, Operation *op, 1953 mlir::OperandRange operands, mlir::TypeRange types, 1954 std::optional<mlir::ArrayAttr> gangArgTypes, 1955 std::optional<mlir::ArrayAttr> deviceTypes, 1956 std::optional<mlir::DenseI32ArrayAttr> segments, 1957 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) { 1958 1959 if (operands.begin() == operands.end() && 1960 hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) { 1961 return; 1962 } 1963 1964 p << "("; 1965 1966 printDeviceTypes(p, gangOnlyDeviceTypes); 1967 1968 if (hasDeviceTypeValues(gangOnlyDeviceTypes) && 1969 hasDeviceTypeValues(deviceTypes)) 1970 p << ", "; 1971 1972 if (hasDeviceTypeValues(deviceTypes)) { 1973 unsigned opIdx = 0; 1974 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) { 1975 p << "{"; 1976 llvm::interleaveComma( 1977 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) { 1978 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>( 1979 (*gangArgTypes)[opIdx]); 1980 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num) 1981 p << LoopOp::getGangNumKeyword(); 1982 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim) 1983 p << LoopOp::getGangDimKeyword(); 1984 else if (gangArgTypeAttr.getValue() == 1985 mlir::acc::GangArgType::Static) 1986 p << LoopOp::getGangStaticKeyword(); 1987 p << "=" << operands[opIdx] << " : " << operands[opIdx].getType(); 1988 ++opIdx; 1989 }); 1990 p << "}"; 1991 printSingleDeviceType(p, it.value()); 1992 }); 1993 } 1994 p << ")"; 1995 } 1996 1997 bool hasDuplicateDeviceTypes( 1998 std::optional<mlir::ArrayAttr> segments, 1999 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) { 2000 if (!segments) 2001 return false; 2002 for (auto attr : *segments) { 2003 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr); 2004 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second) 2005 return true; 2006 } 2007 return false; 2008 } 2009 2010 /// Check for duplicates in the DeviceType array attribute. 2011 LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) { 2012 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes; 2013 if (!deviceTypes) 2014 return success(); 2015 for (auto attr : deviceTypes) { 2016 auto deviceTypeAttr = 2017 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr); 2018 if (!deviceTypeAttr) 2019 return failure(); 2020 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second) 2021 return failure(); 2022 } 2023 return success(); 2024 } 2025 2026 LogicalResult acc::LoopOp::verify() { 2027 if (!getUpperbound().empty() && getInclusiveUpperbound() && 2028 (getUpperbound().size() != getInclusiveUpperbound()->size())) 2029 return emitError() << "inclusiveUpperbound size is expected to be the same" 2030 << " as upperbound size"; 2031 2032 // Check collapse 2033 if (getCollapseAttr() && !getCollapseDeviceTypeAttr()) 2034 return emitOpError() << "collapse device_type attr must be define when" 2035 << " collapse attr is present"; 2036 2037 if (getCollapseAttr() && getCollapseDeviceTypeAttr() && 2038 getCollapseAttr().getValue().size() != 2039 getCollapseDeviceTypeAttr().getValue().size()) 2040 return emitOpError() << "collapse attribute count must match collapse" 2041 << " device_type count"; 2042 if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr()))) 2043 return emitOpError() 2044 << "duplicate device_type found in collapseDeviceType attribute"; 2045 2046 // Check gang 2047 if (!getGangOperands().empty()) { 2048 if (!getGangOperandsArgType()) 2049 return emitOpError() << "gangOperandsArgType attribute must be defined" 2050 << " when gang operands are present"; 2051 2052 if (getGangOperands().size() != 2053 getGangOperandsArgTypeAttr().getValue().size()) 2054 return emitOpError() << "gangOperandsArgType attribute count must match" 2055 << " gangOperands count"; 2056 } 2057 if (getGangAttr() && failed(checkDeviceTypes(getGangAttr()))) 2058 return emitOpError() << "duplicate device_type found in gang attribute"; 2059 2060 if (failed(verifyDeviceTypeAndSegmentCountMatch( 2061 *this, getGangOperands(), getGangOperandsSegmentsAttr(), 2062 getGangOperandsDeviceTypeAttr(), "gang"))) 2063 return failure(); 2064 2065 // Check worker 2066 if (failed(checkDeviceTypes(getWorkerAttr()))) 2067 return emitOpError() << "duplicate device_type found in worker attribute"; 2068 if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr()))) 2069 return emitOpError() << "duplicate device_type found in " 2070 "workerNumOperandsDeviceType attribute"; 2071 if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(), 2072 getWorkerNumOperandsDeviceTypeAttr(), 2073 "worker"))) 2074 return failure(); 2075 2076 // Check vector 2077 if (failed(checkDeviceTypes(getVectorAttr()))) 2078 return emitOpError() << "duplicate device_type found in vector attribute"; 2079 if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr()))) 2080 return emitOpError() << "duplicate device_type found in " 2081 "vectorOperandsDeviceType attribute"; 2082 if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(), 2083 getVectorOperandsDeviceTypeAttr(), 2084 "vector"))) 2085 return failure(); 2086 2087 if (failed(verifyDeviceTypeAndSegmentCountMatch( 2088 *this, getTileOperands(), getTileOperandsSegmentsAttr(), 2089 getTileOperandsDeviceTypeAttr(), "tile"))) 2090 return failure(); 2091 2092 // auto, independent and seq attribute are mutually exclusive. 2093 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes; 2094 if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) || 2095 hasDuplicateDeviceTypes(getIndependent(), deviceTypes) || 2096 hasDuplicateDeviceTypes(getSeq(), deviceTypes)) { 2097 return emitError() << "only one of \"" << acc::LoopOp::getAutoAttrStrName() 2098 << "\", " << getIndependentAttrName() << ", " 2099 << getSeqAttrName() 2100 << " can be present at the same time"; 2101 } 2102 2103 // Gang, worker and vector are incompatible with seq. 2104 if (getSeqAttr()) { 2105 for (auto attr : getSeqAttr()) { 2106 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr); 2107 if (hasVector(deviceTypeAttr.getValue()) || 2108 getVectorValue(deviceTypeAttr.getValue()) || 2109 hasWorker(deviceTypeAttr.getValue()) || 2110 getWorkerValue(deviceTypeAttr.getValue()) || 2111 hasGang(deviceTypeAttr.getValue()) || 2112 getGangValue(mlir::acc::GangArgType::Num, 2113 deviceTypeAttr.getValue()) || 2114 getGangValue(mlir::acc::GangArgType::Dim, 2115 deviceTypeAttr.getValue()) || 2116 getGangValue(mlir::acc::GangArgType::Static, 2117 deviceTypeAttr.getValue())) 2118 return emitError() 2119 << "gang, worker or vector cannot appear with the seq attr"; 2120 } 2121 } 2122 2123 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>( 2124 *this, getPrivatizations(), getPrivateOperands(), "private", 2125 "privatizations", false))) 2126 return failure(); 2127 2128 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>( 2129 *this, getReductionRecipes(), getReductionOperands(), "reduction", 2130 "reductions", false))) 2131 return failure(); 2132 2133 if (getCombined().has_value() && 2134 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop && 2135 getCombined().value() != acc::CombinedConstructsType::KernelsLoop && 2136 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) { 2137 return emitError("unexpected combined constructs attribute"); 2138 } 2139 2140 // Check non-empty body(). 2141 if (getRegion().empty()) 2142 return emitError("expected non-empty body."); 2143 2144 return success(); 2145 } 2146 2147 unsigned LoopOp::getNumDataOperands() { 2148 return getReductionOperands().size() + getPrivateOperands().size(); 2149 } 2150 2151 Value LoopOp::getDataOperand(unsigned i) { 2152 unsigned numOptional = 2153 getLowerbound().size() + getUpperbound().size() + getStep().size(); 2154 numOptional += getGangOperands().size(); 2155 numOptional += getVectorOperands().size(); 2156 numOptional += getWorkerNumOperands().size(); 2157 numOptional += getTileOperands().size(); 2158 numOptional += getCacheOperands().size(); 2159 return getOperand(numOptional + i); 2160 } 2161 2162 bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); } 2163 2164 bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) { 2165 return hasDeviceType(getAuto_(), deviceType); 2166 } 2167 2168 bool LoopOp::hasIndependent() { 2169 return hasIndependent(mlir::acc::DeviceType::None); 2170 } 2171 2172 bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) { 2173 return hasDeviceType(getIndependent(), deviceType); 2174 } 2175 2176 bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); } 2177 2178 bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) { 2179 return hasDeviceType(getSeq(), deviceType); 2180 } 2181 2182 mlir::Value LoopOp::getVectorValue() { 2183 return getVectorValue(mlir::acc::DeviceType::None); 2184 } 2185 2186 mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) { 2187 return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(), 2188 getVectorOperands(), deviceType); 2189 } 2190 2191 bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); } 2192 2193 bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) { 2194 return hasDeviceType(getVector(), deviceType); 2195 } 2196 2197 mlir::Value LoopOp::getWorkerValue() { 2198 return getWorkerValue(mlir::acc::DeviceType::None); 2199 } 2200 2201 mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) { 2202 return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(), 2203 getWorkerNumOperands(), deviceType); 2204 } 2205 2206 bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); } 2207 2208 bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) { 2209 return hasDeviceType(getWorker(), deviceType); 2210 } 2211 2212 mlir::Operation::operand_range LoopOp::getTileValues() { 2213 return getTileValues(mlir::acc::DeviceType::None); 2214 } 2215 2216 mlir::Operation::operand_range 2217 LoopOp::getTileValues(mlir::acc::DeviceType deviceType) { 2218 return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(), 2219 getTileOperandsSegments(), deviceType); 2220 } 2221 2222 std::optional<int64_t> LoopOp::getCollapseValue() { 2223 return getCollapseValue(mlir::acc::DeviceType::None); 2224 } 2225 2226 std::optional<int64_t> 2227 LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) { 2228 if (!getCollapseAttr()) 2229 return std::nullopt; 2230 if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) { 2231 auto intAttr = 2232 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]); 2233 return intAttr.getValue().getZExtValue(); 2234 } 2235 return std::nullopt; 2236 } 2237 2238 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) { 2239 return getGangValue(gangArgType, mlir::acc::DeviceType::None); 2240 } 2241 2242 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType, 2243 mlir::acc::DeviceType deviceType) { 2244 if (getGangOperands().empty()) 2245 return {}; 2246 if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) { 2247 int32_t nbOperandsBefore = 0; 2248 for (unsigned i = 0; i < *pos; ++i) 2249 nbOperandsBefore += (*getGangOperandsSegments())[i]; 2250 mlir::Operation::operand_range values = 2251 getGangOperands() 2252 .drop_front(nbOperandsBefore) 2253 .take_front((*getGangOperandsSegments())[*pos]); 2254 2255 int32_t argTypeIdx = nbOperandsBefore; 2256 for (auto value : values) { 2257 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>( 2258 (*getGangOperandsArgType())[argTypeIdx]); 2259 if (gangArgTypeAttr.getValue() == gangArgType) 2260 return value; 2261 ++argTypeIdx; 2262 } 2263 } 2264 return {}; 2265 } 2266 2267 bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); } 2268 2269 bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) { 2270 return hasDeviceType(getGang(), deviceType); 2271 } 2272 2273 llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() { 2274 return {&getRegion()}; 2275 } 2276 2277 /// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=` 2278 /// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step` 2279 /// `(` ssa-id-and-type-list `)` 2280 /// region 2281 ParseResult 2282 parseLoopControl(OpAsmParser &parser, Region ®ion, 2283 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &lowerbound, 2284 SmallVectorImpl<Type> &lowerboundType, 2285 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &upperbound, 2286 SmallVectorImpl<Type> &upperboundType, 2287 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &step, 2288 SmallVectorImpl<Type> &stepType) { 2289 2290 SmallVector<OpAsmParser::Argument> inductionVars; 2291 if (succeeded( 2292 parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) { 2293 if (parser.parseLParen() || 2294 parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None, 2295 /*allowType=*/true) || 2296 parser.parseRParen() || parser.parseEqual() || parser.parseLParen() || 2297 parser.parseOperandList(lowerbound, inductionVars.size(), 2298 OpAsmParser::Delimiter::None) || 2299 parser.parseColonTypeList(lowerboundType) || parser.parseRParen() || 2300 parser.parseKeyword("to") || parser.parseLParen() || 2301 parser.parseOperandList(upperbound, inductionVars.size(), 2302 OpAsmParser::Delimiter::None) || 2303 parser.parseColonTypeList(upperboundType) || parser.parseRParen() || 2304 parser.parseKeyword("step") || parser.parseLParen() || 2305 parser.parseOperandList(step, inductionVars.size(), 2306 OpAsmParser::Delimiter::None) || 2307 parser.parseColonTypeList(stepType) || parser.parseRParen()) 2308 return failure(); 2309 } 2310 return parser.parseRegion(region, inductionVars); 2311 } 2312 2313 void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, 2314 ValueRange lowerbound, TypeRange lowerboundType, 2315 ValueRange upperbound, TypeRange upperboundType, 2316 ValueRange steps, TypeRange stepType) { 2317 ValueRange regionArgs = region.front().getArguments(); 2318 if (!regionArgs.empty()) { 2319 p << acc::LoopOp::getControlKeyword() << "("; 2320 llvm::interleaveComma(regionArgs, p, 2321 [&p](Value v) { p << v << " : " << v.getType(); }); 2322 p << ") = (" << lowerbound << " : " << lowerboundType << ") to (" 2323 << upperbound << " : " << upperboundType << ") " << " step (" << steps 2324 << " : " << stepType << ") "; 2325 } 2326 p.printRegion(region, /*printEntryBlockArgs=*/false); 2327 } 2328 2329 //===----------------------------------------------------------------------===// 2330 // DataOp 2331 //===----------------------------------------------------------------------===// 2332 2333 LogicalResult acc::DataOp::verify() { 2334 // 2.6.5. Data Construct restriction 2335 // At least one copy, copyin, copyout, create, no_create, present, deviceptr, 2336 // attach, or default clause must appear on a data construct. 2337 if (getOperands().empty() && !getDefaultAttr()) 2338 return emitError("at least one operand or the default attribute " 2339 "must appear on the data operation"); 2340 2341 for (mlir::Value operand : getDataClauseOperands()) 2342 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp, 2343 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp, 2344 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>( 2345 operand.getDefiningOp())) 2346 return emitError("expect data entry/exit operation or acc.getdeviceptr " 2347 "as defining op"); 2348 2349 if (failed(checkWaitAndAsyncConflict<acc::DataOp>(*this))) 2350 return failure(); 2351 2352 return success(); 2353 } 2354 2355 unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); } 2356 2357 Value DataOp::getDataOperand(unsigned i) { 2358 unsigned numOptional = getIfCond() ? 1 : 0; 2359 numOptional += getAsyncOperands().size() ? 1 : 0; 2360 numOptional += getWaitOperands().size(); 2361 return getOperand(numOptional + i); 2362 } 2363 2364 bool acc::DataOp::hasAsyncOnly() { 2365 return hasAsyncOnly(mlir::acc::DeviceType::None); 2366 } 2367 2368 bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) { 2369 return hasDeviceType(getAsyncOnly(), deviceType); 2370 } 2371 2372 mlir::Value DataOp::getAsyncValue() { 2373 return getAsyncValue(mlir::acc::DeviceType::None); 2374 } 2375 2376 mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) { 2377 return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(), 2378 getAsyncOperands(), deviceType); 2379 } 2380 2381 bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); } 2382 2383 bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { 2384 return hasDeviceType(getWaitOnly(), deviceType); 2385 } 2386 2387 mlir::Operation::operand_range DataOp::getWaitValues() { 2388 return getWaitValues(mlir::acc::DeviceType::None); 2389 } 2390 2391 mlir::Operation::operand_range 2392 DataOp::getWaitValues(mlir::acc::DeviceType deviceType) { 2393 return getWaitValuesWithoutDevnum( 2394 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(), 2395 getHasWaitDevnum(), deviceType); 2396 } 2397 2398 mlir::Value DataOp::getWaitDevnum() { 2399 return getWaitDevnum(mlir::acc::DeviceType::None); 2400 } 2401 2402 mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) { 2403 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(), 2404 getWaitOperandsSegments(), getHasWaitDevnum(), 2405 deviceType); 2406 } 2407 2408 //===----------------------------------------------------------------------===// 2409 // ExitDataOp 2410 //===----------------------------------------------------------------------===// 2411 2412 LogicalResult acc::ExitDataOp::verify() { 2413 // 2.6.6. Data Exit Directive restriction 2414 // At least one copyout, delete, or detach clause must appear on an exit data 2415 // directive. 2416 if (getDataClauseOperands().empty()) 2417 return emitError("at least one operand must be present in dataOperands on " 2418 "the exit data operation"); 2419 2420 // The async attribute represent the async clause without value. Therefore the 2421 // attribute and operand cannot appear at the same time. 2422 if (getAsyncOperand() && getAsync()) 2423 return emitError("async attribute cannot appear with asyncOperand"); 2424 2425 // The wait attribute represent the wait clause without values. Therefore the 2426 // attribute and operands cannot appear at the same time. 2427 if (!getWaitOperands().empty() && getWait()) 2428 return emitError("wait attribute cannot appear with waitOperands"); 2429 2430 if (getWaitDevnum() && getWaitOperands().empty()) 2431 return emitError("wait_devnum cannot appear without waitOperands"); 2432 2433 return success(); 2434 } 2435 2436 unsigned ExitDataOp::getNumDataOperands() { 2437 return getDataClauseOperands().size(); 2438 } 2439 2440 Value ExitDataOp::getDataOperand(unsigned i) { 2441 unsigned numOptional = getIfCond() ? 1 : 0; 2442 numOptional += getAsyncOperand() ? 1 : 0; 2443 numOptional += getWaitDevnum() ? 1 : 0; 2444 return getOperand(getWaitOperands().size() + numOptional + i); 2445 } 2446 2447 void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results, 2448 MLIRContext *context) { 2449 results.add<RemoveConstantIfCondition<ExitDataOp>>(context); 2450 } 2451 2452 //===----------------------------------------------------------------------===// 2453 // EnterDataOp 2454 //===----------------------------------------------------------------------===// 2455 2456 LogicalResult acc::EnterDataOp::verify() { 2457 // 2.6.6. Data Enter Directive restriction 2458 // At least one copyin, create, or attach clause must appear on an enter data 2459 // directive. 2460 if (getDataClauseOperands().empty()) 2461 return emitError("at least one operand must be present in dataOperands on " 2462 "the enter data operation"); 2463 2464 // The async attribute represent the async clause without value. Therefore the 2465 // attribute and operand cannot appear at the same time. 2466 if (getAsyncOperand() && getAsync()) 2467 return emitError("async attribute cannot appear with asyncOperand"); 2468 2469 // The wait attribute represent the wait clause without values. Therefore the 2470 // attribute and operands cannot appear at the same time. 2471 if (!getWaitOperands().empty() && getWait()) 2472 return emitError("wait attribute cannot appear with waitOperands"); 2473 2474 if (getWaitDevnum() && getWaitOperands().empty()) 2475 return emitError("wait_devnum cannot appear without waitOperands"); 2476 2477 for (mlir::Value operand : getDataClauseOperands()) 2478 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>( 2479 operand.getDefiningOp())) 2480 return emitError("expect data entry operation as defining op"); 2481 2482 return success(); 2483 } 2484 2485 unsigned EnterDataOp::getNumDataOperands() { 2486 return getDataClauseOperands().size(); 2487 } 2488 2489 Value EnterDataOp::getDataOperand(unsigned i) { 2490 unsigned numOptional = getIfCond() ? 1 : 0; 2491 numOptional += getAsyncOperand() ? 1 : 0; 2492 numOptional += getWaitDevnum() ? 1 : 0; 2493 return getOperand(getWaitOperands().size() + numOptional + i); 2494 } 2495 2496 void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results, 2497 MLIRContext *context) { 2498 results.add<RemoveConstantIfCondition<EnterDataOp>>(context); 2499 } 2500 2501 //===----------------------------------------------------------------------===// 2502 // AtomicReadOp 2503 //===----------------------------------------------------------------------===// 2504 2505 LogicalResult AtomicReadOp::verify() { return verifyCommon(); } 2506 2507 //===----------------------------------------------------------------------===// 2508 // AtomicWriteOp 2509 //===----------------------------------------------------------------------===// 2510 2511 LogicalResult AtomicWriteOp::verify() { return verifyCommon(); } 2512 2513 //===----------------------------------------------------------------------===// 2514 // AtomicUpdateOp 2515 //===----------------------------------------------------------------------===// 2516 2517 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op, 2518 PatternRewriter &rewriter) { 2519 if (op.isNoOp()) { 2520 rewriter.eraseOp(op); 2521 return success(); 2522 } 2523 2524 if (Value writeVal = op.getWriteOpVal()) { 2525 rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal); 2526 return success(); 2527 } 2528 2529 return failure(); 2530 } 2531 2532 LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); } 2533 2534 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); } 2535 2536 //===----------------------------------------------------------------------===// 2537 // AtomicCaptureOp 2538 //===----------------------------------------------------------------------===// 2539 2540 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() { 2541 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp())) 2542 return op; 2543 return dyn_cast<AtomicReadOp>(getSecondOp()); 2544 } 2545 2546 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() { 2547 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp())) 2548 return op; 2549 return dyn_cast<AtomicWriteOp>(getSecondOp()); 2550 } 2551 2552 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() { 2553 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp())) 2554 return op; 2555 return dyn_cast<AtomicUpdateOp>(getSecondOp()); 2556 } 2557 2558 LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); } 2559 2560 //===----------------------------------------------------------------------===// 2561 // DeclareEnterOp 2562 //===----------------------------------------------------------------------===// 2563 2564 template <typename Op> 2565 static LogicalResult 2566 checkDeclareOperands(Op &op, const mlir::ValueRange &operands, 2567 bool requireAtLeastOneOperand = true) { 2568 if (operands.empty() && requireAtLeastOneOperand) 2569 return emitError( 2570 op->getLoc(), 2571 "at least one operand must appear on the declare operation"); 2572 2573 for (mlir::Value operand : operands) { 2574 if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp, 2575 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp, 2576 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>( 2577 operand.getDefiningOp())) 2578 return op.emitError( 2579 "expect valid declare data entry operation or acc.getdeviceptr " 2580 "as defining op"); 2581 2582 mlir::Value varPtr{getVarPtr(operand.getDefiningOp())}; 2583 assert(varPtr && "declare operands can only be data entry operations which " 2584 "must have varPtr"); 2585 std::optional<mlir::acc::DataClause> dataClauseOptional{ 2586 getDataClause(operand.getDefiningOp())}; 2587 assert(dataClauseOptional.has_value() && 2588 "declare operands can only be data entry operations which must have " 2589 "dataClause"); 2590 2591 // If varPtr has no defining op - there is nothing to check further. 2592 if (!varPtr.getDefiningOp()) 2593 continue; 2594 2595 // Check that the varPtr has a declare attribute. 2596 auto declareAttribute{ 2597 varPtr.getDefiningOp()->getAttr(mlir::acc::getDeclareAttrName())}; 2598 if (!declareAttribute) 2599 return op.emitError( 2600 "expect declare attribute on variable in declare operation"); 2601 2602 auto declAttr = mlir::cast<mlir::acc::DeclareAttr>(declareAttribute); 2603 if (declAttr.getDataClause().getValue() != dataClauseOptional.value()) 2604 return op.emitError( 2605 "expect matching declare attribute on variable in declare operation"); 2606 2607 // If the variable is marked with implicit attribute, the matching declare 2608 // data action must also be marked implicit. The reverse is not checked 2609 // since implicit data action may be inserted to do actions like updating 2610 // device copy, in which case the variable is not necessarily implicitly 2611 // declare'd. 2612 if (declAttr.getImplicit() && 2613 declAttr.getImplicit() != acc::getImplicitFlag(operand.getDefiningOp())) 2614 return op.emitError( 2615 "implicitness must match between declare op and flag on variable"); 2616 } 2617 2618 return success(); 2619 } 2620 2621 LogicalResult acc::DeclareEnterOp::verify() { 2622 return checkDeclareOperands(*this, this->getDataClauseOperands()); 2623 } 2624 2625 //===----------------------------------------------------------------------===// 2626 // DeclareExitOp 2627 //===----------------------------------------------------------------------===// 2628 2629 LogicalResult acc::DeclareExitOp::verify() { 2630 if (getToken()) 2631 return checkDeclareOperands(*this, this->getDataClauseOperands(), 2632 /*requireAtLeastOneOperand=*/false); 2633 return checkDeclareOperands(*this, this->getDataClauseOperands()); 2634 } 2635 2636 //===----------------------------------------------------------------------===// 2637 // DeclareOp 2638 //===----------------------------------------------------------------------===// 2639 2640 LogicalResult acc::DeclareOp::verify() { 2641 return checkDeclareOperands(*this, this->getDataClauseOperands()); 2642 } 2643 2644 //===----------------------------------------------------------------------===// 2645 // RoutineOp 2646 //===----------------------------------------------------------------------===// 2647 2648 static unsigned getParallelismForDeviceType(acc::RoutineOp op, 2649 acc::DeviceType dtype) { 2650 unsigned parallelism = 0; 2651 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0; 2652 parallelism += op.hasWorker(dtype) ? 1 : 0; 2653 parallelism += op.hasVector(dtype) ? 1 : 0; 2654 parallelism += op.hasSeq(dtype) ? 1 : 0; 2655 return parallelism; 2656 } 2657 2658 LogicalResult acc::RoutineOp::verify() { 2659 unsigned baseParallelism = 2660 getParallelismForDeviceType(*this, acc::DeviceType::None); 2661 2662 if (baseParallelism > 1) 2663 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can " 2664 "be present at the same time"; 2665 2666 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType(); 2667 ++dtypeInt) { 2668 auto dtype = static_cast<acc::DeviceType>(dtypeInt); 2669 if (dtype == acc::DeviceType::None) 2670 continue; 2671 unsigned parallelism = getParallelismForDeviceType(*this, dtype); 2672 2673 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1)) 2674 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can " 2675 "be present at the same time"; 2676 } 2677 2678 return success(); 2679 } 2680 2681 static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName, 2682 mlir::ArrayAttr &deviceTypes) { 2683 llvm::SmallVector<mlir::Attribute> bindNameAttrs; 2684 llvm::SmallVector<mlir::Attribute> deviceTypeAttrs; 2685 2686 if (failed(parser.parseCommaSeparatedList([&]() { 2687 if (parser.parseAttribute(bindNameAttrs.emplace_back())) 2688 return failure(); 2689 if (failed(parser.parseOptionalLSquare())) { 2690 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get( 2691 parser.getContext(), mlir::acc::DeviceType::None)); 2692 } else { 2693 if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) || 2694 parser.parseRSquare()) 2695 return failure(); 2696 } 2697 return success(); 2698 }))) 2699 return failure(); 2700 2701 bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs); 2702 deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs); 2703 2704 return success(); 2705 } 2706 2707 static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, 2708 std::optional<mlir::ArrayAttr> bindName, 2709 std::optional<mlir::ArrayAttr> deviceTypes) { 2710 llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p, 2711 [&](const auto &pair) { 2712 p << std::get<0>(pair); 2713 printSingleDeviceType(p, std::get<1>(pair)); 2714 }); 2715 } 2716 2717 static ParseResult parseRoutineGangClause(OpAsmParser &parser, 2718 mlir::ArrayAttr &gang, 2719 mlir::ArrayAttr &gangDim, 2720 mlir::ArrayAttr &gangDimDeviceTypes) { 2721 2722 llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs, 2723 gangDimDeviceTypeAttrs; 2724 bool needCommaBeforeOperands = false; 2725 2726 // Gang keyword only 2727 if (failed(parser.parseOptionalLParen())) { 2728 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get( 2729 parser.getContext(), mlir::acc::DeviceType::None)); 2730 gang = ArrayAttr::get(parser.getContext(), gangAttrs); 2731 return success(); 2732 } 2733 2734 // Parse keyword only attributes 2735 if (succeeded(parser.parseOptionalLSquare())) { 2736 if (failed(parser.parseCommaSeparatedList([&]() { 2737 if (parser.parseAttribute(gangAttrs.emplace_back())) 2738 return failure(); 2739 return success(); 2740 }))) 2741 return failure(); 2742 if (parser.parseRSquare()) 2743 return failure(); 2744 needCommaBeforeOperands = true; 2745 } 2746 2747 if (needCommaBeforeOperands && failed(parser.parseComma())) 2748 return failure(); 2749 2750 if (failed(parser.parseCommaSeparatedList([&]() { 2751 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) || 2752 parser.parseColon() || 2753 parser.parseAttribute(gangDimAttrs.emplace_back())) 2754 return failure(); 2755 if (succeeded(parser.parseOptionalLSquare())) { 2756 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) || 2757 parser.parseRSquare()) 2758 return failure(); 2759 } else { 2760 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get( 2761 parser.getContext(), mlir::acc::DeviceType::None)); 2762 } 2763 return success(); 2764 }))) 2765 return failure(); 2766 2767 if (failed(parser.parseRParen())) 2768 return failure(); 2769 2770 gang = ArrayAttr::get(parser.getContext(), gangAttrs); 2771 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs); 2772 gangDimDeviceTypes = 2773 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs); 2774 2775 return success(); 2776 } 2777 2778 void printRoutineGangClause(OpAsmPrinter &p, Operation *op, 2779 std::optional<mlir::ArrayAttr> gang, 2780 std::optional<mlir::ArrayAttr> gangDim, 2781 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) { 2782 2783 if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) && 2784 gang->size() == 1) { 2785 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]); 2786 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None) 2787 return; 2788 } 2789 2790 p << "("; 2791 2792 printDeviceTypes(p, gang); 2793 2794 if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes)) 2795 p << ", "; 2796 2797 if (hasDeviceTypeValues(gangDimDeviceTypes)) 2798 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p, 2799 [&](const auto &pair) { 2800 p << acc::RoutineOp::getGangDimKeyword() << ": "; 2801 p << std::get<0>(pair); 2802 printSingleDeviceType(p, std::get<1>(pair)); 2803 }); 2804 2805 p << ")"; 2806 } 2807 2808 static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, 2809 mlir::ArrayAttr &deviceTypes) { 2810 llvm::SmallVector<mlir::Attribute> attributes; 2811 // Keyword only 2812 if (failed(parser.parseOptionalLParen())) { 2813 attributes.push_back(mlir::acc::DeviceTypeAttr::get( 2814 parser.getContext(), mlir::acc::DeviceType::None)); 2815 deviceTypes = ArrayAttr::get(parser.getContext(), attributes); 2816 return success(); 2817 } 2818 2819 // Parse device type attributes 2820 if (succeeded(parser.parseOptionalLSquare())) { 2821 if (failed(parser.parseCommaSeparatedList([&]() { 2822 if (parser.parseAttribute(attributes.emplace_back())) 2823 return failure(); 2824 return success(); 2825 }))) 2826 return failure(); 2827 if (parser.parseRSquare() || parser.parseRParen()) 2828 return failure(); 2829 } 2830 deviceTypes = ArrayAttr::get(parser.getContext(), attributes); 2831 return success(); 2832 } 2833 2834 static void 2835 printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, 2836 std::optional<mlir::ArrayAttr> deviceTypes) { 2837 2838 if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) { 2839 auto deviceTypeAttr = 2840 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]); 2841 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None) 2842 return; 2843 } 2844 2845 if (!hasDeviceTypeValues(deviceTypes)) 2846 return; 2847 2848 p << "(["; 2849 llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) { 2850 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr); 2851 p << dTypeAttr; 2852 }); 2853 p << "])"; 2854 } 2855 2856 bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); } 2857 2858 bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) { 2859 return hasDeviceType(getWorker(), deviceType); 2860 } 2861 2862 bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); } 2863 2864 bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) { 2865 return hasDeviceType(getVector(), deviceType); 2866 } 2867 2868 bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); } 2869 2870 bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) { 2871 return hasDeviceType(getSeq(), deviceType); 2872 } 2873 2874 std::optional<llvm::StringRef> RoutineOp::getBindNameValue() { 2875 return getBindNameValue(mlir::acc::DeviceType::None); 2876 } 2877 2878 std::optional<llvm::StringRef> 2879 RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) { 2880 if (!hasDeviceTypeValues(getBindNameDeviceType())) 2881 return std::nullopt; 2882 if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) { 2883 auto attr = (*getBindName())[*pos]; 2884 auto stringAttr = dyn_cast<mlir::StringAttr>(attr); 2885 return stringAttr.getValue(); 2886 } 2887 return std::nullopt; 2888 } 2889 2890 bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); } 2891 2892 bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) { 2893 return hasDeviceType(getGang(), deviceType); 2894 } 2895 2896 std::optional<int64_t> RoutineOp::getGangDimValue() { 2897 return getGangDimValue(mlir::acc::DeviceType::None); 2898 } 2899 2900 std::optional<int64_t> 2901 RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) { 2902 if (!hasDeviceTypeValues(getGangDimDeviceType())) 2903 return std::nullopt; 2904 if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) { 2905 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]); 2906 return intAttr.getInt(); 2907 } 2908 return std::nullopt; 2909 } 2910 2911 //===----------------------------------------------------------------------===// 2912 // InitOp 2913 //===----------------------------------------------------------------------===// 2914 2915 LogicalResult acc::InitOp::verify() { 2916 Operation *currOp = *this; 2917 while ((currOp = currOp->getParentOp())) 2918 if (isComputeOperation(currOp)) 2919 return emitOpError("cannot be nested in a compute operation"); 2920 return success(); 2921 } 2922 2923 //===----------------------------------------------------------------------===// 2924 // ShutdownOp 2925 //===----------------------------------------------------------------------===// 2926 2927 LogicalResult acc::ShutdownOp::verify() { 2928 Operation *currOp = *this; 2929 while ((currOp = currOp->getParentOp())) 2930 if (isComputeOperation(currOp)) 2931 return emitOpError("cannot be nested in a compute operation"); 2932 return success(); 2933 } 2934 2935 //===----------------------------------------------------------------------===// 2936 // SetOp 2937 //===----------------------------------------------------------------------===// 2938 2939 LogicalResult acc::SetOp::verify() { 2940 Operation *currOp = *this; 2941 while ((currOp = currOp->getParentOp())) 2942 if (isComputeOperation(currOp)) 2943 return emitOpError("cannot be nested in a compute operation"); 2944 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum()) 2945 return emitOpError("at least one default_async, device_num, or device_type " 2946 "operand must appear"); 2947 return success(); 2948 } 2949 2950 //===----------------------------------------------------------------------===// 2951 // UpdateOp 2952 //===----------------------------------------------------------------------===// 2953 2954 LogicalResult acc::UpdateOp::verify() { 2955 // At least one of host or device should have a value. 2956 if (getDataClauseOperands().empty()) 2957 return emitError("at least one value must be present in dataOperands"); 2958 2959 if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(), 2960 getAsyncOperandsDeviceTypeAttr(), 2961 "async"))) 2962 return failure(); 2963 2964 if (failed(verifyDeviceTypeAndSegmentCountMatch( 2965 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(), 2966 getWaitOperandsDeviceTypeAttr(), "wait"))) 2967 return failure(); 2968 2969 if (failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*this))) 2970 return failure(); 2971 2972 for (mlir::Value operand : getDataClauseOperands()) 2973 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>( 2974 operand.getDefiningOp())) 2975 return emitError("expect data entry/exit operation or acc.getdeviceptr " 2976 "as defining op"); 2977 2978 return success(); 2979 } 2980 2981 unsigned UpdateOp::getNumDataOperands() { 2982 return getDataClauseOperands().size(); 2983 } 2984 2985 Value UpdateOp::getDataOperand(unsigned i) { 2986 unsigned numOptional = getAsyncOperands().size(); 2987 numOptional += getIfCond() ? 1 : 0; 2988 return getOperand(getWaitOperands().size() + numOptional + i); 2989 } 2990 2991 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results, 2992 MLIRContext *context) { 2993 results.add<RemoveConstantIfCondition<UpdateOp>>(context); 2994 } 2995 2996 bool UpdateOp::hasAsyncOnly() { 2997 return hasAsyncOnly(mlir::acc::DeviceType::None); 2998 } 2999 3000 bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) { 3001 return hasDeviceType(getAsync(), deviceType); 3002 } 3003 3004 mlir::Value UpdateOp::getAsyncValue() { 3005 return getAsyncValue(mlir::acc::DeviceType::None); 3006 } 3007 3008 mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) { 3009 if (!hasDeviceTypeValues(getAsyncOperandsDeviceType())) 3010 return {}; 3011 3012 if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType)) 3013 return getAsyncOperands()[*pos]; 3014 3015 return {}; 3016 } 3017 3018 bool UpdateOp::hasWaitOnly() { 3019 return hasWaitOnly(mlir::acc::DeviceType::None); 3020 } 3021 3022 bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { 3023 return hasDeviceType(getWaitOnly(), deviceType); 3024 } 3025 3026 mlir::Operation::operand_range UpdateOp::getWaitValues() { 3027 return getWaitValues(mlir::acc::DeviceType::None); 3028 } 3029 3030 mlir::Operation::operand_range 3031 UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) { 3032 return getWaitValuesWithoutDevnum( 3033 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(), 3034 getHasWaitDevnum(), deviceType); 3035 } 3036 3037 mlir::Value UpdateOp::getWaitDevnum() { 3038 return getWaitDevnum(mlir::acc::DeviceType::None); 3039 } 3040 3041 mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) { 3042 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(), 3043 getWaitOperandsSegments(), getHasWaitDevnum(), 3044 deviceType); 3045 } 3046 3047 //===----------------------------------------------------------------------===// 3048 // WaitOp 3049 //===----------------------------------------------------------------------===// 3050 3051 LogicalResult acc::WaitOp::verify() { 3052 // The async attribute represent the async clause without value. Therefore the 3053 // attribute and operand cannot appear at the same time. 3054 if (getAsyncOperand() && getAsync()) 3055 return emitError("async attribute cannot appear with asyncOperand"); 3056 3057 if (getWaitDevnum() && getWaitOperands().empty()) 3058 return emitError("wait_devnum cannot appear without waitOperands"); 3059 3060 return success(); 3061 } 3062 3063 #define GET_OP_CLASSES 3064 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" 3065 3066 #define GET_ATTRDEF_CLASSES 3067 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc" 3068 3069 #define GET_TYPEDEF_CLASSES 3070 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc" 3071 3072 //===----------------------------------------------------------------------===// 3073 // acc dialect utilities 3074 //===----------------------------------------------------------------------===// 3075 3076 mlir::TypedValue<mlir::acc::PointerLikeType> 3077 mlir::acc::getVarPtr(mlir::Operation *accDataClauseOp) { 3078 auto varPtr{llvm::TypeSwitch<mlir::Operation *, 3079 mlir::TypedValue<mlir::acc::PointerLikeType>>( 3080 accDataClauseOp) 3081 .Case<ACC_DATA_ENTRY_OPS>( 3082 [&](auto entry) { return entry.getVarPtr(); }) 3083 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>( 3084 [&](auto exit) { return exit.getVarPtr(); }) 3085 .Default([&](mlir::Operation *) { 3086 return mlir::TypedValue<mlir::acc::PointerLikeType>(); 3087 })}; 3088 return varPtr; 3089 } 3090 3091 mlir::Value mlir::acc::getVar(mlir::Operation *accDataClauseOp) { 3092 auto varPtr{ 3093 llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp) 3094 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getVar(); }) 3095 .Default([&](mlir::Operation *) { return mlir::Value(); })}; 3096 return varPtr; 3097 } 3098 3099 mlir::Type mlir::acc::getVarType(mlir::Operation *accDataClauseOp) { 3100 auto varType{llvm::TypeSwitch<mlir::Operation *, mlir::Type>(accDataClauseOp) 3101 .Case<ACC_DATA_ENTRY_OPS>( 3102 [&](auto entry) { return entry.getVarType(); }) 3103 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>( 3104 [&](auto exit) { return exit.getVarType(); }) 3105 .Default([&](mlir::Operation *) { return mlir::Type(); })}; 3106 return varType; 3107 } 3108 3109 mlir::TypedValue<mlir::acc::PointerLikeType> 3110 mlir::acc::getAccPtr(mlir::Operation *accDataClauseOp) { 3111 auto accPtr{llvm::TypeSwitch<mlir::Operation *, 3112 mlir::TypedValue<mlir::acc::PointerLikeType>>( 3113 accDataClauseOp) 3114 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>( 3115 [&](auto dataClause) { return dataClause.getAccPtr(); }) 3116 .Default([&](mlir::Operation *) { 3117 return mlir::TypedValue<mlir::acc::PointerLikeType>(); 3118 })}; 3119 return accPtr; 3120 } 3121 3122 mlir::Value mlir::acc::getAccVar(mlir::Operation *accDataClauseOp) { 3123 auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp) 3124 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>( 3125 [&](auto dataClause) { return dataClause.getAccVar(); }) 3126 .Default([&](mlir::Operation *) { return mlir::Value(); })}; 3127 return accPtr; 3128 } 3129 3130 mlir::Value mlir::acc::getVarPtrPtr(mlir::Operation *accDataClauseOp) { 3131 auto varPtrPtr{ 3132 llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp) 3133 .Case<ACC_DATA_ENTRY_OPS>( 3134 [&](auto dataClause) { return dataClause.getVarPtrPtr(); }) 3135 .Default([&](mlir::Operation *) { return mlir::Value(); })}; 3136 return varPtrPtr; 3137 } 3138 3139 mlir::SmallVector<mlir::Value> 3140 mlir::acc::getBounds(mlir::Operation *accDataClauseOp) { 3141 mlir::SmallVector<mlir::Value> bounds{ 3142 llvm::TypeSwitch<mlir::Operation *, mlir::SmallVector<mlir::Value>>( 3143 accDataClauseOp) 3144 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) { 3145 return mlir::SmallVector<mlir::Value>( 3146 dataClause.getBounds().begin(), dataClause.getBounds().end()); 3147 }) 3148 .Default([&](mlir::Operation *) { 3149 return mlir::SmallVector<mlir::Value, 0>(); 3150 })}; 3151 return bounds; 3152 } 3153 3154 mlir::SmallVector<mlir::Value> 3155 mlir::acc::getAsyncOperands(mlir::Operation *accDataClauseOp) { 3156 return llvm::TypeSwitch<mlir::Operation *, mlir::SmallVector<mlir::Value>>( 3157 accDataClauseOp) 3158 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) { 3159 return mlir::SmallVector<mlir::Value>( 3160 dataClause.getAsyncOperands().begin(), 3161 dataClause.getAsyncOperands().end()); 3162 }) 3163 .Default([&](mlir::Operation *) { 3164 return mlir::SmallVector<mlir::Value, 0>(); 3165 }); 3166 } 3167 3168 mlir::ArrayAttr 3169 mlir::acc::getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp) { 3170 return llvm::TypeSwitch<mlir::Operation *, mlir::ArrayAttr>(accDataClauseOp) 3171 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) { 3172 return dataClause.getAsyncOperandsDeviceTypeAttr(); 3173 }) 3174 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; }); 3175 } 3176 3177 mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) { 3178 return llvm::TypeSwitch<mlir::Operation *, mlir::ArrayAttr>(accDataClauseOp) 3179 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>( 3180 [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); }) 3181 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; }); 3182 } 3183 3184 std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) { 3185 auto name{ 3186 llvm::TypeSwitch<mlir::Operation *, std::optional<llvm::StringRef>>(accOp) 3187 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); }) 3188 .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> { 3189 return {}; 3190 })}; 3191 return name; 3192 } 3193 3194 std::optional<mlir::acc::DataClause> 3195 mlir::acc::getDataClause(mlir::Operation *accDataEntryOp) { 3196 auto dataClause{ 3197 llvm::TypeSwitch<mlir::Operation *, std::optional<mlir::acc::DataClause>>( 3198 accDataEntryOp) 3199 .Case<ACC_DATA_ENTRY_OPS>( 3200 [&](auto entry) { return entry.getDataClause(); }) 3201 .Default([&](mlir::Operation *) { return std::nullopt; })}; 3202 return dataClause; 3203 } 3204 3205 bool mlir::acc::getImplicitFlag(mlir::Operation *accDataEntryOp) { 3206 auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp) 3207 .Case<ACC_DATA_ENTRY_OPS>( 3208 [&](auto entry) { return entry.getImplicit(); }) 3209 .Default([&](mlir::Operation *) { return false; })}; 3210 return implicit; 3211 } 3212 3213 mlir::ValueRange mlir::acc::getDataOperands(mlir::Operation *accOp) { 3214 auto dataOperands{ 3215 llvm::TypeSwitch<mlir::Operation *, mlir::ValueRange>(accOp) 3216 .Case<ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS>( 3217 [&](auto entry) { return entry.getDataClauseOperands(); }) 3218 .Default([&](mlir::Operation *) { return mlir::ValueRange(); })}; 3219 return dataOperands; 3220 } 3221 3222 mlir::MutableOperandRange 3223 mlir::acc::getMutableDataOperands(mlir::Operation *accOp) { 3224 auto dataOperands{ 3225 llvm::TypeSwitch<mlir::Operation *, mlir::MutableOperandRange>(accOp) 3226 .Case<ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS>( 3227 [&](auto entry) { return entry.getDataClauseOperandsMutable(); }) 3228 .Default([&](mlir::Operation *) { return nullptr; })}; 3229 return dataOperands; 3230 } 3231