1 //===- InferIntRangeCommon.cpp - Inference for common ops ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains implementations of range inference for operations that are 10 // common to both the `arith` and `index` dialects to facilitate reuse. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" 15 16 #include "mlir/Interfaces/InferIntRangeInterface.h" 17 18 #include "llvm/ADT/ArrayRef.h" 19 #include "llvm/ADT/STLExtras.h" 20 21 #include "llvm/Support/Debug.h" 22 23 #include <iterator> 24 #include <optional> 25 26 using namespace mlir; 27 28 #define DEBUG_TYPE "int-range-analysis" 29 30 //===----------------------------------------------------------------------===// 31 // General utilities 32 //===----------------------------------------------------------------------===// 33 34 /// Function that evaluates the result of doing something on arithmetic 35 /// constants and returns std::nullopt on overflow. 36 using ConstArithFn = 37 function_ref<std::optional<APInt>(const APInt &, const APInt &)>; 38 39 /// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, 40 /// If either computation overflows, make the result unbounded. 41 static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft, 42 const APInt &minRight, 43 const APInt &maxLeft, 44 const APInt &maxRight, bool isSigned) { 45 std::optional<APInt> maybeMin = op(minLeft, minRight); 46 std::optional<APInt> maybeMax = op(maxLeft, maxRight); 47 if (maybeMin && maybeMax) 48 return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned); 49 return ConstantIntRanges::maxRange(minLeft.getBitWidth()); 50 } 51 52 /// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`, 53 /// ignoring unbounded values. Returns the maximal range if `op` overflows. 54 static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs, 55 ArrayRef<APInt> rhs, bool isSigned) { 56 unsigned width = lhs[0].getBitWidth(); 57 APInt min = 58 isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width); 59 APInt max = 60 isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width); 61 for (const APInt &left : lhs) { 62 for (const APInt &right : rhs) { 63 std::optional<APInt> maybeThisResult = op(left, right); 64 if (!maybeThisResult) 65 return ConstantIntRanges::maxRange(width); 66 APInt result = std::move(*maybeThisResult); 67 min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min; 68 max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max; 69 } 70 } 71 return ConstantIntRanges::range(min, max, isSigned); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // Ext, trunc, index op handling 76 //===----------------------------------------------------------------------===// 77 78 ConstantIntRanges 79 mlir::intrange::inferIndexOp(InferRangeFn inferFn, 80 ArrayRef<ConstantIntRanges> argRanges, 81 intrange::CmpMode mode) { 82 ConstantIntRanges sixtyFour = inferFn(argRanges); 83 SmallVector<ConstantIntRanges, 2> truncated; 84 llvm::transform(argRanges, std::back_inserter(truncated), 85 [](const ConstantIntRanges &range) { 86 return truncRange(range, /*destWidth=*/indexMinWidth); 87 }); 88 ConstantIntRanges thirtyTwo = inferFn(truncated); 89 ConstantIntRanges thirtyTwoAsSixtyFour = 90 extRange(thirtyTwo, /*destWidth=*/indexMaxWidth); 91 ConstantIntRanges sixtyFourAsThirtyTwo = 92 truncRange(sixtyFour, /*destWidth=*/indexMinWidth); 93 94 LLVM_DEBUG(llvm::dbgs() << "Index handling: 64-bit result = " << sixtyFour 95 << " 32-bit = " << thirtyTwo << "\n"); 96 bool truncEqual = false; 97 switch (mode) { 98 case intrange::CmpMode::Both: 99 truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo); 100 break; 101 case intrange::CmpMode::Signed: 102 truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.smin() && 103 thirtyTwo.smax() == sixtyFourAsThirtyTwo.smax()); 104 break; 105 case intrange::CmpMode::Unsigned: 106 truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.umin() && 107 thirtyTwo.umax() == sixtyFourAsThirtyTwo.umax()); 108 break; 109 } 110 if (truncEqual) 111 // Returing the 64-bit result preserves more information. 112 return sixtyFour; 113 ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour); 114 return merged; 115 } 116 117 ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range, 118 unsigned int destWidth) { 119 APInt umin = range.umin().zext(destWidth); 120 APInt umax = range.umax().zext(destWidth); 121 APInt smin = range.smin().sext(destWidth); 122 APInt smax = range.smax().sext(destWidth); 123 return {umin, umax, smin, smax}; 124 } 125 126 ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range, 127 unsigned destWidth) { 128 APInt umin = range.umin().zext(destWidth); 129 APInt umax = range.umax().zext(destWidth); 130 return ConstantIntRanges::fromUnsigned(umin, umax); 131 } 132 133 ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range, 134 unsigned destWidth) { 135 APInt smin = range.smin().sext(destWidth); 136 APInt smax = range.smax().sext(destWidth); 137 return ConstantIntRanges::fromSigned(smin, smax); 138 } 139 140 ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range, 141 unsigned int destWidth) { 142 // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb], 143 // the range of the resulting value is not contiguous ind includes 0. 144 // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2], 145 // but you can't truncate [255, 257] similarly. 146 bool hasUnsignedRollover = 147 range.umin().lshr(destWidth) != range.umax().lshr(destWidth); 148 APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth) 149 : range.umin().trunc(destWidth); 150 APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth) 151 : range.umax().trunc(destWidth); 152 153 // Signed post-truncation rollover will not occur when either: 154 // - The high parts of the min and max, plus the sign bit, are the same 155 // - The high halves + sign bit of the min and max are either all 1s or all 0s 156 // and you won't create a [positive, negative] range by truncating. 157 // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8 158 // but not [255, 257]_i16 to a range of i8s. You can also truncate 159 // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16. 160 // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e) 161 // will truncate to 0x7e, which is greater than 0 162 APInt sminHighPart = range.smin().ashr(destWidth - 1); 163 APInt smaxHighPart = range.smax().ashr(destWidth - 1); 164 bool hasSignedOverflow = 165 (sminHighPart != smaxHighPart) && 166 !(sminHighPart.isAllOnes() && 167 (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) && 168 !(sminHighPart.isZero() && smaxHighPart.isZero()); 169 APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth) 170 : range.smin().trunc(destWidth); 171 APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth) 172 : range.smax().trunc(destWidth); 173 return {umin, umax, smin, smax}; 174 } 175 176 //===----------------------------------------------------------------------===// 177 // Addition 178 //===----------------------------------------------------------------------===// 179 180 ConstantIntRanges 181 mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges) { 182 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 183 ConstArithFn uadd = [](const APInt &a, 184 const APInt &b) -> std::optional<APInt> { 185 bool overflowed = false; 186 APInt result = a.uadd_ov(b, overflowed); 187 return overflowed ? std::optional<APInt>() : result; 188 }; 189 ConstArithFn sadd = [](const APInt &a, 190 const APInt &b) -> std::optional<APInt> { 191 bool overflowed = false; 192 APInt result = a.sadd_ov(b, overflowed); 193 return overflowed ? std::optional<APInt>() : result; 194 }; 195 196 ConstantIntRanges urange = computeBoundsBy( 197 uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false); 198 ConstantIntRanges srange = computeBoundsBy( 199 sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true); 200 return urange.intersection(srange); 201 } 202 203 //===----------------------------------------------------------------------===// 204 // Subtraction 205 //===----------------------------------------------------------------------===// 206 207 ConstantIntRanges 208 mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges) { 209 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 210 211 ConstArithFn usub = [](const APInt &a, 212 const APInt &b) -> std::optional<APInt> { 213 bool overflowed = false; 214 APInt result = a.usub_ov(b, overflowed); 215 return overflowed ? std::optional<APInt>() : result; 216 }; 217 ConstArithFn ssub = [](const APInt &a, 218 const APInt &b) -> std::optional<APInt> { 219 bool overflowed = false; 220 APInt result = a.ssub_ov(b, overflowed); 221 return overflowed ? std::optional<APInt>() : result; 222 }; 223 ConstantIntRanges urange = computeBoundsBy( 224 usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false); 225 ConstantIntRanges srange = computeBoundsBy( 226 ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true); 227 return urange.intersection(srange); 228 } 229 230 //===----------------------------------------------------------------------===// 231 // Multiplication 232 //===----------------------------------------------------------------------===// 233 234 ConstantIntRanges 235 mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges) { 236 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 237 238 ConstArithFn umul = [](const APInt &a, 239 const APInt &b) -> std::optional<APInt> { 240 bool overflowed = false; 241 APInt result = a.umul_ov(b, overflowed); 242 return overflowed ? std::optional<APInt>() : result; 243 }; 244 ConstArithFn smul = [](const APInt &a, 245 const APInt &b) -> std::optional<APInt> { 246 bool overflowed = false; 247 APInt result = a.smul_ov(b, overflowed); 248 return overflowed ? std::optional<APInt>() : result; 249 }; 250 251 ConstantIntRanges urange = 252 minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, 253 /*isSigned=*/false); 254 ConstantIntRanges srange = 255 minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()}, 256 /*isSigned=*/true); 257 return urange.intersection(srange); 258 } 259 260 //===----------------------------------------------------------------------===// 261 // DivU, CeilDivU (Unsigned division) 262 //===----------------------------------------------------------------------===// 263 264 /// Fix up division results (ex. for ceiling and floor), returning an APInt 265 /// if there has been no overflow 266 using DivisionFixupFn = function_ref<std::optional<APInt>( 267 const APInt &lhs, const APInt &rhs, const APInt &result)>; 268 269 static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs, 270 const ConstantIntRanges &rhs, 271 DivisionFixupFn fixup) { 272 const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(), 273 &rhsMax = rhs.umax(); 274 275 if (!rhsMin.isZero()) { 276 auto udiv = [&fixup](const APInt &a, 277 const APInt &b) -> std::optional<APInt> { 278 return fixup(a, b, a.udiv(b)); 279 }; 280 return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, 281 /*isSigned=*/false); 282 } 283 // Otherwise, it's possible we might divide by 0. 284 return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); 285 } 286 287 ConstantIntRanges 288 mlir::intrange::inferDivU(ArrayRef<ConstantIntRanges> argRanges) { 289 return inferDivURange(argRanges[0], argRanges[1], 290 [](const APInt &lhs, const APInt &rhs, 291 const APInt &result) { return result; }); 292 } 293 294 ConstantIntRanges 295 mlir::intrange::inferCeilDivU(ArrayRef<ConstantIntRanges> argRanges) { 296 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 297 298 DivisionFixupFn ceilDivUIFix = 299 [](const APInt &lhs, const APInt &rhs, 300 const APInt &result) -> std::optional<APInt> { 301 if (!lhs.urem(rhs).isZero()) { 302 bool overflowed = false; 303 APInt corrected = 304 result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed); 305 return overflowed ? std::optional<APInt>() : corrected; 306 } 307 return result; 308 }; 309 return inferDivURange(lhs, rhs, ceilDivUIFix); 310 } 311 312 //===----------------------------------------------------------------------===// 313 // DivS, CeilDivS, FloorDivS (Signed division) 314 //===----------------------------------------------------------------------===// 315 316 static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs, 317 const ConstantIntRanges &rhs, 318 DivisionFixupFn fixup) { 319 const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), 320 &rhsMax = rhs.smax(); 321 bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative(); 322 323 if (canDivide) { 324 auto sdiv = [&fixup](const APInt &a, 325 const APInt &b) -> std::optional<APInt> { 326 bool overflowed = false; 327 APInt result = a.sdiv_ov(b, overflowed); 328 return overflowed ? std::optional<APInt>() : fixup(a, b, result); 329 }; 330 return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, 331 /*isSigned=*/true); 332 } 333 return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); 334 } 335 336 ConstantIntRanges 337 mlir::intrange::inferDivS(ArrayRef<ConstantIntRanges> argRanges) { 338 return inferDivSRange(argRanges[0], argRanges[1], 339 [](const APInt &lhs, const APInt &rhs, 340 const APInt &result) { return result; }); 341 } 342 343 ConstantIntRanges 344 mlir::intrange::inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges) { 345 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 346 347 DivisionFixupFn ceilDivSIFix = 348 [](const APInt &lhs, const APInt &rhs, 349 const APInt &result) -> std::optional<APInt> { 350 if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) { 351 bool overflowed = false; 352 APInt corrected = 353 result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed); 354 return overflowed ? std::optional<APInt>() : corrected; 355 } 356 return result; 357 }; 358 return inferDivSRange(lhs, rhs, ceilDivSIFix); 359 } 360 361 ConstantIntRanges 362 mlir::intrange::inferFloorDivS(ArrayRef<ConstantIntRanges> argRanges) { 363 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 364 365 DivisionFixupFn floorDivSIFix = 366 [](const APInt &lhs, const APInt &rhs, 367 const APInt &result) -> std::optional<APInt> { 368 if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) { 369 bool overflowed = false; 370 APInt corrected = 371 result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed); 372 return overflowed ? std::optional<APInt>() : corrected; 373 } 374 return result; 375 }; 376 return inferDivSRange(lhs, rhs, floorDivSIFix); 377 } 378 379 //===----------------------------------------------------------------------===// 380 // Signed remainder (RemS) 381 //===----------------------------------------------------------------------===// 382 383 ConstantIntRanges 384 mlir::intrange::inferRemS(ArrayRef<ConstantIntRanges> argRanges) { 385 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 386 const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), 387 &rhsMax = rhs.smax(); 388 389 unsigned width = rhsMax.getBitWidth(); 390 APInt smin = APInt::getSignedMinValue(width); 391 APInt smax = APInt::getSignedMaxValue(width); 392 // No bounds if zero could be a divisor. 393 bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative()); 394 if (canBound) { 395 APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs(); 396 bool canNegativeDividend = lhsMin.isNegative(); 397 bool canPositiveDividend = lhsMax.isStrictlyPositive(); 398 APInt zero = APInt::getZero(maxDivisor.getBitWidth()); 399 APInt maxPositiveResult = maxDivisor - 1; 400 APInt minNegativeResult = -maxPositiveResult; 401 smin = canNegativeDividend ? minNegativeResult : zero; 402 smax = canPositiveDividend ? maxPositiveResult : zero; 403 // Special case: sweeping out a contiguous range in N/[modulus]. 404 if (rhsMin == rhsMax) { 405 if ((lhsMax - lhsMin).ult(maxDivisor)) { 406 APInt minRem = lhsMin.srem(maxDivisor); 407 APInt maxRem = lhsMax.srem(maxDivisor); 408 if (minRem.sle(maxRem)) { 409 smin = minRem; 410 smax = maxRem; 411 } 412 } 413 } 414 } 415 return ConstantIntRanges::fromSigned(smin, smax); 416 } 417 418 //===----------------------------------------------------------------------===// 419 // Unsigned remainder (RemU) 420 //===----------------------------------------------------------------------===// 421 422 ConstantIntRanges 423 mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) { 424 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 425 const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax(); 426 427 unsigned width = rhsMin.getBitWidth(); 428 APInt umin = APInt::getZero(width); 429 APInt umax = APInt::getMaxValue(width); 430 431 if (!rhsMin.isZero()) { 432 umax = rhsMax - 1; 433 // Special case: sweeping out a contiguous range in N/[modulus] 434 if (rhsMin == rhsMax) { 435 const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(); 436 if ((lhsMax - lhsMin).ult(rhsMax)) { 437 APInt minRem = lhsMin.urem(rhsMax); 438 APInt maxRem = lhsMax.urem(rhsMax); 439 if (minRem.ule(maxRem)) { 440 umin = minRem; 441 umax = maxRem; 442 } 443 } 444 } 445 } 446 return ConstantIntRanges::fromUnsigned(umin, umax); 447 } 448 449 //===----------------------------------------------------------------------===// 450 // Max and min (MaxS, MaxU, MinS, MinU) 451 //===----------------------------------------------------------------------===// 452 453 ConstantIntRanges 454 mlir::intrange::inferMaxS(ArrayRef<ConstantIntRanges> argRanges) { 455 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 456 457 const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin(); 458 const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax(); 459 return ConstantIntRanges::fromSigned(smin, smax); 460 } 461 462 ConstantIntRanges 463 mlir::intrange::inferMaxU(ArrayRef<ConstantIntRanges> argRanges) { 464 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 465 466 const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin(); 467 const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax(); 468 return ConstantIntRanges::fromUnsigned(umin, umax); 469 } 470 471 ConstantIntRanges 472 mlir::intrange::inferMinS(ArrayRef<ConstantIntRanges> argRanges) { 473 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 474 475 const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin(); 476 const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax(); 477 return ConstantIntRanges::fromSigned(smin, smax); 478 } 479 480 ConstantIntRanges 481 mlir::intrange::inferMinU(ArrayRef<ConstantIntRanges> argRanges) { 482 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 483 484 const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin(); 485 const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax(); 486 return ConstantIntRanges::fromUnsigned(umin, umax); 487 } 488 489 //===----------------------------------------------------------------------===// 490 // Bitwise operators (And, Or, Xor) 491 //===----------------------------------------------------------------------===// 492 493 /// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, 494 /// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits 495 /// that both bonuds have in common. This gives us a consertive approximation 496 /// for what values can be passed to bitwise operations. 497 static std::tuple<APInt, APInt> 498 widenBitwiseBounds(const ConstantIntRanges &bound) { 499 APInt leftVal = bound.umin(), rightVal = bound.umax(); 500 unsigned bitwidth = leftVal.getBitWidth(); 501 unsigned differingBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros(); 502 leftVal.clearLowBits(differingBits); 503 rightVal.setLowBits(differingBits); 504 return std::make_tuple(std::move(leftVal), std::move(rightVal)); 505 } 506 507 ConstantIntRanges 508 mlir::intrange::inferAnd(ArrayRef<ConstantIntRanges> argRanges) { 509 auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); 510 auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); 511 auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> { 512 return a & b; 513 }; 514 return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, 515 /*isSigned=*/false); 516 } 517 518 ConstantIntRanges 519 mlir::intrange::inferOr(ArrayRef<ConstantIntRanges> argRanges) { 520 auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); 521 auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); 522 auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> { 523 return a | b; 524 }; 525 return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, 526 /*isSigned=*/false); 527 } 528 529 ConstantIntRanges 530 mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) { 531 auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); 532 auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); 533 auto xori = [](const APInt &a, const APInt &b) -> std::optional<APInt> { 534 return a ^ b; 535 }; 536 return minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, 537 /*isSigned=*/false); 538 } 539 540 //===----------------------------------------------------------------------===// 541 // Shifts (Shl, ShrS, ShrU) 542 //===----------------------------------------------------------------------===// 543 544 ConstantIntRanges 545 mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) { 546 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 547 ConstArithFn shl = [](const APInt &l, 548 const APInt &r) -> std::optional<APInt> { 549 return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r); 550 }; 551 ConstantIntRanges urange = 552 minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, 553 /*isSigned=*/false); 554 ConstantIntRanges srange = 555 minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, 556 /*isSigned=*/true); 557 return urange.intersection(srange); 558 } 559 560 ConstantIntRanges 561 mlir::intrange::inferShrS(ArrayRef<ConstantIntRanges> argRanges) { 562 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 563 564 ConstArithFn ashr = [](const APInt &l, 565 const APInt &r) -> std::optional<APInt> { 566 return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.ashr(r); 567 }; 568 569 return minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, 570 /*isSigned=*/true); 571 } 572 573 ConstantIntRanges 574 mlir::intrange::inferShrU(ArrayRef<ConstantIntRanges> argRanges) { 575 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 576 577 ConstArithFn lshr = [](const APInt &l, 578 const APInt &r) -> std::optional<APInt> { 579 return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.lshr(r); 580 }; 581 return minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, 582 /*isSigned=*/false); 583 } 584 585 //===----------------------------------------------------------------------===// 586 // Comparisons (Cmp) 587 //===----------------------------------------------------------------------===// 588 589 static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred) { 590 switch (pred) { 591 case intrange::CmpPredicate::eq: 592 return intrange::CmpPredicate::ne; 593 case intrange::CmpPredicate::ne: 594 return intrange::CmpPredicate::eq; 595 case intrange::CmpPredicate::slt: 596 return intrange::CmpPredicate::sge; 597 case intrange::CmpPredicate::sle: 598 return intrange::CmpPredicate::sgt; 599 case intrange::CmpPredicate::sgt: 600 return intrange::CmpPredicate::sle; 601 case intrange::CmpPredicate::sge: 602 return intrange::CmpPredicate::slt; 603 case intrange::CmpPredicate::ult: 604 return intrange::CmpPredicate::uge; 605 case intrange::CmpPredicate::ule: 606 return intrange::CmpPredicate::ugt; 607 case intrange::CmpPredicate::ugt: 608 return intrange::CmpPredicate::ule; 609 case intrange::CmpPredicate::uge: 610 return intrange::CmpPredicate::ult; 611 } 612 llvm_unreachable("unknown cmp predicate value"); 613 } 614 615 static bool isStaticallyTrue(intrange::CmpPredicate pred, 616 const ConstantIntRanges &lhs, 617 const ConstantIntRanges &rhs) { 618 switch (pred) { 619 case intrange::CmpPredicate::sle: 620 return lhs.smax().sle(rhs.smin()); 621 case intrange::CmpPredicate::slt: 622 return lhs.smax().slt(rhs.smin()); 623 case intrange::CmpPredicate::ule: 624 return lhs.umax().ule(rhs.umin()); 625 case intrange::CmpPredicate::ult: 626 return lhs.umax().ult(rhs.umin()); 627 case intrange::CmpPredicate::sge: 628 return lhs.smin().sge(rhs.smax()); 629 case intrange::CmpPredicate::sgt: 630 return lhs.smin().sgt(rhs.smax()); 631 case intrange::CmpPredicate::uge: 632 return lhs.umin().uge(rhs.umax()); 633 case intrange::CmpPredicate::ugt: 634 return lhs.umin().ugt(rhs.umax()); 635 case intrange::CmpPredicate::eq: { 636 std::optional<APInt> lhsConst = lhs.getConstantValue(); 637 std::optional<APInt> rhsConst = rhs.getConstantValue(); 638 return lhsConst && rhsConst && lhsConst == rhsConst; 639 } 640 case intrange::CmpPredicate::ne: { 641 // While equality requires that there is an interpration of the preceeding 642 // computations that produces equal constants, whether that be signed or 643 // unsigned, statically determining inequality requires that neither 644 // interpretation produce potentially overlapping ranges. 645 bool sne = isStaticallyTrue(intrange::CmpPredicate::slt, lhs, rhs) || 646 isStaticallyTrue(intrange::CmpPredicate::sgt, lhs, rhs); 647 bool une = isStaticallyTrue(intrange::CmpPredicate::ult, lhs, rhs) || 648 isStaticallyTrue(intrange::CmpPredicate::ugt, lhs, rhs); 649 return sne && une; 650 } 651 } 652 return false; 653 } 654 655 std::optional<bool> mlir::intrange::evaluatePred(CmpPredicate pred, 656 const ConstantIntRanges &lhs, 657 const ConstantIntRanges &rhs) { 658 if (isStaticallyTrue(pred, lhs, rhs)) 659 return true; 660 if (isStaticallyTrue(invertPredicate(pred), lhs, rhs)) 661 return false; 662 return std::nullopt; 663 } 664