xref: /llvm-project/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp (revision 18f8106f310ee702046a11f360af47947c030d2e)
1 //===-- NVPTXTargetTransformInfo.cpp - NVPTX specific TTI -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "NVPTXTargetTransformInfo.h"
10 #include "NVPTXUtilities.h"
11 #include "llvm/Analysis/LoopInfo.h"
12 #include "llvm/Analysis/TargetTransformInfo.h"
13 #include "llvm/Analysis/ValueTracking.h"
14 #include "llvm/CodeGen/BasicTTIImpl.h"
15 #include "llvm/CodeGen/TargetLowering.h"
16 #include "llvm/IR/Constants.h"
17 #include "llvm/IR/IntrinsicInst.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/IntrinsicsNVPTX.h"
20 #include "llvm/IR/Value.h"
21 #include "llvm/Support/Casting.h"
22 #include "llvm/Support/ErrorHandling.h"
23 #include "llvm/Transforms/InstCombine/InstCombiner.h"
24 #include <optional>
25 using namespace llvm;
26 
27 #define DEBUG_TYPE "NVPTXtti"
28 
29 // Whether the given intrinsic reads threadIdx.x/y/z.
30 static bool readsThreadIndex(const IntrinsicInst *II) {
31   switch (II->getIntrinsicID()) {
32     default: return false;
33     case Intrinsic::nvvm_read_ptx_sreg_tid_x:
34     case Intrinsic::nvvm_read_ptx_sreg_tid_y:
35     case Intrinsic::nvvm_read_ptx_sreg_tid_z:
36       return true;
37   }
38 }
39 
40 static bool readsLaneId(const IntrinsicInst *II) {
41   return II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_laneid;
42 }
43 
44 // Whether the given intrinsic is an atomic instruction in PTX.
45 static bool isNVVMAtomic(const IntrinsicInst *II) {
46   switch (II->getIntrinsicID()) {
47     default: return false;
48     case Intrinsic::nvvm_atomic_load_inc_32:
49     case Intrinsic::nvvm_atomic_load_dec_32:
50 
51     case Intrinsic::nvvm_atomic_add_gen_f_cta:
52     case Intrinsic::nvvm_atomic_add_gen_f_sys:
53     case Intrinsic::nvvm_atomic_add_gen_i_cta:
54     case Intrinsic::nvvm_atomic_add_gen_i_sys:
55     case Intrinsic::nvvm_atomic_and_gen_i_cta:
56     case Intrinsic::nvvm_atomic_and_gen_i_sys:
57     case Intrinsic::nvvm_atomic_cas_gen_i_cta:
58     case Intrinsic::nvvm_atomic_cas_gen_i_sys:
59     case Intrinsic::nvvm_atomic_dec_gen_i_cta:
60     case Intrinsic::nvvm_atomic_dec_gen_i_sys:
61     case Intrinsic::nvvm_atomic_inc_gen_i_cta:
62     case Intrinsic::nvvm_atomic_inc_gen_i_sys:
63     case Intrinsic::nvvm_atomic_max_gen_i_cta:
64     case Intrinsic::nvvm_atomic_max_gen_i_sys:
65     case Intrinsic::nvvm_atomic_min_gen_i_cta:
66     case Intrinsic::nvvm_atomic_min_gen_i_sys:
67     case Intrinsic::nvvm_atomic_or_gen_i_cta:
68     case Intrinsic::nvvm_atomic_or_gen_i_sys:
69     case Intrinsic::nvvm_atomic_exch_gen_i_cta:
70     case Intrinsic::nvvm_atomic_exch_gen_i_sys:
71     case Intrinsic::nvvm_atomic_xor_gen_i_cta:
72     case Intrinsic::nvvm_atomic_xor_gen_i_sys:
73       return true;
74   }
75 }
76 
77 bool NVPTXTTIImpl::isSourceOfDivergence(const Value *V) {
78   // Without inter-procedural analysis, we conservatively assume that arguments
79   // to __device__ functions are divergent.
80   if (const Argument *Arg = dyn_cast<Argument>(V))
81     return !isKernelFunction(*Arg->getParent());
82 
83   if (const Instruction *I = dyn_cast<Instruction>(V)) {
84     // Without pointer analysis, we conservatively assume values loaded from
85     // generic or local address space are divergent.
86     if (const LoadInst *LI = dyn_cast<LoadInst>(I)) {
87       unsigned AS = LI->getPointerAddressSpace();
88       return AS == ADDRESS_SPACE_GENERIC || AS == ADDRESS_SPACE_LOCAL;
89     }
90     // Atomic instructions may cause divergence. Atomic instructions are
91     // executed sequentially across all threads in a warp. Therefore, an earlier
92     // executed thread may see different memory inputs than a later executed
93     // thread. For example, suppose *a = 0 initially.
94     //
95     //   atom.global.add.s32 d, [a], 1
96     //
97     // returns 0 for the first thread that enters the critical region, and 1 for
98     // the second thread.
99     if (I->isAtomic())
100       return true;
101     if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
102       // Instructions that read threadIdx are obviously divergent.
103       if (readsThreadIndex(II) || readsLaneId(II))
104         return true;
105       // Handle the NVPTX atomic intrinsics that cannot be represented as an
106       // atomic IR instruction.
107       if (isNVVMAtomic(II))
108         return true;
109     }
110     // Conservatively consider the return value of function calls as divergent.
111     // We could analyze callees with bodies more precisely using
112     // inter-procedural analysis.
113     if (isa<CallInst>(I))
114       return true;
115   }
116 
117   return false;
118 }
119 
120 // Convert NVVM intrinsics to target-generic LLVM code where possible.
121 static Instruction *convertNvvmIntrinsicToLlvm(InstCombiner &IC,
122                                                IntrinsicInst *II) {
123   // Each NVVM intrinsic we can simplify can be replaced with one of:
124   //
125   //  * an LLVM intrinsic,
126   //  * an LLVM cast operation,
127   //  * an LLVM binary operation, or
128   //  * ad-hoc LLVM IR for the particular operation.
129 
130   // Some transformations are only valid when the module's
131   // flush-denormals-to-zero (ftz) setting is true/false, whereas other
132   // transformations are valid regardless of the module's ftz setting.
133   enum FtzRequirementTy {
134     FTZ_Any,       // Any ftz setting is ok.
135     FTZ_MustBeOn,  // Transformation is valid only if ftz is on.
136     FTZ_MustBeOff, // Transformation is valid only if ftz is off.
137   };
138   // Classes of NVVM intrinsics that can't be replaced one-to-one with a
139   // target-generic intrinsic, cast op, or binary op but that we can nonetheless
140   // simplify.
141   enum SpecialCase {
142     SPC_Reciprocal,
143     SCP_FunnelShiftClamp,
144   };
145 
146   // SimplifyAction is a poor-man's variant (plus an additional flag) that
147   // represents how to replace an NVVM intrinsic with target-generic LLVM IR.
148   struct SimplifyAction {
149     // Invariant: At most one of these Optionals has a value.
150     std::optional<Intrinsic::ID> IID;
151     std::optional<Instruction::CastOps> CastOp;
152     std::optional<Instruction::BinaryOps> BinaryOp;
153     std::optional<SpecialCase> Special;
154 
155     FtzRequirementTy FtzRequirement = FTZ_Any;
156     // Denormal handling is guarded by different attributes depending on the
157     // type (denormal-fp-math vs denormal-fp-math-f32), take note of halfs.
158     bool IsHalfTy = false;
159 
160     SimplifyAction() = default;
161 
162     SimplifyAction(Intrinsic::ID IID, FtzRequirementTy FtzReq,
163                    bool IsHalfTy = false)
164         : IID(IID), FtzRequirement(FtzReq), IsHalfTy(IsHalfTy) {}
165 
166     // Cast operations don't have anything to do with FTZ, so we skip that
167     // argument.
168     SimplifyAction(Instruction::CastOps CastOp) : CastOp(CastOp) {}
169 
170     SimplifyAction(Instruction::BinaryOps BinaryOp, FtzRequirementTy FtzReq)
171         : BinaryOp(BinaryOp), FtzRequirement(FtzReq) {}
172 
173     SimplifyAction(SpecialCase Special, FtzRequirementTy FtzReq)
174         : Special(Special), FtzRequirement(FtzReq) {}
175   };
176 
177   // Try to generate a SimplifyAction describing how to replace our
178   // IntrinsicInstr with target-generic LLVM IR.
179   const SimplifyAction Action = [II]() -> SimplifyAction {
180     switch (II->getIntrinsicID()) {
181     // NVVM intrinsics that map directly to LLVM intrinsics.
182     case Intrinsic::nvvm_ceil_d:
183       return {Intrinsic::ceil, FTZ_Any};
184     case Intrinsic::nvvm_ceil_f:
185       return {Intrinsic::ceil, FTZ_MustBeOff};
186     case Intrinsic::nvvm_ceil_ftz_f:
187       return {Intrinsic::ceil, FTZ_MustBeOn};
188     case Intrinsic::nvvm_fabs_d:
189       return {Intrinsic::fabs, FTZ_Any};
190     case Intrinsic::nvvm_floor_d:
191       return {Intrinsic::floor, FTZ_Any};
192     case Intrinsic::nvvm_floor_f:
193       return {Intrinsic::floor, FTZ_MustBeOff};
194     case Intrinsic::nvvm_floor_ftz_f:
195       return {Intrinsic::floor, FTZ_MustBeOn};
196     case Intrinsic::nvvm_fma_rn_d:
197       return {Intrinsic::fma, FTZ_Any};
198     case Intrinsic::nvvm_fma_rn_f:
199       return {Intrinsic::fma, FTZ_MustBeOff};
200     case Intrinsic::nvvm_fma_rn_ftz_f:
201       return {Intrinsic::fma, FTZ_MustBeOn};
202     case Intrinsic::nvvm_fma_rn_f16:
203       return {Intrinsic::fma, FTZ_MustBeOff, true};
204     case Intrinsic::nvvm_fma_rn_ftz_f16:
205       return {Intrinsic::fma, FTZ_MustBeOn, true};
206     case Intrinsic::nvvm_fma_rn_f16x2:
207       return {Intrinsic::fma, FTZ_MustBeOff, true};
208     case Intrinsic::nvvm_fma_rn_ftz_f16x2:
209       return {Intrinsic::fma, FTZ_MustBeOn, true};
210     case Intrinsic::nvvm_fma_rn_bf16:
211       return {Intrinsic::fma, FTZ_MustBeOff, true};
212     case Intrinsic::nvvm_fma_rn_ftz_bf16:
213       return {Intrinsic::fma, FTZ_MustBeOn, true};
214     case Intrinsic::nvvm_fma_rn_bf16x2:
215       return {Intrinsic::fma, FTZ_MustBeOff, true};
216     case Intrinsic::nvvm_fma_rn_ftz_bf16x2:
217       return {Intrinsic::fma, FTZ_MustBeOn, true};
218     case Intrinsic::nvvm_fmax_d:
219       return {Intrinsic::maxnum, FTZ_Any};
220     case Intrinsic::nvvm_fmax_f:
221       return {Intrinsic::maxnum, FTZ_MustBeOff};
222     case Intrinsic::nvvm_fmax_ftz_f:
223       return {Intrinsic::maxnum, FTZ_MustBeOn};
224     case Intrinsic::nvvm_fmax_nan_f:
225       return {Intrinsic::maximum, FTZ_MustBeOff};
226     case Intrinsic::nvvm_fmax_ftz_nan_f:
227       return {Intrinsic::maximum, FTZ_MustBeOn};
228     case Intrinsic::nvvm_fmax_f16:
229       return {Intrinsic::maxnum, FTZ_MustBeOff, true};
230     case Intrinsic::nvvm_fmax_ftz_f16:
231       return {Intrinsic::maxnum, FTZ_MustBeOn, true};
232     case Intrinsic::nvvm_fmax_f16x2:
233       return {Intrinsic::maxnum, FTZ_MustBeOff, true};
234     case Intrinsic::nvvm_fmax_ftz_f16x2:
235       return {Intrinsic::maxnum, FTZ_MustBeOn, true};
236     case Intrinsic::nvvm_fmax_nan_f16:
237       return {Intrinsic::maximum, FTZ_MustBeOff, true};
238     case Intrinsic::nvvm_fmax_ftz_nan_f16:
239       return {Intrinsic::maximum, FTZ_MustBeOn, true};
240     case Intrinsic::nvvm_fmax_nan_f16x2:
241       return {Intrinsic::maximum, FTZ_MustBeOff, true};
242     case Intrinsic::nvvm_fmax_ftz_nan_f16x2:
243       return {Intrinsic::maximum, FTZ_MustBeOn, true};
244     case Intrinsic::nvvm_fmin_d:
245       return {Intrinsic::minnum, FTZ_Any};
246     case Intrinsic::nvvm_fmin_f:
247       return {Intrinsic::minnum, FTZ_MustBeOff};
248     case Intrinsic::nvvm_fmin_ftz_f:
249       return {Intrinsic::minnum, FTZ_MustBeOn};
250     case Intrinsic::nvvm_fmin_nan_f:
251       return {Intrinsic::minimum, FTZ_MustBeOff};
252     case Intrinsic::nvvm_fmin_ftz_nan_f:
253       return {Intrinsic::minimum, FTZ_MustBeOn};
254     case Intrinsic::nvvm_fmin_f16:
255       return {Intrinsic::minnum, FTZ_MustBeOff, true};
256     case Intrinsic::nvvm_fmin_ftz_f16:
257       return {Intrinsic::minnum, FTZ_MustBeOn, true};
258     case Intrinsic::nvvm_fmin_f16x2:
259       return {Intrinsic::minnum, FTZ_MustBeOff, true};
260     case Intrinsic::nvvm_fmin_ftz_f16x2:
261       return {Intrinsic::minnum, FTZ_MustBeOn, true};
262     case Intrinsic::nvvm_fmin_nan_f16:
263       return {Intrinsic::minimum, FTZ_MustBeOff, true};
264     case Intrinsic::nvvm_fmin_ftz_nan_f16:
265       return {Intrinsic::minimum, FTZ_MustBeOn, true};
266     case Intrinsic::nvvm_fmin_nan_f16x2:
267       return {Intrinsic::minimum, FTZ_MustBeOff, true};
268     case Intrinsic::nvvm_fmin_ftz_nan_f16x2:
269       return {Intrinsic::minimum, FTZ_MustBeOn, true};
270     case Intrinsic::nvvm_sqrt_rn_d:
271       return {Intrinsic::sqrt, FTZ_Any};
272     case Intrinsic::nvvm_sqrt_f:
273       // nvvm_sqrt_f is a special case.  For  most intrinsics, foo_ftz_f is the
274       // ftz version, and foo_f is the non-ftz version.  But nvvm_sqrt_f adopts
275       // the ftz-ness of the surrounding code.  sqrt_rn_f and sqrt_rn_ftz_f are
276       // the versions with explicit ftz-ness.
277       return {Intrinsic::sqrt, FTZ_Any};
278     case Intrinsic::nvvm_trunc_d:
279       return {Intrinsic::trunc, FTZ_Any};
280     case Intrinsic::nvvm_trunc_f:
281       return {Intrinsic::trunc, FTZ_MustBeOff};
282     case Intrinsic::nvvm_trunc_ftz_f:
283       return {Intrinsic::trunc, FTZ_MustBeOn};
284 
285     // NVVM intrinsics that map to LLVM cast operations.
286     //
287     // Note that llvm's target-generic conversion operators correspond to the rz
288     // (round to zero) versions of the nvvm conversion intrinsics, even though
289     // most everything else here uses the rn (round to nearest even) nvvm ops.
290     case Intrinsic::nvvm_d2i_rz:
291     case Intrinsic::nvvm_f2i_rz:
292     case Intrinsic::nvvm_d2ll_rz:
293     case Intrinsic::nvvm_f2ll_rz:
294       return {Instruction::FPToSI};
295     case Intrinsic::nvvm_d2ui_rz:
296     case Intrinsic::nvvm_f2ui_rz:
297     case Intrinsic::nvvm_d2ull_rz:
298     case Intrinsic::nvvm_f2ull_rz:
299       return {Instruction::FPToUI};
300     // Integer to floating-point uses RN rounding, not RZ
301     case Intrinsic::nvvm_i2d_rn:
302     case Intrinsic::nvvm_i2f_rn:
303     case Intrinsic::nvvm_ll2d_rn:
304     case Intrinsic::nvvm_ll2f_rn:
305       return {Instruction::SIToFP};
306     case Intrinsic::nvvm_ui2d_rn:
307     case Intrinsic::nvvm_ui2f_rn:
308     case Intrinsic::nvvm_ull2d_rn:
309     case Intrinsic::nvvm_ull2f_rn:
310       return {Instruction::UIToFP};
311 
312     // NVVM intrinsics that map to LLVM binary ops.
313     case Intrinsic::nvvm_div_rn_d:
314       return {Instruction::FDiv, FTZ_Any};
315 
316     // The remainder of cases are NVVM intrinsics that map to LLVM idioms, but
317     // need special handling.
318     //
319     // We seem to be missing intrinsics for rcp.approx.{ftz.}f32, which is just
320     // as well.
321     case Intrinsic::nvvm_rcp_rn_d:
322       return {SPC_Reciprocal, FTZ_Any};
323 
324     case Intrinsic::nvvm_fshl_clamp:
325     case Intrinsic::nvvm_fshr_clamp:
326       return {SCP_FunnelShiftClamp, FTZ_Any};
327 
328       // We do not currently simplify intrinsics that give an approximate
329       // answer. These include:
330       //
331       //   - nvvm_cos_approx_{f,ftz_f}
332       //   - nvvm_ex2_approx_{d,f,ftz_f}
333       //   - nvvm_lg2_approx_{d,f,ftz_f}
334       //   - nvvm_sin_approx_{f,ftz_f}
335       //   - nvvm_sqrt_approx_{f,ftz_f}
336       //   - nvvm_rsqrt_approx_{d,f,ftz_f}
337       //   - nvvm_div_approx_{ftz_d,ftz_f,f}
338       //   - nvvm_rcp_approx_ftz_d
339       //
340       // Ideally we'd encode them as e.g. "fast call @llvm.cos", where "fast"
341       // means that fastmath is enabled in the intrinsic.  Unfortunately only
342       // binary operators (currently) have a fastmath bit in SelectionDAG, so
343       // this information gets lost and we can't select on it.
344       //
345       // TODO: div and rcp are lowered to a binary op, so these we could in
346       // theory lower them to "fast fdiv".
347 
348     default:
349       return {};
350     }
351   }();
352 
353   // If Action.FtzRequirementTy is not satisfied by the module's ftz state, we
354   // can bail out now.  (Notice that in the case that IID is not an NVVM
355   // intrinsic, we don't have to look up any module metadata, as
356   // FtzRequirementTy will be FTZ_Any.)
357   if (Action.FtzRequirement != FTZ_Any) {
358     // FIXME: Broken for f64
359     DenormalMode Mode = II->getFunction()->getDenormalMode(
360         Action.IsHalfTy ? APFloat::IEEEhalf() : APFloat::IEEEsingle());
361     bool FtzEnabled = Mode.Output == DenormalMode::PreserveSign;
362 
363     if (FtzEnabled != (Action.FtzRequirement == FTZ_MustBeOn))
364       return nullptr;
365   }
366 
367   // Simplify to target-generic intrinsic.
368   if (Action.IID) {
369     SmallVector<Value *, 4> Args(II->args());
370     // All the target-generic intrinsics currently of interest to us have one
371     // type argument, equal to that of the nvvm intrinsic's argument.
372     Type *Tys[] = {II->getArgOperand(0)->getType()};
373     return CallInst::Create(
374         Intrinsic::getOrInsertDeclaration(II->getModule(), *Action.IID, Tys),
375         Args);
376   }
377 
378   // Simplify to target-generic binary op.
379   if (Action.BinaryOp)
380     return BinaryOperator::Create(*Action.BinaryOp, II->getArgOperand(0),
381                                   II->getArgOperand(1), II->getName());
382 
383   // Simplify to target-generic cast op.
384   if (Action.CastOp)
385     return CastInst::Create(*Action.CastOp, II->getArgOperand(0), II->getType(),
386                             II->getName());
387 
388   // All that's left are the special cases.
389   if (!Action.Special)
390     return nullptr;
391 
392   switch (*Action.Special) {
393   case SPC_Reciprocal:
394     // Simplify reciprocal.
395     return BinaryOperator::Create(
396         Instruction::FDiv, ConstantFP::get(II->getArgOperand(0)->getType(), 1),
397         II->getArgOperand(0), II->getName());
398 
399   case SCP_FunnelShiftClamp: {
400     // Canonicalize a clamping funnel shift to the generic llvm funnel shift
401     // when possible, as this is easier for llvm to optimize further.
402     if (const auto *ShiftConst = dyn_cast<ConstantInt>(II->getArgOperand(2))) {
403       const bool IsLeft = II->getIntrinsicID() == Intrinsic::nvvm_fshl_clamp;
404       if (ShiftConst->getZExtValue() >= II->getType()->getIntegerBitWidth())
405         return IC.replaceInstUsesWith(*II, II->getArgOperand(IsLeft ? 1 : 0));
406 
407       const unsigned FshIID = IsLeft ? Intrinsic::fshl : Intrinsic::fshr;
408       return CallInst::Create(Intrinsic::getOrInsertDeclaration(
409                                   II->getModule(), FshIID, II->getType()),
410                               SmallVector<Value *, 3>(II->args()));
411     }
412     return nullptr;
413   }
414   }
415   llvm_unreachable("All SpecialCase enumerators should be handled in switch.");
416 }
417 
418 // Returns true/false when we know the answer, nullopt otherwise.
419 static std::optional<bool> evaluateIsSpace(Intrinsic::ID IID, unsigned AS) {
420   if (AS == NVPTXAS::ADDRESS_SPACE_GENERIC ||
421       AS == NVPTXAS::ADDRESS_SPACE_PARAM)
422     return std::nullopt; // Got to check at run-time.
423   switch (IID) {
424   case Intrinsic::nvvm_isspacep_global:
425     return AS == NVPTXAS::ADDRESS_SPACE_GLOBAL;
426   case Intrinsic::nvvm_isspacep_local:
427     return AS == NVPTXAS::ADDRESS_SPACE_LOCAL;
428   case Intrinsic::nvvm_isspacep_shared:
429     return AS == NVPTXAS::ADDRESS_SPACE_SHARED;
430   case Intrinsic::nvvm_isspacep_shared_cluster:
431     // We can't tell shared from shared_cluster at compile time from AS alone,
432     // but it can't be either is AS is not shared.
433     return AS == NVPTXAS::ADDRESS_SPACE_SHARED ? std::nullopt
434                                                : std::optional{false};
435   case Intrinsic::nvvm_isspacep_const:
436     return AS == NVPTXAS::ADDRESS_SPACE_CONST;
437   default:
438     llvm_unreachable("Unexpected intrinsic");
439   }
440 }
441 
442 // Returns an instruction pointer (may be nullptr if we do not know the answer).
443 // Returns nullopt if `II` is not one of the `isspacep` intrinsics.
444 //
445 // TODO: If InferAddressSpaces were run early enough in the pipeline this could
446 // be removed in favor of the constant folding that occurs there through
447 // rewriteIntrinsicWithAddressSpace
448 static std::optional<Instruction *>
449 handleSpaceCheckIntrinsics(InstCombiner &IC, IntrinsicInst &II) {
450 
451   switch (auto IID = II.getIntrinsicID()) {
452   case Intrinsic::nvvm_isspacep_global:
453   case Intrinsic::nvvm_isspacep_local:
454   case Intrinsic::nvvm_isspacep_shared:
455   case Intrinsic::nvvm_isspacep_shared_cluster:
456   case Intrinsic::nvvm_isspacep_const: {
457     Value *Op0 = II.getArgOperand(0);
458     unsigned AS = Op0->getType()->getPointerAddressSpace();
459     // Peek through ASC to generic AS.
460     // TODO: we could dig deeper through both ASCs and GEPs.
461     if (AS == NVPTXAS::ADDRESS_SPACE_GENERIC)
462       if (auto *ASCO = dyn_cast<AddrSpaceCastOperator>(Op0))
463         AS = ASCO->getOperand(0)->getType()->getPointerAddressSpace();
464 
465     if (std::optional<bool> Answer = evaluateIsSpace(IID, AS))
466       return IC.replaceInstUsesWith(II,
467                                     ConstantInt::get(II.getType(), *Answer));
468     return nullptr; // Don't know the answer, got to check at run time.
469   }
470   default:
471     return std::nullopt;
472   }
473 }
474 
475 std::optional<Instruction *>
476 NVPTXTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
477   if (std::optional<Instruction *> I = handleSpaceCheckIntrinsics(IC, II))
478     return *I;
479   if (Instruction *I = convertNvvmIntrinsicToLlvm(IC, &II))
480     return I;
481 
482   return std::nullopt;
483 }
484 
485 InstructionCost NVPTXTTIImpl::getArithmeticInstrCost(
486     unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
487     TTI::OperandValueInfo Op1Info, TTI::OperandValueInfo Op2Info,
488     ArrayRef<const Value *> Args,
489     const Instruction *CxtI) {
490   // Legalize the type.
491   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
492 
493   int ISD = TLI->InstructionOpcodeToISD(Opcode);
494 
495   switch (ISD) {
496   default:
497     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
498                                          Op2Info);
499   case ISD::ADD:
500   case ISD::MUL:
501   case ISD::XOR:
502   case ISD::OR:
503   case ISD::AND:
504     // The machine code (SASS) simulates an i64 with two i32. Therefore, we
505     // estimate that arithmetic operations on i64 are twice as expensive as
506     // those on types that can fit into one machine register.
507     if (LT.second.SimpleTy == MVT::i64)
508       return 2 * LT.first;
509     // Delegate other cases to the basic TTI.
510     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
511                                          Op2Info);
512   }
513 }
514 
515 void NVPTXTTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
516                                            TTI::UnrollingPreferences &UP,
517                                            OptimizationRemarkEmitter *ORE) {
518   BaseT::getUnrollingPreferences(L, SE, UP, ORE);
519 
520   // Enable partial unrolling and runtime unrolling, but reduce the
521   // threshold.  This partially unrolls small loops which are often
522   // unrolled by the PTX to SASS compiler and unrolling earlier can be
523   // beneficial.
524   UP.Partial = UP.Runtime = true;
525   UP.PartialThreshold = UP.Threshold / 4;
526 }
527 
528 void NVPTXTTIImpl::getPeelingPreferences(Loop *L, ScalarEvolution &SE,
529                                          TTI::PeelingPreferences &PP) {
530   BaseT::getPeelingPreferences(L, SE, PP);
531 }
532 
533 bool NVPTXTTIImpl::collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
534                                               Intrinsic::ID IID) const {
535   switch (IID) {
536   case Intrinsic::nvvm_isspacep_const:
537   case Intrinsic::nvvm_isspacep_global:
538   case Intrinsic::nvvm_isspacep_local:
539   case Intrinsic::nvvm_isspacep_shared:
540   case Intrinsic::nvvm_isspacep_shared_cluster: {
541     OpIndexes.push_back(0);
542     return true;
543   }
544   }
545   return false;
546 }
547 
548 Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
549                                                       Value *OldV,
550                                                       Value *NewV) const {
551   const Intrinsic::ID IID = II->getIntrinsicID();
552   switch (IID) {
553   case Intrinsic::nvvm_isspacep_const:
554   case Intrinsic::nvvm_isspacep_global:
555   case Intrinsic::nvvm_isspacep_local:
556   case Intrinsic::nvvm_isspacep_shared:
557   case Intrinsic::nvvm_isspacep_shared_cluster: {
558     const unsigned NewAS = NewV->getType()->getPointerAddressSpace();
559     if (const auto R = evaluateIsSpace(IID, NewAS))
560       return ConstantInt::get(II->getType(), *R);
561     return nullptr;
562   }
563   }
564   return nullptr;
565 }
566 
567 void NVPTXTTIImpl::collectKernelLaunchBounds(
568     const Function &F,
569     SmallVectorImpl<std::pair<StringRef, int64_t>> &LB) const {
570   std::optional<unsigned> Val;
571   if ((Val = getMaxClusterRank(F)))
572     LB.push_back({"maxclusterrank", *Val});
573   if ((Val = getMaxNTIDx(F)))
574     LB.push_back({"maxntidx", *Val});
575   if ((Val = getMaxNTIDy(F)))
576     LB.push_back({"maxntidy", *Val});
577   if ((Val = getMaxNTIDz(F)))
578     LB.push_back({"maxntidz", *Val});
579 }
580