xref: /llvm-project/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (revision 6e80e3bd1bef3e7408b29a6d7eda0efbb829a65f)
1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
10 
11 #include <memory>
12 #include <type_traits>
13 
14 #include "../PassDetail.h"
15 #include "mlir/Dialect/Complex/IR/Complex.h"
16 #include "mlir/Dialect/Math/IR/Math.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/ImplicitLocOpBuilder.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
26   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
27 
28   LogicalResult
29   matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands,
30                   ConversionPatternRewriter &rewriter) const override {
31     complex::AbsOp::Adaptor transformed(operands);
32     auto loc = op.getLoc();
33     auto type = op.getType();
34 
35     Value real =
36         rewriter.create<complex::ReOp>(loc, type, transformed.complex());
37     Value imag =
38         rewriter.create<complex::ImOp>(loc, type, transformed.complex());
39     Value realSqr = rewriter.create<MulFOp>(loc, real, real);
40     Value imagSqr = rewriter.create<MulFOp>(loc, imag, imag);
41     Value sqNorm = rewriter.create<AddFOp>(loc, realSqr, imagSqr);
42 
43     rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
44     return success();
45   }
46 };
47 
48 template <typename ComparisonOp, CmpFPredicate p>
49 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
50   using OpConversionPattern<ComparisonOp>::OpConversionPattern;
51   using ResultCombiner =
52       std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
53                          AndOp, OrOp>;
54 
55   LogicalResult
56   matchAndRewrite(ComparisonOp op, ArrayRef<Value> operands,
57                   ConversionPatternRewriter &rewriter) const override {
58     typename ComparisonOp::Adaptor transformed(operands);
59     auto loc = op.getLoc();
60     auto type = transformed.lhs()
61                     .getType()
62                     .template cast<ComplexType>()
63                     .getElementType();
64 
65     Value realLhs =
66         rewriter.create<complex::ReOp>(loc, type, transformed.lhs());
67     Value imagLhs =
68         rewriter.create<complex::ImOp>(loc, type, transformed.lhs());
69     Value realRhs =
70         rewriter.create<complex::ReOp>(loc, type, transformed.rhs());
71     Value imagRhs =
72         rewriter.create<complex::ImOp>(loc, type, transformed.rhs());
73     Value realComparison = rewriter.create<CmpFOp>(loc, p, realLhs, realRhs);
74     Value imagComparison = rewriter.create<CmpFOp>(loc, p, imagLhs, imagRhs);
75 
76     rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
77                                                 imagComparison);
78     return success();
79   }
80 };
81 
82 struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
83   using OpConversionPattern<complex::DivOp>::OpConversionPattern;
84 
85   LogicalResult
86   matchAndRewrite(complex::DivOp op, ArrayRef<Value> operands,
87                   ConversionPatternRewriter &rewriter) const override {
88     complex::DivOp::Adaptor transformed(operands);
89     auto loc = op.getLoc();
90     auto type = transformed.lhs().getType().cast<ComplexType>();
91     auto elementType = type.getElementType().cast<FloatType>();
92 
93     Value lhsReal =
94         rewriter.create<complex::ReOp>(loc, elementType, transformed.lhs());
95     Value lhsImag =
96         rewriter.create<complex::ImOp>(loc, elementType, transformed.lhs());
97     Value rhsReal =
98         rewriter.create<complex::ReOp>(loc, elementType, transformed.rhs());
99     Value rhsImag =
100         rewriter.create<complex::ImOp>(loc, elementType, transformed.rhs());
101 
102     // Smith's algorithm to divide complex numbers. It is just a bit smarter
103     // way to compute the following formula:
104     //  (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
105     //    = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
106     //          ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
107     //    = ((lhsReal * rhsReal + lhsImag * rhsImag) +
108     //          (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
109     //
110     // Depending on whether |rhsReal| < |rhsImag| we compute either
111     //   rhsRealImagRatio = rhsReal / rhsImag
112     //   rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
113     //   resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
114     //   resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
115     //
116     // or
117     //
118     //   rhsImagRealRatio = rhsImag / rhsReal
119     //   rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
120     //   resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
121     //   resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
122     //
123     // See https://dl.acm.org/citation.cfm?id=368661 for more details.
124     Value rhsRealImagRatio = rewriter.create<DivFOp>(loc, rhsReal, rhsImag);
125     Value rhsRealImagDenom = rewriter.create<AddFOp>(
126         loc, rhsImag, rewriter.create<MulFOp>(loc, rhsRealImagRatio, rhsReal));
127     Value realNumerator1 = rewriter.create<AddFOp>(
128         loc, rewriter.create<MulFOp>(loc, lhsReal, rhsRealImagRatio), lhsImag);
129     Value resultReal1 =
130         rewriter.create<DivFOp>(loc, realNumerator1, rhsRealImagDenom);
131     Value imagNumerator1 = rewriter.create<SubFOp>(
132         loc, rewriter.create<MulFOp>(loc, lhsImag, rhsRealImagRatio), lhsReal);
133     Value resultImag1 =
134         rewriter.create<DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
135 
136     Value rhsImagRealRatio = rewriter.create<DivFOp>(loc, rhsImag, rhsReal);
137     Value rhsImagRealDenom = rewriter.create<AddFOp>(
138         loc, rhsReal, rewriter.create<MulFOp>(loc, rhsImagRealRatio, rhsImag));
139     Value realNumerator2 = rewriter.create<AddFOp>(
140         loc, lhsReal, rewriter.create<MulFOp>(loc, lhsImag, rhsImagRealRatio));
141     Value resultReal2 =
142         rewriter.create<DivFOp>(loc, realNumerator2, rhsImagRealDenom);
143     Value imagNumerator2 = rewriter.create<SubFOp>(
144         loc, lhsImag, rewriter.create<MulFOp>(loc, lhsReal, rhsImagRealRatio));
145     Value resultImag2 =
146         rewriter.create<DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
147 
148     // Consider corner cases.
149     // Case 1. Zero denominator, numerator contains at most one NaN value.
150     Value zero = rewriter.create<ConstantOp>(loc, elementType,
151                                              rewriter.getZeroAttr(elementType));
152     Value rhsRealAbs = rewriter.create<AbsFOp>(loc, rhsReal);
153     Value rhsRealIsZero =
154         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsRealAbs, zero);
155     Value rhsImagAbs = rewriter.create<AbsFOp>(loc, rhsImag);
156     Value rhsImagIsZero =
157         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsImagAbs, zero);
158     Value lhsRealIsNotNaN =
159         rewriter.create<CmpFOp>(loc, CmpFPredicate::ORD, lhsReal, zero);
160     Value lhsImagIsNotNaN =
161         rewriter.create<CmpFOp>(loc, CmpFPredicate::ORD, lhsImag, zero);
162     Value lhsContainsNotNaNValue =
163         rewriter.create<OrOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
164     Value resultIsInfinity = rewriter.create<AndOp>(
165         loc, lhsContainsNotNaNValue,
166         rewriter.create<AndOp>(loc, rhsRealIsZero, rhsImagIsZero));
167     Value inf = rewriter.create<ConstantOp>(
168         loc, elementType,
169         rewriter.getFloatAttr(
170             elementType, APFloat::getInf(elementType.getFloatSemantics())));
171     Value infWithSignOfRhsReal = rewriter.create<CopySignOp>(loc, inf, rhsReal);
172     Value infinityResultReal =
173         rewriter.create<MulFOp>(loc, infWithSignOfRhsReal, lhsReal);
174     Value infinityResultImag =
175         rewriter.create<MulFOp>(loc, infWithSignOfRhsReal, lhsImag);
176 
177     // Case 2. Infinite numerator, finite denominator.
178     Value rhsRealFinite =
179         rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, rhsRealAbs, inf);
180     Value rhsImagFinite =
181         rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, rhsImagAbs, inf);
182     Value rhsFinite = rewriter.create<AndOp>(loc, rhsRealFinite, rhsImagFinite);
183     Value lhsRealAbs = rewriter.create<AbsFOp>(loc, lhsReal);
184     Value lhsRealInfinite =
185         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, lhsRealAbs, inf);
186     Value lhsImagAbs = rewriter.create<AbsFOp>(loc, lhsImag);
187     Value lhsImagInfinite =
188         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, lhsImagAbs, inf);
189     Value lhsInfinite =
190         rewriter.create<OrOp>(loc, lhsRealInfinite, lhsImagInfinite);
191     Value infNumFiniteDenom =
192         rewriter.create<AndOp>(loc, lhsInfinite, rhsFinite);
193     Value one = rewriter.create<ConstantOp>(
194         loc, elementType, rewriter.getFloatAttr(elementType, 1));
195     Value lhsRealIsInfWithSign = rewriter.create<CopySignOp>(
196         loc, rewriter.create<SelectOp>(loc, lhsRealInfinite, one, zero),
197         lhsReal);
198     Value lhsImagIsInfWithSign = rewriter.create<CopySignOp>(
199         loc, rewriter.create<SelectOp>(loc, lhsImagInfinite, one, zero),
200         lhsImag);
201     Value lhsRealIsInfWithSignTimesRhsReal =
202         rewriter.create<MulFOp>(loc, lhsRealIsInfWithSign, rhsReal);
203     Value lhsImagIsInfWithSignTimesRhsImag =
204         rewriter.create<MulFOp>(loc, lhsImagIsInfWithSign, rhsImag);
205     Value resultReal3 = rewriter.create<MulFOp>(
206         loc, inf,
207         rewriter.create<AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
208                                 lhsImagIsInfWithSignTimesRhsImag));
209     Value lhsRealIsInfWithSignTimesRhsImag =
210         rewriter.create<MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
211     Value lhsImagIsInfWithSignTimesRhsReal =
212         rewriter.create<MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
213     Value resultImag3 = rewriter.create<MulFOp>(
214         loc, inf,
215         rewriter.create<SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
216                                 lhsRealIsInfWithSignTimesRhsImag));
217 
218     // Case 3: Finite numerator, infinite denominator.
219     Value lhsRealFinite =
220         rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, lhsRealAbs, inf);
221     Value lhsImagFinite =
222         rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, lhsImagAbs, inf);
223     Value lhsFinite = rewriter.create<AndOp>(loc, lhsRealFinite, lhsImagFinite);
224     Value rhsRealInfinite =
225         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsRealAbs, inf);
226     Value rhsImagInfinite =
227         rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsImagAbs, inf);
228     Value rhsInfinite =
229         rewriter.create<OrOp>(loc, rhsRealInfinite, rhsImagInfinite);
230     Value finiteNumInfiniteDenom =
231         rewriter.create<AndOp>(loc, lhsFinite, rhsInfinite);
232     Value rhsRealIsInfWithSign = rewriter.create<CopySignOp>(
233         loc, rewriter.create<SelectOp>(loc, rhsRealInfinite, one, zero),
234         rhsReal);
235     Value rhsImagIsInfWithSign = rewriter.create<CopySignOp>(
236         loc, rewriter.create<SelectOp>(loc, rhsImagInfinite, one, zero),
237         rhsImag);
238     Value rhsRealIsInfWithSignTimesLhsReal =
239         rewriter.create<MulFOp>(loc, lhsReal, rhsRealIsInfWithSign);
240     Value rhsImagIsInfWithSignTimesLhsImag =
241         rewriter.create<MulFOp>(loc, lhsImag, rhsImagIsInfWithSign);
242     Value resultReal4 = rewriter.create<MulFOp>(
243         loc, zero,
244         rewriter.create<AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
245                                 rhsImagIsInfWithSignTimesLhsImag));
246     Value rhsRealIsInfWithSignTimesLhsImag =
247         rewriter.create<MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
248     Value rhsImagIsInfWithSignTimesLhsReal =
249         rewriter.create<MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
250     Value resultImag4 = rewriter.create<MulFOp>(
251         loc, zero,
252         rewriter.create<SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
253                                 rhsImagIsInfWithSignTimesLhsReal));
254 
255     Value realAbsSmallerThanImagAbs = rewriter.create<CmpFOp>(
256         loc, CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
257     Value resultReal = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs,
258                                                  resultReal1, resultReal2);
259     Value resultImag = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs,
260                                                  resultImag1, resultImag2);
261     Value resultRealSpecialCase3 = rewriter.create<SelectOp>(
262         loc, finiteNumInfiniteDenom, resultReal4, resultReal);
263     Value resultImagSpecialCase3 = rewriter.create<SelectOp>(
264         loc, finiteNumInfiniteDenom, resultImag4, resultImag);
265     Value resultRealSpecialCase2 = rewriter.create<SelectOp>(
266         loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
267     Value resultImagSpecialCase2 = rewriter.create<SelectOp>(
268         loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
269     Value resultRealSpecialCase1 = rewriter.create<SelectOp>(
270         loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
271     Value resultImagSpecialCase1 = rewriter.create<SelectOp>(
272         loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
273 
274     Value resultRealIsNaN =
275         rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, resultReal, zero);
276     Value resultImagIsNaN =
277         rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, resultImag, zero);
278     Value resultIsNaN =
279         rewriter.create<AndOp>(loc, resultRealIsNaN, resultImagIsNaN);
280     Value resultRealWithSpecialCases = rewriter.create<SelectOp>(
281         loc, resultIsNaN, resultRealSpecialCase1, resultReal);
282     Value resultImagWithSpecialCases = rewriter.create<SelectOp>(
283         loc, resultIsNaN, resultImagSpecialCase1, resultImag);
284 
285     rewriter.replaceOpWithNewOp<complex::CreateOp>(
286         op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
287     return success();
288   }
289 };
290 
291 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
292   using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
293 
294   LogicalResult
295   matchAndRewrite(complex::ExpOp op, ArrayRef<Value> operands,
296                   ConversionPatternRewriter &rewriter) const override {
297     complex::ExpOp::Adaptor transformed(operands);
298     auto loc = op.getLoc();
299     auto type = transformed.complex().getType().cast<ComplexType>();
300     auto elementType = type.getElementType().cast<FloatType>();
301 
302     Value real =
303         rewriter.create<complex::ReOp>(loc, elementType, transformed.complex());
304     Value imag =
305         rewriter.create<complex::ImOp>(loc, elementType, transformed.complex());
306     Value expReal = rewriter.create<math::ExpOp>(loc, real);
307     Value cosImag = rewriter.create<math::CosOp>(loc, imag);
308     Value resultReal = rewriter.create<MulFOp>(loc, expReal, cosImag);
309     Value sinImag = rewriter.create<math::SinOp>(loc, imag);
310     Value resultImag = rewriter.create<MulFOp>(loc, expReal, sinImag);
311 
312     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
313                                                    resultImag);
314     return success();
315   }
316 };
317 
318 struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
319   using OpConversionPattern<complex::LogOp>::OpConversionPattern;
320 
321   LogicalResult
322   matchAndRewrite(complex::LogOp op, ArrayRef<Value> operands,
323                   ConversionPatternRewriter &rewriter) const override {
324     complex::LogOp::Adaptor transformed(operands);
325     auto type = transformed.complex().getType().cast<ComplexType>();
326     auto elementType = type.getElementType().cast<FloatType>();
327     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
328 
329     Value abs = b.create<complex::AbsOp>(elementType, transformed.complex());
330     Value resultReal = b.create<math::LogOp>(elementType, abs);
331     Value real = b.create<complex::ReOp>(elementType, transformed.complex());
332     Value imag = b.create<complex::ImOp>(elementType, transformed.complex());
333     Value resultImag = b.create<math::Atan2Op>(elementType, imag, real);
334     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
335                                                    resultImag);
336     return success();
337   }
338 };
339 
340 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
341   using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
342 
343   LogicalResult
344   matchAndRewrite(complex::Log1pOp op, ArrayRef<Value> operands,
345                   ConversionPatternRewriter &rewriter) const override {
346     complex::Log1pOp::Adaptor transformed(operands);
347     auto type = transformed.complex().getType().cast<ComplexType>();
348     auto elementType = type.getElementType().cast<FloatType>();
349     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
350 
351     Value real = b.create<complex::ReOp>(elementType, transformed.complex());
352     Value imag = b.create<complex::ImOp>(elementType, transformed.complex());
353     Value one =
354         b.create<ConstantOp>(elementType, b.getFloatAttr(elementType, 1));
355     Value realPlusOne = b.create<AddFOp>(real, one);
356     Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag);
357     rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex);
358     return success();
359   }
360 };
361 
362 struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
363   using OpConversionPattern<complex::MulOp>::OpConversionPattern;
364 
365   LogicalResult
366   matchAndRewrite(complex::MulOp op, ArrayRef<Value> operands,
367                   ConversionPatternRewriter &rewriter) const override {
368     complex::MulOp::Adaptor transformed(operands);
369     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
370     auto type = transformed.lhs().getType().cast<ComplexType>();
371     auto elementType = type.getElementType().cast<FloatType>();
372 
373     Value lhsReal = b.create<complex::ReOp>(elementType, transformed.lhs());
374     Value lhsRealAbs = b.create<AbsFOp>(lhsReal);
375     Value lhsImag = b.create<complex::ImOp>(elementType, transformed.lhs());
376     Value lhsImagAbs = b.create<AbsFOp>(lhsImag);
377     Value rhsReal = b.create<complex::ReOp>(elementType, transformed.rhs());
378     Value rhsRealAbs = b.create<AbsFOp>(rhsReal);
379     Value rhsImag = b.create<complex::ImOp>(elementType, transformed.rhs());
380     Value rhsImagAbs = b.create<AbsFOp>(rhsImag);
381 
382     Value lhsRealTimesRhsReal = b.create<MulFOp>(lhsReal, rhsReal);
383     Value lhsRealTimesRhsRealAbs = b.create<AbsFOp>(lhsRealTimesRhsReal);
384     Value lhsImagTimesRhsImag = b.create<MulFOp>(lhsImag, rhsImag);
385     Value lhsImagTimesRhsImagAbs = b.create<AbsFOp>(lhsImagTimesRhsImag);
386     Value real = b.create<SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
387 
388     Value lhsImagTimesRhsReal = b.create<MulFOp>(lhsImag, rhsReal);
389     Value lhsImagTimesRhsRealAbs = b.create<AbsFOp>(lhsImagTimesRhsReal);
390     Value lhsRealTimesRhsImag = b.create<MulFOp>(lhsReal, rhsImag);
391     Value lhsRealTimesRhsImagAbs = b.create<AbsFOp>(lhsRealTimesRhsImag);
392     Value imag = b.create<AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
393 
394     // Handle cases where the "naive" calculation results in NaN values.
395     Value realIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, real, real);
396     Value imagIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, imag, imag);
397     Value isNan = b.create<AndOp>(realIsNan, imagIsNan);
398 
399     Value inf = b.create<ConstantOp>(
400         elementType,
401         b.getFloatAttr(elementType,
402                        APFloat::getInf(elementType.getFloatSemantics())));
403 
404     // Case 1. `lhsReal` or `lhsImag` are infinite.
405     Value lhsRealIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, lhsRealAbs, inf);
406     Value lhsImagIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, lhsImagAbs, inf);
407     Value lhsIsInf = b.create<OrOp>(lhsRealIsInf, lhsImagIsInf);
408     Value rhsRealIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, rhsReal, rhsReal);
409     Value rhsImagIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, rhsImag, rhsImag);
410     Value zero = b.create<ConstantOp>(elementType, b.getZeroAttr(elementType));
411     Value one =
412         b.create<ConstantOp>(elementType, b.getFloatAttr(elementType, 1));
413     Value lhsRealIsInfFloat = b.create<SelectOp>(lhsRealIsInf, one, zero);
414     lhsReal = b.create<SelectOp>(
415         lhsIsInf, b.create<CopySignOp>(lhsRealIsInfFloat, lhsReal), lhsReal);
416     Value lhsImagIsInfFloat = b.create<SelectOp>(lhsImagIsInf, one, zero);
417     lhsImag = b.create<SelectOp>(
418         lhsIsInf, b.create<CopySignOp>(lhsImagIsInfFloat, lhsImag), lhsImag);
419     Value lhsIsInfAndRhsRealIsNan = b.create<AndOp>(lhsIsInf, rhsRealIsNan);
420     rhsReal = b.create<SelectOp>(lhsIsInfAndRhsRealIsNan,
421                                  b.create<CopySignOp>(zero, rhsReal), rhsReal);
422     Value lhsIsInfAndRhsImagIsNan = b.create<AndOp>(lhsIsInf, rhsImagIsNan);
423     rhsImag = b.create<SelectOp>(lhsIsInfAndRhsImagIsNan,
424                                  b.create<CopySignOp>(zero, rhsImag), rhsImag);
425 
426     // Case 2. `rhsReal` or `rhsImag` are infinite.
427     Value rhsRealIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, rhsRealAbs, inf);
428     Value rhsImagIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, rhsImagAbs, inf);
429     Value rhsIsInf = b.create<OrOp>(rhsRealIsInf, rhsImagIsInf);
430     Value lhsRealIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, lhsReal, lhsReal);
431     Value lhsImagIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, lhsImag, lhsImag);
432     Value rhsRealIsInfFloat = b.create<SelectOp>(rhsRealIsInf, one, zero);
433     rhsReal = b.create<SelectOp>(
434         rhsIsInf, b.create<CopySignOp>(rhsRealIsInfFloat, rhsReal), rhsReal);
435     Value rhsImagIsInfFloat = b.create<SelectOp>(rhsImagIsInf, one, zero);
436     rhsImag = b.create<SelectOp>(
437         rhsIsInf, b.create<CopySignOp>(rhsImagIsInfFloat, rhsImag), rhsImag);
438     Value rhsIsInfAndLhsRealIsNan = b.create<AndOp>(rhsIsInf, lhsRealIsNan);
439     lhsReal = b.create<SelectOp>(rhsIsInfAndLhsRealIsNan,
440                                  b.create<CopySignOp>(zero, lhsReal), lhsReal);
441     Value rhsIsInfAndLhsImagIsNan = b.create<AndOp>(rhsIsInf, lhsImagIsNan);
442     lhsImag = b.create<SelectOp>(rhsIsInfAndLhsImagIsNan,
443                                  b.create<CopySignOp>(zero, lhsImag), lhsImag);
444     Value recalc = b.create<OrOp>(lhsIsInf, rhsIsInf);
445 
446     // Case 3. One of the pairwise products of left hand side with right hand
447     // side is infinite.
448     Value lhsRealTimesRhsRealIsInf =
449         b.create<CmpFOp>(CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
450     Value lhsImagTimesRhsImagIsInf =
451         b.create<CmpFOp>(CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
452     Value isSpecialCase =
453         b.create<OrOp>(lhsRealTimesRhsRealIsInf, lhsImagTimesRhsImagIsInf);
454     Value lhsRealTimesRhsImagIsInf =
455         b.create<CmpFOp>(CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
456     isSpecialCase = b.create<OrOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
457     Value lhsImagTimesRhsRealIsInf =
458         b.create<CmpFOp>(CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
459     isSpecialCase = b.create<OrOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
460     Type i1Type = b.getI1Type();
461     Value notRecalc = b.create<XOrOp>(
462         recalc, b.create<ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
463     isSpecialCase = b.create<AndOp>(isSpecialCase, notRecalc);
464     Value isSpecialCaseAndLhsRealIsNan =
465         b.create<AndOp>(isSpecialCase, lhsRealIsNan);
466     lhsReal = b.create<SelectOp>(isSpecialCaseAndLhsRealIsNan,
467                                  b.create<CopySignOp>(zero, lhsReal), lhsReal);
468     Value isSpecialCaseAndLhsImagIsNan =
469         b.create<AndOp>(isSpecialCase, lhsImagIsNan);
470     lhsImag = b.create<SelectOp>(isSpecialCaseAndLhsImagIsNan,
471                                  b.create<CopySignOp>(zero, lhsImag), lhsImag);
472     Value isSpecialCaseAndRhsRealIsNan =
473         b.create<AndOp>(isSpecialCase, rhsRealIsNan);
474     rhsReal = b.create<SelectOp>(isSpecialCaseAndRhsRealIsNan,
475                                  b.create<CopySignOp>(zero, rhsReal), rhsReal);
476     Value isSpecialCaseAndRhsImagIsNan =
477         b.create<AndOp>(isSpecialCase, rhsImagIsNan);
478     rhsImag = b.create<SelectOp>(isSpecialCaseAndRhsImagIsNan,
479                                  b.create<CopySignOp>(zero, rhsImag), rhsImag);
480     recalc = b.create<OrOp>(recalc, isSpecialCase);
481     recalc = b.create<AndOp>(isNan, recalc);
482 
483     // Recalculate real part.
484     lhsRealTimesRhsReal = b.create<MulFOp>(lhsReal, rhsReal);
485     lhsImagTimesRhsImag = b.create<MulFOp>(lhsImag, rhsImag);
486     Value newReal = b.create<SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
487     real = b.create<SelectOp>(recalc, b.create<MulFOp>(inf, newReal), real);
488 
489     // Recalculate imag part.
490     lhsImagTimesRhsReal = b.create<MulFOp>(lhsImag, rhsReal);
491     lhsRealTimesRhsImag = b.create<MulFOp>(lhsReal, rhsImag);
492     Value newImag = b.create<AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
493     imag = b.create<SelectOp>(recalc, b.create<MulFOp>(inf, newImag), imag);
494 
495     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
496     return success();
497   }
498 };
499 
500 struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
501   using OpConversionPattern<complex::NegOp>::OpConversionPattern;
502 
503   LogicalResult
504   matchAndRewrite(complex::NegOp op, ArrayRef<Value> operands,
505                   ConversionPatternRewriter &rewriter) const override {
506     complex::NegOp::Adaptor transformed(operands);
507     auto loc = op.getLoc();
508     auto type = transformed.complex().getType().cast<ComplexType>();
509     auto elementType = type.getElementType().cast<FloatType>();
510 
511     Value real =
512         rewriter.create<complex::ReOp>(loc, elementType, transformed.complex());
513     Value imag =
514         rewriter.create<complex::ImOp>(loc, elementType, transformed.complex());
515     Value negReal = rewriter.create<NegFOp>(loc, real);
516     Value negImag = rewriter.create<NegFOp>(loc, imag);
517     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
518     return success();
519   }
520 };
521 
522 struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
523   using OpConversionPattern<complex::SignOp>::OpConversionPattern;
524 
525   LogicalResult
526   matchAndRewrite(complex::SignOp op, ArrayRef<Value> operands,
527                   ConversionPatternRewriter &rewriter) const override {
528     complex::SignOp::Adaptor transformed(operands);
529     auto type = transformed.complex().getType().cast<ComplexType>();
530     auto elementType = type.getElementType().cast<FloatType>();
531     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
532 
533     Value real = b.create<complex::ReOp>(elementType, transformed.complex());
534     Value imag = b.create<complex::ImOp>(elementType, transformed.complex());
535     Value zero = b.create<ConstantOp>(elementType, b.getZeroAttr(elementType));
536     Value realIsZero = b.create<CmpFOp>(CmpFPredicate::OEQ, real, zero);
537     Value imagIsZero = b.create<CmpFOp>(CmpFPredicate::OEQ, imag, zero);
538     Value isZero = b.create<AndOp>(realIsZero, imagIsZero);
539     auto abs = b.create<complex::AbsOp>(elementType, transformed.complex());
540     Value realSign = b.create<DivFOp>(real, abs);
541     Value imagSign = b.create<DivFOp>(imag, abs);
542     Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
543     rewriter.replaceOpWithNewOp<SelectOp>(op, isZero, transformed.complex(),
544                                           sign);
545     return success();
546   }
547 };
548 } // namespace
549 
550 void mlir::populateComplexToStandardConversionPatterns(
551     RewritePatternSet &patterns) {
552   // clang-format off
553   patterns.add<
554       AbsOpConversion,
555       ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>,
556       ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>,
557       DivOpConversion,
558       ExpOpConversion,
559       LogOpConversion,
560       Log1pOpConversion,
561       MulOpConversion,
562       NegOpConversion,
563       SignOpConversion>(patterns.getContext());
564   // clang-format on
565 }
566 
567 namespace {
568 struct ConvertComplexToStandardPass
569     : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
570   void runOnFunction() override;
571 };
572 
573 void ConvertComplexToStandardPass::runOnFunction() {
574   auto function = getFunction();
575 
576   // Convert to the Standard dialect using the converter defined above.
577   RewritePatternSet patterns(&getContext());
578   populateComplexToStandardConversionPatterns(patterns);
579 
580   ConversionTarget target(getContext());
581   target.addLegalDialect<StandardOpsDialect, math::MathDialect,
582                          complex::ComplexDialect>();
583   target.addIllegalOp<complex::AbsOp, complex::DivOp, complex::EqualOp,
584                       complex::ExpOp, complex::LogOp, complex::Log1pOp,
585                       complex::MulOp, complex::NegOp, complex::NotEqualOp,
586                       complex::SignOp>();
587   if (failed(applyPartialConversion(function, target, std::move(patterns))))
588     signalPassFailure();
589 }
590 } // namespace
591 
592 std::unique_ptr<OperationPass<FuncOp>>
593 mlir::createConvertComplexToStandardPass() {
594   return std::make_unique<ConvertComplexToStandardPass>();
595 }
596