1 //===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <cassert> 10 #include <cstdint> 11 #include <functional> 12 #include <utility> 13 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/CommonFolders.h" 16 #include "mlir/Dialect/UB/IR/UBOps.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/BuiltinAttributeInterfaces.h" 19 #include "mlir/IR/BuiltinAttributes.h" 20 #include "mlir/IR/Matchers.h" 21 #include "mlir/IR/OpImplementation.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/IR/TypeUtilities.h" 24 #include "mlir/Support/LogicalResult.h" 25 26 #include "llvm/ADT/APFloat.h" 27 #include "llvm/ADT/APInt.h" 28 #include "llvm/ADT/APSInt.h" 29 #include "llvm/ADT/FloatingPointMode.h" 30 #include "llvm/ADT/STLExtras.h" 31 #include "llvm/ADT/SmallString.h" 32 #include "llvm/ADT/SmallVector.h" 33 #include "llvm/ADT/TypeSwitch.h" 34 35 using namespace mlir; 36 using namespace mlir::arith; 37 38 //===----------------------------------------------------------------------===// 39 // Pattern helpers 40 //===----------------------------------------------------------------------===// 41 42 static IntegerAttr 43 applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, 44 Attribute rhs, 45 function_ref<APInt(const APInt &, const APInt &)> binFn) { 46 APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue(); 47 APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue(); 48 APInt value = binFn(lhsVal, rhsVal); 49 return IntegerAttr::get(res.getType(), value); 50 } 51 52 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, 53 Attribute lhs, Attribute rhs) { 54 return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>()); 55 } 56 57 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, 58 Attribute lhs, Attribute rhs) { 59 return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>()); 60 } 61 62 static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, 63 Attribute lhs, Attribute rhs) { 64 return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>()); 65 } 66 67 // Merge overflow flags from 2 ops, selecting the most conservative combination. 68 static IntegerOverflowFlagsAttr 69 mergeOverflowFlags(IntegerOverflowFlagsAttr val1, 70 IntegerOverflowFlagsAttr val2) { 71 return IntegerOverflowFlagsAttr::get(val1.getContext(), 72 val1.getValue() & val2.getValue()); 73 } 74 75 /// Invert an integer comparison predicate. 76 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) { 77 switch (pred) { 78 case arith::CmpIPredicate::eq: 79 return arith::CmpIPredicate::ne; 80 case arith::CmpIPredicate::ne: 81 return arith::CmpIPredicate::eq; 82 case arith::CmpIPredicate::slt: 83 return arith::CmpIPredicate::sge; 84 case arith::CmpIPredicate::sle: 85 return arith::CmpIPredicate::sgt; 86 case arith::CmpIPredicate::sgt: 87 return arith::CmpIPredicate::sle; 88 case arith::CmpIPredicate::sge: 89 return arith::CmpIPredicate::slt; 90 case arith::CmpIPredicate::ult: 91 return arith::CmpIPredicate::uge; 92 case arith::CmpIPredicate::ule: 93 return arith::CmpIPredicate::ugt; 94 case arith::CmpIPredicate::ugt: 95 return arith::CmpIPredicate::ule; 96 case arith::CmpIPredicate::uge: 97 return arith::CmpIPredicate::ult; 98 } 99 llvm_unreachable("unknown cmpi predicate kind"); 100 } 101 102 /// Equivalent to 103 /// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)). 104 /// 105 /// Not possible to implement as chain of calls as this would introduce a 106 /// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend 107 /// on the LLVM dialect and on translation to LLVM. 108 static llvm::RoundingMode 109 convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) { 110 switch (roundingMode) { 111 case RoundingMode::downward: 112 return llvm::RoundingMode::TowardNegative; 113 case RoundingMode::to_nearest_away: 114 return llvm::RoundingMode::NearestTiesToAway; 115 case RoundingMode::to_nearest_even: 116 return llvm::RoundingMode::NearestTiesToEven; 117 case RoundingMode::toward_zero: 118 return llvm::RoundingMode::TowardZero; 119 case RoundingMode::upward: 120 return llvm::RoundingMode::TowardPositive; 121 } 122 llvm_unreachable("Unhandled rounding mode"); 123 } 124 125 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { 126 return arith::CmpIPredicateAttr::get(pred.getContext(), 127 invertPredicate(pred.getValue())); 128 } 129 130 static int64_t getScalarOrElementWidth(Type type) { 131 Type elemTy = getElementTypeOrSelf(type); 132 if (elemTy.isIntOrFloat()) 133 return elemTy.getIntOrFloatBitWidth(); 134 135 return -1; 136 } 137 138 static int64_t getScalarOrElementWidth(Value value) { 139 return getScalarOrElementWidth(value.getType()); 140 } 141 142 static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) { 143 APInt value; 144 if (matchPattern(attr, m_ConstantInt(&value))) 145 return value; 146 147 return failure(); 148 } 149 150 static Attribute getBoolAttribute(Type type, bool value) { 151 auto boolAttr = BoolAttr::get(type.getContext(), value); 152 ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type); 153 if (!shapedType) 154 return boolAttr; 155 return DenseElementsAttr::get(shapedType, boolAttr); 156 } 157 158 //===----------------------------------------------------------------------===// 159 // TableGen'd canonicalization patterns 160 //===----------------------------------------------------------------------===// 161 162 namespace { 163 #include "ArithCanonicalization.inc" 164 } // namespace 165 166 //===----------------------------------------------------------------------===// 167 // Common helpers 168 //===----------------------------------------------------------------------===// 169 170 /// Return the type of the same shape (scalar, vector or tensor) containing i1. 171 static Type getI1SameShape(Type type) { 172 auto i1Type = IntegerType::get(type.getContext(), 1); 173 if (auto shapedType = llvm::dyn_cast<ShapedType>(type)) 174 return shapedType.cloneWith(std::nullopt, i1Type); 175 if (llvm::isa<UnrankedTensorType>(type)) 176 return UnrankedTensorType::get(i1Type); 177 return i1Type; 178 } 179 180 //===----------------------------------------------------------------------===// 181 // ConstantOp 182 //===----------------------------------------------------------------------===// 183 184 void arith::ConstantOp::getAsmResultNames( 185 function_ref<void(Value, StringRef)> setNameFn) { 186 auto type = getType(); 187 if (auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) { 188 auto intType = llvm::dyn_cast<IntegerType>(type); 189 190 // Sugar i1 constants with 'true' and 'false'. 191 if (intType && intType.getWidth() == 1) 192 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); 193 194 // Otherwise, build a complex name with the value and type. 195 SmallString<32> specialNameBuffer; 196 llvm::raw_svector_ostream specialName(specialNameBuffer); 197 specialName << 'c' << intCst.getValue(); 198 if (intType) 199 specialName << '_' << type; 200 setNameFn(getResult(), specialName.str()); 201 } else { 202 setNameFn(getResult(), "cst"); 203 } 204 } 205 206 /// TODO: disallow arith.constant to return anything other than signless integer 207 /// or float like. 208 LogicalResult arith::ConstantOp::verify() { 209 auto type = getType(); 210 // The value's type must match the return type. 211 if (getValue().getType() != type) { 212 return emitOpError() << "value type " << getValue().getType() 213 << " must match return type: " << type; 214 } 215 // Integer values must be signless. 216 if (llvm::isa<IntegerType>(type) && 217 !llvm::cast<IntegerType>(type).isSignless()) 218 return emitOpError("integer return type must be signless"); 219 // Any float or elements attribute are acceptable. 220 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) { 221 return emitOpError( 222 "value must be an integer, float, or elements attribute"); 223 } 224 225 // Note, we could relax this for vectors with 1 scalable dim, e.g.: 226 // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32> 227 // However, this would most likely require updating the lowerings to LLVM. 228 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue())) 229 return emitOpError( 230 "intializing scalable vectors with elements attribute is not supported" 231 " unless it's a vector splat"); 232 return success(); 233 } 234 235 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { 236 // The value's type must be the same as the provided type. 237 auto typedAttr = llvm::dyn_cast<TypedAttr>(value); 238 if (!typedAttr || typedAttr.getType() != type) 239 return false; 240 // Integer values must be signless. 241 if (llvm::isa<IntegerType>(type) && 242 !llvm::cast<IntegerType>(type).isSignless()) 243 return false; 244 // Integer, float, and element attributes are buildable. 245 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value); 246 } 247 248 ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value, 249 Type type, Location loc) { 250 if (isBuildableWith(value, type)) 251 return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value)); 252 return nullptr; 253 } 254 255 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } 256 257 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 258 int64_t value, unsigned width) { 259 auto type = builder.getIntegerType(width); 260 arith::ConstantOp::build(builder, result, type, 261 builder.getIntegerAttr(type, value)); 262 } 263 264 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 265 int64_t value, Type type) { 266 assert(type.isSignlessInteger() && 267 "ConstantIntOp can only have signless integer type values"); 268 arith::ConstantOp::build(builder, result, type, 269 builder.getIntegerAttr(type, value)); 270 } 271 272 bool arith::ConstantIntOp::classof(Operation *op) { 273 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 274 return constOp.getType().isSignlessInteger(); 275 return false; 276 } 277 278 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, 279 const APFloat &value, FloatType type) { 280 arith::ConstantOp::build(builder, result, type, 281 builder.getFloatAttr(type, value)); 282 } 283 284 bool arith::ConstantFloatOp::classof(Operation *op) { 285 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 286 return llvm::isa<FloatType>(constOp.getType()); 287 return false; 288 } 289 290 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, 291 int64_t value) { 292 arith::ConstantOp::build(builder, result, builder.getIndexType(), 293 builder.getIndexAttr(value)); 294 } 295 296 bool arith::ConstantIndexOp::classof(Operation *op) { 297 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 298 return constOp.getType().isIndex(); 299 return false; 300 } 301 302 //===----------------------------------------------------------------------===// 303 // AddIOp 304 //===----------------------------------------------------------------------===// 305 306 OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) { 307 // addi(x, 0) -> x 308 if (matchPattern(adaptor.getRhs(), m_Zero())) 309 return getLhs(); 310 311 // addi(subi(a, b), b) -> a 312 if (auto sub = getLhs().getDefiningOp<SubIOp>()) 313 if (getRhs() == sub.getRhs()) 314 return sub.getLhs(); 315 316 // addi(b, subi(a, b)) -> a 317 if (auto sub = getRhs().getDefiningOp<SubIOp>()) 318 if (getLhs() == sub.getRhs()) 319 return sub.getLhs(); 320 321 return constFoldBinaryOp<IntegerAttr>( 322 adaptor.getOperands(), 323 [](APInt a, const APInt &b) { return std::move(a) + b; }); 324 } 325 326 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 327 MLIRContext *context) { 328 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS, 329 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context); 330 } 331 332 //===----------------------------------------------------------------------===// 333 // AddUIExtendedOp 334 //===----------------------------------------------------------------------===// 335 336 std::optional<SmallVector<int64_t, 4>> 337 arith::AddUIExtendedOp::getShapeForUnroll() { 338 if (auto vt = llvm::dyn_cast<VectorType>(getType(0))) 339 return llvm::to_vector<4>(vt.getShape()); 340 return std::nullopt; 341 } 342 343 // Returns the overflow bit, assuming that `sum` is the result of unsigned 344 // addition of `operand` and another number. 345 static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) { 346 return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1); 347 } 348 349 LogicalResult 350 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor, 351 SmallVectorImpl<OpFoldResult> &results) { 352 Type overflowTy = getOverflow().getType(); 353 // addui_extended(x, 0) -> x, false 354 if (matchPattern(getRhs(), m_Zero())) { 355 Builder builder(getContext()); 356 auto falseValue = builder.getZeroAttr(overflowTy); 357 358 results.push_back(getLhs()); 359 results.push_back(falseValue); 360 return success(); 361 } 362 363 // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry 364 // Let the `constFoldBinaryOp` utility attempt to fold the sum of both 365 // operands. If that succeeds, calculate the overflow bit based on the sum 366 // and the first (constant) operand, `lhs`. 367 if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>( 368 adaptor.getOperands(), 369 [](APInt a, const APInt &b) { return std::move(a) + b; })) { 370 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>( 371 ArrayRef({sumAttr, adaptor.getLhs()}), 372 getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()), 373 calculateUnsignedOverflow); 374 if (!overflowAttr) 375 return failure(); 376 377 results.push_back(sumAttr); 378 results.push_back(overflowAttr); 379 return success(); 380 } 381 382 return failure(); 383 } 384 385 void arith::AddUIExtendedOp::getCanonicalizationPatterns( 386 RewritePatternSet &patterns, MLIRContext *context) { 387 patterns.add<AddUIExtendedToAddI>(context); 388 } 389 390 //===----------------------------------------------------------------------===// 391 // SubIOp 392 //===----------------------------------------------------------------------===// 393 394 OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) { 395 // subi(x,x) -> 0 396 if (getOperand(0) == getOperand(1)) { 397 auto shapedType = dyn_cast<ShapedType>(getType()); 398 // We can't generate a constant with a dynamic shaped tensor. 399 if (!shapedType || shapedType.hasStaticShape()) 400 return Builder(getContext()).getZeroAttr(getType()); 401 } 402 // subi(x,0) -> x 403 if (matchPattern(adaptor.getRhs(), m_Zero())) 404 return getLhs(); 405 406 if (auto add = getLhs().getDefiningOp<AddIOp>()) { 407 // subi(addi(a, b), b) -> a 408 if (getRhs() == add.getRhs()) 409 return add.getLhs(); 410 // subi(addi(a, b), a) -> b 411 if (getRhs() == add.getLhs()) 412 return add.getRhs(); 413 } 414 415 return constFoldBinaryOp<IntegerAttr>( 416 adaptor.getOperands(), 417 [](APInt a, const APInt &b) { return std::move(a) - b; }); 418 } 419 420 void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 421 MLIRContext *context) { 422 patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS, 423 SubIRHSSubConstantLHS, SubILHSSubConstantRHS, 424 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context); 425 } 426 427 //===----------------------------------------------------------------------===// 428 // MulIOp 429 //===----------------------------------------------------------------------===// 430 431 OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) { 432 // muli(x, 0) -> 0 433 if (matchPattern(adaptor.getRhs(), m_Zero())) 434 return getRhs(); 435 // muli(x, 1) -> x 436 if (matchPattern(adaptor.getRhs(), m_One())) 437 return getLhs(); 438 // TODO: Handle the overflow case. 439 440 // default folder 441 return constFoldBinaryOp<IntegerAttr>( 442 adaptor.getOperands(), 443 [](const APInt &a, const APInt &b) { return a * b; }); 444 } 445 446 void arith::MulIOp::getAsmResultNames( 447 function_ref<void(Value, StringRef)> setNameFn) { 448 if (!isa<IndexType>(getType())) 449 return; 450 451 // Match vector.vscale by name to avoid depending on the vector dialect (which 452 // is a circular dependency). 453 auto isVscale = [](Operation *op) { 454 return op && op->getName().getStringRef() == "vector.vscale"; 455 }; 456 457 IntegerAttr baseValue; 458 auto isVscaleExpr = [&](Value a, Value b) { 459 return matchPattern(a, m_Constant(&baseValue)) && 460 isVscale(b.getDefiningOp()); 461 }; 462 463 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs())) 464 return; 465 466 // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`. 467 SmallString<32> specialNameBuffer; 468 llvm::raw_svector_ostream specialName(specialNameBuffer); 469 specialName << 'c' << baseValue.getInt() << "_vscale"; 470 setNameFn(getResult(), specialName.str()); 471 } 472 473 void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 474 MLIRContext *context) { 475 patterns.add<MulIMulIConstant>(context); 476 } 477 478 //===----------------------------------------------------------------------===// 479 // MulSIExtendedOp 480 //===----------------------------------------------------------------------===// 481 482 std::optional<SmallVector<int64_t, 4>> 483 arith::MulSIExtendedOp::getShapeForUnroll() { 484 if (auto vt = llvm::dyn_cast<VectorType>(getType(0))) 485 return llvm::to_vector<4>(vt.getShape()); 486 return std::nullopt; 487 } 488 489 LogicalResult 490 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor, 491 SmallVectorImpl<OpFoldResult> &results) { 492 // mulsi_extended(x, 0) -> 0, 0 493 if (matchPattern(adaptor.getRhs(), m_Zero())) { 494 Attribute zero = adaptor.getRhs(); 495 results.push_back(zero); 496 results.push_back(zero); 497 return success(); 498 } 499 500 // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high 501 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>( 502 adaptor.getOperands(), 503 [](const APInt &a, const APInt &b) { return a * b; })) { 504 // Invoke the constant fold helper again to calculate the 'high' result. 505 Attribute highAttr = constFoldBinaryOp<IntegerAttr>( 506 adaptor.getOperands(), [](const APInt &a, const APInt &b) { 507 return llvm::APIntOps::mulhs(a, b); 508 }); 509 assert(highAttr && "Unexpected constant-folding failure"); 510 511 results.push_back(lowAttr); 512 results.push_back(highAttr); 513 return success(); 514 } 515 516 return failure(); 517 } 518 519 void arith::MulSIExtendedOp::getCanonicalizationPatterns( 520 RewritePatternSet &patterns, MLIRContext *context) { 521 patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context); 522 } 523 524 //===----------------------------------------------------------------------===// 525 // MulUIExtendedOp 526 //===----------------------------------------------------------------------===// 527 528 std::optional<SmallVector<int64_t, 4>> 529 arith::MulUIExtendedOp::getShapeForUnroll() { 530 if (auto vt = llvm::dyn_cast<VectorType>(getType(0))) 531 return llvm::to_vector<4>(vt.getShape()); 532 return std::nullopt; 533 } 534 535 LogicalResult 536 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor, 537 SmallVectorImpl<OpFoldResult> &results) { 538 // mului_extended(x, 0) -> 0, 0 539 if (matchPattern(adaptor.getRhs(), m_Zero())) { 540 Attribute zero = adaptor.getRhs(); 541 results.push_back(zero); 542 results.push_back(zero); 543 return success(); 544 } 545 546 // mului_extended(x, 1) -> x, 0 547 if (matchPattern(adaptor.getRhs(), m_One())) { 548 Builder builder(getContext()); 549 Attribute zero = builder.getZeroAttr(getLhs().getType()); 550 results.push_back(getLhs()); 551 results.push_back(zero); 552 return success(); 553 } 554 555 // mului_extended(cst_a, cst_b) -> cst_low, cst_high 556 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>( 557 adaptor.getOperands(), 558 [](const APInt &a, const APInt &b) { return a * b; })) { 559 // Invoke the constant fold helper again to calculate the 'high' result. 560 Attribute highAttr = constFoldBinaryOp<IntegerAttr>( 561 adaptor.getOperands(), [](const APInt &a, const APInt &b) { 562 return llvm::APIntOps::mulhu(a, b); 563 }); 564 assert(highAttr && "Unexpected constant-folding failure"); 565 566 results.push_back(lowAttr); 567 results.push_back(highAttr); 568 return success(); 569 } 570 571 return failure(); 572 } 573 574 void arith::MulUIExtendedOp::getCanonicalizationPatterns( 575 RewritePatternSet &patterns, MLIRContext *context) { 576 patterns.add<MulUIExtendedToMulI>(context); 577 } 578 579 //===----------------------------------------------------------------------===// 580 // DivUIOp 581 //===----------------------------------------------------------------------===// 582 583 /// Fold `(a * b) / b -> a` 584 static Value foldDivMul(Value lhs, Value rhs, 585 arith::IntegerOverflowFlags ovfFlags) { 586 auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>(); 587 if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags)) 588 return {}; 589 590 if (mul.getLhs() == rhs) 591 return mul.getRhs(); 592 593 if (mul.getRhs() == rhs) 594 return mul.getLhs(); 595 596 return {}; 597 } 598 599 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) { 600 // divui (x, 1) -> x. 601 if (matchPattern(adaptor.getRhs(), m_One())) 602 return getLhs(); 603 604 // (a * b) / b -> a 605 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw)) 606 return val; 607 608 // Don't fold if it would require a division by zero. 609 bool div0 = false; 610 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), 611 [&](APInt a, const APInt &b) { 612 if (div0 || !b) { 613 div0 = true; 614 return a; 615 } 616 return a.udiv(b); 617 }); 618 619 return div0 ? Attribute() : result; 620 } 621 622 /// Returns whether an unsigned division by `divisor` is speculatable. 623 static Speculation::Speculatability getDivUISpeculatability(Value divisor) { 624 // X / 0 => UB 625 if (matchPattern(divisor, m_IntRangeWithoutZeroU())) 626 return Speculation::Speculatable; 627 628 return Speculation::NotSpeculatable; 629 } 630 631 Speculation::Speculatability arith::DivUIOp::getSpeculatability() { 632 return getDivUISpeculatability(getRhs()); 633 } 634 635 //===----------------------------------------------------------------------===// 636 // DivSIOp 637 //===----------------------------------------------------------------------===// 638 639 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) { 640 // divsi (x, 1) -> x. 641 if (matchPattern(adaptor.getRhs(), m_One())) 642 return getLhs(); 643 644 // (a * b) / b -> a 645 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw)) 646 return val; 647 648 // Don't fold if it would overflow or if it requires a division by zero. 649 bool overflowOrDiv0 = false; 650 auto result = constFoldBinaryOp<IntegerAttr>( 651 adaptor.getOperands(), [&](APInt a, const APInt &b) { 652 if (overflowOrDiv0 || !b) { 653 overflowOrDiv0 = true; 654 return a; 655 } 656 return a.sdiv_ov(b, overflowOrDiv0); 657 }); 658 659 return overflowOrDiv0 ? Attribute() : result; 660 } 661 662 /// Returns whether a signed division by `divisor` is speculatable. This 663 /// function conservatively assumes that all signed division by -1 are not 664 /// speculatable. 665 static Speculation::Speculatability getDivSISpeculatability(Value divisor) { 666 // X / 0 => UB 667 // INT_MIN / -1 => UB 668 if (matchPattern(divisor, m_IntRangeWithoutZeroS()) && 669 matchPattern(divisor, m_IntRangeWithoutNegOneS())) 670 return Speculation::Speculatable; 671 672 return Speculation::NotSpeculatable; 673 } 674 675 Speculation::Speculatability arith::DivSIOp::getSpeculatability() { 676 return getDivSISpeculatability(getRhs()); 677 } 678 679 //===----------------------------------------------------------------------===// 680 // Ceil and floor division folding helpers 681 //===----------------------------------------------------------------------===// 682 683 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, 684 bool &overflow) { 685 // Returns (a-1)/b + 1 686 APInt one(a.getBitWidth(), 1, true); // Signed value 1. 687 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow); 688 return val.sadd_ov(one, overflow); 689 } 690 691 //===----------------------------------------------------------------------===// 692 // CeilDivUIOp 693 //===----------------------------------------------------------------------===// 694 695 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) { 696 // ceildivui (x, 1) -> x. 697 if (matchPattern(adaptor.getRhs(), m_One())) 698 return getLhs(); 699 700 bool overflowOrDiv0 = false; 701 auto result = constFoldBinaryOp<IntegerAttr>( 702 adaptor.getOperands(), [&](APInt a, const APInt &b) { 703 if (overflowOrDiv0 || !b) { 704 overflowOrDiv0 = true; 705 return a; 706 } 707 APInt quotient = a.udiv(b); 708 if (!a.urem(b)) 709 return quotient; 710 APInt one(a.getBitWidth(), 1, true); 711 return quotient.uadd_ov(one, overflowOrDiv0); 712 }); 713 714 return overflowOrDiv0 ? Attribute() : result; 715 } 716 717 Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() { 718 return getDivUISpeculatability(getRhs()); 719 } 720 721 //===----------------------------------------------------------------------===// 722 // CeilDivSIOp 723 //===----------------------------------------------------------------------===// 724 725 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) { 726 // ceildivsi (x, 1) -> x. 727 if (matchPattern(adaptor.getRhs(), m_One())) 728 return getLhs(); 729 730 // Don't fold if it would overflow or if it requires a division by zero. 731 // TODO: This hook won't fold operations where a = MININT, because 732 // negating MININT overflows. This can be improved. 733 bool overflowOrDiv0 = false; 734 auto result = constFoldBinaryOp<IntegerAttr>( 735 adaptor.getOperands(), [&](APInt a, const APInt &b) { 736 if (overflowOrDiv0 || !b) { 737 overflowOrDiv0 = true; 738 return a; 739 } 740 if (!a) 741 return a; 742 // After this point we know that neither a or b are zero. 743 unsigned bits = a.getBitWidth(); 744 APInt zero = APInt::getZero(bits); 745 bool aGtZero = a.sgt(zero); 746 bool bGtZero = b.sgt(zero); 747 if (aGtZero && bGtZero) { 748 // Both positive, return ceil(a, b). 749 return signedCeilNonnegInputs(a, b, overflowOrDiv0); 750 } 751 752 // No folding happens if any of the intermediate arithmetic operations 753 // overflows. 754 bool overflowNegA = false; 755 bool overflowNegB = false; 756 bool overflowDiv = false; 757 bool overflowNegRes = false; 758 if (!aGtZero && !bGtZero) { 759 // Both negative, return ceil(-a, -b). 760 APInt posA = zero.ssub_ov(a, overflowNegA); 761 APInt posB = zero.ssub_ov(b, overflowNegB); 762 APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv); 763 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv); 764 return res; 765 } 766 if (!aGtZero && bGtZero) { 767 // A is negative, b is positive, return - ( -a / b). 768 APInt posA = zero.ssub_ov(a, overflowNegA); 769 APInt div = posA.sdiv_ov(b, overflowDiv); 770 APInt res = zero.ssub_ov(div, overflowNegRes); 771 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes); 772 return res; 773 } 774 // A is positive, b is negative, return - (a / -b). 775 APInt posB = zero.ssub_ov(b, overflowNegB); 776 APInt div = a.sdiv_ov(posB, overflowDiv); 777 APInt res = zero.ssub_ov(div, overflowNegRes); 778 779 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes); 780 return res; 781 }); 782 783 return overflowOrDiv0 ? Attribute() : result; 784 } 785 786 Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() { 787 return getDivSISpeculatability(getRhs()); 788 } 789 790 //===----------------------------------------------------------------------===// 791 // FloorDivSIOp 792 //===----------------------------------------------------------------------===// 793 794 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) { 795 // floordivsi (x, 1) -> x. 796 if (matchPattern(adaptor.getRhs(), m_One())) 797 return getLhs(); 798 799 // Don't fold if it would overflow or if it requires a division by zero. 800 bool overflowOrDiv = false; 801 auto result = constFoldBinaryOp<IntegerAttr>( 802 adaptor.getOperands(), [&](APInt a, const APInt &b) { 803 if (b.isZero()) { 804 overflowOrDiv = true; 805 return a; 806 } 807 return a.sfloordiv_ov(b, overflowOrDiv); 808 }); 809 810 return overflowOrDiv ? Attribute() : result; 811 } 812 813 //===----------------------------------------------------------------------===// 814 // RemUIOp 815 //===----------------------------------------------------------------------===// 816 817 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) { 818 // remui (x, 1) -> 0. 819 if (matchPattern(adaptor.getRhs(), m_One())) 820 return Builder(getContext()).getZeroAttr(getType()); 821 822 // Don't fold if it would require a division by zero. 823 bool div0 = false; 824 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), 825 [&](APInt a, const APInt &b) { 826 if (div0 || b.isZero()) { 827 div0 = true; 828 return a; 829 } 830 return a.urem(b); 831 }); 832 833 return div0 ? Attribute() : result; 834 } 835 836 //===----------------------------------------------------------------------===// 837 // RemSIOp 838 //===----------------------------------------------------------------------===// 839 840 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) { 841 // remsi (x, 1) -> 0. 842 if (matchPattern(adaptor.getRhs(), m_One())) 843 return Builder(getContext()).getZeroAttr(getType()); 844 845 // Don't fold if it would require a division by zero. 846 bool div0 = false; 847 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), 848 [&](APInt a, const APInt &b) { 849 if (div0 || b.isZero()) { 850 div0 = true; 851 return a; 852 } 853 return a.srem(b); 854 }); 855 856 return div0 ? Attribute() : result; 857 } 858 859 //===----------------------------------------------------------------------===// 860 // AndIOp 861 //===----------------------------------------------------------------------===// 862 863 /// Fold `and(a, and(a, b))` to `and(a, b)` 864 static Value foldAndIofAndI(arith::AndIOp op) { 865 for (bool reversePrev : {false, true}) { 866 auto prev = (reversePrev ? op.getRhs() : op.getLhs()) 867 .getDefiningOp<arith::AndIOp>(); 868 if (!prev) 869 continue; 870 871 Value other = (reversePrev ? op.getLhs() : op.getRhs()); 872 if (other != prev.getLhs() && other != prev.getRhs()) 873 continue; 874 875 return prev.getResult(); 876 } 877 return {}; 878 } 879 880 OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) { 881 /// and(x, 0) -> 0 882 if (matchPattern(adaptor.getRhs(), m_Zero())) 883 return getRhs(); 884 /// and(x, allOnes) -> x 885 APInt intValue; 886 if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) && 887 intValue.isAllOnes()) 888 return getLhs(); 889 /// and(x, not(x)) -> 0 890 if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()), 891 m_ConstantInt(&intValue))) && 892 intValue.isAllOnes()) 893 return Builder(getContext()).getZeroAttr(getType()); 894 /// and(not(x), x) -> 0 895 if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()), 896 m_ConstantInt(&intValue))) && 897 intValue.isAllOnes()) 898 return Builder(getContext()).getZeroAttr(getType()); 899 900 /// and(a, and(a, b)) -> and(a, b) 901 if (Value result = foldAndIofAndI(*this)) 902 return result; 903 904 return constFoldBinaryOp<IntegerAttr>( 905 adaptor.getOperands(), 906 [](APInt a, const APInt &b) { return std::move(a) & b; }); 907 } 908 909 //===----------------------------------------------------------------------===// 910 // OrIOp 911 //===----------------------------------------------------------------------===// 912 913 OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) { 914 if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) { 915 /// or(x, 0) -> x 916 if (rhsVal.isZero()) 917 return getLhs(); 918 /// or(x, <all ones>) -> <all ones> 919 if (rhsVal.isAllOnes()) 920 return adaptor.getRhs(); 921 } 922 923 APInt intValue; 924 /// or(x, xor(x, 1)) -> 1 925 if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()), 926 m_ConstantInt(&intValue))) && 927 intValue.isAllOnes()) 928 return getRhs().getDefiningOp<XOrIOp>().getRhs(); 929 /// or(xor(x, 1), x) -> 1 930 if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()), 931 m_ConstantInt(&intValue))) && 932 intValue.isAllOnes()) 933 return getLhs().getDefiningOp<XOrIOp>().getRhs(); 934 935 return constFoldBinaryOp<IntegerAttr>( 936 adaptor.getOperands(), 937 [](APInt a, const APInt &b) { return std::move(a) | b; }); 938 } 939 940 //===----------------------------------------------------------------------===// 941 // XOrIOp 942 //===----------------------------------------------------------------------===// 943 944 OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) { 945 /// xor(x, 0) -> x 946 if (matchPattern(adaptor.getRhs(), m_Zero())) 947 return getLhs(); 948 /// xor(x, x) -> 0 949 if (getLhs() == getRhs()) 950 return Builder(getContext()).getZeroAttr(getType()); 951 /// xor(xor(x, a), a) -> x 952 /// xor(xor(a, x), a) -> x 953 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) { 954 if (prev.getRhs() == getRhs()) 955 return prev.getLhs(); 956 if (prev.getLhs() == getRhs()) 957 return prev.getRhs(); 958 } 959 /// xor(a, xor(x, a)) -> x 960 /// xor(a, xor(a, x)) -> x 961 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) { 962 if (prev.getRhs() == getLhs()) 963 return prev.getLhs(); 964 if (prev.getLhs() == getLhs()) 965 return prev.getRhs(); 966 } 967 968 return constFoldBinaryOp<IntegerAttr>( 969 adaptor.getOperands(), 970 [](APInt a, const APInt &b) { return std::move(a) ^ b; }); 971 } 972 973 void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 974 MLIRContext *context) { 975 patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context); 976 } 977 978 //===----------------------------------------------------------------------===// 979 // NegFOp 980 //===----------------------------------------------------------------------===// 981 982 OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) { 983 /// negf(negf(x)) -> x 984 if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>()) 985 return op.getOperand(); 986 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(), 987 [](const APFloat &a) { return -a; }); 988 } 989 990 //===----------------------------------------------------------------------===// 991 // AddFOp 992 //===----------------------------------------------------------------------===// 993 994 OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) { 995 // addf(x, -0) -> x 996 if (matchPattern(adaptor.getRhs(), m_NegZeroFloat())) 997 return getLhs(); 998 999 return constFoldBinaryOp<FloatAttr>( 1000 adaptor.getOperands(), 1001 [](const APFloat &a, const APFloat &b) { return a + b; }); 1002 } 1003 1004 //===----------------------------------------------------------------------===// 1005 // SubFOp 1006 //===----------------------------------------------------------------------===// 1007 1008 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) { 1009 // subf(x, +0) -> x 1010 if (matchPattern(adaptor.getRhs(), m_PosZeroFloat())) 1011 return getLhs(); 1012 1013 return constFoldBinaryOp<FloatAttr>( 1014 adaptor.getOperands(), 1015 [](const APFloat &a, const APFloat &b) { return a - b; }); 1016 } 1017 1018 //===----------------------------------------------------------------------===// 1019 // MaximumFOp 1020 //===----------------------------------------------------------------------===// 1021 1022 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) { 1023 // maximumf(x,x) -> x 1024 if (getLhs() == getRhs()) 1025 return getRhs(); 1026 1027 // maximumf(x, -inf) -> x 1028 if (matchPattern(adaptor.getRhs(), m_NegInfFloat())) 1029 return getLhs(); 1030 1031 return constFoldBinaryOp<FloatAttr>( 1032 adaptor.getOperands(), 1033 [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); 1034 } 1035 1036 //===----------------------------------------------------------------------===// 1037 // MaxNumFOp 1038 //===----------------------------------------------------------------------===// 1039 1040 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) { 1041 // maxnumf(x,x) -> x 1042 if (getLhs() == getRhs()) 1043 return getRhs(); 1044 1045 // maxnumf(x, NaN) -> x 1046 if (matchPattern(adaptor.getRhs(), m_NaNFloat())) 1047 return getLhs(); 1048 1049 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum); 1050 } 1051 1052 //===----------------------------------------------------------------------===// 1053 // MaxSIOp 1054 //===----------------------------------------------------------------------===// 1055 1056 OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) { 1057 // maxsi(x,x) -> x 1058 if (getLhs() == getRhs()) 1059 return getRhs(); 1060 1061 if (APInt intValue; 1062 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) { 1063 // maxsi(x,MAX_INT) -> MAX_INT 1064 if (intValue.isMaxSignedValue()) 1065 return getRhs(); 1066 // maxsi(x, MIN_INT) -> x 1067 if (intValue.isMinSignedValue()) 1068 return getLhs(); 1069 } 1070 1071 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), 1072 [](const APInt &a, const APInt &b) { 1073 return llvm::APIntOps::smax(a, b); 1074 }); 1075 } 1076 1077 //===----------------------------------------------------------------------===// 1078 // MaxUIOp 1079 //===----------------------------------------------------------------------===// 1080 1081 OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) { 1082 // maxui(x,x) -> x 1083 if (getLhs() == getRhs()) 1084 return getRhs(); 1085 1086 if (APInt intValue; 1087 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) { 1088 // maxui(x,MAX_INT) -> MAX_INT 1089 if (intValue.isMaxValue()) 1090 return getRhs(); 1091 // maxui(x, MIN_INT) -> x 1092 if (intValue.isMinValue()) 1093 return getLhs(); 1094 } 1095 1096 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), 1097 [](const APInt &a, const APInt &b) { 1098 return llvm::APIntOps::umax(a, b); 1099 }); 1100 } 1101 1102 //===----------------------------------------------------------------------===// 1103 // MinimumFOp 1104 //===----------------------------------------------------------------------===// 1105 1106 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) { 1107 // minimumf(x,x) -> x 1108 if (getLhs() == getRhs()) 1109 return getRhs(); 1110 1111 // minimumf(x, +inf) -> x 1112 if (matchPattern(adaptor.getRhs(), m_PosInfFloat())) 1113 return getLhs(); 1114 1115 return constFoldBinaryOp<FloatAttr>( 1116 adaptor.getOperands(), 1117 [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); }); 1118 } 1119 1120 //===----------------------------------------------------------------------===// 1121 // MinNumFOp 1122 //===----------------------------------------------------------------------===// 1123 1124 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) { 1125 // minnumf(x,x) -> x 1126 if (getLhs() == getRhs()) 1127 return getRhs(); 1128 1129 // minnumf(x, NaN) -> x 1130 if (matchPattern(adaptor.getRhs(), m_NaNFloat())) 1131 return getLhs(); 1132 1133 return constFoldBinaryOp<FloatAttr>( 1134 adaptor.getOperands(), 1135 [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); }); 1136 } 1137 1138 //===----------------------------------------------------------------------===// 1139 // MinSIOp 1140 //===----------------------------------------------------------------------===// 1141 1142 OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) { 1143 // minsi(x,x) -> x 1144 if (getLhs() == getRhs()) 1145 return getRhs(); 1146 1147 if (APInt intValue; 1148 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) { 1149 // minsi(x,MIN_INT) -> MIN_INT 1150 if (intValue.isMinSignedValue()) 1151 return getRhs(); 1152 // minsi(x, MAX_INT) -> x 1153 if (intValue.isMaxSignedValue()) 1154 return getLhs(); 1155 } 1156 1157 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), 1158 [](const APInt &a, const APInt &b) { 1159 return llvm::APIntOps::smin(a, b); 1160 }); 1161 } 1162 1163 //===----------------------------------------------------------------------===// 1164 // MinUIOp 1165 //===----------------------------------------------------------------------===// 1166 1167 OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) { 1168 // minui(x,x) -> x 1169 if (getLhs() == getRhs()) 1170 return getRhs(); 1171 1172 if (APInt intValue; 1173 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) { 1174 // minui(x,MIN_INT) -> MIN_INT 1175 if (intValue.isMinValue()) 1176 return getRhs(); 1177 // minui(x, MAX_INT) -> x 1178 if (intValue.isMaxValue()) 1179 return getLhs(); 1180 } 1181 1182 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), 1183 [](const APInt &a, const APInt &b) { 1184 return llvm::APIntOps::umin(a, b); 1185 }); 1186 } 1187 1188 //===----------------------------------------------------------------------===// 1189 // MulFOp 1190 //===----------------------------------------------------------------------===// 1191 1192 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) { 1193 // mulf(x, 1) -> x 1194 if (matchPattern(adaptor.getRhs(), m_OneFloat())) 1195 return getLhs(); 1196 1197 return constFoldBinaryOp<FloatAttr>( 1198 adaptor.getOperands(), 1199 [](const APFloat &a, const APFloat &b) { return a * b; }); 1200 } 1201 1202 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1203 MLIRContext *context) { 1204 patterns.add<MulFOfNegF>(context); 1205 } 1206 1207 //===----------------------------------------------------------------------===// 1208 // DivFOp 1209 //===----------------------------------------------------------------------===// 1210 1211 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) { 1212 // divf(x, 1) -> x 1213 if (matchPattern(adaptor.getRhs(), m_OneFloat())) 1214 return getLhs(); 1215 1216 return constFoldBinaryOp<FloatAttr>( 1217 adaptor.getOperands(), 1218 [](const APFloat &a, const APFloat &b) { return a / b; }); 1219 } 1220 1221 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1222 MLIRContext *context) { 1223 patterns.add<DivFOfNegF>(context); 1224 } 1225 1226 //===----------------------------------------------------------------------===// 1227 // RemFOp 1228 //===----------------------------------------------------------------------===// 1229 1230 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) { 1231 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), 1232 [](const APFloat &a, const APFloat &b) { 1233 APFloat result(a); 1234 // APFloat::mod() offers the remainder 1235 // behavior we want, i.e. the result has 1236 // the sign of LHS operand. 1237 (void)result.mod(b); 1238 return result; 1239 }); 1240 } 1241 1242 //===----------------------------------------------------------------------===// 1243 // Utility functions for verifying cast ops 1244 //===----------------------------------------------------------------------===// 1245 1246 template <typename... Types> 1247 using type_list = std::tuple<Types...> *; 1248 1249 /// Returns a non-null type only if the provided type is one of the allowed 1250 /// types or one of the allowed shaped types of the allowed types. Returns the 1251 /// element type if a valid shaped type is provided. 1252 template <typename... ShapedTypes, typename... ElementTypes> 1253 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>, 1254 type_list<ElementTypes...>) { 1255 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type)) 1256 return {}; 1257 1258 auto underlyingType = getElementTypeOrSelf(type); 1259 if (!llvm::isa<ElementTypes...>(underlyingType)) 1260 return {}; 1261 1262 return underlyingType; 1263 } 1264 1265 /// Get allowed underlying types for vectors and tensors. 1266 template <typename... ElementTypes> 1267 static Type getTypeIfLike(Type type) { 1268 return getUnderlyingType(type, type_list<VectorType, TensorType>(), 1269 type_list<ElementTypes...>()); 1270 } 1271 1272 /// Get allowed underlying types for vectors, tensors, and memrefs. 1273 template <typename... ElementTypes> 1274 static Type getTypeIfLikeOrMemRef(Type type) { 1275 return getUnderlyingType(type, 1276 type_list<VectorType, TensorType, MemRefType>(), 1277 type_list<ElementTypes...>()); 1278 } 1279 1280 /// Return false if both types are ranked tensor with mismatching encoding. 1281 static bool hasSameEncoding(Type typeA, Type typeB) { 1282 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA); 1283 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB); 1284 if (!rankedTensorA || !rankedTensorB) 1285 return true; 1286 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding(); 1287 } 1288 1289 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { 1290 if (inputs.size() != 1 || outputs.size() != 1) 1291 return false; 1292 if (!hasSameEncoding(inputs.front(), outputs.front())) 1293 return false; 1294 return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front())); 1295 } 1296 1297 //===----------------------------------------------------------------------===// 1298 // Verifiers for integer and floating point extension/truncation ops 1299 //===----------------------------------------------------------------------===// 1300 1301 // Extend ops can only extend to a wider type. 1302 template <typename ValType, typename Op> 1303 static LogicalResult verifyExtOp(Op op) { 1304 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 1305 Type dstType = getElementTypeOrSelf(op.getType()); 1306 1307 if (llvm::cast<ValType>(srcType).getWidth() >= 1308 llvm::cast<ValType>(dstType).getWidth()) 1309 return op.emitError("result type ") 1310 << dstType << " must be wider than operand type " << srcType; 1311 1312 return success(); 1313 } 1314 1315 // Truncate ops can only truncate to a shorter type. 1316 template <typename ValType, typename Op> 1317 static LogicalResult verifyTruncateOp(Op op) { 1318 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 1319 Type dstType = getElementTypeOrSelf(op.getType()); 1320 1321 if (llvm::cast<ValType>(srcType).getWidth() <= 1322 llvm::cast<ValType>(dstType).getWidth()) 1323 return op.emitError("result type ") 1324 << dstType << " must be shorter than operand type " << srcType; 1325 1326 return success(); 1327 } 1328 1329 /// Validate a cast that changes the width of a type. 1330 template <template <typename> class WidthComparator, typename... ElementTypes> 1331 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) { 1332 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1333 return false; 1334 1335 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front()); 1336 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front()); 1337 if (!srcType || !dstType) 1338 return false; 1339 1340 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(), 1341 srcType.getIntOrFloatBitWidth()); 1342 } 1343 1344 /// Attempts to convert `sourceValue` to an APFloat value with 1345 /// `targetSemantics` and `roundingMode`, without any information loss. 1346 static FailureOr<APFloat> convertFloatValue( 1347 APFloat sourceValue, const llvm::fltSemantics &targetSemantics, 1348 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) { 1349 bool losesInfo = false; 1350 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo); 1351 if (losesInfo || status != APFloat::opOK) 1352 return failure(); 1353 1354 return sourceValue; 1355 } 1356 1357 //===----------------------------------------------------------------------===// 1358 // ExtUIOp 1359 //===----------------------------------------------------------------------===// 1360 1361 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) { 1362 if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) { 1363 getInMutable().assign(lhs.getIn()); 1364 return getResult(); 1365 } 1366 1367 Type resType = getElementTypeOrSelf(getType()); 1368 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth(); 1369 return constFoldCastOp<IntegerAttr, IntegerAttr>( 1370 adaptor.getOperands(), getType(), 1371 [bitWidth](const APInt &a, bool &castStatus) { 1372 return a.zext(bitWidth); 1373 }); 1374 } 1375 1376 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1377 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 1378 } 1379 1380 LogicalResult arith::ExtUIOp::verify() { 1381 return verifyExtOp<IntegerType>(*this); 1382 } 1383 1384 //===----------------------------------------------------------------------===// 1385 // ExtSIOp 1386 //===----------------------------------------------------------------------===// 1387 1388 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) { 1389 if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) { 1390 getInMutable().assign(lhs.getIn()); 1391 return getResult(); 1392 } 1393 1394 Type resType = getElementTypeOrSelf(getType()); 1395 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth(); 1396 return constFoldCastOp<IntegerAttr, IntegerAttr>( 1397 adaptor.getOperands(), getType(), 1398 [bitWidth](const APInt &a, bool &castStatus) { 1399 return a.sext(bitWidth); 1400 }); 1401 } 1402 1403 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1404 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 1405 } 1406 1407 void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1408 MLIRContext *context) { 1409 patterns.add<ExtSIOfExtUI>(context); 1410 } 1411 1412 LogicalResult arith::ExtSIOp::verify() { 1413 return verifyExtOp<IntegerType>(*this); 1414 } 1415 1416 //===----------------------------------------------------------------------===// 1417 // ExtFOp 1418 //===----------------------------------------------------------------------===// 1419 1420 /// Fold extension of float constants when there is no information loss due the 1421 /// difference in fp semantics. 1422 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) { 1423 if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) { 1424 if (truncFOp.getOperand().getType() == getType()) { 1425 arith::FastMathFlags truncFMF = 1426 truncFOp.getFastmath().value_or(arith::FastMathFlags::none); 1427 bool isTruncContract = 1428 bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract); 1429 arith::FastMathFlags extFMF = 1430 getFastmath().value_or(arith::FastMathFlags::none); 1431 bool isExtContract = 1432 bitEnumContainsAll(extFMF, arith::FastMathFlags::contract); 1433 if (isTruncContract && isExtContract) { 1434 return truncFOp.getOperand(); 1435 } 1436 } 1437 } 1438 1439 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType())); 1440 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics(); 1441 return constFoldCastOp<FloatAttr, FloatAttr>( 1442 adaptor.getOperands(), getType(), 1443 [&targetSemantics](const APFloat &a, bool &castStatus) { 1444 FailureOr<APFloat> result = convertFloatValue(a, targetSemantics); 1445 if (failed(result)) { 1446 castStatus = false; 1447 return a; 1448 } 1449 return *result; 1450 }); 1451 } 1452 1453 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1454 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs); 1455 } 1456 1457 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); } 1458 1459 //===----------------------------------------------------------------------===// 1460 // TruncIOp 1461 //===----------------------------------------------------------------------===// 1462 1463 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) { 1464 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) || 1465 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) { 1466 Value src = getOperand().getDefiningOp()->getOperand(0); 1467 Type srcType = getElementTypeOrSelf(src.getType()); 1468 Type dstType = getElementTypeOrSelf(getType()); 1469 // trunci(zexti(a)) -> trunci(a) 1470 // trunci(sexti(a)) -> trunci(a) 1471 if (llvm::cast<IntegerType>(srcType).getWidth() > 1472 llvm::cast<IntegerType>(dstType).getWidth()) { 1473 setOperand(src); 1474 return getResult(); 1475 } 1476 1477 // trunci(zexti(a)) -> a 1478 // trunci(sexti(a)) -> a 1479 if (srcType == dstType) 1480 return src; 1481 } 1482 1483 // trunci(trunci(a)) -> trunci(a)) 1484 if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) { 1485 setOperand(getOperand().getDefiningOp()->getOperand(0)); 1486 return getResult(); 1487 } 1488 1489 Type resType = getElementTypeOrSelf(getType()); 1490 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth(); 1491 return constFoldCastOp<IntegerAttr, IntegerAttr>( 1492 adaptor.getOperands(), getType(), 1493 [bitWidth](const APInt &a, bool &castStatus) { 1494 return a.trunc(bitWidth); 1495 }); 1496 } 1497 1498 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1499 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs); 1500 } 1501 1502 void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1503 MLIRContext *context) { 1504 patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI, 1505 TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>( 1506 context); 1507 } 1508 1509 LogicalResult arith::TruncIOp::verify() { 1510 return verifyTruncateOp<IntegerType>(*this); 1511 } 1512 1513 //===----------------------------------------------------------------------===// 1514 // TruncFOp 1515 //===----------------------------------------------------------------------===// 1516 1517 /// Perform safe const propagation for truncf, i.e., only propagate if FP value 1518 /// can be represented without precision loss. 1519 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) { 1520 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType())); 1521 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics(); 1522 return constFoldCastOp<FloatAttr, FloatAttr>( 1523 adaptor.getOperands(), getType(), 1524 [this, &targetSemantics](const APFloat &a, bool &castStatus) { 1525 RoundingMode roundingMode = 1526 getRoundingmode().value_or(RoundingMode::to_nearest_even); 1527 llvm::RoundingMode llvmRoundingMode = 1528 convertArithRoundingModeToLLVMIR(roundingMode); 1529 FailureOr<APFloat> result = 1530 convertFloatValue(a, targetSemantics, llvmRoundingMode); 1531 if (failed(result)) { 1532 castStatus = false; 1533 return a; 1534 } 1535 return *result; 1536 }); 1537 } 1538 1539 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1540 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs); 1541 } 1542 1543 LogicalResult arith::TruncFOp::verify() { 1544 return verifyTruncateOp<FloatType>(*this); 1545 } 1546 1547 //===----------------------------------------------------------------------===// 1548 // AndIOp 1549 //===----------------------------------------------------------------------===// 1550 1551 void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1552 MLIRContext *context) { 1553 patterns.add<AndOfExtUI, AndOfExtSI>(context); 1554 } 1555 1556 //===----------------------------------------------------------------------===// 1557 // OrIOp 1558 //===----------------------------------------------------------------------===// 1559 1560 void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1561 MLIRContext *context) { 1562 patterns.add<OrOfExtUI, OrOfExtSI>(context); 1563 } 1564 1565 //===----------------------------------------------------------------------===// 1566 // Verifiers for casts between integers and floats. 1567 //===----------------------------------------------------------------------===// 1568 1569 template <typename From, typename To> 1570 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) { 1571 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1572 return false; 1573 1574 auto srcType = getTypeIfLike<From>(inputs.front()); 1575 auto dstType = getTypeIfLike<To>(outputs.back()); 1576 1577 return srcType && dstType; 1578 } 1579 1580 //===----------------------------------------------------------------------===// 1581 // UIToFPOp 1582 //===----------------------------------------------------------------------===// 1583 1584 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1585 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 1586 } 1587 1588 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) { 1589 Type resEleType = getElementTypeOrSelf(getType()); 1590 return constFoldCastOp<IntegerAttr, FloatAttr>( 1591 adaptor.getOperands(), getType(), 1592 [&resEleType](const APInt &a, bool &castStatus) { 1593 FloatType floatTy = llvm::cast<FloatType>(resEleType); 1594 APFloat apf(floatTy.getFloatSemantics(), 1595 APInt::getZero(floatTy.getWidth())); 1596 apf.convertFromAPInt(a, /*IsSigned=*/false, 1597 APFloat::rmNearestTiesToEven); 1598 return apf; 1599 }); 1600 } 1601 1602 //===----------------------------------------------------------------------===// 1603 // SIToFPOp 1604 //===----------------------------------------------------------------------===// 1605 1606 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1607 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 1608 } 1609 1610 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) { 1611 Type resEleType = getElementTypeOrSelf(getType()); 1612 return constFoldCastOp<IntegerAttr, FloatAttr>( 1613 adaptor.getOperands(), getType(), 1614 [&resEleType](const APInt &a, bool &castStatus) { 1615 FloatType floatTy = llvm::cast<FloatType>(resEleType); 1616 APFloat apf(floatTy.getFloatSemantics(), 1617 APInt::getZero(floatTy.getWidth())); 1618 apf.convertFromAPInt(a, /*IsSigned=*/true, 1619 APFloat::rmNearestTiesToEven); 1620 return apf; 1621 }); 1622 } 1623 1624 //===----------------------------------------------------------------------===// 1625 // FPToUIOp 1626 //===----------------------------------------------------------------------===// 1627 1628 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1629 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 1630 } 1631 1632 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) { 1633 Type resType = getElementTypeOrSelf(getType()); 1634 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth(); 1635 return constFoldCastOp<FloatAttr, IntegerAttr>( 1636 adaptor.getOperands(), getType(), 1637 [&bitWidth](const APFloat &a, bool &castStatus) { 1638 bool ignored; 1639 APSInt api(bitWidth, /*isUnsigned=*/true); 1640 castStatus = APFloat::opInvalidOp != 1641 a.convertToInteger(api, APFloat::rmTowardZero, &ignored); 1642 return api; 1643 }); 1644 } 1645 1646 //===----------------------------------------------------------------------===// 1647 // FPToSIOp 1648 //===----------------------------------------------------------------------===// 1649 1650 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1651 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 1652 } 1653 1654 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) { 1655 Type resType = getElementTypeOrSelf(getType()); 1656 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth(); 1657 return constFoldCastOp<FloatAttr, IntegerAttr>( 1658 adaptor.getOperands(), getType(), 1659 [&bitWidth](const APFloat &a, bool &castStatus) { 1660 bool ignored; 1661 APSInt api(bitWidth, /*isUnsigned=*/false); 1662 castStatus = APFloat::opInvalidOp != 1663 a.convertToInteger(api, APFloat::rmTowardZero, &ignored); 1664 return api; 1665 }); 1666 } 1667 1668 //===----------------------------------------------------------------------===// 1669 // IndexCastOp 1670 //===----------------------------------------------------------------------===// 1671 1672 static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) { 1673 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1674 return false; 1675 1676 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front()); 1677 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front()); 1678 if (!srcType || !dstType) 1679 return false; 1680 1681 return (srcType.isIndex() && dstType.isSignlessInteger()) || 1682 (srcType.isSignlessInteger() && dstType.isIndex()); 1683 } 1684 1685 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, 1686 TypeRange outputs) { 1687 return areIndexCastCompatible(inputs, outputs); 1688 } 1689 1690 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) { 1691 // index_cast(constant) -> constant 1692 unsigned resultBitwidth = 64; // Default for index integer attributes. 1693 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType()))) 1694 resultBitwidth = intTy.getWidth(); 1695 1696 return constFoldCastOp<IntegerAttr, IntegerAttr>( 1697 adaptor.getOperands(), getType(), 1698 [resultBitwidth](const APInt &a, bool & /*castStatus*/) { 1699 return a.sextOrTrunc(resultBitwidth); 1700 }); 1701 } 1702 1703 void arith::IndexCastOp::getCanonicalizationPatterns( 1704 RewritePatternSet &patterns, MLIRContext *context) { 1705 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context); 1706 } 1707 1708 //===----------------------------------------------------------------------===// 1709 // IndexCastUIOp 1710 //===----------------------------------------------------------------------===// 1711 1712 bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs, 1713 TypeRange outputs) { 1714 return areIndexCastCompatible(inputs, outputs); 1715 } 1716 1717 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) { 1718 // index_castui(constant) -> constant 1719 unsigned resultBitwidth = 64; // Default for index integer attributes. 1720 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType()))) 1721 resultBitwidth = intTy.getWidth(); 1722 1723 return constFoldCastOp<IntegerAttr, IntegerAttr>( 1724 adaptor.getOperands(), getType(), 1725 [resultBitwidth](const APInt &a, bool & /*castStatus*/) { 1726 return a.zextOrTrunc(resultBitwidth); 1727 }); 1728 } 1729 1730 void arith::IndexCastUIOp::getCanonicalizationPatterns( 1731 RewritePatternSet &patterns, MLIRContext *context) { 1732 patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context); 1733 } 1734 1735 //===----------------------------------------------------------------------===// 1736 // BitcastOp 1737 //===----------------------------------------------------------------------===// 1738 1739 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1740 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1741 return false; 1742 1743 auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front()); 1744 auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front()); 1745 if (!srcType || !dstType) 1746 return false; 1747 1748 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); 1749 } 1750 1751 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) { 1752 auto resType = getType(); 1753 auto operand = adaptor.getIn(); 1754 if (!operand) 1755 return {}; 1756 1757 /// Bitcast dense elements. 1758 if (auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand)) 1759 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType()); 1760 /// Other shaped types unhandled. 1761 if (llvm::isa<ShapedType>(resType)) 1762 return {}; 1763 1764 /// Bitcast integer or float to integer or float. 1765 APInt bits = llvm::isa<FloatAttr>(operand) 1766 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt() 1767 : llvm::cast<IntegerAttr>(operand).getValue(); 1768 assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() && 1769 "trying to fold on broken IR: operands have incompatible types"); 1770 1771 if (auto resFloatType = llvm::dyn_cast<FloatType>(resType)) 1772 return FloatAttr::get(resType, 1773 APFloat(resFloatType.getFloatSemantics(), bits)); 1774 return IntegerAttr::get(resType, bits); 1775 } 1776 1777 void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1778 MLIRContext *context) { 1779 patterns.add<BitcastOfBitcast>(context); 1780 } 1781 1782 //===----------------------------------------------------------------------===// 1783 // CmpIOp 1784 //===----------------------------------------------------------------------===// 1785 1786 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer 1787 /// comparison predicates. 1788 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate, 1789 const APInt &lhs, const APInt &rhs) { 1790 switch (predicate) { 1791 case arith::CmpIPredicate::eq: 1792 return lhs.eq(rhs); 1793 case arith::CmpIPredicate::ne: 1794 return lhs.ne(rhs); 1795 case arith::CmpIPredicate::slt: 1796 return lhs.slt(rhs); 1797 case arith::CmpIPredicate::sle: 1798 return lhs.sle(rhs); 1799 case arith::CmpIPredicate::sgt: 1800 return lhs.sgt(rhs); 1801 case arith::CmpIPredicate::sge: 1802 return lhs.sge(rhs); 1803 case arith::CmpIPredicate::ult: 1804 return lhs.ult(rhs); 1805 case arith::CmpIPredicate::ule: 1806 return lhs.ule(rhs); 1807 case arith::CmpIPredicate::ugt: 1808 return lhs.ugt(rhs); 1809 case arith::CmpIPredicate::uge: 1810 return lhs.uge(rhs); 1811 } 1812 llvm_unreachable("unknown cmpi predicate kind"); 1813 } 1814 1815 /// Returns true if the predicate is true for two equal operands. 1816 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) { 1817 switch (predicate) { 1818 case arith::CmpIPredicate::eq: 1819 case arith::CmpIPredicate::sle: 1820 case arith::CmpIPredicate::sge: 1821 case arith::CmpIPredicate::ule: 1822 case arith::CmpIPredicate::uge: 1823 return true; 1824 case arith::CmpIPredicate::ne: 1825 case arith::CmpIPredicate::slt: 1826 case arith::CmpIPredicate::sgt: 1827 case arith::CmpIPredicate::ult: 1828 case arith::CmpIPredicate::ugt: 1829 return false; 1830 } 1831 llvm_unreachable("unknown cmpi predicate kind"); 1832 } 1833 1834 static std::optional<int64_t> getIntegerWidth(Type t) { 1835 if (auto intType = llvm::dyn_cast<IntegerType>(t)) { 1836 return intType.getWidth(); 1837 } 1838 if (auto vectorIntType = llvm::dyn_cast<VectorType>(t)) { 1839 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth(); 1840 } 1841 return std::nullopt; 1842 } 1843 1844 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) { 1845 // cmpi(pred, x, x) 1846 if (getLhs() == getRhs()) { 1847 auto val = applyCmpPredicateToEqualOperands(getPredicate()); 1848 return getBoolAttribute(getType(), val); 1849 } 1850 1851 if (matchPattern(adaptor.getRhs(), m_Zero())) { 1852 if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) { 1853 // extsi(%x : i1 -> iN) != 0 -> %x 1854 std::optional<int64_t> integerWidth = 1855 getIntegerWidth(extOp.getOperand().getType()); 1856 if (integerWidth && integerWidth.value() == 1 && 1857 getPredicate() == arith::CmpIPredicate::ne) 1858 return extOp.getOperand(); 1859 } 1860 if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) { 1861 // extui(%x : i1 -> iN) != 0 -> %x 1862 std::optional<int64_t> integerWidth = 1863 getIntegerWidth(extOp.getOperand().getType()); 1864 if (integerWidth && integerWidth.value() == 1 && 1865 getPredicate() == arith::CmpIPredicate::ne) 1866 return extOp.getOperand(); 1867 } 1868 1869 // arith.cmpi ne, %val, %zero : i1 -> %val 1870 if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) && 1871 getPredicate() == arith::CmpIPredicate::ne) 1872 return getLhs(); 1873 } 1874 1875 if (matchPattern(adaptor.getRhs(), m_One())) { 1876 // arith.cmpi eq, %val, %one : i1 -> %val 1877 if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) && 1878 getPredicate() == arith::CmpIPredicate::eq) 1879 return getLhs(); 1880 } 1881 1882 // Move constant to the right side. 1883 if (adaptor.getLhs() && !adaptor.getRhs()) { 1884 // Do not use invertPredicate, as it will change eq to ne and vice versa. 1885 using Pred = CmpIPredicate; 1886 const std::pair<Pred, Pred> invPreds[] = { 1887 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge}, 1888 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult}, 1889 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq}, 1890 {Pred::ne, Pred::ne}, 1891 }; 1892 Pred origPred = getPredicate(); 1893 for (auto pred : invPreds) { 1894 if (origPred == pred.first) { 1895 setPredicate(pred.second); 1896 Value lhs = getLhs(); 1897 Value rhs = getRhs(); 1898 getLhsMutable().assign(rhs); 1899 getRhsMutable().assign(lhs); 1900 return getResult(); 1901 } 1902 } 1903 llvm_unreachable("unknown cmpi predicate kind"); 1904 } 1905 1906 // We are moving constants to the right side; So if lhs is constant rhs is 1907 // guaranteed to be a constant. 1908 if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) { 1909 return constFoldBinaryOp<IntegerAttr>( 1910 adaptor.getOperands(), getI1SameShape(lhs.getType()), 1911 [pred = getPredicate()](const APInt &lhs, const APInt &rhs) { 1912 return APInt(1, 1913 static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs))); 1914 }); 1915 } 1916 1917 return {}; 1918 } 1919 1920 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1921 MLIRContext *context) { 1922 patterns.insert<CmpIExtSI, CmpIExtUI>(context); 1923 } 1924 1925 //===----------------------------------------------------------------------===// 1926 // CmpFOp 1927 //===----------------------------------------------------------------------===// 1928 1929 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point 1930 /// comparison predicates. 1931 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate, 1932 const APFloat &lhs, const APFloat &rhs) { 1933 auto cmpResult = lhs.compare(rhs); 1934 switch (predicate) { 1935 case arith::CmpFPredicate::AlwaysFalse: 1936 return false; 1937 case arith::CmpFPredicate::OEQ: 1938 return cmpResult == APFloat::cmpEqual; 1939 case arith::CmpFPredicate::OGT: 1940 return cmpResult == APFloat::cmpGreaterThan; 1941 case arith::CmpFPredicate::OGE: 1942 return cmpResult == APFloat::cmpGreaterThan || 1943 cmpResult == APFloat::cmpEqual; 1944 case arith::CmpFPredicate::OLT: 1945 return cmpResult == APFloat::cmpLessThan; 1946 case arith::CmpFPredicate::OLE: 1947 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1948 case arith::CmpFPredicate::ONE: 1949 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; 1950 case arith::CmpFPredicate::ORD: 1951 return cmpResult != APFloat::cmpUnordered; 1952 case arith::CmpFPredicate::UEQ: 1953 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; 1954 case arith::CmpFPredicate::UGT: 1955 return cmpResult == APFloat::cmpUnordered || 1956 cmpResult == APFloat::cmpGreaterThan; 1957 case arith::CmpFPredicate::UGE: 1958 return cmpResult == APFloat::cmpUnordered || 1959 cmpResult == APFloat::cmpGreaterThan || 1960 cmpResult == APFloat::cmpEqual; 1961 case arith::CmpFPredicate::ULT: 1962 return cmpResult == APFloat::cmpUnordered || 1963 cmpResult == APFloat::cmpLessThan; 1964 case arith::CmpFPredicate::ULE: 1965 return cmpResult == APFloat::cmpUnordered || 1966 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1967 case arith::CmpFPredicate::UNE: 1968 return cmpResult != APFloat::cmpEqual; 1969 case arith::CmpFPredicate::UNO: 1970 return cmpResult == APFloat::cmpUnordered; 1971 case arith::CmpFPredicate::AlwaysTrue: 1972 return true; 1973 } 1974 llvm_unreachable("unknown cmpf predicate kind"); 1975 } 1976 1977 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) { 1978 auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs()); 1979 auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs()); 1980 1981 // If one operand is NaN, making them both NaN does not change the result. 1982 if (lhs && lhs.getValue().isNaN()) 1983 rhs = lhs; 1984 if (rhs && rhs.getValue().isNaN()) 1985 lhs = rhs; 1986 1987 if (!lhs || !rhs) 1988 return {}; 1989 1990 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1991 return BoolAttr::get(getContext(), val); 1992 } 1993 1994 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> { 1995 public: 1996 using OpRewritePattern<CmpFOp>::OpRewritePattern; 1997 1998 static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, 1999 bool isUnsigned) { 2000 using namespace arith; 2001 switch (pred) { 2002 case CmpFPredicate::UEQ: 2003 case CmpFPredicate::OEQ: 2004 return CmpIPredicate::eq; 2005 case CmpFPredicate::UGT: 2006 case CmpFPredicate::OGT: 2007 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt; 2008 case CmpFPredicate::UGE: 2009 case CmpFPredicate::OGE: 2010 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge; 2011 case CmpFPredicate::ULT: 2012 case CmpFPredicate::OLT: 2013 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt; 2014 case CmpFPredicate::ULE: 2015 case CmpFPredicate::OLE: 2016 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle; 2017 case CmpFPredicate::UNE: 2018 case CmpFPredicate::ONE: 2019 return CmpIPredicate::ne; 2020 default: 2021 llvm_unreachable("Unexpected predicate!"); 2022 } 2023 } 2024 2025 LogicalResult matchAndRewrite(CmpFOp op, 2026 PatternRewriter &rewriter) const override { 2027 FloatAttr flt; 2028 if (!matchPattern(op.getRhs(), m_Constant(&flt))) 2029 return failure(); 2030 2031 const APFloat &rhs = flt.getValue(); 2032 2033 // Don't attempt to fold a nan. 2034 if (rhs.isNaN()) 2035 return failure(); 2036 2037 // Get the width of the mantissa. We don't want to hack on conversions that 2038 // might lose information from the integer, e.g. "i64 -> float" 2039 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType()); 2040 int mantissaWidth = floatTy.getFPMantissaWidth(); 2041 if (mantissaWidth <= 0) 2042 return failure(); 2043 2044 bool isUnsigned; 2045 Value intVal; 2046 2047 if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) { 2048 isUnsigned = false; 2049 intVal = si.getIn(); 2050 } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) { 2051 isUnsigned = true; 2052 intVal = ui.getIn(); 2053 } else { 2054 return failure(); 2055 } 2056 2057 // Check to see that the input is converted from an integer type that is 2058 // small enough that preserves all bits. 2059 auto intTy = llvm::cast<IntegerType>(intVal.getType()); 2060 auto intWidth = intTy.getWidth(); 2061 2062 // Number of bits representing values, as opposed to the sign 2063 auto valueBits = isUnsigned ? intWidth : (intWidth - 1); 2064 2065 // Following test does NOT adjust intWidth downwards for signed inputs, 2066 // because the most negative value still requires all the mantissa bits 2067 // to distinguish it from one less than that value. 2068 if ((int)intWidth > mantissaWidth) { 2069 // Conversion would lose accuracy. Check if loss can impact comparison. 2070 int exponent = ilogb(rhs); 2071 if (exponent == APFloat::IEK_Inf) { 2072 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics())); 2073 if (maxExponent < (int)valueBits) { 2074 // Conversion could create infinity. 2075 return failure(); 2076 } 2077 } else { 2078 // Note that if rhs is zero or NaN, then Exp is negative 2079 // and first condition is trivially false. 2080 if (mantissaWidth <= exponent && exponent <= (int)valueBits) { 2081 // Conversion could affect comparison. 2082 return failure(); 2083 } 2084 } 2085 } 2086 2087 // Convert to equivalent cmpi predicate 2088 CmpIPredicate pred; 2089 switch (op.getPredicate()) { 2090 case CmpFPredicate::ORD: 2091 // Int to fp conversion doesn't create a nan (ord checks neither is a nan) 2092 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 2093 /*width=*/1); 2094 return success(); 2095 case CmpFPredicate::UNO: 2096 // Int to fp conversion doesn't create a nan (uno checks either is a nan) 2097 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 2098 /*width=*/1); 2099 return success(); 2100 default: 2101 pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned); 2102 break; 2103 } 2104 2105 if (!isUnsigned) { 2106 // If the rhs value is > SignedMax, fold the comparison. This handles 2107 // +INF and large values. 2108 APFloat signedMax(rhs.getSemantics()); 2109 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true, 2110 APFloat::rmNearestTiesToEven); 2111 if (signedMax < rhs) { // smax < 13123.0 2112 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt || 2113 pred == CmpIPredicate::sle) 2114 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 2115 /*width=*/1); 2116 else 2117 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 2118 /*width=*/1); 2119 return success(); 2120 } 2121 } else { 2122 // If the rhs value is > UnsignedMax, fold the comparison. This handles 2123 // +INF and large values. 2124 APFloat unsignedMax(rhs.getSemantics()); 2125 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false, 2126 APFloat::rmNearestTiesToEven); 2127 if (unsignedMax < rhs) { // umax < 13123.0 2128 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult || 2129 pred == CmpIPredicate::ule) 2130 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 2131 /*width=*/1); 2132 else 2133 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 2134 /*width=*/1); 2135 return success(); 2136 } 2137 } 2138 2139 if (!isUnsigned) { 2140 // See if the rhs value is < SignedMin. 2141 APFloat signedMin(rhs.getSemantics()); 2142 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true, 2143 APFloat::rmNearestTiesToEven); 2144 if (signedMin > rhs) { // smin > 12312.0 2145 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt || 2146 pred == CmpIPredicate::sge) 2147 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 2148 /*width=*/1); 2149 else 2150 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 2151 /*width=*/1); 2152 return success(); 2153 } 2154 } else { 2155 // See if the rhs value is < UnsignedMin. 2156 APFloat unsignedMin(rhs.getSemantics()); 2157 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false, 2158 APFloat::rmNearestTiesToEven); 2159 if (unsignedMin > rhs) { // umin > 12312.0 2160 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt || 2161 pred == CmpIPredicate::uge) 2162 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 2163 /*width=*/1); 2164 else 2165 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 2166 /*width=*/1); 2167 return success(); 2168 } 2169 } 2170 2171 // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or 2172 // [0, UMAX], but it may still be fractional. See if it is fractional by 2173 // casting the FP value to the integer value and back, checking for 2174 // equality. Don't do this for zero, because -0.0 is not fractional. 2175 bool ignored; 2176 APSInt rhsInt(intWidth, isUnsigned); 2177 if (APFloat::opInvalidOp == 2178 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) { 2179 // Undefined behavior invoked - the destination type can't represent 2180 // the input constant. 2181 return failure(); 2182 } 2183 2184 if (!rhs.isZero()) { 2185 APFloat apf(floatTy.getFloatSemantics(), 2186 APInt::getZero(floatTy.getWidth())); 2187 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven); 2188 2189 bool equal = apf == rhs; 2190 if (!equal) { 2191 // If we had a comparison against a fractional value, we have to adjust 2192 // the compare predicate and sometimes the value. rhsInt is rounded 2193 // towards zero at this point. 2194 switch (pred) { 2195 case CmpIPredicate::ne: // (float)int != 4.4 --> true 2196 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 2197 /*width=*/1); 2198 return success(); 2199 case CmpIPredicate::eq: // (float)int == 4.4 --> false 2200 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 2201 /*width=*/1); 2202 return success(); 2203 case CmpIPredicate::ule: 2204 // (float)int <= 4.4 --> int <= 4 2205 // (float)int <= -4.4 --> false 2206 if (rhs.isNegative()) { 2207 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 2208 /*width=*/1); 2209 return success(); 2210 } 2211 break; 2212 case CmpIPredicate::sle: 2213 // (float)int <= 4.4 --> int <= 4 2214 // (float)int <= -4.4 --> int < -4 2215 if (rhs.isNegative()) 2216 pred = CmpIPredicate::slt; 2217 break; 2218 case CmpIPredicate::ult: 2219 // (float)int < -4.4 --> false 2220 // (float)int < 4.4 --> int <= 4 2221 if (rhs.isNegative()) { 2222 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 2223 /*width=*/1); 2224 return success(); 2225 } 2226 pred = CmpIPredicate::ule; 2227 break; 2228 case CmpIPredicate::slt: 2229 // (float)int < -4.4 --> int < -4 2230 // (float)int < 4.4 --> int <= 4 2231 if (!rhs.isNegative()) 2232 pred = CmpIPredicate::sle; 2233 break; 2234 case CmpIPredicate::ugt: 2235 // (float)int > 4.4 --> int > 4 2236 // (float)int > -4.4 --> true 2237 if (rhs.isNegative()) { 2238 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 2239 /*width=*/1); 2240 return success(); 2241 } 2242 break; 2243 case CmpIPredicate::sgt: 2244 // (float)int > 4.4 --> int > 4 2245 // (float)int > -4.4 --> int >= -4 2246 if (rhs.isNegative()) 2247 pred = CmpIPredicate::sge; 2248 break; 2249 case CmpIPredicate::uge: 2250 // (float)int >= -4.4 --> true 2251 // (float)int >= 4.4 --> int > 4 2252 if (rhs.isNegative()) { 2253 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 2254 /*width=*/1); 2255 return success(); 2256 } 2257 pred = CmpIPredicate::ugt; 2258 break; 2259 case CmpIPredicate::sge: 2260 // (float)int >= -4.4 --> int >= -4 2261 // (float)int >= 4.4 --> int > 4 2262 if (!rhs.isNegative()) 2263 pred = CmpIPredicate::sgt; 2264 break; 2265 } 2266 } 2267 } 2268 2269 // Lower this FP comparison into an appropriate integer version of the 2270 // comparison. 2271 rewriter.replaceOpWithNewOp<CmpIOp>( 2272 op, pred, intVal, 2273 rewriter.create<ConstantOp>( 2274 op.getLoc(), intVal.getType(), 2275 rewriter.getIntegerAttr(intVal.getType(), rhsInt))); 2276 return success(); 2277 } 2278 }; 2279 2280 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 2281 MLIRContext *context) { 2282 patterns.insert<CmpFIntToFPConst>(context); 2283 } 2284 2285 //===----------------------------------------------------------------------===// 2286 // SelectOp 2287 //===----------------------------------------------------------------------===// 2288 2289 // select %arg, %c1, %c0 => extui %arg 2290 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> { 2291 using OpRewritePattern<arith::SelectOp>::OpRewritePattern; 2292 2293 LogicalResult matchAndRewrite(arith::SelectOp op, 2294 PatternRewriter &rewriter) const override { 2295 // Cannot extui i1 to i1, or i1 to f32 2296 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1)) 2297 return failure(); 2298 2299 // select %x, c1, %c0 => extui %arg 2300 if (matchPattern(op.getTrueValue(), m_One()) && 2301 matchPattern(op.getFalseValue(), m_Zero())) { 2302 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(), 2303 op.getCondition()); 2304 return success(); 2305 } 2306 2307 // select %x, c0, %c1 => extui (xor %arg, true) 2308 if (matchPattern(op.getTrueValue(), m_Zero()) && 2309 matchPattern(op.getFalseValue(), m_One())) { 2310 rewriter.replaceOpWithNewOp<arith::ExtUIOp>( 2311 op, op.getType(), 2312 rewriter.create<arith::XOrIOp>( 2313 op.getLoc(), op.getCondition(), 2314 rewriter.create<arith::ConstantIntOp>( 2315 op.getLoc(), 1, op.getCondition().getType()))); 2316 return success(); 2317 } 2318 2319 return failure(); 2320 } 2321 }; 2322 2323 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results, 2324 MLIRContext *context) { 2325 results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond, 2326 SelectI1ToNot, SelectToExtUI>(context); 2327 } 2328 2329 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) { 2330 Value trueVal = getTrueValue(); 2331 Value falseVal = getFalseValue(); 2332 if (trueVal == falseVal) 2333 return trueVal; 2334 2335 Value condition = getCondition(); 2336 2337 // select true, %0, %1 => %0 2338 if (matchPattern(adaptor.getCondition(), m_One())) 2339 return trueVal; 2340 2341 // select false, %0, %1 => %1 2342 if (matchPattern(adaptor.getCondition(), m_Zero())) 2343 return falseVal; 2344 2345 // If either operand is fully poisoned, return the other. 2346 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue())) 2347 return falseVal; 2348 2349 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue())) 2350 return trueVal; 2351 2352 // select %x, true, false => %x 2353 if (getType().isSignlessInteger(1) && 2354 matchPattern(adaptor.getTrueValue(), m_One()) && 2355 matchPattern(adaptor.getFalseValue(), m_Zero())) 2356 return condition; 2357 2358 if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) { 2359 auto pred = cmp.getPredicate(); 2360 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) { 2361 auto cmpLhs = cmp.getLhs(); 2362 auto cmpRhs = cmp.getRhs(); 2363 2364 // %0 = arith.cmpi eq, %arg0, %arg1 2365 // %1 = arith.select %0, %arg0, %arg1 => %arg1 2366 2367 // %0 = arith.cmpi ne, %arg0, %arg1 2368 // %1 = arith.select %0, %arg0, %arg1 => %arg0 2369 2370 if ((cmpLhs == trueVal && cmpRhs == falseVal) || 2371 (cmpRhs == trueVal && cmpLhs == falseVal)) 2372 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal; 2373 } 2374 } 2375 2376 // Constant-fold constant operands over non-splat constant condition. 2377 // select %cst_vec, %cst0, %cst1 => %cst2 2378 if (auto cond = 2379 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) { 2380 if (auto lhs = 2381 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) { 2382 if (auto rhs = 2383 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) { 2384 SmallVector<Attribute> results; 2385 results.reserve(static_cast<size_t>(cond.getNumElements())); 2386 auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(), 2387 cond.value_end<BoolAttr>()); 2388 auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(), 2389 lhs.value_end<Attribute>()); 2390 auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(), 2391 rhs.value_end<Attribute>()); 2392 2393 for (auto [condVal, lhsVal, rhsVal] : 2394 llvm::zip_equal(condVals, lhsVals, rhsVals)) 2395 results.push_back(condVal.getValue() ? lhsVal : rhsVal); 2396 2397 return DenseElementsAttr::get(lhs.getType(), results); 2398 } 2399 } 2400 } 2401 2402 return nullptr; 2403 } 2404 2405 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) { 2406 Type conditionType, resultType; 2407 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands; 2408 if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) || 2409 parser.parseOptionalAttrDict(result.attributes) || 2410 parser.parseColonType(resultType)) 2411 return failure(); 2412 2413 // Check for the explicit condition type if this is a masked tensor or vector. 2414 if (succeeded(parser.parseOptionalComma())) { 2415 conditionType = resultType; 2416 if (parser.parseType(resultType)) 2417 return failure(); 2418 } else { 2419 conditionType = parser.getBuilder().getI1Type(); 2420 } 2421 2422 result.addTypes(resultType); 2423 return parser.resolveOperands(operands, 2424 {conditionType, resultType, resultType}, 2425 parser.getNameLoc(), result.operands); 2426 } 2427 2428 void arith::SelectOp::print(OpAsmPrinter &p) { 2429 p << " " << getOperands(); 2430 p.printOptionalAttrDict((*this)->getAttrs()); 2431 p << " : "; 2432 if (ShapedType condType = 2433 llvm::dyn_cast<ShapedType>(getCondition().getType())) 2434 p << condType << ", "; 2435 p << getType(); 2436 } 2437 2438 LogicalResult arith::SelectOp::verify() { 2439 Type conditionType = getCondition().getType(); 2440 if (conditionType.isSignlessInteger(1)) 2441 return success(); 2442 2443 // If the result type is a vector or tensor, the type can be a mask with the 2444 // same elements. 2445 Type resultType = getType(); 2446 if (!llvm::isa<TensorType, VectorType>(resultType)) 2447 return emitOpError() << "expected condition to be a signless i1, but got " 2448 << conditionType; 2449 Type shapedConditionType = getI1SameShape(resultType); 2450 if (conditionType != shapedConditionType) { 2451 return emitOpError() << "expected condition type to have the same shape " 2452 "as the result type, expected " 2453 << shapedConditionType << ", but got " 2454 << conditionType; 2455 } 2456 return success(); 2457 } 2458 //===----------------------------------------------------------------------===// 2459 // ShLIOp 2460 //===----------------------------------------------------------------------===// 2461 2462 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) { 2463 // shli(x, 0) -> x 2464 if (matchPattern(adaptor.getRhs(), m_Zero())) 2465 return getLhs(); 2466 // Don't fold if shifting more or equal than the bit width. 2467 bool bounded = false; 2468 auto result = constFoldBinaryOp<IntegerAttr>( 2469 adaptor.getOperands(), [&](const APInt &a, const APInt &b) { 2470 bounded = b.ult(b.getBitWidth()); 2471 return a.shl(b); 2472 }); 2473 return bounded ? result : Attribute(); 2474 } 2475 2476 //===----------------------------------------------------------------------===// 2477 // ShRUIOp 2478 //===----------------------------------------------------------------------===// 2479 2480 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) { 2481 // shrui(x, 0) -> x 2482 if (matchPattern(adaptor.getRhs(), m_Zero())) 2483 return getLhs(); 2484 // Don't fold if shifting more or equal than the bit width. 2485 bool bounded = false; 2486 auto result = constFoldBinaryOp<IntegerAttr>( 2487 adaptor.getOperands(), [&](const APInt &a, const APInt &b) { 2488 bounded = b.ult(b.getBitWidth()); 2489 return a.lshr(b); 2490 }); 2491 return bounded ? result : Attribute(); 2492 } 2493 2494 //===----------------------------------------------------------------------===// 2495 // ShRSIOp 2496 //===----------------------------------------------------------------------===// 2497 2498 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) { 2499 // shrsi(x, 0) -> x 2500 if (matchPattern(adaptor.getRhs(), m_Zero())) 2501 return getLhs(); 2502 // Don't fold if shifting more or equal than the bit width. 2503 bool bounded = false; 2504 auto result = constFoldBinaryOp<IntegerAttr>( 2505 adaptor.getOperands(), [&](const APInt &a, const APInt &b) { 2506 bounded = b.ult(b.getBitWidth()); 2507 return a.ashr(b); 2508 }); 2509 return bounded ? result : Attribute(); 2510 } 2511 2512 //===----------------------------------------------------------------------===// 2513 // Atomic Enum 2514 //===----------------------------------------------------------------------===// 2515 2516 /// Returns the identity value attribute associated with an AtomicRMWKind op. 2517 TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, 2518 OpBuilder &builder, Location loc, 2519 bool useOnlyFiniteValue) { 2520 switch (kind) { 2521 case AtomicRMWKind::maximumf: { 2522 const llvm::fltSemantics &semantic = 2523 llvm::cast<FloatType>(resultType).getFloatSemantics(); 2524 APFloat identity = useOnlyFiniteValue 2525 ? APFloat::getLargest(semantic, /*Negative=*/true) 2526 : APFloat::getInf(semantic, /*Negative=*/true); 2527 return builder.getFloatAttr(resultType, identity); 2528 } 2529 case AtomicRMWKind::maxnumf: { 2530 const llvm::fltSemantics &semantic = 2531 llvm::cast<FloatType>(resultType).getFloatSemantics(); 2532 APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true); 2533 return builder.getFloatAttr(resultType, identity); 2534 } 2535 case AtomicRMWKind::addf: 2536 case AtomicRMWKind::addi: 2537 case AtomicRMWKind::maxu: 2538 case AtomicRMWKind::ori: 2539 return builder.getZeroAttr(resultType); 2540 case AtomicRMWKind::andi: 2541 return builder.getIntegerAttr( 2542 resultType, 2543 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth())); 2544 case AtomicRMWKind::maxs: 2545 return builder.getIntegerAttr( 2546 resultType, APInt::getSignedMinValue( 2547 llvm::cast<IntegerType>(resultType).getWidth())); 2548 case AtomicRMWKind::minimumf: { 2549 const llvm::fltSemantics &semantic = 2550 llvm::cast<FloatType>(resultType).getFloatSemantics(); 2551 APFloat identity = useOnlyFiniteValue 2552 ? APFloat::getLargest(semantic, /*Negative=*/false) 2553 : APFloat::getInf(semantic, /*Negative=*/false); 2554 2555 return builder.getFloatAttr(resultType, identity); 2556 } 2557 case AtomicRMWKind::minnumf: { 2558 const llvm::fltSemantics &semantic = 2559 llvm::cast<FloatType>(resultType).getFloatSemantics(); 2560 APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false); 2561 return builder.getFloatAttr(resultType, identity); 2562 } 2563 case AtomicRMWKind::mins: 2564 return builder.getIntegerAttr( 2565 resultType, APInt::getSignedMaxValue( 2566 llvm::cast<IntegerType>(resultType).getWidth())); 2567 case AtomicRMWKind::minu: 2568 return builder.getIntegerAttr( 2569 resultType, 2570 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth())); 2571 case AtomicRMWKind::muli: 2572 return builder.getIntegerAttr(resultType, 1); 2573 case AtomicRMWKind::mulf: 2574 return builder.getFloatAttr(resultType, 1); 2575 // TODO: Add remaining reduction operations. 2576 default: 2577 (void)emitOptionalError(loc, "Reduction operation type not supported"); 2578 break; 2579 } 2580 return nullptr; 2581 } 2582 2583 /// Return the identity numeric value associated to the give op. 2584 std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) { 2585 std::optional<AtomicRMWKind> maybeKind = 2586 llvm::TypeSwitch<Operation *, std::optional<AtomicRMWKind>>(op) 2587 // Floating-point operations. 2588 .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; }) 2589 .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; }) 2590 .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; }) 2591 .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; }) 2592 .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; }) 2593 .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; }) 2594 // Integer operations. 2595 .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; }) 2596 .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; }) 2597 .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; }) 2598 .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; }) 2599 .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; }) 2600 .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; }) 2601 .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; }) 2602 .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; }) 2603 .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; }) 2604 .Default([](Operation *op) { return std::nullopt; }); 2605 if (!maybeKind) { 2606 return std::nullopt; 2607 } 2608 2609 bool useOnlyFiniteValue = false; 2610 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op); 2611 if (fmfOpInterface) { 2612 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr(); 2613 useOnlyFiniteValue = 2614 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf); 2615 } 2616 2617 // Builder only used as helper for attribute creation. 2618 OpBuilder b(op->getContext()); 2619 Type resultType = op->getResult(0).getType(); 2620 2621 return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(), 2622 useOnlyFiniteValue); 2623 } 2624 2625 /// Returns the identity value associated with an AtomicRMWKind op. 2626 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, 2627 OpBuilder &builder, Location loc, 2628 bool useOnlyFiniteValue) { 2629 auto attr = 2630 getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue); 2631 return builder.create<arith::ConstantOp>(loc, attr); 2632 } 2633 2634 /// Return the value obtained by applying the reduction operation kind 2635 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. 2636 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, 2637 Location loc, Value lhs, Value rhs) { 2638 switch (op) { 2639 case AtomicRMWKind::addf: 2640 return builder.create<arith::AddFOp>(loc, lhs, rhs); 2641 case AtomicRMWKind::addi: 2642 return builder.create<arith::AddIOp>(loc, lhs, rhs); 2643 case AtomicRMWKind::mulf: 2644 return builder.create<arith::MulFOp>(loc, lhs, rhs); 2645 case AtomicRMWKind::muli: 2646 return builder.create<arith::MulIOp>(loc, lhs, rhs); 2647 case AtomicRMWKind::maximumf: 2648 return builder.create<arith::MaximumFOp>(loc, lhs, rhs); 2649 case AtomicRMWKind::minimumf: 2650 return builder.create<arith::MinimumFOp>(loc, lhs, rhs); 2651 case AtomicRMWKind::maxnumf: 2652 return builder.create<arith::MaxNumFOp>(loc, lhs, rhs); 2653 case AtomicRMWKind::minnumf: 2654 return builder.create<arith::MinNumFOp>(loc, lhs, rhs); 2655 case AtomicRMWKind::maxs: 2656 return builder.create<arith::MaxSIOp>(loc, lhs, rhs); 2657 case AtomicRMWKind::mins: 2658 return builder.create<arith::MinSIOp>(loc, lhs, rhs); 2659 case AtomicRMWKind::maxu: 2660 return builder.create<arith::MaxUIOp>(loc, lhs, rhs); 2661 case AtomicRMWKind::minu: 2662 return builder.create<arith::MinUIOp>(loc, lhs, rhs); 2663 case AtomicRMWKind::ori: 2664 return builder.create<arith::OrIOp>(loc, lhs, rhs); 2665 case AtomicRMWKind::andi: 2666 return builder.create<arith::AndIOp>(loc, lhs, rhs); 2667 // TODO: Add remaining reduction operations. 2668 default: 2669 (void)emitOptionalError(loc, "Reduction operation type not supported"); 2670 break; 2671 } 2672 return nullptr; 2673 } 2674 2675 //===----------------------------------------------------------------------===// 2676 // TableGen'd op method definitions 2677 //===----------------------------------------------------------------------===// 2678 2679 #define GET_OP_CLASSES 2680 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc" 2681 2682 //===----------------------------------------------------------------------===// 2683 // TableGen'd enum attribute definitions 2684 //===----------------------------------------------------------------------===// 2685 2686 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc" 2687