xref: /llvm-project/mlir/lib/Dialect/Index/IR/IndexOps.cpp (revision 795b4efad0259cbf03fc98e3045621916328ce57)
1 //===- IndexOps.cpp - Index operation definitions --------------------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Index/IR/IndexOps.h"
10 #include "mlir/Dialect/Index/IR/IndexAttrs.h"
11 #include "mlir/Dialect/Index/IR/IndexDialect.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/Matchers.h"
14 #include "mlir/IR/OpImplementation.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
17 #include "llvm/ADT/SmallString.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 
20 using namespace mlir;
21 using namespace mlir::index;
22 
23 //===----------------------------------------------------------------------===//
24 // IndexDialect
25 //===----------------------------------------------------------------------===//
26 
27 void IndexDialect::registerOperations() {
28   addOperations<
29 #define GET_OP_LIST
30 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
31       >();
32 }
33 
34 Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
35                                              Type type, Location loc) {
36   // Materialize bool constants as `i1`.
37   if (auto boolValue = dyn_cast<BoolAttr>(value)) {
38     if (!type.isSignlessInteger(1))
39       return nullptr;
40     return b.create<BoolConstantOp>(loc, type, boolValue);
41   }
42 
43   // Materialize integer attributes as `index`.
44   if (auto indexValue = dyn_cast<IntegerAttr>(value)) {
45     if (!llvm::isa<IndexType>(indexValue.getType()) ||
46         !llvm::isa<IndexType>(type))
47       return nullptr;
48     assert(indexValue.getValue().getBitWidth() ==
49            IndexType::kInternalStorageBitWidth);
50     return b.create<ConstantOp>(loc, indexValue);
51   }
52 
53   return nullptr;
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // Fold Utilities
58 //===----------------------------------------------------------------------===//
59 
60 /// Fold an index operation irrespective of the target bitwidth. The
61 /// operation must satisfy the property:
62 ///
63 /// ```
64 /// trunc(f(a, b)) = f(trunc(a), trunc(b))
65 /// ```
66 ///
67 /// For all values of `a` and `b`. The function accepts a lambda that computes
68 /// the integer result, which in turn must satisfy the above property.
69 static OpFoldResult foldBinaryOpUnchecked(
70     ArrayRef<Attribute> operands,
71     function_ref<std::optional<APInt>(const APInt &, const APInt &)>
72         calculate) {
73   assert(operands.size() == 2 && "binary operation expected 2 operands");
74   auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
75   auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
76   if (!lhs || !rhs)
77     return {};
78 
79   std::optional<APInt> result = calculate(lhs.getValue(), rhs.getValue());
80   if (!result)
81     return {};
82   assert(result->trunc(32) ==
83          calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)));
84   return IntegerAttr::get(IndexType::get(lhs.getContext()), *result);
85 }
86 
87 /// Fold an index operation only if the truncated 64-bit result matches the
88 /// 32-bit result for operations that don't satisfy the above property. These
89 /// are operations where the upper bits of the operands can affect the lower
90 /// bits of the results.
91 ///
92 /// The function accepts a lambda that computes the integer result in both
93 /// 64-bit and 32-bit. If either call returns `std::nullopt`, the operation is
94 /// not folded.
95 static OpFoldResult foldBinaryOpChecked(
96     ArrayRef<Attribute> operands,
97     function_ref<std::optional<APInt>(const APInt &, const APInt &lhs)>
98         calculate) {
99   assert(operands.size() == 2 && "binary operation expected 2 operands");
100   auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
101   auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
102   // Only fold index operands.
103   if (!lhs || !rhs)
104     return {};
105 
106   // Compute the 64-bit result and the 32-bit result.
107   std::optional<APInt> result64 = calculate(lhs.getValue(), rhs.getValue());
108   if (!result64)
109     return {};
110   std::optional<APInt> result32 =
111       calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32));
112   if (!result32)
113     return {};
114   // Compare the truncated 64-bit result to the 32-bit result.
115   if (result64->trunc(32) != *result32)
116     return {};
117   // The operation can be folded for these particular operands.
118   return IntegerAttr::get(IndexType::get(lhs.getContext()), *result64);
119 }
120 
121 /// Helper for associative and commutative binary ops that can be transformed:
122 /// `x = op(v, c1); y = op(x, c2)` -> `tmp = op(c1, c2); y = op(v, tmp)`
123 /// where c1 and c2 are constants. It is expected that `tmp` will be folded.
124 template <typename BinaryOp>
125 LogicalResult
126 canonicalizeAssociativeCommutativeBinaryOp(BinaryOp op,
127                                            PatternRewriter &rewriter) {
128   if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant()))
129     return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
130 
131   auto lhsOp = op.getLhs().template getDefiningOp<BinaryOp>();
132   if (!lhsOp)
133     return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not the same BinaryOp");
134 
135   if (!mlir::matchPattern(lhsOp.getRhs(), mlir::m_Constant()))
136     return rewriter.notifyMatchFailure(op.getLoc(), "RHS of LHS op is not a constant");
137 
138   Value c = rewriter.createOrFold<BinaryOp>(op->getLoc(), op.getRhs(),
139                                            lhsOp.getRhs());
140   if (c.getDefiningOp<BinaryOp>())
141     return rewriter.notifyMatchFailure(op.getLoc(), "new BinaryOp was not folded");
142 
143   rewriter.replaceOpWithNewOp<BinaryOp>(op, lhsOp.getLhs(), c);
144   return success();
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // AddOp
149 //===----------------------------------------------------------------------===//
150 
151 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
152   if (OpFoldResult result = foldBinaryOpUnchecked(
153           adaptor.getOperands(),
154           [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; }))
155     return result;
156 
157   if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
158     // Fold `add(x, 0) -> x`.
159     if (rhs.getValue().isZero())
160       return getLhs();
161   }
162 
163   return {};
164 }
165 
166 LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
167   return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
168 }
169 
170 //===----------------------------------------------------------------------===//
171 // SubOp
172 //===----------------------------------------------------------------------===//
173 
174 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
175   if (OpFoldResult result = foldBinaryOpUnchecked(
176           adaptor.getOperands(),
177           [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; }))
178     return result;
179 
180   if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
181     // Fold `sub(x, 0) -> x`.
182     if (rhs.getValue().isZero())
183       return getLhs();
184   }
185 
186   return {};
187 }
188 
189 //===----------------------------------------------------------------------===//
190 // MulOp
191 //===----------------------------------------------------------------------===//
192 
193 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
194   if (OpFoldResult result = foldBinaryOpUnchecked(
195           adaptor.getOperands(),
196           [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; }))
197     return result;
198 
199   if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
200     // Fold `mul(x, 1) -> x`.
201     if (rhs.getValue().isOne())
202       return getLhs();
203     // Fold `mul(x, 0) -> 0`.
204     if (rhs.getValue().isZero())
205       return rhs;
206   }
207 
208   return {};
209 }
210 
211 LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
212   return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
213 }
214 
215 //===----------------------------------------------------------------------===//
216 // DivSOp
217 //===----------------------------------------------------------------------===//
218 
219 OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
220   return foldBinaryOpChecked(
221       adaptor.getOperands(),
222       [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
223         // Don't fold division by zero.
224         if (rhs.isZero())
225           return std::nullopt;
226         return lhs.sdiv(rhs);
227       });
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // DivUOp
232 //===----------------------------------------------------------------------===//
233 
234 OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
235   return foldBinaryOpChecked(
236       adaptor.getOperands(),
237       [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
238         // Don't fold division by zero.
239         if (rhs.isZero())
240           return std::nullopt;
241         return lhs.udiv(rhs);
242       });
243 }
244 
245 //===----------------------------------------------------------------------===//
246 // CeilDivSOp
247 //===----------------------------------------------------------------------===//
248 
249 /// Compute `ceildivs(n, m)` as `x = m > 0 ? -1 : 1` and then
250 /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`.
251 static std::optional<APInt> calculateCeilDivS(const APInt &n, const APInt &m) {
252   // Don't fold division by zero.
253   if (m.isZero())
254     return std::nullopt;
255   // Short-circuit the zero case.
256   if (n.isZero())
257     return n;
258 
259   bool mGtZ = m.sgt(0);
260   if (n.sgt(0) != mGtZ) {
261     // If the operands have different signs, compute the negative result. Signed
262     // division overflow is not possible, since if `m == -1`, `n` can be at most
263     // `INT_MAX`, and `-INT_MAX != INT_MIN` in two's complement.
264     return -(-n).sdiv(m);
265   }
266   // Otherwise, compute the positive result. Signed division overflow is not
267   // possible since if `m == -1`, `x` will be `1`.
268   int64_t x = mGtZ ? -1 : 1;
269   return (n + x).sdiv(m) + 1;
270 }
271 
272 OpFoldResult CeilDivSOp::fold(FoldAdaptor adaptor) {
273   return foldBinaryOpChecked(adaptor.getOperands(), calculateCeilDivS);
274 }
275 
276 //===----------------------------------------------------------------------===//
277 // CeilDivUOp
278 //===----------------------------------------------------------------------===//
279 
280 OpFoldResult CeilDivUOp::fold(FoldAdaptor adaptor) {
281   // Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`.
282   return foldBinaryOpChecked(
283       adaptor.getOperands(),
284       [](const APInt &n, const APInt &m) -> std::optional<APInt> {
285         // Don't fold division by zero.
286         if (m.isZero())
287           return std::nullopt;
288         // Short-circuit the zero case.
289         if (n.isZero())
290           return n;
291 
292         return (n - 1).udiv(m) + 1;
293       });
294 }
295 
296 //===----------------------------------------------------------------------===//
297 // FloorDivSOp
298 //===----------------------------------------------------------------------===//
299 
300 /// Compute `floordivs(n, m)` as `x = m < 0 ? 1 : -1` and then
301 /// `n*m < 0 ? -1 - (x-n)/m : n/m`.
302 static std::optional<APInt> calculateFloorDivS(const APInt &n, const APInt &m) {
303   // Don't fold division by zero.
304   if (m.isZero())
305     return std::nullopt;
306   // Short-circuit the zero case.
307   if (n.isZero())
308     return n;
309 
310   bool mLtZ = m.slt(0);
311   if (n.slt(0) == mLtZ) {
312     // If the operands have the same sign, compute the positive result.
313     return n.sdiv(m);
314   }
315   // If the operands have different signs, compute the negative result. Signed
316   // division overflow is not possible since if `m == -1`, `x` will be 1 and
317   // `n` can be at most `INT_MAX`.
318   int64_t x = mLtZ ? 1 : -1;
319   return -1 - (x - n).sdiv(m);
320 }
321 
322 OpFoldResult FloorDivSOp::fold(FoldAdaptor adaptor) {
323   return foldBinaryOpChecked(adaptor.getOperands(), calculateFloorDivS);
324 }
325 
326 //===----------------------------------------------------------------------===//
327 // RemSOp
328 //===----------------------------------------------------------------------===//
329 
330 OpFoldResult RemSOp::fold(FoldAdaptor adaptor) {
331   return foldBinaryOpChecked(
332       adaptor.getOperands(),
333       [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
334         // Don't fold division by zero.
335         if (rhs.isZero())
336           return std::nullopt;
337         return lhs.srem(rhs);
338       });
339 }
340 
341 //===----------------------------------------------------------------------===//
342 // RemUOp
343 //===----------------------------------------------------------------------===//
344 
345 OpFoldResult RemUOp::fold(FoldAdaptor adaptor) {
346   return foldBinaryOpChecked(
347       adaptor.getOperands(),
348       [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
349         // Don't fold division by zero.
350         if (rhs.isZero())
351           return std::nullopt;
352         return lhs.urem(rhs);
353       });
354 }
355 
356 //===----------------------------------------------------------------------===//
357 // MaxSOp
358 //===----------------------------------------------------------------------===//
359 
360 OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) {
361   return foldBinaryOpChecked(adaptor.getOperands(),
362                              [](const APInt &lhs, const APInt &rhs) {
363                                return lhs.sgt(rhs) ? lhs : rhs;
364                              });
365 }
366 
367 LogicalResult MaxSOp::canonicalize(MaxSOp op, PatternRewriter &rewriter) {
368   return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
369 }
370 
371 //===----------------------------------------------------------------------===//
372 // MaxUOp
373 //===----------------------------------------------------------------------===//
374 
375 OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) {
376   return foldBinaryOpChecked(adaptor.getOperands(),
377                              [](const APInt &lhs, const APInt &rhs) {
378                                return lhs.ugt(rhs) ? lhs : rhs;
379                              });
380 }
381 
382 LogicalResult MaxUOp::canonicalize(MaxUOp op, PatternRewriter &rewriter) {
383   return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
384 }
385 
386 //===----------------------------------------------------------------------===//
387 // MinSOp
388 //===----------------------------------------------------------------------===//
389 
390 OpFoldResult MinSOp::fold(FoldAdaptor adaptor) {
391   return foldBinaryOpChecked(adaptor.getOperands(),
392                              [](const APInt &lhs, const APInt &rhs) {
393                                return lhs.slt(rhs) ? lhs : rhs;
394                              });
395 }
396 
397 LogicalResult MinSOp::canonicalize(MinSOp op, PatternRewriter &rewriter) {
398   return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
399 }
400 
401 //===----------------------------------------------------------------------===//
402 // MinUOp
403 //===----------------------------------------------------------------------===//
404 
405 OpFoldResult MinUOp::fold(FoldAdaptor adaptor) {
406   return foldBinaryOpChecked(adaptor.getOperands(),
407                              [](const APInt &lhs, const APInt &rhs) {
408                                return lhs.ult(rhs) ? lhs : rhs;
409                              });
410 }
411 
412 LogicalResult MinUOp::canonicalize(MinUOp op, PatternRewriter &rewriter) {
413   return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
414 }
415 
416 //===----------------------------------------------------------------------===//
417 // ShlOp
418 //===----------------------------------------------------------------------===//
419 
420 OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
421   return foldBinaryOpUnchecked(
422       adaptor.getOperands(),
423       [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
424         // We cannot fold if the RHS is greater than or equal to 32 because
425         // this would be UB in 32-bit systems but not on 64-bit systems. RHS is
426         // already treated as unsigned.
427         if (rhs.uge(32))
428           return {};
429         return lhs << rhs;
430       });
431 }
432 
433 //===----------------------------------------------------------------------===//
434 // ShrSOp
435 //===----------------------------------------------------------------------===//
436 
437 OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
438   return foldBinaryOpChecked(
439       adaptor.getOperands(),
440       [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
441         // Don't fold if RHS is greater than or equal to 32.
442         if (rhs.uge(32))
443           return {};
444         return lhs.ashr(rhs);
445       });
446 }
447 
448 //===----------------------------------------------------------------------===//
449 // ShrUOp
450 //===----------------------------------------------------------------------===//
451 
452 OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
453   return foldBinaryOpChecked(
454       adaptor.getOperands(),
455       [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
456         // Don't fold if RHS is greater than or equal to 32.
457         if (rhs.uge(32))
458           return {};
459         return lhs.lshr(rhs);
460       });
461 }
462 
463 //===----------------------------------------------------------------------===//
464 // AndOp
465 //===----------------------------------------------------------------------===//
466 
467 OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
468   return foldBinaryOpUnchecked(
469       adaptor.getOperands(),
470       [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
471 }
472 
473 LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
474   return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
475 }
476 
477 //===----------------------------------------------------------------------===//
478 // OrOp
479 //===----------------------------------------------------------------------===//
480 
481 OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
482   return foldBinaryOpUnchecked(
483       adaptor.getOperands(),
484       [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
485 }
486 
487 LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
488   return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
489 }
490 
491 //===----------------------------------------------------------------------===//
492 // XOrOp
493 //===----------------------------------------------------------------------===//
494 
495 OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
496   return foldBinaryOpUnchecked(
497       adaptor.getOperands(),
498       [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
499 }
500 
501 LogicalResult XOrOp::canonicalize(XOrOp op, PatternRewriter &rewriter) {
502   return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
503 }
504 
505 //===----------------------------------------------------------------------===//
506 // CastSOp
507 //===----------------------------------------------------------------------===//
508 
509 static OpFoldResult
510 foldCastOp(Attribute input, Type type,
511            function_ref<APInt(const APInt &, unsigned)> extFn,
512            function_ref<APInt(const APInt &, unsigned)> extOrTruncFn) {
513   auto attr = dyn_cast_if_present<IntegerAttr>(input);
514   if (!attr)
515     return {};
516   const APInt &value = attr.getValue();
517 
518   if (isa<IndexType>(type)) {
519     // When casting to an index type, perform the cast assuming a 64-bit target.
520     // The result can be truncated to 32 bits as needed and always be correct.
521     // This is because `cast32(cast64(value)) == cast32(value)`.
522     APInt result = extOrTruncFn(value, 64);
523     return IntegerAttr::get(type, result);
524   }
525 
526   // When casting from an index type, we must ensure the results respect
527   // `cast_t(value) == cast_t(trunc32(value))`.
528   auto intType = cast<IntegerType>(type);
529   unsigned width = intType.getWidth();
530 
531   // If the result type is at most 32 bits, then the cast can always be folded
532   // because it is always a truncation.
533   if (width <= 32) {
534     APInt result = value.trunc(width);
535     return IntegerAttr::get(type, result);
536   }
537 
538   // If the result type is at least 64 bits, then the cast is always a
539   // extension. The results will differ if `trunc32(value) != value)`.
540   if (width >= 64) {
541     if (extFn(value.trunc(32), 64) != value)
542       return {};
543     APInt result = extFn(value, width);
544     return IntegerAttr::get(type, result);
545   }
546 
547   // Otherwise, we just have to check the property directly.
548   APInt result = value.trunc(width);
549   if (result != extFn(value.trunc(32), width))
550     return {};
551   return IntegerAttr::get(type, result);
552 }
553 
554 bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
555   return llvm::isa<IndexType>(lhsTypes.front()) !=
556          llvm::isa<IndexType>(rhsTypes.front());
557 }
558 
559 OpFoldResult CastSOp::fold(FoldAdaptor adaptor) {
560   return foldCastOp(
561       adaptor.getInput(), getType(),
562       [](const APInt &x, unsigned width) { return x.sext(width); },
563       [](const APInt &x, unsigned width) { return x.sextOrTrunc(width); });
564 }
565 
566 //===----------------------------------------------------------------------===//
567 // CastUOp
568 //===----------------------------------------------------------------------===//
569 
570 bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
571   return llvm::isa<IndexType>(lhsTypes.front()) !=
572          llvm::isa<IndexType>(rhsTypes.front());
573 }
574 
575 OpFoldResult CastUOp::fold(FoldAdaptor adaptor) {
576   return foldCastOp(
577       adaptor.getInput(), getType(),
578       [](const APInt &x, unsigned width) { return x.zext(width); },
579       [](const APInt &x, unsigned width) { return x.zextOrTrunc(width); });
580 }
581 
582 //===----------------------------------------------------------------------===//
583 // CmpOp
584 //===----------------------------------------------------------------------===//
585 
586 /// Compare two integers according to the comparison predicate.
587 bool compareIndices(const APInt &lhs, const APInt &rhs,
588                     IndexCmpPredicate pred) {
589   switch (pred) {
590   case IndexCmpPredicate::EQ:
591     return lhs.eq(rhs);
592   case IndexCmpPredicate::NE:
593     return lhs.ne(rhs);
594   case IndexCmpPredicate::SGE:
595     return lhs.sge(rhs);
596   case IndexCmpPredicate::SGT:
597     return lhs.sgt(rhs);
598   case IndexCmpPredicate::SLE:
599     return lhs.sle(rhs);
600   case IndexCmpPredicate::SLT:
601     return lhs.slt(rhs);
602   case IndexCmpPredicate::UGE:
603     return lhs.uge(rhs);
604   case IndexCmpPredicate::UGT:
605     return lhs.ugt(rhs);
606   case IndexCmpPredicate::ULE:
607     return lhs.ule(rhs);
608   case IndexCmpPredicate::ULT:
609     return lhs.ult(rhs);
610   }
611   llvm_unreachable("unhandled IndexCmpPredicate predicate");
612 }
613 
614 /// `cmp(max/min(x, cstA), cstB)` can be folded to a constant depending on the
615 /// values of `cstA` and `cstB`, the max or min operation, and the comparison
616 /// predicate. Check whether the value folds in both 32-bit and 64-bit
617 /// arithmetic and to the same value.
618 static std::optional<bool> foldCmpOfMaxOrMin(Operation *lhsOp,
619                                              const APInt &cstA,
620                                              const APInt &cstB, unsigned width,
621                                              IndexCmpPredicate pred) {
622   ConstantIntRanges lhsRange = TypeSwitch<Operation *, ConstantIntRanges>(lhsOp)
623                                    .Case([&](MinSOp op) {
624                                      return ConstantIntRanges::fromSigned(
625                                          APInt::getSignedMinValue(width), cstA);
626                                    })
627                                    .Case([&](MinUOp op) {
628                                      return ConstantIntRanges::fromUnsigned(
629                                          APInt::getMinValue(width), cstA);
630                                    })
631                                    .Case([&](MaxSOp op) {
632                                      return ConstantIntRanges::fromSigned(
633                                          cstA, APInt::getSignedMaxValue(width));
634                                    })
635                                    .Case([&](MaxUOp op) {
636                                      return ConstantIntRanges::fromUnsigned(
637                                          cstA, APInt::getMaxValue(width));
638                                    });
639   return intrange::evaluatePred(static_cast<intrange::CmpPredicate>(pred),
640                                 lhsRange, ConstantIntRanges::constant(cstB));
641 }
642 
643 /// Return the result of `cmp(pred, x, x)`
644 static bool compareSameArgs(IndexCmpPredicate pred) {
645   switch (pred) {
646   case IndexCmpPredicate::EQ:
647   case IndexCmpPredicate::SGE:
648   case IndexCmpPredicate::SLE:
649   case IndexCmpPredicate::UGE:
650   case IndexCmpPredicate::ULE:
651     return true;
652   case IndexCmpPredicate::NE:
653   case IndexCmpPredicate::SGT:
654   case IndexCmpPredicate::SLT:
655   case IndexCmpPredicate::UGT:
656   case IndexCmpPredicate::ULT:
657     return false;
658   }
659   llvm_unreachable("unknown predicate in compareSameArgs");
660 }
661 
662 OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
663   // Attempt to fold if both inputs are constant.
664   auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
665   auto rhs = dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
666   if (lhs && rhs) {
667     // Perform the comparison in 64-bit and 32-bit.
668     bool result64 = compareIndices(lhs.getValue(), rhs.getValue(), getPred());
669     bool result32 = compareIndices(lhs.getValue().trunc(32),
670                                    rhs.getValue().trunc(32), getPred());
671     if (result64 == result32)
672       return BoolAttr::get(getContext(), result64);
673   }
674 
675   // Fold `cmp(max/min(x, cstA), cstB)`.
676   Operation *lhsOp = getLhs().getDefiningOp();
677   IntegerAttr cstA;
678   if (isa_and_nonnull<MinSOp, MinUOp, MaxSOp, MaxUOp>(lhsOp) &&
679       matchPattern(lhsOp->getOperand(1), m_Constant(&cstA)) && rhs) {
680     std::optional<bool> result64 = foldCmpOfMaxOrMin(
681         lhsOp, cstA.getValue(), rhs.getValue(), 64, getPred());
682     std::optional<bool> result32 =
683         foldCmpOfMaxOrMin(lhsOp, cstA.getValue().trunc(32),
684                           rhs.getValue().trunc(32), 32, getPred());
685     // Fold if the 32-bit and 64-bit results are the same.
686     if (result64 && result32 && *result64 == *result32)
687       return BoolAttr::get(getContext(), *result64);
688   }
689 
690   // Fold `cmp(x, x)`
691   if (getLhs() == getRhs())
692     return BoolAttr::get(getContext(), compareSameArgs(getPred()));
693 
694   return {};
695 }
696 
697 /// Canonicalize
698 /// `x - y cmp 0` to `x cmp y`. or `x - y cmp 0` to `x cmp y`.
699 /// `0 cmp x - y` to `y cmp x`. or `0 cmp x - y` to `y cmp x`.
700 LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) {
701   IntegerAttr cmpRhs;
702   IntegerAttr cmpLhs;
703 
704   bool rhsIsZero = matchPattern(op.getRhs(), m_Constant(&cmpRhs)) &&
705                    cmpRhs.getValue().isZero();
706   bool lhsIsZero = matchPattern(op.getLhs(), m_Constant(&cmpLhs)) &&
707                    cmpLhs.getValue().isZero();
708   if (!rhsIsZero && !lhsIsZero)
709     return rewriter.notifyMatchFailure(op.getLoc(),
710                                        "cmp is not comparing something with 0");
711   SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOp<index::SubOp>()
712                           : op.getRhs().getDefiningOp<index::SubOp>();
713   if (!subOp)
714     return rewriter.notifyMatchFailure(
715         op.getLoc(), "non-zero operand is not a result of subtraction");
716 
717   index::CmpOp newCmp;
718   if (rhsIsZero)
719     newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
720                                            subOp.getLhs(), subOp.getRhs());
721   else
722     newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
723                                            subOp.getRhs(), subOp.getLhs());
724   rewriter.replaceOp(op, newCmp);
725   return success();
726 }
727 
728 //===----------------------------------------------------------------------===//
729 // ConstantOp
730 //===----------------------------------------------------------------------===//
731 
732 void ConstantOp::getAsmResultNames(
733     function_ref<void(Value, StringRef)> setNameFn) {
734   SmallString<32> specialNameBuffer;
735   llvm::raw_svector_ostream specialName(specialNameBuffer);
736   specialName << "idx" << getValueAttr().getValue();
737   setNameFn(getResult(), specialName.str());
738 }
739 
740 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
741 
742 void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) {
743   build(b, state, b.getIndexType(), b.getIndexAttr(value));
744 }
745 
746 //===----------------------------------------------------------------------===//
747 // BoolConstantOp
748 //===----------------------------------------------------------------------===//
749 
750 OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
751   return getValueAttr();
752 }
753 
754 void BoolConstantOp::getAsmResultNames(
755     function_ref<void(Value, StringRef)> setNameFn) {
756   setNameFn(getResult(), getValue() ? "true" : "false");
757 }
758 
759 //===----------------------------------------------------------------------===//
760 // ODS-Generated Definitions
761 //===----------------------------------------------------------------------===//
762 
763 #define GET_OP_CLASSES
764 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
765