xref: /llvm-project/llvm/lib/Target/NVPTX/NVPTXISelLowering.h (revision 5e5fd0e6fc50cc1198750308c11433a5b3acfd0f)
1 //===-- NVPTXISelLowering.h - NVPTX DAG Lowering Interface ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the interfaces that NVPTX uses to lower LLVM code into a
10 // selection DAG.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_LIB_TARGET_NVPTX_NVPTXISELLOWERING_H
15 #define LLVM_LIB_TARGET_NVPTX_NVPTXISELLOWERING_H
16 
17 #include "NVPTX.h"
18 #include "llvm/CodeGen/SelectionDAG.h"
19 #include "llvm/CodeGen/TargetLowering.h"
20 
21 namespace llvm {
22 namespace NVPTXISD {
23 enum NodeType : unsigned {
24   // Start the numbering from where ISD NodeType finishes.
25   FIRST_NUMBER = ISD::BUILTIN_OP_END,
26   Wrapper,
27   CALL,
28   RET_GLUE,
29   LOAD_PARAM,
30   DeclareParam,
31   DeclareScalarParam,
32   DeclareRetParam,
33   DeclareRet,
34   DeclareScalarRet,
35   PrintCall,
36   PrintConvergentCall,
37   PrintCallUni,
38   PrintConvergentCallUni,
39   CallArgBegin,
40   CallArg,
41   LastCallArg,
42   CallArgEnd,
43   CallVoid,
44   CallVal,
45   CallSymbol,
46   Prototype,
47   MoveParam,
48   PseudoUseParam,
49   RETURN,
50   CallSeqBegin,
51   CallSeqEnd,
52   CallPrototype,
53   ProxyReg,
54   FSHL_CLAMP,
55   FSHR_CLAMP,
56   MUL_WIDE_SIGNED,
57   MUL_WIDE_UNSIGNED,
58   SETP_F16X2,
59   SETP_BF16X2,
60   BFE,
61   BFI,
62   PRMT,
63   FCOPYSIGN,
64   DYNAMIC_STACKALLOC,
65   STACKRESTORE,
66   STACKSAVE,
67   BrxStart,
68   BrxItem,
69   BrxEnd,
70   Dummy,
71 
72   FIRST_MEMORY_OPCODE,
73   LoadV2 = FIRST_MEMORY_OPCODE,
74   LoadV4,
75   LDUV2, // LDU.v2
76   LDUV4, // LDU.v4
77   StoreV2,
78   StoreV4,
79   LoadParam,
80   LoadParamV2,
81   LoadParamV4,
82   StoreParam,
83   StoreParamV2,
84   StoreParamV4,
85   StoreParamS32, // to sext and store a <32bit value, not used currently
86   StoreParamU32, // to zext and store a <32bit value, not used currently
87   StoreRetval,
88   StoreRetvalV2,
89   StoreRetvalV4,
90   LAST_MEMORY_OPCODE = StoreRetvalV4,
91 };
92 }
93 
94 class NVPTXSubtarget;
95 
96 //===--------------------------------------------------------------------===//
97 // TargetLowering Implementation
98 //===--------------------------------------------------------------------===//
99 class NVPTXTargetLowering : public TargetLowering {
100 public:
101   explicit NVPTXTargetLowering(const NVPTXTargetMachine &TM,
102                                const NVPTXSubtarget &STI);
103   SDValue LowerOperation(SDValue Op, SelectionDAG &DAG) const override;
104 
105   SDValue LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const;
106 
107   const char *getTargetNodeName(unsigned Opcode) const override;
108 
109   bool getTgtMemIntrinsic(IntrinsicInfo &Info, const CallInst &I,
110                           MachineFunction &MF,
111                           unsigned Intrinsic) const override;
112 
113   Align getFunctionArgumentAlignment(const Function *F, Type *Ty, unsigned Idx,
114                                      const DataLayout &DL) const;
115 
116   /// getFunctionParamOptimizedAlign - since function arguments are passed via
117   /// .param space, we may want to increase their alignment in a way that
118   /// ensures that we can effectively vectorize their loads & stores. We can
119   /// increase alignment only if the function has internal or has private
120   /// linkage as for other linkage types callers may already rely on default
121   /// alignment. To allow using 128-bit vectorized loads/stores, this function
122   /// ensures that alignment is 16 or greater.
123   Align getFunctionParamOptimizedAlign(const Function *F, Type *ArgTy,
124                                        const DataLayout &DL) const;
125 
126   /// Helper for computing alignment of a device function byval parameter.
127   Align getFunctionByValParamAlign(const Function *F, Type *ArgTy,
128                                    Align InitialAlign,
129                                    const DataLayout &DL) const;
130 
131   // Helper for getting a function parameter name. Name is composed from
132   // its index and the function name. Negative index corresponds to special
133   // parameter (unsized array) used for passing variable arguments.
134   std::string getParamName(const Function *F, int Idx) const;
135 
136   /// isLegalAddressingMode - Return true if the addressing mode represented
137   /// by AM is legal for this target, for a load/store of the specified type
138   /// Used to guide target specific optimizations, like loop strength
139   /// reduction (LoopStrengthReduce.cpp) and memory optimization for
140   /// address mode (CodeGenPrepare.cpp)
141   bool isLegalAddressingMode(const DataLayout &DL, const AddrMode &AM, Type *Ty,
142                              unsigned AS,
143                              Instruction *I = nullptr) const override;
144 
145   bool isTruncateFree(Type *SrcTy, Type *DstTy) const override {
146     // Truncating 64-bit to 32-bit is free in SASS.
147     if (!SrcTy->isIntegerTy() || !DstTy->isIntegerTy())
148       return false;
149     return SrcTy->getPrimitiveSizeInBits() == 64 &&
150            DstTy->getPrimitiveSizeInBits() == 32;
151   }
152 
153   EVT getSetCCResultType(const DataLayout &DL, LLVMContext &Ctx,
154                          EVT VT) const override {
155     if (VT.isVector())
156       return EVT::getVectorVT(Ctx, MVT::i1, VT.getVectorNumElements());
157     return MVT::i1;
158   }
159 
160   ConstraintType getConstraintType(StringRef Constraint) const override;
161   std::pair<unsigned, const TargetRegisterClass *>
162   getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
163                                StringRef Constraint, MVT VT) const override;
164 
165   SDValue LowerFormalArguments(SDValue Chain, CallingConv::ID CallConv,
166                                bool isVarArg,
167                                const SmallVectorImpl<ISD::InputArg> &Ins,
168                                const SDLoc &dl, SelectionDAG &DAG,
169                                SmallVectorImpl<SDValue> &InVals) const override;
170 
171   SDValue LowerCall(CallLoweringInfo &CLI,
172                     SmallVectorImpl<SDValue> &InVals) const override;
173 
174   SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
175   SDValue LowerSTACKSAVE(SDValue Op, SelectionDAG &DAG) const;
176   SDValue LowerSTACKRESTORE(SDValue Op, SelectionDAG &DAG) const;
177 
178   std::string
179   getPrototype(const DataLayout &DL, Type *, const ArgListTy &,
180                const SmallVectorImpl<ISD::OutputArg> &, MaybeAlign retAlignment,
181                std::optional<std::pair<unsigned, const APInt &>> VAInfo,
182                const CallBase &CB, unsigned UniqueCallSite) const;
183 
184   SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
185                       const SmallVectorImpl<ISD::OutputArg> &Outs,
186                       const SmallVectorImpl<SDValue> &OutVals, const SDLoc &dl,
187                       SelectionDAG &DAG) const override;
188 
189   void LowerAsmOperandForConstraint(SDValue Op, StringRef Constraint,
190                                     std::vector<SDValue> &Ops,
191                                     SelectionDAG &DAG) const override;
192 
193   const NVPTXTargetMachine *nvTM;
194 
195   // PTX always uses 32-bit shift amounts
196   MVT getScalarShiftAmountTy(const DataLayout &, EVT) const override {
197     return MVT::i32;
198   }
199 
200   TargetLoweringBase::LegalizeTypeAction
201   getPreferredVectorAction(MVT VT) const override;
202 
203   // Get the degree of precision we want from 32-bit floating point division
204   // operations.
205   //
206   //  0 - Use ptx div.approx
207   //  1 - Use ptx.div.full (approximate, but less so than div.approx)
208   //  2 - Use IEEE-compliant div instructions, if available.
209   int getDivF32Level() const;
210 
211   // Get whether we should use a precise or approximate 32-bit floating point
212   // sqrt instruction.
213   bool usePrecSqrtF32() const;
214 
215   // Get whether we should use instructions that flush floating-point denormals
216   // to sign-preserving zero.
217   bool useF32FTZ(const MachineFunction &MF) const;
218 
219   SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
220                           int &ExtraSteps, bool &UseOneConst,
221                           bool Reciprocal) const override;
222 
223   unsigned combineRepeatedFPDivisors() const override { return 2; }
224 
225   bool allowFMA(MachineFunction &MF, CodeGenOptLevel OptLevel) const;
226   bool allowUnsafeFPMath(MachineFunction &MF) const;
227 
228   bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
229                                   EVT) const override {
230     return true;
231   }
232 
233   // The default is the same as pointer type, but brx.idx only accepts i32
234   MVT getJumpTableRegTy(const DataLayout &) const override { return MVT::i32; }
235 
236   unsigned getJumpTableEncoding() const override;
237 
238   bool enableAggressiveFMAFusion(EVT VT) const override { return true; }
239 
240   // The default is to transform llvm.ctlz(x, false) (where false indicates that
241   // x == 0 is not undefined behavior) into a branch that checks whether x is 0
242   // and avoids calling ctlz in that case.  We have a dedicated ctlz
243   // instruction, so we say that ctlz is cheap to speculate.
244   bool isCheapToSpeculateCtlz(Type *Ty) const override { return true; }
245 
246   AtomicExpansionKind shouldCastAtomicLoadInIR(LoadInst *LI) const override {
247     return AtomicExpansionKind::None;
248   }
249 
250   AtomicExpansionKind shouldCastAtomicStoreInIR(StoreInst *SI) const override {
251     return AtomicExpansionKind::None;
252   }
253 
254   AtomicExpansionKind
255   shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const override;
256 
257   bool aggressivelyPreferBuildVectorSources(EVT VecVT) const override {
258     // There's rarely any point of packing something into a vector type if we
259     // already have the source data.
260     return true;
261   }
262 
263 private:
264   const NVPTXSubtarget &STI; // cache the subtarget here
265   SDValue getParamSymbol(SelectionDAG &DAG, int idx, EVT) const;
266 
267   SDValue LowerBITCAST(SDValue Op, SelectionDAG &DAG) const;
268 
269   SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
270   SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
271   SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
272   SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
273   SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;
274 
275   SDValue LowerFCOPYSIGN(SDValue Op, SelectionDAG &DAG) const;
276 
277   SDValue LowerFROUND(SDValue Op, SelectionDAG &DAG) const;
278   SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;
279   SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const;
280 
281   SDValue PromoteBinOpIfF32FTZ(SDValue Op, SelectionDAG &DAG) const;
282 
283   SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
284   SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
285 
286   SDValue LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const;
287   SDValue LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const;
288 
289   SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
290   SDValue LowerLOADi1(SDValue Op, SelectionDAG &DAG) const;
291 
292   SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const;
293   SDValue LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const;
294   SDValue LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const;
295 
296   SDValue LowerShiftRightParts(SDValue Op, SelectionDAG &DAG) const;
297   SDValue LowerShiftLeftParts(SDValue Op, SelectionDAG &DAG) const;
298 
299   SDValue LowerSelect(SDValue Op, SelectionDAG &DAG) const;
300 
301   SDValue LowerBR_JT(SDValue Op, SelectionDAG &DAG) const;
302 
303   SDValue LowerVAARG(SDValue Op, SelectionDAG &DAG) const;
304   SDValue LowerVASTART(SDValue Op, SelectionDAG &DAG) const;
305 
306   SDValue LowerCopyToReg_128(SDValue Op, SelectionDAG &DAG) const;
307   unsigned getNumRegisters(LLVMContext &Context, EVT VT,
308                            std::optional<MVT> RegisterVT) const override;
309   bool
310   splitValueIntoRegisterParts(SelectionDAG &DAG, const SDLoc &DL, SDValue Val,
311                               SDValue *Parts, unsigned NumParts, MVT PartVT,
312                               std::optional<CallingConv::ID> CC) const override;
313 
314   void ReplaceNodeResults(SDNode *N, SmallVectorImpl<SDValue> &Results,
315                           SelectionDAG &DAG) const override;
316   SDValue PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const override;
317 
318   Align getArgumentAlignment(const CallBase *CB, Type *Ty, unsigned Idx,
319                              const DataLayout &DL) const;
320 };
321 
322 } // namespace llvm
323 
324 #endif
325