1 //===- SwitchLoweringUtils.h - Switch Lowering ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef LLVM_CODEGEN_SWITCHLOWERINGUTILS_H 10 #define LLVM_CODEGEN_SWITCHLOWERINGUTILS_H 11 12 #include "llvm/ADT/SmallVector.h" 13 #include "llvm/CodeGen/ISDOpcodes.h" 14 #include "llvm/CodeGen/SelectionDAGNodes.h" 15 #include "llvm/IR/InstrTypes.h" 16 #include "llvm/Support/BranchProbability.h" 17 #include <vector> 18 19 namespace llvm { 20 21 class BlockFrequencyInfo; 22 class ConstantInt; 23 class FunctionLoweringInfo; 24 class MachineBasicBlock; 25 class ProfileSummaryInfo; 26 class TargetLowering; 27 class TargetMachine; 28 29 namespace SwitchCG { 30 31 enum CaseClusterKind { 32 /// A cluster of adjacent case labels with the same destination, or just one 33 /// case. 34 CC_Range, 35 /// A cluster of cases suitable for jump table lowering. 36 CC_JumpTable, 37 /// A cluster of cases suitable for bit test lowering. 38 CC_BitTests 39 }; 40 41 /// A cluster of case labels. 42 struct CaseCluster { 43 CaseClusterKind Kind; 44 const ConstantInt *Low, *High; 45 union { 46 MachineBasicBlock *MBB; 47 unsigned JTCasesIndex; 48 unsigned BTCasesIndex; 49 }; 50 BranchProbability Prob; 51 52 static CaseCluster range(const ConstantInt *Low, const ConstantInt *High, 53 MachineBasicBlock *MBB, BranchProbability Prob) { 54 CaseCluster C; 55 C.Kind = CC_Range; 56 C.Low = Low; 57 C.High = High; 58 C.MBB = MBB; 59 C.Prob = Prob; 60 return C; 61 } 62 63 static CaseCluster jumpTable(const ConstantInt *Low, const ConstantInt *High, 64 unsigned JTCasesIndex, BranchProbability Prob) { 65 CaseCluster C; 66 C.Kind = CC_JumpTable; 67 C.Low = Low; 68 C.High = High; 69 C.JTCasesIndex = JTCasesIndex; 70 C.Prob = Prob; 71 return C; 72 } 73 74 static CaseCluster bitTests(const ConstantInt *Low, const ConstantInt *High, 75 unsigned BTCasesIndex, BranchProbability Prob) { 76 CaseCluster C; 77 C.Kind = CC_BitTests; 78 C.Low = Low; 79 C.High = High; 80 C.BTCasesIndex = BTCasesIndex; 81 C.Prob = Prob; 82 return C; 83 } 84 }; 85 86 using CaseClusterVector = std::vector<CaseCluster>; 87 using CaseClusterIt = CaseClusterVector::iterator; 88 89 /// Sort Clusters and merge adjacent cases. 90 void sortAndRangeify(CaseClusterVector &Clusters); 91 92 struct CaseBits { 93 uint64_t Mask = 0; 94 MachineBasicBlock *BB = nullptr; 95 unsigned Bits = 0; 96 BranchProbability ExtraProb; 97 98 CaseBits() = default; 99 CaseBits(uint64_t mask, MachineBasicBlock *bb, unsigned bits, 100 BranchProbability Prob) 101 : Mask(mask), BB(bb), Bits(bits), ExtraProb(Prob) {} 102 }; 103 104 using CaseBitsVector = std::vector<CaseBits>; 105 106 /// This structure is used to communicate between SelectionDAGBuilder and 107 /// SDISel for the code generation of additional basic blocks needed by 108 /// multi-case switch statements. 109 struct CaseBlock { 110 // For the GISel interface. 111 struct PredInfoPair { 112 CmpInst::Predicate Pred; 113 // Set when no comparison should be emitted. 114 bool NoCmp; 115 }; 116 union { 117 // The condition code to use for the case block's setcc node. 118 // Besides the integer condition codes, this can also be SETTRUE, in which 119 // case no comparison gets emitted. 120 ISD::CondCode CC; 121 struct PredInfoPair PredInfo; 122 }; 123 124 // The LHS/MHS/RHS of the comparison to emit. 125 // Emit by default LHS op RHS. MHS is used for range comparisons: 126 // If MHS is not null: (LHS <= MHS) and (MHS <= RHS). 127 const Value *CmpLHS, *CmpMHS, *CmpRHS; 128 129 // The block to branch to if the setcc is true/false. 130 MachineBasicBlock *TrueBB, *FalseBB; 131 132 // The block into which to emit the code for the setcc and branches. 133 MachineBasicBlock *ThisBB; 134 135 /// The debug location of the instruction this CaseBlock was 136 /// produced from. 137 SDLoc DL; 138 DebugLoc DbgLoc; 139 140 // Branch weights and predictability. 141 BranchProbability TrueProb, FalseProb; 142 bool IsUnpredictable; 143 144 // Constructor for SelectionDAG. 145 CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs, 146 const Value *cmpmiddle, MachineBasicBlock *truebb, 147 MachineBasicBlock *falsebb, MachineBasicBlock *me, SDLoc dl, 148 BranchProbability trueprob = BranchProbability::getUnknown(), 149 BranchProbability falseprob = BranchProbability::getUnknown(), 150 bool isunpredictable = false) 151 : CC(cc), CmpLHS(cmplhs), CmpMHS(cmpmiddle), CmpRHS(cmprhs), 152 TrueBB(truebb), FalseBB(falsebb), ThisBB(me), DL(dl), 153 TrueProb(trueprob), FalseProb(falseprob), 154 IsUnpredictable(isunpredictable) {} 155 156 // Constructor for GISel. 157 CaseBlock(CmpInst::Predicate pred, bool nocmp, const Value *cmplhs, 158 const Value *cmprhs, const Value *cmpmiddle, 159 MachineBasicBlock *truebb, MachineBasicBlock *falsebb, 160 MachineBasicBlock *me, DebugLoc dl, 161 BranchProbability trueprob = BranchProbability::getUnknown(), 162 BranchProbability falseprob = BranchProbability::getUnknown(), 163 bool isunpredictable = false) 164 : PredInfo({pred, nocmp}), CmpLHS(cmplhs), CmpMHS(cmpmiddle), 165 CmpRHS(cmprhs), TrueBB(truebb), FalseBB(falsebb), ThisBB(me), 166 DbgLoc(dl), TrueProb(trueprob), FalseProb(falseprob), 167 IsUnpredictable(isunpredictable) {} 168 }; 169 170 struct JumpTable { 171 /// The virtual register containing the index of the jump table entry 172 /// to jump to. 173 Register Reg; 174 /// The JumpTableIndex for this jump table in the function. 175 unsigned JTI; 176 /// The MBB into which to emit the code for the indirect jump. 177 MachineBasicBlock *MBB; 178 /// The MBB of the default bb, which is a successor of the range 179 /// check MBB. This is when updating PHI nodes in successors. 180 MachineBasicBlock *Default; 181 182 /// The debug location of the instruction this JumpTable was produced from. 183 std::optional<SDLoc> SL; // For SelectionDAG 184 185 JumpTable(Register R, unsigned J, MachineBasicBlock *M, MachineBasicBlock *D, 186 std::optional<SDLoc> SL) 187 : Reg(R), JTI(J), MBB(M), Default(D), SL(SL) {} 188 }; 189 struct JumpTableHeader { 190 APInt First; 191 APInt Last; 192 const Value *SValue; 193 MachineBasicBlock *HeaderBB; 194 bool Emitted; 195 bool FallthroughUnreachable = false; 196 197 JumpTableHeader(APInt F, APInt L, const Value *SV, MachineBasicBlock *H, 198 bool E = false) 199 : First(std::move(F)), Last(std::move(L)), SValue(SV), HeaderBB(H), 200 Emitted(E) {} 201 }; 202 using JumpTableBlock = std::pair<JumpTableHeader, JumpTable>; 203 204 struct BitTestCase { 205 uint64_t Mask; 206 MachineBasicBlock *ThisBB; 207 MachineBasicBlock *TargetBB; 208 BranchProbability ExtraProb; 209 210 BitTestCase(uint64_t M, MachineBasicBlock *T, MachineBasicBlock *Tr, 211 BranchProbability Prob) 212 : Mask(M), ThisBB(T), TargetBB(Tr), ExtraProb(Prob) {} 213 }; 214 215 using BitTestInfo = SmallVector<BitTestCase, 3>; 216 217 struct BitTestBlock { 218 APInt First; 219 APInt Range; 220 const Value *SValue; 221 Register Reg; 222 MVT RegVT; 223 bool Emitted; 224 bool ContiguousRange; 225 MachineBasicBlock *Parent; 226 MachineBasicBlock *Default; 227 BitTestInfo Cases; 228 BranchProbability Prob; 229 BranchProbability DefaultProb; 230 bool FallthroughUnreachable = false; 231 232 BitTestBlock(APInt F, APInt R, const Value *SV, Register Rg, MVT RgVT, bool E, 233 bool CR, MachineBasicBlock *P, MachineBasicBlock *D, 234 BitTestInfo C, BranchProbability Pr) 235 : First(std::move(F)), Range(std::move(R)), SValue(SV), Reg(Rg), 236 RegVT(RgVT), Emitted(E), ContiguousRange(CR), Parent(P), Default(D), 237 Cases(std::move(C)), Prob(Pr) {} 238 }; 239 240 /// Return the range of values within a range. 241 uint64_t getJumpTableRange(const CaseClusterVector &Clusters, unsigned First, 242 unsigned Last); 243 244 /// Return the number of cases within a range. 245 uint64_t getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases, 246 unsigned First, unsigned Last); 247 248 struct SwitchWorkListItem { 249 MachineBasicBlock *MBB = nullptr; 250 CaseClusterIt FirstCluster; 251 CaseClusterIt LastCluster; 252 const ConstantInt *GE = nullptr; 253 const ConstantInt *LT = nullptr; 254 BranchProbability DefaultProb; 255 }; 256 using SwitchWorkList = SmallVector<SwitchWorkListItem, 4>; 257 258 class SwitchLowering { 259 public: 260 SwitchLowering(FunctionLoweringInfo &funcinfo) : FuncInfo(funcinfo) {} 261 262 void init(const TargetLowering &tli, const TargetMachine &tm, 263 const DataLayout &dl) { 264 TLI = &tli; 265 TM = &tm; 266 DL = &dl; 267 } 268 269 /// Vector of CaseBlock structures used to communicate SwitchInst code 270 /// generation information. 271 std::vector<CaseBlock> SwitchCases; 272 273 /// Vector of JumpTable structures used to communicate SwitchInst code 274 /// generation information. 275 std::vector<JumpTableBlock> JTCases; 276 277 /// Vector of BitTestBlock structures used to communicate SwitchInst code 278 /// generation information. 279 std::vector<BitTestBlock> BitTestCases; 280 281 void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI, 282 std::optional<SDLoc> SL, MachineBasicBlock *DefaultMBB, 283 ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI); 284 285 bool buildJumpTable(const CaseClusterVector &Clusters, unsigned First, 286 unsigned Last, const SwitchInst *SI, 287 const std::optional<SDLoc> &SL, 288 MachineBasicBlock *DefaultMBB, CaseCluster &JTCluster); 289 290 void findBitTestClusters(CaseClusterVector &Clusters, const SwitchInst *SI); 291 292 /// Build a bit test cluster from Clusters[First..Last]. Returns false if it 293 /// decides it's not a good idea. 294 bool buildBitTests(CaseClusterVector &Clusters, unsigned First, unsigned Last, 295 const SwitchInst *SI, CaseCluster &BTCluster); 296 297 virtual void addSuccessorWithProb( 298 MachineBasicBlock *Src, MachineBasicBlock *Dst, 299 BranchProbability Prob = BranchProbability::getUnknown()) = 0; 300 301 /// Determine the rank by weight of CC in [First,Last]. If CC has more weight 302 /// than each cluster in the range, its rank is 0. 303 unsigned caseClusterRank(const CaseCluster &CC, CaseClusterIt First, 304 CaseClusterIt Last); 305 306 struct SplitWorkItemInfo { 307 CaseClusterIt LastLeft; 308 CaseClusterIt FirstRight; 309 BranchProbability LeftProb; 310 BranchProbability RightProb; 311 }; 312 /// Compute information to balance the tree based on branch probabilities to 313 /// create a near-optimal (in terms of search time given key frequency) binary 314 /// search tree. See e.g. Kurt Mehlhorn "Nearly Optimal Binary Search Trees" 315 /// (1975). 316 SplitWorkItemInfo computeSplitWorkItemInfo(const SwitchWorkListItem &W); 317 virtual ~SwitchLowering() = default; 318 319 private: 320 const TargetLowering *TLI = nullptr; 321 const TargetMachine *TM = nullptr; 322 const DataLayout *DL = nullptr; 323 FunctionLoweringInfo &FuncInfo; 324 }; 325 326 } // namespace SwitchCG 327 } // namespace llvm 328 329 #endif // LLVM_CODEGEN_SWITCHLOWERINGUTILS_H 330