xref: /llvm-project/llvm/include/llvm/CodeGen/SwitchLoweringUtils.h (revision 292ee93a87018bfef519ceff7de676e4792aa8d9)
1 //===- SwitchLoweringUtils.h - Switch Lowering ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
10 #define LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
11 
12 #include "llvm/ADT/SmallVector.h"
13 #include "llvm/CodeGen/ISDOpcodes.h"
14 #include "llvm/CodeGen/SelectionDAGNodes.h"
15 #include "llvm/IR/InstrTypes.h"
16 #include "llvm/Support/BranchProbability.h"
17 #include <vector>
18 
19 namespace llvm {
20 
21 class BlockFrequencyInfo;
22 class ConstantInt;
23 class FunctionLoweringInfo;
24 class MachineBasicBlock;
25 class ProfileSummaryInfo;
26 class TargetLowering;
27 class TargetMachine;
28 
29 namespace SwitchCG {
30 
31 enum CaseClusterKind {
32   /// A cluster of adjacent case labels with the same destination, or just one
33   /// case.
34   CC_Range,
35   /// A cluster of cases suitable for jump table lowering.
36   CC_JumpTable,
37   /// A cluster of cases suitable for bit test lowering.
38   CC_BitTests
39 };
40 
41 /// A cluster of case labels.
42 struct CaseCluster {
43   CaseClusterKind Kind;
44   const ConstantInt *Low, *High;
45   union {
46     MachineBasicBlock *MBB;
47     unsigned JTCasesIndex;
48     unsigned BTCasesIndex;
49   };
50   BranchProbability Prob;
51 
52   static CaseCluster range(const ConstantInt *Low, const ConstantInt *High,
53                            MachineBasicBlock *MBB, BranchProbability Prob) {
54     CaseCluster C;
55     C.Kind = CC_Range;
56     C.Low = Low;
57     C.High = High;
58     C.MBB = MBB;
59     C.Prob = Prob;
60     return C;
61   }
62 
63   static CaseCluster jumpTable(const ConstantInt *Low, const ConstantInt *High,
64                                unsigned JTCasesIndex, BranchProbability Prob) {
65     CaseCluster C;
66     C.Kind = CC_JumpTable;
67     C.Low = Low;
68     C.High = High;
69     C.JTCasesIndex = JTCasesIndex;
70     C.Prob = Prob;
71     return C;
72   }
73 
74   static CaseCluster bitTests(const ConstantInt *Low, const ConstantInt *High,
75                               unsigned BTCasesIndex, BranchProbability Prob) {
76     CaseCluster C;
77     C.Kind = CC_BitTests;
78     C.Low = Low;
79     C.High = High;
80     C.BTCasesIndex = BTCasesIndex;
81     C.Prob = Prob;
82     return C;
83   }
84 };
85 
86 using CaseClusterVector = std::vector<CaseCluster>;
87 using CaseClusterIt = CaseClusterVector::iterator;
88 
89 /// Sort Clusters and merge adjacent cases.
90 void sortAndRangeify(CaseClusterVector &Clusters);
91 
92 struct CaseBits {
93   uint64_t Mask = 0;
94   MachineBasicBlock *BB = nullptr;
95   unsigned Bits = 0;
96   BranchProbability ExtraProb;
97 
98   CaseBits() = default;
99   CaseBits(uint64_t mask, MachineBasicBlock *bb, unsigned bits,
100            BranchProbability Prob)
101       : Mask(mask), BB(bb), Bits(bits), ExtraProb(Prob) {}
102 };
103 
104 using CaseBitsVector = std::vector<CaseBits>;
105 
106 /// This structure is used to communicate between SelectionDAGBuilder and
107 /// SDISel for the code generation of additional basic blocks needed by
108 /// multi-case switch statements.
109 struct CaseBlock {
110   // For the GISel interface.
111   struct PredInfoPair {
112     CmpInst::Predicate Pred;
113     // Set when no comparison should be emitted.
114     bool NoCmp;
115   };
116   union {
117     // The condition code to use for the case block's setcc node.
118     // Besides the integer condition codes, this can also be SETTRUE, in which
119     // case no comparison gets emitted.
120     ISD::CondCode CC;
121     struct PredInfoPair PredInfo;
122   };
123 
124   // The LHS/MHS/RHS of the comparison to emit.
125   // Emit by default LHS op RHS. MHS is used for range comparisons:
126   // If MHS is not null: (LHS <= MHS) and (MHS <= RHS).
127   const Value *CmpLHS, *CmpMHS, *CmpRHS;
128 
129   // The block to branch to if the setcc is true/false.
130   MachineBasicBlock *TrueBB, *FalseBB;
131 
132   // The block into which to emit the code for the setcc and branches.
133   MachineBasicBlock *ThisBB;
134 
135   /// The debug location of the instruction this CaseBlock was
136   /// produced from.
137   SDLoc DL;
138   DebugLoc DbgLoc;
139 
140   // Branch weights and predictability.
141   BranchProbability TrueProb, FalseProb;
142   bool IsUnpredictable;
143 
144   // Constructor for SelectionDAG.
145   CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs,
146             const Value *cmpmiddle, MachineBasicBlock *truebb,
147             MachineBasicBlock *falsebb, MachineBasicBlock *me, SDLoc dl,
148             BranchProbability trueprob = BranchProbability::getUnknown(),
149             BranchProbability falseprob = BranchProbability::getUnknown(),
150             bool isunpredictable = false)
151       : CC(cc), CmpLHS(cmplhs), CmpMHS(cmpmiddle), CmpRHS(cmprhs),
152         TrueBB(truebb), FalseBB(falsebb), ThisBB(me), DL(dl),
153         TrueProb(trueprob), FalseProb(falseprob),
154         IsUnpredictable(isunpredictable) {}
155 
156   // Constructor for GISel.
157   CaseBlock(CmpInst::Predicate pred, bool nocmp, const Value *cmplhs,
158             const Value *cmprhs, const Value *cmpmiddle,
159             MachineBasicBlock *truebb, MachineBasicBlock *falsebb,
160             MachineBasicBlock *me, DebugLoc dl,
161             BranchProbability trueprob = BranchProbability::getUnknown(),
162             BranchProbability falseprob = BranchProbability::getUnknown(),
163             bool isunpredictable = false)
164       : PredInfo({pred, nocmp}), CmpLHS(cmplhs), CmpMHS(cmpmiddle),
165         CmpRHS(cmprhs), TrueBB(truebb), FalseBB(falsebb), ThisBB(me),
166         DbgLoc(dl), TrueProb(trueprob), FalseProb(falseprob),
167         IsUnpredictable(isunpredictable) {}
168 };
169 
170 struct JumpTable {
171   /// The virtual register containing the index of the jump table entry
172   /// to jump to.
173   Register Reg;
174   /// The JumpTableIndex for this jump table in the function.
175   unsigned JTI;
176   /// The MBB into which to emit the code for the indirect jump.
177   MachineBasicBlock *MBB;
178   /// The MBB of the default bb, which is a successor of the range
179   /// check MBB.  This is when updating PHI nodes in successors.
180   MachineBasicBlock *Default;
181 
182   /// The debug location of the instruction this JumpTable was produced from.
183   std::optional<SDLoc> SL; // For SelectionDAG
184 
185   JumpTable(Register R, unsigned J, MachineBasicBlock *M, MachineBasicBlock *D,
186             std::optional<SDLoc> SL)
187       : Reg(R), JTI(J), MBB(M), Default(D), SL(SL) {}
188 };
189 struct JumpTableHeader {
190   APInt First;
191   APInt Last;
192   const Value *SValue;
193   MachineBasicBlock *HeaderBB;
194   bool Emitted;
195   bool FallthroughUnreachable = false;
196 
197   JumpTableHeader(APInt F, APInt L, const Value *SV, MachineBasicBlock *H,
198                   bool E = false)
199       : First(std::move(F)), Last(std::move(L)), SValue(SV), HeaderBB(H),
200         Emitted(E) {}
201 };
202 using JumpTableBlock = std::pair<JumpTableHeader, JumpTable>;
203 
204 struct BitTestCase {
205   uint64_t Mask;
206   MachineBasicBlock *ThisBB;
207   MachineBasicBlock *TargetBB;
208   BranchProbability ExtraProb;
209 
210   BitTestCase(uint64_t M, MachineBasicBlock *T, MachineBasicBlock *Tr,
211               BranchProbability Prob)
212       : Mask(M), ThisBB(T), TargetBB(Tr), ExtraProb(Prob) {}
213 };
214 
215 using BitTestInfo = SmallVector<BitTestCase, 3>;
216 
217 struct BitTestBlock {
218   APInt First;
219   APInt Range;
220   const Value *SValue;
221   Register Reg;
222   MVT RegVT;
223   bool Emitted;
224   bool ContiguousRange;
225   MachineBasicBlock *Parent;
226   MachineBasicBlock *Default;
227   BitTestInfo Cases;
228   BranchProbability Prob;
229   BranchProbability DefaultProb;
230   bool FallthroughUnreachable = false;
231 
232   BitTestBlock(APInt F, APInt R, const Value *SV, Register Rg, MVT RgVT, bool E,
233                bool CR, MachineBasicBlock *P, MachineBasicBlock *D,
234                BitTestInfo C, BranchProbability Pr)
235       : First(std::move(F)), Range(std::move(R)), SValue(SV), Reg(Rg),
236         RegVT(RgVT), Emitted(E), ContiguousRange(CR), Parent(P), Default(D),
237         Cases(std::move(C)), Prob(Pr) {}
238 };
239 
240 /// Return the range of values within a range.
241 uint64_t getJumpTableRange(const CaseClusterVector &Clusters, unsigned First,
242                            unsigned Last);
243 
244 /// Return the number of cases within a range.
245 uint64_t getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
246                               unsigned First, unsigned Last);
247 
248 struct SwitchWorkListItem {
249   MachineBasicBlock *MBB = nullptr;
250   CaseClusterIt FirstCluster;
251   CaseClusterIt LastCluster;
252   const ConstantInt *GE = nullptr;
253   const ConstantInt *LT = nullptr;
254   BranchProbability DefaultProb;
255 };
256 using SwitchWorkList = SmallVector<SwitchWorkListItem, 4>;
257 
258 class SwitchLowering {
259 public:
260   SwitchLowering(FunctionLoweringInfo &funcinfo) : FuncInfo(funcinfo) {}
261 
262   void init(const TargetLowering &tli, const TargetMachine &tm,
263             const DataLayout &dl) {
264     TLI = &tli;
265     TM = &tm;
266     DL = &dl;
267   }
268 
269   /// Vector of CaseBlock structures used to communicate SwitchInst code
270   /// generation information.
271   std::vector<CaseBlock> SwitchCases;
272 
273   /// Vector of JumpTable structures used to communicate SwitchInst code
274   /// generation information.
275   std::vector<JumpTableBlock> JTCases;
276 
277   /// Vector of BitTestBlock structures used to communicate SwitchInst code
278   /// generation information.
279   std::vector<BitTestBlock> BitTestCases;
280 
281   void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI,
282                       std::optional<SDLoc> SL, MachineBasicBlock *DefaultMBB,
283                       ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI);
284 
285   bool buildJumpTable(const CaseClusterVector &Clusters, unsigned First,
286                       unsigned Last, const SwitchInst *SI,
287                       const std::optional<SDLoc> &SL,
288                       MachineBasicBlock *DefaultMBB, CaseCluster &JTCluster);
289 
290   void findBitTestClusters(CaseClusterVector &Clusters, const SwitchInst *SI);
291 
292   /// Build a bit test cluster from Clusters[First..Last]. Returns false if it
293   /// decides it's not a good idea.
294   bool buildBitTests(CaseClusterVector &Clusters, unsigned First, unsigned Last,
295                      const SwitchInst *SI, CaseCluster &BTCluster);
296 
297   virtual void addSuccessorWithProb(
298       MachineBasicBlock *Src, MachineBasicBlock *Dst,
299       BranchProbability Prob = BranchProbability::getUnknown()) = 0;
300 
301   /// Determine the rank by weight of CC in [First,Last]. If CC has more weight
302   /// than each cluster in the range, its rank is 0.
303   unsigned caseClusterRank(const CaseCluster &CC, CaseClusterIt First,
304                            CaseClusterIt Last);
305 
306   struct SplitWorkItemInfo {
307     CaseClusterIt LastLeft;
308     CaseClusterIt FirstRight;
309     BranchProbability LeftProb;
310     BranchProbability RightProb;
311   };
312   /// Compute information to balance the tree based on branch probabilities to
313   /// create a near-optimal (in terms of search time given key frequency) binary
314   /// search tree. See e.g. Kurt Mehlhorn "Nearly Optimal Binary Search Trees"
315   /// (1975).
316   SplitWorkItemInfo computeSplitWorkItemInfo(const SwitchWorkListItem &W);
317   virtual ~SwitchLowering() = default;
318 
319 private:
320   const TargetLowering *TLI = nullptr;
321   const TargetMachine *TM = nullptr;
322   const DataLayout *DL = nullptr;
323   FunctionLoweringInfo &FuncInfo;
324 };
325 
326 } // namespace SwitchCG
327 } // namespace llvm
328 
329 #endif // LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
330