1 //===-- NVPTXISelDAGToDAG.cpp - A dag to dag inst selector for NVPTX ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines an instruction selector for the NVPTX target. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "NVPTXISelDAGToDAG.h" 14 #include "NVPTX.h" 15 #include "NVPTXUtilities.h" 16 #include "llvm/Analysis/ValueTracking.h" 17 #include "llvm/CodeGen/ISDOpcodes.h" 18 #include "llvm/CodeGen/SelectionDAGNodes.h" 19 #include "llvm/IR/GlobalValue.h" 20 #include "llvm/IR/Instructions.h" 21 #include "llvm/IR/IntrinsicsNVPTX.h" 22 #include "llvm/IR/NVVMIntrinsicUtils.h" 23 #include "llvm/Support/AtomicOrdering.h" 24 #include "llvm/Support/CommandLine.h" 25 #include "llvm/Support/ErrorHandling.h" 26 #include "llvm/Support/FormatVariadic.h" 27 #include "llvm/Target/TargetIntrinsicInfo.h" 28 29 using namespace llvm; 30 31 #define DEBUG_TYPE "nvptx-isel" 32 #define PASS_NAME "NVPTX DAG->DAG Pattern Instruction Selection" 33 34 static cl::opt<bool> 35 EnableRsqrtOpt("nvptx-rsqrt-approx-opt", cl::init(true), cl::Hidden, 36 cl::desc("Enable reciprocal sqrt optimization")); 37 38 /// createNVPTXISelDag - This pass converts a legalized DAG into a 39 /// NVPTX-specific DAG, ready for instruction scheduling. 40 FunctionPass *llvm::createNVPTXISelDag(NVPTXTargetMachine &TM, 41 llvm::CodeGenOptLevel OptLevel) { 42 return new NVPTXDAGToDAGISelLegacy(TM, OptLevel); 43 } 44 45 NVPTXDAGToDAGISelLegacy::NVPTXDAGToDAGISelLegacy(NVPTXTargetMachine &tm, 46 CodeGenOptLevel OptLevel) 47 : SelectionDAGISelLegacy( 48 ID, std::make_unique<NVPTXDAGToDAGISel>(tm, OptLevel)) {} 49 50 char NVPTXDAGToDAGISelLegacy::ID = 0; 51 52 INITIALIZE_PASS(NVPTXDAGToDAGISelLegacy, DEBUG_TYPE, PASS_NAME, false, false) 53 54 NVPTXDAGToDAGISel::NVPTXDAGToDAGISel(NVPTXTargetMachine &tm, 55 CodeGenOptLevel OptLevel) 56 : SelectionDAGISel(tm, OptLevel), TM(tm) { 57 doMulWide = (OptLevel > CodeGenOptLevel::None); 58 } 59 60 bool NVPTXDAGToDAGISel::runOnMachineFunction(MachineFunction &MF) { 61 Subtarget = &MF.getSubtarget<NVPTXSubtarget>(); 62 Scopes = NVPTXScopes(MF.getFunction().getContext()); 63 return SelectionDAGISel::runOnMachineFunction(MF); 64 } 65 66 int NVPTXDAGToDAGISel::getDivF32Level() const { 67 return Subtarget->getTargetLowering()->getDivF32Level(); 68 } 69 70 bool NVPTXDAGToDAGISel::usePrecSqrtF32() const { 71 return Subtarget->getTargetLowering()->usePrecSqrtF32(); 72 } 73 74 bool NVPTXDAGToDAGISel::useF32FTZ() const { 75 return Subtarget->getTargetLowering()->useF32FTZ(*MF); 76 } 77 78 bool NVPTXDAGToDAGISel::allowFMA() const { 79 const NVPTXTargetLowering *TL = Subtarget->getTargetLowering(); 80 return TL->allowFMA(*MF, OptLevel); 81 } 82 83 bool NVPTXDAGToDAGISel::allowUnsafeFPMath() const { 84 const NVPTXTargetLowering *TL = Subtarget->getTargetLowering(); 85 return TL->allowUnsafeFPMath(*MF); 86 } 87 88 bool NVPTXDAGToDAGISel::doRsqrtOpt() const { return EnableRsqrtOpt; } 89 90 /// Select - Select instructions not customized! Used for 91 /// expanded, promoted and normal instructions. 92 void NVPTXDAGToDAGISel::Select(SDNode *N) { 93 94 if (N->isMachineOpcode()) { 95 N->setNodeId(-1); 96 return; // Already selected. 97 } 98 99 switch (N->getOpcode()) { 100 case ISD::LOAD: 101 case ISD::ATOMIC_LOAD: 102 if (tryLoad(N)) 103 return; 104 break; 105 case ISD::STORE: 106 case ISD::ATOMIC_STORE: 107 if (tryStore(N)) 108 return; 109 break; 110 case ISD::ATOMIC_FENCE: 111 if (tryFence(N)) 112 return; 113 break; 114 case ISD::EXTRACT_VECTOR_ELT: 115 if (tryEXTRACT_VECTOR_ELEMENT(N)) 116 return; 117 break; 118 case NVPTXISD::SETP_F16X2: 119 SelectSETP_F16X2(N); 120 return; 121 case NVPTXISD::SETP_BF16X2: 122 SelectSETP_BF16X2(N); 123 return; 124 case NVPTXISD::LoadV2: 125 case NVPTXISD::LoadV4: 126 if (tryLoadVector(N)) 127 return; 128 break; 129 case NVPTXISD::LDUV2: 130 case NVPTXISD::LDUV4: 131 if (tryLDGLDU(N)) 132 return; 133 break; 134 case NVPTXISD::StoreV2: 135 case NVPTXISD::StoreV4: 136 if (tryStoreVector(N)) 137 return; 138 break; 139 case NVPTXISD::LoadParam: 140 case NVPTXISD::LoadParamV2: 141 case NVPTXISD::LoadParamV4: 142 if (tryLoadParam(N)) 143 return; 144 break; 145 case NVPTXISD::StoreRetval: 146 case NVPTXISD::StoreRetvalV2: 147 case NVPTXISD::StoreRetvalV4: 148 if (tryStoreRetval(N)) 149 return; 150 break; 151 case NVPTXISD::StoreParam: 152 case NVPTXISD::StoreParamV2: 153 case NVPTXISD::StoreParamV4: 154 case NVPTXISD::StoreParamS32: 155 case NVPTXISD::StoreParamU32: 156 if (tryStoreParam(N)) 157 return; 158 break; 159 case ISD::INTRINSIC_WO_CHAIN: 160 if (tryIntrinsicNoChain(N)) 161 return; 162 break; 163 case ISD::INTRINSIC_W_CHAIN: 164 if (tryIntrinsicChain(N)) 165 return; 166 break; 167 case ISD::INTRINSIC_VOID: 168 if (tryIntrinsicVoid(N)) 169 return; 170 break; 171 case ISD::AND: 172 case ISD::SRA: 173 case ISD::SRL: 174 // Try to select BFE 175 if (tryBFE(N)) 176 return; 177 break; 178 case ISD::ADDRSPACECAST: 179 SelectAddrSpaceCast(N); 180 return; 181 case ISD::CopyToReg: { 182 if (N->getOperand(1).getValueType() == MVT::i128) { 183 SelectV2I64toI128(N); 184 return; 185 } 186 break; 187 } 188 case ISD::CopyFromReg: { 189 if (N->getOperand(1).getValueType() == MVT::i128) { 190 SelectI128toV2I64(N); 191 return; 192 } 193 break; 194 } 195 case ISD::FADD: 196 case ISD::FMUL: 197 case ISD::FSUB: 198 if (tryBF16ArithToFMA(N)) 199 return; 200 break; 201 default: 202 break; 203 } 204 SelectCode(N); 205 } 206 207 bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) { 208 unsigned IID = N->getConstantOperandVal(1); 209 switch (IID) { 210 default: 211 return false; 212 case Intrinsic::nvvm_ldu_global_f: 213 case Intrinsic::nvvm_ldu_global_i: 214 case Intrinsic::nvvm_ldu_global_p: 215 return tryLDGLDU(N); 216 } 217 } 218 219 // Map ISD:CONDCODE value to appropriate CmpMode expected by 220 // NVPTXInstPrinter::printCmpMode() 221 static unsigned getPTXCmpMode(const CondCodeSDNode &CondCode, bool FTZ) { 222 using NVPTX::PTXCmpMode::CmpMode; 223 unsigned PTXCmpMode = [](ISD::CondCode CC) { 224 switch (CC) { 225 default: 226 llvm_unreachable("Unexpected condition code."); 227 case ISD::SETOEQ: 228 return CmpMode::EQ; 229 case ISD::SETOGT: 230 return CmpMode::GT; 231 case ISD::SETOGE: 232 return CmpMode::GE; 233 case ISD::SETOLT: 234 return CmpMode::LT; 235 case ISD::SETOLE: 236 return CmpMode::LE; 237 case ISD::SETONE: 238 return CmpMode::NE; 239 case ISD::SETO: 240 return CmpMode::NUM; 241 case ISD::SETUO: 242 return CmpMode::NotANumber; 243 case ISD::SETUEQ: 244 return CmpMode::EQU; 245 case ISD::SETUGT: 246 return CmpMode::GTU; 247 case ISD::SETUGE: 248 return CmpMode::GEU; 249 case ISD::SETULT: 250 return CmpMode::LTU; 251 case ISD::SETULE: 252 return CmpMode::LEU; 253 case ISD::SETUNE: 254 return CmpMode::NEU; 255 case ISD::SETEQ: 256 return CmpMode::EQ; 257 case ISD::SETGT: 258 return CmpMode::GT; 259 case ISD::SETGE: 260 return CmpMode::GE; 261 case ISD::SETLT: 262 return CmpMode::LT; 263 case ISD::SETLE: 264 return CmpMode::LE; 265 case ISD::SETNE: 266 return CmpMode::NE; 267 } 268 }(CondCode.get()); 269 270 if (FTZ) 271 PTXCmpMode |= NVPTX::PTXCmpMode::FTZ_FLAG; 272 273 return PTXCmpMode; 274 } 275 276 bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) { 277 unsigned PTXCmpMode = 278 getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)), useF32FTZ()); 279 SDLoc DL(N); 280 SDNode *SetP = CurDAG->getMachineNode( 281 NVPTX::SETP_f16x2rr, DL, MVT::i1, MVT::i1, N->getOperand(0), 282 N->getOperand(1), CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32)); 283 ReplaceNode(N, SetP); 284 return true; 285 } 286 287 bool NVPTXDAGToDAGISel::SelectSETP_BF16X2(SDNode *N) { 288 unsigned PTXCmpMode = 289 getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)), useF32FTZ()); 290 SDLoc DL(N); 291 SDNode *SetP = CurDAG->getMachineNode( 292 NVPTX::SETP_bf16x2rr, DL, MVT::i1, MVT::i1, N->getOperand(0), 293 N->getOperand(1), CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32)); 294 ReplaceNode(N, SetP); 295 return true; 296 } 297 298 // Find all instances of extract_vector_elt that use this v2f16 vector 299 // and coalesce them into a scattering move instruction. 300 bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) { 301 SDValue Vector = N->getOperand(0); 302 303 // We only care about 16x2 as it's the only real vector type we 304 // need to deal with. 305 MVT VT = Vector.getSimpleValueType(); 306 if (!Isv2x16VT(VT)) 307 return false; 308 // Find and record all uses of this vector that extract element 0 or 1. 309 SmallVector<SDNode *, 4> E0, E1; 310 for (auto *U : Vector.getNode()->users()) { 311 if (U->getOpcode() != ISD::EXTRACT_VECTOR_ELT) 312 continue; 313 if (U->getOperand(0) != Vector) 314 continue; 315 if (const ConstantSDNode *IdxConst = 316 dyn_cast<ConstantSDNode>(U->getOperand(1))) { 317 if (IdxConst->getZExtValue() == 0) 318 E0.push_back(U); 319 else if (IdxConst->getZExtValue() == 1) 320 E1.push_back(U); 321 else 322 llvm_unreachable("Invalid vector index."); 323 } 324 } 325 326 // There's no point scattering f16x2 if we only ever access one 327 // element of it. 328 if (E0.empty() || E1.empty()) 329 return false; 330 331 // Merge (f16 extractelt(V, 0), f16 extractelt(V,1)) 332 // into f16,f16 SplitF16x2(V) 333 MVT EltVT = VT.getVectorElementType(); 334 SDNode *ScatterOp = 335 CurDAG->getMachineNode(NVPTX::I32toV2I16, SDLoc(N), EltVT, EltVT, Vector); 336 for (auto *Node : E0) 337 ReplaceUses(SDValue(Node, 0), SDValue(ScatterOp, 0)); 338 for (auto *Node : E1) 339 ReplaceUses(SDValue(Node, 0), SDValue(ScatterOp, 1)); 340 341 return true; 342 } 343 344 static unsigned int getCodeAddrSpace(MemSDNode *N) { 345 const Value *Src = N->getMemOperand()->getValue(); 346 347 if (!Src) 348 return NVPTX::AddressSpace::Generic; 349 350 if (auto *PT = dyn_cast<PointerType>(Src->getType())) { 351 switch (PT->getAddressSpace()) { 352 case llvm::ADDRESS_SPACE_LOCAL: 353 return NVPTX::AddressSpace::Local; 354 case llvm::ADDRESS_SPACE_GLOBAL: 355 return NVPTX::AddressSpace::Global; 356 case llvm::ADDRESS_SPACE_SHARED: 357 return NVPTX::AddressSpace::Shared; 358 case llvm::ADDRESS_SPACE_GENERIC: 359 return NVPTX::AddressSpace::Generic; 360 case llvm::ADDRESS_SPACE_PARAM: 361 return NVPTX::AddressSpace::Param; 362 case llvm::ADDRESS_SPACE_CONST: 363 return NVPTX::AddressSpace::Const; 364 default: break; 365 } 366 } 367 return NVPTX::AddressSpace::Generic; 368 } 369 370 namespace { 371 372 struct OperationOrderings { 373 NVPTX::Ordering InstructionOrdering, FenceOrdering; 374 OperationOrderings(NVPTX::Ordering IO = NVPTX::Ordering::NotAtomic, 375 NVPTX::Ordering FO = NVPTX::Ordering::NotAtomic) 376 : InstructionOrdering(IO), FenceOrdering(FO) {} 377 }; 378 379 static OperationOrderings 380 getOperationOrderings(MemSDNode *N, const NVPTXSubtarget *Subtarget) { 381 AtomicOrdering Ordering = N->getSuccessOrdering(); 382 auto CodeAddrSpace = getCodeAddrSpace(N); 383 384 bool HasMemoryOrdering = Subtarget->hasMemoryOrdering(); 385 bool HasRelaxedMMIO = Subtarget->hasRelaxedMMIO(); 386 387 // clang-format off 388 389 // Lowering for Load/Store Operations (note: AcquireRelease Loads or Stores error). 390 // Note: uses of Relaxed in the Atomic column of this table refer 391 // to LLVM AtomicOrdering::Monotonic. 392 // 393 // | Atomic | Volatile | Statespace | PTX sm_60- | PTX sm_70+ | 394 // |---------|----------|--------------------|------------|------------------------------| 395 // | No | No | All | plain | .weak | 396 // | No | Yes | Generic,Shared, | .volatile | .volatile | 397 // | | | Global [0] | | | 398 // | No | Yes | Local,Const,Param | plain [1] | .weak [1] | 399 // | Unorder | Yes/No | All | == Relaxed | == Relaxed | 400 // | Relaxed | No | Generic,Shared, | .volatile | <atomic sem> | 401 // | | | Global [0] | | | 402 // | Other | No | Generic,Shared, | Error [2] | <atomic sem> | 403 // | | | Global [0] | | | 404 // | Yes | No | Local,Const,Param | plain [1] | .weak [1] | 405 // | Relaxed | Yes | Generic,Shared [0] | .volatile | .volatile | 406 // | Relaxed | Yes | Global [0] | .volatile | .mmio.relaxed.sys (PTX 8.2+) | 407 // | | | | | or .volatile (PTX 8.1-) | 408 // | Relaxed | Yes | Local,Const,Param | plain [1] | .weak [1] | 409 // | Other | Yes | Generic, Shared, | Error [2] | <atomic sem> [3] | 410 // | | | / Global [0] | | | 411 412 // Lowering of CUDA C++ SequentiallyConsistent Operations and Fences to PTX 413 // by following the ABI proven sound in: 414 // Lustig et al, A Formal Analysis of the NVIDIA PTX Memory Consistency Model, ASPLOS’19. 415 // https://dl.acm.org/doi/pdf/10.1145/3297858.3304043 416 // 417 // | CUDA C++ Atomic Operation or Atomic Fence | PTX Atomic Operation or Fence | 418 // |------------------------------------------------------|-------------------------------| 419 // | cuda::atomic_thread_fence | fence.sc.<scope>; | 420 // | (memory_order_seq_cst, cuda::thread_scope_<scope>) | | 421 // |------------------------------------------------------|-------------------------------| 422 // | cuda::atomic_load | fence.sc.<scope>; | 423 // | (memory_order_seq_cst, cuda::thread_scope_<scope>) | ld.acquire.<scope>; | 424 // |------------------------------------------------------|-------------------------------| 425 // | cuda::atomic_store | fence.sc.<scope>; | 426 // | (memory_order_seq_cst, cuda::thread_scope_<scope>) | st.release.<scope>; | 427 // |------------------------------------------------------|-------------------------------| 428 // | cuda::atomic_fetch_<op> | fence.sc.<scope>; | 429 // | (memory_order_seq_cst, cuda::thread_scope_<scope>) | atom.acq_rel.<scope>; | 430 431 // clang-format on 432 433 // [0]: volatile and atomics are only supported on global or shared 434 // memory locations, accessed via generic/shared/global pointers. 435 // MMIO is only supported on global memory locations, 436 // accessed via generic/global pointers. 437 // TODO: Implement MMIO access via generic pointer to global. 438 // Currently implemented for global pointers only. 439 440 // [1]: Lowering volatile/atomic operations to non-volatile/non-atomic 441 // PTX instructions fails to preserve their C++ side-effects. 442 // 443 // Example (https://github.com/llvm/llvm-project/issues/62057): 444 // 445 // void example() { 446 // std::atomic<bool> True = true; 447 // while (True.load(std::memory_order_relaxed)); 448 // } 449 // 450 // A C++ program that calls "example" is well-defined: the infinite loop 451 // performs an atomic operation. By lowering volatile/atomics to 452 // "weak" memory operations, we are transforming the above into: 453 // 454 // void undefined_behavior() { 455 // bool True = true; 456 // while (True); 457 // } 458 // 459 // which exhibits undefined behavior in both C++ and PTX. 460 // 461 // Calling "example" in CUDA C++ compiled for sm_60- exhibits undefined 462 // behavior due to lack of Independent Forward Progress. Lowering these 463 // to weak memory operations in sm_60- is therefore fine. 464 // 465 // TODO: lower atomic and volatile operations to memory locations 466 // in local, const, and param to two PTX instructions in sm_70+: 467 // - the "weak" memory instruction we are currently lowering to, and 468 // - some other instruction that preserves the side-effect, e.g., 469 // a dead dummy volatile load. 470 if (CodeAddrSpace == NVPTX::AddressSpace::Local || 471 CodeAddrSpace == NVPTX::AddressSpace::Const || 472 CodeAddrSpace == NVPTX::AddressSpace::Param) { 473 return NVPTX::Ordering::NotAtomic; 474 } 475 476 // [2]: Atomics with Ordering different than Unordered or Relaxed are not 477 // supported on sm_60 and older; this includes volatile atomics. 478 if (!(Ordering == AtomicOrdering::NotAtomic || 479 Ordering == AtomicOrdering::Unordered || 480 Ordering == AtomicOrdering::Monotonic) && 481 !HasMemoryOrdering) { 482 report_fatal_error( 483 formatv("PTX does not support \"atomic\" for orderings different than" 484 "\"NotAtomic\" or \"Monotonic\" for sm_60 or older, but order " 485 "is: \"{}\".", 486 toIRString(Ordering))); 487 } 488 489 // [3]: TODO: these should eventually use .mmio<.atomic sem>; for now we drop 490 // the volatile semantics and preserve the atomic ones. 491 492 // PTX volatile and PTX atomics are not available for statespace that differ 493 // from .generic, .global, or .shared. The behavior of PTX volatile and PTX 494 // atomics is undefined if the generic address does not refer to a .global or 495 // .shared memory location. 496 bool AddrGenericOrGlobalOrShared = 497 (CodeAddrSpace == NVPTX::AddressSpace::Generic || 498 CodeAddrSpace == NVPTX::AddressSpace::Global || 499 CodeAddrSpace == NVPTX::AddressSpace::Shared); 500 if (!AddrGenericOrGlobalOrShared) 501 return NVPTX::Ordering::NotAtomic; 502 503 bool UseRelaxedMMIO = 504 HasRelaxedMMIO && CodeAddrSpace == NVPTX::AddressSpace::Global; 505 506 switch (Ordering) { 507 case AtomicOrdering::NotAtomic: 508 return N->isVolatile() ? NVPTX::Ordering::Volatile 509 : NVPTX::Ordering::NotAtomic; 510 case AtomicOrdering::Unordered: 511 // We lower unordered in the exact same way as 'monotonic' to respect 512 // LLVM IR atomicity requirements. 513 case AtomicOrdering::Monotonic: 514 if (N->isVolatile()) 515 return UseRelaxedMMIO ? NVPTX::Ordering::RelaxedMMIO 516 : NVPTX::Ordering::Volatile; 517 else 518 return HasMemoryOrdering ? NVPTX::Ordering::Relaxed 519 : NVPTX::Ordering::Volatile; 520 // case AtomicOrdering::Consume: // If LLVM ever provides this, lower it to 521 // Acquire. 522 case AtomicOrdering::Acquire: 523 if (!N->readMem()) 524 report_fatal_error( 525 formatv("PTX only supports Acquire Ordering on reads: {}", 526 N->getOperationName())); 527 return NVPTX::Ordering::Acquire; 528 case AtomicOrdering::Release: 529 if (!N->writeMem()) 530 report_fatal_error( 531 formatv("PTX only supports Release Ordering on writes: {}", 532 N->getOperationName())); 533 return NVPTX::Ordering::Release; 534 case AtomicOrdering::AcquireRelease: { 535 report_fatal_error( 536 formatv("NVPTX does not support AcquireRelease Ordering on " 537 "read-modify-write " 538 "yet and PTX does not support it on loads or stores: {}", 539 N->getOperationName())); 540 } 541 case AtomicOrdering::SequentiallyConsistent: { 542 // LLVM-IR SequentiallyConsistent atomics map to a two-instruction PTX 543 // sequence including a "fence.sc.sco" and the memory instruction with an 544 // Ordering that differs from "sc": acq, rel, or acq_rel, depending on 545 // whether the memory operation is a read, write, or read-modify-write. 546 // 547 // This sets the ordering of the fence to SequentiallyConsistent, and 548 // sets the corresponding ordering for the instruction. 549 NVPTX::Ordering InstrOrder; 550 if (N->readMem()) 551 InstrOrder = NVPTX::Ordering::Acquire; 552 else if (N->writeMem()) 553 InstrOrder = NVPTX::Ordering::Release; 554 else 555 report_fatal_error( 556 formatv("NVPTX does not support SequentiallyConsistent Ordering on " 557 "read-modify-writes yet: {}", 558 N->getOperationName())); 559 return OperationOrderings(InstrOrder, 560 NVPTX::Ordering::SequentiallyConsistent); 561 } 562 } 563 report_fatal_error( 564 formatv("NVPTX backend does not support AtomicOrdering \"{}\" yet.", 565 toIRString(Ordering))); 566 } 567 568 } // namespace 569 570 NVPTX::Scope NVPTXDAGToDAGISel::getOperationScope(MemSDNode *N, 571 NVPTX::Ordering O) const { 572 switch (O) { 573 case NVPTX::Ordering::NotAtomic: 574 case NVPTX::Ordering::Volatile: // Non-atomic volatile operations 575 // NVPTX uses Thread scope as the scope of non-atomic operations. 576 return NVPTX::Scope::Thread; 577 case NVPTX::Ordering::RelaxedMMIO: 578 // RelaxedMMIO operations are always system scope. 579 // If a RelaxedMMIO order was generated from an atomic volatile operation 580 // with a smaller thread scope, we bump it here to system scope. 581 return NVPTX::Scope::System; 582 case NVPTX::Ordering::Relaxed: 583 case NVPTX::Ordering::Acquire: 584 case NVPTX::Ordering::Release: 585 case NVPTX::Ordering::AcquireRelease: 586 case NVPTX::Ordering::SequentiallyConsistent: 587 auto S = Scopes[N->getSyncScopeID()]; 588 589 // Atomic operations must have a scope greater than thread. 590 if (S == NVPTX::Scope::Thread) 591 report_fatal_error( 592 formatv("Atomics need scope > \"{}\".", ScopeToString(S))); 593 594 // If scope is cluster, clusters must be supported. 595 if (S == NVPTX::Scope::Cluster) 596 Subtarget->failIfClustersUnsupported("cluster scope"); 597 598 // If operation is volatile, then its scope is system. 599 return N->isVolatile() ? NVPTX::Scope::System : S; 600 } 601 llvm_unreachable("unhandled ordering"); 602 } 603 604 static bool canLowerToLDG(MemSDNode *N, const NVPTXSubtarget &Subtarget, 605 unsigned CodeAddrSpace, MachineFunction *F) { 606 // We use ldg (i.e. ld.global.nc) for invariant loads from the global address 607 // space. 608 // 609 // We have two ways of identifying invariant loads: Loads may be explicitly 610 // marked as invariant, or we may infer them to be invariant. 611 // 612 // We currently infer invariance for loads from 613 // - constant global variables, and 614 // - kernel function pointer params that are noalias (i.e. __restrict) and 615 // never written to. 616 // 617 // TODO: Perform a more powerful invariance analysis (ideally IPO, and ideally 618 // not during the SelectionDAG phase). 619 // 620 // TODO: Infer invariance only at -O2. We still want to use ldg at -O0 for 621 // explicitly invariant loads because these are how clang tells us to use ldg 622 // when the user uses a builtin. 623 if (!Subtarget.hasLDG() || CodeAddrSpace != NVPTX::AddressSpace::Global) 624 return false; 625 626 if (N->isInvariant()) 627 return true; 628 629 bool IsKernelFn = isKernelFunction(F->getFunction()); 630 631 // We use getUnderlyingObjects() here instead of getUnderlyingObject() mainly 632 // because the former looks through phi nodes while the latter does not. We 633 // need to look through phi nodes to handle pointer induction variables. 634 SmallVector<const Value *, 8> Objs; 635 getUnderlyingObjects(N->getMemOperand()->getValue(), Objs); 636 637 return all_of(Objs, [&](const Value *V) { 638 if (auto *A = dyn_cast<const Argument>(V)) 639 return IsKernelFn && A->onlyReadsMemory() && A->hasNoAliasAttr(); 640 if (auto *GV = dyn_cast<const GlobalVariable>(V)) 641 return GV->isConstant(); 642 return false; 643 }); 644 } 645 646 static unsigned int getFenceOp(NVPTX::Ordering O, NVPTX::Scope S, 647 NVPTXSubtarget const *T) { 648 if (S == NVPTX::Scope::Cluster) 649 T->failIfClustersUnsupported(".cluster scope fence"); 650 651 switch (O) { 652 case NVPTX::Ordering::Acquire: 653 case NVPTX::Ordering::Release: 654 case NVPTX::Ordering::AcquireRelease: { 655 switch (S) { 656 case NVPTX::Scope::System: 657 return T->hasMemoryOrdering() ? NVPTX::atomic_thread_fence_acq_rel_sys 658 : NVPTX::INT_MEMBAR_SYS; 659 case NVPTX::Scope::Block: 660 return T->hasMemoryOrdering() ? NVPTX::atomic_thread_fence_acq_rel_cta 661 : NVPTX::INT_MEMBAR_CTA; 662 case NVPTX::Scope::Cluster: 663 return NVPTX::atomic_thread_fence_acq_rel_cluster; 664 case NVPTX::Scope::Device: 665 return T->hasMemoryOrdering() ? NVPTX::atomic_thread_fence_acq_rel_gpu 666 : NVPTX::INT_MEMBAR_GL; 667 case NVPTX::Scope::Thread: 668 report_fatal_error( 669 formatv("Unsupported scope \"{}\" for acquire/release/acq_rel fence.", 670 ScopeToString(S))); 671 } 672 break; 673 } 674 case NVPTX::Ordering::SequentiallyConsistent: { 675 switch (S) { 676 case NVPTX::Scope::System: 677 return T->hasMemoryOrdering() ? NVPTX::atomic_thread_fence_seq_cst_sys 678 : NVPTX::INT_MEMBAR_SYS; 679 case NVPTX::Scope::Block: 680 return T->hasMemoryOrdering() ? NVPTX::atomic_thread_fence_seq_cst_cta 681 : NVPTX::INT_MEMBAR_CTA; 682 case NVPTX::Scope::Cluster: 683 return NVPTX::atomic_thread_fence_seq_cst_cluster; 684 case NVPTX::Scope::Device: 685 return T->hasMemoryOrdering() ? NVPTX::atomic_thread_fence_seq_cst_gpu 686 : NVPTX::INT_MEMBAR_GL; 687 case NVPTX::Scope::Thread: 688 report_fatal_error(formatv("Unsupported scope \"{}\" for seq_cst fence.", 689 ScopeToString(S))); 690 } 691 break; 692 } 693 case NVPTX::Ordering::NotAtomic: 694 case NVPTX::Ordering::Relaxed: 695 case NVPTX::Ordering::Volatile: 696 case NVPTX::Ordering::RelaxedMMIO: 697 report_fatal_error( 698 formatv("Unsupported \"{}\" ordering and \"{}\" scope for fence.", 699 OrderingToString(O), ScopeToString(S))); 700 } 701 llvm_unreachable("unhandled ordering"); 702 } 703 704 // Returns Memory Order and Scope of a memory instruction, and 705 // inserts any fence before the instruction that's required to 706 // implement its memory ordering. 707 std::pair<NVPTX::Ordering, NVPTX::Scope> 708 NVPTXDAGToDAGISel::insertMemoryInstructionFence(SDLoc DL, SDValue &Chain, 709 MemSDNode *N) { 710 auto [InstructionOrdering, FenceOrdering] = 711 getOperationOrderings(N, Subtarget); 712 auto Scope = getOperationScope(N, InstructionOrdering); 713 714 // If a fence is required before the operation, insert it: 715 switch (NVPTX::Ordering(FenceOrdering)) { 716 case NVPTX::Ordering::NotAtomic: 717 break; 718 case NVPTX::Ordering::SequentiallyConsistent: { 719 auto Op = getFenceOp(FenceOrdering, Scope, Subtarget); 720 Chain = SDValue(CurDAG->getMachineNode(Op, DL, MVT::Other, Chain), 0); 721 break; 722 } 723 default: 724 report_fatal_error( 725 formatv("Unexpected fence ordering: \"{}\".", 726 OrderingToString(NVPTX::Ordering(FenceOrdering)))); 727 } 728 return {InstructionOrdering, Scope}; 729 } 730 731 bool NVPTXDAGToDAGISel::tryIntrinsicNoChain(SDNode *N) { 732 unsigned IID = N->getConstantOperandVal(0); 733 switch (IID) { 734 default: 735 return false; 736 case Intrinsic::nvvm_texsurf_handle_internal: 737 SelectTexSurfHandle(N); 738 return true; 739 } 740 } 741 742 void NVPTXDAGToDAGISel::SelectTexSurfHandle(SDNode *N) { 743 // Op 0 is the intrinsic ID 744 SDValue Wrapper = N->getOperand(1); 745 SDValue GlobalVal = Wrapper.getOperand(0); 746 ReplaceNode(N, CurDAG->getMachineNode(NVPTX::texsurf_handles, SDLoc(N), 747 MVT::i64, GlobalVal)); 748 } 749 750 void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) { 751 SDValue Src = N->getOperand(0); 752 AddrSpaceCastSDNode *CastN = cast<AddrSpaceCastSDNode>(N); 753 unsigned SrcAddrSpace = CastN->getSrcAddressSpace(); 754 unsigned DstAddrSpace = CastN->getDestAddressSpace(); 755 SDLoc DL(N); 756 assert(SrcAddrSpace != DstAddrSpace && 757 "addrspacecast must be between different address spaces"); 758 759 if (DstAddrSpace == ADDRESS_SPACE_GENERIC) { 760 // Specific to generic 761 762 if (TM.is64Bit() && TM.getPointerSizeInBits(SrcAddrSpace) == 32) { 763 SDValue CvtNone = 764 CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32); 765 SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_u64_u32, DL, MVT::i64, 766 Src, CvtNone); 767 Src = SDValue(Cvt, 0); 768 } 769 770 unsigned Opc; 771 switch (SrcAddrSpace) { 772 default: report_fatal_error("Bad address space in addrspacecast"); 773 case ADDRESS_SPACE_GLOBAL: 774 Opc = TM.is64Bit() ? NVPTX::cvta_global_64 : NVPTX::cvta_global; 775 break; 776 case ADDRESS_SPACE_SHARED: 777 Opc = TM.is64Bit() ? NVPTX::cvta_shared_64 : NVPTX::cvta_shared; 778 break; 779 case ADDRESS_SPACE_CONST: 780 Opc = TM.is64Bit() ? NVPTX::cvta_const_64 : NVPTX::cvta_const; 781 break; 782 case ADDRESS_SPACE_LOCAL: 783 Opc = TM.is64Bit() ? NVPTX::cvta_local_64 : NVPTX::cvta_local; 784 break; 785 } 786 ReplaceNode(N, CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src)); 787 return; 788 } else { 789 // Generic to specific 790 if (SrcAddrSpace != 0) 791 report_fatal_error("Cannot cast between two non-generic address spaces"); 792 unsigned Opc; 793 switch (DstAddrSpace) { 794 default: report_fatal_error("Bad address space in addrspacecast"); 795 case ADDRESS_SPACE_GLOBAL: 796 Opc = TM.is64Bit() ? NVPTX::cvta_to_global_64 : NVPTX::cvta_to_global; 797 break; 798 case ADDRESS_SPACE_SHARED: 799 Opc = TM.is64Bit() ? NVPTX::cvta_to_shared_64 : NVPTX::cvta_to_shared; 800 break; 801 case ADDRESS_SPACE_CONST: 802 Opc = TM.is64Bit() ? NVPTX::cvta_to_const_64 : NVPTX::cvta_to_const; 803 break; 804 case ADDRESS_SPACE_LOCAL: 805 Opc = TM.is64Bit() ? NVPTX::cvta_to_local_64 : NVPTX::cvta_to_local; 806 break; 807 case ADDRESS_SPACE_PARAM: 808 Opc = TM.is64Bit() ? NVPTX::IMOV64rr : NVPTX::IMOV32rr; 809 break; 810 } 811 812 SDNode *CVTA = CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src); 813 if (TM.is64Bit() && TM.getPointerSizeInBits(DstAddrSpace) == 32) { 814 SDValue CvtNone = 815 CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32); 816 CVTA = CurDAG->getMachineNode(NVPTX::CVT_u32_u64, DL, MVT::i32, 817 SDValue(CVTA, 0), CvtNone); 818 } 819 820 ReplaceNode(N, CVTA); 821 return; 822 } 823 } 824 825 // Helper function template to reduce amount of boilerplate code for 826 // opcode selection. 827 static std::optional<unsigned> 828 pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8, 829 unsigned Opcode_i16, unsigned Opcode_i32, 830 std::optional<unsigned> Opcode_i64, unsigned Opcode_f32, 831 std::optional<unsigned> Opcode_f64) { 832 switch (VT) { 833 case MVT::i1: 834 case MVT::i8: 835 return Opcode_i8; 836 case MVT::i16: 837 return Opcode_i16; 838 case MVT::i32: 839 return Opcode_i32; 840 case MVT::i64: 841 return Opcode_i64; 842 case MVT::f16: 843 case MVT::bf16: 844 return Opcode_i16; 845 case MVT::v2f16: 846 case MVT::v2bf16: 847 case MVT::v2i16: 848 case MVT::v4i8: 849 return Opcode_i32; 850 case MVT::f32: 851 return Opcode_f32; 852 case MVT::f64: 853 return Opcode_f64; 854 default: 855 return std::nullopt; 856 } 857 } 858 859 static int getLdStRegType(EVT VT) { 860 if (VT.isFloatingPoint()) 861 switch (VT.getSimpleVT().SimpleTy) { 862 case MVT::f16: 863 case MVT::bf16: 864 case MVT::v2f16: 865 case MVT::v2bf16: 866 return NVPTX::PTXLdStInstCode::Untyped; 867 default: 868 return NVPTX::PTXLdStInstCode::Float; 869 } 870 else 871 return NVPTX::PTXLdStInstCode::Unsigned; 872 } 873 874 bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { 875 MemSDNode *LD = cast<MemSDNode>(N); 876 assert(LD->readMem() && "Expected load"); 877 878 // do not support pre/post inc/dec 879 LoadSDNode *PlainLoad = dyn_cast<LoadSDNode>(N); 880 if (PlainLoad && PlainLoad->isIndexed()) 881 return false; 882 883 EVT LoadedVT = LD->getMemoryVT(); 884 if (!LoadedVT.isSimple()) 885 return false; 886 887 // Address Space Setting 888 unsigned int CodeAddrSpace = getCodeAddrSpace(LD); 889 if (canLowerToLDG(LD, *Subtarget, CodeAddrSpace, MF)) { 890 return tryLDGLDU(N); 891 } 892 unsigned int PointerSize = 893 CurDAG->getDataLayout().getPointerSizeInBits(LD->getAddressSpace()); 894 895 SDLoc DL(N); 896 SDValue Chain = N->getOperand(0); 897 auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, LD); 898 899 // Type Setting: fromType + fromTypeWidth 900 // 901 // Sign : ISD::SEXTLOAD 902 // Unsign : ISD::ZEXTLOAD, ISD::NON_EXTLOAD or ISD::EXTLOAD and the 903 // type is integer 904 // Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float 905 MVT SimpleVT = LoadedVT.getSimpleVT(); 906 MVT ScalarVT = SimpleVT.getScalarType(); 907 // Read at least 8 bits (predicates are stored as 8-bit values) 908 unsigned FromTypeWidth = std::max(8U, (unsigned)ScalarVT.getSizeInBits()); 909 unsigned int FromType; 910 911 // Vector Setting 912 unsigned VecType = NVPTX::PTXLdStInstCode::Scalar; 913 if (SimpleVT.isVector()) { 914 assert((Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8) && 915 "Unexpected vector type"); 916 // v2f16/v2bf16/v2i16 is loaded using ld.b32 917 FromTypeWidth = 32; 918 } 919 920 if (PlainLoad && (PlainLoad->getExtensionType() == ISD::SEXTLOAD)) 921 FromType = NVPTX::PTXLdStInstCode::Signed; 922 else 923 FromType = getLdStRegType(ScalarVT); 924 925 // Create the machine instruction DAG 926 SDValue N1 = N->getOperand(1); 927 SDValue Addr; 928 SDValue Offset, Base; 929 std::optional<unsigned> Opcode; 930 MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy; 931 932 SmallVector<SDValue, 12> Ops({getI32Imm(Ordering, DL), getI32Imm(Scope, DL), 933 getI32Imm(CodeAddrSpace, DL), 934 getI32Imm(VecType, DL), getI32Imm(FromType, DL), 935 getI32Imm(FromTypeWidth, DL)}); 936 937 if (SelectDirectAddr(N1, Addr)) { 938 Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_i8_avar, NVPTX::LD_i16_avar, 939 NVPTX::LD_i32_avar, NVPTX::LD_i64_avar, 940 NVPTX::LD_f32_avar, NVPTX::LD_f64_avar); 941 if (!Opcode) 942 return false; 943 Ops.append({Addr, Chain}); 944 } else if (PointerSize == 64 ? SelectADDRsi64(N1.getNode(), N1, Base, Offset) 945 : SelectADDRsi(N1.getNode(), N1, Base, Offset)) { 946 Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_i8_asi, NVPTX::LD_i16_asi, 947 NVPTX::LD_i32_asi, NVPTX::LD_i64_asi, 948 NVPTX::LD_f32_asi, NVPTX::LD_f64_asi); 949 if (!Opcode) 950 return false; 951 Ops.append({Base, Offset, Chain}); 952 } else if (PointerSize == 64 ? SelectADDRri64(N1.getNode(), N1, Base, Offset) 953 : SelectADDRri(N1.getNode(), N1, Base, Offset)) { 954 if (PointerSize == 64) 955 Opcode = 956 pickOpcodeForVT(TargetVT, NVPTX::LD_i8_ari_64, NVPTX::LD_i16_ari_64, 957 NVPTX::LD_i32_ari_64, NVPTX::LD_i64_ari_64, 958 NVPTX::LD_f32_ari_64, NVPTX::LD_f64_ari_64); 959 else 960 Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_i8_ari, NVPTX::LD_i16_ari, 961 NVPTX::LD_i32_ari, NVPTX::LD_i64_ari, 962 NVPTX::LD_f32_ari, NVPTX::LD_f64_ari); 963 if (!Opcode) 964 return false; 965 Ops.append({Base, Offset, Chain}); 966 } else { 967 if (PointerSize == 64) 968 Opcode = 969 pickOpcodeForVT(TargetVT, NVPTX::LD_i8_areg_64, NVPTX::LD_i16_areg_64, 970 NVPTX::LD_i32_areg_64, NVPTX::LD_i64_areg_64, 971 NVPTX::LD_f32_areg_64, NVPTX::LD_f64_areg_64); 972 else 973 Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_i8_areg, NVPTX::LD_i16_areg, 974 NVPTX::LD_i32_areg, NVPTX::LD_i64_areg, 975 NVPTX::LD_f32_areg, NVPTX::LD_f64_areg); 976 if (!Opcode) 977 return false; 978 Ops.append({N1, Chain}); 979 } 980 981 SDNode *NVPTXLD = 982 CurDAG->getMachineNode(*Opcode, DL, TargetVT, MVT::Other, Ops); 983 if (!NVPTXLD) 984 return false; 985 986 MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand(); 987 CurDAG->setNodeMemRefs(cast<MachineSDNode>(NVPTXLD), {MemRef}); 988 989 ReplaceNode(N, NVPTXLD); 990 return true; 991 } 992 993 static bool isVectorElementTypeUpsized(EVT EltVT) { 994 // Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for 995 // total load/store size, PTX syntax only supports v2/v4. Thus, we can't use 996 // vectorized loads/stores with the actual element type for i8/i16 as that 997 // would require v8/v16 variants that do not exist. 998 // In order to load/store such vectors efficiently, in Type Legalization 999 // we split the vector into word-sized chunks (v2x16/v4i8). Now, we will 1000 // lower to PTX as vectors of b32. 1001 return Isv2x16VT(EltVT) || EltVT == MVT::v4i8; 1002 } 1003 1004 bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) { 1005 MemSDNode *MemSD = cast<MemSDNode>(N); 1006 EVT LoadedVT = MemSD->getMemoryVT(); 1007 if (!LoadedVT.isSimple()) 1008 return false; 1009 1010 // Address Space Setting 1011 unsigned int CodeAddrSpace = getCodeAddrSpace(MemSD); 1012 if (canLowerToLDG(MemSD, *Subtarget, CodeAddrSpace, MF)) { 1013 return tryLDGLDU(N); 1014 } 1015 unsigned int PointerSize = 1016 CurDAG->getDataLayout().getPointerSizeInBits(MemSD->getAddressSpace()); 1017 1018 SDLoc DL(N); 1019 SDValue Chain = N->getOperand(0); 1020 auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, MemSD); 1021 1022 // Vector Setting 1023 MVT SimpleVT = LoadedVT.getSimpleVT(); 1024 1025 // Type Setting: fromType + fromTypeWidth 1026 // 1027 // Sign : ISD::SEXTLOAD 1028 // Unsign : ISD::ZEXTLOAD, ISD::NON_EXTLOAD or ISD::EXTLOAD and the 1029 // type is integer 1030 // Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float 1031 MVT ScalarVT = SimpleVT.getScalarType(); 1032 // Read at least 8 bits (predicates are stored as 8-bit values) 1033 unsigned FromTypeWidth = std::max(8U, (unsigned)ScalarVT.getSizeInBits()); 1034 unsigned int FromType; 1035 // The last operand holds the original LoadSDNode::getExtensionType() value 1036 unsigned ExtensionType = cast<ConstantSDNode>( 1037 N->getOperand(N->getNumOperands() - 1))->getZExtValue(); 1038 if (ExtensionType == ISD::SEXTLOAD) 1039 FromType = NVPTX::PTXLdStInstCode::Signed; 1040 else 1041 FromType = getLdStRegType(ScalarVT); 1042 1043 unsigned VecType; 1044 1045 switch (N->getOpcode()) { 1046 case NVPTXISD::LoadV2: 1047 VecType = NVPTX::PTXLdStInstCode::V2; 1048 break; 1049 case NVPTXISD::LoadV4: 1050 VecType = NVPTX::PTXLdStInstCode::V4; 1051 break; 1052 default: 1053 return false; 1054 } 1055 1056 EVT EltVT = N->getValueType(0); 1057 1058 if (isVectorElementTypeUpsized(EltVT)) { 1059 EltVT = MVT::i32; 1060 FromType = NVPTX::PTXLdStInstCode::Untyped; 1061 FromTypeWidth = 32; 1062 } 1063 1064 SDValue Op1 = N->getOperand(1); 1065 SDValue Addr, Offset, Base; 1066 std::optional<unsigned> Opcode; 1067 SDNode *LD; 1068 1069 SmallVector<SDValue, 12> Ops({getI32Imm(Ordering, DL), getI32Imm(Scope, DL), 1070 getI32Imm(CodeAddrSpace, DL), 1071 getI32Imm(VecType, DL), getI32Imm(FromType, DL), 1072 getI32Imm(FromTypeWidth, DL)}); 1073 1074 if (SelectDirectAddr(Op1, Addr)) { 1075 switch (N->getOpcode()) { 1076 default: 1077 return false; 1078 case NVPTXISD::LoadV2: 1079 Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1080 NVPTX::LDV_i8_v2_avar, NVPTX::LDV_i16_v2_avar, 1081 NVPTX::LDV_i32_v2_avar, NVPTX::LDV_i64_v2_avar, 1082 NVPTX::LDV_f32_v2_avar, NVPTX::LDV_f64_v2_avar); 1083 break; 1084 case NVPTXISD::LoadV4: 1085 Opcode = 1086 pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4_avar, 1087 NVPTX::LDV_i16_v4_avar, NVPTX::LDV_i32_v4_avar, 1088 std::nullopt, NVPTX::LDV_f32_v4_avar, std::nullopt); 1089 break; 1090 } 1091 if (!Opcode) 1092 return false; 1093 Ops.append({Addr, Chain}); 1094 } else if (PointerSize == 64 1095 ? SelectADDRsi64(Op1.getNode(), Op1, Base, Offset) 1096 : SelectADDRsi(Op1.getNode(), Op1, Base, Offset)) { 1097 switch (N->getOpcode()) { 1098 default: 1099 return false; 1100 case NVPTXISD::LoadV2: 1101 Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1102 NVPTX::LDV_i8_v2_asi, NVPTX::LDV_i16_v2_asi, 1103 NVPTX::LDV_i32_v2_asi, NVPTX::LDV_i64_v2_asi, 1104 NVPTX::LDV_f32_v2_asi, NVPTX::LDV_f64_v2_asi); 1105 break; 1106 case NVPTXISD::LoadV4: 1107 Opcode = 1108 pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4_asi, 1109 NVPTX::LDV_i16_v4_asi, NVPTX::LDV_i32_v4_asi, 1110 std::nullopt, NVPTX::LDV_f32_v4_asi, std::nullopt); 1111 break; 1112 } 1113 if (!Opcode) 1114 return false; 1115 Ops.append({Base, Offset, Chain}); 1116 } else if (PointerSize == 64 1117 ? SelectADDRri64(Op1.getNode(), Op1, Base, Offset) 1118 : SelectADDRri(Op1.getNode(), Op1, Base, Offset)) { 1119 if (PointerSize == 64) { 1120 switch (N->getOpcode()) { 1121 default: 1122 return false; 1123 case NVPTXISD::LoadV2: 1124 Opcode = 1125 pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1126 NVPTX::LDV_i8_v2_ari_64, NVPTX::LDV_i16_v2_ari_64, 1127 NVPTX::LDV_i32_v2_ari_64, NVPTX::LDV_i64_v2_ari_64, 1128 NVPTX::LDV_f32_v2_ari_64, NVPTX::LDV_f64_v2_ari_64); 1129 break; 1130 case NVPTXISD::LoadV4: 1131 Opcode = pickOpcodeForVT( 1132 EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4_ari_64, 1133 NVPTX::LDV_i16_v4_ari_64, NVPTX::LDV_i32_v4_ari_64, std::nullopt, 1134 NVPTX::LDV_f32_v4_ari_64, std::nullopt); 1135 break; 1136 } 1137 } else { 1138 switch (N->getOpcode()) { 1139 default: 1140 return false; 1141 case NVPTXISD::LoadV2: 1142 Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1143 NVPTX::LDV_i8_v2_ari, NVPTX::LDV_i16_v2_ari, 1144 NVPTX::LDV_i32_v2_ari, NVPTX::LDV_i64_v2_ari, 1145 NVPTX::LDV_f32_v2_ari, NVPTX::LDV_f64_v2_ari); 1146 break; 1147 case NVPTXISD::LoadV4: 1148 Opcode = 1149 pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4_ari, 1150 NVPTX::LDV_i16_v4_ari, NVPTX::LDV_i32_v4_ari, 1151 std::nullopt, NVPTX::LDV_f32_v4_ari, std::nullopt); 1152 break; 1153 } 1154 } 1155 if (!Opcode) 1156 return false; 1157 Ops.append({Base, Offset, Chain}); 1158 } else { 1159 if (PointerSize == 64) { 1160 switch (N->getOpcode()) { 1161 default: 1162 return false; 1163 case NVPTXISD::LoadV2: 1164 Opcode = pickOpcodeForVT( 1165 EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v2_areg_64, 1166 NVPTX::LDV_i16_v2_areg_64, NVPTX::LDV_i32_v2_areg_64, 1167 NVPTX::LDV_i64_v2_areg_64, NVPTX::LDV_f32_v2_areg_64, 1168 NVPTX::LDV_f64_v2_areg_64); 1169 break; 1170 case NVPTXISD::LoadV4: 1171 Opcode = pickOpcodeForVT( 1172 EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4_areg_64, 1173 NVPTX::LDV_i16_v4_areg_64, NVPTX::LDV_i32_v4_areg_64, std::nullopt, 1174 NVPTX::LDV_f32_v4_areg_64, std::nullopt); 1175 break; 1176 } 1177 } else { 1178 switch (N->getOpcode()) { 1179 default: 1180 return false; 1181 case NVPTXISD::LoadV2: 1182 Opcode = 1183 pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v2_areg, 1184 NVPTX::LDV_i16_v2_areg, NVPTX::LDV_i32_v2_areg, 1185 NVPTX::LDV_i64_v2_areg, NVPTX::LDV_f32_v2_areg, 1186 NVPTX::LDV_f64_v2_areg); 1187 break; 1188 case NVPTXISD::LoadV4: 1189 Opcode = 1190 pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4_areg, 1191 NVPTX::LDV_i16_v4_areg, NVPTX::LDV_i32_v4_areg, 1192 std::nullopt, NVPTX::LDV_f32_v4_areg, std::nullopt); 1193 break; 1194 } 1195 } 1196 if (!Opcode) 1197 return false; 1198 Ops.append({Op1, Chain}); 1199 } 1200 LD = CurDAG->getMachineNode(*Opcode, DL, N->getVTList(), Ops); 1201 1202 MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand(); 1203 CurDAG->setNodeMemRefs(cast<MachineSDNode>(LD), {MemRef}); 1204 1205 ReplaceNode(N, LD); 1206 return true; 1207 } 1208 1209 bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) { 1210 auto *Mem = cast<MemSDNode>(N); 1211 1212 // If this is an LDG intrinsic, the address is the third operand. If its an 1213 // LDG/LDU SD node (from custom vector handling), then its the second operand 1214 SDValue Op1 = N->getOperand(N->getOpcode() == ISD::INTRINSIC_W_CHAIN ? 2 : 1); 1215 1216 EVT OrigType = N->getValueType(0); 1217 EVT EltVT = Mem->getMemoryVT(); 1218 unsigned NumElts = 1; 1219 if (EltVT.isVector()) { 1220 NumElts = EltVT.getVectorNumElements(); 1221 EltVT = EltVT.getVectorElementType(); 1222 // vectors of 8/16bits type are loaded/stored as multiples of v4i8/v2x16 1223 // elements. 1224 if ((EltVT == MVT::f16 && OrigType == MVT::v2f16) || 1225 (EltVT == MVT::bf16 && OrigType == MVT::v2bf16) || 1226 (EltVT == MVT::i16 && OrigType == MVT::v2i16) || 1227 (EltVT == MVT::i8 && OrigType == MVT::v4i8)) { 1228 assert(NumElts % OrigType.getVectorNumElements() == 0 && 1229 "NumElts must be divisible by the number of elts in subvectors"); 1230 EltVT = OrigType; 1231 NumElts /= OrigType.getVectorNumElements(); 1232 } 1233 } 1234 1235 // Build the "promoted" result VTList for the load. If we are really loading 1236 // i8s, then the return type will be promoted to i16 since we do not expose 1237 // 8-bit registers in NVPTX. 1238 EVT NodeVT = (EltVT == MVT::i8) ? MVT::i16 : EltVT; 1239 SmallVector<EVT, 5> InstVTs; 1240 for (unsigned i = 0; i != NumElts; ++i) { 1241 InstVTs.push_back(NodeVT); 1242 } 1243 InstVTs.push_back(MVT::Other); 1244 SDVTList InstVTList = CurDAG->getVTList(InstVTs); 1245 SDValue Chain = N->getOperand(0); 1246 1247 std::optional<unsigned> Opcode; 1248 SDLoc DL(N); 1249 SDNode *LD; 1250 SDValue Base, Offset, Addr; 1251 1252 if (SelectDirectAddr(Op1, Addr)) { 1253 switch (N->getOpcode()) { 1254 default: 1255 return false; 1256 case ISD::LOAD: 1257 Opcode = pickOpcodeForVT( 1258 EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8avar, 1259 NVPTX::INT_PTX_LDG_GLOBAL_i16avar, NVPTX::INT_PTX_LDG_GLOBAL_i32avar, 1260 NVPTX::INT_PTX_LDG_GLOBAL_i64avar, NVPTX::INT_PTX_LDG_GLOBAL_f32avar, 1261 NVPTX::INT_PTX_LDG_GLOBAL_f64avar); 1262 break; 1263 case ISD::INTRINSIC_W_CHAIN: 1264 Opcode = pickOpcodeForVT( 1265 EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8avar, 1266 NVPTX::INT_PTX_LDU_GLOBAL_i16avar, NVPTX::INT_PTX_LDU_GLOBAL_i32avar, 1267 NVPTX::INT_PTX_LDU_GLOBAL_i64avar, NVPTX::INT_PTX_LDU_GLOBAL_f32avar, 1268 NVPTX::INT_PTX_LDU_GLOBAL_f64avar); 1269 break; 1270 case NVPTXISD::LoadV2: 1271 Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1272 NVPTX::INT_PTX_LDG_G_v2i8_ELE_avar, 1273 NVPTX::INT_PTX_LDG_G_v2i16_ELE_avar, 1274 NVPTX::INT_PTX_LDG_G_v2i32_ELE_avar, 1275 NVPTX::INT_PTX_LDG_G_v2i64_ELE_avar, 1276 NVPTX::INT_PTX_LDG_G_v2f32_ELE_avar, 1277 NVPTX::INT_PTX_LDG_G_v2f64_ELE_avar); 1278 break; 1279 case NVPTXISD::LDUV2: 1280 Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1281 NVPTX::INT_PTX_LDU_G_v2i8_ELE_avar, 1282 NVPTX::INT_PTX_LDU_G_v2i16_ELE_avar, 1283 NVPTX::INT_PTX_LDU_G_v2i32_ELE_avar, 1284 NVPTX::INT_PTX_LDU_G_v2i64_ELE_avar, 1285 NVPTX::INT_PTX_LDU_G_v2f32_ELE_avar, 1286 NVPTX::INT_PTX_LDU_G_v2f64_ELE_avar); 1287 break; 1288 case NVPTXISD::LoadV4: 1289 Opcode = pickOpcodeForVT( 1290 EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE_avar, 1291 NVPTX::INT_PTX_LDG_G_v4i16_ELE_avar, 1292 NVPTX::INT_PTX_LDG_G_v4i32_ELE_avar, std::nullopt, 1293 NVPTX::INT_PTX_LDG_G_v4f32_ELE_avar, std::nullopt); 1294 break; 1295 case NVPTXISD::LDUV4: 1296 Opcode = pickOpcodeForVT( 1297 EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE_avar, 1298 NVPTX::INT_PTX_LDU_G_v4i16_ELE_avar, 1299 NVPTX::INT_PTX_LDU_G_v4i32_ELE_avar, std::nullopt, 1300 NVPTX::INT_PTX_LDU_G_v4f32_ELE_avar, std::nullopt); 1301 break; 1302 } 1303 if (!Opcode) 1304 return false; 1305 SDValue Ops[] = { Addr, Chain }; 1306 LD = CurDAG->getMachineNode(*Opcode, DL, InstVTList, Ops); 1307 } else if (TM.is64Bit() ? SelectADDRri64(Op1.getNode(), Op1, Base, Offset) 1308 : SelectADDRri(Op1.getNode(), Op1, Base, Offset)) { 1309 if (TM.is64Bit()) { 1310 switch (N->getOpcode()) { 1311 default: 1312 return false; 1313 case ISD::LOAD: 1314 Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1315 NVPTX::INT_PTX_LDG_GLOBAL_i8ari64, 1316 NVPTX::INT_PTX_LDG_GLOBAL_i16ari64, 1317 NVPTX::INT_PTX_LDG_GLOBAL_i32ari64, 1318 NVPTX::INT_PTX_LDG_GLOBAL_i64ari64, 1319 NVPTX::INT_PTX_LDG_GLOBAL_f32ari64, 1320 NVPTX::INT_PTX_LDG_GLOBAL_f64ari64); 1321 break; 1322 case ISD::INTRINSIC_W_CHAIN: 1323 Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1324 NVPTX::INT_PTX_LDU_GLOBAL_i8ari64, 1325 NVPTX::INT_PTX_LDU_GLOBAL_i16ari64, 1326 NVPTX::INT_PTX_LDU_GLOBAL_i32ari64, 1327 NVPTX::INT_PTX_LDU_GLOBAL_i64ari64, 1328 NVPTX::INT_PTX_LDU_GLOBAL_f32ari64, 1329 NVPTX::INT_PTX_LDU_GLOBAL_f64ari64); 1330 break; 1331 case NVPTXISD::LoadV2: 1332 Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1333 NVPTX::INT_PTX_LDG_G_v2i8_ELE_ari64, 1334 NVPTX::INT_PTX_LDG_G_v2i16_ELE_ari64, 1335 NVPTX::INT_PTX_LDG_G_v2i32_ELE_ari64, 1336 NVPTX::INT_PTX_LDG_G_v2i64_ELE_ari64, 1337 NVPTX::INT_PTX_LDG_G_v2f32_ELE_ari64, 1338 NVPTX::INT_PTX_LDG_G_v2f64_ELE_ari64); 1339 break; 1340 case NVPTXISD::LDUV2: 1341 Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1342 NVPTX::INT_PTX_LDU_G_v2i8_ELE_ari64, 1343 NVPTX::INT_PTX_LDU_G_v2i16_ELE_ari64, 1344 NVPTX::INT_PTX_LDU_G_v2i32_ELE_ari64, 1345 NVPTX::INT_PTX_LDU_G_v2i64_ELE_ari64, 1346 NVPTX::INT_PTX_LDU_G_v2f32_ELE_ari64, 1347 NVPTX::INT_PTX_LDU_G_v2f64_ELE_ari64); 1348 break; 1349 case NVPTXISD::LoadV4: 1350 Opcode = pickOpcodeForVT( 1351 EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE_ari64, 1352 NVPTX::INT_PTX_LDG_G_v4i16_ELE_ari64, 1353 NVPTX::INT_PTX_LDG_G_v4i32_ELE_ari64, std::nullopt, 1354 NVPTX::INT_PTX_LDG_G_v4f32_ELE_ari64, std::nullopt); 1355 break; 1356 case NVPTXISD::LDUV4: 1357 Opcode = pickOpcodeForVT( 1358 EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE_ari64, 1359 NVPTX::INT_PTX_LDU_G_v4i16_ELE_ari64, 1360 NVPTX::INT_PTX_LDU_G_v4i32_ELE_ari64, std::nullopt, 1361 NVPTX::INT_PTX_LDU_G_v4f32_ELE_ari64, std::nullopt); 1362 break; 1363 } 1364 } else { 1365 switch (N->getOpcode()) { 1366 default: 1367 return false; 1368 case ISD::LOAD: 1369 Opcode = pickOpcodeForVT( 1370 EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8ari, 1371 NVPTX::INT_PTX_LDG_GLOBAL_i16ari, NVPTX::INT_PTX_LDG_GLOBAL_i32ari, 1372 NVPTX::INT_PTX_LDG_GLOBAL_i64ari, NVPTX::INT_PTX_LDG_GLOBAL_f32ari, 1373 NVPTX::INT_PTX_LDG_GLOBAL_f64ari); 1374 break; 1375 case ISD::INTRINSIC_W_CHAIN: 1376 Opcode = pickOpcodeForVT( 1377 EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8ari, 1378 NVPTX::INT_PTX_LDU_GLOBAL_i16ari, NVPTX::INT_PTX_LDU_GLOBAL_i32ari, 1379 NVPTX::INT_PTX_LDU_GLOBAL_i64ari, NVPTX::INT_PTX_LDU_GLOBAL_f32ari, 1380 NVPTX::INT_PTX_LDU_GLOBAL_f64ari); 1381 break; 1382 case NVPTXISD::LoadV2: 1383 Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1384 NVPTX::INT_PTX_LDG_G_v2i8_ELE_ari32, 1385 NVPTX::INT_PTX_LDG_G_v2i16_ELE_ari32, 1386 NVPTX::INT_PTX_LDG_G_v2i32_ELE_ari32, 1387 NVPTX::INT_PTX_LDG_G_v2i64_ELE_ari32, 1388 NVPTX::INT_PTX_LDG_G_v2f32_ELE_ari32, 1389 NVPTX::INT_PTX_LDG_G_v2f64_ELE_ari32); 1390 break; 1391 case NVPTXISD::LDUV2: 1392 Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1393 NVPTX::INT_PTX_LDU_G_v2i8_ELE_ari32, 1394 NVPTX::INT_PTX_LDU_G_v2i16_ELE_ari32, 1395 NVPTX::INT_PTX_LDU_G_v2i32_ELE_ari32, 1396 NVPTX::INT_PTX_LDU_G_v2i64_ELE_ari32, 1397 NVPTX::INT_PTX_LDU_G_v2f32_ELE_ari32, 1398 NVPTX::INT_PTX_LDU_G_v2f64_ELE_ari32); 1399 break; 1400 case NVPTXISD::LoadV4: 1401 Opcode = pickOpcodeForVT( 1402 EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE_ari32, 1403 NVPTX::INT_PTX_LDG_G_v4i16_ELE_ari32, 1404 NVPTX::INT_PTX_LDG_G_v4i32_ELE_ari32, std::nullopt, 1405 NVPTX::INT_PTX_LDG_G_v4f32_ELE_ari32, std::nullopt); 1406 break; 1407 case NVPTXISD::LDUV4: 1408 Opcode = pickOpcodeForVT( 1409 EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE_ari32, 1410 NVPTX::INT_PTX_LDU_G_v4i16_ELE_ari32, 1411 NVPTX::INT_PTX_LDU_G_v4i32_ELE_ari32, std::nullopt, 1412 NVPTX::INT_PTX_LDU_G_v4f32_ELE_ari32, std::nullopt); 1413 break; 1414 } 1415 } 1416 if (!Opcode) 1417 return false; 1418 SDValue Ops[] = {Base, Offset, Chain}; 1419 LD = CurDAG->getMachineNode(*Opcode, DL, InstVTList, Ops); 1420 } else { 1421 if (TM.is64Bit()) { 1422 switch (N->getOpcode()) { 1423 default: 1424 return false; 1425 case ISD::LOAD: 1426 Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1427 NVPTX::INT_PTX_LDG_GLOBAL_i8areg64, 1428 NVPTX::INT_PTX_LDG_GLOBAL_i16areg64, 1429 NVPTX::INT_PTX_LDG_GLOBAL_i32areg64, 1430 NVPTX::INT_PTX_LDG_GLOBAL_i64areg64, 1431 NVPTX::INT_PTX_LDG_GLOBAL_f32areg64, 1432 NVPTX::INT_PTX_LDG_GLOBAL_f64areg64); 1433 break; 1434 case ISD::INTRINSIC_W_CHAIN: 1435 Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1436 NVPTX::INT_PTX_LDU_GLOBAL_i8areg64, 1437 NVPTX::INT_PTX_LDU_GLOBAL_i16areg64, 1438 NVPTX::INT_PTX_LDU_GLOBAL_i32areg64, 1439 NVPTX::INT_PTX_LDU_GLOBAL_i64areg64, 1440 NVPTX::INT_PTX_LDU_GLOBAL_f32areg64, 1441 NVPTX::INT_PTX_LDU_GLOBAL_f64areg64); 1442 break; 1443 case NVPTXISD::LoadV2: 1444 Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1445 NVPTX::INT_PTX_LDG_G_v2i8_ELE_areg64, 1446 NVPTX::INT_PTX_LDG_G_v2i16_ELE_areg64, 1447 NVPTX::INT_PTX_LDG_G_v2i32_ELE_areg64, 1448 NVPTX::INT_PTX_LDG_G_v2i64_ELE_areg64, 1449 NVPTX::INT_PTX_LDG_G_v2f32_ELE_areg64, 1450 NVPTX::INT_PTX_LDG_G_v2f64_ELE_areg64); 1451 break; 1452 case NVPTXISD::LDUV2: 1453 Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1454 NVPTX::INT_PTX_LDU_G_v2i8_ELE_areg64, 1455 NVPTX::INT_PTX_LDU_G_v2i16_ELE_areg64, 1456 NVPTX::INT_PTX_LDU_G_v2i32_ELE_areg64, 1457 NVPTX::INT_PTX_LDU_G_v2i64_ELE_areg64, 1458 NVPTX::INT_PTX_LDU_G_v2f32_ELE_areg64, 1459 NVPTX::INT_PTX_LDU_G_v2f64_ELE_areg64); 1460 break; 1461 case NVPTXISD::LoadV4: 1462 Opcode = pickOpcodeForVT( 1463 EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE_areg64, 1464 NVPTX::INT_PTX_LDG_G_v4i16_ELE_areg64, 1465 NVPTX::INT_PTX_LDG_G_v4i32_ELE_areg64, std::nullopt, 1466 NVPTX::INT_PTX_LDG_G_v4f32_ELE_areg64, std::nullopt); 1467 break; 1468 case NVPTXISD::LDUV4: 1469 Opcode = pickOpcodeForVT( 1470 EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE_areg64, 1471 NVPTX::INT_PTX_LDU_G_v4i16_ELE_areg64, 1472 NVPTX::INT_PTX_LDU_G_v4i32_ELE_areg64, std::nullopt, 1473 NVPTX::INT_PTX_LDU_G_v4f32_ELE_areg64, std::nullopt); 1474 break; 1475 } 1476 } else { 1477 switch (N->getOpcode()) { 1478 default: 1479 return false; 1480 case ISD::LOAD: 1481 Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1482 NVPTX::INT_PTX_LDG_GLOBAL_i8areg, 1483 NVPTX::INT_PTX_LDG_GLOBAL_i16areg, 1484 NVPTX::INT_PTX_LDG_GLOBAL_i32areg, 1485 NVPTX::INT_PTX_LDG_GLOBAL_i64areg, 1486 NVPTX::INT_PTX_LDG_GLOBAL_f32areg, 1487 NVPTX::INT_PTX_LDG_GLOBAL_f64areg); 1488 break; 1489 case ISD::INTRINSIC_W_CHAIN: 1490 Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1491 NVPTX::INT_PTX_LDU_GLOBAL_i8areg, 1492 NVPTX::INT_PTX_LDU_GLOBAL_i16areg, 1493 NVPTX::INT_PTX_LDU_GLOBAL_i32areg, 1494 NVPTX::INT_PTX_LDU_GLOBAL_i64areg, 1495 NVPTX::INT_PTX_LDU_GLOBAL_f32areg, 1496 NVPTX::INT_PTX_LDU_GLOBAL_f64areg); 1497 break; 1498 case NVPTXISD::LoadV2: 1499 Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1500 NVPTX::INT_PTX_LDG_G_v2i8_ELE_areg32, 1501 NVPTX::INT_PTX_LDG_G_v2i16_ELE_areg32, 1502 NVPTX::INT_PTX_LDG_G_v2i32_ELE_areg32, 1503 NVPTX::INT_PTX_LDG_G_v2i64_ELE_areg32, 1504 NVPTX::INT_PTX_LDG_G_v2f32_ELE_areg32, 1505 NVPTX::INT_PTX_LDG_G_v2f64_ELE_areg32); 1506 break; 1507 case NVPTXISD::LDUV2: 1508 Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1509 NVPTX::INT_PTX_LDU_G_v2i8_ELE_areg32, 1510 NVPTX::INT_PTX_LDU_G_v2i16_ELE_areg32, 1511 NVPTX::INT_PTX_LDU_G_v2i32_ELE_areg32, 1512 NVPTX::INT_PTX_LDU_G_v2i64_ELE_areg32, 1513 NVPTX::INT_PTX_LDU_G_v2f32_ELE_areg32, 1514 NVPTX::INT_PTX_LDU_G_v2f64_ELE_areg32); 1515 break; 1516 case NVPTXISD::LoadV4: 1517 Opcode = pickOpcodeForVT( 1518 EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE_areg32, 1519 NVPTX::INT_PTX_LDG_G_v4i16_ELE_areg32, 1520 NVPTX::INT_PTX_LDG_G_v4i32_ELE_areg32, std::nullopt, 1521 NVPTX::INT_PTX_LDG_G_v4f32_ELE_areg32, std::nullopt); 1522 break; 1523 case NVPTXISD::LDUV4: 1524 Opcode = pickOpcodeForVT( 1525 EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE_areg32, 1526 NVPTX::INT_PTX_LDU_G_v4i16_ELE_areg32, 1527 NVPTX::INT_PTX_LDU_G_v4i32_ELE_areg32, std::nullopt, 1528 NVPTX::INT_PTX_LDU_G_v4f32_ELE_areg32, std::nullopt); 1529 break; 1530 } 1531 } 1532 if (!Opcode) 1533 return false; 1534 SDValue Ops[] = { Op1, Chain }; 1535 LD = CurDAG->getMachineNode(*Opcode, DL, InstVTList, Ops); 1536 } 1537 1538 // For automatic generation of LDG (through SelectLoad[Vector], not the 1539 // intrinsics), we may have an extending load like: 1540 // 1541 // i32,ch = load<LD1[%data1(addrspace=1)], zext from i8> t0, t7, undef:i64 1542 // 1543 // In this case, the matching logic above will select a load for the original 1544 // memory type (in this case, i8) and our types will not match (the node needs 1545 // to return an i32 in this case). Our LDG/LDU nodes do not support the 1546 // concept of sign-/zero-extension, so emulate it here by adding an explicit 1547 // CVT instruction. Ptxas should clean up any redundancies here. 1548 1549 LoadSDNode *LdNode = dyn_cast<LoadSDNode>(N); 1550 1551 if (OrigType != EltVT && 1552 (LdNode || (OrigType.isFloatingPoint() && EltVT.isFloatingPoint()))) { 1553 // We have an extending-load. The instruction we selected operates on the 1554 // smaller type, but the SDNode we are replacing has the larger type. We 1555 // need to emit a CVT to make the types match. 1556 unsigned CvtOpc = 1557 GetConvertOpcode(OrigType.getSimpleVT(), EltVT.getSimpleVT(), LdNode); 1558 1559 // For each output value, apply the manual sign/zero-extension and make sure 1560 // all users of the load go through that CVT. 1561 for (unsigned i = 0; i != NumElts; ++i) { 1562 SDValue Res(LD, i); 1563 SDValue OrigVal(N, i); 1564 1565 SDNode *CvtNode = 1566 CurDAG->getMachineNode(CvtOpc, DL, OrigType, Res, 1567 CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, 1568 DL, MVT::i32)); 1569 ReplaceUses(OrigVal, SDValue(CvtNode, 0)); 1570 } 1571 } 1572 1573 ReplaceNode(N, LD); 1574 return true; 1575 } 1576 1577 bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { 1578 MemSDNode *ST = cast<MemSDNode>(N); 1579 assert(ST->writeMem() && "Expected store"); 1580 StoreSDNode *PlainStore = dyn_cast<StoreSDNode>(N); 1581 AtomicSDNode *AtomicStore = dyn_cast<AtomicSDNode>(N); 1582 assert((PlainStore || AtomicStore) && "Expected store"); 1583 1584 // do not support pre/post inc/dec 1585 if (PlainStore && PlainStore->isIndexed()) 1586 return false; 1587 1588 EVT StoreVT = ST->getMemoryVT(); 1589 if (!StoreVT.isSimple()) 1590 return false; 1591 1592 // Address Space Setting 1593 unsigned int CodeAddrSpace = getCodeAddrSpace(ST); 1594 unsigned int PointerSize = 1595 CurDAG->getDataLayout().getPointerSizeInBits(ST->getAddressSpace()); 1596 1597 SDLoc DL(N); 1598 SDValue Chain = ST->getChain(); 1599 auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, ST); 1600 1601 // Vector Setting 1602 MVT SimpleVT = StoreVT.getSimpleVT(); 1603 unsigned VecType = NVPTX::PTXLdStInstCode::Scalar; 1604 1605 // Type Setting: toType + toTypeWidth 1606 // - for integer type, always use 'u' 1607 MVT ScalarVT = SimpleVT.getScalarType(); 1608 unsigned ToTypeWidth = ScalarVT.getSizeInBits(); 1609 if (SimpleVT.isVector()) { 1610 assert((Isv2x16VT(StoreVT) || StoreVT == MVT::v4i8) && 1611 "Unexpected vector type"); 1612 // v2x16 is stored using st.b32 1613 ToTypeWidth = 32; 1614 } 1615 1616 unsigned int ToType = getLdStRegType(ScalarVT); 1617 1618 // Create the machine instruction DAG 1619 SDValue Value = PlainStore ? PlainStore->getValue() : AtomicStore->getVal(); 1620 SDValue BasePtr = ST->getBasePtr(); 1621 SDValue Addr; 1622 SDValue Offset, Base; 1623 std::optional<unsigned> Opcode; 1624 MVT::SimpleValueType SourceVT = 1625 Value.getNode()->getSimpleValueType(0).SimpleTy; 1626 1627 SmallVector<SDValue, 12> Ops( 1628 {Value, getI32Imm(Ordering, DL), getI32Imm(Scope, DL), 1629 getI32Imm(CodeAddrSpace, DL), getI32Imm(VecType, DL), 1630 getI32Imm(ToType, DL), getI32Imm(ToTypeWidth, DL)}); 1631 1632 if (SelectDirectAddr(BasePtr, Addr)) { 1633 Opcode = pickOpcodeForVT(SourceVT, NVPTX::ST_i8_avar, NVPTX::ST_i16_avar, 1634 NVPTX::ST_i32_avar, NVPTX::ST_i64_avar, 1635 NVPTX::ST_f32_avar, NVPTX::ST_f64_avar); 1636 if (!Opcode) 1637 return false; 1638 Ops.append({Addr, Chain}); 1639 } else if (PointerSize == 64 1640 ? SelectADDRsi64(BasePtr.getNode(), BasePtr, Base, Offset) 1641 : SelectADDRsi(BasePtr.getNode(), BasePtr, Base, Offset)) { 1642 Opcode = pickOpcodeForVT(SourceVT, NVPTX::ST_i8_asi, NVPTX::ST_i16_asi, 1643 NVPTX::ST_i32_asi, NVPTX::ST_i64_asi, 1644 NVPTX::ST_f32_asi, NVPTX::ST_f64_asi); 1645 if (!Opcode) 1646 return false; 1647 Ops.append({Base, Offset, Chain}); 1648 } else if (PointerSize == 64 1649 ? SelectADDRri64(BasePtr.getNode(), BasePtr, Base, Offset) 1650 : SelectADDRri(BasePtr.getNode(), BasePtr, Base, Offset)) { 1651 if (PointerSize == 64) 1652 Opcode = 1653 pickOpcodeForVT(SourceVT, NVPTX::ST_i8_ari_64, NVPTX::ST_i16_ari_64, 1654 NVPTX::ST_i32_ari_64, NVPTX::ST_i64_ari_64, 1655 NVPTX::ST_f32_ari_64, NVPTX::ST_f64_ari_64); 1656 else 1657 Opcode = pickOpcodeForVT(SourceVT, NVPTX::ST_i8_ari, NVPTX::ST_i16_ari, 1658 NVPTX::ST_i32_ari, NVPTX::ST_i64_ari, 1659 NVPTX::ST_f32_ari, NVPTX::ST_f64_ari); 1660 if (!Opcode) 1661 return false; 1662 Ops.append({Base, Offset, Chain}); 1663 } else { 1664 if (PointerSize == 64) 1665 Opcode = 1666 pickOpcodeForVT(SourceVT, NVPTX::ST_i8_areg_64, NVPTX::ST_i16_areg_64, 1667 NVPTX::ST_i32_areg_64, NVPTX::ST_i64_areg_64, 1668 NVPTX::ST_f32_areg_64, NVPTX::ST_f64_areg_64); 1669 else 1670 Opcode = pickOpcodeForVT(SourceVT, NVPTX::ST_i8_areg, NVPTX::ST_i16_areg, 1671 NVPTX::ST_i32_areg, NVPTX::ST_i64_areg, 1672 NVPTX::ST_f32_areg, NVPTX::ST_f64_areg); 1673 if (!Opcode) 1674 return false; 1675 Ops.append({BasePtr, Chain}); 1676 } 1677 1678 SDNode *NVPTXST = CurDAG->getMachineNode(*Opcode, DL, MVT::Other, Ops); 1679 1680 if (!NVPTXST) 1681 return false; 1682 1683 MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand(); 1684 CurDAG->setNodeMemRefs(cast<MachineSDNode>(NVPTXST), {MemRef}); 1685 ReplaceNode(N, NVPTXST); 1686 return true; 1687 } 1688 1689 bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) { 1690 SDValue Op1 = N->getOperand(1); 1691 SDValue Addr, Offset, Base; 1692 std::optional<unsigned> Opcode; 1693 SDNode *ST; 1694 EVT EltVT = Op1.getValueType(); 1695 MemSDNode *MemSD = cast<MemSDNode>(N); 1696 EVT StoreVT = MemSD->getMemoryVT(); 1697 1698 // Address Space Setting 1699 unsigned CodeAddrSpace = getCodeAddrSpace(MemSD); 1700 if (CodeAddrSpace == NVPTX::AddressSpace::Const) { 1701 report_fatal_error("Cannot store to pointer that points to constant " 1702 "memory space"); 1703 } 1704 unsigned int PointerSize = 1705 CurDAG->getDataLayout().getPointerSizeInBits(MemSD->getAddressSpace()); 1706 1707 SDLoc DL(N); 1708 SDValue Chain = N->getOperand(0); 1709 auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, MemSD); 1710 1711 // Type Setting: toType + toTypeWidth 1712 // - for integer type, always use 'u' 1713 assert(StoreVT.isSimple() && "Store value is not simple"); 1714 MVT ScalarVT = StoreVT.getSimpleVT().getScalarType(); 1715 unsigned ToTypeWidth = ScalarVT.getSizeInBits(); 1716 unsigned ToType = getLdStRegType(ScalarVT); 1717 1718 SmallVector<SDValue, 12> Ops; 1719 SDValue N2; 1720 unsigned VecType; 1721 1722 switch (N->getOpcode()) { 1723 case NVPTXISD::StoreV2: 1724 VecType = NVPTX::PTXLdStInstCode::V2; 1725 Ops.append({N->getOperand(1), N->getOperand(2)}); 1726 N2 = N->getOperand(3); 1727 break; 1728 case NVPTXISD::StoreV4: 1729 VecType = NVPTX::PTXLdStInstCode::V4; 1730 Ops.append({N->getOperand(1), N->getOperand(2), N->getOperand(3), 1731 N->getOperand(4)}); 1732 N2 = N->getOperand(5); 1733 break; 1734 default: 1735 return false; 1736 } 1737 1738 if (isVectorElementTypeUpsized(EltVT)) { 1739 EltVT = MVT::i32; 1740 ToType = NVPTX::PTXLdStInstCode::Untyped; 1741 ToTypeWidth = 32; 1742 } 1743 1744 Ops.append({getI32Imm(Ordering, DL), getI32Imm(Scope, DL), 1745 getI32Imm(CodeAddrSpace, DL), getI32Imm(VecType, DL), 1746 getI32Imm(ToType, DL), getI32Imm(ToTypeWidth, DL)}); 1747 1748 if (SelectDirectAddr(N2, Addr)) { 1749 switch (N->getOpcode()) { 1750 default: 1751 return false; 1752 case NVPTXISD::StoreV2: 1753 Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1754 NVPTX::STV_i8_v2_avar, NVPTX::STV_i16_v2_avar, 1755 NVPTX::STV_i32_v2_avar, NVPTX::STV_i64_v2_avar, 1756 NVPTX::STV_f32_v2_avar, NVPTX::STV_f64_v2_avar); 1757 break; 1758 case NVPTXISD::StoreV4: 1759 Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1760 NVPTX::STV_i8_v4_avar, NVPTX::STV_i16_v4_avar, 1761 NVPTX::STV_i32_v4_avar, std::nullopt, 1762 NVPTX::STV_f32_v4_avar, std::nullopt); 1763 break; 1764 } 1765 Ops.push_back(Addr); 1766 } else if (PointerSize == 64 ? SelectADDRsi64(N2.getNode(), N2, Base, Offset) 1767 : SelectADDRsi(N2.getNode(), N2, Base, Offset)) { 1768 switch (N->getOpcode()) { 1769 default: 1770 return false; 1771 case NVPTXISD::StoreV2: 1772 Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1773 NVPTX::STV_i8_v2_asi, NVPTX::STV_i16_v2_asi, 1774 NVPTX::STV_i32_v2_asi, NVPTX::STV_i64_v2_asi, 1775 NVPTX::STV_f32_v2_asi, NVPTX::STV_f64_v2_asi); 1776 break; 1777 case NVPTXISD::StoreV4: 1778 Opcode = 1779 pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4_asi, 1780 NVPTX::STV_i16_v4_asi, NVPTX::STV_i32_v4_asi, 1781 std::nullopt, NVPTX::STV_f32_v4_asi, std::nullopt); 1782 break; 1783 } 1784 Ops.append({Base, Offset}); 1785 } else if (PointerSize == 64 ? SelectADDRri64(N2.getNode(), N2, Base, Offset) 1786 : SelectADDRri(N2.getNode(), N2, Base, Offset)) { 1787 if (PointerSize == 64) { 1788 switch (N->getOpcode()) { 1789 default: 1790 return false; 1791 case NVPTXISD::StoreV2: 1792 Opcode = 1793 pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1794 NVPTX::STV_i8_v2_ari_64, NVPTX::STV_i16_v2_ari_64, 1795 NVPTX::STV_i32_v2_ari_64, NVPTX::STV_i64_v2_ari_64, 1796 NVPTX::STV_f32_v2_ari_64, NVPTX::STV_f64_v2_ari_64); 1797 break; 1798 case NVPTXISD::StoreV4: 1799 Opcode = pickOpcodeForVT( 1800 EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4_ari_64, 1801 NVPTX::STV_i16_v4_ari_64, NVPTX::STV_i32_v4_ari_64, std::nullopt, 1802 NVPTX::STV_f32_v4_ari_64, std::nullopt); 1803 break; 1804 } 1805 } else { 1806 switch (N->getOpcode()) { 1807 default: 1808 return false; 1809 case NVPTXISD::StoreV2: 1810 Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1811 NVPTX::STV_i8_v2_ari, NVPTX::STV_i16_v2_ari, 1812 NVPTX::STV_i32_v2_ari, NVPTX::STV_i64_v2_ari, 1813 NVPTX::STV_f32_v2_ari, NVPTX::STV_f64_v2_ari); 1814 break; 1815 case NVPTXISD::StoreV4: 1816 Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, 1817 NVPTX::STV_i8_v4_ari, NVPTX::STV_i16_v4_ari, 1818 NVPTX::STV_i32_v4_ari, std::nullopt, 1819 NVPTX::STV_f32_v4_ari, std::nullopt); 1820 break; 1821 } 1822 } 1823 Ops.append({Base, Offset}); 1824 } else { 1825 if (PointerSize == 64) { 1826 switch (N->getOpcode()) { 1827 default: 1828 return false; 1829 case NVPTXISD::StoreV2: 1830 Opcode = pickOpcodeForVT( 1831 EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v2_areg_64, 1832 NVPTX::STV_i16_v2_areg_64, NVPTX::STV_i32_v2_areg_64, 1833 NVPTX::STV_i64_v2_areg_64, NVPTX::STV_f32_v2_areg_64, 1834 NVPTX::STV_f64_v2_areg_64); 1835 break; 1836 case NVPTXISD::StoreV4: 1837 Opcode = pickOpcodeForVT( 1838 EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4_areg_64, 1839 NVPTX::STV_i16_v4_areg_64, NVPTX::STV_i32_v4_areg_64, std::nullopt, 1840 NVPTX::STV_f32_v4_areg_64, std::nullopt); 1841 break; 1842 } 1843 } else { 1844 switch (N->getOpcode()) { 1845 default: 1846 return false; 1847 case NVPTXISD::StoreV2: 1848 Opcode = 1849 pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v2_areg, 1850 NVPTX::STV_i16_v2_areg, NVPTX::STV_i32_v2_areg, 1851 NVPTX::STV_i64_v2_areg, NVPTX::STV_f32_v2_areg, 1852 NVPTX::STV_f64_v2_areg); 1853 break; 1854 case NVPTXISD::StoreV4: 1855 Opcode = 1856 pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4_areg, 1857 NVPTX::STV_i16_v4_areg, NVPTX::STV_i32_v4_areg, 1858 std::nullopt, NVPTX::STV_f32_v4_areg, std::nullopt); 1859 break; 1860 } 1861 } 1862 Ops.push_back(N2); 1863 } 1864 1865 if (!Opcode) 1866 return false; 1867 1868 Ops.push_back(Chain); 1869 1870 ST = CurDAG->getMachineNode(*Opcode, DL, MVT::Other, Ops); 1871 1872 MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand(); 1873 CurDAG->setNodeMemRefs(cast<MachineSDNode>(ST), {MemRef}); 1874 1875 ReplaceNode(N, ST); 1876 return true; 1877 } 1878 1879 bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) { 1880 SDValue Chain = Node->getOperand(0); 1881 SDValue Offset = Node->getOperand(2); 1882 SDValue Glue = Node->getOperand(3); 1883 SDLoc DL(Node); 1884 MemSDNode *Mem = cast<MemSDNode>(Node); 1885 1886 unsigned VecSize; 1887 switch (Node->getOpcode()) { 1888 default: 1889 return false; 1890 case NVPTXISD::LoadParam: 1891 VecSize = 1; 1892 break; 1893 case NVPTXISD::LoadParamV2: 1894 VecSize = 2; 1895 break; 1896 case NVPTXISD::LoadParamV4: 1897 VecSize = 4; 1898 break; 1899 } 1900 1901 EVT EltVT = Node->getValueType(0); 1902 EVT MemVT = Mem->getMemoryVT(); 1903 1904 std::optional<unsigned> Opcode; 1905 1906 switch (VecSize) { 1907 default: 1908 return false; 1909 case 1: 1910 Opcode = pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy, 1911 NVPTX::LoadParamMemI8, NVPTX::LoadParamMemI16, 1912 NVPTX::LoadParamMemI32, NVPTX::LoadParamMemI64, 1913 NVPTX::LoadParamMemF32, NVPTX::LoadParamMemF64); 1914 break; 1915 case 2: 1916 Opcode = 1917 pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy, NVPTX::LoadParamMemV2I8, 1918 NVPTX::LoadParamMemV2I16, NVPTX::LoadParamMemV2I32, 1919 NVPTX::LoadParamMemV2I64, NVPTX::LoadParamMemV2F32, 1920 NVPTX::LoadParamMemV2F64); 1921 break; 1922 case 4: 1923 Opcode = 1924 pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy, NVPTX::LoadParamMemV4I8, 1925 NVPTX::LoadParamMemV4I16, NVPTX::LoadParamMemV4I32, 1926 std::nullopt, NVPTX::LoadParamMemV4F32, std::nullopt); 1927 break; 1928 } 1929 if (!Opcode) 1930 return false; 1931 1932 SDVTList VTs; 1933 if (VecSize == 1) { 1934 VTs = CurDAG->getVTList(EltVT, MVT::Other, MVT::Glue); 1935 } else if (VecSize == 2) { 1936 VTs = CurDAG->getVTList(EltVT, EltVT, MVT::Other, MVT::Glue); 1937 } else { 1938 EVT EVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other, MVT::Glue }; 1939 VTs = CurDAG->getVTList(EVTs); 1940 } 1941 1942 unsigned OffsetVal = Offset->getAsZExtVal(); 1943 1944 SmallVector<SDValue, 2> Ops( 1945 {CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain, Glue}); 1946 1947 ReplaceNode(Node, CurDAG->getMachineNode(*Opcode, DL, VTs, Ops)); 1948 return true; 1949 } 1950 1951 bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) { 1952 SDLoc DL(N); 1953 SDValue Chain = N->getOperand(0); 1954 SDValue Offset = N->getOperand(1); 1955 unsigned OffsetVal = Offset->getAsZExtVal(); 1956 MemSDNode *Mem = cast<MemSDNode>(N); 1957 1958 // How many elements do we have? 1959 unsigned NumElts = 1; 1960 switch (N->getOpcode()) { 1961 default: 1962 return false; 1963 case NVPTXISD::StoreRetval: 1964 NumElts = 1; 1965 break; 1966 case NVPTXISD::StoreRetvalV2: 1967 NumElts = 2; 1968 break; 1969 case NVPTXISD::StoreRetvalV4: 1970 NumElts = 4; 1971 break; 1972 } 1973 1974 // Build vector of operands 1975 SmallVector<SDValue, 6> Ops; 1976 for (unsigned i = 0; i < NumElts; ++i) 1977 Ops.push_back(N->getOperand(i + 2)); 1978 Ops.append({CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain}); 1979 1980 // Determine target opcode 1981 // If we have an i1, use an 8-bit store. The lowering code in 1982 // NVPTXISelLowering will have already emitted an upcast. 1983 std::optional<unsigned> Opcode = 0; 1984 switch (NumElts) { 1985 default: 1986 return false; 1987 case 1: 1988 Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy, 1989 NVPTX::StoreRetvalI8, NVPTX::StoreRetvalI16, 1990 NVPTX::StoreRetvalI32, NVPTX::StoreRetvalI64, 1991 NVPTX::StoreRetvalF32, NVPTX::StoreRetvalF64); 1992 if (Opcode == NVPTX::StoreRetvalI8) { 1993 // Fine tune the opcode depending on the size of the operand. 1994 // This helps to avoid creating redundant COPY instructions in 1995 // InstrEmitter::AddRegisterOperand(). 1996 switch (Ops[0].getSimpleValueType().SimpleTy) { 1997 default: 1998 break; 1999 case MVT::i32: 2000 Opcode = NVPTX::StoreRetvalI8TruncI32; 2001 break; 2002 case MVT::i64: 2003 Opcode = NVPTX::StoreRetvalI8TruncI64; 2004 break; 2005 } 2006 } 2007 break; 2008 case 2: 2009 Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy, 2010 NVPTX::StoreRetvalV2I8, NVPTX::StoreRetvalV2I16, 2011 NVPTX::StoreRetvalV2I32, NVPTX::StoreRetvalV2I64, 2012 NVPTX::StoreRetvalV2F32, NVPTX::StoreRetvalV2F64); 2013 break; 2014 case 4: 2015 Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy, 2016 NVPTX::StoreRetvalV4I8, NVPTX::StoreRetvalV4I16, 2017 NVPTX::StoreRetvalV4I32, std::nullopt, 2018 NVPTX::StoreRetvalV4F32, std::nullopt); 2019 break; 2020 } 2021 if (!Opcode) 2022 return false; 2023 2024 SDNode *Ret = CurDAG->getMachineNode(*Opcode, DL, MVT::Other, Ops); 2025 MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand(); 2026 CurDAG->setNodeMemRefs(cast<MachineSDNode>(Ret), {MemRef}); 2027 2028 ReplaceNode(N, Ret); 2029 return true; 2030 } 2031 2032 // Helpers for constructing opcode (ex: NVPTX::StoreParamV4F32_iiri) 2033 #define getOpcV2H(ty, opKind0, opKind1) \ 2034 NVPTX::StoreParamV2##ty##_##opKind0##opKind1 2035 2036 #define getOpcV2H1(ty, opKind0, isImm1) \ 2037 (isImm1) ? getOpcV2H(ty, opKind0, i) : getOpcV2H(ty, opKind0, r) 2038 2039 #define getOpcodeForVectorStParamV2(ty, isimm) \ 2040 (isimm[0]) ? getOpcV2H1(ty, i, isimm[1]) : getOpcV2H1(ty, r, isimm[1]) 2041 2042 #define getOpcV4H(ty, opKind0, opKind1, opKind2, opKind3) \ 2043 NVPTX::StoreParamV4##ty##_##opKind0##opKind1##opKind2##opKind3 2044 2045 #define getOpcV4H3(ty, opKind0, opKind1, opKind2, isImm3) \ 2046 (isImm3) ? getOpcV4H(ty, opKind0, opKind1, opKind2, i) \ 2047 : getOpcV4H(ty, opKind0, opKind1, opKind2, r) 2048 2049 #define getOpcV4H2(ty, opKind0, opKind1, isImm2, isImm3) \ 2050 (isImm2) ? getOpcV4H3(ty, opKind0, opKind1, i, isImm3) \ 2051 : getOpcV4H3(ty, opKind0, opKind1, r, isImm3) 2052 2053 #define getOpcV4H1(ty, opKind0, isImm1, isImm2, isImm3) \ 2054 (isImm1) ? getOpcV4H2(ty, opKind0, i, isImm2, isImm3) \ 2055 : getOpcV4H2(ty, opKind0, r, isImm2, isImm3) 2056 2057 #define getOpcodeForVectorStParamV4(ty, isimm) \ 2058 (isimm[0]) ? getOpcV4H1(ty, i, isimm[1], isimm[2], isimm[3]) \ 2059 : getOpcV4H1(ty, r, isimm[1], isimm[2], isimm[3]) 2060 2061 #define getOpcodeForVectorStParam(n, ty, isimm) \ 2062 (n == 2) ? getOpcodeForVectorStParamV2(ty, isimm) \ 2063 : getOpcodeForVectorStParamV4(ty, isimm) 2064 2065 static unsigned pickOpcodeForVectorStParam(SmallVector<SDValue, 8> &Ops, 2066 unsigned NumElts, 2067 MVT::SimpleValueType MemTy, 2068 SelectionDAG *CurDAG, SDLoc DL) { 2069 // Determine which inputs are registers and immediates make new operators 2070 // with constant values 2071 SmallVector<bool, 4> IsImm(NumElts, false); 2072 for (unsigned i = 0; i < NumElts; i++) { 2073 IsImm[i] = (isa<ConstantSDNode>(Ops[i]) || isa<ConstantFPSDNode>(Ops[i])); 2074 if (IsImm[i]) { 2075 SDValue Imm = Ops[i]; 2076 if (MemTy == MVT::f32 || MemTy == MVT::f64) { 2077 const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm); 2078 const ConstantFP *CF = ConstImm->getConstantFPValue(); 2079 Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0)); 2080 } else { 2081 const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm); 2082 const ConstantInt *CI = ConstImm->getConstantIntValue(); 2083 Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0)); 2084 } 2085 Ops[i] = Imm; 2086 } 2087 } 2088 2089 // Get opcode for MemTy, size, and register/immediate operand ordering 2090 switch (MemTy) { 2091 case MVT::i8: 2092 return getOpcodeForVectorStParam(NumElts, I8, IsImm); 2093 case MVT::i16: 2094 return getOpcodeForVectorStParam(NumElts, I16, IsImm); 2095 case MVT::i32: 2096 return getOpcodeForVectorStParam(NumElts, I32, IsImm); 2097 case MVT::i64: 2098 assert(NumElts == 2 && "MVT too large for NumElts > 2"); 2099 return getOpcodeForVectorStParamV2(I64, IsImm); 2100 case MVT::f32: 2101 return getOpcodeForVectorStParam(NumElts, F32, IsImm); 2102 case MVT::f64: 2103 assert(NumElts == 2 && "MVT too large for NumElts > 2"); 2104 return getOpcodeForVectorStParamV2(F64, IsImm); 2105 2106 // These cases don't support immediates, just use the all register version 2107 // and generate moves. 2108 case MVT::i1: 2109 return (NumElts == 2) ? NVPTX::StoreParamV2I8_rr 2110 : NVPTX::StoreParamV4I8_rrrr; 2111 case MVT::f16: 2112 case MVT::bf16: 2113 return (NumElts == 2) ? NVPTX::StoreParamV2I16_rr 2114 : NVPTX::StoreParamV4I16_rrrr; 2115 case MVT::v2f16: 2116 case MVT::v2bf16: 2117 case MVT::v2i16: 2118 case MVT::v4i8: 2119 return (NumElts == 2) ? NVPTX::StoreParamV2I32_rr 2120 : NVPTX::StoreParamV4I32_rrrr; 2121 default: 2122 llvm_unreachable("Cannot select st.param for unknown MemTy"); 2123 } 2124 } 2125 2126 bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) { 2127 SDLoc DL(N); 2128 SDValue Chain = N->getOperand(0); 2129 SDValue Param = N->getOperand(1); 2130 unsigned ParamVal = Param->getAsZExtVal(); 2131 SDValue Offset = N->getOperand(2); 2132 unsigned OffsetVal = Offset->getAsZExtVal(); 2133 MemSDNode *Mem = cast<MemSDNode>(N); 2134 SDValue Glue = N->getOperand(N->getNumOperands() - 1); 2135 2136 // How many elements do we have? 2137 unsigned NumElts; 2138 switch (N->getOpcode()) { 2139 default: 2140 llvm_unreachable("Unexpected opcode"); 2141 case NVPTXISD::StoreParamU32: 2142 case NVPTXISD::StoreParamS32: 2143 case NVPTXISD::StoreParam: 2144 NumElts = 1; 2145 break; 2146 case NVPTXISD::StoreParamV2: 2147 NumElts = 2; 2148 break; 2149 case NVPTXISD::StoreParamV4: 2150 NumElts = 4; 2151 break; 2152 } 2153 2154 // Build vector of operands 2155 SmallVector<SDValue, 8> Ops; 2156 for (unsigned i = 0; i < NumElts; ++i) 2157 Ops.push_back(N->getOperand(i + 3)); 2158 Ops.append({CurDAG->getTargetConstant(ParamVal, DL, MVT::i32), 2159 CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain, Glue}); 2160 2161 // Determine target opcode 2162 // If we have an i1, use an 8-bit store. The lowering code in 2163 // NVPTXISelLowering will have already emitted an upcast. 2164 std::optional<unsigned> Opcode; 2165 switch (N->getOpcode()) { 2166 default: 2167 switch (NumElts) { 2168 default: 2169 llvm_unreachable("Unexpected NumElts"); 2170 case 1: { 2171 MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy; 2172 SDValue Imm = Ops[0]; 2173 if (MemTy != MVT::f16 && MemTy != MVT::v2f16 && 2174 (isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) { 2175 // Convert immediate to target constant 2176 if (MemTy == MVT::f32 || MemTy == MVT::f64) { 2177 const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm); 2178 const ConstantFP *CF = ConstImm->getConstantFPValue(); 2179 Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0)); 2180 } else { 2181 const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm); 2182 const ConstantInt *CI = ConstImm->getConstantIntValue(); 2183 Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0)); 2184 } 2185 Ops[0] = Imm; 2186 // Use immediate version of store param 2187 Opcode = pickOpcodeForVT(MemTy, NVPTX::StoreParamI8_i, 2188 NVPTX::StoreParamI16_i, NVPTX::StoreParamI32_i, 2189 NVPTX::StoreParamI64_i, NVPTX::StoreParamF32_i, 2190 NVPTX::StoreParamF64_i); 2191 } else 2192 Opcode = 2193 pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy, 2194 NVPTX::StoreParamI8_r, NVPTX::StoreParamI16_r, 2195 NVPTX::StoreParamI32_r, NVPTX::StoreParamI64_r, 2196 NVPTX::StoreParamF32_r, NVPTX::StoreParamF64_r); 2197 if (Opcode == NVPTX::StoreParamI8_r) { 2198 // Fine tune the opcode depending on the size of the operand. 2199 // This helps to avoid creating redundant COPY instructions in 2200 // InstrEmitter::AddRegisterOperand(). 2201 switch (Ops[0].getSimpleValueType().SimpleTy) { 2202 default: 2203 break; 2204 case MVT::i32: 2205 Opcode = NVPTX::StoreParamI8TruncI32_r; 2206 break; 2207 case MVT::i64: 2208 Opcode = NVPTX::StoreParamI8TruncI64_r; 2209 break; 2210 } 2211 } 2212 break; 2213 } 2214 case 2: 2215 case 4: { 2216 MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy; 2217 Opcode = pickOpcodeForVectorStParam(Ops, NumElts, MemTy, CurDAG, DL); 2218 break; 2219 } 2220 } 2221 break; 2222 // Special case: if we have a sign-extend/zero-extend node, insert the 2223 // conversion instruction first, and use that as the value operand to 2224 // the selected StoreParam node. 2225 case NVPTXISD::StoreParamU32: { 2226 Opcode = NVPTX::StoreParamI32_r; 2227 SDValue CvtNone = CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, 2228 MVT::i32); 2229 SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_u32_u16, DL, 2230 MVT::i32, Ops[0], CvtNone); 2231 Ops[0] = SDValue(Cvt, 0); 2232 break; 2233 } 2234 case NVPTXISD::StoreParamS32: { 2235 Opcode = NVPTX::StoreParamI32_r; 2236 SDValue CvtNone = CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, 2237 MVT::i32); 2238 SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_s32_s16, DL, 2239 MVT::i32, Ops[0], CvtNone); 2240 Ops[0] = SDValue(Cvt, 0); 2241 break; 2242 } 2243 } 2244 2245 SDVTList RetVTs = CurDAG->getVTList(MVT::Other, MVT::Glue); 2246 SDNode *Ret = CurDAG->getMachineNode(*Opcode, DL, RetVTs, Ops); 2247 MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand(); 2248 CurDAG->setNodeMemRefs(cast<MachineSDNode>(Ret), {MemRef}); 2249 2250 ReplaceNode(N, Ret); 2251 return true; 2252 } 2253 2254 /// SelectBFE - Look for instruction sequences that can be made more efficient 2255 /// by using the 'bfe' (bit-field extract) PTX instruction 2256 bool NVPTXDAGToDAGISel::tryBFE(SDNode *N) { 2257 SDLoc DL(N); 2258 SDValue LHS = N->getOperand(0); 2259 SDValue RHS = N->getOperand(1); 2260 SDValue Len; 2261 SDValue Start; 2262 SDValue Val; 2263 bool IsSigned = false; 2264 2265 if (N->getOpcode() == ISD::AND) { 2266 // Canonicalize the operands 2267 // We want 'and %val, %mask' 2268 if (isa<ConstantSDNode>(LHS) && !isa<ConstantSDNode>(RHS)) { 2269 std::swap(LHS, RHS); 2270 } 2271 2272 ConstantSDNode *Mask = dyn_cast<ConstantSDNode>(RHS); 2273 if (!Mask) { 2274 // We need a constant mask on the RHS of the AND 2275 return false; 2276 } 2277 2278 // Extract the mask bits 2279 uint64_t MaskVal = Mask->getZExtValue(); 2280 if (!isMask_64(MaskVal)) { 2281 // We *could* handle shifted masks here, but doing so would require an 2282 // 'and' operation to fix up the low-order bits so we would trade 2283 // shr+and for bfe+and, which has the same throughput 2284 return false; 2285 } 2286 2287 // How many bits are in our mask? 2288 int64_t NumBits = countr_one(MaskVal); 2289 Len = CurDAG->getTargetConstant(NumBits, DL, MVT::i32); 2290 2291 if (LHS.getOpcode() == ISD::SRL || LHS.getOpcode() == ISD::SRA) { 2292 // We have a 'srl/and' pair, extract the effective start bit and length 2293 Val = LHS.getNode()->getOperand(0); 2294 Start = LHS.getNode()->getOperand(1); 2295 ConstantSDNode *StartConst = dyn_cast<ConstantSDNode>(Start); 2296 if (StartConst) { 2297 uint64_t StartVal = StartConst->getZExtValue(); 2298 // How many "good" bits do we have left? "good" is defined here as bits 2299 // that exist in the original value, not shifted in. 2300 int64_t GoodBits = Start.getValueSizeInBits() - StartVal; 2301 if (NumBits > GoodBits) { 2302 // Do not handle the case where bits have been shifted in. In theory 2303 // we could handle this, but the cost is likely higher than just 2304 // emitting the srl/and pair. 2305 return false; 2306 } 2307 Start = CurDAG->getTargetConstant(StartVal, DL, MVT::i32); 2308 } else { 2309 // Do not handle the case where the shift amount (can be zero if no srl 2310 // was found) is not constant. We could handle this case, but it would 2311 // require run-time logic that would be more expensive than just 2312 // emitting the srl/and pair. 2313 return false; 2314 } 2315 } else { 2316 // Do not handle the case where the LHS of the and is not a shift. While 2317 // it would be trivial to handle this case, it would just transform 2318 // 'and' -> 'bfe', but 'and' has higher-throughput. 2319 return false; 2320 } 2321 } else if (N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) { 2322 if (LHS->getOpcode() == ISD::AND) { 2323 ConstantSDNode *ShiftCnst = dyn_cast<ConstantSDNode>(RHS); 2324 if (!ShiftCnst) { 2325 // Shift amount must be constant 2326 return false; 2327 } 2328 2329 uint64_t ShiftAmt = ShiftCnst->getZExtValue(); 2330 2331 SDValue AndLHS = LHS->getOperand(0); 2332 SDValue AndRHS = LHS->getOperand(1); 2333 2334 // Canonicalize the AND to have the mask on the RHS 2335 if (isa<ConstantSDNode>(AndLHS)) { 2336 std::swap(AndLHS, AndRHS); 2337 } 2338 2339 ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(AndRHS); 2340 if (!MaskCnst) { 2341 // Mask must be constant 2342 return false; 2343 } 2344 2345 uint64_t MaskVal = MaskCnst->getZExtValue(); 2346 uint64_t NumZeros; 2347 uint64_t NumBits; 2348 if (isMask_64(MaskVal)) { 2349 NumZeros = 0; 2350 // The number of bits in the result bitfield will be the number of 2351 // trailing ones (the AND) minus the number of bits we shift off 2352 NumBits = llvm::countr_one(MaskVal) - ShiftAmt; 2353 } else if (isShiftedMask_64(MaskVal)) { 2354 NumZeros = llvm::countr_zero(MaskVal); 2355 unsigned NumOnes = llvm::countr_one(MaskVal >> NumZeros); 2356 // The number of bits in the result bitfield will be the number of 2357 // trailing zeros plus the number of set bits in the mask minus the 2358 // number of bits we shift off 2359 NumBits = NumZeros + NumOnes - ShiftAmt; 2360 } else { 2361 // This is not a mask we can handle 2362 return false; 2363 } 2364 2365 if (ShiftAmt < NumZeros) { 2366 // Handling this case would require extra logic that would make this 2367 // transformation non-profitable 2368 return false; 2369 } 2370 2371 Val = AndLHS; 2372 Start = CurDAG->getTargetConstant(ShiftAmt, DL, MVT::i32); 2373 Len = CurDAG->getTargetConstant(NumBits, DL, MVT::i32); 2374 } else if (LHS->getOpcode() == ISD::SHL) { 2375 // Here, we have a pattern like: 2376 // 2377 // (sra (shl val, NN), MM) 2378 // or 2379 // (srl (shl val, NN), MM) 2380 // 2381 // If MM >= NN, we can efficiently optimize this with bfe 2382 Val = LHS->getOperand(0); 2383 2384 SDValue ShlRHS = LHS->getOperand(1); 2385 ConstantSDNode *ShlCnst = dyn_cast<ConstantSDNode>(ShlRHS); 2386 if (!ShlCnst) { 2387 // Shift amount must be constant 2388 return false; 2389 } 2390 uint64_t InnerShiftAmt = ShlCnst->getZExtValue(); 2391 2392 SDValue ShrRHS = RHS; 2393 ConstantSDNode *ShrCnst = dyn_cast<ConstantSDNode>(ShrRHS); 2394 if (!ShrCnst) { 2395 // Shift amount must be constant 2396 return false; 2397 } 2398 uint64_t OuterShiftAmt = ShrCnst->getZExtValue(); 2399 2400 // To avoid extra codegen and be profitable, we need Outer >= Inner 2401 if (OuterShiftAmt < InnerShiftAmt) { 2402 return false; 2403 } 2404 2405 // If the outer shift is more than the type size, we have no bitfield to 2406 // extract (since we also check that the inner shift is <= the outer shift 2407 // then this also implies that the inner shift is < the type size) 2408 if (OuterShiftAmt >= Val.getValueSizeInBits()) { 2409 return false; 2410 } 2411 2412 Start = CurDAG->getTargetConstant(OuterShiftAmt - InnerShiftAmt, DL, 2413 MVT::i32); 2414 Len = CurDAG->getTargetConstant(Val.getValueSizeInBits() - OuterShiftAmt, 2415 DL, MVT::i32); 2416 2417 if (N->getOpcode() == ISD::SRA) { 2418 // If we have a arithmetic right shift, we need to use the signed bfe 2419 // variant 2420 IsSigned = true; 2421 } 2422 } else { 2423 // No can do... 2424 return false; 2425 } 2426 } else { 2427 // No can do... 2428 return false; 2429 } 2430 2431 2432 unsigned Opc; 2433 // For the BFE operations we form here from "and" and "srl", always use the 2434 // unsigned variants. 2435 if (Val.getValueType() == MVT::i32) { 2436 if (IsSigned) { 2437 Opc = NVPTX::BFE_S32rii; 2438 } else { 2439 Opc = NVPTX::BFE_U32rii; 2440 } 2441 } else if (Val.getValueType() == MVT::i64) { 2442 if (IsSigned) { 2443 Opc = NVPTX::BFE_S64rii; 2444 } else { 2445 Opc = NVPTX::BFE_U64rii; 2446 } 2447 } else { 2448 // We cannot handle this type 2449 return false; 2450 } 2451 2452 SDValue Ops[] = { 2453 Val, Start, Len 2454 }; 2455 2456 ReplaceNode(N, CurDAG->getMachineNode(Opc, DL, N->getVTList(), Ops)); 2457 return true; 2458 } 2459 2460 // Select bf16/bf16v2 FADD, FSUB, FMUL as fma on targets with only fma 2461 bool NVPTXDAGToDAGISel::tryBF16ArithToFMA(SDNode *N) { 2462 EVT VT = SDValue(N, 0).getValueType(); 2463 if (VT.getScalarType() != MVT::bf16) 2464 return false; 2465 2466 const NVPTXSubtarget *STI = TM.getSubtargetImpl(); 2467 if (STI->hasNativeBF16Support(N->getOpcode())) 2468 return false; 2469 2470 const bool IsVec = VT.isVector(); 2471 assert(!IsVec || VT.getVectorNumElements() == 2); 2472 SDLoc DL(N); 2473 SDValue N0 = N->getOperand(0); 2474 SDValue N1 = N->getOperand(1); 2475 SmallVector<SDValue, 3> Operands; 2476 auto GetConstant = [&](float Value) -> SDValue { 2477 // BF16 immediates must be legalized to integer register values 2478 APFloat APF(Value); 2479 bool LosesInfo; 2480 APF.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &LosesInfo); 2481 assert(!LosesInfo); 2482 if (IsVec) { 2483 auto API = APF.bitcastToAPInt(); 2484 API = API.concat(API); 2485 auto Const = CurDAG->getTargetConstant(API, DL, MVT::i32); 2486 return SDValue(CurDAG->getMachineNode(NVPTX::IMOV32ri, DL, VT, Const), 0); 2487 } 2488 auto Const = CurDAG->getTargetConstantFP(APF, DL, VT); 2489 return SDValue(CurDAG->getMachineNode(NVPTX::BFMOV16ri, DL, VT, Const), 0); 2490 }; 2491 2492 switch (N->getOpcode()) { 2493 case ISD::FADD: 2494 // add(a, b) -> fma(a, 1.0, b) 2495 Operands = {N0, GetConstant(1.0), N1}; 2496 break; 2497 case ISD::FSUB: 2498 // sub(a, b) -> fma(b, -1.0, a) 2499 Operands = {N1, GetConstant(-1.0), N0}; 2500 break; 2501 case ISD::FMUL: 2502 // mul(a, b) -> fma(a, b, -0.0) 2503 // NOTE: The identity is -0, not 0, because -0 + 0 == 0 for floats 2504 Operands = {N0, N1, GetConstant(-0.0)}; 2505 break; 2506 default: 2507 llvm_unreachable("Unexpected opcode"); 2508 }; 2509 2510 int Opcode = IsVec ? NVPTX::BFMA16x2rrr : NVPTX::BFMA16rrr; 2511 MachineSDNode *FMA = CurDAG->getMachineNode(Opcode, DL, VT, Operands); 2512 ReplaceNode(N, FMA); 2513 return true; 2514 } 2515 2516 static inline bool isAddLike(const SDValue V) { 2517 return V.getOpcode() == ISD::ADD || 2518 (V->getOpcode() == ISD::OR && V->getFlags().hasDisjoint()); 2519 } 2520 2521 // SelectDirectAddr - Match a direct address for DAG. 2522 // A direct address could be a globaladdress or externalsymbol. 2523 bool NVPTXDAGToDAGISel::SelectDirectAddr(SDValue N, SDValue &Address) { 2524 // Return true if TGA or ES. 2525 if (N.getOpcode() == ISD::TargetGlobalAddress || 2526 N.getOpcode() == ISD::TargetExternalSymbol) { 2527 Address = N; 2528 return true; 2529 } 2530 if (N.getOpcode() == NVPTXISD::Wrapper) { 2531 Address = N.getOperand(0); 2532 return true; 2533 } 2534 // addrspacecast(MoveParam(arg_symbol) to addrspace(PARAM)) -> arg_symbol 2535 if (AddrSpaceCastSDNode *CastN = dyn_cast<AddrSpaceCastSDNode>(N)) { 2536 if (CastN->getSrcAddressSpace() == ADDRESS_SPACE_GENERIC && 2537 CastN->getDestAddressSpace() == ADDRESS_SPACE_PARAM && 2538 CastN->getOperand(0).getOpcode() == NVPTXISD::MoveParam) 2539 return SelectDirectAddr(CastN->getOperand(0).getOperand(0), Address); 2540 } 2541 return false; 2542 } 2543 2544 // symbol+offset 2545 bool NVPTXDAGToDAGISel::SelectADDRsi_imp(SDNode *OpNode, SDValue Addr, 2546 SDValue &Base, SDValue &Offset, 2547 MVT VT) { 2548 std::function<std::optional<uint64_t>(SDValue, uint64_t)> 2549 FindRootAddressAndTotalOffset = 2550 [&](SDValue Addr, 2551 uint64_t AccumulatedOffset) -> std::optional<uint64_t> { 2552 if (isAddLike(Addr)) { 2553 if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(Addr.getOperand(1))) { 2554 SDValue PossibleBaseAddr = Addr.getOperand(0); 2555 AccumulatedOffset += CN->getZExtValue(); 2556 if (SelectDirectAddr(PossibleBaseAddr, Base)) 2557 return AccumulatedOffset; 2558 return FindRootAddressAndTotalOffset(PossibleBaseAddr, 2559 AccumulatedOffset); 2560 } 2561 } 2562 return std::nullopt; 2563 }; 2564 if (auto AccumulatedOffset = FindRootAddressAndTotalOffset(Addr, 0)) { 2565 Offset = CurDAG->getTargetConstant(*AccumulatedOffset, SDLoc(OpNode), VT); 2566 return true; 2567 } 2568 return false; 2569 } 2570 2571 // symbol+offset 2572 bool NVPTXDAGToDAGISel::SelectADDRsi(SDNode *OpNode, SDValue Addr, 2573 SDValue &Base, SDValue &Offset) { 2574 return SelectADDRsi_imp(OpNode, Addr, Base, Offset, MVT::i32); 2575 } 2576 2577 // symbol+offset 2578 bool NVPTXDAGToDAGISel::SelectADDRsi64(SDNode *OpNode, SDValue Addr, 2579 SDValue &Base, SDValue &Offset) { 2580 return SelectADDRsi_imp(OpNode, Addr, Base, Offset, MVT::i64); 2581 } 2582 2583 // register+offset 2584 bool NVPTXDAGToDAGISel::SelectADDRri_imp(SDNode *OpNode, SDValue Addr, 2585 SDValue &Base, SDValue &Offset, 2586 MVT VT) { 2587 if (FrameIndexSDNode *FIN = dyn_cast<FrameIndexSDNode>(Addr)) { 2588 Base = CurDAG->getTargetFrameIndex(FIN->getIndex(), VT); 2589 Offset = CurDAG->getTargetConstant(0, SDLoc(OpNode), VT); 2590 return true; 2591 } 2592 if (Addr.getOpcode() == ISD::TargetExternalSymbol || 2593 Addr.getOpcode() == ISD::TargetGlobalAddress) 2594 return false; // direct calls. 2595 2596 if (isAddLike(Addr)) { 2597 if (SelectDirectAddr(Addr.getOperand(0), Addr)) { 2598 return false; 2599 } 2600 if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(Addr.getOperand(1))) { 2601 if (FrameIndexSDNode *FIN = 2602 dyn_cast<FrameIndexSDNode>(Addr.getOperand(0))) 2603 // Constant offset from frame ref. 2604 Base = CurDAG->getTargetFrameIndex(FIN->getIndex(), VT); 2605 else 2606 Base = Addr.getOperand(0); 2607 2608 // Offset must fit in a 32-bit signed int in PTX [register+offset] address 2609 // mode 2610 if (!CN->getAPIntValue().isSignedIntN(32)) 2611 return false; 2612 2613 Offset = CurDAG->getSignedTargetConstant(CN->getSExtValue(), 2614 SDLoc(OpNode), MVT::i32); 2615 return true; 2616 } 2617 } 2618 return false; 2619 } 2620 2621 // register+offset 2622 bool NVPTXDAGToDAGISel::SelectADDRri(SDNode *OpNode, SDValue Addr, 2623 SDValue &Base, SDValue &Offset) { 2624 return SelectADDRri_imp(OpNode, Addr, Base, Offset, MVT::i32); 2625 } 2626 2627 // register+offset 2628 bool NVPTXDAGToDAGISel::SelectADDRri64(SDNode *OpNode, SDValue Addr, 2629 SDValue &Base, SDValue &Offset) { 2630 return SelectADDRri_imp(OpNode, Addr, Base, Offset, MVT::i64); 2631 } 2632 2633 bool NVPTXDAGToDAGISel::ChkMemSDNodeAddressSpace(SDNode *N, 2634 unsigned int spN) const { 2635 const Value *Src = nullptr; 2636 if (MemSDNode *mN = dyn_cast<MemSDNode>(N)) { 2637 if (spN == 0 && mN->getMemOperand()->getPseudoValue()) 2638 return true; 2639 Src = mN->getMemOperand()->getValue(); 2640 } 2641 if (!Src) 2642 return false; 2643 if (auto *PT = dyn_cast<PointerType>(Src->getType())) 2644 return (PT->getAddressSpace() == spN); 2645 return false; 2646 } 2647 2648 /// SelectInlineAsmMemoryOperand - Implement addressing mode selection for 2649 /// inline asm expressions. 2650 bool NVPTXDAGToDAGISel::SelectInlineAsmMemoryOperand( 2651 const SDValue &Op, InlineAsm::ConstraintCode ConstraintID, 2652 std::vector<SDValue> &OutOps) { 2653 SDValue Op0, Op1; 2654 switch (ConstraintID) { 2655 default: 2656 return true; 2657 case InlineAsm::ConstraintCode::m: // memory 2658 if (SelectDirectAddr(Op, Op0)) { 2659 OutOps.push_back(Op0); 2660 OutOps.push_back(CurDAG->getTargetConstant(0, SDLoc(Op), MVT::i32)); 2661 return false; 2662 } 2663 if (SelectADDRri(Op.getNode(), Op, Op0, Op1)) { 2664 OutOps.push_back(Op0); 2665 OutOps.push_back(Op1); 2666 return false; 2667 } 2668 break; 2669 } 2670 return true; 2671 } 2672 2673 void NVPTXDAGToDAGISel::SelectV2I64toI128(SDNode *N) { 2674 // Lower a CopyToReg with two 64-bit inputs 2675 // Dst:i128, lo:i64, hi:i64 2676 // 2677 // CopyToReg Dst, lo, hi; 2678 // 2679 // ==> 2680 // 2681 // tmp = V2I64toI128 {lo, hi}; 2682 // CopyToReg Dst, tmp; 2683 SDValue Dst = N->getOperand(1); 2684 SDValue Lo = N->getOperand(2); 2685 SDValue Hi = N->getOperand(3); 2686 2687 SDLoc DL(N); 2688 SDNode *Mov = 2689 CurDAG->getMachineNode(NVPTX::V2I64toI128, DL, MVT::i128, {Lo, Hi}); 2690 2691 SmallVector<SDValue, 4> NewOps(N->getNumOperands() - 1); 2692 NewOps[0] = N->getOperand(0); 2693 NewOps[1] = Dst; 2694 NewOps[2] = SDValue(Mov, 0); 2695 if (N->getNumOperands() == 5) 2696 NewOps[3] = N->getOperand(4); 2697 SDValue NewValue = CurDAG->getNode(ISD::CopyToReg, DL, SmallVector<EVT>(N->values()), NewOps); 2698 2699 ReplaceNode(N, NewValue.getNode()); 2700 } 2701 2702 void NVPTXDAGToDAGISel::SelectI128toV2I64(SDNode *N) { 2703 // Lower CopyFromReg from a 128-bit regs to two 64-bit regs 2704 // Dst:i128, Src:i128 2705 // 2706 // {lo, hi} = CopyFromReg Src 2707 // 2708 // ==> 2709 // 2710 // {lo, hi} = I128toV2I64 Src 2711 // 2712 SDValue Ch = N->getOperand(0); 2713 SDValue Src = N->getOperand(1); 2714 SDValue Glue = N->getOperand(2); 2715 SDLoc DL(N); 2716 2717 // Add Glue and Ch to the operands and results to avoid break the execution 2718 // order 2719 SDNode *Mov = CurDAG->getMachineNode( 2720 NVPTX::I128toV2I64, DL, 2721 {MVT::i64, MVT::i64, Ch.getValueType(), Glue.getValueType()}, 2722 {Src, Ch, Glue}); 2723 2724 ReplaceNode(N, Mov); 2725 } 2726 2727 /// GetConvertOpcode - Returns the CVT_ instruction opcode that implements a 2728 /// conversion from \p SrcTy to \p DestTy. 2729 unsigned NVPTXDAGToDAGISel::GetConvertOpcode(MVT DestTy, MVT SrcTy, 2730 LoadSDNode *LdNode) { 2731 bool IsSigned = LdNode && LdNode->getExtensionType() == ISD::SEXTLOAD; 2732 switch (SrcTy.SimpleTy) { 2733 default: 2734 llvm_unreachable("Unhandled source type"); 2735 case MVT::i8: 2736 switch (DestTy.SimpleTy) { 2737 default: 2738 llvm_unreachable("Unhandled dest type"); 2739 case MVT::i16: 2740 return IsSigned ? NVPTX::CVT_s16_s8 : NVPTX::CVT_u16_u8; 2741 case MVT::i32: 2742 return IsSigned ? NVPTX::CVT_s32_s8 : NVPTX::CVT_u32_u8; 2743 case MVT::i64: 2744 return IsSigned ? NVPTX::CVT_s64_s8 : NVPTX::CVT_u64_u8; 2745 } 2746 case MVT::i16: 2747 switch (DestTy.SimpleTy) { 2748 default: 2749 llvm_unreachable("Unhandled dest type"); 2750 case MVT::i8: 2751 return IsSigned ? NVPTX::CVT_s8_s16 : NVPTX::CVT_u8_u16; 2752 case MVT::i32: 2753 return IsSigned ? NVPTX::CVT_s32_s16 : NVPTX::CVT_u32_u16; 2754 case MVT::i64: 2755 return IsSigned ? NVPTX::CVT_s64_s16 : NVPTX::CVT_u64_u16; 2756 } 2757 case MVT::i32: 2758 switch (DestTy.SimpleTy) { 2759 default: 2760 llvm_unreachable("Unhandled dest type"); 2761 case MVT::i8: 2762 return IsSigned ? NVPTX::CVT_s8_s32 : NVPTX::CVT_u8_u32; 2763 case MVT::i16: 2764 return IsSigned ? NVPTX::CVT_s16_s32 : NVPTX::CVT_u16_u32; 2765 case MVT::i64: 2766 return IsSigned ? NVPTX::CVT_s64_s32 : NVPTX::CVT_u64_u32; 2767 } 2768 case MVT::i64: 2769 switch (DestTy.SimpleTy) { 2770 default: 2771 llvm_unreachable("Unhandled dest type"); 2772 case MVT::i8: 2773 return IsSigned ? NVPTX::CVT_s8_s64 : NVPTX::CVT_u8_u64; 2774 case MVT::i16: 2775 return IsSigned ? NVPTX::CVT_s16_s64 : NVPTX::CVT_u16_u64; 2776 case MVT::i32: 2777 return IsSigned ? NVPTX::CVT_s32_s64 : NVPTX::CVT_u32_u64; 2778 } 2779 case MVT::f16: 2780 switch (DestTy.SimpleTy) { 2781 default: 2782 llvm_unreachable("Unhandled dest type"); 2783 case MVT::f32: 2784 return NVPTX::CVT_f32_f16; 2785 case MVT::f64: 2786 return NVPTX::CVT_f64_f16; 2787 } 2788 } 2789 } 2790 2791 bool NVPTXDAGToDAGISel::tryFence(SDNode *N) { 2792 SDLoc DL(N); 2793 assert(N->getOpcode() == ISD::ATOMIC_FENCE); 2794 unsigned int FenceOp = 2795 getFenceOp(NVPTX::Ordering(N->getConstantOperandVal(1)), 2796 Scopes[N->getConstantOperandVal(2)], Subtarget); 2797 SDValue Chain = N->getOperand(0); 2798 SDNode *FenceNode = CurDAG->getMachineNode(FenceOp, DL, MVT::Other, Chain); 2799 ReplaceNode(N, FenceNode); 2800 return true; 2801 } 2802 2803 NVPTXScopes::NVPTXScopes(LLVMContext &C) { 2804 Scopes[C.getOrInsertSyncScopeID("singlethread")] = NVPTX::Scope::Thread; 2805 Scopes[C.getOrInsertSyncScopeID("")] = NVPTX::Scope::System; 2806 Scopes[C.getOrInsertSyncScopeID("block")] = NVPTX::Scope::Block; 2807 Scopes[C.getOrInsertSyncScopeID("cluster")] = NVPTX::Scope::Cluster; 2808 Scopes[C.getOrInsertSyncScopeID("device")] = NVPTX::Scope::Device; 2809 } 2810 2811 NVPTX::Scope NVPTXScopes::operator[](SyncScope::ID ID) const { 2812 if (Scopes.empty()) 2813 llvm_unreachable("NVPTX Scopes must be initialized before calling " 2814 "NVPTXScopes::operator[]"); 2815 2816 auto S = Scopes.find(ID); 2817 if (S == Scopes.end()) { 2818 // TODO: 2819 // - Add API to LLVMContext to get the name of a single scope. 2820 // - Use that API here to print an error containing the name 2821 // of this Unknown ID. 2822 report_fatal_error(formatv("Could not find scope ID={}.", int(ID))); 2823 } 2824 return S->second; 2825 } 2826 2827 bool NVPTXScopes::empty() const { return Scopes.size() == 0; } 2828 2829 #define CP_ASYNC_BULK_TENSOR_OPCODE(dir, dim, mode, is_s32, suffix) \ 2830 (is_s32 \ 2831 ? NVPTX::CP_ASYNC_BULK_TENSOR_##dir##_##dim##_SHARED32_##mode##suffix \ 2832 : NVPTX::CP_ASYNC_BULK_TENSOR_##dir##_##dim##_##mode##suffix) 2833 2834 #define CP_ASYNC_BULK_TENSOR_OPCODE_S2G_IMPL(op, dim, mode, is_ch, is_s32) \ 2835 (is_ch ? (CP_ASYNC_BULK_TENSOR_OPCODE(op, dim, mode, is_s32, _CH)) \ 2836 : (CP_ASYNC_BULK_TENSOR_OPCODE(op, dim, mode, is_s32, ))) 2837 2838 #define GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(dim, mode, is_reduce, is_ch, \ 2839 is_s32) \ 2840 (is_reduce \ 2841 ? (CP_ASYNC_BULK_TENSOR_OPCODE_S2G_IMPL(RED, dim, mode, is_ch, is_s32)) \ 2842 : (CP_ASYNC_BULK_TENSOR_OPCODE_S2G_IMPL(S2G, dim, mode, is_ch, \ 2843 is_s32))) 2844 2845 #define GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(dim, mode, is_mc, is_ch, is_s32) \ 2846 [&]() -> auto { \ 2847 if (is_mc && is_ch) \ 2848 return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, _MC_CH); \ 2849 if (is_ch) \ 2850 return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, _CH); \ 2851 if (is_mc) \ 2852 return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, _MC); \ 2853 return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, ); \ 2854 }() 2855 2856 #define GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(dim, mode, is_ch) \ 2857 (is_ch ? NVPTX::CP_ASYNC_BULK_TENSOR_PREFETCH_##dim##_##mode##_CH \ 2858 : NVPTX::CP_ASYNC_BULK_TENSOR_PREFETCH_##dim##_##mode) 2859 2860 static unsigned GetCpAsyncBulkTensorS2GOpcode(size_t Dim, bool IsShared32, 2861 bool IsCacheHint, bool IsIm2Col, 2862 bool IsReduce = false) { 2863 if (IsIm2Col) { 2864 switch (Dim) { 2865 case 3: 2866 return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(3D, IM2COL, IsReduce, 2867 IsCacheHint, IsShared32); 2868 case 4: 2869 return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(4D, IM2COL, IsReduce, 2870 IsCacheHint, IsShared32); 2871 case 5: 2872 return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(5D, IM2COL, IsReduce, 2873 IsCacheHint, IsShared32); 2874 default: 2875 llvm_unreachable("Invalid Dimension in im2col mode for " 2876 "GetCpAsyncBulkTensorS2GOpcode."); 2877 } 2878 } else { 2879 switch (Dim) { 2880 case 1: 2881 return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(1D, TILE, IsReduce, 2882 IsCacheHint, IsShared32); 2883 case 2: 2884 return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(2D, TILE, IsReduce, 2885 IsCacheHint, IsShared32); 2886 case 3: 2887 return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(3D, TILE, IsReduce, 2888 IsCacheHint, IsShared32); 2889 case 4: 2890 return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(4D, TILE, IsReduce, 2891 IsCacheHint, IsShared32); 2892 case 5: 2893 return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(5D, TILE, IsReduce, 2894 IsCacheHint, IsShared32); 2895 default: 2896 llvm_unreachable( 2897 "Invalid Dimension in tile mode for GetCpAsyncBulkTensorS2GOpcode."); 2898 } 2899 } 2900 } 2901 2902 static unsigned GetCpAsyncBulkTensorG2SOpcode(size_t Dim, bool IsShared32, 2903 bool IsMultiCast, 2904 bool IsCacheHint, bool IsIm2Col) { 2905 if (IsIm2Col) { 2906 switch (Dim) { 2907 case 3: 2908 return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(3D, IM2COL, IsMultiCast, 2909 IsCacheHint, IsShared32); 2910 case 4: 2911 return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(4D, IM2COL, IsMultiCast, 2912 IsCacheHint, IsShared32); 2913 case 5: 2914 return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(5D, IM2COL, IsMultiCast, 2915 IsCacheHint, IsShared32); 2916 default: 2917 llvm_unreachable("Invalid Dimension in im2col mode for " 2918 "GetCpAsyncBulkTensorG2SOpcode."); 2919 } 2920 } else { 2921 switch (Dim) { 2922 case 1: 2923 return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(1D, TILE, IsMultiCast, 2924 IsCacheHint, IsShared32); 2925 case 2: 2926 return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(2D, TILE, IsMultiCast, 2927 IsCacheHint, IsShared32); 2928 case 3: 2929 return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(3D, TILE, IsMultiCast, 2930 IsCacheHint, IsShared32); 2931 case 4: 2932 return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(4D, TILE, IsMultiCast, 2933 IsCacheHint, IsShared32); 2934 case 5: 2935 return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(5D, TILE, IsMultiCast, 2936 IsCacheHint, IsShared32); 2937 default: 2938 llvm_unreachable( 2939 "Invalid Dimension in tile mode for GetCpAsyncBulkTensorG2SOpcode."); 2940 } 2941 } 2942 } 2943 2944 static unsigned GetCpAsyncBulkTensorPrefetchOpcode(size_t Dim, bool IsCacheHint, 2945 bool IsIm2Col) { 2946 if (IsIm2Col) { 2947 switch (Dim) { 2948 case 3: 2949 return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(3D, IM2COL, IsCacheHint); 2950 case 4: 2951 return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(4D, IM2COL, IsCacheHint); 2952 case 5: 2953 return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(5D, IM2COL, IsCacheHint); 2954 default: 2955 llvm_unreachable("Invalid Dimension in im2col mode for " 2956 "GetCpAsyncBulkTensorPrefetchOpcode."); 2957 } 2958 } else { 2959 switch (Dim) { 2960 case 1: 2961 return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(1D, TILE, IsCacheHint); 2962 case 2: 2963 return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(2D, TILE, IsCacheHint); 2964 case 3: 2965 return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(3D, TILE, IsCacheHint); 2966 case 4: 2967 return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(4D, TILE, IsCacheHint); 2968 case 5: 2969 return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(5D, TILE, IsCacheHint); 2970 default: 2971 llvm_unreachable("Invalid Dimension in tile mode for " 2972 "GetCpAsyncBulkTensorPrefetchOpcode."); 2973 } 2974 } 2975 } 2976 2977 static size_t GetDimsFromIntrinsic(unsigned IID) { 2978 switch (IID) { 2979 case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d: 2980 case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d: 2981 return 3; 2982 case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d: 2983 case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d: 2984 return 4; 2985 case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d: 2986 case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d: 2987 return 5; 2988 default: 2989 llvm_unreachable("Invalid im2col intrinsic in GetDimsFromIntrinsic."); 2990 } 2991 } 2992 2993 void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorG2SCommon(SDNode *N, 2994 bool IsIm2Col) { 2995 // We have {Chain, Intrinsic-ID} followed by the actual intrisic args: 2996 // {dst, mbar, src, dims{d0...dN}, im2col_offsets{dims-2} 2997 // multicast, cache_hint, 2998 // multicast_flag, cache_hint_flag} 2999 // NumOperands = {Chain, IID} + {Actual intrinsic args} 3000 // = {2} + {7 + dims + im2col_offsets} 3001 size_t NumOps = N->getNumOperands(); 3002 size_t NumDims = IsIm2Col ? GetDimsFromIntrinsic(N->getConstantOperandVal(1)) 3003 : (NumOps - 9); 3004 // Offsets is always 'NumDims - 2' and only for im2col mode 3005 size_t NumOffsets = IsIm2Col ? (NumDims - 2) : 0; 3006 bool IsCacheHint = N->getConstantOperandVal(NumOps - 1) == 1; 3007 bool IsMultiCast = N->getConstantOperandVal(NumOps - 2) == 1; 3008 size_t NumBaseArgs = NumDims + NumOffsets + 3; // for {dst, mbar, src} 3009 size_t MultiCastIdx = NumBaseArgs + 2; // for Chain and IID 3010 3011 SDLoc DL(N); 3012 SmallVector<SDValue, 8> Ops(N->ops().slice(2, NumBaseArgs)); 3013 3014 // Push MultiCast operand, if available 3015 if (IsMultiCast) 3016 Ops.push_back(N->getOperand(MultiCastIdx)); 3017 3018 // Push CacheHint operand, if available 3019 if (IsCacheHint) 3020 Ops.push_back(N->getOperand(MultiCastIdx + 1)); 3021 3022 // Finally, the chain operand 3023 Ops.push_back(N->getOperand(0)); 3024 3025 bool IsShared32 = 3026 CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32; 3027 unsigned Opcode = GetCpAsyncBulkTensorG2SOpcode( 3028 NumDims, IsShared32, IsMultiCast, IsCacheHint, IsIm2Col); 3029 ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops)); 3030 } 3031 3032 void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorS2GCommon(SDNode *N, 3033 bool IsIm2Col) { 3034 // We have {Chain, Intrinsic-ID} followed by the actual intrisic args: 3035 // src, dst, dims{d0...dN}, cache_hint, cache_hint_flag 3036 // NumOperands = {Chain, IID} + {Actual intrinsic args} 3037 // = {2} + {4 + dims} 3038 size_t NumOps = N->getNumOperands(); 3039 size_t NumDims = NumOps - 6; 3040 bool IsCacheHint = N->getConstantOperandVal(NumOps - 1) == 1; 3041 size_t NumArgs = NumDims + (IsCacheHint ? 3 : 2); // src, dst, cache_hint 3042 3043 SDLoc DL(N); 3044 SmallVector<SDValue, 8> Ops(N->ops().slice(2, NumArgs)); 3045 Ops.push_back(N->getOperand(0)); // Chain operand 3046 3047 bool IsShared32 = 3048 CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32; 3049 unsigned Opcode = 3050 GetCpAsyncBulkTensorS2GOpcode(NumDims, IsShared32, IsCacheHint, IsIm2Col); 3051 ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops)); 3052 } 3053 3054 void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorPrefetchCommon(SDNode *N, 3055 bool IsIm2Col) { 3056 // We have {Chain, Intrinsic-ID} followed by the actual intrisic args: 3057 // {src, dims{d0...dN}, im2col_offsets{dims-2} 3058 // cache_hint, cache_hint_flag} 3059 // NumOperands = {Chain, IID} + {Actual intrinsic args} 3060 // = {2} + {3 + dims + im2col_offsets} 3061 size_t NumOps = N->getNumOperands(); 3062 size_t NumDims = IsIm2Col ? GetDimsFromIntrinsic(N->getConstantOperandVal(1)) 3063 : (NumOps - 5); 3064 // Offsets is always 'NumDims - 2' and only for im2col mode 3065 size_t NumOffsets = IsIm2Col ? (NumDims - 2) : 0; 3066 bool IsCacheHint = N->getConstantOperandVal(NumOps - 1) == 1; 3067 size_t NumArgs = NumDims + NumOffsets + (IsCacheHint ? 2 : 1); 3068 3069 SDLoc DL(N); 3070 SmallVector<SDValue, 12> Ops(N->ops().slice(2, NumArgs)); 3071 Ops.push_back(N->getOperand(0)); // Chain operand 3072 3073 unsigned Opcode = 3074 GetCpAsyncBulkTensorPrefetchOpcode(NumDims, IsCacheHint, IsIm2Col); 3075 ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops)); 3076 } 3077 3078 void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorReduceCommon(SDNode *N, 3079 unsigned RedOp, 3080 bool IsIm2Col) { 3081 // We have {Chain, Intrinsic-ID} followed by the actual intrisic args: 3082 // src, dst, dims{d0...dN}, cache_hint, cache_hint_flag 3083 // NumOperands = {Chain, IID} + {Actual intrinsic args} 3084 // = {2} + {4 + dims} 3085 size_t NumOps = N->getNumOperands(); 3086 size_t NumDims = NumOps - 6; 3087 bool IsCacheHint = N->getConstantOperandVal(NumOps - 1) == 1; 3088 size_t NumArgs = NumDims + (IsCacheHint ? 3 : 2); // src, dst, cache_hint 3089 3090 SDLoc DL(N); 3091 SmallVector<SDValue, 12> Ops(N->ops().slice(2, NumArgs)); 3092 Ops.push_back(getI32Imm(RedOp, DL)); // Reduction Op 3093 Ops.push_back(N->getOperand(0)); // Chain operand 3094 3095 bool IsShared32 = 3096 CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32; 3097 unsigned Opcode = GetCpAsyncBulkTensorS2GOpcode( 3098 NumDims, IsShared32, IsCacheHint, IsIm2Col, /*IsReduce=*/true); 3099 ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops)); 3100 } 3101 3102 void NVPTXDAGToDAGISel::SelectCpAsyncBulkS2G(SDNode *N) { 3103 // We have {Chain, Intrinsic-ID} followed by the actual intrisic args: 3104 // dst, src, size, cache_hint, cache_hint_flag 3105 // NumOperands = {Chain, IID} + {Actual intrinsic args} 3106 // = {2} + {5} 3107 size_t NumOps = N->getNumOperands(); 3108 bool IsCacheHint = N->getConstantOperandVal(NumOps - 1) == 1; 3109 size_t NumArgs = IsCacheHint ? 4 : 3; // src, dst, size, cache_hint 3110 3111 SDLoc DL(N); 3112 SmallVector<SDValue, 8> Ops(N->ops().slice(2, NumArgs)); 3113 Ops.push_back(N->getOperand(0)); // Chain operand 3114 3115 bool IsShared32 = 3116 CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32; 3117 unsigned Opcode; 3118 if (IsCacheHint) 3119 Opcode = IsShared32 ? NVPTX::CP_ASYNC_BULK_S2G_SHARED32_CH 3120 : NVPTX::CP_ASYNC_BULK_S2G_CH; 3121 else 3122 Opcode = IsShared32 ? NVPTX::CP_ASYNC_BULK_S2G_SHARED32 3123 : NVPTX::CP_ASYNC_BULK_S2G; 3124 ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops)); 3125 } 3126 3127 void NVPTXDAGToDAGISel::SelectCpAsyncBulkG2S(SDNode *N) { 3128 // We have {Chain, Intrinsic-ID} followed by the actual intrisic args: 3129 // {dst, mbar, src, size, multicast, cache_hint, 3130 // multicast_flag, cache_hint_flag} 3131 // NumOperands = {Chain, IID} + {Actual intrinsic args} 3132 // = {2} + {8} 3133 size_t NumOps = N->getNumOperands(); 3134 bool IsCacheHint = N->getConstantOperandVal(NumOps - 1) == 1; 3135 bool IsMultiCast = N->getConstantOperandVal(NumOps - 2) == 1; 3136 size_t NumBaseArgs = 4; // dst, mbar, src, size 3137 size_t MultiCastIdx = NumBaseArgs + 2; // for Chain and IID 3138 3139 SDLoc DL(N); 3140 SmallVector<SDValue, 8> Ops(N->ops().slice(2, NumBaseArgs)); 3141 3142 // Push MultiCast operand, if available 3143 if (IsMultiCast) 3144 Ops.push_back(N->getOperand(MultiCastIdx)); 3145 3146 // Push CacheHint operand, if available 3147 if (IsCacheHint) 3148 Ops.push_back(N->getOperand(MultiCastIdx + 1)); 3149 3150 // Finally, the chain operand 3151 Ops.push_back(N->getOperand(0)); 3152 3153 bool IsShared32 = 3154 CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32; 3155 unsigned Opcode = [&]() { 3156 if (IsMultiCast && IsCacheHint) 3157 return IsShared32 ? NVPTX::CP_ASYNC_BULK_G2S_SHARED32_MC_CH 3158 : NVPTX::CP_ASYNC_BULK_G2S_MC_CH; 3159 if (IsMultiCast) 3160 return IsShared32 ? NVPTX::CP_ASYNC_BULK_G2S_SHARED32_MC 3161 : NVPTX::CP_ASYNC_BULK_G2S_MC; 3162 if (IsCacheHint) 3163 return IsShared32 ? NVPTX::CP_ASYNC_BULK_G2S_SHARED32_CH 3164 : NVPTX::CP_ASYNC_BULK_G2S_CH; 3165 return IsShared32 ? NVPTX::CP_ASYNC_BULK_G2S_SHARED32 3166 : NVPTX::CP_ASYNC_BULK_G2S; 3167 }(); 3168 ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops)); 3169 } 3170 3171 void NVPTXDAGToDAGISel::SelectCpAsyncBulkPrefetchL2(SDNode *N) { 3172 // We have {Chain, Intrinsic-ID} followed by the actual intrisic args: 3173 // src, size, cache_hint, cache_hint_flag 3174 // NumOperands = {Chain, IID} + {Actual intrinsic args} 3175 // = {2} + {4} 3176 size_t NumOps = N->getNumOperands(); 3177 bool IsCacheHint = N->getConstantOperandVal(NumOps - 1) == 1; 3178 size_t NumArgs = IsCacheHint ? 3 : 2; // src, size, cache_hint 3179 3180 SDLoc DL(N); 3181 SmallVector<SDValue, 4> Ops(N->ops().slice(2, NumArgs)); 3182 Ops.push_back(N->getOperand(0)); // Chain operand 3183 3184 unsigned Opcode = IsCacheHint 3185 ? NVPTX::CP_ASYNC_BULK_PREFETCH_CH 3186 : NVPTX::CP_ASYNC_BULK_PREFETCH; 3187 ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops)); 3188 } 3189 3190 bool NVPTXDAGToDAGISel::tryIntrinsicVoid(SDNode *N) { 3191 unsigned IID = N->getConstantOperandVal(1); 3192 using TMARedTy = llvm::nvvm::TMAReductionOp; 3193 auto CastTy = [](TMARedTy Op) { return static_cast<unsigned>(Op); }; 3194 switch (IID) { 3195 default: 3196 return false; 3197 case Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster: 3198 SelectCpAsyncBulkG2S(N); 3199 return true; 3200 case Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global: 3201 SelectCpAsyncBulkS2G(N); 3202 return true; 3203 case Intrinsic::nvvm_cp_async_bulk_prefetch_L2: 3204 SelectCpAsyncBulkPrefetchL2(N); 3205 return true; 3206 case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d: 3207 case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d: 3208 case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d: 3209 case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d: 3210 case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d: 3211 SelectCpAsyncBulkTensorS2GCommon(N); 3212 return true; 3213 case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d: 3214 case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d: 3215 case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d: 3216 SelectCpAsyncBulkTensorS2GCommon(N, /*IsIm2Col=*/true); 3217 return true; 3218 case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d: 3219 case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d: 3220 case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d: 3221 case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d: 3222 case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d: 3223 SelectCpAsyncBulkTensorG2SCommon(N); 3224 return true; 3225 case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d: 3226 case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d: 3227 case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d: 3228 SelectCpAsyncBulkTensorG2SCommon(N, /*IsIm2Col=*/true); 3229 return true; 3230 case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d: 3231 case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d: 3232 case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d: 3233 case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d: 3234 case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d: 3235 SelectCpAsyncBulkTensorPrefetchCommon(N); 3236 return true; 3237 case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d: 3238 case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d: 3239 case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d: 3240 SelectCpAsyncBulkTensorPrefetchCommon(N, /*IsIm2Col=*/true); 3241 return true; 3242 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d: 3243 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d: 3244 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d: 3245 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d: 3246 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d: 3247 SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::ADD)); 3248 return true; 3249 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d: 3250 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d: 3251 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d: 3252 SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::ADD), 3253 /*IsIm2Col=*/true); 3254 return true; 3255 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d: 3256 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d: 3257 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d: 3258 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d: 3259 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d: 3260 SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::MIN)); 3261 return true; 3262 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d: 3263 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d: 3264 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d: 3265 SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::MIN), 3266 /*IsIm2Col=*/true); 3267 return true; 3268 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d: 3269 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d: 3270 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d: 3271 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d: 3272 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d: 3273 SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::MAX)); 3274 return true; 3275 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d: 3276 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d: 3277 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d: 3278 SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::MAX), 3279 /*IsIm2Col=*/true); 3280 return true; 3281 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d: 3282 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d: 3283 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d: 3284 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d: 3285 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d: 3286 SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::INC)); 3287 return true; 3288 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d: 3289 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d: 3290 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d: 3291 SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::INC), 3292 /*IsIm2Col=*/true); 3293 return true; 3294 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d: 3295 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d: 3296 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d: 3297 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d: 3298 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d: 3299 SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::DEC)); 3300 return true; 3301 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d: 3302 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d: 3303 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d: 3304 SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::DEC), 3305 /*IsIm2Col=*/true); 3306 return true; 3307 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d: 3308 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d: 3309 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d: 3310 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d: 3311 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d: 3312 SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::AND)); 3313 return true; 3314 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d: 3315 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d: 3316 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d: 3317 SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::AND), 3318 /*IsIm2Col=*/true); 3319 return true; 3320 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d: 3321 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d: 3322 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d: 3323 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d: 3324 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d: 3325 SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::OR)); 3326 return true; 3327 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d: 3328 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d: 3329 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d: 3330 SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::OR), 3331 /*IsIm2Col=*/true); 3332 return true; 3333 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d: 3334 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d: 3335 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d: 3336 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d: 3337 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d: 3338 SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::XOR)); 3339 return true; 3340 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d: 3341 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d: 3342 case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d: 3343 SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::XOR), 3344 /*IsIm2Col=*/true); 3345 return true; 3346 } 3347 } 3348