12ea7fb7bSAdrian Kuegel //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// 22ea7fb7bSAdrian Kuegel // 32ea7fb7bSAdrian Kuegel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 42ea7fb7bSAdrian Kuegel // See https://llvm.org/LICENSE.txt for license information. 52ea7fb7bSAdrian Kuegel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 62ea7fb7bSAdrian Kuegel // 72ea7fb7bSAdrian Kuegel //===----------------------------------------------------------------------===// 82ea7fb7bSAdrian Kuegel 92ea7fb7bSAdrian Kuegel #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" 102ea7fb7bSAdrian Kuegel 11abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 122ea7fb7bSAdrian Kuegel #include "mlir/Dialect/Complex/IR/Complex.h" 132ea7fb7bSAdrian Kuegel #include "mlir/Dialect/Math/IR/Math.h" 14f112bd61SAdrian Kuegel #include "mlir/IR/ImplicitLocOpBuilder.h" 152ea7fb7bSAdrian Kuegel #include "mlir/IR/PatternMatch.h" 1667d0d7acSMichele Scuttari #include "mlir/Pass/Pass.h" 172ea7fb7bSAdrian Kuegel #include "mlir/Transforms/DialectConversion.h" 1867d0d7acSMichele Scuttari #include <memory> 1967d0d7acSMichele Scuttari #include <type_traits> 2067d0d7acSMichele Scuttari 2167d0d7acSMichele Scuttari namespace mlir { 2267d0d7acSMichele Scuttari #define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARD 2367d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc" 2467d0d7acSMichele Scuttari } // namespace mlir 252ea7fb7bSAdrian Kuegel 262ea7fb7bSAdrian Kuegel using namespace mlir; 272ea7fb7bSAdrian Kuegel 282ea7fb7bSAdrian Kuegel namespace { 299d9bb7b1SJohannes Reifferscheid 3033e60f35SJohannes Reifferscheid enum class AbsFn { abs, sqrt, rsqrt }; 3133e60f35SJohannes Reifferscheid 3233e60f35SJohannes Reifferscheid // Returns the absolute value, its square root or its reciprocal square root. 33ff9bc3a0SJohannes Reifferscheid Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf, 3433e60f35SJohannes Reifferscheid ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) { 35ff9bc3a0SJohannes Reifferscheid Value one = b.create<arith::ConstantOp>(real.getType(), 36ff9bc3a0SJohannes Reifferscheid b.getFloatAttr(real.getType(), 1.0)); 372ea7fb7bSAdrian Kuegel 389d9bb7b1SJohannes Reifferscheid Value absReal = b.create<math::AbsFOp>(real, fmf); 399d9bb7b1SJohannes Reifferscheid Value absImag = b.create<math::AbsFOp>(imag, fmf); 40b17348c3SKai Sasaki 419d9bb7b1SJohannes Reifferscheid Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf); 429d9bb7b1SJohannes Reifferscheid Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf); 439b225d01SJohannes Reifferscheid 449b225d01SJohannes Reifferscheid // The lowering below requires NaNs and infinities to work correctly. 459b225d01SJohannes Reifferscheid arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear( 469b225d01SJohannes Reifferscheid fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf); 479b225d01SJohannes Reifferscheid Value ratio = b.create<arith::DivFOp>(min, max, fmfWithNaNInf); 489b225d01SJohannes Reifferscheid Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmfWithNaNInf); 499b225d01SJohannes Reifferscheid Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmfWithNaNInf); 50ff9bc3a0SJohannes Reifferscheid Value result; 51ff9bc3a0SJohannes Reifferscheid 5233e60f35SJohannes Reifferscheid if (fn == AbsFn::rsqrt) { 539b225d01SJohannes Reifferscheid ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmfWithNaNInf); 549b225d01SJohannes Reifferscheid min = b.create<math::RsqrtOp>(min, fmfWithNaNInf); 559b225d01SJohannes Reifferscheid max = b.create<math::RsqrtOp>(max, fmfWithNaNInf); 5633e60f35SJohannes Reifferscheid } 5733e60f35SJohannes Reifferscheid 5833e60f35SJohannes Reifferscheid if (fn == AbsFn::sqrt) { 59ff9bc3a0SJohannes Reifferscheid Value quarter = b.create<arith::ConstantOp>( 60ff9bc3a0SJohannes Reifferscheid real.getType(), b.getFloatAttr(real.getType(), 0.25)); 61ff9bc3a0SJohannes Reifferscheid // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily. 629b225d01SJohannes Reifferscheid Value sqrt = b.create<math::SqrtOp>(max, fmfWithNaNInf); 639b225d01SJohannes Reifferscheid Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmfWithNaNInf); 649b225d01SJohannes Reifferscheid result = b.create<arith::MulFOp>(sqrt, p025, fmfWithNaNInf); 65ff9bc3a0SJohannes Reifferscheid } else { 669b225d01SJohannes Reifferscheid Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmfWithNaNInf); 679b225d01SJohannes Reifferscheid result = b.create<arith::MulFOp>(max, sqrt, fmfWithNaNInf); 68ff9bc3a0SJohannes Reifferscheid } 69ff9bc3a0SJohannes Reifferscheid 709b225d01SJohannes Reifferscheid Value isNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, 719b225d01SJohannes Reifferscheid result, fmfWithNaNInf); 72ff9bc3a0SJohannes Reifferscheid return b.create<arith::SelectOp>(isNaN, min, result); 73ff9bc3a0SJohannes Reifferscheid } 74ff9bc3a0SJohannes Reifferscheid 75ff9bc3a0SJohannes Reifferscheid struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { 76ff9bc3a0SJohannes Reifferscheid using OpConversionPattern<complex::AbsOp>::OpConversionPattern; 77ff9bc3a0SJohannes Reifferscheid 78ff9bc3a0SJohannes Reifferscheid LogicalResult 79ff9bc3a0SJohannes Reifferscheid matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, 80ff9bc3a0SJohannes Reifferscheid ConversionPatternRewriter &rewriter) const override { 81ff9bc3a0SJohannes Reifferscheid ImplicitLocOpBuilder b(op.getLoc(), rewriter); 82ff9bc3a0SJohannes Reifferscheid 83ff9bc3a0SJohannes Reifferscheid arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); 84ff9bc3a0SJohannes Reifferscheid 85ff9bc3a0SJohannes Reifferscheid Value real = b.create<complex::ReOp>(adaptor.getComplex()); 86ff9bc3a0SJohannes Reifferscheid Value imag = b.create<complex::ImOp>(adaptor.getComplex()); 87ff9bc3a0SJohannes Reifferscheid rewriter.replaceOp(op, computeAbs(real, imag, fmf, b)); 88b17348c3SKai Sasaki 892ea7fb7bSAdrian Kuegel return success(); 902ea7fb7bSAdrian Kuegel } 912ea7fb7bSAdrian Kuegel }; 92ac00cb0dSAdrian Kuegel 93f711785eSAlexander Belyaev // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2)) 94f711785eSAlexander Belyaev struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> { 95f711785eSAlexander Belyaev using OpConversionPattern<complex::Atan2Op>::OpConversionPattern; 96f711785eSAlexander Belyaev 97f711785eSAlexander Belyaev LogicalResult 98f711785eSAlexander Belyaev matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor, 99f711785eSAlexander Belyaev ConversionPatternRewriter &rewriter) const override { 100f711785eSAlexander Belyaev mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 101f711785eSAlexander Belyaev 1025550c821STres Popp auto type = cast<ComplexType>(op.getType()); 103f711785eSAlexander Belyaev Type elementType = type.getElementType(); 104b930b14dSKai Sasaki arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); 105f711785eSAlexander Belyaev 106f711785eSAlexander Belyaev Value lhs = adaptor.getLhs(); 107f711785eSAlexander Belyaev Value rhs = adaptor.getRhs(); 108f711785eSAlexander Belyaev 109b930b14dSKai Sasaki Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs, fmf); 110b930b14dSKai Sasaki Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs, fmf); 111f711785eSAlexander Belyaev Value rhsSquaredPlusLhsSquared = 112b930b14dSKai Sasaki b.create<complex::AddOp>(type, rhsSquared, lhsSquared, fmf); 113f711785eSAlexander Belyaev Value sqrtOfRhsSquaredPlusLhsSquared = 114b930b14dSKai Sasaki b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared, fmf); 115f711785eSAlexander Belyaev 116f711785eSAlexander Belyaev Value zero = 117f711785eSAlexander Belyaev b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); 118f711785eSAlexander Belyaev Value one = b.create<arith::ConstantOp>(elementType, 119f711785eSAlexander Belyaev b.getFloatAttr(elementType, 1)); 120f711785eSAlexander Belyaev Value i = b.create<complex::CreateOp>(type, zero, one); 121b930b14dSKai Sasaki Value iTimesLhs = b.create<complex::MulOp>(i, lhs, fmf); 122b930b14dSKai Sasaki Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs, fmf); 123f711785eSAlexander Belyaev 124b930b14dSKai Sasaki Value divResult = b.create<complex::DivOp>( 125b930b14dSKai Sasaki rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf); 126b930b14dSKai Sasaki Value logResult = b.create<complex::LogOp>(divResult, fmf); 127f711785eSAlexander Belyaev 128f711785eSAlexander Belyaev Value negativeOne = b.create<arith::ConstantOp>( 129f711785eSAlexander Belyaev elementType, b.getFloatAttr(elementType, -1)); 130f711785eSAlexander Belyaev Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne); 131f711785eSAlexander Belyaev 132b930b14dSKai Sasaki rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult, fmf); 133f711785eSAlexander Belyaev return success(); 134f711785eSAlexander Belyaev } 135f711785eSAlexander Belyaev }; 136f711785eSAlexander Belyaev 137a54f4eaeSMogball template <typename ComparisonOp, arith::CmpFPredicate p> 138fb8b2b86SAdrian Kuegel struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> { 139fb8b2b86SAdrian Kuegel using OpConversionPattern<ComparisonOp>::OpConversionPattern; 140fb8b2b86SAdrian Kuegel using ResultCombiner = 141fb8b2b86SAdrian Kuegel std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value, 142a54f4eaeSMogball arith::AndIOp, arith::OrIOp>; 143ac00cb0dSAdrian Kuegel 144ac00cb0dSAdrian Kuegel LogicalResult 145b54c724bSRiver Riddle matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor, 146ac00cb0dSAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 147ac00cb0dSAdrian Kuegel auto loc = op.getLoc(); 1485550c821STres Popp auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType(); 149ac00cb0dSAdrian Kuegel 150c0342a2dSJacques Pienaar Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs()); 151c0342a2dSJacques Pienaar Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs()); 152c0342a2dSJacques Pienaar Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs()); 153c0342a2dSJacques Pienaar Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs()); 154a54f4eaeSMogball Value realComparison = 155a54f4eaeSMogball rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs); 156a54f4eaeSMogball Value imagComparison = 157a54f4eaeSMogball rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs); 158ac00cb0dSAdrian Kuegel 159fb8b2b86SAdrian Kuegel rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison, 160fb8b2b86SAdrian Kuegel imagComparison); 161ac00cb0dSAdrian Kuegel return success(); 162ac00cb0dSAdrian Kuegel } 163ac00cb0dSAdrian Kuegel }; 164942be7cbSAdrian Kuegel 165fb978f09SAdrian Kuegel // Default conversion which applies the BinaryStandardOp separately on the real 166fb978f09SAdrian Kuegel // and imaginary parts. Can for example be used for complex::AddOp and 167fb978f09SAdrian Kuegel // complex::SubOp. 168fb978f09SAdrian Kuegel template <typename BinaryComplexOp, typename BinaryStandardOp> 169fb978f09SAdrian Kuegel struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> { 170fb978f09SAdrian Kuegel using OpConversionPattern<BinaryComplexOp>::OpConversionPattern; 171fb978f09SAdrian Kuegel 172fb978f09SAdrian Kuegel LogicalResult 173b54c724bSRiver Riddle matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor, 174fb978f09SAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 1755550c821STres Popp auto type = cast<ComplexType>(adaptor.getLhs().getType()); 1765550c821STres Popp auto elementType = cast<FloatType>(type.getElementType()); 177fb978f09SAdrian Kuegel mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 178be5b6667SKai Sasaki arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); 179fb978f09SAdrian Kuegel 180c0342a2dSJacques Pienaar Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs()); 181c0342a2dSJacques Pienaar Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs()); 182be5b6667SKai Sasaki Value resultReal = b.create<BinaryStandardOp>(elementType, realLhs, realRhs, 183be5b6667SKai Sasaki fmf.getValue()); 184c0342a2dSJacques Pienaar Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs()); 185c0342a2dSJacques Pienaar Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs()); 186be5b6667SKai Sasaki Value resultImag = b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs, 187be5b6667SKai Sasaki fmf.getValue()); 188fb978f09SAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 189fb978f09SAdrian Kuegel resultImag); 190fb978f09SAdrian Kuegel return success(); 191fb978f09SAdrian Kuegel } 192fb978f09SAdrian Kuegel }; 193fb978f09SAdrian Kuegel 194672b908bSGoran Flegar template <typename TrigonometricOp> 195672b908bSGoran Flegar struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> { 196672b908bSGoran Flegar using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor; 197672b908bSGoran Flegar 198672b908bSGoran Flegar using OpConversionPattern<TrigonometricOp>::OpConversionPattern; 199672b908bSGoran Flegar 200672b908bSGoran Flegar LogicalResult 201672b908bSGoran Flegar matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor, 202672b908bSGoran Flegar ConversionPatternRewriter &rewriter) const override { 203672b908bSGoran Flegar auto loc = op.getLoc(); 2045550c821STres Popp auto type = cast<ComplexType>(adaptor.getComplex().getType()); 2055550c821STres Popp auto elementType = cast<FloatType>(type.getElementType()); 2067d2d8e2aSKai Sasaki arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); 207672b908bSGoran Flegar 208672b908bSGoran Flegar Value real = 209672b908bSGoran Flegar rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 210672b908bSGoran Flegar Value imag = 211672b908bSGoran Flegar rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 212672b908bSGoran Flegar 213672b908bSGoran Flegar // Trigonometric ops use a set of common building blocks to convert to real 214672b908bSGoran Flegar // ops. Here we create these building blocks and call into an op-specific 215672b908bSGoran Flegar // implementation in the subclass to combine them. 216672b908bSGoran Flegar Value half = rewriter.create<arith::ConstantOp>( 217672b908bSGoran Flegar loc, elementType, rewriter.getFloatAttr(elementType, 0.5)); 2187d2d8e2aSKai Sasaki Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf); 2197d2d8e2aSKai Sasaki Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp, fmf); 2207d2d8e2aSKai Sasaki Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp, fmf); 2217d2d8e2aSKai Sasaki Value sin = rewriter.create<math::SinOp>(loc, real, fmf); 2227d2d8e2aSKai Sasaki Value cos = rewriter.create<math::CosOp>(loc, real, fmf); 223672b908bSGoran Flegar 224672b908bSGoran Flegar auto resultPair = 2257d2d8e2aSKai Sasaki combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf); 226672b908bSGoran Flegar 227672b908bSGoran Flegar rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first, 228672b908bSGoran Flegar resultPair.second); 229672b908bSGoran Flegar return success(); 230672b908bSGoran Flegar } 231672b908bSGoran Flegar 232672b908bSGoran Flegar virtual std::pair<Value, Value> 233672b908bSGoran Flegar combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin, 2347d2d8e2aSKai Sasaki Value cos, ConversionPatternRewriter &rewriter, 2357d2d8e2aSKai Sasaki arith::FastMathFlagsAttr fmf) const = 0; 236672b908bSGoran Flegar }; 237672b908bSGoran Flegar 238672b908bSGoran Flegar struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> { 239672b908bSGoran Flegar using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion; 240672b908bSGoran Flegar 2417d2d8e2aSKai Sasaki std::pair<Value, Value> combine(Location loc, Value scaledExp, 2427d2d8e2aSKai Sasaki Value reciprocalExp, Value sin, Value cos, 2437d2d8e2aSKai Sasaki ConversionPatternRewriter &rewriter, 2447d2d8e2aSKai Sasaki arith::FastMathFlagsAttr fmf) const override { 245672b908bSGoran Flegar // Complex cosine is defined as; 246672b908bSGoran Flegar // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy))) 247672b908bSGoran Flegar // Plugging in: 248672b908bSGoran Flegar // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x)) 249672b908bSGoran Flegar // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x))) 250672b908bSGoran Flegar // and defining t := exp(y) 251672b908bSGoran Flegar // We get: 252672b908bSGoran Flegar // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x 253672b908bSGoran Flegar // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x 2547d2d8e2aSKai Sasaki Value sum = 2557d2d8e2aSKai Sasaki rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf); 2567d2d8e2aSKai Sasaki Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos, fmf); 2577d2d8e2aSKai Sasaki Value diff = 2587d2d8e2aSKai Sasaki rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf); 2597d2d8e2aSKai Sasaki Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin, fmf); 260672b908bSGoran Flegar return {resultReal, resultImag}; 261672b908bSGoran Flegar } 262672b908bSGoran Flegar }; 263672b908bSGoran Flegar 264942be7cbSAdrian Kuegel struct DivOpConversion : public OpConversionPattern<complex::DivOp> { 265942be7cbSAdrian Kuegel using OpConversionPattern<complex::DivOp>::OpConversionPattern; 266942be7cbSAdrian Kuegel 267942be7cbSAdrian Kuegel LogicalResult 268b54c724bSRiver Riddle matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, 269942be7cbSAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 270942be7cbSAdrian Kuegel auto loc = op.getLoc(); 2715550c821STres Popp auto type = cast<ComplexType>(adaptor.getLhs().getType()); 2725550c821STres Popp auto elementType = cast<FloatType>(type.getElementType()); 273288d317fSKai Sasaki arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); 274942be7cbSAdrian Kuegel 275942be7cbSAdrian Kuegel Value lhsReal = 276c0342a2dSJacques Pienaar rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs()); 277942be7cbSAdrian Kuegel Value lhsImag = 278c0342a2dSJacques Pienaar rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs()); 279942be7cbSAdrian Kuegel Value rhsReal = 280c0342a2dSJacques Pienaar rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs()); 281942be7cbSAdrian Kuegel Value rhsImag = 282c0342a2dSJacques Pienaar rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs()); 283942be7cbSAdrian Kuegel 284942be7cbSAdrian Kuegel // Smith's algorithm to divide complex numbers. It is just a bit smarter 285942be7cbSAdrian Kuegel // way to compute the following formula: 286942be7cbSAdrian Kuegel // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i) 287942be7cbSAdrian Kuegel // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) / 288942be7cbSAdrian Kuegel // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i)) 289942be7cbSAdrian Kuegel // = ((lhsReal * rhsReal + lhsImag * rhsImag) + 290942be7cbSAdrian Kuegel // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2 291942be7cbSAdrian Kuegel // 292942be7cbSAdrian Kuegel // Depending on whether |rhsReal| < |rhsImag| we compute either 293942be7cbSAdrian Kuegel // rhsRealImagRatio = rhsReal / rhsImag 294942be7cbSAdrian Kuegel // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio 295942be7cbSAdrian Kuegel // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom 296942be7cbSAdrian Kuegel // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom 297942be7cbSAdrian Kuegel // 298942be7cbSAdrian Kuegel // or 299942be7cbSAdrian Kuegel // 300942be7cbSAdrian Kuegel // rhsImagRealRatio = rhsImag / rhsReal 301942be7cbSAdrian Kuegel // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio 302942be7cbSAdrian Kuegel // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom 303942be7cbSAdrian Kuegel // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom 304942be7cbSAdrian Kuegel // 305942be7cbSAdrian Kuegel // See https://dl.acm.org/citation.cfm?id=368661 for more details. 306a54f4eaeSMogball Value rhsRealImagRatio = 307288d317fSKai Sasaki rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag, fmf); 308a54f4eaeSMogball Value rhsRealImagDenom = rewriter.create<arith::AddFOp>( 309a54f4eaeSMogball loc, rhsImag, 310288d317fSKai Sasaki rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal, fmf), 311288d317fSKai Sasaki fmf); 312a54f4eaeSMogball Value realNumerator1 = rewriter.create<arith::AddFOp>( 313288d317fSKai Sasaki loc, 314288d317fSKai Sasaki rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio, fmf), 315288d317fSKai Sasaki lhsImag, fmf); 316288d317fSKai Sasaki Value resultReal1 = rewriter.create<arith::DivFOp>(loc, realNumerator1, 317288d317fSKai Sasaki rhsRealImagDenom, fmf); 318a54f4eaeSMogball Value imagNumerator1 = rewriter.create<arith::SubFOp>( 319288d317fSKai Sasaki loc, 320288d317fSKai Sasaki rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio, fmf), 321288d317fSKai Sasaki lhsReal, fmf); 322288d317fSKai Sasaki Value resultImag1 = rewriter.create<arith::DivFOp>(loc, imagNumerator1, 323288d317fSKai Sasaki rhsRealImagDenom, fmf); 324942be7cbSAdrian Kuegel 325a54f4eaeSMogball Value rhsImagRealRatio = 326288d317fSKai Sasaki rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal, fmf); 327a54f4eaeSMogball Value rhsImagRealDenom = rewriter.create<arith::AddFOp>( 328a54f4eaeSMogball loc, rhsReal, 329288d317fSKai Sasaki rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag, fmf), 330288d317fSKai Sasaki fmf); 331a54f4eaeSMogball Value realNumerator2 = rewriter.create<arith::AddFOp>( 332a54f4eaeSMogball loc, lhsReal, 333288d317fSKai Sasaki rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio, fmf), 334288d317fSKai Sasaki fmf); 335288d317fSKai Sasaki Value resultReal2 = rewriter.create<arith::DivFOp>(loc, realNumerator2, 336288d317fSKai Sasaki rhsImagRealDenom, fmf); 337a54f4eaeSMogball Value imagNumerator2 = rewriter.create<arith::SubFOp>( 338a54f4eaeSMogball loc, lhsImag, 339288d317fSKai Sasaki rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio, fmf), 340288d317fSKai Sasaki fmf); 341288d317fSKai Sasaki Value resultImag2 = rewriter.create<arith::DivFOp>(loc, imagNumerator2, 342288d317fSKai Sasaki rhsImagRealDenom, fmf); 343942be7cbSAdrian Kuegel 344942be7cbSAdrian Kuegel // Consider corner cases. 345942be7cbSAdrian Kuegel // Case 1. Zero denominator, numerator contains at most one NaN value. 346a54f4eaeSMogball Value zero = rewriter.create<arith::ConstantOp>( 347a54f4eaeSMogball loc, elementType, rewriter.getZeroAttr(elementType)); 348288d317fSKai Sasaki Value rhsRealAbs = rewriter.create<math::AbsFOp>(loc, rhsReal, fmf); 349a54f4eaeSMogball Value rhsRealIsZero = rewriter.create<arith::CmpFOp>( 350a54f4eaeSMogball loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero); 351288d317fSKai Sasaki Value rhsImagAbs = rewriter.create<math::AbsFOp>(loc, rhsImag, fmf); 352a54f4eaeSMogball Value rhsImagIsZero = rewriter.create<arith::CmpFOp>( 353a54f4eaeSMogball loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero); 354a54f4eaeSMogball Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>( 355a54f4eaeSMogball loc, arith::CmpFPredicate::ORD, lhsReal, zero); 356a54f4eaeSMogball Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>( 357a54f4eaeSMogball loc, arith::CmpFPredicate::ORD, lhsImag, zero); 358942be7cbSAdrian Kuegel Value lhsContainsNotNaNValue = 359a54f4eaeSMogball rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); 360a54f4eaeSMogball Value resultIsInfinity = rewriter.create<arith::AndIOp>( 361942be7cbSAdrian Kuegel loc, lhsContainsNotNaNValue, 362a54f4eaeSMogball rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero)); 363a54f4eaeSMogball Value inf = rewriter.create<arith::ConstantOp>( 364942be7cbSAdrian Kuegel loc, elementType, 365942be7cbSAdrian Kuegel rewriter.getFloatAttr( 366942be7cbSAdrian Kuegel elementType, APFloat::getInf(elementType.getFloatSemantics()))); 367a54f4eaeSMogball Value infWithSignOfRhsReal = 368a54f4eaeSMogball rewriter.create<math::CopySignOp>(loc, inf, rhsReal); 369942be7cbSAdrian Kuegel Value infinityResultReal = 370288d317fSKai Sasaki rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal, fmf); 371942be7cbSAdrian Kuegel Value infinityResultImag = 372288d317fSKai Sasaki rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag, fmf); 373942be7cbSAdrian Kuegel 374942be7cbSAdrian Kuegel // Case 2. Infinite numerator, finite denominator. 375a54f4eaeSMogball Value rhsRealFinite = rewriter.create<arith::CmpFOp>( 376a54f4eaeSMogball loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf); 377a54f4eaeSMogball Value rhsImagFinite = rewriter.create<arith::CmpFOp>( 378a54f4eaeSMogball loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf); 379a54f4eaeSMogball Value rhsFinite = 380a54f4eaeSMogball rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite); 381288d317fSKai Sasaki Value lhsRealAbs = rewriter.create<math::AbsFOp>(loc, lhsReal, fmf); 382a54f4eaeSMogball Value lhsRealInfinite = rewriter.create<arith::CmpFOp>( 383a54f4eaeSMogball loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf); 384288d317fSKai Sasaki Value lhsImagAbs = rewriter.create<math::AbsFOp>(loc, lhsImag, fmf); 385a54f4eaeSMogball Value lhsImagInfinite = rewriter.create<arith::CmpFOp>( 386a54f4eaeSMogball loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf); 387942be7cbSAdrian Kuegel Value lhsInfinite = 388a54f4eaeSMogball rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite); 389942be7cbSAdrian Kuegel Value infNumFiniteDenom = 390a54f4eaeSMogball rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite); 391a54f4eaeSMogball Value one = rewriter.create<arith::ConstantOp>( 392942be7cbSAdrian Kuegel loc, elementType, rewriter.getFloatAttr(elementType, 1)); 393a54f4eaeSMogball Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>( 394dec8af70SRiver Riddle loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero), 395942be7cbSAdrian Kuegel lhsReal); 396a54f4eaeSMogball Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>( 397dec8af70SRiver Riddle loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero), 398942be7cbSAdrian Kuegel lhsImag); 399942be7cbSAdrian Kuegel Value lhsRealIsInfWithSignTimesRhsReal = 400288d317fSKai Sasaki rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal, fmf); 401942be7cbSAdrian Kuegel Value lhsImagIsInfWithSignTimesRhsImag = 402288d317fSKai Sasaki rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag, fmf); 403a54f4eaeSMogball Value resultReal3 = rewriter.create<arith::MulFOp>( 404942be7cbSAdrian Kuegel loc, inf, 405a54f4eaeSMogball rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal, 406288d317fSKai Sasaki lhsImagIsInfWithSignTimesRhsImag, fmf), 407288d317fSKai Sasaki fmf); 408942be7cbSAdrian Kuegel Value lhsRealIsInfWithSignTimesRhsImag = 409288d317fSKai Sasaki rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag, fmf); 410942be7cbSAdrian Kuegel Value lhsImagIsInfWithSignTimesRhsReal = 411288d317fSKai Sasaki rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal, fmf); 412a54f4eaeSMogball Value resultImag3 = rewriter.create<arith::MulFOp>( 413942be7cbSAdrian Kuegel loc, inf, 414a54f4eaeSMogball rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal, 415288d317fSKai Sasaki lhsRealIsInfWithSignTimesRhsImag, fmf), 416288d317fSKai Sasaki fmf); 417942be7cbSAdrian Kuegel 418942be7cbSAdrian Kuegel // Case 3: Finite numerator, infinite denominator. 419a54f4eaeSMogball Value lhsRealFinite = rewriter.create<arith::CmpFOp>( 420a54f4eaeSMogball loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf); 421a54f4eaeSMogball Value lhsImagFinite = rewriter.create<arith::CmpFOp>( 422a54f4eaeSMogball loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf); 423a54f4eaeSMogball Value lhsFinite = 424a54f4eaeSMogball rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite); 425a54f4eaeSMogball Value rhsRealInfinite = rewriter.create<arith::CmpFOp>( 426a54f4eaeSMogball loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf); 427a54f4eaeSMogball Value rhsImagInfinite = rewriter.create<arith::CmpFOp>( 428a54f4eaeSMogball loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf); 429942be7cbSAdrian Kuegel Value rhsInfinite = 430a54f4eaeSMogball rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite); 431942be7cbSAdrian Kuegel Value finiteNumInfiniteDenom = 432a54f4eaeSMogball rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite); 433a54f4eaeSMogball Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>( 434dec8af70SRiver Riddle loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero), 435942be7cbSAdrian Kuegel rhsReal); 436a54f4eaeSMogball Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>( 437dec8af70SRiver Riddle loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero), 438942be7cbSAdrian Kuegel rhsImag); 439942be7cbSAdrian Kuegel Value rhsRealIsInfWithSignTimesLhsReal = 440288d317fSKai Sasaki rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign, fmf); 441942be7cbSAdrian Kuegel Value rhsImagIsInfWithSignTimesLhsImag = 442288d317fSKai Sasaki rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign, fmf); 443a54f4eaeSMogball Value resultReal4 = rewriter.create<arith::MulFOp>( 444942be7cbSAdrian Kuegel loc, zero, 445a54f4eaeSMogball rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal, 446288d317fSKai Sasaki rhsImagIsInfWithSignTimesLhsImag, fmf), 447288d317fSKai Sasaki fmf); 448942be7cbSAdrian Kuegel Value rhsRealIsInfWithSignTimesLhsImag = 449288d317fSKai Sasaki rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign, fmf); 450942be7cbSAdrian Kuegel Value rhsImagIsInfWithSignTimesLhsReal = 451288d317fSKai Sasaki rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign, fmf); 452a54f4eaeSMogball Value resultImag4 = rewriter.create<arith::MulFOp>( 453942be7cbSAdrian Kuegel loc, zero, 454a54f4eaeSMogball rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag, 455288d317fSKai Sasaki rhsImagIsInfWithSignTimesLhsReal, fmf), 456288d317fSKai Sasaki fmf); 457942be7cbSAdrian Kuegel 458a54f4eaeSMogball Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>( 459a54f4eaeSMogball loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); 460dec8af70SRiver Riddle Value resultReal = rewriter.create<arith::SelectOp>( 461dec8af70SRiver Riddle loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); 462dec8af70SRiver Riddle Value resultImag = rewriter.create<arith::SelectOp>( 463dec8af70SRiver Riddle loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); 464dec8af70SRiver Riddle Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>( 465942be7cbSAdrian Kuegel loc, finiteNumInfiniteDenom, resultReal4, resultReal); 466dec8af70SRiver Riddle Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>( 467942be7cbSAdrian Kuegel loc, finiteNumInfiniteDenom, resultImag4, resultImag); 468dec8af70SRiver Riddle Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>( 469942be7cbSAdrian Kuegel loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); 470dec8af70SRiver Riddle Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>( 471942be7cbSAdrian Kuegel loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); 472dec8af70SRiver Riddle Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>( 473942be7cbSAdrian Kuegel loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); 474dec8af70SRiver Riddle Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>( 475942be7cbSAdrian Kuegel loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); 476942be7cbSAdrian Kuegel 477a54f4eaeSMogball Value resultRealIsNaN = rewriter.create<arith::CmpFOp>( 478a54f4eaeSMogball loc, arith::CmpFPredicate::UNO, resultReal, zero); 479a54f4eaeSMogball Value resultImagIsNaN = rewriter.create<arith::CmpFOp>( 480a54f4eaeSMogball loc, arith::CmpFPredicate::UNO, resultImag, zero); 481942be7cbSAdrian Kuegel Value resultIsNaN = 482a54f4eaeSMogball rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN); 483dec8af70SRiver Riddle Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>( 484942be7cbSAdrian Kuegel loc, resultIsNaN, resultRealSpecialCase1, resultReal); 485dec8af70SRiver Riddle Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>( 486942be7cbSAdrian Kuegel loc, resultIsNaN, resultImagSpecialCase1, resultImag); 487942be7cbSAdrian Kuegel 488942be7cbSAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>( 489942be7cbSAdrian Kuegel op, type, resultRealWithSpecialCases, resultImagWithSpecialCases); 490942be7cbSAdrian Kuegel return success(); 491942be7cbSAdrian Kuegel } 492942be7cbSAdrian Kuegel }; 49373cbc91cSAdrian Kuegel 49473cbc91cSAdrian Kuegel struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> { 49573cbc91cSAdrian Kuegel using OpConversionPattern<complex::ExpOp>::OpConversionPattern; 49673cbc91cSAdrian Kuegel 49773cbc91cSAdrian Kuegel LogicalResult 498b54c724bSRiver Riddle matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor, 49973cbc91cSAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 50073cbc91cSAdrian Kuegel auto loc = op.getLoc(); 5015550c821STres Popp auto type = cast<ComplexType>(adaptor.getComplex().getType()); 5025550c821STres Popp auto elementType = cast<FloatType>(type.getElementType()); 503d230bf3fSKai Sasaki arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); 50473cbc91cSAdrian Kuegel 50573cbc91cSAdrian Kuegel Value real = 506c0342a2dSJacques Pienaar rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 50773cbc91cSAdrian Kuegel Value imag = 508c0342a2dSJacques Pienaar rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 509d230bf3fSKai Sasaki Value expReal = rewriter.create<math::ExpOp>(loc, real, fmf.getValue()); 510d230bf3fSKai Sasaki Value cosImag = rewriter.create<math::CosOp>(loc, imag, fmf.getValue()); 511d230bf3fSKai Sasaki Value resultReal = 512d230bf3fSKai Sasaki rewriter.create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue()); 513d230bf3fSKai Sasaki Value sinImag = rewriter.create<math::SinOp>(loc, imag, fmf.getValue()); 514d230bf3fSKai Sasaki Value resultImag = 515d230bf3fSKai Sasaki rewriter.create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue()); 51673cbc91cSAdrian Kuegel 51773cbc91cSAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 51873cbc91cSAdrian Kuegel resultImag); 51973cbc91cSAdrian Kuegel return success(); 52073cbc91cSAdrian Kuegel } 52173cbc91cSAdrian Kuegel }; 522662e074dSAdrian Kuegel 52318ee0032SAlexander Belyaev Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg, 52418ee0032SAlexander Belyaev ArrayRef<double> coefficients, 52518ee0032SAlexander Belyaev arith::FastMathFlagsAttr fmf) { 52618ee0032SAlexander Belyaev auto argType = mlir::cast<FloatType>(arg.getType()); 52718ee0032SAlexander Belyaev Value poly = 52818ee0032SAlexander Belyaev b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[0])); 529*06011feeSJie Fu for (unsigned i = 1; i < coefficients.size(); ++i) { 53018ee0032SAlexander Belyaev poly = b.create<math::FmaOp>( 53118ee0032SAlexander Belyaev poly, arg, 53218ee0032SAlexander Belyaev b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[i])), 53318ee0032SAlexander Belyaev fmf); 53418ee0032SAlexander Belyaev } 53518ee0032SAlexander Belyaev return poly; 53618ee0032SAlexander Belyaev } 53718ee0032SAlexander Belyaev 538338e76f8Sbixia1 struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> { 539338e76f8Sbixia1 using OpConversionPattern<complex::Expm1Op>::OpConversionPattern; 540338e76f8Sbixia1 54118ee0032SAlexander Belyaev // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i 54218ee0032SAlexander Belyaev // [handle inaccuracies when a and/or b are small] 54318ee0032SAlexander Belyaev // = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i 54418ee0032SAlexander Belyaev // = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i 545338e76f8Sbixia1 LogicalResult 546338e76f8Sbixia1 matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor, 547338e76f8Sbixia1 ConversionPatternRewriter &rewriter) const override { 54818ee0032SAlexander Belyaev auto type = op.getType(); 54918ee0032SAlexander Belyaev auto elemType = mlir::cast<FloatType>(type.getElementType()); 55018ee0032SAlexander Belyaev 551d230bf3fSKai Sasaki arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); 55218ee0032SAlexander Belyaev ImplicitLocOpBuilder b(op.getLoc(), rewriter); 55318ee0032SAlexander Belyaev Value real = b.create<complex::ReOp>(adaptor.getComplex()); 55418ee0032SAlexander Belyaev Value imag = b.create<complex::ImOp>(adaptor.getComplex()); 555338e76f8Sbixia1 55618ee0032SAlexander Belyaev Value zero = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 0.0)); 55718ee0032SAlexander Belyaev Value one = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 1.0)); 558338e76f8Sbixia1 55918ee0032SAlexander Belyaev Value expm1Real = b.create<math::ExpM1Op>(real, fmf); 56018ee0032SAlexander Belyaev Value expReal = b.create<arith::AddFOp>(expm1Real, one, fmf); 561338e76f8Sbixia1 56218ee0032SAlexander Belyaev Value sinImag = b.create<math::SinOp>(imag, fmf); 56318ee0032SAlexander Belyaev Value cosm1Imag = emitCosm1(imag, fmf, b); 56418ee0032SAlexander Belyaev Value cosImag = b.create<arith::AddFOp>(cosm1Imag, one, fmf); 56518ee0032SAlexander Belyaev 56618ee0032SAlexander Belyaev Value realResult = b.create<arith::AddFOp>( 56718ee0032SAlexander Belyaev b.create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf); 56818ee0032SAlexander Belyaev 56918ee0032SAlexander Belyaev Value imagIsZero = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, 57018ee0032SAlexander Belyaev zero, fmf.getValue()); 57118ee0032SAlexander Belyaev Value imagResult = b.create<arith::SelectOp>( 57218ee0032SAlexander Belyaev imagIsZero, zero, b.create<arith::MulFOp>(expReal, sinImag, fmf)); 57318ee0032SAlexander Belyaev 57418ee0032SAlexander Belyaev rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realResult, 57518ee0032SAlexander Belyaev imagResult); 576338e76f8Sbixia1 return success(); 577338e76f8Sbixia1 } 57818ee0032SAlexander Belyaev 57918ee0032SAlexander Belyaev private: 58018ee0032SAlexander Belyaev Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf, 58118ee0032SAlexander Belyaev ImplicitLocOpBuilder &b) const { 58218ee0032SAlexander Belyaev auto argType = mlir::cast<FloatType>(arg.getType()); 58318ee0032SAlexander Belyaev auto negHalf = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -0.5)); 58418ee0032SAlexander Belyaev auto negOne = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -1.0)); 58518ee0032SAlexander Belyaev 58618ee0032SAlexander Belyaev // Algorithm copied from cephes cosm1. 58718ee0032SAlexander Belyaev SmallVector<double, 7> kCoeffs{ 58818ee0032SAlexander Belyaev 4.7377507964246204691685E-14, -1.1470284843425359765671E-11, 58918ee0032SAlexander Belyaev 2.0876754287081521758361E-9, -2.7557319214999787979814E-7, 59018ee0032SAlexander Belyaev 2.4801587301570552304991E-5, -1.3888888888888872993737E-3, 59118ee0032SAlexander Belyaev 4.1666666666666666609054E-2, 59218ee0032SAlexander Belyaev }; 59318ee0032SAlexander Belyaev Value cos = b.create<math::CosOp>(arg, fmf); 59418ee0032SAlexander Belyaev Value forLargeArg = b.create<arith::AddFOp>(cos, negOne, fmf); 59518ee0032SAlexander Belyaev 59618ee0032SAlexander Belyaev Value argPow2 = b.create<arith::MulFOp>(arg, arg, fmf); 59718ee0032SAlexander Belyaev Value argPow4 = b.create<arith::MulFOp>(argPow2, argPow2, fmf); 59818ee0032SAlexander Belyaev Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf); 59918ee0032SAlexander Belyaev 60018ee0032SAlexander Belyaev auto forSmallArg = 60118ee0032SAlexander Belyaev b.create<arith::AddFOp>(b.create<arith::MulFOp>(argPow4, poly, fmf), 60218ee0032SAlexander Belyaev b.create<arith::MulFOp>(negHalf, argPow2, fmf)); 60318ee0032SAlexander Belyaev 60418ee0032SAlexander Belyaev // (pi/4)^2 is approximately 0.61685 60518ee0032SAlexander Belyaev Value piOver4Pow2 = 60618ee0032SAlexander Belyaev b.create<arith::ConstantOp>(b.getFloatAttr(argType, 0.61685)); 60718ee0032SAlexander Belyaev Value cond = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, argPow2, 60818ee0032SAlexander Belyaev piOver4Pow2, fmf.getValue()); 60918ee0032SAlexander Belyaev return b.create<arith::SelectOp>(cond, forLargeArg, forSmallArg); 61018ee0032SAlexander Belyaev } 611338e76f8Sbixia1 }; 612338e76f8Sbixia1 613380fa71fSAdrian Kuegel struct LogOpConversion : public OpConversionPattern<complex::LogOp> { 614380fa71fSAdrian Kuegel using OpConversionPattern<complex::LogOp>::OpConversionPattern; 615380fa71fSAdrian Kuegel 616380fa71fSAdrian Kuegel LogicalResult 617b54c724bSRiver Riddle matchAndRewrite(complex::LogOp op, OpAdaptor adaptor, 618380fa71fSAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 6195550c821STres Popp auto type = cast<ComplexType>(adaptor.getComplex().getType()); 6205550c821STres Popp auto elementType = cast<FloatType>(type.getElementType()); 6218aaa2cb8SKai Sasaki arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); 622380fa71fSAdrian Kuegel mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 623380fa71fSAdrian Kuegel 6248aaa2cb8SKai Sasaki Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), 6258aaa2cb8SKai Sasaki fmf.getValue()); 6268aaa2cb8SKai Sasaki Value resultReal = b.create<math::LogOp>(elementType, abs, fmf.getValue()); 627c0342a2dSJacques Pienaar Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 628c0342a2dSJacques Pienaar Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 6298aaa2cb8SKai Sasaki Value resultImag = 6308aaa2cb8SKai Sasaki b.create<math::Atan2Op>(elementType, imag, real, fmf.getValue()); 631380fa71fSAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 632380fa71fSAdrian Kuegel resultImag); 633380fa71fSAdrian Kuegel return success(); 634380fa71fSAdrian Kuegel } 635380fa71fSAdrian Kuegel }; 636380fa71fSAdrian Kuegel 6376e80e3bdSAdrian Kuegel struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> { 6386e80e3bdSAdrian Kuegel using OpConversionPattern<complex::Log1pOp>::OpConversionPattern; 6396e80e3bdSAdrian Kuegel 6406e80e3bdSAdrian Kuegel LogicalResult 641b54c724bSRiver Riddle matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor, 6426e80e3bdSAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 6435550c821STres Popp auto type = cast<ComplexType>(adaptor.getComplex().getType()); 6445550c821STres Popp auto elementType = cast<FloatType>(type.getElementType()); 6455c9315f5SJohannes Reifferscheid arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); 6466e80e3bdSAdrian Kuegel mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 6476e80e3bdSAdrian Kuegel 6485c9315f5SJohannes Reifferscheid Value real = b.create<complex::ReOp>(adaptor.getComplex()); 6495c9315f5SJohannes Reifferscheid Value imag = b.create<complex::ImOp>(adaptor.getComplex()); 650375a5cb6SJohannes Reifferscheid 651375a5cb6SJohannes Reifferscheid Value half = b.create<arith::ConstantOp>(elementType, 652375a5cb6SJohannes Reifferscheid b.getFloatAttr(elementType, 0.5)); 653a54f4eaeSMogball Value one = b.create<arith::ConstantOp>(elementType, 654a54f4eaeSMogball b.getFloatAttr(elementType, 1)); 6555c9315f5SJohannes Reifferscheid Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf); 6565c9315f5SJohannes Reifferscheid Value absRealPlusOne = b.create<math::AbsFOp>(realPlusOne, fmf); 6575c9315f5SJohannes Reifferscheid Value absImag = b.create<math::AbsFOp>(imag, fmf); 658375a5cb6SJohannes Reifferscheid 6595c9315f5SJohannes Reifferscheid Value maxAbs = b.create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf); 6605c9315f5SJohannes Reifferscheid Value minAbs = b.create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf); 661375a5cb6SJohannes Reifferscheid 6625c9315f5SJohannes Reifferscheid Value useReal = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, 6635c9315f5SJohannes Reifferscheid realPlusOne, absImag, fmf); 6645c9315f5SJohannes Reifferscheid Value maxMinusOne = b.create<arith::SubFOp>(maxAbs, one, fmf); 6655c9315f5SJohannes Reifferscheid Value maxAbsOfRealPlusOneAndImagMinusOne = 6665c9315f5SJohannes Reifferscheid b.create<arith::SelectOp>(useReal, real, maxMinusOne); 6679b225d01SJohannes Reifferscheid arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear( 6689b225d01SJohannes Reifferscheid fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf); 6699b225d01SJohannes Reifferscheid Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmfWithNaNInf); 6705c9315f5SJohannes Reifferscheid Value logOfMaxAbsOfRealPlusOneAndImag = 6715c9315f5SJohannes Reifferscheid b.create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf); 6725c9315f5SJohannes Reifferscheid Value logOfSqrtPart = b.create<math::Log1pOp>( 6739b225d01SJohannes Reifferscheid b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmfWithNaNInf), 6749b225d01SJohannes Reifferscheid fmfWithNaNInf); 6755c9315f5SJohannes Reifferscheid Value r = b.create<arith::AddFOp>( 6769b225d01SJohannes Reifferscheid b.create<arith::MulFOp>(half, logOfSqrtPart, fmfWithNaNInf), 6779b225d01SJohannes Reifferscheid logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf); 6785c9315f5SJohannes Reifferscheid Value resultReal = b.create<arith::SelectOp>( 6799b225d01SJohannes Reifferscheid b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf), 6809b225d01SJohannes Reifferscheid minAbs, r); 6815c9315f5SJohannes Reifferscheid Value resultImag = b.create<math::Atan2Op>(imag, realPlusOne, fmf); 682375a5cb6SJohannes Reifferscheid rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 683375a5cb6SJohannes Reifferscheid resultImag); 6846e80e3bdSAdrian Kuegel return success(); 6856e80e3bdSAdrian Kuegel } 6866e80e3bdSAdrian Kuegel }; 6876e80e3bdSAdrian Kuegel 688bf17ee19SAdrian Kuegel struct MulOpConversion : public OpConversionPattern<complex::MulOp> { 689bf17ee19SAdrian Kuegel using OpConversionPattern<complex::MulOp>::OpConversionPattern; 690bf17ee19SAdrian Kuegel 691bf17ee19SAdrian Kuegel LogicalResult 692b54c724bSRiver Riddle matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, 693bf17ee19SAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 694bf17ee19SAdrian Kuegel mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 6955550c821STres Popp auto type = cast<ComplexType>(adaptor.getLhs().getType()); 6965550c821STres Popp auto elementType = cast<FloatType>(type.getElementType()); 697eee71ed3SKai Sasaki arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); 698eee71ed3SKai Sasaki auto fmfValue = fmf.getValue(); 699c0342a2dSJacques Pienaar Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs()); 700c0342a2dSJacques Pienaar Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs()); 701c0342a2dSJacques Pienaar Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs()); 702c0342a2dSJacques Pienaar Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs()); 703eee71ed3SKai Sasaki Value lhsRealTimesRhsReal = 704eee71ed3SKai Sasaki b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue); 705eee71ed3SKai Sasaki Value lhsImagTimesRhsImag = 706eee71ed3SKai Sasaki b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue); 707eee71ed3SKai Sasaki Value real = b.create<arith::SubFOp>(lhsRealTimesRhsReal, 708eee71ed3SKai Sasaki lhsImagTimesRhsImag, fmfValue); 709eee71ed3SKai Sasaki Value lhsImagTimesRhsReal = 710eee71ed3SKai Sasaki b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue); 711eee71ed3SKai Sasaki Value lhsRealTimesRhsImag = 712eee71ed3SKai Sasaki b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue); 713eee71ed3SKai Sasaki Value imag = b.create<arith::AddFOp>(lhsImagTimesRhsReal, 714eee71ed3SKai Sasaki lhsRealTimesRhsImag, fmfValue); 715bf17ee19SAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag); 716bf17ee19SAdrian Kuegel return success(); 717bf17ee19SAdrian Kuegel } 718bf17ee19SAdrian Kuegel }; 719bf17ee19SAdrian Kuegel 720662e074dSAdrian Kuegel struct NegOpConversion : public OpConversionPattern<complex::NegOp> { 721662e074dSAdrian Kuegel using OpConversionPattern<complex::NegOp>::OpConversionPattern; 722662e074dSAdrian Kuegel 723662e074dSAdrian Kuegel LogicalResult 724b54c724bSRiver Riddle matchAndRewrite(complex::NegOp op, OpAdaptor adaptor, 725662e074dSAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 726662e074dSAdrian Kuegel auto loc = op.getLoc(); 7275550c821STres Popp auto type = cast<ComplexType>(adaptor.getComplex().getType()); 7285550c821STres Popp auto elementType = cast<FloatType>(type.getElementType()); 729662e074dSAdrian Kuegel 730662e074dSAdrian Kuegel Value real = 731c0342a2dSJacques Pienaar rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 732662e074dSAdrian Kuegel Value imag = 733c0342a2dSJacques Pienaar rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 734a54f4eaeSMogball Value negReal = rewriter.create<arith::NegFOp>(loc, real); 735a54f4eaeSMogball Value negImag = rewriter.create<arith::NegFOp>(loc, imag); 736662e074dSAdrian Kuegel rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag); 737662e074dSAdrian Kuegel return success(); 738662e074dSAdrian Kuegel } 739662e074dSAdrian Kuegel }; 740f112bd61SAdrian Kuegel 741672b908bSGoran Flegar struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> { 742672b908bSGoran Flegar using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion; 743672b908bSGoran Flegar 7447d2d8e2aSKai Sasaki std::pair<Value, Value> combine(Location loc, Value scaledExp, 7457d2d8e2aSKai Sasaki Value reciprocalExp, Value sin, Value cos, 7467d2d8e2aSKai Sasaki ConversionPatternRewriter &rewriter, 7477d2d8e2aSKai Sasaki arith::FastMathFlagsAttr fmf) const override { 748672b908bSGoran Flegar // Complex sine is defined as; 749672b908bSGoran Flegar // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy))) 750672b908bSGoran Flegar // Plugging in: 751672b908bSGoran Flegar // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x)) 752672b908bSGoran Flegar // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x))) 753672b908bSGoran Flegar // and defining t := exp(y) 754672b908bSGoran Flegar // We get: 755672b908bSGoran Flegar // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x 756672b908bSGoran Flegar // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x 7577d2d8e2aSKai Sasaki Value sum = 7587d2d8e2aSKai Sasaki rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf); 7597d2d8e2aSKai Sasaki Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin, fmf); 7607d2d8e2aSKai Sasaki Value diff = 7617d2d8e2aSKai Sasaki rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf); 7627d2d8e2aSKai Sasaki Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos, fmf); 763672b908bSGoran Flegar return {resultReal, resultImag}; 764672b908bSGoran Flegar } 765672b908bSGoran Flegar }; 766672b908bSGoran Flegar 767f711785eSAlexander Belyaev // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780. 768f711785eSAlexander Belyaev struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> { 769f711785eSAlexander Belyaev using OpConversionPattern<complex::SqrtOp>::OpConversionPattern; 770f711785eSAlexander Belyaev 771f711785eSAlexander Belyaev LogicalResult 772f711785eSAlexander Belyaev matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor, 773f711785eSAlexander Belyaev ConversionPatternRewriter &rewriter) const override { 774ff9bc3a0SJohannes Reifferscheid ImplicitLocOpBuilder b(op.getLoc(), rewriter); 775f711785eSAlexander Belyaev 7765550c821STres Popp auto type = cast<ComplexType>(op.getType()); 777fac349a1SChristian Sigg auto elementType = cast<FloatType>(type.getElementType()); 778ff9bc3a0SJohannes Reifferscheid arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); 779f711785eSAlexander Belyaev 780ff9bc3a0SJohannes Reifferscheid auto cst = [&](APFloat v) { 781ff9bc3a0SJohannes Reifferscheid return b.create<arith::ConstantOp>(elementType, 782ff9bc3a0SJohannes Reifferscheid b.getFloatAttr(elementType, v)); 783ff9bc3a0SJohannes Reifferscheid }; 784ff9bc3a0SJohannes Reifferscheid const auto &floatSemantics = elementType.getFloatSemantics(); 785ff9bc3a0SJohannes Reifferscheid Value zero = cst(APFloat::getZero(floatSemantics)); 786ff9bc3a0SJohannes Reifferscheid Value half = b.create<arith::ConstantOp>(elementType, 787ff9bc3a0SJohannes Reifferscheid b.getFloatAttr(elementType, 0.5)); 788f711785eSAlexander Belyaev 789f711785eSAlexander Belyaev Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 790f711785eSAlexander Belyaev Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 79133e60f35SJohannes Reifferscheid Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt); 792ff9bc3a0SJohannes Reifferscheid Value argArg = b.create<math::Atan2Op>(imag, real, fmf); 793ff9bc3a0SJohannes Reifferscheid Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf); 794ff9bc3a0SJohannes Reifferscheid Value cos = b.create<math::CosOp>(sqrtArg, fmf); 795ff9bc3a0SJohannes Reifferscheid Value sin = b.create<math::SinOp>(sqrtArg, fmf); 796ff9bc3a0SJohannes Reifferscheid // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply 797ff9bc3a0SJohannes Reifferscheid // 0 * inf. 798ff9bc3a0SJohannes Reifferscheid Value sinIsZero = 799ff9bc3a0SJohannes Reifferscheid b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf); 800f711785eSAlexander Belyaev 801ff9bc3a0SJohannes Reifferscheid Value resultReal = b.create<arith::MulFOp>(absSqrt, cos, fmf); 802f711785eSAlexander Belyaev Value resultImag = b.create<arith::SelectOp>( 803ff9bc3a0SJohannes Reifferscheid sinIsZero, zero, b.create<arith::MulFOp>(absSqrt, sin, fmf)); 804ff9bc3a0SJohannes Reifferscheid if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | 805ff9bc3a0SJohannes Reifferscheid arith::FastMathFlags::ninf)) { 806ff9bc3a0SJohannes Reifferscheid Value inf = cst(APFloat::getInf(floatSemantics)); 807ff9bc3a0SJohannes Reifferscheid Value negInf = cst(APFloat::getInf(floatSemantics, true)); 808ff9bc3a0SJohannes Reifferscheid Value nan = cst(APFloat::getNaN(floatSemantics)); 809ff9bc3a0SJohannes Reifferscheid Value absImag = b.create<math::AbsFOp>(elementType, imag, fmf); 810ff9bc3a0SJohannes Reifferscheid 811ff9bc3a0SJohannes Reifferscheid Value absImagIsInf = 812ff9bc3a0SJohannes Reifferscheid b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf); 813ff9bc3a0SJohannes Reifferscheid Value absImagIsNotInf = 814ff9bc3a0SJohannes Reifferscheid b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf); 815ff9bc3a0SJohannes Reifferscheid Value realIsInf = 816ff9bc3a0SJohannes Reifferscheid b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf); 817ff9bc3a0SJohannes Reifferscheid Value realIsNegInf = 818ff9bc3a0SJohannes Reifferscheid b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf); 819f711785eSAlexander Belyaev 820f711785eSAlexander Belyaev resultReal = b.create<arith::SelectOp>( 821ff9bc3a0SJohannes Reifferscheid b.create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero, 822f711785eSAlexander Belyaev resultReal); 823ff9bc3a0SJohannes Reifferscheid resultReal = b.create<arith::SelectOp>( 824ff9bc3a0SJohannes Reifferscheid b.create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal); 825f711785eSAlexander Belyaev 826ff9bc3a0SJohannes Reifferscheid Value imagSignInf = b.create<math::CopySignOp>(inf, imag, fmf); 827ff9bc3a0SJohannes Reifferscheid resultImag = b.create<arith::SelectOp>( 828ff9bc3a0SJohannes Reifferscheid b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt), 829ff9bc3a0SJohannes Reifferscheid nan, resultImag); 830ff9bc3a0SJohannes Reifferscheid resultImag = b.create<arith::SelectOp>( 831ff9bc3a0SJohannes Reifferscheid b.create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf, 832ff9bc3a0SJohannes Reifferscheid resultImag); 833ff9bc3a0SJohannes Reifferscheid } 834f711785eSAlexander Belyaev 835ff9bc3a0SJohannes Reifferscheid Value resultIsZero = 836ff9bc3a0SJohannes Reifferscheid b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf); 837ff9bc3a0SJohannes Reifferscheid resultReal = b.create<arith::SelectOp>(resultIsZero, zero, resultReal); 838ff9bc3a0SJohannes Reifferscheid resultImag = b.create<arith::SelectOp>(resultIsZero, zero, resultImag); 839f711785eSAlexander Belyaev 840f711785eSAlexander Belyaev rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 841f711785eSAlexander Belyaev resultImag); 842f711785eSAlexander Belyaev return success(); 843f711785eSAlexander Belyaev } 844f711785eSAlexander Belyaev }; 845f711785eSAlexander Belyaev 846f112bd61SAdrian Kuegel struct SignOpConversion : public OpConversionPattern<complex::SignOp> { 847f112bd61SAdrian Kuegel using OpConversionPattern<complex::SignOp>::OpConversionPattern; 848f112bd61SAdrian Kuegel 849f112bd61SAdrian Kuegel LogicalResult 850b54c724bSRiver Riddle matchAndRewrite(complex::SignOp op, OpAdaptor adaptor, 851f112bd61SAdrian Kuegel ConversionPatternRewriter &rewriter) const override { 8525550c821STres Popp auto type = cast<ComplexType>(adaptor.getComplex().getType()); 8535550c821STres Popp auto elementType = cast<FloatType>(type.getElementType()); 854f112bd61SAdrian Kuegel mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 855a522dbbdSKai Sasaki arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); 856f112bd61SAdrian Kuegel 857c0342a2dSJacques Pienaar Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 858c0342a2dSJacques Pienaar Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 859a54f4eaeSMogball Value zero = 860a54f4eaeSMogball b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); 861a54f4eaeSMogball Value realIsZero = 862a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero); 863a54f4eaeSMogball Value imagIsZero = 864a54f4eaeSMogball b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero); 865a54f4eaeSMogball Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero); 866a522dbbdSKai Sasaki auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), fmf); 867a522dbbdSKai Sasaki Value realSign = b.create<arith::DivFOp>(real, abs, fmf); 868a522dbbdSKai Sasaki Value imagSign = b.create<arith::DivFOp>(imag, abs, fmf); 869f112bd61SAdrian Kuegel Value sign = b.create<complex::CreateOp>(type, realSign, imagSign); 870dec8af70SRiver Riddle rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero, 871dec8af70SRiver Riddle adaptor.getComplex(), sign); 872f112bd61SAdrian Kuegel return success(); 873f112bd61SAdrian Kuegel } 874f112bd61SAdrian Kuegel }; 8756d75c897Slewuathe 876f43deca2SJohannes Reifferscheid template <typename Op> 877f43deca2SJohannes Reifferscheid struct TanTanhOpConversion : public OpConversionPattern<Op> { 878f43deca2SJohannes Reifferscheid using OpConversionPattern<Op>::OpConversionPattern; 8796d75c897Slewuathe 8806d75c897Slewuathe LogicalResult 881e0a293d1SKazu Hirata matchAndRewrite(Op op, typename Op::Adaptor adaptor, 882ffb8eecdSlewuathe ConversionPatternRewriter &rewriter) const override { 8839da0ef16SJohannes Reifferscheid ImplicitLocOpBuilder b(op.getLoc(), rewriter); 884ffb8eecdSlewuathe auto loc = op.getLoc(); 8855550c821STres Popp auto type = cast<ComplexType>(adaptor.getComplex().getType()); 8865550c821STres Popp auto elementType = cast<FloatType>(type.getElementType()); 8879da0ef16SJohannes Reifferscheid arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); 8889da0ef16SJohannes Reifferscheid const auto &floatSemantics = elementType.getFloatSemantics(); 889ffb8eecdSlewuathe 890ffb8eecdSlewuathe Value real = 8919da0ef16SJohannes Reifferscheid b.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 892ffb8eecdSlewuathe Value imag = 8939da0ef16SJohannes Reifferscheid b.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 894f43deca2SJohannes Reifferscheid Value negOne = b.create<arith::ConstantOp>( 895f43deca2SJohannes Reifferscheid elementType, b.getFloatAttr(elementType, -1.0)); 896f43deca2SJohannes Reifferscheid 897f43deca2SJohannes Reifferscheid if constexpr (std::is_same_v<Op, complex::TanOp>) { 898f43deca2SJohannes Reifferscheid // tan(x+yi) = -i*tanh(-y + xi) 899f43deca2SJohannes Reifferscheid std::swap(real, imag); 900f43deca2SJohannes Reifferscheid real = b.create<arith::MulFOp>(real, negOne, fmf); 901f43deca2SJohannes Reifferscheid } 9029da0ef16SJohannes Reifferscheid 9039da0ef16SJohannes Reifferscheid auto cst = [&](APFloat v) { 9049da0ef16SJohannes Reifferscheid return b.create<arith::ConstantOp>(elementType, 9059da0ef16SJohannes Reifferscheid b.getFloatAttr(elementType, v)); 9069da0ef16SJohannes Reifferscheid }; 9079da0ef16SJohannes Reifferscheid Value inf = cst(APFloat::getInf(floatSemantics)); 9089da0ef16SJohannes Reifferscheid Value four = b.create<arith::ConstantOp>(elementType, 9099da0ef16SJohannes Reifferscheid b.getFloatAttr(elementType, 4.0)); 9109da0ef16SJohannes Reifferscheid Value twoReal = b.create<arith::AddFOp>(real, real, fmf); 9119da0ef16SJohannes Reifferscheid Value negTwoReal = b.create<arith::MulFOp>(negOne, twoReal, fmf); 9129da0ef16SJohannes Reifferscheid 9139da0ef16SJohannes Reifferscheid Value expTwoRealMinusOne = b.create<math::ExpM1Op>(twoReal, fmf); 9149da0ef16SJohannes Reifferscheid Value expNegTwoRealMinusOne = b.create<math::ExpM1Op>(negTwoReal, fmf); 9159da0ef16SJohannes Reifferscheid Value realNum = 9169da0ef16SJohannes Reifferscheid b.create<arith::SubFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf); 9179da0ef16SJohannes Reifferscheid 9189da0ef16SJohannes Reifferscheid Value cosImag = b.create<math::CosOp>(imag, fmf); 9199da0ef16SJohannes Reifferscheid Value cosImagSq = b.create<arith::MulFOp>(cosImag, cosImag, fmf); 9209da0ef16SJohannes Reifferscheid Value twoCosTwoImagPlusOne = b.create<arith::MulFOp>(cosImagSq, four, fmf); 9219da0ef16SJohannes Reifferscheid Value sinImag = b.create<math::SinOp>(imag, fmf); 9229da0ef16SJohannes Reifferscheid 9239da0ef16SJohannes Reifferscheid Value imagNum = b.create<arith::MulFOp>( 9249da0ef16SJohannes Reifferscheid four, b.create<arith::MulFOp>(cosImag, sinImag, fmf), fmf); 9259da0ef16SJohannes Reifferscheid 9269da0ef16SJohannes Reifferscheid Value expSumMinusTwo = 9279da0ef16SJohannes Reifferscheid b.create<arith::AddFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf); 9289da0ef16SJohannes Reifferscheid Value denom = 9299da0ef16SJohannes Reifferscheid b.create<arith::AddFOp>(expSumMinusTwo, twoCosTwoImagPlusOne, fmf); 9309da0ef16SJohannes Reifferscheid 9319da0ef16SJohannes Reifferscheid Value isInf = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 9329da0ef16SJohannes Reifferscheid expSumMinusTwo, inf, fmf); 9339da0ef16SJohannes Reifferscheid Value realLimit = b.create<math::CopySignOp>(negOne, real, fmf); 9349da0ef16SJohannes Reifferscheid 9359da0ef16SJohannes Reifferscheid Value resultReal = b.create<arith::SelectOp>( 9369da0ef16SJohannes Reifferscheid isInf, realLimit, b.create<arith::DivFOp>(realNum, denom, fmf)); 9379da0ef16SJohannes Reifferscheid Value resultImag = b.create<arith::DivFOp>(imagNum, denom, fmf); 9389da0ef16SJohannes Reifferscheid 9399da0ef16SJohannes Reifferscheid if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | 9409da0ef16SJohannes Reifferscheid arith::FastMathFlags::ninf)) { 9419da0ef16SJohannes Reifferscheid Value absReal = b.create<math::AbsFOp>(real, fmf); 9429da0ef16SJohannes Reifferscheid Value zero = b.create<arith::ConstantOp>( 9439da0ef16SJohannes Reifferscheid elementType, b.getFloatAttr(elementType, 0.0)); 9449da0ef16SJohannes Reifferscheid Value nan = cst(APFloat::getNaN(floatSemantics)); 9459da0ef16SJohannes Reifferscheid 9469da0ef16SJohannes Reifferscheid Value absRealIsInf = 9479da0ef16SJohannes Reifferscheid b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf); 9489da0ef16SJohannes Reifferscheid Value imagIsZero = 9499da0ef16SJohannes Reifferscheid b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf); 9509da0ef16SJohannes Reifferscheid Value absRealIsNotInf = b.create<arith::XOrIOp>( 9519da0ef16SJohannes Reifferscheid absRealIsInf, b.create<arith::ConstantIntOp>(true, /*width=*/1)); 9529da0ef16SJohannes Reifferscheid 9539da0ef16SJohannes Reifferscheid Value imagNumIsNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, 9549da0ef16SJohannes Reifferscheid imagNum, imagNum, fmf); 9559da0ef16SJohannes Reifferscheid Value resultRealIsNaN = 9569da0ef16SJohannes Reifferscheid b.create<arith::AndIOp>(imagNumIsNaN, absRealIsNotInf); 9579da0ef16SJohannes Reifferscheid Value resultImagIsZero = b.create<arith::OrIOp>( 9589da0ef16SJohannes Reifferscheid imagIsZero, b.create<arith::AndIOp>(absRealIsInf, imagNumIsNaN)); 9599da0ef16SJohannes Reifferscheid 9609da0ef16SJohannes Reifferscheid resultReal = b.create<arith::SelectOp>(resultRealIsNaN, nan, resultReal); 9619da0ef16SJohannes Reifferscheid resultImag = 9629da0ef16SJohannes Reifferscheid b.create<arith::SelectOp>(resultImagIsZero, zero, resultImag); 9639da0ef16SJohannes Reifferscheid } 9649da0ef16SJohannes Reifferscheid 965f43deca2SJohannes Reifferscheid if constexpr (std::is_same_v<Op, complex::TanOp>) { 966f43deca2SJohannes Reifferscheid // tan(x+yi) = -i*tanh(-y + xi) 967f43deca2SJohannes Reifferscheid std::swap(resultReal, resultImag); 968f43deca2SJohannes Reifferscheid resultImag = b.create<arith::MulFOp>(resultImag, negOne, fmf); 969f43deca2SJohannes Reifferscheid } 970f43deca2SJohannes Reifferscheid 9719da0ef16SJohannes Reifferscheid rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 9729da0ef16SJohannes Reifferscheid resultImag); 973ffb8eecdSlewuathe return success(); 974ffb8eecdSlewuathe } 975ffb8eecdSlewuathe }; 976ffb8eecdSlewuathe 97762a34f6aSlewuathe struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> { 97862a34f6aSlewuathe using OpConversionPattern<complex::ConjOp>::OpConversionPattern; 97962a34f6aSlewuathe 98062a34f6aSlewuathe LogicalResult 98162a34f6aSlewuathe matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor, 98262a34f6aSlewuathe ConversionPatternRewriter &rewriter) const override { 98362a34f6aSlewuathe auto loc = op.getLoc(); 9845550c821STres Popp auto type = cast<ComplexType>(adaptor.getComplex().getType()); 9855550c821STres Popp auto elementType = cast<FloatType>(type.getElementType()); 98662a34f6aSlewuathe Value real = 98762a34f6aSlewuathe rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 98862a34f6aSlewuathe Value imag = 98962a34f6aSlewuathe rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 99062a34f6aSlewuathe Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag); 99162a34f6aSlewuathe 99262a34f6aSlewuathe rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag); 99362a34f6aSlewuathe 99462a34f6aSlewuathe return success(); 99562a34f6aSlewuathe } 99662a34f6aSlewuathe }; 99762a34f6aSlewuathe 99877dd4357SJohannes Reifferscheid /// Converts lhs^y = (a+bi)^(c+di) to 9996c6eddb6Sbixia1 /// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), 10006c6eddb6Sbixia1 /// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) 10016c6eddb6Sbixia1 static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder, 100277dd4357SJohannes Reifferscheid ComplexType type, Value lhs, Value c, Value d, 100377dd4357SJohannes Reifferscheid arith::FastMathFlags fmf) { 10045550c821STres Popp auto elementType = cast<FloatType>(type.getElementType()); 10056c6eddb6Sbixia1 100677dd4357SJohannes Reifferscheid Value a = builder.create<complex::ReOp>(lhs); 100777dd4357SJohannes Reifferscheid Value b = builder.create<complex::ImOp>(lhs); 10086c6eddb6Sbixia1 100977dd4357SJohannes Reifferscheid Value abs = builder.create<complex::AbsOp>(lhs, fmf); 101077dd4357SJohannes Reifferscheid Value absToC = builder.create<math::PowFOp>(abs, c, fmf); 10116c6eddb6Sbixia1 101277dd4357SJohannes Reifferscheid Value negD = builder.create<arith::NegFOp>(d, fmf); 101377dd4357SJohannes Reifferscheid Value argLhs = builder.create<math::Atan2Op>(b, a, fmf); 101477dd4357SJohannes Reifferscheid Value negDArgLhs = builder.create<arith::MulFOp>(negD, argLhs, fmf); 101577dd4357SJohannes Reifferscheid Value expNegDArgLhs = builder.create<math::ExpOp>(negDArgLhs, fmf); 10166c6eddb6Sbixia1 101777dd4357SJohannes Reifferscheid Value coeff = builder.create<arith::MulFOp>(absToC, expNegDArgLhs, fmf); 101877dd4357SJohannes Reifferscheid Value lnAbs = builder.create<math::LogOp>(abs, fmf); 101977dd4357SJohannes Reifferscheid Value cArgLhs = builder.create<arith::MulFOp>(c, argLhs, fmf); 102077dd4357SJohannes Reifferscheid Value dLnAbs = builder.create<arith::MulFOp>(d, lnAbs, fmf); 102177dd4357SJohannes Reifferscheid Value q = builder.create<arith::AddFOp>(cArgLhs, dLnAbs, fmf); 102277dd4357SJohannes Reifferscheid Value cosQ = builder.create<math::CosOp>(q, fmf); 102377dd4357SJohannes Reifferscheid Value sinQ = builder.create<math::SinOp>(q, fmf); 10246c6eddb6Sbixia1 102577dd4357SJohannes Reifferscheid Value inf = builder.create<arith::ConstantOp>( 102677dd4357SJohannes Reifferscheid elementType, 102777dd4357SJohannes Reifferscheid builder.getFloatAttr(elementType, 102877dd4357SJohannes Reifferscheid APFloat::getInf(elementType.getFloatSemantics()))); 10296c6eddb6Sbixia1 Value zero = builder.create<arith::ConstantOp>( 103077dd4357SJohannes Reifferscheid elementType, builder.getFloatAttr(elementType, 0.0)); 10316c6eddb6Sbixia1 Value one = builder.create<arith::ConstantOp>( 103277dd4357SJohannes Reifferscheid elementType, builder.getFloatAttr(elementType, 1.0)); 10336c6eddb6Sbixia1 Value complexOne = builder.create<complex::CreateOp>(type, one, zero); 103477dd4357SJohannes Reifferscheid Value complexZero = builder.create<complex::CreateOp>(type, zero, zero); 103577dd4357SJohannes Reifferscheid Value complexInf = builder.create<complex::CreateOp>(type, inf, zero); 10366c6eddb6Sbixia1 103777dd4357SJohannes Reifferscheid // Case 0: 103877dd4357SJohannes Reifferscheid // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see 10396c6eddb6Sbixia1 // Branch Cuts for Complex Elementary Functions or Much Ado About 10406c6eddb6Sbixia1 // Nothing's Sign Bit, W. Kahan, Section 10. 104177dd4357SJohannes Reifferscheid Value absEqZero = 104277dd4357SJohannes Reifferscheid builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, abs, zero, fmf); 104377dd4357SJohannes Reifferscheid Value dEqZero = 104477dd4357SJohannes Reifferscheid builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero, fmf); 104577dd4357SJohannes Reifferscheid Value cEqZero = 104677dd4357SJohannes Reifferscheid builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero, fmf); 104777dd4357SJohannes Reifferscheid Value bEqZero = 104877dd4357SJohannes Reifferscheid builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, b, zero, fmf); 104977dd4357SJohannes Reifferscheid 105077dd4357SJohannes Reifferscheid Value zeroLeC = 105177dd4357SJohannes Reifferscheid builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLE, zero, c, fmf); 105277dd4357SJohannes Reifferscheid Value coeffCosQ = builder.create<arith::MulFOp>(coeff, cosQ, fmf); 105377dd4357SJohannes Reifferscheid Value coeffSinQ = builder.create<arith::MulFOp>(coeff, sinQ, fmf); 105477dd4357SJohannes Reifferscheid Value complexOneOrZero = 105577dd4357SJohannes Reifferscheid builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero); 105677dd4357SJohannes Reifferscheid Value coeffCosSin = 105777dd4357SJohannes Reifferscheid builder.create<complex::CreateOp>(type, coeffCosQ, coeffSinQ); 105877dd4357SJohannes Reifferscheid Value cutoff0 = builder.create<arith::SelectOp>( 105977dd4357SJohannes Reifferscheid builder.create<arith::AndIOp>( 106077dd4357SJohannes Reifferscheid builder.create<arith::AndIOp>(absEqZero, dEqZero), zeroLeC), 106177dd4357SJohannes Reifferscheid complexOneOrZero, coeffCosSin); 106277dd4357SJohannes Reifferscheid 106377dd4357SJohannes Reifferscheid // Case 1: 106477dd4357SJohannes Reifferscheid // x^0 is defined to be 1 for any x, see 106577dd4357SJohannes Reifferscheid // Branch Cuts for Complex Elementary Functions or Much Ado About 106677dd4357SJohannes Reifferscheid // Nothing's Sign Bit, W. Kahan, Section 10. 106777dd4357SJohannes Reifferscheid Value rhsEqZero = builder.create<arith::AndIOp>(cEqZero, dEqZero); 106877dd4357SJohannes Reifferscheid Value cutoff1 = 106977dd4357SJohannes Reifferscheid builder.create<arith::SelectOp>(rhsEqZero, complexOne, cutoff0); 107077dd4357SJohannes Reifferscheid 107177dd4357SJohannes Reifferscheid // Case 2: 107277dd4357SJohannes Reifferscheid // 1^(c + d*i) = 1 + 0*i 107377dd4357SJohannes Reifferscheid Value lhsEqOne = builder.create<arith::AndIOp>( 1074ff9bc3a0SJohannes Reifferscheid builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf), 107577dd4357SJohannes Reifferscheid bEqZero); 107677dd4357SJohannes Reifferscheid Value cutoff2 = 107777dd4357SJohannes Reifferscheid builder.create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1); 107877dd4357SJohannes Reifferscheid 107977dd4357SJohannes Reifferscheid // Case 3: 108077dd4357SJohannes Reifferscheid // inf^(c + 0*i) = inf + 0*i, c > 0 108177dd4357SJohannes Reifferscheid Value lhsEqInf = builder.create<arith::AndIOp>( 1082ff9bc3a0SJohannes Reifferscheid builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf), 108377dd4357SJohannes Reifferscheid bEqZero); 108477dd4357SJohannes Reifferscheid Value rhsGt0 = builder.create<arith::AndIOp>( 108577dd4357SJohannes Reifferscheid dEqZero, 1086ff9bc3a0SJohannes Reifferscheid builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf)); 108777dd4357SJohannes Reifferscheid Value cutoff3 = builder.create<arith::SelectOp>( 108877dd4357SJohannes Reifferscheid builder.create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2); 108977dd4357SJohannes Reifferscheid 109077dd4357SJohannes Reifferscheid // Case 4: 109177dd4357SJohannes Reifferscheid // inf^(c + 0*i) = 0 + 0*i, c < 0 109277dd4357SJohannes Reifferscheid Value rhsLt0 = builder.create<arith::AndIOp>( 109377dd4357SJohannes Reifferscheid dEqZero, 1094ff9bc3a0SJohannes Reifferscheid builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf)); 109577dd4357SJohannes Reifferscheid Value cutoff4 = builder.create<arith::SelectOp>( 109677dd4357SJohannes Reifferscheid builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3); 109777dd4357SJohannes Reifferscheid 109877dd4357SJohannes Reifferscheid return cutoff4; 10996c6eddb6Sbixia1 } 11006c6eddb6Sbixia1 11016c6eddb6Sbixia1 struct PowOpConversion : public OpConversionPattern<complex::PowOp> { 11026c6eddb6Sbixia1 using OpConversionPattern<complex::PowOp>::OpConversionPattern; 11036c6eddb6Sbixia1 11046c6eddb6Sbixia1 LogicalResult 11056c6eddb6Sbixia1 matchAndRewrite(complex::PowOp op, OpAdaptor adaptor, 11066c6eddb6Sbixia1 ConversionPatternRewriter &rewriter) const override { 11076c6eddb6Sbixia1 mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter); 11085550c821STres Popp auto type = cast<ComplexType>(adaptor.getLhs().getType()); 11095550c821STres Popp auto elementType = cast<FloatType>(type.getElementType()); 11106c6eddb6Sbixia1 11116c6eddb6Sbixia1 Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs()); 11126c6eddb6Sbixia1 Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs()); 11136c6eddb6Sbixia1 111477dd4357SJohannes Reifferscheid rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(), 111577dd4357SJohannes Reifferscheid c, d, op.getFastmath())}); 11166c6eddb6Sbixia1 return success(); 11176c6eddb6Sbixia1 } 11186c6eddb6Sbixia1 }; 11196c6eddb6Sbixia1 11206c6eddb6Sbixia1 struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> { 11216c6eddb6Sbixia1 using OpConversionPattern<complex::RsqrtOp>::OpConversionPattern; 11226c6eddb6Sbixia1 11236c6eddb6Sbixia1 LogicalResult 11246c6eddb6Sbixia1 matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor, 11256c6eddb6Sbixia1 ConversionPatternRewriter &rewriter) const override { 112633e60f35SJohannes Reifferscheid mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 11275550c821STres Popp auto type = cast<ComplexType>(adaptor.getComplex().getType()); 11285550c821STres Popp auto elementType = cast<FloatType>(type.getElementType()); 11296c6eddb6Sbixia1 113033e60f35SJohannes Reifferscheid arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); 11316c6eddb6Sbixia1 113233e60f35SJohannes Reifferscheid auto cst = [&](APFloat v) { 113333e60f35SJohannes Reifferscheid return b.create<arith::ConstantOp>(elementType, 113433e60f35SJohannes Reifferscheid b.getFloatAttr(elementType, v)); 113533e60f35SJohannes Reifferscheid }; 113633e60f35SJohannes Reifferscheid const auto &floatSemantics = elementType.getFloatSemantics(); 113733e60f35SJohannes Reifferscheid Value zero = cst(APFloat::getZero(floatSemantics)); 113833e60f35SJohannes Reifferscheid Value inf = cst(APFloat::getInf(floatSemantics)); 113933e60f35SJohannes Reifferscheid Value negHalf = b.create<arith::ConstantOp>( 114033e60f35SJohannes Reifferscheid elementType, b.getFloatAttr(elementType, -0.5)); 114133e60f35SJohannes Reifferscheid Value nan = cst(APFloat::getNaN(floatSemantics)); 114233e60f35SJohannes Reifferscheid 114333e60f35SJohannes Reifferscheid Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 114433e60f35SJohannes Reifferscheid Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 114533e60f35SJohannes Reifferscheid Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt); 114633e60f35SJohannes Reifferscheid Value argArg = b.create<math::Atan2Op>(imag, real, fmf); 114733e60f35SJohannes Reifferscheid Value rsqrtArg = b.create<arith::MulFOp>(argArg, negHalf, fmf); 114833e60f35SJohannes Reifferscheid Value cos = b.create<math::CosOp>(rsqrtArg, fmf); 114933e60f35SJohannes Reifferscheid Value sin = b.create<math::SinOp>(rsqrtArg, fmf); 115033e60f35SJohannes Reifferscheid 115133e60f35SJohannes Reifferscheid Value resultReal = b.create<arith::MulFOp>(absRsqrt, cos, fmf); 115233e60f35SJohannes Reifferscheid Value resultImag = b.create<arith::MulFOp>(absRsqrt, sin, fmf); 115333e60f35SJohannes Reifferscheid 115433e60f35SJohannes Reifferscheid if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | 115533e60f35SJohannes Reifferscheid arith::FastMathFlags::ninf)) { 115633e60f35SJohannes Reifferscheid Value negOne = b.create<arith::ConstantOp>( 115733e60f35SJohannes Reifferscheid elementType, b.getFloatAttr(elementType, -1)); 115833e60f35SJohannes Reifferscheid 115933e60f35SJohannes Reifferscheid Value realSignedZero = b.create<math::CopySignOp>(zero, real, fmf); 116033e60f35SJohannes Reifferscheid Value imagSignedZero = b.create<math::CopySignOp>(zero, imag, fmf); 116133e60f35SJohannes Reifferscheid Value negImagSignedZero = 116233e60f35SJohannes Reifferscheid b.create<arith::MulFOp>(negOne, imagSignedZero, fmf); 116333e60f35SJohannes Reifferscheid 116433e60f35SJohannes Reifferscheid Value absReal = b.create<math::AbsFOp>(real, fmf); 116533e60f35SJohannes Reifferscheid Value absImag = b.create<math::AbsFOp>(imag, fmf); 116633e60f35SJohannes Reifferscheid 116733e60f35SJohannes Reifferscheid Value absImagIsInf = 116833e60f35SJohannes Reifferscheid b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf); 116933e60f35SJohannes Reifferscheid Value realIsNan = 117033e60f35SJohannes Reifferscheid b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf); 117133e60f35SJohannes Reifferscheid Value realIsInf = 117233e60f35SJohannes Reifferscheid b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf); 117333e60f35SJohannes Reifferscheid Value inIsNanInf = b.create<arith::AndIOp>(absImagIsInf, realIsNan); 117433e60f35SJohannes Reifferscheid 117533e60f35SJohannes Reifferscheid Value resultIsZero = b.create<arith::OrIOp>(inIsNanInf, realIsInf); 117633e60f35SJohannes Reifferscheid 117733e60f35SJohannes Reifferscheid resultReal = 117833e60f35SJohannes Reifferscheid b.create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal); 117933e60f35SJohannes Reifferscheid resultImag = b.create<arith::SelectOp>(resultIsZero, negImagSignedZero, 118033e60f35SJohannes Reifferscheid resultImag); 118133e60f35SJohannes Reifferscheid } 118233e60f35SJohannes Reifferscheid 118333e60f35SJohannes Reifferscheid Value isRealZero = 118433e60f35SJohannes Reifferscheid b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf); 118533e60f35SJohannes Reifferscheid Value isImagZero = 118633e60f35SJohannes Reifferscheid b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf); 118733e60f35SJohannes Reifferscheid Value isZero = b.create<arith::AndIOp>(isRealZero, isImagZero); 118833e60f35SJohannes Reifferscheid 118933e60f35SJohannes Reifferscheid resultReal = b.create<arith::SelectOp>(isZero, inf, resultReal); 119033e60f35SJohannes Reifferscheid resultImag = b.create<arith::SelectOp>(isZero, nan, resultImag); 119133e60f35SJohannes Reifferscheid 119233e60f35SJohannes Reifferscheid rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 119333e60f35SJohannes Reifferscheid resultImag); 11946c6eddb6Sbixia1 return success(); 11956c6eddb6Sbixia1 } 11966c6eddb6Sbixia1 }; 11976c6eddb6Sbixia1 11988fa2e679SLewuathe struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> { 11998fa2e679SLewuathe using OpConversionPattern<complex::AngleOp>::OpConversionPattern; 12008fa2e679SLewuathe 12018fa2e679SLewuathe LogicalResult 12028fa2e679SLewuathe matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor, 12038fa2e679SLewuathe ConversionPatternRewriter &rewriter) const override { 12048fa2e679SLewuathe auto loc = op.getLoc(); 12058fa2e679SLewuathe auto type = op.getType(); 12068c9d814bSKai Sasaki arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); 12078fa2e679SLewuathe 12088fa2e679SLewuathe Value real = 12098fa2e679SLewuathe rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex()); 12108fa2e679SLewuathe Value imag = 12118fa2e679SLewuathe rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex()); 12128fa2e679SLewuathe 12138c9d814bSKai Sasaki rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real, fmf); 12148fa2e679SLewuathe 12158fa2e679SLewuathe return success(); 12168fa2e679SLewuathe } 12178fa2e679SLewuathe }; 12188fa2e679SLewuathe 12192ea7fb7bSAdrian Kuegel } // namespace 12202ea7fb7bSAdrian Kuegel 12212ea7fb7bSAdrian Kuegel void mlir::populateComplexToStandardConversionPatterns( 12222ea7fb7bSAdrian Kuegel RewritePatternSet &patterns) { 1223f112bd61SAdrian Kuegel // clang-format off 1224f112bd61SAdrian Kuegel patterns.add< 1225f112bd61SAdrian Kuegel AbsOpConversion, 12268fa2e679SLewuathe AngleOpConversion, 1227f711785eSAlexander Belyaev Atan2OpConversion, 1228a54f4eaeSMogball BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>, 1229a54f4eaeSMogball BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>, 123062a34f6aSlewuathe ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>, 123162a34f6aSlewuathe ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>, 123262a34f6aSlewuathe ConjOpConversion, 1233672b908bSGoran Flegar CosOpConversion, 1234f112bd61SAdrian Kuegel DivOpConversion, 1235f112bd61SAdrian Kuegel ExpOpConversion, 1236338e76f8Sbixia1 Expm1OpConversion, 12376e80e3bdSAdrian Kuegel Log1pOpConversion, 123862a34f6aSlewuathe LogOpConversion, 1239bf17ee19SAdrian Kuegel MulOpConversion, 1240f112bd61SAdrian Kuegel NegOpConversion, 1241672b908bSGoran Flegar SignOpConversion, 12426d75c897Slewuathe SinOpConversion, 1243f711785eSAlexander Belyaev SqrtOpConversion, 1244f43deca2SJohannes Reifferscheid TanTanhOpConversion<complex::TanOp>, 1245f43deca2SJohannes Reifferscheid TanTanhOpConversion<complex::TanhOp>, 12466c6eddb6Sbixia1 PowOpConversion, 12476c6eddb6Sbixia1 RsqrtOpConversion 124862a34f6aSlewuathe >(patterns.getContext()); 1249f112bd61SAdrian Kuegel // clang-format on 12502ea7fb7bSAdrian Kuegel } 12512ea7fb7bSAdrian Kuegel 12522ea7fb7bSAdrian Kuegel namespace { 12532ea7fb7bSAdrian Kuegel struct ConvertComplexToStandardPass 125467d0d7acSMichele Scuttari : public impl::ConvertComplexToStandardBase<ConvertComplexToStandardPass> { 125541574554SRiver Riddle void runOnOperation() override; 12562ea7fb7bSAdrian Kuegel }; 12572ea7fb7bSAdrian Kuegel 125841574554SRiver Riddle void ConvertComplexToStandardPass::runOnOperation() { 12592ea7fb7bSAdrian Kuegel // Convert to the Standard dialect using the converter defined above. 12602ea7fb7bSAdrian Kuegel RewritePatternSet patterns(&getContext()); 12612ea7fb7bSAdrian Kuegel populateComplexToStandardConversionPatterns(patterns); 12622ea7fb7bSAdrian Kuegel 12632ea7fb7bSAdrian Kuegel ConversionTarget target(getContext()); 1264abc362a1SJakub Kuderski target.addLegalDialect<arith::ArithDialect, math::MathDialect>(); 1265fb978f09SAdrian Kuegel target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>(); 126647f175b0SRiver Riddle if (failed( 126747f175b0SRiver Riddle applyPartialConversion(getOperation(), target, std::move(patterns)))) 12682ea7fb7bSAdrian Kuegel signalPassFailure(); 12692ea7fb7bSAdrian Kuegel } 12702ea7fb7bSAdrian Kuegel } // namespace 1271039b969bSMichele Scuttari 1272039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createConvertComplexToStandardPass() { 1273039b969bSMichele Scuttari return std::make_unique<ConvertComplexToStandardPass>(); 1274039b969bSMichele Scuttari } 1275