110a57f3aSPrashant Kumar //===- ExpandPatterns.cpp - Code to expand various math operations. -------===// 2f3bdb56dSRob Suderman // 3f3bdb56dSRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4f3bdb56dSRob Suderman // See https://llvm.org/LICENSE.txt for license information. 5f3bdb56dSRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6f3bdb56dSRob Suderman // 7f3bdb56dSRob Suderman //===----------------------------------------------------------------------===// 8f3bdb56dSRob Suderman // 910a57f3aSPrashant Kumar // This file implements expansion of various math operations. 10f3bdb56dSRob Suderman // 11f3bdb56dSRob Suderman //===----------------------------------------------------------------------===// 12f3bdb56dSRob Suderman 13abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 14f3bdb56dSRob Suderman #include "mlir/Dialect/Math/IR/Math.h" 15f3bdb56dSRob Suderman #include "mlir/Dialect/Math/Transforms/Passes.h" 168b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 17711c5893SRobert Suderman #include "mlir/Dialect/Vector/IR/VectorOps.h" 18f3bdb56dSRob Suderman #include "mlir/IR/Builders.h" 19740e2e90SRobert Suderman #include "mlir/IR/ImplicitLocOpBuilder.h" 20711c5893SRobert Suderman #include "mlir/IR/TypeUtilities.h" 21f3bdb56dSRob Suderman #include "mlir/Transforms/DialectConversion.h" 22f3bdb56dSRob Suderman 23f3bdb56dSRob Suderman using namespace mlir; 24f3bdb56dSRob Suderman 25711c5893SRobert Suderman /// Create a float constant. 2610a57f3aSPrashant Kumar static Value createFloatConst(Location loc, Type type, APFloat value, 27711c5893SRobert Suderman OpBuilder &b) { 2810a57f3aSPrashant Kumar bool losesInfo = false; 2910a57f3aSPrashant Kumar auto eltType = getElementTypeOrSelf(type); 3010a57f3aSPrashant Kumar // Convert double to the given `FloatType` with round-to-nearest-ties-to-even. 3110a57f3aSPrashant Kumar value.convert(cast<FloatType>(eltType).getFloatSemantics(), 3210a57f3aSPrashant Kumar APFloat::rmNearestTiesToEven, &losesInfo); 3310a57f3aSPrashant Kumar auto attr = b.getFloatAttr(eltType, value); 34711c5893SRobert Suderman if (auto shapedTy = dyn_cast<ShapedType>(type)) { 35711c5893SRobert Suderman return b.create<arith::ConstantOp>(loc, 36711c5893SRobert Suderman DenseElementsAttr::get(shapedTy, attr)); 37711c5893SRobert Suderman } 38711c5893SRobert Suderman 39711c5893SRobert Suderman return b.create<arith::ConstantOp>(loc, attr); 40711c5893SRobert Suderman } 41711c5893SRobert Suderman 4210a57f3aSPrashant Kumar static Value createFloatConst(Location loc, Type type, double value, 4310a57f3aSPrashant Kumar OpBuilder &b) { 4410a57f3aSPrashant Kumar return createFloatConst(loc, type, APFloat(value), b); 4510a57f3aSPrashant Kumar } 4610a57f3aSPrashant Kumar 4710a57f3aSPrashant Kumar /// Create an integer constant. 48711c5893SRobert Suderman static Value createIntConst(Location loc, Type type, int64_t value, 49711c5893SRobert Suderman OpBuilder &b) { 50711c5893SRobert Suderman auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value); 51711c5893SRobert Suderman if (auto shapedTy = dyn_cast<ShapedType>(type)) { 52711c5893SRobert Suderman return b.create<arith::ConstantOp>(loc, 53711c5893SRobert Suderman DenseElementsAttr::get(shapedTy, attr)); 54711c5893SRobert Suderman } 55711c5893SRobert Suderman 56711c5893SRobert Suderman return b.create<arith::ConstantOp>(loc, attr); 57711c5893SRobert Suderman } 58711c5893SRobert Suderman 592217888dSBalaji V. Iyer static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) { 602217888dSBalaji V. Iyer Type opType = operand.getType(); 6144baa655SRamiro Leal-Cavazos Type i64Ty = b.getI64Type(); 6244baa655SRamiro Leal-Cavazos if (auto shapedTy = dyn_cast<ShapedType>(opType)) 6344baa655SRamiro Leal-Cavazos i64Ty = shapedTy.clone(i64Ty); 6444baa655SRamiro Leal-Cavazos Value fixedConvert = b.create<arith::FPToSIOp>(i64Ty, operand); 652217888dSBalaji V. Iyer Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert); 6644baa655SRamiro Leal-Cavazos // The truncation does not preserve the sign when the truncated 6744baa655SRamiro Leal-Cavazos // value is -0. So here the sign is copied again. 6844baa655SRamiro Leal-Cavazos return b.create<math::CopySignOp>(fpFixedConvert, operand); 692217888dSBalaji V. Iyer } 702217888dSBalaji V. Iyer 71aa165edcSRob Suderman // sinhf(float x) -> (exp(x) - exp(-x)) / 2 72aa165edcSRob Suderman static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) { 73aa165edcSRob Suderman ImplicitLocOpBuilder b(op->getLoc(), rewriter); 74aa165edcSRob Suderman Value operand = op.getOperand(); 75aa165edcSRob Suderman Type opType = operand.getType(); 76aa165edcSRob Suderman 77a62a7024Sjinchen Value exp = b.create<math::ExpOp>(operand); 78a62a7024Sjinchen Value neg = b.create<arith::NegFOp>(operand); 79a62a7024Sjinchen Value nexp = b.create<math::ExpOp>(neg); 80aa165edcSRob Suderman Value sub = b.create<arith::SubFOp>(exp, nexp); 81a62a7024Sjinchen Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter); 82a62a7024Sjinchen Value res = b.create<arith::MulFOp>(sub, half); 83a62a7024Sjinchen rewriter.replaceOp(op, res); 84aa165edcSRob Suderman return success(); 85aa165edcSRob Suderman } 86aa165edcSRob Suderman 87aa165edcSRob Suderman // coshf(float x) -> (exp(x) + exp(-x)) / 2 88aa165edcSRob Suderman static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) { 89aa165edcSRob Suderman ImplicitLocOpBuilder b(op->getLoc(), rewriter); 90aa165edcSRob Suderman Value operand = op.getOperand(); 91aa165edcSRob Suderman Type opType = operand.getType(); 92aa165edcSRob Suderman 93a62a7024Sjinchen Value exp = b.create<math::ExpOp>(operand); 94a62a7024Sjinchen Value neg = b.create<arith::NegFOp>(operand); 95a62a7024Sjinchen Value nexp = b.create<math::ExpOp>(neg); 96aa165edcSRob Suderman Value add = b.create<arith::AddFOp>(exp, nexp); 97a62a7024Sjinchen Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter); 98a62a7024Sjinchen Value res = b.create<arith::MulFOp>(add, half); 99a62a7024Sjinchen rewriter.replaceOp(op, res); 100aa165edcSRob Suderman return success(); 101aa165edcSRob Suderman } 102aa165edcSRob Suderman 103f3bdb56dSRob Suderman /// Expands tanh op into 104d39ac3a8Ssrcarroll /// 1-exp^{-2x} / 1+exp^{-2x} 105d39ac3a8Ssrcarroll /// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`. 106d39ac3a8Ssrcarroll /// We compute a "signs" value which is -1 if input is negative and +1 if input 107d39ac3a8Ssrcarroll /// is positive. Then multiply the input by this value, guaranteeing that the 108d39ac3a8Ssrcarroll /// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0, 109d39ac3a8Ssrcarroll /// 1]. Expand the computation on the input `x * sign(x)`, then multiply the 110d39ac3a8Ssrcarroll /// result by `sign(x)` to retain sign of the real result. 111f3bdb56dSRob Suderman static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) { 112f3bdb56dSRob Suderman auto floatType = op.getOperand().getType(); 113f3bdb56dSRob Suderman Location loc = op.getLoc(); 114d39ac3a8Ssrcarroll Value zero = createFloatConst(loc, floatType, 0.0, rewriter); 115711c5893SRobert Suderman Value one = createFloatConst(loc, floatType, 1.0, rewriter); 116d39ac3a8Ssrcarroll Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter); 117f3bdb56dSRob Suderman 118d39ac3a8Ssrcarroll // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1 119d39ac3a8Ssrcarroll Value isNegative = rewriter.create<arith::CmpFOp>( 120d39ac3a8Ssrcarroll loc, arith::CmpFPredicate::OLT, op.getOperand(), zero); 121d39ac3a8Ssrcarroll Value isNegativeFloat = 122d39ac3a8Ssrcarroll rewriter.create<arith::UIToFPOp>(loc, floatType, isNegative); 123d39ac3a8Ssrcarroll Value isNegativeTimesNegTwo = 124d39ac3a8Ssrcarroll rewriter.create<arith::MulFOp>(loc, isNegativeFloat, negTwo); 125d39ac3a8Ssrcarroll Value sign = rewriter.create<arith::AddFOp>(loc, isNegativeTimesNegTwo, one); 126d39ac3a8Ssrcarroll 127d39ac3a8Ssrcarroll // Normalize input to positive value: y = sign(x) * x 128d39ac3a8Ssrcarroll Value positiveX = rewriter.create<arith::MulFOp>(loc, sign, op.getOperand()); 129d39ac3a8Ssrcarroll 130d39ac3a8Ssrcarroll // Decompose on normalized input 131d39ac3a8Ssrcarroll Value negDoubledX = rewriter.create<arith::MulFOp>(loc, negTwo, positiveX); 132f3bdb56dSRob Suderman Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX); 133f3bdb56dSRob Suderman Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x); 134f3bdb56dSRob Suderman Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x); 135f3bdb56dSRob Suderman Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor); 136f3bdb56dSRob Suderman 137d39ac3a8Ssrcarroll // Multiply result by sign(x) to retain signs from negative inputs 138d39ac3a8Ssrcarroll rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes); 139f3bdb56dSRob Suderman 140f3bdb56dSRob Suderman return success(); 141f3bdb56dSRob Suderman } 142f3bdb56dSRob Suderman 143711c5893SRobert Suderman // Converts math.tan to math.sin, math.cos, and arith.divf. 144740e2e90SRobert Suderman static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) { 145740e2e90SRobert Suderman ImplicitLocOpBuilder b(op->getLoc(), rewriter); 146740e2e90SRobert Suderman Value operand = op.getOperand(); 147740e2e90SRobert Suderman Type type = operand.getType(); 148740e2e90SRobert Suderman Value sin = b.create<math::SinOp>(type, operand); 149740e2e90SRobert Suderman Value cos = b.create<math::CosOp>(type, operand); 150740e2e90SRobert Suderman Value div = b.create<arith::DivFOp>(type, sin, cos); 151740e2e90SRobert Suderman rewriter.replaceOp(op, div); 152740e2e90SRobert Suderman return success(); 153740e2e90SRobert Suderman } 154740e2e90SRobert Suderman 155a62a7024Sjinchen // asinh(float x) -> log(x + sqrt(x**2 + 1)) 156a62a7024Sjinchen static LogicalResult convertAsinhOp(math::AsinhOp op, 157a62a7024Sjinchen PatternRewriter &rewriter) { 158a62a7024Sjinchen ImplicitLocOpBuilder b(op->getLoc(), rewriter); 159a62a7024Sjinchen Value operand = op.getOperand(); 160a62a7024Sjinchen Type opType = operand.getType(); 161a62a7024Sjinchen 162a62a7024Sjinchen Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter); 163a62a7024Sjinchen Value fma = b.create<math::FmaOp>(operand, operand, one); 164a62a7024Sjinchen Value sqrt = b.create<math::SqrtOp>(fma); 165a62a7024Sjinchen Value add = b.create<arith::AddFOp>(operand, sqrt); 166a62a7024Sjinchen Value res = b.create<math::LogOp>(add); 167a62a7024Sjinchen rewriter.replaceOp(op, res); 168a62a7024Sjinchen return success(); 169a62a7024Sjinchen } 170a62a7024Sjinchen 171a62a7024Sjinchen // acosh(float x) -> log(x + sqrt(x**2 - 1)) 172a62a7024Sjinchen static LogicalResult convertAcoshOp(math::AcoshOp op, 173a62a7024Sjinchen PatternRewriter &rewriter) { 174a62a7024Sjinchen ImplicitLocOpBuilder b(op->getLoc(), rewriter); 175a62a7024Sjinchen Value operand = op.getOperand(); 176a62a7024Sjinchen Type opType = operand.getType(); 177a62a7024Sjinchen 178a62a7024Sjinchen Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter); 179a62a7024Sjinchen Value fma = b.create<math::FmaOp>(operand, operand, negOne); 180a62a7024Sjinchen Value sqrt = b.create<math::SqrtOp>(fma); 181a62a7024Sjinchen Value add = b.create<arith::AddFOp>(operand, sqrt); 182a62a7024Sjinchen Value res = b.create<math::LogOp>(add); 183a62a7024Sjinchen rewriter.replaceOp(op, res); 184a62a7024Sjinchen return success(); 185a62a7024Sjinchen } 186a62a7024Sjinchen 187a62a7024Sjinchen // atanh(float x) -> log((1 + x) / (1 - x)) / 2 188a62a7024Sjinchen static LogicalResult convertAtanhOp(math::AtanhOp op, 189a62a7024Sjinchen PatternRewriter &rewriter) { 190a62a7024Sjinchen ImplicitLocOpBuilder b(op->getLoc(), rewriter); 191a62a7024Sjinchen Value operand = op.getOperand(); 192a62a7024Sjinchen Type opType = operand.getType(); 193a62a7024Sjinchen 194a62a7024Sjinchen Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter); 195a62a7024Sjinchen Value add = b.create<arith::AddFOp>(operand, one); 196a62a7024Sjinchen Value neg = b.create<arith::NegFOp>(operand); 197a62a7024Sjinchen Value sub = b.create<arith::AddFOp>(neg, one); 198a62a7024Sjinchen Value div = b.create<arith::DivFOp>(add, sub); 199a62a7024Sjinchen Value log = b.create<math::LogOp>(div); 200a62a7024Sjinchen Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter); 201a62a7024Sjinchen Value res = b.create<arith::MulFOp>(log, half); 202a62a7024Sjinchen rewriter.replaceOp(op, res); 203a62a7024Sjinchen return success(); 204a62a7024Sjinchen } 205a62a7024Sjinchen 206a7c2102dSBalaji V. Iyer static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) { 207a7c2102dSBalaji V. Iyer ImplicitLocOpBuilder b(op->getLoc(), rewriter); 208a7c2102dSBalaji V. Iyer Value operandA = op.getOperand(0); 209a7c2102dSBalaji V. Iyer Value operandB = op.getOperand(1); 210a7c2102dSBalaji V. Iyer Value operandC = op.getOperand(2); 211a7c2102dSBalaji V. Iyer Type type = op.getType(); 212a7c2102dSBalaji V. Iyer Value mult = b.create<arith::MulFOp>(type, operandA, operandB); 213a7c2102dSBalaji V. Iyer Value add = b.create<arith::AddFOp>(type, mult, operandC); 214a7c2102dSBalaji V. Iyer rewriter.replaceOp(op, add); 215a7c2102dSBalaji V. Iyer return success(); 216a7c2102dSBalaji V. Iyer } 217a7c2102dSBalaji V. Iyer 2182217888dSBalaji V. Iyer // Converts a ceilf() function to the following: 2192217888dSBalaji V. Iyer // ceilf(float x) -> 2202217888dSBalaji V. Iyer // y = (float)(int) x 2212217888dSBalaji V. Iyer // if (x > y) then incr = 1 else incr = 0 2222217888dSBalaji V. Iyer // y = y + incr <= replace this op with the ceilf op. 2232217888dSBalaji V. Iyer static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) { 2242217888dSBalaji V. Iyer ImplicitLocOpBuilder b(op->getLoc(), rewriter); 2252217888dSBalaji V. Iyer Value operand = op.getOperand(); 2262217888dSBalaji V. Iyer Type opType = operand.getType(); 2272217888dSBalaji V. Iyer Value fpFixedConvert = createTruncatedFPValue(operand, b); 2282217888dSBalaji V. Iyer 2292217888dSBalaji V. Iyer // Creating constants for later use. 2302217888dSBalaji V. Iyer Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); 2312217888dSBalaji V. Iyer Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter); 2322217888dSBalaji V. Iyer 2332217888dSBalaji V. Iyer Value gtCheck = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand, 2342217888dSBalaji V. Iyer fpFixedConvert); 2352217888dSBalaji V. Iyer Value incrValue = b.create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero); 2362217888dSBalaji V. Iyer 2372217888dSBalaji V. Iyer Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue); 2382217888dSBalaji V. Iyer rewriter.replaceOp(op, ret); 2392217888dSBalaji V. Iyer return success(); 2402217888dSBalaji V. Iyer } 24110a57f3aSPrashant Kumar 24210a57f3aSPrashant Kumar // Convert `math.fpowi` to a series of `arith.mulf` operations. 24310a57f3aSPrashant Kumar // If the power is negative, we divide one by the result. 24410a57f3aSPrashant Kumar // If both the base and power are zero, the result is 1. 2455b702be1SPrashant Kumar // In the case of non constant power, we convert the operation to `math.powf`. 2465b702be1SPrashant Kumar static LogicalResult convertFPowIOp(math::FPowIOp op, 24710a57f3aSPrashant Kumar PatternRewriter &rewriter) { 24810a57f3aSPrashant Kumar ImplicitLocOpBuilder b(op->getLoc(), rewriter); 24910a57f3aSPrashant Kumar Value base = op.getOperand(0); 25010a57f3aSPrashant Kumar Value power = op.getOperand(1); 25110a57f3aSPrashant Kumar Type baseType = base.getType(); 25210a57f3aSPrashant Kumar 2535b702be1SPrashant Kumar auto convertFPowItoPowf = [&]() -> LogicalResult { 2545b702be1SPrashant Kumar Value castPowerToFp = 2555b702be1SPrashant Kumar rewriter.create<arith::SIToFPOp>(op.getLoc(), baseType, power); 2565b702be1SPrashant Kumar Value res = rewriter.create<math::PowFOp>(op.getLoc(), baseType, base, 2575b702be1SPrashant Kumar castPowerToFp); 2585b702be1SPrashant Kumar rewriter.replaceOp(op, res); 2595b702be1SPrashant Kumar return success(); 2605b702be1SPrashant Kumar }; 2615b702be1SPrashant Kumar 26210a57f3aSPrashant Kumar Attribute cstAttr; 26310a57f3aSPrashant Kumar if (!matchPattern(power, m_Constant(&cstAttr))) 2645b702be1SPrashant Kumar return convertFPowItoPowf(); 26510a57f3aSPrashant Kumar 26610a57f3aSPrashant Kumar APInt value; 26710a57f3aSPrashant Kumar if (!matchPattern(cstAttr, m_ConstantInt(&value))) 2685b702be1SPrashant Kumar return convertFPowItoPowf(); 26910a57f3aSPrashant Kumar 27010a57f3aSPrashant Kumar int64_t powerInt = value.getSExtValue(); 27110a57f3aSPrashant Kumar bool isNegative = powerInt < 0; 27210a57f3aSPrashant Kumar int64_t absPower = std::abs(powerInt); 27310a57f3aSPrashant Kumar Value one = createFloatConst(op->getLoc(), baseType, 1.00, rewriter); 27410a57f3aSPrashant Kumar Value res = createFloatConst(op->getLoc(), baseType, 1.00, rewriter); 27510a57f3aSPrashant Kumar 27610a57f3aSPrashant Kumar while (absPower > 0) { 27710a57f3aSPrashant Kumar if (absPower & 1) 27810a57f3aSPrashant Kumar res = b.create<arith::MulFOp>(baseType, base, res); 27910a57f3aSPrashant Kumar absPower >>= 1; 28010a57f3aSPrashant Kumar base = b.create<arith::MulFOp>(baseType, base, base); 28110a57f3aSPrashant Kumar } 28210a57f3aSPrashant Kumar 28310a57f3aSPrashant Kumar // Make sure not to introduce UB in case of negative power. 28410a57f3aSPrashant Kumar if (isNegative) { 28510a57f3aSPrashant Kumar auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType)) 28610a57f3aSPrashant Kumar .getFloatSemantics(); 28710a57f3aSPrashant Kumar Value zero = 28810a57f3aSPrashant Kumar createFloatConst(op->getLoc(), baseType, 28910a57f3aSPrashant Kumar APFloat::getZero(sem, /*Negative=*/false), rewriter); 29010a57f3aSPrashant Kumar Value negZero = 29110a57f3aSPrashant Kumar createFloatConst(op->getLoc(), baseType, 29210a57f3aSPrashant Kumar APFloat::getZero(sem, /*Negative=*/true), rewriter); 29310a57f3aSPrashant Kumar Value posInfinity = 29410a57f3aSPrashant Kumar createFloatConst(op->getLoc(), baseType, 29510a57f3aSPrashant Kumar APFloat::getInf(sem, /*Negative=*/false), rewriter); 29610a57f3aSPrashant Kumar Value negInfinity = 29710a57f3aSPrashant Kumar createFloatConst(op->getLoc(), baseType, 29810a57f3aSPrashant Kumar APFloat::getInf(sem, /*Negative=*/true), rewriter); 29910a57f3aSPrashant Kumar Value zeroEqCheck = 30010a57f3aSPrashant Kumar b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero); 30110a57f3aSPrashant Kumar Value negZeroEqCheck = 30210a57f3aSPrashant Kumar b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero); 30310a57f3aSPrashant Kumar res = b.create<arith::DivFOp>(baseType, one, res); 30410a57f3aSPrashant Kumar res = 30510a57f3aSPrashant Kumar b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res); 30610a57f3aSPrashant Kumar res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity, 30710a57f3aSPrashant Kumar res); 30810a57f3aSPrashant Kumar } 30910a57f3aSPrashant Kumar 31010a57f3aSPrashant Kumar rewriter.replaceOp(op, res); 31110a57f3aSPrashant Kumar return success(); 31210a57f3aSPrashant Kumar } 31310a57f3aSPrashant Kumar 3142d4e8567SBalaji V. Iyer // Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a)) 315*3a337757SHyunsung Lee // Restricting a >= 0 3162d4e8567SBalaji V. Iyer static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { 3172d4e8567SBalaji V. Iyer ImplicitLocOpBuilder b(op->getLoc(), rewriter); 3182d4e8567SBalaji V. Iyer Value operandA = op.getOperand(0); 3192d4e8567SBalaji V. Iyer Value operandB = op.getOperand(1); 3202d4e8567SBalaji V. Iyer Type opType = operandA.getType(); 321f66e4bd6SBalaji V. Iyer Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); 322a92e3df3SChristopher Bate Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter); 3232d4e8567SBalaji V. Iyer 324*3a337757SHyunsung Lee Value logA = b.create<math::LogOp>(opType, operandA); 325*3a337757SHyunsung Lee Value mult = b.create<arith::MulFOp>(opType, operandB, logA); 3262d4e8567SBalaji V. Iyer Value expResult = b.create<math::ExpOp>(opType, mult); 327f66e4bd6SBalaji V. Iyer 328a92e3df3SChristopher Bate // First, we select between the exp value and the adjusted value for odd 329a92e3df3SChristopher Bate // powers of negatives. Then, we ensure that one is produced if `b` is zero. 330a92e3df3SChristopher Bate // This corresponds to `libm` behavior, even for `0^0`. Without this check, 331a92e3df3SChristopher Bate // `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`. 332a92e3df3SChristopher Bate Value zeroCheck = 333a92e3df3SChristopher Bate b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero); 334*3a337757SHyunsung Lee Value finalResult = 335*3a337757SHyunsung Lee b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, expResult); 336*3a337757SHyunsung Lee rewriter.replaceOp(op, finalResult); 3372d4e8567SBalaji V. Iyer return success(); 3382d4e8567SBalaji V. Iyer } 3392217888dSBalaji V. Iyer 3404da96515SBalaji V. Iyer // exp2f(float x) -> exp(x * ln(2)) 3414da96515SBalaji V. Iyer // Proof: Let's say 2^x = y 3424da96515SBalaji V. Iyer // ln(2^x) = ln(y) 3434da96515SBalaji V. Iyer // x * ln(2) = ln(y) => e ^(x*ln(2)) = y 3444da96515SBalaji V. Iyer static LogicalResult convertExp2fOp(math::Exp2Op op, 3454da96515SBalaji V. Iyer PatternRewriter &rewriter) { 3464da96515SBalaji V. Iyer ImplicitLocOpBuilder b(op->getLoc(), rewriter); 3474da96515SBalaji V. Iyer Value operand = op.getOperand(); 3484da96515SBalaji V. Iyer Type opType = operand.getType(); 3494da96515SBalaji V. Iyer Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b); 3504da96515SBalaji V. Iyer Value mult = b.create<arith::MulFOp>(opType, operand, ln2); 3514da96515SBalaji V. Iyer Value exp = b.create<math::ExpOp>(op->getLoc(), mult); 3524da96515SBalaji V. Iyer rewriter.replaceOp(op, exp); 3534da96515SBalaji V. Iyer return success(); 3544da96515SBalaji V. Iyer } 3554da96515SBalaji V. Iyer 356be911578SBalaji V. Iyer static LogicalResult convertRoundOp(math::RoundOp op, 357be911578SBalaji V. Iyer PatternRewriter &rewriter) { 35844baa655SRamiro Leal-Cavazos Location loc = op.getLoc(); 35944baa655SRamiro Leal-Cavazos ImplicitLocOpBuilder b(loc, rewriter); 360be911578SBalaji V. Iyer Value operand = op.getOperand(); 361be911578SBalaji V. Iyer Type opType = operand.getType(); 36244baa655SRamiro Leal-Cavazos Type opEType = getElementTypeOrSelf(opType); 363be911578SBalaji V. Iyer 36444baa655SRamiro Leal-Cavazos if (!opEType.isF32()) { 36544baa655SRamiro Leal-Cavazos return rewriter.notifyMatchFailure(op, "not a round of f32."); 36644baa655SRamiro Leal-Cavazos } 367be911578SBalaji V. Iyer 36844baa655SRamiro Leal-Cavazos Type i32Ty = b.getI32Type(); 36944baa655SRamiro Leal-Cavazos if (auto shapedTy = dyn_cast<ShapedType>(opType)) 37044baa655SRamiro Leal-Cavazos i32Ty = shapedTy.clone(i32Ty); 37144baa655SRamiro Leal-Cavazos 37244baa655SRamiro Leal-Cavazos Value half = createFloatConst(loc, opType, 0.5, b); 37344baa655SRamiro Leal-Cavazos Value c23 = createIntConst(loc, i32Ty, 23, b); 37444baa655SRamiro Leal-Cavazos Value c127 = createIntConst(loc, i32Ty, 127, b); 37544baa655SRamiro Leal-Cavazos Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b); 37644baa655SRamiro Leal-Cavazos 37744baa655SRamiro Leal-Cavazos Value incrValue = b.create<math::CopySignOp>(half, operand); 378be911578SBalaji V. Iyer Value add = b.create<arith::AddFOp>(opType, operand, incrValue); 379be911578SBalaji V. Iyer Value fpFixedConvert = createTruncatedFPValue(add, b); 38044baa655SRamiro Leal-Cavazos 38144baa655SRamiro Leal-Cavazos // There are three cases where adding 0.5 to the value and truncating by 38244baa655SRamiro Leal-Cavazos // converting to an i64 does not result in the correct behavior: 38344baa655SRamiro Leal-Cavazos // 38444baa655SRamiro Leal-Cavazos // 1. Special values: +-inf and +-nan 38544baa655SRamiro Leal-Cavazos // Casting these special values to i64 has undefined behavior. To identify 38644baa655SRamiro Leal-Cavazos // these values, we use the fact that these values are the only float 38744baa655SRamiro Leal-Cavazos // values with the maximum possible biased exponent. 38844baa655SRamiro Leal-Cavazos // 38944baa655SRamiro Leal-Cavazos // 2. Large values: 2^23 <= |x| <= INT_64_MAX 39044baa655SRamiro Leal-Cavazos // Adding 0.5 to a float larger than or equal to 2^23 results in precision 39144baa655SRamiro Leal-Cavazos // errors that sometimes round the value up and sometimes round the value 39244baa655SRamiro Leal-Cavazos // down. For example: 39344baa655SRamiro Leal-Cavazos // 8388608.0 + 0.5 = 8388608.0 39444baa655SRamiro Leal-Cavazos // 8388609.0 + 0.5 = 8388610.0 39544baa655SRamiro Leal-Cavazos // 39644baa655SRamiro Leal-Cavazos // 3. Very large values: |x| > INT_64_MAX 39744baa655SRamiro Leal-Cavazos // Casting to i64 a value greater than the max i64 value will overflow the 39844baa655SRamiro Leal-Cavazos // i64 leading to wrong outputs. 39944baa655SRamiro Leal-Cavazos // 40044baa655SRamiro Leal-Cavazos // All three cases satisfy the property `biasedExp >= 23`. 40144baa655SRamiro Leal-Cavazos Value operandBitcast = b.create<arith::BitcastOp>(i32Ty, operand); 40244baa655SRamiro Leal-Cavazos Value operandExp = b.create<arith::AndIOp>( 40344baa655SRamiro Leal-Cavazos b.create<arith::ShRUIOp>(operandBitcast, c23), expMask); 40444baa655SRamiro Leal-Cavazos Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127); 40544baa655SRamiro Leal-Cavazos Value isSpecialValOrLargeVal = 40644baa655SRamiro Leal-Cavazos b.create<arith::CmpIOp>(arith::CmpIPredicate::sge, operandBiasedExp, c23); 40744baa655SRamiro Leal-Cavazos 40844baa655SRamiro Leal-Cavazos Value result = b.create<arith::SelectOp>(isSpecialValOrLargeVal, operand, 40944baa655SRamiro Leal-Cavazos fpFixedConvert); 41044baa655SRamiro Leal-Cavazos rewriter.replaceOp(op, result); 411be911578SBalaji V. Iyer return success(); 412be911578SBalaji V. Iyer } 413be911578SBalaji V. Iyer 414711c5893SRobert Suderman // Converts math.ctlz to scf and arith operations. This is done 415711c5893SRobert Suderman // by performing a binary search on the bits. 416f3bdb56dSRob Suderman static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, 417f3bdb56dSRob Suderman PatternRewriter &rewriter) { 418f3bdb56dSRob Suderman auto operand = op.getOperand(); 419711c5893SRobert Suderman auto operandTy = operand.getType(); 420711c5893SRobert Suderman auto eTy = getElementTypeOrSelf(operandTy); 421f3bdb56dSRob Suderman Location loc = op.getLoc(); 422f3bdb56dSRob Suderman 423711c5893SRobert Suderman int32_t bitwidth = eTy.getIntOrFloatBitWidth(); 424711c5893SRobert Suderman if (bitwidth > 64) 425711c5893SRobert Suderman return failure(); 426f3bdb56dSRob Suderman 427711c5893SRobert Suderman uint64_t allbits = -1; 428711c5893SRobert Suderman if (bitwidth < 64) { 429711c5893SRobert Suderman allbits = allbits >> (64 - bitwidth); 430711c5893SRobert Suderman } 431f3bdb56dSRob Suderman 432711c5893SRobert Suderman Value x = operand; 433711c5893SRobert Suderman Value count = createIntConst(loc, operandTy, 0, rewriter); 434711c5893SRobert Suderman for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) { 435711c5893SRobert Suderman auto half = bw / 2; 436711c5893SRobert Suderman auto bits = createIntConst(loc, operandTy, half, rewriter); 437711c5893SRobert Suderman auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter); 438f3bdb56dSRob Suderman 439711c5893SRobert Suderman Value pred = 440711c5893SRobert Suderman rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, x, mask); 441711c5893SRobert Suderman Value add = rewriter.create<arith::AddIOp>(loc, count, bits); 442711c5893SRobert Suderman Value shift = rewriter.create<arith::ShLIOp>(loc, x, bits); 443f3bdb56dSRob Suderman 444711c5893SRobert Suderman x = rewriter.create<arith::SelectOp>(loc, pred, shift, x); 445711c5893SRobert Suderman count = rewriter.create<arith::SelectOp>(loc, pred, add, count); 446711c5893SRobert Suderman } 447f3bdb56dSRob Suderman 448711c5893SRobert Suderman Value zero = createIntConst(loc, operandTy, 0, rewriter); 449711c5893SRobert Suderman Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 450711c5893SRobert Suderman operand, zero); 451f3bdb56dSRob Suderman 452711c5893SRobert Suderman Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter); 453711c5893SRobert Suderman Value sel = rewriter.create<arith::SelectOp>(loc, pred, bwval, count); 454711c5893SRobert Suderman rewriter.replaceOp(op, sel); 455f3bdb56dSRob Suderman return success(); 456f3bdb56dSRob Suderman } 457f3bdb56dSRob Suderman 45844baa655SRamiro Leal-Cavazos // Convert `math.roundeven` into `math.round` + arith ops 45944baa655SRamiro Leal-Cavazos static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, 46044baa655SRamiro Leal-Cavazos PatternRewriter &rewriter) { 46144baa655SRamiro Leal-Cavazos Location loc = op.getLoc(); 46244baa655SRamiro Leal-Cavazos ImplicitLocOpBuilder b(loc, rewriter); 46344baa655SRamiro Leal-Cavazos auto operand = op.getOperand(); 46444baa655SRamiro Leal-Cavazos Type operandTy = operand.getType(); 46544baa655SRamiro Leal-Cavazos Type resultTy = op.getType(); 46644baa655SRamiro Leal-Cavazos Type operandETy = getElementTypeOrSelf(operandTy); 46744baa655SRamiro Leal-Cavazos Type resultETy = getElementTypeOrSelf(resultTy); 46844baa655SRamiro Leal-Cavazos 469fe355a44SAlexander Shaposhnikov if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) { 470fe355a44SAlexander Shaposhnikov return rewriter.notifyMatchFailure(op, "not a roundeven of f16 or f32."); 47144baa655SRamiro Leal-Cavazos } 47244baa655SRamiro Leal-Cavazos 473fe355a44SAlexander Shaposhnikov Type fTy = operandTy; 474fe355a44SAlexander Shaposhnikov Type iTy = rewriter.getIntegerType(operandETy.getIntOrFloatBitWidth()); 475fe355a44SAlexander Shaposhnikov if (auto shapedTy = dyn_cast<ShapedType>(fTy)) { 476fe355a44SAlexander Shaposhnikov iTy = shapedTy.clone(iTy); 47744baa655SRamiro Leal-Cavazos } 47844baa655SRamiro Leal-Cavazos 479fe355a44SAlexander Shaposhnikov unsigned bitWidth = operandETy.getIntOrFloatBitWidth(); 480fe355a44SAlexander Shaposhnikov // The width returned by getFPMantissaWidth includes the integer bit. 481fe355a44SAlexander Shaposhnikov unsigned mantissaWidth = 482fe355a44SAlexander Shaposhnikov llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1; 483fe355a44SAlexander Shaposhnikov unsigned exponentWidth = bitWidth - mantissaWidth - 1; 48444baa655SRamiro Leal-Cavazos 485fe355a44SAlexander Shaposhnikov // The names of the variables correspond to f32. 486fe355a44SAlexander Shaposhnikov // f64: 1 bit sign | 11 bits exponent | 52 bits mantissa. 487fe355a44SAlexander Shaposhnikov // f32: 1 bit sign | 8 bits exponent | 23 bits mantissa. 488fe355a44SAlexander Shaposhnikov // f16: 1 bit sign | 5 bits exponent | 10 bits mantissa. 489fe355a44SAlexander Shaposhnikov Value c1Float = createFloatConst(loc, fTy, 1.0, b); 490fe355a44SAlexander Shaposhnikov Value c0 = createIntConst(loc, iTy, 0, b); 491fe355a44SAlexander Shaposhnikov Value c1 = createIntConst(loc, iTy, 1, b); 492fe355a44SAlexander Shaposhnikov Value cNeg1 = createIntConst(loc, iTy, -1, b); 493fe355a44SAlexander Shaposhnikov Value c23 = createIntConst(loc, iTy, mantissaWidth, b); 494fe355a44SAlexander Shaposhnikov Value c31 = createIntConst(loc, iTy, bitWidth - 1, b); 495fe355a44SAlexander Shaposhnikov Value c127 = createIntConst(loc, iTy, (1ull << (exponentWidth - 1)) - 1, b); 496fe355a44SAlexander Shaposhnikov Value c2To22 = createIntConst(loc, iTy, 1ull << (mantissaWidth - 1), b); 497fe355a44SAlexander Shaposhnikov Value c23Mask = createIntConst(loc, iTy, (1ull << mantissaWidth) - 1, b); 498fe355a44SAlexander Shaposhnikov Value expMask = createIntConst(loc, iTy, (1ull << exponentWidth) - 1, b); 499fe355a44SAlexander Shaposhnikov 500fe355a44SAlexander Shaposhnikov Value operandBitcast = b.create<arith::BitcastOp>(iTy, operand); 50144baa655SRamiro Leal-Cavazos Value round = b.create<math::RoundOp>(operand); 502fe355a44SAlexander Shaposhnikov Value roundBitcast = b.create<arith::BitcastOp>(iTy, round); 50344baa655SRamiro Leal-Cavazos 50444baa655SRamiro Leal-Cavazos // Get biased exponents for operand and round(operand) 50544baa655SRamiro Leal-Cavazos Value operandExp = b.create<arith::AndIOp>( 50644baa655SRamiro Leal-Cavazos b.create<arith::ShRUIOp>(operandBitcast, c23), expMask); 50744baa655SRamiro Leal-Cavazos Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127); 50844baa655SRamiro Leal-Cavazos Value roundExp = b.create<arith::AndIOp>( 50944baa655SRamiro Leal-Cavazos b.create<arith::ShRUIOp>(roundBitcast, c23), expMask); 51044baa655SRamiro Leal-Cavazos Value roundBiasedExp = b.create<arith::SubIOp>(roundExp, c127); 51144baa655SRamiro Leal-Cavazos 51244baa655SRamiro Leal-Cavazos auto safeShiftRight = [&](Value x, Value shift) -> Value { 513fe355a44SAlexander Shaposhnikov // Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior 51444baa655SRamiro Leal-Cavazos Value clampedShift = b.create<arith::MaxSIOp>(shift, c0); 51544baa655SRamiro Leal-Cavazos clampedShift = b.create<arith::MinSIOp>(clampedShift, c31); 51644baa655SRamiro Leal-Cavazos return b.create<arith::ShRUIOp>(x, clampedShift); 51744baa655SRamiro Leal-Cavazos }; 51844baa655SRamiro Leal-Cavazos 51944baa655SRamiro Leal-Cavazos auto maskMantissa = [&](Value mantissa, 52044baa655SRamiro Leal-Cavazos Value mantissaMaskRightShift) -> Value { 52144baa655SRamiro Leal-Cavazos Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift); 52244baa655SRamiro Leal-Cavazos return b.create<arith::AndIOp>(mantissa, shiftedMantissaMask); 52344baa655SRamiro Leal-Cavazos }; 52444baa655SRamiro Leal-Cavazos 52544baa655SRamiro Leal-Cavazos // A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring 52644baa655SRamiro Leal-Cavazos // the leftmost `clamp(biasedExp - 1, 0, 23)` bits, is zero. Large numbers 52744baa655SRamiro Leal-Cavazos // with `biasedExp > 23` (numbers where there is not enough precision to store 52844baa655SRamiro Leal-Cavazos // decimals) are always even, and they satisfy the even condition trivially 52944baa655SRamiro Leal-Cavazos // since the mantissa without all its bits is zero. The even condition 53044baa655SRamiro Leal-Cavazos // is also true for +-0, since they have `biasedExp = -127` and the entire 53144baa655SRamiro Leal-Cavazos // mantissa is zero. The case of +-1 has to be handled separately. Here 53244baa655SRamiro Leal-Cavazos // we identify these values by noting that +-1 are the only whole numbers with 53344baa655SRamiro Leal-Cavazos // `biasedExp == 0`. 53444baa655SRamiro Leal-Cavazos // 53544baa655SRamiro Leal-Cavazos // The special values +-inf and +-nan also satisfy the same property that 53644baa655SRamiro Leal-Cavazos // whole non-unit even numbers satisfy. In particular, the special values have 53744baa655SRamiro Leal-Cavazos // `biasedExp > 23`, so they get treated as large numbers with no room for 53844baa655SRamiro Leal-Cavazos // decimals, which are always even. 53944baa655SRamiro Leal-Cavazos Value roundBiasedExpEq0 = 54044baa655SRamiro Leal-Cavazos b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, roundBiasedExp, c0); 54144baa655SRamiro Leal-Cavazos Value roundBiasedExpMinus1 = b.create<arith::SubIOp>(roundBiasedExp, c1); 54244baa655SRamiro Leal-Cavazos Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1); 54344baa655SRamiro Leal-Cavazos Value roundIsNotEvenOrSpecialVal = b.create<arith::CmpIOp>( 54444baa655SRamiro Leal-Cavazos arith::CmpIPredicate::ne, roundMaskedMantissa, c0); 54544baa655SRamiro Leal-Cavazos roundIsNotEvenOrSpecialVal = 54644baa655SRamiro Leal-Cavazos b.create<arith::OrIOp>(roundIsNotEvenOrSpecialVal, roundBiasedExpEq0); 54744baa655SRamiro Leal-Cavazos 54844baa655SRamiro Leal-Cavazos // A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive 54944baa655SRamiro Leal-Cavazos // integers if the bit at index `biasedExp` starting from the left in the 55044baa655SRamiro Leal-Cavazos // mantissa is 1 and all the bits to the right are zero. Values with 55144baa655SRamiro Leal-Cavazos // `biasedExp >= 23` don't have decimals, so they are never halfway. The 55244baa655SRamiro Leal-Cavazos // values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`, 55344baa655SRamiro Leal-Cavazos // so these are handled separately. In particular, if `biasedExp == -1`, the 55444baa655SRamiro Leal-Cavazos // value is halfway if the entire mantissa is zero. 55544baa655SRamiro Leal-Cavazos Value operandBiasedExpEqNeg1 = b.create<arith::CmpIOp>( 55644baa655SRamiro Leal-Cavazos arith::CmpIPredicate::eq, operandBiasedExp, cNeg1); 55744baa655SRamiro Leal-Cavazos Value expectedOperandMaskedMantissa = b.create<arith::SelectOp>( 55844baa655SRamiro Leal-Cavazos operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp)); 55944baa655SRamiro Leal-Cavazos Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp); 56044baa655SRamiro Leal-Cavazos Value operandIsHalfway = 56144baa655SRamiro Leal-Cavazos b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, operandMaskedMantissa, 56244baa655SRamiro Leal-Cavazos expectedOperandMaskedMantissa); 56344baa655SRamiro Leal-Cavazos // Ensure `biasedExp` is in the valid range for half values. 56444baa655SRamiro Leal-Cavazos Value operandBiasedExpGeNeg1 = b.create<arith::CmpIOp>( 56544baa655SRamiro Leal-Cavazos arith::CmpIPredicate::sge, operandBiasedExp, cNeg1); 56644baa655SRamiro Leal-Cavazos Value operandBiasedExpLt23 = 56744baa655SRamiro Leal-Cavazos b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, operandBiasedExp, c23); 56844baa655SRamiro Leal-Cavazos operandIsHalfway = 56944baa655SRamiro Leal-Cavazos b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpLt23); 57044baa655SRamiro Leal-Cavazos operandIsHalfway = 57144baa655SRamiro Leal-Cavazos b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpGeNeg1); 57244baa655SRamiro Leal-Cavazos 57344baa655SRamiro Leal-Cavazos // Adjust rounded operand with `round(operand) - sign(operand)` to correct the 57444baa655SRamiro Leal-Cavazos // case where `round` rounded in the opposite direction of `roundeven`. 57544baa655SRamiro Leal-Cavazos Value sign = b.create<math::CopySignOp>(c1Float, operand); 57644baa655SRamiro Leal-Cavazos Value roundShifted = b.create<arith::SubFOp>(round, sign); 57744baa655SRamiro Leal-Cavazos // If the rounded value is even or a special value, we default to the behavior 57844baa655SRamiro Leal-Cavazos // of `math.round`. 57944baa655SRamiro Leal-Cavazos Value needsShift = 58044baa655SRamiro Leal-Cavazos b.create<arith::AndIOp>(roundIsNotEvenOrSpecialVal, operandIsHalfway); 58144baa655SRamiro Leal-Cavazos Value result = b.create<arith::SelectOp>(needsShift, roundShifted, round); 58244baa655SRamiro Leal-Cavazos // The `x - sign` adjustment does not preserve the sign when we are adjusting 58344baa655SRamiro Leal-Cavazos // the value -1 to -0. So here the sign is copied again to ensure that -0.5 is 58444baa655SRamiro Leal-Cavazos // rounded to -0.0. 58544baa655SRamiro Leal-Cavazos result = b.create<math::CopySignOp>(result, operand); 58644baa655SRamiro Leal-Cavazos rewriter.replaceOp(op, result); 58744baa655SRamiro Leal-Cavazos return success(); 58844baa655SRamiro Leal-Cavazos } 58944baa655SRamiro Leal-Cavazos 590279a659eSCorentin Ferry // Convert `math.rsqrt` into `arith.divf` + `math.sqrt` 591279a659eSCorentin Ferry static LogicalResult convertRsqrtOp(math::RsqrtOp op, 592279a659eSCorentin Ferry PatternRewriter &rewriter) { 593279a659eSCorentin Ferry 594279a659eSCorentin Ferry auto operand = op.getOperand(); 595279a659eSCorentin Ferry auto operandTy = operand.getType(); 596279a659eSCorentin Ferry auto eTy = getElementTypeOrSelf(operandTy); 597279a659eSCorentin Ferry if (!isa<FloatType>(eTy)) 598279a659eSCorentin Ferry return failure(); 599279a659eSCorentin Ferry 600279a659eSCorentin Ferry Location loc = op->getLoc(); 601279a659eSCorentin Ferry auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter); 602279a659eSCorentin Ferry auto sqrtOp = rewriter.create<math::SqrtOp>(loc, operand); 603279a659eSCorentin Ferry rewriter.replaceOpWithNewOp<arith::DivFOp>(op, constOneFloat, sqrtOp); 604279a659eSCorentin Ferry return success(); 605279a659eSCorentin Ferry } 606279a659eSCorentin Ferry 607f3bdb56dSRob Suderman void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) { 608f3bdb56dSRob Suderman patterns.add(convertCtlzOp); 609f3bdb56dSRob Suderman } 610f3bdb56dSRob Suderman 611aa165edcSRob Suderman void mlir::populateExpandSinhPattern(RewritePatternSet &patterns) { 612aa165edcSRob Suderman patterns.add(convertSinhOp); 613aa165edcSRob Suderman } 614aa165edcSRob Suderman 615aa165edcSRob Suderman void mlir::populateExpandCoshPattern(RewritePatternSet &patterns) { 616aa165edcSRob Suderman patterns.add(convertCoshOp); 617aa165edcSRob Suderman } 618aa165edcSRob Suderman 619740e2e90SRobert Suderman void mlir::populateExpandTanPattern(RewritePatternSet &patterns) { 620740e2e90SRobert Suderman patterns.add(convertTanOp); 621740e2e90SRobert Suderman } 622740e2e90SRobert Suderman 623f3bdb56dSRob Suderman void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) { 624f3bdb56dSRob Suderman patterns.add(convertTanhOp); 625f3bdb56dSRob Suderman } 626a7c2102dSBalaji V. Iyer 627a62a7024Sjinchen void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) { 628a62a7024Sjinchen patterns.add(convertAsinhOp); 629a62a7024Sjinchen } 630a62a7024Sjinchen 631a62a7024Sjinchen void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) { 632a62a7024Sjinchen patterns.add(convertAcoshOp); 633a62a7024Sjinchen } 634a62a7024Sjinchen 635a62a7024Sjinchen void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) { 636a62a7024Sjinchen patterns.add(convertAtanhOp); 637a62a7024Sjinchen } 638a62a7024Sjinchen 639a7c2102dSBalaji V. Iyer void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) { 640a7c2102dSBalaji V. Iyer patterns.add(convertFmaFOp); 641a7c2102dSBalaji V. Iyer } 6422217888dSBalaji V. Iyer 6432217888dSBalaji V. Iyer void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) { 6442217888dSBalaji V. Iyer patterns.add(convertCeilOp); 6452217888dSBalaji V. Iyer } 6462217888dSBalaji V. Iyer 6474da96515SBalaji V. Iyer void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) { 6484da96515SBalaji V. Iyer patterns.add(convertExp2fOp); 6494da96515SBalaji V. Iyer } 6504da96515SBalaji V. Iyer 6512d4e8567SBalaji V. Iyer void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) { 6522d4e8567SBalaji V. Iyer patterns.add(convertPowfOp); 6532d4e8567SBalaji V. Iyer } 6542d4e8567SBalaji V. Iyer 65510a57f3aSPrashant Kumar void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) { 6565b702be1SPrashant Kumar patterns.add(convertFPowIOp); 65710a57f3aSPrashant Kumar } 65810a57f3aSPrashant Kumar 659be911578SBalaji V. Iyer void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) { 660be911578SBalaji V. Iyer patterns.add(convertRoundOp); 661be911578SBalaji V. Iyer } 662be911578SBalaji V. Iyer 66344baa655SRamiro Leal-Cavazos void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) { 66444baa655SRamiro Leal-Cavazos patterns.add(convertRoundEvenOp); 66544baa655SRamiro Leal-Cavazos } 666279a659eSCorentin Ferry 667279a659eSCorentin Ferry void mlir::populateExpandRsqrtPattern(RewritePatternSet &patterns) { 668279a659eSCorentin Ferry patterns.add(convertRsqrtOp); 669279a659eSCorentin Ferry } 670