xref: /llvm-project/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp (revision 72f362170fbe7834febfe08d8bd1f7ba935ddbc9)
1f99ccf65SEugene Zhulenev //===- PolynomialApproximation.cpp - Approximate math operations ----------===//
2f99ccf65SEugene Zhulenev //
3f99ccf65SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f99ccf65SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information.
5f99ccf65SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f99ccf65SEugene Zhulenev //
7f99ccf65SEugene Zhulenev //===----------------------------------------------------------------------===//
8f99ccf65SEugene Zhulenev //
9f99ccf65SEugene Zhulenev // This file implements expansion of math operations to fast approximations
10f99ccf65SEugene Zhulenev // that do not rely on any of the library functions.
11f99ccf65SEugene Zhulenev //
12f99ccf65SEugene Zhulenev //===----------------------------------------------------------------------===//
133a506b31SChris Lattner 
14ec32d540SEugene Zhulenev #include <climits>
150bedb667SRobert Suderman #include <cmath>
16ec32d540SEugene Zhulenev #include <cstddef>
17ec32d540SEugene Zhulenev 
18abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
19f99ccf65SEugene Zhulenev #include "mlir/Dialect/Math/IR/Math.h"
20f1b92218SBoian Petkantchin #include "mlir/Dialect/Math/Transforms/Approximation.h"
21f99ccf65SEugene Zhulenev #include "mlir/Dialect/Math/Transforms/Passes.h"
2299ef9eebSMatthias Springer #include "mlir/Dialect/Utils/IndexingUtils.h"
2399ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
2499ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
2535553d45SEmilio Cota #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
26f99ccf65SEugene Zhulenev #include "mlir/IR/Builders.h"
27bbddd19eSJacques Pienaar #include "mlir/IR/BuiltinTypes.h"
28ce976d2dSEugene Zhulenev #include "mlir/IR/ImplicitLocOpBuilder.h"
29bbddd19eSJacques Pienaar #include "mlir/IR/OpDefinition.h"
30bbddd19eSJacques Pienaar #include "mlir/IR/PatternMatch.h"
31ec32d540SEugene Zhulenev #include "mlir/IR/TypeUtilities.h"
32f99ccf65SEugene Zhulenev #include "mlir/Transforms/DialectConversion.h"
33f99ccf65SEugene Zhulenev #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
34f1b92218SBoian Petkantchin #include "llvm/ADT/ArrayRef.h"
35bbddd19eSJacques Pienaar #include "llvm/ADT/STLExtras.h"
3675bd41f6SJacques Pienaar #include "llvm/Support/MathExtras.h"
37f99ccf65SEugene Zhulenev 
38f99ccf65SEugene Zhulenev using namespace mlir;
39f1b92218SBoian Petkantchin using namespace mlir::math;
40f99ccf65SEugene Zhulenev using namespace mlir::vector;
41f99ccf65SEugene Zhulenev 
42e74bcecdSBenjamin Maxwell // Helper to encapsulate a vector's shape (including scalable dims).
43e74bcecdSBenjamin Maxwell struct VectorShape {
44e74bcecdSBenjamin Maxwell   ArrayRef<int64_t> sizes;
45e74bcecdSBenjamin Maxwell   ArrayRef<bool> scalableFlags;
46e74bcecdSBenjamin Maxwell };
47e74bcecdSBenjamin Maxwell 
48*72f36217SKunwar Grover // Returns vector shape if the type is a vector, otherwise return nullopt.
49*72f36217SKunwar Grover static std::optional<VectorShape> vectorShape(Type type) {
50*72f36217SKunwar Grover   if (auto vectorType = dyn_cast<VectorType>(type)) {
51*72f36217SKunwar Grover     return VectorShape{vectorType.getShape(), vectorType.getScalableDims()};
52*72f36217SKunwar Grover   }
53*72f36217SKunwar Grover   return std::nullopt;
54ce976d2dSEugene Zhulenev }
55ce976d2dSEugene Zhulenev 
56*72f36217SKunwar Grover static std::optional<VectorShape> vectorShape(Value value) {
57ec32d540SEugene Zhulenev   return vectorShape(value.getType());
5839b2cd40SEugene Zhulenev }
5939b2cd40SEugene Zhulenev 
60f99ccf65SEugene Zhulenev //----------------------------------------------------------------------------//
61ce976d2dSEugene Zhulenev // Broadcast scalar types and values into vector types and values.
62f99ccf65SEugene Zhulenev //----------------------------------------------------------------------------//
63f99ccf65SEugene Zhulenev 
6496cee297SAlexander Belyaev // Broadcasts scalar type into vector type (iff shape is non-scalar).
65*72f36217SKunwar Grover static Type broadcast(Type type, std::optional<VectorShape> shape) {
665550c821STres Popp   assert(!isa<VectorType>(type) && "must be scalar type");
67*72f36217SKunwar Grover   return shape ? VectorType::get(shape->sizes, type, shape->scalableFlags)
68e74bcecdSBenjamin Maxwell                : type;
6996cee297SAlexander Belyaev }
7096cee297SAlexander Belyaev 
7196cee297SAlexander Belyaev // Broadcasts scalar value into vector (iff shape is non-scalar).
7296cee297SAlexander Belyaev static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
73*72f36217SKunwar Grover                        std::optional<VectorShape> shape) {
745550c821STres Popp   assert(!isa<VectorType>(value.getType()) && "must be scalar value");
7596cee297SAlexander Belyaev   auto type = broadcast(value.getType(), shape);
76*72f36217SKunwar Grover   return shape ? builder.create<BroadcastOp>(type, value) : value;
77ce976d2dSEugene Zhulenev }
78f99ccf65SEugene Zhulenev 
79ce976d2dSEugene Zhulenev //----------------------------------------------------------------------------//
80627fa0b9SEugene Zhulenev // Helper function to handle n-D vectors with 1-D operations.
81627fa0b9SEugene Zhulenev //----------------------------------------------------------------------------//
82627fa0b9SEugene Zhulenev 
83627fa0b9SEugene Zhulenev // Expands and unrolls n-D vector operands into multiple fixed size 1-D vectors
84627fa0b9SEugene Zhulenev // and calls the compute function with 1-D vector operands. Stitches back all
85627fa0b9SEugene Zhulenev // results into the original n-D vector result.
86627fa0b9SEugene Zhulenev //
87627fa0b9SEugene Zhulenev // Examples: vectorWidth = 8
88627fa0b9SEugene Zhulenev //   - vector<4x8xf32> unrolled 4 times
89627fa0b9SEugene Zhulenev //   - vector<16xf32> expanded to vector<2x8xf32> and unrolled 2 times
90627fa0b9SEugene Zhulenev //   - vector<4x16xf32> expanded to vector<4x2x8xf32> and unrolled 4*2 times
91627fa0b9SEugene Zhulenev //
92627fa0b9SEugene Zhulenev // Some math approximations rely on ISA-specific operations that only accept
93627fa0b9SEugene Zhulenev // fixed size 1-D vectors (e.g. AVX expects vectors of width 8).
94627fa0b9SEugene Zhulenev //
95627fa0b9SEugene Zhulenev // It is the caller's responsibility to verify that the inner dimension is
96627fa0b9SEugene Zhulenev // divisible by the vectorWidth, and that all operands have the same vector
97627fa0b9SEugene Zhulenev // shape.
98627fa0b9SEugene Zhulenev static Value
99627fa0b9SEugene Zhulenev handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
100627fa0b9SEugene Zhulenev                               ValueRange operands, int64_t vectorWidth,
1011fc096afSMehdi Amini                               llvm::function_ref<Value(ValueRange)> compute) {
102627fa0b9SEugene Zhulenev   assert(!operands.empty() && "operands must be not empty");
103627fa0b9SEugene Zhulenev   assert(vectorWidth > 0 && "vector width must be larger than 0");
104627fa0b9SEugene Zhulenev 
1055550c821STres Popp   VectorType inputType = cast<VectorType>(operands[0].getType());
106627fa0b9SEugene Zhulenev   ArrayRef<int64_t> inputShape = inputType.getShape();
107627fa0b9SEugene Zhulenev 
108627fa0b9SEugene Zhulenev   // If input shape matches target vector width, we can just call the
109627fa0b9SEugene Zhulenev   // user-provided compute function with the operands.
110984b800aSserge-sans-paille   if (inputShape == llvm::ArrayRef(vectorWidth))
111627fa0b9SEugene Zhulenev     return compute(operands);
112627fa0b9SEugene Zhulenev 
113627fa0b9SEugene Zhulenev   // Check if the inner dimension has to be expanded, or we can directly iterate
114627fa0b9SEugene Zhulenev   // over the outer dimensions of the vector.
115627fa0b9SEugene Zhulenev   int64_t innerDim = inputShape.back();
116627fa0b9SEugene Zhulenev   int64_t expansionDim = innerDim / vectorWidth;
117627fa0b9SEugene Zhulenev   assert((innerDim % vectorWidth == 0) && "invalid inner dimension size");
118627fa0b9SEugene Zhulenev 
119627fa0b9SEugene Zhulenev   // Maybe expand operands to the higher rank vector shape that we'll use to
120627fa0b9SEugene Zhulenev   // iterate over and extract one dimensional vectors.
1215262865aSKazu Hirata   SmallVector<int64_t> expandedShape(inputShape);
122627fa0b9SEugene Zhulenev   SmallVector<Value> expandedOperands(operands);
123627fa0b9SEugene Zhulenev 
124627fa0b9SEugene Zhulenev   if (expansionDim > 1) {
125627fa0b9SEugene Zhulenev     // Expand shape from [..., innerDim] to [..., expansionDim, vectorWidth].
126627fa0b9SEugene Zhulenev     expandedShape.insert(expandedShape.end() - 1, expansionDim);
127627fa0b9SEugene Zhulenev     expandedShape.back() = vectorWidth;
128627fa0b9SEugene Zhulenev 
129627fa0b9SEugene Zhulenev     for (unsigned i = 0; i < operands.size(); ++i) {
130627fa0b9SEugene Zhulenev       auto operand = operands[i];
1315550c821STres Popp       auto eltType = cast<VectorType>(operand.getType()).getElementType();
132627fa0b9SEugene Zhulenev       auto expandedType = VectorType::get(expandedShape, eltType);
133627fa0b9SEugene Zhulenev       expandedOperands[i] =
134627fa0b9SEugene Zhulenev           builder.create<vector::ShapeCastOp>(expandedType, operand);
135627fa0b9SEugene Zhulenev     }
136627fa0b9SEugene Zhulenev   }
137627fa0b9SEugene Zhulenev 
138627fa0b9SEugene Zhulenev   // Iterate over all outer dimensions of the compute shape vector type.
139627fa0b9SEugene Zhulenev   auto iterationDims = ArrayRef<int64_t>(expandedShape).drop_back();
1407a69a9d7SNicolas Vasilache   int64_t maxIndex = computeMaxLinearIndex(iterationDims);
1417a69a9d7SNicolas Vasilache   auto strides = computeStrides(iterationDims);
142627fa0b9SEugene Zhulenev 
143627fa0b9SEugene Zhulenev   // Compute results for each one dimensional vector.
1447a69a9d7SNicolas Vasilache   SmallVector<Value> results(maxIndex);
145627fa0b9SEugene Zhulenev 
1467a69a9d7SNicolas Vasilache   for (int64_t i = 0; i < maxIndex; ++i) {
147203fad47SNicolas Vasilache     auto offsets = delinearize(i, strides);
148627fa0b9SEugene Zhulenev 
149627fa0b9SEugene Zhulenev     SmallVector<Value> extracted(expandedOperands.size());
15089de9cc8SMehdi Amini     for (const auto &tuple : llvm::enumerate(expandedOperands))
151627fa0b9SEugene Zhulenev       extracted[tuple.index()] =
152627fa0b9SEugene Zhulenev           builder.create<vector::ExtractOp>(tuple.value(), offsets);
153627fa0b9SEugene Zhulenev 
154627fa0b9SEugene Zhulenev     results[i] = compute(extracted);
155627fa0b9SEugene Zhulenev   }
156627fa0b9SEugene Zhulenev 
157627fa0b9SEugene Zhulenev   // Stitch results together into one large vector.
1585550c821STres Popp   Type resultEltType = cast<VectorType>(results[0].getType()).getElementType();
159627fa0b9SEugene Zhulenev   Type resultExpandedType = VectorType::get(expandedShape, resultEltType);
1608e123ca6SRiver Riddle   Value result = builder.create<arith::ConstantOp>(
161627fa0b9SEugene Zhulenev       resultExpandedType, builder.getZeroAttr(resultExpandedType));
162627fa0b9SEugene Zhulenev 
1637a69a9d7SNicolas Vasilache   for (int64_t i = 0; i < maxIndex; ++i)
164627fa0b9SEugene Zhulenev     result = builder.create<vector::InsertOp>(results[i], result,
165203fad47SNicolas Vasilache                                               delinearize(i, strides));
166627fa0b9SEugene Zhulenev 
167627fa0b9SEugene Zhulenev   // Reshape back to the original vector shape.
168627fa0b9SEugene Zhulenev   return builder.create<vector::ShapeCastOp>(
169627fa0b9SEugene Zhulenev       VectorType::get(inputShape, resultEltType), result);
170627fa0b9SEugene Zhulenev }
171627fa0b9SEugene Zhulenev 
172627fa0b9SEugene Zhulenev //----------------------------------------------------------------------------//
173ce976d2dSEugene Zhulenev // Helper functions to create constants.
174ce976d2dSEugene Zhulenev //----------------------------------------------------------------------------//
175f99ccf65SEugene Zhulenev 
176880b8f4eSPrashant Kumar static Value floatCst(ImplicitLocOpBuilder &builder, float value,
177880b8f4eSPrashant Kumar                       Type elementType) {
17875076f8dSKazu Hirata   assert((elementType.isF16() || elementType.isF32()) &&
17975076f8dSKazu Hirata          "x must be f16 or f32 type.");
180880b8f4eSPrashant Kumar   return builder.create<arith::ConstantOp>(
181880b8f4eSPrashant Kumar       builder.getFloatAttr(elementType, value));
182880b8f4eSPrashant Kumar }
183880b8f4eSPrashant Kumar 
1840bedb667SRobert Suderman static Value f32Cst(ImplicitLocOpBuilder &builder, double value) {
185a54f4eaeSMogball   return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value));
186ce976d2dSEugene Zhulenev }
187f99ccf65SEugene Zhulenev 
188ce976d2dSEugene Zhulenev static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) {
189a54f4eaeSMogball   return builder.create<arith::ConstantOp>(builder.getI32IntegerAttr(value));
190ce976d2dSEugene Zhulenev }
191f99ccf65SEugene Zhulenev 
192ce976d2dSEugene Zhulenev static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) {
193ce976d2dSEugene Zhulenev   Value i32Value = i32Cst(builder, static_cast<int32_t>(bits));
194a54f4eaeSMogball   return builder.create<arith::BitcastOp>(builder.getF32Type(), i32Value);
195ce976d2dSEugene Zhulenev }
196f99ccf65SEugene Zhulenev 
197ce976d2dSEugene Zhulenev //----------------------------------------------------------------------------//
198ce976d2dSEugene Zhulenev // Helper functions to build math functions approximations.
199ce976d2dSEugene Zhulenev //----------------------------------------------------------------------------//
200ce976d2dSEugene Zhulenev 
201f5efe280STres Popp // Return the minimum of the two values or NaN if value is NaN
202f5efe280STres Popp static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound) {
203dec8af70SRiver Riddle   return builder.create<arith::SelectOp>(
204f5efe280STres Popp       builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, value, bound),
205f5efe280STres Popp       value, bound);
206ce976d2dSEugene Zhulenev }
207ce976d2dSEugene Zhulenev 
208f5efe280STres Popp // Return the maximum of the two values or NaN if value is NaN
209f5efe280STres Popp static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound) {
210dec8af70SRiver Riddle   return builder.create<arith::SelectOp>(
211f5efe280STres Popp       builder.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, value, bound),
212f5efe280STres Popp       value, bound);
213ce976d2dSEugene Zhulenev }
214ce976d2dSEugene Zhulenev 
215f5efe280STres Popp // Return the clamped value or NaN if value is NaN
216ce976d2dSEugene Zhulenev static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
217ce976d2dSEugene Zhulenev                    Value upperBound) {
218ce976d2dSEugene Zhulenev   return max(builder, min(builder, value, upperBound), lowerBound);
219ce976d2dSEugene Zhulenev }
220ce976d2dSEugene Zhulenev 
221ce976d2dSEugene Zhulenev // Decomposes given floating point value `arg` into a normalized fraction and
222ce976d2dSEugene Zhulenev // an integral power of two (see std::frexp). Returned values have float type.
223ce976d2dSEugene Zhulenev static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
22402b6fb21SMehdi Amini                                      bool isPositive = false) {
225ec32d540SEugene Zhulenev   assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
226*72f36217SKunwar Grover   std::optional<VectorShape> shape = vectorShape(arg);
227ce976d2dSEugene Zhulenev 
228ce976d2dSEugene Zhulenev   auto bcast = [&](Value value) -> Value {
22996cee297SAlexander Belyaev     return broadcast(builder, value, shape);
230f99ccf65SEugene Zhulenev   };
231f99ccf65SEugene Zhulenev 
232ce976d2dSEugene Zhulenev   auto i32 = builder.getIntegerType(32);
23396cee297SAlexander Belyaev   auto i32Vec = broadcast(i32, shape);
23496cee297SAlexander Belyaev   auto f32Vec = broadcast(builder.getF32Type(), shape);
235f99ccf65SEugene Zhulenev 
236ce976d2dSEugene Zhulenev   Value cst126f = f32Cst(builder, 126.0f);
237ce976d2dSEugene Zhulenev   Value cstHalf = f32Cst(builder, 0.5f);
238ce976d2dSEugene Zhulenev   Value cstInvMantMask = f32FromBits(builder, ~0x7f800000u);
239f99ccf65SEugene Zhulenev 
240ce976d2dSEugene Zhulenev   // Bitcast to i32 for bitwise operations.
241a54f4eaeSMogball   Value i32Half = builder.create<arith::BitcastOp>(i32, cstHalf);
242a54f4eaeSMogball   Value i32InvMantMask = builder.create<arith::BitcastOp>(i32, cstInvMantMask);
243a54f4eaeSMogball   Value i32Arg = builder.create<arith::BitcastOp>(i32Vec, arg);
244f99ccf65SEugene Zhulenev 
245ce976d2dSEugene Zhulenev   // Compute normalized fraction.
246a54f4eaeSMogball   Value tmp0 = builder.create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask));
247a54f4eaeSMogball   Value tmp1 = builder.create<arith::OrIOp>(tmp0, bcast(i32Half));
248a54f4eaeSMogball   Value normalizedFraction = builder.create<arith::BitcastOp>(f32Vec, tmp1);
249f99ccf65SEugene Zhulenev 
250ce976d2dSEugene Zhulenev   // Compute exponent.
25100f7096dSJeff Niu   Value arg0 = isPositive ? arg : builder.create<math::AbsFOp>(arg);
252a54f4eaeSMogball   Value biasedExponentBits = builder.create<arith::ShRUIOp>(
253a54f4eaeSMogball       builder.create<arith::BitcastOp>(i32Vec, arg0),
254a54f4eaeSMogball       bcast(i32Cst(builder, 23)));
255a54f4eaeSMogball   Value biasedExponent =
256a54f4eaeSMogball       builder.create<arith::SIToFPOp>(f32Vec, biasedExponentBits);
257a54f4eaeSMogball   Value exponent =
258a54f4eaeSMogball       builder.create<arith::SubFOp>(biasedExponent, bcast(cst126f));
259f99ccf65SEugene Zhulenev 
260ce976d2dSEugene Zhulenev   return {normalizedFraction, exponent};
261f99ccf65SEugene Zhulenev }
262f99ccf65SEugene Zhulenev 
263ea7f211bSAhmed Taei // Computes exp2 for an i32 argument.
264ea7f211bSAhmed Taei static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
265ec32d540SEugene Zhulenev   assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type");
266*72f36217SKunwar Grover   std::optional<VectorShape> shape = vectorShape(arg);
267ea7f211bSAhmed Taei 
268ea7f211bSAhmed Taei   auto bcast = [&](Value value) -> Value {
26996cee297SAlexander Belyaev     return broadcast(builder, value, shape);
270ea7f211bSAhmed Taei   };
271ea7f211bSAhmed Taei 
27296cee297SAlexander Belyaev   auto f32Vec = broadcast(builder.getF32Type(), shape);
273ea7f211bSAhmed Taei   // The exponent of f32 located at 23-bit.
274ea7f211bSAhmed Taei   auto exponetBitLocation = bcast(i32Cst(builder, 23));
275ea7f211bSAhmed Taei   // Set the exponent bias to zero.
276ea7f211bSAhmed Taei   auto bias = bcast(i32Cst(builder, 127));
277ea7f211bSAhmed Taei 
278a54f4eaeSMogball   Value biasedArg = builder.create<arith::AddIOp>(arg, bias);
279ea7f211bSAhmed Taei   Value exp2ValueInt =
280a54f4eaeSMogball       builder.create<arith::ShLIOp>(biasedArg, exponetBitLocation);
281a54f4eaeSMogball   Value exp2ValueF32 = builder.create<arith::BitcastOp>(f32Vec, exp2ValueInt);
282ea7f211bSAhmed Taei 
283ea7f211bSAhmed Taei   return exp2ValueF32;
284ea7f211bSAhmed Taei }
285ea7f211bSAhmed Taei 
286f1b92218SBoian Petkantchin namespace {
287f1b92218SBoian Petkantchin Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
288f1b92218SBoian Petkantchin                                 llvm::ArrayRef<Value> coeffs, Value x) {
289880b8f4eSPrashant Kumar   Type elementType = getElementTypeOrSelf(x);
29075076f8dSKazu Hirata   assert((elementType.isF32() || elementType.isF16()) &&
29175076f8dSKazu Hirata          "x must be f32 or f16 type");
292*72f36217SKunwar Grover   std::optional<VectorShape> shape = vectorShape(x);
293ec32d540SEugene Zhulenev 
294ec32d540SEugene Zhulenev   if (coeffs.empty())
295880b8f4eSPrashant Kumar     return broadcast(builder, floatCst(builder, 0.0f, elementType), shape);
296ec32d540SEugene Zhulenev 
297ec32d540SEugene Zhulenev   if (coeffs.size() == 1)
298f1b92218SBoian Petkantchin     return coeffs[0];
299ec32d540SEugene Zhulenev 
300f1b92218SBoian Petkantchin   Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1],
301f1b92218SBoian Petkantchin                                           coeffs[coeffs.size() - 2]);
302f1b92218SBoian Petkantchin   for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) {
303f1b92218SBoian Petkantchin     res = builder.create<math::FmaOp>(x, res, coeffs[i]);
304f1b92218SBoian Petkantchin   }
305f1b92218SBoian Petkantchin   return res;
306f1b92218SBoian Petkantchin }
307f1b92218SBoian Petkantchin } // namespace
308f1b92218SBoian Petkantchin 
309f99ccf65SEugene Zhulenev //----------------------------------------------------------------------------//
310bbddd19eSJacques Pienaar // Helper function/pattern to insert casts for reusing F32 bit expansion.
311bbddd19eSJacques Pienaar //----------------------------------------------------------------------------//
312bbddd19eSJacques Pienaar 
313bbddd19eSJacques Pienaar template <typename T>
314bbddd19eSJacques Pienaar LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) {
315bbddd19eSJacques Pienaar   // Conservatively only allow where the operand and result types are exactly 1.
316bbddd19eSJacques Pienaar   Type origType = op->getResultTypes().front();
317bbddd19eSJacques Pienaar   for (Type t : llvm::drop_begin(op->getResultTypes()))
318bbddd19eSJacques Pienaar     if (origType != t)
319bbddd19eSJacques Pienaar       return rewriter.notifyMatchFailure(op, "required all types to match");
320bbddd19eSJacques Pienaar   for (Type t : op->getOperandTypes())
321bbddd19eSJacques Pienaar     if (origType != t)
322bbddd19eSJacques Pienaar       return rewriter.notifyMatchFailure(op, "required all types to match");
323bbddd19eSJacques Pienaar 
324bbddd19eSJacques Pienaar   // Skip if already F32  or larger than 32 bits.
325bbddd19eSJacques Pienaar   if (getElementTypeOrSelf(origType).isF32() ||
326bbddd19eSJacques Pienaar       getElementTypeOrSelf(origType).getIntOrFloatBitWidth() > 32)
327bbddd19eSJacques Pienaar     return failure();
328bbddd19eSJacques Pienaar 
329bbddd19eSJacques Pienaar   // Create F32 equivalent type.
330bbddd19eSJacques Pienaar   Type newType;
3315550c821STres Popp   if (auto shaped = dyn_cast<ShapedType>(origType)) {
332bbddd19eSJacques Pienaar     newType = shaped.clone(rewriter.getF32Type());
3335550c821STres Popp   } else if (isa<FloatType>(origType)) {
334bbddd19eSJacques Pienaar     newType = rewriter.getF32Type();
335bbddd19eSJacques Pienaar   } else {
336bbddd19eSJacques Pienaar     return rewriter.notifyMatchFailure(op,
337bbddd19eSJacques Pienaar                                        "unable to find F32 equivalent type");
338bbddd19eSJacques Pienaar   }
339bbddd19eSJacques Pienaar 
340bbddd19eSJacques Pienaar   Location loc = op->getLoc();
341bbddd19eSJacques Pienaar   SmallVector<Value> operands;
342bbddd19eSJacques Pienaar   for (auto operand : op->getOperands())
343bbddd19eSJacques Pienaar     operands.push_back(rewriter.create<arith::ExtFOp>(loc, newType, operand));
34457e1943eSRobert Suderman   auto result =
34557e1943eSRobert Suderman       rewriter.create<T>(loc, TypeRange{newType}, operands, op->getAttrs());
346bbddd19eSJacques Pienaar   rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, origType, result);
347bbddd19eSJacques Pienaar   return success();
348bbddd19eSJacques Pienaar }
349bbddd19eSJacques Pienaar 
350bbddd19eSJacques Pienaar namespace {
351bbddd19eSJacques Pienaar // Pattern to cast to F32 to reuse F32 expansion as fallback for single-result
352bbddd19eSJacques Pienaar // op.
353bbddd19eSJacques Pienaar // TODO: Consider revising to avoid adding multiple casts for a subgraph that is
354bbddd19eSJacques Pienaar // all in lower precision. Currently this is only fallback support and performs
355bbddd19eSJacques Pienaar // simplistic casting.
356bbddd19eSJacques Pienaar template <typename T>
357bbddd19eSJacques Pienaar struct ReuseF32Expansion : public OpRewritePattern<T> {
358bbddd19eSJacques Pienaar public:
359bbddd19eSJacques Pienaar   using OpRewritePattern<T>::OpRewritePattern;
360bbddd19eSJacques Pienaar   LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const final {
361bbddd19eSJacques Pienaar     static_assert(
362bbddd19eSJacques Pienaar         T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(),
363bbddd19eSJacques Pienaar         "requires same operands and result types");
364bbddd19eSJacques Pienaar     return insertCasts<T>(op, rewriter);
365bbddd19eSJacques Pienaar   }
366bbddd19eSJacques Pienaar };
367bbddd19eSJacques Pienaar } // namespace
368bbddd19eSJacques Pienaar 
369bbddd19eSJacques Pienaar //----------------------------------------------------------------------------//
3702f9f9afaSRob Suderman // AtanOp approximation.
3712f9f9afaSRob Suderman //----------------------------------------------------------------------------//
3722f9f9afaSRob Suderman 
3732f9f9afaSRob Suderman namespace {
3742f9f9afaSRob Suderman struct AtanApproximation : public OpRewritePattern<math::AtanOp> {
3752f9f9afaSRob Suderman public:
3762f9f9afaSRob Suderman   using OpRewritePattern::OpRewritePattern;
3772f9f9afaSRob Suderman 
3782f9f9afaSRob Suderman   LogicalResult matchAndRewrite(math::AtanOp op,
3792f9f9afaSRob Suderman                                 PatternRewriter &rewriter) const final;
3802f9f9afaSRob Suderman };
3812f9f9afaSRob Suderman } // namespace
3822f9f9afaSRob Suderman 
3832f9f9afaSRob Suderman LogicalResult
3842f9f9afaSRob Suderman AtanApproximation::matchAndRewrite(math::AtanOp op,
3852f9f9afaSRob Suderman                                    PatternRewriter &rewriter) const {
3862f9f9afaSRob Suderman   auto operand = op.getOperand();
3872f9f9afaSRob Suderman   if (!getElementTypeOrSelf(operand).isF32())
3882f9f9afaSRob Suderman     return rewriter.notifyMatchFailure(op, "unsupported operand type");
3892f9f9afaSRob Suderman 
390*72f36217SKunwar Grover   std::optional<VectorShape> shape = vectorShape(op.getOperand());
3912f9f9afaSRob Suderman 
3922f9f9afaSRob Suderman   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
39300f7096dSJeff Niu   Value abs = builder.create<math::AbsFOp>(operand);
3940bedb667SRobert Suderman 
3950bedb667SRobert Suderman   auto one = broadcast(builder, f32Cst(builder, 1.0), shape);
3960bedb667SRobert Suderman 
3970bedb667SRobert Suderman   // When 0.66 < x <= 2.41 we do (x-1) / (x+1):
3980bedb667SRobert Suderman   auto twoThirds = broadcast(builder, f32Cst(builder, 0.66), shape);
3990bedb667SRobert Suderman   Value cmp2 =
4000bedb667SRobert Suderman       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, twoThirds);
4010bedb667SRobert Suderman   Value addone = builder.create<arith::AddFOp>(abs, one);
4020bedb667SRobert Suderman   Value subone = builder.create<arith::SubFOp>(abs, one);
4030bedb667SRobert Suderman   Value xnum = builder.create<arith::SelectOp>(cmp2, subone, abs);
4040bedb667SRobert Suderman   Value xden = builder.create<arith::SelectOp>(cmp2, addone, one);
4050bedb667SRobert Suderman 
4060bedb667SRobert Suderman   auto bcast = [&](Value value) -> Value {
4070bedb667SRobert Suderman     return broadcast(builder, value, shape);
4080bedb667SRobert Suderman   };
4090bedb667SRobert Suderman 
4100bedb667SRobert Suderman   // Break into the <= 0.66 or > 2.41 we do x or 1/x:
4110bedb667SRobert Suderman   auto tan3pio8 = bcast(f32Cst(builder, 2.41421356237309504880));
4120bedb667SRobert Suderman   Value cmp1 =
4130bedb667SRobert Suderman       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, tan3pio8);
4140bedb667SRobert Suderman   xnum = builder.create<arith::SelectOp>(cmp1, one, xnum);
4150bedb667SRobert Suderman   xden = builder.create<arith::SelectOp>(cmp1, abs, xden);
4160bedb667SRobert Suderman 
4170bedb667SRobert Suderman   Value x = builder.create<arith::DivFOp>(xnum, xden);
4180bedb667SRobert Suderman   Value xx = builder.create<arith::MulFOp>(x, x);
4192f9f9afaSRob Suderman 
4202f9f9afaSRob Suderman   // Perform the Taylor series approximation for atan over the range
4210bedb667SRobert Suderman   // [0.0, 0.66].
4220bedb667SRobert Suderman   auto p0 = bcast(f32Cst(builder, -8.750608600031904122785e-01));
4230bedb667SRobert Suderman   auto p1 = bcast(f32Cst(builder, -1.615753718733365076637e+01));
4240bedb667SRobert Suderman   auto p2 = bcast(f32Cst(builder, -7.500855792314704667340e+01));
4250bedb667SRobert Suderman   auto p3 = bcast(f32Cst(builder, -1.228866684490136173410e+02));
4260bedb667SRobert Suderman   auto p4 = bcast(f32Cst(builder, -6.485021904942025371773e+01));
4270bedb667SRobert Suderman   auto q0 = bcast(f32Cst(builder, +2.485846490142306297962e+01));
4280bedb667SRobert Suderman   auto q1 = bcast(f32Cst(builder, +1.650270098316988542046e+02));
4290bedb667SRobert Suderman   auto q2 = bcast(f32Cst(builder, +4.328810604912902668951e+02));
4300bedb667SRobert Suderman   auto q3 = bcast(f32Cst(builder, +4.853903996359136964868e+02));
4310bedb667SRobert Suderman   auto q4 = bcast(f32Cst(builder, +1.945506571482613964425e+02));
4322f9f9afaSRob Suderman 
4330bedb667SRobert Suderman   // Apply the polynomial approximation for the numerator:
4340bedb667SRobert Suderman   Value n = p0;
4350bedb667SRobert Suderman   n = builder.create<math::FmaOp>(xx, n, p1);
4360bedb667SRobert Suderman   n = builder.create<math::FmaOp>(xx, n, p2);
4370bedb667SRobert Suderman   n = builder.create<math::FmaOp>(xx, n, p3);
4380bedb667SRobert Suderman   n = builder.create<math::FmaOp>(xx, n, p4);
4390bedb667SRobert Suderman   n = builder.create<arith::MulFOp>(n, xx);
4402f9f9afaSRob Suderman 
4410bedb667SRobert Suderman   // Apply the polynomial approximation for the denominator:
4420bedb667SRobert Suderman   Value d = q0;
4430bedb667SRobert Suderman   d = builder.create<math::FmaOp>(xx, d, q1);
4440bedb667SRobert Suderman   d = builder.create<math::FmaOp>(xx, d, q2);
4450bedb667SRobert Suderman   d = builder.create<math::FmaOp>(xx, d, q3);
4460bedb667SRobert Suderman   d = builder.create<math::FmaOp>(xx, d, q4);
4470bedb667SRobert Suderman 
4480bedb667SRobert Suderman   // Compute approximation of theta:
4490bedb667SRobert Suderman   Value ans0 = builder.create<arith::DivFOp>(n, d);
4500bedb667SRobert Suderman   ans0 = builder.create<math::FmaOp>(ans0, x, x);
4510bedb667SRobert Suderman 
4520bedb667SRobert Suderman   // Correct for the input mapping's angles:
45375bd41f6SJacques Pienaar   Value mpi4 = bcast(f32Cst(builder, llvm::numbers::pi / 4));
4540bedb667SRobert Suderman   Value ans2 = builder.create<arith::AddFOp>(mpi4, ans0);
4550bedb667SRobert Suderman   Value ans = builder.create<arith::SelectOp>(cmp2, ans2, ans0);
4560bedb667SRobert Suderman 
45775bd41f6SJacques Pienaar   Value mpi2 = bcast(f32Cst(builder, llvm::numbers::pi / 2));
4580bedb667SRobert Suderman   Value ans1 = builder.create<arith::SubFOp>(mpi2, ans0);
4590bedb667SRobert Suderman   ans = builder.create<arith::SelectOp>(cmp1, ans1, ans);
4602f9f9afaSRob Suderman 
4612f9f9afaSRob Suderman   // Correct for signing of the input.
4620bedb667SRobert Suderman   rewriter.replaceOpWithNewOp<math::CopySignOp>(op, ans, operand);
4632f9f9afaSRob Suderman   return success();
4642f9f9afaSRob Suderman }
4652f9f9afaSRob Suderman 
4662f9f9afaSRob Suderman //----------------------------------------------------------------------------//
4672f9f9afaSRob Suderman // AtanOp approximation.
4682f9f9afaSRob Suderman //----------------------------------------------------------------------------//
4692f9f9afaSRob Suderman 
4702f9f9afaSRob Suderman namespace {
4712f9f9afaSRob Suderman struct Atan2Approximation : public OpRewritePattern<math::Atan2Op> {
4722f9f9afaSRob Suderman public:
4732f9f9afaSRob Suderman   using OpRewritePattern::OpRewritePattern;
4742f9f9afaSRob Suderman 
4752f9f9afaSRob Suderman   LogicalResult matchAndRewrite(math::Atan2Op op,
4762f9f9afaSRob Suderman                                 PatternRewriter &rewriter) const final;
4772f9f9afaSRob Suderman };
4782f9f9afaSRob Suderman } // namespace
4792f9f9afaSRob Suderman 
4802f9f9afaSRob Suderman LogicalResult
4812f9f9afaSRob Suderman Atan2Approximation::matchAndRewrite(math::Atan2Op op,
4822f9f9afaSRob Suderman                                     PatternRewriter &rewriter) const {
4832f9f9afaSRob Suderman   auto y = op.getOperand(0);
4842f9f9afaSRob Suderman   auto x = op.getOperand(1);
4852f9f9afaSRob Suderman   if (!getElementTypeOrSelf(x).isF32())
4862f9f9afaSRob Suderman     return rewriter.notifyMatchFailure(op, "unsupported operand type");
4872f9f9afaSRob Suderman 
4882f9f9afaSRob Suderman   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
489*72f36217SKunwar Grover   std::optional<VectorShape> shape = vectorShape(op.getResult());
4902f9f9afaSRob Suderman 
4912f9f9afaSRob Suderman   // Compute atan in the valid range.
4922f9f9afaSRob Suderman   auto div = builder.create<arith::DivFOp>(y, x);
4932f9f9afaSRob Suderman   auto atan = builder.create<math::AtanOp>(div);
4942f9f9afaSRob Suderman 
4952f9f9afaSRob Suderman   // Determine what the atan would be for a 180 degree rotation.
4962f9f9afaSRob Suderman   auto zero = broadcast(builder, f32Cst(builder, 0.0f), shape);
4972f9f9afaSRob Suderman   auto pi = broadcast(builder, f32Cst(builder, 3.14159265359f), shape);
49870ed93ecSMehdi Amini   auto addPi = builder.create<arith::AddFOp>(atan, pi);
49970ed93ecSMehdi Amini   auto subPi = builder.create<arith::SubFOp>(atan, pi);
50070ed93ecSMehdi Amini   auto atanGt =
5012f9f9afaSRob Suderman       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, atan, zero);
502dec8af70SRiver Riddle   auto flippedAtan = builder.create<arith::SelectOp>(atanGt, subPi, addPi);
5032f9f9afaSRob Suderman 
5042f9f9afaSRob Suderman   // Determine whether to directly use atan or use the 180 degree flip
50570ed93ecSMehdi Amini   auto xGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zero);
506dec8af70SRiver Riddle   Value result = builder.create<arith::SelectOp>(xGt, atan, flippedAtan);
5072f9f9afaSRob Suderman 
5082f9f9afaSRob Suderman   // Handle x = 0, y > 0
50970ed93ecSMehdi Amini   Value xZero =
5102f9f9afaSRob Suderman       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, x, zero);
51170ed93ecSMehdi Amini   Value yGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, y, zero);
51270ed93ecSMehdi Amini   Value isHalfPi = builder.create<arith::AndIOp>(xZero, yGt);
51370ed93ecSMehdi Amini   auto halfPi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape);
514dec8af70SRiver Riddle   result = builder.create<arith::SelectOp>(isHalfPi, halfPi, result);
5152f9f9afaSRob Suderman 
5162f9f9afaSRob Suderman   // Handle x = 0, y < 0
51770ed93ecSMehdi Amini   Value yLt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, y, zero);
51870ed93ecSMehdi Amini   Value isNegativeHalfPiPi = builder.create<arith::AndIOp>(xZero, yLt);
51970ed93ecSMehdi Amini   auto negativeHalfPiPi =
520dc3b9365SAlexandre Ganea       broadcast(builder, f32Cst(builder, -1.57079632679f), shape);
521dec8af70SRiver Riddle   result = builder.create<arith::SelectOp>(isNegativeHalfPiPi, negativeHalfPiPi,
522dec8af70SRiver Riddle                                            result);
5232f9f9afaSRob Suderman 
5242f9f9afaSRob Suderman   // Handle x = 0, y = 0;
52570ed93ecSMehdi Amini   Value yZero =
5262f9f9afaSRob Suderman       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, y, zero);
52770ed93ecSMehdi Amini   Value isNan = builder.create<arith::AndIOp>(xZero, yZero);
52870ed93ecSMehdi Amini   Value cstNan = broadcast(builder, f32FromBits(builder, 0x7fc00000), shape);
529dec8af70SRiver Riddle   result = builder.create<arith::SelectOp>(isNan, cstNan, result);
5302f9f9afaSRob Suderman 
5312f9f9afaSRob Suderman   rewriter.replaceOp(op, result);
5322f9f9afaSRob Suderman   return success();
5332f9f9afaSRob Suderman }
5342f9f9afaSRob Suderman 
5352f9f9afaSRob Suderman //----------------------------------------------------------------------------//
536f99ccf65SEugene Zhulenev // TanhOp approximation.
537f99ccf65SEugene Zhulenev //----------------------------------------------------------------------------//
538f99ccf65SEugene Zhulenev 
539f99ccf65SEugene Zhulenev namespace {
540f99ccf65SEugene Zhulenev struct TanhApproximation : public OpRewritePattern<math::TanhOp> {
541f99ccf65SEugene Zhulenev public:
542f99ccf65SEugene Zhulenev   using OpRewritePattern::OpRewritePattern;
543f99ccf65SEugene Zhulenev 
544f99ccf65SEugene Zhulenev   LogicalResult matchAndRewrite(math::TanhOp op,
545f99ccf65SEugene Zhulenev                                 PatternRewriter &rewriter) const final;
546f99ccf65SEugene Zhulenev };
547f99ccf65SEugene Zhulenev } // namespace
548f99ccf65SEugene Zhulenev 
549f99ccf65SEugene Zhulenev LogicalResult
550f99ccf65SEugene Zhulenev TanhApproximation::matchAndRewrite(math::TanhOp op,
551f99ccf65SEugene Zhulenev                                    PatternRewriter &rewriter) const {
55262fea88bSJacques Pienaar   if (!getElementTypeOrSelf(op.getOperand()).isF32())
553f99ccf65SEugene Zhulenev     return rewriter.notifyMatchFailure(op, "unsupported operand type");
554f99ccf65SEugene Zhulenev 
555*72f36217SKunwar Grover   std::optional<VectorShape> shape = vectorShape(op.getOperand());
556ec32d540SEugene Zhulenev 
557ce976d2dSEugene Zhulenev   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
558ce976d2dSEugene Zhulenev   auto bcast = [&](Value value) -> Value {
559ec32d540SEugene Zhulenev     return broadcast(builder, value, shape);
560ce976d2dSEugene Zhulenev   };
561f99ccf65SEugene Zhulenev 
562f99ccf65SEugene Zhulenev   // Clamp operand into [plusClamp, minusClamp] range.
563bf32bb7eSEugene Zhulenev   Value minusClamp = bcast(f32Cst(builder, -7.99881172180175781f));
564bf32bb7eSEugene Zhulenev   Value plusClamp = bcast(f32Cst(builder, 7.99881172180175781f));
56562fea88bSJacques Pienaar   Value x = clamp(builder, op.getOperand(), minusClamp, plusClamp);
566f99ccf65SEugene Zhulenev 
567f99ccf65SEugene Zhulenev   // Mask for tiny values that are approximated with `operand`.
568ce976d2dSEugene Zhulenev   Value tiny = bcast(f32Cst(builder, 0.0004f));
569a54f4eaeSMogball   Value tinyMask = builder.create<arith::CmpFOp>(
57000f7096dSJeff Niu       arith::CmpFPredicate::OLT, builder.create<math::AbsFOp>(op.getOperand()),
571a54f4eaeSMogball       tiny);
572f99ccf65SEugene Zhulenev 
573f99ccf65SEugene Zhulenev   // The monomial coefficients of the numerator polynomial (odd).
574ce976d2dSEugene Zhulenev   Value alpha1 = bcast(f32Cst(builder, 4.89352455891786e-03f));
575ce976d2dSEugene Zhulenev   Value alpha3 = bcast(f32Cst(builder, 6.37261928875436e-04f));
576ce976d2dSEugene Zhulenev   Value alpha5 = bcast(f32Cst(builder, 1.48572235717979e-05f));
577ce976d2dSEugene Zhulenev   Value alpha7 = bcast(f32Cst(builder, 5.12229709037114e-08f));
578ce976d2dSEugene Zhulenev   Value alpha9 = bcast(f32Cst(builder, -8.60467152213735e-11f));
579ce976d2dSEugene Zhulenev   Value alpha11 = bcast(f32Cst(builder, 2.00018790482477e-13f));
580ce976d2dSEugene Zhulenev   Value alpha13 = bcast(f32Cst(builder, -2.76076847742355e-16f));
581f99ccf65SEugene Zhulenev 
582f99ccf65SEugene Zhulenev   // The monomial coefficients of the denominator polynomial (even).
583ce976d2dSEugene Zhulenev   Value beta0 = bcast(f32Cst(builder, 4.89352518554385e-03f));
584ce976d2dSEugene Zhulenev   Value beta2 = bcast(f32Cst(builder, 2.26843463243900e-03f));
585ce976d2dSEugene Zhulenev   Value beta4 = bcast(f32Cst(builder, 1.18534705686654e-04f));
586ce976d2dSEugene Zhulenev   Value beta6 = bcast(f32Cst(builder, 1.19825839466702e-06f));
587f99ccf65SEugene Zhulenev 
588f99ccf65SEugene Zhulenev   // Since the polynomials are odd/even, we need x^2.
589a54f4eaeSMogball   Value x2 = builder.create<arith::MulFOp>(x, x);
590f99ccf65SEugene Zhulenev 
591f99ccf65SEugene Zhulenev   // Evaluate the numerator polynomial p.
592a54f4eaeSMogball   Value p = builder.create<math::FmaOp>(x2, alpha13, alpha11);
593a54f4eaeSMogball   p = builder.create<math::FmaOp>(x2, p, alpha9);
594a54f4eaeSMogball   p = builder.create<math::FmaOp>(x2, p, alpha7);
595a54f4eaeSMogball   p = builder.create<math::FmaOp>(x2, p, alpha5);
596a54f4eaeSMogball   p = builder.create<math::FmaOp>(x2, p, alpha3);
597a54f4eaeSMogball   p = builder.create<math::FmaOp>(x2, p, alpha1);
598a54f4eaeSMogball   p = builder.create<arith::MulFOp>(x, p);
599f99ccf65SEugene Zhulenev 
600f99ccf65SEugene Zhulenev   // Evaluate the denominator polynomial q.
601a54f4eaeSMogball   Value q = builder.create<math::FmaOp>(x2, beta6, beta4);
602a54f4eaeSMogball   q = builder.create<math::FmaOp>(x2, q, beta2);
603a54f4eaeSMogball   q = builder.create<math::FmaOp>(x2, q, beta0);
604f99ccf65SEugene Zhulenev 
605f99ccf65SEugene Zhulenev   // Divide the numerator by the denominator.
606dec8af70SRiver Riddle   Value res = builder.create<arith::SelectOp>(
607dec8af70SRiver Riddle       tinyMask, x, builder.create<arith::DivFOp>(p, q));
608f99ccf65SEugene Zhulenev 
609f99ccf65SEugene Zhulenev   rewriter.replaceOp(op, res);
610f99ccf65SEugene Zhulenev 
611f99ccf65SEugene Zhulenev   return success();
612f99ccf65SEugene Zhulenev }
613f99ccf65SEugene Zhulenev 
614ea7f211bSAhmed Taei #define LN2_VALUE                                                              \
615ea7f211bSAhmed Taei   0.693147180559945309417232121458176568075500134360255254120680009493393621L
616c0891706SEmilio Cota #define LOG2E_VALUE                                                            \
617ea7f211bSAhmed Taei   1.442695040888963407359924681001892137426645954152985934135449406931109219L
618ea7f211bSAhmed Taei 
619f99ccf65SEugene Zhulenev //----------------------------------------------------------------------------//
620c0891706SEmilio Cota // LogOp and Log2Op approximation.
621ce976d2dSEugene Zhulenev //----------------------------------------------------------------------------//
622ce976d2dSEugene Zhulenev 
623ce976d2dSEugene Zhulenev namespace {
624c0891706SEmilio Cota template <typename Op>
625c0891706SEmilio Cota struct LogApproximationBase : public OpRewritePattern<Op> {
626c0891706SEmilio Cota   using OpRewritePattern<Op>::OpRewritePattern;
627ce976d2dSEugene Zhulenev 
628c0891706SEmilio Cota   /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise.
629c0891706SEmilio Cota   LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter,
630c0891706SEmilio Cota                                    bool base2) const;
631ce976d2dSEugene Zhulenev };
632ce976d2dSEugene Zhulenev } // namespace
633ce976d2dSEugene Zhulenev 
634c0891706SEmilio Cota // This approximation comes from Julien Pommier's SSE math library.
635c0891706SEmilio Cota // Link: http://gruntthepeon.free.fr/ssemath
636c0891706SEmilio Cota template <typename Op>
637ce976d2dSEugene Zhulenev LogicalResult
638c0891706SEmilio Cota LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
639c0891706SEmilio Cota                                              bool base2) const {
64062fea88bSJacques Pienaar   if (!getElementTypeOrSelf(op.getOperand()).isF32())
641ce976d2dSEugene Zhulenev     return rewriter.notifyMatchFailure(op, "unsupported operand type");
642ce976d2dSEugene Zhulenev 
643*72f36217SKunwar Grover   std::optional<VectorShape> shape = vectorShape(op.getOperand());
644ec32d540SEugene Zhulenev 
645ce976d2dSEugene Zhulenev   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
646ce976d2dSEugene Zhulenev   auto bcast = [&](Value value) -> Value {
647ec32d540SEugene Zhulenev     return broadcast(builder, value, shape);
648ce976d2dSEugene Zhulenev   };
649ce976d2dSEugene Zhulenev 
650ce976d2dSEugene Zhulenev   Value cstZero = bcast(f32Cst(builder, 0.0f));
651ce976d2dSEugene Zhulenev   Value cstOne = bcast(f32Cst(builder, 1.0f));
652ce976d2dSEugene Zhulenev   Value cstNegHalf = bcast(f32Cst(builder, -0.5f));
653ce976d2dSEugene Zhulenev 
654ce976d2dSEugene Zhulenev   // The smallest non denormalized float number.
655ce976d2dSEugene Zhulenev   Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u));
656ce976d2dSEugene Zhulenev   Value cstMinusInf = bcast(f32FromBits(builder, 0xff800000u));
657ce976d2dSEugene Zhulenev   Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));
658ce976d2dSEugene Zhulenev   Value cstNan = bcast(f32FromBits(builder, 0x7fc00000));
659ce976d2dSEugene Zhulenev 
660ce976d2dSEugene Zhulenev   // Polynomial coefficients.
661ce976d2dSEugene Zhulenev   Value cstCephesSQRTHF = bcast(f32Cst(builder, 0.707106781186547524f));
662ce976d2dSEugene Zhulenev   Value cstCephesLogP0 = bcast(f32Cst(builder, 7.0376836292E-2f));
663ce976d2dSEugene Zhulenev   Value cstCephesLogP1 = bcast(f32Cst(builder, -1.1514610310E-1f));
664ce976d2dSEugene Zhulenev   Value cstCephesLogP2 = bcast(f32Cst(builder, 1.1676998740E-1f));
665ce976d2dSEugene Zhulenev   Value cstCephesLogP3 = bcast(f32Cst(builder, -1.2420140846E-1f));
666ce976d2dSEugene Zhulenev   Value cstCephesLogP4 = bcast(f32Cst(builder, +1.4249322787E-1f));
667ce976d2dSEugene Zhulenev   Value cstCephesLogP5 = bcast(f32Cst(builder, -1.6668057665E-1f));
668ce976d2dSEugene Zhulenev   Value cstCephesLogP6 = bcast(f32Cst(builder, +2.0000714765E-1f));
669ce976d2dSEugene Zhulenev   Value cstCephesLogP7 = bcast(f32Cst(builder, -2.4999993993E-1f));
670ce976d2dSEugene Zhulenev   Value cstCephesLogP8 = bcast(f32Cst(builder, +3.3333331174E-1f));
671ce976d2dSEugene Zhulenev 
67262fea88bSJacques Pienaar   Value x = op.getOperand();
673ce976d2dSEugene Zhulenev 
674ce976d2dSEugene Zhulenev   // Truncate input values to the minimum positive normal.
675ce976d2dSEugene Zhulenev   x = max(builder, x, cstMinNormPos);
676ce976d2dSEugene Zhulenev 
677ce976d2dSEugene Zhulenev   // Extract significant in the range [0.5,1) and exponent.
678ced8690dSMehdi Amini   std::pair<Value, Value> pair = frexp(builder, x, /*isPositive=*/true);
679ce976d2dSEugene Zhulenev   x = pair.first;
680ce976d2dSEugene Zhulenev   Value e = pair.second;
681ce976d2dSEugene Zhulenev 
682ce976d2dSEugene Zhulenev   // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift
683ce976d2dSEugene Zhulenev   // by -1.0. The values are then centered around 0, which improves the
684ce976d2dSEugene Zhulenev   // stability of the polynomial evaluation:
685ce976d2dSEugene Zhulenev   //
686ce976d2dSEugene Zhulenev   //   if( x < SQRTHF ) {
687ce976d2dSEugene Zhulenev   //     e -= 1;
688ce976d2dSEugene Zhulenev   //     x = x + x - 1.0;
689ce976d2dSEugene Zhulenev   //   } else { x = x - 1.0; }
690a54f4eaeSMogball   Value mask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x,
691a54f4eaeSMogball                                              cstCephesSQRTHF);
692dec8af70SRiver Riddle   Value tmp = builder.create<arith::SelectOp>(mask, x, cstZero);
693ce976d2dSEugene Zhulenev 
694a54f4eaeSMogball   x = builder.create<arith::SubFOp>(x, cstOne);
695a54f4eaeSMogball   e = builder.create<arith::SubFOp>(
696dec8af70SRiver Riddle       e, builder.create<arith::SelectOp>(mask, cstOne, cstZero));
697a54f4eaeSMogball   x = builder.create<arith::AddFOp>(x, tmp);
698ce976d2dSEugene Zhulenev 
699a54f4eaeSMogball   Value x2 = builder.create<arith::MulFOp>(x, x);
700a54f4eaeSMogball   Value x3 = builder.create<arith::MulFOp>(x2, x);
701ce976d2dSEugene Zhulenev 
702ce976d2dSEugene Zhulenev   // Evaluate the polynomial approximant of degree 8 in three parts.
703ce976d2dSEugene Zhulenev   Value y0, y1, y2;
704a54f4eaeSMogball   y0 = builder.create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1);
705a54f4eaeSMogball   y1 = builder.create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4);
706a54f4eaeSMogball   y2 = builder.create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7);
707a54f4eaeSMogball   y0 = builder.create<math::FmaOp>(y0, x, cstCephesLogP2);
708a54f4eaeSMogball   y1 = builder.create<math::FmaOp>(y1, x, cstCephesLogP5);
709a54f4eaeSMogball   y2 = builder.create<math::FmaOp>(y2, x, cstCephesLogP8);
710a54f4eaeSMogball   y0 = builder.create<math::FmaOp>(y0, x3, y1);
711a54f4eaeSMogball   y0 = builder.create<math::FmaOp>(y0, x3, y2);
712a54f4eaeSMogball   y0 = builder.create<arith::MulFOp>(y0, x3);
713ce976d2dSEugene Zhulenev 
714a54f4eaeSMogball   y0 = builder.create<math::FmaOp>(cstNegHalf, x2, y0);
715a54f4eaeSMogball   x = builder.create<arith::AddFOp>(x, y0);
716ce976d2dSEugene Zhulenev 
717c0891706SEmilio Cota   if (base2) {
718c0891706SEmilio Cota     Value cstLog2e = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE)));
719a54f4eaeSMogball     x = builder.create<math::FmaOp>(x, cstLog2e, e);
720c0891706SEmilio Cota   } else {
721ce976d2dSEugene Zhulenev     Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE)));
722a54f4eaeSMogball     x = builder.create<math::FmaOp>(e, cstLn2, x);
723c0891706SEmilio Cota   }
724ce976d2dSEugene Zhulenev 
725a54f4eaeSMogball   Value invalidMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT,
72662fea88bSJacques Pienaar                                                     op.getOperand(), cstZero);
727a54f4eaeSMogball   Value zeroMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
72862fea88bSJacques Pienaar                                                  op.getOperand(), cstZero);
729a54f4eaeSMogball   Value posInfMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
73062fea88bSJacques Pienaar                                                    op.getOperand(), cstPosInf);
731ce976d2dSEugene Zhulenev 
732ce976d2dSEugene Zhulenev   // Filter out invalid values:
733ce976d2dSEugene Zhulenev   //  • x == 0     -> -INF
734ce976d2dSEugene Zhulenev   //  • x < 0      ->  NAN
735ce976d2dSEugene Zhulenev   //  • x == +INF  -> +INF
736dec8af70SRiver Riddle   Value aproximation = builder.create<arith::SelectOp>(
737ce976d2dSEugene Zhulenev       zeroMask, cstMinusInf,
738dec8af70SRiver Riddle       builder.create<arith::SelectOp>(
739ce976d2dSEugene Zhulenev           invalidMask, cstNan,
740dec8af70SRiver Riddle           builder.create<arith::SelectOp>(posInfMask, cstPosInf, x)));
741ce976d2dSEugene Zhulenev 
742ce976d2dSEugene Zhulenev   rewriter.replaceOp(op, aproximation);
743ce976d2dSEugene Zhulenev 
744ce976d2dSEugene Zhulenev   return success();
745ce976d2dSEugene Zhulenev }
746ce976d2dSEugene Zhulenev 
747c0891706SEmilio Cota namespace {
748c0891706SEmilio Cota struct LogApproximation : public LogApproximationBase<math::LogOp> {
749c0891706SEmilio Cota   using LogApproximationBase::LogApproximationBase;
750c0891706SEmilio Cota 
751c0891706SEmilio Cota   LogicalResult matchAndRewrite(math::LogOp op,
752c0891706SEmilio Cota                                 PatternRewriter &rewriter) const final {
753c0891706SEmilio Cota     return logMatchAndRewrite(op, rewriter, /*base2=*/false);
754c0891706SEmilio Cota   }
755c0891706SEmilio Cota };
756c0891706SEmilio Cota } // namespace
757c0891706SEmilio Cota 
758c0891706SEmilio Cota namespace {
759c0891706SEmilio Cota struct Log2Approximation : public LogApproximationBase<math::Log2Op> {
760c0891706SEmilio Cota   using LogApproximationBase::LogApproximationBase;
761c0891706SEmilio Cota 
762c0891706SEmilio Cota   LogicalResult matchAndRewrite(math::Log2Op op,
763c0891706SEmilio Cota                                 PatternRewriter &rewriter) const final {
764c0891706SEmilio Cota     return logMatchAndRewrite(op, rewriter, /*base2=*/true);
765c0891706SEmilio Cota   }
766c0891706SEmilio Cota };
767c0891706SEmilio Cota } // namespace
768c0891706SEmilio Cota 
769ce976d2dSEugene Zhulenev //----------------------------------------------------------------------------//
7701c0374e7SEmilio Cota // Log1p approximation.
7711c0374e7SEmilio Cota //----------------------------------------------------------------------------//
7721c0374e7SEmilio Cota 
7731c0374e7SEmilio Cota namespace {
7741c0374e7SEmilio Cota struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> {
7751c0374e7SEmilio Cota public:
7761c0374e7SEmilio Cota   using OpRewritePattern::OpRewritePattern;
7771c0374e7SEmilio Cota 
7781c0374e7SEmilio Cota   LogicalResult matchAndRewrite(math::Log1pOp op,
7791c0374e7SEmilio Cota                                 PatternRewriter &rewriter) const final;
7801c0374e7SEmilio Cota };
7811c0374e7SEmilio Cota } // namespace
7821c0374e7SEmilio Cota 
7831c0374e7SEmilio Cota // Approximate log(1+x).
7841c0374e7SEmilio Cota LogicalResult
7851c0374e7SEmilio Cota Log1pApproximation::matchAndRewrite(math::Log1pOp op,
7861c0374e7SEmilio Cota                                     PatternRewriter &rewriter) const {
78762fea88bSJacques Pienaar   if (!getElementTypeOrSelf(op.getOperand()).isF32())
7881c0374e7SEmilio Cota     return rewriter.notifyMatchFailure(op, "unsupported operand type");
7891c0374e7SEmilio Cota 
790*72f36217SKunwar Grover   std::optional<VectorShape> shape = vectorShape(op.getOperand());
791ec32d540SEugene Zhulenev 
7921c0374e7SEmilio Cota   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
7931c0374e7SEmilio Cota   auto bcast = [&](Value value) -> Value {
794ec32d540SEugene Zhulenev     return broadcast(builder, value, shape);
7951c0374e7SEmilio Cota   };
7961c0374e7SEmilio Cota 
7971c0374e7SEmilio Cota   // Approximate log(1+x) using the following, due to W. Kahan:
7981c0374e7SEmilio Cota   //   u = x + 1.0;
7991c0374e7SEmilio Cota   //   if (u == 1.0 || u == inf) return x;
8001c0374e7SEmilio Cota   //   return x * log(u) / (u - 1.0);
8011c0374e7SEmilio Cota   //          ^^^^^^^^^^^^^^^^^^^^^^
8021c0374e7SEmilio Cota   //             "logLarge" below.
8031c0374e7SEmilio Cota   Value cstOne = bcast(f32Cst(builder, 1.0f));
80462fea88bSJacques Pienaar   Value x = op.getOperand();
805a54f4eaeSMogball   Value u = builder.create<arith::AddFOp>(x, cstOne);
806a54f4eaeSMogball   Value uSmall =
807a54f4eaeSMogball       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne);
8081c0374e7SEmilio Cota   Value logU = builder.create<math::LogOp>(u);
809a54f4eaeSMogball   Value uInf =
810a54f4eaeSMogball       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU);
811a54f4eaeSMogball   Value logLarge = builder.create<arith::MulFOp>(
812a54f4eaeSMogball       x, builder.create<arith::DivFOp>(
813a54f4eaeSMogball              logU, builder.create<arith::SubFOp>(u, cstOne)));
814dec8af70SRiver Riddle   Value approximation = builder.create<arith::SelectOp>(
815a54f4eaeSMogball       builder.create<arith::OrIOp>(uSmall, uInf), x, logLarge);
8161c0374e7SEmilio Cota   rewriter.replaceOp(op, approximation);
8171c0374e7SEmilio Cota   return success();
8181c0374e7SEmilio Cota }
8191c0374e7SEmilio Cota 
8201c0374e7SEmilio Cota //----------------------------------------------------------------------------//
82172085698SPrashant Kumar // Asin approximation.
82272085698SPrashant Kumar //----------------------------------------------------------------------------//
82372085698SPrashant Kumar 
82472085698SPrashant Kumar // Approximates asin(x).
82572085698SPrashant Kumar // This approximation is based on the following stackoverflow post:
82672085698SPrashant Kumar // https://stackoverflow.com/a/42683455
82772085698SPrashant Kumar namespace {
82872085698SPrashant Kumar struct AsinPolynomialApproximation : public OpRewritePattern<math::AsinOp> {
82972085698SPrashant Kumar public:
83072085698SPrashant Kumar   using OpRewritePattern::OpRewritePattern;
83172085698SPrashant Kumar 
83272085698SPrashant Kumar   LogicalResult matchAndRewrite(math::AsinOp op,
83372085698SPrashant Kumar                                 PatternRewriter &rewriter) const final;
83472085698SPrashant Kumar };
83572085698SPrashant Kumar } // namespace
83672085698SPrashant Kumar LogicalResult
83772085698SPrashant Kumar AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
83872085698SPrashant Kumar                                              PatternRewriter &rewriter) const {
83972085698SPrashant Kumar   Value operand = op.getOperand();
84072085698SPrashant Kumar   Type elementType = getElementTypeOrSelf(operand);
84172085698SPrashant Kumar 
84272085698SPrashant Kumar   if (!(elementType.isF32() || elementType.isF16()))
84372085698SPrashant Kumar     return rewriter.notifyMatchFailure(op,
84472085698SPrashant Kumar                                        "only f32 and f16 type is supported.");
845*72f36217SKunwar Grover   std::optional<VectorShape> shape = vectorShape(operand);
84672085698SPrashant Kumar 
84772085698SPrashant Kumar   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
84872085698SPrashant Kumar   auto bcast = [&](Value value) -> Value {
84972085698SPrashant Kumar     return broadcast(builder, value, shape);
85072085698SPrashant Kumar   };
85172085698SPrashant Kumar 
85272085698SPrashant Kumar   auto fma = [&](Value a, Value b, Value c) -> Value {
85372085698SPrashant Kumar     return builder.create<math::FmaOp>(a, b, c);
85472085698SPrashant Kumar   };
85572085698SPrashant Kumar 
85672085698SPrashant Kumar   auto mul = [&](Value a, Value b) -> Value {
85772085698SPrashant Kumar     return builder.create<arith::MulFOp>(a, b);
85872085698SPrashant Kumar   };
85972085698SPrashant Kumar 
860a3fb3014SRob Suderman   auto sub = [&](Value a, Value b) -> Value {
861a3fb3014SRob Suderman     return builder.create<arith::SubFOp>(a, b);
862a3fb3014SRob Suderman   };
863a3fb3014SRob Suderman 
864a3fb3014SRob Suderman   auto abs = [&](Value a) -> Value { return builder.create<math::AbsFOp>(a); };
865a3fb3014SRob Suderman 
866a3fb3014SRob Suderman   auto sqrt = [&](Value a) -> Value { return builder.create<math::SqrtOp>(a); };
867a3fb3014SRob Suderman 
868a3fb3014SRob Suderman   auto scopy = [&](Value a, Value b) -> Value {
869a3fb3014SRob Suderman     return builder.create<math::CopySignOp>(a, b);
870a3fb3014SRob Suderman   };
871a3fb3014SRob Suderman 
872a3fb3014SRob Suderman   auto sel = [&](Value a, Value b, Value c) -> Value {
873a3fb3014SRob Suderman     return builder.create<arith::SelectOp>(a, b, c);
874a3fb3014SRob Suderman   };
875a3fb3014SRob Suderman 
876a3fb3014SRob Suderman   Value abso = abs(operand);
877a3fb3014SRob Suderman   Value aa = mul(operand, operand);
878a3fb3014SRob Suderman   Value opp = sqrt(sub(bcast(floatCst(builder, 1.0, elementType)), aa));
879a3fb3014SRob Suderman 
880a3fb3014SRob Suderman   Value gt =
881a3fb3014SRob Suderman       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, aa,
882a3fb3014SRob Suderman                                     bcast(floatCst(builder, 0.5, elementType)));
883a3fb3014SRob Suderman 
884a3fb3014SRob Suderman   Value x = sel(gt, opp, abso);
885a3fb3014SRob Suderman 
886a3fb3014SRob Suderman   // Asin(x) approximation for x = [-9/16, 9/16]:
887a3fb3014SRob Suderman   Value s = mul(x, x);
88872085698SPrashant Kumar   Value q = mul(s, s);
88972085698SPrashant Kumar   Value r = bcast(floatCst(builder, 5.5579749017470502e-2, elementType));
89072085698SPrashant Kumar   Value t = bcast(floatCst(builder, -6.2027913464120114e-2, elementType));
89172085698SPrashant Kumar 
89272085698SPrashant Kumar   r = fma(r, q, bcast(floatCst(builder, 5.4224464349245036e-2, elementType)));
89372085698SPrashant Kumar   t = fma(t, q, bcast(floatCst(builder, -1.1326992890324464e-2, elementType)));
89472085698SPrashant Kumar   r = fma(r, q, bcast(floatCst(builder, 1.5268872539397656e-2, elementType)));
89572085698SPrashant Kumar   t = fma(t, q, bcast(floatCst(builder, 1.0493798473372081e-2, elementType)));
89672085698SPrashant Kumar   r = fma(r, q, bcast(floatCst(builder, 1.4106045900607047e-2, elementType)));
89772085698SPrashant Kumar   t = fma(t, q, bcast(floatCst(builder, 1.7339776384962050e-2, elementType)));
89872085698SPrashant Kumar   r = fma(r, q, bcast(floatCst(builder, 2.2372961589651054e-2, elementType)));
89972085698SPrashant Kumar   t = fma(t, q, bcast(floatCst(builder, 3.0381912707941005e-2, elementType)));
90072085698SPrashant Kumar   r = fma(r, q, bcast(floatCst(builder, 4.4642857881094775e-2, elementType)));
90172085698SPrashant Kumar   t = fma(t, q, bcast(floatCst(builder, 7.4999999991367292e-2, elementType)));
90272085698SPrashant Kumar   r = fma(r, s, t);
90372085698SPrashant Kumar   r = fma(r, s, bcast(floatCst(builder, 1.6666666666670193e-1, elementType)));
904a3fb3014SRob Suderman   t = mul(x, s);
905a3fb3014SRob Suderman   r = fma(r, t, x);
906a3fb3014SRob Suderman 
907a3fb3014SRob Suderman   Value rsub = sub(bcast(floatCst(builder, 1.57079632679, elementType)), r);
908a3fb3014SRob Suderman   r = sel(gt, rsub, r);
909a3fb3014SRob Suderman   r = scopy(r, operand);
91072085698SPrashant Kumar 
91172085698SPrashant Kumar   rewriter.replaceOp(op, r);
91272085698SPrashant Kumar   return success();
91372085698SPrashant Kumar }
91472085698SPrashant Kumar 
91572085698SPrashant Kumar //----------------------------------------------------------------------------//
91672085698SPrashant Kumar // Acos approximation.
91772085698SPrashant Kumar //----------------------------------------------------------------------------//
91872085698SPrashant Kumar 
91972085698SPrashant Kumar // Approximates acos(x).
92072085698SPrashant Kumar // This approximation is based on the following stackoverflow post:
92172085698SPrashant Kumar // https://stackoverflow.com/a/42683455
92272085698SPrashant Kumar namespace {
92372085698SPrashant Kumar struct AcosPolynomialApproximation : public OpRewritePattern<math::AcosOp> {
92472085698SPrashant Kumar public:
92572085698SPrashant Kumar   using OpRewritePattern::OpRewritePattern;
92672085698SPrashant Kumar 
92772085698SPrashant Kumar   LogicalResult matchAndRewrite(math::AcosOp op,
92872085698SPrashant Kumar                                 PatternRewriter &rewriter) const final;
92972085698SPrashant Kumar };
93072085698SPrashant Kumar } // namespace
93172085698SPrashant Kumar LogicalResult
93272085698SPrashant Kumar AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
93372085698SPrashant Kumar                                              PatternRewriter &rewriter) const {
93472085698SPrashant Kumar   Value operand = op.getOperand();
93572085698SPrashant Kumar   Type elementType = getElementTypeOrSelf(operand);
93672085698SPrashant Kumar 
93772085698SPrashant Kumar   if (!(elementType.isF32() || elementType.isF16()))
93872085698SPrashant Kumar     return rewriter.notifyMatchFailure(op,
93972085698SPrashant Kumar                                        "only f32 and f16 type is supported.");
940*72f36217SKunwar Grover   std::optional<VectorShape> shape = vectorShape(operand);
94172085698SPrashant Kumar 
94272085698SPrashant Kumar   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
94372085698SPrashant Kumar   auto bcast = [&](Value value) -> Value {
94472085698SPrashant Kumar     return broadcast(builder, value, shape);
94572085698SPrashant Kumar   };
94672085698SPrashant Kumar 
94772085698SPrashant Kumar   auto fma = [&](Value a, Value b, Value c) -> Value {
94872085698SPrashant Kumar     return builder.create<math::FmaOp>(a, b, c);
94972085698SPrashant Kumar   };
95072085698SPrashant Kumar 
95172085698SPrashant Kumar   auto mul = [&](Value a, Value b) -> Value {
95272085698SPrashant Kumar     return builder.create<arith::MulFOp>(a, b);
95372085698SPrashant Kumar   };
95472085698SPrashant Kumar 
95572085698SPrashant Kumar   Value negOperand = builder.create<arith::NegFOp>(operand);
95672085698SPrashant Kumar   Value zero = bcast(floatCst(builder, 0.0, elementType));
95772085698SPrashant Kumar   Value half = bcast(floatCst(builder, 0.5, elementType));
95872085698SPrashant Kumar   Value negOne = bcast(floatCst(builder, -1.0, elementType));
95972085698SPrashant Kumar   Value selR =
96072085698SPrashant Kumar       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand, zero);
96172085698SPrashant Kumar   Value r = builder.create<arith::SelectOp>(selR, negOperand, operand);
96272085698SPrashant Kumar   Value chkConst = bcast(floatCst(builder, -0.5625, elementType));
96372085698SPrashant Kumar   Value firstPred =
96472085698SPrashant Kumar       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, r, chkConst);
96572085698SPrashant Kumar 
96672085698SPrashant Kumar   Value trueVal =
96772085698SPrashant Kumar       fma(bcast(floatCst(builder, 9.3282184640716537e-1, elementType)),
96872085698SPrashant Kumar           bcast(floatCst(builder, 1.6839188885261840e+0, elementType)),
96972085698SPrashant Kumar           builder.create<math::AsinOp>(r));
97072085698SPrashant Kumar 
97172085698SPrashant Kumar   Value falseVal = builder.create<math::SqrtOp>(fma(half, r, half));
97272085698SPrashant Kumar   falseVal = builder.create<math::AsinOp>(falseVal);
97372085698SPrashant Kumar   falseVal = mul(bcast(floatCst(builder, 2.0, elementType)), falseVal);
97472085698SPrashant Kumar 
97572085698SPrashant Kumar   r = builder.create<arith::SelectOp>(firstPred, trueVal, falseVal);
97672085698SPrashant Kumar 
97772085698SPrashant Kumar   // Check whether the operand lies in between [-1.0, 0.0).
97872085698SPrashant Kumar   Value greaterThanNegOne =
97972085698SPrashant Kumar       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, operand, negOne);
98072085698SPrashant Kumar 
98172085698SPrashant Kumar   Value lessThanZero =
98272085698SPrashant Kumar       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
98372085698SPrashant Kumar 
98472085698SPrashant Kumar   Value betweenNegOneZero =
98572085698SPrashant Kumar       builder.create<arith::AndIOp>(greaterThanNegOne, lessThanZero);
98672085698SPrashant Kumar 
98772085698SPrashant Kumar   trueVal = fma(bcast(floatCst(builder, 1.8656436928143307e+0, elementType)),
98872085698SPrashant Kumar                 bcast(floatCst(builder, 1.6839188885261840e+0, elementType)),
98972085698SPrashant Kumar                 builder.create<arith::NegFOp>(r));
99072085698SPrashant Kumar 
99172085698SPrashant Kumar   Value finalVal =
99272085698SPrashant Kumar       builder.create<arith::SelectOp>(betweenNegOneZero, trueVal, r);
99372085698SPrashant Kumar 
99472085698SPrashant Kumar   rewriter.replaceOp(op, finalVal);
99572085698SPrashant Kumar   return success();
99672085698SPrashant Kumar }
99772085698SPrashant Kumar 
99872085698SPrashant Kumar //----------------------------------------------------------------------------//
999f1b92218SBoian Petkantchin // Erf approximation.
1000f1b92218SBoian Petkantchin //----------------------------------------------------------------------------//
1001f1b92218SBoian Petkantchin 
1002f1b92218SBoian Petkantchin // Approximates erf(x) with
1003f1b92218SBoian Petkantchin // a - P(x)/Q(x)
1004f1b92218SBoian Petkantchin // where P and Q are polynomials of degree 4.
1005f1b92218SBoian Petkantchin // Different coefficients are chosen based on the value of x.
1006f1b92218SBoian Petkantchin // The approximation error is ~2.5e-07.
1007f1b92218SBoian Petkantchin // Boost's minimax tool that utilizes the Remez method was used to find the
1008f1b92218SBoian Petkantchin // coefficients.
1009f1b92218SBoian Petkantchin LogicalResult
1010f1b92218SBoian Petkantchin ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
1011f1b92218SBoian Petkantchin                                             PatternRewriter &rewriter) const {
1012880b8f4eSPrashant Kumar   Value operand = op.getOperand();
1013880b8f4eSPrashant Kumar   Type elementType = getElementTypeOrSelf(operand);
1014f1b92218SBoian Petkantchin 
1015880b8f4eSPrashant Kumar   if (!(elementType.isF32() || elementType.isF16()))
1016880b8f4eSPrashant Kumar     return rewriter.notifyMatchFailure(op,
1017880b8f4eSPrashant Kumar                                        "only f32 and f16 type is supported.");
1018*72f36217SKunwar Grover   std::optional<VectorShape> shape = vectorShape(operand);
1019ec32d540SEugene Zhulenev 
1020f1b92218SBoian Petkantchin   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1021f1b92218SBoian Petkantchin   auto bcast = [&](Value value) -> Value {
1022ec32d540SEugene Zhulenev     return broadcast(builder, value, shape);
1023f1b92218SBoian Petkantchin   };
1024f1b92218SBoian Petkantchin 
1025f1b92218SBoian Petkantchin   const int intervalsCount = 3;
1026f1b92218SBoian Petkantchin   const int polyDegree = 4;
1027f1b92218SBoian Petkantchin 
1028880b8f4eSPrashant Kumar   Value zero = bcast(floatCst(builder, 0, elementType));
1029880b8f4eSPrashant Kumar   Value one = bcast(floatCst(builder, 1, elementType));
1030f1b92218SBoian Petkantchin   Value pp[intervalsCount][polyDegree + 1];
1031880b8f4eSPrashant Kumar   pp[0][0] = bcast(floatCst(builder, +0.00000000000000000e+00f, elementType));
1032880b8f4eSPrashant Kumar   pp[0][1] = bcast(floatCst(builder, +1.12837916222975858e+00f, elementType));
1033880b8f4eSPrashant Kumar   pp[0][2] = bcast(floatCst(builder, -5.23018562988006470e-01f, elementType));
1034880b8f4eSPrashant Kumar   pp[0][3] = bcast(floatCst(builder, +2.09741709609267072e-01f, elementType));
1035880b8f4eSPrashant Kumar   pp[0][4] = bcast(floatCst(builder, +2.58146801602987875e-02f, elementType));
1036880b8f4eSPrashant Kumar   pp[1][0] = bcast(floatCst(builder, +0.00000000000000000e+00f, elementType));
1037880b8f4eSPrashant Kumar   pp[1][1] = bcast(floatCst(builder, +1.12750687816789140e+00f, elementType));
1038880b8f4eSPrashant Kumar   pp[1][2] = bcast(floatCst(builder, -3.64721408487825775e-01f, elementType));
1039880b8f4eSPrashant Kumar   pp[1][3] = bcast(floatCst(builder, +1.18407396425136952e-01f, elementType));
1040880b8f4eSPrashant Kumar   pp[1][4] = bcast(floatCst(builder, +3.70645533056476558e-02f, elementType));
1041880b8f4eSPrashant Kumar   pp[2][0] = bcast(floatCst(builder, -3.30093071049483172e-03f, elementType));
1042880b8f4eSPrashant Kumar   pp[2][1] = bcast(floatCst(builder, +3.51961938357697011e-03f, elementType));
1043880b8f4eSPrashant Kumar   pp[2][2] = bcast(floatCst(builder, -1.41373622814988039e-03f, elementType));
1044880b8f4eSPrashant Kumar   pp[2][3] = bcast(floatCst(builder, +2.53447094961941348e-04f, elementType));
1045880b8f4eSPrashant Kumar   pp[2][4] = bcast(floatCst(builder, -1.71048029455037401e-05f, elementType));
1046f1b92218SBoian Petkantchin 
1047f1b92218SBoian Petkantchin   Value qq[intervalsCount][polyDegree + 1];
1048880b8f4eSPrashant Kumar   qq[0][0] = bcast(floatCst(builder, +1.000000000000000000e+00f, elementType));
1049880b8f4eSPrashant Kumar   qq[0][1] = bcast(floatCst(builder, -4.635138185962547255e-01f, elementType));
1050880b8f4eSPrashant Kumar   qq[0][2] = bcast(floatCst(builder, +5.192301327279782447e-01f, elementType));
1051880b8f4eSPrashant Kumar   qq[0][3] = bcast(floatCst(builder, -1.318089722204810087e-01f, elementType));
1052880b8f4eSPrashant Kumar   qq[0][4] = bcast(floatCst(builder, +7.397964654672315005e-02f, elementType));
1053880b8f4eSPrashant Kumar   qq[1][0] = bcast(floatCst(builder, +1.00000000000000000e+00f, elementType));
1054880b8f4eSPrashant Kumar   qq[1][1] = bcast(floatCst(builder, -3.27607011824493086e-01f, elementType));
1055880b8f4eSPrashant Kumar   qq[1][2] = bcast(floatCst(builder, +4.48369090658821977e-01f, elementType));
1056880b8f4eSPrashant Kumar   qq[1][3] = bcast(floatCst(builder, -8.83462621207857930e-02f, elementType));
1057880b8f4eSPrashant Kumar   qq[1][4] = bcast(floatCst(builder, +5.72442770283176093e-02f, elementType));
1058880b8f4eSPrashant Kumar   qq[2][0] = bcast(floatCst(builder, +1.00000000000000000e+00f, elementType));
1059880b8f4eSPrashant Kumar   qq[2][1] = bcast(floatCst(builder, -2.06069165953913769e+00f, elementType));
1060880b8f4eSPrashant Kumar   qq[2][2] = bcast(floatCst(builder, +1.62705939945477759e+00f, elementType));
1061880b8f4eSPrashant Kumar   qq[2][3] = bcast(floatCst(builder, -5.83389859211130017e-01f, elementType));
1062880b8f4eSPrashant Kumar   qq[2][4] = bcast(floatCst(builder, +8.21908939856640930e-02f, elementType));
1063f1b92218SBoian Petkantchin 
1064f1b92218SBoian Petkantchin   Value offsets[intervalsCount];
1065880b8f4eSPrashant Kumar   offsets[0] = bcast(floatCst(builder, 0.0f, elementType));
1066880b8f4eSPrashant Kumar   offsets[1] = bcast(floatCst(builder, 0.0f, elementType));
1067880b8f4eSPrashant Kumar   offsets[2] = bcast(floatCst(builder, 1.0f, elementType));
1068f1b92218SBoian Petkantchin 
1069f1b92218SBoian Petkantchin   Value bounds[intervalsCount];
1070880b8f4eSPrashant Kumar   bounds[0] = bcast(floatCst(builder, 0.8f, elementType));
1071880b8f4eSPrashant Kumar   bounds[1] = bcast(floatCst(builder, 2.0f, elementType));
1072880b8f4eSPrashant Kumar   bounds[2] = bcast(floatCst(builder, 3.75f, elementType));
1073f1b92218SBoian Petkantchin 
1074880b8f4eSPrashant Kumar   Value isNegativeArg =
1075880b8f4eSPrashant Kumar       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
1076880b8f4eSPrashant Kumar   Value negArg = builder.create<arith::NegFOp>(operand);
1077880b8f4eSPrashant Kumar   Value x = builder.create<arith::SelectOp>(isNegativeArg, negArg, operand);
1078f1b92218SBoian Petkantchin 
1079f1b92218SBoian Petkantchin   Value offset = offsets[0];
1080f1b92218SBoian Petkantchin   Value p[polyDegree + 1];
1081f1b92218SBoian Petkantchin   Value q[polyDegree + 1];
1082f1b92218SBoian Petkantchin   for (int i = 0; i <= polyDegree; ++i) {
1083f1b92218SBoian Petkantchin     p[i] = pp[0][i];
1084f1b92218SBoian Petkantchin     q[i] = qq[0][i];
1085f1b92218SBoian Petkantchin   }
1086f1b92218SBoian Petkantchin 
1087f1b92218SBoian Petkantchin   // TODO: maybe use vector stacking to reduce the number of selects.
1088f1b92218SBoian Petkantchin   Value isLessThanBound[intervalsCount];
1089f1b92218SBoian Petkantchin   for (int j = 0; j < intervalsCount - 1; ++j) {
1090f1b92218SBoian Petkantchin     isLessThanBound[j] =
1091f1b92218SBoian Petkantchin         builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[j]);
1092f1b92218SBoian Petkantchin     for (int i = 0; i <= polyDegree; ++i) {
1093dec8af70SRiver Riddle       p[i] = builder.create<arith::SelectOp>(isLessThanBound[j], p[i],
1094dec8af70SRiver Riddle                                              pp[j + 1][i]);
1095dec8af70SRiver Riddle       q[i] = builder.create<arith::SelectOp>(isLessThanBound[j], q[i],
1096dec8af70SRiver Riddle                                              qq[j + 1][i]);
1097f1b92218SBoian Petkantchin     }
1098dec8af70SRiver Riddle     offset = builder.create<arith::SelectOp>(isLessThanBound[j], offset,
1099dec8af70SRiver Riddle                                              offsets[j + 1]);
1100f1b92218SBoian Petkantchin   }
1101f1b92218SBoian Petkantchin   isLessThanBound[intervalsCount - 1] = builder.create<arith::CmpFOp>(
1102f1b92218SBoian Petkantchin       arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
1103f1b92218SBoian Petkantchin 
1104f1b92218SBoian Petkantchin   Value pPoly = makePolynomialCalculation(builder, p, x);
1105f1b92218SBoian Petkantchin   Value qPoly = makePolynomialCalculation(builder, q, x);
1106f1b92218SBoian Petkantchin   Value rationalPoly = builder.create<arith::DivFOp>(pPoly, qPoly);
1107f1b92218SBoian Petkantchin   Value formula = builder.create<arith::AddFOp>(offset, rationalPoly);
1108dec8af70SRiver Riddle   formula = builder.create<arith::SelectOp>(isLessThanBound[intervalsCount - 1],
1109f1b92218SBoian Petkantchin                                             formula, one);
1110f1b92218SBoian Petkantchin 
1111f1b92218SBoian Petkantchin   // erf is odd function: erf(x) = -erf(-x).
1112f1b92218SBoian Petkantchin   Value negFormula = builder.create<arith::NegFOp>(formula);
1113dec8af70SRiver Riddle   Value res =
1114dec8af70SRiver Riddle       builder.create<arith::SelectOp>(isNegativeArg, negFormula, formula);
1115f1b92218SBoian Petkantchin 
1116f1b92218SBoian Petkantchin   rewriter.replaceOp(op, res);
1117f1b92218SBoian Petkantchin 
1118f1b92218SBoian Petkantchin   return success();
1119f1b92218SBoian Petkantchin }
1120f1b92218SBoian Petkantchin 
1121f1b92218SBoian Petkantchin //----------------------------------------------------------------------------//
1122ea7f211bSAhmed Taei // Exp approximation.
1123ea7f211bSAhmed Taei //----------------------------------------------------------------------------//
1124ea7f211bSAhmed Taei 
1125ea7f211bSAhmed Taei namespace {
1126ea7f211bSAhmed Taei 
1127*72f36217SKunwar Grover Value clampWithNormals(ImplicitLocOpBuilder &builder,
1128*72f36217SKunwar Grover                        const std::optional<VectorShape> shape, Value value,
1129*72f36217SKunwar Grover                        float lowerBound, float upperBound) {
1130710dc728SRobert Suderman   assert(!std::isnan(lowerBound));
1131710dc728SRobert Suderman   assert(!std::isnan(upperBound));
1132710dc728SRobert Suderman 
1133710dc728SRobert Suderman   auto bcast = [&](Value value) -> Value {
1134710dc728SRobert Suderman     return broadcast(builder, value, shape);
1135710dc728SRobert Suderman   };
1136710dc728SRobert Suderman 
1137710dc728SRobert Suderman   auto selectCmp = [&builder](auto pred, Value value, Value bound) {
1138710dc728SRobert Suderman     return builder.create<arith::SelectOp>(
1139710dc728SRobert Suderman         builder.create<arith::CmpFOp>(pred, value, bound), value, bound);
1140710dc728SRobert Suderman   };
1141710dc728SRobert Suderman 
1142710dc728SRobert Suderman   // Note: prefer UGE/ULE vs. UGT/ULT, since they generate vmaxps/vminps vs.
1143710dc728SRobert Suderman   // vcmpleps+vmovaps on x86_64. The latter outcome is also obtained with
1144710dc728SRobert Suderman   // arith::{Max,Min}FOp.
1145710dc728SRobert Suderman   value = selectCmp(arith::CmpFPredicate::UGE, value,
1146710dc728SRobert Suderman                     bcast(f32Cst(builder, lowerBound)));
1147710dc728SRobert Suderman   value = selectCmp(arith::CmpFPredicate::ULE, value,
1148710dc728SRobert Suderman                     bcast(f32Cst(builder, upperBound)));
1149710dc728SRobert Suderman   return value;
1150710dc728SRobert Suderman }
1151710dc728SRobert Suderman 
1152ea7f211bSAhmed Taei struct ExpApproximation : public OpRewritePattern<math::ExpOp> {
1153ea7f211bSAhmed Taei public:
1154ea7f211bSAhmed Taei   using OpRewritePattern::OpRewritePattern;
1155ea7f211bSAhmed Taei 
1156ea7f211bSAhmed Taei   LogicalResult matchAndRewrite(math::ExpOp op,
1157ea7f211bSAhmed Taei                                 PatternRewriter &rewriter) const final;
1158ea7f211bSAhmed Taei };
1159ea7f211bSAhmed Taei 
1160ea7f211bSAhmed Taei LogicalResult
1161ea7f211bSAhmed Taei ExpApproximation::matchAndRewrite(math::ExpOp op,
1162ea7f211bSAhmed Taei                                   PatternRewriter &rewriter) const {
1163710dc728SRobert Suderman   auto shape = vectorShape(op.getOperand().getType());
1164710dc728SRobert Suderman   auto elementTy = getElementTypeOrSelf(op.getType());
1165710dc728SRobert Suderman   if (!elementTy.isF32())
1166ea7f211bSAhmed Taei     return rewriter.notifyMatchFailure(op, "unsupported operand type");
1167ec32d540SEugene Zhulenev 
1168ea7f211bSAhmed Taei   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1169ea7f211bSAhmed Taei 
1170710dc728SRobert Suderman   auto add = [&](Value a, Value b) -> Value {
1171710dc728SRobert Suderman     return builder.create<arith::AddFOp>(a, b);
1172710dc728SRobert Suderman   };
1173ea7f211bSAhmed Taei   auto bcast = [&](Value value) -> Value {
1174ec32d540SEugene Zhulenev     return broadcast(builder, value, shape);
1175ea7f211bSAhmed Taei   };
1176710dc728SRobert Suderman   auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
1177ea7f211bSAhmed Taei   auto fmla = [&](Value a, Value b, Value c) {
1178a54f4eaeSMogball     return builder.create<math::FmaOp>(a, b, c);
1179ea7f211bSAhmed Taei   };
1180ea7f211bSAhmed Taei   auto mul = [&](Value a, Value b) -> Value {
1181a54f4eaeSMogball     return builder.create<arith::MulFOp>(a, b);
1182ea7f211bSAhmed Taei   };
1183ea7f211bSAhmed Taei 
1184710dc728SRobert Suderman   // Polynomial approximation from Cephes.
1185710dc728SRobert Suderman   //
1186710dc728SRobert Suderman   // To compute e^x, we re-express it as
1187710dc728SRobert Suderman   //
1188710dc728SRobert Suderman   //   e^x = e^(a + b)
1189710dc728SRobert Suderman   //       = e^(a + n log(2))
1190710dc728SRobert Suderman   //       = e^a * 2^n.
1191710dc728SRobert Suderman   //
1192710dc728SRobert Suderman   // We choose n = round(x / log(2)), restricting the value of `a` to
1193710dc728SRobert Suderman   // (-log(2)/2, log(2)/2).  We then use a polynomial to compute e^a. The
1194710dc728SRobert Suderman   // relative error between our approximation and the true value of e^a is less
1195710dc728SRobert Suderman   // than 2^-22.5 for all values of `a` within this range.
1196ea7f211bSAhmed Taei 
1197710dc728SRobert Suderman   // Restrict input to a small range, including some values that evaluate to
1198710dc728SRobert Suderman   // +/- inf.  Note that for our lower bound, we choose log(2^-126) instead of
1199710dc728SRobert Suderman   // log(F32_EPSILON). We do so because this routine always flushes denormal
1200710dc728SRobert Suderman   // floating points to 0. Therefore, we only need to worry about exponentiating
1201710dc728SRobert Suderman   // up to the smallest representable non-denormal floating point, which is
1202710dc728SRobert Suderman   // 2^-126.
1203ea7f211bSAhmed Taei 
1204710dc728SRobert Suderman   // Constants.
1205d4933b32SMehdi Amini   Value cstHalf = bcast(f32Cst(builder, 0.5f));
1206d4933b32SMehdi Amini   Value cstOne = bcast(f32Cst(builder, 1.0f));
1207710dc728SRobert Suderman 
1208710dc728SRobert Suderman   // 1/log(2)
1209d4933b32SMehdi Amini   Value cstLog2ef = bcast(f32Cst(builder, 1.44269504088896341f));
1210710dc728SRobert Suderman 
1211d4933b32SMehdi Amini   Value cstExpC1 = bcast(f32Cst(builder, -0.693359375f));
1212d4933b32SMehdi Amini   Value cstExpC2 = bcast(f32Cst(builder, 2.12194440e-4f));
1213d4933b32SMehdi Amini   Value cstExpP0 = bcast(f32Cst(builder, 1.9875691500E-4f));
1214d4933b32SMehdi Amini   Value cstExpP1 = bcast(f32Cst(builder, 1.3981999507E-3f));
1215d4933b32SMehdi Amini   Value cstExpP2 = bcast(f32Cst(builder, 8.3334519073E-3f));
1216d4933b32SMehdi Amini   Value cstExpP3 = bcast(f32Cst(builder, 4.1665795894E-2f));
1217d4933b32SMehdi Amini   Value cstExpP4 = bcast(f32Cst(builder, 1.6666665459E-1f));
1218d4933b32SMehdi Amini   Value cstExpP5 = bcast(f32Cst(builder, 5.0000001201E-1f));
1219710dc728SRobert Suderman 
1220710dc728SRobert Suderman   // Our computations below aren't particularly sensitive to the exact choices
1221710dc728SRobert Suderman   // here, so we choose values a bit larger/smaller than
1222710dc728SRobert Suderman   //
1223710dc728SRobert Suderman   //   log(F32_MAX) = 88.723...
1224710dc728SRobert Suderman   //   log(2^-126) = -87.337...
122562fea88bSJacques Pienaar   Value x = op.getOperand();
1226710dc728SRobert Suderman   x = clampWithNormals(builder, shape, x, -87.8f, 88.8f);
1227d4933b32SMehdi Amini   Value n = floor(fmla(x, cstLog2ef, cstHalf));
1228ea7f211bSAhmed Taei 
1229710dc728SRobert Suderman   // When we eventually do the multiplication in e^a * 2^n, we need to handle
1230710dc728SRobert Suderman   // the case when n > 127, the max fp32 exponent (so 2^n == inf) but e^a < 1
1231710dc728SRobert Suderman   // (so e^a * 2^n != inf).  There's a similar problem for n < -126, the
1232710dc728SRobert Suderman   // smallest fp32 exponent.
1233710dc728SRobert Suderman   //
1234710dc728SRobert Suderman   // A straightforward solution would be to detect n out of range and split it
1235710dc728SRobert Suderman   // up, doing
1236710dc728SRobert Suderman   //
1237710dc728SRobert Suderman   //   e^a * 2^n = e^a * 2^(n1 + n2)
1238710dc728SRobert Suderman   //             = (2^n1 * e^a) * 2^n2.
1239710dc728SRobert Suderman   //
1240710dc728SRobert Suderman   // But it turns out this approach is quite slow, probably because it
1241710dc728SRobert Suderman   // manipulates subnormal values.
1242710dc728SRobert Suderman   //
1243710dc728SRobert Suderman   // The approach we use instead is to clamp n to [-127, 127]. Let n' be the
1244710dc728SRobert Suderman   // value of n clamped to [-127, 127]. In the case where n' = 127, `a` can grow
1245710dc728SRobert Suderman   // up to as large as 88.8 - 127 * log(2) which is about 0.7703. Even though
1246710dc728SRobert Suderman   // this value of `a` is outside our previously specified range, e^a will still
1247710dc728SRobert Suderman   // only have a relative error of approximately 2^-16 at worse. In practice
1248710dc728SRobert Suderman   // this seems to work well enough; it passes our exhaustive tests, breaking
1249710dc728SRobert Suderman   // only one result, and by one ulp (we return exp(88.7228394) = max-float but
1250710dc728SRobert Suderman   // we should return inf).
1251710dc728SRobert Suderman   //
1252710dc728SRobert Suderman   // In the case where n' = -127, the original input value of x is so small that
1253710dc728SRobert Suderman   // e^x, our final answer, is less than 2^-126. Since 2^-126 is the smallest
1254710dc728SRobert Suderman   // normal floating point, and since we flush denormals, we simply return 0. We
1255710dc728SRobert Suderman   // do this in a branchless way by observing that our code for constructing 2^n
1256710dc728SRobert Suderman   // produces 0 if n = -127.
1257710dc728SRobert Suderman   //
1258710dc728SRobert Suderman   // The proof that n' = -127 implies e^x < 2^-126 is as follows:
1259710dc728SRobert Suderman   //
1260710dc728SRobert Suderman   //    n' = -127 implies n <= -127
1261710dc728SRobert Suderman   //              implies round(x / log(2)) <= -127
1262710dc728SRobert Suderman   //              implies x/log(2) < -126.5
1263710dc728SRobert Suderman   //              implies x < -126.5 * log(2)
1264710dc728SRobert Suderman   //              implies e^x < e^(-126.5 * log(2))
1265710dc728SRobert Suderman   //              implies e^x < 2^-126.5 < 2^-126
1266710dc728SRobert Suderman   //
1267710dc728SRobert Suderman   //    This proves that n' = -127 implies e^x < 2^-126.
1268710dc728SRobert Suderman   n = clampWithNormals(builder, shape, n, -127.0f, 127.0f);
1269b122cbebSAdrian Kuegel 
1270710dc728SRobert Suderman   // Computes x = x - n' * log(2), the value for `a`
1271d4933b32SMehdi Amini   x = fmla(cstExpC1, n, x);
1272d4933b32SMehdi Amini   x = fmla(cstExpC2, n, x);
1273ea7f211bSAhmed Taei 
1274710dc728SRobert Suderman   // Polynomial to compute z = e^a, accurate for a in (-0.5, 0.5).
1275d4933b32SMehdi Amini   Value z = fmla(x, cstExpP0, cstExpP1);
1276d4933b32SMehdi Amini   z = fmla(z, x, cstExpP2);
1277d4933b32SMehdi Amini   z = fmla(z, x, cstExpP3);
1278d4933b32SMehdi Amini   z = fmla(z, x, cstExpP4);
1279d4933b32SMehdi Amini   z = fmla(z, x, cstExpP5);
1280710dc728SRobert Suderman   z = fmla(z, mul(x, x), x);
1281d4933b32SMehdi Amini   z = add(cstOne, z);
1282ea7f211bSAhmed Taei 
1283710dc728SRobert Suderman   // Convert n' to an i32.  This is safe because we clamped it above.
1284d4933b32SMehdi Amini   auto i32Vec = broadcast(builder.getI32Type(), shape);
1285d4933b32SMehdi Amini   Value nI32 = builder.create<arith::FPToSIOp>(i32Vec, n);
1286ea7f211bSAhmed Taei 
1287710dc728SRobert Suderman   // Creates the value 2^n' if -126 <= n' <= 127 and 0 if n' = -127.
1288d4933b32SMehdi Amini   Value pow2 = exp2I32(builder, nI32);
1289ea7f211bSAhmed Taei 
1290710dc728SRobert Suderman   // Return z * 2^n' if -126 <= n' <= 127 and 0 if n = -127.
1291710dc728SRobert Suderman   Value ret = mul(z, pow2);
1292ea7f211bSAhmed Taei 
1293710dc728SRobert Suderman   rewriter.replaceOp(op, ret);
1294710dc728SRobert Suderman   return mlir::success();
1295ea7f211bSAhmed Taei }
1296ea7f211bSAhmed Taei 
1297710dc728SRobert Suderman } // namespace
1298710dc728SRobert Suderman 
1299ea7f211bSAhmed Taei //----------------------------------------------------------------------------//
13000edc4bc8SEmilio Cota // ExpM1 approximation.
13010edc4bc8SEmilio Cota //----------------------------------------------------------------------------//
13020edc4bc8SEmilio Cota 
13030edc4bc8SEmilio Cota namespace {
13040edc4bc8SEmilio Cota 
13050edc4bc8SEmilio Cota struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> {
13060edc4bc8SEmilio Cota public:
13070edc4bc8SEmilio Cota   using OpRewritePattern::OpRewritePattern;
13080edc4bc8SEmilio Cota 
13090edc4bc8SEmilio Cota   LogicalResult matchAndRewrite(math::ExpM1Op op,
13100edc4bc8SEmilio Cota                                 PatternRewriter &rewriter) const final;
13110edc4bc8SEmilio Cota };
13120edc4bc8SEmilio Cota } // namespace
13130edc4bc8SEmilio Cota 
13140edc4bc8SEmilio Cota LogicalResult
13150edc4bc8SEmilio Cota ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
13160edc4bc8SEmilio Cota                                     PatternRewriter &rewriter) const {
131762fea88bSJacques Pienaar   if (!getElementTypeOrSelf(op.getOperand()).isF32())
13180edc4bc8SEmilio Cota     return rewriter.notifyMatchFailure(op, "unsupported operand type");
13190edc4bc8SEmilio Cota 
1320*72f36217SKunwar Grover   std::optional<VectorShape> shape = vectorShape(op.getOperand());
1321ec32d540SEugene Zhulenev 
13220edc4bc8SEmilio Cota   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
13230edc4bc8SEmilio Cota   auto bcast = [&](Value value) -> Value {
1324ec32d540SEugene Zhulenev     return broadcast(builder, value, shape);
13250edc4bc8SEmilio Cota   };
13260edc4bc8SEmilio Cota 
13270edc4bc8SEmilio Cota   // expm1(x) = exp(x) - 1 = u - 1.
13280edc4bc8SEmilio Cota   // We have to handle it carefully when x is near 0, i.e. u ~= 1,
13290edc4bc8SEmilio Cota   // and when the input is ~= -inf, i.e. u - 1 ~= -1.
13300edc4bc8SEmilio Cota   Value cstOne = bcast(f32Cst(builder, 1.0f));
13310edc4bc8SEmilio Cota   Value cstNegOne = bcast(f32Cst(builder, -1.0f));
133262fea88bSJacques Pienaar   Value x = op.getOperand();
13330edc4bc8SEmilio Cota   Value u = builder.create<math::ExpOp>(x);
133487de451bSAdrian Kuegel   Value uEqOneOrNaN =
133587de451bSAdrian Kuegel       builder.create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne);
1336a54f4eaeSMogball   Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne);
1337a54f4eaeSMogball   Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>(
1338a54f4eaeSMogball       arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
13390edc4bc8SEmilio Cota   // logU = log(u) ~= x
13400edc4bc8SEmilio Cota   Value logU = builder.create<math::LogOp>(u);
13410edc4bc8SEmilio Cota 
13420edc4bc8SEmilio Cota   // Detect exp(x) = +inf; written this way to avoid having to form +inf.
1343a54f4eaeSMogball   Value isInf =
1344a54f4eaeSMogball       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u);
13450edc4bc8SEmilio Cota 
13460edc4bc8SEmilio Cota   // (u - 1) * (x / ~x)
1347a54f4eaeSMogball   Value expm1 = builder.create<arith::MulFOp>(
1348a54f4eaeSMogball       uMinusOne, builder.create<arith::DivFOp>(x, logU));
1349dec8af70SRiver Riddle   expm1 = builder.create<arith::SelectOp>(isInf, u, expm1);
1350dec8af70SRiver Riddle   Value approximation = builder.create<arith::SelectOp>(
135187de451bSAdrian Kuegel       uEqOneOrNaN, x,
1352dec8af70SRiver Riddle       builder.create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
13530edc4bc8SEmilio Cota   rewriter.replaceOp(op, approximation);
13540edc4bc8SEmilio Cota   return success();
13550edc4bc8SEmilio Cota }
13560edc4bc8SEmilio Cota 
13570edc4bc8SEmilio Cota //----------------------------------------------------------------------------//
13587e2d672aSAhmed S. Taei // Sin and Cos approximation.
13597e2d672aSAhmed S. Taei //----------------------------------------------------------------------------//
13607e2d672aSAhmed S. Taei 
13617e2d672aSAhmed S. Taei namespace {
13627e2d672aSAhmed S. Taei 
13637e2d672aSAhmed S. Taei template <bool isSine, typename OpTy>
13647e2d672aSAhmed S. Taei struct SinAndCosApproximation : public OpRewritePattern<OpTy> {
13657e2d672aSAhmed S. Taei public:
13667e2d672aSAhmed S. Taei   using OpRewritePattern<OpTy>::OpRewritePattern;
13677e2d672aSAhmed S. Taei 
13687e2d672aSAhmed S. Taei   LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
13697e2d672aSAhmed S. Taei };
13707e2d672aSAhmed S. Taei } // namespace
13717e2d672aSAhmed S. Taei 
13727e2d672aSAhmed S. Taei #define TWO_OVER_PI                                                            \
13737e2d672aSAhmed S. Taei   0.6366197723675813430755350534900574481378385829618257949906693762L
13747e2d672aSAhmed S. Taei #define PI_OVER_2                                                              \
13757e2d672aSAhmed S. Taei   1.5707963267948966192313216916397514420985846996875529104874722961L
13767e2d672aSAhmed S. Taei 
13777e2d672aSAhmed S. Taei // Approximates sin(x) or cos(x) by finding the best approximation polynomial in
13787e2d672aSAhmed S. Taei // the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the
13797e2d672aSAhmed S. Taei // reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y).
13807e2d672aSAhmed S. Taei template <bool isSine, typename OpTy>
13817e2d672aSAhmed S. Taei LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
13827e2d672aSAhmed S. Taei     OpTy op, PatternRewriter &rewriter) const {
13837e2d672aSAhmed S. Taei   static_assert(
13847e2d672aSAhmed S. Taei       llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
13857e2d672aSAhmed S. Taei       "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
1386ec32d540SEugene Zhulenev 
138762fea88bSJacques Pienaar   if (!getElementTypeOrSelf(op.getOperand()).isF32())
13887e2d672aSAhmed S. Taei     return rewriter.notifyMatchFailure(op, "unsupported operand type");
13897e2d672aSAhmed S. Taei 
1390*72f36217SKunwar Grover   std::optional<VectorShape> shape = vectorShape(op.getOperand());
1391ec32d540SEugene Zhulenev 
13927e2d672aSAhmed S. Taei   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
13937e2d672aSAhmed S. Taei   auto bcast = [&](Value value) -> Value {
1394ec32d540SEugene Zhulenev     return broadcast(builder, value, shape);
13957e2d672aSAhmed S. Taei   };
13967e2d672aSAhmed S. Taei   auto mul = [&](Value a, Value b) -> Value {
1397a54f4eaeSMogball     return builder.create<arith::MulFOp>(a, b);
13987e2d672aSAhmed S. Taei   };
13997e2d672aSAhmed S. Taei   auto sub = [&](Value a, Value b) -> Value {
1400a54f4eaeSMogball     return builder.create<arith::SubFOp>(a, b);
14017e2d672aSAhmed S. Taei   };
1402a54f4eaeSMogball   auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
14037e2d672aSAhmed S. Taei 
1404ec32d540SEugene Zhulenev   auto i32Vec = broadcast(builder.getI32Type(), shape);
14057e2d672aSAhmed S. Taei   auto fPToSingedInteger = [&](Value a) -> Value {
14063c69bc4dSRiver Riddle     return builder.create<arith::FPToSIOp>(i32Vec, a);
14077e2d672aSAhmed S. Taei   };
14087e2d672aSAhmed S. Taei 
14097e2d672aSAhmed S. Taei   auto modulo4 = [&](Value a) -> Value {
1410a54f4eaeSMogball     return builder.create<arith::AndIOp>(a, bcast(i32Cst(builder, 3)));
14117e2d672aSAhmed S. Taei   };
14127e2d672aSAhmed S. Taei 
14137e2d672aSAhmed S. Taei   auto isEqualTo = [&](Value a, Value b) -> Value {
1414a54f4eaeSMogball     return builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b);
14157e2d672aSAhmed S. Taei   };
14167e2d672aSAhmed S. Taei 
14177e2d672aSAhmed S. Taei   auto isGreaterThan = [&](Value a, Value b) -> Value {
1418a54f4eaeSMogball     return builder.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b);
14197e2d672aSAhmed S. Taei   };
14207e2d672aSAhmed S. Taei 
14217e2d672aSAhmed S. Taei   auto select = [&](Value cond, Value t, Value f) -> Value {
1422dec8af70SRiver Riddle     return builder.create<arith::SelectOp>(cond, t, f);
14237e2d672aSAhmed S. Taei   };
14247e2d672aSAhmed S. Taei 
14257e2d672aSAhmed S. Taei   auto fmla = [&](Value a, Value b, Value c) {
1426a54f4eaeSMogball     return builder.create<math::FmaOp>(a, b, c);
14277e2d672aSAhmed S. Taei   };
14287e2d672aSAhmed S. Taei 
1429a54f4eaeSMogball   auto bitwiseOr = [&](Value a, Value b) {
1430a54f4eaeSMogball     return builder.create<arith::OrIOp>(a, b);
1431a54f4eaeSMogball   };
14327e2d672aSAhmed S. Taei 
1433dc3b9365SAlexandre Ganea   Value twoOverPi = bcast(f32Cst(builder, (float)TWO_OVER_PI));
1434dc3b9365SAlexandre Ganea   Value piOverTwo = bcast(f32Cst(builder, (float)PI_OVER_2));
14357e2d672aSAhmed S. Taei 
143662fea88bSJacques Pienaar   Value x = op.getOperand();
14377e2d672aSAhmed S. Taei 
14387e2d672aSAhmed S. Taei   Value k = floor(mul(x, twoOverPi));
14397e2d672aSAhmed S. Taei 
14407e2d672aSAhmed S. Taei   Value y = sub(x, mul(k, piOverTwo));
14417e2d672aSAhmed S. Taei 
14427e2d672aSAhmed S. Taei   Value cstOne = bcast(f32Cst(builder, 1.0));
14437e2d672aSAhmed S. Taei   Value cstNegativeOne = bcast(f32Cst(builder, -1.0));
14447e2d672aSAhmed S. Taei 
14457e2d672aSAhmed S. Taei   Value cstSC2 = bcast(f32Cst(builder, -0.16666667163372039794921875f));
14467e2d672aSAhmed S. Taei   Value cstSC4 = bcast(f32Cst(builder, 8.333347737789154052734375e-3f));
14477e2d672aSAhmed S. Taei   Value cstSC6 = bcast(f32Cst(builder, -1.9842604524455964565277099609375e-4f));
14487e2d672aSAhmed S. Taei   Value cstSC8 =
14497e2d672aSAhmed S. Taei       bcast(f32Cst(builder, 2.760012648650445044040679931640625e-6f));
14507e2d672aSAhmed S. Taei   Value cstSC10 =
14517e2d672aSAhmed S. Taei       bcast(f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f));
14527e2d672aSAhmed S. Taei 
14537e2d672aSAhmed S. Taei   Value cstCC2 = bcast(f32Cst(builder, -0.5f));
14547e2d672aSAhmed S. Taei   Value cstCC4 = bcast(f32Cst(builder, 4.166664183139801025390625e-2f));
14557e2d672aSAhmed S. Taei   Value cstCC6 = bcast(f32Cst(builder, -1.388833043165504932403564453125e-3f));
14567e2d672aSAhmed S. Taei   Value cstCC8 = bcast(f32Cst(builder, 2.47562347794882953166961669921875e-5f));
14577e2d672aSAhmed S. Taei   Value cstCC10 =
14587e2d672aSAhmed S. Taei       bcast(f32Cst(builder, -2.59630184018533327616751194000244140625e-7f));
14597e2d672aSAhmed S. Taei 
14607e2d672aSAhmed S. Taei   Value kMod4 = modulo4(fPToSingedInteger(k));
14617e2d672aSAhmed S. Taei 
14627e2d672aSAhmed S. Taei   Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, 0)));
14637e2d672aSAhmed S. Taei   Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, 1)));
14647e2d672aSAhmed S. Taei   Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, 2)));
14657e2d672aSAhmed S. Taei   Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, 3)));
14667e2d672aSAhmed S. Taei 
14677e2d672aSAhmed S. Taei   Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2);
14687e2d672aSAhmed S. Taei   Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, 1)))
14697e2d672aSAhmed S. Taei                                : bitwiseOr(kR1, kR2);
14707e2d672aSAhmed S. Taei 
14717e2d672aSAhmed S. Taei   Value y2 = mul(y, y);
14727e2d672aSAhmed S. Taei 
14737e2d672aSAhmed S. Taei   Value base = select(sinuseCos, cstOne, y);
14747e2d672aSAhmed S. Taei   Value cstC2 = select(sinuseCos, cstCC2, cstSC2);
14757e2d672aSAhmed S. Taei   Value cstC4 = select(sinuseCos, cstCC4, cstSC4);
14767e2d672aSAhmed S. Taei   Value cstC6 = select(sinuseCos, cstCC6, cstSC6);
14777e2d672aSAhmed S. Taei   Value cstC8 = select(sinuseCos, cstCC8, cstSC8);
14787e2d672aSAhmed S. Taei   Value cstC10 = select(sinuseCos, cstCC10, cstSC10);
14797e2d672aSAhmed S. Taei 
14807e2d672aSAhmed S. Taei   Value v1 = fmla(y2, cstC10, cstC8);
14817e2d672aSAhmed S. Taei   Value v2 = fmla(y2, v1, cstC6);
14827e2d672aSAhmed S. Taei   Value v3 = fmla(y2, v2, cstC4);
14837e2d672aSAhmed S. Taei   Value v4 = fmla(y2, v3, cstC2);
14847e2d672aSAhmed S. Taei   Value v5 = fmla(y2, v4, cstOne);
14857e2d672aSAhmed S. Taei   Value v6 = mul(base, v5);
14867e2d672aSAhmed S. Taei 
14877e2d672aSAhmed S. Taei   Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6);
14887e2d672aSAhmed S. Taei 
14897e2d672aSAhmed S. Taei   rewriter.replaceOp(op, approximation);
14907e2d672aSAhmed S. Taei 
14917e2d672aSAhmed S. Taei   return success();
14927e2d672aSAhmed S. Taei }
14937e2d672aSAhmed S. Taei 
14947e2d672aSAhmed S. Taei //----------------------------------------------------------------------------//
14956b538810SRobert Suderman // Cbrt approximation.
14966b538810SRobert Suderman //----------------------------------------------------------------------------//
14976b538810SRobert Suderman 
14986b538810SRobert Suderman namespace {
14996b538810SRobert Suderman struct CbrtApproximation : public OpRewritePattern<math::CbrtOp> {
15006b538810SRobert Suderman   using OpRewritePattern::OpRewritePattern;
15016b538810SRobert Suderman 
15026b538810SRobert Suderman   LogicalResult matchAndRewrite(math::CbrtOp op,
15036b538810SRobert Suderman                                 PatternRewriter &rewriter) const final;
15046b538810SRobert Suderman };
15056b538810SRobert Suderman } // namespace
15066b538810SRobert Suderman 
15076b538810SRobert Suderman // Estimation of cube-root using an algorithm defined in
15086b538810SRobert Suderman // Hacker's Delight 2nd Edition.
15096b538810SRobert Suderman LogicalResult
15106b538810SRobert Suderman CbrtApproximation::matchAndRewrite(math::CbrtOp op,
15116b538810SRobert Suderman                                    PatternRewriter &rewriter) const {
15126b538810SRobert Suderman   auto operand = op.getOperand();
15136b538810SRobert Suderman   if (!getElementTypeOrSelf(operand).isF32())
15146b538810SRobert Suderman     return rewriter.notifyMatchFailure(op, "unsupported operand type");
15156b538810SRobert Suderman 
15166b538810SRobert Suderman   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1517*72f36217SKunwar Grover   std::optional<VectorShape> shape = vectorShape(operand);
15186b538810SRobert Suderman 
15196b538810SRobert Suderman   Type floatTy = getElementTypeOrSelf(operand.getType());
15206b538810SRobert Suderman   Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth());
15216b538810SRobert Suderman 
15226b538810SRobert Suderman   // Convert to vector types if necessary.
15236b538810SRobert Suderman   floatTy = broadcast(floatTy, shape);
15246b538810SRobert Suderman   intTy = broadcast(intTy, shape);
15256b538810SRobert Suderman 
15266089d612SRahul Kayaith   auto bconst = [&](TypedAttr attr) -> Value {
15276b538810SRobert Suderman     Value value = b.create<arith::ConstantOp>(attr);
15286b538810SRobert Suderman     return broadcast(b, value, shape);
15296b538810SRobert Suderman   };
15306b538810SRobert Suderman 
15316b538810SRobert Suderman   // Declare the initial values:
15326b538810SRobert Suderman   Value intTwo = bconst(b.getI32IntegerAttr(2));
15336b538810SRobert Suderman   Value intFour = bconst(b.getI32IntegerAttr(4));
15346b538810SRobert Suderman   Value intEight = bconst(b.getI32IntegerAttr(8));
15356b538810SRobert Suderman   Value intMagic = bconst(b.getI32IntegerAttr(0x2a5137a0));
15366b538810SRobert Suderman   Value fpThird = bconst(b.getF32FloatAttr(0.33333333f));
15376b538810SRobert Suderman   Value fpTwo = bconst(b.getF32FloatAttr(2.0f));
15386b538810SRobert Suderman   Value fpZero = bconst(b.getF32FloatAttr(0.0f));
15396b538810SRobert Suderman 
15406b538810SRobert Suderman   // Compute an approximation of one third:
15416b538810SRobert Suderman   // union {int ix; float x;};
15426b538810SRobert Suderman   // x = x0;
15436b538810SRobert Suderman   // ix = ix/4 + ix/16;
15446b538810SRobert Suderman   Value absValue = b.create<math::AbsFOp>(operand);
15456b538810SRobert Suderman   Value intValue = b.create<arith::BitcastOp>(intTy, absValue);
15466b538810SRobert Suderman   Value divideBy4 = b.create<arith::ShRSIOp>(intValue, intTwo);
15476b538810SRobert Suderman   Value divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
15486b538810SRobert Suderman   intValue = b.create<arith::AddIOp>(divideBy4, divideBy16);
15496b538810SRobert Suderman 
15506b538810SRobert Suderman   // ix = ix + ix/16;
15516b538810SRobert Suderman   divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
15526b538810SRobert Suderman   intValue = b.create<arith::AddIOp>(intValue, divideBy16);
15536b538810SRobert Suderman 
15546b538810SRobert Suderman   // ix = ix + ix/256;
15556b538810SRobert Suderman   Value divideBy256 = b.create<arith::ShRSIOp>(intValue, intEight);
15566b538810SRobert Suderman   intValue = b.create<arith::AddIOp>(intValue, divideBy256);
15576b538810SRobert Suderman 
15586b538810SRobert Suderman   // ix = 0x2a5137a0 + ix;
15596b538810SRobert Suderman   intValue = b.create<arith::AddIOp>(intValue, intMagic);
15606b538810SRobert Suderman 
15616b538810SRobert Suderman   // Perform one newtons step:
15626b538810SRobert Suderman   // x = 0.33333333f*(2.0f*x + x0/(x*x));
15636b538810SRobert Suderman   Value floatValue = b.create<arith::BitcastOp>(floatTy, intValue);
15646b538810SRobert Suderman   Value squared = b.create<arith::MulFOp>(floatValue, floatValue);
15656b538810SRobert Suderman   Value mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
15666b538810SRobert Suderman   Value divSquared = b.create<arith::DivFOp>(absValue, squared);
15676b538810SRobert Suderman   floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
15686b538810SRobert Suderman   floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
15696b538810SRobert Suderman 
15706b538810SRobert Suderman   // x = 0.33333333f*(2.0f*x + x0/(x*x));
15716b538810SRobert Suderman   squared = b.create<arith::MulFOp>(floatValue, floatValue);
15726b538810SRobert Suderman   mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
15736b538810SRobert Suderman   divSquared = b.create<arith::DivFOp>(absValue, squared);
15746b538810SRobert Suderman   floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
15756b538810SRobert Suderman   floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
15766b538810SRobert Suderman 
15776b538810SRobert Suderman   // Check for zero and restore sign.
15786b538810SRobert Suderman   Value isZero =
15796b538810SRobert Suderman       b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absValue, fpZero);
15806b538810SRobert Suderman   floatValue = b.create<arith::SelectOp>(isZero, fpZero, floatValue);
15816b538810SRobert Suderman   floatValue = b.create<math::CopySignOp>(floatValue, operand);
15826b538810SRobert Suderman 
15836b538810SRobert Suderman   rewriter.replaceOp(op, floatValue);
15846b538810SRobert Suderman   return success();
15856b538810SRobert Suderman }
15866b538810SRobert Suderman 
15876b538810SRobert Suderman //----------------------------------------------------------------------------//
158835553d45SEmilio Cota // Rsqrt approximation.
158935553d45SEmilio Cota //----------------------------------------------------------------------------//
159035553d45SEmilio Cota 
159135553d45SEmilio Cota namespace {
159235553d45SEmilio Cota struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> {
159335553d45SEmilio Cota   using OpRewritePattern::OpRewritePattern;
159435553d45SEmilio Cota 
159535553d45SEmilio Cota   LogicalResult matchAndRewrite(math::RsqrtOp op,
159635553d45SEmilio Cota                                 PatternRewriter &rewriter) const final;
159735553d45SEmilio Cota };
159835553d45SEmilio Cota } // namespace
159935553d45SEmilio Cota 
160035553d45SEmilio Cota LogicalResult
160135553d45SEmilio Cota RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
160235553d45SEmilio Cota                                     PatternRewriter &rewriter) const {
160362fea88bSJacques Pienaar   if (!getElementTypeOrSelf(op.getOperand()).isF32())
1604ec32d540SEugene Zhulenev     return rewriter.notifyMatchFailure(op, "unsupported operand type");
1605ec32d540SEugene Zhulenev 
1606*72f36217SKunwar Grover   std::optional<VectorShape> shape = vectorShape(op.getOperand());
1607ec32d540SEugene Zhulenev 
160835553d45SEmilio Cota   // Only support already-vectorized rsqrt's.
1609*72f36217SKunwar Grover   if (!shape || shape->sizes.empty() || shape->sizes.back() % 8 != 0)
161035553d45SEmilio Cota     return rewriter.notifyMatchFailure(op, "unsupported operand type");
161135553d45SEmilio Cota 
161235553d45SEmilio Cota   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
161335553d45SEmilio Cota   auto bcast = [&](Value value) -> Value {
1614ec32d540SEugene Zhulenev     return broadcast(builder, value, shape);
161535553d45SEmilio Cota   };
161635553d45SEmilio Cota 
161735553d45SEmilio Cota   Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));
161835553d45SEmilio Cota   Value cstOnePointFive = bcast(f32Cst(builder, 1.5f));
161935553d45SEmilio Cota   Value cstNegHalf = bcast(f32Cst(builder, -0.5f));
162035553d45SEmilio Cota   Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u));
162135553d45SEmilio Cota 
162262fea88bSJacques Pienaar   Value negHalf = builder.create<arith::MulFOp>(op.getOperand(), cstNegHalf);
162335553d45SEmilio Cota 
162435553d45SEmilio Cota   // Select only the inverse sqrt of positive normals (denormals are
162535553d45SEmilio Cota   // flushed to zero).
162662fea88bSJacques Pienaar   Value ltMinMask = builder.create<arith::CmpFOp>(
162762fea88bSJacques Pienaar       arith::CmpFPredicate::OLT, op.getOperand(), cstMinNormPos);
162835553d45SEmilio Cota   Value infMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
162962fea88bSJacques Pienaar                                                 op.getOperand(), cstPosInf);
163035553d45SEmilio Cota   Value notNormalFiniteMask = builder.create<arith::OrIOp>(ltMinMask, infMask);
163135553d45SEmilio Cota 
163235553d45SEmilio Cota   // Compute an approximate result.
1633627fa0b9SEugene Zhulenev   Value yApprox = handleMultidimensionalVectors(
1634627fa0b9SEugene Zhulenev       builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value {
1635627fa0b9SEugene Zhulenev         return builder.create<x86vector::RsqrtOp>(operands);
1636627fa0b9SEugene Zhulenev       });
163735553d45SEmilio Cota 
163835553d45SEmilio Cota   // Do a single step of Newton-Raphson iteration to improve the approximation.
163935553d45SEmilio Cota   // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).
164035553d45SEmilio Cota   // It is essential to evaluate the inner term like this because forming
164135553d45SEmilio Cota   // y_n^2 may over- or underflow.
164235553d45SEmilio Cota   Value inner = builder.create<arith::MulFOp>(negHalf, yApprox);
164335553d45SEmilio Cota   Value fma = builder.create<math::FmaOp>(yApprox, inner, cstOnePointFive);
164435553d45SEmilio Cota   Value yNewton = builder.create<arith::MulFOp>(yApprox, fma);
164535553d45SEmilio Cota 
164635553d45SEmilio Cota   // Select the result of the Newton-Raphson step for positive normal arguments.
164735553d45SEmilio Cota   // For other arguments, choose the output of the intrinsic. This will
164835553d45SEmilio Cota   // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if
164935553d45SEmilio Cota   // x is zero or a positive denormalized float (equivalent to flushing positive
165035553d45SEmilio Cota   // denormalized inputs to zero).
1651dec8af70SRiver Riddle   Value res =
1652dec8af70SRiver Riddle       builder.create<arith::SelectOp>(notNormalFiniteMask, yApprox, yNewton);
165335553d45SEmilio Cota   rewriter.replaceOp(op, res);
165435553d45SEmilio Cota 
165535553d45SEmilio Cota   return success();
165635553d45SEmilio Cota }
165735553d45SEmilio Cota 
165835553d45SEmilio Cota //----------------------------------------------------------------------------//
1659f99ccf65SEugene Zhulenev 
1660bcf9826aSJohannes Reifferscheid void mlir::populatePolynomialApproximateTanhPattern(
1661bcf9826aSJohannes Reifferscheid     RewritePatternSet &patterns) {
1662bcf9826aSJohannes Reifferscheid   patterns.add<TanhApproximation>(patterns.getContext());
1663bcf9826aSJohannes Reifferscheid }
1664bcf9826aSJohannes Reifferscheid 
1665bcf9826aSJohannes Reifferscheid void mlir::populatePolynomialApproximateErfPattern(
1666bcf9826aSJohannes Reifferscheid     RewritePatternSet &patterns) {
1667bcf9826aSJohannes Reifferscheid   patterns.add<ErfPolynomialApproximation>(patterns.getContext());
1668bcf9826aSJohannes Reifferscheid }
1669bcf9826aSJohannes Reifferscheid 
1670f99ccf65SEugene Zhulenev void mlir::populateMathPolynomialApproximationPatterns(
167135553d45SEmilio Cota     RewritePatternSet &patterns,
167235553d45SEmilio Cota     const MathPolynomialApproximationOptions &options) {
167357e1943eSRobert Suderman   // Patterns for leveraging existing f32 lowerings on other data types.
167457e1943eSRobert Suderman   patterns
167557e1943eSRobert Suderman       .add<ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
167657e1943eSRobert Suderman            ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
167757e1943eSRobert Suderman            ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
167857e1943eSRobert Suderman            ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
167957e1943eSRobert Suderman            ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
168057e1943eSRobert Suderman            ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
168157e1943eSRobert Suderman           patterns.getContext());
168257e1943eSRobert Suderman 
168372085698SPrashant Kumar   patterns
168472085698SPrashant Kumar       .add<AtanApproximation, Atan2Approximation, TanhApproximation,
16852f9f9afaSRob Suderman            LogApproximation, Log2Approximation, Log1pApproximation,
168672085698SPrashant Kumar            ErfPolynomialApproximation, AsinPolynomialApproximation,
168772085698SPrashant Kumar            AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
168857e1943eSRobert Suderman            CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
168972085698SPrashant Kumar            SinAndCosApproximation<false, math::CosOp>>(patterns.getContext());
169057e1943eSRobert Suderman   if (options.enableAvx2) {
169157e1943eSRobert Suderman     patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
169257e1943eSRobert Suderman         patterns.getContext());
169357e1943eSRobert Suderman   }
1694f99ccf65SEugene Zhulenev }
1695