1 //===- InferIntRangeCommon.cpp - Inference for common ops ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains implementations of range inference for operations that are 10 // common to both the `arith` and `index` dialects to facilitate reuse. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" 15 16 #include "mlir/Interfaces/InferIntRangeInterface.h" 17 18 #include "llvm/ADT/ArrayRef.h" 19 #include "llvm/ADT/STLExtras.h" 20 21 #include "llvm/Support/Debug.h" 22 23 #include <iterator> 24 #include <optional> 25 26 using namespace mlir; 27 28 #define DEBUG_TYPE "int-range-analysis" 29 30 //===----------------------------------------------------------------------===// 31 // General utilities 32 //===----------------------------------------------------------------------===// 33 34 /// Function that evaluates the result of doing something on arithmetic 35 /// constants and returns std::nullopt on overflow. 36 using ConstArithFn = 37 function_ref<std::optional<APInt>(const APInt &, const APInt &)>; 38 using ConstArithStdFn = 39 std::function<std::optional<APInt>(const APInt &, const APInt &)>; 40 41 /// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, 42 /// If either computation overflows, make the result unbounded. 43 static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft, 44 const APInt &minRight, 45 const APInt &maxLeft, 46 const APInt &maxRight, bool isSigned) { 47 std::optional<APInt> maybeMin = op(minLeft, minRight); 48 std::optional<APInt> maybeMax = op(maxLeft, maxRight); 49 if (maybeMin && maybeMax) 50 return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned); 51 return ConstantIntRanges::maxRange(minLeft.getBitWidth()); 52 } 53 54 /// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`, 55 /// ignoring unbounded values. Returns the maximal range if `op` overflows. 56 static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs, 57 ArrayRef<APInt> rhs, bool isSigned) { 58 unsigned width = lhs[0].getBitWidth(); 59 APInt min = 60 isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width); 61 APInt max = 62 isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width); 63 for (const APInt &left : lhs) { 64 for (const APInt &right : rhs) { 65 std::optional<APInt> maybeThisResult = op(left, right); 66 if (!maybeThisResult) 67 return ConstantIntRanges::maxRange(width); 68 APInt result = std::move(*maybeThisResult); 69 min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min; 70 max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max; 71 } 72 } 73 return ConstantIntRanges::range(min, max, isSigned); 74 } 75 76 //===----------------------------------------------------------------------===// 77 // Ext, trunc, index op handling 78 //===----------------------------------------------------------------------===// 79 80 ConstantIntRanges 81 mlir::intrange::inferIndexOp(const InferRangeFn &inferFn, 82 ArrayRef<ConstantIntRanges> argRanges, 83 intrange::CmpMode mode) { 84 ConstantIntRanges sixtyFour = inferFn(argRanges); 85 SmallVector<ConstantIntRanges, 2> truncated; 86 llvm::transform(argRanges, std::back_inserter(truncated), 87 [](const ConstantIntRanges &range) { 88 return truncRange(range, /*destWidth=*/indexMinWidth); 89 }); 90 ConstantIntRanges thirtyTwo = inferFn(truncated); 91 ConstantIntRanges thirtyTwoAsSixtyFour = 92 extRange(thirtyTwo, /*destWidth=*/indexMaxWidth); 93 ConstantIntRanges sixtyFourAsThirtyTwo = 94 truncRange(sixtyFour, /*destWidth=*/indexMinWidth); 95 96 LLVM_DEBUG(llvm::dbgs() << "Index handling: 64-bit result = " << sixtyFour 97 << " 32-bit = " << thirtyTwo << "\n"); 98 bool truncEqual = false; 99 switch (mode) { 100 case intrange::CmpMode::Both: 101 truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo); 102 break; 103 case intrange::CmpMode::Signed: 104 truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.smin() && 105 thirtyTwo.smax() == sixtyFourAsThirtyTwo.smax()); 106 break; 107 case intrange::CmpMode::Unsigned: 108 truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.umin() && 109 thirtyTwo.umax() == sixtyFourAsThirtyTwo.umax()); 110 break; 111 } 112 if (truncEqual) 113 // Returing the 64-bit result preserves more information. 114 return sixtyFour; 115 ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour); 116 return merged; 117 } 118 119 ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range, 120 unsigned int destWidth) { 121 APInt umin = range.umin().zext(destWidth); 122 APInt umax = range.umax().zext(destWidth); 123 APInt smin = range.smin().sext(destWidth); 124 APInt smax = range.smax().sext(destWidth); 125 return {umin, umax, smin, smax}; 126 } 127 128 ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range, 129 unsigned destWidth) { 130 APInt umin = range.umin().zext(destWidth); 131 APInt umax = range.umax().zext(destWidth); 132 return ConstantIntRanges::fromUnsigned(umin, umax); 133 } 134 135 ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range, 136 unsigned destWidth) { 137 APInt smin = range.smin().sext(destWidth); 138 APInt smax = range.smax().sext(destWidth); 139 return ConstantIntRanges::fromSigned(smin, smax); 140 } 141 142 ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range, 143 unsigned int destWidth) { 144 // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb], 145 // the range of the resulting value is not contiguous ind includes 0. 146 // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2], 147 // but you can't truncate [255, 257] similarly. 148 bool hasUnsignedRollover = 149 range.umin().lshr(destWidth) != range.umax().lshr(destWidth); 150 APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth) 151 : range.umin().trunc(destWidth); 152 APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth) 153 : range.umax().trunc(destWidth); 154 155 // Signed post-truncation rollover will not occur when either: 156 // - The high parts of the min and max, plus the sign bit, are the same 157 // - The high halves + sign bit of the min and max are either all 1s or all 0s 158 // and you won't create a [positive, negative] range by truncating. 159 // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8 160 // but not [255, 257]_i16 to a range of i8s. You can also truncate 161 // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16. 162 // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e) 163 // will truncate to 0x7e, which is greater than 0 164 APInt sminHighPart = range.smin().ashr(destWidth - 1); 165 APInt smaxHighPart = range.smax().ashr(destWidth - 1); 166 bool hasSignedOverflow = 167 (sminHighPart != smaxHighPart) && 168 !(sminHighPart.isAllOnes() && 169 (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) && 170 !(sminHighPart.isZero() && smaxHighPart.isZero()); 171 APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth) 172 : range.smin().trunc(destWidth); 173 APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth) 174 : range.smax().trunc(destWidth); 175 return {umin, umax, smin, smax}; 176 } 177 178 //===----------------------------------------------------------------------===// 179 // Addition 180 //===----------------------------------------------------------------------===// 181 182 ConstantIntRanges 183 mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges, 184 OverflowFlags ovfFlags) { 185 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 186 187 ConstArithStdFn uadd = [=](const APInt &a, 188 const APInt &b) -> std::optional<APInt> { 189 bool overflowed = false; 190 APInt result = any(ovfFlags & OverflowFlags::Nuw) 191 ? a.uadd_sat(b) 192 : a.uadd_ov(b, overflowed); 193 return overflowed ? std::optional<APInt>() : result; 194 }; 195 ConstArithStdFn sadd = [=](const APInt &a, 196 const APInt &b) -> std::optional<APInt> { 197 bool overflowed = false; 198 APInt result = any(ovfFlags & OverflowFlags::Nsw) 199 ? a.sadd_sat(b) 200 : a.sadd_ov(b, overflowed); 201 return overflowed ? std::optional<APInt>() : result; 202 }; 203 204 ConstantIntRanges urange = computeBoundsBy( 205 uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false); 206 ConstantIntRanges srange = computeBoundsBy( 207 sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true); 208 return urange.intersection(srange); 209 } 210 211 //===----------------------------------------------------------------------===// 212 // Subtraction 213 //===----------------------------------------------------------------------===// 214 215 ConstantIntRanges 216 mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges, 217 OverflowFlags ovfFlags) { 218 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 219 220 ConstArithStdFn usub = [=](const APInt &a, 221 const APInt &b) -> std::optional<APInt> { 222 bool overflowed = false; 223 APInt result = any(ovfFlags & OverflowFlags::Nuw) 224 ? a.usub_sat(b) 225 : a.usub_ov(b, overflowed); 226 return overflowed ? std::optional<APInt>() : result; 227 }; 228 ConstArithStdFn ssub = [=](const APInt &a, 229 const APInt &b) -> std::optional<APInt> { 230 bool overflowed = false; 231 APInt result = any(ovfFlags & OverflowFlags::Nsw) 232 ? a.ssub_sat(b) 233 : a.ssub_ov(b, overflowed); 234 return overflowed ? std::optional<APInt>() : result; 235 }; 236 ConstantIntRanges urange = computeBoundsBy( 237 usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false); 238 ConstantIntRanges srange = computeBoundsBy( 239 ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true); 240 return urange.intersection(srange); 241 } 242 243 //===----------------------------------------------------------------------===// 244 // Multiplication 245 //===----------------------------------------------------------------------===// 246 247 ConstantIntRanges 248 mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges, 249 OverflowFlags ovfFlags) { 250 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 251 252 ConstArithStdFn umul = [=](const APInt &a, 253 const APInt &b) -> std::optional<APInt> { 254 bool overflowed = false; 255 APInt result = any(ovfFlags & OverflowFlags::Nuw) 256 ? a.umul_sat(b) 257 : a.umul_ov(b, overflowed); 258 return overflowed ? std::optional<APInt>() : result; 259 }; 260 ConstArithStdFn smul = [=](const APInt &a, 261 const APInt &b) -> std::optional<APInt> { 262 bool overflowed = false; 263 APInt result = any(ovfFlags & OverflowFlags::Nsw) 264 ? a.smul_sat(b) 265 : a.smul_ov(b, overflowed); 266 return overflowed ? std::optional<APInt>() : result; 267 }; 268 269 ConstantIntRanges urange = 270 minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, 271 /*isSigned=*/false); 272 ConstantIntRanges srange = 273 minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()}, 274 /*isSigned=*/true); 275 return urange.intersection(srange); 276 } 277 278 //===----------------------------------------------------------------------===// 279 // DivU, CeilDivU (Unsigned division) 280 //===----------------------------------------------------------------------===// 281 282 /// Fix up division results (ex. for ceiling and floor), returning an APInt 283 /// if there has been no overflow 284 using DivisionFixupFn = function_ref<std::optional<APInt>( 285 const APInt &lhs, const APInt &rhs, const APInt &result)>; 286 287 static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs, 288 const ConstantIntRanges &rhs, 289 DivisionFixupFn fixup) { 290 const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(), 291 &rhsMax = rhs.umax(); 292 293 if (!rhsMin.isZero()) { 294 auto udiv = [&fixup](const APInt &a, 295 const APInt &b) -> std::optional<APInt> { 296 return fixup(a, b, a.udiv(b)); 297 }; 298 return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, 299 /*isSigned=*/false); 300 } 301 302 APInt umin = APInt::getZero(rhsMin.getBitWidth()); 303 if (lhsMin.uge(rhsMax) && !rhsMax.isZero()) 304 umin = lhsMin.udiv(rhsMax); 305 306 // X u/ Y u<= X. 307 APInt umax = lhsMax; 308 return ConstantIntRanges::fromUnsigned(umin, umax); 309 } 310 311 ConstantIntRanges 312 mlir::intrange::inferDivU(ArrayRef<ConstantIntRanges> argRanges) { 313 return inferDivURange(argRanges[0], argRanges[1], 314 [](const APInt &lhs, const APInt &rhs, 315 const APInt &result) { return result; }); 316 } 317 318 ConstantIntRanges 319 mlir::intrange::inferCeilDivU(ArrayRef<ConstantIntRanges> argRanges) { 320 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 321 322 auto ceilDivUIFix = [](const APInt &lhs, const APInt &rhs, 323 const APInt &result) -> std::optional<APInt> { 324 if (!lhs.urem(rhs).isZero()) { 325 bool overflowed = false; 326 APInt corrected = 327 result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed); 328 return overflowed ? std::optional<APInt>() : corrected; 329 } 330 return result; 331 }; 332 return inferDivURange(lhs, rhs, ceilDivUIFix); 333 } 334 335 //===----------------------------------------------------------------------===// 336 // DivS, CeilDivS, FloorDivS (Signed division) 337 //===----------------------------------------------------------------------===// 338 339 static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs, 340 const ConstantIntRanges &rhs, 341 DivisionFixupFn fixup) { 342 const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), 343 &rhsMax = rhs.smax(); 344 bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative(); 345 346 if (canDivide) { 347 auto sdiv = [&fixup](const APInt &a, 348 const APInt &b) -> std::optional<APInt> { 349 bool overflowed = false; 350 APInt result = a.sdiv_ov(b, overflowed); 351 return overflowed ? std::optional<APInt>() : fixup(a, b, result); 352 }; 353 return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, 354 /*isSigned=*/true); 355 } 356 return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); 357 } 358 359 ConstantIntRanges 360 mlir::intrange::inferDivS(ArrayRef<ConstantIntRanges> argRanges) { 361 return inferDivSRange(argRanges[0], argRanges[1], 362 [](const APInt &lhs, const APInt &rhs, 363 const APInt &result) { return result; }); 364 } 365 366 ConstantIntRanges 367 mlir::intrange::inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges) { 368 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 369 370 auto ceilDivSIFix = [](const APInt &lhs, const APInt &rhs, 371 const APInt &result) -> std::optional<APInt> { 372 if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) { 373 bool overflowed = false; 374 APInt corrected = 375 result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed); 376 return overflowed ? std::optional<APInt>() : corrected; 377 } 378 // Special case where the usual implementation of ceilDiv causes 379 // INT_MIN / [positive number] to be positive. This doesn't match the 380 // definition of signed ceiling division mathematically, but it prevents 381 // inconsistent constant-folding results. This arises because (-int_min) is 382 // still negative, so -(-int_min / b) is -(int_min / b), which is 383 // positive See #115293. 384 if (lhs.isMinSignedValue() && rhs.sgt(1)) { 385 return -result; 386 } 387 return result; 388 }; 389 ConstantIntRanges result = inferDivSRange(lhs, rhs, ceilDivSIFix); 390 if (lhs.smin().isMinSignedValue() && lhs.smax().sgt(lhs.smin())) { 391 // If lhs range includes INT_MIN and lhs is not a single value, we can 392 // suddenly wrap to positive val, skipping entire negative range, add 393 // [INT_MIN + 1, smax()] range to the result to handle this. 394 auto newLhs = ConstantIntRanges::fromSigned(lhs.smin() + 1, lhs.smax()); 395 result = result.rangeUnion(inferDivSRange(newLhs, rhs, ceilDivSIFix)); 396 } 397 return result; 398 } 399 400 ConstantIntRanges 401 mlir::intrange::inferFloorDivS(ArrayRef<ConstantIntRanges> argRanges) { 402 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 403 404 auto floorDivSIFix = [](const APInt &lhs, const APInt &rhs, 405 const APInt &result) -> std::optional<APInt> { 406 if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) { 407 bool overflowed = false; 408 APInt corrected = 409 result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed); 410 return overflowed ? std::optional<APInt>() : corrected; 411 } 412 return result; 413 }; 414 return inferDivSRange(lhs, rhs, floorDivSIFix); 415 } 416 417 //===----------------------------------------------------------------------===// 418 // Signed remainder (RemS) 419 //===----------------------------------------------------------------------===// 420 421 ConstantIntRanges 422 mlir::intrange::inferRemS(ArrayRef<ConstantIntRanges> argRanges) { 423 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 424 const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), 425 &rhsMax = rhs.smax(); 426 427 unsigned width = rhsMax.getBitWidth(); 428 APInt smin = APInt::getSignedMinValue(width); 429 APInt smax = APInt::getSignedMaxValue(width); 430 // No bounds if zero could be a divisor. 431 bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative()); 432 if (canBound) { 433 APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs(); 434 bool canNegativeDividend = lhsMin.isNegative(); 435 bool canPositiveDividend = lhsMax.isStrictlyPositive(); 436 APInt zero = APInt::getZero(maxDivisor.getBitWidth()); 437 APInt maxPositiveResult = maxDivisor - 1; 438 APInt minNegativeResult = -maxPositiveResult; 439 smin = canNegativeDividend ? minNegativeResult : zero; 440 smax = canPositiveDividend ? maxPositiveResult : zero; 441 // Special case: sweeping out a contiguous range in N/[modulus]. 442 if (rhsMin == rhsMax) { 443 if ((lhsMax - lhsMin).ult(maxDivisor)) { 444 APInt minRem = lhsMin.srem(maxDivisor); 445 APInt maxRem = lhsMax.srem(maxDivisor); 446 if (minRem.sle(maxRem)) { 447 smin = minRem; 448 smax = maxRem; 449 } 450 } 451 } 452 } 453 return ConstantIntRanges::fromSigned(smin, smax); 454 } 455 456 //===----------------------------------------------------------------------===// 457 // Unsigned remainder (RemU) 458 //===----------------------------------------------------------------------===// 459 460 ConstantIntRanges 461 mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) { 462 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 463 const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax(); 464 465 unsigned width = rhsMin.getBitWidth(); 466 APInt umin = APInt::getZero(width); 467 // Remainder can't be larger than either of its arguments. 468 APInt umax = llvm::APIntOps::umin((rhsMax - 1), lhs.umax()); 469 470 if (!rhsMin.isZero()) { 471 // Special case: sweeping out a contiguous range in N/[modulus] 472 if (rhsMin == rhsMax) { 473 const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(); 474 if ((lhsMax - lhsMin).ult(rhsMax)) { 475 APInt minRem = lhsMin.urem(rhsMax); 476 APInt maxRem = lhsMax.urem(rhsMax); 477 if (minRem.ule(maxRem)) { 478 umin = minRem; 479 umax = maxRem; 480 } 481 } 482 } 483 } 484 return ConstantIntRanges::fromUnsigned(umin, umax); 485 } 486 487 //===----------------------------------------------------------------------===// 488 // Max and min (MaxS, MaxU, MinS, MinU) 489 //===----------------------------------------------------------------------===// 490 491 ConstantIntRanges 492 mlir::intrange::inferMaxS(ArrayRef<ConstantIntRanges> argRanges) { 493 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 494 495 const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin(); 496 const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax(); 497 return ConstantIntRanges::fromSigned(smin, smax); 498 } 499 500 ConstantIntRanges 501 mlir::intrange::inferMaxU(ArrayRef<ConstantIntRanges> argRanges) { 502 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 503 504 const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin(); 505 const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax(); 506 return ConstantIntRanges::fromUnsigned(umin, umax); 507 } 508 509 ConstantIntRanges 510 mlir::intrange::inferMinS(ArrayRef<ConstantIntRanges> argRanges) { 511 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 512 513 const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin(); 514 const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax(); 515 return ConstantIntRanges::fromSigned(smin, smax); 516 } 517 518 ConstantIntRanges 519 mlir::intrange::inferMinU(ArrayRef<ConstantIntRanges> argRanges) { 520 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 521 522 const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin(); 523 const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax(); 524 return ConstantIntRanges::fromUnsigned(umin, umax); 525 } 526 527 //===----------------------------------------------------------------------===// 528 // Bitwise operators (And, Or, Xor) 529 //===----------------------------------------------------------------------===// 530 531 /// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, 532 /// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits 533 /// that both bonuds have in common. This gives us a consertive approximation 534 /// for what values can be passed to bitwise operations. 535 static std::tuple<APInt, APInt> 536 widenBitwiseBounds(const ConstantIntRanges &bound) { 537 APInt leftVal = bound.umin(), rightVal = bound.umax(); 538 unsigned bitwidth = leftVal.getBitWidth(); 539 unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero(); 540 leftVal.clearLowBits(differingBits); 541 rightVal.setLowBits(differingBits); 542 return std::make_tuple(std::move(leftVal), std::move(rightVal)); 543 } 544 545 ConstantIntRanges 546 mlir::intrange::inferAnd(ArrayRef<ConstantIntRanges> argRanges) { 547 auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); 548 auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); 549 auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> { 550 return a & b; 551 }; 552 return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, 553 /*isSigned=*/false); 554 } 555 556 ConstantIntRanges 557 mlir::intrange::inferOr(ArrayRef<ConstantIntRanges> argRanges) { 558 auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); 559 auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); 560 auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> { 561 return a | b; 562 }; 563 return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, 564 /*isSigned=*/false); 565 } 566 567 /// Get bitmask of all bits which can change while iterating in 568 /// [bound.umin(), bound.umax()]. 569 static APInt getVaryingBitsMask(const ConstantIntRanges &bound) { 570 APInt leftVal = bound.umin(), rightVal = bound.umax(); 571 unsigned bitwidth = leftVal.getBitWidth(); 572 unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero(); 573 return APInt::getLowBitsSet(bitwidth, differingBits); 574 } 575 576 ConstantIntRanges 577 mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) { 578 // Construct mask of varying bits for both ranges, xor values and then replace 579 // masked bits with 0s and 1s to get min and max values respectively. 580 ConstantIntRanges lhs = argRanges[0], rhs = argRanges[1]; 581 APInt mask = getVaryingBitsMask(lhs) | getVaryingBitsMask(rhs); 582 APInt res = lhs.umin() ^ rhs.umin(); 583 APInt min = res & ~mask; 584 APInt max = res | mask; 585 return ConstantIntRanges::fromUnsigned(min, max); 586 } 587 588 //===----------------------------------------------------------------------===// 589 // Shifts (Shl, ShrS, ShrU) 590 //===----------------------------------------------------------------------===// 591 592 ConstantIntRanges 593 mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges, 594 OverflowFlags ovfFlags) { 595 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 596 const APInt &rhsUMin = rhs.umin(), &rhsUMax = rhs.umax(); 597 598 // The signed/unsigned overflow behavior of shl by `rhs` matches a mul with 599 // 2^rhs. 600 ConstArithStdFn ushl = [=](const APInt &l, 601 const APInt &r) -> std::optional<APInt> { 602 bool overflowed = false; 603 APInt result = any(ovfFlags & OverflowFlags::Nuw) 604 ? l.ushl_sat(r) 605 : l.ushl_ov(r, overflowed); 606 return overflowed ? std::optional<APInt>() : result; 607 }; 608 ConstArithStdFn sshl = [=](const APInt &l, 609 const APInt &r) -> std::optional<APInt> { 610 bool overflowed = false; 611 APInt result = any(ovfFlags & OverflowFlags::Nsw) 612 ? l.sshl_sat(r) 613 : l.sshl_ov(r, overflowed); 614 return overflowed ? std::optional<APInt>() : result; 615 }; 616 617 ConstantIntRanges urange = 618 minMaxBy(ushl, {lhs.umin(), lhs.umax()}, {rhsUMin, rhsUMax}, 619 /*isSigned=*/false); 620 ConstantIntRanges srange = 621 minMaxBy(sshl, {lhs.smin(), lhs.smax()}, {rhsUMin, rhsUMax}, 622 /*isSigned=*/true); 623 return urange.intersection(srange); 624 } 625 626 ConstantIntRanges 627 mlir::intrange::inferShrS(ArrayRef<ConstantIntRanges> argRanges) { 628 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 629 630 auto ashr = [](const APInt &l, const APInt &r) -> std::optional<APInt> { 631 return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.ashr(r); 632 }; 633 634 return minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, 635 /*isSigned=*/true); 636 } 637 638 ConstantIntRanges 639 mlir::intrange::inferShrU(ArrayRef<ConstantIntRanges> argRanges) { 640 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 641 642 auto lshr = [](const APInt &l, const APInt &r) -> std::optional<APInt> { 643 return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.lshr(r); 644 }; 645 return minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, 646 /*isSigned=*/false); 647 } 648 649 //===----------------------------------------------------------------------===// 650 // Comparisons (Cmp) 651 //===----------------------------------------------------------------------===// 652 653 static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred) { 654 switch (pred) { 655 case intrange::CmpPredicate::eq: 656 return intrange::CmpPredicate::ne; 657 case intrange::CmpPredicate::ne: 658 return intrange::CmpPredicate::eq; 659 case intrange::CmpPredicate::slt: 660 return intrange::CmpPredicate::sge; 661 case intrange::CmpPredicate::sle: 662 return intrange::CmpPredicate::sgt; 663 case intrange::CmpPredicate::sgt: 664 return intrange::CmpPredicate::sle; 665 case intrange::CmpPredicate::sge: 666 return intrange::CmpPredicate::slt; 667 case intrange::CmpPredicate::ult: 668 return intrange::CmpPredicate::uge; 669 case intrange::CmpPredicate::ule: 670 return intrange::CmpPredicate::ugt; 671 case intrange::CmpPredicate::ugt: 672 return intrange::CmpPredicate::ule; 673 case intrange::CmpPredicate::uge: 674 return intrange::CmpPredicate::ult; 675 } 676 llvm_unreachable("unknown cmp predicate value"); 677 } 678 679 static bool isStaticallyTrue(intrange::CmpPredicate pred, 680 const ConstantIntRanges &lhs, 681 const ConstantIntRanges &rhs) { 682 switch (pred) { 683 case intrange::CmpPredicate::sle: 684 return lhs.smax().sle(rhs.smin()); 685 case intrange::CmpPredicate::slt: 686 return lhs.smax().slt(rhs.smin()); 687 case intrange::CmpPredicate::ule: 688 return lhs.umax().ule(rhs.umin()); 689 case intrange::CmpPredicate::ult: 690 return lhs.umax().ult(rhs.umin()); 691 case intrange::CmpPredicate::sge: 692 return lhs.smin().sge(rhs.smax()); 693 case intrange::CmpPredicate::sgt: 694 return lhs.smin().sgt(rhs.smax()); 695 case intrange::CmpPredicate::uge: 696 return lhs.umin().uge(rhs.umax()); 697 case intrange::CmpPredicate::ugt: 698 return lhs.umin().ugt(rhs.umax()); 699 case intrange::CmpPredicate::eq: { 700 std::optional<APInt> lhsConst = lhs.getConstantValue(); 701 std::optional<APInt> rhsConst = rhs.getConstantValue(); 702 return lhsConst && rhsConst && lhsConst == rhsConst; 703 } 704 case intrange::CmpPredicate::ne: { 705 // While equality requires that there is an interpration of the preceeding 706 // computations that produces equal constants, whether that be signed or 707 // unsigned, statically determining inequality requires that neither 708 // interpretation produce potentially overlapping ranges. 709 bool sne = isStaticallyTrue(intrange::CmpPredicate::slt, lhs, rhs) || 710 isStaticallyTrue(intrange::CmpPredicate::sgt, lhs, rhs); 711 bool une = isStaticallyTrue(intrange::CmpPredicate::ult, lhs, rhs) || 712 isStaticallyTrue(intrange::CmpPredicate::ugt, lhs, rhs); 713 return sne && une; 714 } 715 } 716 return false; 717 } 718 719 std::optional<bool> mlir::intrange::evaluatePred(CmpPredicate pred, 720 const ConstantIntRanges &lhs, 721 const ConstantIntRanges &rhs) { 722 if (isStaticallyTrue(pred, lhs, rhs)) 723 return true; 724 if (isStaticallyTrue(invertPredicate(pred), lhs, rhs)) 725 return false; 726 return std::nullopt; 727 } 728