xref: /llvm-project/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (revision bdd365825d0766b6991c8f5443f8a9f76e75011a)
12ea7fb7bSAdrian Kuegel //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
22ea7fb7bSAdrian Kuegel //
32ea7fb7bSAdrian Kuegel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42ea7fb7bSAdrian Kuegel // See https://llvm.org/LICENSE.txt for license information.
52ea7fb7bSAdrian Kuegel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62ea7fb7bSAdrian Kuegel //
72ea7fb7bSAdrian Kuegel //===----------------------------------------------------------------------===//
82ea7fb7bSAdrian Kuegel 
92ea7fb7bSAdrian Kuegel #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
102ea7fb7bSAdrian Kuegel 
11abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
122ea7fb7bSAdrian Kuegel #include "mlir/Dialect/Complex/IR/Complex.h"
132ea7fb7bSAdrian Kuegel #include "mlir/Dialect/Math/IR/Math.h"
14f112bd61SAdrian Kuegel #include "mlir/IR/ImplicitLocOpBuilder.h"
152ea7fb7bSAdrian Kuegel #include "mlir/IR/PatternMatch.h"
1667d0d7acSMichele Scuttari #include "mlir/Pass/Pass.h"
172ea7fb7bSAdrian Kuegel #include "mlir/Transforms/DialectConversion.h"
1867d0d7acSMichele Scuttari #include <memory>
1967d0d7acSMichele Scuttari #include <type_traits>
2067d0d7acSMichele Scuttari 
2167d0d7acSMichele Scuttari namespace mlir {
2267d0d7acSMichele Scuttari #define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARD
2367d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc"
2467d0d7acSMichele Scuttari } // namespace mlir
252ea7fb7bSAdrian Kuegel 
262ea7fb7bSAdrian Kuegel using namespace mlir;
272ea7fb7bSAdrian Kuegel 
282ea7fb7bSAdrian Kuegel namespace {
299d9bb7b1SJohannes Reifferscheid 
3033e60f35SJohannes Reifferscheid enum class AbsFn { abs, sqrt, rsqrt };
3133e60f35SJohannes Reifferscheid 
3233e60f35SJohannes Reifferscheid // Returns the absolute value, its square root or its reciprocal square root.
33ff9bc3a0SJohannes Reifferscheid Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
3433e60f35SJohannes Reifferscheid                  ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) {
35ff9bc3a0SJohannes Reifferscheid   Value one = b.create<arith::ConstantOp>(real.getType(),
36ff9bc3a0SJohannes Reifferscheid                                           b.getFloatAttr(real.getType(), 1.0));
372ea7fb7bSAdrian Kuegel 
389d9bb7b1SJohannes Reifferscheid   Value absReal = b.create<math::AbsFOp>(real, fmf);
399d9bb7b1SJohannes Reifferscheid   Value absImag = b.create<math::AbsFOp>(imag, fmf);
40b17348c3SKai Sasaki 
419d9bb7b1SJohannes Reifferscheid   Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf);
429d9bb7b1SJohannes Reifferscheid   Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf);
439b225d01SJohannes Reifferscheid 
449b225d01SJohannes Reifferscheid   // The lowering below requires NaNs and infinities to work correctly.
459b225d01SJohannes Reifferscheid   arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
469b225d01SJohannes Reifferscheid       fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
479b225d01SJohannes Reifferscheid   Value ratio = b.create<arith::DivFOp>(min, max, fmfWithNaNInf);
489b225d01SJohannes Reifferscheid   Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmfWithNaNInf);
499b225d01SJohannes Reifferscheid   Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmfWithNaNInf);
50ff9bc3a0SJohannes Reifferscheid   Value result;
51ff9bc3a0SJohannes Reifferscheid 
5233e60f35SJohannes Reifferscheid   if (fn == AbsFn::rsqrt) {
539b225d01SJohannes Reifferscheid     ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
549b225d01SJohannes Reifferscheid     min = b.create<math::RsqrtOp>(min, fmfWithNaNInf);
559b225d01SJohannes Reifferscheid     max = b.create<math::RsqrtOp>(max, fmfWithNaNInf);
5633e60f35SJohannes Reifferscheid   }
5733e60f35SJohannes Reifferscheid 
5833e60f35SJohannes Reifferscheid   if (fn == AbsFn::sqrt) {
59ff9bc3a0SJohannes Reifferscheid     Value quarter = b.create<arith::ConstantOp>(
60ff9bc3a0SJohannes Reifferscheid         real.getType(), b.getFloatAttr(real.getType(), 0.25));
61ff9bc3a0SJohannes Reifferscheid     // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
629b225d01SJohannes Reifferscheid     Value sqrt = b.create<math::SqrtOp>(max, fmfWithNaNInf);
639b225d01SJohannes Reifferscheid     Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmfWithNaNInf);
649b225d01SJohannes Reifferscheid     result = b.create<arith::MulFOp>(sqrt, p025, fmfWithNaNInf);
65ff9bc3a0SJohannes Reifferscheid   } else {
669b225d01SJohannes Reifferscheid     Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
679b225d01SJohannes Reifferscheid     result = b.create<arith::MulFOp>(max, sqrt, fmfWithNaNInf);
68ff9bc3a0SJohannes Reifferscheid   }
69ff9bc3a0SJohannes Reifferscheid 
709b225d01SJohannes Reifferscheid   Value isNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result,
719b225d01SJohannes Reifferscheid                                         result, fmfWithNaNInf);
72ff9bc3a0SJohannes Reifferscheid   return b.create<arith::SelectOp>(isNaN, min, result);
73ff9bc3a0SJohannes Reifferscheid }
74ff9bc3a0SJohannes Reifferscheid 
75ff9bc3a0SJohannes Reifferscheid struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
76ff9bc3a0SJohannes Reifferscheid   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
77ff9bc3a0SJohannes Reifferscheid 
78ff9bc3a0SJohannes Reifferscheid   LogicalResult
79ff9bc3a0SJohannes Reifferscheid   matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
80ff9bc3a0SJohannes Reifferscheid                   ConversionPatternRewriter &rewriter) const override {
81ff9bc3a0SJohannes Reifferscheid     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
82ff9bc3a0SJohannes Reifferscheid 
83ff9bc3a0SJohannes Reifferscheid     arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
84ff9bc3a0SJohannes Reifferscheid 
85ff9bc3a0SJohannes Reifferscheid     Value real = b.create<complex::ReOp>(adaptor.getComplex());
86ff9bc3a0SJohannes Reifferscheid     Value imag = b.create<complex::ImOp>(adaptor.getComplex());
87ff9bc3a0SJohannes Reifferscheid     rewriter.replaceOp(op, computeAbs(real, imag, fmf, b));
88b17348c3SKai Sasaki 
892ea7fb7bSAdrian Kuegel     return success();
902ea7fb7bSAdrian Kuegel   }
912ea7fb7bSAdrian Kuegel };
92ac00cb0dSAdrian Kuegel 
93f711785eSAlexander Belyaev // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
94f711785eSAlexander Belyaev struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
95f711785eSAlexander Belyaev   using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;
96f711785eSAlexander Belyaev 
97f711785eSAlexander Belyaev   LogicalResult
98f711785eSAlexander Belyaev   matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
99f711785eSAlexander Belyaev                   ConversionPatternRewriter &rewriter) const override {
100f711785eSAlexander Belyaev     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
101f711785eSAlexander Belyaev 
1025550c821STres Popp     auto type = cast<ComplexType>(op.getType());
103f711785eSAlexander Belyaev     Type elementType = type.getElementType();
104b930b14dSKai Sasaki     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
105f711785eSAlexander Belyaev 
106f711785eSAlexander Belyaev     Value lhs = adaptor.getLhs();
107f711785eSAlexander Belyaev     Value rhs = adaptor.getRhs();
108f711785eSAlexander Belyaev 
109b930b14dSKai Sasaki     Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs, fmf);
110b930b14dSKai Sasaki     Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs, fmf);
111f711785eSAlexander Belyaev     Value rhsSquaredPlusLhsSquared =
112b930b14dSKai Sasaki         b.create<complex::AddOp>(type, rhsSquared, lhsSquared, fmf);
113f711785eSAlexander Belyaev     Value sqrtOfRhsSquaredPlusLhsSquared =
114b930b14dSKai Sasaki         b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared, fmf);
115f711785eSAlexander Belyaev 
116f711785eSAlexander Belyaev     Value zero =
117f711785eSAlexander Belyaev         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
118f711785eSAlexander Belyaev     Value one = b.create<arith::ConstantOp>(elementType,
119f711785eSAlexander Belyaev                                             b.getFloatAttr(elementType, 1));
120f711785eSAlexander Belyaev     Value i = b.create<complex::CreateOp>(type, zero, one);
121b930b14dSKai Sasaki     Value iTimesLhs = b.create<complex::MulOp>(i, lhs, fmf);
122b930b14dSKai Sasaki     Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs, fmf);
123f711785eSAlexander Belyaev 
124b930b14dSKai Sasaki     Value divResult = b.create<complex::DivOp>(
125b930b14dSKai Sasaki         rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
126b930b14dSKai Sasaki     Value logResult = b.create<complex::LogOp>(divResult, fmf);
127f711785eSAlexander Belyaev 
128f711785eSAlexander Belyaev     Value negativeOne = b.create<arith::ConstantOp>(
129f711785eSAlexander Belyaev         elementType, b.getFloatAttr(elementType, -1));
130f711785eSAlexander Belyaev     Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
131f711785eSAlexander Belyaev 
132b930b14dSKai Sasaki     rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult, fmf);
133f711785eSAlexander Belyaev     return success();
134f711785eSAlexander Belyaev   }
135f711785eSAlexander Belyaev };
136f711785eSAlexander Belyaev 
137a54f4eaeSMogball template <typename ComparisonOp, arith::CmpFPredicate p>
138fb8b2b86SAdrian Kuegel struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
139fb8b2b86SAdrian Kuegel   using OpConversionPattern<ComparisonOp>::OpConversionPattern;
140fb8b2b86SAdrian Kuegel   using ResultCombiner =
141fb8b2b86SAdrian Kuegel       std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
142a54f4eaeSMogball                          arith::AndIOp, arith::OrIOp>;
143ac00cb0dSAdrian Kuegel 
144ac00cb0dSAdrian Kuegel   LogicalResult
145b54c724bSRiver Riddle   matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
146ac00cb0dSAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
147ac00cb0dSAdrian Kuegel     auto loc = op.getLoc();
1485550c821STres Popp     auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
149ac00cb0dSAdrian Kuegel 
150c0342a2dSJacques Pienaar     Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
151c0342a2dSJacques Pienaar     Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
152c0342a2dSJacques Pienaar     Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs());
153c0342a2dSJacques Pienaar     Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs());
154a54f4eaeSMogball     Value realComparison =
155a54f4eaeSMogball         rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
156a54f4eaeSMogball     Value imagComparison =
157a54f4eaeSMogball         rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
158ac00cb0dSAdrian Kuegel 
159fb8b2b86SAdrian Kuegel     rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
160fb8b2b86SAdrian Kuegel                                                 imagComparison);
161ac00cb0dSAdrian Kuegel     return success();
162ac00cb0dSAdrian Kuegel   }
163ac00cb0dSAdrian Kuegel };
164942be7cbSAdrian Kuegel 
165fb978f09SAdrian Kuegel // Default conversion which applies the BinaryStandardOp separately on the real
166fb978f09SAdrian Kuegel // and imaginary parts. Can for example be used for complex::AddOp and
167fb978f09SAdrian Kuegel // complex::SubOp.
168fb978f09SAdrian Kuegel template <typename BinaryComplexOp, typename BinaryStandardOp>
169fb978f09SAdrian Kuegel struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
170fb978f09SAdrian Kuegel   using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
171fb978f09SAdrian Kuegel 
172fb978f09SAdrian Kuegel   LogicalResult
173b54c724bSRiver Riddle   matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
174fb978f09SAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
1755550c821STres Popp     auto type = cast<ComplexType>(adaptor.getLhs().getType());
1765550c821STres Popp     auto elementType = cast<FloatType>(type.getElementType());
177fb978f09SAdrian Kuegel     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
178be5b6667SKai Sasaki     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
179fb978f09SAdrian Kuegel 
180c0342a2dSJacques Pienaar     Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
181c0342a2dSJacques Pienaar     Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
182be5b6667SKai Sasaki     Value resultReal = b.create<BinaryStandardOp>(elementType, realLhs, realRhs,
183be5b6667SKai Sasaki                                                   fmf.getValue());
184c0342a2dSJacques Pienaar     Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
185c0342a2dSJacques Pienaar     Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
186be5b6667SKai Sasaki     Value resultImag = b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs,
187be5b6667SKai Sasaki                                                   fmf.getValue());
188fb978f09SAdrian Kuegel     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
189fb978f09SAdrian Kuegel                                                    resultImag);
190fb978f09SAdrian Kuegel     return success();
191fb978f09SAdrian Kuegel   }
192fb978f09SAdrian Kuegel };
193fb978f09SAdrian Kuegel 
194672b908bSGoran Flegar template <typename TrigonometricOp>
195672b908bSGoran Flegar struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
196672b908bSGoran Flegar   using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
197672b908bSGoran Flegar 
198672b908bSGoran Flegar   using OpConversionPattern<TrigonometricOp>::OpConversionPattern;
199672b908bSGoran Flegar 
200672b908bSGoran Flegar   LogicalResult
201672b908bSGoran Flegar   matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
202672b908bSGoran Flegar                   ConversionPatternRewriter &rewriter) const override {
203672b908bSGoran Flegar     auto loc = op.getLoc();
2045550c821STres Popp     auto type = cast<ComplexType>(adaptor.getComplex().getType());
2055550c821STres Popp     auto elementType = cast<FloatType>(type.getElementType());
2067d2d8e2aSKai Sasaki     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
207672b908bSGoran Flegar 
208672b908bSGoran Flegar     Value real =
209672b908bSGoran Flegar         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
210672b908bSGoran Flegar     Value imag =
211672b908bSGoran Flegar         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
212672b908bSGoran Flegar 
213672b908bSGoran Flegar     // Trigonometric ops use a set of common building blocks to convert to real
214672b908bSGoran Flegar     // ops. Here we create these building blocks and call into an op-specific
215672b908bSGoran Flegar     // implementation in the subclass to combine them.
216672b908bSGoran Flegar     Value half = rewriter.create<arith::ConstantOp>(
217672b908bSGoran Flegar         loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
2187d2d8e2aSKai Sasaki     Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf);
2197d2d8e2aSKai Sasaki     Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp, fmf);
2207d2d8e2aSKai Sasaki     Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp, fmf);
2217d2d8e2aSKai Sasaki     Value sin = rewriter.create<math::SinOp>(loc, real, fmf);
2227d2d8e2aSKai Sasaki     Value cos = rewriter.create<math::CosOp>(loc, real, fmf);
223672b908bSGoran Flegar 
224672b908bSGoran Flegar     auto resultPair =
2257d2d8e2aSKai Sasaki         combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
226672b908bSGoran Flegar 
227672b908bSGoran Flegar     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
228672b908bSGoran Flegar                                                    resultPair.second);
229672b908bSGoran Flegar     return success();
230672b908bSGoran Flegar   }
231672b908bSGoran Flegar 
232672b908bSGoran Flegar   virtual std::pair<Value, Value>
233672b908bSGoran Flegar   combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
2347d2d8e2aSKai Sasaki           Value cos, ConversionPatternRewriter &rewriter,
2357d2d8e2aSKai Sasaki           arith::FastMathFlagsAttr fmf) const = 0;
236672b908bSGoran Flegar };
237672b908bSGoran Flegar 
238672b908bSGoran Flegar struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
239672b908bSGoran Flegar   using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
240672b908bSGoran Flegar 
2417d2d8e2aSKai Sasaki   std::pair<Value, Value> combine(Location loc, Value scaledExp,
2427d2d8e2aSKai Sasaki                                   Value reciprocalExp, Value sin, Value cos,
2437d2d8e2aSKai Sasaki                                   ConversionPatternRewriter &rewriter,
2447d2d8e2aSKai Sasaki                                   arith::FastMathFlagsAttr fmf) const override {
245672b908bSGoran Flegar     // Complex cosine is defined as;
246672b908bSGoran Flegar     //   cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
247672b908bSGoran Flegar     // Plugging in:
248672b908bSGoran Flegar     //   exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
249672b908bSGoran Flegar     //   exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
250672b908bSGoran Flegar     // and defining t := exp(y)
251672b908bSGoran Flegar     // We get:
252672b908bSGoran Flegar     //   Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
253672b908bSGoran Flegar     //   Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
2547d2d8e2aSKai Sasaki     Value sum =
2557d2d8e2aSKai Sasaki         rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf);
2567d2d8e2aSKai Sasaki     Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos, fmf);
2577d2d8e2aSKai Sasaki     Value diff =
2587d2d8e2aSKai Sasaki         rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf);
2597d2d8e2aSKai Sasaki     Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin, fmf);
260672b908bSGoran Flegar     return {resultReal, resultImag};
261672b908bSGoran Flegar   }
262672b908bSGoran Flegar };
263672b908bSGoran Flegar 
264942be7cbSAdrian Kuegel struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
265942be7cbSAdrian Kuegel   using OpConversionPattern<complex::DivOp>::OpConversionPattern;
266942be7cbSAdrian Kuegel 
267942be7cbSAdrian Kuegel   LogicalResult
268b54c724bSRiver Riddle   matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
269942be7cbSAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
270942be7cbSAdrian Kuegel     auto loc = op.getLoc();
2715550c821STres Popp     auto type = cast<ComplexType>(adaptor.getLhs().getType());
2725550c821STres Popp     auto elementType = cast<FloatType>(type.getElementType());
273288d317fSKai Sasaki     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
274942be7cbSAdrian Kuegel 
275942be7cbSAdrian Kuegel     Value lhsReal =
276c0342a2dSJacques Pienaar         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
277942be7cbSAdrian Kuegel     Value lhsImag =
278c0342a2dSJacques Pienaar         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs());
279942be7cbSAdrian Kuegel     Value rhsReal =
280c0342a2dSJacques Pienaar         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs());
281942be7cbSAdrian Kuegel     Value rhsImag =
282c0342a2dSJacques Pienaar         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs());
283942be7cbSAdrian Kuegel 
284942be7cbSAdrian Kuegel     // Smith's algorithm to divide complex numbers. It is just a bit smarter
285942be7cbSAdrian Kuegel     // way to compute the following formula:
286942be7cbSAdrian Kuegel     //  (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
287942be7cbSAdrian Kuegel     //    = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
288942be7cbSAdrian Kuegel     //          ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
289942be7cbSAdrian Kuegel     //    = ((lhsReal * rhsReal + lhsImag * rhsImag) +
290942be7cbSAdrian Kuegel     //          (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
291942be7cbSAdrian Kuegel     //
292942be7cbSAdrian Kuegel     // Depending on whether |rhsReal| < |rhsImag| we compute either
293942be7cbSAdrian Kuegel     //   rhsRealImagRatio = rhsReal / rhsImag
294942be7cbSAdrian Kuegel     //   rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
295942be7cbSAdrian Kuegel     //   resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
296942be7cbSAdrian Kuegel     //   resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
297942be7cbSAdrian Kuegel     //
298942be7cbSAdrian Kuegel     // or
299942be7cbSAdrian Kuegel     //
300942be7cbSAdrian Kuegel     //   rhsImagRealRatio = rhsImag / rhsReal
301942be7cbSAdrian Kuegel     //   rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
302942be7cbSAdrian Kuegel     //   resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
303942be7cbSAdrian Kuegel     //   resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
304942be7cbSAdrian Kuegel     //
305942be7cbSAdrian Kuegel     // See https://dl.acm.org/citation.cfm?id=368661 for more details.
306a54f4eaeSMogball     Value rhsRealImagRatio =
307288d317fSKai Sasaki         rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag, fmf);
308a54f4eaeSMogball     Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
309a54f4eaeSMogball         loc, rhsImag,
310288d317fSKai Sasaki         rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal, fmf),
311288d317fSKai Sasaki         fmf);
312a54f4eaeSMogball     Value realNumerator1 = rewriter.create<arith::AddFOp>(
313288d317fSKai Sasaki         loc,
314288d317fSKai Sasaki         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio, fmf),
315288d317fSKai Sasaki         lhsImag, fmf);
316288d317fSKai Sasaki     Value resultReal1 = rewriter.create<arith::DivFOp>(loc, realNumerator1,
317288d317fSKai Sasaki                                                        rhsRealImagDenom, fmf);
318a54f4eaeSMogball     Value imagNumerator1 = rewriter.create<arith::SubFOp>(
319288d317fSKai Sasaki         loc,
320288d317fSKai Sasaki         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio, fmf),
321288d317fSKai Sasaki         lhsReal, fmf);
322288d317fSKai Sasaki     Value resultImag1 = rewriter.create<arith::DivFOp>(loc, imagNumerator1,
323288d317fSKai Sasaki                                                        rhsRealImagDenom, fmf);
324942be7cbSAdrian Kuegel 
325a54f4eaeSMogball     Value rhsImagRealRatio =
326288d317fSKai Sasaki         rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal, fmf);
327a54f4eaeSMogball     Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
328a54f4eaeSMogball         loc, rhsReal,
329288d317fSKai Sasaki         rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag, fmf),
330288d317fSKai Sasaki         fmf);
331a54f4eaeSMogball     Value realNumerator2 = rewriter.create<arith::AddFOp>(
332a54f4eaeSMogball         loc, lhsReal,
333288d317fSKai Sasaki         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio, fmf),
334288d317fSKai Sasaki         fmf);
335288d317fSKai Sasaki     Value resultReal2 = rewriter.create<arith::DivFOp>(loc, realNumerator2,
336288d317fSKai Sasaki                                                        rhsImagRealDenom, fmf);
337a54f4eaeSMogball     Value imagNumerator2 = rewriter.create<arith::SubFOp>(
338a54f4eaeSMogball         loc, lhsImag,
339288d317fSKai Sasaki         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio, fmf),
340288d317fSKai Sasaki         fmf);
341288d317fSKai Sasaki     Value resultImag2 = rewriter.create<arith::DivFOp>(loc, imagNumerator2,
342288d317fSKai Sasaki                                                        rhsImagRealDenom, fmf);
343942be7cbSAdrian Kuegel 
344942be7cbSAdrian Kuegel     // Consider corner cases.
345942be7cbSAdrian Kuegel     // Case 1. Zero denominator, numerator contains at most one NaN value.
346a54f4eaeSMogball     Value zero = rewriter.create<arith::ConstantOp>(
347a54f4eaeSMogball         loc, elementType, rewriter.getZeroAttr(elementType));
348288d317fSKai Sasaki     Value rhsRealAbs = rewriter.create<math::AbsFOp>(loc, rhsReal, fmf);
349a54f4eaeSMogball     Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
350a54f4eaeSMogball         loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
351288d317fSKai Sasaki     Value rhsImagAbs = rewriter.create<math::AbsFOp>(loc, rhsImag, fmf);
352a54f4eaeSMogball     Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
353a54f4eaeSMogball         loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
354a54f4eaeSMogball     Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
355a54f4eaeSMogball         loc, arith::CmpFPredicate::ORD, lhsReal, zero);
356a54f4eaeSMogball     Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>(
357a54f4eaeSMogball         loc, arith::CmpFPredicate::ORD, lhsImag, zero);
358942be7cbSAdrian Kuegel     Value lhsContainsNotNaNValue =
359a54f4eaeSMogball         rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
360a54f4eaeSMogball     Value resultIsInfinity = rewriter.create<arith::AndIOp>(
361942be7cbSAdrian Kuegel         loc, lhsContainsNotNaNValue,
362a54f4eaeSMogball         rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
363a54f4eaeSMogball     Value inf = rewriter.create<arith::ConstantOp>(
364942be7cbSAdrian Kuegel         loc, elementType,
365942be7cbSAdrian Kuegel         rewriter.getFloatAttr(
366942be7cbSAdrian Kuegel             elementType, APFloat::getInf(elementType.getFloatSemantics())));
367a54f4eaeSMogball     Value infWithSignOfRhsReal =
368a54f4eaeSMogball         rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
369942be7cbSAdrian Kuegel     Value infinityResultReal =
370288d317fSKai Sasaki         rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal, fmf);
371942be7cbSAdrian Kuegel     Value infinityResultImag =
372288d317fSKai Sasaki         rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag, fmf);
373942be7cbSAdrian Kuegel 
374942be7cbSAdrian Kuegel     // Case 2. Infinite numerator, finite denominator.
375a54f4eaeSMogball     Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
376a54f4eaeSMogball         loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
377a54f4eaeSMogball     Value rhsImagFinite = rewriter.create<arith::CmpFOp>(
378a54f4eaeSMogball         loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
379a54f4eaeSMogball     Value rhsFinite =
380a54f4eaeSMogball         rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
381288d317fSKai Sasaki     Value lhsRealAbs = rewriter.create<math::AbsFOp>(loc, lhsReal, fmf);
382a54f4eaeSMogball     Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
383a54f4eaeSMogball         loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
384288d317fSKai Sasaki     Value lhsImagAbs = rewriter.create<math::AbsFOp>(loc, lhsImag, fmf);
385a54f4eaeSMogball     Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
386a54f4eaeSMogball         loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
387942be7cbSAdrian Kuegel     Value lhsInfinite =
388a54f4eaeSMogball         rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
389942be7cbSAdrian Kuegel     Value infNumFiniteDenom =
390a54f4eaeSMogball         rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
391a54f4eaeSMogball     Value one = rewriter.create<arith::ConstantOp>(
392942be7cbSAdrian Kuegel         loc, elementType, rewriter.getFloatAttr(elementType, 1));
393a54f4eaeSMogball     Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
394dec8af70SRiver Riddle         loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero),
395942be7cbSAdrian Kuegel         lhsReal);
396a54f4eaeSMogball     Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
397dec8af70SRiver Riddle         loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
398942be7cbSAdrian Kuegel         lhsImag);
399942be7cbSAdrian Kuegel     Value lhsRealIsInfWithSignTimesRhsReal =
400288d317fSKai Sasaki         rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal, fmf);
401942be7cbSAdrian Kuegel     Value lhsImagIsInfWithSignTimesRhsImag =
402288d317fSKai Sasaki         rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag, fmf);
403a54f4eaeSMogball     Value resultReal3 = rewriter.create<arith::MulFOp>(
404942be7cbSAdrian Kuegel         loc, inf,
405a54f4eaeSMogball         rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
406288d317fSKai Sasaki                                        lhsImagIsInfWithSignTimesRhsImag, fmf),
407288d317fSKai Sasaki         fmf);
408942be7cbSAdrian Kuegel     Value lhsRealIsInfWithSignTimesRhsImag =
409288d317fSKai Sasaki         rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag, fmf);
410942be7cbSAdrian Kuegel     Value lhsImagIsInfWithSignTimesRhsReal =
411288d317fSKai Sasaki         rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal, fmf);
412a54f4eaeSMogball     Value resultImag3 = rewriter.create<arith::MulFOp>(
413942be7cbSAdrian Kuegel         loc, inf,
414a54f4eaeSMogball         rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
415288d317fSKai Sasaki                                        lhsRealIsInfWithSignTimesRhsImag, fmf),
416288d317fSKai Sasaki         fmf);
417942be7cbSAdrian Kuegel 
418942be7cbSAdrian Kuegel     // Case 3: Finite numerator, infinite denominator.
419a54f4eaeSMogball     Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
420a54f4eaeSMogball         loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
421a54f4eaeSMogball     Value lhsImagFinite = rewriter.create<arith::CmpFOp>(
422a54f4eaeSMogball         loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
423a54f4eaeSMogball     Value lhsFinite =
424a54f4eaeSMogball         rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
425a54f4eaeSMogball     Value rhsRealInfinite = rewriter.create<arith::CmpFOp>(
426a54f4eaeSMogball         loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
427a54f4eaeSMogball     Value rhsImagInfinite = rewriter.create<arith::CmpFOp>(
428a54f4eaeSMogball         loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
429942be7cbSAdrian Kuegel     Value rhsInfinite =
430a54f4eaeSMogball         rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
431942be7cbSAdrian Kuegel     Value finiteNumInfiniteDenom =
432a54f4eaeSMogball         rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
433a54f4eaeSMogball     Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
434dec8af70SRiver Riddle         loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero),
435942be7cbSAdrian Kuegel         rhsReal);
436a54f4eaeSMogball     Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
437dec8af70SRiver Riddle         loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
438942be7cbSAdrian Kuegel         rhsImag);
439942be7cbSAdrian Kuegel     Value rhsRealIsInfWithSignTimesLhsReal =
440288d317fSKai Sasaki         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign, fmf);
441942be7cbSAdrian Kuegel     Value rhsImagIsInfWithSignTimesLhsImag =
442288d317fSKai Sasaki         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign, fmf);
443a54f4eaeSMogball     Value resultReal4 = rewriter.create<arith::MulFOp>(
444942be7cbSAdrian Kuegel         loc, zero,
445a54f4eaeSMogball         rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
446288d317fSKai Sasaki                                        rhsImagIsInfWithSignTimesLhsImag, fmf),
447288d317fSKai Sasaki         fmf);
448942be7cbSAdrian Kuegel     Value rhsRealIsInfWithSignTimesLhsImag =
449288d317fSKai Sasaki         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign, fmf);
450942be7cbSAdrian Kuegel     Value rhsImagIsInfWithSignTimesLhsReal =
451288d317fSKai Sasaki         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign, fmf);
452a54f4eaeSMogball     Value resultImag4 = rewriter.create<arith::MulFOp>(
453942be7cbSAdrian Kuegel         loc, zero,
454a54f4eaeSMogball         rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
455288d317fSKai Sasaki                                        rhsImagIsInfWithSignTimesLhsReal, fmf),
456288d317fSKai Sasaki         fmf);
457942be7cbSAdrian Kuegel 
458a54f4eaeSMogball     Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
459a54f4eaeSMogball         loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
460dec8af70SRiver Riddle     Value resultReal = rewriter.create<arith::SelectOp>(
461dec8af70SRiver Riddle         loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
462dec8af70SRiver Riddle     Value resultImag = rewriter.create<arith::SelectOp>(
463dec8af70SRiver Riddle         loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
464dec8af70SRiver Riddle     Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>(
465942be7cbSAdrian Kuegel         loc, finiteNumInfiniteDenom, resultReal4, resultReal);
466dec8af70SRiver Riddle     Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>(
467942be7cbSAdrian Kuegel         loc, finiteNumInfiniteDenom, resultImag4, resultImag);
468dec8af70SRiver Riddle     Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>(
469942be7cbSAdrian Kuegel         loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
470dec8af70SRiver Riddle     Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>(
471942be7cbSAdrian Kuegel         loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
472dec8af70SRiver Riddle     Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>(
473942be7cbSAdrian Kuegel         loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
474dec8af70SRiver Riddle     Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>(
475942be7cbSAdrian Kuegel         loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
476942be7cbSAdrian Kuegel 
477a54f4eaeSMogball     Value resultRealIsNaN = rewriter.create<arith::CmpFOp>(
478a54f4eaeSMogball         loc, arith::CmpFPredicate::UNO, resultReal, zero);
479a54f4eaeSMogball     Value resultImagIsNaN = rewriter.create<arith::CmpFOp>(
480a54f4eaeSMogball         loc, arith::CmpFPredicate::UNO, resultImag, zero);
481942be7cbSAdrian Kuegel     Value resultIsNaN =
482a54f4eaeSMogball         rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
483dec8af70SRiver Riddle     Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>(
484942be7cbSAdrian Kuegel         loc, resultIsNaN, resultRealSpecialCase1, resultReal);
485dec8af70SRiver Riddle     Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>(
486942be7cbSAdrian Kuegel         loc, resultIsNaN, resultImagSpecialCase1, resultImag);
487942be7cbSAdrian Kuegel 
488942be7cbSAdrian Kuegel     rewriter.replaceOpWithNewOp<complex::CreateOp>(
489942be7cbSAdrian Kuegel         op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
490942be7cbSAdrian Kuegel     return success();
491942be7cbSAdrian Kuegel   }
492942be7cbSAdrian Kuegel };
49373cbc91cSAdrian Kuegel 
49473cbc91cSAdrian Kuegel struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
49573cbc91cSAdrian Kuegel   using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
49673cbc91cSAdrian Kuegel 
49773cbc91cSAdrian Kuegel   LogicalResult
498b54c724bSRiver Riddle   matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
49973cbc91cSAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
50073cbc91cSAdrian Kuegel     auto loc = op.getLoc();
5015550c821STres Popp     auto type = cast<ComplexType>(adaptor.getComplex().getType());
5025550c821STres Popp     auto elementType = cast<FloatType>(type.getElementType());
503d230bf3fSKai Sasaki     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
50473cbc91cSAdrian Kuegel 
50573cbc91cSAdrian Kuegel     Value real =
506c0342a2dSJacques Pienaar         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
50773cbc91cSAdrian Kuegel     Value imag =
508c0342a2dSJacques Pienaar         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
509d230bf3fSKai Sasaki     Value expReal = rewriter.create<math::ExpOp>(loc, real, fmf.getValue());
510d230bf3fSKai Sasaki     Value cosImag = rewriter.create<math::CosOp>(loc, imag, fmf.getValue());
511d230bf3fSKai Sasaki     Value resultReal =
512d230bf3fSKai Sasaki         rewriter.create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue());
513d230bf3fSKai Sasaki     Value sinImag = rewriter.create<math::SinOp>(loc, imag, fmf.getValue());
514d230bf3fSKai Sasaki     Value resultImag =
515d230bf3fSKai Sasaki         rewriter.create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue());
51673cbc91cSAdrian Kuegel 
51773cbc91cSAdrian Kuegel     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
51873cbc91cSAdrian Kuegel                                                    resultImag);
51973cbc91cSAdrian Kuegel     return success();
52073cbc91cSAdrian Kuegel   }
52173cbc91cSAdrian Kuegel };
522662e074dSAdrian Kuegel 
52318ee0032SAlexander Belyaev Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg,
52418ee0032SAlexander Belyaev                          ArrayRef<double> coefficients,
52518ee0032SAlexander Belyaev                          arith::FastMathFlagsAttr fmf) {
52618ee0032SAlexander Belyaev   auto argType = mlir::cast<FloatType>(arg.getType());
52718ee0032SAlexander Belyaev   Value poly =
52818ee0032SAlexander Belyaev       b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[0]));
529*06011feeSJie Fu   for (unsigned i = 1; i < coefficients.size(); ++i) {
53018ee0032SAlexander Belyaev     poly = b.create<math::FmaOp>(
53118ee0032SAlexander Belyaev         poly, arg,
53218ee0032SAlexander Belyaev         b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[i])),
53318ee0032SAlexander Belyaev         fmf);
53418ee0032SAlexander Belyaev   }
53518ee0032SAlexander Belyaev   return poly;
53618ee0032SAlexander Belyaev }
53718ee0032SAlexander Belyaev 
538338e76f8Sbixia1 struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
539338e76f8Sbixia1   using OpConversionPattern<complex::Expm1Op>::OpConversionPattern;
540338e76f8Sbixia1 
54118ee0032SAlexander Belyaev   // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
54218ee0032SAlexander Belyaev   //            [handle inaccuracies when a and/or b are small]
54318ee0032SAlexander Belyaev   //            = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i
54418ee0032SAlexander Belyaev   //            = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i
545338e76f8Sbixia1   LogicalResult
546338e76f8Sbixia1   matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
547338e76f8Sbixia1                   ConversionPatternRewriter &rewriter) const override {
54818ee0032SAlexander Belyaev     auto type = op.getType();
54918ee0032SAlexander Belyaev     auto elemType = mlir::cast<FloatType>(type.getElementType());
55018ee0032SAlexander Belyaev 
551d230bf3fSKai Sasaki     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
55218ee0032SAlexander Belyaev     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
55318ee0032SAlexander Belyaev     Value real = b.create<complex::ReOp>(adaptor.getComplex());
55418ee0032SAlexander Belyaev     Value imag = b.create<complex::ImOp>(adaptor.getComplex());
555338e76f8Sbixia1 
55618ee0032SAlexander Belyaev     Value zero = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 0.0));
55718ee0032SAlexander Belyaev     Value one = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 1.0));
558338e76f8Sbixia1 
55918ee0032SAlexander Belyaev     Value expm1Real = b.create<math::ExpM1Op>(real, fmf);
56018ee0032SAlexander Belyaev     Value expReal = b.create<arith::AddFOp>(expm1Real, one, fmf);
561338e76f8Sbixia1 
56218ee0032SAlexander Belyaev     Value sinImag = b.create<math::SinOp>(imag, fmf);
56318ee0032SAlexander Belyaev     Value cosm1Imag = emitCosm1(imag, fmf, b);
56418ee0032SAlexander Belyaev     Value cosImag = b.create<arith::AddFOp>(cosm1Imag, one, fmf);
56518ee0032SAlexander Belyaev 
56618ee0032SAlexander Belyaev     Value realResult = b.create<arith::AddFOp>(
56718ee0032SAlexander Belyaev         b.create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf);
56818ee0032SAlexander Belyaev 
56918ee0032SAlexander Belyaev     Value imagIsZero = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag,
57018ee0032SAlexander Belyaev                                                zero, fmf.getValue());
57118ee0032SAlexander Belyaev     Value imagResult = b.create<arith::SelectOp>(
57218ee0032SAlexander Belyaev         imagIsZero, zero, b.create<arith::MulFOp>(expReal, sinImag, fmf));
57318ee0032SAlexander Belyaev 
57418ee0032SAlexander Belyaev     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realResult,
57518ee0032SAlexander Belyaev                                                    imagResult);
576338e76f8Sbixia1     return success();
577338e76f8Sbixia1   }
57818ee0032SAlexander Belyaev 
57918ee0032SAlexander Belyaev private:
58018ee0032SAlexander Belyaev   Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf,
58118ee0032SAlexander Belyaev                   ImplicitLocOpBuilder &b) const {
58218ee0032SAlexander Belyaev     auto argType = mlir::cast<FloatType>(arg.getType());
58318ee0032SAlexander Belyaev     auto negHalf = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -0.5));
58418ee0032SAlexander Belyaev     auto negOne = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -1.0));
58518ee0032SAlexander Belyaev 
58618ee0032SAlexander Belyaev     // Algorithm copied from cephes cosm1.
58718ee0032SAlexander Belyaev     SmallVector<double, 7> kCoeffs{
58818ee0032SAlexander Belyaev         4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
58918ee0032SAlexander Belyaev         2.0876754287081521758361E-9,  -2.7557319214999787979814E-7,
59018ee0032SAlexander Belyaev         2.4801587301570552304991E-5,  -1.3888888888888872993737E-3,
59118ee0032SAlexander Belyaev         4.1666666666666666609054E-2,
59218ee0032SAlexander Belyaev     };
59318ee0032SAlexander Belyaev     Value cos = b.create<math::CosOp>(arg, fmf);
59418ee0032SAlexander Belyaev     Value forLargeArg = b.create<arith::AddFOp>(cos, negOne, fmf);
59518ee0032SAlexander Belyaev 
59618ee0032SAlexander Belyaev     Value argPow2 = b.create<arith::MulFOp>(arg, arg, fmf);
59718ee0032SAlexander Belyaev     Value argPow4 = b.create<arith::MulFOp>(argPow2, argPow2, fmf);
59818ee0032SAlexander Belyaev     Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf);
59918ee0032SAlexander Belyaev 
60018ee0032SAlexander Belyaev     auto forSmallArg =
60118ee0032SAlexander Belyaev         b.create<arith::AddFOp>(b.create<arith::MulFOp>(argPow4, poly, fmf),
60218ee0032SAlexander Belyaev                                 b.create<arith::MulFOp>(negHalf, argPow2, fmf));
60318ee0032SAlexander Belyaev 
60418ee0032SAlexander Belyaev     // (pi/4)^2 is approximately 0.61685
60518ee0032SAlexander Belyaev     Value piOver4Pow2 =
60618ee0032SAlexander Belyaev         b.create<arith::ConstantOp>(b.getFloatAttr(argType, 0.61685));
60718ee0032SAlexander Belyaev     Value cond = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, argPow2,
60818ee0032SAlexander Belyaev                                          piOver4Pow2, fmf.getValue());
60918ee0032SAlexander Belyaev     return b.create<arith::SelectOp>(cond, forLargeArg, forSmallArg);
61018ee0032SAlexander Belyaev   }
611338e76f8Sbixia1 };
612338e76f8Sbixia1 
613380fa71fSAdrian Kuegel struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
614380fa71fSAdrian Kuegel   using OpConversionPattern<complex::LogOp>::OpConversionPattern;
615380fa71fSAdrian Kuegel 
616380fa71fSAdrian Kuegel   LogicalResult
617b54c724bSRiver Riddle   matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
618380fa71fSAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
6195550c821STres Popp     auto type = cast<ComplexType>(adaptor.getComplex().getType());
6205550c821STres Popp     auto elementType = cast<FloatType>(type.getElementType());
6218aaa2cb8SKai Sasaki     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
622380fa71fSAdrian Kuegel     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
623380fa71fSAdrian Kuegel 
6248aaa2cb8SKai Sasaki     Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(),
6258aaa2cb8SKai Sasaki                                          fmf.getValue());
6268aaa2cb8SKai Sasaki     Value resultReal = b.create<math::LogOp>(elementType, abs, fmf.getValue());
627c0342a2dSJacques Pienaar     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
628c0342a2dSJacques Pienaar     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
6298aaa2cb8SKai Sasaki     Value resultImag =
6308aaa2cb8SKai Sasaki         b.create<math::Atan2Op>(elementType, imag, real, fmf.getValue());
631380fa71fSAdrian Kuegel     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
632380fa71fSAdrian Kuegel                                                    resultImag);
633380fa71fSAdrian Kuegel     return success();
634380fa71fSAdrian Kuegel   }
635380fa71fSAdrian Kuegel };
636380fa71fSAdrian Kuegel 
6376e80e3bdSAdrian Kuegel struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
6386e80e3bdSAdrian Kuegel   using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
6396e80e3bdSAdrian Kuegel 
6406e80e3bdSAdrian Kuegel   LogicalResult
641b54c724bSRiver Riddle   matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
6426e80e3bdSAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
6435550c821STres Popp     auto type = cast<ComplexType>(adaptor.getComplex().getType());
6445550c821STres Popp     auto elementType = cast<FloatType>(type.getElementType());
6455c9315f5SJohannes Reifferscheid     arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
6466e80e3bdSAdrian Kuegel     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
6476e80e3bdSAdrian Kuegel 
6485c9315f5SJohannes Reifferscheid     Value real = b.create<complex::ReOp>(adaptor.getComplex());
6495c9315f5SJohannes Reifferscheid     Value imag = b.create<complex::ImOp>(adaptor.getComplex());
650375a5cb6SJohannes Reifferscheid 
651375a5cb6SJohannes Reifferscheid     Value half = b.create<arith::ConstantOp>(elementType,
652375a5cb6SJohannes Reifferscheid                                              b.getFloatAttr(elementType, 0.5));
653a54f4eaeSMogball     Value one = b.create<arith::ConstantOp>(elementType,
654a54f4eaeSMogball                                             b.getFloatAttr(elementType, 1));
6555c9315f5SJohannes Reifferscheid     Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf);
6565c9315f5SJohannes Reifferscheid     Value absRealPlusOne = b.create<math::AbsFOp>(realPlusOne, fmf);
6575c9315f5SJohannes Reifferscheid     Value absImag = b.create<math::AbsFOp>(imag, fmf);
658375a5cb6SJohannes Reifferscheid 
6595c9315f5SJohannes Reifferscheid     Value maxAbs = b.create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf);
6605c9315f5SJohannes Reifferscheid     Value minAbs = b.create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf);
661375a5cb6SJohannes Reifferscheid 
6625c9315f5SJohannes Reifferscheid     Value useReal = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
6635c9315f5SJohannes Reifferscheid                                             realPlusOne, absImag, fmf);
6645c9315f5SJohannes Reifferscheid     Value maxMinusOne = b.create<arith::SubFOp>(maxAbs, one, fmf);
6655c9315f5SJohannes Reifferscheid     Value maxAbsOfRealPlusOneAndImagMinusOne =
6665c9315f5SJohannes Reifferscheid         b.create<arith::SelectOp>(useReal, real, maxMinusOne);
6679b225d01SJohannes Reifferscheid     arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
6689b225d01SJohannes Reifferscheid         fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
6699b225d01SJohannes Reifferscheid     Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmfWithNaNInf);
6705c9315f5SJohannes Reifferscheid     Value logOfMaxAbsOfRealPlusOneAndImag =
6715c9315f5SJohannes Reifferscheid         b.create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf);
6725c9315f5SJohannes Reifferscheid     Value logOfSqrtPart = b.create<math::Log1pOp>(
6739b225d01SJohannes Reifferscheid         b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmfWithNaNInf),
6749b225d01SJohannes Reifferscheid         fmfWithNaNInf);
6755c9315f5SJohannes Reifferscheid     Value r = b.create<arith::AddFOp>(
6769b225d01SJohannes Reifferscheid         b.create<arith::MulFOp>(half, logOfSqrtPart, fmfWithNaNInf),
6779b225d01SJohannes Reifferscheid         logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf);
6785c9315f5SJohannes Reifferscheid     Value resultReal = b.create<arith::SelectOp>(
6799b225d01SJohannes Reifferscheid         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf),
6809b225d01SJohannes Reifferscheid         minAbs, r);
6815c9315f5SJohannes Reifferscheid     Value resultImag = b.create<math::Atan2Op>(imag, realPlusOne, fmf);
682375a5cb6SJohannes Reifferscheid     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
683375a5cb6SJohannes Reifferscheid                                                    resultImag);
6846e80e3bdSAdrian Kuegel     return success();
6856e80e3bdSAdrian Kuegel   }
6866e80e3bdSAdrian Kuegel };
6876e80e3bdSAdrian Kuegel 
688bf17ee19SAdrian Kuegel struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
689bf17ee19SAdrian Kuegel   using OpConversionPattern<complex::MulOp>::OpConversionPattern;
690bf17ee19SAdrian Kuegel 
691bf17ee19SAdrian Kuegel   LogicalResult
692b54c724bSRiver Riddle   matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
693bf17ee19SAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
694bf17ee19SAdrian Kuegel     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
6955550c821STres Popp     auto type = cast<ComplexType>(adaptor.getLhs().getType());
6965550c821STres Popp     auto elementType = cast<FloatType>(type.getElementType());
697eee71ed3SKai Sasaki     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
698eee71ed3SKai Sasaki     auto fmfValue = fmf.getValue();
699c0342a2dSJacques Pienaar     Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
700c0342a2dSJacques Pienaar     Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
701c0342a2dSJacques Pienaar     Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
702c0342a2dSJacques Pienaar     Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
703eee71ed3SKai Sasaki     Value lhsRealTimesRhsReal =
704eee71ed3SKai Sasaki         b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
705eee71ed3SKai Sasaki     Value lhsImagTimesRhsImag =
706eee71ed3SKai Sasaki         b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
707eee71ed3SKai Sasaki     Value real = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
708eee71ed3SKai Sasaki                                          lhsImagTimesRhsImag, fmfValue);
709eee71ed3SKai Sasaki     Value lhsImagTimesRhsReal =
710eee71ed3SKai Sasaki         b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
711eee71ed3SKai Sasaki     Value lhsRealTimesRhsImag =
712eee71ed3SKai Sasaki         b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
713eee71ed3SKai Sasaki     Value imag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
714eee71ed3SKai Sasaki                                          lhsRealTimesRhsImag, fmfValue);
715bf17ee19SAdrian Kuegel     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
716bf17ee19SAdrian Kuegel     return success();
717bf17ee19SAdrian Kuegel   }
718bf17ee19SAdrian Kuegel };
719bf17ee19SAdrian Kuegel 
720662e074dSAdrian Kuegel struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
721662e074dSAdrian Kuegel   using OpConversionPattern<complex::NegOp>::OpConversionPattern;
722662e074dSAdrian Kuegel 
723662e074dSAdrian Kuegel   LogicalResult
724b54c724bSRiver Riddle   matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
725662e074dSAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
726662e074dSAdrian Kuegel     auto loc = op.getLoc();
7275550c821STres Popp     auto type = cast<ComplexType>(adaptor.getComplex().getType());
7285550c821STres Popp     auto elementType = cast<FloatType>(type.getElementType());
729662e074dSAdrian Kuegel 
730662e074dSAdrian Kuegel     Value real =
731c0342a2dSJacques Pienaar         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
732662e074dSAdrian Kuegel     Value imag =
733c0342a2dSJacques Pienaar         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
734a54f4eaeSMogball     Value negReal = rewriter.create<arith::NegFOp>(loc, real);
735a54f4eaeSMogball     Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
736662e074dSAdrian Kuegel     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
737662e074dSAdrian Kuegel     return success();
738662e074dSAdrian Kuegel   }
739662e074dSAdrian Kuegel };
740f112bd61SAdrian Kuegel 
741672b908bSGoran Flegar struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
742672b908bSGoran Flegar   using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
743672b908bSGoran Flegar 
7447d2d8e2aSKai Sasaki   std::pair<Value, Value> combine(Location loc, Value scaledExp,
7457d2d8e2aSKai Sasaki                                   Value reciprocalExp, Value sin, Value cos,
7467d2d8e2aSKai Sasaki                                   ConversionPatternRewriter &rewriter,
7477d2d8e2aSKai Sasaki                                   arith::FastMathFlagsAttr fmf) const override {
748672b908bSGoran Flegar     // Complex sine is defined as;
749672b908bSGoran Flegar     //   sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
750672b908bSGoran Flegar     // Plugging in:
751672b908bSGoran Flegar     //   exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
752672b908bSGoran Flegar     //   exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
753672b908bSGoran Flegar     // and defining t := exp(y)
754672b908bSGoran Flegar     // We get:
755672b908bSGoran Flegar     //   Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
756672b908bSGoran Flegar     //   Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
7577d2d8e2aSKai Sasaki     Value sum =
7587d2d8e2aSKai Sasaki         rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf);
7597d2d8e2aSKai Sasaki     Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin, fmf);
7607d2d8e2aSKai Sasaki     Value diff =
7617d2d8e2aSKai Sasaki         rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf);
7627d2d8e2aSKai Sasaki     Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos, fmf);
763672b908bSGoran Flegar     return {resultReal, resultImag};
764672b908bSGoran Flegar   }
765672b908bSGoran Flegar };
766672b908bSGoran Flegar 
767f711785eSAlexander Belyaev // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
768f711785eSAlexander Belyaev struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
769f711785eSAlexander Belyaev   using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;
770f711785eSAlexander Belyaev 
771f711785eSAlexander Belyaev   LogicalResult
772f711785eSAlexander Belyaev   matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
773f711785eSAlexander Belyaev                   ConversionPatternRewriter &rewriter) const override {
774ff9bc3a0SJohannes Reifferscheid     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
775f711785eSAlexander Belyaev 
7765550c821STres Popp     auto type = cast<ComplexType>(op.getType());
777fac349a1SChristian Sigg     auto elementType = cast<FloatType>(type.getElementType());
778ff9bc3a0SJohannes Reifferscheid     arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
779f711785eSAlexander Belyaev 
780ff9bc3a0SJohannes Reifferscheid     auto cst = [&](APFloat v) {
781ff9bc3a0SJohannes Reifferscheid       return b.create<arith::ConstantOp>(elementType,
782ff9bc3a0SJohannes Reifferscheid                                          b.getFloatAttr(elementType, v));
783ff9bc3a0SJohannes Reifferscheid     };
784ff9bc3a0SJohannes Reifferscheid     const auto &floatSemantics = elementType.getFloatSemantics();
785ff9bc3a0SJohannes Reifferscheid     Value zero = cst(APFloat::getZero(floatSemantics));
786ff9bc3a0SJohannes Reifferscheid     Value half = b.create<arith::ConstantOp>(elementType,
787ff9bc3a0SJohannes Reifferscheid                                              b.getFloatAttr(elementType, 0.5));
788f711785eSAlexander Belyaev 
789f711785eSAlexander Belyaev     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
790f711785eSAlexander Belyaev     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
79133e60f35SJohannes Reifferscheid     Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
792ff9bc3a0SJohannes Reifferscheid     Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
793ff9bc3a0SJohannes Reifferscheid     Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf);
794ff9bc3a0SJohannes Reifferscheid     Value cos = b.create<math::CosOp>(sqrtArg, fmf);
795ff9bc3a0SJohannes Reifferscheid     Value sin = b.create<math::SinOp>(sqrtArg, fmf);
796ff9bc3a0SJohannes Reifferscheid     // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply
797ff9bc3a0SJohannes Reifferscheid     // 0 * inf.
798ff9bc3a0SJohannes Reifferscheid     Value sinIsZero =
799ff9bc3a0SJohannes Reifferscheid         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf);
800f711785eSAlexander Belyaev 
801ff9bc3a0SJohannes Reifferscheid     Value resultReal = b.create<arith::MulFOp>(absSqrt, cos, fmf);
802f711785eSAlexander Belyaev     Value resultImag = b.create<arith::SelectOp>(
803ff9bc3a0SJohannes Reifferscheid         sinIsZero, zero, b.create<arith::MulFOp>(absSqrt, sin, fmf));
804ff9bc3a0SJohannes Reifferscheid     if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
805ff9bc3a0SJohannes Reifferscheid                                             arith::FastMathFlags::ninf)) {
806ff9bc3a0SJohannes Reifferscheid       Value inf = cst(APFloat::getInf(floatSemantics));
807ff9bc3a0SJohannes Reifferscheid       Value negInf = cst(APFloat::getInf(floatSemantics, true));
808ff9bc3a0SJohannes Reifferscheid       Value nan = cst(APFloat::getNaN(floatSemantics));
809ff9bc3a0SJohannes Reifferscheid       Value absImag = b.create<math::AbsFOp>(elementType, imag, fmf);
810ff9bc3a0SJohannes Reifferscheid 
811ff9bc3a0SJohannes Reifferscheid       Value absImagIsInf =
812ff9bc3a0SJohannes Reifferscheid           b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
813ff9bc3a0SJohannes Reifferscheid       Value absImagIsNotInf =
814ff9bc3a0SJohannes Reifferscheid           b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf);
815ff9bc3a0SJohannes Reifferscheid       Value realIsInf =
816ff9bc3a0SJohannes Reifferscheid           b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf);
817ff9bc3a0SJohannes Reifferscheid       Value realIsNegInf =
818ff9bc3a0SJohannes Reifferscheid           b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf);
819f711785eSAlexander Belyaev 
820f711785eSAlexander Belyaev       resultReal = b.create<arith::SelectOp>(
821ff9bc3a0SJohannes Reifferscheid           b.create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero,
822f711785eSAlexander Belyaev           resultReal);
823ff9bc3a0SJohannes Reifferscheid       resultReal = b.create<arith::SelectOp>(
824ff9bc3a0SJohannes Reifferscheid           b.create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal);
825f711785eSAlexander Belyaev 
826ff9bc3a0SJohannes Reifferscheid       Value imagSignInf = b.create<math::CopySignOp>(inf, imag, fmf);
827ff9bc3a0SJohannes Reifferscheid       resultImag = b.create<arith::SelectOp>(
828ff9bc3a0SJohannes Reifferscheid           b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt),
829ff9bc3a0SJohannes Reifferscheid           nan, resultImag);
830ff9bc3a0SJohannes Reifferscheid       resultImag = b.create<arith::SelectOp>(
831ff9bc3a0SJohannes Reifferscheid           b.create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf,
832ff9bc3a0SJohannes Reifferscheid           resultImag);
833ff9bc3a0SJohannes Reifferscheid     }
834f711785eSAlexander Belyaev 
835ff9bc3a0SJohannes Reifferscheid     Value resultIsZero =
836ff9bc3a0SJohannes Reifferscheid         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
837ff9bc3a0SJohannes Reifferscheid     resultReal = b.create<arith::SelectOp>(resultIsZero, zero, resultReal);
838ff9bc3a0SJohannes Reifferscheid     resultImag = b.create<arith::SelectOp>(resultIsZero, zero, resultImag);
839f711785eSAlexander Belyaev 
840f711785eSAlexander Belyaev     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
841f711785eSAlexander Belyaev                                                    resultImag);
842f711785eSAlexander Belyaev     return success();
843f711785eSAlexander Belyaev   }
844f711785eSAlexander Belyaev };
845f711785eSAlexander Belyaev 
846f112bd61SAdrian Kuegel struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
847f112bd61SAdrian Kuegel   using OpConversionPattern<complex::SignOp>::OpConversionPattern;
848f112bd61SAdrian Kuegel 
849f112bd61SAdrian Kuegel   LogicalResult
850b54c724bSRiver Riddle   matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
851f112bd61SAdrian Kuegel                   ConversionPatternRewriter &rewriter) const override {
8525550c821STres Popp     auto type = cast<ComplexType>(adaptor.getComplex().getType());
8535550c821STres Popp     auto elementType = cast<FloatType>(type.getElementType());
854f112bd61SAdrian Kuegel     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
855a522dbbdSKai Sasaki     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
856f112bd61SAdrian Kuegel 
857c0342a2dSJacques Pienaar     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
858c0342a2dSJacques Pienaar     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
859a54f4eaeSMogball     Value zero =
860a54f4eaeSMogball         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
861a54f4eaeSMogball     Value realIsZero =
862a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
863a54f4eaeSMogball     Value imagIsZero =
864a54f4eaeSMogball         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
865a54f4eaeSMogball     Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
866a522dbbdSKai Sasaki     auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), fmf);
867a522dbbdSKai Sasaki     Value realSign = b.create<arith::DivFOp>(real, abs, fmf);
868a522dbbdSKai Sasaki     Value imagSign = b.create<arith::DivFOp>(imag, abs, fmf);
869f112bd61SAdrian Kuegel     Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
870dec8af70SRiver Riddle     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
871dec8af70SRiver Riddle                                                  adaptor.getComplex(), sign);
872f112bd61SAdrian Kuegel     return success();
873f112bd61SAdrian Kuegel   }
874f112bd61SAdrian Kuegel };
8756d75c897Slewuathe 
876f43deca2SJohannes Reifferscheid template <typename Op>
877f43deca2SJohannes Reifferscheid struct TanTanhOpConversion : public OpConversionPattern<Op> {
878f43deca2SJohannes Reifferscheid   using OpConversionPattern<Op>::OpConversionPattern;
8796d75c897Slewuathe 
8806d75c897Slewuathe   LogicalResult
881e0a293d1SKazu Hirata   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
882ffb8eecdSlewuathe                   ConversionPatternRewriter &rewriter) const override {
8839da0ef16SJohannes Reifferscheid     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
884ffb8eecdSlewuathe     auto loc = op.getLoc();
8855550c821STres Popp     auto type = cast<ComplexType>(adaptor.getComplex().getType());
8865550c821STres Popp     auto elementType = cast<FloatType>(type.getElementType());
8879da0ef16SJohannes Reifferscheid     arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
8889da0ef16SJohannes Reifferscheid     const auto &floatSemantics = elementType.getFloatSemantics();
889ffb8eecdSlewuathe 
890ffb8eecdSlewuathe     Value real =
8919da0ef16SJohannes Reifferscheid         b.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
892ffb8eecdSlewuathe     Value imag =
8939da0ef16SJohannes Reifferscheid         b.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
894f43deca2SJohannes Reifferscheid     Value negOne = b.create<arith::ConstantOp>(
895f43deca2SJohannes Reifferscheid         elementType, b.getFloatAttr(elementType, -1.0));
896f43deca2SJohannes Reifferscheid 
897f43deca2SJohannes Reifferscheid     if constexpr (std::is_same_v<Op, complex::TanOp>) {
898f43deca2SJohannes Reifferscheid       // tan(x+yi) = -i*tanh(-y + xi)
899f43deca2SJohannes Reifferscheid       std::swap(real, imag);
900f43deca2SJohannes Reifferscheid       real = b.create<arith::MulFOp>(real, negOne, fmf);
901f43deca2SJohannes Reifferscheid     }
9029da0ef16SJohannes Reifferscheid 
9039da0ef16SJohannes Reifferscheid     auto cst = [&](APFloat v) {
9049da0ef16SJohannes Reifferscheid       return b.create<arith::ConstantOp>(elementType,
9059da0ef16SJohannes Reifferscheid                                          b.getFloatAttr(elementType, v));
9069da0ef16SJohannes Reifferscheid     };
9079da0ef16SJohannes Reifferscheid     Value inf = cst(APFloat::getInf(floatSemantics));
9089da0ef16SJohannes Reifferscheid     Value four = b.create<arith::ConstantOp>(elementType,
9099da0ef16SJohannes Reifferscheid                                              b.getFloatAttr(elementType, 4.0));
9109da0ef16SJohannes Reifferscheid     Value twoReal = b.create<arith::AddFOp>(real, real, fmf);
9119da0ef16SJohannes Reifferscheid     Value negTwoReal = b.create<arith::MulFOp>(negOne, twoReal, fmf);
9129da0ef16SJohannes Reifferscheid 
9139da0ef16SJohannes Reifferscheid     Value expTwoRealMinusOne = b.create<math::ExpM1Op>(twoReal, fmf);
9149da0ef16SJohannes Reifferscheid     Value expNegTwoRealMinusOne = b.create<math::ExpM1Op>(negTwoReal, fmf);
9159da0ef16SJohannes Reifferscheid     Value realNum =
9169da0ef16SJohannes Reifferscheid         b.create<arith::SubFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
9179da0ef16SJohannes Reifferscheid 
9189da0ef16SJohannes Reifferscheid     Value cosImag = b.create<math::CosOp>(imag, fmf);
9199da0ef16SJohannes Reifferscheid     Value cosImagSq = b.create<arith::MulFOp>(cosImag, cosImag, fmf);
9209da0ef16SJohannes Reifferscheid     Value twoCosTwoImagPlusOne = b.create<arith::MulFOp>(cosImagSq, four, fmf);
9219da0ef16SJohannes Reifferscheid     Value sinImag = b.create<math::SinOp>(imag, fmf);
9229da0ef16SJohannes Reifferscheid 
9239da0ef16SJohannes Reifferscheid     Value imagNum = b.create<arith::MulFOp>(
9249da0ef16SJohannes Reifferscheid         four, b.create<arith::MulFOp>(cosImag, sinImag, fmf), fmf);
9259da0ef16SJohannes Reifferscheid 
9269da0ef16SJohannes Reifferscheid     Value expSumMinusTwo =
9279da0ef16SJohannes Reifferscheid         b.create<arith::AddFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
9289da0ef16SJohannes Reifferscheid     Value denom =
9299da0ef16SJohannes Reifferscheid         b.create<arith::AddFOp>(expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
9309da0ef16SJohannes Reifferscheid 
9319da0ef16SJohannes Reifferscheid     Value isInf = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
9329da0ef16SJohannes Reifferscheid                                           expSumMinusTwo, inf, fmf);
9339da0ef16SJohannes Reifferscheid     Value realLimit = b.create<math::CopySignOp>(negOne, real, fmf);
9349da0ef16SJohannes Reifferscheid 
9359da0ef16SJohannes Reifferscheid     Value resultReal = b.create<arith::SelectOp>(
9369da0ef16SJohannes Reifferscheid         isInf, realLimit, b.create<arith::DivFOp>(realNum, denom, fmf));
9379da0ef16SJohannes Reifferscheid     Value resultImag = b.create<arith::DivFOp>(imagNum, denom, fmf);
9389da0ef16SJohannes Reifferscheid 
9399da0ef16SJohannes Reifferscheid     if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
9409da0ef16SJohannes Reifferscheid                                             arith::FastMathFlags::ninf)) {
9419da0ef16SJohannes Reifferscheid       Value absReal = b.create<math::AbsFOp>(real, fmf);
9429da0ef16SJohannes Reifferscheid       Value zero = b.create<arith::ConstantOp>(
9439da0ef16SJohannes Reifferscheid           elementType, b.getFloatAttr(elementType, 0.0));
9449da0ef16SJohannes Reifferscheid       Value nan = cst(APFloat::getNaN(floatSemantics));
9459da0ef16SJohannes Reifferscheid 
9469da0ef16SJohannes Reifferscheid       Value absRealIsInf =
9479da0ef16SJohannes Reifferscheid           b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
9489da0ef16SJohannes Reifferscheid       Value imagIsZero =
9499da0ef16SJohannes Reifferscheid           b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
9509da0ef16SJohannes Reifferscheid       Value absRealIsNotInf = b.create<arith::XOrIOp>(
9519da0ef16SJohannes Reifferscheid           absRealIsInf, b.create<arith::ConstantIntOp>(true, /*width=*/1));
9529da0ef16SJohannes Reifferscheid 
9539da0ef16SJohannes Reifferscheid       Value imagNumIsNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO,
9549da0ef16SJohannes Reifferscheid                                                    imagNum, imagNum, fmf);
9559da0ef16SJohannes Reifferscheid       Value resultRealIsNaN =
9569da0ef16SJohannes Reifferscheid           b.create<arith::AndIOp>(imagNumIsNaN, absRealIsNotInf);
9579da0ef16SJohannes Reifferscheid       Value resultImagIsZero = b.create<arith::OrIOp>(
9589da0ef16SJohannes Reifferscheid           imagIsZero, b.create<arith::AndIOp>(absRealIsInf, imagNumIsNaN));
9599da0ef16SJohannes Reifferscheid 
9609da0ef16SJohannes Reifferscheid       resultReal = b.create<arith::SelectOp>(resultRealIsNaN, nan, resultReal);
9619da0ef16SJohannes Reifferscheid       resultImag =
9629da0ef16SJohannes Reifferscheid           b.create<arith::SelectOp>(resultImagIsZero, zero, resultImag);
9639da0ef16SJohannes Reifferscheid     }
9649da0ef16SJohannes Reifferscheid 
965f43deca2SJohannes Reifferscheid     if constexpr (std::is_same_v<Op, complex::TanOp>) {
966f43deca2SJohannes Reifferscheid       // tan(x+yi) = -i*tanh(-y + xi)
967f43deca2SJohannes Reifferscheid       std::swap(resultReal, resultImag);
968f43deca2SJohannes Reifferscheid       resultImag = b.create<arith::MulFOp>(resultImag, negOne, fmf);
969f43deca2SJohannes Reifferscheid     }
970f43deca2SJohannes Reifferscheid 
9719da0ef16SJohannes Reifferscheid     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
9729da0ef16SJohannes Reifferscheid                                                    resultImag);
973ffb8eecdSlewuathe     return success();
974ffb8eecdSlewuathe   }
975ffb8eecdSlewuathe };
976ffb8eecdSlewuathe 
97762a34f6aSlewuathe struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
97862a34f6aSlewuathe   using OpConversionPattern<complex::ConjOp>::OpConversionPattern;
97962a34f6aSlewuathe 
98062a34f6aSlewuathe   LogicalResult
98162a34f6aSlewuathe   matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
98262a34f6aSlewuathe                   ConversionPatternRewriter &rewriter) const override {
98362a34f6aSlewuathe     auto loc = op.getLoc();
9845550c821STres Popp     auto type = cast<ComplexType>(adaptor.getComplex().getType());
9855550c821STres Popp     auto elementType = cast<FloatType>(type.getElementType());
98662a34f6aSlewuathe     Value real =
98762a34f6aSlewuathe         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
98862a34f6aSlewuathe     Value imag =
98962a34f6aSlewuathe         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
99062a34f6aSlewuathe     Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag);
99162a34f6aSlewuathe 
99262a34f6aSlewuathe     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag);
99362a34f6aSlewuathe 
99462a34f6aSlewuathe     return success();
99562a34f6aSlewuathe   }
99662a34f6aSlewuathe };
99762a34f6aSlewuathe 
99877dd4357SJohannes Reifferscheid /// Converts lhs^y = (a+bi)^(c+di) to
9996c6eddb6Sbixia1 ///    (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
10006c6eddb6Sbixia1 ///    where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
10016c6eddb6Sbixia1 static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
100277dd4357SJohannes Reifferscheid                                  ComplexType type, Value lhs, Value c, Value d,
100377dd4357SJohannes Reifferscheid                                  arith::FastMathFlags fmf) {
10045550c821STres Popp   auto elementType = cast<FloatType>(type.getElementType());
10056c6eddb6Sbixia1 
100677dd4357SJohannes Reifferscheid   Value a = builder.create<complex::ReOp>(lhs);
100777dd4357SJohannes Reifferscheid   Value b = builder.create<complex::ImOp>(lhs);
10086c6eddb6Sbixia1 
100977dd4357SJohannes Reifferscheid   Value abs = builder.create<complex::AbsOp>(lhs, fmf);
101077dd4357SJohannes Reifferscheid   Value absToC = builder.create<math::PowFOp>(abs, c, fmf);
10116c6eddb6Sbixia1 
101277dd4357SJohannes Reifferscheid   Value negD = builder.create<arith::NegFOp>(d, fmf);
101377dd4357SJohannes Reifferscheid   Value argLhs = builder.create<math::Atan2Op>(b, a, fmf);
101477dd4357SJohannes Reifferscheid   Value negDArgLhs = builder.create<arith::MulFOp>(negD, argLhs, fmf);
101577dd4357SJohannes Reifferscheid   Value expNegDArgLhs = builder.create<math::ExpOp>(negDArgLhs, fmf);
10166c6eddb6Sbixia1 
101777dd4357SJohannes Reifferscheid   Value coeff = builder.create<arith::MulFOp>(absToC, expNegDArgLhs, fmf);
101877dd4357SJohannes Reifferscheid   Value lnAbs = builder.create<math::LogOp>(abs, fmf);
101977dd4357SJohannes Reifferscheid   Value cArgLhs = builder.create<arith::MulFOp>(c, argLhs, fmf);
102077dd4357SJohannes Reifferscheid   Value dLnAbs = builder.create<arith::MulFOp>(d, lnAbs, fmf);
102177dd4357SJohannes Reifferscheid   Value q = builder.create<arith::AddFOp>(cArgLhs, dLnAbs, fmf);
102277dd4357SJohannes Reifferscheid   Value cosQ = builder.create<math::CosOp>(q, fmf);
102377dd4357SJohannes Reifferscheid   Value sinQ = builder.create<math::SinOp>(q, fmf);
10246c6eddb6Sbixia1 
102577dd4357SJohannes Reifferscheid   Value inf = builder.create<arith::ConstantOp>(
102677dd4357SJohannes Reifferscheid       elementType,
102777dd4357SJohannes Reifferscheid       builder.getFloatAttr(elementType,
102877dd4357SJohannes Reifferscheid                            APFloat::getInf(elementType.getFloatSemantics())));
10296c6eddb6Sbixia1   Value zero = builder.create<arith::ConstantOp>(
103077dd4357SJohannes Reifferscheid       elementType, builder.getFloatAttr(elementType, 0.0));
10316c6eddb6Sbixia1   Value one = builder.create<arith::ConstantOp>(
103277dd4357SJohannes Reifferscheid       elementType, builder.getFloatAttr(elementType, 1.0));
10336c6eddb6Sbixia1   Value complexOne = builder.create<complex::CreateOp>(type, one, zero);
103477dd4357SJohannes Reifferscheid   Value complexZero = builder.create<complex::CreateOp>(type, zero, zero);
103577dd4357SJohannes Reifferscheid   Value complexInf = builder.create<complex::CreateOp>(type, inf, zero);
10366c6eddb6Sbixia1 
103777dd4357SJohannes Reifferscheid   // Case 0:
103877dd4357SJohannes Reifferscheid   // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
10396c6eddb6Sbixia1   // Branch Cuts for Complex Elementary Functions or Much Ado About
10406c6eddb6Sbixia1   // Nothing's Sign Bit, W. Kahan, Section 10.
104177dd4357SJohannes Reifferscheid   Value absEqZero =
104277dd4357SJohannes Reifferscheid       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, abs, zero, fmf);
104377dd4357SJohannes Reifferscheid   Value dEqZero =
104477dd4357SJohannes Reifferscheid       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero, fmf);
104577dd4357SJohannes Reifferscheid   Value cEqZero =
104677dd4357SJohannes Reifferscheid       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero, fmf);
104777dd4357SJohannes Reifferscheid   Value bEqZero =
104877dd4357SJohannes Reifferscheid       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, b, zero, fmf);
104977dd4357SJohannes Reifferscheid 
105077dd4357SJohannes Reifferscheid   Value zeroLeC =
105177dd4357SJohannes Reifferscheid       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLE, zero, c, fmf);
105277dd4357SJohannes Reifferscheid   Value coeffCosQ = builder.create<arith::MulFOp>(coeff, cosQ, fmf);
105377dd4357SJohannes Reifferscheid   Value coeffSinQ = builder.create<arith::MulFOp>(coeff, sinQ, fmf);
105477dd4357SJohannes Reifferscheid   Value complexOneOrZero =
105577dd4357SJohannes Reifferscheid       builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero);
105677dd4357SJohannes Reifferscheid   Value coeffCosSin =
105777dd4357SJohannes Reifferscheid       builder.create<complex::CreateOp>(type, coeffCosQ, coeffSinQ);
105877dd4357SJohannes Reifferscheid   Value cutoff0 = builder.create<arith::SelectOp>(
105977dd4357SJohannes Reifferscheid       builder.create<arith::AndIOp>(
106077dd4357SJohannes Reifferscheid           builder.create<arith::AndIOp>(absEqZero, dEqZero), zeroLeC),
106177dd4357SJohannes Reifferscheid       complexOneOrZero, coeffCosSin);
106277dd4357SJohannes Reifferscheid 
106377dd4357SJohannes Reifferscheid   // Case 1:
106477dd4357SJohannes Reifferscheid   // x^0 is defined to be 1 for any x, see
106577dd4357SJohannes Reifferscheid   // Branch Cuts for Complex Elementary Functions or Much Ado About
106677dd4357SJohannes Reifferscheid   // Nothing's Sign Bit, W. Kahan, Section 10.
106777dd4357SJohannes Reifferscheid   Value rhsEqZero = builder.create<arith::AndIOp>(cEqZero, dEqZero);
106877dd4357SJohannes Reifferscheid   Value cutoff1 =
106977dd4357SJohannes Reifferscheid       builder.create<arith::SelectOp>(rhsEqZero, complexOne, cutoff0);
107077dd4357SJohannes Reifferscheid 
107177dd4357SJohannes Reifferscheid   // Case 2:
107277dd4357SJohannes Reifferscheid   // 1^(c + d*i) = 1 + 0*i
107377dd4357SJohannes Reifferscheid   Value lhsEqOne = builder.create<arith::AndIOp>(
1074ff9bc3a0SJohannes Reifferscheid       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf),
107577dd4357SJohannes Reifferscheid       bEqZero);
107677dd4357SJohannes Reifferscheid   Value cutoff2 =
107777dd4357SJohannes Reifferscheid       builder.create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1);
107877dd4357SJohannes Reifferscheid 
107977dd4357SJohannes Reifferscheid   // Case 3:
108077dd4357SJohannes Reifferscheid   // inf^(c + 0*i) = inf + 0*i, c > 0
108177dd4357SJohannes Reifferscheid   Value lhsEqInf = builder.create<arith::AndIOp>(
1082ff9bc3a0SJohannes Reifferscheid       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf),
108377dd4357SJohannes Reifferscheid       bEqZero);
108477dd4357SJohannes Reifferscheid   Value rhsGt0 = builder.create<arith::AndIOp>(
108577dd4357SJohannes Reifferscheid       dEqZero,
1086ff9bc3a0SJohannes Reifferscheid       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf));
108777dd4357SJohannes Reifferscheid   Value cutoff3 = builder.create<arith::SelectOp>(
108877dd4357SJohannes Reifferscheid       builder.create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2);
108977dd4357SJohannes Reifferscheid 
109077dd4357SJohannes Reifferscheid   // Case 4:
109177dd4357SJohannes Reifferscheid   // inf^(c + 0*i) = 0 + 0*i, c < 0
109277dd4357SJohannes Reifferscheid   Value rhsLt0 = builder.create<arith::AndIOp>(
109377dd4357SJohannes Reifferscheid       dEqZero,
1094ff9bc3a0SJohannes Reifferscheid       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf));
109577dd4357SJohannes Reifferscheid   Value cutoff4 = builder.create<arith::SelectOp>(
109677dd4357SJohannes Reifferscheid       builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3);
109777dd4357SJohannes Reifferscheid 
109877dd4357SJohannes Reifferscheid   return cutoff4;
10996c6eddb6Sbixia1 }
11006c6eddb6Sbixia1 
11016c6eddb6Sbixia1 struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
11026c6eddb6Sbixia1   using OpConversionPattern<complex::PowOp>::OpConversionPattern;
11036c6eddb6Sbixia1 
11046c6eddb6Sbixia1   LogicalResult
11056c6eddb6Sbixia1   matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
11066c6eddb6Sbixia1                   ConversionPatternRewriter &rewriter) const override {
11076c6eddb6Sbixia1     mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
11085550c821STres Popp     auto type = cast<ComplexType>(adaptor.getLhs().getType());
11095550c821STres Popp     auto elementType = cast<FloatType>(type.getElementType());
11106c6eddb6Sbixia1 
11116c6eddb6Sbixia1     Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
11126c6eddb6Sbixia1     Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
11136c6eddb6Sbixia1 
111477dd4357SJohannes Reifferscheid     rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
111577dd4357SJohannes Reifferscheid                                                 c, d, op.getFastmath())});
11166c6eddb6Sbixia1     return success();
11176c6eddb6Sbixia1   }
11186c6eddb6Sbixia1 };
11196c6eddb6Sbixia1 
11206c6eddb6Sbixia1 struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
11216c6eddb6Sbixia1   using OpConversionPattern<complex::RsqrtOp>::OpConversionPattern;
11226c6eddb6Sbixia1 
11236c6eddb6Sbixia1   LogicalResult
11246c6eddb6Sbixia1   matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
11256c6eddb6Sbixia1                   ConversionPatternRewriter &rewriter) const override {
112633e60f35SJohannes Reifferscheid     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
11275550c821STres Popp     auto type = cast<ComplexType>(adaptor.getComplex().getType());
11285550c821STres Popp     auto elementType = cast<FloatType>(type.getElementType());
11296c6eddb6Sbixia1 
113033e60f35SJohannes Reifferscheid     arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
11316c6eddb6Sbixia1 
113233e60f35SJohannes Reifferscheid     auto cst = [&](APFloat v) {
113333e60f35SJohannes Reifferscheid       return b.create<arith::ConstantOp>(elementType,
113433e60f35SJohannes Reifferscheid                                          b.getFloatAttr(elementType, v));
113533e60f35SJohannes Reifferscheid     };
113633e60f35SJohannes Reifferscheid     const auto &floatSemantics = elementType.getFloatSemantics();
113733e60f35SJohannes Reifferscheid     Value zero = cst(APFloat::getZero(floatSemantics));
113833e60f35SJohannes Reifferscheid     Value inf = cst(APFloat::getInf(floatSemantics));
113933e60f35SJohannes Reifferscheid     Value negHalf = b.create<arith::ConstantOp>(
114033e60f35SJohannes Reifferscheid         elementType, b.getFloatAttr(elementType, -0.5));
114133e60f35SJohannes Reifferscheid     Value nan = cst(APFloat::getNaN(floatSemantics));
114233e60f35SJohannes Reifferscheid 
114333e60f35SJohannes Reifferscheid     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
114433e60f35SJohannes Reifferscheid     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
114533e60f35SJohannes Reifferscheid     Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
114633e60f35SJohannes Reifferscheid     Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
114733e60f35SJohannes Reifferscheid     Value rsqrtArg = b.create<arith::MulFOp>(argArg, negHalf, fmf);
114833e60f35SJohannes Reifferscheid     Value cos = b.create<math::CosOp>(rsqrtArg, fmf);
114933e60f35SJohannes Reifferscheid     Value sin = b.create<math::SinOp>(rsqrtArg, fmf);
115033e60f35SJohannes Reifferscheid 
115133e60f35SJohannes Reifferscheid     Value resultReal = b.create<arith::MulFOp>(absRsqrt, cos, fmf);
115233e60f35SJohannes Reifferscheid     Value resultImag = b.create<arith::MulFOp>(absRsqrt, sin, fmf);
115333e60f35SJohannes Reifferscheid 
115433e60f35SJohannes Reifferscheid     if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
115533e60f35SJohannes Reifferscheid                                             arith::FastMathFlags::ninf)) {
115633e60f35SJohannes Reifferscheid       Value negOne = b.create<arith::ConstantOp>(
115733e60f35SJohannes Reifferscheid           elementType, b.getFloatAttr(elementType, -1));
115833e60f35SJohannes Reifferscheid 
115933e60f35SJohannes Reifferscheid       Value realSignedZero = b.create<math::CopySignOp>(zero, real, fmf);
116033e60f35SJohannes Reifferscheid       Value imagSignedZero = b.create<math::CopySignOp>(zero, imag, fmf);
116133e60f35SJohannes Reifferscheid       Value negImagSignedZero =
116233e60f35SJohannes Reifferscheid           b.create<arith::MulFOp>(negOne, imagSignedZero, fmf);
116333e60f35SJohannes Reifferscheid 
116433e60f35SJohannes Reifferscheid       Value absReal = b.create<math::AbsFOp>(real, fmf);
116533e60f35SJohannes Reifferscheid       Value absImag = b.create<math::AbsFOp>(imag, fmf);
116633e60f35SJohannes Reifferscheid 
116733e60f35SJohannes Reifferscheid       Value absImagIsInf =
116833e60f35SJohannes Reifferscheid           b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
116933e60f35SJohannes Reifferscheid       Value realIsNan =
117033e60f35SJohannes Reifferscheid           b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf);
117133e60f35SJohannes Reifferscheid       Value realIsInf =
117233e60f35SJohannes Reifferscheid           b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
117333e60f35SJohannes Reifferscheid       Value inIsNanInf = b.create<arith::AndIOp>(absImagIsInf, realIsNan);
117433e60f35SJohannes Reifferscheid 
117533e60f35SJohannes Reifferscheid       Value resultIsZero = b.create<arith::OrIOp>(inIsNanInf, realIsInf);
117633e60f35SJohannes Reifferscheid 
117733e60f35SJohannes Reifferscheid       resultReal =
117833e60f35SJohannes Reifferscheid           b.create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal);
117933e60f35SJohannes Reifferscheid       resultImag = b.create<arith::SelectOp>(resultIsZero, negImagSignedZero,
118033e60f35SJohannes Reifferscheid                                              resultImag);
118133e60f35SJohannes Reifferscheid     }
118233e60f35SJohannes Reifferscheid 
118333e60f35SJohannes Reifferscheid     Value isRealZero =
118433e60f35SJohannes Reifferscheid         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf);
118533e60f35SJohannes Reifferscheid     Value isImagZero =
118633e60f35SJohannes Reifferscheid         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
118733e60f35SJohannes Reifferscheid     Value isZero = b.create<arith::AndIOp>(isRealZero, isImagZero);
118833e60f35SJohannes Reifferscheid 
118933e60f35SJohannes Reifferscheid     resultReal = b.create<arith::SelectOp>(isZero, inf, resultReal);
119033e60f35SJohannes Reifferscheid     resultImag = b.create<arith::SelectOp>(isZero, nan, resultImag);
119133e60f35SJohannes Reifferscheid 
119233e60f35SJohannes Reifferscheid     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
119333e60f35SJohannes Reifferscheid                                                    resultImag);
11946c6eddb6Sbixia1     return success();
11956c6eddb6Sbixia1   }
11966c6eddb6Sbixia1 };
11976c6eddb6Sbixia1 
11988fa2e679SLewuathe struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
11998fa2e679SLewuathe   using OpConversionPattern<complex::AngleOp>::OpConversionPattern;
12008fa2e679SLewuathe 
12018fa2e679SLewuathe   LogicalResult
12028fa2e679SLewuathe   matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
12038fa2e679SLewuathe                   ConversionPatternRewriter &rewriter) const override {
12048fa2e679SLewuathe     auto loc = op.getLoc();
12058fa2e679SLewuathe     auto type = op.getType();
12068c9d814bSKai Sasaki     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
12078fa2e679SLewuathe 
12088fa2e679SLewuathe     Value real =
12098fa2e679SLewuathe         rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
12108fa2e679SLewuathe     Value imag =
12118fa2e679SLewuathe         rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
12128fa2e679SLewuathe 
12138c9d814bSKai Sasaki     rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real, fmf);
12148fa2e679SLewuathe 
12158fa2e679SLewuathe     return success();
12168fa2e679SLewuathe   }
12178fa2e679SLewuathe };
12188fa2e679SLewuathe 
12192ea7fb7bSAdrian Kuegel } // namespace
12202ea7fb7bSAdrian Kuegel 
12212ea7fb7bSAdrian Kuegel void mlir::populateComplexToStandardConversionPatterns(
12222ea7fb7bSAdrian Kuegel     RewritePatternSet &patterns) {
1223f112bd61SAdrian Kuegel   // clang-format off
1224f112bd61SAdrian Kuegel   patterns.add<
1225f112bd61SAdrian Kuegel       AbsOpConversion,
12268fa2e679SLewuathe       AngleOpConversion,
1227f711785eSAlexander Belyaev       Atan2OpConversion,
1228a54f4eaeSMogball       BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1229a54f4eaeSMogball       BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
123062a34f6aSlewuathe       ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
123162a34f6aSlewuathe       ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
123262a34f6aSlewuathe       ConjOpConversion,
1233672b908bSGoran Flegar       CosOpConversion,
1234f112bd61SAdrian Kuegel       DivOpConversion,
1235f112bd61SAdrian Kuegel       ExpOpConversion,
1236338e76f8Sbixia1       Expm1OpConversion,
12376e80e3bdSAdrian Kuegel       Log1pOpConversion,
123862a34f6aSlewuathe       LogOpConversion,
1239bf17ee19SAdrian Kuegel       MulOpConversion,
1240f112bd61SAdrian Kuegel       NegOpConversion,
1241672b908bSGoran Flegar       SignOpConversion,
12426d75c897Slewuathe       SinOpConversion,
1243f711785eSAlexander Belyaev       SqrtOpConversion,
1244f43deca2SJohannes Reifferscheid       TanTanhOpConversion<complex::TanOp>,
1245f43deca2SJohannes Reifferscheid       TanTanhOpConversion<complex::TanhOp>,
12466c6eddb6Sbixia1       PowOpConversion,
12476c6eddb6Sbixia1       RsqrtOpConversion
124862a34f6aSlewuathe   >(patterns.getContext());
1249f112bd61SAdrian Kuegel   // clang-format on
12502ea7fb7bSAdrian Kuegel }
12512ea7fb7bSAdrian Kuegel 
12522ea7fb7bSAdrian Kuegel namespace {
12532ea7fb7bSAdrian Kuegel struct ConvertComplexToStandardPass
125467d0d7acSMichele Scuttari     : public impl::ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
125541574554SRiver Riddle   void runOnOperation() override;
12562ea7fb7bSAdrian Kuegel };
12572ea7fb7bSAdrian Kuegel 
125841574554SRiver Riddle void ConvertComplexToStandardPass::runOnOperation() {
12592ea7fb7bSAdrian Kuegel   // Convert to the Standard dialect using the converter defined above.
12602ea7fb7bSAdrian Kuegel   RewritePatternSet patterns(&getContext());
12612ea7fb7bSAdrian Kuegel   populateComplexToStandardConversionPatterns(patterns);
12622ea7fb7bSAdrian Kuegel 
12632ea7fb7bSAdrian Kuegel   ConversionTarget target(getContext());
1264abc362a1SJakub Kuderski   target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
1265fb978f09SAdrian Kuegel   target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
126647f175b0SRiver Riddle   if (failed(
126747f175b0SRiver Riddle           applyPartialConversion(getOperation(), target, std::move(patterns))))
12682ea7fb7bSAdrian Kuegel     signalPassFailure();
12692ea7fb7bSAdrian Kuegel }
12702ea7fb7bSAdrian Kuegel } // namespace
1271039b969bSMichele Scuttari 
1272039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createConvertComplexToStandardPass() {
1273039b969bSMichele Scuttari   return std::make_unique<ConvertComplexToStandardPass>();
1274039b969bSMichele Scuttari }
1275