xref: /llvm-project/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp (revision 9ab16d49c99966f33900d68ed5370f19927ca52c)
1 //===- InferIntRangeCommon.cpp - Inference for common ops ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains implementations of range inference for operations that are
10 // common to both the `arith` and `index` dialects to facilitate reuse.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
15 
16 #include "mlir/Interfaces/InferIntRangeInterface.h"
17 
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/STLExtras.h"
20 
21 #include "llvm/Support/Debug.h"
22 
23 #include <iterator>
24 #include <optional>
25 
26 using namespace mlir;
27 
28 #define DEBUG_TYPE "int-range-analysis"
29 
30 //===----------------------------------------------------------------------===//
31 // General utilities
32 //===----------------------------------------------------------------------===//
33 
34 /// Function that evaluates the result of doing something on arithmetic
35 /// constants and returns std::nullopt on overflow.
36 using ConstArithFn =
37     function_ref<std::optional<APInt>(const APInt &, const APInt &)>;
38 using ConstArithStdFn =
39     std::function<std::optional<APInt>(const APInt &, const APInt &)>;
40 
41 /// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
42 /// If either computation overflows, make the result unbounded.
43 static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
44                                          const APInt &minRight,
45                                          const APInt &maxLeft,
46                                          const APInt &maxRight, bool isSigned) {
47   std::optional<APInt> maybeMin = op(minLeft, minRight);
48   std::optional<APInt> maybeMax = op(maxLeft, maxRight);
49   if (maybeMin && maybeMax)
50     return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned);
51   return ConstantIntRanges::maxRange(minLeft.getBitWidth());
52 }
53 
54 /// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`,
55 /// ignoring unbounded values. Returns the maximal range if `op` overflows.
56 static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs,
57                                   ArrayRef<APInt> rhs, bool isSigned) {
58   unsigned width = lhs[0].getBitWidth();
59   APInt min =
60       isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width);
61   APInt max =
62       isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width);
63   for (const APInt &left : lhs) {
64     for (const APInt &right : rhs) {
65       std::optional<APInt> maybeThisResult = op(left, right);
66       if (!maybeThisResult)
67         return ConstantIntRanges::maxRange(width);
68       APInt result = std::move(*maybeThisResult);
69       min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min;
70       max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max;
71     }
72   }
73   return ConstantIntRanges::range(min, max, isSigned);
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // Ext, trunc, index op handling
78 //===----------------------------------------------------------------------===//
79 
80 ConstantIntRanges
81 mlir::intrange::inferIndexOp(const InferRangeFn &inferFn,
82                              ArrayRef<ConstantIntRanges> argRanges,
83                              intrange::CmpMode mode) {
84   ConstantIntRanges sixtyFour = inferFn(argRanges);
85   SmallVector<ConstantIntRanges, 2> truncated;
86   llvm::transform(argRanges, std::back_inserter(truncated),
87                   [](const ConstantIntRanges &range) {
88                     return truncRange(range, /*destWidth=*/indexMinWidth);
89                   });
90   ConstantIntRanges thirtyTwo = inferFn(truncated);
91   ConstantIntRanges thirtyTwoAsSixtyFour =
92       extRange(thirtyTwo, /*destWidth=*/indexMaxWidth);
93   ConstantIntRanges sixtyFourAsThirtyTwo =
94       truncRange(sixtyFour, /*destWidth=*/indexMinWidth);
95 
96   LLVM_DEBUG(llvm::dbgs() << "Index handling: 64-bit result = " << sixtyFour
97                           << " 32-bit = " << thirtyTwo << "\n");
98   bool truncEqual = false;
99   switch (mode) {
100   case intrange::CmpMode::Both:
101     truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo);
102     break;
103   case intrange::CmpMode::Signed:
104     truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.smin() &&
105                   thirtyTwo.smax() == sixtyFourAsThirtyTwo.smax());
106     break;
107   case intrange::CmpMode::Unsigned:
108     truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.umin() &&
109                   thirtyTwo.umax() == sixtyFourAsThirtyTwo.umax());
110     break;
111   }
112   if (truncEqual)
113     // Returing the 64-bit result preserves more information.
114     return sixtyFour;
115   ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour);
116   return merged;
117 }
118 
119 ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range,
120                                            unsigned int destWidth) {
121   APInt umin = range.umin().zext(destWidth);
122   APInt umax = range.umax().zext(destWidth);
123   APInt smin = range.smin().sext(destWidth);
124   APInt smax = range.smax().sext(destWidth);
125   return {umin, umax, smin, smax};
126 }
127 
128 ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range,
129                                              unsigned destWidth) {
130   APInt umin = range.umin().zext(destWidth);
131   APInt umax = range.umax().zext(destWidth);
132   return ConstantIntRanges::fromUnsigned(umin, umax);
133 }
134 
135 ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range,
136                                              unsigned destWidth) {
137   APInt smin = range.smin().sext(destWidth);
138   APInt smax = range.smax().sext(destWidth);
139   return ConstantIntRanges::fromSigned(smin, smax);
140 }
141 
142 ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
143                                              unsigned int destWidth) {
144   // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb],
145   // the range of the resulting value is not contiguous ind includes 0.
146   // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2],
147   // but you can't truncate [255, 257] similarly.
148   bool hasUnsignedRollover =
149       range.umin().lshr(destWidth) != range.umax().lshr(destWidth);
150   APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth)
151                                    : range.umin().trunc(destWidth);
152   APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth)
153                                    : range.umax().trunc(destWidth);
154 
155   // Signed post-truncation rollover will not occur when either:
156   // - The high parts of the min and max, plus the sign bit, are the same
157   // - The high halves + sign bit of the min and max are either all 1s or all 0s
158   //  and you won't create a [positive, negative] range by truncating.
159   // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8
160   // but not [255, 257]_i16 to a range of i8s. You can also truncate
161   // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16.
162   // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e)
163   // will truncate to 0x7e, which is greater than 0
164   APInt sminHighPart = range.smin().ashr(destWidth - 1);
165   APInt smaxHighPart = range.smax().ashr(destWidth - 1);
166   bool hasSignedOverflow =
167       (sminHighPart != smaxHighPart) &&
168       !(sminHighPart.isAllOnes() &&
169         (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) &&
170       !(sminHighPart.isZero() && smaxHighPart.isZero());
171   APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth)
172                                  : range.smin().trunc(destWidth);
173   APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
174                                  : range.smax().trunc(destWidth);
175   return {umin, umax, smin, smax};
176 }
177 
178 //===----------------------------------------------------------------------===//
179 // Addition
180 //===----------------------------------------------------------------------===//
181 
182 ConstantIntRanges
183 mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges,
184                          OverflowFlags ovfFlags) {
185   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
186 
187   ConstArithStdFn uadd = [=](const APInt &a,
188                              const APInt &b) -> std::optional<APInt> {
189     bool overflowed = false;
190     APInt result = any(ovfFlags & OverflowFlags::Nuw)
191                        ? a.uadd_sat(b)
192                        : a.uadd_ov(b, overflowed);
193     return overflowed ? std::optional<APInt>() : result;
194   };
195   ConstArithStdFn sadd = [=](const APInt &a,
196                              const APInt &b) -> std::optional<APInt> {
197     bool overflowed = false;
198     APInt result = any(ovfFlags & OverflowFlags::Nsw)
199                        ? a.sadd_sat(b)
200                        : a.sadd_ov(b, overflowed);
201     return overflowed ? std::optional<APInt>() : result;
202   };
203 
204   ConstantIntRanges urange = computeBoundsBy(
205       uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
206   ConstantIntRanges srange = computeBoundsBy(
207       sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
208   return urange.intersection(srange);
209 }
210 
211 //===----------------------------------------------------------------------===//
212 // Subtraction
213 //===----------------------------------------------------------------------===//
214 
215 ConstantIntRanges
216 mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges,
217                          OverflowFlags ovfFlags) {
218   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
219 
220   ConstArithStdFn usub = [=](const APInt &a,
221                              const APInt &b) -> std::optional<APInt> {
222     bool overflowed = false;
223     APInt result = any(ovfFlags & OverflowFlags::Nuw)
224                        ? a.usub_sat(b)
225                        : a.usub_ov(b, overflowed);
226     return overflowed ? std::optional<APInt>() : result;
227   };
228   ConstArithStdFn ssub = [=](const APInt &a,
229                              const APInt &b) -> std::optional<APInt> {
230     bool overflowed = false;
231     APInt result = any(ovfFlags & OverflowFlags::Nsw)
232                        ? a.ssub_sat(b)
233                        : a.ssub_ov(b, overflowed);
234     return overflowed ? std::optional<APInt>() : result;
235   };
236   ConstantIntRanges urange = computeBoundsBy(
237       usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
238   ConstantIntRanges srange = computeBoundsBy(
239       ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
240   return urange.intersection(srange);
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // Multiplication
245 //===----------------------------------------------------------------------===//
246 
247 ConstantIntRanges
248 mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges,
249                          OverflowFlags ovfFlags) {
250   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
251 
252   ConstArithStdFn umul = [=](const APInt &a,
253                              const APInt &b) -> std::optional<APInt> {
254     bool overflowed = false;
255     APInt result = any(ovfFlags & OverflowFlags::Nuw)
256                        ? a.umul_sat(b)
257                        : a.umul_ov(b, overflowed);
258     return overflowed ? std::optional<APInt>() : result;
259   };
260   ConstArithStdFn smul = [=](const APInt &a,
261                              const APInt &b) -> std::optional<APInt> {
262     bool overflowed = false;
263     APInt result = any(ovfFlags & OverflowFlags::Nsw)
264                        ? a.smul_sat(b)
265                        : a.smul_ov(b, overflowed);
266     return overflowed ? std::optional<APInt>() : result;
267   };
268 
269   ConstantIntRanges urange =
270       minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
271                /*isSigned=*/false);
272   ConstantIntRanges srange =
273       minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
274                /*isSigned=*/true);
275   return urange.intersection(srange);
276 }
277 
278 //===----------------------------------------------------------------------===//
279 // DivU, CeilDivU (Unsigned division)
280 //===----------------------------------------------------------------------===//
281 
282 /// Fix up division results (ex. for ceiling and floor), returning an APInt
283 /// if there has been no overflow
284 using DivisionFixupFn = function_ref<std::optional<APInt>(
285     const APInt &lhs, const APInt &rhs, const APInt &result)>;
286 
287 static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs,
288                                         const ConstantIntRanges &rhs,
289                                         DivisionFixupFn fixup) {
290   const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(),
291               &rhsMax = rhs.umax();
292 
293   if (!rhsMin.isZero()) {
294     auto udiv = [&fixup](const APInt &a,
295                          const APInt &b) -> std::optional<APInt> {
296       return fixup(a, b, a.udiv(b));
297     };
298     return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
299                     /*isSigned=*/false);
300   }
301 
302   APInt umin = APInt::getZero(rhsMin.getBitWidth());
303   if (lhsMin.uge(rhsMax) && !rhsMax.isZero())
304     umin = lhsMin.udiv(rhsMax);
305 
306   // X u/ Y u<= X.
307   APInt umax = lhsMax;
308   return ConstantIntRanges::fromUnsigned(umin, umax);
309 }
310 
311 ConstantIntRanges
312 mlir::intrange::inferDivU(ArrayRef<ConstantIntRanges> argRanges) {
313   return inferDivURange(argRanges[0], argRanges[1],
314                         [](const APInt &lhs, const APInt &rhs,
315                            const APInt &result) { return result; });
316 }
317 
318 ConstantIntRanges
319 mlir::intrange::inferCeilDivU(ArrayRef<ConstantIntRanges> argRanges) {
320   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
321 
322   auto ceilDivUIFix = [](const APInt &lhs, const APInt &rhs,
323                          const APInt &result) -> std::optional<APInt> {
324     if (!lhs.urem(rhs).isZero()) {
325       bool overflowed = false;
326       APInt corrected =
327           result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed);
328       return overflowed ? std::optional<APInt>() : corrected;
329     }
330     return result;
331   };
332   return inferDivURange(lhs, rhs, ceilDivUIFix);
333 }
334 
335 //===----------------------------------------------------------------------===//
336 // DivS, CeilDivS, FloorDivS (Signed division)
337 //===----------------------------------------------------------------------===//
338 
339 static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs,
340                                         const ConstantIntRanges &rhs,
341                                         DivisionFixupFn fixup) {
342   const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
343               &rhsMax = rhs.smax();
344   bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative();
345 
346   if (canDivide) {
347     auto sdiv = [&fixup](const APInt &a,
348                          const APInt &b) -> std::optional<APInt> {
349       bool overflowed = false;
350       APInt result = a.sdiv_ov(b, overflowed);
351       return overflowed ? std::optional<APInt>() : fixup(a, b, result);
352     };
353     return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
354                     /*isSigned=*/true);
355   }
356   return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
357 }
358 
359 ConstantIntRanges
360 mlir::intrange::inferDivS(ArrayRef<ConstantIntRanges> argRanges) {
361   return inferDivSRange(argRanges[0], argRanges[1],
362                         [](const APInt &lhs, const APInt &rhs,
363                            const APInt &result) { return result; });
364 }
365 
366 ConstantIntRanges
367 mlir::intrange::inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges) {
368   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
369 
370   auto ceilDivSIFix = [](const APInt &lhs, const APInt &rhs,
371                          const APInt &result) -> std::optional<APInt> {
372     if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) {
373       bool overflowed = false;
374       APInt corrected =
375           result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed);
376       return overflowed ? std::optional<APInt>() : corrected;
377     }
378     // Special case where the usual implementation of ceilDiv causes
379     // INT_MIN / [positive number] to be positive. This doesn't match the
380     // definition of signed ceiling division mathematically, but it prevents
381     // inconsistent constant-folding results. This arises because (-int_min) is
382     // still negative, so -(-int_min / b) is -(int_min / b), which is
383     // positive See #115293.
384     if (lhs.isMinSignedValue() && rhs.sgt(1)) {
385       return -result;
386     }
387     return result;
388   };
389   ConstantIntRanges result = inferDivSRange(lhs, rhs, ceilDivSIFix);
390   if (lhs.smin().isMinSignedValue() && lhs.smax().sgt(lhs.smin())) {
391     // If lhs range includes INT_MIN and lhs is not a single value, we can
392     // suddenly wrap to positive val, skipping entire negative range, add
393     // [INT_MIN + 1, smax()] range to the result to handle this.
394     auto newLhs = ConstantIntRanges::fromSigned(lhs.smin() + 1, lhs.smax());
395     result = result.rangeUnion(inferDivSRange(newLhs, rhs, ceilDivSIFix));
396   }
397   return result;
398 }
399 
400 ConstantIntRanges
401 mlir::intrange::inferFloorDivS(ArrayRef<ConstantIntRanges> argRanges) {
402   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
403 
404   auto floorDivSIFix = [](const APInt &lhs, const APInt &rhs,
405                           const APInt &result) -> std::optional<APInt> {
406     if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) {
407       bool overflowed = false;
408       APInt corrected =
409           result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed);
410       return overflowed ? std::optional<APInt>() : corrected;
411     }
412     return result;
413   };
414   return inferDivSRange(lhs, rhs, floorDivSIFix);
415 }
416 
417 //===----------------------------------------------------------------------===//
418 // Signed remainder (RemS)
419 //===----------------------------------------------------------------------===//
420 
421 ConstantIntRanges
422 mlir::intrange::inferRemS(ArrayRef<ConstantIntRanges> argRanges) {
423   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
424   const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
425               &rhsMax = rhs.smax();
426 
427   unsigned width = rhsMax.getBitWidth();
428   APInt smin = APInt::getSignedMinValue(width);
429   APInt smax = APInt::getSignedMaxValue(width);
430   // No bounds if zero could be a divisor.
431   bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative());
432   if (canBound) {
433     APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs();
434     bool canNegativeDividend = lhsMin.isNegative();
435     bool canPositiveDividend = lhsMax.isStrictlyPositive();
436     APInt zero = APInt::getZero(maxDivisor.getBitWidth());
437     APInt maxPositiveResult = maxDivisor - 1;
438     APInt minNegativeResult = -maxPositiveResult;
439     smin = canNegativeDividend ? minNegativeResult : zero;
440     smax = canPositiveDividend ? maxPositiveResult : zero;
441     // Special case: sweeping out a contiguous range in N/[modulus].
442     if (rhsMin == rhsMax) {
443       if ((lhsMax - lhsMin).ult(maxDivisor)) {
444         APInt minRem = lhsMin.srem(maxDivisor);
445         APInt maxRem = lhsMax.srem(maxDivisor);
446         if (minRem.sle(maxRem)) {
447           smin = minRem;
448           smax = maxRem;
449         }
450       }
451     }
452   }
453   return ConstantIntRanges::fromSigned(smin, smax);
454 }
455 
456 //===----------------------------------------------------------------------===//
457 // Unsigned remainder (RemU)
458 //===----------------------------------------------------------------------===//
459 
460 ConstantIntRanges
461 mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
462   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
463   const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
464 
465   unsigned width = rhsMin.getBitWidth();
466   APInt umin = APInt::getZero(width);
467   // Remainder can't be larger than either of its arguments.
468   APInt umax = llvm::APIntOps::umin((rhsMax - 1), lhs.umax());
469 
470   if (!rhsMin.isZero()) {
471     // Special case: sweeping out a contiguous range in N/[modulus]
472     if (rhsMin == rhsMax) {
473       const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
474       if ((lhsMax - lhsMin).ult(rhsMax)) {
475         APInt minRem = lhsMin.urem(rhsMax);
476         APInt maxRem = lhsMax.urem(rhsMax);
477         if (minRem.ule(maxRem)) {
478           umin = minRem;
479           umax = maxRem;
480         }
481       }
482     }
483   }
484   return ConstantIntRanges::fromUnsigned(umin, umax);
485 }
486 
487 //===----------------------------------------------------------------------===//
488 // Max and min (MaxS, MaxU, MinS, MinU)
489 //===----------------------------------------------------------------------===//
490 
491 ConstantIntRanges
492 mlir::intrange::inferMaxS(ArrayRef<ConstantIntRanges> argRanges) {
493   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
494 
495   const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
496   const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
497   return ConstantIntRanges::fromSigned(smin, smax);
498 }
499 
500 ConstantIntRanges
501 mlir::intrange::inferMaxU(ArrayRef<ConstantIntRanges> argRanges) {
502   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
503 
504   const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin();
505   const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax();
506   return ConstantIntRanges::fromUnsigned(umin, umax);
507 }
508 
509 ConstantIntRanges
510 mlir::intrange::inferMinS(ArrayRef<ConstantIntRanges> argRanges) {
511   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
512 
513   const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin();
514   const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax();
515   return ConstantIntRanges::fromSigned(smin, smax);
516 }
517 
518 ConstantIntRanges
519 mlir::intrange::inferMinU(ArrayRef<ConstantIntRanges> argRanges) {
520   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
521 
522   const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin();
523   const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax();
524   return ConstantIntRanges::fromUnsigned(umin, umax);
525 }
526 
527 //===----------------------------------------------------------------------===//
528 // Bitwise operators (And, Or, Xor)
529 //===----------------------------------------------------------------------===//
530 
531 /// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???,
532 /// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits
533 /// that both bonuds have in common. This gives us a consertive approximation
534 /// for what values can be passed to bitwise operations.
535 static std::tuple<APInt, APInt>
536 widenBitwiseBounds(const ConstantIntRanges &bound) {
537   APInt leftVal = bound.umin(), rightVal = bound.umax();
538   unsigned bitwidth = leftVal.getBitWidth();
539   unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero();
540   leftVal.clearLowBits(differingBits);
541   rightVal.setLowBits(differingBits);
542   return std::make_tuple(std::move(leftVal), std::move(rightVal));
543 }
544 
545 ConstantIntRanges
546 mlir::intrange::inferAnd(ArrayRef<ConstantIntRanges> argRanges) {
547   auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
548   auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
549   auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
550     return a & b;
551   };
552   return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
553                   /*isSigned=*/false);
554 }
555 
556 ConstantIntRanges
557 mlir::intrange::inferOr(ArrayRef<ConstantIntRanges> argRanges) {
558   auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
559   auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
560   auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
561     return a | b;
562   };
563   return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
564                   /*isSigned=*/false);
565 }
566 
567 /// Get bitmask of all bits which can change while iterating in
568 /// [bound.umin(), bound.umax()].
569 static APInt getVaryingBitsMask(const ConstantIntRanges &bound) {
570   APInt leftVal = bound.umin(), rightVal = bound.umax();
571   unsigned bitwidth = leftVal.getBitWidth();
572   unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero();
573   return APInt::getLowBitsSet(bitwidth, differingBits);
574 }
575 
576 ConstantIntRanges
577 mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
578   // Construct mask of varying bits for both ranges, xor values and then replace
579   // masked bits with 0s and 1s to get min and max values respectively.
580   ConstantIntRanges lhs = argRanges[0], rhs = argRanges[1];
581   APInt mask = getVaryingBitsMask(lhs) | getVaryingBitsMask(rhs);
582   APInt res = lhs.umin() ^ rhs.umin();
583   APInt min = res & ~mask;
584   APInt max = res | mask;
585   return ConstantIntRanges::fromUnsigned(min, max);
586 }
587 
588 //===----------------------------------------------------------------------===//
589 // Shifts (Shl, ShrS, ShrU)
590 //===----------------------------------------------------------------------===//
591 
592 ConstantIntRanges
593 mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges,
594                          OverflowFlags ovfFlags) {
595   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
596   const APInt &rhsUMin = rhs.umin(), &rhsUMax = rhs.umax();
597 
598   // The signed/unsigned overflow behavior of shl by `rhs` matches a mul with
599   // 2^rhs.
600   ConstArithStdFn ushl = [=](const APInt &l,
601                              const APInt &r) -> std::optional<APInt> {
602     bool overflowed = false;
603     APInt result = any(ovfFlags & OverflowFlags::Nuw)
604                        ? l.ushl_sat(r)
605                        : l.ushl_ov(r, overflowed);
606     return overflowed ? std::optional<APInt>() : result;
607   };
608   ConstArithStdFn sshl = [=](const APInt &l,
609                              const APInt &r) -> std::optional<APInt> {
610     bool overflowed = false;
611     APInt result = any(ovfFlags & OverflowFlags::Nsw)
612                        ? l.sshl_sat(r)
613                        : l.sshl_ov(r, overflowed);
614     return overflowed ? std::optional<APInt>() : result;
615   };
616 
617   ConstantIntRanges urange =
618       minMaxBy(ushl, {lhs.umin(), lhs.umax()}, {rhsUMin, rhsUMax},
619                /*isSigned=*/false);
620   ConstantIntRanges srange =
621       minMaxBy(sshl, {lhs.smin(), lhs.smax()}, {rhsUMin, rhsUMax},
622                /*isSigned=*/true);
623   return urange.intersection(srange);
624 }
625 
626 ConstantIntRanges
627 mlir::intrange::inferShrS(ArrayRef<ConstantIntRanges> argRanges) {
628   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
629 
630   auto ashr = [](const APInt &l, const APInt &r) -> std::optional<APInt> {
631     return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.ashr(r);
632   };
633 
634   return minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
635                   /*isSigned=*/true);
636 }
637 
638 ConstantIntRanges
639 mlir::intrange::inferShrU(ArrayRef<ConstantIntRanges> argRanges) {
640   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
641 
642   auto lshr = [](const APInt &l, const APInt &r) -> std::optional<APInt> {
643     return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.lshr(r);
644   };
645   return minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
646                   /*isSigned=*/false);
647 }
648 
649 //===----------------------------------------------------------------------===//
650 // Comparisons (Cmp)
651 //===----------------------------------------------------------------------===//
652 
653 static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred) {
654   switch (pred) {
655   case intrange::CmpPredicate::eq:
656     return intrange::CmpPredicate::ne;
657   case intrange::CmpPredicate::ne:
658     return intrange::CmpPredicate::eq;
659   case intrange::CmpPredicate::slt:
660     return intrange::CmpPredicate::sge;
661   case intrange::CmpPredicate::sle:
662     return intrange::CmpPredicate::sgt;
663   case intrange::CmpPredicate::sgt:
664     return intrange::CmpPredicate::sle;
665   case intrange::CmpPredicate::sge:
666     return intrange::CmpPredicate::slt;
667   case intrange::CmpPredicate::ult:
668     return intrange::CmpPredicate::uge;
669   case intrange::CmpPredicate::ule:
670     return intrange::CmpPredicate::ugt;
671   case intrange::CmpPredicate::ugt:
672     return intrange::CmpPredicate::ule;
673   case intrange::CmpPredicate::uge:
674     return intrange::CmpPredicate::ult;
675   }
676   llvm_unreachable("unknown cmp predicate value");
677 }
678 
679 static bool isStaticallyTrue(intrange::CmpPredicate pred,
680                              const ConstantIntRanges &lhs,
681                              const ConstantIntRanges &rhs) {
682   switch (pred) {
683   case intrange::CmpPredicate::sle:
684     return lhs.smax().sle(rhs.smin());
685   case intrange::CmpPredicate::slt:
686     return lhs.smax().slt(rhs.smin());
687   case intrange::CmpPredicate::ule:
688     return lhs.umax().ule(rhs.umin());
689   case intrange::CmpPredicate::ult:
690     return lhs.umax().ult(rhs.umin());
691   case intrange::CmpPredicate::sge:
692     return lhs.smin().sge(rhs.smax());
693   case intrange::CmpPredicate::sgt:
694     return lhs.smin().sgt(rhs.smax());
695   case intrange::CmpPredicate::uge:
696     return lhs.umin().uge(rhs.umax());
697   case intrange::CmpPredicate::ugt:
698     return lhs.umin().ugt(rhs.umax());
699   case intrange::CmpPredicate::eq: {
700     std::optional<APInt> lhsConst = lhs.getConstantValue();
701     std::optional<APInt> rhsConst = rhs.getConstantValue();
702     return lhsConst && rhsConst && lhsConst == rhsConst;
703   }
704   case intrange::CmpPredicate::ne: {
705     // While equality requires that there is an interpration of the preceeding
706     // computations that produces equal constants, whether that be signed or
707     // unsigned, statically determining inequality requires that neither
708     // interpretation produce potentially overlapping ranges.
709     bool sne = isStaticallyTrue(intrange::CmpPredicate::slt, lhs, rhs) ||
710                isStaticallyTrue(intrange::CmpPredicate::sgt, lhs, rhs);
711     bool une = isStaticallyTrue(intrange::CmpPredicate::ult, lhs, rhs) ||
712                isStaticallyTrue(intrange::CmpPredicate::ugt, lhs, rhs);
713     return sne && une;
714   }
715   }
716   return false;
717 }
718 
719 std::optional<bool> mlir::intrange::evaluatePred(CmpPredicate pred,
720                                                  const ConstantIntRanges &lhs,
721                                                  const ConstantIntRanges &rhs) {
722   if (isStaticallyTrue(pred, lhs, rhs))
723     return true;
724   if (isStaticallyTrue(invertPredicate(pred), lhs, rhs))
725     return false;
726   return std::nullopt;
727 }
728