1f99ccf65SEugene Zhulenev //===- PolynomialApproximation.cpp - Approximate math operations ----------===// 2f99ccf65SEugene Zhulenev // 3f99ccf65SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4f99ccf65SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information. 5f99ccf65SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6f99ccf65SEugene Zhulenev // 7f99ccf65SEugene Zhulenev //===----------------------------------------------------------------------===// 8f99ccf65SEugene Zhulenev // 9f99ccf65SEugene Zhulenev // This file implements expansion of math operations to fast approximations 10f99ccf65SEugene Zhulenev // that do not rely on any of the library functions. 11f99ccf65SEugene Zhulenev // 12f99ccf65SEugene Zhulenev //===----------------------------------------------------------------------===// 133a506b31SChris Lattner 14ec32d540SEugene Zhulenev #include <climits> 150bedb667SRobert Suderman #include <cmath> 16ec32d540SEugene Zhulenev #include <cstddef> 17ec32d540SEugene Zhulenev 18abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 19f99ccf65SEugene Zhulenev #include "mlir/Dialect/Math/IR/Math.h" 20f1b92218SBoian Petkantchin #include "mlir/Dialect/Math/Transforms/Approximation.h" 21f99ccf65SEugene Zhulenev #include "mlir/Dialect/Math/Transforms/Passes.h" 2299ef9eebSMatthias Springer #include "mlir/Dialect/Utils/IndexingUtils.h" 2399ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h" 2499ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 2535553d45SEmilio Cota #include "mlir/Dialect/X86Vector/X86VectorDialect.h" 26f99ccf65SEugene Zhulenev #include "mlir/IR/Builders.h" 27bbddd19eSJacques Pienaar #include "mlir/IR/BuiltinTypes.h" 28ce976d2dSEugene Zhulenev #include "mlir/IR/ImplicitLocOpBuilder.h" 29bbddd19eSJacques Pienaar #include "mlir/IR/OpDefinition.h" 30bbddd19eSJacques Pienaar #include "mlir/IR/PatternMatch.h" 31ec32d540SEugene Zhulenev #include "mlir/IR/TypeUtilities.h" 32f99ccf65SEugene Zhulenev #include "mlir/Transforms/DialectConversion.h" 33f99ccf65SEugene Zhulenev #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 34f1b92218SBoian Petkantchin #include "llvm/ADT/ArrayRef.h" 35bbddd19eSJacques Pienaar #include "llvm/ADT/STLExtras.h" 3675bd41f6SJacques Pienaar #include "llvm/Support/MathExtras.h" 37f99ccf65SEugene Zhulenev 38f99ccf65SEugene Zhulenev using namespace mlir; 39f1b92218SBoian Petkantchin using namespace mlir::math; 40f99ccf65SEugene Zhulenev using namespace mlir::vector; 41f99ccf65SEugene Zhulenev 42e74bcecdSBenjamin Maxwell // Helper to encapsulate a vector's shape (including scalable dims). 43e74bcecdSBenjamin Maxwell struct VectorShape { 44e74bcecdSBenjamin Maxwell ArrayRef<int64_t> sizes; 45e74bcecdSBenjamin Maxwell ArrayRef<bool> scalableFlags; 46e74bcecdSBenjamin Maxwell }; 47e74bcecdSBenjamin Maxwell 48*72f36217SKunwar Grover // Returns vector shape if the type is a vector, otherwise return nullopt. 49*72f36217SKunwar Grover static std::optional<VectorShape> vectorShape(Type type) { 50*72f36217SKunwar Grover if (auto vectorType = dyn_cast<VectorType>(type)) { 51*72f36217SKunwar Grover return VectorShape{vectorType.getShape(), vectorType.getScalableDims()}; 52*72f36217SKunwar Grover } 53*72f36217SKunwar Grover return std::nullopt; 54ce976d2dSEugene Zhulenev } 55ce976d2dSEugene Zhulenev 56*72f36217SKunwar Grover static std::optional<VectorShape> vectorShape(Value value) { 57ec32d540SEugene Zhulenev return vectorShape(value.getType()); 5839b2cd40SEugene Zhulenev } 5939b2cd40SEugene Zhulenev 60f99ccf65SEugene Zhulenev //----------------------------------------------------------------------------// 61ce976d2dSEugene Zhulenev // Broadcast scalar types and values into vector types and values. 62f99ccf65SEugene Zhulenev //----------------------------------------------------------------------------// 63f99ccf65SEugene Zhulenev 6496cee297SAlexander Belyaev // Broadcasts scalar type into vector type (iff shape is non-scalar). 65*72f36217SKunwar Grover static Type broadcast(Type type, std::optional<VectorShape> shape) { 665550c821STres Popp assert(!isa<VectorType>(type) && "must be scalar type"); 67*72f36217SKunwar Grover return shape ? VectorType::get(shape->sizes, type, shape->scalableFlags) 68e74bcecdSBenjamin Maxwell : type; 6996cee297SAlexander Belyaev } 7096cee297SAlexander Belyaev 7196cee297SAlexander Belyaev // Broadcasts scalar value into vector (iff shape is non-scalar). 7296cee297SAlexander Belyaev static Value broadcast(ImplicitLocOpBuilder &builder, Value value, 73*72f36217SKunwar Grover std::optional<VectorShape> shape) { 745550c821STres Popp assert(!isa<VectorType>(value.getType()) && "must be scalar value"); 7596cee297SAlexander Belyaev auto type = broadcast(value.getType(), shape); 76*72f36217SKunwar Grover return shape ? builder.create<BroadcastOp>(type, value) : value; 77ce976d2dSEugene Zhulenev } 78f99ccf65SEugene Zhulenev 79ce976d2dSEugene Zhulenev //----------------------------------------------------------------------------// 80627fa0b9SEugene Zhulenev // Helper function to handle n-D vectors with 1-D operations. 81627fa0b9SEugene Zhulenev //----------------------------------------------------------------------------// 82627fa0b9SEugene Zhulenev 83627fa0b9SEugene Zhulenev // Expands and unrolls n-D vector operands into multiple fixed size 1-D vectors 84627fa0b9SEugene Zhulenev // and calls the compute function with 1-D vector operands. Stitches back all 85627fa0b9SEugene Zhulenev // results into the original n-D vector result. 86627fa0b9SEugene Zhulenev // 87627fa0b9SEugene Zhulenev // Examples: vectorWidth = 8 88627fa0b9SEugene Zhulenev // - vector<4x8xf32> unrolled 4 times 89627fa0b9SEugene Zhulenev // - vector<16xf32> expanded to vector<2x8xf32> and unrolled 2 times 90627fa0b9SEugene Zhulenev // - vector<4x16xf32> expanded to vector<4x2x8xf32> and unrolled 4*2 times 91627fa0b9SEugene Zhulenev // 92627fa0b9SEugene Zhulenev // Some math approximations rely on ISA-specific operations that only accept 93627fa0b9SEugene Zhulenev // fixed size 1-D vectors (e.g. AVX expects vectors of width 8). 94627fa0b9SEugene Zhulenev // 95627fa0b9SEugene Zhulenev // It is the caller's responsibility to verify that the inner dimension is 96627fa0b9SEugene Zhulenev // divisible by the vectorWidth, and that all operands have the same vector 97627fa0b9SEugene Zhulenev // shape. 98627fa0b9SEugene Zhulenev static Value 99627fa0b9SEugene Zhulenev handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, 100627fa0b9SEugene Zhulenev ValueRange operands, int64_t vectorWidth, 1011fc096afSMehdi Amini llvm::function_ref<Value(ValueRange)> compute) { 102627fa0b9SEugene Zhulenev assert(!operands.empty() && "operands must be not empty"); 103627fa0b9SEugene Zhulenev assert(vectorWidth > 0 && "vector width must be larger than 0"); 104627fa0b9SEugene Zhulenev 1055550c821STres Popp VectorType inputType = cast<VectorType>(operands[0].getType()); 106627fa0b9SEugene Zhulenev ArrayRef<int64_t> inputShape = inputType.getShape(); 107627fa0b9SEugene Zhulenev 108627fa0b9SEugene Zhulenev // If input shape matches target vector width, we can just call the 109627fa0b9SEugene Zhulenev // user-provided compute function with the operands. 110984b800aSserge-sans-paille if (inputShape == llvm::ArrayRef(vectorWidth)) 111627fa0b9SEugene Zhulenev return compute(operands); 112627fa0b9SEugene Zhulenev 113627fa0b9SEugene Zhulenev // Check if the inner dimension has to be expanded, or we can directly iterate 114627fa0b9SEugene Zhulenev // over the outer dimensions of the vector. 115627fa0b9SEugene Zhulenev int64_t innerDim = inputShape.back(); 116627fa0b9SEugene Zhulenev int64_t expansionDim = innerDim / vectorWidth; 117627fa0b9SEugene Zhulenev assert((innerDim % vectorWidth == 0) && "invalid inner dimension size"); 118627fa0b9SEugene Zhulenev 119627fa0b9SEugene Zhulenev // Maybe expand operands to the higher rank vector shape that we'll use to 120627fa0b9SEugene Zhulenev // iterate over and extract one dimensional vectors. 1215262865aSKazu Hirata SmallVector<int64_t> expandedShape(inputShape); 122627fa0b9SEugene Zhulenev SmallVector<Value> expandedOperands(operands); 123627fa0b9SEugene Zhulenev 124627fa0b9SEugene Zhulenev if (expansionDim > 1) { 125627fa0b9SEugene Zhulenev // Expand shape from [..., innerDim] to [..., expansionDim, vectorWidth]. 126627fa0b9SEugene Zhulenev expandedShape.insert(expandedShape.end() - 1, expansionDim); 127627fa0b9SEugene Zhulenev expandedShape.back() = vectorWidth; 128627fa0b9SEugene Zhulenev 129627fa0b9SEugene Zhulenev for (unsigned i = 0; i < operands.size(); ++i) { 130627fa0b9SEugene Zhulenev auto operand = operands[i]; 1315550c821STres Popp auto eltType = cast<VectorType>(operand.getType()).getElementType(); 132627fa0b9SEugene Zhulenev auto expandedType = VectorType::get(expandedShape, eltType); 133627fa0b9SEugene Zhulenev expandedOperands[i] = 134627fa0b9SEugene Zhulenev builder.create<vector::ShapeCastOp>(expandedType, operand); 135627fa0b9SEugene Zhulenev } 136627fa0b9SEugene Zhulenev } 137627fa0b9SEugene Zhulenev 138627fa0b9SEugene Zhulenev // Iterate over all outer dimensions of the compute shape vector type. 139627fa0b9SEugene Zhulenev auto iterationDims = ArrayRef<int64_t>(expandedShape).drop_back(); 1407a69a9d7SNicolas Vasilache int64_t maxIndex = computeMaxLinearIndex(iterationDims); 1417a69a9d7SNicolas Vasilache auto strides = computeStrides(iterationDims); 142627fa0b9SEugene Zhulenev 143627fa0b9SEugene Zhulenev // Compute results for each one dimensional vector. 1447a69a9d7SNicolas Vasilache SmallVector<Value> results(maxIndex); 145627fa0b9SEugene Zhulenev 1467a69a9d7SNicolas Vasilache for (int64_t i = 0; i < maxIndex; ++i) { 147203fad47SNicolas Vasilache auto offsets = delinearize(i, strides); 148627fa0b9SEugene Zhulenev 149627fa0b9SEugene Zhulenev SmallVector<Value> extracted(expandedOperands.size()); 15089de9cc8SMehdi Amini for (const auto &tuple : llvm::enumerate(expandedOperands)) 151627fa0b9SEugene Zhulenev extracted[tuple.index()] = 152627fa0b9SEugene Zhulenev builder.create<vector::ExtractOp>(tuple.value(), offsets); 153627fa0b9SEugene Zhulenev 154627fa0b9SEugene Zhulenev results[i] = compute(extracted); 155627fa0b9SEugene Zhulenev } 156627fa0b9SEugene Zhulenev 157627fa0b9SEugene Zhulenev // Stitch results together into one large vector. 1585550c821STres Popp Type resultEltType = cast<VectorType>(results[0].getType()).getElementType(); 159627fa0b9SEugene Zhulenev Type resultExpandedType = VectorType::get(expandedShape, resultEltType); 1608e123ca6SRiver Riddle Value result = builder.create<arith::ConstantOp>( 161627fa0b9SEugene Zhulenev resultExpandedType, builder.getZeroAttr(resultExpandedType)); 162627fa0b9SEugene Zhulenev 1637a69a9d7SNicolas Vasilache for (int64_t i = 0; i < maxIndex; ++i) 164627fa0b9SEugene Zhulenev result = builder.create<vector::InsertOp>(results[i], result, 165203fad47SNicolas Vasilache delinearize(i, strides)); 166627fa0b9SEugene Zhulenev 167627fa0b9SEugene Zhulenev // Reshape back to the original vector shape. 168627fa0b9SEugene Zhulenev return builder.create<vector::ShapeCastOp>( 169627fa0b9SEugene Zhulenev VectorType::get(inputShape, resultEltType), result); 170627fa0b9SEugene Zhulenev } 171627fa0b9SEugene Zhulenev 172627fa0b9SEugene Zhulenev //----------------------------------------------------------------------------// 173ce976d2dSEugene Zhulenev // Helper functions to create constants. 174ce976d2dSEugene Zhulenev //----------------------------------------------------------------------------// 175f99ccf65SEugene Zhulenev 176880b8f4eSPrashant Kumar static Value floatCst(ImplicitLocOpBuilder &builder, float value, 177880b8f4eSPrashant Kumar Type elementType) { 17875076f8dSKazu Hirata assert((elementType.isF16() || elementType.isF32()) && 17975076f8dSKazu Hirata "x must be f16 or f32 type."); 180880b8f4eSPrashant Kumar return builder.create<arith::ConstantOp>( 181880b8f4eSPrashant Kumar builder.getFloatAttr(elementType, value)); 182880b8f4eSPrashant Kumar } 183880b8f4eSPrashant Kumar 1840bedb667SRobert Suderman static Value f32Cst(ImplicitLocOpBuilder &builder, double value) { 185a54f4eaeSMogball return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value)); 186ce976d2dSEugene Zhulenev } 187f99ccf65SEugene Zhulenev 188ce976d2dSEugene Zhulenev static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) { 189a54f4eaeSMogball return builder.create<arith::ConstantOp>(builder.getI32IntegerAttr(value)); 190ce976d2dSEugene Zhulenev } 191f99ccf65SEugene Zhulenev 192ce976d2dSEugene Zhulenev static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) { 193ce976d2dSEugene Zhulenev Value i32Value = i32Cst(builder, static_cast<int32_t>(bits)); 194a54f4eaeSMogball return builder.create<arith::BitcastOp>(builder.getF32Type(), i32Value); 195ce976d2dSEugene Zhulenev } 196f99ccf65SEugene Zhulenev 197ce976d2dSEugene Zhulenev //----------------------------------------------------------------------------// 198ce976d2dSEugene Zhulenev // Helper functions to build math functions approximations. 199ce976d2dSEugene Zhulenev //----------------------------------------------------------------------------// 200ce976d2dSEugene Zhulenev 201f5efe280STres Popp // Return the minimum of the two values or NaN if value is NaN 202f5efe280STres Popp static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound) { 203dec8af70SRiver Riddle return builder.create<arith::SelectOp>( 204f5efe280STres Popp builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, value, bound), 205f5efe280STres Popp value, bound); 206ce976d2dSEugene Zhulenev } 207ce976d2dSEugene Zhulenev 208f5efe280STres Popp // Return the maximum of the two values or NaN if value is NaN 209f5efe280STres Popp static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound) { 210dec8af70SRiver Riddle return builder.create<arith::SelectOp>( 211f5efe280STres Popp builder.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, value, bound), 212f5efe280STres Popp value, bound); 213ce976d2dSEugene Zhulenev } 214ce976d2dSEugene Zhulenev 215f5efe280STres Popp // Return the clamped value or NaN if value is NaN 216ce976d2dSEugene Zhulenev static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, 217ce976d2dSEugene Zhulenev Value upperBound) { 218ce976d2dSEugene Zhulenev return max(builder, min(builder, value, upperBound), lowerBound); 219ce976d2dSEugene Zhulenev } 220ce976d2dSEugene Zhulenev 221ce976d2dSEugene Zhulenev // Decomposes given floating point value `arg` into a normalized fraction and 222ce976d2dSEugene Zhulenev // an integral power of two (see std::frexp). Returned values have float type. 223ce976d2dSEugene Zhulenev static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg, 22402b6fb21SMehdi Amini bool isPositive = false) { 225ec32d540SEugene Zhulenev assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type"); 226*72f36217SKunwar Grover std::optional<VectorShape> shape = vectorShape(arg); 227ce976d2dSEugene Zhulenev 228ce976d2dSEugene Zhulenev auto bcast = [&](Value value) -> Value { 22996cee297SAlexander Belyaev return broadcast(builder, value, shape); 230f99ccf65SEugene Zhulenev }; 231f99ccf65SEugene Zhulenev 232ce976d2dSEugene Zhulenev auto i32 = builder.getIntegerType(32); 23396cee297SAlexander Belyaev auto i32Vec = broadcast(i32, shape); 23496cee297SAlexander Belyaev auto f32Vec = broadcast(builder.getF32Type(), shape); 235f99ccf65SEugene Zhulenev 236ce976d2dSEugene Zhulenev Value cst126f = f32Cst(builder, 126.0f); 237ce976d2dSEugene Zhulenev Value cstHalf = f32Cst(builder, 0.5f); 238ce976d2dSEugene Zhulenev Value cstInvMantMask = f32FromBits(builder, ~0x7f800000u); 239f99ccf65SEugene Zhulenev 240ce976d2dSEugene Zhulenev // Bitcast to i32 for bitwise operations. 241a54f4eaeSMogball Value i32Half = builder.create<arith::BitcastOp>(i32, cstHalf); 242a54f4eaeSMogball Value i32InvMantMask = builder.create<arith::BitcastOp>(i32, cstInvMantMask); 243a54f4eaeSMogball Value i32Arg = builder.create<arith::BitcastOp>(i32Vec, arg); 244f99ccf65SEugene Zhulenev 245ce976d2dSEugene Zhulenev // Compute normalized fraction. 246a54f4eaeSMogball Value tmp0 = builder.create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask)); 247a54f4eaeSMogball Value tmp1 = builder.create<arith::OrIOp>(tmp0, bcast(i32Half)); 248a54f4eaeSMogball Value normalizedFraction = builder.create<arith::BitcastOp>(f32Vec, tmp1); 249f99ccf65SEugene Zhulenev 250ce976d2dSEugene Zhulenev // Compute exponent. 25100f7096dSJeff Niu Value arg0 = isPositive ? arg : builder.create<math::AbsFOp>(arg); 252a54f4eaeSMogball Value biasedExponentBits = builder.create<arith::ShRUIOp>( 253a54f4eaeSMogball builder.create<arith::BitcastOp>(i32Vec, arg0), 254a54f4eaeSMogball bcast(i32Cst(builder, 23))); 255a54f4eaeSMogball Value biasedExponent = 256a54f4eaeSMogball builder.create<arith::SIToFPOp>(f32Vec, biasedExponentBits); 257a54f4eaeSMogball Value exponent = 258a54f4eaeSMogball builder.create<arith::SubFOp>(biasedExponent, bcast(cst126f)); 259f99ccf65SEugene Zhulenev 260ce976d2dSEugene Zhulenev return {normalizedFraction, exponent}; 261f99ccf65SEugene Zhulenev } 262f99ccf65SEugene Zhulenev 263ea7f211bSAhmed Taei // Computes exp2 for an i32 argument. 264ea7f211bSAhmed Taei static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) { 265ec32d540SEugene Zhulenev assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type"); 266*72f36217SKunwar Grover std::optional<VectorShape> shape = vectorShape(arg); 267ea7f211bSAhmed Taei 268ea7f211bSAhmed Taei auto bcast = [&](Value value) -> Value { 26996cee297SAlexander Belyaev return broadcast(builder, value, shape); 270ea7f211bSAhmed Taei }; 271ea7f211bSAhmed Taei 27296cee297SAlexander Belyaev auto f32Vec = broadcast(builder.getF32Type(), shape); 273ea7f211bSAhmed Taei // The exponent of f32 located at 23-bit. 274ea7f211bSAhmed Taei auto exponetBitLocation = bcast(i32Cst(builder, 23)); 275ea7f211bSAhmed Taei // Set the exponent bias to zero. 276ea7f211bSAhmed Taei auto bias = bcast(i32Cst(builder, 127)); 277ea7f211bSAhmed Taei 278a54f4eaeSMogball Value biasedArg = builder.create<arith::AddIOp>(arg, bias); 279ea7f211bSAhmed Taei Value exp2ValueInt = 280a54f4eaeSMogball builder.create<arith::ShLIOp>(biasedArg, exponetBitLocation); 281a54f4eaeSMogball Value exp2ValueF32 = builder.create<arith::BitcastOp>(f32Vec, exp2ValueInt); 282ea7f211bSAhmed Taei 283ea7f211bSAhmed Taei return exp2ValueF32; 284ea7f211bSAhmed Taei } 285ea7f211bSAhmed Taei 286f1b92218SBoian Petkantchin namespace { 287f1b92218SBoian Petkantchin Value makePolynomialCalculation(ImplicitLocOpBuilder &builder, 288f1b92218SBoian Petkantchin llvm::ArrayRef<Value> coeffs, Value x) { 289880b8f4eSPrashant Kumar Type elementType = getElementTypeOrSelf(x); 29075076f8dSKazu Hirata assert((elementType.isF32() || elementType.isF16()) && 29175076f8dSKazu Hirata "x must be f32 or f16 type"); 292*72f36217SKunwar Grover std::optional<VectorShape> shape = vectorShape(x); 293ec32d540SEugene Zhulenev 294ec32d540SEugene Zhulenev if (coeffs.empty()) 295880b8f4eSPrashant Kumar return broadcast(builder, floatCst(builder, 0.0f, elementType), shape); 296ec32d540SEugene Zhulenev 297ec32d540SEugene Zhulenev if (coeffs.size() == 1) 298f1b92218SBoian Petkantchin return coeffs[0]; 299ec32d540SEugene Zhulenev 300f1b92218SBoian Petkantchin Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1], 301f1b92218SBoian Petkantchin coeffs[coeffs.size() - 2]); 302f1b92218SBoian Petkantchin for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) { 303f1b92218SBoian Petkantchin res = builder.create<math::FmaOp>(x, res, coeffs[i]); 304f1b92218SBoian Petkantchin } 305f1b92218SBoian Petkantchin return res; 306f1b92218SBoian Petkantchin } 307f1b92218SBoian Petkantchin } // namespace 308f1b92218SBoian Petkantchin 309f99ccf65SEugene Zhulenev //----------------------------------------------------------------------------// 310bbddd19eSJacques Pienaar // Helper function/pattern to insert casts for reusing F32 bit expansion. 311bbddd19eSJacques Pienaar //----------------------------------------------------------------------------// 312bbddd19eSJacques Pienaar 313bbddd19eSJacques Pienaar template <typename T> 314bbddd19eSJacques Pienaar LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) { 315bbddd19eSJacques Pienaar // Conservatively only allow where the operand and result types are exactly 1. 316bbddd19eSJacques Pienaar Type origType = op->getResultTypes().front(); 317bbddd19eSJacques Pienaar for (Type t : llvm::drop_begin(op->getResultTypes())) 318bbddd19eSJacques Pienaar if (origType != t) 319bbddd19eSJacques Pienaar return rewriter.notifyMatchFailure(op, "required all types to match"); 320bbddd19eSJacques Pienaar for (Type t : op->getOperandTypes()) 321bbddd19eSJacques Pienaar if (origType != t) 322bbddd19eSJacques Pienaar return rewriter.notifyMatchFailure(op, "required all types to match"); 323bbddd19eSJacques Pienaar 324bbddd19eSJacques Pienaar // Skip if already F32 or larger than 32 bits. 325bbddd19eSJacques Pienaar if (getElementTypeOrSelf(origType).isF32() || 326bbddd19eSJacques Pienaar getElementTypeOrSelf(origType).getIntOrFloatBitWidth() > 32) 327bbddd19eSJacques Pienaar return failure(); 328bbddd19eSJacques Pienaar 329bbddd19eSJacques Pienaar // Create F32 equivalent type. 330bbddd19eSJacques Pienaar Type newType; 3315550c821STres Popp if (auto shaped = dyn_cast<ShapedType>(origType)) { 332bbddd19eSJacques Pienaar newType = shaped.clone(rewriter.getF32Type()); 3335550c821STres Popp } else if (isa<FloatType>(origType)) { 334bbddd19eSJacques Pienaar newType = rewriter.getF32Type(); 335bbddd19eSJacques Pienaar } else { 336bbddd19eSJacques Pienaar return rewriter.notifyMatchFailure(op, 337bbddd19eSJacques Pienaar "unable to find F32 equivalent type"); 338bbddd19eSJacques Pienaar } 339bbddd19eSJacques Pienaar 340bbddd19eSJacques Pienaar Location loc = op->getLoc(); 341bbddd19eSJacques Pienaar SmallVector<Value> operands; 342bbddd19eSJacques Pienaar for (auto operand : op->getOperands()) 343bbddd19eSJacques Pienaar operands.push_back(rewriter.create<arith::ExtFOp>(loc, newType, operand)); 34457e1943eSRobert Suderman auto result = 34557e1943eSRobert Suderman rewriter.create<T>(loc, TypeRange{newType}, operands, op->getAttrs()); 346bbddd19eSJacques Pienaar rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, origType, result); 347bbddd19eSJacques Pienaar return success(); 348bbddd19eSJacques Pienaar } 349bbddd19eSJacques Pienaar 350bbddd19eSJacques Pienaar namespace { 351bbddd19eSJacques Pienaar // Pattern to cast to F32 to reuse F32 expansion as fallback for single-result 352bbddd19eSJacques Pienaar // op. 353bbddd19eSJacques Pienaar // TODO: Consider revising to avoid adding multiple casts for a subgraph that is 354bbddd19eSJacques Pienaar // all in lower precision. Currently this is only fallback support and performs 355bbddd19eSJacques Pienaar // simplistic casting. 356bbddd19eSJacques Pienaar template <typename T> 357bbddd19eSJacques Pienaar struct ReuseF32Expansion : public OpRewritePattern<T> { 358bbddd19eSJacques Pienaar public: 359bbddd19eSJacques Pienaar using OpRewritePattern<T>::OpRewritePattern; 360bbddd19eSJacques Pienaar LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const final { 361bbddd19eSJacques Pienaar static_assert( 362bbddd19eSJacques Pienaar T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(), 363bbddd19eSJacques Pienaar "requires same operands and result types"); 364bbddd19eSJacques Pienaar return insertCasts<T>(op, rewriter); 365bbddd19eSJacques Pienaar } 366bbddd19eSJacques Pienaar }; 367bbddd19eSJacques Pienaar } // namespace 368bbddd19eSJacques Pienaar 369bbddd19eSJacques Pienaar //----------------------------------------------------------------------------// 3702f9f9afaSRob Suderman // AtanOp approximation. 3712f9f9afaSRob Suderman //----------------------------------------------------------------------------// 3722f9f9afaSRob Suderman 3732f9f9afaSRob Suderman namespace { 3742f9f9afaSRob Suderman struct AtanApproximation : public OpRewritePattern<math::AtanOp> { 3752f9f9afaSRob Suderman public: 3762f9f9afaSRob Suderman using OpRewritePattern::OpRewritePattern; 3772f9f9afaSRob Suderman 3782f9f9afaSRob Suderman LogicalResult matchAndRewrite(math::AtanOp op, 3792f9f9afaSRob Suderman PatternRewriter &rewriter) const final; 3802f9f9afaSRob Suderman }; 3812f9f9afaSRob Suderman } // namespace 3822f9f9afaSRob Suderman 3832f9f9afaSRob Suderman LogicalResult 3842f9f9afaSRob Suderman AtanApproximation::matchAndRewrite(math::AtanOp op, 3852f9f9afaSRob Suderman PatternRewriter &rewriter) const { 3862f9f9afaSRob Suderman auto operand = op.getOperand(); 3872f9f9afaSRob Suderman if (!getElementTypeOrSelf(operand).isF32()) 3882f9f9afaSRob Suderman return rewriter.notifyMatchFailure(op, "unsupported operand type"); 3892f9f9afaSRob Suderman 390*72f36217SKunwar Grover std::optional<VectorShape> shape = vectorShape(op.getOperand()); 3912f9f9afaSRob Suderman 3922f9f9afaSRob Suderman ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 39300f7096dSJeff Niu Value abs = builder.create<math::AbsFOp>(operand); 3940bedb667SRobert Suderman 3950bedb667SRobert Suderman auto one = broadcast(builder, f32Cst(builder, 1.0), shape); 3960bedb667SRobert Suderman 3970bedb667SRobert Suderman // When 0.66 < x <= 2.41 we do (x-1) / (x+1): 3980bedb667SRobert Suderman auto twoThirds = broadcast(builder, f32Cst(builder, 0.66), shape); 3990bedb667SRobert Suderman Value cmp2 = 4000bedb667SRobert Suderman builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, twoThirds); 4010bedb667SRobert Suderman Value addone = builder.create<arith::AddFOp>(abs, one); 4020bedb667SRobert Suderman Value subone = builder.create<arith::SubFOp>(abs, one); 4030bedb667SRobert Suderman Value xnum = builder.create<arith::SelectOp>(cmp2, subone, abs); 4040bedb667SRobert Suderman Value xden = builder.create<arith::SelectOp>(cmp2, addone, one); 4050bedb667SRobert Suderman 4060bedb667SRobert Suderman auto bcast = [&](Value value) -> Value { 4070bedb667SRobert Suderman return broadcast(builder, value, shape); 4080bedb667SRobert Suderman }; 4090bedb667SRobert Suderman 4100bedb667SRobert Suderman // Break into the <= 0.66 or > 2.41 we do x or 1/x: 4110bedb667SRobert Suderman auto tan3pio8 = bcast(f32Cst(builder, 2.41421356237309504880)); 4120bedb667SRobert Suderman Value cmp1 = 4130bedb667SRobert Suderman builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, tan3pio8); 4140bedb667SRobert Suderman xnum = builder.create<arith::SelectOp>(cmp1, one, xnum); 4150bedb667SRobert Suderman xden = builder.create<arith::SelectOp>(cmp1, abs, xden); 4160bedb667SRobert Suderman 4170bedb667SRobert Suderman Value x = builder.create<arith::DivFOp>(xnum, xden); 4180bedb667SRobert Suderman Value xx = builder.create<arith::MulFOp>(x, x); 4192f9f9afaSRob Suderman 4202f9f9afaSRob Suderman // Perform the Taylor series approximation for atan over the range 4210bedb667SRobert Suderman // [0.0, 0.66]. 4220bedb667SRobert Suderman auto p0 = bcast(f32Cst(builder, -8.750608600031904122785e-01)); 4230bedb667SRobert Suderman auto p1 = bcast(f32Cst(builder, -1.615753718733365076637e+01)); 4240bedb667SRobert Suderman auto p2 = bcast(f32Cst(builder, -7.500855792314704667340e+01)); 4250bedb667SRobert Suderman auto p3 = bcast(f32Cst(builder, -1.228866684490136173410e+02)); 4260bedb667SRobert Suderman auto p4 = bcast(f32Cst(builder, -6.485021904942025371773e+01)); 4270bedb667SRobert Suderman auto q0 = bcast(f32Cst(builder, +2.485846490142306297962e+01)); 4280bedb667SRobert Suderman auto q1 = bcast(f32Cst(builder, +1.650270098316988542046e+02)); 4290bedb667SRobert Suderman auto q2 = bcast(f32Cst(builder, +4.328810604912902668951e+02)); 4300bedb667SRobert Suderman auto q3 = bcast(f32Cst(builder, +4.853903996359136964868e+02)); 4310bedb667SRobert Suderman auto q4 = bcast(f32Cst(builder, +1.945506571482613964425e+02)); 4322f9f9afaSRob Suderman 4330bedb667SRobert Suderman // Apply the polynomial approximation for the numerator: 4340bedb667SRobert Suderman Value n = p0; 4350bedb667SRobert Suderman n = builder.create<math::FmaOp>(xx, n, p1); 4360bedb667SRobert Suderman n = builder.create<math::FmaOp>(xx, n, p2); 4370bedb667SRobert Suderman n = builder.create<math::FmaOp>(xx, n, p3); 4380bedb667SRobert Suderman n = builder.create<math::FmaOp>(xx, n, p4); 4390bedb667SRobert Suderman n = builder.create<arith::MulFOp>(n, xx); 4402f9f9afaSRob Suderman 4410bedb667SRobert Suderman // Apply the polynomial approximation for the denominator: 4420bedb667SRobert Suderman Value d = q0; 4430bedb667SRobert Suderman d = builder.create<math::FmaOp>(xx, d, q1); 4440bedb667SRobert Suderman d = builder.create<math::FmaOp>(xx, d, q2); 4450bedb667SRobert Suderman d = builder.create<math::FmaOp>(xx, d, q3); 4460bedb667SRobert Suderman d = builder.create<math::FmaOp>(xx, d, q4); 4470bedb667SRobert Suderman 4480bedb667SRobert Suderman // Compute approximation of theta: 4490bedb667SRobert Suderman Value ans0 = builder.create<arith::DivFOp>(n, d); 4500bedb667SRobert Suderman ans0 = builder.create<math::FmaOp>(ans0, x, x); 4510bedb667SRobert Suderman 4520bedb667SRobert Suderman // Correct for the input mapping's angles: 45375bd41f6SJacques Pienaar Value mpi4 = bcast(f32Cst(builder, llvm::numbers::pi / 4)); 4540bedb667SRobert Suderman Value ans2 = builder.create<arith::AddFOp>(mpi4, ans0); 4550bedb667SRobert Suderman Value ans = builder.create<arith::SelectOp>(cmp2, ans2, ans0); 4560bedb667SRobert Suderman 45775bd41f6SJacques Pienaar Value mpi2 = bcast(f32Cst(builder, llvm::numbers::pi / 2)); 4580bedb667SRobert Suderman Value ans1 = builder.create<arith::SubFOp>(mpi2, ans0); 4590bedb667SRobert Suderman ans = builder.create<arith::SelectOp>(cmp1, ans1, ans); 4602f9f9afaSRob Suderman 4612f9f9afaSRob Suderman // Correct for signing of the input. 4620bedb667SRobert Suderman rewriter.replaceOpWithNewOp<math::CopySignOp>(op, ans, operand); 4632f9f9afaSRob Suderman return success(); 4642f9f9afaSRob Suderman } 4652f9f9afaSRob Suderman 4662f9f9afaSRob Suderman //----------------------------------------------------------------------------// 4672f9f9afaSRob Suderman // AtanOp approximation. 4682f9f9afaSRob Suderman //----------------------------------------------------------------------------// 4692f9f9afaSRob Suderman 4702f9f9afaSRob Suderman namespace { 4712f9f9afaSRob Suderman struct Atan2Approximation : public OpRewritePattern<math::Atan2Op> { 4722f9f9afaSRob Suderman public: 4732f9f9afaSRob Suderman using OpRewritePattern::OpRewritePattern; 4742f9f9afaSRob Suderman 4752f9f9afaSRob Suderman LogicalResult matchAndRewrite(math::Atan2Op op, 4762f9f9afaSRob Suderman PatternRewriter &rewriter) const final; 4772f9f9afaSRob Suderman }; 4782f9f9afaSRob Suderman } // namespace 4792f9f9afaSRob Suderman 4802f9f9afaSRob Suderman LogicalResult 4812f9f9afaSRob Suderman Atan2Approximation::matchAndRewrite(math::Atan2Op op, 4822f9f9afaSRob Suderman PatternRewriter &rewriter) const { 4832f9f9afaSRob Suderman auto y = op.getOperand(0); 4842f9f9afaSRob Suderman auto x = op.getOperand(1); 4852f9f9afaSRob Suderman if (!getElementTypeOrSelf(x).isF32()) 4862f9f9afaSRob Suderman return rewriter.notifyMatchFailure(op, "unsupported operand type"); 4872f9f9afaSRob Suderman 4882f9f9afaSRob Suderman ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 489*72f36217SKunwar Grover std::optional<VectorShape> shape = vectorShape(op.getResult()); 4902f9f9afaSRob Suderman 4912f9f9afaSRob Suderman // Compute atan in the valid range. 4922f9f9afaSRob Suderman auto div = builder.create<arith::DivFOp>(y, x); 4932f9f9afaSRob Suderman auto atan = builder.create<math::AtanOp>(div); 4942f9f9afaSRob Suderman 4952f9f9afaSRob Suderman // Determine what the atan would be for a 180 degree rotation. 4962f9f9afaSRob Suderman auto zero = broadcast(builder, f32Cst(builder, 0.0f), shape); 4972f9f9afaSRob Suderman auto pi = broadcast(builder, f32Cst(builder, 3.14159265359f), shape); 49870ed93ecSMehdi Amini auto addPi = builder.create<arith::AddFOp>(atan, pi); 49970ed93ecSMehdi Amini auto subPi = builder.create<arith::SubFOp>(atan, pi); 50070ed93ecSMehdi Amini auto atanGt = 5012f9f9afaSRob Suderman builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, atan, zero); 502dec8af70SRiver Riddle auto flippedAtan = builder.create<arith::SelectOp>(atanGt, subPi, addPi); 5032f9f9afaSRob Suderman 5042f9f9afaSRob Suderman // Determine whether to directly use atan or use the 180 degree flip 50570ed93ecSMehdi Amini auto xGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zero); 506dec8af70SRiver Riddle Value result = builder.create<arith::SelectOp>(xGt, atan, flippedAtan); 5072f9f9afaSRob Suderman 5082f9f9afaSRob Suderman // Handle x = 0, y > 0 50970ed93ecSMehdi Amini Value xZero = 5102f9f9afaSRob Suderman builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, x, zero); 51170ed93ecSMehdi Amini Value yGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, y, zero); 51270ed93ecSMehdi Amini Value isHalfPi = builder.create<arith::AndIOp>(xZero, yGt); 51370ed93ecSMehdi Amini auto halfPi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape); 514dec8af70SRiver Riddle result = builder.create<arith::SelectOp>(isHalfPi, halfPi, result); 5152f9f9afaSRob Suderman 5162f9f9afaSRob Suderman // Handle x = 0, y < 0 51770ed93ecSMehdi Amini Value yLt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, y, zero); 51870ed93ecSMehdi Amini Value isNegativeHalfPiPi = builder.create<arith::AndIOp>(xZero, yLt); 51970ed93ecSMehdi Amini auto negativeHalfPiPi = 520dc3b9365SAlexandre Ganea broadcast(builder, f32Cst(builder, -1.57079632679f), shape); 521dec8af70SRiver Riddle result = builder.create<arith::SelectOp>(isNegativeHalfPiPi, negativeHalfPiPi, 522dec8af70SRiver Riddle result); 5232f9f9afaSRob Suderman 5242f9f9afaSRob Suderman // Handle x = 0, y = 0; 52570ed93ecSMehdi Amini Value yZero = 5262f9f9afaSRob Suderman builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, y, zero); 52770ed93ecSMehdi Amini Value isNan = builder.create<arith::AndIOp>(xZero, yZero); 52870ed93ecSMehdi Amini Value cstNan = broadcast(builder, f32FromBits(builder, 0x7fc00000), shape); 529dec8af70SRiver Riddle result = builder.create<arith::SelectOp>(isNan, cstNan, result); 5302f9f9afaSRob Suderman 5312f9f9afaSRob Suderman rewriter.replaceOp(op, result); 5322f9f9afaSRob Suderman return success(); 5332f9f9afaSRob Suderman } 5342f9f9afaSRob Suderman 5352f9f9afaSRob Suderman //----------------------------------------------------------------------------// 536f99ccf65SEugene Zhulenev // TanhOp approximation. 537f99ccf65SEugene Zhulenev //----------------------------------------------------------------------------// 538f99ccf65SEugene Zhulenev 539f99ccf65SEugene Zhulenev namespace { 540f99ccf65SEugene Zhulenev struct TanhApproximation : public OpRewritePattern<math::TanhOp> { 541f99ccf65SEugene Zhulenev public: 542f99ccf65SEugene Zhulenev using OpRewritePattern::OpRewritePattern; 543f99ccf65SEugene Zhulenev 544f99ccf65SEugene Zhulenev LogicalResult matchAndRewrite(math::TanhOp op, 545f99ccf65SEugene Zhulenev PatternRewriter &rewriter) const final; 546f99ccf65SEugene Zhulenev }; 547f99ccf65SEugene Zhulenev } // namespace 548f99ccf65SEugene Zhulenev 549f99ccf65SEugene Zhulenev LogicalResult 550f99ccf65SEugene Zhulenev TanhApproximation::matchAndRewrite(math::TanhOp op, 551f99ccf65SEugene Zhulenev PatternRewriter &rewriter) const { 55262fea88bSJacques Pienaar if (!getElementTypeOrSelf(op.getOperand()).isF32()) 553f99ccf65SEugene Zhulenev return rewriter.notifyMatchFailure(op, "unsupported operand type"); 554f99ccf65SEugene Zhulenev 555*72f36217SKunwar Grover std::optional<VectorShape> shape = vectorShape(op.getOperand()); 556ec32d540SEugene Zhulenev 557ce976d2dSEugene Zhulenev ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 558ce976d2dSEugene Zhulenev auto bcast = [&](Value value) -> Value { 559ec32d540SEugene Zhulenev return broadcast(builder, value, shape); 560ce976d2dSEugene Zhulenev }; 561f99ccf65SEugene Zhulenev 562f99ccf65SEugene Zhulenev // Clamp operand into [plusClamp, minusClamp] range. 563bf32bb7eSEugene Zhulenev Value minusClamp = bcast(f32Cst(builder, -7.99881172180175781f)); 564bf32bb7eSEugene Zhulenev Value plusClamp = bcast(f32Cst(builder, 7.99881172180175781f)); 56562fea88bSJacques Pienaar Value x = clamp(builder, op.getOperand(), minusClamp, plusClamp); 566f99ccf65SEugene Zhulenev 567f99ccf65SEugene Zhulenev // Mask for tiny values that are approximated with `operand`. 568ce976d2dSEugene Zhulenev Value tiny = bcast(f32Cst(builder, 0.0004f)); 569a54f4eaeSMogball Value tinyMask = builder.create<arith::CmpFOp>( 57000f7096dSJeff Niu arith::CmpFPredicate::OLT, builder.create<math::AbsFOp>(op.getOperand()), 571a54f4eaeSMogball tiny); 572f99ccf65SEugene Zhulenev 573f99ccf65SEugene Zhulenev // The monomial coefficients of the numerator polynomial (odd). 574ce976d2dSEugene Zhulenev Value alpha1 = bcast(f32Cst(builder, 4.89352455891786e-03f)); 575ce976d2dSEugene Zhulenev Value alpha3 = bcast(f32Cst(builder, 6.37261928875436e-04f)); 576ce976d2dSEugene Zhulenev Value alpha5 = bcast(f32Cst(builder, 1.48572235717979e-05f)); 577ce976d2dSEugene Zhulenev Value alpha7 = bcast(f32Cst(builder, 5.12229709037114e-08f)); 578ce976d2dSEugene Zhulenev Value alpha9 = bcast(f32Cst(builder, -8.60467152213735e-11f)); 579ce976d2dSEugene Zhulenev Value alpha11 = bcast(f32Cst(builder, 2.00018790482477e-13f)); 580ce976d2dSEugene Zhulenev Value alpha13 = bcast(f32Cst(builder, -2.76076847742355e-16f)); 581f99ccf65SEugene Zhulenev 582f99ccf65SEugene Zhulenev // The monomial coefficients of the denominator polynomial (even). 583ce976d2dSEugene Zhulenev Value beta0 = bcast(f32Cst(builder, 4.89352518554385e-03f)); 584ce976d2dSEugene Zhulenev Value beta2 = bcast(f32Cst(builder, 2.26843463243900e-03f)); 585ce976d2dSEugene Zhulenev Value beta4 = bcast(f32Cst(builder, 1.18534705686654e-04f)); 586ce976d2dSEugene Zhulenev Value beta6 = bcast(f32Cst(builder, 1.19825839466702e-06f)); 587f99ccf65SEugene Zhulenev 588f99ccf65SEugene Zhulenev // Since the polynomials are odd/even, we need x^2. 589a54f4eaeSMogball Value x2 = builder.create<arith::MulFOp>(x, x); 590f99ccf65SEugene Zhulenev 591f99ccf65SEugene Zhulenev // Evaluate the numerator polynomial p. 592a54f4eaeSMogball Value p = builder.create<math::FmaOp>(x2, alpha13, alpha11); 593a54f4eaeSMogball p = builder.create<math::FmaOp>(x2, p, alpha9); 594a54f4eaeSMogball p = builder.create<math::FmaOp>(x2, p, alpha7); 595a54f4eaeSMogball p = builder.create<math::FmaOp>(x2, p, alpha5); 596a54f4eaeSMogball p = builder.create<math::FmaOp>(x2, p, alpha3); 597a54f4eaeSMogball p = builder.create<math::FmaOp>(x2, p, alpha1); 598a54f4eaeSMogball p = builder.create<arith::MulFOp>(x, p); 599f99ccf65SEugene Zhulenev 600f99ccf65SEugene Zhulenev // Evaluate the denominator polynomial q. 601a54f4eaeSMogball Value q = builder.create<math::FmaOp>(x2, beta6, beta4); 602a54f4eaeSMogball q = builder.create<math::FmaOp>(x2, q, beta2); 603a54f4eaeSMogball q = builder.create<math::FmaOp>(x2, q, beta0); 604f99ccf65SEugene Zhulenev 605f99ccf65SEugene Zhulenev // Divide the numerator by the denominator. 606dec8af70SRiver Riddle Value res = builder.create<arith::SelectOp>( 607dec8af70SRiver Riddle tinyMask, x, builder.create<arith::DivFOp>(p, q)); 608f99ccf65SEugene Zhulenev 609f99ccf65SEugene Zhulenev rewriter.replaceOp(op, res); 610f99ccf65SEugene Zhulenev 611f99ccf65SEugene Zhulenev return success(); 612f99ccf65SEugene Zhulenev } 613f99ccf65SEugene Zhulenev 614ea7f211bSAhmed Taei #define LN2_VALUE \ 615ea7f211bSAhmed Taei 0.693147180559945309417232121458176568075500134360255254120680009493393621L 616c0891706SEmilio Cota #define LOG2E_VALUE \ 617ea7f211bSAhmed Taei 1.442695040888963407359924681001892137426645954152985934135449406931109219L 618ea7f211bSAhmed Taei 619f99ccf65SEugene Zhulenev //----------------------------------------------------------------------------// 620c0891706SEmilio Cota // LogOp and Log2Op approximation. 621ce976d2dSEugene Zhulenev //----------------------------------------------------------------------------// 622ce976d2dSEugene Zhulenev 623ce976d2dSEugene Zhulenev namespace { 624c0891706SEmilio Cota template <typename Op> 625c0891706SEmilio Cota struct LogApproximationBase : public OpRewritePattern<Op> { 626c0891706SEmilio Cota using OpRewritePattern<Op>::OpRewritePattern; 627ce976d2dSEugene Zhulenev 628c0891706SEmilio Cota /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise. 629c0891706SEmilio Cota LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter, 630c0891706SEmilio Cota bool base2) const; 631ce976d2dSEugene Zhulenev }; 632ce976d2dSEugene Zhulenev } // namespace 633ce976d2dSEugene Zhulenev 634c0891706SEmilio Cota // This approximation comes from Julien Pommier's SSE math library. 635c0891706SEmilio Cota // Link: http://gruntthepeon.free.fr/ssemath 636c0891706SEmilio Cota template <typename Op> 637ce976d2dSEugene Zhulenev LogicalResult 638c0891706SEmilio Cota LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter, 639c0891706SEmilio Cota bool base2) const { 64062fea88bSJacques Pienaar if (!getElementTypeOrSelf(op.getOperand()).isF32()) 641ce976d2dSEugene Zhulenev return rewriter.notifyMatchFailure(op, "unsupported operand type"); 642ce976d2dSEugene Zhulenev 643*72f36217SKunwar Grover std::optional<VectorShape> shape = vectorShape(op.getOperand()); 644ec32d540SEugene Zhulenev 645ce976d2dSEugene Zhulenev ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 646ce976d2dSEugene Zhulenev auto bcast = [&](Value value) -> Value { 647ec32d540SEugene Zhulenev return broadcast(builder, value, shape); 648ce976d2dSEugene Zhulenev }; 649ce976d2dSEugene Zhulenev 650ce976d2dSEugene Zhulenev Value cstZero = bcast(f32Cst(builder, 0.0f)); 651ce976d2dSEugene Zhulenev Value cstOne = bcast(f32Cst(builder, 1.0f)); 652ce976d2dSEugene Zhulenev Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); 653ce976d2dSEugene Zhulenev 654ce976d2dSEugene Zhulenev // The smallest non denormalized float number. 655ce976d2dSEugene Zhulenev Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); 656ce976d2dSEugene Zhulenev Value cstMinusInf = bcast(f32FromBits(builder, 0xff800000u)); 657ce976d2dSEugene Zhulenev Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u)); 658ce976d2dSEugene Zhulenev Value cstNan = bcast(f32FromBits(builder, 0x7fc00000)); 659ce976d2dSEugene Zhulenev 660ce976d2dSEugene Zhulenev // Polynomial coefficients. 661ce976d2dSEugene Zhulenev Value cstCephesSQRTHF = bcast(f32Cst(builder, 0.707106781186547524f)); 662ce976d2dSEugene Zhulenev Value cstCephesLogP0 = bcast(f32Cst(builder, 7.0376836292E-2f)); 663ce976d2dSEugene Zhulenev Value cstCephesLogP1 = bcast(f32Cst(builder, -1.1514610310E-1f)); 664ce976d2dSEugene Zhulenev Value cstCephesLogP2 = bcast(f32Cst(builder, 1.1676998740E-1f)); 665ce976d2dSEugene Zhulenev Value cstCephesLogP3 = bcast(f32Cst(builder, -1.2420140846E-1f)); 666ce976d2dSEugene Zhulenev Value cstCephesLogP4 = bcast(f32Cst(builder, +1.4249322787E-1f)); 667ce976d2dSEugene Zhulenev Value cstCephesLogP5 = bcast(f32Cst(builder, -1.6668057665E-1f)); 668ce976d2dSEugene Zhulenev Value cstCephesLogP6 = bcast(f32Cst(builder, +2.0000714765E-1f)); 669ce976d2dSEugene Zhulenev Value cstCephesLogP7 = bcast(f32Cst(builder, -2.4999993993E-1f)); 670ce976d2dSEugene Zhulenev Value cstCephesLogP8 = bcast(f32Cst(builder, +3.3333331174E-1f)); 671ce976d2dSEugene Zhulenev 67262fea88bSJacques Pienaar Value x = op.getOperand(); 673ce976d2dSEugene Zhulenev 674ce976d2dSEugene Zhulenev // Truncate input values to the minimum positive normal. 675ce976d2dSEugene Zhulenev x = max(builder, x, cstMinNormPos); 676ce976d2dSEugene Zhulenev 677ce976d2dSEugene Zhulenev // Extract significant in the range [0.5,1) and exponent. 678ced8690dSMehdi Amini std::pair<Value, Value> pair = frexp(builder, x, /*isPositive=*/true); 679ce976d2dSEugene Zhulenev x = pair.first; 680ce976d2dSEugene Zhulenev Value e = pair.second; 681ce976d2dSEugene Zhulenev 682ce976d2dSEugene Zhulenev // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift 683ce976d2dSEugene Zhulenev // by -1.0. The values are then centered around 0, which improves the 684ce976d2dSEugene Zhulenev // stability of the polynomial evaluation: 685ce976d2dSEugene Zhulenev // 686ce976d2dSEugene Zhulenev // if( x < SQRTHF ) { 687ce976d2dSEugene Zhulenev // e -= 1; 688ce976d2dSEugene Zhulenev // x = x + x - 1.0; 689ce976d2dSEugene Zhulenev // } else { x = x - 1.0; } 690a54f4eaeSMogball Value mask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, 691a54f4eaeSMogball cstCephesSQRTHF); 692dec8af70SRiver Riddle Value tmp = builder.create<arith::SelectOp>(mask, x, cstZero); 693ce976d2dSEugene Zhulenev 694a54f4eaeSMogball x = builder.create<arith::SubFOp>(x, cstOne); 695a54f4eaeSMogball e = builder.create<arith::SubFOp>( 696dec8af70SRiver Riddle e, builder.create<arith::SelectOp>(mask, cstOne, cstZero)); 697a54f4eaeSMogball x = builder.create<arith::AddFOp>(x, tmp); 698ce976d2dSEugene Zhulenev 699a54f4eaeSMogball Value x2 = builder.create<arith::MulFOp>(x, x); 700a54f4eaeSMogball Value x3 = builder.create<arith::MulFOp>(x2, x); 701ce976d2dSEugene Zhulenev 702ce976d2dSEugene Zhulenev // Evaluate the polynomial approximant of degree 8 in three parts. 703ce976d2dSEugene Zhulenev Value y0, y1, y2; 704a54f4eaeSMogball y0 = builder.create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1); 705a54f4eaeSMogball y1 = builder.create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4); 706a54f4eaeSMogball y2 = builder.create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7); 707a54f4eaeSMogball y0 = builder.create<math::FmaOp>(y0, x, cstCephesLogP2); 708a54f4eaeSMogball y1 = builder.create<math::FmaOp>(y1, x, cstCephesLogP5); 709a54f4eaeSMogball y2 = builder.create<math::FmaOp>(y2, x, cstCephesLogP8); 710a54f4eaeSMogball y0 = builder.create<math::FmaOp>(y0, x3, y1); 711a54f4eaeSMogball y0 = builder.create<math::FmaOp>(y0, x3, y2); 712a54f4eaeSMogball y0 = builder.create<arith::MulFOp>(y0, x3); 713ce976d2dSEugene Zhulenev 714a54f4eaeSMogball y0 = builder.create<math::FmaOp>(cstNegHalf, x2, y0); 715a54f4eaeSMogball x = builder.create<arith::AddFOp>(x, y0); 716ce976d2dSEugene Zhulenev 717c0891706SEmilio Cota if (base2) { 718c0891706SEmilio Cota Value cstLog2e = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE))); 719a54f4eaeSMogball x = builder.create<math::FmaOp>(x, cstLog2e, e); 720c0891706SEmilio Cota } else { 721ce976d2dSEugene Zhulenev Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE))); 722a54f4eaeSMogball x = builder.create<math::FmaOp>(e, cstLn2, x); 723c0891706SEmilio Cota } 724ce976d2dSEugene Zhulenev 725a54f4eaeSMogball Value invalidMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, 72662fea88bSJacques Pienaar op.getOperand(), cstZero); 727a54f4eaeSMogball Value zeroMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 72862fea88bSJacques Pienaar op.getOperand(), cstZero); 729a54f4eaeSMogball Value posInfMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 73062fea88bSJacques Pienaar op.getOperand(), cstPosInf); 731ce976d2dSEugene Zhulenev 732ce976d2dSEugene Zhulenev // Filter out invalid values: 733ce976d2dSEugene Zhulenev // • x == 0 -> -INF 734ce976d2dSEugene Zhulenev // • x < 0 -> NAN 735ce976d2dSEugene Zhulenev // • x == +INF -> +INF 736dec8af70SRiver Riddle Value aproximation = builder.create<arith::SelectOp>( 737ce976d2dSEugene Zhulenev zeroMask, cstMinusInf, 738dec8af70SRiver Riddle builder.create<arith::SelectOp>( 739ce976d2dSEugene Zhulenev invalidMask, cstNan, 740dec8af70SRiver Riddle builder.create<arith::SelectOp>(posInfMask, cstPosInf, x))); 741ce976d2dSEugene Zhulenev 742ce976d2dSEugene Zhulenev rewriter.replaceOp(op, aproximation); 743ce976d2dSEugene Zhulenev 744ce976d2dSEugene Zhulenev return success(); 745ce976d2dSEugene Zhulenev } 746ce976d2dSEugene Zhulenev 747c0891706SEmilio Cota namespace { 748c0891706SEmilio Cota struct LogApproximation : public LogApproximationBase<math::LogOp> { 749c0891706SEmilio Cota using LogApproximationBase::LogApproximationBase; 750c0891706SEmilio Cota 751c0891706SEmilio Cota LogicalResult matchAndRewrite(math::LogOp op, 752c0891706SEmilio Cota PatternRewriter &rewriter) const final { 753c0891706SEmilio Cota return logMatchAndRewrite(op, rewriter, /*base2=*/false); 754c0891706SEmilio Cota } 755c0891706SEmilio Cota }; 756c0891706SEmilio Cota } // namespace 757c0891706SEmilio Cota 758c0891706SEmilio Cota namespace { 759c0891706SEmilio Cota struct Log2Approximation : public LogApproximationBase<math::Log2Op> { 760c0891706SEmilio Cota using LogApproximationBase::LogApproximationBase; 761c0891706SEmilio Cota 762c0891706SEmilio Cota LogicalResult matchAndRewrite(math::Log2Op op, 763c0891706SEmilio Cota PatternRewriter &rewriter) const final { 764c0891706SEmilio Cota return logMatchAndRewrite(op, rewriter, /*base2=*/true); 765c0891706SEmilio Cota } 766c0891706SEmilio Cota }; 767c0891706SEmilio Cota } // namespace 768c0891706SEmilio Cota 769ce976d2dSEugene Zhulenev //----------------------------------------------------------------------------// 7701c0374e7SEmilio Cota // Log1p approximation. 7711c0374e7SEmilio Cota //----------------------------------------------------------------------------// 7721c0374e7SEmilio Cota 7731c0374e7SEmilio Cota namespace { 7741c0374e7SEmilio Cota struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> { 7751c0374e7SEmilio Cota public: 7761c0374e7SEmilio Cota using OpRewritePattern::OpRewritePattern; 7771c0374e7SEmilio Cota 7781c0374e7SEmilio Cota LogicalResult matchAndRewrite(math::Log1pOp op, 7791c0374e7SEmilio Cota PatternRewriter &rewriter) const final; 7801c0374e7SEmilio Cota }; 7811c0374e7SEmilio Cota } // namespace 7821c0374e7SEmilio Cota 7831c0374e7SEmilio Cota // Approximate log(1+x). 7841c0374e7SEmilio Cota LogicalResult 7851c0374e7SEmilio Cota Log1pApproximation::matchAndRewrite(math::Log1pOp op, 7861c0374e7SEmilio Cota PatternRewriter &rewriter) const { 78762fea88bSJacques Pienaar if (!getElementTypeOrSelf(op.getOperand()).isF32()) 7881c0374e7SEmilio Cota return rewriter.notifyMatchFailure(op, "unsupported operand type"); 7891c0374e7SEmilio Cota 790*72f36217SKunwar Grover std::optional<VectorShape> shape = vectorShape(op.getOperand()); 791ec32d540SEugene Zhulenev 7921c0374e7SEmilio Cota ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 7931c0374e7SEmilio Cota auto bcast = [&](Value value) -> Value { 794ec32d540SEugene Zhulenev return broadcast(builder, value, shape); 7951c0374e7SEmilio Cota }; 7961c0374e7SEmilio Cota 7971c0374e7SEmilio Cota // Approximate log(1+x) using the following, due to W. Kahan: 7981c0374e7SEmilio Cota // u = x + 1.0; 7991c0374e7SEmilio Cota // if (u == 1.0 || u == inf) return x; 8001c0374e7SEmilio Cota // return x * log(u) / (u - 1.0); 8011c0374e7SEmilio Cota // ^^^^^^^^^^^^^^^^^^^^^^ 8021c0374e7SEmilio Cota // "logLarge" below. 8031c0374e7SEmilio Cota Value cstOne = bcast(f32Cst(builder, 1.0f)); 80462fea88bSJacques Pienaar Value x = op.getOperand(); 805a54f4eaeSMogball Value u = builder.create<arith::AddFOp>(x, cstOne); 806a54f4eaeSMogball Value uSmall = 807a54f4eaeSMogball builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne); 8081c0374e7SEmilio Cota Value logU = builder.create<math::LogOp>(u); 809a54f4eaeSMogball Value uInf = 810a54f4eaeSMogball builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU); 811a54f4eaeSMogball Value logLarge = builder.create<arith::MulFOp>( 812a54f4eaeSMogball x, builder.create<arith::DivFOp>( 813a54f4eaeSMogball logU, builder.create<arith::SubFOp>(u, cstOne))); 814dec8af70SRiver Riddle Value approximation = builder.create<arith::SelectOp>( 815a54f4eaeSMogball builder.create<arith::OrIOp>(uSmall, uInf), x, logLarge); 8161c0374e7SEmilio Cota rewriter.replaceOp(op, approximation); 8171c0374e7SEmilio Cota return success(); 8181c0374e7SEmilio Cota } 8191c0374e7SEmilio Cota 8201c0374e7SEmilio Cota //----------------------------------------------------------------------------// 82172085698SPrashant Kumar // Asin approximation. 82272085698SPrashant Kumar //----------------------------------------------------------------------------// 82372085698SPrashant Kumar 82472085698SPrashant Kumar // Approximates asin(x). 82572085698SPrashant Kumar // This approximation is based on the following stackoverflow post: 82672085698SPrashant Kumar // https://stackoverflow.com/a/42683455 82772085698SPrashant Kumar namespace { 82872085698SPrashant Kumar struct AsinPolynomialApproximation : public OpRewritePattern<math::AsinOp> { 82972085698SPrashant Kumar public: 83072085698SPrashant Kumar using OpRewritePattern::OpRewritePattern; 83172085698SPrashant Kumar 83272085698SPrashant Kumar LogicalResult matchAndRewrite(math::AsinOp op, 83372085698SPrashant Kumar PatternRewriter &rewriter) const final; 83472085698SPrashant Kumar }; 83572085698SPrashant Kumar } // namespace 83672085698SPrashant Kumar LogicalResult 83772085698SPrashant Kumar AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op, 83872085698SPrashant Kumar PatternRewriter &rewriter) const { 83972085698SPrashant Kumar Value operand = op.getOperand(); 84072085698SPrashant Kumar Type elementType = getElementTypeOrSelf(operand); 84172085698SPrashant Kumar 84272085698SPrashant Kumar if (!(elementType.isF32() || elementType.isF16())) 84372085698SPrashant Kumar return rewriter.notifyMatchFailure(op, 84472085698SPrashant Kumar "only f32 and f16 type is supported."); 845*72f36217SKunwar Grover std::optional<VectorShape> shape = vectorShape(operand); 84672085698SPrashant Kumar 84772085698SPrashant Kumar ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 84872085698SPrashant Kumar auto bcast = [&](Value value) -> Value { 84972085698SPrashant Kumar return broadcast(builder, value, shape); 85072085698SPrashant Kumar }; 85172085698SPrashant Kumar 85272085698SPrashant Kumar auto fma = [&](Value a, Value b, Value c) -> Value { 85372085698SPrashant Kumar return builder.create<math::FmaOp>(a, b, c); 85472085698SPrashant Kumar }; 85572085698SPrashant Kumar 85672085698SPrashant Kumar auto mul = [&](Value a, Value b) -> Value { 85772085698SPrashant Kumar return builder.create<arith::MulFOp>(a, b); 85872085698SPrashant Kumar }; 85972085698SPrashant Kumar 860a3fb3014SRob Suderman auto sub = [&](Value a, Value b) -> Value { 861a3fb3014SRob Suderman return builder.create<arith::SubFOp>(a, b); 862a3fb3014SRob Suderman }; 863a3fb3014SRob Suderman 864a3fb3014SRob Suderman auto abs = [&](Value a) -> Value { return builder.create<math::AbsFOp>(a); }; 865a3fb3014SRob Suderman 866a3fb3014SRob Suderman auto sqrt = [&](Value a) -> Value { return builder.create<math::SqrtOp>(a); }; 867a3fb3014SRob Suderman 868a3fb3014SRob Suderman auto scopy = [&](Value a, Value b) -> Value { 869a3fb3014SRob Suderman return builder.create<math::CopySignOp>(a, b); 870a3fb3014SRob Suderman }; 871a3fb3014SRob Suderman 872a3fb3014SRob Suderman auto sel = [&](Value a, Value b, Value c) -> Value { 873a3fb3014SRob Suderman return builder.create<arith::SelectOp>(a, b, c); 874a3fb3014SRob Suderman }; 875a3fb3014SRob Suderman 876a3fb3014SRob Suderman Value abso = abs(operand); 877a3fb3014SRob Suderman Value aa = mul(operand, operand); 878a3fb3014SRob Suderman Value opp = sqrt(sub(bcast(floatCst(builder, 1.0, elementType)), aa)); 879a3fb3014SRob Suderman 880a3fb3014SRob Suderman Value gt = 881a3fb3014SRob Suderman builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, aa, 882a3fb3014SRob Suderman bcast(floatCst(builder, 0.5, elementType))); 883a3fb3014SRob Suderman 884a3fb3014SRob Suderman Value x = sel(gt, opp, abso); 885a3fb3014SRob Suderman 886a3fb3014SRob Suderman // Asin(x) approximation for x = [-9/16, 9/16]: 887a3fb3014SRob Suderman Value s = mul(x, x); 88872085698SPrashant Kumar Value q = mul(s, s); 88972085698SPrashant Kumar Value r = bcast(floatCst(builder, 5.5579749017470502e-2, elementType)); 89072085698SPrashant Kumar Value t = bcast(floatCst(builder, -6.2027913464120114e-2, elementType)); 89172085698SPrashant Kumar 89272085698SPrashant Kumar r = fma(r, q, bcast(floatCst(builder, 5.4224464349245036e-2, elementType))); 89372085698SPrashant Kumar t = fma(t, q, bcast(floatCst(builder, -1.1326992890324464e-2, elementType))); 89472085698SPrashant Kumar r = fma(r, q, bcast(floatCst(builder, 1.5268872539397656e-2, elementType))); 89572085698SPrashant Kumar t = fma(t, q, bcast(floatCst(builder, 1.0493798473372081e-2, elementType))); 89672085698SPrashant Kumar r = fma(r, q, bcast(floatCst(builder, 1.4106045900607047e-2, elementType))); 89772085698SPrashant Kumar t = fma(t, q, bcast(floatCst(builder, 1.7339776384962050e-2, elementType))); 89872085698SPrashant Kumar r = fma(r, q, bcast(floatCst(builder, 2.2372961589651054e-2, elementType))); 89972085698SPrashant Kumar t = fma(t, q, bcast(floatCst(builder, 3.0381912707941005e-2, elementType))); 90072085698SPrashant Kumar r = fma(r, q, bcast(floatCst(builder, 4.4642857881094775e-2, elementType))); 90172085698SPrashant Kumar t = fma(t, q, bcast(floatCst(builder, 7.4999999991367292e-2, elementType))); 90272085698SPrashant Kumar r = fma(r, s, t); 90372085698SPrashant Kumar r = fma(r, s, bcast(floatCst(builder, 1.6666666666670193e-1, elementType))); 904a3fb3014SRob Suderman t = mul(x, s); 905a3fb3014SRob Suderman r = fma(r, t, x); 906a3fb3014SRob Suderman 907a3fb3014SRob Suderman Value rsub = sub(bcast(floatCst(builder, 1.57079632679, elementType)), r); 908a3fb3014SRob Suderman r = sel(gt, rsub, r); 909a3fb3014SRob Suderman r = scopy(r, operand); 91072085698SPrashant Kumar 91172085698SPrashant Kumar rewriter.replaceOp(op, r); 91272085698SPrashant Kumar return success(); 91372085698SPrashant Kumar } 91472085698SPrashant Kumar 91572085698SPrashant Kumar //----------------------------------------------------------------------------// 91672085698SPrashant Kumar // Acos approximation. 91772085698SPrashant Kumar //----------------------------------------------------------------------------// 91872085698SPrashant Kumar 91972085698SPrashant Kumar // Approximates acos(x). 92072085698SPrashant Kumar // This approximation is based on the following stackoverflow post: 92172085698SPrashant Kumar // https://stackoverflow.com/a/42683455 92272085698SPrashant Kumar namespace { 92372085698SPrashant Kumar struct AcosPolynomialApproximation : public OpRewritePattern<math::AcosOp> { 92472085698SPrashant Kumar public: 92572085698SPrashant Kumar using OpRewritePattern::OpRewritePattern; 92672085698SPrashant Kumar 92772085698SPrashant Kumar LogicalResult matchAndRewrite(math::AcosOp op, 92872085698SPrashant Kumar PatternRewriter &rewriter) const final; 92972085698SPrashant Kumar }; 93072085698SPrashant Kumar } // namespace 93172085698SPrashant Kumar LogicalResult 93272085698SPrashant Kumar AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op, 93372085698SPrashant Kumar PatternRewriter &rewriter) const { 93472085698SPrashant Kumar Value operand = op.getOperand(); 93572085698SPrashant Kumar Type elementType = getElementTypeOrSelf(operand); 93672085698SPrashant Kumar 93772085698SPrashant Kumar if (!(elementType.isF32() || elementType.isF16())) 93872085698SPrashant Kumar return rewriter.notifyMatchFailure(op, 93972085698SPrashant Kumar "only f32 and f16 type is supported."); 940*72f36217SKunwar Grover std::optional<VectorShape> shape = vectorShape(operand); 94172085698SPrashant Kumar 94272085698SPrashant Kumar ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 94372085698SPrashant Kumar auto bcast = [&](Value value) -> Value { 94472085698SPrashant Kumar return broadcast(builder, value, shape); 94572085698SPrashant Kumar }; 94672085698SPrashant Kumar 94772085698SPrashant Kumar auto fma = [&](Value a, Value b, Value c) -> Value { 94872085698SPrashant Kumar return builder.create<math::FmaOp>(a, b, c); 94972085698SPrashant Kumar }; 95072085698SPrashant Kumar 95172085698SPrashant Kumar auto mul = [&](Value a, Value b) -> Value { 95272085698SPrashant Kumar return builder.create<arith::MulFOp>(a, b); 95372085698SPrashant Kumar }; 95472085698SPrashant Kumar 95572085698SPrashant Kumar Value negOperand = builder.create<arith::NegFOp>(operand); 95672085698SPrashant Kumar Value zero = bcast(floatCst(builder, 0.0, elementType)); 95772085698SPrashant Kumar Value half = bcast(floatCst(builder, 0.5, elementType)); 95872085698SPrashant Kumar Value negOne = bcast(floatCst(builder, -1.0, elementType)); 95972085698SPrashant Kumar Value selR = 96072085698SPrashant Kumar builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand, zero); 96172085698SPrashant Kumar Value r = builder.create<arith::SelectOp>(selR, negOperand, operand); 96272085698SPrashant Kumar Value chkConst = bcast(floatCst(builder, -0.5625, elementType)); 96372085698SPrashant Kumar Value firstPred = 96472085698SPrashant Kumar builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, r, chkConst); 96572085698SPrashant Kumar 96672085698SPrashant Kumar Value trueVal = 96772085698SPrashant Kumar fma(bcast(floatCst(builder, 9.3282184640716537e-1, elementType)), 96872085698SPrashant Kumar bcast(floatCst(builder, 1.6839188885261840e+0, elementType)), 96972085698SPrashant Kumar builder.create<math::AsinOp>(r)); 97072085698SPrashant Kumar 97172085698SPrashant Kumar Value falseVal = builder.create<math::SqrtOp>(fma(half, r, half)); 97272085698SPrashant Kumar falseVal = builder.create<math::AsinOp>(falseVal); 97372085698SPrashant Kumar falseVal = mul(bcast(floatCst(builder, 2.0, elementType)), falseVal); 97472085698SPrashant Kumar 97572085698SPrashant Kumar r = builder.create<arith::SelectOp>(firstPred, trueVal, falseVal); 97672085698SPrashant Kumar 97772085698SPrashant Kumar // Check whether the operand lies in between [-1.0, 0.0). 97872085698SPrashant Kumar Value greaterThanNegOne = 97972085698SPrashant Kumar builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, operand, negOne); 98072085698SPrashant Kumar 98172085698SPrashant Kumar Value lessThanZero = 98272085698SPrashant Kumar builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero); 98372085698SPrashant Kumar 98472085698SPrashant Kumar Value betweenNegOneZero = 98572085698SPrashant Kumar builder.create<arith::AndIOp>(greaterThanNegOne, lessThanZero); 98672085698SPrashant Kumar 98772085698SPrashant Kumar trueVal = fma(bcast(floatCst(builder, 1.8656436928143307e+0, elementType)), 98872085698SPrashant Kumar bcast(floatCst(builder, 1.6839188885261840e+0, elementType)), 98972085698SPrashant Kumar builder.create<arith::NegFOp>(r)); 99072085698SPrashant Kumar 99172085698SPrashant Kumar Value finalVal = 99272085698SPrashant Kumar builder.create<arith::SelectOp>(betweenNegOneZero, trueVal, r); 99372085698SPrashant Kumar 99472085698SPrashant Kumar rewriter.replaceOp(op, finalVal); 99572085698SPrashant Kumar return success(); 99672085698SPrashant Kumar } 99772085698SPrashant Kumar 99872085698SPrashant Kumar //----------------------------------------------------------------------------// 999f1b92218SBoian Petkantchin // Erf approximation. 1000f1b92218SBoian Petkantchin //----------------------------------------------------------------------------// 1001f1b92218SBoian Petkantchin 1002f1b92218SBoian Petkantchin // Approximates erf(x) with 1003f1b92218SBoian Petkantchin // a - P(x)/Q(x) 1004f1b92218SBoian Petkantchin // where P and Q are polynomials of degree 4. 1005f1b92218SBoian Petkantchin // Different coefficients are chosen based on the value of x. 1006f1b92218SBoian Petkantchin // The approximation error is ~2.5e-07. 1007f1b92218SBoian Petkantchin // Boost's minimax tool that utilizes the Remez method was used to find the 1008f1b92218SBoian Petkantchin // coefficients. 1009f1b92218SBoian Petkantchin LogicalResult 1010f1b92218SBoian Petkantchin ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op, 1011f1b92218SBoian Petkantchin PatternRewriter &rewriter) const { 1012880b8f4eSPrashant Kumar Value operand = op.getOperand(); 1013880b8f4eSPrashant Kumar Type elementType = getElementTypeOrSelf(operand); 1014f1b92218SBoian Petkantchin 1015880b8f4eSPrashant Kumar if (!(elementType.isF32() || elementType.isF16())) 1016880b8f4eSPrashant Kumar return rewriter.notifyMatchFailure(op, 1017880b8f4eSPrashant Kumar "only f32 and f16 type is supported."); 1018*72f36217SKunwar Grover std::optional<VectorShape> shape = vectorShape(operand); 1019ec32d540SEugene Zhulenev 1020f1b92218SBoian Petkantchin ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 1021f1b92218SBoian Petkantchin auto bcast = [&](Value value) -> Value { 1022ec32d540SEugene Zhulenev return broadcast(builder, value, shape); 1023f1b92218SBoian Petkantchin }; 1024f1b92218SBoian Petkantchin 1025f1b92218SBoian Petkantchin const int intervalsCount = 3; 1026f1b92218SBoian Petkantchin const int polyDegree = 4; 1027f1b92218SBoian Petkantchin 1028880b8f4eSPrashant Kumar Value zero = bcast(floatCst(builder, 0, elementType)); 1029880b8f4eSPrashant Kumar Value one = bcast(floatCst(builder, 1, elementType)); 1030f1b92218SBoian Petkantchin Value pp[intervalsCount][polyDegree + 1]; 1031880b8f4eSPrashant Kumar pp[0][0] = bcast(floatCst(builder, +0.00000000000000000e+00f, elementType)); 1032880b8f4eSPrashant Kumar pp[0][1] = bcast(floatCst(builder, +1.12837916222975858e+00f, elementType)); 1033880b8f4eSPrashant Kumar pp[0][2] = bcast(floatCst(builder, -5.23018562988006470e-01f, elementType)); 1034880b8f4eSPrashant Kumar pp[0][3] = bcast(floatCst(builder, +2.09741709609267072e-01f, elementType)); 1035880b8f4eSPrashant Kumar pp[0][4] = bcast(floatCst(builder, +2.58146801602987875e-02f, elementType)); 1036880b8f4eSPrashant Kumar pp[1][0] = bcast(floatCst(builder, +0.00000000000000000e+00f, elementType)); 1037880b8f4eSPrashant Kumar pp[1][1] = bcast(floatCst(builder, +1.12750687816789140e+00f, elementType)); 1038880b8f4eSPrashant Kumar pp[1][2] = bcast(floatCst(builder, -3.64721408487825775e-01f, elementType)); 1039880b8f4eSPrashant Kumar pp[1][3] = bcast(floatCst(builder, +1.18407396425136952e-01f, elementType)); 1040880b8f4eSPrashant Kumar pp[1][4] = bcast(floatCst(builder, +3.70645533056476558e-02f, elementType)); 1041880b8f4eSPrashant Kumar pp[2][0] = bcast(floatCst(builder, -3.30093071049483172e-03f, elementType)); 1042880b8f4eSPrashant Kumar pp[2][1] = bcast(floatCst(builder, +3.51961938357697011e-03f, elementType)); 1043880b8f4eSPrashant Kumar pp[2][2] = bcast(floatCst(builder, -1.41373622814988039e-03f, elementType)); 1044880b8f4eSPrashant Kumar pp[2][3] = bcast(floatCst(builder, +2.53447094961941348e-04f, elementType)); 1045880b8f4eSPrashant Kumar pp[2][4] = bcast(floatCst(builder, -1.71048029455037401e-05f, elementType)); 1046f1b92218SBoian Petkantchin 1047f1b92218SBoian Petkantchin Value qq[intervalsCount][polyDegree + 1]; 1048880b8f4eSPrashant Kumar qq[0][0] = bcast(floatCst(builder, +1.000000000000000000e+00f, elementType)); 1049880b8f4eSPrashant Kumar qq[0][1] = bcast(floatCst(builder, -4.635138185962547255e-01f, elementType)); 1050880b8f4eSPrashant Kumar qq[0][2] = bcast(floatCst(builder, +5.192301327279782447e-01f, elementType)); 1051880b8f4eSPrashant Kumar qq[0][3] = bcast(floatCst(builder, -1.318089722204810087e-01f, elementType)); 1052880b8f4eSPrashant Kumar qq[0][4] = bcast(floatCst(builder, +7.397964654672315005e-02f, elementType)); 1053880b8f4eSPrashant Kumar qq[1][0] = bcast(floatCst(builder, +1.00000000000000000e+00f, elementType)); 1054880b8f4eSPrashant Kumar qq[1][1] = bcast(floatCst(builder, -3.27607011824493086e-01f, elementType)); 1055880b8f4eSPrashant Kumar qq[1][2] = bcast(floatCst(builder, +4.48369090658821977e-01f, elementType)); 1056880b8f4eSPrashant Kumar qq[1][3] = bcast(floatCst(builder, -8.83462621207857930e-02f, elementType)); 1057880b8f4eSPrashant Kumar qq[1][4] = bcast(floatCst(builder, +5.72442770283176093e-02f, elementType)); 1058880b8f4eSPrashant Kumar qq[2][0] = bcast(floatCst(builder, +1.00000000000000000e+00f, elementType)); 1059880b8f4eSPrashant Kumar qq[2][1] = bcast(floatCst(builder, -2.06069165953913769e+00f, elementType)); 1060880b8f4eSPrashant Kumar qq[2][2] = bcast(floatCst(builder, +1.62705939945477759e+00f, elementType)); 1061880b8f4eSPrashant Kumar qq[2][3] = bcast(floatCst(builder, -5.83389859211130017e-01f, elementType)); 1062880b8f4eSPrashant Kumar qq[2][4] = bcast(floatCst(builder, +8.21908939856640930e-02f, elementType)); 1063f1b92218SBoian Petkantchin 1064f1b92218SBoian Petkantchin Value offsets[intervalsCount]; 1065880b8f4eSPrashant Kumar offsets[0] = bcast(floatCst(builder, 0.0f, elementType)); 1066880b8f4eSPrashant Kumar offsets[1] = bcast(floatCst(builder, 0.0f, elementType)); 1067880b8f4eSPrashant Kumar offsets[2] = bcast(floatCst(builder, 1.0f, elementType)); 1068f1b92218SBoian Petkantchin 1069f1b92218SBoian Petkantchin Value bounds[intervalsCount]; 1070880b8f4eSPrashant Kumar bounds[0] = bcast(floatCst(builder, 0.8f, elementType)); 1071880b8f4eSPrashant Kumar bounds[1] = bcast(floatCst(builder, 2.0f, elementType)); 1072880b8f4eSPrashant Kumar bounds[2] = bcast(floatCst(builder, 3.75f, elementType)); 1073f1b92218SBoian Petkantchin 1074880b8f4eSPrashant Kumar Value isNegativeArg = 1075880b8f4eSPrashant Kumar builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero); 1076880b8f4eSPrashant Kumar Value negArg = builder.create<arith::NegFOp>(operand); 1077880b8f4eSPrashant Kumar Value x = builder.create<arith::SelectOp>(isNegativeArg, negArg, operand); 1078f1b92218SBoian Petkantchin 1079f1b92218SBoian Petkantchin Value offset = offsets[0]; 1080f1b92218SBoian Petkantchin Value p[polyDegree + 1]; 1081f1b92218SBoian Petkantchin Value q[polyDegree + 1]; 1082f1b92218SBoian Petkantchin for (int i = 0; i <= polyDegree; ++i) { 1083f1b92218SBoian Petkantchin p[i] = pp[0][i]; 1084f1b92218SBoian Petkantchin q[i] = qq[0][i]; 1085f1b92218SBoian Petkantchin } 1086f1b92218SBoian Petkantchin 1087f1b92218SBoian Petkantchin // TODO: maybe use vector stacking to reduce the number of selects. 1088f1b92218SBoian Petkantchin Value isLessThanBound[intervalsCount]; 1089f1b92218SBoian Petkantchin for (int j = 0; j < intervalsCount - 1; ++j) { 1090f1b92218SBoian Petkantchin isLessThanBound[j] = 1091f1b92218SBoian Petkantchin builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[j]); 1092f1b92218SBoian Petkantchin for (int i = 0; i <= polyDegree; ++i) { 1093dec8af70SRiver Riddle p[i] = builder.create<arith::SelectOp>(isLessThanBound[j], p[i], 1094dec8af70SRiver Riddle pp[j + 1][i]); 1095dec8af70SRiver Riddle q[i] = builder.create<arith::SelectOp>(isLessThanBound[j], q[i], 1096dec8af70SRiver Riddle qq[j + 1][i]); 1097f1b92218SBoian Petkantchin } 1098dec8af70SRiver Riddle offset = builder.create<arith::SelectOp>(isLessThanBound[j], offset, 1099dec8af70SRiver Riddle offsets[j + 1]); 1100f1b92218SBoian Petkantchin } 1101f1b92218SBoian Petkantchin isLessThanBound[intervalsCount - 1] = builder.create<arith::CmpFOp>( 1102f1b92218SBoian Petkantchin arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]); 1103f1b92218SBoian Petkantchin 1104f1b92218SBoian Petkantchin Value pPoly = makePolynomialCalculation(builder, p, x); 1105f1b92218SBoian Petkantchin Value qPoly = makePolynomialCalculation(builder, q, x); 1106f1b92218SBoian Petkantchin Value rationalPoly = builder.create<arith::DivFOp>(pPoly, qPoly); 1107f1b92218SBoian Petkantchin Value formula = builder.create<arith::AddFOp>(offset, rationalPoly); 1108dec8af70SRiver Riddle formula = builder.create<arith::SelectOp>(isLessThanBound[intervalsCount - 1], 1109f1b92218SBoian Petkantchin formula, one); 1110f1b92218SBoian Petkantchin 1111f1b92218SBoian Petkantchin // erf is odd function: erf(x) = -erf(-x). 1112f1b92218SBoian Petkantchin Value negFormula = builder.create<arith::NegFOp>(formula); 1113dec8af70SRiver Riddle Value res = 1114dec8af70SRiver Riddle builder.create<arith::SelectOp>(isNegativeArg, negFormula, formula); 1115f1b92218SBoian Petkantchin 1116f1b92218SBoian Petkantchin rewriter.replaceOp(op, res); 1117f1b92218SBoian Petkantchin 1118f1b92218SBoian Petkantchin return success(); 1119f1b92218SBoian Petkantchin } 1120f1b92218SBoian Petkantchin 1121f1b92218SBoian Petkantchin //----------------------------------------------------------------------------// 1122ea7f211bSAhmed Taei // Exp approximation. 1123ea7f211bSAhmed Taei //----------------------------------------------------------------------------// 1124ea7f211bSAhmed Taei 1125ea7f211bSAhmed Taei namespace { 1126ea7f211bSAhmed Taei 1127*72f36217SKunwar Grover Value clampWithNormals(ImplicitLocOpBuilder &builder, 1128*72f36217SKunwar Grover const std::optional<VectorShape> shape, Value value, 1129*72f36217SKunwar Grover float lowerBound, float upperBound) { 1130710dc728SRobert Suderman assert(!std::isnan(lowerBound)); 1131710dc728SRobert Suderman assert(!std::isnan(upperBound)); 1132710dc728SRobert Suderman 1133710dc728SRobert Suderman auto bcast = [&](Value value) -> Value { 1134710dc728SRobert Suderman return broadcast(builder, value, shape); 1135710dc728SRobert Suderman }; 1136710dc728SRobert Suderman 1137710dc728SRobert Suderman auto selectCmp = [&builder](auto pred, Value value, Value bound) { 1138710dc728SRobert Suderman return builder.create<arith::SelectOp>( 1139710dc728SRobert Suderman builder.create<arith::CmpFOp>(pred, value, bound), value, bound); 1140710dc728SRobert Suderman }; 1141710dc728SRobert Suderman 1142710dc728SRobert Suderman // Note: prefer UGE/ULE vs. UGT/ULT, since they generate vmaxps/vminps vs. 1143710dc728SRobert Suderman // vcmpleps+vmovaps on x86_64. The latter outcome is also obtained with 1144710dc728SRobert Suderman // arith::{Max,Min}FOp. 1145710dc728SRobert Suderman value = selectCmp(arith::CmpFPredicate::UGE, value, 1146710dc728SRobert Suderman bcast(f32Cst(builder, lowerBound))); 1147710dc728SRobert Suderman value = selectCmp(arith::CmpFPredicate::ULE, value, 1148710dc728SRobert Suderman bcast(f32Cst(builder, upperBound))); 1149710dc728SRobert Suderman return value; 1150710dc728SRobert Suderman } 1151710dc728SRobert Suderman 1152ea7f211bSAhmed Taei struct ExpApproximation : public OpRewritePattern<math::ExpOp> { 1153ea7f211bSAhmed Taei public: 1154ea7f211bSAhmed Taei using OpRewritePattern::OpRewritePattern; 1155ea7f211bSAhmed Taei 1156ea7f211bSAhmed Taei LogicalResult matchAndRewrite(math::ExpOp op, 1157ea7f211bSAhmed Taei PatternRewriter &rewriter) const final; 1158ea7f211bSAhmed Taei }; 1159ea7f211bSAhmed Taei 1160ea7f211bSAhmed Taei LogicalResult 1161ea7f211bSAhmed Taei ExpApproximation::matchAndRewrite(math::ExpOp op, 1162ea7f211bSAhmed Taei PatternRewriter &rewriter) const { 1163710dc728SRobert Suderman auto shape = vectorShape(op.getOperand().getType()); 1164710dc728SRobert Suderman auto elementTy = getElementTypeOrSelf(op.getType()); 1165710dc728SRobert Suderman if (!elementTy.isF32()) 1166ea7f211bSAhmed Taei return rewriter.notifyMatchFailure(op, "unsupported operand type"); 1167ec32d540SEugene Zhulenev 1168ea7f211bSAhmed Taei ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 1169ea7f211bSAhmed Taei 1170710dc728SRobert Suderman auto add = [&](Value a, Value b) -> Value { 1171710dc728SRobert Suderman return builder.create<arith::AddFOp>(a, b); 1172710dc728SRobert Suderman }; 1173ea7f211bSAhmed Taei auto bcast = [&](Value value) -> Value { 1174ec32d540SEugene Zhulenev return broadcast(builder, value, shape); 1175ea7f211bSAhmed Taei }; 1176710dc728SRobert Suderman auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); }; 1177ea7f211bSAhmed Taei auto fmla = [&](Value a, Value b, Value c) { 1178a54f4eaeSMogball return builder.create<math::FmaOp>(a, b, c); 1179ea7f211bSAhmed Taei }; 1180ea7f211bSAhmed Taei auto mul = [&](Value a, Value b) -> Value { 1181a54f4eaeSMogball return builder.create<arith::MulFOp>(a, b); 1182ea7f211bSAhmed Taei }; 1183ea7f211bSAhmed Taei 1184710dc728SRobert Suderman // Polynomial approximation from Cephes. 1185710dc728SRobert Suderman // 1186710dc728SRobert Suderman // To compute e^x, we re-express it as 1187710dc728SRobert Suderman // 1188710dc728SRobert Suderman // e^x = e^(a + b) 1189710dc728SRobert Suderman // = e^(a + n log(2)) 1190710dc728SRobert Suderman // = e^a * 2^n. 1191710dc728SRobert Suderman // 1192710dc728SRobert Suderman // We choose n = round(x / log(2)), restricting the value of `a` to 1193710dc728SRobert Suderman // (-log(2)/2, log(2)/2). We then use a polynomial to compute e^a. The 1194710dc728SRobert Suderman // relative error between our approximation and the true value of e^a is less 1195710dc728SRobert Suderman // than 2^-22.5 for all values of `a` within this range. 1196ea7f211bSAhmed Taei 1197710dc728SRobert Suderman // Restrict input to a small range, including some values that evaluate to 1198710dc728SRobert Suderman // +/- inf. Note that for our lower bound, we choose log(2^-126) instead of 1199710dc728SRobert Suderman // log(F32_EPSILON). We do so because this routine always flushes denormal 1200710dc728SRobert Suderman // floating points to 0. Therefore, we only need to worry about exponentiating 1201710dc728SRobert Suderman // up to the smallest representable non-denormal floating point, which is 1202710dc728SRobert Suderman // 2^-126. 1203ea7f211bSAhmed Taei 1204710dc728SRobert Suderman // Constants. 1205d4933b32SMehdi Amini Value cstHalf = bcast(f32Cst(builder, 0.5f)); 1206d4933b32SMehdi Amini Value cstOne = bcast(f32Cst(builder, 1.0f)); 1207710dc728SRobert Suderman 1208710dc728SRobert Suderman // 1/log(2) 1209d4933b32SMehdi Amini Value cstLog2ef = bcast(f32Cst(builder, 1.44269504088896341f)); 1210710dc728SRobert Suderman 1211d4933b32SMehdi Amini Value cstExpC1 = bcast(f32Cst(builder, -0.693359375f)); 1212d4933b32SMehdi Amini Value cstExpC2 = bcast(f32Cst(builder, 2.12194440e-4f)); 1213d4933b32SMehdi Amini Value cstExpP0 = bcast(f32Cst(builder, 1.9875691500E-4f)); 1214d4933b32SMehdi Amini Value cstExpP1 = bcast(f32Cst(builder, 1.3981999507E-3f)); 1215d4933b32SMehdi Amini Value cstExpP2 = bcast(f32Cst(builder, 8.3334519073E-3f)); 1216d4933b32SMehdi Amini Value cstExpP3 = bcast(f32Cst(builder, 4.1665795894E-2f)); 1217d4933b32SMehdi Amini Value cstExpP4 = bcast(f32Cst(builder, 1.6666665459E-1f)); 1218d4933b32SMehdi Amini Value cstExpP5 = bcast(f32Cst(builder, 5.0000001201E-1f)); 1219710dc728SRobert Suderman 1220710dc728SRobert Suderman // Our computations below aren't particularly sensitive to the exact choices 1221710dc728SRobert Suderman // here, so we choose values a bit larger/smaller than 1222710dc728SRobert Suderman // 1223710dc728SRobert Suderman // log(F32_MAX) = 88.723... 1224710dc728SRobert Suderman // log(2^-126) = -87.337... 122562fea88bSJacques Pienaar Value x = op.getOperand(); 1226710dc728SRobert Suderman x = clampWithNormals(builder, shape, x, -87.8f, 88.8f); 1227d4933b32SMehdi Amini Value n = floor(fmla(x, cstLog2ef, cstHalf)); 1228ea7f211bSAhmed Taei 1229710dc728SRobert Suderman // When we eventually do the multiplication in e^a * 2^n, we need to handle 1230710dc728SRobert Suderman // the case when n > 127, the max fp32 exponent (so 2^n == inf) but e^a < 1 1231710dc728SRobert Suderman // (so e^a * 2^n != inf). There's a similar problem for n < -126, the 1232710dc728SRobert Suderman // smallest fp32 exponent. 1233710dc728SRobert Suderman // 1234710dc728SRobert Suderman // A straightforward solution would be to detect n out of range and split it 1235710dc728SRobert Suderman // up, doing 1236710dc728SRobert Suderman // 1237710dc728SRobert Suderman // e^a * 2^n = e^a * 2^(n1 + n2) 1238710dc728SRobert Suderman // = (2^n1 * e^a) * 2^n2. 1239710dc728SRobert Suderman // 1240710dc728SRobert Suderman // But it turns out this approach is quite slow, probably because it 1241710dc728SRobert Suderman // manipulates subnormal values. 1242710dc728SRobert Suderman // 1243710dc728SRobert Suderman // The approach we use instead is to clamp n to [-127, 127]. Let n' be the 1244710dc728SRobert Suderman // value of n clamped to [-127, 127]. In the case where n' = 127, `a` can grow 1245710dc728SRobert Suderman // up to as large as 88.8 - 127 * log(2) which is about 0.7703. Even though 1246710dc728SRobert Suderman // this value of `a` is outside our previously specified range, e^a will still 1247710dc728SRobert Suderman // only have a relative error of approximately 2^-16 at worse. In practice 1248710dc728SRobert Suderman // this seems to work well enough; it passes our exhaustive tests, breaking 1249710dc728SRobert Suderman // only one result, and by one ulp (we return exp(88.7228394) = max-float but 1250710dc728SRobert Suderman // we should return inf). 1251710dc728SRobert Suderman // 1252710dc728SRobert Suderman // In the case where n' = -127, the original input value of x is so small that 1253710dc728SRobert Suderman // e^x, our final answer, is less than 2^-126. Since 2^-126 is the smallest 1254710dc728SRobert Suderman // normal floating point, and since we flush denormals, we simply return 0. We 1255710dc728SRobert Suderman // do this in a branchless way by observing that our code for constructing 2^n 1256710dc728SRobert Suderman // produces 0 if n = -127. 1257710dc728SRobert Suderman // 1258710dc728SRobert Suderman // The proof that n' = -127 implies e^x < 2^-126 is as follows: 1259710dc728SRobert Suderman // 1260710dc728SRobert Suderman // n' = -127 implies n <= -127 1261710dc728SRobert Suderman // implies round(x / log(2)) <= -127 1262710dc728SRobert Suderman // implies x/log(2) < -126.5 1263710dc728SRobert Suderman // implies x < -126.5 * log(2) 1264710dc728SRobert Suderman // implies e^x < e^(-126.5 * log(2)) 1265710dc728SRobert Suderman // implies e^x < 2^-126.5 < 2^-126 1266710dc728SRobert Suderman // 1267710dc728SRobert Suderman // This proves that n' = -127 implies e^x < 2^-126. 1268710dc728SRobert Suderman n = clampWithNormals(builder, shape, n, -127.0f, 127.0f); 1269b122cbebSAdrian Kuegel 1270710dc728SRobert Suderman // Computes x = x - n' * log(2), the value for `a` 1271d4933b32SMehdi Amini x = fmla(cstExpC1, n, x); 1272d4933b32SMehdi Amini x = fmla(cstExpC2, n, x); 1273ea7f211bSAhmed Taei 1274710dc728SRobert Suderman // Polynomial to compute z = e^a, accurate for a in (-0.5, 0.5). 1275d4933b32SMehdi Amini Value z = fmla(x, cstExpP0, cstExpP1); 1276d4933b32SMehdi Amini z = fmla(z, x, cstExpP2); 1277d4933b32SMehdi Amini z = fmla(z, x, cstExpP3); 1278d4933b32SMehdi Amini z = fmla(z, x, cstExpP4); 1279d4933b32SMehdi Amini z = fmla(z, x, cstExpP5); 1280710dc728SRobert Suderman z = fmla(z, mul(x, x), x); 1281d4933b32SMehdi Amini z = add(cstOne, z); 1282ea7f211bSAhmed Taei 1283710dc728SRobert Suderman // Convert n' to an i32. This is safe because we clamped it above. 1284d4933b32SMehdi Amini auto i32Vec = broadcast(builder.getI32Type(), shape); 1285d4933b32SMehdi Amini Value nI32 = builder.create<arith::FPToSIOp>(i32Vec, n); 1286ea7f211bSAhmed Taei 1287710dc728SRobert Suderman // Creates the value 2^n' if -126 <= n' <= 127 and 0 if n' = -127. 1288d4933b32SMehdi Amini Value pow2 = exp2I32(builder, nI32); 1289ea7f211bSAhmed Taei 1290710dc728SRobert Suderman // Return z * 2^n' if -126 <= n' <= 127 and 0 if n = -127. 1291710dc728SRobert Suderman Value ret = mul(z, pow2); 1292ea7f211bSAhmed Taei 1293710dc728SRobert Suderman rewriter.replaceOp(op, ret); 1294710dc728SRobert Suderman return mlir::success(); 1295ea7f211bSAhmed Taei } 1296ea7f211bSAhmed Taei 1297710dc728SRobert Suderman } // namespace 1298710dc728SRobert Suderman 1299ea7f211bSAhmed Taei //----------------------------------------------------------------------------// 13000edc4bc8SEmilio Cota // ExpM1 approximation. 13010edc4bc8SEmilio Cota //----------------------------------------------------------------------------// 13020edc4bc8SEmilio Cota 13030edc4bc8SEmilio Cota namespace { 13040edc4bc8SEmilio Cota 13050edc4bc8SEmilio Cota struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> { 13060edc4bc8SEmilio Cota public: 13070edc4bc8SEmilio Cota using OpRewritePattern::OpRewritePattern; 13080edc4bc8SEmilio Cota 13090edc4bc8SEmilio Cota LogicalResult matchAndRewrite(math::ExpM1Op op, 13100edc4bc8SEmilio Cota PatternRewriter &rewriter) const final; 13110edc4bc8SEmilio Cota }; 13120edc4bc8SEmilio Cota } // namespace 13130edc4bc8SEmilio Cota 13140edc4bc8SEmilio Cota LogicalResult 13150edc4bc8SEmilio Cota ExpM1Approximation::matchAndRewrite(math::ExpM1Op op, 13160edc4bc8SEmilio Cota PatternRewriter &rewriter) const { 131762fea88bSJacques Pienaar if (!getElementTypeOrSelf(op.getOperand()).isF32()) 13180edc4bc8SEmilio Cota return rewriter.notifyMatchFailure(op, "unsupported operand type"); 13190edc4bc8SEmilio Cota 1320*72f36217SKunwar Grover std::optional<VectorShape> shape = vectorShape(op.getOperand()); 1321ec32d540SEugene Zhulenev 13220edc4bc8SEmilio Cota ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 13230edc4bc8SEmilio Cota auto bcast = [&](Value value) -> Value { 1324ec32d540SEugene Zhulenev return broadcast(builder, value, shape); 13250edc4bc8SEmilio Cota }; 13260edc4bc8SEmilio Cota 13270edc4bc8SEmilio Cota // expm1(x) = exp(x) - 1 = u - 1. 13280edc4bc8SEmilio Cota // We have to handle it carefully when x is near 0, i.e. u ~= 1, 13290edc4bc8SEmilio Cota // and when the input is ~= -inf, i.e. u - 1 ~= -1. 13300edc4bc8SEmilio Cota Value cstOne = bcast(f32Cst(builder, 1.0f)); 13310edc4bc8SEmilio Cota Value cstNegOne = bcast(f32Cst(builder, -1.0f)); 133262fea88bSJacques Pienaar Value x = op.getOperand(); 13330edc4bc8SEmilio Cota Value u = builder.create<math::ExpOp>(x); 133487de451bSAdrian Kuegel Value uEqOneOrNaN = 133587de451bSAdrian Kuegel builder.create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne); 1336a54f4eaeSMogball Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne); 1337a54f4eaeSMogball Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>( 1338a54f4eaeSMogball arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne); 13390edc4bc8SEmilio Cota // logU = log(u) ~= x 13400edc4bc8SEmilio Cota Value logU = builder.create<math::LogOp>(u); 13410edc4bc8SEmilio Cota 13420edc4bc8SEmilio Cota // Detect exp(x) = +inf; written this way to avoid having to form +inf. 1343a54f4eaeSMogball Value isInf = 1344a54f4eaeSMogball builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u); 13450edc4bc8SEmilio Cota 13460edc4bc8SEmilio Cota // (u - 1) * (x / ~x) 1347a54f4eaeSMogball Value expm1 = builder.create<arith::MulFOp>( 1348a54f4eaeSMogball uMinusOne, builder.create<arith::DivFOp>(x, logU)); 1349dec8af70SRiver Riddle expm1 = builder.create<arith::SelectOp>(isInf, u, expm1); 1350dec8af70SRiver Riddle Value approximation = builder.create<arith::SelectOp>( 135187de451bSAdrian Kuegel uEqOneOrNaN, x, 1352dec8af70SRiver Riddle builder.create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1)); 13530edc4bc8SEmilio Cota rewriter.replaceOp(op, approximation); 13540edc4bc8SEmilio Cota return success(); 13550edc4bc8SEmilio Cota } 13560edc4bc8SEmilio Cota 13570edc4bc8SEmilio Cota //----------------------------------------------------------------------------// 13587e2d672aSAhmed S. Taei // Sin and Cos approximation. 13597e2d672aSAhmed S. Taei //----------------------------------------------------------------------------// 13607e2d672aSAhmed S. Taei 13617e2d672aSAhmed S. Taei namespace { 13627e2d672aSAhmed S. Taei 13637e2d672aSAhmed S. Taei template <bool isSine, typename OpTy> 13647e2d672aSAhmed S. Taei struct SinAndCosApproximation : public OpRewritePattern<OpTy> { 13657e2d672aSAhmed S. Taei public: 13667e2d672aSAhmed S. Taei using OpRewritePattern<OpTy>::OpRewritePattern; 13677e2d672aSAhmed S. Taei 13687e2d672aSAhmed S. Taei LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final; 13697e2d672aSAhmed S. Taei }; 13707e2d672aSAhmed S. Taei } // namespace 13717e2d672aSAhmed S. Taei 13727e2d672aSAhmed S. Taei #define TWO_OVER_PI \ 13737e2d672aSAhmed S. Taei 0.6366197723675813430755350534900574481378385829618257949906693762L 13747e2d672aSAhmed S. Taei #define PI_OVER_2 \ 13757e2d672aSAhmed S. Taei 1.5707963267948966192313216916397514420985846996875529104874722961L 13767e2d672aSAhmed S. Taei 13777e2d672aSAhmed S. Taei // Approximates sin(x) or cos(x) by finding the best approximation polynomial in 13787e2d672aSAhmed S. Taei // the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the 13797e2d672aSAhmed S. Taei // reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y). 13807e2d672aSAhmed S. Taei template <bool isSine, typename OpTy> 13817e2d672aSAhmed S. Taei LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite( 13827e2d672aSAhmed S. Taei OpTy op, PatternRewriter &rewriter) const { 13837e2d672aSAhmed S. Taei static_assert( 13847e2d672aSAhmed S. Taei llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value, 13857e2d672aSAhmed S. Taei "SinAndCosApproximation pattern expects math::SinOp or math::CosOp"); 1386ec32d540SEugene Zhulenev 138762fea88bSJacques Pienaar if (!getElementTypeOrSelf(op.getOperand()).isF32()) 13887e2d672aSAhmed S. Taei return rewriter.notifyMatchFailure(op, "unsupported operand type"); 13897e2d672aSAhmed S. Taei 1390*72f36217SKunwar Grover std::optional<VectorShape> shape = vectorShape(op.getOperand()); 1391ec32d540SEugene Zhulenev 13927e2d672aSAhmed S. Taei ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 13937e2d672aSAhmed S. Taei auto bcast = [&](Value value) -> Value { 1394ec32d540SEugene Zhulenev return broadcast(builder, value, shape); 13957e2d672aSAhmed S. Taei }; 13967e2d672aSAhmed S. Taei auto mul = [&](Value a, Value b) -> Value { 1397a54f4eaeSMogball return builder.create<arith::MulFOp>(a, b); 13987e2d672aSAhmed S. Taei }; 13997e2d672aSAhmed S. Taei auto sub = [&](Value a, Value b) -> Value { 1400a54f4eaeSMogball return builder.create<arith::SubFOp>(a, b); 14017e2d672aSAhmed S. Taei }; 1402a54f4eaeSMogball auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); }; 14037e2d672aSAhmed S. Taei 1404ec32d540SEugene Zhulenev auto i32Vec = broadcast(builder.getI32Type(), shape); 14057e2d672aSAhmed S. Taei auto fPToSingedInteger = [&](Value a) -> Value { 14063c69bc4dSRiver Riddle return builder.create<arith::FPToSIOp>(i32Vec, a); 14077e2d672aSAhmed S. Taei }; 14087e2d672aSAhmed S. Taei 14097e2d672aSAhmed S. Taei auto modulo4 = [&](Value a) -> Value { 1410a54f4eaeSMogball return builder.create<arith::AndIOp>(a, bcast(i32Cst(builder, 3))); 14117e2d672aSAhmed S. Taei }; 14127e2d672aSAhmed S. Taei 14137e2d672aSAhmed S. Taei auto isEqualTo = [&](Value a, Value b) -> Value { 1414a54f4eaeSMogball return builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b); 14157e2d672aSAhmed S. Taei }; 14167e2d672aSAhmed S. Taei 14177e2d672aSAhmed S. Taei auto isGreaterThan = [&](Value a, Value b) -> Value { 1418a54f4eaeSMogball return builder.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b); 14197e2d672aSAhmed S. Taei }; 14207e2d672aSAhmed S. Taei 14217e2d672aSAhmed S. Taei auto select = [&](Value cond, Value t, Value f) -> Value { 1422dec8af70SRiver Riddle return builder.create<arith::SelectOp>(cond, t, f); 14237e2d672aSAhmed S. Taei }; 14247e2d672aSAhmed S. Taei 14257e2d672aSAhmed S. Taei auto fmla = [&](Value a, Value b, Value c) { 1426a54f4eaeSMogball return builder.create<math::FmaOp>(a, b, c); 14277e2d672aSAhmed S. Taei }; 14287e2d672aSAhmed S. Taei 1429a54f4eaeSMogball auto bitwiseOr = [&](Value a, Value b) { 1430a54f4eaeSMogball return builder.create<arith::OrIOp>(a, b); 1431a54f4eaeSMogball }; 14327e2d672aSAhmed S. Taei 1433dc3b9365SAlexandre Ganea Value twoOverPi = bcast(f32Cst(builder, (float)TWO_OVER_PI)); 1434dc3b9365SAlexandre Ganea Value piOverTwo = bcast(f32Cst(builder, (float)PI_OVER_2)); 14357e2d672aSAhmed S. Taei 143662fea88bSJacques Pienaar Value x = op.getOperand(); 14377e2d672aSAhmed S. Taei 14387e2d672aSAhmed S. Taei Value k = floor(mul(x, twoOverPi)); 14397e2d672aSAhmed S. Taei 14407e2d672aSAhmed S. Taei Value y = sub(x, mul(k, piOverTwo)); 14417e2d672aSAhmed S. Taei 14427e2d672aSAhmed S. Taei Value cstOne = bcast(f32Cst(builder, 1.0)); 14437e2d672aSAhmed S. Taei Value cstNegativeOne = bcast(f32Cst(builder, -1.0)); 14447e2d672aSAhmed S. Taei 14457e2d672aSAhmed S. Taei Value cstSC2 = bcast(f32Cst(builder, -0.16666667163372039794921875f)); 14467e2d672aSAhmed S. Taei Value cstSC4 = bcast(f32Cst(builder, 8.333347737789154052734375e-3f)); 14477e2d672aSAhmed S. Taei Value cstSC6 = bcast(f32Cst(builder, -1.9842604524455964565277099609375e-4f)); 14487e2d672aSAhmed S. Taei Value cstSC8 = 14497e2d672aSAhmed S. Taei bcast(f32Cst(builder, 2.760012648650445044040679931640625e-6f)); 14507e2d672aSAhmed S. Taei Value cstSC10 = 14517e2d672aSAhmed S. Taei bcast(f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f)); 14527e2d672aSAhmed S. Taei 14537e2d672aSAhmed S. Taei Value cstCC2 = bcast(f32Cst(builder, -0.5f)); 14547e2d672aSAhmed S. Taei Value cstCC4 = bcast(f32Cst(builder, 4.166664183139801025390625e-2f)); 14557e2d672aSAhmed S. Taei Value cstCC6 = bcast(f32Cst(builder, -1.388833043165504932403564453125e-3f)); 14567e2d672aSAhmed S. Taei Value cstCC8 = bcast(f32Cst(builder, 2.47562347794882953166961669921875e-5f)); 14577e2d672aSAhmed S. Taei Value cstCC10 = 14587e2d672aSAhmed S. Taei bcast(f32Cst(builder, -2.59630184018533327616751194000244140625e-7f)); 14597e2d672aSAhmed S. Taei 14607e2d672aSAhmed S. Taei Value kMod4 = modulo4(fPToSingedInteger(k)); 14617e2d672aSAhmed S. Taei 14627e2d672aSAhmed S. Taei Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, 0))); 14637e2d672aSAhmed S. Taei Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, 1))); 14647e2d672aSAhmed S. Taei Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, 2))); 14657e2d672aSAhmed S. Taei Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, 3))); 14667e2d672aSAhmed S. Taei 14677e2d672aSAhmed S. Taei Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2); 14687e2d672aSAhmed S. Taei Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, 1))) 14697e2d672aSAhmed S. Taei : bitwiseOr(kR1, kR2); 14707e2d672aSAhmed S. Taei 14717e2d672aSAhmed S. Taei Value y2 = mul(y, y); 14727e2d672aSAhmed S. Taei 14737e2d672aSAhmed S. Taei Value base = select(sinuseCos, cstOne, y); 14747e2d672aSAhmed S. Taei Value cstC2 = select(sinuseCos, cstCC2, cstSC2); 14757e2d672aSAhmed S. Taei Value cstC4 = select(sinuseCos, cstCC4, cstSC4); 14767e2d672aSAhmed S. Taei Value cstC6 = select(sinuseCos, cstCC6, cstSC6); 14777e2d672aSAhmed S. Taei Value cstC8 = select(sinuseCos, cstCC8, cstSC8); 14787e2d672aSAhmed S. Taei Value cstC10 = select(sinuseCos, cstCC10, cstSC10); 14797e2d672aSAhmed S. Taei 14807e2d672aSAhmed S. Taei Value v1 = fmla(y2, cstC10, cstC8); 14817e2d672aSAhmed S. Taei Value v2 = fmla(y2, v1, cstC6); 14827e2d672aSAhmed S. Taei Value v3 = fmla(y2, v2, cstC4); 14837e2d672aSAhmed S. Taei Value v4 = fmla(y2, v3, cstC2); 14847e2d672aSAhmed S. Taei Value v5 = fmla(y2, v4, cstOne); 14857e2d672aSAhmed S. Taei Value v6 = mul(base, v5); 14867e2d672aSAhmed S. Taei 14877e2d672aSAhmed S. Taei Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6); 14887e2d672aSAhmed S. Taei 14897e2d672aSAhmed S. Taei rewriter.replaceOp(op, approximation); 14907e2d672aSAhmed S. Taei 14917e2d672aSAhmed S. Taei return success(); 14927e2d672aSAhmed S. Taei } 14937e2d672aSAhmed S. Taei 14947e2d672aSAhmed S. Taei //----------------------------------------------------------------------------// 14956b538810SRobert Suderman // Cbrt approximation. 14966b538810SRobert Suderman //----------------------------------------------------------------------------// 14976b538810SRobert Suderman 14986b538810SRobert Suderman namespace { 14996b538810SRobert Suderman struct CbrtApproximation : public OpRewritePattern<math::CbrtOp> { 15006b538810SRobert Suderman using OpRewritePattern::OpRewritePattern; 15016b538810SRobert Suderman 15026b538810SRobert Suderman LogicalResult matchAndRewrite(math::CbrtOp op, 15036b538810SRobert Suderman PatternRewriter &rewriter) const final; 15046b538810SRobert Suderman }; 15056b538810SRobert Suderman } // namespace 15066b538810SRobert Suderman 15076b538810SRobert Suderman // Estimation of cube-root using an algorithm defined in 15086b538810SRobert Suderman // Hacker's Delight 2nd Edition. 15096b538810SRobert Suderman LogicalResult 15106b538810SRobert Suderman CbrtApproximation::matchAndRewrite(math::CbrtOp op, 15116b538810SRobert Suderman PatternRewriter &rewriter) const { 15126b538810SRobert Suderman auto operand = op.getOperand(); 15136b538810SRobert Suderman if (!getElementTypeOrSelf(operand).isF32()) 15146b538810SRobert Suderman return rewriter.notifyMatchFailure(op, "unsupported operand type"); 15156b538810SRobert Suderman 15166b538810SRobert Suderman ImplicitLocOpBuilder b(op->getLoc(), rewriter); 1517*72f36217SKunwar Grover std::optional<VectorShape> shape = vectorShape(operand); 15186b538810SRobert Suderman 15196b538810SRobert Suderman Type floatTy = getElementTypeOrSelf(operand.getType()); 15206b538810SRobert Suderman Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth()); 15216b538810SRobert Suderman 15226b538810SRobert Suderman // Convert to vector types if necessary. 15236b538810SRobert Suderman floatTy = broadcast(floatTy, shape); 15246b538810SRobert Suderman intTy = broadcast(intTy, shape); 15256b538810SRobert Suderman 15266089d612SRahul Kayaith auto bconst = [&](TypedAttr attr) -> Value { 15276b538810SRobert Suderman Value value = b.create<arith::ConstantOp>(attr); 15286b538810SRobert Suderman return broadcast(b, value, shape); 15296b538810SRobert Suderman }; 15306b538810SRobert Suderman 15316b538810SRobert Suderman // Declare the initial values: 15326b538810SRobert Suderman Value intTwo = bconst(b.getI32IntegerAttr(2)); 15336b538810SRobert Suderman Value intFour = bconst(b.getI32IntegerAttr(4)); 15346b538810SRobert Suderman Value intEight = bconst(b.getI32IntegerAttr(8)); 15356b538810SRobert Suderman Value intMagic = bconst(b.getI32IntegerAttr(0x2a5137a0)); 15366b538810SRobert Suderman Value fpThird = bconst(b.getF32FloatAttr(0.33333333f)); 15376b538810SRobert Suderman Value fpTwo = bconst(b.getF32FloatAttr(2.0f)); 15386b538810SRobert Suderman Value fpZero = bconst(b.getF32FloatAttr(0.0f)); 15396b538810SRobert Suderman 15406b538810SRobert Suderman // Compute an approximation of one third: 15416b538810SRobert Suderman // union {int ix; float x;}; 15426b538810SRobert Suderman // x = x0; 15436b538810SRobert Suderman // ix = ix/4 + ix/16; 15446b538810SRobert Suderman Value absValue = b.create<math::AbsFOp>(operand); 15456b538810SRobert Suderman Value intValue = b.create<arith::BitcastOp>(intTy, absValue); 15466b538810SRobert Suderman Value divideBy4 = b.create<arith::ShRSIOp>(intValue, intTwo); 15476b538810SRobert Suderman Value divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour); 15486b538810SRobert Suderman intValue = b.create<arith::AddIOp>(divideBy4, divideBy16); 15496b538810SRobert Suderman 15506b538810SRobert Suderman // ix = ix + ix/16; 15516b538810SRobert Suderman divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour); 15526b538810SRobert Suderman intValue = b.create<arith::AddIOp>(intValue, divideBy16); 15536b538810SRobert Suderman 15546b538810SRobert Suderman // ix = ix + ix/256; 15556b538810SRobert Suderman Value divideBy256 = b.create<arith::ShRSIOp>(intValue, intEight); 15566b538810SRobert Suderman intValue = b.create<arith::AddIOp>(intValue, divideBy256); 15576b538810SRobert Suderman 15586b538810SRobert Suderman // ix = 0x2a5137a0 + ix; 15596b538810SRobert Suderman intValue = b.create<arith::AddIOp>(intValue, intMagic); 15606b538810SRobert Suderman 15616b538810SRobert Suderman // Perform one newtons step: 15626b538810SRobert Suderman // x = 0.33333333f*(2.0f*x + x0/(x*x)); 15636b538810SRobert Suderman Value floatValue = b.create<arith::BitcastOp>(floatTy, intValue); 15646b538810SRobert Suderman Value squared = b.create<arith::MulFOp>(floatValue, floatValue); 15656b538810SRobert Suderman Value mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo); 15666b538810SRobert Suderman Value divSquared = b.create<arith::DivFOp>(absValue, squared); 15676b538810SRobert Suderman floatValue = b.create<arith::AddFOp>(mulTwo, divSquared); 15686b538810SRobert Suderman floatValue = b.create<arith::MulFOp>(floatValue, fpThird); 15696b538810SRobert Suderman 15706b538810SRobert Suderman // x = 0.33333333f*(2.0f*x + x0/(x*x)); 15716b538810SRobert Suderman squared = b.create<arith::MulFOp>(floatValue, floatValue); 15726b538810SRobert Suderman mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo); 15736b538810SRobert Suderman divSquared = b.create<arith::DivFOp>(absValue, squared); 15746b538810SRobert Suderman floatValue = b.create<arith::AddFOp>(mulTwo, divSquared); 15756b538810SRobert Suderman floatValue = b.create<arith::MulFOp>(floatValue, fpThird); 15766b538810SRobert Suderman 15776b538810SRobert Suderman // Check for zero and restore sign. 15786b538810SRobert Suderman Value isZero = 15796b538810SRobert Suderman b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absValue, fpZero); 15806b538810SRobert Suderman floatValue = b.create<arith::SelectOp>(isZero, fpZero, floatValue); 15816b538810SRobert Suderman floatValue = b.create<math::CopySignOp>(floatValue, operand); 15826b538810SRobert Suderman 15836b538810SRobert Suderman rewriter.replaceOp(op, floatValue); 15846b538810SRobert Suderman return success(); 15856b538810SRobert Suderman } 15866b538810SRobert Suderman 15876b538810SRobert Suderman //----------------------------------------------------------------------------// 158835553d45SEmilio Cota // Rsqrt approximation. 158935553d45SEmilio Cota //----------------------------------------------------------------------------// 159035553d45SEmilio Cota 159135553d45SEmilio Cota namespace { 159235553d45SEmilio Cota struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> { 159335553d45SEmilio Cota using OpRewritePattern::OpRewritePattern; 159435553d45SEmilio Cota 159535553d45SEmilio Cota LogicalResult matchAndRewrite(math::RsqrtOp op, 159635553d45SEmilio Cota PatternRewriter &rewriter) const final; 159735553d45SEmilio Cota }; 159835553d45SEmilio Cota } // namespace 159935553d45SEmilio Cota 160035553d45SEmilio Cota LogicalResult 160135553d45SEmilio Cota RsqrtApproximation::matchAndRewrite(math::RsqrtOp op, 160235553d45SEmilio Cota PatternRewriter &rewriter) const { 160362fea88bSJacques Pienaar if (!getElementTypeOrSelf(op.getOperand()).isF32()) 1604ec32d540SEugene Zhulenev return rewriter.notifyMatchFailure(op, "unsupported operand type"); 1605ec32d540SEugene Zhulenev 1606*72f36217SKunwar Grover std::optional<VectorShape> shape = vectorShape(op.getOperand()); 1607ec32d540SEugene Zhulenev 160835553d45SEmilio Cota // Only support already-vectorized rsqrt's. 1609*72f36217SKunwar Grover if (!shape || shape->sizes.empty() || shape->sizes.back() % 8 != 0) 161035553d45SEmilio Cota return rewriter.notifyMatchFailure(op, "unsupported operand type"); 161135553d45SEmilio Cota 161235553d45SEmilio Cota ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 161335553d45SEmilio Cota auto bcast = [&](Value value) -> Value { 1614ec32d540SEugene Zhulenev return broadcast(builder, value, shape); 161535553d45SEmilio Cota }; 161635553d45SEmilio Cota 161735553d45SEmilio Cota Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u)); 161835553d45SEmilio Cota Value cstOnePointFive = bcast(f32Cst(builder, 1.5f)); 161935553d45SEmilio Cota Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); 162035553d45SEmilio Cota Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); 162135553d45SEmilio Cota 162262fea88bSJacques Pienaar Value negHalf = builder.create<arith::MulFOp>(op.getOperand(), cstNegHalf); 162335553d45SEmilio Cota 162435553d45SEmilio Cota // Select only the inverse sqrt of positive normals (denormals are 162535553d45SEmilio Cota // flushed to zero). 162662fea88bSJacques Pienaar Value ltMinMask = builder.create<arith::CmpFOp>( 162762fea88bSJacques Pienaar arith::CmpFPredicate::OLT, op.getOperand(), cstMinNormPos); 162835553d45SEmilio Cota Value infMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 162962fea88bSJacques Pienaar op.getOperand(), cstPosInf); 163035553d45SEmilio Cota Value notNormalFiniteMask = builder.create<arith::OrIOp>(ltMinMask, infMask); 163135553d45SEmilio Cota 163235553d45SEmilio Cota // Compute an approximate result. 1633627fa0b9SEugene Zhulenev Value yApprox = handleMultidimensionalVectors( 1634627fa0b9SEugene Zhulenev builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value { 1635627fa0b9SEugene Zhulenev return builder.create<x86vector::RsqrtOp>(operands); 1636627fa0b9SEugene Zhulenev }); 163735553d45SEmilio Cota 163835553d45SEmilio Cota // Do a single step of Newton-Raphson iteration to improve the approximation. 163935553d45SEmilio Cota // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n). 164035553d45SEmilio Cota // It is essential to evaluate the inner term like this because forming 164135553d45SEmilio Cota // y_n^2 may over- or underflow. 164235553d45SEmilio Cota Value inner = builder.create<arith::MulFOp>(negHalf, yApprox); 164335553d45SEmilio Cota Value fma = builder.create<math::FmaOp>(yApprox, inner, cstOnePointFive); 164435553d45SEmilio Cota Value yNewton = builder.create<arith::MulFOp>(yApprox, fma); 164535553d45SEmilio Cota 164635553d45SEmilio Cota // Select the result of the Newton-Raphson step for positive normal arguments. 164735553d45SEmilio Cota // For other arguments, choose the output of the intrinsic. This will 164835553d45SEmilio Cota // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if 164935553d45SEmilio Cota // x is zero or a positive denormalized float (equivalent to flushing positive 165035553d45SEmilio Cota // denormalized inputs to zero). 1651dec8af70SRiver Riddle Value res = 1652dec8af70SRiver Riddle builder.create<arith::SelectOp>(notNormalFiniteMask, yApprox, yNewton); 165335553d45SEmilio Cota rewriter.replaceOp(op, res); 165435553d45SEmilio Cota 165535553d45SEmilio Cota return success(); 165635553d45SEmilio Cota } 165735553d45SEmilio Cota 165835553d45SEmilio Cota //----------------------------------------------------------------------------// 1659f99ccf65SEugene Zhulenev 1660bcf9826aSJohannes Reifferscheid void mlir::populatePolynomialApproximateTanhPattern( 1661bcf9826aSJohannes Reifferscheid RewritePatternSet &patterns) { 1662bcf9826aSJohannes Reifferscheid patterns.add<TanhApproximation>(patterns.getContext()); 1663bcf9826aSJohannes Reifferscheid } 1664bcf9826aSJohannes Reifferscheid 1665bcf9826aSJohannes Reifferscheid void mlir::populatePolynomialApproximateErfPattern( 1666bcf9826aSJohannes Reifferscheid RewritePatternSet &patterns) { 1667bcf9826aSJohannes Reifferscheid patterns.add<ErfPolynomialApproximation>(patterns.getContext()); 1668bcf9826aSJohannes Reifferscheid } 1669bcf9826aSJohannes Reifferscheid 1670f99ccf65SEugene Zhulenev void mlir::populateMathPolynomialApproximationPatterns( 167135553d45SEmilio Cota RewritePatternSet &patterns, 167235553d45SEmilio Cota const MathPolynomialApproximationOptions &options) { 167357e1943eSRobert Suderman // Patterns for leveraging existing f32 lowerings on other data types. 167457e1943eSRobert Suderman patterns 167557e1943eSRobert Suderman .add<ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>, 167657e1943eSRobert Suderman ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>, 167757e1943eSRobert Suderman ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>, 167857e1943eSRobert Suderman ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>, 167957e1943eSRobert Suderman ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>, 168057e1943eSRobert Suderman ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>( 168157e1943eSRobert Suderman patterns.getContext()); 168257e1943eSRobert Suderman 168372085698SPrashant Kumar patterns 168472085698SPrashant Kumar .add<AtanApproximation, Atan2Approximation, TanhApproximation, 16852f9f9afaSRob Suderman LogApproximation, Log2Approximation, Log1pApproximation, 168672085698SPrashant Kumar ErfPolynomialApproximation, AsinPolynomialApproximation, 168772085698SPrashant Kumar AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation, 168857e1943eSRobert Suderman CbrtApproximation, SinAndCosApproximation<true, math::SinOp>, 168972085698SPrashant Kumar SinAndCosApproximation<false, math::CosOp>>(patterns.getContext()); 169057e1943eSRobert Suderman if (options.enableAvx2) { 169157e1943eSRobert Suderman patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>( 169257e1943eSRobert Suderman patterns.getContext()); 169357e1943eSRobert Suderman } 1694f99ccf65SEugene Zhulenev } 1695