//===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include #include namespace mlir { #define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARD #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; namespace { enum class AbsFn { abs, sqrt, rsqrt }; // Returns the absolute value, its square root or its reciprocal square root. Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf, ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) { Value one = b.create(real.getType(), b.getFloatAttr(real.getType(), 1.0)); Value absReal = b.create(real, fmf); Value absImag = b.create(imag, fmf); Value max = b.create(absReal, absImag, fmf); Value min = b.create(absReal, absImag, fmf); // The lowering below requires NaNs and infinities to work correctly. arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear( fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf); Value ratio = b.create(min, max, fmfWithNaNInf); Value ratioSq = b.create(ratio, ratio, fmfWithNaNInf); Value ratioSqPlusOne = b.create(ratioSq, one, fmfWithNaNInf); Value result; if (fn == AbsFn::rsqrt) { ratioSqPlusOne = b.create(ratioSqPlusOne, fmfWithNaNInf); min = b.create(min, fmfWithNaNInf); max = b.create(max, fmfWithNaNInf); } if (fn == AbsFn::sqrt) { Value quarter = b.create( real.getType(), b.getFloatAttr(real.getType(), 0.25)); // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily. Value sqrt = b.create(max, fmfWithNaNInf); Value p025 = b.create(ratioSqPlusOne, quarter, fmfWithNaNInf); result = b.create(sqrt, p025, fmfWithNaNInf); } else { Value sqrt = b.create(ratioSqPlusOne, fmfWithNaNInf); result = b.create(max, sqrt, fmfWithNaNInf); } Value isNaN = b.create(arith::CmpFPredicate::UNO, result, result, fmfWithNaNInf); return b.create(isNaN, min, result); } struct AbsOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { ImplicitLocOpBuilder b(op.getLoc(), rewriter); arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); Value real = b.create(adaptor.getComplex()); Value imag = b.create(adaptor.getComplex()); rewriter.replaceOp(op, computeAbs(real, imag, fmf, b)); return success(); } }; // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2)) struct Atan2OpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto type = cast(op.getType()); Type elementType = type.getElementType(); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value lhs = adaptor.getLhs(); Value rhs = adaptor.getRhs(); Value rhsSquared = b.create(type, rhs, rhs, fmf); Value lhsSquared = b.create(type, lhs, lhs, fmf); Value rhsSquaredPlusLhsSquared = b.create(type, rhsSquared, lhsSquared, fmf); Value sqrtOfRhsSquaredPlusLhsSquared = b.create(type, rhsSquaredPlusLhsSquared, fmf); Value zero = b.create(elementType, b.getZeroAttr(elementType)); Value one = b.create(elementType, b.getFloatAttr(elementType, 1)); Value i = b.create(type, zero, one); Value iTimesLhs = b.create(i, lhs, fmf); Value rhsPlusILhs = b.create(rhs, iTimesLhs, fmf); Value divResult = b.create( rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf); Value logResult = b.create(divResult, fmf); Value negativeOne = b.create( elementType, b.getFloatAttr(elementType, -1)); Value negativeI = b.create(type, zero, negativeOne); rewriter.replaceOpWithNewOp(op, negativeI, logResult, fmf); return success(); } }; template struct ComparisonOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; using ResultCombiner = std::conditional_t::value, arith::AndIOp, arith::OrIOp>; LogicalResult matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = cast(adaptor.getLhs().getType()).getElementType(); Value realLhs = rewriter.create(loc, type, adaptor.getLhs()); Value imagLhs = rewriter.create(loc, type, adaptor.getLhs()); Value realRhs = rewriter.create(loc, type, adaptor.getRhs()); Value imagRhs = rewriter.create(loc, type, adaptor.getRhs()); Value realComparison = rewriter.create(loc, p, realLhs, realRhs); Value imagComparison = rewriter.create(loc, p, imagLhs, imagRhs); rewriter.replaceOpWithNewOp(op, realComparison, imagComparison); return success(); } }; // Default conversion which applies the BinaryStandardOp separately on the real // and imaginary parts. Can for example be used for complex::AddOp and // complex::SubOp. template struct BinaryComplexOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = cast(adaptor.getLhs().getType()); auto elementType = cast(type.getElementType()); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value realLhs = b.create(elementType, adaptor.getLhs()); Value realRhs = b.create(elementType, adaptor.getRhs()); Value resultReal = b.create(elementType, realLhs, realRhs, fmf.getValue()); Value imagLhs = b.create(elementType, adaptor.getLhs()); Value imagRhs = b.create(elementType, adaptor.getRhs()); Value resultImag = b.create(elementType, imagLhs, imagRhs, fmf.getValue()); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); } }; template struct TrigonometricOpConversion : public OpConversionPattern { using OpAdaptor = typename OpConversionPattern::OpAdaptor; using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value real = rewriter.create(loc, elementType, adaptor.getComplex()); Value imag = rewriter.create(loc, elementType, adaptor.getComplex()); // Trigonometric ops use a set of common building blocks to convert to real // ops. Here we create these building blocks and call into an op-specific // implementation in the subclass to combine them. Value half = rewriter.create( loc, elementType, rewriter.getFloatAttr(elementType, 0.5)); Value exp = rewriter.create(loc, imag, fmf); Value scaledExp = rewriter.create(loc, half, exp, fmf); Value reciprocalExp = rewriter.create(loc, half, exp, fmf); Value sin = rewriter.create(loc, real, fmf); Value cos = rewriter.create(loc, real, fmf); auto resultPair = combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf); rewriter.replaceOpWithNewOp(op, type, resultPair.first, resultPair.second); return success(); } virtual std::pair combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin, Value cos, ConversionPatternRewriter &rewriter, arith::FastMathFlagsAttr fmf) const = 0; }; struct CosOpConversion : public TrigonometricOpConversion { using TrigonometricOpConversion::TrigonometricOpConversion; std::pair combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin, Value cos, ConversionPatternRewriter &rewriter, arith::FastMathFlagsAttr fmf) const override { // Complex cosine is defined as; // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy))) // Plugging in: // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x)) // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x))) // and defining t := exp(y) // We get: // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x Value sum = rewriter.create(loc, reciprocalExp, scaledExp, fmf); Value resultReal = rewriter.create(loc, sum, cos, fmf); Value diff = rewriter.create(loc, reciprocalExp, scaledExp, fmf); Value resultImag = rewriter.create(loc, diff, sin, fmf); return {resultReal, resultImag}; } }; struct DivOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = cast(adaptor.getLhs().getType()); auto elementType = cast(type.getElementType()); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value lhsReal = rewriter.create(loc, elementType, adaptor.getLhs()); Value lhsImag = rewriter.create(loc, elementType, adaptor.getLhs()); Value rhsReal = rewriter.create(loc, elementType, adaptor.getRhs()); Value rhsImag = rewriter.create(loc, elementType, adaptor.getRhs()); // Smith's algorithm to divide complex numbers. It is just a bit smarter // way to compute the following formula: // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i) // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) / // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i)) // = ((lhsReal * rhsReal + lhsImag * rhsImag) + // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2 // // Depending on whether |rhsReal| < |rhsImag| we compute either // rhsRealImagRatio = rhsReal / rhsImag // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom // // or // // rhsImagRealRatio = rhsImag / rhsReal // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom // // See https://dl.acm.org/citation.cfm?id=368661 for more details. Value rhsRealImagRatio = rewriter.create(loc, rhsReal, rhsImag, fmf); Value rhsRealImagDenom = rewriter.create( loc, rhsImag, rewriter.create(loc, rhsRealImagRatio, rhsReal, fmf), fmf); Value realNumerator1 = rewriter.create( loc, rewriter.create(loc, lhsReal, rhsRealImagRatio, fmf), lhsImag, fmf); Value resultReal1 = rewriter.create(loc, realNumerator1, rhsRealImagDenom, fmf); Value imagNumerator1 = rewriter.create( loc, rewriter.create(loc, lhsImag, rhsRealImagRatio, fmf), lhsReal, fmf); Value resultImag1 = rewriter.create(loc, imagNumerator1, rhsRealImagDenom, fmf); Value rhsImagRealRatio = rewriter.create(loc, rhsImag, rhsReal, fmf); Value rhsImagRealDenom = rewriter.create( loc, rhsReal, rewriter.create(loc, rhsImagRealRatio, rhsImag, fmf), fmf); Value realNumerator2 = rewriter.create( loc, lhsReal, rewriter.create(loc, lhsImag, rhsImagRealRatio, fmf), fmf); Value resultReal2 = rewriter.create(loc, realNumerator2, rhsImagRealDenom, fmf); Value imagNumerator2 = rewriter.create( loc, lhsImag, rewriter.create(loc, lhsReal, rhsImagRealRatio, fmf), fmf); Value resultImag2 = rewriter.create(loc, imagNumerator2, rhsImagRealDenom, fmf); // Consider corner cases. // Case 1. Zero denominator, numerator contains at most one NaN value. Value zero = rewriter.create( loc, elementType, rewriter.getZeroAttr(elementType)); Value rhsRealAbs = rewriter.create(loc, rhsReal, fmf); Value rhsRealIsZero = rewriter.create( loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero); Value rhsImagAbs = rewriter.create(loc, rhsImag, fmf); Value rhsImagIsZero = rewriter.create( loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero); Value lhsRealIsNotNaN = rewriter.create( loc, arith::CmpFPredicate::ORD, lhsReal, zero); Value lhsImagIsNotNaN = rewriter.create( loc, arith::CmpFPredicate::ORD, lhsImag, zero); Value lhsContainsNotNaNValue = rewriter.create(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); Value resultIsInfinity = rewriter.create( loc, lhsContainsNotNaNValue, rewriter.create(loc, rhsRealIsZero, rhsImagIsZero)); Value inf = rewriter.create( loc, elementType, rewriter.getFloatAttr( elementType, APFloat::getInf(elementType.getFloatSemantics()))); Value infWithSignOfRhsReal = rewriter.create(loc, inf, rhsReal); Value infinityResultReal = rewriter.create(loc, infWithSignOfRhsReal, lhsReal, fmf); Value infinityResultImag = rewriter.create(loc, infWithSignOfRhsReal, lhsImag, fmf); // Case 2. Infinite numerator, finite denominator. Value rhsRealFinite = rewriter.create( loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf); Value rhsImagFinite = rewriter.create( loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf); Value rhsFinite = rewriter.create(loc, rhsRealFinite, rhsImagFinite); Value lhsRealAbs = rewriter.create(loc, lhsReal, fmf); Value lhsRealInfinite = rewriter.create( loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf); Value lhsImagAbs = rewriter.create(loc, lhsImag, fmf); Value lhsImagInfinite = rewriter.create( loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf); Value lhsInfinite = rewriter.create(loc, lhsRealInfinite, lhsImagInfinite); Value infNumFiniteDenom = rewriter.create(loc, lhsInfinite, rhsFinite); Value one = rewriter.create( loc, elementType, rewriter.getFloatAttr(elementType, 1)); Value lhsRealIsInfWithSign = rewriter.create( loc, rewriter.create(loc, lhsRealInfinite, one, zero), lhsReal); Value lhsImagIsInfWithSign = rewriter.create( loc, rewriter.create(loc, lhsImagInfinite, one, zero), lhsImag); Value lhsRealIsInfWithSignTimesRhsReal = rewriter.create(loc, lhsRealIsInfWithSign, rhsReal, fmf); Value lhsImagIsInfWithSignTimesRhsImag = rewriter.create(loc, lhsImagIsInfWithSign, rhsImag, fmf); Value resultReal3 = rewriter.create( loc, inf, rewriter.create(loc, lhsRealIsInfWithSignTimesRhsReal, lhsImagIsInfWithSignTimesRhsImag, fmf), fmf); Value lhsRealIsInfWithSignTimesRhsImag = rewriter.create(loc, lhsRealIsInfWithSign, rhsImag, fmf); Value lhsImagIsInfWithSignTimesRhsReal = rewriter.create(loc, lhsImagIsInfWithSign, rhsReal, fmf); Value resultImag3 = rewriter.create( loc, inf, rewriter.create(loc, lhsImagIsInfWithSignTimesRhsReal, lhsRealIsInfWithSignTimesRhsImag, fmf), fmf); // Case 3: Finite numerator, infinite denominator. Value lhsRealFinite = rewriter.create( loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf); Value lhsImagFinite = rewriter.create( loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf); Value lhsFinite = rewriter.create(loc, lhsRealFinite, lhsImagFinite); Value rhsRealInfinite = rewriter.create( loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf); Value rhsImagInfinite = rewriter.create( loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf); Value rhsInfinite = rewriter.create(loc, rhsRealInfinite, rhsImagInfinite); Value finiteNumInfiniteDenom = rewriter.create(loc, lhsFinite, rhsInfinite); Value rhsRealIsInfWithSign = rewriter.create( loc, rewriter.create(loc, rhsRealInfinite, one, zero), rhsReal); Value rhsImagIsInfWithSign = rewriter.create( loc, rewriter.create(loc, rhsImagInfinite, one, zero), rhsImag); Value rhsRealIsInfWithSignTimesLhsReal = rewriter.create(loc, lhsReal, rhsRealIsInfWithSign, fmf); Value rhsImagIsInfWithSignTimesLhsImag = rewriter.create(loc, lhsImag, rhsImagIsInfWithSign, fmf); Value resultReal4 = rewriter.create( loc, zero, rewriter.create(loc, rhsRealIsInfWithSignTimesLhsReal, rhsImagIsInfWithSignTimesLhsImag, fmf), fmf); Value rhsRealIsInfWithSignTimesLhsImag = rewriter.create(loc, lhsImag, rhsRealIsInfWithSign, fmf); Value rhsImagIsInfWithSignTimesLhsReal = rewriter.create(loc, lhsReal, rhsImagIsInfWithSign, fmf); Value resultImag4 = rewriter.create( loc, zero, rewriter.create(loc, rhsRealIsInfWithSignTimesLhsImag, rhsImagIsInfWithSignTimesLhsReal, fmf), fmf); Value realAbsSmallerThanImagAbs = rewriter.create( loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); Value resultReal = rewriter.create( loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); Value resultImag = rewriter.create( loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); Value resultRealSpecialCase3 = rewriter.create( loc, finiteNumInfiniteDenom, resultReal4, resultReal); Value resultImagSpecialCase3 = rewriter.create( loc, finiteNumInfiniteDenom, resultImag4, resultImag); Value resultRealSpecialCase2 = rewriter.create( loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); Value resultImagSpecialCase2 = rewriter.create( loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); Value resultRealSpecialCase1 = rewriter.create( loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); Value resultImagSpecialCase1 = rewriter.create( loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); Value resultRealIsNaN = rewriter.create( loc, arith::CmpFPredicate::UNO, resultReal, zero); Value resultImagIsNaN = rewriter.create( loc, arith::CmpFPredicate::UNO, resultImag, zero); Value resultIsNaN = rewriter.create(loc, resultRealIsNaN, resultImagIsNaN); Value resultRealWithSpecialCases = rewriter.create( loc, resultIsNaN, resultRealSpecialCase1, resultReal); Value resultImagWithSpecialCases = rewriter.create( loc, resultIsNaN, resultImagSpecialCase1, resultImag); rewriter.replaceOpWithNewOp( op, type, resultRealWithSpecialCases, resultImagWithSpecialCases); return success(); } }; struct ExpOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value real = rewriter.create(loc, elementType, adaptor.getComplex()); Value imag = rewriter.create(loc, elementType, adaptor.getComplex()); Value expReal = rewriter.create(loc, real, fmf.getValue()); Value cosImag = rewriter.create(loc, imag, fmf.getValue()); Value resultReal = rewriter.create(loc, expReal, cosImag, fmf.getValue()); Value sinImag = rewriter.create(loc, imag, fmf.getValue()); Value resultImag = rewriter.create(loc, expReal, sinImag, fmf.getValue()); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); } }; Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg, ArrayRef coefficients, arith::FastMathFlagsAttr fmf) { auto argType = mlir::cast(arg.getType()); Value poly = b.create(b.getFloatAttr(argType, coefficients[0])); for (unsigned i = 1; i < coefficients.size(); ++i) { poly = b.create( poly, arg, b.create(b.getFloatAttr(argType, coefficients[i])), fmf); } return poly; } struct Expm1OpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i // [handle inaccuracies when a and/or b are small] // = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i // = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i LogicalResult matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = op.getType(); auto elemType = mlir::cast(type.getElementType()); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value real = b.create(adaptor.getComplex()); Value imag = b.create(adaptor.getComplex()); Value zero = b.create(b.getFloatAttr(elemType, 0.0)); Value one = b.create(b.getFloatAttr(elemType, 1.0)); Value expm1Real = b.create(real, fmf); Value expReal = b.create(expm1Real, one, fmf); Value sinImag = b.create(imag, fmf); Value cosm1Imag = emitCosm1(imag, fmf, b); Value cosImag = b.create(cosm1Imag, one, fmf); Value realResult = b.create( b.create(expm1Real, cosImag, fmf), cosm1Imag, fmf); Value imagIsZero = b.create(arith::CmpFPredicate::OEQ, imag, zero, fmf.getValue()); Value imagResult = b.create( imagIsZero, zero, b.create(expReal, sinImag, fmf)); rewriter.replaceOpWithNewOp(op, type, realResult, imagResult); return success(); } private: Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf, ImplicitLocOpBuilder &b) const { auto argType = mlir::cast(arg.getType()); auto negHalf = b.create(b.getFloatAttr(argType, -0.5)); auto negOne = b.create(b.getFloatAttr(argType, -1.0)); // Algorithm copied from cephes cosm1. SmallVector kCoeffs{ 4.7377507964246204691685E-14, -1.1470284843425359765671E-11, 2.0876754287081521758361E-9, -2.7557319214999787979814E-7, 2.4801587301570552304991E-5, -1.3888888888888872993737E-3, 4.1666666666666666609054E-2, }; Value cos = b.create(arg, fmf); Value forLargeArg = b.create(cos, negOne, fmf); Value argPow2 = b.create(arg, arg, fmf); Value argPow4 = b.create(argPow2, argPow2, fmf); Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf); auto forSmallArg = b.create(b.create(argPow4, poly, fmf), b.create(negHalf, argPow2, fmf)); // (pi/4)^2 is approximately 0.61685 Value piOver4Pow2 = b.create(b.getFloatAttr(argType, 0.61685)); Value cond = b.create(arith::CmpFPredicate::OGE, argPow2, piOver4Pow2, fmf.getValue()); return b.create(cond, forLargeArg, forSmallArg); } }; struct LogOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::LogOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value abs = b.create(elementType, adaptor.getComplex(), fmf.getValue()); Value resultReal = b.create(elementType, abs, fmf.getValue()); Value real = b.create(elementType, adaptor.getComplex()); Value imag = b.create(elementType, adaptor.getComplex()); Value resultImag = b.create(elementType, imag, real, fmf.getValue()); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); } }; struct Log1pOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value real = b.create(adaptor.getComplex()); Value imag = b.create(adaptor.getComplex()); Value half = b.create(elementType, b.getFloatAttr(elementType, 0.5)); Value one = b.create(elementType, b.getFloatAttr(elementType, 1)); Value realPlusOne = b.create(real, one, fmf); Value absRealPlusOne = b.create(realPlusOne, fmf); Value absImag = b.create(imag, fmf); Value maxAbs = b.create(absRealPlusOne, absImag, fmf); Value minAbs = b.create(absRealPlusOne, absImag, fmf); Value useReal = b.create(arith::CmpFPredicate::OGT, realPlusOne, absImag, fmf); Value maxMinusOne = b.create(maxAbs, one, fmf); Value maxAbsOfRealPlusOneAndImagMinusOne = b.create(useReal, real, maxMinusOne); arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear( fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf); Value minMaxRatio = b.create(minAbs, maxAbs, fmfWithNaNInf); Value logOfMaxAbsOfRealPlusOneAndImag = b.create(maxAbsOfRealPlusOneAndImagMinusOne, fmf); Value logOfSqrtPart = b.create( b.create(minMaxRatio, minMaxRatio, fmfWithNaNInf), fmfWithNaNInf); Value r = b.create( b.create(half, logOfSqrtPart, fmfWithNaNInf), logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf); Value resultReal = b.create( b.create(arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf), minAbs, r); Value resultImag = b.create(imag, realPlusOne, fmf); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); } }; struct MulOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto type = cast(adaptor.getLhs().getType()); auto elementType = cast(type.getElementType()); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); auto fmfValue = fmf.getValue(); Value lhsReal = b.create(elementType, adaptor.getLhs()); Value lhsImag = b.create(elementType, adaptor.getLhs()); Value rhsReal = b.create(elementType, adaptor.getRhs()); Value rhsImag = b.create(elementType, adaptor.getRhs()); Value lhsRealTimesRhsReal = b.create(lhsReal, rhsReal, fmfValue); Value lhsImagTimesRhsImag = b.create(lhsImag, rhsImag, fmfValue); Value real = b.create(lhsRealTimesRhsReal, lhsImagTimesRhsImag, fmfValue); Value lhsImagTimesRhsReal = b.create(lhsImag, rhsReal, fmfValue); Value lhsRealTimesRhsImag = b.create(lhsReal, rhsImag, fmfValue); Value imag = b.create(lhsImagTimesRhsReal, lhsRealTimesRhsImag, fmfValue); rewriter.replaceOpWithNewOp(op, type, real, imag); return success(); } }; struct NegOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::NegOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); Value real = rewriter.create(loc, elementType, adaptor.getComplex()); Value imag = rewriter.create(loc, elementType, adaptor.getComplex()); Value negReal = rewriter.create(loc, real); Value negImag = rewriter.create(loc, imag); rewriter.replaceOpWithNewOp(op, type, negReal, negImag); return success(); } }; struct SinOpConversion : public TrigonometricOpConversion { using TrigonometricOpConversion::TrigonometricOpConversion; std::pair combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin, Value cos, ConversionPatternRewriter &rewriter, arith::FastMathFlagsAttr fmf) const override { // Complex sine is defined as; // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy))) // Plugging in: // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x)) // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x))) // and defining t := exp(y) // We get: // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x Value sum = rewriter.create(loc, scaledExp, reciprocalExp, fmf); Value resultReal = rewriter.create(loc, sum, sin, fmf); Value diff = rewriter.create(loc, scaledExp, reciprocalExp, fmf); Value resultImag = rewriter.create(loc, diff, cos, fmf); return {resultReal, resultImag}; } }; // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780. struct SqrtOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto type = cast(op.getType()); auto elementType = cast(type.getElementType()); arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); auto cst = [&](APFloat v) { return b.create(elementType, b.getFloatAttr(elementType, v)); }; const auto &floatSemantics = elementType.getFloatSemantics(); Value zero = cst(APFloat::getZero(floatSemantics)); Value half = b.create(elementType, b.getFloatAttr(elementType, 0.5)); Value real = b.create(elementType, adaptor.getComplex()); Value imag = b.create(elementType, adaptor.getComplex()); Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt); Value argArg = b.create(imag, real, fmf); Value sqrtArg = b.create(argArg, half, fmf); Value cos = b.create(sqrtArg, fmf); Value sin = b.create(sqrtArg, fmf); // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply // 0 * inf. Value sinIsZero = b.create(arith::CmpFPredicate::OEQ, sin, zero, fmf); Value resultReal = b.create(absSqrt, cos, fmf); Value resultImag = b.create( sinIsZero, zero, b.create(absSqrt, sin, fmf)); if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf)) { Value inf = cst(APFloat::getInf(floatSemantics)); Value negInf = cst(APFloat::getInf(floatSemantics, true)); Value nan = cst(APFloat::getNaN(floatSemantics)); Value absImag = b.create(elementType, imag, fmf); Value absImagIsInf = b.create(arith::CmpFPredicate::OEQ, absImag, inf, fmf); Value absImagIsNotInf = b.create(arith::CmpFPredicate::ONE, absImag, inf, fmf); Value realIsInf = b.create(arith::CmpFPredicate::OEQ, real, inf, fmf); Value realIsNegInf = b.create(arith::CmpFPredicate::OEQ, real, negInf, fmf); resultReal = b.create( b.create(realIsNegInf, absImagIsNotInf), zero, resultReal); resultReal = b.create( b.create(absImagIsInf, realIsInf), inf, resultReal); Value imagSignInf = b.create(inf, imag, fmf); resultImag = b.create( b.create(arith::CmpFPredicate::UNO, absSqrt, absSqrt), nan, resultImag); resultImag = b.create( b.create(absImagIsInf, realIsNegInf), imagSignInf, resultImag); } Value resultIsZero = b.create(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf); resultReal = b.create(resultIsZero, zero, resultReal); resultImag = b.create(resultIsZero, zero, resultImag); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); } }; struct SignOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::SignOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value real = b.create(elementType, adaptor.getComplex()); Value imag = b.create(elementType, adaptor.getComplex()); Value zero = b.create(elementType, b.getZeroAttr(elementType)); Value realIsZero = b.create(arith::CmpFPredicate::OEQ, real, zero); Value imagIsZero = b.create(arith::CmpFPredicate::OEQ, imag, zero); Value isZero = b.create(realIsZero, imagIsZero); auto abs = b.create(elementType, adaptor.getComplex(), fmf); Value realSign = b.create(real, abs, fmf); Value imagSign = b.create(imag, abs, fmf); Value sign = b.create(type, realSign, imagSign); rewriter.replaceOpWithNewOp(op, isZero, adaptor.getComplex(), sign); return success(); } }; template struct TanTanhOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto loc = op.getLoc(); auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); const auto &floatSemantics = elementType.getFloatSemantics(); Value real = b.create(loc, elementType, adaptor.getComplex()); Value imag = b.create(loc, elementType, adaptor.getComplex()); Value negOne = b.create( elementType, b.getFloatAttr(elementType, -1.0)); if constexpr (std::is_same_v) { // tan(x+yi) = -i*tanh(-y + xi) std::swap(real, imag); real = b.create(real, negOne, fmf); } auto cst = [&](APFloat v) { return b.create(elementType, b.getFloatAttr(elementType, v)); }; Value inf = cst(APFloat::getInf(floatSemantics)); Value four = b.create(elementType, b.getFloatAttr(elementType, 4.0)); Value twoReal = b.create(real, real, fmf); Value negTwoReal = b.create(negOne, twoReal, fmf); Value expTwoRealMinusOne = b.create(twoReal, fmf); Value expNegTwoRealMinusOne = b.create(negTwoReal, fmf); Value realNum = b.create(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf); Value cosImag = b.create(imag, fmf); Value cosImagSq = b.create(cosImag, cosImag, fmf); Value twoCosTwoImagPlusOne = b.create(cosImagSq, four, fmf); Value sinImag = b.create(imag, fmf); Value imagNum = b.create( four, b.create(cosImag, sinImag, fmf), fmf); Value expSumMinusTwo = b.create(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf); Value denom = b.create(expSumMinusTwo, twoCosTwoImagPlusOne, fmf); Value isInf = b.create(arith::CmpFPredicate::OEQ, expSumMinusTwo, inf, fmf); Value realLimit = b.create(negOne, real, fmf); Value resultReal = b.create( isInf, realLimit, b.create(realNum, denom, fmf)); Value resultImag = b.create(imagNum, denom, fmf); if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf)) { Value absReal = b.create(real, fmf); Value zero = b.create( elementType, b.getFloatAttr(elementType, 0.0)); Value nan = cst(APFloat::getNaN(floatSemantics)); Value absRealIsInf = b.create(arith::CmpFPredicate::OEQ, absReal, inf, fmf); Value imagIsZero = b.create(arith::CmpFPredicate::OEQ, imag, zero, fmf); Value absRealIsNotInf = b.create( absRealIsInf, b.create(true, /*width=*/1)); Value imagNumIsNaN = b.create(arith::CmpFPredicate::UNO, imagNum, imagNum, fmf); Value resultRealIsNaN = b.create(imagNumIsNaN, absRealIsNotInf); Value resultImagIsZero = b.create( imagIsZero, b.create(absRealIsInf, imagNumIsNaN)); resultReal = b.create(resultRealIsNaN, nan, resultReal); resultImag = b.create(resultImagIsZero, zero, resultImag); } if constexpr (std::is_same_v) { // tan(x+yi) = -i*tanh(-y + xi) std::swap(resultReal, resultImag); resultImag = b.create(resultImag, negOne, fmf); } rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); } }; struct ConjOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); Value real = rewriter.create(loc, elementType, adaptor.getComplex()); Value imag = rewriter.create(loc, elementType, adaptor.getComplex()); Value negImag = rewriter.create(loc, elementType, imag); rewriter.replaceOpWithNewOp(op, type, real, negImag); return success(); } }; /// Converts lhs^y = (a+bi)^(c+di) to /// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), /// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder, ComplexType type, Value lhs, Value c, Value d, arith::FastMathFlags fmf) { auto elementType = cast(type.getElementType()); Value a = builder.create(lhs); Value b = builder.create(lhs); Value abs = builder.create(lhs, fmf); Value absToC = builder.create(abs, c, fmf); Value negD = builder.create(d, fmf); Value argLhs = builder.create(b, a, fmf); Value negDArgLhs = builder.create(negD, argLhs, fmf); Value expNegDArgLhs = builder.create(negDArgLhs, fmf); Value coeff = builder.create(absToC, expNegDArgLhs, fmf); Value lnAbs = builder.create(abs, fmf); Value cArgLhs = builder.create(c, argLhs, fmf); Value dLnAbs = builder.create(d, lnAbs, fmf); Value q = builder.create(cArgLhs, dLnAbs, fmf); Value cosQ = builder.create(q, fmf); Value sinQ = builder.create(q, fmf); Value inf = builder.create( elementType, builder.getFloatAttr(elementType, APFloat::getInf(elementType.getFloatSemantics()))); Value zero = builder.create( elementType, builder.getFloatAttr(elementType, 0.0)); Value one = builder.create( elementType, builder.getFloatAttr(elementType, 1.0)); Value complexOne = builder.create(type, one, zero); Value complexZero = builder.create(type, zero, zero); Value complexInf = builder.create(type, inf, zero); // Case 0: // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see // Branch Cuts for Complex Elementary Functions or Much Ado About // Nothing's Sign Bit, W. Kahan, Section 10. Value absEqZero = builder.create(arith::CmpFPredicate::OEQ, abs, zero, fmf); Value dEqZero = builder.create(arith::CmpFPredicate::OEQ, d, zero, fmf); Value cEqZero = builder.create(arith::CmpFPredicate::OEQ, c, zero, fmf); Value bEqZero = builder.create(arith::CmpFPredicate::OEQ, b, zero, fmf); Value zeroLeC = builder.create(arith::CmpFPredicate::OLE, zero, c, fmf); Value coeffCosQ = builder.create(coeff, cosQ, fmf); Value coeffSinQ = builder.create(coeff, sinQ, fmf); Value complexOneOrZero = builder.create(cEqZero, complexOne, complexZero); Value coeffCosSin = builder.create(type, coeffCosQ, coeffSinQ); Value cutoff0 = builder.create( builder.create( builder.create(absEqZero, dEqZero), zeroLeC), complexOneOrZero, coeffCosSin); // Case 1: // x^0 is defined to be 1 for any x, see // Branch Cuts for Complex Elementary Functions or Much Ado About // Nothing's Sign Bit, W. Kahan, Section 10. Value rhsEqZero = builder.create(cEqZero, dEqZero); Value cutoff1 = builder.create(rhsEqZero, complexOne, cutoff0); // Case 2: // 1^(c + d*i) = 1 + 0*i Value lhsEqOne = builder.create( builder.create(arith::CmpFPredicate::OEQ, a, one, fmf), bEqZero); Value cutoff2 = builder.create(lhsEqOne, complexOne, cutoff1); // Case 3: // inf^(c + 0*i) = inf + 0*i, c > 0 Value lhsEqInf = builder.create( builder.create(arith::CmpFPredicate::OEQ, a, inf, fmf), bEqZero); Value rhsGt0 = builder.create( dEqZero, builder.create(arith::CmpFPredicate::OGT, c, zero, fmf)); Value cutoff3 = builder.create( builder.create(lhsEqInf, rhsGt0), complexInf, cutoff2); // Case 4: // inf^(c + 0*i) = 0 + 0*i, c < 0 Value rhsLt0 = builder.create( dEqZero, builder.create(arith::CmpFPredicate::OLT, c, zero, fmf)); Value cutoff4 = builder.create( builder.create(lhsEqInf, rhsLt0), complexZero, cutoff3); return cutoff4; } struct PowOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::PowOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter); auto type = cast(adaptor.getLhs().getType()); auto elementType = cast(type.getElementType()); Value c = builder.create(elementType, adaptor.getRhs()); Value d = builder.create(elementType, adaptor.getRhs()); rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(), c, d, op.getFastmath())}); return success(); } }; struct RsqrtOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); auto cst = [&](APFloat v) { return b.create(elementType, b.getFloatAttr(elementType, v)); }; const auto &floatSemantics = elementType.getFloatSemantics(); Value zero = cst(APFloat::getZero(floatSemantics)); Value inf = cst(APFloat::getInf(floatSemantics)); Value negHalf = b.create( elementType, b.getFloatAttr(elementType, -0.5)); Value nan = cst(APFloat::getNaN(floatSemantics)); Value real = b.create(elementType, adaptor.getComplex()); Value imag = b.create(elementType, adaptor.getComplex()); Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt); Value argArg = b.create(imag, real, fmf); Value rsqrtArg = b.create(argArg, negHalf, fmf); Value cos = b.create(rsqrtArg, fmf); Value sin = b.create(rsqrtArg, fmf); Value resultReal = b.create(absRsqrt, cos, fmf); Value resultImag = b.create(absRsqrt, sin, fmf); if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf)) { Value negOne = b.create( elementType, b.getFloatAttr(elementType, -1)); Value realSignedZero = b.create(zero, real, fmf); Value imagSignedZero = b.create(zero, imag, fmf); Value negImagSignedZero = b.create(negOne, imagSignedZero, fmf); Value absReal = b.create(real, fmf); Value absImag = b.create(imag, fmf); Value absImagIsInf = b.create(arith::CmpFPredicate::OEQ, absImag, inf, fmf); Value realIsNan = b.create(arith::CmpFPredicate::UNO, real, real, fmf); Value realIsInf = b.create(arith::CmpFPredicate::OEQ, absReal, inf, fmf); Value inIsNanInf = b.create(absImagIsInf, realIsNan); Value resultIsZero = b.create(inIsNanInf, realIsInf); resultReal = b.create(resultIsZero, realSignedZero, resultReal); resultImag = b.create(resultIsZero, negImagSignedZero, resultImag); } Value isRealZero = b.create(arith::CmpFPredicate::OEQ, real, zero, fmf); Value isImagZero = b.create(arith::CmpFPredicate::OEQ, imag, zero, fmf); Value isZero = b.create(isRealZero, isImagZero); resultReal = b.create(isZero, inf, resultReal); resultImag = b.create(isZero, nan, resultImag); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); } }; struct AngleOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = op.getType(); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value real = rewriter.create(loc, type, adaptor.getComplex()); Value imag = rewriter.create(loc, type, adaptor.getComplex()); rewriter.replaceOpWithNewOp(op, imag, real, fmf); return success(); } }; } // namespace void mlir::populateComplexToStandardConversionPatterns( RewritePatternSet &patterns) { // clang-format off patterns.add< AbsOpConversion, AngleOpConversion, Atan2OpConversion, BinaryComplexOpConversion, BinaryComplexOpConversion, ComparisonOpConversion, ComparisonOpConversion, ConjOpConversion, CosOpConversion, DivOpConversion, ExpOpConversion, Expm1OpConversion, Log1pOpConversion, LogOpConversion, MulOpConversion, NegOpConversion, SignOpConversion, SinOpConversion, SqrtOpConversion, TanTanhOpConversion, TanTanhOpConversion, PowOpConversion, RsqrtOpConversion >(patterns.getContext()); // clang-format on } namespace { struct ConvertComplexToStandardPass : public impl::ConvertComplexToStandardBase { void runOnOperation() override; }; void ConvertComplexToStandardPass::runOnOperation() { // Convert to the Standard dialect using the converter defined above. RewritePatternSet patterns(&getContext()); populateComplexToStandardConversionPatterns(patterns); ConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalOp(); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } } // namespace std::unique_ptr mlir::createConvertComplexToStandardPass() { return std::make_unique(); }