xref: /llvm-project/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (revision fa7f0e582bc25a91d89dab7c488a1619060f9bef)
1 //===-- NVPTXISelDAGToDAG.cpp - A dag to dag inst selector for NVPTX ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines an instruction selector for the NVPTX target.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "NVPTXISelDAGToDAG.h"
14 #include "NVPTX.h"
15 #include "NVPTXUtilities.h"
16 #include "llvm/Analysis/ValueTracking.h"
17 #include "llvm/CodeGen/ISDOpcodes.h"
18 #include "llvm/CodeGen/SelectionDAGNodes.h"
19 #include "llvm/IR/GlobalValue.h"
20 #include "llvm/IR/Instructions.h"
21 #include "llvm/IR/IntrinsicsNVPTX.h"
22 #include "llvm/IR/NVVMIntrinsicUtils.h"
23 #include "llvm/Support/AtomicOrdering.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Target/TargetIntrinsicInfo.h"
28 
29 using namespace llvm;
30 
31 #define DEBUG_TYPE "nvptx-isel"
32 #define PASS_NAME "NVPTX DAG->DAG Pattern Instruction Selection"
33 
34 static cl::opt<bool>
35     EnableRsqrtOpt("nvptx-rsqrt-approx-opt", cl::init(true), cl::Hidden,
36                    cl::desc("Enable reciprocal sqrt optimization"));
37 
38 /// createNVPTXISelDag - This pass converts a legalized DAG into a
39 /// NVPTX-specific DAG, ready for instruction scheduling.
40 FunctionPass *llvm::createNVPTXISelDag(NVPTXTargetMachine &TM,
41                                        llvm::CodeGenOptLevel OptLevel) {
42   return new NVPTXDAGToDAGISelLegacy(TM, OptLevel);
43 }
44 
45 NVPTXDAGToDAGISelLegacy::NVPTXDAGToDAGISelLegacy(NVPTXTargetMachine &tm,
46                                                  CodeGenOptLevel OptLevel)
47     : SelectionDAGISelLegacy(
48           ID, std::make_unique<NVPTXDAGToDAGISel>(tm, OptLevel)) {}
49 
50 char NVPTXDAGToDAGISelLegacy::ID = 0;
51 
52 INITIALIZE_PASS(NVPTXDAGToDAGISelLegacy, DEBUG_TYPE, PASS_NAME, false, false)
53 
54 NVPTXDAGToDAGISel::NVPTXDAGToDAGISel(NVPTXTargetMachine &tm,
55                                      CodeGenOptLevel OptLevel)
56     : SelectionDAGISel(tm, OptLevel), TM(tm) {
57   doMulWide = (OptLevel > CodeGenOptLevel::None);
58 }
59 
60 bool NVPTXDAGToDAGISel::runOnMachineFunction(MachineFunction &MF) {
61   Subtarget = &MF.getSubtarget<NVPTXSubtarget>();
62   Scopes = NVPTXScopes(MF.getFunction().getContext());
63   return SelectionDAGISel::runOnMachineFunction(MF);
64 }
65 
66 int NVPTXDAGToDAGISel::getDivF32Level() const {
67   return Subtarget->getTargetLowering()->getDivF32Level();
68 }
69 
70 bool NVPTXDAGToDAGISel::usePrecSqrtF32() const {
71   return Subtarget->getTargetLowering()->usePrecSqrtF32();
72 }
73 
74 bool NVPTXDAGToDAGISel::useF32FTZ() const {
75   return Subtarget->getTargetLowering()->useF32FTZ(*MF);
76 }
77 
78 bool NVPTXDAGToDAGISel::allowFMA() const {
79   const NVPTXTargetLowering *TL = Subtarget->getTargetLowering();
80   return TL->allowFMA(*MF, OptLevel);
81 }
82 
83 bool NVPTXDAGToDAGISel::allowUnsafeFPMath() const {
84   const NVPTXTargetLowering *TL = Subtarget->getTargetLowering();
85   return TL->allowUnsafeFPMath(*MF);
86 }
87 
88 bool NVPTXDAGToDAGISel::doRsqrtOpt() const { return EnableRsqrtOpt; }
89 
90 /// Select - Select instructions not customized! Used for
91 /// expanded, promoted and normal instructions.
92 void NVPTXDAGToDAGISel::Select(SDNode *N) {
93 
94   if (N->isMachineOpcode()) {
95     N->setNodeId(-1);
96     return; // Already selected.
97   }
98 
99   switch (N->getOpcode()) {
100   case ISD::LOAD:
101   case ISD::ATOMIC_LOAD:
102     if (tryLoad(N))
103       return;
104     break;
105   case ISD::STORE:
106   case ISD::ATOMIC_STORE:
107     if (tryStore(N))
108       return;
109     break;
110   case ISD::ATOMIC_FENCE:
111     if (tryFence(N))
112       return;
113     break;
114   case ISD::EXTRACT_VECTOR_ELT:
115     if (tryEXTRACT_VECTOR_ELEMENT(N))
116       return;
117     break;
118   case NVPTXISD::SETP_F16X2:
119     SelectSETP_F16X2(N);
120     return;
121   case NVPTXISD::SETP_BF16X2:
122     SelectSETP_BF16X2(N);
123     return;
124   case NVPTXISD::LoadV2:
125   case NVPTXISD::LoadV4:
126     if (tryLoadVector(N))
127       return;
128     break;
129   case NVPTXISD::LDUV2:
130   case NVPTXISD::LDUV4:
131     if (tryLDGLDU(N))
132       return;
133     break;
134   case NVPTXISD::StoreV2:
135   case NVPTXISD::StoreV4:
136     if (tryStoreVector(N))
137       return;
138     break;
139   case NVPTXISD::LoadParam:
140   case NVPTXISD::LoadParamV2:
141   case NVPTXISD::LoadParamV4:
142     if (tryLoadParam(N))
143       return;
144     break;
145   case NVPTXISD::StoreRetval:
146   case NVPTXISD::StoreRetvalV2:
147   case NVPTXISD::StoreRetvalV4:
148     if (tryStoreRetval(N))
149       return;
150     break;
151   case NVPTXISD::StoreParam:
152   case NVPTXISD::StoreParamV2:
153   case NVPTXISD::StoreParamV4:
154   case NVPTXISD::StoreParamS32:
155   case NVPTXISD::StoreParamU32:
156     if (tryStoreParam(N))
157       return;
158     break;
159   case ISD::INTRINSIC_WO_CHAIN:
160     if (tryIntrinsicNoChain(N))
161       return;
162     break;
163   case ISD::INTRINSIC_W_CHAIN:
164     if (tryIntrinsicChain(N))
165       return;
166     break;
167   case ISD::INTRINSIC_VOID:
168     if (tryIntrinsicVoid(N))
169       return;
170     break;
171   case ISD::AND:
172   case ISD::SRA:
173   case ISD::SRL:
174     // Try to select BFE
175     if (tryBFE(N))
176       return;
177     break;
178   case ISD::ADDRSPACECAST:
179     SelectAddrSpaceCast(N);
180     return;
181   case ISD::CopyToReg: {
182     if (N->getOperand(1).getValueType() == MVT::i128) {
183       SelectV2I64toI128(N);
184       return;
185     }
186     break;
187   }
188   case ISD::CopyFromReg: {
189     if (N->getOperand(1).getValueType() == MVT::i128) {
190       SelectI128toV2I64(N);
191       return;
192     }
193     break;
194   }
195   case ISD::FADD:
196   case ISD::FMUL:
197   case ISD::FSUB:
198     if (tryBF16ArithToFMA(N))
199       return;
200     break;
201   default:
202     break;
203   }
204   SelectCode(N);
205 }
206 
207 bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) {
208   unsigned IID = N->getConstantOperandVal(1);
209   switch (IID) {
210   default:
211     return false;
212   case Intrinsic::nvvm_ldu_global_f:
213   case Intrinsic::nvvm_ldu_global_i:
214   case Intrinsic::nvvm_ldu_global_p:
215     return tryLDGLDU(N);
216   }
217 }
218 
219 // Map ISD:CONDCODE value to appropriate CmpMode expected by
220 // NVPTXInstPrinter::printCmpMode()
221 static unsigned getPTXCmpMode(const CondCodeSDNode &CondCode, bool FTZ) {
222   using NVPTX::PTXCmpMode::CmpMode;
223   unsigned PTXCmpMode = [](ISD::CondCode CC) {
224     switch (CC) {
225     default:
226       llvm_unreachable("Unexpected condition code.");
227     case ISD::SETOEQ:
228       return CmpMode::EQ;
229     case ISD::SETOGT:
230       return CmpMode::GT;
231     case ISD::SETOGE:
232       return CmpMode::GE;
233     case ISD::SETOLT:
234       return CmpMode::LT;
235     case ISD::SETOLE:
236       return CmpMode::LE;
237     case ISD::SETONE:
238       return CmpMode::NE;
239     case ISD::SETO:
240       return CmpMode::NUM;
241     case ISD::SETUO:
242       return CmpMode::NotANumber;
243     case ISD::SETUEQ:
244       return CmpMode::EQU;
245     case ISD::SETUGT:
246       return CmpMode::GTU;
247     case ISD::SETUGE:
248       return CmpMode::GEU;
249     case ISD::SETULT:
250       return CmpMode::LTU;
251     case ISD::SETULE:
252       return CmpMode::LEU;
253     case ISD::SETUNE:
254       return CmpMode::NEU;
255     case ISD::SETEQ:
256       return CmpMode::EQ;
257     case ISD::SETGT:
258       return CmpMode::GT;
259     case ISD::SETGE:
260       return CmpMode::GE;
261     case ISD::SETLT:
262       return CmpMode::LT;
263     case ISD::SETLE:
264       return CmpMode::LE;
265     case ISD::SETNE:
266       return CmpMode::NE;
267     }
268   }(CondCode.get());
269 
270   if (FTZ)
271     PTXCmpMode |= NVPTX::PTXCmpMode::FTZ_FLAG;
272 
273   return PTXCmpMode;
274 }
275 
276 bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) {
277   unsigned PTXCmpMode =
278       getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)), useF32FTZ());
279   SDLoc DL(N);
280   SDNode *SetP = CurDAG->getMachineNode(
281       NVPTX::SETP_f16x2rr, DL, MVT::i1, MVT::i1, N->getOperand(0),
282       N->getOperand(1), CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32));
283   ReplaceNode(N, SetP);
284   return true;
285 }
286 
287 bool NVPTXDAGToDAGISel::SelectSETP_BF16X2(SDNode *N) {
288   unsigned PTXCmpMode =
289       getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)), useF32FTZ());
290   SDLoc DL(N);
291   SDNode *SetP = CurDAG->getMachineNode(
292       NVPTX::SETP_bf16x2rr, DL, MVT::i1, MVT::i1, N->getOperand(0),
293       N->getOperand(1), CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32));
294   ReplaceNode(N, SetP);
295   return true;
296 }
297 
298 // Find all instances of extract_vector_elt that use this v2f16 vector
299 // and coalesce them into a scattering move instruction.
300 bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
301   SDValue Vector = N->getOperand(0);
302 
303   // We only care about 16x2 as it's the only real vector type we
304   // need to deal with.
305   MVT VT = Vector.getSimpleValueType();
306   if (!Isv2x16VT(VT))
307     return false;
308   // Find and record all uses of this vector that extract element 0 or 1.
309   SmallVector<SDNode *, 4> E0, E1;
310   for (auto *U : Vector.getNode()->users()) {
311     if (U->getOpcode() != ISD::EXTRACT_VECTOR_ELT)
312       continue;
313     if (U->getOperand(0) != Vector)
314       continue;
315     if (const ConstantSDNode *IdxConst =
316             dyn_cast<ConstantSDNode>(U->getOperand(1))) {
317       if (IdxConst->getZExtValue() == 0)
318         E0.push_back(U);
319       else if (IdxConst->getZExtValue() == 1)
320         E1.push_back(U);
321       else
322         llvm_unreachable("Invalid vector index.");
323     }
324   }
325 
326   // There's no point scattering f16x2 if we only ever access one
327   // element of it.
328   if (E0.empty() || E1.empty())
329     return false;
330 
331   // Merge (f16 extractelt(V, 0), f16 extractelt(V,1))
332   // into f16,f16 SplitF16x2(V)
333   MVT EltVT = VT.getVectorElementType();
334   SDNode *ScatterOp =
335       CurDAG->getMachineNode(NVPTX::I32toV2I16, SDLoc(N), EltVT, EltVT, Vector);
336   for (auto *Node : E0)
337     ReplaceUses(SDValue(Node, 0), SDValue(ScatterOp, 0));
338   for (auto *Node : E1)
339     ReplaceUses(SDValue(Node, 0), SDValue(ScatterOp, 1));
340 
341   return true;
342 }
343 
344 static unsigned int getCodeAddrSpace(MemSDNode *N) {
345   const Value *Src = N->getMemOperand()->getValue();
346 
347   if (!Src)
348     return NVPTX::AddressSpace::Generic;
349 
350   if (auto *PT = dyn_cast<PointerType>(Src->getType())) {
351     switch (PT->getAddressSpace()) {
352     case llvm::ADDRESS_SPACE_LOCAL:
353       return NVPTX::AddressSpace::Local;
354     case llvm::ADDRESS_SPACE_GLOBAL:
355       return NVPTX::AddressSpace::Global;
356     case llvm::ADDRESS_SPACE_SHARED:
357       return NVPTX::AddressSpace::Shared;
358     case llvm::ADDRESS_SPACE_GENERIC:
359       return NVPTX::AddressSpace::Generic;
360     case llvm::ADDRESS_SPACE_PARAM:
361       return NVPTX::AddressSpace::Param;
362     case llvm::ADDRESS_SPACE_CONST:
363       return NVPTX::AddressSpace::Const;
364     default: break;
365     }
366   }
367   return NVPTX::AddressSpace::Generic;
368 }
369 
370 namespace {
371 
372 struct OperationOrderings {
373   NVPTX::Ordering InstructionOrdering, FenceOrdering;
374   OperationOrderings(NVPTX::Ordering IO = NVPTX::Ordering::NotAtomic,
375                      NVPTX::Ordering FO = NVPTX::Ordering::NotAtomic)
376       : InstructionOrdering(IO), FenceOrdering(FO) {}
377 };
378 
379 static OperationOrderings
380 getOperationOrderings(MemSDNode *N, const NVPTXSubtarget *Subtarget) {
381   AtomicOrdering Ordering = N->getSuccessOrdering();
382   auto CodeAddrSpace = getCodeAddrSpace(N);
383 
384   bool HasMemoryOrdering = Subtarget->hasMemoryOrdering();
385   bool HasRelaxedMMIO = Subtarget->hasRelaxedMMIO();
386 
387   // clang-format off
388 
389   // Lowering for Load/Store Operations (note: AcquireRelease Loads or Stores error).
390   // Note: uses of Relaxed in the Atomic column of this table refer
391   // to LLVM AtomicOrdering::Monotonic.
392   //
393   // | Atomic  | Volatile | Statespace         | PTX sm_60- | PTX sm_70+                   |
394   // |---------|----------|--------------------|------------|------------------------------|
395   // | No      | No       | All                | plain      | .weak                        |
396   // | No      | Yes      | Generic,Shared,    | .volatile  | .volatile                    |
397   // |         |          | Global [0]         |            |                              |
398   // | No      | Yes      | Local,Const,Param  | plain [1]  | .weak [1]                    |
399   // | Unorder | Yes/No   | All                | == Relaxed | == Relaxed                   |
400   // | Relaxed | No       | Generic,Shared,    | .volatile  | <atomic sem>                 |
401   // |         |          | Global [0]         |            |                              |
402   // | Other   | No       | Generic,Shared,    | Error [2]  | <atomic sem>                 |
403   // |         |          | Global [0]         |            |                              |
404   // | Yes     | No       | Local,Const,Param  | plain [1]  | .weak [1]                    |
405   // | Relaxed | Yes      | Generic,Shared [0] | .volatile  | .volatile                    |
406   // | Relaxed | Yes      | Global [0]         | .volatile  | .mmio.relaxed.sys (PTX 8.2+) |
407   // |         |          |                    |            |  or .volatile (PTX 8.1-)     |
408   // | Relaxed | Yes      | Local,Const,Param  | plain [1]  | .weak [1]                    |
409   // | Other   | Yes      | Generic, Shared,   | Error [2]  | <atomic sem> [3]             |
410   // |         |          | / Global [0]       |            |                              |
411 
412   // Lowering of CUDA C++ SequentiallyConsistent Operations and Fences to PTX
413   // by following the ABI proven sound in:
414   //   Lustig et al, A Formal Analysis of the NVIDIA PTX Memory Consistency Model, ASPLOS’19.
415   //   https://dl.acm.org/doi/pdf/10.1145/3297858.3304043
416   //
417   // | CUDA C++ Atomic Operation or Atomic Fence            | PTX Atomic Operation or Fence |
418   // |------------------------------------------------------|-------------------------------|
419   // | cuda::atomic_thread_fence                            | fence.sc.<scope>;             |
420   // |   (memory_order_seq_cst, cuda::thread_scope_<scope>) |                               |
421   // |------------------------------------------------------|-------------------------------|
422   // | cuda::atomic_load                                    | fence.sc.<scope>;             |
423   // |   (memory_order_seq_cst, cuda::thread_scope_<scope>) | ld.acquire.<scope>;           |
424   // |------------------------------------------------------|-------------------------------|
425   // | cuda::atomic_store                                   | fence.sc.<scope>;             |
426   // |   (memory_order_seq_cst, cuda::thread_scope_<scope>) | st.release.<scope>;           |
427   // |------------------------------------------------------|-------------------------------|
428   // | cuda::atomic_fetch_<op>                              | fence.sc.<scope>;             |
429   // |   (memory_order_seq_cst, cuda::thread_scope_<scope>) | atom.acq_rel.<scope>;         |
430 
431   // clang-format on
432 
433   // [0]: volatile and atomics are only supported on global or shared
434   //      memory locations, accessed via generic/shared/global pointers.
435   //      MMIO is only supported on global memory locations,
436   //      accessed via generic/global pointers.
437   // TODO: Implement MMIO access via generic pointer to global.
438   //       Currently implemented for global pointers only.
439 
440   // [1]: Lowering volatile/atomic operations to non-volatile/non-atomic
441   //      PTX instructions fails to preserve their C++ side-effects.
442   //
443   //      Example (https://github.com/llvm/llvm-project/issues/62057):
444   //
445   //          void example() {
446   //              std::atomic<bool> True = true;
447   //              while (True.load(std::memory_order_relaxed));
448   //          }
449   //
450   //      A C++ program that calls "example" is well-defined: the infinite loop
451   //      performs an atomic operation. By lowering volatile/atomics to
452   //      "weak" memory operations, we are transforming the above into:
453   //
454   //          void undefined_behavior() {
455   //              bool True = true;
456   //              while (True);
457   //          }
458   //
459   //      which exhibits undefined behavior in both C++ and PTX.
460   //
461   //      Calling "example" in CUDA C++ compiled for sm_60- exhibits undefined
462   //      behavior due to lack of Independent Forward Progress. Lowering these
463   //      to weak memory operations in sm_60- is therefore fine.
464   //
465   //      TODO: lower atomic and volatile operations to memory locations
466   //      in local, const, and param to two PTX instructions in sm_70+:
467   //        - the "weak" memory instruction we are currently lowering to, and
468   //        - some other instruction that preserves the side-effect, e.g.,
469   //          a dead dummy volatile load.
470   if (CodeAddrSpace == NVPTX::AddressSpace::Local ||
471       CodeAddrSpace == NVPTX::AddressSpace::Const ||
472       CodeAddrSpace == NVPTX::AddressSpace::Param) {
473     return NVPTX::Ordering::NotAtomic;
474   }
475 
476   // [2]: Atomics with Ordering different than Unordered or Relaxed are not
477   //      supported on sm_60 and older; this includes volatile atomics.
478   if (!(Ordering == AtomicOrdering::NotAtomic ||
479         Ordering == AtomicOrdering::Unordered ||
480         Ordering == AtomicOrdering::Monotonic) &&
481       !HasMemoryOrdering) {
482     report_fatal_error(
483         formatv("PTX does not support \"atomic\" for orderings different than"
484                 "\"NotAtomic\" or \"Monotonic\" for sm_60 or older, but order "
485                 "is: \"{}\".",
486                 toIRString(Ordering)));
487   }
488 
489   // [3]: TODO: these should eventually use .mmio<.atomic sem>; for now we drop
490   // the volatile semantics and preserve the atomic ones.
491 
492   // PTX volatile and PTX atomics are not available for statespace that differ
493   // from .generic, .global, or .shared. The behavior of PTX volatile and PTX
494   // atomics is undefined if the generic address does not refer to a .global or
495   // .shared memory location.
496   bool AddrGenericOrGlobalOrShared =
497       (CodeAddrSpace == NVPTX::AddressSpace::Generic ||
498        CodeAddrSpace == NVPTX::AddressSpace::Global ||
499        CodeAddrSpace == NVPTX::AddressSpace::Shared);
500   if (!AddrGenericOrGlobalOrShared)
501     return NVPTX::Ordering::NotAtomic;
502 
503   bool UseRelaxedMMIO =
504       HasRelaxedMMIO && CodeAddrSpace == NVPTX::AddressSpace::Global;
505 
506   switch (Ordering) {
507   case AtomicOrdering::NotAtomic:
508     return N->isVolatile() ? NVPTX::Ordering::Volatile
509                            : NVPTX::Ordering::NotAtomic;
510   case AtomicOrdering::Unordered:
511     // We lower unordered in the exact same way as 'monotonic' to respect
512     // LLVM IR atomicity requirements.
513   case AtomicOrdering::Monotonic:
514     if (N->isVolatile())
515       return UseRelaxedMMIO ? NVPTX::Ordering::RelaxedMMIO
516                             : NVPTX::Ordering::Volatile;
517     else
518       return HasMemoryOrdering ? NVPTX::Ordering::Relaxed
519                                : NVPTX::Ordering::Volatile;
520   // case AtomicOrdering::Consume: // If LLVM ever provides this, lower it to
521   // Acquire.
522   case AtomicOrdering::Acquire:
523     if (!N->readMem())
524       report_fatal_error(
525           formatv("PTX only supports Acquire Ordering on reads: {}",
526                   N->getOperationName()));
527     return NVPTX::Ordering::Acquire;
528   case AtomicOrdering::Release:
529     if (!N->writeMem())
530       report_fatal_error(
531           formatv("PTX only supports Release Ordering on writes: {}",
532                   N->getOperationName()));
533     return NVPTX::Ordering::Release;
534   case AtomicOrdering::AcquireRelease: {
535     report_fatal_error(
536         formatv("NVPTX does not support AcquireRelease Ordering on "
537                 "read-modify-write "
538                 "yet and PTX does not support it on loads or stores: {}",
539                 N->getOperationName()));
540   }
541   case AtomicOrdering::SequentiallyConsistent: {
542     // LLVM-IR SequentiallyConsistent atomics map to a two-instruction PTX
543     // sequence including a "fence.sc.sco" and the memory instruction with an
544     // Ordering that differs from "sc": acq, rel, or acq_rel, depending on
545     // whether the memory operation is a read, write, or read-modify-write.
546     //
547     // This sets the ordering of the fence to SequentiallyConsistent, and
548     // sets the corresponding ordering for the instruction.
549     NVPTX::Ordering InstrOrder;
550     if (N->readMem())
551       InstrOrder = NVPTX::Ordering::Acquire;
552     else if (N->writeMem())
553       InstrOrder = NVPTX::Ordering::Release;
554     else
555       report_fatal_error(
556           formatv("NVPTX does not support SequentiallyConsistent Ordering on "
557                   "read-modify-writes yet: {}",
558                   N->getOperationName()));
559     return OperationOrderings(InstrOrder,
560                               NVPTX::Ordering::SequentiallyConsistent);
561   }
562   }
563   report_fatal_error(
564       formatv("NVPTX backend does not support AtomicOrdering \"{}\" yet.",
565               toIRString(Ordering)));
566 }
567 
568 } // namespace
569 
570 NVPTX::Scope NVPTXDAGToDAGISel::getOperationScope(MemSDNode *N,
571                                                   NVPTX::Ordering O) const {
572   switch (O) {
573   case NVPTX::Ordering::NotAtomic:
574   case NVPTX::Ordering::Volatile: // Non-atomic volatile operations
575     // NVPTX uses Thread scope as the scope of non-atomic operations.
576     return NVPTX::Scope::Thread;
577   case NVPTX::Ordering::RelaxedMMIO:
578     // RelaxedMMIO operations are always system scope.
579     // If a RelaxedMMIO order was generated from an atomic volatile operation
580     // with a smaller thread scope, we bump it here to system scope.
581     return NVPTX::Scope::System;
582   case NVPTX::Ordering::Relaxed:
583   case NVPTX::Ordering::Acquire:
584   case NVPTX::Ordering::Release:
585   case NVPTX::Ordering::AcquireRelease:
586   case NVPTX::Ordering::SequentiallyConsistent:
587     auto S = Scopes[N->getSyncScopeID()];
588 
589     // Atomic operations must have a scope greater than thread.
590     if (S == NVPTX::Scope::Thread)
591       report_fatal_error(
592           formatv("Atomics need scope > \"{}\".", ScopeToString(S)));
593 
594     // If scope is cluster, clusters must be supported.
595     if (S == NVPTX::Scope::Cluster)
596       Subtarget->failIfClustersUnsupported("cluster scope");
597 
598     // If operation is volatile, then its scope is system.
599     return N->isVolatile() ? NVPTX::Scope::System : S;
600   }
601   llvm_unreachable("unhandled ordering");
602 }
603 
604 static bool canLowerToLDG(MemSDNode *N, const NVPTXSubtarget &Subtarget,
605                           unsigned CodeAddrSpace, MachineFunction *F) {
606   // We use ldg (i.e. ld.global.nc) for invariant loads from the global address
607   // space.
608   //
609   // We have two ways of identifying invariant loads: Loads may be explicitly
610   // marked as invariant, or we may infer them to be invariant.
611   //
612   // We currently infer invariance for loads from
613   //  - constant global variables, and
614   //  - kernel function pointer params that are noalias (i.e. __restrict) and
615   //    never written to.
616   //
617   // TODO: Perform a more powerful invariance analysis (ideally IPO, and ideally
618   // not during the SelectionDAG phase).
619   //
620   // TODO: Infer invariance only at -O2.  We still want to use ldg at -O0 for
621   // explicitly invariant loads because these are how clang tells us to use ldg
622   // when the user uses a builtin.
623   if (!Subtarget.hasLDG() || CodeAddrSpace != NVPTX::AddressSpace::Global)
624     return false;
625 
626   if (N->isInvariant())
627     return true;
628 
629   bool IsKernelFn = isKernelFunction(F->getFunction());
630 
631   // We use getUnderlyingObjects() here instead of getUnderlyingObject() mainly
632   // because the former looks through phi nodes while the latter does not. We
633   // need to look through phi nodes to handle pointer induction variables.
634   SmallVector<const Value *, 8> Objs;
635   getUnderlyingObjects(N->getMemOperand()->getValue(), Objs);
636 
637   return all_of(Objs, [&](const Value *V) {
638     if (auto *A = dyn_cast<const Argument>(V))
639       return IsKernelFn && A->onlyReadsMemory() && A->hasNoAliasAttr();
640     if (auto *GV = dyn_cast<const GlobalVariable>(V))
641       return GV->isConstant();
642     return false;
643   });
644 }
645 
646 static unsigned int getFenceOp(NVPTX::Ordering O, NVPTX::Scope S,
647                                NVPTXSubtarget const *T) {
648   if (S == NVPTX::Scope::Cluster)
649     T->failIfClustersUnsupported(".cluster scope fence");
650 
651   switch (O) {
652   case NVPTX::Ordering::Acquire:
653   case NVPTX::Ordering::Release:
654   case NVPTX::Ordering::AcquireRelease: {
655     switch (S) {
656     case NVPTX::Scope::System:
657       return T->hasMemoryOrdering() ? NVPTX::atomic_thread_fence_acq_rel_sys
658                                     : NVPTX::INT_MEMBAR_SYS;
659     case NVPTX::Scope::Block:
660       return T->hasMemoryOrdering() ? NVPTX::atomic_thread_fence_acq_rel_cta
661                                     : NVPTX::INT_MEMBAR_CTA;
662     case NVPTX::Scope::Cluster:
663       return NVPTX::atomic_thread_fence_acq_rel_cluster;
664     case NVPTX::Scope::Device:
665       return T->hasMemoryOrdering() ? NVPTX::atomic_thread_fence_acq_rel_gpu
666                                     : NVPTX::INT_MEMBAR_GL;
667     case NVPTX::Scope::Thread:
668       report_fatal_error(
669           formatv("Unsupported scope \"{}\" for acquire/release/acq_rel fence.",
670                   ScopeToString(S)));
671     }
672     break;
673   }
674   case NVPTX::Ordering::SequentiallyConsistent: {
675     switch (S) {
676     case NVPTX::Scope::System:
677       return T->hasMemoryOrdering() ? NVPTX::atomic_thread_fence_seq_cst_sys
678                                     : NVPTX::INT_MEMBAR_SYS;
679     case NVPTX::Scope::Block:
680       return T->hasMemoryOrdering() ? NVPTX::atomic_thread_fence_seq_cst_cta
681                                     : NVPTX::INT_MEMBAR_CTA;
682     case NVPTX::Scope::Cluster:
683       return NVPTX::atomic_thread_fence_seq_cst_cluster;
684     case NVPTX::Scope::Device:
685       return T->hasMemoryOrdering() ? NVPTX::atomic_thread_fence_seq_cst_gpu
686                                     : NVPTX::INT_MEMBAR_GL;
687     case NVPTX::Scope::Thread:
688       report_fatal_error(formatv("Unsupported scope \"{}\" for seq_cst fence.",
689                                  ScopeToString(S)));
690     }
691     break;
692   }
693   case NVPTX::Ordering::NotAtomic:
694   case NVPTX::Ordering::Relaxed:
695   case NVPTX::Ordering::Volatile:
696   case NVPTX::Ordering::RelaxedMMIO:
697     report_fatal_error(
698         formatv("Unsupported \"{}\" ordering and \"{}\" scope for fence.",
699                 OrderingToString(O), ScopeToString(S)));
700   }
701   llvm_unreachable("unhandled ordering");
702 }
703 
704 // Returns Memory Order and Scope of a memory instruction, and
705 // inserts any fence before the instruction that's required to
706 // implement its memory ordering.
707 std::pair<NVPTX::Ordering, NVPTX::Scope>
708 NVPTXDAGToDAGISel::insertMemoryInstructionFence(SDLoc DL, SDValue &Chain,
709                                                 MemSDNode *N) {
710   auto [InstructionOrdering, FenceOrdering] =
711       getOperationOrderings(N, Subtarget);
712   auto Scope = getOperationScope(N, InstructionOrdering);
713 
714   // If a fence is required before the operation, insert it:
715   switch (NVPTX::Ordering(FenceOrdering)) {
716   case NVPTX::Ordering::NotAtomic:
717     break;
718   case NVPTX::Ordering::SequentiallyConsistent: {
719     auto Op = getFenceOp(FenceOrdering, Scope, Subtarget);
720     Chain = SDValue(CurDAG->getMachineNode(Op, DL, MVT::Other, Chain), 0);
721     break;
722   }
723   default:
724     report_fatal_error(
725         formatv("Unexpected fence ordering: \"{}\".",
726                 OrderingToString(NVPTX::Ordering(FenceOrdering))));
727   }
728   return {InstructionOrdering, Scope};
729 }
730 
731 bool NVPTXDAGToDAGISel::tryIntrinsicNoChain(SDNode *N) {
732   unsigned IID = N->getConstantOperandVal(0);
733   switch (IID) {
734   default:
735     return false;
736   case Intrinsic::nvvm_texsurf_handle_internal:
737     SelectTexSurfHandle(N);
738     return true;
739   }
740 }
741 
742 void NVPTXDAGToDAGISel::SelectTexSurfHandle(SDNode *N) {
743   // Op 0 is the intrinsic ID
744   SDValue Wrapper = N->getOperand(1);
745   SDValue GlobalVal = Wrapper.getOperand(0);
746   ReplaceNode(N, CurDAG->getMachineNode(NVPTX::texsurf_handles, SDLoc(N),
747                                         MVT::i64, GlobalVal));
748 }
749 
750 void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
751   SDValue Src = N->getOperand(0);
752   AddrSpaceCastSDNode *CastN = cast<AddrSpaceCastSDNode>(N);
753   unsigned SrcAddrSpace = CastN->getSrcAddressSpace();
754   unsigned DstAddrSpace = CastN->getDestAddressSpace();
755   SDLoc DL(N);
756   assert(SrcAddrSpace != DstAddrSpace &&
757          "addrspacecast must be between different address spaces");
758 
759   if (DstAddrSpace == ADDRESS_SPACE_GENERIC) {
760     // Specific to generic
761 
762     if (TM.is64Bit() && TM.getPointerSizeInBits(SrcAddrSpace) == 32) {
763       SDValue CvtNone =
764           CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32);
765       SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_u64_u32, DL, MVT::i64,
766                                            Src, CvtNone);
767       Src = SDValue(Cvt, 0);
768     }
769 
770     unsigned Opc;
771     switch (SrcAddrSpace) {
772     default: report_fatal_error("Bad address space in addrspacecast");
773     case ADDRESS_SPACE_GLOBAL:
774       Opc = TM.is64Bit() ? NVPTX::cvta_global_64 : NVPTX::cvta_global;
775       break;
776     case ADDRESS_SPACE_SHARED:
777       Opc = TM.is64Bit() ? NVPTX::cvta_shared_64 : NVPTX::cvta_shared;
778       break;
779     case ADDRESS_SPACE_CONST:
780       Opc = TM.is64Bit() ? NVPTX::cvta_const_64 : NVPTX::cvta_const;
781       break;
782     case ADDRESS_SPACE_LOCAL:
783       Opc = TM.is64Bit() ? NVPTX::cvta_local_64 : NVPTX::cvta_local;
784       break;
785     }
786     ReplaceNode(N, CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src));
787     return;
788   } else {
789     // Generic to specific
790     if (SrcAddrSpace != 0)
791       report_fatal_error("Cannot cast between two non-generic address spaces");
792     unsigned Opc;
793     switch (DstAddrSpace) {
794     default: report_fatal_error("Bad address space in addrspacecast");
795     case ADDRESS_SPACE_GLOBAL:
796       Opc = TM.is64Bit() ? NVPTX::cvta_to_global_64 : NVPTX::cvta_to_global;
797       break;
798     case ADDRESS_SPACE_SHARED:
799       Opc = TM.is64Bit() ? NVPTX::cvta_to_shared_64 : NVPTX::cvta_to_shared;
800       break;
801     case ADDRESS_SPACE_CONST:
802       Opc = TM.is64Bit() ? NVPTX::cvta_to_const_64 : NVPTX::cvta_to_const;
803       break;
804     case ADDRESS_SPACE_LOCAL:
805       Opc = TM.is64Bit() ? NVPTX::cvta_to_local_64 : NVPTX::cvta_to_local;
806       break;
807     case ADDRESS_SPACE_PARAM:
808       Opc = TM.is64Bit() ? NVPTX::IMOV64rr : NVPTX::IMOV32rr;
809       break;
810     }
811 
812     SDNode *CVTA = CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src);
813     if (TM.is64Bit() && TM.getPointerSizeInBits(DstAddrSpace) == 32) {
814       SDValue CvtNone =
815           CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32);
816       CVTA = CurDAG->getMachineNode(NVPTX::CVT_u32_u64, DL, MVT::i32,
817                                     SDValue(CVTA, 0), CvtNone);
818     }
819 
820     ReplaceNode(N, CVTA);
821     return;
822   }
823 }
824 
825 // Helper function template to reduce amount of boilerplate code for
826 // opcode selection.
827 static std::optional<unsigned>
828 pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
829                 unsigned Opcode_i16, unsigned Opcode_i32,
830                 std::optional<unsigned> Opcode_i64, unsigned Opcode_f32,
831                 std::optional<unsigned> Opcode_f64) {
832   switch (VT) {
833   case MVT::i1:
834   case MVT::i8:
835     return Opcode_i8;
836   case MVT::i16:
837     return Opcode_i16;
838   case MVT::i32:
839     return Opcode_i32;
840   case MVT::i64:
841     return Opcode_i64;
842   case MVT::f16:
843   case MVT::bf16:
844     return Opcode_i16;
845   case MVT::v2f16:
846   case MVT::v2bf16:
847   case MVT::v2i16:
848   case MVT::v4i8:
849     return Opcode_i32;
850   case MVT::f32:
851     return Opcode_f32;
852   case MVT::f64:
853     return Opcode_f64;
854   default:
855     return std::nullopt;
856   }
857 }
858 
859 static int getLdStRegType(EVT VT) {
860   if (VT.isFloatingPoint())
861     switch (VT.getSimpleVT().SimpleTy) {
862     case MVT::f16:
863     case MVT::bf16:
864     case MVT::v2f16:
865     case MVT::v2bf16:
866       return NVPTX::PTXLdStInstCode::Untyped;
867     default:
868       return NVPTX::PTXLdStInstCode::Float;
869     }
870   else
871     return NVPTX::PTXLdStInstCode::Unsigned;
872 }
873 
874 bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
875   MemSDNode *LD = cast<MemSDNode>(N);
876   assert(LD->readMem() && "Expected load");
877 
878   // do not support pre/post inc/dec
879   LoadSDNode *PlainLoad = dyn_cast<LoadSDNode>(N);
880   if (PlainLoad && PlainLoad->isIndexed())
881     return false;
882 
883   EVT LoadedVT = LD->getMemoryVT();
884   if (!LoadedVT.isSimple())
885     return false;
886 
887   // Address Space Setting
888   unsigned int CodeAddrSpace = getCodeAddrSpace(LD);
889   if (canLowerToLDG(LD, *Subtarget, CodeAddrSpace, MF)) {
890     return tryLDGLDU(N);
891   }
892   unsigned int PointerSize =
893       CurDAG->getDataLayout().getPointerSizeInBits(LD->getAddressSpace());
894 
895   SDLoc DL(N);
896   SDValue Chain = N->getOperand(0);
897   auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, LD);
898 
899   // Type Setting: fromType + fromTypeWidth
900   //
901   // Sign   : ISD::SEXTLOAD
902   // Unsign : ISD::ZEXTLOAD, ISD::NON_EXTLOAD or ISD::EXTLOAD and the
903   //          type is integer
904   // Float  : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
905   MVT SimpleVT = LoadedVT.getSimpleVT();
906   MVT ScalarVT = SimpleVT.getScalarType();
907   // Read at least 8 bits (predicates are stored as 8-bit values)
908   unsigned FromTypeWidth = std::max(8U, (unsigned)ScalarVT.getSizeInBits());
909   unsigned int FromType;
910 
911   // Vector Setting
912   unsigned VecType = NVPTX::PTXLdStInstCode::Scalar;
913   if (SimpleVT.isVector()) {
914     assert((Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8) &&
915            "Unexpected vector type");
916     // v2f16/v2bf16/v2i16 is loaded using ld.b32
917     FromTypeWidth = 32;
918   }
919 
920   if (PlainLoad && (PlainLoad->getExtensionType() == ISD::SEXTLOAD))
921     FromType = NVPTX::PTXLdStInstCode::Signed;
922   else
923     FromType = getLdStRegType(ScalarVT);
924 
925   // Create the machine instruction DAG
926   SDValue N1 = N->getOperand(1);
927   SDValue Addr;
928   SDValue Offset, Base;
929   std::optional<unsigned> Opcode;
930   MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy;
931 
932   SmallVector<SDValue, 12> Ops({getI32Imm(Ordering, DL), getI32Imm(Scope, DL),
933                                 getI32Imm(CodeAddrSpace, DL),
934                                 getI32Imm(VecType, DL), getI32Imm(FromType, DL),
935                                 getI32Imm(FromTypeWidth, DL)});
936 
937   if (SelectDirectAddr(N1, Addr)) {
938     Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_i8_avar, NVPTX::LD_i16_avar,
939                              NVPTX::LD_i32_avar, NVPTX::LD_i64_avar,
940                              NVPTX::LD_f32_avar, NVPTX::LD_f64_avar);
941     if (!Opcode)
942       return false;
943     Ops.append({Addr, Chain});
944   } else if (PointerSize == 64 ? SelectADDRsi64(N1.getNode(), N1, Base, Offset)
945                                : SelectADDRsi(N1.getNode(), N1, Base, Offset)) {
946     Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_i8_asi, NVPTX::LD_i16_asi,
947                              NVPTX::LD_i32_asi, NVPTX::LD_i64_asi,
948                              NVPTX::LD_f32_asi, NVPTX::LD_f64_asi);
949     if (!Opcode)
950       return false;
951     Ops.append({Base, Offset, Chain});
952   } else if (PointerSize == 64 ? SelectADDRri64(N1.getNode(), N1, Base, Offset)
953                                : SelectADDRri(N1.getNode(), N1, Base, Offset)) {
954     if (PointerSize == 64)
955       Opcode =
956           pickOpcodeForVT(TargetVT, NVPTX::LD_i8_ari_64, NVPTX::LD_i16_ari_64,
957                           NVPTX::LD_i32_ari_64, NVPTX::LD_i64_ari_64,
958                           NVPTX::LD_f32_ari_64, NVPTX::LD_f64_ari_64);
959     else
960       Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_i8_ari, NVPTX::LD_i16_ari,
961                                NVPTX::LD_i32_ari, NVPTX::LD_i64_ari,
962                                NVPTX::LD_f32_ari, NVPTX::LD_f64_ari);
963     if (!Opcode)
964       return false;
965     Ops.append({Base, Offset, Chain});
966   } else {
967     if (PointerSize == 64)
968       Opcode =
969           pickOpcodeForVT(TargetVT, NVPTX::LD_i8_areg_64, NVPTX::LD_i16_areg_64,
970                           NVPTX::LD_i32_areg_64, NVPTX::LD_i64_areg_64,
971                           NVPTX::LD_f32_areg_64, NVPTX::LD_f64_areg_64);
972     else
973       Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_i8_areg, NVPTX::LD_i16_areg,
974                                NVPTX::LD_i32_areg, NVPTX::LD_i64_areg,
975                                NVPTX::LD_f32_areg, NVPTX::LD_f64_areg);
976     if (!Opcode)
977       return false;
978     Ops.append({N1, Chain});
979   }
980 
981   SDNode *NVPTXLD =
982       CurDAG->getMachineNode(*Opcode, DL, TargetVT, MVT::Other, Ops);
983   if (!NVPTXLD)
984     return false;
985 
986   MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
987   CurDAG->setNodeMemRefs(cast<MachineSDNode>(NVPTXLD), {MemRef});
988 
989   ReplaceNode(N, NVPTXLD);
990   return true;
991 }
992 
993 static bool isVectorElementTypeUpsized(EVT EltVT) {
994   // Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for
995   // total load/store size, PTX syntax only supports v2/v4. Thus, we can't use
996   // vectorized loads/stores with the actual element type for i8/i16 as that
997   // would require v8/v16 variants that do not exist.
998   // In order to load/store such vectors efficiently, in Type Legalization
999   // we split the vector into word-sized chunks (v2x16/v4i8). Now, we will
1000   // lower to PTX as vectors of b32.
1001   return Isv2x16VT(EltVT) || EltVT == MVT::v4i8;
1002 }
1003 
1004 bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
1005   MemSDNode *MemSD = cast<MemSDNode>(N);
1006   EVT LoadedVT = MemSD->getMemoryVT();
1007   if (!LoadedVT.isSimple())
1008     return false;
1009 
1010   // Address Space Setting
1011   unsigned int CodeAddrSpace = getCodeAddrSpace(MemSD);
1012   if (canLowerToLDG(MemSD, *Subtarget, CodeAddrSpace, MF)) {
1013     return tryLDGLDU(N);
1014   }
1015   unsigned int PointerSize =
1016       CurDAG->getDataLayout().getPointerSizeInBits(MemSD->getAddressSpace());
1017 
1018   SDLoc DL(N);
1019   SDValue Chain = N->getOperand(0);
1020   auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, MemSD);
1021 
1022   // Vector Setting
1023   MVT SimpleVT = LoadedVT.getSimpleVT();
1024 
1025   // Type Setting: fromType + fromTypeWidth
1026   //
1027   // Sign   : ISD::SEXTLOAD
1028   // Unsign : ISD::ZEXTLOAD, ISD::NON_EXTLOAD or ISD::EXTLOAD and the
1029   //          type is integer
1030   // Float  : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
1031   MVT ScalarVT = SimpleVT.getScalarType();
1032   // Read at least 8 bits (predicates are stored as 8-bit values)
1033   unsigned FromTypeWidth = std::max(8U, (unsigned)ScalarVT.getSizeInBits());
1034   unsigned int FromType;
1035   // The last operand holds the original LoadSDNode::getExtensionType() value
1036   unsigned ExtensionType = cast<ConstantSDNode>(
1037       N->getOperand(N->getNumOperands() - 1))->getZExtValue();
1038   if (ExtensionType == ISD::SEXTLOAD)
1039     FromType = NVPTX::PTXLdStInstCode::Signed;
1040   else
1041     FromType = getLdStRegType(ScalarVT);
1042 
1043   unsigned VecType;
1044 
1045   switch (N->getOpcode()) {
1046   case NVPTXISD::LoadV2:
1047     VecType = NVPTX::PTXLdStInstCode::V2;
1048     break;
1049   case NVPTXISD::LoadV4:
1050     VecType = NVPTX::PTXLdStInstCode::V4;
1051     break;
1052   default:
1053     return false;
1054   }
1055 
1056   EVT EltVT = N->getValueType(0);
1057 
1058   if (isVectorElementTypeUpsized(EltVT)) {
1059     EltVT = MVT::i32;
1060     FromType = NVPTX::PTXLdStInstCode::Untyped;
1061     FromTypeWidth = 32;
1062   }
1063 
1064   SDValue Op1 = N->getOperand(1);
1065   SDValue Addr, Offset, Base;
1066   std::optional<unsigned> Opcode;
1067   SDNode *LD;
1068 
1069   SmallVector<SDValue, 12> Ops({getI32Imm(Ordering, DL), getI32Imm(Scope, DL),
1070                                 getI32Imm(CodeAddrSpace, DL),
1071                                 getI32Imm(VecType, DL), getI32Imm(FromType, DL),
1072                                 getI32Imm(FromTypeWidth, DL)});
1073 
1074   if (SelectDirectAddr(Op1, Addr)) {
1075     switch (N->getOpcode()) {
1076     default:
1077       return false;
1078     case NVPTXISD::LoadV2:
1079       Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1080                                NVPTX::LDV_i8_v2_avar, NVPTX::LDV_i16_v2_avar,
1081                                NVPTX::LDV_i32_v2_avar, NVPTX::LDV_i64_v2_avar,
1082                                NVPTX::LDV_f32_v2_avar, NVPTX::LDV_f64_v2_avar);
1083       break;
1084     case NVPTXISD::LoadV4:
1085       Opcode =
1086           pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4_avar,
1087                           NVPTX::LDV_i16_v4_avar, NVPTX::LDV_i32_v4_avar,
1088                           std::nullopt, NVPTX::LDV_f32_v4_avar, std::nullopt);
1089       break;
1090     }
1091     if (!Opcode)
1092       return false;
1093     Ops.append({Addr, Chain});
1094   } else if (PointerSize == 64
1095                  ? SelectADDRsi64(Op1.getNode(), Op1, Base, Offset)
1096                  : SelectADDRsi(Op1.getNode(), Op1, Base, Offset)) {
1097     switch (N->getOpcode()) {
1098     default:
1099       return false;
1100     case NVPTXISD::LoadV2:
1101       Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1102                                NVPTX::LDV_i8_v2_asi, NVPTX::LDV_i16_v2_asi,
1103                                NVPTX::LDV_i32_v2_asi, NVPTX::LDV_i64_v2_asi,
1104                                NVPTX::LDV_f32_v2_asi, NVPTX::LDV_f64_v2_asi);
1105       break;
1106     case NVPTXISD::LoadV4:
1107       Opcode =
1108           pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4_asi,
1109                           NVPTX::LDV_i16_v4_asi, NVPTX::LDV_i32_v4_asi,
1110                           std::nullopt, NVPTX::LDV_f32_v4_asi, std::nullopt);
1111       break;
1112     }
1113     if (!Opcode)
1114       return false;
1115     Ops.append({Base, Offset, Chain});
1116   } else if (PointerSize == 64
1117                  ? SelectADDRri64(Op1.getNode(), Op1, Base, Offset)
1118                  : SelectADDRri(Op1.getNode(), Op1, Base, Offset)) {
1119     if (PointerSize == 64) {
1120       switch (N->getOpcode()) {
1121       default:
1122         return false;
1123       case NVPTXISD::LoadV2:
1124         Opcode =
1125             pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1126                             NVPTX::LDV_i8_v2_ari_64, NVPTX::LDV_i16_v2_ari_64,
1127                             NVPTX::LDV_i32_v2_ari_64, NVPTX::LDV_i64_v2_ari_64,
1128                             NVPTX::LDV_f32_v2_ari_64, NVPTX::LDV_f64_v2_ari_64);
1129         break;
1130       case NVPTXISD::LoadV4:
1131         Opcode = pickOpcodeForVT(
1132             EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4_ari_64,
1133             NVPTX::LDV_i16_v4_ari_64, NVPTX::LDV_i32_v4_ari_64, std::nullopt,
1134             NVPTX::LDV_f32_v4_ari_64, std::nullopt);
1135         break;
1136       }
1137     } else {
1138       switch (N->getOpcode()) {
1139       default:
1140         return false;
1141       case NVPTXISD::LoadV2:
1142         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1143                                  NVPTX::LDV_i8_v2_ari, NVPTX::LDV_i16_v2_ari,
1144                                  NVPTX::LDV_i32_v2_ari, NVPTX::LDV_i64_v2_ari,
1145                                  NVPTX::LDV_f32_v2_ari, NVPTX::LDV_f64_v2_ari);
1146         break;
1147       case NVPTXISD::LoadV4:
1148         Opcode =
1149             pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4_ari,
1150                             NVPTX::LDV_i16_v4_ari, NVPTX::LDV_i32_v4_ari,
1151                             std::nullopt, NVPTX::LDV_f32_v4_ari, std::nullopt);
1152         break;
1153       }
1154     }
1155     if (!Opcode)
1156       return false;
1157     Ops.append({Base, Offset, Chain});
1158   } else {
1159     if (PointerSize == 64) {
1160       switch (N->getOpcode()) {
1161       default:
1162         return false;
1163       case NVPTXISD::LoadV2:
1164         Opcode = pickOpcodeForVT(
1165             EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v2_areg_64,
1166             NVPTX::LDV_i16_v2_areg_64, NVPTX::LDV_i32_v2_areg_64,
1167             NVPTX::LDV_i64_v2_areg_64, NVPTX::LDV_f32_v2_areg_64,
1168             NVPTX::LDV_f64_v2_areg_64);
1169         break;
1170       case NVPTXISD::LoadV4:
1171         Opcode = pickOpcodeForVT(
1172             EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4_areg_64,
1173             NVPTX::LDV_i16_v4_areg_64, NVPTX::LDV_i32_v4_areg_64, std::nullopt,
1174             NVPTX::LDV_f32_v4_areg_64, std::nullopt);
1175         break;
1176       }
1177     } else {
1178       switch (N->getOpcode()) {
1179       default:
1180         return false;
1181       case NVPTXISD::LoadV2:
1182         Opcode =
1183             pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v2_areg,
1184                             NVPTX::LDV_i16_v2_areg, NVPTX::LDV_i32_v2_areg,
1185                             NVPTX::LDV_i64_v2_areg, NVPTX::LDV_f32_v2_areg,
1186                             NVPTX::LDV_f64_v2_areg);
1187         break;
1188       case NVPTXISD::LoadV4:
1189         Opcode =
1190             pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4_areg,
1191                             NVPTX::LDV_i16_v4_areg, NVPTX::LDV_i32_v4_areg,
1192                             std::nullopt, NVPTX::LDV_f32_v4_areg, std::nullopt);
1193         break;
1194       }
1195     }
1196     if (!Opcode)
1197       return false;
1198     Ops.append({Op1, Chain});
1199   }
1200   LD = CurDAG->getMachineNode(*Opcode, DL, N->getVTList(), Ops);
1201 
1202   MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
1203   CurDAG->setNodeMemRefs(cast<MachineSDNode>(LD), {MemRef});
1204 
1205   ReplaceNode(N, LD);
1206   return true;
1207 }
1208 
1209 bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
1210   auto *Mem = cast<MemSDNode>(N);
1211 
1212   // If this is an LDG intrinsic, the address is the third operand. If its an
1213   // LDG/LDU SD node (from custom vector handling), then its the second operand
1214   SDValue Op1 = N->getOperand(N->getOpcode() == ISD::INTRINSIC_W_CHAIN ? 2 : 1);
1215 
1216   EVT OrigType = N->getValueType(0);
1217   EVT EltVT = Mem->getMemoryVT();
1218   unsigned NumElts = 1;
1219   if (EltVT.isVector()) {
1220     NumElts = EltVT.getVectorNumElements();
1221     EltVT = EltVT.getVectorElementType();
1222     // vectors of 8/16bits type are loaded/stored as multiples of v4i8/v2x16
1223     // elements.
1224     if ((EltVT == MVT::f16 && OrigType == MVT::v2f16) ||
1225         (EltVT == MVT::bf16 && OrigType == MVT::v2bf16) ||
1226         (EltVT == MVT::i16 && OrigType == MVT::v2i16) ||
1227         (EltVT == MVT::i8 && OrigType == MVT::v4i8)) {
1228       assert(NumElts % OrigType.getVectorNumElements() == 0 &&
1229              "NumElts must be divisible by the number of elts in subvectors");
1230       EltVT = OrigType;
1231       NumElts /= OrigType.getVectorNumElements();
1232     }
1233   }
1234 
1235   // Build the "promoted" result VTList for the load. If we are really loading
1236   // i8s, then the return type will be promoted to i16 since we do not expose
1237   // 8-bit registers in NVPTX.
1238   EVT NodeVT = (EltVT == MVT::i8) ? MVT::i16 : EltVT;
1239   SmallVector<EVT, 5> InstVTs;
1240   for (unsigned i = 0; i != NumElts; ++i) {
1241     InstVTs.push_back(NodeVT);
1242   }
1243   InstVTs.push_back(MVT::Other);
1244   SDVTList InstVTList = CurDAG->getVTList(InstVTs);
1245   SDValue Chain = N->getOperand(0);
1246 
1247   std::optional<unsigned> Opcode;
1248   SDLoc DL(N);
1249   SDNode *LD;
1250   SDValue Base, Offset, Addr;
1251 
1252   if (SelectDirectAddr(Op1, Addr)) {
1253     switch (N->getOpcode()) {
1254     default:
1255       return false;
1256     case ISD::LOAD:
1257       Opcode = pickOpcodeForVT(
1258           EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8avar,
1259           NVPTX::INT_PTX_LDG_GLOBAL_i16avar, NVPTX::INT_PTX_LDG_GLOBAL_i32avar,
1260           NVPTX::INT_PTX_LDG_GLOBAL_i64avar, NVPTX::INT_PTX_LDG_GLOBAL_f32avar,
1261           NVPTX::INT_PTX_LDG_GLOBAL_f64avar);
1262       break;
1263     case ISD::INTRINSIC_W_CHAIN:
1264       Opcode = pickOpcodeForVT(
1265           EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8avar,
1266           NVPTX::INT_PTX_LDU_GLOBAL_i16avar, NVPTX::INT_PTX_LDU_GLOBAL_i32avar,
1267           NVPTX::INT_PTX_LDU_GLOBAL_i64avar, NVPTX::INT_PTX_LDU_GLOBAL_f32avar,
1268           NVPTX::INT_PTX_LDU_GLOBAL_f64avar);
1269       break;
1270     case NVPTXISD::LoadV2:
1271       Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1272                                NVPTX::INT_PTX_LDG_G_v2i8_ELE_avar,
1273                                NVPTX::INT_PTX_LDG_G_v2i16_ELE_avar,
1274                                NVPTX::INT_PTX_LDG_G_v2i32_ELE_avar,
1275                                NVPTX::INT_PTX_LDG_G_v2i64_ELE_avar,
1276                                NVPTX::INT_PTX_LDG_G_v2f32_ELE_avar,
1277                                NVPTX::INT_PTX_LDG_G_v2f64_ELE_avar);
1278       break;
1279     case NVPTXISD::LDUV2:
1280       Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1281                                NVPTX::INT_PTX_LDU_G_v2i8_ELE_avar,
1282                                NVPTX::INT_PTX_LDU_G_v2i16_ELE_avar,
1283                                NVPTX::INT_PTX_LDU_G_v2i32_ELE_avar,
1284                                NVPTX::INT_PTX_LDU_G_v2i64_ELE_avar,
1285                                NVPTX::INT_PTX_LDU_G_v2f32_ELE_avar,
1286                                NVPTX::INT_PTX_LDU_G_v2f64_ELE_avar);
1287       break;
1288     case NVPTXISD::LoadV4:
1289       Opcode = pickOpcodeForVT(
1290           EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE_avar,
1291           NVPTX::INT_PTX_LDG_G_v4i16_ELE_avar,
1292           NVPTX::INT_PTX_LDG_G_v4i32_ELE_avar, std::nullopt,
1293           NVPTX::INT_PTX_LDG_G_v4f32_ELE_avar, std::nullopt);
1294       break;
1295     case NVPTXISD::LDUV4:
1296       Opcode = pickOpcodeForVT(
1297           EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE_avar,
1298           NVPTX::INT_PTX_LDU_G_v4i16_ELE_avar,
1299           NVPTX::INT_PTX_LDU_G_v4i32_ELE_avar, std::nullopt,
1300           NVPTX::INT_PTX_LDU_G_v4f32_ELE_avar, std::nullopt);
1301       break;
1302     }
1303     if (!Opcode)
1304       return false;
1305     SDValue Ops[] = { Addr, Chain };
1306     LD = CurDAG->getMachineNode(*Opcode, DL, InstVTList, Ops);
1307   } else if (TM.is64Bit() ? SelectADDRri64(Op1.getNode(), Op1, Base, Offset)
1308                           : SelectADDRri(Op1.getNode(), Op1, Base, Offset)) {
1309     if (TM.is64Bit()) {
1310       switch (N->getOpcode()) {
1311       default:
1312         return false;
1313       case ISD::LOAD:
1314         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1315                                  NVPTX::INT_PTX_LDG_GLOBAL_i8ari64,
1316                                  NVPTX::INT_PTX_LDG_GLOBAL_i16ari64,
1317                                  NVPTX::INT_PTX_LDG_GLOBAL_i32ari64,
1318                                  NVPTX::INT_PTX_LDG_GLOBAL_i64ari64,
1319                                  NVPTX::INT_PTX_LDG_GLOBAL_f32ari64,
1320                                  NVPTX::INT_PTX_LDG_GLOBAL_f64ari64);
1321         break;
1322       case ISD::INTRINSIC_W_CHAIN:
1323         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1324                                  NVPTX::INT_PTX_LDU_GLOBAL_i8ari64,
1325                                  NVPTX::INT_PTX_LDU_GLOBAL_i16ari64,
1326                                  NVPTX::INT_PTX_LDU_GLOBAL_i32ari64,
1327                                  NVPTX::INT_PTX_LDU_GLOBAL_i64ari64,
1328                                  NVPTX::INT_PTX_LDU_GLOBAL_f32ari64,
1329                                  NVPTX::INT_PTX_LDU_GLOBAL_f64ari64);
1330         break;
1331       case NVPTXISD::LoadV2:
1332         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1333                                      NVPTX::INT_PTX_LDG_G_v2i8_ELE_ari64,
1334                                      NVPTX::INT_PTX_LDG_G_v2i16_ELE_ari64,
1335                                      NVPTX::INT_PTX_LDG_G_v2i32_ELE_ari64,
1336                                      NVPTX::INT_PTX_LDG_G_v2i64_ELE_ari64,
1337                                      NVPTX::INT_PTX_LDG_G_v2f32_ELE_ari64,
1338                                      NVPTX::INT_PTX_LDG_G_v2f64_ELE_ari64);
1339         break;
1340       case NVPTXISD::LDUV2:
1341         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1342                                      NVPTX::INT_PTX_LDU_G_v2i8_ELE_ari64,
1343                                      NVPTX::INT_PTX_LDU_G_v2i16_ELE_ari64,
1344                                      NVPTX::INT_PTX_LDU_G_v2i32_ELE_ari64,
1345                                      NVPTX::INT_PTX_LDU_G_v2i64_ELE_ari64,
1346                                      NVPTX::INT_PTX_LDU_G_v2f32_ELE_ari64,
1347                                      NVPTX::INT_PTX_LDU_G_v2f64_ELE_ari64);
1348         break;
1349       case NVPTXISD::LoadV4:
1350         Opcode = pickOpcodeForVT(
1351             EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE_ari64,
1352             NVPTX::INT_PTX_LDG_G_v4i16_ELE_ari64,
1353             NVPTX::INT_PTX_LDG_G_v4i32_ELE_ari64, std::nullopt,
1354             NVPTX::INT_PTX_LDG_G_v4f32_ELE_ari64, std::nullopt);
1355         break;
1356       case NVPTXISD::LDUV4:
1357         Opcode = pickOpcodeForVT(
1358             EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE_ari64,
1359             NVPTX::INT_PTX_LDU_G_v4i16_ELE_ari64,
1360             NVPTX::INT_PTX_LDU_G_v4i32_ELE_ari64, std::nullopt,
1361             NVPTX::INT_PTX_LDU_G_v4f32_ELE_ari64, std::nullopt);
1362         break;
1363       }
1364     } else {
1365       switch (N->getOpcode()) {
1366       default:
1367         return false;
1368       case ISD::LOAD:
1369         Opcode = pickOpcodeForVT(
1370             EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8ari,
1371             NVPTX::INT_PTX_LDG_GLOBAL_i16ari, NVPTX::INT_PTX_LDG_GLOBAL_i32ari,
1372             NVPTX::INT_PTX_LDG_GLOBAL_i64ari, NVPTX::INT_PTX_LDG_GLOBAL_f32ari,
1373             NVPTX::INT_PTX_LDG_GLOBAL_f64ari);
1374         break;
1375       case ISD::INTRINSIC_W_CHAIN:
1376         Opcode = pickOpcodeForVT(
1377             EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8ari,
1378             NVPTX::INT_PTX_LDU_GLOBAL_i16ari, NVPTX::INT_PTX_LDU_GLOBAL_i32ari,
1379             NVPTX::INT_PTX_LDU_GLOBAL_i64ari, NVPTX::INT_PTX_LDU_GLOBAL_f32ari,
1380             NVPTX::INT_PTX_LDU_GLOBAL_f64ari);
1381         break;
1382       case NVPTXISD::LoadV2:
1383         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1384                                  NVPTX::INT_PTX_LDG_G_v2i8_ELE_ari32,
1385                                  NVPTX::INT_PTX_LDG_G_v2i16_ELE_ari32,
1386                                  NVPTX::INT_PTX_LDG_G_v2i32_ELE_ari32,
1387                                  NVPTX::INT_PTX_LDG_G_v2i64_ELE_ari32,
1388                                  NVPTX::INT_PTX_LDG_G_v2f32_ELE_ari32,
1389                                  NVPTX::INT_PTX_LDG_G_v2f64_ELE_ari32);
1390         break;
1391       case NVPTXISD::LDUV2:
1392         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1393                                  NVPTX::INT_PTX_LDU_G_v2i8_ELE_ari32,
1394                                  NVPTX::INT_PTX_LDU_G_v2i16_ELE_ari32,
1395                                  NVPTX::INT_PTX_LDU_G_v2i32_ELE_ari32,
1396                                  NVPTX::INT_PTX_LDU_G_v2i64_ELE_ari32,
1397                                  NVPTX::INT_PTX_LDU_G_v2f32_ELE_ari32,
1398                                  NVPTX::INT_PTX_LDU_G_v2f64_ELE_ari32);
1399         break;
1400       case NVPTXISD::LoadV4:
1401         Opcode = pickOpcodeForVT(
1402             EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE_ari32,
1403             NVPTX::INT_PTX_LDG_G_v4i16_ELE_ari32,
1404             NVPTX::INT_PTX_LDG_G_v4i32_ELE_ari32, std::nullopt,
1405             NVPTX::INT_PTX_LDG_G_v4f32_ELE_ari32, std::nullopt);
1406         break;
1407       case NVPTXISD::LDUV4:
1408         Opcode = pickOpcodeForVT(
1409             EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE_ari32,
1410             NVPTX::INT_PTX_LDU_G_v4i16_ELE_ari32,
1411             NVPTX::INT_PTX_LDU_G_v4i32_ELE_ari32, std::nullopt,
1412             NVPTX::INT_PTX_LDU_G_v4f32_ELE_ari32, std::nullopt);
1413         break;
1414       }
1415     }
1416     if (!Opcode)
1417       return false;
1418     SDValue Ops[] = {Base, Offset, Chain};
1419     LD = CurDAG->getMachineNode(*Opcode, DL, InstVTList, Ops);
1420   } else {
1421     if (TM.is64Bit()) {
1422       switch (N->getOpcode()) {
1423       default:
1424         return false;
1425       case ISD::LOAD:
1426         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1427                                  NVPTX::INT_PTX_LDG_GLOBAL_i8areg64,
1428                                  NVPTX::INT_PTX_LDG_GLOBAL_i16areg64,
1429                                  NVPTX::INT_PTX_LDG_GLOBAL_i32areg64,
1430                                  NVPTX::INT_PTX_LDG_GLOBAL_i64areg64,
1431                                  NVPTX::INT_PTX_LDG_GLOBAL_f32areg64,
1432                                  NVPTX::INT_PTX_LDG_GLOBAL_f64areg64);
1433         break;
1434       case ISD::INTRINSIC_W_CHAIN:
1435         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1436                                  NVPTX::INT_PTX_LDU_GLOBAL_i8areg64,
1437                                  NVPTX::INT_PTX_LDU_GLOBAL_i16areg64,
1438                                  NVPTX::INT_PTX_LDU_GLOBAL_i32areg64,
1439                                  NVPTX::INT_PTX_LDU_GLOBAL_i64areg64,
1440                                  NVPTX::INT_PTX_LDU_GLOBAL_f32areg64,
1441                                  NVPTX::INT_PTX_LDU_GLOBAL_f64areg64);
1442         break;
1443       case NVPTXISD::LoadV2:
1444         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1445                                      NVPTX::INT_PTX_LDG_G_v2i8_ELE_areg64,
1446                                      NVPTX::INT_PTX_LDG_G_v2i16_ELE_areg64,
1447                                      NVPTX::INT_PTX_LDG_G_v2i32_ELE_areg64,
1448                                      NVPTX::INT_PTX_LDG_G_v2i64_ELE_areg64,
1449                                      NVPTX::INT_PTX_LDG_G_v2f32_ELE_areg64,
1450                                      NVPTX::INT_PTX_LDG_G_v2f64_ELE_areg64);
1451         break;
1452       case NVPTXISD::LDUV2:
1453         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1454                                      NVPTX::INT_PTX_LDU_G_v2i8_ELE_areg64,
1455                                      NVPTX::INT_PTX_LDU_G_v2i16_ELE_areg64,
1456                                      NVPTX::INT_PTX_LDU_G_v2i32_ELE_areg64,
1457                                      NVPTX::INT_PTX_LDU_G_v2i64_ELE_areg64,
1458                                      NVPTX::INT_PTX_LDU_G_v2f32_ELE_areg64,
1459                                      NVPTX::INT_PTX_LDU_G_v2f64_ELE_areg64);
1460         break;
1461       case NVPTXISD::LoadV4:
1462         Opcode = pickOpcodeForVT(
1463             EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE_areg64,
1464             NVPTX::INT_PTX_LDG_G_v4i16_ELE_areg64,
1465             NVPTX::INT_PTX_LDG_G_v4i32_ELE_areg64, std::nullopt,
1466             NVPTX::INT_PTX_LDG_G_v4f32_ELE_areg64, std::nullopt);
1467         break;
1468       case NVPTXISD::LDUV4:
1469         Opcode = pickOpcodeForVT(
1470             EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE_areg64,
1471             NVPTX::INT_PTX_LDU_G_v4i16_ELE_areg64,
1472             NVPTX::INT_PTX_LDU_G_v4i32_ELE_areg64, std::nullopt,
1473             NVPTX::INT_PTX_LDU_G_v4f32_ELE_areg64, std::nullopt);
1474         break;
1475       }
1476     } else {
1477       switch (N->getOpcode()) {
1478       default:
1479         return false;
1480       case ISD::LOAD:
1481         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1482                                  NVPTX::INT_PTX_LDG_GLOBAL_i8areg,
1483                                  NVPTX::INT_PTX_LDG_GLOBAL_i16areg,
1484                                  NVPTX::INT_PTX_LDG_GLOBAL_i32areg,
1485                                  NVPTX::INT_PTX_LDG_GLOBAL_i64areg,
1486                                  NVPTX::INT_PTX_LDG_GLOBAL_f32areg,
1487                                  NVPTX::INT_PTX_LDG_GLOBAL_f64areg);
1488         break;
1489       case ISD::INTRINSIC_W_CHAIN:
1490         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1491                                  NVPTX::INT_PTX_LDU_GLOBAL_i8areg,
1492                                  NVPTX::INT_PTX_LDU_GLOBAL_i16areg,
1493                                  NVPTX::INT_PTX_LDU_GLOBAL_i32areg,
1494                                  NVPTX::INT_PTX_LDU_GLOBAL_i64areg,
1495                                  NVPTX::INT_PTX_LDU_GLOBAL_f32areg,
1496                                  NVPTX::INT_PTX_LDU_GLOBAL_f64areg);
1497         break;
1498       case NVPTXISD::LoadV2:
1499         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1500                                  NVPTX::INT_PTX_LDG_G_v2i8_ELE_areg32,
1501                                  NVPTX::INT_PTX_LDG_G_v2i16_ELE_areg32,
1502                                  NVPTX::INT_PTX_LDG_G_v2i32_ELE_areg32,
1503                                  NVPTX::INT_PTX_LDG_G_v2i64_ELE_areg32,
1504                                  NVPTX::INT_PTX_LDG_G_v2f32_ELE_areg32,
1505                                  NVPTX::INT_PTX_LDG_G_v2f64_ELE_areg32);
1506         break;
1507       case NVPTXISD::LDUV2:
1508         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1509                                  NVPTX::INT_PTX_LDU_G_v2i8_ELE_areg32,
1510                                  NVPTX::INT_PTX_LDU_G_v2i16_ELE_areg32,
1511                                  NVPTX::INT_PTX_LDU_G_v2i32_ELE_areg32,
1512                                  NVPTX::INT_PTX_LDU_G_v2i64_ELE_areg32,
1513                                  NVPTX::INT_PTX_LDU_G_v2f32_ELE_areg32,
1514                                  NVPTX::INT_PTX_LDU_G_v2f64_ELE_areg32);
1515         break;
1516       case NVPTXISD::LoadV4:
1517         Opcode = pickOpcodeForVT(
1518             EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE_areg32,
1519             NVPTX::INT_PTX_LDG_G_v4i16_ELE_areg32,
1520             NVPTX::INT_PTX_LDG_G_v4i32_ELE_areg32, std::nullopt,
1521             NVPTX::INT_PTX_LDG_G_v4f32_ELE_areg32, std::nullopt);
1522         break;
1523       case NVPTXISD::LDUV4:
1524         Opcode = pickOpcodeForVT(
1525             EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE_areg32,
1526             NVPTX::INT_PTX_LDU_G_v4i16_ELE_areg32,
1527             NVPTX::INT_PTX_LDU_G_v4i32_ELE_areg32, std::nullopt,
1528             NVPTX::INT_PTX_LDU_G_v4f32_ELE_areg32, std::nullopt);
1529         break;
1530       }
1531     }
1532     if (!Opcode)
1533       return false;
1534     SDValue Ops[] = { Op1, Chain };
1535     LD = CurDAG->getMachineNode(*Opcode, DL, InstVTList, Ops);
1536   }
1537 
1538   // For automatic generation of LDG (through SelectLoad[Vector], not the
1539   // intrinsics), we may have an extending load like:
1540   //
1541   //   i32,ch = load<LD1[%data1(addrspace=1)], zext from i8> t0, t7, undef:i64
1542   //
1543   // In this case, the matching logic above will select a load for the original
1544   // memory type (in this case, i8) and our types will not match (the node needs
1545   // to return an i32 in this case). Our LDG/LDU nodes do not support the
1546   // concept of sign-/zero-extension, so emulate it here by adding an explicit
1547   // CVT instruction. Ptxas should clean up any redundancies here.
1548 
1549   LoadSDNode *LdNode = dyn_cast<LoadSDNode>(N);
1550 
1551   if (OrigType != EltVT &&
1552       (LdNode || (OrigType.isFloatingPoint() && EltVT.isFloatingPoint()))) {
1553     // We have an extending-load. The instruction we selected operates on the
1554     // smaller type, but the SDNode we are replacing has the larger type. We
1555     // need to emit a CVT to make the types match.
1556     unsigned CvtOpc =
1557         GetConvertOpcode(OrigType.getSimpleVT(), EltVT.getSimpleVT(), LdNode);
1558 
1559     // For each output value, apply the manual sign/zero-extension and make sure
1560     // all users of the load go through that CVT.
1561     for (unsigned i = 0; i != NumElts; ++i) {
1562       SDValue Res(LD, i);
1563       SDValue OrigVal(N, i);
1564 
1565       SDNode *CvtNode =
1566         CurDAG->getMachineNode(CvtOpc, DL, OrigType, Res,
1567                                CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE,
1568                                                          DL, MVT::i32));
1569       ReplaceUses(OrigVal, SDValue(CvtNode, 0));
1570     }
1571   }
1572 
1573   ReplaceNode(N, LD);
1574   return true;
1575 }
1576 
1577 bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
1578   MemSDNode *ST = cast<MemSDNode>(N);
1579   assert(ST->writeMem() && "Expected store");
1580   StoreSDNode *PlainStore = dyn_cast<StoreSDNode>(N);
1581   AtomicSDNode *AtomicStore = dyn_cast<AtomicSDNode>(N);
1582   assert((PlainStore || AtomicStore) && "Expected store");
1583 
1584   // do not support pre/post inc/dec
1585   if (PlainStore && PlainStore->isIndexed())
1586     return false;
1587 
1588   EVT StoreVT = ST->getMemoryVT();
1589   if (!StoreVT.isSimple())
1590     return false;
1591 
1592   // Address Space Setting
1593   unsigned int CodeAddrSpace = getCodeAddrSpace(ST);
1594   unsigned int PointerSize =
1595       CurDAG->getDataLayout().getPointerSizeInBits(ST->getAddressSpace());
1596 
1597   SDLoc DL(N);
1598   SDValue Chain = ST->getChain();
1599   auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, ST);
1600 
1601   // Vector Setting
1602   MVT SimpleVT = StoreVT.getSimpleVT();
1603   unsigned VecType = NVPTX::PTXLdStInstCode::Scalar;
1604 
1605   // Type Setting: toType + toTypeWidth
1606   // - for integer type, always use 'u'
1607   MVT ScalarVT = SimpleVT.getScalarType();
1608   unsigned ToTypeWidth = ScalarVT.getSizeInBits();
1609   if (SimpleVT.isVector()) {
1610     assert((Isv2x16VT(StoreVT) || StoreVT == MVT::v4i8) &&
1611            "Unexpected vector type");
1612     // v2x16 is stored using st.b32
1613     ToTypeWidth = 32;
1614   }
1615 
1616   unsigned int ToType = getLdStRegType(ScalarVT);
1617 
1618   // Create the machine instruction DAG
1619   SDValue Value = PlainStore ? PlainStore->getValue() : AtomicStore->getVal();
1620   SDValue BasePtr = ST->getBasePtr();
1621   SDValue Addr;
1622   SDValue Offset, Base;
1623   std::optional<unsigned> Opcode;
1624   MVT::SimpleValueType SourceVT =
1625       Value.getNode()->getSimpleValueType(0).SimpleTy;
1626 
1627   SmallVector<SDValue, 12> Ops(
1628       {Value, getI32Imm(Ordering, DL), getI32Imm(Scope, DL),
1629        getI32Imm(CodeAddrSpace, DL), getI32Imm(VecType, DL),
1630        getI32Imm(ToType, DL), getI32Imm(ToTypeWidth, DL)});
1631 
1632   if (SelectDirectAddr(BasePtr, Addr)) {
1633     Opcode = pickOpcodeForVT(SourceVT, NVPTX::ST_i8_avar, NVPTX::ST_i16_avar,
1634                              NVPTX::ST_i32_avar, NVPTX::ST_i64_avar,
1635                              NVPTX::ST_f32_avar, NVPTX::ST_f64_avar);
1636     if (!Opcode)
1637       return false;
1638     Ops.append({Addr, Chain});
1639   } else if (PointerSize == 64
1640                  ? SelectADDRsi64(BasePtr.getNode(), BasePtr, Base, Offset)
1641                  : SelectADDRsi(BasePtr.getNode(), BasePtr, Base, Offset)) {
1642     Opcode = pickOpcodeForVT(SourceVT, NVPTX::ST_i8_asi, NVPTX::ST_i16_asi,
1643                              NVPTX::ST_i32_asi, NVPTX::ST_i64_asi,
1644                              NVPTX::ST_f32_asi, NVPTX::ST_f64_asi);
1645     if (!Opcode)
1646       return false;
1647     Ops.append({Base, Offset, Chain});
1648   } else if (PointerSize == 64
1649                  ? SelectADDRri64(BasePtr.getNode(), BasePtr, Base, Offset)
1650                  : SelectADDRri(BasePtr.getNode(), BasePtr, Base, Offset)) {
1651     if (PointerSize == 64)
1652       Opcode =
1653           pickOpcodeForVT(SourceVT, NVPTX::ST_i8_ari_64, NVPTX::ST_i16_ari_64,
1654                           NVPTX::ST_i32_ari_64, NVPTX::ST_i64_ari_64,
1655                           NVPTX::ST_f32_ari_64, NVPTX::ST_f64_ari_64);
1656     else
1657       Opcode = pickOpcodeForVT(SourceVT, NVPTX::ST_i8_ari, NVPTX::ST_i16_ari,
1658                                NVPTX::ST_i32_ari, NVPTX::ST_i64_ari,
1659                                NVPTX::ST_f32_ari, NVPTX::ST_f64_ari);
1660     if (!Opcode)
1661       return false;
1662     Ops.append({Base, Offset, Chain});
1663   } else {
1664     if (PointerSize == 64)
1665       Opcode =
1666           pickOpcodeForVT(SourceVT, NVPTX::ST_i8_areg_64, NVPTX::ST_i16_areg_64,
1667                           NVPTX::ST_i32_areg_64, NVPTX::ST_i64_areg_64,
1668                           NVPTX::ST_f32_areg_64, NVPTX::ST_f64_areg_64);
1669     else
1670       Opcode = pickOpcodeForVT(SourceVT, NVPTX::ST_i8_areg, NVPTX::ST_i16_areg,
1671                                NVPTX::ST_i32_areg, NVPTX::ST_i64_areg,
1672                                NVPTX::ST_f32_areg, NVPTX::ST_f64_areg);
1673     if (!Opcode)
1674       return false;
1675     Ops.append({BasePtr, Chain});
1676   }
1677 
1678   SDNode *NVPTXST = CurDAG->getMachineNode(*Opcode, DL, MVT::Other, Ops);
1679 
1680   if (!NVPTXST)
1681     return false;
1682 
1683   MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
1684   CurDAG->setNodeMemRefs(cast<MachineSDNode>(NVPTXST), {MemRef});
1685   ReplaceNode(N, NVPTXST);
1686   return true;
1687 }
1688 
1689 bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
1690   SDValue Op1 = N->getOperand(1);
1691   SDValue Addr, Offset, Base;
1692   std::optional<unsigned> Opcode;
1693   SDNode *ST;
1694   EVT EltVT = Op1.getValueType();
1695   MemSDNode *MemSD = cast<MemSDNode>(N);
1696   EVT StoreVT = MemSD->getMemoryVT();
1697 
1698   // Address Space Setting
1699   unsigned CodeAddrSpace = getCodeAddrSpace(MemSD);
1700   if (CodeAddrSpace == NVPTX::AddressSpace::Const) {
1701     report_fatal_error("Cannot store to pointer that points to constant "
1702                        "memory space");
1703   }
1704   unsigned int PointerSize =
1705       CurDAG->getDataLayout().getPointerSizeInBits(MemSD->getAddressSpace());
1706 
1707   SDLoc DL(N);
1708   SDValue Chain = N->getOperand(0);
1709   auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, MemSD);
1710 
1711   // Type Setting: toType + toTypeWidth
1712   // - for integer type, always use 'u'
1713   assert(StoreVT.isSimple() && "Store value is not simple");
1714   MVT ScalarVT = StoreVT.getSimpleVT().getScalarType();
1715   unsigned ToTypeWidth = ScalarVT.getSizeInBits();
1716   unsigned ToType = getLdStRegType(ScalarVT);
1717 
1718   SmallVector<SDValue, 12> Ops;
1719   SDValue N2;
1720   unsigned VecType;
1721 
1722   switch (N->getOpcode()) {
1723   case NVPTXISD::StoreV2:
1724     VecType = NVPTX::PTXLdStInstCode::V2;
1725     Ops.append({N->getOperand(1), N->getOperand(2)});
1726     N2 = N->getOperand(3);
1727     break;
1728   case NVPTXISD::StoreV4:
1729     VecType = NVPTX::PTXLdStInstCode::V4;
1730     Ops.append({N->getOperand(1), N->getOperand(2), N->getOperand(3),
1731                 N->getOperand(4)});
1732     N2 = N->getOperand(5);
1733     break;
1734   default:
1735     return false;
1736   }
1737 
1738   if (isVectorElementTypeUpsized(EltVT)) {
1739     EltVT = MVT::i32;
1740     ToType = NVPTX::PTXLdStInstCode::Untyped;
1741     ToTypeWidth = 32;
1742   }
1743 
1744   Ops.append({getI32Imm(Ordering, DL), getI32Imm(Scope, DL),
1745               getI32Imm(CodeAddrSpace, DL), getI32Imm(VecType, DL),
1746               getI32Imm(ToType, DL), getI32Imm(ToTypeWidth, DL)});
1747 
1748   if (SelectDirectAddr(N2, Addr)) {
1749     switch (N->getOpcode()) {
1750     default:
1751       return false;
1752     case NVPTXISD::StoreV2:
1753       Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1754                                NVPTX::STV_i8_v2_avar, NVPTX::STV_i16_v2_avar,
1755                                NVPTX::STV_i32_v2_avar, NVPTX::STV_i64_v2_avar,
1756                                NVPTX::STV_f32_v2_avar, NVPTX::STV_f64_v2_avar);
1757       break;
1758     case NVPTXISD::StoreV4:
1759       Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1760                                NVPTX::STV_i8_v4_avar, NVPTX::STV_i16_v4_avar,
1761                                NVPTX::STV_i32_v4_avar, std::nullopt,
1762                                NVPTX::STV_f32_v4_avar, std::nullopt);
1763       break;
1764     }
1765     Ops.push_back(Addr);
1766   } else if (PointerSize == 64 ? SelectADDRsi64(N2.getNode(), N2, Base, Offset)
1767                                : SelectADDRsi(N2.getNode(), N2, Base, Offset)) {
1768     switch (N->getOpcode()) {
1769     default:
1770       return false;
1771     case NVPTXISD::StoreV2:
1772       Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1773                                NVPTX::STV_i8_v2_asi, NVPTX::STV_i16_v2_asi,
1774                                NVPTX::STV_i32_v2_asi, NVPTX::STV_i64_v2_asi,
1775                                NVPTX::STV_f32_v2_asi, NVPTX::STV_f64_v2_asi);
1776       break;
1777     case NVPTXISD::StoreV4:
1778       Opcode =
1779           pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4_asi,
1780                           NVPTX::STV_i16_v4_asi, NVPTX::STV_i32_v4_asi,
1781                           std::nullopt, NVPTX::STV_f32_v4_asi, std::nullopt);
1782       break;
1783     }
1784     Ops.append({Base, Offset});
1785   } else if (PointerSize == 64 ? SelectADDRri64(N2.getNode(), N2, Base, Offset)
1786                                : SelectADDRri(N2.getNode(), N2, Base, Offset)) {
1787     if (PointerSize == 64) {
1788       switch (N->getOpcode()) {
1789       default:
1790         return false;
1791       case NVPTXISD::StoreV2:
1792         Opcode =
1793             pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1794                             NVPTX::STV_i8_v2_ari_64, NVPTX::STV_i16_v2_ari_64,
1795                             NVPTX::STV_i32_v2_ari_64, NVPTX::STV_i64_v2_ari_64,
1796                             NVPTX::STV_f32_v2_ari_64, NVPTX::STV_f64_v2_ari_64);
1797         break;
1798       case NVPTXISD::StoreV4:
1799         Opcode = pickOpcodeForVT(
1800             EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4_ari_64,
1801             NVPTX::STV_i16_v4_ari_64, NVPTX::STV_i32_v4_ari_64, std::nullopt,
1802             NVPTX::STV_f32_v4_ari_64, std::nullopt);
1803         break;
1804       }
1805     } else {
1806       switch (N->getOpcode()) {
1807       default:
1808         return false;
1809       case NVPTXISD::StoreV2:
1810         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1811                                  NVPTX::STV_i8_v2_ari, NVPTX::STV_i16_v2_ari,
1812                                  NVPTX::STV_i32_v2_ari, NVPTX::STV_i64_v2_ari,
1813                                  NVPTX::STV_f32_v2_ari, NVPTX::STV_f64_v2_ari);
1814         break;
1815       case NVPTXISD::StoreV4:
1816         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1817                                  NVPTX::STV_i8_v4_ari, NVPTX::STV_i16_v4_ari,
1818                                  NVPTX::STV_i32_v4_ari, std::nullopt,
1819                                  NVPTX::STV_f32_v4_ari, std::nullopt);
1820         break;
1821       }
1822     }
1823     Ops.append({Base, Offset});
1824   } else {
1825     if (PointerSize == 64) {
1826       switch (N->getOpcode()) {
1827       default:
1828         return false;
1829       case NVPTXISD::StoreV2:
1830         Opcode = pickOpcodeForVT(
1831             EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v2_areg_64,
1832             NVPTX::STV_i16_v2_areg_64, NVPTX::STV_i32_v2_areg_64,
1833             NVPTX::STV_i64_v2_areg_64, NVPTX::STV_f32_v2_areg_64,
1834             NVPTX::STV_f64_v2_areg_64);
1835         break;
1836       case NVPTXISD::StoreV4:
1837         Opcode = pickOpcodeForVT(
1838             EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4_areg_64,
1839             NVPTX::STV_i16_v4_areg_64, NVPTX::STV_i32_v4_areg_64, std::nullopt,
1840             NVPTX::STV_f32_v4_areg_64, std::nullopt);
1841         break;
1842       }
1843     } else {
1844       switch (N->getOpcode()) {
1845       default:
1846         return false;
1847       case NVPTXISD::StoreV2:
1848         Opcode =
1849             pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v2_areg,
1850                             NVPTX::STV_i16_v2_areg, NVPTX::STV_i32_v2_areg,
1851                             NVPTX::STV_i64_v2_areg, NVPTX::STV_f32_v2_areg,
1852                             NVPTX::STV_f64_v2_areg);
1853         break;
1854       case NVPTXISD::StoreV4:
1855         Opcode =
1856             pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4_areg,
1857                             NVPTX::STV_i16_v4_areg, NVPTX::STV_i32_v4_areg,
1858                             std::nullopt, NVPTX::STV_f32_v4_areg, std::nullopt);
1859         break;
1860       }
1861     }
1862     Ops.push_back(N2);
1863   }
1864 
1865   if (!Opcode)
1866     return false;
1867 
1868   Ops.push_back(Chain);
1869 
1870   ST = CurDAG->getMachineNode(*Opcode, DL, MVT::Other, Ops);
1871 
1872   MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
1873   CurDAG->setNodeMemRefs(cast<MachineSDNode>(ST), {MemRef});
1874 
1875   ReplaceNode(N, ST);
1876   return true;
1877 }
1878 
1879 bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) {
1880   SDValue Chain = Node->getOperand(0);
1881   SDValue Offset = Node->getOperand(2);
1882   SDValue Glue = Node->getOperand(3);
1883   SDLoc DL(Node);
1884   MemSDNode *Mem = cast<MemSDNode>(Node);
1885 
1886   unsigned VecSize;
1887   switch (Node->getOpcode()) {
1888   default:
1889     return false;
1890   case NVPTXISD::LoadParam:
1891     VecSize = 1;
1892     break;
1893   case NVPTXISD::LoadParamV2:
1894     VecSize = 2;
1895     break;
1896   case NVPTXISD::LoadParamV4:
1897     VecSize = 4;
1898     break;
1899   }
1900 
1901   EVT EltVT = Node->getValueType(0);
1902   EVT MemVT = Mem->getMemoryVT();
1903 
1904   std::optional<unsigned> Opcode;
1905 
1906   switch (VecSize) {
1907   default:
1908     return false;
1909   case 1:
1910     Opcode = pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy,
1911                              NVPTX::LoadParamMemI8, NVPTX::LoadParamMemI16,
1912                              NVPTX::LoadParamMemI32, NVPTX::LoadParamMemI64,
1913                              NVPTX::LoadParamMemF32, NVPTX::LoadParamMemF64);
1914     break;
1915   case 2:
1916     Opcode =
1917         pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy, NVPTX::LoadParamMemV2I8,
1918                         NVPTX::LoadParamMemV2I16, NVPTX::LoadParamMemV2I32,
1919                         NVPTX::LoadParamMemV2I64, NVPTX::LoadParamMemV2F32,
1920                         NVPTX::LoadParamMemV2F64);
1921     break;
1922   case 4:
1923     Opcode =
1924         pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy, NVPTX::LoadParamMemV4I8,
1925                         NVPTX::LoadParamMemV4I16, NVPTX::LoadParamMemV4I32,
1926                         std::nullopt, NVPTX::LoadParamMemV4F32, std::nullopt);
1927     break;
1928   }
1929   if (!Opcode)
1930     return false;
1931 
1932   SDVTList VTs;
1933   if (VecSize == 1) {
1934     VTs = CurDAG->getVTList(EltVT, MVT::Other, MVT::Glue);
1935   } else if (VecSize == 2) {
1936     VTs = CurDAG->getVTList(EltVT, EltVT, MVT::Other, MVT::Glue);
1937   } else {
1938     EVT EVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other, MVT::Glue };
1939     VTs = CurDAG->getVTList(EVTs);
1940   }
1941 
1942   unsigned OffsetVal = Offset->getAsZExtVal();
1943 
1944   SmallVector<SDValue, 2> Ops(
1945       {CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain, Glue});
1946 
1947   ReplaceNode(Node, CurDAG->getMachineNode(*Opcode, DL, VTs, Ops));
1948   return true;
1949 }
1950 
1951 bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) {
1952   SDLoc DL(N);
1953   SDValue Chain = N->getOperand(0);
1954   SDValue Offset = N->getOperand(1);
1955   unsigned OffsetVal = Offset->getAsZExtVal();
1956   MemSDNode *Mem = cast<MemSDNode>(N);
1957 
1958   // How many elements do we have?
1959   unsigned NumElts = 1;
1960   switch (N->getOpcode()) {
1961   default:
1962     return false;
1963   case NVPTXISD::StoreRetval:
1964     NumElts = 1;
1965     break;
1966   case NVPTXISD::StoreRetvalV2:
1967     NumElts = 2;
1968     break;
1969   case NVPTXISD::StoreRetvalV4:
1970     NumElts = 4;
1971     break;
1972   }
1973 
1974   // Build vector of operands
1975   SmallVector<SDValue, 6> Ops;
1976   for (unsigned i = 0; i < NumElts; ++i)
1977     Ops.push_back(N->getOperand(i + 2));
1978   Ops.append({CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain});
1979 
1980   // Determine target opcode
1981   // If we have an i1, use an 8-bit store. The lowering code in
1982   // NVPTXISelLowering will have already emitted an upcast.
1983   std::optional<unsigned> Opcode = 0;
1984   switch (NumElts) {
1985   default:
1986     return false;
1987   case 1:
1988     Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
1989                              NVPTX::StoreRetvalI8, NVPTX::StoreRetvalI16,
1990                              NVPTX::StoreRetvalI32, NVPTX::StoreRetvalI64,
1991                              NVPTX::StoreRetvalF32, NVPTX::StoreRetvalF64);
1992     if (Opcode == NVPTX::StoreRetvalI8) {
1993       // Fine tune the opcode depending on the size of the operand.
1994       // This helps to avoid creating redundant COPY instructions in
1995       // InstrEmitter::AddRegisterOperand().
1996       switch (Ops[0].getSimpleValueType().SimpleTy) {
1997       default:
1998         break;
1999       case MVT::i32:
2000         Opcode = NVPTX::StoreRetvalI8TruncI32;
2001         break;
2002       case MVT::i64:
2003         Opcode = NVPTX::StoreRetvalI8TruncI64;
2004         break;
2005       }
2006     }
2007     break;
2008   case 2:
2009     Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
2010                              NVPTX::StoreRetvalV2I8, NVPTX::StoreRetvalV2I16,
2011                              NVPTX::StoreRetvalV2I32, NVPTX::StoreRetvalV2I64,
2012                              NVPTX::StoreRetvalV2F32, NVPTX::StoreRetvalV2F64);
2013     break;
2014   case 4:
2015     Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
2016                              NVPTX::StoreRetvalV4I8, NVPTX::StoreRetvalV4I16,
2017                              NVPTX::StoreRetvalV4I32, std::nullopt,
2018                              NVPTX::StoreRetvalV4F32, std::nullopt);
2019     break;
2020   }
2021   if (!Opcode)
2022     return false;
2023 
2024   SDNode *Ret = CurDAG->getMachineNode(*Opcode, DL, MVT::Other, Ops);
2025   MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
2026   CurDAG->setNodeMemRefs(cast<MachineSDNode>(Ret), {MemRef});
2027 
2028   ReplaceNode(N, Ret);
2029   return true;
2030 }
2031 
2032 // Helpers for constructing opcode (ex: NVPTX::StoreParamV4F32_iiri)
2033 #define getOpcV2H(ty, opKind0, opKind1)                                        \
2034   NVPTX::StoreParamV2##ty##_##opKind0##opKind1
2035 
2036 #define getOpcV2H1(ty, opKind0, isImm1)                                        \
2037   (isImm1) ? getOpcV2H(ty, opKind0, i) : getOpcV2H(ty, opKind0, r)
2038 
2039 #define getOpcodeForVectorStParamV2(ty, isimm)                                 \
2040   (isimm[0]) ? getOpcV2H1(ty, i, isimm[1]) : getOpcV2H1(ty, r, isimm[1])
2041 
2042 #define getOpcV4H(ty, opKind0, opKind1, opKind2, opKind3)                      \
2043   NVPTX::StoreParamV4##ty##_##opKind0##opKind1##opKind2##opKind3
2044 
2045 #define getOpcV4H3(ty, opKind0, opKind1, opKind2, isImm3)                      \
2046   (isImm3) ? getOpcV4H(ty, opKind0, opKind1, opKind2, i)                       \
2047            : getOpcV4H(ty, opKind0, opKind1, opKind2, r)
2048 
2049 #define getOpcV4H2(ty, opKind0, opKind1, isImm2, isImm3)                       \
2050   (isImm2) ? getOpcV4H3(ty, opKind0, opKind1, i, isImm3)                       \
2051            : getOpcV4H3(ty, opKind0, opKind1, r, isImm3)
2052 
2053 #define getOpcV4H1(ty, opKind0, isImm1, isImm2, isImm3)                        \
2054   (isImm1) ? getOpcV4H2(ty, opKind0, i, isImm2, isImm3)                        \
2055            : getOpcV4H2(ty, opKind0, r, isImm2, isImm3)
2056 
2057 #define getOpcodeForVectorStParamV4(ty, isimm)                                 \
2058   (isimm[0]) ? getOpcV4H1(ty, i, isimm[1], isimm[2], isimm[3])                 \
2059              : getOpcV4H1(ty, r, isimm[1], isimm[2], isimm[3])
2060 
2061 #define getOpcodeForVectorStParam(n, ty, isimm)                                \
2062   (n == 2) ? getOpcodeForVectorStParamV2(ty, isimm)                            \
2063            : getOpcodeForVectorStParamV4(ty, isimm)
2064 
2065 static unsigned pickOpcodeForVectorStParam(SmallVector<SDValue, 8> &Ops,
2066                                            unsigned NumElts,
2067                                            MVT::SimpleValueType MemTy,
2068                                            SelectionDAG *CurDAG, SDLoc DL) {
2069   // Determine which inputs are registers and immediates make new operators
2070   // with constant values
2071   SmallVector<bool, 4> IsImm(NumElts, false);
2072   for (unsigned i = 0; i < NumElts; i++) {
2073     IsImm[i] = (isa<ConstantSDNode>(Ops[i]) || isa<ConstantFPSDNode>(Ops[i]));
2074     if (IsImm[i]) {
2075       SDValue Imm = Ops[i];
2076       if (MemTy == MVT::f32 || MemTy == MVT::f64) {
2077         const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
2078         const ConstantFP *CF = ConstImm->getConstantFPValue();
2079         Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
2080       } else {
2081         const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
2082         const ConstantInt *CI = ConstImm->getConstantIntValue();
2083         Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
2084       }
2085       Ops[i] = Imm;
2086     }
2087   }
2088 
2089   // Get opcode for MemTy, size, and register/immediate operand ordering
2090   switch (MemTy) {
2091   case MVT::i8:
2092     return getOpcodeForVectorStParam(NumElts, I8, IsImm);
2093   case MVT::i16:
2094     return getOpcodeForVectorStParam(NumElts, I16, IsImm);
2095   case MVT::i32:
2096     return getOpcodeForVectorStParam(NumElts, I32, IsImm);
2097   case MVT::i64:
2098     assert(NumElts == 2 && "MVT too large for NumElts > 2");
2099     return getOpcodeForVectorStParamV2(I64, IsImm);
2100   case MVT::f32:
2101     return getOpcodeForVectorStParam(NumElts, F32, IsImm);
2102   case MVT::f64:
2103     assert(NumElts == 2 && "MVT too large for NumElts > 2");
2104     return getOpcodeForVectorStParamV2(F64, IsImm);
2105 
2106   // These cases don't support immediates, just use the all register version
2107   // and generate moves.
2108   case MVT::i1:
2109     return (NumElts == 2) ? NVPTX::StoreParamV2I8_rr
2110                           : NVPTX::StoreParamV4I8_rrrr;
2111   case MVT::f16:
2112   case MVT::bf16:
2113     return (NumElts == 2) ? NVPTX::StoreParamV2I16_rr
2114                           : NVPTX::StoreParamV4I16_rrrr;
2115   case MVT::v2f16:
2116   case MVT::v2bf16:
2117   case MVT::v2i16:
2118   case MVT::v4i8:
2119     return (NumElts == 2) ? NVPTX::StoreParamV2I32_rr
2120                           : NVPTX::StoreParamV4I32_rrrr;
2121   default:
2122     llvm_unreachable("Cannot select st.param for unknown MemTy");
2123   }
2124 }
2125 
2126 bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
2127   SDLoc DL(N);
2128   SDValue Chain = N->getOperand(0);
2129   SDValue Param = N->getOperand(1);
2130   unsigned ParamVal = Param->getAsZExtVal();
2131   SDValue Offset = N->getOperand(2);
2132   unsigned OffsetVal = Offset->getAsZExtVal();
2133   MemSDNode *Mem = cast<MemSDNode>(N);
2134   SDValue Glue = N->getOperand(N->getNumOperands() - 1);
2135 
2136   // How many elements do we have?
2137   unsigned NumElts;
2138   switch (N->getOpcode()) {
2139   default:
2140     llvm_unreachable("Unexpected opcode");
2141   case NVPTXISD::StoreParamU32:
2142   case NVPTXISD::StoreParamS32:
2143   case NVPTXISD::StoreParam:
2144     NumElts = 1;
2145     break;
2146   case NVPTXISD::StoreParamV2:
2147     NumElts = 2;
2148     break;
2149   case NVPTXISD::StoreParamV4:
2150     NumElts = 4;
2151     break;
2152   }
2153 
2154   // Build vector of operands
2155   SmallVector<SDValue, 8> Ops;
2156   for (unsigned i = 0; i < NumElts; ++i)
2157     Ops.push_back(N->getOperand(i + 3));
2158   Ops.append({CurDAG->getTargetConstant(ParamVal, DL, MVT::i32),
2159               CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain, Glue});
2160 
2161   // Determine target opcode
2162   // If we have an i1, use an 8-bit store. The lowering code in
2163   // NVPTXISelLowering will have already emitted an upcast.
2164   std::optional<unsigned> Opcode;
2165   switch (N->getOpcode()) {
2166   default:
2167     switch (NumElts) {
2168     default:
2169       llvm_unreachable("Unexpected NumElts");
2170     case 1: {
2171       MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
2172       SDValue Imm = Ops[0];
2173       if (MemTy != MVT::f16 && MemTy != MVT::v2f16 &&
2174           (isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) {
2175         // Convert immediate to target constant
2176         if (MemTy == MVT::f32 || MemTy == MVT::f64) {
2177           const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
2178           const ConstantFP *CF = ConstImm->getConstantFPValue();
2179           Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
2180         } else {
2181           const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
2182           const ConstantInt *CI = ConstImm->getConstantIntValue();
2183           Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
2184         }
2185         Ops[0] = Imm;
2186         // Use immediate version of store param
2187         Opcode = pickOpcodeForVT(MemTy, NVPTX::StoreParamI8_i,
2188                                  NVPTX::StoreParamI16_i, NVPTX::StoreParamI32_i,
2189                                  NVPTX::StoreParamI64_i, NVPTX::StoreParamF32_i,
2190                                  NVPTX::StoreParamF64_i);
2191       } else
2192         Opcode =
2193             pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
2194                             NVPTX::StoreParamI8_r, NVPTX::StoreParamI16_r,
2195                             NVPTX::StoreParamI32_r, NVPTX::StoreParamI64_r,
2196                             NVPTX::StoreParamF32_r, NVPTX::StoreParamF64_r);
2197       if (Opcode == NVPTX::StoreParamI8_r) {
2198         // Fine tune the opcode depending on the size of the operand.
2199         // This helps to avoid creating redundant COPY instructions in
2200         // InstrEmitter::AddRegisterOperand().
2201         switch (Ops[0].getSimpleValueType().SimpleTy) {
2202         default:
2203           break;
2204         case MVT::i32:
2205           Opcode = NVPTX::StoreParamI8TruncI32_r;
2206           break;
2207         case MVT::i64:
2208           Opcode = NVPTX::StoreParamI8TruncI64_r;
2209           break;
2210         }
2211       }
2212       break;
2213     }
2214     case 2:
2215     case 4: {
2216       MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
2217       Opcode = pickOpcodeForVectorStParam(Ops, NumElts, MemTy, CurDAG, DL);
2218       break;
2219     }
2220     }
2221     break;
2222   // Special case: if we have a sign-extend/zero-extend node, insert the
2223   // conversion instruction first, and use that as the value operand to
2224   // the selected StoreParam node.
2225   case NVPTXISD::StoreParamU32: {
2226     Opcode = NVPTX::StoreParamI32_r;
2227     SDValue CvtNone = CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL,
2228                                                 MVT::i32);
2229     SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_u32_u16, DL,
2230                                          MVT::i32, Ops[0], CvtNone);
2231     Ops[0] = SDValue(Cvt, 0);
2232     break;
2233   }
2234   case NVPTXISD::StoreParamS32: {
2235     Opcode = NVPTX::StoreParamI32_r;
2236     SDValue CvtNone = CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL,
2237                                                 MVT::i32);
2238     SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_s32_s16, DL,
2239                                          MVT::i32, Ops[0], CvtNone);
2240     Ops[0] = SDValue(Cvt, 0);
2241     break;
2242   }
2243   }
2244 
2245   SDVTList RetVTs = CurDAG->getVTList(MVT::Other, MVT::Glue);
2246   SDNode *Ret = CurDAG->getMachineNode(*Opcode, DL, RetVTs, Ops);
2247   MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
2248   CurDAG->setNodeMemRefs(cast<MachineSDNode>(Ret), {MemRef});
2249 
2250   ReplaceNode(N, Ret);
2251   return true;
2252 }
2253 
2254 /// SelectBFE - Look for instruction sequences that can be made more efficient
2255 /// by using the 'bfe' (bit-field extract) PTX instruction
2256 bool NVPTXDAGToDAGISel::tryBFE(SDNode *N) {
2257   SDLoc DL(N);
2258   SDValue LHS = N->getOperand(0);
2259   SDValue RHS = N->getOperand(1);
2260   SDValue Len;
2261   SDValue Start;
2262   SDValue Val;
2263   bool IsSigned = false;
2264 
2265   if (N->getOpcode() == ISD::AND) {
2266     // Canonicalize the operands
2267     // We want 'and %val, %mask'
2268     if (isa<ConstantSDNode>(LHS) && !isa<ConstantSDNode>(RHS)) {
2269       std::swap(LHS, RHS);
2270     }
2271 
2272     ConstantSDNode *Mask = dyn_cast<ConstantSDNode>(RHS);
2273     if (!Mask) {
2274       // We need a constant mask on the RHS of the AND
2275       return false;
2276     }
2277 
2278     // Extract the mask bits
2279     uint64_t MaskVal = Mask->getZExtValue();
2280     if (!isMask_64(MaskVal)) {
2281       // We *could* handle shifted masks here, but doing so would require an
2282       // 'and' operation to fix up the low-order bits so we would trade
2283       // shr+and for bfe+and, which has the same throughput
2284       return false;
2285     }
2286 
2287     // How many bits are in our mask?
2288     int64_t NumBits = countr_one(MaskVal);
2289     Len = CurDAG->getTargetConstant(NumBits, DL, MVT::i32);
2290 
2291     if (LHS.getOpcode() == ISD::SRL || LHS.getOpcode() == ISD::SRA) {
2292       // We have a 'srl/and' pair, extract the effective start bit and length
2293       Val = LHS.getNode()->getOperand(0);
2294       Start = LHS.getNode()->getOperand(1);
2295       ConstantSDNode *StartConst = dyn_cast<ConstantSDNode>(Start);
2296       if (StartConst) {
2297         uint64_t StartVal = StartConst->getZExtValue();
2298         // How many "good" bits do we have left?  "good" is defined here as bits
2299         // that exist in the original value, not shifted in.
2300         int64_t GoodBits = Start.getValueSizeInBits() - StartVal;
2301         if (NumBits > GoodBits) {
2302           // Do not handle the case where bits have been shifted in. In theory
2303           // we could handle this, but the cost is likely higher than just
2304           // emitting the srl/and pair.
2305           return false;
2306         }
2307         Start = CurDAG->getTargetConstant(StartVal, DL, MVT::i32);
2308       } else {
2309         // Do not handle the case where the shift amount (can be zero if no srl
2310         // was found) is not constant. We could handle this case, but it would
2311         // require run-time logic that would be more expensive than just
2312         // emitting the srl/and pair.
2313         return false;
2314       }
2315     } else {
2316       // Do not handle the case where the LHS of the and is not a shift. While
2317       // it would be trivial to handle this case, it would just transform
2318       // 'and' -> 'bfe', but 'and' has higher-throughput.
2319       return false;
2320     }
2321   } else if (N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) {
2322     if (LHS->getOpcode() == ISD::AND) {
2323       ConstantSDNode *ShiftCnst = dyn_cast<ConstantSDNode>(RHS);
2324       if (!ShiftCnst) {
2325         // Shift amount must be constant
2326         return false;
2327       }
2328 
2329       uint64_t ShiftAmt = ShiftCnst->getZExtValue();
2330 
2331       SDValue AndLHS = LHS->getOperand(0);
2332       SDValue AndRHS = LHS->getOperand(1);
2333 
2334       // Canonicalize the AND to have the mask on the RHS
2335       if (isa<ConstantSDNode>(AndLHS)) {
2336         std::swap(AndLHS, AndRHS);
2337       }
2338 
2339       ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(AndRHS);
2340       if (!MaskCnst) {
2341         // Mask must be constant
2342         return false;
2343       }
2344 
2345       uint64_t MaskVal = MaskCnst->getZExtValue();
2346       uint64_t NumZeros;
2347       uint64_t NumBits;
2348       if (isMask_64(MaskVal)) {
2349         NumZeros = 0;
2350         // The number of bits in the result bitfield will be the number of
2351         // trailing ones (the AND) minus the number of bits we shift off
2352         NumBits = llvm::countr_one(MaskVal) - ShiftAmt;
2353       } else if (isShiftedMask_64(MaskVal)) {
2354         NumZeros = llvm::countr_zero(MaskVal);
2355         unsigned NumOnes = llvm::countr_one(MaskVal >> NumZeros);
2356         // The number of bits in the result bitfield will be the number of
2357         // trailing zeros plus the number of set bits in the mask minus the
2358         // number of bits we shift off
2359         NumBits = NumZeros + NumOnes - ShiftAmt;
2360       } else {
2361         // This is not a mask we can handle
2362         return false;
2363       }
2364 
2365       if (ShiftAmt < NumZeros) {
2366         // Handling this case would require extra logic that would make this
2367         // transformation non-profitable
2368         return false;
2369       }
2370 
2371       Val = AndLHS;
2372       Start = CurDAG->getTargetConstant(ShiftAmt, DL, MVT::i32);
2373       Len = CurDAG->getTargetConstant(NumBits, DL, MVT::i32);
2374     } else if (LHS->getOpcode() == ISD::SHL) {
2375       // Here, we have a pattern like:
2376       //
2377       // (sra (shl val, NN), MM)
2378       // or
2379       // (srl (shl val, NN), MM)
2380       //
2381       // If MM >= NN, we can efficiently optimize this with bfe
2382       Val = LHS->getOperand(0);
2383 
2384       SDValue ShlRHS = LHS->getOperand(1);
2385       ConstantSDNode *ShlCnst = dyn_cast<ConstantSDNode>(ShlRHS);
2386       if (!ShlCnst) {
2387         // Shift amount must be constant
2388         return false;
2389       }
2390       uint64_t InnerShiftAmt = ShlCnst->getZExtValue();
2391 
2392       SDValue ShrRHS = RHS;
2393       ConstantSDNode *ShrCnst = dyn_cast<ConstantSDNode>(ShrRHS);
2394       if (!ShrCnst) {
2395         // Shift amount must be constant
2396         return false;
2397       }
2398       uint64_t OuterShiftAmt = ShrCnst->getZExtValue();
2399 
2400       // To avoid extra codegen and be profitable, we need Outer >= Inner
2401       if (OuterShiftAmt < InnerShiftAmt) {
2402         return false;
2403       }
2404 
2405       // If the outer shift is more than the type size, we have no bitfield to
2406       // extract (since we also check that the inner shift is <= the outer shift
2407       // then this also implies that the inner shift is < the type size)
2408       if (OuterShiftAmt >= Val.getValueSizeInBits()) {
2409         return false;
2410       }
2411 
2412       Start = CurDAG->getTargetConstant(OuterShiftAmt - InnerShiftAmt, DL,
2413                                         MVT::i32);
2414       Len = CurDAG->getTargetConstant(Val.getValueSizeInBits() - OuterShiftAmt,
2415                                       DL, MVT::i32);
2416 
2417       if (N->getOpcode() == ISD::SRA) {
2418         // If we have a arithmetic right shift, we need to use the signed bfe
2419         // variant
2420         IsSigned = true;
2421       }
2422     } else {
2423       // No can do...
2424       return false;
2425     }
2426   } else {
2427     // No can do...
2428     return false;
2429   }
2430 
2431 
2432   unsigned Opc;
2433   // For the BFE operations we form here from "and" and "srl", always use the
2434   // unsigned variants.
2435   if (Val.getValueType() == MVT::i32) {
2436     if (IsSigned) {
2437       Opc = NVPTX::BFE_S32rii;
2438     } else {
2439       Opc = NVPTX::BFE_U32rii;
2440     }
2441   } else if (Val.getValueType() == MVT::i64) {
2442     if (IsSigned) {
2443       Opc = NVPTX::BFE_S64rii;
2444     } else {
2445       Opc = NVPTX::BFE_U64rii;
2446     }
2447   } else {
2448     // We cannot handle this type
2449     return false;
2450   }
2451 
2452   SDValue Ops[] = {
2453     Val, Start, Len
2454   };
2455 
2456   ReplaceNode(N, CurDAG->getMachineNode(Opc, DL, N->getVTList(), Ops));
2457   return true;
2458 }
2459 
2460 // Select bf16/bf16v2 FADD, FSUB, FMUL as fma on targets with only fma
2461 bool NVPTXDAGToDAGISel::tryBF16ArithToFMA(SDNode *N) {
2462   EVT VT = SDValue(N, 0).getValueType();
2463   if (VT.getScalarType() != MVT::bf16)
2464     return false;
2465 
2466   const NVPTXSubtarget *STI = TM.getSubtargetImpl();
2467   if (STI->hasNativeBF16Support(N->getOpcode()))
2468     return false;
2469 
2470   const bool IsVec = VT.isVector();
2471   assert(!IsVec || VT.getVectorNumElements() == 2);
2472   SDLoc DL(N);
2473   SDValue N0 = N->getOperand(0);
2474   SDValue N1 = N->getOperand(1);
2475   SmallVector<SDValue, 3> Operands;
2476   auto GetConstant = [&](float Value) -> SDValue {
2477     // BF16 immediates must be legalized to integer register values
2478     APFloat APF(Value);
2479     bool LosesInfo;
2480     APF.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &LosesInfo);
2481     assert(!LosesInfo);
2482     if (IsVec) {
2483       auto API = APF.bitcastToAPInt();
2484       API = API.concat(API);
2485       auto Const = CurDAG->getTargetConstant(API, DL, MVT::i32);
2486       return SDValue(CurDAG->getMachineNode(NVPTX::IMOV32ri, DL, VT, Const), 0);
2487     }
2488     auto Const = CurDAG->getTargetConstantFP(APF, DL, VT);
2489     return SDValue(CurDAG->getMachineNode(NVPTX::BFMOV16ri, DL, VT, Const), 0);
2490   };
2491 
2492   switch (N->getOpcode()) {
2493   case ISD::FADD:
2494     // add(a, b) -> fma(a, 1.0, b)
2495     Operands = {N0, GetConstant(1.0), N1};
2496     break;
2497   case ISD::FSUB:
2498     // sub(a, b) -> fma(b, -1.0, a)
2499     Operands = {N1, GetConstant(-1.0), N0};
2500     break;
2501   case ISD::FMUL:
2502     // mul(a, b) -> fma(a, b, -0.0)
2503     // NOTE: The identity is -0, not 0, because -0 + 0 == 0 for floats
2504     Operands = {N0, N1, GetConstant(-0.0)};
2505     break;
2506   default:
2507     llvm_unreachable("Unexpected opcode");
2508   };
2509 
2510   int Opcode = IsVec ? NVPTX::BFMA16x2rrr : NVPTX::BFMA16rrr;
2511   MachineSDNode *FMA = CurDAG->getMachineNode(Opcode, DL, VT, Operands);
2512   ReplaceNode(N, FMA);
2513   return true;
2514 }
2515 
2516 static inline bool isAddLike(const SDValue V) {
2517   return V.getOpcode() == ISD::ADD ||
2518          (V->getOpcode() == ISD::OR && V->getFlags().hasDisjoint());
2519 }
2520 
2521 // SelectDirectAddr - Match a direct address for DAG.
2522 // A direct address could be a globaladdress or externalsymbol.
2523 bool NVPTXDAGToDAGISel::SelectDirectAddr(SDValue N, SDValue &Address) {
2524   // Return true if TGA or ES.
2525   if (N.getOpcode() == ISD::TargetGlobalAddress ||
2526       N.getOpcode() == ISD::TargetExternalSymbol) {
2527     Address = N;
2528     return true;
2529   }
2530   if (N.getOpcode() == NVPTXISD::Wrapper) {
2531     Address = N.getOperand(0);
2532     return true;
2533   }
2534   // addrspacecast(MoveParam(arg_symbol) to addrspace(PARAM)) -> arg_symbol
2535   if (AddrSpaceCastSDNode *CastN = dyn_cast<AddrSpaceCastSDNode>(N)) {
2536     if (CastN->getSrcAddressSpace() == ADDRESS_SPACE_GENERIC &&
2537         CastN->getDestAddressSpace() == ADDRESS_SPACE_PARAM &&
2538         CastN->getOperand(0).getOpcode() == NVPTXISD::MoveParam)
2539       return SelectDirectAddr(CastN->getOperand(0).getOperand(0), Address);
2540   }
2541   return false;
2542 }
2543 
2544 // symbol+offset
2545 bool NVPTXDAGToDAGISel::SelectADDRsi_imp(SDNode *OpNode, SDValue Addr,
2546                                          SDValue &Base, SDValue &Offset,
2547                                          MVT VT) {
2548   std::function<std::optional<uint64_t>(SDValue, uint64_t)>
2549       FindRootAddressAndTotalOffset =
2550           [&](SDValue Addr,
2551               uint64_t AccumulatedOffset) -> std::optional<uint64_t> {
2552     if (isAddLike(Addr)) {
2553       if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(Addr.getOperand(1))) {
2554         SDValue PossibleBaseAddr = Addr.getOperand(0);
2555         AccumulatedOffset += CN->getZExtValue();
2556         if (SelectDirectAddr(PossibleBaseAddr, Base))
2557           return AccumulatedOffset;
2558         return FindRootAddressAndTotalOffset(PossibleBaseAddr,
2559                                              AccumulatedOffset);
2560       }
2561     }
2562     return std::nullopt;
2563   };
2564   if (auto AccumulatedOffset = FindRootAddressAndTotalOffset(Addr, 0)) {
2565     Offset = CurDAG->getTargetConstant(*AccumulatedOffset, SDLoc(OpNode), VT);
2566     return true;
2567   }
2568   return false;
2569 }
2570 
2571 // symbol+offset
2572 bool NVPTXDAGToDAGISel::SelectADDRsi(SDNode *OpNode, SDValue Addr,
2573                                      SDValue &Base, SDValue &Offset) {
2574   return SelectADDRsi_imp(OpNode, Addr, Base, Offset, MVT::i32);
2575 }
2576 
2577 // symbol+offset
2578 bool NVPTXDAGToDAGISel::SelectADDRsi64(SDNode *OpNode, SDValue Addr,
2579                                        SDValue &Base, SDValue &Offset) {
2580   return SelectADDRsi_imp(OpNode, Addr, Base, Offset, MVT::i64);
2581 }
2582 
2583 // register+offset
2584 bool NVPTXDAGToDAGISel::SelectADDRri_imp(SDNode *OpNode, SDValue Addr,
2585                                          SDValue &Base, SDValue &Offset,
2586                                          MVT VT) {
2587   if (FrameIndexSDNode *FIN = dyn_cast<FrameIndexSDNode>(Addr)) {
2588     Base = CurDAG->getTargetFrameIndex(FIN->getIndex(), VT);
2589     Offset = CurDAG->getTargetConstant(0, SDLoc(OpNode), VT);
2590     return true;
2591   }
2592   if (Addr.getOpcode() == ISD::TargetExternalSymbol ||
2593       Addr.getOpcode() == ISD::TargetGlobalAddress)
2594     return false; // direct calls.
2595 
2596   if (isAddLike(Addr)) {
2597     if (SelectDirectAddr(Addr.getOperand(0), Addr)) {
2598       return false;
2599     }
2600     if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(Addr.getOperand(1))) {
2601       if (FrameIndexSDNode *FIN =
2602               dyn_cast<FrameIndexSDNode>(Addr.getOperand(0)))
2603         // Constant offset from frame ref.
2604         Base = CurDAG->getTargetFrameIndex(FIN->getIndex(), VT);
2605       else
2606         Base = Addr.getOperand(0);
2607 
2608       // Offset must fit in a 32-bit signed int in PTX [register+offset] address
2609       // mode
2610       if (!CN->getAPIntValue().isSignedIntN(32))
2611         return false;
2612 
2613       Offset = CurDAG->getSignedTargetConstant(CN->getSExtValue(),
2614                                                SDLoc(OpNode), MVT::i32);
2615       return true;
2616     }
2617   }
2618   return false;
2619 }
2620 
2621 // register+offset
2622 bool NVPTXDAGToDAGISel::SelectADDRri(SDNode *OpNode, SDValue Addr,
2623                                      SDValue &Base, SDValue &Offset) {
2624   return SelectADDRri_imp(OpNode, Addr, Base, Offset, MVT::i32);
2625 }
2626 
2627 // register+offset
2628 bool NVPTXDAGToDAGISel::SelectADDRri64(SDNode *OpNode, SDValue Addr,
2629                                        SDValue &Base, SDValue &Offset) {
2630   return SelectADDRri_imp(OpNode, Addr, Base, Offset, MVT::i64);
2631 }
2632 
2633 bool NVPTXDAGToDAGISel::ChkMemSDNodeAddressSpace(SDNode *N,
2634                                                  unsigned int spN) const {
2635   const Value *Src = nullptr;
2636   if (MemSDNode *mN = dyn_cast<MemSDNode>(N)) {
2637     if (spN == 0 && mN->getMemOperand()->getPseudoValue())
2638       return true;
2639     Src = mN->getMemOperand()->getValue();
2640   }
2641   if (!Src)
2642     return false;
2643   if (auto *PT = dyn_cast<PointerType>(Src->getType()))
2644     return (PT->getAddressSpace() == spN);
2645   return false;
2646 }
2647 
2648 /// SelectInlineAsmMemoryOperand - Implement addressing mode selection for
2649 /// inline asm expressions.
2650 bool NVPTXDAGToDAGISel::SelectInlineAsmMemoryOperand(
2651     const SDValue &Op, InlineAsm::ConstraintCode ConstraintID,
2652     std::vector<SDValue> &OutOps) {
2653   SDValue Op0, Op1;
2654   switch (ConstraintID) {
2655   default:
2656     return true;
2657   case InlineAsm::ConstraintCode::m: // memory
2658     if (SelectDirectAddr(Op, Op0)) {
2659       OutOps.push_back(Op0);
2660       OutOps.push_back(CurDAG->getTargetConstant(0, SDLoc(Op), MVT::i32));
2661       return false;
2662     }
2663     if (SelectADDRri(Op.getNode(), Op, Op0, Op1)) {
2664       OutOps.push_back(Op0);
2665       OutOps.push_back(Op1);
2666       return false;
2667     }
2668     break;
2669   }
2670   return true;
2671 }
2672 
2673 void NVPTXDAGToDAGISel::SelectV2I64toI128(SDNode *N) {
2674   // Lower a CopyToReg with two 64-bit inputs
2675   // Dst:i128, lo:i64, hi:i64
2676   //
2677   // CopyToReg Dst, lo, hi;
2678   //
2679   // ==>
2680   //
2681   // tmp = V2I64toI128 {lo, hi};
2682   // CopyToReg Dst, tmp;
2683   SDValue Dst = N->getOperand(1);
2684   SDValue Lo = N->getOperand(2);
2685   SDValue Hi = N->getOperand(3);
2686 
2687   SDLoc DL(N);
2688   SDNode *Mov =
2689       CurDAG->getMachineNode(NVPTX::V2I64toI128, DL, MVT::i128, {Lo, Hi});
2690 
2691   SmallVector<SDValue, 4> NewOps(N->getNumOperands() - 1);
2692   NewOps[0] = N->getOperand(0);
2693   NewOps[1] = Dst;
2694   NewOps[2] = SDValue(Mov, 0);
2695   if (N->getNumOperands() == 5)
2696     NewOps[3] = N->getOperand(4);
2697   SDValue NewValue = CurDAG->getNode(ISD::CopyToReg, DL, SmallVector<EVT>(N->values()), NewOps);
2698 
2699   ReplaceNode(N, NewValue.getNode());
2700 }
2701 
2702 void NVPTXDAGToDAGISel::SelectI128toV2I64(SDNode *N) {
2703   // Lower CopyFromReg from a 128-bit regs to two 64-bit regs
2704   // Dst:i128, Src:i128
2705   //
2706   // {lo, hi} = CopyFromReg Src
2707   //
2708   // ==>
2709   //
2710   // {lo, hi} = I128toV2I64 Src
2711   //
2712   SDValue Ch = N->getOperand(0);
2713   SDValue Src = N->getOperand(1);
2714   SDValue Glue = N->getOperand(2);
2715   SDLoc DL(N);
2716 
2717   // Add Glue and Ch to the operands and results to avoid break the execution
2718   // order
2719   SDNode *Mov = CurDAG->getMachineNode(
2720       NVPTX::I128toV2I64, DL,
2721       {MVT::i64, MVT::i64, Ch.getValueType(), Glue.getValueType()},
2722       {Src, Ch, Glue});
2723 
2724   ReplaceNode(N, Mov);
2725 }
2726 
2727 /// GetConvertOpcode - Returns the CVT_ instruction opcode that implements a
2728 /// conversion from \p SrcTy to \p DestTy.
2729 unsigned NVPTXDAGToDAGISel::GetConvertOpcode(MVT DestTy, MVT SrcTy,
2730                                              LoadSDNode *LdNode) {
2731   bool IsSigned = LdNode && LdNode->getExtensionType() == ISD::SEXTLOAD;
2732   switch (SrcTy.SimpleTy) {
2733   default:
2734     llvm_unreachable("Unhandled source type");
2735   case MVT::i8:
2736     switch (DestTy.SimpleTy) {
2737     default:
2738       llvm_unreachable("Unhandled dest type");
2739     case MVT::i16:
2740       return IsSigned ? NVPTX::CVT_s16_s8 : NVPTX::CVT_u16_u8;
2741     case MVT::i32:
2742       return IsSigned ? NVPTX::CVT_s32_s8 : NVPTX::CVT_u32_u8;
2743     case MVT::i64:
2744       return IsSigned ? NVPTX::CVT_s64_s8 : NVPTX::CVT_u64_u8;
2745     }
2746   case MVT::i16:
2747     switch (DestTy.SimpleTy) {
2748     default:
2749       llvm_unreachable("Unhandled dest type");
2750     case MVT::i8:
2751       return IsSigned ? NVPTX::CVT_s8_s16 : NVPTX::CVT_u8_u16;
2752     case MVT::i32:
2753       return IsSigned ? NVPTX::CVT_s32_s16 : NVPTX::CVT_u32_u16;
2754     case MVT::i64:
2755       return IsSigned ? NVPTX::CVT_s64_s16 : NVPTX::CVT_u64_u16;
2756     }
2757   case MVT::i32:
2758     switch (DestTy.SimpleTy) {
2759     default:
2760       llvm_unreachable("Unhandled dest type");
2761     case MVT::i8:
2762       return IsSigned ? NVPTX::CVT_s8_s32 : NVPTX::CVT_u8_u32;
2763     case MVT::i16:
2764       return IsSigned ? NVPTX::CVT_s16_s32 : NVPTX::CVT_u16_u32;
2765     case MVT::i64:
2766       return IsSigned ? NVPTX::CVT_s64_s32 : NVPTX::CVT_u64_u32;
2767     }
2768   case MVT::i64:
2769     switch (DestTy.SimpleTy) {
2770     default:
2771       llvm_unreachable("Unhandled dest type");
2772     case MVT::i8:
2773       return IsSigned ? NVPTX::CVT_s8_s64 : NVPTX::CVT_u8_u64;
2774     case MVT::i16:
2775       return IsSigned ? NVPTX::CVT_s16_s64 : NVPTX::CVT_u16_u64;
2776     case MVT::i32:
2777       return IsSigned ? NVPTX::CVT_s32_s64 : NVPTX::CVT_u32_u64;
2778     }
2779   case MVT::f16:
2780     switch (DestTy.SimpleTy) {
2781     default:
2782       llvm_unreachable("Unhandled dest type");
2783     case MVT::f32:
2784       return NVPTX::CVT_f32_f16;
2785     case MVT::f64:
2786       return NVPTX::CVT_f64_f16;
2787     }
2788   }
2789 }
2790 
2791 bool NVPTXDAGToDAGISel::tryFence(SDNode *N) {
2792   SDLoc DL(N);
2793   assert(N->getOpcode() == ISD::ATOMIC_FENCE);
2794   unsigned int FenceOp =
2795       getFenceOp(NVPTX::Ordering(N->getConstantOperandVal(1)),
2796                  Scopes[N->getConstantOperandVal(2)], Subtarget);
2797   SDValue Chain = N->getOperand(0);
2798   SDNode *FenceNode = CurDAG->getMachineNode(FenceOp, DL, MVT::Other, Chain);
2799   ReplaceNode(N, FenceNode);
2800   return true;
2801 }
2802 
2803 NVPTXScopes::NVPTXScopes(LLVMContext &C) {
2804   Scopes[C.getOrInsertSyncScopeID("singlethread")] = NVPTX::Scope::Thread;
2805   Scopes[C.getOrInsertSyncScopeID("")] = NVPTX::Scope::System;
2806   Scopes[C.getOrInsertSyncScopeID("block")] = NVPTX::Scope::Block;
2807   Scopes[C.getOrInsertSyncScopeID("cluster")] = NVPTX::Scope::Cluster;
2808   Scopes[C.getOrInsertSyncScopeID("device")] = NVPTX::Scope::Device;
2809 }
2810 
2811 NVPTX::Scope NVPTXScopes::operator[](SyncScope::ID ID) const {
2812   if (Scopes.empty())
2813     llvm_unreachable("NVPTX Scopes must be initialized before calling "
2814                      "NVPTXScopes::operator[]");
2815 
2816   auto S = Scopes.find(ID);
2817   if (S == Scopes.end()) {
2818     // TODO:
2819     // - Add API to LLVMContext to get the name of a single scope.
2820     // - Use that API here to print an error containing the name
2821     //   of this Unknown ID.
2822     report_fatal_error(formatv("Could not find scope ID={}.", int(ID)));
2823   }
2824   return S->second;
2825 }
2826 
2827 bool NVPTXScopes::empty() const { return Scopes.size() == 0; }
2828 
2829 #define CP_ASYNC_BULK_TENSOR_OPCODE(dir, dim, mode, is_s32, suffix)            \
2830   (is_s32                                                                      \
2831        ? NVPTX::CP_ASYNC_BULK_TENSOR_##dir##_##dim##_SHARED32_##mode##suffix   \
2832        : NVPTX::CP_ASYNC_BULK_TENSOR_##dir##_##dim##_##mode##suffix)
2833 
2834 #define CP_ASYNC_BULK_TENSOR_OPCODE_S2G_IMPL(op, dim, mode, is_ch, is_s32)     \
2835   (is_ch ? (CP_ASYNC_BULK_TENSOR_OPCODE(op, dim, mode, is_s32, _CH))           \
2836          : (CP_ASYNC_BULK_TENSOR_OPCODE(op, dim, mode, is_s32, )))
2837 
2838 #define GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(dim, mode, is_reduce, is_ch,       \
2839                                             is_s32)                            \
2840   (is_reduce                                                                   \
2841        ? (CP_ASYNC_BULK_TENSOR_OPCODE_S2G_IMPL(RED, dim, mode, is_ch, is_s32)) \
2842        : (CP_ASYNC_BULK_TENSOR_OPCODE_S2G_IMPL(S2G, dim, mode, is_ch,          \
2843                                                is_s32)))
2844 
2845 #define GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(dim, mode, is_mc, is_ch, is_s32)   \
2846   [&]() -> auto {                                                              \
2847     if (is_mc && is_ch)                                                        \
2848       return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, _MC_CH);      \
2849     if (is_ch)                                                                 \
2850       return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, _CH);         \
2851     if (is_mc)                                                                 \
2852       return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, _MC);         \
2853     return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, );              \
2854   }()
2855 
2856 #define GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(dim, mode, is_ch)             \
2857   (is_ch ? NVPTX::CP_ASYNC_BULK_TENSOR_PREFETCH_##dim##_##mode##_CH            \
2858          : NVPTX::CP_ASYNC_BULK_TENSOR_PREFETCH_##dim##_##mode)
2859 
2860 static unsigned GetCpAsyncBulkTensorS2GOpcode(size_t Dim, bool IsShared32,
2861                                               bool IsCacheHint, bool IsIm2Col,
2862                                               bool IsReduce = false) {
2863   if (IsIm2Col) {
2864     switch (Dim) {
2865     case 3:
2866       return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(3D, IM2COL, IsReduce,
2867                                                  IsCacheHint, IsShared32);
2868     case 4:
2869       return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(4D, IM2COL, IsReduce,
2870                                                  IsCacheHint, IsShared32);
2871     case 5:
2872       return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(5D, IM2COL, IsReduce,
2873                                                  IsCacheHint, IsShared32);
2874     default:
2875       llvm_unreachable("Invalid Dimension in im2col mode for "
2876                        "GetCpAsyncBulkTensorS2GOpcode.");
2877     }
2878   } else {
2879     switch (Dim) {
2880     case 1:
2881       return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(1D, TILE, IsReduce,
2882                                                  IsCacheHint, IsShared32);
2883     case 2:
2884       return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(2D, TILE, IsReduce,
2885                                                  IsCacheHint, IsShared32);
2886     case 3:
2887       return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(3D, TILE, IsReduce,
2888                                                  IsCacheHint, IsShared32);
2889     case 4:
2890       return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(4D, TILE, IsReduce,
2891                                                  IsCacheHint, IsShared32);
2892     case 5:
2893       return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(5D, TILE, IsReduce,
2894                                                  IsCacheHint, IsShared32);
2895     default:
2896       llvm_unreachable(
2897           "Invalid Dimension in tile mode for GetCpAsyncBulkTensorS2GOpcode.");
2898     }
2899   }
2900 }
2901 
2902 static unsigned GetCpAsyncBulkTensorG2SOpcode(size_t Dim, bool IsShared32,
2903                                               bool IsMultiCast,
2904                                               bool IsCacheHint, bool IsIm2Col) {
2905   if (IsIm2Col) {
2906     switch (Dim) {
2907     case 3:
2908       return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(3D, IM2COL, IsMultiCast,
2909                                                  IsCacheHint, IsShared32);
2910     case 4:
2911       return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(4D, IM2COL, IsMultiCast,
2912                                                  IsCacheHint, IsShared32);
2913     case 5:
2914       return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(5D, IM2COL, IsMultiCast,
2915                                                  IsCacheHint, IsShared32);
2916     default:
2917       llvm_unreachable("Invalid Dimension in im2col mode for "
2918                        "GetCpAsyncBulkTensorG2SOpcode.");
2919     }
2920   } else {
2921     switch (Dim) {
2922     case 1:
2923       return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(1D, TILE, IsMultiCast,
2924                                                  IsCacheHint, IsShared32);
2925     case 2:
2926       return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(2D, TILE, IsMultiCast,
2927                                                  IsCacheHint, IsShared32);
2928     case 3:
2929       return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(3D, TILE, IsMultiCast,
2930                                                  IsCacheHint, IsShared32);
2931     case 4:
2932       return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(4D, TILE, IsMultiCast,
2933                                                  IsCacheHint, IsShared32);
2934     case 5:
2935       return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(5D, TILE, IsMultiCast,
2936                                                  IsCacheHint, IsShared32);
2937     default:
2938       llvm_unreachable(
2939           "Invalid Dimension in tile mode for GetCpAsyncBulkTensorG2SOpcode.");
2940     }
2941   }
2942 }
2943 
2944 static unsigned GetCpAsyncBulkTensorPrefetchOpcode(size_t Dim, bool IsCacheHint,
2945                                                    bool IsIm2Col) {
2946   if (IsIm2Col) {
2947     switch (Dim) {
2948     case 3:
2949       return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(3D, IM2COL, IsCacheHint);
2950     case 4:
2951       return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(4D, IM2COL, IsCacheHint);
2952     case 5:
2953       return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(5D, IM2COL, IsCacheHint);
2954     default:
2955       llvm_unreachable("Invalid Dimension in im2col mode for "
2956                        "GetCpAsyncBulkTensorPrefetchOpcode.");
2957     }
2958   } else {
2959     switch (Dim) {
2960     case 1:
2961       return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(1D, TILE, IsCacheHint);
2962     case 2:
2963       return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(2D, TILE, IsCacheHint);
2964     case 3:
2965       return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(3D, TILE, IsCacheHint);
2966     case 4:
2967       return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(4D, TILE, IsCacheHint);
2968     case 5:
2969       return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(5D, TILE, IsCacheHint);
2970     default:
2971       llvm_unreachable("Invalid Dimension in tile mode for "
2972                        "GetCpAsyncBulkTensorPrefetchOpcode.");
2973     }
2974   }
2975 }
2976 
2977 static size_t GetDimsFromIntrinsic(unsigned IID) {
2978   switch (IID) {
2979   case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d:
2980   case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d:
2981     return 3;
2982   case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d:
2983   case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d:
2984     return 4;
2985   case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d:
2986   case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d:
2987     return 5;
2988   default:
2989     llvm_unreachable("Invalid im2col intrinsic in GetDimsFromIntrinsic.");
2990   }
2991 }
2992 
2993 void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorG2SCommon(SDNode *N,
2994                                                          bool IsIm2Col) {
2995   // We have {Chain, Intrinsic-ID} followed by the actual intrisic args:
2996   // {dst, mbar, src, dims{d0...dN}, im2col_offsets{dims-2}
2997   // multicast, cache_hint,
2998   // multicast_flag, cache_hint_flag}
2999   // NumOperands = {Chain, IID} + {Actual intrinsic args}
3000   //             = {2}          + {7 + dims + im2col_offsets}
3001   size_t NumOps = N->getNumOperands();
3002   size_t NumDims = IsIm2Col ? GetDimsFromIntrinsic(N->getConstantOperandVal(1))
3003                             : (NumOps - 9);
3004   // Offsets is always 'NumDims - 2' and only for im2col mode
3005   size_t NumOffsets = IsIm2Col ? (NumDims - 2) : 0;
3006   bool IsCacheHint = N->getConstantOperandVal(NumOps - 1) == 1;
3007   bool IsMultiCast = N->getConstantOperandVal(NumOps - 2) == 1;
3008   size_t NumBaseArgs = NumDims + NumOffsets + 3; // for {dst, mbar, src}
3009   size_t MultiCastIdx = NumBaseArgs + 2;         // for Chain and IID
3010 
3011   SDLoc DL(N);
3012   SmallVector<SDValue, 8> Ops(N->ops().slice(2, NumBaseArgs));
3013 
3014   // Push MultiCast operand, if available
3015   if (IsMultiCast)
3016     Ops.push_back(N->getOperand(MultiCastIdx));
3017 
3018   // Push CacheHint operand, if available
3019   if (IsCacheHint)
3020     Ops.push_back(N->getOperand(MultiCastIdx + 1));
3021 
3022   // Finally, the chain operand
3023   Ops.push_back(N->getOperand(0));
3024 
3025   bool IsShared32 =
3026       CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32;
3027   unsigned Opcode = GetCpAsyncBulkTensorG2SOpcode(
3028       NumDims, IsShared32, IsMultiCast, IsCacheHint, IsIm2Col);
3029   ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
3030 }
3031 
3032 void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorS2GCommon(SDNode *N,
3033                                                          bool IsIm2Col) {
3034   // We have {Chain, Intrinsic-ID} followed by the actual intrisic args:
3035   // src, dst, dims{d0...dN}, cache_hint, cache_hint_flag
3036   // NumOperands = {Chain, IID} + {Actual intrinsic args}
3037   //             = {2}          + {4 + dims}
3038   size_t NumOps = N->getNumOperands();
3039   size_t NumDims = NumOps - 6;
3040   bool IsCacheHint = N->getConstantOperandVal(NumOps - 1) == 1;
3041   size_t NumArgs = NumDims + (IsCacheHint ? 3 : 2); // src, dst, cache_hint
3042 
3043   SDLoc DL(N);
3044   SmallVector<SDValue, 8> Ops(N->ops().slice(2, NumArgs));
3045   Ops.push_back(N->getOperand(0)); // Chain operand
3046 
3047   bool IsShared32 =
3048       CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32;
3049   unsigned Opcode =
3050       GetCpAsyncBulkTensorS2GOpcode(NumDims, IsShared32, IsCacheHint, IsIm2Col);
3051   ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
3052 }
3053 
3054 void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorPrefetchCommon(SDNode *N,
3055                                                               bool IsIm2Col) {
3056   // We have {Chain, Intrinsic-ID} followed by the actual intrisic args:
3057   // {src, dims{d0...dN}, im2col_offsets{dims-2}
3058   // cache_hint, cache_hint_flag}
3059   // NumOperands = {Chain, IID} + {Actual intrinsic args}
3060   //             = {2}          + {3 + dims + im2col_offsets}
3061   size_t NumOps = N->getNumOperands();
3062   size_t NumDims = IsIm2Col ? GetDimsFromIntrinsic(N->getConstantOperandVal(1))
3063                             : (NumOps - 5);
3064   // Offsets is always 'NumDims - 2' and only for im2col mode
3065   size_t NumOffsets = IsIm2Col ? (NumDims - 2) : 0;
3066   bool IsCacheHint = N->getConstantOperandVal(NumOps - 1) == 1;
3067   size_t NumArgs = NumDims + NumOffsets + (IsCacheHint ? 2 : 1);
3068 
3069   SDLoc DL(N);
3070   SmallVector<SDValue, 12> Ops(N->ops().slice(2, NumArgs));
3071   Ops.push_back(N->getOperand(0)); // Chain operand
3072 
3073   unsigned Opcode =
3074       GetCpAsyncBulkTensorPrefetchOpcode(NumDims, IsCacheHint, IsIm2Col);
3075   ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
3076 }
3077 
3078 void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorReduceCommon(SDNode *N,
3079                                                             unsigned RedOp,
3080                                                             bool IsIm2Col) {
3081   // We have {Chain, Intrinsic-ID} followed by the actual intrisic args:
3082   // src, dst, dims{d0...dN}, cache_hint, cache_hint_flag
3083   // NumOperands = {Chain, IID} + {Actual intrinsic args}
3084   //             = {2}          + {4 + dims}
3085   size_t NumOps = N->getNumOperands();
3086   size_t NumDims = NumOps - 6;
3087   bool IsCacheHint = N->getConstantOperandVal(NumOps - 1) == 1;
3088   size_t NumArgs = NumDims + (IsCacheHint ? 3 : 2); // src, dst, cache_hint
3089 
3090   SDLoc DL(N);
3091   SmallVector<SDValue, 12> Ops(N->ops().slice(2, NumArgs));
3092   Ops.push_back(getI32Imm(RedOp, DL)); // Reduction Op
3093   Ops.push_back(N->getOperand(0));     // Chain operand
3094 
3095   bool IsShared32 =
3096       CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32;
3097   unsigned Opcode = GetCpAsyncBulkTensorS2GOpcode(
3098       NumDims, IsShared32, IsCacheHint, IsIm2Col, /*IsReduce=*/true);
3099   ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
3100 }
3101 
3102 void NVPTXDAGToDAGISel::SelectCpAsyncBulkS2G(SDNode *N) {
3103   // We have {Chain, Intrinsic-ID} followed by the actual intrisic args:
3104   // dst, src, size, cache_hint, cache_hint_flag
3105   // NumOperands = {Chain, IID} + {Actual intrinsic args}
3106   //             = {2}          + {5}
3107   size_t NumOps = N->getNumOperands();
3108   bool IsCacheHint = N->getConstantOperandVal(NumOps - 1) == 1;
3109   size_t NumArgs = IsCacheHint ? 4 : 3; // src, dst, size, cache_hint
3110 
3111   SDLoc DL(N);
3112   SmallVector<SDValue, 8> Ops(N->ops().slice(2, NumArgs));
3113   Ops.push_back(N->getOperand(0)); // Chain operand
3114 
3115   bool IsShared32 =
3116       CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32;
3117   unsigned Opcode;
3118   if (IsCacheHint)
3119     Opcode = IsShared32 ? NVPTX::CP_ASYNC_BULK_S2G_SHARED32_CH
3120                         : NVPTX::CP_ASYNC_BULK_S2G_CH;
3121   else
3122     Opcode = IsShared32 ? NVPTX::CP_ASYNC_BULK_S2G_SHARED32
3123                         : NVPTX::CP_ASYNC_BULK_S2G;
3124   ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
3125 }
3126 
3127 void NVPTXDAGToDAGISel::SelectCpAsyncBulkG2S(SDNode *N) {
3128   // We have {Chain, Intrinsic-ID} followed by the actual intrisic args:
3129   // {dst, mbar, src, size, multicast, cache_hint,
3130   // multicast_flag, cache_hint_flag}
3131   // NumOperands = {Chain, IID} + {Actual intrinsic args}
3132   //             = {2}          + {8}
3133   size_t NumOps = N->getNumOperands();
3134   bool IsCacheHint = N->getConstantOperandVal(NumOps - 1) == 1;
3135   bool IsMultiCast = N->getConstantOperandVal(NumOps - 2) == 1;
3136   size_t NumBaseArgs = 4;                // dst, mbar, src, size
3137   size_t MultiCastIdx = NumBaseArgs + 2; // for Chain and IID
3138 
3139   SDLoc DL(N);
3140   SmallVector<SDValue, 8> Ops(N->ops().slice(2, NumBaseArgs));
3141 
3142   // Push MultiCast operand, if available
3143   if (IsMultiCast)
3144     Ops.push_back(N->getOperand(MultiCastIdx));
3145 
3146   // Push CacheHint operand, if available
3147   if (IsCacheHint)
3148     Ops.push_back(N->getOperand(MultiCastIdx + 1));
3149 
3150   // Finally, the chain operand
3151   Ops.push_back(N->getOperand(0));
3152 
3153   bool IsShared32 =
3154       CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32;
3155   unsigned Opcode = [&]() {
3156     if (IsMultiCast && IsCacheHint)
3157       return IsShared32 ? NVPTX::CP_ASYNC_BULK_G2S_SHARED32_MC_CH
3158                         : NVPTX::CP_ASYNC_BULK_G2S_MC_CH;
3159     if (IsMultiCast)
3160       return IsShared32 ? NVPTX::CP_ASYNC_BULK_G2S_SHARED32_MC
3161                         : NVPTX::CP_ASYNC_BULK_G2S_MC;
3162     if (IsCacheHint)
3163       return IsShared32 ? NVPTX::CP_ASYNC_BULK_G2S_SHARED32_CH
3164                         : NVPTX::CP_ASYNC_BULK_G2S_CH;
3165     return IsShared32 ? NVPTX::CP_ASYNC_BULK_G2S_SHARED32
3166                       : NVPTX::CP_ASYNC_BULK_G2S;
3167   }();
3168   ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
3169 }
3170 
3171 void NVPTXDAGToDAGISel::SelectCpAsyncBulkPrefetchL2(SDNode *N) {
3172   // We have {Chain, Intrinsic-ID} followed by the actual intrisic args:
3173   // src, size, cache_hint, cache_hint_flag
3174   // NumOperands = {Chain, IID} + {Actual intrinsic args}
3175   //             = {2}          + {4}
3176   size_t NumOps = N->getNumOperands();
3177   bool IsCacheHint = N->getConstantOperandVal(NumOps - 1) == 1;
3178   size_t NumArgs = IsCacheHint ? 3 : 2; // src, size, cache_hint
3179 
3180   SDLoc DL(N);
3181   SmallVector<SDValue, 4> Ops(N->ops().slice(2, NumArgs));
3182   Ops.push_back(N->getOperand(0)); // Chain operand
3183 
3184   unsigned Opcode = IsCacheHint
3185   ?  NVPTX::CP_ASYNC_BULK_PREFETCH_CH
3186   :  NVPTX::CP_ASYNC_BULK_PREFETCH;
3187   ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
3188 }
3189 
3190 bool NVPTXDAGToDAGISel::tryIntrinsicVoid(SDNode *N) {
3191   unsigned IID = N->getConstantOperandVal(1);
3192   using TMARedTy = llvm::nvvm::TMAReductionOp;
3193   auto CastTy = [](TMARedTy Op) { return static_cast<unsigned>(Op); };
3194   switch (IID) {
3195   default:
3196     return false;
3197   case Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster:
3198     SelectCpAsyncBulkG2S(N);
3199     return true;
3200   case Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global:
3201     SelectCpAsyncBulkS2G(N);
3202     return true;
3203   case Intrinsic::nvvm_cp_async_bulk_prefetch_L2:
3204     SelectCpAsyncBulkPrefetchL2(N);
3205     return true;
3206   case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d:
3207   case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d:
3208   case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d:
3209   case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d:
3210   case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d:
3211     SelectCpAsyncBulkTensorS2GCommon(N);
3212     return true;
3213   case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d:
3214   case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d:
3215   case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d:
3216     SelectCpAsyncBulkTensorS2GCommon(N, /*IsIm2Col=*/true);
3217     return true;
3218   case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d:
3219   case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d:
3220   case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d:
3221   case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d:
3222   case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d:
3223     SelectCpAsyncBulkTensorG2SCommon(N);
3224     return true;
3225   case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d:
3226   case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d:
3227   case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d:
3228     SelectCpAsyncBulkTensorG2SCommon(N, /*IsIm2Col=*/true);
3229     return true;
3230   case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d:
3231   case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d:
3232   case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d:
3233   case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d:
3234   case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d:
3235     SelectCpAsyncBulkTensorPrefetchCommon(N);
3236     return true;
3237   case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d:
3238   case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d:
3239   case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d:
3240     SelectCpAsyncBulkTensorPrefetchCommon(N, /*IsIm2Col=*/true);
3241     return true;
3242   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d:
3243   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d:
3244   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d:
3245   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d:
3246   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d:
3247     SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::ADD));
3248     return true;
3249   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d:
3250   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d:
3251   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d:
3252     SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::ADD),
3253                                         /*IsIm2Col=*/true);
3254     return true;
3255   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d:
3256   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d:
3257   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d:
3258   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d:
3259   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d:
3260     SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::MIN));
3261     return true;
3262   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d:
3263   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d:
3264   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d:
3265     SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::MIN),
3266                                         /*IsIm2Col=*/true);
3267     return true;
3268   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d:
3269   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d:
3270   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d:
3271   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d:
3272   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d:
3273     SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::MAX));
3274     return true;
3275   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d:
3276   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d:
3277   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d:
3278     SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::MAX),
3279                                         /*IsIm2Col=*/true);
3280     return true;
3281   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d:
3282   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d:
3283   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d:
3284   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d:
3285   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d:
3286     SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::INC));
3287     return true;
3288   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d:
3289   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d:
3290   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d:
3291     SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::INC),
3292                                         /*IsIm2Col=*/true);
3293     return true;
3294   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d:
3295   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d:
3296   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d:
3297   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d:
3298   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d:
3299     SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::DEC));
3300     return true;
3301   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d:
3302   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d:
3303   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d:
3304     SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::DEC),
3305                                         /*IsIm2Col=*/true);
3306     return true;
3307   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d:
3308   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d:
3309   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d:
3310   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d:
3311   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d:
3312     SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::AND));
3313     return true;
3314   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d:
3315   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d:
3316   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d:
3317     SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::AND),
3318                                         /*IsIm2Col=*/true);
3319     return true;
3320   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d:
3321   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d:
3322   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d:
3323   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d:
3324   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d:
3325     SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::OR));
3326     return true;
3327   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d:
3328   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d:
3329   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d:
3330     SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::OR),
3331                                         /*IsIm2Col=*/true);
3332     return true;
3333   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d:
3334   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d:
3335   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d:
3336   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d:
3337   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d:
3338     SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::XOR));
3339     return true;
3340   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d:
3341   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d:
3342   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d:
3343     SelectCpAsyncBulkTensorReduceCommon(N, CastTy(TMARedTy::XOR),
3344                                         /*IsIm2Col=*/true);
3345     return true;
3346   }
3347 }
3348