xref: /llvm-project/mlir/lib/Dialect/Arith/IR/ArithOps.cpp (revision ac87d6b03642eca3901a7776d73be368299402e9)
1 //===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cassert>
10 #include <cstdint>
11 #include <functional>
12 #include <utility>
13 
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/CommonFolders.h"
16 #include "mlir/Dialect/UB/IR/UBOps.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinAttributeInterfaces.h"
19 #include "mlir/IR/BuiltinAttributes.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/OpImplementation.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 #include "mlir/Support/LogicalResult.h"
25 
26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/ADT/APSInt.h"
29 #include "llvm/ADT/FloatingPointMode.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallString.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 
35 using namespace mlir;
36 using namespace mlir::arith;
37 
38 //===----------------------------------------------------------------------===//
39 // Pattern helpers
40 //===----------------------------------------------------------------------===//
41 
42 static IntegerAttr
43 applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs,
44                     Attribute rhs,
45                     function_ref<APInt(const APInt &, const APInt &)> binFn) {
46   APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
47   APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
48   APInt value = binFn(lhsVal, rhsVal);
49   return IntegerAttr::get(res.getType(), value);
50 }
51 
52 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
53                                    Attribute lhs, Attribute rhs) {
54   return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
55 }
56 
57 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
58                                    Attribute lhs, Attribute rhs) {
59   return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
60 }
61 
62 static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
63                                    Attribute lhs, Attribute rhs) {
64   return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
65 }
66 
67 // Merge overflow flags from 2 ops, selecting the most conservative combination.
68 static IntegerOverflowFlagsAttr
69 mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
70                    IntegerOverflowFlagsAttr val2) {
71   return IntegerOverflowFlagsAttr::get(val1.getContext(),
72                                        val1.getValue() & val2.getValue());
73 }
74 
75 /// Invert an integer comparison predicate.
76 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
77   switch (pred) {
78   case arith::CmpIPredicate::eq:
79     return arith::CmpIPredicate::ne;
80   case arith::CmpIPredicate::ne:
81     return arith::CmpIPredicate::eq;
82   case arith::CmpIPredicate::slt:
83     return arith::CmpIPredicate::sge;
84   case arith::CmpIPredicate::sle:
85     return arith::CmpIPredicate::sgt;
86   case arith::CmpIPredicate::sgt:
87     return arith::CmpIPredicate::sle;
88   case arith::CmpIPredicate::sge:
89     return arith::CmpIPredicate::slt;
90   case arith::CmpIPredicate::ult:
91     return arith::CmpIPredicate::uge;
92   case arith::CmpIPredicate::ule:
93     return arith::CmpIPredicate::ugt;
94   case arith::CmpIPredicate::ugt:
95     return arith::CmpIPredicate::ule;
96   case arith::CmpIPredicate::uge:
97     return arith::CmpIPredicate::ult;
98   }
99   llvm_unreachable("unknown cmpi predicate kind");
100 }
101 
102 /// Equivalent to
103 /// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
104 ///
105 /// Not possible to implement as chain of calls as this would introduce a
106 /// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
107 /// on the LLVM dialect and on translation to LLVM.
108 static llvm::RoundingMode
109 convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
110   switch (roundingMode) {
111   case RoundingMode::downward:
112     return llvm::RoundingMode::TowardNegative;
113   case RoundingMode::to_nearest_away:
114     return llvm::RoundingMode::NearestTiesToAway;
115   case RoundingMode::to_nearest_even:
116     return llvm::RoundingMode::NearestTiesToEven;
117   case RoundingMode::toward_zero:
118     return llvm::RoundingMode::TowardZero;
119   case RoundingMode::upward:
120     return llvm::RoundingMode::TowardPositive;
121   }
122   llvm_unreachable("Unhandled rounding mode");
123 }
124 
125 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
126   return arith::CmpIPredicateAttr::get(pred.getContext(),
127                                        invertPredicate(pred.getValue()));
128 }
129 
130 static int64_t getScalarOrElementWidth(Type type) {
131   Type elemTy = getElementTypeOrSelf(type);
132   if (elemTy.isIntOrFloat())
133     return elemTy.getIntOrFloatBitWidth();
134 
135   return -1;
136 }
137 
138 static int64_t getScalarOrElementWidth(Value value) {
139   return getScalarOrElementWidth(value.getType());
140 }
141 
142 static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
143   APInt value;
144   if (matchPattern(attr, m_ConstantInt(&value)))
145     return value;
146 
147   return failure();
148 }
149 
150 static Attribute getBoolAttribute(Type type, bool value) {
151   auto boolAttr = BoolAttr::get(type.getContext(), value);
152   ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
153   if (!shapedType)
154     return boolAttr;
155   return DenseElementsAttr::get(shapedType, boolAttr);
156 }
157 
158 //===----------------------------------------------------------------------===//
159 // TableGen'd canonicalization patterns
160 //===----------------------------------------------------------------------===//
161 
162 namespace {
163 #include "ArithCanonicalization.inc"
164 } // namespace
165 
166 //===----------------------------------------------------------------------===//
167 // Common helpers
168 //===----------------------------------------------------------------------===//
169 
170 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
171 static Type getI1SameShape(Type type) {
172   auto i1Type = IntegerType::get(type.getContext(), 1);
173   if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
174     return shapedType.cloneWith(std::nullopt, i1Type);
175   if (llvm::isa<UnrankedTensorType>(type))
176     return UnrankedTensorType::get(i1Type);
177   return i1Type;
178 }
179 
180 //===----------------------------------------------------------------------===//
181 // ConstantOp
182 //===----------------------------------------------------------------------===//
183 
184 void arith::ConstantOp::getAsmResultNames(
185     function_ref<void(Value, StringRef)> setNameFn) {
186   auto type = getType();
187   if (auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
188     auto intType = llvm::dyn_cast<IntegerType>(type);
189 
190     // Sugar i1 constants with 'true' and 'false'.
191     if (intType && intType.getWidth() == 1)
192       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
193 
194     // Otherwise, build a complex name with the value and type.
195     SmallString<32> specialNameBuffer;
196     llvm::raw_svector_ostream specialName(specialNameBuffer);
197     specialName << 'c' << intCst.getValue();
198     if (intType)
199       specialName << '_' << type;
200     setNameFn(getResult(), specialName.str());
201   } else {
202     setNameFn(getResult(), "cst");
203   }
204 }
205 
206 /// TODO: disallow arith.constant to return anything other than signless integer
207 /// or float like.
208 LogicalResult arith::ConstantOp::verify() {
209   auto type = getType();
210   // The value's type must match the return type.
211   if (getValue().getType() != type) {
212     return emitOpError() << "value type " << getValue().getType()
213                          << " must match return type: " << type;
214   }
215   // Integer values must be signless.
216   if (llvm::isa<IntegerType>(type) &&
217       !llvm::cast<IntegerType>(type).isSignless())
218     return emitOpError("integer return type must be signless");
219   // Any float or elements attribute are acceptable.
220   if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
221     return emitOpError(
222         "value must be an integer, float, or elements attribute");
223   }
224 
225   // Note, we could relax this for vectors with 1 scalable dim, e.g.:
226   //  * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
227   // However, this would most likely require updating the lowerings to LLVM.
228   if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
229     return emitOpError(
230         "intializing scalable vectors with elements attribute is not supported"
231         " unless it's a vector splat");
232   return success();
233 }
234 
235 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
236   // The value's type must be the same as the provided type.
237   auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
238   if (!typedAttr || typedAttr.getType() != type)
239     return false;
240   // Integer values must be signless.
241   if (llvm::isa<IntegerType>(type) &&
242       !llvm::cast<IntegerType>(type).isSignless())
243     return false;
244   // Integer, float, and element attributes are buildable.
245   return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
246 }
247 
248 ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
249                                           Type type, Location loc) {
250   if (isBuildableWith(value, type))
251     return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
252   return nullptr;
253 }
254 
255 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
256 
257 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
258                                  int64_t value, unsigned width) {
259   auto type = builder.getIntegerType(width);
260   arith::ConstantOp::build(builder, result, type,
261                            builder.getIntegerAttr(type, value));
262 }
263 
264 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
265                                  int64_t value, Type type) {
266   assert(type.isSignlessInteger() &&
267          "ConstantIntOp can only have signless integer type values");
268   arith::ConstantOp::build(builder, result, type,
269                            builder.getIntegerAttr(type, value));
270 }
271 
272 bool arith::ConstantIntOp::classof(Operation *op) {
273   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
274     return constOp.getType().isSignlessInteger();
275   return false;
276 }
277 
278 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
279                                    const APFloat &value, FloatType type) {
280   arith::ConstantOp::build(builder, result, type,
281                            builder.getFloatAttr(type, value));
282 }
283 
284 bool arith::ConstantFloatOp::classof(Operation *op) {
285   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
286     return llvm::isa<FloatType>(constOp.getType());
287   return false;
288 }
289 
290 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
291                                    int64_t value) {
292   arith::ConstantOp::build(builder, result, builder.getIndexType(),
293                            builder.getIndexAttr(value));
294 }
295 
296 bool arith::ConstantIndexOp::classof(Operation *op) {
297   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
298     return constOp.getType().isIndex();
299   return false;
300 }
301 
302 //===----------------------------------------------------------------------===//
303 // AddIOp
304 //===----------------------------------------------------------------------===//
305 
306 OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
307   // addi(x, 0) -> x
308   if (matchPattern(adaptor.getRhs(), m_Zero()))
309     return getLhs();
310 
311   // addi(subi(a, b), b) -> a
312   if (auto sub = getLhs().getDefiningOp<SubIOp>())
313     if (getRhs() == sub.getRhs())
314       return sub.getLhs();
315 
316   // addi(b, subi(a, b)) -> a
317   if (auto sub = getRhs().getDefiningOp<SubIOp>())
318     if (getLhs() == sub.getRhs())
319       return sub.getLhs();
320 
321   return constFoldBinaryOp<IntegerAttr>(
322       adaptor.getOperands(),
323       [](APInt a, const APInt &b) { return std::move(a) + b; });
324 }
325 
326 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
327                                                 MLIRContext *context) {
328   patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
329                AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
330 }
331 
332 //===----------------------------------------------------------------------===//
333 // AddUIExtendedOp
334 //===----------------------------------------------------------------------===//
335 
336 std::optional<SmallVector<int64_t, 4>>
337 arith::AddUIExtendedOp::getShapeForUnroll() {
338   if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
339     return llvm::to_vector<4>(vt.getShape());
340   return std::nullopt;
341 }
342 
343 // Returns the overflow bit, assuming that `sum` is the result of unsigned
344 // addition of `operand` and another number.
345 static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
346   return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
347 }
348 
349 LogicalResult
350 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
351                              SmallVectorImpl<OpFoldResult> &results) {
352   Type overflowTy = getOverflow().getType();
353   // addui_extended(x, 0) -> x, false
354   if (matchPattern(getRhs(), m_Zero())) {
355     Builder builder(getContext());
356     auto falseValue = builder.getZeroAttr(overflowTy);
357 
358     results.push_back(getLhs());
359     results.push_back(falseValue);
360     return success();
361   }
362 
363   // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
364   // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
365   // operands. If that succeeds, calculate the overflow bit based on the sum
366   // and the first (constant) operand, `lhs`.
367   if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
368           adaptor.getOperands(),
369           [](APInt a, const APInt &b) { return std::move(a) + b; })) {
370     Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
371         ArrayRef({sumAttr, adaptor.getLhs()}),
372         getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
373         calculateUnsignedOverflow);
374     if (!overflowAttr)
375       return failure();
376 
377     results.push_back(sumAttr);
378     results.push_back(overflowAttr);
379     return success();
380   }
381 
382   return failure();
383 }
384 
385 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
386     RewritePatternSet &patterns, MLIRContext *context) {
387   patterns.add<AddUIExtendedToAddI>(context);
388 }
389 
390 //===----------------------------------------------------------------------===//
391 // SubIOp
392 //===----------------------------------------------------------------------===//
393 
394 OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
395   // subi(x,x) -> 0
396   if (getOperand(0) == getOperand(1)) {
397     auto shapedType = dyn_cast<ShapedType>(getType());
398     // We can't generate a constant with a dynamic shaped tensor.
399     if (!shapedType || shapedType.hasStaticShape())
400       return Builder(getContext()).getZeroAttr(getType());
401   }
402   // subi(x,0) -> x
403   if (matchPattern(adaptor.getRhs(), m_Zero()))
404     return getLhs();
405 
406   if (auto add = getLhs().getDefiningOp<AddIOp>()) {
407     // subi(addi(a, b), b) -> a
408     if (getRhs() == add.getRhs())
409       return add.getLhs();
410     // subi(addi(a, b), a) -> b
411     if (getRhs() == add.getLhs())
412       return add.getRhs();
413   }
414 
415   return constFoldBinaryOp<IntegerAttr>(
416       adaptor.getOperands(),
417       [](APInt a, const APInt &b) { return std::move(a) - b; });
418 }
419 
420 void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
421                                                 MLIRContext *context) {
422   patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
423                SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
424                SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
425 }
426 
427 //===----------------------------------------------------------------------===//
428 // MulIOp
429 //===----------------------------------------------------------------------===//
430 
431 OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
432   // muli(x, 0) -> 0
433   if (matchPattern(adaptor.getRhs(), m_Zero()))
434     return getRhs();
435   // muli(x, 1) -> x
436   if (matchPattern(adaptor.getRhs(), m_One()))
437     return getLhs();
438   // TODO: Handle the overflow case.
439 
440   // default folder
441   return constFoldBinaryOp<IntegerAttr>(
442       adaptor.getOperands(),
443       [](const APInt &a, const APInt &b) { return a * b; });
444 }
445 
446 void arith::MulIOp::getAsmResultNames(
447     function_ref<void(Value, StringRef)> setNameFn) {
448   if (!isa<IndexType>(getType()))
449     return;
450 
451   // Match vector.vscale by name to avoid depending on the vector dialect (which
452   // is a circular dependency).
453   auto isVscale = [](Operation *op) {
454     return op && op->getName().getStringRef() == "vector.vscale";
455   };
456 
457   IntegerAttr baseValue;
458   auto isVscaleExpr = [&](Value a, Value b) {
459     return matchPattern(a, m_Constant(&baseValue)) &&
460            isVscale(b.getDefiningOp());
461   };
462 
463   if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
464     return;
465 
466   // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
467   SmallString<32> specialNameBuffer;
468   llvm::raw_svector_ostream specialName(specialNameBuffer);
469   specialName << 'c' << baseValue.getInt() << "_vscale";
470   setNameFn(getResult(), specialName.str());
471 }
472 
473 void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
474                                                 MLIRContext *context) {
475   patterns.add<MulIMulIConstant>(context);
476 }
477 
478 //===----------------------------------------------------------------------===//
479 // MulSIExtendedOp
480 //===----------------------------------------------------------------------===//
481 
482 std::optional<SmallVector<int64_t, 4>>
483 arith::MulSIExtendedOp::getShapeForUnroll() {
484   if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
485     return llvm::to_vector<4>(vt.getShape());
486   return std::nullopt;
487 }
488 
489 LogicalResult
490 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
491                              SmallVectorImpl<OpFoldResult> &results) {
492   // mulsi_extended(x, 0) -> 0, 0
493   if (matchPattern(adaptor.getRhs(), m_Zero())) {
494     Attribute zero = adaptor.getRhs();
495     results.push_back(zero);
496     results.push_back(zero);
497     return success();
498   }
499 
500   // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
501   if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
502           adaptor.getOperands(),
503           [](const APInt &a, const APInt &b) { return a * b; })) {
504     // Invoke the constant fold helper again to calculate the 'high' result.
505     Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
506         adaptor.getOperands(), [](const APInt &a, const APInt &b) {
507           return llvm::APIntOps::mulhs(a, b);
508         });
509     assert(highAttr && "Unexpected constant-folding failure");
510 
511     results.push_back(lowAttr);
512     results.push_back(highAttr);
513     return success();
514   }
515 
516   return failure();
517 }
518 
519 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
520     RewritePatternSet &patterns, MLIRContext *context) {
521   patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
522 }
523 
524 //===----------------------------------------------------------------------===//
525 // MulUIExtendedOp
526 //===----------------------------------------------------------------------===//
527 
528 std::optional<SmallVector<int64_t, 4>>
529 arith::MulUIExtendedOp::getShapeForUnroll() {
530   if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
531     return llvm::to_vector<4>(vt.getShape());
532   return std::nullopt;
533 }
534 
535 LogicalResult
536 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
537                              SmallVectorImpl<OpFoldResult> &results) {
538   // mului_extended(x, 0) -> 0, 0
539   if (matchPattern(adaptor.getRhs(), m_Zero())) {
540     Attribute zero = adaptor.getRhs();
541     results.push_back(zero);
542     results.push_back(zero);
543     return success();
544   }
545 
546   // mului_extended(x, 1) -> x, 0
547   if (matchPattern(adaptor.getRhs(), m_One())) {
548     Builder builder(getContext());
549     Attribute zero = builder.getZeroAttr(getLhs().getType());
550     results.push_back(getLhs());
551     results.push_back(zero);
552     return success();
553   }
554 
555   // mului_extended(cst_a, cst_b) -> cst_low, cst_high
556   if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
557           adaptor.getOperands(),
558           [](const APInt &a, const APInt &b) { return a * b; })) {
559     // Invoke the constant fold helper again to calculate the 'high' result.
560     Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
561         adaptor.getOperands(), [](const APInt &a, const APInt &b) {
562           return llvm::APIntOps::mulhu(a, b);
563         });
564     assert(highAttr && "Unexpected constant-folding failure");
565 
566     results.push_back(lowAttr);
567     results.push_back(highAttr);
568     return success();
569   }
570 
571   return failure();
572 }
573 
574 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
575     RewritePatternSet &patterns, MLIRContext *context) {
576   patterns.add<MulUIExtendedToMulI>(context);
577 }
578 
579 //===----------------------------------------------------------------------===//
580 // DivUIOp
581 //===----------------------------------------------------------------------===//
582 
583 /// Fold `(a * b) / b -> a`
584 static Value foldDivMul(Value lhs, Value rhs,
585                         arith::IntegerOverflowFlags ovfFlags) {
586   auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
587   if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
588     return {};
589 
590   if (mul.getLhs() == rhs)
591     return mul.getRhs();
592 
593   if (mul.getRhs() == rhs)
594     return mul.getLhs();
595 
596   return {};
597 }
598 
599 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
600   // divui (x, 1) -> x.
601   if (matchPattern(adaptor.getRhs(), m_One()))
602     return getLhs();
603 
604   // (a * b) / b -> a
605   if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
606     return val;
607 
608   // Don't fold if it would require a division by zero.
609   bool div0 = false;
610   auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
611                                                [&](APInt a, const APInt &b) {
612                                                  if (div0 || !b) {
613                                                    div0 = true;
614                                                    return a;
615                                                  }
616                                                  return a.udiv(b);
617                                                });
618 
619   return div0 ? Attribute() : result;
620 }
621 
622 /// Returns whether an unsigned division by `divisor` is speculatable.
623 static Speculation::Speculatability getDivUISpeculatability(Value divisor) {
624   // X / 0 => UB
625   if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
626     return Speculation::Speculatable;
627 
628   return Speculation::NotSpeculatable;
629 }
630 
631 Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
632   return getDivUISpeculatability(getRhs());
633 }
634 
635 //===----------------------------------------------------------------------===//
636 // DivSIOp
637 //===----------------------------------------------------------------------===//
638 
639 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
640   // divsi (x, 1) -> x.
641   if (matchPattern(adaptor.getRhs(), m_One()))
642     return getLhs();
643 
644   // (a * b) / b -> a
645   if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
646     return val;
647 
648   // Don't fold if it would overflow or if it requires a division by zero.
649   bool overflowOrDiv0 = false;
650   auto result = constFoldBinaryOp<IntegerAttr>(
651       adaptor.getOperands(), [&](APInt a, const APInt &b) {
652         if (overflowOrDiv0 || !b) {
653           overflowOrDiv0 = true;
654           return a;
655         }
656         return a.sdiv_ov(b, overflowOrDiv0);
657       });
658 
659   return overflowOrDiv0 ? Attribute() : result;
660 }
661 
662 /// Returns whether a signed division by `divisor` is speculatable. This
663 /// function conservatively assumes that all signed division by -1 are not
664 /// speculatable.
665 static Speculation::Speculatability getDivSISpeculatability(Value divisor) {
666   // X / 0 => UB
667   // INT_MIN / -1 => UB
668   if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
669       matchPattern(divisor, m_IntRangeWithoutNegOneS()))
670     return Speculation::Speculatable;
671 
672   return Speculation::NotSpeculatable;
673 }
674 
675 Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
676   return getDivSISpeculatability(getRhs());
677 }
678 
679 //===----------------------------------------------------------------------===//
680 // Ceil and floor division folding helpers
681 //===----------------------------------------------------------------------===//
682 
683 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
684                                     bool &overflow) {
685   // Returns (a-1)/b + 1
686   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
687   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
688   return val.sadd_ov(one, overflow);
689 }
690 
691 //===----------------------------------------------------------------------===//
692 // CeilDivUIOp
693 //===----------------------------------------------------------------------===//
694 
695 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
696   // ceildivui (x, 1) -> x.
697   if (matchPattern(adaptor.getRhs(), m_One()))
698     return getLhs();
699 
700   bool overflowOrDiv0 = false;
701   auto result = constFoldBinaryOp<IntegerAttr>(
702       adaptor.getOperands(), [&](APInt a, const APInt &b) {
703         if (overflowOrDiv0 || !b) {
704           overflowOrDiv0 = true;
705           return a;
706         }
707         APInt quotient = a.udiv(b);
708         if (!a.urem(b))
709           return quotient;
710         APInt one(a.getBitWidth(), 1, true);
711         return quotient.uadd_ov(one, overflowOrDiv0);
712       });
713 
714   return overflowOrDiv0 ? Attribute() : result;
715 }
716 
717 Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
718   return getDivUISpeculatability(getRhs());
719 }
720 
721 //===----------------------------------------------------------------------===//
722 // CeilDivSIOp
723 //===----------------------------------------------------------------------===//
724 
725 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
726   // ceildivsi (x, 1) -> x.
727   if (matchPattern(adaptor.getRhs(), m_One()))
728     return getLhs();
729 
730   // Don't fold if it would overflow or if it requires a division by zero.
731   // TODO: This hook won't fold operations where a = MININT, because
732   // negating MININT overflows. This can be improved.
733   bool overflowOrDiv0 = false;
734   auto result = constFoldBinaryOp<IntegerAttr>(
735       adaptor.getOperands(), [&](APInt a, const APInt &b) {
736         if (overflowOrDiv0 || !b) {
737           overflowOrDiv0 = true;
738           return a;
739         }
740         if (!a)
741           return a;
742         // After this point we know that neither a or b are zero.
743         unsigned bits = a.getBitWidth();
744         APInt zero = APInt::getZero(bits);
745         bool aGtZero = a.sgt(zero);
746         bool bGtZero = b.sgt(zero);
747         if (aGtZero && bGtZero) {
748           // Both positive, return ceil(a, b).
749           return signedCeilNonnegInputs(a, b, overflowOrDiv0);
750         }
751 
752         // No folding happens if any of the intermediate arithmetic operations
753         // overflows.
754         bool overflowNegA = false;
755         bool overflowNegB = false;
756         bool overflowDiv = false;
757         bool overflowNegRes = false;
758         if (!aGtZero && !bGtZero) {
759           // Both negative, return ceil(-a, -b).
760           APInt posA = zero.ssub_ov(a, overflowNegA);
761           APInt posB = zero.ssub_ov(b, overflowNegB);
762           APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
763           overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
764           return res;
765         }
766         if (!aGtZero && bGtZero) {
767           // A is negative, b is positive, return - ( -a / b).
768           APInt posA = zero.ssub_ov(a, overflowNegA);
769           APInt div = posA.sdiv_ov(b, overflowDiv);
770           APInt res = zero.ssub_ov(div, overflowNegRes);
771           overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
772           return res;
773         }
774         // A is positive, b is negative, return - (a / -b).
775         APInt posB = zero.ssub_ov(b, overflowNegB);
776         APInt div = a.sdiv_ov(posB, overflowDiv);
777         APInt res = zero.ssub_ov(div, overflowNegRes);
778 
779         overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
780         return res;
781       });
782 
783   return overflowOrDiv0 ? Attribute() : result;
784 }
785 
786 Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
787   return getDivSISpeculatability(getRhs());
788 }
789 
790 //===----------------------------------------------------------------------===//
791 // FloorDivSIOp
792 //===----------------------------------------------------------------------===//
793 
794 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
795   // floordivsi (x, 1) -> x.
796   if (matchPattern(adaptor.getRhs(), m_One()))
797     return getLhs();
798 
799   // Don't fold if it would overflow or if it requires a division by zero.
800   bool overflowOrDiv = false;
801   auto result = constFoldBinaryOp<IntegerAttr>(
802       adaptor.getOperands(), [&](APInt a, const APInt &b) {
803         if (b.isZero()) {
804           overflowOrDiv = true;
805           return a;
806         }
807         return a.sfloordiv_ov(b, overflowOrDiv);
808       });
809 
810   return overflowOrDiv ? Attribute() : result;
811 }
812 
813 //===----------------------------------------------------------------------===//
814 // RemUIOp
815 //===----------------------------------------------------------------------===//
816 
817 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
818   // remui (x, 1) -> 0.
819   if (matchPattern(adaptor.getRhs(), m_One()))
820     return Builder(getContext()).getZeroAttr(getType());
821 
822   // Don't fold if it would require a division by zero.
823   bool div0 = false;
824   auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
825                                                [&](APInt a, const APInt &b) {
826                                                  if (div0 || b.isZero()) {
827                                                    div0 = true;
828                                                    return a;
829                                                  }
830                                                  return a.urem(b);
831                                                });
832 
833   return div0 ? Attribute() : result;
834 }
835 
836 //===----------------------------------------------------------------------===//
837 // RemSIOp
838 //===----------------------------------------------------------------------===//
839 
840 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
841   // remsi (x, 1) -> 0.
842   if (matchPattern(adaptor.getRhs(), m_One()))
843     return Builder(getContext()).getZeroAttr(getType());
844 
845   // Don't fold if it would require a division by zero.
846   bool div0 = false;
847   auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
848                                                [&](APInt a, const APInt &b) {
849                                                  if (div0 || b.isZero()) {
850                                                    div0 = true;
851                                                    return a;
852                                                  }
853                                                  return a.srem(b);
854                                                });
855 
856   return div0 ? Attribute() : result;
857 }
858 
859 //===----------------------------------------------------------------------===//
860 // AndIOp
861 //===----------------------------------------------------------------------===//
862 
863 /// Fold `and(a, and(a, b))` to `and(a, b)`
864 static Value foldAndIofAndI(arith::AndIOp op) {
865   for (bool reversePrev : {false, true}) {
866     auto prev = (reversePrev ? op.getRhs() : op.getLhs())
867                     .getDefiningOp<arith::AndIOp>();
868     if (!prev)
869       continue;
870 
871     Value other = (reversePrev ? op.getLhs() : op.getRhs());
872     if (other != prev.getLhs() && other != prev.getRhs())
873       continue;
874 
875     return prev.getResult();
876   }
877   return {};
878 }
879 
880 OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
881   /// and(x, 0) -> 0
882   if (matchPattern(adaptor.getRhs(), m_Zero()))
883     return getRhs();
884   /// and(x, allOnes) -> x
885   APInt intValue;
886   if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
887       intValue.isAllOnes())
888     return getLhs();
889   /// and(x, not(x)) -> 0
890   if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
891                                           m_ConstantInt(&intValue))) &&
892       intValue.isAllOnes())
893     return Builder(getContext()).getZeroAttr(getType());
894   /// and(not(x), x) -> 0
895   if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
896                                           m_ConstantInt(&intValue))) &&
897       intValue.isAllOnes())
898     return Builder(getContext()).getZeroAttr(getType());
899 
900   /// and(a, and(a, b)) -> and(a, b)
901   if (Value result = foldAndIofAndI(*this))
902     return result;
903 
904   return constFoldBinaryOp<IntegerAttr>(
905       adaptor.getOperands(),
906       [](APInt a, const APInt &b) { return std::move(a) & b; });
907 }
908 
909 //===----------------------------------------------------------------------===//
910 // OrIOp
911 //===----------------------------------------------------------------------===//
912 
913 OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
914   if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
915     /// or(x, 0) -> x
916     if (rhsVal.isZero())
917       return getLhs();
918     /// or(x, <all ones>) -> <all ones>
919     if (rhsVal.isAllOnes())
920       return adaptor.getRhs();
921   }
922 
923   APInt intValue;
924   /// or(x, xor(x, 1)) -> 1
925   if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
926                                           m_ConstantInt(&intValue))) &&
927       intValue.isAllOnes())
928     return getRhs().getDefiningOp<XOrIOp>().getRhs();
929   /// or(xor(x, 1), x) -> 1
930   if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
931                                           m_ConstantInt(&intValue))) &&
932       intValue.isAllOnes())
933     return getLhs().getDefiningOp<XOrIOp>().getRhs();
934 
935   return constFoldBinaryOp<IntegerAttr>(
936       adaptor.getOperands(),
937       [](APInt a, const APInt &b) { return std::move(a) | b; });
938 }
939 
940 //===----------------------------------------------------------------------===//
941 // XOrIOp
942 //===----------------------------------------------------------------------===//
943 
944 OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
945   /// xor(x, 0) -> x
946   if (matchPattern(adaptor.getRhs(), m_Zero()))
947     return getLhs();
948   /// xor(x, x) -> 0
949   if (getLhs() == getRhs())
950     return Builder(getContext()).getZeroAttr(getType());
951   /// xor(xor(x, a), a) -> x
952   /// xor(xor(a, x), a) -> x
953   if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
954     if (prev.getRhs() == getRhs())
955       return prev.getLhs();
956     if (prev.getLhs() == getRhs())
957       return prev.getRhs();
958   }
959   /// xor(a, xor(x, a)) -> x
960   /// xor(a, xor(a, x)) -> x
961   if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
962     if (prev.getRhs() == getLhs())
963       return prev.getLhs();
964     if (prev.getLhs() == getLhs())
965       return prev.getRhs();
966   }
967 
968   return constFoldBinaryOp<IntegerAttr>(
969       adaptor.getOperands(),
970       [](APInt a, const APInt &b) { return std::move(a) ^ b; });
971 }
972 
973 void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
974                                                 MLIRContext *context) {
975   patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
976 }
977 
978 //===----------------------------------------------------------------------===//
979 // NegFOp
980 //===----------------------------------------------------------------------===//
981 
982 OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
983   /// negf(negf(x)) -> x
984   if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
985     return op.getOperand();
986   return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
987                                      [](const APFloat &a) { return -a; });
988 }
989 
990 //===----------------------------------------------------------------------===//
991 // AddFOp
992 //===----------------------------------------------------------------------===//
993 
994 OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
995   // addf(x, -0) -> x
996   if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
997     return getLhs();
998 
999   return constFoldBinaryOp<FloatAttr>(
1000       adaptor.getOperands(),
1001       [](const APFloat &a, const APFloat &b) { return a + b; });
1002 }
1003 
1004 //===----------------------------------------------------------------------===//
1005 // SubFOp
1006 //===----------------------------------------------------------------------===//
1007 
1008 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1009   // subf(x, +0) -> x
1010   if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
1011     return getLhs();
1012 
1013   return constFoldBinaryOp<FloatAttr>(
1014       adaptor.getOperands(),
1015       [](const APFloat &a, const APFloat &b) { return a - b; });
1016 }
1017 
1018 //===----------------------------------------------------------------------===//
1019 // MaximumFOp
1020 //===----------------------------------------------------------------------===//
1021 
1022 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1023   // maximumf(x,x) -> x
1024   if (getLhs() == getRhs())
1025     return getRhs();
1026 
1027   // maximumf(x, -inf) -> x
1028   if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1029     return getLhs();
1030 
1031   return constFoldBinaryOp<FloatAttr>(
1032       adaptor.getOperands(),
1033       [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
1034 }
1035 
1036 //===----------------------------------------------------------------------===//
1037 // MaxNumFOp
1038 //===----------------------------------------------------------------------===//
1039 
1040 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1041   // maxnumf(x,x) -> x
1042   if (getLhs() == getRhs())
1043     return getRhs();
1044 
1045   // maxnumf(x, NaN) -> x
1046   if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1047     return getLhs();
1048 
1049   return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1050 }
1051 
1052 //===----------------------------------------------------------------------===//
1053 // MaxSIOp
1054 //===----------------------------------------------------------------------===//
1055 
1056 OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1057   // maxsi(x,x) -> x
1058   if (getLhs() == getRhs())
1059     return getRhs();
1060 
1061   if (APInt intValue;
1062       matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1063     // maxsi(x,MAX_INT) -> MAX_INT
1064     if (intValue.isMaxSignedValue())
1065       return getRhs();
1066     // maxsi(x, MIN_INT) -> x
1067     if (intValue.isMinSignedValue())
1068       return getLhs();
1069   }
1070 
1071   return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1072                                         [](const APInt &a, const APInt &b) {
1073                                           return llvm::APIntOps::smax(a, b);
1074                                         });
1075 }
1076 
1077 //===----------------------------------------------------------------------===//
1078 // MaxUIOp
1079 //===----------------------------------------------------------------------===//
1080 
1081 OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1082   // maxui(x,x) -> x
1083   if (getLhs() == getRhs())
1084     return getRhs();
1085 
1086   if (APInt intValue;
1087       matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1088     // maxui(x,MAX_INT) -> MAX_INT
1089     if (intValue.isMaxValue())
1090       return getRhs();
1091     // maxui(x, MIN_INT) -> x
1092     if (intValue.isMinValue())
1093       return getLhs();
1094   }
1095 
1096   return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1097                                         [](const APInt &a, const APInt &b) {
1098                                           return llvm::APIntOps::umax(a, b);
1099                                         });
1100 }
1101 
1102 //===----------------------------------------------------------------------===//
1103 // MinimumFOp
1104 //===----------------------------------------------------------------------===//
1105 
1106 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1107   // minimumf(x,x) -> x
1108   if (getLhs() == getRhs())
1109     return getRhs();
1110 
1111   // minimumf(x, +inf) -> x
1112   if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1113     return getLhs();
1114 
1115   return constFoldBinaryOp<FloatAttr>(
1116       adaptor.getOperands(),
1117       [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
1118 }
1119 
1120 //===----------------------------------------------------------------------===//
1121 // MinNumFOp
1122 //===----------------------------------------------------------------------===//
1123 
1124 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1125   // minnumf(x,x) -> x
1126   if (getLhs() == getRhs())
1127     return getRhs();
1128 
1129   // minnumf(x, NaN) -> x
1130   if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1131     return getLhs();
1132 
1133   return constFoldBinaryOp<FloatAttr>(
1134       adaptor.getOperands(),
1135       [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
1136 }
1137 
1138 //===----------------------------------------------------------------------===//
1139 // MinSIOp
1140 //===----------------------------------------------------------------------===//
1141 
1142 OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1143   // minsi(x,x) -> x
1144   if (getLhs() == getRhs())
1145     return getRhs();
1146 
1147   if (APInt intValue;
1148       matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1149     // minsi(x,MIN_INT) -> MIN_INT
1150     if (intValue.isMinSignedValue())
1151       return getRhs();
1152     // minsi(x, MAX_INT) -> x
1153     if (intValue.isMaxSignedValue())
1154       return getLhs();
1155   }
1156 
1157   return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1158                                         [](const APInt &a, const APInt &b) {
1159                                           return llvm::APIntOps::smin(a, b);
1160                                         });
1161 }
1162 
1163 //===----------------------------------------------------------------------===//
1164 // MinUIOp
1165 //===----------------------------------------------------------------------===//
1166 
1167 OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1168   // minui(x,x) -> x
1169   if (getLhs() == getRhs())
1170     return getRhs();
1171 
1172   if (APInt intValue;
1173       matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1174     // minui(x,MIN_INT) -> MIN_INT
1175     if (intValue.isMinValue())
1176       return getRhs();
1177     // minui(x, MAX_INT) -> x
1178     if (intValue.isMaxValue())
1179       return getLhs();
1180   }
1181 
1182   return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1183                                         [](const APInt &a, const APInt &b) {
1184                                           return llvm::APIntOps::umin(a, b);
1185                                         });
1186 }
1187 
1188 //===----------------------------------------------------------------------===//
1189 // MulFOp
1190 //===----------------------------------------------------------------------===//
1191 
1192 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1193   // mulf(x, 1) -> x
1194   if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1195     return getLhs();
1196 
1197   return constFoldBinaryOp<FloatAttr>(
1198       adaptor.getOperands(),
1199       [](const APFloat &a, const APFloat &b) { return a * b; });
1200 }
1201 
1202 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1203                                                 MLIRContext *context) {
1204   patterns.add<MulFOfNegF>(context);
1205 }
1206 
1207 //===----------------------------------------------------------------------===//
1208 // DivFOp
1209 //===----------------------------------------------------------------------===//
1210 
1211 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1212   // divf(x, 1) -> x
1213   if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1214     return getLhs();
1215 
1216   return constFoldBinaryOp<FloatAttr>(
1217       adaptor.getOperands(),
1218       [](const APFloat &a, const APFloat &b) { return a / b; });
1219 }
1220 
1221 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1222                                                 MLIRContext *context) {
1223   patterns.add<DivFOfNegF>(context);
1224 }
1225 
1226 //===----------------------------------------------------------------------===//
1227 // RemFOp
1228 //===----------------------------------------------------------------------===//
1229 
1230 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1231   return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1232                                       [](const APFloat &a, const APFloat &b) {
1233                                         APFloat result(a);
1234                                         // APFloat::mod() offers the remainder
1235                                         // behavior we want, i.e. the result has
1236                                         // the sign of LHS operand.
1237                                         (void)result.mod(b);
1238                                         return result;
1239                                       });
1240 }
1241 
1242 //===----------------------------------------------------------------------===//
1243 // Utility functions for verifying cast ops
1244 //===----------------------------------------------------------------------===//
1245 
1246 template <typename... Types>
1247 using type_list = std::tuple<Types...> *;
1248 
1249 /// Returns a non-null type only if the provided type is one of the allowed
1250 /// types or one of the allowed shaped types of the allowed types. Returns the
1251 /// element type if a valid shaped type is provided.
1252 template <typename... ShapedTypes, typename... ElementTypes>
1253 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
1254                               type_list<ElementTypes...>) {
1255   if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1256     return {};
1257 
1258   auto underlyingType = getElementTypeOrSelf(type);
1259   if (!llvm::isa<ElementTypes...>(underlyingType))
1260     return {};
1261 
1262   return underlyingType;
1263 }
1264 
1265 /// Get allowed underlying types for vectors and tensors.
1266 template <typename... ElementTypes>
1267 static Type getTypeIfLike(Type type) {
1268   return getUnderlyingType(type, type_list<VectorType, TensorType>(),
1269                            type_list<ElementTypes...>());
1270 }
1271 
1272 /// Get allowed underlying types for vectors, tensors, and memrefs.
1273 template <typename... ElementTypes>
1274 static Type getTypeIfLikeOrMemRef(Type type) {
1275   return getUnderlyingType(type,
1276                            type_list<VectorType, TensorType, MemRefType>(),
1277                            type_list<ElementTypes...>());
1278 }
1279 
1280 /// Return false if both types are ranked tensor with mismatching encoding.
1281 static bool hasSameEncoding(Type typeA, Type typeB) {
1282   auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1283   auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1284   if (!rankedTensorA || !rankedTensorB)
1285     return true;
1286   return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1287 }
1288 
1289 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
1290   if (inputs.size() != 1 || outputs.size() != 1)
1291     return false;
1292   if (!hasSameEncoding(inputs.front(), outputs.front()))
1293     return false;
1294   return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1295 }
1296 
1297 //===----------------------------------------------------------------------===//
1298 // Verifiers for integer and floating point extension/truncation ops
1299 //===----------------------------------------------------------------------===//
1300 
1301 // Extend ops can only extend to a wider type.
1302 template <typename ValType, typename Op>
1303 static LogicalResult verifyExtOp(Op op) {
1304   Type srcType = getElementTypeOrSelf(op.getIn().getType());
1305   Type dstType = getElementTypeOrSelf(op.getType());
1306 
1307   if (llvm::cast<ValType>(srcType).getWidth() >=
1308       llvm::cast<ValType>(dstType).getWidth())
1309     return op.emitError("result type ")
1310            << dstType << " must be wider than operand type " << srcType;
1311 
1312   return success();
1313 }
1314 
1315 // Truncate ops can only truncate to a shorter type.
1316 template <typename ValType, typename Op>
1317 static LogicalResult verifyTruncateOp(Op op) {
1318   Type srcType = getElementTypeOrSelf(op.getIn().getType());
1319   Type dstType = getElementTypeOrSelf(op.getType());
1320 
1321   if (llvm::cast<ValType>(srcType).getWidth() <=
1322       llvm::cast<ValType>(dstType).getWidth())
1323     return op.emitError("result type ")
1324            << dstType << " must be shorter than operand type " << srcType;
1325 
1326   return success();
1327 }
1328 
1329 /// Validate a cast that changes the width of a type.
1330 template <template <typename> class WidthComparator, typename... ElementTypes>
1331 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1332   if (!areValidCastInputsAndOutputs(inputs, outputs))
1333     return false;
1334 
1335   auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1336   auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1337   if (!srcType || !dstType)
1338     return false;
1339 
1340   return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1341                                      srcType.getIntOrFloatBitWidth());
1342 }
1343 
1344 /// Attempts to convert `sourceValue` to an APFloat value with
1345 /// `targetSemantics` and `roundingMode`, without any information loss.
1346 static FailureOr<APFloat> convertFloatValue(
1347     APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
1348     llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1349   bool losesInfo = false;
1350   auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1351   if (losesInfo || status != APFloat::opOK)
1352     return failure();
1353 
1354   return sourceValue;
1355 }
1356 
1357 //===----------------------------------------------------------------------===//
1358 // ExtUIOp
1359 //===----------------------------------------------------------------------===//
1360 
1361 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1362   if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1363     getInMutable().assign(lhs.getIn());
1364     return getResult();
1365   }
1366 
1367   Type resType = getElementTypeOrSelf(getType());
1368   unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1369   return constFoldCastOp<IntegerAttr, IntegerAttr>(
1370       adaptor.getOperands(), getType(),
1371       [bitWidth](const APInt &a, bool &castStatus) {
1372         return a.zext(bitWidth);
1373       });
1374 }
1375 
1376 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1377   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1378 }
1379 
1380 LogicalResult arith::ExtUIOp::verify() {
1381   return verifyExtOp<IntegerType>(*this);
1382 }
1383 
1384 //===----------------------------------------------------------------------===//
1385 // ExtSIOp
1386 //===----------------------------------------------------------------------===//
1387 
1388 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1389   if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1390     getInMutable().assign(lhs.getIn());
1391     return getResult();
1392   }
1393 
1394   Type resType = getElementTypeOrSelf(getType());
1395   unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1396   return constFoldCastOp<IntegerAttr, IntegerAttr>(
1397       adaptor.getOperands(), getType(),
1398       [bitWidth](const APInt &a, bool &castStatus) {
1399         return a.sext(bitWidth);
1400       });
1401 }
1402 
1403 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1404   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1405 }
1406 
1407 void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1408                                                  MLIRContext *context) {
1409   patterns.add<ExtSIOfExtUI>(context);
1410 }
1411 
1412 LogicalResult arith::ExtSIOp::verify() {
1413   return verifyExtOp<IntegerType>(*this);
1414 }
1415 
1416 //===----------------------------------------------------------------------===//
1417 // ExtFOp
1418 //===----------------------------------------------------------------------===//
1419 
1420 /// Fold extension of float constants when there is no information loss due the
1421 /// difference in fp semantics.
1422 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1423   if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1424     if (truncFOp.getOperand().getType() == getType()) {
1425       arith::FastMathFlags truncFMF =
1426           truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1427       bool isTruncContract =
1428           bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1429       arith::FastMathFlags extFMF =
1430           getFastmath().value_or(arith::FastMathFlags::none);
1431       bool isExtContract =
1432           bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1433       if (isTruncContract && isExtContract) {
1434         return truncFOp.getOperand();
1435       }
1436     }
1437   }
1438 
1439   auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1440   const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1441   return constFoldCastOp<FloatAttr, FloatAttr>(
1442       adaptor.getOperands(), getType(),
1443       [&targetSemantics](const APFloat &a, bool &castStatus) {
1444         FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1445         if (failed(result)) {
1446           castStatus = false;
1447           return a;
1448         }
1449         return *result;
1450       });
1451 }
1452 
1453 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1454   return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1455 }
1456 
1457 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1458 
1459 //===----------------------------------------------------------------------===//
1460 // TruncIOp
1461 //===----------------------------------------------------------------------===//
1462 
1463 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1464   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1465       matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1466     Value src = getOperand().getDefiningOp()->getOperand(0);
1467     Type srcType = getElementTypeOrSelf(src.getType());
1468     Type dstType = getElementTypeOrSelf(getType());
1469     // trunci(zexti(a)) -> trunci(a)
1470     // trunci(sexti(a)) -> trunci(a)
1471     if (llvm::cast<IntegerType>(srcType).getWidth() >
1472         llvm::cast<IntegerType>(dstType).getWidth()) {
1473       setOperand(src);
1474       return getResult();
1475     }
1476 
1477     // trunci(zexti(a)) -> a
1478     // trunci(sexti(a)) -> a
1479     if (srcType == dstType)
1480       return src;
1481   }
1482 
1483   // trunci(trunci(a)) -> trunci(a))
1484   if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1485     setOperand(getOperand().getDefiningOp()->getOperand(0));
1486     return getResult();
1487   }
1488 
1489   Type resType = getElementTypeOrSelf(getType());
1490   unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1491   return constFoldCastOp<IntegerAttr, IntegerAttr>(
1492       adaptor.getOperands(), getType(),
1493       [bitWidth](const APInt &a, bool &castStatus) {
1494         return a.trunc(bitWidth);
1495       });
1496 }
1497 
1498 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1499   return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1500 }
1501 
1502 void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1503                                                   MLIRContext *context) {
1504   patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1505                TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1506       context);
1507 }
1508 
1509 LogicalResult arith::TruncIOp::verify() {
1510   return verifyTruncateOp<IntegerType>(*this);
1511 }
1512 
1513 //===----------------------------------------------------------------------===//
1514 // TruncFOp
1515 //===----------------------------------------------------------------------===//
1516 
1517 /// Perform safe const propagation for truncf, i.e., only propagate if FP value
1518 /// can be represented without precision loss.
1519 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1520   auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1521   const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1522   return constFoldCastOp<FloatAttr, FloatAttr>(
1523       adaptor.getOperands(), getType(),
1524       [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1525         RoundingMode roundingMode =
1526             getRoundingmode().value_or(RoundingMode::to_nearest_even);
1527         llvm::RoundingMode llvmRoundingMode =
1528             convertArithRoundingModeToLLVMIR(roundingMode);
1529         FailureOr<APFloat> result =
1530             convertFloatValue(a, targetSemantics, llvmRoundingMode);
1531         if (failed(result)) {
1532           castStatus = false;
1533           return a;
1534         }
1535         return *result;
1536       });
1537 }
1538 
1539 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1540   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1541 }
1542 
1543 LogicalResult arith::TruncFOp::verify() {
1544   return verifyTruncateOp<FloatType>(*this);
1545 }
1546 
1547 //===----------------------------------------------------------------------===//
1548 // AndIOp
1549 //===----------------------------------------------------------------------===//
1550 
1551 void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1552                                                 MLIRContext *context) {
1553   patterns.add<AndOfExtUI, AndOfExtSI>(context);
1554 }
1555 
1556 //===----------------------------------------------------------------------===//
1557 // OrIOp
1558 //===----------------------------------------------------------------------===//
1559 
1560 void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1561                                                MLIRContext *context) {
1562   patterns.add<OrOfExtUI, OrOfExtSI>(context);
1563 }
1564 
1565 //===----------------------------------------------------------------------===//
1566 // Verifiers for casts between integers and floats.
1567 //===----------------------------------------------------------------------===//
1568 
1569 template <typename From, typename To>
1570 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1571   if (!areValidCastInputsAndOutputs(inputs, outputs))
1572     return false;
1573 
1574   auto srcType = getTypeIfLike<From>(inputs.front());
1575   auto dstType = getTypeIfLike<To>(outputs.back());
1576 
1577   return srcType && dstType;
1578 }
1579 
1580 //===----------------------------------------------------------------------===//
1581 // UIToFPOp
1582 //===----------------------------------------------------------------------===//
1583 
1584 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1585   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1586 }
1587 
1588 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1589   Type resEleType = getElementTypeOrSelf(getType());
1590   return constFoldCastOp<IntegerAttr, FloatAttr>(
1591       adaptor.getOperands(), getType(),
1592       [&resEleType](const APInt &a, bool &castStatus) {
1593         FloatType floatTy = llvm::cast<FloatType>(resEleType);
1594         APFloat apf(floatTy.getFloatSemantics(),
1595                     APInt::getZero(floatTy.getWidth()));
1596         apf.convertFromAPInt(a, /*IsSigned=*/false,
1597                              APFloat::rmNearestTiesToEven);
1598         return apf;
1599       });
1600 }
1601 
1602 //===----------------------------------------------------------------------===//
1603 // SIToFPOp
1604 //===----------------------------------------------------------------------===//
1605 
1606 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1607   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1608 }
1609 
1610 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1611   Type resEleType = getElementTypeOrSelf(getType());
1612   return constFoldCastOp<IntegerAttr, FloatAttr>(
1613       adaptor.getOperands(), getType(),
1614       [&resEleType](const APInt &a, bool &castStatus) {
1615         FloatType floatTy = llvm::cast<FloatType>(resEleType);
1616         APFloat apf(floatTy.getFloatSemantics(),
1617                     APInt::getZero(floatTy.getWidth()));
1618         apf.convertFromAPInt(a, /*IsSigned=*/true,
1619                              APFloat::rmNearestTiesToEven);
1620         return apf;
1621       });
1622 }
1623 
1624 //===----------------------------------------------------------------------===//
1625 // FPToUIOp
1626 //===----------------------------------------------------------------------===//
1627 
1628 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1629   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1630 }
1631 
1632 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1633   Type resType = getElementTypeOrSelf(getType());
1634   unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1635   return constFoldCastOp<FloatAttr, IntegerAttr>(
1636       adaptor.getOperands(), getType(),
1637       [&bitWidth](const APFloat &a, bool &castStatus) {
1638         bool ignored;
1639         APSInt api(bitWidth, /*isUnsigned=*/true);
1640         castStatus = APFloat::opInvalidOp !=
1641                      a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1642         return api;
1643       });
1644 }
1645 
1646 //===----------------------------------------------------------------------===//
1647 // FPToSIOp
1648 //===----------------------------------------------------------------------===//
1649 
1650 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1651   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1652 }
1653 
1654 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1655   Type resType = getElementTypeOrSelf(getType());
1656   unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1657   return constFoldCastOp<FloatAttr, IntegerAttr>(
1658       adaptor.getOperands(), getType(),
1659       [&bitWidth](const APFloat &a, bool &castStatus) {
1660         bool ignored;
1661         APSInt api(bitWidth, /*isUnsigned=*/false);
1662         castStatus = APFloat::opInvalidOp !=
1663                      a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1664         return api;
1665       });
1666 }
1667 
1668 //===----------------------------------------------------------------------===//
1669 // IndexCastOp
1670 //===----------------------------------------------------------------------===//
1671 
1672 static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1673   if (!areValidCastInputsAndOutputs(inputs, outputs))
1674     return false;
1675 
1676   auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1677   auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1678   if (!srcType || !dstType)
1679     return false;
1680 
1681   return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1682          (srcType.isSignlessInteger() && dstType.isIndex());
1683 }
1684 
1685 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1686                                            TypeRange outputs) {
1687   return areIndexCastCompatible(inputs, outputs);
1688 }
1689 
1690 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1691   // index_cast(constant) -> constant
1692   unsigned resultBitwidth = 64; // Default for index integer attributes.
1693   if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1694     resultBitwidth = intTy.getWidth();
1695 
1696   return constFoldCastOp<IntegerAttr, IntegerAttr>(
1697       adaptor.getOperands(), getType(),
1698       [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1699         return a.sextOrTrunc(resultBitwidth);
1700       });
1701 }
1702 
1703 void arith::IndexCastOp::getCanonicalizationPatterns(
1704     RewritePatternSet &patterns, MLIRContext *context) {
1705   patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1706 }
1707 
1708 //===----------------------------------------------------------------------===//
1709 // IndexCastUIOp
1710 //===----------------------------------------------------------------------===//
1711 
1712 bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
1713                                              TypeRange outputs) {
1714   return areIndexCastCompatible(inputs, outputs);
1715 }
1716 
1717 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1718   // index_castui(constant) -> constant
1719   unsigned resultBitwidth = 64; // Default for index integer attributes.
1720   if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1721     resultBitwidth = intTy.getWidth();
1722 
1723   return constFoldCastOp<IntegerAttr, IntegerAttr>(
1724       adaptor.getOperands(), getType(),
1725       [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1726         return a.zextOrTrunc(resultBitwidth);
1727       });
1728 }
1729 
1730 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1731     RewritePatternSet &patterns, MLIRContext *context) {
1732   patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1733 }
1734 
1735 //===----------------------------------------------------------------------===//
1736 // BitcastOp
1737 //===----------------------------------------------------------------------===//
1738 
1739 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1740   if (!areValidCastInputsAndOutputs(inputs, outputs))
1741     return false;
1742 
1743   auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
1744   auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
1745   if (!srcType || !dstType)
1746     return false;
1747 
1748   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1749 }
1750 
1751 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1752   auto resType = getType();
1753   auto operand = adaptor.getIn();
1754   if (!operand)
1755     return {};
1756 
1757   /// Bitcast dense elements.
1758   if (auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1759     return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
1760   /// Other shaped types unhandled.
1761   if (llvm::isa<ShapedType>(resType))
1762     return {};
1763 
1764   /// Bitcast integer or float to integer or float.
1765   APInt bits = llvm::isa<FloatAttr>(operand)
1766                    ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1767                    : llvm::cast<IntegerAttr>(operand).getValue();
1768   assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() &&
1769          "trying to fold on broken IR: operands have incompatible types");
1770 
1771   if (auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1772     return FloatAttr::get(resType,
1773                           APFloat(resFloatType.getFloatSemantics(), bits));
1774   return IntegerAttr::get(resType, bits);
1775 }
1776 
1777 void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1778                                                    MLIRContext *context) {
1779   patterns.add<BitcastOfBitcast>(context);
1780 }
1781 
1782 //===----------------------------------------------------------------------===//
1783 // CmpIOp
1784 //===----------------------------------------------------------------------===//
1785 
1786 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1787 /// comparison predicates.
1788 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1789                                     const APInt &lhs, const APInt &rhs) {
1790   switch (predicate) {
1791   case arith::CmpIPredicate::eq:
1792     return lhs.eq(rhs);
1793   case arith::CmpIPredicate::ne:
1794     return lhs.ne(rhs);
1795   case arith::CmpIPredicate::slt:
1796     return lhs.slt(rhs);
1797   case arith::CmpIPredicate::sle:
1798     return lhs.sle(rhs);
1799   case arith::CmpIPredicate::sgt:
1800     return lhs.sgt(rhs);
1801   case arith::CmpIPredicate::sge:
1802     return lhs.sge(rhs);
1803   case arith::CmpIPredicate::ult:
1804     return lhs.ult(rhs);
1805   case arith::CmpIPredicate::ule:
1806     return lhs.ule(rhs);
1807   case arith::CmpIPredicate::ugt:
1808     return lhs.ugt(rhs);
1809   case arith::CmpIPredicate::uge:
1810     return lhs.uge(rhs);
1811   }
1812   llvm_unreachable("unknown cmpi predicate kind");
1813 }
1814 
1815 /// Returns true if the predicate is true for two equal operands.
1816 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1817   switch (predicate) {
1818   case arith::CmpIPredicate::eq:
1819   case arith::CmpIPredicate::sle:
1820   case arith::CmpIPredicate::sge:
1821   case arith::CmpIPredicate::ule:
1822   case arith::CmpIPredicate::uge:
1823     return true;
1824   case arith::CmpIPredicate::ne:
1825   case arith::CmpIPredicate::slt:
1826   case arith::CmpIPredicate::sgt:
1827   case arith::CmpIPredicate::ult:
1828   case arith::CmpIPredicate::ugt:
1829     return false;
1830   }
1831   llvm_unreachable("unknown cmpi predicate kind");
1832 }
1833 
1834 static std::optional<int64_t> getIntegerWidth(Type t) {
1835   if (auto intType = llvm::dyn_cast<IntegerType>(t)) {
1836     return intType.getWidth();
1837   }
1838   if (auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1839     return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1840   }
1841   return std::nullopt;
1842 }
1843 
1844 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1845   // cmpi(pred, x, x)
1846   if (getLhs() == getRhs()) {
1847     auto val = applyCmpPredicateToEqualOperands(getPredicate());
1848     return getBoolAttribute(getType(), val);
1849   }
1850 
1851   if (matchPattern(adaptor.getRhs(), m_Zero())) {
1852     if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1853       // extsi(%x : i1 -> iN) != 0  ->  %x
1854       std::optional<int64_t> integerWidth =
1855           getIntegerWidth(extOp.getOperand().getType());
1856       if (integerWidth && integerWidth.value() == 1 &&
1857           getPredicate() == arith::CmpIPredicate::ne)
1858         return extOp.getOperand();
1859     }
1860     if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1861       // extui(%x : i1 -> iN) != 0  ->  %x
1862       std::optional<int64_t> integerWidth =
1863           getIntegerWidth(extOp.getOperand().getType());
1864       if (integerWidth && integerWidth.value() == 1 &&
1865           getPredicate() == arith::CmpIPredicate::ne)
1866         return extOp.getOperand();
1867     }
1868 
1869     // arith.cmpi ne, %val, %zero : i1 -> %val
1870     if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
1871         getPredicate() == arith::CmpIPredicate::ne)
1872       return getLhs();
1873   }
1874 
1875   if (matchPattern(adaptor.getRhs(), m_One())) {
1876     // arith.cmpi eq, %val, %one : i1 -> %val
1877     if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
1878         getPredicate() == arith::CmpIPredicate::eq)
1879       return getLhs();
1880   }
1881 
1882   // Move constant to the right side.
1883   if (adaptor.getLhs() && !adaptor.getRhs()) {
1884     // Do not use invertPredicate, as it will change eq to ne and vice versa.
1885     using Pred = CmpIPredicate;
1886     const std::pair<Pred, Pred> invPreds[] = {
1887         {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1888         {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1889         {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1890         {Pred::ne, Pred::ne},
1891     };
1892     Pred origPred = getPredicate();
1893     for (auto pred : invPreds) {
1894       if (origPred == pred.first) {
1895         setPredicate(pred.second);
1896         Value lhs = getLhs();
1897         Value rhs = getRhs();
1898         getLhsMutable().assign(rhs);
1899         getRhsMutable().assign(lhs);
1900         return getResult();
1901       }
1902     }
1903     llvm_unreachable("unknown cmpi predicate kind");
1904   }
1905 
1906   // We are moving constants to the right side; So if lhs is constant rhs is
1907   // guaranteed to be a constant.
1908   if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1909     return constFoldBinaryOp<IntegerAttr>(
1910         adaptor.getOperands(), getI1SameShape(lhs.getType()),
1911         [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
1912           return APInt(1,
1913                        static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
1914         });
1915   }
1916 
1917   return {};
1918 }
1919 
1920 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1921                                                 MLIRContext *context) {
1922   patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1923 }
1924 
1925 //===----------------------------------------------------------------------===//
1926 // CmpFOp
1927 //===----------------------------------------------------------------------===//
1928 
1929 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1930 /// comparison predicates.
1931 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1932                                     const APFloat &lhs, const APFloat &rhs) {
1933   auto cmpResult = lhs.compare(rhs);
1934   switch (predicate) {
1935   case arith::CmpFPredicate::AlwaysFalse:
1936     return false;
1937   case arith::CmpFPredicate::OEQ:
1938     return cmpResult == APFloat::cmpEqual;
1939   case arith::CmpFPredicate::OGT:
1940     return cmpResult == APFloat::cmpGreaterThan;
1941   case arith::CmpFPredicate::OGE:
1942     return cmpResult == APFloat::cmpGreaterThan ||
1943            cmpResult == APFloat::cmpEqual;
1944   case arith::CmpFPredicate::OLT:
1945     return cmpResult == APFloat::cmpLessThan;
1946   case arith::CmpFPredicate::OLE:
1947     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1948   case arith::CmpFPredicate::ONE:
1949     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1950   case arith::CmpFPredicate::ORD:
1951     return cmpResult != APFloat::cmpUnordered;
1952   case arith::CmpFPredicate::UEQ:
1953     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1954   case arith::CmpFPredicate::UGT:
1955     return cmpResult == APFloat::cmpUnordered ||
1956            cmpResult == APFloat::cmpGreaterThan;
1957   case arith::CmpFPredicate::UGE:
1958     return cmpResult == APFloat::cmpUnordered ||
1959            cmpResult == APFloat::cmpGreaterThan ||
1960            cmpResult == APFloat::cmpEqual;
1961   case arith::CmpFPredicate::ULT:
1962     return cmpResult == APFloat::cmpUnordered ||
1963            cmpResult == APFloat::cmpLessThan;
1964   case arith::CmpFPredicate::ULE:
1965     return cmpResult == APFloat::cmpUnordered ||
1966            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1967   case arith::CmpFPredicate::UNE:
1968     return cmpResult != APFloat::cmpEqual;
1969   case arith::CmpFPredicate::UNO:
1970     return cmpResult == APFloat::cmpUnordered;
1971   case arith::CmpFPredicate::AlwaysTrue:
1972     return true;
1973   }
1974   llvm_unreachable("unknown cmpf predicate kind");
1975 }
1976 
1977 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
1978   auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
1979   auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
1980 
1981   // If one operand is NaN, making them both NaN does not change the result.
1982   if (lhs && lhs.getValue().isNaN())
1983     rhs = lhs;
1984   if (rhs && rhs.getValue().isNaN())
1985     lhs = rhs;
1986 
1987   if (!lhs || !rhs)
1988     return {};
1989 
1990   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1991   return BoolAttr::get(getContext(), val);
1992 }
1993 
1994 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
1995 public:
1996   using OpRewritePattern<CmpFOp>::OpRewritePattern;
1997 
1998   static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
1999                                                  bool isUnsigned) {
2000     using namespace arith;
2001     switch (pred) {
2002     case CmpFPredicate::UEQ:
2003     case CmpFPredicate::OEQ:
2004       return CmpIPredicate::eq;
2005     case CmpFPredicate::UGT:
2006     case CmpFPredicate::OGT:
2007       return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2008     case CmpFPredicate::UGE:
2009     case CmpFPredicate::OGE:
2010       return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2011     case CmpFPredicate::ULT:
2012     case CmpFPredicate::OLT:
2013       return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2014     case CmpFPredicate::ULE:
2015     case CmpFPredicate::OLE:
2016       return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2017     case CmpFPredicate::UNE:
2018     case CmpFPredicate::ONE:
2019       return CmpIPredicate::ne;
2020     default:
2021       llvm_unreachable("Unexpected predicate!");
2022     }
2023   }
2024 
2025   LogicalResult matchAndRewrite(CmpFOp op,
2026                                 PatternRewriter &rewriter) const override {
2027     FloatAttr flt;
2028     if (!matchPattern(op.getRhs(), m_Constant(&flt)))
2029       return failure();
2030 
2031     const APFloat &rhs = flt.getValue();
2032 
2033     // Don't attempt to fold a nan.
2034     if (rhs.isNaN())
2035       return failure();
2036 
2037     // Get the width of the mantissa.  We don't want to hack on conversions that
2038     // might lose information from the integer, e.g. "i64 -> float"
2039     FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2040     int mantissaWidth = floatTy.getFPMantissaWidth();
2041     if (mantissaWidth <= 0)
2042       return failure();
2043 
2044     bool isUnsigned;
2045     Value intVal;
2046 
2047     if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2048       isUnsigned = false;
2049       intVal = si.getIn();
2050     } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2051       isUnsigned = true;
2052       intVal = ui.getIn();
2053     } else {
2054       return failure();
2055     }
2056 
2057     // Check to see that the input is converted from an integer type that is
2058     // small enough that preserves all bits.
2059     auto intTy = llvm::cast<IntegerType>(intVal.getType());
2060     auto intWidth = intTy.getWidth();
2061 
2062     // Number of bits representing values, as opposed to the sign
2063     auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2064 
2065     // Following test does NOT adjust intWidth downwards for signed inputs,
2066     // because the most negative value still requires all the mantissa bits
2067     // to distinguish it from one less than that value.
2068     if ((int)intWidth > mantissaWidth) {
2069       // Conversion would lose accuracy. Check if loss can impact comparison.
2070       int exponent = ilogb(rhs);
2071       if (exponent == APFloat::IEK_Inf) {
2072         int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2073         if (maxExponent < (int)valueBits) {
2074           // Conversion could create infinity.
2075           return failure();
2076         }
2077       } else {
2078         // Note that if rhs is zero or NaN, then Exp is negative
2079         // and first condition is trivially false.
2080         if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
2081           // Conversion could affect comparison.
2082           return failure();
2083         }
2084       }
2085     }
2086 
2087     // Convert to equivalent cmpi predicate
2088     CmpIPredicate pred;
2089     switch (op.getPredicate()) {
2090     case CmpFPredicate::ORD:
2091       // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
2092       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2093                                                  /*width=*/1);
2094       return success();
2095     case CmpFPredicate::UNO:
2096       // Int to fp conversion doesn't create a nan (uno checks either is a nan)
2097       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2098                                                  /*width=*/1);
2099       return success();
2100     default:
2101       pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
2102       break;
2103     }
2104 
2105     if (!isUnsigned) {
2106       // If the rhs value is > SignedMax, fold the comparison.  This handles
2107       // +INF and large values.
2108       APFloat signedMax(rhs.getSemantics());
2109       signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
2110                                  APFloat::rmNearestTiesToEven);
2111       if (signedMax < rhs) { // smax < 13123.0
2112         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2113             pred == CmpIPredicate::sle)
2114           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2115                                                      /*width=*/1);
2116         else
2117           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2118                                                      /*width=*/1);
2119         return success();
2120       }
2121     } else {
2122       // If the rhs value is > UnsignedMax, fold the comparison. This handles
2123       // +INF and large values.
2124       APFloat unsignedMax(rhs.getSemantics());
2125       unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
2126                                    APFloat::rmNearestTiesToEven);
2127       if (unsignedMax < rhs) { // umax < 13123.0
2128         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2129             pred == CmpIPredicate::ule)
2130           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2131                                                      /*width=*/1);
2132         else
2133           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2134                                                      /*width=*/1);
2135         return success();
2136       }
2137     }
2138 
2139     if (!isUnsigned) {
2140       // See if the rhs value is < SignedMin.
2141       APFloat signedMin(rhs.getSemantics());
2142       signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
2143                                  APFloat::rmNearestTiesToEven);
2144       if (signedMin > rhs) { // smin > 12312.0
2145         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2146             pred == CmpIPredicate::sge)
2147           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2148                                                      /*width=*/1);
2149         else
2150           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2151                                                      /*width=*/1);
2152         return success();
2153       }
2154     } else {
2155       // See if the rhs value is < UnsignedMin.
2156       APFloat unsignedMin(rhs.getSemantics());
2157       unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
2158                                    APFloat::rmNearestTiesToEven);
2159       if (unsignedMin > rhs) { // umin > 12312.0
2160         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2161             pred == CmpIPredicate::uge)
2162           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2163                                                      /*width=*/1);
2164         else
2165           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2166                                                      /*width=*/1);
2167         return success();
2168       }
2169     }
2170 
2171     // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2172     // [0, UMAX], but it may still be fractional.  See if it is fractional by
2173     // casting the FP value to the integer value and back, checking for
2174     // equality. Don't do this for zero, because -0.0 is not fractional.
2175     bool ignored;
2176     APSInt rhsInt(intWidth, isUnsigned);
2177     if (APFloat::opInvalidOp ==
2178         rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2179       // Undefined behavior invoked - the destination type can't represent
2180       // the input constant.
2181       return failure();
2182     }
2183 
2184     if (!rhs.isZero()) {
2185       APFloat apf(floatTy.getFloatSemantics(),
2186                   APInt::getZero(floatTy.getWidth()));
2187       apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2188 
2189       bool equal = apf == rhs;
2190       if (!equal) {
2191         // If we had a comparison against a fractional value, we have to adjust
2192         // the compare predicate and sometimes the value.  rhsInt is rounded
2193         // towards zero at this point.
2194         switch (pred) {
2195         case CmpIPredicate::ne: // (float)int != 4.4   --> true
2196           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2197                                                      /*width=*/1);
2198           return success();
2199         case CmpIPredicate::eq: // (float)int == 4.4   --> false
2200           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2201                                                      /*width=*/1);
2202           return success();
2203         case CmpIPredicate::ule:
2204           // (float)int <= 4.4   --> int <= 4
2205           // (float)int <= -4.4  --> false
2206           if (rhs.isNegative()) {
2207             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2208                                                        /*width=*/1);
2209             return success();
2210           }
2211           break;
2212         case CmpIPredicate::sle:
2213           // (float)int <= 4.4   --> int <= 4
2214           // (float)int <= -4.4  --> int < -4
2215           if (rhs.isNegative())
2216             pred = CmpIPredicate::slt;
2217           break;
2218         case CmpIPredicate::ult:
2219           // (float)int < -4.4   --> false
2220           // (float)int < 4.4    --> int <= 4
2221           if (rhs.isNegative()) {
2222             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2223                                                        /*width=*/1);
2224             return success();
2225           }
2226           pred = CmpIPredicate::ule;
2227           break;
2228         case CmpIPredicate::slt:
2229           // (float)int < -4.4   --> int < -4
2230           // (float)int < 4.4    --> int <= 4
2231           if (!rhs.isNegative())
2232             pred = CmpIPredicate::sle;
2233           break;
2234         case CmpIPredicate::ugt:
2235           // (float)int > 4.4    --> int > 4
2236           // (float)int > -4.4   --> true
2237           if (rhs.isNegative()) {
2238             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2239                                                        /*width=*/1);
2240             return success();
2241           }
2242           break;
2243         case CmpIPredicate::sgt:
2244           // (float)int > 4.4    --> int > 4
2245           // (float)int > -4.4   --> int >= -4
2246           if (rhs.isNegative())
2247             pred = CmpIPredicate::sge;
2248           break;
2249         case CmpIPredicate::uge:
2250           // (float)int >= -4.4   --> true
2251           // (float)int >= 4.4    --> int > 4
2252           if (rhs.isNegative()) {
2253             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2254                                                        /*width=*/1);
2255             return success();
2256           }
2257           pred = CmpIPredicate::ugt;
2258           break;
2259         case CmpIPredicate::sge:
2260           // (float)int >= -4.4   --> int >= -4
2261           // (float)int >= 4.4    --> int > 4
2262           if (!rhs.isNegative())
2263             pred = CmpIPredicate::sgt;
2264           break;
2265         }
2266       }
2267     }
2268 
2269     // Lower this FP comparison into an appropriate integer version of the
2270     // comparison.
2271     rewriter.replaceOpWithNewOp<CmpIOp>(
2272         op, pred, intVal,
2273         rewriter.create<ConstantOp>(
2274             op.getLoc(), intVal.getType(),
2275             rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2276     return success();
2277   }
2278 };
2279 
2280 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2281                                                 MLIRContext *context) {
2282   patterns.insert<CmpFIntToFPConst>(context);
2283 }
2284 
2285 //===----------------------------------------------------------------------===//
2286 // SelectOp
2287 //===----------------------------------------------------------------------===//
2288 
2289 //  select %arg, %c1, %c0 => extui %arg
2290 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2291   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
2292 
2293   LogicalResult matchAndRewrite(arith::SelectOp op,
2294                                 PatternRewriter &rewriter) const override {
2295     // Cannot extui i1 to i1, or i1 to f32
2296     if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2297       return failure();
2298 
2299     // select %x, c1, %c0 => extui %arg
2300     if (matchPattern(op.getTrueValue(), m_One()) &&
2301         matchPattern(op.getFalseValue(), m_Zero())) {
2302       rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2303                                                   op.getCondition());
2304       return success();
2305     }
2306 
2307     // select %x, c0, %c1 => extui (xor %arg, true)
2308     if (matchPattern(op.getTrueValue(), m_Zero()) &&
2309         matchPattern(op.getFalseValue(), m_One())) {
2310       rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2311           op, op.getType(),
2312           rewriter.create<arith::XOrIOp>(
2313               op.getLoc(), op.getCondition(),
2314               rewriter.create<arith::ConstantIntOp>(
2315                   op.getLoc(), 1, op.getCondition().getType())));
2316       return success();
2317     }
2318 
2319     return failure();
2320   }
2321 };
2322 
2323 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2324                                                   MLIRContext *context) {
2325   results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2326               SelectI1ToNot, SelectToExtUI>(context);
2327 }
2328 
2329 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2330   Value trueVal = getTrueValue();
2331   Value falseVal = getFalseValue();
2332   if (trueVal == falseVal)
2333     return trueVal;
2334 
2335   Value condition = getCondition();
2336 
2337   // select true, %0, %1 => %0
2338   if (matchPattern(adaptor.getCondition(), m_One()))
2339     return trueVal;
2340 
2341   // select false, %0, %1 => %1
2342   if (matchPattern(adaptor.getCondition(), m_Zero()))
2343     return falseVal;
2344 
2345   // If either operand is fully poisoned, return the other.
2346   if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2347     return falseVal;
2348 
2349   if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2350     return trueVal;
2351 
2352   // select %x, true, false => %x
2353   if (getType().isSignlessInteger(1) &&
2354       matchPattern(adaptor.getTrueValue(), m_One()) &&
2355       matchPattern(adaptor.getFalseValue(), m_Zero()))
2356     return condition;
2357 
2358   if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
2359     auto pred = cmp.getPredicate();
2360     if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2361       auto cmpLhs = cmp.getLhs();
2362       auto cmpRhs = cmp.getRhs();
2363 
2364       // %0 = arith.cmpi eq, %arg0, %arg1
2365       // %1 = arith.select %0, %arg0, %arg1 => %arg1
2366 
2367       // %0 = arith.cmpi ne, %arg0, %arg1
2368       // %1 = arith.select %0, %arg0, %arg1 => %arg0
2369 
2370       if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2371           (cmpRhs == trueVal && cmpLhs == falseVal))
2372         return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2373     }
2374   }
2375 
2376   // Constant-fold constant operands over non-splat constant condition.
2377   // select %cst_vec, %cst0, %cst1 => %cst2
2378   if (auto cond =
2379           llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2380     if (auto lhs =
2381             llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2382       if (auto rhs =
2383               llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2384         SmallVector<Attribute> results;
2385         results.reserve(static_cast<size_t>(cond.getNumElements()));
2386         auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2387                                          cond.value_end<BoolAttr>());
2388         auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2389                                         lhs.value_end<Attribute>());
2390         auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2391                                         rhs.value_end<Attribute>());
2392 
2393         for (auto [condVal, lhsVal, rhsVal] :
2394              llvm::zip_equal(condVals, lhsVals, rhsVals))
2395           results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2396 
2397         return DenseElementsAttr::get(lhs.getType(), results);
2398       }
2399     }
2400   }
2401 
2402   return nullptr;
2403 }
2404 
2405 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
2406   Type conditionType, resultType;
2407   SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
2408   if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2409       parser.parseOptionalAttrDict(result.attributes) ||
2410       parser.parseColonType(resultType))
2411     return failure();
2412 
2413   // Check for the explicit condition type if this is a masked tensor or vector.
2414   if (succeeded(parser.parseOptionalComma())) {
2415     conditionType = resultType;
2416     if (parser.parseType(resultType))
2417       return failure();
2418   } else {
2419     conditionType = parser.getBuilder().getI1Type();
2420   }
2421 
2422   result.addTypes(resultType);
2423   return parser.resolveOperands(operands,
2424                                 {conditionType, resultType, resultType},
2425                                 parser.getNameLoc(), result.operands);
2426 }
2427 
2428 void arith::SelectOp::print(OpAsmPrinter &p) {
2429   p << " " << getOperands();
2430   p.printOptionalAttrDict((*this)->getAttrs());
2431   p << " : ";
2432   if (ShapedType condType =
2433           llvm::dyn_cast<ShapedType>(getCondition().getType()))
2434     p << condType << ", ";
2435   p << getType();
2436 }
2437 
2438 LogicalResult arith::SelectOp::verify() {
2439   Type conditionType = getCondition().getType();
2440   if (conditionType.isSignlessInteger(1))
2441     return success();
2442 
2443   // If the result type is a vector or tensor, the type can be a mask with the
2444   // same elements.
2445   Type resultType = getType();
2446   if (!llvm::isa<TensorType, VectorType>(resultType))
2447     return emitOpError() << "expected condition to be a signless i1, but got "
2448                          << conditionType;
2449   Type shapedConditionType = getI1SameShape(resultType);
2450   if (conditionType != shapedConditionType) {
2451     return emitOpError() << "expected condition type to have the same shape "
2452                             "as the result type, expected "
2453                          << shapedConditionType << ", but got "
2454                          << conditionType;
2455   }
2456   return success();
2457 }
2458 //===----------------------------------------------------------------------===//
2459 // ShLIOp
2460 //===----------------------------------------------------------------------===//
2461 
2462 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2463   // shli(x, 0) -> x
2464   if (matchPattern(adaptor.getRhs(), m_Zero()))
2465     return getLhs();
2466   // Don't fold if shifting more or equal than the bit width.
2467   bool bounded = false;
2468   auto result = constFoldBinaryOp<IntegerAttr>(
2469       adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2470         bounded = b.ult(b.getBitWidth());
2471         return a.shl(b);
2472       });
2473   return bounded ? result : Attribute();
2474 }
2475 
2476 //===----------------------------------------------------------------------===//
2477 // ShRUIOp
2478 //===----------------------------------------------------------------------===//
2479 
2480 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2481   // shrui(x, 0) -> x
2482   if (matchPattern(adaptor.getRhs(), m_Zero()))
2483     return getLhs();
2484   // Don't fold if shifting more or equal than the bit width.
2485   bool bounded = false;
2486   auto result = constFoldBinaryOp<IntegerAttr>(
2487       adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2488         bounded = b.ult(b.getBitWidth());
2489         return a.lshr(b);
2490       });
2491   return bounded ? result : Attribute();
2492 }
2493 
2494 //===----------------------------------------------------------------------===//
2495 // ShRSIOp
2496 //===----------------------------------------------------------------------===//
2497 
2498 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2499   // shrsi(x, 0) -> x
2500   if (matchPattern(adaptor.getRhs(), m_Zero()))
2501     return getLhs();
2502   // Don't fold if shifting more or equal than the bit width.
2503   bool bounded = false;
2504   auto result = constFoldBinaryOp<IntegerAttr>(
2505       adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2506         bounded = b.ult(b.getBitWidth());
2507         return a.ashr(b);
2508       });
2509   return bounded ? result : Attribute();
2510 }
2511 
2512 //===----------------------------------------------------------------------===//
2513 // Atomic Enum
2514 //===----------------------------------------------------------------------===//
2515 
2516 /// Returns the identity value attribute associated with an AtomicRMWKind op.
2517 TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2518                                             OpBuilder &builder, Location loc,
2519                                             bool useOnlyFiniteValue) {
2520   switch (kind) {
2521   case AtomicRMWKind::maximumf: {
2522     const llvm::fltSemantics &semantic =
2523         llvm::cast<FloatType>(resultType).getFloatSemantics();
2524     APFloat identity = useOnlyFiniteValue
2525                            ? APFloat::getLargest(semantic, /*Negative=*/true)
2526                            : APFloat::getInf(semantic, /*Negative=*/true);
2527     return builder.getFloatAttr(resultType, identity);
2528   }
2529   case AtomicRMWKind::maxnumf: {
2530     const llvm::fltSemantics &semantic =
2531         llvm::cast<FloatType>(resultType).getFloatSemantics();
2532     APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
2533     return builder.getFloatAttr(resultType, identity);
2534   }
2535   case AtomicRMWKind::addf:
2536   case AtomicRMWKind::addi:
2537   case AtomicRMWKind::maxu:
2538   case AtomicRMWKind::ori:
2539     return builder.getZeroAttr(resultType);
2540   case AtomicRMWKind::andi:
2541     return builder.getIntegerAttr(
2542         resultType,
2543         APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2544   case AtomicRMWKind::maxs:
2545     return builder.getIntegerAttr(
2546         resultType, APInt::getSignedMinValue(
2547                         llvm::cast<IntegerType>(resultType).getWidth()));
2548   case AtomicRMWKind::minimumf: {
2549     const llvm::fltSemantics &semantic =
2550         llvm::cast<FloatType>(resultType).getFloatSemantics();
2551     APFloat identity = useOnlyFiniteValue
2552                            ? APFloat::getLargest(semantic, /*Negative=*/false)
2553                            : APFloat::getInf(semantic, /*Negative=*/false);
2554 
2555     return builder.getFloatAttr(resultType, identity);
2556   }
2557   case AtomicRMWKind::minnumf: {
2558     const llvm::fltSemantics &semantic =
2559         llvm::cast<FloatType>(resultType).getFloatSemantics();
2560     APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
2561     return builder.getFloatAttr(resultType, identity);
2562   }
2563   case AtomicRMWKind::mins:
2564     return builder.getIntegerAttr(
2565         resultType, APInt::getSignedMaxValue(
2566                         llvm::cast<IntegerType>(resultType).getWidth()));
2567   case AtomicRMWKind::minu:
2568     return builder.getIntegerAttr(
2569         resultType,
2570         APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2571   case AtomicRMWKind::muli:
2572     return builder.getIntegerAttr(resultType, 1);
2573   case AtomicRMWKind::mulf:
2574     return builder.getFloatAttr(resultType, 1);
2575   // TODO: Add remaining reduction operations.
2576   default:
2577     (void)emitOptionalError(loc, "Reduction operation type not supported");
2578     break;
2579   }
2580   return nullptr;
2581 }
2582 
2583 /// Return the identity numeric value associated to the give op.
2584 std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2585   std::optional<AtomicRMWKind> maybeKind =
2586       llvm::TypeSwitch<Operation *, std::optional<AtomicRMWKind>>(op)
2587           // Floating-point operations.
2588           .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2589           .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2590           .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2591           .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2592           .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2593           .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
2594           // Integer operations.
2595           .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2596           .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2597           .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
2598           .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
2599           .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2600           .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2601           .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2602           .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2603           .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
2604           .Default([](Operation *op) { return std::nullopt; });
2605   if (!maybeKind) {
2606     return std::nullopt;
2607   }
2608 
2609   bool useOnlyFiniteValue = false;
2610   auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2611   if (fmfOpInterface) {
2612     arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2613     useOnlyFiniteValue =
2614         bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2615   }
2616 
2617   // Builder only used as helper for attribute creation.
2618   OpBuilder b(op->getContext());
2619   Type resultType = op->getResult(0).getType();
2620 
2621   return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
2622                               useOnlyFiniteValue);
2623 }
2624 
2625 /// Returns the identity value associated with an AtomicRMWKind op.
2626 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2627                                     OpBuilder &builder, Location loc,
2628                                     bool useOnlyFiniteValue) {
2629   auto attr =
2630       getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
2631   return builder.create<arith::ConstantOp>(loc, attr);
2632 }
2633 
2634 /// Return the value obtained by applying the reduction operation kind
2635 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2636 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
2637                                   Location loc, Value lhs, Value rhs) {
2638   switch (op) {
2639   case AtomicRMWKind::addf:
2640     return builder.create<arith::AddFOp>(loc, lhs, rhs);
2641   case AtomicRMWKind::addi:
2642     return builder.create<arith::AddIOp>(loc, lhs, rhs);
2643   case AtomicRMWKind::mulf:
2644     return builder.create<arith::MulFOp>(loc, lhs, rhs);
2645   case AtomicRMWKind::muli:
2646     return builder.create<arith::MulIOp>(loc, lhs, rhs);
2647   case AtomicRMWKind::maximumf:
2648     return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
2649   case AtomicRMWKind::minimumf:
2650     return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
2651    case AtomicRMWKind::maxnumf:
2652     return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
2653   case AtomicRMWKind::minnumf:
2654     return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
2655   case AtomicRMWKind::maxs:
2656     return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
2657   case AtomicRMWKind::mins:
2658     return builder.create<arith::MinSIOp>(loc, lhs, rhs);
2659   case AtomicRMWKind::maxu:
2660     return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
2661   case AtomicRMWKind::minu:
2662     return builder.create<arith::MinUIOp>(loc, lhs, rhs);
2663   case AtomicRMWKind::ori:
2664     return builder.create<arith::OrIOp>(loc, lhs, rhs);
2665   case AtomicRMWKind::andi:
2666     return builder.create<arith::AndIOp>(loc, lhs, rhs);
2667   // TODO: Add remaining reduction operations.
2668   default:
2669     (void)emitOptionalError(loc, "Reduction operation type not supported");
2670     break;
2671   }
2672   return nullptr;
2673 }
2674 
2675 //===----------------------------------------------------------------------===//
2676 // TableGen'd op method definitions
2677 //===----------------------------------------------------------------------===//
2678 
2679 #define GET_OP_CLASSES
2680 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2681 
2682 //===----------------------------------------------------------------------===//
2683 // TableGen'd enum attribute definitions
2684 //===----------------------------------------------------------------------===//
2685 
2686 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
2687