xref: /llvm-project/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp (revision a62a7024164c2977cd0e77f77807f957802d204a)
1 //===- ExpandPatterns.cpp - Code to expand various math operations. -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements expansion of various math operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Math/IR/Math.h"
15 #include "mlir/Dialect/Math/Transforms/Passes.h"
16 #include "mlir/Dialect/SCF/IR/SCF.h"
17 #include "mlir/Dialect/Vector/IR/VectorOps.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/ImplicitLocOpBuilder.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 
23 using namespace mlir;
24 
25 /// Create a float constant.
26 static Value createFloatConst(Location loc, Type type, APFloat value,
27                               OpBuilder &b) {
28   bool losesInfo = false;
29   auto eltType = getElementTypeOrSelf(type);
30   // Convert double to the given `FloatType` with round-to-nearest-ties-to-even.
31   value.convert(cast<FloatType>(eltType).getFloatSemantics(),
32                 APFloat::rmNearestTiesToEven, &losesInfo);
33   auto attr = b.getFloatAttr(eltType, value);
34   if (auto shapedTy = dyn_cast<ShapedType>(type)) {
35     return b.create<arith::ConstantOp>(loc,
36                                        DenseElementsAttr::get(shapedTy, attr));
37   }
38 
39   return b.create<arith::ConstantOp>(loc, attr);
40 }
41 
42 static Value createFloatConst(Location loc, Type type, double value,
43                               OpBuilder &b) {
44   return createFloatConst(loc, type, APFloat(value), b);
45 }
46 
47 /// Create an integer constant.
48 static Value createIntConst(Location loc, Type type, int64_t value,
49                             OpBuilder &b) {
50   auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
51   if (auto shapedTy = dyn_cast<ShapedType>(type)) {
52     return b.create<arith::ConstantOp>(loc,
53                                        DenseElementsAttr::get(shapedTy, attr));
54   }
55 
56   return b.create<arith::ConstantOp>(loc, attr);
57 }
58 
59 static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
60   Type opType = operand.getType();
61   Type i64Ty = b.getI64Type();
62   if (auto shapedTy = dyn_cast<ShapedType>(opType))
63     i64Ty = shapedTy.clone(i64Ty);
64   Value fixedConvert = b.create<arith::FPToSIOp>(i64Ty, operand);
65   Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert);
66   // The truncation does not preserve the sign when the truncated
67   // value is -0. So here the sign is copied again.
68   return b.create<math::CopySignOp>(fpFixedConvert, operand);
69 }
70 
71 // sinhf(float x) -> (exp(x) - exp(-x)) / 2
72 static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
73   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
74   Value operand = op.getOperand();
75   Type opType = operand.getType();
76 
77   Value exp = b.create<math::ExpOp>(operand);
78   Value neg = b.create<arith::NegFOp>(operand);
79   Value nexp = b.create<math::ExpOp>(neg);
80   Value sub = b.create<arith::SubFOp>(exp, nexp);
81   Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
82   Value res = b.create<arith::MulFOp>(sub, half);
83   rewriter.replaceOp(op, res);
84   return success();
85 }
86 
87 // coshf(float x) -> (exp(x) + exp(-x)) / 2
88 static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
89   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
90   Value operand = op.getOperand();
91   Type opType = operand.getType();
92 
93   Value exp = b.create<math::ExpOp>(operand);
94   Value neg = b.create<arith::NegFOp>(operand);
95   Value nexp = b.create<math::ExpOp>(neg);
96   Value add = b.create<arith::AddFOp>(exp, nexp);
97   Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
98   Value res = b.create<arith::MulFOp>(add, half);
99   rewriter.replaceOp(op, res);
100   return success();
101 }
102 
103 /// Expands tanh op into
104 /// 1-exp^{-2x} / 1+exp^{-2x}
105 /// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
106 /// We compute a "signs" value which is -1 if input is negative and +1 if input
107 /// is positive.  Then multiply the input by this value, guaranteeing that the
108 /// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
109 /// 1]. Expand the computation on the input `x * sign(x)`, then multiply the
110 /// result by `sign(x)` to retain sign of the real result.
111 static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
112   auto floatType = op.getOperand().getType();
113   Location loc = op.getLoc();
114   Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
115   Value one = createFloatConst(loc, floatType, 1.0, rewriter);
116   Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
117 
118   // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
119   Value isNegative = rewriter.create<arith::CmpFOp>(
120       loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
121   Value isNegativeFloat =
122       rewriter.create<arith::UIToFPOp>(loc, floatType, isNegative);
123   Value isNegativeTimesNegTwo =
124       rewriter.create<arith::MulFOp>(loc, isNegativeFloat, negTwo);
125   Value sign = rewriter.create<arith::AddFOp>(loc, isNegativeTimesNegTwo, one);
126 
127   // Normalize input to positive value: y = sign(x) * x
128   Value positiveX = rewriter.create<arith::MulFOp>(loc, sign, op.getOperand());
129 
130   // Decompose on normalized input
131   Value negDoubledX = rewriter.create<arith::MulFOp>(loc, negTwo, positiveX);
132   Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
133   Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
134   Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
135   Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
136 
137   // Multiply result by sign(x) to retain signs from negative inputs
138   rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
139 
140   return success();
141 }
142 
143 // Converts math.tan to math.sin, math.cos, and arith.divf.
144 static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
145   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
146   Value operand = op.getOperand();
147   Type type = operand.getType();
148   Value sin = b.create<math::SinOp>(type, operand);
149   Value cos = b.create<math::CosOp>(type, operand);
150   Value div = b.create<arith::DivFOp>(type, sin, cos);
151   rewriter.replaceOp(op, div);
152   return success();
153 }
154 
155 // asinh(float x) -> log(x + sqrt(x**2 + 1))
156 static LogicalResult convertAsinhOp(math::AsinhOp op,
157                                     PatternRewriter &rewriter) {
158   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
159   Value operand = op.getOperand();
160   Type opType = operand.getType();
161 
162   Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
163   Value fma = b.create<math::FmaOp>(operand, operand, one);
164   Value sqrt = b.create<math::SqrtOp>(fma);
165   Value add = b.create<arith::AddFOp>(operand, sqrt);
166   Value res = b.create<math::LogOp>(add);
167   rewriter.replaceOp(op, res);
168   return success();
169 }
170 
171 // acosh(float x) -> log(x + sqrt(x**2 - 1))
172 static LogicalResult convertAcoshOp(math::AcoshOp op,
173                                     PatternRewriter &rewriter) {
174   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
175   Value operand = op.getOperand();
176   Type opType = operand.getType();
177 
178   Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter);
179   Value fma = b.create<math::FmaOp>(operand, operand, negOne);
180   Value sqrt = b.create<math::SqrtOp>(fma);
181   Value add = b.create<arith::AddFOp>(operand, sqrt);
182   Value res = b.create<math::LogOp>(add);
183   rewriter.replaceOp(op, res);
184   return success();
185 }
186 
187 // atanh(float x) -> log((1 + x) / (1 - x)) / 2
188 static LogicalResult convertAtanhOp(math::AtanhOp op,
189                                     PatternRewriter &rewriter) {
190   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
191   Value operand = op.getOperand();
192   Type opType = operand.getType();
193 
194   Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
195   Value add = b.create<arith::AddFOp>(operand, one);
196   Value neg = b.create<arith::NegFOp>(operand);
197   Value sub = b.create<arith::AddFOp>(neg, one);
198   Value div = b.create<arith::DivFOp>(add, sub);
199   Value log = b.create<math::LogOp>(div);
200   Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
201   Value res = b.create<arith::MulFOp>(log, half);
202   rewriter.replaceOp(op, res);
203   return success();
204 }
205 
206 static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
207   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
208   Value operandA = op.getOperand(0);
209   Value operandB = op.getOperand(1);
210   Value operandC = op.getOperand(2);
211   Type type = op.getType();
212   Value mult = b.create<arith::MulFOp>(type, operandA, operandB);
213   Value add = b.create<arith::AddFOp>(type, mult, operandC);
214   rewriter.replaceOp(op, add);
215   return success();
216 }
217 
218 // Converts a floorf() function to the following:
219 // floorf(float x) ->
220 //     y = (float)(int) x
221 //     if (x < 0) then incr = -1 else incr = 0
222 //     y = y + incr    <= replace this op with the floorf op.
223 static LogicalResult convertFloorOp(math::FloorOp op,
224                                     PatternRewriter &rewriter) {
225   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
226   Value operand = op.getOperand();
227   Type opType = operand.getType();
228   Value fpFixedConvert = createTruncatedFPValue(operand, b);
229 
230   // Creating constants for later use.
231   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
232   Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
233 
234   Value negCheck =
235       b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
236   Value incrValue =
237       b.create<arith::SelectOp>(op->getLoc(), negCheck, negOne, zero);
238   Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
239   rewriter.replaceOp(op, ret);
240   return success();
241 }
242 
243 // Converts a ceilf() function to the following:
244 // ceilf(float x) ->
245 //      y = (float)(int) x
246 //      if (x > y) then incr = 1 else incr = 0
247 //      y = y + incr   <= replace this op with the ceilf op.
248 static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
249   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
250   Value operand = op.getOperand();
251   Type opType = operand.getType();
252   Value fpFixedConvert = createTruncatedFPValue(operand, b);
253 
254   // Creating constants for later use.
255   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
256   Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
257 
258   Value gtCheck = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand,
259                                           fpFixedConvert);
260   Value incrValue = b.create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero);
261 
262   Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
263   rewriter.replaceOp(op, ret);
264   return success();
265 }
266 
267 // Convert `math.fpowi` to a series of `arith.mulf` operations.
268 // If the power is negative, we divide one by the result.
269 // If both the base and power are zero, the result is 1.
270 // In the case of non constant power, we convert the operation to `math.powf`.
271 static LogicalResult convertFPowIOp(math::FPowIOp op,
272                                     PatternRewriter &rewriter) {
273   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
274   Value base = op.getOperand(0);
275   Value power = op.getOperand(1);
276   Type baseType = base.getType();
277 
278   auto convertFPowItoPowf = [&]() -> LogicalResult {
279     Value castPowerToFp =
280         rewriter.create<arith::SIToFPOp>(op.getLoc(), baseType, power);
281     Value res = rewriter.create<math::PowFOp>(op.getLoc(), baseType, base,
282                                               castPowerToFp);
283     rewriter.replaceOp(op, res);
284     return success();
285   };
286 
287   Attribute cstAttr;
288   if (!matchPattern(power, m_Constant(&cstAttr)))
289     return convertFPowItoPowf();
290 
291   APInt value;
292   if (!matchPattern(cstAttr, m_ConstantInt(&value)))
293     return convertFPowItoPowf();
294 
295   int64_t powerInt = value.getSExtValue();
296   bool isNegative = powerInt < 0;
297   int64_t absPower = std::abs(powerInt);
298   Value one = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
299   Value res = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
300 
301   while (absPower > 0) {
302     if (absPower & 1)
303       res = b.create<arith::MulFOp>(baseType, base, res);
304     absPower >>= 1;
305     base = b.create<arith::MulFOp>(baseType, base, base);
306   }
307 
308   // Make sure not to introduce UB in case of negative power.
309   if (isNegative) {
310     auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
311                     .getFloatSemantics();
312     Value zero =
313         createFloatConst(op->getLoc(), baseType,
314                          APFloat::getZero(sem, /*Negative=*/false), rewriter);
315     Value negZero =
316         createFloatConst(op->getLoc(), baseType,
317                          APFloat::getZero(sem, /*Negative=*/true), rewriter);
318     Value posInfinity =
319         createFloatConst(op->getLoc(), baseType,
320                          APFloat::getInf(sem, /*Negative=*/false), rewriter);
321     Value negInfinity =
322         createFloatConst(op->getLoc(), baseType,
323                          APFloat::getInf(sem, /*Negative=*/true), rewriter);
324     Value zeroEqCheck =
325         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero);
326     Value negZeroEqCheck =
327         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero);
328     res = b.create<arith::DivFOp>(baseType, one, res);
329     res =
330         b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res);
331     res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity,
332                                     res);
333   }
334 
335   rewriter.replaceOp(op, res);
336   return success();
337 }
338 
339 // Converts  Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
340 static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
341   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
342   Value operandA = op.getOperand(0);
343   Value operandB = op.getOperand(1);
344   Type opType = operandA.getType();
345   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
346   Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
347   Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
348   Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
349   Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);
350 
351   Value logA = b.create<math::LogOp>(opType, opASquared);
352   Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
353   Value expResult = b.create<math::ExpOp>(opType, mult);
354   Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
355   Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
356   Value negCheck =
357       b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
358   Value oddPower =
359       b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
360   Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
361 
362   Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
363                                         expResult);
364   rewriter.replaceOp(op, res);
365   return success();
366 }
367 
368 // exp2f(float x) -> exp(x * ln(2))
369 //   Proof: Let's say 2^x = y
370 //   ln(2^x) = ln(y)
371 //   x * ln(2) = ln(y) => e ^(x*ln(2)) = y
372 static LogicalResult convertExp2fOp(math::Exp2Op op,
373                                     PatternRewriter &rewriter) {
374   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
375   Value operand = op.getOperand();
376   Type opType = operand.getType();
377   Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
378   Value mult = b.create<arith::MulFOp>(opType, operand, ln2);
379   Value exp = b.create<math::ExpOp>(op->getLoc(), mult);
380   rewriter.replaceOp(op, exp);
381   return success();
382 }
383 
384 static LogicalResult convertRoundOp(math::RoundOp op,
385                                     PatternRewriter &rewriter) {
386   Location loc = op.getLoc();
387   ImplicitLocOpBuilder b(loc, rewriter);
388   Value operand = op.getOperand();
389   Type opType = operand.getType();
390   Type opEType = getElementTypeOrSelf(opType);
391 
392   if (!opEType.isF32()) {
393     return rewriter.notifyMatchFailure(op, "not a round of f32.");
394   }
395 
396   Type i32Ty = b.getI32Type();
397   if (auto shapedTy = dyn_cast<ShapedType>(opType))
398     i32Ty = shapedTy.clone(i32Ty);
399 
400   Value half = createFloatConst(loc, opType, 0.5, b);
401   Value c23 = createIntConst(loc, i32Ty, 23, b);
402   Value c127 = createIntConst(loc, i32Ty, 127, b);
403   Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
404 
405   Value incrValue = b.create<math::CopySignOp>(half, operand);
406   Value add = b.create<arith::AddFOp>(opType, operand, incrValue);
407   Value fpFixedConvert = createTruncatedFPValue(add, b);
408 
409   // There are three cases where adding 0.5 to the value and truncating by
410   // converting to an i64 does not result in the correct behavior:
411   //
412   // 1. Special values: +-inf and +-nan
413   //     Casting these special values to i64 has undefined behavior. To identify
414   //     these values, we use the fact that these values are the only float
415   //     values with the maximum possible biased exponent.
416   //
417   // 2. Large values: 2^23 <= |x| <= INT_64_MAX
418   //     Adding 0.5 to a float larger than or equal to 2^23 results in precision
419   //     errors that sometimes round the value up and sometimes round the value
420   //     down. For example:
421   //         8388608.0 + 0.5 = 8388608.0
422   //         8388609.0 + 0.5 = 8388610.0
423   //
424   // 3. Very large values: |x| > INT_64_MAX
425   //     Casting to i64 a value greater than the max i64 value will overflow the
426   //     i64 leading to wrong outputs.
427   //
428   // All three cases satisfy the property `biasedExp >= 23`.
429   Value operandBitcast = b.create<arith::BitcastOp>(i32Ty, operand);
430   Value operandExp = b.create<arith::AndIOp>(
431       b.create<arith::ShRUIOp>(operandBitcast, c23), expMask);
432   Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127);
433   Value isSpecialValOrLargeVal =
434       b.create<arith::CmpIOp>(arith::CmpIPredicate::sge, operandBiasedExp, c23);
435 
436   Value result = b.create<arith::SelectOp>(isSpecialValOrLargeVal, operand,
437                                            fpFixedConvert);
438   rewriter.replaceOp(op, result);
439   return success();
440 }
441 
442 // Converts math.ctlz to scf and arith operations. This is done
443 // by performing a binary search on the bits.
444 static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
445                                    PatternRewriter &rewriter) {
446   auto operand = op.getOperand();
447   auto operandTy = operand.getType();
448   auto eTy = getElementTypeOrSelf(operandTy);
449   Location loc = op.getLoc();
450 
451   int32_t bitwidth = eTy.getIntOrFloatBitWidth();
452   if (bitwidth > 64)
453     return failure();
454 
455   uint64_t allbits = -1;
456   if (bitwidth < 64) {
457     allbits = allbits >> (64 - bitwidth);
458   }
459 
460   Value x = operand;
461   Value count = createIntConst(loc, operandTy, 0, rewriter);
462   for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
463     auto half = bw / 2;
464     auto bits = createIntConst(loc, operandTy, half, rewriter);
465     auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter);
466 
467     Value pred =
468         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, x, mask);
469     Value add = rewriter.create<arith::AddIOp>(loc, count, bits);
470     Value shift = rewriter.create<arith::ShLIOp>(loc, x, bits);
471 
472     x = rewriter.create<arith::SelectOp>(loc, pred, shift, x);
473     count = rewriter.create<arith::SelectOp>(loc, pred, add, count);
474   }
475 
476   Value zero = createIntConst(loc, operandTy, 0, rewriter);
477   Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
478                                               operand, zero);
479 
480   Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter);
481   Value sel = rewriter.create<arith::SelectOp>(loc, pred, bwval, count);
482   rewriter.replaceOp(op, sel);
483   return success();
484 }
485 
486 // Convert `math.roundeven` into `math.round` + arith ops
487 static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
488                                         PatternRewriter &rewriter) {
489   Location loc = op.getLoc();
490   ImplicitLocOpBuilder b(loc, rewriter);
491   auto operand = op.getOperand();
492   Type operandTy = operand.getType();
493   Type resultTy = op.getType();
494   Type operandETy = getElementTypeOrSelf(operandTy);
495   Type resultETy = getElementTypeOrSelf(resultTy);
496 
497   if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) {
498     return rewriter.notifyMatchFailure(op, "not a roundeven of f16 or f32.");
499   }
500 
501   Type fTy = operandTy;
502   Type iTy = rewriter.getIntegerType(operandETy.getIntOrFloatBitWidth());
503   if (auto shapedTy = dyn_cast<ShapedType>(fTy)) {
504     iTy = shapedTy.clone(iTy);
505   }
506 
507   unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
508   // The width returned by getFPMantissaWidth includes the integer bit.
509   unsigned mantissaWidth =
510       llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
511   unsigned exponentWidth = bitWidth - mantissaWidth - 1;
512 
513   // The names of the variables correspond to f32.
514   // f64: 1 bit sign | 11 bits exponent | 52 bits mantissa.
515   // f32: 1 bit sign | 8 bits exponent  | 23 bits mantissa.
516   // f16: 1 bit sign | 5 bits exponent  | 10 bits mantissa.
517   Value c1Float = createFloatConst(loc, fTy, 1.0, b);
518   Value c0 = createIntConst(loc, iTy, 0, b);
519   Value c1 = createIntConst(loc, iTy, 1, b);
520   Value cNeg1 = createIntConst(loc, iTy, -1, b);
521   Value c23 = createIntConst(loc, iTy, mantissaWidth, b);
522   Value c31 = createIntConst(loc, iTy, bitWidth - 1, b);
523   Value c127 = createIntConst(loc, iTy, (1ull << (exponentWidth - 1)) - 1, b);
524   Value c2To22 = createIntConst(loc, iTy, 1ull << (mantissaWidth - 1), b);
525   Value c23Mask = createIntConst(loc, iTy, (1ull << mantissaWidth) - 1, b);
526   Value expMask = createIntConst(loc, iTy, (1ull << exponentWidth) - 1, b);
527 
528   Value operandBitcast = b.create<arith::BitcastOp>(iTy, operand);
529   Value round = b.create<math::RoundOp>(operand);
530   Value roundBitcast = b.create<arith::BitcastOp>(iTy, round);
531 
532   // Get biased exponents for operand and round(operand)
533   Value operandExp = b.create<arith::AndIOp>(
534       b.create<arith::ShRUIOp>(operandBitcast, c23), expMask);
535   Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127);
536   Value roundExp = b.create<arith::AndIOp>(
537       b.create<arith::ShRUIOp>(roundBitcast, c23), expMask);
538   Value roundBiasedExp = b.create<arith::SubIOp>(roundExp, c127);
539 
540   auto safeShiftRight = [&](Value x, Value shift) -> Value {
541     // Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior
542     Value clampedShift = b.create<arith::MaxSIOp>(shift, c0);
543     clampedShift = b.create<arith::MinSIOp>(clampedShift, c31);
544     return b.create<arith::ShRUIOp>(x, clampedShift);
545   };
546 
547   auto maskMantissa = [&](Value mantissa,
548                           Value mantissaMaskRightShift) -> Value {
549     Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
550     return b.create<arith::AndIOp>(mantissa, shiftedMantissaMask);
551   };
552 
553   // A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring
554   // the leftmost `clamp(biasedExp - 1, 0, 23)` bits, is zero. Large numbers
555   // with `biasedExp > 23` (numbers where there is not enough precision to store
556   // decimals) are always even, and they satisfy the even condition trivially
557   // since the mantissa without all its bits is zero. The even condition
558   // is also true for +-0, since they have `biasedExp = -127` and the entire
559   // mantissa is zero. The case of +-1 has to be handled separately. Here
560   // we identify these values by noting that +-1 are the only whole numbers with
561   // `biasedExp == 0`.
562   //
563   // The special values +-inf and +-nan also satisfy the same property that
564   // whole non-unit even numbers satisfy. In particular, the special values have
565   // `biasedExp > 23`, so they get treated as large numbers with no room for
566   // decimals, which are always even.
567   Value roundBiasedExpEq0 =
568       b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, roundBiasedExp, c0);
569   Value roundBiasedExpMinus1 = b.create<arith::SubIOp>(roundBiasedExp, c1);
570   Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
571   Value roundIsNotEvenOrSpecialVal = b.create<arith::CmpIOp>(
572       arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
573   roundIsNotEvenOrSpecialVal =
574       b.create<arith::OrIOp>(roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
575 
576   // A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive
577   // integers if the bit at index `biasedExp` starting from the left in the
578   // mantissa is 1 and all the bits to the right are zero. Values with
579   // `biasedExp >= 23` don't have decimals, so they are never halfway. The
580   // values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`,
581   // so these are handled separately. In particular, if `biasedExp == -1`, the
582   // value is halfway if the entire mantissa is zero.
583   Value operandBiasedExpEqNeg1 = b.create<arith::CmpIOp>(
584       arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
585   Value expectedOperandMaskedMantissa = b.create<arith::SelectOp>(
586       operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
587   Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
588   Value operandIsHalfway =
589       b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, operandMaskedMantissa,
590                               expectedOperandMaskedMantissa);
591   // Ensure `biasedExp` is in the valid range for half values.
592   Value operandBiasedExpGeNeg1 = b.create<arith::CmpIOp>(
593       arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
594   Value operandBiasedExpLt23 =
595       b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, operandBiasedExp, c23);
596   operandIsHalfway =
597       b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpLt23);
598   operandIsHalfway =
599       b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpGeNeg1);
600 
601   // Adjust rounded operand with `round(operand) - sign(operand)` to correct the
602   // case where `round` rounded in the opposite direction of `roundeven`.
603   Value sign = b.create<math::CopySignOp>(c1Float, operand);
604   Value roundShifted = b.create<arith::SubFOp>(round, sign);
605   // If the rounded value is even or a special value, we default to the behavior
606   // of `math.round`.
607   Value needsShift =
608       b.create<arith::AndIOp>(roundIsNotEvenOrSpecialVal, operandIsHalfway);
609   Value result = b.create<arith::SelectOp>(needsShift, roundShifted, round);
610   // The `x - sign` adjustment does not preserve the sign when we are adjusting
611   // the value -1 to -0. So here the sign is copied again to ensure that -0.5 is
612   // rounded to -0.0.
613   result = b.create<math::CopySignOp>(result, operand);
614   rewriter.replaceOp(op, result);
615   return success();
616 }
617 
618 void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
619   patterns.add(convertCtlzOp);
620 }
621 
622 void mlir::populateExpandSinhPattern(RewritePatternSet &patterns) {
623   patterns.add(convertSinhOp);
624 }
625 
626 void mlir::populateExpandCoshPattern(RewritePatternSet &patterns) {
627   patterns.add(convertCoshOp);
628 }
629 
630 void mlir::populateExpandTanPattern(RewritePatternSet &patterns) {
631   patterns.add(convertTanOp);
632 }
633 
634 void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
635   patterns.add(convertTanhOp);
636 }
637 
638 void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) {
639   patterns.add(convertAsinhOp);
640 }
641 
642 void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) {
643   patterns.add(convertAcoshOp);
644 }
645 
646 void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) {
647   patterns.add(convertAtanhOp);
648 }
649 
650 void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
651   patterns.add(convertFmaFOp);
652 }
653 
654 void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) {
655   patterns.add(convertCeilOp);
656 }
657 
658 void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
659   patterns.add(convertExp2fOp);
660 }
661 
662 void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
663   patterns.add(convertPowfOp);
664 }
665 
666 void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
667   patterns.add(convertFPowIOp);
668 }
669 
670 void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
671   patterns.add(convertRoundOp);
672 }
673 
674 void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
675   patterns.add(convertFloorOp);
676 }
677 
678 void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) {
679   patterns.add(convertRoundEvenOp);
680 }
681