xref: /llvm-project/llvm/lib/Analysis/ValueTracking.cpp (revision f226cabbb1b9737676536bc4417336bef4808992)
1 //===- ValueTracking.cpp - Walk computations to compute properties --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains routines that help analyze properties that chains of
10 // computations have.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Analysis/ValueTracking.h"
15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/ADT/APInt.h"
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/ScopeExit.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "llvm/ADT/SmallSet.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/ADT/iterator_range.h"
25 #include "llvm/Analysis/AliasAnalysis.h"
26 #include "llvm/Analysis/AssumeBundleQueries.h"
27 #include "llvm/Analysis/AssumptionCache.h"
28 #include "llvm/Analysis/ConstantFolding.h"
29 #include "llvm/Analysis/DomConditionCache.h"
30 #include "llvm/Analysis/GuardUtils.h"
31 #include "llvm/Analysis/InstructionSimplify.h"
32 #include "llvm/Analysis/Loads.h"
33 #include "llvm/Analysis/LoopInfo.h"
34 #include "llvm/Analysis/TargetLibraryInfo.h"
35 #include "llvm/Analysis/VectorUtils.h"
36 #include "llvm/Analysis/WithCache.h"
37 #include "llvm/IR/Argument.h"
38 #include "llvm/IR/Attributes.h"
39 #include "llvm/IR/BasicBlock.h"
40 #include "llvm/IR/Constant.h"
41 #include "llvm/IR/ConstantRange.h"
42 #include "llvm/IR/Constants.h"
43 #include "llvm/IR/DerivedTypes.h"
44 #include "llvm/IR/DiagnosticInfo.h"
45 #include "llvm/IR/Dominators.h"
46 #include "llvm/IR/EHPersonalities.h"
47 #include "llvm/IR/Function.h"
48 #include "llvm/IR/GetElementPtrTypeIterator.h"
49 #include "llvm/IR/GlobalAlias.h"
50 #include "llvm/IR/GlobalValue.h"
51 #include "llvm/IR/GlobalVariable.h"
52 #include "llvm/IR/InstrTypes.h"
53 #include "llvm/IR/Instruction.h"
54 #include "llvm/IR/Instructions.h"
55 #include "llvm/IR/IntrinsicInst.h"
56 #include "llvm/IR/Intrinsics.h"
57 #include "llvm/IR/IntrinsicsAArch64.h"
58 #include "llvm/IR/IntrinsicsAMDGPU.h"
59 #include "llvm/IR/IntrinsicsRISCV.h"
60 #include "llvm/IR/IntrinsicsX86.h"
61 #include "llvm/IR/LLVMContext.h"
62 #include "llvm/IR/Metadata.h"
63 #include "llvm/IR/Module.h"
64 #include "llvm/IR/Operator.h"
65 #include "llvm/IR/PatternMatch.h"
66 #include "llvm/IR/Type.h"
67 #include "llvm/IR/User.h"
68 #include "llvm/IR/Value.h"
69 #include "llvm/Support/Casting.h"
70 #include "llvm/Support/CommandLine.h"
71 #include "llvm/Support/Compiler.h"
72 #include "llvm/Support/ErrorHandling.h"
73 #include "llvm/Support/KnownBits.h"
74 #include "llvm/Support/MathExtras.h"
75 #include "llvm/TargetParser/RISCVTargetParser.h"
76 #include <algorithm>
77 #include <cassert>
78 #include <cstdint>
79 #include <optional>
80 #include <utility>
81 
82 using namespace llvm;
83 using namespace llvm::PatternMatch;
84 
85 // Controls the number of uses of the value searched for possible
86 // dominating comparisons.
87 static cl::opt<unsigned> DomConditionsMaxUses("dom-conditions-max-uses",
88                                               cl::Hidden, cl::init(20));
89 
90 
91 /// Returns the bitwidth of the given scalar or pointer type. For vector types,
92 /// returns the element type's bitwidth.
93 static unsigned getBitWidth(Type *Ty, const DataLayout &DL) {
94   if (unsigned BitWidth = Ty->getScalarSizeInBits())
95     return BitWidth;
96 
97   return DL.getPointerTypeSizeInBits(Ty);
98 }
99 
100 // Given the provided Value and, potentially, a context instruction, return
101 // the preferred context instruction (if any).
102 static const Instruction *safeCxtI(const Value *V, const Instruction *CxtI) {
103   // If we've been provided with a context instruction, then use that (provided
104   // it has been inserted).
105   if (CxtI && CxtI->getParent())
106     return CxtI;
107 
108   // If the value is really an already-inserted instruction, then use that.
109   CxtI = dyn_cast<Instruction>(V);
110   if (CxtI && CxtI->getParent())
111     return CxtI;
112 
113   return nullptr;
114 }
115 
116 static const Instruction *safeCxtI(const Value *V1, const Value *V2, const Instruction *CxtI) {
117   // If we've been provided with a context instruction, then use that (provided
118   // it has been inserted).
119   if (CxtI && CxtI->getParent())
120     return CxtI;
121 
122   // If the value is really an already-inserted instruction, then use that.
123   CxtI = dyn_cast<Instruction>(V1);
124   if (CxtI && CxtI->getParent())
125     return CxtI;
126 
127   CxtI = dyn_cast<Instruction>(V2);
128   if (CxtI && CxtI->getParent())
129     return CxtI;
130 
131   return nullptr;
132 }
133 
134 static bool getShuffleDemandedElts(const ShuffleVectorInst *Shuf,
135                                    const APInt &DemandedElts,
136                                    APInt &DemandedLHS, APInt &DemandedRHS) {
137   if (isa<ScalableVectorType>(Shuf->getType())) {
138     assert(DemandedElts == APInt(1,1));
139     DemandedLHS = DemandedRHS = DemandedElts;
140     return true;
141   }
142 
143   int NumElts =
144       cast<FixedVectorType>(Shuf->getOperand(0)->getType())->getNumElements();
145   return llvm::getShuffleDemandedElts(NumElts, Shuf->getShuffleMask(),
146                                       DemandedElts, DemandedLHS, DemandedRHS);
147 }
148 
149 static void computeKnownBits(const Value *V, const APInt &DemandedElts,
150                              KnownBits &Known, unsigned Depth,
151                              const SimplifyQuery &Q);
152 
153 void llvm::computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth,
154                             const SimplifyQuery &Q) {
155   // Since the number of lanes in a scalable vector is unknown at compile time,
156   // we track one bit which is implicitly broadcast to all lanes.  This means
157   // that all lanes in a scalable vector are considered demanded.
158   auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
159   APInt DemandedElts =
160       FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
161   ::computeKnownBits(V, DemandedElts, Known, Depth, Q);
162 }
163 
164 void llvm::computeKnownBits(const Value *V, KnownBits &Known,
165                             const DataLayout &DL, unsigned Depth,
166                             AssumptionCache *AC, const Instruction *CxtI,
167                             const DominatorTree *DT, bool UseInstrInfo) {
168   computeKnownBits(
169       V, Known, Depth,
170       SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo));
171 }
172 
173 KnownBits llvm::computeKnownBits(const Value *V, const DataLayout &DL,
174                                  unsigned Depth, AssumptionCache *AC,
175                                  const Instruction *CxtI,
176                                  const DominatorTree *DT, bool UseInstrInfo) {
177   return computeKnownBits(
178       V, Depth, SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo));
179 }
180 
181 KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
182                                  const DataLayout &DL, unsigned Depth,
183                                  AssumptionCache *AC, const Instruction *CxtI,
184                                  const DominatorTree *DT, bool UseInstrInfo) {
185   return computeKnownBits(
186       V, DemandedElts, Depth,
187       SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo));
188 }
189 
190 static bool haveNoCommonBitsSetSpecialCases(const Value *LHS, const Value *RHS,
191                                             const SimplifyQuery &SQ) {
192   // Look for an inverted mask: (X & ~M) op (Y & M).
193   {
194     Value *M;
195     if (match(LHS, m_c_And(m_Not(m_Value(M)), m_Value())) &&
196         match(RHS, m_c_And(m_Specific(M), m_Value())) &&
197         isGuaranteedNotToBeUndef(M, SQ.AC, SQ.CxtI, SQ.DT))
198       return true;
199   }
200 
201   // X op (Y & ~X)
202   if (match(RHS, m_c_And(m_Not(m_Specific(LHS)), m_Value())) &&
203       isGuaranteedNotToBeUndef(LHS, SQ.AC, SQ.CxtI, SQ.DT))
204     return true;
205 
206   // X op ((X & Y) ^ Y) -- this is the canonical form of the previous pattern
207   // for constant Y.
208   Value *Y;
209   if (match(RHS,
210             m_c_Xor(m_c_And(m_Specific(LHS), m_Value(Y)), m_Deferred(Y))) &&
211       isGuaranteedNotToBeUndef(LHS, SQ.AC, SQ.CxtI, SQ.DT) &&
212       isGuaranteedNotToBeUndef(Y, SQ.AC, SQ.CxtI, SQ.DT))
213     return true;
214 
215   // Peek through extends to find a 'not' of the other side:
216   // (ext Y) op ext(~Y)
217   if (match(LHS, m_ZExtOrSExt(m_Value(Y))) &&
218       match(RHS, m_ZExtOrSExt(m_Not(m_Specific(Y)))) &&
219       isGuaranteedNotToBeUndef(Y, SQ.AC, SQ.CxtI, SQ.DT))
220     return true;
221 
222   // Look for: (A & B) op ~(A | B)
223   {
224     Value *A, *B;
225     if (match(LHS, m_And(m_Value(A), m_Value(B))) &&
226         match(RHS, m_Not(m_c_Or(m_Specific(A), m_Specific(B)))) &&
227         isGuaranteedNotToBeUndef(A, SQ.AC, SQ.CxtI, SQ.DT) &&
228         isGuaranteedNotToBeUndef(B, SQ.AC, SQ.CxtI, SQ.DT))
229       return true;
230   }
231 
232   // Look for: (X << V) op (Y >> (BitWidth - V))
233   // or        (X >> V) op (Y << (BitWidth - V))
234   {
235     const Value *V;
236     const APInt *R;
237     if (((match(RHS, m_Shl(m_Value(), m_Sub(m_APInt(R), m_Value(V)))) &&
238           match(LHS, m_LShr(m_Value(), m_Specific(V)))) ||
239          (match(RHS, m_LShr(m_Value(), m_Sub(m_APInt(R), m_Value(V)))) &&
240           match(LHS, m_Shl(m_Value(), m_Specific(V))))) &&
241         R->uge(LHS->getType()->getScalarSizeInBits()))
242       return true;
243   }
244 
245   return false;
246 }
247 
248 bool llvm::haveNoCommonBitsSet(const WithCache<const Value *> &LHSCache,
249                                const WithCache<const Value *> &RHSCache,
250                                const SimplifyQuery &SQ) {
251   const Value *LHS = LHSCache.getValue();
252   const Value *RHS = RHSCache.getValue();
253 
254   assert(LHS->getType() == RHS->getType() &&
255          "LHS and RHS should have the same type");
256   assert(LHS->getType()->isIntOrIntVectorTy() &&
257          "LHS and RHS should be integers");
258 
259   if (haveNoCommonBitsSetSpecialCases(LHS, RHS, SQ) ||
260       haveNoCommonBitsSetSpecialCases(RHS, LHS, SQ))
261     return true;
262 
263   return KnownBits::haveNoCommonBitsSet(LHSCache.getKnownBits(SQ),
264                                         RHSCache.getKnownBits(SQ));
265 }
266 
267 bool llvm::isOnlyUsedInZeroComparison(const Instruction *I) {
268   return !I->user_empty() && all_of(I->users(), [](const User *U) {
269     return match(U, m_ICmp(m_Value(), m_Zero()));
270   });
271 }
272 
273 bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *I) {
274   return !I->user_empty() && all_of(I->users(), [](const User *U) {
275     CmpPredicate P;
276     return match(U, m_ICmp(P, m_Value(), m_Zero())) && ICmpInst::isEquality(P);
277   });
278 }
279 
280 bool llvm::isKnownToBeAPowerOfTwo(const Value *V, const DataLayout &DL,
281                                   bool OrZero, unsigned Depth,
282                                   AssumptionCache *AC, const Instruction *CxtI,
283                                   const DominatorTree *DT, bool UseInstrInfo) {
284   return ::isKnownToBeAPowerOfTwo(
285       V, OrZero, Depth,
286       SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo));
287 }
288 
289 static bool isKnownNonZero(const Value *V, const APInt &DemandedElts,
290                            const SimplifyQuery &Q, unsigned Depth);
291 
292 bool llvm::isKnownNonNegative(const Value *V, const SimplifyQuery &SQ,
293                               unsigned Depth) {
294   return computeKnownBits(V, Depth, SQ).isNonNegative();
295 }
296 
297 bool llvm::isKnownPositive(const Value *V, const SimplifyQuery &SQ,
298                            unsigned Depth) {
299   if (auto *CI = dyn_cast<ConstantInt>(V))
300     return CI->getValue().isStrictlyPositive();
301 
302   // If `isKnownNonNegative` ever becomes more sophisticated, make sure to keep
303   // this updated.
304   KnownBits Known = computeKnownBits(V, Depth, SQ);
305   return Known.isNonNegative() &&
306          (Known.isNonZero() || isKnownNonZero(V, SQ, Depth));
307 }
308 
309 bool llvm::isKnownNegative(const Value *V, const SimplifyQuery &SQ,
310                            unsigned Depth) {
311   return computeKnownBits(V, Depth, SQ).isNegative();
312 }
313 
314 static bool isKnownNonEqual(const Value *V1, const Value *V2,
315                             const APInt &DemandedElts, unsigned Depth,
316                             const SimplifyQuery &Q);
317 
318 bool llvm::isKnownNonEqual(const Value *V1, const Value *V2,
319                            const DataLayout &DL, AssumptionCache *AC,
320                            const Instruction *CxtI, const DominatorTree *DT,
321                            bool UseInstrInfo) {
322   // We don't support looking through casts.
323   if (V1 == V2 || V1->getType() != V2->getType())
324     return false;
325   auto *FVTy = dyn_cast<FixedVectorType>(V1->getType());
326   APInt DemandedElts =
327       FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
328   return ::isKnownNonEqual(
329       V1, V2, DemandedElts, 0,
330       SimplifyQuery(DL, DT, AC, safeCxtI(V2, V1, CxtI), UseInstrInfo));
331 }
332 
333 bool llvm::MaskedValueIsZero(const Value *V, const APInt &Mask,
334                              const SimplifyQuery &SQ, unsigned Depth) {
335   KnownBits Known(Mask.getBitWidth());
336   computeKnownBits(V, Known, Depth, SQ);
337   return Mask.isSubsetOf(Known.Zero);
338 }
339 
340 static unsigned ComputeNumSignBits(const Value *V, const APInt &DemandedElts,
341                                    unsigned Depth, const SimplifyQuery &Q);
342 
343 static unsigned ComputeNumSignBits(const Value *V, unsigned Depth,
344                                    const SimplifyQuery &Q) {
345   auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
346   APInt DemandedElts =
347       FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
348   return ComputeNumSignBits(V, DemandedElts, Depth, Q);
349 }
350 
351 unsigned llvm::ComputeNumSignBits(const Value *V, const DataLayout &DL,
352                                   unsigned Depth, AssumptionCache *AC,
353                                   const Instruction *CxtI,
354                                   const DominatorTree *DT, bool UseInstrInfo) {
355   return ::ComputeNumSignBits(
356       V, Depth, SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo));
357 }
358 
359 unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL,
360                                          unsigned Depth, AssumptionCache *AC,
361                                          const Instruction *CxtI,
362                                          const DominatorTree *DT) {
363   unsigned SignBits = ComputeNumSignBits(V, DL, Depth, AC, CxtI, DT);
364   return V->getType()->getScalarSizeInBits() - SignBits + 1;
365 }
366 
367 static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
368                                    bool NSW, bool NUW,
369                                    const APInt &DemandedElts,
370                                    KnownBits &KnownOut, KnownBits &Known2,
371                                    unsigned Depth, const SimplifyQuery &Q) {
372   computeKnownBits(Op1, DemandedElts, KnownOut, Depth + 1, Q);
373 
374   // If one operand is unknown and we have no nowrap information,
375   // the result will be unknown independently of the second operand.
376   if (KnownOut.isUnknown() && !NSW && !NUW)
377     return;
378 
379   computeKnownBits(Op0, DemandedElts, Known2, Depth + 1, Q);
380   KnownOut = KnownBits::computeForAddSub(Add, NSW, NUW, Known2, KnownOut);
381 }
382 
383 static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
384                                 bool NUW, const APInt &DemandedElts,
385                                 KnownBits &Known, KnownBits &Known2,
386                                 unsigned Depth, const SimplifyQuery &Q) {
387   computeKnownBits(Op1, DemandedElts, Known, Depth + 1, Q);
388   computeKnownBits(Op0, DemandedElts, Known2, Depth + 1, Q);
389 
390   bool isKnownNegative = false;
391   bool isKnownNonNegative = false;
392   // If the multiplication is known not to overflow, compute the sign bit.
393   if (NSW) {
394     if (Op0 == Op1) {
395       // The product of a number with itself is non-negative.
396       isKnownNonNegative = true;
397     } else {
398       bool isKnownNonNegativeOp1 = Known.isNonNegative();
399       bool isKnownNonNegativeOp0 = Known2.isNonNegative();
400       bool isKnownNegativeOp1 = Known.isNegative();
401       bool isKnownNegativeOp0 = Known2.isNegative();
402       // The product of two numbers with the same sign is non-negative.
403       isKnownNonNegative = (isKnownNegativeOp1 && isKnownNegativeOp0) ||
404                            (isKnownNonNegativeOp1 && isKnownNonNegativeOp0);
405       if (!isKnownNonNegative && NUW) {
406         // mul nuw nsw with a factor > 1 is non-negative.
407         KnownBits One = KnownBits::makeConstant(APInt(Known.getBitWidth(), 1));
408         isKnownNonNegative = KnownBits::sgt(Known, One).value_or(false) ||
409                              KnownBits::sgt(Known2, One).value_or(false);
410       }
411 
412       // The product of a negative number and a non-negative number is either
413       // negative or zero.
414       if (!isKnownNonNegative)
415         isKnownNegative =
416             (isKnownNegativeOp1 && isKnownNonNegativeOp0 &&
417              Known2.isNonZero()) ||
418             (isKnownNegativeOp0 && isKnownNonNegativeOp1 && Known.isNonZero());
419     }
420   }
421 
422   bool SelfMultiply = Op0 == Op1;
423   if (SelfMultiply)
424     SelfMultiply &=
425         isGuaranteedNotToBeUndef(Op0, Q.AC, Q.CxtI, Q.DT, Depth + 1);
426   Known = KnownBits::mul(Known, Known2, SelfMultiply);
427 
428   // Only make use of no-wrap flags if we failed to compute the sign bit
429   // directly.  This matters if the multiplication always overflows, in
430   // which case we prefer to follow the result of the direct computation,
431   // though as the program is invoking undefined behaviour we can choose
432   // whatever we like here.
433   if (isKnownNonNegative && !Known.isNegative())
434     Known.makeNonNegative();
435   else if (isKnownNegative && !Known.isNonNegative())
436     Known.makeNegative();
437 }
438 
439 void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges,
440                                              KnownBits &Known) {
441   unsigned BitWidth = Known.getBitWidth();
442   unsigned NumRanges = Ranges.getNumOperands() / 2;
443   assert(NumRanges >= 1);
444 
445   Known.Zero.setAllBits();
446   Known.One.setAllBits();
447 
448   for (unsigned i = 0; i < NumRanges; ++i) {
449     ConstantInt *Lower =
450         mdconst::extract<ConstantInt>(Ranges.getOperand(2 * i + 0));
451     ConstantInt *Upper =
452         mdconst::extract<ConstantInt>(Ranges.getOperand(2 * i + 1));
453     ConstantRange Range(Lower->getValue(), Upper->getValue());
454 
455     // The first CommonPrefixBits of all values in Range are equal.
456     unsigned CommonPrefixBits =
457         (Range.getUnsignedMax() ^ Range.getUnsignedMin()).countl_zero();
458     APInt Mask = APInt::getHighBitsSet(BitWidth, CommonPrefixBits);
459     APInt UnsignedMax = Range.getUnsignedMax().zextOrTrunc(BitWidth);
460     Known.One &= UnsignedMax & Mask;
461     Known.Zero &= ~UnsignedMax & Mask;
462   }
463 }
464 
465 static bool isEphemeralValueOf(const Instruction *I, const Value *E) {
466   SmallVector<const Value *, 16> WorkSet(1, I);
467   SmallPtrSet<const Value *, 32> Visited;
468   SmallPtrSet<const Value *, 16> EphValues;
469 
470   // The instruction defining an assumption's condition itself is always
471   // considered ephemeral to that assumption (even if it has other
472   // non-ephemeral users). See r246696's test case for an example.
473   if (is_contained(I->operands(), E))
474     return true;
475 
476   while (!WorkSet.empty()) {
477     const Value *V = WorkSet.pop_back_val();
478     if (!Visited.insert(V).second)
479       continue;
480 
481     // If all uses of this value are ephemeral, then so is this value.
482     if (llvm::all_of(V->users(), [&](const User *U) {
483                                    return EphValues.count(U);
484                                  })) {
485       if (V == E)
486         return true;
487 
488       if (V == I || (isa<Instruction>(V) &&
489                      !cast<Instruction>(V)->mayHaveSideEffects() &&
490                      !cast<Instruction>(V)->isTerminator())) {
491        EphValues.insert(V);
492        if (const User *U = dyn_cast<User>(V))
493          append_range(WorkSet, U->operands());
494       }
495     }
496   }
497 
498   return false;
499 }
500 
501 // Is this an intrinsic that cannot be speculated but also cannot trap?
502 bool llvm::isAssumeLikeIntrinsic(const Instruction *I) {
503   if (const IntrinsicInst *CI = dyn_cast<IntrinsicInst>(I))
504     return CI->isAssumeLikeIntrinsic();
505 
506   return false;
507 }
508 
509 bool llvm::isValidAssumeForContext(const Instruction *Inv,
510                                    const Instruction *CxtI,
511                                    const DominatorTree *DT,
512                                    bool AllowEphemerals) {
513   // There are two restrictions on the use of an assume:
514   //  1. The assume must dominate the context (or the control flow must
515   //     reach the assume whenever it reaches the context).
516   //  2. The context must not be in the assume's set of ephemeral values
517   //     (otherwise we will use the assume to prove that the condition
518   //     feeding the assume is trivially true, thus causing the removal of
519   //     the assume).
520 
521   if (Inv->getParent() == CxtI->getParent()) {
522     // If Inv and CtxI are in the same block, check if the assume (Inv) is first
523     // in the BB.
524     if (Inv->comesBefore(CxtI))
525       return true;
526 
527     // Don't let an assume affect itself - this would cause the problems
528     // `isEphemeralValueOf` is trying to prevent, and it would also make
529     // the loop below go out of bounds.
530     if (!AllowEphemerals && Inv == CxtI)
531       return false;
532 
533     // The context comes first, but they're both in the same block.
534     // Make sure there is nothing in between that might interrupt
535     // the control flow, not even CxtI itself.
536     // We limit the scan distance between the assume and its context instruction
537     // to avoid a compile-time explosion. This limit is chosen arbitrarily, so
538     // it can be adjusted if needed (could be turned into a cl::opt).
539     auto Range = make_range(CxtI->getIterator(), Inv->getIterator());
540     if (!isGuaranteedToTransferExecutionToSuccessor(Range, 15))
541       return false;
542 
543     return AllowEphemerals || !isEphemeralValueOf(Inv, CxtI);
544   }
545 
546   // Inv and CxtI are in different blocks.
547   if (DT) {
548     if (DT->dominates(Inv, CxtI))
549       return true;
550   } else if (Inv->getParent() == CxtI->getParent()->getSinglePredecessor() ||
551              Inv->getParent()->isEntryBlock()) {
552     // We don't have a DT, but this trivially dominates.
553     return true;
554   }
555 
556   return false;
557 }
558 
559 // TODO: cmpExcludesZero misses many cases where `RHS` is non-constant but
560 // we still have enough information about `RHS` to conclude non-zero. For
561 // example Pred=EQ, RHS=isKnownNonZero. cmpExcludesZero is called in loops
562 // so the extra compile time may not be worth it, but possibly a second API
563 // should be created for use outside of loops.
564 static bool cmpExcludesZero(CmpInst::Predicate Pred, const Value *RHS) {
565   // v u> y implies v != 0.
566   if (Pred == ICmpInst::ICMP_UGT)
567     return true;
568 
569   // Special-case v != 0 to also handle v != null.
570   if (Pred == ICmpInst::ICMP_NE)
571     return match(RHS, m_Zero());
572 
573   // All other predicates - rely on generic ConstantRange handling.
574   const APInt *C;
575   auto Zero = APInt::getZero(RHS->getType()->getScalarSizeInBits());
576   if (match(RHS, m_APInt(C))) {
577     ConstantRange TrueValues = ConstantRange::makeExactICmpRegion(Pred, *C);
578     return !TrueValues.contains(Zero);
579   }
580 
581   auto *VC = dyn_cast<ConstantDataVector>(RHS);
582   if (VC == nullptr)
583     return false;
584 
585   for (unsigned ElemIdx = 0, NElem = VC->getNumElements(); ElemIdx < NElem;
586        ++ElemIdx) {
587     ConstantRange TrueValues = ConstantRange::makeExactICmpRegion(
588         Pred, VC->getElementAsAPInt(ElemIdx));
589     if (TrueValues.contains(Zero))
590       return false;
591   }
592   return true;
593 }
594 
595 static void breakSelfRecursivePHI(const Use *U, const PHINode *PHI,
596                                   Value *&ValOut, Instruction *&CtxIOut,
597                                   const PHINode **PhiOut = nullptr) {
598   ValOut = U->get();
599   if (ValOut == PHI)
600     return;
601   CtxIOut = PHI->getIncomingBlock(*U)->getTerminator();
602   if (PhiOut)
603     *PhiOut = PHI;
604   Value *V;
605   // If the Use is a select of this phi, compute analysis on other arm to break
606   // recursion.
607   // TODO: Min/Max
608   if (match(ValOut, m_Select(m_Value(), m_Specific(PHI), m_Value(V))) ||
609       match(ValOut, m_Select(m_Value(), m_Value(V), m_Specific(PHI))))
610     ValOut = V;
611 
612   // Same for select, if this phi is 2-operand phi, compute analysis on other
613   // incoming value to break recursion.
614   // TODO: We could handle any number of incoming edges as long as we only have
615   // two unique values.
616   if (auto *IncPhi = dyn_cast<PHINode>(ValOut);
617       IncPhi && IncPhi->getNumIncomingValues() == 2) {
618     for (int Idx = 0; Idx < 2; ++Idx) {
619       if (IncPhi->getIncomingValue(Idx) == PHI) {
620         ValOut = IncPhi->getIncomingValue(1 - Idx);
621         if (PhiOut)
622           *PhiOut = IncPhi;
623         CtxIOut = IncPhi->getIncomingBlock(1 - Idx)->getTerminator();
624         break;
625       }
626     }
627   }
628 }
629 
630 static bool isKnownNonZeroFromAssume(const Value *V, const SimplifyQuery &Q) {
631   // Use of assumptions is context-sensitive. If we don't have a context, we
632   // cannot use them!
633   if (!Q.AC || !Q.CxtI)
634     return false;
635 
636   for (AssumptionCache::ResultElem &Elem : Q.AC->assumptionsFor(V)) {
637     if (!Elem.Assume)
638       continue;
639 
640     AssumeInst *I = cast<AssumeInst>(Elem.Assume);
641     assert(I->getFunction() == Q.CxtI->getFunction() &&
642            "Got assumption for the wrong function!");
643 
644     if (Elem.Index != AssumptionCache::ExprResultIdx) {
645       if (!V->getType()->isPointerTy())
646         continue;
647       if (RetainedKnowledge RK = getKnowledgeFromBundle(
648               *I, I->bundle_op_info_begin()[Elem.Index])) {
649         if (RK.WasOn == V &&
650             (RK.AttrKind == Attribute::NonNull ||
651              (RK.AttrKind == Attribute::Dereferenceable &&
652               !NullPointerIsDefined(Q.CxtI->getFunction(),
653                                     V->getType()->getPointerAddressSpace()))) &&
654             isValidAssumeForContext(I, Q.CxtI, Q.DT))
655           return true;
656       }
657       continue;
658     }
659 
660     // Warning: This loop can end up being somewhat performance sensitive.
661     // We're running this loop for once for each value queried resulting in a
662     // runtime of ~O(#assumes * #values).
663 
664     Value *RHS;
665     CmpPredicate Pred;
666     auto m_V = m_CombineOr(m_Specific(V), m_PtrToInt(m_Specific(V)));
667     if (!match(I->getArgOperand(0), m_c_ICmp(Pred, m_V, m_Value(RHS))))
668       continue;
669 
670     if (cmpExcludesZero(Pred, RHS) && isValidAssumeForContext(I, Q.CxtI, Q.DT))
671       return true;
672   }
673 
674   return false;
675 }
676 
677 static void computeKnownBitsFromCmp(const Value *V, CmpInst::Predicate Pred,
678                                     Value *LHS, Value *RHS, KnownBits &Known,
679                                     const SimplifyQuery &Q) {
680   if (RHS->getType()->isPointerTy()) {
681     // Handle comparison of pointer to null explicitly, as it will not be
682     // covered by the m_APInt() logic below.
683     if (LHS == V && match(RHS, m_Zero())) {
684       switch (Pred) {
685       case ICmpInst::ICMP_EQ:
686         Known.setAllZero();
687         break;
688       case ICmpInst::ICMP_SGE:
689       case ICmpInst::ICMP_SGT:
690         Known.makeNonNegative();
691         break;
692       case ICmpInst::ICMP_SLT:
693         Known.makeNegative();
694         break;
695       default:
696         break;
697       }
698     }
699     return;
700   }
701 
702   unsigned BitWidth = Known.getBitWidth();
703   auto m_V =
704       m_CombineOr(m_Specific(V), m_PtrToIntSameSize(Q.DL, m_Specific(V)));
705 
706   Value *Y;
707   const APInt *Mask, *C;
708   uint64_t ShAmt;
709   switch (Pred) {
710   case ICmpInst::ICMP_EQ:
711     // assume(V = C)
712     if (match(LHS, m_V) && match(RHS, m_APInt(C))) {
713       Known = Known.unionWith(KnownBits::makeConstant(*C));
714       // assume(V & Mask = C)
715     } else if (match(LHS, m_c_And(m_V, m_Value(Y))) &&
716                match(RHS, m_APInt(C))) {
717       // For one bits in Mask, we can propagate bits from C to V.
718       Known.One |= *C;
719       if (match(Y, m_APInt(Mask)))
720         Known.Zero |= ~*C & *Mask;
721       // assume(V | Mask = C)
722     } else if (match(LHS, m_c_Or(m_V, m_Value(Y))) && match(RHS, m_APInt(C))) {
723       // For zero bits in Mask, we can propagate bits from C to V.
724       Known.Zero |= ~*C;
725       if (match(Y, m_APInt(Mask)))
726         Known.One |= *C & ~*Mask;
727       // assume(V ^ Mask = C)
728     } else if (match(LHS, m_Xor(m_V, m_APInt(Mask))) &&
729                match(RHS, m_APInt(C))) {
730       // Equivalent to assume(V == Mask ^ C)
731       Known = Known.unionWith(KnownBits::makeConstant(*C ^ *Mask));
732       // assume(V << ShAmt = C)
733     } else if (match(LHS, m_Shl(m_V, m_ConstantInt(ShAmt))) &&
734                match(RHS, m_APInt(C)) && ShAmt < BitWidth) {
735       // For those bits in C that are known, we can propagate them to known
736       // bits in V shifted to the right by ShAmt.
737       KnownBits RHSKnown = KnownBits::makeConstant(*C);
738       RHSKnown.Zero.lshrInPlace(ShAmt);
739       RHSKnown.One.lshrInPlace(ShAmt);
740       Known = Known.unionWith(RHSKnown);
741       // assume(V >> ShAmt = C)
742     } else if (match(LHS, m_Shr(m_V, m_ConstantInt(ShAmt))) &&
743                match(RHS, m_APInt(C)) && ShAmt < BitWidth) {
744       KnownBits RHSKnown = KnownBits::makeConstant(*C);
745       // For those bits in RHS that are known, we can propagate them to known
746       // bits in V shifted to the right by C.
747       Known.Zero |= RHSKnown.Zero << ShAmt;
748       Known.One |= RHSKnown.One << ShAmt;
749     }
750     break;
751   case ICmpInst::ICMP_NE: {
752     // assume (V & B != 0) where B is a power of 2
753     const APInt *BPow2;
754     if (match(LHS, m_And(m_V, m_Power2(BPow2))) && match(RHS, m_Zero()))
755       Known.One |= *BPow2;
756     break;
757   }
758   default:
759     if (match(RHS, m_APInt(C))) {
760       const APInt *Offset = nullptr;
761       if (match(LHS, m_CombineOr(m_V, m_AddLike(m_V, m_APInt(Offset))))) {
762         ConstantRange LHSRange = ConstantRange::makeAllowedICmpRegion(Pred, *C);
763         if (Offset)
764           LHSRange = LHSRange.sub(*Offset);
765         Known = Known.unionWith(LHSRange.toKnownBits());
766       }
767       if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) {
768         // X & Y u> C     -> X u> C && Y u> C
769         // X nuw- Y u> C  -> X u> C
770         if (match(LHS, m_c_And(m_V, m_Value())) ||
771             match(LHS, m_NUWSub(m_V, m_Value())))
772           Known.One.setHighBits(
773               (*C + (Pred == ICmpInst::ICMP_UGT)).countLeadingOnes());
774       }
775       if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) {
776         // X | Y u< C    -> X u< C && Y u< C
777         // X nuw+ Y u< C -> X u< C && Y u< C
778         if (match(LHS, m_c_Or(m_V, m_Value())) ||
779             match(LHS, m_c_NUWAdd(m_V, m_Value()))) {
780           Known.Zero.setHighBits(
781               (*C - (Pred == ICmpInst::ICMP_ULT)).countLeadingZeros());
782         }
783       }
784     }
785     break;
786   }
787 }
788 
789 static void computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp,
790                                          KnownBits &Known,
791                                          const SimplifyQuery &SQ, bool Invert) {
792   ICmpInst::Predicate Pred =
793       Invert ? Cmp->getInversePredicate() : Cmp->getPredicate();
794   Value *LHS = Cmp->getOperand(0);
795   Value *RHS = Cmp->getOperand(1);
796 
797   // Handle icmp pred (trunc V), C
798   if (match(LHS, m_Trunc(m_Specific(V)))) {
799     KnownBits DstKnown(LHS->getType()->getScalarSizeInBits());
800     computeKnownBitsFromCmp(LHS, Pred, LHS, RHS, DstKnown, SQ);
801     Known = Known.unionWith(DstKnown.anyext(Known.getBitWidth()));
802     return;
803   }
804 
805   computeKnownBitsFromCmp(V, Pred, LHS, RHS, Known, SQ);
806 }
807 
808 static void computeKnownBitsFromCond(const Value *V, Value *Cond,
809                                      KnownBits &Known, unsigned Depth,
810                                      const SimplifyQuery &SQ, bool Invert) {
811   Value *A, *B;
812   if (Depth < MaxAnalysisRecursionDepth &&
813       match(Cond, m_LogicalOp(m_Value(A), m_Value(B)))) {
814     KnownBits Known2(Known.getBitWidth());
815     KnownBits Known3(Known.getBitWidth());
816     computeKnownBitsFromCond(V, A, Known2, Depth + 1, SQ, Invert);
817     computeKnownBitsFromCond(V, B, Known3, Depth + 1, SQ, Invert);
818     if (Invert ? match(Cond, m_LogicalOr(m_Value(), m_Value()))
819                : match(Cond, m_LogicalAnd(m_Value(), m_Value())))
820       Known2 = Known2.unionWith(Known3);
821     else
822       Known2 = Known2.intersectWith(Known3);
823     Known = Known.unionWith(Known2);
824   }
825 
826   if (auto *Cmp = dyn_cast<ICmpInst>(Cond))
827     computeKnownBitsFromICmpCond(V, Cmp, Known, SQ, Invert);
828 }
829 
830 void llvm::computeKnownBitsFromContext(const Value *V, KnownBits &Known,
831                                        unsigned Depth, const SimplifyQuery &Q) {
832   // Handle injected condition.
833   if (Q.CC && Q.CC->AffectedValues.contains(V))
834     computeKnownBitsFromCond(V, Q.CC->Cond, Known, Depth, Q, Q.CC->Invert);
835 
836   if (!Q.CxtI)
837     return;
838 
839   if (Q.DC && Q.DT) {
840     // Handle dominating conditions.
841     for (BranchInst *BI : Q.DC->conditionsFor(V)) {
842       BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
843       if (Q.DT->dominates(Edge0, Q.CxtI->getParent()))
844         computeKnownBitsFromCond(V, BI->getCondition(), Known, Depth, Q,
845                                  /*Invert*/ false);
846 
847       BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(1));
848       if (Q.DT->dominates(Edge1, Q.CxtI->getParent()))
849         computeKnownBitsFromCond(V, BI->getCondition(), Known, Depth, Q,
850                                  /*Invert*/ true);
851     }
852 
853     if (Known.hasConflict())
854       Known.resetAll();
855   }
856 
857   if (!Q.AC)
858     return;
859 
860   unsigned BitWidth = Known.getBitWidth();
861 
862   // Note that the patterns below need to be kept in sync with the code
863   // in AssumptionCache::updateAffectedValues.
864 
865   for (AssumptionCache::ResultElem &Elem : Q.AC->assumptionsFor(V)) {
866     if (!Elem.Assume)
867       continue;
868 
869     AssumeInst *I = cast<AssumeInst>(Elem.Assume);
870     assert(I->getParent()->getParent() == Q.CxtI->getParent()->getParent() &&
871            "Got assumption for the wrong function!");
872 
873     if (Elem.Index != AssumptionCache::ExprResultIdx) {
874       if (!V->getType()->isPointerTy())
875         continue;
876       if (RetainedKnowledge RK = getKnowledgeFromBundle(
877               *I, I->bundle_op_info_begin()[Elem.Index])) {
878         // Allow AllowEphemerals in isValidAssumeForContext, as the CxtI might
879         // be the producer of the pointer in the bundle. At the moment, align
880         // assumptions aren't optimized away.
881         if (RK.WasOn == V && RK.AttrKind == Attribute::Alignment &&
882             isPowerOf2_64(RK.ArgValue) &&
883             isValidAssumeForContext(I, Q.CxtI, Q.DT, /*AllowEphemerals*/ true))
884           Known.Zero.setLowBits(Log2_64(RK.ArgValue));
885       }
886       continue;
887     }
888 
889     // Warning: This loop can end up being somewhat performance sensitive.
890     // We're running this loop for once for each value queried resulting in a
891     // runtime of ~O(#assumes * #values).
892 
893     Value *Arg = I->getArgOperand(0);
894 
895     if (Arg == V && isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
896       assert(BitWidth == 1 && "assume operand is not i1?");
897       (void)BitWidth;
898       Known.setAllOnes();
899       return;
900     }
901     if (match(Arg, m_Not(m_Specific(V))) &&
902         isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
903       assert(BitWidth == 1 && "assume operand is not i1?");
904       (void)BitWidth;
905       Known.setAllZero();
906       return;
907     }
908 
909     // The remaining tests are all recursive, so bail out if we hit the limit.
910     if (Depth == MaxAnalysisRecursionDepth)
911       continue;
912 
913     ICmpInst *Cmp = dyn_cast<ICmpInst>(Arg);
914     if (!Cmp)
915       continue;
916 
917     if (!isValidAssumeForContext(I, Q.CxtI, Q.DT))
918       continue;
919 
920     computeKnownBitsFromICmpCond(V, Cmp, Known, Q, /*Invert=*/false);
921   }
922 
923   // Conflicting assumption: Undefined behavior will occur on this execution
924   // path.
925   if (Known.hasConflict())
926     Known.resetAll();
927 }
928 
929 /// Compute known bits from a shift operator, including those with a
930 /// non-constant shift amount. Known is the output of this function. Known2 is a
931 /// pre-allocated temporary with the same bit width as Known and on return
932 /// contains the known bit of the shift value source. KF is an
933 /// operator-specific function that, given the known-bits and a shift amount,
934 /// compute the implied known-bits of the shift operator's result respectively
935 /// for that shift amount. The results from calling KF are conservatively
936 /// combined for all permitted shift amounts.
937 static void computeKnownBitsFromShiftOperator(
938     const Operator *I, const APInt &DemandedElts, KnownBits &Known,
939     KnownBits &Known2, unsigned Depth, const SimplifyQuery &Q,
940     function_ref<KnownBits(const KnownBits &, const KnownBits &, bool)> KF) {
941   computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
942   computeKnownBits(I->getOperand(1), DemandedElts, Known, Depth + 1, Q);
943   // To limit compile-time impact, only query isKnownNonZero() if we know at
944   // least something about the shift amount.
945   bool ShAmtNonZero =
946       Known.isNonZero() ||
947       (Known.getMaxValue().ult(Known.getBitWidth()) &&
948        isKnownNonZero(I->getOperand(1), DemandedElts, Q, Depth + 1));
949   Known = KF(Known2, Known, ShAmtNonZero);
950 }
951 
952 static KnownBits
953 getKnownBitsFromAndXorOr(const Operator *I, const APInt &DemandedElts,
954                          const KnownBits &KnownLHS, const KnownBits &KnownRHS,
955                          unsigned Depth, const SimplifyQuery &Q) {
956   unsigned BitWidth = KnownLHS.getBitWidth();
957   KnownBits KnownOut(BitWidth);
958   bool IsAnd = false;
959   bool HasKnownOne = !KnownLHS.One.isZero() || !KnownRHS.One.isZero();
960   Value *X = nullptr, *Y = nullptr;
961 
962   switch (I->getOpcode()) {
963   case Instruction::And:
964     KnownOut = KnownLHS & KnownRHS;
965     IsAnd = true;
966     // and(x, -x) is common idioms that will clear all but lowest set
967     // bit. If we have a single known bit in x, we can clear all bits
968     // above it.
969     // TODO: instcombine often reassociates independent `and` which can hide
970     // this pattern. Try to match and(x, and(-x, y)) / and(and(x, y), -x).
971     if (HasKnownOne && match(I, m_c_And(m_Value(X), m_Neg(m_Deferred(X))))) {
972       // -(-x) == x so using whichever (LHS/RHS) gets us a better result.
973       if (KnownLHS.countMaxTrailingZeros() <= KnownRHS.countMaxTrailingZeros())
974         KnownOut = KnownLHS.blsi();
975       else
976         KnownOut = KnownRHS.blsi();
977     }
978     break;
979   case Instruction::Or:
980     KnownOut = KnownLHS | KnownRHS;
981     break;
982   case Instruction::Xor:
983     KnownOut = KnownLHS ^ KnownRHS;
984     // xor(x, x-1) is common idioms that will clear all but lowest set
985     // bit. If we have a single known bit in x, we can clear all bits
986     // above it.
987     // TODO: xor(x, x-1) is often rewritting as xor(x, x-C) where C !=
988     // -1 but for the purpose of demanded bits (xor(x, x-C) &
989     // Demanded) == (xor(x, x-1) & Demanded). Extend the xor pattern
990     // to use arbitrary C if xor(x, x-C) as the same as xor(x, x-1).
991     if (HasKnownOne &&
992         match(I, m_c_Xor(m_Value(X), m_Add(m_Deferred(X), m_AllOnes())))) {
993       const KnownBits &XBits = I->getOperand(0) == X ? KnownLHS : KnownRHS;
994       KnownOut = XBits.blsmsk();
995     }
996     break;
997   default:
998     llvm_unreachable("Invalid Op used in 'analyzeKnownBitsFromAndXorOr'");
999   }
1000 
1001   // and(x, add (x, -1)) is a common idiom that always clears the low bit;
1002   // xor/or(x, add (x, -1)) is an idiom that will always set the low bit.
1003   // here we handle the more general case of adding any odd number by
1004   // matching the form and/xor/or(x, add(x, y)) where y is odd.
1005   // TODO: This could be generalized to clearing any bit set in y where the
1006   // following bit is known to be unset in y.
1007   if (!KnownOut.Zero[0] && !KnownOut.One[0] &&
1008       (match(I, m_c_BinOp(m_Value(X), m_c_Add(m_Deferred(X), m_Value(Y)))) ||
1009        match(I, m_c_BinOp(m_Value(X), m_Sub(m_Deferred(X), m_Value(Y)))) ||
1010        match(I, m_c_BinOp(m_Value(X), m_Sub(m_Value(Y), m_Deferred(X)))))) {
1011     KnownBits KnownY(BitWidth);
1012     computeKnownBits(Y, DemandedElts, KnownY, Depth + 1, Q);
1013     if (KnownY.countMinTrailingOnes() > 0) {
1014       if (IsAnd)
1015         KnownOut.Zero.setBit(0);
1016       else
1017         KnownOut.One.setBit(0);
1018     }
1019   }
1020   return KnownOut;
1021 }
1022 
1023 static KnownBits computeKnownBitsForHorizontalOperation(
1024     const Operator *I, const APInt &DemandedElts, unsigned Depth,
1025     const SimplifyQuery &Q,
1026     const function_ref<KnownBits(const KnownBits &, const KnownBits &)>
1027         KnownBitsFunc) {
1028   APInt DemandedEltsLHS, DemandedEltsRHS;
1029   getHorizDemandedEltsForFirstOperand(Q.DL.getTypeSizeInBits(I->getType()),
1030                                       DemandedElts, DemandedEltsLHS,
1031                                       DemandedEltsRHS);
1032 
1033   const auto ComputeForSingleOpFunc =
1034       [Depth, &Q, KnownBitsFunc](const Value *Op, APInt &DemandedEltsOp) {
1035         return KnownBitsFunc(
1036             computeKnownBits(Op, DemandedEltsOp, Depth + 1, Q),
1037             computeKnownBits(Op, DemandedEltsOp << 1, Depth + 1, Q));
1038       };
1039 
1040   if (DemandedEltsRHS.isZero())
1041     return ComputeForSingleOpFunc(I->getOperand(0), DemandedEltsLHS);
1042   if (DemandedEltsLHS.isZero())
1043     return ComputeForSingleOpFunc(I->getOperand(1), DemandedEltsRHS);
1044 
1045   return ComputeForSingleOpFunc(I->getOperand(0), DemandedEltsLHS)
1046       .intersectWith(ComputeForSingleOpFunc(I->getOperand(1), DemandedEltsRHS));
1047 }
1048 
1049 // Public so this can be used in `SimplifyDemandedUseBits`.
1050 KnownBits llvm::analyzeKnownBitsFromAndXorOr(const Operator *I,
1051                                              const KnownBits &KnownLHS,
1052                                              const KnownBits &KnownRHS,
1053                                              unsigned Depth,
1054                                              const SimplifyQuery &SQ) {
1055   auto *FVTy = dyn_cast<FixedVectorType>(I->getType());
1056   APInt DemandedElts =
1057       FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
1058 
1059   return getKnownBitsFromAndXorOr(I, DemandedElts, KnownLHS, KnownRHS, Depth,
1060                                   SQ);
1061 }
1062 
1063 ConstantRange llvm::getVScaleRange(const Function *F, unsigned BitWidth) {
1064   Attribute Attr = F->getFnAttribute(Attribute::VScaleRange);
1065   // Without vscale_range, we only know that vscale is non-zero.
1066   if (!Attr.isValid())
1067     return ConstantRange(APInt(BitWidth, 1), APInt::getZero(BitWidth));
1068 
1069   unsigned AttrMin = Attr.getVScaleRangeMin();
1070   // Minimum is larger than vscale width, result is always poison.
1071   if ((unsigned)llvm::bit_width(AttrMin) > BitWidth)
1072     return ConstantRange::getEmpty(BitWidth);
1073 
1074   APInt Min(BitWidth, AttrMin);
1075   std::optional<unsigned> AttrMax = Attr.getVScaleRangeMax();
1076   if (!AttrMax || (unsigned)llvm::bit_width(*AttrMax) > BitWidth)
1077     return ConstantRange(Min, APInt::getZero(BitWidth));
1078 
1079   return ConstantRange(Min, APInt(BitWidth, *AttrMax) + 1);
1080 }
1081 
1082 void llvm::adjustKnownBitsForSelectArm(KnownBits &Known, Value *Cond,
1083                                        Value *Arm, bool Invert, unsigned Depth,
1084                                        const SimplifyQuery &Q) {
1085   // If we have a constant arm, we are done.
1086   if (Known.isConstant())
1087     return;
1088 
1089   // See what condition implies about the bits of the select arm.
1090   KnownBits CondRes(Known.getBitWidth());
1091   computeKnownBitsFromCond(Arm, Cond, CondRes, Depth + 1, Q, Invert);
1092   // If we don't get any information from the condition, no reason to
1093   // proceed.
1094   if (CondRes.isUnknown())
1095     return;
1096 
1097   // We can have conflict if the condition is dead. I.e if we have
1098   // (x | 64) < 32 ? (x | 64) : y
1099   // we will have conflict at bit 6 from the condition/the `or`.
1100   // In that case just return. Its not particularly important
1101   // what we do, as this select is going to be simplified soon.
1102   CondRes = CondRes.unionWith(Known);
1103   if (CondRes.hasConflict())
1104     return;
1105 
1106   // Finally make sure the information we found is valid. This is relatively
1107   // expensive so it's left for the very end.
1108   if (!isGuaranteedNotToBeUndef(Arm, Q.AC, Q.CxtI, Q.DT, Depth + 1))
1109     return;
1110 
1111   // Finally, we know we get information from the condition and its valid,
1112   // so return it.
1113   Known = CondRes;
1114 }
1115 
1116 // Match a signed min+max clamp pattern like smax(smin(In, CHigh), CLow).
1117 // Returns the input and lower/upper bounds.
1118 static bool isSignedMinMaxClamp(const Value *Select, const Value *&In,
1119                                 const APInt *&CLow, const APInt *&CHigh) {
1120   assert(isa<Operator>(Select) &&
1121          cast<Operator>(Select)->getOpcode() == Instruction::Select &&
1122          "Input should be a Select!");
1123 
1124   const Value *LHS = nullptr, *RHS = nullptr;
1125   SelectPatternFlavor SPF = matchSelectPattern(Select, LHS, RHS).Flavor;
1126   if (SPF != SPF_SMAX && SPF != SPF_SMIN)
1127     return false;
1128 
1129   if (!match(RHS, m_APInt(CLow)))
1130     return false;
1131 
1132   const Value *LHS2 = nullptr, *RHS2 = nullptr;
1133   SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2).Flavor;
1134   if (getInverseMinMaxFlavor(SPF) != SPF2)
1135     return false;
1136 
1137   if (!match(RHS2, m_APInt(CHigh)))
1138     return false;
1139 
1140   if (SPF == SPF_SMIN)
1141     std::swap(CLow, CHigh);
1142 
1143   In = LHS2;
1144   return CLow->sle(*CHigh);
1145 }
1146 
1147 static bool isSignedMinMaxIntrinsicClamp(const IntrinsicInst *II,
1148                                          const APInt *&CLow,
1149                                          const APInt *&CHigh) {
1150   assert((II->getIntrinsicID() == Intrinsic::smin ||
1151           II->getIntrinsicID() == Intrinsic::smax) &&
1152          "Must be smin/smax");
1153 
1154   Intrinsic::ID InverseID = getInverseMinMaxIntrinsic(II->getIntrinsicID());
1155   auto *InnerII = dyn_cast<IntrinsicInst>(II->getArgOperand(0));
1156   if (!InnerII || InnerII->getIntrinsicID() != InverseID ||
1157       !match(II->getArgOperand(1), m_APInt(CLow)) ||
1158       !match(InnerII->getArgOperand(1), m_APInt(CHigh)))
1159     return false;
1160 
1161   if (II->getIntrinsicID() == Intrinsic::smin)
1162     std::swap(CLow, CHigh);
1163   return CLow->sle(*CHigh);
1164 }
1165 
1166 static void unionWithMinMaxIntrinsicClamp(const IntrinsicInst *II,
1167                                           KnownBits &Known) {
1168   const APInt *CLow, *CHigh;
1169   if (isSignedMinMaxIntrinsicClamp(II, CLow, CHigh))
1170     Known = Known.unionWith(
1171         ConstantRange::getNonEmpty(*CLow, *CHigh + 1).toKnownBits());
1172 }
1173 
1174 static void computeKnownBitsFromOperator(const Operator *I,
1175                                          const APInt &DemandedElts,
1176                                          KnownBits &Known, unsigned Depth,
1177                                          const SimplifyQuery &Q) {
1178   unsigned BitWidth = Known.getBitWidth();
1179 
1180   KnownBits Known2(BitWidth);
1181   switch (I->getOpcode()) {
1182   default: break;
1183   case Instruction::Load:
1184     if (MDNode *MD =
1185             Q.IIQ.getMetadata(cast<LoadInst>(I), LLVMContext::MD_range))
1186       computeKnownBitsFromRangeMetadata(*MD, Known);
1187     break;
1188   case Instruction::And:
1189     computeKnownBits(I->getOperand(1), DemandedElts, Known, Depth + 1, Q);
1190     computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
1191 
1192     Known = getKnownBitsFromAndXorOr(I, DemandedElts, Known2, Known, Depth, Q);
1193     break;
1194   case Instruction::Or:
1195     computeKnownBits(I->getOperand(1), DemandedElts, Known, Depth + 1, Q);
1196     computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
1197 
1198     Known = getKnownBitsFromAndXorOr(I, DemandedElts, Known2, Known, Depth, Q);
1199     break;
1200   case Instruction::Xor:
1201     computeKnownBits(I->getOperand(1), DemandedElts, Known, Depth + 1, Q);
1202     computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
1203 
1204     Known = getKnownBitsFromAndXorOr(I, DemandedElts, Known2, Known, Depth, Q);
1205     break;
1206   case Instruction::Mul: {
1207     bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
1208     bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
1209     computeKnownBitsMul(I->getOperand(0), I->getOperand(1), NSW, NUW,
1210                         DemandedElts, Known, Known2, Depth, Q);
1211     break;
1212   }
1213   case Instruction::UDiv: {
1214     computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
1215     computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
1216     Known =
1217         KnownBits::udiv(Known, Known2, Q.IIQ.isExact(cast<BinaryOperator>(I)));
1218     break;
1219   }
1220   case Instruction::SDiv: {
1221     computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
1222     computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
1223     Known =
1224         KnownBits::sdiv(Known, Known2, Q.IIQ.isExact(cast<BinaryOperator>(I)));
1225     break;
1226   }
1227   case Instruction::Select: {
1228     auto ComputeForArm = [&](Value *Arm, bool Invert) {
1229       KnownBits Res(Known.getBitWidth());
1230       computeKnownBits(Arm, DemandedElts, Res, Depth + 1, Q);
1231       adjustKnownBitsForSelectArm(Res, I->getOperand(0), Arm, Invert, Depth, Q);
1232       return Res;
1233     };
1234     // Only known if known in both the LHS and RHS.
1235     Known =
1236         ComputeForArm(I->getOperand(1), /*Invert=*/false)
1237             .intersectWith(ComputeForArm(I->getOperand(2), /*Invert=*/true));
1238     break;
1239   }
1240   case Instruction::FPTrunc:
1241   case Instruction::FPExt:
1242   case Instruction::FPToUI:
1243   case Instruction::FPToSI:
1244   case Instruction::SIToFP:
1245   case Instruction::UIToFP:
1246     break; // Can't work with floating point.
1247   case Instruction::PtrToInt:
1248   case Instruction::IntToPtr:
1249     // Fall through and handle them the same as zext/trunc.
1250     [[fallthrough]];
1251   case Instruction::ZExt:
1252   case Instruction::Trunc: {
1253     Type *SrcTy = I->getOperand(0)->getType();
1254 
1255     unsigned SrcBitWidth;
1256     // Note that we handle pointer operands here because of inttoptr/ptrtoint
1257     // which fall through here.
1258     Type *ScalarTy = SrcTy->getScalarType();
1259     SrcBitWidth = ScalarTy->isPointerTy() ?
1260       Q.DL.getPointerTypeSizeInBits(ScalarTy) :
1261       Q.DL.getTypeSizeInBits(ScalarTy);
1262 
1263     assert(SrcBitWidth && "SrcBitWidth can't be zero");
1264     Known = Known.anyextOrTrunc(SrcBitWidth);
1265     computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
1266     if (auto *Inst = dyn_cast<PossiblyNonNegInst>(I);
1267         Inst && Inst->hasNonNeg() && !Known.isNegative())
1268       Known.makeNonNegative();
1269     Known = Known.zextOrTrunc(BitWidth);
1270     break;
1271   }
1272   case Instruction::BitCast: {
1273     Type *SrcTy = I->getOperand(0)->getType();
1274     if (SrcTy->isIntOrPtrTy() &&
1275         // TODO: For now, not handling conversions like:
1276         // (bitcast i64 %x to <2 x i32>)
1277         !I->getType()->isVectorTy()) {
1278       computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
1279       break;
1280     }
1281 
1282     const Value *V;
1283     // Handle bitcast from floating point to integer.
1284     if (match(I, m_ElementWiseBitCast(m_Value(V))) &&
1285         V->getType()->isFPOrFPVectorTy()) {
1286       Type *FPType = V->getType()->getScalarType();
1287       KnownFPClass Result =
1288           computeKnownFPClass(V, DemandedElts, fcAllFlags, Depth + 1, Q);
1289       FPClassTest FPClasses = Result.KnownFPClasses;
1290 
1291       // TODO: Treat it as zero/poison if the use of I is unreachable.
1292       if (FPClasses == fcNone)
1293         break;
1294 
1295       if (Result.isKnownNever(fcNormal | fcSubnormal | fcNan)) {
1296         Known.Zero.setAllBits();
1297         Known.One.setAllBits();
1298 
1299         if (FPClasses & fcInf)
1300           Known = Known.intersectWith(KnownBits::makeConstant(
1301               APFloat::getInf(FPType->getFltSemantics()).bitcastToAPInt()));
1302 
1303         if (FPClasses & fcZero)
1304           Known = Known.intersectWith(KnownBits::makeConstant(
1305               APInt::getZero(FPType->getScalarSizeInBits())));
1306 
1307         Known.Zero.clearSignBit();
1308         Known.One.clearSignBit();
1309       }
1310 
1311       if (Result.SignBit) {
1312         if (*Result.SignBit)
1313           Known.makeNegative();
1314         else
1315           Known.makeNonNegative();
1316       }
1317 
1318       break;
1319     }
1320 
1321     // Handle cast from vector integer type to scalar or vector integer.
1322     auto *SrcVecTy = dyn_cast<FixedVectorType>(SrcTy);
1323     if (!SrcVecTy || !SrcVecTy->getElementType()->isIntegerTy() ||
1324         !I->getType()->isIntOrIntVectorTy() ||
1325         isa<ScalableVectorType>(I->getType()))
1326       break;
1327 
1328     // Look through a cast from narrow vector elements to wider type.
1329     // Examples: v4i32 -> v2i64, v3i8 -> v24
1330     unsigned SubBitWidth = SrcVecTy->getScalarSizeInBits();
1331     if (BitWidth % SubBitWidth == 0) {
1332       // Known bits are automatically intersected across demanded elements of a
1333       // vector. So for example, if a bit is computed as known zero, it must be
1334       // zero across all demanded elements of the vector.
1335       //
1336       // For this bitcast, each demanded element of the output is sub-divided
1337       // across a set of smaller vector elements in the source vector. To get
1338       // the known bits for an entire element of the output, compute the known
1339       // bits for each sub-element sequentially. This is done by shifting the
1340       // one-set-bit demanded elements parameter across the sub-elements for
1341       // consecutive calls to computeKnownBits. We are using the demanded
1342       // elements parameter as a mask operator.
1343       //
1344       // The known bits of each sub-element are then inserted into place
1345       // (dependent on endian) to form the full result of known bits.
1346       unsigned NumElts = DemandedElts.getBitWidth();
1347       unsigned SubScale = BitWidth / SubBitWidth;
1348       APInt SubDemandedElts = APInt::getZero(NumElts * SubScale);
1349       for (unsigned i = 0; i != NumElts; ++i) {
1350         if (DemandedElts[i])
1351           SubDemandedElts.setBit(i * SubScale);
1352       }
1353 
1354       KnownBits KnownSrc(SubBitWidth);
1355       for (unsigned i = 0; i != SubScale; ++i) {
1356         computeKnownBits(I->getOperand(0), SubDemandedElts.shl(i), KnownSrc,
1357                          Depth + 1, Q);
1358         unsigned ShiftElt = Q.DL.isLittleEndian() ? i : SubScale - 1 - i;
1359         Known.insertBits(KnownSrc, ShiftElt * SubBitWidth);
1360       }
1361     }
1362     break;
1363   }
1364   case Instruction::SExt: {
1365     // Compute the bits in the result that are not present in the input.
1366     unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits();
1367 
1368     Known = Known.trunc(SrcBitWidth);
1369     computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
1370     // If the sign bit of the input is known set or clear, then we know the
1371     // top bits of the result.
1372     Known = Known.sext(BitWidth);
1373     break;
1374   }
1375   case Instruction::Shl: {
1376     bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
1377     bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
1378     auto KF = [NUW, NSW](const KnownBits &KnownVal, const KnownBits &KnownAmt,
1379                          bool ShAmtNonZero) {
1380       return KnownBits::shl(KnownVal, KnownAmt, NUW, NSW, ShAmtNonZero);
1381     };
1382     computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q,
1383                                       KF);
1384     // Trailing zeros of a right-shifted constant never decrease.
1385     const APInt *C;
1386     if (match(I->getOperand(0), m_APInt(C)))
1387       Known.Zero.setLowBits(C->countr_zero());
1388     break;
1389   }
1390   case Instruction::LShr: {
1391     bool Exact = Q.IIQ.isExact(cast<BinaryOperator>(I));
1392     auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
1393                       bool ShAmtNonZero) {
1394       return KnownBits::lshr(KnownVal, KnownAmt, ShAmtNonZero, Exact);
1395     };
1396     computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q,
1397                                       KF);
1398     // Leading zeros of a left-shifted constant never decrease.
1399     const APInt *C;
1400     if (match(I->getOperand(0), m_APInt(C)))
1401       Known.Zero.setHighBits(C->countl_zero());
1402     break;
1403   }
1404   case Instruction::AShr: {
1405     bool Exact = Q.IIQ.isExact(cast<BinaryOperator>(I));
1406     auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
1407                       bool ShAmtNonZero) {
1408       return KnownBits::ashr(KnownVal, KnownAmt, ShAmtNonZero, Exact);
1409     };
1410     computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q,
1411                                       KF);
1412     break;
1413   }
1414   case Instruction::Sub: {
1415     bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
1416     bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
1417     computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW, NUW,
1418                            DemandedElts, Known, Known2, Depth, Q);
1419     break;
1420   }
1421   case Instruction::Add: {
1422     bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
1423     bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
1424     computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW, NUW,
1425                            DemandedElts, Known, Known2, Depth, Q);
1426     break;
1427   }
1428   case Instruction::SRem:
1429     computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
1430     computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
1431     Known = KnownBits::srem(Known, Known2);
1432     break;
1433 
1434   case Instruction::URem:
1435     computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
1436     computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
1437     Known = KnownBits::urem(Known, Known2);
1438     break;
1439   case Instruction::Alloca:
1440     Known.Zero.setLowBits(Log2(cast<AllocaInst>(I)->getAlign()));
1441     break;
1442   case Instruction::GetElementPtr: {
1443     // Analyze all of the subscripts of this getelementptr instruction
1444     // to determine if we can prove known low zero bits.
1445     computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
1446     // Accumulate the constant indices in a separate variable
1447     // to minimize the number of calls to computeForAddSub.
1448     APInt AccConstIndices(BitWidth, 0, /*IsSigned*/ true);
1449 
1450     gep_type_iterator GTI = gep_type_begin(I);
1451     for (unsigned i = 1, e = I->getNumOperands(); i != e; ++i, ++GTI) {
1452       // TrailZ can only become smaller, short-circuit if we hit zero.
1453       if (Known.isUnknown())
1454         break;
1455 
1456       Value *Index = I->getOperand(i);
1457 
1458       // Handle case when index is zero.
1459       Constant *CIndex = dyn_cast<Constant>(Index);
1460       if (CIndex && CIndex->isZeroValue())
1461         continue;
1462 
1463       if (StructType *STy = GTI.getStructTypeOrNull()) {
1464         // Handle struct member offset arithmetic.
1465 
1466         assert(CIndex &&
1467                "Access to structure field must be known at compile time");
1468 
1469         if (CIndex->getType()->isVectorTy())
1470           Index = CIndex->getSplatValue();
1471 
1472         unsigned Idx = cast<ConstantInt>(Index)->getZExtValue();
1473         const StructLayout *SL = Q.DL.getStructLayout(STy);
1474         uint64_t Offset = SL->getElementOffset(Idx);
1475         AccConstIndices += Offset;
1476         continue;
1477       }
1478 
1479       // Handle array index arithmetic.
1480       Type *IndexedTy = GTI.getIndexedType();
1481       if (!IndexedTy->isSized()) {
1482         Known.resetAll();
1483         break;
1484       }
1485 
1486       unsigned IndexBitWidth = Index->getType()->getScalarSizeInBits();
1487       KnownBits IndexBits(IndexBitWidth);
1488       computeKnownBits(Index, IndexBits, Depth + 1, Q);
1489       TypeSize IndexTypeSize = GTI.getSequentialElementStride(Q.DL);
1490       uint64_t TypeSizeInBytes = IndexTypeSize.getKnownMinValue();
1491       KnownBits ScalingFactor(IndexBitWidth);
1492       // Multiply by current sizeof type.
1493       // &A[i] == A + i * sizeof(*A[i]).
1494       if (IndexTypeSize.isScalable()) {
1495         // For scalable types the only thing we know about sizeof is
1496         // that this is a multiple of the minimum size.
1497         ScalingFactor.Zero.setLowBits(llvm::countr_zero(TypeSizeInBytes));
1498       } else if (IndexBits.isConstant()) {
1499         APInt IndexConst = IndexBits.getConstant();
1500         APInt ScalingFactor(IndexBitWidth, TypeSizeInBytes);
1501         IndexConst *= ScalingFactor;
1502         AccConstIndices += IndexConst.sextOrTrunc(BitWidth);
1503         continue;
1504       } else {
1505         ScalingFactor =
1506             KnownBits::makeConstant(APInt(IndexBitWidth, TypeSizeInBytes));
1507       }
1508       IndexBits = KnownBits::mul(IndexBits, ScalingFactor);
1509 
1510       // If the offsets have a different width from the pointer, according
1511       // to the language reference we need to sign-extend or truncate them
1512       // to the width of the pointer.
1513       IndexBits = IndexBits.sextOrTrunc(BitWidth);
1514 
1515       // Note that inbounds does *not* guarantee nsw for the addition, as only
1516       // the offset is signed, while the base address is unsigned.
1517       Known = KnownBits::add(Known, IndexBits);
1518     }
1519     if (!Known.isUnknown() && !AccConstIndices.isZero()) {
1520       KnownBits Index = KnownBits::makeConstant(AccConstIndices);
1521       Known = KnownBits::add(Known, Index);
1522     }
1523     break;
1524   }
1525   case Instruction::PHI: {
1526     const PHINode *P = cast<PHINode>(I);
1527     BinaryOperator *BO = nullptr;
1528     Value *R = nullptr, *L = nullptr;
1529     if (matchSimpleRecurrence(P, BO, R, L)) {
1530       // Handle the case of a simple two-predecessor recurrence PHI.
1531       // There's a lot more that could theoretically be done here, but
1532       // this is sufficient to catch some interesting cases.
1533       unsigned Opcode = BO->getOpcode();
1534 
1535       switch (Opcode) {
1536       // If this is a shift recurrence, we know the bits being shifted in. We
1537       // can combine that with information about the start value of the
1538       // recurrence to conclude facts about the result. If this is a udiv
1539       // recurrence, we know that the result can never exceed either the
1540       // numerator or the start value, whichever is greater.
1541       case Instruction::LShr:
1542       case Instruction::AShr:
1543       case Instruction::Shl:
1544       case Instruction::UDiv:
1545         if (BO->getOperand(0) != I)
1546           break;
1547         [[fallthrough]];
1548 
1549       // For a urem recurrence, the result can never exceed the start value. The
1550       // phi could either be the numerator or the denominator.
1551       case Instruction::URem: {
1552         // We have matched a recurrence of the form:
1553         // %iv = [R, %entry], [%iv.next, %backedge]
1554         // %iv.next = shift_op %iv, L
1555 
1556         // Recurse with the phi context to avoid concern about whether facts
1557         // inferred hold at original context instruction.  TODO: It may be
1558         // correct to use the original context.  IF warranted, explore and
1559         // add sufficient tests to cover.
1560         SimplifyQuery RecQ = Q.getWithoutCondContext();
1561         RecQ.CxtI = P;
1562         computeKnownBits(R, DemandedElts, Known2, Depth + 1, RecQ);
1563         switch (Opcode) {
1564         case Instruction::Shl:
1565           // A shl recurrence will only increase the tailing zeros
1566           Known.Zero.setLowBits(Known2.countMinTrailingZeros());
1567           break;
1568         case Instruction::LShr:
1569         case Instruction::UDiv:
1570         case Instruction::URem:
1571           // lshr, udiv, and urem recurrences will preserve the leading zeros of
1572           // the start value.
1573           Known.Zero.setHighBits(Known2.countMinLeadingZeros());
1574           break;
1575         case Instruction::AShr:
1576           // An ashr recurrence will extend the initial sign bit
1577           Known.Zero.setHighBits(Known2.countMinLeadingZeros());
1578           Known.One.setHighBits(Known2.countMinLeadingOnes());
1579           break;
1580         }
1581         break;
1582       }
1583 
1584       // Check for operations that have the property that if
1585       // both their operands have low zero bits, the result
1586       // will have low zero bits.
1587       case Instruction::Add:
1588       case Instruction::Sub:
1589       case Instruction::And:
1590       case Instruction::Or:
1591       case Instruction::Mul: {
1592         // Change the context instruction to the "edge" that flows into the
1593         // phi. This is important because that is where the value is actually
1594         // "evaluated" even though it is used later somewhere else. (see also
1595         // D69571).
1596         SimplifyQuery RecQ = Q.getWithoutCondContext();
1597 
1598         unsigned OpNum = P->getOperand(0) == R ? 0 : 1;
1599         Instruction *RInst = P->getIncomingBlock(OpNum)->getTerminator();
1600         Instruction *LInst = P->getIncomingBlock(1 - OpNum)->getTerminator();
1601 
1602         // Ok, we have a PHI of the form L op= R. Check for low
1603         // zero bits.
1604         RecQ.CxtI = RInst;
1605         computeKnownBits(R, DemandedElts, Known2, Depth + 1, RecQ);
1606 
1607         // We need to take the minimum number of known bits
1608         KnownBits Known3(BitWidth);
1609         RecQ.CxtI = LInst;
1610         computeKnownBits(L, DemandedElts, Known3, Depth + 1, RecQ);
1611 
1612         Known.Zero.setLowBits(std::min(Known2.countMinTrailingZeros(),
1613                                        Known3.countMinTrailingZeros()));
1614 
1615         auto *OverflowOp = dyn_cast<OverflowingBinaryOperator>(BO);
1616         if (!OverflowOp || !Q.IIQ.hasNoSignedWrap(OverflowOp))
1617           break;
1618 
1619         switch (Opcode) {
1620         // If initial value of recurrence is nonnegative, and we are adding
1621         // a nonnegative number with nsw, the result can only be nonnegative
1622         // or poison value regardless of the number of times we execute the
1623         // add in phi recurrence. If initial value is negative and we are
1624         // adding a negative number with nsw, the result can only be
1625         // negative or poison value. Similar arguments apply to sub and mul.
1626         //
1627         // (add non-negative, non-negative) --> non-negative
1628         // (add negative, negative) --> negative
1629         case Instruction::Add: {
1630           if (Known2.isNonNegative() && Known3.isNonNegative())
1631             Known.makeNonNegative();
1632           else if (Known2.isNegative() && Known3.isNegative())
1633             Known.makeNegative();
1634           break;
1635         }
1636 
1637         // (sub nsw non-negative, negative) --> non-negative
1638         // (sub nsw negative, non-negative) --> negative
1639         case Instruction::Sub: {
1640           if (BO->getOperand(0) != I)
1641             break;
1642           if (Known2.isNonNegative() && Known3.isNegative())
1643             Known.makeNonNegative();
1644           else if (Known2.isNegative() && Known3.isNonNegative())
1645             Known.makeNegative();
1646           break;
1647         }
1648 
1649         // (mul nsw non-negative, non-negative) --> non-negative
1650         case Instruction::Mul:
1651           if (Known2.isNonNegative() && Known3.isNonNegative())
1652             Known.makeNonNegative();
1653           break;
1654 
1655         default:
1656           break;
1657         }
1658         break;
1659       }
1660 
1661       default:
1662         break;
1663       }
1664     }
1665 
1666     // Unreachable blocks may have zero-operand PHI nodes.
1667     if (P->getNumIncomingValues() == 0)
1668       break;
1669 
1670     // Otherwise take the unions of the known bit sets of the operands,
1671     // taking conservative care to avoid excessive recursion.
1672     if (Depth < MaxAnalysisRecursionDepth - 1 && Known.isUnknown()) {
1673       // Skip if every incoming value references to ourself.
1674       if (isa_and_nonnull<UndefValue>(P->hasConstantValue()))
1675         break;
1676 
1677       Known.Zero.setAllBits();
1678       Known.One.setAllBits();
1679       for (const Use &U : P->operands()) {
1680         Value *IncValue;
1681         const PHINode *CxtPhi;
1682         Instruction *CxtI;
1683         breakSelfRecursivePHI(&U, P, IncValue, CxtI, &CxtPhi);
1684         // Skip direct self references.
1685         if (IncValue == P)
1686           continue;
1687 
1688         // Change the context instruction to the "edge" that flows into the
1689         // phi. This is important because that is where the value is actually
1690         // "evaluated" even though it is used later somewhere else. (see also
1691         // D69571).
1692         SimplifyQuery RecQ = Q.getWithoutCondContext().getWithInstruction(CxtI);
1693 
1694         Known2 = KnownBits(BitWidth);
1695 
1696         // Recurse, but cap the recursion to one level, because we don't
1697         // want to waste time spinning around in loops.
1698         // TODO: See if we can base recursion limiter on number of incoming phi
1699         // edges so we don't overly clamp analysis.
1700         computeKnownBits(IncValue, DemandedElts, Known2,
1701                          MaxAnalysisRecursionDepth - 1, RecQ);
1702 
1703         // See if we can further use a conditional branch into the phi
1704         // to help us determine the range of the value.
1705         if (!Known2.isConstant()) {
1706           CmpPredicate Pred;
1707           const APInt *RHSC;
1708           BasicBlock *TrueSucc, *FalseSucc;
1709           // TODO: Use RHS Value and compute range from its known bits.
1710           if (match(RecQ.CxtI,
1711                     m_Br(m_c_ICmp(Pred, m_Specific(IncValue), m_APInt(RHSC)),
1712                          m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc)))) {
1713             // Check for cases of duplicate successors.
1714             if ((TrueSucc == CxtPhi->getParent()) !=
1715                 (FalseSucc == CxtPhi->getParent())) {
1716               // If we're using the false successor, invert the predicate.
1717               if (FalseSucc == CxtPhi->getParent())
1718                 Pred = CmpInst::getInversePredicate(Pred);
1719               // Get the knownbits implied by the incoming phi condition.
1720               auto CR = ConstantRange::makeExactICmpRegion(Pred, *RHSC);
1721               KnownBits KnownUnion = Known2.unionWith(CR.toKnownBits());
1722               // We can have conflicts here if we are analyzing deadcode (its
1723               // impossible for us reach this BB based the icmp).
1724               if (KnownUnion.hasConflict()) {
1725                 // No reason to continue analyzing in a known dead region, so
1726                 // just resetAll and break. This will cause us to also exit the
1727                 // outer loop.
1728                 Known.resetAll();
1729                 break;
1730               }
1731               Known2 = KnownUnion;
1732             }
1733           }
1734         }
1735 
1736         Known = Known.intersectWith(Known2);
1737         // If all bits have been ruled out, there's no need to check
1738         // more operands.
1739         if (Known.isUnknown())
1740           break;
1741       }
1742     }
1743     break;
1744   }
1745   case Instruction::Call:
1746   case Instruction::Invoke: {
1747     // If range metadata is attached to this call, set known bits from that,
1748     // and then intersect with known bits based on other properties of the
1749     // function.
1750     if (MDNode *MD =
1751             Q.IIQ.getMetadata(cast<Instruction>(I), LLVMContext::MD_range))
1752       computeKnownBitsFromRangeMetadata(*MD, Known);
1753 
1754     const auto *CB = cast<CallBase>(I);
1755 
1756     if (std::optional<ConstantRange> Range = CB->getRange())
1757       Known = Known.unionWith(Range->toKnownBits());
1758 
1759     if (const Value *RV = CB->getReturnedArgOperand()) {
1760       if (RV->getType() == I->getType()) {
1761         computeKnownBits(RV, Known2, Depth + 1, Q);
1762         Known = Known.unionWith(Known2);
1763         // If the function doesn't return properly for all input values
1764         // (e.g. unreachable exits) then there might be conflicts between the
1765         // argument value and the range metadata. Simply discard the known bits
1766         // in case of conflicts.
1767         if (Known.hasConflict())
1768           Known.resetAll();
1769       }
1770     }
1771     if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
1772       switch (II->getIntrinsicID()) {
1773       default:
1774         break;
1775       case Intrinsic::abs: {
1776         computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
1777         bool IntMinIsPoison = match(II->getArgOperand(1), m_One());
1778         Known = Known2.abs(IntMinIsPoison);
1779         break;
1780       }
1781       case Intrinsic::bitreverse:
1782         computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
1783         Known.Zero |= Known2.Zero.reverseBits();
1784         Known.One |= Known2.One.reverseBits();
1785         break;
1786       case Intrinsic::bswap:
1787         computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
1788         Known.Zero |= Known2.Zero.byteSwap();
1789         Known.One |= Known2.One.byteSwap();
1790         break;
1791       case Intrinsic::ctlz: {
1792         computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
1793         // If we have a known 1, its position is our upper bound.
1794         unsigned PossibleLZ = Known2.countMaxLeadingZeros();
1795         // If this call is poison for 0 input, the result will be less than 2^n.
1796         if (II->getArgOperand(1) == ConstantInt::getTrue(II->getContext()))
1797           PossibleLZ = std::min(PossibleLZ, BitWidth - 1);
1798         unsigned LowBits = llvm::bit_width(PossibleLZ);
1799         Known.Zero.setBitsFrom(LowBits);
1800         break;
1801       }
1802       case Intrinsic::cttz: {
1803         computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
1804         // If we have a known 1, its position is our upper bound.
1805         unsigned PossibleTZ = Known2.countMaxTrailingZeros();
1806         // If this call is poison for 0 input, the result will be less than 2^n.
1807         if (II->getArgOperand(1) == ConstantInt::getTrue(II->getContext()))
1808           PossibleTZ = std::min(PossibleTZ, BitWidth - 1);
1809         unsigned LowBits = llvm::bit_width(PossibleTZ);
1810         Known.Zero.setBitsFrom(LowBits);
1811         break;
1812       }
1813       case Intrinsic::ctpop: {
1814         computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
1815         // We can bound the space the count needs.  Also, bits known to be zero
1816         // can't contribute to the population.
1817         unsigned BitsPossiblySet = Known2.countMaxPopulation();
1818         unsigned LowBits = llvm::bit_width(BitsPossiblySet);
1819         Known.Zero.setBitsFrom(LowBits);
1820         // TODO: we could bound KnownOne using the lower bound on the number
1821         // of bits which might be set provided by popcnt KnownOne2.
1822         break;
1823       }
1824       case Intrinsic::fshr:
1825       case Intrinsic::fshl: {
1826         const APInt *SA;
1827         if (!match(I->getOperand(2), m_APInt(SA)))
1828           break;
1829 
1830         // Normalize to funnel shift left.
1831         uint64_t ShiftAmt = SA->urem(BitWidth);
1832         if (II->getIntrinsicID() == Intrinsic::fshr)
1833           ShiftAmt = BitWidth - ShiftAmt;
1834 
1835         KnownBits Known3(BitWidth);
1836         computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
1837         computeKnownBits(I->getOperand(1), DemandedElts, Known3, Depth + 1, Q);
1838 
1839         Known.Zero =
1840             Known2.Zero.shl(ShiftAmt) | Known3.Zero.lshr(BitWidth - ShiftAmt);
1841         Known.One =
1842             Known2.One.shl(ShiftAmt) | Known3.One.lshr(BitWidth - ShiftAmt);
1843         break;
1844       }
1845       case Intrinsic::uadd_sat:
1846         computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
1847         computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
1848         Known = KnownBits::uadd_sat(Known, Known2);
1849         break;
1850       case Intrinsic::usub_sat:
1851         computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
1852         computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
1853         Known = KnownBits::usub_sat(Known, Known2);
1854         break;
1855       case Intrinsic::sadd_sat:
1856         computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
1857         computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
1858         Known = KnownBits::sadd_sat(Known, Known2);
1859         break;
1860       case Intrinsic::ssub_sat:
1861         computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
1862         computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
1863         Known = KnownBits::ssub_sat(Known, Known2);
1864         break;
1865         // Vec reverse preserves bits from input vec.
1866       case Intrinsic::vector_reverse:
1867         computeKnownBits(I->getOperand(0), DemandedElts.reverseBits(), Known,
1868                          Depth + 1, Q);
1869         break;
1870         // for min/max/and/or reduce, any bit common to each element in the
1871         // input vec is set in the output.
1872       case Intrinsic::vector_reduce_and:
1873       case Intrinsic::vector_reduce_or:
1874       case Intrinsic::vector_reduce_umax:
1875       case Intrinsic::vector_reduce_umin:
1876       case Intrinsic::vector_reduce_smax:
1877       case Intrinsic::vector_reduce_smin:
1878         computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
1879         break;
1880       case Intrinsic::vector_reduce_xor: {
1881         computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
1882         // The zeros common to all vecs are zero in the output.
1883         // If the number of elements is odd, then the common ones remain. If the
1884         // number of elements is even, then the common ones becomes zeros.
1885         auto *VecTy = cast<VectorType>(I->getOperand(0)->getType());
1886         // Even, so the ones become zeros.
1887         bool EvenCnt = VecTy->getElementCount().isKnownEven();
1888         if (EvenCnt)
1889           Known.Zero |= Known.One;
1890         // Maybe even element count so need to clear ones.
1891         if (VecTy->isScalableTy() || EvenCnt)
1892           Known.One.clearAllBits();
1893         break;
1894       }
1895       case Intrinsic::umin:
1896         computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
1897         computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
1898         Known = KnownBits::umin(Known, Known2);
1899         break;
1900       case Intrinsic::umax:
1901         computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
1902         computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
1903         Known = KnownBits::umax(Known, Known2);
1904         break;
1905       case Intrinsic::smin:
1906         computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
1907         computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
1908         Known = KnownBits::smin(Known, Known2);
1909         unionWithMinMaxIntrinsicClamp(II, Known);
1910         break;
1911       case Intrinsic::smax:
1912         computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
1913         computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
1914         Known = KnownBits::smax(Known, Known2);
1915         unionWithMinMaxIntrinsicClamp(II, Known);
1916         break;
1917       case Intrinsic::ptrmask: {
1918         computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
1919 
1920         const Value *Mask = I->getOperand(1);
1921         Known2 = KnownBits(Mask->getType()->getScalarSizeInBits());
1922         computeKnownBits(Mask, DemandedElts, Known2, Depth + 1, Q);
1923         // TODO: 1-extend would be more precise.
1924         Known &= Known2.anyextOrTrunc(BitWidth);
1925         break;
1926       }
1927       case Intrinsic::x86_sse2_pmulh_w:
1928       case Intrinsic::x86_avx2_pmulh_w:
1929       case Intrinsic::x86_avx512_pmulh_w_512:
1930         computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
1931         computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
1932         Known = KnownBits::mulhs(Known, Known2);
1933         break;
1934       case Intrinsic::x86_sse2_pmulhu_w:
1935       case Intrinsic::x86_avx2_pmulhu_w:
1936       case Intrinsic::x86_avx512_pmulhu_w_512:
1937         computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
1938         computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
1939         Known = KnownBits::mulhu(Known, Known2);
1940         break;
1941       case Intrinsic::x86_sse42_crc32_64_64:
1942         Known.Zero.setBitsFrom(32);
1943         break;
1944       case Intrinsic::x86_ssse3_phadd_d_128:
1945       case Intrinsic::x86_ssse3_phadd_w_128:
1946       case Intrinsic::x86_avx2_phadd_d:
1947       case Intrinsic::x86_avx2_phadd_w: {
1948         Known = computeKnownBitsForHorizontalOperation(
1949             I, DemandedElts, Depth, Q,
1950             [](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
1951               return KnownBits::add(KnownLHS, KnownRHS);
1952             });
1953         break;
1954       }
1955       case Intrinsic::x86_ssse3_phadd_sw_128:
1956       case Intrinsic::x86_avx2_phadd_sw: {
1957         Known = computeKnownBitsForHorizontalOperation(I, DemandedElts, Depth,
1958                                                        Q, KnownBits::sadd_sat);
1959         break;
1960       }
1961       case Intrinsic::x86_ssse3_phsub_d_128:
1962       case Intrinsic::x86_ssse3_phsub_w_128:
1963       case Intrinsic::x86_avx2_phsub_d:
1964       case Intrinsic::x86_avx2_phsub_w: {
1965         Known = computeKnownBitsForHorizontalOperation(
1966             I, DemandedElts, Depth, Q,
1967             [](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
1968               return KnownBits::sub(KnownLHS, KnownRHS);
1969             });
1970         break;
1971       }
1972       case Intrinsic::x86_ssse3_phsub_sw_128:
1973       case Intrinsic::x86_avx2_phsub_sw: {
1974         Known = computeKnownBitsForHorizontalOperation(I, DemandedElts, Depth,
1975                                                        Q, KnownBits::ssub_sat);
1976         break;
1977       }
1978       case Intrinsic::riscv_vsetvli:
1979       case Intrinsic::riscv_vsetvlimax: {
1980         bool HasAVL = II->getIntrinsicID() == Intrinsic::riscv_vsetvli;
1981         const ConstantRange Range = getVScaleRange(II->getFunction(), BitWidth);
1982         uint64_t SEW = RISCVVType::decodeVSEW(
1983             cast<ConstantInt>(II->getArgOperand(HasAVL))->getZExtValue());
1984         RISCVII::VLMUL VLMUL = static_cast<RISCVII::VLMUL>(
1985             cast<ConstantInt>(II->getArgOperand(1 + HasAVL))->getZExtValue());
1986         uint64_t MaxVLEN =
1987             Range.getUnsignedMax().getZExtValue() * RISCV::RVVBitsPerBlock;
1988         uint64_t MaxVL = MaxVLEN / RISCVVType::getSEWLMULRatio(SEW, VLMUL);
1989 
1990         // Result of vsetvli must be not larger than AVL.
1991         if (HasAVL)
1992           if (auto *CI = dyn_cast<ConstantInt>(II->getArgOperand(0)))
1993             MaxVL = std::min(MaxVL, CI->getZExtValue());
1994 
1995         unsigned KnownZeroFirstBit = Log2_32(MaxVL) + 1;
1996         if (BitWidth > KnownZeroFirstBit)
1997           Known.Zero.setBitsFrom(KnownZeroFirstBit);
1998         break;
1999       }
2000       case Intrinsic::vscale: {
2001         if (!II->getParent() || !II->getFunction())
2002           break;
2003 
2004         Known = getVScaleRange(II->getFunction(), BitWidth).toKnownBits();
2005         break;
2006       }
2007       }
2008     }
2009     break;
2010   }
2011   case Instruction::ShuffleVector: {
2012     auto *Shuf = dyn_cast<ShuffleVectorInst>(I);
2013     // FIXME: Do we need to handle ConstantExpr involving shufflevectors?
2014     if (!Shuf) {
2015       Known.resetAll();
2016       return;
2017     }
2018     // For undef elements, we don't know anything about the common state of
2019     // the shuffle result.
2020     APInt DemandedLHS, DemandedRHS;
2021     if (!getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS)) {
2022       Known.resetAll();
2023       return;
2024     }
2025     Known.One.setAllBits();
2026     Known.Zero.setAllBits();
2027     if (!!DemandedLHS) {
2028       const Value *LHS = Shuf->getOperand(0);
2029       computeKnownBits(LHS, DemandedLHS, Known, Depth + 1, Q);
2030       // If we don't know any bits, early out.
2031       if (Known.isUnknown())
2032         break;
2033     }
2034     if (!!DemandedRHS) {
2035       const Value *RHS = Shuf->getOperand(1);
2036       computeKnownBits(RHS, DemandedRHS, Known2, Depth + 1, Q);
2037       Known = Known.intersectWith(Known2);
2038     }
2039     break;
2040   }
2041   case Instruction::InsertElement: {
2042     if (isa<ScalableVectorType>(I->getType())) {
2043       Known.resetAll();
2044       return;
2045     }
2046     const Value *Vec = I->getOperand(0);
2047     const Value *Elt = I->getOperand(1);
2048     auto *CIdx = dyn_cast<ConstantInt>(I->getOperand(2));
2049     unsigned NumElts = DemandedElts.getBitWidth();
2050     APInt DemandedVecElts = DemandedElts;
2051     bool NeedsElt = true;
2052     // If we know the index we are inserting too, clear it from Vec check.
2053     if (CIdx && CIdx->getValue().ult(NumElts)) {
2054       DemandedVecElts.clearBit(CIdx->getZExtValue());
2055       NeedsElt = DemandedElts[CIdx->getZExtValue()];
2056     }
2057 
2058     Known.One.setAllBits();
2059     Known.Zero.setAllBits();
2060     if (NeedsElt) {
2061       computeKnownBits(Elt, Known, Depth + 1, Q);
2062       // If we don't know any bits, early out.
2063       if (Known.isUnknown())
2064         break;
2065     }
2066 
2067     if (!DemandedVecElts.isZero()) {
2068       computeKnownBits(Vec, DemandedVecElts, Known2, Depth + 1, Q);
2069       Known = Known.intersectWith(Known2);
2070     }
2071     break;
2072   }
2073   case Instruction::ExtractElement: {
2074     // Look through extract element. If the index is non-constant or
2075     // out-of-range demand all elements, otherwise just the extracted element.
2076     const Value *Vec = I->getOperand(0);
2077     const Value *Idx = I->getOperand(1);
2078     auto *CIdx = dyn_cast<ConstantInt>(Idx);
2079     if (isa<ScalableVectorType>(Vec->getType())) {
2080       // FIXME: there's probably *something* we can do with scalable vectors
2081       Known.resetAll();
2082       break;
2083     }
2084     unsigned NumElts = cast<FixedVectorType>(Vec->getType())->getNumElements();
2085     APInt DemandedVecElts = APInt::getAllOnes(NumElts);
2086     if (CIdx && CIdx->getValue().ult(NumElts))
2087       DemandedVecElts = APInt::getOneBitSet(NumElts, CIdx->getZExtValue());
2088     computeKnownBits(Vec, DemandedVecElts, Known, Depth + 1, Q);
2089     break;
2090   }
2091   case Instruction::ExtractValue:
2092     if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I->getOperand(0))) {
2093       const ExtractValueInst *EVI = cast<ExtractValueInst>(I);
2094       if (EVI->getNumIndices() != 1) break;
2095       if (EVI->getIndices()[0] == 0) {
2096         switch (II->getIntrinsicID()) {
2097         default: break;
2098         case Intrinsic::uadd_with_overflow:
2099         case Intrinsic::sadd_with_overflow:
2100           computeKnownBitsAddSub(
2101               true, II->getArgOperand(0), II->getArgOperand(1), /*NSW=*/false,
2102               /* NUW=*/false, DemandedElts, Known, Known2, Depth, Q);
2103           break;
2104         case Intrinsic::usub_with_overflow:
2105         case Intrinsic::ssub_with_overflow:
2106           computeKnownBitsAddSub(
2107               false, II->getArgOperand(0), II->getArgOperand(1), /*NSW=*/false,
2108               /* NUW=*/false, DemandedElts, Known, Known2, Depth, Q);
2109           break;
2110         case Intrinsic::umul_with_overflow:
2111         case Intrinsic::smul_with_overflow:
2112           computeKnownBitsMul(II->getArgOperand(0), II->getArgOperand(1), false,
2113                               false, DemandedElts, Known, Known2, Depth, Q);
2114           break;
2115         }
2116       }
2117     }
2118     break;
2119   case Instruction::Freeze:
2120     if (isGuaranteedNotToBePoison(I->getOperand(0), Q.AC, Q.CxtI, Q.DT,
2121                                   Depth + 1))
2122       computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
2123     break;
2124   }
2125 }
2126 
2127 /// Determine which bits of V are known to be either zero or one and return
2128 /// them.
2129 KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
2130                                  unsigned Depth, const SimplifyQuery &Q) {
2131   KnownBits Known(getBitWidth(V->getType(), Q.DL));
2132   ::computeKnownBits(V, DemandedElts, Known, Depth, Q);
2133   return Known;
2134 }
2135 
2136 /// Determine which bits of V are known to be either zero or one and return
2137 /// them.
2138 KnownBits llvm::computeKnownBits(const Value *V, unsigned Depth,
2139                                  const SimplifyQuery &Q) {
2140   KnownBits Known(getBitWidth(V->getType(), Q.DL));
2141   computeKnownBits(V, Known, Depth, Q);
2142   return Known;
2143 }
2144 
2145 /// Determine which bits of V are known to be either zero or one and return
2146 /// them in the Known bit set.
2147 ///
2148 /// NOTE: we cannot consider 'undef' to be "IsZero" here.  The problem is that
2149 /// we cannot optimize based on the assumption that it is zero without changing
2150 /// it to be an explicit zero.  If we don't change it to zero, other code could
2151 /// optimized based on the contradictory assumption that it is non-zero.
2152 /// Because instcombine aggressively folds operations with undef args anyway,
2153 /// this won't lose us code quality.
2154 ///
2155 /// This function is defined on values with integer type, values with pointer
2156 /// type, and vectors of integers.  In the case
2157 /// where V is a vector, known zero, and known one values are the
2158 /// same width as the vector element, and the bit is set only if it is true
2159 /// for all of the demanded elements in the vector specified by DemandedElts.
2160 void computeKnownBits(const Value *V, const APInt &DemandedElts,
2161                       KnownBits &Known, unsigned Depth,
2162                       const SimplifyQuery &Q) {
2163   if (!DemandedElts) {
2164     // No demanded elts, better to assume we don't know anything.
2165     Known.resetAll();
2166     return;
2167   }
2168 
2169   assert(V && "No Value?");
2170   assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
2171 
2172 #ifndef NDEBUG
2173   Type *Ty = V->getType();
2174   unsigned BitWidth = Known.getBitWidth();
2175 
2176   assert((Ty->isIntOrIntVectorTy(BitWidth) || Ty->isPtrOrPtrVectorTy()) &&
2177          "Not integer or pointer type!");
2178 
2179   if (auto *FVTy = dyn_cast<FixedVectorType>(Ty)) {
2180     assert(
2181         FVTy->getNumElements() == DemandedElts.getBitWidth() &&
2182         "DemandedElt width should equal the fixed vector number of elements");
2183   } else {
2184     assert(DemandedElts == APInt(1, 1) &&
2185            "DemandedElt width should be 1 for scalars or scalable vectors");
2186   }
2187 
2188   Type *ScalarTy = Ty->getScalarType();
2189   if (ScalarTy->isPointerTy()) {
2190     assert(BitWidth == Q.DL.getPointerTypeSizeInBits(ScalarTy) &&
2191            "V and Known should have same BitWidth");
2192   } else {
2193     assert(BitWidth == Q.DL.getTypeSizeInBits(ScalarTy) &&
2194            "V and Known should have same BitWidth");
2195   }
2196 #endif
2197 
2198   const APInt *C;
2199   if (match(V, m_APInt(C))) {
2200     // We know all of the bits for a scalar constant or a splat vector constant!
2201     Known = KnownBits::makeConstant(*C);
2202     return;
2203   }
2204   // Null and aggregate-zero are all-zeros.
2205   if (isa<ConstantPointerNull>(V) || isa<ConstantAggregateZero>(V)) {
2206     Known.setAllZero();
2207     return;
2208   }
2209   // Handle a constant vector by taking the intersection of the known bits of
2210   // each element.
2211   if (const ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(V)) {
2212     assert(!isa<ScalableVectorType>(V->getType()));
2213     // We know that CDV must be a vector of integers. Take the intersection of
2214     // each element.
2215     Known.Zero.setAllBits(); Known.One.setAllBits();
2216     for (unsigned i = 0, e = CDV->getNumElements(); i != e; ++i) {
2217       if (!DemandedElts[i])
2218         continue;
2219       APInt Elt = CDV->getElementAsAPInt(i);
2220       Known.Zero &= ~Elt;
2221       Known.One &= Elt;
2222     }
2223     if (Known.hasConflict())
2224       Known.resetAll();
2225     return;
2226   }
2227 
2228   if (const auto *CV = dyn_cast<ConstantVector>(V)) {
2229     assert(!isa<ScalableVectorType>(V->getType()));
2230     // We know that CV must be a vector of integers. Take the intersection of
2231     // each element.
2232     Known.Zero.setAllBits(); Known.One.setAllBits();
2233     for (unsigned i = 0, e = CV->getNumOperands(); i != e; ++i) {
2234       if (!DemandedElts[i])
2235         continue;
2236       Constant *Element = CV->getAggregateElement(i);
2237       if (isa<PoisonValue>(Element))
2238         continue;
2239       auto *ElementCI = dyn_cast_or_null<ConstantInt>(Element);
2240       if (!ElementCI) {
2241         Known.resetAll();
2242         return;
2243       }
2244       const APInt &Elt = ElementCI->getValue();
2245       Known.Zero &= ~Elt;
2246       Known.One &= Elt;
2247     }
2248     if (Known.hasConflict())
2249       Known.resetAll();
2250     return;
2251   }
2252 
2253   // Start out not knowing anything.
2254   Known.resetAll();
2255 
2256   // We can't imply anything about undefs.
2257   if (isa<UndefValue>(V))
2258     return;
2259 
2260   // There's no point in looking through other users of ConstantData for
2261   // assumptions.  Confirm that we've handled them all.
2262   assert(!isa<ConstantData>(V) && "Unhandled constant data!");
2263 
2264   if (const auto *A = dyn_cast<Argument>(V))
2265     if (std::optional<ConstantRange> Range = A->getRange())
2266       Known = Range->toKnownBits();
2267 
2268   // All recursive calls that increase depth must come after this.
2269   if (Depth == MaxAnalysisRecursionDepth)
2270     return;
2271 
2272   // A weak GlobalAlias is totally unknown. A non-weak GlobalAlias has
2273   // the bits of its aliasee.
2274   if (const GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) {
2275     if (!GA->isInterposable())
2276       computeKnownBits(GA->getAliasee(), Known, Depth + 1, Q);
2277     return;
2278   }
2279 
2280   if (const Operator *I = dyn_cast<Operator>(V))
2281     computeKnownBitsFromOperator(I, DemandedElts, Known, Depth, Q);
2282   else if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) {
2283     if (std::optional<ConstantRange> CR = GV->getAbsoluteSymbolRange())
2284       Known = CR->toKnownBits();
2285   }
2286 
2287   // Aligned pointers have trailing zeros - refine Known.Zero set
2288   if (isa<PointerType>(V->getType())) {
2289     Align Alignment = V->getPointerAlignment(Q.DL);
2290     Known.Zero.setLowBits(Log2(Alignment));
2291   }
2292 
2293   // computeKnownBitsFromContext strictly refines Known.
2294   // Therefore, we run them after computeKnownBitsFromOperator.
2295 
2296   // Check whether we can determine known bits from context such as assumes.
2297   computeKnownBitsFromContext(V, Known, Depth, Q);
2298 }
2299 
2300 /// Try to detect a recurrence that the value of the induction variable is
2301 /// always a power of two (or zero).
2302 static bool isPowerOfTwoRecurrence(const PHINode *PN, bool OrZero,
2303                                    unsigned Depth, SimplifyQuery &Q) {
2304   BinaryOperator *BO = nullptr;
2305   Value *Start = nullptr, *Step = nullptr;
2306   if (!matchSimpleRecurrence(PN, BO, Start, Step))
2307     return false;
2308 
2309   // Initial value must be a power of two.
2310   for (const Use &U : PN->operands()) {
2311     if (U.get() == Start) {
2312       // Initial value comes from a different BB, need to adjust context
2313       // instruction for analysis.
2314       Q.CxtI = PN->getIncomingBlock(U)->getTerminator();
2315       if (!isKnownToBeAPowerOfTwo(Start, OrZero, Depth, Q))
2316         return false;
2317     }
2318   }
2319 
2320   // Except for Mul, the induction variable must be on the left side of the
2321   // increment expression, otherwise its value can be arbitrary.
2322   if (BO->getOpcode() != Instruction::Mul && BO->getOperand(1) != Step)
2323     return false;
2324 
2325   Q.CxtI = BO->getParent()->getTerminator();
2326   switch (BO->getOpcode()) {
2327   case Instruction::Mul:
2328     // Power of two is closed under multiplication.
2329     return (OrZero || Q.IIQ.hasNoUnsignedWrap(BO) ||
2330             Q.IIQ.hasNoSignedWrap(BO)) &&
2331            isKnownToBeAPowerOfTwo(Step, OrZero, Depth, Q);
2332   case Instruction::SDiv:
2333     // Start value must not be signmask for signed division, so simply being a
2334     // power of two is not sufficient, and it has to be a constant.
2335     if (!match(Start, m_Power2()) || match(Start, m_SignMask()))
2336       return false;
2337     [[fallthrough]];
2338   case Instruction::UDiv:
2339     // Divisor must be a power of two.
2340     // If OrZero is false, cannot guarantee induction variable is non-zero after
2341     // division, same for Shr, unless it is exact division.
2342     return (OrZero || Q.IIQ.isExact(BO)) &&
2343            isKnownToBeAPowerOfTwo(Step, false, Depth, Q);
2344   case Instruction::Shl:
2345     return OrZero || Q.IIQ.hasNoUnsignedWrap(BO) || Q.IIQ.hasNoSignedWrap(BO);
2346   case Instruction::AShr:
2347     if (!match(Start, m_Power2()) || match(Start, m_SignMask()))
2348       return false;
2349     [[fallthrough]];
2350   case Instruction::LShr:
2351     return OrZero || Q.IIQ.isExact(BO);
2352   default:
2353     return false;
2354   }
2355 }
2356 
2357 /// Return true if we can infer that \p V is known to be a power of 2 from
2358 /// dominating condition \p Cond (e.g., ctpop(V) == 1).
2359 static bool isImpliedToBeAPowerOfTwoFromCond(const Value *V, bool OrZero,
2360                                              const Value *Cond,
2361                                              bool CondIsTrue) {
2362   CmpPredicate Pred;
2363   const APInt *RHSC;
2364   if (!match(Cond, m_ICmp(Pred, m_Intrinsic<Intrinsic::ctpop>(m_Specific(V)),
2365                           m_APInt(RHSC))))
2366     return false;
2367   if (!CondIsTrue)
2368     Pred = ICmpInst::getInversePredicate(Pred);
2369   // ctpop(V) u< 2
2370   if (OrZero && Pred == ICmpInst::ICMP_ULT && *RHSC == 2)
2371     return true;
2372   // ctpop(V) == 1
2373   return Pred == ICmpInst::ICMP_EQ && *RHSC == 1;
2374 }
2375 
2376 /// Return true if the given value is known to have exactly one
2377 /// bit set when defined. For vectors return true if every element is known to
2378 /// be a power of two when defined. Supports values with integer or pointer
2379 /// types and vectors of integers.
2380 bool llvm::isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth,
2381                                   const SimplifyQuery &Q) {
2382   assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
2383 
2384   if (isa<Constant>(V))
2385     return OrZero ? match(V, m_Power2OrZero()) : match(V, m_Power2());
2386 
2387   // i1 is by definition a power of 2 or zero.
2388   if (OrZero && V->getType()->getScalarSizeInBits() == 1)
2389     return true;
2390 
2391   // Try to infer from assumptions.
2392   if (Q.AC && Q.CxtI) {
2393     for (auto &AssumeVH : Q.AC->assumptionsFor(V)) {
2394       if (!AssumeVH)
2395         continue;
2396       CallInst *I = cast<CallInst>(AssumeVH);
2397       if (isImpliedToBeAPowerOfTwoFromCond(V, OrZero, I->getArgOperand(0),
2398                                            /*CondIsTrue=*/true) &&
2399           isValidAssumeForContext(I, Q.CxtI, Q.DT))
2400         return true;
2401     }
2402   }
2403 
2404   // Handle dominating conditions.
2405   if (Q.DC && Q.CxtI && Q.DT) {
2406     for (BranchInst *BI : Q.DC->conditionsFor(V)) {
2407       Value *Cond = BI->getCondition();
2408 
2409       BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
2410       if (isImpliedToBeAPowerOfTwoFromCond(V, OrZero, Cond,
2411                                            /*CondIsTrue=*/true) &&
2412           Q.DT->dominates(Edge0, Q.CxtI->getParent()))
2413         return true;
2414 
2415       BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(1));
2416       if (isImpliedToBeAPowerOfTwoFromCond(V, OrZero, Cond,
2417                                            /*CondIsTrue=*/false) &&
2418           Q.DT->dominates(Edge1, Q.CxtI->getParent()))
2419         return true;
2420     }
2421   }
2422 
2423   auto *I = dyn_cast<Instruction>(V);
2424   if (!I)
2425     return false;
2426 
2427   if (Q.CxtI && match(V, m_VScale())) {
2428     const Function *F = Q.CxtI->getFunction();
2429     // The vscale_range indicates vscale is a power-of-two.
2430     return F->hasFnAttribute(Attribute::VScaleRange);
2431   }
2432 
2433   // 1 << X is clearly a power of two if the one is not shifted off the end.  If
2434   // it is shifted off the end then the result is undefined.
2435   if (match(I, m_Shl(m_One(), m_Value())))
2436     return true;
2437 
2438   // (signmask) >>l X is clearly a power of two if the one is not shifted off
2439   // the bottom.  If it is shifted off the bottom then the result is undefined.
2440   if (match(I, m_LShr(m_SignMask(), m_Value())))
2441     return true;
2442 
2443   // The remaining tests are all recursive, so bail out if we hit the limit.
2444   if (Depth++ == MaxAnalysisRecursionDepth)
2445     return false;
2446 
2447   switch (I->getOpcode()) {
2448   case Instruction::ZExt:
2449     return isKnownToBeAPowerOfTwo(I->getOperand(0), OrZero, Depth, Q);
2450   case Instruction::Trunc:
2451     return OrZero && isKnownToBeAPowerOfTwo(I->getOperand(0), OrZero, Depth, Q);
2452   case Instruction::Shl:
2453     if (OrZero || Q.IIQ.hasNoUnsignedWrap(I) || Q.IIQ.hasNoSignedWrap(I))
2454       return isKnownToBeAPowerOfTwo(I->getOperand(0), OrZero, Depth, Q);
2455     return false;
2456   case Instruction::LShr:
2457     if (OrZero || Q.IIQ.isExact(cast<BinaryOperator>(I)))
2458       return isKnownToBeAPowerOfTwo(I->getOperand(0), OrZero, Depth, Q);
2459     return false;
2460   case Instruction::UDiv:
2461     if (Q.IIQ.isExact(cast<BinaryOperator>(I)))
2462       return isKnownToBeAPowerOfTwo(I->getOperand(0), OrZero, Depth, Q);
2463     return false;
2464   case Instruction::Mul:
2465     return isKnownToBeAPowerOfTwo(I->getOperand(1), OrZero, Depth, Q) &&
2466            isKnownToBeAPowerOfTwo(I->getOperand(0), OrZero, Depth, Q) &&
2467            (OrZero || isKnownNonZero(I, Q, Depth));
2468   case Instruction::And:
2469     // A power of two and'd with anything is a power of two or zero.
2470     if (OrZero &&
2471         (isKnownToBeAPowerOfTwo(I->getOperand(1), /*OrZero*/ true, Depth, Q) ||
2472          isKnownToBeAPowerOfTwo(I->getOperand(0), /*OrZero*/ true, Depth, Q)))
2473       return true;
2474     // X & (-X) is always a power of two or zero.
2475     if (match(I->getOperand(0), m_Neg(m_Specific(I->getOperand(1)))) ||
2476         match(I->getOperand(1), m_Neg(m_Specific(I->getOperand(0)))))
2477       return OrZero || isKnownNonZero(I->getOperand(0), Q, Depth);
2478     return false;
2479   case Instruction::Add: {
2480     // Adding a power-of-two or zero to the same power-of-two or zero yields
2481     // either the original power-of-two, a larger power-of-two or zero.
2482     const OverflowingBinaryOperator *VOBO = cast<OverflowingBinaryOperator>(V);
2483     if (OrZero || Q.IIQ.hasNoUnsignedWrap(VOBO) ||
2484         Q.IIQ.hasNoSignedWrap(VOBO)) {
2485       if (match(I->getOperand(0),
2486                 m_c_And(m_Specific(I->getOperand(1)), m_Value())) &&
2487           isKnownToBeAPowerOfTwo(I->getOperand(1), OrZero, Depth, Q))
2488         return true;
2489       if (match(I->getOperand(1),
2490                 m_c_And(m_Specific(I->getOperand(0)), m_Value())) &&
2491           isKnownToBeAPowerOfTwo(I->getOperand(0), OrZero, Depth, Q))
2492         return true;
2493 
2494       unsigned BitWidth = V->getType()->getScalarSizeInBits();
2495       KnownBits LHSBits(BitWidth);
2496       computeKnownBits(I->getOperand(0), LHSBits, Depth, Q);
2497 
2498       KnownBits RHSBits(BitWidth);
2499       computeKnownBits(I->getOperand(1), RHSBits, Depth, Q);
2500       // If i8 V is a power of two or zero:
2501       //  ZeroBits: 1 1 1 0 1 1 1 1
2502       // ~ZeroBits: 0 0 0 1 0 0 0 0
2503       if ((~(LHSBits.Zero & RHSBits.Zero)).isPowerOf2())
2504         // If OrZero isn't set, we cannot give back a zero result.
2505         // Make sure either the LHS or RHS has a bit set.
2506         if (OrZero || RHSBits.One.getBoolValue() || LHSBits.One.getBoolValue())
2507           return true;
2508     }
2509 
2510     // LShr(UINT_MAX, Y) + 1 is a power of two (if add is nuw) or zero.
2511     if (OrZero || Q.IIQ.hasNoUnsignedWrap(VOBO))
2512       if (match(I, m_Add(m_LShr(m_AllOnes(), m_Value()), m_One())))
2513         return true;
2514     return false;
2515   }
2516   case Instruction::Select:
2517     return isKnownToBeAPowerOfTwo(I->getOperand(1), OrZero, Depth, Q) &&
2518            isKnownToBeAPowerOfTwo(I->getOperand(2), OrZero, Depth, Q);
2519   case Instruction::PHI: {
2520     // A PHI node is power of two if all incoming values are power of two, or if
2521     // it is an induction variable where in each step its value is a power of
2522     // two.
2523     auto *PN = cast<PHINode>(I);
2524     SimplifyQuery RecQ = Q.getWithoutCondContext();
2525 
2526     // Check if it is an induction variable and always power of two.
2527     if (isPowerOfTwoRecurrence(PN, OrZero, Depth, RecQ))
2528       return true;
2529 
2530     // Recursively check all incoming values. Limit recursion to 2 levels, so
2531     // that search complexity is limited to number of operands^2.
2532     unsigned NewDepth = std::max(Depth, MaxAnalysisRecursionDepth - 1);
2533     return llvm::all_of(PN->operands(), [&](const Use &U) {
2534       // Value is power of 2 if it is coming from PHI node itself by induction.
2535       if (U.get() == PN)
2536         return true;
2537 
2538       // Change the context instruction to the incoming block where it is
2539       // evaluated.
2540       RecQ.CxtI = PN->getIncomingBlock(U)->getTerminator();
2541       return isKnownToBeAPowerOfTwo(U.get(), OrZero, NewDepth, RecQ);
2542     });
2543   }
2544   case Instruction::Invoke:
2545   case Instruction::Call: {
2546     if (auto *II = dyn_cast<IntrinsicInst>(I)) {
2547       switch (II->getIntrinsicID()) {
2548       case Intrinsic::umax:
2549       case Intrinsic::smax:
2550       case Intrinsic::umin:
2551       case Intrinsic::smin:
2552         return isKnownToBeAPowerOfTwo(II->getArgOperand(1), OrZero, Depth, Q) &&
2553                isKnownToBeAPowerOfTwo(II->getArgOperand(0), OrZero, Depth, Q);
2554       // bswap/bitreverse just move around bits, but don't change any 1s/0s
2555       // thus dont change pow2/non-pow2 status.
2556       case Intrinsic::bitreverse:
2557       case Intrinsic::bswap:
2558         return isKnownToBeAPowerOfTwo(II->getArgOperand(0), OrZero, Depth, Q);
2559       case Intrinsic::fshr:
2560       case Intrinsic::fshl:
2561         // If Op0 == Op1, this is a rotate. is_pow2(rotate(x, y)) == is_pow2(x)
2562         if (II->getArgOperand(0) == II->getArgOperand(1))
2563           return isKnownToBeAPowerOfTwo(II->getArgOperand(0), OrZero, Depth, Q);
2564         break;
2565       default:
2566         break;
2567       }
2568     }
2569     return false;
2570   }
2571   default:
2572     return false;
2573   }
2574 }
2575 
2576 /// Test whether a GEP's result is known to be non-null.
2577 ///
2578 /// Uses properties inherent in a GEP to try to determine whether it is known
2579 /// to be non-null.
2580 ///
2581 /// Currently this routine does not support vector GEPs.
2582 static bool isGEPKnownNonNull(const GEPOperator *GEP, unsigned Depth,
2583                               const SimplifyQuery &Q) {
2584   const Function *F = nullptr;
2585   if (const Instruction *I = dyn_cast<Instruction>(GEP))
2586     F = I->getFunction();
2587 
2588   // If the gep is nuw or inbounds with invalid null pointer, then the GEP
2589   // may be null iff the base pointer is null and the offset is zero.
2590   if (!GEP->hasNoUnsignedWrap() &&
2591       !(GEP->isInBounds() &&
2592         !NullPointerIsDefined(F, GEP->getPointerAddressSpace())))
2593     return false;
2594 
2595   // FIXME: Support vector-GEPs.
2596   assert(GEP->getType()->isPointerTy() && "We only support plain pointer GEP");
2597 
2598   // If the base pointer is non-null, we cannot walk to a null address with an
2599   // inbounds GEP in address space zero.
2600   if (isKnownNonZero(GEP->getPointerOperand(), Q, Depth))
2601     return true;
2602 
2603   // Walk the GEP operands and see if any operand introduces a non-zero offset.
2604   // If so, then the GEP cannot produce a null pointer, as doing so would
2605   // inherently violate the inbounds contract within address space zero.
2606   for (gep_type_iterator GTI = gep_type_begin(GEP), GTE = gep_type_end(GEP);
2607        GTI != GTE; ++GTI) {
2608     // Struct types are easy -- they must always be indexed by a constant.
2609     if (StructType *STy = GTI.getStructTypeOrNull()) {
2610       ConstantInt *OpC = cast<ConstantInt>(GTI.getOperand());
2611       unsigned ElementIdx = OpC->getZExtValue();
2612       const StructLayout *SL = Q.DL.getStructLayout(STy);
2613       uint64_t ElementOffset = SL->getElementOffset(ElementIdx);
2614       if (ElementOffset > 0)
2615         return true;
2616       continue;
2617     }
2618 
2619     // If we have a zero-sized type, the index doesn't matter. Keep looping.
2620     if (GTI.getSequentialElementStride(Q.DL).isZero())
2621       continue;
2622 
2623     // Fast path the constant operand case both for efficiency and so we don't
2624     // increment Depth when just zipping down an all-constant GEP.
2625     if (ConstantInt *OpC = dyn_cast<ConstantInt>(GTI.getOperand())) {
2626       if (!OpC->isZero())
2627         return true;
2628       continue;
2629     }
2630 
2631     // We post-increment Depth here because while isKnownNonZero increments it
2632     // as well, when we pop back up that increment won't persist. We don't want
2633     // to recurse 10k times just because we have 10k GEP operands. We don't
2634     // bail completely out because we want to handle constant GEPs regardless
2635     // of depth.
2636     if (Depth++ >= MaxAnalysisRecursionDepth)
2637       continue;
2638 
2639     if (isKnownNonZero(GTI.getOperand(), Q, Depth))
2640       return true;
2641   }
2642 
2643   return false;
2644 }
2645 
2646 static bool isKnownNonNullFromDominatingCondition(const Value *V,
2647                                                   const Instruction *CtxI,
2648                                                   const DominatorTree *DT) {
2649   assert(!isa<Constant>(V) && "Called for constant?");
2650 
2651   if (!CtxI || !DT)
2652     return false;
2653 
2654   unsigned NumUsesExplored = 0;
2655   for (auto &U : V->uses()) {
2656     // Avoid massive lists
2657     if (NumUsesExplored >= DomConditionsMaxUses)
2658       break;
2659     NumUsesExplored++;
2660 
2661     const Instruction *UI = cast<Instruction>(U.getUser());
2662     // If the value is used as an argument to a call or invoke, then argument
2663     // attributes may provide an answer about null-ness.
2664     if (V->getType()->isPointerTy()) {
2665       if (const auto *CB = dyn_cast<CallBase>(UI)) {
2666         if (CB->isArgOperand(&U) &&
2667             CB->paramHasNonNullAttr(CB->getArgOperandNo(&U),
2668                                     /*AllowUndefOrPoison=*/false) &&
2669             DT->dominates(CB, CtxI))
2670           return true;
2671       }
2672     }
2673 
2674     // If the value is used as a load/store, then the pointer must be non null.
2675     if (V == getLoadStorePointerOperand(UI)) {
2676       if (!NullPointerIsDefined(UI->getFunction(),
2677                                 V->getType()->getPointerAddressSpace()) &&
2678           DT->dominates(UI, CtxI))
2679         return true;
2680     }
2681 
2682     if ((match(UI, m_IDiv(m_Value(), m_Specific(V))) ||
2683          match(UI, m_IRem(m_Value(), m_Specific(V)))) &&
2684         isValidAssumeForContext(UI, CtxI, DT))
2685       return true;
2686 
2687     // Consider only compare instructions uniquely controlling a branch
2688     Value *RHS;
2689     CmpPredicate Pred;
2690     if (!match(UI, m_c_ICmp(Pred, m_Specific(V), m_Value(RHS))))
2691       continue;
2692 
2693     bool NonNullIfTrue;
2694     if (cmpExcludesZero(Pred, RHS))
2695       NonNullIfTrue = true;
2696     else if (cmpExcludesZero(CmpInst::getInversePredicate(Pred), RHS))
2697       NonNullIfTrue = false;
2698     else
2699       continue;
2700 
2701     SmallVector<const User *, 4> WorkList;
2702     SmallPtrSet<const User *, 4> Visited;
2703     for (const auto *CmpU : UI->users()) {
2704       assert(WorkList.empty() && "Should be!");
2705       if (Visited.insert(CmpU).second)
2706         WorkList.push_back(CmpU);
2707 
2708       while (!WorkList.empty()) {
2709         auto *Curr = WorkList.pop_back_val();
2710 
2711         // If a user is an AND, add all its users to the work list. We only
2712         // propagate "pred != null" condition through AND because it is only
2713         // correct to assume that all conditions of AND are met in true branch.
2714         // TODO: Support similar logic of OR and EQ predicate?
2715         if (NonNullIfTrue)
2716           if (match(Curr, m_LogicalAnd(m_Value(), m_Value()))) {
2717             for (const auto *CurrU : Curr->users())
2718               if (Visited.insert(CurrU).second)
2719                 WorkList.push_back(CurrU);
2720             continue;
2721           }
2722 
2723         if (const BranchInst *BI = dyn_cast<BranchInst>(Curr)) {
2724           assert(BI->isConditional() && "uses a comparison!");
2725 
2726           BasicBlock *NonNullSuccessor =
2727               BI->getSuccessor(NonNullIfTrue ? 0 : 1);
2728           BasicBlockEdge Edge(BI->getParent(), NonNullSuccessor);
2729           if (Edge.isSingleEdge() && DT->dominates(Edge, CtxI->getParent()))
2730             return true;
2731         } else if (NonNullIfTrue && isGuard(Curr) &&
2732                    DT->dominates(cast<Instruction>(Curr), CtxI)) {
2733           return true;
2734         }
2735       }
2736     }
2737   }
2738 
2739   return false;
2740 }
2741 
2742 /// Does the 'Range' metadata (which must be a valid MD_range operand list)
2743 /// ensure that the value it's attached to is never Value?  'RangeType' is
2744 /// is the type of the value described by the range.
2745 static bool rangeMetadataExcludesValue(const MDNode* Ranges, const APInt& Value) {
2746   const unsigned NumRanges = Ranges->getNumOperands() / 2;
2747   assert(NumRanges >= 1);
2748   for (unsigned i = 0; i < NumRanges; ++i) {
2749     ConstantInt *Lower =
2750         mdconst::extract<ConstantInt>(Ranges->getOperand(2 * i + 0));
2751     ConstantInt *Upper =
2752         mdconst::extract<ConstantInt>(Ranges->getOperand(2 * i + 1));
2753     ConstantRange Range(Lower->getValue(), Upper->getValue());
2754     if (Range.contains(Value))
2755       return false;
2756   }
2757   return true;
2758 }
2759 
2760 /// Try to detect a recurrence that monotonically increases/decreases from a
2761 /// non-zero starting value. These are common as induction variables.
2762 static bool isNonZeroRecurrence(const PHINode *PN) {
2763   BinaryOperator *BO = nullptr;
2764   Value *Start = nullptr, *Step = nullptr;
2765   const APInt *StartC, *StepC;
2766   if (!matchSimpleRecurrence(PN, BO, Start, Step) ||
2767       !match(Start, m_APInt(StartC)) || StartC->isZero())
2768     return false;
2769 
2770   switch (BO->getOpcode()) {
2771   case Instruction::Add:
2772     // Starting from non-zero and stepping away from zero can never wrap back
2773     // to zero.
2774     return BO->hasNoUnsignedWrap() ||
2775            (BO->hasNoSignedWrap() && match(Step, m_APInt(StepC)) &&
2776             StartC->isNegative() == StepC->isNegative());
2777   case Instruction::Mul:
2778     return (BO->hasNoUnsignedWrap() || BO->hasNoSignedWrap()) &&
2779            match(Step, m_APInt(StepC)) && !StepC->isZero();
2780   case Instruction::Shl:
2781     return BO->hasNoUnsignedWrap() || BO->hasNoSignedWrap();
2782   case Instruction::AShr:
2783   case Instruction::LShr:
2784     return BO->isExact();
2785   default:
2786     return false;
2787   }
2788 }
2789 
2790 static bool matchOpWithOpEqZero(Value *Op0, Value *Op1) {
2791   return match(Op0, m_ZExtOrSExt(m_SpecificICmp(ICmpInst::ICMP_EQ,
2792                                                 m_Specific(Op1), m_Zero()))) ||
2793          match(Op1, m_ZExtOrSExt(m_SpecificICmp(ICmpInst::ICMP_EQ,
2794                                                 m_Specific(Op0), m_Zero())));
2795 }
2796 
2797 static bool isNonZeroAdd(const APInt &DemandedElts, unsigned Depth,
2798                          const SimplifyQuery &Q, unsigned BitWidth, Value *X,
2799                          Value *Y, bool NSW, bool NUW) {
2800   // (X + (X != 0)) is non zero
2801   if (matchOpWithOpEqZero(X, Y))
2802     return true;
2803 
2804   if (NUW)
2805     return isKnownNonZero(Y, DemandedElts, Q, Depth) ||
2806            isKnownNonZero(X, DemandedElts, Q, Depth);
2807 
2808   KnownBits XKnown = computeKnownBits(X, DemandedElts, Depth, Q);
2809   KnownBits YKnown = computeKnownBits(Y, DemandedElts, Depth, Q);
2810 
2811   // If X and Y are both non-negative (as signed values) then their sum is not
2812   // zero unless both X and Y are zero.
2813   if (XKnown.isNonNegative() && YKnown.isNonNegative())
2814     if (isKnownNonZero(Y, DemandedElts, Q, Depth) ||
2815         isKnownNonZero(X, DemandedElts, Q, Depth))
2816       return true;
2817 
2818   // If X and Y are both negative (as signed values) then their sum is not
2819   // zero unless both X and Y equal INT_MIN.
2820   if (XKnown.isNegative() && YKnown.isNegative()) {
2821     APInt Mask = APInt::getSignedMaxValue(BitWidth);
2822     // The sign bit of X is set.  If some other bit is set then X is not equal
2823     // to INT_MIN.
2824     if (XKnown.One.intersects(Mask))
2825       return true;
2826     // The sign bit of Y is set.  If some other bit is set then Y is not equal
2827     // to INT_MIN.
2828     if (YKnown.One.intersects(Mask))
2829       return true;
2830   }
2831 
2832   // The sum of a non-negative number and a power of two is not zero.
2833   if (XKnown.isNonNegative() &&
2834       isKnownToBeAPowerOfTwo(Y, /*OrZero*/ false, Depth, Q))
2835     return true;
2836   if (YKnown.isNonNegative() &&
2837       isKnownToBeAPowerOfTwo(X, /*OrZero*/ false, Depth, Q))
2838     return true;
2839 
2840   return KnownBits::add(XKnown, YKnown, NSW, NUW).isNonZero();
2841 }
2842 
2843 static bool isNonZeroSub(const APInt &DemandedElts, unsigned Depth,
2844                          const SimplifyQuery &Q, unsigned BitWidth, Value *X,
2845                          Value *Y) {
2846   // (X - (X != 0)) is non zero
2847   // ((X != 0) - X) is non zero
2848   if (matchOpWithOpEqZero(X, Y))
2849     return true;
2850 
2851   // TODO: Move this case into isKnownNonEqual().
2852   if (auto *C = dyn_cast<Constant>(X))
2853     if (C->isNullValue() && isKnownNonZero(Y, DemandedElts, Q, Depth))
2854       return true;
2855 
2856   return ::isKnownNonEqual(X, Y, DemandedElts, Depth, Q);
2857 }
2858 
2859 static bool isNonZeroMul(const APInt &DemandedElts, unsigned Depth,
2860                          const SimplifyQuery &Q, unsigned BitWidth, Value *X,
2861                          Value *Y, bool NSW, bool NUW) {
2862   // If X and Y are non-zero then so is X * Y as long as the multiplication
2863   // does not overflow.
2864   if (NSW || NUW)
2865     return isKnownNonZero(X, DemandedElts, Q, Depth) &&
2866            isKnownNonZero(Y, DemandedElts, Q, Depth);
2867 
2868   // If either X or Y is odd, then if the other is non-zero the result can't
2869   // be zero.
2870   KnownBits XKnown = computeKnownBits(X, DemandedElts, Depth, Q);
2871   if (XKnown.One[0])
2872     return isKnownNonZero(Y, DemandedElts, Q, Depth);
2873 
2874   KnownBits YKnown = computeKnownBits(Y, DemandedElts, Depth, Q);
2875   if (YKnown.One[0])
2876     return XKnown.isNonZero() || isKnownNonZero(X, DemandedElts, Q, Depth);
2877 
2878   // If there exists any subset of X (sX) and subset of Y (sY) s.t sX * sY is
2879   // non-zero, then X * Y is non-zero. We can find sX and sY by just taking
2880   // the lowest known One of X and Y. If they are non-zero, the result
2881   // must be non-zero. We can check if LSB(X) * LSB(Y) != 0 by doing
2882   // X.CountLeadingZeros + Y.CountLeadingZeros < BitWidth.
2883   return (XKnown.countMaxTrailingZeros() + YKnown.countMaxTrailingZeros()) <
2884          BitWidth;
2885 }
2886 
2887 static bool isNonZeroShift(const Operator *I, const APInt &DemandedElts,
2888                            unsigned Depth, const SimplifyQuery &Q,
2889                            const KnownBits &KnownVal) {
2890   auto ShiftOp = [&](const APInt &Lhs, const APInt &Rhs) {
2891     switch (I->getOpcode()) {
2892     case Instruction::Shl:
2893       return Lhs.shl(Rhs);
2894     case Instruction::LShr:
2895       return Lhs.lshr(Rhs);
2896     case Instruction::AShr:
2897       return Lhs.ashr(Rhs);
2898     default:
2899       llvm_unreachable("Unknown Shift Opcode");
2900     }
2901   };
2902 
2903   auto InvShiftOp = [&](const APInt &Lhs, const APInt &Rhs) {
2904     switch (I->getOpcode()) {
2905     case Instruction::Shl:
2906       return Lhs.lshr(Rhs);
2907     case Instruction::LShr:
2908     case Instruction::AShr:
2909       return Lhs.shl(Rhs);
2910     default:
2911       llvm_unreachable("Unknown Shift Opcode");
2912     }
2913   };
2914 
2915   if (KnownVal.isUnknown())
2916     return false;
2917 
2918   KnownBits KnownCnt =
2919       computeKnownBits(I->getOperand(1), DemandedElts, Depth, Q);
2920   APInt MaxShift = KnownCnt.getMaxValue();
2921   unsigned NumBits = KnownVal.getBitWidth();
2922   if (MaxShift.uge(NumBits))
2923     return false;
2924 
2925   if (!ShiftOp(KnownVal.One, MaxShift).isZero())
2926     return true;
2927 
2928   // If all of the bits shifted out are known to be zero, and Val is known
2929   // non-zero then at least one non-zero bit must remain.
2930   if (InvShiftOp(KnownVal.Zero, NumBits - MaxShift)
2931           .eq(InvShiftOp(APInt::getAllOnes(NumBits), NumBits - MaxShift)) &&
2932       isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth))
2933     return true;
2934 
2935   return false;
2936 }
2937 
2938 static bool isKnownNonZeroFromOperator(const Operator *I,
2939                                        const APInt &DemandedElts,
2940                                        unsigned Depth, const SimplifyQuery &Q) {
2941   unsigned BitWidth = getBitWidth(I->getType()->getScalarType(), Q.DL);
2942   switch (I->getOpcode()) {
2943   case Instruction::Alloca:
2944     // Alloca never returns null, malloc might.
2945     return I->getType()->getPointerAddressSpace() == 0;
2946   case Instruction::GetElementPtr:
2947     if (I->getType()->isPointerTy())
2948       return isGEPKnownNonNull(cast<GEPOperator>(I), Depth, Q);
2949     break;
2950   case Instruction::BitCast: {
2951     // We need to be a bit careful here. We can only peek through the bitcast
2952     // if the scalar size of elements in the operand are smaller than and a
2953     // multiple of the size they are casting too. Take three cases:
2954     //
2955     // 1) Unsafe:
2956     //        bitcast <2 x i16> %NonZero to <4 x i8>
2957     //
2958     //    %NonZero can have 2 non-zero i16 elements, but isKnownNonZero on a
2959     //    <4 x i8> requires that all 4 i8 elements be non-zero which isn't
2960     //    guranteed (imagine just sign bit set in the 2 i16 elements).
2961     //
2962     // 2) Unsafe:
2963     //        bitcast <4 x i3> %NonZero to <3 x i4>
2964     //
2965     //    Even though the scalar size of the src (`i3`) is smaller than the
2966     //    scalar size of the dst `i4`, because `i3` is not a multiple of `i4`
2967     //    its possible for the `3 x i4` elements to be zero because there are
2968     //    some elements in the destination that don't contain any full src
2969     //    element.
2970     //
2971     // 3) Safe:
2972     //        bitcast <4 x i8> %NonZero to <2 x i16>
2973     //
2974     //    This is always safe as non-zero in the 4 i8 elements implies
2975     //    non-zero in the combination of any two adjacent ones. Since i8 is a
2976     //    multiple of i16, each i16 is guranteed to have 2 full i8 elements.
2977     //    This all implies the 2 i16 elements are non-zero.
2978     Type *FromTy = I->getOperand(0)->getType();
2979     if ((FromTy->isIntOrIntVectorTy() || FromTy->isPtrOrPtrVectorTy()) &&
2980         (BitWidth % getBitWidth(FromTy->getScalarType(), Q.DL)) == 0)
2981       return isKnownNonZero(I->getOperand(0), Q, Depth);
2982   } break;
2983   case Instruction::IntToPtr:
2984     // Note that we have to take special care to avoid looking through
2985     // truncating casts, e.g., int2ptr/ptr2int with appropriate sizes, as well
2986     // as casts that can alter the value, e.g., AddrSpaceCasts.
2987     if (!isa<ScalableVectorType>(I->getType()) &&
2988         Q.DL.getTypeSizeInBits(I->getOperand(0)->getType()).getFixedValue() <=
2989             Q.DL.getTypeSizeInBits(I->getType()).getFixedValue())
2990       return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
2991     break;
2992   case Instruction::PtrToInt:
2993     // Similar to int2ptr above, we can look through ptr2int here if the cast
2994     // is a no-op or an extend and not a truncate.
2995     if (!isa<ScalableVectorType>(I->getType()) &&
2996         Q.DL.getTypeSizeInBits(I->getOperand(0)->getType()).getFixedValue() <=
2997             Q.DL.getTypeSizeInBits(I->getType()).getFixedValue())
2998       return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
2999     break;
3000   case Instruction::Trunc:
3001     // nuw/nsw trunc preserves zero/non-zero status of input.
3002     if (auto *TI = dyn_cast<TruncInst>(I))
3003       if (TI->hasNoSignedWrap() || TI->hasNoUnsignedWrap())
3004         return isKnownNonZero(TI->getOperand(0), DemandedElts, Q, Depth);
3005     break;
3006 
3007   case Instruction::Sub:
3008     return isNonZeroSub(DemandedElts, Depth, Q, BitWidth, I->getOperand(0),
3009                         I->getOperand(1));
3010   case Instruction::Xor:
3011     // (X ^ (X != 0)) is non zero
3012     if (matchOpWithOpEqZero(I->getOperand(0), I->getOperand(1)))
3013       return true;
3014     break;
3015   case Instruction::Or:
3016     // (X | (X != 0)) is non zero
3017     if (matchOpWithOpEqZero(I->getOperand(0), I->getOperand(1)))
3018       return true;
3019     // X | Y != 0 if X != 0 or Y != 0.
3020     return isKnownNonZero(I->getOperand(1), DemandedElts, Q, Depth) ||
3021            isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
3022   case Instruction::SExt:
3023   case Instruction::ZExt:
3024     // ext X != 0 if X != 0.
3025     return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
3026 
3027   case Instruction::Shl: {
3028     // shl nsw/nuw can't remove any non-zero bits.
3029     const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(I);
3030     if (Q.IIQ.hasNoUnsignedWrap(BO) || Q.IIQ.hasNoSignedWrap(BO))
3031       return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
3032 
3033     // shl X, Y != 0 if X is odd.  Note that the value of the shift is undefined
3034     // if the lowest bit is shifted off the end.
3035     KnownBits Known(BitWidth);
3036     computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth, Q);
3037     if (Known.One[0])
3038       return true;
3039 
3040     return isNonZeroShift(I, DemandedElts, Depth, Q, Known);
3041   }
3042   case Instruction::LShr:
3043   case Instruction::AShr: {
3044     // shr exact can only shift out zero bits.
3045     const PossiblyExactOperator *BO = cast<PossiblyExactOperator>(I);
3046     if (BO->isExact())
3047       return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
3048 
3049     // shr X, Y != 0 if X is negative.  Note that the value of the shift is not
3050     // defined if the sign bit is shifted off the end.
3051     KnownBits Known =
3052         computeKnownBits(I->getOperand(0), DemandedElts, Depth, Q);
3053     if (Known.isNegative())
3054       return true;
3055 
3056     return isNonZeroShift(I, DemandedElts, Depth, Q, Known);
3057   }
3058   case Instruction::UDiv:
3059   case Instruction::SDiv: {
3060     // X / Y
3061     // div exact can only produce a zero if the dividend is zero.
3062     if (cast<PossiblyExactOperator>(I)->isExact())
3063       return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
3064 
3065     KnownBits XKnown =
3066         computeKnownBits(I->getOperand(0), DemandedElts, Depth, Q);
3067     // If X is fully unknown we won't be able to figure anything out so don't
3068     // both computing knownbits for Y.
3069     if (XKnown.isUnknown())
3070       return false;
3071 
3072     KnownBits YKnown =
3073         computeKnownBits(I->getOperand(1), DemandedElts, Depth, Q);
3074     if (I->getOpcode() == Instruction::SDiv) {
3075       // For signed division need to compare abs value of the operands.
3076       XKnown = XKnown.abs(/*IntMinIsPoison*/ false);
3077       YKnown = YKnown.abs(/*IntMinIsPoison*/ false);
3078     }
3079     // If X u>= Y then div is non zero (0/0 is UB).
3080     std::optional<bool> XUgeY = KnownBits::uge(XKnown, YKnown);
3081     // If X is total unknown or X u< Y we won't be able to prove non-zero
3082     // with compute known bits so just return early.
3083     return XUgeY && *XUgeY;
3084   }
3085   case Instruction::Add: {
3086     // X + Y.
3087 
3088     // If Add has nuw wrap flag, then if either X or Y is non-zero the result is
3089     // non-zero.
3090     auto *BO = cast<OverflowingBinaryOperator>(I);
3091     return isNonZeroAdd(DemandedElts, Depth, Q, BitWidth, I->getOperand(0),
3092                         I->getOperand(1), Q.IIQ.hasNoSignedWrap(BO),
3093                         Q.IIQ.hasNoUnsignedWrap(BO));
3094   }
3095   case Instruction::Mul: {
3096     const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(I);
3097     return isNonZeroMul(DemandedElts, Depth, Q, BitWidth, I->getOperand(0),
3098                         I->getOperand(1), Q.IIQ.hasNoSignedWrap(BO),
3099                         Q.IIQ.hasNoUnsignedWrap(BO));
3100   }
3101   case Instruction::Select: {
3102     // (C ? X : Y) != 0 if X != 0 and Y != 0.
3103 
3104     // First check if the arm is non-zero using `isKnownNonZero`. If that fails,
3105     // then see if the select condition implies the arm is non-zero. For example
3106     // (X != 0 ? X : Y), we know the true arm is non-zero as the `X` "return" is
3107     // dominated by `X != 0`.
3108     auto SelectArmIsNonZero = [&](bool IsTrueArm) {
3109       Value *Op;
3110       Op = IsTrueArm ? I->getOperand(1) : I->getOperand(2);
3111       // Op is trivially non-zero.
3112       if (isKnownNonZero(Op, DemandedElts, Q, Depth))
3113         return true;
3114 
3115       // The condition of the select dominates the true/false arm. Check if the
3116       // condition implies that a given arm is non-zero.
3117       Value *X;
3118       CmpPredicate Pred;
3119       if (!match(I->getOperand(0), m_c_ICmp(Pred, m_Specific(Op), m_Value(X))))
3120         return false;
3121 
3122       if (!IsTrueArm)
3123         Pred = ICmpInst::getInversePredicate(Pred);
3124 
3125       return cmpExcludesZero(Pred, X);
3126     };
3127 
3128     if (SelectArmIsNonZero(/* IsTrueArm */ true) &&
3129         SelectArmIsNonZero(/* IsTrueArm */ false))
3130       return true;
3131     break;
3132   }
3133   case Instruction::PHI: {
3134     auto *PN = cast<PHINode>(I);
3135     if (Q.IIQ.UseInstrInfo && isNonZeroRecurrence(PN))
3136       return true;
3137 
3138     // Check if all incoming values are non-zero using recursion.
3139     SimplifyQuery RecQ = Q.getWithoutCondContext();
3140     unsigned NewDepth = std::max(Depth, MaxAnalysisRecursionDepth - 1);
3141     return llvm::all_of(PN->operands(), [&](const Use &U) {
3142       if (U.get() == PN)
3143         return true;
3144       RecQ.CxtI = PN->getIncomingBlock(U)->getTerminator();
3145       // Check if the branch on the phi excludes zero.
3146       CmpPredicate Pred;
3147       Value *X;
3148       BasicBlock *TrueSucc, *FalseSucc;
3149       if (match(RecQ.CxtI,
3150                 m_Br(m_c_ICmp(Pred, m_Specific(U.get()), m_Value(X)),
3151                      m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc)))) {
3152         // Check for cases of duplicate successors.
3153         if ((TrueSucc == PN->getParent()) != (FalseSucc == PN->getParent())) {
3154           // If we're using the false successor, invert the predicate.
3155           if (FalseSucc == PN->getParent())
3156             Pred = CmpInst::getInversePredicate(Pred);
3157           if (cmpExcludesZero(Pred, X))
3158             return true;
3159         }
3160       }
3161       // Finally recurse on the edge and check it directly.
3162       return isKnownNonZero(U.get(), DemandedElts, RecQ, NewDepth);
3163     });
3164   }
3165   case Instruction::InsertElement: {
3166     if (isa<ScalableVectorType>(I->getType()))
3167       break;
3168 
3169     const Value *Vec = I->getOperand(0);
3170     const Value *Elt = I->getOperand(1);
3171     auto *CIdx = dyn_cast<ConstantInt>(I->getOperand(2));
3172 
3173     unsigned NumElts = DemandedElts.getBitWidth();
3174     APInt DemandedVecElts = DemandedElts;
3175     bool SkipElt = false;
3176     // If we know the index we are inserting too, clear it from Vec check.
3177     if (CIdx && CIdx->getValue().ult(NumElts)) {
3178       DemandedVecElts.clearBit(CIdx->getZExtValue());
3179       SkipElt = !DemandedElts[CIdx->getZExtValue()];
3180     }
3181 
3182     // Result is zero if Elt is non-zero and rest of the demanded elts in Vec
3183     // are non-zero.
3184     return (SkipElt || isKnownNonZero(Elt, Q, Depth)) &&
3185            (DemandedVecElts.isZero() ||
3186             isKnownNonZero(Vec, DemandedVecElts, Q, Depth));
3187   }
3188   case Instruction::ExtractElement:
3189     if (const auto *EEI = dyn_cast<ExtractElementInst>(I)) {
3190       const Value *Vec = EEI->getVectorOperand();
3191       const Value *Idx = EEI->getIndexOperand();
3192       auto *CIdx = dyn_cast<ConstantInt>(Idx);
3193       if (auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType())) {
3194         unsigned NumElts = VecTy->getNumElements();
3195         APInt DemandedVecElts = APInt::getAllOnes(NumElts);
3196         if (CIdx && CIdx->getValue().ult(NumElts))
3197           DemandedVecElts = APInt::getOneBitSet(NumElts, CIdx->getZExtValue());
3198         return isKnownNonZero(Vec, DemandedVecElts, Q, Depth);
3199       }
3200     }
3201     break;
3202   case Instruction::ShuffleVector: {
3203     auto *Shuf = dyn_cast<ShuffleVectorInst>(I);
3204     if (!Shuf)
3205       break;
3206     APInt DemandedLHS, DemandedRHS;
3207     // For undef elements, we don't know anything about the common state of
3208     // the shuffle result.
3209     if (!getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS))
3210       break;
3211     // If demanded elements for both vecs are non-zero, the shuffle is non-zero.
3212     return (DemandedRHS.isZero() ||
3213             isKnownNonZero(Shuf->getOperand(1), DemandedRHS, Q, Depth)) &&
3214            (DemandedLHS.isZero() ||
3215             isKnownNonZero(Shuf->getOperand(0), DemandedLHS, Q, Depth));
3216   }
3217   case Instruction::Freeze:
3218     return isKnownNonZero(I->getOperand(0), Q, Depth) &&
3219            isGuaranteedNotToBePoison(I->getOperand(0), Q.AC, Q.CxtI, Q.DT,
3220                                      Depth);
3221   case Instruction::Load: {
3222     auto *LI = cast<LoadInst>(I);
3223     // A Load tagged with nonnull or dereferenceable with null pointer undefined
3224     // is never null.
3225     if (auto *PtrT = dyn_cast<PointerType>(I->getType())) {
3226       if (Q.IIQ.getMetadata(LI, LLVMContext::MD_nonnull) ||
3227           (Q.IIQ.getMetadata(LI, LLVMContext::MD_dereferenceable) &&
3228            !NullPointerIsDefined(LI->getFunction(), PtrT->getAddressSpace())))
3229         return true;
3230     } else if (MDNode *Ranges = Q.IIQ.getMetadata(LI, LLVMContext::MD_range)) {
3231       return rangeMetadataExcludesValue(Ranges, APInt::getZero(BitWidth));
3232     }
3233 
3234     // No need to fall through to computeKnownBits as range metadata is already
3235     // handled in isKnownNonZero.
3236     return false;
3237   }
3238   case Instruction::ExtractValue: {
3239     const WithOverflowInst *WO;
3240     if (match(I, m_ExtractValue<0>(m_WithOverflowInst(WO)))) {
3241       switch (WO->getBinaryOp()) {
3242       default:
3243         break;
3244       case Instruction::Add:
3245         return isNonZeroAdd(DemandedElts, Depth, Q, BitWidth,
3246                             WO->getArgOperand(0), WO->getArgOperand(1),
3247                             /*NSW=*/false,
3248                             /*NUW=*/false);
3249       case Instruction::Sub:
3250         return isNonZeroSub(DemandedElts, Depth, Q, BitWidth,
3251                             WO->getArgOperand(0), WO->getArgOperand(1));
3252       case Instruction::Mul:
3253         return isNonZeroMul(DemandedElts, Depth, Q, BitWidth,
3254                             WO->getArgOperand(0), WO->getArgOperand(1),
3255                             /*NSW=*/false, /*NUW=*/false);
3256         break;
3257       }
3258     }
3259     break;
3260   }
3261   case Instruction::Call:
3262   case Instruction::Invoke: {
3263     const auto *Call = cast<CallBase>(I);
3264     if (I->getType()->isPointerTy()) {
3265       if (Call->isReturnNonNull())
3266         return true;
3267       if (const auto *RP = getArgumentAliasingToReturnedPointer(Call, true))
3268         return isKnownNonZero(RP, Q, Depth);
3269     } else {
3270       if (MDNode *Ranges = Q.IIQ.getMetadata(Call, LLVMContext::MD_range))
3271         return rangeMetadataExcludesValue(Ranges, APInt::getZero(BitWidth));
3272       if (std::optional<ConstantRange> Range = Call->getRange()) {
3273         const APInt ZeroValue(Range->getBitWidth(), 0);
3274         if (!Range->contains(ZeroValue))
3275           return true;
3276       }
3277       if (const Value *RV = Call->getReturnedArgOperand())
3278         if (RV->getType() == I->getType() && isKnownNonZero(RV, Q, Depth))
3279           return true;
3280     }
3281 
3282     if (auto *II = dyn_cast<IntrinsicInst>(I)) {
3283       switch (II->getIntrinsicID()) {
3284       case Intrinsic::sshl_sat:
3285       case Intrinsic::ushl_sat:
3286       case Intrinsic::abs:
3287       case Intrinsic::bitreverse:
3288       case Intrinsic::bswap:
3289       case Intrinsic::ctpop:
3290         return isKnownNonZero(II->getArgOperand(0), DemandedElts, Q, Depth);
3291         // NB: We don't do usub_sat here as in any case we can prove its
3292         // non-zero, we will fold it to `sub nuw` in InstCombine.
3293       case Intrinsic::ssub_sat:
3294         return isNonZeroSub(DemandedElts, Depth, Q, BitWidth,
3295                             II->getArgOperand(0), II->getArgOperand(1));
3296       case Intrinsic::sadd_sat:
3297         return isNonZeroAdd(DemandedElts, Depth, Q, BitWidth,
3298                             II->getArgOperand(0), II->getArgOperand(1),
3299                             /*NSW=*/true, /* NUW=*/false);
3300         // Vec reverse preserves zero/non-zero status from input vec.
3301       case Intrinsic::vector_reverse:
3302         return isKnownNonZero(II->getArgOperand(0), DemandedElts.reverseBits(),
3303                               Q, Depth);
3304         // umin/smin/smax/smin/or of all non-zero elements is always non-zero.
3305       case Intrinsic::vector_reduce_or:
3306       case Intrinsic::vector_reduce_umax:
3307       case Intrinsic::vector_reduce_umin:
3308       case Intrinsic::vector_reduce_smax:
3309       case Intrinsic::vector_reduce_smin:
3310         return isKnownNonZero(II->getArgOperand(0), Q, Depth);
3311       case Intrinsic::umax:
3312       case Intrinsic::uadd_sat:
3313         // umax(X, (X != 0)) is non zero
3314         // X +usat (X != 0) is non zero
3315         if (matchOpWithOpEqZero(II->getArgOperand(0), II->getArgOperand(1)))
3316           return true;
3317 
3318         return isKnownNonZero(II->getArgOperand(1), DemandedElts, Q, Depth) ||
3319                isKnownNonZero(II->getArgOperand(0), DemandedElts, Q, Depth);
3320       case Intrinsic::smax: {
3321         // If either arg is strictly positive the result is non-zero. Otherwise
3322         // the result is non-zero if both ops are non-zero.
3323         auto IsNonZero = [&](Value *Op, std::optional<bool> &OpNonZero,
3324                              const KnownBits &OpKnown) {
3325           if (!OpNonZero.has_value())
3326             OpNonZero = OpKnown.isNonZero() ||
3327                         isKnownNonZero(Op, DemandedElts, Q, Depth);
3328           return *OpNonZero;
3329         };
3330         // Avoid re-computing isKnownNonZero.
3331         std::optional<bool> Op0NonZero, Op1NonZero;
3332         KnownBits Op1Known =
3333             computeKnownBits(II->getArgOperand(1), DemandedElts, Depth, Q);
3334         if (Op1Known.isNonNegative() &&
3335             IsNonZero(II->getArgOperand(1), Op1NonZero, Op1Known))
3336           return true;
3337         KnownBits Op0Known =
3338             computeKnownBits(II->getArgOperand(0), DemandedElts, Depth, Q);
3339         if (Op0Known.isNonNegative() &&
3340             IsNonZero(II->getArgOperand(0), Op0NonZero, Op0Known))
3341           return true;
3342         return IsNonZero(II->getArgOperand(1), Op1NonZero, Op1Known) &&
3343                IsNonZero(II->getArgOperand(0), Op0NonZero, Op0Known);
3344       }
3345       case Intrinsic::smin: {
3346         // If either arg is negative the result is non-zero. Otherwise
3347         // the result is non-zero if both ops are non-zero.
3348         KnownBits Op1Known =
3349             computeKnownBits(II->getArgOperand(1), DemandedElts, Depth, Q);
3350         if (Op1Known.isNegative())
3351           return true;
3352         KnownBits Op0Known =
3353             computeKnownBits(II->getArgOperand(0), DemandedElts, Depth, Q);
3354         if (Op0Known.isNegative())
3355           return true;
3356 
3357         if (Op1Known.isNonZero() && Op0Known.isNonZero())
3358           return true;
3359       }
3360         [[fallthrough]];
3361       case Intrinsic::umin:
3362         return isKnownNonZero(II->getArgOperand(0), DemandedElts, Q, Depth) &&
3363                isKnownNonZero(II->getArgOperand(1), DemandedElts, Q, Depth);
3364       case Intrinsic::cttz:
3365         return computeKnownBits(II->getArgOperand(0), DemandedElts, Depth, Q)
3366             .Zero[0];
3367       case Intrinsic::ctlz:
3368         return computeKnownBits(II->getArgOperand(0), DemandedElts, Depth, Q)
3369             .isNonNegative();
3370       case Intrinsic::fshr:
3371       case Intrinsic::fshl:
3372         // If Op0 == Op1, this is a rotate. rotate(x, y) != 0 iff x != 0.
3373         if (II->getArgOperand(0) == II->getArgOperand(1))
3374           return isKnownNonZero(II->getArgOperand(0), DemandedElts, Q, Depth);
3375         break;
3376       case Intrinsic::vscale:
3377         return true;
3378       case Intrinsic::experimental_get_vector_length:
3379         return isKnownNonZero(I->getOperand(0), Q, Depth);
3380       default:
3381         break;
3382       }
3383       break;
3384     }
3385 
3386     return false;
3387   }
3388   }
3389 
3390   KnownBits Known(BitWidth);
3391   computeKnownBits(I, DemandedElts, Known, Depth, Q);
3392   return Known.One != 0;
3393 }
3394 
3395 /// Return true if the given value is known to be non-zero when defined. For
3396 /// vectors, return true if every demanded element is known to be non-zero when
3397 /// defined. For pointers, if the context instruction and dominator tree are
3398 /// specified, perform context-sensitive analysis and return true if the
3399 /// pointer couldn't possibly be null at the specified instruction.
3400 /// Supports values with integer or pointer type and vectors of integers.
3401 bool isKnownNonZero(const Value *V, const APInt &DemandedElts,
3402                     const SimplifyQuery &Q, unsigned Depth) {
3403   Type *Ty = V->getType();
3404 
3405 #ifndef NDEBUG
3406   assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
3407 
3408   if (auto *FVTy = dyn_cast<FixedVectorType>(Ty)) {
3409     assert(
3410         FVTy->getNumElements() == DemandedElts.getBitWidth() &&
3411         "DemandedElt width should equal the fixed vector number of elements");
3412   } else {
3413     assert(DemandedElts == APInt(1, 1) &&
3414            "DemandedElt width should be 1 for scalars");
3415   }
3416 #endif
3417 
3418   if (auto *C = dyn_cast<Constant>(V)) {
3419     if (C->isNullValue())
3420       return false;
3421     if (isa<ConstantInt>(C))
3422       // Must be non-zero due to null test above.
3423       return true;
3424 
3425     // For constant vectors, check that all elements are poison or known
3426     // non-zero to determine that the whole vector is known non-zero.
3427     if (auto *VecTy = dyn_cast<FixedVectorType>(Ty)) {
3428       for (unsigned i = 0, e = VecTy->getNumElements(); i != e; ++i) {
3429         if (!DemandedElts[i])
3430           continue;
3431         Constant *Elt = C->getAggregateElement(i);
3432         if (!Elt || Elt->isNullValue())
3433           return false;
3434         if (!isa<PoisonValue>(Elt) && !isa<ConstantInt>(Elt))
3435           return false;
3436       }
3437       return true;
3438     }
3439 
3440     // Constant ptrauth can be null, iff the base pointer can be.
3441     if (auto *CPA = dyn_cast<ConstantPtrAuth>(V))
3442       return isKnownNonZero(CPA->getPointer(), DemandedElts, Q, Depth);
3443 
3444     // A global variable in address space 0 is non null unless extern weak
3445     // or an absolute symbol reference. Other address spaces may have null as a
3446     // valid address for a global, so we can't assume anything.
3447     if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) {
3448       if (!GV->isAbsoluteSymbolRef() && !GV->hasExternalWeakLinkage() &&
3449           GV->getType()->getAddressSpace() == 0)
3450         return true;
3451     }
3452 
3453     // For constant expressions, fall through to the Operator code below.
3454     if (!isa<ConstantExpr>(V))
3455       return false;
3456   }
3457 
3458   if (const auto *A = dyn_cast<Argument>(V))
3459     if (std::optional<ConstantRange> Range = A->getRange()) {
3460       const APInt ZeroValue(Range->getBitWidth(), 0);
3461       if (!Range->contains(ZeroValue))
3462         return true;
3463     }
3464 
3465   if (!isa<Constant>(V) && isKnownNonZeroFromAssume(V, Q))
3466     return true;
3467 
3468   // Some of the tests below are recursive, so bail out if we hit the limit.
3469   if (Depth++ >= MaxAnalysisRecursionDepth)
3470     return false;
3471 
3472   // Check for pointer simplifications.
3473 
3474   if (PointerType *PtrTy = dyn_cast<PointerType>(Ty)) {
3475     // A byval, inalloca may not be null in a non-default addres space. A
3476     // nonnull argument is assumed never 0.
3477     if (const Argument *A = dyn_cast<Argument>(V)) {
3478       if (((A->hasPassPointeeByValueCopyAttr() &&
3479             !NullPointerIsDefined(A->getParent(), PtrTy->getAddressSpace())) ||
3480            A->hasNonNullAttr()))
3481         return true;
3482     }
3483   }
3484 
3485   if (const auto *I = dyn_cast<Operator>(V))
3486     if (isKnownNonZeroFromOperator(I, DemandedElts, Depth, Q))
3487       return true;
3488 
3489   if (!isa<Constant>(V) &&
3490       isKnownNonNullFromDominatingCondition(V, Q.CxtI, Q.DT))
3491     return true;
3492 
3493   return false;
3494 }
3495 
3496 bool llvm::isKnownNonZero(const Value *V, const SimplifyQuery &Q,
3497                           unsigned Depth) {
3498   auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
3499   APInt DemandedElts =
3500       FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
3501   return ::isKnownNonZero(V, DemandedElts, Q, Depth);
3502 }
3503 
3504 /// If the pair of operators are the same invertible function, return the
3505 /// the operands of the function corresponding to each input. Otherwise,
3506 /// return std::nullopt.  An invertible function is one that is 1-to-1 and maps
3507 /// every input value to exactly one output value.  This is equivalent to
3508 /// saying that Op1 and Op2 are equal exactly when the specified pair of
3509 /// operands are equal, (except that Op1 and Op2 may be poison more often.)
3510 static std::optional<std::pair<Value*, Value*>>
3511 getInvertibleOperands(const Operator *Op1,
3512                       const Operator *Op2) {
3513   if (Op1->getOpcode() != Op2->getOpcode())
3514     return std::nullopt;
3515 
3516   auto getOperands = [&](unsigned OpNum) -> auto {
3517     return std::make_pair(Op1->getOperand(OpNum), Op2->getOperand(OpNum));
3518   };
3519 
3520   switch (Op1->getOpcode()) {
3521   default:
3522     break;
3523   case Instruction::Or:
3524     if (!cast<PossiblyDisjointInst>(Op1)->isDisjoint() ||
3525         !cast<PossiblyDisjointInst>(Op2)->isDisjoint())
3526       break;
3527     [[fallthrough]];
3528   case Instruction::Xor:
3529   case Instruction::Add: {
3530     Value *Other;
3531     if (match(Op2, m_c_BinOp(m_Specific(Op1->getOperand(0)), m_Value(Other))))
3532       return std::make_pair(Op1->getOperand(1), Other);
3533     if (match(Op2, m_c_BinOp(m_Specific(Op1->getOperand(1)), m_Value(Other))))
3534       return std::make_pair(Op1->getOperand(0), Other);
3535     break;
3536   }
3537   case Instruction::Sub:
3538     if (Op1->getOperand(0) == Op2->getOperand(0))
3539       return getOperands(1);
3540     if (Op1->getOperand(1) == Op2->getOperand(1))
3541       return getOperands(0);
3542     break;
3543   case Instruction::Mul: {
3544     // invertible if A * B == (A * B) mod 2^N where A, and B are integers
3545     // and N is the bitwdith.  The nsw case is non-obvious, but proven by
3546     // alive2: https://alive2.llvm.org/ce/z/Z6D5qK
3547     auto *OBO1 = cast<OverflowingBinaryOperator>(Op1);
3548     auto *OBO2 = cast<OverflowingBinaryOperator>(Op2);
3549     if ((!OBO1->hasNoUnsignedWrap() || !OBO2->hasNoUnsignedWrap()) &&
3550         (!OBO1->hasNoSignedWrap() || !OBO2->hasNoSignedWrap()))
3551       break;
3552 
3553     // Assume operand order has been canonicalized
3554     if (Op1->getOperand(1) == Op2->getOperand(1) &&
3555         isa<ConstantInt>(Op1->getOperand(1)) &&
3556         !cast<ConstantInt>(Op1->getOperand(1))->isZero())
3557       return getOperands(0);
3558     break;
3559   }
3560   case Instruction::Shl: {
3561     // Same as multiplies, with the difference that we don't need to check
3562     // for a non-zero multiply. Shifts always multiply by non-zero.
3563     auto *OBO1 = cast<OverflowingBinaryOperator>(Op1);
3564     auto *OBO2 = cast<OverflowingBinaryOperator>(Op2);
3565     if ((!OBO1->hasNoUnsignedWrap() || !OBO2->hasNoUnsignedWrap()) &&
3566         (!OBO1->hasNoSignedWrap() || !OBO2->hasNoSignedWrap()))
3567       break;
3568 
3569     if (Op1->getOperand(1) == Op2->getOperand(1))
3570       return getOperands(0);
3571     break;
3572   }
3573   case Instruction::AShr:
3574   case Instruction::LShr: {
3575     auto *PEO1 = cast<PossiblyExactOperator>(Op1);
3576     auto *PEO2 = cast<PossiblyExactOperator>(Op2);
3577     if (!PEO1->isExact() || !PEO2->isExact())
3578       break;
3579 
3580     if (Op1->getOperand(1) == Op2->getOperand(1))
3581       return getOperands(0);
3582     break;
3583   }
3584   case Instruction::SExt:
3585   case Instruction::ZExt:
3586     if (Op1->getOperand(0)->getType() == Op2->getOperand(0)->getType())
3587       return getOperands(0);
3588     break;
3589   case Instruction::PHI: {
3590     const PHINode *PN1 = cast<PHINode>(Op1);
3591     const PHINode *PN2 = cast<PHINode>(Op2);
3592 
3593     // If PN1 and PN2 are both recurrences, can we prove the entire recurrences
3594     // are a single invertible function of the start values? Note that repeated
3595     // application of an invertible function is also invertible
3596     BinaryOperator *BO1 = nullptr;
3597     Value *Start1 = nullptr, *Step1 = nullptr;
3598     BinaryOperator *BO2 = nullptr;
3599     Value *Start2 = nullptr, *Step2 = nullptr;
3600     if (PN1->getParent() != PN2->getParent() ||
3601         !matchSimpleRecurrence(PN1, BO1, Start1, Step1) ||
3602         !matchSimpleRecurrence(PN2, BO2, Start2, Step2))
3603       break;
3604 
3605     auto Values = getInvertibleOperands(cast<Operator>(BO1),
3606                                         cast<Operator>(BO2));
3607     if (!Values)
3608        break;
3609 
3610     // We have to be careful of mutually defined recurrences here.  Ex:
3611     // * X_i = X_(i-1) OP Y_(i-1), and Y_i = X_(i-1) OP V
3612     // * X_i = Y_i = X_(i-1) OP Y_(i-1)
3613     // The invertibility of these is complicated, and not worth reasoning
3614     // about (yet?).
3615     if (Values->first != PN1 || Values->second != PN2)
3616       break;
3617 
3618     return std::make_pair(Start1, Start2);
3619   }
3620   }
3621   return std::nullopt;
3622 }
3623 
3624 /// Return true if V1 == (binop V2, X), where X is known non-zero.
3625 /// Only handle a small subset of binops where (binop V2, X) with non-zero X
3626 /// implies V2 != V1.
3627 static bool isModifyingBinopOfNonZero(const Value *V1, const Value *V2,
3628                                       const APInt &DemandedElts, unsigned Depth,
3629                                       const SimplifyQuery &Q) {
3630   const BinaryOperator *BO = dyn_cast<BinaryOperator>(V1);
3631   if (!BO)
3632     return false;
3633   switch (BO->getOpcode()) {
3634   default:
3635     break;
3636   case Instruction::Or:
3637     if (!cast<PossiblyDisjointInst>(V1)->isDisjoint())
3638       break;
3639     [[fallthrough]];
3640   case Instruction::Xor:
3641   case Instruction::Add:
3642     Value *Op = nullptr;
3643     if (V2 == BO->getOperand(0))
3644       Op = BO->getOperand(1);
3645     else if (V2 == BO->getOperand(1))
3646       Op = BO->getOperand(0);
3647     else
3648       return false;
3649     return isKnownNonZero(Op, DemandedElts, Q, Depth + 1);
3650   }
3651   return false;
3652 }
3653 
3654 /// Return true if V2 == V1 * C, where V1 is known non-zero, C is not 0/1 and
3655 /// the multiplication is nuw or nsw.
3656 static bool isNonEqualMul(const Value *V1, const Value *V2,
3657                           const APInt &DemandedElts, unsigned Depth,
3658                           const SimplifyQuery &Q) {
3659   if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(V2)) {
3660     const APInt *C;
3661     return match(OBO, m_Mul(m_Specific(V1), m_APInt(C))) &&
3662            (OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap()) &&
3663            !C->isZero() && !C->isOne() &&
3664            isKnownNonZero(V1, DemandedElts, Q, Depth + 1);
3665   }
3666   return false;
3667 }
3668 
3669 /// Return true if V2 == V1 << C, where V1 is known non-zero, C is not 0 and
3670 /// the shift is nuw or nsw.
3671 static bool isNonEqualShl(const Value *V1, const Value *V2,
3672                           const APInt &DemandedElts, unsigned Depth,
3673                           const SimplifyQuery &Q) {
3674   if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(V2)) {
3675     const APInt *C;
3676     return match(OBO, m_Shl(m_Specific(V1), m_APInt(C))) &&
3677            (OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap()) &&
3678            !C->isZero() && isKnownNonZero(V1, DemandedElts, Q, Depth + 1);
3679   }
3680   return false;
3681 }
3682 
3683 static bool isNonEqualPHIs(const PHINode *PN1, const PHINode *PN2,
3684                            const APInt &DemandedElts, unsigned Depth,
3685                            const SimplifyQuery &Q) {
3686   // Check two PHIs are in same block.
3687   if (PN1->getParent() != PN2->getParent())
3688     return false;
3689 
3690   SmallPtrSet<const BasicBlock *, 8> VisitedBBs;
3691   bool UsedFullRecursion = false;
3692   for (const BasicBlock *IncomBB : PN1->blocks()) {
3693     if (!VisitedBBs.insert(IncomBB).second)
3694       continue; // Don't reprocess blocks that we have dealt with already.
3695     const Value *IV1 = PN1->getIncomingValueForBlock(IncomBB);
3696     const Value *IV2 = PN2->getIncomingValueForBlock(IncomBB);
3697     const APInt *C1, *C2;
3698     if (match(IV1, m_APInt(C1)) && match(IV2, m_APInt(C2)) && *C1 != *C2)
3699       continue;
3700 
3701     // Only one pair of phi operands is allowed for full recursion.
3702     if (UsedFullRecursion)
3703       return false;
3704 
3705     SimplifyQuery RecQ = Q.getWithoutCondContext();
3706     RecQ.CxtI = IncomBB->getTerminator();
3707     if (!isKnownNonEqual(IV1, IV2, DemandedElts, Depth + 1, RecQ))
3708       return false;
3709     UsedFullRecursion = true;
3710   }
3711   return true;
3712 }
3713 
3714 static bool isNonEqualSelect(const Value *V1, const Value *V2,
3715                              const APInt &DemandedElts, unsigned Depth,
3716                              const SimplifyQuery &Q) {
3717   const SelectInst *SI1 = dyn_cast<SelectInst>(V1);
3718   if (!SI1)
3719     return false;
3720 
3721   if (const SelectInst *SI2 = dyn_cast<SelectInst>(V2)) {
3722     const Value *Cond1 = SI1->getCondition();
3723     const Value *Cond2 = SI2->getCondition();
3724     if (Cond1 == Cond2)
3725       return isKnownNonEqual(SI1->getTrueValue(), SI2->getTrueValue(),
3726                              DemandedElts, Depth + 1, Q) &&
3727              isKnownNonEqual(SI1->getFalseValue(), SI2->getFalseValue(),
3728                              DemandedElts, Depth + 1, Q);
3729   }
3730   return isKnownNonEqual(SI1->getTrueValue(), V2, DemandedElts, Depth + 1, Q) &&
3731          isKnownNonEqual(SI1->getFalseValue(), V2, DemandedElts, Depth + 1, Q);
3732 }
3733 
3734 // Check to see if A is both a GEP and is the incoming value for a PHI in the
3735 // loop, and B is either a ptr or another GEP. If the PHI has 2 incoming values,
3736 // one of them being the recursive GEP A and the other a ptr at same base and at
3737 // the same/higher offset than B we are only incrementing the pointer further in
3738 // loop if offset of recursive GEP is greater than 0.
3739 static bool isNonEqualPointersWithRecursiveGEP(const Value *A, const Value *B,
3740                                                const SimplifyQuery &Q) {
3741   if (!A->getType()->isPointerTy() || !B->getType()->isPointerTy())
3742     return false;
3743 
3744   auto *GEPA = dyn_cast<GEPOperator>(A);
3745   if (!GEPA || GEPA->getNumIndices() != 1 || !isa<Constant>(GEPA->idx_begin()))
3746     return false;
3747 
3748   // Handle 2 incoming PHI values with one being a recursive GEP.
3749   auto *PN = dyn_cast<PHINode>(GEPA->getPointerOperand());
3750   if (!PN || PN->getNumIncomingValues() != 2)
3751     return false;
3752 
3753   // Search for the recursive GEP as an incoming operand, and record that as
3754   // Step.
3755   Value *Start = nullptr;
3756   Value *Step = const_cast<Value *>(A);
3757   if (PN->getIncomingValue(0) == Step)
3758     Start = PN->getIncomingValue(1);
3759   else if (PN->getIncomingValue(1) == Step)
3760     Start = PN->getIncomingValue(0);
3761   else
3762     return false;
3763 
3764   // Other incoming node base should match the B base.
3765   // StartOffset >= OffsetB && StepOffset > 0?
3766   // StartOffset <= OffsetB && StepOffset < 0?
3767   // Is non-equal if above are true.
3768   // We use stripAndAccumulateInBoundsConstantOffsets to restrict the
3769   // optimisation to inbounds GEPs only.
3770   unsigned IndexWidth = Q.DL.getIndexTypeSizeInBits(Start->getType());
3771   APInt StartOffset(IndexWidth, 0);
3772   Start = Start->stripAndAccumulateInBoundsConstantOffsets(Q.DL, StartOffset);
3773   APInt StepOffset(IndexWidth, 0);
3774   Step = Step->stripAndAccumulateInBoundsConstantOffsets(Q.DL, StepOffset);
3775 
3776   // Check if Base Pointer of Step matches the PHI.
3777   if (Step != PN)
3778     return false;
3779   APInt OffsetB(IndexWidth, 0);
3780   B = B->stripAndAccumulateInBoundsConstantOffsets(Q.DL, OffsetB);
3781   return Start == B &&
3782          ((StartOffset.sge(OffsetB) && StepOffset.isStrictlyPositive()) ||
3783           (StartOffset.sle(OffsetB) && StepOffset.isNegative()));
3784 }
3785 
3786 /// Return true if it is known that V1 != V2.
3787 static bool isKnownNonEqual(const Value *V1, const Value *V2,
3788                             const APInt &DemandedElts, unsigned Depth,
3789                             const SimplifyQuery &Q) {
3790   if (V1 == V2)
3791     return false;
3792   if (V1->getType() != V2->getType())
3793     // We can't look through casts yet.
3794     return false;
3795 
3796   if (Depth >= MaxAnalysisRecursionDepth)
3797     return false;
3798 
3799   // See if we can recurse through (exactly one of) our operands.  This
3800   // requires our operation be 1-to-1 and map every input value to exactly
3801   // one output value.  Such an operation is invertible.
3802   auto *O1 = dyn_cast<Operator>(V1);
3803   auto *O2 = dyn_cast<Operator>(V2);
3804   if (O1 && O2 && O1->getOpcode() == O2->getOpcode()) {
3805     if (auto Values = getInvertibleOperands(O1, O2))
3806       return isKnownNonEqual(Values->first, Values->second, DemandedElts,
3807                              Depth + 1, Q);
3808 
3809     if (const PHINode *PN1 = dyn_cast<PHINode>(V1)) {
3810       const PHINode *PN2 = cast<PHINode>(V2);
3811       // FIXME: This is missing a generalization to handle the case where one is
3812       // a PHI and another one isn't.
3813       if (isNonEqualPHIs(PN1, PN2, DemandedElts, Depth, Q))
3814         return true;
3815     };
3816   }
3817 
3818   if (isModifyingBinopOfNonZero(V1, V2, DemandedElts, Depth, Q) ||
3819       isModifyingBinopOfNonZero(V2, V1, DemandedElts, Depth, Q))
3820     return true;
3821 
3822   if (isNonEqualMul(V1, V2, DemandedElts, Depth, Q) ||
3823       isNonEqualMul(V2, V1, DemandedElts, Depth, Q))
3824     return true;
3825 
3826   if (isNonEqualShl(V1, V2, DemandedElts, Depth, Q) ||
3827       isNonEqualShl(V2, V1, DemandedElts, Depth, Q))
3828     return true;
3829 
3830   if (V1->getType()->isIntOrIntVectorTy()) {
3831     // Are any known bits in V1 contradictory to known bits in V2? If V1
3832     // has a known zero where V2 has a known one, they must not be equal.
3833     KnownBits Known1 = computeKnownBits(V1, DemandedElts, Depth, Q);
3834     if (!Known1.isUnknown()) {
3835       KnownBits Known2 = computeKnownBits(V2, DemandedElts, Depth, Q);
3836       if (Known1.Zero.intersects(Known2.One) ||
3837           Known2.Zero.intersects(Known1.One))
3838         return true;
3839     }
3840   }
3841 
3842   if (isNonEqualSelect(V1, V2, DemandedElts, Depth, Q) ||
3843       isNonEqualSelect(V2, V1, DemandedElts, Depth, Q))
3844     return true;
3845 
3846   if (isNonEqualPointersWithRecursiveGEP(V1, V2, Q) ||
3847       isNonEqualPointersWithRecursiveGEP(V2, V1, Q))
3848     return true;
3849 
3850   Value *A, *B;
3851   // PtrToInts are NonEqual if their Ptrs are NonEqual.
3852   // Check PtrToInt type matches the pointer size.
3853   if (match(V1, m_PtrToIntSameSize(Q.DL, m_Value(A))) &&
3854       match(V2, m_PtrToIntSameSize(Q.DL, m_Value(B))))
3855     return isKnownNonEqual(A, B, DemandedElts, Depth + 1, Q);
3856 
3857   return false;
3858 }
3859 
3860 /// For vector constants, loop over the elements and find the constant with the
3861 /// minimum number of sign bits. Return 0 if the value is not a vector constant
3862 /// or if any element was not analyzed; otherwise, return the count for the
3863 /// element with the minimum number of sign bits.
3864 static unsigned computeNumSignBitsVectorConstant(const Value *V,
3865                                                  const APInt &DemandedElts,
3866                                                  unsigned TyBits) {
3867   const auto *CV = dyn_cast<Constant>(V);
3868   if (!CV || !isa<FixedVectorType>(CV->getType()))
3869     return 0;
3870 
3871   unsigned MinSignBits = TyBits;
3872   unsigned NumElts = cast<FixedVectorType>(CV->getType())->getNumElements();
3873   for (unsigned i = 0; i != NumElts; ++i) {
3874     if (!DemandedElts[i])
3875       continue;
3876     // If we find a non-ConstantInt, bail out.
3877     auto *Elt = dyn_cast_or_null<ConstantInt>(CV->getAggregateElement(i));
3878     if (!Elt)
3879       return 0;
3880 
3881     MinSignBits = std::min(MinSignBits, Elt->getValue().getNumSignBits());
3882   }
3883 
3884   return MinSignBits;
3885 }
3886 
3887 static unsigned ComputeNumSignBitsImpl(const Value *V,
3888                                        const APInt &DemandedElts,
3889                                        unsigned Depth, const SimplifyQuery &Q);
3890 
3891 static unsigned ComputeNumSignBits(const Value *V, const APInt &DemandedElts,
3892                                    unsigned Depth, const SimplifyQuery &Q) {
3893   unsigned Result = ComputeNumSignBitsImpl(V, DemandedElts, Depth, Q);
3894   assert(Result > 0 && "At least one sign bit needs to be present!");
3895   return Result;
3896 }
3897 
3898 /// Return the number of times the sign bit of the register is replicated into
3899 /// the other bits. We know that at least 1 bit is always equal to the sign bit
3900 /// (itself), but other cases can give us information. For example, immediately
3901 /// after an "ashr X, 2", we know that the top 3 bits are all equal to each
3902 /// other, so we return 3. For vectors, return the number of sign bits for the
3903 /// vector element with the minimum number of known sign bits of the demanded
3904 /// elements in the vector specified by DemandedElts.
3905 static unsigned ComputeNumSignBitsImpl(const Value *V,
3906                                        const APInt &DemandedElts,
3907                                        unsigned Depth, const SimplifyQuery &Q) {
3908   Type *Ty = V->getType();
3909 #ifndef NDEBUG
3910   assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
3911 
3912   if (auto *FVTy = dyn_cast<FixedVectorType>(Ty)) {
3913     assert(
3914         FVTy->getNumElements() == DemandedElts.getBitWidth() &&
3915         "DemandedElt width should equal the fixed vector number of elements");
3916   } else {
3917     assert(DemandedElts == APInt(1, 1) &&
3918            "DemandedElt width should be 1 for scalars");
3919   }
3920 #endif
3921 
3922   // We return the minimum number of sign bits that are guaranteed to be present
3923   // in V, so for undef we have to conservatively return 1.  We don't have the
3924   // same behavior for poison though -- that's a FIXME today.
3925 
3926   Type *ScalarTy = Ty->getScalarType();
3927   unsigned TyBits = ScalarTy->isPointerTy() ?
3928     Q.DL.getPointerTypeSizeInBits(ScalarTy) :
3929     Q.DL.getTypeSizeInBits(ScalarTy);
3930 
3931   unsigned Tmp, Tmp2;
3932   unsigned FirstAnswer = 1;
3933 
3934   // Note that ConstantInt is handled by the general computeKnownBits case
3935   // below.
3936 
3937   if (Depth == MaxAnalysisRecursionDepth)
3938     return 1;
3939 
3940   if (auto *U = dyn_cast<Operator>(V)) {
3941     switch (Operator::getOpcode(V)) {
3942     default: break;
3943     case Instruction::SExt:
3944       Tmp = TyBits - U->getOperand(0)->getType()->getScalarSizeInBits();
3945       return ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q) +
3946              Tmp;
3947 
3948     case Instruction::SDiv: {
3949       const APInt *Denominator;
3950       // sdiv X, C -> adds log(C) sign bits.
3951       if (match(U->getOperand(1), m_APInt(Denominator))) {
3952 
3953         // Ignore non-positive denominator.
3954         if (!Denominator->isStrictlyPositive())
3955           break;
3956 
3957         // Calculate the incoming numerator bits.
3958         unsigned NumBits =
3959             ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
3960 
3961         // Add floor(log(C)) bits to the numerator bits.
3962         return std::min(TyBits, NumBits + Denominator->logBase2());
3963       }
3964       break;
3965     }
3966 
3967     case Instruction::SRem: {
3968       Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
3969 
3970       const APInt *Denominator;
3971       // srem X, C -> we know that the result is within [-C+1,C) when C is a
3972       // positive constant.  This let us put a lower bound on the number of sign
3973       // bits.
3974       if (match(U->getOperand(1), m_APInt(Denominator))) {
3975 
3976         // Ignore non-positive denominator.
3977         if (Denominator->isStrictlyPositive()) {
3978           // Calculate the leading sign bit constraints by examining the
3979           // denominator.  Given that the denominator is positive, there are two
3980           // cases:
3981           //
3982           //  1. The numerator is positive. The result range is [0,C) and
3983           //     [0,C) u< (1 << ceilLogBase2(C)).
3984           //
3985           //  2. The numerator is negative. Then the result range is (-C,0] and
3986           //     integers in (-C,0] are either 0 or >u (-1 << ceilLogBase2(C)).
3987           //
3988           // Thus a lower bound on the number of sign bits is `TyBits -
3989           // ceilLogBase2(C)`.
3990 
3991           unsigned ResBits = TyBits - Denominator->ceilLogBase2();
3992           Tmp = std::max(Tmp, ResBits);
3993         }
3994       }
3995       return Tmp;
3996     }
3997 
3998     case Instruction::AShr: {
3999       Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
4000       // ashr X, C   -> adds C sign bits.  Vectors too.
4001       const APInt *ShAmt;
4002       if (match(U->getOperand(1), m_APInt(ShAmt))) {
4003         if (ShAmt->uge(TyBits))
4004           break; // Bad shift.
4005         unsigned ShAmtLimited = ShAmt->getZExtValue();
4006         Tmp += ShAmtLimited;
4007         if (Tmp > TyBits) Tmp = TyBits;
4008       }
4009       return Tmp;
4010     }
4011     case Instruction::Shl: {
4012       const APInt *ShAmt;
4013       Value *X = nullptr;
4014       if (match(U->getOperand(1), m_APInt(ShAmt))) {
4015         // shl destroys sign bits.
4016         if (ShAmt->uge(TyBits))
4017           break; // Bad shift.
4018         // We can look through a zext (more or less treating it as a sext) if
4019         // all extended bits are shifted out.
4020         if (match(U->getOperand(0), m_ZExt(m_Value(X))) &&
4021             ShAmt->uge(TyBits - X->getType()->getScalarSizeInBits())) {
4022           Tmp = ComputeNumSignBits(X, DemandedElts, Depth + 1, Q);
4023           Tmp += TyBits - X->getType()->getScalarSizeInBits();
4024         } else
4025           Tmp =
4026               ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
4027         if (ShAmt->uge(Tmp))
4028           break; // Shifted all sign bits out.
4029         Tmp2 = ShAmt->getZExtValue();
4030         return Tmp - Tmp2;
4031       }
4032       break;
4033     }
4034     case Instruction::And:
4035     case Instruction::Or:
4036     case Instruction::Xor: // NOT is handled here.
4037       // Logical binary ops preserve the number of sign bits at the worst.
4038       Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
4039       if (Tmp != 1) {
4040         Tmp2 = ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q);
4041         FirstAnswer = std::min(Tmp, Tmp2);
4042         // We computed what we know about the sign bits as our first
4043         // answer. Now proceed to the generic code that uses
4044         // computeKnownBits, and pick whichever answer is better.
4045       }
4046       break;
4047 
4048     case Instruction::Select: {
4049       // If we have a clamp pattern, we know that the number of sign bits will
4050       // be the minimum of the clamp min/max range.
4051       const Value *X;
4052       const APInt *CLow, *CHigh;
4053       if (isSignedMinMaxClamp(U, X, CLow, CHigh))
4054         return std::min(CLow->getNumSignBits(), CHigh->getNumSignBits());
4055 
4056       Tmp = ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q);
4057       if (Tmp == 1)
4058         break;
4059       Tmp2 = ComputeNumSignBits(U->getOperand(2), DemandedElts, Depth + 1, Q);
4060       return std::min(Tmp, Tmp2);
4061     }
4062 
4063     case Instruction::Add:
4064       // Add can have at most one carry bit.  Thus we know that the output
4065       // is, at worst, one more bit than the inputs.
4066       Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
4067       if (Tmp == 1) break;
4068 
4069       // Special case decrementing a value (ADD X, -1):
4070       if (const auto *CRHS = dyn_cast<Constant>(U->getOperand(1)))
4071         if (CRHS->isAllOnesValue()) {
4072           KnownBits Known(TyBits);
4073           computeKnownBits(U->getOperand(0), DemandedElts, Known, Depth + 1, Q);
4074 
4075           // If the input is known to be 0 or 1, the output is 0/-1, which is
4076           // all sign bits set.
4077           if ((Known.Zero | 1).isAllOnes())
4078             return TyBits;
4079 
4080           // If we are subtracting one from a positive number, there is no carry
4081           // out of the result.
4082           if (Known.isNonNegative())
4083             return Tmp;
4084         }
4085 
4086       Tmp2 = ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q);
4087       if (Tmp2 == 1)
4088         break;
4089       return std::min(Tmp, Tmp2) - 1;
4090 
4091     case Instruction::Sub:
4092       Tmp2 = ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q);
4093       if (Tmp2 == 1)
4094         break;
4095 
4096       // Handle NEG.
4097       if (const auto *CLHS = dyn_cast<Constant>(U->getOperand(0)))
4098         if (CLHS->isNullValue()) {
4099           KnownBits Known(TyBits);
4100           computeKnownBits(U->getOperand(1), DemandedElts, Known, Depth + 1, Q);
4101           // If the input is known to be 0 or 1, the output is 0/-1, which is
4102           // all sign bits set.
4103           if ((Known.Zero | 1).isAllOnes())
4104             return TyBits;
4105 
4106           // If the input is known to be positive (the sign bit is known clear),
4107           // the output of the NEG has the same number of sign bits as the
4108           // input.
4109           if (Known.isNonNegative())
4110             return Tmp2;
4111 
4112           // Otherwise, we treat this like a SUB.
4113         }
4114 
4115       // Sub can have at most one carry bit.  Thus we know that the output
4116       // is, at worst, one more bit than the inputs.
4117       Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
4118       if (Tmp == 1)
4119         break;
4120       return std::min(Tmp, Tmp2) - 1;
4121 
4122     case Instruction::Mul: {
4123       // The output of the Mul can be at most twice the valid bits in the
4124       // inputs.
4125       unsigned SignBitsOp0 =
4126           ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
4127       if (SignBitsOp0 == 1)
4128         break;
4129       unsigned SignBitsOp1 =
4130           ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q);
4131       if (SignBitsOp1 == 1)
4132         break;
4133       unsigned OutValidBits =
4134           (TyBits - SignBitsOp0 + 1) + (TyBits - SignBitsOp1 + 1);
4135       return OutValidBits > TyBits ? 1 : TyBits - OutValidBits + 1;
4136     }
4137 
4138     case Instruction::PHI: {
4139       const PHINode *PN = cast<PHINode>(U);
4140       unsigned NumIncomingValues = PN->getNumIncomingValues();
4141       // Don't analyze large in-degree PHIs.
4142       if (NumIncomingValues > 4) break;
4143       // Unreachable blocks may have zero-operand PHI nodes.
4144       if (NumIncomingValues == 0) break;
4145 
4146       // Take the minimum of all incoming values.  This can't infinitely loop
4147       // because of our depth threshold.
4148       SimplifyQuery RecQ = Q.getWithoutCondContext();
4149       Tmp = TyBits;
4150       for (unsigned i = 0, e = NumIncomingValues; i != e; ++i) {
4151         if (Tmp == 1) return Tmp;
4152         RecQ.CxtI = PN->getIncomingBlock(i)->getTerminator();
4153         Tmp = std::min(Tmp, ComputeNumSignBits(PN->getIncomingValue(i),
4154                                                DemandedElts, Depth + 1, RecQ));
4155       }
4156       return Tmp;
4157     }
4158 
4159     case Instruction::Trunc: {
4160       // If the input contained enough sign bits that some remain after the
4161       // truncation, then we can make use of that. Otherwise we don't know
4162       // anything.
4163       Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
4164       unsigned OperandTyBits = U->getOperand(0)->getType()->getScalarSizeInBits();
4165       if (Tmp > (OperandTyBits - TyBits))
4166         return Tmp - (OperandTyBits - TyBits);
4167 
4168       return 1;
4169     }
4170 
4171     case Instruction::ExtractElement:
4172       // Look through extract element. At the moment we keep this simple and
4173       // skip tracking the specific element. But at least we might find
4174       // information valid for all elements of the vector (for example if vector
4175       // is sign extended, shifted, etc).
4176       return ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
4177 
4178     case Instruction::ShuffleVector: {
4179       // Collect the minimum number of sign bits that are shared by every vector
4180       // element referenced by the shuffle.
4181       auto *Shuf = dyn_cast<ShuffleVectorInst>(U);
4182       if (!Shuf) {
4183         // FIXME: Add support for shufflevector constant expressions.
4184         return 1;
4185       }
4186       APInt DemandedLHS, DemandedRHS;
4187       // For undef elements, we don't know anything about the common state of
4188       // the shuffle result.
4189       if (!getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS))
4190         return 1;
4191       Tmp = std::numeric_limits<unsigned>::max();
4192       if (!!DemandedLHS) {
4193         const Value *LHS = Shuf->getOperand(0);
4194         Tmp = ComputeNumSignBits(LHS, DemandedLHS, Depth + 1, Q);
4195       }
4196       // If we don't know anything, early out and try computeKnownBits
4197       // fall-back.
4198       if (Tmp == 1)
4199         break;
4200       if (!!DemandedRHS) {
4201         const Value *RHS = Shuf->getOperand(1);
4202         Tmp2 = ComputeNumSignBits(RHS, DemandedRHS, Depth + 1, Q);
4203         Tmp = std::min(Tmp, Tmp2);
4204       }
4205       // If we don't know anything, early out and try computeKnownBits
4206       // fall-back.
4207       if (Tmp == 1)
4208         break;
4209       assert(Tmp <= TyBits && "Failed to determine minimum sign bits");
4210       return Tmp;
4211     }
4212     case Instruction::Call: {
4213       if (const auto *II = dyn_cast<IntrinsicInst>(U)) {
4214         switch (II->getIntrinsicID()) {
4215         default:
4216           break;
4217         case Intrinsic::abs:
4218           Tmp =
4219               ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
4220           if (Tmp == 1)
4221             break;
4222 
4223           // Absolute value reduces number of sign bits by at most 1.
4224           return Tmp - 1;
4225         case Intrinsic::smin:
4226         case Intrinsic::smax: {
4227           const APInt *CLow, *CHigh;
4228           if (isSignedMinMaxIntrinsicClamp(II, CLow, CHigh))
4229             return std::min(CLow->getNumSignBits(), CHigh->getNumSignBits());
4230         }
4231         }
4232       }
4233     }
4234     }
4235   }
4236 
4237   // Finally, if we can prove that the top bits of the result are 0's or 1's,
4238   // use this information.
4239 
4240   // If we can examine all elements of a vector constant successfully, we're
4241   // done (we can't do any better than that). If not, keep trying.
4242   if (unsigned VecSignBits =
4243           computeNumSignBitsVectorConstant(V, DemandedElts, TyBits))
4244     return VecSignBits;
4245 
4246   KnownBits Known(TyBits);
4247   computeKnownBits(V, DemandedElts, Known, Depth, Q);
4248 
4249   // If we know that the sign bit is either zero or one, determine the number of
4250   // identical bits in the top of the input value.
4251   return std::max(FirstAnswer, Known.countMinSignBits());
4252 }
4253 
4254 Intrinsic::ID llvm::getIntrinsicForCallSite(const CallBase &CB,
4255                                             const TargetLibraryInfo *TLI) {
4256   const Function *F = CB.getCalledFunction();
4257   if (!F)
4258     return Intrinsic::not_intrinsic;
4259 
4260   if (F->isIntrinsic())
4261     return F->getIntrinsicID();
4262 
4263   // We are going to infer semantics of a library function based on mapping it
4264   // to an LLVM intrinsic. Check that the library function is available from
4265   // this callbase and in this environment.
4266   LibFunc Func;
4267   if (F->hasLocalLinkage() || !TLI || !TLI->getLibFunc(CB, Func) ||
4268       !CB.onlyReadsMemory())
4269     return Intrinsic::not_intrinsic;
4270 
4271   switch (Func) {
4272   default:
4273     break;
4274   case LibFunc_sin:
4275   case LibFunc_sinf:
4276   case LibFunc_sinl:
4277     return Intrinsic::sin;
4278   case LibFunc_cos:
4279   case LibFunc_cosf:
4280   case LibFunc_cosl:
4281     return Intrinsic::cos;
4282   case LibFunc_tan:
4283   case LibFunc_tanf:
4284   case LibFunc_tanl:
4285     return Intrinsic::tan;
4286   case LibFunc_asin:
4287   case LibFunc_asinf:
4288   case LibFunc_asinl:
4289     return Intrinsic::asin;
4290   case LibFunc_acos:
4291   case LibFunc_acosf:
4292   case LibFunc_acosl:
4293     return Intrinsic::acos;
4294   case LibFunc_atan:
4295   case LibFunc_atanf:
4296   case LibFunc_atanl:
4297     return Intrinsic::atan;
4298   case LibFunc_atan2:
4299   case LibFunc_atan2f:
4300   case LibFunc_atan2l:
4301     return Intrinsic::atan2;
4302   case LibFunc_sinh:
4303   case LibFunc_sinhf:
4304   case LibFunc_sinhl:
4305     return Intrinsic::sinh;
4306   case LibFunc_cosh:
4307   case LibFunc_coshf:
4308   case LibFunc_coshl:
4309     return Intrinsic::cosh;
4310   case LibFunc_tanh:
4311   case LibFunc_tanhf:
4312   case LibFunc_tanhl:
4313     return Intrinsic::tanh;
4314   case LibFunc_exp:
4315   case LibFunc_expf:
4316   case LibFunc_expl:
4317     return Intrinsic::exp;
4318   case LibFunc_exp2:
4319   case LibFunc_exp2f:
4320   case LibFunc_exp2l:
4321     return Intrinsic::exp2;
4322   case LibFunc_exp10:
4323   case LibFunc_exp10f:
4324   case LibFunc_exp10l:
4325     return Intrinsic::exp10;
4326   case LibFunc_log:
4327   case LibFunc_logf:
4328   case LibFunc_logl:
4329     return Intrinsic::log;
4330   case LibFunc_log10:
4331   case LibFunc_log10f:
4332   case LibFunc_log10l:
4333     return Intrinsic::log10;
4334   case LibFunc_log2:
4335   case LibFunc_log2f:
4336   case LibFunc_log2l:
4337     return Intrinsic::log2;
4338   case LibFunc_fabs:
4339   case LibFunc_fabsf:
4340   case LibFunc_fabsl:
4341     return Intrinsic::fabs;
4342   case LibFunc_fmin:
4343   case LibFunc_fminf:
4344   case LibFunc_fminl:
4345     return Intrinsic::minnum;
4346   case LibFunc_fmax:
4347   case LibFunc_fmaxf:
4348   case LibFunc_fmaxl:
4349     return Intrinsic::maxnum;
4350   case LibFunc_copysign:
4351   case LibFunc_copysignf:
4352   case LibFunc_copysignl:
4353     return Intrinsic::copysign;
4354   case LibFunc_floor:
4355   case LibFunc_floorf:
4356   case LibFunc_floorl:
4357     return Intrinsic::floor;
4358   case LibFunc_ceil:
4359   case LibFunc_ceilf:
4360   case LibFunc_ceill:
4361     return Intrinsic::ceil;
4362   case LibFunc_trunc:
4363   case LibFunc_truncf:
4364   case LibFunc_truncl:
4365     return Intrinsic::trunc;
4366   case LibFunc_rint:
4367   case LibFunc_rintf:
4368   case LibFunc_rintl:
4369     return Intrinsic::rint;
4370   case LibFunc_nearbyint:
4371   case LibFunc_nearbyintf:
4372   case LibFunc_nearbyintl:
4373     return Intrinsic::nearbyint;
4374   case LibFunc_round:
4375   case LibFunc_roundf:
4376   case LibFunc_roundl:
4377     return Intrinsic::round;
4378   case LibFunc_roundeven:
4379   case LibFunc_roundevenf:
4380   case LibFunc_roundevenl:
4381     return Intrinsic::roundeven;
4382   case LibFunc_pow:
4383   case LibFunc_powf:
4384   case LibFunc_powl:
4385     return Intrinsic::pow;
4386   case LibFunc_sqrt:
4387   case LibFunc_sqrtf:
4388   case LibFunc_sqrtl:
4389     return Intrinsic::sqrt;
4390   }
4391 
4392   return Intrinsic::not_intrinsic;
4393 }
4394 
4395 /// Return true if it's possible to assume IEEE treatment of input denormals in
4396 /// \p F for \p Val.
4397 static bool inputDenormalIsIEEE(const Function &F, const Type *Ty) {
4398   Ty = Ty->getScalarType();
4399   return F.getDenormalMode(Ty->getFltSemantics()).Input == DenormalMode::IEEE;
4400 }
4401 
4402 static bool inputDenormalIsIEEEOrPosZero(const Function &F, const Type *Ty) {
4403   Ty = Ty->getScalarType();
4404   DenormalMode Mode = F.getDenormalMode(Ty->getFltSemantics());
4405   return Mode.Input == DenormalMode::IEEE ||
4406          Mode.Input == DenormalMode::PositiveZero;
4407 }
4408 
4409 static bool outputDenormalIsIEEEOrPosZero(const Function &F, const Type *Ty) {
4410   Ty = Ty->getScalarType();
4411   DenormalMode Mode = F.getDenormalMode(Ty->getFltSemantics());
4412   return Mode.Output == DenormalMode::IEEE ||
4413          Mode.Output == DenormalMode::PositiveZero;
4414 }
4415 
4416 bool KnownFPClass::isKnownNeverLogicalZero(const Function &F, Type *Ty) const {
4417   return isKnownNeverZero() &&
4418          (isKnownNeverSubnormal() || inputDenormalIsIEEE(F, Ty));
4419 }
4420 
4421 bool KnownFPClass::isKnownNeverLogicalNegZero(const Function &F,
4422                                               Type *Ty) const {
4423   return isKnownNeverNegZero() &&
4424          (isKnownNeverNegSubnormal() || inputDenormalIsIEEEOrPosZero(F, Ty));
4425 }
4426 
4427 bool KnownFPClass::isKnownNeverLogicalPosZero(const Function &F,
4428                                               Type *Ty) const {
4429   if (!isKnownNeverPosZero())
4430     return false;
4431 
4432   // If we know there are no denormals, nothing can be flushed to zero.
4433   if (isKnownNeverSubnormal())
4434     return true;
4435 
4436   DenormalMode Mode = F.getDenormalMode(Ty->getScalarType()->getFltSemantics());
4437   switch (Mode.Input) {
4438   case DenormalMode::IEEE:
4439     return true;
4440   case DenormalMode::PreserveSign:
4441     // Negative subnormal won't flush to +0
4442     return isKnownNeverPosSubnormal();
4443   case DenormalMode::PositiveZero:
4444   default:
4445     // Both positive and negative subnormal could flush to +0
4446     return false;
4447   }
4448 
4449   llvm_unreachable("covered switch over denormal mode");
4450 }
4451 
4452 void KnownFPClass::propagateDenormal(const KnownFPClass &Src, const Function &F,
4453                                      Type *Ty) {
4454   KnownFPClasses = Src.KnownFPClasses;
4455   // If we aren't assuming the source can't be a zero, we don't have to check if
4456   // a denormal input could be flushed.
4457   if (!Src.isKnownNeverPosZero() && !Src.isKnownNeverNegZero())
4458     return;
4459 
4460   // If we know the input can't be a denormal, it can't be flushed to 0.
4461   if (Src.isKnownNeverSubnormal())
4462     return;
4463 
4464   DenormalMode Mode = F.getDenormalMode(Ty->getScalarType()->getFltSemantics());
4465 
4466   if (!Src.isKnownNeverPosSubnormal() && Mode != DenormalMode::getIEEE())
4467     KnownFPClasses |= fcPosZero;
4468 
4469   if (!Src.isKnownNeverNegSubnormal() && Mode != DenormalMode::getIEEE()) {
4470     if (Mode != DenormalMode::getPositiveZero())
4471       KnownFPClasses |= fcNegZero;
4472 
4473     if (Mode.Input == DenormalMode::PositiveZero ||
4474         Mode.Output == DenormalMode::PositiveZero ||
4475         Mode.Input == DenormalMode::Dynamic ||
4476         Mode.Output == DenormalMode::Dynamic)
4477       KnownFPClasses |= fcPosZero;
4478   }
4479 }
4480 
4481 void KnownFPClass::propagateCanonicalizingSrc(const KnownFPClass &Src,
4482                                               const Function &F, Type *Ty) {
4483   propagateDenormal(Src, F, Ty);
4484   propagateNaN(Src, /*PreserveSign=*/true);
4485 }
4486 
4487 /// Given an exploded icmp instruction, return true if the comparison only
4488 /// checks the sign bit. If it only checks the sign bit, set TrueIfSigned if
4489 /// the result of the comparison is true when the input value is signed.
4490 bool llvm::isSignBitCheck(ICmpInst::Predicate Pred, const APInt &RHS,
4491                           bool &TrueIfSigned) {
4492   switch (Pred) {
4493   case ICmpInst::ICMP_SLT: // True if LHS s< 0
4494     TrueIfSigned = true;
4495     return RHS.isZero();
4496   case ICmpInst::ICMP_SLE: // True if LHS s<= -1
4497     TrueIfSigned = true;
4498     return RHS.isAllOnes();
4499   case ICmpInst::ICMP_SGT: // True if LHS s> -1
4500     TrueIfSigned = false;
4501     return RHS.isAllOnes();
4502   case ICmpInst::ICMP_SGE: // True if LHS s>= 0
4503     TrueIfSigned = false;
4504     return RHS.isZero();
4505   case ICmpInst::ICMP_UGT:
4506     // True if LHS u> RHS and RHS == sign-bit-mask - 1
4507     TrueIfSigned = true;
4508     return RHS.isMaxSignedValue();
4509   case ICmpInst::ICMP_UGE:
4510     // True if LHS u>= RHS and RHS == sign-bit-mask (2^7, 2^15, 2^31, etc)
4511     TrueIfSigned = true;
4512     return RHS.isMinSignedValue();
4513   case ICmpInst::ICMP_ULT:
4514     // True if LHS u< RHS and RHS == sign-bit-mask (2^7, 2^15, 2^31, etc)
4515     TrueIfSigned = false;
4516     return RHS.isMinSignedValue();
4517   case ICmpInst::ICMP_ULE:
4518     // True if LHS u<= RHS and RHS == sign-bit-mask - 1
4519     TrueIfSigned = false;
4520     return RHS.isMaxSignedValue();
4521   default:
4522     return false;
4523   }
4524 }
4525 
4526 /// Returns a pair of values, which if passed to llvm.is.fpclass, returns the
4527 /// same result as an fcmp with the given operands.
4528 std::pair<Value *, FPClassTest> llvm::fcmpToClassTest(FCmpInst::Predicate Pred,
4529                                                       const Function &F,
4530                                                       Value *LHS, Value *RHS,
4531                                                       bool LookThroughSrc) {
4532   const APFloat *ConstRHS;
4533   if (!match(RHS, m_APFloatAllowPoison(ConstRHS)))
4534     return {nullptr, fcAllFlags};
4535 
4536   return fcmpToClassTest(Pred, F, LHS, ConstRHS, LookThroughSrc);
4537 }
4538 
4539 std::pair<Value *, FPClassTest>
4540 llvm::fcmpToClassTest(FCmpInst::Predicate Pred, const Function &F, Value *LHS,
4541                       const APFloat *ConstRHS, bool LookThroughSrc) {
4542 
4543   auto [Src, ClassIfTrue, ClassIfFalse] =
4544       fcmpImpliesClass(Pred, F, LHS, *ConstRHS, LookThroughSrc);
4545   if (Src && ClassIfTrue == ~ClassIfFalse)
4546     return {Src, ClassIfTrue};
4547   return {nullptr, fcAllFlags};
4548 }
4549 
4550 /// Return the return value for fcmpImpliesClass for a compare that produces an
4551 /// exact class test.
4552 static std::tuple<Value *, FPClassTest, FPClassTest> exactClass(Value *V,
4553                                                                 FPClassTest M) {
4554   return {V, M, ~M};
4555 }
4556 
4557 std::tuple<Value *, FPClassTest, FPClassTest>
4558 llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
4559                        FPClassTest RHSClass, bool LookThroughSrc) {
4560   assert(RHSClass != fcNone);
4561   Value *Src = LHS;
4562 
4563   if (Pred == FCmpInst::FCMP_TRUE)
4564     return exactClass(Src, fcAllFlags);
4565 
4566   if (Pred == FCmpInst::FCMP_FALSE)
4567     return exactClass(Src, fcNone);
4568 
4569   const FPClassTest OrigClass = RHSClass;
4570 
4571   const bool IsNegativeRHS = (RHSClass & fcNegative) == RHSClass;
4572   const bool IsPositiveRHS = (RHSClass & fcPositive) == RHSClass;
4573   const bool IsNaN = (RHSClass & ~fcNan) == fcNone;
4574 
4575   if (IsNaN) {
4576     // fcmp o__ x, nan -> false
4577     // fcmp u__ x, nan -> true
4578     return exactClass(Src, CmpInst::isOrdered(Pred) ? fcNone : fcAllFlags);
4579   }
4580 
4581   // fcmp ord x, zero|normal|subnormal|inf -> ~fcNan
4582   if (Pred == FCmpInst::FCMP_ORD)
4583     return exactClass(Src, ~fcNan);
4584 
4585   // fcmp uno x, zero|normal|subnormal|inf -> fcNan
4586   if (Pred == FCmpInst::FCMP_UNO)
4587     return exactClass(Src, fcNan);
4588 
4589   const bool IsFabs = LookThroughSrc && match(LHS, m_FAbs(m_Value(Src)));
4590   if (IsFabs)
4591     RHSClass = llvm::inverse_fabs(RHSClass);
4592 
4593   const bool IsZero = (OrigClass & fcZero) == OrigClass;
4594   if (IsZero) {
4595     assert(Pred != FCmpInst::FCMP_ORD && Pred != FCmpInst::FCMP_UNO);
4596     // Compares with fcNone are only exactly equal to fcZero if input denormals
4597     // are not flushed.
4598     // TODO: Handle DAZ by expanding masks to cover subnormal cases.
4599     if (!inputDenormalIsIEEE(F, LHS->getType()))
4600       return {nullptr, fcAllFlags, fcAllFlags};
4601 
4602     switch (Pred) {
4603     case FCmpInst::FCMP_OEQ: // Match x == 0.0
4604       return exactClass(Src, fcZero);
4605     case FCmpInst::FCMP_UEQ: // Match isnan(x) || (x == 0.0)
4606       return exactClass(Src, fcZero | fcNan);
4607     case FCmpInst::FCMP_UNE: // Match (x != 0.0)
4608       return exactClass(Src, ~fcZero);
4609     case FCmpInst::FCMP_ONE: // Match !isnan(x) && x != 0.0
4610       return exactClass(Src, ~fcNan & ~fcZero);
4611     case FCmpInst::FCMP_ORD:
4612       // Canonical form of ord/uno is with a zero. We could also handle
4613       // non-canonical other non-NaN constants or LHS == RHS.
4614       return exactClass(Src, ~fcNan);
4615     case FCmpInst::FCMP_UNO:
4616       return exactClass(Src, fcNan);
4617     case FCmpInst::FCMP_OGT: // x > 0
4618       return exactClass(Src, fcPosSubnormal | fcPosNormal | fcPosInf);
4619     case FCmpInst::FCMP_UGT: // isnan(x) || x > 0
4620       return exactClass(Src, fcPosSubnormal | fcPosNormal | fcPosInf | fcNan);
4621     case FCmpInst::FCMP_OGE: // x >= 0
4622       return exactClass(Src, fcPositive | fcNegZero);
4623     case FCmpInst::FCMP_UGE: // isnan(x) || x >= 0
4624       return exactClass(Src, fcPositive | fcNegZero | fcNan);
4625     case FCmpInst::FCMP_OLT: // x < 0
4626       return exactClass(Src, fcNegSubnormal | fcNegNormal | fcNegInf);
4627     case FCmpInst::FCMP_ULT: // isnan(x) || x < 0
4628       return exactClass(Src, fcNegSubnormal | fcNegNormal | fcNegInf | fcNan);
4629     case FCmpInst::FCMP_OLE: // x <= 0
4630       return exactClass(Src, fcNegative | fcPosZero);
4631     case FCmpInst::FCMP_ULE: // isnan(x) || x <= 0
4632       return exactClass(Src, fcNegative | fcPosZero | fcNan);
4633     default:
4634       llvm_unreachable("all compare types are handled");
4635     }
4636 
4637     return {nullptr, fcAllFlags, fcAllFlags};
4638   }
4639 
4640   const bool IsDenormalRHS = (OrigClass & fcSubnormal) == OrigClass;
4641 
4642   const bool IsInf = (OrigClass & fcInf) == OrigClass;
4643   if (IsInf) {
4644     FPClassTest Mask = fcAllFlags;
4645 
4646     switch (Pred) {
4647     case FCmpInst::FCMP_OEQ:
4648     case FCmpInst::FCMP_UNE: {
4649       // Match __builtin_isinf patterns
4650       //
4651       //   fcmp oeq x, +inf -> is_fpclass x, fcPosInf
4652       //   fcmp oeq fabs(x), +inf -> is_fpclass x, fcInf
4653       //   fcmp oeq x, -inf -> is_fpclass x, fcNegInf
4654       //   fcmp oeq fabs(x), -inf -> is_fpclass x, 0 -> false
4655       //
4656       //   fcmp une x, +inf -> is_fpclass x, ~fcPosInf
4657       //   fcmp une fabs(x), +inf -> is_fpclass x, ~fcInf
4658       //   fcmp une x, -inf -> is_fpclass x, ~fcNegInf
4659       //   fcmp une fabs(x), -inf -> is_fpclass x, fcAllFlags -> true
4660       if (IsNegativeRHS) {
4661         Mask = fcNegInf;
4662         if (IsFabs)
4663           Mask = fcNone;
4664       } else {
4665         Mask = fcPosInf;
4666         if (IsFabs)
4667           Mask |= fcNegInf;
4668       }
4669       break;
4670     }
4671     case FCmpInst::FCMP_ONE:
4672     case FCmpInst::FCMP_UEQ: {
4673       // Match __builtin_isinf patterns
4674       //   fcmp one x, -inf -> is_fpclass x, fcNegInf
4675       //   fcmp one fabs(x), -inf -> is_fpclass x, ~fcNegInf & ~fcNan
4676       //   fcmp one x, +inf -> is_fpclass x, ~fcNegInf & ~fcNan
4677       //   fcmp one fabs(x), +inf -> is_fpclass x, ~fcInf & fcNan
4678       //
4679       //   fcmp ueq x, +inf -> is_fpclass x, fcPosInf|fcNan
4680       //   fcmp ueq (fabs x), +inf -> is_fpclass x, fcInf|fcNan
4681       //   fcmp ueq x, -inf -> is_fpclass x, fcNegInf|fcNan
4682       //   fcmp ueq fabs(x), -inf -> is_fpclass x, fcNan
4683       if (IsNegativeRHS) {
4684         Mask = ~fcNegInf & ~fcNan;
4685         if (IsFabs)
4686           Mask = ~fcNan;
4687       } else {
4688         Mask = ~fcPosInf & ~fcNan;
4689         if (IsFabs)
4690           Mask &= ~fcNegInf;
4691       }
4692 
4693       break;
4694     }
4695     case FCmpInst::FCMP_OLT:
4696     case FCmpInst::FCMP_UGE: {
4697       if (IsNegativeRHS) {
4698         // No value is ordered and less than negative infinity.
4699         // All values are unordered with or at least negative infinity.
4700         // fcmp olt x, -inf -> false
4701         // fcmp uge x, -inf -> true
4702         Mask = fcNone;
4703         break;
4704       }
4705 
4706       // fcmp olt fabs(x), +inf -> fcFinite
4707       // fcmp uge fabs(x), +inf -> ~fcFinite
4708       // fcmp olt x, +inf -> fcFinite|fcNegInf
4709       // fcmp uge x, +inf -> ~(fcFinite|fcNegInf)
4710       Mask = fcFinite;
4711       if (!IsFabs)
4712         Mask |= fcNegInf;
4713       break;
4714     }
4715     case FCmpInst::FCMP_OGE:
4716     case FCmpInst::FCMP_ULT: {
4717       if (IsNegativeRHS) {
4718         // fcmp oge x, -inf -> ~fcNan
4719         // fcmp oge fabs(x), -inf -> ~fcNan
4720         // fcmp ult x, -inf -> fcNan
4721         // fcmp ult fabs(x), -inf -> fcNan
4722         Mask = ~fcNan;
4723         break;
4724       }
4725 
4726       // fcmp oge fabs(x), +inf -> fcInf
4727       // fcmp oge x, +inf -> fcPosInf
4728       // fcmp ult fabs(x), +inf -> ~fcInf
4729       // fcmp ult x, +inf -> ~fcPosInf
4730       Mask = fcPosInf;
4731       if (IsFabs)
4732         Mask |= fcNegInf;
4733       break;
4734     }
4735     case FCmpInst::FCMP_OGT:
4736     case FCmpInst::FCMP_ULE: {
4737       if (IsNegativeRHS) {
4738         // fcmp ogt x, -inf -> fcmp one x, -inf
4739         // fcmp ogt fabs(x), -inf -> fcmp ord x, x
4740         // fcmp ule x, -inf -> fcmp ueq x, -inf
4741         // fcmp ule fabs(x), -inf -> fcmp uno x, x
4742         Mask = IsFabs ? ~fcNan : ~(fcNegInf | fcNan);
4743         break;
4744       }
4745 
4746       // No value is ordered and greater than infinity.
4747       Mask = fcNone;
4748       break;
4749     }
4750     case FCmpInst::FCMP_OLE:
4751     case FCmpInst::FCMP_UGT: {
4752       if (IsNegativeRHS) {
4753         Mask = IsFabs ? fcNone : fcNegInf;
4754         break;
4755       }
4756 
4757       // fcmp ole x, +inf -> fcmp ord x, x
4758       // fcmp ole fabs(x), +inf -> fcmp ord x, x
4759       // fcmp ole x, -inf -> fcmp oeq x, -inf
4760       // fcmp ole fabs(x), -inf -> false
4761       Mask = ~fcNan;
4762       break;
4763     }
4764     default:
4765       llvm_unreachable("all compare types are handled");
4766     }
4767 
4768     // Invert the comparison for the unordered cases.
4769     if (FCmpInst::isUnordered(Pred))
4770       Mask = ~Mask;
4771 
4772     return exactClass(Src, Mask);
4773   }
4774 
4775   if (Pred == FCmpInst::FCMP_OEQ)
4776     return {Src, RHSClass, fcAllFlags};
4777 
4778   if (Pred == FCmpInst::FCMP_UEQ) {
4779     FPClassTest Class = RHSClass | fcNan;
4780     return {Src, Class, ~fcNan};
4781   }
4782 
4783   if (Pred == FCmpInst::FCMP_ONE)
4784     return {Src, ~fcNan, RHSClass | fcNan};
4785 
4786   if (Pred == FCmpInst::FCMP_UNE)
4787     return {Src, fcAllFlags, RHSClass};
4788 
4789   assert((RHSClass == fcNone || RHSClass == fcPosNormal ||
4790           RHSClass == fcNegNormal || RHSClass == fcNormal ||
4791           RHSClass == fcPosSubnormal || RHSClass == fcNegSubnormal ||
4792           RHSClass == fcSubnormal) &&
4793          "should have been recognized as an exact class test");
4794 
4795   if (IsNegativeRHS) {
4796     // TODO: Handle fneg(fabs)
4797     if (IsFabs) {
4798       // fabs(x) o> -k -> fcmp ord x, x
4799       // fabs(x) u> -k -> true
4800       // fabs(x) o< -k -> false
4801       // fabs(x) u< -k -> fcmp uno x, x
4802       switch (Pred) {
4803       case FCmpInst::FCMP_OGT:
4804       case FCmpInst::FCMP_OGE:
4805         return {Src, ~fcNan, fcNan};
4806       case FCmpInst::FCMP_UGT:
4807       case FCmpInst::FCMP_UGE:
4808         return {Src, fcAllFlags, fcNone};
4809       case FCmpInst::FCMP_OLT:
4810       case FCmpInst::FCMP_OLE:
4811         return {Src, fcNone, fcAllFlags};
4812       case FCmpInst::FCMP_ULT:
4813       case FCmpInst::FCMP_ULE:
4814         return {Src, fcNan, ~fcNan};
4815       default:
4816         break;
4817       }
4818 
4819       return {nullptr, fcAllFlags, fcAllFlags};
4820     }
4821 
4822     FPClassTest ClassesLE = fcNegInf | fcNegNormal;
4823     FPClassTest ClassesGE = fcPositive | fcNegZero | fcNegSubnormal;
4824 
4825     if (IsDenormalRHS)
4826       ClassesLE |= fcNegSubnormal;
4827     else
4828       ClassesGE |= fcNegNormal;
4829 
4830     switch (Pred) {
4831     case FCmpInst::FCMP_OGT:
4832     case FCmpInst::FCMP_OGE:
4833       return {Src, ClassesGE, ~ClassesGE | RHSClass};
4834     case FCmpInst::FCMP_UGT:
4835     case FCmpInst::FCMP_UGE:
4836       return {Src, ClassesGE | fcNan, ~(ClassesGE | fcNan) | RHSClass};
4837     case FCmpInst::FCMP_OLT:
4838     case FCmpInst::FCMP_OLE:
4839       return {Src, ClassesLE, ~ClassesLE | RHSClass};
4840     case FCmpInst::FCMP_ULT:
4841     case FCmpInst::FCMP_ULE:
4842       return {Src, ClassesLE | fcNan, ~(ClassesLE | fcNan) | RHSClass};
4843     default:
4844       break;
4845     }
4846   } else if (IsPositiveRHS) {
4847     FPClassTest ClassesGE = fcPosNormal | fcPosInf;
4848     FPClassTest ClassesLE = fcNegative | fcPosZero | fcPosSubnormal;
4849     if (IsDenormalRHS)
4850       ClassesGE |= fcPosSubnormal;
4851     else
4852       ClassesLE |= fcPosNormal;
4853 
4854     if (IsFabs) {
4855       ClassesGE = llvm::inverse_fabs(ClassesGE);
4856       ClassesLE = llvm::inverse_fabs(ClassesLE);
4857     }
4858 
4859     switch (Pred) {
4860     case FCmpInst::FCMP_OGT:
4861     case FCmpInst::FCMP_OGE:
4862       return {Src, ClassesGE, ~ClassesGE | RHSClass};
4863     case FCmpInst::FCMP_UGT:
4864     case FCmpInst::FCMP_UGE:
4865       return {Src, ClassesGE | fcNan, ~(ClassesGE | fcNan) | RHSClass};
4866     case FCmpInst::FCMP_OLT:
4867     case FCmpInst::FCMP_OLE:
4868       return {Src, ClassesLE, ~ClassesLE | RHSClass};
4869     case FCmpInst::FCMP_ULT:
4870     case FCmpInst::FCMP_ULE:
4871       return {Src, ClassesLE | fcNan, ~(ClassesLE | fcNan) | RHSClass};
4872     default:
4873       break;
4874     }
4875   }
4876 
4877   return {nullptr, fcAllFlags, fcAllFlags};
4878 }
4879 
4880 std::tuple<Value *, FPClassTest, FPClassTest>
4881 llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
4882                        const APFloat &ConstRHS, bool LookThroughSrc) {
4883   // We can refine checks against smallest normal / largest denormal to an
4884   // exact class test.
4885   if (!ConstRHS.isNegative() && ConstRHS.isSmallestNormalized()) {
4886     Value *Src = LHS;
4887     const bool IsFabs = LookThroughSrc && match(LHS, m_FAbs(m_Value(Src)));
4888 
4889     FPClassTest Mask;
4890     // Match pattern that's used in __builtin_isnormal.
4891     switch (Pred) {
4892     case FCmpInst::FCMP_OLT:
4893     case FCmpInst::FCMP_UGE: {
4894       // fcmp olt x, smallest_normal -> fcNegInf|fcNegNormal|fcSubnormal|fcZero
4895       // fcmp olt fabs(x), smallest_normal -> fcSubnormal|fcZero
4896       // fcmp uge x, smallest_normal -> fcNan|fcPosNormal|fcPosInf
4897       // fcmp uge fabs(x), smallest_normal -> ~(fcSubnormal|fcZero)
4898       Mask = fcZero | fcSubnormal;
4899       if (!IsFabs)
4900         Mask |= fcNegNormal | fcNegInf;
4901 
4902       break;
4903     }
4904     case FCmpInst::FCMP_OGE:
4905     case FCmpInst::FCMP_ULT: {
4906       // fcmp oge x, smallest_normal -> fcPosNormal | fcPosInf
4907       // fcmp oge fabs(x), smallest_normal -> fcInf | fcNormal
4908       // fcmp ult x, smallest_normal -> ~(fcPosNormal | fcPosInf)
4909       // fcmp ult fabs(x), smallest_normal -> ~(fcInf | fcNormal)
4910       Mask = fcPosInf | fcPosNormal;
4911       if (IsFabs)
4912         Mask |= fcNegInf | fcNegNormal;
4913       break;
4914     }
4915     default:
4916       return fcmpImpliesClass(Pred, F, LHS, ConstRHS.classify(),
4917                               LookThroughSrc);
4918     }
4919 
4920     // Invert the comparison for the unordered cases.
4921     if (FCmpInst::isUnordered(Pred))
4922       Mask = ~Mask;
4923 
4924     return exactClass(Src, Mask);
4925   }
4926 
4927   return fcmpImpliesClass(Pred, F, LHS, ConstRHS.classify(), LookThroughSrc);
4928 }
4929 
4930 std::tuple<Value *, FPClassTest, FPClassTest>
4931 llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
4932                        Value *RHS, bool LookThroughSrc) {
4933   const APFloat *ConstRHS;
4934   if (!match(RHS, m_APFloatAllowPoison(ConstRHS)))
4935     return {nullptr, fcAllFlags, fcAllFlags};
4936 
4937   // TODO: Just call computeKnownFPClass for RHS to handle non-constants.
4938   return fcmpImpliesClass(Pred, F, LHS, *ConstRHS, LookThroughSrc);
4939 }
4940 
4941 static void computeKnownFPClassFromCond(const Value *V, Value *Cond,
4942                                         unsigned Depth, bool CondIsTrue,
4943                                         const Instruction *CxtI,
4944                                         KnownFPClass &KnownFromContext) {
4945   Value *A, *B;
4946   if (Depth < MaxAnalysisRecursionDepth &&
4947       (CondIsTrue ? match(Cond, m_LogicalAnd(m_Value(A), m_Value(B)))
4948                   : match(Cond, m_LogicalOr(m_Value(A), m_Value(B))))) {
4949     computeKnownFPClassFromCond(V, A, Depth + 1, CondIsTrue, CxtI,
4950                                 KnownFromContext);
4951     computeKnownFPClassFromCond(V, B, Depth + 1, CondIsTrue, CxtI,
4952                                 KnownFromContext);
4953     return;
4954   }
4955   CmpPredicate Pred;
4956   Value *LHS;
4957   uint64_t ClassVal = 0;
4958   const APFloat *CRHS;
4959   const APInt *RHS;
4960   if (match(Cond, m_FCmp(Pred, m_Value(LHS), m_APFloat(CRHS)))) {
4961     auto [CmpVal, MaskIfTrue, MaskIfFalse] = fcmpImpliesClass(
4962         Pred, *CxtI->getParent()->getParent(), LHS, *CRHS, LHS != V);
4963     if (CmpVal == V)
4964       KnownFromContext.knownNot(~(CondIsTrue ? MaskIfTrue : MaskIfFalse));
4965   } else if (match(Cond, m_Intrinsic<Intrinsic::is_fpclass>(
4966                              m_Specific(V), m_ConstantInt(ClassVal)))) {
4967     FPClassTest Mask = static_cast<FPClassTest>(ClassVal);
4968     KnownFromContext.knownNot(CondIsTrue ? ~Mask : Mask);
4969   } else if (match(Cond, m_ICmp(Pred, m_ElementWiseBitCast(m_Specific(V)),
4970                                 m_APInt(RHS)))) {
4971     bool TrueIfSigned;
4972     if (!isSignBitCheck(Pred, *RHS, TrueIfSigned))
4973       return;
4974     if (TrueIfSigned == CondIsTrue)
4975       KnownFromContext.signBitMustBeOne();
4976     else
4977       KnownFromContext.signBitMustBeZero();
4978   }
4979 }
4980 
4981 static KnownFPClass computeKnownFPClassFromContext(const Value *V,
4982                                                    const SimplifyQuery &Q) {
4983   KnownFPClass KnownFromContext;
4984 
4985   if (!Q.CxtI)
4986     return KnownFromContext;
4987 
4988   if (Q.DC && Q.DT) {
4989     // Handle dominating conditions.
4990     for (BranchInst *BI : Q.DC->conditionsFor(V)) {
4991       Value *Cond = BI->getCondition();
4992 
4993       BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
4994       if (Q.DT->dominates(Edge0, Q.CxtI->getParent()))
4995         computeKnownFPClassFromCond(V, Cond, /*Depth=*/0, /*CondIsTrue=*/true,
4996                                     Q.CxtI, KnownFromContext);
4997 
4998       BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(1));
4999       if (Q.DT->dominates(Edge1, Q.CxtI->getParent()))
5000         computeKnownFPClassFromCond(V, Cond, /*Depth=*/0, /*CondIsTrue=*/false,
5001                                     Q.CxtI, KnownFromContext);
5002     }
5003   }
5004 
5005   if (!Q.AC)
5006     return KnownFromContext;
5007 
5008   // Try to restrict the floating-point classes based on information from
5009   // assumptions.
5010   for (auto &AssumeVH : Q.AC->assumptionsFor(V)) {
5011     if (!AssumeVH)
5012       continue;
5013     CallInst *I = cast<CallInst>(AssumeVH);
5014 
5015     assert(I->getFunction() == Q.CxtI->getParent()->getParent() &&
5016            "Got assumption for the wrong function!");
5017     assert(I->getIntrinsicID() == Intrinsic::assume &&
5018            "must be an assume intrinsic");
5019 
5020     if (!isValidAssumeForContext(I, Q.CxtI, Q.DT))
5021       continue;
5022 
5023     computeKnownFPClassFromCond(V, I->getArgOperand(0), /*Depth=*/0,
5024                                 /*CondIsTrue=*/true, Q.CxtI, KnownFromContext);
5025   }
5026 
5027   return KnownFromContext;
5028 }
5029 
5030 void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
5031                          FPClassTest InterestedClasses, KnownFPClass &Known,
5032                          unsigned Depth, const SimplifyQuery &Q);
5033 
5034 static void computeKnownFPClass(const Value *V, KnownFPClass &Known,
5035                                 FPClassTest InterestedClasses, unsigned Depth,
5036                                 const SimplifyQuery &Q) {
5037   auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
5038   APInt DemandedElts =
5039       FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
5040   computeKnownFPClass(V, DemandedElts, InterestedClasses, Known, Depth, Q);
5041 }
5042 
5043 static void computeKnownFPClassForFPTrunc(const Operator *Op,
5044                                           const APInt &DemandedElts,
5045                                           FPClassTest InterestedClasses,
5046                                           KnownFPClass &Known, unsigned Depth,
5047                                           const SimplifyQuery &Q) {
5048   if ((InterestedClasses &
5049        (KnownFPClass::OrderedLessThanZeroMask | fcNan)) == fcNone)
5050     return;
5051 
5052   KnownFPClass KnownSrc;
5053   computeKnownFPClass(Op->getOperand(0), DemandedElts, InterestedClasses,
5054                       KnownSrc, Depth + 1, Q);
5055 
5056   // Sign should be preserved
5057   // TODO: Handle cannot be ordered greater than zero
5058   if (KnownSrc.cannotBeOrderedLessThanZero())
5059     Known.knownNot(KnownFPClass::OrderedLessThanZeroMask);
5060 
5061   Known.propagateNaN(KnownSrc, true);
5062 
5063   // Infinity needs a range check.
5064 }
5065 
5066 void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
5067                          FPClassTest InterestedClasses, KnownFPClass &Known,
5068                          unsigned Depth, const SimplifyQuery &Q) {
5069   assert(Known.isUnknown() && "should not be called with known information");
5070 
5071   if (!DemandedElts) {
5072     // No demanded elts, better to assume we don't know anything.
5073     Known.resetAll();
5074     return;
5075   }
5076 
5077   assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
5078 
5079   if (auto *CFP = dyn_cast<ConstantFP>(V)) {
5080     Known.KnownFPClasses = CFP->getValueAPF().classify();
5081     Known.SignBit = CFP->isNegative();
5082     return;
5083   }
5084 
5085   if (isa<ConstantAggregateZero>(V)) {
5086     Known.KnownFPClasses = fcPosZero;
5087     Known.SignBit = false;
5088     return;
5089   }
5090 
5091   if (isa<PoisonValue>(V)) {
5092     Known.KnownFPClasses = fcNone;
5093     Known.SignBit = false;
5094     return;
5095   }
5096 
5097   // Try to handle fixed width vector constants
5098   auto *VFVTy = dyn_cast<FixedVectorType>(V->getType());
5099   const Constant *CV = dyn_cast<Constant>(V);
5100   if (VFVTy && CV) {
5101     Known.KnownFPClasses = fcNone;
5102     bool SignBitAllZero = true;
5103     bool SignBitAllOne = true;
5104 
5105     // For vectors, verify that each element is not NaN.
5106     unsigned NumElts = VFVTy->getNumElements();
5107     for (unsigned i = 0; i != NumElts; ++i) {
5108       if (!DemandedElts[i])
5109         continue;
5110 
5111       Constant *Elt = CV->getAggregateElement(i);
5112       if (!Elt) {
5113         Known = KnownFPClass();
5114         return;
5115       }
5116       if (isa<PoisonValue>(Elt))
5117         continue;
5118       auto *CElt = dyn_cast<ConstantFP>(Elt);
5119       if (!CElt) {
5120         Known = KnownFPClass();
5121         return;
5122       }
5123 
5124       const APFloat &C = CElt->getValueAPF();
5125       Known.KnownFPClasses |= C.classify();
5126       if (C.isNegative())
5127         SignBitAllZero = false;
5128       else
5129         SignBitAllOne = false;
5130     }
5131     if (SignBitAllOne != SignBitAllZero)
5132       Known.SignBit = SignBitAllOne;
5133     return;
5134   }
5135 
5136   FPClassTest KnownNotFromFlags = fcNone;
5137   if (const auto *CB = dyn_cast<CallBase>(V))
5138     KnownNotFromFlags |= CB->getRetNoFPClass();
5139   else if (const auto *Arg = dyn_cast<Argument>(V))
5140     KnownNotFromFlags |= Arg->getNoFPClass();
5141 
5142   const Operator *Op = dyn_cast<Operator>(V);
5143   if (const FPMathOperator *FPOp = dyn_cast_or_null<FPMathOperator>(Op)) {
5144     if (FPOp->hasNoNaNs())
5145       KnownNotFromFlags |= fcNan;
5146     if (FPOp->hasNoInfs())
5147       KnownNotFromFlags |= fcInf;
5148   }
5149 
5150   KnownFPClass AssumedClasses = computeKnownFPClassFromContext(V, Q);
5151   KnownNotFromFlags |= ~AssumedClasses.KnownFPClasses;
5152 
5153   // We no longer need to find out about these bits from inputs if we can
5154   // assume this from flags/attributes.
5155   InterestedClasses &= ~KnownNotFromFlags;
5156 
5157   auto ClearClassesFromFlags = make_scope_exit([=, &Known] {
5158     Known.knownNot(KnownNotFromFlags);
5159     if (!Known.SignBit && AssumedClasses.SignBit) {
5160       if (*AssumedClasses.SignBit)
5161         Known.signBitMustBeOne();
5162       else
5163         Known.signBitMustBeZero();
5164     }
5165   });
5166 
5167   if (!Op)
5168     return;
5169 
5170   // All recursive calls that increase depth must come after this.
5171   if (Depth == MaxAnalysisRecursionDepth)
5172     return;
5173 
5174   const unsigned Opc = Op->getOpcode();
5175   switch (Opc) {
5176   case Instruction::FNeg: {
5177     computeKnownFPClass(Op->getOperand(0), DemandedElts, InterestedClasses,
5178                         Known, Depth + 1, Q);
5179     Known.fneg();
5180     break;
5181   }
5182   case Instruction::Select: {
5183     Value *Cond = Op->getOperand(0);
5184     Value *LHS = Op->getOperand(1);
5185     Value *RHS = Op->getOperand(2);
5186 
5187     FPClassTest FilterLHS = fcAllFlags;
5188     FPClassTest FilterRHS = fcAllFlags;
5189 
5190     Value *TestedValue = nullptr;
5191     FPClassTest MaskIfTrue = fcAllFlags;
5192     FPClassTest MaskIfFalse = fcAllFlags;
5193     uint64_t ClassVal = 0;
5194     const Function *F = cast<Instruction>(Op)->getFunction();
5195     CmpPredicate Pred;
5196     Value *CmpLHS, *CmpRHS;
5197     if (F && match(Cond, m_FCmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
5198       // If the select filters out a value based on the class, it no longer
5199       // participates in the class of the result
5200 
5201       // TODO: In some degenerate cases we can infer something if we try again
5202       // without looking through sign operations.
5203       bool LookThroughFAbsFNeg = CmpLHS != LHS && CmpLHS != RHS;
5204       std::tie(TestedValue, MaskIfTrue, MaskIfFalse) =
5205           fcmpImpliesClass(Pred, *F, CmpLHS, CmpRHS, LookThroughFAbsFNeg);
5206     } else if (match(Cond,
5207                      m_Intrinsic<Intrinsic::is_fpclass>(
5208                          m_Value(TestedValue), m_ConstantInt(ClassVal)))) {
5209       FPClassTest TestedMask = static_cast<FPClassTest>(ClassVal);
5210       MaskIfTrue = TestedMask;
5211       MaskIfFalse = ~TestedMask;
5212     }
5213 
5214     if (TestedValue == LHS) {
5215       // match !isnan(x) ? x : y
5216       FilterLHS = MaskIfTrue;
5217     } else if (TestedValue == RHS) { // && IsExactClass
5218       // match !isnan(x) ? y : x
5219       FilterRHS = MaskIfFalse;
5220     }
5221 
5222     KnownFPClass Known2;
5223     computeKnownFPClass(LHS, DemandedElts, InterestedClasses & FilterLHS, Known,
5224                         Depth + 1, Q);
5225     Known.KnownFPClasses &= FilterLHS;
5226 
5227     computeKnownFPClass(RHS, DemandedElts, InterestedClasses & FilterRHS,
5228                         Known2, Depth + 1, Q);
5229     Known2.KnownFPClasses &= FilterRHS;
5230 
5231     Known |= Known2;
5232     break;
5233   }
5234   case Instruction::Call: {
5235     const CallInst *II = cast<CallInst>(Op);
5236     const Intrinsic::ID IID = II->getIntrinsicID();
5237     switch (IID) {
5238     case Intrinsic::fabs: {
5239       if ((InterestedClasses & (fcNan | fcPositive)) != fcNone) {
5240         // If we only care about the sign bit we don't need to inspect the
5241         // operand.
5242         computeKnownFPClass(II->getArgOperand(0), DemandedElts,
5243                             InterestedClasses, Known, Depth + 1, Q);
5244       }
5245 
5246       Known.fabs();
5247       break;
5248     }
5249     case Intrinsic::copysign: {
5250       KnownFPClass KnownSign;
5251 
5252       computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedClasses,
5253                           Known, Depth + 1, Q);
5254       computeKnownFPClass(II->getArgOperand(1), DemandedElts, InterestedClasses,
5255                           KnownSign, Depth + 1, Q);
5256       Known.copysign(KnownSign);
5257       break;
5258     }
5259     case Intrinsic::fma:
5260     case Intrinsic::fmuladd: {
5261       if ((InterestedClasses & fcNegative) == fcNone)
5262         break;
5263 
5264       if (II->getArgOperand(0) != II->getArgOperand(1))
5265         break;
5266 
5267       // The multiply cannot be -0 and therefore the add can't be -0
5268       Known.knownNot(fcNegZero);
5269 
5270       // x * x + y is non-negative if y is non-negative.
5271       KnownFPClass KnownAddend;
5272       computeKnownFPClass(II->getArgOperand(2), DemandedElts, InterestedClasses,
5273                           KnownAddend, Depth + 1, Q);
5274 
5275       if (KnownAddend.cannotBeOrderedLessThanZero())
5276         Known.knownNot(fcNegative);
5277       break;
5278     }
5279     case Intrinsic::sqrt:
5280     case Intrinsic::experimental_constrained_sqrt: {
5281       KnownFPClass KnownSrc;
5282       FPClassTest InterestedSrcs = InterestedClasses;
5283       if (InterestedClasses & fcNan)
5284         InterestedSrcs |= KnownFPClass::OrderedLessThanZeroMask;
5285 
5286       computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedSrcs,
5287                           KnownSrc, Depth + 1, Q);
5288 
5289       if (KnownSrc.isKnownNeverPosInfinity())
5290         Known.knownNot(fcPosInf);
5291       if (KnownSrc.isKnownNever(fcSNan))
5292         Known.knownNot(fcSNan);
5293 
5294       // Any negative value besides -0 returns a nan.
5295       if (KnownSrc.isKnownNeverNaN() && KnownSrc.cannotBeOrderedLessThanZero())
5296         Known.knownNot(fcNan);
5297 
5298       // The only negative value that can be returned is -0 for -0 inputs.
5299       Known.knownNot(fcNegInf | fcNegSubnormal | fcNegNormal);
5300 
5301       // If the input denormal mode could be PreserveSign, a negative
5302       // subnormal input could produce a negative zero output.
5303       const Function *F = II->getFunction();
5304       if (Q.IIQ.hasNoSignedZeros(II) ||
5305           (F && KnownSrc.isKnownNeverLogicalNegZero(*F, II->getType())))
5306         Known.knownNot(fcNegZero);
5307 
5308       break;
5309     }
5310     case Intrinsic::sin:
5311     case Intrinsic::cos: {
5312       // Return NaN on infinite inputs.
5313       KnownFPClass KnownSrc;
5314       computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedClasses,
5315                           KnownSrc, Depth + 1, Q);
5316       Known.knownNot(fcInf);
5317       if (KnownSrc.isKnownNeverNaN() && KnownSrc.isKnownNeverInfinity())
5318         Known.knownNot(fcNan);
5319       break;
5320     }
5321     case Intrinsic::maxnum:
5322     case Intrinsic::minnum:
5323     case Intrinsic::minimum:
5324     case Intrinsic::maximum: {
5325       KnownFPClass KnownLHS, KnownRHS;
5326       computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedClasses,
5327                           KnownLHS, Depth + 1, Q);
5328       computeKnownFPClass(II->getArgOperand(1), DemandedElts, InterestedClasses,
5329                           KnownRHS, Depth + 1, Q);
5330 
5331       bool NeverNaN = KnownLHS.isKnownNeverNaN() || KnownRHS.isKnownNeverNaN();
5332       Known = KnownLHS | KnownRHS;
5333 
5334       // If either operand is not NaN, the result is not NaN.
5335       if (NeverNaN && (IID == Intrinsic::minnum || IID == Intrinsic::maxnum))
5336         Known.knownNot(fcNan);
5337 
5338       if (IID == Intrinsic::maxnum) {
5339         // If at least one operand is known to be positive, the result must be
5340         // positive.
5341         if ((KnownLHS.cannotBeOrderedLessThanZero() &&
5342              KnownLHS.isKnownNeverNaN()) ||
5343             (KnownRHS.cannotBeOrderedLessThanZero() &&
5344              KnownRHS.isKnownNeverNaN()))
5345           Known.knownNot(KnownFPClass::OrderedLessThanZeroMask);
5346       } else if (IID == Intrinsic::maximum) {
5347         // If at least one operand is known to be positive, the result must be
5348         // positive.
5349         if (KnownLHS.cannotBeOrderedLessThanZero() ||
5350             KnownRHS.cannotBeOrderedLessThanZero())
5351           Known.knownNot(KnownFPClass::OrderedLessThanZeroMask);
5352       } else if (IID == Intrinsic::minnum) {
5353         // If at least one operand is known to be negative, the result must be
5354         // negative.
5355         if ((KnownLHS.cannotBeOrderedGreaterThanZero() &&
5356              KnownLHS.isKnownNeverNaN()) ||
5357             (KnownRHS.cannotBeOrderedGreaterThanZero() &&
5358              KnownRHS.isKnownNeverNaN()))
5359           Known.knownNot(KnownFPClass::OrderedGreaterThanZeroMask);
5360       } else {
5361         // If at least one operand is known to be negative, the result must be
5362         // negative.
5363         if (KnownLHS.cannotBeOrderedGreaterThanZero() ||
5364             KnownRHS.cannotBeOrderedGreaterThanZero())
5365           Known.knownNot(KnownFPClass::OrderedGreaterThanZeroMask);
5366       }
5367 
5368       // Fixup zero handling if denormals could be returned as a zero.
5369       //
5370       // As there's no spec for denormal flushing, be conservative with the
5371       // treatment of denormals that could be flushed to zero. For older
5372       // subtargets on AMDGPU the min/max instructions would not flush the
5373       // output and return the original value.
5374       //
5375       if ((Known.KnownFPClasses & fcZero) != fcNone &&
5376           !Known.isKnownNeverSubnormal()) {
5377         const Function *Parent = II->getFunction();
5378         if (!Parent)
5379           break;
5380 
5381         DenormalMode Mode = Parent->getDenormalMode(
5382             II->getType()->getScalarType()->getFltSemantics());
5383         if (Mode != DenormalMode::getIEEE())
5384           Known.KnownFPClasses |= fcZero;
5385       }
5386 
5387       if (Known.isKnownNeverNaN()) {
5388         if (KnownLHS.SignBit && KnownRHS.SignBit &&
5389             *KnownLHS.SignBit == *KnownRHS.SignBit) {
5390           if (*KnownLHS.SignBit)
5391             Known.signBitMustBeOne();
5392           else
5393             Known.signBitMustBeZero();
5394         } else if ((IID == Intrinsic::maximum || IID == Intrinsic::minimum) ||
5395                    ((KnownLHS.isKnownNeverNegZero() ||
5396                      KnownRHS.isKnownNeverPosZero()) &&
5397                     (KnownLHS.isKnownNeverPosZero() ||
5398                      KnownRHS.isKnownNeverNegZero()))) {
5399           if ((IID == Intrinsic::maximum || IID == Intrinsic::maxnum) &&
5400               (KnownLHS.SignBit == false || KnownRHS.SignBit == false))
5401             Known.signBitMustBeZero();
5402           else if ((IID == Intrinsic::minimum || IID == Intrinsic::minnum) &&
5403                    (KnownLHS.SignBit == true || KnownRHS.SignBit == true))
5404             Known.signBitMustBeOne();
5405         }
5406       }
5407       break;
5408     }
5409     case Intrinsic::canonicalize: {
5410       KnownFPClass KnownSrc;
5411       computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedClasses,
5412                           KnownSrc, Depth + 1, Q);
5413 
5414       // This is essentially a stronger form of
5415       // propagateCanonicalizingSrc. Other "canonicalizing" operations don't
5416       // actually have an IR canonicalization guarantee.
5417 
5418       // Canonicalize may flush denormals to zero, so we have to consider the
5419       // denormal mode to preserve known-not-0 knowledge.
5420       Known.KnownFPClasses = KnownSrc.KnownFPClasses | fcZero | fcQNan;
5421 
5422       // Stronger version of propagateNaN
5423       // Canonicalize is guaranteed to quiet signaling nans.
5424       if (KnownSrc.isKnownNeverNaN())
5425         Known.knownNot(fcNan);
5426       else
5427         Known.knownNot(fcSNan);
5428 
5429       const Function *F = II->getFunction();
5430       if (!F)
5431         break;
5432 
5433       // If the parent function flushes denormals, the canonical output cannot
5434       // be a denormal.
5435       const fltSemantics &FPType =
5436           II->getType()->getScalarType()->getFltSemantics();
5437       DenormalMode DenormMode = F->getDenormalMode(FPType);
5438       if (DenormMode == DenormalMode::getIEEE()) {
5439         if (KnownSrc.isKnownNever(fcPosZero))
5440           Known.knownNot(fcPosZero);
5441         if (KnownSrc.isKnownNever(fcNegZero))
5442           Known.knownNot(fcNegZero);
5443         break;
5444       }
5445 
5446       if (DenormMode.inputsAreZero() || DenormMode.outputsAreZero())
5447         Known.knownNot(fcSubnormal);
5448 
5449       if (DenormMode.Input == DenormalMode::PositiveZero ||
5450           (DenormMode.Output == DenormalMode::PositiveZero &&
5451            DenormMode.Input == DenormalMode::IEEE))
5452         Known.knownNot(fcNegZero);
5453 
5454       break;
5455     }
5456     case Intrinsic::vector_reduce_fmax:
5457     case Intrinsic::vector_reduce_fmin:
5458     case Intrinsic::vector_reduce_fmaximum:
5459     case Intrinsic::vector_reduce_fminimum: {
5460       // reduce min/max will choose an element from one of the vector elements,
5461       // so we can infer and class information that is common to all elements.
5462       Known = computeKnownFPClass(II->getArgOperand(0), II->getFastMathFlags(),
5463                                   InterestedClasses, Depth + 1, Q);
5464       // Can only propagate sign if output is never NaN.
5465       if (!Known.isKnownNeverNaN())
5466         Known.SignBit.reset();
5467       break;
5468     }
5469       // reverse preserves all characteristics of the input vec's element.
5470     case Intrinsic::vector_reverse:
5471       Known = computeKnownFPClass(
5472           II->getArgOperand(0), DemandedElts.reverseBits(),
5473           II->getFastMathFlags(), InterestedClasses, Depth + 1, Q);
5474       break;
5475     case Intrinsic::trunc:
5476     case Intrinsic::floor:
5477     case Intrinsic::ceil:
5478     case Intrinsic::rint:
5479     case Intrinsic::nearbyint:
5480     case Intrinsic::round:
5481     case Intrinsic::roundeven: {
5482       KnownFPClass KnownSrc;
5483       FPClassTest InterestedSrcs = InterestedClasses;
5484       if (InterestedSrcs & fcPosFinite)
5485         InterestedSrcs |= fcPosFinite;
5486       if (InterestedSrcs & fcNegFinite)
5487         InterestedSrcs |= fcNegFinite;
5488       computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedSrcs,
5489                           KnownSrc, Depth + 1, Q);
5490 
5491       // Integer results cannot be subnormal.
5492       Known.knownNot(fcSubnormal);
5493 
5494       Known.propagateNaN(KnownSrc, true);
5495 
5496       // Pass through infinities, except PPC_FP128 is a special case for
5497       // intrinsics other than trunc.
5498       if (IID == Intrinsic::trunc || !V->getType()->isMultiUnitFPType()) {
5499         if (KnownSrc.isKnownNeverPosInfinity())
5500           Known.knownNot(fcPosInf);
5501         if (KnownSrc.isKnownNeverNegInfinity())
5502           Known.knownNot(fcNegInf);
5503       }
5504 
5505       // Negative round ups to 0 produce -0
5506       if (KnownSrc.isKnownNever(fcPosFinite))
5507         Known.knownNot(fcPosFinite);
5508       if (KnownSrc.isKnownNever(fcNegFinite))
5509         Known.knownNot(fcNegFinite);
5510 
5511       break;
5512     }
5513     case Intrinsic::exp:
5514     case Intrinsic::exp2:
5515     case Intrinsic::exp10: {
5516       Known.knownNot(fcNegative);
5517       if ((InterestedClasses & fcNan) == fcNone)
5518         break;
5519 
5520       KnownFPClass KnownSrc;
5521       computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedClasses,
5522                           KnownSrc, Depth + 1, Q);
5523       if (KnownSrc.isKnownNeverNaN()) {
5524         Known.knownNot(fcNan);
5525         Known.signBitMustBeZero();
5526       }
5527 
5528       break;
5529     }
5530     case Intrinsic::fptrunc_round: {
5531       computeKnownFPClassForFPTrunc(Op, DemandedElts, InterestedClasses, Known,
5532                                     Depth, Q);
5533       break;
5534     }
5535     case Intrinsic::log:
5536     case Intrinsic::log10:
5537     case Intrinsic::log2:
5538     case Intrinsic::experimental_constrained_log:
5539     case Intrinsic::experimental_constrained_log10:
5540     case Intrinsic::experimental_constrained_log2: {
5541       // log(+inf) -> +inf
5542       // log([+-]0.0) -> -inf
5543       // log(-inf) -> nan
5544       // log(-x) -> nan
5545       if ((InterestedClasses & (fcNan | fcInf)) == fcNone)
5546         break;
5547 
5548       FPClassTest InterestedSrcs = InterestedClasses;
5549       if ((InterestedClasses & fcNegInf) != fcNone)
5550         InterestedSrcs |= fcZero | fcSubnormal;
5551       if ((InterestedClasses & fcNan) != fcNone)
5552         InterestedSrcs |= fcNan | (fcNegative & ~fcNan);
5553 
5554       KnownFPClass KnownSrc;
5555       computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedSrcs,
5556                           KnownSrc, Depth + 1, Q);
5557 
5558       if (KnownSrc.isKnownNeverPosInfinity())
5559         Known.knownNot(fcPosInf);
5560 
5561       if (KnownSrc.isKnownNeverNaN() && KnownSrc.cannotBeOrderedLessThanZero())
5562         Known.knownNot(fcNan);
5563 
5564       const Function *F = II->getFunction();
5565       if (F && KnownSrc.isKnownNeverLogicalZero(*F, II->getType()))
5566         Known.knownNot(fcNegInf);
5567 
5568       break;
5569     }
5570     case Intrinsic::powi: {
5571       if ((InterestedClasses & fcNegative) == fcNone)
5572         break;
5573 
5574       const Value *Exp = II->getArgOperand(1);
5575       Type *ExpTy = Exp->getType();
5576       unsigned BitWidth = ExpTy->getScalarType()->getIntegerBitWidth();
5577       KnownBits ExponentKnownBits(BitWidth);
5578       computeKnownBits(Exp, isa<VectorType>(ExpTy) ? DemandedElts : APInt(1, 1),
5579                        ExponentKnownBits, Depth + 1, Q);
5580 
5581       if (ExponentKnownBits.Zero[0]) { // Is even
5582         Known.knownNot(fcNegative);
5583         break;
5584       }
5585 
5586       // Given that exp is an integer, here are the
5587       // ways that pow can return a negative value:
5588       //
5589       //   pow(-x, exp)   --> negative if exp is odd and x is negative.
5590       //   pow(-0, exp)   --> -inf if exp is negative odd.
5591       //   pow(-0, exp)   --> -0 if exp is positive odd.
5592       //   pow(-inf, exp) --> -0 if exp is negative odd.
5593       //   pow(-inf, exp) --> -inf if exp is positive odd.
5594       KnownFPClass KnownSrc;
5595       computeKnownFPClass(II->getArgOperand(0), DemandedElts, fcNegative,
5596                           KnownSrc, Depth + 1, Q);
5597       if (KnownSrc.isKnownNever(fcNegative))
5598         Known.knownNot(fcNegative);
5599       break;
5600     }
5601     case Intrinsic::ldexp: {
5602       KnownFPClass KnownSrc;
5603       computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedClasses,
5604                           KnownSrc, Depth + 1, Q);
5605       Known.propagateNaN(KnownSrc, /*PropagateSign=*/true);
5606 
5607       // Sign is preserved, but underflows may produce zeroes.
5608       if (KnownSrc.isKnownNever(fcNegative))
5609         Known.knownNot(fcNegative);
5610       else if (KnownSrc.cannotBeOrderedLessThanZero())
5611         Known.knownNot(KnownFPClass::OrderedLessThanZeroMask);
5612 
5613       if (KnownSrc.isKnownNever(fcPositive))
5614         Known.knownNot(fcPositive);
5615       else if (KnownSrc.cannotBeOrderedGreaterThanZero())
5616         Known.knownNot(KnownFPClass::OrderedGreaterThanZeroMask);
5617 
5618       // Can refine inf/zero handling based on the exponent operand.
5619       const FPClassTest ExpInfoMask = fcZero | fcSubnormal | fcInf;
5620       if ((InterestedClasses & ExpInfoMask) == fcNone)
5621         break;
5622       if ((KnownSrc.KnownFPClasses & ExpInfoMask) == fcNone)
5623         break;
5624 
5625       const fltSemantics &Flt =
5626           II->getType()->getScalarType()->getFltSemantics();
5627       unsigned Precision = APFloat::semanticsPrecision(Flt);
5628       const Value *ExpArg = II->getArgOperand(1);
5629       ConstantRange ExpRange = computeConstantRange(
5630           ExpArg, true, Q.IIQ.UseInstrInfo, Q.AC, Q.CxtI, Q.DT, Depth + 1);
5631 
5632       const int MantissaBits = Precision - 1;
5633       if (ExpRange.getSignedMin().sge(static_cast<int64_t>(MantissaBits)))
5634         Known.knownNot(fcSubnormal);
5635 
5636       const Function *F = II->getFunction();
5637       const APInt *ConstVal = ExpRange.getSingleElement();
5638       if (ConstVal && ConstVal->isZero()) {
5639         // ldexp(x, 0) -> x, so propagate everything.
5640         Known.propagateCanonicalizingSrc(KnownSrc, *F, II->getType());
5641       } else if (ExpRange.isAllNegative()) {
5642         // If we know the power is <= 0, can't introduce inf
5643         if (KnownSrc.isKnownNeverPosInfinity())
5644           Known.knownNot(fcPosInf);
5645         if (KnownSrc.isKnownNeverNegInfinity())
5646           Known.knownNot(fcNegInf);
5647       } else if (ExpRange.isAllNonNegative()) {
5648         // If we know the power is >= 0, can't introduce subnormal or zero
5649         if (KnownSrc.isKnownNeverPosSubnormal())
5650           Known.knownNot(fcPosSubnormal);
5651         if (KnownSrc.isKnownNeverNegSubnormal())
5652           Known.knownNot(fcNegSubnormal);
5653         if (F && KnownSrc.isKnownNeverLogicalPosZero(*F, II->getType()))
5654           Known.knownNot(fcPosZero);
5655         if (F && KnownSrc.isKnownNeverLogicalNegZero(*F, II->getType()))
5656           Known.knownNot(fcNegZero);
5657       }
5658 
5659       break;
5660     }
5661     case Intrinsic::arithmetic_fence: {
5662       computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedClasses,
5663                           Known, Depth + 1, Q);
5664       break;
5665     }
5666     case Intrinsic::experimental_constrained_sitofp:
5667     case Intrinsic::experimental_constrained_uitofp:
5668       // Cannot produce nan
5669       Known.knownNot(fcNan);
5670 
5671       // sitofp and uitofp turn into +0.0 for zero.
5672       Known.knownNot(fcNegZero);
5673 
5674       // Integers cannot be subnormal
5675       Known.knownNot(fcSubnormal);
5676 
5677       if (IID == Intrinsic::experimental_constrained_uitofp)
5678         Known.signBitMustBeZero();
5679 
5680       // TODO: Copy inf handling from instructions
5681       break;
5682     default:
5683       break;
5684     }
5685 
5686     break;
5687   }
5688   case Instruction::FAdd:
5689   case Instruction::FSub: {
5690     KnownFPClass KnownLHS, KnownRHS;
5691     bool WantNegative =
5692         Op->getOpcode() == Instruction::FAdd &&
5693         (InterestedClasses & KnownFPClass::OrderedLessThanZeroMask) != fcNone;
5694     bool WantNaN = (InterestedClasses & fcNan) != fcNone;
5695     bool WantNegZero = (InterestedClasses & fcNegZero) != fcNone;
5696 
5697     if (!WantNaN && !WantNegative && !WantNegZero)
5698       break;
5699 
5700     FPClassTest InterestedSrcs = InterestedClasses;
5701     if (WantNegative)
5702       InterestedSrcs |= KnownFPClass::OrderedLessThanZeroMask;
5703     if (InterestedClasses & fcNan)
5704       InterestedSrcs |= fcInf;
5705     computeKnownFPClass(Op->getOperand(1), DemandedElts, InterestedSrcs,
5706                         KnownRHS, Depth + 1, Q);
5707 
5708     if ((WantNaN && KnownRHS.isKnownNeverNaN()) ||
5709         (WantNegative && KnownRHS.cannotBeOrderedLessThanZero()) ||
5710         WantNegZero || Opc == Instruction::FSub) {
5711 
5712       // RHS is canonically cheaper to compute. Skip inspecting the LHS if
5713       // there's no point.
5714       computeKnownFPClass(Op->getOperand(0), DemandedElts, InterestedSrcs,
5715                           KnownLHS, Depth + 1, Q);
5716       // Adding positive and negative infinity produces NaN.
5717       // TODO: Check sign of infinities.
5718       if (KnownLHS.isKnownNeverNaN() && KnownRHS.isKnownNeverNaN() &&
5719           (KnownLHS.isKnownNeverInfinity() || KnownRHS.isKnownNeverInfinity()))
5720         Known.knownNot(fcNan);
5721 
5722       // FIXME: Context function should always be passed in separately
5723       const Function *F = cast<Instruction>(Op)->getFunction();
5724 
5725       if (Op->getOpcode() == Instruction::FAdd) {
5726         if (KnownLHS.cannotBeOrderedLessThanZero() &&
5727             KnownRHS.cannotBeOrderedLessThanZero())
5728           Known.knownNot(KnownFPClass::OrderedLessThanZeroMask);
5729         if (!F)
5730           break;
5731 
5732         // (fadd x, 0.0) is guaranteed to return +0.0, not -0.0.
5733         if ((KnownLHS.isKnownNeverLogicalNegZero(*F, Op->getType()) ||
5734              KnownRHS.isKnownNeverLogicalNegZero(*F, Op->getType())) &&
5735             // Make sure output negative denormal can't flush to -0
5736             outputDenormalIsIEEEOrPosZero(*F, Op->getType()))
5737           Known.knownNot(fcNegZero);
5738       } else {
5739         if (!F)
5740           break;
5741 
5742         // Only fsub -0, +0 can return -0
5743         if ((KnownLHS.isKnownNeverLogicalNegZero(*F, Op->getType()) ||
5744              KnownRHS.isKnownNeverLogicalPosZero(*F, Op->getType())) &&
5745             // Make sure output negative denormal can't flush to -0
5746             outputDenormalIsIEEEOrPosZero(*F, Op->getType()))
5747           Known.knownNot(fcNegZero);
5748       }
5749     }
5750 
5751     break;
5752   }
5753   case Instruction::FMul: {
5754     // X * X is always non-negative or a NaN.
5755     if (Op->getOperand(0) == Op->getOperand(1))
5756       Known.knownNot(fcNegative);
5757 
5758     if ((InterestedClasses & fcNan) != fcNan)
5759       break;
5760 
5761     // fcSubnormal is only needed in case of DAZ.
5762     const FPClassTest NeedForNan = fcNan | fcInf | fcZero | fcSubnormal;
5763 
5764     KnownFPClass KnownLHS, KnownRHS;
5765     computeKnownFPClass(Op->getOperand(1), DemandedElts, NeedForNan, KnownRHS,
5766                         Depth + 1, Q);
5767     if (!KnownRHS.isKnownNeverNaN())
5768       break;
5769 
5770     computeKnownFPClass(Op->getOperand(0), DemandedElts, NeedForNan, KnownLHS,
5771                         Depth + 1, Q);
5772     if (!KnownLHS.isKnownNeverNaN())
5773       break;
5774 
5775     if (KnownLHS.SignBit && KnownRHS.SignBit) {
5776       if (*KnownLHS.SignBit == *KnownRHS.SignBit)
5777         Known.signBitMustBeZero();
5778       else
5779         Known.signBitMustBeOne();
5780     }
5781 
5782     // If 0 * +/-inf produces NaN.
5783     if (KnownLHS.isKnownNeverInfinity() && KnownRHS.isKnownNeverInfinity()) {
5784       Known.knownNot(fcNan);
5785       break;
5786     }
5787 
5788     const Function *F = cast<Instruction>(Op)->getFunction();
5789     if (!F)
5790       break;
5791 
5792     if ((KnownRHS.isKnownNeverInfinity() ||
5793          KnownLHS.isKnownNeverLogicalZero(*F, Op->getType())) &&
5794         (KnownLHS.isKnownNeverInfinity() ||
5795          KnownRHS.isKnownNeverLogicalZero(*F, Op->getType())))
5796       Known.knownNot(fcNan);
5797 
5798     break;
5799   }
5800   case Instruction::FDiv:
5801   case Instruction::FRem: {
5802     if (Op->getOperand(0) == Op->getOperand(1)) {
5803       // TODO: Could filter out snan if we inspect the operand
5804       if (Op->getOpcode() == Instruction::FDiv) {
5805         // X / X is always exactly 1.0 or a NaN.
5806         Known.KnownFPClasses = fcNan | fcPosNormal;
5807       } else {
5808         // X % X is always exactly [+-]0.0 or a NaN.
5809         Known.KnownFPClasses = fcNan | fcZero;
5810       }
5811 
5812       break;
5813     }
5814 
5815     const bool WantNan = (InterestedClasses & fcNan) != fcNone;
5816     const bool WantNegative = (InterestedClasses & fcNegative) != fcNone;
5817     const bool WantPositive =
5818         Opc == Instruction::FRem && (InterestedClasses & fcPositive) != fcNone;
5819     if (!WantNan && !WantNegative && !WantPositive)
5820       break;
5821 
5822     KnownFPClass KnownLHS, KnownRHS;
5823 
5824     computeKnownFPClass(Op->getOperand(1), DemandedElts,
5825                         fcNan | fcInf | fcZero | fcNegative, KnownRHS,
5826                         Depth + 1, Q);
5827 
5828     bool KnowSomethingUseful =
5829         KnownRHS.isKnownNeverNaN() || KnownRHS.isKnownNever(fcNegative);
5830 
5831     if (KnowSomethingUseful || WantPositive) {
5832       const FPClassTest InterestedLHS =
5833           WantPositive ? fcAllFlags
5834                        : fcNan | fcInf | fcZero | fcSubnormal | fcNegative;
5835 
5836       computeKnownFPClass(Op->getOperand(0), DemandedElts,
5837                           InterestedClasses & InterestedLHS, KnownLHS,
5838                           Depth + 1, Q);
5839     }
5840 
5841     const Function *F = cast<Instruction>(Op)->getFunction();
5842 
5843     if (Op->getOpcode() == Instruction::FDiv) {
5844       // Only 0/0, Inf/Inf produce NaN.
5845       if (KnownLHS.isKnownNeverNaN() && KnownRHS.isKnownNeverNaN() &&
5846           (KnownLHS.isKnownNeverInfinity() ||
5847            KnownRHS.isKnownNeverInfinity()) &&
5848           ((F && KnownLHS.isKnownNeverLogicalZero(*F, Op->getType())) ||
5849            (F && KnownRHS.isKnownNeverLogicalZero(*F, Op->getType())))) {
5850         Known.knownNot(fcNan);
5851       }
5852 
5853       // X / -0.0 is -Inf (or NaN).
5854       // +X / +X is +X
5855       if (KnownLHS.isKnownNever(fcNegative) && KnownRHS.isKnownNever(fcNegative))
5856         Known.knownNot(fcNegative);
5857     } else {
5858       // Inf REM x and x REM 0 produce NaN.
5859       if (KnownLHS.isKnownNeverNaN() && KnownRHS.isKnownNeverNaN() &&
5860           KnownLHS.isKnownNeverInfinity() && F &&
5861           KnownRHS.isKnownNeverLogicalZero(*F, Op->getType())) {
5862         Known.knownNot(fcNan);
5863       }
5864 
5865       // The sign for frem is the same as the first operand.
5866       if (KnownLHS.cannotBeOrderedLessThanZero())
5867         Known.knownNot(KnownFPClass::OrderedLessThanZeroMask);
5868       if (KnownLHS.cannotBeOrderedGreaterThanZero())
5869         Known.knownNot(KnownFPClass::OrderedGreaterThanZeroMask);
5870 
5871       // See if we can be more aggressive about the sign of 0.
5872       if (KnownLHS.isKnownNever(fcNegative))
5873         Known.knownNot(fcNegative);
5874       if (KnownLHS.isKnownNever(fcPositive))
5875         Known.knownNot(fcPositive);
5876     }
5877 
5878     break;
5879   }
5880   case Instruction::FPExt: {
5881     // Infinity, nan and zero propagate from source.
5882     computeKnownFPClass(Op->getOperand(0), DemandedElts, InterestedClasses,
5883                         Known, Depth + 1, Q);
5884 
5885     const fltSemantics &DstTy =
5886         Op->getType()->getScalarType()->getFltSemantics();
5887     const fltSemantics &SrcTy =
5888         Op->getOperand(0)->getType()->getScalarType()->getFltSemantics();
5889 
5890     // All subnormal inputs should be in the normal range in the result type.
5891     if (APFloat::isRepresentableAsNormalIn(SrcTy, DstTy)) {
5892       if (Known.KnownFPClasses & fcPosSubnormal)
5893         Known.KnownFPClasses |= fcPosNormal;
5894       if (Known.KnownFPClasses & fcNegSubnormal)
5895         Known.KnownFPClasses |= fcNegNormal;
5896       Known.knownNot(fcSubnormal);
5897     }
5898 
5899     // Sign bit of a nan isn't guaranteed.
5900     if (!Known.isKnownNeverNaN())
5901       Known.SignBit = std::nullopt;
5902     break;
5903   }
5904   case Instruction::FPTrunc: {
5905     computeKnownFPClassForFPTrunc(Op, DemandedElts, InterestedClasses, Known,
5906                                   Depth, Q);
5907     break;
5908   }
5909   case Instruction::SIToFP:
5910   case Instruction::UIToFP: {
5911     // Cannot produce nan
5912     Known.knownNot(fcNan);
5913 
5914     // Integers cannot be subnormal
5915     Known.knownNot(fcSubnormal);
5916 
5917     // sitofp and uitofp turn into +0.0 for zero.
5918     Known.knownNot(fcNegZero);
5919     if (Op->getOpcode() == Instruction::UIToFP)
5920       Known.signBitMustBeZero();
5921 
5922     if (InterestedClasses & fcInf) {
5923       // Get width of largest magnitude integer (remove a bit if signed).
5924       // This still works for a signed minimum value because the largest FP
5925       // value is scaled by some fraction close to 2.0 (1.0 + 0.xxxx).
5926       int IntSize = Op->getOperand(0)->getType()->getScalarSizeInBits();
5927       if (Op->getOpcode() == Instruction::SIToFP)
5928         --IntSize;
5929 
5930       // If the exponent of the largest finite FP value can hold the largest
5931       // integer, the result of the cast must be finite.
5932       Type *FPTy = Op->getType()->getScalarType();
5933       if (ilogb(APFloat::getLargest(FPTy->getFltSemantics())) >= IntSize)
5934         Known.knownNot(fcInf);
5935     }
5936 
5937     break;
5938   }
5939   case Instruction::ExtractElement: {
5940     // Look through extract element. If the index is non-constant or
5941     // out-of-range demand all elements, otherwise just the extracted element.
5942     const Value *Vec = Op->getOperand(0);
5943     const Value *Idx = Op->getOperand(1);
5944     auto *CIdx = dyn_cast<ConstantInt>(Idx);
5945 
5946     if (auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType())) {
5947       unsigned NumElts = VecTy->getNumElements();
5948       APInt DemandedVecElts = APInt::getAllOnes(NumElts);
5949       if (CIdx && CIdx->getValue().ult(NumElts))
5950         DemandedVecElts = APInt::getOneBitSet(NumElts, CIdx->getZExtValue());
5951       return computeKnownFPClass(Vec, DemandedVecElts, InterestedClasses, Known,
5952                                  Depth + 1, Q);
5953     }
5954 
5955     break;
5956   }
5957   case Instruction::InsertElement: {
5958     if (isa<ScalableVectorType>(Op->getType()))
5959       return;
5960 
5961     const Value *Vec = Op->getOperand(0);
5962     const Value *Elt = Op->getOperand(1);
5963     auto *CIdx = dyn_cast<ConstantInt>(Op->getOperand(2));
5964     unsigned NumElts = DemandedElts.getBitWidth();
5965     APInt DemandedVecElts = DemandedElts;
5966     bool NeedsElt = true;
5967     // If we know the index we are inserting to, clear it from Vec check.
5968     if (CIdx && CIdx->getValue().ult(NumElts)) {
5969       DemandedVecElts.clearBit(CIdx->getZExtValue());
5970       NeedsElt = DemandedElts[CIdx->getZExtValue()];
5971     }
5972 
5973     // Do we demand the inserted element?
5974     if (NeedsElt) {
5975       computeKnownFPClass(Elt, Known, InterestedClasses, Depth + 1, Q);
5976       // If we don't know any bits, early out.
5977       if (Known.isUnknown())
5978         break;
5979     } else {
5980       Known.KnownFPClasses = fcNone;
5981     }
5982 
5983     // Do we need anymore elements from Vec?
5984     if (!DemandedVecElts.isZero()) {
5985       KnownFPClass Known2;
5986       computeKnownFPClass(Vec, DemandedVecElts, InterestedClasses, Known2,
5987                           Depth + 1, Q);
5988       Known |= Known2;
5989     }
5990 
5991     break;
5992   }
5993   case Instruction::ShuffleVector: {
5994     // For undef elements, we don't know anything about the common state of
5995     // the shuffle result.
5996     APInt DemandedLHS, DemandedRHS;
5997     auto *Shuf = dyn_cast<ShuffleVectorInst>(Op);
5998     if (!Shuf || !getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS))
5999       return;
6000 
6001     if (!!DemandedLHS) {
6002       const Value *LHS = Shuf->getOperand(0);
6003       computeKnownFPClass(LHS, DemandedLHS, InterestedClasses, Known,
6004                           Depth + 1, Q);
6005 
6006       // If we don't know any bits, early out.
6007       if (Known.isUnknown())
6008         break;
6009     } else {
6010       Known.KnownFPClasses = fcNone;
6011     }
6012 
6013     if (!!DemandedRHS) {
6014       KnownFPClass Known2;
6015       const Value *RHS = Shuf->getOperand(1);
6016       computeKnownFPClass(RHS, DemandedRHS, InterestedClasses, Known2,
6017                           Depth + 1, Q);
6018       Known |= Known2;
6019     }
6020 
6021     break;
6022   }
6023   case Instruction::ExtractValue: {
6024     const ExtractValueInst *Extract = cast<ExtractValueInst>(Op);
6025     ArrayRef<unsigned> Indices = Extract->getIndices();
6026     const Value *Src = Extract->getAggregateOperand();
6027     if (isa<StructType>(Src->getType()) && Indices.size() == 1 &&
6028         Indices[0] == 0) {
6029       if (const auto *II = dyn_cast<IntrinsicInst>(Src)) {
6030         switch (II->getIntrinsicID()) {
6031         case Intrinsic::frexp: {
6032           Known.knownNot(fcSubnormal);
6033 
6034           KnownFPClass KnownSrc;
6035           computeKnownFPClass(II->getArgOperand(0), DemandedElts,
6036                               InterestedClasses, KnownSrc, Depth + 1, Q);
6037 
6038           const Function *F = cast<Instruction>(Op)->getFunction();
6039 
6040           if (KnownSrc.isKnownNever(fcNegative))
6041             Known.knownNot(fcNegative);
6042           else {
6043             if (F && KnownSrc.isKnownNeverLogicalNegZero(*F, Op->getType()))
6044               Known.knownNot(fcNegZero);
6045             if (KnownSrc.isKnownNever(fcNegInf))
6046               Known.knownNot(fcNegInf);
6047           }
6048 
6049           if (KnownSrc.isKnownNever(fcPositive))
6050             Known.knownNot(fcPositive);
6051           else {
6052             if (F && KnownSrc.isKnownNeverLogicalPosZero(*F, Op->getType()))
6053               Known.knownNot(fcPosZero);
6054             if (KnownSrc.isKnownNever(fcPosInf))
6055               Known.knownNot(fcPosInf);
6056           }
6057 
6058           Known.propagateNaN(KnownSrc);
6059           return;
6060         }
6061         default:
6062           break;
6063         }
6064       }
6065     }
6066 
6067     computeKnownFPClass(Src, DemandedElts, InterestedClasses, Known, Depth + 1,
6068                         Q);
6069     break;
6070   }
6071   case Instruction::PHI: {
6072     const PHINode *P = cast<PHINode>(Op);
6073     // Unreachable blocks may have zero-operand PHI nodes.
6074     if (P->getNumIncomingValues() == 0)
6075       break;
6076 
6077     // Otherwise take the unions of the known bit sets of the operands,
6078     // taking conservative care to avoid excessive recursion.
6079     const unsigned PhiRecursionLimit = MaxAnalysisRecursionDepth - 2;
6080 
6081     if (Depth < PhiRecursionLimit) {
6082       // Skip if every incoming value references to ourself.
6083       if (isa_and_nonnull<UndefValue>(P->hasConstantValue()))
6084         break;
6085 
6086       bool First = true;
6087 
6088       for (const Use &U : P->operands()) {
6089         Value *IncValue;
6090         Instruction *CxtI;
6091         breakSelfRecursivePHI(&U, P, IncValue, CxtI);
6092         // Skip direct self references.
6093         if (IncValue == P)
6094           continue;
6095 
6096         KnownFPClass KnownSrc;
6097         // Recurse, but cap the recursion to two levels, because we don't want
6098         // to waste time spinning around in loops. We need at least depth 2 to
6099         // detect known sign bits.
6100         computeKnownFPClass(IncValue, DemandedElts, InterestedClasses, KnownSrc,
6101                             PhiRecursionLimit,
6102                             Q.getWithoutCondContext().getWithInstruction(CxtI));
6103 
6104         if (First) {
6105           Known = KnownSrc;
6106           First = false;
6107         } else {
6108           Known |= KnownSrc;
6109         }
6110 
6111         if (Known.KnownFPClasses == fcAllFlags)
6112           break;
6113       }
6114     }
6115 
6116     break;
6117   }
6118   case Instruction::BitCast: {
6119     const Value *Src;
6120     if (!match(Op, m_ElementWiseBitCast(m_Value(Src))) ||
6121         !Src->getType()->isIntOrIntVectorTy())
6122       break;
6123 
6124     const Type *Ty = Op->getType()->getScalarType();
6125     KnownBits Bits(Ty->getScalarSizeInBits());
6126     computeKnownBits(Src, DemandedElts, Bits, Depth + 1, Q);
6127 
6128     // Transfer information from the sign bit.
6129     if (Bits.isNonNegative())
6130       Known.signBitMustBeZero();
6131     else if (Bits.isNegative())
6132       Known.signBitMustBeOne();
6133 
6134     if (Ty->isIEEE()) {
6135       // IEEE floats are NaN when all bits of the exponent plus at least one of
6136       // the fraction bits are 1. This means:
6137       //   - If we assume unknown bits are 0 and the value is NaN, it will
6138       //     always be NaN
6139       //   - If we assume unknown bits are 1 and the value is not NaN, it can
6140       //     never be NaN
6141       if (APFloat(Ty->getFltSemantics(), Bits.One).isNaN())
6142         Known.KnownFPClasses = fcNan;
6143       else if (!APFloat(Ty->getFltSemantics(), ~Bits.Zero).isNaN())
6144         Known.knownNot(fcNan);
6145 
6146       // Build KnownBits representing Inf and check if it must be equal or
6147       // unequal to this value.
6148       auto InfKB = KnownBits::makeConstant(
6149           APFloat::getInf(Ty->getFltSemantics()).bitcastToAPInt());
6150       InfKB.Zero.clearSignBit();
6151       if (const auto InfResult = KnownBits::eq(Bits, InfKB)) {
6152         assert(!InfResult.value());
6153         Known.knownNot(fcInf);
6154       } else if (Bits == InfKB) {
6155         Known.KnownFPClasses = fcInf;
6156       }
6157 
6158       // Build KnownBits representing Zero and check if it must be equal or
6159       // unequal to this value.
6160       auto ZeroKB = KnownBits::makeConstant(
6161           APFloat::getZero(Ty->getFltSemantics()).bitcastToAPInt());
6162       ZeroKB.Zero.clearSignBit();
6163       if (const auto ZeroResult = KnownBits::eq(Bits, ZeroKB)) {
6164         assert(!ZeroResult.value());
6165         Known.knownNot(fcZero);
6166       } else if (Bits == ZeroKB) {
6167         Known.KnownFPClasses = fcZero;
6168       }
6169     }
6170 
6171     break;
6172   }
6173   default:
6174     break;
6175   }
6176 }
6177 
6178 KnownFPClass llvm::computeKnownFPClass(const Value *V,
6179                                        const APInt &DemandedElts,
6180                                        FPClassTest InterestedClasses,
6181                                        unsigned Depth,
6182                                        const SimplifyQuery &SQ) {
6183   KnownFPClass KnownClasses;
6184   ::computeKnownFPClass(V, DemandedElts, InterestedClasses, KnownClasses, Depth,
6185                         SQ);
6186   return KnownClasses;
6187 }
6188 
6189 KnownFPClass llvm::computeKnownFPClass(const Value *V,
6190                                        FPClassTest InterestedClasses,
6191                                        unsigned Depth,
6192                                        const SimplifyQuery &SQ) {
6193   KnownFPClass Known;
6194   ::computeKnownFPClass(V, Known, InterestedClasses, Depth, SQ);
6195   return Known;
6196 }
6197 
6198 Value *llvm::isBytewiseValue(Value *V, const DataLayout &DL) {
6199 
6200   // All byte-wide stores are splatable, even of arbitrary variables.
6201   if (V->getType()->isIntegerTy(8))
6202     return V;
6203 
6204   LLVMContext &Ctx = V->getContext();
6205 
6206   // Undef don't care.
6207   auto *UndefInt8 = UndefValue::get(Type::getInt8Ty(Ctx));
6208   if (isa<UndefValue>(V))
6209     return UndefInt8;
6210 
6211   // Return poison for zero-sized type.
6212   if (DL.getTypeStoreSize(V->getType()).isZero())
6213     return PoisonValue::get(Type::getInt8Ty(Ctx));
6214 
6215   Constant *C = dyn_cast<Constant>(V);
6216   if (!C) {
6217     // Conceptually, we could handle things like:
6218     //   %a = zext i8 %X to i16
6219     //   %b = shl i16 %a, 8
6220     //   %c = or i16 %a, %b
6221     // but until there is an example that actually needs this, it doesn't seem
6222     // worth worrying about.
6223     return nullptr;
6224   }
6225 
6226   // Handle 'null' ConstantArrayZero etc.
6227   if (C->isNullValue())
6228     return Constant::getNullValue(Type::getInt8Ty(Ctx));
6229 
6230   // Constant floating-point values can be handled as integer values if the
6231   // corresponding integer value is "byteable".  An important case is 0.0.
6232   if (ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
6233     Type *Ty = nullptr;
6234     if (CFP->getType()->isHalfTy())
6235       Ty = Type::getInt16Ty(Ctx);
6236     else if (CFP->getType()->isFloatTy())
6237       Ty = Type::getInt32Ty(Ctx);
6238     else if (CFP->getType()->isDoubleTy())
6239       Ty = Type::getInt64Ty(Ctx);
6240     // Don't handle long double formats, which have strange constraints.
6241     return Ty ? isBytewiseValue(ConstantExpr::getBitCast(CFP, Ty), DL)
6242               : nullptr;
6243   }
6244 
6245   // We can handle constant integers that are multiple of 8 bits.
6246   if (ConstantInt *CI = dyn_cast<ConstantInt>(C)) {
6247     if (CI->getBitWidth() % 8 == 0) {
6248       assert(CI->getBitWidth() > 8 && "8 bits should be handled above!");
6249       if (!CI->getValue().isSplat(8))
6250         return nullptr;
6251       return ConstantInt::get(Ctx, CI->getValue().trunc(8));
6252     }
6253   }
6254 
6255   if (auto *CE = dyn_cast<ConstantExpr>(C)) {
6256     if (CE->getOpcode() == Instruction::IntToPtr) {
6257       if (auto *PtrTy = dyn_cast<PointerType>(CE->getType())) {
6258         unsigned BitWidth = DL.getPointerSizeInBits(PtrTy->getAddressSpace());
6259         if (Constant *Op = ConstantFoldIntegerCast(
6260                 CE->getOperand(0), Type::getIntNTy(Ctx, BitWidth), false, DL))
6261           return isBytewiseValue(Op, DL);
6262       }
6263     }
6264   }
6265 
6266   auto Merge = [&](Value *LHS, Value *RHS) -> Value * {
6267     if (LHS == RHS)
6268       return LHS;
6269     if (!LHS || !RHS)
6270       return nullptr;
6271     if (LHS == UndefInt8)
6272       return RHS;
6273     if (RHS == UndefInt8)
6274       return LHS;
6275     return nullptr;
6276   };
6277 
6278   if (ConstantDataSequential *CA = dyn_cast<ConstantDataSequential>(C)) {
6279     Value *Val = UndefInt8;
6280     for (unsigned I = 0, E = CA->getNumElements(); I != E; ++I)
6281       if (!(Val = Merge(Val, isBytewiseValue(CA->getElementAsConstant(I), DL))))
6282         return nullptr;
6283     return Val;
6284   }
6285 
6286   if (isa<ConstantAggregate>(C)) {
6287     Value *Val = UndefInt8;
6288     for (Value *Op : C->operands())
6289       if (!(Val = Merge(Val, isBytewiseValue(Op, DL))))
6290         return nullptr;
6291     return Val;
6292   }
6293 
6294   // Don't try to handle the handful of other constants.
6295   return nullptr;
6296 }
6297 
6298 // This is the recursive version of BuildSubAggregate. It takes a few different
6299 // arguments. Idxs is the index within the nested struct From that we are
6300 // looking at now (which is of type IndexedType). IdxSkip is the number of
6301 // indices from Idxs that should be left out when inserting into the resulting
6302 // struct. To is the result struct built so far, new insertvalue instructions
6303 // build on that.
6304 static Value *BuildSubAggregate(Value *From, Value *To, Type *IndexedType,
6305                                 SmallVectorImpl<unsigned> &Idxs,
6306                                 unsigned IdxSkip,
6307                                 BasicBlock::iterator InsertBefore) {
6308   StructType *STy = dyn_cast<StructType>(IndexedType);
6309   if (STy) {
6310     // Save the original To argument so we can modify it
6311     Value *OrigTo = To;
6312     // General case, the type indexed by Idxs is a struct
6313     for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
6314       // Process each struct element recursively
6315       Idxs.push_back(i);
6316       Value *PrevTo = To;
6317       To = BuildSubAggregate(From, To, STy->getElementType(i), Idxs, IdxSkip,
6318                              InsertBefore);
6319       Idxs.pop_back();
6320       if (!To) {
6321         // Couldn't find any inserted value for this index? Cleanup
6322         while (PrevTo != OrigTo) {
6323           InsertValueInst* Del = cast<InsertValueInst>(PrevTo);
6324           PrevTo = Del->getAggregateOperand();
6325           Del->eraseFromParent();
6326         }
6327         // Stop processing elements
6328         break;
6329       }
6330     }
6331     // If we successfully found a value for each of our subaggregates
6332     if (To)
6333       return To;
6334   }
6335   // Base case, the type indexed by SourceIdxs is not a struct, or not all of
6336   // the struct's elements had a value that was inserted directly. In the latter
6337   // case, perhaps we can't determine each of the subelements individually, but
6338   // we might be able to find the complete struct somewhere.
6339 
6340   // Find the value that is at that particular spot
6341   Value *V = FindInsertedValue(From, Idxs);
6342 
6343   if (!V)
6344     return nullptr;
6345 
6346   // Insert the value in the new (sub) aggregate
6347   return InsertValueInst::Create(To, V, ArrayRef(Idxs).slice(IdxSkip), "tmp",
6348                                  InsertBefore);
6349 }
6350 
6351 // This helper takes a nested struct and extracts a part of it (which is again a
6352 // struct) into a new value. For example, given the struct:
6353 // { a, { b, { c, d }, e } }
6354 // and the indices "1, 1" this returns
6355 // { c, d }.
6356 //
6357 // It does this by inserting an insertvalue for each element in the resulting
6358 // struct, as opposed to just inserting a single struct. This will only work if
6359 // each of the elements of the substruct are known (ie, inserted into From by an
6360 // insertvalue instruction somewhere).
6361 //
6362 // All inserted insertvalue instructions are inserted before InsertBefore
6363 static Value *BuildSubAggregate(Value *From, ArrayRef<unsigned> idx_range,
6364                                 BasicBlock::iterator InsertBefore) {
6365   Type *IndexedType = ExtractValueInst::getIndexedType(From->getType(),
6366                                                              idx_range);
6367   Value *To = PoisonValue::get(IndexedType);
6368   SmallVector<unsigned, 10> Idxs(idx_range);
6369   unsigned IdxSkip = Idxs.size();
6370 
6371   return BuildSubAggregate(From, To, IndexedType, Idxs, IdxSkip, InsertBefore);
6372 }
6373 
6374 /// Given an aggregate and a sequence of indices, see if the scalar value
6375 /// indexed is already around as a register, for example if it was inserted
6376 /// directly into the aggregate.
6377 ///
6378 /// If InsertBefore is not null, this function will duplicate (modified)
6379 /// insertvalues when a part of a nested struct is extracted.
6380 Value *
6381 llvm::FindInsertedValue(Value *V, ArrayRef<unsigned> idx_range,
6382                         std::optional<BasicBlock::iterator> InsertBefore) {
6383   // Nothing to index? Just return V then (this is useful at the end of our
6384   // recursion).
6385   if (idx_range.empty())
6386     return V;
6387   // We have indices, so V should have an indexable type.
6388   assert((V->getType()->isStructTy() || V->getType()->isArrayTy()) &&
6389          "Not looking at a struct or array?");
6390   assert(ExtractValueInst::getIndexedType(V->getType(), idx_range) &&
6391          "Invalid indices for type?");
6392 
6393   if (Constant *C = dyn_cast<Constant>(V)) {
6394     C = C->getAggregateElement(idx_range[0]);
6395     if (!C) return nullptr;
6396     return FindInsertedValue(C, idx_range.slice(1), InsertBefore);
6397   }
6398 
6399   if (InsertValueInst *I = dyn_cast<InsertValueInst>(V)) {
6400     // Loop the indices for the insertvalue instruction in parallel with the
6401     // requested indices
6402     const unsigned *req_idx = idx_range.begin();
6403     for (const unsigned *i = I->idx_begin(), *e = I->idx_end();
6404          i != e; ++i, ++req_idx) {
6405       if (req_idx == idx_range.end()) {
6406         // We can't handle this without inserting insertvalues
6407         if (!InsertBefore)
6408           return nullptr;
6409 
6410         // The requested index identifies a part of a nested aggregate. Handle
6411         // this specially. For example,
6412         // %A = insertvalue { i32, {i32, i32 } } undef, i32 10, 1, 0
6413         // %B = insertvalue { i32, {i32, i32 } } %A, i32 11, 1, 1
6414         // %C = extractvalue {i32, { i32, i32 } } %B, 1
6415         // This can be changed into
6416         // %A = insertvalue {i32, i32 } undef, i32 10, 0
6417         // %C = insertvalue {i32, i32 } %A, i32 11, 1
6418         // which allows the unused 0,0 element from the nested struct to be
6419         // removed.
6420         return BuildSubAggregate(V, ArrayRef(idx_range.begin(), req_idx),
6421                                  *InsertBefore);
6422       }
6423 
6424       // This insert value inserts something else than what we are looking for.
6425       // See if the (aggregate) value inserted into has the value we are
6426       // looking for, then.
6427       if (*req_idx != *i)
6428         return FindInsertedValue(I->getAggregateOperand(), idx_range,
6429                                  InsertBefore);
6430     }
6431     // If we end up here, the indices of the insertvalue match with those
6432     // requested (though possibly only partially). Now we recursively look at
6433     // the inserted value, passing any remaining indices.
6434     return FindInsertedValue(I->getInsertedValueOperand(),
6435                              ArrayRef(req_idx, idx_range.end()), InsertBefore);
6436   }
6437 
6438   if (ExtractValueInst *I = dyn_cast<ExtractValueInst>(V)) {
6439     // If we're extracting a value from an aggregate that was extracted from
6440     // something else, we can extract from that something else directly instead.
6441     // However, we will need to chain I's indices with the requested indices.
6442 
6443     // Calculate the number of indices required
6444     unsigned size = I->getNumIndices() + idx_range.size();
6445     // Allocate some space to put the new indices in
6446     SmallVector<unsigned, 5> Idxs;
6447     Idxs.reserve(size);
6448     // Add indices from the extract value instruction
6449     Idxs.append(I->idx_begin(), I->idx_end());
6450 
6451     // Add requested indices
6452     Idxs.append(idx_range.begin(), idx_range.end());
6453 
6454     assert(Idxs.size() == size
6455            && "Number of indices added not correct?");
6456 
6457     return FindInsertedValue(I->getAggregateOperand(), Idxs, InsertBefore);
6458   }
6459   // Otherwise, we don't know (such as, extracting from a function return value
6460   // or load instruction)
6461   return nullptr;
6462 }
6463 
6464 bool llvm::isGEPBasedOnPointerToString(const GEPOperator *GEP,
6465                                        unsigned CharSize) {
6466   // Make sure the GEP has exactly three arguments.
6467   if (GEP->getNumOperands() != 3)
6468     return false;
6469 
6470   // Make sure the index-ee is a pointer to array of \p CharSize integers.
6471   // CharSize.
6472   ArrayType *AT = dyn_cast<ArrayType>(GEP->getSourceElementType());
6473   if (!AT || !AT->getElementType()->isIntegerTy(CharSize))
6474     return false;
6475 
6476   // Check to make sure that the first operand of the GEP is an integer and
6477   // has value 0 so that we are sure we're indexing into the initializer.
6478   const ConstantInt *FirstIdx = dyn_cast<ConstantInt>(GEP->getOperand(1));
6479   if (!FirstIdx || !FirstIdx->isZero())
6480     return false;
6481 
6482   return true;
6483 }
6484 
6485 // If V refers to an initialized global constant, set Slice either to
6486 // its initializer if the size of its elements equals ElementSize, or,
6487 // for ElementSize == 8, to its representation as an array of unsiged
6488 // char. Return true on success.
6489 // Offset is in the unit "nr of ElementSize sized elements".
6490 bool llvm::getConstantDataArrayInfo(const Value *V,
6491                                     ConstantDataArraySlice &Slice,
6492                                     unsigned ElementSize, uint64_t Offset) {
6493   assert(V && "V should not be null.");
6494   assert((ElementSize % 8) == 0 &&
6495          "ElementSize expected to be a multiple of the size of a byte.");
6496   unsigned ElementSizeInBytes = ElementSize / 8;
6497 
6498   // Drill down into the pointer expression V, ignoring any intervening
6499   // casts, and determine the identity of the object it references along
6500   // with the cumulative byte offset into it.
6501   const GlobalVariable *GV =
6502     dyn_cast<GlobalVariable>(getUnderlyingObject(V));
6503   if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
6504     // Fail if V is not based on constant global object.
6505     return false;
6506 
6507   const DataLayout &DL = GV->getDataLayout();
6508   APInt Off(DL.getIndexTypeSizeInBits(V->getType()), 0);
6509 
6510   if (GV != V->stripAndAccumulateConstantOffsets(DL, Off,
6511                                                  /*AllowNonInbounds*/ true))
6512     // Fail if a constant offset could not be determined.
6513     return false;
6514 
6515   uint64_t StartIdx = Off.getLimitedValue();
6516   if (StartIdx == UINT64_MAX)
6517     // Fail if the constant offset is excessive.
6518     return false;
6519 
6520   // Off/StartIdx is in the unit of bytes. So we need to convert to number of
6521   // elements. Simply bail out if that isn't possible.
6522   if ((StartIdx % ElementSizeInBytes) != 0)
6523     return false;
6524 
6525   Offset += StartIdx / ElementSizeInBytes;
6526   ConstantDataArray *Array = nullptr;
6527   ArrayType *ArrayTy = nullptr;
6528 
6529   if (GV->getInitializer()->isNullValue()) {
6530     Type *GVTy = GV->getValueType();
6531     uint64_t SizeInBytes = DL.getTypeStoreSize(GVTy).getFixedValue();
6532     uint64_t Length = SizeInBytes / ElementSizeInBytes;
6533 
6534     Slice.Array = nullptr;
6535     Slice.Offset = 0;
6536     // Return an empty Slice for undersized constants to let callers
6537     // transform even undefined library calls into simpler, well-defined
6538     // expressions.  This is preferable to making the calls although it
6539     // prevents sanitizers from detecting such calls.
6540     Slice.Length = Length < Offset ? 0 : Length - Offset;
6541     return true;
6542   }
6543 
6544   auto *Init = const_cast<Constant *>(GV->getInitializer());
6545   if (auto *ArrayInit = dyn_cast<ConstantDataArray>(Init)) {
6546     Type *InitElTy = ArrayInit->getElementType();
6547     if (InitElTy->isIntegerTy(ElementSize)) {
6548       // If Init is an initializer for an array of the expected type
6549       // and size, use it as is.
6550       Array = ArrayInit;
6551       ArrayTy = ArrayInit->getType();
6552     }
6553   }
6554 
6555   if (!Array) {
6556     if (ElementSize != 8)
6557       // TODO: Handle conversions to larger integral types.
6558       return false;
6559 
6560     // Otherwise extract the portion of the initializer starting
6561     // at Offset as an array of bytes, and reset Offset.
6562     Init = ReadByteArrayFromGlobal(GV, Offset);
6563     if (!Init)
6564       return false;
6565 
6566     Offset = 0;
6567     Array = dyn_cast<ConstantDataArray>(Init);
6568     ArrayTy = dyn_cast<ArrayType>(Init->getType());
6569   }
6570 
6571   uint64_t NumElts = ArrayTy->getArrayNumElements();
6572   if (Offset > NumElts)
6573     return false;
6574 
6575   Slice.Array = Array;
6576   Slice.Offset = Offset;
6577   Slice.Length = NumElts - Offset;
6578   return true;
6579 }
6580 
6581 /// Extract bytes from the initializer of the constant array V, which need
6582 /// not be a nul-terminated string.  On success, store the bytes in Str and
6583 /// return true.  When TrimAtNul is set, Str will contain only the bytes up
6584 /// to but not including the first nul.  Return false on failure.
6585 bool llvm::getConstantStringInfo(const Value *V, StringRef &Str,
6586                                  bool TrimAtNul) {
6587   ConstantDataArraySlice Slice;
6588   if (!getConstantDataArrayInfo(V, Slice, 8))
6589     return false;
6590 
6591   if (Slice.Array == nullptr) {
6592     if (TrimAtNul) {
6593       // Return a nul-terminated string even for an empty Slice.  This is
6594       // safe because all existing SimplifyLibcalls callers require string
6595       // arguments and the behavior of the functions they fold is undefined
6596       // otherwise.  Folding the calls this way is preferable to making
6597       // the undefined library calls, even though it prevents sanitizers
6598       // from reporting such calls.
6599       Str = StringRef();
6600       return true;
6601     }
6602     if (Slice.Length == 1) {
6603       Str = StringRef("", 1);
6604       return true;
6605     }
6606     // We cannot instantiate a StringRef as we do not have an appropriate string
6607     // of 0s at hand.
6608     return false;
6609   }
6610 
6611   // Start out with the entire array in the StringRef.
6612   Str = Slice.Array->getAsString();
6613   // Skip over 'offset' bytes.
6614   Str = Str.substr(Slice.Offset);
6615 
6616   if (TrimAtNul) {
6617     // Trim off the \0 and anything after it.  If the array is not nul
6618     // terminated, we just return the whole end of string.  The client may know
6619     // some other way that the string is length-bound.
6620     Str = Str.substr(0, Str.find('\0'));
6621   }
6622   return true;
6623 }
6624 
6625 // These next two are very similar to the above, but also look through PHI
6626 // nodes.
6627 // TODO: See if we can integrate these two together.
6628 
6629 /// If we can compute the length of the string pointed to by
6630 /// the specified pointer, return 'len+1'.  If we can't, return 0.
6631 static uint64_t GetStringLengthH(const Value *V,
6632                                  SmallPtrSetImpl<const PHINode*> &PHIs,
6633                                  unsigned CharSize) {
6634   // Look through noop bitcast instructions.
6635   V = V->stripPointerCasts();
6636 
6637   // If this is a PHI node, there are two cases: either we have already seen it
6638   // or we haven't.
6639   if (const PHINode *PN = dyn_cast<PHINode>(V)) {
6640     if (!PHIs.insert(PN).second)
6641       return ~0ULL;  // already in the set.
6642 
6643     // If it was new, see if all the input strings are the same length.
6644     uint64_t LenSoFar = ~0ULL;
6645     for (Value *IncValue : PN->incoming_values()) {
6646       uint64_t Len = GetStringLengthH(IncValue, PHIs, CharSize);
6647       if (Len == 0) return 0; // Unknown length -> unknown.
6648 
6649       if (Len == ~0ULL) continue;
6650 
6651       if (Len != LenSoFar && LenSoFar != ~0ULL)
6652         return 0;    // Disagree -> unknown.
6653       LenSoFar = Len;
6654     }
6655 
6656     // Success, all agree.
6657     return LenSoFar;
6658   }
6659 
6660   // strlen(select(c,x,y)) -> strlen(x) ^ strlen(y)
6661   if (const SelectInst *SI = dyn_cast<SelectInst>(V)) {
6662     uint64_t Len1 = GetStringLengthH(SI->getTrueValue(), PHIs, CharSize);
6663     if (Len1 == 0) return 0;
6664     uint64_t Len2 = GetStringLengthH(SI->getFalseValue(), PHIs, CharSize);
6665     if (Len2 == 0) return 0;
6666     if (Len1 == ~0ULL) return Len2;
6667     if (Len2 == ~0ULL) return Len1;
6668     if (Len1 != Len2) return 0;
6669     return Len1;
6670   }
6671 
6672   // Otherwise, see if we can read the string.
6673   ConstantDataArraySlice Slice;
6674   if (!getConstantDataArrayInfo(V, Slice, CharSize))
6675     return 0;
6676 
6677   if (Slice.Array == nullptr)
6678     // Zeroinitializer (including an empty one).
6679     return 1;
6680 
6681   // Search for the first nul character.  Return a conservative result even
6682   // when there is no nul.  This is safe since otherwise the string function
6683   // being folded such as strlen is undefined, and can be preferable to
6684   // making the undefined library call.
6685   unsigned NullIndex = 0;
6686   for (unsigned E = Slice.Length; NullIndex < E; ++NullIndex) {
6687     if (Slice.Array->getElementAsInteger(Slice.Offset + NullIndex) == 0)
6688       break;
6689   }
6690 
6691   return NullIndex + 1;
6692 }
6693 
6694 /// If we can compute the length of the string pointed to by
6695 /// the specified pointer, return 'len+1'.  If we can't, return 0.
6696 uint64_t llvm::GetStringLength(const Value *V, unsigned CharSize) {
6697   if (!V->getType()->isPointerTy())
6698     return 0;
6699 
6700   SmallPtrSet<const PHINode*, 32> PHIs;
6701   uint64_t Len = GetStringLengthH(V, PHIs, CharSize);
6702   // If Len is ~0ULL, we had an infinite phi cycle: this is dead code, so return
6703   // an empty string as a length.
6704   return Len == ~0ULL ? 1 : Len;
6705 }
6706 
6707 const Value *
6708 llvm::getArgumentAliasingToReturnedPointer(const CallBase *Call,
6709                                            bool MustPreserveNullness) {
6710   assert(Call &&
6711          "getArgumentAliasingToReturnedPointer only works on nonnull calls");
6712   if (const Value *RV = Call->getReturnedArgOperand())
6713     return RV;
6714   // This can be used only as a aliasing property.
6715   if (isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(
6716           Call, MustPreserveNullness))
6717     return Call->getArgOperand(0);
6718   return nullptr;
6719 }
6720 
6721 bool llvm::isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(
6722     const CallBase *Call, bool MustPreserveNullness) {
6723   switch (Call->getIntrinsicID()) {
6724   case Intrinsic::launder_invariant_group:
6725   case Intrinsic::strip_invariant_group:
6726   case Intrinsic::aarch64_irg:
6727   case Intrinsic::aarch64_tagp:
6728   // The amdgcn_make_buffer_rsrc function does not alter the address of the
6729   // input pointer (and thus preserve null-ness for the purposes of escape
6730   // analysis, which is where the MustPreserveNullness flag comes in to play).
6731   // However, it will not necessarily map ptr addrspace(N) null to ptr
6732   // addrspace(8) null, aka the "null descriptor", which has "all loads return
6733   // 0, all stores are dropped" semantics. Given the context of this intrinsic
6734   // list, no one should be relying on such a strict interpretation of
6735   // MustPreserveNullness (and, at time of writing, they are not), but we
6736   // document this fact out of an abundance of caution.
6737   case Intrinsic::amdgcn_make_buffer_rsrc:
6738     return true;
6739   case Intrinsic::ptrmask:
6740     return !MustPreserveNullness;
6741   case Intrinsic::threadlocal_address:
6742     // The underlying variable changes with thread ID. The Thread ID may change
6743     // at coroutine suspend points.
6744     return !Call->getParent()->getParent()->isPresplitCoroutine();
6745   default:
6746     return false;
6747   }
6748 }
6749 
6750 /// \p PN defines a loop-variant pointer to an object.  Check if the
6751 /// previous iteration of the loop was referring to the same object as \p PN.
6752 static bool isSameUnderlyingObjectInLoop(const PHINode *PN,
6753                                          const LoopInfo *LI) {
6754   // Find the loop-defined value.
6755   Loop *L = LI->getLoopFor(PN->getParent());
6756   if (PN->getNumIncomingValues() != 2)
6757     return true;
6758 
6759   // Find the value from previous iteration.
6760   auto *PrevValue = dyn_cast<Instruction>(PN->getIncomingValue(0));
6761   if (!PrevValue || LI->getLoopFor(PrevValue->getParent()) != L)
6762     PrevValue = dyn_cast<Instruction>(PN->getIncomingValue(1));
6763   if (!PrevValue || LI->getLoopFor(PrevValue->getParent()) != L)
6764     return true;
6765 
6766   // If a new pointer is loaded in the loop, the pointer references a different
6767   // object in every iteration.  E.g.:
6768   //    for (i)
6769   //       int *p = a[i];
6770   //       ...
6771   if (auto *Load = dyn_cast<LoadInst>(PrevValue))
6772     if (!L->isLoopInvariant(Load->getPointerOperand()))
6773       return false;
6774   return true;
6775 }
6776 
6777 const Value *llvm::getUnderlyingObject(const Value *V, unsigned MaxLookup) {
6778   for (unsigned Count = 0; MaxLookup == 0 || Count < MaxLookup; ++Count) {
6779     if (auto *GEP = dyn_cast<GEPOperator>(V)) {
6780       const Value *PtrOp = GEP->getPointerOperand();
6781       if (!PtrOp->getType()->isPointerTy()) // Only handle scalar pointer base.
6782         return V;
6783       V = PtrOp;
6784     } else if (Operator::getOpcode(V) == Instruction::BitCast ||
6785                Operator::getOpcode(V) == Instruction::AddrSpaceCast) {
6786       Value *NewV = cast<Operator>(V)->getOperand(0);
6787       if (!NewV->getType()->isPointerTy())
6788         return V;
6789       V = NewV;
6790     } else if (auto *GA = dyn_cast<GlobalAlias>(V)) {
6791       if (GA->isInterposable())
6792         return V;
6793       V = GA->getAliasee();
6794     } else {
6795       if (auto *PHI = dyn_cast<PHINode>(V)) {
6796         // Look through single-arg phi nodes created by LCSSA.
6797         if (PHI->getNumIncomingValues() == 1) {
6798           V = PHI->getIncomingValue(0);
6799           continue;
6800         }
6801       } else if (auto *Call = dyn_cast<CallBase>(V)) {
6802         // CaptureTracking can know about special capturing properties of some
6803         // intrinsics like launder.invariant.group, that can't be expressed with
6804         // the attributes, but have properties like returning aliasing pointer.
6805         // Because some analysis may assume that nocaptured pointer is not
6806         // returned from some special intrinsic (because function would have to
6807         // be marked with returns attribute), it is crucial to use this function
6808         // because it should be in sync with CaptureTracking. Not using it may
6809         // cause weird miscompilations where 2 aliasing pointers are assumed to
6810         // noalias.
6811         if (auto *RP = getArgumentAliasingToReturnedPointer(Call, false)) {
6812           V = RP;
6813           continue;
6814         }
6815       }
6816 
6817       return V;
6818     }
6819     assert(V->getType()->isPointerTy() && "Unexpected operand type!");
6820   }
6821   return V;
6822 }
6823 
6824 void llvm::getUnderlyingObjects(const Value *V,
6825                                 SmallVectorImpl<const Value *> &Objects,
6826                                 const LoopInfo *LI, unsigned MaxLookup) {
6827   SmallPtrSet<const Value *, 4> Visited;
6828   SmallVector<const Value *, 4> Worklist;
6829   Worklist.push_back(V);
6830   do {
6831     const Value *P = Worklist.pop_back_val();
6832     P = getUnderlyingObject(P, MaxLookup);
6833 
6834     if (!Visited.insert(P).second)
6835       continue;
6836 
6837     if (auto *SI = dyn_cast<SelectInst>(P)) {
6838       Worklist.push_back(SI->getTrueValue());
6839       Worklist.push_back(SI->getFalseValue());
6840       continue;
6841     }
6842 
6843     if (auto *PN = dyn_cast<PHINode>(P)) {
6844       // If this PHI changes the underlying object in every iteration of the
6845       // loop, don't look through it.  Consider:
6846       //   int **A;
6847       //   for (i) {
6848       //     Prev = Curr;     // Prev = PHI (Prev_0, Curr)
6849       //     Curr = A[i];
6850       //     *Prev, *Curr;
6851       //
6852       // Prev is tracking Curr one iteration behind so they refer to different
6853       // underlying objects.
6854       if (!LI || !LI->isLoopHeader(PN->getParent()) ||
6855           isSameUnderlyingObjectInLoop(PN, LI))
6856         append_range(Worklist, PN->incoming_values());
6857       else
6858         Objects.push_back(P);
6859       continue;
6860     }
6861 
6862     Objects.push_back(P);
6863   } while (!Worklist.empty());
6864 }
6865 
6866 const Value *llvm::getUnderlyingObjectAggressive(const Value *V) {
6867   const unsigned MaxVisited = 8;
6868 
6869   SmallPtrSet<const Value *, 8> Visited;
6870   SmallVector<const Value *, 8> Worklist;
6871   Worklist.push_back(V);
6872   const Value *Object = nullptr;
6873   // Used as fallback if we can't find a common underlying object through
6874   // recursion.
6875   bool First = true;
6876   const Value *FirstObject = getUnderlyingObject(V);
6877   do {
6878     const Value *P = Worklist.pop_back_val();
6879     P = First ? FirstObject : getUnderlyingObject(P);
6880     First = false;
6881 
6882     if (!Visited.insert(P).second)
6883       continue;
6884 
6885     if (Visited.size() == MaxVisited)
6886       return FirstObject;
6887 
6888     if (auto *SI = dyn_cast<SelectInst>(P)) {
6889       Worklist.push_back(SI->getTrueValue());
6890       Worklist.push_back(SI->getFalseValue());
6891       continue;
6892     }
6893 
6894     if (auto *PN = dyn_cast<PHINode>(P)) {
6895       append_range(Worklist, PN->incoming_values());
6896       continue;
6897     }
6898 
6899     if (!Object)
6900       Object = P;
6901     else if (Object != P)
6902       return FirstObject;
6903   } while (!Worklist.empty());
6904 
6905   return Object ? Object : FirstObject;
6906 }
6907 
6908 /// This is the function that does the work of looking through basic
6909 /// ptrtoint+arithmetic+inttoptr sequences.
6910 static const Value *getUnderlyingObjectFromInt(const Value *V) {
6911   do {
6912     if (const Operator *U = dyn_cast<Operator>(V)) {
6913       // If we find a ptrtoint, we can transfer control back to the
6914       // regular getUnderlyingObjectFromInt.
6915       if (U->getOpcode() == Instruction::PtrToInt)
6916         return U->getOperand(0);
6917       // If we find an add of a constant, a multiplied value, or a phi, it's
6918       // likely that the other operand will lead us to the base
6919       // object. We don't have to worry about the case where the
6920       // object address is somehow being computed by the multiply,
6921       // because our callers only care when the result is an
6922       // identifiable object.
6923       if (U->getOpcode() != Instruction::Add ||
6924           (!isa<ConstantInt>(U->getOperand(1)) &&
6925            Operator::getOpcode(U->getOperand(1)) != Instruction::Mul &&
6926            !isa<PHINode>(U->getOperand(1))))
6927         return V;
6928       V = U->getOperand(0);
6929     } else {
6930       return V;
6931     }
6932     assert(V->getType()->isIntegerTy() && "Unexpected operand type!");
6933   } while (true);
6934 }
6935 
6936 /// This is a wrapper around getUnderlyingObjects and adds support for basic
6937 /// ptrtoint+arithmetic+inttoptr sequences.
6938 /// It returns false if unidentified object is found in getUnderlyingObjects.
6939 bool llvm::getUnderlyingObjectsForCodeGen(const Value *V,
6940                                           SmallVectorImpl<Value *> &Objects) {
6941   SmallPtrSet<const Value *, 16> Visited;
6942   SmallVector<const Value *, 4> Working(1, V);
6943   do {
6944     V = Working.pop_back_val();
6945 
6946     SmallVector<const Value *, 4> Objs;
6947     getUnderlyingObjects(V, Objs);
6948 
6949     for (const Value *V : Objs) {
6950       if (!Visited.insert(V).second)
6951         continue;
6952       if (Operator::getOpcode(V) == Instruction::IntToPtr) {
6953         const Value *O =
6954           getUnderlyingObjectFromInt(cast<User>(V)->getOperand(0));
6955         if (O->getType()->isPointerTy()) {
6956           Working.push_back(O);
6957           continue;
6958         }
6959       }
6960       // If getUnderlyingObjects fails to find an identifiable object,
6961       // getUnderlyingObjectsForCodeGen also fails for safety.
6962       if (!isIdentifiedObject(V)) {
6963         Objects.clear();
6964         return false;
6965       }
6966       Objects.push_back(const_cast<Value *>(V));
6967     }
6968   } while (!Working.empty());
6969   return true;
6970 }
6971 
6972 AllocaInst *llvm::findAllocaForValue(Value *V, bool OffsetZero) {
6973   AllocaInst *Result = nullptr;
6974   SmallPtrSet<Value *, 4> Visited;
6975   SmallVector<Value *, 4> Worklist;
6976 
6977   auto AddWork = [&](Value *V) {
6978     if (Visited.insert(V).second)
6979       Worklist.push_back(V);
6980   };
6981 
6982   AddWork(V);
6983   do {
6984     V = Worklist.pop_back_val();
6985     assert(Visited.count(V));
6986 
6987     if (AllocaInst *AI = dyn_cast<AllocaInst>(V)) {
6988       if (Result && Result != AI)
6989         return nullptr;
6990       Result = AI;
6991     } else if (CastInst *CI = dyn_cast<CastInst>(V)) {
6992       AddWork(CI->getOperand(0));
6993     } else if (PHINode *PN = dyn_cast<PHINode>(V)) {
6994       for (Value *IncValue : PN->incoming_values())
6995         AddWork(IncValue);
6996     } else if (auto *SI = dyn_cast<SelectInst>(V)) {
6997       AddWork(SI->getTrueValue());
6998       AddWork(SI->getFalseValue());
6999     } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) {
7000       if (OffsetZero && !GEP->hasAllZeroIndices())
7001         return nullptr;
7002       AddWork(GEP->getPointerOperand());
7003     } else if (CallBase *CB = dyn_cast<CallBase>(V)) {
7004       Value *Returned = CB->getReturnedArgOperand();
7005       if (Returned)
7006         AddWork(Returned);
7007       else
7008         return nullptr;
7009     } else {
7010       return nullptr;
7011     }
7012   } while (!Worklist.empty());
7013 
7014   return Result;
7015 }
7016 
7017 static bool onlyUsedByLifetimeMarkersOrDroppableInstsHelper(
7018     const Value *V, bool AllowLifetime, bool AllowDroppable) {
7019   for (const User *U : V->users()) {
7020     const IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
7021     if (!II)
7022       return false;
7023 
7024     if (AllowLifetime && II->isLifetimeStartOrEnd())
7025       continue;
7026 
7027     if (AllowDroppable && II->isDroppable())
7028       continue;
7029 
7030     return false;
7031   }
7032   return true;
7033 }
7034 
7035 bool llvm::onlyUsedByLifetimeMarkers(const Value *V) {
7036   return onlyUsedByLifetimeMarkersOrDroppableInstsHelper(
7037       V, /* AllowLifetime */ true, /* AllowDroppable */ false);
7038 }
7039 bool llvm::onlyUsedByLifetimeMarkersOrDroppableInsts(const Value *V) {
7040   return onlyUsedByLifetimeMarkersOrDroppableInstsHelper(
7041       V, /* AllowLifetime */ true, /* AllowDroppable */ true);
7042 }
7043 
7044 bool llvm::isNotCrossLaneOperation(const Instruction *I) {
7045   if (auto *II = dyn_cast<IntrinsicInst>(I))
7046     return isTriviallyVectorizable(II->getIntrinsicID());
7047   auto *Shuffle = dyn_cast<ShuffleVectorInst>(I);
7048   return (!Shuffle || Shuffle->isSelect()) &&
7049          !isa<CallBase, BitCastInst, ExtractElementInst>(I);
7050 }
7051 
7052 bool llvm::isSafeToSpeculativelyExecute(const Instruction *Inst,
7053                                         const Instruction *CtxI,
7054                                         AssumptionCache *AC,
7055                                         const DominatorTree *DT,
7056                                         const TargetLibraryInfo *TLI,
7057                                         bool UseVariableInfo) {
7058   return isSafeToSpeculativelyExecuteWithOpcode(Inst->getOpcode(), Inst, CtxI,
7059                                                 AC, DT, TLI, UseVariableInfo);
7060 }
7061 
7062 bool llvm::isSafeToSpeculativelyExecuteWithOpcode(
7063     unsigned Opcode, const Instruction *Inst, const Instruction *CtxI,
7064     AssumptionCache *AC, const DominatorTree *DT, const TargetLibraryInfo *TLI,
7065     bool UseVariableInfo) {
7066 #ifndef NDEBUG
7067   if (Inst->getOpcode() != Opcode) {
7068     // Check that the operands are actually compatible with the Opcode override.
7069     auto hasEqualReturnAndLeadingOperandTypes =
7070         [](const Instruction *Inst, unsigned NumLeadingOperands) {
7071           if (Inst->getNumOperands() < NumLeadingOperands)
7072             return false;
7073           const Type *ExpectedType = Inst->getType();
7074           for (unsigned ItOp = 0; ItOp < NumLeadingOperands; ++ItOp)
7075             if (Inst->getOperand(ItOp)->getType() != ExpectedType)
7076               return false;
7077           return true;
7078         };
7079     assert(!Instruction::isBinaryOp(Opcode) ||
7080            hasEqualReturnAndLeadingOperandTypes(Inst, 2));
7081     assert(!Instruction::isUnaryOp(Opcode) ||
7082            hasEqualReturnAndLeadingOperandTypes(Inst, 1));
7083   }
7084 #endif
7085 
7086   switch (Opcode) {
7087   default:
7088     return true;
7089   case Instruction::UDiv:
7090   case Instruction::URem: {
7091     // x / y is undefined if y == 0.
7092     const APInt *V;
7093     if (match(Inst->getOperand(1), m_APInt(V)))
7094       return *V != 0;
7095     return false;
7096   }
7097   case Instruction::SDiv:
7098   case Instruction::SRem: {
7099     // x / y is undefined if y == 0 or x == INT_MIN and y == -1
7100     const APInt *Numerator, *Denominator;
7101     if (!match(Inst->getOperand(1), m_APInt(Denominator)))
7102       return false;
7103     // We cannot hoist this division if the denominator is 0.
7104     if (*Denominator == 0)
7105       return false;
7106     // It's safe to hoist if the denominator is not 0 or -1.
7107     if (!Denominator->isAllOnes())
7108       return true;
7109     // At this point we know that the denominator is -1.  It is safe to hoist as
7110     // long we know that the numerator is not INT_MIN.
7111     if (match(Inst->getOperand(0), m_APInt(Numerator)))
7112       return !Numerator->isMinSignedValue();
7113     // The numerator *might* be MinSignedValue.
7114     return false;
7115   }
7116   case Instruction::Load: {
7117     if (!UseVariableInfo)
7118       return false;
7119 
7120     const LoadInst *LI = dyn_cast<LoadInst>(Inst);
7121     if (!LI)
7122       return false;
7123     if (mustSuppressSpeculation(*LI))
7124       return false;
7125     const DataLayout &DL = LI->getDataLayout();
7126     return isDereferenceableAndAlignedPointer(LI->getPointerOperand(),
7127                                               LI->getType(), LI->getAlign(), DL,
7128                                               CtxI, AC, DT, TLI);
7129   }
7130   case Instruction::Call: {
7131     auto *CI = dyn_cast<const CallInst>(Inst);
7132     if (!CI)
7133       return false;
7134     const Function *Callee = CI->getCalledFunction();
7135 
7136     // The called function could have undefined behavior or side-effects, even
7137     // if marked readnone nounwind.
7138     return Callee && Callee->isSpeculatable();
7139   }
7140   case Instruction::VAArg:
7141   case Instruction::Alloca:
7142   case Instruction::Invoke:
7143   case Instruction::CallBr:
7144   case Instruction::PHI:
7145   case Instruction::Store:
7146   case Instruction::Ret:
7147   case Instruction::Br:
7148   case Instruction::IndirectBr:
7149   case Instruction::Switch:
7150   case Instruction::Unreachable:
7151   case Instruction::Fence:
7152   case Instruction::AtomicRMW:
7153   case Instruction::AtomicCmpXchg:
7154   case Instruction::LandingPad:
7155   case Instruction::Resume:
7156   case Instruction::CatchSwitch:
7157   case Instruction::CatchPad:
7158   case Instruction::CatchRet:
7159   case Instruction::CleanupPad:
7160   case Instruction::CleanupRet:
7161     return false; // Misc instructions which have effects
7162   }
7163 }
7164 
7165 bool llvm::mayHaveNonDefUseDependency(const Instruction &I) {
7166   if (I.mayReadOrWriteMemory())
7167     // Memory dependency possible
7168     return true;
7169   if (!isSafeToSpeculativelyExecute(&I))
7170     // Can't move above a maythrow call or infinite loop.  Or if an
7171     // inalloca alloca, above a stacksave call.
7172     return true;
7173   if (!isGuaranteedToTransferExecutionToSuccessor(&I))
7174     // 1) Can't reorder two inf-loop calls, even if readonly
7175     // 2) Also can't reorder an inf-loop call below a instruction which isn't
7176     //    safe to speculative execute.  (Inverse of above)
7177     return true;
7178   return false;
7179 }
7180 
7181 /// Convert ConstantRange OverflowResult into ValueTracking OverflowResult.
7182 static OverflowResult mapOverflowResult(ConstantRange::OverflowResult OR) {
7183   switch (OR) {
7184     case ConstantRange::OverflowResult::MayOverflow:
7185       return OverflowResult::MayOverflow;
7186     case ConstantRange::OverflowResult::AlwaysOverflowsLow:
7187       return OverflowResult::AlwaysOverflowsLow;
7188     case ConstantRange::OverflowResult::AlwaysOverflowsHigh:
7189       return OverflowResult::AlwaysOverflowsHigh;
7190     case ConstantRange::OverflowResult::NeverOverflows:
7191       return OverflowResult::NeverOverflows;
7192   }
7193   llvm_unreachable("Unknown OverflowResult");
7194 }
7195 
7196 /// Combine constant ranges from computeConstantRange() and computeKnownBits().
7197 ConstantRange
7198 llvm::computeConstantRangeIncludingKnownBits(const WithCache<const Value *> &V,
7199                                              bool ForSigned,
7200                                              const SimplifyQuery &SQ) {
7201   ConstantRange CR1 =
7202       ConstantRange::fromKnownBits(V.getKnownBits(SQ), ForSigned);
7203   ConstantRange CR2 = computeConstantRange(V, ForSigned, SQ.IIQ.UseInstrInfo);
7204   ConstantRange::PreferredRangeType RangeType =
7205       ForSigned ? ConstantRange::Signed : ConstantRange::Unsigned;
7206   return CR1.intersectWith(CR2, RangeType);
7207 }
7208 
7209 OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS,
7210                                                    const Value *RHS,
7211                                                    const SimplifyQuery &SQ,
7212                                                    bool IsNSW) {
7213   KnownBits LHSKnown = computeKnownBits(LHS, /*Depth=*/0, SQ);
7214   KnownBits RHSKnown = computeKnownBits(RHS, /*Depth=*/0, SQ);
7215 
7216   // mul nsw of two non-negative numbers is also nuw.
7217   if (IsNSW && LHSKnown.isNonNegative() && RHSKnown.isNonNegative())
7218     return OverflowResult::NeverOverflows;
7219 
7220   ConstantRange LHSRange = ConstantRange::fromKnownBits(LHSKnown, false);
7221   ConstantRange RHSRange = ConstantRange::fromKnownBits(RHSKnown, false);
7222   return mapOverflowResult(LHSRange.unsignedMulMayOverflow(RHSRange));
7223 }
7224 
7225 OverflowResult llvm::computeOverflowForSignedMul(const Value *LHS,
7226                                                  const Value *RHS,
7227                                                  const SimplifyQuery &SQ) {
7228   // Multiplying n * m significant bits yields a result of n + m significant
7229   // bits. If the total number of significant bits does not exceed the
7230   // result bit width (minus 1), there is no overflow.
7231   // This means if we have enough leading sign bits in the operands
7232   // we can guarantee that the result does not overflow.
7233   // Ref: "Hacker's Delight" by Henry Warren
7234   unsigned BitWidth = LHS->getType()->getScalarSizeInBits();
7235 
7236   // Note that underestimating the number of sign bits gives a more
7237   // conservative answer.
7238   unsigned SignBits =
7239       ::ComputeNumSignBits(LHS, 0, SQ) + ::ComputeNumSignBits(RHS, 0, SQ);
7240 
7241   // First handle the easy case: if we have enough sign bits there's
7242   // definitely no overflow.
7243   if (SignBits > BitWidth + 1)
7244     return OverflowResult::NeverOverflows;
7245 
7246   // There are two ambiguous cases where there can be no overflow:
7247   //   SignBits == BitWidth + 1    and
7248   //   SignBits == BitWidth
7249   // The second case is difficult to check, therefore we only handle the
7250   // first case.
7251   if (SignBits == BitWidth + 1) {
7252     // It overflows only when both arguments are negative and the true
7253     // product is exactly the minimum negative number.
7254     // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000
7255     // For simplicity we just check if at least one side is not negative.
7256     KnownBits LHSKnown = computeKnownBits(LHS, /*Depth=*/0, SQ);
7257     KnownBits RHSKnown = computeKnownBits(RHS, /*Depth=*/0, SQ);
7258     if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative())
7259       return OverflowResult::NeverOverflows;
7260   }
7261   return OverflowResult::MayOverflow;
7262 }
7263 
7264 OverflowResult
7265 llvm::computeOverflowForUnsignedAdd(const WithCache<const Value *> &LHS,
7266                                     const WithCache<const Value *> &RHS,
7267                                     const SimplifyQuery &SQ) {
7268   ConstantRange LHSRange =
7269       computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/false, SQ);
7270   ConstantRange RHSRange =
7271       computeConstantRangeIncludingKnownBits(RHS, /*ForSigned=*/false, SQ);
7272   return mapOverflowResult(LHSRange.unsignedAddMayOverflow(RHSRange));
7273 }
7274 
7275 static OverflowResult
7276 computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
7277                             const WithCache<const Value *> &RHS,
7278                             const AddOperator *Add, const SimplifyQuery &SQ) {
7279   if (Add && Add->hasNoSignedWrap()) {
7280     return OverflowResult::NeverOverflows;
7281   }
7282 
7283   // If LHS and RHS each have at least two sign bits, the addition will look
7284   // like
7285   //
7286   // XX..... +
7287   // YY.....
7288   //
7289   // If the carry into the most significant position is 0, X and Y can't both
7290   // be 1 and therefore the carry out of the addition is also 0.
7291   //
7292   // If the carry into the most significant position is 1, X and Y can't both
7293   // be 0 and therefore the carry out of the addition is also 1.
7294   //
7295   // Since the carry into the most significant position is always equal to
7296   // the carry out of the addition, there is no signed overflow.
7297   if (::ComputeNumSignBits(LHS, 0, SQ) > 1 &&
7298       ::ComputeNumSignBits(RHS, 0, SQ) > 1)
7299     return OverflowResult::NeverOverflows;
7300 
7301   ConstantRange LHSRange =
7302       computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/true, SQ);
7303   ConstantRange RHSRange =
7304       computeConstantRangeIncludingKnownBits(RHS, /*ForSigned=*/true, SQ);
7305   OverflowResult OR =
7306       mapOverflowResult(LHSRange.signedAddMayOverflow(RHSRange));
7307   if (OR != OverflowResult::MayOverflow)
7308     return OR;
7309 
7310   // The remaining code needs Add to be available. Early returns if not so.
7311   if (!Add)
7312     return OverflowResult::MayOverflow;
7313 
7314   // If the sign of Add is the same as at least one of the operands, this add
7315   // CANNOT overflow. If this can be determined from the known bits of the
7316   // operands the above signedAddMayOverflow() check will have already done so.
7317   // The only other way to improve on the known bits is from an assumption, so
7318   // call computeKnownBitsFromContext() directly.
7319   bool LHSOrRHSKnownNonNegative =
7320       (LHSRange.isAllNonNegative() || RHSRange.isAllNonNegative());
7321   bool LHSOrRHSKnownNegative =
7322       (LHSRange.isAllNegative() || RHSRange.isAllNegative());
7323   if (LHSOrRHSKnownNonNegative || LHSOrRHSKnownNegative) {
7324     KnownBits AddKnown(LHSRange.getBitWidth());
7325     computeKnownBitsFromContext(Add, AddKnown, /*Depth=*/0, SQ);
7326     if ((AddKnown.isNonNegative() && LHSOrRHSKnownNonNegative) ||
7327         (AddKnown.isNegative() && LHSOrRHSKnownNegative))
7328       return OverflowResult::NeverOverflows;
7329   }
7330 
7331   return OverflowResult::MayOverflow;
7332 }
7333 
7334 OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS,
7335                                                    const Value *RHS,
7336                                                    const SimplifyQuery &SQ) {
7337   // X - (X % ?)
7338   // The remainder of a value can't have greater magnitude than itself,
7339   // so the subtraction can't overflow.
7340 
7341   // X - (X -nuw ?)
7342   // In the minimal case, this would simplify to "?", so there's no subtract
7343   // at all. But if this analysis is used to peek through casts, for example,
7344   // then determining no-overflow may allow other transforms.
7345 
7346   // TODO: There are other patterns like this.
7347   //       See simplifyICmpWithBinOpOnLHS() for candidates.
7348   if (match(RHS, m_URem(m_Specific(LHS), m_Value())) ||
7349       match(RHS, m_NUWSub(m_Specific(LHS), m_Value())))
7350     if (isGuaranteedNotToBeUndef(LHS, SQ.AC, SQ.CxtI, SQ.DT))
7351       return OverflowResult::NeverOverflows;
7352 
7353   if (auto C = isImpliedByDomCondition(CmpInst::ICMP_UGE, LHS, RHS, SQ.CxtI,
7354                                        SQ.DL)) {
7355     if (*C)
7356       return OverflowResult::NeverOverflows;
7357     return OverflowResult::AlwaysOverflowsLow;
7358   }
7359 
7360   ConstantRange LHSRange =
7361       computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/false, SQ);
7362   ConstantRange RHSRange =
7363       computeConstantRangeIncludingKnownBits(RHS, /*ForSigned=*/false, SQ);
7364   return mapOverflowResult(LHSRange.unsignedSubMayOverflow(RHSRange));
7365 }
7366 
7367 OverflowResult llvm::computeOverflowForSignedSub(const Value *LHS,
7368                                                  const Value *RHS,
7369                                                  const SimplifyQuery &SQ) {
7370   // X - (X % ?)
7371   // The remainder of a value can't have greater magnitude than itself,
7372   // so the subtraction can't overflow.
7373 
7374   // X - (X -nsw ?)
7375   // In the minimal case, this would simplify to "?", so there's no subtract
7376   // at all. But if this analysis is used to peek through casts, for example,
7377   // then determining no-overflow may allow other transforms.
7378   if (match(RHS, m_SRem(m_Specific(LHS), m_Value())) ||
7379       match(RHS, m_NSWSub(m_Specific(LHS), m_Value())))
7380     if (isGuaranteedNotToBeUndef(LHS, SQ.AC, SQ.CxtI, SQ.DT))
7381       return OverflowResult::NeverOverflows;
7382 
7383   // If LHS and RHS each have at least two sign bits, the subtraction
7384   // cannot overflow.
7385   if (::ComputeNumSignBits(LHS, 0, SQ) > 1 &&
7386       ::ComputeNumSignBits(RHS, 0, SQ) > 1)
7387     return OverflowResult::NeverOverflows;
7388 
7389   ConstantRange LHSRange =
7390       computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/true, SQ);
7391   ConstantRange RHSRange =
7392       computeConstantRangeIncludingKnownBits(RHS, /*ForSigned=*/true, SQ);
7393   return mapOverflowResult(LHSRange.signedSubMayOverflow(RHSRange));
7394 }
7395 
7396 bool llvm::isOverflowIntrinsicNoWrap(const WithOverflowInst *WO,
7397                                      const DominatorTree &DT) {
7398   SmallVector<const BranchInst *, 2> GuardingBranches;
7399   SmallVector<const ExtractValueInst *, 2> Results;
7400 
7401   for (const User *U : WO->users()) {
7402     if (const auto *EVI = dyn_cast<ExtractValueInst>(U)) {
7403       assert(EVI->getNumIndices() == 1 && "Obvious from CI's type");
7404 
7405       if (EVI->getIndices()[0] == 0)
7406         Results.push_back(EVI);
7407       else {
7408         assert(EVI->getIndices()[0] == 1 && "Obvious from CI's type");
7409 
7410         for (const auto *U : EVI->users())
7411           if (const auto *B = dyn_cast<BranchInst>(U)) {
7412             assert(B->isConditional() && "How else is it using an i1?");
7413             GuardingBranches.push_back(B);
7414           }
7415       }
7416     } else {
7417       // We are using the aggregate directly in a way we don't want to analyze
7418       // here (storing it to a global, say).
7419       return false;
7420     }
7421   }
7422 
7423   auto AllUsesGuardedByBranch = [&](const BranchInst *BI) {
7424     BasicBlockEdge NoWrapEdge(BI->getParent(), BI->getSuccessor(1));
7425     if (!NoWrapEdge.isSingleEdge())
7426       return false;
7427 
7428     // Check if all users of the add are provably no-wrap.
7429     for (const auto *Result : Results) {
7430       // If the extractvalue itself is not executed on overflow, the we don't
7431       // need to check each use separately, since domination is transitive.
7432       if (DT.dominates(NoWrapEdge, Result->getParent()))
7433         continue;
7434 
7435       for (const auto &RU : Result->uses())
7436         if (!DT.dominates(NoWrapEdge, RU))
7437           return false;
7438     }
7439 
7440     return true;
7441   };
7442 
7443   return llvm::any_of(GuardingBranches, AllUsesGuardedByBranch);
7444 }
7445 
7446 /// Shifts return poison if shiftwidth is larger than the bitwidth.
7447 static bool shiftAmountKnownInRange(const Value *ShiftAmount) {
7448   auto *C = dyn_cast<Constant>(ShiftAmount);
7449   if (!C)
7450     return false;
7451 
7452   // Shifts return poison if shiftwidth is larger than the bitwidth.
7453   SmallVector<const Constant *, 4> ShiftAmounts;
7454   if (auto *FVTy = dyn_cast<FixedVectorType>(C->getType())) {
7455     unsigned NumElts = FVTy->getNumElements();
7456     for (unsigned i = 0; i < NumElts; ++i)
7457       ShiftAmounts.push_back(C->getAggregateElement(i));
7458   } else if (isa<ScalableVectorType>(C->getType()))
7459     return false; // Can't tell, just return false to be safe
7460   else
7461     ShiftAmounts.push_back(C);
7462 
7463   bool Safe = llvm::all_of(ShiftAmounts, [](const Constant *C) {
7464     auto *CI = dyn_cast_or_null<ConstantInt>(C);
7465     return CI && CI->getValue().ult(C->getType()->getIntegerBitWidth());
7466   });
7467 
7468   return Safe;
7469 }
7470 
7471 enum class UndefPoisonKind {
7472   PoisonOnly = (1 << 0),
7473   UndefOnly = (1 << 1),
7474   UndefOrPoison = PoisonOnly | UndefOnly,
7475 };
7476 
7477 static bool includesPoison(UndefPoisonKind Kind) {
7478   return (unsigned(Kind) & unsigned(UndefPoisonKind::PoisonOnly)) != 0;
7479 }
7480 
7481 static bool includesUndef(UndefPoisonKind Kind) {
7482   return (unsigned(Kind) & unsigned(UndefPoisonKind::UndefOnly)) != 0;
7483 }
7484 
7485 static bool canCreateUndefOrPoison(const Operator *Op, UndefPoisonKind Kind,
7486                                    bool ConsiderFlagsAndMetadata) {
7487 
7488   if (ConsiderFlagsAndMetadata && includesPoison(Kind) &&
7489       Op->hasPoisonGeneratingAnnotations())
7490     return true;
7491 
7492   unsigned Opcode = Op->getOpcode();
7493 
7494   // Check whether opcode is a poison/undef-generating operation
7495   switch (Opcode) {
7496   case Instruction::Shl:
7497   case Instruction::AShr:
7498   case Instruction::LShr:
7499     return includesPoison(Kind) && !shiftAmountKnownInRange(Op->getOperand(1));
7500   case Instruction::FPToSI:
7501   case Instruction::FPToUI:
7502     // fptosi/ui yields poison if the resulting value does not fit in the
7503     // destination type.
7504     return true;
7505   case Instruction::Call:
7506     if (auto *II = dyn_cast<IntrinsicInst>(Op)) {
7507       switch (II->getIntrinsicID()) {
7508       // TODO: Add more intrinsics.
7509       case Intrinsic::ctlz:
7510       case Intrinsic::cttz:
7511       case Intrinsic::abs:
7512         if (cast<ConstantInt>(II->getArgOperand(1))->isNullValue())
7513           return false;
7514         break;
7515       case Intrinsic::ctpop:
7516       case Intrinsic::bswap:
7517       case Intrinsic::bitreverse:
7518       case Intrinsic::fshl:
7519       case Intrinsic::fshr:
7520       case Intrinsic::smax:
7521       case Intrinsic::smin:
7522       case Intrinsic::umax:
7523       case Intrinsic::umin:
7524       case Intrinsic::ptrmask:
7525       case Intrinsic::fptoui_sat:
7526       case Intrinsic::fptosi_sat:
7527       case Intrinsic::sadd_with_overflow:
7528       case Intrinsic::ssub_with_overflow:
7529       case Intrinsic::smul_with_overflow:
7530       case Intrinsic::uadd_with_overflow:
7531       case Intrinsic::usub_with_overflow:
7532       case Intrinsic::umul_with_overflow:
7533       case Intrinsic::sadd_sat:
7534       case Intrinsic::uadd_sat:
7535       case Intrinsic::ssub_sat:
7536       case Intrinsic::usub_sat:
7537         return false;
7538       case Intrinsic::sshl_sat:
7539       case Intrinsic::ushl_sat:
7540         return includesPoison(Kind) &&
7541                !shiftAmountKnownInRange(II->getArgOperand(1));
7542       case Intrinsic::fma:
7543       case Intrinsic::fmuladd:
7544       case Intrinsic::sqrt:
7545       case Intrinsic::powi:
7546       case Intrinsic::sin:
7547       case Intrinsic::cos:
7548       case Intrinsic::pow:
7549       case Intrinsic::log:
7550       case Intrinsic::log10:
7551       case Intrinsic::log2:
7552       case Intrinsic::exp:
7553       case Intrinsic::exp2:
7554       case Intrinsic::exp10:
7555       case Intrinsic::fabs:
7556       case Intrinsic::copysign:
7557       case Intrinsic::floor:
7558       case Intrinsic::ceil:
7559       case Intrinsic::trunc:
7560       case Intrinsic::rint:
7561       case Intrinsic::nearbyint:
7562       case Intrinsic::round:
7563       case Intrinsic::roundeven:
7564       case Intrinsic::fptrunc_round:
7565       case Intrinsic::canonicalize:
7566       case Intrinsic::arithmetic_fence:
7567       case Intrinsic::minnum:
7568       case Intrinsic::maxnum:
7569       case Intrinsic::minimum:
7570       case Intrinsic::maximum:
7571       case Intrinsic::is_fpclass:
7572       case Intrinsic::ldexp:
7573       case Intrinsic::frexp:
7574         return false;
7575       case Intrinsic::lround:
7576       case Intrinsic::llround:
7577       case Intrinsic::lrint:
7578       case Intrinsic::llrint:
7579         // If the value doesn't fit an unspecified value is returned (but this
7580         // is not poison).
7581         return false;
7582       }
7583     }
7584     [[fallthrough]];
7585   case Instruction::CallBr:
7586   case Instruction::Invoke: {
7587     const auto *CB = cast<CallBase>(Op);
7588     return !CB->hasRetAttr(Attribute::NoUndef);
7589   }
7590   case Instruction::InsertElement:
7591   case Instruction::ExtractElement: {
7592     // If index exceeds the length of the vector, it returns poison
7593     auto *VTy = cast<VectorType>(Op->getOperand(0)->getType());
7594     unsigned IdxOp = Op->getOpcode() == Instruction::InsertElement ? 2 : 1;
7595     auto *Idx = dyn_cast<ConstantInt>(Op->getOperand(IdxOp));
7596     if (includesPoison(Kind))
7597       return !Idx ||
7598              Idx->getValue().uge(VTy->getElementCount().getKnownMinValue());
7599     return false;
7600   }
7601   case Instruction::ShuffleVector: {
7602     ArrayRef<int> Mask = isa<ConstantExpr>(Op)
7603                              ? cast<ConstantExpr>(Op)->getShuffleMask()
7604                              : cast<ShuffleVectorInst>(Op)->getShuffleMask();
7605     return includesPoison(Kind) && is_contained(Mask, PoisonMaskElem);
7606   }
7607   case Instruction::FNeg:
7608   case Instruction::PHI:
7609   case Instruction::Select:
7610   case Instruction::URem:
7611   case Instruction::SRem:
7612   case Instruction::ExtractValue:
7613   case Instruction::InsertValue:
7614   case Instruction::Freeze:
7615   case Instruction::ICmp:
7616   case Instruction::FCmp:
7617   case Instruction::FAdd:
7618   case Instruction::FSub:
7619   case Instruction::FMul:
7620   case Instruction::FDiv:
7621   case Instruction::FRem:
7622     return false;
7623   case Instruction::GetElementPtr:
7624     // inbounds is handled above
7625     // TODO: what about inrange on constexpr?
7626     return false;
7627   default: {
7628     const auto *CE = dyn_cast<ConstantExpr>(Op);
7629     if (isa<CastInst>(Op) || (CE && CE->isCast()))
7630       return false;
7631     else if (Instruction::isBinaryOp(Opcode))
7632       return false;
7633     // Be conservative and return true.
7634     return true;
7635   }
7636   }
7637 }
7638 
7639 bool llvm::canCreateUndefOrPoison(const Operator *Op,
7640                                   bool ConsiderFlagsAndMetadata) {
7641   return ::canCreateUndefOrPoison(Op, UndefPoisonKind::UndefOrPoison,
7642                                   ConsiderFlagsAndMetadata);
7643 }
7644 
7645 bool llvm::canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata) {
7646   return ::canCreateUndefOrPoison(Op, UndefPoisonKind::PoisonOnly,
7647                                   ConsiderFlagsAndMetadata);
7648 }
7649 
7650 static bool directlyImpliesPoison(const Value *ValAssumedPoison, const Value *V,
7651                                   unsigned Depth) {
7652   if (ValAssumedPoison == V)
7653     return true;
7654 
7655   const unsigned MaxDepth = 2;
7656   if (Depth >= MaxDepth)
7657     return false;
7658 
7659   if (const auto *I = dyn_cast<Instruction>(V)) {
7660     if (any_of(I->operands(), [=](const Use &Op) {
7661           return propagatesPoison(Op) &&
7662                  directlyImpliesPoison(ValAssumedPoison, Op, Depth + 1);
7663         }))
7664       return true;
7665 
7666     // V  = extractvalue V0, idx
7667     // V2 = extractvalue V0, idx2
7668     // V0's elements are all poison or not. (e.g., add_with_overflow)
7669     const WithOverflowInst *II;
7670     if (match(I, m_ExtractValue(m_WithOverflowInst(II))) &&
7671         (match(ValAssumedPoison, m_ExtractValue(m_Specific(II))) ||
7672          llvm::is_contained(II->args(), ValAssumedPoison)))
7673       return true;
7674   }
7675   return false;
7676 }
7677 
7678 static bool impliesPoison(const Value *ValAssumedPoison, const Value *V,
7679                           unsigned Depth) {
7680   if (isGuaranteedNotToBePoison(ValAssumedPoison))
7681     return true;
7682 
7683   if (directlyImpliesPoison(ValAssumedPoison, V, /* Depth */ 0))
7684     return true;
7685 
7686   const unsigned MaxDepth = 2;
7687   if (Depth >= MaxDepth)
7688     return false;
7689 
7690   const auto *I = dyn_cast<Instruction>(ValAssumedPoison);
7691   if (I && !canCreatePoison(cast<Operator>(I))) {
7692     return all_of(I->operands(), [=](const Value *Op) {
7693       return impliesPoison(Op, V, Depth + 1);
7694     });
7695   }
7696   return false;
7697 }
7698 
7699 bool llvm::impliesPoison(const Value *ValAssumedPoison, const Value *V) {
7700   return ::impliesPoison(ValAssumedPoison, V, /* Depth */ 0);
7701 }
7702 
7703 static bool programUndefinedIfUndefOrPoison(const Value *V, bool PoisonOnly);
7704 
7705 static bool isGuaranteedNotToBeUndefOrPoison(
7706     const Value *V, AssumptionCache *AC, const Instruction *CtxI,
7707     const DominatorTree *DT, unsigned Depth, UndefPoisonKind Kind) {
7708   if (Depth >= MaxAnalysisRecursionDepth)
7709     return false;
7710 
7711   if (isa<MetadataAsValue>(V))
7712     return false;
7713 
7714   if (const auto *A = dyn_cast<Argument>(V)) {
7715     if (A->hasAttribute(Attribute::NoUndef) ||
7716         A->hasAttribute(Attribute::Dereferenceable) ||
7717         A->hasAttribute(Attribute::DereferenceableOrNull))
7718       return true;
7719   }
7720 
7721   if (auto *C = dyn_cast<Constant>(V)) {
7722     if (isa<PoisonValue>(C))
7723       return !includesPoison(Kind);
7724 
7725     if (isa<UndefValue>(C))
7726       return !includesUndef(Kind);
7727 
7728     if (isa<ConstantInt>(C) || isa<GlobalVariable>(C) || isa<ConstantFP>(V) ||
7729         isa<ConstantPointerNull>(C) || isa<Function>(C))
7730       return true;
7731 
7732     if (C->getType()->isVectorTy() && !isa<ConstantExpr>(C)) {
7733       if (includesUndef(Kind) && C->containsUndefElement())
7734         return false;
7735       if (includesPoison(Kind) && C->containsPoisonElement())
7736         return false;
7737       return !C->containsConstantExpression();
7738     }
7739   }
7740 
7741   // Strip cast operations from a pointer value.
7742   // Note that stripPointerCastsSameRepresentation can strip off getelementptr
7743   // inbounds with zero offset. To guarantee that the result isn't poison, the
7744   // stripped pointer is checked as it has to be pointing into an allocated
7745   // object or be null `null` to ensure `inbounds` getelement pointers with a
7746   // zero offset could not produce poison.
7747   // It can strip off addrspacecast that do not change bit representation as
7748   // well. We believe that such addrspacecast is equivalent to no-op.
7749   auto *StrippedV = V->stripPointerCastsSameRepresentation();
7750   if (isa<AllocaInst>(StrippedV) || isa<GlobalVariable>(StrippedV) ||
7751       isa<Function>(StrippedV) || isa<ConstantPointerNull>(StrippedV))
7752     return true;
7753 
7754   auto OpCheck = [&](const Value *V) {
7755     return isGuaranteedNotToBeUndefOrPoison(V, AC, CtxI, DT, Depth + 1, Kind);
7756   };
7757 
7758   if (auto *Opr = dyn_cast<Operator>(V)) {
7759     // If the value is a freeze instruction, then it can never
7760     // be undef or poison.
7761     if (isa<FreezeInst>(V))
7762       return true;
7763 
7764     if (const auto *CB = dyn_cast<CallBase>(V)) {
7765       if (CB->hasRetAttr(Attribute::NoUndef) ||
7766           CB->hasRetAttr(Attribute::Dereferenceable) ||
7767           CB->hasRetAttr(Attribute::DereferenceableOrNull))
7768         return true;
7769     }
7770 
7771     if (const auto *PN = dyn_cast<PHINode>(V)) {
7772       unsigned Num = PN->getNumIncomingValues();
7773       bool IsWellDefined = true;
7774       for (unsigned i = 0; i < Num; ++i) {
7775         auto *TI = PN->getIncomingBlock(i)->getTerminator();
7776         if (!isGuaranteedNotToBeUndefOrPoison(PN->getIncomingValue(i), AC, TI,
7777                                               DT, Depth + 1, Kind)) {
7778           IsWellDefined = false;
7779           break;
7780         }
7781       }
7782       if (IsWellDefined)
7783         return true;
7784     } else if (!::canCreateUndefOrPoison(Opr, Kind,
7785                                          /*ConsiderFlagsAndMetadata*/ true) &&
7786                all_of(Opr->operands(), OpCheck))
7787       return true;
7788   }
7789 
7790   if (auto *I = dyn_cast<LoadInst>(V))
7791     if (I->hasMetadata(LLVMContext::MD_noundef) ||
7792         I->hasMetadata(LLVMContext::MD_dereferenceable) ||
7793         I->hasMetadata(LLVMContext::MD_dereferenceable_or_null))
7794       return true;
7795 
7796   if (programUndefinedIfUndefOrPoison(V, !includesUndef(Kind)))
7797     return true;
7798 
7799   // CxtI may be null or a cloned instruction.
7800   if (!CtxI || !CtxI->getParent() || !DT)
7801     return false;
7802 
7803   auto *DNode = DT->getNode(CtxI->getParent());
7804   if (!DNode)
7805     // Unreachable block
7806     return false;
7807 
7808   // If V is used as a branch condition before reaching CtxI, V cannot be
7809   // undef or poison.
7810   //   br V, BB1, BB2
7811   // BB1:
7812   //   CtxI ; V cannot be undef or poison here
7813   auto *Dominator = DNode->getIDom();
7814   // This check is purely for compile time reasons: we can skip the IDom walk
7815   // if what we are checking for includes undef and the value is not an integer.
7816   if (!includesUndef(Kind) || V->getType()->isIntegerTy())
7817     while (Dominator) {
7818       auto *TI = Dominator->getBlock()->getTerminator();
7819 
7820       Value *Cond = nullptr;
7821       if (auto BI = dyn_cast_or_null<BranchInst>(TI)) {
7822         if (BI->isConditional())
7823           Cond = BI->getCondition();
7824       } else if (auto SI = dyn_cast_or_null<SwitchInst>(TI)) {
7825         Cond = SI->getCondition();
7826       }
7827 
7828       if (Cond) {
7829         if (Cond == V)
7830           return true;
7831         else if (!includesUndef(Kind) && isa<Operator>(Cond)) {
7832           // For poison, we can analyze further
7833           auto *Opr = cast<Operator>(Cond);
7834           if (any_of(Opr->operands(), [V](const Use &U) {
7835                 return V == U && propagatesPoison(U);
7836               }))
7837             return true;
7838         }
7839       }
7840 
7841       Dominator = Dominator->getIDom();
7842     }
7843 
7844   if (getKnowledgeValidInContext(V, {Attribute::NoUndef}, CtxI, DT, AC))
7845     return true;
7846 
7847   return false;
7848 }
7849 
7850 bool llvm::isGuaranteedNotToBeUndefOrPoison(const Value *V, AssumptionCache *AC,
7851                                             const Instruction *CtxI,
7852                                             const DominatorTree *DT,
7853                                             unsigned Depth) {
7854   return ::isGuaranteedNotToBeUndefOrPoison(V, AC, CtxI, DT, Depth,
7855                                             UndefPoisonKind::UndefOrPoison);
7856 }
7857 
7858 bool llvm::isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC,
7859                                      const Instruction *CtxI,
7860                                      const DominatorTree *DT, unsigned Depth) {
7861   return ::isGuaranteedNotToBeUndefOrPoison(V, AC, CtxI, DT, Depth,
7862                                             UndefPoisonKind::PoisonOnly);
7863 }
7864 
7865 bool llvm::isGuaranteedNotToBeUndef(const Value *V, AssumptionCache *AC,
7866                                     const Instruction *CtxI,
7867                                     const DominatorTree *DT, unsigned Depth) {
7868   return ::isGuaranteedNotToBeUndefOrPoison(V, AC, CtxI, DT, Depth,
7869                                             UndefPoisonKind::UndefOnly);
7870 }
7871 
7872 /// Return true if undefined behavior would provably be executed on the path to
7873 /// OnPathTo if Root produced a posion result.  Note that this doesn't say
7874 /// anything about whether OnPathTo is actually executed or whether Root is
7875 /// actually poison.  This can be used to assess whether a new use of Root can
7876 /// be added at a location which is control equivalent with OnPathTo (such as
7877 /// immediately before it) without introducing UB which didn't previously
7878 /// exist.  Note that a false result conveys no information.
7879 bool llvm::mustExecuteUBIfPoisonOnPathTo(Instruction *Root,
7880                                          Instruction *OnPathTo,
7881                                          DominatorTree *DT) {
7882   // Basic approach is to assume Root is poison, propagate poison forward
7883   // through all users we can easily track, and then check whether any of those
7884   // users are provable UB and must execute before out exiting block might
7885   // exit.
7886 
7887   // The set of all recursive users we've visited (which are assumed to all be
7888   // poison because of said visit)
7889   SmallSet<const Value *, 16> KnownPoison;
7890   SmallVector<const Instruction*, 16> Worklist;
7891   Worklist.push_back(Root);
7892   while (!Worklist.empty()) {
7893     const Instruction *I = Worklist.pop_back_val();
7894 
7895     // If we know this must trigger UB on a path leading our target.
7896     if (mustTriggerUB(I, KnownPoison) && DT->dominates(I, OnPathTo))
7897       return true;
7898 
7899     // If we can't analyze propagation through this instruction, just skip it
7900     // and transitive users.  Safe as false is a conservative result.
7901     if (I != Root && !any_of(I->operands(), [&KnownPoison](const Use &U) {
7902           return KnownPoison.contains(U) && propagatesPoison(U);
7903         }))
7904       continue;
7905 
7906     if (KnownPoison.insert(I).second)
7907       for (const User *User : I->users())
7908         Worklist.push_back(cast<Instruction>(User));
7909   }
7910 
7911   // Might be non-UB, or might have a path we couldn't prove must execute on
7912   // way to exiting bb.
7913   return false;
7914 }
7915 
7916 OverflowResult llvm::computeOverflowForSignedAdd(const AddOperator *Add,
7917                                                  const SimplifyQuery &SQ) {
7918   return ::computeOverflowForSignedAdd(Add->getOperand(0), Add->getOperand(1),
7919                                        Add, SQ);
7920 }
7921 
7922 OverflowResult
7923 llvm::computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
7924                                   const WithCache<const Value *> &RHS,
7925                                   const SimplifyQuery &SQ) {
7926   return ::computeOverflowForSignedAdd(LHS, RHS, nullptr, SQ);
7927 }
7928 
7929 bool llvm::isGuaranteedToTransferExecutionToSuccessor(const Instruction *I) {
7930   // Note: An atomic operation isn't guaranteed to return in a reasonable amount
7931   // of time because it's possible for another thread to interfere with it for an
7932   // arbitrary length of time, but programs aren't allowed to rely on that.
7933 
7934   // If there is no successor, then execution can't transfer to it.
7935   if (isa<ReturnInst>(I))
7936     return false;
7937   if (isa<UnreachableInst>(I))
7938     return false;
7939 
7940   // Note: Do not add new checks here; instead, change Instruction::mayThrow or
7941   // Instruction::willReturn.
7942   //
7943   // FIXME: Move this check into Instruction::willReturn.
7944   if (isa<CatchPadInst>(I)) {
7945     switch (classifyEHPersonality(I->getFunction()->getPersonalityFn())) {
7946     default:
7947       // A catchpad may invoke exception object constructors and such, which
7948       // in some languages can be arbitrary code, so be conservative by default.
7949       return false;
7950     case EHPersonality::CoreCLR:
7951       // For CoreCLR, it just involves a type test.
7952       return true;
7953     }
7954   }
7955 
7956   // An instruction that returns without throwing must transfer control flow
7957   // to a successor.
7958   return !I->mayThrow() && I->willReturn();
7959 }
7960 
7961 bool llvm::isGuaranteedToTransferExecutionToSuccessor(const BasicBlock *BB) {
7962   // TODO: This is slightly conservative for invoke instruction since exiting
7963   // via an exception *is* normal control for them.
7964   for (const Instruction &I : *BB)
7965     if (!isGuaranteedToTransferExecutionToSuccessor(&I))
7966       return false;
7967   return true;
7968 }
7969 
7970 bool llvm::isGuaranteedToTransferExecutionToSuccessor(
7971    BasicBlock::const_iterator Begin, BasicBlock::const_iterator End,
7972    unsigned ScanLimit) {
7973   return isGuaranteedToTransferExecutionToSuccessor(make_range(Begin, End),
7974                                                     ScanLimit);
7975 }
7976 
7977 bool llvm::isGuaranteedToTransferExecutionToSuccessor(
7978    iterator_range<BasicBlock::const_iterator> Range, unsigned ScanLimit) {
7979   assert(ScanLimit && "scan limit must be non-zero");
7980   for (const Instruction &I : Range) {
7981     if (isa<DbgInfoIntrinsic>(I))
7982         continue;
7983     if (--ScanLimit == 0)
7984       return false;
7985     if (!isGuaranteedToTransferExecutionToSuccessor(&I))
7986       return false;
7987   }
7988   return true;
7989 }
7990 
7991 bool llvm::isGuaranteedToExecuteForEveryIteration(const Instruction *I,
7992                                                   const Loop *L) {
7993   // The loop header is guaranteed to be executed for every iteration.
7994   //
7995   // FIXME: Relax this constraint to cover all basic blocks that are
7996   // guaranteed to be executed at every iteration.
7997   if (I->getParent() != L->getHeader()) return false;
7998 
7999   for (const Instruction &LI : *L->getHeader()) {
8000     if (&LI == I) return true;
8001     if (!isGuaranteedToTransferExecutionToSuccessor(&LI)) return false;
8002   }
8003   llvm_unreachable("Instruction not contained in its own parent basic block.");
8004 }
8005 
8006 bool llvm::propagatesPoison(const Use &PoisonOp) {
8007   const Operator *I = cast<Operator>(PoisonOp.getUser());
8008   switch (I->getOpcode()) {
8009   case Instruction::Freeze:
8010   case Instruction::PHI:
8011   case Instruction::Invoke:
8012     return false;
8013   case Instruction::Select:
8014     return PoisonOp.getOperandNo() == 0;
8015   case Instruction::Call:
8016     if (auto *II = dyn_cast<IntrinsicInst>(I)) {
8017       switch (II->getIntrinsicID()) {
8018       // TODO: Add more intrinsics.
8019       case Intrinsic::sadd_with_overflow:
8020       case Intrinsic::ssub_with_overflow:
8021       case Intrinsic::smul_with_overflow:
8022       case Intrinsic::uadd_with_overflow:
8023       case Intrinsic::usub_with_overflow:
8024       case Intrinsic::umul_with_overflow:
8025         // If an input is a vector containing a poison element, the
8026         // two output vectors (calculated results, overflow bits)'
8027         // corresponding lanes are poison.
8028         return true;
8029       case Intrinsic::ctpop:
8030       case Intrinsic::ctlz:
8031       case Intrinsic::cttz:
8032       case Intrinsic::abs:
8033       case Intrinsic::smax:
8034       case Intrinsic::smin:
8035       case Intrinsic::umax:
8036       case Intrinsic::umin:
8037       case Intrinsic::bitreverse:
8038       case Intrinsic::bswap:
8039       case Intrinsic::sadd_sat:
8040       case Intrinsic::ssub_sat:
8041       case Intrinsic::sshl_sat:
8042       case Intrinsic::uadd_sat:
8043       case Intrinsic::usub_sat:
8044       case Intrinsic::ushl_sat:
8045         return true;
8046       }
8047     }
8048     return false;
8049   case Instruction::ICmp:
8050   case Instruction::FCmp:
8051   case Instruction::GetElementPtr:
8052     return true;
8053   default:
8054     if (isa<BinaryOperator>(I) || isa<UnaryOperator>(I) || isa<CastInst>(I))
8055       return true;
8056 
8057     // Be conservative and return false.
8058     return false;
8059   }
8060 }
8061 
8062 /// Enumerates all operands of \p I that are guaranteed to not be undef or
8063 /// poison. If the callback \p Handle returns true, stop processing and return
8064 /// true. Otherwise, return false.
8065 template <typename CallableT>
8066 static bool handleGuaranteedWellDefinedOps(const Instruction *I,
8067                                            const CallableT &Handle) {
8068   switch (I->getOpcode()) {
8069     case Instruction::Store:
8070       if (Handle(cast<StoreInst>(I)->getPointerOperand()))
8071         return true;
8072       break;
8073 
8074     case Instruction::Load:
8075       if (Handle(cast<LoadInst>(I)->getPointerOperand()))
8076         return true;
8077       break;
8078 
8079     // Since dereferenceable attribute imply noundef, atomic operations
8080     // also implicitly have noundef pointers too
8081     case Instruction::AtomicCmpXchg:
8082       if (Handle(cast<AtomicCmpXchgInst>(I)->getPointerOperand()))
8083         return true;
8084       break;
8085 
8086     case Instruction::AtomicRMW:
8087       if (Handle(cast<AtomicRMWInst>(I)->getPointerOperand()))
8088         return true;
8089       break;
8090 
8091     case Instruction::Call:
8092     case Instruction::Invoke: {
8093       const CallBase *CB = cast<CallBase>(I);
8094       if (CB->isIndirectCall() && Handle(CB->getCalledOperand()))
8095         return true;
8096       for (unsigned i = 0; i < CB->arg_size(); ++i)
8097         if ((CB->paramHasAttr(i, Attribute::NoUndef) ||
8098              CB->paramHasAttr(i, Attribute::Dereferenceable) ||
8099              CB->paramHasAttr(i, Attribute::DereferenceableOrNull)) &&
8100             Handle(CB->getArgOperand(i)))
8101           return true;
8102       break;
8103     }
8104     case Instruction::Ret:
8105       if (I->getFunction()->hasRetAttribute(Attribute::NoUndef) &&
8106           Handle(I->getOperand(0)))
8107         return true;
8108       break;
8109     case Instruction::Switch:
8110       if (Handle(cast<SwitchInst>(I)->getCondition()))
8111         return true;
8112       break;
8113     case Instruction::Br: {
8114       auto *BR = cast<BranchInst>(I);
8115       if (BR->isConditional() && Handle(BR->getCondition()))
8116         return true;
8117       break;
8118     }
8119     default:
8120       break;
8121   }
8122 
8123   return false;
8124 }
8125 
8126 void llvm::getGuaranteedWellDefinedOps(
8127     const Instruction *I, SmallVectorImpl<const Value *> &Operands) {
8128   handleGuaranteedWellDefinedOps(I, [&](const Value *V) {
8129     Operands.push_back(V);
8130     return false;
8131   });
8132 }
8133 
8134 /// Enumerates all operands of \p I that are guaranteed to not be poison.
8135 template <typename CallableT>
8136 static bool handleGuaranteedNonPoisonOps(const Instruction *I,
8137                                          const CallableT &Handle) {
8138   if (handleGuaranteedWellDefinedOps(I, Handle))
8139     return true;
8140   switch (I->getOpcode()) {
8141   // Divisors of these operations are allowed to be partially undef.
8142   case Instruction::UDiv:
8143   case Instruction::SDiv:
8144   case Instruction::URem:
8145   case Instruction::SRem:
8146     return Handle(I->getOperand(1));
8147   default:
8148     return false;
8149   }
8150 }
8151 
8152 void llvm::getGuaranteedNonPoisonOps(const Instruction *I,
8153                                      SmallVectorImpl<const Value *> &Operands) {
8154   handleGuaranteedNonPoisonOps(I, [&](const Value *V) {
8155     Operands.push_back(V);
8156     return false;
8157   });
8158 }
8159 
8160 bool llvm::mustTriggerUB(const Instruction *I,
8161                          const SmallPtrSetImpl<const Value *> &KnownPoison) {
8162   return handleGuaranteedNonPoisonOps(
8163       I, [&](const Value *V) { return KnownPoison.count(V); });
8164 }
8165 
8166 static bool programUndefinedIfUndefOrPoison(const Value *V,
8167                                             bool PoisonOnly) {
8168   // We currently only look for uses of values within the same basic
8169   // block, as that makes it easier to guarantee that the uses will be
8170   // executed given that Inst is executed.
8171   //
8172   // FIXME: Expand this to consider uses beyond the same basic block. To do
8173   // this, look out for the distinction between post-dominance and strong
8174   // post-dominance.
8175   const BasicBlock *BB = nullptr;
8176   BasicBlock::const_iterator Begin;
8177   if (const auto *Inst = dyn_cast<Instruction>(V)) {
8178     BB = Inst->getParent();
8179     Begin = Inst->getIterator();
8180     Begin++;
8181   } else if (const auto *Arg = dyn_cast<Argument>(V)) {
8182     if (Arg->getParent()->isDeclaration())
8183       return false;
8184     BB = &Arg->getParent()->getEntryBlock();
8185     Begin = BB->begin();
8186   } else {
8187     return false;
8188   }
8189 
8190   // Limit number of instructions we look at, to avoid scanning through large
8191   // blocks. The current limit is chosen arbitrarily.
8192   unsigned ScanLimit = 32;
8193   BasicBlock::const_iterator End = BB->end();
8194 
8195   if (!PoisonOnly) {
8196     // Since undef does not propagate eagerly, be conservative & just check
8197     // whether a value is directly passed to an instruction that must take
8198     // well-defined operands.
8199 
8200     for (const auto &I : make_range(Begin, End)) {
8201       if (isa<DbgInfoIntrinsic>(I))
8202         continue;
8203       if (--ScanLimit == 0)
8204         break;
8205 
8206       if (handleGuaranteedWellDefinedOps(&I, [V](const Value *WellDefinedOp) {
8207             return WellDefinedOp == V;
8208           }))
8209         return true;
8210 
8211       if (!isGuaranteedToTransferExecutionToSuccessor(&I))
8212         break;
8213     }
8214     return false;
8215   }
8216 
8217   // Set of instructions that we have proved will yield poison if Inst
8218   // does.
8219   SmallSet<const Value *, 16> YieldsPoison;
8220   SmallSet<const BasicBlock *, 4> Visited;
8221 
8222   YieldsPoison.insert(V);
8223   Visited.insert(BB);
8224 
8225   while (true) {
8226     for (const auto &I : make_range(Begin, End)) {
8227       if (isa<DbgInfoIntrinsic>(I))
8228         continue;
8229       if (--ScanLimit == 0)
8230         return false;
8231       if (mustTriggerUB(&I, YieldsPoison))
8232         return true;
8233       if (!isGuaranteedToTransferExecutionToSuccessor(&I))
8234         return false;
8235 
8236       // If an operand is poison and propagates it, mark I as yielding poison.
8237       for (const Use &Op : I.operands()) {
8238         if (YieldsPoison.count(Op) && propagatesPoison(Op)) {
8239           YieldsPoison.insert(&I);
8240           break;
8241         }
8242       }
8243 
8244       // Special handling for select, which returns poison if its operand 0 is
8245       // poison (handled in the loop above) *or* if both its true/false operands
8246       // are poison (handled here).
8247       if (I.getOpcode() == Instruction::Select &&
8248           YieldsPoison.count(I.getOperand(1)) &&
8249           YieldsPoison.count(I.getOperand(2))) {
8250         YieldsPoison.insert(&I);
8251       }
8252     }
8253 
8254     BB = BB->getSingleSuccessor();
8255     if (!BB || !Visited.insert(BB).second)
8256       break;
8257 
8258     Begin = BB->getFirstNonPHIIt();
8259     End = BB->end();
8260   }
8261   return false;
8262 }
8263 
8264 bool llvm::programUndefinedIfUndefOrPoison(const Instruction *Inst) {
8265   return ::programUndefinedIfUndefOrPoison(Inst, false);
8266 }
8267 
8268 bool llvm::programUndefinedIfPoison(const Instruction *Inst) {
8269   return ::programUndefinedIfUndefOrPoison(Inst, true);
8270 }
8271 
8272 static bool isKnownNonNaN(const Value *V, FastMathFlags FMF) {
8273   if (FMF.noNaNs())
8274     return true;
8275 
8276   if (auto *C = dyn_cast<ConstantFP>(V))
8277     return !C->isNaN();
8278 
8279   if (auto *C = dyn_cast<ConstantDataVector>(V)) {
8280     if (!C->getElementType()->isFloatingPointTy())
8281       return false;
8282     for (unsigned I = 0, E = C->getNumElements(); I < E; ++I) {
8283       if (C->getElementAsAPFloat(I).isNaN())
8284         return false;
8285     }
8286     return true;
8287   }
8288 
8289   if (isa<ConstantAggregateZero>(V))
8290     return true;
8291 
8292   return false;
8293 }
8294 
8295 static bool isKnownNonZero(const Value *V) {
8296   if (auto *C = dyn_cast<ConstantFP>(V))
8297     return !C->isZero();
8298 
8299   if (auto *C = dyn_cast<ConstantDataVector>(V)) {
8300     if (!C->getElementType()->isFloatingPointTy())
8301       return false;
8302     for (unsigned I = 0, E = C->getNumElements(); I < E; ++I) {
8303       if (C->getElementAsAPFloat(I).isZero())
8304         return false;
8305     }
8306     return true;
8307   }
8308 
8309   return false;
8310 }
8311 
8312 /// Match clamp pattern for float types without care about NaNs or signed zeros.
8313 /// Given non-min/max outer cmp/select from the clamp pattern this
8314 /// function recognizes if it can be substitued by a "canonical" min/max
8315 /// pattern.
8316 static SelectPatternResult matchFastFloatClamp(CmpInst::Predicate Pred,
8317                                                Value *CmpLHS, Value *CmpRHS,
8318                                                Value *TrueVal, Value *FalseVal,
8319                                                Value *&LHS, Value *&RHS) {
8320   // Try to match
8321   //   X < C1 ? C1 : Min(X, C2) --> Max(C1, Min(X, C2))
8322   //   X > C1 ? C1 : Max(X, C2) --> Min(C1, Max(X, C2))
8323   // and return description of the outer Max/Min.
8324 
8325   // First, check if select has inverse order:
8326   if (CmpRHS == FalseVal) {
8327     std::swap(TrueVal, FalseVal);
8328     Pred = CmpInst::getInversePredicate(Pred);
8329   }
8330 
8331   // Assume success now. If there's no match, callers should not use these anyway.
8332   LHS = TrueVal;
8333   RHS = FalseVal;
8334 
8335   const APFloat *FC1;
8336   if (CmpRHS != TrueVal || !match(CmpRHS, m_APFloat(FC1)) || !FC1->isFinite())
8337     return {SPF_UNKNOWN, SPNB_NA, false};
8338 
8339   const APFloat *FC2;
8340   switch (Pred) {
8341   case CmpInst::FCMP_OLT:
8342   case CmpInst::FCMP_OLE:
8343   case CmpInst::FCMP_ULT:
8344   case CmpInst::FCMP_ULE:
8345     if (match(FalseVal, m_OrdOrUnordFMin(m_Specific(CmpLHS), m_APFloat(FC2))) &&
8346         *FC1 < *FC2)
8347       return {SPF_FMAXNUM, SPNB_RETURNS_ANY, false};
8348     break;
8349   case CmpInst::FCMP_OGT:
8350   case CmpInst::FCMP_OGE:
8351   case CmpInst::FCMP_UGT:
8352   case CmpInst::FCMP_UGE:
8353     if (match(FalseVal, m_OrdOrUnordFMax(m_Specific(CmpLHS), m_APFloat(FC2))) &&
8354         *FC1 > *FC2)
8355       return {SPF_FMINNUM, SPNB_RETURNS_ANY, false};
8356     break;
8357   default:
8358     break;
8359   }
8360 
8361   return {SPF_UNKNOWN, SPNB_NA, false};
8362 }
8363 
8364 /// Recognize variations of:
8365 ///   CLAMP(v,l,h) ==> ((v) < (l) ? (l) : ((v) > (h) ? (h) : (v)))
8366 static SelectPatternResult matchClamp(CmpInst::Predicate Pred,
8367                                       Value *CmpLHS, Value *CmpRHS,
8368                                       Value *TrueVal, Value *FalseVal) {
8369   // Swap the select operands and predicate to match the patterns below.
8370   if (CmpRHS != TrueVal) {
8371     Pred = ICmpInst::getSwappedPredicate(Pred);
8372     std::swap(TrueVal, FalseVal);
8373   }
8374   const APInt *C1;
8375   if (CmpRHS == TrueVal && match(CmpRHS, m_APInt(C1))) {
8376     const APInt *C2;
8377     // (X <s C1) ? C1 : SMIN(X, C2) ==> SMAX(SMIN(X, C2), C1)
8378     if (match(FalseVal, m_SMin(m_Specific(CmpLHS), m_APInt(C2))) &&
8379         C1->slt(*C2) && Pred == CmpInst::ICMP_SLT)
8380       return {SPF_SMAX, SPNB_NA, false};
8381 
8382     // (X >s C1) ? C1 : SMAX(X, C2) ==> SMIN(SMAX(X, C2), C1)
8383     if (match(FalseVal, m_SMax(m_Specific(CmpLHS), m_APInt(C2))) &&
8384         C1->sgt(*C2) && Pred == CmpInst::ICMP_SGT)
8385       return {SPF_SMIN, SPNB_NA, false};
8386 
8387     // (X <u C1) ? C1 : UMIN(X, C2) ==> UMAX(UMIN(X, C2), C1)
8388     if (match(FalseVal, m_UMin(m_Specific(CmpLHS), m_APInt(C2))) &&
8389         C1->ult(*C2) && Pred == CmpInst::ICMP_ULT)
8390       return {SPF_UMAX, SPNB_NA, false};
8391 
8392     // (X >u C1) ? C1 : UMAX(X, C2) ==> UMIN(UMAX(X, C2), C1)
8393     if (match(FalseVal, m_UMax(m_Specific(CmpLHS), m_APInt(C2))) &&
8394         C1->ugt(*C2) && Pred == CmpInst::ICMP_UGT)
8395       return {SPF_UMIN, SPNB_NA, false};
8396   }
8397   return {SPF_UNKNOWN, SPNB_NA, false};
8398 }
8399 
8400 /// Recognize variations of:
8401 ///   a < c ? min(a,b) : min(b,c) ==> min(min(a,b),min(b,c))
8402 static SelectPatternResult matchMinMaxOfMinMax(CmpInst::Predicate Pred,
8403                                                Value *CmpLHS, Value *CmpRHS,
8404                                                Value *TVal, Value *FVal,
8405                                                unsigned Depth) {
8406   // TODO: Allow FP min/max with nnan/nsz.
8407   assert(CmpInst::isIntPredicate(Pred) && "Expected integer comparison");
8408 
8409   Value *A = nullptr, *B = nullptr;
8410   SelectPatternResult L = matchSelectPattern(TVal, A, B, nullptr, Depth + 1);
8411   if (!SelectPatternResult::isMinOrMax(L.Flavor))
8412     return {SPF_UNKNOWN, SPNB_NA, false};
8413 
8414   Value *C = nullptr, *D = nullptr;
8415   SelectPatternResult R = matchSelectPattern(FVal, C, D, nullptr, Depth + 1);
8416   if (L.Flavor != R.Flavor)
8417     return {SPF_UNKNOWN, SPNB_NA, false};
8418 
8419   // We have something like: x Pred y ? min(a, b) : min(c, d).
8420   // Try to match the compare to the min/max operations of the select operands.
8421   // First, make sure we have the right compare predicate.
8422   switch (L.Flavor) {
8423   case SPF_SMIN:
8424     if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) {
8425       Pred = ICmpInst::getSwappedPredicate(Pred);
8426       std::swap(CmpLHS, CmpRHS);
8427     }
8428     if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
8429       break;
8430     return {SPF_UNKNOWN, SPNB_NA, false};
8431   case SPF_SMAX:
8432     if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) {
8433       Pred = ICmpInst::getSwappedPredicate(Pred);
8434       std::swap(CmpLHS, CmpRHS);
8435     }
8436     if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE)
8437       break;
8438     return {SPF_UNKNOWN, SPNB_NA, false};
8439   case SPF_UMIN:
8440     if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) {
8441       Pred = ICmpInst::getSwappedPredicate(Pred);
8442       std::swap(CmpLHS, CmpRHS);
8443     }
8444     if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
8445       break;
8446     return {SPF_UNKNOWN, SPNB_NA, false};
8447   case SPF_UMAX:
8448     if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) {
8449       Pred = ICmpInst::getSwappedPredicate(Pred);
8450       std::swap(CmpLHS, CmpRHS);
8451     }
8452     if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE)
8453       break;
8454     return {SPF_UNKNOWN, SPNB_NA, false};
8455   default:
8456     return {SPF_UNKNOWN, SPNB_NA, false};
8457   }
8458 
8459   // If there is a common operand in the already matched min/max and the other
8460   // min/max operands match the compare operands (either directly or inverted),
8461   // then this is min/max of the same flavor.
8462 
8463   // a pred c ? m(a, b) : m(c, b) --> m(m(a, b), m(c, b))
8464   // ~c pred ~a ? m(a, b) : m(c, b) --> m(m(a, b), m(c, b))
8465   if (D == B) {
8466     if ((CmpLHS == A && CmpRHS == C) || (match(C, m_Not(m_Specific(CmpLHS))) &&
8467                                          match(A, m_Not(m_Specific(CmpRHS)))))
8468       return {L.Flavor, SPNB_NA, false};
8469   }
8470   // a pred d ? m(a, b) : m(b, d) --> m(m(a, b), m(b, d))
8471   // ~d pred ~a ? m(a, b) : m(b, d) --> m(m(a, b), m(b, d))
8472   if (C == B) {
8473     if ((CmpLHS == A && CmpRHS == D) || (match(D, m_Not(m_Specific(CmpLHS))) &&
8474                                          match(A, m_Not(m_Specific(CmpRHS)))))
8475       return {L.Flavor, SPNB_NA, false};
8476   }
8477   // b pred c ? m(a, b) : m(c, a) --> m(m(a, b), m(c, a))
8478   // ~c pred ~b ? m(a, b) : m(c, a) --> m(m(a, b), m(c, a))
8479   if (D == A) {
8480     if ((CmpLHS == B && CmpRHS == C) || (match(C, m_Not(m_Specific(CmpLHS))) &&
8481                                          match(B, m_Not(m_Specific(CmpRHS)))))
8482       return {L.Flavor, SPNB_NA, false};
8483   }
8484   // b pred d ? m(a, b) : m(a, d) --> m(m(a, b), m(a, d))
8485   // ~d pred ~b ? m(a, b) : m(a, d) --> m(m(a, b), m(a, d))
8486   if (C == A) {
8487     if ((CmpLHS == B && CmpRHS == D) || (match(D, m_Not(m_Specific(CmpLHS))) &&
8488                                          match(B, m_Not(m_Specific(CmpRHS)))))
8489       return {L.Flavor, SPNB_NA, false};
8490   }
8491 
8492   return {SPF_UNKNOWN, SPNB_NA, false};
8493 }
8494 
8495 /// If the input value is the result of a 'not' op, constant integer, or vector
8496 /// splat of a constant integer, return the bitwise-not source value.
8497 /// TODO: This could be extended to handle non-splat vector integer constants.
8498 static Value *getNotValue(Value *V) {
8499   Value *NotV;
8500   if (match(V, m_Not(m_Value(NotV))))
8501     return NotV;
8502 
8503   const APInt *C;
8504   if (match(V, m_APInt(C)))
8505     return ConstantInt::get(V->getType(), ~(*C));
8506 
8507   return nullptr;
8508 }
8509 
8510 /// Match non-obvious integer minimum and maximum sequences.
8511 static SelectPatternResult matchMinMax(CmpInst::Predicate Pred,
8512                                        Value *CmpLHS, Value *CmpRHS,
8513                                        Value *TrueVal, Value *FalseVal,
8514                                        Value *&LHS, Value *&RHS,
8515                                        unsigned Depth) {
8516   // Assume success. If there's no match, callers should not use these anyway.
8517   LHS = TrueVal;
8518   RHS = FalseVal;
8519 
8520   SelectPatternResult SPR = matchClamp(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal);
8521   if (SPR.Flavor != SelectPatternFlavor::SPF_UNKNOWN)
8522     return SPR;
8523 
8524   SPR = matchMinMaxOfMinMax(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, Depth);
8525   if (SPR.Flavor != SelectPatternFlavor::SPF_UNKNOWN)
8526     return SPR;
8527 
8528   // Look through 'not' ops to find disguised min/max.
8529   // (X > Y) ? ~X : ~Y ==> (~X < ~Y) ? ~X : ~Y ==> MIN(~X, ~Y)
8530   // (X < Y) ? ~X : ~Y ==> (~X > ~Y) ? ~X : ~Y ==> MAX(~X, ~Y)
8531   if (CmpLHS == getNotValue(TrueVal) && CmpRHS == getNotValue(FalseVal)) {
8532     switch (Pred) {
8533     case CmpInst::ICMP_SGT: return {SPF_SMIN, SPNB_NA, false};
8534     case CmpInst::ICMP_SLT: return {SPF_SMAX, SPNB_NA, false};
8535     case CmpInst::ICMP_UGT: return {SPF_UMIN, SPNB_NA, false};
8536     case CmpInst::ICMP_ULT: return {SPF_UMAX, SPNB_NA, false};
8537     default: break;
8538     }
8539   }
8540 
8541   // (X > Y) ? ~Y : ~X ==> (~X < ~Y) ? ~Y : ~X ==> MAX(~Y, ~X)
8542   // (X < Y) ? ~Y : ~X ==> (~X > ~Y) ? ~Y : ~X ==> MIN(~Y, ~X)
8543   if (CmpLHS == getNotValue(FalseVal) && CmpRHS == getNotValue(TrueVal)) {
8544     switch (Pred) {
8545     case CmpInst::ICMP_SGT: return {SPF_SMAX, SPNB_NA, false};
8546     case CmpInst::ICMP_SLT: return {SPF_SMIN, SPNB_NA, false};
8547     case CmpInst::ICMP_UGT: return {SPF_UMAX, SPNB_NA, false};
8548     case CmpInst::ICMP_ULT: return {SPF_UMIN, SPNB_NA, false};
8549     default: break;
8550     }
8551   }
8552 
8553   if (Pred != CmpInst::ICMP_SGT && Pred != CmpInst::ICMP_SLT)
8554     return {SPF_UNKNOWN, SPNB_NA, false};
8555 
8556   const APInt *C1;
8557   if (!match(CmpRHS, m_APInt(C1)))
8558     return {SPF_UNKNOWN, SPNB_NA, false};
8559 
8560   // An unsigned min/max can be written with a signed compare.
8561   const APInt *C2;
8562   if ((CmpLHS == TrueVal && match(FalseVal, m_APInt(C2))) ||
8563       (CmpLHS == FalseVal && match(TrueVal, m_APInt(C2)))) {
8564     // Is the sign bit set?
8565     // (X <s 0) ? X : MAXVAL ==> (X >u MAXVAL) ? X : MAXVAL ==> UMAX
8566     // (X <s 0) ? MAXVAL : X ==> (X >u MAXVAL) ? MAXVAL : X ==> UMIN
8567     if (Pred == CmpInst::ICMP_SLT && C1->isZero() && C2->isMaxSignedValue())
8568       return {CmpLHS == TrueVal ? SPF_UMAX : SPF_UMIN, SPNB_NA, false};
8569 
8570     // Is the sign bit clear?
8571     // (X >s -1) ? MINVAL : X ==> (X <u MINVAL) ? MINVAL : X ==> UMAX
8572     // (X >s -1) ? X : MINVAL ==> (X <u MINVAL) ? X : MINVAL ==> UMIN
8573     if (Pred == CmpInst::ICMP_SGT && C1->isAllOnes() && C2->isMinSignedValue())
8574       return {CmpLHS == FalseVal ? SPF_UMAX : SPF_UMIN, SPNB_NA, false};
8575   }
8576 
8577   return {SPF_UNKNOWN, SPNB_NA, false};
8578 }
8579 
8580 bool llvm::isKnownNegation(const Value *X, const Value *Y, bool NeedNSW,
8581                            bool AllowPoison) {
8582   assert(X && Y && "Invalid operand");
8583 
8584   auto IsNegationOf = [&](const Value *X, const Value *Y) {
8585     if (!match(X, m_Neg(m_Specific(Y))))
8586       return false;
8587 
8588     auto *BO = cast<BinaryOperator>(X);
8589     if (NeedNSW && !BO->hasNoSignedWrap())
8590       return false;
8591 
8592     auto *Zero = cast<Constant>(BO->getOperand(0));
8593     if (!AllowPoison && !Zero->isNullValue())
8594       return false;
8595 
8596     return true;
8597   };
8598 
8599   // X = -Y or Y = -X
8600   if (IsNegationOf(X, Y) || IsNegationOf(Y, X))
8601     return true;
8602 
8603   // X = sub (A, B), Y = sub (B, A) || X = sub nsw (A, B), Y = sub nsw (B, A)
8604   Value *A, *B;
8605   return (!NeedNSW && (match(X, m_Sub(m_Value(A), m_Value(B))) &&
8606                         match(Y, m_Sub(m_Specific(B), m_Specific(A))))) ||
8607          (NeedNSW && (match(X, m_NSWSub(m_Value(A), m_Value(B))) &&
8608                        match(Y, m_NSWSub(m_Specific(B), m_Specific(A)))));
8609 }
8610 
8611 bool llvm::isKnownInversion(const Value *X, const Value *Y) {
8612   // Handle X = icmp pred A, B, Y = icmp pred A, C.
8613   Value *A, *B, *C;
8614   CmpPredicate Pred1, Pred2;
8615   if (!match(X, m_ICmp(Pred1, m_Value(A), m_Value(B))) ||
8616       !match(Y, m_c_ICmp(Pred2, m_Specific(A), m_Value(C))))
8617     return false;
8618 
8619   // They must both have samesign flag or not.
8620   if (cast<ICmpInst>(X)->hasSameSign() != cast<ICmpInst>(Y)->hasSameSign())
8621     return false;
8622 
8623   if (B == C)
8624     return Pred1 == ICmpInst::getInversePredicate(Pred2);
8625 
8626   // Try to infer the relationship from constant ranges.
8627   const APInt *RHSC1, *RHSC2;
8628   if (!match(B, m_APInt(RHSC1)) || !match(C, m_APInt(RHSC2)))
8629     return false;
8630 
8631   // Sign bits of two RHSCs should match.
8632   if (cast<ICmpInst>(X)->hasSameSign() &&
8633       RHSC1->isNonNegative() != RHSC2->isNonNegative())
8634     return false;
8635 
8636   const auto CR1 = ConstantRange::makeExactICmpRegion(Pred1, *RHSC1);
8637   const auto CR2 = ConstantRange::makeExactICmpRegion(Pred2, *RHSC2);
8638 
8639   return CR1.inverse() == CR2;
8640 }
8641 
8642 SelectPatternResult llvm::getSelectPattern(CmpInst::Predicate Pred,
8643                                            SelectPatternNaNBehavior NaNBehavior,
8644                                            bool Ordered) {
8645   switch (Pred) {
8646   default:
8647     return {SPF_UNKNOWN, SPNB_NA, false}; // Equality.
8648   case ICmpInst::ICMP_UGT:
8649   case ICmpInst::ICMP_UGE:
8650     return {SPF_UMAX, SPNB_NA, false};
8651   case ICmpInst::ICMP_SGT:
8652   case ICmpInst::ICMP_SGE:
8653     return {SPF_SMAX, SPNB_NA, false};
8654   case ICmpInst::ICMP_ULT:
8655   case ICmpInst::ICMP_ULE:
8656     return {SPF_UMIN, SPNB_NA, false};
8657   case ICmpInst::ICMP_SLT:
8658   case ICmpInst::ICMP_SLE:
8659     return {SPF_SMIN, SPNB_NA, false};
8660   case FCmpInst::FCMP_UGT:
8661   case FCmpInst::FCMP_UGE:
8662   case FCmpInst::FCMP_OGT:
8663   case FCmpInst::FCMP_OGE:
8664     return {SPF_FMAXNUM, NaNBehavior, Ordered};
8665   case FCmpInst::FCMP_ULT:
8666   case FCmpInst::FCMP_ULE:
8667   case FCmpInst::FCMP_OLT:
8668   case FCmpInst::FCMP_OLE:
8669     return {SPF_FMINNUM, NaNBehavior, Ordered};
8670   }
8671 }
8672 
8673 std::optional<std::pair<CmpPredicate, Constant *>>
8674 llvm::getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred, Constant *C) {
8675   assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) &&
8676          "Only for relational integer predicates.");
8677   if (isa<UndefValue>(C))
8678     return std::nullopt;
8679 
8680   Type *Type = C->getType();
8681   bool IsSigned = ICmpInst::isSigned(Pred);
8682 
8683   CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred);
8684   bool WillIncrement =
8685       UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT;
8686 
8687   // Check if the constant operand can be safely incremented/decremented
8688   // without overflowing/underflowing.
8689   auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) {
8690     return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned);
8691   };
8692 
8693   Constant *SafeReplacementConstant = nullptr;
8694   if (auto *CI = dyn_cast<ConstantInt>(C)) {
8695     // Bail out if the constant can't be safely incremented/decremented.
8696     if (!ConstantIsOk(CI))
8697       return std::nullopt;
8698   } else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) {
8699     unsigned NumElts = FVTy->getNumElements();
8700     for (unsigned i = 0; i != NumElts; ++i) {
8701       Constant *Elt = C->getAggregateElement(i);
8702       if (!Elt)
8703         return std::nullopt;
8704 
8705       if (isa<UndefValue>(Elt))
8706         continue;
8707 
8708       // Bail out if we can't determine if this constant is min/max or if we
8709       // know that this constant is min/max.
8710       auto *CI = dyn_cast<ConstantInt>(Elt);
8711       if (!CI || !ConstantIsOk(CI))
8712         return std::nullopt;
8713 
8714       if (!SafeReplacementConstant)
8715         SafeReplacementConstant = CI;
8716     }
8717   } else if (isa<VectorType>(C->getType())) {
8718     // Handle scalable splat
8719     Value *SplatC = C->getSplatValue();
8720     auto *CI = dyn_cast_or_null<ConstantInt>(SplatC);
8721     // Bail out if the constant can't be safely incremented/decremented.
8722     if (!CI || !ConstantIsOk(CI))
8723       return std::nullopt;
8724   } else {
8725     // ConstantExpr?
8726     return std::nullopt;
8727   }
8728 
8729   // It may not be safe to change a compare predicate in the presence of
8730   // undefined elements, so replace those elements with the first safe constant
8731   // that we found.
8732   // TODO: in case of poison, it is safe; let's replace undefs only.
8733   if (C->containsUndefOrPoisonElement()) {
8734     assert(SafeReplacementConstant && "Replacement constant not set");
8735     C = Constant::replaceUndefsWith(C, SafeReplacementConstant);
8736   }
8737 
8738   CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred);
8739 
8740   // Increment or decrement the constant.
8741   Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true);
8742   Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne);
8743 
8744   return std::make_pair(NewPred, NewC);
8745 }
8746 
8747 static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred,
8748                                               FastMathFlags FMF,
8749                                               Value *CmpLHS, Value *CmpRHS,
8750                                               Value *TrueVal, Value *FalseVal,
8751                                               Value *&LHS, Value *&RHS,
8752                                               unsigned Depth) {
8753   bool HasMismatchedZeros = false;
8754   if (CmpInst::isFPPredicate(Pred)) {
8755     // IEEE-754 ignores the sign of 0.0 in comparisons. So if the select has one
8756     // 0.0 operand, set the compare's 0.0 operands to that same value for the
8757     // purpose of identifying min/max. Disregard vector constants with undefined
8758     // elements because those can not be back-propagated for analysis.
8759     Value *OutputZeroVal = nullptr;
8760     if (match(TrueVal, m_AnyZeroFP()) && !match(FalseVal, m_AnyZeroFP()) &&
8761         !cast<Constant>(TrueVal)->containsUndefOrPoisonElement())
8762       OutputZeroVal = TrueVal;
8763     else if (match(FalseVal, m_AnyZeroFP()) && !match(TrueVal, m_AnyZeroFP()) &&
8764              !cast<Constant>(FalseVal)->containsUndefOrPoisonElement())
8765       OutputZeroVal = FalseVal;
8766 
8767     if (OutputZeroVal) {
8768       if (match(CmpLHS, m_AnyZeroFP()) && CmpLHS != OutputZeroVal) {
8769         HasMismatchedZeros = true;
8770         CmpLHS = OutputZeroVal;
8771       }
8772       if (match(CmpRHS, m_AnyZeroFP()) && CmpRHS != OutputZeroVal) {
8773         HasMismatchedZeros = true;
8774         CmpRHS = OutputZeroVal;
8775       }
8776     }
8777   }
8778 
8779   LHS = CmpLHS;
8780   RHS = CmpRHS;
8781 
8782   // Signed zero may return inconsistent results between implementations.
8783   //  (0.0 <= -0.0) ? 0.0 : -0.0 // Returns 0.0
8784   //  minNum(0.0, -0.0)          // May return -0.0 or 0.0 (IEEE 754-2008 5.3.1)
8785   // Therefore, we behave conservatively and only proceed if at least one of the
8786   // operands is known to not be zero or if we don't care about signed zero.
8787   switch (Pred) {
8788   default: break;
8789   case CmpInst::FCMP_OGT: case CmpInst::FCMP_OLT:
8790   case CmpInst::FCMP_UGT: case CmpInst::FCMP_ULT:
8791     if (!HasMismatchedZeros)
8792       break;
8793     [[fallthrough]];
8794   case CmpInst::FCMP_OGE: case CmpInst::FCMP_OLE:
8795   case CmpInst::FCMP_UGE: case CmpInst::FCMP_ULE:
8796     if (!FMF.noSignedZeros() && !isKnownNonZero(CmpLHS) &&
8797         !isKnownNonZero(CmpRHS))
8798       return {SPF_UNKNOWN, SPNB_NA, false};
8799   }
8800 
8801   SelectPatternNaNBehavior NaNBehavior = SPNB_NA;
8802   bool Ordered = false;
8803 
8804   // When given one NaN and one non-NaN input:
8805   //   - maxnum/minnum (C99 fmaxf()/fminf()) return the non-NaN input.
8806   //   - A simple C99 (a < b ? a : b) construction will return 'b' (as the
8807   //     ordered comparison fails), which could be NaN or non-NaN.
8808   // so here we discover exactly what NaN behavior is required/accepted.
8809   if (CmpInst::isFPPredicate(Pred)) {
8810     bool LHSSafe = isKnownNonNaN(CmpLHS, FMF);
8811     bool RHSSafe = isKnownNonNaN(CmpRHS, FMF);
8812 
8813     if (LHSSafe && RHSSafe) {
8814       // Both operands are known non-NaN.
8815       NaNBehavior = SPNB_RETURNS_ANY;
8816     } else if (CmpInst::isOrdered(Pred)) {
8817       // An ordered comparison will return false when given a NaN, so it
8818       // returns the RHS.
8819       Ordered = true;
8820       if (LHSSafe)
8821         // LHS is non-NaN, so if RHS is NaN then NaN will be returned.
8822         NaNBehavior = SPNB_RETURNS_NAN;
8823       else if (RHSSafe)
8824         NaNBehavior = SPNB_RETURNS_OTHER;
8825       else
8826         // Completely unsafe.
8827         return {SPF_UNKNOWN, SPNB_NA, false};
8828     } else {
8829       Ordered = false;
8830       // An unordered comparison will return true when given a NaN, so it
8831       // returns the LHS.
8832       if (LHSSafe)
8833         // LHS is non-NaN, so if RHS is NaN then non-NaN will be returned.
8834         NaNBehavior = SPNB_RETURNS_OTHER;
8835       else if (RHSSafe)
8836         NaNBehavior = SPNB_RETURNS_NAN;
8837       else
8838         // Completely unsafe.
8839         return {SPF_UNKNOWN, SPNB_NA, false};
8840     }
8841   }
8842 
8843   if (TrueVal == CmpRHS && FalseVal == CmpLHS) {
8844     std::swap(CmpLHS, CmpRHS);
8845     Pred = CmpInst::getSwappedPredicate(Pred);
8846     if (NaNBehavior == SPNB_RETURNS_NAN)
8847       NaNBehavior = SPNB_RETURNS_OTHER;
8848     else if (NaNBehavior == SPNB_RETURNS_OTHER)
8849       NaNBehavior = SPNB_RETURNS_NAN;
8850     Ordered = !Ordered;
8851   }
8852 
8853   // ([if]cmp X, Y) ? X : Y
8854   if (TrueVal == CmpLHS && FalseVal == CmpRHS)
8855     return getSelectPattern(Pred, NaNBehavior, Ordered);
8856 
8857   if (isKnownNegation(TrueVal, FalseVal)) {
8858     // Sign-extending LHS does not change its sign, so TrueVal/FalseVal can
8859     // match against either LHS or sext(LHS).
8860     auto MaybeSExtCmpLHS =
8861         m_CombineOr(m_Specific(CmpLHS), m_SExt(m_Specific(CmpLHS)));
8862     auto ZeroOrAllOnes = m_CombineOr(m_ZeroInt(), m_AllOnes());
8863     auto ZeroOrOne = m_CombineOr(m_ZeroInt(), m_One());
8864     if (match(TrueVal, MaybeSExtCmpLHS)) {
8865       // Set the return values. If the compare uses the negated value (-X >s 0),
8866       // swap the return values because the negated value is always 'RHS'.
8867       LHS = TrueVal;
8868       RHS = FalseVal;
8869       if (match(CmpLHS, m_Neg(m_Specific(FalseVal))))
8870         std::swap(LHS, RHS);
8871 
8872       // (X >s 0) ? X : -X or (X >s -1) ? X : -X --> ABS(X)
8873       // (-X >s 0) ? -X : X or (-X >s -1) ? -X : X --> ABS(X)
8874       if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, ZeroOrAllOnes))
8875         return {SPF_ABS, SPNB_NA, false};
8876 
8877       // (X >=s 0) ? X : -X or (X >=s 1) ? X : -X --> ABS(X)
8878       if (Pred == ICmpInst::ICMP_SGE && match(CmpRHS, ZeroOrOne))
8879         return {SPF_ABS, SPNB_NA, false};
8880 
8881       // (X <s 0) ? X : -X or (X <s 1) ? X : -X --> NABS(X)
8882       // (-X <s 0) ? -X : X or (-X <s 1) ? -X : X --> NABS(X)
8883       if (Pred == ICmpInst::ICMP_SLT && match(CmpRHS, ZeroOrOne))
8884         return {SPF_NABS, SPNB_NA, false};
8885     }
8886     else if (match(FalseVal, MaybeSExtCmpLHS)) {
8887       // Set the return values. If the compare uses the negated value (-X >s 0),
8888       // swap the return values because the negated value is always 'RHS'.
8889       LHS = FalseVal;
8890       RHS = TrueVal;
8891       if (match(CmpLHS, m_Neg(m_Specific(TrueVal))))
8892         std::swap(LHS, RHS);
8893 
8894       // (X >s 0) ? -X : X or (X >s -1) ? -X : X --> NABS(X)
8895       // (-X >s 0) ? X : -X or (-X >s -1) ? X : -X --> NABS(X)
8896       if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, ZeroOrAllOnes))
8897         return {SPF_NABS, SPNB_NA, false};
8898 
8899       // (X <s 0) ? -X : X or (X <s 1) ? -X : X --> ABS(X)
8900       // (-X <s 0) ? X : -X or (-X <s 1) ? X : -X --> ABS(X)
8901       if (Pred == ICmpInst::ICMP_SLT && match(CmpRHS, ZeroOrOne))
8902         return {SPF_ABS, SPNB_NA, false};
8903     }
8904   }
8905 
8906   if (CmpInst::isIntPredicate(Pred))
8907     return matchMinMax(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS, Depth);
8908 
8909   // According to (IEEE 754-2008 5.3.1), minNum(0.0, -0.0) and similar
8910   // may return either -0.0 or 0.0, so fcmp/select pair has stricter
8911   // semantics than minNum. Be conservative in such case.
8912   if (NaNBehavior != SPNB_RETURNS_ANY ||
8913       (!FMF.noSignedZeros() && !isKnownNonZero(CmpLHS) &&
8914        !isKnownNonZero(CmpRHS)))
8915     return {SPF_UNKNOWN, SPNB_NA, false};
8916 
8917   return matchFastFloatClamp(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS);
8918 }
8919 
8920 static Value *lookThroughCastConst(CmpInst *CmpI, Type *SrcTy, Constant *C,
8921                                    Instruction::CastOps *CastOp) {
8922   const DataLayout &DL = CmpI->getDataLayout();
8923 
8924   Constant *CastedTo = nullptr;
8925   switch (*CastOp) {
8926   case Instruction::ZExt:
8927     if (CmpI->isUnsigned())
8928       CastedTo = ConstantExpr::getTrunc(C, SrcTy);
8929     break;
8930   case Instruction::SExt:
8931     if (CmpI->isSigned())
8932       CastedTo = ConstantExpr::getTrunc(C, SrcTy, true);
8933     break;
8934   case Instruction::Trunc:
8935     Constant *CmpConst;
8936     if (match(CmpI->getOperand(1), m_Constant(CmpConst)) &&
8937         CmpConst->getType() == SrcTy) {
8938       // Here we have the following case:
8939       //
8940       //   %cond = cmp iN %x, CmpConst
8941       //   %tr = trunc iN %x to iK
8942       //   %narrowsel = select i1 %cond, iK %t, iK C
8943       //
8944       // We can always move trunc after select operation:
8945       //
8946       //   %cond = cmp iN %x, CmpConst
8947       //   %widesel = select i1 %cond, iN %x, iN CmpConst
8948       //   %tr = trunc iN %widesel to iK
8949       //
8950       // Note that C could be extended in any way because we don't care about
8951       // upper bits after truncation. It can't be abs pattern, because it would
8952       // look like:
8953       //
8954       //   select i1 %cond, x, -x.
8955       //
8956       // So only min/max pattern could be matched. Such match requires widened C
8957       // == CmpConst. That is why set widened C = CmpConst, condition trunc
8958       // CmpConst == C is checked below.
8959       CastedTo = CmpConst;
8960     } else {
8961       unsigned ExtOp = CmpI->isSigned() ? Instruction::SExt : Instruction::ZExt;
8962       CastedTo = ConstantFoldCastOperand(ExtOp, C, SrcTy, DL);
8963     }
8964     break;
8965   case Instruction::FPTrunc:
8966     CastedTo = ConstantFoldCastOperand(Instruction::FPExt, C, SrcTy, DL);
8967     break;
8968   case Instruction::FPExt:
8969     CastedTo = ConstantFoldCastOperand(Instruction::FPTrunc, C, SrcTy, DL);
8970     break;
8971   case Instruction::FPToUI:
8972     CastedTo = ConstantFoldCastOperand(Instruction::UIToFP, C, SrcTy, DL);
8973     break;
8974   case Instruction::FPToSI:
8975     CastedTo = ConstantFoldCastOperand(Instruction::SIToFP, C, SrcTy, DL);
8976     break;
8977   case Instruction::UIToFP:
8978     CastedTo = ConstantFoldCastOperand(Instruction::FPToUI, C, SrcTy, DL);
8979     break;
8980   case Instruction::SIToFP:
8981     CastedTo = ConstantFoldCastOperand(Instruction::FPToSI, C, SrcTy, DL);
8982     break;
8983   default:
8984     break;
8985   }
8986 
8987   if (!CastedTo)
8988     return nullptr;
8989 
8990   // Make sure the cast doesn't lose any information.
8991   Constant *CastedBack =
8992       ConstantFoldCastOperand(*CastOp, CastedTo, C->getType(), DL);
8993   if (CastedBack && CastedBack != C)
8994     return nullptr;
8995 
8996   return CastedTo;
8997 }
8998 
8999 /// Helps to match a select pattern in case of a type mismatch.
9000 ///
9001 /// The function processes the case when type of true and false values of a
9002 /// select instruction differs from type of the cmp instruction operands because
9003 /// of a cast instruction. The function checks if it is legal to move the cast
9004 /// operation after "select". If yes, it returns the new second value of
9005 /// "select" (with the assumption that cast is moved):
9006 /// 1. As operand of cast instruction when both values of "select" are same cast
9007 /// instructions.
9008 /// 2. As restored constant (by applying reverse cast operation) when the first
9009 /// value of the "select" is a cast operation and the second value is a
9010 /// constant. It is implemented in lookThroughCastConst().
9011 /// 3. As one operand is cast instruction and the other is not. The operands in
9012 /// sel(cmp) are in different type integer.
9013 /// NOTE: We return only the new second value because the first value could be
9014 /// accessed as operand of cast instruction.
9015 static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2,
9016                               Instruction::CastOps *CastOp) {
9017   auto *Cast1 = dyn_cast<CastInst>(V1);
9018   if (!Cast1)
9019     return nullptr;
9020 
9021   *CastOp = Cast1->getOpcode();
9022   Type *SrcTy = Cast1->getSrcTy();
9023   if (auto *Cast2 = dyn_cast<CastInst>(V2)) {
9024     // If V1 and V2 are both the same cast from the same type, look through V1.
9025     if (*CastOp == Cast2->getOpcode() && SrcTy == Cast2->getSrcTy())
9026       return Cast2->getOperand(0);
9027     return nullptr;
9028   }
9029 
9030   auto *C = dyn_cast<Constant>(V2);
9031   if (C)
9032     return lookThroughCastConst(CmpI, SrcTy, C, CastOp);
9033 
9034   Value *CastedTo = nullptr;
9035   if (*CastOp == Instruction::Trunc) {
9036     if (match(CmpI->getOperand(1), m_ZExtOrSExt(m_Specific(V2)))) {
9037       // Here we have the following case:
9038       //   %y_ext = sext iK %y to iN
9039       //   %cond = cmp iN %x, %y_ext
9040       //   %tr = trunc iN %x to iK
9041       //   %narrowsel = select i1 %cond, iK %tr, iK %y
9042       //
9043       // We can always move trunc after select operation:
9044       //   %y_ext = sext iK %y to iN
9045       //   %cond = cmp iN %x, %y_ext
9046       //   %widesel = select i1 %cond, iN %x, iN %y_ext
9047       //   %tr = trunc iN %widesel to iK
9048       assert(V2->getType() == Cast1->getType() &&
9049              "V2 and Cast1 should be the same type.");
9050       CastedTo = CmpI->getOperand(1);
9051     }
9052   }
9053 
9054   return CastedTo;
9055 }
9056 SelectPatternResult llvm::matchSelectPattern(Value *V, Value *&LHS, Value *&RHS,
9057                                              Instruction::CastOps *CastOp,
9058                                              unsigned Depth) {
9059   if (Depth >= MaxAnalysisRecursionDepth)
9060     return {SPF_UNKNOWN, SPNB_NA, false};
9061 
9062   SelectInst *SI = dyn_cast<SelectInst>(V);
9063   if (!SI) return {SPF_UNKNOWN, SPNB_NA, false};
9064 
9065   CmpInst *CmpI = dyn_cast<CmpInst>(SI->getCondition());
9066   if (!CmpI) return {SPF_UNKNOWN, SPNB_NA, false};
9067 
9068   Value *TrueVal = SI->getTrueValue();
9069   Value *FalseVal = SI->getFalseValue();
9070 
9071   return llvm::matchDecomposedSelectPattern(CmpI, TrueVal, FalseVal, LHS, RHS,
9072                                             CastOp, Depth);
9073 }
9074 
9075 SelectPatternResult llvm::matchDecomposedSelectPattern(
9076     CmpInst *CmpI, Value *TrueVal, Value *FalseVal, Value *&LHS, Value *&RHS,
9077     Instruction::CastOps *CastOp, unsigned Depth) {
9078   CmpInst::Predicate Pred = CmpI->getPredicate();
9079   Value *CmpLHS = CmpI->getOperand(0);
9080   Value *CmpRHS = CmpI->getOperand(1);
9081   FastMathFlags FMF;
9082   if (isa<FPMathOperator>(CmpI))
9083     FMF = CmpI->getFastMathFlags();
9084 
9085   // Bail out early.
9086   if (CmpI->isEquality())
9087     return {SPF_UNKNOWN, SPNB_NA, false};
9088 
9089   // Deal with type mismatches.
9090   if (CastOp && CmpLHS->getType() != TrueVal->getType()) {
9091     if (Value *C = lookThroughCast(CmpI, TrueVal, FalseVal, CastOp)) {
9092       // If this is a potential fmin/fmax with a cast to integer, then ignore
9093       // -0.0 because there is no corresponding integer value.
9094       if (*CastOp == Instruction::FPToSI || *CastOp == Instruction::FPToUI)
9095         FMF.setNoSignedZeros();
9096       return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS,
9097                                   cast<CastInst>(TrueVal)->getOperand(0), C,
9098                                   LHS, RHS, Depth);
9099     }
9100     if (Value *C = lookThroughCast(CmpI, FalseVal, TrueVal, CastOp)) {
9101       // If this is a potential fmin/fmax with a cast to integer, then ignore
9102       // -0.0 because there is no corresponding integer value.
9103       if (*CastOp == Instruction::FPToSI || *CastOp == Instruction::FPToUI)
9104         FMF.setNoSignedZeros();
9105       return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS,
9106                                   C, cast<CastInst>(FalseVal)->getOperand(0),
9107                                   LHS, RHS, Depth);
9108     }
9109   }
9110   return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS, TrueVal, FalseVal,
9111                               LHS, RHS, Depth);
9112 }
9113 
9114 CmpInst::Predicate llvm::getMinMaxPred(SelectPatternFlavor SPF, bool Ordered) {
9115   if (SPF == SPF_SMIN) return ICmpInst::ICMP_SLT;
9116   if (SPF == SPF_UMIN) return ICmpInst::ICMP_ULT;
9117   if (SPF == SPF_SMAX) return ICmpInst::ICMP_SGT;
9118   if (SPF == SPF_UMAX) return ICmpInst::ICMP_UGT;
9119   if (SPF == SPF_FMINNUM)
9120     return Ordered ? FCmpInst::FCMP_OLT : FCmpInst::FCMP_ULT;
9121   if (SPF == SPF_FMAXNUM)
9122     return Ordered ? FCmpInst::FCMP_OGT : FCmpInst::FCMP_UGT;
9123   llvm_unreachable("unhandled!");
9124 }
9125 
9126 Intrinsic::ID llvm::getMinMaxIntrinsic(SelectPatternFlavor SPF) {
9127   switch (SPF) {
9128   case SelectPatternFlavor::SPF_UMIN:
9129     return Intrinsic::umin;
9130   case SelectPatternFlavor::SPF_UMAX:
9131     return Intrinsic::umax;
9132   case SelectPatternFlavor::SPF_SMIN:
9133     return Intrinsic::smin;
9134   case SelectPatternFlavor::SPF_SMAX:
9135     return Intrinsic::smax;
9136   default:
9137     llvm_unreachable("Unexpected SPF");
9138   }
9139 }
9140 
9141 SelectPatternFlavor llvm::getInverseMinMaxFlavor(SelectPatternFlavor SPF) {
9142   if (SPF == SPF_SMIN) return SPF_SMAX;
9143   if (SPF == SPF_UMIN) return SPF_UMAX;
9144   if (SPF == SPF_SMAX) return SPF_SMIN;
9145   if (SPF == SPF_UMAX) return SPF_UMIN;
9146   llvm_unreachable("unhandled!");
9147 }
9148 
9149 Intrinsic::ID llvm::getInverseMinMaxIntrinsic(Intrinsic::ID MinMaxID) {
9150   switch (MinMaxID) {
9151   case Intrinsic::smax: return Intrinsic::smin;
9152   case Intrinsic::smin: return Intrinsic::smax;
9153   case Intrinsic::umax: return Intrinsic::umin;
9154   case Intrinsic::umin: return Intrinsic::umax;
9155   // Please note that next four intrinsics may produce the same result for
9156   // original and inverted case even if X != Y due to NaN is handled specially.
9157   case Intrinsic::maximum: return Intrinsic::minimum;
9158   case Intrinsic::minimum: return Intrinsic::maximum;
9159   case Intrinsic::maxnum: return Intrinsic::minnum;
9160   case Intrinsic::minnum: return Intrinsic::maxnum;
9161   default: llvm_unreachable("Unexpected intrinsic");
9162   }
9163 }
9164 
9165 APInt llvm::getMinMaxLimit(SelectPatternFlavor SPF, unsigned BitWidth) {
9166   switch (SPF) {
9167   case SPF_SMAX: return APInt::getSignedMaxValue(BitWidth);
9168   case SPF_SMIN: return APInt::getSignedMinValue(BitWidth);
9169   case SPF_UMAX: return APInt::getMaxValue(BitWidth);
9170   case SPF_UMIN: return APInt::getMinValue(BitWidth);
9171   default: llvm_unreachable("Unexpected flavor");
9172   }
9173 }
9174 
9175 std::pair<Intrinsic::ID, bool>
9176 llvm::canConvertToMinOrMaxIntrinsic(ArrayRef<Value *> VL) {
9177   // Check if VL contains select instructions that can be folded into a min/max
9178   // vector intrinsic and return the intrinsic if it is possible.
9179   // TODO: Support floating point min/max.
9180   bool AllCmpSingleUse = true;
9181   SelectPatternResult SelectPattern;
9182   SelectPattern.Flavor = SPF_UNKNOWN;
9183   if (all_of(VL, [&SelectPattern, &AllCmpSingleUse](Value *I) {
9184         Value *LHS, *RHS;
9185         auto CurrentPattern = matchSelectPattern(I, LHS, RHS);
9186         if (!SelectPatternResult::isMinOrMax(CurrentPattern.Flavor))
9187           return false;
9188         if (SelectPattern.Flavor != SPF_UNKNOWN &&
9189             SelectPattern.Flavor != CurrentPattern.Flavor)
9190           return false;
9191         SelectPattern = CurrentPattern;
9192         AllCmpSingleUse &=
9193             match(I, m_Select(m_OneUse(m_Value()), m_Value(), m_Value()));
9194         return true;
9195       })) {
9196     switch (SelectPattern.Flavor) {
9197     case SPF_SMIN:
9198       return {Intrinsic::smin, AllCmpSingleUse};
9199     case SPF_UMIN:
9200       return {Intrinsic::umin, AllCmpSingleUse};
9201     case SPF_SMAX:
9202       return {Intrinsic::smax, AllCmpSingleUse};
9203     case SPF_UMAX:
9204       return {Intrinsic::umax, AllCmpSingleUse};
9205     case SPF_FMAXNUM:
9206       return {Intrinsic::maxnum, AllCmpSingleUse};
9207     case SPF_FMINNUM:
9208       return {Intrinsic::minnum, AllCmpSingleUse};
9209     default:
9210       llvm_unreachable("unexpected select pattern flavor");
9211     }
9212   }
9213   return {Intrinsic::not_intrinsic, false};
9214 }
9215 
9216 bool llvm::matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO,
9217                                  Value *&Start, Value *&Step) {
9218   // Handle the case of a simple two-predecessor recurrence PHI.
9219   // There's a lot more that could theoretically be done here, but
9220   // this is sufficient to catch some interesting cases.
9221   if (P->getNumIncomingValues() != 2)
9222     return false;
9223 
9224   for (unsigned i = 0; i != 2; ++i) {
9225     Value *L = P->getIncomingValue(i);
9226     Value *R = P->getIncomingValue(!i);
9227     auto *LU = dyn_cast<BinaryOperator>(L);
9228     if (!LU)
9229       continue;
9230     unsigned Opcode = LU->getOpcode();
9231 
9232     switch (Opcode) {
9233     default:
9234       continue;
9235     // TODO: Expand list -- xor, gep, uadd.sat etc.
9236     case Instruction::LShr:
9237     case Instruction::AShr:
9238     case Instruction::Shl:
9239     case Instruction::Add:
9240     case Instruction::Sub:
9241     case Instruction::UDiv:
9242     case Instruction::URem:
9243     case Instruction::And:
9244     case Instruction::Or:
9245     case Instruction::Mul:
9246     case Instruction::FMul: {
9247       Value *LL = LU->getOperand(0);
9248       Value *LR = LU->getOperand(1);
9249       // Find a recurrence.
9250       if (LL == P)
9251         L = LR;
9252       else if (LR == P)
9253         L = LL;
9254       else
9255         continue; // Check for recurrence with L and R flipped.
9256 
9257       break; // Match!
9258     }
9259     };
9260 
9261     // We have matched a recurrence of the form:
9262     //   %iv = [R, %entry], [%iv.next, %backedge]
9263     //   %iv.next = binop %iv, L
9264     // OR
9265     //   %iv = [R, %entry], [%iv.next, %backedge]
9266     //   %iv.next = binop L, %iv
9267     BO = LU;
9268     Start = R;
9269     Step = L;
9270     return true;
9271   }
9272   return false;
9273 }
9274 
9275 bool llvm::matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P,
9276                                  Value *&Start, Value *&Step) {
9277   BinaryOperator *BO = nullptr;
9278   P = dyn_cast<PHINode>(I->getOperand(0));
9279   if (!P)
9280     P = dyn_cast<PHINode>(I->getOperand(1));
9281   return P && matchSimpleRecurrence(P, BO, Start, Step) && BO == I;
9282 }
9283 
9284 /// Return true if "icmp Pred LHS RHS" is always true.
9285 static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS,
9286                             const Value *RHS) {
9287   if (ICmpInst::isTrueWhenEqual(Pred) && LHS == RHS)
9288     return true;
9289 
9290   switch (Pred) {
9291   default:
9292     return false;
9293 
9294   case CmpInst::ICMP_SLE: {
9295     const APInt *C;
9296 
9297     // LHS s<= LHS +_{nsw} C   if C >= 0
9298     // LHS s<= LHS | C         if C >= 0
9299     if (match(RHS, m_NSWAdd(m_Specific(LHS), m_APInt(C))) ||
9300         match(RHS, m_Or(m_Specific(LHS), m_APInt(C))))
9301       return !C->isNegative();
9302 
9303     // LHS s<= smax(LHS, V) for any V
9304     if (match(RHS, m_c_SMax(m_Specific(LHS), m_Value())))
9305       return true;
9306 
9307     // smin(RHS, V) s<= RHS for any V
9308     if (match(LHS, m_c_SMin(m_Specific(RHS), m_Value())))
9309       return true;
9310 
9311     // Match A to (X +_{nsw} CA) and B to (X +_{nsw} CB)
9312     const Value *X;
9313     const APInt *CLHS, *CRHS;
9314     if (match(LHS, m_NSWAddLike(m_Value(X), m_APInt(CLHS))) &&
9315         match(RHS, m_NSWAddLike(m_Specific(X), m_APInt(CRHS))))
9316       return CLHS->sle(*CRHS);
9317 
9318     return false;
9319   }
9320 
9321   case CmpInst::ICMP_ULE: {
9322     // LHS u<= LHS +_{nuw} V for any V
9323     if (match(RHS, m_c_Add(m_Specific(LHS), m_Value())) &&
9324         cast<OverflowingBinaryOperator>(RHS)->hasNoUnsignedWrap())
9325       return true;
9326 
9327     // LHS u<= LHS | V for any V
9328     if (match(RHS, m_c_Or(m_Specific(LHS), m_Value())))
9329       return true;
9330 
9331     // LHS u<= umax(LHS, V) for any V
9332     if (match(RHS, m_c_UMax(m_Specific(LHS), m_Value())))
9333       return true;
9334 
9335     // RHS >> V u<= RHS for any V
9336     if (match(LHS, m_LShr(m_Specific(RHS), m_Value())))
9337       return true;
9338 
9339     // RHS u/ C_ugt_1 u<= RHS
9340     const APInt *C;
9341     if (match(LHS, m_UDiv(m_Specific(RHS), m_APInt(C))) && C->ugt(1))
9342       return true;
9343 
9344     // RHS & V u<= RHS for any V
9345     if (match(LHS, m_c_And(m_Specific(RHS), m_Value())))
9346       return true;
9347 
9348     // umin(RHS, V) u<= RHS for any V
9349     if (match(LHS, m_c_UMin(m_Specific(RHS), m_Value())))
9350       return true;
9351 
9352     // Match A to (X +_{nuw} CA) and B to (X +_{nuw} CB)
9353     const Value *X;
9354     const APInt *CLHS, *CRHS;
9355     if (match(LHS, m_NUWAddLike(m_Value(X), m_APInt(CLHS))) &&
9356         match(RHS, m_NUWAddLike(m_Specific(X), m_APInt(CRHS))))
9357       return CLHS->ule(*CRHS);
9358 
9359     return false;
9360   }
9361   }
9362 }
9363 
9364 /// Return true if "icmp Pred BLHS BRHS" is true whenever "icmp Pred
9365 /// ALHS ARHS" is true.  Otherwise, return std::nullopt.
9366 static std::optional<bool>
9367 isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS,
9368                       const Value *ARHS, const Value *BLHS, const Value *BRHS) {
9369   switch (Pred) {
9370   default:
9371     return std::nullopt;
9372 
9373   case CmpInst::ICMP_SLT:
9374   case CmpInst::ICMP_SLE:
9375     if (isTruePredicate(CmpInst::ICMP_SLE, BLHS, ALHS) &&
9376         isTruePredicate(CmpInst::ICMP_SLE, ARHS, BRHS))
9377       return true;
9378     return std::nullopt;
9379 
9380   case CmpInst::ICMP_SGT:
9381   case CmpInst::ICMP_SGE:
9382     if (isTruePredicate(CmpInst::ICMP_SLE, ALHS, BLHS) &&
9383         isTruePredicate(CmpInst::ICMP_SLE, BRHS, ARHS))
9384       return true;
9385     return std::nullopt;
9386 
9387   case CmpInst::ICMP_ULT:
9388   case CmpInst::ICMP_ULE:
9389     if (isTruePredicate(CmpInst::ICMP_ULE, BLHS, ALHS) &&
9390         isTruePredicate(CmpInst::ICMP_ULE, ARHS, BRHS))
9391       return true;
9392     return std::nullopt;
9393 
9394   case CmpInst::ICMP_UGT:
9395   case CmpInst::ICMP_UGE:
9396     if (isTruePredicate(CmpInst::ICMP_ULE, ALHS, BLHS) &&
9397         isTruePredicate(CmpInst::ICMP_ULE, BRHS, ARHS))
9398       return true;
9399     return std::nullopt;
9400   }
9401 }
9402 
9403 /// Return true if "icmp LPred X, LCR" implies "icmp RPred X, RCR" is true.
9404 /// Return false if "icmp LPred X, LCR" implies "icmp RPred X, RCR" is false.
9405 /// Otherwise, return std::nullopt if we can't infer anything.
9406 static std::optional<bool>
9407 isImpliedCondCommonOperandWithCR(CmpPredicate LPred, const ConstantRange &LCR,
9408                                  CmpPredicate RPred, const ConstantRange &RCR) {
9409   auto CRImpliesPred = [&](ConstantRange CR,
9410                            CmpInst::Predicate Pred) -> std::optional<bool> {
9411     // If all true values for lhs and true for rhs, lhs implies rhs
9412     if (CR.icmp(Pred, RCR))
9413       return true;
9414 
9415     // If there is no overlap, lhs implies not rhs
9416     if (CR.icmp(CmpInst::getInversePredicate(Pred), RCR))
9417       return false;
9418 
9419     return std::nullopt;
9420   };
9421   if (auto Res = CRImpliesPred(ConstantRange::makeAllowedICmpRegion(LPred, LCR),
9422                                RPred))
9423     return Res;
9424   if (LPred.hasSameSign() ^ RPred.hasSameSign()) {
9425     LPred = LPred.hasSameSign() ? ICmpInst::getFlippedSignednessPredicate(LPred)
9426                                 : static_cast<CmpInst::Predicate>(LPred);
9427     RPred = RPred.hasSameSign() ? ICmpInst::getFlippedSignednessPredicate(RPred)
9428                                 : static_cast<CmpInst::Predicate>(RPred);
9429     return CRImpliesPred(ConstantRange::makeAllowedICmpRegion(LPred, LCR),
9430                          RPred);
9431   }
9432   return std::nullopt;
9433 }
9434 
9435 /// Return true if LHS implies RHS (expanded to its components as "R0 RPred R1")
9436 /// is true.  Return false if LHS implies RHS is false. Otherwise, return
9437 /// std::nullopt if we can't infer anything.
9438 static std::optional<bool>
9439 isImpliedCondICmps(const ICmpInst *LHS, CmpPredicate RPred, const Value *R0,
9440                    const Value *R1, const DataLayout &DL, bool LHSIsTrue) {
9441   Value *L0 = LHS->getOperand(0);
9442   Value *L1 = LHS->getOperand(1);
9443 
9444   // The rest of the logic assumes the LHS condition is true.  If that's not the
9445   // case, invert the predicate to make it so.
9446   CmpPredicate LPred =
9447       LHSIsTrue ? LHS->getCmpPredicate() : LHS->getInverseCmpPredicate();
9448 
9449   // We can have non-canonical operands, so try to normalize any common operand
9450   // to L0/R0.
9451   if (L0 == R1) {
9452     std::swap(R0, R1);
9453     RPred = ICmpInst::getSwappedCmpPredicate(RPred);
9454   }
9455   if (R0 == L1) {
9456     std::swap(L0, L1);
9457     LPred = ICmpInst::getSwappedCmpPredicate(LPred);
9458   }
9459   if (L1 == R1) {
9460     // If we have L0 == R0 and L1 == R1, then make L1/R1 the constants.
9461     if (L0 != R0 || match(L0, m_ImmConstant())) {
9462       std::swap(L0, L1);
9463       LPred = ICmpInst::getSwappedCmpPredicate(LPred);
9464       std::swap(R0, R1);
9465       RPred = ICmpInst::getSwappedCmpPredicate(RPred);
9466     }
9467   }
9468 
9469   // See if we can infer anything if operand-0 matches and we have at least one
9470   // constant.
9471   const APInt *Unused;
9472   if (L0 == R0 && (match(L1, m_APInt(Unused)) || match(R1, m_APInt(Unused)))) {
9473     // Potential TODO: We could also further use the constant range of L0/R0 to
9474     // further constraint the constant ranges. At the moment this leads to
9475     // several regressions related to not transforming `multi_use(A + C0) eq/ne
9476     // C1` (see discussion: D58633).
9477     ConstantRange LCR = computeConstantRange(
9478         L1, ICmpInst::isSigned(LPred), /* UseInstrInfo=*/true, /*AC=*/nullptr,
9479         /*CxtI=*/nullptr, /*DT=*/nullptr, MaxAnalysisRecursionDepth - 1);
9480     ConstantRange RCR = computeConstantRange(
9481         R1, ICmpInst::isSigned(RPred), /* UseInstrInfo=*/true, /*AC=*/nullptr,
9482         /*CxtI=*/nullptr, /*DT=*/nullptr, MaxAnalysisRecursionDepth - 1);
9483     // Even if L1/R1 are not both constant, we can still sometimes deduce
9484     // relationship from a single constant. For example X u> Y implies X != 0.
9485     if (auto R = isImpliedCondCommonOperandWithCR(LPred, LCR, RPred, RCR))
9486       return R;
9487     // If both L1/R1 were exact constant ranges and we didn't get anything
9488     // here, we won't be able to deduce this.
9489     if (match(L1, m_APInt(Unused)) && match(R1, m_APInt(Unused)))
9490       return std::nullopt;
9491   }
9492 
9493   // Can we infer anything when the two compares have matching operands?
9494   if (L0 == R0 && L1 == R1)
9495     return ICmpInst::isImpliedByMatchingCmp(LPred, RPred);
9496 
9497   // It only really makes sense in the context of signed comparison for "X - Y
9498   // must be positive if X >= Y and no overflow".
9499   // Take SGT as an example:  L0:x > L1:y and C >= 0
9500   //                      ==> R0:(x -nsw y) < R1:(-C) is false
9501   CmpInst::Predicate SignedLPred = LPred.getPreferredSignedPredicate();
9502   if ((SignedLPred == ICmpInst::ICMP_SGT ||
9503        SignedLPred == ICmpInst::ICMP_SGE) &&
9504       match(R0, m_NSWSub(m_Specific(L0), m_Specific(L1)))) {
9505     if (match(R1, m_NonPositive()) &&
9506         ICmpInst::isImpliedByMatchingCmp(SignedLPred, RPred) == false)
9507       return false;
9508   }
9509 
9510   // Take SLT as an example:  L0:x < L1:y and C <= 0
9511   //                      ==> R0:(x -nsw y) < R1:(-C) is true
9512   if ((SignedLPred == ICmpInst::ICMP_SLT ||
9513        SignedLPred == ICmpInst::ICMP_SLE) &&
9514       match(R0, m_NSWSub(m_Specific(L0), m_Specific(L1)))) {
9515     if (match(R1, m_NonNegative()) &&
9516         ICmpInst::isImpliedByMatchingCmp(SignedLPred, RPred) == true)
9517       return true;
9518   }
9519 
9520   // L0 = R0 = L1 + R1, L0 >=u L1 implies R0 >=u R1, L0 <u L1 implies R0 <u R1
9521   if (L0 == R0 &&
9522       (LPred == ICmpInst::ICMP_ULT || LPred == ICmpInst::ICMP_UGE) &&
9523       (RPred == ICmpInst::ICMP_ULT || RPred == ICmpInst::ICMP_UGE) &&
9524       match(L0, m_c_Add(m_Specific(L1), m_Specific(R1))))
9525     return CmpPredicate::getMatching(LPred, RPred).has_value();
9526 
9527   if (auto P = CmpPredicate::getMatching(LPred, RPred))
9528     return isImpliedCondOperands(*P, L0, L1, R0, R1);
9529 
9530   return std::nullopt;
9531 }
9532 
9533 /// Return true if LHS implies RHS is true.  Return false if LHS implies RHS is
9534 /// false.  Otherwise, return std::nullopt if we can't infer anything.  We
9535 /// expect the RHS to be an icmp and the LHS to be an 'and', 'or', or a 'select'
9536 /// instruction.
9537 static std::optional<bool>
9538 isImpliedCondAndOr(const Instruction *LHS, CmpPredicate RHSPred,
9539                    const Value *RHSOp0, const Value *RHSOp1,
9540                    const DataLayout &DL, bool LHSIsTrue, unsigned Depth) {
9541   // The LHS must be an 'or', 'and', or a 'select' instruction.
9542   assert((LHS->getOpcode() == Instruction::And ||
9543           LHS->getOpcode() == Instruction::Or ||
9544           LHS->getOpcode() == Instruction::Select) &&
9545          "Expected LHS to be 'and', 'or', or 'select'.");
9546 
9547   assert(Depth <= MaxAnalysisRecursionDepth && "Hit recursion limit");
9548 
9549   // If the result of an 'or' is false, then we know both legs of the 'or' are
9550   // false.  Similarly, if the result of an 'and' is true, then we know both
9551   // legs of the 'and' are true.
9552   const Value *ALHS, *ARHS;
9553   if ((!LHSIsTrue && match(LHS, m_LogicalOr(m_Value(ALHS), m_Value(ARHS)))) ||
9554       (LHSIsTrue && match(LHS, m_LogicalAnd(m_Value(ALHS), m_Value(ARHS))))) {
9555     // FIXME: Make this non-recursion.
9556     if (std::optional<bool> Implication = isImpliedCondition(
9557             ALHS, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue, Depth + 1))
9558       return Implication;
9559     if (std::optional<bool> Implication = isImpliedCondition(
9560             ARHS, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue, Depth + 1))
9561       return Implication;
9562     return std::nullopt;
9563   }
9564   return std::nullopt;
9565 }
9566 
9567 std::optional<bool>
9568 llvm::isImpliedCondition(const Value *LHS, CmpPredicate RHSPred,
9569                          const Value *RHSOp0, const Value *RHSOp1,
9570                          const DataLayout &DL, bool LHSIsTrue, unsigned Depth) {
9571   // Bail out when we hit the limit.
9572   if (Depth == MaxAnalysisRecursionDepth)
9573     return std::nullopt;
9574 
9575   // A mismatch occurs when we compare a scalar cmp to a vector cmp, for
9576   // example.
9577   if (RHSOp0->getType()->isVectorTy() != LHS->getType()->isVectorTy())
9578     return std::nullopt;
9579 
9580   assert(LHS->getType()->isIntOrIntVectorTy(1) &&
9581          "Expected integer type only!");
9582 
9583   // Match not
9584   if (match(LHS, m_Not(m_Value(LHS))))
9585     LHSIsTrue = !LHSIsTrue;
9586 
9587   // Both LHS and RHS are icmps.
9588   const ICmpInst *LHSCmp = dyn_cast<ICmpInst>(LHS);
9589   if (LHSCmp)
9590     return isImpliedCondICmps(LHSCmp, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue);
9591 
9592   /// The LHS should be an 'or', 'and', or a 'select' instruction.  We expect
9593   /// the RHS to be an icmp.
9594   /// FIXME: Add support for and/or/select on the RHS.
9595   if (const Instruction *LHSI = dyn_cast<Instruction>(LHS)) {
9596     if ((LHSI->getOpcode() == Instruction::And ||
9597          LHSI->getOpcode() == Instruction::Or ||
9598          LHSI->getOpcode() == Instruction::Select))
9599       return isImpliedCondAndOr(LHSI, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue,
9600                                 Depth);
9601   }
9602   return std::nullopt;
9603 }
9604 
9605 std::optional<bool> llvm::isImpliedCondition(const Value *LHS, const Value *RHS,
9606                                              const DataLayout &DL,
9607                                              bool LHSIsTrue, unsigned Depth) {
9608   // LHS ==> RHS by definition
9609   if (LHS == RHS)
9610     return LHSIsTrue;
9611 
9612   // Match not
9613   bool InvertRHS = false;
9614   if (match(RHS, m_Not(m_Value(RHS)))) {
9615     if (LHS == RHS)
9616       return !LHSIsTrue;
9617     InvertRHS = true;
9618   }
9619 
9620   if (const ICmpInst *RHSCmp = dyn_cast<ICmpInst>(RHS)) {
9621     if (auto Implied = isImpliedCondition(
9622             LHS, RHSCmp->getCmpPredicate(), RHSCmp->getOperand(0),
9623             RHSCmp->getOperand(1), DL, LHSIsTrue, Depth))
9624       return InvertRHS ? !*Implied : *Implied;
9625     return std::nullopt;
9626   }
9627 
9628   if (Depth == MaxAnalysisRecursionDepth)
9629     return std::nullopt;
9630 
9631   // LHS ==> (RHS1 || RHS2) if LHS ==> RHS1 or LHS ==> RHS2
9632   // LHS ==> !(RHS1 && RHS2) if LHS ==> !RHS1 or LHS ==> !RHS2
9633   const Value *RHS1, *RHS2;
9634   if (match(RHS, m_LogicalOr(m_Value(RHS1), m_Value(RHS2)))) {
9635     if (std::optional<bool> Imp =
9636             isImpliedCondition(LHS, RHS1, DL, LHSIsTrue, Depth + 1))
9637       if (*Imp == true)
9638         return !InvertRHS;
9639     if (std::optional<bool> Imp =
9640             isImpliedCondition(LHS, RHS2, DL, LHSIsTrue, Depth + 1))
9641       if (*Imp == true)
9642         return !InvertRHS;
9643   }
9644   if (match(RHS, m_LogicalAnd(m_Value(RHS1), m_Value(RHS2)))) {
9645     if (std::optional<bool> Imp =
9646             isImpliedCondition(LHS, RHS1, DL, LHSIsTrue, Depth + 1))
9647       if (*Imp == false)
9648         return InvertRHS;
9649     if (std::optional<bool> Imp =
9650             isImpliedCondition(LHS, RHS2, DL, LHSIsTrue, Depth + 1))
9651       if (*Imp == false)
9652         return InvertRHS;
9653   }
9654 
9655   return std::nullopt;
9656 }
9657 
9658 // Returns a pair (Condition, ConditionIsTrue), where Condition is a branch
9659 // condition dominating ContextI or nullptr, if no condition is found.
9660 static std::pair<Value *, bool>
9661 getDomPredecessorCondition(const Instruction *ContextI) {
9662   if (!ContextI || !ContextI->getParent())
9663     return {nullptr, false};
9664 
9665   // TODO: This is a poor/cheap way to determine dominance. Should we use a
9666   // dominator tree (eg, from a SimplifyQuery) instead?
9667   const BasicBlock *ContextBB = ContextI->getParent();
9668   const BasicBlock *PredBB = ContextBB->getSinglePredecessor();
9669   if (!PredBB)
9670     return {nullptr, false};
9671 
9672   // We need a conditional branch in the predecessor.
9673   Value *PredCond;
9674   BasicBlock *TrueBB, *FalseBB;
9675   if (!match(PredBB->getTerminator(), m_Br(m_Value(PredCond), TrueBB, FalseBB)))
9676     return {nullptr, false};
9677 
9678   // The branch should get simplified. Don't bother simplifying this condition.
9679   if (TrueBB == FalseBB)
9680     return {nullptr, false};
9681 
9682   assert((TrueBB == ContextBB || FalseBB == ContextBB) &&
9683          "Predecessor block does not point to successor?");
9684 
9685   // Is this condition implied by the predecessor condition?
9686   return {PredCond, TrueBB == ContextBB};
9687 }
9688 
9689 std::optional<bool> llvm::isImpliedByDomCondition(const Value *Cond,
9690                                                   const Instruction *ContextI,
9691                                                   const DataLayout &DL) {
9692   assert(Cond->getType()->isIntOrIntVectorTy(1) && "Condition must be bool");
9693   auto PredCond = getDomPredecessorCondition(ContextI);
9694   if (PredCond.first)
9695     return isImpliedCondition(PredCond.first, Cond, DL, PredCond.second);
9696   return std::nullopt;
9697 }
9698 
9699 std::optional<bool> llvm::isImpliedByDomCondition(CmpPredicate Pred,
9700                                                   const Value *LHS,
9701                                                   const Value *RHS,
9702                                                   const Instruction *ContextI,
9703                                                   const DataLayout &DL) {
9704   auto PredCond = getDomPredecessorCondition(ContextI);
9705   if (PredCond.first)
9706     return isImpliedCondition(PredCond.first, Pred, LHS, RHS, DL,
9707                               PredCond.second);
9708   return std::nullopt;
9709 }
9710 
9711 static void setLimitsForBinOp(const BinaryOperator &BO, APInt &Lower,
9712                               APInt &Upper, const InstrInfoQuery &IIQ,
9713                               bool PreferSignedRange) {
9714   unsigned Width = Lower.getBitWidth();
9715   const APInt *C;
9716   switch (BO.getOpcode()) {
9717   case Instruction::Add:
9718     if (match(BO.getOperand(1), m_APInt(C)) && !C->isZero()) {
9719       bool HasNSW = IIQ.hasNoSignedWrap(&BO);
9720       bool HasNUW = IIQ.hasNoUnsignedWrap(&BO);
9721 
9722       // If the caller expects a signed compare, then try to use a signed range.
9723       // Otherwise if both no-wraps are set, use the unsigned range because it
9724       // is never larger than the signed range. Example:
9725       // "add nuw nsw i8 X, -2" is unsigned [254,255] vs. signed [-128, 125].
9726       if (PreferSignedRange && HasNSW && HasNUW)
9727         HasNUW = false;
9728 
9729       if (HasNUW) {
9730         // 'add nuw x, C' produces [C, UINT_MAX].
9731         Lower = *C;
9732       } else if (HasNSW) {
9733         if (C->isNegative()) {
9734           // 'add nsw x, -C' produces [SINT_MIN, SINT_MAX - C].
9735           Lower = APInt::getSignedMinValue(Width);
9736           Upper = APInt::getSignedMaxValue(Width) + *C + 1;
9737         } else {
9738           // 'add nsw x, +C' produces [SINT_MIN + C, SINT_MAX].
9739           Lower = APInt::getSignedMinValue(Width) + *C;
9740           Upper = APInt::getSignedMaxValue(Width) + 1;
9741         }
9742       }
9743     }
9744     break;
9745 
9746   case Instruction::And:
9747     if (match(BO.getOperand(1), m_APInt(C)))
9748       // 'and x, C' produces [0, C].
9749       Upper = *C + 1;
9750     // X & -X is a power of two or zero. So we can cap the value at max power of
9751     // two.
9752     if (match(BO.getOperand(0), m_Neg(m_Specific(BO.getOperand(1)))) ||
9753         match(BO.getOperand(1), m_Neg(m_Specific(BO.getOperand(0)))))
9754       Upper = APInt::getSignedMinValue(Width) + 1;
9755     break;
9756 
9757   case Instruction::Or:
9758     if (match(BO.getOperand(1), m_APInt(C)))
9759       // 'or x, C' produces [C, UINT_MAX].
9760       Lower = *C;
9761     break;
9762 
9763   case Instruction::AShr:
9764     if (match(BO.getOperand(1), m_APInt(C)) && C->ult(Width)) {
9765       // 'ashr x, C' produces [INT_MIN >> C, INT_MAX >> C].
9766       Lower = APInt::getSignedMinValue(Width).ashr(*C);
9767       Upper = APInt::getSignedMaxValue(Width).ashr(*C) + 1;
9768     } else if (match(BO.getOperand(0), m_APInt(C))) {
9769       unsigned ShiftAmount = Width - 1;
9770       if (!C->isZero() && IIQ.isExact(&BO))
9771         ShiftAmount = C->countr_zero();
9772       if (C->isNegative()) {
9773         // 'ashr C, x' produces [C, C >> (Width-1)]
9774         Lower = *C;
9775         Upper = C->ashr(ShiftAmount) + 1;
9776       } else {
9777         // 'ashr C, x' produces [C >> (Width-1), C]
9778         Lower = C->ashr(ShiftAmount);
9779         Upper = *C + 1;
9780       }
9781     }
9782     break;
9783 
9784   case Instruction::LShr:
9785     if (match(BO.getOperand(1), m_APInt(C)) && C->ult(Width)) {
9786       // 'lshr x, C' produces [0, UINT_MAX >> C].
9787       Upper = APInt::getAllOnes(Width).lshr(*C) + 1;
9788     } else if (match(BO.getOperand(0), m_APInt(C))) {
9789       // 'lshr C, x' produces [C >> (Width-1), C].
9790       unsigned ShiftAmount = Width - 1;
9791       if (!C->isZero() && IIQ.isExact(&BO))
9792         ShiftAmount = C->countr_zero();
9793       Lower = C->lshr(ShiftAmount);
9794       Upper = *C + 1;
9795     }
9796     break;
9797 
9798   case Instruction::Shl:
9799     if (match(BO.getOperand(0), m_APInt(C))) {
9800       if (IIQ.hasNoUnsignedWrap(&BO)) {
9801         // 'shl nuw C, x' produces [C, C << CLZ(C)]
9802         Lower = *C;
9803         Upper = Lower.shl(Lower.countl_zero()) + 1;
9804       } else if (BO.hasNoSignedWrap()) { // TODO: What if both nuw+nsw?
9805         if (C->isNegative()) {
9806           // 'shl nsw C, x' produces [C << CLO(C)-1, C]
9807           unsigned ShiftAmount = C->countl_one() - 1;
9808           Lower = C->shl(ShiftAmount);
9809           Upper = *C + 1;
9810         } else {
9811           // 'shl nsw C, x' produces [C, C << CLZ(C)-1]
9812           unsigned ShiftAmount = C->countl_zero() - 1;
9813           Lower = *C;
9814           Upper = C->shl(ShiftAmount) + 1;
9815         }
9816       } else {
9817         // If lowbit is set, value can never be zero.
9818         if ((*C)[0])
9819           Lower = APInt::getOneBitSet(Width, 0);
9820         // If we are shifting a constant the largest it can be is if the longest
9821         // sequence of consecutive ones is shifted to the highbits (breaking
9822         // ties for which sequence is higher). At the moment we take a liberal
9823         // upper bound on this by just popcounting the constant.
9824         // TODO: There may be a bitwise trick for it longest/highest
9825         // consecutative sequence of ones (naive method is O(Width) loop).
9826         Upper = APInt::getHighBitsSet(Width, C->popcount()) + 1;
9827       }
9828     } else if (match(BO.getOperand(1), m_APInt(C)) && C->ult(Width)) {
9829       Upper = APInt::getBitsSetFrom(Width, C->getZExtValue()) + 1;
9830     }
9831     break;
9832 
9833   case Instruction::SDiv:
9834     if (match(BO.getOperand(1), m_APInt(C))) {
9835       APInt IntMin = APInt::getSignedMinValue(Width);
9836       APInt IntMax = APInt::getSignedMaxValue(Width);
9837       if (C->isAllOnes()) {
9838         // 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX]
9839         //    where C != -1 and C != 0 and C != 1
9840         Lower = IntMin + 1;
9841         Upper = IntMax + 1;
9842       } else if (C->countl_zero() < Width - 1) {
9843         // 'sdiv x, C' produces [INT_MIN / C, INT_MAX / C]
9844         //    where C != -1 and C != 0 and C != 1
9845         Lower = IntMin.sdiv(*C);
9846         Upper = IntMax.sdiv(*C);
9847         if (Lower.sgt(Upper))
9848           std::swap(Lower, Upper);
9849         Upper = Upper + 1;
9850         assert(Upper != Lower && "Upper part of range has wrapped!");
9851       }
9852     } else if (match(BO.getOperand(0), m_APInt(C))) {
9853       if (C->isMinSignedValue()) {
9854         // 'sdiv INT_MIN, x' produces [INT_MIN, INT_MIN / -2].
9855         Lower = *C;
9856         Upper = Lower.lshr(1) + 1;
9857       } else {
9858         // 'sdiv C, x' produces [-|C|, |C|].
9859         Upper = C->abs() + 1;
9860         Lower = (-Upper) + 1;
9861       }
9862     }
9863     break;
9864 
9865   case Instruction::UDiv:
9866     if (match(BO.getOperand(1), m_APInt(C)) && !C->isZero()) {
9867       // 'udiv x, C' produces [0, UINT_MAX / C].
9868       Upper = APInt::getMaxValue(Width).udiv(*C) + 1;
9869     } else if (match(BO.getOperand(0), m_APInt(C))) {
9870       // 'udiv C, x' produces [0, C].
9871       Upper = *C + 1;
9872     }
9873     break;
9874 
9875   case Instruction::SRem:
9876     if (match(BO.getOperand(1), m_APInt(C))) {
9877       // 'srem x, C' produces (-|C|, |C|).
9878       Upper = C->abs();
9879       Lower = (-Upper) + 1;
9880     } else if (match(BO.getOperand(0), m_APInt(C))) {
9881       if (C->isNegative()) {
9882         // 'srem -|C|, x' produces [-|C|, 0].
9883         Upper = 1;
9884         Lower = *C;
9885       } else {
9886         // 'srem |C|, x' produces [0, |C|].
9887         Upper = *C + 1;
9888       }
9889     }
9890     break;
9891 
9892   case Instruction::URem:
9893     if (match(BO.getOperand(1), m_APInt(C)))
9894       // 'urem x, C' produces [0, C).
9895       Upper = *C;
9896     else if (match(BO.getOperand(0), m_APInt(C)))
9897       // 'urem C, x' produces [0, C].
9898       Upper = *C + 1;
9899     break;
9900 
9901   default:
9902     break;
9903   }
9904 }
9905 
9906 static ConstantRange getRangeForIntrinsic(const IntrinsicInst &II,
9907                                           bool UseInstrInfo) {
9908   unsigned Width = II.getType()->getScalarSizeInBits();
9909   const APInt *C;
9910   switch (II.getIntrinsicID()) {
9911   case Intrinsic::ctlz:
9912   case Intrinsic::cttz: {
9913     APInt Upper(Width, Width);
9914     if (!UseInstrInfo || !match(II.getArgOperand(1), m_One()))
9915       Upper += 1;
9916     // Maximum of set/clear bits is the bit width.
9917     return ConstantRange::getNonEmpty(APInt::getZero(Width), Upper);
9918   }
9919   case Intrinsic::ctpop:
9920     // Maximum of set/clear bits is the bit width.
9921     return ConstantRange::getNonEmpty(APInt::getZero(Width),
9922                                       APInt(Width, Width) + 1);
9923   case Intrinsic::uadd_sat:
9924     // uadd.sat(x, C) produces [C, UINT_MAX].
9925     if (match(II.getOperand(0), m_APInt(C)) ||
9926         match(II.getOperand(1), m_APInt(C)))
9927       return ConstantRange::getNonEmpty(*C, APInt::getZero(Width));
9928     break;
9929   case Intrinsic::sadd_sat:
9930     if (match(II.getOperand(0), m_APInt(C)) ||
9931         match(II.getOperand(1), m_APInt(C))) {
9932       if (C->isNegative())
9933         // sadd.sat(x, -C) produces [SINT_MIN, SINT_MAX + (-C)].
9934         return ConstantRange::getNonEmpty(APInt::getSignedMinValue(Width),
9935                                           APInt::getSignedMaxValue(Width) + *C +
9936                                               1);
9937 
9938       // sadd.sat(x, +C) produces [SINT_MIN + C, SINT_MAX].
9939       return ConstantRange::getNonEmpty(APInt::getSignedMinValue(Width) + *C,
9940                                         APInt::getSignedMaxValue(Width) + 1);
9941     }
9942     break;
9943   case Intrinsic::usub_sat:
9944     // usub.sat(C, x) produces [0, C].
9945     if (match(II.getOperand(0), m_APInt(C)))
9946       return ConstantRange::getNonEmpty(APInt::getZero(Width), *C + 1);
9947 
9948     // usub.sat(x, C) produces [0, UINT_MAX - C].
9949     if (match(II.getOperand(1), m_APInt(C)))
9950       return ConstantRange::getNonEmpty(APInt::getZero(Width),
9951                                         APInt::getMaxValue(Width) - *C + 1);
9952     break;
9953   case Intrinsic::ssub_sat:
9954     if (match(II.getOperand(0), m_APInt(C))) {
9955       if (C->isNegative())
9956         // ssub.sat(-C, x) produces [SINT_MIN, -SINT_MIN + (-C)].
9957         return ConstantRange::getNonEmpty(APInt::getSignedMinValue(Width),
9958                                           *C - APInt::getSignedMinValue(Width) +
9959                                               1);
9960 
9961       // ssub.sat(+C, x) produces [-SINT_MAX + C, SINT_MAX].
9962       return ConstantRange::getNonEmpty(*C - APInt::getSignedMaxValue(Width),
9963                                         APInt::getSignedMaxValue(Width) + 1);
9964     } else if (match(II.getOperand(1), m_APInt(C))) {
9965       if (C->isNegative())
9966         // ssub.sat(x, -C) produces [SINT_MIN - (-C), SINT_MAX]:
9967         return ConstantRange::getNonEmpty(APInt::getSignedMinValue(Width) - *C,
9968                                           APInt::getSignedMaxValue(Width) + 1);
9969 
9970       // ssub.sat(x, +C) produces [SINT_MIN, SINT_MAX - C].
9971       return ConstantRange::getNonEmpty(APInt::getSignedMinValue(Width),
9972                                         APInt::getSignedMaxValue(Width) - *C +
9973                                             1);
9974     }
9975     break;
9976   case Intrinsic::umin:
9977   case Intrinsic::umax:
9978   case Intrinsic::smin:
9979   case Intrinsic::smax:
9980     if (!match(II.getOperand(0), m_APInt(C)) &&
9981         !match(II.getOperand(1), m_APInt(C)))
9982       break;
9983 
9984     switch (II.getIntrinsicID()) {
9985     case Intrinsic::umin:
9986       return ConstantRange::getNonEmpty(APInt::getZero(Width), *C + 1);
9987     case Intrinsic::umax:
9988       return ConstantRange::getNonEmpty(*C, APInt::getZero(Width));
9989     case Intrinsic::smin:
9990       return ConstantRange::getNonEmpty(APInt::getSignedMinValue(Width),
9991                                         *C + 1);
9992     case Intrinsic::smax:
9993       return ConstantRange::getNonEmpty(*C,
9994                                         APInt::getSignedMaxValue(Width) + 1);
9995     default:
9996       llvm_unreachable("Must be min/max intrinsic");
9997     }
9998     break;
9999   case Intrinsic::abs:
10000     // If abs of SIGNED_MIN is poison, then the result is [0..SIGNED_MAX],
10001     // otherwise it is [0..SIGNED_MIN], as -SIGNED_MIN == SIGNED_MIN.
10002     if (match(II.getOperand(1), m_One()))
10003       return ConstantRange::getNonEmpty(APInt::getZero(Width),
10004                                         APInt::getSignedMaxValue(Width) + 1);
10005 
10006     return ConstantRange::getNonEmpty(APInt::getZero(Width),
10007                                       APInt::getSignedMinValue(Width) + 1);
10008   case Intrinsic::vscale:
10009     if (!II.getParent() || !II.getFunction())
10010       break;
10011     return getVScaleRange(II.getFunction(), Width);
10012   case Intrinsic::scmp:
10013   case Intrinsic::ucmp:
10014     return ConstantRange::getNonEmpty(APInt::getAllOnes(Width),
10015                                       APInt(Width, 2));
10016   default:
10017     break;
10018   }
10019 
10020   return ConstantRange::getFull(Width);
10021 }
10022 
10023 static ConstantRange getRangeForSelectPattern(const SelectInst &SI,
10024                                               const InstrInfoQuery &IIQ) {
10025   unsigned BitWidth = SI.getType()->getScalarSizeInBits();
10026   const Value *LHS = nullptr, *RHS = nullptr;
10027   SelectPatternResult R = matchSelectPattern(&SI, LHS, RHS);
10028   if (R.Flavor == SPF_UNKNOWN)
10029     return ConstantRange::getFull(BitWidth);
10030 
10031   if (R.Flavor == SelectPatternFlavor::SPF_ABS) {
10032     // If the negation part of the abs (in RHS) has the NSW flag,
10033     // then the result of abs(X) is [0..SIGNED_MAX],
10034     // otherwise it is [0..SIGNED_MIN], as -SIGNED_MIN == SIGNED_MIN.
10035     if (match(RHS, m_Neg(m_Specific(LHS))) &&
10036         IIQ.hasNoSignedWrap(cast<Instruction>(RHS)))
10037       return ConstantRange::getNonEmpty(APInt::getZero(BitWidth),
10038                                         APInt::getSignedMaxValue(BitWidth) + 1);
10039 
10040     return ConstantRange::getNonEmpty(APInt::getZero(BitWidth),
10041                                       APInt::getSignedMinValue(BitWidth) + 1);
10042   }
10043 
10044   if (R.Flavor == SelectPatternFlavor::SPF_NABS) {
10045     // The result of -abs(X) is <= 0.
10046     return ConstantRange::getNonEmpty(APInt::getSignedMinValue(BitWidth),
10047                                       APInt(BitWidth, 1));
10048   }
10049 
10050   const APInt *C;
10051   if (!match(LHS, m_APInt(C)) && !match(RHS, m_APInt(C)))
10052     return ConstantRange::getFull(BitWidth);
10053 
10054   switch (R.Flavor) {
10055   case SPF_UMIN:
10056     return ConstantRange::getNonEmpty(APInt::getZero(BitWidth), *C + 1);
10057   case SPF_UMAX:
10058     return ConstantRange::getNonEmpty(*C, APInt::getZero(BitWidth));
10059   case SPF_SMIN:
10060     return ConstantRange::getNonEmpty(APInt::getSignedMinValue(BitWidth),
10061                                       *C + 1);
10062   case SPF_SMAX:
10063     return ConstantRange::getNonEmpty(*C,
10064                                       APInt::getSignedMaxValue(BitWidth) + 1);
10065   default:
10066     return ConstantRange::getFull(BitWidth);
10067   }
10068 }
10069 
10070 static void setLimitForFPToI(const Instruction *I, APInt &Lower, APInt &Upper) {
10071   // The maximum representable value of a half is 65504. For floats the maximum
10072   // value is 3.4e38 which requires roughly 129 bits.
10073   unsigned BitWidth = I->getType()->getScalarSizeInBits();
10074   if (!I->getOperand(0)->getType()->getScalarType()->isHalfTy())
10075     return;
10076   if (isa<FPToSIInst>(I) && BitWidth >= 17) {
10077     Lower = APInt(BitWidth, -65504, true);
10078     Upper = APInt(BitWidth, 65505);
10079   }
10080 
10081   if (isa<FPToUIInst>(I) && BitWidth >= 16) {
10082     // For a fptoui the lower limit is left as 0.
10083     Upper = APInt(BitWidth, 65505);
10084   }
10085 }
10086 
10087 ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned,
10088                                          bool UseInstrInfo, AssumptionCache *AC,
10089                                          const Instruction *CtxI,
10090                                          const DominatorTree *DT,
10091                                          unsigned Depth) {
10092   assert(V->getType()->isIntOrIntVectorTy() && "Expected integer instruction");
10093 
10094   if (Depth == MaxAnalysisRecursionDepth)
10095     return ConstantRange::getFull(V->getType()->getScalarSizeInBits());
10096 
10097   if (auto *C = dyn_cast<Constant>(V))
10098     return C->toConstantRange();
10099 
10100   unsigned BitWidth = V->getType()->getScalarSizeInBits();
10101   InstrInfoQuery IIQ(UseInstrInfo);
10102   ConstantRange CR = ConstantRange::getFull(BitWidth);
10103   if (auto *BO = dyn_cast<BinaryOperator>(V)) {
10104     APInt Lower = APInt(BitWidth, 0);
10105     APInt Upper = APInt(BitWidth, 0);
10106     // TODO: Return ConstantRange.
10107     setLimitsForBinOp(*BO, Lower, Upper, IIQ, ForSigned);
10108     CR = ConstantRange::getNonEmpty(Lower, Upper);
10109   } else if (auto *II = dyn_cast<IntrinsicInst>(V))
10110     CR = getRangeForIntrinsic(*II, UseInstrInfo);
10111   else if (auto *SI = dyn_cast<SelectInst>(V)) {
10112     ConstantRange CRTrue = computeConstantRange(
10113         SI->getTrueValue(), ForSigned, UseInstrInfo, AC, CtxI, DT, Depth + 1);
10114     ConstantRange CRFalse = computeConstantRange(
10115         SI->getFalseValue(), ForSigned, UseInstrInfo, AC, CtxI, DT, Depth + 1);
10116     CR = CRTrue.unionWith(CRFalse);
10117     CR = CR.intersectWith(getRangeForSelectPattern(*SI, IIQ));
10118   } else if (isa<FPToUIInst>(V) || isa<FPToSIInst>(V)) {
10119     APInt Lower = APInt(BitWidth, 0);
10120     APInt Upper = APInt(BitWidth, 0);
10121     // TODO: Return ConstantRange.
10122     setLimitForFPToI(cast<Instruction>(V), Lower, Upper);
10123     CR = ConstantRange::getNonEmpty(Lower, Upper);
10124   } else if (const auto *A = dyn_cast<Argument>(V))
10125     if (std::optional<ConstantRange> Range = A->getRange())
10126       CR = *Range;
10127 
10128   if (auto *I = dyn_cast<Instruction>(V)) {
10129     if (auto *Range = IIQ.getMetadata(I, LLVMContext::MD_range))
10130       CR = CR.intersectWith(getConstantRangeFromMetadata(*Range));
10131 
10132     if (const auto *CB = dyn_cast<CallBase>(V))
10133       if (std::optional<ConstantRange> Range = CB->getRange())
10134         CR = CR.intersectWith(*Range);
10135   }
10136 
10137   if (CtxI && AC) {
10138     // Try to restrict the range based on information from assumptions.
10139     for (auto &AssumeVH : AC->assumptionsFor(V)) {
10140       if (!AssumeVH)
10141         continue;
10142       CallInst *I = cast<CallInst>(AssumeVH);
10143       assert(I->getParent()->getParent() == CtxI->getParent()->getParent() &&
10144              "Got assumption for the wrong function!");
10145       assert(I->getIntrinsicID() == Intrinsic::assume &&
10146              "must be an assume intrinsic");
10147 
10148       if (!isValidAssumeForContext(I, CtxI, DT))
10149         continue;
10150       Value *Arg = I->getArgOperand(0);
10151       ICmpInst *Cmp = dyn_cast<ICmpInst>(Arg);
10152       // Currently we just use information from comparisons.
10153       if (!Cmp || Cmp->getOperand(0) != V)
10154         continue;
10155       // TODO: Set "ForSigned" parameter via Cmp->isSigned()?
10156       ConstantRange RHS =
10157           computeConstantRange(Cmp->getOperand(1), /* ForSigned */ false,
10158                                UseInstrInfo, AC, I, DT, Depth + 1);
10159       CR = CR.intersectWith(
10160           ConstantRange::makeAllowedICmpRegion(Cmp->getPredicate(), RHS));
10161     }
10162   }
10163 
10164   return CR;
10165 }
10166 
10167 static void
10168 addValueAffectedByCondition(Value *V,
10169                             function_ref<void(Value *)> InsertAffected) {
10170   assert(V != nullptr);
10171   if (isa<Argument>(V) || isa<GlobalValue>(V)) {
10172     InsertAffected(V);
10173   } else if (auto *I = dyn_cast<Instruction>(V)) {
10174     InsertAffected(V);
10175 
10176     // Peek through unary operators to find the source of the condition.
10177     Value *Op;
10178     if (match(I, m_CombineOr(m_PtrToInt(m_Value(Op)), m_Trunc(m_Value(Op))))) {
10179       if (isa<Instruction>(Op) || isa<Argument>(Op))
10180         InsertAffected(Op);
10181     }
10182   }
10183 }
10184 
10185 void llvm::findValuesAffectedByCondition(
10186     Value *Cond, bool IsAssume, function_ref<void(Value *)> InsertAffected) {
10187   auto AddAffected = [&InsertAffected](Value *V) {
10188     addValueAffectedByCondition(V, InsertAffected);
10189   };
10190 
10191   auto AddCmpOperands = [&AddAffected, IsAssume](Value *LHS, Value *RHS) {
10192     if (IsAssume) {
10193       AddAffected(LHS);
10194       AddAffected(RHS);
10195     } else if (match(RHS, m_Constant()))
10196       AddAffected(LHS);
10197   };
10198 
10199   SmallVector<Value *, 8> Worklist;
10200   SmallPtrSet<Value *, 8> Visited;
10201   Worklist.push_back(Cond);
10202   while (!Worklist.empty()) {
10203     Value *V = Worklist.pop_back_val();
10204     if (!Visited.insert(V).second)
10205       continue;
10206 
10207     CmpPredicate Pred;
10208     Value *A, *B, *X;
10209 
10210     if (IsAssume) {
10211       AddAffected(V);
10212       if (match(V, m_Not(m_Value(X))))
10213         AddAffected(X);
10214     }
10215 
10216     if (match(V, m_LogicalOp(m_Value(A), m_Value(B)))) {
10217       // assume(A && B) is split to -> assume(A); assume(B);
10218       // assume(!(A || B)) is split to -> assume(!A); assume(!B);
10219       // Finally, assume(A || B) / assume(!(A && B)) generally don't provide
10220       // enough information to be worth handling (intersection of information as
10221       // opposed to union).
10222       if (!IsAssume) {
10223         Worklist.push_back(A);
10224         Worklist.push_back(B);
10225       }
10226     } else if (match(V, m_ICmp(Pred, m_Value(A), m_Value(B)))) {
10227       AddCmpOperands(A, B);
10228 
10229       bool HasRHSC = match(B, m_ConstantInt());
10230       if (ICmpInst::isEquality(Pred)) {
10231         if (HasRHSC) {
10232           Value *Y;
10233           // (X & C) or (X | C) or (X ^ C).
10234           // (X << C) or (X >>_s C) or (X >>_u C).
10235           if (match(A, m_BitwiseLogic(m_Value(X), m_ConstantInt())) ||
10236               match(A, m_Shift(m_Value(X), m_ConstantInt())))
10237             AddAffected(X);
10238           else if (match(A, m_And(m_Value(X), m_Value(Y))) ||
10239                    match(A, m_Or(m_Value(X), m_Value(Y)))) {
10240             AddAffected(X);
10241             AddAffected(Y);
10242           }
10243         }
10244       } else {
10245         if (HasRHSC) {
10246           // Handle (A + C1) u< C2, which is the canonical form of
10247           // A > C3 && A < C4.
10248           if (match(A, m_AddLike(m_Value(X), m_ConstantInt())))
10249             AddAffected(X);
10250 
10251           if (ICmpInst::isUnsigned(Pred)) {
10252             Value *Y;
10253             // X & Y u> C    -> X >u C && Y >u C
10254             // X | Y u< C    -> X u< C && Y u< C
10255             // X nuw+ Y u< C -> X u< C && Y u< C
10256             if (match(A, m_And(m_Value(X), m_Value(Y))) ||
10257                 match(A, m_Or(m_Value(X), m_Value(Y))) ||
10258                 match(A, m_NUWAdd(m_Value(X), m_Value(Y)))) {
10259               AddAffected(X);
10260               AddAffected(Y);
10261             }
10262             // X nuw- Y u> C -> X u> C
10263             if (match(A, m_NUWSub(m_Value(X), m_Value())))
10264               AddAffected(X);
10265           }
10266         }
10267 
10268         // Handle icmp slt/sgt (bitcast X to int), 0/-1, which is supported
10269         // by computeKnownFPClass().
10270         if (match(A, m_ElementWiseBitCast(m_Value(X)))) {
10271           if (Pred == ICmpInst::ICMP_SLT && match(B, m_Zero()))
10272             InsertAffected(X);
10273           else if (Pred == ICmpInst::ICMP_SGT && match(B, m_AllOnes()))
10274             InsertAffected(X);
10275         }
10276       }
10277 
10278       if (HasRHSC && match(A, m_Intrinsic<Intrinsic::ctpop>(m_Value(X))))
10279         AddAffected(X);
10280     } else if (match(V, m_FCmp(Pred, m_Value(A), m_Value(B)))) {
10281       AddCmpOperands(A, B);
10282 
10283       // fcmp fneg(x), y
10284       // fcmp fabs(x), y
10285       // fcmp fneg(fabs(x)), y
10286       if (match(A, m_FNeg(m_Value(A))))
10287         AddAffected(A);
10288       if (match(A, m_FAbs(m_Value(A))))
10289         AddAffected(A);
10290 
10291     } else if (match(V, m_Intrinsic<Intrinsic::is_fpclass>(m_Value(A),
10292                                                            m_Value()))) {
10293       // Handle patterns that computeKnownFPClass() support.
10294       AddAffected(A);
10295     }
10296   }
10297 }
10298