xref: /llvm-project/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp (revision 5af9d16dae71f2c2087ba88c5fc06893e6aecfe9)
1 //===- InferIntRangeCommon.cpp - Inference for common ops ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains implementations of range inference for operations that are
10 // common to both the `arith` and `index` dialects to facilitate reuse.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
15 
16 #include "mlir/Interfaces/InferIntRangeInterface.h"
17 
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/STLExtras.h"
20 
21 #include "llvm/Support/Debug.h"
22 
23 #include <iterator>
24 #include <optional>
25 
26 using namespace mlir;
27 
28 #define DEBUG_TYPE "int-range-analysis"
29 
30 //===----------------------------------------------------------------------===//
31 // General utilities
32 //===----------------------------------------------------------------------===//
33 
34 /// Function that evaluates the result of doing something on arithmetic
35 /// constants and returns std::nullopt on overflow.
36 using ConstArithFn =
37     function_ref<std::optional<APInt>(const APInt &, const APInt &)>;
38 
39 /// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
40 /// If either computation overflows, make the result unbounded.
41 static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
42                                          const APInt &minRight,
43                                          const APInt &maxLeft,
44                                          const APInt &maxRight, bool isSigned) {
45   std::optional<APInt> maybeMin = op(minLeft, minRight);
46   std::optional<APInt> maybeMax = op(maxLeft, maxRight);
47   if (maybeMin && maybeMax)
48     return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned);
49   return ConstantIntRanges::maxRange(minLeft.getBitWidth());
50 }
51 
52 /// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`,
53 /// ignoring unbounded values. Returns the maximal range if `op` overflows.
54 static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs,
55                                   ArrayRef<APInt> rhs, bool isSigned) {
56   unsigned width = lhs[0].getBitWidth();
57   APInt min =
58       isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width);
59   APInt max =
60       isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width);
61   for (const APInt &left : lhs) {
62     for (const APInt &right : rhs) {
63       std::optional<APInt> maybeThisResult = op(left, right);
64       if (!maybeThisResult)
65         return ConstantIntRanges::maxRange(width);
66       APInt result = std::move(*maybeThisResult);
67       min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min;
68       max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max;
69     }
70   }
71   return ConstantIntRanges::range(min, max, isSigned);
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // Ext, trunc, index op handling
76 //===----------------------------------------------------------------------===//
77 
78 ConstantIntRanges
79 mlir::intrange::inferIndexOp(InferRangeFn inferFn,
80                              ArrayRef<ConstantIntRanges> argRanges,
81                              intrange::CmpMode mode) {
82   ConstantIntRanges sixtyFour = inferFn(argRanges);
83   SmallVector<ConstantIntRanges, 2> truncated;
84   llvm::transform(argRanges, std::back_inserter(truncated),
85                   [](const ConstantIntRanges &range) {
86                     return truncRange(range, /*destWidth=*/indexMinWidth);
87                   });
88   ConstantIntRanges thirtyTwo = inferFn(truncated);
89   ConstantIntRanges thirtyTwoAsSixtyFour =
90       extRange(thirtyTwo, /*destWidth=*/indexMaxWidth);
91   ConstantIntRanges sixtyFourAsThirtyTwo =
92       truncRange(sixtyFour, /*destWidth=*/indexMinWidth);
93 
94   LLVM_DEBUG(llvm::dbgs() << "Index handling: 64-bit result = " << sixtyFour
95                           << " 32-bit = " << thirtyTwo << "\n");
96   bool truncEqual = false;
97   switch (mode) {
98   case intrange::CmpMode::Both:
99     truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo);
100     break;
101   case intrange::CmpMode::Signed:
102     truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.smin() &&
103                   thirtyTwo.smax() == sixtyFourAsThirtyTwo.smax());
104     break;
105   case intrange::CmpMode::Unsigned:
106     truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.umin() &&
107                   thirtyTwo.umax() == sixtyFourAsThirtyTwo.umax());
108     break;
109   }
110   if (truncEqual)
111     // Returing the 64-bit result preserves more information.
112     return sixtyFour;
113   ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour);
114   return merged;
115 }
116 
117 ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range,
118                                            unsigned int destWidth) {
119   APInt umin = range.umin().zext(destWidth);
120   APInt umax = range.umax().zext(destWidth);
121   APInt smin = range.smin().sext(destWidth);
122   APInt smax = range.smax().sext(destWidth);
123   return {umin, umax, smin, smax};
124 }
125 
126 ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range,
127                                              unsigned destWidth) {
128   APInt umin = range.umin().zext(destWidth);
129   APInt umax = range.umax().zext(destWidth);
130   return ConstantIntRanges::fromUnsigned(umin, umax);
131 }
132 
133 ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range,
134                                              unsigned destWidth) {
135   APInt smin = range.smin().sext(destWidth);
136   APInt smax = range.smax().sext(destWidth);
137   return ConstantIntRanges::fromSigned(smin, smax);
138 }
139 
140 ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
141                                              unsigned int destWidth) {
142   // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb],
143   // the range of the resulting value is not contiguous ind includes 0.
144   // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2],
145   // but you can't truncate [255, 257] similarly.
146   bool hasUnsignedRollover =
147       range.umin().lshr(destWidth) != range.umax().lshr(destWidth);
148   APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth)
149                                    : range.umin().trunc(destWidth);
150   APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth)
151                                    : range.umax().trunc(destWidth);
152 
153   // Signed post-truncation rollover will not occur when either:
154   // - The high parts of the min and max, plus the sign bit, are the same
155   // - The high halves + sign bit of the min and max are either all 1s or all 0s
156   //  and you won't create a [positive, negative] range by truncating.
157   // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8
158   // but not [255, 257]_i16 to a range of i8s. You can also truncate
159   // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16.
160   // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e)
161   // will truncate to 0x7e, which is greater than 0
162   APInt sminHighPart = range.smin().ashr(destWidth - 1);
163   APInt smaxHighPart = range.smax().ashr(destWidth - 1);
164   bool hasSignedOverflow =
165       (sminHighPart != smaxHighPart) &&
166       !(sminHighPart.isAllOnes() &&
167         (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) &&
168       !(sminHighPart.isZero() && smaxHighPart.isZero());
169   APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth)
170                                  : range.smin().trunc(destWidth);
171   APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
172                                  : range.smax().trunc(destWidth);
173   return {umin, umax, smin, smax};
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // Addition
178 //===----------------------------------------------------------------------===//
179 
180 ConstantIntRanges
181 mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges) {
182   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
183   ConstArithFn uadd = [](const APInt &a,
184                          const APInt &b) -> std::optional<APInt> {
185     bool overflowed = false;
186     APInt result = a.uadd_ov(b, overflowed);
187     return overflowed ? std::optional<APInt>() : result;
188   };
189   ConstArithFn sadd = [](const APInt &a,
190                          const APInt &b) -> std::optional<APInt> {
191     bool overflowed = false;
192     APInt result = a.sadd_ov(b, overflowed);
193     return overflowed ? std::optional<APInt>() : result;
194   };
195 
196   ConstantIntRanges urange = computeBoundsBy(
197       uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
198   ConstantIntRanges srange = computeBoundsBy(
199       sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
200   return urange.intersection(srange);
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // Subtraction
205 //===----------------------------------------------------------------------===//
206 
207 ConstantIntRanges
208 mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges) {
209   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
210 
211   ConstArithFn usub = [](const APInt &a,
212                          const APInt &b) -> std::optional<APInt> {
213     bool overflowed = false;
214     APInt result = a.usub_ov(b, overflowed);
215     return overflowed ? std::optional<APInt>() : result;
216   };
217   ConstArithFn ssub = [](const APInt &a,
218                          const APInt &b) -> std::optional<APInt> {
219     bool overflowed = false;
220     APInt result = a.ssub_ov(b, overflowed);
221     return overflowed ? std::optional<APInt>() : result;
222   };
223   ConstantIntRanges urange = computeBoundsBy(
224       usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
225   ConstantIntRanges srange = computeBoundsBy(
226       ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
227   return urange.intersection(srange);
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // Multiplication
232 //===----------------------------------------------------------------------===//
233 
234 ConstantIntRanges
235 mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges) {
236   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
237 
238   ConstArithFn umul = [](const APInt &a,
239                          const APInt &b) -> std::optional<APInt> {
240     bool overflowed = false;
241     APInt result = a.umul_ov(b, overflowed);
242     return overflowed ? std::optional<APInt>() : result;
243   };
244   ConstArithFn smul = [](const APInt &a,
245                          const APInt &b) -> std::optional<APInt> {
246     bool overflowed = false;
247     APInt result = a.smul_ov(b, overflowed);
248     return overflowed ? std::optional<APInt>() : result;
249   };
250 
251   ConstantIntRanges urange =
252       minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
253                /*isSigned=*/false);
254   ConstantIntRanges srange =
255       minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
256                /*isSigned=*/true);
257   return urange.intersection(srange);
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // DivU, CeilDivU (Unsigned division)
262 //===----------------------------------------------------------------------===//
263 
264 /// Fix up division results (ex. for ceiling and floor), returning an APInt
265 /// if there has been no overflow
266 using DivisionFixupFn = function_ref<std::optional<APInt>(
267     const APInt &lhs, const APInt &rhs, const APInt &result)>;
268 
269 static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs,
270                                         const ConstantIntRanges &rhs,
271                                         DivisionFixupFn fixup) {
272   const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(),
273               &rhsMax = rhs.umax();
274 
275   if (!rhsMin.isZero()) {
276     auto udiv = [&fixup](const APInt &a,
277                          const APInt &b) -> std::optional<APInt> {
278       return fixup(a, b, a.udiv(b));
279     };
280     return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
281                     /*isSigned=*/false);
282   }
283   // Otherwise, it's possible we might divide by 0.
284   return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
285 }
286 
287 ConstantIntRanges
288 mlir::intrange::inferDivU(ArrayRef<ConstantIntRanges> argRanges) {
289   return inferDivURange(argRanges[0], argRanges[1],
290                         [](const APInt &lhs, const APInt &rhs,
291                            const APInt &result) { return result; });
292 }
293 
294 ConstantIntRanges
295 mlir::intrange::inferCeilDivU(ArrayRef<ConstantIntRanges> argRanges) {
296   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
297 
298   DivisionFixupFn ceilDivUIFix =
299       [](const APInt &lhs, const APInt &rhs,
300          const APInt &result) -> std::optional<APInt> {
301     if (!lhs.urem(rhs).isZero()) {
302       bool overflowed = false;
303       APInt corrected =
304           result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed);
305       return overflowed ? std::optional<APInt>() : corrected;
306     }
307     return result;
308   };
309   return inferDivURange(lhs, rhs, ceilDivUIFix);
310 }
311 
312 //===----------------------------------------------------------------------===//
313 // DivS, CeilDivS, FloorDivS (Signed division)
314 //===----------------------------------------------------------------------===//
315 
316 static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs,
317                                         const ConstantIntRanges &rhs,
318                                         DivisionFixupFn fixup) {
319   const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
320               &rhsMax = rhs.smax();
321   bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative();
322 
323   if (canDivide) {
324     auto sdiv = [&fixup](const APInt &a,
325                          const APInt &b) -> std::optional<APInt> {
326       bool overflowed = false;
327       APInt result = a.sdiv_ov(b, overflowed);
328       return overflowed ? std::optional<APInt>() : fixup(a, b, result);
329     };
330     return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
331                     /*isSigned=*/true);
332   }
333   return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
334 }
335 
336 ConstantIntRanges
337 mlir::intrange::inferDivS(ArrayRef<ConstantIntRanges> argRanges) {
338   return inferDivSRange(argRanges[0], argRanges[1],
339                         [](const APInt &lhs, const APInt &rhs,
340                            const APInt &result) { return result; });
341 }
342 
343 ConstantIntRanges
344 mlir::intrange::inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges) {
345   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
346 
347   DivisionFixupFn ceilDivSIFix =
348       [](const APInt &lhs, const APInt &rhs,
349          const APInt &result) -> std::optional<APInt> {
350     if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) {
351       bool overflowed = false;
352       APInt corrected =
353           result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed);
354       return overflowed ? std::optional<APInt>() : corrected;
355     }
356     return result;
357   };
358   return inferDivSRange(lhs, rhs, ceilDivSIFix);
359 }
360 
361 ConstantIntRanges
362 mlir::intrange::inferFloorDivS(ArrayRef<ConstantIntRanges> argRanges) {
363   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
364 
365   DivisionFixupFn floorDivSIFix =
366       [](const APInt &lhs, const APInt &rhs,
367          const APInt &result) -> std::optional<APInt> {
368     if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) {
369       bool overflowed = false;
370       APInt corrected =
371           result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed);
372       return overflowed ? std::optional<APInt>() : corrected;
373     }
374     return result;
375   };
376   return inferDivSRange(lhs, rhs, floorDivSIFix);
377 }
378 
379 //===----------------------------------------------------------------------===//
380 // Signed remainder (RemS)
381 //===----------------------------------------------------------------------===//
382 
383 ConstantIntRanges
384 mlir::intrange::inferRemS(ArrayRef<ConstantIntRanges> argRanges) {
385   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
386   const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
387               &rhsMax = rhs.smax();
388 
389   unsigned width = rhsMax.getBitWidth();
390   APInt smin = APInt::getSignedMinValue(width);
391   APInt smax = APInt::getSignedMaxValue(width);
392   // No bounds if zero could be a divisor.
393   bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative());
394   if (canBound) {
395     APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs();
396     bool canNegativeDividend = lhsMin.isNegative();
397     bool canPositiveDividend = lhsMax.isStrictlyPositive();
398     APInt zero = APInt::getZero(maxDivisor.getBitWidth());
399     APInt maxPositiveResult = maxDivisor - 1;
400     APInt minNegativeResult = -maxPositiveResult;
401     smin = canNegativeDividend ? minNegativeResult : zero;
402     smax = canPositiveDividend ? maxPositiveResult : zero;
403     // Special case: sweeping out a contiguous range in N/[modulus].
404     if (rhsMin == rhsMax) {
405       if ((lhsMax - lhsMin).ult(maxDivisor)) {
406         APInt minRem = lhsMin.srem(maxDivisor);
407         APInt maxRem = lhsMax.srem(maxDivisor);
408         if (minRem.sle(maxRem)) {
409           smin = minRem;
410           smax = maxRem;
411         }
412       }
413     }
414   }
415   return ConstantIntRanges::fromSigned(smin, smax);
416 }
417 
418 //===----------------------------------------------------------------------===//
419 // Unsigned remainder (RemU)
420 //===----------------------------------------------------------------------===//
421 
422 ConstantIntRanges
423 mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
424   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
425   const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
426 
427   unsigned width = rhsMin.getBitWidth();
428   APInt umin = APInt::getZero(width);
429   APInt umax = APInt::getMaxValue(width);
430 
431   if (!rhsMin.isZero()) {
432     umax = rhsMax - 1;
433     // Special case: sweeping out a contiguous range in N/[modulus]
434     if (rhsMin == rhsMax) {
435       const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
436       if ((lhsMax - lhsMin).ult(rhsMax)) {
437         APInt minRem = lhsMin.urem(rhsMax);
438         APInt maxRem = lhsMax.urem(rhsMax);
439         if (minRem.ule(maxRem)) {
440           umin = minRem;
441           umax = maxRem;
442         }
443       }
444     }
445   }
446   return ConstantIntRanges::fromUnsigned(umin, umax);
447 }
448 
449 //===----------------------------------------------------------------------===//
450 // Max and min (MaxS, MaxU, MinS, MinU)
451 //===----------------------------------------------------------------------===//
452 
453 ConstantIntRanges
454 mlir::intrange::inferMaxS(ArrayRef<ConstantIntRanges> argRanges) {
455   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
456 
457   const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
458   const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
459   return ConstantIntRanges::fromSigned(smin, smax);
460 }
461 
462 ConstantIntRanges
463 mlir::intrange::inferMaxU(ArrayRef<ConstantIntRanges> argRanges) {
464   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
465 
466   const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin();
467   const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax();
468   return ConstantIntRanges::fromUnsigned(umin, umax);
469 }
470 
471 ConstantIntRanges
472 mlir::intrange::inferMinS(ArrayRef<ConstantIntRanges> argRanges) {
473   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
474 
475   const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin();
476   const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax();
477   return ConstantIntRanges::fromSigned(smin, smax);
478 }
479 
480 ConstantIntRanges
481 mlir::intrange::inferMinU(ArrayRef<ConstantIntRanges> argRanges) {
482   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
483 
484   const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin();
485   const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax();
486   return ConstantIntRanges::fromUnsigned(umin, umax);
487 }
488 
489 //===----------------------------------------------------------------------===//
490 // Bitwise operators (And, Or, Xor)
491 //===----------------------------------------------------------------------===//
492 
493 /// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???,
494 /// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits
495 /// that both bonuds have in common. This gives us a consertive approximation
496 /// for what values can be passed to bitwise operations.
497 static std::tuple<APInt, APInt>
498 widenBitwiseBounds(const ConstantIntRanges &bound) {
499   APInt leftVal = bound.umin(), rightVal = bound.umax();
500   unsigned bitwidth = leftVal.getBitWidth();
501   unsigned differingBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros();
502   leftVal.clearLowBits(differingBits);
503   rightVal.setLowBits(differingBits);
504   return std::make_tuple(std::move(leftVal), std::move(rightVal));
505 }
506 
507 ConstantIntRanges
508 mlir::intrange::inferAnd(ArrayRef<ConstantIntRanges> argRanges) {
509   auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
510   auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
511   auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
512     return a & b;
513   };
514   return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
515                   /*isSigned=*/false);
516 }
517 
518 ConstantIntRanges
519 mlir::intrange::inferOr(ArrayRef<ConstantIntRanges> argRanges) {
520   auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
521   auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
522   auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
523     return a | b;
524   };
525   return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
526                   /*isSigned=*/false);
527 }
528 
529 ConstantIntRanges
530 mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
531   auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
532   auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
533   auto xori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
534     return a ^ b;
535   };
536   return minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
537                   /*isSigned=*/false);
538 }
539 
540 //===----------------------------------------------------------------------===//
541 // Shifts (Shl, ShrS, ShrU)
542 //===----------------------------------------------------------------------===//
543 
544 ConstantIntRanges
545 mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
546   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
547   ConstArithFn shl = [](const APInt &l,
548                         const APInt &r) -> std::optional<APInt> {
549     return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
550   };
551   ConstantIntRanges urange =
552       minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
553                /*isSigned=*/false);
554   ConstantIntRanges srange =
555       minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
556                /*isSigned=*/true);
557   return urange.intersection(srange);
558 }
559 
560 ConstantIntRanges
561 mlir::intrange::inferShrS(ArrayRef<ConstantIntRanges> argRanges) {
562   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
563 
564   ConstArithFn ashr = [](const APInt &l,
565                          const APInt &r) -> std::optional<APInt> {
566     return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.ashr(r);
567   };
568 
569   return minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
570                   /*isSigned=*/true);
571 }
572 
573 ConstantIntRanges
574 mlir::intrange::inferShrU(ArrayRef<ConstantIntRanges> argRanges) {
575   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
576 
577   ConstArithFn lshr = [](const APInt &l,
578                          const APInt &r) -> std::optional<APInt> {
579     return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.lshr(r);
580   };
581   return minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
582                   /*isSigned=*/false);
583 }
584 
585 //===----------------------------------------------------------------------===//
586 // Comparisons (Cmp)
587 //===----------------------------------------------------------------------===//
588 
589 static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred) {
590   switch (pred) {
591   case intrange::CmpPredicate::eq:
592     return intrange::CmpPredicate::ne;
593   case intrange::CmpPredicate::ne:
594     return intrange::CmpPredicate::eq;
595   case intrange::CmpPredicate::slt:
596     return intrange::CmpPredicate::sge;
597   case intrange::CmpPredicate::sle:
598     return intrange::CmpPredicate::sgt;
599   case intrange::CmpPredicate::sgt:
600     return intrange::CmpPredicate::sle;
601   case intrange::CmpPredicate::sge:
602     return intrange::CmpPredicate::slt;
603   case intrange::CmpPredicate::ult:
604     return intrange::CmpPredicate::uge;
605   case intrange::CmpPredicate::ule:
606     return intrange::CmpPredicate::ugt;
607   case intrange::CmpPredicate::ugt:
608     return intrange::CmpPredicate::ule;
609   case intrange::CmpPredicate::uge:
610     return intrange::CmpPredicate::ult;
611   }
612   llvm_unreachable("unknown cmp predicate value");
613 }
614 
615 static bool isStaticallyTrue(intrange::CmpPredicate pred,
616                              const ConstantIntRanges &lhs,
617                              const ConstantIntRanges &rhs) {
618   switch (pred) {
619   case intrange::CmpPredicate::sle:
620     return lhs.smax().sle(rhs.smin());
621   case intrange::CmpPredicate::slt:
622     return lhs.smax().slt(rhs.smin());
623   case intrange::CmpPredicate::ule:
624     return lhs.umax().ule(rhs.umin());
625   case intrange::CmpPredicate::ult:
626     return lhs.umax().ult(rhs.umin());
627   case intrange::CmpPredicate::sge:
628     return lhs.smin().sge(rhs.smax());
629   case intrange::CmpPredicate::sgt:
630     return lhs.smin().sgt(rhs.smax());
631   case intrange::CmpPredicate::uge:
632     return lhs.umin().uge(rhs.umax());
633   case intrange::CmpPredicate::ugt:
634     return lhs.umin().ugt(rhs.umax());
635   case intrange::CmpPredicate::eq: {
636     std::optional<APInt> lhsConst = lhs.getConstantValue();
637     std::optional<APInt> rhsConst = rhs.getConstantValue();
638     return lhsConst && rhsConst && lhsConst == rhsConst;
639   }
640   case intrange::CmpPredicate::ne: {
641     // While equality requires that there is an interpration of the preceeding
642     // computations that produces equal constants, whether that be signed or
643     // unsigned, statically determining inequality requires that neither
644     // interpretation produce potentially overlapping ranges.
645     bool sne = isStaticallyTrue(intrange::CmpPredicate::slt, lhs, rhs) ||
646                isStaticallyTrue(intrange::CmpPredicate::sgt, lhs, rhs);
647     bool une = isStaticallyTrue(intrange::CmpPredicate::ult, lhs, rhs) ||
648                isStaticallyTrue(intrange::CmpPredicate::ugt, lhs, rhs);
649     return sne && une;
650   }
651   }
652   return false;
653 }
654 
655 std::optional<bool> mlir::intrange::evaluatePred(CmpPredicate pred,
656                                                  const ConstantIntRanges &lhs,
657                                                  const ConstantIntRanges &rhs) {
658   if (isStaticallyTrue(pred, lhs, rhs))
659     return true;
660   if (isStaticallyTrue(invertPredicate(pred), lhs, rhs))
661     return false;
662   return std::nullopt;
663 }
664