xref: /llvm-project/llvm/unittests/Support/KnownBitsTest.cpp (revision d81db0e5f5b1404ff4813af3050d671528ad45cc)
1 //===- llvm/unittest/Support/KnownBitsTest.cpp - KnownBits tests ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements unit tests for KnownBits functions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/ADT/ArrayRef.h"
14 #include "llvm/Support/KnownBits.h"
15 #include "KnownBitsTest.h"
16 #include "gtest/gtest.h"
17 
18 using namespace llvm;
19 
20 using UnaryBitsFn = llvm::function_ref<KnownBits(const KnownBits &)>;
21 using UnaryIntFn = llvm::function_ref<std::optional<APInt>(const APInt &)>;
22 using UnaryCheckFn = llvm::function_ref<bool(const KnownBits &)>;
23 
24 using BinaryBitsFn =
25     llvm::function_ref<KnownBits(const KnownBits &, const KnownBits &)>;
26 using BinaryIntFn =
27     llvm::function_ref<std::optional<APInt>(const APInt &, const APInt &)>;
28 using BinaryCheckFn =
29     llvm::function_ref<bool(const KnownBits &, const KnownBits &)>;
30 
31 static bool checkOptimalityUnary(const KnownBits &) { return true; }
32 static bool checkCorrectnessOnlyUnary(const KnownBits &) { return false; }
33 static bool checkOptimalityBinary(const KnownBits &, const KnownBits &) {
34   return true;
35 }
36 static bool checkCorrectnessOnlyBinary(const KnownBits &, const KnownBits &) {
37   return false;
38 }
39 
40 static testing::AssertionResult isCorrect(const KnownBits &Exact,
41                                           const KnownBits &Computed,
42                                           ArrayRef<KnownBits> Inputs) {
43   if (Computed.Zero.isSubsetOf(Exact.Zero) &&
44       Computed.One.isSubsetOf(Exact.One))
45     return testing::AssertionSuccess();
46 
47   testing::AssertionResult Result = testing::AssertionFailure();
48   Result << "Inputs = ";
49   for (const KnownBits &Input : Inputs)
50     Result << Input << ", ";
51   Result << "Computed = " << Computed << ", Exact = " << Exact;
52   return Result;
53 }
54 
55 static testing::AssertionResult isOptimal(const KnownBits &Exact,
56                                           const KnownBits &Computed,
57                                           ArrayRef<KnownBits> Inputs) {
58   if (Computed == Exact)
59     return testing::AssertionSuccess();
60 
61   testing::AssertionResult Result = testing::AssertionFailure();
62   Result << "Inputs = ";
63   for (const KnownBits &Input : Inputs)
64     Result << Input << ", ";
65   Result << "Computed = " << Computed << ", Exact = " << Exact;
66   return Result;
67 }
68 
69 static void
70 testUnaryOpExhaustive(UnaryBitsFn BitsFn, UnaryIntFn IntFn,
71                       UnaryCheckFn CheckOptimalityFn = checkOptimalityUnary) {
72   for (unsigned Bits : {1, 4}) {
73     ForeachKnownBits(Bits, [&](const KnownBits &Known) {
74       KnownBits Computed = BitsFn(Known);
75       KnownBits Exact(Bits);
76       Exact.Zero.setAllBits();
77       Exact.One.setAllBits();
78 
79       ForeachNumInKnownBits(Known, [&](const APInt &N) {
80         if (std::optional<APInt> Res = IntFn(N)) {
81           Exact.One &= *Res;
82           Exact.Zero &= ~*Res;
83         }
84       });
85 
86       EXPECT_TRUE(!Computed.hasConflict());
87       EXPECT_TRUE(isCorrect(Exact, Computed, Known));
88       // We generally don't want to return conflicting known bits, even if it is
89       // legal for always poison results.
90       if (CheckOptimalityFn(Known) && !Exact.hasConflict()) {
91         EXPECT_TRUE(isOptimal(Exact, Computed, Known));
92       }
93     });
94   }
95 }
96 
97 static void
98 testBinaryOpExhaustive(BinaryBitsFn BitsFn, BinaryIntFn IntFn,
99                        BinaryCheckFn CheckOptimalityFn = checkOptimalityBinary,
100                        bool RefinePoisonToZero = false) {
101   for (unsigned Bits : {1, 4}) {
102     ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
103       ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
104         KnownBits Computed = BitsFn(Known1, Known2);
105         KnownBits Exact(Bits);
106         Exact.Zero.setAllBits();
107         Exact.One.setAllBits();
108 
109         ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
110           ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
111             if (std::optional<APInt> Res = IntFn(N1, N2)) {
112               Exact.One &= *Res;
113               Exact.Zero &= ~*Res;
114             }
115           });
116         });
117 
118         EXPECT_TRUE(!Computed.hasConflict());
119         EXPECT_TRUE(isCorrect(Exact, Computed, {Known1, Known2}));
120         // We generally don't want to return conflicting known bits, even if it
121         // is legal for always poison results.
122         if (CheckOptimalityFn(Known1, Known2) && !Exact.hasConflict()) {
123           EXPECT_TRUE(isOptimal(Exact, Computed, {Known1, Known2}));
124         }
125         // In some cases we choose to return zero if the result is always
126         // poison.
127         if (RefinePoisonToZero && Exact.hasConflict()) {
128           EXPECT_TRUE(Computed.isZero());
129         }
130       });
131     });
132   }
133 }
134 
135 namespace {
136 
137 TEST(KnownBitsTest, AddCarryExhaustive) {
138   unsigned Bits = 4;
139   ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
140     ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
141       ForeachKnownBits(1, [&](const KnownBits &KnownCarry) {
142         // Explicitly compute known bits of the addition by trying all
143         // possibilities.
144         KnownBits Known(Bits);
145         Known.Zero.setAllBits();
146         Known.One.setAllBits();
147         ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
148           ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
149             ForeachNumInKnownBits(KnownCarry, [&](const APInt &Carry) {
150               APInt Add = N1 + N2;
151               if (Carry.getBoolValue())
152                 ++Add;
153 
154               Known.One &= Add;
155               Known.Zero &= ~Add;
156             });
157           });
158         });
159 
160         KnownBits KnownComputed =
161             KnownBits::computeForAddCarry(Known1, Known2, KnownCarry);
162         EXPECT_EQ(Known, KnownComputed);
163       });
164     });
165   });
166 }
167 
168 static void TestAddSubExhaustive(bool IsAdd) {
169   unsigned Bits = 4;
170   ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
171     ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
172       KnownBits Known(Bits), KnownNSW(Bits), KnownNUW(Bits),
173           KnownNSWAndNUW(Bits);
174       Known.Zero.setAllBits();
175       Known.One.setAllBits();
176       KnownNSW.Zero.setAllBits();
177       KnownNSW.One.setAllBits();
178       KnownNUW.Zero.setAllBits();
179       KnownNUW.One.setAllBits();
180       KnownNSWAndNUW.Zero.setAllBits();
181       KnownNSWAndNUW.One.setAllBits();
182 
183       ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
184         ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
185           bool SignedOverflow;
186           bool UnsignedOverflow;
187           APInt Res;
188           if (IsAdd) {
189             Res = N1.uadd_ov(N2, UnsignedOverflow);
190             Res = N1.sadd_ov(N2, SignedOverflow);
191           } else {
192             Res = N1.usub_ov(N2, UnsignedOverflow);
193             Res = N1.ssub_ov(N2, SignedOverflow);
194           }
195 
196           Known.One &= Res;
197           Known.Zero &= ~Res;
198 
199           if (!SignedOverflow) {
200             KnownNSW.One &= Res;
201             KnownNSW.Zero &= ~Res;
202           }
203 
204           if (!UnsignedOverflow) {
205             KnownNUW.One &= Res;
206             KnownNUW.Zero &= ~Res;
207           }
208 
209           if (!UnsignedOverflow && !SignedOverflow) {
210             KnownNSWAndNUW.One &= Res;
211             KnownNSWAndNUW.Zero &= ~Res;
212           }
213         });
214       });
215 
216       KnownBits KnownComputed = KnownBits::computeForAddSub(
217           IsAdd, /*NSW=*/false, /*NUW=*/false, Known1, Known2);
218       EXPECT_TRUE(isOptimal(Known, KnownComputed, {Known1, Known2}));
219 
220       KnownBits KnownNSWComputed = KnownBits::computeForAddSub(
221           IsAdd, /*NSW=*/true, /*NUW=*/false, Known1, Known2);
222       if (!KnownNSW.hasConflict())
223         EXPECT_TRUE(isOptimal(KnownNSW, KnownNSWComputed, {Known1, Known2}));
224 
225       KnownBits KnownNUWComputed = KnownBits::computeForAddSub(
226           IsAdd, /*NSW=*/false, /*NUW=*/true, Known1, Known2);
227       if (!KnownNUW.hasConflict())
228         EXPECT_TRUE(isOptimal(KnownNUW, KnownNUWComputed, {Known1, Known2}));
229 
230       KnownBits KnownNSWAndNUWComputed = KnownBits::computeForAddSub(
231           IsAdd, /*NSW=*/true, /*NUW=*/true, Known1, Known2);
232       if (!KnownNSWAndNUW.hasConflict())
233         EXPECT_TRUE(isOptimal(KnownNSWAndNUW, KnownNSWAndNUWComputed,
234                               {Known1, Known2}));
235     });
236   });
237 }
238 
239 TEST(KnownBitsTest, AddSubExhaustive) {
240   TestAddSubExhaustive(true);
241   TestAddSubExhaustive(false);
242 }
243 
244 TEST(KnownBitsTest, SubBorrowExhaustive) {
245   unsigned Bits = 4;
246   ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
247     ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
248       ForeachKnownBits(1, [&](const KnownBits &KnownBorrow) {
249         // Explicitly compute known bits of the subtraction by trying all
250         // possibilities.
251         KnownBits Known(Bits);
252         Known.Zero.setAllBits();
253         Known.One.setAllBits();
254         ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
255           ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
256             ForeachNumInKnownBits(KnownBorrow, [&](const APInt &Borrow) {
257               APInt Sub = N1 - N2;
258               if (Borrow.getBoolValue())
259                 --Sub;
260 
261               Known.One &= Sub;
262               Known.Zero &= ~Sub;
263             });
264           });
265         });
266 
267         KnownBits KnownComputed =
268             KnownBits::computeForSubBorrow(Known1, Known2, KnownBorrow);
269         EXPECT_EQ(Known, KnownComputed);
270       });
271     });
272   });
273 }
274 
275 TEST(KnownBitsTest, SignBitUnknown) {
276   KnownBits Known(2);
277   EXPECT_TRUE(Known.isSignUnknown());
278   Known.Zero.setBit(0);
279   EXPECT_TRUE(Known.isSignUnknown());
280   Known.Zero.setBit(1);
281   EXPECT_FALSE(Known.isSignUnknown());
282   Known.Zero.clearBit(0);
283   EXPECT_FALSE(Known.isSignUnknown());
284   Known.Zero.clearBit(1);
285   EXPECT_TRUE(Known.isSignUnknown());
286 
287   Known.One.setBit(0);
288   EXPECT_TRUE(Known.isSignUnknown());
289   Known.One.setBit(1);
290   EXPECT_FALSE(Known.isSignUnknown());
291   Known.One.clearBit(0);
292   EXPECT_FALSE(Known.isSignUnknown());
293   Known.One.clearBit(1);
294   EXPECT_TRUE(Known.isSignUnknown());
295 }
296 
297 TEST(KnownBitsTest, AbsDiffSpecialCase) {
298   // There are 2 implementation of absdiff - both are currently needed to cover
299   // extra cases.
300   KnownBits LHS, RHS, Res;
301 
302   // absdiff(LHS,RHS) = sub(umax(LHS,RHS), umin(LHS,RHS)).
303   // Actual: false (Inputs = 1011, 101?, Computed = 000?, Exact = 000?)
304   LHS.One = APInt(4, 0b1011);
305   RHS.One = APInt(4, 0b1010);
306   LHS.Zero = APInt(4, 0b0100);
307   RHS.Zero = APInt(4, 0b0100);
308   Res = KnownBits::absdiff(LHS, RHS);
309   EXPECT_EQ(0b0000ul, Res.One.getZExtValue());
310   EXPECT_EQ(0b1110ul, Res.Zero.getZExtValue());
311 
312   // find the common bits between sub(LHS,RHS) and sub(RHS,LHS).
313   // Actual: false (Inputs = ???1, 1000, Computed = ???1, Exact = 0??1)
314   LHS.One = APInt(4, 0b0001);
315   RHS.One = APInt(4, 0b1000);
316   LHS.Zero = APInt(4, 0b0000);
317   RHS.Zero = APInt(4, 0b0111);
318   Res = KnownBits::absdiff(LHS, RHS);
319   EXPECT_EQ(0b0001ul, Res.One.getZExtValue());
320   EXPECT_EQ(0b0000ul, Res.Zero.getZExtValue());
321 }
322 
323 TEST(KnownBitsTest, BinaryExhaustive) {
324   testBinaryOpExhaustive(
325       [](const KnownBits &Known1, const KnownBits &Known2) {
326         return Known1 & Known2;
327       },
328       [](const APInt &N1, const APInt &N2) { return N1 & N2; });
329   testBinaryOpExhaustive(
330       [](const KnownBits &Known1, const KnownBits &Known2) {
331         return Known1 | Known2;
332       },
333       [](const APInt &N1, const APInt &N2) { return N1 | N2; });
334   testBinaryOpExhaustive(
335       [](const KnownBits &Known1, const KnownBits &Known2) {
336         return Known1 ^ Known2;
337       },
338       [](const APInt &N1, const APInt &N2) { return N1 ^ N2; });
339 
340   testBinaryOpExhaustive(
341       [](const KnownBits &Known1, const KnownBits &Known2) {
342         return KnownBits::umax(Known1, Known2);
343       },
344       [](const APInt &N1, const APInt &N2) { return APIntOps::umax(N1, N2); });
345   testBinaryOpExhaustive(
346       [](const KnownBits &Known1, const KnownBits &Known2) {
347         return KnownBits::umin(Known1, Known2);
348       },
349       [](const APInt &N1, const APInt &N2) { return APIntOps::umin(N1, N2); });
350   testBinaryOpExhaustive(
351       [](const KnownBits &Known1, const KnownBits &Known2) {
352         return KnownBits::smax(Known1, Known2);
353       },
354       [](const APInt &N1, const APInt &N2) { return APIntOps::smax(N1, N2); });
355   testBinaryOpExhaustive(
356       [](const KnownBits &Known1, const KnownBits &Known2) {
357         return KnownBits::smin(Known1, Known2);
358       },
359       [](const APInt &N1, const APInt &N2) { return APIntOps::smin(N1, N2); });
360   testBinaryOpExhaustive(
361       [](const KnownBits &Known1, const KnownBits &Known2) {
362         return KnownBits::absdiff(Known1, Known2);
363       },
364       [](const APInt &N1, const APInt &N2) {
365         return APIntOps::absdiff(N1, N2);
366       },
367       checkCorrectnessOnlyBinary);
368   testBinaryOpExhaustive(
369       [](const KnownBits &Known1, const KnownBits &Known2) {
370         return KnownBits::udiv(Known1, Known2);
371       },
372       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
373         if (N2.isZero())
374           return std::nullopt;
375         return N1.udiv(N2);
376       },
377       checkCorrectnessOnlyBinary);
378   testBinaryOpExhaustive(
379       [](const KnownBits &Known1, const KnownBits &Known2) {
380         return KnownBits::udiv(Known1, Known2, /*Exact*/ true);
381       },
382       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
383         if (N2.isZero() || !N1.urem(N2).isZero())
384           return std::nullopt;
385         return N1.udiv(N2);
386       },
387       checkCorrectnessOnlyBinary);
388   testBinaryOpExhaustive(
389       [](const KnownBits &Known1, const KnownBits &Known2) {
390         return KnownBits::sdiv(Known1, Known2);
391       },
392       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
393         if (N2.isZero() || (N1.isMinSignedValue() && N2.isAllOnes()))
394           return std::nullopt;
395         return N1.sdiv(N2);
396       },
397       checkCorrectnessOnlyBinary);
398   testBinaryOpExhaustive(
399       [](const KnownBits &Known1, const KnownBits &Known2) {
400         return KnownBits::sdiv(Known1, Known2, /*Exact*/ true);
401       },
402       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
403         if (N2.isZero() || (N1.isMinSignedValue() && N2.isAllOnes()) ||
404             !N1.srem(N2).isZero())
405           return std::nullopt;
406         return N1.sdiv(N2);
407       },
408       checkCorrectnessOnlyBinary);
409   testBinaryOpExhaustive(
410       [](const KnownBits &Known1, const KnownBits &Known2) {
411         return KnownBits::urem(Known1, Known2);
412       },
413       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
414         if (N2.isZero())
415           return std::nullopt;
416         return N1.urem(N2);
417       },
418       checkCorrectnessOnlyBinary);
419   testBinaryOpExhaustive(
420       [](const KnownBits &Known1, const KnownBits &Known2) {
421         return KnownBits::srem(Known1, Known2);
422       },
423       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
424         if (N2.isZero())
425           return std::nullopt;
426         return N1.srem(N2);
427       },
428       checkCorrectnessOnlyBinary);
429   testBinaryOpExhaustive(
430       [](const KnownBits &Known1, const KnownBits &Known2) {
431         return KnownBits::sadd_sat(Known1, Known2);
432       },
433       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
434         return N1.sadd_sat(N2);
435       },
436       checkCorrectnessOnlyBinary);
437   testBinaryOpExhaustive(
438       [](const KnownBits &Known1, const KnownBits &Known2) {
439         return KnownBits::uadd_sat(Known1, Known2);
440       },
441       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
442         return N1.uadd_sat(N2);
443       },
444       checkCorrectnessOnlyBinary);
445   testBinaryOpExhaustive(
446       [](const KnownBits &Known1, const KnownBits &Known2) {
447         return KnownBits::ssub_sat(Known1, Known2);
448       },
449       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
450         return N1.ssub_sat(N2);
451       },
452       checkCorrectnessOnlyBinary);
453   testBinaryOpExhaustive(
454       [](const KnownBits &Known1, const KnownBits &Known2) {
455         return KnownBits::usub_sat(Known1, Known2);
456       },
457       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
458         return N1.usub_sat(N2);
459       },
460       checkCorrectnessOnlyBinary);
461   testBinaryOpExhaustive(
462       [](const KnownBits &Known1, const KnownBits &Known2) {
463         return KnownBits::shl(Known1, Known2);
464       },
465       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
466         if (N2.uge(N2.getBitWidth()))
467           return std::nullopt;
468         return N1.shl(N2);
469       },
470       checkOptimalityBinary, /* RefinePoisonToZero */ true);
471   testBinaryOpExhaustive(
472       [](const KnownBits &Known1, const KnownBits &Known2) {
473         return KnownBits::shl(Known1, Known2, /* NUW */ true);
474       },
475       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
476         bool Overflow;
477         APInt Res = N1.ushl_ov(N2, Overflow);
478         if (Overflow)
479           return std::nullopt;
480         return Res;
481       },
482       checkOptimalityBinary, /* RefinePoisonToZero */ true);
483   testBinaryOpExhaustive(
484       [](const KnownBits &Known1, const KnownBits &Known2) {
485         return KnownBits::shl(Known1, Known2, /* NUW */ false, /* NSW */ true);
486       },
487       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
488         bool Overflow;
489         APInt Res = N1.sshl_ov(N2, Overflow);
490         if (Overflow)
491           return std::nullopt;
492         return Res;
493       },
494       checkOptimalityBinary, /* RefinePoisonToZero */ true);
495   testBinaryOpExhaustive(
496       [](const KnownBits &Known1, const KnownBits &Known2) {
497         return KnownBits::shl(Known1, Known2, /* NUW */ true, /* NSW */ true);
498       },
499       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
500         bool OverflowUnsigned, OverflowSigned;
501         APInt Res = N1.ushl_ov(N2, OverflowUnsigned);
502         (void)N1.sshl_ov(N2, OverflowSigned);
503         if (OverflowUnsigned || OverflowSigned)
504           return std::nullopt;
505         return Res;
506       },
507       checkOptimalityBinary, /* RefinePoisonToZero */ true);
508 
509   testBinaryOpExhaustive(
510       [](const KnownBits &Known1, const KnownBits &Known2) {
511         return KnownBits::lshr(Known1, Known2);
512       },
513       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
514         if (N2.uge(N2.getBitWidth()))
515           return std::nullopt;
516         return N1.lshr(N2);
517       },
518       checkOptimalityBinary, /* RefinePoisonToZero */ true);
519   testBinaryOpExhaustive(
520       [](const KnownBits &Known1, const KnownBits &Known2) {
521         return KnownBits::lshr(Known1, Known2, /*ShAmtNonZero=*/false,
522                                /*Exact=*/true);
523       },
524       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
525         if (N2.uge(N2.getBitWidth()))
526           return std::nullopt;
527         if (!N1.extractBits(N2.getZExtValue(), 0).isZero())
528           return std::nullopt;
529         return N1.lshr(N2);
530       },
531       checkOptimalityBinary, /* RefinePoisonToZero */ true);
532   testBinaryOpExhaustive(
533       [](const KnownBits &Known1, const KnownBits &Known2) {
534         return KnownBits::ashr(Known1, Known2);
535       },
536       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
537         if (N2.uge(N2.getBitWidth()))
538           return std::nullopt;
539         return N1.ashr(N2);
540       },
541       checkOptimalityBinary, /* RefinePoisonToZero */ true);
542   testBinaryOpExhaustive(
543       [](const KnownBits &Known1, const KnownBits &Known2) {
544         return KnownBits::ashr(Known1, Known2, /*ShAmtNonZero=*/false,
545                                /*Exact=*/true);
546       },
547       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
548         if (N2.uge(N2.getBitWidth()))
549           return std::nullopt;
550         if (!N1.extractBits(N2.getZExtValue(), 0).isZero())
551           return std::nullopt;
552         return N1.ashr(N2);
553       },
554       checkOptimalityBinary, /* RefinePoisonToZero */ true);
555 
556   testBinaryOpExhaustive(
557       [](const KnownBits &Known1, const KnownBits &Known2) {
558         return KnownBits::mul(Known1, Known2);
559       },
560       [](const APInt &N1, const APInt &N2) { return N1 * N2; },
561       checkCorrectnessOnlyBinary);
562   testBinaryOpExhaustive(
563       [](const KnownBits &Known1, const KnownBits &Known2) {
564         return KnownBits::mulhs(Known1, Known2);
565       },
566       [](const APInt &N1, const APInt &N2) {
567         unsigned Bits = N1.getBitWidth();
568         return (N1.sext(2 * Bits) * N2.sext(2 * Bits)).extractBits(Bits, Bits);
569       },
570       checkCorrectnessOnlyBinary);
571   testBinaryOpExhaustive(
572       [](const KnownBits &Known1, const KnownBits &Known2) {
573         return KnownBits::mulhu(Known1, Known2);
574       },
575       [](const APInt &N1, const APInt &N2) {
576         unsigned Bits = N1.getBitWidth();
577         return (N1.zext(2 * Bits) * N2.zext(2 * Bits)).extractBits(Bits, Bits);
578       },
579       checkCorrectnessOnlyBinary);
580 }
581 
582 TEST(KnownBitsTest, UnaryExhaustive) {
583   testUnaryOpExhaustive([](const KnownBits &Known) { return Known.abs(); },
584                         [](const APInt &N) { return N.abs(); });
585 
586   testUnaryOpExhaustive([](const KnownBits &Known) { return Known.abs(true); },
587                         [](const APInt &N) -> std::optional<APInt> {
588                           if (N.isMinSignedValue())
589                             return std::nullopt;
590                           return N.abs();
591                         });
592 
593   testUnaryOpExhaustive([](const KnownBits &Known) { return Known.blsi(); },
594                         [](const APInt &N) { return N & -N; });
595   testUnaryOpExhaustive([](const KnownBits &Known) { return Known.blsmsk(); },
596                         [](const APInt &N) { return N ^ (N - 1); });
597 
598   testUnaryOpExhaustive(
599       [](const KnownBits &Known) {
600         return KnownBits::mul(Known, Known, /*SelfMultiply*/ true);
601       },
602       [](const APInt &N) { return N * N; }, checkCorrectnessOnlyUnary);
603 }
604 
605 TEST(KnownBitsTest, WideShifts) {
606   unsigned BitWidth = 128;
607   KnownBits Unknown(BitWidth);
608   KnownBits AllOnes = KnownBits::makeConstant(APInt::getAllOnes(BitWidth));
609 
610   KnownBits ShlResult(BitWidth);
611   ShlResult.makeNegative();
612   EXPECT_EQ(KnownBits::shl(AllOnes, Unknown), ShlResult);
613   KnownBits LShrResult(BitWidth);
614   LShrResult.One.setBit(0);
615   EXPECT_EQ(KnownBits::lshr(AllOnes, Unknown), LShrResult);
616   EXPECT_EQ(KnownBits::ashr(AllOnes, Unknown), AllOnes);
617 }
618 
619 TEST(KnownBitsTest, ICmpExhaustive) {
620   unsigned Bits = 4;
621   ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
622     ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
623       bool AllEQ = true, NoneEQ = true;
624       bool AllNE = true, NoneNE = true;
625       bool AllUGT = true, NoneUGT = true;
626       bool AllUGE = true, NoneUGE = true;
627       bool AllULT = true, NoneULT = true;
628       bool AllULE = true, NoneULE = true;
629       bool AllSGT = true, NoneSGT = true;
630       bool AllSGE = true, NoneSGE = true;
631       bool AllSLT = true, NoneSLT = true;
632       bool AllSLE = true, NoneSLE = true;
633 
634       ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
635         ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
636           AllEQ &= N1.eq(N2);
637           AllNE &= N1.ne(N2);
638           AllUGT &= N1.ugt(N2);
639           AllUGE &= N1.uge(N2);
640           AllULT &= N1.ult(N2);
641           AllULE &= N1.ule(N2);
642           AllSGT &= N1.sgt(N2);
643           AllSGE &= N1.sge(N2);
644           AllSLT &= N1.slt(N2);
645           AllSLE &= N1.sle(N2);
646           NoneEQ &= !N1.eq(N2);
647           NoneNE &= !N1.ne(N2);
648           NoneUGT &= !N1.ugt(N2);
649           NoneUGE &= !N1.uge(N2);
650           NoneULT &= !N1.ult(N2);
651           NoneULE &= !N1.ule(N2);
652           NoneSGT &= !N1.sgt(N2);
653           NoneSGE &= !N1.sge(N2);
654           NoneSLT &= !N1.slt(N2);
655           NoneSLE &= !N1.sle(N2);
656         });
657       });
658 
659       std::optional<bool> KnownEQ = KnownBits::eq(Known1, Known2);
660       std::optional<bool> KnownNE = KnownBits::ne(Known1, Known2);
661       std::optional<bool> KnownUGT = KnownBits::ugt(Known1, Known2);
662       std::optional<bool> KnownUGE = KnownBits::uge(Known1, Known2);
663       std::optional<bool> KnownULT = KnownBits::ult(Known1, Known2);
664       std::optional<bool> KnownULE = KnownBits::ule(Known1, Known2);
665       std::optional<bool> KnownSGT = KnownBits::sgt(Known1, Known2);
666       std::optional<bool> KnownSGE = KnownBits::sge(Known1, Known2);
667       std::optional<bool> KnownSLT = KnownBits::slt(Known1, Known2);
668       std::optional<bool> KnownSLE = KnownBits::sle(Known1, Known2);
669 
670       EXPECT_EQ(AllEQ || NoneEQ, KnownEQ.has_value());
671       EXPECT_EQ(AllNE || NoneNE, KnownNE.has_value());
672       EXPECT_EQ(AllUGT || NoneUGT, KnownUGT.has_value());
673       EXPECT_EQ(AllUGE || NoneUGE, KnownUGE.has_value());
674       EXPECT_EQ(AllULT || NoneULT, KnownULT.has_value());
675       EXPECT_EQ(AllULE || NoneULE, KnownULE.has_value());
676       EXPECT_EQ(AllSGT || NoneSGT, KnownSGT.has_value());
677       EXPECT_EQ(AllSGE || NoneSGE, KnownSGE.has_value());
678       EXPECT_EQ(AllSLT || NoneSLT, KnownSLT.has_value());
679       EXPECT_EQ(AllSLE || NoneSLE, KnownSLE.has_value());
680 
681       EXPECT_EQ(AllEQ, KnownEQ.has_value() && *KnownEQ);
682       EXPECT_EQ(AllNE, KnownNE.has_value() && *KnownNE);
683       EXPECT_EQ(AllUGT, KnownUGT.has_value() && *KnownUGT);
684       EXPECT_EQ(AllUGE, KnownUGE.has_value() && *KnownUGE);
685       EXPECT_EQ(AllULT, KnownULT.has_value() && *KnownULT);
686       EXPECT_EQ(AllULE, KnownULE.has_value() && *KnownULE);
687       EXPECT_EQ(AllSGT, KnownSGT.has_value() && *KnownSGT);
688       EXPECT_EQ(AllSGE, KnownSGE.has_value() && *KnownSGE);
689       EXPECT_EQ(AllSLT, KnownSLT.has_value() && *KnownSLT);
690       EXPECT_EQ(AllSLE, KnownSLE.has_value() && *KnownSLE);
691 
692       EXPECT_EQ(NoneEQ, KnownEQ.has_value() && !*KnownEQ);
693       EXPECT_EQ(NoneNE, KnownNE.has_value() && !*KnownNE);
694       EXPECT_EQ(NoneUGT, KnownUGT.has_value() && !*KnownUGT);
695       EXPECT_EQ(NoneUGE, KnownUGE.has_value() && !*KnownUGE);
696       EXPECT_EQ(NoneULT, KnownULT.has_value() && !*KnownULT);
697       EXPECT_EQ(NoneULE, KnownULE.has_value() && !*KnownULE);
698       EXPECT_EQ(NoneSGT, KnownSGT.has_value() && !*KnownSGT);
699       EXPECT_EQ(NoneSGE, KnownSGE.has_value() && !*KnownSGE);
700       EXPECT_EQ(NoneSLT, KnownSLT.has_value() && !*KnownSLT);
701       EXPECT_EQ(NoneSLE, KnownSLE.has_value() && !*KnownSLE);
702     });
703   });
704 }
705 
706 TEST(KnownBitsTest, GetMinMaxVal) {
707   unsigned Bits = 4;
708   ForeachKnownBits(Bits, [&](const KnownBits &Known) {
709     APInt Min = APInt::getMaxValue(Bits);
710     APInt Max = APInt::getMinValue(Bits);
711     ForeachNumInKnownBits(Known, [&](const APInt &N) {
712       Min = APIntOps::umin(Min, N);
713       Max = APIntOps::umax(Max, N);
714     });
715     EXPECT_EQ(Min, Known.getMinValue());
716     EXPECT_EQ(Max, Known.getMaxValue());
717   });
718 }
719 
720 TEST(KnownBitsTest, GetSignedMinMaxVal) {
721   unsigned Bits = 4;
722   ForeachKnownBits(Bits, [&](const KnownBits &Known) {
723     APInt Min = APInt::getSignedMaxValue(Bits);
724     APInt Max = APInt::getSignedMinValue(Bits);
725     ForeachNumInKnownBits(Known, [&](const APInt &N) {
726       Min = APIntOps::smin(Min, N);
727       Max = APIntOps::smax(Max, N);
728     });
729     EXPECT_EQ(Min, Known.getSignedMinValue());
730     EXPECT_EQ(Max, Known.getSignedMaxValue());
731   });
732 }
733 
734 TEST(KnownBitsTest, CountMaxActiveBits) {
735   unsigned Bits = 4;
736   ForeachKnownBits(Bits, [&](const KnownBits &Known) {
737     unsigned Expected = 0;
738     ForeachNumInKnownBits(Known, [&](const APInt &N) {
739       Expected = std::max(Expected, N.getActiveBits());
740     });
741     EXPECT_EQ(Expected, Known.countMaxActiveBits());
742   });
743 }
744 
745 TEST(KnownBitsTest, CountMaxSignificantBits) {
746   unsigned Bits = 4;
747   ForeachKnownBits(Bits, [&](const KnownBits &Known) {
748     unsigned Expected = 0;
749     ForeachNumInKnownBits(Known, [&](const APInt &N) {
750       Expected = std::max(Expected, N.getSignificantBits());
751     });
752     EXPECT_EQ(Expected, Known.countMaxSignificantBits());
753   });
754 }
755 
756 TEST(KnownBitsTest, SExtOrTrunc) {
757   const unsigned NarrowerSize = 4;
758   const unsigned BaseSize = 6;
759   const unsigned WiderSize = 8;
760   APInt NegativeFitsNarrower(BaseSize, -4, /*isSigned*/ true);
761   APInt NegativeDoesntFitNarrower(BaseSize, -28, /*isSigned*/ true);
762   APInt PositiveFitsNarrower(BaseSize, 14);
763   APInt PositiveDoesntFitNarrower(BaseSize, 36);
764   auto InitKnownBits = [&](KnownBits &Res, const APInt &Input) {
765     Res = KnownBits(Input.getBitWidth());
766     Res.One = Input;
767     Res.Zero = ~Input;
768   };
769 
770   for (unsigned Size : {NarrowerSize, BaseSize, WiderSize}) {
771     for (const APInt &Input :
772          {NegativeFitsNarrower, NegativeDoesntFitNarrower, PositiveFitsNarrower,
773           PositiveDoesntFitNarrower}) {
774       KnownBits Test;
775       InitKnownBits(Test, Input);
776       KnownBits Baseline;
777       InitKnownBits(Baseline, Input.sextOrTrunc(Size));
778       Test = Test.sextOrTrunc(Size);
779       EXPECT_EQ(Test, Baseline);
780     }
781   }
782 }
783 
784 TEST(KnownBitsTest, SExtInReg) {
785   unsigned Bits = 4;
786   for (unsigned FromBits = 1; FromBits <= Bits; ++FromBits) {
787     ForeachKnownBits(Bits, [&](const KnownBits &Known) {
788       APInt CommonOne = APInt::getAllOnes(Bits);
789       APInt CommonZero = APInt::getAllOnes(Bits);
790       unsigned ExtBits = Bits - FromBits;
791       ForeachNumInKnownBits(Known, [&](const APInt &N) {
792         APInt Ext = N << ExtBits;
793         Ext.ashrInPlace(ExtBits);
794         CommonOne &= Ext;
795         CommonZero &= ~Ext;
796       });
797       KnownBits KnownSExtInReg = Known.sextInReg(FromBits);
798       EXPECT_EQ(CommonOne, KnownSExtInReg.One);
799       EXPECT_EQ(CommonZero, KnownSExtInReg.Zero);
800     });
801   }
802 }
803 
804 TEST(KnownBitsTest, CommonBitsSet) {
805   unsigned Bits = 4;
806   ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
807     ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
808       bool HasCommonBitsSet = false;
809       ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
810         ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
811           HasCommonBitsSet |= N1.intersects(N2);
812         });
813       });
814       EXPECT_EQ(!HasCommonBitsSet,
815                 KnownBits::haveNoCommonBitsSet(Known1, Known2));
816     });
817   });
818 }
819 
820 TEST(KnownBitsTest, ConcatBits) {
821   unsigned Bits = 4;
822   for (unsigned LoBits = 1; LoBits < Bits; ++LoBits) {
823     unsigned HiBits = Bits - LoBits;
824     ForeachKnownBits(LoBits, [&](const KnownBits &KnownLo) {
825       ForeachKnownBits(HiBits, [&](const KnownBits &KnownHi) {
826         KnownBits KnownAll = KnownHi.concat(KnownLo);
827 
828         EXPECT_EQ(KnownLo.countMinPopulation() + KnownHi.countMinPopulation(),
829                   KnownAll.countMinPopulation());
830         EXPECT_EQ(KnownLo.countMaxPopulation() + KnownHi.countMaxPopulation(),
831                   KnownAll.countMaxPopulation());
832 
833         KnownBits ExtractLo = KnownAll.extractBits(LoBits, 0);
834         KnownBits ExtractHi = KnownAll.extractBits(HiBits, LoBits);
835 
836         EXPECT_EQ(KnownLo.One.getZExtValue(), ExtractLo.One.getZExtValue());
837         EXPECT_EQ(KnownHi.One.getZExtValue(), ExtractHi.One.getZExtValue());
838         EXPECT_EQ(KnownLo.Zero.getZExtValue(), ExtractLo.Zero.getZExtValue());
839         EXPECT_EQ(KnownHi.Zero.getZExtValue(), ExtractHi.Zero.getZExtValue());
840       });
841     });
842   }
843 }
844 
845 } // end anonymous namespace
846