xref: /llvm-project/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h (revision f590963db836ccbf7c547a3dea9dc719f24444d1)
1 //===- RISCVTargetTransformInfo.h - RISC-V specific TTI ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// This file defines a TargetTransformInfo::Concept conforming object specific
10 /// to the RISC-V target machine. It uses the target's detailed information to
11 /// provide more precise answers to certain TTI queries, while letting the
12 /// target independent and default TTI implementations handle the rest.
13 ///
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef LLVM_LIB_TARGET_RISCV_RISCVTARGETTRANSFORMINFO_H
17 #define LLVM_LIB_TARGET_RISCV_RISCVTARGETTRANSFORMINFO_H
18 
19 #include "RISCVSubtarget.h"
20 #include "RISCVTargetMachine.h"
21 #include "llvm/Analysis/IVDescriptors.h"
22 #include "llvm/Analysis/TargetTransformInfo.h"
23 #include "llvm/CodeGen/BasicTTIImpl.h"
24 #include "llvm/IR/Function.h"
25 #include <optional>
26 
27 namespace llvm {
28 
29 class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
30   using BaseT = BasicTTIImplBase<RISCVTTIImpl>;
31   using TTI = TargetTransformInfo;
32 
33   friend BaseT;
34 
35   const RISCVSubtarget *ST;
36   const RISCVTargetLowering *TLI;
37 
38   const RISCVSubtarget *getST() const { return ST; }
39   const RISCVTargetLowering *getTLI() const { return TLI; }
40 
41   /// This function returns an estimate for VL to be used in VL based terms
42   /// of the cost model.  For fixed length vectors, this is simply the
43   /// vector length.  For scalable vectors, we return results consistent
44   /// with getVScaleForTuning under the assumption that clients are also
45   /// using that when comparing costs between scalar and vector representation.
46   /// This does unfortunately mean that we can both undershoot and overshot
47   /// the true cost significantly if getVScaleForTuning is wildly off for the
48   /// actual target hardware.
49   unsigned getEstimatedVLFor(VectorType *Ty);
50 
51   InstructionCost getRISCVInstructionCost(ArrayRef<unsigned> OpCodes, MVT VT,
52                                           TTI::TargetCostKind CostKind);
53 
54   /// Return the cost of accessing a constant pool entry of the specified
55   /// type.
56   InstructionCost getConstantPoolLoadCost(Type *Ty,
57                                           TTI::TargetCostKind CostKind);
58 public:
59   explicit RISCVTTIImpl(const RISCVTargetMachine *TM, const Function &F)
60       : BaseT(TM, F.getDataLayout()), ST(TM->getSubtargetImpl(F)),
61         TLI(ST->getTargetLowering()) {}
62 
63   /// Return the cost of materializing an immediate for a value operand of
64   /// a store instruction.
65   InstructionCost getStoreImmCost(Type *VecTy, TTI::OperandValueInfo OpInfo,
66                                   TTI::TargetCostKind CostKind);
67 
68   InstructionCost getIntImmCost(const APInt &Imm, Type *Ty,
69                                 TTI::TargetCostKind CostKind);
70   InstructionCost getIntImmCostInst(unsigned Opcode, unsigned Idx,
71                                     const APInt &Imm, Type *Ty,
72                                     TTI::TargetCostKind CostKind,
73                                     Instruction *Inst = nullptr);
74   InstructionCost getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx,
75                                       const APInt &Imm, Type *Ty,
76                                       TTI::TargetCostKind CostKind);
77 
78   /// \name EVL Support for predicated vectorization.
79   /// Whether the target supports the %evl parameter of VP intrinsic efficiently
80   /// in hardware, for the given opcode and type/alignment. (see LLVM Language
81   /// Reference - "Vector Predication Intrinsics",
82   /// https://llvm.org/docs/LangRef.html#vector-predication-intrinsics and
83   /// "IR-level VP intrinsics",
84   /// https://llvm.org/docs/Proposals/VectorPredication.html#ir-level-vp-intrinsics).
85   /// \param Opcode the opcode of the instruction checked for predicated version
86   /// support.
87   /// \param DataType the type of the instruction with the \p Opcode checked for
88   /// prediction support.
89   /// \param Alignment the alignment for memory access operation checked for
90   /// predicated version support.
91   bool hasActiveVectorLength(unsigned Opcode, Type *DataType,
92                              Align Alignment) const;
93 
94   TargetTransformInfo::PopcntSupportKind getPopcntSupport(unsigned TyWidth);
95 
96   bool shouldExpandReduction(const IntrinsicInst *II) const;
97   bool supportsScalableVectors() const { return ST->hasVInstructions(); }
98   bool enableOrderedReductions() const { return true; }
99   bool enableScalableVectorization() const { return ST->hasVInstructions(); }
100   TailFoldingStyle
101   getPreferredTailFoldingStyle(bool IVUpdateMayOverflow) const {
102     return ST->hasVInstructions() ? TailFoldingStyle::Data
103                                   : TailFoldingStyle::DataWithoutLaneMask;
104   }
105   std::optional<unsigned> getMaxVScale() const;
106   std::optional<unsigned> getVScaleForTuning() const;
107 
108   TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const;
109 
110   unsigned getRegUsageForType(Type *Ty);
111 
112   unsigned getMaximumVF(unsigned ElemWidth, unsigned Opcode) const;
113 
114   bool preferEpilogueVectorization() const {
115     // Epilogue vectorization is usually unprofitable - tail folding or
116     // a smaller VF would have been better.  This a blunt hammer - we
117     // should re-examine this once vectorization is better tuned.
118     return false;
119   }
120 
121   InstructionCost getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
122                                         Align Alignment, unsigned AddressSpace,
123                                         TTI::TargetCostKind CostKind);
124 
125   InstructionCost getPointersChainCost(ArrayRef<const Value *> Ptrs,
126                                        const Value *Base,
127                                        const TTI::PointersChainInfo &Info,
128                                        Type *AccessTy,
129                                        TTI::TargetCostKind CostKind);
130 
131   void getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
132                                TTI::UnrollingPreferences &UP,
133                                OptimizationRemarkEmitter *ORE);
134 
135   void getPeelingPreferences(Loop *L, ScalarEvolution &SE,
136                              TTI::PeelingPreferences &PP);
137 
138   unsigned getMinVectorRegisterBitWidth() const {
139     return ST->useRVVForFixedLengthVectors() ? 16 : 0;
140   }
141 
142   InstructionCost getShuffleCost(TTI::ShuffleKind Kind, VectorType *Tp,
143                                  ArrayRef<int> Mask,
144                                  TTI::TargetCostKind CostKind, int Index,
145                                  VectorType *SubTp,
146                                  ArrayRef<const Value *> Args = {},
147                                  const Instruction *CxtI = nullptr);
148 
149   InstructionCost getScalarizationOverhead(VectorType *Ty,
150                                            const APInt &DemandedElts,
151                                            bool Insert, bool Extract,
152                                            TTI::TargetCostKind CostKind,
153                                            ArrayRef<Value *> VL = {});
154 
155   InstructionCost getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
156                                         TTI::TargetCostKind CostKind);
157 
158   InstructionCost getInterleavedMemoryOpCost(
159       unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
160       Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
161       bool UseMaskForCond = false, bool UseMaskForGaps = false);
162 
163   InstructionCost getGatherScatterOpCost(unsigned Opcode, Type *DataTy,
164                                          const Value *Ptr, bool VariableMask,
165                                          Align Alignment,
166                                          TTI::TargetCostKind CostKind,
167                                          const Instruction *I);
168 
169   InstructionCost getStridedMemoryOpCost(unsigned Opcode, Type *DataTy,
170                                          const Value *Ptr, bool VariableMask,
171                                          Align Alignment,
172                                          TTI::TargetCostKind CostKind,
173                                          const Instruction *I);
174 
175   InstructionCost getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys);
176 
177   InstructionCost getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
178                                    TTI::CastContextHint CCH,
179                                    TTI::TargetCostKind CostKind,
180                                    const Instruction *I = nullptr);
181 
182   InstructionCost getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty,
183                                          FastMathFlags FMF,
184                                          TTI::TargetCostKind CostKind);
185 
186   InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
187                                              std::optional<FastMathFlags> FMF,
188                                              TTI::TargetCostKind CostKind);
189 
190   InstructionCost getExtendedReductionCost(unsigned Opcode, bool IsUnsigned,
191                                            Type *ResTy, VectorType *ValTy,
192                                            FastMathFlags FMF,
193                                            TTI::TargetCostKind CostKind);
194 
195   InstructionCost
196   getMemoryOpCost(unsigned Opcode, Type *Src, MaybeAlign Alignment,
197                   unsigned AddressSpace, TTI::TargetCostKind CostKind,
198                   TTI::OperandValueInfo OpdInfo = {TTI::OK_AnyValue, TTI::OP_None},
199                   const Instruction *I = nullptr);
200 
201   InstructionCost getCmpSelInstrCost(
202       unsigned Opcode, Type *ValTy, Type *CondTy, CmpInst::Predicate VecPred,
203       TTI::TargetCostKind CostKind,
204       TTI::OperandValueInfo Op1Info = {TTI::OK_AnyValue, TTI::OP_None},
205       TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
206       const Instruction *I = nullptr);
207 
208   InstructionCost getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind,
209                                  const Instruction *I = nullptr);
210 
211   using BaseT::getVectorInstrCost;
212   InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
213                                      TTI::TargetCostKind CostKind,
214                                      unsigned Index, Value *Op0, Value *Op1);
215 
216   InstructionCost getArithmeticInstrCost(
217       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
218       TTI::OperandValueInfo Op1Info = {TTI::OK_AnyValue, TTI::OP_None},
219       TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
220       ArrayRef<const Value *> Args = {}, const Instruction *CxtI = nullptr);
221 
222   bool isElementTypeLegalForScalableVector(Type *Ty) const {
223     return TLI->isLegalElementTypeForRVV(TLI->getValueType(DL, Ty));
224   }
225 
226   bool isLegalMaskedLoadStore(Type *DataType, Align Alignment) {
227     if (!ST->hasVInstructions())
228       return false;
229 
230     EVT DataTypeVT = TLI->getValueType(DL, DataType);
231 
232     // Only support fixed vectors if we know the minimum vector size.
233     if (DataTypeVT.isFixedLengthVector() && !ST->useRVVForFixedLengthVectors())
234       return false;
235 
236     EVT ElemType = DataTypeVT.getScalarType();
237     if (!ST->enableUnalignedVectorMem() && Alignment < ElemType.getStoreSize())
238       return false;
239 
240     return TLI->isLegalElementTypeForRVV(ElemType);
241   }
242 
243   bool isLegalMaskedLoad(Type *DataType, Align Alignment) {
244     return isLegalMaskedLoadStore(DataType, Alignment);
245   }
246   bool isLegalMaskedStore(Type *DataType, Align Alignment) {
247     return isLegalMaskedLoadStore(DataType, Alignment);
248   }
249 
250   bool isLegalMaskedGatherScatter(Type *DataType, Align Alignment) {
251     if (!ST->hasVInstructions())
252       return false;
253 
254     EVT DataTypeVT = TLI->getValueType(DL, DataType);
255 
256     // Only support fixed vectors if we know the minimum vector size.
257     if (DataTypeVT.isFixedLengthVector() && !ST->useRVVForFixedLengthVectors())
258       return false;
259 
260     // We also need to check if the vector of address is valid.
261     EVT PointerTypeVT = EVT(TLI->getPointerTy(DL));
262     if (DataTypeVT.isScalableVector() &&
263         !TLI->isLegalElementTypeForRVV(PointerTypeVT))
264       return false;
265 
266     EVT ElemType = DataTypeVT.getScalarType();
267     if (!ST->enableUnalignedVectorMem() && Alignment < ElemType.getStoreSize())
268       return false;
269 
270     return TLI->isLegalElementTypeForRVV(ElemType);
271   }
272 
273   bool isLegalMaskedGather(Type *DataType, Align Alignment) {
274     return isLegalMaskedGatherScatter(DataType, Alignment);
275   }
276   bool isLegalMaskedScatter(Type *DataType, Align Alignment) {
277     return isLegalMaskedGatherScatter(DataType, Alignment);
278   }
279 
280   bool forceScalarizeMaskedGather(VectorType *VTy, Align Alignment) {
281     // Scalarize masked gather for RV64 if EEW=64 indices aren't supported.
282     return ST->is64Bit() && !ST->hasVInstructionsI64();
283   }
284 
285   bool forceScalarizeMaskedScatter(VectorType *VTy, Align Alignment) {
286     // Scalarize masked scatter for RV64 if EEW=64 indices aren't supported.
287     return ST->is64Bit() && !ST->hasVInstructionsI64();
288   }
289 
290   bool isLegalStridedLoadStore(Type *DataType, Align Alignment) {
291     EVT DataTypeVT = TLI->getValueType(DL, DataType);
292     return TLI->isLegalStridedLoadStore(DataTypeVT, Alignment);
293   }
294 
295   bool isLegalInterleavedAccessType(VectorType *VTy, unsigned Factor,
296                                     Align Alignment, unsigned AddrSpace) {
297     return TLI->isLegalInterleavedAccessType(VTy, Factor, Alignment, AddrSpace,
298                                              DL);
299   }
300 
301   bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment);
302 
303   bool isLegalMaskedCompressStore(Type *DataTy, Align Alignment);
304 
305   bool isVScaleKnownToBeAPowerOfTwo() const {
306     return TLI->isVScaleKnownToBeAPowerOfTwo();
307   }
308 
309   /// \returns How the target needs this vector-predicated operation to be
310   /// transformed.
311   TargetTransformInfo::VPLegalization
312   getVPLegalizationStrategy(const VPIntrinsic &PI) const {
313     using VPLegalization = TargetTransformInfo::VPLegalization;
314     if (!ST->hasVInstructions() ||
315         (PI.getIntrinsicID() == Intrinsic::vp_reduce_mul &&
316          cast<VectorType>(PI.getArgOperand(1)->getType())
317                  ->getElementType()
318                  ->getIntegerBitWidth() != 1))
319       return VPLegalization(VPLegalization::Discard, VPLegalization::Convert);
320     return VPLegalization(VPLegalization::Legal, VPLegalization::Legal);
321   }
322 
323   bool isLegalToVectorizeReduction(const RecurrenceDescriptor &RdxDesc,
324                                    ElementCount VF) const {
325     if (!VF.isScalable())
326       return true;
327 
328     Type *Ty = RdxDesc.getRecurrenceType();
329     if (!TLI->isLegalElementTypeForRVV(TLI->getValueType(DL, Ty)))
330       return false;
331 
332     // We can't promote f16/bf16 fadd reductions and scalable vectors can't be
333     // expanded.
334     // TODO: Promote f16/bf16 fmin/fmax reductions
335     if (Ty->isBFloatTy() || (Ty->isHalfTy() && !ST->hasVInstructionsF16()))
336       return false;
337 
338     switch (RdxDesc.getRecurrenceKind()) {
339     case RecurKind::Add:
340     case RecurKind::FAdd:
341     case RecurKind::And:
342     case RecurKind::Or:
343     case RecurKind::Xor:
344     case RecurKind::SMin:
345     case RecurKind::SMax:
346     case RecurKind::UMin:
347     case RecurKind::UMax:
348     case RecurKind::FMin:
349     case RecurKind::FMax:
350     case RecurKind::FMulAdd:
351     case RecurKind::IAnyOf:
352     case RecurKind::FAnyOf:
353       return true;
354     default:
355       return false;
356     }
357   }
358 
359   unsigned getMaxInterleaveFactor(ElementCount VF) {
360     // Don't interleave if the loop has been vectorized with scalable vectors.
361     if (VF.isScalable())
362       return 1;
363     // If the loop will not be vectorized, don't interleave the loop.
364     // Let regular unroll to unroll the loop.
365     return VF.isScalar() ? 1 : ST->getMaxInterleaveFactor();
366   }
367 
368   bool enableInterleavedAccessVectorization() { return true; }
369 
370   enum RISCVRegisterClass { GPRRC, FPRRC, VRRC };
371   unsigned getNumberOfRegisters(unsigned ClassID) const {
372     switch (ClassID) {
373     case RISCVRegisterClass::GPRRC:
374       // 31 = 32 GPR - x0 (zero register)
375       // FIXME: Should we exclude fixed registers like SP, TP or GP?
376       return 31;
377     case RISCVRegisterClass::FPRRC:
378       if (ST->hasStdExtF())
379         return 32;
380       return 0;
381     case RISCVRegisterClass::VRRC:
382       // Although there are 32 vector registers, v0 is special in that it is the
383       // only register that can be used to hold a mask.
384       // FIXME: Should we conservatively return 31 as the number of usable
385       // vector registers?
386       return ST->hasVInstructions() ? 32 : 0;
387     }
388     llvm_unreachable("unknown register class");
389   }
390 
391   TTI::AddressingModeKind getPreferredAddressingMode(const Loop *L,
392                                                      ScalarEvolution *SE) const;
393 
394   unsigned getRegisterClassForType(bool Vector, Type *Ty = nullptr) const {
395     if (Vector)
396       return RISCVRegisterClass::VRRC;
397     if (!Ty)
398       return RISCVRegisterClass::GPRRC;
399 
400     Type *ScalarTy = Ty->getScalarType();
401     if ((ScalarTy->isHalfTy() && ST->hasStdExtZfhmin()) ||
402         (ScalarTy->isFloatTy() && ST->hasStdExtF()) ||
403         (ScalarTy->isDoubleTy() && ST->hasStdExtD())) {
404       return RISCVRegisterClass::FPRRC;
405     }
406 
407     return RISCVRegisterClass::GPRRC;
408   }
409 
410   const char *getRegisterClassName(unsigned ClassID) const {
411     switch (ClassID) {
412     case RISCVRegisterClass::GPRRC:
413       return "RISCV::GPRRC";
414     case RISCVRegisterClass::FPRRC:
415       return "RISCV::FPRRC";
416     case RISCVRegisterClass::VRRC:
417       return "RISCV::VRRC";
418     }
419     llvm_unreachable("unknown register class");
420   }
421 
422   bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
423                      const TargetTransformInfo::LSRCost &C2);
424 
425   bool
426   shouldConsiderAddressTypePromotion(const Instruction &I,
427                                      bool &AllowPromotionWithoutCommonHeader);
428   std::optional<unsigned> getMinPageSize() const { return 4096; }
429   /// Return true if the (vector) instruction I will be lowered to an
430   /// instruction with a scalar splat operand for the given Operand number.
431   bool canSplatOperand(Instruction *I, int Operand) const;
432   /// Return true if a vector instruction will lower to a target instruction
433   /// able to splat the given operand.
434   bool canSplatOperand(unsigned Opcode, int Operand) const;
435 
436   bool isProfitableToSinkOperands(Instruction *I,
437                                   SmallVectorImpl<Use *> &Ops) const;
438 
439   TTI::MemCmpExpansionOptions enableMemCmpExpansion(bool OptSize,
440                                                     bool IsZeroCmp) const;
441 };
442 
443 } // end namespace llvm
444 
445 #endif // LLVM_LIB_TARGET_RISCV_RISCVTARGETTRANSFORMINFO_H
446