1 //===- InferIntRangeCommon.cpp - Inference for common ops ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains implementations of range inference for operations that are 10 // common to both the `arith` and `index` dialects to facilitate reuse. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" 15 16 #include "mlir/Interfaces/InferIntRangeInterface.h" 17 18 #include "llvm/ADT/ArrayRef.h" 19 #include "llvm/ADT/STLExtras.h" 20 21 #include "llvm/Support/Debug.h" 22 23 #include <iterator> 24 #include <optional> 25 26 using namespace mlir; 27 28 #define DEBUG_TYPE "int-range-analysis" 29 30 //===----------------------------------------------------------------------===// 31 // General utilities 32 //===----------------------------------------------------------------------===// 33 34 /// Function that evaluates the result of doing something on arithmetic 35 /// constants and returns std::nullopt on overflow. 36 using ConstArithFn = 37 function_ref<std::optional<APInt>(const APInt &, const APInt &)>; 38 39 /// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, 40 /// If either computation overflows, make the result unbounded. 41 static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft, 42 const APInt &minRight, 43 const APInt &maxLeft, 44 const APInt &maxRight, bool isSigned) { 45 std::optional<APInt> maybeMin = op(minLeft, minRight); 46 std::optional<APInt> maybeMax = op(maxLeft, maxRight); 47 if (maybeMin && maybeMax) 48 return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned); 49 return ConstantIntRanges::maxRange(minLeft.getBitWidth()); 50 } 51 52 /// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`, 53 /// ignoring unbounded values. Returns the maximal range if `op` overflows. 54 static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs, 55 ArrayRef<APInt> rhs, bool isSigned) { 56 unsigned width = lhs[0].getBitWidth(); 57 APInt min = 58 isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width); 59 APInt max = 60 isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width); 61 for (const APInt &left : lhs) { 62 for (const APInt &right : rhs) { 63 std::optional<APInt> maybeThisResult = op(left, right); 64 if (!maybeThisResult) 65 return ConstantIntRanges::maxRange(width); 66 APInt result = std::move(*maybeThisResult); 67 min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min; 68 max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max; 69 } 70 } 71 return ConstantIntRanges::range(min, max, isSigned); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // Ext, trunc, index op handling 76 //===----------------------------------------------------------------------===// 77 78 ConstantIntRanges 79 mlir::intrange::inferIndexOp(const InferRangeFn &inferFn, 80 ArrayRef<ConstantIntRanges> argRanges, 81 intrange::CmpMode mode) { 82 ConstantIntRanges sixtyFour = inferFn(argRanges); 83 SmallVector<ConstantIntRanges, 2> truncated; 84 llvm::transform(argRanges, std::back_inserter(truncated), 85 [](const ConstantIntRanges &range) { 86 return truncRange(range, /*destWidth=*/indexMinWidth); 87 }); 88 ConstantIntRanges thirtyTwo = inferFn(truncated); 89 ConstantIntRanges thirtyTwoAsSixtyFour = 90 extRange(thirtyTwo, /*destWidth=*/indexMaxWidth); 91 ConstantIntRanges sixtyFourAsThirtyTwo = 92 truncRange(sixtyFour, /*destWidth=*/indexMinWidth); 93 94 LLVM_DEBUG(llvm::dbgs() << "Index handling: 64-bit result = " << sixtyFour 95 << " 32-bit = " << thirtyTwo << "\n"); 96 bool truncEqual = false; 97 switch (mode) { 98 case intrange::CmpMode::Both: 99 truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo); 100 break; 101 case intrange::CmpMode::Signed: 102 truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.smin() && 103 thirtyTwo.smax() == sixtyFourAsThirtyTwo.smax()); 104 break; 105 case intrange::CmpMode::Unsigned: 106 truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.umin() && 107 thirtyTwo.umax() == sixtyFourAsThirtyTwo.umax()); 108 break; 109 } 110 if (truncEqual) 111 // Returing the 64-bit result preserves more information. 112 return sixtyFour; 113 ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour); 114 return merged; 115 } 116 117 ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range, 118 unsigned int destWidth) { 119 APInt umin = range.umin().zext(destWidth); 120 APInt umax = range.umax().zext(destWidth); 121 APInt smin = range.smin().sext(destWidth); 122 APInt smax = range.smax().sext(destWidth); 123 return {umin, umax, smin, smax}; 124 } 125 126 ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range, 127 unsigned destWidth) { 128 APInt umin = range.umin().zext(destWidth); 129 APInt umax = range.umax().zext(destWidth); 130 return ConstantIntRanges::fromUnsigned(umin, umax); 131 } 132 133 ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range, 134 unsigned destWidth) { 135 APInt smin = range.smin().sext(destWidth); 136 APInt smax = range.smax().sext(destWidth); 137 return ConstantIntRanges::fromSigned(smin, smax); 138 } 139 140 ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range, 141 unsigned int destWidth) { 142 // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb], 143 // the range of the resulting value is not contiguous ind includes 0. 144 // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2], 145 // but you can't truncate [255, 257] similarly. 146 bool hasUnsignedRollover = 147 range.umin().lshr(destWidth) != range.umax().lshr(destWidth); 148 APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth) 149 : range.umin().trunc(destWidth); 150 APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth) 151 : range.umax().trunc(destWidth); 152 153 // Signed post-truncation rollover will not occur when either: 154 // - The high parts of the min and max, plus the sign bit, are the same 155 // - The high halves + sign bit of the min and max are either all 1s or all 0s 156 // and you won't create a [positive, negative] range by truncating. 157 // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8 158 // but not [255, 257]_i16 to a range of i8s. You can also truncate 159 // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16. 160 // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e) 161 // will truncate to 0x7e, which is greater than 0 162 APInt sminHighPart = range.smin().ashr(destWidth - 1); 163 APInt smaxHighPart = range.smax().ashr(destWidth - 1); 164 bool hasSignedOverflow = 165 (sminHighPart != smaxHighPart) && 166 !(sminHighPart.isAllOnes() && 167 (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) && 168 !(sminHighPart.isZero() && smaxHighPart.isZero()); 169 APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth) 170 : range.smin().trunc(destWidth); 171 APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth) 172 : range.smax().trunc(destWidth); 173 return {umin, umax, smin, smax}; 174 } 175 176 //===----------------------------------------------------------------------===// 177 // Addition 178 //===----------------------------------------------------------------------===// 179 180 ConstantIntRanges 181 mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges, 182 OverflowFlags ovfFlags) { 183 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 184 185 std::function uadd = [=](const APInt &a, 186 const APInt &b) -> std::optional<APInt> { 187 bool overflowed = false; 188 APInt result = any(ovfFlags & OverflowFlags::Nuw) 189 ? a.uadd_sat(b) 190 : a.uadd_ov(b, overflowed); 191 return overflowed ? std::optional<APInt>() : result; 192 }; 193 std::function sadd = [=](const APInt &a, 194 const APInt &b) -> std::optional<APInt> { 195 bool overflowed = false; 196 APInt result = any(ovfFlags & OverflowFlags::Nsw) 197 ? a.sadd_sat(b) 198 : a.sadd_ov(b, overflowed); 199 return overflowed ? std::optional<APInt>() : result; 200 }; 201 202 ConstantIntRanges urange = computeBoundsBy( 203 uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false); 204 ConstantIntRanges srange = computeBoundsBy( 205 sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true); 206 return urange.intersection(srange); 207 } 208 209 //===----------------------------------------------------------------------===// 210 // Subtraction 211 //===----------------------------------------------------------------------===// 212 213 ConstantIntRanges 214 mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges, 215 OverflowFlags ovfFlags) { 216 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 217 218 std::function usub = [=](const APInt &a, 219 const APInt &b) -> std::optional<APInt> { 220 bool overflowed = false; 221 APInt result = any(ovfFlags & OverflowFlags::Nuw) 222 ? a.usub_sat(b) 223 : a.usub_ov(b, overflowed); 224 return overflowed ? std::optional<APInt>() : result; 225 }; 226 std::function ssub = [=](const APInt &a, 227 const APInt &b) -> std::optional<APInt> { 228 bool overflowed = false; 229 APInt result = any(ovfFlags & OverflowFlags::Nsw) 230 ? a.ssub_sat(b) 231 : a.ssub_ov(b, overflowed); 232 return overflowed ? std::optional<APInt>() : result; 233 }; 234 ConstantIntRanges urange = computeBoundsBy( 235 usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false); 236 ConstantIntRanges srange = computeBoundsBy( 237 ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true); 238 return urange.intersection(srange); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // Multiplication 243 //===----------------------------------------------------------------------===// 244 245 ConstantIntRanges 246 mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges, 247 OverflowFlags ovfFlags) { 248 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 249 250 std::function umul = [=](const APInt &a, 251 const APInt &b) -> std::optional<APInt> { 252 bool overflowed = false; 253 APInt result = any(ovfFlags & OverflowFlags::Nuw) 254 ? a.umul_sat(b) 255 : a.umul_ov(b, overflowed); 256 return overflowed ? std::optional<APInt>() : result; 257 }; 258 std::function smul = [=](const APInt &a, 259 const APInt &b) -> std::optional<APInt> { 260 bool overflowed = false; 261 APInt result = any(ovfFlags & OverflowFlags::Nsw) 262 ? a.smul_sat(b) 263 : a.smul_ov(b, overflowed); 264 return overflowed ? std::optional<APInt>() : result; 265 }; 266 267 ConstantIntRanges urange = 268 minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, 269 /*isSigned=*/false); 270 ConstantIntRanges srange = 271 minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()}, 272 /*isSigned=*/true); 273 return urange.intersection(srange); 274 } 275 276 //===----------------------------------------------------------------------===// 277 // DivU, CeilDivU (Unsigned division) 278 //===----------------------------------------------------------------------===// 279 280 /// Fix up division results (ex. for ceiling and floor), returning an APInt 281 /// if there has been no overflow 282 using DivisionFixupFn = function_ref<std::optional<APInt>( 283 const APInt &lhs, const APInt &rhs, const APInt &result)>; 284 285 static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs, 286 const ConstantIntRanges &rhs, 287 DivisionFixupFn fixup) { 288 const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(), 289 &rhsMax = rhs.umax(); 290 291 if (!rhsMin.isZero()) { 292 auto udiv = [&fixup](const APInt &a, 293 const APInt &b) -> std::optional<APInt> { 294 return fixup(a, b, a.udiv(b)); 295 }; 296 return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, 297 /*isSigned=*/false); 298 } 299 // Otherwise, it's possible we might divide by 0. 300 return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); 301 } 302 303 ConstantIntRanges 304 mlir::intrange::inferDivU(ArrayRef<ConstantIntRanges> argRanges) { 305 return inferDivURange(argRanges[0], argRanges[1], 306 [](const APInt &lhs, const APInt &rhs, 307 const APInt &result) { return result; }); 308 } 309 310 ConstantIntRanges 311 mlir::intrange::inferCeilDivU(ArrayRef<ConstantIntRanges> argRanges) { 312 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 313 314 DivisionFixupFn ceilDivUIFix = 315 [](const APInt &lhs, const APInt &rhs, 316 const APInt &result) -> std::optional<APInt> { 317 if (!lhs.urem(rhs).isZero()) { 318 bool overflowed = false; 319 APInt corrected = 320 result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed); 321 return overflowed ? std::optional<APInt>() : corrected; 322 } 323 return result; 324 }; 325 return inferDivURange(lhs, rhs, ceilDivUIFix); 326 } 327 328 //===----------------------------------------------------------------------===// 329 // DivS, CeilDivS, FloorDivS (Signed division) 330 //===----------------------------------------------------------------------===// 331 332 static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs, 333 const ConstantIntRanges &rhs, 334 DivisionFixupFn fixup) { 335 const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), 336 &rhsMax = rhs.smax(); 337 bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative(); 338 339 if (canDivide) { 340 auto sdiv = [&fixup](const APInt &a, 341 const APInt &b) -> std::optional<APInt> { 342 bool overflowed = false; 343 APInt result = a.sdiv_ov(b, overflowed); 344 return overflowed ? std::optional<APInt>() : fixup(a, b, result); 345 }; 346 return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, 347 /*isSigned=*/true); 348 } 349 return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); 350 } 351 352 ConstantIntRanges 353 mlir::intrange::inferDivS(ArrayRef<ConstantIntRanges> argRanges) { 354 return inferDivSRange(argRanges[0], argRanges[1], 355 [](const APInt &lhs, const APInt &rhs, 356 const APInt &result) { return result; }); 357 } 358 359 ConstantIntRanges 360 mlir::intrange::inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges) { 361 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 362 363 DivisionFixupFn ceilDivSIFix = 364 [](const APInt &lhs, const APInt &rhs, 365 const APInt &result) -> std::optional<APInt> { 366 if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) { 367 bool overflowed = false; 368 APInt corrected = 369 result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed); 370 return overflowed ? std::optional<APInt>() : corrected; 371 } 372 return result; 373 }; 374 return inferDivSRange(lhs, rhs, ceilDivSIFix); 375 } 376 377 ConstantIntRanges 378 mlir::intrange::inferFloorDivS(ArrayRef<ConstantIntRanges> argRanges) { 379 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 380 381 DivisionFixupFn floorDivSIFix = 382 [](const APInt &lhs, const APInt &rhs, 383 const APInt &result) -> std::optional<APInt> { 384 if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) { 385 bool overflowed = false; 386 APInt corrected = 387 result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed); 388 return overflowed ? std::optional<APInt>() : corrected; 389 } 390 return result; 391 }; 392 return inferDivSRange(lhs, rhs, floorDivSIFix); 393 } 394 395 //===----------------------------------------------------------------------===// 396 // Signed remainder (RemS) 397 //===----------------------------------------------------------------------===// 398 399 ConstantIntRanges 400 mlir::intrange::inferRemS(ArrayRef<ConstantIntRanges> argRanges) { 401 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 402 const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), 403 &rhsMax = rhs.smax(); 404 405 unsigned width = rhsMax.getBitWidth(); 406 APInt smin = APInt::getSignedMinValue(width); 407 APInt smax = APInt::getSignedMaxValue(width); 408 // No bounds if zero could be a divisor. 409 bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative()); 410 if (canBound) { 411 APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs(); 412 bool canNegativeDividend = lhsMin.isNegative(); 413 bool canPositiveDividend = lhsMax.isStrictlyPositive(); 414 APInt zero = APInt::getZero(maxDivisor.getBitWidth()); 415 APInt maxPositiveResult = maxDivisor - 1; 416 APInt minNegativeResult = -maxPositiveResult; 417 smin = canNegativeDividend ? minNegativeResult : zero; 418 smax = canPositiveDividend ? maxPositiveResult : zero; 419 // Special case: sweeping out a contiguous range in N/[modulus]. 420 if (rhsMin == rhsMax) { 421 if ((lhsMax - lhsMin).ult(maxDivisor)) { 422 APInt minRem = lhsMin.srem(maxDivisor); 423 APInt maxRem = lhsMax.srem(maxDivisor); 424 if (minRem.sle(maxRem)) { 425 smin = minRem; 426 smax = maxRem; 427 } 428 } 429 } 430 } 431 return ConstantIntRanges::fromSigned(smin, smax); 432 } 433 434 //===----------------------------------------------------------------------===// 435 // Unsigned remainder (RemU) 436 //===----------------------------------------------------------------------===// 437 438 ConstantIntRanges 439 mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) { 440 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 441 const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax(); 442 443 unsigned width = rhsMin.getBitWidth(); 444 APInt umin = APInt::getZero(width); 445 APInt umax = APInt::getMaxValue(width); 446 447 if (!rhsMin.isZero()) { 448 umax = rhsMax - 1; 449 // Special case: sweeping out a contiguous range in N/[modulus] 450 if (rhsMin == rhsMax) { 451 const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(); 452 if ((lhsMax - lhsMin).ult(rhsMax)) { 453 APInt minRem = lhsMin.urem(rhsMax); 454 APInt maxRem = lhsMax.urem(rhsMax); 455 if (minRem.ule(maxRem)) { 456 umin = minRem; 457 umax = maxRem; 458 } 459 } 460 } 461 } 462 return ConstantIntRanges::fromUnsigned(umin, umax); 463 } 464 465 //===----------------------------------------------------------------------===// 466 // Max and min (MaxS, MaxU, MinS, MinU) 467 //===----------------------------------------------------------------------===// 468 469 ConstantIntRanges 470 mlir::intrange::inferMaxS(ArrayRef<ConstantIntRanges> argRanges) { 471 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 472 473 const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin(); 474 const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax(); 475 return ConstantIntRanges::fromSigned(smin, smax); 476 } 477 478 ConstantIntRanges 479 mlir::intrange::inferMaxU(ArrayRef<ConstantIntRanges> argRanges) { 480 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 481 482 const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin(); 483 const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax(); 484 return ConstantIntRanges::fromUnsigned(umin, umax); 485 } 486 487 ConstantIntRanges 488 mlir::intrange::inferMinS(ArrayRef<ConstantIntRanges> argRanges) { 489 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 490 491 const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin(); 492 const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax(); 493 return ConstantIntRanges::fromSigned(smin, smax); 494 } 495 496 ConstantIntRanges 497 mlir::intrange::inferMinU(ArrayRef<ConstantIntRanges> argRanges) { 498 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 499 500 const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin(); 501 const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax(); 502 return ConstantIntRanges::fromUnsigned(umin, umax); 503 } 504 505 //===----------------------------------------------------------------------===// 506 // Bitwise operators (And, Or, Xor) 507 //===----------------------------------------------------------------------===// 508 509 /// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, 510 /// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits 511 /// that both bonuds have in common. This gives us a consertive approximation 512 /// for what values can be passed to bitwise operations. 513 static std::tuple<APInt, APInt> 514 widenBitwiseBounds(const ConstantIntRanges &bound) { 515 APInt leftVal = bound.umin(), rightVal = bound.umax(); 516 unsigned bitwidth = leftVal.getBitWidth(); 517 unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero(); 518 leftVal.clearLowBits(differingBits); 519 rightVal.setLowBits(differingBits); 520 return std::make_tuple(std::move(leftVal), std::move(rightVal)); 521 } 522 523 ConstantIntRanges 524 mlir::intrange::inferAnd(ArrayRef<ConstantIntRanges> argRanges) { 525 auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); 526 auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); 527 auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> { 528 return a & b; 529 }; 530 return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, 531 /*isSigned=*/false); 532 } 533 534 ConstantIntRanges 535 mlir::intrange::inferOr(ArrayRef<ConstantIntRanges> argRanges) { 536 auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); 537 auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); 538 auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> { 539 return a | b; 540 }; 541 return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, 542 /*isSigned=*/false); 543 } 544 545 ConstantIntRanges 546 mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) { 547 auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); 548 auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); 549 auto xori = [](const APInt &a, const APInt &b) -> std::optional<APInt> { 550 return a ^ b; 551 }; 552 return minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, 553 /*isSigned=*/false); 554 } 555 556 //===----------------------------------------------------------------------===// 557 // Shifts (Shl, ShrS, ShrU) 558 //===----------------------------------------------------------------------===// 559 560 ConstantIntRanges 561 mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges, 562 OverflowFlags ovfFlags) { 563 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 564 const APInt &rhsUMin = rhs.umin(), &rhsUMax = rhs.umax(); 565 566 // The signed/unsigned overflow behavior of shl by `rhs` matches a mul with 567 // 2^rhs. 568 std::function ushl = [=](const APInt &l, 569 const APInt &r) -> std::optional<APInt> { 570 bool overflowed = false; 571 APInt result = any(ovfFlags & OverflowFlags::Nuw) 572 ? l.ushl_sat(r) 573 : l.ushl_ov(r, overflowed); 574 return overflowed ? std::optional<APInt>() : result; 575 }; 576 std::function sshl = [=](const APInt &l, 577 const APInt &r) -> std::optional<APInt> { 578 bool overflowed = false; 579 APInt result = any(ovfFlags & OverflowFlags::Nsw) 580 ? l.sshl_sat(r) 581 : l.sshl_ov(r, overflowed); 582 return overflowed ? std::optional<APInt>() : result; 583 }; 584 585 ConstantIntRanges urange = 586 minMaxBy(ushl, {lhs.umin(), lhs.umax()}, {rhsUMin, rhsUMax}, 587 /*isSigned=*/false); 588 ConstantIntRanges srange = 589 minMaxBy(sshl, {lhs.smin(), lhs.smax()}, {rhsUMin, rhsUMax}, 590 /*isSigned=*/true); 591 return urange.intersection(srange); 592 } 593 594 ConstantIntRanges 595 mlir::intrange::inferShrS(ArrayRef<ConstantIntRanges> argRanges) { 596 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 597 598 ConstArithFn ashr = [](const APInt &l, 599 const APInt &r) -> std::optional<APInt> { 600 return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.ashr(r); 601 }; 602 603 return minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, 604 /*isSigned=*/true); 605 } 606 607 ConstantIntRanges 608 mlir::intrange::inferShrU(ArrayRef<ConstantIntRanges> argRanges) { 609 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 610 611 ConstArithFn lshr = [](const APInt &l, 612 const APInt &r) -> std::optional<APInt> { 613 return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.lshr(r); 614 }; 615 return minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, 616 /*isSigned=*/false); 617 } 618 619 //===----------------------------------------------------------------------===// 620 // Comparisons (Cmp) 621 //===----------------------------------------------------------------------===// 622 623 static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred) { 624 switch (pred) { 625 case intrange::CmpPredicate::eq: 626 return intrange::CmpPredicate::ne; 627 case intrange::CmpPredicate::ne: 628 return intrange::CmpPredicate::eq; 629 case intrange::CmpPredicate::slt: 630 return intrange::CmpPredicate::sge; 631 case intrange::CmpPredicate::sle: 632 return intrange::CmpPredicate::sgt; 633 case intrange::CmpPredicate::sgt: 634 return intrange::CmpPredicate::sle; 635 case intrange::CmpPredicate::sge: 636 return intrange::CmpPredicate::slt; 637 case intrange::CmpPredicate::ult: 638 return intrange::CmpPredicate::uge; 639 case intrange::CmpPredicate::ule: 640 return intrange::CmpPredicate::ugt; 641 case intrange::CmpPredicate::ugt: 642 return intrange::CmpPredicate::ule; 643 case intrange::CmpPredicate::uge: 644 return intrange::CmpPredicate::ult; 645 } 646 llvm_unreachable("unknown cmp predicate value"); 647 } 648 649 static bool isStaticallyTrue(intrange::CmpPredicate pred, 650 const ConstantIntRanges &lhs, 651 const ConstantIntRanges &rhs) { 652 switch (pred) { 653 case intrange::CmpPredicate::sle: 654 return lhs.smax().sle(rhs.smin()); 655 case intrange::CmpPredicate::slt: 656 return lhs.smax().slt(rhs.smin()); 657 case intrange::CmpPredicate::ule: 658 return lhs.umax().ule(rhs.umin()); 659 case intrange::CmpPredicate::ult: 660 return lhs.umax().ult(rhs.umin()); 661 case intrange::CmpPredicate::sge: 662 return lhs.smin().sge(rhs.smax()); 663 case intrange::CmpPredicate::sgt: 664 return lhs.smin().sgt(rhs.smax()); 665 case intrange::CmpPredicate::uge: 666 return lhs.umin().uge(rhs.umax()); 667 case intrange::CmpPredicate::ugt: 668 return lhs.umin().ugt(rhs.umax()); 669 case intrange::CmpPredicate::eq: { 670 std::optional<APInt> lhsConst = lhs.getConstantValue(); 671 std::optional<APInt> rhsConst = rhs.getConstantValue(); 672 return lhsConst && rhsConst && lhsConst == rhsConst; 673 } 674 case intrange::CmpPredicate::ne: { 675 // While equality requires that there is an interpration of the preceeding 676 // computations that produces equal constants, whether that be signed or 677 // unsigned, statically determining inequality requires that neither 678 // interpretation produce potentially overlapping ranges. 679 bool sne = isStaticallyTrue(intrange::CmpPredicate::slt, lhs, rhs) || 680 isStaticallyTrue(intrange::CmpPredicate::sgt, lhs, rhs); 681 bool une = isStaticallyTrue(intrange::CmpPredicate::ult, lhs, rhs) || 682 isStaticallyTrue(intrange::CmpPredicate::ugt, lhs, rhs); 683 return sne && une; 684 } 685 } 686 return false; 687 } 688 689 std::optional<bool> mlir::intrange::evaluatePred(CmpPredicate pred, 690 const ConstantIntRanges &lhs, 691 const ConstantIntRanges &rhs) { 692 if (isStaticallyTrue(pred, lhs, rhs)) 693 return true; 694 if (isStaticallyTrue(invertPredicate(pred), lhs, rhs)) 695 return false; 696 return std::nullopt; 697 } 698