1 //===- InstCombineShifts.cpp ----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the visitShl, visitLShr, and visitAShr functions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "InstCombineInternal.h" 14 #include "llvm/Analysis/ConstantFolding.h" 15 #include "llvm/Analysis/InstructionSimplify.h" 16 #include "llvm/IR/IntrinsicInst.h" 17 #include "llvm/IR/PatternMatch.h" 18 using namespace llvm; 19 using namespace PatternMatch; 20 21 #define DEBUG_TYPE "instcombine" 22 23 // Given pattern: 24 // (x shiftopcode Q) shiftopcode K 25 // we should rewrite it as 26 // x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x) 27 // This is valid for any shift, but they must be identical. 28 // 29 // AnalyzeForSignBitExtraction indicates that we will only analyze whether this 30 // pattern has any 2 right-shifts that sum to 1 less than original bit width. 31 Value *InstCombiner::reassociateShiftAmtsOfTwoSameDirectionShifts( 32 BinaryOperator *Sh0, const SimplifyQuery &SQ, 33 bool AnalyzeForSignBitExtraction) { 34 // Look for a shift of some instruction, ignore zext of shift amount if any. 35 Instruction *Sh0Op0; 36 Value *ShAmt0; 37 if (!match(Sh0, 38 m_Shift(m_Instruction(Sh0Op0), m_ZExtOrSelf(m_Value(ShAmt0))))) 39 return nullptr; 40 41 // If there is a truncation between the two shifts, we must make note of it 42 // and look through it. The truncation imposes additional constraints on the 43 // transform. 44 Instruction *Sh1; 45 Value *Trunc = nullptr; 46 match(Sh0Op0, 47 m_CombineOr(m_CombineAnd(m_Trunc(m_Instruction(Sh1)), m_Value(Trunc)), 48 m_Instruction(Sh1))); 49 50 // Inner shift: (x shiftopcode ShAmt1) 51 // Like with other shift, ignore zext of shift amount if any. 52 Value *X, *ShAmt1; 53 if (!match(Sh1, m_Shift(m_Value(X), m_ZExtOrSelf(m_Value(ShAmt1))))) 54 return nullptr; 55 56 // We have two shift amounts from two different shifts. The types of those 57 // shift amounts may not match. If that's the case let's bailout now.. 58 if (ShAmt0->getType() != ShAmt1->getType()) 59 return nullptr; 60 61 // We are only looking for signbit extraction if we have two right shifts. 62 bool HadTwoRightShifts = match(Sh0, m_Shr(m_Value(), m_Value())) && 63 match(Sh1, m_Shr(m_Value(), m_Value())); 64 // ... and if it's not two right-shifts, we know the answer already. 65 if (AnalyzeForSignBitExtraction && !HadTwoRightShifts) 66 return nullptr; 67 68 // The shift opcodes must be identical, unless we are just checking whether 69 // this pattern can be interpreted as a sign-bit-extraction. 70 Instruction::BinaryOps ShiftOpcode = Sh0->getOpcode(); 71 bool IdenticalShOpcodes = Sh0->getOpcode() == Sh1->getOpcode(); 72 if (!IdenticalShOpcodes && !AnalyzeForSignBitExtraction) 73 return nullptr; 74 75 // If we saw truncation, we'll need to produce extra instruction, 76 // and for that one of the operands of the shift must be one-use, 77 // unless of course we don't actually plan to produce any instructions here. 78 if (Trunc && !AnalyzeForSignBitExtraction && 79 !match(Sh0, m_c_BinOp(m_OneUse(m_Value()), m_Value()))) 80 return nullptr; 81 82 // Can we fold (ShAmt0+ShAmt1) ? 83 auto *NewShAmt = dyn_cast_or_null<Constant>( 84 SimplifyAddInst(ShAmt0, ShAmt1, /*isNSW=*/false, /*isNUW=*/false, 85 SQ.getWithInstruction(Sh0))); 86 if (!NewShAmt) 87 return nullptr; // Did not simplify. 88 unsigned NewShAmtBitWidth = NewShAmt->getType()->getScalarSizeInBits(); 89 unsigned XBitWidth = X->getType()->getScalarSizeInBits(); 90 // Is the new shift amount smaller than the bit width of inner/new shift? 91 if (!match(NewShAmt, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT, 92 APInt(NewShAmtBitWidth, XBitWidth)))) 93 return nullptr; // FIXME: could perform constant-folding. 94 95 // If there was a truncation, and we have a right-shift, we can only fold if 96 // we are left with the original sign bit. Likewise, if we were just checking 97 // that this is a sighbit extraction, this is the place to check it. 98 // FIXME: zero shift amount is also legal here, but we can't *easily* check 99 // more than one predicate so it's not really worth it. 100 if (HadTwoRightShifts && (Trunc || AnalyzeForSignBitExtraction)) { 101 // If it's not a sign bit extraction, then we're done. 102 if (!match(NewShAmt, 103 m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, 104 APInt(NewShAmtBitWidth, XBitWidth - 1)))) 105 return nullptr; 106 // If it is, and that was the question, return the base value. 107 if (AnalyzeForSignBitExtraction) 108 return X; 109 } 110 111 assert(IdenticalShOpcodes && "Should not get here with different shifts."); 112 113 // All good, we can do this fold. 114 NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, X->getType()); 115 116 BinaryOperator *NewShift = BinaryOperator::Create(ShiftOpcode, X, NewShAmt); 117 118 // The flags can only be propagated if there wasn't a trunc. 119 if (!Trunc) { 120 // If the pattern did not involve trunc, and both of the original shifts 121 // had the same flag set, preserve the flag. 122 if (ShiftOpcode == Instruction::BinaryOps::Shl) { 123 NewShift->setHasNoUnsignedWrap(Sh0->hasNoUnsignedWrap() && 124 Sh1->hasNoUnsignedWrap()); 125 NewShift->setHasNoSignedWrap(Sh0->hasNoSignedWrap() && 126 Sh1->hasNoSignedWrap()); 127 } else { 128 NewShift->setIsExact(Sh0->isExact() && Sh1->isExact()); 129 } 130 } 131 132 Instruction *Ret = NewShift; 133 if (Trunc) { 134 Builder.Insert(NewShift); 135 Ret = CastInst::Create(Instruction::Trunc, NewShift, Sh0->getType()); 136 } 137 138 return Ret; 139 } 140 141 // If we have some pattern that leaves only some low bits set, and then performs 142 // left-shift of those bits, if none of the bits that are left after the final 143 // shift are modified by the mask, we can omit the mask. 144 // 145 // There are many variants to this pattern: 146 // a) (x & ((1 << MaskShAmt) - 1)) << ShiftShAmt 147 // b) (x & (~(-1 << MaskShAmt))) << ShiftShAmt 148 // c) (x & (-1 >> MaskShAmt)) << ShiftShAmt 149 // d) (x & ((-1 << MaskShAmt) >> MaskShAmt)) << ShiftShAmt 150 // e) ((x << MaskShAmt) l>> MaskShAmt) << ShiftShAmt 151 // f) ((x << MaskShAmt) a>> MaskShAmt) << ShiftShAmt 152 // All these patterns can be simplified to just: 153 // x << ShiftShAmt 154 // iff: 155 // a,b) (MaskShAmt+ShiftShAmt) u>= bitwidth(x) 156 // c,d,e,f) (ShiftShAmt-MaskShAmt) s>= 0 (i.e. ShiftShAmt u>= MaskShAmt) 157 static Instruction * 158 dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, 159 const SimplifyQuery &Q, 160 InstCombiner::BuilderTy &Builder) { 161 assert(OuterShift->getOpcode() == Instruction::BinaryOps::Shl && 162 "The input must be 'shl'!"); 163 164 Value *Masked, *ShiftShAmt; 165 match(OuterShift, 166 m_Shift(m_Value(Masked), m_ZExtOrSelf(m_Value(ShiftShAmt)))); 167 168 // *If* there is a truncation between an outer shift and a possibly-mask, 169 // then said truncation *must* be one-use, else we can't perform the fold. 170 Value *Trunc; 171 if (match(Masked, m_CombineAnd(m_Trunc(m_Value(Masked)), m_Value(Trunc))) && 172 !Trunc->hasOneUse()) 173 return nullptr; 174 175 Type *NarrowestTy = OuterShift->getType(); 176 Type *WidestTy = Masked->getType(); 177 bool HadTrunc = WidestTy != NarrowestTy; 178 179 // The mask must be computed in a type twice as wide to ensure 180 // that no bits are lost if the sum-of-shifts is wider than the base type. 181 Type *ExtendedTy = WidestTy->getExtendedType(); 182 183 Value *MaskShAmt; 184 185 // ((1 << MaskShAmt) - 1) 186 auto MaskA = m_Add(m_Shl(m_One(), m_Value(MaskShAmt)), m_AllOnes()); 187 // (~(-1 << maskNbits)) 188 auto MaskB = m_Xor(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_AllOnes()); 189 // (-1 >> MaskShAmt) 190 auto MaskC = m_Shr(m_AllOnes(), m_Value(MaskShAmt)); 191 // ((-1 << MaskShAmt) >> MaskShAmt) 192 auto MaskD = 193 m_Shr(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_Deferred(MaskShAmt)); 194 195 Value *X; 196 Constant *NewMask; 197 198 if (match(Masked, m_c_And(m_CombineOr(MaskA, MaskB), m_Value(X)))) { 199 // Peek through an optional zext of the shift amount. 200 match(MaskShAmt, m_ZExtOrSelf(m_Value(MaskShAmt))); 201 202 // We have two shift amounts from two different shifts. The types of those 203 // shift amounts may not match. If that's the case let's bailout now. 204 if (MaskShAmt->getType() != ShiftShAmt->getType()) 205 return nullptr; 206 207 // Can we simplify (MaskShAmt+ShiftShAmt) ? 208 auto *SumOfShAmts = dyn_cast_or_null<Constant>(SimplifyAddInst( 209 MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q)); 210 if (!SumOfShAmts) 211 return nullptr; // Did not simplify. 212 // In this pattern SumOfShAmts correlates with the number of low bits 213 // that shall remain in the root value (OuterShift). 214 215 // An extend of an undef value becomes zero because the high bits are never 216 // completely unknown. Replace the the `undef` shift amounts with final 217 // shift bitwidth to ensure that the value remains undef when creating the 218 // subsequent shift op. 219 SumOfShAmts = Constant::replaceUndefsWith( 220 SumOfShAmts, ConstantInt::get(SumOfShAmts->getType()->getScalarType(), 221 ExtendedTy->getScalarSizeInBits())); 222 auto *ExtendedSumOfShAmts = ConstantExpr::getZExt(SumOfShAmts, ExtendedTy); 223 // And compute the mask as usual: ~(-1 << (SumOfShAmts)) 224 auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy); 225 auto *ExtendedInvertedMask = 226 ConstantExpr::getShl(ExtendedAllOnes, ExtendedSumOfShAmts); 227 NewMask = ConstantExpr::getNot(ExtendedInvertedMask); 228 } else if (match(Masked, m_c_And(m_CombineOr(MaskC, MaskD), m_Value(X))) || 229 match(Masked, m_Shr(m_Shl(m_Value(X), m_Value(MaskShAmt)), 230 m_Deferred(MaskShAmt)))) { 231 // Peek through an optional zext of the shift amount. 232 match(MaskShAmt, m_ZExtOrSelf(m_Value(MaskShAmt))); 233 234 // We have two shift amounts from two different shifts. The types of those 235 // shift amounts may not match. If that's the case let's bailout now. 236 if (MaskShAmt->getType() != ShiftShAmt->getType()) 237 return nullptr; 238 239 // Can we simplify (ShiftShAmt-MaskShAmt) ? 240 auto *ShAmtsDiff = dyn_cast_or_null<Constant>(SimplifySubInst( 241 ShiftShAmt, MaskShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q)); 242 if (!ShAmtsDiff) 243 return nullptr; // Did not simplify. 244 // In this pattern ShAmtsDiff correlates with the number of high bits that 245 // shall be unset in the root value (OuterShift). 246 247 // An extend of an undef value becomes zero because the high bits are never 248 // completely unknown. Replace the the `undef` shift amounts with negated 249 // bitwidth of innermost shift to ensure that the value remains undef when 250 // creating the subsequent shift op. 251 unsigned WidestTyBitWidth = WidestTy->getScalarSizeInBits(); 252 ShAmtsDiff = Constant::replaceUndefsWith( 253 ShAmtsDiff, ConstantInt::get(ShAmtsDiff->getType()->getScalarType(), 254 -WidestTyBitWidth)); 255 auto *ExtendedNumHighBitsToClear = ConstantExpr::getZExt( 256 ConstantExpr::getSub(ConstantInt::get(ShAmtsDiff->getType(), 257 WidestTyBitWidth, 258 /*isSigned=*/false), 259 ShAmtsDiff), 260 ExtendedTy); 261 // And compute the mask as usual: (-1 l>> (NumHighBitsToClear)) 262 auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy); 263 NewMask = 264 ConstantExpr::getLShr(ExtendedAllOnes, ExtendedNumHighBitsToClear); 265 } else 266 return nullptr; // Don't know anything about this pattern. 267 268 NewMask = ConstantExpr::getTrunc(NewMask, NarrowestTy); 269 270 // Does this mask has any unset bits? If not then we can just not apply it. 271 bool NeedMask = !match(NewMask, m_AllOnes()); 272 273 // If we need to apply a mask, there are several more restrictions we have. 274 if (NeedMask) { 275 // The old masking instruction must go away. 276 if (!Masked->hasOneUse()) 277 return nullptr; 278 // The original "masking" instruction must not have been`ashr`. 279 if (match(Masked, m_AShr(m_Value(), m_Value()))) 280 return nullptr; 281 } 282 283 // If we need to apply truncation, let's do it first, since we can. 284 // We have already ensured that the old truncation will go away. 285 if (HadTrunc) 286 X = Builder.CreateTrunc(X, NarrowestTy); 287 288 // No 'NUW'/'NSW'! We no longer know that we won't shift-out non-0 bits. 289 // We didn't change the Type of this outermost shift, so we can just do it. 290 auto *NewShift = BinaryOperator::Create(OuterShift->getOpcode(), X, 291 OuterShift->getOperand(1)); 292 if (!NeedMask) 293 return NewShift; 294 295 Builder.Insert(NewShift); 296 return BinaryOperator::Create(Instruction::And, NewShift, NewMask); 297 } 298 299 /// If we have a shift-by-constant of a bitwise logic op that itself has a 300 /// shift-by-constant operand with identical opcode, we may be able to convert 301 /// that into 2 independent shifts followed by the logic op. This eliminates a 302 /// a use of an intermediate value (reduces dependency chain). 303 static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I, 304 InstCombiner::BuilderTy &Builder) { 305 assert(I.isShift() && "Expected a shift as input"); 306 auto *LogicInst = dyn_cast<BinaryOperator>(I.getOperand(0)); 307 if (!LogicInst || !LogicInst->isBitwiseLogicOp() || !LogicInst->hasOneUse()) 308 return nullptr; 309 310 const APInt *C0, *C1; 311 if (!match(I.getOperand(1), m_APInt(C1))) 312 return nullptr; 313 314 Instruction::BinaryOps ShiftOpcode = I.getOpcode(); 315 Type *Ty = I.getType(); 316 317 // Find a matching one-use shift by constant. The fold is not valid if the sum 318 // of the shift values equals or exceeds bitwidth. 319 // TODO: Remove the one-use check if the other logic operand (Y) is constant. 320 Value *X, *Y; 321 auto matchFirstShift = [&](Value *V) { 322 return !isa<ConstantExpr>(V) && 323 match(V, m_OneUse(m_Shift(m_Value(X), m_APInt(C0)))) && 324 cast<BinaryOperator>(V)->getOpcode() == ShiftOpcode && 325 (*C0 + *C1).ult(Ty->getScalarSizeInBits()); 326 }; 327 328 // Logic ops are commutative, so check each operand for a match. 329 if (matchFirstShift(LogicInst->getOperand(0))) 330 Y = LogicInst->getOperand(1); 331 else if (matchFirstShift(LogicInst->getOperand(1))) 332 Y = LogicInst->getOperand(0); 333 else 334 return nullptr; 335 336 // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1) 337 Constant *ShiftSumC = ConstantInt::get(Ty, *C0 + *C1); 338 Value *NewShift1 = Builder.CreateBinOp(ShiftOpcode, X, ShiftSumC); 339 Value *NewShift2 = Builder.CreateBinOp(ShiftOpcode, Y, I.getOperand(1)); 340 return BinaryOperator::Create(LogicInst->getOpcode(), NewShift1, NewShift2); 341 } 342 343 Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { 344 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); 345 assert(Op0->getType() == Op1->getType()); 346 347 // If the shift amount is a one-use `sext`, we can demote it to `zext`. 348 Value *Y; 349 if (match(Op1, m_OneUse(m_SExt(m_Value(Y))))) { 350 Value *NewExt = Builder.CreateZExt(Y, I.getType(), Op1->getName()); 351 return BinaryOperator::Create(I.getOpcode(), Op0, NewExt); 352 } 353 354 // See if we can fold away this shift. 355 if (SimplifyDemandedInstructionBits(I)) 356 return &I; 357 358 // Try to fold constant and into select arguments. 359 if (isa<Constant>(Op0)) 360 if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) 361 if (Instruction *R = FoldOpIntoSelect(I, SI)) 362 return R; 363 364 if (Constant *CUI = dyn_cast<Constant>(Op1)) 365 if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I)) 366 return Res; 367 368 if (auto *NewShift = cast_or_null<Instruction>( 369 reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ))) 370 return NewShift; 371 372 // (C1 shift (A add C2)) -> (C1 shift C2) shift A) 373 // iff A and C2 are both positive. 374 Value *A; 375 Constant *C; 376 if (match(Op0, m_Constant()) && match(Op1, m_Add(m_Value(A), m_Constant(C)))) 377 if (isKnownNonNegative(A, DL, 0, &AC, &I, &DT) && 378 isKnownNonNegative(C, DL, 0, &AC, &I, &DT)) 379 return BinaryOperator::Create( 380 I.getOpcode(), Builder.CreateBinOp(I.getOpcode(), Op0, C), A); 381 382 // X shift (A srem B) -> X shift (A and B-1) iff B is a power of 2. 383 // Because shifts by negative values (which could occur if A were negative) 384 // are undefined. 385 const APInt *B; 386 if (Op1->hasOneUse() && match(Op1, m_SRem(m_Value(A), m_Power2(B)))) { 387 // FIXME: Should this get moved into SimplifyDemandedBits by saying we don't 388 // demand the sign bit (and many others) here?? 389 Value *Rem = Builder.CreateAnd(A, ConstantInt::get(I.getType(), *B - 1), 390 Op1->getName()); 391 I.setOperand(1, Rem); 392 return &I; 393 } 394 395 if (Instruction *Logic = foldShiftOfShiftedLogic(I, Builder)) 396 return Logic; 397 398 return nullptr; 399 } 400 401 /// Return true if we can simplify two logical (either left or right) shifts 402 /// that have constant shift amounts: OuterShift (InnerShift X, C1), C2. 403 static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl, 404 Instruction *InnerShift, InstCombiner &IC, 405 Instruction *CxtI) { 406 assert(InnerShift->isLogicalShift() && "Unexpected instruction type"); 407 408 // We need constant scalar or constant splat shifts. 409 const APInt *InnerShiftConst; 410 if (!match(InnerShift->getOperand(1), m_APInt(InnerShiftConst))) 411 return false; 412 413 // Two logical shifts in the same direction: 414 // shl (shl X, C1), C2 --> shl X, C1 + C2 415 // lshr (lshr X, C1), C2 --> lshr X, C1 + C2 416 bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl; 417 if (IsInnerShl == IsOuterShl) 418 return true; 419 420 // Equal shift amounts in opposite directions become bitwise 'and': 421 // lshr (shl X, C), C --> and X, C' 422 // shl (lshr X, C), C --> and X, C' 423 if (*InnerShiftConst == OuterShAmt) 424 return true; 425 426 // If the 2nd shift is bigger than the 1st, we can fold: 427 // lshr (shl X, C1), C2 --> and (shl X, C1 - C2), C3 428 // shl (lshr X, C1), C2 --> and (lshr X, C1 - C2), C3 429 // but it isn't profitable unless we know the and'd out bits are already zero. 430 // Also, check that the inner shift is valid (less than the type width) or 431 // we'll crash trying to produce the bit mask for the 'and'. 432 unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits(); 433 if (InnerShiftConst->ugt(OuterShAmt) && InnerShiftConst->ult(TypeWidth)) { 434 unsigned InnerShAmt = InnerShiftConst->getZExtValue(); 435 unsigned MaskShift = 436 IsInnerShl ? TypeWidth - InnerShAmt : InnerShAmt - OuterShAmt; 437 APInt Mask = APInt::getLowBitsSet(TypeWidth, OuterShAmt) << MaskShift; 438 if (IC.MaskedValueIsZero(InnerShift->getOperand(0), Mask, 0, CxtI)) 439 return true; 440 } 441 442 return false; 443 } 444 445 /// See if we can compute the specified value, but shifted logically to the left 446 /// or right by some number of bits. This should return true if the expression 447 /// can be computed for the same cost as the current expression tree. This is 448 /// used to eliminate extraneous shifting from things like: 449 /// %C = shl i128 %A, 64 450 /// %D = shl i128 %B, 96 451 /// %E = or i128 %C, %D 452 /// %F = lshr i128 %E, 64 453 /// where the client will ask if E can be computed shifted right by 64-bits. If 454 /// this succeeds, getShiftedValue() will be called to produce the value. 455 static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, 456 InstCombiner &IC, Instruction *CxtI) { 457 // We can always evaluate constants shifted. 458 if (isa<Constant>(V)) 459 return true; 460 461 Instruction *I = dyn_cast<Instruction>(V); 462 if (!I) return false; 463 464 // If this is the opposite shift, we can directly reuse the input of the shift 465 // if the needed bits are already zero in the input. This allows us to reuse 466 // the value which means that we don't care if the shift has multiple uses. 467 // TODO: Handle opposite shift by exact value. 468 ConstantInt *CI = nullptr; 469 if ((IsLeftShift && match(I, m_LShr(m_Value(), m_ConstantInt(CI)))) || 470 (!IsLeftShift && match(I, m_Shl(m_Value(), m_ConstantInt(CI))))) { 471 if (CI->getValue() == NumBits) { 472 // TODO: Check that the input bits are already zero with MaskedValueIsZero 473 #if 0 474 // If this is a truncate of a logical shr, we can truncate it to a smaller 475 // lshr iff we know that the bits we would otherwise be shifting in are 476 // already zeros. 477 uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); 478 uint32_t BitWidth = Ty->getScalarSizeInBits(); 479 if (MaskedValueIsZero(I->getOperand(0), 480 APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth)) && 481 CI->getLimitedValue(BitWidth) < BitWidth) { 482 return CanEvaluateTruncated(I->getOperand(0), Ty); 483 } 484 #endif 485 486 } 487 } 488 489 // We can't mutate something that has multiple uses: doing so would 490 // require duplicating the instruction in general, which isn't profitable. 491 if (!I->hasOneUse()) return false; 492 493 switch (I->getOpcode()) { 494 default: return false; 495 case Instruction::And: 496 case Instruction::Or: 497 case Instruction::Xor: 498 // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted. 499 return canEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) && 500 canEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I); 501 502 case Instruction::Shl: 503 case Instruction::LShr: 504 return canEvaluateShiftedShift(NumBits, IsLeftShift, I, IC, CxtI); 505 506 case Instruction::Select: { 507 SelectInst *SI = cast<SelectInst>(I); 508 Value *TrueVal = SI->getTrueValue(); 509 Value *FalseVal = SI->getFalseValue(); 510 return canEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) && 511 canEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI); 512 } 513 case Instruction::PHI: { 514 // We can change a phi if we can change all operands. Note that we never 515 // get into trouble with cyclic PHIs here because we only consider 516 // instructions with a single use. 517 PHINode *PN = cast<PHINode>(I); 518 for (Value *IncValue : PN->incoming_values()) 519 if (!canEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN)) 520 return false; 521 return true; 522 } 523 } 524 } 525 526 /// Fold OuterShift (InnerShift X, C1), C2. 527 /// See canEvaluateShiftedShift() for the constraints on these instructions. 528 static Value *foldShiftedShift(BinaryOperator *InnerShift, unsigned OuterShAmt, 529 bool IsOuterShl, 530 InstCombiner::BuilderTy &Builder) { 531 bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl; 532 Type *ShType = InnerShift->getType(); 533 unsigned TypeWidth = ShType->getScalarSizeInBits(); 534 535 // We only accept shifts-by-a-constant in canEvaluateShifted(). 536 const APInt *C1; 537 match(InnerShift->getOperand(1), m_APInt(C1)); 538 unsigned InnerShAmt = C1->getZExtValue(); 539 540 // Change the shift amount and clear the appropriate IR flags. 541 auto NewInnerShift = [&](unsigned ShAmt) { 542 InnerShift->setOperand(1, ConstantInt::get(ShType, ShAmt)); 543 if (IsInnerShl) { 544 InnerShift->setHasNoUnsignedWrap(false); 545 InnerShift->setHasNoSignedWrap(false); 546 } else { 547 InnerShift->setIsExact(false); 548 } 549 return InnerShift; 550 }; 551 552 // Two logical shifts in the same direction: 553 // shl (shl X, C1), C2 --> shl X, C1 + C2 554 // lshr (lshr X, C1), C2 --> lshr X, C1 + C2 555 if (IsInnerShl == IsOuterShl) { 556 // If this is an oversized composite shift, then unsigned shifts get 0. 557 if (InnerShAmt + OuterShAmt >= TypeWidth) 558 return Constant::getNullValue(ShType); 559 560 return NewInnerShift(InnerShAmt + OuterShAmt); 561 } 562 563 // Equal shift amounts in opposite directions become bitwise 'and': 564 // lshr (shl X, C), C --> and X, C' 565 // shl (lshr X, C), C --> and X, C' 566 if (InnerShAmt == OuterShAmt) { 567 APInt Mask = IsInnerShl 568 ? APInt::getLowBitsSet(TypeWidth, TypeWidth - OuterShAmt) 569 : APInt::getHighBitsSet(TypeWidth, TypeWidth - OuterShAmt); 570 Value *And = Builder.CreateAnd(InnerShift->getOperand(0), 571 ConstantInt::get(ShType, Mask)); 572 if (auto *AndI = dyn_cast<Instruction>(And)) { 573 AndI->moveBefore(InnerShift); 574 AndI->takeName(InnerShift); 575 } 576 return And; 577 } 578 579 assert(InnerShAmt > OuterShAmt && 580 "Unexpected opposite direction logical shift pair"); 581 582 // In general, we would need an 'and' for this transform, but 583 // canEvaluateShiftedShift() guarantees that the masked-off bits are not used. 584 // lshr (shl X, C1), C2 --> shl X, C1 - C2 585 // shl (lshr X, C1), C2 --> lshr X, C1 - C2 586 return NewInnerShift(InnerShAmt - OuterShAmt); 587 } 588 589 /// When canEvaluateShifted() returns true for an expression, this function 590 /// inserts the new computation that produces the shifted value. 591 static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, 592 InstCombiner &IC, const DataLayout &DL) { 593 // We can always evaluate constants shifted. 594 if (Constant *C = dyn_cast<Constant>(V)) { 595 if (isLeftShift) 596 V = IC.Builder.CreateShl(C, NumBits); 597 else 598 V = IC.Builder.CreateLShr(C, NumBits); 599 // If we got a constantexpr back, try to simplify it with TD info. 600 if (auto *C = dyn_cast<Constant>(V)) 601 if (auto *FoldedC = 602 ConstantFoldConstant(C, DL, &IC.getTargetLibraryInfo())) 603 V = FoldedC; 604 return V; 605 } 606 607 Instruction *I = cast<Instruction>(V); 608 IC.Worklist.Add(I); 609 610 switch (I->getOpcode()) { 611 default: llvm_unreachable("Inconsistency with CanEvaluateShifted"); 612 case Instruction::And: 613 case Instruction::Or: 614 case Instruction::Xor: 615 // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted. 616 I->setOperand( 617 0, getShiftedValue(I->getOperand(0), NumBits, isLeftShift, IC, DL)); 618 I->setOperand( 619 1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL)); 620 return I; 621 622 case Instruction::Shl: 623 case Instruction::LShr: 624 return foldShiftedShift(cast<BinaryOperator>(I), NumBits, isLeftShift, 625 IC.Builder); 626 627 case Instruction::Select: 628 I->setOperand( 629 1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL)); 630 I->setOperand( 631 2, getShiftedValue(I->getOperand(2), NumBits, isLeftShift, IC, DL)); 632 return I; 633 case Instruction::PHI: { 634 // We can change a phi if we can change all operands. Note that we never 635 // get into trouble with cyclic PHIs here because we only consider 636 // instructions with a single use. 637 PHINode *PN = cast<PHINode>(I); 638 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) 639 PN->setIncomingValue(i, getShiftedValue(PN->getIncomingValue(i), NumBits, 640 isLeftShift, IC, DL)); 641 return PN; 642 } 643 } 644 } 645 646 // If this is a bitwise operator or add with a constant RHS we might be able 647 // to pull it through a shift. 648 static bool canShiftBinOpWithConstantRHS(BinaryOperator &Shift, 649 BinaryOperator *BO) { 650 switch (BO->getOpcode()) { 651 default: 652 return false; // Do not perform transform! 653 case Instruction::Add: 654 return Shift.getOpcode() == Instruction::Shl; 655 case Instruction::Or: 656 case Instruction::Xor: 657 case Instruction::And: 658 return true; 659 } 660 } 661 662 Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, 663 BinaryOperator &I) { 664 bool isLeftShift = I.getOpcode() == Instruction::Shl; 665 666 const APInt *Op1C; 667 if (!match(Op1, m_APInt(Op1C))) 668 return nullptr; 669 670 // See if we can propagate this shift into the input, this covers the trivial 671 // cast of lshr(shl(x,c1),c2) as well as other more complex cases. 672 if (I.getOpcode() != Instruction::AShr && 673 canEvaluateShifted(Op0, Op1C->getZExtValue(), isLeftShift, *this, &I)) { 674 LLVM_DEBUG( 675 dbgs() << "ICE: GetShiftedValue propagating shift through expression" 676 " to eliminate shift:\n IN: " 677 << *Op0 << "\n SH: " << I << "\n"); 678 679 return replaceInstUsesWith( 680 I, getShiftedValue(Op0, Op1C->getZExtValue(), isLeftShift, *this, DL)); 681 } 682 683 // See if we can simplify any instructions used by the instruction whose sole 684 // purpose is to compute bits we don't care about. 685 unsigned TypeBits = Op0->getType()->getScalarSizeInBits(); 686 687 assert(!Op1C->uge(TypeBits) && 688 "Shift over the type width should have been removed already"); 689 690 if (Instruction *FoldedShift = foldBinOpIntoSelectOrPhi(I)) 691 return FoldedShift; 692 693 // Fold shift2(trunc(shift1(x,c1)), c2) -> trunc(shift2(shift1(x,c1),c2)) 694 if (TruncInst *TI = dyn_cast<TruncInst>(Op0)) { 695 Instruction *TrOp = dyn_cast<Instruction>(TI->getOperand(0)); 696 // If 'shift2' is an ashr, we would have to get the sign bit into a funny 697 // place. Don't try to do this transformation in this case. Also, we 698 // require that the input operand is a shift-by-constant so that we have 699 // confidence that the shifts will get folded together. We could do this 700 // xform in more cases, but it is unlikely to be profitable. 701 if (TrOp && I.isLogicalShift() && TrOp->isShift() && 702 isa<ConstantInt>(TrOp->getOperand(1))) { 703 // Okay, we'll do this xform. Make the shift of shift. 704 Constant *ShAmt = 705 ConstantExpr::getZExt(cast<Constant>(Op1), TrOp->getType()); 706 // (shift2 (shift1 & 0x00FF), c2) 707 Value *NSh = Builder.CreateBinOp(I.getOpcode(), TrOp, ShAmt, I.getName()); 708 709 // For logical shifts, the truncation has the effect of making the high 710 // part of the register be zeros. Emulate this by inserting an AND to 711 // clear the top bits as needed. This 'and' will usually be zapped by 712 // other xforms later if dead. 713 unsigned SrcSize = TrOp->getType()->getScalarSizeInBits(); 714 unsigned DstSize = TI->getType()->getScalarSizeInBits(); 715 APInt MaskV(APInt::getLowBitsSet(SrcSize, DstSize)); 716 717 // The mask we constructed says what the trunc would do if occurring 718 // between the shifts. We want to know the effect *after* the second 719 // shift. We know that it is a logical shift by a constant, so adjust the 720 // mask as appropriate. 721 if (I.getOpcode() == Instruction::Shl) 722 MaskV <<= Op1C->getZExtValue(); 723 else { 724 assert(I.getOpcode() == Instruction::LShr && "Unknown logical shift"); 725 MaskV.lshrInPlace(Op1C->getZExtValue()); 726 } 727 728 // shift1 & 0x00FF 729 Value *And = Builder.CreateAnd(NSh, 730 ConstantInt::get(I.getContext(), MaskV), 731 TI->getName()); 732 733 // Return the value truncated to the interesting size. 734 return new TruncInst(And, I.getType()); 735 } 736 } 737 738 if (Op0->hasOneUse()) { 739 if (BinaryOperator *Op0BO = dyn_cast<BinaryOperator>(Op0)) { 740 // Turn ((X >> C) + Y) << C -> (X + (Y << C)) & (~0 << C) 741 Value *V1, *V2; 742 ConstantInt *CC; 743 switch (Op0BO->getOpcode()) { 744 default: break; 745 case Instruction::Add: 746 case Instruction::And: 747 case Instruction::Or: 748 case Instruction::Xor: { 749 // These operators commute. 750 // Turn (Y + (X >> C)) << C -> (X + (Y << C)) & (~0 << C) 751 if (isLeftShift && Op0BO->getOperand(1)->hasOneUse() && 752 match(Op0BO->getOperand(1), m_Shr(m_Value(V1), 753 m_Specific(Op1)))) { 754 Value *YS = // (Y << C) 755 Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName()); 756 // (X + (Y << C)) 757 Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), YS, V1, 758 Op0BO->getOperand(1)->getName()); 759 unsigned Op1Val = Op1C->getLimitedValue(TypeBits); 760 761 APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); 762 Constant *Mask = ConstantInt::get(I.getContext(), Bits); 763 if (VectorType *VT = dyn_cast<VectorType>(X->getType())) 764 Mask = ConstantVector::getSplat(VT->getNumElements(), Mask); 765 return BinaryOperator::CreateAnd(X, Mask); 766 } 767 768 // Turn (Y + ((X >> C) & CC)) << C -> ((X & (CC << C)) + (Y << C)) 769 Value *Op0BOOp1 = Op0BO->getOperand(1); 770 if (isLeftShift && Op0BOOp1->hasOneUse() && 771 match(Op0BOOp1, 772 m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))), 773 m_ConstantInt(CC)))) { 774 Value *YS = // (Y << C) 775 Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName()); 776 // X & (CC << C) 777 Value *XM = Builder.CreateAnd(V1, ConstantExpr::getShl(CC, Op1), 778 V1->getName()+".mask"); 779 return BinaryOperator::Create(Op0BO->getOpcode(), YS, XM); 780 } 781 LLVM_FALLTHROUGH; 782 } 783 784 case Instruction::Sub: { 785 // Turn ((X >> C) + Y) << C -> (X + (Y << C)) & (~0 << C) 786 if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() && 787 match(Op0BO->getOperand(0), m_Shr(m_Value(V1), 788 m_Specific(Op1)))) { 789 Value *YS = // (Y << C) 790 Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName()); 791 // (X + (Y << C)) 792 Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), V1, YS, 793 Op0BO->getOperand(0)->getName()); 794 unsigned Op1Val = Op1C->getLimitedValue(TypeBits); 795 796 APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); 797 Constant *Mask = ConstantInt::get(I.getContext(), Bits); 798 if (VectorType *VT = dyn_cast<VectorType>(X->getType())) 799 Mask = ConstantVector::getSplat(VT->getNumElements(), Mask); 800 return BinaryOperator::CreateAnd(X, Mask); 801 } 802 803 // Turn (((X >> C)&CC) + Y) << C -> (X + (Y << C)) & (CC << C) 804 if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() && 805 match(Op0BO->getOperand(0), 806 m_And(m_OneUse(m_Shr(m_Value(V1), m_Value(V2))), 807 m_ConstantInt(CC))) && V2 == Op1) { 808 Value *YS = // (Y << C) 809 Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName()); 810 // X & (CC << C) 811 Value *XM = Builder.CreateAnd(V1, ConstantExpr::getShl(CC, Op1), 812 V1->getName()+".mask"); 813 814 return BinaryOperator::Create(Op0BO->getOpcode(), XM, YS); 815 } 816 817 break; 818 } 819 } 820 821 822 // If the operand is a bitwise operator with a constant RHS, and the 823 // shift is the only use, we can pull it out of the shift. 824 const APInt *Op0C; 825 if (match(Op0BO->getOperand(1), m_APInt(Op0C))) { 826 if (canShiftBinOpWithConstantRHS(I, Op0BO)) { 827 Constant *NewRHS = ConstantExpr::get(I.getOpcode(), 828 cast<Constant>(Op0BO->getOperand(1)), Op1); 829 830 Value *NewShift = 831 Builder.CreateBinOp(I.getOpcode(), Op0BO->getOperand(0), Op1); 832 NewShift->takeName(Op0BO); 833 834 return BinaryOperator::Create(Op0BO->getOpcode(), NewShift, 835 NewRHS); 836 } 837 } 838 839 // If the operand is a subtract with a constant LHS, and the shift 840 // is the only use, we can pull it out of the shift. 841 // This folds (shl (sub C1, X), C2) -> (sub (C1 << C2), (shl X, C2)) 842 if (isLeftShift && Op0BO->getOpcode() == Instruction::Sub && 843 match(Op0BO->getOperand(0), m_APInt(Op0C))) { 844 Constant *NewRHS = ConstantExpr::get(I.getOpcode(), 845 cast<Constant>(Op0BO->getOperand(0)), Op1); 846 847 Value *NewShift = Builder.CreateShl(Op0BO->getOperand(1), Op1); 848 NewShift->takeName(Op0BO); 849 850 return BinaryOperator::CreateSub(NewRHS, NewShift); 851 } 852 } 853 854 // If we have a select that conditionally executes some binary operator, 855 // see if we can pull it the select and operator through the shift. 856 // 857 // For example, turning: 858 // shl (select C, (add X, C1), X), C2 859 // Into: 860 // Y = shl X, C2 861 // select C, (add Y, C1 << C2), Y 862 Value *Cond; 863 BinaryOperator *TBO; 864 Value *FalseVal; 865 if (match(Op0, m_Select(m_Value(Cond), m_OneUse(m_BinOp(TBO)), 866 m_Value(FalseVal)))) { 867 const APInt *C; 868 if (!isa<Constant>(FalseVal) && TBO->getOperand(0) == FalseVal && 869 match(TBO->getOperand(1), m_APInt(C)) && 870 canShiftBinOpWithConstantRHS(I, TBO)) { 871 Constant *NewRHS = ConstantExpr::get(I.getOpcode(), 872 cast<Constant>(TBO->getOperand(1)), Op1); 873 874 Value *NewShift = 875 Builder.CreateBinOp(I.getOpcode(), FalseVal, Op1); 876 Value *NewOp = Builder.CreateBinOp(TBO->getOpcode(), NewShift, 877 NewRHS); 878 return SelectInst::Create(Cond, NewOp, NewShift); 879 } 880 } 881 882 BinaryOperator *FBO; 883 Value *TrueVal; 884 if (match(Op0, m_Select(m_Value(Cond), m_Value(TrueVal), 885 m_OneUse(m_BinOp(FBO))))) { 886 const APInt *C; 887 if (!isa<Constant>(TrueVal) && FBO->getOperand(0) == TrueVal && 888 match(FBO->getOperand(1), m_APInt(C)) && 889 canShiftBinOpWithConstantRHS(I, FBO)) { 890 Constant *NewRHS = ConstantExpr::get(I.getOpcode(), 891 cast<Constant>(FBO->getOperand(1)), Op1); 892 893 Value *NewShift = 894 Builder.CreateBinOp(I.getOpcode(), TrueVal, Op1); 895 Value *NewOp = Builder.CreateBinOp(FBO->getOpcode(), NewShift, 896 NewRHS); 897 return SelectInst::Create(Cond, NewShift, NewOp); 898 } 899 } 900 } 901 902 return nullptr; 903 } 904 905 Instruction *InstCombiner::visitShl(BinaryOperator &I) { 906 const SimplifyQuery Q = SQ.getWithInstruction(&I); 907 908 if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1), 909 I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), Q)) 910 return replaceInstUsesWith(I, V); 911 912 if (Instruction *X = foldVectorBinop(I)) 913 return X; 914 915 if (Instruction *V = commonShiftTransforms(I)) 916 return V; 917 918 if (Instruction *V = dropRedundantMaskingOfLeftShiftInput(&I, Q, Builder)) 919 return V; 920 921 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); 922 Type *Ty = I.getType(); 923 unsigned BitWidth = Ty->getScalarSizeInBits(); 924 925 const APInt *ShAmtAPInt; 926 if (match(Op1, m_APInt(ShAmtAPInt))) { 927 unsigned ShAmt = ShAmtAPInt->getZExtValue(); 928 929 // shl (zext X), ShAmt --> zext (shl X, ShAmt) 930 // This is only valid if X would have zeros shifted out. 931 Value *X; 932 if (match(Op0, m_OneUse(m_ZExt(m_Value(X))))) { 933 unsigned SrcWidth = X->getType()->getScalarSizeInBits(); 934 if (ShAmt < SrcWidth && 935 MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmt), 0, &I)) 936 return new ZExtInst(Builder.CreateShl(X, ShAmt), Ty); 937 } 938 939 // (X >> C) << C --> X & (-1 << C) 940 if (match(Op0, m_Shr(m_Value(X), m_Specific(Op1)))) { 941 APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt)); 942 return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); 943 } 944 945 // FIXME: we do not yet transform non-exact shr's. The backend (DAGCombine) 946 // needs a few fixes for the rotate pattern recognition first. 947 const APInt *ShOp1; 948 if (match(Op0, m_Exact(m_Shr(m_Value(X), m_APInt(ShOp1))))) { 949 unsigned ShrAmt = ShOp1->getZExtValue(); 950 if (ShrAmt < ShAmt) { 951 // If C1 < C2: (X >>?,exact C1) << C2 --> X << (C2 - C1) 952 Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShrAmt); 953 auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); 954 NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); 955 NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); 956 return NewShl; 957 } 958 if (ShrAmt > ShAmt) { 959 // If C1 > C2: (X >>?exact C1) << C2 --> X >>?exact (C1 - C2) 960 Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmt); 961 auto *NewShr = BinaryOperator::Create( 962 cast<BinaryOperator>(Op0)->getOpcode(), X, ShiftDiff); 963 NewShr->setIsExact(true); 964 return NewShr; 965 } 966 } 967 968 if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1)))) { 969 unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); 970 // Oversized shifts are simplified to zero in InstSimplify. 971 if (AmtSum < BitWidth) 972 // (X << C1) << C2 --> X << (C1 + C2) 973 return BinaryOperator::CreateShl(X, ConstantInt::get(Ty, AmtSum)); 974 } 975 976 // If the shifted-out value is known-zero, then this is a NUW shift. 977 if (!I.hasNoUnsignedWrap() && 978 MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmt), 0, &I)) { 979 I.setHasNoUnsignedWrap(); 980 return &I; 981 } 982 983 // If the shifted-out value is all signbits, then this is a NSW shift. 984 if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmt) { 985 I.setHasNoSignedWrap(); 986 return &I; 987 } 988 } 989 990 // Transform (x >> y) << y to x & (-1 << y) 991 // Valid for any type of right-shift. 992 Value *X; 993 if (match(Op0, m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))))) { 994 Constant *AllOnes = ConstantInt::getAllOnesValue(Ty); 995 Value *Mask = Builder.CreateShl(AllOnes, Op1); 996 return BinaryOperator::CreateAnd(Mask, X); 997 } 998 999 Constant *C1; 1000 if (match(Op1, m_Constant(C1))) { 1001 Constant *C2; 1002 Value *X; 1003 // (C2 << X) << C1 --> (C2 << C1) << X 1004 if (match(Op0, m_OneUse(m_Shl(m_Constant(C2), m_Value(X))))) 1005 return BinaryOperator::CreateShl(ConstantExpr::getShl(C2, C1), X); 1006 1007 // (X * C2) << C1 --> X * (C2 << C1) 1008 if (match(Op0, m_Mul(m_Value(X), m_Constant(C2)))) 1009 return BinaryOperator::CreateMul(X, ConstantExpr::getShl(C2, C1)); 1010 1011 // shl (zext i1 X), C1 --> select (X, 1 << C1, 0) 1012 if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { 1013 auto *NewC = ConstantExpr::getShl(ConstantInt::get(Ty, 1), C1); 1014 return SelectInst::Create(X, NewC, ConstantInt::getNullValue(Ty)); 1015 } 1016 } 1017 1018 // (1 << (C - x)) -> ((1 << C) >> x) if C is bitwidth - 1 1019 if (match(Op0, m_One()) && 1020 match(Op1, m_Sub(m_SpecificInt(BitWidth - 1), m_Value(X)))) 1021 return BinaryOperator::CreateLShr( 1022 ConstantInt::get(Ty, APInt::getSignMask(BitWidth)), X); 1023 1024 return nullptr; 1025 } 1026 1027 Instruction *InstCombiner::visitLShr(BinaryOperator &I) { 1028 if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), 1029 SQ.getWithInstruction(&I))) 1030 return replaceInstUsesWith(I, V); 1031 1032 if (Instruction *X = foldVectorBinop(I)) 1033 return X; 1034 1035 if (Instruction *R = commonShiftTransforms(I)) 1036 return R; 1037 1038 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); 1039 Type *Ty = I.getType(); 1040 const APInt *ShAmtAPInt; 1041 if (match(Op1, m_APInt(ShAmtAPInt))) { 1042 unsigned ShAmt = ShAmtAPInt->getZExtValue(); 1043 unsigned BitWidth = Ty->getScalarSizeInBits(); 1044 auto *II = dyn_cast<IntrinsicInst>(Op0); 1045 if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt && 1046 (II->getIntrinsicID() == Intrinsic::ctlz || 1047 II->getIntrinsicID() == Intrinsic::cttz || 1048 II->getIntrinsicID() == Intrinsic::ctpop)) { 1049 // ctlz.i32(x)>>5 --> zext(x == 0) 1050 // cttz.i32(x)>>5 --> zext(x == 0) 1051 // ctpop.i32(x)>>5 --> zext(x == -1) 1052 bool IsPop = II->getIntrinsicID() == Intrinsic::ctpop; 1053 Constant *RHS = ConstantInt::getSigned(Ty, IsPop ? -1 : 0); 1054 Value *Cmp = Builder.CreateICmpEQ(II->getArgOperand(0), RHS); 1055 return new ZExtInst(Cmp, Ty); 1056 } 1057 1058 Value *X; 1059 const APInt *ShOp1; 1060 if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1))) && ShOp1->ult(BitWidth)) { 1061 if (ShOp1->ult(ShAmt)) { 1062 unsigned ShlAmt = ShOp1->getZExtValue(); 1063 Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt); 1064 if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) { 1065 // (X <<nuw C1) >>u C2 --> X >>u (C2 - C1) 1066 auto *NewLShr = BinaryOperator::CreateLShr(X, ShiftDiff); 1067 NewLShr->setIsExact(I.isExact()); 1068 return NewLShr; 1069 } 1070 // (X << C1) >>u C2 --> (X >>u (C2 - C1)) & (-1 >> C2) 1071 Value *NewLShr = Builder.CreateLShr(X, ShiftDiff, "", I.isExact()); 1072 APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); 1073 return BinaryOperator::CreateAnd(NewLShr, ConstantInt::get(Ty, Mask)); 1074 } 1075 if (ShOp1->ugt(ShAmt)) { 1076 unsigned ShlAmt = ShOp1->getZExtValue(); 1077 Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt); 1078 if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) { 1079 // (X <<nuw C1) >>u C2 --> X <<nuw (C1 - C2) 1080 auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); 1081 NewShl->setHasNoUnsignedWrap(true); 1082 return NewShl; 1083 } 1084 // (X << C1) >>u C2 --> X << (C1 - C2) & (-1 >> C2) 1085 Value *NewShl = Builder.CreateShl(X, ShiftDiff); 1086 APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); 1087 return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask)); 1088 } 1089 assert(*ShOp1 == ShAmt); 1090 // (X << C) >>u C --> X & (-1 >>u C) 1091 APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); 1092 return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); 1093 } 1094 1095 if (match(Op0, m_OneUse(m_ZExt(m_Value(X)))) && 1096 (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) { 1097 assert(ShAmt < X->getType()->getScalarSizeInBits() && 1098 "Big shift not simplified to zero?"); 1099 // lshr (zext iM X to iN), C --> zext (lshr X, C) to iN 1100 Value *NewLShr = Builder.CreateLShr(X, ShAmt); 1101 return new ZExtInst(NewLShr, Ty); 1102 } 1103 1104 if (match(Op0, m_SExt(m_Value(X))) && 1105 (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) { 1106 // Are we moving the sign bit to the low bit and widening with high zeros? 1107 unsigned SrcTyBitWidth = X->getType()->getScalarSizeInBits(); 1108 if (ShAmt == BitWidth - 1) { 1109 // lshr (sext i1 X to iN), N-1 --> zext X to iN 1110 if (SrcTyBitWidth == 1) 1111 return new ZExtInst(X, Ty); 1112 1113 // lshr (sext iM X to iN), N-1 --> zext (lshr X, M-1) to iN 1114 if (Op0->hasOneUse()) { 1115 Value *NewLShr = Builder.CreateLShr(X, SrcTyBitWidth - 1); 1116 return new ZExtInst(NewLShr, Ty); 1117 } 1118 } 1119 1120 // lshr (sext iM X to iN), N-M --> zext (ashr X, min(N-M, M-1)) to iN 1121 if (ShAmt == BitWidth - SrcTyBitWidth && Op0->hasOneUse()) { 1122 // The new shift amount can't be more than the narrow source type. 1123 unsigned NewShAmt = std::min(ShAmt, SrcTyBitWidth - 1); 1124 Value *AShr = Builder.CreateAShr(X, NewShAmt); 1125 return new ZExtInst(AShr, Ty); 1126 } 1127 } 1128 1129 if (match(Op0, m_LShr(m_Value(X), m_APInt(ShOp1)))) { 1130 unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); 1131 // Oversized shifts are simplified to zero in InstSimplify. 1132 if (AmtSum < BitWidth) 1133 // (X >>u C1) >>u C2 --> X >>u (C1 + C2) 1134 return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, AmtSum)); 1135 } 1136 1137 // If the shifted-out value is known-zero, then this is an exact shift. 1138 if (!I.isExact() && 1139 MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) { 1140 I.setIsExact(); 1141 return &I; 1142 } 1143 } 1144 1145 // Transform (x << y) >> y to x & (-1 >> y) 1146 Value *X; 1147 if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))))) { 1148 Constant *AllOnes = ConstantInt::getAllOnesValue(Ty); 1149 Value *Mask = Builder.CreateLShr(AllOnes, Op1); 1150 return BinaryOperator::CreateAnd(Mask, X); 1151 } 1152 1153 return nullptr; 1154 } 1155 1156 Instruction * 1157 InstCombiner::foldVariableSignZeroExtensionOfVariableHighBitExtract( 1158 BinaryOperator &OldAShr) { 1159 assert(OldAShr.getOpcode() == Instruction::AShr && 1160 "Must be called with arithmetic right-shift instruction only."); 1161 1162 // Check that constant C is a splat of the element-wise bitwidth of V. 1163 auto BitWidthSplat = [](Constant *C, Value *V) { 1164 return match( 1165 C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, 1166 APInt(C->getType()->getScalarSizeInBits(), 1167 V->getType()->getScalarSizeInBits()))); 1168 }; 1169 1170 // It should look like variable-length sign-extension on the outside: 1171 // (Val << (bitwidth(Val)-Nbits)) a>> (bitwidth(Val)-Nbits) 1172 Value *NBits; 1173 Instruction *MaybeTrunc; 1174 Constant *C1, *C2; 1175 if (!match(&OldAShr, 1176 m_AShr(m_Shl(m_Instruction(MaybeTrunc), 1177 m_ZExtOrSelf(m_Sub(m_Constant(C1), 1178 m_ZExtOrSelf(m_Value(NBits))))), 1179 m_ZExtOrSelf(m_Sub(m_Constant(C2), 1180 m_ZExtOrSelf(m_Deferred(NBits)))))) || 1181 !BitWidthSplat(C1, &OldAShr) || !BitWidthSplat(C2, &OldAShr)) 1182 return nullptr; 1183 1184 // There may or may not be a truncation after outer two shifts. 1185 Instruction *HighBitExtract; 1186 match(MaybeTrunc, m_TruncOrSelf(m_Instruction(HighBitExtract))); 1187 bool HadTrunc = MaybeTrunc != HighBitExtract; 1188 1189 // And finally, the innermost part of the pattern must be a right-shift. 1190 Value *X, *NumLowBitsToSkip; 1191 if (!match(HighBitExtract, m_Shr(m_Value(X), m_Value(NumLowBitsToSkip)))) 1192 return nullptr; 1193 1194 // Said right-shift must extract high NBits bits - C0 must be it's bitwidth. 1195 Constant *C0; 1196 if (!match(NumLowBitsToSkip, 1197 m_ZExtOrSelf( 1198 m_Sub(m_Constant(C0), m_ZExtOrSelf(m_Specific(NBits))))) || 1199 !BitWidthSplat(C0, HighBitExtract)) 1200 return nullptr; 1201 1202 // Since the NBits is identical for all shifts, if the outermost and 1203 // innermost shifts are identical, then outermost shifts are redundant. 1204 // If we had truncation, do keep it though. 1205 if (HighBitExtract->getOpcode() == OldAShr.getOpcode()) 1206 return replaceInstUsesWith(OldAShr, MaybeTrunc); 1207 1208 // Else, if there was a truncation, then we need to ensure that one 1209 // instruction will go away. 1210 if (HadTrunc && !match(&OldAShr, m_c_BinOp(m_OneUse(m_Value()), m_Value()))) 1211 return nullptr; 1212 1213 // Finally, bypass two innermost shifts, and perform the outermost shift on 1214 // the operands of the innermost shift. 1215 Instruction *NewAShr = 1216 BinaryOperator::Create(OldAShr.getOpcode(), X, NumLowBitsToSkip); 1217 NewAShr->copyIRFlags(HighBitExtract); // We can preserve 'exact'-ness. 1218 if (!HadTrunc) 1219 return NewAShr; 1220 1221 Builder.Insert(NewAShr); 1222 return TruncInst::CreateTruncOrBitCast(NewAShr, OldAShr.getType()); 1223 } 1224 1225 Instruction *InstCombiner::visitAShr(BinaryOperator &I) { 1226 if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), 1227 SQ.getWithInstruction(&I))) 1228 return replaceInstUsesWith(I, V); 1229 1230 if (Instruction *X = foldVectorBinop(I)) 1231 return X; 1232 1233 if (Instruction *R = commonShiftTransforms(I)) 1234 return R; 1235 1236 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); 1237 Type *Ty = I.getType(); 1238 unsigned BitWidth = Ty->getScalarSizeInBits(); 1239 const APInt *ShAmtAPInt; 1240 if (match(Op1, m_APInt(ShAmtAPInt)) && ShAmtAPInt->ult(BitWidth)) { 1241 unsigned ShAmt = ShAmtAPInt->getZExtValue(); 1242 1243 // If the shift amount equals the difference in width of the destination 1244 // and source scalar types: 1245 // ashr (shl (zext X), C), C --> sext X 1246 Value *X; 1247 if (match(Op0, m_Shl(m_ZExt(m_Value(X)), m_Specific(Op1))) && 1248 ShAmt == BitWidth - X->getType()->getScalarSizeInBits()) 1249 return new SExtInst(X, Ty); 1250 1251 // We can't handle (X << C1) >>s C2. It shifts arbitrary bits in. However, 1252 // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits. 1253 const APInt *ShOp1; 1254 if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShOp1))) && 1255 ShOp1->ult(BitWidth)) { 1256 unsigned ShlAmt = ShOp1->getZExtValue(); 1257 if (ShlAmt < ShAmt) { 1258 // (X <<nsw C1) >>s C2 --> X >>s (C2 - C1) 1259 Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt); 1260 auto *NewAShr = BinaryOperator::CreateAShr(X, ShiftDiff); 1261 NewAShr->setIsExact(I.isExact()); 1262 return NewAShr; 1263 } 1264 if (ShlAmt > ShAmt) { 1265 // (X <<nsw C1) >>s C2 --> X <<nsw (C1 - C2) 1266 Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt); 1267 auto *NewShl = BinaryOperator::Create(Instruction::Shl, X, ShiftDiff); 1268 NewShl->setHasNoSignedWrap(true); 1269 return NewShl; 1270 } 1271 } 1272 1273 if (match(Op0, m_AShr(m_Value(X), m_APInt(ShOp1))) && 1274 ShOp1->ult(BitWidth)) { 1275 unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); 1276 // Oversized arithmetic shifts replicate the sign bit. 1277 AmtSum = std::min(AmtSum, BitWidth - 1); 1278 // (X >>s C1) >>s C2 --> X >>s (C1 + C2) 1279 return BinaryOperator::CreateAShr(X, ConstantInt::get(Ty, AmtSum)); 1280 } 1281 1282 if (match(Op0, m_OneUse(m_SExt(m_Value(X)))) && 1283 (Ty->isVectorTy() || shouldChangeType(Ty, X->getType()))) { 1284 // ashr (sext X), C --> sext (ashr X, C') 1285 Type *SrcTy = X->getType(); 1286 ShAmt = std::min(ShAmt, SrcTy->getScalarSizeInBits() - 1); 1287 Value *NewSh = Builder.CreateAShr(X, ConstantInt::get(SrcTy, ShAmt)); 1288 return new SExtInst(NewSh, Ty); 1289 } 1290 1291 // If the shifted-out value is known-zero, then this is an exact shift. 1292 if (!I.isExact() && 1293 MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) { 1294 I.setIsExact(); 1295 return &I; 1296 } 1297 } 1298 1299 if (Instruction *R = foldVariableSignZeroExtensionOfVariableHighBitExtract(I)) 1300 return R; 1301 1302 // See if we can turn a signed shr into an unsigned shr. 1303 if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I)) 1304 return BinaryOperator::CreateLShr(Op0, Op1); 1305 1306 return nullptr; 1307 } 1308