xref: /llvm-project/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp (revision f75d164eeaf407b21ec6b2cff4bcc3ad2003af61)
1 //===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements expansion of tanh op.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Math/IR/Math.h"
15 #include "mlir/Dialect/Math/Transforms/Passes.h"
16 #include "mlir/Dialect/SCF/IR/SCF.h"
17 #include "mlir/Dialect/Vector/IR/VectorOps.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/ImplicitLocOpBuilder.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 
23 using namespace mlir;
24 
25 /// Create a float constant.
26 static Value createFloatConst(Location loc, Type type, double value,
27                               OpBuilder &b) {
28   auto attr = b.getFloatAttr(getElementTypeOrSelf(type), value);
29   if (auto shapedTy = dyn_cast<ShapedType>(type)) {
30     return b.create<arith::ConstantOp>(loc,
31                                        DenseElementsAttr::get(shapedTy, attr));
32   }
33 
34   return b.create<arith::ConstantOp>(loc, attr);
35 }
36 
37 /// Create a float constant.
38 static Value createIntConst(Location loc, Type type, int64_t value,
39                             OpBuilder &b) {
40   auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
41   if (auto shapedTy = dyn_cast<ShapedType>(type)) {
42     return b.create<arith::ConstantOp>(loc,
43                                        DenseElementsAttr::get(shapedTy, attr));
44   }
45 
46   return b.create<arith::ConstantOp>(loc, attr);
47 }
48 
49 static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
50   Type opType = operand.getType();
51   Type i64Ty = b.getI64Type();
52   if (auto shapedTy = dyn_cast<ShapedType>(opType))
53     i64Ty = shapedTy.clone(i64Ty);
54   Value fixedConvert = b.create<arith::FPToSIOp>(i64Ty, operand);
55   Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert);
56   // The truncation does not preserve the sign when the truncated
57   // value is -0. So here the sign is copied again.
58   return b.create<math::CopySignOp>(fpFixedConvert, operand);
59 }
60 
61 // sinhf(float x) -> (exp(x) - exp(-x)) / 2
62 static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
63   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
64   Value operand = op.getOperand();
65   Type opType = operand.getType();
66   Value exp = b.create<math::ExpOp>(operand);
67 
68   Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
69   Value nexp = b.create<arith::DivFOp>(one, exp);
70   Value sub = b.create<arith::SubFOp>(exp, nexp);
71   Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter);
72   Value div = b.create<arith::DivFOp>(sub, two);
73   rewriter.replaceOp(op, div);
74   return success();
75 }
76 
77 // coshf(float x) -> (exp(x) + exp(-x)) / 2
78 static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
79   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
80   Value operand = op.getOperand();
81   Type opType = operand.getType();
82   Value exp = b.create<math::ExpOp>(operand);
83 
84   Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
85   Value nexp = b.create<arith::DivFOp>(one, exp);
86   Value add = b.create<arith::AddFOp>(exp, nexp);
87   Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter);
88   Value div = b.create<arith::DivFOp>(add, two);
89   rewriter.replaceOp(op, div);
90   return success();
91 }
92 
93 /// Expands tanh op into
94 ///   1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
95 ///   2) exp^{2x}-1 / exp^{2x}+1  , if x < 0
96 static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
97   auto floatType = op.getOperand().getType();
98   Location loc = op.getLoc();
99   Value one = createFloatConst(loc, floatType, 1.0, rewriter);
100   Value two = createFloatConst(loc, floatType, 2.0, rewriter);
101   Value doubledX = rewriter.create<arith::MulFOp>(loc, op.getOperand(), two);
102 
103   // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
104   Value negDoubledX = rewriter.create<arith::NegFOp>(loc, doubledX);
105   Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
106   Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
107   Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
108   Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
109 
110   // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
111   exp2x = rewriter.create<math::ExpOp>(loc, doubledX);
112   dividend = rewriter.create<arith::SubFOp>(loc, exp2x, one);
113   divisor = rewriter.create<arith::AddFOp>(loc, exp2x, one);
114   Value negativeRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
115 
116   // tanh(x) = x >= 0 ? positiveRes : negativeRes
117   Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
118   Value cmpRes = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
119                                                 op.getOperand(), zero);
120   rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmpRes, positiveRes,
121                                                negativeRes);
122   return success();
123 }
124 
125 // Converts math.tan to math.sin, math.cos, and arith.divf.
126 static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
127   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
128   Value operand = op.getOperand();
129   Type type = operand.getType();
130   Value sin = b.create<math::SinOp>(type, operand);
131   Value cos = b.create<math::CosOp>(type, operand);
132   Value div = b.create<arith::DivFOp>(type, sin, cos);
133   rewriter.replaceOp(op, div);
134   return success();
135 }
136 
137 static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
138   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
139   Value operandA = op.getOperand(0);
140   Value operandB = op.getOperand(1);
141   Value operandC = op.getOperand(2);
142   Type type = op.getType();
143   Value mult = b.create<arith::MulFOp>(type, operandA, operandB);
144   Value add = b.create<arith::AddFOp>(type, mult, operandC);
145   rewriter.replaceOp(op, add);
146   return success();
147 }
148 
149 // Converts a floorf() function to the following:
150 // floorf(float x) ->
151 //     y = (float)(int) x
152 //     if (x < 0) then incr = -1 else incr = 0
153 //     y = y + incr    <= replace this op with the floorf op.
154 static LogicalResult convertFloorOp(math::FloorOp op,
155                                     PatternRewriter &rewriter) {
156   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
157   Value operand = op.getOperand();
158   Type opType = operand.getType();
159   Value fpFixedConvert = createTruncatedFPValue(operand, b);
160 
161   // Creating constants for later use.
162   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
163   Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
164 
165   Value negCheck =
166       b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
167   Value incrValue =
168       b.create<arith::SelectOp>(op->getLoc(), negCheck, negOne, zero);
169   Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
170   rewriter.replaceOp(op, ret);
171   return success();
172 }
173 
174 // Converts a ceilf() function to the following:
175 // ceilf(float x) ->
176 //      y = (float)(int) x
177 //      if (x > y) then incr = 1 else incr = 0
178 //      y = y + incr   <= replace this op with the ceilf op.
179 static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
180   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
181   Value operand = op.getOperand();
182   Type opType = operand.getType();
183   Value fpFixedConvert = createTruncatedFPValue(operand, b);
184 
185   // Creating constants for later use.
186   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
187   Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
188 
189   Value gtCheck = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand,
190                                           fpFixedConvert);
191   Value incrValue = b.create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero);
192 
193   Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
194   rewriter.replaceOp(op, ret);
195   return success();
196 }
197 // Converts  Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
198 static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
199   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
200   Value operandA = op.getOperand(0);
201   Value operandB = op.getOperand(1);
202   Type opType = operandA.getType();
203   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
204   Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
205   Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
206   Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
207   Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);
208 
209   Value logA = b.create<math::LogOp>(opType, opASquared);
210   Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
211   Value expResult = b.create<math::ExpOp>(opType, mult);
212   Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
213   Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
214   Value negCheck =
215       b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
216   Value oddPower =
217       b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
218   Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
219 
220   Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
221                                         expResult);
222   rewriter.replaceOp(op, res);
223   return success();
224 }
225 
226 // exp2f(float x) -> exp(x * ln(2))
227 //   Proof: Let's say 2^x = y
228 //   ln(2^x) = ln(y)
229 //   x * ln(2) = ln(y) => e ^(x*ln(2)) = y
230 static LogicalResult convertExp2fOp(math::Exp2Op op,
231                                     PatternRewriter &rewriter) {
232   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
233   Value operand = op.getOperand();
234   Type opType = operand.getType();
235   Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
236   Value mult = b.create<arith::MulFOp>(opType, operand, ln2);
237   Value exp = b.create<math::ExpOp>(op->getLoc(), mult);
238   rewriter.replaceOp(op, exp);
239   return success();
240 }
241 
242 static LogicalResult convertRoundOp(math::RoundOp op,
243                                     PatternRewriter &rewriter) {
244   Location loc = op.getLoc();
245   ImplicitLocOpBuilder b(loc, rewriter);
246   Value operand = op.getOperand();
247   Type opType = operand.getType();
248   Type opEType = getElementTypeOrSelf(opType);
249 
250   if (!opEType.isF32()) {
251     return rewriter.notifyMatchFailure(op, "not a round of f32.");
252   }
253 
254   Type i32Ty = b.getI32Type();
255   if (auto shapedTy = dyn_cast<ShapedType>(opType))
256     i32Ty = shapedTy.clone(i32Ty);
257 
258   Value half = createFloatConst(loc, opType, 0.5, b);
259   Value c23 = createIntConst(loc, i32Ty, 23, b);
260   Value c127 = createIntConst(loc, i32Ty, 127, b);
261   Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
262 
263   Value incrValue = b.create<math::CopySignOp>(half, operand);
264   Value add = b.create<arith::AddFOp>(opType, operand, incrValue);
265   Value fpFixedConvert = createTruncatedFPValue(add, b);
266 
267   // There are three cases where adding 0.5 to the value and truncating by
268   // converting to an i64 does not result in the correct behavior:
269   //
270   // 1. Special values: +-inf and +-nan
271   //     Casting these special values to i64 has undefined behavior. To identify
272   //     these values, we use the fact that these values are the only float
273   //     values with the maximum possible biased exponent.
274   //
275   // 2. Large values: 2^23 <= |x| <= INT_64_MAX
276   //     Adding 0.5 to a float larger than or equal to 2^23 results in precision
277   //     errors that sometimes round the value up and sometimes round the value
278   //     down. For example:
279   //         8388608.0 + 0.5 = 8388608.0
280   //         8388609.0 + 0.5 = 8388610.0
281   //
282   // 3. Very large values: |x| > INT_64_MAX
283   //     Casting to i64 a value greater than the max i64 value will overflow the
284   //     i64 leading to wrong outputs.
285   //
286   // All three cases satisfy the property `biasedExp >= 23`.
287   Value operandBitcast = b.create<arith::BitcastOp>(i32Ty, operand);
288   Value operandExp = b.create<arith::AndIOp>(
289       b.create<arith::ShRUIOp>(operandBitcast, c23), expMask);
290   Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127);
291   Value isSpecialValOrLargeVal =
292       b.create<arith::CmpIOp>(arith::CmpIPredicate::sge, operandBiasedExp, c23);
293 
294   Value result = b.create<arith::SelectOp>(isSpecialValOrLargeVal, operand,
295                                            fpFixedConvert);
296   rewriter.replaceOp(op, result);
297   return success();
298 }
299 
300 // Converts math.ctlz to scf and arith operations. This is done
301 // by performing a binary search on the bits.
302 static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
303                                    PatternRewriter &rewriter) {
304   auto operand = op.getOperand();
305   auto operandTy = operand.getType();
306   auto eTy = getElementTypeOrSelf(operandTy);
307   Location loc = op.getLoc();
308 
309   int32_t bitwidth = eTy.getIntOrFloatBitWidth();
310   if (bitwidth > 64)
311     return failure();
312 
313   uint64_t allbits = -1;
314   if (bitwidth < 64) {
315     allbits = allbits >> (64 - bitwidth);
316   }
317 
318   Value x = operand;
319   Value count = createIntConst(loc, operandTy, 0, rewriter);
320   for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
321     auto half = bw / 2;
322     auto bits = createIntConst(loc, operandTy, half, rewriter);
323     auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter);
324 
325     Value pred =
326         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, x, mask);
327     Value add = rewriter.create<arith::AddIOp>(loc, count, bits);
328     Value shift = rewriter.create<arith::ShLIOp>(loc, x, bits);
329 
330     x = rewriter.create<arith::SelectOp>(loc, pred, shift, x);
331     count = rewriter.create<arith::SelectOp>(loc, pred, add, count);
332   }
333 
334   Value zero = createIntConst(loc, operandTy, 0, rewriter);
335   Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
336                                               operand, zero);
337 
338   Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter);
339   Value sel = rewriter.create<arith::SelectOp>(loc, pred, bwval, count);
340   rewriter.replaceOp(op, sel);
341   return success();
342 }
343 
344 // Convert `math.roundeven` into `math.round` + arith ops
345 static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
346                                         PatternRewriter &rewriter) {
347   Location loc = op.getLoc();
348   ImplicitLocOpBuilder b(loc, rewriter);
349   auto operand = op.getOperand();
350   Type operandTy = operand.getType();
351   Type resultTy = op.getType();
352   Type operandETy = getElementTypeOrSelf(operandTy);
353   Type resultETy = getElementTypeOrSelf(resultTy);
354 
355   if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) {
356     return rewriter.notifyMatchFailure(op, "not a roundeven of f16 or f32.");
357   }
358 
359   Type fTy = operandTy;
360   Type iTy = rewriter.getIntegerType(operandETy.getIntOrFloatBitWidth());
361   if (auto shapedTy = dyn_cast<ShapedType>(fTy)) {
362     iTy = shapedTy.clone(iTy);
363   }
364 
365   unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
366   // The width returned by getFPMantissaWidth includes the integer bit.
367   unsigned mantissaWidth =
368       llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
369   unsigned exponentWidth = bitWidth - mantissaWidth - 1;
370 
371   // The names of the variables correspond to f32.
372   // f64: 1 bit sign | 11 bits exponent | 52 bits mantissa.
373   // f32: 1 bit sign | 8 bits exponent  | 23 bits mantissa.
374   // f16: 1 bit sign | 5 bits exponent  | 10 bits mantissa.
375   Value c1Float = createFloatConst(loc, fTy, 1.0, b);
376   Value c0 = createIntConst(loc, iTy, 0, b);
377   Value c1 = createIntConst(loc, iTy, 1, b);
378   Value cNeg1 = createIntConst(loc, iTy, -1, b);
379   Value c23 = createIntConst(loc, iTy, mantissaWidth, b);
380   Value c31 = createIntConst(loc, iTy, bitWidth - 1, b);
381   Value c127 = createIntConst(loc, iTy, (1ull << (exponentWidth - 1)) - 1, b);
382   Value c2To22 = createIntConst(loc, iTy, 1ull << (mantissaWidth - 1), b);
383   Value c23Mask = createIntConst(loc, iTy, (1ull << mantissaWidth) - 1, b);
384   Value expMask = createIntConst(loc, iTy, (1ull << exponentWidth) - 1, b);
385 
386   Value operandBitcast = b.create<arith::BitcastOp>(iTy, operand);
387   Value round = b.create<math::RoundOp>(operand);
388   Value roundBitcast = b.create<arith::BitcastOp>(iTy, round);
389 
390   // Get biased exponents for operand and round(operand)
391   Value operandExp = b.create<arith::AndIOp>(
392       b.create<arith::ShRUIOp>(operandBitcast, c23), expMask);
393   Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127);
394   Value roundExp = b.create<arith::AndIOp>(
395       b.create<arith::ShRUIOp>(roundBitcast, c23), expMask);
396   Value roundBiasedExp = b.create<arith::SubIOp>(roundExp, c127);
397 
398   auto safeShiftRight = [&](Value x, Value shift) -> Value {
399     // Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior
400     Value clampedShift = b.create<arith::MaxSIOp>(shift, c0);
401     clampedShift = b.create<arith::MinSIOp>(clampedShift, c31);
402     return b.create<arith::ShRUIOp>(x, clampedShift);
403   };
404 
405   auto maskMantissa = [&](Value mantissa,
406                           Value mantissaMaskRightShift) -> Value {
407     Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
408     return b.create<arith::AndIOp>(mantissa, shiftedMantissaMask);
409   };
410 
411   // A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring
412   // the leftmost `clamp(biasedExp - 1, 0, 23)` bits, is zero. Large numbers
413   // with `biasedExp > 23` (numbers where there is not enough precision to store
414   // decimals) are always even, and they satisfy the even condition trivially
415   // since the mantissa without all its bits is zero. The even condition
416   // is also true for +-0, since they have `biasedExp = -127` and the entire
417   // mantissa is zero. The case of +-1 has to be handled separately. Here
418   // we identify these values by noting that +-1 are the only whole numbers with
419   // `biasedExp == 0`.
420   //
421   // The special values +-inf and +-nan also satisfy the same property that
422   // whole non-unit even numbers satisfy. In particular, the special values have
423   // `biasedExp > 23`, so they get treated as large numbers with no room for
424   // decimals, which are always even.
425   Value roundBiasedExpEq0 =
426       b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, roundBiasedExp, c0);
427   Value roundBiasedExpMinus1 = b.create<arith::SubIOp>(roundBiasedExp, c1);
428   Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
429   Value roundIsNotEvenOrSpecialVal = b.create<arith::CmpIOp>(
430       arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
431   roundIsNotEvenOrSpecialVal =
432       b.create<arith::OrIOp>(roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
433 
434   // A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive
435   // integers if the bit at index `biasedExp` starting from the left in the
436   // mantissa is 1 and all the bits to the right are zero. Values with
437   // `biasedExp >= 23` don't have decimals, so they are never halfway. The
438   // values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`,
439   // so these are handled separately. In particular, if `biasedExp == -1`, the
440   // value is halfway if the entire mantissa is zero.
441   Value operandBiasedExpEqNeg1 = b.create<arith::CmpIOp>(
442       arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
443   Value expectedOperandMaskedMantissa = b.create<arith::SelectOp>(
444       operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
445   Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
446   Value operandIsHalfway =
447       b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, operandMaskedMantissa,
448                               expectedOperandMaskedMantissa);
449   // Ensure `biasedExp` is in the valid range for half values.
450   Value operandBiasedExpGeNeg1 = b.create<arith::CmpIOp>(
451       arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
452   Value operandBiasedExpLt23 =
453       b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, operandBiasedExp, c23);
454   operandIsHalfway =
455       b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpLt23);
456   operandIsHalfway =
457       b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpGeNeg1);
458 
459   // Adjust rounded operand with `round(operand) - sign(operand)` to correct the
460   // case where `round` rounded in the opposite direction of `roundeven`.
461   Value sign = b.create<math::CopySignOp>(c1Float, operand);
462   Value roundShifted = b.create<arith::SubFOp>(round, sign);
463   // If the rounded value is even or a special value, we default to the behavior
464   // of `math.round`.
465   Value needsShift =
466       b.create<arith::AndIOp>(roundIsNotEvenOrSpecialVal, operandIsHalfway);
467   Value result = b.create<arith::SelectOp>(needsShift, roundShifted, round);
468   // The `x - sign` adjustment does not preserve the sign when we are adjusting
469   // the value -1 to -0. So here the sign is copied again to ensure that -0.5 is
470   // rounded to -0.0.
471   result = b.create<math::CopySignOp>(result, operand);
472   rewriter.replaceOp(op, result);
473   return success();
474 }
475 
476 void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
477   patterns.add(convertCtlzOp);
478 }
479 
480 void mlir::populateExpandSinhPattern(RewritePatternSet &patterns) {
481   patterns.add(convertSinhOp);
482 }
483 
484 void mlir::populateExpandCoshPattern(RewritePatternSet &patterns) {
485   patterns.add(convertCoshOp);
486 }
487 
488 void mlir::populateExpandTanPattern(RewritePatternSet &patterns) {
489   patterns.add(convertTanOp);
490 }
491 
492 void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
493   patterns.add(convertTanhOp);
494 }
495 
496 void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
497   patterns.add(convertFmaFOp);
498 }
499 
500 void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) {
501   patterns.add(convertCeilOp);
502 }
503 
504 void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
505   patterns.add(convertExp2fOp);
506 }
507 
508 void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
509   patterns.add(convertPowfOp);
510 }
511 
512 void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
513   patterns.add(convertRoundOp);
514 }
515 
516 void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
517   patterns.add(convertFloorOp);
518 }
519 
520 void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) {
521   patterns.add(convertRoundEvenOp);
522 }
523