1 //===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements expansion of tanh op. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Math/IR/Math.h" 15 #include "mlir/Dialect/Math/Transforms/Passes.h" 16 #include "mlir/Dialect/SCF/IR/SCF.h" 17 #include "mlir/Dialect/Vector/IR/VectorOps.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/ImplicitLocOpBuilder.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 23 using namespace mlir; 24 25 /// Create a float constant. 26 static Value createFloatConst(Location loc, Type type, double value, 27 OpBuilder &b) { 28 auto attr = b.getFloatAttr(getElementTypeOrSelf(type), value); 29 if (auto shapedTy = dyn_cast<ShapedType>(type)) { 30 return b.create<arith::ConstantOp>(loc, 31 DenseElementsAttr::get(shapedTy, attr)); 32 } 33 34 return b.create<arith::ConstantOp>(loc, attr); 35 } 36 37 /// Create a float constant. 38 static Value createIntConst(Location loc, Type type, int64_t value, 39 OpBuilder &b) { 40 auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value); 41 if (auto shapedTy = dyn_cast<ShapedType>(type)) { 42 return b.create<arith::ConstantOp>(loc, 43 DenseElementsAttr::get(shapedTy, attr)); 44 } 45 46 return b.create<arith::ConstantOp>(loc, attr); 47 } 48 49 static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) { 50 Type opType = operand.getType(); 51 Type i64Ty = b.getI64Type(); 52 if (auto shapedTy = dyn_cast<ShapedType>(opType)) 53 i64Ty = shapedTy.clone(i64Ty); 54 Value fixedConvert = b.create<arith::FPToSIOp>(i64Ty, operand); 55 Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert); 56 // The truncation does not preserve the sign when the truncated 57 // value is -0. So here the sign is copied again. 58 return b.create<math::CopySignOp>(fpFixedConvert, operand); 59 } 60 61 // sinhf(float x) -> (exp(x) - exp(-x)) / 2 62 static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) { 63 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 64 Value operand = op.getOperand(); 65 Type opType = operand.getType(); 66 Value exp = b.create<math::ExpOp>(operand); 67 68 Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter); 69 Value nexp = b.create<arith::DivFOp>(one, exp); 70 Value sub = b.create<arith::SubFOp>(exp, nexp); 71 Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter); 72 Value div = b.create<arith::DivFOp>(sub, two); 73 rewriter.replaceOp(op, div); 74 return success(); 75 } 76 77 // coshf(float x) -> (exp(x) + exp(-x)) / 2 78 static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) { 79 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 80 Value operand = op.getOperand(); 81 Type opType = operand.getType(); 82 Value exp = b.create<math::ExpOp>(operand); 83 84 Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter); 85 Value nexp = b.create<arith::DivFOp>(one, exp); 86 Value add = b.create<arith::AddFOp>(exp, nexp); 87 Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter); 88 Value div = b.create<arith::DivFOp>(add, two); 89 rewriter.replaceOp(op, div); 90 return success(); 91 } 92 93 /// Expands tanh op into 94 /// 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0 95 /// 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0 96 static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) { 97 auto floatType = op.getOperand().getType(); 98 Location loc = op.getLoc(); 99 Value one = createFloatConst(loc, floatType, 1.0, rewriter); 100 Value two = createFloatConst(loc, floatType, 2.0, rewriter); 101 Value doubledX = rewriter.create<arith::MulFOp>(loc, op.getOperand(), two); 102 103 // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x} 104 Value negDoubledX = rewriter.create<arith::NegFOp>(loc, doubledX); 105 Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX); 106 Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x); 107 Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x); 108 Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor); 109 110 // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1 111 exp2x = rewriter.create<math::ExpOp>(loc, doubledX); 112 dividend = rewriter.create<arith::SubFOp>(loc, exp2x, one); 113 divisor = rewriter.create<arith::AddFOp>(loc, exp2x, one); 114 Value negativeRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor); 115 116 // tanh(x) = x >= 0 ? positiveRes : negativeRes 117 Value zero = createFloatConst(loc, floatType, 0.0, rewriter); 118 Value cmpRes = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE, 119 op.getOperand(), zero); 120 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmpRes, positiveRes, 121 negativeRes); 122 return success(); 123 } 124 125 // Converts math.tan to math.sin, math.cos, and arith.divf. 126 static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) { 127 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 128 Value operand = op.getOperand(); 129 Type type = operand.getType(); 130 Value sin = b.create<math::SinOp>(type, operand); 131 Value cos = b.create<math::CosOp>(type, operand); 132 Value div = b.create<arith::DivFOp>(type, sin, cos); 133 rewriter.replaceOp(op, div); 134 return success(); 135 } 136 137 static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) { 138 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 139 Value operandA = op.getOperand(0); 140 Value operandB = op.getOperand(1); 141 Value operandC = op.getOperand(2); 142 Type type = op.getType(); 143 Value mult = b.create<arith::MulFOp>(type, operandA, operandB); 144 Value add = b.create<arith::AddFOp>(type, mult, operandC); 145 rewriter.replaceOp(op, add); 146 return success(); 147 } 148 149 // Converts a floorf() function to the following: 150 // floorf(float x) -> 151 // y = (float)(int) x 152 // if (x < 0) then incr = -1 else incr = 0 153 // y = y + incr <= replace this op with the floorf op. 154 static LogicalResult convertFloorOp(math::FloorOp op, 155 PatternRewriter &rewriter) { 156 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 157 Value operand = op.getOperand(); 158 Type opType = operand.getType(); 159 Value fpFixedConvert = createTruncatedFPValue(operand, b); 160 161 // Creating constants for later use. 162 Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); 163 Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter); 164 165 Value negCheck = 166 b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero); 167 Value incrValue = 168 b.create<arith::SelectOp>(op->getLoc(), negCheck, negOne, zero); 169 Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue); 170 rewriter.replaceOp(op, ret); 171 return success(); 172 } 173 174 // Converts a ceilf() function to the following: 175 // ceilf(float x) -> 176 // y = (float)(int) x 177 // if (x > y) then incr = 1 else incr = 0 178 // y = y + incr <= replace this op with the ceilf op. 179 static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) { 180 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 181 Value operand = op.getOperand(); 182 Type opType = operand.getType(); 183 Value fpFixedConvert = createTruncatedFPValue(operand, b); 184 185 // Creating constants for later use. 186 Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); 187 Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter); 188 189 Value gtCheck = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand, 190 fpFixedConvert); 191 Value incrValue = b.create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero); 192 193 Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue); 194 rewriter.replaceOp(op, ret); 195 return success(); 196 } 197 // Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a)) 198 static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { 199 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 200 Value operandA = op.getOperand(0); 201 Value operandB = op.getOperand(1); 202 Type opType = operandA.getType(); 203 Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); 204 Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter); 205 Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter); 206 Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA); 207 Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two); 208 209 Value logA = b.create<math::LogOp>(opType, opASquared); 210 Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA); 211 Value expResult = b.create<math::ExpOp>(opType, mult); 212 Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne); 213 Value remainder = b.create<arith::RemFOp>(opType, operandB, two); 214 Value negCheck = 215 b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero); 216 Value oddPower = 217 b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero); 218 Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck); 219 220 Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult, 221 expResult); 222 rewriter.replaceOp(op, res); 223 return success(); 224 } 225 226 // exp2f(float x) -> exp(x * ln(2)) 227 // Proof: Let's say 2^x = y 228 // ln(2^x) = ln(y) 229 // x * ln(2) = ln(y) => e ^(x*ln(2)) = y 230 static LogicalResult convertExp2fOp(math::Exp2Op op, 231 PatternRewriter &rewriter) { 232 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 233 Value operand = op.getOperand(); 234 Type opType = operand.getType(); 235 Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b); 236 Value mult = b.create<arith::MulFOp>(opType, operand, ln2); 237 Value exp = b.create<math::ExpOp>(op->getLoc(), mult); 238 rewriter.replaceOp(op, exp); 239 return success(); 240 } 241 242 static LogicalResult convertRoundOp(math::RoundOp op, 243 PatternRewriter &rewriter) { 244 Location loc = op.getLoc(); 245 ImplicitLocOpBuilder b(loc, rewriter); 246 Value operand = op.getOperand(); 247 Type opType = operand.getType(); 248 Type opEType = getElementTypeOrSelf(opType); 249 250 if (!opEType.isF32()) { 251 return rewriter.notifyMatchFailure(op, "not a round of f32."); 252 } 253 254 Type i32Ty = b.getI32Type(); 255 if (auto shapedTy = dyn_cast<ShapedType>(opType)) 256 i32Ty = shapedTy.clone(i32Ty); 257 258 Value half = createFloatConst(loc, opType, 0.5, b); 259 Value c23 = createIntConst(loc, i32Ty, 23, b); 260 Value c127 = createIntConst(loc, i32Ty, 127, b); 261 Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b); 262 263 Value incrValue = b.create<math::CopySignOp>(half, operand); 264 Value add = b.create<arith::AddFOp>(opType, operand, incrValue); 265 Value fpFixedConvert = createTruncatedFPValue(add, b); 266 267 // There are three cases where adding 0.5 to the value and truncating by 268 // converting to an i64 does not result in the correct behavior: 269 // 270 // 1. Special values: +-inf and +-nan 271 // Casting these special values to i64 has undefined behavior. To identify 272 // these values, we use the fact that these values are the only float 273 // values with the maximum possible biased exponent. 274 // 275 // 2. Large values: 2^23 <= |x| <= INT_64_MAX 276 // Adding 0.5 to a float larger than or equal to 2^23 results in precision 277 // errors that sometimes round the value up and sometimes round the value 278 // down. For example: 279 // 8388608.0 + 0.5 = 8388608.0 280 // 8388609.0 + 0.5 = 8388610.0 281 // 282 // 3. Very large values: |x| > INT_64_MAX 283 // Casting to i64 a value greater than the max i64 value will overflow the 284 // i64 leading to wrong outputs. 285 // 286 // All three cases satisfy the property `biasedExp >= 23`. 287 Value operandBitcast = b.create<arith::BitcastOp>(i32Ty, operand); 288 Value operandExp = b.create<arith::AndIOp>( 289 b.create<arith::ShRUIOp>(operandBitcast, c23), expMask); 290 Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127); 291 Value isSpecialValOrLargeVal = 292 b.create<arith::CmpIOp>(arith::CmpIPredicate::sge, operandBiasedExp, c23); 293 294 Value result = b.create<arith::SelectOp>(isSpecialValOrLargeVal, operand, 295 fpFixedConvert); 296 rewriter.replaceOp(op, result); 297 return success(); 298 } 299 300 // Converts math.ctlz to scf and arith operations. This is done 301 // by performing a binary search on the bits. 302 static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, 303 PatternRewriter &rewriter) { 304 auto operand = op.getOperand(); 305 auto operandTy = operand.getType(); 306 auto eTy = getElementTypeOrSelf(operandTy); 307 Location loc = op.getLoc(); 308 309 int32_t bitwidth = eTy.getIntOrFloatBitWidth(); 310 if (bitwidth > 64) 311 return failure(); 312 313 uint64_t allbits = -1; 314 if (bitwidth < 64) { 315 allbits = allbits >> (64 - bitwidth); 316 } 317 318 Value x = operand; 319 Value count = createIntConst(loc, operandTy, 0, rewriter); 320 for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) { 321 auto half = bw / 2; 322 auto bits = createIntConst(loc, operandTy, half, rewriter); 323 auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter); 324 325 Value pred = 326 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, x, mask); 327 Value add = rewriter.create<arith::AddIOp>(loc, count, bits); 328 Value shift = rewriter.create<arith::ShLIOp>(loc, x, bits); 329 330 x = rewriter.create<arith::SelectOp>(loc, pred, shift, x); 331 count = rewriter.create<arith::SelectOp>(loc, pred, add, count); 332 } 333 334 Value zero = createIntConst(loc, operandTy, 0, rewriter); 335 Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 336 operand, zero); 337 338 Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter); 339 Value sel = rewriter.create<arith::SelectOp>(loc, pred, bwval, count); 340 rewriter.replaceOp(op, sel); 341 return success(); 342 } 343 344 // Convert `math.roundeven` into `math.round` + arith ops 345 static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, 346 PatternRewriter &rewriter) { 347 Location loc = op.getLoc(); 348 ImplicitLocOpBuilder b(loc, rewriter); 349 auto operand = op.getOperand(); 350 Type operandTy = operand.getType(); 351 Type resultTy = op.getType(); 352 Type operandETy = getElementTypeOrSelf(operandTy); 353 Type resultETy = getElementTypeOrSelf(resultTy); 354 355 if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) { 356 return rewriter.notifyMatchFailure(op, "not a roundeven of f16 or f32."); 357 } 358 359 Type fTy = operandTy; 360 Type iTy = rewriter.getIntegerType(operandETy.getIntOrFloatBitWidth()); 361 if (auto shapedTy = dyn_cast<ShapedType>(fTy)) { 362 iTy = shapedTy.clone(iTy); 363 } 364 365 unsigned bitWidth = operandETy.getIntOrFloatBitWidth(); 366 // The width returned by getFPMantissaWidth includes the integer bit. 367 unsigned mantissaWidth = 368 llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1; 369 unsigned exponentWidth = bitWidth - mantissaWidth - 1; 370 371 // The names of the variables correspond to f32. 372 // f64: 1 bit sign | 11 bits exponent | 52 bits mantissa. 373 // f32: 1 bit sign | 8 bits exponent | 23 bits mantissa. 374 // f16: 1 bit sign | 5 bits exponent | 10 bits mantissa. 375 Value c1Float = createFloatConst(loc, fTy, 1.0, b); 376 Value c0 = createIntConst(loc, iTy, 0, b); 377 Value c1 = createIntConst(loc, iTy, 1, b); 378 Value cNeg1 = createIntConst(loc, iTy, -1, b); 379 Value c23 = createIntConst(loc, iTy, mantissaWidth, b); 380 Value c31 = createIntConst(loc, iTy, bitWidth - 1, b); 381 Value c127 = createIntConst(loc, iTy, (1ull << (exponentWidth - 1)) - 1, b); 382 Value c2To22 = createIntConst(loc, iTy, 1ull << (mantissaWidth - 1), b); 383 Value c23Mask = createIntConst(loc, iTy, (1ull << mantissaWidth) - 1, b); 384 Value expMask = createIntConst(loc, iTy, (1ull << exponentWidth) - 1, b); 385 386 Value operandBitcast = b.create<arith::BitcastOp>(iTy, operand); 387 Value round = b.create<math::RoundOp>(operand); 388 Value roundBitcast = b.create<arith::BitcastOp>(iTy, round); 389 390 // Get biased exponents for operand and round(operand) 391 Value operandExp = b.create<arith::AndIOp>( 392 b.create<arith::ShRUIOp>(operandBitcast, c23), expMask); 393 Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127); 394 Value roundExp = b.create<arith::AndIOp>( 395 b.create<arith::ShRUIOp>(roundBitcast, c23), expMask); 396 Value roundBiasedExp = b.create<arith::SubIOp>(roundExp, c127); 397 398 auto safeShiftRight = [&](Value x, Value shift) -> Value { 399 // Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior 400 Value clampedShift = b.create<arith::MaxSIOp>(shift, c0); 401 clampedShift = b.create<arith::MinSIOp>(clampedShift, c31); 402 return b.create<arith::ShRUIOp>(x, clampedShift); 403 }; 404 405 auto maskMantissa = [&](Value mantissa, 406 Value mantissaMaskRightShift) -> Value { 407 Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift); 408 return b.create<arith::AndIOp>(mantissa, shiftedMantissaMask); 409 }; 410 411 // A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring 412 // the leftmost `clamp(biasedExp - 1, 0, 23)` bits, is zero. Large numbers 413 // with `biasedExp > 23` (numbers where there is not enough precision to store 414 // decimals) are always even, and they satisfy the even condition trivially 415 // since the mantissa without all its bits is zero. The even condition 416 // is also true for +-0, since they have `biasedExp = -127` and the entire 417 // mantissa is zero. The case of +-1 has to be handled separately. Here 418 // we identify these values by noting that +-1 are the only whole numbers with 419 // `biasedExp == 0`. 420 // 421 // The special values +-inf and +-nan also satisfy the same property that 422 // whole non-unit even numbers satisfy. In particular, the special values have 423 // `biasedExp > 23`, so they get treated as large numbers with no room for 424 // decimals, which are always even. 425 Value roundBiasedExpEq0 = 426 b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, roundBiasedExp, c0); 427 Value roundBiasedExpMinus1 = b.create<arith::SubIOp>(roundBiasedExp, c1); 428 Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1); 429 Value roundIsNotEvenOrSpecialVal = b.create<arith::CmpIOp>( 430 arith::CmpIPredicate::ne, roundMaskedMantissa, c0); 431 roundIsNotEvenOrSpecialVal = 432 b.create<arith::OrIOp>(roundIsNotEvenOrSpecialVal, roundBiasedExpEq0); 433 434 // A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive 435 // integers if the bit at index `biasedExp` starting from the left in the 436 // mantissa is 1 and all the bits to the right are zero. Values with 437 // `biasedExp >= 23` don't have decimals, so they are never halfway. The 438 // values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`, 439 // so these are handled separately. In particular, if `biasedExp == -1`, the 440 // value is halfway if the entire mantissa is zero. 441 Value operandBiasedExpEqNeg1 = b.create<arith::CmpIOp>( 442 arith::CmpIPredicate::eq, operandBiasedExp, cNeg1); 443 Value expectedOperandMaskedMantissa = b.create<arith::SelectOp>( 444 operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp)); 445 Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp); 446 Value operandIsHalfway = 447 b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, operandMaskedMantissa, 448 expectedOperandMaskedMantissa); 449 // Ensure `biasedExp` is in the valid range for half values. 450 Value operandBiasedExpGeNeg1 = b.create<arith::CmpIOp>( 451 arith::CmpIPredicate::sge, operandBiasedExp, cNeg1); 452 Value operandBiasedExpLt23 = 453 b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, operandBiasedExp, c23); 454 operandIsHalfway = 455 b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpLt23); 456 operandIsHalfway = 457 b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpGeNeg1); 458 459 // Adjust rounded operand with `round(operand) - sign(operand)` to correct the 460 // case where `round` rounded in the opposite direction of `roundeven`. 461 Value sign = b.create<math::CopySignOp>(c1Float, operand); 462 Value roundShifted = b.create<arith::SubFOp>(round, sign); 463 // If the rounded value is even or a special value, we default to the behavior 464 // of `math.round`. 465 Value needsShift = 466 b.create<arith::AndIOp>(roundIsNotEvenOrSpecialVal, operandIsHalfway); 467 Value result = b.create<arith::SelectOp>(needsShift, roundShifted, round); 468 // The `x - sign` adjustment does not preserve the sign when we are adjusting 469 // the value -1 to -0. So here the sign is copied again to ensure that -0.5 is 470 // rounded to -0.0. 471 result = b.create<math::CopySignOp>(result, operand); 472 rewriter.replaceOp(op, result); 473 return success(); 474 } 475 476 void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) { 477 patterns.add(convertCtlzOp); 478 } 479 480 void mlir::populateExpandSinhPattern(RewritePatternSet &patterns) { 481 patterns.add(convertSinhOp); 482 } 483 484 void mlir::populateExpandCoshPattern(RewritePatternSet &patterns) { 485 patterns.add(convertCoshOp); 486 } 487 488 void mlir::populateExpandTanPattern(RewritePatternSet &patterns) { 489 patterns.add(convertTanOp); 490 } 491 492 void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) { 493 patterns.add(convertTanhOp); 494 } 495 496 void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) { 497 patterns.add(convertFmaFOp); 498 } 499 500 void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) { 501 patterns.add(convertCeilOp); 502 } 503 504 void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) { 505 patterns.add(convertExp2fOp); 506 } 507 508 void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) { 509 patterns.add(convertPowfOp); 510 } 511 512 void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) { 513 patterns.add(convertRoundOp); 514 } 515 516 void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) { 517 patterns.add(convertFloorOp); 518 } 519 520 void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) { 521 patterns.add(convertRoundEvenOp); 522 } 523