1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" 10 11 #include "mlir/Dialect/Arith/IR/Arith.h" 12 #include "mlir/Dialect/Complex/IR/Complex.h" 13 #include "mlir/Dialect/Math/IR/Math.h" 14 #include "mlir/IR/ImplicitLocOpBuilder.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 #include <memory> 19 #include <type_traits> 20 21 namespace mlir { 22 #define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARD 23 #include "mlir/Conversion/Passes.h.inc" 24 } // namespace mlir 25 26 using namespace mlir; 27 28 namespace { 29 30 enum class AbsFn { abs, sqrt, rsqrt }; 31 32 // Returns the absolute value, its square root or its reciprocal square root. 33 Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf, 34 ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) { 35 Value one = b.create<arith::ConstantOp>(real.getType(), 36 b.getFloatAttr(real.getType(), 1.0)); 37 38 Value absReal = b.create<math::AbsFOp>(real, fmf); 39 Value absImag = b.create<math::AbsFOp>(imag, fmf); 40 41 Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf); 42 Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf); 43 44 // The lowering below requires NaNs and infinities to work correctly. 45 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear( 46 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf); 47 Value ratio = b.create<arith::DivFOp>(min, max, fmfWithNaNInf); 48 Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmfWithNaNInf); 49 Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmfWithNaNInf); 50 Value result; 51 52 if (fn == AbsFn::rsqrt) { 53 ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmfWithNaNInf); 54 min = b.create<math::RsqrtOp>(min, fmfWithNaNInf); 55 max = b.create<math::RsqrtOp>(max, fmfWithNaNInf); 56 } 57 58 if (fn == AbsFn::sqrt) { 59 Value quarter = b.create<arith::ConstantOp>( 60 real.getType(), b.getFloatAttr(real.getType(), 0.25)); 61 // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily. 62 Value sqrt = b.create<math::SqrtOp>(max, fmfWithNaNInf); 63 Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmfWithNaNInf); 64 result = b.create<arith::MulFOp>(sqrt, p025, fmfWithNaNInf); 65 } else { 66 Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmfWithNaNInf); 67 result = b.create<arith::MulFOp>(max, sqrt, fmfWithNaNInf); 68 } 69 70 Value isNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, 71 result, fmfWithNaNInf); 72 return b.create<arith::SelectOp>(isNaN, min, result); 73 } 74 75 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { 76 using OpConversionPattern<complex::AbsOp>::OpConversionPattern; 77 78 LogicalResult 79 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, 80 ConversionPatternRewriter &rewriter) const override { 81 ImplicitLocOpBuilder b(op.getLoc(), rewriter); 82 83 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); 84 85 Value real = b.create<complex::ReOp>(adaptor.getComplex()); 86 Value imag = b.create<complex::ImOp>(adaptor.getComplex()); 87 rewriter.replaceOp(op, computeAbs(real, imag, fmf, b)); 88 89 return success(); 90 } 91 }; 92 93 // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2)) 94 struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> { 95 using OpConversionPattern<complex::Atan2Op>::OpConversionPattern; 96 97 LogicalResult 98 matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor, 99 ConversionPatternRewriter &rewriter) const override { 100 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 101 102 auto type = cast<ComplexType>(op.getType()); 103 Type elementType = type.getElementType(); 104 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); 105 106 Value lhs = adaptor.getLhs(); 107 Value rhs = adaptor.getRhs(); 108 109 Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs, fmf); 110 Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs, fmf); 111 Value rhsSquaredPlusLhsSquared = 112 b.create<complex::AddOp>(type, rhsSquared, lhsSquared, fmf); 113 Value sqrtOfRhsSquaredPlusLhsSquared = 114 b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared, fmf); 115 116 Value zero = 117 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); 118 Value one = b.create<arith::ConstantOp>(elementType, 119 b.getFloatAttr(elementType, 1)); 120 Value i = b.create<complex::CreateOp>(type, zero, one); 121 Value iTimesLhs = b.create<complex::MulOp>(i, lhs, fmf); 122 Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs, fmf); 123 124 Value divResult = b.create<complex::DivOp>( 125 rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf); 126 Value logResult = b.create<complex::LogOp>(divResult, fmf); 127 128 Value negativeOne = b.create<arith::ConstantOp>( 129 elementType, b.getFloatAttr(elementType, -1)); 130 Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne); 131 132 rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult, fmf); 133 return success(); 134 } 135 }; 136 137 template <typename ComparisonOp, arith::CmpFPredicate p> 138 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> { 139 using OpConversionPattern<ComparisonOp>::OpConversionPattern; 140 using ResultCombiner = 141 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value, 142 arith::AndIOp, arith::OrIOp>; 143 144 LogicalResult 145 matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor, 146 ConversionPatternRewriter &rewriter) const override { 147 auto loc = op.getLoc(); 148 auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType(); 149 150 Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs()); 151 Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs()); 152 Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs()); 153 Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs()); 154 Value realComparison = 155 rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs); 156 Value imagComparison = 157 rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs); 158 159 rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison, 160 imagComparison); 161 return success(); 162 } 163 }; 164 165 // Default conversion which applies the BinaryStandardOp separately on the real 166 // and imaginary parts. Can for example be used for complex::AddOp and 167 // complex::SubOp. 168 template <typename BinaryComplexOp, typename BinaryStandardOp> 169 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> { 170 using OpConversionPattern<BinaryComplexOp>::OpConversionPattern; 171 172 LogicalResult 173 matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor, 174 ConversionPatternRewriter &rewriter) const override { 175 auto type = cast<ComplexType>(adaptor.getLhs().getType()); 176 auto elementType = cast<FloatType>(type.getElementType()); 177 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 178 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); 179 180 Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs()); 181 Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs()); 182 Value resultReal = b.create<BinaryStandardOp>(elementType, realLhs, realRhs, 183 fmf.getValue()); 184 Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs()); 185 Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs()); 186 Value resultImag = b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs, 187 fmf.getValue()); 188 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 189 resultImag); 190 return success(); 191 } 192 }; 193 194 template <typename TrigonometricOp> 195 struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> { 196 using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor; 197 198 using OpConversionPattern<TrigonometricOp>::OpConversionPattern; 199 200 LogicalResult 201 matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor, 202 ConversionPatternRewriter &rewriter) const override { 203 auto loc = op.getLoc(); 204 auto type = cast<ComplexType>(adaptor.getComplex().getType()); 205 auto elementType = cast<FloatType>(type.getElementType()); 206 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); 207 208 Value real = 209 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 210 Value imag = 211 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 212 213 // Trigonometric ops use a set of common building blocks to convert to real 214 // ops. Here we create these building blocks and call into an op-specific 215 // implementation in the subclass to combine them. 216 Value half = rewriter.create<arith::ConstantOp>( 217 loc, elementType, rewriter.getFloatAttr(elementType, 0.5)); 218 Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf); 219 Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp, fmf); 220 Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp, fmf); 221 Value sin = rewriter.create<math::SinOp>(loc, real, fmf); 222 Value cos = rewriter.create<math::CosOp>(loc, real, fmf); 223 224 auto resultPair = 225 combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf); 226 227 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first, 228 resultPair.second); 229 return success(); 230 } 231 232 virtual std::pair<Value, Value> 233 combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin, 234 Value cos, ConversionPatternRewriter &rewriter, 235 arith::FastMathFlagsAttr fmf) const = 0; 236 }; 237 238 struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> { 239 using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion; 240 241 std::pair<Value, Value> combine(Location loc, Value scaledExp, 242 Value reciprocalExp, Value sin, Value cos, 243 ConversionPatternRewriter &rewriter, 244 arith::FastMathFlagsAttr fmf) const override { 245 // Complex cosine is defined as; 246 // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy))) 247 // Plugging in: 248 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x)) 249 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x))) 250 // and defining t := exp(y) 251 // We get: 252 // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x 253 // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x 254 Value sum = 255 rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf); 256 Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos, fmf); 257 Value diff = 258 rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf); 259 Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin, fmf); 260 return {resultReal, resultImag}; 261 } 262 }; 263 264 struct DivOpConversion : public OpConversionPattern<complex::DivOp> { 265 using OpConversionPattern<complex::DivOp>::OpConversionPattern; 266 267 LogicalResult 268 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, 269 ConversionPatternRewriter &rewriter) const override { 270 auto loc = op.getLoc(); 271 auto type = cast<ComplexType>(adaptor.getLhs().getType()); 272 auto elementType = cast<FloatType>(type.getElementType()); 273 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); 274 275 Value lhsReal = 276 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs()); 277 Value lhsImag = 278 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs()); 279 Value rhsReal = 280 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs()); 281 Value rhsImag = 282 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs()); 283 284 // Smith's algorithm to divide complex numbers. It is just a bit smarter 285 // way to compute the following formula: 286 // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i) 287 // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) / 288 // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i)) 289 // = ((lhsReal * rhsReal + lhsImag * rhsImag) + 290 // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2 291 // 292 // Depending on whether |rhsReal| < |rhsImag| we compute either 293 // rhsRealImagRatio = rhsReal / rhsImag 294 // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio 295 // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom 296 // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom 297 // 298 // or 299 // 300 // rhsImagRealRatio = rhsImag / rhsReal 301 // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio 302 // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom 303 // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom 304 // 305 // See https://dl.acm.org/citation.cfm?id=368661 for more details. 306 Value rhsRealImagRatio = 307 rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag, fmf); 308 Value rhsRealImagDenom = rewriter.create<arith::AddFOp>( 309 loc, rhsImag, 310 rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal, fmf), 311 fmf); 312 Value realNumerator1 = rewriter.create<arith::AddFOp>( 313 loc, 314 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio, fmf), 315 lhsImag, fmf); 316 Value resultReal1 = rewriter.create<arith::DivFOp>(loc, realNumerator1, 317 rhsRealImagDenom, fmf); 318 Value imagNumerator1 = rewriter.create<arith::SubFOp>( 319 loc, 320 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio, fmf), 321 lhsReal, fmf); 322 Value resultImag1 = rewriter.create<arith::DivFOp>(loc, imagNumerator1, 323 rhsRealImagDenom, fmf); 324 325 Value rhsImagRealRatio = 326 rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal, fmf); 327 Value rhsImagRealDenom = rewriter.create<arith::AddFOp>( 328 loc, rhsReal, 329 rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag, fmf), 330 fmf); 331 Value realNumerator2 = rewriter.create<arith::AddFOp>( 332 loc, lhsReal, 333 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio, fmf), 334 fmf); 335 Value resultReal2 = rewriter.create<arith::DivFOp>(loc, realNumerator2, 336 rhsImagRealDenom, fmf); 337 Value imagNumerator2 = rewriter.create<arith::SubFOp>( 338 loc, lhsImag, 339 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio, fmf), 340 fmf); 341 Value resultImag2 = rewriter.create<arith::DivFOp>(loc, imagNumerator2, 342 rhsImagRealDenom, fmf); 343 344 // Consider corner cases. 345 // Case 1. Zero denominator, numerator contains at most one NaN value. 346 Value zero = rewriter.create<arith::ConstantOp>( 347 loc, elementType, rewriter.getZeroAttr(elementType)); 348 Value rhsRealAbs = rewriter.create<math::AbsFOp>(loc, rhsReal, fmf); 349 Value rhsRealIsZero = rewriter.create<arith::CmpFOp>( 350 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero); 351 Value rhsImagAbs = rewriter.create<math::AbsFOp>(loc, rhsImag, fmf); 352 Value rhsImagIsZero = rewriter.create<arith::CmpFOp>( 353 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero); 354 Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>( 355 loc, arith::CmpFPredicate::ORD, lhsReal, zero); 356 Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>( 357 loc, arith::CmpFPredicate::ORD, lhsImag, zero); 358 Value lhsContainsNotNaNValue = 359 rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); 360 Value resultIsInfinity = rewriter.create<arith::AndIOp>( 361 loc, lhsContainsNotNaNValue, 362 rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero)); 363 Value inf = rewriter.create<arith::ConstantOp>( 364 loc, elementType, 365 rewriter.getFloatAttr( 366 elementType, APFloat::getInf(elementType.getFloatSemantics()))); 367 Value infWithSignOfRhsReal = 368 rewriter.create<math::CopySignOp>(loc, inf, rhsReal); 369 Value infinityResultReal = 370 rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal, fmf); 371 Value infinityResultImag = 372 rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag, fmf); 373 374 // Case 2. Infinite numerator, finite denominator. 375 Value rhsRealFinite = rewriter.create<arith::CmpFOp>( 376 loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf); 377 Value rhsImagFinite = rewriter.create<arith::CmpFOp>( 378 loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf); 379 Value rhsFinite = 380 rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite); 381 Value lhsRealAbs = rewriter.create<math::AbsFOp>(loc, lhsReal, fmf); 382 Value lhsRealInfinite = rewriter.create<arith::CmpFOp>( 383 loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf); 384 Value lhsImagAbs = rewriter.create<math::AbsFOp>(loc, lhsImag, fmf); 385 Value lhsImagInfinite = rewriter.create<arith::CmpFOp>( 386 loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf); 387 Value lhsInfinite = 388 rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite); 389 Value infNumFiniteDenom = 390 rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite); 391 Value one = rewriter.create<arith::ConstantOp>( 392 loc, elementType, rewriter.getFloatAttr(elementType, 1)); 393 Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>( 394 loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero), 395 lhsReal); 396 Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>( 397 loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero), 398 lhsImag); 399 Value lhsRealIsInfWithSignTimesRhsReal = 400 rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal, fmf); 401 Value lhsImagIsInfWithSignTimesRhsImag = 402 rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag, fmf); 403 Value resultReal3 = rewriter.create<arith::MulFOp>( 404 loc, inf, 405 rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal, 406 lhsImagIsInfWithSignTimesRhsImag, fmf), 407 fmf); 408 Value lhsRealIsInfWithSignTimesRhsImag = 409 rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag, fmf); 410 Value lhsImagIsInfWithSignTimesRhsReal = 411 rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal, fmf); 412 Value resultImag3 = rewriter.create<arith::MulFOp>( 413 loc, inf, 414 rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal, 415 lhsRealIsInfWithSignTimesRhsImag, fmf), 416 fmf); 417 418 // Case 3: Finite numerator, infinite denominator. 419 Value lhsRealFinite = rewriter.create<arith::CmpFOp>( 420 loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf); 421 Value lhsImagFinite = rewriter.create<arith::CmpFOp>( 422 loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf); 423 Value lhsFinite = 424 rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite); 425 Value rhsRealInfinite = rewriter.create<arith::CmpFOp>( 426 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf); 427 Value rhsImagInfinite = rewriter.create<arith::CmpFOp>( 428 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf); 429 Value rhsInfinite = 430 rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite); 431 Value finiteNumInfiniteDenom = 432 rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite); 433 Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>( 434 loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero), 435 rhsReal); 436 Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>( 437 loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero), 438 rhsImag); 439 Value rhsRealIsInfWithSignTimesLhsReal = 440 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign, fmf); 441 Value rhsImagIsInfWithSignTimesLhsImag = 442 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign, fmf); 443 Value resultReal4 = rewriter.create<arith::MulFOp>( 444 loc, zero, 445 rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal, 446 rhsImagIsInfWithSignTimesLhsImag, fmf), 447 fmf); 448 Value rhsRealIsInfWithSignTimesLhsImag = 449 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign, fmf); 450 Value rhsImagIsInfWithSignTimesLhsReal = 451 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign, fmf); 452 Value resultImag4 = rewriter.create<arith::MulFOp>( 453 loc, zero, 454 rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag, 455 rhsImagIsInfWithSignTimesLhsReal, fmf), 456 fmf); 457 458 Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>( 459 loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); 460 Value resultReal = rewriter.create<arith::SelectOp>( 461 loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); 462 Value resultImag = rewriter.create<arith::SelectOp>( 463 loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); 464 Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>( 465 loc, finiteNumInfiniteDenom, resultReal4, resultReal); 466 Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>( 467 loc, finiteNumInfiniteDenom, resultImag4, resultImag); 468 Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>( 469 loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); 470 Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>( 471 loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); 472 Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>( 473 loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); 474 Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>( 475 loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); 476 477 Value resultRealIsNaN = rewriter.create<arith::CmpFOp>( 478 loc, arith::CmpFPredicate::UNO, resultReal, zero); 479 Value resultImagIsNaN = rewriter.create<arith::CmpFOp>( 480 loc, arith::CmpFPredicate::UNO, resultImag, zero); 481 Value resultIsNaN = 482 rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN); 483 Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>( 484 loc, resultIsNaN, resultRealSpecialCase1, resultReal); 485 Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>( 486 loc, resultIsNaN, resultImagSpecialCase1, resultImag); 487 488 rewriter.replaceOpWithNewOp<complex::CreateOp>( 489 op, type, resultRealWithSpecialCases, resultImagWithSpecialCases); 490 return success(); 491 } 492 }; 493 494 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> { 495 using OpConversionPattern<complex::ExpOp>::OpConversionPattern; 496 497 LogicalResult 498 matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor, 499 ConversionPatternRewriter &rewriter) const override { 500 auto loc = op.getLoc(); 501 auto type = cast<ComplexType>(adaptor.getComplex().getType()); 502 auto elementType = cast<FloatType>(type.getElementType()); 503 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); 504 505 Value real = 506 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 507 Value imag = 508 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 509 Value expReal = rewriter.create<math::ExpOp>(loc, real, fmf.getValue()); 510 Value cosImag = rewriter.create<math::CosOp>(loc, imag, fmf.getValue()); 511 Value resultReal = 512 rewriter.create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue()); 513 Value sinImag = rewriter.create<math::SinOp>(loc, imag, fmf.getValue()); 514 Value resultImag = 515 rewriter.create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue()); 516 517 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 518 resultImag); 519 return success(); 520 } 521 }; 522 523 Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg, 524 ArrayRef<double> coefficients, 525 arith::FastMathFlagsAttr fmf) { 526 auto argType = mlir::cast<FloatType>(arg.getType()); 527 Value poly = 528 b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[0])); 529 for (unsigned i = 1; i < coefficients.size(); ++i) { 530 poly = b.create<math::FmaOp>( 531 poly, arg, 532 b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[i])), 533 fmf); 534 } 535 return poly; 536 } 537 538 struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> { 539 using OpConversionPattern<complex::Expm1Op>::OpConversionPattern; 540 541 // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i 542 // [handle inaccuracies when a and/or b are small] 543 // = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i 544 // = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i 545 LogicalResult 546 matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor, 547 ConversionPatternRewriter &rewriter) const override { 548 auto type = op.getType(); 549 auto elemType = mlir::cast<FloatType>(type.getElementType()); 550 551 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); 552 ImplicitLocOpBuilder b(op.getLoc(), rewriter); 553 Value real = b.create<complex::ReOp>(adaptor.getComplex()); 554 Value imag = b.create<complex::ImOp>(adaptor.getComplex()); 555 556 Value zero = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 0.0)); 557 Value one = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 1.0)); 558 559 Value expm1Real = b.create<math::ExpM1Op>(real, fmf); 560 Value expReal = b.create<arith::AddFOp>(expm1Real, one, fmf); 561 562 Value sinImag = b.create<math::SinOp>(imag, fmf); 563 Value cosm1Imag = emitCosm1(imag, fmf, b); 564 Value cosImag = b.create<arith::AddFOp>(cosm1Imag, one, fmf); 565 566 Value realResult = b.create<arith::AddFOp>( 567 b.create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf); 568 569 Value imagIsZero = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, 570 zero, fmf.getValue()); 571 Value imagResult = b.create<arith::SelectOp>( 572 imagIsZero, zero, b.create<arith::MulFOp>(expReal, sinImag, fmf)); 573 574 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realResult, 575 imagResult); 576 return success(); 577 } 578 579 private: 580 Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf, 581 ImplicitLocOpBuilder &b) const { 582 auto argType = mlir::cast<FloatType>(arg.getType()); 583 auto negHalf = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -0.5)); 584 auto negOne = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -1.0)); 585 586 // Algorithm copied from cephes cosm1. 587 SmallVector<double, 7> kCoeffs{ 588 4.7377507964246204691685E-14, -1.1470284843425359765671E-11, 589 2.0876754287081521758361E-9, -2.7557319214999787979814E-7, 590 2.4801587301570552304991E-5, -1.3888888888888872993737E-3, 591 4.1666666666666666609054E-2, 592 }; 593 Value cos = b.create<math::CosOp>(arg, fmf); 594 Value forLargeArg = b.create<arith::AddFOp>(cos, negOne, fmf); 595 596 Value argPow2 = b.create<arith::MulFOp>(arg, arg, fmf); 597 Value argPow4 = b.create<arith::MulFOp>(argPow2, argPow2, fmf); 598 Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf); 599 600 auto forSmallArg = 601 b.create<arith::AddFOp>(b.create<arith::MulFOp>(argPow4, poly, fmf), 602 b.create<arith::MulFOp>(negHalf, argPow2, fmf)); 603 604 // (pi/4)^2 is approximately 0.61685 605 Value piOver4Pow2 = 606 b.create<arith::ConstantOp>(b.getFloatAttr(argType, 0.61685)); 607 Value cond = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, argPow2, 608 piOver4Pow2, fmf.getValue()); 609 return b.create<arith::SelectOp>(cond, forLargeArg, forSmallArg); 610 } 611 }; 612 613 struct LogOpConversion : public OpConversionPattern<complex::LogOp> { 614 using OpConversionPattern<complex::LogOp>::OpConversionPattern; 615 616 LogicalResult 617 matchAndRewrite(complex::LogOp op, OpAdaptor adaptor, 618 ConversionPatternRewriter &rewriter) const override { 619 auto type = cast<ComplexType>(adaptor.getComplex().getType()); 620 auto elementType = cast<FloatType>(type.getElementType()); 621 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); 622 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 623 624 Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), 625 fmf.getValue()); 626 Value resultReal = b.create<math::LogOp>(elementType, abs, fmf.getValue()); 627 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 628 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 629 Value resultImag = 630 b.create<math::Atan2Op>(elementType, imag, real, fmf.getValue()); 631 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 632 resultImag); 633 return success(); 634 } 635 }; 636 637 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> { 638 using OpConversionPattern<complex::Log1pOp>::OpConversionPattern; 639 640 LogicalResult 641 matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor, 642 ConversionPatternRewriter &rewriter) const override { 643 auto type = cast<ComplexType>(adaptor.getComplex().getType()); 644 auto elementType = cast<FloatType>(type.getElementType()); 645 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); 646 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 647 648 Value real = b.create<complex::ReOp>(adaptor.getComplex()); 649 Value imag = b.create<complex::ImOp>(adaptor.getComplex()); 650 651 Value half = b.create<arith::ConstantOp>(elementType, 652 b.getFloatAttr(elementType, 0.5)); 653 Value one = b.create<arith::ConstantOp>(elementType, 654 b.getFloatAttr(elementType, 1)); 655 Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf); 656 Value absRealPlusOne = b.create<math::AbsFOp>(realPlusOne, fmf); 657 Value absImag = b.create<math::AbsFOp>(imag, fmf); 658 659 Value maxAbs = b.create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf); 660 Value minAbs = b.create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf); 661 662 Value useReal = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, 663 realPlusOne, absImag, fmf); 664 Value maxMinusOne = b.create<arith::SubFOp>(maxAbs, one, fmf); 665 Value maxAbsOfRealPlusOneAndImagMinusOne = 666 b.create<arith::SelectOp>(useReal, real, maxMinusOne); 667 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear( 668 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf); 669 Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmfWithNaNInf); 670 Value logOfMaxAbsOfRealPlusOneAndImag = 671 b.create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf); 672 Value logOfSqrtPart = b.create<math::Log1pOp>( 673 b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmfWithNaNInf), 674 fmfWithNaNInf); 675 Value r = b.create<arith::AddFOp>( 676 b.create<arith::MulFOp>(half, logOfSqrtPart, fmfWithNaNInf), 677 logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf); 678 Value resultReal = b.create<arith::SelectOp>( 679 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf), 680 minAbs, r); 681 Value resultImag = b.create<math::Atan2Op>(imag, realPlusOne, fmf); 682 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 683 resultImag); 684 return success(); 685 } 686 }; 687 688 struct MulOpConversion : public OpConversionPattern<complex::MulOp> { 689 using OpConversionPattern<complex::MulOp>::OpConversionPattern; 690 691 LogicalResult 692 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, 693 ConversionPatternRewriter &rewriter) const override { 694 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 695 auto type = cast<ComplexType>(adaptor.getLhs().getType()); 696 auto elementType = cast<FloatType>(type.getElementType()); 697 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); 698 auto fmfValue = fmf.getValue(); 699 Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs()); 700 Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs()); 701 Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs()); 702 Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs()); 703 Value lhsRealTimesRhsReal = 704 b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue); 705 Value lhsImagTimesRhsImag = 706 b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue); 707 Value real = b.create<arith::SubFOp>(lhsRealTimesRhsReal, 708 lhsImagTimesRhsImag, fmfValue); 709 Value lhsImagTimesRhsReal = 710 b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue); 711 Value lhsRealTimesRhsImag = 712 b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue); 713 Value imag = b.create<arith::AddFOp>(lhsImagTimesRhsReal, 714 lhsRealTimesRhsImag, fmfValue); 715 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag); 716 return success(); 717 } 718 }; 719 720 struct NegOpConversion : public OpConversionPattern<complex::NegOp> { 721 using OpConversionPattern<complex::NegOp>::OpConversionPattern; 722 723 LogicalResult 724 matchAndRewrite(complex::NegOp op, OpAdaptor adaptor, 725 ConversionPatternRewriter &rewriter) const override { 726 auto loc = op.getLoc(); 727 auto type = cast<ComplexType>(adaptor.getComplex().getType()); 728 auto elementType = cast<FloatType>(type.getElementType()); 729 730 Value real = 731 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 732 Value imag = 733 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 734 Value negReal = rewriter.create<arith::NegFOp>(loc, real); 735 Value negImag = rewriter.create<arith::NegFOp>(loc, imag); 736 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag); 737 return success(); 738 } 739 }; 740 741 struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> { 742 using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion; 743 744 std::pair<Value, Value> combine(Location loc, Value scaledExp, 745 Value reciprocalExp, Value sin, Value cos, 746 ConversionPatternRewriter &rewriter, 747 arith::FastMathFlagsAttr fmf) const override { 748 // Complex sine is defined as; 749 // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy))) 750 // Plugging in: 751 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x)) 752 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x))) 753 // and defining t := exp(y) 754 // We get: 755 // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x 756 // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x 757 Value sum = 758 rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf); 759 Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin, fmf); 760 Value diff = 761 rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf); 762 Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos, fmf); 763 return {resultReal, resultImag}; 764 } 765 }; 766 767 // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780. 768 struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> { 769 using OpConversionPattern<complex::SqrtOp>::OpConversionPattern; 770 771 LogicalResult 772 matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor, 773 ConversionPatternRewriter &rewriter) const override { 774 ImplicitLocOpBuilder b(op.getLoc(), rewriter); 775 776 auto type = cast<ComplexType>(op.getType()); 777 auto elementType = cast<FloatType>(type.getElementType()); 778 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); 779 780 auto cst = [&](APFloat v) { 781 return b.create<arith::ConstantOp>(elementType, 782 b.getFloatAttr(elementType, v)); 783 }; 784 const auto &floatSemantics = elementType.getFloatSemantics(); 785 Value zero = cst(APFloat::getZero(floatSemantics)); 786 Value half = b.create<arith::ConstantOp>(elementType, 787 b.getFloatAttr(elementType, 0.5)); 788 789 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 790 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 791 Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt); 792 Value argArg = b.create<math::Atan2Op>(imag, real, fmf); 793 Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf); 794 Value cos = b.create<math::CosOp>(sqrtArg, fmf); 795 Value sin = b.create<math::SinOp>(sqrtArg, fmf); 796 // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply 797 // 0 * inf. 798 Value sinIsZero = 799 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf); 800 801 Value resultReal = b.create<arith::MulFOp>(absSqrt, cos, fmf); 802 Value resultImag = b.create<arith::SelectOp>( 803 sinIsZero, zero, b.create<arith::MulFOp>(absSqrt, sin, fmf)); 804 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | 805 arith::FastMathFlags::ninf)) { 806 Value inf = cst(APFloat::getInf(floatSemantics)); 807 Value negInf = cst(APFloat::getInf(floatSemantics, true)); 808 Value nan = cst(APFloat::getNaN(floatSemantics)); 809 Value absImag = b.create<math::AbsFOp>(elementType, imag, fmf); 810 811 Value absImagIsInf = 812 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf); 813 Value absImagIsNotInf = 814 b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf); 815 Value realIsInf = 816 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf); 817 Value realIsNegInf = 818 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf); 819 820 resultReal = b.create<arith::SelectOp>( 821 b.create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero, 822 resultReal); 823 resultReal = b.create<arith::SelectOp>( 824 b.create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal); 825 826 Value imagSignInf = b.create<math::CopySignOp>(inf, imag, fmf); 827 resultImag = b.create<arith::SelectOp>( 828 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt), 829 nan, resultImag); 830 resultImag = b.create<arith::SelectOp>( 831 b.create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf, 832 resultImag); 833 } 834 835 Value resultIsZero = 836 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf); 837 resultReal = b.create<arith::SelectOp>(resultIsZero, zero, resultReal); 838 resultImag = b.create<arith::SelectOp>(resultIsZero, zero, resultImag); 839 840 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 841 resultImag); 842 return success(); 843 } 844 }; 845 846 struct SignOpConversion : public OpConversionPattern<complex::SignOp> { 847 using OpConversionPattern<complex::SignOp>::OpConversionPattern; 848 849 LogicalResult 850 matchAndRewrite(complex::SignOp op, OpAdaptor adaptor, 851 ConversionPatternRewriter &rewriter) const override { 852 auto type = cast<ComplexType>(adaptor.getComplex().getType()); 853 auto elementType = cast<FloatType>(type.getElementType()); 854 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 855 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); 856 857 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 858 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 859 Value zero = 860 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); 861 Value realIsZero = 862 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero); 863 Value imagIsZero = 864 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero); 865 Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero); 866 auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), fmf); 867 Value realSign = b.create<arith::DivFOp>(real, abs, fmf); 868 Value imagSign = b.create<arith::DivFOp>(imag, abs, fmf); 869 Value sign = b.create<complex::CreateOp>(type, realSign, imagSign); 870 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero, 871 adaptor.getComplex(), sign); 872 return success(); 873 } 874 }; 875 876 template <typename Op> 877 struct TanTanhOpConversion : public OpConversionPattern<Op> { 878 using OpConversionPattern<Op>::OpConversionPattern; 879 880 LogicalResult 881 matchAndRewrite(Op op, typename Op::Adaptor adaptor, 882 ConversionPatternRewriter &rewriter) const override { 883 ImplicitLocOpBuilder b(op.getLoc(), rewriter); 884 auto loc = op.getLoc(); 885 auto type = cast<ComplexType>(adaptor.getComplex().getType()); 886 auto elementType = cast<FloatType>(type.getElementType()); 887 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); 888 const auto &floatSemantics = elementType.getFloatSemantics(); 889 890 Value real = 891 b.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 892 Value imag = 893 b.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 894 Value negOne = b.create<arith::ConstantOp>( 895 elementType, b.getFloatAttr(elementType, -1.0)); 896 897 if constexpr (std::is_same_v<Op, complex::TanOp>) { 898 // tan(x+yi) = -i*tanh(-y + xi) 899 std::swap(real, imag); 900 real = b.create<arith::MulFOp>(real, negOne, fmf); 901 } 902 903 auto cst = [&](APFloat v) { 904 return b.create<arith::ConstantOp>(elementType, 905 b.getFloatAttr(elementType, v)); 906 }; 907 Value inf = cst(APFloat::getInf(floatSemantics)); 908 Value four = b.create<arith::ConstantOp>(elementType, 909 b.getFloatAttr(elementType, 4.0)); 910 Value twoReal = b.create<arith::AddFOp>(real, real, fmf); 911 Value negTwoReal = b.create<arith::MulFOp>(negOne, twoReal, fmf); 912 913 Value expTwoRealMinusOne = b.create<math::ExpM1Op>(twoReal, fmf); 914 Value expNegTwoRealMinusOne = b.create<math::ExpM1Op>(negTwoReal, fmf); 915 Value realNum = 916 b.create<arith::SubFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf); 917 918 Value cosImag = b.create<math::CosOp>(imag, fmf); 919 Value cosImagSq = b.create<arith::MulFOp>(cosImag, cosImag, fmf); 920 Value twoCosTwoImagPlusOne = b.create<arith::MulFOp>(cosImagSq, four, fmf); 921 Value sinImag = b.create<math::SinOp>(imag, fmf); 922 923 Value imagNum = b.create<arith::MulFOp>( 924 four, b.create<arith::MulFOp>(cosImag, sinImag, fmf), fmf); 925 926 Value expSumMinusTwo = 927 b.create<arith::AddFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf); 928 Value denom = 929 b.create<arith::AddFOp>(expSumMinusTwo, twoCosTwoImagPlusOne, fmf); 930 931 Value isInf = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 932 expSumMinusTwo, inf, fmf); 933 Value realLimit = b.create<math::CopySignOp>(negOne, real, fmf); 934 935 Value resultReal = b.create<arith::SelectOp>( 936 isInf, realLimit, b.create<arith::DivFOp>(realNum, denom, fmf)); 937 Value resultImag = b.create<arith::DivFOp>(imagNum, denom, fmf); 938 939 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | 940 arith::FastMathFlags::ninf)) { 941 Value absReal = b.create<math::AbsFOp>(real, fmf); 942 Value zero = b.create<arith::ConstantOp>( 943 elementType, b.getFloatAttr(elementType, 0.0)); 944 Value nan = cst(APFloat::getNaN(floatSemantics)); 945 946 Value absRealIsInf = 947 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf); 948 Value imagIsZero = 949 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf); 950 Value absRealIsNotInf = b.create<arith::XOrIOp>( 951 absRealIsInf, b.create<arith::ConstantIntOp>(true, /*width=*/1)); 952 953 Value imagNumIsNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, 954 imagNum, imagNum, fmf); 955 Value resultRealIsNaN = 956 b.create<arith::AndIOp>(imagNumIsNaN, absRealIsNotInf); 957 Value resultImagIsZero = b.create<arith::OrIOp>( 958 imagIsZero, b.create<arith::AndIOp>(absRealIsInf, imagNumIsNaN)); 959 960 resultReal = b.create<arith::SelectOp>(resultRealIsNaN, nan, resultReal); 961 resultImag = 962 b.create<arith::SelectOp>(resultImagIsZero, zero, resultImag); 963 } 964 965 if constexpr (std::is_same_v<Op, complex::TanOp>) { 966 // tan(x+yi) = -i*tanh(-y + xi) 967 std::swap(resultReal, resultImag); 968 resultImag = b.create<arith::MulFOp>(resultImag, negOne, fmf); 969 } 970 971 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 972 resultImag); 973 return success(); 974 } 975 }; 976 977 struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> { 978 using OpConversionPattern<complex::ConjOp>::OpConversionPattern; 979 980 LogicalResult 981 matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor, 982 ConversionPatternRewriter &rewriter) const override { 983 auto loc = op.getLoc(); 984 auto type = cast<ComplexType>(adaptor.getComplex().getType()); 985 auto elementType = cast<FloatType>(type.getElementType()); 986 Value real = 987 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 988 Value imag = 989 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 990 Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag); 991 992 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag); 993 994 return success(); 995 } 996 }; 997 998 /// Converts lhs^y = (a+bi)^(c+di) to 999 /// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), 1000 /// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) 1001 static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder, 1002 ComplexType type, Value lhs, Value c, Value d, 1003 arith::FastMathFlags fmf) { 1004 auto elementType = cast<FloatType>(type.getElementType()); 1005 1006 Value a = builder.create<complex::ReOp>(lhs); 1007 Value b = builder.create<complex::ImOp>(lhs); 1008 1009 Value abs = builder.create<complex::AbsOp>(lhs, fmf); 1010 Value absToC = builder.create<math::PowFOp>(abs, c, fmf); 1011 1012 Value negD = builder.create<arith::NegFOp>(d, fmf); 1013 Value argLhs = builder.create<math::Atan2Op>(b, a, fmf); 1014 Value negDArgLhs = builder.create<arith::MulFOp>(negD, argLhs, fmf); 1015 Value expNegDArgLhs = builder.create<math::ExpOp>(negDArgLhs, fmf); 1016 1017 Value coeff = builder.create<arith::MulFOp>(absToC, expNegDArgLhs, fmf); 1018 Value lnAbs = builder.create<math::LogOp>(abs, fmf); 1019 Value cArgLhs = builder.create<arith::MulFOp>(c, argLhs, fmf); 1020 Value dLnAbs = builder.create<arith::MulFOp>(d, lnAbs, fmf); 1021 Value q = builder.create<arith::AddFOp>(cArgLhs, dLnAbs, fmf); 1022 Value cosQ = builder.create<math::CosOp>(q, fmf); 1023 Value sinQ = builder.create<math::SinOp>(q, fmf); 1024 1025 Value inf = builder.create<arith::ConstantOp>( 1026 elementType, 1027 builder.getFloatAttr(elementType, 1028 APFloat::getInf(elementType.getFloatSemantics()))); 1029 Value zero = builder.create<arith::ConstantOp>( 1030 elementType, builder.getFloatAttr(elementType, 0.0)); 1031 Value one = builder.create<arith::ConstantOp>( 1032 elementType, builder.getFloatAttr(elementType, 1.0)); 1033 Value complexOne = builder.create<complex::CreateOp>(type, one, zero); 1034 Value complexZero = builder.create<complex::CreateOp>(type, zero, zero); 1035 Value complexInf = builder.create<complex::CreateOp>(type, inf, zero); 1036 1037 // Case 0: 1038 // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see 1039 // Branch Cuts for Complex Elementary Functions or Much Ado About 1040 // Nothing's Sign Bit, W. Kahan, Section 10. 1041 Value absEqZero = 1042 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, abs, zero, fmf); 1043 Value dEqZero = 1044 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero, fmf); 1045 Value cEqZero = 1046 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero, fmf); 1047 Value bEqZero = 1048 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, b, zero, fmf); 1049 1050 Value zeroLeC = 1051 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLE, zero, c, fmf); 1052 Value coeffCosQ = builder.create<arith::MulFOp>(coeff, cosQ, fmf); 1053 Value coeffSinQ = builder.create<arith::MulFOp>(coeff, sinQ, fmf); 1054 Value complexOneOrZero = 1055 builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero); 1056 Value coeffCosSin = 1057 builder.create<complex::CreateOp>(type, coeffCosQ, coeffSinQ); 1058 Value cutoff0 = builder.create<arith::SelectOp>( 1059 builder.create<arith::AndIOp>( 1060 builder.create<arith::AndIOp>(absEqZero, dEqZero), zeroLeC), 1061 complexOneOrZero, coeffCosSin); 1062 1063 // Case 1: 1064 // x^0 is defined to be 1 for any x, see 1065 // Branch Cuts for Complex Elementary Functions or Much Ado About 1066 // Nothing's Sign Bit, W. Kahan, Section 10. 1067 Value rhsEqZero = builder.create<arith::AndIOp>(cEqZero, dEqZero); 1068 Value cutoff1 = 1069 builder.create<arith::SelectOp>(rhsEqZero, complexOne, cutoff0); 1070 1071 // Case 2: 1072 // 1^(c + d*i) = 1 + 0*i 1073 Value lhsEqOne = builder.create<arith::AndIOp>( 1074 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf), 1075 bEqZero); 1076 Value cutoff2 = 1077 builder.create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1); 1078 1079 // Case 3: 1080 // inf^(c + 0*i) = inf + 0*i, c > 0 1081 Value lhsEqInf = builder.create<arith::AndIOp>( 1082 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf), 1083 bEqZero); 1084 Value rhsGt0 = builder.create<arith::AndIOp>( 1085 dEqZero, 1086 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf)); 1087 Value cutoff3 = builder.create<arith::SelectOp>( 1088 builder.create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2); 1089 1090 // Case 4: 1091 // inf^(c + 0*i) = 0 + 0*i, c < 0 1092 Value rhsLt0 = builder.create<arith::AndIOp>( 1093 dEqZero, 1094 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf)); 1095 Value cutoff4 = builder.create<arith::SelectOp>( 1096 builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3); 1097 1098 return cutoff4; 1099 } 1100 1101 struct PowOpConversion : public OpConversionPattern<complex::PowOp> { 1102 using OpConversionPattern<complex::PowOp>::OpConversionPattern; 1103 1104 LogicalResult 1105 matchAndRewrite(complex::PowOp op, OpAdaptor adaptor, 1106 ConversionPatternRewriter &rewriter) const override { 1107 mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter); 1108 auto type = cast<ComplexType>(adaptor.getLhs().getType()); 1109 auto elementType = cast<FloatType>(type.getElementType()); 1110 1111 Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs()); 1112 Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs()); 1113 1114 rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(), 1115 c, d, op.getFastmath())}); 1116 return success(); 1117 } 1118 }; 1119 1120 struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> { 1121 using OpConversionPattern<complex::RsqrtOp>::OpConversionPattern; 1122 1123 LogicalResult 1124 matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor, 1125 ConversionPatternRewriter &rewriter) const override { 1126 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 1127 auto type = cast<ComplexType>(adaptor.getComplex().getType()); 1128 auto elementType = cast<FloatType>(type.getElementType()); 1129 1130 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); 1131 1132 auto cst = [&](APFloat v) { 1133 return b.create<arith::ConstantOp>(elementType, 1134 b.getFloatAttr(elementType, v)); 1135 }; 1136 const auto &floatSemantics = elementType.getFloatSemantics(); 1137 Value zero = cst(APFloat::getZero(floatSemantics)); 1138 Value inf = cst(APFloat::getInf(floatSemantics)); 1139 Value negHalf = b.create<arith::ConstantOp>( 1140 elementType, b.getFloatAttr(elementType, -0.5)); 1141 Value nan = cst(APFloat::getNaN(floatSemantics)); 1142 1143 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 1144 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 1145 Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt); 1146 Value argArg = b.create<math::Atan2Op>(imag, real, fmf); 1147 Value rsqrtArg = b.create<arith::MulFOp>(argArg, negHalf, fmf); 1148 Value cos = b.create<math::CosOp>(rsqrtArg, fmf); 1149 Value sin = b.create<math::SinOp>(rsqrtArg, fmf); 1150 1151 Value resultReal = b.create<arith::MulFOp>(absRsqrt, cos, fmf); 1152 Value resultImag = b.create<arith::MulFOp>(absRsqrt, sin, fmf); 1153 1154 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | 1155 arith::FastMathFlags::ninf)) { 1156 Value negOne = b.create<arith::ConstantOp>( 1157 elementType, b.getFloatAttr(elementType, -1)); 1158 1159 Value realSignedZero = b.create<math::CopySignOp>(zero, real, fmf); 1160 Value imagSignedZero = b.create<math::CopySignOp>(zero, imag, fmf); 1161 Value negImagSignedZero = 1162 b.create<arith::MulFOp>(negOne, imagSignedZero, fmf); 1163 1164 Value absReal = b.create<math::AbsFOp>(real, fmf); 1165 Value absImag = b.create<math::AbsFOp>(imag, fmf); 1166 1167 Value absImagIsInf = 1168 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf); 1169 Value realIsNan = 1170 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf); 1171 Value realIsInf = 1172 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf); 1173 Value inIsNanInf = b.create<arith::AndIOp>(absImagIsInf, realIsNan); 1174 1175 Value resultIsZero = b.create<arith::OrIOp>(inIsNanInf, realIsInf); 1176 1177 resultReal = 1178 b.create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal); 1179 resultImag = b.create<arith::SelectOp>(resultIsZero, negImagSignedZero, 1180 resultImag); 1181 } 1182 1183 Value isRealZero = 1184 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf); 1185 Value isImagZero = 1186 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf); 1187 Value isZero = b.create<arith::AndIOp>(isRealZero, isImagZero); 1188 1189 resultReal = b.create<arith::SelectOp>(isZero, inf, resultReal); 1190 resultImag = b.create<arith::SelectOp>(isZero, nan, resultImag); 1191 1192 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 1193 resultImag); 1194 return success(); 1195 } 1196 }; 1197 1198 struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> { 1199 using OpConversionPattern<complex::AngleOp>::OpConversionPattern; 1200 1201 LogicalResult 1202 matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor, 1203 ConversionPatternRewriter &rewriter) const override { 1204 auto loc = op.getLoc(); 1205 auto type = op.getType(); 1206 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); 1207 1208 Value real = 1209 rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex()); 1210 Value imag = 1211 rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex()); 1212 1213 rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real, fmf); 1214 1215 return success(); 1216 } 1217 }; 1218 1219 } // namespace 1220 1221 void mlir::populateComplexToStandardConversionPatterns( 1222 RewritePatternSet &patterns) { 1223 // clang-format off 1224 patterns.add< 1225 AbsOpConversion, 1226 AngleOpConversion, 1227 Atan2OpConversion, 1228 BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>, 1229 BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>, 1230 ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>, 1231 ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>, 1232 ConjOpConversion, 1233 CosOpConversion, 1234 DivOpConversion, 1235 ExpOpConversion, 1236 Expm1OpConversion, 1237 Log1pOpConversion, 1238 LogOpConversion, 1239 MulOpConversion, 1240 NegOpConversion, 1241 SignOpConversion, 1242 SinOpConversion, 1243 SqrtOpConversion, 1244 TanTanhOpConversion<complex::TanOp>, 1245 TanTanhOpConversion<complex::TanhOp>, 1246 PowOpConversion, 1247 RsqrtOpConversion 1248 >(patterns.getContext()); 1249 // clang-format on 1250 } 1251 1252 namespace { 1253 struct ConvertComplexToStandardPass 1254 : public impl::ConvertComplexToStandardBase<ConvertComplexToStandardPass> { 1255 void runOnOperation() override; 1256 }; 1257 1258 void ConvertComplexToStandardPass::runOnOperation() { 1259 // Convert to the Standard dialect using the converter defined above. 1260 RewritePatternSet patterns(&getContext()); 1261 populateComplexToStandardConversionPatterns(patterns); 1262 1263 ConversionTarget target(getContext()); 1264 target.addLegalDialect<arith::ArithDialect, math::MathDialect>(); 1265 target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>(); 1266 if (failed( 1267 applyPartialConversion(getOperation(), target, std::move(patterns)))) 1268 signalPassFailure(); 1269 } 1270 } // namespace 1271 1272 std::unique_ptr<Pass> mlir::createConvertComplexToStandardPass() { 1273 return std::make_unique<ConvertComplexToStandardPass>(); 1274 } 1275