xref: /llvm-project/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp (revision 2034f2fc8729bd4645ef7caa3c5c6efa284d2d3f)
1 //===- InferIntRangeCommon.cpp - Inference for common ops ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains implementations of range inference for operations that are
10 // common to both the `arith` and `index` dialects to facilitate reuse.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
15 
16 #include "mlir/Interfaces/InferIntRangeInterface.h"
17 
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/STLExtras.h"
20 
21 #include "llvm/Support/Debug.h"
22 
23 #include <iterator>
24 #include <optional>
25 
26 using namespace mlir;
27 
28 #define DEBUG_TYPE "int-range-analysis"
29 
30 //===----------------------------------------------------------------------===//
31 // General utilities
32 //===----------------------------------------------------------------------===//
33 
34 /// Function that evaluates the result of doing something on arithmetic
35 /// constants and returns std::nullopt on overflow.
36 using ConstArithFn =
37     function_ref<std::optional<APInt>(const APInt &, const APInt &)>;
38 
39 /// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
40 /// If either computation overflows, make the result unbounded.
41 static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
42                                          const APInt &minRight,
43                                          const APInt &maxLeft,
44                                          const APInt &maxRight, bool isSigned) {
45   std::optional<APInt> maybeMin = op(minLeft, minRight);
46   std::optional<APInt> maybeMax = op(maxLeft, maxRight);
47   if (maybeMin && maybeMax)
48     return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned);
49   return ConstantIntRanges::maxRange(minLeft.getBitWidth());
50 }
51 
52 /// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`,
53 /// ignoring unbounded values. Returns the maximal range if `op` overflows.
54 static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs,
55                                   ArrayRef<APInt> rhs, bool isSigned) {
56   unsigned width = lhs[0].getBitWidth();
57   APInt min =
58       isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width);
59   APInt max =
60       isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width);
61   for (const APInt &left : lhs) {
62     for (const APInt &right : rhs) {
63       std::optional<APInt> maybeThisResult = op(left, right);
64       if (!maybeThisResult)
65         return ConstantIntRanges::maxRange(width);
66       APInt result = std::move(*maybeThisResult);
67       min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min;
68       max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max;
69     }
70   }
71   return ConstantIntRanges::range(min, max, isSigned);
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // Ext, trunc, index op handling
76 //===----------------------------------------------------------------------===//
77 
78 ConstantIntRanges
79 mlir::intrange::inferIndexOp(InferRangeFn inferFn,
80                              ArrayRef<ConstantIntRanges> argRanges,
81                              intrange::CmpMode mode) {
82   ConstantIntRanges sixtyFour = inferFn(argRanges);
83   SmallVector<ConstantIntRanges, 2> truncated;
84   llvm::transform(argRanges, std::back_inserter(truncated),
85                   [](const ConstantIntRanges &range) {
86                     return truncRange(range, /*destWidth=*/indexMinWidth);
87                   });
88   ConstantIntRanges thirtyTwo = inferFn(truncated);
89   ConstantIntRanges thirtyTwoAsSixtyFour =
90       extRange(thirtyTwo, /*destWidth=*/indexMaxWidth);
91   ConstantIntRanges sixtyFourAsThirtyTwo =
92       truncRange(sixtyFour, /*destWidth=*/indexMinWidth);
93 
94   LLVM_DEBUG(llvm::dbgs() << "Index handling: 64-bit result = " << sixtyFour
95                           << " 32-bit = " << thirtyTwo << "\n");
96   bool truncEqual = false;
97   switch (mode) {
98   case intrange::CmpMode::Both:
99     truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo);
100     break;
101   case intrange::CmpMode::Signed:
102     truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.smin() &&
103                   thirtyTwo.smax() == sixtyFourAsThirtyTwo.smax());
104     break;
105   case intrange::CmpMode::Unsigned:
106     truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.umin() &&
107                   thirtyTwo.umax() == sixtyFourAsThirtyTwo.umax());
108     break;
109   }
110   if (truncEqual)
111     // Returing the 64-bit result preserves more information.
112     return sixtyFour;
113   ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour);
114   return merged;
115 }
116 
117 ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range,
118                                            unsigned int destWidth) {
119   APInt umin = range.umin().zext(destWidth);
120   APInt umax = range.umax().zext(destWidth);
121   APInt smin = range.smin().sext(destWidth);
122   APInt smax = range.smax().sext(destWidth);
123   return {umin, umax, smin, smax};
124 }
125 
126 ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range,
127                                              unsigned destWidth) {
128   APInt umin = range.umin().zext(destWidth);
129   APInt umax = range.umax().zext(destWidth);
130   return ConstantIntRanges::fromUnsigned(umin, umax);
131 }
132 
133 ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range,
134                                              unsigned destWidth) {
135   APInt smin = range.smin().sext(destWidth);
136   APInt smax = range.smax().sext(destWidth);
137   return ConstantIntRanges::fromSigned(smin, smax);
138 }
139 
140 ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
141                                              unsigned int destWidth) {
142   // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb],
143   // the range of the resulting value is not contiguous ind includes 0.
144   // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2],
145   // but you can't truncate [255, 257] similarly.
146   bool hasUnsignedRollover =
147       range.umin().lshr(destWidth) != range.umax().lshr(destWidth);
148   APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth)
149                                    : range.umin().trunc(destWidth);
150   APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth)
151                                    : range.umax().trunc(destWidth);
152 
153   // Signed post-truncation rollover will not occur when either:
154   // - The high parts of the min and max, plus the sign bit, are the same
155   // - The high halves + sign bit of the min and max are either all 1s or all 0s
156   //  and you won't create a [positive, negative] range by truncating.
157   // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8
158   // but not [255, 257]_i16 to a range of i8s. You can also truncate
159   // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16.
160   // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e)
161   // will truncate to 0x7e, which is greater than 0
162   APInt sminHighPart = range.smin().ashr(destWidth - 1);
163   APInt smaxHighPart = range.smax().ashr(destWidth - 1);
164   bool hasSignedOverflow =
165       (sminHighPart != smaxHighPart) &&
166       !(sminHighPart.isAllOnes() &&
167         (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) &&
168       !(sminHighPart.isZero() && smaxHighPart.isZero());
169   APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth)
170                                  : range.smin().trunc(destWidth);
171   APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
172                                  : range.smax().trunc(destWidth);
173   return {umin, umax, smin, smax};
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // Addition
178 //===----------------------------------------------------------------------===//
179 
180 ConstantIntRanges
181 mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges,
182                          OverflowFlags ovfFlags) {
183   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
184 
185   std::function uadd = [=](const APInt &a,
186                            const APInt &b) -> std::optional<APInt> {
187     bool overflowed = false;
188     APInt result = any(ovfFlags & OverflowFlags::Nuw)
189                        ? a.uadd_sat(b)
190                        : a.uadd_ov(b, overflowed);
191     return overflowed ? std::optional<APInt>() : result;
192   };
193   std::function sadd = [=](const APInt &a,
194                            const APInt &b) -> std::optional<APInt> {
195     bool overflowed = false;
196     APInt result = any(ovfFlags & OverflowFlags::Nsw)
197                        ? a.sadd_sat(b)
198                        : a.sadd_ov(b, overflowed);
199     return overflowed ? std::optional<APInt>() : result;
200   };
201 
202   ConstantIntRanges urange = computeBoundsBy(
203       uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
204   ConstantIntRanges srange = computeBoundsBy(
205       sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
206   return urange.intersection(srange);
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // Subtraction
211 //===----------------------------------------------------------------------===//
212 
213 ConstantIntRanges
214 mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges,
215                          OverflowFlags ovfFlags) {
216   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
217 
218   std::function usub = [=](const APInt &a,
219                            const APInt &b) -> std::optional<APInt> {
220     bool overflowed = false;
221     APInt result = any(ovfFlags & OverflowFlags::Nuw)
222                        ? a.usub_sat(b)
223                        : a.usub_ov(b, overflowed);
224     return overflowed ? std::optional<APInt>() : result;
225   };
226   std::function ssub = [=](const APInt &a,
227                            const APInt &b) -> std::optional<APInt> {
228     bool overflowed = false;
229     APInt result = any(ovfFlags & OverflowFlags::Nsw)
230                        ? a.ssub_sat(b)
231                        : a.ssub_ov(b, overflowed);
232     return overflowed ? std::optional<APInt>() : result;
233   };
234   ConstantIntRanges urange = computeBoundsBy(
235       usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
236   ConstantIntRanges srange = computeBoundsBy(
237       ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
238   return urange.intersection(srange);
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // Multiplication
243 //===----------------------------------------------------------------------===//
244 
245 ConstantIntRanges
246 mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges,
247                          OverflowFlags ovfFlags) {
248   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
249 
250   std::function umul = [=](const APInt &a,
251                            const APInt &b) -> std::optional<APInt> {
252     bool overflowed = false;
253     APInt result = any(ovfFlags & OverflowFlags::Nuw)
254                        ? a.umul_sat(b)
255                        : a.umul_ov(b, overflowed);
256     return overflowed ? std::optional<APInt>() : result;
257   };
258   std::function smul = [=](const APInt &a,
259                            const APInt &b) -> std::optional<APInt> {
260     bool overflowed = false;
261     APInt result = any(ovfFlags & OverflowFlags::Nsw)
262                        ? a.smul_sat(b)
263                        : a.smul_ov(b, overflowed);
264     return overflowed ? std::optional<APInt>() : result;
265   };
266 
267   ConstantIntRanges urange =
268       minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
269                /*isSigned=*/false);
270   ConstantIntRanges srange =
271       minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
272                /*isSigned=*/true);
273   return urange.intersection(srange);
274 }
275 
276 //===----------------------------------------------------------------------===//
277 // DivU, CeilDivU (Unsigned division)
278 //===----------------------------------------------------------------------===//
279 
280 /// Fix up division results (ex. for ceiling and floor), returning an APInt
281 /// if there has been no overflow
282 using DivisionFixupFn = function_ref<std::optional<APInt>(
283     const APInt &lhs, const APInt &rhs, const APInt &result)>;
284 
285 static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs,
286                                         const ConstantIntRanges &rhs,
287                                         DivisionFixupFn fixup) {
288   const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(),
289               &rhsMax = rhs.umax();
290 
291   if (!rhsMin.isZero()) {
292     auto udiv = [&fixup](const APInt &a,
293                          const APInt &b) -> std::optional<APInt> {
294       return fixup(a, b, a.udiv(b));
295     };
296     return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
297                     /*isSigned=*/false);
298   }
299   // Otherwise, it's possible we might divide by 0.
300   return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
301 }
302 
303 ConstantIntRanges
304 mlir::intrange::inferDivU(ArrayRef<ConstantIntRanges> argRanges) {
305   return inferDivURange(argRanges[0], argRanges[1],
306                         [](const APInt &lhs, const APInt &rhs,
307                            const APInt &result) { return result; });
308 }
309 
310 ConstantIntRanges
311 mlir::intrange::inferCeilDivU(ArrayRef<ConstantIntRanges> argRanges) {
312   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
313 
314   DivisionFixupFn ceilDivUIFix =
315       [](const APInt &lhs, const APInt &rhs,
316          const APInt &result) -> std::optional<APInt> {
317     if (!lhs.urem(rhs).isZero()) {
318       bool overflowed = false;
319       APInt corrected =
320           result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed);
321       return overflowed ? std::optional<APInt>() : corrected;
322     }
323     return result;
324   };
325   return inferDivURange(lhs, rhs, ceilDivUIFix);
326 }
327 
328 //===----------------------------------------------------------------------===//
329 // DivS, CeilDivS, FloorDivS (Signed division)
330 //===----------------------------------------------------------------------===//
331 
332 static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs,
333                                         const ConstantIntRanges &rhs,
334                                         DivisionFixupFn fixup) {
335   const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
336               &rhsMax = rhs.smax();
337   bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative();
338 
339   if (canDivide) {
340     auto sdiv = [&fixup](const APInt &a,
341                          const APInt &b) -> std::optional<APInt> {
342       bool overflowed = false;
343       APInt result = a.sdiv_ov(b, overflowed);
344       return overflowed ? std::optional<APInt>() : fixup(a, b, result);
345     };
346     return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
347                     /*isSigned=*/true);
348   }
349   return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
350 }
351 
352 ConstantIntRanges
353 mlir::intrange::inferDivS(ArrayRef<ConstantIntRanges> argRanges) {
354   return inferDivSRange(argRanges[0], argRanges[1],
355                         [](const APInt &lhs, const APInt &rhs,
356                            const APInt &result) { return result; });
357 }
358 
359 ConstantIntRanges
360 mlir::intrange::inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges) {
361   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
362 
363   DivisionFixupFn ceilDivSIFix =
364       [](const APInt &lhs, const APInt &rhs,
365          const APInt &result) -> std::optional<APInt> {
366     if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) {
367       bool overflowed = false;
368       APInt corrected =
369           result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed);
370       return overflowed ? std::optional<APInt>() : corrected;
371     }
372     return result;
373   };
374   return inferDivSRange(lhs, rhs, ceilDivSIFix);
375 }
376 
377 ConstantIntRanges
378 mlir::intrange::inferFloorDivS(ArrayRef<ConstantIntRanges> argRanges) {
379   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
380 
381   DivisionFixupFn floorDivSIFix =
382       [](const APInt &lhs, const APInt &rhs,
383          const APInt &result) -> std::optional<APInt> {
384     if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) {
385       bool overflowed = false;
386       APInt corrected =
387           result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed);
388       return overflowed ? std::optional<APInt>() : corrected;
389     }
390     return result;
391   };
392   return inferDivSRange(lhs, rhs, floorDivSIFix);
393 }
394 
395 //===----------------------------------------------------------------------===//
396 // Signed remainder (RemS)
397 //===----------------------------------------------------------------------===//
398 
399 ConstantIntRanges
400 mlir::intrange::inferRemS(ArrayRef<ConstantIntRanges> argRanges) {
401   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
402   const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
403               &rhsMax = rhs.smax();
404 
405   unsigned width = rhsMax.getBitWidth();
406   APInt smin = APInt::getSignedMinValue(width);
407   APInt smax = APInt::getSignedMaxValue(width);
408   // No bounds if zero could be a divisor.
409   bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative());
410   if (canBound) {
411     APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs();
412     bool canNegativeDividend = lhsMin.isNegative();
413     bool canPositiveDividend = lhsMax.isStrictlyPositive();
414     APInt zero = APInt::getZero(maxDivisor.getBitWidth());
415     APInt maxPositiveResult = maxDivisor - 1;
416     APInt minNegativeResult = -maxPositiveResult;
417     smin = canNegativeDividend ? minNegativeResult : zero;
418     smax = canPositiveDividend ? maxPositiveResult : zero;
419     // Special case: sweeping out a contiguous range in N/[modulus].
420     if (rhsMin == rhsMax) {
421       if ((lhsMax - lhsMin).ult(maxDivisor)) {
422         APInt minRem = lhsMin.srem(maxDivisor);
423         APInt maxRem = lhsMax.srem(maxDivisor);
424         if (minRem.sle(maxRem)) {
425           smin = minRem;
426           smax = maxRem;
427         }
428       }
429     }
430   }
431   return ConstantIntRanges::fromSigned(smin, smax);
432 }
433 
434 //===----------------------------------------------------------------------===//
435 // Unsigned remainder (RemU)
436 //===----------------------------------------------------------------------===//
437 
438 ConstantIntRanges
439 mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
440   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
441   const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
442 
443   unsigned width = rhsMin.getBitWidth();
444   APInt umin = APInt::getZero(width);
445   APInt umax = APInt::getMaxValue(width);
446 
447   if (!rhsMin.isZero()) {
448     umax = rhsMax - 1;
449     // Special case: sweeping out a contiguous range in N/[modulus]
450     if (rhsMin == rhsMax) {
451       const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
452       if ((lhsMax - lhsMin).ult(rhsMax)) {
453         APInt minRem = lhsMin.urem(rhsMax);
454         APInt maxRem = lhsMax.urem(rhsMax);
455         if (minRem.ule(maxRem)) {
456           umin = minRem;
457           umax = maxRem;
458         }
459       }
460     }
461   }
462   return ConstantIntRanges::fromUnsigned(umin, umax);
463 }
464 
465 //===----------------------------------------------------------------------===//
466 // Max and min (MaxS, MaxU, MinS, MinU)
467 //===----------------------------------------------------------------------===//
468 
469 ConstantIntRanges
470 mlir::intrange::inferMaxS(ArrayRef<ConstantIntRanges> argRanges) {
471   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
472 
473   const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
474   const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
475   return ConstantIntRanges::fromSigned(smin, smax);
476 }
477 
478 ConstantIntRanges
479 mlir::intrange::inferMaxU(ArrayRef<ConstantIntRanges> argRanges) {
480   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
481 
482   const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin();
483   const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax();
484   return ConstantIntRanges::fromUnsigned(umin, umax);
485 }
486 
487 ConstantIntRanges
488 mlir::intrange::inferMinS(ArrayRef<ConstantIntRanges> argRanges) {
489   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
490 
491   const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin();
492   const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax();
493   return ConstantIntRanges::fromSigned(smin, smax);
494 }
495 
496 ConstantIntRanges
497 mlir::intrange::inferMinU(ArrayRef<ConstantIntRanges> argRanges) {
498   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
499 
500   const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin();
501   const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax();
502   return ConstantIntRanges::fromUnsigned(umin, umax);
503 }
504 
505 //===----------------------------------------------------------------------===//
506 // Bitwise operators (And, Or, Xor)
507 //===----------------------------------------------------------------------===//
508 
509 /// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???,
510 /// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits
511 /// that both bonuds have in common. This gives us a consertive approximation
512 /// for what values can be passed to bitwise operations.
513 static std::tuple<APInt, APInt>
514 widenBitwiseBounds(const ConstantIntRanges &bound) {
515   APInt leftVal = bound.umin(), rightVal = bound.umax();
516   unsigned bitwidth = leftVal.getBitWidth();
517   unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero();
518   leftVal.clearLowBits(differingBits);
519   rightVal.setLowBits(differingBits);
520   return std::make_tuple(std::move(leftVal), std::move(rightVal));
521 }
522 
523 ConstantIntRanges
524 mlir::intrange::inferAnd(ArrayRef<ConstantIntRanges> argRanges) {
525   auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
526   auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
527   auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
528     return a & b;
529   };
530   return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
531                   /*isSigned=*/false);
532 }
533 
534 ConstantIntRanges
535 mlir::intrange::inferOr(ArrayRef<ConstantIntRanges> argRanges) {
536   auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
537   auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
538   auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
539     return a | b;
540   };
541   return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
542                   /*isSigned=*/false);
543 }
544 
545 ConstantIntRanges
546 mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
547   auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
548   auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
549   auto xori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
550     return a ^ b;
551   };
552   return minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
553                   /*isSigned=*/false);
554 }
555 
556 //===----------------------------------------------------------------------===//
557 // Shifts (Shl, ShrS, ShrU)
558 //===----------------------------------------------------------------------===//
559 
560 ConstantIntRanges
561 mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges,
562                          OverflowFlags ovfFlags) {
563   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
564   const APInt &rhsUMin = rhs.umin(), &rhsUMax = rhs.umax();
565 
566   // The signed/unsigned overflow behavior of shl by `rhs` matches a mul with
567   // 2^rhs.
568   std::function ushl = [=](const APInt &l,
569                            const APInt &r) -> std::optional<APInt> {
570     bool overflowed = false;
571     APInt result = any(ovfFlags & OverflowFlags::Nuw)
572                        ? l.ushl_sat(r)
573                        : l.ushl_ov(r, overflowed);
574     return overflowed ? std::optional<APInt>() : result;
575   };
576   std::function sshl = [=](const APInt &l,
577                            const APInt &r) -> std::optional<APInt> {
578     bool overflowed = false;
579     APInt result = any(ovfFlags & OverflowFlags::Nsw)
580                        ? l.sshl_sat(r)
581                        : l.sshl_ov(r, overflowed);
582     return overflowed ? std::optional<APInt>() : result;
583   };
584 
585   ConstantIntRanges urange =
586       minMaxBy(ushl, {lhs.umin(), lhs.umax()}, {rhsUMin, rhsUMax},
587                /*isSigned=*/false);
588   ConstantIntRanges srange =
589       minMaxBy(sshl, {lhs.smin(), lhs.smax()}, {rhsUMin, rhsUMax},
590                /*isSigned=*/true);
591   return urange.intersection(srange);
592 }
593 
594 ConstantIntRanges
595 mlir::intrange::inferShrS(ArrayRef<ConstantIntRanges> argRanges) {
596   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
597 
598   ConstArithFn ashr = [](const APInt &l,
599                          const APInt &r) -> std::optional<APInt> {
600     return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.ashr(r);
601   };
602 
603   return minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
604                   /*isSigned=*/true);
605 }
606 
607 ConstantIntRanges
608 mlir::intrange::inferShrU(ArrayRef<ConstantIntRanges> argRanges) {
609   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
610 
611   ConstArithFn lshr = [](const APInt &l,
612                          const APInt &r) -> std::optional<APInt> {
613     return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.lshr(r);
614   };
615   return minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
616                   /*isSigned=*/false);
617 }
618 
619 //===----------------------------------------------------------------------===//
620 // Comparisons (Cmp)
621 //===----------------------------------------------------------------------===//
622 
623 static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred) {
624   switch (pred) {
625   case intrange::CmpPredicate::eq:
626     return intrange::CmpPredicate::ne;
627   case intrange::CmpPredicate::ne:
628     return intrange::CmpPredicate::eq;
629   case intrange::CmpPredicate::slt:
630     return intrange::CmpPredicate::sge;
631   case intrange::CmpPredicate::sle:
632     return intrange::CmpPredicate::sgt;
633   case intrange::CmpPredicate::sgt:
634     return intrange::CmpPredicate::sle;
635   case intrange::CmpPredicate::sge:
636     return intrange::CmpPredicate::slt;
637   case intrange::CmpPredicate::ult:
638     return intrange::CmpPredicate::uge;
639   case intrange::CmpPredicate::ule:
640     return intrange::CmpPredicate::ugt;
641   case intrange::CmpPredicate::ugt:
642     return intrange::CmpPredicate::ule;
643   case intrange::CmpPredicate::uge:
644     return intrange::CmpPredicate::ult;
645   }
646   llvm_unreachable("unknown cmp predicate value");
647 }
648 
649 static bool isStaticallyTrue(intrange::CmpPredicate pred,
650                              const ConstantIntRanges &lhs,
651                              const ConstantIntRanges &rhs) {
652   switch (pred) {
653   case intrange::CmpPredicate::sle:
654     return lhs.smax().sle(rhs.smin());
655   case intrange::CmpPredicate::slt:
656     return lhs.smax().slt(rhs.smin());
657   case intrange::CmpPredicate::ule:
658     return lhs.umax().ule(rhs.umin());
659   case intrange::CmpPredicate::ult:
660     return lhs.umax().ult(rhs.umin());
661   case intrange::CmpPredicate::sge:
662     return lhs.smin().sge(rhs.smax());
663   case intrange::CmpPredicate::sgt:
664     return lhs.smin().sgt(rhs.smax());
665   case intrange::CmpPredicate::uge:
666     return lhs.umin().uge(rhs.umax());
667   case intrange::CmpPredicate::ugt:
668     return lhs.umin().ugt(rhs.umax());
669   case intrange::CmpPredicate::eq: {
670     std::optional<APInt> lhsConst = lhs.getConstantValue();
671     std::optional<APInt> rhsConst = rhs.getConstantValue();
672     return lhsConst && rhsConst && lhsConst == rhsConst;
673   }
674   case intrange::CmpPredicate::ne: {
675     // While equality requires that there is an interpration of the preceeding
676     // computations that produces equal constants, whether that be signed or
677     // unsigned, statically determining inequality requires that neither
678     // interpretation produce potentially overlapping ranges.
679     bool sne = isStaticallyTrue(intrange::CmpPredicate::slt, lhs, rhs) ||
680                isStaticallyTrue(intrange::CmpPredicate::sgt, lhs, rhs);
681     bool une = isStaticallyTrue(intrange::CmpPredicate::ult, lhs, rhs) ||
682                isStaticallyTrue(intrange::CmpPredicate::ugt, lhs, rhs);
683     return sne && une;
684   }
685   }
686   return false;
687 }
688 
689 std::optional<bool> mlir::intrange::evaluatePred(CmpPredicate pred,
690                                                  const ConstantIntRanges &lhs,
691                                                  const ConstantIntRanges &rhs) {
692   if (isStaticallyTrue(pred, lhs, rhs))
693     return true;
694   if (isStaticallyTrue(invertPredicate(pred), lhs, rhs))
695     return false;
696   return std::nullopt;
697 }
698