1 //===- AggressiveInstCombine.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the aggressive expression pattern combiner classes. 10 // Currently, it handles expression patterns for: 11 // * Truncate instruction 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h" 16 #include "AggressiveInstCombineInternal.h" 17 #include "llvm/ADT/Statistic.h" 18 #include "llvm/Analysis/AliasAnalysis.h" 19 #include "llvm/Analysis/AssumptionCache.h" 20 #include "llvm/Analysis/BasicAliasAnalysis.h" 21 #include "llvm/Analysis/ConstantFolding.h" 22 #include "llvm/Analysis/DomTreeUpdater.h" 23 #include "llvm/Analysis/GlobalsModRef.h" 24 #include "llvm/Analysis/TargetLibraryInfo.h" 25 #include "llvm/Analysis/TargetTransformInfo.h" 26 #include "llvm/Analysis/ValueTracking.h" 27 #include "llvm/IR/DataLayout.h" 28 #include "llvm/IR/Dominators.h" 29 #include "llvm/IR/Function.h" 30 #include "llvm/IR/IRBuilder.h" 31 #include "llvm/IR/PatternMatch.h" 32 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 33 #include "llvm/Transforms/Utils/BuildLibCalls.h" 34 #include "llvm/Transforms/Utils/Local.h" 35 36 using namespace llvm; 37 using namespace PatternMatch; 38 39 #define DEBUG_TYPE "aggressive-instcombine" 40 41 STATISTIC(NumAnyOrAllBitsSet, "Number of any/all-bits-set patterns folded"); 42 STATISTIC(NumGuardedRotates, 43 "Number of guarded rotates transformed into funnel shifts"); 44 STATISTIC(NumGuardedFunnelShifts, 45 "Number of guarded funnel shifts transformed into funnel shifts"); 46 STATISTIC(NumPopCountRecognized, "Number of popcount idioms recognized"); 47 48 static cl::opt<unsigned> MaxInstrsToScan( 49 "aggressive-instcombine-max-scan-instrs", cl::init(64), cl::Hidden, 50 cl::desc("Max number of instructions to scan for aggressive instcombine.")); 51 52 static cl::opt<unsigned> StrNCmpInlineThreshold( 53 "strncmp-inline-threshold", cl::init(3), cl::Hidden, 54 cl::desc("The maximum length of a constant string for a builtin string cmp " 55 "call eligible for inlining. The default value is 3.")); 56 57 static cl::opt<unsigned> 58 MemChrInlineThreshold("memchr-inline-threshold", cl::init(3), cl::Hidden, 59 cl::desc("The maximum length of a constant string to " 60 "inline a memchr call.")); 61 62 /// Match a pattern for a bitwise funnel/rotate operation that partially guards 63 /// against undefined behavior by branching around the funnel-shift/rotation 64 /// when the shift amount is 0. 65 static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) { 66 if (I.getOpcode() != Instruction::PHI || I.getNumOperands() != 2) 67 return false; 68 69 // As with the one-use checks below, this is not strictly necessary, but we 70 // are being cautious to avoid potential perf regressions on targets that 71 // do not actually have a funnel/rotate instruction (where the funnel shift 72 // would be expanded back into math/shift/logic ops). 73 if (!isPowerOf2_32(I.getType()->getScalarSizeInBits())) 74 return false; 75 76 // Match V to funnel shift left/right and capture the source operands and 77 // shift amount. 78 auto matchFunnelShift = [](Value *V, Value *&ShVal0, Value *&ShVal1, 79 Value *&ShAmt) { 80 unsigned Width = V->getType()->getScalarSizeInBits(); 81 82 // fshl(ShVal0, ShVal1, ShAmt) 83 // == (ShVal0 << ShAmt) | (ShVal1 >> (Width -ShAmt)) 84 if (match(V, m_OneUse(m_c_Or( 85 m_Shl(m_Value(ShVal0), m_Value(ShAmt)), 86 m_LShr(m_Value(ShVal1), 87 m_Sub(m_SpecificInt(Width), m_Deferred(ShAmt))))))) { 88 return Intrinsic::fshl; 89 } 90 91 // fshr(ShVal0, ShVal1, ShAmt) 92 // == (ShVal0 >> ShAmt) | (ShVal1 << (Width - ShAmt)) 93 if (match(V, 94 m_OneUse(m_c_Or(m_Shl(m_Value(ShVal0), m_Sub(m_SpecificInt(Width), 95 m_Value(ShAmt))), 96 m_LShr(m_Value(ShVal1), m_Deferred(ShAmt)))))) { 97 return Intrinsic::fshr; 98 } 99 100 return Intrinsic::not_intrinsic; 101 }; 102 103 // One phi operand must be a funnel/rotate operation, and the other phi 104 // operand must be the source value of that funnel/rotate operation: 105 // phi [ rotate(RotSrc, ShAmt), FunnelBB ], [ RotSrc, GuardBB ] 106 // phi [ fshl(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal0, GuardBB ] 107 // phi [ fshr(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal1, GuardBB ] 108 PHINode &Phi = cast<PHINode>(I); 109 unsigned FunnelOp = 0, GuardOp = 1; 110 Value *P0 = Phi.getOperand(0), *P1 = Phi.getOperand(1); 111 Value *ShVal0, *ShVal1, *ShAmt; 112 Intrinsic::ID IID = matchFunnelShift(P0, ShVal0, ShVal1, ShAmt); 113 if (IID == Intrinsic::not_intrinsic || 114 (IID == Intrinsic::fshl && ShVal0 != P1) || 115 (IID == Intrinsic::fshr && ShVal1 != P1)) { 116 IID = matchFunnelShift(P1, ShVal0, ShVal1, ShAmt); 117 if (IID == Intrinsic::not_intrinsic || 118 (IID == Intrinsic::fshl && ShVal0 != P0) || 119 (IID == Intrinsic::fshr && ShVal1 != P0)) 120 return false; 121 assert((IID == Intrinsic::fshl || IID == Intrinsic::fshr) && 122 "Pattern must match funnel shift left or right"); 123 std::swap(FunnelOp, GuardOp); 124 } 125 126 // The incoming block with our source operand must be the "guard" block. 127 // That must contain a cmp+branch to avoid the funnel/rotate when the shift 128 // amount is equal to 0. The other incoming block is the block with the 129 // funnel/rotate. 130 BasicBlock *GuardBB = Phi.getIncomingBlock(GuardOp); 131 BasicBlock *FunnelBB = Phi.getIncomingBlock(FunnelOp); 132 Instruction *TermI = GuardBB->getTerminator(); 133 134 // Ensure that the shift values dominate each block. 135 if (!DT.dominates(ShVal0, TermI) || !DT.dominates(ShVal1, TermI)) 136 return false; 137 138 BasicBlock *PhiBB = Phi.getParent(); 139 if (!match(TermI, m_Br(m_SpecificICmp(CmpInst::ICMP_EQ, m_Specific(ShAmt), 140 m_ZeroInt()), 141 m_SpecificBB(PhiBB), m_SpecificBB(FunnelBB)))) 142 return false; 143 144 IRBuilder<> Builder(PhiBB, PhiBB->getFirstInsertionPt()); 145 146 if (ShVal0 == ShVal1) 147 ++NumGuardedRotates; 148 else 149 ++NumGuardedFunnelShifts; 150 151 // If this is not a rotate then the select was blocking poison from the 152 // 'shift-by-zero' non-TVal, but a funnel shift won't - so freeze it. 153 bool IsFshl = IID == Intrinsic::fshl; 154 if (ShVal0 != ShVal1) { 155 if (IsFshl && !llvm::isGuaranteedNotToBePoison(ShVal1)) 156 ShVal1 = Builder.CreateFreeze(ShVal1); 157 else if (!IsFshl && !llvm::isGuaranteedNotToBePoison(ShVal0)) 158 ShVal0 = Builder.CreateFreeze(ShVal0); 159 } 160 161 // We matched a variation of this IR pattern: 162 // GuardBB: 163 // %cmp = icmp eq i32 %ShAmt, 0 164 // br i1 %cmp, label %PhiBB, label %FunnelBB 165 // FunnelBB: 166 // %sub = sub i32 32, %ShAmt 167 // %shr = lshr i32 %ShVal1, %sub 168 // %shl = shl i32 %ShVal0, %ShAmt 169 // %fsh = or i32 %shr, %shl 170 // br label %PhiBB 171 // PhiBB: 172 // %cond = phi i32 [ %fsh, %FunnelBB ], [ %ShVal0, %GuardBB ] 173 // --> 174 // llvm.fshl.i32(i32 %ShVal0, i32 %ShVal1, i32 %ShAmt) 175 Phi.replaceAllUsesWith( 176 Builder.CreateIntrinsic(IID, Phi.getType(), {ShVal0, ShVal1, ShAmt})); 177 return true; 178 } 179 180 /// This is used by foldAnyOrAllBitsSet() to capture a source value (Root) and 181 /// the bit indexes (Mask) needed by a masked compare. If we're matching a chain 182 /// of 'and' ops, then we also need to capture the fact that we saw an 183 /// "and X, 1", so that's an extra return value for that case. 184 namespace { 185 struct MaskOps { 186 Value *Root = nullptr; 187 APInt Mask; 188 bool MatchAndChain; 189 bool FoundAnd1 = false; 190 191 MaskOps(unsigned BitWidth, bool MatchAnds) 192 : Mask(APInt::getZero(BitWidth)), MatchAndChain(MatchAnds) {} 193 }; 194 } // namespace 195 196 /// This is a recursive helper for foldAnyOrAllBitsSet() that walks through a 197 /// chain of 'and' or 'or' instructions looking for shift ops of a common source 198 /// value. Examples: 199 /// or (or (or X, (X >> 3)), (X >> 5)), (X >> 8) 200 /// returns { X, 0x129 } 201 /// and (and (X >> 1), 1), (X >> 4) 202 /// returns { X, 0x12 } 203 static bool matchAndOrChain(Value *V, MaskOps &MOps) { 204 Value *Op0, *Op1; 205 if (MOps.MatchAndChain) { 206 // Recurse through a chain of 'and' operands. This requires an extra check 207 // vs. the 'or' matcher: we must find an "and X, 1" instruction somewhere 208 // in the chain to know that all of the high bits are cleared. 209 if (match(V, m_And(m_Value(Op0), m_One()))) { 210 MOps.FoundAnd1 = true; 211 return matchAndOrChain(Op0, MOps); 212 } 213 if (match(V, m_And(m_Value(Op0), m_Value(Op1)))) 214 return matchAndOrChain(Op0, MOps) && matchAndOrChain(Op1, MOps); 215 } else { 216 // Recurse through a chain of 'or' operands. 217 if (match(V, m_Or(m_Value(Op0), m_Value(Op1)))) 218 return matchAndOrChain(Op0, MOps) && matchAndOrChain(Op1, MOps); 219 } 220 221 // We need a shift-right or a bare value representing a compare of bit 0 of 222 // the original source operand. 223 Value *Candidate; 224 const APInt *BitIndex = nullptr; 225 if (!match(V, m_LShr(m_Value(Candidate), m_APInt(BitIndex)))) 226 Candidate = V; 227 228 // Initialize result source operand. 229 if (!MOps.Root) 230 MOps.Root = Candidate; 231 232 // The shift constant is out-of-range? This code hasn't been simplified. 233 if (BitIndex && BitIndex->uge(MOps.Mask.getBitWidth())) 234 return false; 235 236 // Fill in the mask bit derived from the shift constant. 237 MOps.Mask.setBit(BitIndex ? BitIndex->getZExtValue() : 0); 238 return MOps.Root == Candidate; 239 } 240 241 /// Match patterns that correspond to "any-bits-set" and "all-bits-set". 242 /// These will include a chain of 'or' or 'and'-shifted bits from a 243 /// common source value: 244 /// and (or (lshr X, C), ...), 1 --> (X & CMask) != 0 245 /// and (and (lshr X, C), ...), 1 --> (X & CMask) == CMask 246 /// Note: "any-bits-clear" and "all-bits-clear" are variations of these patterns 247 /// that differ only with a final 'not' of the result. We expect that final 248 /// 'not' to be folded with the compare that we create here (invert predicate). 249 static bool foldAnyOrAllBitsSet(Instruction &I) { 250 // The 'any-bits-set' ('or' chain) pattern is simpler to match because the 251 // final "and X, 1" instruction must be the final op in the sequence. 252 bool MatchAllBitsSet; 253 if (match(&I, m_c_And(m_OneUse(m_And(m_Value(), m_Value())), m_Value()))) 254 MatchAllBitsSet = true; 255 else if (match(&I, m_And(m_OneUse(m_Or(m_Value(), m_Value())), m_One()))) 256 MatchAllBitsSet = false; 257 else 258 return false; 259 260 MaskOps MOps(I.getType()->getScalarSizeInBits(), MatchAllBitsSet); 261 if (MatchAllBitsSet) { 262 if (!matchAndOrChain(cast<BinaryOperator>(&I), MOps) || !MOps.FoundAnd1) 263 return false; 264 } else { 265 if (!matchAndOrChain(cast<BinaryOperator>(&I)->getOperand(0), MOps)) 266 return false; 267 } 268 269 // The pattern was found. Create a masked compare that replaces all of the 270 // shift and logic ops. 271 IRBuilder<> Builder(&I); 272 Constant *Mask = ConstantInt::get(I.getType(), MOps.Mask); 273 Value *And = Builder.CreateAnd(MOps.Root, Mask); 274 Value *Cmp = MatchAllBitsSet ? Builder.CreateICmpEQ(And, Mask) 275 : Builder.CreateIsNotNull(And); 276 Value *Zext = Builder.CreateZExt(Cmp, I.getType()); 277 I.replaceAllUsesWith(Zext); 278 ++NumAnyOrAllBitsSet; 279 return true; 280 } 281 282 // Try to recognize below function as popcount intrinsic. 283 // This is the "best" algorithm from 284 // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel 285 // Also used in TargetLowering::expandCTPOP(). 286 // 287 // int popcount(unsigned int i) { 288 // i = i - ((i >> 1) & 0x55555555); 289 // i = (i & 0x33333333) + ((i >> 2) & 0x33333333); 290 // i = ((i + (i >> 4)) & 0x0F0F0F0F); 291 // return (i * 0x01010101) >> 24; 292 // } 293 static bool tryToRecognizePopCount(Instruction &I) { 294 if (I.getOpcode() != Instruction::LShr) 295 return false; 296 297 Type *Ty = I.getType(); 298 if (!Ty->isIntOrIntVectorTy()) 299 return false; 300 301 unsigned Len = Ty->getScalarSizeInBits(); 302 // FIXME: fix Len == 8 and other irregular type lengths. 303 if (!(Len <= 128 && Len > 8 && Len % 8 == 0)) 304 return false; 305 306 APInt Mask55 = APInt::getSplat(Len, APInt(8, 0x55)); 307 APInt Mask33 = APInt::getSplat(Len, APInt(8, 0x33)); 308 APInt Mask0F = APInt::getSplat(Len, APInt(8, 0x0F)); 309 APInt Mask01 = APInt::getSplat(Len, APInt(8, 0x01)); 310 APInt MaskShift = APInt(Len, Len - 8); 311 312 Value *Op0 = I.getOperand(0); 313 Value *Op1 = I.getOperand(1); 314 Value *MulOp0; 315 // Matching "(i * 0x01010101...) >> 24". 316 if ((match(Op0, m_Mul(m_Value(MulOp0), m_SpecificInt(Mask01)))) && 317 match(Op1, m_SpecificInt(MaskShift))) { 318 Value *ShiftOp0; 319 // Matching "((i + (i >> 4)) & 0x0F0F0F0F...)". 320 if (match(MulOp0, m_And(m_c_Add(m_LShr(m_Value(ShiftOp0), m_SpecificInt(4)), 321 m_Deferred(ShiftOp0)), 322 m_SpecificInt(Mask0F)))) { 323 Value *AndOp0; 324 // Matching "(i & 0x33333333...) + ((i >> 2) & 0x33333333...)". 325 if (match(ShiftOp0, 326 m_c_Add(m_And(m_Value(AndOp0), m_SpecificInt(Mask33)), 327 m_And(m_LShr(m_Deferred(AndOp0), m_SpecificInt(2)), 328 m_SpecificInt(Mask33))))) { 329 Value *Root, *SubOp1; 330 // Matching "i - ((i >> 1) & 0x55555555...)". 331 if (match(AndOp0, m_Sub(m_Value(Root), m_Value(SubOp1))) && 332 match(SubOp1, m_And(m_LShr(m_Specific(Root), m_SpecificInt(1)), 333 m_SpecificInt(Mask55)))) { 334 LLVM_DEBUG(dbgs() << "Recognized popcount intrinsic\n"); 335 IRBuilder<> Builder(&I); 336 I.replaceAllUsesWith( 337 Builder.CreateIntrinsic(Intrinsic::ctpop, I.getType(), {Root})); 338 ++NumPopCountRecognized; 339 return true; 340 } 341 } 342 } 343 } 344 345 return false; 346 } 347 348 /// Fold smin(smax(fptosi(x), C1), C2) to llvm.fptosi.sat(x), providing C1 and 349 /// C2 saturate the value of the fp conversion. The transform is not reversable 350 /// as the fptosi.sat is more defined than the input - all values produce a 351 /// valid value for the fptosi.sat, where as some produce poison for original 352 /// that were out of range of the integer conversion. The reversed pattern may 353 /// use fmax and fmin instead. As we cannot directly reverse the transform, and 354 /// it is not always profitable, we make it conditional on the cost being 355 /// reported as lower by TTI. 356 static bool tryToFPToSat(Instruction &I, TargetTransformInfo &TTI) { 357 // Look for min(max(fptosi, converting to fptosi_sat. 358 Value *In; 359 const APInt *MinC, *MaxC; 360 if (!match(&I, m_SMax(m_OneUse(m_SMin(m_OneUse(m_FPToSI(m_Value(In))), 361 m_APInt(MinC))), 362 m_APInt(MaxC))) && 363 !match(&I, m_SMin(m_OneUse(m_SMax(m_OneUse(m_FPToSI(m_Value(In))), 364 m_APInt(MaxC))), 365 m_APInt(MinC)))) 366 return false; 367 368 // Check that the constants clamp a saturate. 369 if (!(*MinC + 1).isPowerOf2() || -*MaxC != *MinC + 1) 370 return false; 371 372 Type *IntTy = I.getType(); 373 Type *FpTy = In->getType(); 374 Type *SatTy = 375 IntegerType::get(IntTy->getContext(), (*MinC + 1).exactLogBase2() + 1); 376 if (auto *VecTy = dyn_cast<VectorType>(IntTy)) 377 SatTy = VectorType::get(SatTy, VecTy->getElementCount()); 378 379 // Get the cost of the intrinsic, and check that against the cost of 380 // fptosi+smin+smax 381 InstructionCost SatCost = TTI.getIntrinsicInstrCost( 382 IntrinsicCostAttributes(Intrinsic::fptosi_sat, SatTy, {In}, {FpTy}), 383 TTI::TCK_RecipThroughput); 384 SatCost += TTI.getCastInstrCost(Instruction::SExt, IntTy, SatTy, 385 TTI::CastContextHint::None, 386 TTI::TCK_RecipThroughput); 387 388 InstructionCost MinMaxCost = TTI.getCastInstrCost( 389 Instruction::FPToSI, IntTy, FpTy, TTI::CastContextHint::None, 390 TTI::TCK_RecipThroughput); 391 MinMaxCost += TTI.getIntrinsicInstrCost( 392 IntrinsicCostAttributes(Intrinsic::smin, IntTy, {IntTy}), 393 TTI::TCK_RecipThroughput); 394 MinMaxCost += TTI.getIntrinsicInstrCost( 395 IntrinsicCostAttributes(Intrinsic::smax, IntTy, {IntTy}), 396 TTI::TCK_RecipThroughput); 397 398 if (SatCost >= MinMaxCost) 399 return false; 400 401 IRBuilder<> Builder(&I); 402 Value *Sat = 403 Builder.CreateIntrinsic(Intrinsic::fptosi_sat, {SatTy, FpTy}, In); 404 I.replaceAllUsesWith(Builder.CreateSExt(Sat, IntTy)); 405 return true; 406 } 407 408 /// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids 409 /// pessimistic codegen that has to account for setting errno and can enable 410 /// vectorization. 411 static bool foldSqrt(CallInst *Call, LibFunc Func, TargetTransformInfo &TTI, 412 TargetLibraryInfo &TLI, AssumptionCache &AC, 413 DominatorTree &DT) { 414 // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created 415 // (because NNAN or the operand arg must not be less than -0.0) and (2) we 416 // would not end up lowering to a libcall anyway (which could change the value 417 // of errno), then: 418 // (1) errno won't be set. 419 // (2) it is safe to convert this to an intrinsic call. 420 Type *Ty = Call->getType(); 421 Value *Arg = Call->getArgOperand(0); 422 if (TTI.haveFastSqrt(Ty) && 423 (Call->hasNoNaNs() || 424 cannotBeOrderedLessThanZero( 425 Arg, 0, 426 SimplifyQuery(Call->getDataLayout(), &TLI, &DT, &AC, Call)))) { 427 IRBuilder<> Builder(Call); 428 Value *NewSqrt = 429 Builder.CreateIntrinsic(Intrinsic::sqrt, Ty, Arg, Call, "sqrt"); 430 Call->replaceAllUsesWith(NewSqrt); 431 432 // Explicitly erase the old call because a call with side effects is not 433 // trivially dead. 434 Call->eraseFromParent(); 435 return true; 436 } 437 438 return false; 439 } 440 441 // Check if this array of constants represents a cttz table. 442 // Iterate over the elements from \p Table by trying to find/match all 443 // the numbers from 0 to \p InputBits that should represent cttz results. 444 static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul, 445 uint64_t Shift, uint64_t InputBits) { 446 unsigned Length = Table.getNumElements(); 447 if (Length < InputBits || Length > InputBits * 2) 448 return false; 449 450 APInt Mask = APInt::getBitsSetFrom(InputBits, Shift); 451 unsigned Matched = 0; 452 453 for (unsigned i = 0; i < Length; i++) { 454 uint64_t Element = Table.getElementAsInteger(i); 455 if (Element >= InputBits) 456 continue; 457 458 // Check if \p Element matches a concrete answer. It could fail for some 459 // elements that are never accessed, so we keep iterating over each element 460 // from the table. The number of matched elements should be equal to the 461 // number of potential right answers which is \p InputBits actually. 462 if ((((Mul << Element) & Mask.getZExtValue()) >> Shift) == i) 463 Matched++; 464 } 465 466 return Matched == InputBits; 467 } 468 469 // Try to recognize table-based ctz implementation. 470 // E.g., an example in C (for more cases please see the llvm/tests): 471 // int f(unsigned x) { 472 // static const char table[32] = 473 // {0, 1, 28, 2, 29, 14, 24, 3, 30, 474 // 22, 20, 15, 25, 17, 4, 8, 31, 27, 475 // 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9}; 476 // return table[((unsigned)((x & -x) * 0x077CB531U)) >> 27]; 477 // } 478 // this can be lowered to `cttz` instruction. 479 // There is also a special case when the element is 0. 480 // 481 // Here are some examples or LLVM IR for a 64-bit target: 482 // 483 // CASE 1: 484 // %sub = sub i32 0, %x 485 // %and = and i32 %sub, %x 486 // %mul = mul i32 %and, 125613361 487 // %shr = lshr i32 %mul, 27 488 // %idxprom = zext i32 %shr to i64 489 // %arrayidx = getelementptr inbounds [32 x i8], [32 x i8]* @ctz1.table, i64 0, 490 // i64 %idxprom 491 // %0 = load i8, i8* %arrayidx, align 1, !tbaa !8 492 // 493 // CASE 2: 494 // %sub = sub i32 0, %x 495 // %and = and i32 %sub, %x 496 // %mul = mul i32 %and, 72416175 497 // %shr = lshr i32 %mul, 26 498 // %idxprom = zext i32 %shr to i64 499 // %arrayidx = getelementptr inbounds [64 x i16], [64 x i16]* @ctz2.table, 500 // i64 0, i64 %idxprom 501 // %0 = load i16, i16* %arrayidx, align 2, !tbaa !8 502 // 503 // CASE 3: 504 // %sub = sub i32 0, %x 505 // %and = and i32 %sub, %x 506 // %mul = mul i32 %and, 81224991 507 // %shr = lshr i32 %mul, 27 508 // %idxprom = zext i32 %shr to i64 509 // %arrayidx = getelementptr inbounds [32 x i32], [32 x i32]* @ctz3.table, 510 // i64 0, i64 %idxprom 511 // %0 = load i32, i32* %arrayidx, align 4, !tbaa !8 512 // 513 // CASE 4: 514 // %sub = sub i64 0, %x 515 // %and = and i64 %sub, %x 516 // %mul = mul i64 %and, 283881067100198605 517 // %shr = lshr i64 %mul, 58 518 // %arrayidx = getelementptr inbounds [64 x i8], [64 x i8]* @table, i64 0, 519 // i64 %shr 520 // %0 = load i8, i8* %arrayidx, align 1, !tbaa !8 521 // 522 // All this can be lowered to @llvm.cttz.i32/64 intrinsic. 523 static bool tryToRecognizeTableBasedCttz(Instruction &I) { 524 LoadInst *LI = dyn_cast<LoadInst>(&I); 525 if (!LI) 526 return false; 527 528 Type *AccessType = LI->getType(); 529 if (!AccessType->isIntegerTy()) 530 return false; 531 532 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getPointerOperand()); 533 if (!GEP || !GEP->isInBounds() || GEP->getNumIndices() != 2) 534 return false; 535 536 if (!GEP->getSourceElementType()->isArrayTy()) 537 return false; 538 539 uint64_t ArraySize = GEP->getSourceElementType()->getArrayNumElements(); 540 if (ArraySize != 32 && ArraySize != 64) 541 return false; 542 543 GlobalVariable *GVTable = dyn_cast<GlobalVariable>(GEP->getPointerOperand()); 544 if (!GVTable || !GVTable->hasInitializer() || !GVTable->isConstant()) 545 return false; 546 547 ConstantDataArray *ConstData = 548 dyn_cast<ConstantDataArray>(GVTable->getInitializer()); 549 if (!ConstData) 550 return false; 551 552 if (!match(GEP->idx_begin()->get(), m_ZeroInt())) 553 return false; 554 555 Value *Idx2 = std::next(GEP->idx_begin())->get(); 556 Value *X1; 557 uint64_t MulConst, ShiftConst; 558 // FIXME: 64-bit targets have `i64` type for the GEP index, so this match will 559 // probably fail for other (e.g. 32-bit) targets. 560 if (!match(Idx2, m_ZExtOrSelf( 561 m_LShr(m_Mul(m_c_And(m_Neg(m_Value(X1)), m_Deferred(X1)), 562 m_ConstantInt(MulConst)), 563 m_ConstantInt(ShiftConst))))) 564 return false; 565 566 unsigned InputBits = X1->getType()->getScalarSizeInBits(); 567 if (InputBits != 32 && InputBits != 64) 568 return false; 569 570 // Shift should extract top 5..7 bits. 571 if (InputBits - Log2_32(InputBits) != ShiftConst && 572 InputBits - Log2_32(InputBits) - 1 != ShiftConst) 573 return false; 574 575 if (!isCTTZTable(*ConstData, MulConst, ShiftConst, InputBits)) 576 return false; 577 578 auto ZeroTableElem = ConstData->getElementAsInteger(0); 579 bool DefinedForZero = ZeroTableElem == InputBits; 580 581 IRBuilder<> B(LI); 582 ConstantInt *BoolConst = B.getInt1(!DefinedForZero); 583 Type *XType = X1->getType(); 584 auto Cttz = B.CreateIntrinsic(Intrinsic::cttz, {XType}, {X1, BoolConst}); 585 Value *ZExtOrTrunc = nullptr; 586 587 if (DefinedForZero) { 588 ZExtOrTrunc = B.CreateZExtOrTrunc(Cttz, AccessType); 589 } else { 590 // If the value in elem 0 isn't the same as InputBits, we still want to 591 // produce the value from the table. 592 auto Cmp = B.CreateICmpEQ(X1, ConstantInt::get(XType, 0)); 593 auto Select = 594 B.CreateSelect(Cmp, ConstantInt::get(XType, ZeroTableElem), Cttz); 595 596 // NOTE: If the table[0] is 0, but the cttz(0) is defined by the Target 597 // it should be handled as: `cttz(x) & (typeSize - 1)`. 598 599 ZExtOrTrunc = B.CreateZExtOrTrunc(Select, AccessType); 600 } 601 602 LI->replaceAllUsesWith(ZExtOrTrunc); 603 604 return true; 605 } 606 607 /// This is used by foldLoadsRecursive() to capture a Root Load node which is 608 /// of type or(load, load) and recursively build the wide load. Also capture the 609 /// shift amount, zero extend type and loadSize. 610 struct LoadOps { 611 LoadInst *Root = nullptr; 612 LoadInst *RootInsert = nullptr; 613 bool FoundRoot = false; 614 uint64_t LoadSize = 0; 615 const APInt *Shift = nullptr; 616 Type *ZextType; 617 AAMDNodes AATags; 618 }; 619 620 // Identify and Merge consecutive loads recursively which is of the form 621 // (ZExt(L1) << shift1) | (ZExt(L2) << shift2) -> ZExt(L3) << shift1 622 // (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3) 623 static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL, 624 AliasAnalysis &AA) { 625 const APInt *ShAmt2 = nullptr; 626 Value *X; 627 Instruction *L1, *L2; 628 629 // Go to the last node with loads. 630 if (match(V, m_OneUse(m_c_Or( 631 m_Value(X), 632 m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))), 633 m_APInt(ShAmt2)))))) || 634 match(V, m_OneUse(m_Or(m_Value(X), 635 m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))))))) { 636 if (!foldLoadsRecursive(X, LOps, DL, AA) && LOps.FoundRoot) 637 // Avoid Partial chain merge. 638 return false; 639 } else 640 return false; 641 642 // Check if the pattern has loads 643 LoadInst *LI1 = LOps.Root; 644 const APInt *ShAmt1 = LOps.Shift; 645 if (LOps.FoundRoot == false && 646 (match(X, m_OneUse(m_ZExt(m_Instruction(L1)))) || 647 match(X, m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1)))), 648 m_APInt(ShAmt1)))))) { 649 LI1 = dyn_cast<LoadInst>(L1); 650 } 651 LoadInst *LI2 = dyn_cast<LoadInst>(L2); 652 653 // Check if loads are same, atomic, volatile and having same address space. 654 if (LI1 == LI2 || !LI1 || !LI2 || !LI1->isSimple() || !LI2->isSimple() || 655 LI1->getPointerAddressSpace() != LI2->getPointerAddressSpace()) 656 return false; 657 658 // Check if Loads come from same BB. 659 if (LI1->getParent() != LI2->getParent()) 660 return false; 661 662 // Find the data layout 663 bool IsBigEndian = DL.isBigEndian(); 664 665 // Check if loads are consecutive and same size. 666 Value *Load1Ptr = LI1->getPointerOperand(); 667 APInt Offset1(DL.getIndexTypeSizeInBits(Load1Ptr->getType()), 0); 668 Load1Ptr = 669 Load1Ptr->stripAndAccumulateConstantOffsets(DL, Offset1, 670 /* AllowNonInbounds */ true); 671 672 Value *Load2Ptr = LI2->getPointerOperand(); 673 APInt Offset2(DL.getIndexTypeSizeInBits(Load2Ptr->getType()), 0); 674 Load2Ptr = 675 Load2Ptr->stripAndAccumulateConstantOffsets(DL, Offset2, 676 /* AllowNonInbounds */ true); 677 678 // Verify if both loads have same base pointers and load sizes are same. 679 uint64_t LoadSize1 = LI1->getType()->getPrimitiveSizeInBits(); 680 uint64_t LoadSize2 = LI2->getType()->getPrimitiveSizeInBits(); 681 if (Load1Ptr != Load2Ptr || LoadSize1 != LoadSize2) 682 return false; 683 684 // Support Loadsizes greater or equal to 8bits and only power of 2. 685 if (LoadSize1 < 8 || !isPowerOf2_64(LoadSize1)) 686 return false; 687 688 // Alias Analysis to check for stores b/w the loads. 689 LoadInst *Start = LOps.FoundRoot ? LOps.RootInsert : LI1, *End = LI2; 690 MemoryLocation Loc; 691 if (!Start->comesBefore(End)) { 692 std::swap(Start, End); 693 Loc = MemoryLocation::get(End); 694 if (LOps.FoundRoot) 695 Loc = Loc.getWithNewSize(LOps.LoadSize); 696 } else 697 Loc = MemoryLocation::get(End); 698 unsigned NumScanned = 0; 699 for (Instruction &Inst : 700 make_range(Start->getIterator(), End->getIterator())) { 701 if (Inst.mayWriteToMemory() && isModSet(AA.getModRefInfo(&Inst, Loc))) 702 return false; 703 704 // Ignore debug info so that's not counted against MaxInstrsToScan. 705 // Otherwise debug info could affect codegen. 706 if (!isa<DbgInfoIntrinsic>(Inst) && ++NumScanned > MaxInstrsToScan) 707 return false; 708 } 709 710 // Make sure Load with lower Offset is at LI1 711 bool Reverse = false; 712 if (Offset2.slt(Offset1)) { 713 std::swap(LI1, LI2); 714 std::swap(ShAmt1, ShAmt2); 715 std::swap(Offset1, Offset2); 716 std::swap(Load1Ptr, Load2Ptr); 717 std::swap(LoadSize1, LoadSize2); 718 Reverse = true; 719 } 720 721 // Big endian swap the shifts 722 if (IsBigEndian) 723 std::swap(ShAmt1, ShAmt2); 724 725 // Find Shifts values. 726 uint64_t Shift1 = 0, Shift2 = 0; 727 if (ShAmt1) 728 Shift1 = ShAmt1->getZExtValue(); 729 if (ShAmt2) 730 Shift2 = ShAmt2->getZExtValue(); 731 732 // First load is always LI1. This is where we put the new load. 733 // Use the merged load size available from LI1 for forward loads. 734 if (LOps.FoundRoot) { 735 if (!Reverse) 736 LoadSize1 = LOps.LoadSize; 737 else 738 LoadSize2 = LOps.LoadSize; 739 } 740 741 // Verify if shift amount and load index aligns and verifies that loads 742 // are consecutive. 743 uint64_t ShiftDiff = IsBigEndian ? LoadSize2 : LoadSize1; 744 uint64_t PrevSize = 745 DL.getTypeStoreSize(IntegerType::get(LI1->getContext(), LoadSize1)); 746 if ((Shift2 - Shift1) != ShiftDiff || (Offset2 - Offset1) != PrevSize) 747 return false; 748 749 // Update LOps 750 AAMDNodes AATags1 = LOps.AATags; 751 AAMDNodes AATags2 = LI2->getAAMetadata(); 752 if (LOps.FoundRoot == false) { 753 LOps.FoundRoot = true; 754 AATags1 = LI1->getAAMetadata(); 755 } 756 LOps.LoadSize = LoadSize1 + LoadSize2; 757 LOps.RootInsert = Start; 758 759 // Concatenate the AATags of the Merged Loads. 760 LOps.AATags = AATags1.concat(AATags2); 761 762 LOps.Root = LI1; 763 LOps.Shift = ShAmt1; 764 LOps.ZextType = X->getType(); 765 return true; 766 } 767 768 // For a given BB instruction, evaluate all loads in the chain that form a 769 // pattern which suggests that the loads can be combined. The one and only use 770 // of the loads is to form a wider load. 771 static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL, 772 TargetTransformInfo &TTI, AliasAnalysis &AA, 773 const DominatorTree &DT) { 774 // Only consider load chains of scalar values. 775 if (isa<VectorType>(I.getType())) 776 return false; 777 778 LoadOps LOps; 779 if (!foldLoadsRecursive(&I, LOps, DL, AA) || !LOps.FoundRoot) 780 return false; 781 782 IRBuilder<> Builder(&I); 783 LoadInst *NewLoad = nullptr, *LI1 = LOps.Root; 784 785 IntegerType *WiderType = IntegerType::get(I.getContext(), LOps.LoadSize); 786 // TTI based checks if we want to proceed with wider load 787 bool Allowed = TTI.isTypeLegal(WiderType); 788 if (!Allowed) 789 return false; 790 791 unsigned AS = LI1->getPointerAddressSpace(); 792 unsigned Fast = 0; 793 Allowed = TTI.allowsMisalignedMemoryAccesses(I.getContext(), LOps.LoadSize, 794 AS, LI1->getAlign(), &Fast); 795 if (!Allowed || !Fast) 796 return false; 797 798 // Get the Index and Ptr for the new GEP. 799 Value *Load1Ptr = LI1->getPointerOperand(); 800 Builder.SetInsertPoint(LOps.RootInsert); 801 if (!DT.dominates(Load1Ptr, LOps.RootInsert)) { 802 APInt Offset1(DL.getIndexTypeSizeInBits(Load1Ptr->getType()), 0); 803 Load1Ptr = Load1Ptr->stripAndAccumulateConstantOffsets( 804 DL, Offset1, /* AllowNonInbounds */ true); 805 Load1Ptr = Builder.CreatePtrAdd(Load1Ptr, Builder.getInt(Offset1)); 806 } 807 // Generate wider load. 808 NewLoad = Builder.CreateAlignedLoad(WiderType, Load1Ptr, LI1->getAlign(), 809 LI1->isVolatile(), ""); 810 NewLoad->takeName(LI1); 811 // Set the New Load AATags Metadata. 812 if (LOps.AATags) 813 NewLoad->setAAMetadata(LOps.AATags); 814 815 Value *NewOp = NewLoad; 816 // Check if zero extend needed. 817 if (LOps.ZextType) 818 NewOp = Builder.CreateZExt(NewOp, LOps.ZextType); 819 820 // Check if shift needed. We need to shift with the amount of load1 821 // shift if not zero. 822 if (LOps.Shift) 823 NewOp = Builder.CreateShl(NewOp, ConstantInt::get(I.getContext(), *LOps.Shift)); 824 I.replaceAllUsesWith(NewOp); 825 826 return true; 827 } 828 829 // Calculate GEP Stride and accumulated const ModOffset. Return Stride and 830 // ModOffset 831 static std::pair<APInt, APInt> 832 getStrideAndModOffsetOfGEP(Value *PtrOp, const DataLayout &DL) { 833 unsigned BW = DL.getIndexTypeSizeInBits(PtrOp->getType()); 834 std::optional<APInt> Stride; 835 APInt ModOffset(BW, 0); 836 // Return a minimum gep stride, greatest common divisor of consective gep 837 // index scales(c.f. Bézout's identity). 838 while (auto *GEP = dyn_cast<GEPOperator>(PtrOp)) { 839 SmallMapVector<Value *, APInt, 4> VarOffsets; 840 if (!GEP->collectOffset(DL, BW, VarOffsets, ModOffset)) 841 break; 842 843 for (auto [V, Scale] : VarOffsets) { 844 // Only keep a power of two factor for non-inbounds 845 if (!GEP->isInBounds()) 846 Scale = APInt::getOneBitSet(Scale.getBitWidth(), Scale.countr_zero()); 847 848 if (!Stride) 849 Stride = Scale; 850 else 851 Stride = APIntOps::GreatestCommonDivisor(*Stride, Scale); 852 } 853 854 PtrOp = GEP->getPointerOperand(); 855 } 856 857 // Check whether pointer arrives back at Global Variable via at least one GEP. 858 // Even if it doesn't, we can check by alignment. 859 if (!isa<GlobalVariable>(PtrOp) || !Stride) 860 return {APInt(BW, 1), APInt(BW, 0)}; 861 862 // In consideration of signed GEP indices, non-negligible offset become 863 // remainder of division by minimum GEP stride. 864 ModOffset = ModOffset.srem(*Stride); 865 if (ModOffset.isNegative()) 866 ModOffset += *Stride; 867 868 return {*Stride, ModOffset}; 869 } 870 871 /// If C is a constant patterned array and all valid loaded results for given 872 /// alignment are same to a constant, return that constant. 873 static bool foldPatternedLoads(Instruction &I, const DataLayout &DL) { 874 auto *LI = dyn_cast<LoadInst>(&I); 875 if (!LI || LI->isVolatile()) 876 return false; 877 878 // We can only fold the load if it is from a constant global with definitive 879 // initializer. Skip expensive logic if this is not the case. 880 auto *PtrOp = LI->getPointerOperand(); 881 auto *GV = dyn_cast<GlobalVariable>(getUnderlyingObject(PtrOp)); 882 if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer()) 883 return false; 884 885 // Bail for large initializers in excess of 4K to avoid too many scans. 886 Constant *C = GV->getInitializer(); 887 uint64_t GVSize = DL.getTypeAllocSize(C->getType()); 888 if (!GVSize || 4096 < GVSize) 889 return false; 890 891 Type *LoadTy = LI->getType(); 892 unsigned BW = DL.getIndexTypeSizeInBits(PtrOp->getType()); 893 auto [Stride, ConstOffset] = getStrideAndModOffsetOfGEP(PtrOp, DL); 894 895 // Any possible offset could be multiple of GEP stride. And any valid 896 // offset is multiple of load alignment, so checking only multiples of bigger 897 // one is sufficient to say results' equality. 898 if (auto LA = LI->getAlign(); 899 LA <= GV->getAlign().valueOrOne() && Stride.getZExtValue() < LA.value()) { 900 ConstOffset = APInt(BW, 0); 901 Stride = APInt(BW, LA.value()); 902 } 903 904 Constant *Ca = ConstantFoldLoadFromConst(C, LoadTy, ConstOffset, DL); 905 if (!Ca) 906 return false; 907 908 unsigned E = GVSize - DL.getTypeStoreSize(LoadTy); 909 for (; ConstOffset.getZExtValue() <= E; ConstOffset += Stride) 910 if (Ca != ConstantFoldLoadFromConst(C, LoadTy, ConstOffset, DL)) 911 return false; 912 913 I.replaceAllUsesWith(Ca); 914 915 return true; 916 } 917 918 namespace { 919 class StrNCmpInliner { 920 public: 921 StrNCmpInliner(CallInst *CI, LibFunc Func, DomTreeUpdater *DTU, 922 const DataLayout &DL) 923 : CI(CI), Func(Func), DTU(DTU), DL(DL) {} 924 925 bool optimizeStrNCmp(); 926 927 private: 928 void inlineCompare(Value *LHS, StringRef RHS, uint64_t N, bool Swapped); 929 930 CallInst *CI; 931 LibFunc Func; 932 DomTreeUpdater *DTU; 933 const DataLayout &DL; 934 }; 935 936 } // namespace 937 938 /// First we normalize calls to strncmp/strcmp to the form of 939 /// compare(s1, s2, N), which means comparing first N bytes of s1 and s2 940 /// (without considering '\0'). 941 /// 942 /// Examples: 943 /// 944 /// \code 945 /// strncmp(s, "a", 3) -> compare(s, "a", 2) 946 /// strncmp(s, "abc", 3) -> compare(s, "abc", 3) 947 /// strncmp(s, "a\0b", 3) -> compare(s, "a\0b", 2) 948 /// strcmp(s, "a") -> compare(s, "a", 2) 949 /// 950 /// char s2[] = {'a'} 951 /// strncmp(s, s2, 3) -> compare(s, s2, 3) 952 /// 953 /// char s2[] = {'a', 'b', 'c', 'd'} 954 /// strncmp(s, s2, 3) -> compare(s, s2, 3) 955 /// \endcode 956 /// 957 /// We only handle cases where N and exactly one of s1 and s2 are constant. 958 /// Cases that s1 and s2 are both constant are already handled by the 959 /// instcombine pass. 960 /// 961 /// We do not handle cases where N > StrNCmpInlineThreshold. 962 /// 963 /// We also do not handles cases where N < 2, which are already 964 /// handled by the instcombine pass. 965 /// 966 bool StrNCmpInliner::optimizeStrNCmp() { 967 if (StrNCmpInlineThreshold < 2) 968 return false; 969 970 if (!isOnlyUsedInZeroComparison(CI)) 971 return false; 972 973 Value *Str1P = CI->getArgOperand(0); 974 Value *Str2P = CI->getArgOperand(1); 975 // Should be handled elsewhere. 976 if (Str1P == Str2P) 977 return false; 978 979 StringRef Str1, Str2; 980 bool HasStr1 = getConstantStringInfo(Str1P, Str1, /*TrimAtNul=*/false); 981 bool HasStr2 = getConstantStringInfo(Str2P, Str2, /*TrimAtNul=*/false); 982 if (HasStr1 == HasStr2) 983 return false; 984 985 // Note that '\0' and characters after it are not trimmed. 986 StringRef Str = HasStr1 ? Str1 : Str2; 987 Value *StrP = HasStr1 ? Str2P : Str1P; 988 989 size_t Idx = Str.find('\0'); 990 uint64_t N = Idx == StringRef::npos ? UINT64_MAX : Idx + 1; 991 if (Func == LibFunc_strncmp) { 992 if (auto *ConstInt = dyn_cast<ConstantInt>(CI->getArgOperand(2))) 993 N = std::min(N, ConstInt->getZExtValue()); 994 else 995 return false; 996 } 997 // Now N means how many bytes we need to compare at most. 998 if (N > Str.size() || N < 2 || N > StrNCmpInlineThreshold) 999 return false; 1000 1001 // Cases where StrP has two or more dereferenceable bytes might be better 1002 // optimized elsewhere. 1003 bool CanBeNull = false, CanBeFreed = false; 1004 if (StrP->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed) > 1) 1005 return false; 1006 inlineCompare(StrP, Str, N, HasStr1); 1007 return true; 1008 } 1009 1010 /// Convert 1011 /// 1012 /// \code 1013 /// ret = compare(s1, s2, N) 1014 /// \endcode 1015 /// 1016 /// into 1017 /// 1018 /// \code 1019 /// ret = (int)s1[0] - (int)s2[0] 1020 /// if (ret != 0) 1021 /// goto NE 1022 /// ... 1023 /// ret = (int)s1[N-2] - (int)s2[N-2] 1024 /// if (ret != 0) 1025 /// goto NE 1026 /// ret = (int)s1[N-1] - (int)s2[N-1] 1027 /// NE: 1028 /// \endcode 1029 /// 1030 /// CFG before and after the transformation: 1031 /// 1032 /// (before) 1033 /// BBCI 1034 /// 1035 /// (after) 1036 /// BBCI -> BBSubs[0] (sub,icmp) --NE-> BBNE -> BBTail 1037 /// | ^ 1038 /// E | 1039 /// | | 1040 /// BBSubs[1] (sub,icmp) --NE-----+ 1041 /// ... | 1042 /// BBSubs[N-1] (sub) ---------+ 1043 /// 1044 void StrNCmpInliner::inlineCompare(Value *LHS, StringRef RHS, uint64_t N, 1045 bool Swapped) { 1046 auto &Ctx = CI->getContext(); 1047 IRBuilder<> B(Ctx); 1048 // We want these instructions to be recognized as inlined instructions for the 1049 // compare call, but we don't have a source location for the definition of 1050 // that function, since we're generating that code now. Because the generated 1051 // code is a viable point for a memory access error, we make the pragmatic 1052 // choice here to directly use CI's location so that we have useful 1053 // attribution for the generated code. 1054 B.SetCurrentDebugLocation(CI->getDebugLoc()); 1055 1056 BasicBlock *BBCI = CI->getParent(); 1057 BasicBlock *BBTail = 1058 SplitBlock(BBCI, CI, DTU, nullptr, nullptr, BBCI->getName() + ".tail"); 1059 1060 SmallVector<BasicBlock *> BBSubs; 1061 for (uint64_t I = 0; I < N; ++I) 1062 BBSubs.push_back( 1063 BasicBlock::Create(Ctx, "sub_" + Twine(I), BBCI->getParent(), BBTail)); 1064 BasicBlock *BBNE = BasicBlock::Create(Ctx, "ne", BBCI->getParent(), BBTail); 1065 1066 cast<BranchInst>(BBCI->getTerminator())->setSuccessor(0, BBSubs[0]); 1067 1068 B.SetInsertPoint(BBNE); 1069 PHINode *Phi = B.CreatePHI(CI->getType(), N); 1070 B.CreateBr(BBTail); 1071 1072 Value *Base = LHS; 1073 for (uint64_t i = 0; i < N; ++i) { 1074 B.SetInsertPoint(BBSubs[i]); 1075 Value *VL = 1076 B.CreateZExt(B.CreateLoad(B.getInt8Ty(), 1077 B.CreateInBoundsPtrAdd(Base, B.getInt64(i))), 1078 CI->getType()); 1079 Value *VR = 1080 ConstantInt::get(CI->getType(), static_cast<unsigned char>(RHS[i])); 1081 Value *Sub = Swapped ? B.CreateSub(VR, VL) : B.CreateSub(VL, VR); 1082 if (i < N - 1) 1083 B.CreateCondBr(B.CreateICmpNE(Sub, ConstantInt::get(CI->getType(), 0)), 1084 BBNE, BBSubs[i + 1]); 1085 else 1086 B.CreateBr(BBNE); 1087 1088 Phi->addIncoming(Sub, BBSubs[i]); 1089 } 1090 1091 CI->replaceAllUsesWith(Phi); 1092 CI->eraseFromParent(); 1093 1094 if (DTU) { 1095 SmallVector<DominatorTree::UpdateType, 8> Updates; 1096 Updates.push_back({DominatorTree::Insert, BBCI, BBSubs[0]}); 1097 for (uint64_t i = 0; i < N; ++i) { 1098 if (i < N - 1) 1099 Updates.push_back({DominatorTree::Insert, BBSubs[i], BBSubs[i + 1]}); 1100 Updates.push_back({DominatorTree::Insert, BBSubs[i], BBNE}); 1101 } 1102 Updates.push_back({DominatorTree::Insert, BBNE, BBTail}); 1103 Updates.push_back({DominatorTree::Delete, BBCI, BBTail}); 1104 DTU->applyUpdates(Updates); 1105 } 1106 } 1107 1108 /// Convert memchr with a small constant string into a switch 1109 static bool foldMemChr(CallInst *Call, DomTreeUpdater *DTU, 1110 const DataLayout &DL) { 1111 if (isa<Constant>(Call->getArgOperand(1))) 1112 return false; 1113 1114 StringRef Str; 1115 Value *Base = Call->getArgOperand(0); 1116 if (!getConstantStringInfo(Base, Str, /*TrimAtNul=*/false)) 1117 return false; 1118 1119 uint64_t N = Str.size(); 1120 if (auto *ConstInt = dyn_cast<ConstantInt>(Call->getArgOperand(2))) { 1121 uint64_t Val = ConstInt->getZExtValue(); 1122 // Ignore the case that n is larger than the size of string. 1123 if (Val > N) 1124 return false; 1125 N = Val; 1126 } else 1127 return false; 1128 1129 if (N > MemChrInlineThreshold) 1130 return false; 1131 1132 BasicBlock *BB = Call->getParent(); 1133 BasicBlock *BBNext = SplitBlock(BB, Call, DTU); 1134 IRBuilder<> IRB(BB); 1135 IntegerType *ByteTy = IRB.getInt8Ty(); 1136 BB->getTerminator()->eraseFromParent(); 1137 SwitchInst *SI = IRB.CreateSwitch( 1138 IRB.CreateTrunc(Call->getArgOperand(1), ByteTy), BBNext, N); 1139 Type *IndexTy = DL.getIndexType(Call->getType()); 1140 SmallVector<DominatorTree::UpdateType, 8> Updates; 1141 1142 BasicBlock *BBSuccess = BasicBlock::Create( 1143 Call->getContext(), "memchr.success", BB->getParent(), BBNext); 1144 IRB.SetInsertPoint(BBSuccess); 1145 PHINode *IndexPHI = IRB.CreatePHI(IndexTy, N, "memchr.idx"); 1146 Value *FirstOccursLocation = IRB.CreateInBoundsPtrAdd(Base, IndexPHI); 1147 IRB.CreateBr(BBNext); 1148 if (DTU) 1149 Updates.push_back({DominatorTree::Insert, BBSuccess, BBNext}); 1150 1151 SmallPtrSet<ConstantInt *, 4> Cases; 1152 for (uint64_t I = 0; I < N; ++I) { 1153 ConstantInt *CaseVal = ConstantInt::get(ByteTy, Str[I]); 1154 if (!Cases.insert(CaseVal).second) 1155 continue; 1156 1157 BasicBlock *BBCase = BasicBlock::Create(Call->getContext(), "memchr.case", 1158 BB->getParent(), BBSuccess); 1159 SI->addCase(CaseVal, BBCase); 1160 IRB.SetInsertPoint(BBCase); 1161 IndexPHI->addIncoming(ConstantInt::get(IndexTy, I), BBCase); 1162 IRB.CreateBr(BBSuccess); 1163 if (DTU) { 1164 Updates.push_back({DominatorTree::Insert, BB, BBCase}); 1165 Updates.push_back({DominatorTree::Insert, BBCase, BBSuccess}); 1166 } 1167 } 1168 1169 PHINode *PHI = 1170 PHINode::Create(Call->getType(), 2, Call->getName(), BBNext->begin()); 1171 PHI->addIncoming(Constant::getNullValue(Call->getType()), BB); 1172 PHI->addIncoming(FirstOccursLocation, BBSuccess); 1173 1174 Call->replaceAllUsesWith(PHI); 1175 Call->eraseFromParent(); 1176 1177 if (DTU) 1178 DTU->applyUpdates(Updates); 1179 1180 return true; 1181 } 1182 1183 static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI, 1184 TargetLibraryInfo &TLI, AssumptionCache &AC, 1185 DominatorTree &DT, const DataLayout &DL, 1186 bool &MadeCFGChange) { 1187 1188 auto *CI = dyn_cast<CallInst>(&I); 1189 if (!CI || CI->isNoBuiltin()) 1190 return false; 1191 1192 Function *CalledFunc = CI->getCalledFunction(); 1193 if (!CalledFunc) 1194 return false; 1195 1196 LibFunc LF; 1197 if (!TLI.getLibFunc(*CalledFunc, LF) || 1198 !isLibFuncEmittable(CI->getModule(), &TLI, LF)) 1199 return false; 1200 1201 DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Lazy); 1202 1203 switch (LF) { 1204 case LibFunc_sqrt: 1205 case LibFunc_sqrtf: 1206 case LibFunc_sqrtl: 1207 return foldSqrt(CI, LF, TTI, TLI, AC, DT); 1208 case LibFunc_strcmp: 1209 case LibFunc_strncmp: 1210 if (StrNCmpInliner(CI, LF, &DTU, DL).optimizeStrNCmp()) { 1211 MadeCFGChange = true; 1212 return true; 1213 } 1214 break; 1215 case LibFunc_memchr: 1216 if (foldMemChr(CI, &DTU, DL)) { 1217 MadeCFGChange = true; 1218 return true; 1219 } 1220 break; 1221 default:; 1222 } 1223 return false; 1224 } 1225 1226 /// This is the entry point for folds that could be implemented in regular 1227 /// InstCombine, but they are separated because they are not expected to 1228 /// occur frequently and/or have more than a constant-length pattern match. 1229 static bool foldUnusualPatterns(Function &F, DominatorTree &DT, 1230 TargetTransformInfo &TTI, 1231 TargetLibraryInfo &TLI, AliasAnalysis &AA, 1232 AssumptionCache &AC, bool &MadeCFGChange) { 1233 bool MadeChange = false; 1234 for (BasicBlock &BB : F) { 1235 // Ignore unreachable basic blocks. 1236 if (!DT.isReachableFromEntry(&BB)) 1237 continue; 1238 1239 const DataLayout &DL = F.getDataLayout(); 1240 1241 // Walk the block backwards for efficiency. We're matching a chain of 1242 // use->defs, so we're more likely to succeed by starting from the bottom. 1243 // Also, we want to avoid matching partial patterns. 1244 // TODO: It would be more efficient if we removed dead instructions 1245 // iteratively in this loop rather than waiting until the end. 1246 for (Instruction &I : make_early_inc_range(llvm::reverse(BB))) { 1247 MadeChange |= foldAnyOrAllBitsSet(I); 1248 MadeChange |= foldGuardedFunnelShift(I, DT); 1249 MadeChange |= tryToRecognizePopCount(I); 1250 MadeChange |= tryToFPToSat(I, TTI); 1251 MadeChange |= tryToRecognizeTableBasedCttz(I); 1252 MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT); 1253 MadeChange |= foldPatternedLoads(I, DL); 1254 // NOTE: This function introduces erasing of the instruction `I`, so it 1255 // needs to be called at the end of this sequence, otherwise we may make 1256 // bugs. 1257 MadeChange |= foldLibCalls(I, TTI, TLI, AC, DT, DL, MadeCFGChange); 1258 } 1259 } 1260 1261 // We're done with transforms, so remove dead instructions. 1262 if (MadeChange) 1263 for (BasicBlock &BB : F) 1264 SimplifyInstructionsInBlock(&BB); 1265 1266 return MadeChange; 1267 } 1268 1269 /// This is the entry point for all transforms. Pass manager differences are 1270 /// handled in the callers of this function. 1271 static bool runImpl(Function &F, AssumptionCache &AC, TargetTransformInfo &TTI, 1272 TargetLibraryInfo &TLI, DominatorTree &DT, 1273 AliasAnalysis &AA, bool &MadeCFGChange) { 1274 bool MadeChange = false; 1275 const DataLayout &DL = F.getDataLayout(); 1276 TruncInstCombine TIC(AC, TLI, DL, DT); 1277 MadeChange |= TIC.run(F); 1278 MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA, AC, MadeCFGChange); 1279 return MadeChange; 1280 } 1281 1282 PreservedAnalyses AggressiveInstCombinePass::run(Function &F, 1283 FunctionAnalysisManager &AM) { 1284 auto &AC = AM.getResult<AssumptionAnalysis>(F); 1285 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); 1286 auto &DT = AM.getResult<DominatorTreeAnalysis>(F); 1287 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 1288 auto &AA = AM.getResult<AAManager>(F); 1289 bool MadeCFGChange = false; 1290 if (!runImpl(F, AC, TTI, TLI, DT, AA, MadeCFGChange)) { 1291 // No changes, all analyses are preserved. 1292 return PreservedAnalyses::all(); 1293 } 1294 // Mark all the analyses that instcombine updates as preserved. 1295 PreservedAnalyses PA; 1296 if (MadeCFGChange) 1297 PA.preserve<DominatorTreeAnalysis>(); 1298 else 1299 PA.preserveSet<CFGAnalyses>(); 1300 return PA; 1301 } 1302