xref: /llvm-project/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (revision bdd365825d0766b6991c8f5443f8a9f76e75011a)
1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
10 
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/Complex/IR/Complex.h"
13 #include "mlir/Dialect/Math/IR/Math.h"
14 #include "mlir/IR/ImplicitLocOpBuilder.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 #include <memory>
19 #include <type_traits>
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARD
23 #include "mlir/Conversion/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 
28 namespace {
29 
30 enum class AbsFn { abs, sqrt, rsqrt };
31 
32 // Returns the absolute value, its square root or its reciprocal square root.
33 Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
34                  ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) {
35   Value one = b.create<arith::ConstantOp>(real.getType(),
36                                           b.getFloatAttr(real.getType(), 1.0));
37 
38   Value absReal = b.create<math::AbsFOp>(real, fmf);
39   Value absImag = b.create<math::AbsFOp>(imag, fmf);
40 
41   Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf);
42   Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf);
43 
44   // The lowering below requires NaNs and infinities to work correctly.
45   arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
46       fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
47   Value ratio = b.create<arith::DivFOp>(min, max, fmfWithNaNInf);
48   Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmfWithNaNInf);
49   Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmfWithNaNInf);
50   Value result;
51 
52   if (fn == AbsFn::rsqrt) {
53     ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
54     min = b.create<math::RsqrtOp>(min, fmfWithNaNInf);
55     max = b.create<math::RsqrtOp>(max, fmfWithNaNInf);
56   }
57 
58   if (fn == AbsFn::sqrt) {
59     Value quarter = b.create<arith::ConstantOp>(
60         real.getType(), b.getFloatAttr(real.getType(), 0.25));
61     // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
62     Value sqrt = b.create<math::SqrtOp>(max, fmfWithNaNInf);
63     Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmfWithNaNInf);
64     result = b.create<arith::MulFOp>(sqrt, p025, fmfWithNaNInf);
65   } else {
66     Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
67     result = b.create<arith::MulFOp>(max, sqrt, fmfWithNaNInf);
68   }
69 
70   Value isNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result,
71                                         result, fmfWithNaNInf);
72   return b.create<arith::SelectOp>(isNaN, min, result);
73 }
74 
75 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
76   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
77 
78   LogicalResult
79   matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
80                   ConversionPatternRewriter &rewriter) const override {
81     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
82 
83     arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
84 
85     Value real = b.create<complex::ReOp>(adaptor.getComplex());
86     Value imag = b.create<complex::ImOp>(adaptor.getComplex());
87     rewriter.replaceOp(op, computeAbs(real, imag, fmf, b));
88 
89     return success();
90   }
91 };
92 
93 // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
94 struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
95   using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;
96 
97   LogicalResult
98   matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
99                   ConversionPatternRewriter &rewriter) const override {
100     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
101 
102     auto type = cast<ComplexType>(op.getType());
103     Type elementType = type.getElementType();
104     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
105 
106     Value lhs = adaptor.getLhs();
107     Value rhs = adaptor.getRhs();
108 
109     Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs, fmf);
110     Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs, fmf);
111     Value rhsSquaredPlusLhsSquared =
112         b.create<complex::AddOp>(type, rhsSquared, lhsSquared, fmf);
113     Value sqrtOfRhsSquaredPlusLhsSquared =
114         b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared, fmf);
115 
116     Value zero =
117         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
118     Value one = b.create<arith::ConstantOp>(elementType,
119                                             b.getFloatAttr(elementType, 1));
120     Value i = b.create<complex::CreateOp>(type, zero, one);
121     Value iTimesLhs = b.create<complex::MulOp>(i, lhs, fmf);
122     Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs, fmf);
123 
124     Value divResult = b.create<complex::DivOp>(
125         rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
126     Value logResult = b.create<complex::LogOp>(divResult, fmf);
127 
128     Value negativeOne = b.create<arith::ConstantOp>(
129         elementType, b.getFloatAttr(elementType, -1));
130     Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
131 
132     rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult, fmf);
133     return success();
134   }
135 };
136 
137 template <typename ComparisonOp, arith::CmpFPredicate p>
138 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
139   using OpConversionPattern<ComparisonOp>::OpConversionPattern;
140   using ResultCombiner =
141       std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
142                          arith::AndIOp, arith::OrIOp>;
143 
144   LogicalResult
145   matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
146                   ConversionPatternRewriter &rewriter) const override {
147     auto loc = op.getLoc();
148     auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
149 
150     Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
151     Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
152     Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs());
153     Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs());
154     Value realComparison =
155         rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
156     Value imagComparison =
157         rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
158 
159     rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
160                                                 imagComparison);
161     return success();
162   }
163 };
164 
165 // Default conversion which applies the BinaryStandardOp separately on the real
166 // and imaginary parts. Can for example be used for complex::AddOp and
167 // complex::SubOp.
168 template <typename BinaryComplexOp, typename BinaryStandardOp>
169 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
170   using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
171 
172   LogicalResult
173   matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
174                   ConversionPatternRewriter &rewriter) const override {
175     auto type = cast<ComplexType>(adaptor.getLhs().getType());
176     auto elementType = cast<FloatType>(type.getElementType());
177     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
178     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
179 
180     Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
181     Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
182     Value resultReal = b.create<BinaryStandardOp>(elementType, realLhs, realRhs,
183                                                   fmf.getValue());
184     Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
185     Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
186     Value resultImag = b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs,
187                                                   fmf.getValue());
188     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
189                                                    resultImag);
190     return success();
191   }
192 };
193 
194 template <typename TrigonometricOp>
195 struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
196   using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
197 
198   using OpConversionPattern<TrigonometricOp>::OpConversionPattern;
199 
200   LogicalResult
201   matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
202                   ConversionPatternRewriter &rewriter) const override {
203     auto loc = op.getLoc();
204     auto type = cast<ComplexType>(adaptor.getComplex().getType());
205     auto elementType = cast<FloatType>(type.getElementType());
206     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
207 
208     Value real =
209         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
210     Value imag =
211         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
212 
213     // Trigonometric ops use a set of common building blocks to convert to real
214     // ops. Here we create these building blocks and call into an op-specific
215     // implementation in the subclass to combine them.
216     Value half = rewriter.create<arith::ConstantOp>(
217         loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
218     Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf);
219     Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp, fmf);
220     Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp, fmf);
221     Value sin = rewriter.create<math::SinOp>(loc, real, fmf);
222     Value cos = rewriter.create<math::CosOp>(loc, real, fmf);
223 
224     auto resultPair =
225         combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
226 
227     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
228                                                    resultPair.second);
229     return success();
230   }
231 
232   virtual std::pair<Value, Value>
233   combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
234           Value cos, ConversionPatternRewriter &rewriter,
235           arith::FastMathFlagsAttr fmf) const = 0;
236 };
237 
238 struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
239   using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
240 
241   std::pair<Value, Value> combine(Location loc, Value scaledExp,
242                                   Value reciprocalExp, Value sin, Value cos,
243                                   ConversionPatternRewriter &rewriter,
244                                   arith::FastMathFlagsAttr fmf) const override {
245     // Complex cosine is defined as;
246     //   cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
247     // Plugging in:
248     //   exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
249     //   exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
250     // and defining t := exp(y)
251     // We get:
252     //   Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
253     //   Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
254     Value sum =
255         rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf);
256     Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos, fmf);
257     Value diff =
258         rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf);
259     Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin, fmf);
260     return {resultReal, resultImag};
261   }
262 };
263 
264 struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
265   using OpConversionPattern<complex::DivOp>::OpConversionPattern;
266 
267   LogicalResult
268   matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
269                   ConversionPatternRewriter &rewriter) const override {
270     auto loc = op.getLoc();
271     auto type = cast<ComplexType>(adaptor.getLhs().getType());
272     auto elementType = cast<FloatType>(type.getElementType());
273     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
274 
275     Value lhsReal =
276         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
277     Value lhsImag =
278         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs());
279     Value rhsReal =
280         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs());
281     Value rhsImag =
282         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs());
283 
284     // Smith's algorithm to divide complex numbers. It is just a bit smarter
285     // way to compute the following formula:
286     //  (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
287     //    = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
288     //          ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
289     //    = ((lhsReal * rhsReal + lhsImag * rhsImag) +
290     //          (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
291     //
292     // Depending on whether |rhsReal| < |rhsImag| we compute either
293     //   rhsRealImagRatio = rhsReal / rhsImag
294     //   rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
295     //   resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
296     //   resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
297     //
298     // or
299     //
300     //   rhsImagRealRatio = rhsImag / rhsReal
301     //   rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
302     //   resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
303     //   resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
304     //
305     // See https://dl.acm.org/citation.cfm?id=368661 for more details.
306     Value rhsRealImagRatio =
307         rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag, fmf);
308     Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
309         loc, rhsImag,
310         rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal, fmf),
311         fmf);
312     Value realNumerator1 = rewriter.create<arith::AddFOp>(
313         loc,
314         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio, fmf),
315         lhsImag, fmf);
316     Value resultReal1 = rewriter.create<arith::DivFOp>(loc, realNumerator1,
317                                                        rhsRealImagDenom, fmf);
318     Value imagNumerator1 = rewriter.create<arith::SubFOp>(
319         loc,
320         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio, fmf),
321         lhsReal, fmf);
322     Value resultImag1 = rewriter.create<arith::DivFOp>(loc, imagNumerator1,
323                                                        rhsRealImagDenom, fmf);
324 
325     Value rhsImagRealRatio =
326         rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal, fmf);
327     Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
328         loc, rhsReal,
329         rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag, fmf),
330         fmf);
331     Value realNumerator2 = rewriter.create<arith::AddFOp>(
332         loc, lhsReal,
333         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio, fmf),
334         fmf);
335     Value resultReal2 = rewriter.create<arith::DivFOp>(loc, realNumerator2,
336                                                        rhsImagRealDenom, fmf);
337     Value imagNumerator2 = rewriter.create<arith::SubFOp>(
338         loc, lhsImag,
339         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio, fmf),
340         fmf);
341     Value resultImag2 = rewriter.create<arith::DivFOp>(loc, imagNumerator2,
342                                                        rhsImagRealDenom, fmf);
343 
344     // Consider corner cases.
345     // Case 1. Zero denominator, numerator contains at most one NaN value.
346     Value zero = rewriter.create<arith::ConstantOp>(
347         loc, elementType, rewriter.getZeroAttr(elementType));
348     Value rhsRealAbs = rewriter.create<math::AbsFOp>(loc, rhsReal, fmf);
349     Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
350         loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
351     Value rhsImagAbs = rewriter.create<math::AbsFOp>(loc, rhsImag, fmf);
352     Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
353         loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
354     Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
355         loc, arith::CmpFPredicate::ORD, lhsReal, zero);
356     Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>(
357         loc, arith::CmpFPredicate::ORD, lhsImag, zero);
358     Value lhsContainsNotNaNValue =
359         rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
360     Value resultIsInfinity = rewriter.create<arith::AndIOp>(
361         loc, lhsContainsNotNaNValue,
362         rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
363     Value inf = rewriter.create<arith::ConstantOp>(
364         loc, elementType,
365         rewriter.getFloatAttr(
366             elementType, APFloat::getInf(elementType.getFloatSemantics())));
367     Value infWithSignOfRhsReal =
368         rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
369     Value infinityResultReal =
370         rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal, fmf);
371     Value infinityResultImag =
372         rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag, fmf);
373 
374     // Case 2. Infinite numerator, finite denominator.
375     Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
376         loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
377     Value rhsImagFinite = rewriter.create<arith::CmpFOp>(
378         loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
379     Value rhsFinite =
380         rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
381     Value lhsRealAbs = rewriter.create<math::AbsFOp>(loc, lhsReal, fmf);
382     Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
383         loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
384     Value lhsImagAbs = rewriter.create<math::AbsFOp>(loc, lhsImag, fmf);
385     Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
386         loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
387     Value lhsInfinite =
388         rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
389     Value infNumFiniteDenom =
390         rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
391     Value one = rewriter.create<arith::ConstantOp>(
392         loc, elementType, rewriter.getFloatAttr(elementType, 1));
393     Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
394         loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero),
395         lhsReal);
396     Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
397         loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
398         lhsImag);
399     Value lhsRealIsInfWithSignTimesRhsReal =
400         rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal, fmf);
401     Value lhsImagIsInfWithSignTimesRhsImag =
402         rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag, fmf);
403     Value resultReal3 = rewriter.create<arith::MulFOp>(
404         loc, inf,
405         rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
406                                        lhsImagIsInfWithSignTimesRhsImag, fmf),
407         fmf);
408     Value lhsRealIsInfWithSignTimesRhsImag =
409         rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag, fmf);
410     Value lhsImagIsInfWithSignTimesRhsReal =
411         rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal, fmf);
412     Value resultImag3 = rewriter.create<arith::MulFOp>(
413         loc, inf,
414         rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
415                                        lhsRealIsInfWithSignTimesRhsImag, fmf),
416         fmf);
417 
418     // Case 3: Finite numerator, infinite denominator.
419     Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
420         loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
421     Value lhsImagFinite = rewriter.create<arith::CmpFOp>(
422         loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
423     Value lhsFinite =
424         rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
425     Value rhsRealInfinite = rewriter.create<arith::CmpFOp>(
426         loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
427     Value rhsImagInfinite = rewriter.create<arith::CmpFOp>(
428         loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
429     Value rhsInfinite =
430         rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
431     Value finiteNumInfiniteDenom =
432         rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
433     Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
434         loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero),
435         rhsReal);
436     Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
437         loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
438         rhsImag);
439     Value rhsRealIsInfWithSignTimesLhsReal =
440         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign, fmf);
441     Value rhsImagIsInfWithSignTimesLhsImag =
442         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign, fmf);
443     Value resultReal4 = rewriter.create<arith::MulFOp>(
444         loc, zero,
445         rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
446                                        rhsImagIsInfWithSignTimesLhsImag, fmf),
447         fmf);
448     Value rhsRealIsInfWithSignTimesLhsImag =
449         rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign, fmf);
450     Value rhsImagIsInfWithSignTimesLhsReal =
451         rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign, fmf);
452     Value resultImag4 = rewriter.create<arith::MulFOp>(
453         loc, zero,
454         rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
455                                        rhsImagIsInfWithSignTimesLhsReal, fmf),
456         fmf);
457 
458     Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
459         loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
460     Value resultReal = rewriter.create<arith::SelectOp>(
461         loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
462     Value resultImag = rewriter.create<arith::SelectOp>(
463         loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
464     Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>(
465         loc, finiteNumInfiniteDenom, resultReal4, resultReal);
466     Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>(
467         loc, finiteNumInfiniteDenom, resultImag4, resultImag);
468     Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>(
469         loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
470     Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>(
471         loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
472     Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>(
473         loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
474     Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>(
475         loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
476 
477     Value resultRealIsNaN = rewriter.create<arith::CmpFOp>(
478         loc, arith::CmpFPredicate::UNO, resultReal, zero);
479     Value resultImagIsNaN = rewriter.create<arith::CmpFOp>(
480         loc, arith::CmpFPredicate::UNO, resultImag, zero);
481     Value resultIsNaN =
482         rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
483     Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>(
484         loc, resultIsNaN, resultRealSpecialCase1, resultReal);
485     Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>(
486         loc, resultIsNaN, resultImagSpecialCase1, resultImag);
487 
488     rewriter.replaceOpWithNewOp<complex::CreateOp>(
489         op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
490     return success();
491   }
492 };
493 
494 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
495   using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
496 
497   LogicalResult
498   matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
499                   ConversionPatternRewriter &rewriter) const override {
500     auto loc = op.getLoc();
501     auto type = cast<ComplexType>(adaptor.getComplex().getType());
502     auto elementType = cast<FloatType>(type.getElementType());
503     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
504 
505     Value real =
506         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
507     Value imag =
508         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
509     Value expReal = rewriter.create<math::ExpOp>(loc, real, fmf.getValue());
510     Value cosImag = rewriter.create<math::CosOp>(loc, imag, fmf.getValue());
511     Value resultReal =
512         rewriter.create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue());
513     Value sinImag = rewriter.create<math::SinOp>(loc, imag, fmf.getValue());
514     Value resultImag =
515         rewriter.create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue());
516 
517     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
518                                                    resultImag);
519     return success();
520   }
521 };
522 
523 Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg,
524                          ArrayRef<double> coefficients,
525                          arith::FastMathFlagsAttr fmf) {
526   auto argType = mlir::cast<FloatType>(arg.getType());
527   Value poly =
528       b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[0]));
529   for (unsigned i = 1; i < coefficients.size(); ++i) {
530     poly = b.create<math::FmaOp>(
531         poly, arg,
532         b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[i])),
533         fmf);
534   }
535   return poly;
536 }
537 
538 struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
539   using OpConversionPattern<complex::Expm1Op>::OpConversionPattern;
540 
541   // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
542   //            [handle inaccuracies when a and/or b are small]
543   //            = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i
544   //            = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i
545   LogicalResult
546   matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
547                   ConversionPatternRewriter &rewriter) const override {
548     auto type = op.getType();
549     auto elemType = mlir::cast<FloatType>(type.getElementType());
550 
551     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
552     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
553     Value real = b.create<complex::ReOp>(adaptor.getComplex());
554     Value imag = b.create<complex::ImOp>(adaptor.getComplex());
555 
556     Value zero = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 0.0));
557     Value one = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 1.0));
558 
559     Value expm1Real = b.create<math::ExpM1Op>(real, fmf);
560     Value expReal = b.create<arith::AddFOp>(expm1Real, one, fmf);
561 
562     Value sinImag = b.create<math::SinOp>(imag, fmf);
563     Value cosm1Imag = emitCosm1(imag, fmf, b);
564     Value cosImag = b.create<arith::AddFOp>(cosm1Imag, one, fmf);
565 
566     Value realResult = b.create<arith::AddFOp>(
567         b.create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf);
568 
569     Value imagIsZero = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag,
570                                                zero, fmf.getValue());
571     Value imagResult = b.create<arith::SelectOp>(
572         imagIsZero, zero, b.create<arith::MulFOp>(expReal, sinImag, fmf));
573 
574     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realResult,
575                                                    imagResult);
576     return success();
577   }
578 
579 private:
580   Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf,
581                   ImplicitLocOpBuilder &b) const {
582     auto argType = mlir::cast<FloatType>(arg.getType());
583     auto negHalf = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -0.5));
584     auto negOne = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -1.0));
585 
586     // Algorithm copied from cephes cosm1.
587     SmallVector<double, 7> kCoeffs{
588         4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
589         2.0876754287081521758361E-9,  -2.7557319214999787979814E-7,
590         2.4801587301570552304991E-5,  -1.3888888888888872993737E-3,
591         4.1666666666666666609054E-2,
592     };
593     Value cos = b.create<math::CosOp>(arg, fmf);
594     Value forLargeArg = b.create<arith::AddFOp>(cos, negOne, fmf);
595 
596     Value argPow2 = b.create<arith::MulFOp>(arg, arg, fmf);
597     Value argPow4 = b.create<arith::MulFOp>(argPow2, argPow2, fmf);
598     Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf);
599 
600     auto forSmallArg =
601         b.create<arith::AddFOp>(b.create<arith::MulFOp>(argPow4, poly, fmf),
602                                 b.create<arith::MulFOp>(negHalf, argPow2, fmf));
603 
604     // (pi/4)^2 is approximately 0.61685
605     Value piOver4Pow2 =
606         b.create<arith::ConstantOp>(b.getFloatAttr(argType, 0.61685));
607     Value cond = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, argPow2,
608                                          piOver4Pow2, fmf.getValue());
609     return b.create<arith::SelectOp>(cond, forLargeArg, forSmallArg);
610   }
611 };
612 
613 struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
614   using OpConversionPattern<complex::LogOp>::OpConversionPattern;
615 
616   LogicalResult
617   matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
618                   ConversionPatternRewriter &rewriter) const override {
619     auto type = cast<ComplexType>(adaptor.getComplex().getType());
620     auto elementType = cast<FloatType>(type.getElementType());
621     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
622     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
623 
624     Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(),
625                                          fmf.getValue());
626     Value resultReal = b.create<math::LogOp>(elementType, abs, fmf.getValue());
627     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
628     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
629     Value resultImag =
630         b.create<math::Atan2Op>(elementType, imag, real, fmf.getValue());
631     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
632                                                    resultImag);
633     return success();
634   }
635 };
636 
637 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
638   using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
639 
640   LogicalResult
641   matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
642                   ConversionPatternRewriter &rewriter) const override {
643     auto type = cast<ComplexType>(adaptor.getComplex().getType());
644     auto elementType = cast<FloatType>(type.getElementType());
645     arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
646     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
647 
648     Value real = b.create<complex::ReOp>(adaptor.getComplex());
649     Value imag = b.create<complex::ImOp>(adaptor.getComplex());
650 
651     Value half = b.create<arith::ConstantOp>(elementType,
652                                              b.getFloatAttr(elementType, 0.5));
653     Value one = b.create<arith::ConstantOp>(elementType,
654                                             b.getFloatAttr(elementType, 1));
655     Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf);
656     Value absRealPlusOne = b.create<math::AbsFOp>(realPlusOne, fmf);
657     Value absImag = b.create<math::AbsFOp>(imag, fmf);
658 
659     Value maxAbs = b.create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf);
660     Value minAbs = b.create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf);
661 
662     Value useReal = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
663                                             realPlusOne, absImag, fmf);
664     Value maxMinusOne = b.create<arith::SubFOp>(maxAbs, one, fmf);
665     Value maxAbsOfRealPlusOneAndImagMinusOne =
666         b.create<arith::SelectOp>(useReal, real, maxMinusOne);
667     arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
668         fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
669     Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmfWithNaNInf);
670     Value logOfMaxAbsOfRealPlusOneAndImag =
671         b.create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf);
672     Value logOfSqrtPart = b.create<math::Log1pOp>(
673         b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmfWithNaNInf),
674         fmfWithNaNInf);
675     Value r = b.create<arith::AddFOp>(
676         b.create<arith::MulFOp>(half, logOfSqrtPart, fmfWithNaNInf),
677         logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf);
678     Value resultReal = b.create<arith::SelectOp>(
679         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf),
680         minAbs, r);
681     Value resultImag = b.create<math::Atan2Op>(imag, realPlusOne, fmf);
682     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
683                                                    resultImag);
684     return success();
685   }
686 };
687 
688 struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
689   using OpConversionPattern<complex::MulOp>::OpConversionPattern;
690 
691   LogicalResult
692   matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
693                   ConversionPatternRewriter &rewriter) const override {
694     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
695     auto type = cast<ComplexType>(adaptor.getLhs().getType());
696     auto elementType = cast<FloatType>(type.getElementType());
697     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
698     auto fmfValue = fmf.getValue();
699     Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
700     Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
701     Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
702     Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
703     Value lhsRealTimesRhsReal =
704         b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
705     Value lhsImagTimesRhsImag =
706         b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
707     Value real = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
708                                          lhsImagTimesRhsImag, fmfValue);
709     Value lhsImagTimesRhsReal =
710         b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
711     Value lhsRealTimesRhsImag =
712         b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
713     Value imag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
714                                          lhsRealTimesRhsImag, fmfValue);
715     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
716     return success();
717   }
718 };
719 
720 struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
721   using OpConversionPattern<complex::NegOp>::OpConversionPattern;
722 
723   LogicalResult
724   matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
725                   ConversionPatternRewriter &rewriter) const override {
726     auto loc = op.getLoc();
727     auto type = cast<ComplexType>(adaptor.getComplex().getType());
728     auto elementType = cast<FloatType>(type.getElementType());
729 
730     Value real =
731         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
732     Value imag =
733         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
734     Value negReal = rewriter.create<arith::NegFOp>(loc, real);
735     Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
736     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
737     return success();
738   }
739 };
740 
741 struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
742   using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
743 
744   std::pair<Value, Value> combine(Location loc, Value scaledExp,
745                                   Value reciprocalExp, Value sin, Value cos,
746                                   ConversionPatternRewriter &rewriter,
747                                   arith::FastMathFlagsAttr fmf) const override {
748     // Complex sine is defined as;
749     //   sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
750     // Plugging in:
751     //   exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
752     //   exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
753     // and defining t := exp(y)
754     // We get:
755     //   Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
756     //   Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
757     Value sum =
758         rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf);
759     Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin, fmf);
760     Value diff =
761         rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf);
762     Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos, fmf);
763     return {resultReal, resultImag};
764   }
765 };
766 
767 // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
768 struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
769   using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;
770 
771   LogicalResult
772   matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
773                   ConversionPatternRewriter &rewriter) const override {
774     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
775 
776     auto type = cast<ComplexType>(op.getType());
777     auto elementType = cast<FloatType>(type.getElementType());
778     arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
779 
780     auto cst = [&](APFloat v) {
781       return b.create<arith::ConstantOp>(elementType,
782                                          b.getFloatAttr(elementType, v));
783     };
784     const auto &floatSemantics = elementType.getFloatSemantics();
785     Value zero = cst(APFloat::getZero(floatSemantics));
786     Value half = b.create<arith::ConstantOp>(elementType,
787                                              b.getFloatAttr(elementType, 0.5));
788 
789     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
790     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
791     Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
792     Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
793     Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf);
794     Value cos = b.create<math::CosOp>(sqrtArg, fmf);
795     Value sin = b.create<math::SinOp>(sqrtArg, fmf);
796     // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply
797     // 0 * inf.
798     Value sinIsZero =
799         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf);
800 
801     Value resultReal = b.create<arith::MulFOp>(absSqrt, cos, fmf);
802     Value resultImag = b.create<arith::SelectOp>(
803         sinIsZero, zero, b.create<arith::MulFOp>(absSqrt, sin, fmf));
804     if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
805                                             arith::FastMathFlags::ninf)) {
806       Value inf = cst(APFloat::getInf(floatSemantics));
807       Value negInf = cst(APFloat::getInf(floatSemantics, true));
808       Value nan = cst(APFloat::getNaN(floatSemantics));
809       Value absImag = b.create<math::AbsFOp>(elementType, imag, fmf);
810 
811       Value absImagIsInf =
812           b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
813       Value absImagIsNotInf =
814           b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf);
815       Value realIsInf =
816           b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf);
817       Value realIsNegInf =
818           b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf);
819 
820       resultReal = b.create<arith::SelectOp>(
821           b.create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero,
822           resultReal);
823       resultReal = b.create<arith::SelectOp>(
824           b.create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal);
825 
826       Value imagSignInf = b.create<math::CopySignOp>(inf, imag, fmf);
827       resultImag = b.create<arith::SelectOp>(
828           b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt),
829           nan, resultImag);
830       resultImag = b.create<arith::SelectOp>(
831           b.create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf,
832           resultImag);
833     }
834 
835     Value resultIsZero =
836         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
837     resultReal = b.create<arith::SelectOp>(resultIsZero, zero, resultReal);
838     resultImag = b.create<arith::SelectOp>(resultIsZero, zero, resultImag);
839 
840     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
841                                                    resultImag);
842     return success();
843   }
844 };
845 
846 struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
847   using OpConversionPattern<complex::SignOp>::OpConversionPattern;
848 
849   LogicalResult
850   matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
851                   ConversionPatternRewriter &rewriter) const override {
852     auto type = cast<ComplexType>(adaptor.getComplex().getType());
853     auto elementType = cast<FloatType>(type.getElementType());
854     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
855     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
856 
857     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
858     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
859     Value zero =
860         b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
861     Value realIsZero =
862         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
863     Value imagIsZero =
864         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
865     Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
866     auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), fmf);
867     Value realSign = b.create<arith::DivFOp>(real, abs, fmf);
868     Value imagSign = b.create<arith::DivFOp>(imag, abs, fmf);
869     Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
870     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
871                                                  adaptor.getComplex(), sign);
872     return success();
873   }
874 };
875 
876 template <typename Op>
877 struct TanTanhOpConversion : public OpConversionPattern<Op> {
878   using OpConversionPattern<Op>::OpConversionPattern;
879 
880   LogicalResult
881   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
882                   ConversionPatternRewriter &rewriter) const override {
883     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
884     auto loc = op.getLoc();
885     auto type = cast<ComplexType>(adaptor.getComplex().getType());
886     auto elementType = cast<FloatType>(type.getElementType());
887     arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
888     const auto &floatSemantics = elementType.getFloatSemantics();
889 
890     Value real =
891         b.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
892     Value imag =
893         b.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
894     Value negOne = b.create<arith::ConstantOp>(
895         elementType, b.getFloatAttr(elementType, -1.0));
896 
897     if constexpr (std::is_same_v<Op, complex::TanOp>) {
898       // tan(x+yi) = -i*tanh(-y + xi)
899       std::swap(real, imag);
900       real = b.create<arith::MulFOp>(real, negOne, fmf);
901     }
902 
903     auto cst = [&](APFloat v) {
904       return b.create<arith::ConstantOp>(elementType,
905                                          b.getFloatAttr(elementType, v));
906     };
907     Value inf = cst(APFloat::getInf(floatSemantics));
908     Value four = b.create<arith::ConstantOp>(elementType,
909                                              b.getFloatAttr(elementType, 4.0));
910     Value twoReal = b.create<arith::AddFOp>(real, real, fmf);
911     Value negTwoReal = b.create<arith::MulFOp>(negOne, twoReal, fmf);
912 
913     Value expTwoRealMinusOne = b.create<math::ExpM1Op>(twoReal, fmf);
914     Value expNegTwoRealMinusOne = b.create<math::ExpM1Op>(negTwoReal, fmf);
915     Value realNum =
916         b.create<arith::SubFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
917 
918     Value cosImag = b.create<math::CosOp>(imag, fmf);
919     Value cosImagSq = b.create<arith::MulFOp>(cosImag, cosImag, fmf);
920     Value twoCosTwoImagPlusOne = b.create<arith::MulFOp>(cosImagSq, four, fmf);
921     Value sinImag = b.create<math::SinOp>(imag, fmf);
922 
923     Value imagNum = b.create<arith::MulFOp>(
924         four, b.create<arith::MulFOp>(cosImag, sinImag, fmf), fmf);
925 
926     Value expSumMinusTwo =
927         b.create<arith::AddFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
928     Value denom =
929         b.create<arith::AddFOp>(expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
930 
931     Value isInf = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
932                                           expSumMinusTwo, inf, fmf);
933     Value realLimit = b.create<math::CopySignOp>(negOne, real, fmf);
934 
935     Value resultReal = b.create<arith::SelectOp>(
936         isInf, realLimit, b.create<arith::DivFOp>(realNum, denom, fmf));
937     Value resultImag = b.create<arith::DivFOp>(imagNum, denom, fmf);
938 
939     if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
940                                             arith::FastMathFlags::ninf)) {
941       Value absReal = b.create<math::AbsFOp>(real, fmf);
942       Value zero = b.create<arith::ConstantOp>(
943           elementType, b.getFloatAttr(elementType, 0.0));
944       Value nan = cst(APFloat::getNaN(floatSemantics));
945 
946       Value absRealIsInf =
947           b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
948       Value imagIsZero =
949           b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
950       Value absRealIsNotInf = b.create<arith::XOrIOp>(
951           absRealIsInf, b.create<arith::ConstantIntOp>(true, /*width=*/1));
952 
953       Value imagNumIsNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO,
954                                                    imagNum, imagNum, fmf);
955       Value resultRealIsNaN =
956           b.create<arith::AndIOp>(imagNumIsNaN, absRealIsNotInf);
957       Value resultImagIsZero = b.create<arith::OrIOp>(
958           imagIsZero, b.create<arith::AndIOp>(absRealIsInf, imagNumIsNaN));
959 
960       resultReal = b.create<arith::SelectOp>(resultRealIsNaN, nan, resultReal);
961       resultImag =
962           b.create<arith::SelectOp>(resultImagIsZero, zero, resultImag);
963     }
964 
965     if constexpr (std::is_same_v<Op, complex::TanOp>) {
966       // tan(x+yi) = -i*tanh(-y + xi)
967       std::swap(resultReal, resultImag);
968       resultImag = b.create<arith::MulFOp>(resultImag, negOne, fmf);
969     }
970 
971     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
972                                                    resultImag);
973     return success();
974   }
975 };
976 
977 struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
978   using OpConversionPattern<complex::ConjOp>::OpConversionPattern;
979 
980   LogicalResult
981   matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
982                   ConversionPatternRewriter &rewriter) const override {
983     auto loc = op.getLoc();
984     auto type = cast<ComplexType>(adaptor.getComplex().getType());
985     auto elementType = cast<FloatType>(type.getElementType());
986     Value real =
987         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
988     Value imag =
989         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
990     Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag);
991 
992     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag);
993 
994     return success();
995   }
996 };
997 
998 /// Converts lhs^y = (a+bi)^(c+di) to
999 ///    (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
1000 ///    where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
1001 static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
1002                                  ComplexType type, Value lhs, Value c, Value d,
1003                                  arith::FastMathFlags fmf) {
1004   auto elementType = cast<FloatType>(type.getElementType());
1005 
1006   Value a = builder.create<complex::ReOp>(lhs);
1007   Value b = builder.create<complex::ImOp>(lhs);
1008 
1009   Value abs = builder.create<complex::AbsOp>(lhs, fmf);
1010   Value absToC = builder.create<math::PowFOp>(abs, c, fmf);
1011 
1012   Value negD = builder.create<arith::NegFOp>(d, fmf);
1013   Value argLhs = builder.create<math::Atan2Op>(b, a, fmf);
1014   Value negDArgLhs = builder.create<arith::MulFOp>(negD, argLhs, fmf);
1015   Value expNegDArgLhs = builder.create<math::ExpOp>(negDArgLhs, fmf);
1016 
1017   Value coeff = builder.create<arith::MulFOp>(absToC, expNegDArgLhs, fmf);
1018   Value lnAbs = builder.create<math::LogOp>(abs, fmf);
1019   Value cArgLhs = builder.create<arith::MulFOp>(c, argLhs, fmf);
1020   Value dLnAbs = builder.create<arith::MulFOp>(d, lnAbs, fmf);
1021   Value q = builder.create<arith::AddFOp>(cArgLhs, dLnAbs, fmf);
1022   Value cosQ = builder.create<math::CosOp>(q, fmf);
1023   Value sinQ = builder.create<math::SinOp>(q, fmf);
1024 
1025   Value inf = builder.create<arith::ConstantOp>(
1026       elementType,
1027       builder.getFloatAttr(elementType,
1028                            APFloat::getInf(elementType.getFloatSemantics())));
1029   Value zero = builder.create<arith::ConstantOp>(
1030       elementType, builder.getFloatAttr(elementType, 0.0));
1031   Value one = builder.create<arith::ConstantOp>(
1032       elementType, builder.getFloatAttr(elementType, 1.0));
1033   Value complexOne = builder.create<complex::CreateOp>(type, one, zero);
1034   Value complexZero = builder.create<complex::CreateOp>(type, zero, zero);
1035   Value complexInf = builder.create<complex::CreateOp>(type, inf, zero);
1036 
1037   // Case 0:
1038   // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
1039   // Branch Cuts for Complex Elementary Functions or Much Ado About
1040   // Nothing's Sign Bit, W. Kahan, Section 10.
1041   Value absEqZero =
1042       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, abs, zero, fmf);
1043   Value dEqZero =
1044       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero, fmf);
1045   Value cEqZero =
1046       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero, fmf);
1047   Value bEqZero =
1048       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, b, zero, fmf);
1049 
1050   Value zeroLeC =
1051       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLE, zero, c, fmf);
1052   Value coeffCosQ = builder.create<arith::MulFOp>(coeff, cosQ, fmf);
1053   Value coeffSinQ = builder.create<arith::MulFOp>(coeff, sinQ, fmf);
1054   Value complexOneOrZero =
1055       builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero);
1056   Value coeffCosSin =
1057       builder.create<complex::CreateOp>(type, coeffCosQ, coeffSinQ);
1058   Value cutoff0 = builder.create<arith::SelectOp>(
1059       builder.create<arith::AndIOp>(
1060           builder.create<arith::AndIOp>(absEqZero, dEqZero), zeroLeC),
1061       complexOneOrZero, coeffCosSin);
1062 
1063   // Case 1:
1064   // x^0 is defined to be 1 for any x, see
1065   // Branch Cuts for Complex Elementary Functions or Much Ado About
1066   // Nothing's Sign Bit, W. Kahan, Section 10.
1067   Value rhsEqZero = builder.create<arith::AndIOp>(cEqZero, dEqZero);
1068   Value cutoff1 =
1069       builder.create<arith::SelectOp>(rhsEqZero, complexOne, cutoff0);
1070 
1071   // Case 2:
1072   // 1^(c + d*i) = 1 + 0*i
1073   Value lhsEqOne = builder.create<arith::AndIOp>(
1074       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf),
1075       bEqZero);
1076   Value cutoff2 =
1077       builder.create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1);
1078 
1079   // Case 3:
1080   // inf^(c + 0*i) = inf + 0*i, c > 0
1081   Value lhsEqInf = builder.create<arith::AndIOp>(
1082       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf),
1083       bEqZero);
1084   Value rhsGt0 = builder.create<arith::AndIOp>(
1085       dEqZero,
1086       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf));
1087   Value cutoff3 = builder.create<arith::SelectOp>(
1088       builder.create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2);
1089 
1090   // Case 4:
1091   // inf^(c + 0*i) = 0 + 0*i, c < 0
1092   Value rhsLt0 = builder.create<arith::AndIOp>(
1093       dEqZero,
1094       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf));
1095   Value cutoff4 = builder.create<arith::SelectOp>(
1096       builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3);
1097 
1098   return cutoff4;
1099 }
1100 
1101 struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
1102   using OpConversionPattern<complex::PowOp>::OpConversionPattern;
1103 
1104   LogicalResult
1105   matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
1106                   ConversionPatternRewriter &rewriter) const override {
1107     mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
1108     auto type = cast<ComplexType>(adaptor.getLhs().getType());
1109     auto elementType = cast<FloatType>(type.getElementType());
1110 
1111     Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
1112     Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
1113 
1114     rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
1115                                                 c, d, op.getFastmath())});
1116     return success();
1117   }
1118 };
1119 
1120 struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
1121   using OpConversionPattern<complex::RsqrtOp>::OpConversionPattern;
1122 
1123   LogicalResult
1124   matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
1125                   ConversionPatternRewriter &rewriter) const override {
1126     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1127     auto type = cast<ComplexType>(adaptor.getComplex().getType());
1128     auto elementType = cast<FloatType>(type.getElementType());
1129 
1130     arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
1131 
1132     auto cst = [&](APFloat v) {
1133       return b.create<arith::ConstantOp>(elementType,
1134                                          b.getFloatAttr(elementType, v));
1135     };
1136     const auto &floatSemantics = elementType.getFloatSemantics();
1137     Value zero = cst(APFloat::getZero(floatSemantics));
1138     Value inf = cst(APFloat::getInf(floatSemantics));
1139     Value negHalf = b.create<arith::ConstantOp>(
1140         elementType, b.getFloatAttr(elementType, -0.5));
1141     Value nan = cst(APFloat::getNaN(floatSemantics));
1142 
1143     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
1144     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
1145     Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
1146     Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
1147     Value rsqrtArg = b.create<arith::MulFOp>(argArg, negHalf, fmf);
1148     Value cos = b.create<math::CosOp>(rsqrtArg, fmf);
1149     Value sin = b.create<math::SinOp>(rsqrtArg, fmf);
1150 
1151     Value resultReal = b.create<arith::MulFOp>(absRsqrt, cos, fmf);
1152     Value resultImag = b.create<arith::MulFOp>(absRsqrt, sin, fmf);
1153 
1154     if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1155                                             arith::FastMathFlags::ninf)) {
1156       Value negOne = b.create<arith::ConstantOp>(
1157           elementType, b.getFloatAttr(elementType, -1));
1158 
1159       Value realSignedZero = b.create<math::CopySignOp>(zero, real, fmf);
1160       Value imagSignedZero = b.create<math::CopySignOp>(zero, imag, fmf);
1161       Value negImagSignedZero =
1162           b.create<arith::MulFOp>(negOne, imagSignedZero, fmf);
1163 
1164       Value absReal = b.create<math::AbsFOp>(real, fmf);
1165       Value absImag = b.create<math::AbsFOp>(imag, fmf);
1166 
1167       Value absImagIsInf =
1168           b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
1169       Value realIsNan =
1170           b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf);
1171       Value realIsInf =
1172           b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
1173       Value inIsNanInf = b.create<arith::AndIOp>(absImagIsInf, realIsNan);
1174 
1175       Value resultIsZero = b.create<arith::OrIOp>(inIsNanInf, realIsInf);
1176 
1177       resultReal =
1178           b.create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal);
1179       resultImag = b.create<arith::SelectOp>(resultIsZero, negImagSignedZero,
1180                                              resultImag);
1181     }
1182 
1183     Value isRealZero =
1184         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf);
1185     Value isImagZero =
1186         b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
1187     Value isZero = b.create<arith::AndIOp>(isRealZero, isImagZero);
1188 
1189     resultReal = b.create<arith::SelectOp>(isZero, inf, resultReal);
1190     resultImag = b.create<arith::SelectOp>(isZero, nan, resultImag);
1191 
1192     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
1193                                                    resultImag);
1194     return success();
1195   }
1196 };
1197 
1198 struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
1199   using OpConversionPattern<complex::AngleOp>::OpConversionPattern;
1200 
1201   LogicalResult
1202   matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
1203                   ConversionPatternRewriter &rewriter) const override {
1204     auto loc = op.getLoc();
1205     auto type = op.getType();
1206     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
1207 
1208     Value real =
1209         rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
1210     Value imag =
1211         rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
1212 
1213     rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real, fmf);
1214 
1215     return success();
1216   }
1217 };
1218 
1219 } // namespace
1220 
1221 void mlir::populateComplexToStandardConversionPatterns(
1222     RewritePatternSet &patterns) {
1223   // clang-format off
1224   patterns.add<
1225       AbsOpConversion,
1226       AngleOpConversion,
1227       Atan2OpConversion,
1228       BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1229       BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
1230       ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
1231       ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
1232       ConjOpConversion,
1233       CosOpConversion,
1234       DivOpConversion,
1235       ExpOpConversion,
1236       Expm1OpConversion,
1237       Log1pOpConversion,
1238       LogOpConversion,
1239       MulOpConversion,
1240       NegOpConversion,
1241       SignOpConversion,
1242       SinOpConversion,
1243       SqrtOpConversion,
1244       TanTanhOpConversion<complex::TanOp>,
1245       TanTanhOpConversion<complex::TanhOp>,
1246       PowOpConversion,
1247       RsqrtOpConversion
1248   >(patterns.getContext());
1249   // clang-format on
1250 }
1251 
1252 namespace {
1253 struct ConvertComplexToStandardPass
1254     : public impl::ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
1255   void runOnOperation() override;
1256 };
1257 
1258 void ConvertComplexToStandardPass::runOnOperation() {
1259   // Convert to the Standard dialect using the converter defined above.
1260   RewritePatternSet patterns(&getContext());
1261   populateComplexToStandardConversionPatterns(patterns);
1262 
1263   ConversionTarget target(getContext());
1264   target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
1265   target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
1266   if (failed(
1267           applyPartialConversion(getOperation(), target, std::move(patterns))))
1268     signalPassFailure();
1269 }
1270 } // namespace
1271 
1272 std::unique_ptr<Pass> mlir::createConvertComplexToStandardPass() {
1273   return std::make_unique<ConvertComplexToStandardPass>();
1274 }
1275