1 //===- IndexOps.cpp - Index operation definitions --------------------------==// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Index/IR/IndexOps.h" 10 #include "mlir/Dialect/Index/IR/IndexAttrs.h" 11 #include "mlir/Dialect/Index/IR/IndexDialect.h" 12 #include "mlir/IR/Builders.h" 13 #include "mlir/IR/Matchers.h" 14 #include "mlir/IR/OpImplementation.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" 17 #include "llvm/ADT/SmallString.h" 18 #include "llvm/ADT/TypeSwitch.h" 19 20 using namespace mlir; 21 using namespace mlir::index; 22 23 //===----------------------------------------------------------------------===// 24 // IndexDialect 25 //===----------------------------------------------------------------------===// 26 27 void IndexDialect::registerOperations() { 28 addOperations< 29 #define GET_OP_LIST 30 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc" 31 >(); 32 } 33 34 Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value, 35 Type type, Location loc) { 36 // Materialize bool constants as `i1`. 37 if (auto boolValue = dyn_cast<BoolAttr>(value)) { 38 if (!type.isSignlessInteger(1)) 39 return nullptr; 40 return b.create<BoolConstantOp>(loc, type, boolValue); 41 } 42 43 // Materialize integer attributes as `index`. 44 if (auto indexValue = dyn_cast<IntegerAttr>(value)) { 45 if (!llvm::isa<IndexType>(indexValue.getType()) || 46 !llvm::isa<IndexType>(type)) 47 return nullptr; 48 assert(indexValue.getValue().getBitWidth() == 49 IndexType::kInternalStorageBitWidth); 50 return b.create<ConstantOp>(loc, indexValue); 51 } 52 53 return nullptr; 54 } 55 56 //===----------------------------------------------------------------------===// 57 // Fold Utilities 58 //===----------------------------------------------------------------------===// 59 60 /// Fold an index operation irrespective of the target bitwidth. The 61 /// operation must satisfy the property: 62 /// 63 /// ``` 64 /// trunc(f(a, b)) = f(trunc(a), trunc(b)) 65 /// ``` 66 /// 67 /// For all values of `a` and `b`. The function accepts a lambda that computes 68 /// the integer result, which in turn must satisfy the above property. 69 static OpFoldResult foldBinaryOpUnchecked( 70 ArrayRef<Attribute> operands, 71 function_ref<std::optional<APInt>(const APInt &, const APInt &)> 72 calculate) { 73 assert(operands.size() == 2 && "binary operation expected 2 operands"); 74 auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]); 75 auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]); 76 if (!lhs || !rhs) 77 return {}; 78 79 std::optional<APInt> result = calculate(lhs.getValue(), rhs.getValue()); 80 if (!result) 81 return {}; 82 assert(result->trunc(32) == 83 calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32))); 84 return IntegerAttr::get(IndexType::get(lhs.getContext()), *result); 85 } 86 87 /// Fold an index operation only if the truncated 64-bit result matches the 88 /// 32-bit result for operations that don't satisfy the above property. These 89 /// are operations where the upper bits of the operands can affect the lower 90 /// bits of the results. 91 /// 92 /// The function accepts a lambda that computes the integer result in both 93 /// 64-bit and 32-bit. If either call returns `std::nullopt`, the operation is 94 /// not folded. 95 static OpFoldResult foldBinaryOpChecked( 96 ArrayRef<Attribute> operands, 97 function_ref<std::optional<APInt>(const APInt &, const APInt &lhs)> 98 calculate) { 99 assert(operands.size() == 2 && "binary operation expected 2 operands"); 100 auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]); 101 auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]); 102 // Only fold index operands. 103 if (!lhs || !rhs) 104 return {}; 105 106 // Compute the 64-bit result and the 32-bit result. 107 std::optional<APInt> result64 = calculate(lhs.getValue(), rhs.getValue()); 108 if (!result64) 109 return {}; 110 std::optional<APInt> result32 = 111 calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)); 112 if (!result32) 113 return {}; 114 // Compare the truncated 64-bit result to the 32-bit result. 115 if (result64->trunc(32) != *result32) 116 return {}; 117 // The operation can be folded for these particular operands. 118 return IntegerAttr::get(IndexType::get(lhs.getContext()), *result64); 119 } 120 121 /// Helper for associative and commutative binary ops that can be transformed: 122 /// `x = op(v, c1); y = op(x, c2)` -> `tmp = op(c1, c2); y = op(v, tmp)` 123 /// where c1 and c2 are constants. It is expected that `tmp` will be folded. 124 template <typename BinaryOp> 125 LogicalResult 126 canonicalizeAssociativeCommutativeBinaryOp(BinaryOp op, 127 PatternRewriter &rewriter) { 128 if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant())) 129 return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant"); 130 131 auto lhsOp = op.getLhs().template getDefiningOp<BinaryOp>(); 132 if (!lhsOp) 133 return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not the same BinaryOp"); 134 135 if (!mlir::matchPattern(lhsOp.getRhs(), mlir::m_Constant())) 136 return rewriter.notifyMatchFailure(op.getLoc(), "RHS of LHS op is not a constant"); 137 138 Value c = rewriter.createOrFold<BinaryOp>(op->getLoc(), op.getRhs(), 139 lhsOp.getRhs()); 140 if (c.getDefiningOp<BinaryOp>()) 141 return rewriter.notifyMatchFailure(op.getLoc(), "new BinaryOp was not folded"); 142 143 rewriter.replaceOpWithNewOp<BinaryOp>(op, lhsOp.getLhs(), c); 144 return success(); 145 } 146 147 //===----------------------------------------------------------------------===// 148 // AddOp 149 //===----------------------------------------------------------------------===// 150 151 OpFoldResult AddOp::fold(FoldAdaptor adaptor) { 152 if (OpFoldResult result = foldBinaryOpUnchecked( 153 adaptor.getOperands(), 154 [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; })) 155 return result; 156 157 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) { 158 // Fold `add(x, 0) -> x`. 159 if (rhs.getValue().isZero()) 160 return getLhs(); 161 } 162 163 return {}; 164 } 165 166 LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) { 167 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); 168 } 169 170 //===----------------------------------------------------------------------===// 171 // SubOp 172 //===----------------------------------------------------------------------===// 173 174 OpFoldResult SubOp::fold(FoldAdaptor adaptor) { 175 if (OpFoldResult result = foldBinaryOpUnchecked( 176 adaptor.getOperands(), 177 [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; })) 178 return result; 179 180 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) { 181 // Fold `sub(x, 0) -> x`. 182 if (rhs.getValue().isZero()) 183 return getLhs(); 184 } 185 186 return {}; 187 } 188 189 //===----------------------------------------------------------------------===// 190 // MulOp 191 //===----------------------------------------------------------------------===// 192 193 OpFoldResult MulOp::fold(FoldAdaptor adaptor) { 194 if (OpFoldResult result = foldBinaryOpUnchecked( 195 adaptor.getOperands(), 196 [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; })) 197 return result; 198 199 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) { 200 // Fold `mul(x, 1) -> x`. 201 if (rhs.getValue().isOne()) 202 return getLhs(); 203 // Fold `mul(x, 0) -> 0`. 204 if (rhs.getValue().isZero()) 205 return rhs; 206 } 207 208 return {}; 209 } 210 211 LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) { 212 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); 213 } 214 215 //===----------------------------------------------------------------------===// 216 // DivSOp 217 //===----------------------------------------------------------------------===// 218 219 OpFoldResult DivSOp::fold(FoldAdaptor adaptor) { 220 return foldBinaryOpChecked( 221 adaptor.getOperands(), 222 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { 223 // Don't fold division by zero. 224 if (rhs.isZero()) 225 return std::nullopt; 226 return lhs.sdiv(rhs); 227 }); 228 } 229 230 //===----------------------------------------------------------------------===// 231 // DivUOp 232 //===----------------------------------------------------------------------===// 233 234 OpFoldResult DivUOp::fold(FoldAdaptor adaptor) { 235 return foldBinaryOpChecked( 236 adaptor.getOperands(), 237 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { 238 // Don't fold division by zero. 239 if (rhs.isZero()) 240 return std::nullopt; 241 return lhs.udiv(rhs); 242 }); 243 } 244 245 //===----------------------------------------------------------------------===// 246 // CeilDivSOp 247 //===----------------------------------------------------------------------===// 248 249 /// Compute `ceildivs(n, m)` as `x = m > 0 ? -1 : 1` and then 250 /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. 251 static std::optional<APInt> calculateCeilDivS(const APInt &n, const APInt &m) { 252 // Don't fold division by zero. 253 if (m.isZero()) 254 return std::nullopt; 255 // Short-circuit the zero case. 256 if (n.isZero()) 257 return n; 258 259 bool mGtZ = m.sgt(0); 260 if (n.sgt(0) != mGtZ) { 261 // If the operands have different signs, compute the negative result. Signed 262 // division overflow is not possible, since if `m == -1`, `n` can be at most 263 // `INT_MAX`, and `-INT_MAX != INT_MIN` in two's complement. 264 return -(-n).sdiv(m); 265 } 266 // Otherwise, compute the positive result. Signed division overflow is not 267 // possible since if `m == -1`, `x` will be `1`. 268 int64_t x = mGtZ ? -1 : 1; 269 return (n + x).sdiv(m) + 1; 270 } 271 272 OpFoldResult CeilDivSOp::fold(FoldAdaptor adaptor) { 273 return foldBinaryOpChecked(adaptor.getOperands(), calculateCeilDivS); 274 } 275 276 //===----------------------------------------------------------------------===// 277 // CeilDivUOp 278 //===----------------------------------------------------------------------===// 279 280 OpFoldResult CeilDivUOp::fold(FoldAdaptor adaptor) { 281 // Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`. 282 return foldBinaryOpChecked( 283 adaptor.getOperands(), 284 [](const APInt &n, const APInt &m) -> std::optional<APInt> { 285 // Don't fold division by zero. 286 if (m.isZero()) 287 return std::nullopt; 288 // Short-circuit the zero case. 289 if (n.isZero()) 290 return n; 291 292 return (n - 1).udiv(m) + 1; 293 }); 294 } 295 296 //===----------------------------------------------------------------------===// 297 // FloorDivSOp 298 //===----------------------------------------------------------------------===// 299 300 /// Compute `floordivs(n, m)` as `x = m < 0 ? 1 : -1` and then 301 /// `n*m < 0 ? -1 - (x-n)/m : n/m`. 302 static std::optional<APInt> calculateFloorDivS(const APInt &n, const APInt &m) { 303 // Don't fold division by zero. 304 if (m.isZero()) 305 return std::nullopt; 306 // Short-circuit the zero case. 307 if (n.isZero()) 308 return n; 309 310 bool mLtZ = m.slt(0); 311 if (n.slt(0) == mLtZ) { 312 // If the operands have the same sign, compute the positive result. 313 return n.sdiv(m); 314 } 315 // If the operands have different signs, compute the negative result. Signed 316 // division overflow is not possible since if `m == -1`, `x` will be 1 and 317 // `n` can be at most `INT_MAX`. 318 int64_t x = mLtZ ? 1 : -1; 319 return -1 - (x - n).sdiv(m); 320 } 321 322 OpFoldResult FloorDivSOp::fold(FoldAdaptor adaptor) { 323 return foldBinaryOpChecked(adaptor.getOperands(), calculateFloorDivS); 324 } 325 326 //===----------------------------------------------------------------------===// 327 // RemSOp 328 //===----------------------------------------------------------------------===// 329 330 OpFoldResult RemSOp::fold(FoldAdaptor adaptor) { 331 return foldBinaryOpChecked( 332 adaptor.getOperands(), 333 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { 334 // Don't fold division by zero. 335 if (rhs.isZero()) 336 return std::nullopt; 337 return lhs.srem(rhs); 338 }); 339 } 340 341 //===----------------------------------------------------------------------===// 342 // RemUOp 343 //===----------------------------------------------------------------------===// 344 345 OpFoldResult RemUOp::fold(FoldAdaptor adaptor) { 346 return foldBinaryOpChecked( 347 adaptor.getOperands(), 348 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { 349 // Don't fold division by zero. 350 if (rhs.isZero()) 351 return std::nullopt; 352 return lhs.urem(rhs); 353 }); 354 } 355 356 //===----------------------------------------------------------------------===// 357 // MaxSOp 358 //===----------------------------------------------------------------------===// 359 360 OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) { 361 return foldBinaryOpChecked(adaptor.getOperands(), 362 [](const APInt &lhs, const APInt &rhs) { 363 return lhs.sgt(rhs) ? lhs : rhs; 364 }); 365 } 366 367 LogicalResult MaxSOp::canonicalize(MaxSOp op, PatternRewriter &rewriter) { 368 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); 369 } 370 371 //===----------------------------------------------------------------------===// 372 // MaxUOp 373 //===----------------------------------------------------------------------===// 374 375 OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) { 376 return foldBinaryOpChecked(adaptor.getOperands(), 377 [](const APInt &lhs, const APInt &rhs) { 378 return lhs.ugt(rhs) ? lhs : rhs; 379 }); 380 } 381 382 LogicalResult MaxUOp::canonicalize(MaxUOp op, PatternRewriter &rewriter) { 383 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); 384 } 385 386 //===----------------------------------------------------------------------===// 387 // MinSOp 388 //===----------------------------------------------------------------------===// 389 390 OpFoldResult MinSOp::fold(FoldAdaptor adaptor) { 391 return foldBinaryOpChecked(adaptor.getOperands(), 392 [](const APInt &lhs, const APInt &rhs) { 393 return lhs.slt(rhs) ? lhs : rhs; 394 }); 395 } 396 397 LogicalResult MinSOp::canonicalize(MinSOp op, PatternRewriter &rewriter) { 398 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); 399 } 400 401 //===----------------------------------------------------------------------===// 402 // MinUOp 403 //===----------------------------------------------------------------------===// 404 405 OpFoldResult MinUOp::fold(FoldAdaptor adaptor) { 406 return foldBinaryOpChecked(adaptor.getOperands(), 407 [](const APInt &lhs, const APInt &rhs) { 408 return lhs.ult(rhs) ? lhs : rhs; 409 }); 410 } 411 412 LogicalResult MinUOp::canonicalize(MinUOp op, PatternRewriter &rewriter) { 413 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); 414 } 415 416 //===----------------------------------------------------------------------===// 417 // ShlOp 418 //===----------------------------------------------------------------------===// 419 420 OpFoldResult ShlOp::fold(FoldAdaptor adaptor) { 421 return foldBinaryOpUnchecked( 422 adaptor.getOperands(), 423 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { 424 // We cannot fold if the RHS is greater than or equal to 32 because 425 // this would be UB in 32-bit systems but not on 64-bit systems. RHS is 426 // already treated as unsigned. 427 if (rhs.uge(32)) 428 return {}; 429 return lhs << rhs; 430 }); 431 } 432 433 //===----------------------------------------------------------------------===// 434 // ShrSOp 435 //===----------------------------------------------------------------------===// 436 437 OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) { 438 return foldBinaryOpChecked( 439 adaptor.getOperands(), 440 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { 441 // Don't fold if RHS is greater than or equal to 32. 442 if (rhs.uge(32)) 443 return {}; 444 return lhs.ashr(rhs); 445 }); 446 } 447 448 //===----------------------------------------------------------------------===// 449 // ShrUOp 450 //===----------------------------------------------------------------------===// 451 452 OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) { 453 return foldBinaryOpChecked( 454 adaptor.getOperands(), 455 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { 456 // Don't fold if RHS is greater than or equal to 32. 457 if (rhs.uge(32)) 458 return {}; 459 return lhs.lshr(rhs); 460 }); 461 } 462 463 //===----------------------------------------------------------------------===// 464 // AndOp 465 //===----------------------------------------------------------------------===// 466 467 OpFoldResult AndOp::fold(FoldAdaptor adaptor) { 468 return foldBinaryOpUnchecked( 469 adaptor.getOperands(), 470 [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; }); 471 } 472 473 LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) { 474 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); 475 } 476 477 //===----------------------------------------------------------------------===// 478 // OrOp 479 //===----------------------------------------------------------------------===// 480 481 OpFoldResult OrOp::fold(FoldAdaptor adaptor) { 482 return foldBinaryOpUnchecked( 483 adaptor.getOperands(), 484 [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; }); 485 } 486 487 LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) { 488 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); 489 } 490 491 //===----------------------------------------------------------------------===// 492 // XOrOp 493 //===----------------------------------------------------------------------===// 494 495 OpFoldResult XOrOp::fold(FoldAdaptor adaptor) { 496 return foldBinaryOpUnchecked( 497 adaptor.getOperands(), 498 [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; }); 499 } 500 501 LogicalResult XOrOp::canonicalize(XOrOp op, PatternRewriter &rewriter) { 502 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); 503 } 504 505 //===----------------------------------------------------------------------===// 506 // CastSOp 507 //===----------------------------------------------------------------------===// 508 509 static OpFoldResult 510 foldCastOp(Attribute input, Type type, 511 function_ref<APInt(const APInt &, unsigned)> extFn, 512 function_ref<APInt(const APInt &, unsigned)> extOrTruncFn) { 513 auto attr = dyn_cast_if_present<IntegerAttr>(input); 514 if (!attr) 515 return {}; 516 const APInt &value = attr.getValue(); 517 518 if (isa<IndexType>(type)) { 519 // When casting to an index type, perform the cast assuming a 64-bit target. 520 // The result can be truncated to 32 bits as needed and always be correct. 521 // This is because `cast32(cast64(value)) == cast32(value)`. 522 APInt result = extOrTruncFn(value, 64); 523 return IntegerAttr::get(type, result); 524 } 525 526 // When casting from an index type, we must ensure the results respect 527 // `cast_t(value) == cast_t(trunc32(value))`. 528 auto intType = cast<IntegerType>(type); 529 unsigned width = intType.getWidth(); 530 531 // If the result type is at most 32 bits, then the cast can always be folded 532 // because it is always a truncation. 533 if (width <= 32) { 534 APInt result = value.trunc(width); 535 return IntegerAttr::get(type, result); 536 } 537 538 // If the result type is at least 64 bits, then the cast is always a 539 // extension. The results will differ if `trunc32(value) != value)`. 540 if (width >= 64) { 541 if (extFn(value.trunc(32), 64) != value) 542 return {}; 543 APInt result = extFn(value, width); 544 return IntegerAttr::get(type, result); 545 } 546 547 // Otherwise, we just have to check the property directly. 548 APInt result = value.trunc(width); 549 if (result != extFn(value.trunc(32), width)) 550 return {}; 551 return IntegerAttr::get(type, result); 552 } 553 554 bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) { 555 return llvm::isa<IndexType>(lhsTypes.front()) != 556 llvm::isa<IndexType>(rhsTypes.front()); 557 } 558 559 OpFoldResult CastSOp::fold(FoldAdaptor adaptor) { 560 return foldCastOp( 561 adaptor.getInput(), getType(), 562 [](const APInt &x, unsigned width) { return x.sext(width); }, 563 [](const APInt &x, unsigned width) { return x.sextOrTrunc(width); }); 564 } 565 566 //===----------------------------------------------------------------------===// 567 // CastUOp 568 //===----------------------------------------------------------------------===// 569 570 bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) { 571 return llvm::isa<IndexType>(lhsTypes.front()) != 572 llvm::isa<IndexType>(rhsTypes.front()); 573 } 574 575 OpFoldResult CastUOp::fold(FoldAdaptor adaptor) { 576 return foldCastOp( 577 adaptor.getInput(), getType(), 578 [](const APInt &x, unsigned width) { return x.zext(width); }, 579 [](const APInt &x, unsigned width) { return x.zextOrTrunc(width); }); 580 } 581 582 //===----------------------------------------------------------------------===// 583 // CmpOp 584 //===----------------------------------------------------------------------===// 585 586 /// Compare two integers according to the comparison predicate. 587 bool compareIndices(const APInt &lhs, const APInt &rhs, 588 IndexCmpPredicate pred) { 589 switch (pred) { 590 case IndexCmpPredicate::EQ: 591 return lhs.eq(rhs); 592 case IndexCmpPredicate::NE: 593 return lhs.ne(rhs); 594 case IndexCmpPredicate::SGE: 595 return lhs.sge(rhs); 596 case IndexCmpPredicate::SGT: 597 return lhs.sgt(rhs); 598 case IndexCmpPredicate::SLE: 599 return lhs.sle(rhs); 600 case IndexCmpPredicate::SLT: 601 return lhs.slt(rhs); 602 case IndexCmpPredicate::UGE: 603 return lhs.uge(rhs); 604 case IndexCmpPredicate::UGT: 605 return lhs.ugt(rhs); 606 case IndexCmpPredicate::ULE: 607 return lhs.ule(rhs); 608 case IndexCmpPredicate::ULT: 609 return lhs.ult(rhs); 610 } 611 llvm_unreachable("unhandled IndexCmpPredicate predicate"); 612 } 613 614 /// `cmp(max/min(x, cstA), cstB)` can be folded to a constant depending on the 615 /// values of `cstA` and `cstB`, the max or min operation, and the comparison 616 /// predicate. Check whether the value folds in both 32-bit and 64-bit 617 /// arithmetic and to the same value. 618 static std::optional<bool> foldCmpOfMaxOrMin(Operation *lhsOp, 619 const APInt &cstA, 620 const APInt &cstB, unsigned width, 621 IndexCmpPredicate pred) { 622 ConstantIntRanges lhsRange = TypeSwitch<Operation *, ConstantIntRanges>(lhsOp) 623 .Case([&](MinSOp op) { 624 return ConstantIntRanges::fromSigned( 625 APInt::getSignedMinValue(width), cstA); 626 }) 627 .Case([&](MinUOp op) { 628 return ConstantIntRanges::fromUnsigned( 629 APInt::getMinValue(width), cstA); 630 }) 631 .Case([&](MaxSOp op) { 632 return ConstantIntRanges::fromSigned( 633 cstA, APInt::getSignedMaxValue(width)); 634 }) 635 .Case([&](MaxUOp op) { 636 return ConstantIntRanges::fromUnsigned( 637 cstA, APInt::getMaxValue(width)); 638 }); 639 return intrange::evaluatePred(static_cast<intrange::CmpPredicate>(pred), 640 lhsRange, ConstantIntRanges::constant(cstB)); 641 } 642 643 /// Return the result of `cmp(pred, x, x)` 644 static bool compareSameArgs(IndexCmpPredicate pred) { 645 switch (pred) { 646 case IndexCmpPredicate::EQ: 647 case IndexCmpPredicate::SGE: 648 case IndexCmpPredicate::SLE: 649 case IndexCmpPredicate::UGE: 650 case IndexCmpPredicate::ULE: 651 return true; 652 case IndexCmpPredicate::NE: 653 case IndexCmpPredicate::SGT: 654 case IndexCmpPredicate::SLT: 655 case IndexCmpPredicate::UGT: 656 case IndexCmpPredicate::ULT: 657 return false; 658 } 659 llvm_unreachable("unknown predicate in compareSameArgs"); 660 } 661 662 OpFoldResult CmpOp::fold(FoldAdaptor adaptor) { 663 // Attempt to fold if both inputs are constant. 664 auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs()); 665 auto rhs = dyn_cast_if_present<IntegerAttr>(adaptor.getRhs()); 666 if (lhs && rhs) { 667 // Perform the comparison in 64-bit and 32-bit. 668 bool result64 = compareIndices(lhs.getValue(), rhs.getValue(), getPred()); 669 bool result32 = compareIndices(lhs.getValue().trunc(32), 670 rhs.getValue().trunc(32), getPred()); 671 if (result64 == result32) 672 return BoolAttr::get(getContext(), result64); 673 } 674 675 // Fold `cmp(max/min(x, cstA), cstB)`. 676 Operation *lhsOp = getLhs().getDefiningOp(); 677 IntegerAttr cstA; 678 if (isa_and_nonnull<MinSOp, MinUOp, MaxSOp, MaxUOp>(lhsOp) && 679 matchPattern(lhsOp->getOperand(1), m_Constant(&cstA)) && rhs) { 680 std::optional<bool> result64 = foldCmpOfMaxOrMin( 681 lhsOp, cstA.getValue(), rhs.getValue(), 64, getPred()); 682 std::optional<bool> result32 = 683 foldCmpOfMaxOrMin(lhsOp, cstA.getValue().trunc(32), 684 rhs.getValue().trunc(32), 32, getPred()); 685 // Fold if the 32-bit and 64-bit results are the same. 686 if (result64 && result32 && *result64 == *result32) 687 return BoolAttr::get(getContext(), *result64); 688 } 689 690 // Fold `cmp(x, x)` 691 if (getLhs() == getRhs()) 692 return BoolAttr::get(getContext(), compareSameArgs(getPred())); 693 694 return {}; 695 } 696 697 /// Canonicalize 698 /// `x - y cmp 0` to `x cmp y`. or `x - y cmp 0` to `x cmp y`. 699 /// `0 cmp x - y` to `y cmp x`. or `0 cmp x - y` to `y cmp x`. 700 LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) { 701 IntegerAttr cmpRhs; 702 IntegerAttr cmpLhs; 703 704 bool rhsIsZero = matchPattern(op.getRhs(), m_Constant(&cmpRhs)) && 705 cmpRhs.getValue().isZero(); 706 bool lhsIsZero = matchPattern(op.getLhs(), m_Constant(&cmpLhs)) && 707 cmpLhs.getValue().isZero(); 708 if (!rhsIsZero && !lhsIsZero) 709 return rewriter.notifyMatchFailure(op.getLoc(), 710 "cmp is not comparing something with 0"); 711 SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOp<index::SubOp>() 712 : op.getRhs().getDefiningOp<index::SubOp>(); 713 if (!subOp) 714 return rewriter.notifyMatchFailure( 715 op.getLoc(), "non-zero operand is not a result of subtraction"); 716 717 index::CmpOp newCmp; 718 if (rhsIsZero) 719 newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(), 720 subOp.getLhs(), subOp.getRhs()); 721 else 722 newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(), 723 subOp.getRhs(), subOp.getLhs()); 724 rewriter.replaceOp(op, newCmp); 725 return success(); 726 } 727 728 //===----------------------------------------------------------------------===// 729 // ConstantOp 730 //===----------------------------------------------------------------------===// 731 732 void ConstantOp::getAsmResultNames( 733 function_ref<void(Value, StringRef)> setNameFn) { 734 SmallString<32> specialNameBuffer; 735 llvm::raw_svector_ostream specialName(specialNameBuffer); 736 specialName << "idx" << getValueAttr().getValue(); 737 setNameFn(getResult(), specialName.str()); 738 } 739 740 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } 741 742 void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) { 743 build(b, state, b.getIndexType(), b.getIndexAttr(value)); 744 } 745 746 //===----------------------------------------------------------------------===// 747 // BoolConstantOp 748 //===----------------------------------------------------------------------===// 749 750 OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) { 751 return getValueAttr(); 752 } 753 754 void BoolConstantOp::getAsmResultNames( 755 function_ref<void(Value, StringRef)> setNameFn) { 756 setNameFn(getResult(), getValue() ? "true" : "false"); 757 } 758 759 //===----------------------------------------------------------------------===// 760 // ODS-Generated Definitions 761 //===----------------------------------------------------------------------===// 762 763 #define GET_OP_CLASSES 764 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc" 765