xref: /llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp (revision 2131115be5b9d8b39af80973d9b64c0adc41d38d)
1 //===- InstCombineSimplifyDemanded.cpp ------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains logic for simplifying instructions based on information
10 // about how they are used.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "InstCombineInternal.h"
15 #include "llvm/Analysis/ValueTracking.h"
16 #include "llvm/IR/GetElementPtrTypeIterator.h"
17 #include "llvm/IR/IntrinsicInst.h"
18 #include "llvm/IR/PatternMatch.h"
19 #include "llvm/Support/KnownBits.h"
20 #include "llvm/Transforms/InstCombine/InstCombiner.h"
21 
22 using namespace llvm;
23 using namespace llvm::PatternMatch;
24 
25 #define DEBUG_TYPE "instcombine"
26 
27 static cl::opt<bool>
28     VerifyKnownBits("instcombine-verify-known-bits",
29                     cl::desc("Verify that computeKnownBits() and "
30                              "SimplifyDemandedBits() are consistent"),
31                     cl::Hidden, cl::init(false));
32 
33 static cl::opt<unsigned> SimplifyDemandedVectorEltsDepthLimit(
34     "instcombine-simplify-vector-elts-depth",
35     cl::desc(
36         "Depth limit when simplifying vector instructions and their operands"),
37     cl::Hidden, cl::init(10));
38 
39 /// Check to see if the specified operand of the specified instruction is a
40 /// constant integer. If so, check to see if there are any bits set in the
41 /// constant that are not demanded. If so, shrink the constant and return true.
42 static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo,
43                                    const APInt &Demanded) {
44   assert(I && "No instruction?");
45   assert(OpNo < I->getNumOperands() && "Operand index too large");
46 
47   // The operand must be a constant integer or splat integer.
48   Value *Op = I->getOperand(OpNo);
49   const APInt *C;
50   if (!match(Op, m_APInt(C)))
51     return false;
52 
53   // If there are no bits set that aren't demanded, nothing to do.
54   if (C->isSubsetOf(Demanded))
55     return false;
56 
57   // This instruction is producing bits that are not demanded. Shrink the RHS.
58   I->setOperand(OpNo, ConstantInt::get(Op->getType(), *C & Demanded));
59 
60   return true;
61 }
62 
63 /// Returns the bitwidth of the given scalar or pointer type. For vector types,
64 /// returns the element type's bitwidth.
65 static unsigned getBitWidth(Type *Ty, const DataLayout &DL) {
66   if (unsigned BitWidth = Ty->getScalarSizeInBits())
67     return BitWidth;
68 
69   return DL.getPointerTypeSizeInBits(Ty);
70 }
71 
72 /// Inst is an integer instruction that SimplifyDemandedBits knows about. See if
73 /// the instruction has any properties that allow us to simplify its operands.
74 bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst,
75                                                        KnownBits &Known) {
76   APInt DemandedMask(APInt::getAllOnes(Known.getBitWidth()));
77   Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, Known,
78                                      0, SQ.getWithInstruction(&Inst));
79   if (!V) return false;
80   if (V == &Inst) return true;
81   replaceInstUsesWith(Inst, V);
82   return true;
83 }
84 
85 /// Inst is an integer instruction that SimplifyDemandedBits knows about. See if
86 /// the instruction has any properties that allow us to simplify its operands.
87 bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) {
88   KnownBits Known(getBitWidth(Inst.getType(), DL));
89   return SimplifyDemandedInstructionBits(Inst, Known);
90 }
91 
92 /// This form of SimplifyDemandedBits simplifies the specified instruction
93 /// operand if possible, updating it in place. It returns true if it made any
94 /// change and false otherwise.
95 bool InstCombinerImpl::SimplifyDemandedBits(Instruction *I, unsigned OpNo,
96                                             const APInt &DemandedMask,
97                                             KnownBits &Known, unsigned Depth,
98                                             const SimplifyQuery &Q) {
99   Use &U = I->getOperandUse(OpNo);
100   Value *V = U.get();
101   if (isa<Constant>(V)) {
102     llvm::computeKnownBits(V, Known, Depth, Q);
103     return false;
104   }
105 
106   Known.resetAll();
107   if (DemandedMask.isZero()) {
108     // Not demanding any bits from V.
109     replaceUse(U, UndefValue::get(V->getType()));
110     return true;
111   }
112 
113   Instruction *VInst = dyn_cast<Instruction>(V);
114   if (!VInst) {
115     llvm::computeKnownBits(V, Known, Depth, Q);
116     return false;
117   }
118 
119   if (Depth == MaxAnalysisRecursionDepth)
120     return false;
121 
122   Value *NewVal;
123   if (VInst->hasOneUse()) {
124     // If the instruction has one use, we can directly simplify it.
125     NewVal = SimplifyDemandedUseBits(VInst, DemandedMask, Known, Depth, Q);
126   } else {
127     // If there are multiple uses of this instruction, then we can simplify
128     // VInst to some other value, but not modify the instruction.
129     NewVal =
130         SimplifyMultipleUseDemandedBits(VInst, DemandedMask, Known, Depth, Q);
131   }
132   if (!NewVal) return false;
133   if (Instruction* OpInst = dyn_cast<Instruction>(U))
134     salvageDebugInfo(*OpInst);
135 
136   replaceUse(U, NewVal);
137   return true;
138 }
139 
140 /// This function attempts to replace V with a simpler value based on the
141 /// demanded bits. When this function is called, it is known that only the bits
142 /// set in DemandedMask of the result of V are ever used downstream.
143 /// Consequently, depending on the mask and V, it may be possible to replace V
144 /// with a constant or one of its operands. In such cases, this function does
145 /// the replacement and returns true. In all other cases, it returns false after
146 /// analyzing the expression and setting KnownOne and known to be one in the
147 /// expression. Known.Zero contains all the bits that are known to be zero in
148 /// the expression. These are provided to potentially allow the caller (which
149 /// might recursively be SimplifyDemandedBits itself) to simplify the
150 /// expression.
151 /// Known.One and Known.Zero always follow the invariant that:
152 ///   Known.One & Known.Zero == 0.
153 /// That is, a bit can't be both 1 and 0. The bits in Known.One and Known.Zero
154 /// are accurate even for bits not in DemandedMask. Note
155 /// also that the bitwidth of V, DemandedMask, Known.Zero and Known.One must all
156 /// be the same.
157 ///
158 /// This returns null if it did not change anything and it permits no
159 /// simplification.  This returns V itself if it did some simplification of V's
160 /// operands based on the information about what bits are demanded. This returns
161 /// some other non-null value if it found out that V is equal to another value
162 /// in the context where the specified bits are demanded, but not for all users.
163 Value *InstCombinerImpl::SimplifyDemandedUseBits(Instruction *I,
164                                                  const APInt &DemandedMask,
165                                                  KnownBits &Known,
166                                                  unsigned Depth,
167                                                  const SimplifyQuery &Q) {
168   assert(I != nullptr && "Null pointer of Value???");
169   assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
170   uint32_t BitWidth = DemandedMask.getBitWidth();
171   Type *VTy = I->getType();
172   assert(
173       (!VTy->isIntOrIntVectorTy() || VTy->getScalarSizeInBits() == BitWidth) &&
174       Known.getBitWidth() == BitWidth &&
175       "Value *V, DemandedMask and Known must have same BitWidth");
176 
177   KnownBits LHSKnown(BitWidth), RHSKnown(BitWidth);
178 
179   // Update flags after simplifying an operand based on the fact that some high
180   // order bits are not demanded.
181   auto disableWrapFlagsBasedOnUnusedHighBits = [](Instruction *I,
182                                                   unsigned NLZ) {
183     if (NLZ > 0) {
184       // Disable the nsw and nuw flags here: We can no longer guarantee that
185       // we won't wrap after simplification. Removing the nsw/nuw flags is
186       // legal here because the top bit is not demanded.
187       I->setHasNoSignedWrap(false);
188       I->setHasNoUnsignedWrap(false);
189     }
190     return I;
191   };
192 
193   // If the high-bits of an ADD/SUB/MUL are not demanded, then we do not care
194   // about the high bits of the operands.
195   auto simplifyOperandsBasedOnUnusedHighBits = [&](APInt &DemandedFromOps) {
196     unsigned NLZ = DemandedMask.countl_zero();
197     // Right fill the mask of bits for the operands to demand the most
198     // significant bit and all those below it.
199     DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
200     if (ShrinkDemandedConstant(I, 0, DemandedFromOps) ||
201         SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1, Q) ||
202         ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
203         SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1, Q)) {
204       disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
205       return true;
206     }
207     return false;
208   };
209 
210   switch (I->getOpcode()) {
211   default:
212     llvm::computeKnownBits(I, Known, Depth, Q);
213     break;
214   case Instruction::And: {
215     // If either the LHS or the RHS are Zero, the result is zero.
216     if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1, Q) ||
217         SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnown.Zero, LHSKnown,
218                              Depth + 1, Q))
219       return I;
220 
221     Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
222                                          Depth, Q);
223 
224     // If the client is only demanding bits that we know, return the known
225     // constant.
226     if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
227       return Constant::getIntegerValue(VTy, Known.One);
228 
229     // If all of the demanded bits are known 1 on one side, return the other.
230     // These bits cannot contribute to the result of the 'and'.
231     if (DemandedMask.isSubsetOf(LHSKnown.Zero | RHSKnown.One))
232       return I->getOperand(0);
233     if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.One))
234       return I->getOperand(1);
235 
236     // If the RHS is a constant, see if we can simplify it.
237     if (ShrinkDemandedConstant(I, 1, DemandedMask & ~LHSKnown.Zero))
238       return I;
239 
240     break;
241   }
242   case Instruction::Or: {
243     // If either the LHS or the RHS are One, the result is One.
244     if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1, Q) ||
245         SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnown.One, LHSKnown,
246                              Depth + 1, Q)) {
247       // Disjoint flag may not longer hold.
248       I->dropPoisonGeneratingFlags();
249       return I;
250     }
251 
252     Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
253                                          Depth, Q);
254 
255     // If the client is only demanding bits that we know, return the known
256     // constant.
257     if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
258       return Constant::getIntegerValue(VTy, Known.One);
259 
260     // If all of the demanded bits are known zero on one side, return the other.
261     // These bits cannot contribute to the result of the 'or'.
262     if (DemandedMask.isSubsetOf(LHSKnown.One | RHSKnown.Zero))
263       return I->getOperand(0);
264     if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero))
265       return I->getOperand(1);
266 
267     // If the RHS is a constant, see if we can simplify it.
268     if (ShrinkDemandedConstant(I, 1, DemandedMask))
269       return I;
270 
271     // Infer disjoint flag if no common bits are set.
272     if (!cast<PossiblyDisjointInst>(I)->isDisjoint()) {
273       WithCache<const Value *> LHSCache(I->getOperand(0), LHSKnown),
274           RHSCache(I->getOperand(1), RHSKnown);
275       if (haveNoCommonBitsSet(LHSCache, RHSCache, Q)) {
276         cast<PossiblyDisjointInst>(I)->setIsDisjoint(true);
277         return I;
278       }
279     }
280 
281     break;
282   }
283   case Instruction::Xor: {
284     if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1, Q) ||
285         SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1, Q))
286       return I;
287     Value *LHS, *RHS;
288     if (DemandedMask == 1 &&
289         match(I->getOperand(0), m_Intrinsic<Intrinsic::ctpop>(m_Value(LHS))) &&
290         match(I->getOperand(1), m_Intrinsic<Intrinsic::ctpop>(m_Value(RHS)))) {
291       // (ctpop(X) ^ ctpop(Y)) & 1 --> ctpop(X^Y) & 1
292       IRBuilderBase::InsertPointGuard Guard(Builder);
293       Builder.SetInsertPoint(I);
294       auto *Xor = Builder.CreateXor(LHS, RHS);
295       return Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, Xor);
296     }
297 
298     Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
299                                          Depth, Q);
300 
301     // If the client is only demanding bits that we know, return the known
302     // constant.
303     if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
304       return Constant::getIntegerValue(VTy, Known.One);
305 
306     // If all of the demanded bits are known zero on one side, return the other.
307     // These bits cannot contribute to the result of the 'xor'.
308     if (DemandedMask.isSubsetOf(RHSKnown.Zero))
309       return I->getOperand(0);
310     if (DemandedMask.isSubsetOf(LHSKnown.Zero))
311       return I->getOperand(1);
312 
313     // If all of the demanded bits are known to be zero on one side or the
314     // other, turn this into an *inclusive* or.
315     //    e.g. (A & C1)^(B & C2) -> (A & C1)|(B & C2) iff C1&C2 == 0
316     if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.Zero)) {
317       Instruction *Or =
318           BinaryOperator::CreateOr(I->getOperand(0), I->getOperand(1));
319       if (DemandedMask.isAllOnes())
320         cast<PossiblyDisjointInst>(Or)->setIsDisjoint(true);
321       Or->takeName(I);
322       return InsertNewInstWith(Or, I->getIterator());
323     }
324 
325     // If all of the demanded bits on one side are known, and all of the set
326     // bits on that side are also known to be set on the other side, turn this
327     // into an AND, as we know the bits will be cleared.
328     //    e.g. (X | C1) ^ C2 --> (X | C1) & ~C2 iff (C1&C2) == C2
329     if (DemandedMask.isSubsetOf(RHSKnown.Zero|RHSKnown.One) &&
330         RHSKnown.One.isSubsetOf(LHSKnown.One)) {
331       Constant *AndC = Constant::getIntegerValue(VTy,
332                                                  ~RHSKnown.One & DemandedMask);
333       Instruction *And = BinaryOperator::CreateAnd(I->getOperand(0), AndC);
334       return InsertNewInstWith(And, I->getIterator());
335     }
336 
337     // If the RHS is a constant, see if we can change it. Don't alter a -1
338     // constant because that's a canonical 'not' op, and that is better for
339     // combining, SCEV, and codegen.
340     const APInt *C;
341     if (match(I->getOperand(1), m_APInt(C)) && !C->isAllOnes()) {
342       if ((*C | ~DemandedMask).isAllOnes()) {
343         // Force bits to 1 to create a 'not' op.
344         I->setOperand(1, ConstantInt::getAllOnesValue(VTy));
345         return I;
346       }
347       // If we can't turn this into a 'not', try to shrink the constant.
348       if (ShrinkDemandedConstant(I, 1, DemandedMask))
349         return I;
350     }
351 
352     // If our LHS is an 'and' and if it has one use, and if any of the bits we
353     // are flipping are known to be set, then the xor is just resetting those
354     // bits to zero.  We can just knock out bits from the 'and' and the 'xor',
355     // simplifying both of them.
356     if (Instruction *LHSInst = dyn_cast<Instruction>(I->getOperand(0))) {
357       ConstantInt *AndRHS, *XorRHS;
358       if (LHSInst->getOpcode() == Instruction::And && LHSInst->hasOneUse() &&
359           match(I->getOperand(1), m_ConstantInt(XorRHS)) &&
360           match(LHSInst->getOperand(1), m_ConstantInt(AndRHS)) &&
361           (LHSKnown.One & RHSKnown.One & DemandedMask) != 0) {
362         APInt NewMask = ~(LHSKnown.One & RHSKnown.One & DemandedMask);
363 
364         Constant *AndC = ConstantInt::get(VTy, NewMask & AndRHS->getValue());
365         Instruction *NewAnd = BinaryOperator::CreateAnd(I->getOperand(0), AndC);
366         InsertNewInstWith(NewAnd, I->getIterator());
367 
368         Constant *XorC = ConstantInt::get(VTy, NewMask & XorRHS->getValue());
369         Instruction *NewXor = BinaryOperator::CreateXor(NewAnd, XorC);
370         return InsertNewInstWith(NewXor, I->getIterator());
371       }
372     }
373     break;
374   }
375   case Instruction::Select: {
376     if (SimplifyDemandedBits(I, 2, DemandedMask, RHSKnown, Depth + 1, Q) ||
377         SimplifyDemandedBits(I, 1, DemandedMask, LHSKnown, Depth + 1, Q))
378       return I;
379 
380     // If the operands are constants, see if we can simplify them.
381     // This is similar to ShrinkDemandedConstant, but for a select we want to
382     // try to keep the selected constants the same as icmp value constants, if
383     // we can. This helps not break apart (or helps put back together)
384     // canonical patterns like min and max.
385     auto CanonicalizeSelectConstant = [](Instruction *I, unsigned OpNo,
386                                          const APInt &DemandedMask) {
387       const APInt *SelC;
388       if (!match(I->getOperand(OpNo), m_APInt(SelC)))
389         return false;
390 
391       // Get the constant out of the ICmp, if there is one.
392       // Only try this when exactly 1 operand is a constant (if both operands
393       // are constant, the icmp should eventually simplify). Otherwise, we may
394       // invert the transform that reduces set bits and infinite-loop.
395       Value *X;
396       const APInt *CmpC;
397       if (!match(I->getOperand(0), m_ICmp(m_Value(X), m_APInt(CmpC))) ||
398           isa<Constant>(X) || CmpC->getBitWidth() != SelC->getBitWidth())
399         return ShrinkDemandedConstant(I, OpNo, DemandedMask);
400 
401       // If the constant is already the same as the ICmp, leave it as-is.
402       if (*CmpC == *SelC)
403         return false;
404       // If the constants are not already the same, but can be with the demand
405       // mask, use the constant value from the ICmp.
406       if ((*CmpC & DemandedMask) == (*SelC & DemandedMask)) {
407         I->setOperand(OpNo, ConstantInt::get(I->getType(), *CmpC));
408         return true;
409       }
410       return ShrinkDemandedConstant(I, OpNo, DemandedMask);
411     };
412     if (CanonicalizeSelectConstant(I, 1, DemandedMask) ||
413         CanonicalizeSelectConstant(I, 2, DemandedMask))
414       return I;
415 
416     // Only known if known in both the LHS and RHS.
417     adjustKnownBitsForSelectArm(LHSKnown, I->getOperand(0), I->getOperand(1),
418                                 /*Invert=*/false, Depth, Q);
419     adjustKnownBitsForSelectArm(RHSKnown, I->getOperand(0), I->getOperand(2),
420                                 /*Invert=*/true, Depth, Q);
421     Known = LHSKnown.intersectWith(RHSKnown);
422     break;
423   }
424   case Instruction::Trunc: {
425     // If we do not demand the high bits of a right-shifted and truncated value,
426     // then we may be able to truncate it before the shift.
427     Value *X;
428     const APInt *C;
429     if (match(I->getOperand(0), m_OneUse(m_LShr(m_Value(X), m_APInt(C))))) {
430       // The shift amount must be valid (not poison) in the narrow type, and
431       // it must not be greater than the high bits demanded of the result.
432       if (C->ult(VTy->getScalarSizeInBits()) &&
433           C->ule(DemandedMask.countl_zero())) {
434         // trunc (lshr X, C) --> lshr (trunc X), C
435         IRBuilderBase::InsertPointGuard Guard(Builder);
436         Builder.SetInsertPoint(I);
437         Value *Trunc = Builder.CreateTrunc(X, VTy);
438         return Builder.CreateLShr(Trunc, C->getZExtValue());
439       }
440     }
441   }
442     [[fallthrough]];
443   case Instruction::ZExt: {
444     unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits();
445 
446     APInt InputDemandedMask = DemandedMask.zextOrTrunc(SrcBitWidth);
447     KnownBits InputKnown(SrcBitWidth);
448     if (SimplifyDemandedBits(I, 0, InputDemandedMask, InputKnown, Depth + 1,
449                              Q)) {
450       // For zext nneg, we may have dropped the instruction which made the
451       // input non-negative.
452       I->dropPoisonGeneratingFlags();
453       return I;
454     }
455     assert(InputKnown.getBitWidth() == SrcBitWidth && "Src width changed?");
456     if (I->getOpcode() == Instruction::ZExt && I->hasNonNeg() &&
457         !InputKnown.isNegative())
458       InputKnown.makeNonNegative();
459     Known = InputKnown.zextOrTrunc(BitWidth);
460 
461     break;
462   }
463   case Instruction::SExt: {
464     // Compute the bits in the result that are not present in the input.
465     unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits();
466 
467     APInt InputDemandedBits = DemandedMask.trunc(SrcBitWidth);
468 
469     // If any of the sign extended bits are demanded, we know that the sign
470     // bit is demanded.
471     if (DemandedMask.getActiveBits() > SrcBitWidth)
472       InputDemandedBits.setBit(SrcBitWidth-1);
473 
474     KnownBits InputKnown(SrcBitWidth);
475     if (SimplifyDemandedBits(I, 0, InputDemandedBits, InputKnown, Depth + 1, Q))
476       return I;
477 
478     // If the input sign bit is known zero, or if the NewBits are not demanded
479     // convert this into a zero extension.
480     if (InputKnown.isNonNegative() ||
481         DemandedMask.getActiveBits() <= SrcBitWidth) {
482       // Convert to ZExt cast.
483       CastInst *NewCast = new ZExtInst(I->getOperand(0), VTy);
484       NewCast->takeName(I);
485       return InsertNewInstWith(NewCast, I->getIterator());
486     }
487 
488     // If the sign bit of the input is known set or clear, then we know the
489     // top bits of the result.
490     Known = InputKnown.sext(BitWidth);
491     break;
492   }
493   case Instruction::Add: {
494     if ((DemandedMask & 1) == 0) {
495       // If we do not need the low bit, try to convert bool math to logic:
496       // add iN (zext i1 X), (sext i1 Y) --> sext (~X & Y) to iN
497       Value *X, *Y;
498       if (match(I, m_c_Add(m_OneUse(m_ZExt(m_Value(X))),
499                            m_OneUse(m_SExt(m_Value(Y))))) &&
500           X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType()) {
501         // Truth table for inputs and output signbits:
502         //       X:0 | X:1
503         //      ----------
504         // Y:0  |  0 | 0 |
505         // Y:1  | -1 | 0 |
506         //      ----------
507         IRBuilderBase::InsertPointGuard Guard(Builder);
508         Builder.SetInsertPoint(I);
509         Value *AndNot = Builder.CreateAnd(Builder.CreateNot(X), Y);
510         return Builder.CreateSExt(AndNot, VTy);
511       }
512 
513       // add iN (sext i1 X), (sext i1 Y) --> sext (X | Y) to iN
514       if (match(I, m_Add(m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
515           X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType() &&
516           (I->getOperand(0)->hasOneUse() || I->getOperand(1)->hasOneUse())) {
517 
518         // Truth table for inputs and output signbits:
519         //       X:0 | X:1
520         //      -----------
521         // Y:0  | -1 | -1 |
522         // Y:1  | -1 |  0 |
523         //      -----------
524         IRBuilderBase::InsertPointGuard Guard(Builder);
525         Builder.SetInsertPoint(I);
526         Value *Or = Builder.CreateOr(X, Y);
527         return Builder.CreateSExt(Or, VTy);
528       }
529     }
530 
531     // Right fill the mask of bits for the operands to demand the most
532     // significant bit and all those below it.
533     unsigned NLZ = DemandedMask.countl_zero();
534     APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
535     if (ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
536         SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1, Q))
537       return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
538 
539     // If low order bits are not demanded and known to be zero in one operand,
540     // then we don't need to demand them from the other operand, since they
541     // can't cause overflow into any bits that are demanded in the result.
542     unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countr_one();
543     APInt DemandedFromLHS = DemandedFromOps;
544     DemandedFromLHS.clearLowBits(NTZ);
545     if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) ||
546         SimplifyDemandedBits(I, 0, DemandedFromLHS, LHSKnown, Depth + 1, Q))
547       return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
548 
549     // If we are known to be adding zeros to every bit below
550     // the highest demanded bit, we just return the other side.
551     if (DemandedFromOps.isSubsetOf(RHSKnown.Zero))
552       return I->getOperand(0);
553     if (DemandedFromOps.isSubsetOf(LHSKnown.Zero))
554       return I->getOperand(1);
555 
556     // (add X, C) --> (xor X, C) IFF C is equal to the top bit of the DemandMask
557     {
558       const APInt *C;
559       if (match(I->getOperand(1), m_APInt(C)) &&
560           C->isOneBitSet(DemandedMask.getActiveBits() - 1)) {
561         IRBuilderBase::InsertPointGuard Guard(Builder);
562         Builder.SetInsertPoint(I);
563         return Builder.CreateXor(I->getOperand(0), ConstantInt::get(VTy, *C));
564       }
565     }
566 
567     // Otherwise just compute the known bits of the result.
568     bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
569     bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap();
570     Known = KnownBits::add(LHSKnown, RHSKnown, NSW, NUW);
571     break;
572   }
573   case Instruction::Sub: {
574     // Right fill the mask of bits for the operands to demand the most
575     // significant bit and all those below it.
576     unsigned NLZ = DemandedMask.countl_zero();
577     APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
578     if (ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
579         SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1, Q))
580       return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
581 
582     // If low order bits are not demanded and are known to be zero in RHS,
583     // then we don't need to demand them from LHS, since they can't cause a
584     // borrow from any bits that are demanded in the result.
585     unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countr_one();
586     APInt DemandedFromLHS = DemandedFromOps;
587     DemandedFromLHS.clearLowBits(NTZ);
588     if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) ||
589         SimplifyDemandedBits(I, 0, DemandedFromLHS, LHSKnown, Depth + 1, Q))
590       return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
591 
592     // If we are known to be subtracting zeros from every bit below
593     // the highest demanded bit, we just return the other side.
594     if (DemandedFromOps.isSubsetOf(RHSKnown.Zero))
595       return I->getOperand(0);
596     // We can't do this with the LHS for subtraction, unless we are only
597     // demanding the LSB.
598     if (DemandedFromOps.isOne() && DemandedFromOps.isSubsetOf(LHSKnown.Zero))
599       return I->getOperand(1);
600 
601     // Canonicalize sub mask, X -> ~X
602     const APInt *LHSC;
603     if (match(I->getOperand(0), m_LowBitMask(LHSC)) &&
604         DemandedFromOps.isSubsetOf(*LHSC)) {
605       IRBuilderBase::InsertPointGuard Guard(Builder);
606       Builder.SetInsertPoint(I);
607       return Builder.CreateNot(I->getOperand(1));
608     }
609 
610     // Otherwise just compute the known bits of the result.
611     bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
612     bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap();
613     Known = KnownBits::sub(LHSKnown, RHSKnown, NSW, NUW);
614     break;
615   }
616   case Instruction::Mul: {
617     APInt DemandedFromOps;
618     if (simplifyOperandsBasedOnUnusedHighBits(DemandedFromOps))
619       return I;
620 
621     if (DemandedMask.isPowerOf2()) {
622       // The LSB of X*Y is set only if (X & 1) == 1 and (Y & 1) == 1.
623       // If we demand exactly one bit N and we have "X * (C' << N)" where C' is
624       // odd (has LSB set), then the left-shifted low bit of X is the answer.
625       unsigned CTZ = DemandedMask.countr_zero();
626       const APInt *C;
627       if (match(I->getOperand(1), m_APInt(C)) && C->countr_zero() == CTZ) {
628         Constant *ShiftC = ConstantInt::get(VTy, CTZ);
629         Instruction *Shl = BinaryOperator::CreateShl(I->getOperand(0), ShiftC);
630         return InsertNewInstWith(Shl, I->getIterator());
631       }
632     }
633     // For a squared value "X * X", the bottom 2 bits are 0 and X[0] because:
634     // X * X is odd iff X is odd.
635     // 'Quadratic Reciprocity': X * X -> 0 for bit[1]
636     if (I->getOperand(0) == I->getOperand(1) && DemandedMask.ult(4)) {
637       Constant *One = ConstantInt::get(VTy, 1);
638       Instruction *And1 = BinaryOperator::CreateAnd(I->getOperand(0), One);
639       return InsertNewInstWith(And1, I->getIterator());
640     }
641 
642     llvm::computeKnownBits(I, Known, Depth, Q);
643     break;
644   }
645   case Instruction::Shl: {
646     const APInt *SA;
647     if (match(I->getOperand(1), m_APInt(SA))) {
648       const APInt *ShrAmt;
649       if (match(I->getOperand(0), m_Shr(m_Value(), m_APInt(ShrAmt))))
650         if (Instruction *Shr = dyn_cast<Instruction>(I->getOperand(0)))
651           if (Value *R = simplifyShrShlDemandedBits(Shr, *ShrAmt, I, *SA,
652                                                     DemandedMask, Known))
653             return R;
654 
655       // Do not simplify if shl is part of funnel-shift pattern
656       if (I->hasOneUse()) {
657         auto *Inst = dyn_cast<Instruction>(I->user_back());
658         if (Inst && Inst->getOpcode() == BinaryOperator::Or) {
659           if (auto Opt = convertOrOfShiftsToFunnelShift(*Inst)) {
660             auto [IID, FShiftArgs] = *Opt;
661             if ((IID == Intrinsic::fshl || IID == Intrinsic::fshr) &&
662                 FShiftArgs[0] == FShiftArgs[1]) {
663               llvm::computeKnownBits(I, Known, Depth, Q);
664               break;
665             }
666           }
667         }
668       }
669 
670       // We only want bits that already match the signbit then we don't
671       // need to shift.
672       uint64_t ShiftAmt = SA->getLimitedValue(BitWidth - 1);
673       if (DemandedMask.countr_zero() >= ShiftAmt) {
674         if (I->hasNoSignedWrap()) {
675           unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero();
676           unsigned SignBits =
677               ComputeNumSignBits(I->getOperand(0), Depth + 1, Q.CxtI);
678           if (SignBits > ShiftAmt && SignBits - ShiftAmt >= NumHiDemandedBits)
679             return I->getOperand(0);
680         }
681 
682         // If we can pre-shift a right-shifted constant to the left without
683         // losing any high bits and we don't demand the low bits, then eliminate
684         // the left-shift:
685         // (C >> X) << LeftShiftAmtC --> (C << LeftShiftAmtC) >> X
686         Value *X;
687         Constant *C;
688         if (match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) {
689           Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt);
690           Constant *NewC = ConstantFoldBinaryOpOperands(Instruction::Shl, C,
691                                                         LeftShiftAmtC, DL);
692           if (ConstantFoldBinaryOpOperands(Instruction::LShr, NewC,
693                                            LeftShiftAmtC, DL) == C) {
694             Instruction *Lshr = BinaryOperator::CreateLShr(NewC, X);
695             return InsertNewInstWith(Lshr, I->getIterator());
696           }
697         }
698       }
699 
700       APInt DemandedMaskIn(DemandedMask.lshr(ShiftAmt));
701 
702       // If the shift is NUW/NSW, then it does demand the high bits.
703       ShlOperator *IOp = cast<ShlOperator>(I);
704       if (IOp->hasNoSignedWrap())
705         DemandedMaskIn.setHighBits(ShiftAmt+1);
706       else if (IOp->hasNoUnsignedWrap())
707         DemandedMaskIn.setHighBits(ShiftAmt);
708 
709       if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1, Q))
710         return I;
711 
712       Known = KnownBits::shl(Known,
713                              KnownBits::makeConstant(APInt(BitWidth, ShiftAmt)),
714                              /* NUW */ IOp->hasNoUnsignedWrap(),
715                              /* NSW */ IOp->hasNoSignedWrap());
716     } else {
717       // This is a variable shift, so we can't shift the demand mask by a known
718       // amount. But if we are not demanding high bits, then we are not
719       // demanding those bits from the pre-shifted operand either.
720       if (unsigned CTLZ = DemandedMask.countl_zero()) {
721         APInt DemandedFromOp(APInt::getLowBitsSet(BitWidth, BitWidth - CTLZ));
722         if (SimplifyDemandedBits(I, 0, DemandedFromOp, Known, Depth + 1, Q)) {
723           // We can't guarantee that nsw/nuw hold after simplifying the operand.
724           I->dropPoisonGeneratingFlags();
725           return I;
726         }
727       }
728       llvm::computeKnownBits(I, Known, Depth, Q);
729     }
730     break;
731   }
732   case Instruction::LShr: {
733     const APInt *SA;
734     if (match(I->getOperand(1), m_APInt(SA))) {
735       uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
736 
737       // Do not simplify if lshr is part of funnel-shift pattern
738       if (I->hasOneUse()) {
739         auto *Inst = dyn_cast<Instruction>(I->user_back());
740         if (Inst && Inst->getOpcode() == BinaryOperator::Or) {
741           if (auto Opt = convertOrOfShiftsToFunnelShift(*Inst)) {
742             auto [IID, FShiftArgs] = *Opt;
743             if ((IID == Intrinsic::fshl || IID == Intrinsic::fshr) &&
744                 FShiftArgs[0] == FShiftArgs[1]) {
745               llvm::computeKnownBits(I, Known, Depth, Q);
746               break;
747             }
748           }
749         }
750       }
751 
752       // If we are just demanding the shifted sign bit and below, then this can
753       // be treated as an ASHR in disguise.
754       if (DemandedMask.countl_zero() >= ShiftAmt) {
755         // If we only want bits that already match the signbit then we don't
756         // need to shift.
757         unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero();
758         unsigned SignBits =
759             ComputeNumSignBits(I->getOperand(0), Depth + 1, Q.CxtI);
760         if (SignBits >= NumHiDemandedBits)
761           return I->getOperand(0);
762 
763         // If we can pre-shift a left-shifted constant to the right without
764         // losing any low bits (we already know we don't demand the high bits),
765         // then eliminate the right-shift:
766         // (C << X) >> RightShiftAmtC --> (C >> RightShiftAmtC) << X
767         Value *X;
768         Constant *C;
769         if (match(I->getOperand(0), m_Shl(m_ImmConstant(C), m_Value(X)))) {
770           Constant *RightShiftAmtC = ConstantInt::get(VTy, ShiftAmt);
771           Constant *NewC = ConstantFoldBinaryOpOperands(Instruction::LShr, C,
772                                                         RightShiftAmtC, DL);
773           if (ConstantFoldBinaryOpOperands(Instruction::Shl, NewC,
774                                            RightShiftAmtC, DL) == C) {
775             Instruction *Shl = BinaryOperator::CreateShl(NewC, X);
776             return InsertNewInstWith(Shl, I->getIterator());
777           }
778         }
779 
780         const APInt *Factor;
781         if (match(I->getOperand(0),
782                   m_OneUse(m_Mul(m_Value(X), m_APInt(Factor)))) &&
783             Factor->countr_zero() >= ShiftAmt) {
784           BinaryOperator *Mul = BinaryOperator::CreateMul(
785               X, ConstantInt::get(X->getType(), Factor->lshr(ShiftAmt)));
786           return InsertNewInstWith(Mul, I->getIterator());
787         }
788       }
789 
790       // Unsigned shift right.
791       APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt));
792       if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1, Q)) {
793         // exact flag may not longer hold.
794         I->dropPoisonGeneratingFlags();
795         return I;
796       }
797       Known.Zero.lshrInPlace(ShiftAmt);
798       Known.One.lshrInPlace(ShiftAmt);
799       if (ShiftAmt)
800         Known.Zero.setHighBits(ShiftAmt);  // high bits known zero.
801     } else {
802       llvm::computeKnownBits(I, Known, Depth, Q);
803     }
804     break;
805   }
806   case Instruction::AShr: {
807     unsigned SignBits = ComputeNumSignBits(I->getOperand(0), Depth + 1, Q.CxtI);
808 
809     // If we only want bits that already match the signbit then we don't need
810     // to shift.
811     unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero();
812     if (SignBits >= NumHiDemandedBits)
813       return I->getOperand(0);
814 
815     // If this is an arithmetic shift right and only the low-bit is set, we can
816     // always convert this into a logical shr, even if the shift amount is
817     // variable.  The low bit of the shift cannot be an input sign bit unless
818     // the shift amount is >= the size of the datatype, which is undefined.
819     if (DemandedMask.isOne()) {
820       // Perform the logical shift right.
821       Instruction *NewVal = BinaryOperator::CreateLShr(
822                         I->getOperand(0), I->getOperand(1), I->getName());
823       return InsertNewInstWith(NewVal, I->getIterator());
824     }
825 
826     const APInt *SA;
827     if (match(I->getOperand(1), m_APInt(SA))) {
828       uint32_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
829 
830       // Signed shift right.
831       APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt));
832       // If any of the bits being shifted in are demanded, then we should set
833       // the sign bit as demanded.
834       bool ShiftedInBitsDemanded = DemandedMask.countl_zero() < ShiftAmt;
835       if (ShiftedInBitsDemanded)
836         DemandedMaskIn.setSignBit();
837       if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1, Q)) {
838         // exact flag may not longer hold.
839         I->dropPoisonGeneratingFlags();
840         return I;
841       }
842 
843       // If the input sign bit is known to be zero, or if none of the shifted in
844       // bits are demanded, turn this into an unsigned shift right.
845       if (Known.Zero[BitWidth - 1] || !ShiftedInBitsDemanded) {
846         BinaryOperator *LShr = BinaryOperator::CreateLShr(I->getOperand(0),
847                                                           I->getOperand(1));
848         LShr->setIsExact(cast<BinaryOperator>(I)->isExact());
849         LShr->takeName(I);
850         return InsertNewInstWith(LShr, I->getIterator());
851       }
852 
853       Known = KnownBits::ashr(
854           Known, KnownBits::makeConstant(APInt(BitWidth, ShiftAmt)),
855           ShiftAmt != 0, I->isExact());
856     } else {
857       llvm::computeKnownBits(I, Known, Depth, Q);
858     }
859     break;
860   }
861   case Instruction::UDiv: {
862     // UDiv doesn't demand low bits that are zero in the divisor.
863     const APInt *SA;
864     if (match(I->getOperand(1), m_APInt(SA))) {
865       // TODO: Take the demanded mask of the result into account.
866       unsigned RHSTrailingZeros = SA->countr_zero();
867       APInt DemandedMaskIn =
868           APInt::getHighBitsSet(BitWidth, BitWidth - RHSTrailingZeros);
869       if (SimplifyDemandedBits(I, 0, DemandedMaskIn, LHSKnown, Depth + 1, Q)) {
870         // We can't guarantee that "exact" is still true after changing the
871         // the dividend.
872         I->dropPoisonGeneratingFlags();
873         return I;
874       }
875 
876       Known = KnownBits::udiv(LHSKnown, KnownBits::makeConstant(*SA),
877                               cast<BinaryOperator>(I)->isExact());
878     } else {
879       llvm::computeKnownBits(I, Known, Depth, Q);
880     }
881     break;
882   }
883   case Instruction::SRem: {
884     const APInt *Rem;
885     if (match(I->getOperand(1), m_APInt(Rem)) && Rem->isPowerOf2()) {
886       if (DemandedMask.ult(*Rem)) // srem won't affect demanded bits
887         return I->getOperand(0);
888 
889       APInt LowBits = *Rem - 1;
890       APInt Mask2 = LowBits | APInt::getSignMask(BitWidth);
891       if (SimplifyDemandedBits(I, 0, Mask2, LHSKnown, Depth + 1, Q))
892         return I;
893       Known = KnownBits::srem(LHSKnown, KnownBits::makeConstant(*Rem));
894       break;
895     }
896 
897     llvm::computeKnownBits(I, Known, Depth, Q);
898     break;
899   }
900   case Instruction::Call: {
901     bool KnownBitsComputed = false;
902     if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
903       switch (II->getIntrinsicID()) {
904       case Intrinsic::abs: {
905         if (DemandedMask == 1)
906           return II->getArgOperand(0);
907         break;
908       }
909       case Intrinsic::ctpop: {
910         // Checking if the number of clear bits is odd (parity)? If the type has
911         // an even number of bits, that's the same as checking if the number of
912         // set bits is odd, so we can eliminate the 'not' op.
913         Value *X;
914         if (DemandedMask == 1 && VTy->getScalarSizeInBits() % 2 == 0 &&
915             match(II->getArgOperand(0), m_Not(m_Value(X)))) {
916           Function *Ctpop = Intrinsic::getOrInsertDeclaration(
917               II->getModule(), Intrinsic::ctpop, VTy);
918           return InsertNewInstWith(CallInst::Create(Ctpop, {X}), I->getIterator());
919         }
920         break;
921       }
922       case Intrinsic::bswap: {
923         // If the only bits demanded come from one byte of the bswap result,
924         // just shift the input byte into position to eliminate the bswap.
925         unsigned NLZ = DemandedMask.countl_zero();
926         unsigned NTZ = DemandedMask.countr_zero();
927 
928         // Round NTZ down to the next byte.  If we have 11 trailing zeros, then
929         // we need all the bits down to bit 8.  Likewise, round NLZ.  If we
930         // have 14 leading zeros, round to 8.
931         NLZ = alignDown(NLZ, 8);
932         NTZ = alignDown(NTZ, 8);
933         // If we need exactly one byte, we can do this transformation.
934         if (BitWidth - NLZ - NTZ == 8) {
935           // Replace this with either a left or right shift to get the byte into
936           // the right place.
937           Instruction *NewVal;
938           if (NLZ > NTZ)
939             NewVal = BinaryOperator::CreateLShr(
940                 II->getArgOperand(0), ConstantInt::get(VTy, NLZ - NTZ));
941           else
942             NewVal = BinaryOperator::CreateShl(
943                 II->getArgOperand(0), ConstantInt::get(VTy, NTZ - NLZ));
944           NewVal->takeName(I);
945           return InsertNewInstWith(NewVal, I->getIterator());
946         }
947         break;
948       }
949       case Intrinsic::ptrmask: {
950         unsigned MaskWidth = I->getOperand(1)->getType()->getScalarSizeInBits();
951         RHSKnown = KnownBits(MaskWidth);
952         // If either the LHS or the RHS are Zero, the result is zero.
953         if (SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1, Q) ||
954             SimplifyDemandedBits(
955                 I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth),
956                 RHSKnown, Depth + 1, Q))
957           return I;
958 
959         // TODO: Should be 1-extend
960         RHSKnown = RHSKnown.anyextOrTrunc(BitWidth);
961 
962         Known = LHSKnown & RHSKnown;
963         KnownBitsComputed = true;
964 
965         // If the client is only demanding bits we know to be zero, return
966         // `llvm.ptrmask(p, 0)`. We can't return `null` here due to pointer
967         // provenance, but making the mask zero will be easily optimizable in
968         // the backend.
969         if (DemandedMask.isSubsetOf(Known.Zero) &&
970             !match(I->getOperand(1), m_Zero()))
971           return replaceOperand(
972               *I, 1, Constant::getNullValue(I->getOperand(1)->getType()));
973 
974         // Mask in demanded space does nothing.
975         // NOTE: We may have attributes associated with the return value of the
976         // llvm.ptrmask intrinsic that will be lost when we just return the
977         // operand. We should try to preserve them.
978         if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero))
979           return I->getOperand(0);
980 
981         // If the RHS is a constant, see if we can simplify it.
982         if (ShrinkDemandedConstant(
983                 I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth)))
984           return I;
985 
986         // Combine:
987         // (ptrmask (getelementptr i8, ptr p, imm i), imm mask)
988         //   -> (ptrmask (getelementptr i8, ptr p, imm (i & mask)), imm mask)
989         // where only the low bits known to be zero in the pointer are changed
990         Value *InnerPtr;
991         uint64_t GEPIndex;
992         uint64_t PtrMaskImmediate;
993         if (match(I, m_Intrinsic<Intrinsic::ptrmask>(
994                          m_PtrAdd(m_Value(InnerPtr), m_ConstantInt(GEPIndex)),
995                          m_ConstantInt(PtrMaskImmediate)))) {
996 
997           LHSKnown = computeKnownBits(InnerPtr, Depth + 1, I);
998           if (!LHSKnown.isZero()) {
999             const unsigned trailingZeros = LHSKnown.countMinTrailingZeros();
1000             uint64_t PointerAlignBits = (uint64_t(1) << trailingZeros) - 1;
1001 
1002             uint64_t HighBitsGEPIndex = GEPIndex & ~PointerAlignBits;
1003             uint64_t MaskedLowBitsGEPIndex =
1004                 GEPIndex & PointerAlignBits & PtrMaskImmediate;
1005 
1006             uint64_t MaskedGEPIndex = HighBitsGEPIndex | MaskedLowBitsGEPIndex;
1007 
1008             if (MaskedGEPIndex != GEPIndex) {
1009               auto *GEP = cast<GEPOperator>(II->getArgOperand(0));
1010               Builder.SetInsertPoint(I);
1011               Type *GEPIndexType =
1012                   DL.getIndexType(GEP->getPointerOperand()->getType());
1013               Value *MaskedGEP = Builder.CreateGEP(
1014                   GEP->getSourceElementType(), InnerPtr,
1015                   ConstantInt::get(GEPIndexType, MaskedGEPIndex),
1016                   GEP->getName(), GEP->isInBounds());
1017 
1018               replaceOperand(*I, 0, MaskedGEP);
1019               return I;
1020             }
1021           }
1022         }
1023 
1024         break;
1025       }
1026 
1027       case Intrinsic::fshr:
1028       case Intrinsic::fshl: {
1029         const APInt *SA;
1030         if (!match(I->getOperand(2), m_APInt(SA)))
1031           break;
1032 
1033         // Normalize to funnel shift left. APInt shifts of BitWidth are well-
1034         // defined, so no need to special-case zero shifts here.
1035         uint64_t ShiftAmt = SA->urem(BitWidth);
1036         if (II->getIntrinsicID() == Intrinsic::fshr)
1037           ShiftAmt = BitWidth - ShiftAmt;
1038 
1039         APInt DemandedMaskLHS(DemandedMask.lshr(ShiftAmt));
1040         APInt DemandedMaskRHS(DemandedMask.shl(BitWidth - ShiftAmt));
1041         if (I->getOperand(0) != I->getOperand(1)) {
1042           if (SimplifyDemandedBits(I, 0, DemandedMaskLHS, LHSKnown, Depth + 1,
1043                                    Q) ||
1044               SimplifyDemandedBits(I, 1, DemandedMaskRHS, RHSKnown, Depth + 1,
1045                                    Q)) {
1046             // Range attribute may no longer hold.
1047             I->dropPoisonGeneratingReturnAttributes();
1048             return I;
1049           }
1050         } else { // fshl is a rotate
1051           // Avoid converting rotate into funnel shift.
1052           // Only simplify if one operand is constant.
1053           LHSKnown = computeKnownBits(I->getOperand(0), Depth + 1, I);
1054           if (DemandedMaskLHS.isSubsetOf(LHSKnown.Zero | LHSKnown.One) &&
1055               !match(I->getOperand(0), m_SpecificInt(LHSKnown.One))) {
1056             replaceOperand(*I, 0, Constant::getIntegerValue(VTy, LHSKnown.One));
1057             return I;
1058           }
1059 
1060           RHSKnown = computeKnownBits(I->getOperand(1), Depth + 1, I);
1061           if (DemandedMaskRHS.isSubsetOf(RHSKnown.Zero | RHSKnown.One) &&
1062               !match(I->getOperand(1), m_SpecificInt(RHSKnown.One))) {
1063             replaceOperand(*I, 1, Constant::getIntegerValue(VTy, RHSKnown.One));
1064             return I;
1065           }
1066         }
1067 
1068         Known.Zero = LHSKnown.Zero.shl(ShiftAmt) |
1069                      RHSKnown.Zero.lshr(BitWidth - ShiftAmt);
1070         Known.One = LHSKnown.One.shl(ShiftAmt) |
1071                     RHSKnown.One.lshr(BitWidth - ShiftAmt);
1072         KnownBitsComputed = true;
1073         break;
1074       }
1075       case Intrinsic::umax: {
1076         // UMax(A, C) == A if ...
1077         // The lowest non-zero bit of DemandMask is higher than the highest
1078         // non-zero bit of C.
1079         const APInt *C;
1080         unsigned CTZ = DemandedMask.countr_zero();
1081         if (match(II->getArgOperand(1), m_APInt(C)) &&
1082             CTZ >= C->getActiveBits())
1083           return II->getArgOperand(0);
1084         break;
1085       }
1086       case Intrinsic::umin: {
1087         // UMin(A, C) == A if ...
1088         // The lowest non-zero bit of DemandMask is higher than the highest
1089         // non-one bit of C.
1090         // This comes from using DeMorgans on the above umax example.
1091         const APInt *C;
1092         unsigned CTZ = DemandedMask.countr_zero();
1093         if (match(II->getArgOperand(1), m_APInt(C)) &&
1094             CTZ >= C->getBitWidth() - C->countl_one())
1095           return II->getArgOperand(0);
1096         break;
1097       }
1098       default: {
1099         // Handle target specific intrinsics
1100         std::optional<Value *> V = targetSimplifyDemandedUseBitsIntrinsic(
1101             *II, DemandedMask, Known, KnownBitsComputed);
1102         if (V)
1103           return *V;
1104         break;
1105       }
1106       }
1107     }
1108 
1109     if (!KnownBitsComputed)
1110       llvm::computeKnownBits(I, Known, Depth, Q);
1111     break;
1112   }
1113   }
1114 
1115   if (I->getType()->isPointerTy()) {
1116     Align Alignment = I->getPointerAlignment(DL);
1117     Known.Zero.setLowBits(Log2(Alignment));
1118   }
1119 
1120   // If the client is only demanding bits that we know, return the known
1121   // constant. We can't directly simplify pointers as a constant because of
1122   // pointer provenance.
1123   // TODO: We could return `(inttoptr const)` for pointers.
1124   if (!I->getType()->isPointerTy() &&
1125       DemandedMask.isSubsetOf(Known.Zero | Known.One))
1126     return Constant::getIntegerValue(VTy, Known.One);
1127 
1128   if (VerifyKnownBits) {
1129     KnownBits ReferenceKnown = llvm::computeKnownBits(I, Depth, Q);
1130     if (Known != ReferenceKnown) {
1131       errs() << "Mismatched known bits for " << *I << " in "
1132              << I->getFunction()->getName() << "\n";
1133       errs() << "computeKnownBits(): " << ReferenceKnown << "\n";
1134       errs() << "SimplifyDemandedBits(): " << Known << "\n";
1135       std::abort();
1136     }
1137   }
1138 
1139   return nullptr;
1140 }
1141 
1142 /// Helper routine of SimplifyDemandedUseBits. It computes Known
1143 /// bits. It also tries to handle simplifications that can be done based on
1144 /// DemandedMask, but without modifying the Instruction.
1145 Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
1146     Instruction *I, const APInt &DemandedMask, KnownBits &Known, unsigned Depth,
1147     const SimplifyQuery &Q) {
1148   unsigned BitWidth = DemandedMask.getBitWidth();
1149   Type *ITy = I->getType();
1150 
1151   KnownBits LHSKnown(BitWidth);
1152   KnownBits RHSKnown(BitWidth);
1153 
1154   // Despite the fact that we can't simplify this instruction in all User's
1155   // context, we can at least compute the known bits, and we can
1156   // do simplifications that apply to *just* the one user if we know that
1157   // this instruction has a simpler value in that context.
1158   switch (I->getOpcode()) {
1159   case Instruction::And: {
1160     llvm::computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, Q);
1161     llvm::computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, Q);
1162     Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
1163                                          Depth, Q);
1164     computeKnownBitsFromContext(I, Known, Depth, Q);
1165 
1166     // If the client is only demanding bits that we know, return the known
1167     // constant.
1168     if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
1169       return Constant::getIntegerValue(ITy, Known.One);
1170 
1171     // If all of the demanded bits are known 1 on one side, return the other.
1172     // These bits cannot contribute to the result of the 'and' in this context.
1173     if (DemandedMask.isSubsetOf(LHSKnown.Zero | RHSKnown.One))
1174       return I->getOperand(0);
1175     if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.One))
1176       return I->getOperand(1);
1177 
1178     break;
1179   }
1180   case Instruction::Or: {
1181     llvm::computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, Q);
1182     llvm::computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, Q);
1183     Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
1184                                          Depth, Q);
1185     computeKnownBitsFromContext(I, Known, Depth, Q);
1186 
1187     // If the client is only demanding bits that we know, return the known
1188     // constant.
1189     if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
1190       return Constant::getIntegerValue(ITy, Known.One);
1191 
1192     // We can simplify (X|Y) -> X or Y in the user's context if we know that
1193     // only bits from X or Y are demanded.
1194     // If all of the demanded bits are known zero on one side, return the other.
1195     // These bits cannot contribute to the result of the 'or' in this context.
1196     if (DemandedMask.isSubsetOf(LHSKnown.One | RHSKnown.Zero))
1197       return I->getOperand(0);
1198     if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero))
1199       return I->getOperand(1);
1200 
1201     break;
1202   }
1203   case Instruction::Xor: {
1204     llvm::computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, Q);
1205     llvm::computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, Q);
1206     Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
1207                                          Depth, Q);
1208     computeKnownBitsFromContext(I, Known, Depth, Q);
1209 
1210     // If the client is only demanding bits that we know, return the known
1211     // constant.
1212     if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
1213       return Constant::getIntegerValue(ITy, Known.One);
1214 
1215     // We can simplify (X^Y) -> X or Y in the user's context if we know that
1216     // only bits from X or Y are demanded.
1217     // If all of the demanded bits are known zero on one side, return the other.
1218     if (DemandedMask.isSubsetOf(RHSKnown.Zero))
1219       return I->getOperand(0);
1220     if (DemandedMask.isSubsetOf(LHSKnown.Zero))
1221       return I->getOperand(1);
1222 
1223     break;
1224   }
1225   case Instruction::Add: {
1226     unsigned NLZ = DemandedMask.countl_zero();
1227     APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
1228 
1229     // If an operand adds zeros to every bit below the highest demanded bit,
1230     // that operand doesn't change the result. Return the other side.
1231     llvm::computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, Q);
1232     if (DemandedFromOps.isSubsetOf(RHSKnown.Zero))
1233       return I->getOperand(0);
1234 
1235     llvm::computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, Q);
1236     if (DemandedFromOps.isSubsetOf(LHSKnown.Zero))
1237       return I->getOperand(1);
1238 
1239     bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
1240     bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap();
1241     Known = KnownBits::add(LHSKnown, RHSKnown, NSW, NUW);
1242     computeKnownBitsFromContext(I, Known, Depth, Q);
1243     break;
1244   }
1245   case Instruction::Sub: {
1246     unsigned NLZ = DemandedMask.countl_zero();
1247     APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
1248 
1249     // If an operand subtracts zeros from every bit below the highest demanded
1250     // bit, that operand doesn't change the result. Return the other side.
1251     llvm::computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, Q);
1252     if (DemandedFromOps.isSubsetOf(RHSKnown.Zero))
1253       return I->getOperand(0);
1254 
1255     bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
1256     bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap();
1257     llvm::computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, Q);
1258     Known = KnownBits::sub(LHSKnown, RHSKnown, NSW, NUW);
1259     computeKnownBitsFromContext(I, Known, Depth, Q);
1260     break;
1261   }
1262   case Instruction::AShr: {
1263     // Compute the Known bits to simplify things downstream.
1264     llvm::computeKnownBits(I, Known, Depth, Q);
1265 
1266     // If this user is only demanding bits that we know, return the known
1267     // constant.
1268     if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
1269       return Constant::getIntegerValue(ITy, Known.One);
1270 
1271     // If the right shift operand 0 is a result of a left shift by the same
1272     // amount, this is probably a zero/sign extension, which may be unnecessary,
1273     // if we do not demand any of the new sign bits. So, return the original
1274     // operand instead.
1275     const APInt *ShiftRC;
1276     const APInt *ShiftLC;
1277     Value *X;
1278     unsigned BitWidth = DemandedMask.getBitWidth();
1279     if (match(I,
1280               m_AShr(m_Shl(m_Value(X), m_APInt(ShiftLC)), m_APInt(ShiftRC))) &&
1281         ShiftLC == ShiftRC && ShiftLC->ult(BitWidth) &&
1282         DemandedMask.isSubsetOf(APInt::getLowBitsSet(
1283             BitWidth, BitWidth - ShiftRC->getZExtValue()))) {
1284       return X;
1285     }
1286 
1287     break;
1288   }
1289   default:
1290     // Compute the Known bits to simplify things downstream.
1291     llvm::computeKnownBits(I, Known, Depth, Q);
1292 
1293     // If this user is only demanding bits that we know, return the known
1294     // constant.
1295     if (DemandedMask.isSubsetOf(Known.Zero|Known.One))
1296       return Constant::getIntegerValue(ITy, Known.One);
1297 
1298     break;
1299   }
1300 
1301   return nullptr;
1302 }
1303 
1304 /// Helper routine of SimplifyDemandedUseBits. It tries to simplify
1305 /// "E1 = (X lsr C1) << C2", where the C1 and C2 are constant, into
1306 /// "E2 = X << (C2 - C1)" or "E2 = X >> (C1 - C2)", depending on the sign
1307 /// of "C2-C1".
1308 ///
1309 /// Suppose E1 and E2 are generally different in bits S={bm, bm+1,
1310 /// ..., bn}, without considering the specific value X is holding.
1311 /// This transformation is legal iff one of following conditions is hold:
1312 ///  1) All the bit in S are 0, in this case E1 == E2.
1313 ///  2) We don't care those bits in S, per the input DemandedMask.
1314 ///  3) Combination of 1) and 2). Some bits in S are 0, and we don't care the
1315 ///     rest bits.
1316 ///
1317 /// Currently we only test condition 2).
1318 ///
1319 /// As with SimplifyDemandedUseBits, it returns NULL if the simplification was
1320 /// not successful.
1321 Value *InstCombinerImpl::simplifyShrShlDemandedBits(
1322     Instruction *Shr, const APInt &ShrOp1, Instruction *Shl,
1323     const APInt &ShlOp1, const APInt &DemandedMask, KnownBits &Known) {
1324   if (!ShlOp1 || !ShrOp1)
1325     return nullptr; // No-op.
1326 
1327   Value *VarX = Shr->getOperand(0);
1328   Type *Ty = VarX->getType();
1329   unsigned BitWidth = Ty->getScalarSizeInBits();
1330   if (ShlOp1.uge(BitWidth) || ShrOp1.uge(BitWidth))
1331     return nullptr; // Undef.
1332 
1333   unsigned ShlAmt = ShlOp1.getZExtValue();
1334   unsigned ShrAmt = ShrOp1.getZExtValue();
1335 
1336   Known.One.clearAllBits();
1337   Known.Zero.setLowBits(ShlAmt - 1);
1338   Known.Zero &= DemandedMask;
1339 
1340   APInt BitMask1(APInt::getAllOnes(BitWidth));
1341   APInt BitMask2(APInt::getAllOnes(BitWidth));
1342 
1343   bool isLshr = (Shr->getOpcode() == Instruction::LShr);
1344   BitMask1 = isLshr ? (BitMask1.lshr(ShrAmt) << ShlAmt) :
1345                       (BitMask1.ashr(ShrAmt) << ShlAmt);
1346 
1347   if (ShrAmt <= ShlAmt) {
1348     BitMask2 <<= (ShlAmt - ShrAmt);
1349   } else {
1350     BitMask2 = isLshr ? BitMask2.lshr(ShrAmt - ShlAmt):
1351                         BitMask2.ashr(ShrAmt - ShlAmt);
1352   }
1353 
1354   // Check if condition-2 (see the comment to this function) is satified.
1355   if ((BitMask1 & DemandedMask) == (BitMask2 & DemandedMask)) {
1356     if (ShrAmt == ShlAmt)
1357       return VarX;
1358 
1359     if (!Shr->hasOneUse())
1360       return nullptr;
1361 
1362     BinaryOperator *New;
1363     if (ShrAmt < ShlAmt) {
1364       Constant *Amt = ConstantInt::get(VarX->getType(), ShlAmt - ShrAmt);
1365       New = BinaryOperator::CreateShl(VarX, Amt);
1366       BinaryOperator *Orig = cast<BinaryOperator>(Shl);
1367       New->setHasNoSignedWrap(Orig->hasNoSignedWrap());
1368       New->setHasNoUnsignedWrap(Orig->hasNoUnsignedWrap());
1369     } else {
1370       Constant *Amt = ConstantInt::get(VarX->getType(), ShrAmt - ShlAmt);
1371       New = isLshr ? BinaryOperator::CreateLShr(VarX, Amt) :
1372                      BinaryOperator::CreateAShr(VarX, Amt);
1373       if (cast<BinaryOperator>(Shr)->isExact())
1374         New->setIsExact(true);
1375     }
1376 
1377     return InsertNewInstWith(New, Shl->getIterator());
1378   }
1379 
1380   return nullptr;
1381 }
1382 
1383 /// The specified value produces a vector with any number of elements.
1384 /// This method analyzes which elements of the operand are poison and
1385 /// returns that information in PoisonElts.
1386 ///
1387 /// DemandedElts contains the set of elements that are actually used by the
1388 /// caller, and by default (AllowMultipleUsers equals false) the value is
1389 /// simplified only if it has a single caller. If AllowMultipleUsers is set
1390 /// to true, DemandedElts refers to the union of sets of elements that are
1391 /// used by all callers.
1392 ///
1393 /// If the information about demanded elements can be used to simplify the
1394 /// operation, the operation is simplified, then the resultant value is
1395 /// returned.  This returns null if no change was made.
1396 Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
1397                                                     APInt DemandedElts,
1398                                                     APInt &PoisonElts,
1399                                                     unsigned Depth,
1400                                                     bool AllowMultipleUsers) {
1401   // Cannot analyze scalable type. The number of vector elements is not a
1402   // compile-time constant.
1403   if (isa<ScalableVectorType>(V->getType()))
1404     return nullptr;
1405 
1406   unsigned VWidth = cast<FixedVectorType>(V->getType())->getNumElements();
1407   APInt EltMask(APInt::getAllOnes(VWidth));
1408   assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!");
1409 
1410   if (match(V, m_Poison())) {
1411     // If the entire vector is poison, just return this info.
1412     PoisonElts = EltMask;
1413     return nullptr;
1414   }
1415 
1416   if (DemandedElts.isZero()) { // If nothing is demanded, provide poison.
1417     PoisonElts = EltMask;
1418     return PoisonValue::get(V->getType());
1419   }
1420 
1421   PoisonElts = 0;
1422 
1423   if (auto *C = dyn_cast<Constant>(V)) {
1424     // Check if this is identity. If so, return 0 since we are not simplifying
1425     // anything.
1426     if (DemandedElts.isAllOnes())
1427       return nullptr;
1428 
1429     Type *EltTy = cast<VectorType>(V->getType())->getElementType();
1430     Constant *Poison = PoisonValue::get(EltTy);
1431     SmallVector<Constant*, 16> Elts;
1432     for (unsigned i = 0; i != VWidth; ++i) {
1433       if (!DemandedElts[i]) {   // If not demanded, set to poison.
1434         Elts.push_back(Poison);
1435         PoisonElts.setBit(i);
1436         continue;
1437       }
1438 
1439       Constant *Elt = C->getAggregateElement(i);
1440       if (!Elt) return nullptr;
1441 
1442       Elts.push_back(Elt);
1443       if (isa<PoisonValue>(Elt)) // Already poison.
1444         PoisonElts.setBit(i);
1445     }
1446 
1447     // If we changed the constant, return it.
1448     Constant *NewCV = ConstantVector::get(Elts);
1449     return NewCV != C ? NewCV : nullptr;
1450   }
1451 
1452   // Limit search depth.
1453   if (Depth == SimplifyDemandedVectorEltsDepthLimit)
1454     return nullptr;
1455 
1456   if (!AllowMultipleUsers) {
1457     // If multiple users are using the root value, proceed with
1458     // simplification conservatively assuming that all elements
1459     // are needed.
1460     if (!V->hasOneUse()) {
1461       // Quit if we find multiple users of a non-root value though.
1462       // They'll be handled when it's their turn to be visited by
1463       // the main instcombine process.
1464       if (Depth != 0)
1465         // TODO: Just compute the PoisonElts information recursively.
1466         return nullptr;
1467 
1468       // Conservatively assume that all elements are needed.
1469       DemandedElts = EltMask;
1470     }
1471   }
1472 
1473   Instruction *I = dyn_cast<Instruction>(V);
1474   if (!I) return nullptr;        // Only analyze instructions.
1475 
1476   bool MadeChange = false;
1477   auto simplifyAndSetOp = [&](Instruction *Inst, unsigned OpNum,
1478                               APInt Demanded, APInt &Undef) {
1479     auto *II = dyn_cast<IntrinsicInst>(Inst);
1480     Value *Op = II ? II->getArgOperand(OpNum) : Inst->getOperand(OpNum);
1481     if (Value *V = SimplifyDemandedVectorElts(Op, Demanded, Undef, Depth + 1)) {
1482       replaceOperand(*Inst, OpNum, V);
1483       MadeChange = true;
1484     }
1485   };
1486 
1487   APInt PoisonElts2(VWidth, 0);
1488   APInt PoisonElts3(VWidth, 0);
1489   switch (I->getOpcode()) {
1490   default: break;
1491 
1492   case Instruction::GetElementPtr: {
1493     // The LangRef requires that struct geps have all constant indices.  As
1494     // such, we can't convert any operand to partial undef.
1495     auto mayIndexStructType = [](GetElementPtrInst &GEP) {
1496       for (auto I = gep_type_begin(GEP), E = gep_type_end(GEP);
1497            I != E; I++)
1498         if (I.isStruct())
1499           return true;
1500       return false;
1501     };
1502     if (mayIndexStructType(cast<GetElementPtrInst>(*I)))
1503       break;
1504 
1505     // Conservatively track the demanded elements back through any vector
1506     // operands we may have.  We know there must be at least one, or we
1507     // wouldn't have a vector result to get here. Note that we intentionally
1508     // merge the undef bits here since gepping with either an poison base or
1509     // index results in poison.
1510     for (unsigned i = 0; i < I->getNumOperands(); i++) {
1511       if (i == 0 ? match(I->getOperand(i), m_Undef())
1512                  : match(I->getOperand(i), m_Poison())) {
1513         // If the entire vector is undefined, just return this info.
1514         PoisonElts = EltMask;
1515         return nullptr;
1516       }
1517       if (I->getOperand(i)->getType()->isVectorTy()) {
1518         APInt PoisonEltsOp(VWidth, 0);
1519         simplifyAndSetOp(I, i, DemandedElts, PoisonEltsOp);
1520         // gep(x, undef) is not undef, so skip considering idx ops here
1521         // Note that we could propagate poison, but we can't distinguish between
1522         // undef & poison bits ATM
1523         if (i == 0)
1524           PoisonElts |= PoisonEltsOp;
1525       }
1526     }
1527 
1528     break;
1529   }
1530   case Instruction::InsertElement: {
1531     // If this is a variable index, we don't know which element it overwrites.
1532     // demand exactly the same input as we produce.
1533     ConstantInt *Idx = dyn_cast<ConstantInt>(I->getOperand(2));
1534     if (!Idx) {
1535       // Note that we can't propagate undef elt info, because we don't know
1536       // which elt is getting updated.
1537       simplifyAndSetOp(I, 0, DemandedElts, PoisonElts2);
1538       break;
1539     }
1540 
1541     // The element inserted overwrites whatever was there, so the input demanded
1542     // set is simpler than the output set.
1543     unsigned IdxNo = Idx->getZExtValue();
1544     APInt PreInsertDemandedElts = DemandedElts;
1545     if (IdxNo < VWidth)
1546       PreInsertDemandedElts.clearBit(IdxNo);
1547 
1548     // If we only demand the element that is being inserted and that element
1549     // was extracted from the same index in another vector with the same type,
1550     // replace this insert with that other vector.
1551     // Note: This is attempted before the call to simplifyAndSetOp because that
1552     //       may change PoisonElts to a value that does not match with Vec.
1553     Value *Vec;
1554     if (PreInsertDemandedElts == 0 &&
1555         match(I->getOperand(1),
1556               m_ExtractElt(m_Value(Vec), m_SpecificInt(IdxNo))) &&
1557         Vec->getType() == I->getType()) {
1558       return Vec;
1559     }
1560 
1561     simplifyAndSetOp(I, 0, PreInsertDemandedElts, PoisonElts);
1562 
1563     // If this is inserting an element that isn't demanded, remove this
1564     // insertelement.
1565     if (IdxNo >= VWidth || !DemandedElts[IdxNo]) {
1566       Worklist.push(I);
1567       return I->getOperand(0);
1568     }
1569 
1570     // The inserted element is defined.
1571     PoisonElts.clearBit(IdxNo);
1572     break;
1573   }
1574   case Instruction::ShuffleVector: {
1575     auto *Shuffle = cast<ShuffleVectorInst>(I);
1576     assert(Shuffle->getOperand(0)->getType() ==
1577            Shuffle->getOperand(1)->getType() &&
1578            "Expected shuffle operands to have same type");
1579     unsigned OpWidth = cast<FixedVectorType>(Shuffle->getOperand(0)->getType())
1580                            ->getNumElements();
1581     // Handle trivial case of a splat. Only check the first element of LHS
1582     // operand.
1583     if (all_of(Shuffle->getShuffleMask(), [](int Elt) { return Elt == 0; }) &&
1584         DemandedElts.isAllOnes()) {
1585       if (!isa<PoisonValue>(I->getOperand(1))) {
1586         I->setOperand(1, PoisonValue::get(I->getOperand(1)->getType()));
1587         MadeChange = true;
1588       }
1589       APInt LeftDemanded(OpWidth, 1);
1590       APInt LHSPoisonElts(OpWidth, 0);
1591       simplifyAndSetOp(I, 0, LeftDemanded, LHSPoisonElts);
1592       if (LHSPoisonElts[0])
1593         PoisonElts = EltMask;
1594       else
1595         PoisonElts.clearAllBits();
1596       break;
1597     }
1598 
1599     APInt LeftDemanded(OpWidth, 0), RightDemanded(OpWidth, 0);
1600     for (unsigned i = 0; i < VWidth; i++) {
1601       if (DemandedElts[i]) {
1602         unsigned MaskVal = Shuffle->getMaskValue(i);
1603         if (MaskVal != -1u) {
1604           assert(MaskVal < OpWidth * 2 &&
1605                  "shufflevector mask index out of range!");
1606           if (MaskVal < OpWidth)
1607             LeftDemanded.setBit(MaskVal);
1608           else
1609             RightDemanded.setBit(MaskVal - OpWidth);
1610         }
1611       }
1612     }
1613 
1614     APInt LHSPoisonElts(OpWidth, 0);
1615     simplifyAndSetOp(I, 0, LeftDemanded, LHSPoisonElts);
1616 
1617     APInt RHSPoisonElts(OpWidth, 0);
1618     simplifyAndSetOp(I, 1, RightDemanded, RHSPoisonElts);
1619 
1620     // If this shuffle does not change the vector length and the elements
1621     // demanded by this shuffle are an identity mask, then this shuffle is
1622     // unnecessary.
1623     //
1624     // We are assuming canonical form for the mask, so the source vector is
1625     // operand 0 and operand 1 is not used.
1626     //
1627     // Note that if an element is demanded and this shuffle mask is undefined
1628     // for that element, then the shuffle is not considered an identity
1629     // operation. The shuffle prevents poison from the operand vector from
1630     // leaking to the result by replacing poison with an undefined value.
1631     if (VWidth == OpWidth) {
1632       bool IsIdentityShuffle = true;
1633       for (unsigned i = 0; i < VWidth; i++) {
1634         unsigned MaskVal = Shuffle->getMaskValue(i);
1635         if (DemandedElts[i] && i != MaskVal) {
1636           IsIdentityShuffle = false;
1637           break;
1638         }
1639       }
1640       if (IsIdentityShuffle)
1641         return Shuffle->getOperand(0);
1642     }
1643 
1644     bool NewPoisonElts = false;
1645     unsigned LHSIdx = -1u, LHSValIdx = -1u;
1646     unsigned RHSIdx = -1u, RHSValIdx = -1u;
1647     bool LHSUniform = true;
1648     bool RHSUniform = true;
1649     for (unsigned i = 0; i < VWidth; i++) {
1650       unsigned MaskVal = Shuffle->getMaskValue(i);
1651       if (MaskVal == -1u) {
1652         PoisonElts.setBit(i);
1653       } else if (!DemandedElts[i]) {
1654         NewPoisonElts = true;
1655         PoisonElts.setBit(i);
1656       } else if (MaskVal < OpWidth) {
1657         if (LHSPoisonElts[MaskVal]) {
1658           NewPoisonElts = true;
1659           PoisonElts.setBit(i);
1660         } else {
1661           LHSIdx = LHSIdx == -1u ? i : OpWidth;
1662           LHSValIdx = LHSValIdx == -1u ? MaskVal : OpWidth;
1663           LHSUniform = LHSUniform && (MaskVal == i);
1664         }
1665       } else {
1666         if (RHSPoisonElts[MaskVal - OpWidth]) {
1667           NewPoisonElts = true;
1668           PoisonElts.setBit(i);
1669         } else {
1670           RHSIdx = RHSIdx == -1u ? i : OpWidth;
1671           RHSValIdx = RHSValIdx == -1u ? MaskVal - OpWidth : OpWidth;
1672           RHSUniform = RHSUniform && (MaskVal - OpWidth == i);
1673         }
1674       }
1675     }
1676 
1677     // Try to transform shuffle with constant vector and single element from
1678     // this constant vector to single insertelement instruction.
1679     // shufflevector V, C, <v1, v2, .., ci, .., vm> ->
1680     // insertelement V, C[ci], ci-n
1681     if (OpWidth ==
1682         cast<FixedVectorType>(Shuffle->getType())->getNumElements()) {
1683       Value *Op = nullptr;
1684       Constant *Value = nullptr;
1685       unsigned Idx = -1u;
1686 
1687       // Find constant vector with the single element in shuffle (LHS or RHS).
1688       if (LHSIdx < OpWidth && RHSUniform) {
1689         if (auto *CV = dyn_cast<ConstantVector>(Shuffle->getOperand(0))) {
1690           Op = Shuffle->getOperand(1);
1691           Value = CV->getOperand(LHSValIdx);
1692           Idx = LHSIdx;
1693         }
1694       }
1695       if (RHSIdx < OpWidth && LHSUniform) {
1696         if (auto *CV = dyn_cast<ConstantVector>(Shuffle->getOperand(1))) {
1697           Op = Shuffle->getOperand(0);
1698           Value = CV->getOperand(RHSValIdx);
1699           Idx = RHSIdx;
1700         }
1701       }
1702       // Found constant vector with single element - convert to insertelement.
1703       if (Op && Value) {
1704         Instruction *New = InsertElementInst::Create(
1705             Op, Value, ConstantInt::get(Type::getInt64Ty(I->getContext()), Idx),
1706             Shuffle->getName());
1707         InsertNewInstWith(New, Shuffle->getIterator());
1708         return New;
1709       }
1710     }
1711     if (NewPoisonElts) {
1712       // Add additional discovered undefs.
1713       SmallVector<int, 16> Elts;
1714       for (unsigned i = 0; i < VWidth; ++i) {
1715         if (PoisonElts[i])
1716           Elts.push_back(PoisonMaskElem);
1717         else
1718           Elts.push_back(Shuffle->getMaskValue(i));
1719       }
1720       Shuffle->setShuffleMask(Elts);
1721       MadeChange = true;
1722     }
1723     break;
1724   }
1725   case Instruction::Select: {
1726     // If this is a vector select, try to transform the select condition based
1727     // on the current demanded elements.
1728     SelectInst *Sel = cast<SelectInst>(I);
1729     if (Sel->getCondition()->getType()->isVectorTy()) {
1730       // TODO: We are not doing anything with PoisonElts based on this call.
1731       // It is overwritten below based on the other select operands. If an
1732       // element of the select condition is known undef, then we are free to
1733       // choose the output value from either arm of the select. If we know that
1734       // one of those values is undef, then the output can be undef.
1735       simplifyAndSetOp(I, 0, DemandedElts, PoisonElts);
1736     }
1737 
1738     // Next, see if we can transform the arms of the select.
1739     APInt DemandedLHS(DemandedElts), DemandedRHS(DemandedElts);
1740     if (auto *CV = dyn_cast<ConstantVector>(Sel->getCondition())) {
1741       for (unsigned i = 0; i < VWidth; i++) {
1742         Constant *CElt = CV->getAggregateElement(i);
1743 
1744         // isNullValue() always returns false when called on a ConstantExpr.
1745         if (CElt->isNullValue())
1746           DemandedLHS.clearBit(i);
1747         else if (CElt->isOneValue())
1748           DemandedRHS.clearBit(i);
1749       }
1750     }
1751 
1752     simplifyAndSetOp(I, 1, DemandedLHS, PoisonElts2);
1753     simplifyAndSetOp(I, 2, DemandedRHS, PoisonElts3);
1754 
1755     // Output elements are undefined if the element from each arm is undefined.
1756     // TODO: This can be improved. See comment in select condition handling.
1757     PoisonElts = PoisonElts2 & PoisonElts3;
1758     break;
1759   }
1760   case Instruction::BitCast: {
1761     // Vector->vector casts only.
1762     VectorType *VTy = dyn_cast<VectorType>(I->getOperand(0)->getType());
1763     if (!VTy) break;
1764     unsigned InVWidth = cast<FixedVectorType>(VTy)->getNumElements();
1765     APInt InputDemandedElts(InVWidth, 0);
1766     PoisonElts2 = APInt(InVWidth, 0);
1767     unsigned Ratio;
1768 
1769     if (VWidth == InVWidth) {
1770       // If we are converting from <4 x i32> -> <4 x f32>, we demand the same
1771       // elements as are demanded of us.
1772       Ratio = 1;
1773       InputDemandedElts = DemandedElts;
1774     } else if ((VWidth % InVWidth) == 0) {
1775       // If the number of elements in the output is a multiple of the number of
1776       // elements in the input then an input element is live if any of the
1777       // corresponding output elements are live.
1778       Ratio = VWidth / InVWidth;
1779       for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx)
1780         if (DemandedElts[OutIdx])
1781           InputDemandedElts.setBit(OutIdx / Ratio);
1782     } else if ((InVWidth % VWidth) == 0) {
1783       // If the number of elements in the input is a multiple of the number of
1784       // elements in the output then an input element is live if the
1785       // corresponding output element is live.
1786       Ratio = InVWidth / VWidth;
1787       for (unsigned InIdx = 0; InIdx != InVWidth; ++InIdx)
1788         if (DemandedElts[InIdx / Ratio])
1789           InputDemandedElts.setBit(InIdx);
1790     } else {
1791       // Unsupported so far.
1792       break;
1793     }
1794 
1795     simplifyAndSetOp(I, 0, InputDemandedElts, PoisonElts2);
1796 
1797     if (VWidth == InVWidth) {
1798       PoisonElts = PoisonElts2;
1799     } else if ((VWidth % InVWidth) == 0) {
1800       // If the number of elements in the output is a multiple of the number of
1801       // elements in the input then an output element is undef if the
1802       // corresponding input element is undef.
1803       for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx)
1804         if (PoisonElts2[OutIdx / Ratio])
1805           PoisonElts.setBit(OutIdx);
1806     } else if ((InVWidth % VWidth) == 0) {
1807       // If the number of elements in the input is a multiple of the number of
1808       // elements in the output then an output element is undef if all of the
1809       // corresponding input elements are undef.
1810       for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) {
1811         APInt SubUndef = PoisonElts2.lshr(OutIdx * Ratio).zextOrTrunc(Ratio);
1812         if (SubUndef.popcount() == Ratio)
1813           PoisonElts.setBit(OutIdx);
1814       }
1815     } else {
1816       llvm_unreachable("Unimp");
1817     }
1818     break;
1819   }
1820   case Instruction::FPTrunc:
1821   case Instruction::FPExt:
1822     simplifyAndSetOp(I, 0, DemandedElts, PoisonElts);
1823     break;
1824 
1825   case Instruction::Call: {
1826     IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
1827     if (!II) break;
1828     switch (II->getIntrinsicID()) {
1829     case Intrinsic::masked_gather: // fallthrough
1830     case Intrinsic::masked_load: {
1831       // Subtlety: If we load from a pointer, the pointer must be valid
1832       // regardless of whether the element is demanded.  Doing otherwise risks
1833       // segfaults which didn't exist in the original program.
1834       APInt DemandedPtrs(APInt::getAllOnes(VWidth)),
1835           DemandedPassThrough(DemandedElts);
1836       if (auto *CV = dyn_cast<ConstantVector>(II->getOperand(2)))
1837         for (unsigned i = 0; i < VWidth; i++) {
1838           Constant *CElt = CV->getAggregateElement(i);
1839           if (CElt->isNullValue())
1840             DemandedPtrs.clearBit(i);
1841           else if (CElt->isAllOnesValue())
1842             DemandedPassThrough.clearBit(i);
1843         }
1844       if (II->getIntrinsicID() == Intrinsic::masked_gather)
1845         simplifyAndSetOp(II, 0, DemandedPtrs, PoisonElts2);
1846       simplifyAndSetOp(II, 3, DemandedPassThrough, PoisonElts3);
1847 
1848       // Output elements are undefined if the element from both sources are.
1849       // TODO: can strengthen via mask as well.
1850       PoisonElts = PoisonElts2 & PoisonElts3;
1851       break;
1852     }
1853     default: {
1854       // Handle target specific intrinsics
1855       std::optional<Value *> V = targetSimplifyDemandedVectorEltsIntrinsic(
1856           *II, DemandedElts, PoisonElts, PoisonElts2, PoisonElts3,
1857           simplifyAndSetOp);
1858       if (V)
1859         return *V;
1860       break;
1861     }
1862     } // switch on IntrinsicID
1863     break;
1864   } // case Call
1865   } // switch on Opcode
1866 
1867   // TODO: We bail completely on integer div/rem and shifts because they have
1868   // UB/poison potential, but that should be refined.
1869   BinaryOperator *BO;
1870   if (match(I, m_BinOp(BO)) && !BO->isIntDivRem() && !BO->isShift()) {
1871     Value *X = BO->getOperand(0);
1872     Value *Y = BO->getOperand(1);
1873 
1874     // Look for an equivalent binop except that one operand has been shuffled.
1875     // If the demand for this binop only includes elements that are the same as
1876     // the other binop, then we may be able to replace this binop with a use of
1877     // the earlier one.
1878     //
1879     // Example:
1880     // %other_bo = bo (shuf X, {0}), Y
1881     // %this_extracted_bo = extelt (bo X, Y), 0
1882     // -->
1883     // %other_bo = bo (shuf X, {0}), Y
1884     // %this_extracted_bo = extelt %other_bo, 0
1885     //
1886     // TODO: Handle demand of an arbitrary single element or more than one
1887     //       element instead of just element 0.
1888     // TODO: Unlike general demanded elements transforms, this should be safe
1889     //       for any (div/rem/shift) opcode too.
1890     if (DemandedElts == 1 && !X->hasOneUse() && !Y->hasOneUse() &&
1891         BO->hasOneUse() ) {
1892 
1893       auto findShufBO = [&](bool MatchShufAsOp0) -> User * {
1894         // Try to use shuffle-of-operand in place of an operand:
1895         // bo X, Y --> bo (shuf X), Y
1896         // bo X, Y --> bo X, (shuf Y)
1897         BinaryOperator::BinaryOps Opcode = BO->getOpcode();
1898         Value *ShufOp = MatchShufAsOp0 ? X : Y;
1899         Value *OtherOp = MatchShufAsOp0 ? Y : X;
1900         for (User *U : OtherOp->users()) {
1901           ArrayRef<int> Mask;
1902           auto Shuf = m_Shuffle(m_Specific(ShufOp), m_Value(), m_Mask(Mask));
1903           if (BO->isCommutative()
1904                   ? match(U, m_c_BinOp(Opcode, Shuf, m_Specific(OtherOp)))
1905                   : MatchShufAsOp0
1906                         ? match(U, m_BinOp(Opcode, Shuf, m_Specific(OtherOp)))
1907                         : match(U, m_BinOp(Opcode, m_Specific(OtherOp), Shuf)))
1908             if (match(Mask, m_ZeroMask()) && Mask[0] != PoisonMaskElem)
1909               if (DT.dominates(U, I))
1910                 return U;
1911         }
1912         return nullptr;
1913       };
1914 
1915       if (User *ShufBO = findShufBO(/* MatchShufAsOp0 */ true))
1916         return ShufBO;
1917       if (User *ShufBO = findShufBO(/* MatchShufAsOp0 */ false))
1918         return ShufBO;
1919     }
1920 
1921     simplifyAndSetOp(I, 0, DemandedElts, PoisonElts);
1922     simplifyAndSetOp(I, 1, DemandedElts, PoisonElts2);
1923 
1924     // Output elements are undefined if both are undefined. Consider things
1925     // like undef & 0. The result is known zero, not undef.
1926     PoisonElts &= PoisonElts2;
1927   }
1928 
1929   // If we've proven all of the lanes poison, return a poison value.
1930   // TODO: Intersect w/demanded lanes
1931   if (PoisonElts.isAllOnes())
1932     return PoisonValue::get(I->getType());
1933 
1934   return MadeChange ? I : nullptr;
1935 }
1936 
1937 /// For floating-point classes that resolve to a single bit pattern, return that
1938 /// value.
1939 static Constant *getFPClassConstant(Type *Ty, FPClassTest Mask) {
1940   if (Mask == fcNone)
1941     return PoisonValue::get(Ty);
1942 
1943   if (Mask == fcPosZero)
1944     return Constant::getNullValue(Ty);
1945 
1946   // TODO: Support aggregate types that are allowed by FPMathOperator.
1947   if (Ty->isAggregateType())
1948     return nullptr;
1949 
1950   switch (Mask) {
1951   case fcNegZero:
1952     return ConstantFP::getZero(Ty, true);
1953   case fcPosInf:
1954     return ConstantFP::getInfinity(Ty);
1955   case fcNegInf:
1956     return ConstantFP::getInfinity(Ty, true);
1957   default:
1958     return nullptr;
1959   }
1960 }
1961 
1962 Value *InstCombinerImpl::SimplifyDemandedUseFPClass(
1963     Value *V, const FPClassTest DemandedMask, KnownFPClass &Known,
1964     unsigned Depth, Instruction *CxtI) {
1965   assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
1966   Type *VTy = V->getType();
1967 
1968   assert(Known == KnownFPClass() && "expected uninitialized state");
1969 
1970   if (DemandedMask == fcNone)
1971     return isa<UndefValue>(V) ? nullptr : PoisonValue::get(VTy);
1972 
1973   if (Depth == MaxAnalysisRecursionDepth)
1974     return nullptr;
1975 
1976   Instruction *I = dyn_cast<Instruction>(V);
1977   if (!I) {
1978     // Handle constants and arguments
1979     Known = computeKnownFPClass(V, fcAllFlags, CxtI, Depth + 1);
1980     Value *FoldedToConst =
1981         getFPClassConstant(VTy, DemandedMask & Known.KnownFPClasses);
1982     return FoldedToConst == V ? nullptr : FoldedToConst;
1983   }
1984 
1985   if (!I->hasOneUse())
1986     return nullptr;
1987 
1988   // TODO: Should account for nofpclass/FastMathFlags on current instruction
1989   switch (I->getOpcode()) {
1990   case Instruction::FNeg: {
1991     if (SimplifyDemandedFPClass(I, 0, llvm::fneg(DemandedMask), Known,
1992                                 Depth + 1))
1993       return I;
1994     Known.fneg();
1995     break;
1996   }
1997   case Instruction::Call: {
1998     CallInst *CI = cast<CallInst>(I);
1999     switch (CI->getIntrinsicID()) {
2000     case Intrinsic::fabs:
2001       if (SimplifyDemandedFPClass(I, 0, llvm::inverse_fabs(DemandedMask), Known,
2002                                   Depth + 1))
2003         return I;
2004       Known.fabs();
2005       break;
2006     case Intrinsic::arithmetic_fence:
2007       if (SimplifyDemandedFPClass(I, 0, DemandedMask, Known, Depth + 1))
2008         return I;
2009       break;
2010     case Intrinsic::copysign: {
2011       // Flip on more potentially demanded classes
2012       const FPClassTest DemandedMaskAnySign = llvm::unknown_sign(DemandedMask);
2013       if (SimplifyDemandedFPClass(I, 0, DemandedMaskAnySign, Known, Depth + 1))
2014         return I;
2015 
2016       if ((DemandedMask & fcPositive) == fcNone) {
2017         // Roundabout way of replacing with fneg(fabs)
2018         I->setOperand(1, ConstantFP::get(VTy, -1.0));
2019         return I;
2020       }
2021 
2022       if ((DemandedMask & fcNegative) == fcNone) {
2023         // Roundabout way of replacing with fabs
2024         I->setOperand(1, ConstantFP::getZero(VTy));
2025         return I;
2026       }
2027 
2028       KnownFPClass KnownSign =
2029           computeKnownFPClass(I->getOperand(1), fcAllFlags, CxtI, Depth + 1);
2030       Known.copysign(KnownSign);
2031       break;
2032     }
2033     default:
2034       Known = computeKnownFPClass(I, ~DemandedMask, CxtI, Depth + 1);
2035       break;
2036     }
2037 
2038     break;
2039   }
2040   case Instruction::Select: {
2041     KnownFPClass KnownLHS, KnownRHS;
2042     if (SimplifyDemandedFPClass(I, 2, DemandedMask, KnownRHS, Depth + 1) ||
2043         SimplifyDemandedFPClass(I, 1, DemandedMask, KnownLHS, Depth + 1))
2044       return I;
2045 
2046     if (KnownLHS.isKnownNever(DemandedMask))
2047       return I->getOperand(2);
2048     if (KnownRHS.isKnownNever(DemandedMask))
2049       return I->getOperand(1);
2050 
2051     // TODO: Recognize clamping patterns
2052     Known = KnownLHS | KnownRHS;
2053     break;
2054   }
2055   default:
2056     Known = computeKnownFPClass(I, ~DemandedMask, CxtI, Depth + 1);
2057     break;
2058   }
2059 
2060   return getFPClassConstant(VTy, DemandedMask & Known.KnownFPClasses);
2061 }
2062 
2063 bool InstCombinerImpl::SimplifyDemandedFPClass(Instruction *I, unsigned OpNo,
2064                                                FPClassTest DemandedMask,
2065                                                KnownFPClass &Known,
2066                                                unsigned Depth) {
2067   Use &U = I->getOperandUse(OpNo);
2068   Value *NewVal =
2069       SimplifyDemandedUseFPClass(U.get(), DemandedMask, Known, Depth, I);
2070   if (!NewVal)
2071     return false;
2072   if (Instruction *OpInst = dyn_cast<Instruction>(U))
2073     salvageDebugInfo(*OpInst);
2074 
2075   replaceUse(U, NewVal);
2076   return true;
2077 }
2078