xref: /llvm-project/llvm/include/llvm/Support/KnownBits.h (revision 5cabf1505d2180083ae3cf34abd79924cbcdbdbf)
1 //===- llvm/Support/KnownBits.h - Stores known zeros/ones -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a class for representing known zeros and ones used by
10 // computeKnownBits.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_SUPPORT_KNOWNBITS_H
15 #define LLVM_SUPPORT_KNOWNBITS_H
16 
17 #include "llvm/ADT/APInt.h"
18 #include <optional>
19 
20 namespace llvm {
21 
22 // Struct for tracking the known zeros and ones of a value.
23 struct KnownBits {
24   APInt Zero;
25   APInt One;
26 
27 private:
28   // Internal constructor for creating a KnownBits from two APInts.
29   KnownBits(APInt Zero, APInt One)
30       : Zero(std::move(Zero)), One(std::move(One)) {}
31 
32   // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF]
33   static KnownBits flipSignBit(const KnownBits &Val);
34 
35 public:
36   // Default construct Zero and One.
37   KnownBits() = default;
38 
39   /// Create a known bits object of BitWidth bits initialized to unknown.
40   KnownBits(unsigned BitWidth) : Zero(BitWidth, 0), One(BitWidth, 0) {}
41 
42   /// Get the bit width of this value.
43   unsigned getBitWidth() const {
44     assert(Zero.getBitWidth() == One.getBitWidth() &&
45            "Zero and One should have the same width!");
46     return Zero.getBitWidth();
47   }
48 
49   /// Returns true if there is conflicting information.
50   bool hasConflict() const { return Zero.intersects(One); }
51 
52   /// Returns true if we know the value of all bits.
53   bool isConstant() const {
54     return Zero.popcount() + One.popcount() == getBitWidth();
55   }
56 
57   /// Returns the value when all bits have a known value. This just returns One
58   /// with a protective assertion.
59   const APInt &getConstant() const {
60     assert(isConstant() && "Can only get value when all bits are known");
61     return One;
62   }
63 
64   /// Returns true if we don't know any bits.
65   bool isUnknown() const { return Zero.isZero() && One.isZero(); }
66 
67   /// Returns true if we don't know the sign bit.
68   bool isSignUnknown() const {
69     return !Zero.isSignBitSet() && !One.isSignBitSet();
70   }
71 
72   /// Resets the known state of all bits.
73   void resetAll() {
74     Zero.clearAllBits();
75     One.clearAllBits();
76   }
77 
78   /// Returns true if value is all zero.
79   bool isZero() const { return Zero.isAllOnes(); }
80 
81   /// Returns true if value is all one bits.
82   bool isAllOnes() const { return One.isAllOnes(); }
83 
84   /// Make all bits known to be zero and discard any previous information.
85   void setAllZero() {
86     Zero.setAllBits();
87     One.clearAllBits();
88   }
89 
90   /// Make all bits known to be one and discard any previous information.
91   void setAllOnes() {
92     Zero.clearAllBits();
93     One.setAllBits();
94   }
95 
96   /// Returns true if this value is known to be negative.
97   bool isNegative() const { return One.isSignBitSet(); }
98 
99   /// Returns true if this value is known to be non-negative.
100   bool isNonNegative() const { return Zero.isSignBitSet(); }
101 
102   /// Returns true if this value is known to be non-zero.
103   bool isNonZero() const { return !One.isZero(); }
104 
105   /// Returns true if this value is known to be positive.
106   bool isStrictlyPositive() const {
107     return Zero.isSignBitSet() && !One.isZero();
108   }
109 
110   /// Make this value negative.
111   void makeNegative() {
112     One.setSignBit();
113   }
114 
115   /// Make this value non-negative.
116   void makeNonNegative() {
117     Zero.setSignBit();
118   }
119 
120   /// Return the minimal unsigned value possible given these KnownBits.
121   APInt getMinValue() const {
122     // Assume that all bits that aren't known-ones are zeros.
123     return One;
124   }
125 
126   /// Return the minimal signed value possible given these KnownBits.
127   APInt getSignedMinValue() const {
128     // Assume that all bits that aren't known-ones are zeros.
129     APInt Min = One;
130     // Sign bit is unknown.
131     if (Zero.isSignBitClear())
132       Min.setSignBit();
133     return Min;
134   }
135 
136   /// Return the maximal unsigned value possible given these KnownBits.
137   APInt getMaxValue() const {
138     // Assume that all bits that aren't known-zeros are ones.
139     return ~Zero;
140   }
141 
142   /// Return the maximal signed value possible given these KnownBits.
143   APInt getSignedMaxValue() const {
144     // Assume that all bits that aren't known-zeros are ones.
145     APInt Max = ~Zero;
146     // Sign bit is unknown.
147     if (One.isSignBitClear())
148       Max.clearSignBit();
149     return Max;
150   }
151 
152   /// Return known bits for a truncation of the value we're tracking.
153   KnownBits trunc(unsigned BitWidth) const {
154     return KnownBits(Zero.trunc(BitWidth), One.trunc(BitWidth));
155   }
156 
157   /// Return known bits for an "any" extension of the value we're tracking,
158   /// where we don't know anything about the extended bits.
159   KnownBits anyext(unsigned BitWidth) const {
160     return KnownBits(Zero.zext(BitWidth), One.zext(BitWidth));
161   }
162 
163   /// Return known bits for a zero extension of the value we're tracking.
164   KnownBits zext(unsigned BitWidth) const {
165     unsigned OldBitWidth = getBitWidth();
166     APInt NewZero = Zero.zext(BitWidth);
167     NewZero.setBitsFrom(OldBitWidth);
168     return KnownBits(NewZero, One.zext(BitWidth));
169   }
170 
171   /// Return known bits for a sign extension of the value we're tracking.
172   KnownBits sext(unsigned BitWidth) const {
173     return KnownBits(Zero.sext(BitWidth), One.sext(BitWidth));
174   }
175 
176   /// Return known bits for an "any" extension or truncation of the value we're
177   /// tracking.
178   KnownBits anyextOrTrunc(unsigned BitWidth) const {
179     if (BitWidth > getBitWidth())
180       return anyext(BitWidth);
181     if (BitWidth < getBitWidth())
182       return trunc(BitWidth);
183     return *this;
184   }
185 
186   /// Return known bits for a zero extension or truncation of the value we're
187   /// tracking.
188   KnownBits zextOrTrunc(unsigned BitWidth) const {
189     if (BitWidth > getBitWidth())
190       return zext(BitWidth);
191     if (BitWidth < getBitWidth())
192       return trunc(BitWidth);
193     return *this;
194   }
195 
196   /// Return known bits for a sign extension or truncation of the value we're
197   /// tracking.
198   KnownBits sextOrTrunc(unsigned BitWidth) const {
199     if (BitWidth > getBitWidth())
200       return sext(BitWidth);
201     if (BitWidth < getBitWidth())
202       return trunc(BitWidth);
203     return *this;
204   }
205 
206   /// Return known bits for a in-register sign extension of the value we're
207   /// tracking.
208   KnownBits sextInReg(unsigned SrcBitWidth) const;
209 
210   /// Insert the bits from a smaller known bits starting at bitPosition.
211   void insertBits(const KnownBits &SubBits, unsigned BitPosition) {
212     Zero.insertBits(SubBits.Zero, BitPosition);
213     One.insertBits(SubBits.One, BitPosition);
214   }
215 
216   /// Return a subset of the known bits from [bitPosition,bitPosition+numBits).
217   KnownBits extractBits(unsigned NumBits, unsigned BitPosition) const {
218     return KnownBits(Zero.extractBits(NumBits, BitPosition),
219                      One.extractBits(NumBits, BitPosition));
220   }
221 
222   /// Concatenate the bits from \p Lo onto the bottom of *this.  This is
223   /// equivalent to:
224   ///   (this->zext(NewWidth) << Lo.getBitWidth()) | Lo.zext(NewWidth)
225   KnownBits concat(const KnownBits &Lo) const {
226     return KnownBits(Zero.concat(Lo.Zero), One.concat(Lo.One));
227   }
228 
229   /// Return KnownBits based on this, but updated given that the underlying
230   /// value is known to be greater than or equal to Val.
231   KnownBits makeGE(const APInt &Val) const;
232 
233   /// Returns the minimum number of trailing zero bits.
234   unsigned countMinTrailingZeros() const { return Zero.countr_one(); }
235 
236   /// Returns the minimum number of trailing one bits.
237   unsigned countMinTrailingOnes() const { return One.countr_one(); }
238 
239   /// Returns the minimum number of leading zero bits.
240   unsigned countMinLeadingZeros() const { return Zero.countl_one(); }
241 
242   /// Returns the minimum number of leading one bits.
243   unsigned countMinLeadingOnes() const { return One.countl_one(); }
244 
245   /// Returns the number of times the sign bit is replicated into the other
246   /// bits.
247   unsigned countMinSignBits() const {
248     if (isNonNegative())
249       return countMinLeadingZeros();
250     if (isNegative())
251       return countMinLeadingOnes();
252     // Every value has at least 1 sign bit.
253     return 1;
254   }
255 
256   /// Returns the maximum number of bits needed to represent all possible
257   /// signed values with these known bits. This is the inverse of the minimum
258   /// number of known sign bits. Examples for bitwidth 5:
259   /// 110?? --> 4
260   /// 0000? --> 2
261   unsigned countMaxSignificantBits() const {
262     return getBitWidth() - countMinSignBits() + 1;
263   }
264 
265   /// Returns the maximum number of trailing zero bits possible.
266   unsigned countMaxTrailingZeros() const { return One.countr_zero(); }
267 
268   /// Returns the maximum number of trailing one bits possible.
269   unsigned countMaxTrailingOnes() const { return Zero.countr_zero(); }
270 
271   /// Returns the maximum number of leading zero bits possible.
272   unsigned countMaxLeadingZeros() const { return One.countl_zero(); }
273 
274   /// Returns the maximum number of leading one bits possible.
275   unsigned countMaxLeadingOnes() const { return Zero.countl_zero(); }
276 
277   /// Returns the number of bits known to be one.
278   unsigned countMinPopulation() const { return One.popcount(); }
279 
280   /// Returns the maximum number of bits that could be one.
281   unsigned countMaxPopulation() const {
282     return getBitWidth() - Zero.popcount();
283   }
284 
285   /// Returns the maximum number of bits needed to represent all possible
286   /// unsigned values with these known bits. This is the inverse of the
287   /// minimum number of leading zeros.
288   unsigned countMaxActiveBits() const {
289     return getBitWidth() - countMinLeadingZeros();
290   }
291 
292   /// Create known bits from a known constant.
293   static KnownBits makeConstant(const APInt &C) {
294     return KnownBits(~C, C);
295   }
296 
297   /// Returns KnownBits information that is known to be true for both this and
298   /// RHS.
299   ///
300   /// When an operation is known to return one of its operands, this can be used
301   /// to combine information about the known bits of the operands to get the
302   /// information that must be true about the result.
303   KnownBits intersectWith(const KnownBits &RHS) const {
304     return KnownBits(Zero & RHS.Zero, One & RHS.One);
305   }
306 
307   /// Returns KnownBits information that is known to be true for either this or
308   /// RHS or both.
309   ///
310   /// This can be used to combine different sources of information about the
311   /// known bits of a single value, e.g. information about the low bits and the
312   /// high bits of the result of a multiplication.
313   KnownBits unionWith(const KnownBits &RHS) const {
314     return KnownBits(Zero | RHS.Zero, One | RHS.One);
315   }
316 
317   /// Return true if LHS and RHS have no common bits set.
318   static bool haveNoCommonBitsSet(const KnownBits &LHS, const KnownBits &RHS) {
319     return (LHS.Zero | RHS.Zero).isAllOnes();
320   }
321 
322   /// Compute known bits resulting from adding LHS, RHS and a 1-bit Carry.
323   static KnownBits computeForAddCarry(
324       const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry);
325 
326   /// Compute known bits resulting from adding LHS and RHS.
327   static KnownBits computeForAddSub(bool Add, bool NSW, bool NUW,
328                                     const KnownBits &LHS, const KnownBits &RHS);
329 
330   /// Compute known bits results from subtracting RHS from LHS with 1-bit
331   /// Borrow.
332   static KnownBits computeForSubBorrow(const KnownBits &LHS, KnownBits RHS,
333                                        const KnownBits &Borrow);
334 
335   /// Compute knownbits resulting from addition of LHS and RHS.
336   static KnownBits add(const KnownBits &LHS, const KnownBits &RHS,
337                        bool NSW = false, bool NUW = false) {
338     return computeForAddSub(/*Add=*/true, NSW, NUW, LHS, RHS);
339   }
340 
341   /// Compute knownbits resulting from subtraction of LHS and RHS.
342   static KnownBits sub(const KnownBits &LHS, const KnownBits &RHS,
343                        bool NSW = false, bool NUW = false) {
344     return computeForAddSub(/*Add=*/false, NSW, NUW, LHS, RHS);
345   }
346 
347   /// Compute knownbits resulting from llvm.sadd.sat(LHS, RHS)
348   static KnownBits sadd_sat(const KnownBits &LHS, const KnownBits &RHS);
349 
350   /// Compute knownbits resulting from llvm.uadd.sat(LHS, RHS)
351   static KnownBits uadd_sat(const KnownBits &LHS, const KnownBits &RHS);
352 
353   /// Compute knownbits resulting from llvm.ssub.sat(LHS, RHS)
354   static KnownBits ssub_sat(const KnownBits &LHS, const KnownBits &RHS);
355 
356   /// Compute knownbits resulting from llvm.usub.sat(LHS, RHS)
357   static KnownBits usub_sat(const KnownBits &LHS, const KnownBits &RHS);
358 
359   /// Compute knownbits resulting from APIntOps::avgFloorS
360   static KnownBits avgFloorS(const KnownBits &LHS, const KnownBits &RHS);
361 
362   /// Compute knownbits resulting from APIntOps::avgFloorU
363   static KnownBits avgFloorU(const KnownBits &LHS, const KnownBits &RHS);
364 
365   /// Compute knownbits resulting from APIntOps::avgCeilS
366   static KnownBits avgCeilS(const KnownBits &LHS, const KnownBits &RHS);
367 
368   /// Compute knownbits resulting from APIntOps::avgCeilU
369   static KnownBits avgCeilU(const KnownBits &LHS, const KnownBits &RHS);
370 
371   /// Compute known bits resulting from multiplying LHS and RHS.
372   static KnownBits mul(const KnownBits &LHS, const KnownBits &RHS,
373                        bool NoUndefSelfMultiply = false);
374 
375   /// Compute known bits from sign-extended multiply-hi.
376   static KnownBits mulhs(const KnownBits &LHS, const KnownBits &RHS);
377 
378   /// Compute known bits from zero-extended multiply-hi.
379   static KnownBits mulhu(const KnownBits &LHS, const KnownBits &RHS);
380 
381   /// Compute known bits for sdiv(LHS, RHS).
382   static KnownBits sdiv(const KnownBits &LHS, const KnownBits &RHS,
383                         bool Exact = false);
384 
385   /// Compute known bits for udiv(LHS, RHS).
386   static KnownBits udiv(const KnownBits &LHS, const KnownBits &RHS,
387                         bool Exact = false);
388 
389   /// Compute known bits for urem(LHS, RHS).
390   static KnownBits urem(const KnownBits &LHS, const KnownBits &RHS);
391 
392   /// Compute known bits for srem(LHS, RHS).
393   static KnownBits srem(const KnownBits &LHS, const KnownBits &RHS);
394 
395   /// Compute known bits for umax(LHS, RHS).
396   static KnownBits umax(const KnownBits &LHS, const KnownBits &RHS);
397 
398   /// Compute known bits for umin(LHS, RHS).
399   static KnownBits umin(const KnownBits &LHS, const KnownBits &RHS);
400 
401   /// Compute known bits for smax(LHS, RHS).
402   static KnownBits smax(const KnownBits &LHS, const KnownBits &RHS);
403 
404   /// Compute known bits for smin(LHS, RHS).
405   static KnownBits smin(const KnownBits &LHS, const KnownBits &RHS);
406 
407   /// Compute known bits for abdu(LHS, RHS).
408   static KnownBits abdu(const KnownBits &LHS, const KnownBits &RHS);
409 
410   /// Compute known bits for abds(LHS, RHS).
411   static KnownBits abds(KnownBits LHS, KnownBits RHS);
412 
413   /// Compute known bits for shl(LHS, RHS).
414   /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
415   static KnownBits shl(const KnownBits &LHS, const KnownBits &RHS,
416                        bool NUW = false, bool NSW = false,
417                        bool ShAmtNonZero = false);
418 
419   /// Compute known bits for lshr(LHS, RHS).
420   /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
421   static KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS,
422                         bool ShAmtNonZero = false, bool Exact = false);
423 
424   /// Compute known bits for ashr(LHS, RHS).
425   /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
426   static KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS,
427                         bool ShAmtNonZero = false, bool Exact = false);
428 
429   /// Determine if these known bits always give the same ICMP_EQ result.
430   static std::optional<bool> eq(const KnownBits &LHS, const KnownBits &RHS);
431 
432   /// Determine if these known bits always give the same ICMP_NE result.
433   static std::optional<bool> ne(const KnownBits &LHS, const KnownBits &RHS);
434 
435   /// Determine if these known bits always give the same ICMP_UGT result.
436   static std::optional<bool> ugt(const KnownBits &LHS, const KnownBits &RHS);
437 
438   /// Determine if these known bits always give the same ICMP_UGE result.
439   static std::optional<bool> uge(const KnownBits &LHS, const KnownBits &RHS);
440 
441   /// Determine if these known bits always give the same ICMP_ULT result.
442   static std::optional<bool> ult(const KnownBits &LHS, const KnownBits &RHS);
443 
444   /// Determine if these known bits always give the same ICMP_ULE result.
445   static std::optional<bool> ule(const KnownBits &LHS, const KnownBits &RHS);
446 
447   /// Determine if these known bits always give the same ICMP_SGT result.
448   static std::optional<bool> sgt(const KnownBits &LHS, const KnownBits &RHS);
449 
450   /// Determine if these known bits always give the same ICMP_SGE result.
451   static std::optional<bool> sge(const KnownBits &LHS, const KnownBits &RHS);
452 
453   /// Determine if these known bits always give the same ICMP_SLT result.
454   static std::optional<bool> slt(const KnownBits &LHS, const KnownBits &RHS);
455 
456   /// Determine if these known bits always give the same ICMP_SLE result.
457   static std::optional<bool> sle(const KnownBits &LHS, const KnownBits &RHS);
458 
459   /// Update known bits based on ANDing with RHS.
460   KnownBits &operator&=(const KnownBits &RHS);
461 
462   /// Update known bits based on ORing with RHS.
463   KnownBits &operator|=(const KnownBits &RHS);
464 
465   /// Update known bits based on XORing with RHS.
466   KnownBits &operator^=(const KnownBits &RHS);
467 
468   /// Compute known bits for the absolute value.
469   KnownBits abs(bool IntMinIsPoison = false) const;
470 
471   KnownBits byteSwap() const {
472     return KnownBits(Zero.byteSwap(), One.byteSwap());
473   }
474 
475   KnownBits reverseBits() const {
476     return KnownBits(Zero.reverseBits(), One.reverseBits());
477   }
478 
479   /// Compute known bits for X & -X, which has only the lowest bit set of X set.
480   /// The name comes from the X86 BMI instruction
481   KnownBits blsi() const;
482 
483   /// Compute known bits for X ^ (X - 1), which has all bits up to and including
484   /// the lowest set bit of X set. The name comes from the X86 BMI instruction.
485   KnownBits blsmsk() const;
486 
487   bool operator==(const KnownBits &Other) const {
488     return Zero == Other.Zero && One == Other.One;
489   }
490 
491   bool operator!=(const KnownBits &Other) const { return !(*this == Other); }
492 
493   void print(raw_ostream &OS) const;
494   void dump() const;
495 
496 private:
497   // Internal helper for getting the initial KnownBits for an `srem` or `urem`
498   // operation with the low-bits set.
499   static KnownBits remGetLowBits(const KnownBits &LHS, const KnownBits &RHS);
500 };
501 
502 inline KnownBits operator&(KnownBits LHS, const KnownBits &RHS) {
503   LHS &= RHS;
504   return LHS;
505 }
506 
507 inline KnownBits operator&(const KnownBits &LHS, KnownBits &&RHS) {
508   RHS &= LHS;
509   return std::move(RHS);
510 }
511 
512 inline KnownBits operator|(KnownBits LHS, const KnownBits &RHS) {
513   LHS |= RHS;
514   return LHS;
515 }
516 
517 inline KnownBits operator|(const KnownBits &LHS, KnownBits &&RHS) {
518   RHS |= LHS;
519   return std::move(RHS);
520 }
521 
522 inline KnownBits operator^(KnownBits LHS, const KnownBits &RHS) {
523   LHS ^= RHS;
524   return LHS;
525 }
526 
527 inline KnownBits operator^(const KnownBits &LHS, KnownBits &&RHS) {
528   RHS ^= LHS;
529   return std::move(RHS);
530 }
531 
532 inline raw_ostream &operator<<(raw_ostream &OS, const KnownBits &Known) {
533   Known.print(OS);
534   return OS;
535 }
536 
537 } // end namespace llvm
538 
539 #endif
540