xref: /llvm-project/llvm/lib/Support/KnownBits.cpp (revision 877cb9a2edc9057d70321e436e5ea8ccc1a87140)
1 //===-- KnownBits.cpp - Stores known zeros/ones ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a class for representing known zeros and ones used by
10 // computeKnownBits.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Support/KnownBits.h"
15 #include "llvm/Support/Debug.h"
16 #include "llvm/Support/raw_ostream.h"
17 #include <cassert>
18 
19 using namespace llvm;
20 
21 KnownBits KnownBits::flipSignBit(const KnownBits &Val) {
22   unsigned SignBitPosition = Val.getBitWidth() - 1;
23   APInt Zero = Val.Zero;
24   APInt One = Val.One;
25   Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
26   One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
27   return KnownBits(Zero, One);
28 }
29 
30 static KnownBits computeForAddCarry(const KnownBits &LHS, const KnownBits &RHS,
31                                     bool CarryZero, bool CarryOne) {
32 
33   APInt PossibleSumZero = LHS.getMaxValue() + RHS.getMaxValue() + !CarryZero;
34   APInt PossibleSumOne = LHS.getMinValue() + RHS.getMinValue() + CarryOne;
35 
36   // Compute known bits of the carry.
37   APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero);
38   APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One;
39 
40   // Compute set of known bits (where all three relevant bits are known).
41   APInt LHSKnownUnion = LHS.Zero | LHS.One;
42   APInt RHSKnownUnion = RHS.Zero | RHS.One;
43   APInt CarryKnownUnion = std::move(CarryKnownZero) | CarryKnownOne;
44   APInt Known = std::move(LHSKnownUnion) & RHSKnownUnion & CarryKnownUnion;
45 
46   // Compute known bits of the result.
47   KnownBits KnownOut;
48   KnownOut.Zero = ~std::move(PossibleSumZero) & Known;
49   KnownOut.One = std::move(PossibleSumOne) & Known;
50   return KnownOut;
51 }
52 
53 KnownBits KnownBits::computeForAddCarry(
54     const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry) {
55   assert(Carry.getBitWidth() == 1 && "Carry must be 1-bit");
56   return ::computeForAddCarry(
57       LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue());
58 }
59 
60 KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, bool NUW,
61                                       const KnownBits &LHS,
62                                       const KnownBits &RHS) {
63   unsigned BitWidth = LHS.getBitWidth();
64   KnownBits KnownOut(BitWidth);
65   // This can be a relatively expensive helper, so optimistically save some
66   // work.
67   if (LHS.isUnknown() && RHS.isUnknown())
68     return KnownOut;
69 
70   if (!LHS.isUnknown() && !RHS.isUnknown()) {
71     if (Add) {
72       // Sum = LHS + RHS + 0
73       KnownOut = ::computeForAddCarry(LHS, RHS, /*CarryZero=*/true,
74                                       /*CarryOne=*/false);
75     } else {
76       // Sum = LHS + ~RHS + 1
77       KnownBits NotRHS = RHS;
78       std::swap(NotRHS.Zero, NotRHS.One);
79       KnownOut = ::computeForAddCarry(LHS, NotRHS, /*CarryZero=*/false,
80                                       /*CarryOne=*/true);
81     }
82   }
83 
84   // Handle add/sub given nsw and/or nuw.
85   if (NUW) {
86     if (Add) {
87       // (add nuw X, Y)
88       APInt MinVal = LHS.getMinValue().uadd_sat(RHS.getMinValue());
89       // None of the adds can end up overflowing, so min consecutive highbits
90       // in minimum possible of X + Y must all remain set.
91       if (NSW) {
92         unsigned NumBits = MinVal.trunc(BitWidth - 1).countl_one();
93         // If we have NSW as well, we also know we can't overflow the signbit so
94         // can start counting from 1 bit back.
95         KnownOut.One.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
96       }
97       KnownOut.One.setHighBits(MinVal.countl_one());
98     } else {
99       // (sub nuw X, Y)
100       APInt MaxVal = LHS.getMaxValue().usub_sat(RHS.getMinValue());
101       // None of the subs can overflow at any point, so any common high bits
102       // will subtract away and result in zeros.
103       if (NSW) {
104         // If we have NSW as well, we also know we can't overflow the signbit so
105         // can start counting from 1 bit back.
106         unsigned NumBits = MaxVal.trunc(BitWidth - 1).countl_zero();
107         KnownOut.Zero.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
108       }
109       KnownOut.Zero.setHighBits(MaxVal.countl_zero());
110     }
111   }
112 
113   if (NSW) {
114     APInt MinVal;
115     APInt MaxVal;
116     if (Add) {
117       // (add nsw X, Y)
118       MinVal = LHS.getSignedMinValue().sadd_sat(RHS.getSignedMinValue());
119       MaxVal = LHS.getSignedMaxValue().sadd_sat(RHS.getSignedMaxValue());
120     } else {
121       // (sub nsw X, Y)
122       MinVal = LHS.getSignedMinValue().ssub_sat(RHS.getSignedMaxValue());
123       MaxVal = LHS.getSignedMaxValue().ssub_sat(RHS.getSignedMinValue());
124     }
125     if (MinVal.isNonNegative()) {
126       // If min is non-negative, result will always be non-neg (can't overflow
127       // around).
128       unsigned NumBits = MinVal.trunc(BitWidth - 1).countl_one();
129       KnownOut.One.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
130       KnownOut.Zero.setSignBit();
131     }
132     if (MaxVal.isNegative()) {
133       // If max is negative, result will always be neg (can't overflow around).
134       unsigned NumBits = MaxVal.trunc(BitWidth - 1).countl_zero();
135       KnownOut.Zero.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
136       KnownOut.One.setSignBit();
137     }
138   }
139 
140   // Just return 0 if the nsw/nuw is violated and we have poison.
141   if (KnownOut.hasConflict())
142     KnownOut.setAllZero();
143   return KnownOut;
144 }
145 
146 KnownBits KnownBits::computeForSubBorrow(const KnownBits &LHS, KnownBits RHS,
147                                          const KnownBits &Borrow) {
148   assert(Borrow.getBitWidth() == 1 && "Borrow must be 1-bit");
149 
150   // LHS - RHS = LHS + ~RHS + 1
151   // Carry 1 - Borrow in ::computeForAddCarry
152   std::swap(RHS.Zero, RHS.One);
153   return ::computeForAddCarry(LHS, RHS,
154                               /*CarryZero=*/Borrow.One.getBoolValue(),
155                               /*CarryOne=*/Borrow.Zero.getBoolValue());
156 }
157 
158 KnownBits KnownBits::sextInReg(unsigned SrcBitWidth) const {
159   unsigned BitWidth = getBitWidth();
160   assert(0 < SrcBitWidth && SrcBitWidth <= BitWidth &&
161          "Illegal sext-in-register");
162 
163   if (SrcBitWidth == BitWidth)
164     return *this;
165 
166   unsigned ExtBits = BitWidth - SrcBitWidth;
167   KnownBits Result;
168   Result.One = One << ExtBits;
169   Result.Zero = Zero << ExtBits;
170   Result.One.ashrInPlace(ExtBits);
171   Result.Zero.ashrInPlace(ExtBits);
172   return Result;
173 }
174 
175 KnownBits KnownBits::makeGE(const APInt &Val) const {
176   // Count the number of leading bit positions where our underlying value is
177   // known to be less than or equal to Val.
178   unsigned N = (Zero | Val).countl_one();
179 
180   // For each of those bit positions, if Val has a 1 in that bit then our
181   // underlying value must also have a 1.
182   APInt MaskedVal(Val);
183   MaskedVal.clearLowBits(getBitWidth() - N);
184   return KnownBits(Zero, One | MaskedVal);
185 }
186 
187 KnownBits KnownBits::umax(const KnownBits &LHS, const KnownBits &RHS) {
188   // If we can prove that LHS >= RHS then use LHS as the result. Likewise for
189   // RHS. Ideally our caller would already have spotted these cases and
190   // optimized away the umax operation, but we handle them here for
191   // completeness.
192   if (LHS.getMinValue().uge(RHS.getMaxValue()))
193     return LHS;
194   if (RHS.getMinValue().uge(LHS.getMaxValue()))
195     return RHS;
196 
197   // If the result of the umax is LHS then it must be greater than or equal to
198   // the minimum possible value of RHS. Likewise for RHS. Any known bits that
199   // are common to these two values are also known in the result.
200   KnownBits L = LHS.makeGE(RHS.getMinValue());
201   KnownBits R = RHS.makeGE(LHS.getMinValue());
202   return L.intersectWith(R);
203 }
204 
205 KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) {
206   // Flip the range of values: [0, 0xFFFFFFFF] <-> [0xFFFFFFFF, 0]
207   auto Flip = [](const KnownBits &Val) { return KnownBits(Val.One, Val.Zero); };
208   return Flip(umax(Flip(LHS), Flip(RHS)));
209 }
210 
211 KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) {
212   return flipSignBit(umax(flipSignBit(LHS), flipSignBit(RHS)));
213 }
214 
215 KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
216   // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0xFFFFFFFF, 0]
217   auto Flip = [](const KnownBits &Val) {
218     unsigned SignBitPosition = Val.getBitWidth() - 1;
219     APInt Zero = Val.One;
220     APInt One = Val.Zero;
221     Zero.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
222     One.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
223     return KnownBits(Zero, One);
224   };
225   return Flip(umax(Flip(LHS), Flip(RHS)));
226 }
227 
228 KnownBits KnownBits::abdu(const KnownBits &LHS, const KnownBits &RHS) {
229   // If we know which argument is larger, return (sub LHS, RHS) or
230   // (sub RHS, LHS) directly.
231   if (LHS.getMinValue().uge(RHS.getMaxValue()))
232     return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS,
233                             RHS);
234   if (RHS.getMinValue().uge(LHS.getMaxValue()))
235     return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, RHS,
236                             LHS);
237 
238   // By construction, the subtraction in abdu never has unsigned overflow.
239   // Find the common bits between (sub nuw LHS, RHS) and (sub nuw RHS, LHS).
240   KnownBits Diff0 =
241       computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS, RHS);
242   KnownBits Diff1 =
243       computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, RHS, LHS);
244   return Diff0.intersectWith(Diff1);
245 }
246 
247 KnownBits KnownBits::abds(KnownBits LHS, KnownBits RHS) {
248   // If we know which argument is larger, return (sub LHS, RHS) or
249   // (sub RHS, LHS) directly.
250   if (LHS.getSignedMinValue().sge(RHS.getSignedMaxValue()))
251     return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS,
252                             RHS);
253   if (RHS.getSignedMinValue().sge(LHS.getSignedMaxValue()))
254     return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, RHS,
255                             LHS);
256 
257   // Shift both arguments from the signed range to the unsigned range, e.g. from
258   // [-0x80, 0x7F] to [0, 0xFF]. This allows us to use "sub nuw" below just like
259   // abdu does.
260   // Note that we can't just use "sub nsw" instead because abds has signed
261   // inputs but an unsigned result, which makes the overflow conditions
262   // different.
263   unsigned SignBitPosition = LHS.getBitWidth() - 1;
264   for (auto Arg : {&LHS, &RHS}) {
265     bool Tmp = Arg->Zero[SignBitPosition];
266     Arg->Zero.setBitVal(SignBitPosition, Arg->One[SignBitPosition]);
267     Arg->One.setBitVal(SignBitPosition, Tmp);
268   }
269 
270   // Find the common bits between (sub nuw LHS, RHS) and (sub nuw RHS, LHS).
271   KnownBits Diff0 =
272       computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS, RHS);
273   KnownBits Diff1 =
274       computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, RHS, LHS);
275   return Diff0.intersectWith(Diff1);
276 }
277 
278 static unsigned getMaxShiftAmount(const APInt &MaxValue, unsigned BitWidth) {
279   if (isPowerOf2_32(BitWidth))
280     return MaxValue.extractBitsAsZExtValue(Log2_32(BitWidth), 0);
281   // This is only an approximate upper bound.
282   return MaxValue.getLimitedValue(BitWidth - 1);
283 }
284 
285 KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
286                          bool NSW, bool ShAmtNonZero) {
287   unsigned BitWidth = LHS.getBitWidth();
288   auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
289     KnownBits Known;
290     bool ShiftedOutZero, ShiftedOutOne;
291     Known.Zero = LHS.Zero.ushl_ov(ShiftAmt, ShiftedOutZero);
292     Known.Zero.setLowBits(ShiftAmt);
293     Known.One = LHS.One.ushl_ov(ShiftAmt, ShiftedOutOne);
294 
295     // All cases returning poison have been handled by MaxShiftAmount already.
296     if (NSW) {
297       if (NUW && ShiftAmt != 0)
298         // NUW means we can assume anything shifted out was a zero.
299         ShiftedOutZero = true;
300 
301       if (ShiftedOutZero)
302         Known.makeNonNegative();
303       else if (ShiftedOutOne)
304         Known.makeNegative();
305     }
306     return Known;
307   };
308 
309   // Fast path for a common case when LHS is completely unknown.
310   KnownBits Known(BitWidth);
311   unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth);
312   if (MinShiftAmount == 0 && ShAmtNonZero)
313     MinShiftAmount = 1;
314   if (LHS.isUnknown()) {
315     Known.Zero.setLowBits(MinShiftAmount);
316     if (NUW && NSW && MinShiftAmount != 0)
317       Known.makeNonNegative();
318     return Known;
319   }
320 
321   // Determine maximum shift amount, taking NUW/NSW flags into account.
322   APInt MaxValue = RHS.getMaxValue();
323   unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
324   if (NUW && NSW)
325     MaxShiftAmount = std::min(MaxShiftAmount, LHS.countMaxLeadingZeros() - 1);
326   if (NUW)
327     MaxShiftAmount = std::min(MaxShiftAmount, LHS.countMaxLeadingZeros());
328   if (NSW)
329     MaxShiftAmount = std::min(
330         MaxShiftAmount,
331         std::max(LHS.countMaxLeadingZeros(), LHS.countMaxLeadingOnes()) - 1);
332 
333   // Fast path for common case where the shift amount is unknown.
334   if (MinShiftAmount == 0 && MaxShiftAmount == BitWidth - 1 &&
335       isPowerOf2_32(BitWidth)) {
336     Known.Zero.setLowBits(LHS.countMinTrailingZeros());
337     if (LHS.isAllOnes())
338       Known.One.setSignBit();
339     if (NSW) {
340       if (LHS.isNonNegative())
341         Known.makeNonNegative();
342       if (LHS.isNegative())
343         Known.makeNegative();
344     }
345     return Known;
346   }
347 
348   // Find the common bits from all possible shifts.
349   unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
350   unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
351   Known.Zero.setAllBits();
352   Known.One.setAllBits();
353   for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
354        ++ShiftAmt) {
355     // Skip if the shift amount is impossible.
356     if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
357         (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
358       continue;
359     Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt));
360     if (Known.isUnknown())
361       break;
362   }
363 
364   // All shift amounts may result in poison.
365   if (Known.hasConflict())
366     Known.setAllZero();
367   return Known;
368 }
369 
370 KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
371                           bool ShAmtNonZero, bool Exact) {
372   unsigned BitWidth = LHS.getBitWidth();
373   auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
374     KnownBits Known = LHS;
375     Known.Zero.lshrInPlace(ShiftAmt);
376     Known.One.lshrInPlace(ShiftAmt);
377     // High bits are known zero.
378     Known.Zero.setHighBits(ShiftAmt);
379     return Known;
380   };
381 
382   // Fast path for a common case when LHS is completely unknown.
383   KnownBits Known(BitWidth);
384   unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth);
385   if (MinShiftAmount == 0 && ShAmtNonZero)
386     MinShiftAmount = 1;
387   if (LHS.isUnknown()) {
388     Known.Zero.setHighBits(MinShiftAmount);
389     return Known;
390   }
391 
392   // Find the common bits from all possible shifts.
393   APInt MaxValue = RHS.getMaxValue();
394   unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
395 
396   // If exact, bound MaxShiftAmount to first known 1 in LHS.
397   if (Exact) {
398     unsigned FirstOne = LHS.countMaxTrailingZeros();
399     if (FirstOne < MinShiftAmount) {
400       // Always poison. Return zero because we don't like returning conflict.
401       Known.setAllZero();
402       return Known;
403     }
404     MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
405   }
406 
407   unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
408   unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
409   Known.Zero.setAllBits();
410   Known.One.setAllBits();
411   for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
412        ++ShiftAmt) {
413     // Skip if the shift amount is impossible.
414     if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
415         (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
416       continue;
417     Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt));
418     if (Known.isUnknown())
419       break;
420   }
421 
422   // All shift amounts may result in poison.
423   if (Known.hasConflict())
424     Known.setAllZero();
425   return Known;
426 }
427 
428 KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
429                           bool ShAmtNonZero, bool Exact) {
430   unsigned BitWidth = LHS.getBitWidth();
431   auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
432     KnownBits Known = LHS;
433     Known.Zero.ashrInPlace(ShiftAmt);
434     Known.One.ashrInPlace(ShiftAmt);
435     return Known;
436   };
437 
438   // Fast path for a common case when LHS is completely unknown.
439   KnownBits Known(BitWidth);
440   unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth);
441   if (MinShiftAmount == 0 && ShAmtNonZero)
442     MinShiftAmount = 1;
443   if (LHS.isUnknown()) {
444     if (MinShiftAmount == BitWidth) {
445       // Always poison. Return zero because we don't like returning conflict.
446       Known.setAllZero();
447       return Known;
448     }
449     return Known;
450   }
451 
452   // Find the common bits from all possible shifts.
453   APInt MaxValue = RHS.getMaxValue();
454   unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
455 
456   // If exact, bound MaxShiftAmount to first known 1 in LHS.
457   if (Exact) {
458     unsigned FirstOne = LHS.countMaxTrailingZeros();
459     if (FirstOne < MinShiftAmount) {
460       // Always poison. Return zero because we don't like returning conflict.
461       Known.setAllZero();
462       return Known;
463     }
464     MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
465   }
466 
467   unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
468   unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
469   Known.Zero.setAllBits();
470   Known.One.setAllBits();
471   for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
472       ++ShiftAmt) {
473     // Skip if the shift amount is impossible.
474     if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
475         (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
476       continue;
477     Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt));
478     if (Known.isUnknown())
479       break;
480   }
481 
482   // All shift amounts may result in poison.
483   if (Known.hasConflict())
484     Known.setAllZero();
485   return Known;
486 }
487 
488 std::optional<bool> KnownBits::eq(const KnownBits &LHS, const KnownBits &RHS) {
489   if (LHS.isConstant() && RHS.isConstant())
490     return std::optional<bool>(LHS.getConstant() == RHS.getConstant());
491   if (LHS.One.intersects(RHS.Zero) || RHS.One.intersects(LHS.Zero))
492     return std::optional<bool>(false);
493   return std::nullopt;
494 }
495 
496 std::optional<bool> KnownBits::ne(const KnownBits &LHS, const KnownBits &RHS) {
497   if (std::optional<bool> KnownEQ = eq(LHS, RHS))
498     return std::optional<bool>(!*KnownEQ);
499   return std::nullopt;
500 }
501 
502 std::optional<bool> KnownBits::ugt(const KnownBits &LHS, const KnownBits &RHS) {
503   // LHS >u RHS -> false if umax(LHS) <= umax(RHS)
504   if (LHS.getMaxValue().ule(RHS.getMinValue()))
505     return std::optional<bool>(false);
506   // LHS >u RHS -> true if umin(LHS) > umax(RHS)
507   if (LHS.getMinValue().ugt(RHS.getMaxValue()))
508     return std::optional<bool>(true);
509   return std::nullopt;
510 }
511 
512 std::optional<bool> KnownBits::uge(const KnownBits &LHS, const KnownBits &RHS) {
513   if (std::optional<bool> IsUGT = ugt(RHS, LHS))
514     return std::optional<bool>(!*IsUGT);
515   return std::nullopt;
516 }
517 
518 std::optional<bool> KnownBits::ult(const KnownBits &LHS, const KnownBits &RHS) {
519   return ugt(RHS, LHS);
520 }
521 
522 std::optional<bool> KnownBits::ule(const KnownBits &LHS, const KnownBits &RHS) {
523   return uge(RHS, LHS);
524 }
525 
526 std::optional<bool> KnownBits::sgt(const KnownBits &LHS, const KnownBits &RHS) {
527   // LHS >s RHS -> false if smax(LHS) <= smax(RHS)
528   if (LHS.getSignedMaxValue().sle(RHS.getSignedMinValue()))
529     return std::optional<bool>(false);
530   // LHS >s RHS -> true if smin(LHS) > smax(RHS)
531   if (LHS.getSignedMinValue().sgt(RHS.getSignedMaxValue()))
532     return std::optional<bool>(true);
533   return std::nullopt;
534 }
535 
536 std::optional<bool> KnownBits::sge(const KnownBits &LHS, const KnownBits &RHS) {
537   if (std::optional<bool> KnownSGT = sgt(RHS, LHS))
538     return std::optional<bool>(!*KnownSGT);
539   return std::nullopt;
540 }
541 
542 std::optional<bool> KnownBits::slt(const KnownBits &LHS, const KnownBits &RHS) {
543   return sgt(RHS, LHS);
544 }
545 
546 std::optional<bool> KnownBits::sle(const KnownBits &LHS, const KnownBits &RHS) {
547   return sge(RHS, LHS);
548 }
549 
550 KnownBits KnownBits::abs(bool IntMinIsPoison) const {
551   // If the source's MSB is zero then we know the rest of the bits already.
552   if (isNonNegative())
553     return *this;
554 
555   // Absolute value preserves trailing zero count.
556   KnownBits KnownAbs(getBitWidth());
557 
558   // If the input is negative, then abs(x) == -x.
559   if (isNegative()) {
560     KnownBits Tmp = *this;
561     // Special case for IntMinIsPoison. We know the sign bit is set and we know
562     // all the rest of the bits except one to be zero. Since we have
563     // IntMinIsPoison, that final bit MUST be a one, as otherwise the input is
564     // INT_MIN.
565     if (IntMinIsPoison && (Zero.popcount() + 2) == getBitWidth())
566       Tmp.One.setBit(countMinTrailingZeros());
567 
568     KnownAbs = computeForAddSub(
569         /*Add*/ false, IntMinIsPoison, /*NUW=*/false,
570         KnownBits::makeConstant(APInt(getBitWidth(), 0)), Tmp);
571 
572     // One more special case for IntMinIsPoison. If we don't know any ones other
573     // than the signbit, we know for certain that all the unknowns can't be
574     // zero. So if we know high zero bits, but have unknown low bits, we know
575     // for certain those high-zero bits will end up as one. This is because,
576     // the low bits can't be all zeros, so the +1 in (~x + 1) cannot carry up
577     // to the high bits. If we know a known INT_MIN input skip this. The result
578     // is poison anyways.
579     if (IntMinIsPoison && Tmp.countMinPopulation() == 1 &&
580         Tmp.countMaxPopulation() != 1) {
581       Tmp.One.clearSignBit();
582       Tmp.Zero.setSignBit();
583       KnownAbs.One.setBits(getBitWidth() - Tmp.countMinLeadingZeros(),
584                            getBitWidth() - 1);
585     }
586 
587   } else {
588     unsigned MaxTZ = countMaxTrailingZeros();
589     unsigned MinTZ = countMinTrailingZeros();
590 
591     KnownAbs.Zero.setLowBits(MinTZ);
592     // If we know the lowest set 1, then preserve it.
593     if (MaxTZ == MinTZ && MaxTZ < getBitWidth())
594       KnownAbs.One.setBit(MaxTZ);
595 
596     // We only know that the absolute values's MSB will be zero if INT_MIN is
597     // poison, or there is a set bit that isn't the sign bit (otherwise it could
598     // be INT_MIN).
599     if (IntMinIsPoison || (!One.isZero() && !One.isMinSignedValue())) {
600       KnownAbs.One.clearSignBit();
601       KnownAbs.Zero.setSignBit();
602     }
603   }
604 
605   return KnownAbs;
606 }
607 
608 static KnownBits computeForSatAddSub(bool Add, bool Signed,
609                                      const KnownBits &LHS,
610                                      const KnownBits &RHS) {
611   // We don't see NSW even for sadd/ssub as we want to check if the result has
612   // signed overflow.
613   unsigned BitWidth = LHS.getBitWidth();
614 
615   std::optional<bool> Overflow;
616   // Even if we can't entirely rule out overflow, we may be able to rule out
617   // overflow in one direction. This allows us to potentially keep some of the
618   // add/sub bits. I.e if we can't overflow in the positive direction we won't
619   // clamp to INT_MAX so we can keep low 0s from the add/sub result.
620   bool MayNegClamp = true;
621   bool MayPosClamp = true;
622   if (Signed) {
623     // Easy cases we can rule out any overflow.
624     if (Add && ((LHS.isNegative() && RHS.isNonNegative()) ||
625                 (LHS.isNonNegative() && RHS.isNegative())))
626       Overflow = false;
627     else if (!Add && (((LHS.isNegative() && RHS.isNegative()) ||
628                        (LHS.isNonNegative() && RHS.isNonNegative()))))
629       Overflow = false;
630     else {
631       // Check if we may overflow. If we can't rule out overflow then check if
632       // we can rule out a direction at least.
633       KnownBits UnsignedLHS = LHS;
634       KnownBits UnsignedRHS = RHS;
635       // Get version of LHS/RHS with clearer signbit. This allows us to detect
636       // how the addition/subtraction might overflow into the signbit. Then
637       // using the actual known signbits of LHS/RHS, we can figure out which
638       // overflows are/aren't possible.
639       UnsignedLHS.One.clearSignBit();
640       UnsignedLHS.Zero.setSignBit();
641       UnsignedRHS.One.clearSignBit();
642       UnsignedRHS.Zero.setSignBit();
643       KnownBits Res =
644           KnownBits::computeForAddSub(Add, /*NSW=*/false,
645                                       /*NUW=*/false, UnsignedLHS, UnsignedRHS);
646       if (Add) {
647         if (Res.isNegative()) {
648           // Only overflow scenario is Pos + Pos.
649           MayNegClamp = false;
650           // Pos + Pos will overflow with extra signbit.
651           if (LHS.isNonNegative() && RHS.isNonNegative())
652             Overflow = true;
653         } else if (Res.isNonNegative()) {
654           // Only overflow scenario is Neg + Neg
655           MayPosClamp = false;
656           // Neg + Neg will overflow without extra signbit.
657           if (LHS.isNegative() && RHS.isNegative())
658             Overflow = true;
659         }
660         // We will never clamp to the opposite sign of N-bit result.
661         if (LHS.isNegative() || RHS.isNegative())
662           MayPosClamp = false;
663         if (LHS.isNonNegative() || RHS.isNonNegative())
664           MayNegClamp = false;
665       } else {
666         if (Res.isNegative()) {
667           // Only overflow scenario is Neg - Pos.
668           MayPosClamp = false;
669           // Neg - Pos will overflow with extra signbit.
670           if (LHS.isNegative() && RHS.isNonNegative())
671             Overflow = true;
672         } else if (Res.isNonNegative()) {
673           // Only overflow scenario is Pos - Neg.
674           MayNegClamp = false;
675           // Pos - Neg will overflow without extra signbit.
676           if (LHS.isNonNegative() && RHS.isNegative())
677             Overflow = true;
678         }
679         // We will never clamp to the opposite sign of N-bit result.
680         if (LHS.isNegative() || RHS.isNonNegative())
681           MayPosClamp = false;
682         if (LHS.isNonNegative() || RHS.isNegative())
683           MayNegClamp = false;
684       }
685     }
686     // If we have ruled out all clamping, we will never overflow.
687     if (!MayNegClamp && !MayPosClamp)
688       Overflow = false;
689   } else if (Add) {
690     // uadd.sat
691     bool Of;
692     (void)LHS.getMaxValue().uadd_ov(RHS.getMaxValue(), Of);
693     if (!Of) {
694       Overflow = false;
695     } else {
696       (void)LHS.getMinValue().uadd_ov(RHS.getMinValue(), Of);
697       if (Of)
698         Overflow = true;
699     }
700   } else {
701     // usub.sat
702     bool Of;
703     (void)LHS.getMinValue().usub_ov(RHS.getMaxValue(), Of);
704     if (!Of) {
705       Overflow = false;
706     } else {
707       (void)LHS.getMaxValue().usub_ov(RHS.getMinValue(), Of);
708       if (Of)
709         Overflow = true;
710     }
711   }
712 
713   KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW=*/Signed,
714                                               /*NUW=*/!Signed, LHS, RHS);
715 
716   if (Overflow) {
717     // We know whether or not we overflowed.
718     if (!(*Overflow)) {
719       // No overflow.
720       return Res;
721     }
722 
723     // We overflowed
724     APInt C;
725     if (Signed) {
726       // sadd.sat / ssub.sat
727       assert(!LHS.isSignUnknown() &&
728              "We somehow know overflow without knowing input sign");
729       C = LHS.isNegative() ? APInt::getSignedMinValue(BitWidth)
730                            : APInt::getSignedMaxValue(BitWidth);
731     } else if (Add) {
732       // uadd.sat
733       C = APInt::getMaxValue(BitWidth);
734     } else {
735       // uadd.sat
736       C = APInt::getMinValue(BitWidth);
737     }
738 
739     Res.One = C;
740     Res.Zero = ~C;
741     return Res;
742   }
743 
744   // We don't know if we overflowed.
745   if (Signed) {
746     // sadd.sat/ssub.sat
747     // We can keep our information about the sign bits.
748     if (MayPosClamp)
749       Res.Zero.clearLowBits(BitWidth - 1);
750     if (MayNegClamp)
751       Res.One.clearLowBits(BitWidth - 1);
752   } else if (Add) {
753     // uadd.sat
754     // We need to clear all the known zeros as we can only use the leading ones.
755     Res.Zero.clearAllBits();
756   } else {
757     // usub.sat
758     // We need to clear all the known ones as we can only use the leading zero.
759     Res.One.clearAllBits();
760   }
761 
762   return Res;
763 }
764 
765 KnownBits KnownBits::sadd_sat(const KnownBits &LHS, const KnownBits &RHS) {
766   return computeForSatAddSub(/*Add*/ true, /*Signed*/ true, LHS, RHS);
767 }
768 KnownBits KnownBits::ssub_sat(const KnownBits &LHS, const KnownBits &RHS) {
769   return computeForSatAddSub(/*Add*/ false, /*Signed*/ true, LHS, RHS);
770 }
771 KnownBits KnownBits::uadd_sat(const KnownBits &LHS, const KnownBits &RHS) {
772   return computeForSatAddSub(/*Add*/ true, /*Signed*/ false, LHS, RHS);
773 }
774 KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) {
775   return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, LHS, RHS);
776 }
777 
778 static KnownBits avgComputeU(KnownBits LHS, KnownBits RHS, bool IsCeil) {
779   unsigned BitWidth = LHS.getBitWidth();
780   LHS = LHS.zext(BitWidth + 1);
781   RHS = RHS.zext(BitWidth + 1);
782   LHS =
783       computeForAddCarry(LHS, RHS, /*CarryZero*/ !IsCeil, /*CarryOne*/ IsCeil);
784   LHS = LHS.extractBits(BitWidth, 1);
785   return LHS;
786 }
787 
788 KnownBits KnownBits::avgFloorS(const KnownBits &LHS, const KnownBits &RHS) {
789   return flipSignBit(avgFloorU(flipSignBit(LHS), flipSignBit(RHS)));
790 }
791 
792 KnownBits KnownBits::avgFloorU(const KnownBits &LHS, const KnownBits &RHS) {
793   return avgComputeU(LHS, RHS, /*IsCeil=*/false);
794 }
795 
796 KnownBits KnownBits::avgCeilS(const KnownBits &LHS, const KnownBits &RHS) {
797   return flipSignBit(avgCeilU(flipSignBit(LHS), flipSignBit(RHS)));
798 }
799 
800 KnownBits KnownBits::avgCeilU(const KnownBits &LHS, const KnownBits &RHS) {
801   return avgComputeU(LHS, RHS, /*IsCeil=*/true);
802 }
803 
804 KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
805                          bool NoUndefSelfMultiply) {
806   unsigned BitWidth = LHS.getBitWidth();
807   assert(BitWidth == RHS.getBitWidth() && "Operand mismatch");
808   assert((!NoUndefSelfMultiply || LHS == RHS) &&
809          "Self multiplication knownbits mismatch");
810 
811   // Compute the high known-0 bits by multiplying the unsigned max of each side.
812   // Conservatively, M active bits * N active bits results in M + N bits in the
813   // result. But if we know a value is a power-of-2 for example, then this
814   // computes one more leading zero.
815   // TODO: This could be generalized to number of sign bits (negative numbers).
816   APInt UMaxLHS = LHS.getMaxValue();
817   APInt UMaxRHS = RHS.getMaxValue();
818 
819   // For leading zeros in the result to be valid, the unsigned max product must
820   // fit in the bitwidth (it must not overflow).
821   bool HasOverflow;
822   APInt UMaxResult = UMaxLHS.umul_ov(UMaxRHS, HasOverflow);
823   unsigned LeadZ = HasOverflow ? 0 : UMaxResult.countl_zero();
824 
825   // The result of the bottom bits of an integer multiply can be
826   // inferred by looking at the bottom bits of both operands and
827   // multiplying them together.
828   // We can infer at least the minimum number of known trailing bits
829   // of both operands. Depending on number of trailing zeros, we can
830   // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming
831   // a and b are divisible by m and n respectively.
832   // We then calculate how many of those bits are inferrable and set
833   // the output. For example, the i8 mul:
834   //  a = XXXX1100 (12)
835   //  b = XXXX1110 (14)
836   // We know the bottom 3 bits are zero since the first can be divided by
837   // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4).
838   // Applying the multiplication to the trimmed arguments gets:
839   //    XX11 (3)
840   //    X111 (7)
841   // -------
842   //    XX11
843   //   XX11
844   //  XX11
845   // XX11
846   // -------
847   // XXXXX01
848   // Which allows us to infer the 2 LSBs. Since we're multiplying the result
849   // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits.
850   // The proof for this can be described as:
851   // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) &&
852   //      (C7 == (1 << (umin(countTrailingZeros(C1), C5) +
853   //                    umin(countTrailingZeros(C2), C6) +
854   //                    umin(C5 - umin(countTrailingZeros(C1), C5),
855   //                         C6 - umin(countTrailingZeros(C2), C6)))) - 1)
856   // %aa = shl i8 %a, C5
857   // %bb = shl i8 %b, C6
858   // %aaa = or i8 %aa, C1
859   // %bbb = or i8 %bb, C2
860   // %mul = mul i8 %aaa, %bbb
861   // %mask = and i8 %mul, C7
862   //   =>
863   // %mask = i8 ((C1*C2)&C7)
864   // Where C5, C6 describe the known bits of %a, %b
865   // C1, C2 describe the known bottom bits of %a, %b.
866   // C7 describes the mask of the known bits of the result.
867   const APInt &Bottom0 = LHS.One;
868   const APInt &Bottom1 = RHS.One;
869 
870   // How many times we'd be able to divide each argument by 2 (shr by 1).
871   // This gives us the number of trailing zeros on the multiplication result.
872   unsigned TrailBitsKnown0 = (LHS.Zero | LHS.One).countr_one();
873   unsigned TrailBitsKnown1 = (RHS.Zero | RHS.One).countr_one();
874   unsigned TrailZero0 = LHS.countMinTrailingZeros();
875   unsigned TrailZero1 = RHS.countMinTrailingZeros();
876   unsigned TrailZ = TrailZero0 + TrailZero1;
877 
878   // Figure out the fewest known-bits operand.
879   unsigned SmallestOperand =
880       std::min(TrailBitsKnown0 - TrailZero0, TrailBitsKnown1 - TrailZero1);
881   unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth);
882 
883   APInt BottomKnown =
884       Bottom0.getLoBits(TrailBitsKnown0) * Bottom1.getLoBits(TrailBitsKnown1);
885 
886   KnownBits Res(BitWidth);
887   Res.Zero.setHighBits(LeadZ);
888   Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown);
889   Res.One = BottomKnown.getLoBits(ResultBitsKnown);
890 
891   // If we're self-multiplying then bit[1] is guaranteed to be zero.
892   if (NoUndefSelfMultiply && BitWidth > 1) {
893     assert(Res.One[1] == 0 &&
894            "Self-multiplication failed Quadratic Reciprocity!");
895     Res.Zero.setBit(1);
896   }
897 
898   return Res;
899 }
900 
901 KnownBits KnownBits::mulhs(const KnownBits &LHS, const KnownBits &RHS) {
902   unsigned BitWidth = LHS.getBitWidth();
903   assert(BitWidth == RHS.getBitWidth() && "Operand mismatch");
904   KnownBits WideLHS = LHS.sext(2 * BitWidth);
905   KnownBits WideRHS = RHS.sext(2 * BitWidth);
906   return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth);
907 }
908 
909 KnownBits KnownBits::mulhu(const KnownBits &LHS, const KnownBits &RHS) {
910   unsigned BitWidth = LHS.getBitWidth();
911   assert(BitWidth == RHS.getBitWidth() && "Operand mismatch");
912   KnownBits WideLHS = LHS.zext(2 * BitWidth);
913   KnownBits WideRHS = RHS.zext(2 * BitWidth);
914   return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth);
915 }
916 
917 static KnownBits divComputeLowBit(KnownBits Known, const KnownBits &LHS,
918                                   const KnownBits &RHS, bool Exact) {
919 
920   if (!Exact)
921     return Known;
922 
923   // If LHS is Odd, the result is Odd no matter what.
924   // Odd / Odd -> Odd
925   // Odd / Even -> Impossible (because its exact division)
926   if (LHS.One[0])
927     Known.One.setBit(0);
928 
929   int MinTZ =
930       (int)LHS.countMinTrailingZeros() - (int)RHS.countMaxTrailingZeros();
931   int MaxTZ =
932       (int)LHS.countMaxTrailingZeros() - (int)RHS.countMinTrailingZeros();
933   if (MinTZ >= 0) {
934     // Result has at least MinTZ trailing zeros.
935     Known.Zero.setLowBits(MinTZ);
936     if (MinTZ == MaxTZ) {
937       // Result has exactly MinTZ trailing zeros.
938       Known.One.setBit(MinTZ);
939     }
940   } else if (MaxTZ < 0) {
941     // Poison Result
942     Known.setAllZero();
943   }
944 
945   // In the KnownBits exhaustive tests, we have poison inputs for exact values
946   // a LOT. If we have a conflict, just return all zeros.
947   if (Known.hasConflict())
948     Known.setAllZero();
949 
950   return Known;
951 }
952 
953 KnownBits KnownBits::sdiv(const KnownBits &LHS, const KnownBits &RHS,
954                           bool Exact) {
955   // Equivalent of `udiv`. We must have caught this before it was folded.
956   if (LHS.isNonNegative() && RHS.isNonNegative())
957     return udiv(LHS, RHS, Exact);
958 
959   unsigned BitWidth = LHS.getBitWidth();
960   KnownBits Known(BitWidth);
961 
962   if (LHS.isZero() || RHS.isZero()) {
963     // Result is either known Zero or UB. Return Zero either way.
964     // Checking this earlier saves us a lot of special cases later on.
965     Known.setAllZero();
966     return Known;
967   }
968 
969   std::optional<APInt> Res;
970   if (LHS.isNegative() && RHS.isNegative()) {
971     // Result non-negative.
972     APInt Denom = RHS.getSignedMaxValue();
973     APInt Num = LHS.getSignedMinValue();
974     // INT_MIN/-1 would be a poison result (impossible). Estimate the division
975     // as signed max (we will only set sign bit in the result).
976     Res = (Num.isMinSignedValue() && Denom.isAllOnes())
977               ? APInt::getSignedMaxValue(BitWidth)
978               : Num.sdiv(Denom);
979   } else if (LHS.isNegative() && RHS.isNonNegative()) {
980     // Result is negative if Exact OR -LHS u>= RHS.
981     if (Exact || (-LHS.getSignedMaxValue()).uge(RHS.getSignedMaxValue())) {
982       APInt Denom = RHS.getSignedMinValue();
983       APInt Num = LHS.getSignedMinValue();
984       Res = Denom.isZero() ? Num : Num.sdiv(Denom);
985     }
986   } else if (LHS.isStrictlyPositive() && RHS.isNegative()) {
987     // Result is negative if Exact OR LHS u>= -RHS.
988     if (Exact || LHS.getSignedMinValue().uge(-RHS.getSignedMinValue())) {
989       APInt Denom = RHS.getSignedMaxValue();
990       APInt Num = LHS.getSignedMaxValue();
991       Res = Num.sdiv(Denom);
992     }
993   }
994 
995   if (Res) {
996     if (Res->isNonNegative()) {
997       unsigned LeadZ = Res->countLeadingZeros();
998       Known.Zero.setHighBits(LeadZ);
999     } else {
1000       unsigned LeadO = Res->countLeadingOnes();
1001       Known.One.setHighBits(LeadO);
1002     }
1003   }
1004 
1005   Known = divComputeLowBit(Known, LHS, RHS, Exact);
1006   return Known;
1007 }
1008 
1009 KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS,
1010                           bool Exact) {
1011   unsigned BitWidth = LHS.getBitWidth();
1012   KnownBits Known(BitWidth);
1013 
1014   if (LHS.isZero() || RHS.isZero()) {
1015     // Result is either known Zero or UB. Return Zero either way.
1016     // Checking this earlier saves us a lot of special cases later on.
1017     Known.setAllZero();
1018     return Known;
1019   }
1020 
1021   // We can figure out the minimum number of upper zero bits by doing
1022   // MaxNumerator / MinDenominator. If the Numerator gets smaller or Denominator
1023   // gets larger, the number of upper zero bits increases.
1024   APInt MinDenom = RHS.getMinValue();
1025   APInt MaxNum = LHS.getMaxValue();
1026   APInt MaxRes = MinDenom.isZero() ? MaxNum : MaxNum.udiv(MinDenom);
1027 
1028   unsigned LeadZ = MaxRes.countLeadingZeros();
1029 
1030   Known.Zero.setHighBits(LeadZ);
1031   Known = divComputeLowBit(Known, LHS, RHS, Exact);
1032 
1033   return Known;
1034 }
1035 
1036 KnownBits KnownBits::remGetLowBits(const KnownBits &LHS, const KnownBits &RHS) {
1037   unsigned BitWidth = LHS.getBitWidth();
1038   if (!RHS.isZero() && RHS.Zero[0]) {
1039     // rem X, Y where Y[0:N] is zero will preserve X[0:N] in the result.
1040     unsigned RHSZeros = RHS.countMinTrailingZeros();
1041     APInt Mask = APInt::getLowBitsSet(BitWidth, RHSZeros);
1042     APInt OnesMask = LHS.One & Mask;
1043     APInt ZerosMask = LHS.Zero & Mask;
1044     return KnownBits(ZerosMask, OnesMask);
1045   }
1046   return KnownBits(BitWidth);
1047 }
1048 
1049 KnownBits KnownBits::urem(const KnownBits &LHS, const KnownBits &RHS) {
1050   KnownBits Known = remGetLowBits(LHS, RHS);
1051   if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) {
1052     // NB: Low bits set in `remGetLowBits`.
1053     APInt HighBits = ~(RHS.getConstant() - 1);
1054     Known.Zero |= HighBits;
1055     return Known;
1056   }
1057 
1058   // Since the result is less than or equal to either operand, any leading
1059   // zero bits in either operand must also exist in the result.
1060   uint32_t Leaders =
1061       std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingZeros());
1062   Known.Zero.setHighBits(Leaders);
1063   return Known;
1064 }
1065 
1066 KnownBits KnownBits::srem(const KnownBits &LHS, const KnownBits &RHS) {
1067   KnownBits Known = remGetLowBits(LHS, RHS);
1068   if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) {
1069     // NB: Low bits are set in `remGetLowBits`.
1070     APInt LowBits = RHS.getConstant() - 1;
1071     // If the first operand is non-negative or has all low bits zero, then
1072     // the upper bits are all zero.
1073     if (LHS.isNonNegative() || LowBits.isSubsetOf(LHS.Zero))
1074       Known.Zero |= ~LowBits;
1075 
1076     // If the first operand is negative and not all low bits are zero, then
1077     // the upper bits are all one.
1078     if (LHS.isNegative() && LowBits.intersects(LHS.One))
1079       Known.One |= ~LowBits;
1080     return Known;
1081   }
1082 
1083   // The sign bit is the LHS's sign bit, except when the result of the
1084   // remainder is zero. The magnitude of the result should be less than or
1085   // equal to the magnitude of either operand.
1086   if (LHS.isNegative() && Known.isNonZero())
1087     Known.One.setHighBits(
1088         std::max(LHS.countMinLeadingOnes(), RHS.countMinSignBits()));
1089   else if (LHS.isNonNegative())
1090     Known.Zero.setHighBits(
1091         std::max(LHS.countMinLeadingZeros(), RHS.countMinSignBits()));
1092   return Known;
1093 }
1094 
1095 KnownBits &KnownBits::operator&=(const KnownBits &RHS) {
1096   // Result bit is 0 if either operand bit is 0.
1097   Zero |= RHS.Zero;
1098   // Result bit is 1 if both operand bits are 1.
1099   One &= RHS.One;
1100   return *this;
1101 }
1102 
1103 KnownBits &KnownBits::operator|=(const KnownBits &RHS) {
1104   // Result bit is 0 if both operand bits are 0.
1105   Zero &= RHS.Zero;
1106   // Result bit is 1 if either operand bit is 1.
1107   One |= RHS.One;
1108   return *this;
1109 }
1110 
1111 KnownBits &KnownBits::operator^=(const KnownBits &RHS) {
1112   // Result bit is 0 if both operand bits are 0 or both are 1.
1113   APInt Z = (Zero & RHS.Zero) | (One & RHS.One);
1114   // Result bit is 1 if one operand bit is 0 and the other is 1.
1115   One = (Zero & RHS.One) | (One & RHS.Zero);
1116   Zero = std::move(Z);
1117   return *this;
1118 }
1119 
1120 KnownBits KnownBits::blsi() const {
1121   unsigned BitWidth = getBitWidth();
1122   KnownBits Known(Zero, APInt(BitWidth, 0));
1123   unsigned Max = countMaxTrailingZeros();
1124   Known.Zero.setBitsFrom(std::min(Max + 1, BitWidth));
1125   unsigned Min = countMinTrailingZeros();
1126   if (Max == Min && Max < BitWidth)
1127     Known.One.setBit(Max);
1128   return Known;
1129 }
1130 
1131 KnownBits KnownBits::blsmsk() const {
1132   unsigned BitWidth = getBitWidth();
1133   KnownBits Known(BitWidth);
1134   unsigned Max = countMaxTrailingZeros();
1135   Known.Zero.setBitsFrom(std::min(Max + 1, BitWidth));
1136   unsigned Min = countMinTrailingZeros();
1137   Known.One.setLowBits(std::min(Min + 1, BitWidth));
1138   return Known;
1139 }
1140 
1141 void KnownBits::print(raw_ostream &OS) const {
1142   unsigned BitWidth = getBitWidth();
1143   for (unsigned I = 0; I < BitWidth; ++I) {
1144     unsigned N = BitWidth - I - 1;
1145     if (Zero[N] && One[N])
1146       OS << "!";
1147     else if (Zero[N])
1148       OS << "0";
1149     else if (One[N])
1150       OS << "1";
1151     else
1152       OS << "?";
1153   }
1154 }
1155 void KnownBits::dump() const {
1156   print(dbgs());
1157   dbgs() << "\n";
1158 }
1159