1 //===- ExpandPatterns.cpp - Code to expand various math operations. -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements expansion of various math operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Math/IR/Math.h" 15 #include "mlir/Dialect/Math/Transforms/Passes.h" 16 #include "mlir/Dialect/SCF/IR/SCF.h" 17 #include "mlir/Dialect/Vector/IR/VectorOps.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/ImplicitLocOpBuilder.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 23 using namespace mlir; 24 25 /// Create a float constant. 26 static Value createFloatConst(Location loc, Type type, APFloat value, 27 OpBuilder &b) { 28 bool losesInfo = false; 29 auto eltType = getElementTypeOrSelf(type); 30 // Convert double to the given `FloatType` with round-to-nearest-ties-to-even. 31 value.convert(cast<FloatType>(eltType).getFloatSemantics(), 32 APFloat::rmNearestTiesToEven, &losesInfo); 33 auto attr = b.getFloatAttr(eltType, value); 34 if (auto shapedTy = dyn_cast<ShapedType>(type)) { 35 return b.create<arith::ConstantOp>(loc, 36 DenseElementsAttr::get(shapedTy, attr)); 37 } 38 39 return b.create<arith::ConstantOp>(loc, attr); 40 } 41 42 static Value createFloatConst(Location loc, Type type, double value, 43 OpBuilder &b) { 44 return createFloatConst(loc, type, APFloat(value), b); 45 } 46 47 /// Create an integer constant. 48 static Value createIntConst(Location loc, Type type, int64_t value, 49 OpBuilder &b) { 50 auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value); 51 if (auto shapedTy = dyn_cast<ShapedType>(type)) { 52 return b.create<arith::ConstantOp>(loc, 53 DenseElementsAttr::get(shapedTy, attr)); 54 } 55 56 return b.create<arith::ConstantOp>(loc, attr); 57 } 58 59 static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) { 60 Type opType = operand.getType(); 61 Type i64Ty = b.getI64Type(); 62 if (auto shapedTy = dyn_cast<ShapedType>(opType)) 63 i64Ty = shapedTy.clone(i64Ty); 64 Value fixedConvert = b.create<arith::FPToSIOp>(i64Ty, operand); 65 Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert); 66 // The truncation does not preserve the sign when the truncated 67 // value is -0. So here the sign is copied again. 68 return b.create<math::CopySignOp>(fpFixedConvert, operand); 69 } 70 71 // sinhf(float x) -> (exp(x) - exp(-x)) / 2 72 static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) { 73 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 74 Value operand = op.getOperand(); 75 Type opType = operand.getType(); 76 77 Value exp = b.create<math::ExpOp>(operand); 78 Value neg = b.create<arith::NegFOp>(operand); 79 Value nexp = b.create<math::ExpOp>(neg); 80 Value sub = b.create<arith::SubFOp>(exp, nexp); 81 Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter); 82 Value res = b.create<arith::MulFOp>(sub, half); 83 rewriter.replaceOp(op, res); 84 return success(); 85 } 86 87 // coshf(float x) -> (exp(x) + exp(-x)) / 2 88 static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) { 89 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 90 Value operand = op.getOperand(); 91 Type opType = operand.getType(); 92 93 Value exp = b.create<math::ExpOp>(operand); 94 Value neg = b.create<arith::NegFOp>(operand); 95 Value nexp = b.create<math::ExpOp>(neg); 96 Value add = b.create<arith::AddFOp>(exp, nexp); 97 Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter); 98 Value res = b.create<arith::MulFOp>(add, half); 99 rewriter.replaceOp(op, res); 100 return success(); 101 } 102 103 /// Expands tanh op into 104 /// 1-exp^{-2x} / 1+exp^{-2x} 105 /// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`. 106 /// We compute a "signs" value which is -1 if input is negative and +1 if input 107 /// is positive. Then multiply the input by this value, guaranteeing that the 108 /// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0, 109 /// 1]. Expand the computation on the input `x * sign(x)`, then multiply the 110 /// result by `sign(x)` to retain sign of the real result. 111 static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) { 112 auto floatType = op.getOperand().getType(); 113 Location loc = op.getLoc(); 114 Value zero = createFloatConst(loc, floatType, 0.0, rewriter); 115 Value one = createFloatConst(loc, floatType, 1.0, rewriter); 116 Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter); 117 118 // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1 119 Value isNegative = rewriter.create<arith::CmpFOp>( 120 loc, arith::CmpFPredicate::OLT, op.getOperand(), zero); 121 Value isNegativeFloat = 122 rewriter.create<arith::UIToFPOp>(loc, floatType, isNegative); 123 Value isNegativeTimesNegTwo = 124 rewriter.create<arith::MulFOp>(loc, isNegativeFloat, negTwo); 125 Value sign = rewriter.create<arith::AddFOp>(loc, isNegativeTimesNegTwo, one); 126 127 // Normalize input to positive value: y = sign(x) * x 128 Value positiveX = rewriter.create<arith::MulFOp>(loc, sign, op.getOperand()); 129 130 // Decompose on normalized input 131 Value negDoubledX = rewriter.create<arith::MulFOp>(loc, negTwo, positiveX); 132 Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX); 133 Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x); 134 Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x); 135 Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor); 136 137 // Multiply result by sign(x) to retain signs from negative inputs 138 rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes); 139 140 return success(); 141 } 142 143 // Converts math.tan to math.sin, math.cos, and arith.divf. 144 static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) { 145 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 146 Value operand = op.getOperand(); 147 Type type = operand.getType(); 148 Value sin = b.create<math::SinOp>(type, operand); 149 Value cos = b.create<math::CosOp>(type, operand); 150 Value div = b.create<arith::DivFOp>(type, sin, cos); 151 rewriter.replaceOp(op, div); 152 return success(); 153 } 154 155 // asinh(float x) -> log(x + sqrt(x**2 + 1)) 156 static LogicalResult convertAsinhOp(math::AsinhOp op, 157 PatternRewriter &rewriter) { 158 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 159 Value operand = op.getOperand(); 160 Type opType = operand.getType(); 161 162 Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter); 163 Value fma = b.create<math::FmaOp>(operand, operand, one); 164 Value sqrt = b.create<math::SqrtOp>(fma); 165 Value add = b.create<arith::AddFOp>(operand, sqrt); 166 Value res = b.create<math::LogOp>(add); 167 rewriter.replaceOp(op, res); 168 return success(); 169 } 170 171 // acosh(float x) -> log(x + sqrt(x**2 - 1)) 172 static LogicalResult convertAcoshOp(math::AcoshOp op, 173 PatternRewriter &rewriter) { 174 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 175 Value operand = op.getOperand(); 176 Type opType = operand.getType(); 177 178 Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter); 179 Value fma = b.create<math::FmaOp>(operand, operand, negOne); 180 Value sqrt = b.create<math::SqrtOp>(fma); 181 Value add = b.create<arith::AddFOp>(operand, sqrt); 182 Value res = b.create<math::LogOp>(add); 183 rewriter.replaceOp(op, res); 184 return success(); 185 } 186 187 // atanh(float x) -> log((1 + x) / (1 - x)) / 2 188 static LogicalResult convertAtanhOp(math::AtanhOp op, 189 PatternRewriter &rewriter) { 190 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 191 Value operand = op.getOperand(); 192 Type opType = operand.getType(); 193 194 Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter); 195 Value add = b.create<arith::AddFOp>(operand, one); 196 Value neg = b.create<arith::NegFOp>(operand); 197 Value sub = b.create<arith::AddFOp>(neg, one); 198 Value div = b.create<arith::DivFOp>(add, sub); 199 Value log = b.create<math::LogOp>(div); 200 Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter); 201 Value res = b.create<arith::MulFOp>(log, half); 202 rewriter.replaceOp(op, res); 203 return success(); 204 } 205 206 static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) { 207 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 208 Value operandA = op.getOperand(0); 209 Value operandB = op.getOperand(1); 210 Value operandC = op.getOperand(2); 211 Type type = op.getType(); 212 Value mult = b.create<arith::MulFOp>(type, operandA, operandB); 213 Value add = b.create<arith::AddFOp>(type, mult, operandC); 214 rewriter.replaceOp(op, add); 215 return success(); 216 } 217 218 // Converts a floorf() function to the following: 219 // floorf(float x) -> 220 // y = (float)(int) x 221 // if (x < 0) then incr = -1 else incr = 0 222 // y = y + incr <= replace this op with the floorf op. 223 static LogicalResult convertFloorOp(math::FloorOp op, 224 PatternRewriter &rewriter) { 225 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 226 Value operand = op.getOperand(); 227 Type opType = operand.getType(); 228 Value fpFixedConvert = createTruncatedFPValue(operand, b); 229 230 // Creating constants for later use. 231 Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); 232 Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter); 233 234 Value negCheck = 235 b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero); 236 Value incrValue = 237 b.create<arith::SelectOp>(op->getLoc(), negCheck, negOne, zero); 238 Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue); 239 rewriter.replaceOp(op, ret); 240 return success(); 241 } 242 243 // Converts a ceilf() function to the following: 244 // ceilf(float x) -> 245 // y = (float)(int) x 246 // if (x > y) then incr = 1 else incr = 0 247 // y = y + incr <= replace this op with the ceilf op. 248 static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) { 249 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 250 Value operand = op.getOperand(); 251 Type opType = operand.getType(); 252 Value fpFixedConvert = createTruncatedFPValue(operand, b); 253 254 // Creating constants for later use. 255 Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); 256 Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter); 257 258 Value gtCheck = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand, 259 fpFixedConvert); 260 Value incrValue = b.create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero); 261 262 Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue); 263 rewriter.replaceOp(op, ret); 264 return success(); 265 } 266 267 // Convert `math.fpowi` to a series of `arith.mulf` operations. 268 // If the power is negative, we divide one by the result. 269 // If both the base and power are zero, the result is 1. 270 // In the case of non constant power, we convert the operation to `math.powf`. 271 static LogicalResult convertFPowIOp(math::FPowIOp op, 272 PatternRewriter &rewriter) { 273 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 274 Value base = op.getOperand(0); 275 Value power = op.getOperand(1); 276 Type baseType = base.getType(); 277 278 auto convertFPowItoPowf = [&]() -> LogicalResult { 279 Value castPowerToFp = 280 rewriter.create<arith::SIToFPOp>(op.getLoc(), baseType, power); 281 Value res = rewriter.create<math::PowFOp>(op.getLoc(), baseType, base, 282 castPowerToFp); 283 rewriter.replaceOp(op, res); 284 return success(); 285 }; 286 287 Attribute cstAttr; 288 if (!matchPattern(power, m_Constant(&cstAttr))) 289 return convertFPowItoPowf(); 290 291 APInt value; 292 if (!matchPattern(cstAttr, m_ConstantInt(&value))) 293 return convertFPowItoPowf(); 294 295 int64_t powerInt = value.getSExtValue(); 296 bool isNegative = powerInt < 0; 297 int64_t absPower = std::abs(powerInt); 298 Value one = createFloatConst(op->getLoc(), baseType, 1.00, rewriter); 299 Value res = createFloatConst(op->getLoc(), baseType, 1.00, rewriter); 300 301 while (absPower > 0) { 302 if (absPower & 1) 303 res = b.create<arith::MulFOp>(baseType, base, res); 304 absPower >>= 1; 305 base = b.create<arith::MulFOp>(baseType, base, base); 306 } 307 308 // Make sure not to introduce UB in case of negative power. 309 if (isNegative) { 310 auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType)) 311 .getFloatSemantics(); 312 Value zero = 313 createFloatConst(op->getLoc(), baseType, 314 APFloat::getZero(sem, /*Negative=*/false), rewriter); 315 Value negZero = 316 createFloatConst(op->getLoc(), baseType, 317 APFloat::getZero(sem, /*Negative=*/true), rewriter); 318 Value posInfinity = 319 createFloatConst(op->getLoc(), baseType, 320 APFloat::getInf(sem, /*Negative=*/false), rewriter); 321 Value negInfinity = 322 createFloatConst(op->getLoc(), baseType, 323 APFloat::getInf(sem, /*Negative=*/true), rewriter); 324 Value zeroEqCheck = 325 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero); 326 Value negZeroEqCheck = 327 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero); 328 res = b.create<arith::DivFOp>(baseType, one, res); 329 res = 330 b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res); 331 res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity, 332 res); 333 } 334 335 rewriter.replaceOp(op, res); 336 return success(); 337 } 338 339 // Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a)) 340 static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { 341 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 342 Value operandA = op.getOperand(0); 343 Value operandB = op.getOperand(1); 344 Type opType = operandA.getType(); 345 Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); 346 Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter); 347 Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter); 348 Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter); 349 Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA); 350 Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two); 351 352 Value logA = b.create<math::LogOp>(opType, opASquared); 353 Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA); 354 Value expResult = b.create<math::ExpOp>(opType, mult); 355 Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne); 356 Value remainder = b.create<arith::RemFOp>(opType, operandB, two); 357 Value negCheck = 358 b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero); 359 Value oddPower = 360 b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero); 361 Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck); 362 363 // First, we select between the exp value and the adjusted value for odd 364 // powers of negatives. Then, we ensure that one is produced if `b` is zero. 365 // This corresponds to `libm` behavior, even for `0^0`. Without this check, 366 // `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`. 367 Value zeroCheck = 368 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero); 369 Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult, 370 expResult); 371 res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res); 372 rewriter.replaceOp(op, res); 373 return success(); 374 } 375 376 // exp2f(float x) -> exp(x * ln(2)) 377 // Proof: Let's say 2^x = y 378 // ln(2^x) = ln(y) 379 // x * ln(2) = ln(y) => e ^(x*ln(2)) = y 380 static LogicalResult convertExp2fOp(math::Exp2Op op, 381 PatternRewriter &rewriter) { 382 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 383 Value operand = op.getOperand(); 384 Type opType = operand.getType(); 385 Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b); 386 Value mult = b.create<arith::MulFOp>(opType, operand, ln2); 387 Value exp = b.create<math::ExpOp>(op->getLoc(), mult); 388 rewriter.replaceOp(op, exp); 389 return success(); 390 } 391 392 static LogicalResult convertRoundOp(math::RoundOp op, 393 PatternRewriter &rewriter) { 394 Location loc = op.getLoc(); 395 ImplicitLocOpBuilder b(loc, rewriter); 396 Value operand = op.getOperand(); 397 Type opType = operand.getType(); 398 Type opEType = getElementTypeOrSelf(opType); 399 400 if (!opEType.isF32()) { 401 return rewriter.notifyMatchFailure(op, "not a round of f32."); 402 } 403 404 Type i32Ty = b.getI32Type(); 405 if (auto shapedTy = dyn_cast<ShapedType>(opType)) 406 i32Ty = shapedTy.clone(i32Ty); 407 408 Value half = createFloatConst(loc, opType, 0.5, b); 409 Value c23 = createIntConst(loc, i32Ty, 23, b); 410 Value c127 = createIntConst(loc, i32Ty, 127, b); 411 Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b); 412 413 Value incrValue = b.create<math::CopySignOp>(half, operand); 414 Value add = b.create<arith::AddFOp>(opType, operand, incrValue); 415 Value fpFixedConvert = createTruncatedFPValue(add, b); 416 417 // There are three cases where adding 0.5 to the value and truncating by 418 // converting to an i64 does not result in the correct behavior: 419 // 420 // 1. Special values: +-inf and +-nan 421 // Casting these special values to i64 has undefined behavior. To identify 422 // these values, we use the fact that these values are the only float 423 // values with the maximum possible biased exponent. 424 // 425 // 2. Large values: 2^23 <= |x| <= INT_64_MAX 426 // Adding 0.5 to a float larger than or equal to 2^23 results in precision 427 // errors that sometimes round the value up and sometimes round the value 428 // down. For example: 429 // 8388608.0 + 0.5 = 8388608.0 430 // 8388609.0 + 0.5 = 8388610.0 431 // 432 // 3. Very large values: |x| > INT_64_MAX 433 // Casting to i64 a value greater than the max i64 value will overflow the 434 // i64 leading to wrong outputs. 435 // 436 // All three cases satisfy the property `biasedExp >= 23`. 437 Value operandBitcast = b.create<arith::BitcastOp>(i32Ty, operand); 438 Value operandExp = b.create<arith::AndIOp>( 439 b.create<arith::ShRUIOp>(operandBitcast, c23), expMask); 440 Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127); 441 Value isSpecialValOrLargeVal = 442 b.create<arith::CmpIOp>(arith::CmpIPredicate::sge, operandBiasedExp, c23); 443 444 Value result = b.create<arith::SelectOp>(isSpecialValOrLargeVal, operand, 445 fpFixedConvert); 446 rewriter.replaceOp(op, result); 447 return success(); 448 } 449 450 // Converts math.ctlz to scf and arith operations. This is done 451 // by performing a binary search on the bits. 452 static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, 453 PatternRewriter &rewriter) { 454 auto operand = op.getOperand(); 455 auto operandTy = operand.getType(); 456 auto eTy = getElementTypeOrSelf(operandTy); 457 Location loc = op.getLoc(); 458 459 int32_t bitwidth = eTy.getIntOrFloatBitWidth(); 460 if (bitwidth > 64) 461 return failure(); 462 463 uint64_t allbits = -1; 464 if (bitwidth < 64) { 465 allbits = allbits >> (64 - bitwidth); 466 } 467 468 Value x = operand; 469 Value count = createIntConst(loc, operandTy, 0, rewriter); 470 for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) { 471 auto half = bw / 2; 472 auto bits = createIntConst(loc, operandTy, half, rewriter); 473 auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter); 474 475 Value pred = 476 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, x, mask); 477 Value add = rewriter.create<arith::AddIOp>(loc, count, bits); 478 Value shift = rewriter.create<arith::ShLIOp>(loc, x, bits); 479 480 x = rewriter.create<arith::SelectOp>(loc, pred, shift, x); 481 count = rewriter.create<arith::SelectOp>(loc, pred, add, count); 482 } 483 484 Value zero = createIntConst(loc, operandTy, 0, rewriter); 485 Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 486 operand, zero); 487 488 Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter); 489 Value sel = rewriter.create<arith::SelectOp>(loc, pred, bwval, count); 490 rewriter.replaceOp(op, sel); 491 return success(); 492 } 493 494 // Convert `math.roundeven` into `math.round` + arith ops 495 static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, 496 PatternRewriter &rewriter) { 497 Location loc = op.getLoc(); 498 ImplicitLocOpBuilder b(loc, rewriter); 499 auto operand = op.getOperand(); 500 Type operandTy = operand.getType(); 501 Type resultTy = op.getType(); 502 Type operandETy = getElementTypeOrSelf(operandTy); 503 Type resultETy = getElementTypeOrSelf(resultTy); 504 505 if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) { 506 return rewriter.notifyMatchFailure(op, "not a roundeven of f16 or f32."); 507 } 508 509 Type fTy = operandTy; 510 Type iTy = rewriter.getIntegerType(operandETy.getIntOrFloatBitWidth()); 511 if (auto shapedTy = dyn_cast<ShapedType>(fTy)) { 512 iTy = shapedTy.clone(iTy); 513 } 514 515 unsigned bitWidth = operandETy.getIntOrFloatBitWidth(); 516 // The width returned by getFPMantissaWidth includes the integer bit. 517 unsigned mantissaWidth = 518 llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1; 519 unsigned exponentWidth = bitWidth - mantissaWidth - 1; 520 521 // The names of the variables correspond to f32. 522 // f64: 1 bit sign | 11 bits exponent | 52 bits mantissa. 523 // f32: 1 bit sign | 8 bits exponent | 23 bits mantissa. 524 // f16: 1 bit sign | 5 bits exponent | 10 bits mantissa. 525 Value c1Float = createFloatConst(loc, fTy, 1.0, b); 526 Value c0 = createIntConst(loc, iTy, 0, b); 527 Value c1 = createIntConst(loc, iTy, 1, b); 528 Value cNeg1 = createIntConst(loc, iTy, -1, b); 529 Value c23 = createIntConst(loc, iTy, mantissaWidth, b); 530 Value c31 = createIntConst(loc, iTy, bitWidth - 1, b); 531 Value c127 = createIntConst(loc, iTy, (1ull << (exponentWidth - 1)) - 1, b); 532 Value c2To22 = createIntConst(loc, iTy, 1ull << (mantissaWidth - 1), b); 533 Value c23Mask = createIntConst(loc, iTy, (1ull << mantissaWidth) - 1, b); 534 Value expMask = createIntConst(loc, iTy, (1ull << exponentWidth) - 1, b); 535 536 Value operandBitcast = b.create<arith::BitcastOp>(iTy, operand); 537 Value round = b.create<math::RoundOp>(operand); 538 Value roundBitcast = b.create<arith::BitcastOp>(iTy, round); 539 540 // Get biased exponents for operand and round(operand) 541 Value operandExp = b.create<arith::AndIOp>( 542 b.create<arith::ShRUIOp>(operandBitcast, c23), expMask); 543 Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127); 544 Value roundExp = b.create<arith::AndIOp>( 545 b.create<arith::ShRUIOp>(roundBitcast, c23), expMask); 546 Value roundBiasedExp = b.create<arith::SubIOp>(roundExp, c127); 547 548 auto safeShiftRight = [&](Value x, Value shift) -> Value { 549 // Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior 550 Value clampedShift = b.create<arith::MaxSIOp>(shift, c0); 551 clampedShift = b.create<arith::MinSIOp>(clampedShift, c31); 552 return b.create<arith::ShRUIOp>(x, clampedShift); 553 }; 554 555 auto maskMantissa = [&](Value mantissa, 556 Value mantissaMaskRightShift) -> Value { 557 Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift); 558 return b.create<arith::AndIOp>(mantissa, shiftedMantissaMask); 559 }; 560 561 // A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring 562 // the leftmost `clamp(biasedExp - 1, 0, 23)` bits, is zero. Large numbers 563 // with `biasedExp > 23` (numbers where there is not enough precision to store 564 // decimals) are always even, and they satisfy the even condition trivially 565 // since the mantissa without all its bits is zero. The even condition 566 // is also true for +-0, since they have `biasedExp = -127` and the entire 567 // mantissa is zero. The case of +-1 has to be handled separately. Here 568 // we identify these values by noting that +-1 are the only whole numbers with 569 // `biasedExp == 0`. 570 // 571 // The special values +-inf and +-nan also satisfy the same property that 572 // whole non-unit even numbers satisfy. In particular, the special values have 573 // `biasedExp > 23`, so they get treated as large numbers with no room for 574 // decimals, which are always even. 575 Value roundBiasedExpEq0 = 576 b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, roundBiasedExp, c0); 577 Value roundBiasedExpMinus1 = b.create<arith::SubIOp>(roundBiasedExp, c1); 578 Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1); 579 Value roundIsNotEvenOrSpecialVal = b.create<arith::CmpIOp>( 580 arith::CmpIPredicate::ne, roundMaskedMantissa, c0); 581 roundIsNotEvenOrSpecialVal = 582 b.create<arith::OrIOp>(roundIsNotEvenOrSpecialVal, roundBiasedExpEq0); 583 584 // A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive 585 // integers if the bit at index `biasedExp` starting from the left in the 586 // mantissa is 1 and all the bits to the right are zero. Values with 587 // `biasedExp >= 23` don't have decimals, so they are never halfway. The 588 // values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`, 589 // so these are handled separately. In particular, if `biasedExp == -1`, the 590 // value is halfway if the entire mantissa is zero. 591 Value operandBiasedExpEqNeg1 = b.create<arith::CmpIOp>( 592 arith::CmpIPredicate::eq, operandBiasedExp, cNeg1); 593 Value expectedOperandMaskedMantissa = b.create<arith::SelectOp>( 594 operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp)); 595 Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp); 596 Value operandIsHalfway = 597 b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, operandMaskedMantissa, 598 expectedOperandMaskedMantissa); 599 // Ensure `biasedExp` is in the valid range for half values. 600 Value operandBiasedExpGeNeg1 = b.create<arith::CmpIOp>( 601 arith::CmpIPredicate::sge, operandBiasedExp, cNeg1); 602 Value operandBiasedExpLt23 = 603 b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, operandBiasedExp, c23); 604 operandIsHalfway = 605 b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpLt23); 606 operandIsHalfway = 607 b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpGeNeg1); 608 609 // Adjust rounded operand with `round(operand) - sign(operand)` to correct the 610 // case where `round` rounded in the opposite direction of `roundeven`. 611 Value sign = b.create<math::CopySignOp>(c1Float, operand); 612 Value roundShifted = b.create<arith::SubFOp>(round, sign); 613 // If the rounded value is even or a special value, we default to the behavior 614 // of `math.round`. 615 Value needsShift = 616 b.create<arith::AndIOp>(roundIsNotEvenOrSpecialVal, operandIsHalfway); 617 Value result = b.create<arith::SelectOp>(needsShift, roundShifted, round); 618 // The `x - sign` adjustment does not preserve the sign when we are adjusting 619 // the value -1 to -0. So here the sign is copied again to ensure that -0.5 is 620 // rounded to -0.0. 621 result = b.create<math::CopySignOp>(result, operand); 622 rewriter.replaceOp(op, result); 623 return success(); 624 } 625 626 // Convert `math.rsqrt` into `arith.divf` + `math.sqrt` 627 static LogicalResult convertRsqrtOp(math::RsqrtOp op, 628 PatternRewriter &rewriter) { 629 630 auto operand = op.getOperand(); 631 auto operandTy = operand.getType(); 632 auto eTy = getElementTypeOrSelf(operandTy); 633 if (!isa<FloatType>(eTy)) 634 return failure(); 635 636 Location loc = op->getLoc(); 637 auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter); 638 auto sqrtOp = rewriter.create<math::SqrtOp>(loc, operand); 639 rewriter.replaceOpWithNewOp<arith::DivFOp>(op, constOneFloat, sqrtOp); 640 return success(); 641 } 642 643 void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) { 644 patterns.add(convertCtlzOp); 645 } 646 647 void mlir::populateExpandSinhPattern(RewritePatternSet &patterns) { 648 patterns.add(convertSinhOp); 649 } 650 651 void mlir::populateExpandCoshPattern(RewritePatternSet &patterns) { 652 patterns.add(convertCoshOp); 653 } 654 655 void mlir::populateExpandTanPattern(RewritePatternSet &patterns) { 656 patterns.add(convertTanOp); 657 } 658 659 void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) { 660 patterns.add(convertTanhOp); 661 } 662 663 void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) { 664 patterns.add(convertAsinhOp); 665 } 666 667 void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) { 668 patterns.add(convertAcoshOp); 669 } 670 671 void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) { 672 patterns.add(convertAtanhOp); 673 } 674 675 void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) { 676 patterns.add(convertFmaFOp); 677 } 678 679 void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) { 680 patterns.add(convertCeilOp); 681 } 682 683 void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) { 684 patterns.add(convertExp2fOp); 685 } 686 687 void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) { 688 patterns.add(convertPowfOp); 689 } 690 691 void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) { 692 patterns.add(convertFPowIOp); 693 } 694 695 void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) { 696 patterns.add(convertRoundOp); 697 } 698 699 void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) { 700 patterns.add(convertFloorOp); 701 } 702 703 void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) { 704 patterns.add(convertRoundEvenOp); 705 } 706 707 void mlir::populateExpandRsqrtPattern(RewritePatternSet &patterns) { 708 patterns.add(convertRsqrtOp); 709 } 710