1 //===- PolynomialApproximation.cpp - Approximate math operations ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements expansion of math operations to fast approximations 10 // that do not rely on any of the library functions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include <climits> 15 #include <cmath> 16 #include <cstddef> 17 18 #include "mlir/Dialect/Arith/IR/Arith.h" 19 #include "mlir/Dialect/Math/IR/Math.h" 20 #include "mlir/Dialect/Math/Transforms/Approximation.h" 21 #include "mlir/Dialect/Math/Transforms/Passes.h" 22 #include "mlir/Dialect/Utils/IndexingUtils.h" 23 #include "mlir/Dialect/Vector/IR/VectorOps.h" 24 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 25 #include "mlir/Dialect/X86Vector/X86VectorDialect.h" 26 #include "mlir/IR/Builders.h" 27 #include "mlir/IR/BuiltinTypes.h" 28 #include "mlir/IR/ImplicitLocOpBuilder.h" 29 #include "mlir/IR/OpDefinition.h" 30 #include "mlir/IR/PatternMatch.h" 31 #include "mlir/IR/TypeUtilities.h" 32 #include "mlir/Transforms/DialectConversion.h" 33 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 34 #include "llvm/ADT/ArrayRef.h" 35 #include "llvm/ADT/STLExtras.h" 36 #include "llvm/Support/MathExtras.h" 37 38 using namespace mlir; 39 using namespace mlir::math; 40 using namespace mlir::vector; 41 42 // Helper to encapsulate a vector's shape (including scalable dims). 43 struct VectorShape { 44 ArrayRef<int64_t> sizes; 45 ArrayRef<bool> scalableFlags; 46 }; 47 48 // Returns vector shape if the type is a vector, otherwise return nullopt. 49 static std::optional<VectorShape> vectorShape(Type type) { 50 if (auto vectorType = dyn_cast<VectorType>(type)) { 51 return VectorShape{vectorType.getShape(), vectorType.getScalableDims()}; 52 } 53 return std::nullopt; 54 } 55 56 static std::optional<VectorShape> vectorShape(Value value) { 57 return vectorShape(value.getType()); 58 } 59 60 //----------------------------------------------------------------------------// 61 // Broadcast scalar types and values into vector types and values. 62 //----------------------------------------------------------------------------// 63 64 // Broadcasts scalar type into vector type (iff shape is non-scalar). 65 static Type broadcast(Type type, std::optional<VectorShape> shape) { 66 assert(!isa<VectorType>(type) && "must be scalar type"); 67 return shape ? VectorType::get(shape->sizes, type, shape->scalableFlags) 68 : type; 69 } 70 71 // Broadcasts scalar value into vector (iff shape is non-scalar). 72 static Value broadcast(ImplicitLocOpBuilder &builder, Value value, 73 std::optional<VectorShape> shape) { 74 assert(!isa<VectorType>(value.getType()) && "must be scalar value"); 75 auto type = broadcast(value.getType(), shape); 76 return shape ? builder.create<BroadcastOp>(type, value) : value; 77 } 78 79 //----------------------------------------------------------------------------// 80 // Helper function to handle n-D vectors with 1-D operations. 81 //----------------------------------------------------------------------------// 82 83 // Expands and unrolls n-D vector operands into multiple fixed size 1-D vectors 84 // and calls the compute function with 1-D vector operands. Stitches back all 85 // results into the original n-D vector result. 86 // 87 // Examples: vectorWidth = 8 88 // - vector<4x8xf32> unrolled 4 times 89 // - vector<16xf32> expanded to vector<2x8xf32> and unrolled 2 times 90 // - vector<4x16xf32> expanded to vector<4x2x8xf32> and unrolled 4*2 times 91 // 92 // Some math approximations rely on ISA-specific operations that only accept 93 // fixed size 1-D vectors (e.g. AVX expects vectors of width 8). 94 // 95 // It is the caller's responsibility to verify that the inner dimension is 96 // divisible by the vectorWidth, and that all operands have the same vector 97 // shape. 98 static Value 99 handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, 100 ValueRange operands, int64_t vectorWidth, 101 llvm::function_ref<Value(ValueRange)> compute) { 102 assert(!operands.empty() && "operands must be not empty"); 103 assert(vectorWidth > 0 && "vector width must be larger than 0"); 104 105 VectorType inputType = cast<VectorType>(operands[0].getType()); 106 ArrayRef<int64_t> inputShape = inputType.getShape(); 107 108 // If input shape matches target vector width, we can just call the 109 // user-provided compute function with the operands. 110 if (inputShape == llvm::ArrayRef(vectorWidth)) 111 return compute(operands); 112 113 // Check if the inner dimension has to be expanded, or we can directly iterate 114 // over the outer dimensions of the vector. 115 int64_t innerDim = inputShape.back(); 116 int64_t expansionDim = innerDim / vectorWidth; 117 assert((innerDim % vectorWidth == 0) && "invalid inner dimension size"); 118 119 // Maybe expand operands to the higher rank vector shape that we'll use to 120 // iterate over and extract one dimensional vectors. 121 SmallVector<int64_t> expandedShape(inputShape); 122 SmallVector<Value> expandedOperands(operands); 123 124 if (expansionDim > 1) { 125 // Expand shape from [..., innerDim] to [..., expansionDim, vectorWidth]. 126 expandedShape.insert(expandedShape.end() - 1, expansionDim); 127 expandedShape.back() = vectorWidth; 128 129 for (unsigned i = 0; i < operands.size(); ++i) { 130 auto operand = operands[i]; 131 auto eltType = cast<VectorType>(operand.getType()).getElementType(); 132 auto expandedType = VectorType::get(expandedShape, eltType); 133 expandedOperands[i] = 134 builder.create<vector::ShapeCastOp>(expandedType, operand); 135 } 136 } 137 138 // Iterate over all outer dimensions of the compute shape vector type. 139 auto iterationDims = ArrayRef<int64_t>(expandedShape).drop_back(); 140 int64_t maxIndex = computeMaxLinearIndex(iterationDims); 141 auto strides = computeStrides(iterationDims); 142 143 // Compute results for each one dimensional vector. 144 SmallVector<Value> results(maxIndex); 145 146 for (int64_t i = 0; i < maxIndex; ++i) { 147 auto offsets = delinearize(i, strides); 148 149 SmallVector<Value> extracted(expandedOperands.size()); 150 for (const auto &tuple : llvm::enumerate(expandedOperands)) 151 extracted[tuple.index()] = 152 builder.create<vector::ExtractOp>(tuple.value(), offsets); 153 154 results[i] = compute(extracted); 155 } 156 157 // Stitch results together into one large vector. 158 Type resultEltType = cast<VectorType>(results[0].getType()).getElementType(); 159 Type resultExpandedType = VectorType::get(expandedShape, resultEltType); 160 Value result = builder.create<arith::ConstantOp>( 161 resultExpandedType, builder.getZeroAttr(resultExpandedType)); 162 163 for (int64_t i = 0; i < maxIndex; ++i) 164 result = builder.create<vector::InsertOp>(results[i], result, 165 delinearize(i, strides)); 166 167 // Reshape back to the original vector shape. 168 return builder.create<vector::ShapeCastOp>( 169 VectorType::get(inputShape, resultEltType), result); 170 } 171 172 //----------------------------------------------------------------------------// 173 // Helper functions to create constants. 174 //----------------------------------------------------------------------------// 175 176 static Value floatCst(ImplicitLocOpBuilder &builder, float value, 177 Type elementType) { 178 assert((elementType.isF16() || elementType.isF32()) && 179 "x must be f16 or f32 type."); 180 return builder.create<arith::ConstantOp>( 181 builder.getFloatAttr(elementType, value)); 182 } 183 184 static Value f32Cst(ImplicitLocOpBuilder &builder, double value) { 185 return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value)); 186 } 187 188 static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) { 189 return builder.create<arith::ConstantOp>(builder.getI32IntegerAttr(value)); 190 } 191 192 static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) { 193 Value i32Value = i32Cst(builder, static_cast<int32_t>(bits)); 194 return builder.create<arith::BitcastOp>(builder.getF32Type(), i32Value); 195 } 196 197 //----------------------------------------------------------------------------// 198 // Helper functions to build math functions approximations. 199 //----------------------------------------------------------------------------// 200 201 // Return the minimum of the two values or NaN if value is NaN 202 static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound) { 203 return builder.create<arith::SelectOp>( 204 builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, value, bound), 205 value, bound); 206 } 207 208 // Return the maximum of the two values or NaN if value is NaN 209 static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound) { 210 return builder.create<arith::SelectOp>( 211 builder.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, value, bound), 212 value, bound); 213 } 214 215 // Return the clamped value or NaN if value is NaN 216 static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, 217 Value upperBound) { 218 return max(builder, min(builder, value, upperBound), lowerBound); 219 } 220 221 // Decomposes given floating point value `arg` into a normalized fraction and 222 // an integral power of two (see std::frexp). Returned values have float type. 223 static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg, 224 bool isPositive = false) { 225 assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type"); 226 std::optional<VectorShape> shape = vectorShape(arg); 227 228 auto bcast = [&](Value value) -> Value { 229 return broadcast(builder, value, shape); 230 }; 231 232 auto i32 = builder.getIntegerType(32); 233 auto i32Vec = broadcast(i32, shape); 234 auto f32Vec = broadcast(builder.getF32Type(), shape); 235 236 Value cst126f = f32Cst(builder, 126.0f); 237 Value cstHalf = f32Cst(builder, 0.5f); 238 Value cstInvMantMask = f32FromBits(builder, ~0x7f800000u); 239 240 // Bitcast to i32 for bitwise operations. 241 Value i32Half = builder.create<arith::BitcastOp>(i32, cstHalf); 242 Value i32InvMantMask = builder.create<arith::BitcastOp>(i32, cstInvMantMask); 243 Value i32Arg = builder.create<arith::BitcastOp>(i32Vec, arg); 244 245 // Compute normalized fraction. 246 Value tmp0 = builder.create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask)); 247 Value tmp1 = builder.create<arith::OrIOp>(tmp0, bcast(i32Half)); 248 Value normalizedFraction = builder.create<arith::BitcastOp>(f32Vec, tmp1); 249 250 // Compute exponent. 251 Value arg0 = isPositive ? arg : builder.create<math::AbsFOp>(arg); 252 Value biasedExponentBits = builder.create<arith::ShRUIOp>( 253 builder.create<arith::BitcastOp>(i32Vec, arg0), 254 bcast(i32Cst(builder, 23))); 255 Value biasedExponent = 256 builder.create<arith::SIToFPOp>(f32Vec, biasedExponentBits); 257 Value exponent = 258 builder.create<arith::SubFOp>(biasedExponent, bcast(cst126f)); 259 260 return {normalizedFraction, exponent}; 261 } 262 263 // Computes exp2 for an i32 argument. 264 static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) { 265 assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type"); 266 std::optional<VectorShape> shape = vectorShape(arg); 267 268 auto bcast = [&](Value value) -> Value { 269 return broadcast(builder, value, shape); 270 }; 271 272 auto f32Vec = broadcast(builder.getF32Type(), shape); 273 // The exponent of f32 located at 23-bit. 274 auto exponetBitLocation = bcast(i32Cst(builder, 23)); 275 // Set the exponent bias to zero. 276 auto bias = bcast(i32Cst(builder, 127)); 277 278 Value biasedArg = builder.create<arith::AddIOp>(arg, bias); 279 Value exp2ValueInt = 280 builder.create<arith::ShLIOp>(biasedArg, exponetBitLocation); 281 Value exp2ValueF32 = builder.create<arith::BitcastOp>(f32Vec, exp2ValueInt); 282 283 return exp2ValueF32; 284 } 285 286 namespace { 287 Value makePolynomialCalculation(ImplicitLocOpBuilder &builder, 288 llvm::ArrayRef<Value> coeffs, Value x) { 289 Type elementType = getElementTypeOrSelf(x); 290 assert((elementType.isF32() || elementType.isF16()) && 291 "x must be f32 or f16 type"); 292 std::optional<VectorShape> shape = vectorShape(x); 293 294 if (coeffs.empty()) 295 return broadcast(builder, floatCst(builder, 0.0f, elementType), shape); 296 297 if (coeffs.size() == 1) 298 return coeffs[0]; 299 300 Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1], 301 coeffs[coeffs.size() - 2]); 302 for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) { 303 res = builder.create<math::FmaOp>(x, res, coeffs[i]); 304 } 305 return res; 306 } 307 } // namespace 308 309 //----------------------------------------------------------------------------// 310 // Helper function/pattern to insert casts for reusing F32 bit expansion. 311 //----------------------------------------------------------------------------// 312 313 template <typename T> 314 LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) { 315 // Conservatively only allow where the operand and result types are exactly 1. 316 Type origType = op->getResultTypes().front(); 317 for (Type t : llvm::drop_begin(op->getResultTypes())) 318 if (origType != t) 319 return rewriter.notifyMatchFailure(op, "required all types to match"); 320 for (Type t : op->getOperandTypes()) 321 if (origType != t) 322 return rewriter.notifyMatchFailure(op, "required all types to match"); 323 324 // Skip if already F32 or larger than 32 bits. 325 if (getElementTypeOrSelf(origType).isF32() || 326 getElementTypeOrSelf(origType).getIntOrFloatBitWidth() > 32) 327 return failure(); 328 329 // Create F32 equivalent type. 330 Type newType; 331 if (auto shaped = dyn_cast<ShapedType>(origType)) { 332 newType = shaped.clone(rewriter.getF32Type()); 333 } else if (isa<FloatType>(origType)) { 334 newType = rewriter.getF32Type(); 335 } else { 336 return rewriter.notifyMatchFailure(op, 337 "unable to find F32 equivalent type"); 338 } 339 340 Location loc = op->getLoc(); 341 SmallVector<Value> operands; 342 for (auto operand : op->getOperands()) 343 operands.push_back(rewriter.create<arith::ExtFOp>(loc, newType, operand)); 344 auto result = 345 rewriter.create<T>(loc, TypeRange{newType}, operands, op->getAttrs()); 346 rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, origType, result); 347 return success(); 348 } 349 350 namespace { 351 // Pattern to cast to F32 to reuse F32 expansion as fallback for single-result 352 // op. 353 // TODO: Consider revising to avoid adding multiple casts for a subgraph that is 354 // all in lower precision. Currently this is only fallback support and performs 355 // simplistic casting. 356 template <typename T> 357 struct ReuseF32Expansion : public OpRewritePattern<T> { 358 public: 359 using OpRewritePattern<T>::OpRewritePattern; 360 LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const final { 361 static_assert( 362 T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(), 363 "requires same operands and result types"); 364 return insertCasts<T>(op, rewriter); 365 } 366 }; 367 } // namespace 368 369 //----------------------------------------------------------------------------// 370 // AtanOp approximation. 371 //----------------------------------------------------------------------------// 372 373 namespace { 374 struct AtanApproximation : public OpRewritePattern<math::AtanOp> { 375 public: 376 using OpRewritePattern::OpRewritePattern; 377 378 LogicalResult matchAndRewrite(math::AtanOp op, 379 PatternRewriter &rewriter) const final; 380 }; 381 } // namespace 382 383 LogicalResult 384 AtanApproximation::matchAndRewrite(math::AtanOp op, 385 PatternRewriter &rewriter) const { 386 auto operand = op.getOperand(); 387 if (!getElementTypeOrSelf(operand).isF32()) 388 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 389 390 std::optional<VectorShape> shape = vectorShape(op.getOperand()); 391 392 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 393 Value abs = builder.create<math::AbsFOp>(operand); 394 395 auto one = broadcast(builder, f32Cst(builder, 1.0), shape); 396 397 // When 0.66 < x <= 2.41 we do (x-1) / (x+1): 398 auto twoThirds = broadcast(builder, f32Cst(builder, 0.66), shape); 399 Value cmp2 = 400 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, twoThirds); 401 Value addone = builder.create<arith::AddFOp>(abs, one); 402 Value subone = builder.create<arith::SubFOp>(abs, one); 403 Value xnum = builder.create<arith::SelectOp>(cmp2, subone, abs); 404 Value xden = builder.create<arith::SelectOp>(cmp2, addone, one); 405 406 auto bcast = [&](Value value) -> Value { 407 return broadcast(builder, value, shape); 408 }; 409 410 // Break into the <= 0.66 or > 2.41 we do x or 1/x: 411 auto tan3pio8 = bcast(f32Cst(builder, 2.41421356237309504880)); 412 Value cmp1 = 413 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, tan3pio8); 414 xnum = builder.create<arith::SelectOp>(cmp1, one, xnum); 415 xden = builder.create<arith::SelectOp>(cmp1, abs, xden); 416 417 Value x = builder.create<arith::DivFOp>(xnum, xden); 418 Value xx = builder.create<arith::MulFOp>(x, x); 419 420 // Perform the Taylor series approximation for atan over the range 421 // [0.0, 0.66]. 422 auto p0 = bcast(f32Cst(builder, -8.750608600031904122785e-01)); 423 auto p1 = bcast(f32Cst(builder, -1.615753718733365076637e+01)); 424 auto p2 = bcast(f32Cst(builder, -7.500855792314704667340e+01)); 425 auto p3 = bcast(f32Cst(builder, -1.228866684490136173410e+02)); 426 auto p4 = bcast(f32Cst(builder, -6.485021904942025371773e+01)); 427 auto q0 = bcast(f32Cst(builder, +2.485846490142306297962e+01)); 428 auto q1 = bcast(f32Cst(builder, +1.650270098316988542046e+02)); 429 auto q2 = bcast(f32Cst(builder, +4.328810604912902668951e+02)); 430 auto q3 = bcast(f32Cst(builder, +4.853903996359136964868e+02)); 431 auto q4 = bcast(f32Cst(builder, +1.945506571482613964425e+02)); 432 433 // Apply the polynomial approximation for the numerator: 434 Value n = p0; 435 n = builder.create<math::FmaOp>(xx, n, p1); 436 n = builder.create<math::FmaOp>(xx, n, p2); 437 n = builder.create<math::FmaOp>(xx, n, p3); 438 n = builder.create<math::FmaOp>(xx, n, p4); 439 n = builder.create<arith::MulFOp>(n, xx); 440 441 // Apply the polynomial approximation for the denominator: 442 Value d = q0; 443 d = builder.create<math::FmaOp>(xx, d, q1); 444 d = builder.create<math::FmaOp>(xx, d, q2); 445 d = builder.create<math::FmaOp>(xx, d, q3); 446 d = builder.create<math::FmaOp>(xx, d, q4); 447 448 // Compute approximation of theta: 449 Value ans0 = builder.create<arith::DivFOp>(n, d); 450 ans0 = builder.create<math::FmaOp>(ans0, x, x); 451 452 // Correct for the input mapping's angles: 453 Value mpi4 = bcast(f32Cst(builder, llvm::numbers::pi / 4)); 454 Value ans2 = builder.create<arith::AddFOp>(mpi4, ans0); 455 Value ans = builder.create<arith::SelectOp>(cmp2, ans2, ans0); 456 457 Value mpi2 = bcast(f32Cst(builder, llvm::numbers::pi / 2)); 458 Value ans1 = builder.create<arith::SubFOp>(mpi2, ans0); 459 ans = builder.create<arith::SelectOp>(cmp1, ans1, ans); 460 461 // Correct for signing of the input. 462 rewriter.replaceOpWithNewOp<math::CopySignOp>(op, ans, operand); 463 return success(); 464 } 465 466 //----------------------------------------------------------------------------// 467 // AtanOp approximation. 468 //----------------------------------------------------------------------------// 469 470 namespace { 471 struct Atan2Approximation : public OpRewritePattern<math::Atan2Op> { 472 public: 473 using OpRewritePattern::OpRewritePattern; 474 475 LogicalResult matchAndRewrite(math::Atan2Op op, 476 PatternRewriter &rewriter) const final; 477 }; 478 } // namespace 479 480 LogicalResult 481 Atan2Approximation::matchAndRewrite(math::Atan2Op op, 482 PatternRewriter &rewriter) const { 483 auto y = op.getOperand(0); 484 auto x = op.getOperand(1); 485 if (!getElementTypeOrSelf(x).isF32()) 486 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 487 488 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 489 std::optional<VectorShape> shape = vectorShape(op.getResult()); 490 491 // Compute atan in the valid range. 492 auto div = builder.create<arith::DivFOp>(y, x); 493 auto atan = builder.create<math::AtanOp>(div); 494 495 // Determine what the atan would be for a 180 degree rotation. 496 auto zero = broadcast(builder, f32Cst(builder, 0.0f), shape); 497 auto pi = broadcast(builder, f32Cst(builder, 3.14159265359f), shape); 498 auto addPi = builder.create<arith::AddFOp>(atan, pi); 499 auto subPi = builder.create<arith::SubFOp>(atan, pi); 500 auto atanGt = 501 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, atan, zero); 502 auto flippedAtan = builder.create<arith::SelectOp>(atanGt, subPi, addPi); 503 504 // Determine whether to directly use atan or use the 180 degree flip 505 auto xGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zero); 506 Value result = builder.create<arith::SelectOp>(xGt, atan, flippedAtan); 507 508 // Handle x = 0, y > 0 509 Value xZero = 510 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, x, zero); 511 Value yGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, y, zero); 512 Value isHalfPi = builder.create<arith::AndIOp>(xZero, yGt); 513 auto halfPi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape); 514 result = builder.create<arith::SelectOp>(isHalfPi, halfPi, result); 515 516 // Handle x = 0, y < 0 517 Value yLt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, y, zero); 518 Value isNegativeHalfPiPi = builder.create<arith::AndIOp>(xZero, yLt); 519 auto negativeHalfPiPi = 520 broadcast(builder, f32Cst(builder, -1.57079632679f), shape); 521 result = builder.create<arith::SelectOp>(isNegativeHalfPiPi, negativeHalfPiPi, 522 result); 523 524 // Handle x = 0, y = 0; 525 Value yZero = 526 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, y, zero); 527 Value isNan = builder.create<arith::AndIOp>(xZero, yZero); 528 Value cstNan = broadcast(builder, f32FromBits(builder, 0x7fc00000), shape); 529 result = builder.create<arith::SelectOp>(isNan, cstNan, result); 530 531 rewriter.replaceOp(op, result); 532 return success(); 533 } 534 535 //----------------------------------------------------------------------------// 536 // TanhOp approximation. 537 //----------------------------------------------------------------------------// 538 539 namespace { 540 struct TanhApproximation : public OpRewritePattern<math::TanhOp> { 541 public: 542 using OpRewritePattern::OpRewritePattern; 543 544 LogicalResult matchAndRewrite(math::TanhOp op, 545 PatternRewriter &rewriter) const final; 546 }; 547 } // namespace 548 549 LogicalResult 550 TanhApproximation::matchAndRewrite(math::TanhOp op, 551 PatternRewriter &rewriter) const { 552 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 553 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 554 555 std::optional<VectorShape> shape = vectorShape(op.getOperand()); 556 557 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 558 auto bcast = [&](Value value) -> Value { 559 return broadcast(builder, value, shape); 560 }; 561 562 // Clamp operand into [plusClamp, minusClamp] range. 563 Value minusClamp = bcast(f32Cst(builder, -7.99881172180175781f)); 564 Value plusClamp = bcast(f32Cst(builder, 7.99881172180175781f)); 565 Value x = clamp(builder, op.getOperand(), minusClamp, plusClamp); 566 567 // Mask for tiny values that are approximated with `operand`. 568 Value tiny = bcast(f32Cst(builder, 0.0004f)); 569 Value tinyMask = builder.create<arith::CmpFOp>( 570 arith::CmpFPredicate::OLT, builder.create<math::AbsFOp>(op.getOperand()), 571 tiny); 572 573 // The monomial coefficients of the numerator polynomial (odd). 574 Value alpha1 = bcast(f32Cst(builder, 4.89352455891786e-03f)); 575 Value alpha3 = bcast(f32Cst(builder, 6.37261928875436e-04f)); 576 Value alpha5 = bcast(f32Cst(builder, 1.48572235717979e-05f)); 577 Value alpha7 = bcast(f32Cst(builder, 5.12229709037114e-08f)); 578 Value alpha9 = bcast(f32Cst(builder, -8.60467152213735e-11f)); 579 Value alpha11 = bcast(f32Cst(builder, 2.00018790482477e-13f)); 580 Value alpha13 = bcast(f32Cst(builder, -2.76076847742355e-16f)); 581 582 // The monomial coefficients of the denominator polynomial (even). 583 Value beta0 = bcast(f32Cst(builder, 4.89352518554385e-03f)); 584 Value beta2 = bcast(f32Cst(builder, 2.26843463243900e-03f)); 585 Value beta4 = bcast(f32Cst(builder, 1.18534705686654e-04f)); 586 Value beta6 = bcast(f32Cst(builder, 1.19825839466702e-06f)); 587 588 // Since the polynomials are odd/even, we need x^2. 589 Value x2 = builder.create<arith::MulFOp>(x, x); 590 591 // Evaluate the numerator polynomial p. 592 Value p = builder.create<math::FmaOp>(x2, alpha13, alpha11); 593 p = builder.create<math::FmaOp>(x2, p, alpha9); 594 p = builder.create<math::FmaOp>(x2, p, alpha7); 595 p = builder.create<math::FmaOp>(x2, p, alpha5); 596 p = builder.create<math::FmaOp>(x2, p, alpha3); 597 p = builder.create<math::FmaOp>(x2, p, alpha1); 598 p = builder.create<arith::MulFOp>(x, p); 599 600 // Evaluate the denominator polynomial q. 601 Value q = builder.create<math::FmaOp>(x2, beta6, beta4); 602 q = builder.create<math::FmaOp>(x2, q, beta2); 603 q = builder.create<math::FmaOp>(x2, q, beta0); 604 605 // Divide the numerator by the denominator. 606 Value res = builder.create<arith::SelectOp>( 607 tinyMask, x, builder.create<arith::DivFOp>(p, q)); 608 609 rewriter.replaceOp(op, res); 610 611 return success(); 612 } 613 614 #define LN2_VALUE \ 615 0.693147180559945309417232121458176568075500134360255254120680009493393621L 616 #define LOG2E_VALUE \ 617 1.442695040888963407359924681001892137426645954152985934135449406931109219L 618 619 //----------------------------------------------------------------------------// 620 // LogOp and Log2Op approximation. 621 //----------------------------------------------------------------------------// 622 623 namespace { 624 template <typename Op> 625 struct LogApproximationBase : public OpRewritePattern<Op> { 626 using OpRewritePattern<Op>::OpRewritePattern; 627 628 /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise. 629 LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter, 630 bool base2) const; 631 }; 632 } // namespace 633 634 // This approximation comes from Julien Pommier's SSE math library. 635 // Link: http://gruntthepeon.free.fr/ssemath 636 template <typename Op> 637 LogicalResult 638 LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter, 639 bool base2) const { 640 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 641 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 642 643 std::optional<VectorShape> shape = vectorShape(op.getOperand()); 644 645 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 646 auto bcast = [&](Value value) -> Value { 647 return broadcast(builder, value, shape); 648 }; 649 650 Value cstZero = bcast(f32Cst(builder, 0.0f)); 651 Value cstOne = bcast(f32Cst(builder, 1.0f)); 652 Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); 653 654 // The smallest non denormalized float number. 655 Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); 656 Value cstMinusInf = bcast(f32FromBits(builder, 0xff800000u)); 657 Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u)); 658 Value cstNan = bcast(f32FromBits(builder, 0x7fc00000)); 659 660 // Polynomial coefficients. 661 Value cstCephesSQRTHF = bcast(f32Cst(builder, 0.707106781186547524f)); 662 Value cstCephesLogP0 = bcast(f32Cst(builder, 7.0376836292E-2f)); 663 Value cstCephesLogP1 = bcast(f32Cst(builder, -1.1514610310E-1f)); 664 Value cstCephesLogP2 = bcast(f32Cst(builder, 1.1676998740E-1f)); 665 Value cstCephesLogP3 = bcast(f32Cst(builder, -1.2420140846E-1f)); 666 Value cstCephesLogP4 = bcast(f32Cst(builder, +1.4249322787E-1f)); 667 Value cstCephesLogP5 = bcast(f32Cst(builder, -1.6668057665E-1f)); 668 Value cstCephesLogP6 = bcast(f32Cst(builder, +2.0000714765E-1f)); 669 Value cstCephesLogP7 = bcast(f32Cst(builder, -2.4999993993E-1f)); 670 Value cstCephesLogP8 = bcast(f32Cst(builder, +3.3333331174E-1f)); 671 672 Value x = op.getOperand(); 673 674 // Truncate input values to the minimum positive normal. 675 x = max(builder, x, cstMinNormPos); 676 677 // Extract significant in the range [0.5,1) and exponent. 678 std::pair<Value, Value> pair = frexp(builder, x, /*isPositive=*/true); 679 x = pair.first; 680 Value e = pair.second; 681 682 // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift 683 // by -1.0. The values are then centered around 0, which improves the 684 // stability of the polynomial evaluation: 685 // 686 // if( x < SQRTHF ) { 687 // e -= 1; 688 // x = x + x - 1.0; 689 // } else { x = x - 1.0; } 690 Value mask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, 691 cstCephesSQRTHF); 692 Value tmp = builder.create<arith::SelectOp>(mask, x, cstZero); 693 694 x = builder.create<arith::SubFOp>(x, cstOne); 695 e = builder.create<arith::SubFOp>( 696 e, builder.create<arith::SelectOp>(mask, cstOne, cstZero)); 697 x = builder.create<arith::AddFOp>(x, tmp); 698 699 Value x2 = builder.create<arith::MulFOp>(x, x); 700 Value x3 = builder.create<arith::MulFOp>(x2, x); 701 702 // Evaluate the polynomial approximant of degree 8 in three parts. 703 Value y0, y1, y2; 704 y0 = builder.create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1); 705 y1 = builder.create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4); 706 y2 = builder.create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7); 707 y0 = builder.create<math::FmaOp>(y0, x, cstCephesLogP2); 708 y1 = builder.create<math::FmaOp>(y1, x, cstCephesLogP5); 709 y2 = builder.create<math::FmaOp>(y2, x, cstCephesLogP8); 710 y0 = builder.create<math::FmaOp>(y0, x3, y1); 711 y0 = builder.create<math::FmaOp>(y0, x3, y2); 712 y0 = builder.create<arith::MulFOp>(y0, x3); 713 714 y0 = builder.create<math::FmaOp>(cstNegHalf, x2, y0); 715 x = builder.create<arith::AddFOp>(x, y0); 716 717 if (base2) { 718 Value cstLog2e = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE))); 719 x = builder.create<math::FmaOp>(x, cstLog2e, e); 720 } else { 721 Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE))); 722 x = builder.create<math::FmaOp>(e, cstLn2, x); 723 } 724 725 Value invalidMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, 726 op.getOperand(), cstZero); 727 Value zeroMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 728 op.getOperand(), cstZero); 729 Value posInfMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 730 op.getOperand(), cstPosInf); 731 732 // Filter out invalid values: 733 // • x == 0 -> -INF 734 // • x < 0 -> NAN 735 // • x == +INF -> +INF 736 Value aproximation = builder.create<arith::SelectOp>( 737 zeroMask, cstMinusInf, 738 builder.create<arith::SelectOp>( 739 invalidMask, cstNan, 740 builder.create<arith::SelectOp>(posInfMask, cstPosInf, x))); 741 742 rewriter.replaceOp(op, aproximation); 743 744 return success(); 745 } 746 747 namespace { 748 struct LogApproximation : public LogApproximationBase<math::LogOp> { 749 using LogApproximationBase::LogApproximationBase; 750 751 LogicalResult matchAndRewrite(math::LogOp op, 752 PatternRewriter &rewriter) const final { 753 return logMatchAndRewrite(op, rewriter, /*base2=*/false); 754 } 755 }; 756 } // namespace 757 758 namespace { 759 struct Log2Approximation : public LogApproximationBase<math::Log2Op> { 760 using LogApproximationBase::LogApproximationBase; 761 762 LogicalResult matchAndRewrite(math::Log2Op op, 763 PatternRewriter &rewriter) const final { 764 return logMatchAndRewrite(op, rewriter, /*base2=*/true); 765 } 766 }; 767 } // namespace 768 769 //----------------------------------------------------------------------------// 770 // Log1p approximation. 771 //----------------------------------------------------------------------------// 772 773 namespace { 774 struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> { 775 public: 776 using OpRewritePattern::OpRewritePattern; 777 778 LogicalResult matchAndRewrite(math::Log1pOp op, 779 PatternRewriter &rewriter) const final; 780 }; 781 } // namespace 782 783 // Approximate log(1+x). 784 LogicalResult 785 Log1pApproximation::matchAndRewrite(math::Log1pOp op, 786 PatternRewriter &rewriter) const { 787 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 788 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 789 790 std::optional<VectorShape> shape = vectorShape(op.getOperand()); 791 792 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 793 auto bcast = [&](Value value) -> Value { 794 return broadcast(builder, value, shape); 795 }; 796 797 // Approximate log(1+x) using the following, due to W. Kahan: 798 // u = x + 1.0; 799 // if (u == 1.0 || u == inf) return x; 800 // return x * log(u) / (u - 1.0); 801 // ^^^^^^^^^^^^^^^^^^^^^^ 802 // "logLarge" below. 803 Value cstOne = bcast(f32Cst(builder, 1.0f)); 804 Value x = op.getOperand(); 805 Value u = builder.create<arith::AddFOp>(x, cstOne); 806 Value uSmall = 807 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne); 808 Value logU = builder.create<math::LogOp>(u); 809 Value uInf = 810 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU); 811 Value logLarge = builder.create<arith::MulFOp>( 812 x, builder.create<arith::DivFOp>( 813 logU, builder.create<arith::SubFOp>(u, cstOne))); 814 Value approximation = builder.create<arith::SelectOp>( 815 builder.create<arith::OrIOp>(uSmall, uInf), x, logLarge); 816 rewriter.replaceOp(op, approximation); 817 return success(); 818 } 819 820 //----------------------------------------------------------------------------// 821 // Asin approximation. 822 //----------------------------------------------------------------------------// 823 824 // Approximates asin(x). 825 // This approximation is based on the following stackoverflow post: 826 // https://stackoverflow.com/a/42683455 827 namespace { 828 struct AsinPolynomialApproximation : public OpRewritePattern<math::AsinOp> { 829 public: 830 using OpRewritePattern::OpRewritePattern; 831 832 LogicalResult matchAndRewrite(math::AsinOp op, 833 PatternRewriter &rewriter) const final; 834 }; 835 } // namespace 836 LogicalResult 837 AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op, 838 PatternRewriter &rewriter) const { 839 Value operand = op.getOperand(); 840 Type elementType = getElementTypeOrSelf(operand); 841 842 if (!(elementType.isF32() || elementType.isF16())) 843 return rewriter.notifyMatchFailure(op, 844 "only f32 and f16 type is supported."); 845 std::optional<VectorShape> shape = vectorShape(operand); 846 847 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 848 auto bcast = [&](Value value) -> Value { 849 return broadcast(builder, value, shape); 850 }; 851 852 auto fma = [&](Value a, Value b, Value c) -> Value { 853 return builder.create<math::FmaOp>(a, b, c); 854 }; 855 856 auto mul = [&](Value a, Value b) -> Value { 857 return builder.create<arith::MulFOp>(a, b); 858 }; 859 860 auto sub = [&](Value a, Value b) -> Value { 861 return builder.create<arith::SubFOp>(a, b); 862 }; 863 864 auto abs = [&](Value a) -> Value { return builder.create<math::AbsFOp>(a); }; 865 866 auto sqrt = [&](Value a) -> Value { return builder.create<math::SqrtOp>(a); }; 867 868 auto scopy = [&](Value a, Value b) -> Value { 869 return builder.create<math::CopySignOp>(a, b); 870 }; 871 872 auto sel = [&](Value a, Value b, Value c) -> Value { 873 return builder.create<arith::SelectOp>(a, b, c); 874 }; 875 876 Value abso = abs(operand); 877 Value aa = mul(operand, operand); 878 Value opp = sqrt(sub(bcast(floatCst(builder, 1.0, elementType)), aa)); 879 880 Value gt = 881 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, aa, 882 bcast(floatCst(builder, 0.5, elementType))); 883 884 Value x = sel(gt, opp, abso); 885 886 // Asin(x) approximation for x = [-9/16, 9/16]: 887 Value s = mul(x, x); 888 Value q = mul(s, s); 889 Value r = bcast(floatCst(builder, 5.5579749017470502e-2, elementType)); 890 Value t = bcast(floatCst(builder, -6.2027913464120114e-2, elementType)); 891 892 r = fma(r, q, bcast(floatCst(builder, 5.4224464349245036e-2, elementType))); 893 t = fma(t, q, bcast(floatCst(builder, -1.1326992890324464e-2, elementType))); 894 r = fma(r, q, bcast(floatCst(builder, 1.5268872539397656e-2, elementType))); 895 t = fma(t, q, bcast(floatCst(builder, 1.0493798473372081e-2, elementType))); 896 r = fma(r, q, bcast(floatCst(builder, 1.4106045900607047e-2, elementType))); 897 t = fma(t, q, bcast(floatCst(builder, 1.7339776384962050e-2, elementType))); 898 r = fma(r, q, bcast(floatCst(builder, 2.2372961589651054e-2, elementType))); 899 t = fma(t, q, bcast(floatCst(builder, 3.0381912707941005e-2, elementType))); 900 r = fma(r, q, bcast(floatCst(builder, 4.4642857881094775e-2, elementType))); 901 t = fma(t, q, bcast(floatCst(builder, 7.4999999991367292e-2, elementType))); 902 r = fma(r, s, t); 903 r = fma(r, s, bcast(floatCst(builder, 1.6666666666670193e-1, elementType))); 904 t = mul(x, s); 905 r = fma(r, t, x); 906 907 Value rsub = sub(bcast(floatCst(builder, 1.57079632679, elementType)), r); 908 r = sel(gt, rsub, r); 909 r = scopy(r, operand); 910 911 rewriter.replaceOp(op, r); 912 return success(); 913 } 914 915 //----------------------------------------------------------------------------// 916 // Acos approximation. 917 //----------------------------------------------------------------------------// 918 919 // Approximates acos(x). 920 // This approximation is based on the following stackoverflow post: 921 // https://stackoverflow.com/a/42683455 922 namespace { 923 struct AcosPolynomialApproximation : public OpRewritePattern<math::AcosOp> { 924 public: 925 using OpRewritePattern::OpRewritePattern; 926 927 LogicalResult matchAndRewrite(math::AcosOp op, 928 PatternRewriter &rewriter) const final; 929 }; 930 } // namespace 931 LogicalResult 932 AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op, 933 PatternRewriter &rewriter) const { 934 Value operand = op.getOperand(); 935 Type elementType = getElementTypeOrSelf(operand); 936 937 if (!(elementType.isF32() || elementType.isF16())) 938 return rewriter.notifyMatchFailure(op, 939 "only f32 and f16 type is supported."); 940 std::optional<VectorShape> shape = vectorShape(operand); 941 942 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 943 auto bcast = [&](Value value) -> Value { 944 return broadcast(builder, value, shape); 945 }; 946 947 auto fma = [&](Value a, Value b, Value c) -> Value { 948 return builder.create<math::FmaOp>(a, b, c); 949 }; 950 951 auto mul = [&](Value a, Value b) -> Value { 952 return builder.create<arith::MulFOp>(a, b); 953 }; 954 955 Value negOperand = builder.create<arith::NegFOp>(operand); 956 Value zero = bcast(floatCst(builder, 0.0, elementType)); 957 Value half = bcast(floatCst(builder, 0.5, elementType)); 958 Value negOne = bcast(floatCst(builder, -1.0, elementType)); 959 Value selR = 960 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand, zero); 961 Value r = builder.create<arith::SelectOp>(selR, negOperand, operand); 962 Value chkConst = bcast(floatCst(builder, -0.5625, elementType)); 963 Value firstPred = 964 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, r, chkConst); 965 966 Value trueVal = 967 fma(bcast(floatCst(builder, 9.3282184640716537e-1, elementType)), 968 bcast(floatCst(builder, 1.6839188885261840e+0, elementType)), 969 builder.create<math::AsinOp>(r)); 970 971 Value falseVal = builder.create<math::SqrtOp>(fma(half, r, half)); 972 falseVal = builder.create<math::AsinOp>(falseVal); 973 falseVal = mul(bcast(floatCst(builder, 2.0, elementType)), falseVal); 974 975 r = builder.create<arith::SelectOp>(firstPred, trueVal, falseVal); 976 977 // Check whether the operand lies in between [-1.0, 0.0). 978 Value greaterThanNegOne = 979 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, operand, negOne); 980 981 Value lessThanZero = 982 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero); 983 984 Value betweenNegOneZero = 985 builder.create<arith::AndIOp>(greaterThanNegOne, lessThanZero); 986 987 trueVal = fma(bcast(floatCst(builder, 1.8656436928143307e+0, elementType)), 988 bcast(floatCst(builder, 1.6839188885261840e+0, elementType)), 989 builder.create<arith::NegFOp>(r)); 990 991 Value finalVal = 992 builder.create<arith::SelectOp>(betweenNegOneZero, trueVal, r); 993 994 rewriter.replaceOp(op, finalVal); 995 return success(); 996 } 997 998 //----------------------------------------------------------------------------// 999 // Erf approximation. 1000 //----------------------------------------------------------------------------// 1001 1002 // Approximates erf(x) with 1003 // a - P(x)/Q(x) 1004 // where P and Q are polynomials of degree 4. 1005 // Different coefficients are chosen based on the value of x. 1006 // The approximation error is ~2.5e-07. 1007 // Boost's minimax tool that utilizes the Remez method was used to find the 1008 // coefficients. 1009 LogicalResult 1010 ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op, 1011 PatternRewriter &rewriter) const { 1012 Value operand = op.getOperand(); 1013 Type elementType = getElementTypeOrSelf(operand); 1014 1015 if (!(elementType.isF32() || elementType.isF16())) 1016 return rewriter.notifyMatchFailure(op, 1017 "only f32 and f16 type is supported."); 1018 std::optional<VectorShape> shape = vectorShape(operand); 1019 1020 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 1021 auto bcast = [&](Value value) -> Value { 1022 return broadcast(builder, value, shape); 1023 }; 1024 1025 const int intervalsCount = 3; 1026 const int polyDegree = 4; 1027 1028 Value zero = bcast(floatCst(builder, 0, elementType)); 1029 Value one = bcast(floatCst(builder, 1, elementType)); 1030 Value pp[intervalsCount][polyDegree + 1]; 1031 pp[0][0] = bcast(floatCst(builder, +0.00000000000000000e+00f, elementType)); 1032 pp[0][1] = bcast(floatCst(builder, +1.12837916222975858e+00f, elementType)); 1033 pp[0][2] = bcast(floatCst(builder, -5.23018562988006470e-01f, elementType)); 1034 pp[0][3] = bcast(floatCst(builder, +2.09741709609267072e-01f, elementType)); 1035 pp[0][4] = bcast(floatCst(builder, +2.58146801602987875e-02f, elementType)); 1036 pp[1][0] = bcast(floatCst(builder, +0.00000000000000000e+00f, elementType)); 1037 pp[1][1] = bcast(floatCst(builder, +1.12750687816789140e+00f, elementType)); 1038 pp[1][2] = bcast(floatCst(builder, -3.64721408487825775e-01f, elementType)); 1039 pp[1][3] = bcast(floatCst(builder, +1.18407396425136952e-01f, elementType)); 1040 pp[1][4] = bcast(floatCst(builder, +3.70645533056476558e-02f, elementType)); 1041 pp[2][0] = bcast(floatCst(builder, -3.30093071049483172e-03f, elementType)); 1042 pp[2][1] = bcast(floatCst(builder, +3.51961938357697011e-03f, elementType)); 1043 pp[2][2] = bcast(floatCst(builder, -1.41373622814988039e-03f, elementType)); 1044 pp[2][3] = bcast(floatCst(builder, +2.53447094961941348e-04f, elementType)); 1045 pp[2][4] = bcast(floatCst(builder, -1.71048029455037401e-05f, elementType)); 1046 1047 Value qq[intervalsCount][polyDegree + 1]; 1048 qq[0][0] = bcast(floatCst(builder, +1.000000000000000000e+00f, elementType)); 1049 qq[0][1] = bcast(floatCst(builder, -4.635138185962547255e-01f, elementType)); 1050 qq[0][2] = bcast(floatCst(builder, +5.192301327279782447e-01f, elementType)); 1051 qq[0][3] = bcast(floatCst(builder, -1.318089722204810087e-01f, elementType)); 1052 qq[0][4] = bcast(floatCst(builder, +7.397964654672315005e-02f, elementType)); 1053 qq[1][0] = bcast(floatCst(builder, +1.00000000000000000e+00f, elementType)); 1054 qq[1][1] = bcast(floatCst(builder, -3.27607011824493086e-01f, elementType)); 1055 qq[1][2] = bcast(floatCst(builder, +4.48369090658821977e-01f, elementType)); 1056 qq[1][3] = bcast(floatCst(builder, -8.83462621207857930e-02f, elementType)); 1057 qq[1][4] = bcast(floatCst(builder, +5.72442770283176093e-02f, elementType)); 1058 qq[2][0] = bcast(floatCst(builder, +1.00000000000000000e+00f, elementType)); 1059 qq[2][1] = bcast(floatCst(builder, -2.06069165953913769e+00f, elementType)); 1060 qq[2][2] = bcast(floatCst(builder, +1.62705939945477759e+00f, elementType)); 1061 qq[2][3] = bcast(floatCst(builder, -5.83389859211130017e-01f, elementType)); 1062 qq[2][4] = bcast(floatCst(builder, +8.21908939856640930e-02f, elementType)); 1063 1064 Value offsets[intervalsCount]; 1065 offsets[0] = bcast(floatCst(builder, 0.0f, elementType)); 1066 offsets[1] = bcast(floatCst(builder, 0.0f, elementType)); 1067 offsets[2] = bcast(floatCst(builder, 1.0f, elementType)); 1068 1069 Value bounds[intervalsCount]; 1070 bounds[0] = bcast(floatCst(builder, 0.8f, elementType)); 1071 bounds[1] = bcast(floatCst(builder, 2.0f, elementType)); 1072 bounds[2] = bcast(floatCst(builder, 3.75f, elementType)); 1073 1074 Value isNegativeArg = 1075 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero); 1076 Value negArg = builder.create<arith::NegFOp>(operand); 1077 Value x = builder.create<arith::SelectOp>(isNegativeArg, negArg, operand); 1078 1079 Value offset = offsets[0]; 1080 Value p[polyDegree + 1]; 1081 Value q[polyDegree + 1]; 1082 for (int i = 0; i <= polyDegree; ++i) { 1083 p[i] = pp[0][i]; 1084 q[i] = qq[0][i]; 1085 } 1086 1087 // TODO: maybe use vector stacking to reduce the number of selects. 1088 Value isLessThanBound[intervalsCount]; 1089 for (int j = 0; j < intervalsCount - 1; ++j) { 1090 isLessThanBound[j] = 1091 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[j]); 1092 for (int i = 0; i <= polyDegree; ++i) { 1093 p[i] = builder.create<arith::SelectOp>(isLessThanBound[j], p[i], 1094 pp[j + 1][i]); 1095 q[i] = builder.create<arith::SelectOp>(isLessThanBound[j], q[i], 1096 qq[j + 1][i]); 1097 } 1098 offset = builder.create<arith::SelectOp>(isLessThanBound[j], offset, 1099 offsets[j + 1]); 1100 } 1101 isLessThanBound[intervalsCount - 1] = builder.create<arith::CmpFOp>( 1102 arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]); 1103 1104 Value pPoly = makePolynomialCalculation(builder, p, x); 1105 Value qPoly = makePolynomialCalculation(builder, q, x); 1106 Value rationalPoly = builder.create<arith::DivFOp>(pPoly, qPoly); 1107 Value formula = builder.create<arith::AddFOp>(offset, rationalPoly); 1108 formula = builder.create<arith::SelectOp>(isLessThanBound[intervalsCount - 1], 1109 formula, one); 1110 1111 // erf is odd function: erf(x) = -erf(-x). 1112 Value negFormula = builder.create<arith::NegFOp>(formula); 1113 Value res = 1114 builder.create<arith::SelectOp>(isNegativeArg, negFormula, formula); 1115 1116 rewriter.replaceOp(op, res); 1117 1118 return success(); 1119 } 1120 1121 //----------------------------------------------------------------------------// 1122 // Exp approximation. 1123 //----------------------------------------------------------------------------// 1124 1125 namespace { 1126 1127 Value clampWithNormals(ImplicitLocOpBuilder &builder, 1128 const std::optional<VectorShape> shape, Value value, 1129 float lowerBound, float upperBound) { 1130 assert(!std::isnan(lowerBound)); 1131 assert(!std::isnan(upperBound)); 1132 1133 auto bcast = [&](Value value) -> Value { 1134 return broadcast(builder, value, shape); 1135 }; 1136 1137 auto selectCmp = [&builder](auto pred, Value value, Value bound) { 1138 return builder.create<arith::SelectOp>( 1139 builder.create<arith::CmpFOp>(pred, value, bound), value, bound); 1140 }; 1141 1142 // Note: prefer UGE/ULE vs. UGT/ULT, since they generate vmaxps/vminps vs. 1143 // vcmpleps+vmovaps on x86_64. The latter outcome is also obtained with 1144 // arith::{Max,Min}FOp. 1145 value = selectCmp(arith::CmpFPredicate::UGE, value, 1146 bcast(f32Cst(builder, lowerBound))); 1147 value = selectCmp(arith::CmpFPredicate::ULE, value, 1148 bcast(f32Cst(builder, upperBound))); 1149 return value; 1150 } 1151 1152 struct ExpApproximation : public OpRewritePattern<math::ExpOp> { 1153 public: 1154 using OpRewritePattern::OpRewritePattern; 1155 1156 LogicalResult matchAndRewrite(math::ExpOp op, 1157 PatternRewriter &rewriter) const final; 1158 }; 1159 1160 LogicalResult 1161 ExpApproximation::matchAndRewrite(math::ExpOp op, 1162 PatternRewriter &rewriter) const { 1163 auto shape = vectorShape(op.getOperand().getType()); 1164 auto elementTy = getElementTypeOrSelf(op.getType()); 1165 if (!elementTy.isF32()) 1166 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 1167 1168 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 1169 1170 auto add = [&](Value a, Value b) -> Value { 1171 return builder.create<arith::AddFOp>(a, b); 1172 }; 1173 auto bcast = [&](Value value) -> Value { 1174 return broadcast(builder, value, shape); 1175 }; 1176 auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); }; 1177 auto fmla = [&](Value a, Value b, Value c) { 1178 return builder.create<math::FmaOp>(a, b, c); 1179 }; 1180 auto mul = [&](Value a, Value b) -> Value { 1181 return builder.create<arith::MulFOp>(a, b); 1182 }; 1183 1184 // Polynomial approximation from Cephes. 1185 // 1186 // To compute e^x, we re-express it as 1187 // 1188 // e^x = e^(a + b) 1189 // = e^(a + n log(2)) 1190 // = e^a * 2^n. 1191 // 1192 // We choose n = round(x / log(2)), restricting the value of `a` to 1193 // (-log(2)/2, log(2)/2). We then use a polynomial to compute e^a. The 1194 // relative error between our approximation and the true value of e^a is less 1195 // than 2^-22.5 for all values of `a` within this range. 1196 1197 // Restrict input to a small range, including some values that evaluate to 1198 // +/- inf. Note that for our lower bound, we choose log(2^-126) instead of 1199 // log(F32_EPSILON). We do so because this routine always flushes denormal 1200 // floating points to 0. Therefore, we only need to worry about exponentiating 1201 // up to the smallest representable non-denormal floating point, which is 1202 // 2^-126. 1203 1204 // Constants. 1205 Value cstHalf = bcast(f32Cst(builder, 0.5f)); 1206 Value cstOne = bcast(f32Cst(builder, 1.0f)); 1207 1208 // 1/log(2) 1209 Value cstLog2ef = bcast(f32Cst(builder, 1.44269504088896341f)); 1210 1211 Value cstExpC1 = bcast(f32Cst(builder, -0.693359375f)); 1212 Value cstExpC2 = bcast(f32Cst(builder, 2.12194440e-4f)); 1213 Value cstExpP0 = bcast(f32Cst(builder, 1.9875691500E-4f)); 1214 Value cstExpP1 = bcast(f32Cst(builder, 1.3981999507E-3f)); 1215 Value cstExpP2 = bcast(f32Cst(builder, 8.3334519073E-3f)); 1216 Value cstExpP3 = bcast(f32Cst(builder, 4.1665795894E-2f)); 1217 Value cstExpP4 = bcast(f32Cst(builder, 1.6666665459E-1f)); 1218 Value cstExpP5 = bcast(f32Cst(builder, 5.0000001201E-1f)); 1219 1220 // Our computations below aren't particularly sensitive to the exact choices 1221 // here, so we choose values a bit larger/smaller than 1222 // 1223 // log(F32_MAX) = 88.723... 1224 // log(2^-126) = -87.337... 1225 Value x = op.getOperand(); 1226 x = clampWithNormals(builder, shape, x, -87.8f, 88.8f); 1227 Value n = floor(fmla(x, cstLog2ef, cstHalf)); 1228 1229 // When we eventually do the multiplication in e^a * 2^n, we need to handle 1230 // the case when n > 127, the max fp32 exponent (so 2^n == inf) but e^a < 1 1231 // (so e^a * 2^n != inf). There's a similar problem for n < -126, the 1232 // smallest fp32 exponent. 1233 // 1234 // A straightforward solution would be to detect n out of range and split it 1235 // up, doing 1236 // 1237 // e^a * 2^n = e^a * 2^(n1 + n2) 1238 // = (2^n1 * e^a) * 2^n2. 1239 // 1240 // But it turns out this approach is quite slow, probably because it 1241 // manipulates subnormal values. 1242 // 1243 // The approach we use instead is to clamp n to [-127, 127]. Let n' be the 1244 // value of n clamped to [-127, 127]. In the case where n' = 127, `a` can grow 1245 // up to as large as 88.8 - 127 * log(2) which is about 0.7703. Even though 1246 // this value of `a` is outside our previously specified range, e^a will still 1247 // only have a relative error of approximately 2^-16 at worse. In practice 1248 // this seems to work well enough; it passes our exhaustive tests, breaking 1249 // only one result, and by one ulp (we return exp(88.7228394) = max-float but 1250 // we should return inf). 1251 // 1252 // In the case where n' = -127, the original input value of x is so small that 1253 // e^x, our final answer, is less than 2^-126. Since 2^-126 is the smallest 1254 // normal floating point, and since we flush denormals, we simply return 0. We 1255 // do this in a branchless way by observing that our code for constructing 2^n 1256 // produces 0 if n = -127. 1257 // 1258 // The proof that n' = -127 implies e^x < 2^-126 is as follows: 1259 // 1260 // n' = -127 implies n <= -127 1261 // implies round(x / log(2)) <= -127 1262 // implies x/log(2) < -126.5 1263 // implies x < -126.5 * log(2) 1264 // implies e^x < e^(-126.5 * log(2)) 1265 // implies e^x < 2^-126.5 < 2^-126 1266 // 1267 // This proves that n' = -127 implies e^x < 2^-126. 1268 n = clampWithNormals(builder, shape, n, -127.0f, 127.0f); 1269 1270 // Computes x = x - n' * log(2), the value for `a` 1271 x = fmla(cstExpC1, n, x); 1272 x = fmla(cstExpC2, n, x); 1273 1274 // Polynomial to compute z = e^a, accurate for a in (-0.5, 0.5). 1275 Value z = fmla(x, cstExpP0, cstExpP1); 1276 z = fmla(z, x, cstExpP2); 1277 z = fmla(z, x, cstExpP3); 1278 z = fmla(z, x, cstExpP4); 1279 z = fmla(z, x, cstExpP5); 1280 z = fmla(z, mul(x, x), x); 1281 z = add(cstOne, z); 1282 1283 // Convert n' to an i32. This is safe because we clamped it above. 1284 auto i32Vec = broadcast(builder.getI32Type(), shape); 1285 Value nI32 = builder.create<arith::FPToSIOp>(i32Vec, n); 1286 1287 // Creates the value 2^n' if -126 <= n' <= 127 and 0 if n' = -127. 1288 Value pow2 = exp2I32(builder, nI32); 1289 1290 // Return z * 2^n' if -126 <= n' <= 127 and 0 if n = -127. 1291 Value ret = mul(z, pow2); 1292 1293 rewriter.replaceOp(op, ret); 1294 return mlir::success(); 1295 } 1296 1297 } // namespace 1298 1299 //----------------------------------------------------------------------------// 1300 // ExpM1 approximation. 1301 //----------------------------------------------------------------------------// 1302 1303 namespace { 1304 1305 struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> { 1306 public: 1307 using OpRewritePattern::OpRewritePattern; 1308 1309 LogicalResult matchAndRewrite(math::ExpM1Op op, 1310 PatternRewriter &rewriter) const final; 1311 }; 1312 } // namespace 1313 1314 LogicalResult 1315 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op, 1316 PatternRewriter &rewriter) const { 1317 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 1318 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 1319 1320 std::optional<VectorShape> shape = vectorShape(op.getOperand()); 1321 1322 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 1323 auto bcast = [&](Value value) -> Value { 1324 return broadcast(builder, value, shape); 1325 }; 1326 1327 // expm1(x) = exp(x) - 1 = u - 1. 1328 // We have to handle it carefully when x is near 0, i.e. u ~= 1, 1329 // and when the input is ~= -inf, i.e. u - 1 ~= -1. 1330 Value cstOne = bcast(f32Cst(builder, 1.0f)); 1331 Value cstNegOne = bcast(f32Cst(builder, -1.0f)); 1332 Value x = op.getOperand(); 1333 Value u = builder.create<math::ExpOp>(x); 1334 Value uEqOneOrNaN = 1335 builder.create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne); 1336 Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne); 1337 Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>( 1338 arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne); 1339 // logU = log(u) ~= x 1340 Value logU = builder.create<math::LogOp>(u); 1341 1342 // Detect exp(x) = +inf; written this way to avoid having to form +inf. 1343 Value isInf = 1344 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u); 1345 1346 // (u - 1) * (x / ~x) 1347 Value expm1 = builder.create<arith::MulFOp>( 1348 uMinusOne, builder.create<arith::DivFOp>(x, logU)); 1349 expm1 = builder.create<arith::SelectOp>(isInf, u, expm1); 1350 Value approximation = builder.create<arith::SelectOp>( 1351 uEqOneOrNaN, x, 1352 builder.create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1)); 1353 rewriter.replaceOp(op, approximation); 1354 return success(); 1355 } 1356 1357 //----------------------------------------------------------------------------// 1358 // Sin and Cos approximation. 1359 //----------------------------------------------------------------------------// 1360 1361 namespace { 1362 1363 template <bool isSine, typename OpTy> 1364 struct SinAndCosApproximation : public OpRewritePattern<OpTy> { 1365 public: 1366 using OpRewritePattern<OpTy>::OpRewritePattern; 1367 1368 LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final; 1369 }; 1370 } // namespace 1371 1372 #define TWO_OVER_PI \ 1373 0.6366197723675813430755350534900574481378385829618257949906693762L 1374 #define PI_OVER_2 \ 1375 1.5707963267948966192313216916397514420985846996875529104874722961L 1376 1377 // Approximates sin(x) or cos(x) by finding the best approximation polynomial in 1378 // the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the 1379 // reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y). 1380 template <bool isSine, typename OpTy> 1381 LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite( 1382 OpTy op, PatternRewriter &rewriter) const { 1383 static_assert( 1384 llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value, 1385 "SinAndCosApproximation pattern expects math::SinOp or math::CosOp"); 1386 1387 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 1388 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 1389 1390 std::optional<VectorShape> shape = vectorShape(op.getOperand()); 1391 1392 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 1393 auto bcast = [&](Value value) -> Value { 1394 return broadcast(builder, value, shape); 1395 }; 1396 auto mul = [&](Value a, Value b) -> Value { 1397 return builder.create<arith::MulFOp>(a, b); 1398 }; 1399 auto sub = [&](Value a, Value b) -> Value { 1400 return builder.create<arith::SubFOp>(a, b); 1401 }; 1402 auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); }; 1403 1404 auto i32Vec = broadcast(builder.getI32Type(), shape); 1405 auto fPToSingedInteger = [&](Value a) -> Value { 1406 return builder.create<arith::FPToSIOp>(i32Vec, a); 1407 }; 1408 1409 auto modulo4 = [&](Value a) -> Value { 1410 return builder.create<arith::AndIOp>(a, bcast(i32Cst(builder, 3))); 1411 }; 1412 1413 auto isEqualTo = [&](Value a, Value b) -> Value { 1414 return builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b); 1415 }; 1416 1417 auto isGreaterThan = [&](Value a, Value b) -> Value { 1418 return builder.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b); 1419 }; 1420 1421 auto select = [&](Value cond, Value t, Value f) -> Value { 1422 return builder.create<arith::SelectOp>(cond, t, f); 1423 }; 1424 1425 auto fmla = [&](Value a, Value b, Value c) { 1426 return builder.create<math::FmaOp>(a, b, c); 1427 }; 1428 1429 auto bitwiseOr = [&](Value a, Value b) { 1430 return builder.create<arith::OrIOp>(a, b); 1431 }; 1432 1433 Value twoOverPi = bcast(f32Cst(builder, (float)TWO_OVER_PI)); 1434 Value piOverTwo = bcast(f32Cst(builder, (float)PI_OVER_2)); 1435 1436 Value x = op.getOperand(); 1437 1438 Value k = floor(mul(x, twoOverPi)); 1439 1440 Value y = sub(x, mul(k, piOverTwo)); 1441 1442 Value cstOne = bcast(f32Cst(builder, 1.0)); 1443 Value cstNegativeOne = bcast(f32Cst(builder, -1.0)); 1444 1445 Value cstSC2 = bcast(f32Cst(builder, -0.16666667163372039794921875f)); 1446 Value cstSC4 = bcast(f32Cst(builder, 8.333347737789154052734375e-3f)); 1447 Value cstSC6 = bcast(f32Cst(builder, -1.9842604524455964565277099609375e-4f)); 1448 Value cstSC8 = 1449 bcast(f32Cst(builder, 2.760012648650445044040679931640625e-6f)); 1450 Value cstSC10 = 1451 bcast(f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f)); 1452 1453 Value cstCC2 = bcast(f32Cst(builder, -0.5f)); 1454 Value cstCC4 = bcast(f32Cst(builder, 4.166664183139801025390625e-2f)); 1455 Value cstCC6 = bcast(f32Cst(builder, -1.388833043165504932403564453125e-3f)); 1456 Value cstCC8 = bcast(f32Cst(builder, 2.47562347794882953166961669921875e-5f)); 1457 Value cstCC10 = 1458 bcast(f32Cst(builder, -2.59630184018533327616751194000244140625e-7f)); 1459 1460 Value kMod4 = modulo4(fPToSingedInteger(k)); 1461 1462 Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, 0))); 1463 Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, 1))); 1464 Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, 2))); 1465 Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, 3))); 1466 1467 Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2); 1468 Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, 1))) 1469 : bitwiseOr(kR1, kR2); 1470 1471 Value y2 = mul(y, y); 1472 1473 Value base = select(sinuseCos, cstOne, y); 1474 Value cstC2 = select(sinuseCos, cstCC2, cstSC2); 1475 Value cstC4 = select(sinuseCos, cstCC4, cstSC4); 1476 Value cstC6 = select(sinuseCos, cstCC6, cstSC6); 1477 Value cstC8 = select(sinuseCos, cstCC8, cstSC8); 1478 Value cstC10 = select(sinuseCos, cstCC10, cstSC10); 1479 1480 Value v1 = fmla(y2, cstC10, cstC8); 1481 Value v2 = fmla(y2, v1, cstC6); 1482 Value v3 = fmla(y2, v2, cstC4); 1483 Value v4 = fmla(y2, v3, cstC2); 1484 Value v5 = fmla(y2, v4, cstOne); 1485 Value v6 = mul(base, v5); 1486 1487 Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6); 1488 1489 rewriter.replaceOp(op, approximation); 1490 1491 return success(); 1492 } 1493 1494 //----------------------------------------------------------------------------// 1495 // Cbrt approximation. 1496 //----------------------------------------------------------------------------// 1497 1498 namespace { 1499 struct CbrtApproximation : public OpRewritePattern<math::CbrtOp> { 1500 using OpRewritePattern::OpRewritePattern; 1501 1502 LogicalResult matchAndRewrite(math::CbrtOp op, 1503 PatternRewriter &rewriter) const final; 1504 }; 1505 } // namespace 1506 1507 // Estimation of cube-root using an algorithm defined in 1508 // Hacker's Delight 2nd Edition. 1509 LogicalResult 1510 CbrtApproximation::matchAndRewrite(math::CbrtOp op, 1511 PatternRewriter &rewriter) const { 1512 auto operand = op.getOperand(); 1513 if (!getElementTypeOrSelf(operand).isF32()) 1514 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 1515 1516 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 1517 std::optional<VectorShape> shape = vectorShape(operand); 1518 1519 Type floatTy = getElementTypeOrSelf(operand.getType()); 1520 Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth()); 1521 1522 // Convert to vector types if necessary. 1523 floatTy = broadcast(floatTy, shape); 1524 intTy = broadcast(intTy, shape); 1525 1526 auto bconst = [&](TypedAttr attr) -> Value { 1527 Value value = b.create<arith::ConstantOp>(attr); 1528 return broadcast(b, value, shape); 1529 }; 1530 1531 // Declare the initial values: 1532 Value intTwo = bconst(b.getI32IntegerAttr(2)); 1533 Value intFour = bconst(b.getI32IntegerAttr(4)); 1534 Value intEight = bconst(b.getI32IntegerAttr(8)); 1535 Value intMagic = bconst(b.getI32IntegerAttr(0x2a5137a0)); 1536 Value fpThird = bconst(b.getF32FloatAttr(0.33333333f)); 1537 Value fpTwo = bconst(b.getF32FloatAttr(2.0f)); 1538 Value fpZero = bconst(b.getF32FloatAttr(0.0f)); 1539 1540 // Compute an approximation of one third: 1541 // union {int ix; float x;}; 1542 // x = x0; 1543 // ix = ix/4 + ix/16; 1544 Value absValue = b.create<math::AbsFOp>(operand); 1545 Value intValue = b.create<arith::BitcastOp>(intTy, absValue); 1546 Value divideBy4 = b.create<arith::ShRSIOp>(intValue, intTwo); 1547 Value divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour); 1548 intValue = b.create<arith::AddIOp>(divideBy4, divideBy16); 1549 1550 // ix = ix + ix/16; 1551 divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour); 1552 intValue = b.create<arith::AddIOp>(intValue, divideBy16); 1553 1554 // ix = ix + ix/256; 1555 Value divideBy256 = b.create<arith::ShRSIOp>(intValue, intEight); 1556 intValue = b.create<arith::AddIOp>(intValue, divideBy256); 1557 1558 // ix = 0x2a5137a0 + ix; 1559 intValue = b.create<arith::AddIOp>(intValue, intMagic); 1560 1561 // Perform one newtons step: 1562 // x = 0.33333333f*(2.0f*x + x0/(x*x)); 1563 Value floatValue = b.create<arith::BitcastOp>(floatTy, intValue); 1564 Value squared = b.create<arith::MulFOp>(floatValue, floatValue); 1565 Value mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo); 1566 Value divSquared = b.create<arith::DivFOp>(absValue, squared); 1567 floatValue = b.create<arith::AddFOp>(mulTwo, divSquared); 1568 floatValue = b.create<arith::MulFOp>(floatValue, fpThird); 1569 1570 // x = 0.33333333f*(2.0f*x + x0/(x*x)); 1571 squared = b.create<arith::MulFOp>(floatValue, floatValue); 1572 mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo); 1573 divSquared = b.create<arith::DivFOp>(absValue, squared); 1574 floatValue = b.create<arith::AddFOp>(mulTwo, divSquared); 1575 floatValue = b.create<arith::MulFOp>(floatValue, fpThird); 1576 1577 // Check for zero and restore sign. 1578 Value isZero = 1579 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absValue, fpZero); 1580 floatValue = b.create<arith::SelectOp>(isZero, fpZero, floatValue); 1581 floatValue = b.create<math::CopySignOp>(floatValue, operand); 1582 1583 rewriter.replaceOp(op, floatValue); 1584 return success(); 1585 } 1586 1587 //----------------------------------------------------------------------------// 1588 // Rsqrt approximation. 1589 //----------------------------------------------------------------------------// 1590 1591 namespace { 1592 struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> { 1593 using OpRewritePattern::OpRewritePattern; 1594 1595 LogicalResult matchAndRewrite(math::RsqrtOp op, 1596 PatternRewriter &rewriter) const final; 1597 }; 1598 } // namespace 1599 1600 LogicalResult 1601 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op, 1602 PatternRewriter &rewriter) const { 1603 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 1604 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 1605 1606 std::optional<VectorShape> shape = vectorShape(op.getOperand()); 1607 1608 // Only support already-vectorized rsqrt's. 1609 if (!shape || shape->sizes.empty() || shape->sizes.back() % 8 != 0) 1610 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 1611 1612 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 1613 auto bcast = [&](Value value) -> Value { 1614 return broadcast(builder, value, shape); 1615 }; 1616 1617 Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u)); 1618 Value cstOnePointFive = bcast(f32Cst(builder, 1.5f)); 1619 Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); 1620 Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); 1621 1622 Value negHalf = builder.create<arith::MulFOp>(op.getOperand(), cstNegHalf); 1623 1624 // Select only the inverse sqrt of positive normals (denormals are 1625 // flushed to zero). 1626 Value ltMinMask = builder.create<arith::CmpFOp>( 1627 arith::CmpFPredicate::OLT, op.getOperand(), cstMinNormPos); 1628 Value infMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 1629 op.getOperand(), cstPosInf); 1630 Value notNormalFiniteMask = builder.create<arith::OrIOp>(ltMinMask, infMask); 1631 1632 // Compute an approximate result. 1633 Value yApprox = handleMultidimensionalVectors( 1634 builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value { 1635 return builder.create<x86vector::RsqrtOp>(operands); 1636 }); 1637 1638 // Do a single step of Newton-Raphson iteration to improve the approximation. 1639 // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n). 1640 // It is essential to evaluate the inner term like this because forming 1641 // y_n^2 may over- or underflow. 1642 Value inner = builder.create<arith::MulFOp>(negHalf, yApprox); 1643 Value fma = builder.create<math::FmaOp>(yApprox, inner, cstOnePointFive); 1644 Value yNewton = builder.create<arith::MulFOp>(yApprox, fma); 1645 1646 // Select the result of the Newton-Raphson step for positive normal arguments. 1647 // For other arguments, choose the output of the intrinsic. This will 1648 // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if 1649 // x is zero or a positive denormalized float (equivalent to flushing positive 1650 // denormalized inputs to zero). 1651 Value res = 1652 builder.create<arith::SelectOp>(notNormalFiniteMask, yApprox, yNewton); 1653 rewriter.replaceOp(op, res); 1654 1655 return success(); 1656 } 1657 1658 //----------------------------------------------------------------------------// 1659 1660 void mlir::populatePolynomialApproximateTanhPattern( 1661 RewritePatternSet &patterns) { 1662 patterns.add<TanhApproximation>(patterns.getContext()); 1663 } 1664 1665 void mlir::populatePolynomialApproximateErfPattern( 1666 RewritePatternSet &patterns) { 1667 patterns.add<ErfPolynomialApproximation>(patterns.getContext()); 1668 } 1669 1670 void mlir::populateMathPolynomialApproximationPatterns( 1671 RewritePatternSet &patterns, 1672 const MathPolynomialApproximationOptions &options) { 1673 // Patterns for leveraging existing f32 lowerings on other data types. 1674 patterns 1675 .add<ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>, 1676 ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>, 1677 ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>, 1678 ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>, 1679 ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>, 1680 ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>( 1681 patterns.getContext()); 1682 1683 patterns 1684 .add<AtanApproximation, Atan2Approximation, TanhApproximation, 1685 LogApproximation, Log2Approximation, Log1pApproximation, 1686 ErfPolynomialApproximation, AsinPolynomialApproximation, 1687 AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation, 1688 CbrtApproximation, SinAndCosApproximation<true, math::SinOp>, 1689 SinAndCosApproximation<false, math::CosOp>>(patterns.getContext()); 1690 if (options.enableAvx2) { 1691 patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>( 1692 patterns.getContext()); 1693 } 1694 } 1695