xref: /llvm-project/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp (revision 72f362170fbe7834febfe08d8bd1f7ba935ddbc9)
1 //===- PolynomialApproximation.cpp - Approximate math operations ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements expansion of math operations to fast approximations
10 // that do not rely on any of the library functions.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <climits>
15 #include <cmath>
16 #include <cstddef>
17 
18 #include "mlir/Dialect/Arith/IR/Arith.h"
19 #include "mlir/Dialect/Math/IR/Math.h"
20 #include "mlir/Dialect/Math/Transforms/Approximation.h"
21 #include "mlir/Dialect/Math/Transforms/Passes.h"
22 #include "mlir/Dialect/Utils/IndexingUtils.h"
23 #include "mlir/Dialect/Vector/IR/VectorOps.h"
24 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
25 #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
26 #include "mlir/IR/Builders.h"
27 #include "mlir/IR/BuiltinTypes.h"
28 #include "mlir/IR/ImplicitLocOpBuilder.h"
29 #include "mlir/IR/OpDefinition.h"
30 #include "mlir/IR/PatternMatch.h"
31 #include "mlir/IR/TypeUtilities.h"
32 #include "mlir/Transforms/DialectConversion.h"
33 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/Support/MathExtras.h"
37 
38 using namespace mlir;
39 using namespace mlir::math;
40 using namespace mlir::vector;
41 
42 // Helper to encapsulate a vector's shape (including scalable dims).
43 struct VectorShape {
44   ArrayRef<int64_t> sizes;
45   ArrayRef<bool> scalableFlags;
46 };
47 
48 // Returns vector shape if the type is a vector, otherwise return nullopt.
49 static std::optional<VectorShape> vectorShape(Type type) {
50   if (auto vectorType = dyn_cast<VectorType>(type)) {
51     return VectorShape{vectorType.getShape(), vectorType.getScalableDims()};
52   }
53   return std::nullopt;
54 }
55 
56 static std::optional<VectorShape> vectorShape(Value value) {
57   return vectorShape(value.getType());
58 }
59 
60 //----------------------------------------------------------------------------//
61 // Broadcast scalar types and values into vector types and values.
62 //----------------------------------------------------------------------------//
63 
64 // Broadcasts scalar type into vector type (iff shape is non-scalar).
65 static Type broadcast(Type type, std::optional<VectorShape> shape) {
66   assert(!isa<VectorType>(type) && "must be scalar type");
67   return shape ? VectorType::get(shape->sizes, type, shape->scalableFlags)
68                : type;
69 }
70 
71 // Broadcasts scalar value into vector (iff shape is non-scalar).
72 static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
73                        std::optional<VectorShape> shape) {
74   assert(!isa<VectorType>(value.getType()) && "must be scalar value");
75   auto type = broadcast(value.getType(), shape);
76   return shape ? builder.create<BroadcastOp>(type, value) : value;
77 }
78 
79 //----------------------------------------------------------------------------//
80 // Helper function to handle n-D vectors with 1-D operations.
81 //----------------------------------------------------------------------------//
82 
83 // Expands and unrolls n-D vector operands into multiple fixed size 1-D vectors
84 // and calls the compute function with 1-D vector operands. Stitches back all
85 // results into the original n-D vector result.
86 //
87 // Examples: vectorWidth = 8
88 //   - vector<4x8xf32> unrolled 4 times
89 //   - vector<16xf32> expanded to vector<2x8xf32> and unrolled 2 times
90 //   - vector<4x16xf32> expanded to vector<4x2x8xf32> and unrolled 4*2 times
91 //
92 // Some math approximations rely on ISA-specific operations that only accept
93 // fixed size 1-D vectors (e.g. AVX expects vectors of width 8).
94 //
95 // It is the caller's responsibility to verify that the inner dimension is
96 // divisible by the vectorWidth, and that all operands have the same vector
97 // shape.
98 static Value
99 handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
100                               ValueRange operands, int64_t vectorWidth,
101                               llvm::function_ref<Value(ValueRange)> compute) {
102   assert(!operands.empty() && "operands must be not empty");
103   assert(vectorWidth > 0 && "vector width must be larger than 0");
104 
105   VectorType inputType = cast<VectorType>(operands[0].getType());
106   ArrayRef<int64_t> inputShape = inputType.getShape();
107 
108   // If input shape matches target vector width, we can just call the
109   // user-provided compute function with the operands.
110   if (inputShape == llvm::ArrayRef(vectorWidth))
111     return compute(operands);
112 
113   // Check if the inner dimension has to be expanded, or we can directly iterate
114   // over the outer dimensions of the vector.
115   int64_t innerDim = inputShape.back();
116   int64_t expansionDim = innerDim / vectorWidth;
117   assert((innerDim % vectorWidth == 0) && "invalid inner dimension size");
118 
119   // Maybe expand operands to the higher rank vector shape that we'll use to
120   // iterate over and extract one dimensional vectors.
121   SmallVector<int64_t> expandedShape(inputShape);
122   SmallVector<Value> expandedOperands(operands);
123 
124   if (expansionDim > 1) {
125     // Expand shape from [..., innerDim] to [..., expansionDim, vectorWidth].
126     expandedShape.insert(expandedShape.end() - 1, expansionDim);
127     expandedShape.back() = vectorWidth;
128 
129     for (unsigned i = 0; i < operands.size(); ++i) {
130       auto operand = operands[i];
131       auto eltType = cast<VectorType>(operand.getType()).getElementType();
132       auto expandedType = VectorType::get(expandedShape, eltType);
133       expandedOperands[i] =
134           builder.create<vector::ShapeCastOp>(expandedType, operand);
135     }
136   }
137 
138   // Iterate over all outer dimensions of the compute shape vector type.
139   auto iterationDims = ArrayRef<int64_t>(expandedShape).drop_back();
140   int64_t maxIndex = computeMaxLinearIndex(iterationDims);
141   auto strides = computeStrides(iterationDims);
142 
143   // Compute results for each one dimensional vector.
144   SmallVector<Value> results(maxIndex);
145 
146   for (int64_t i = 0; i < maxIndex; ++i) {
147     auto offsets = delinearize(i, strides);
148 
149     SmallVector<Value> extracted(expandedOperands.size());
150     for (const auto &tuple : llvm::enumerate(expandedOperands))
151       extracted[tuple.index()] =
152           builder.create<vector::ExtractOp>(tuple.value(), offsets);
153 
154     results[i] = compute(extracted);
155   }
156 
157   // Stitch results together into one large vector.
158   Type resultEltType = cast<VectorType>(results[0].getType()).getElementType();
159   Type resultExpandedType = VectorType::get(expandedShape, resultEltType);
160   Value result = builder.create<arith::ConstantOp>(
161       resultExpandedType, builder.getZeroAttr(resultExpandedType));
162 
163   for (int64_t i = 0; i < maxIndex; ++i)
164     result = builder.create<vector::InsertOp>(results[i], result,
165                                               delinearize(i, strides));
166 
167   // Reshape back to the original vector shape.
168   return builder.create<vector::ShapeCastOp>(
169       VectorType::get(inputShape, resultEltType), result);
170 }
171 
172 //----------------------------------------------------------------------------//
173 // Helper functions to create constants.
174 //----------------------------------------------------------------------------//
175 
176 static Value floatCst(ImplicitLocOpBuilder &builder, float value,
177                       Type elementType) {
178   assert((elementType.isF16() || elementType.isF32()) &&
179          "x must be f16 or f32 type.");
180   return builder.create<arith::ConstantOp>(
181       builder.getFloatAttr(elementType, value));
182 }
183 
184 static Value f32Cst(ImplicitLocOpBuilder &builder, double value) {
185   return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value));
186 }
187 
188 static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) {
189   return builder.create<arith::ConstantOp>(builder.getI32IntegerAttr(value));
190 }
191 
192 static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) {
193   Value i32Value = i32Cst(builder, static_cast<int32_t>(bits));
194   return builder.create<arith::BitcastOp>(builder.getF32Type(), i32Value);
195 }
196 
197 //----------------------------------------------------------------------------//
198 // Helper functions to build math functions approximations.
199 //----------------------------------------------------------------------------//
200 
201 // Return the minimum of the two values or NaN if value is NaN
202 static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound) {
203   return builder.create<arith::SelectOp>(
204       builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, value, bound),
205       value, bound);
206 }
207 
208 // Return the maximum of the two values or NaN if value is NaN
209 static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound) {
210   return builder.create<arith::SelectOp>(
211       builder.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, value, bound),
212       value, bound);
213 }
214 
215 // Return the clamped value or NaN if value is NaN
216 static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
217                    Value upperBound) {
218   return max(builder, min(builder, value, upperBound), lowerBound);
219 }
220 
221 // Decomposes given floating point value `arg` into a normalized fraction and
222 // an integral power of two (see std::frexp). Returned values have float type.
223 static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
224                                      bool isPositive = false) {
225   assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
226   std::optional<VectorShape> shape = vectorShape(arg);
227 
228   auto bcast = [&](Value value) -> Value {
229     return broadcast(builder, value, shape);
230   };
231 
232   auto i32 = builder.getIntegerType(32);
233   auto i32Vec = broadcast(i32, shape);
234   auto f32Vec = broadcast(builder.getF32Type(), shape);
235 
236   Value cst126f = f32Cst(builder, 126.0f);
237   Value cstHalf = f32Cst(builder, 0.5f);
238   Value cstInvMantMask = f32FromBits(builder, ~0x7f800000u);
239 
240   // Bitcast to i32 for bitwise operations.
241   Value i32Half = builder.create<arith::BitcastOp>(i32, cstHalf);
242   Value i32InvMantMask = builder.create<arith::BitcastOp>(i32, cstInvMantMask);
243   Value i32Arg = builder.create<arith::BitcastOp>(i32Vec, arg);
244 
245   // Compute normalized fraction.
246   Value tmp0 = builder.create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask));
247   Value tmp1 = builder.create<arith::OrIOp>(tmp0, bcast(i32Half));
248   Value normalizedFraction = builder.create<arith::BitcastOp>(f32Vec, tmp1);
249 
250   // Compute exponent.
251   Value arg0 = isPositive ? arg : builder.create<math::AbsFOp>(arg);
252   Value biasedExponentBits = builder.create<arith::ShRUIOp>(
253       builder.create<arith::BitcastOp>(i32Vec, arg0),
254       bcast(i32Cst(builder, 23)));
255   Value biasedExponent =
256       builder.create<arith::SIToFPOp>(f32Vec, biasedExponentBits);
257   Value exponent =
258       builder.create<arith::SubFOp>(biasedExponent, bcast(cst126f));
259 
260   return {normalizedFraction, exponent};
261 }
262 
263 // Computes exp2 for an i32 argument.
264 static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
265   assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type");
266   std::optional<VectorShape> shape = vectorShape(arg);
267 
268   auto bcast = [&](Value value) -> Value {
269     return broadcast(builder, value, shape);
270   };
271 
272   auto f32Vec = broadcast(builder.getF32Type(), shape);
273   // The exponent of f32 located at 23-bit.
274   auto exponetBitLocation = bcast(i32Cst(builder, 23));
275   // Set the exponent bias to zero.
276   auto bias = bcast(i32Cst(builder, 127));
277 
278   Value biasedArg = builder.create<arith::AddIOp>(arg, bias);
279   Value exp2ValueInt =
280       builder.create<arith::ShLIOp>(biasedArg, exponetBitLocation);
281   Value exp2ValueF32 = builder.create<arith::BitcastOp>(f32Vec, exp2ValueInt);
282 
283   return exp2ValueF32;
284 }
285 
286 namespace {
287 Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
288                                 llvm::ArrayRef<Value> coeffs, Value x) {
289   Type elementType = getElementTypeOrSelf(x);
290   assert((elementType.isF32() || elementType.isF16()) &&
291          "x must be f32 or f16 type");
292   std::optional<VectorShape> shape = vectorShape(x);
293 
294   if (coeffs.empty())
295     return broadcast(builder, floatCst(builder, 0.0f, elementType), shape);
296 
297   if (coeffs.size() == 1)
298     return coeffs[0];
299 
300   Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1],
301                                           coeffs[coeffs.size() - 2]);
302   for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) {
303     res = builder.create<math::FmaOp>(x, res, coeffs[i]);
304   }
305   return res;
306 }
307 } // namespace
308 
309 //----------------------------------------------------------------------------//
310 // Helper function/pattern to insert casts for reusing F32 bit expansion.
311 //----------------------------------------------------------------------------//
312 
313 template <typename T>
314 LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) {
315   // Conservatively only allow where the operand and result types are exactly 1.
316   Type origType = op->getResultTypes().front();
317   for (Type t : llvm::drop_begin(op->getResultTypes()))
318     if (origType != t)
319       return rewriter.notifyMatchFailure(op, "required all types to match");
320   for (Type t : op->getOperandTypes())
321     if (origType != t)
322       return rewriter.notifyMatchFailure(op, "required all types to match");
323 
324   // Skip if already F32  or larger than 32 bits.
325   if (getElementTypeOrSelf(origType).isF32() ||
326       getElementTypeOrSelf(origType).getIntOrFloatBitWidth() > 32)
327     return failure();
328 
329   // Create F32 equivalent type.
330   Type newType;
331   if (auto shaped = dyn_cast<ShapedType>(origType)) {
332     newType = shaped.clone(rewriter.getF32Type());
333   } else if (isa<FloatType>(origType)) {
334     newType = rewriter.getF32Type();
335   } else {
336     return rewriter.notifyMatchFailure(op,
337                                        "unable to find F32 equivalent type");
338   }
339 
340   Location loc = op->getLoc();
341   SmallVector<Value> operands;
342   for (auto operand : op->getOperands())
343     operands.push_back(rewriter.create<arith::ExtFOp>(loc, newType, operand));
344   auto result =
345       rewriter.create<T>(loc, TypeRange{newType}, operands, op->getAttrs());
346   rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, origType, result);
347   return success();
348 }
349 
350 namespace {
351 // Pattern to cast to F32 to reuse F32 expansion as fallback for single-result
352 // op.
353 // TODO: Consider revising to avoid adding multiple casts for a subgraph that is
354 // all in lower precision. Currently this is only fallback support and performs
355 // simplistic casting.
356 template <typename T>
357 struct ReuseF32Expansion : public OpRewritePattern<T> {
358 public:
359   using OpRewritePattern<T>::OpRewritePattern;
360   LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const final {
361     static_assert(
362         T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(),
363         "requires same operands and result types");
364     return insertCasts<T>(op, rewriter);
365   }
366 };
367 } // namespace
368 
369 //----------------------------------------------------------------------------//
370 // AtanOp approximation.
371 //----------------------------------------------------------------------------//
372 
373 namespace {
374 struct AtanApproximation : public OpRewritePattern<math::AtanOp> {
375 public:
376   using OpRewritePattern::OpRewritePattern;
377 
378   LogicalResult matchAndRewrite(math::AtanOp op,
379                                 PatternRewriter &rewriter) const final;
380 };
381 } // namespace
382 
383 LogicalResult
384 AtanApproximation::matchAndRewrite(math::AtanOp op,
385                                    PatternRewriter &rewriter) const {
386   auto operand = op.getOperand();
387   if (!getElementTypeOrSelf(operand).isF32())
388     return rewriter.notifyMatchFailure(op, "unsupported operand type");
389 
390   std::optional<VectorShape> shape = vectorShape(op.getOperand());
391 
392   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
393   Value abs = builder.create<math::AbsFOp>(operand);
394 
395   auto one = broadcast(builder, f32Cst(builder, 1.0), shape);
396 
397   // When 0.66 < x <= 2.41 we do (x-1) / (x+1):
398   auto twoThirds = broadcast(builder, f32Cst(builder, 0.66), shape);
399   Value cmp2 =
400       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, twoThirds);
401   Value addone = builder.create<arith::AddFOp>(abs, one);
402   Value subone = builder.create<arith::SubFOp>(abs, one);
403   Value xnum = builder.create<arith::SelectOp>(cmp2, subone, abs);
404   Value xden = builder.create<arith::SelectOp>(cmp2, addone, one);
405 
406   auto bcast = [&](Value value) -> Value {
407     return broadcast(builder, value, shape);
408   };
409 
410   // Break into the <= 0.66 or > 2.41 we do x or 1/x:
411   auto tan3pio8 = bcast(f32Cst(builder, 2.41421356237309504880));
412   Value cmp1 =
413       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, tan3pio8);
414   xnum = builder.create<arith::SelectOp>(cmp1, one, xnum);
415   xden = builder.create<arith::SelectOp>(cmp1, abs, xden);
416 
417   Value x = builder.create<arith::DivFOp>(xnum, xden);
418   Value xx = builder.create<arith::MulFOp>(x, x);
419 
420   // Perform the Taylor series approximation for atan over the range
421   // [0.0, 0.66].
422   auto p0 = bcast(f32Cst(builder, -8.750608600031904122785e-01));
423   auto p1 = bcast(f32Cst(builder, -1.615753718733365076637e+01));
424   auto p2 = bcast(f32Cst(builder, -7.500855792314704667340e+01));
425   auto p3 = bcast(f32Cst(builder, -1.228866684490136173410e+02));
426   auto p4 = bcast(f32Cst(builder, -6.485021904942025371773e+01));
427   auto q0 = bcast(f32Cst(builder, +2.485846490142306297962e+01));
428   auto q1 = bcast(f32Cst(builder, +1.650270098316988542046e+02));
429   auto q2 = bcast(f32Cst(builder, +4.328810604912902668951e+02));
430   auto q3 = bcast(f32Cst(builder, +4.853903996359136964868e+02));
431   auto q4 = bcast(f32Cst(builder, +1.945506571482613964425e+02));
432 
433   // Apply the polynomial approximation for the numerator:
434   Value n = p0;
435   n = builder.create<math::FmaOp>(xx, n, p1);
436   n = builder.create<math::FmaOp>(xx, n, p2);
437   n = builder.create<math::FmaOp>(xx, n, p3);
438   n = builder.create<math::FmaOp>(xx, n, p4);
439   n = builder.create<arith::MulFOp>(n, xx);
440 
441   // Apply the polynomial approximation for the denominator:
442   Value d = q0;
443   d = builder.create<math::FmaOp>(xx, d, q1);
444   d = builder.create<math::FmaOp>(xx, d, q2);
445   d = builder.create<math::FmaOp>(xx, d, q3);
446   d = builder.create<math::FmaOp>(xx, d, q4);
447 
448   // Compute approximation of theta:
449   Value ans0 = builder.create<arith::DivFOp>(n, d);
450   ans0 = builder.create<math::FmaOp>(ans0, x, x);
451 
452   // Correct for the input mapping's angles:
453   Value mpi4 = bcast(f32Cst(builder, llvm::numbers::pi / 4));
454   Value ans2 = builder.create<arith::AddFOp>(mpi4, ans0);
455   Value ans = builder.create<arith::SelectOp>(cmp2, ans2, ans0);
456 
457   Value mpi2 = bcast(f32Cst(builder, llvm::numbers::pi / 2));
458   Value ans1 = builder.create<arith::SubFOp>(mpi2, ans0);
459   ans = builder.create<arith::SelectOp>(cmp1, ans1, ans);
460 
461   // Correct for signing of the input.
462   rewriter.replaceOpWithNewOp<math::CopySignOp>(op, ans, operand);
463   return success();
464 }
465 
466 //----------------------------------------------------------------------------//
467 // AtanOp approximation.
468 //----------------------------------------------------------------------------//
469 
470 namespace {
471 struct Atan2Approximation : public OpRewritePattern<math::Atan2Op> {
472 public:
473   using OpRewritePattern::OpRewritePattern;
474 
475   LogicalResult matchAndRewrite(math::Atan2Op op,
476                                 PatternRewriter &rewriter) const final;
477 };
478 } // namespace
479 
480 LogicalResult
481 Atan2Approximation::matchAndRewrite(math::Atan2Op op,
482                                     PatternRewriter &rewriter) const {
483   auto y = op.getOperand(0);
484   auto x = op.getOperand(1);
485   if (!getElementTypeOrSelf(x).isF32())
486     return rewriter.notifyMatchFailure(op, "unsupported operand type");
487 
488   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
489   std::optional<VectorShape> shape = vectorShape(op.getResult());
490 
491   // Compute atan in the valid range.
492   auto div = builder.create<arith::DivFOp>(y, x);
493   auto atan = builder.create<math::AtanOp>(div);
494 
495   // Determine what the atan would be for a 180 degree rotation.
496   auto zero = broadcast(builder, f32Cst(builder, 0.0f), shape);
497   auto pi = broadcast(builder, f32Cst(builder, 3.14159265359f), shape);
498   auto addPi = builder.create<arith::AddFOp>(atan, pi);
499   auto subPi = builder.create<arith::SubFOp>(atan, pi);
500   auto atanGt =
501       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, atan, zero);
502   auto flippedAtan = builder.create<arith::SelectOp>(atanGt, subPi, addPi);
503 
504   // Determine whether to directly use atan or use the 180 degree flip
505   auto xGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zero);
506   Value result = builder.create<arith::SelectOp>(xGt, atan, flippedAtan);
507 
508   // Handle x = 0, y > 0
509   Value xZero =
510       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, x, zero);
511   Value yGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, y, zero);
512   Value isHalfPi = builder.create<arith::AndIOp>(xZero, yGt);
513   auto halfPi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape);
514   result = builder.create<arith::SelectOp>(isHalfPi, halfPi, result);
515 
516   // Handle x = 0, y < 0
517   Value yLt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, y, zero);
518   Value isNegativeHalfPiPi = builder.create<arith::AndIOp>(xZero, yLt);
519   auto negativeHalfPiPi =
520       broadcast(builder, f32Cst(builder, -1.57079632679f), shape);
521   result = builder.create<arith::SelectOp>(isNegativeHalfPiPi, negativeHalfPiPi,
522                                            result);
523 
524   // Handle x = 0, y = 0;
525   Value yZero =
526       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, y, zero);
527   Value isNan = builder.create<arith::AndIOp>(xZero, yZero);
528   Value cstNan = broadcast(builder, f32FromBits(builder, 0x7fc00000), shape);
529   result = builder.create<arith::SelectOp>(isNan, cstNan, result);
530 
531   rewriter.replaceOp(op, result);
532   return success();
533 }
534 
535 //----------------------------------------------------------------------------//
536 // TanhOp approximation.
537 //----------------------------------------------------------------------------//
538 
539 namespace {
540 struct TanhApproximation : public OpRewritePattern<math::TanhOp> {
541 public:
542   using OpRewritePattern::OpRewritePattern;
543 
544   LogicalResult matchAndRewrite(math::TanhOp op,
545                                 PatternRewriter &rewriter) const final;
546 };
547 } // namespace
548 
549 LogicalResult
550 TanhApproximation::matchAndRewrite(math::TanhOp op,
551                                    PatternRewriter &rewriter) const {
552   if (!getElementTypeOrSelf(op.getOperand()).isF32())
553     return rewriter.notifyMatchFailure(op, "unsupported operand type");
554 
555   std::optional<VectorShape> shape = vectorShape(op.getOperand());
556 
557   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
558   auto bcast = [&](Value value) -> Value {
559     return broadcast(builder, value, shape);
560   };
561 
562   // Clamp operand into [plusClamp, minusClamp] range.
563   Value minusClamp = bcast(f32Cst(builder, -7.99881172180175781f));
564   Value plusClamp = bcast(f32Cst(builder, 7.99881172180175781f));
565   Value x = clamp(builder, op.getOperand(), minusClamp, plusClamp);
566 
567   // Mask for tiny values that are approximated with `operand`.
568   Value tiny = bcast(f32Cst(builder, 0.0004f));
569   Value tinyMask = builder.create<arith::CmpFOp>(
570       arith::CmpFPredicate::OLT, builder.create<math::AbsFOp>(op.getOperand()),
571       tiny);
572 
573   // The monomial coefficients of the numerator polynomial (odd).
574   Value alpha1 = bcast(f32Cst(builder, 4.89352455891786e-03f));
575   Value alpha3 = bcast(f32Cst(builder, 6.37261928875436e-04f));
576   Value alpha5 = bcast(f32Cst(builder, 1.48572235717979e-05f));
577   Value alpha7 = bcast(f32Cst(builder, 5.12229709037114e-08f));
578   Value alpha9 = bcast(f32Cst(builder, -8.60467152213735e-11f));
579   Value alpha11 = bcast(f32Cst(builder, 2.00018790482477e-13f));
580   Value alpha13 = bcast(f32Cst(builder, -2.76076847742355e-16f));
581 
582   // The monomial coefficients of the denominator polynomial (even).
583   Value beta0 = bcast(f32Cst(builder, 4.89352518554385e-03f));
584   Value beta2 = bcast(f32Cst(builder, 2.26843463243900e-03f));
585   Value beta4 = bcast(f32Cst(builder, 1.18534705686654e-04f));
586   Value beta6 = bcast(f32Cst(builder, 1.19825839466702e-06f));
587 
588   // Since the polynomials are odd/even, we need x^2.
589   Value x2 = builder.create<arith::MulFOp>(x, x);
590 
591   // Evaluate the numerator polynomial p.
592   Value p = builder.create<math::FmaOp>(x2, alpha13, alpha11);
593   p = builder.create<math::FmaOp>(x2, p, alpha9);
594   p = builder.create<math::FmaOp>(x2, p, alpha7);
595   p = builder.create<math::FmaOp>(x2, p, alpha5);
596   p = builder.create<math::FmaOp>(x2, p, alpha3);
597   p = builder.create<math::FmaOp>(x2, p, alpha1);
598   p = builder.create<arith::MulFOp>(x, p);
599 
600   // Evaluate the denominator polynomial q.
601   Value q = builder.create<math::FmaOp>(x2, beta6, beta4);
602   q = builder.create<math::FmaOp>(x2, q, beta2);
603   q = builder.create<math::FmaOp>(x2, q, beta0);
604 
605   // Divide the numerator by the denominator.
606   Value res = builder.create<arith::SelectOp>(
607       tinyMask, x, builder.create<arith::DivFOp>(p, q));
608 
609   rewriter.replaceOp(op, res);
610 
611   return success();
612 }
613 
614 #define LN2_VALUE                                                              \
615   0.693147180559945309417232121458176568075500134360255254120680009493393621L
616 #define LOG2E_VALUE                                                            \
617   1.442695040888963407359924681001892137426645954152985934135449406931109219L
618 
619 //----------------------------------------------------------------------------//
620 // LogOp and Log2Op approximation.
621 //----------------------------------------------------------------------------//
622 
623 namespace {
624 template <typename Op>
625 struct LogApproximationBase : public OpRewritePattern<Op> {
626   using OpRewritePattern<Op>::OpRewritePattern;
627 
628   /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise.
629   LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter,
630                                    bool base2) const;
631 };
632 } // namespace
633 
634 // This approximation comes from Julien Pommier's SSE math library.
635 // Link: http://gruntthepeon.free.fr/ssemath
636 template <typename Op>
637 LogicalResult
638 LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
639                                              bool base2) const {
640   if (!getElementTypeOrSelf(op.getOperand()).isF32())
641     return rewriter.notifyMatchFailure(op, "unsupported operand type");
642 
643   std::optional<VectorShape> shape = vectorShape(op.getOperand());
644 
645   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
646   auto bcast = [&](Value value) -> Value {
647     return broadcast(builder, value, shape);
648   };
649 
650   Value cstZero = bcast(f32Cst(builder, 0.0f));
651   Value cstOne = bcast(f32Cst(builder, 1.0f));
652   Value cstNegHalf = bcast(f32Cst(builder, -0.5f));
653 
654   // The smallest non denormalized float number.
655   Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u));
656   Value cstMinusInf = bcast(f32FromBits(builder, 0xff800000u));
657   Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));
658   Value cstNan = bcast(f32FromBits(builder, 0x7fc00000));
659 
660   // Polynomial coefficients.
661   Value cstCephesSQRTHF = bcast(f32Cst(builder, 0.707106781186547524f));
662   Value cstCephesLogP0 = bcast(f32Cst(builder, 7.0376836292E-2f));
663   Value cstCephesLogP1 = bcast(f32Cst(builder, -1.1514610310E-1f));
664   Value cstCephesLogP2 = bcast(f32Cst(builder, 1.1676998740E-1f));
665   Value cstCephesLogP3 = bcast(f32Cst(builder, -1.2420140846E-1f));
666   Value cstCephesLogP4 = bcast(f32Cst(builder, +1.4249322787E-1f));
667   Value cstCephesLogP5 = bcast(f32Cst(builder, -1.6668057665E-1f));
668   Value cstCephesLogP6 = bcast(f32Cst(builder, +2.0000714765E-1f));
669   Value cstCephesLogP7 = bcast(f32Cst(builder, -2.4999993993E-1f));
670   Value cstCephesLogP8 = bcast(f32Cst(builder, +3.3333331174E-1f));
671 
672   Value x = op.getOperand();
673 
674   // Truncate input values to the minimum positive normal.
675   x = max(builder, x, cstMinNormPos);
676 
677   // Extract significant in the range [0.5,1) and exponent.
678   std::pair<Value, Value> pair = frexp(builder, x, /*isPositive=*/true);
679   x = pair.first;
680   Value e = pair.second;
681 
682   // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift
683   // by -1.0. The values are then centered around 0, which improves the
684   // stability of the polynomial evaluation:
685   //
686   //   if( x < SQRTHF ) {
687   //     e -= 1;
688   //     x = x + x - 1.0;
689   //   } else { x = x - 1.0; }
690   Value mask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x,
691                                              cstCephesSQRTHF);
692   Value tmp = builder.create<arith::SelectOp>(mask, x, cstZero);
693 
694   x = builder.create<arith::SubFOp>(x, cstOne);
695   e = builder.create<arith::SubFOp>(
696       e, builder.create<arith::SelectOp>(mask, cstOne, cstZero));
697   x = builder.create<arith::AddFOp>(x, tmp);
698 
699   Value x2 = builder.create<arith::MulFOp>(x, x);
700   Value x3 = builder.create<arith::MulFOp>(x2, x);
701 
702   // Evaluate the polynomial approximant of degree 8 in three parts.
703   Value y0, y1, y2;
704   y0 = builder.create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1);
705   y1 = builder.create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4);
706   y2 = builder.create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7);
707   y0 = builder.create<math::FmaOp>(y0, x, cstCephesLogP2);
708   y1 = builder.create<math::FmaOp>(y1, x, cstCephesLogP5);
709   y2 = builder.create<math::FmaOp>(y2, x, cstCephesLogP8);
710   y0 = builder.create<math::FmaOp>(y0, x3, y1);
711   y0 = builder.create<math::FmaOp>(y0, x3, y2);
712   y0 = builder.create<arith::MulFOp>(y0, x3);
713 
714   y0 = builder.create<math::FmaOp>(cstNegHalf, x2, y0);
715   x = builder.create<arith::AddFOp>(x, y0);
716 
717   if (base2) {
718     Value cstLog2e = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE)));
719     x = builder.create<math::FmaOp>(x, cstLog2e, e);
720   } else {
721     Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE)));
722     x = builder.create<math::FmaOp>(e, cstLn2, x);
723   }
724 
725   Value invalidMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT,
726                                                     op.getOperand(), cstZero);
727   Value zeroMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
728                                                  op.getOperand(), cstZero);
729   Value posInfMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
730                                                    op.getOperand(), cstPosInf);
731 
732   // Filter out invalid values:
733   //  • x == 0     -> -INF
734   //  • x < 0      ->  NAN
735   //  • x == +INF  -> +INF
736   Value aproximation = builder.create<arith::SelectOp>(
737       zeroMask, cstMinusInf,
738       builder.create<arith::SelectOp>(
739           invalidMask, cstNan,
740           builder.create<arith::SelectOp>(posInfMask, cstPosInf, x)));
741 
742   rewriter.replaceOp(op, aproximation);
743 
744   return success();
745 }
746 
747 namespace {
748 struct LogApproximation : public LogApproximationBase<math::LogOp> {
749   using LogApproximationBase::LogApproximationBase;
750 
751   LogicalResult matchAndRewrite(math::LogOp op,
752                                 PatternRewriter &rewriter) const final {
753     return logMatchAndRewrite(op, rewriter, /*base2=*/false);
754   }
755 };
756 } // namespace
757 
758 namespace {
759 struct Log2Approximation : public LogApproximationBase<math::Log2Op> {
760   using LogApproximationBase::LogApproximationBase;
761 
762   LogicalResult matchAndRewrite(math::Log2Op op,
763                                 PatternRewriter &rewriter) const final {
764     return logMatchAndRewrite(op, rewriter, /*base2=*/true);
765   }
766 };
767 } // namespace
768 
769 //----------------------------------------------------------------------------//
770 // Log1p approximation.
771 //----------------------------------------------------------------------------//
772 
773 namespace {
774 struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> {
775 public:
776   using OpRewritePattern::OpRewritePattern;
777 
778   LogicalResult matchAndRewrite(math::Log1pOp op,
779                                 PatternRewriter &rewriter) const final;
780 };
781 } // namespace
782 
783 // Approximate log(1+x).
784 LogicalResult
785 Log1pApproximation::matchAndRewrite(math::Log1pOp op,
786                                     PatternRewriter &rewriter) const {
787   if (!getElementTypeOrSelf(op.getOperand()).isF32())
788     return rewriter.notifyMatchFailure(op, "unsupported operand type");
789 
790   std::optional<VectorShape> shape = vectorShape(op.getOperand());
791 
792   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
793   auto bcast = [&](Value value) -> Value {
794     return broadcast(builder, value, shape);
795   };
796 
797   // Approximate log(1+x) using the following, due to W. Kahan:
798   //   u = x + 1.0;
799   //   if (u == 1.0 || u == inf) return x;
800   //   return x * log(u) / (u - 1.0);
801   //          ^^^^^^^^^^^^^^^^^^^^^^
802   //             "logLarge" below.
803   Value cstOne = bcast(f32Cst(builder, 1.0f));
804   Value x = op.getOperand();
805   Value u = builder.create<arith::AddFOp>(x, cstOne);
806   Value uSmall =
807       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne);
808   Value logU = builder.create<math::LogOp>(u);
809   Value uInf =
810       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU);
811   Value logLarge = builder.create<arith::MulFOp>(
812       x, builder.create<arith::DivFOp>(
813              logU, builder.create<arith::SubFOp>(u, cstOne)));
814   Value approximation = builder.create<arith::SelectOp>(
815       builder.create<arith::OrIOp>(uSmall, uInf), x, logLarge);
816   rewriter.replaceOp(op, approximation);
817   return success();
818 }
819 
820 //----------------------------------------------------------------------------//
821 // Asin approximation.
822 //----------------------------------------------------------------------------//
823 
824 // Approximates asin(x).
825 // This approximation is based on the following stackoverflow post:
826 // https://stackoverflow.com/a/42683455
827 namespace {
828 struct AsinPolynomialApproximation : public OpRewritePattern<math::AsinOp> {
829 public:
830   using OpRewritePattern::OpRewritePattern;
831 
832   LogicalResult matchAndRewrite(math::AsinOp op,
833                                 PatternRewriter &rewriter) const final;
834 };
835 } // namespace
836 LogicalResult
837 AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
838                                              PatternRewriter &rewriter) const {
839   Value operand = op.getOperand();
840   Type elementType = getElementTypeOrSelf(operand);
841 
842   if (!(elementType.isF32() || elementType.isF16()))
843     return rewriter.notifyMatchFailure(op,
844                                        "only f32 and f16 type is supported.");
845   std::optional<VectorShape> shape = vectorShape(operand);
846 
847   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
848   auto bcast = [&](Value value) -> Value {
849     return broadcast(builder, value, shape);
850   };
851 
852   auto fma = [&](Value a, Value b, Value c) -> Value {
853     return builder.create<math::FmaOp>(a, b, c);
854   };
855 
856   auto mul = [&](Value a, Value b) -> Value {
857     return builder.create<arith::MulFOp>(a, b);
858   };
859 
860   auto sub = [&](Value a, Value b) -> Value {
861     return builder.create<arith::SubFOp>(a, b);
862   };
863 
864   auto abs = [&](Value a) -> Value { return builder.create<math::AbsFOp>(a); };
865 
866   auto sqrt = [&](Value a) -> Value { return builder.create<math::SqrtOp>(a); };
867 
868   auto scopy = [&](Value a, Value b) -> Value {
869     return builder.create<math::CopySignOp>(a, b);
870   };
871 
872   auto sel = [&](Value a, Value b, Value c) -> Value {
873     return builder.create<arith::SelectOp>(a, b, c);
874   };
875 
876   Value abso = abs(operand);
877   Value aa = mul(operand, operand);
878   Value opp = sqrt(sub(bcast(floatCst(builder, 1.0, elementType)), aa));
879 
880   Value gt =
881       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, aa,
882                                     bcast(floatCst(builder, 0.5, elementType)));
883 
884   Value x = sel(gt, opp, abso);
885 
886   // Asin(x) approximation for x = [-9/16, 9/16]:
887   Value s = mul(x, x);
888   Value q = mul(s, s);
889   Value r = bcast(floatCst(builder, 5.5579749017470502e-2, elementType));
890   Value t = bcast(floatCst(builder, -6.2027913464120114e-2, elementType));
891 
892   r = fma(r, q, bcast(floatCst(builder, 5.4224464349245036e-2, elementType)));
893   t = fma(t, q, bcast(floatCst(builder, -1.1326992890324464e-2, elementType)));
894   r = fma(r, q, bcast(floatCst(builder, 1.5268872539397656e-2, elementType)));
895   t = fma(t, q, bcast(floatCst(builder, 1.0493798473372081e-2, elementType)));
896   r = fma(r, q, bcast(floatCst(builder, 1.4106045900607047e-2, elementType)));
897   t = fma(t, q, bcast(floatCst(builder, 1.7339776384962050e-2, elementType)));
898   r = fma(r, q, bcast(floatCst(builder, 2.2372961589651054e-2, elementType)));
899   t = fma(t, q, bcast(floatCst(builder, 3.0381912707941005e-2, elementType)));
900   r = fma(r, q, bcast(floatCst(builder, 4.4642857881094775e-2, elementType)));
901   t = fma(t, q, bcast(floatCst(builder, 7.4999999991367292e-2, elementType)));
902   r = fma(r, s, t);
903   r = fma(r, s, bcast(floatCst(builder, 1.6666666666670193e-1, elementType)));
904   t = mul(x, s);
905   r = fma(r, t, x);
906 
907   Value rsub = sub(bcast(floatCst(builder, 1.57079632679, elementType)), r);
908   r = sel(gt, rsub, r);
909   r = scopy(r, operand);
910 
911   rewriter.replaceOp(op, r);
912   return success();
913 }
914 
915 //----------------------------------------------------------------------------//
916 // Acos approximation.
917 //----------------------------------------------------------------------------//
918 
919 // Approximates acos(x).
920 // This approximation is based on the following stackoverflow post:
921 // https://stackoverflow.com/a/42683455
922 namespace {
923 struct AcosPolynomialApproximation : public OpRewritePattern<math::AcosOp> {
924 public:
925   using OpRewritePattern::OpRewritePattern;
926 
927   LogicalResult matchAndRewrite(math::AcosOp op,
928                                 PatternRewriter &rewriter) const final;
929 };
930 } // namespace
931 LogicalResult
932 AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
933                                              PatternRewriter &rewriter) const {
934   Value operand = op.getOperand();
935   Type elementType = getElementTypeOrSelf(operand);
936 
937   if (!(elementType.isF32() || elementType.isF16()))
938     return rewriter.notifyMatchFailure(op,
939                                        "only f32 and f16 type is supported.");
940   std::optional<VectorShape> shape = vectorShape(operand);
941 
942   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
943   auto bcast = [&](Value value) -> Value {
944     return broadcast(builder, value, shape);
945   };
946 
947   auto fma = [&](Value a, Value b, Value c) -> Value {
948     return builder.create<math::FmaOp>(a, b, c);
949   };
950 
951   auto mul = [&](Value a, Value b) -> Value {
952     return builder.create<arith::MulFOp>(a, b);
953   };
954 
955   Value negOperand = builder.create<arith::NegFOp>(operand);
956   Value zero = bcast(floatCst(builder, 0.0, elementType));
957   Value half = bcast(floatCst(builder, 0.5, elementType));
958   Value negOne = bcast(floatCst(builder, -1.0, elementType));
959   Value selR =
960       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand, zero);
961   Value r = builder.create<arith::SelectOp>(selR, negOperand, operand);
962   Value chkConst = bcast(floatCst(builder, -0.5625, elementType));
963   Value firstPred =
964       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, r, chkConst);
965 
966   Value trueVal =
967       fma(bcast(floatCst(builder, 9.3282184640716537e-1, elementType)),
968           bcast(floatCst(builder, 1.6839188885261840e+0, elementType)),
969           builder.create<math::AsinOp>(r));
970 
971   Value falseVal = builder.create<math::SqrtOp>(fma(half, r, half));
972   falseVal = builder.create<math::AsinOp>(falseVal);
973   falseVal = mul(bcast(floatCst(builder, 2.0, elementType)), falseVal);
974 
975   r = builder.create<arith::SelectOp>(firstPred, trueVal, falseVal);
976 
977   // Check whether the operand lies in between [-1.0, 0.0).
978   Value greaterThanNegOne =
979       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, operand, negOne);
980 
981   Value lessThanZero =
982       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
983 
984   Value betweenNegOneZero =
985       builder.create<arith::AndIOp>(greaterThanNegOne, lessThanZero);
986 
987   trueVal = fma(bcast(floatCst(builder, 1.8656436928143307e+0, elementType)),
988                 bcast(floatCst(builder, 1.6839188885261840e+0, elementType)),
989                 builder.create<arith::NegFOp>(r));
990 
991   Value finalVal =
992       builder.create<arith::SelectOp>(betweenNegOneZero, trueVal, r);
993 
994   rewriter.replaceOp(op, finalVal);
995   return success();
996 }
997 
998 //----------------------------------------------------------------------------//
999 // Erf approximation.
1000 //----------------------------------------------------------------------------//
1001 
1002 // Approximates erf(x) with
1003 // a - P(x)/Q(x)
1004 // where P and Q are polynomials of degree 4.
1005 // Different coefficients are chosen based on the value of x.
1006 // The approximation error is ~2.5e-07.
1007 // Boost's minimax tool that utilizes the Remez method was used to find the
1008 // coefficients.
1009 LogicalResult
1010 ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
1011                                             PatternRewriter &rewriter) const {
1012   Value operand = op.getOperand();
1013   Type elementType = getElementTypeOrSelf(operand);
1014 
1015   if (!(elementType.isF32() || elementType.isF16()))
1016     return rewriter.notifyMatchFailure(op,
1017                                        "only f32 and f16 type is supported.");
1018   std::optional<VectorShape> shape = vectorShape(operand);
1019 
1020   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1021   auto bcast = [&](Value value) -> Value {
1022     return broadcast(builder, value, shape);
1023   };
1024 
1025   const int intervalsCount = 3;
1026   const int polyDegree = 4;
1027 
1028   Value zero = bcast(floatCst(builder, 0, elementType));
1029   Value one = bcast(floatCst(builder, 1, elementType));
1030   Value pp[intervalsCount][polyDegree + 1];
1031   pp[0][0] = bcast(floatCst(builder, +0.00000000000000000e+00f, elementType));
1032   pp[0][1] = bcast(floatCst(builder, +1.12837916222975858e+00f, elementType));
1033   pp[0][2] = bcast(floatCst(builder, -5.23018562988006470e-01f, elementType));
1034   pp[0][3] = bcast(floatCst(builder, +2.09741709609267072e-01f, elementType));
1035   pp[0][4] = bcast(floatCst(builder, +2.58146801602987875e-02f, elementType));
1036   pp[1][0] = bcast(floatCst(builder, +0.00000000000000000e+00f, elementType));
1037   pp[1][1] = bcast(floatCst(builder, +1.12750687816789140e+00f, elementType));
1038   pp[1][2] = bcast(floatCst(builder, -3.64721408487825775e-01f, elementType));
1039   pp[1][3] = bcast(floatCst(builder, +1.18407396425136952e-01f, elementType));
1040   pp[1][4] = bcast(floatCst(builder, +3.70645533056476558e-02f, elementType));
1041   pp[2][0] = bcast(floatCst(builder, -3.30093071049483172e-03f, elementType));
1042   pp[2][1] = bcast(floatCst(builder, +3.51961938357697011e-03f, elementType));
1043   pp[2][2] = bcast(floatCst(builder, -1.41373622814988039e-03f, elementType));
1044   pp[2][3] = bcast(floatCst(builder, +2.53447094961941348e-04f, elementType));
1045   pp[2][4] = bcast(floatCst(builder, -1.71048029455037401e-05f, elementType));
1046 
1047   Value qq[intervalsCount][polyDegree + 1];
1048   qq[0][0] = bcast(floatCst(builder, +1.000000000000000000e+00f, elementType));
1049   qq[0][1] = bcast(floatCst(builder, -4.635138185962547255e-01f, elementType));
1050   qq[0][2] = bcast(floatCst(builder, +5.192301327279782447e-01f, elementType));
1051   qq[0][3] = bcast(floatCst(builder, -1.318089722204810087e-01f, elementType));
1052   qq[0][4] = bcast(floatCst(builder, +7.397964654672315005e-02f, elementType));
1053   qq[1][0] = bcast(floatCst(builder, +1.00000000000000000e+00f, elementType));
1054   qq[1][1] = bcast(floatCst(builder, -3.27607011824493086e-01f, elementType));
1055   qq[1][2] = bcast(floatCst(builder, +4.48369090658821977e-01f, elementType));
1056   qq[1][3] = bcast(floatCst(builder, -8.83462621207857930e-02f, elementType));
1057   qq[1][4] = bcast(floatCst(builder, +5.72442770283176093e-02f, elementType));
1058   qq[2][0] = bcast(floatCst(builder, +1.00000000000000000e+00f, elementType));
1059   qq[2][1] = bcast(floatCst(builder, -2.06069165953913769e+00f, elementType));
1060   qq[2][2] = bcast(floatCst(builder, +1.62705939945477759e+00f, elementType));
1061   qq[2][3] = bcast(floatCst(builder, -5.83389859211130017e-01f, elementType));
1062   qq[2][4] = bcast(floatCst(builder, +8.21908939856640930e-02f, elementType));
1063 
1064   Value offsets[intervalsCount];
1065   offsets[0] = bcast(floatCst(builder, 0.0f, elementType));
1066   offsets[1] = bcast(floatCst(builder, 0.0f, elementType));
1067   offsets[2] = bcast(floatCst(builder, 1.0f, elementType));
1068 
1069   Value bounds[intervalsCount];
1070   bounds[0] = bcast(floatCst(builder, 0.8f, elementType));
1071   bounds[1] = bcast(floatCst(builder, 2.0f, elementType));
1072   bounds[2] = bcast(floatCst(builder, 3.75f, elementType));
1073 
1074   Value isNegativeArg =
1075       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
1076   Value negArg = builder.create<arith::NegFOp>(operand);
1077   Value x = builder.create<arith::SelectOp>(isNegativeArg, negArg, operand);
1078 
1079   Value offset = offsets[0];
1080   Value p[polyDegree + 1];
1081   Value q[polyDegree + 1];
1082   for (int i = 0; i <= polyDegree; ++i) {
1083     p[i] = pp[0][i];
1084     q[i] = qq[0][i];
1085   }
1086 
1087   // TODO: maybe use vector stacking to reduce the number of selects.
1088   Value isLessThanBound[intervalsCount];
1089   for (int j = 0; j < intervalsCount - 1; ++j) {
1090     isLessThanBound[j] =
1091         builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[j]);
1092     for (int i = 0; i <= polyDegree; ++i) {
1093       p[i] = builder.create<arith::SelectOp>(isLessThanBound[j], p[i],
1094                                              pp[j + 1][i]);
1095       q[i] = builder.create<arith::SelectOp>(isLessThanBound[j], q[i],
1096                                              qq[j + 1][i]);
1097     }
1098     offset = builder.create<arith::SelectOp>(isLessThanBound[j], offset,
1099                                              offsets[j + 1]);
1100   }
1101   isLessThanBound[intervalsCount - 1] = builder.create<arith::CmpFOp>(
1102       arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
1103 
1104   Value pPoly = makePolynomialCalculation(builder, p, x);
1105   Value qPoly = makePolynomialCalculation(builder, q, x);
1106   Value rationalPoly = builder.create<arith::DivFOp>(pPoly, qPoly);
1107   Value formula = builder.create<arith::AddFOp>(offset, rationalPoly);
1108   formula = builder.create<arith::SelectOp>(isLessThanBound[intervalsCount - 1],
1109                                             formula, one);
1110 
1111   // erf is odd function: erf(x) = -erf(-x).
1112   Value negFormula = builder.create<arith::NegFOp>(formula);
1113   Value res =
1114       builder.create<arith::SelectOp>(isNegativeArg, negFormula, formula);
1115 
1116   rewriter.replaceOp(op, res);
1117 
1118   return success();
1119 }
1120 
1121 //----------------------------------------------------------------------------//
1122 // Exp approximation.
1123 //----------------------------------------------------------------------------//
1124 
1125 namespace {
1126 
1127 Value clampWithNormals(ImplicitLocOpBuilder &builder,
1128                        const std::optional<VectorShape> shape, Value value,
1129                        float lowerBound, float upperBound) {
1130   assert(!std::isnan(lowerBound));
1131   assert(!std::isnan(upperBound));
1132 
1133   auto bcast = [&](Value value) -> Value {
1134     return broadcast(builder, value, shape);
1135   };
1136 
1137   auto selectCmp = [&builder](auto pred, Value value, Value bound) {
1138     return builder.create<arith::SelectOp>(
1139         builder.create<arith::CmpFOp>(pred, value, bound), value, bound);
1140   };
1141 
1142   // Note: prefer UGE/ULE vs. UGT/ULT, since they generate vmaxps/vminps vs.
1143   // vcmpleps+vmovaps on x86_64. The latter outcome is also obtained with
1144   // arith::{Max,Min}FOp.
1145   value = selectCmp(arith::CmpFPredicate::UGE, value,
1146                     bcast(f32Cst(builder, lowerBound)));
1147   value = selectCmp(arith::CmpFPredicate::ULE, value,
1148                     bcast(f32Cst(builder, upperBound)));
1149   return value;
1150 }
1151 
1152 struct ExpApproximation : public OpRewritePattern<math::ExpOp> {
1153 public:
1154   using OpRewritePattern::OpRewritePattern;
1155 
1156   LogicalResult matchAndRewrite(math::ExpOp op,
1157                                 PatternRewriter &rewriter) const final;
1158 };
1159 
1160 LogicalResult
1161 ExpApproximation::matchAndRewrite(math::ExpOp op,
1162                                   PatternRewriter &rewriter) const {
1163   auto shape = vectorShape(op.getOperand().getType());
1164   auto elementTy = getElementTypeOrSelf(op.getType());
1165   if (!elementTy.isF32())
1166     return rewriter.notifyMatchFailure(op, "unsupported operand type");
1167 
1168   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1169 
1170   auto add = [&](Value a, Value b) -> Value {
1171     return builder.create<arith::AddFOp>(a, b);
1172   };
1173   auto bcast = [&](Value value) -> Value {
1174     return broadcast(builder, value, shape);
1175   };
1176   auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
1177   auto fmla = [&](Value a, Value b, Value c) {
1178     return builder.create<math::FmaOp>(a, b, c);
1179   };
1180   auto mul = [&](Value a, Value b) -> Value {
1181     return builder.create<arith::MulFOp>(a, b);
1182   };
1183 
1184   // Polynomial approximation from Cephes.
1185   //
1186   // To compute e^x, we re-express it as
1187   //
1188   //   e^x = e^(a + b)
1189   //       = e^(a + n log(2))
1190   //       = e^a * 2^n.
1191   //
1192   // We choose n = round(x / log(2)), restricting the value of `a` to
1193   // (-log(2)/2, log(2)/2).  We then use a polynomial to compute e^a. The
1194   // relative error between our approximation and the true value of e^a is less
1195   // than 2^-22.5 for all values of `a` within this range.
1196 
1197   // Restrict input to a small range, including some values that evaluate to
1198   // +/- inf.  Note that for our lower bound, we choose log(2^-126) instead of
1199   // log(F32_EPSILON). We do so because this routine always flushes denormal
1200   // floating points to 0. Therefore, we only need to worry about exponentiating
1201   // up to the smallest representable non-denormal floating point, which is
1202   // 2^-126.
1203 
1204   // Constants.
1205   Value cstHalf = bcast(f32Cst(builder, 0.5f));
1206   Value cstOne = bcast(f32Cst(builder, 1.0f));
1207 
1208   // 1/log(2)
1209   Value cstLog2ef = bcast(f32Cst(builder, 1.44269504088896341f));
1210 
1211   Value cstExpC1 = bcast(f32Cst(builder, -0.693359375f));
1212   Value cstExpC2 = bcast(f32Cst(builder, 2.12194440e-4f));
1213   Value cstExpP0 = bcast(f32Cst(builder, 1.9875691500E-4f));
1214   Value cstExpP1 = bcast(f32Cst(builder, 1.3981999507E-3f));
1215   Value cstExpP2 = bcast(f32Cst(builder, 8.3334519073E-3f));
1216   Value cstExpP3 = bcast(f32Cst(builder, 4.1665795894E-2f));
1217   Value cstExpP4 = bcast(f32Cst(builder, 1.6666665459E-1f));
1218   Value cstExpP5 = bcast(f32Cst(builder, 5.0000001201E-1f));
1219 
1220   // Our computations below aren't particularly sensitive to the exact choices
1221   // here, so we choose values a bit larger/smaller than
1222   //
1223   //   log(F32_MAX) = 88.723...
1224   //   log(2^-126) = -87.337...
1225   Value x = op.getOperand();
1226   x = clampWithNormals(builder, shape, x, -87.8f, 88.8f);
1227   Value n = floor(fmla(x, cstLog2ef, cstHalf));
1228 
1229   // When we eventually do the multiplication in e^a * 2^n, we need to handle
1230   // the case when n > 127, the max fp32 exponent (so 2^n == inf) but e^a < 1
1231   // (so e^a * 2^n != inf).  There's a similar problem for n < -126, the
1232   // smallest fp32 exponent.
1233   //
1234   // A straightforward solution would be to detect n out of range and split it
1235   // up, doing
1236   //
1237   //   e^a * 2^n = e^a * 2^(n1 + n2)
1238   //             = (2^n1 * e^a) * 2^n2.
1239   //
1240   // But it turns out this approach is quite slow, probably because it
1241   // manipulates subnormal values.
1242   //
1243   // The approach we use instead is to clamp n to [-127, 127]. Let n' be the
1244   // value of n clamped to [-127, 127]. In the case where n' = 127, `a` can grow
1245   // up to as large as 88.8 - 127 * log(2) which is about 0.7703. Even though
1246   // this value of `a` is outside our previously specified range, e^a will still
1247   // only have a relative error of approximately 2^-16 at worse. In practice
1248   // this seems to work well enough; it passes our exhaustive tests, breaking
1249   // only one result, and by one ulp (we return exp(88.7228394) = max-float but
1250   // we should return inf).
1251   //
1252   // In the case where n' = -127, the original input value of x is so small that
1253   // e^x, our final answer, is less than 2^-126. Since 2^-126 is the smallest
1254   // normal floating point, and since we flush denormals, we simply return 0. We
1255   // do this in a branchless way by observing that our code for constructing 2^n
1256   // produces 0 if n = -127.
1257   //
1258   // The proof that n' = -127 implies e^x < 2^-126 is as follows:
1259   //
1260   //    n' = -127 implies n <= -127
1261   //              implies round(x / log(2)) <= -127
1262   //              implies x/log(2) < -126.5
1263   //              implies x < -126.5 * log(2)
1264   //              implies e^x < e^(-126.5 * log(2))
1265   //              implies e^x < 2^-126.5 < 2^-126
1266   //
1267   //    This proves that n' = -127 implies e^x < 2^-126.
1268   n = clampWithNormals(builder, shape, n, -127.0f, 127.0f);
1269 
1270   // Computes x = x - n' * log(2), the value for `a`
1271   x = fmla(cstExpC1, n, x);
1272   x = fmla(cstExpC2, n, x);
1273 
1274   // Polynomial to compute z = e^a, accurate for a in (-0.5, 0.5).
1275   Value z = fmla(x, cstExpP0, cstExpP1);
1276   z = fmla(z, x, cstExpP2);
1277   z = fmla(z, x, cstExpP3);
1278   z = fmla(z, x, cstExpP4);
1279   z = fmla(z, x, cstExpP5);
1280   z = fmla(z, mul(x, x), x);
1281   z = add(cstOne, z);
1282 
1283   // Convert n' to an i32.  This is safe because we clamped it above.
1284   auto i32Vec = broadcast(builder.getI32Type(), shape);
1285   Value nI32 = builder.create<arith::FPToSIOp>(i32Vec, n);
1286 
1287   // Creates the value 2^n' if -126 <= n' <= 127 and 0 if n' = -127.
1288   Value pow2 = exp2I32(builder, nI32);
1289 
1290   // Return z * 2^n' if -126 <= n' <= 127 and 0 if n = -127.
1291   Value ret = mul(z, pow2);
1292 
1293   rewriter.replaceOp(op, ret);
1294   return mlir::success();
1295 }
1296 
1297 } // namespace
1298 
1299 //----------------------------------------------------------------------------//
1300 // ExpM1 approximation.
1301 //----------------------------------------------------------------------------//
1302 
1303 namespace {
1304 
1305 struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> {
1306 public:
1307   using OpRewritePattern::OpRewritePattern;
1308 
1309   LogicalResult matchAndRewrite(math::ExpM1Op op,
1310                                 PatternRewriter &rewriter) const final;
1311 };
1312 } // namespace
1313 
1314 LogicalResult
1315 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
1316                                     PatternRewriter &rewriter) const {
1317   if (!getElementTypeOrSelf(op.getOperand()).isF32())
1318     return rewriter.notifyMatchFailure(op, "unsupported operand type");
1319 
1320   std::optional<VectorShape> shape = vectorShape(op.getOperand());
1321 
1322   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1323   auto bcast = [&](Value value) -> Value {
1324     return broadcast(builder, value, shape);
1325   };
1326 
1327   // expm1(x) = exp(x) - 1 = u - 1.
1328   // We have to handle it carefully when x is near 0, i.e. u ~= 1,
1329   // and when the input is ~= -inf, i.e. u - 1 ~= -1.
1330   Value cstOne = bcast(f32Cst(builder, 1.0f));
1331   Value cstNegOne = bcast(f32Cst(builder, -1.0f));
1332   Value x = op.getOperand();
1333   Value u = builder.create<math::ExpOp>(x);
1334   Value uEqOneOrNaN =
1335       builder.create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne);
1336   Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne);
1337   Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>(
1338       arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
1339   // logU = log(u) ~= x
1340   Value logU = builder.create<math::LogOp>(u);
1341 
1342   // Detect exp(x) = +inf; written this way to avoid having to form +inf.
1343   Value isInf =
1344       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u);
1345 
1346   // (u - 1) * (x / ~x)
1347   Value expm1 = builder.create<arith::MulFOp>(
1348       uMinusOne, builder.create<arith::DivFOp>(x, logU));
1349   expm1 = builder.create<arith::SelectOp>(isInf, u, expm1);
1350   Value approximation = builder.create<arith::SelectOp>(
1351       uEqOneOrNaN, x,
1352       builder.create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
1353   rewriter.replaceOp(op, approximation);
1354   return success();
1355 }
1356 
1357 //----------------------------------------------------------------------------//
1358 // Sin and Cos approximation.
1359 //----------------------------------------------------------------------------//
1360 
1361 namespace {
1362 
1363 template <bool isSine, typename OpTy>
1364 struct SinAndCosApproximation : public OpRewritePattern<OpTy> {
1365 public:
1366   using OpRewritePattern<OpTy>::OpRewritePattern;
1367 
1368   LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
1369 };
1370 } // namespace
1371 
1372 #define TWO_OVER_PI                                                            \
1373   0.6366197723675813430755350534900574481378385829618257949906693762L
1374 #define PI_OVER_2                                                              \
1375   1.5707963267948966192313216916397514420985846996875529104874722961L
1376 
1377 // Approximates sin(x) or cos(x) by finding the best approximation polynomial in
1378 // the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the
1379 // reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y).
1380 template <bool isSine, typename OpTy>
1381 LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
1382     OpTy op, PatternRewriter &rewriter) const {
1383   static_assert(
1384       llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
1385       "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
1386 
1387   if (!getElementTypeOrSelf(op.getOperand()).isF32())
1388     return rewriter.notifyMatchFailure(op, "unsupported operand type");
1389 
1390   std::optional<VectorShape> shape = vectorShape(op.getOperand());
1391 
1392   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1393   auto bcast = [&](Value value) -> Value {
1394     return broadcast(builder, value, shape);
1395   };
1396   auto mul = [&](Value a, Value b) -> Value {
1397     return builder.create<arith::MulFOp>(a, b);
1398   };
1399   auto sub = [&](Value a, Value b) -> Value {
1400     return builder.create<arith::SubFOp>(a, b);
1401   };
1402   auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
1403 
1404   auto i32Vec = broadcast(builder.getI32Type(), shape);
1405   auto fPToSingedInteger = [&](Value a) -> Value {
1406     return builder.create<arith::FPToSIOp>(i32Vec, a);
1407   };
1408 
1409   auto modulo4 = [&](Value a) -> Value {
1410     return builder.create<arith::AndIOp>(a, bcast(i32Cst(builder, 3)));
1411   };
1412 
1413   auto isEqualTo = [&](Value a, Value b) -> Value {
1414     return builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b);
1415   };
1416 
1417   auto isGreaterThan = [&](Value a, Value b) -> Value {
1418     return builder.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b);
1419   };
1420 
1421   auto select = [&](Value cond, Value t, Value f) -> Value {
1422     return builder.create<arith::SelectOp>(cond, t, f);
1423   };
1424 
1425   auto fmla = [&](Value a, Value b, Value c) {
1426     return builder.create<math::FmaOp>(a, b, c);
1427   };
1428 
1429   auto bitwiseOr = [&](Value a, Value b) {
1430     return builder.create<arith::OrIOp>(a, b);
1431   };
1432 
1433   Value twoOverPi = bcast(f32Cst(builder, (float)TWO_OVER_PI));
1434   Value piOverTwo = bcast(f32Cst(builder, (float)PI_OVER_2));
1435 
1436   Value x = op.getOperand();
1437 
1438   Value k = floor(mul(x, twoOverPi));
1439 
1440   Value y = sub(x, mul(k, piOverTwo));
1441 
1442   Value cstOne = bcast(f32Cst(builder, 1.0));
1443   Value cstNegativeOne = bcast(f32Cst(builder, -1.0));
1444 
1445   Value cstSC2 = bcast(f32Cst(builder, -0.16666667163372039794921875f));
1446   Value cstSC4 = bcast(f32Cst(builder, 8.333347737789154052734375e-3f));
1447   Value cstSC6 = bcast(f32Cst(builder, -1.9842604524455964565277099609375e-4f));
1448   Value cstSC8 =
1449       bcast(f32Cst(builder, 2.760012648650445044040679931640625e-6f));
1450   Value cstSC10 =
1451       bcast(f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f));
1452 
1453   Value cstCC2 = bcast(f32Cst(builder, -0.5f));
1454   Value cstCC4 = bcast(f32Cst(builder, 4.166664183139801025390625e-2f));
1455   Value cstCC6 = bcast(f32Cst(builder, -1.388833043165504932403564453125e-3f));
1456   Value cstCC8 = bcast(f32Cst(builder, 2.47562347794882953166961669921875e-5f));
1457   Value cstCC10 =
1458       bcast(f32Cst(builder, -2.59630184018533327616751194000244140625e-7f));
1459 
1460   Value kMod4 = modulo4(fPToSingedInteger(k));
1461 
1462   Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, 0)));
1463   Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, 1)));
1464   Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, 2)));
1465   Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, 3)));
1466 
1467   Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2);
1468   Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, 1)))
1469                                : bitwiseOr(kR1, kR2);
1470 
1471   Value y2 = mul(y, y);
1472 
1473   Value base = select(sinuseCos, cstOne, y);
1474   Value cstC2 = select(sinuseCos, cstCC2, cstSC2);
1475   Value cstC4 = select(sinuseCos, cstCC4, cstSC4);
1476   Value cstC6 = select(sinuseCos, cstCC6, cstSC6);
1477   Value cstC8 = select(sinuseCos, cstCC8, cstSC8);
1478   Value cstC10 = select(sinuseCos, cstCC10, cstSC10);
1479 
1480   Value v1 = fmla(y2, cstC10, cstC8);
1481   Value v2 = fmla(y2, v1, cstC6);
1482   Value v3 = fmla(y2, v2, cstC4);
1483   Value v4 = fmla(y2, v3, cstC2);
1484   Value v5 = fmla(y2, v4, cstOne);
1485   Value v6 = mul(base, v5);
1486 
1487   Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6);
1488 
1489   rewriter.replaceOp(op, approximation);
1490 
1491   return success();
1492 }
1493 
1494 //----------------------------------------------------------------------------//
1495 // Cbrt approximation.
1496 //----------------------------------------------------------------------------//
1497 
1498 namespace {
1499 struct CbrtApproximation : public OpRewritePattern<math::CbrtOp> {
1500   using OpRewritePattern::OpRewritePattern;
1501 
1502   LogicalResult matchAndRewrite(math::CbrtOp op,
1503                                 PatternRewriter &rewriter) const final;
1504 };
1505 } // namespace
1506 
1507 // Estimation of cube-root using an algorithm defined in
1508 // Hacker's Delight 2nd Edition.
1509 LogicalResult
1510 CbrtApproximation::matchAndRewrite(math::CbrtOp op,
1511                                    PatternRewriter &rewriter) const {
1512   auto operand = op.getOperand();
1513   if (!getElementTypeOrSelf(operand).isF32())
1514     return rewriter.notifyMatchFailure(op, "unsupported operand type");
1515 
1516   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1517   std::optional<VectorShape> shape = vectorShape(operand);
1518 
1519   Type floatTy = getElementTypeOrSelf(operand.getType());
1520   Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth());
1521 
1522   // Convert to vector types if necessary.
1523   floatTy = broadcast(floatTy, shape);
1524   intTy = broadcast(intTy, shape);
1525 
1526   auto bconst = [&](TypedAttr attr) -> Value {
1527     Value value = b.create<arith::ConstantOp>(attr);
1528     return broadcast(b, value, shape);
1529   };
1530 
1531   // Declare the initial values:
1532   Value intTwo = bconst(b.getI32IntegerAttr(2));
1533   Value intFour = bconst(b.getI32IntegerAttr(4));
1534   Value intEight = bconst(b.getI32IntegerAttr(8));
1535   Value intMagic = bconst(b.getI32IntegerAttr(0x2a5137a0));
1536   Value fpThird = bconst(b.getF32FloatAttr(0.33333333f));
1537   Value fpTwo = bconst(b.getF32FloatAttr(2.0f));
1538   Value fpZero = bconst(b.getF32FloatAttr(0.0f));
1539 
1540   // Compute an approximation of one third:
1541   // union {int ix; float x;};
1542   // x = x0;
1543   // ix = ix/4 + ix/16;
1544   Value absValue = b.create<math::AbsFOp>(operand);
1545   Value intValue = b.create<arith::BitcastOp>(intTy, absValue);
1546   Value divideBy4 = b.create<arith::ShRSIOp>(intValue, intTwo);
1547   Value divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
1548   intValue = b.create<arith::AddIOp>(divideBy4, divideBy16);
1549 
1550   // ix = ix + ix/16;
1551   divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
1552   intValue = b.create<arith::AddIOp>(intValue, divideBy16);
1553 
1554   // ix = ix + ix/256;
1555   Value divideBy256 = b.create<arith::ShRSIOp>(intValue, intEight);
1556   intValue = b.create<arith::AddIOp>(intValue, divideBy256);
1557 
1558   // ix = 0x2a5137a0 + ix;
1559   intValue = b.create<arith::AddIOp>(intValue, intMagic);
1560 
1561   // Perform one newtons step:
1562   // x = 0.33333333f*(2.0f*x + x0/(x*x));
1563   Value floatValue = b.create<arith::BitcastOp>(floatTy, intValue);
1564   Value squared = b.create<arith::MulFOp>(floatValue, floatValue);
1565   Value mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
1566   Value divSquared = b.create<arith::DivFOp>(absValue, squared);
1567   floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
1568   floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
1569 
1570   // x = 0.33333333f*(2.0f*x + x0/(x*x));
1571   squared = b.create<arith::MulFOp>(floatValue, floatValue);
1572   mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
1573   divSquared = b.create<arith::DivFOp>(absValue, squared);
1574   floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
1575   floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
1576 
1577   // Check for zero and restore sign.
1578   Value isZero =
1579       b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absValue, fpZero);
1580   floatValue = b.create<arith::SelectOp>(isZero, fpZero, floatValue);
1581   floatValue = b.create<math::CopySignOp>(floatValue, operand);
1582 
1583   rewriter.replaceOp(op, floatValue);
1584   return success();
1585 }
1586 
1587 //----------------------------------------------------------------------------//
1588 // Rsqrt approximation.
1589 //----------------------------------------------------------------------------//
1590 
1591 namespace {
1592 struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> {
1593   using OpRewritePattern::OpRewritePattern;
1594 
1595   LogicalResult matchAndRewrite(math::RsqrtOp op,
1596                                 PatternRewriter &rewriter) const final;
1597 };
1598 } // namespace
1599 
1600 LogicalResult
1601 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
1602                                     PatternRewriter &rewriter) const {
1603   if (!getElementTypeOrSelf(op.getOperand()).isF32())
1604     return rewriter.notifyMatchFailure(op, "unsupported operand type");
1605 
1606   std::optional<VectorShape> shape = vectorShape(op.getOperand());
1607 
1608   // Only support already-vectorized rsqrt's.
1609   if (!shape || shape->sizes.empty() || shape->sizes.back() % 8 != 0)
1610     return rewriter.notifyMatchFailure(op, "unsupported operand type");
1611 
1612   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1613   auto bcast = [&](Value value) -> Value {
1614     return broadcast(builder, value, shape);
1615   };
1616 
1617   Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));
1618   Value cstOnePointFive = bcast(f32Cst(builder, 1.5f));
1619   Value cstNegHalf = bcast(f32Cst(builder, -0.5f));
1620   Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u));
1621 
1622   Value negHalf = builder.create<arith::MulFOp>(op.getOperand(), cstNegHalf);
1623 
1624   // Select only the inverse sqrt of positive normals (denormals are
1625   // flushed to zero).
1626   Value ltMinMask = builder.create<arith::CmpFOp>(
1627       arith::CmpFPredicate::OLT, op.getOperand(), cstMinNormPos);
1628   Value infMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
1629                                                 op.getOperand(), cstPosInf);
1630   Value notNormalFiniteMask = builder.create<arith::OrIOp>(ltMinMask, infMask);
1631 
1632   // Compute an approximate result.
1633   Value yApprox = handleMultidimensionalVectors(
1634       builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value {
1635         return builder.create<x86vector::RsqrtOp>(operands);
1636       });
1637 
1638   // Do a single step of Newton-Raphson iteration to improve the approximation.
1639   // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).
1640   // It is essential to evaluate the inner term like this because forming
1641   // y_n^2 may over- or underflow.
1642   Value inner = builder.create<arith::MulFOp>(negHalf, yApprox);
1643   Value fma = builder.create<math::FmaOp>(yApprox, inner, cstOnePointFive);
1644   Value yNewton = builder.create<arith::MulFOp>(yApprox, fma);
1645 
1646   // Select the result of the Newton-Raphson step for positive normal arguments.
1647   // For other arguments, choose the output of the intrinsic. This will
1648   // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if
1649   // x is zero or a positive denormalized float (equivalent to flushing positive
1650   // denormalized inputs to zero).
1651   Value res =
1652       builder.create<arith::SelectOp>(notNormalFiniteMask, yApprox, yNewton);
1653   rewriter.replaceOp(op, res);
1654 
1655   return success();
1656 }
1657 
1658 //----------------------------------------------------------------------------//
1659 
1660 void mlir::populatePolynomialApproximateTanhPattern(
1661     RewritePatternSet &patterns) {
1662   patterns.add<TanhApproximation>(patterns.getContext());
1663 }
1664 
1665 void mlir::populatePolynomialApproximateErfPattern(
1666     RewritePatternSet &patterns) {
1667   patterns.add<ErfPolynomialApproximation>(patterns.getContext());
1668 }
1669 
1670 void mlir::populateMathPolynomialApproximationPatterns(
1671     RewritePatternSet &patterns,
1672     const MathPolynomialApproximationOptions &options) {
1673   // Patterns for leveraging existing f32 lowerings on other data types.
1674   patterns
1675       .add<ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
1676            ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
1677            ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
1678            ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
1679            ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
1680            ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
1681           patterns.getContext());
1682 
1683   patterns
1684       .add<AtanApproximation, Atan2Approximation, TanhApproximation,
1685            LogApproximation, Log2Approximation, Log1pApproximation,
1686            ErfPolynomialApproximation, AsinPolynomialApproximation,
1687            AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1688            CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
1689            SinAndCosApproximation<false, math::CosOp>>(patterns.getContext());
1690   if (options.enableAvx2) {
1691     patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
1692         patterns.getContext());
1693   }
1694 }
1695