xref: /llvm-project/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp (revision 3a3377579f137a0a6e14b60d891a9736707e7e8d)
110a57f3aSPrashant Kumar //===- ExpandPatterns.cpp - Code to expand various math operations. -------===//
2f3bdb56dSRob Suderman //
3f3bdb56dSRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f3bdb56dSRob Suderman // See https://llvm.org/LICENSE.txt for license information.
5f3bdb56dSRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f3bdb56dSRob Suderman //
7f3bdb56dSRob Suderman //===----------------------------------------------------------------------===//
8f3bdb56dSRob Suderman //
910a57f3aSPrashant Kumar // This file implements expansion of various math operations.
10f3bdb56dSRob Suderman //
11f3bdb56dSRob Suderman //===----------------------------------------------------------------------===//
12f3bdb56dSRob Suderman 
13abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
14f3bdb56dSRob Suderman #include "mlir/Dialect/Math/IR/Math.h"
15f3bdb56dSRob Suderman #include "mlir/Dialect/Math/Transforms/Passes.h"
168b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
17711c5893SRobert Suderman #include "mlir/Dialect/Vector/IR/VectorOps.h"
18f3bdb56dSRob Suderman #include "mlir/IR/Builders.h"
19740e2e90SRobert Suderman #include "mlir/IR/ImplicitLocOpBuilder.h"
20711c5893SRobert Suderman #include "mlir/IR/TypeUtilities.h"
21f3bdb56dSRob Suderman #include "mlir/Transforms/DialectConversion.h"
22f3bdb56dSRob Suderman 
23f3bdb56dSRob Suderman using namespace mlir;
24f3bdb56dSRob Suderman 
25711c5893SRobert Suderman /// Create a float constant.
2610a57f3aSPrashant Kumar static Value createFloatConst(Location loc, Type type, APFloat value,
27711c5893SRobert Suderman                               OpBuilder &b) {
2810a57f3aSPrashant Kumar   bool losesInfo = false;
2910a57f3aSPrashant Kumar   auto eltType = getElementTypeOrSelf(type);
3010a57f3aSPrashant Kumar   // Convert double to the given `FloatType` with round-to-nearest-ties-to-even.
3110a57f3aSPrashant Kumar   value.convert(cast<FloatType>(eltType).getFloatSemantics(),
3210a57f3aSPrashant Kumar                 APFloat::rmNearestTiesToEven, &losesInfo);
3310a57f3aSPrashant Kumar   auto attr = b.getFloatAttr(eltType, value);
34711c5893SRobert Suderman   if (auto shapedTy = dyn_cast<ShapedType>(type)) {
35711c5893SRobert Suderman     return b.create<arith::ConstantOp>(loc,
36711c5893SRobert Suderman                                        DenseElementsAttr::get(shapedTy, attr));
37711c5893SRobert Suderman   }
38711c5893SRobert Suderman 
39711c5893SRobert Suderman   return b.create<arith::ConstantOp>(loc, attr);
40711c5893SRobert Suderman }
41711c5893SRobert Suderman 
4210a57f3aSPrashant Kumar static Value createFloatConst(Location loc, Type type, double value,
4310a57f3aSPrashant Kumar                               OpBuilder &b) {
4410a57f3aSPrashant Kumar   return createFloatConst(loc, type, APFloat(value), b);
4510a57f3aSPrashant Kumar }
4610a57f3aSPrashant Kumar 
4710a57f3aSPrashant Kumar /// Create an integer constant.
48711c5893SRobert Suderman static Value createIntConst(Location loc, Type type, int64_t value,
49711c5893SRobert Suderman                             OpBuilder &b) {
50711c5893SRobert Suderman   auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
51711c5893SRobert Suderman   if (auto shapedTy = dyn_cast<ShapedType>(type)) {
52711c5893SRobert Suderman     return b.create<arith::ConstantOp>(loc,
53711c5893SRobert Suderman                                        DenseElementsAttr::get(shapedTy, attr));
54711c5893SRobert Suderman   }
55711c5893SRobert Suderman 
56711c5893SRobert Suderman   return b.create<arith::ConstantOp>(loc, attr);
57711c5893SRobert Suderman }
58711c5893SRobert Suderman 
592217888dSBalaji V. Iyer static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
602217888dSBalaji V. Iyer   Type opType = operand.getType();
6144baa655SRamiro Leal-Cavazos   Type i64Ty = b.getI64Type();
6244baa655SRamiro Leal-Cavazos   if (auto shapedTy = dyn_cast<ShapedType>(opType))
6344baa655SRamiro Leal-Cavazos     i64Ty = shapedTy.clone(i64Ty);
6444baa655SRamiro Leal-Cavazos   Value fixedConvert = b.create<arith::FPToSIOp>(i64Ty, operand);
652217888dSBalaji V. Iyer   Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert);
6644baa655SRamiro Leal-Cavazos   // The truncation does not preserve the sign when the truncated
6744baa655SRamiro Leal-Cavazos   // value is -0. So here the sign is copied again.
6844baa655SRamiro Leal-Cavazos   return b.create<math::CopySignOp>(fpFixedConvert, operand);
692217888dSBalaji V. Iyer }
702217888dSBalaji V. Iyer 
71aa165edcSRob Suderman // sinhf(float x) -> (exp(x) - exp(-x)) / 2
72aa165edcSRob Suderman static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
73aa165edcSRob Suderman   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
74aa165edcSRob Suderman   Value operand = op.getOperand();
75aa165edcSRob Suderman   Type opType = operand.getType();
76aa165edcSRob Suderman 
77a62a7024Sjinchen   Value exp = b.create<math::ExpOp>(operand);
78a62a7024Sjinchen   Value neg = b.create<arith::NegFOp>(operand);
79a62a7024Sjinchen   Value nexp = b.create<math::ExpOp>(neg);
80aa165edcSRob Suderman   Value sub = b.create<arith::SubFOp>(exp, nexp);
81a62a7024Sjinchen   Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
82a62a7024Sjinchen   Value res = b.create<arith::MulFOp>(sub, half);
83a62a7024Sjinchen   rewriter.replaceOp(op, res);
84aa165edcSRob Suderman   return success();
85aa165edcSRob Suderman }
86aa165edcSRob Suderman 
87aa165edcSRob Suderman // coshf(float x) -> (exp(x) + exp(-x)) / 2
88aa165edcSRob Suderman static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
89aa165edcSRob Suderman   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
90aa165edcSRob Suderman   Value operand = op.getOperand();
91aa165edcSRob Suderman   Type opType = operand.getType();
92aa165edcSRob Suderman 
93a62a7024Sjinchen   Value exp = b.create<math::ExpOp>(operand);
94a62a7024Sjinchen   Value neg = b.create<arith::NegFOp>(operand);
95a62a7024Sjinchen   Value nexp = b.create<math::ExpOp>(neg);
96aa165edcSRob Suderman   Value add = b.create<arith::AddFOp>(exp, nexp);
97a62a7024Sjinchen   Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
98a62a7024Sjinchen   Value res = b.create<arith::MulFOp>(add, half);
99a62a7024Sjinchen   rewriter.replaceOp(op, res);
100aa165edcSRob Suderman   return success();
101aa165edcSRob Suderman }
102aa165edcSRob Suderman 
103f3bdb56dSRob Suderman /// Expands tanh op into
104d39ac3a8Ssrcarroll /// 1-exp^{-2x} / 1+exp^{-2x}
105d39ac3a8Ssrcarroll /// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
106d39ac3a8Ssrcarroll /// We compute a "signs" value which is -1 if input is negative and +1 if input
107d39ac3a8Ssrcarroll /// is positive.  Then multiply the input by this value, guaranteeing that the
108d39ac3a8Ssrcarroll /// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
109d39ac3a8Ssrcarroll /// 1]. Expand the computation on the input `x * sign(x)`, then multiply the
110d39ac3a8Ssrcarroll /// result by `sign(x)` to retain sign of the real result.
111f3bdb56dSRob Suderman static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
112f3bdb56dSRob Suderman   auto floatType = op.getOperand().getType();
113f3bdb56dSRob Suderman   Location loc = op.getLoc();
114d39ac3a8Ssrcarroll   Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
115711c5893SRobert Suderman   Value one = createFloatConst(loc, floatType, 1.0, rewriter);
116d39ac3a8Ssrcarroll   Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
117f3bdb56dSRob Suderman 
118d39ac3a8Ssrcarroll   // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
119d39ac3a8Ssrcarroll   Value isNegative = rewriter.create<arith::CmpFOp>(
120d39ac3a8Ssrcarroll       loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
121d39ac3a8Ssrcarroll   Value isNegativeFloat =
122d39ac3a8Ssrcarroll       rewriter.create<arith::UIToFPOp>(loc, floatType, isNegative);
123d39ac3a8Ssrcarroll   Value isNegativeTimesNegTwo =
124d39ac3a8Ssrcarroll       rewriter.create<arith::MulFOp>(loc, isNegativeFloat, negTwo);
125d39ac3a8Ssrcarroll   Value sign = rewriter.create<arith::AddFOp>(loc, isNegativeTimesNegTwo, one);
126d39ac3a8Ssrcarroll 
127d39ac3a8Ssrcarroll   // Normalize input to positive value: y = sign(x) * x
128d39ac3a8Ssrcarroll   Value positiveX = rewriter.create<arith::MulFOp>(loc, sign, op.getOperand());
129d39ac3a8Ssrcarroll 
130d39ac3a8Ssrcarroll   // Decompose on normalized input
131d39ac3a8Ssrcarroll   Value negDoubledX = rewriter.create<arith::MulFOp>(loc, negTwo, positiveX);
132f3bdb56dSRob Suderman   Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
133f3bdb56dSRob Suderman   Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
134f3bdb56dSRob Suderman   Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
135f3bdb56dSRob Suderman   Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
136f3bdb56dSRob Suderman 
137d39ac3a8Ssrcarroll   // Multiply result by sign(x) to retain signs from negative inputs
138d39ac3a8Ssrcarroll   rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
139f3bdb56dSRob Suderman 
140f3bdb56dSRob Suderman   return success();
141f3bdb56dSRob Suderman }
142f3bdb56dSRob Suderman 
143711c5893SRobert Suderman // Converts math.tan to math.sin, math.cos, and arith.divf.
144740e2e90SRobert Suderman static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
145740e2e90SRobert Suderman   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
146740e2e90SRobert Suderman   Value operand = op.getOperand();
147740e2e90SRobert Suderman   Type type = operand.getType();
148740e2e90SRobert Suderman   Value sin = b.create<math::SinOp>(type, operand);
149740e2e90SRobert Suderman   Value cos = b.create<math::CosOp>(type, operand);
150740e2e90SRobert Suderman   Value div = b.create<arith::DivFOp>(type, sin, cos);
151740e2e90SRobert Suderman   rewriter.replaceOp(op, div);
152740e2e90SRobert Suderman   return success();
153740e2e90SRobert Suderman }
154740e2e90SRobert Suderman 
155a62a7024Sjinchen // asinh(float x) -> log(x + sqrt(x**2 + 1))
156a62a7024Sjinchen static LogicalResult convertAsinhOp(math::AsinhOp op,
157a62a7024Sjinchen                                     PatternRewriter &rewriter) {
158a62a7024Sjinchen   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
159a62a7024Sjinchen   Value operand = op.getOperand();
160a62a7024Sjinchen   Type opType = operand.getType();
161a62a7024Sjinchen 
162a62a7024Sjinchen   Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
163a62a7024Sjinchen   Value fma = b.create<math::FmaOp>(operand, operand, one);
164a62a7024Sjinchen   Value sqrt = b.create<math::SqrtOp>(fma);
165a62a7024Sjinchen   Value add = b.create<arith::AddFOp>(operand, sqrt);
166a62a7024Sjinchen   Value res = b.create<math::LogOp>(add);
167a62a7024Sjinchen   rewriter.replaceOp(op, res);
168a62a7024Sjinchen   return success();
169a62a7024Sjinchen }
170a62a7024Sjinchen 
171a62a7024Sjinchen // acosh(float x) -> log(x + sqrt(x**2 - 1))
172a62a7024Sjinchen static LogicalResult convertAcoshOp(math::AcoshOp op,
173a62a7024Sjinchen                                     PatternRewriter &rewriter) {
174a62a7024Sjinchen   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
175a62a7024Sjinchen   Value operand = op.getOperand();
176a62a7024Sjinchen   Type opType = operand.getType();
177a62a7024Sjinchen 
178a62a7024Sjinchen   Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter);
179a62a7024Sjinchen   Value fma = b.create<math::FmaOp>(operand, operand, negOne);
180a62a7024Sjinchen   Value sqrt = b.create<math::SqrtOp>(fma);
181a62a7024Sjinchen   Value add = b.create<arith::AddFOp>(operand, sqrt);
182a62a7024Sjinchen   Value res = b.create<math::LogOp>(add);
183a62a7024Sjinchen   rewriter.replaceOp(op, res);
184a62a7024Sjinchen   return success();
185a62a7024Sjinchen }
186a62a7024Sjinchen 
187a62a7024Sjinchen // atanh(float x) -> log((1 + x) / (1 - x)) / 2
188a62a7024Sjinchen static LogicalResult convertAtanhOp(math::AtanhOp op,
189a62a7024Sjinchen                                     PatternRewriter &rewriter) {
190a62a7024Sjinchen   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
191a62a7024Sjinchen   Value operand = op.getOperand();
192a62a7024Sjinchen   Type opType = operand.getType();
193a62a7024Sjinchen 
194a62a7024Sjinchen   Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
195a62a7024Sjinchen   Value add = b.create<arith::AddFOp>(operand, one);
196a62a7024Sjinchen   Value neg = b.create<arith::NegFOp>(operand);
197a62a7024Sjinchen   Value sub = b.create<arith::AddFOp>(neg, one);
198a62a7024Sjinchen   Value div = b.create<arith::DivFOp>(add, sub);
199a62a7024Sjinchen   Value log = b.create<math::LogOp>(div);
200a62a7024Sjinchen   Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
201a62a7024Sjinchen   Value res = b.create<arith::MulFOp>(log, half);
202a62a7024Sjinchen   rewriter.replaceOp(op, res);
203a62a7024Sjinchen   return success();
204a62a7024Sjinchen }
205a62a7024Sjinchen 
206a7c2102dSBalaji V. Iyer static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
207a7c2102dSBalaji V. Iyer   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
208a7c2102dSBalaji V. Iyer   Value operandA = op.getOperand(0);
209a7c2102dSBalaji V. Iyer   Value operandB = op.getOperand(1);
210a7c2102dSBalaji V. Iyer   Value operandC = op.getOperand(2);
211a7c2102dSBalaji V. Iyer   Type type = op.getType();
212a7c2102dSBalaji V. Iyer   Value mult = b.create<arith::MulFOp>(type, operandA, operandB);
213a7c2102dSBalaji V. Iyer   Value add = b.create<arith::AddFOp>(type, mult, operandC);
214a7c2102dSBalaji V. Iyer   rewriter.replaceOp(op, add);
215a7c2102dSBalaji V. Iyer   return success();
216a7c2102dSBalaji V. Iyer }
217a7c2102dSBalaji V. Iyer 
2182217888dSBalaji V. Iyer // Converts a ceilf() function to the following:
2192217888dSBalaji V. Iyer // ceilf(float x) ->
2202217888dSBalaji V. Iyer //      y = (float)(int) x
2212217888dSBalaji V. Iyer //      if (x > y) then incr = 1 else incr = 0
2222217888dSBalaji V. Iyer //      y = y + incr   <= replace this op with the ceilf op.
2232217888dSBalaji V. Iyer static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
2242217888dSBalaji V. Iyer   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
2252217888dSBalaji V. Iyer   Value operand = op.getOperand();
2262217888dSBalaji V. Iyer   Type opType = operand.getType();
2272217888dSBalaji V. Iyer   Value fpFixedConvert = createTruncatedFPValue(operand, b);
2282217888dSBalaji V. Iyer 
2292217888dSBalaji V. Iyer   // Creating constants for later use.
2302217888dSBalaji V. Iyer   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
2312217888dSBalaji V. Iyer   Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
2322217888dSBalaji V. Iyer 
2332217888dSBalaji V. Iyer   Value gtCheck = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand,
2342217888dSBalaji V. Iyer                                           fpFixedConvert);
2352217888dSBalaji V. Iyer   Value incrValue = b.create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero);
2362217888dSBalaji V. Iyer 
2372217888dSBalaji V. Iyer   Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
2382217888dSBalaji V. Iyer   rewriter.replaceOp(op, ret);
2392217888dSBalaji V. Iyer   return success();
2402217888dSBalaji V. Iyer }
24110a57f3aSPrashant Kumar 
24210a57f3aSPrashant Kumar // Convert `math.fpowi` to a series of `arith.mulf` operations.
24310a57f3aSPrashant Kumar // If the power is negative, we divide one by the result.
24410a57f3aSPrashant Kumar // If both the base and power are zero, the result is 1.
2455b702be1SPrashant Kumar // In the case of non constant power, we convert the operation to `math.powf`.
2465b702be1SPrashant Kumar static LogicalResult convertFPowIOp(math::FPowIOp op,
24710a57f3aSPrashant Kumar                                     PatternRewriter &rewriter) {
24810a57f3aSPrashant Kumar   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
24910a57f3aSPrashant Kumar   Value base = op.getOperand(0);
25010a57f3aSPrashant Kumar   Value power = op.getOperand(1);
25110a57f3aSPrashant Kumar   Type baseType = base.getType();
25210a57f3aSPrashant Kumar 
2535b702be1SPrashant Kumar   auto convertFPowItoPowf = [&]() -> LogicalResult {
2545b702be1SPrashant Kumar     Value castPowerToFp =
2555b702be1SPrashant Kumar         rewriter.create<arith::SIToFPOp>(op.getLoc(), baseType, power);
2565b702be1SPrashant Kumar     Value res = rewriter.create<math::PowFOp>(op.getLoc(), baseType, base,
2575b702be1SPrashant Kumar                                               castPowerToFp);
2585b702be1SPrashant Kumar     rewriter.replaceOp(op, res);
2595b702be1SPrashant Kumar     return success();
2605b702be1SPrashant Kumar   };
2615b702be1SPrashant Kumar 
26210a57f3aSPrashant Kumar   Attribute cstAttr;
26310a57f3aSPrashant Kumar   if (!matchPattern(power, m_Constant(&cstAttr)))
2645b702be1SPrashant Kumar     return convertFPowItoPowf();
26510a57f3aSPrashant Kumar 
26610a57f3aSPrashant Kumar   APInt value;
26710a57f3aSPrashant Kumar   if (!matchPattern(cstAttr, m_ConstantInt(&value)))
2685b702be1SPrashant Kumar     return convertFPowItoPowf();
26910a57f3aSPrashant Kumar 
27010a57f3aSPrashant Kumar   int64_t powerInt = value.getSExtValue();
27110a57f3aSPrashant Kumar   bool isNegative = powerInt < 0;
27210a57f3aSPrashant Kumar   int64_t absPower = std::abs(powerInt);
27310a57f3aSPrashant Kumar   Value one = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
27410a57f3aSPrashant Kumar   Value res = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
27510a57f3aSPrashant Kumar 
27610a57f3aSPrashant Kumar   while (absPower > 0) {
27710a57f3aSPrashant Kumar     if (absPower & 1)
27810a57f3aSPrashant Kumar       res = b.create<arith::MulFOp>(baseType, base, res);
27910a57f3aSPrashant Kumar     absPower >>= 1;
28010a57f3aSPrashant Kumar     base = b.create<arith::MulFOp>(baseType, base, base);
28110a57f3aSPrashant Kumar   }
28210a57f3aSPrashant Kumar 
28310a57f3aSPrashant Kumar   // Make sure not to introduce UB in case of negative power.
28410a57f3aSPrashant Kumar   if (isNegative) {
28510a57f3aSPrashant Kumar     auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
28610a57f3aSPrashant Kumar                     .getFloatSemantics();
28710a57f3aSPrashant Kumar     Value zero =
28810a57f3aSPrashant Kumar         createFloatConst(op->getLoc(), baseType,
28910a57f3aSPrashant Kumar                          APFloat::getZero(sem, /*Negative=*/false), rewriter);
29010a57f3aSPrashant Kumar     Value negZero =
29110a57f3aSPrashant Kumar         createFloatConst(op->getLoc(), baseType,
29210a57f3aSPrashant Kumar                          APFloat::getZero(sem, /*Negative=*/true), rewriter);
29310a57f3aSPrashant Kumar     Value posInfinity =
29410a57f3aSPrashant Kumar         createFloatConst(op->getLoc(), baseType,
29510a57f3aSPrashant Kumar                          APFloat::getInf(sem, /*Negative=*/false), rewriter);
29610a57f3aSPrashant Kumar     Value negInfinity =
29710a57f3aSPrashant Kumar         createFloatConst(op->getLoc(), baseType,
29810a57f3aSPrashant Kumar                          APFloat::getInf(sem, /*Negative=*/true), rewriter);
29910a57f3aSPrashant Kumar     Value zeroEqCheck =
30010a57f3aSPrashant Kumar         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero);
30110a57f3aSPrashant Kumar     Value negZeroEqCheck =
30210a57f3aSPrashant Kumar         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero);
30310a57f3aSPrashant Kumar     res = b.create<arith::DivFOp>(baseType, one, res);
30410a57f3aSPrashant Kumar     res =
30510a57f3aSPrashant Kumar         b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res);
30610a57f3aSPrashant Kumar     res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity,
30710a57f3aSPrashant Kumar                                     res);
30810a57f3aSPrashant Kumar   }
30910a57f3aSPrashant Kumar 
31010a57f3aSPrashant Kumar   rewriter.replaceOp(op, res);
31110a57f3aSPrashant Kumar   return success();
31210a57f3aSPrashant Kumar }
31310a57f3aSPrashant Kumar 
3142d4e8567SBalaji V. Iyer // Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
315*3a337757SHyunsung Lee // Restricting a >= 0
3162d4e8567SBalaji V. Iyer static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
3172d4e8567SBalaji V. Iyer   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
3182d4e8567SBalaji V. Iyer   Value operandA = op.getOperand(0);
3192d4e8567SBalaji V. Iyer   Value operandB = op.getOperand(1);
3202d4e8567SBalaji V. Iyer   Type opType = operandA.getType();
321f66e4bd6SBalaji V. Iyer   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
322a92e3df3SChristopher Bate   Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
3232d4e8567SBalaji V. Iyer 
324*3a337757SHyunsung Lee   Value logA = b.create<math::LogOp>(opType, operandA);
325*3a337757SHyunsung Lee   Value mult = b.create<arith::MulFOp>(opType, operandB, logA);
3262d4e8567SBalaji V. Iyer   Value expResult = b.create<math::ExpOp>(opType, mult);
327f66e4bd6SBalaji V. Iyer 
328a92e3df3SChristopher Bate   // First, we select between the exp value and the adjusted value for odd
329a92e3df3SChristopher Bate   // powers of negatives. Then, we ensure that one is produced if `b` is zero.
330a92e3df3SChristopher Bate   // This corresponds to `libm` behavior, even for `0^0`. Without this check,
331a92e3df3SChristopher Bate   // `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
332a92e3df3SChristopher Bate   Value zeroCheck =
333a92e3df3SChristopher Bate       b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
334*3a337757SHyunsung Lee   Value finalResult =
335*3a337757SHyunsung Lee       b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, expResult);
336*3a337757SHyunsung Lee   rewriter.replaceOp(op, finalResult);
3372d4e8567SBalaji V. Iyer   return success();
3382d4e8567SBalaji V. Iyer }
3392217888dSBalaji V. Iyer 
3404da96515SBalaji V. Iyer // exp2f(float x) -> exp(x * ln(2))
3414da96515SBalaji V. Iyer //   Proof: Let's say 2^x = y
3424da96515SBalaji V. Iyer //   ln(2^x) = ln(y)
3434da96515SBalaji V. Iyer //   x * ln(2) = ln(y) => e ^(x*ln(2)) = y
3444da96515SBalaji V. Iyer static LogicalResult convertExp2fOp(math::Exp2Op op,
3454da96515SBalaji V. Iyer                                     PatternRewriter &rewriter) {
3464da96515SBalaji V. Iyer   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
3474da96515SBalaji V. Iyer   Value operand = op.getOperand();
3484da96515SBalaji V. Iyer   Type opType = operand.getType();
3494da96515SBalaji V. Iyer   Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
3504da96515SBalaji V. Iyer   Value mult = b.create<arith::MulFOp>(opType, operand, ln2);
3514da96515SBalaji V. Iyer   Value exp = b.create<math::ExpOp>(op->getLoc(), mult);
3524da96515SBalaji V. Iyer   rewriter.replaceOp(op, exp);
3534da96515SBalaji V. Iyer   return success();
3544da96515SBalaji V. Iyer }
3554da96515SBalaji V. Iyer 
356be911578SBalaji V. Iyer static LogicalResult convertRoundOp(math::RoundOp op,
357be911578SBalaji V. Iyer                                     PatternRewriter &rewriter) {
35844baa655SRamiro Leal-Cavazos   Location loc = op.getLoc();
35944baa655SRamiro Leal-Cavazos   ImplicitLocOpBuilder b(loc, rewriter);
360be911578SBalaji V. Iyer   Value operand = op.getOperand();
361be911578SBalaji V. Iyer   Type opType = operand.getType();
36244baa655SRamiro Leal-Cavazos   Type opEType = getElementTypeOrSelf(opType);
363be911578SBalaji V. Iyer 
36444baa655SRamiro Leal-Cavazos   if (!opEType.isF32()) {
36544baa655SRamiro Leal-Cavazos     return rewriter.notifyMatchFailure(op, "not a round of f32.");
36644baa655SRamiro Leal-Cavazos   }
367be911578SBalaji V. Iyer 
36844baa655SRamiro Leal-Cavazos   Type i32Ty = b.getI32Type();
36944baa655SRamiro Leal-Cavazos   if (auto shapedTy = dyn_cast<ShapedType>(opType))
37044baa655SRamiro Leal-Cavazos     i32Ty = shapedTy.clone(i32Ty);
37144baa655SRamiro Leal-Cavazos 
37244baa655SRamiro Leal-Cavazos   Value half = createFloatConst(loc, opType, 0.5, b);
37344baa655SRamiro Leal-Cavazos   Value c23 = createIntConst(loc, i32Ty, 23, b);
37444baa655SRamiro Leal-Cavazos   Value c127 = createIntConst(loc, i32Ty, 127, b);
37544baa655SRamiro Leal-Cavazos   Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
37644baa655SRamiro Leal-Cavazos 
37744baa655SRamiro Leal-Cavazos   Value incrValue = b.create<math::CopySignOp>(half, operand);
378be911578SBalaji V. Iyer   Value add = b.create<arith::AddFOp>(opType, operand, incrValue);
379be911578SBalaji V. Iyer   Value fpFixedConvert = createTruncatedFPValue(add, b);
38044baa655SRamiro Leal-Cavazos 
38144baa655SRamiro Leal-Cavazos   // There are three cases where adding 0.5 to the value and truncating by
38244baa655SRamiro Leal-Cavazos   // converting to an i64 does not result in the correct behavior:
38344baa655SRamiro Leal-Cavazos   //
38444baa655SRamiro Leal-Cavazos   // 1. Special values: +-inf and +-nan
38544baa655SRamiro Leal-Cavazos   //     Casting these special values to i64 has undefined behavior. To identify
38644baa655SRamiro Leal-Cavazos   //     these values, we use the fact that these values are the only float
38744baa655SRamiro Leal-Cavazos   //     values with the maximum possible biased exponent.
38844baa655SRamiro Leal-Cavazos   //
38944baa655SRamiro Leal-Cavazos   // 2. Large values: 2^23 <= |x| <= INT_64_MAX
39044baa655SRamiro Leal-Cavazos   //     Adding 0.5 to a float larger than or equal to 2^23 results in precision
39144baa655SRamiro Leal-Cavazos   //     errors that sometimes round the value up and sometimes round the value
39244baa655SRamiro Leal-Cavazos   //     down. For example:
39344baa655SRamiro Leal-Cavazos   //         8388608.0 + 0.5 = 8388608.0
39444baa655SRamiro Leal-Cavazos   //         8388609.0 + 0.5 = 8388610.0
39544baa655SRamiro Leal-Cavazos   //
39644baa655SRamiro Leal-Cavazos   // 3. Very large values: |x| > INT_64_MAX
39744baa655SRamiro Leal-Cavazos   //     Casting to i64 a value greater than the max i64 value will overflow the
39844baa655SRamiro Leal-Cavazos   //     i64 leading to wrong outputs.
39944baa655SRamiro Leal-Cavazos   //
40044baa655SRamiro Leal-Cavazos   // All three cases satisfy the property `biasedExp >= 23`.
40144baa655SRamiro Leal-Cavazos   Value operandBitcast = b.create<arith::BitcastOp>(i32Ty, operand);
40244baa655SRamiro Leal-Cavazos   Value operandExp = b.create<arith::AndIOp>(
40344baa655SRamiro Leal-Cavazos       b.create<arith::ShRUIOp>(operandBitcast, c23), expMask);
40444baa655SRamiro Leal-Cavazos   Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127);
40544baa655SRamiro Leal-Cavazos   Value isSpecialValOrLargeVal =
40644baa655SRamiro Leal-Cavazos       b.create<arith::CmpIOp>(arith::CmpIPredicate::sge, operandBiasedExp, c23);
40744baa655SRamiro Leal-Cavazos 
40844baa655SRamiro Leal-Cavazos   Value result = b.create<arith::SelectOp>(isSpecialValOrLargeVal, operand,
40944baa655SRamiro Leal-Cavazos                                            fpFixedConvert);
41044baa655SRamiro Leal-Cavazos   rewriter.replaceOp(op, result);
411be911578SBalaji V. Iyer   return success();
412be911578SBalaji V. Iyer }
413be911578SBalaji V. Iyer 
414711c5893SRobert Suderman // Converts math.ctlz to scf and arith operations. This is done
415711c5893SRobert Suderman // by performing a binary search on the bits.
416f3bdb56dSRob Suderman static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
417f3bdb56dSRob Suderman                                    PatternRewriter &rewriter) {
418f3bdb56dSRob Suderman   auto operand = op.getOperand();
419711c5893SRobert Suderman   auto operandTy = operand.getType();
420711c5893SRobert Suderman   auto eTy = getElementTypeOrSelf(operandTy);
421f3bdb56dSRob Suderman   Location loc = op.getLoc();
422f3bdb56dSRob Suderman 
423711c5893SRobert Suderman   int32_t bitwidth = eTy.getIntOrFloatBitWidth();
424711c5893SRobert Suderman   if (bitwidth > 64)
425711c5893SRobert Suderman     return failure();
426f3bdb56dSRob Suderman 
427711c5893SRobert Suderman   uint64_t allbits = -1;
428711c5893SRobert Suderman   if (bitwidth < 64) {
429711c5893SRobert Suderman     allbits = allbits >> (64 - bitwidth);
430711c5893SRobert Suderman   }
431f3bdb56dSRob Suderman 
432711c5893SRobert Suderman   Value x = operand;
433711c5893SRobert Suderman   Value count = createIntConst(loc, operandTy, 0, rewriter);
434711c5893SRobert Suderman   for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
435711c5893SRobert Suderman     auto half = bw / 2;
436711c5893SRobert Suderman     auto bits = createIntConst(loc, operandTy, half, rewriter);
437711c5893SRobert Suderman     auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter);
438f3bdb56dSRob Suderman 
439711c5893SRobert Suderman     Value pred =
440711c5893SRobert Suderman         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, x, mask);
441711c5893SRobert Suderman     Value add = rewriter.create<arith::AddIOp>(loc, count, bits);
442711c5893SRobert Suderman     Value shift = rewriter.create<arith::ShLIOp>(loc, x, bits);
443f3bdb56dSRob Suderman 
444711c5893SRobert Suderman     x = rewriter.create<arith::SelectOp>(loc, pred, shift, x);
445711c5893SRobert Suderman     count = rewriter.create<arith::SelectOp>(loc, pred, add, count);
446711c5893SRobert Suderman   }
447f3bdb56dSRob Suderman 
448711c5893SRobert Suderman   Value zero = createIntConst(loc, operandTy, 0, rewriter);
449711c5893SRobert Suderman   Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
450711c5893SRobert Suderman                                               operand, zero);
451f3bdb56dSRob Suderman 
452711c5893SRobert Suderman   Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter);
453711c5893SRobert Suderman   Value sel = rewriter.create<arith::SelectOp>(loc, pred, bwval, count);
454711c5893SRobert Suderman   rewriter.replaceOp(op, sel);
455f3bdb56dSRob Suderman   return success();
456f3bdb56dSRob Suderman }
457f3bdb56dSRob Suderman 
45844baa655SRamiro Leal-Cavazos // Convert `math.roundeven` into `math.round` + arith ops
45944baa655SRamiro Leal-Cavazos static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
46044baa655SRamiro Leal-Cavazos                                         PatternRewriter &rewriter) {
46144baa655SRamiro Leal-Cavazos   Location loc = op.getLoc();
46244baa655SRamiro Leal-Cavazos   ImplicitLocOpBuilder b(loc, rewriter);
46344baa655SRamiro Leal-Cavazos   auto operand = op.getOperand();
46444baa655SRamiro Leal-Cavazos   Type operandTy = operand.getType();
46544baa655SRamiro Leal-Cavazos   Type resultTy = op.getType();
46644baa655SRamiro Leal-Cavazos   Type operandETy = getElementTypeOrSelf(operandTy);
46744baa655SRamiro Leal-Cavazos   Type resultETy = getElementTypeOrSelf(resultTy);
46844baa655SRamiro Leal-Cavazos 
469fe355a44SAlexander Shaposhnikov   if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) {
470fe355a44SAlexander Shaposhnikov     return rewriter.notifyMatchFailure(op, "not a roundeven of f16 or f32.");
47144baa655SRamiro Leal-Cavazos   }
47244baa655SRamiro Leal-Cavazos 
473fe355a44SAlexander Shaposhnikov   Type fTy = operandTy;
474fe355a44SAlexander Shaposhnikov   Type iTy = rewriter.getIntegerType(operandETy.getIntOrFloatBitWidth());
475fe355a44SAlexander Shaposhnikov   if (auto shapedTy = dyn_cast<ShapedType>(fTy)) {
476fe355a44SAlexander Shaposhnikov     iTy = shapedTy.clone(iTy);
47744baa655SRamiro Leal-Cavazos   }
47844baa655SRamiro Leal-Cavazos 
479fe355a44SAlexander Shaposhnikov   unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
480fe355a44SAlexander Shaposhnikov   // The width returned by getFPMantissaWidth includes the integer bit.
481fe355a44SAlexander Shaposhnikov   unsigned mantissaWidth =
482fe355a44SAlexander Shaposhnikov       llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
483fe355a44SAlexander Shaposhnikov   unsigned exponentWidth = bitWidth - mantissaWidth - 1;
48444baa655SRamiro Leal-Cavazos 
485fe355a44SAlexander Shaposhnikov   // The names of the variables correspond to f32.
486fe355a44SAlexander Shaposhnikov   // f64: 1 bit sign | 11 bits exponent | 52 bits mantissa.
487fe355a44SAlexander Shaposhnikov   // f32: 1 bit sign | 8 bits exponent  | 23 bits mantissa.
488fe355a44SAlexander Shaposhnikov   // f16: 1 bit sign | 5 bits exponent  | 10 bits mantissa.
489fe355a44SAlexander Shaposhnikov   Value c1Float = createFloatConst(loc, fTy, 1.0, b);
490fe355a44SAlexander Shaposhnikov   Value c0 = createIntConst(loc, iTy, 0, b);
491fe355a44SAlexander Shaposhnikov   Value c1 = createIntConst(loc, iTy, 1, b);
492fe355a44SAlexander Shaposhnikov   Value cNeg1 = createIntConst(loc, iTy, -1, b);
493fe355a44SAlexander Shaposhnikov   Value c23 = createIntConst(loc, iTy, mantissaWidth, b);
494fe355a44SAlexander Shaposhnikov   Value c31 = createIntConst(loc, iTy, bitWidth - 1, b);
495fe355a44SAlexander Shaposhnikov   Value c127 = createIntConst(loc, iTy, (1ull << (exponentWidth - 1)) - 1, b);
496fe355a44SAlexander Shaposhnikov   Value c2To22 = createIntConst(loc, iTy, 1ull << (mantissaWidth - 1), b);
497fe355a44SAlexander Shaposhnikov   Value c23Mask = createIntConst(loc, iTy, (1ull << mantissaWidth) - 1, b);
498fe355a44SAlexander Shaposhnikov   Value expMask = createIntConst(loc, iTy, (1ull << exponentWidth) - 1, b);
499fe355a44SAlexander Shaposhnikov 
500fe355a44SAlexander Shaposhnikov   Value operandBitcast = b.create<arith::BitcastOp>(iTy, operand);
50144baa655SRamiro Leal-Cavazos   Value round = b.create<math::RoundOp>(operand);
502fe355a44SAlexander Shaposhnikov   Value roundBitcast = b.create<arith::BitcastOp>(iTy, round);
50344baa655SRamiro Leal-Cavazos 
50444baa655SRamiro Leal-Cavazos   // Get biased exponents for operand and round(operand)
50544baa655SRamiro Leal-Cavazos   Value operandExp = b.create<arith::AndIOp>(
50644baa655SRamiro Leal-Cavazos       b.create<arith::ShRUIOp>(operandBitcast, c23), expMask);
50744baa655SRamiro Leal-Cavazos   Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127);
50844baa655SRamiro Leal-Cavazos   Value roundExp = b.create<arith::AndIOp>(
50944baa655SRamiro Leal-Cavazos       b.create<arith::ShRUIOp>(roundBitcast, c23), expMask);
51044baa655SRamiro Leal-Cavazos   Value roundBiasedExp = b.create<arith::SubIOp>(roundExp, c127);
51144baa655SRamiro Leal-Cavazos 
51244baa655SRamiro Leal-Cavazos   auto safeShiftRight = [&](Value x, Value shift) -> Value {
513fe355a44SAlexander Shaposhnikov     // Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior
51444baa655SRamiro Leal-Cavazos     Value clampedShift = b.create<arith::MaxSIOp>(shift, c0);
51544baa655SRamiro Leal-Cavazos     clampedShift = b.create<arith::MinSIOp>(clampedShift, c31);
51644baa655SRamiro Leal-Cavazos     return b.create<arith::ShRUIOp>(x, clampedShift);
51744baa655SRamiro Leal-Cavazos   };
51844baa655SRamiro Leal-Cavazos 
51944baa655SRamiro Leal-Cavazos   auto maskMantissa = [&](Value mantissa,
52044baa655SRamiro Leal-Cavazos                           Value mantissaMaskRightShift) -> Value {
52144baa655SRamiro Leal-Cavazos     Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
52244baa655SRamiro Leal-Cavazos     return b.create<arith::AndIOp>(mantissa, shiftedMantissaMask);
52344baa655SRamiro Leal-Cavazos   };
52444baa655SRamiro Leal-Cavazos 
52544baa655SRamiro Leal-Cavazos   // A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring
52644baa655SRamiro Leal-Cavazos   // the leftmost `clamp(biasedExp - 1, 0, 23)` bits, is zero. Large numbers
52744baa655SRamiro Leal-Cavazos   // with `biasedExp > 23` (numbers where there is not enough precision to store
52844baa655SRamiro Leal-Cavazos   // decimals) are always even, and they satisfy the even condition trivially
52944baa655SRamiro Leal-Cavazos   // since the mantissa without all its bits is zero. The even condition
53044baa655SRamiro Leal-Cavazos   // is also true for +-0, since they have `biasedExp = -127` and the entire
53144baa655SRamiro Leal-Cavazos   // mantissa is zero. The case of +-1 has to be handled separately. Here
53244baa655SRamiro Leal-Cavazos   // we identify these values by noting that +-1 are the only whole numbers with
53344baa655SRamiro Leal-Cavazos   // `biasedExp == 0`.
53444baa655SRamiro Leal-Cavazos   //
53544baa655SRamiro Leal-Cavazos   // The special values +-inf and +-nan also satisfy the same property that
53644baa655SRamiro Leal-Cavazos   // whole non-unit even numbers satisfy. In particular, the special values have
53744baa655SRamiro Leal-Cavazos   // `biasedExp > 23`, so they get treated as large numbers with no room for
53844baa655SRamiro Leal-Cavazos   // decimals, which are always even.
53944baa655SRamiro Leal-Cavazos   Value roundBiasedExpEq0 =
54044baa655SRamiro Leal-Cavazos       b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, roundBiasedExp, c0);
54144baa655SRamiro Leal-Cavazos   Value roundBiasedExpMinus1 = b.create<arith::SubIOp>(roundBiasedExp, c1);
54244baa655SRamiro Leal-Cavazos   Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
54344baa655SRamiro Leal-Cavazos   Value roundIsNotEvenOrSpecialVal = b.create<arith::CmpIOp>(
54444baa655SRamiro Leal-Cavazos       arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
54544baa655SRamiro Leal-Cavazos   roundIsNotEvenOrSpecialVal =
54644baa655SRamiro Leal-Cavazos       b.create<arith::OrIOp>(roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
54744baa655SRamiro Leal-Cavazos 
54844baa655SRamiro Leal-Cavazos   // A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive
54944baa655SRamiro Leal-Cavazos   // integers if the bit at index `biasedExp` starting from the left in the
55044baa655SRamiro Leal-Cavazos   // mantissa is 1 and all the bits to the right are zero. Values with
55144baa655SRamiro Leal-Cavazos   // `biasedExp >= 23` don't have decimals, so they are never halfway. The
55244baa655SRamiro Leal-Cavazos   // values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`,
55344baa655SRamiro Leal-Cavazos   // so these are handled separately. In particular, if `biasedExp == -1`, the
55444baa655SRamiro Leal-Cavazos   // value is halfway if the entire mantissa is zero.
55544baa655SRamiro Leal-Cavazos   Value operandBiasedExpEqNeg1 = b.create<arith::CmpIOp>(
55644baa655SRamiro Leal-Cavazos       arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
55744baa655SRamiro Leal-Cavazos   Value expectedOperandMaskedMantissa = b.create<arith::SelectOp>(
55844baa655SRamiro Leal-Cavazos       operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
55944baa655SRamiro Leal-Cavazos   Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
56044baa655SRamiro Leal-Cavazos   Value operandIsHalfway =
56144baa655SRamiro Leal-Cavazos       b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, operandMaskedMantissa,
56244baa655SRamiro Leal-Cavazos                               expectedOperandMaskedMantissa);
56344baa655SRamiro Leal-Cavazos   // Ensure `biasedExp` is in the valid range for half values.
56444baa655SRamiro Leal-Cavazos   Value operandBiasedExpGeNeg1 = b.create<arith::CmpIOp>(
56544baa655SRamiro Leal-Cavazos       arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
56644baa655SRamiro Leal-Cavazos   Value operandBiasedExpLt23 =
56744baa655SRamiro Leal-Cavazos       b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, operandBiasedExp, c23);
56844baa655SRamiro Leal-Cavazos   operandIsHalfway =
56944baa655SRamiro Leal-Cavazos       b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpLt23);
57044baa655SRamiro Leal-Cavazos   operandIsHalfway =
57144baa655SRamiro Leal-Cavazos       b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpGeNeg1);
57244baa655SRamiro Leal-Cavazos 
57344baa655SRamiro Leal-Cavazos   // Adjust rounded operand with `round(operand) - sign(operand)` to correct the
57444baa655SRamiro Leal-Cavazos   // case where `round` rounded in the opposite direction of `roundeven`.
57544baa655SRamiro Leal-Cavazos   Value sign = b.create<math::CopySignOp>(c1Float, operand);
57644baa655SRamiro Leal-Cavazos   Value roundShifted = b.create<arith::SubFOp>(round, sign);
57744baa655SRamiro Leal-Cavazos   // If the rounded value is even or a special value, we default to the behavior
57844baa655SRamiro Leal-Cavazos   // of `math.round`.
57944baa655SRamiro Leal-Cavazos   Value needsShift =
58044baa655SRamiro Leal-Cavazos       b.create<arith::AndIOp>(roundIsNotEvenOrSpecialVal, operandIsHalfway);
58144baa655SRamiro Leal-Cavazos   Value result = b.create<arith::SelectOp>(needsShift, roundShifted, round);
58244baa655SRamiro Leal-Cavazos   // The `x - sign` adjustment does not preserve the sign when we are adjusting
58344baa655SRamiro Leal-Cavazos   // the value -1 to -0. So here the sign is copied again to ensure that -0.5 is
58444baa655SRamiro Leal-Cavazos   // rounded to -0.0.
58544baa655SRamiro Leal-Cavazos   result = b.create<math::CopySignOp>(result, operand);
58644baa655SRamiro Leal-Cavazos   rewriter.replaceOp(op, result);
58744baa655SRamiro Leal-Cavazos   return success();
58844baa655SRamiro Leal-Cavazos }
58944baa655SRamiro Leal-Cavazos 
590279a659eSCorentin Ferry // Convert `math.rsqrt` into `arith.divf` + `math.sqrt`
591279a659eSCorentin Ferry static LogicalResult convertRsqrtOp(math::RsqrtOp op,
592279a659eSCorentin Ferry                                     PatternRewriter &rewriter) {
593279a659eSCorentin Ferry 
594279a659eSCorentin Ferry   auto operand = op.getOperand();
595279a659eSCorentin Ferry   auto operandTy = operand.getType();
596279a659eSCorentin Ferry   auto eTy = getElementTypeOrSelf(operandTy);
597279a659eSCorentin Ferry   if (!isa<FloatType>(eTy))
598279a659eSCorentin Ferry     return failure();
599279a659eSCorentin Ferry 
600279a659eSCorentin Ferry   Location loc = op->getLoc();
601279a659eSCorentin Ferry   auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter);
602279a659eSCorentin Ferry   auto sqrtOp = rewriter.create<math::SqrtOp>(loc, operand);
603279a659eSCorentin Ferry   rewriter.replaceOpWithNewOp<arith::DivFOp>(op, constOneFloat, sqrtOp);
604279a659eSCorentin Ferry   return success();
605279a659eSCorentin Ferry }
606279a659eSCorentin Ferry 
607f3bdb56dSRob Suderman void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
608f3bdb56dSRob Suderman   patterns.add(convertCtlzOp);
609f3bdb56dSRob Suderman }
610f3bdb56dSRob Suderman 
611aa165edcSRob Suderman void mlir::populateExpandSinhPattern(RewritePatternSet &patterns) {
612aa165edcSRob Suderman   patterns.add(convertSinhOp);
613aa165edcSRob Suderman }
614aa165edcSRob Suderman 
615aa165edcSRob Suderman void mlir::populateExpandCoshPattern(RewritePatternSet &patterns) {
616aa165edcSRob Suderman   patterns.add(convertCoshOp);
617aa165edcSRob Suderman }
618aa165edcSRob Suderman 
619740e2e90SRobert Suderman void mlir::populateExpandTanPattern(RewritePatternSet &patterns) {
620740e2e90SRobert Suderman   patterns.add(convertTanOp);
621740e2e90SRobert Suderman }
622740e2e90SRobert Suderman 
623f3bdb56dSRob Suderman void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
624f3bdb56dSRob Suderman   patterns.add(convertTanhOp);
625f3bdb56dSRob Suderman }
626a7c2102dSBalaji V. Iyer 
627a62a7024Sjinchen void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) {
628a62a7024Sjinchen   patterns.add(convertAsinhOp);
629a62a7024Sjinchen }
630a62a7024Sjinchen 
631a62a7024Sjinchen void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) {
632a62a7024Sjinchen   patterns.add(convertAcoshOp);
633a62a7024Sjinchen }
634a62a7024Sjinchen 
635a62a7024Sjinchen void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) {
636a62a7024Sjinchen   patterns.add(convertAtanhOp);
637a62a7024Sjinchen }
638a62a7024Sjinchen 
639a7c2102dSBalaji V. Iyer void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
640a7c2102dSBalaji V. Iyer   patterns.add(convertFmaFOp);
641a7c2102dSBalaji V. Iyer }
6422217888dSBalaji V. Iyer 
6432217888dSBalaji V. Iyer void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) {
6442217888dSBalaji V. Iyer   patterns.add(convertCeilOp);
6452217888dSBalaji V. Iyer }
6462217888dSBalaji V. Iyer 
6474da96515SBalaji V. Iyer void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
6484da96515SBalaji V. Iyer   patterns.add(convertExp2fOp);
6494da96515SBalaji V. Iyer }
6504da96515SBalaji V. Iyer 
6512d4e8567SBalaji V. Iyer void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
6522d4e8567SBalaji V. Iyer   patterns.add(convertPowfOp);
6532d4e8567SBalaji V. Iyer }
6542d4e8567SBalaji V. Iyer 
65510a57f3aSPrashant Kumar void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
6565b702be1SPrashant Kumar   patterns.add(convertFPowIOp);
65710a57f3aSPrashant Kumar }
65810a57f3aSPrashant Kumar 
659be911578SBalaji V. Iyer void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
660be911578SBalaji V. Iyer   patterns.add(convertRoundOp);
661be911578SBalaji V. Iyer }
662be911578SBalaji V. Iyer 
66344baa655SRamiro Leal-Cavazos void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) {
66444baa655SRamiro Leal-Cavazos   patterns.add(convertRoundEvenOp);
66544baa655SRamiro Leal-Cavazos }
666279a659eSCorentin Ferry 
667279a659eSCorentin Ferry void mlir::populateExpandRsqrtPattern(RewritePatternSet &patterns) {
668279a659eSCorentin Ferry   patterns.add(convertRsqrtOp);
669279a659eSCorentin Ferry }
670