xref: /llvm-project/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (revision 892a804d93d44ddfd7cd351852fe6aef32d4dcd0)
1 //===-- NVPTXISelLowering.cpp - NVPTX DAG Lowering Implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the interfaces that NVPTX uses to lower LLVM code into a
10 // selection DAG.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "NVPTXISelLowering.h"
15 #include "MCTargetDesc/NVPTXBaseInfo.h"
16 #include "NVPTX.h"
17 #include "NVPTXSubtarget.h"
18 #include "NVPTXTargetMachine.h"
19 #include "NVPTXTargetObjectFile.h"
20 #include "NVPTXUtilities.h"
21 #include "llvm/ADT/APInt.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/CodeGen/Analysis.h"
26 #include "llvm/CodeGen/ISDOpcodes.h"
27 #include "llvm/CodeGen/MachineFunction.h"
28 #include "llvm/CodeGen/MachineJumpTableInfo.h"
29 #include "llvm/CodeGen/MachineMemOperand.h"
30 #include "llvm/CodeGen/SelectionDAG.h"
31 #include "llvm/CodeGen/SelectionDAGNodes.h"
32 #include "llvm/CodeGen/TargetCallingConv.h"
33 #include "llvm/CodeGen/TargetLowering.h"
34 #include "llvm/CodeGen/ValueTypes.h"
35 #include "llvm/CodeGenTypes/MachineValueType.h"
36 #include "llvm/IR/Argument.h"
37 #include "llvm/IR/Attributes.h"
38 #include "llvm/IR/Constants.h"
39 #include "llvm/IR/DataLayout.h"
40 #include "llvm/IR/DerivedTypes.h"
41 #include "llvm/IR/DiagnosticInfo.h"
42 #include "llvm/IR/FPEnv.h"
43 #include "llvm/IR/Function.h"
44 #include "llvm/IR/GlobalValue.h"
45 #include "llvm/IR/Instruction.h"
46 #include "llvm/IR/Instructions.h"
47 #include "llvm/IR/IntrinsicsNVPTX.h"
48 #include "llvm/IR/Module.h"
49 #include "llvm/IR/Type.h"
50 #include "llvm/IR/Value.h"
51 #include "llvm/Support/Alignment.h"
52 #include "llvm/Support/Casting.h"
53 #include "llvm/Support/CodeGen.h"
54 #include "llvm/Support/CommandLine.h"
55 #include "llvm/Support/ErrorHandling.h"
56 #include "llvm/Support/NVPTXAddrSpace.h"
57 #include "llvm/Support/raw_ostream.h"
58 #include "llvm/Target/TargetMachine.h"
59 #include "llvm/Target/TargetOptions.h"
60 #include <algorithm>
61 #include <cassert>
62 #include <cmath>
63 #include <cstdint>
64 #include <iterator>
65 #include <optional>
66 #include <string>
67 #include <utility>
68 #include <vector>
69 
70 #define DEBUG_TYPE "nvptx-lower"
71 
72 using namespace llvm;
73 
74 static std::atomic<unsigned> GlobalUniqueCallSite;
75 
76 static cl::opt<bool> sched4reg(
77     "nvptx-sched4reg",
78     cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false));
79 
80 static cl::opt<unsigned> FMAContractLevelOpt(
81     "nvptx-fma-level", cl::Hidden,
82     cl::desc("NVPTX Specific: FMA contraction (0: don't do it"
83              " 1: do it  2: do it aggressively"),
84     cl::init(2));
85 
86 static cl::opt<int> UsePrecDivF32(
87     "nvptx-prec-divf32", cl::Hidden,
88     cl::desc("NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use"
89              " IEEE Compliant F32 div.rnd if available."),
90     cl::init(2));
91 
92 static cl::opt<bool> UsePrecSqrtF32(
93     "nvptx-prec-sqrtf32", cl::Hidden,
94     cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."),
95     cl::init(true));
96 
97 /// Whereas CUDA's implementation (see libdevice) uses ex2.approx for exp2(), it
98 /// does NOT use lg2.approx for log2, so this is disabled by default.
99 static cl::opt<bool> UseApproxLog2F32(
100     "nvptx-approx-log2f32",
101     cl::desc("NVPTX Specific: whether to use lg2.approx for log2"),
102     cl::init(false));
103 
104 static cl::opt<bool> ForceMinByValParamAlign(
105     "nvptx-force-min-byval-param-align", cl::Hidden,
106     cl::desc("NVPTX Specific: force 4-byte minimal alignment for byval"
107              " params of device functions."),
108     cl::init(false));
109 
110 int NVPTXTargetLowering::getDivF32Level() const {
111   if (UsePrecDivF32.getNumOccurrences() > 0) {
112     // If nvptx-prec-div32=N is used on the command-line, always honor it
113     return UsePrecDivF32;
114   } else {
115     // Otherwise, use div.approx if fast math is enabled
116     if (getTargetMachine().Options.UnsafeFPMath)
117       return 0;
118     else
119       return 2;
120   }
121 }
122 
123 bool NVPTXTargetLowering::usePrecSqrtF32() const {
124   if (UsePrecSqrtF32.getNumOccurrences() > 0) {
125     // If nvptx-prec-sqrtf32 is used on the command-line, always honor it
126     return UsePrecSqrtF32;
127   } else {
128     // Otherwise, use sqrt.approx if fast math is enabled
129     return !getTargetMachine().Options.UnsafeFPMath;
130   }
131 }
132 
133 bool NVPTXTargetLowering::useF32FTZ(const MachineFunction &MF) const {
134   return MF.getDenormalMode(APFloat::IEEEsingle()).Output ==
135          DenormalMode::PreserveSign;
136 }
137 
138 static bool IsPTXVectorType(MVT VT) {
139   switch (VT.SimpleTy) {
140   default:
141     return false;
142   case MVT::v2i1:
143   case MVT::v4i1:
144   case MVT::v2i8:
145   case MVT::v4i8:
146   case MVT::v8i8:  // <2 x i8x4>
147   case MVT::v16i8: // <4 x i8x4>
148   case MVT::v2i16:
149   case MVT::v4i16:
150   case MVT::v8i16: // <4 x i16x2>
151   case MVT::v2i32:
152   case MVT::v4i32:
153   case MVT::v2i64:
154   case MVT::v2f16:
155   case MVT::v4f16:
156   case MVT::v8f16: // <4 x f16x2>
157   case MVT::v2bf16:
158   case MVT::v4bf16:
159   case MVT::v8bf16: // <4 x bf16x2>
160   case MVT::v2f32:
161   case MVT::v4f32:
162   case MVT::v2f64:
163     return true;
164   }
165 }
166 
167 static bool Is16bitsType(MVT VT) {
168   return (VT.SimpleTy == MVT::f16 || VT.SimpleTy == MVT::bf16 ||
169           VT.SimpleTy == MVT::i16);
170 }
171 
172 // When legalizing vector loads/stores, this function is called, which does two
173 // things:
174 // 1. Determines Whether the vector is something we want to custom lower,
175 // std::nullopt is returned if we do not want to custom lower it.
176 // 2. If we do want to handle it, returns two parameters:
177 //    - unsigned int NumElts - The number of elements in the final vector
178 //    - EVT EltVT - The type of the elements in the final vector
179 static std::optional<std::pair<unsigned int, EVT>>
180 getVectorLoweringShape(EVT VectorVT) {
181   if (!VectorVT.isVector() || !VectorVT.isSimple())
182     return std::nullopt;
183 
184   EVT EltVT = VectorVT.getVectorElementType();
185   unsigned NumElts = VectorVT.getVectorNumElements();
186 
187   // We only handle "native" vector sizes for now, e.g. <4 x double> is not
188   // legal.  We can (and should) split that into 2 stores of <2 x double> here
189   // but I'm leaving that as a TODO for now.
190   switch (VectorVT.getSimpleVT().SimpleTy) {
191   default:
192     return std::nullopt;
193   case MVT::v2i8:
194   case MVT::v2i16:
195   case MVT::v2i32:
196   case MVT::v2i64:
197   case MVT::v2f16:
198   case MVT::v2bf16:
199   case MVT::v2f32:
200   case MVT::v2f64:
201   case MVT::v4i8:
202   case MVT::v4i16:
203   case MVT::v4i32:
204   case MVT::v4f16:
205   case MVT::v4bf16:
206   case MVT::v4f32:
207     // This is a "native" vector type
208     return std::pair(NumElts, EltVT);
209   case MVT::v8i8:   // <2 x i8x4>
210   case MVT::v8f16:  // <4 x f16x2>
211   case MVT::v8bf16: // <4 x bf16x2>
212   case MVT::v8i16:  // <4 x i16x2>
213   case MVT::v16i8:  // <4 x i8x4>
214     // This can be upsized into a "native" vector type.
215     // Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for
216     // total load/store size, PTX syntax only supports v2/v4. Thus, we can't use
217     // vectorized loads/stores with the actual element type for i8/i16 as that
218     // would require v8/v16 variants that do not exist.
219     // In order to load/store such vectors efficiently, here in Type
220     // Legalization, we split the vector into word-sized chunks (v2x16/v4i8).
221     // Later, we will lower to PTX as vectors of b32.
222 
223     // Number of elements to pack in one word.
224     unsigned NPerWord = 32 / EltVT.getSizeInBits();
225 
226     return std::pair(NumElts / NPerWord,
227                       MVT::getVectorVT(EltVT.getSimpleVT(), NPerWord));
228   }
229 
230   llvm_unreachable("All cases in switch should return.");
231 }
232 
233 /// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
234 /// EVTs that compose it.  Unlike ComputeValueVTs, this will break apart vectors
235 /// into their primitive components.
236 /// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
237 /// same number of types as the Ins/Outs arrays in LowerFormalArguments,
238 /// LowerCall, and LowerReturn.
239 static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
240                                Type *Ty, SmallVectorImpl<EVT> &ValueVTs,
241                                SmallVectorImpl<uint64_t> *Offsets = nullptr,
242                                uint64_t StartingOffset = 0) {
243   SmallVector<EVT, 16> TempVTs;
244   SmallVector<uint64_t, 16> TempOffsets;
245 
246   // Special case for i128 - decompose to (i64, i64)
247   if (Ty->isIntegerTy(128)) {
248     ValueVTs.push_back(EVT(MVT::i64));
249     ValueVTs.push_back(EVT(MVT::i64));
250 
251     if (Offsets) {
252       Offsets->push_back(StartingOffset + 0);
253       Offsets->push_back(StartingOffset + 8);
254     }
255 
256     return;
257   }
258 
259   // Given a struct type, recursively traverse the elements with custom ComputePTXValueVTs.
260   if (StructType *STy = dyn_cast<StructType>(Ty)) {
261     auto const *SL = DL.getStructLayout(STy);
262     auto ElementNum = 0;
263     for(auto *EI : STy->elements()) {
264       ComputePTXValueVTs(TLI, DL, EI, ValueVTs, Offsets,
265                          StartingOffset + SL->getElementOffset(ElementNum));
266       ++ElementNum;
267     }
268     return;
269   }
270 
271   // Given an array type, recursively traverse the elements with custom ComputePTXValueVTs.
272   if (ArrayType *ATy = dyn_cast<ArrayType>(Ty)) {
273     Type *EltTy = ATy->getElementType();
274     uint64_t EltSize = DL.getTypeAllocSize(EltTy);
275     for (int I : llvm::seq<int>(ATy->getNumElements()))
276       ComputePTXValueVTs(TLI, DL, EltTy, ValueVTs, Offsets, StartingOffset + I * EltSize);
277     return;
278   }
279 
280   ComputeValueVTs(TLI, DL, Ty, TempVTs, &TempOffsets, StartingOffset);
281   for (unsigned i = 0, e = TempVTs.size(); i != e; ++i) {
282     EVT VT = TempVTs[i];
283     uint64_t Off = TempOffsets[i];
284     // Split vectors into individual elements, except for v2f16, which
285     // we will pass as a single scalar.
286     if (VT.isVector()) {
287       unsigned NumElts = VT.getVectorNumElements();
288       EVT EltVT = VT.getVectorElementType();
289       // We require power-of-2 sized vectors becuase
290       // TargetLoweringBase::getVectorTypeBreakdown() which is invoked in
291       // ComputePTXValueVTs() cannot currently break down non-power-of-2 sized
292       // vectors.
293       if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0 &&
294           isPowerOf2_32(NumElts)) {
295         // Vectors with an even number of f16 elements will be passed to
296         // us as an array of v2f16/v2bf16 elements. We must match this so we
297         // stay in sync with Ins/Outs.
298         switch (EltVT.getSimpleVT().SimpleTy) {
299         case MVT::f16:
300           EltVT = MVT::v2f16;
301           break;
302         case MVT::bf16:
303           EltVT = MVT::v2bf16;
304           break;
305         case MVT::i16:
306           EltVT = MVT::v2i16;
307           break;
308         default:
309           llvm_unreachable("Unexpected type");
310         }
311         NumElts /= 2;
312       } else if (EltVT.getSimpleVT() == MVT::i8 &&
313                  ((NumElts % 4 == 0 && isPowerOf2_32(NumElts)) ||
314                   NumElts == 3)) {
315         // v*i8 are formally lowered as v4i8
316         EltVT = MVT::v4i8;
317         NumElts = (NumElts + 3) / 4;
318       } else if (EltVT.getSimpleVT() == MVT::i8 && NumElts == 2) {
319         // v2i8 is promoted to v2i16
320         NumElts = 1;
321         EltVT = MVT::v2i16;
322       }
323       for (unsigned j = 0; j != NumElts; ++j) {
324         ValueVTs.push_back(EltVT);
325         if (Offsets)
326           Offsets->push_back(Off + j * EltVT.getStoreSize());
327       }
328     } else {
329       ValueVTs.push_back(VT);
330       if (Offsets)
331         Offsets->push_back(Off);
332     }
333   }
334 }
335 
336 /// PromoteScalarIntegerPTX
337 /// Used to make sure the arguments/returns are suitable for passing
338 /// and promote them to a larger size if they're not.
339 ///
340 /// The promoted type is placed in \p PromoteVT if the function returns true.
341 static bool PromoteScalarIntegerPTX(const EVT &VT, MVT *PromotedVT) {
342   if (VT.isScalarInteger()) {
343     switch (PowerOf2Ceil(VT.getFixedSizeInBits())) {
344     default:
345       llvm_unreachable(
346           "Promotion is not suitable for scalars of size larger than 64-bits");
347     case 1:
348       *PromotedVT = MVT::i1;
349       break;
350     case 2:
351     case 4:
352     case 8:
353       *PromotedVT = MVT::i8;
354       break;
355     case 16:
356       *PromotedVT = MVT::i16;
357       break;
358     case 32:
359       *PromotedVT = MVT::i32;
360       break;
361     case 64:
362       *PromotedVT = MVT::i64;
363       break;
364     }
365     return EVT(*PromotedVT) != VT;
366   }
367   return false;
368 }
369 
370 // Check whether we can merge loads/stores of some of the pieces of a
371 // flattened function parameter or return value into a single vector
372 // load/store.
373 //
374 // The flattened parameter is represented as a list of EVTs and
375 // offsets, and the whole structure is aligned to ParamAlignment. This
376 // function determines whether we can load/store pieces of the
377 // parameter starting at index Idx using a single vectorized op of
378 // size AccessSize. If so, it returns the number of param pieces
379 // covered by the vector op. Otherwise, it returns 1.
380 static unsigned CanMergeParamLoadStoresStartingAt(
381     unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs,
382     const SmallVectorImpl<uint64_t> &Offsets, Align ParamAlignment) {
383 
384   // Can't vectorize if param alignment is not sufficient.
385   if (ParamAlignment < AccessSize)
386     return 1;
387   // Can't vectorize if offset is not aligned.
388   if (Offsets[Idx] & (AccessSize - 1))
389     return 1;
390 
391   EVT EltVT = ValueVTs[Idx];
392   unsigned EltSize = EltVT.getStoreSize();
393 
394   // Element is too large to vectorize.
395   if (EltSize >= AccessSize)
396     return 1;
397 
398   unsigned NumElts = AccessSize / EltSize;
399   // Can't vectorize if AccessBytes if not a multiple of EltSize.
400   if (AccessSize != EltSize * NumElts)
401     return 1;
402 
403   // We don't have enough elements to vectorize.
404   if (Idx + NumElts > ValueVTs.size())
405     return 1;
406 
407   // PTX ISA can only deal with 2- and 4-element vector ops.
408   if (NumElts != 4 && NumElts != 2)
409     return 1;
410 
411   for (unsigned j = Idx + 1; j < Idx + NumElts; ++j) {
412     // Types do not match.
413     if (ValueVTs[j] != EltVT)
414       return 1;
415 
416     // Elements are not contiguous.
417     if (Offsets[j] - Offsets[j - 1] != EltSize)
418       return 1;
419   }
420   // OK. We can vectorize ValueVTs[i..i+NumElts)
421   return NumElts;
422 }
423 
424 // Flags for tracking per-element vectorization state of loads/stores
425 // of a flattened function parameter or return value.
426 enum ParamVectorizationFlags {
427   PVF_INNER = 0x0, // Middle elements of a vector.
428   PVF_FIRST = 0x1, // First element of the vector.
429   PVF_LAST = 0x2,  // Last element of the vector.
430   // Scalar is effectively a 1-element vector.
431   PVF_SCALAR = PVF_FIRST | PVF_LAST
432 };
433 
434 // Computes whether and how we can vectorize the loads/stores of a
435 // flattened function parameter or return value.
436 //
437 // The flattened parameter is represented as the list of ValueVTs and
438 // Offsets, and is aligned to ParamAlignment bytes. We return a vector
439 // of the same size as ValueVTs indicating how each piece should be
440 // loaded/stored (i.e. as a scalar, or as part of a vector
441 // load/store).
442 static SmallVector<ParamVectorizationFlags, 16>
443 VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
444                      const SmallVectorImpl<uint64_t> &Offsets,
445                      Align ParamAlignment, bool IsVAArg = false) {
446   // Set vector size to match ValueVTs and mark all elements as
447   // scalars by default.
448   SmallVector<ParamVectorizationFlags, 16> VectorInfo;
449   VectorInfo.assign(ValueVTs.size(), PVF_SCALAR);
450 
451   if (IsVAArg)
452     return VectorInfo;
453 
454   // Check what we can vectorize using 128/64/32-bit accesses.
455   for (int I = 0, E = ValueVTs.size(); I != E; ++I) {
456     // Skip elements we've already processed.
457     assert(VectorInfo[I] == PVF_SCALAR && "Unexpected vector info state.");
458     for (unsigned AccessSize : {16, 8, 4, 2}) {
459       unsigned NumElts = CanMergeParamLoadStoresStartingAt(
460           I, AccessSize, ValueVTs, Offsets, ParamAlignment);
461       // Mark vectorized elements.
462       switch (NumElts) {
463       default:
464         llvm_unreachable("Unexpected return value");
465       case 1:
466         // Can't vectorize using this size, try next smaller size.
467         continue;
468       case 2:
469         assert(I + 1 < E && "Not enough elements.");
470         VectorInfo[I] = PVF_FIRST;
471         VectorInfo[I + 1] = PVF_LAST;
472         I += 1;
473         break;
474       case 4:
475         assert(I + 3 < E && "Not enough elements.");
476         VectorInfo[I] = PVF_FIRST;
477         VectorInfo[I + 1] = PVF_INNER;
478         VectorInfo[I + 2] = PVF_INNER;
479         VectorInfo[I + 3] = PVF_LAST;
480         I += 3;
481         break;
482       }
483       // Break out of the inner loop because we've already succeeded
484       // using largest possible AccessSize.
485       break;
486     }
487   }
488   return VectorInfo;
489 }
490 
491 static SDValue MaybeBitcast(SelectionDAG &DAG, SDLoc DL, EVT VT,
492                             SDValue Value) {
493   if (Value->getValueType(0) == VT)
494     return Value;
495   return DAG.getNode(ISD::BITCAST, DL, VT, Value);
496 }
497 
498 // NVPTXTargetLowering Constructor.
499 NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
500                                          const NVPTXSubtarget &STI)
501     : TargetLowering(TM), nvTM(&TM), STI(STI) {
502   // always lower memset, memcpy, and memmove intrinsics to load/store
503   // instructions, rather
504   // then generating calls to memset, mempcy or memmove.
505   MaxStoresPerMemset = MaxStoresPerMemsetOptSize = (unsigned)0xFFFFFFFF;
506   MaxStoresPerMemcpy = MaxStoresPerMemcpyOptSize = (unsigned) 0xFFFFFFFF;
507   MaxStoresPerMemmove = MaxStoresPerMemmoveOptSize = (unsigned) 0xFFFFFFFF;
508 
509   setBooleanContents(ZeroOrNegativeOneBooleanContent);
510   setBooleanVectorContents(ZeroOrNegativeOneBooleanContent);
511 
512   // Jump is Expensive. Don't create extra control flow for 'and', 'or'
513   // condition branches.
514   setJumpIsExpensive(true);
515 
516   // Wide divides are _very_ slow. Try to reduce the width of the divide if
517   // possible.
518   addBypassSlowDiv(64, 32);
519 
520   // By default, use the Source scheduling
521   if (sched4reg)
522     setSchedulingPreference(Sched::RegPressure);
523   else
524     setSchedulingPreference(Sched::Source);
525 
526   auto setFP16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
527                                     LegalizeAction NoF16Action) {
528     bool IsOpSupported = STI.allowFP16Math();
529     switch (Op) {
530     // Several FP16 instructions are available on sm_80 only.
531     case ISD::FMINNUM:
532     case ISD::FMAXNUM:
533     case ISD::FMAXNUM_IEEE:
534     case ISD::FMINNUM_IEEE:
535     case ISD::FMAXIMUM:
536     case ISD::FMINIMUM:
537       IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
538       break;
539     case ISD::FEXP2:
540       IsOpSupported &= STI.getSmVersion() >= 75 && STI.getPTXVersion() >= 70;
541       break;
542     }
543     setOperationAction(Op, VT, IsOpSupported ? Action : NoF16Action);
544   };
545 
546   auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
547                                     LegalizeAction NoBF16Action) {
548     bool IsOpSupported = STI.hasNativeBF16Support(Op);
549     setOperationAction(
550         Op, VT, IsOpSupported ? Action : NoBF16Action);
551   };
552 
553   auto setI16x2OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
554                                      LegalizeAction NoI16x2Action) {
555     bool IsOpSupported = false;
556     // instructions are available on sm_90 only
557     switch (Op) {
558     case ISD::ADD:
559     case ISD::SMAX:
560     case ISD::SMIN:
561     case ISD::UMIN:
562     case ISD::UMAX:
563       IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 80;
564       break;
565     }
566     setOperationAction(Op, VT, IsOpSupported ? Action : NoI16x2Action);
567   };
568 
569   addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
570   addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
571   addRegisterClass(MVT::v2i16, &NVPTX::Int32RegsRegClass);
572   addRegisterClass(MVT::v4i8, &NVPTX::Int32RegsRegClass);
573   addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
574   addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
575   addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
576   addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
577   addRegisterClass(MVT::f16, &NVPTX::Int16RegsRegClass);
578   addRegisterClass(MVT::v2f16, &NVPTX::Int32RegsRegClass);
579   addRegisterClass(MVT::bf16, &NVPTX::Int16RegsRegClass);
580   addRegisterClass(MVT::v2bf16, &NVPTX::Int32RegsRegClass);
581 
582   // Conversion to/from FP16/FP16x2 is always legal.
583   setOperationAction(ISD::BUILD_VECTOR, MVT::v2f16, Custom);
584   setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2f16, Custom);
585   setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2f16, Expand);
586   setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2f16, Expand);
587 
588   setOperationAction(ISD::READCYCLECOUNTER, MVT::i64, Legal);
589   if (STI.getSmVersion() >= 30 && STI.getPTXVersion() > 31)
590     setOperationAction(ISD::READSTEADYCOUNTER, MVT::i64, Legal);
591 
592   setFP16OperationAction(ISD::SETCC, MVT::f16, Legal, Promote);
593   setFP16OperationAction(ISD::SETCC, MVT::v2f16, Legal, Expand);
594 
595   // Conversion to/from BFP16/BFP16x2 is always legal.
596   setOperationAction(ISD::BUILD_VECTOR, MVT::v2bf16, Custom);
597   setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2bf16, Custom);
598   setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2bf16, Expand);
599   setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2bf16, Expand);
600 
601   setBF16OperationAction(ISD::SETCC, MVT::v2bf16, Legal, Expand);
602   setBF16OperationAction(ISD::SETCC, MVT::bf16, Legal, Promote);
603   if (getOperationAction(ISD::SETCC, MVT::bf16) == Promote)
604     AddPromotedToType(ISD::SETCC, MVT::bf16, MVT::f32);
605 
606   // Conversion to/from i16/i16x2 is always legal.
607   setOperationAction(ISD::BUILD_VECTOR, MVT::v2i16, Custom);
608   setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2i16, Custom);
609   setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2i16, Expand);
610   setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2i16, Expand);
611 
612   setOperationAction(ISD::BUILD_VECTOR, MVT::v4i8, Custom);
613   setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v4i8, Custom);
614   setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
615   setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v4i8, Custom);
616 
617   // Custom conversions to/from v2i8.
618   setOperationAction(ISD::BITCAST, MVT::v2i8, Custom);
619 
620   // Only logical ops can be done on v4i8 directly, others must be done
621   // elementwise.
622   setOperationAction(
623       {ISD::ABS,         ISD::ADD,        ISD::ADDC,        ISD::ADDE,
624        ISD::BITREVERSE,  ISD::CTLZ,       ISD::CTPOP,       ISD::CTTZ,
625        ISD::FP_TO_SINT,  ISD::FP_TO_UINT, ISD::FSHL,        ISD::FSHR,
626        ISD::MUL,         ISD::MULHS,      ISD::MULHU,       ISD::PARITY,
627        ISD::ROTL,        ISD::ROTR,       ISD::SADDO,       ISD::SADDO_CARRY,
628        ISD::SADDSAT,     ISD::SDIV,       ISD::SDIVREM,     ISD::SELECT_CC,
629        ISD::SETCC,       ISD::SHL,        ISD::SINT_TO_FP,  ISD::SMAX,
630        ISD::SMIN,        ISD::SMULO,      ISD::SMUL_LOHI,   ISD::SRA,
631        ISD::SREM,        ISD::SRL,        ISD::SSHLSAT,     ISD::SSUBO,
632        ISD::SSUBO_CARRY, ISD::SSUBSAT,    ISD::SUB,         ISD::SUBC,
633        ISD::SUBE,        ISD::UADDO,      ISD::UADDO_CARRY, ISD::UADDSAT,
634        ISD::UDIV,        ISD::UDIVREM,    ISD::UINT_TO_FP,  ISD::UMAX,
635        ISD::UMIN,        ISD::UMULO,      ISD::UMUL_LOHI,   ISD::UREM,
636        ISD::USHLSAT,     ISD::USUBO,      ISD::USUBO_CARRY, ISD::VSELECT,
637        ISD::USUBSAT},
638       MVT::v4i8, Expand);
639 
640   // Operations not directly supported by NVPTX.
641   for (MVT VT : {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32,
642                  MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::v2i16, MVT::v4i8,
643                  MVT::i32, MVT::i64}) {
644     setOperationAction(ISD::SELECT_CC, VT, Expand);
645     setOperationAction(ISD::BR_CC, VT, Expand);
646   }
647 
648   // Some SIGN_EXTEND_INREG can be done using cvt instruction.
649   // For others we will expand to a SHL/SRA pair.
650   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Legal);
651   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i32, Legal);
652   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16, Legal);
653   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8 , Legal);
654   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
655   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::v2i16, Expand);
656 
657   setOperationAction(ISD::SHL_PARTS, MVT::i32  , Custom);
658   setOperationAction(ISD::SRA_PARTS, MVT::i32  , Custom);
659   setOperationAction(ISD::SRL_PARTS, MVT::i32  , Custom);
660   setOperationAction(ISD::SHL_PARTS, MVT::i64  , Custom);
661   setOperationAction(ISD::SRA_PARTS, MVT::i64  , Custom);
662   setOperationAction(ISD::SRL_PARTS, MVT::i64  , Custom);
663 
664   setOperationAction(ISD::BITREVERSE, MVT::i32, Legal);
665   setOperationAction(ISD::BITREVERSE, MVT::i64, Legal);
666 
667   setOperationAction({ISD::ROTL, ISD::ROTR},
668                      {MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64},
669                      Expand);
670 
671   if (STI.hasHWROT32())
672     setOperationAction({ISD::FSHL, ISD::FSHR}, MVT::i32, Legal);
673 
674   setOperationAction(ISD::BSWAP, MVT::i16, Expand);
675 
676   setOperationAction(ISD::BR_JT, MVT::Other, Custom);
677   setOperationAction(ISD::BRIND, MVT::Other, Expand);
678 
679   setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
680   setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
681 
682   // We want to legalize constant related memmove and memcopy
683   // intrinsics.
684   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
685 
686   // Turn FP extload into load/fpextend
687   setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
688   setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
689   setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand);
690   setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand);
691   setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
692   setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2f16, Expand);
693   setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f16, Expand);
694   setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2bf16, Expand);
695   setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2bf16, Expand);
696   setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Expand);
697   setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Expand);
698   setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f16, Expand);
699   setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand);
700   setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand);
701   setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand);
702   setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8f16, Expand);
703   setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8f16, Expand);
704   setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8bf16, Expand);
705   setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8bf16, Expand);
706   // Turn FP truncstore into trunc + store.
707   // FIXME: vector types should also be expanded
708   setTruncStoreAction(MVT::f32, MVT::f16, Expand);
709   setTruncStoreAction(MVT::f64, MVT::f16, Expand);
710   setTruncStoreAction(MVT::f32, MVT::bf16, Expand);
711   setTruncStoreAction(MVT::f64, MVT::bf16, Expand);
712   setTruncStoreAction(MVT::f64, MVT::f32, Expand);
713 
714   // PTX does not support load / store predicate registers
715   setOperationAction(ISD::LOAD, MVT::i1, Custom);
716   setOperationAction(ISD::STORE, MVT::i1, Custom);
717 
718   for (MVT VT : MVT::integer_valuetypes()) {
719     setLoadExtAction(ISD::SEXTLOAD, VT, MVT::i1, Promote);
720     setLoadExtAction(ISD::ZEXTLOAD, VT, MVT::i1, Promote);
721     setLoadExtAction(ISD::EXTLOAD, VT, MVT::i1, Promote);
722     setTruncStoreAction(VT, MVT::i1, Expand);
723   }
724 
725   setCondCodeAction({ISD::SETNE, ISD::SETEQ, ISD::SETUGE, ISD::SETULE,
726                      ISD::SETUGT, ISD::SETULT, ISD::SETGT, ISD::SETLT,
727                      ISD::SETGE, ISD::SETLE},
728                     MVT::i1, Expand);
729 
730   // expand extload of vector of integers.
731   setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::v2i16,
732                    MVT::v2i8, Expand);
733   setTruncStoreAction(MVT::v2i16, MVT::v2i8, Expand);
734 
735   // This is legal in NVPTX
736   setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
737   setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
738   setOperationAction(ISD::ConstantFP, MVT::f16, Legal);
739   setOperationAction(ISD::ConstantFP, MVT::bf16, Legal);
740 
741   setOperationAction(ISD::DYNAMIC_STACKALLOC, {MVT::i32, MVT::i64}, Custom);
742   setOperationAction({ISD::STACKRESTORE, ISD::STACKSAVE}, MVT::Other, Custom);
743 
744   // TRAP can be lowered to PTX trap
745   setOperationAction(ISD::TRAP, MVT::Other, Legal);
746   // DEBUGTRAP can be lowered to PTX brkpt
747   setOperationAction(ISD::DEBUGTRAP, MVT::Other, Legal);
748 
749   // Register custom handling for vector loads/stores
750   for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
751     if (IsPTXVectorType(VT)) {
752       setOperationAction(ISD::LOAD, VT, Custom);
753       setOperationAction(ISD::STORE, VT, Custom);
754       setOperationAction(ISD::INTRINSIC_W_CHAIN, VT, Custom);
755     }
756   }
757 
758   // Support varargs.
759   setOperationAction(ISD::VASTART, MVT::Other, Custom);
760   setOperationAction(ISD::VAARG, MVT::Other, Custom);
761   setOperationAction(ISD::VACOPY, MVT::Other, Expand);
762   setOperationAction(ISD::VAEND, MVT::Other, Expand);
763 
764   // Custom handling for i8 intrinsics
765   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom);
766 
767   for (const auto& Ty : {MVT::i16, MVT::i32, MVT::i64}) {
768     setOperationAction(ISD::ABS,  Ty, Legal);
769     setOperationAction(ISD::SMIN, Ty, Legal);
770     setOperationAction(ISD::SMAX, Ty, Legal);
771     setOperationAction(ISD::UMIN, Ty, Legal);
772     setOperationAction(ISD::UMAX, Ty, Legal);
773 
774     setOperationAction(ISD::CTPOP, Ty, Legal);
775     setOperationAction(ISD::CTLZ, Ty, Legal);
776   }
777 
778   setI16x2OperationAction(ISD::ABS, MVT::v2i16, Legal, Custom);
779   setI16x2OperationAction(ISD::SMIN, MVT::v2i16, Legal, Custom);
780   setI16x2OperationAction(ISD::SMAX, MVT::v2i16, Legal, Custom);
781   setI16x2OperationAction(ISD::UMIN, MVT::v2i16, Legal, Custom);
782   setI16x2OperationAction(ISD::UMAX, MVT::v2i16, Legal, Custom);
783   setI16x2OperationAction(ISD::CTPOP, MVT::v2i16, Legal, Expand);
784   setI16x2OperationAction(ISD::CTLZ, MVT::v2i16, Legal, Expand);
785 
786   setI16x2OperationAction(ISD::ADD, MVT::v2i16, Legal, Custom);
787   setI16x2OperationAction(ISD::SUB, MVT::v2i16, Legal, Custom);
788   setI16x2OperationAction(ISD::MUL, MVT::v2i16, Legal, Custom);
789   setI16x2OperationAction(ISD::SHL, MVT::v2i16, Legal, Custom);
790   setI16x2OperationAction(ISD::SREM, MVT::v2i16, Legal, Custom);
791   setI16x2OperationAction(ISD::UREM, MVT::v2i16, Legal, Custom);
792 
793   // Other arithmetic and logic ops are unsupported.
794   setOperationAction({ISD::SDIV, ISD::UDIV, ISD::SRA, ISD::SRL, ISD::MULHS,
795                       ISD::MULHU, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
796                       ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::SETCC},
797                      MVT::v2i16, Expand);
798 
799   setOperationAction(ISD::ADDC, MVT::i32, Legal);
800   setOperationAction(ISD::ADDE, MVT::i32, Legal);
801   setOperationAction(ISD::SUBC, MVT::i32, Legal);
802   setOperationAction(ISD::SUBE, MVT::i32, Legal);
803   if (STI.getPTXVersion() >= 43) {
804     setOperationAction(ISD::ADDC, MVT::i64, Legal);
805     setOperationAction(ISD::ADDE, MVT::i64, Legal);
806     setOperationAction(ISD::SUBC, MVT::i64, Legal);
807     setOperationAction(ISD::SUBE, MVT::i64, Legal);
808   }
809 
810   setOperationAction(ISD::CTTZ, MVT::i16, Expand);
811   setOperationAction(ISD::CTTZ, MVT::v2i16, Expand);
812   setOperationAction(ISD::CTTZ, MVT::i32, Expand);
813   setOperationAction(ISD::CTTZ, MVT::i64, Expand);
814 
815   // PTX does not directly support SELP of i1, so promote to i32 first
816   setOperationAction(ISD::SELECT, MVT::i1, Custom);
817 
818   // PTX cannot multiply two i64s in a single instruction.
819   setOperationAction(ISD::SMUL_LOHI, MVT::i64, Expand);
820   setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand);
821 
822   // We have some custom DAG combine patterns for these nodes
823   setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
824                        ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
825                        ISD::BUILD_VECTOR});
826 
827   // setcc for f16x2 and bf16x2 needs special handling to prevent
828   // legalizer's attempt to scalarize it due to v2i1 not being legal.
829   if (STI.allowFP16Math() || STI.hasBF16Math())
830     setTargetDAGCombine(ISD::SETCC);
831 
832   // Promote fp16 arithmetic if fp16 hardware isn't available or the
833   // user passed --nvptx-no-fp16-math. The flag is useful because,
834   // although sm_53+ GPUs have some sort of FP16 support in
835   // hardware, only sm_53 and sm_60 have full implementation. Others
836   // only have token amount of hardware and are likely to run faster
837   // by using fp32 units instead.
838   for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) {
839     setFP16OperationAction(Op, MVT::f16, Legal, Promote);
840     setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
841     setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
842     // bf16 must be promoted to f32.
843     setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
844     if (getOperationAction(Op, MVT::bf16) == Promote)
845       AddPromotedToType(Op, MVT::bf16, MVT::f32);
846   }
847 
848   // On SM80, we select add/mul/sub as fma to avoid promotion to float
849   for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
850     for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
851       if (!STI.hasNativeBF16Support(Op) && STI.hasNativeBF16Support(ISD::FMA)) {
852         setOperationAction(Op, VT, Custom);
853       }
854     }
855   }
856 
857   // f16/f16x2 neg was introduced in PTX 60, SM_53.
858   const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
859                                         STI.getPTXVersion() >= 60 &&
860                                         STI.allowFP16Math();
861   for (const auto &VT : {MVT::f16, MVT::v2f16})
862     setOperationAction(ISD::FNEG, VT,
863                        IsFP16FP16x2NegAvailable ? Legal : Expand);
864 
865   setBF16OperationAction(ISD::FNEG, MVT::bf16, Legal, Expand);
866   setBF16OperationAction(ISD::FNEG, MVT::v2bf16, Legal, Expand);
867   // (would be) Library functions.
868 
869   // These map to conversion instructions for scalar FP types.
870   for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
871                          ISD::FROUNDEVEN, ISD::FTRUNC}) {
872     setOperationAction(Op, MVT::f16, Legal);
873     setOperationAction(Op, MVT::f32, Legal);
874     setOperationAction(Op, MVT::f64, Legal);
875     setOperationAction(Op, MVT::v2f16, Expand);
876     setOperationAction(Op, MVT::v2bf16, Expand);
877     setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
878     if (getOperationAction(Op, MVT::bf16) == Promote)
879       AddPromotedToType(Op, MVT::bf16, MVT::f32);
880   }
881 
882   if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71) {
883     setOperationAction(ISD::BF16_TO_FP, MVT::f32, Expand);
884   }
885   if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
886     for (MVT VT : {MVT::bf16, MVT::f32, MVT::f64}) {
887       setOperationAction(ISD::FP_EXTEND, VT, Custom);
888       setOperationAction(ISD::FP_ROUND, VT, Custom);
889     }
890   }
891 
892   // sm_80 only has conversions between f32 and bf16. Custom lower all other
893   // bf16 conversions.
894   if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
895     for (MVT VT : {MVT::i1, MVT::i16, MVT::i32, MVT::i64}) {
896       setOperationAction(
897           {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
898           VT, Custom);
899     }
900     setOperationAction(
901         {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
902         MVT::bf16, Custom);
903   }
904 
905   setOperationAction(ISD::FROUND, MVT::f16, Promote);
906   setOperationAction(ISD::FROUND, MVT::v2f16, Expand);
907   setOperationAction(ISD::FROUND, MVT::v2bf16, Expand);
908   setOperationAction(ISD::FROUND, MVT::f32, Custom);
909   setOperationAction(ISD::FROUND, MVT::f64, Custom);
910   setOperationAction(ISD::FROUND, MVT::bf16, Promote);
911   AddPromotedToType(ISD::FROUND, MVT::bf16, MVT::f32);
912 
913   // 'Expand' implements FCOPYSIGN without calling an external library.
914   setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand);
915   setOperationAction(ISD::FCOPYSIGN, MVT::v2f16, Expand);
916   setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Expand);
917   setOperationAction(ISD::FCOPYSIGN, MVT::v2bf16, Expand);
918   setOperationAction(ISD::FCOPYSIGN, MVT::f32, Custom);
919   setOperationAction(ISD::FCOPYSIGN, MVT::f64, Custom);
920 
921   // These map to corresponding instructions for f32/f64. f16 must be
922   // promoted to f32. v2f16 is expanded to f16, which is then promoted
923   // to f32.
924   for (const auto &Op :
925        {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS}) {
926     setOperationAction(Op, MVT::f16, Promote);
927     setOperationAction(Op, MVT::f32, Legal);
928     setOperationAction(Op, MVT::f64, Legal);
929     setOperationAction(Op, MVT::v2f16, Expand);
930     setOperationAction(Op, MVT::v2bf16, Expand);
931     setOperationAction(Op, MVT::bf16, Promote);
932     AddPromotedToType(Op, MVT::bf16, MVT::f32);
933   }
934 
935   setOperationAction(ISD::FABS, {MVT::f32, MVT::f64}, Legal);
936   if (STI.getPTXVersion() >= 65) {
937     setFP16OperationAction(ISD::FABS, MVT::f16, Legal, Promote);
938     setFP16OperationAction(ISD::FABS, MVT::v2f16, Legal, Expand);
939   } else {
940     setOperationAction(ISD::FABS, MVT::f16, Promote);
941     setOperationAction(ISD::FABS, MVT::v2f16, Expand);
942   }
943   setBF16OperationAction(ISD::FABS, MVT::v2bf16, Legal, Expand);
944   setBF16OperationAction(ISD::FABS, MVT::bf16, Legal, Promote);
945   if (getOperationAction(ISD::FABS, MVT::bf16) == Promote)
946     AddPromotedToType(ISD::FABS, MVT::bf16, MVT::f32);
947 
948   for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
949     setOperationAction(Op, MVT::f32, Legal);
950     setOperationAction(Op, MVT::f64, Legal);
951     setFP16OperationAction(Op, MVT::f16, Legal, Promote);
952     setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
953     setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
954     setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
955     if (getOperationAction(Op, MVT::bf16) == Promote)
956       AddPromotedToType(Op, MVT::bf16, MVT::f32);
957   }
958   bool SupportsF32MinMaxNaN =
959       STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
960   for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) {
961     setOperationAction(Op, MVT::f32, SupportsF32MinMaxNaN ? Legal : Expand);
962     setFP16OperationAction(Op, MVT::f16, Legal, Expand);
963     setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
964     setBF16OperationAction(Op, MVT::bf16, Legal, Expand);
965     setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
966   }
967 
968   // Custom lowering for inline asm with 128-bit operands
969   setOperationAction(ISD::CopyToReg, MVT::i128, Custom);
970   setOperationAction(ISD::CopyFromReg, MVT::i128, Custom);
971 
972   // FEXP2 support:
973   // - f32
974   // - f16/f16x2 (sm_70+, PTX 7.0+)
975   // - bf16/bf16x2 (sm_90+, PTX 7.8+)
976   // When f16/bf16 types aren't supported, they are promoted/expanded to f32.
977   setOperationAction(ISD::FEXP2, MVT::f32, Legal);
978   setFP16OperationAction(ISD::FEXP2, MVT::f16, Legal, Promote);
979   setFP16OperationAction(ISD::FEXP2, MVT::v2f16, Legal, Expand);
980   setBF16OperationAction(ISD::FEXP2, MVT::bf16, Legal, Promote);
981   setBF16OperationAction(ISD::FEXP2, MVT::v2bf16, Legal, Expand);
982 
983   // FLOG2 supports f32 only
984   // f16/bf16 types aren't supported, but they are promoted/expanded to f32.
985   if (UseApproxLog2F32) {
986     setOperationAction(ISD::FLOG2, MVT::f32, Legal);
987     setOperationPromotedToType(ISD::FLOG2, MVT::f16, MVT::f32);
988     setOperationPromotedToType(ISD::FLOG2, MVT::bf16, MVT::f32);
989     setOperationAction(ISD::FLOG2, {MVT::v2f16, MVT::v2bf16}, Expand);
990   }
991 
992   // No FPOW or FREM in PTX.
993 
994   // Now deduce the information based on the above mentioned
995   // actions
996   computeRegisterProperties(STI.getRegisterInfo());
997 
998   setMinCmpXchgSizeInBits(STI.getMinCmpXchgSizeInBits());
999   setMaxAtomicSizeInBitsSupported(64);
1000   setMaxDivRemBitWidthSupported(64);
1001 }
1002 
1003 const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
1004 
1005 #define MAKE_CASE(V)                                                           \
1006   case V:                                                                      \
1007     return #V;
1008 
1009   switch ((NVPTXISD::NodeType)Opcode) {
1010   case NVPTXISD::FIRST_NUMBER:
1011     break;
1012 
1013     MAKE_CASE(NVPTXISD::CALL)
1014     MAKE_CASE(NVPTXISD::RET_GLUE)
1015     MAKE_CASE(NVPTXISD::LOAD_PARAM)
1016     MAKE_CASE(NVPTXISD::Wrapper)
1017     MAKE_CASE(NVPTXISD::DeclareParam)
1018     MAKE_CASE(NVPTXISD::DeclareScalarParam)
1019     MAKE_CASE(NVPTXISD::DeclareRet)
1020     MAKE_CASE(NVPTXISD::DeclareScalarRet)
1021     MAKE_CASE(NVPTXISD::DeclareRetParam)
1022     MAKE_CASE(NVPTXISD::PrintCall)
1023     MAKE_CASE(NVPTXISD::PrintConvergentCall)
1024     MAKE_CASE(NVPTXISD::PrintCallUni)
1025     MAKE_CASE(NVPTXISD::PrintConvergentCallUni)
1026     MAKE_CASE(NVPTXISD::LoadParam)
1027     MAKE_CASE(NVPTXISD::LoadParamV2)
1028     MAKE_CASE(NVPTXISD::LoadParamV4)
1029     MAKE_CASE(NVPTXISD::StoreParam)
1030     MAKE_CASE(NVPTXISD::StoreParamV2)
1031     MAKE_CASE(NVPTXISD::StoreParamV4)
1032     MAKE_CASE(NVPTXISD::StoreParamS32)
1033     MAKE_CASE(NVPTXISD::StoreParamU32)
1034     MAKE_CASE(NVPTXISD::CallArgBegin)
1035     MAKE_CASE(NVPTXISD::CallArg)
1036     MAKE_CASE(NVPTXISD::LastCallArg)
1037     MAKE_CASE(NVPTXISD::CallArgEnd)
1038     MAKE_CASE(NVPTXISD::CallVoid)
1039     MAKE_CASE(NVPTXISD::CallVal)
1040     MAKE_CASE(NVPTXISD::CallSymbol)
1041     MAKE_CASE(NVPTXISD::Prototype)
1042     MAKE_CASE(NVPTXISD::MoveParam)
1043     MAKE_CASE(NVPTXISD::StoreRetval)
1044     MAKE_CASE(NVPTXISD::StoreRetvalV2)
1045     MAKE_CASE(NVPTXISD::StoreRetvalV4)
1046     MAKE_CASE(NVPTXISD::PseudoUseParam)
1047     MAKE_CASE(NVPTXISD::RETURN)
1048     MAKE_CASE(NVPTXISD::CallSeqBegin)
1049     MAKE_CASE(NVPTXISD::CallSeqEnd)
1050     MAKE_CASE(NVPTXISD::CallPrototype)
1051     MAKE_CASE(NVPTXISD::ProxyReg)
1052     MAKE_CASE(NVPTXISD::LoadV2)
1053     MAKE_CASE(NVPTXISD::LoadV4)
1054     MAKE_CASE(NVPTXISD::LDUV2)
1055     MAKE_CASE(NVPTXISD::LDUV4)
1056     MAKE_CASE(NVPTXISD::StoreV2)
1057     MAKE_CASE(NVPTXISD::StoreV4)
1058     MAKE_CASE(NVPTXISD::FSHL_CLAMP)
1059     MAKE_CASE(NVPTXISD::FSHR_CLAMP)
1060     MAKE_CASE(NVPTXISD::BFE)
1061     MAKE_CASE(NVPTXISD::BFI)
1062     MAKE_CASE(NVPTXISD::PRMT)
1063     MAKE_CASE(NVPTXISD::FCOPYSIGN)
1064     MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC)
1065     MAKE_CASE(NVPTXISD::STACKRESTORE)
1066     MAKE_CASE(NVPTXISD::STACKSAVE)
1067     MAKE_CASE(NVPTXISD::SETP_F16X2)
1068     MAKE_CASE(NVPTXISD::SETP_BF16X2)
1069     MAKE_CASE(NVPTXISD::Dummy)
1070     MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED)
1071     MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED)
1072     MAKE_CASE(NVPTXISD::BrxEnd)
1073     MAKE_CASE(NVPTXISD::BrxItem)
1074     MAKE_CASE(NVPTXISD::BrxStart)
1075   }
1076   return nullptr;
1077 
1078 #undef MAKE_CASE
1079 }
1080 
1081 TargetLoweringBase::LegalizeTypeAction
1082 NVPTXTargetLowering::getPreferredVectorAction(MVT VT) const {
1083   if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 &&
1084       VT.getScalarType() == MVT::i1)
1085     return TypeSplitVector;
1086   return TargetLoweringBase::getPreferredVectorAction(VT);
1087 }
1088 
1089 SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
1090                                              int Enabled, int &ExtraSteps,
1091                                              bool &UseOneConst,
1092                                              bool Reciprocal) const {
1093   if (!(Enabled == ReciprocalEstimate::Enabled ||
1094         (Enabled == ReciprocalEstimate::Unspecified && !usePrecSqrtF32())))
1095     return SDValue();
1096 
1097   if (ExtraSteps == ReciprocalEstimate::Unspecified)
1098     ExtraSteps = 0;
1099 
1100   SDLoc DL(Operand);
1101   EVT VT = Operand.getValueType();
1102   bool Ftz = useF32FTZ(DAG.getMachineFunction());
1103 
1104   auto MakeIntrinsicCall = [&](Intrinsic::ID IID) {
1105     return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
1106                        DAG.getConstant(IID, DL, MVT::i32), Operand);
1107   };
1108 
1109   // The sqrt and rsqrt refinement processes assume we always start out with an
1110   // approximation of the rsqrt.  Therefore, if we're going to do any refinement
1111   // (i.e. ExtraSteps > 0), we must return an rsqrt.  But if we're *not* doing
1112   // any refinement, we must return a regular sqrt.
1113   if (Reciprocal || ExtraSteps > 0) {
1114     if (VT == MVT::f32)
1115       return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_rsqrt_approx_ftz_f
1116                                    : Intrinsic::nvvm_rsqrt_approx_f);
1117     else if (VT == MVT::f64)
1118       return MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d);
1119     else
1120       return SDValue();
1121   } else {
1122     if (VT == MVT::f32)
1123       return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_sqrt_approx_ftz_f
1124                                    : Intrinsic::nvvm_sqrt_approx_f);
1125     else {
1126       // There's no sqrt.approx.f64 instruction, so we emit
1127       // reciprocal(rsqrt(x)).  This is faster than
1128       // select(x == 0, 0, x * rsqrt(x)).  (In fact, it's faster than plain
1129       // x * rsqrt(x).)
1130       return DAG.getNode(
1131           ISD::INTRINSIC_WO_CHAIN, DL, VT,
1132           DAG.getConstant(Intrinsic::nvvm_rcp_approx_ftz_d, DL, MVT::i32),
1133           MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d));
1134     }
1135   }
1136 }
1137 
1138 SDValue
1139 NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
1140   SDLoc dl(Op);
1141   const GlobalAddressSDNode *GAN = cast<GlobalAddressSDNode>(Op);
1142   auto PtrVT = getPointerTy(DAG.getDataLayout(), GAN->getAddressSpace());
1143   Op = DAG.getTargetGlobalAddress(GAN->getGlobal(), dl, PtrVT);
1144   return DAG.getNode(NVPTXISD::Wrapper, dl, PtrVT, Op);
1145 }
1146 
1147 static bool IsTypePassedAsArray(const Type *Ty) {
1148   return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
1149          Ty->isHalfTy() || Ty->isBFloatTy();
1150 }
1151 
1152 std::string NVPTXTargetLowering::getPrototype(
1153     const DataLayout &DL, Type *retTy, const ArgListTy &Args,
1154     const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign retAlignment,
1155     std::optional<std::pair<unsigned, const APInt &>> VAInfo,
1156     const CallBase &CB, unsigned UniqueCallSite) const {
1157   auto PtrVT = getPointerTy(DL);
1158 
1159   bool isABI = (STI.getSmVersion() >= 20);
1160   assert(isABI && "Non-ABI compilation is not supported");
1161   if (!isABI)
1162     return "";
1163 
1164   std::string Prototype;
1165   raw_string_ostream O(Prototype);
1166   O << "prototype_" << UniqueCallSite << " : .callprototype ";
1167 
1168   if (retTy->getTypeID() == Type::VoidTyID) {
1169     O << "()";
1170   } else {
1171     O << "(";
1172     if ((retTy->isFloatingPointTy() || retTy->isIntegerTy()) &&
1173         !IsTypePassedAsArray(retTy)) {
1174       unsigned size = 0;
1175       if (auto *ITy = dyn_cast<IntegerType>(retTy)) {
1176         size = ITy->getBitWidth();
1177       } else {
1178         assert(retTy->isFloatingPointTy() &&
1179                "Floating point type expected here");
1180         size = retTy->getPrimitiveSizeInBits();
1181       }
1182       // PTX ABI requires all scalar return values to be at least 32
1183       // bits in size.  fp16 normally uses .b16 as its storage type in
1184       // PTX, so its size must be adjusted here, too.
1185       size = promoteScalarArgumentSize(size);
1186 
1187       O << ".param .b" << size << " _";
1188     } else if (isa<PointerType>(retTy)) {
1189       O << ".param .b" << PtrVT.getSizeInBits() << " _";
1190     } else if (IsTypePassedAsArray(retTy)) {
1191       O << ".param .align " << (retAlignment ? retAlignment->value() : 0)
1192         << " .b8 _[" << DL.getTypeAllocSize(retTy) << "]";
1193     } else {
1194       llvm_unreachable("Unknown return type");
1195     }
1196     O << ") ";
1197   }
1198   O << "_ (";
1199 
1200   bool first = true;
1201 
1202   unsigned NumArgs = VAInfo ? VAInfo->first : Args.size();
1203   for (unsigned i = 0, OIdx = 0; i != NumArgs; ++i, ++OIdx) {
1204     Type *Ty = Args[i].Ty;
1205     if (!first) {
1206       O << ", ";
1207     }
1208     first = false;
1209 
1210     if (!Outs[OIdx].Flags.isByVal()) {
1211       if (IsTypePassedAsArray(Ty)) {
1212         Align ParamAlign =
1213             getArgumentAlignment(&CB, Ty, i + AttributeList::FirstArgIndex, DL);
1214         O << ".param .align " << ParamAlign.value() << " .b8 ";
1215         O << "_";
1216         O << "[" << DL.getTypeAllocSize(Ty) << "]";
1217         // update the index for Outs
1218         SmallVector<EVT, 16> vtparts;
1219         ComputeValueVTs(*this, DL, Ty, vtparts);
1220         if (unsigned len = vtparts.size())
1221           OIdx += len - 1;
1222         continue;
1223       }
1224       // i8 types in IR will be i16 types in SDAG
1225       assert((getValueType(DL, Ty) == Outs[OIdx].VT ||
1226               (getValueType(DL, Ty) == MVT::i8 && Outs[OIdx].VT == MVT::i16)) &&
1227              "type mismatch between callee prototype and arguments");
1228       // scalar type
1229       unsigned sz = 0;
1230       if (isa<IntegerType>(Ty)) {
1231         sz = cast<IntegerType>(Ty)->getBitWidth();
1232         sz = promoteScalarArgumentSize(sz);
1233       } else if (isa<PointerType>(Ty)) {
1234         sz = PtrVT.getSizeInBits();
1235       } else {
1236         sz = Ty->getPrimitiveSizeInBits();
1237       }
1238       O << ".param .b" << sz << " ";
1239       O << "_";
1240       continue;
1241     }
1242 
1243     // Indirect calls need strict ABI alignment so we disable optimizations by
1244     // not providing a function to optimize.
1245     Type *ETy = Args[i].IndirectType;
1246     Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
1247     Align ParamByValAlign =
1248         getFunctionByValParamAlign(/*F=*/nullptr, ETy, InitialAlign, DL);
1249 
1250     O << ".param .align " << ParamByValAlign.value() << " .b8 ";
1251     O << "_";
1252     O << "[" << Outs[OIdx].Flags.getByValSize() << "]";
1253   }
1254 
1255   if (VAInfo)
1256     O << (first ? "" : ",") << " .param .align " << VAInfo->second
1257       << " .b8 _[]\n";
1258   O << ")";
1259   if (shouldEmitPTXNoReturn(&CB, *nvTM))
1260     O << " .noreturn";
1261   O << ";";
1262 
1263   return Prototype;
1264 }
1265 
1266 Align NVPTXTargetLowering::getFunctionArgumentAlignment(
1267     const Function *F, Type *Ty, unsigned Idx, const DataLayout &DL) const {
1268   return getAlign(*F, Idx).value_or(getFunctionParamOptimizedAlign(F, Ty, DL));
1269 }
1270 
1271 Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
1272                                                 unsigned Idx,
1273                                                 const DataLayout &DL) const {
1274   if (!CB) {
1275     // CallSite is zero, fallback to ABI type alignment
1276     return DL.getABITypeAlign(Ty);
1277   }
1278 
1279   const Function *DirectCallee = CB->getCalledFunction();
1280 
1281   if (!DirectCallee) {
1282     // We don't have a direct function symbol, but that may be because of
1283     // constant cast instructions in the call.
1284 
1285     // With bitcast'd call targets, the instruction will be the call
1286     if (const auto *CI = dyn_cast<CallInst>(CB)) {
1287       // Check if we have call alignment metadata
1288       if (MaybeAlign StackAlign = getAlign(*CI, Idx))
1289         return StackAlign.value();
1290     }
1291     DirectCallee = getMaybeBitcastedCallee(CB);
1292   }
1293 
1294   // Check for function alignment information if we found that the
1295   // ultimate target is a Function
1296   if (DirectCallee)
1297     return getFunctionArgumentAlignment(DirectCallee, Ty, Idx, DL);
1298 
1299   // Call is indirect, fall back to the ABI type alignment
1300   return DL.getABITypeAlign(Ty);
1301 }
1302 
1303 static bool adjustElementType(EVT &ElementType) {
1304   switch (ElementType.getSimpleVT().SimpleTy) {
1305   default:
1306     return false;
1307   case MVT::f16:
1308   case MVT::bf16:
1309     ElementType = MVT::i16;
1310     return true;
1311   case MVT::f32:
1312   case MVT::v2f16:
1313   case MVT::v2bf16:
1314     ElementType = MVT::i32;
1315     return true;
1316   case MVT::f64:
1317     ElementType = MVT::i64;
1318     return true;
1319   }
1320 }
1321 
1322 // Use byte-store when the param address of the argument value is unaligned.
1323 // This may happen when the return value is a field of a packed structure.
1324 //
1325 // This is called in LowerCall() when passing the param values.
1326 static SDValue LowerUnalignedStoreParam(SelectionDAG &DAG, SDValue Chain,
1327                                         uint64_t Offset, EVT ElementType,
1328                                         SDValue StVal, SDValue &InGlue,
1329                                         unsigned ArgID, const SDLoc &dl) {
1330   // Bit logic only works on integer types
1331   if (adjustElementType(ElementType))
1332     StVal = DAG.getNode(ISD::BITCAST, dl, ElementType, StVal);
1333 
1334   // Store each byte
1335   SDVTList StoreVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1336   for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
1337     // Shift the byte to the last byte position
1338     SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, StVal,
1339                                    DAG.getConstant(i * 8, dl, MVT::i32));
1340     SDValue StoreOperands[] = {Chain, DAG.getConstant(ArgID, dl, MVT::i32),
1341                                DAG.getConstant(Offset + i, dl, MVT::i32),
1342                                ShiftVal, InGlue};
1343     // Trunc store only the last byte by using
1344     //     st.param.b8
1345     // The register type can be larger than b8.
1346     Chain = DAG.getMemIntrinsicNode(
1347         NVPTXISD::StoreParam, dl, StoreVTs, StoreOperands, MVT::i8,
1348         MachinePointerInfo(), Align(1), MachineMemOperand::MOStore);
1349     InGlue = Chain.getValue(1);
1350   }
1351   return Chain;
1352 }
1353 
1354 // Use byte-load when the param adress of the returned value is unaligned.
1355 // This may happen when the returned value is a field of a packed structure.
1356 static SDValue
1357 LowerUnalignedLoadRetParam(SelectionDAG &DAG, SDValue &Chain, uint64_t Offset,
1358                            EVT ElementType, SDValue &InGlue,
1359                            SmallVectorImpl<SDValue> &TempProxyRegOps,
1360                            const SDLoc &dl) {
1361   // Bit logic only works on integer types
1362   EVT MergedType = ElementType;
1363   adjustElementType(MergedType);
1364 
1365   // Load each byte and construct the whole value. Initial value to 0
1366   SDValue RetVal = DAG.getConstant(0, dl, MergedType);
1367   // LoadParamMemI8 loads into i16 register only
1368   SDVTList LoadVTs = DAG.getVTList(MVT::i16, MVT::Other, MVT::Glue);
1369   for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
1370     SDValue LoadOperands[] = {Chain, DAG.getConstant(1, dl, MVT::i32),
1371                               DAG.getConstant(Offset + i, dl, MVT::i32),
1372                               InGlue};
1373     // This will be selected to LoadParamMemI8
1374     SDValue LdVal =
1375         DAG.getMemIntrinsicNode(NVPTXISD::LoadParam, dl, LoadVTs, LoadOperands,
1376                                 MVT::i8, MachinePointerInfo(), Align(1));
1377     SDValue TmpLdVal = LdVal.getValue(0);
1378     Chain = LdVal.getValue(1);
1379     InGlue = LdVal.getValue(2);
1380 
1381     TmpLdVal = DAG.getNode(NVPTXISD::ProxyReg, dl,
1382                            TmpLdVal.getSimpleValueType(), TmpLdVal);
1383     TempProxyRegOps.push_back(TmpLdVal);
1384 
1385     SDValue CMask = DAG.getConstant(255, dl, MergedType);
1386     SDValue CShift = DAG.getConstant(i * 8, dl, MVT::i32);
1387     // Need to extend the i16 register to the whole width.
1388     TmpLdVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MergedType, TmpLdVal);
1389     // Mask off the high bits. Leave only the lower 8bits.
1390     // Do this because we are using loadparam.b8.
1391     TmpLdVal = DAG.getNode(ISD::AND, dl, MergedType, TmpLdVal, CMask);
1392     // Shift and merge
1393     TmpLdVal = DAG.getNode(ISD::SHL, dl, MergedType, TmpLdVal, CShift);
1394     RetVal = DAG.getNode(ISD::OR, dl, MergedType, RetVal, TmpLdVal);
1395   }
1396   if (ElementType != MergedType)
1397     RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal);
1398 
1399   return RetVal;
1400 }
1401 
1402 static bool shouldConvertToIndirectCall(const CallBase *CB,
1403                                         const GlobalAddressSDNode *Func) {
1404   if (!Func)
1405     return false;
1406   if (auto *CalleeFunc = dyn_cast<Function>(Func->getGlobal()))
1407     return CB->getFunctionType() != CalleeFunc->getFunctionType();
1408   return false;
1409 }
1410 
1411 SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
1412                                        SmallVectorImpl<SDValue> &InVals) const {
1413 
1414   if (CLI.IsVarArg && (STI.getPTXVersion() < 60 || STI.getSmVersion() < 30))
1415     report_fatal_error(
1416         "Support for variadic functions (unsized array parameter) introduced "
1417         "in PTX ISA version 6.0 and requires target sm_30.");
1418 
1419   SelectionDAG &DAG = CLI.DAG;
1420   SDLoc dl = CLI.DL;
1421   SmallVectorImpl<ISD::OutputArg> &Outs = CLI.Outs;
1422   SmallVectorImpl<SDValue> &OutVals = CLI.OutVals;
1423   SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
1424   SDValue Chain = CLI.Chain;
1425   SDValue Callee = CLI.Callee;
1426   bool &isTailCall = CLI.IsTailCall;
1427   ArgListTy &Args = CLI.getArgs();
1428   Type *RetTy = CLI.RetTy;
1429   const CallBase *CB = CLI.CB;
1430   const DataLayout &DL = DAG.getDataLayout();
1431 
1432   bool isABI = (STI.getSmVersion() >= 20);
1433   assert(isABI && "Non-ABI compilation is not supported");
1434   if (!isABI)
1435     return Chain;
1436 
1437   // Variadic arguments.
1438   //
1439   // Normally, for each argument, we declare a param scalar or a param
1440   // byte array in the .param space, and store the argument value to that
1441   // param scalar or array starting at offset 0.
1442   //
1443   // In the case of the first variadic argument, we declare a vararg byte array
1444   // with size 0. The exact size of this array isn't known at this point, so
1445   // it'll be patched later. All the variadic arguments will be stored to this
1446   // array at a certain offset (which gets tracked by 'VAOffset'). The offset is
1447   // initially set to 0, so it can be used for non-variadic arguments (which use
1448   // 0 offset) to simplify the code.
1449   //
1450   // After all vararg is processed, 'VAOffset' holds the size of the
1451   // vararg byte array.
1452 
1453   SDValue VADeclareParam;                 // vararg byte array
1454   unsigned FirstVAArg = CLI.NumFixedArgs; // position of the first variadic
1455   unsigned VAOffset = 0;                  // current offset in the param array
1456 
1457   unsigned UniqueCallSite = GlobalUniqueCallSite.fetch_add(1);
1458   SDValue TempChain = Chain;
1459   Chain = DAG.getCALLSEQ_START(Chain, UniqueCallSite, 0, dl);
1460   SDValue InGlue = Chain.getValue(1);
1461 
1462   unsigned ParamCount = 0;
1463   // Args.size() and Outs.size() need not match.
1464   // Outs.size() will be larger
1465   //   * if there is an aggregate argument with multiple fields (each field
1466   //     showing up separately in Outs)
1467   //   * if there is a vector argument with more than typical vector-length
1468   //     elements (generally if more than 4) where each vector element is
1469   //     individually present in Outs.
1470   // So a different index should be used for indexing into Outs/OutVals.
1471   // See similar issue in LowerFormalArguments.
1472   unsigned OIdx = 0;
1473   // Declare the .params or .reg need to pass values
1474   // to the function
1475   for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
1476     EVT VT = Outs[OIdx].VT;
1477     Type *Ty = Args[i].Ty;
1478     bool IsVAArg = (i >= CLI.NumFixedArgs);
1479     bool IsByVal = Outs[OIdx].Flags.isByVal();
1480 
1481     SmallVector<EVT, 16> VTs;
1482     SmallVector<uint64_t, 16> Offsets;
1483 
1484     assert((!IsByVal || Args[i].IndirectType) &&
1485            "byval arg must have indirect type");
1486     Type *ETy = (IsByVal ? Args[i].IndirectType : Ty);
1487     ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets, IsByVal ? 0 : VAOffset);
1488 
1489     Align ArgAlign;
1490     if (IsByVal) {
1491       // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
1492       // so we don't need to worry whether it's naturally aligned or not.
1493       // See TargetLowering::LowerCallTo().
1494       Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
1495       ArgAlign = getFunctionByValParamAlign(CB->getCalledFunction(), ETy,
1496                                             InitialAlign, DL);
1497       if (IsVAArg)
1498         VAOffset = alignTo(VAOffset, ArgAlign);
1499     } else {
1500       ArgAlign = getArgumentAlignment(CB, Ty, ParamCount + 1, DL);
1501     }
1502 
1503     unsigned TypeSize =
1504         (IsByVal ? Outs[OIdx].Flags.getByValSize() : DL.getTypeAllocSize(Ty));
1505     SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1506 
1507     bool NeedAlign; // Does argument declaration specify alignment?
1508     bool PassAsArray = IsByVal || IsTypePassedAsArray(Ty);
1509     if (IsVAArg) {
1510       if (ParamCount == FirstVAArg) {
1511         SDValue DeclareParamOps[] = {
1512             Chain, DAG.getConstant(STI.getMaxRequiredAlignment(), dl, MVT::i32),
1513             DAG.getConstant(ParamCount, dl, MVT::i32),
1514             DAG.getConstant(1, dl, MVT::i32), InGlue};
1515         VADeclareParam = Chain = DAG.getNode(NVPTXISD::DeclareParam, dl,
1516                                              DeclareParamVTs, DeclareParamOps);
1517       }
1518       NeedAlign = PassAsArray;
1519     } else if (PassAsArray) {
1520       // declare .param .align <align> .b8 .param<n>[<size>];
1521       SDValue DeclareParamOps[] = {
1522           Chain, DAG.getConstant(ArgAlign.value(), dl, MVT::i32),
1523           DAG.getConstant(ParamCount, dl, MVT::i32),
1524           DAG.getConstant(TypeSize, dl, MVT::i32), InGlue};
1525       Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
1526                           DeclareParamOps);
1527       NeedAlign = true;
1528     } else {
1529       // declare .param .b<size> .param<n>;
1530       if (VT.isInteger() || VT.isFloatingPoint()) {
1531         // PTX ABI requires integral types to be at least 32 bits in
1532         // size. FP16 is loaded/stored using i16, so it's handled
1533         // here as well.
1534         TypeSize = promoteScalarArgumentSize(TypeSize * 8) / 8;
1535       }
1536       SDValue DeclareScalarParamOps[] = {
1537           Chain, DAG.getConstant(ParamCount, dl, MVT::i32),
1538           DAG.getConstant(TypeSize * 8, dl, MVT::i32),
1539           DAG.getConstant(0, dl, MVT::i32), InGlue};
1540       Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
1541                           DeclareScalarParamOps);
1542       NeedAlign = false;
1543     }
1544     InGlue = Chain.getValue(1);
1545 
1546     // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
1547     // than 32-bits are sign extended or zero extended, depending on
1548     // whether they are signed or unsigned types. This case applies
1549     // only to scalar parameters and not to aggregate values.
1550     bool ExtendIntegerParam =
1551         Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty) < 32;
1552 
1553     auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
1554     SmallVector<SDValue, 6> StoreOperands;
1555     for (unsigned j = 0, je = VTs.size(); j != je; ++j) {
1556       EVT EltVT = VTs[j];
1557       int CurOffset = Offsets[j];
1558       MaybeAlign PartAlign;
1559       if (NeedAlign)
1560         PartAlign = commonAlignment(ArgAlign, CurOffset);
1561 
1562       SDValue StVal = OutVals[OIdx];
1563 
1564       MVT PromotedVT;
1565       if (PromoteScalarIntegerPTX(EltVT, &PromotedVT)) {
1566         EltVT = EVT(PromotedVT);
1567       }
1568       if (PromoteScalarIntegerPTX(StVal.getValueType(), &PromotedVT)) {
1569         llvm::ISD::NodeType Ext =
1570             Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1571         StVal = DAG.getNode(Ext, dl, PromotedVT, StVal);
1572       }
1573 
1574       if (IsByVal) {
1575         auto PtrVT = getPointerTy(DL);
1576         SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal,
1577                                       DAG.getConstant(CurOffset, dl, PtrVT));
1578         StVal = DAG.getLoad(EltVT, dl, TempChain, srcAddr, MachinePointerInfo(),
1579                             PartAlign);
1580       } else if (ExtendIntegerParam) {
1581         assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
1582         // zext/sext to i32
1583         StVal = DAG.getNode(Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND
1584                                                       : ISD::ZERO_EXTEND,
1585                             dl, MVT::i32, StVal);
1586       }
1587 
1588       if (!ExtendIntegerParam && EltVT.getSizeInBits() < 16) {
1589         // Use 16-bit registers for small stores as it's the
1590         // smallest general purpose register size supported by NVPTX.
1591         StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
1592       }
1593 
1594       // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
1595       // scalar store. In such cases, fall back to byte stores.
1596       if (VectorInfo[j] == PVF_SCALAR && !IsVAArg && PartAlign.has_value() &&
1597           PartAlign.value() <
1598               DL.getABITypeAlign(EltVT.getTypeForEVT(*DAG.getContext()))) {
1599         assert(StoreOperands.empty() && "Unfinished preceeding store.");
1600         Chain = LowerUnalignedStoreParam(
1601             DAG, Chain, IsByVal ? CurOffset + VAOffset : CurOffset, EltVT,
1602             StVal, InGlue, ParamCount, dl);
1603 
1604         // LowerUnalignedStoreParam took care of inserting the necessary nodes
1605         // into the SDAG, so just move on to the next element.
1606         if (!IsByVal)
1607           ++OIdx;
1608         continue;
1609       }
1610 
1611       // New store.
1612       if (VectorInfo[j] & PVF_FIRST) {
1613         assert(StoreOperands.empty() && "Unfinished preceding store.");
1614         StoreOperands.push_back(Chain);
1615         StoreOperands.push_back(
1616             DAG.getConstant(IsVAArg ? FirstVAArg : ParamCount, dl, MVT::i32));
1617 
1618         StoreOperands.push_back(DAG.getConstant(
1619             IsByVal ? CurOffset + VAOffset : (IsVAArg ? VAOffset : CurOffset),
1620             dl, MVT::i32));
1621       }
1622 
1623       // Record the value to store.
1624       StoreOperands.push_back(StVal);
1625 
1626       if (VectorInfo[j] & PVF_LAST) {
1627         unsigned NumElts = StoreOperands.size() - 3;
1628         NVPTXISD::NodeType Op;
1629         switch (NumElts) {
1630         case 1:
1631           Op = NVPTXISD::StoreParam;
1632           break;
1633         case 2:
1634           Op = NVPTXISD::StoreParamV2;
1635           break;
1636         case 4:
1637           Op = NVPTXISD::StoreParamV4;
1638           break;
1639         default:
1640           llvm_unreachable("Invalid vector info.");
1641         }
1642 
1643         StoreOperands.push_back(InGlue);
1644 
1645         // Adjust type of the store op if we've extended the scalar
1646         // return value.
1647         EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
1648 
1649         Chain = DAG.getMemIntrinsicNode(
1650             Op, dl, DAG.getVTList(MVT::Other, MVT::Glue), StoreOperands,
1651             TheStoreType, MachinePointerInfo(), PartAlign,
1652             MachineMemOperand::MOStore);
1653         InGlue = Chain.getValue(1);
1654 
1655         // Cleanup.
1656         StoreOperands.clear();
1657 
1658         // TODO: We may need to support vector types that can be passed
1659         // as scalars in variadic arguments.
1660         if (!IsByVal && IsVAArg) {
1661           assert(NumElts == 1 &&
1662                  "Vectorization is expected to be disabled for variadics.");
1663           VAOffset += DL.getTypeAllocSize(
1664               TheStoreType.getTypeForEVT(*DAG.getContext()));
1665         }
1666       }
1667       if (!IsByVal)
1668         ++OIdx;
1669     }
1670     assert(StoreOperands.empty() && "Unfinished parameter store.");
1671     if (!IsByVal && VTs.size() > 0)
1672       --OIdx;
1673     ++ParamCount;
1674     if (IsByVal && IsVAArg)
1675       VAOffset += TypeSize;
1676   }
1677 
1678   GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
1679   MaybeAlign retAlignment = std::nullopt;
1680 
1681   // Handle Result
1682   if (Ins.size() > 0) {
1683     SmallVector<EVT, 16> resvtparts;
1684     ComputeValueVTs(*this, DL, RetTy, resvtparts);
1685 
1686     // Declare
1687     //  .param .align N .b8 retval0[<size-in-bytes>], or
1688     //  .param .b<size-in-bits> retval0
1689     unsigned resultsz = DL.getTypeAllocSizeInBits(RetTy);
1690     if (!IsTypePassedAsArray(RetTy)) {
1691       resultsz = promoteScalarArgumentSize(resultsz);
1692       SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1693       SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32),
1694                                   DAG.getConstant(resultsz, dl, MVT::i32),
1695                                   DAG.getConstant(0, dl, MVT::i32), InGlue };
1696       Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
1697                           DeclareRetOps);
1698       InGlue = Chain.getValue(1);
1699     } else {
1700       retAlignment = getArgumentAlignment(CB, RetTy, 0, DL);
1701       assert(retAlignment && "retAlignment is guaranteed to be set");
1702       SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1703       SDValue DeclareRetOps[] = {
1704           Chain, DAG.getConstant(retAlignment->value(), dl, MVT::i32),
1705           DAG.getConstant(resultsz / 8, dl, MVT::i32),
1706           DAG.getConstant(0, dl, MVT::i32), InGlue};
1707       Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs,
1708                           DeclareRetOps);
1709       InGlue = Chain.getValue(1);
1710     }
1711   }
1712 
1713   bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
1714   // Set the size of the vararg param byte array if the callee is a variadic
1715   // function and the variadic part is not empty.
1716   if (HasVAArgs) {
1717     SDValue DeclareParamOps[] = {
1718         VADeclareParam.getOperand(0), VADeclareParam.getOperand(1),
1719         VADeclareParam.getOperand(2), DAG.getConstant(VAOffset, dl, MVT::i32),
1720         VADeclareParam.getOperand(4)};
1721     DAG.MorphNodeTo(VADeclareParam.getNode(), VADeclareParam.getOpcode(),
1722                     VADeclareParam->getVTList(), DeclareParamOps);
1723   }
1724 
1725   // If the type of the callsite does not match that of the function, convert
1726   // the callsite to an indirect call.
1727   bool ConvertToIndirectCall = shouldConvertToIndirectCall(CB, Func);
1728 
1729   // Both indirect calls and libcalls have nullptr Func. In order to distinguish
1730   // between them we must rely on the call site value which is valid for
1731   // indirect calls but is always null for libcalls.
1732   bool isIndirectCall = (!Func && CB) || ConvertToIndirectCall;
1733 
1734   if (isa<ExternalSymbolSDNode>(Callee)) {
1735     Function* CalleeFunc = nullptr;
1736 
1737     // Try to find the callee in the current module.
1738     Callee = DAG.getSymbolFunctionGlobalAddress(Callee, &CalleeFunc);
1739     assert(CalleeFunc != nullptr && "Libcall callee must be set.");
1740 
1741     // Set the "libcall callee" attribute to indicate that the function
1742     // must always have a declaration.
1743     CalleeFunc->addFnAttr("nvptx-libcall-callee", "true");
1744   }
1745 
1746   if (isIndirectCall) {
1747     // This is indirect function call case : PTX requires a prototype of the
1748     // form
1749     // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
1750     // to be emitted, and the label has to used as the last arg of call
1751     // instruction.
1752     // The prototype is embedded in a string and put as the operand for a
1753     // CallPrototype SDNode which will print out to the value of the string.
1754     SDVTList ProtoVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1755     std::string Proto = getPrototype(
1756         DL, RetTy, Args, Outs, retAlignment,
1757         HasVAArgs
1758             ? std::optional<std::pair<unsigned, const APInt &>>(std::make_pair(
1759                   CLI.NumFixedArgs, VADeclareParam->getConstantOperandAPInt(1)))
1760             : std::nullopt,
1761         *CB, UniqueCallSite);
1762     const char *ProtoStr = nvTM->getStrPool().save(Proto).data();
1763     SDValue ProtoOps[] = {
1764         Chain,
1765         DAG.getTargetExternalSymbol(ProtoStr, MVT::i32),
1766         InGlue,
1767     };
1768     Chain = DAG.getNode(NVPTXISD::CallPrototype, dl, ProtoVTs, ProtoOps);
1769     InGlue = Chain.getValue(1);
1770   }
1771   // Op to just print "call"
1772   SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1773   SDValue PrintCallOps[] = {
1774     Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, dl, MVT::i32), InGlue
1775   };
1776   // We model convergent calls as separate opcodes.
1777   unsigned Opcode = isIndirectCall ? NVPTXISD::PrintCall : NVPTXISD::PrintCallUni;
1778   if (CLI.IsConvergent)
1779     Opcode = Opcode == NVPTXISD::PrintCallUni ? NVPTXISD::PrintConvergentCallUni
1780                                               : NVPTXISD::PrintConvergentCall;
1781   Chain = DAG.getNode(Opcode, dl, PrintCallVTs, PrintCallOps);
1782   InGlue = Chain.getValue(1);
1783 
1784   if (ConvertToIndirectCall) {
1785     // Copy the function ptr to a ptx register and use the register to call the
1786     // function.
1787     EVT DestVT = Callee.getValueType();
1788     MachineRegisterInfo &RegInfo = DAG.getMachineFunction().getRegInfo();
1789     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
1790     unsigned DestReg =
1791         RegInfo.createVirtualRegister(TLI.getRegClassFor(DestVT.getSimpleVT()));
1792     auto RegCopy = DAG.getCopyToReg(DAG.getEntryNode(), dl, DestReg, Callee);
1793     Callee = DAG.getCopyFromReg(RegCopy, dl, DestReg, DestVT);
1794   }
1795 
1796   // Ops to print out the function name
1797   SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1798   SDValue CallVoidOps[] = { Chain, Callee, InGlue };
1799   Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps);
1800   InGlue = Chain.getValue(1);
1801 
1802   // Ops to print out the param list
1803   SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1804   SDValue CallArgBeginOps[] = { Chain, InGlue };
1805   Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs,
1806                       CallArgBeginOps);
1807   InGlue = Chain.getValue(1);
1808 
1809   for (unsigned i = 0, e = std::min(CLI.NumFixedArgs + 1, ParamCount); i != e;
1810        ++i) {
1811     unsigned opcode;
1812     if (i == (e - 1))
1813       opcode = NVPTXISD::LastCallArg;
1814     else
1815       opcode = NVPTXISD::CallArg;
1816     SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1817     SDValue CallArgOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32),
1818                              DAG.getConstant(i, dl, MVT::i32), InGlue };
1819     Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps);
1820     InGlue = Chain.getValue(1);
1821   }
1822   SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1823   SDValue CallArgEndOps[] = { Chain,
1824                               DAG.getConstant(isIndirectCall ? 0 : 1, dl, MVT::i32),
1825                               InGlue };
1826   Chain = DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps);
1827   InGlue = Chain.getValue(1);
1828 
1829   if (isIndirectCall) {
1830     SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1831     SDValue PrototypeOps[] = {
1832         Chain, DAG.getConstant(UniqueCallSite, dl, MVT::i32), InGlue};
1833     Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps);
1834     InGlue = Chain.getValue(1);
1835   }
1836 
1837   SmallVector<SDValue, 16> ProxyRegOps;
1838   SmallVector<std::optional<MVT>, 16> ProxyRegTruncates;
1839   // An item of the vector is filled if the element does not need a ProxyReg
1840   // operation on it and should be added to InVals as is. ProxyRegOps and
1841   // ProxyRegTruncates contain empty/none items at the same index.
1842   SmallVector<SDValue, 16> RetElts;
1843   // A temporary ProxyReg operations inserted in `LowerUnalignedLoadRetParam()`
1844   // to use the values of `LoadParam`s and to be replaced later then
1845   // `CALLSEQ_END` is added.
1846   SmallVector<SDValue, 16> TempProxyRegOps;
1847 
1848   // Generate loads from param memory/moves from registers for result
1849   if (Ins.size() > 0) {
1850     SmallVector<EVT, 16> VTs;
1851     SmallVector<uint64_t, 16> Offsets;
1852     ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets, 0);
1853     assert(VTs.size() == Ins.size() && "Bad value decomposition");
1854 
1855     Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
1856     auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
1857 
1858     SmallVector<EVT, 6> LoadVTs;
1859     int VecIdx = -1; // Index of the first element of the vector.
1860 
1861     // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
1862     // 32-bits are sign extended or zero extended, depending on whether
1863     // they are signed or unsigned types.
1864     bool ExtendIntegerRetVal =
1865         RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
1866 
1867     for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
1868       bool needTruncate = false;
1869       EVT TheLoadType = VTs[i];
1870       EVT EltType = Ins[i].VT;
1871       Align EltAlign = commonAlignment(RetAlign, Offsets[i]);
1872       MVT PromotedVT;
1873 
1874       if (PromoteScalarIntegerPTX(TheLoadType, &PromotedVT)) {
1875         TheLoadType = EVT(PromotedVT);
1876         EltType = EVT(PromotedVT);
1877         needTruncate = true;
1878       }
1879 
1880       if (ExtendIntegerRetVal) {
1881         TheLoadType = MVT::i32;
1882         EltType = MVT::i32;
1883         needTruncate = true;
1884       } else if (TheLoadType.getSizeInBits() < 16) {
1885         if (VTs[i].isInteger())
1886           needTruncate = true;
1887         EltType = MVT::i16;
1888       }
1889 
1890       // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
1891       // scalar load. In such cases, fall back to byte loads.
1892       if (VectorInfo[i] == PVF_SCALAR && RetTy->isAggregateType() &&
1893           EltAlign < DL.getABITypeAlign(
1894                          TheLoadType.getTypeForEVT(*DAG.getContext()))) {
1895         assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list.");
1896         SDValue Ret = LowerUnalignedLoadRetParam(
1897             DAG, Chain, Offsets[i], TheLoadType, InGlue, TempProxyRegOps, dl);
1898         ProxyRegOps.push_back(SDValue());
1899         ProxyRegTruncates.push_back(std::optional<MVT>());
1900         RetElts.resize(i);
1901         RetElts.push_back(Ret);
1902 
1903         continue;
1904       }
1905 
1906       // Record index of the very first element of the vector.
1907       if (VectorInfo[i] & PVF_FIRST) {
1908         assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list.");
1909         VecIdx = i;
1910       }
1911 
1912       LoadVTs.push_back(EltType);
1913 
1914       if (VectorInfo[i] & PVF_LAST) {
1915         unsigned NumElts = LoadVTs.size();
1916         LoadVTs.push_back(MVT::Other);
1917         LoadVTs.push_back(MVT::Glue);
1918         NVPTXISD::NodeType Op;
1919         switch (NumElts) {
1920         case 1:
1921           Op = NVPTXISD::LoadParam;
1922           break;
1923         case 2:
1924           Op = NVPTXISD::LoadParamV2;
1925           break;
1926         case 4:
1927           Op = NVPTXISD::LoadParamV4;
1928           break;
1929         default:
1930           llvm_unreachable("Invalid vector info.");
1931         }
1932 
1933         SDValue LoadOperands[] = {
1934             Chain, DAG.getConstant(1, dl, MVT::i32),
1935             DAG.getConstant(Offsets[VecIdx], dl, MVT::i32), InGlue};
1936         SDValue RetVal = DAG.getMemIntrinsicNode(
1937             Op, dl, DAG.getVTList(LoadVTs), LoadOperands, TheLoadType,
1938             MachinePointerInfo(), EltAlign,
1939             MachineMemOperand::MOLoad);
1940 
1941         for (unsigned j = 0; j < NumElts; ++j) {
1942           ProxyRegOps.push_back(RetVal.getValue(j));
1943 
1944           if (needTruncate)
1945             ProxyRegTruncates.push_back(std::optional<MVT>(Ins[VecIdx + j].VT));
1946           else
1947             ProxyRegTruncates.push_back(std::optional<MVT>());
1948         }
1949 
1950         Chain = RetVal.getValue(NumElts);
1951         InGlue = RetVal.getValue(NumElts + 1);
1952 
1953         // Cleanup
1954         VecIdx = -1;
1955         LoadVTs.clear();
1956       }
1957     }
1958   }
1959 
1960   Chain =
1961       DAG.getCALLSEQ_END(Chain, UniqueCallSite, UniqueCallSite + 1, InGlue, dl);
1962   InGlue = Chain.getValue(1);
1963 
1964   // Append ProxyReg instructions to the chain to make sure that `callseq_end`
1965   // will not get lost. Otherwise, during libcalls expansion, the nodes can become
1966   // dangling.
1967   for (unsigned i = 0; i < ProxyRegOps.size(); ++i) {
1968     if (i < RetElts.size() && RetElts[i]) {
1969       InVals.push_back(RetElts[i]);
1970       continue;
1971     }
1972 
1973     SDValue Ret = DAG.getNode(
1974       NVPTXISD::ProxyReg, dl,
1975       DAG.getVTList(ProxyRegOps[i].getSimpleValueType(), MVT::Other, MVT::Glue),
1976       { Chain, ProxyRegOps[i], InGlue }
1977     );
1978 
1979     Chain = Ret.getValue(1);
1980     InGlue = Ret.getValue(2);
1981 
1982     if (ProxyRegTruncates[i]) {
1983       Ret = DAG.getNode(ISD::TRUNCATE, dl, *ProxyRegTruncates[i], Ret);
1984     }
1985 
1986     InVals.push_back(Ret);
1987   }
1988 
1989   for (SDValue &T : TempProxyRegOps) {
1990     SDValue Repl = DAG.getNode(
1991         NVPTXISD::ProxyReg, dl,
1992         DAG.getVTList(T.getSimpleValueType(), MVT::Other, MVT::Glue),
1993         {Chain, T.getOperand(0), InGlue});
1994     DAG.ReplaceAllUsesWith(T, Repl);
1995     DAG.RemoveDeadNode(T.getNode());
1996 
1997     Chain = Repl.getValue(1);
1998     InGlue = Repl.getValue(2);
1999   }
2000 
2001   // set isTailCall to false for now, until we figure out how to express
2002   // tail call optimization in PTX
2003   isTailCall = false;
2004   return Chain;
2005 }
2006 
2007 SDValue NVPTXTargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op,
2008                                                      SelectionDAG &DAG) const {
2009 
2010   if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
2011     const Function &Fn = DAG.getMachineFunction().getFunction();
2012 
2013     DiagnosticInfoUnsupported NoDynamicAlloca(
2014         Fn,
2015         "Support for dynamic alloca introduced in PTX ISA version 7.3 and "
2016         "requires target sm_52.",
2017         SDLoc(Op).getDebugLoc());
2018     DAG.getContext()->diagnose(NoDynamicAlloca);
2019     auto Ops = {DAG.getConstant(0, SDLoc(), Op.getValueType()),
2020                 Op.getOperand(0)};
2021     return DAG.getMergeValues(Ops, SDLoc());
2022   }
2023 
2024   SDValue Chain = Op.getOperand(0);
2025   SDValue Size = Op.getOperand(1);
2026   uint64_t Align = cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue();
2027   SDLoc DL(Op.getNode());
2028 
2029   // The size for ptx alloca instruction is 64-bit for m64 and 32-bit for m32.
2030   MVT ValueSizeTy = nvTM->is64Bit() ? MVT::i64 : MVT::i32;
2031 
2032   SDValue AllocOps[] = {Chain, DAG.getZExtOrTrunc(Size, DL, ValueSizeTy),
2033                         DAG.getTargetConstant(Align, DL, MVT::i32)};
2034   EVT RetTypes[] = {ValueSizeTy, MVT::Other};
2035   return DAG.getNode(NVPTXISD::DYNAMIC_STACKALLOC, DL, RetTypes, AllocOps);
2036 }
2037 
2038 SDValue NVPTXTargetLowering::LowerSTACKRESTORE(SDValue Op,
2039                                                SelectionDAG &DAG) const {
2040   SDLoc DL(Op.getNode());
2041   if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
2042     const Function &Fn = DAG.getMachineFunction().getFunction();
2043 
2044     DiagnosticInfoUnsupported NoStackRestore(
2045         Fn,
2046         "Support for stackrestore requires PTX ISA version >= 7.3 and target "
2047         ">= sm_52.",
2048         DL.getDebugLoc());
2049     DAG.getContext()->diagnose(NoStackRestore);
2050     return Op.getOperand(0);
2051   }
2052 
2053   const MVT LocalVT = getPointerTy(DAG.getDataLayout(), ADDRESS_SPACE_LOCAL);
2054   SDValue Chain = Op.getOperand(0);
2055   SDValue Ptr = Op.getOperand(1);
2056   SDValue ASC = DAG.getAddrSpaceCast(DL, LocalVT, Ptr, ADDRESS_SPACE_GENERIC,
2057                                      ADDRESS_SPACE_LOCAL);
2058   return DAG.getNode(NVPTXISD::STACKRESTORE, DL, MVT::Other, {Chain, ASC});
2059 }
2060 
2061 SDValue NVPTXTargetLowering::LowerSTACKSAVE(SDValue Op,
2062                                             SelectionDAG &DAG) const {
2063   SDLoc DL(Op.getNode());
2064   if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
2065     const Function &Fn = DAG.getMachineFunction().getFunction();
2066 
2067     DiagnosticInfoUnsupported NoStackSave(
2068         Fn,
2069         "Support for stacksave requires PTX ISA version >= 7.3 and target >= "
2070         "sm_52.",
2071         DL.getDebugLoc());
2072     DAG.getContext()->diagnose(NoStackSave);
2073     auto Ops = {DAG.getConstant(0, DL, Op.getValueType()), Op.getOperand(0)};
2074     return DAG.getMergeValues(Ops, DL);
2075   }
2076 
2077   const MVT LocalVT = getPointerTy(DAG.getDataLayout(), ADDRESS_SPACE_LOCAL);
2078   SDValue Chain = Op.getOperand(0);
2079   SDValue SS =
2080       DAG.getNode(NVPTXISD::STACKSAVE, DL, {LocalVT, MVT::Other}, Chain);
2081   SDValue ASC = DAG.getAddrSpaceCast(
2082       DL, Op.getValueType(), SS, ADDRESS_SPACE_LOCAL, ADDRESS_SPACE_GENERIC);
2083   return DAG.getMergeValues({ASC, SDValue(SS.getNode(), 1)}, DL);
2084 }
2085 
2086 // By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
2087 // (see LegalizeDAG.cpp). This is slow and uses local memory.
2088 // We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
2089 SDValue
2090 NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2091   SDNode *Node = Op.getNode();
2092   SDLoc dl(Node);
2093   SmallVector<SDValue, 8> Ops;
2094   unsigned NumOperands = Node->getNumOperands();
2095   for (unsigned i = 0; i < NumOperands; ++i) {
2096     SDValue SubOp = Node->getOperand(i);
2097     EVT VVT = SubOp.getNode()->getValueType(0);
2098     EVT EltVT = VVT.getVectorElementType();
2099     unsigned NumSubElem = VVT.getVectorNumElements();
2100     for (unsigned j = 0; j < NumSubElem; ++j) {
2101       Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp,
2102                                 DAG.getIntPtrConstant(j, dl)));
2103     }
2104   }
2105   return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
2106 }
2107 
2108 SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
2109   // Handle bitcasting from v2i8 without hitting the default promotion
2110   // strategy which goes through stack memory.
2111   EVT FromVT = Op->getOperand(0)->getValueType(0);
2112   if (FromVT != MVT::v2i8) {
2113     return Op;
2114   }
2115 
2116   // Pack vector elements into i16 and bitcast to final type
2117   SDLoc DL(Op);
2118   SDValue Vec0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
2119                              Op->getOperand(0), DAG.getIntPtrConstant(0, DL));
2120   SDValue Vec1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
2121                              Op->getOperand(0), DAG.getIntPtrConstant(1, DL));
2122   SDValue Extend0 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec0);
2123   SDValue Extend1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec1);
2124   SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
2125   SDValue AsInt = DAG.getNode(
2126       ISD::OR, DL, MVT::i16,
2127       {Extend0, DAG.getNode(ISD::SHL, DL, MVT::i16, {Extend1, Const8})});
2128   EVT ToVT = Op->getValueType(0);
2129   return MaybeBitcast(DAG, DL, ToVT, AsInt);
2130 }
2131 
2132 // We can init constant f16x2/v2i16/v4i8 with a single .b32 move.  Normally it
2133 // would get lowered as two constant loads and vector-packing move.
2134 // Instead we want just a constant move:
2135 //        mov.b32         %r2, 0x40003C00
2136 SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
2137                                                SelectionDAG &DAG) const {
2138   EVT VT = Op->getValueType(0);
2139   if (!(Isv2x16VT(VT) || VT == MVT::v4i8))
2140     return Op;
2141   SDLoc DL(Op);
2142 
2143   if (!llvm::all_of(Op->ops(), [](SDValue Operand) {
2144         return Operand->isUndef() || isa<ConstantSDNode>(Operand) ||
2145                isa<ConstantFPSDNode>(Operand);
2146       })) {
2147     if (VT != MVT::v4i8)
2148       return Op;
2149     // Lower non-const v4i8 vector as byte-wise constructed i32, which allows us
2150     // to optimize calculation of constant parts.
2151     auto GetPRMT = [&](const SDValue Left, const SDValue Right, bool Cast,
2152                        uint64_t SelectionValue) -> SDValue {
2153       SDValue L = Left;
2154       SDValue R = Right;
2155       if (Cast) {
2156         L = DAG.getAnyExtOrTrunc(L, DL, MVT::i32);
2157         R = DAG.getAnyExtOrTrunc(R, DL, MVT::i32);
2158       }
2159       return DAG.getNode(
2160           NVPTXISD::PRMT, DL, MVT::v4i8,
2161           {L, R, DAG.getConstant(SelectionValue, DL, MVT::i32),
2162            DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});
2163     };
2164     auto PRMT__10 = GetPRMT(Op->getOperand(0), Op->getOperand(1), true, 0x3340);
2165     auto PRMT__32 = GetPRMT(Op->getOperand(2), Op->getOperand(3), true, 0x3340);
2166     auto PRMT3210 = GetPRMT(PRMT__10, PRMT__32, false, 0x5410);
2167     return DAG.getNode(ISD::BITCAST, DL, VT, PRMT3210);
2168   }
2169 
2170   // Get value or the Nth operand as an APInt(32). Undef values treated as 0.
2171   auto GetOperand = [](SDValue Op, int N) -> APInt {
2172     const SDValue &Operand = Op->getOperand(N);
2173     EVT VT = Op->getValueType(0);
2174     if (Operand->isUndef())
2175       return APInt(32, 0);
2176     APInt Value;
2177     if (VT == MVT::v2f16 || VT == MVT::v2bf16)
2178       Value = cast<ConstantFPSDNode>(Operand)->getValueAPF().bitcastToAPInt();
2179     else if (VT == MVT::v2i16 || VT == MVT::v4i8)
2180       Value = Operand->getAsAPIntVal();
2181     else
2182       llvm_unreachable("Unsupported type");
2183     // i8 values are carried around as i16, so we need to zero out upper bits,
2184     // so they do not get in the way of combining individual byte values
2185     if (VT == MVT::v4i8)
2186       Value = Value.trunc(8);
2187     return Value.zext(32);
2188   };
2189   APInt Value;
2190   if (Isv2x16VT(VT)) {
2191     Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(16);
2192   } else if (VT == MVT::v4i8) {
2193     Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(8) |
2194             GetOperand(Op, 2).shl(16) | GetOperand(Op, 3).shl(24);
2195   } else {
2196     llvm_unreachable("Unsupported type");
2197   }
2198   SDValue Const = DAG.getConstant(Value, DL, MVT::i32);
2199   return DAG.getNode(ISD::BITCAST, DL, Op->getValueType(0), Const);
2200 }
2201 
2202 SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
2203                                                      SelectionDAG &DAG) const {
2204   SDValue Index = Op->getOperand(1);
2205   SDValue Vector = Op->getOperand(0);
2206   SDLoc DL(Op);
2207   EVT VectorVT = Vector.getValueType();
2208 
2209   if (VectorVT == MVT::v4i8) {
2210     SDValue BFE =
2211         DAG.getNode(NVPTXISD::BFE, DL, MVT::i32,
2212                     {Vector,
2213                      DAG.getNode(ISD::MUL, DL, MVT::i32,
2214                                  DAG.getZExtOrTrunc(Index, DL, MVT::i32),
2215                                  DAG.getConstant(8, DL, MVT::i32)),
2216                      DAG.getConstant(8, DL, MVT::i32)});
2217     return DAG.getAnyExtOrTrunc(BFE, DL, Op->getValueType(0));
2218   }
2219 
2220   // Constant index will be matched by tablegen.
2221   if (isa<ConstantSDNode>(Index.getNode()))
2222     return Op;
2223 
2224   // Extract individual elements and select one of them.
2225   assert(Isv2x16VT(VectorVT) && "Unexpected vector type.");
2226   EVT EltVT = VectorVT.getVectorElementType();
2227 
2228   SDLoc dl(Op.getNode());
2229   SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector,
2230                            DAG.getIntPtrConstant(0, dl));
2231   SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector,
2232                            DAG.getIntPtrConstant(1, dl));
2233   return DAG.getSelectCC(dl, Index, DAG.getIntPtrConstant(0, dl), E0, E1,
2234                          ISD::CondCode::SETEQ);
2235 }
2236 
2237 SDValue NVPTXTargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
2238                                                     SelectionDAG &DAG) const {
2239   SDValue Vector = Op->getOperand(0);
2240   EVT VectorVT = Vector.getValueType();
2241 
2242   if (VectorVT != MVT::v4i8)
2243     return Op;
2244   SDLoc DL(Op);
2245   SDValue Value = Op->getOperand(1);
2246   if (Value->isUndef())
2247     return Vector;
2248 
2249   SDValue Index = Op->getOperand(2);
2250 
2251   SDValue BFI =
2252       DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
2253                   {DAG.getZExtOrTrunc(Value, DL, MVT::i32), Vector,
2254                    DAG.getNode(ISD::MUL, DL, MVT::i32,
2255                                DAG.getZExtOrTrunc(Index, DL, MVT::i32),
2256                                DAG.getConstant(8, DL, MVT::i32)),
2257                    DAG.getConstant(8, DL, MVT::i32)});
2258   return DAG.getNode(ISD::BITCAST, DL, Op->getValueType(0), BFI);
2259 }
2260 
2261 SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
2262                                                  SelectionDAG &DAG) const {
2263   SDValue V1 = Op.getOperand(0);
2264   EVT VectorVT = V1.getValueType();
2265   if (VectorVT != MVT::v4i8 || Op.getValueType() != MVT::v4i8)
2266     return Op;
2267 
2268   // Lower shuffle to PRMT instruction.
2269   const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op.getNode());
2270   SDValue V2 = Op.getOperand(1);
2271   uint32_t Selector = 0;
2272   for (auto I : llvm::enumerate(SVN->getMask())) {
2273     if (I.value() != -1) // -1 is a placeholder for undef.
2274       Selector |= (I.value() << (I.index() * 4));
2275   }
2276 
2277   SDLoc DL(Op);
2278   return DAG.getNode(NVPTXISD::PRMT, DL, MVT::v4i8, V1, V2,
2279                      DAG.getConstant(Selector, DL, MVT::i32),
2280                      DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32));
2281 }
2282 /// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
2283 /// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2284 ///    amount, or
2285 /// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2286 ///    amount.
2287 SDValue NVPTXTargetLowering::LowerShiftRightParts(SDValue Op,
2288                                                   SelectionDAG &DAG) const {
2289   assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2290   assert(Op.getOpcode() == ISD::SRA_PARTS || Op.getOpcode() == ISD::SRL_PARTS);
2291 
2292   EVT VT = Op.getValueType();
2293   unsigned VTBits = VT.getSizeInBits();
2294   SDLoc dl(Op);
2295   SDValue ShOpLo = Op.getOperand(0);
2296   SDValue ShOpHi = Op.getOperand(1);
2297   SDValue ShAmt  = Op.getOperand(2);
2298   unsigned Opc = (Op.getOpcode() == ISD::SRA_PARTS) ? ISD::SRA : ISD::SRL;
2299 
2300   if (VTBits == 32 && STI.getSmVersion() >= 35) {
2301     // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2302     // {dHi, dLo} = {aHi, aLo} >> Amt
2303     //   dHi = aHi >> Amt
2304     //   dLo = shf.r.clamp aLo, aHi, Amt
2305 
2306     SDValue Hi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
2307     SDValue Lo =
2308         DAG.getNode(NVPTXISD::FSHR_CLAMP, dl, VT, ShOpHi, ShOpLo, ShAmt);
2309 
2310     SDValue Ops[2] = { Lo, Hi };
2311     return DAG.getMergeValues(Ops, dl);
2312   }
2313   else {
2314     // {dHi, dLo} = {aHi, aLo} >> Amt
2315     // - if (Amt>=size) then
2316     //      dLo = aHi >> (Amt-size)
2317     //      dHi = aHi >> Amt (this is either all 0 or all 1)
2318     //   else
2319     //      dLo = (aLo >>logic Amt) | (aHi << (size-Amt))
2320     //      dHi = aHi >> Amt
2321 
2322     SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32,
2323                                    DAG.getConstant(VTBits, dl, MVT::i32),
2324                                    ShAmt);
2325     SDValue Tmp1 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, ShAmt);
2326     SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32, ShAmt,
2327                                      DAG.getConstant(VTBits, dl, MVT::i32));
2328     SDValue Tmp2 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, RevShAmt);
2329     SDValue FalseVal = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2);
2330     SDValue TrueVal = DAG.getNode(Opc, dl, VT, ShOpHi, ExtraShAmt);
2331 
2332     SDValue Cmp = DAG.getSetCC(dl, MVT::i1, ShAmt,
2333                                DAG.getConstant(VTBits, dl, MVT::i32),
2334                                ISD::SETGE);
2335     SDValue Hi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
2336     SDValue Lo = DAG.getNode(ISD::SELECT, dl, VT, Cmp, TrueVal, FalseVal);
2337 
2338     SDValue Ops[2] = { Lo, Hi };
2339     return DAG.getMergeValues(Ops, dl);
2340   }
2341 }
2342 
2343 /// LowerShiftLeftParts - Lower SHL_PARTS, which
2344 /// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2345 ///    amount, or
2346 /// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2347 ///    amount.
2348 SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op,
2349                                                  SelectionDAG &DAG) const {
2350   assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2351   assert(Op.getOpcode() == ISD::SHL_PARTS);
2352 
2353   EVT VT = Op.getValueType();
2354   unsigned VTBits = VT.getSizeInBits();
2355   SDLoc dl(Op);
2356   SDValue ShOpLo = Op.getOperand(0);
2357   SDValue ShOpHi = Op.getOperand(1);
2358   SDValue ShAmt  = Op.getOperand(2);
2359 
2360   if (VTBits == 32 && STI.getSmVersion() >= 35) {
2361     // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2362     // {dHi, dLo} = {aHi, aLo} << Amt
2363     //   dHi = shf.l.clamp aLo, aHi, Amt
2364     //   dLo = aLo << Amt
2365 
2366     SDValue Hi =
2367         DAG.getNode(NVPTXISD::FSHL_CLAMP, dl, VT, ShOpHi, ShOpLo, ShAmt);
2368     SDValue Lo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
2369 
2370     SDValue Ops[2] = { Lo, Hi };
2371     return DAG.getMergeValues(Ops, dl);
2372   }
2373   else {
2374     // {dHi, dLo} = {aHi, aLo} << Amt
2375     // - if (Amt>=size) then
2376     //      dLo = aLo << Amt (all 0)
2377     //      dLo = aLo << (Amt-size)
2378     //   else
2379     //      dLo = aLo << Amt
2380     //      dHi = (aHi << Amt) | (aLo >> (size-Amt))
2381 
2382     SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32,
2383                                    DAG.getConstant(VTBits, dl, MVT::i32),
2384                                    ShAmt);
2385     SDValue Tmp1 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, ShAmt);
2386     SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32, ShAmt,
2387                                      DAG.getConstant(VTBits, dl, MVT::i32));
2388     SDValue Tmp2 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, RevShAmt);
2389     SDValue FalseVal = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2);
2390     SDValue TrueVal = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ExtraShAmt);
2391 
2392     SDValue Cmp = DAG.getSetCC(dl, MVT::i1, ShAmt,
2393                                DAG.getConstant(VTBits, dl, MVT::i32),
2394                                ISD::SETGE);
2395     SDValue Lo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
2396     SDValue Hi = DAG.getNode(ISD::SELECT, dl, VT, Cmp, TrueVal, FalseVal);
2397 
2398     SDValue Ops[2] = { Lo, Hi };
2399     return DAG.getMergeValues(Ops, dl);
2400   }
2401 }
2402 
2403 /// If the types match, convert the generic copysign to the NVPTXISD version,
2404 /// otherwise bail ensuring that mismatched cases are properly expaned.
2405 SDValue NVPTXTargetLowering::LowerFCOPYSIGN(SDValue Op,
2406                                             SelectionDAG &DAG) const {
2407   EVT VT = Op.getValueType();
2408   SDLoc DL(Op);
2409 
2410   SDValue In1 = Op.getOperand(0);
2411   SDValue In2 = Op.getOperand(1);
2412   EVT SrcVT = In2.getValueType();
2413 
2414   if (!SrcVT.bitsEq(VT))
2415     return SDValue();
2416 
2417   return DAG.getNode(NVPTXISD::FCOPYSIGN, DL, VT, In1, In2);
2418 }
2419 
2420 SDValue NVPTXTargetLowering::LowerFROUND(SDValue Op, SelectionDAG &DAG) const {
2421   EVT VT = Op.getValueType();
2422 
2423   if (VT == MVT::f32)
2424     return LowerFROUND32(Op, DAG);
2425 
2426   if (VT == MVT::f64)
2427     return LowerFROUND64(Op, DAG);
2428 
2429   llvm_unreachable("unhandled type");
2430 }
2431 
2432 // This is the the rounding method used in CUDA libdevice in C like code:
2433 // float roundf(float A)
2434 // {
2435 //   float RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f));
2436 //   RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2437 //   return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2438 // }
2439 SDValue NVPTXTargetLowering::LowerFROUND32(SDValue Op,
2440                                            SelectionDAG &DAG) const {
2441   SDLoc SL(Op);
2442   SDValue A = Op.getOperand(0);
2443   EVT VT = Op.getValueType();
2444 
2445   SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A);
2446 
2447   // RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f))
2448   SDValue Bitcast  = DAG.getNode(ISD::BITCAST, SL, MVT::i32, A);
2449   const unsigned SignBitMask = 0x80000000;
2450   SDValue Sign = DAG.getNode(ISD::AND, SL, MVT::i32, Bitcast,
2451                              DAG.getConstant(SignBitMask, SL, MVT::i32));
2452   const unsigned PointFiveInBits = 0x3F000000;
2453   SDValue PointFiveWithSignRaw =
2454       DAG.getNode(ISD::OR, SL, MVT::i32, Sign,
2455                   DAG.getConstant(PointFiveInBits, SL, MVT::i32));
2456   SDValue PointFiveWithSign =
2457       DAG.getNode(ISD::BITCAST, SL, VT, PointFiveWithSignRaw);
2458   SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, A, PointFiveWithSign);
2459   SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA);
2460 
2461   // RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2462   EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
2463   SDValue IsLarge =
2464       DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 23.0), SL, VT),
2465                    ISD::SETOGT);
2466   RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
2467 
2468   // return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2469   SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA,
2470                                 DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT);
2471   SDValue RoundedAForSmallA = DAG.getNode(ISD::FTRUNC, SL, VT, A);
2472   return DAG.getNode(ISD::SELECT, SL, VT, IsSmall, RoundedAForSmallA, RoundedA);
2473 }
2474 
2475 // The implementation of round(double) is similar to that of round(float) in
2476 // that they both separate the value range into three regions and use a method
2477 // specific to the region to round the values. However, round(double) first
2478 // calculates the round of the absolute value and then adds the sign back while
2479 // round(float) directly rounds the value with sign.
2480 SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
2481                                            SelectionDAG &DAG) const {
2482   SDLoc SL(Op);
2483   SDValue A = Op.getOperand(0);
2484   EVT VT = Op.getValueType();
2485 
2486   SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A);
2487 
2488   // double RoundedA = (double) (int) (abs(A) + 0.5f);
2489   SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, AbsA,
2490                                   DAG.getConstantFP(0.5, SL, VT));
2491   SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA);
2492 
2493   // RoundedA = abs(A) < 0.5 ? (double)0 : RoundedA;
2494   EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
2495   SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA,
2496                                 DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT);
2497   RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsSmall,
2498                          DAG.getConstantFP(0, SL, VT),
2499                          RoundedA);
2500 
2501   // Add sign to rounded_A
2502   RoundedA = DAG.getNode(ISD::FCOPYSIGN, SL, VT, RoundedA, A);
2503   DAG.getNode(ISD::FTRUNC, SL, VT, A);
2504 
2505   // RoundedA = abs(A) > 0x1.0p52 ? A : RoundedA;
2506   SDValue IsLarge =
2507       DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 52.0), SL, VT),
2508                    ISD::SETOGT);
2509   return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
2510 }
2511 
2512 static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) {
2513   EVT VT = N->getValueType(0);
2514   EVT NVT = MVT::f32;
2515   if (VT.isVector()) {
2516     NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
2517   }
2518   SDLoc DL(N);
2519   SDValue Tmp0 = DAG.getFPExtendOrRound(N->getOperand(0), DL, NVT);
2520   SDValue Tmp1 = DAG.getFPExtendOrRound(N->getOperand(1), DL, NVT);
2521   SDValue Res = DAG.getNode(N->getOpcode(), DL, NVT, Tmp0, Tmp1, N->getFlags());
2522   return DAG.getFPExtendOrRound(Res, DL, VT);
2523 }
2524 
2525 SDValue NVPTXTargetLowering::PromoteBinOpIfF32FTZ(SDValue Op,
2526                                                   SelectionDAG &DAG) const {
2527   if (useF32FTZ(DAG.getMachineFunction())) {
2528     return PromoteBinOpToF32(Op.getNode(), DAG);
2529   }
2530   return Op;
2531 }
2532 
2533 SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
2534                                             SelectionDAG &DAG) const {
2535   assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2536 
2537   if (Op.getValueType() == MVT::bf16) {
2538     SDLoc Loc(Op);
2539     return DAG.getNode(
2540         ISD::FP_ROUND, Loc, MVT::bf16,
2541         DAG.getNode(Op.getOpcode(), Loc, MVT::f32, Op.getOperand(0)),
2542         DAG.getIntPtrConstant(0, Loc, /*isTarget=*/true));
2543   }
2544 
2545   // Everything else is considered legal.
2546   return Op;
2547 }
2548 
2549 SDValue NVPTXTargetLowering::LowerFP_TO_INT(SDValue Op,
2550                                             SelectionDAG &DAG) const {
2551   assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2552 
2553   if (Op.getOperand(0).getValueType() == MVT::bf16) {
2554     SDLoc Loc(Op);
2555     return DAG.getNode(
2556         Op.getOpcode(), Loc, Op.getValueType(),
2557         DAG.getNode(ISD::FP_EXTEND, Loc, MVT::f32, Op.getOperand(0)));
2558   }
2559 
2560   // Everything else is considered legal.
2561   return Op;
2562 }
2563 
2564 SDValue NVPTXTargetLowering::LowerFP_ROUND(SDValue Op,
2565                                            SelectionDAG &DAG) const {
2566   EVT NarrowVT = Op.getValueType();
2567   SDValue Wide = Op.getOperand(0);
2568   EVT WideVT = Wide.getValueType();
2569   if (NarrowVT.getScalarType() == MVT::bf16) {
2570     const TargetLowering *TLI = STI.getTargetLowering();
2571     if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 70) {
2572       return TLI->expandFP_ROUND(Op.getNode(), DAG);
2573     }
2574     if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
2575       // This combination was the first to support f32 -> bf16.
2576       if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70) {
2577         if (WideVT.getScalarType() == MVT::f32) {
2578           return Op;
2579         }
2580         if (WideVT.getScalarType() == MVT::f64) {
2581           SDLoc Loc(Op);
2582           // Round-inexact-to-odd f64 to f32, then do the final rounding using
2583           // the hardware f32 -> bf16 instruction.
2584           SDValue rod = TLI->expandRoundInexactToOdd(
2585               WideVT.isVector() ? WideVT.changeVectorElementType(MVT::f32)
2586                                 : MVT::f32,
2587               Wide, Loc, DAG);
2588           return DAG.getFPExtendOrRound(rod, Loc, NarrowVT);
2589         }
2590       }
2591       return TLI->expandFP_ROUND(Op.getNode(), DAG);
2592     }
2593   }
2594 
2595   // Everything else is considered legal.
2596   return Op;
2597 }
2598 
2599 SDValue NVPTXTargetLowering::LowerFP_EXTEND(SDValue Op,
2600                                             SelectionDAG &DAG) const {
2601   SDValue Narrow = Op.getOperand(0);
2602   EVT NarrowVT = Narrow.getValueType();
2603   EVT WideVT = Op.getValueType();
2604   if (NarrowVT.getScalarType() == MVT::bf16) {
2605     if (WideVT.getScalarType() == MVT::f32 &&
2606         (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71)) {
2607       SDLoc Loc(Op);
2608       return DAG.getNode(ISD::BF16_TO_FP, Loc, WideVT, Narrow);
2609     }
2610     if (WideVT.getScalarType() == MVT::f64 &&
2611         (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
2612       EVT F32 = NarrowVT.isVector() ? NarrowVT.changeVectorElementType(MVT::f32)
2613                                     : MVT::f32;
2614       SDLoc Loc(Op);
2615       if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 71) {
2616         Op = DAG.getNode(ISD::FP_EXTEND, Loc, F32, Narrow);
2617       } else {
2618         Op = DAG.getNode(ISD::BF16_TO_FP, Loc, F32, Narrow);
2619       }
2620       return DAG.getNode(ISD::FP_EXTEND, Loc, WideVT, Op);
2621     }
2622   }
2623 
2624   // Everything else is considered legal.
2625   return Op;
2626 }
2627 
2628 static SDValue LowerVectorArith(SDValue Op, SelectionDAG &DAG) {
2629   SDLoc DL(Op);
2630   if (Op.getValueType() != MVT::v2i16)
2631     return Op;
2632   EVT EltVT = Op.getValueType().getVectorElementType();
2633   SmallVector<SDValue> VecElements;
2634   for (int I = 0, E = Op.getValueType().getVectorNumElements(); I < E; I++) {
2635     SmallVector<SDValue> ScalarArgs;
2636     llvm::transform(Op->ops(), std::back_inserter(ScalarArgs),
2637                     [&](const SDUse &O) {
2638                       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT,
2639                                          O.get(), DAG.getIntPtrConstant(I, DL));
2640                     });
2641     VecElements.push_back(DAG.getNode(Op.getOpcode(), DL, EltVT, ScalarArgs));
2642   }
2643   SDValue V =
2644       DAG.getNode(ISD::BUILD_VECTOR, DL, Op.getValueType(), VecElements);
2645   return V;
2646 }
2647 
2648 SDValue
2649 NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
2650   switch (Op.getOpcode()) {
2651   case ISD::RETURNADDR:
2652     return SDValue();
2653   case ISD::FRAMEADDR:
2654     return SDValue();
2655   case ISD::GlobalAddress:
2656     return LowerGlobalAddress(Op, DAG);
2657   case ISD::INTRINSIC_W_CHAIN:
2658     return Op;
2659   case ISD::BUILD_VECTOR:
2660     return LowerBUILD_VECTOR(Op, DAG);
2661   case ISD::BITCAST:
2662     return LowerBITCAST(Op, DAG);
2663   case ISD::EXTRACT_SUBVECTOR:
2664     return Op;
2665   case ISD::EXTRACT_VECTOR_ELT:
2666     return LowerEXTRACT_VECTOR_ELT(Op, DAG);
2667   case ISD::INSERT_VECTOR_ELT:
2668     return LowerINSERT_VECTOR_ELT(Op, DAG);
2669   case ISD::VECTOR_SHUFFLE:
2670     return LowerVECTOR_SHUFFLE(Op, DAG);
2671   case ISD::CONCAT_VECTORS:
2672     return LowerCONCAT_VECTORS(Op, DAG);
2673   case ISD::STORE:
2674     return LowerSTORE(Op, DAG);
2675   case ISD::LOAD:
2676     return LowerLOAD(Op, DAG);
2677   case ISD::SHL_PARTS:
2678     return LowerShiftLeftParts(Op, DAG);
2679   case ISD::SRA_PARTS:
2680   case ISD::SRL_PARTS:
2681     return LowerShiftRightParts(Op, DAG);
2682   case ISD::SELECT:
2683     return LowerSelect(Op, DAG);
2684   case ISD::FROUND:
2685     return LowerFROUND(Op, DAG);
2686   case ISD::FCOPYSIGN:
2687     return LowerFCOPYSIGN(Op, DAG);
2688   case ISD::SINT_TO_FP:
2689   case ISD::UINT_TO_FP:
2690     return LowerINT_TO_FP(Op, DAG);
2691   case ISD::FP_TO_SINT:
2692   case ISD::FP_TO_UINT:
2693     return LowerFP_TO_INT(Op, DAG);
2694   case ISD::FP_ROUND:
2695     return LowerFP_ROUND(Op, DAG);
2696   case ISD::FP_EXTEND:
2697     return LowerFP_EXTEND(Op, DAG);
2698   case ISD::BR_JT:
2699     return LowerBR_JT(Op, DAG);
2700   case ISD::VAARG:
2701     return LowerVAARG(Op, DAG);
2702   case ISD::VASTART:
2703     return LowerVASTART(Op, DAG);
2704   case ISD::ABS:
2705   case ISD::SMIN:
2706   case ISD::SMAX:
2707   case ISD::UMIN:
2708   case ISD::UMAX:
2709   case ISD::ADD:
2710   case ISD::SUB:
2711   case ISD::MUL:
2712   case ISD::SHL:
2713   case ISD::SREM:
2714   case ISD::UREM:
2715     return LowerVectorArith(Op, DAG);
2716   case ISD::DYNAMIC_STACKALLOC:
2717     return LowerDYNAMIC_STACKALLOC(Op, DAG);
2718   case ISD::STACKRESTORE:
2719     return LowerSTACKRESTORE(Op, DAG);
2720   case ISD::STACKSAVE:
2721     return LowerSTACKSAVE(Op, DAG);
2722   case ISD::CopyToReg:
2723     return LowerCopyToReg_128(Op, DAG);
2724   case ISD::FADD:
2725   case ISD::FSUB:
2726   case ISD::FMUL:
2727     // Used only for bf16 on SM80, where we select fma for non-ftz operation
2728     return PromoteBinOpIfF32FTZ(Op, DAG);
2729 
2730   default:
2731     llvm_unreachable("Custom lowering not defined for operation");
2732   }
2733 }
2734 
2735 SDValue NVPTXTargetLowering::LowerBR_JT(SDValue Op, SelectionDAG &DAG) const {
2736   SDLoc DL(Op);
2737   SDValue Chain = Op.getOperand(0);
2738   const auto *JT = cast<JumpTableSDNode>(Op.getOperand(1));
2739   SDValue Index = Op.getOperand(2);
2740 
2741   unsigned JId = JT->getIndex();
2742   MachineJumpTableInfo *MJTI = DAG.getMachineFunction().getJumpTableInfo();
2743   ArrayRef<MachineBasicBlock *> MBBs = MJTI->getJumpTables()[JId].MBBs;
2744 
2745   SDValue IdV = DAG.getConstant(JId, DL, MVT::i32);
2746 
2747   // Generate BrxStart node
2748   SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue);
2749   Chain = DAG.getNode(NVPTXISD::BrxStart, DL, VTs, Chain, IdV);
2750 
2751   // Generate BrxItem nodes
2752   assert(!MBBs.empty());
2753   for (MachineBasicBlock *MBB : MBBs.drop_back())
2754     Chain = DAG.getNode(NVPTXISD::BrxItem, DL, VTs, Chain.getValue(0),
2755                         DAG.getBasicBlock(MBB), Chain.getValue(1));
2756 
2757   // Generate BrxEnd nodes
2758   SDValue EndOps[] = {Chain.getValue(0), DAG.getBasicBlock(MBBs.back()), Index,
2759                       IdV, Chain.getValue(1)};
2760   SDValue BrxEnd = DAG.getNode(NVPTXISD::BrxEnd, DL, VTs, EndOps);
2761 
2762   return BrxEnd;
2763 }
2764 
2765 // This will prevent AsmPrinter from trying to print the jump tables itself.
2766 unsigned NVPTXTargetLowering::getJumpTableEncoding() const {
2767   return MachineJumpTableInfo::EK_Inline;
2768 }
2769 
2770 // This function is almost a copy of SelectionDAG::expandVAArg().
2771 // The only diff is that this one produces loads from local address space.
2772 SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
2773   const TargetLowering *TLI = STI.getTargetLowering();
2774   SDLoc DL(Op);
2775 
2776   SDNode *Node = Op.getNode();
2777   const Value *V = cast<SrcValueSDNode>(Node->getOperand(2))->getValue();
2778   EVT VT = Node->getValueType(0);
2779   auto *Ty = VT.getTypeForEVT(*DAG.getContext());
2780   SDValue Tmp1 = Node->getOperand(0);
2781   SDValue Tmp2 = Node->getOperand(1);
2782   const MaybeAlign MA(Node->getConstantOperandVal(3));
2783 
2784   SDValue VAListLoad = DAG.getLoad(TLI->getPointerTy(DAG.getDataLayout()), DL,
2785                                    Tmp1, Tmp2, MachinePointerInfo(V));
2786   SDValue VAList = VAListLoad;
2787 
2788   if (MA && *MA > TLI->getMinStackArgumentAlignment()) {
2789     VAList = DAG.getNode(
2790         ISD::ADD, DL, VAList.getValueType(), VAList,
2791         DAG.getConstant(MA->value() - 1, DL, VAList.getValueType()));
2792 
2793     VAList = DAG.getNode(ISD::AND, DL, VAList.getValueType(), VAList,
2794                          DAG.getSignedConstant(-(int64_t)MA->value(), DL,
2795                                                VAList.getValueType()));
2796   }
2797 
2798   // Increment the pointer, VAList, to the next vaarg
2799   Tmp1 = DAG.getNode(ISD::ADD, DL, VAList.getValueType(), VAList,
2800                      DAG.getConstant(DAG.getDataLayout().getTypeAllocSize(Ty),
2801                                      DL, VAList.getValueType()));
2802 
2803   // Store the incremented VAList to the legalized pointer
2804   Tmp1 = DAG.getStore(VAListLoad.getValue(1), DL, Tmp1, Tmp2,
2805                       MachinePointerInfo(V));
2806 
2807   const Value *SrcV = Constant::getNullValue(
2808       PointerType::get(*DAG.getContext(), ADDRESS_SPACE_LOCAL));
2809 
2810   // Load the actual argument out of the pointer VAList
2811   return DAG.getLoad(VT, DL, Tmp1, VAList, MachinePointerInfo(SrcV));
2812 }
2813 
2814 SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
2815   const TargetLowering *TLI = STI.getTargetLowering();
2816   SDLoc DL(Op);
2817   EVT PtrVT = TLI->getPointerTy(DAG.getDataLayout());
2818 
2819   // Store the address of unsized array <function>_vararg[] in the ap object.
2820   SDValue Arg = getParamSymbol(DAG, /* vararg */ -1, PtrVT);
2821   SDValue VAReg = DAG.getNode(NVPTXISD::Wrapper, DL, PtrVT, Arg);
2822 
2823   const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
2824   return DAG.getStore(Op.getOperand(0), DL, VAReg, Op.getOperand(1),
2825                       MachinePointerInfo(SV));
2826 }
2827 
2828 SDValue NVPTXTargetLowering::LowerSelect(SDValue Op, SelectionDAG &DAG) const {
2829   SDValue Op0 = Op->getOperand(0);
2830   SDValue Op1 = Op->getOperand(1);
2831   SDValue Op2 = Op->getOperand(2);
2832   SDLoc DL(Op.getNode());
2833 
2834   assert(Op.getValueType() == MVT::i1 && "Custom lowering enabled only for i1");
2835 
2836   Op1 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op1);
2837   Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op2);
2838   SDValue Select = DAG.getNode(ISD::SELECT, DL, MVT::i32, Op0, Op1, Op2);
2839   SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, Select);
2840 
2841   return Trunc;
2842 }
2843 
2844 SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
2845   if (Op.getValueType() == MVT::i1)
2846     return LowerLOADi1(Op, DAG);
2847 
2848   // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
2849   // unaligned loads and have to handle it here.
2850   EVT VT = Op.getValueType();
2851   if (Isv2x16VT(VT) || VT == MVT::v4i8) {
2852     LoadSDNode *Load = cast<LoadSDNode>(Op);
2853     EVT MemVT = Load->getMemoryVT();
2854     if (!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
2855                                         MemVT, *Load->getMemOperand())) {
2856       SDValue Ops[2];
2857       std::tie(Ops[0], Ops[1]) = expandUnalignedLoad(Load, DAG);
2858       return DAG.getMergeValues(Ops, SDLoc(Op));
2859     }
2860   }
2861 
2862   return SDValue();
2863 }
2864 
2865 // v = ld i1* addr
2866 //   =>
2867 // v1 = ld i8* addr (-> i16)
2868 // v = trunc i16 to i1
2869 SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
2870   SDNode *Node = Op.getNode();
2871   LoadSDNode *LD = cast<LoadSDNode>(Node);
2872   SDLoc dl(Node);
2873   assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
2874   assert(Node->getValueType(0) == MVT::i1 &&
2875          "Custom lowering for i1 load only");
2876   SDValue newLD = DAG.getExtLoad(ISD::ZEXTLOAD, dl, MVT::i16, LD->getChain(),
2877                                  LD->getBasePtr(), LD->getPointerInfo(),
2878                                  MVT::i8, LD->getAlign(),
2879                                  LD->getMemOperand()->getFlags());
2880   SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD);
2881   // The legalizer (the caller) is expecting two values from the legalized
2882   // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
2883   // in LegalizeDAG.cpp which also uses MergeValues.
2884   SDValue Ops[] = { result, LD->getChain() };
2885   return DAG.getMergeValues(Ops, dl);
2886 }
2887 
2888 SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
2889   StoreSDNode *Store = cast<StoreSDNode>(Op);
2890   EVT VT = Store->getMemoryVT();
2891 
2892   if (VT == MVT::i1)
2893     return LowerSTOREi1(Op, DAG);
2894 
2895   // v2f16 is legal, so we can't rely on legalizer to handle unaligned
2896   // stores and have to handle it here.
2897   if ((Isv2x16VT(VT) || VT == MVT::v4i8) &&
2898       !allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
2899                                       VT, *Store->getMemOperand()))
2900     return expandUnalignedStore(Store, DAG);
2901 
2902   // v2f16, v2bf16 and v2i16 don't need special handling.
2903   if (Isv2x16VT(VT) || VT == MVT::v4i8)
2904     return SDValue();
2905 
2906   if (VT.isVector())
2907     return LowerSTOREVector(Op, DAG);
2908 
2909   return SDValue();
2910 }
2911 
2912 SDValue
2913 NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
2914   SDNode *N = Op.getNode();
2915   SDValue Val = N->getOperand(1);
2916   SDLoc DL(N);
2917   EVT ValVT = Val.getValueType();
2918 
2919   auto NumEltsAndEltVT = getVectorLoweringShape(ValVT);
2920   if (!NumEltsAndEltVT)
2921     return SDValue();
2922   auto [NumElts, EltVT] = NumEltsAndEltVT.value();
2923 
2924   MemSDNode *MemSD = cast<MemSDNode>(N);
2925   const DataLayout &TD = DAG.getDataLayout();
2926 
2927   Align Alignment = MemSD->getAlign();
2928   Align PrefAlign = TD.getPrefTypeAlign(ValVT.getTypeForEVT(*DAG.getContext()));
2929   if (Alignment < PrefAlign) {
2930     // This store is not sufficiently aligned, so bail out and let this vector
2931     // store be scalarized.  Note that we may still be able to emit smaller
2932     // vector stores.  For example, if we are storing a <4 x float> with an
2933     // alignment of 8, this check will fail but the legalizer will try again
2934     // with 2 x <2 x float>, which will succeed with an alignment of 8.
2935     return SDValue();
2936   }
2937 
2938   // Since StoreV2 is a target node, we cannot rely on DAG type legalization.
2939   // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
2940   // stored type to i16 and propagate the "real" type as the memory type.
2941   bool NeedExt = false;
2942   if (EltVT.getSizeInBits() < 16)
2943     NeedExt = true;
2944 
2945   unsigned Opcode = 0;
2946   switch (NumElts) {
2947   default:
2948     return SDValue();
2949   case 2:
2950     Opcode = NVPTXISD::StoreV2;
2951     break;
2952   case 4:
2953     Opcode = NVPTXISD::StoreV4;
2954     break;
2955   }
2956 
2957   SmallVector<SDValue, 8> Ops;
2958 
2959   // First is the chain
2960   Ops.push_back(N->getOperand(0));
2961 
2962   // Then the split values
2963   assert(NumElts <= ValVT.getVectorNumElements() &&
2964          "NumElts should not increase, only decrease or stay the same.");
2965   if (NumElts < ValVT.getVectorNumElements()) {
2966     // If the number of elements has decreased, getVectorLoweringShape has
2967     // upsized the element types
2968     assert(EltVT.isVector() && EltVT.getSizeInBits() == 32 &&
2969            EltVT.getVectorNumElements() <= 4 && "Unexpected upsized type.");
2970     // Combine individual elements into v2[i,f,bf]16/v4i8 subvectors to be
2971     // stored as b32s
2972     unsigned NumEltsPerSubVector = EltVT.getVectorNumElements();
2973     for (unsigned i = 0; i < NumElts; ++i) {
2974       SmallVector<SDValue, 4> SubVectorElts;
2975       DAG.ExtractVectorElements(Val, SubVectorElts, i * NumEltsPerSubVector,
2976                                 NumEltsPerSubVector);
2977       SDValue SubVector = DAG.getBuildVector(EltVT, DL, SubVectorElts);
2978       Ops.push_back(SubVector);
2979     }
2980   } else {
2981     for (unsigned i = 0; i < NumElts; ++i) {
2982       SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
2983                                    DAG.getIntPtrConstant(i, DL));
2984       if (NeedExt)
2985         ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
2986       Ops.push_back(ExtVal);
2987     }
2988   }
2989 
2990   // Then any remaining arguments
2991   Ops.append(N->op_begin() + 2, N->op_end());
2992 
2993   SDValue NewSt =
2994       DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
2995                               MemSD->getMemoryVT(), MemSD->getMemOperand());
2996 
2997   // return DCI.CombineTo(N, NewSt, true);
2998   return NewSt;
2999 }
3000 
3001 // st i1 v, addr
3002 //    =>
3003 // v1 = zxt v to i16
3004 // st.u8 i16, addr
3005 SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
3006   SDNode *Node = Op.getNode();
3007   SDLoc dl(Node);
3008   StoreSDNode *ST = cast<StoreSDNode>(Node);
3009   SDValue Tmp1 = ST->getChain();
3010   SDValue Tmp2 = ST->getBasePtr();
3011   SDValue Tmp3 = ST->getValue();
3012   assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
3013   Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Tmp3);
3014   SDValue Result =
3015       DAG.getTruncStore(Tmp1, dl, Tmp3, Tmp2, ST->getPointerInfo(), MVT::i8,
3016                         ST->getAlign(), ST->getMemOperand()->getFlags());
3017   return Result;
3018 }
3019 
3020 SDValue NVPTXTargetLowering::LowerCopyToReg_128(SDValue Op,
3021                                                 SelectionDAG &DAG) const {
3022   // Change the CopyToReg to take in two 64-bit operands instead of a 128-bit
3023   // operand so that it can pass the legalization.
3024 
3025   assert(Op.getOperand(1).getValueType() == MVT::i128 &&
3026          "Custom lowering for 128-bit CopyToReg only");
3027 
3028   SDNode *Node = Op.getNode();
3029   SDLoc DL(Node);
3030 
3031   SDValue Cast = DAG.getBitcast(MVT::v2i64, Op->getOperand(2));
3032   SDValue Lo = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
3033                            DAG.getIntPtrConstant(0, DL));
3034   SDValue Hi = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
3035                            DAG.getIntPtrConstant(1, DL));
3036 
3037   SmallVector<SDValue, 5> NewOps(Op->getNumOperands() + 1);
3038   SmallVector<EVT, 3> ResultsType(Node->values());
3039 
3040   NewOps[0] = Op->getOperand(0); // Chain
3041   NewOps[1] = Op->getOperand(1); // Dst Reg
3042   NewOps[2] = Lo;                // Lower 64-bit
3043   NewOps[3] = Hi;                // Higher 64-bit
3044   if (Op.getNumOperands() == 4)
3045     NewOps[4] = Op->getOperand(3); // Glue if exists
3046 
3047   return DAG.getNode(ISD::CopyToReg, DL, ResultsType, NewOps);
3048 }
3049 
3050 unsigned NVPTXTargetLowering::getNumRegisters(
3051     LLVMContext &Context, EVT VT,
3052     std::optional<MVT> RegisterVT = std::nullopt) const {
3053   if (VT == MVT::i128 && RegisterVT == MVT::i128)
3054     return 1;
3055   return TargetLoweringBase::getNumRegisters(Context, VT, RegisterVT);
3056 }
3057 
3058 bool NVPTXTargetLowering::splitValueIntoRegisterParts(
3059     SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
3060     unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
3061   if (Val.getValueType() == MVT::i128 && NumParts == 1) {
3062     Parts[0] = Val;
3063     return true;
3064   }
3065   return false;
3066 }
3067 
3068 // This creates target external symbol for a function parameter.
3069 // Name of the symbol is composed from its index and the function name.
3070 // Negative index corresponds to special parameter (unsized array) used for
3071 // passing variable arguments.
3072 SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx,
3073                                             EVT v) const {
3074   StringRef SavedStr = nvTM->getStrPool().save(
3075       getParamName(&DAG.getMachineFunction().getFunction(), idx));
3076   return DAG.getTargetExternalSymbol(SavedStr.data(), v);
3077 }
3078 
3079 SDValue NVPTXTargetLowering::LowerFormalArguments(
3080     SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
3081     const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
3082     SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
3083   MachineFunction &MF = DAG.getMachineFunction();
3084   const DataLayout &DL = DAG.getDataLayout();
3085   auto PtrVT = getPointerTy(DAG.getDataLayout());
3086 
3087   const Function *F = &MF.getFunction();
3088   const AttributeList &PAL = F->getAttributes();
3089   const TargetLowering *TLI = STI.getTargetLowering();
3090 
3091   SDValue Root = DAG.getRoot();
3092   std::vector<SDValue> OutChains;
3093 
3094   bool isABI = (STI.getSmVersion() >= 20);
3095   assert(isABI && "Non-ABI compilation is not supported");
3096   if (!isABI)
3097     return Chain;
3098 
3099   std::vector<Type *> argTypes;
3100   std::vector<const Argument *> theArgs;
3101   for (const Argument &I : F->args()) {
3102     theArgs.push_back(&I);
3103     argTypes.push_back(I.getType());
3104   }
3105   // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
3106   // Ins.size() will be larger
3107   //   * if there is an aggregate argument with multiple fields (each field
3108   //     showing up separately in Ins)
3109   //   * if there is a vector argument with more than typical vector-length
3110   //     elements (generally if more than 4) where each vector element is
3111   //     individually present in Ins.
3112   // So a different index should be used for indexing into Ins.
3113   // See similar issue in LowerCall.
3114   unsigned InsIdx = 0;
3115 
3116   for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++InsIdx) {
3117     Type *Ty = argTypes[i];
3118 
3119     if (theArgs[i]->use_empty()) {
3120       // argument is dead
3121       if (IsTypePassedAsArray(Ty) && !Ty->isVectorTy()) {
3122         SmallVector<EVT, 16> vtparts;
3123 
3124         ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts);
3125         if (vtparts.empty())
3126           report_fatal_error("Empty parameter types are not supported");
3127 
3128         for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
3129              ++parti) {
3130           InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
3131           ++InsIdx;
3132         }
3133         if (vtparts.size() > 0)
3134           --InsIdx;
3135         continue;
3136       }
3137       if (Ty->isVectorTy()) {
3138         EVT ObjectVT = getValueType(DL, Ty);
3139         unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT);
3140         for (unsigned parti = 0; parti < NumRegs; ++parti) {
3141           InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
3142           ++InsIdx;
3143         }
3144         if (NumRegs > 0)
3145           --InsIdx;
3146         continue;
3147       }
3148       InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
3149       continue;
3150     }
3151 
3152     // In the following cases, assign a node order of "i+1"
3153     // to newly created nodes. The SDNodes for params have to
3154     // appear in the same order as their order of appearance
3155     // in the original function. "i+1" holds that order.
3156     if (!PAL.hasParamAttr(i, Attribute::ByVal)) {
3157       bool aggregateIsPacked = false;
3158       if (StructType *STy = dyn_cast<StructType>(Ty))
3159         aggregateIsPacked = STy->isPacked();
3160 
3161       SmallVector<EVT, 16> VTs;
3162       SmallVector<uint64_t, 16> Offsets;
3163       ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets, 0);
3164       if (VTs.empty())
3165         report_fatal_error("Empty parameter types are not supported");
3166 
3167       Align ArgAlign = getFunctionArgumentAlignment(
3168           F, Ty, i + AttributeList::FirstArgIndex, DL);
3169       auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
3170 
3171       SDValue Arg = getParamSymbol(DAG, i, PtrVT);
3172       int VecIdx = -1; // Index of the first element of the current vector.
3173       for (unsigned parti = 0, parte = VTs.size(); parti != parte; ++parti) {
3174         if (VectorInfo[parti] & PVF_FIRST) {
3175           assert(VecIdx == -1 && "Orphaned vector.");
3176           VecIdx = parti;
3177         }
3178 
3179         // That's the last element of this store op.
3180         if (VectorInfo[parti] & PVF_LAST) {
3181           unsigned NumElts = parti - VecIdx + 1;
3182           EVT EltVT = VTs[parti];
3183           // i1 is loaded/stored as i8.
3184           EVT LoadVT = EltVT;
3185           if (EltVT == MVT::i1)
3186             LoadVT = MVT::i8;
3187           else if (Isv2x16VT(EltVT) || EltVT == MVT::v4i8)
3188             // getLoad needs a vector type, but it can't handle
3189             // vectors which contain v2f16 or v2bf16 elements. So we must load
3190             // using i32 here and then bitcast back.
3191             LoadVT = MVT::i32;
3192 
3193           EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
3194           SDValue VecAddr =
3195               DAG.getNode(ISD::ADD, dl, PtrVT, Arg,
3196                           DAG.getConstant(Offsets[VecIdx], dl, PtrVT));
3197           Value *srcValue = Constant::getNullValue(
3198               PointerType::get(F->getContext(), ADDRESS_SPACE_PARAM));
3199 
3200           const MaybeAlign PartAlign = [&]() -> MaybeAlign {
3201             if (aggregateIsPacked)
3202               return Align(1);
3203             if (NumElts != 1)
3204               return std::nullopt;
3205             Align PartAlign =
3206                 DL.getABITypeAlign(EltVT.getTypeForEVT(F->getContext()));
3207             return commonAlignment(PartAlign, Offsets[parti]);
3208           }();
3209           SDValue P = DAG.getLoad(VecVT, dl, Root, VecAddr,
3210                                   MachinePointerInfo(srcValue), PartAlign,
3211                                   MachineMemOperand::MODereferenceable |
3212                                       MachineMemOperand::MOInvariant);
3213           if (P.getNode())
3214             P.getNode()->setIROrder(i + 1);
3215           for (unsigned j = 0; j < NumElts; ++j) {
3216             SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P,
3217                                       DAG.getIntPtrConstant(j, dl));
3218             // We've loaded i1 as an i8 and now must truncate it back to i1
3219             if (EltVT == MVT::i1)
3220               Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt);
3221             // v2f16 was loaded as an i32. Now we must bitcast it back.
3222             else if (EltVT != LoadVT)
3223               Elt = DAG.getNode(ISD::BITCAST, dl, EltVT, Elt);
3224 
3225             // If a promoted integer type is used, truncate down to the original
3226             MVT PromotedVT;
3227             if (PromoteScalarIntegerPTX(EltVT, &PromotedVT)) {
3228               Elt = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt);
3229             }
3230 
3231             // Extend the element if necessary (e.g. an i8 is loaded
3232             // into an i16 register)
3233             if (Ins[InsIdx].VT.isInteger() &&
3234                 Ins[InsIdx].VT.getFixedSizeInBits() >
3235                     LoadVT.getFixedSizeInBits()) {
3236               unsigned Extend = Ins[InsIdx].Flags.isSExt() ? ISD::SIGN_EXTEND
3237                                                            : ISD::ZERO_EXTEND;
3238               Elt = DAG.getNode(Extend, dl, Ins[InsIdx].VT, Elt);
3239             }
3240             InVals.push_back(Elt);
3241           }
3242 
3243           // Reset vector tracking state.
3244           VecIdx = -1;
3245         }
3246         ++InsIdx;
3247       }
3248       if (VTs.size() > 0)
3249         --InsIdx;
3250       continue;
3251     }
3252 
3253     // Param has ByVal attribute
3254     // Return MoveParam(param symbol).
3255     // Ideally, the param symbol can be returned directly,
3256     // but when SDNode builder decides to use it in a CopyToReg(),
3257     // machine instruction fails because TargetExternalSymbol
3258     // (not lowered) is target dependent, and CopyToReg assumes
3259     // the source is lowered.
3260     EVT ObjectVT = getValueType(DL, Ty);
3261     assert(ObjectVT == Ins[InsIdx].VT &&
3262            "Ins type did not match function type");
3263     SDValue Arg = getParamSymbol(DAG, i, PtrVT);
3264     SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
3265     if (p.getNode())
3266       p.getNode()->setIROrder(i + 1);
3267     InVals.push_back(p);
3268   }
3269 
3270   if (!OutChains.empty())
3271     DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains));
3272 
3273   return Chain;
3274 }
3275 
3276 // Use byte-store when the param adress of the return value is unaligned.
3277 // This may happen when the return value is a field of a packed structure.
3278 static SDValue LowerUnalignedStoreRet(SelectionDAG &DAG, SDValue Chain,
3279                                       uint64_t Offset, EVT ElementType,
3280                                       SDValue RetVal, const SDLoc &dl) {
3281   // Bit logic only works on integer types
3282   if (adjustElementType(ElementType))
3283     RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal);
3284 
3285   // Store each byte
3286   for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
3287     // Shift the byte to the last byte position
3288     SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, RetVal,
3289                                    DAG.getConstant(i * 8, dl, MVT::i32));
3290     SDValue StoreOperands[] = {Chain, DAG.getConstant(Offset + i, dl, MVT::i32),
3291                                ShiftVal};
3292     // Trunc store only the last byte by using
3293     //     st.param.b8
3294     // The register type can be larger than b8.
3295     Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
3296                                     DAG.getVTList(MVT::Other), StoreOperands,
3297                                     MVT::i8, MachinePointerInfo(), std::nullopt,
3298                                     MachineMemOperand::MOStore);
3299   }
3300   return Chain;
3301 }
3302 
3303 SDValue
3304 NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
3305                                  bool isVarArg,
3306                                  const SmallVectorImpl<ISD::OutputArg> &Outs,
3307                                  const SmallVectorImpl<SDValue> &OutVals,
3308                                  const SDLoc &dl, SelectionDAG &DAG) const {
3309   const MachineFunction &MF = DAG.getMachineFunction();
3310   const Function &F = MF.getFunction();
3311   Type *RetTy = MF.getFunction().getReturnType();
3312 
3313   bool isABI = (STI.getSmVersion() >= 20);
3314   assert(isABI && "Non-ABI compilation is not supported");
3315   if (!isABI)
3316     return Chain;
3317 
3318   const DataLayout &DL = DAG.getDataLayout();
3319   SmallVector<SDValue, 16> PromotedOutVals;
3320   SmallVector<EVT, 16> VTs;
3321   SmallVector<uint64_t, 16> Offsets;
3322   ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
3323   assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
3324 
3325   for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
3326     SDValue PromotedOutVal = OutVals[i];
3327     MVT PromotedVT;
3328     if (PromoteScalarIntegerPTX(VTs[i], &PromotedVT)) {
3329       VTs[i] = EVT(PromotedVT);
3330     }
3331     if (PromoteScalarIntegerPTX(PromotedOutVal.getValueType(), &PromotedVT)) {
3332       llvm::ISD::NodeType Ext =
3333           Outs[i].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
3334       PromotedOutVal = DAG.getNode(Ext, dl, PromotedVT, PromotedOutVal);
3335     }
3336     PromotedOutVals.push_back(PromotedOutVal);
3337   }
3338 
3339   auto VectorInfo = VectorizePTXValueVTs(
3340       VTs, Offsets,
3341       RetTy->isSized() ? getFunctionParamOptimizedAlign(&F, RetTy, DL)
3342                        : Align(1));
3343 
3344   // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
3345   // 32-bits are sign extended or zero extended, depending on whether
3346   // they are signed or unsigned types.
3347   bool ExtendIntegerRetVal =
3348       RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
3349 
3350   SmallVector<SDValue, 6> StoreOperands;
3351   for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
3352     SDValue OutVal = OutVals[i];
3353     SDValue RetVal = PromotedOutVals[i];
3354 
3355     if (ExtendIntegerRetVal) {
3356       RetVal = DAG.getNode(Outs[i].Flags.isSExt() ? ISD::SIGN_EXTEND
3357                                                   : ISD::ZERO_EXTEND,
3358                            dl, MVT::i32, RetVal);
3359     } else if (OutVal.getValueSizeInBits() < 16) {
3360       // Use 16-bit registers for small load-stores as it's the
3361       // smallest general purpose register size supported by NVPTX.
3362       RetVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, RetVal);
3363     }
3364 
3365     // If we have a PVF_SCALAR entry, it may not even be sufficiently aligned
3366     // for a scalar store. In such cases, fall back to byte stores.
3367     if (VectorInfo[i] == PVF_SCALAR && RetTy->isAggregateType()) {
3368       EVT ElementType = ExtendIntegerRetVal ? MVT::i32 : VTs[i];
3369       Align ElementTypeAlign =
3370           DL.getABITypeAlign(ElementType.getTypeForEVT(RetTy->getContext()));
3371       Align ElementAlign =
3372           commonAlignment(DL.getABITypeAlign(RetTy), Offsets[i]);
3373       if (ElementAlign < ElementTypeAlign) {
3374         assert(StoreOperands.empty() && "Orphaned operand list.");
3375         Chain = LowerUnalignedStoreRet(DAG, Chain, Offsets[i], ElementType,
3376                                        RetVal, dl);
3377 
3378         // The call to LowerUnalignedStoreRet inserted the necessary SDAG nodes
3379         // into the graph, so just move on to the next element.
3380         continue;
3381       }
3382     }
3383 
3384     // New load/store. Record chain and offset operands.
3385     if (VectorInfo[i] & PVF_FIRST) {
3386       assert(StoreOperands.empty() && "Orphaned operand list.");
3387       StoreOperands.push_back(Chain);
3388       StoreOperands.push_back(DAG.getConstant(Offsets[i], dl, MVT::i32));
3389     }
3390 
3391     // Record the value to return.
3392     StoreOperands.push_back(RetVal);
3393 
3394     // That's the last element of this store op.
3395     if (VectorInfo[i] & PVF_LAST) {
3396       NVPTXISD::NodeType Op;
3397       unsigned NumElts = StoreOperands.size() - 2;
3398       switch (NumElts) {
3399       case 1:
3400         Op = NVPTXISD::StoreRetval;
3401         break;
3402       case 2:
3403         Op = NVPTXISD::StoreRetvalV2;
3404         break;
3405       case 4:
3406         Op = NVPTXISD::StoreRetvalV4;
3407         break;
3408       default:
3409         llvm_unreachable("Invalid vector info.");
3410       }
3411 
3412       // Adjust type of load/store op if we've extended the scalar
3413       // return value.
3414       EVT TheStoreType = ExtendIntegerRetVal ? MVT::i32 : VTs[i];
3415       Chain = DAG.getMemIntrinsicNode(
3416           Op, dl, DAG.getVTList(MVT::Other), StoreOperands, TheStoreType,
3417           MachinePointerInfo(), Align(1), MachineMemOperand::MOStore);
3418       // Cleanup vector state.
3419       StoreOperands.clear();
3420     }
3421   }
3422 
3423   return DAG.getNode(NVPTXISD::RET_GLUE, dl, MVT::Other, Chain);
3424 }
3425 
3426 void NVPTXTargetLowering::LowerAsmOperandForConstraint(
3427     SDValue Op, StringRef Constraint, std::vector<SDValue> &Ops,
3428     SelectionDAG &DAG) const {
3429   if (Constraint.size() > 1)
3430     return;
3431   TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
3432 }
3433 
3434 // llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
3435 // TgtMemIntrinsic
3436 // because we need the information that is only available in the "Value" type
3437 // of destination
3438 // pointer. In particular, the address space information.
3439 bool NVPTXTargetLowering::getTgtMemIntrinsic(
3440     IntrinsicInfo &Info, const CallInst &I,
3441     MachineFunction &MF, unsigned Intrinsic) const {
3442   switch (Intrinsic) {
3443   default:
3444     return false;
3445   case Intrinsic::nvvm_match_all_sync_i32p:
3446   case Intrinsic::nvvm_match_all_sync_i64p:
3447     Info.opc = ISD::INTRINSIC_W_CHAIN;
3448     // memVT is bogus. These intrinsics have IntrInaccessibleMemOnly attribute
3449     // in order to model data exchange with other threads, but perform no real
3450     // memory accesses.
3451     Info.memVT = MVT::i1;
3452 
3453     // Our result depends on both our and other thread's arguments.
3454     Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore;
3455     return true;
3456   case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col:
3457   case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row:
3458   case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col_stride:
3459   case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride:
3460   case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col:
3461   case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row:
3462   case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride:
3463   case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride:
3464   case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col:
3465   case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row:
3466   case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride:
3467   case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride:
3468   case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col:
3469   case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row:
3470   case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride:
3471   case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride:
3472   case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col:
3473   case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row:
3474   case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride:
3475   case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride:
3476   case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col:
3477   case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row:
3478   case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride:
3479   case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride: {
3480     Info.opc = ISD::INTRINSIC_W_CHAIN;
3481     Info.memVT = MVT::v8f16;
3482     Info.ptrVal = I.getArgOperand(0);
3483     Info.offset = 0;
3484     Info.flags = MachineMemOperand::MOLoad;
3485     Info.align = Align(16);
3486     return true;
3487   }
3488   case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col:
3489   case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col_stride:
3490   case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col_stride:
3491   case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col:
3492   case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row:
3493   case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row_stride:
3494   case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row_stride:
3495   case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row:
3496   case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col:
3497   case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col_stride:
3498   case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row:
3499   case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row_stride:
3500   case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col:
3501   case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col_stride:
3502   case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col_stride:
3503   case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col:
3504   case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row:
3505   case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row_stride:
3506   case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row_stride:
3507   case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row:
3508   case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col:
3509   case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col_stride:
3510   case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row:
3511   case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row_stride: {
3512     Info.opc = ISD::INTRINSIC_W_CHAIN;
3513     Info.memVT = MVT::v2i32;
3514     Info.ptrVal = I.getArgOperand(0);
3515     Info.offset = 0;
3516     Info.flags = MachineMemOperand::MOLoad;
3517     Info.align = Align(8);
3518     return true;
3519   }
3520 
3521   case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col:
3522   case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col_stride:
3523   case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col_stride:
3524   case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col:
3525   case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row:
3526   case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row_stride:
3527   case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row_stride:
3528   case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row:
3529   case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col:
3530   case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col_stride:
3531   case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row:
3532   case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row_stride:
3533   case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col:
3534   case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col_stride:
3535   case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row:
3536   case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row_stride:
3537 
3538   case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col:
3539   case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col_stride:
3540   case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col_stride:
3541   case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col:
3542   case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row:
3543   case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row_stride:
3544   case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row_stride:
3545   case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row:
3546   case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col:
3547   case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col_stride:
3548   case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row:
3549   case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row_stride:
3550   case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col:
3551   case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col_stride:
3552   case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
3553   case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride:
3554   case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16:
3555   case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16: {
3556     Info.opc = ISD::INTRINSIC_W_CHAIN;
3557     Info.memVT = MVT::v4i32;
3558     Info.ptrVal = I.getArgOperand(0);
3559     Info.offset = 0;
3560     Info.flags = MachineMemOperand::MOLoad;
3561     Info.align = Align(16);
3562     return true;
3563   }
3564 
3565   case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col:
3566   case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col_stride:
3567   case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col_stride:
3568   case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col:
3569   case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row:
3570   case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row_stride:
3571   case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row_stride:
3572   case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row:
3573 
3574   case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col:
3575   case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col_stride:
3576   case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col_stride:
3577   case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col:
3578   case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row:
3579   case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row_stride:
3580   case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row_stride:
3581   case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row:
3582   case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row:
3583   case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row_stride:
3584   case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col:
3585   case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col_stride:
3586   case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row:
3587   case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row_stride:
3588   case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row_stride:
3589   case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row:
3590   case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col:
3591   case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col_stride:
3592   case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
3593   case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col:
3594   case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16:
3595   case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16: {
3596     Info.opc = ISD::INTRINSIC_W_CHAIN;
3597     Info.memVT = MVT::i32;
3598     Info.ptrVal = I.getArgOperand(0);
3599     Info.offset = 0;
3600     Info.flags = MachineMemOperand::MOLoad;
3601     Info.align = Align(4);
3602     return true;
3603   }
3604 
3605   case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col:
3606   case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row:
3607   case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride:
3608   case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride:
3609   case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col:
3610   case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row:
3611   case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride:
3612   case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride:
3613   case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col:
3614   case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row:
3615   case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride:
3616   case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride: {
3617     Info.opc = ISD::INTRINSIC_W_CHAIN;
3618     Info.memVT = MVT::v4f16;
3619     Info.ptrVal = I.getArgOperand(0);
3620     Info.offset = 0;
3621     Info.flags = MachineMemOperand::MOLoad;
3622     Info.align = Align(16);
3623     return true;
3624   }
3625 
3626   case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col:
3627   case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row:
3628   case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride:
3629   case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride:
3630   case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col:
3631   case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row:
3632   case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride:
3633   case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride:
3634   case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col:
3635   case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row:
3636   case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride:
3637   case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride:
3638   case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col:
3639   case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row:
3640   case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col_stride:
3641   case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row_stride: {
3642     Info.opc = ISD::INTRINSIC_W_CHAIN;
3643     Info.memVT = MVT::v8f32;
3644     Info.ptrVal = I.getArgOperand(0);
3645     Info.offset = 0;
3646     Info.flags = MachineMemOperand::MOLoad;
3647     Info.align = Align(16);
3648     return true;
3649   }
3650 
3651   case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col:
3652   case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col_stride:
3653   case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row:
3654   case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row_stride:
3655 
3656   case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col:
3657   case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col_stride:
3658   case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row:
3659   case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row_stride:
3660 
3661   case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col:
3662   case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col_stride:
3663   case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row:
3664   case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row_stride:
3665   case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col:
3666   case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col_stride:
3667   case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row:
3668   case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row_stride:
3669   case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col:
3670   case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col_stride:
3671   case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row:
3672   case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row_stride: {
3673     Info.opc = ISD::INTRINSIC_W_CHAIN;
3674     Info.memVT = MVT::v8i32;
3675     Info.ptrVal = I.getArgOperand(0);
3676     Info.offset = 0;
3677     Info.flags = MachineMemOperand::MOLoad;
3678     Info.align = Align(16);
3679     return true;
3680   }
3681 
3682   case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col:
3683   case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col_stride:
3684   case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row:
3685   case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row_stride:
3686   case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col:
3687   case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col_stride:
3688   case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
3689   case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride:
3690   case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16:
3691   case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16: {
3692     Info.opc = ISD::INTRINSIC_W_CHAIN;
3693     Info.memVT = MVT::v2i32;
3694     Info.ptrVal = I.getArgOperand(0);
3695     Info.offset = 0;
3696     Info.flags = MachineMemOperand::MOLoad;
3697     Info.align = Align(8);
3698     return true;
3699   }
3700 
3701   case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col:
3702   case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col_stride:
3703   case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row:
3704   case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row_stride:
3705 
3706   case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col:
3707   case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col_stride:
3708   case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row:
3709   case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row_stride: {
3710     Info.opc = ISD::INTRINSIC_W_CHAIN;
3711     Info.memVT = MVT::f64;
3712     Info.ptrVal = I.getArgOperand(0);
3713     Info.offset = 0;
3714     Info.flags = MachineMemOperand::MOLoad;
3715     Info.align = Align(8);
3716     return true;
3717   }
3718 
3719   case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col:
3720   case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col_stride:
3721   case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row:
3722   case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row_stride: {
3723     Info.opc = ISD::INTRINSIC_W_CHAIN;
3724     Info.memVT = MVT::v2f64;
3725     Info.ptrVal = I.getArgOperand(0);
3726     Info.offset = 0;
3727     Info.flags = MachineMemOperand::MOLoad;
3728     Info.align = Align(16);
3729     return true;
3730   }
3731 
3732   case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col:
3733   case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row:
3734   case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride:
3735   case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride:
3736   case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col:
3737   case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row:
3738   case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride:
3739   case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride:
3740   case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col:
3741   case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row:
3742   case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride:
3743   case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride: {
3744     Info.opc = ISD::INTRINSIC_VOID;
3745     Info.memVT = MVT::v4f16;
3746     Info.ptrVal = I.getArgOperand(0);
3747     Info.offset = 0;
3748     Info.flags = MachineMemOperand::MOStore;
3749     Info.align = Align(16);
3750     return true;
3751   }
3752 
3753   case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col:
3754   case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row:
3755   case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride:
3756   case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride:
3757   case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col:
3758   case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row:
3759   case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride:
3760   case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride:
3761   case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col:
3762   case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row:
3763   case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride:
3764   case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride:
3765   case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col:
3766   case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row:
3767   case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col_stride:
3768   case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row_stride: {
3769     Info.opc = ISD::INTRINSIC_VOID;
3770     Info.memVT = MVT::v8f32;
3771     Info.ptrVal = I.getArgOperand(0);
3772     Info.offset = 0;
3773     Info.flags = MachineMemOperand::MOStore;
3774     Info.align = Align(16);
3775     return true;
3776   }
3777 
3778   case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col:
3779   case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col_stride:
3780   case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row:
3781   case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row_stride:
3782   case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col:
3783   case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col_stride:
3784   case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row:
3785   case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row_stride:
3786   case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col:
3787   case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col_stride:
3788   case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row:
3789   case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row_stride: {
3790     Info.opc = ISD::INTRINSIC_VOID;
3791     Info.memVT = MVT::v8i32;
3792     Info.ptrVal = I.getArgOperand(0);
3793     Info.offset = 0;
3794     Info.flags = MachineMemOperand::MOStore;
3795     Info.align = Align(16);
3796     return true;
3797   }
3798 
3799   case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col:
3800   case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col_stride:
3801   case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row:
3802   case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row_stride:
3803   case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
3804   case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
3805   case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
3806   case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: {
3807     Info.opc = ISD::INTRINSIC_VOID;
3808     Info.memVT = MVT::v2i32;
3809     Info.ptrVal = I.getArgOperand(0);
3810     Info.offset = 0;
3811     Info.flags = MachineMemOperand::MOStore;
3812     Info.align = Align(8);
3813     return true;
3814   }
3815 
3816   case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col:
3817   case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col_stride:
3818   case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row:
3819   case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row_stride: {
3820     Info.opc = ISD::INTRINSIC_VOID;
3821     Info.memVT = MVT::v2f64;
3822     Info.ptrVal = I.getArgOperand(0);
3823     Info.offset = 0;
3824     Info.flags = MachineMemOperand::MOStore;
3825     Info.align = Align(16);
3826     return true;
3827   }
3828 
3829   case Intrinsic::nvvm_atomic_load_inc_32:
3830   case Intrinsic::nvvm_atomic_load_dec_32:
3831 
3832   case Intrinsic::nvvm_atomic_add_gen_f_cta:
3833   case Intrinsic::nvvm_atomic_add_gen_f_sys:
3834   case Intrinsic::nvvm_atomic_add_gen_i_cta:
3835   case Intrinsic::nvvm_atomic_add_gen_i_sys:
3836   case Intrinsic::nvvm_atomic_and_gen_i_cta:
3837   case Intrinsic::nvvm_atomic_and_gen_i_sys:
3838   case Intrinsic::nvvm_atomic_cas_gen_i_cta:
3839   case Intrinsic::nvvm_atomic_cas_gen_i_sys:
3840   case Intrinsic::nvvm_atomic_dec_gen_i_cta:
3841   case Intrinsic::nvvm_atomic_dec_gen_i_sys:
3842   case Intrinsic::nvvm_atomic_inc_gen_i_cta:
3843   case Intrinsic::nvvm_atomic_inc_gen_i_sys:
3844   case Intrinsic::nvvm_atomic_max_gen_i_cta:
3845   case Intrinsic::nvvm_atomic_max_gen_i_sys:
3846   case Intrinsic::nvvm_atomic_min_gen_i_cta:
3847   case Intrinsic::nvvm_atomic_min_gen_i_sys:
3848   case Intrinsic::nvvm_atomic_or_gen_i_cta:
3849   case Intrinsic::nvvm_atomic_or_gen_i_sys:
3850   case Intrinsic::nvvm_atomic_exch_gen_i_cta:
3851   case Intrinsic::nvvm_atomic_exch_gen_i_sys:
3852   case Intrinsic::nvvm_atomic_xor_gen_i_cta:
3853   case Intrinsic::nvvm_atomic_xor_gen_i_sys: {
3854     auto &DL = I.getDataLayout();
3855     Info.opc = ISD::INTRINSIC_W_CHAIN;
3856     Info.memVT = getValueType(DL, I.getType());
3857     Info.ptrVal = I.getArgOperand(0);
3858     Info.offset = 0;
3859     Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore;
3860     Info.align.reset();
3861     return true;
3862   }
3863 
3864   case Intrinsic::nvvm_ldu_global_i:
3865   case Intrinsic::nvvm_ldu_global_f:
3866   case Intrinsic::nvvm_ldu_global_p: {
3867     auto &DL = I.getDataLayout();
3868     Info.opc = ISD::INTRINSIC_W_CHAIN;
3869     if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
3870       Info.memVT = getValueType(DL, I.getType());
3871     else if(Intrinsic == Intrinsic::nvvm_ldu_global_p)
3872       Info.memVT = getPointerTy(DL);
3873     else
3874       Info.memVT = getValueType(DL, I.getType());
3875     Info.ptrVal = I.getArgOperand(0);
3876     Info.offset = 0;
3877     Info.flags = MachineMemOperand::MOLoad;
3878     Info.align = cast<ConstantInt>(I.getArgOperand(1))->getMaybeAlignValue();
3879 
3880     return true;
3881   }
3882   case Intrinsic::nvvm_tex_1d_v4f32_s32:
3883   case Intrinsic::nvvm_tex_1d_v4f32_f32:
3884   case Intrinsic::nvvm_tex_1d_level_v4f32_f32:
3885   case Intrinsic::nvvm_tex_1d_grad_v4f32_f32:
3886   case Intrinsic::nvvm_tex_1d_array_v4f32_s32:
3887   case Intrinsic::nvvm_tex_1d_array_v4f32_f32:
3888   case Intrinsic::nvvm_tex_1d_array_level_v4f32_f32:
3889   case Intrinsic::nvvm_tex_1d_array_grad_v4f32_f32:
3890   case Intrinsic::nvvm_tex_2d_v4f32_s32:
3891   case Intrinsic::nvvm_tex_2d_v4f32_f32:
3892   case Intrinsic::nvvm_tex_2d_level_v4f32_f32:
3893   case Intrinsic::nvvm_tex_2d_grad_v4f32_f32:
3894   case Intrinsic::nvvm_tex_2d_array_v4f32_s32:
3895   case Intrinsic::nvvm_tex_2d_array_v4f32_f32:
3896   case Intrinsic::nvvm_tex_2d_array_level_v4f32_f32:
3897   case Intrinsic::nvvm_tex_2d_array_grad_v4f32_f32:
3898   case Intrinsic::nvvm_tex_3d_v4f32_s32:
3899   case Intrinsic::nvvm_tex_3d_v4f32_f32:
3900   case Intrinsic::nvvm_tex_3d_level_v4f32_f32:
3901   case Intrinsic::nvvm_tex_3d_grad_v4f32_f32:
3902   case Intrinsic::nvvm_tex_cube_v4f32_f32:
3903   case Intrinsic::nvvm_tex_cube_level_v4f32_f32:
3904   case Intrinsic::nvvm_tex_cube_array_v4f32_f32:
3905   case Intrinsic::nvvm_tex_cube_array_level_v4f32_f32:
3906   case Intrinsic::nvvm_tld4_r_2d_v4f32_f32:
3907   case Intrinsic::nvvm_tld4_g_2d_v4f32_f32:
3908   case Intrinsic::nvvm_tld4_b_2d_v4f32_f32:
3909   case Intrinsic::nvvm_tld4_a_2d_v4f32_f32:
3910   case Intrinsic::nvvm_tex_unified_1d_v4f32_s32:
3911   case Intrinsic::nvvm_tex_unified_1d_v4f32_f32:
3912   case Intrinsic::nvvm_tex_unified_1d_level_v4f32_f32:
3913   case Intrinsic::nvvm_tex_unified_1d_grad_v4f32_f32:
3914   case Intrinsic::nvvm_tex_unified_1d_array_v4f32_s32:
3915   case Intrinsic::nvvm_tex_unified_1d_array_v4f32_f32:
3916   case Intrinsic::nvvm_tex_unified_1d_array_level_v4f32_f32:
3917   case Intrinsic::nvvm_tex_unified_1d_array_grad_v4f32_f32:
3918   case Intrinsic::nvvm_tex_unified_2d_v4f32_s32:
3919   case Intrinsic::nvvm_tex_unified_2d_v4f32_f32:
3920   case Intrinsic::nvvm_tex_unified_2d_level_v4f32_f32:
3921   case Intrinsic::nvvm_tex_unified_2d_grad_v4f32_f32:
3922   case Intrinsic::nvvm_tex_unified_2d_array_v4f32_s32:
3923   case Intrinsic::nvvm_tex_unified_2d_array_v4f32_f32:
3924   case Intrinsic::nvvm_tex_unified_2d_array_level_v4f32_f32:
3925   case Intrinsic::nvvm_tex_unified_2d_array_grad_v4f32_f32:
3926   case Intrinsic::nvvm_tex_unified_3d_v4f32_s32:
3927   case Intrinsic::nvvm_tex_unified_3d_v4f32_f32:
3928   case Intrinsic::nvvm_tex_unified_3d_level_v4f32_f32:
3929   case Intrinsic::nvvm_tex_unified_3d_grad_v4f32_f32:
3930   case Intrinsic::nvvm_tex_unified_cube_v4f32_f32:
3931   case Intrinsic::nvvm_tex_unified_cube_level_v4f32_f32:
3932   case Intrinsic::nvvm_tex_unified_cube_array_v4f32_f32:
3933   case Intrinsic::nvvm_tex_unified_cube_array_level_v4f32_f32:
3934   case Intrinsic::nvvm_tex_unified_cube_grad_v4f32_f32:
3935   case Intrinsic::nvvm_tex_unified_cube_array_grad_v4f32_f32:
3936   case Intrinsic::nvvm_tld4_unified_r_2d_v4f32_f32:
3937   case Intrinsic::nvvm_tld4_unified_g_2d_v4f32_f32:
3938   case Intrinsic::nvvm_tld4_unified_b_2d_v4f32_f32:
3939   case Intrinsic::nvvm_tld4_unified_a_2d_v4f32_f32:
3940     Info.opc = ISD::INTRINSIC_W_CHAIN;
3941     Info.memVT = MVT::v4f32;
3942     Info.ptrVal = nullptr;
3943     Info.offset = 0;
3944     Info.flags = MachineMemOperand::MOLoad;
3945     Info.align = Align(16);
3946     return true;
3947 
3948   case Intrinsic::nvvm_tex_1d_v4s32_s32:
3949   case Intrinsic::nvvm_tex_1d_v4s32_f32:
3950   case Intrinsic::nvvm_tex_1d_level_v4s32_f32:
3951   case Intrinsic::nvvm_tex_1d_grad_v4s32_f32:
3952   case Intrinsic::nvvm_tex_1d_array_v4s32_s32:
3953   case Intrinsic::nvvm_tex_1d_array_v4s32_f32:
3954   case Intrinsic::nvvm_tex_1d_array_level_v4s32_f32:
3955   case Intrinsic::nvvm_tex_1d_array_grad_v4s32_f32:
3956   case Intrinsic::nvvm_tex_2d_v4s32_s32:
3957   case Intrinsic::nvvm_tex_2d_v4s32_f32:
3958   case Intrinsic::nvvm_tex_2d_level_v4s32_f32:
3959   case Intrinsic::nvvm_tex_2d_grad_v4s32_f32:
3960   case Intrinsic::nvvm_tex_2d_array_v4s32_s32:
3961   case Intrinsic::nvvm_tex_2d_array_v4s32_f32:
3962   case Intrinsic::nvvm_tex_2d_array_level_v4s32_f32:
3963   case Intrinsic::nvvm_tex_2d_array_grad_v4s32_f32:
3964   case Intrinsic::nvvm_tex_3d_v4s32_s32:
3965   case Intrinsic::nvvm_tex_3d_v4s32_f32:
3966   case Intrinsic::nvvm_tex_3d_level_v4s32_f32:
3967   case Intrinsic::nvvm_tex_3d_grad_v4s32_f32:
3968   case Intrinsic::nvvm_tex_cube_v4s32_f32:
3969   case Intrinsic::nvvm_tex_cube_level_v4s32_f32:
3970   case Intrinsic::nvvm_tex_cube_array_v4s32_f32:
3971   case Intrinsic::nvvm_tex_cube_array_level_v4s32_f32:
3972   case Intrinsic::nvvm_tex_cube_v4u32_f32:
3973   case Intrinsic::nvvm_tex_cube_level_v4u32_f32:
3974   case Intrinsic::nvvm_tex_cube_array_v4u32_f32:
3975   case Intrinsic::nvvm_tex_cube_array_level_v4u32_f32:
3976   case Intrinsic::nvvm_tex_1d_v4u32_s32:
3977   case Intrinsic::nvvm_tex_1d_v4u32_f32:
3978   case Intrinsic::nvvm_tex_1d_level_v4u32_f32:
3979   case Intrinsic::nvvm_tex_1d_grad_v4u32_f32:
3980   case Intrinsic::nvvm_tex_1d_array_v4u32_s32:
3981   case Intrinsic::nvvm_tex_1d_array_v4u32_f32:
3982   case Intrinsic::nvvm_tex_1d_array_level_v4u32_f32:
3983   case Intrinsic::nvvm_tex_1d_array_grad_v4u32_f32:
3984   case Intrinsic::nvvm_tex_2d_v4u32_s32:
3985   case Intrinsic::nvvm_tex_2d_v4u32_f32:
3986   case Intrinsic::nvvm_tex_2d_level_v4u32_f32:
3987   case Intrinsic::nvvm_tex_2d_grad_v4u32_f32:
3988   case Intrinsic::nvvm_tex_2d_array_v4u32_s32:
3989   case Intrinsic::nvvm_tex_2d_array_v4u32_f32:
3990   case Intrinsic::nvvm_tex_2d_array_level_v4u32_f32:
3991   case Intrinsic::nvvm_tex_2d_array_grad_v4u32_f32:
3992   case Intrinsic::nvvm_tex_3d_v4u32_s32:
3993   case Intrinsic::nvvm_tex_3d_v4u32_f32:
3994   case Intrinsic::nvvm_tex_3d_level_v4u32_f32:
3995   case Intrinsic::nvvm_tex_3d_grad_v4u32_f32:
3996   case Intrinsic::nvvm_tld4_r_2d_v4s32_f32:
3997   case Intrinsic::nvvm_tld4_g_2d_v4s32_f32:
3998   case Intrinsic::nvvm_tld4_b_2d_v4s32_f32:
3999   case Intrinsic::nvvm_tld4_a_2d_v4s32_f32:
4000   case Intrinsic::nvvm_tld4_r_2d_v4u32_f32:
4001   case Intrinsic::nvvm_tld4_g_2d_v4u32_f32:
4002   case Intrinsic::nvvm_tld4_b_2d_v4u32_f32:
4003   case Intrinsic::nvvm_tld4_a_2d_v4u32_f32:
4004   case Intrinsic::nvvm_tex_unified_1d_v4s32_s32:
4005   case Intrinsic::nvvm_tex_unified_1d_v4s32_f32:
4006   case Intrinsic::nvvm_tex_unified_1d_level_v4s32_f32:
4007   case Intrinsic::nvvm_tex_unified_1d_grad_v4s32_f32:
4008   case Intrinsic::nvvm_tex_unified_1d_array_v4s32_s32:
4009   case Intrinsic::nvvm_tex_unified_1d_array_v4s32_f32:
4010   case Intrinsic::nvvm_tex_unified_1d_array_level_v4s32_f32:
4011   case Intrinsic::nvvm_tex_unified_1d_array_grad_v4s32_f32:
4012   case Intrinsic::nvvm_tex_unified_2d_v4s32_s32:
4013   case Intrinsic::nvvm_tex_unified_2d_v4s32_f32:
4014   case Intrinsic::nvvm_tex_unified_2d_level_v4s32_f32:
4015   case Intrinsic::nvvm_tex_unified_2d_grad_v4s32_f32:
4016   case Intrinsic::nvvm_tex_unified_2d_array_v4s32_s32:
4017   case Intrinsic::nvvm_tex_unified_2d_array_v4s32_f32:
4018   case Intrinsic::nvvm_tex_unified_2d_array_level_v4s32_f32:
4019   case Intrinsic::nvvm_tex_unified_2d_array_grad_v4s32_f32:
4020   case Intrinsic::nvvm_tex_unified_3d_v4s32_s32:
4021   case Intrinsic::nvvm_tex_unified_3d_v4s32_f32:
4022   case Intrinsic::nvvm_tex_unified_3d_level_v4s32_f32:
4023   case Intrinsic::nvvm_tex_unified_3d_grad_v4s32_f32:
4024   case Intrinsic::nvvm_tex_unified_1d_v4u32_s32:
4025   case Intrinsic::nvvm_tex_unified_1d_v4u32_f32:
4026   case Intrinsic::nvvm_tex_unified_1d_level_v4u32_f32:
4027   case Intrinsic::nvvm_tex_unified_1d_grad_v4u32_f32:
4028   case Intrinsic::nvvm_tex_unified_1d_array_v4u32_s32:
4029   case Intrinsic::nvvm_tex_unified_1d_array_v4u32_f32:
4030   case Intrinsic::nvvm_tex_unified_1d_array_level_v4u32_f32:
4031   case Intrinsic::nvvm_tex_unified_1d_array_grad_v4u32_f32:
4032   case Intrinsic::nvvm_tex_unified_2d_v4u32_s32:
4033   case Intrinsic::nvvm_tex_unified_2d_v4u32_f32:
4034   case Intrinsic::nvvm_tex_unified_2d_level_v4u32_f32:
4035   case Intrinsic::nvvm_tex_unified_2d_grad_v4u32_f32:
4036   case Intrinsic::nvvm_tex_unified_2d_array_v4u32_s32:
4037   case Intrinsic::nvvm_tex_unified_2d_array_v4u32_f32:
4038   case Intrinsic::nvvm_tex_unified_2d_array_level_v4u32_f32:
4039   case Intrinsic::nvvm_tex_unified_2d_array_grad_v4u32_f32:
4040   case Intrinsic::nvvm_tex_unified_3d_v4u32_s32:
4041   case Intrinsic::nvvm_tex_unified_3d_v4u32_f32:
4042   case Intrinsic::nvvm_tex_unified_3d_level_v4u32_f32:
4043   case Intrinsic::nvvm_tex_unified_3d_grad_v4u32_f32:
4044   case Intrinsic::nvvm_tex_unified_cube_v4s32_f32:
4045   case Intrinsic::nvvm_tex_unified_cube_level_v4s32_f32:
4046   case Intrinsic::nvvm_tex_unified_cube_array_v4s32_f32:
4047   case Intrinsic::nvvm_tex_unified_cube_array_level_v4s32_f32:
4048   case Intrinsic::nvvm_tex_unified_cube_v4u32_f32:
4049   case Intrinsic::nvvm_tex_unified_cube_level_v4u32_f32:
4050   case Intrinsic::nvvm_tex_unified_cube_array_v4u32_f32:
4051   case Intrinsic::nvvm_tex_unified_cube_array_level_v4u32_f32:
4052   case Intrinsic::nvvm_tex_unified_cube_grad_v4s32_f32:
4053   case Intrinsic::nvvm_tex_unified_cube_grad_v4u32_f32:
4054   case Intrinsic::nvvm_tex_unified_cube_array_grad_v4s32_f32:
4055   case Intrinsic::nvvm_tex_unified_cube_array_grad_v4u32_f32:
4056   case Intrinsic::nvvm_tld4_unified_r_2d_v4s32_f32:
4057   case Intrinsic::nvvm_tld4_unified_g_2d_v4s32_f32:
4058   case Intrinsic::nvvm_tld4_unified_b_2d_v4s32_f32:
4059   case Intrinsic::nvvm_tld4_unified_a_2d_v4s32_f32:
4060   case Intrinsic::nvvm_tld4_unified_r_2d_v4u32_f32:
4061   case Intrinsic::nvvm_tld4_unified_g_2d_v4u32_f32:
4062   case Intrinsic::nvvm_tld4_unified_b_2d_v4u32_f32:
4063   case Intrinsic::nvvm_tld4_unified_a_2d_v4u32_f32:
4064     Info.opc = ISD::INTRINSIC_W_CHAIN;
4065     Info.memVT = MVT::v4i32;
4066     Info.ptrVal = nullptr;
4067     Info.offset = 0;
4068     Info.flags = MachineMemOperand::MOLoad;
4069     Info.align = Align(16);
4070     return true;
4071 
4072   case Intrinsic::nvvm_suld_1d_i8_clamp:
4073   case Intrinsic::nvvm_suld_1d_v2i8_clamp:
4074   case Intrinsic::nvvm_suld_1d_v4i8_clamp:
4075   case Intrinsic::nvvm_suld_1d_array_i8_clamp:
4076   case Intrinsic::nvvm_suld_1d_array_v2i8_clamp:
4077   case Intrinsic::nvvm_suld_1d_array_v4i8_clamp:
4078   case Intrinsic::nvvm_suld_2d_i8_clamp:
4079   case Intrinsic::nvvm_suld_2d_v2i8_clamp:
4080   case Intrinsic::nvvm_suld_2d_v4i8_clamp:
4081   case Intrinsic::nvvm_suld_2d_array_i8_clamp:
4082   case Intrinsic::nvvm_suld_2d_array_v2i8_clamp:
4083   case Intrinsic::nvvm_suld_2d_array_v4i8_clamp:
4084   case Intrinsic::nvvm_suld_3d_i8_clamp:
4085   case Intrinsic::nvvm_suld_3d_v2i8_clamp:
4086   case Intrinsic::nvvm_suld_3d_v4i8_clamp:
4087   case Intrinsic::nvvm_suld_1d_i8_trap:
4088   case Intrinsic::nvvm_suld_1d_v2i8_trap:
4089   case Intrinsic::nvvm_suld_1d_v4i8_trap:
4090   case Intrinsic::nvvm_suld_1d_array_i8_trap:
4091   case Intrinsic::nvvm_suld_1d_array_v2i8_trap:
4092   case Intrinsic::nvvm_suld_1d_array_v4i8_trap:
4093   case Intrinsic::nvvm_suld_2d_i8_trap:
4094   case Intrinsic::nvvm_suld_2d_v2i8_trap:
4095   case Intrinsic::nvvm_suld_2d_v4i8_trap:
4096   case Intrinsic::nvvm_suld_2d_array_i8_trap:
4097   case Intrinsic::nvvm_suld_2d_array_v2i8_trap:
4098   case Intrinsic::nvvm_suld_2d_array_v4i8_trap:
4099   case Intrinsic::nvvm_suld_3d_i8_trap:
4100   case Intrinsic::nvvm_suld_3d_v2i8_trap:
4101   case Intrinsic::nvvm_suld_3d_v4i8_trap:
4102   case Intrinsic::nvvm_suld_1d_i8_zero:
4103   case Intrinsic::nvvm_suld_1d_v2i8_zero:
4104   case Intrinsic::nvvm_suld_1d_v4i8_zero:
4105   case Intrinsic::nvvm_suld_1d_array_i8_zero:
4106   case Intrinsic::nvvm_suld_1d_array_v2i8_zero:
4107   case Intrinsic::nvvm_suld_1d_array_v4i8_zero:
4108   case Intrinsic::nvvm_suld_2d_i8_zero:
4109   case Intrinsic::nvvm_suld_2d_v2i8_zero:
4110   case Intrinsic::nvvm_suld_2d_v4i8_zero:
4111   case Intrinsic::nvvm_suld_2d_array_i8_zero:
4112   case Intrinsic::nvvm_suld_2d_array_v2i8_zero:
4113   case Intrinsic::nvvm_suld_2d_array_v4i8_zero:
4114   case Intrinsic::nvvm_suld_3d_i8_zero:
4115   case Intrinsic::nvvm_suld_3d_v2i8_zero:
4116   case Intrinsic::nvvm_suld_3d_v4i8_zero:
4117     Info.opc = ISD::INTRINSIC_W_CHAIN;
4118     Info.memVT = MVT::i8;
4119     Info.ptrVal = nullptr;
4120     Info.offset = 0;
4121     Info.flags = MachineMemOperand::MOLoad;
4122     Info.align = Align(16);
4123     return true;
4124 
4125   case Intrinsic::nvvm_suld_1d_i16_clamp:
4126   case Intrinsic::nvvm_suld_1d_v2i16_clamp:
4127   case Intrinsic::nvvm_suld_1d_v4i16_clamp:
4128   case Intrinsic::nvvm_suld_1d_array_i16_clamp:
4129   case Intrinsic::nvvm_suld_1d_array_v2i16_clamp:
4130   case Intrinsic::nvvm_suld_1d_array_v4i16_clamp:
4131   case Intrinsic::nvvm_suld_2d_i16_clamp:
4132   case Intrinsic::nvvm_suld_2d_v2i16_clamp:
4133   case Intrinsic::nvvm_suld_2d_v4i16_clamp:
4134   case Intrinsic::nvvm_suld_2d_array_i16_clamp:
4135   case Intrinsic::nvvm_suld_2d_array_v2i16_clamp:
4136   case Intrinsic::nvvm_suld_2d_array_v4i16_clamp:
4137   case Intrinsic::nvvm_suld_3d_i16_clamp:
4138   case Intrinsic::nvvm_suld_3d_v2i16_clamp:
4139   case Intrinsic::nvvm_suld_3d_v4i16_clamp:
4140   case Intrinsic::nvvm_suld_1d_i16_trap:
4141   case Intrinsic::nvvm_suld_1d_v2i16_trap:
4142   case Intrinsic::nvvm_suld_1d_v4i16_trap:
4143   case Intrinsic::nvvm_suld_1d_array_i16_trap:
4144   case Intrinsic::nvvm_suld_1d_array_v2i16_trap:
4145   case Intrinsic::nvvm_suld_1d_array_v4i16_trap:
4146   case Intrinsic::nvvm_suld_2d_i16_trap:
4147   case Intrinsic::nvvm_suld_2d_v2i16_trap:
4148   case Intrinsic::nvvm_suld_2d_v4i16_trap:
4149   case Intrinsic::nvvm_suld_2d_array_i16_trap:
4150   case Intrinsic::nvvm_suld_2d_array_v2i16_trap:
4151   case Intrinsic::nvvm_suld_2d_array_v4i16_trap:
4152   case Intrinsic::nvvm_suld_3d_i16_trap:
4153   case Intrinsic::nvvm_suld_3d_v2i16_trap:
4154   case Intrinsic::nvvm_suld_3d_v4i16_trap:
4155   case Intrinsic::nvvm_suld_1d_i16_zero:
4156   case Intrinsic::nvvm_suld_1d_v2i16_zero:
4157   case Intrinsic::nvvm_suld_1d_v4i16_zero:
4158   case Intrinsic::nvvm_suld_1d_array_i16_zero:
4159   case Intrinsic::nvvm_suld_1d_array_v2i16_zero:
4160   case Intrinsic::nvvm_suld_1d_array_v4i16_zero:
4161   case Intrinsic::nvvm_suld_2d_i16_zero:
4162   case Intrinsic::nvvm_suld_2d_v2i16_zero:
4163   case Intrinsic::nvvm_suld_2d_v4i16_zero:
4164   case Intrinsic::nvvm_suld_2d_array_i16_zero:
4165   case Intrinsic::nvvm_suld_2d_array_v2i16_zero:
4166   case Intrinsic::nvvm_suld_2d_array_v4i16_zero:
4167   case Intrinsic::nvvm_suld_3d_i16_zero:
4168   case Intrinsic::nvvm_suld_3d_v2i16_zero:
4169   case Intrinsic::nvvm_suld_3d_v4i16_zero:
4170     Info.opc = ISD::INTRINSIC_W_CHAIN;
4171     Info.memVT = MVT::i16;
4172     Info.ptrVal = nullptr;
4173     Info.offset = 0;
4174     Info.flags = MachineMemOperand::MOLoad;
4175     Info.align = Align(16);
4176     return true;
4177 
4178   case Intrinsic::nvvm_suld_1d_i32_clamp:
4179   case Intrinsic::nvvm_suld_1d_v2i32_clamp:
4180   case Intrinsic::nvvm_suld_1d_v4i32_clamp:
4181   case Intrinsic::nvvm_suld_1d_array_i32_clamp:
4182   case Intrinsic::nvvm_suld_1d_array_v2i32_clamp:
4183   case Intrinsic::nvvm_suld_1d_array_v4i32_clamp:
4184   case Intrinsic::nvvm_suld_2d_i32_clamp:
4185   case Intrinsic::nvvm_suld_2d_v2i32_clamp:
4186   case Intrinsic::nvvm_suld_2d_v4i32_clamp:
4187   case Intrinsic::nvvm_suld_2d_array_i32_clamp:
4188   case Intrinsic::nvvm_suld_2d_array_v2i32_clamp:
4189   case Intrinsic::nvvm_suld_2d_array_v4i32_clamp:
4190   case Intrinsic::nvvm_suld_3d_i32_clamp:
4191   case Intrinsic::nvvm_suld_3d_v2i32_clamp:
4192   case Intrinsic::nvvm_suld_3d_v4i32_clamp:
4193   case Intrinsic::nvvm_suld_1d_i32_trap:
4194   case Intrinsic::nvvm_suld_1d_v2i32_trap:
4195   case Intrinsic::nvvm_suld_1d_v4i32_trap:
4196   case Intrinsic::nvvm_suld_1d_array_i32_trap:
4197   case Intrinsic::nvvm_suld_1d_array_v2i32_trap:
4198   case Intrinsic::nvvm_suld_1d_array_v4i32_trap:
4199   case Intrinsic::nvvm_suld_2d_i32_trap:
4200   case Intrinsic::nvvm_suld_2d_v2i32_trap:
4201   case Intrinsic::nvvm_suld_2d_v4i32_trap:
4202   case Intrinsic::nvvm_suld_2d_array_i32_trap:
4203   case Intrinsic::nvvm_suld_2d_array_v2i32_trap:
4204   case Intrinsic::nvvm_suld_2d_array_v4i32_trap:
4205   case Intrinsic::nvvm_suld_3d_i32_trap:
4206   case Intrinsic::nvvm_suld_3d_v2i32_trap:
4207   case Intrinsic::nvvm_suld_3d_v4i32_trap:
4208   case Intrinsic::nvvm_suld_1d_i32_zero:
4209   case Intrinsic::nvvm_suld_1d_v2i32_zero:
4210   case Intrinsic::nvvm_suld_1d_v4i32_zero:
4211   case Intrinsic::nvvm_suld_1d_array_i32_zero:
4212   case Intrinsic::nvvm_suld_1d_array_v2i32_zero:
4213   case Intrinsic::nvvm_suld_1d_array_v4i32_zero:
4214   case Intrinsic::nvvm_suld_2d_i32_zero:
4215   case Intrinsic::nvvm_suld_2d_v2i32_zero:
4216   case Intrinsic::nvvm_suld_2d_v4i32_zero:
4217   case Intrinsic::nvvm_suld_2d_array_i32_zero:
4218   case Intrinsic::nvvm_suld_2d_array_v2i32_zero:
4219   case Intrinsic::nvvm_suld_2d_array_v4i32_zero:
4220   case Intrinsic::nvvm_suld_3d_i32_zero:
4221   case Intrinsic::nvvm_suld_3d_v2i32_zero:
4222   case Intrinsic::nvvm_suld_3d_v4i32_zero:
4223     Info.opc = ISD::INTRINSIC_W_CHAIN;
4224     Info.memVT = MVT::i32;
4225     Info.ptrVal = nullptr;
4226     Info.offset = 0;
4227     Info.flags = MachineMemOperand::MOLoad;
4228     Info.align = Align(16);
4229     return true;
4230 
4231   case Intrinsic::nvvm_suld_1d_i64_clamp:
4232   case Intrinsic::nvvm_suld_1d_v2i64_clamp:
4233   case Intrinsic::nvvm_suld_1d_array_i64_clamp:
4234   case Intrinsic::nvvm_suld_1d_array_v2i64_clamp:
4235   case Intrinsic::nvvm_suld_2d_i64_clamp:
4236   case Intrinsic::nvvm_suld_2d_v2i64_clamp:
4237   case Intrinsic::nvvm_suld_2d_array_i64_clamp:
4238   case Intrinsic::nvvm_suld_2d_array_v2i64_clamp:
4239   case Intrinsic::nvvm_suld_3d_i64_clamp:
4240   case Intrinsic::nvvm_suld_3d_v2i64_clamp:
4241   case Intrinsic::nvvm_suld_1d_i64_trap:
4242   case Intrinsic::nvvm_suld_1d_v2i64_trap:
4243   case Intrinsic::nvvm_suld_1d_array_i64_trap:
4244   case Intrinsic::nvvm_suld_1d_array_v2i64_trap:
4245   case Intrinsic::nvvm_suld_2d_i64_trap:
4246   case Intrinsic::nvvm_suld_2d_v2i64_trap:
4247   case Intrinsic::nvvm_suld_2d_array_i64_trap:
4248   case Intrinsic::nvvm_suld_2d_array_v2i64_trap:
4249   case Intrinsic::nvvm_suld_3d_i64_trap:
4250   case Intrinsic::nvvm_suld_3d_v2i64_trap:
4251   case Intrinsic::nvvm_suld_1d_i64_zero:
4252   case Intrinsic::nvvm_suld_1d_v2i64_zero:
4253   case Intrinsic::nvvm_suld_1d_array_i64_zero:
4254   case Intrinsic::nvvm_suld_1d_array_v2i64_zero:
4255   case Intrinsic::nvvm_suld_2d_i64_zero:
4256   case Intrinsic::nvvm_suld_2d_v2i64_zero:
4257   case Intrinsic::nvvm_suld_2d_array_i64_zero:
4258   case Intrinsic::nvvm_suld_2d_array_v2i64_zero:
4259   case Intrinsic::nvvm_suld_3d_i64_zero:
4260   case Intrinsic::nvvm_suld_3d_v2i64_zero:
4261     Info.opc = ISD::INTRINSIC_W_CHAIN;
4262     Info.memVT = MVT::i64;
4263     Info.ptrVal = nullptr;
4264     Info.offset = 0;
4265     Info.flags = MachineMemOperand::MOLoad;
4266     Info.align = Align(16);
4267     return true;
4268   }
4269   return false;
4270 }
4271 
4272 /// getFunctionParamOptimizedAlign - since function arguments are passed via
4273 /// .param space, we may want to increase their alignment in a way that
4274 /// ensures that we can effectively vectorize their loads & stores. We can
4275 /// increase alignment only if the function has internal or has private
4276 /// linkage as for other linkage types callers may already rely on default
4277 /// alignment. To allow using 128-bit vectorized loads/stores, this function
4278 /// ensures that alignment is 16 or greater.
4279 Align NVPTXTargetLowering::getFunctionParamOptimizedAlign(
4280     const Function *F, Type *ArgTy, const DataLayout &DL) const {
4281   // Capping the alignment to 128 bytes as that is the maximum alignment
4282   // supported by PTX.
4283   const Align ABITypeAlign = std::min(Align(128), DL.getABITypeAlign(ArgTy));
4284 
4285   // If a function has linkage different from internal or private, we
4286   // must use default ABI alignment as external users rely on it. Same
4287   // for a function that may be called from a function pointer.
4288   if (!F || !F->hasLocalLinkage() ||
4289       F->hasAddressTaken(/*Users=*/nullptr,
4290                          /*IgnoreCallbackUses=*/false,
4291                          /*IgnoreAssumeLikeCalls=*/true,
4292                          /*IgnoreLLVMUsed=*/true))
4293     return ABITypeAlign;
4294 
4295   assert(!isKernelFunction(*F) && "Expect kernels to have non-local linkage");
4296   return std::max(Align(16), ABITypeAlign);
4297 }
4298 
4299 /// Helper for computing alignment of a device function byval parameter.
4300 Align NVPTXTargetLowering::getFunctionByValParamAlign(
4301     const Function *F, Type *ArgTy, Align InitialAlign,
4302     const DataLayout &DL) const {
4303   Align ArgAlign = InitialAlign;
4304   // Try to increase alignment to enhance vectorization options.
4305   if (F)
4306     ArgAlign = std::max(ArgAlign, getFunctionParamOptimizedAlign(F, ArgTy, DL));
4307 
4308   // Old ptx versions have a bug. When PTX code takes address of
4309   // byval parameter with alignment < 4, ptxas generates code to
4310   // spill argument into memory. Alas on sm_50+ ptxas generates
4311   // SASS code that fails with misaligned access. To work around
4312   // the problem, make sure that we align byval parameters by at
4313   // least 4. This bug seems to be fixed at least starting from
4314   // ptxas > 9.0.
4315   // TODO: remove this after verifying the bug is not reproduced
4316   // on non-deprecated ptxas versions.
4317   if (ForceMinByValParamAlign)
4318     ArgAlign = std::max(ArgAlign, Align(4));
4319 
4320   return ArgAlign;
4321 }
4322 
4323 // Helper for getting a function parameter name. Name is composed from
4324 // its index and the function name. Negative index corresponds to special
4325 // parameter (unsized array) used for passing variable arguments.
4326 std::string NVPTXTargetLowering::getParamName(const Function *F,
4327                                               int Idx) const {
4328   std::string ParamName;
4329   raw_string_ostream ParamStr(ParamName);
4330 
4331   ParamStr << getTargetMachine().getSymbol(F)->getName();
4332   if (Idx < 0)
4333     ParamStr << "_vararg";
4334   else
4335     ParamStr << "_param_" << Idx;
4336 
4337   return ParamName;
4338 }
4339 
4340 /// isLegalAddressingMode - Return true if the addressing mode represented
4341 /// by AM is legal for this target, for a load/store of the specified type.
4342 /// Used to guide target specific optimizations, like loop strength reduction
4343 /// (LoopStrengthReduce.cpp) and memory optimization for address mode
4344 /// (CodeGenPrepare.cpp)
4345 bool NVPTXTargetLowering::isLegalAddressingMode(const DataLayout &DL,
4346                                                 const AddrMode &AM, Type *Ty,
4347                                                 unsigned AS, Instruction *I) const {
4348   // AddrMode - This represents an addressing mode of:
4349   //    BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
4350   //
4351   // The legal address modes are
4352   // - [avar]
4353   // - [areg]
4354   // - [areg+immoff]
4355   // - [immAddr]
4356 
4357   // immoff must fit in a signed 32-bit int
4358   if (!APInt(64, AM.BaseOffs).isSignedIntN(32))
4359     return false;
4360 
4361   if (AM.BaseGV)
4362     return !AM.BaseOffs && !AM.HasBaseReg && !AM.Scale;
4363 
4364   switch (AM.Scale) {
4365   case 0: // "r", "r+i" or "i" is allowed
4366     break;
4367   case 1:
4368     if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
4369       return false;
4370     // Otherwise we have r+i.
4371     break;
4372   default:
4373     // No scale > 1 is allowed
4374     return false;
4375   }
4376   return true;
4377 }
4378 
4379 //===----------------------------------------------------------------------===//
4380 //                         NVPTX Inline Assembly Support
4381 //===----------------------------------------------------------------------===//
4382 
4383 /// getConstraintType - Given a constraint letter, return the type of
4384 /// constraint it is for this target.
4385 NVPTXTargetLowering::ConstraintType
4386 NVPTXTargetLowering::getConstraintType(StringRef Constraint) const {
4387   if (Constraint.size() == 1) {
4388     switch (Constraint[0]) {
4389     default:
4390       break;
4391     case 'b':
4392     case 'r':
4393     case 'h':
4394     case 'c':
4395     case 'l':
4396     case 'f':
4397     case 'd':
4398     case 'q':
4399     case '0':
4400     case 'N':
4401       return C_RegisterClass;
4402     }
4403   }
4404   return TargetLowering::getConstraintType(Constraint);
4405 }
4406 
4407 std::pair<unsigned, const TargetRegisterClass *>
4408 NVPTXTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
4409                                                   StringRef Constraint,
4410                                                   MVT VT) const {
4411   if (Constraint.size() == 1) {
4412     switch (Constraint[0]) {
4413     case 'b':
4414       return std::make_pair(0U, &NVPTX::Int1RegsRegClass);
4415     case 'c':
4416       return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
4417     case 'h':
4418       return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
4419     case 'r':
4420       return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
4421     case 'l':
4422     case 'N':
4423       return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
4424     case 'q': {
4425       if (STI.getSmVersion() < 70)
4426         report_fatal_error("Inline asm with 128 bit operands is only "
4427                            "supported for sm_70 and higher!");
4428       return std::make_pair(0U, &NVPTX::Int128RegsRegClass);
4429     }
4430     case 'f':
4431       return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
4432     case 'd':
4433       return std::make_pair(0U, &NVPTX::Float64RegsRegClass);
4434     }
4435   }
4436   return TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT);
4437 }
4438 
4439 //===----------------------------------------------------------------------===//
4440 //                         NVPTX DAG Combining
4441 //===----------------------------------------------------------------------===//
4442 
4443 bool NVPTXTargetLowering::allowFMA(MachineFunction &MF,
4444                                    CodeGenOptLevel OptLevel) const {
4445   // Always honor command-line argument
4446   if (FMAContractLevelOpt.getNumOccurrences() > 0)
4447     return FMAContractLevelOpt > 0;
4448 
4449   // Do not contract if we're not optimizing the code.
4450   if (OptLevel == CodeGenOptLevel::None)
4451     return false;
4452 
4453   // Honor TargetOptions flags that explicitly say fusion is okay.
4454   if (MF.getTarget().Options.AllowFPOpFusion == FPOpFusion::Fast)
4455     return true;
4456 
4457   return allowUnsafeFPMath(MF);
4458 }
4459 
4460 bool NVPTXTargetLowering::allowUnsafeFPMath(MachineFunction &MF) const {
4461   // Honor TargetOptions flags that explicitly say unsafe math is okay.
4462   if (MF.getTarget().Options.UnsafeFPMath)
4463     return true;
4464 
4465   // Allow unsafe math if unsafe-fp-math attribute explicitly says so.
4466   const Function &F = MF.getFunction();
4467   return F.getFnAttribute("unsafe-fp-math").getValueAsBool();
4468 }
4469 
4470 static bool isConstZero(const SDValue &Operand) {
4471   const auto *Const = dyn_cast<ConstantSDNode>(Operand);
4472   return Const && Const->getZExtValue() == 0;
4473 }
4474 
4475 /// PerformADDCombineWithOperands - Try DAG combinations for an ADD with
4476 /// operands N0 and N1.  This is a helper for PerformADDCombine that is
4477 /// called with the default operands, and if that fails, with commuted
4478 /// operands.
4479 static SDValue
4480 PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
4481                               TargetLowering::DAGCombinerInfo &DCI) {
4482   EVT VT = N0.getValueType();
4483 
4484   // Since integer multiply-add costs the same as integer multiply
4485   // but is more costly than integer add, do the fusion only when
4486   // the mul is only used in the add.
4487   // TODO: this may not be true for later architectures, consider relaxing this
4488   if (!N0.getNode()->hasOneUse())
4489     return SDValue();
4490 
4491   // fold (add (select cond, 0, (mul a, b)), c)
4492   //   -> (select cond, c, (add (mul a, b), c))
4493   //
4494   if (N0.getOpcode() == ISD::SELECT) {
4495     unsigned ZeroOpNum;
4496     if (isConstZero(N0->getOperand(1)))
4497       ZeroOpNum = 1;
4498     else if (isConstZero(N0->getOperand(2)))
4499       ZeroOpNum = 2;
4500     else
4501       return SDValue();
4502 
4503     SDValue M = N0->getOperand((ZeroOpNum == 1) ? 2 : 1);
4504     if (M->getOpcode() != ISD::MUL || !M.getNode()->hasOneUse())
4505       return SDValue();
4506 
4507     SDLoc DL(N);
4508     SDValue Mul =
4509         DCI.DAG.getNode(ISD::MUL, DL, VT, M->getOperand(0), M->getOperand(1));
4510     SDValue MAD = DCI.DAG.getNode(ISD::ADD, DL, VT, Mul, N1);
4511     return DCI.DAG.getSelect(SDLoc(N), VT, N0->getOperand(0),
4512                              ((ZeroOpNum == 1) ? N1 : MAD),
4513                              ((ZeroOpNum == 1) ? MAD : N1));
4514   }
4515 
4516   return SDValue();
4517 }
4518 
4519 static SDValue
4520 PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
4521                                TargetLowering::DAGCombinerInfo &DCI,
4522                                CodeGenOptLevel OptLevel) {
4523   EVT VT = N0.getValueType();
4524   if (N0.getOpcode() == ISD::FMUL) {
4525     const auto *TLI = static_cast<const NVPTXTargetLowering *>(
4526         &DCI.DAG.getTargetLoweringInfo());
4527     if (!TLI->allowFMA(DCI.DAG.getMachineFunction(), OptLevel))
4528       return SDValue();
4529 
4530     // For floating point:
4531     // Do the fusion only when the mul has less than 5 uses and all
4532     // are add.
4533     // The heuristic is that if a use is not an add, then that use
4534     // cannot be fused into fma, therefore mul is still needed anyway.
4535     // If there are more than 4 uses, even if they are all add, fusing
4536     // them will increase register pressue.
4537     //
4538     int numUses = 0;
4539     int nonAddCount = 0;
4540     for (const SDNode *User : N0.getNode()->users()) {
4541       numUses++;
4542       if (User->getOpcode() != ISD::FADD)
4543         ++nonAddCount;
4544       if (numUses >= 5)
4545         return SDValue();
4546     }
4547     if (nonAddCount) {
4548       int orderNo = N->getIROrder();
4549       int orderNo2 = N0.getNode()->getIROrder();
4550       // simple heuristics here for considering potential register
4551       // pressure, the logics here is that the differnce are used
4552       // to measure the distance between def and use, the longer distance
4553       // more likely cause register pressure.
4554       if (orderNo - orderNo2 < 500)
4555         return SDValue();
4556 
4557       // Now, check if at least one of the FMUL's operands is live beyond the
4558       // node N, which guarantees that the FMA will not increase register
4559       // pressure at node N.
4560       bool opIsLive = false;
4561       const SDNode *left = N0.getOperand(0).getNode();
4562       const SDNode *right = N0.getOperand(1).getNode();
4563 
4564       if (isa<ConstantSDNode>(left) || isa<ConstantSDNode>(right))
4565         opIsLive = true;
4566 
4567       if (!opIsLive)
4568         for (const SDNode *User : left->users()) {
4569           int orderNo3 = User->getIROrder();
4570           if (orderNo3 > orderNo) {
4571             opIsLive = true;
4572             break;
4573           }
4574         }
4575 
4576       if (!opIsLive)
4577         for (const SDNode *User : right->users()) {
4578           int orderNo3 = User->getIROrder();
4579           if (orderNo3 > orderNo) {
4580             opIsLive = true;
4581             break;
4582           }
4583         }
4584 
4585       if (!opIsLive)
4586         return SDValue();
4587     }
4588 
4589     return DCI.DAG.getNode(ISD::FMA, SDLoc(N), VT, N0.getOperand(0),
4590                            N0.getOperand(1), N1);
4591   }
4592 
4593   return SDValue();
4594 }
4595 
4596 static SDValue PerformStoreCombineHelper(SDNode *N, std::size_t Front,
4597                                          std::size_t Back) {
4598   if (all_of(N->ops().drop_front(Front).drop_back(Back),
4599              [](const SDUse &U) { return U.get()->isUndef(); }))
4600     // Operand 0 is the previous value in the chain. Cannot return EntryToken
4601     // as the previous value will become unused and eliminated later.
4602     return N->getOperand(0);
4603 
4604   return SDValue();
4605 }
4606 
4607 static SDValue PerformStoreParamCombine(SDNode *N) {
4608   // Operands from the 3rd to the 2nd last one are the values to be stored.
4609   //   {Chain, ArgID, Offset, Val, Glue}
4610   return PerformStoreCombineHelper(N, 3, 1);
4611 }
4612 
4613 static SDValue PerformStoreRetvalCombine(SDNode *N) {
4614   // Operands from the 2nd to the last one are the values to be stored
4615   return PerformStoreCombineHelper(N, 2, 0);
4616 }
4617 
4618 /// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
4619 ///
4620 static SDValue PerformADDCombine(SDNode *N,
4621                                  TargetLowering::DAGCombinerInfo &DCI,
4622                                  CodeGenOptLevel OptLevel) {
4623   if (OptLevel == CodeGenOptLevel::None)
4624     return SDValue();
4625 
4626   SDValue N0 = N->getOperand(0);
4627   SDValue N1 = N->getOperand(1);
4628 
4629   // Skip non-integer, non-scalar case
4630   EVT VT = N0.getValueType();
4631   if (VT.isVector() || VT != MVT::i32)
4632     return SDValue();
4633 
4634   // First try with the default operand order.
4635   if (SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI))
4636     return Result;
4637 
4638   // If that didn't work, try again with the operands commuted.
4639   return PerformADDCombineWithOperands(N, N1, N0, DCI);
4640 }
4641 
4642 /// PerformFADDCombine - Target-specific dag combine xforms for ISD::FADD.
4643 ///
4644 static SDValue PerformFADDCombine(SDNode *N,
4645                                  TargetLowering::DAGCombinerInfo &DCI,
4646                                  CodeGenOptLevel OptLevel) {
4647   SDValue N0 = N->getOperand(0);
4648   SDValue N1 = N->getOperand(1);
4649 
4650   EVT VT = N0.getValueType();
4651   if (VT.isVector() || !(VT == MVT::f32 || VT == MVT::f64))
4652     return SDValue();
4653 
4654   // First try with the default operand order.
4655   if (SDValue Result = PerformFADDCombineWithOperands(N, N0, N1, DCI, OptLevel))
4656     return Result;
4657 
4658   // If that didn't work, try again with the operands commuted.
4659   return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
4660 }
4661 
4662 static SDValue PerformANDCombine(SDNode *N,
4663                                  TargetLowering::DAGCombinerInfo &DCI) {
4664   // The type legalizer turns a vector load of i8 values into a zextload to i16
4665   // registers, optionally ANY_EXTENDs it (if target type is integer),
4666   // and ANDs off the high 8 bits. Since we turn this load into a
4667   // target-specific DAG node, the DAG combiner fails to eliminate these AND
4668   // nodes. Do that here.
4669   SDValue Val = N->getOperand(0);
4670   SDValue Mask = N->getOperand(1);
4671 
4672   if (isa<ConstantSDNode>(Val)) {
4673     std::swap(Val, Mask);
4674   }
4675 
4676   SDValue AExt;
4677 
4678   // Convert BFE-> truncate i16 -> and 255
4679   // To just BFE-> truncate i16, as the value already has all the bits in the
4680   // right places.
4681   if (Val.getOpcode() == ISD::TRUNCATE) {
4682     SDValue BFE = Val.getOperand(0);
4683     if (BFE.getOpcode() != NVPTXISD::BFE)
4684       return SDValue();
4685 
4686     ConstantSDNode *BFEBits = dyn_cast<ConstantSDNode>(BFE.getOperand(0));
4687     if (!BFEBits)
4688       return SDValue();
4689     uint64_t BFEBitsVal = BFEBits->getZExtValue();
4690 
4691     ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
4692     if (!MaskCnst) {
4693       // Not an AND with a constant
4694       return SDValue();
4695     }
4696     uint64_t MaskVal = MaskCnst->getZExtValue();
4697 
4698     if (MaskVal != (uint64_t(1) << BFEBitsVal) - 1)
4699       return SDValue();
4700     // If we get here, the AND is unnecessary.  Just replace it with the trunc
4701     DCI.CombineTo(N, Val, false);
4702   }
4703   // Generally, we will see zextload -> IMOV16rr -> ANY_EXTEND -> and
4704   if (Val.getOpcode() == ISD::ANY_EXTEND) {
4705     AExt = Val;
4706     Val = Val->getOperand(0);
4707   }
4708 
4709   if (Val->getOpcode() == NVPTXISD::LoadV2 ||
4710       Val->getOpcode() == NVPTXISD::LoadV4) {
4711     ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
4712     if (!MaskCnst) {
4713       // Not an AND with a constant
4714       return SDValue();
4715     }
4716 
4717     uint64_t MaskVal = MaskCnst->getZExtValue();
4718     if (MaskVal != 0xff) {
4719       // Not an AND that chops off top 8 bits
4720       return SDValue();
4721     }
4722 
4723     MemSDNode *Mem = dyn_cast<MemSDNode>(Val);
4724     if (!Mem) {
4725       // Not a MemSDNode?!?
4726       return SDValue();
4727     }
4728 
4729     EVT MemVT = Mem->getMemoryVT();
4730     if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) {
4731       // We only handle the i8 case
4732       return SDValue();
4733     }
4734 
4735     unsigned ExtType = Val->getConstantOperandVal(Val->getNumOperands() - 1);
4736     if (ExtType == ISD::SEXTLOAD) {
4737       // If for some reason the load is a sextload, the and is needed to zero
4738       // out the high 8 bits
4739       return SDValue();
4740     }
4741 
4742     bool AddTo = false;
4743     if (AExt.getNode() != nullptr) {
4744       // Re-insert the ext as a zext.
4745       Val = DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N),
4746                             AExt.getValueType(), Val);
4747       AddTo = true;
4748     }
4749 
4750     // If we get here, the AND is unnecessary.  Just replace it with the load
4751     DCI.CombineTo(N, Val, AddTo);
4752   }
4753 
4754   return SDValue();
4755 }
4756 
4757 static SDValue PerformREMCombine(SDNode *N,
4758                                  TargetLowering::DAGCombinerInfo &DCI,
4759                                  CodeGenOptLevel OptLevel) {
4760   assert(N->getOpcode() == ISD::SREM || N->getOpcode() == ISD::UREM);
4761 
4762   // Don't do anything at less than -O2.
4763   if (OptLevel < CodeGenOptLevel::Default)
4764     return SDValue();
4765 
4766   SelectionDAG &DAG = DCI.DAG;
4767   SDLoc DL(N);
4768   EVT VT = N->getValueType(0);
4769   bool IsSigned = N->getOpcode() == ISD::SREM;
4770   unsigned DivOpc = IsSigned ? ISD::SDIV : ISD::UDIV;
4771 
4772   const SDValue &Num = N->getOperand(0);
4773   const SDValue &Den = N->getOperand(1);
4774 
4775   for (const SDNode *U : Num->users()) {
4776     if (U->getOpcode() == DivOpc && U->getOperand(0) == Num &&
4777         U->getOperand(1) == Den) {
4778       // Num % Den -> Num - (Num / Den) * Den
4779       return DAG.getNode(ISD::SUB, DL, VT, Num,
4780                          DAG.getNode(ISD::MUL, DL, VT,
4781                                      DAG.getNode(DivOpc, DL, VT, Num, Den),
4782                                      Den));
4783     }
4784   }
4785   return SDValue();
4786 }
4787 
4788 enum OperandSignedness {
4789   Signed = 0,
4790   Unsigned,
4791   Unknown
4792 };
4793 
4794 /// IsMulWideOperandDemotable - Checks if the provided DAG node is an operand
4795 /// that can be demoted to \p OptSize bits without loss of information. The
4796 /// signedness of the operand, if determinable, is placed in \p S.
4797 static bool IsMulWideOperandDemotable(SDValue Op,
4798                                       unsigned OptSize,
4799                                       OperandSignedness &S) {
4800   S = Unknown;
4801 
4802   if (Op.getOpcode() == ISD::SIGN_EXTEND ||
4803       Op.getOpcode() == ISD::SIGN_EXTEND_INREG) {
4804     EVT OrigVT = Op.getOperand(0).getValueType();
4805     if (OrigVT.getFixedSizeInBits() <= OptSize) {
4806       S = Signed;
4807       return true;
4808     }
4809   } else if (Op.getOpcode() == ISD::ZERO_EXTEND) {
4810     EVT OrigVT = Op.getOperand(0).getValueType();
4811     if (OrigVT.getFixedSizeInBits() <= OptSize) {
4812       S = Unsigned;
4813       return true;
4814     }
4815   }
4816 
4817   return false;
4818 }
4819 
4820 /// AreMulWideOperandsDemotable - Checks if the given LHS and RHS operands can
4821 /// be demoted to \p OptSize bits without loss of information. If the operands
4822 /// contain a constant, it should appear as the RHS operand. The signedness of
4823 /// the operands is placed in \p IsSigned.
4824 static bool AreMulWideOperandsDemotable(SDValue LHS, SDValue RHS,
4825                                         unsigned OptSize,
4826                                         bool &IsSigned) {
4827   OperandSignedness LHSSign;
4828 
4829   // The LHS operand must be a demotable op
4830   if (!IsMulWideOperandDemotable(LHS, OptSize, LHSSign))
4831     return false;
4832 
4833   // We should have been able to determine the signedness from the LHS
4834   if (LHSSign == Unknown)
4835     return false;
4836 
4837   IsSigned = (LHSSign == Signed);
4838 
4839   // The RHS can be a demotable op or a constant
4840   if (ConstantSDNode *CI = dyn_cast<ConstantSDNode>(RHS)) {
4841     const APInt &Val = CI->getAPIntValue();
4842     if (LHSSign == Unsigned) {
4843       return Val.isIntN(OptSize);
4844     } else {
4845       return Val.isSignedIntN(OptSize);
4846     }
4847   } else {
4848     OperandSignedness RHSSign;
4849     if (!IsMulWideOperandDemotable(RHS, OptSize, RHSSign))
4850       return false;
4851 
4852     return LHSSign == RHSSign;
4853   }
4854 }
4855 
4856 /// TryMULWIDECombine - Attempt to replace a multiply of M bits with a multiply
4857 /// of M/2 bits that produces an M-bit result (i.e. mul.wide). This transform
4858 /// works on both multiply DAG nodes and SHL DAG nodes with a constant shift
4859 /// amount.
4860 static SDValue TryMULWIDECombine(SDNode *N,
4861                                  TargetLowering::DAGCombinerInfo &DCI) {
4862   EVT MulType = N->getValueType(0);
4863   if (MulType != MVT::i32 && MulType != MVT::i64) {
4864     return SDValue();
4865   }
4866 
4867   SDLoc DL(N);
4868   unsigned OptSize = MulType.getSizeInBits() >> 1;
4869   SDValue LHS = N->getOperand(0);
4870   SDValue RHS = N->getOperand(1);
4871 
4872   // Canonicalize the multiply so the constant (if any) is on the right
4873   if (N->getOpcode() == ISD::MUL) {
4874     if (isa<ConstantSDNode>(LHS)) {
4875       std::swap(LHS, RHS);
4876     }
4877   }
4878 
4879   // If we have a SHL, determine the actual multiply amount
4880   if (N->getOpcode() == ISD::SHL) {
4881     ConstantSDNode *ShlRHS = dyn_cast<ConstantSDNode>(RHS);
4882     if (!ShlRHS) {
4883       return SDValue();
4884     }
4885 
4886     APInt ShiftAmt = ShlRHS->getAPIntValue();
4887     unsigned BitWidth = MulType.getSizeInBits();
4888     if (ShiftAmt.sge(0) && ShiftAmt.slt(BitWidth)) {
4889       APInt MulVal = APInt(BitWidth, 1) << ShiftAmt;
4890       RHS = DCI.DAG.getConstant(MulVal, DL, MulType);
4891     } else {
4892       return SDValue();
4893     }
4894   }
4895 
4896   bool Signed;
4897   // Verify that our operands are demotable
4898   if (!AreMulWideOperandsDemotable(LHS, RHS, OptSize, Signed)) {
4899     return SDValue();
4900   }
4901 
4902   EVT DemotedVT;
4903   if (MulType == MVT::i32) {
4904     DemotedVT = MVT::i16;
4905   } else {
4906     DemotedVT = MVT::i32;
4907   }
4908 
4909   // Truncate the operands to the correct size. Note that these are just for
4910   // type consistency and will (likely) be eliminated in later phases.
4911   SDValue TruncLHS =
4912     DCI.DAG.getNode(ISD::TRUNCATE, DL, DemotedVT, LHS);
4913   SDValue TruncRHS =
4914     DCI.DAG.getNode(ISD::TRUNCATE, DL, DemotedVT, RHS);
4915 
4916   unsigned Opc;
4917   if (Signed) {
4918     Opc = NVPTXISD::MUL_WIDE_SIGNED;
4919   } else {
4920     Opc = NVPTXISD::MUL_WIDE_UNSIGNED;
4921   }
4922 
4923   return DCI.DAG.getNode(Opc, DL, MulType, TruncLHS, TruncRHS);
4924 }
4925 
4926 static bool isConstOne(const SDValue &Operand) {
4927   const auto *Const = dyn_cast<ConstantSDNode>(Operand);
4928   return Const && Const->getZExtValue() == 1;
4929 }
4930 
4931 static SDValue matchMADConstOnePattern(SDValue Add) {
4932   if (Add->getOpcode() != ISD::ADD)
4933     return SDValue();
4934 
4935   if (isConstOne(Add->getOperand(0)))
4936     return Add->getOperand(1);
4937 
4938   if (isConstOne(Add->getOperand(1)))
4939     return Add->getOperand(0);
4940 
4941   return SDValue();
4942 }
4943 
4944 static SDValue combineMADConstOne(SDValue X, SDValue Add, EVT VT, SDLoc DL,
4945                                   TargetLowering::DAGCombinerInfo &DCI) {
4946 
4947   if (SDValue Y = matchMADConstOnePattern(Add)) {
4948     SDValue Mul = DCI.DAG.getNode(ISD::MUL, DL, VT, X, Y);
4949     return DCI.DAG.getNode(ISD::ADD, DL, VT, Mul, X);
4950   }
4951 
4952   return SDValue();
4953 }
4954 
4955 static SDValue combineMulSelectConstOne(SDValue X, SDValue Select, EVT VT,
4956                                         SDLoc DL,
4957                                         TargetLowering::DAGCombinerInfo &DCI) {
4958   if (Select->getOpcode() != ISD::SELECT)
4959     return SDValue();
4960 
4961   SDValue Cond = Select->getOperand(0);
4962 
4963   unsigned ConstOpNo;
4964   if (isConstOne(Select->getOperand(1)))
4965     ConstOpNo = 1;
4966   else if (isConstOne(Select->getOperand(2)))
4967     ConstOpNo = 2;
4968   else
4969     return SDValue();
4970 
4971   SDValue Y = Select->getOperand((ConstOpNo == 1) ? 2 : 1);
4972 
4973   // Do not combine if the resulting sequence is not obviously profitable.
4974   if (!matchMADConstOnePattern(Y))
4975     return SDValue();
4976 
4977   SDValue NewMul = DCI.DAG.getNode(ISD::MUL, DL, VT, X, Y);
4978 
4979   return DCI.DAG.getNode(ISD::SELECT, DL, VT, Cond,
4980                          (ConstOpNo == 1) ? X : NewMul,
4981                          (ConstOpNo == 1) ? NewMul : X);
4982 }
4983 
4984 static SDValue
4985 PerformMULCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
4986                               TargetLowering::DAGCombinerInfo &DCI) {
4987 
4988   EVT VT = N0.getValueType();
4989   if (VT.isVector())
4990     return SDValue();
4991 
4992   if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
4993     return SDValue();
4994 
4995   SDLoc DL(N);
4996 
4997   // (mul x, (add y, 1)) -> (add (mul x, y), x)
4998   if (SDValue Res = combineMADConstOne(N0, N1, VT, DL, DCI))
4999     return Res;
5000   if (SDValue Res = combineMADConstOne(N1, N0, VT, DL, DCI))
5001     return Res;
5002 
5003   // (mul x, (select y, 1)) -> (select (mul x, y), x)
5004   if (SDValue Res = combineMulSelectConstOne(N0, N1, VT, DL, DCI))
5005     return Res;
5006   if (SDValue Res = combineMulSelectConstOne(N1, N0, VT, DL, DCI))
5007     return Res;
5008 
5009   return SDValue();
5010 }
5011 
5012 /// PerformMULCombine - Runs PTX-specific DAG combine patterns on MUL nodes.
5013 static SDValue PerformMULCombine(SDNode *N,
5014                                  TargetLowering::DAGCombinerInfo &DCI,
5015                                  CodeGenOptLevel OptLevel) {
5016   if (OptLevel == CodeGenOptLevel::None)
5017     return SDValue();
5018 
5019   if (SDValue Ret = TryMULWIDECombine(N, DCI))
5020     return Ret;
5021 
5022   SDValue N0 = N->getOperand(0);
5023   SDValue N1 = N->getOperand(1);
5024   return PerformMULCombineWithOperands(N, N0, N1, DCI);
5025 }
5026 
5027 /// PerformSHLCombine - Runs PTX-specific DAG combine patterns on SHL nodes.
5028 static SDValue PerformSHLCombine(SDNode *N,
5029                                  TargetLowering::DAGCombinerInfo &DCI,
5030                                  CodeGenOptLevel OptLevel) {
5031   if (OptLevel > CodeGenOptLevel::None) {
5032     // Try mul.wide combining at OptLevel > 0
5033     if (SDValue Ret = TryMULWIDECombine(N, DCI))
5034       return Ret;
5035   }
5036 
5037   return SDValue();
5038 }
5039 
5040 static SDValue PerformSETCCCombine(SDNode *N,
5041                                    TargetLowering::DAGCombinerInfo &DCI,
5042                                    unsigned int SmVersion) {
5043   EVT CCType = N->getValueType(0);
5044   SDValue A = N->getOperand(0);
5045   SDValue B = N->getOperand(1);
5046 
5047   EVT AType = A.getValueType();
5048   if (!(CCType == MVT::v2i1 && (AType == MVT::v2f16 || AType == MVT::v2bf16)))
5049     return SDValue();
5050 
5051   if (A.getValueType() == MVT::v2bf16 && SmVersion < 90)
5052     return SDValue();
5053 
5054   SDLoc DL(N);
5055   // setp.f16x2 returns two scalar predicates, which we need to
5056   // convert back to v2i1. The returned result will be scalarized by
5057   // the legalizer, but the comparison will remain a single vector
5058   // instruction.
5059   SDValue CCNode = DCI.DAG.getNode(
5060       A.getValueType() == MVT::v2f16 ? NVPTXISD::SETP_F16X2
5061                                      : NVPTXISD::SETP_BF16X2,
5062       DL, DCI.DAG.getVTList(MVT::i1, MVT::i1), {A, B, N->getOperand(2)});
5063   return DCI.DAG.getNode(ISD::BUILD_VECTOR, DL, CCType, CCNode.getValue(0),
5064                          CCNode.getValue(1));
5065 }
5066 
5067 static SDValue PerformEXTRACTCombine(SDNode *N,
5068                                      TargetLowering::DAGCombinerInfo &DCI) {
5069   SDValue Vector = N->getOperand(0);
5070   if (Vector->getOpcode() == ISD::FREEZE)
5071     Vector = Vector->getOperand(0);
5072   SDLoc DL(N);
5073   EVT VectorVT = Vector.getValueType();
5074   if (Vector->getOpcode() == ISD::LOAD && VectorVT.isSimple() &&
5075       IsPTXVectorType(VectorVT.getSimpleVT()))
5076     return SDValue(); // Native vector loads already combine nicely w/
5077                       // extract_vector_elt.
5078   // Don't mess with singletons or v2*16, v4i8 and v8i8 types, we already
5079   // handle them OK.
5080   if (VectorVT.getVectorNumElements() == 1 || Isv2x16VT(VectorVT) ||
5081       VectorVT == MVT::v4i8 || VectorVT == MVT::v8i8)
5082     return SDValue();
5083 
5084   // Don't mess with undef values as sra may be simplified to 0, not undef.
5085   if (Vector->isUndef() || ISD::allOperandsUndef(Vector.getNode()))
5086     return SDValue();
5087 
5088   uint64_t VectorBits = VectorVT.getSizeInBits();
5089   // We only handle the types we can extract in-register.
5090   if (!(VectorBits == 16 || VectorBits == 32 || VectorBits == 64))
5091     return SDValue();
5092 
5093   ConstantSDNode *Index = dyn_cast<ConstantSDNode>(N->getOperand(1));
5094   // Index == 0 is handled by generic DAG combiner.
5095   if (!Index || Index->getZExtValue() == 0)
5096     return SDValue();
5097 
5098   MVT IVT = MVT::getIntegerVT(VectorBits);
5099   EVT EltVT = VectorVT.getVectorElementType();
5100   EVT EltIVT = EltVT.changeTypeToInteger();
5101   uint64_t EltBits = EltVT.getScalarSizeInBits();
5102 
5103   SDValue Result = DCI.DAG.getNode(
5104       ISD::TRUNCATE, DL, EltIVT,
5105       DCI.DAG.getNode(
5106           ISD::SRA, DL, IVT, DCI.DAG.getNode(ISD::BITCAST, DL, IVT, Vector),
5107           DCI.DAG.getConstant(Index->getZExtValue() * EltBits, DL, IVT)));
5108 
5109   // If element has non-integer type, bitcast it back to the expected type.
5110   if (EltVT != EltIVT)
5111     Result = DCI.DAG.getNode(ISD::BITCAST, DL, EltVT, Result);
5112   // Past legalizer, we may need to extent i8 -> i16 to match the register type.
5113   if (EltVT != N->getValueType(0))
5114     Result = DCI.DAG.getNode(ISD::ANY_EXTEND, DL, N->getValueType(0), Result);
5115 
5116   return Result;
5117 }
5118 
5119 static SDValue PerformVSELECTCombine(SDNode *N,
5120                                      TargetLowering::DAGCombinerInfo &DCI) {
5121   SDValue VA = N->getOperand(1);
5122   EVT VectorVT = VA.getValueType();
5123   if (VectorVT != MVT::v4i8)
5124     return SDValue();
5125 
5126   // We need to split vselect into individual per-element operations Because we
5127   // use BFE/BFI instruction for byte extraction/insertion, we do end up with
5128   // 32-bit values, so we may as well do comparison as i32 to avoid conversions
5129   // to/from i16 normally used for i8 values.
5130   SmallVector<SDValue, 4> E;
5131   SDLoc DL(N);
5132   SDValue VCond = N->getOperand(0);
5133   SDValue VB = N->getOperand(2);
5134   for (int I = 0; I < 4; ++I) {
5135     SDValue C = DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i1, VCond,
5136                                 DCI.DAG.getConstant(I, DL, MVT::i32));
5137     SDValue EA = DCI.DAG.getAnyExtOrTrunc(
5138         DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, VA,
5139                         DCI.DAG.getConstant(I, DL, MVT::i32)),
5140         DL, MVT::i32);
5141     SDValue EB = DCI.DAG.getAnyExtOrTrunc(
5142         DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, VB,
5143                         DCI.DAG.getConstant(I, DL, MVT::i32)),
5144         DL, MVT::i32);
5145     E.push_back(DCI.DAG.getAnyExtOrTrunc(
5146         DCI.DAG.getNode(ISD::SELECT, DL, MVT::i32, C, EA, EB), DL, MVT::i8));
5147   }
5148   return DCI.DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v4i8, E);
5149 }
5150 
5151 static SDValue
5152 PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
5153   auto VT = N->getValueType(0);
5154   if (!DCI.isAfterLegalizeDAG() || !Isv2x16VT(VT))
5155     return SDValue();
5156 
5157   auto Op0 = N->getOperand(0);
5158   auto Op1 = N->getOperand(1);
5159 
5160   // Start out by assuming we want to take the lower 2 bytes of each i32
5161   // operand.
5162   uint64_t Op0Bytes = 0x10;
5163   uint64_t Op1Bytes = 0x54;
5164 
5165   std::pair<SDValue *, uint64_t *> OpData[2] = {{&Op0, &Op0Bytes},
5166                                                 {&Op1, &Op1Bytes}};
5167 
5168   // Check that each operand is an i16, truncated from an i32 operand. We'll
5169   // select individual bytes from those original operands. Optionally, fold in a
5170   // shift right of that original operand.
5171   for (auto &[Op, OpBytes] : OpData) {
5172     // Eat up any bitcast
5173     if (Op->getOpcode() == ISD::BITCAST)
5174       *Op = Op->getOperand(0);
5175 
5176     if (!(Op->getValueType() == MVT::i16 && Op->getOpcode() == ISD::TRUNCATE &&
5177           Op->getOperand(0).getValueType() == MVT::i32))
5178       return SDValue();
5179 
5180     // If the truncate has multiple uses, this optimization can increase
5181     // register pressure
5182     if (!Op->hasOneUse())
5183       return SDValue();
5184 
5185     *Op = Op->getOperand(0);
5186 
5187     // Optionally, fold in a shift-right of the original operand and let permute
5188     // pick the two higher bytes of the original value directly.
5189     if (Op->getOpcode() == ISD::SRL && isa<ConstantSDNode>(Op->getOperand(1))) {
5190       if (cast<ConstantSDNode>(Op->getOperand(1))->getZExtValue() == 16) {
5191         // Shift the PRMT byte selector to pick upper bytes from each respective
5192         // value, instead of the lower ones: 0x10 -> 0x32, 0x54 -> 0x76
5193         assert((*OpBytes == 0x10 || *OpBytes == 0x54) &&
5194                "PRMT selector values out of range");
5195         *OpBytes += 0x22;
5196         *Op = Op->getOperand(0);
5197       }
5198     }
5199   }
5200 
5201   SDLoc DL(N);
5202   auto &DAG = DCI.DAG;
5203 
5204   auto PRMT = DAG.getNode(
5205       NVPTXISD::PRMT, DL, MVT::v4i8,
5206       {Op0, Op1, DAG.getConstant((Op1Bytes << 8) | Op0Bytes, DL, MVT::i32),
5207        DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});
5208   return DAG.getNode(ISD::BITCAST, DL, VT, PRMT);
5209 }
5210 
5211 SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
5212                                                DAGCombinerInfo &DCI) const {
5213   CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
5214   switch (N->getOpcode()) {
5215     default: break;
5216     case ISD::ADD:
5217       return PerformADDCombine(N, DCI, OptLevel);
5218     case ISD::FADD:
5219       return PerformFADDCombine(N, DCI, OptLevel);
5220     case ISD::MUL:
5221       return PerformMULCombine(N, DCI, OptLevel);
5222     case ISD::SHL:
5223       return PerformSHLCombine(N, DCI, OptLevel);
5224     case ISD::AND:
5225       return PerformANDCombine(N, DCI);
5226     case ISD::UREM:
5227     case ISD::SREM:
5228       return PerformREMCombine(N, DCI, OptLevel);
5229     case ISD::SETCC:
5230       return PerformSETCCCombine(N, DCI, STI.getSmVersion());
5231     case NVPTXISD::StoreRetval:
5232     case NVPTXISD::StoreRetvalV2:
5233     case NVPTXISD::StoreRetvalV4:
5234       return PerformStoreRetvalCombine(N);
5235     case NVPTXISD::StoreParam:
5236     case NVPTXISD::StoreParamV2:
5237     case NVPTXISD::StoreParamV4:
5238       return PerformStoreParamCombine(N);
5239     case ISD::EXTRACT_VECTOR_ELT:
5240       return PerformEXTRACTCombine(N, DCI);
5241     case ISD::VSELECT:
5242       return PerformVSELECTCombine(N, DCI);
5243     case ISD::BUILD_VECTOR:
5244       return PerformBUILD_VECTORCombine(N, DCI);
5245   }
5246   return SDValue();
5247 }
5248 
5249 static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
5250                            SmallVectorImpl<SDValue> &Results) {
5251   // Handle bitcasting to v2i8 without hitting the default promotion
5252   // strategy which goes through stack memory.
5253   SDValue Op(Node, 0);
5254   EVT ToVT = Op->getValueType(0);
5255   if (ToVT != MVT::v2i8) {
5256     return;
5257   }
5258 
5259   // Bitcast to i16 and unpack elements into a vector
5260   SDLoc DL(Node);
5261   SDValue AsInt = MaybeBitcast(DAG, DL, MVT::i16, Op->getOperand(0));
5262   SDValue Vec0 = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, AsInt);
5263   SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
5264   SDValue Vec1 =
5265       DAG.getNode(ISD::TRUNCATE, DL, MVT::i8,
5266                   DAG.getNode(ISD::SRL, DL, MVT::i16, {AsInt, Const8}));
5267   Results.push_back(
5268       DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2i8, {Vec0, Vec1}));
5269 }
5270 
5271 /// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
5272 static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
5273                               SmallVectorImpl<SDValue> &Results) {
5274   EVT ResVT = N->getValueType(0);
5275   SDLoc DL(N);
5276 
5277   assert(ResVT.isVector() && "Vector load must have vector type");
5278 
5279   auto NumEltsAndEltVT = getVectorLoweringShape(ResVT);
5280   if (!NumEltsAndEltVT)
5281     return;
5282   auto [NumElts, EltVT] = NumEltsAndEltVT.value();
5283 
5284   LoadSDNode *LD = cast<LoadSDNode>(N);
5285 
5286   Align Alignment = LD->getAlign();
5287   auto &TD = DAG.getDataLayout();
5288   Align PrefAlign =
5289       TD.getPrefTypeAlign(LD->getMemoryVT().getTypeForEVT(*DAG.getContext()));
5290   if (Alignment < PrefAlign) {
5291     // This load is not sufficiently aligned, so bail out and let this vector
5292     // load be scalarized.  Note that we may still be able to emit smaller
5293     // vector loads.  For example, if we are loading a <4 x float> with an
5294     // alignment of 8, this check will fail but the legalizer will try again
5295     // with 2 x <2 x float>, which will succeed with an alignment of 8.
5296     return;
5297   }
5298 
5299   // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
5300   // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
5301   // loaded type to i16 and propagate the "real" type as the memory type.
5302   bool NeedTrunc = false;
5303   if (EltVT.getSizeInBits() < 16) {
5304     EltVT = MVT::i16;
5305     NeedTrunc = true;
5306   }
5307 
5308   unsigned Opcode = 0;
5309   SDVTList LdResVTs;
5310 
5311   switch (NumElts) {
5312   default:
5313     return;
5314   case 2:
5315     Opcode = NVPTXISD::LoadV2;
5316     LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
5317     break;
5318   case 4: {
5319     Opcode = NVPTXISD::LoadV4;
5320     EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
5321     LdResVTs = DAG.getVTList(ListVTs);
5322     break;
5323   }
5324   }
5325 
5326   // Copy regular operands
5327   SmallVector<SDValue, 8> OtherOps(N->ops());
5328 
5329   // The select routine does not have access to the LoadSDNode instance, so
5330   // pass along the extension information
5331   OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL));
5332 
5333   SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps,
5334                                           LD->getMemoryVT(),
5335                                           LD->getMemOperand());
5336 
5337   SmallVector<SDValue> ScalarRes;
5338   assert(NumElts <= ResVT.getVectorNumElements() &&
5339          "NumElts should not increase, only decrease or stay the same.");
5340   if (NumElts < ResVT.getVectorNumElements()) {
5341     // If the number of elements has decreased, getVectorLoweringShape has
5342     // upsized the element types
5343     assert(EltVT.isVector() && EltVT.getSizeInBits() == 32 &&
5344            EltVT.getVectorNumElements() <= 4 && "Unexpected upsized type.");
5345     // Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back
5346     // into individual elements.
5347     for (unsigned i = 0; i < NumElts; ++i) {
5348       SDValue SubVector = NewLD.getValue(i);
5349       DAG.ExtractVectorElements(SubVector, ScalarRes);
5350     }
5351   } else {
5352     for (unsigned i = 0; i < NumElts; ++i) {
5353       SDValue Res = NewLD.getValue(i);
5354       if (NeedTrunc)
5355         Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
5356       ScalarRes.push_back(Res);
5357     }
5358   }
5359 
5360   SDValue LoadChain = NewLD.getValue(NumElts);
5361 
5362   SDValue BuildVec = DAG.getBuildVector(ResVT, DL, ScalarRes);
5363 
5364   Results.push_back(BuildVec);
5365   Results.push_back(LoadChain);
5366 }
5367 
5368 static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
5369                                      SmallVectorImpl<SDValue> &Results) {
5370   SDValue Chain = N->getOperand(0);
5371   SDValue Intrin = N->getOperand(1);
5372   SDLoc DL(N);
5373 
5374   // Get the intrinsic ID
5375   unsigned IntrinNo = Intrin.getNode()->getAsZExtVal();
5376   switch (IntrinNo) {
5377   default:
5378     return;
5379   case Intrinsic::nvvm_ldu_global_i:
5380   case Intrinsic::nvvm_ldu_global_f:
5381   case Intrinsic::nvvm_ldu_global_p: {
5382     EVT ResVT = N->getValueType(0);
5383 
5384     if (ResVT.isVector()) {
5385       // Vector LDG/LDU
5386 
5387       unsigned NumElts = ResVT.getVectorNumElements();
5388       EVT EltVT = ResVT.getVectorElementType();
5389 
5390       // Since LDU/LDG are target nodes, we cannot rely on DAG type
5391       // legalization.
5392       // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
5393       // loaded type to i16 and propagate the "real" type as the memory type.
5394       bool NeedTrunc = false;
5395       if (EltVT.getSizeInBits() < 16) {
5396         EltVT = MVT::i16;
5397         NeedTrunc = true;
5398       }
5399 
5400       unsigned Opcode = 0;
5401       SDVTList LdResVTs;
5402 
5403       switch (NumElts) {
5404       default:
5405         return;
5406       case 2:
5407         Opcode = NVPTXISD::LDUV2;
5408         LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
5409         break;
5410       case 4: {
5411         Opcode = NVPTXISD::LDUV4;
5412         EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
5413         LdResVTs = DAG.getVTList(ListVTs);
5414         break;
5415       }
5416       }
5417 
5418       SmallVector<SDValue, 8> OtherOps;
5419 
5420       // Copy regular operands
5421 
5422       OtherOps.push_back(Chain); // Chain
5423                                  // Skip operand 1 (intrinsic ID)
5424       // Others
5425       OtherOps.append(N->op_begin() + 2, N->op_end());
5426 
5427       MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
5428 
5429       SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps,
5430                                               MemSD->getMemoryVT(),
5431                                               MemSD->getMemOperand());
5432 
5433       SmallVector<SDValue, 4> ScalarRes;
5434 
5435       for (unsigned i = 0; i < NumElts; ++i) {
5436         SDValue Res = NewLD.getValue(i);
5437         if (NeedTrunc)
5438           Res =
5439               DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
5440         ScalarRes.push_back(Res);
5441       }
5442 
5443       SDValue LoadChain = NewLD.getValue(NumElts);
5444 
5445       SDValue BuildVec =
5446           DAG.getBuildVector(ResVT, DL, ScalarRes);
5447 
5448       Results.push_back(BuildVec);
5449       Results.push_back(LoadChain);
5450     } else {
5451       // i8 LDG/LDU
5452       assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
5453              "Custom handling of non-i8 ldu/ldg?");
5454 
5455       // Just copy all operands as-is
5456       SmallVector<SDValue, 4> Ops(N->ops());
5457 
5458       // Force output to i16
5459       SDVTList LdResVTs = DAG.getVTList(MVT::i16, MVT::Other);
5460 
5461       MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
5462 
5463       // We make sure the memory type is i8, which will be used during isel
5464       // to select the proper instruction.
5465       SDValue NewLD =
5466           DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, LdResVTs, Ops,
5467                                   MVT::i8, MemSD->getMemOperand());
5468 
5469       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i8,
5470                                     NewLD.getValue(0)));
5471       Results.push_back(NewLD.getValue(1));
5472     }
5473   }
5474   }
5475 }
5476 
5477 static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG,
5478                                    SmallVectorImpl<SDValue> &Results) {
5479   // Change the CopyFromReg to output 2 64-bit results instead of a 128-bit
5480   // result so that it can pass the legalization
5481   SDLoc DL(N);
5482   SDValue Chain = N->getOperand(0);
5483   SDValue Reg = N->getOperand(1);
5484   SDValue Glue = N->getOperand(2);
5485 
5486   assert(Reg.getValueType() == MVT::i128 &&
5487          "Custom lowering for CopyFromReg with 128-bit reg only");
5488   SmallVector<EVT, 4> ResultsType = {MVT::i64, MVT::i64, N->getValueType(1),
5489                                      N->getValueType(2)};
5490   SmallVector<SDValue, 3> NewOps = {Chain, Reg, Glue};
5491 
5492   SDValue NewValue = DAG.getNode(ISD::CopyFromReg, DL, ResultsType, NewOps);
5493   SDValue Pair = DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i128,
5494                              {NewValue.getValue(0), NewValue.getValue(1)});
5495 
5496   Results.push_back(Pair);
5497   Results.push_back(NewValue.getValue(2));
5498   Results.push_back(NewValue.getValue(3));
5499 }
5500 
5501 void NVPTXTargetLowering::ReplaceNodeResults(
5502     SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
5503   switch (N->getOpcode()) {
5504   default:
5505     report_fatal_error("Unhandled custom legalization");
5506   case ISD::BITCAST:
5507     ReplaceBITCAST(N, DAG, Results);
5508     return;
5509   case ISD::LOAD:
5510     ReplaceLoadVector(N, DAG, Results);
5511     return;
5512   case ISD::INTRINSIC_W_CHAIN:
5513     ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
5514     return;
5515   case ISD::CopyFromReg:
5516     ReplaceCopyFromReg_128(N, DAG, Results);
5517     return;
5518   }
5519 }
5520 
5521 NVPTXTargetLowering::AtomicExpansionKind
5522 NVPTXTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
5523   Type *Ty = AI->getValOperand()->getType();
5524 
5525   if (AI->isFloatingPointOperation()) {
5526     if (AI->getOperation() == AtomicRMWInst::BinOp::FAdd) {
5527       if (Ty->isHalfTy() && STI.getSmVersion() >= 70 &&
5528           STI.getPTXVersion() >= 63)
5529         return AtomicExpansionKind::None;
5530       if (Ty->isBFloatTy() && STI.getSmVersion() >= 90 &&
5531           STI.getPTXVersion() >= 78)
5532         return AtomicExpansionKind::None;
5533       if (Ty->isFloatTy())
5534         return AtomicExpansionKind::None;
5535       if (Ty->isDoubleTy() && STI.hasAtomAddF64())
5536         return AtomicExpansionKind::None;
5537     }
5538     return AtomicExpansionKind::CmpXChg;
5539   }
5540 
5541   assert(Ty->isIntegerTy() && "Ty should be integer at this point");
5542   auto ITy = cast<llvm::IntegerType>(Ty);
5543 
5544   switch (AI->getOperation()) {
5545   default:
5546     return AtomicExpansionKind::CmpXChg;
5547   case AtomicRMWInst::BinOp::And:
5548   case AtomicRMWInst::BinOp::Or:
5549   case AtomicRMWInst::BinOp::Xor:
5550   case AtomicRMWInst::BinOp::Xchg:
5551     switch (ITy->getBitWidth()) {
5552     case 8:
5553     case 16:
5554       return AtomicExpansionKind::CmpXChg;
5555     case 32:
5556       return AtomicExpansionKind::None;
5557     case 64:
5558       if (STI.hasAtomBitwise64())
5559         return AtomicExpansionKind::None;
5560       return AtomicExpansionKind::CmpXChg;
5561     default:
5562       llvm_unreachable("unsupported width encountered");
5563     }
5564   case AtomicRMWInst::BinOp::Add:
5565   case AtomicRMWInst::BinOp::Sub:
5566   case AtomicRMWInst::BinOp::Max:
5567   case AtomicRMWInst::BinOp::Min:
5568   case AtomicRMWInst::BinOp::UMax:
5569   case AtomicRMWInst::BinOp::UMin:
5570     switch (ITy->getBitWidth()) {
5571     case 8:
5572     case 16:
5573       return AtomicExpansionKind::CmpXChg;
5574     case 32:
5575       return AtomicExpansionKind::None;
5576     case 64:
5577       if (STI.hasAtomMinMax64())
5578         return AtomicExpansionKind::None;
5579       return AtomicExpansionKind::CmpXChg;
5580     default:
5581       llvm_unreachable("unsupported width encountered");
5582     }
5583   }
5584 
5585   return AtomicExpansionKind::CmpXChg;
5586 }
5587 
5588 // Pin NVPTXTargetObjectFile's vtables to this file.
5589 NVPTXTargetObjectFile::~NVPTXTargetObjectFile() = default;
5590 
5591 MCSection *NVPTXTargetObjectFile::SelectSectionForGlobal(
5592     const GlobalObject *GO, SectionKind Kind, const TargetMachine &TM) const {
5593   return getDataSection();
5594 }
5595