1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" 10 11 #include <memory> 12 #include <type_traits> 13 14 #include "../PassDetail.h" 15 #include "mlir/Dialect/Complex/IR/Complex.h" 16 #include "mlir/Dialect/Math/IR/Math.h" 17 #include "mlir/Dialect/StandardOps/IR/Ops.h" 18 #include "mlir/IR/ImplicitLocOpBuilder.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 22 using namespace mlir; 23 24 namespace { 25 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { 26 using OpConversionPattern<complex::AbsOp>::OpConversionPattern; 27 28 LogicalResult 29 matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands, 30 ConversionPatternRewriter &rewriter) const override { 31 complex::AbsOp::Adaptor transformed(operands); 32 auto loc = op.getLoc(); 33 auto type = op.getType(); 34 35 Value real = 36 rewriter.create<complex::ReOp>(loc, type, transformed.complex()); 37 Value imag = 38 rewriter.create<complex::ImOp>(loc, type, transformed.complex()); 39 Value realSqr = rewriter.create<MulFOp>(loc, real, real); 40 Value imagSqr = rewriter.create<MulFOp>(loc, imag, imag); 41 Value sqNorm = rewriter.create<AddFOp>(loc, realSqr, imagSqr); 42 43 rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm); 44 return success(); 45 } 46 }; 47 48 template <typename ComparisonOp, CmpFPredicate p> 49 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> { 50 using OpConversionPattern<ComparisonOp>::OpConversionPattern; 51 using ResultCombiner = 52 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value, 53 AndOp, OrOp>; 54 55 LogicalResult 56 matchAndRewrite(ComparisonOp op, ArrayRef<Value> operands, 57 ConversionPatternRewriter &rewriter) const override { 58 typename ComparisonOp::Adaptor transformed(operands); 59 auto loc = op.getLoc(); 60 auto type = transformed.lhs() 61 .getType() 62 .template cast<ComplexType>() 63 .getElementType(); 64 65 Value realLhs = 66 rewriter.create<complex::ReOp>(loc, type, transformed.lhs()); 67 Value imagLhs = 68 rewriter.create<complex::ImOp>(loc, type, transformed.lhs()); 69 Value realRhs = 70 rewriter.create<complex::ReOp>(loc, type, transformed.rhs()); 71 Value imagRhs = 72 rewriter.create<complex::ImOp>(loc, type, transformed.rhs()); 73 Value realComparison = rewriter.create<CmpFOp>(loc, p, realLhs, realRhs); 74 Value imagComparison = rewriter.create<CmpFOp>(loc, p, imagLhs, imagRhs); 75 76 rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison, 77 imagComparison); 78 return success(); 79 } 80 }; 81 82 struct DivOpConversion : public OpConversionPattern<complex::DivOp> { 83 using OpConversionPattern<complex::DivOp>::OpConversionPattern; 84 85 LogicalResult 86 matchAndRewrite(complex::DivOp op, ArrayRef<Value> operands, 87 ConversionPatternRewriter &rewriter) const override { 88 complex::DivOp::Adaptor transformed(operands); 89 auto loc = op.getLoc(); 90 auto type = transformed.lhs().getType().cast<ComplexType>(); 91 auto elementType = type.getElementType().cast<FloatType>(); 92 93 Value lhsReal = 94 rewriter.create<complex::ReOp>(loc, elementType, transformed.lhs()); 95 Value lhsImag = 96 rewriter.create<complex::ImOp>(loc, elementType, transformed.lhs()); 97 Value rhsReal = 98 rewriter.create<complex::ReOp>(loc, elementType, transformed.rhs()); 99 Value rhsImag = 100 rewriter.create<complex::ImOp>(loc, elementType, transformed.rhs()); 101 102 // Smith's algorithm to divide complex numbers. It is just a bit smarter 103 // way to compute the following formula: 104 // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i) 105 // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) / 106 // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i)) 107 // = ((lhsReal * rhsReal + lhsImag * rhsImag) + 108 // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2 109 // 110 // Depending on whether |rhsReal| < |rhsImag| we compute either 111 // rhsRealImagRatio = rhsReal / rhsImag 112 // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio 113 // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom 114 // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom 115 // 116 // or 117 // 118 // rhsImagRealRatio = rhsImag / rhsReal 119 // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio 120 // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom 121 // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom 122 // 123 // See https://dl.acm.org/citation.cfm?id=368661 for more details. 124 Value rhsRealImagRatio = rewriter.create<DivFOp>(loc, rhsReal, rhsImag); 125 Value rhsRealImagDenom = rewriter.create<AddFOp>( 126 loc, rhsImag, rewriter.create<MulFOp>(loc, rhsRealImagRatio, rhsReal)); 127 Value realNumerator1 = rewriter.create<AddFOp>( 128 loc, rewriter.create<MulFOp>(loc, lhsReal, rhsRealImagRatio), lhsImag); 129 Value resultReal1 = 130 rewriter.create<DivFOp>(loc, realNumerator1, rhsRealImagDenom); 131 Value imagNumerator1 = rewriter.create<SubFOp>( 132 loc, rewriter.create<MulFOp>(loc, lhsImag, rhsRealImagRatio), lhsReal); 133 Value resultImag1 = 134 rewriter.create<DivFOp>(loc, imagNumerator1, rhsRealImagDenom); 135 136 Value rhsImagRealRatio = rewriter.create<DivFOp>(loc, rhsImag, rhsReal); 137 Value rhsImagRealDenom = rewriter.create<AddFOp>( 138 loc, rhsReal, rewriter.create<MulFOp>(loc, rhsImagRealRatio, rhsImag)); 139 Value realNumerator2 = rewriter.create<AddFOp>( 140 loc, lhsReal, rewriter.create<MulFOp>(loc, lhsImag, rhsImagRealRatio)); 141 Value resultReal2 = 142 rewriter.create<DivFOp>(loc, realNumerator2, rhsImagRealDenom); 143 Value imagNumerator2 = rewriter.create<SubFOp>( 144 loc, lhsImag, rewriter.create<MulFOp>(loc, lhsReal, rhsImagRealRatio)); 145 Value resultImag2 = 146 rewriter.create<DivFOp>(loc, imagNumerator2, rhsImagRealDenom); 147 148 // Consider corner cases. 149 // Case 1. Zero denominator, numerator contains at most one NaN value. 150 Value zero = rewriter.create<ConstantOp>(loc, elementType, 151 rewriter.getZeroAttr(elementType)); 152 Value rhsRealAbs = rewriter.create<AbsFOp>(loc, rhsReal); 153 Value rhsRealIsZero = 154 rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsRealAbs, zero); 155 Value rhsImagAbs = rewriter.create<AbsFOp>(loc, rhsImag); 156 Value rhsImagIsZero = 157 rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsImagAbs, zero); 158 Value lhsRealIsNotNaN = 159 rewriter.create<CmpFOp>(loc, CmpFPredicate::ORD, lhsReal, zero); 160 Value lhsImagIsNotNaN = 161 rewriter.create<CmpFOp>(loc, CmpFPredicate::ORD, lhsImag, zero); 162 Value lhsContainsNotNaNValue = 163 rewriter.create<OrOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); 164 Value resultIsInfinity = rewriter.create<AndOp>( 165 loc, lhsContainsNotNaNValue, 166 rewriter.create<AndOp>(loc, rhsRealIsZero, rhsImagIsZero)); 167 Value inf = rewriter.create<ConstantOp>( 168 loc, elementType, 169 rewriter.getFloatAttr( 170 elementType, APFloat::getInf(elementType.getFloatSemantics()))); 171 Value infWithSignOfRhsReal = rewriter.create<CopySignOp>(loc, inf, rhsReal); 172 Value infinityResultReal = 173 rewriter.create<MulFOp>(loc, infWithSignOfRhsReal, lhsReal); 174 Value infinityResultImag = 175 rewriter.create<MulFOp>(loc, infWithSignOfRhsReal, lhsImag); 176 177 // Case 2. Infinite numerator, finite denominator. 178 Value rhsRealFinite = 179 rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, rhsRealAbs, inf); 180 Value rhsImagFinite = 181 rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, rhsImagAbs, inf); 182 Value rhsFinite = rewriter.create<AndOp>(loc, rhsRealFinite, rhsImagFinite); 183 Value lhsRealAbs = rewriter.create<AbsFOp>(loc, lhsReal); 184 Value lhsRealInfinite = 185 rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, lhsRealAbs, inf); 186 Value lhsImagAbs = rewriter.create<AbsFOp>(loc, lhsImag); 187 Value lhsImagInfinite = 188 rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, lhsImagAbs, inf); 189 Value lhsInfinite = 190 rewriter.create<OrOp>(loc, lhsRealInfinite, lhsImagInfinite); 191 Value infNumFiniteDenom = 192 rewriter.create<AndOp>(loc, lhsInfinite, rhsFinite); 193 Value one = rewriter.create<ConstantOp>( 194 loc, elementType, rewriter.getFloatAttr(elementType, 1)); 195 Value lhsRealIsInfWithSign = rewriter.create<CopySignOp>( 196 loc, rewriter.create<SelectOp>(loc, lhsRealInfinite, one, zero), 197 lhsReal); 198 Value lhsImagIsInfWithSign = rewriter.create<CopySignOp>( 199 loc, rewriter.create<SelectOp>(loc, lhsImagInfinite, one, zero), 200 lhsImag); 201 Value lhsRealIsInfWithSignTimesRhsReal = 202 rewriter.create<MulFOp>(loc, lhsRealIsInfWithSign, rhsReal); 203 Value lhsImagIsInfWithSignTimesRhsImag = 204 rewriter.create<MulFOp>(loc, lhsImagIsInfWithSign, rhsImag); 205 Value resultReal3 = rewriter.create<MulFOp>( 206 loc, inf, 207 rewriter.create<AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal, 208 lhsImagIsInfWithSignTimesRhsImag)); 209 Value lhsRealIsInfWithSignTimesRhsImag = 210 rewriter.create<MulFOp>(loc, lhsRealIsInfWithSign, rhsImag); 211 Value lhsImagIsInfWithSignTimesRhsReal = 212 rewriter.create<MulFOp>(loc, lhsImagIsInfWithSign, rhsReal); 213 Value resultImag3 = rewriter.create<MulFOp>( 214 loc, inf, 215 rewriter.create<SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal, 216 lhsRealIsInfWithSignTimesRhsImag)); 217 218 // Case 3: Finite numerator, infinite denominator. 219 Value lhsRealFinite = 220 rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, lhsRealAbs, inf); 221 Value lhsImagFinite = 222 rewriter.create<CmpFOp>(loc, CmpFPredicate::ONE, lhsImagAbs, inf); 223 Value lhsFinite = rewriter.create<AndOp>(loc, lhsRealFinite, lhsImagFinite); 224 Value rhsRealInfinite = 225 rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsRealAbs, inf); 226 Value rhsImagInfinite = 227 rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, rhsImagAbs, inf); 228 Value rhsInfinite = 229 rewriter.create<OrOp>(loc, rhsRealInfinite, rhsImagInfinite); 230 Value finiteNumInfiniteDenom = 231 rewriter.create<AndOp>(loc, lhsFinite, rhsInfinite); 232 Value rhsRealIsInfWithSign = rewriter.create<CopySignOp>( 233 loc, rewriter.create<SelectOp>(loc, rhsRealInfinite, one, zero), 234 rhsReal); 235 Value rhsImagIsInfWithSign = rewriter.create<CopySignOp>( 236 loc, rewriter.create<SelectOp>(loc, rhsImagInfinite, one, zero), 237 rhsImag); 238 Value rhsRealIsInfWithSignTimesLhsReal = 239 rewriter.create<MulFOp>(loc, lhsReal, rhsRealIsInfWithSign); 240 Value rhsImagIsInfWithSignTimesLhsImag = 241 rewriter.create<MulFOp>(loc, lhsImag, rhsImagIsInfWithSign); 242 Value resultReal4 = rewriter.create<MulFOp>( 243 loc, zero, 244 rewriter.create<AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal, 245 rhsImagIsInfWithSignTimesLhsImag)); 246 Value rhsRealIsInfWithSignTimesLhsImag = 247 rewriter.create<MulFOp>(loc, lhsImag, rhsRealIsInfWithSign); 248 Value rhsImagIsInfWithSignTimesLhsReal = 249 rewriter.create<MulFOp>(loc, lhsReal, rhsImagIsInfWithSign); 250 Value resultImag4 = rewriter.create<MulFOp>( 251 loc, zero, 252 rewriter.create<SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag, 253 rhsImagIsInfWithSignTimesLhsReal)); 254 255 Value realAbsSmallerThanImagAbs = rewriter.create<CmpFOp>( 256 loc, CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); 257 Value resultReal = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs, 258 resultReal1, resultReal2); 259 Value resultImag = rewriter.create<SelectOp>(loc, realAbsSmallerThanImagAbs, 260 resultImag1, resultImag2); 261 Value resultRealSpecialCase3 = rewriter.create<SelectOp>( 262 loc, finiteNumInfiniteDenom, resultReal4, resultReal); 263 Value resultImagSpecialCase3 = rewriter.create<SelectOp>( 264 loc, finiteNumInfiniteDenom, resultImag4, resultImag); 265 Value resultRealSpecialCase2 = rewriter.create<SelectOp>( 266 loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); 267 Value resultImagSpecialCase2 = rewriter.create<SelectOp>( 268 loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); 269 Value resultRealSpecialCase1 = rewriter.create<SelectOp>( 270 loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); 271 Value resultImagSpecialCase1 = rewriter.create<SelectOp>( 272 loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); 273 274 Value resultRealIsNaN = 275 rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, resultReal, zero); 276 Value resultImagIsNaN = 277 rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, resultImag, zero); 278 Value resultIsNaN = 279 rewriter.create<AndOp>(loc, resultRealIsNaN, resultImagIsNaN); 280 Value resultRealWithSpecialCases = rewriter.create<SelectOp>( 281 loc, resultIsNaN, resultRealSpecialCase1, resultReal); 282 Value resultImagWithSpecialCases = rewriter.create<SelectOp>( 283 loc, resultIsNaN, resultImagSpecialCase1, resultImag); 284 285 rewriter.replaceOpWithNewOp<complex::CreateOp>( 286 op, type, resultRealWithSpecialCases, resultImagWithSpecialCases); 287 return success(); 288 } 289 }; 290 291 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> { 292 using OpConversionPattern<complex::ExpOp>::OpConversionPattern; 293 294 LogicalResult 295 matchAndRewrite(complex::ExpOp op, ArrayRef<Value> operands, 296 ConversionPatternRewriter &rewriter) const override { 297 complex::ExpOp::Adaptor transformed(operands); 298 auto loc = op.getLoc(); 299 auto type = transformed.complex().getType().cast<ComplexType>(); 300 auto elementType = type.getElementType().cast<FloatType>(); 301 302 Value real = 303 rewriter.create<complex::ReOp>(loc, elementType, transformed.complex()); 304 Value imag = 305 rewriter.create<complex::ImOp>(loc, elementType, transformed.complex()); 306 Value expReal = rewriter.create<math::ExpOp>(loc, real); 307 Value cosImag = rewriter.create<math::CosOp>(loc, imag); 308 Value resultReal = rewriter.create<MulFOp>(loc, expReal, cosImag); 309 Value sinImag = rewriter.create<math::SinOp>(loc, imag); 310 Value resultImag = rewriter.create<MulFOp>(loc, expReal, sinImag); 311 312 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 313 resultImag); 314 return success(); 315 } 316 }; 317 318 struct LogOpConversion : public OpConversionPattern<complex::LogOp> { 319 using OpConversionPattern<complex::LogOp>::OpConversionPattern; 320 321 LogicalResult 322 matchAndRewrite(complex::LogOp op, ArrayRef<Value> operands, 323 ConversionPatternRewriter &rewriter) const override { 324 complex::LogOp::Adaptor transformed(operands); 325 auto type = transformed.complex().getType().cast<ComplexType>(); 326 auto elementType = type.getElementType().cast<FloatType>(); 327 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 328 329 Value abs = b.create<complex::AbsOp>(elementType, transformed.complex()); 330 Value resultReal = b.create<math::LogOp>(elementType, abs); 331 Value real = b.create<complex::ReOp>(elementType, transformed.complex()); 332 Value imag = b.create<complex::ImOp>(elementType, transformed.complex()); 333 Value resultImag = b.create<math::Atan2Op>(elementType, imag, real); 334 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 335 resultImag); 336 return success(); 337 } 338 }; 339 340 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> { 341 using OpConversionPattern<complex::Log1pOp>::OpConversionPattern; 342 343 LogicalResult 344 matchAndRewrite(complex::Log1pOp op, ArrayRef<Value> operands, 345 ConversionPatternRewriter &rewriter) const override { 346 complex::Log1pOp::Adaptor transformed(operands); 347 auto type = transformed.complex().getType().cast<ComplexType>(); 348 auto elementType = type.getElementType().cast<FloatType>(); 349 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 350 351 Value real = b.create<complex::ReOp>(elementType, transformed.complex()); 352 Value imag = b.create<complex::ImOp>(elementType, transformed.complex()); 353 Value one = 354 b.create<ConstantOp>(elementType, b.getFloatAttr(elementType, 1)); 355 Value realPlusOne = b.create<AddFOp>(real, one); 356 Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag); 357 rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex); 358 return success(); 359 } 360 }; 361 362 struct MulOpConversion : public OpConversionPattern<complex::MulOp> { 363 using OpConversionPattern<complex::MulOp>::OpConversionPattern; 364 365 LogicalResult 366 matchAndRewrite(complex::MulOp op, ArrayRef<Value> operands, 367 ConversionPatternRewriter &rewriter) const override { 368 complex::MulOp::Adaptor transformed(operands); 369 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 370 auto type = transformed.lhs().getType().cast<ComplexType>(); 371 auto elementType = type.getElementType().cast<FloatType>(); 372 373 Value lhsReal = b.create<complex::ReOp>(elementType, transformed.lhs()); 374 Value lhsRealAbs = b.create<AbsFOp>(lhsReal); 375 Value lhsImag = b.create<complex::ImOp>(elementType, transformed.lhs()); 376 Value lhsImagAbs = b.create<AbsFOp>(lhsImag); 377 Value rhsReal = b.create<complex::ReOp>(elementType, transformed.rhs()); 378 Value rhsRealAbs = b.create<AbsFOp>(rhsReal); 379 Value rhsImag = b.create<complex::ImOp>(elementType, transformed.rhs()); 380 Value rhsImagAbs = b.create<AbsFOp>(rhsImag); 381 382 Value lhsRealTimesRhsReal = b.create<MulFOp>(lhsReal, rhsReal); 383 Value lhsRealTimesRhsRealAbs = b.create<AbsFOp>(lhsRealTimesRhsReal); 384 Value lhsImagTimesRhsImag = b.create<MulFOp>(lhsImag, rhsImag); 385 Value lhsImagTimesRhsImagAbs = b.create<AbsFOp>(lhsImagTimesRhsImag); 386 Value real = b.create<SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag); 387 388 Value lhsImagTimesRhsReal = b.create<MulFOp>(lhsImag, rhsReal); 389 Value lhsImagTimesRhsRealAbs = b.create<AbsFOp>(lhsImagTimesRhsReal); 390 Value lhsRealTimesRhsImag = b.create<MulFOp>(lhsReal, rhsImag); 391 Value lhsRealTimesRhsImagAbs = b.create<AbsFOp>(lhsRealTimesRhsImag); 392 Value imag = b.create<AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag); 393 394 // Handle cases where the "naive" calculation results in NaN values. 395 Value realIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, real, real); 396 Value imagIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, imag, imag); 397 Value isNan = b.create<AndOp>(realIsNan, imagIsNan); 398 399 Value inf = b.create<ConstantOp>( 400 elementType, 401 b.getFloatAttr(elementType, 402 APFloat::getInf(elementType.getFloatSemantics()))); 403 404 // Case 1. `lhsReal` or `lhsImag` are infinite. 405 Value lhsRealIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, lhsRealAbs, inf); 406 Value lhsImagIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, lhsImagAbs, inf); 407 Value lhsIsInf = b.create<OrOp>(lhsRealIsInf, lhsImagIsInf); 408 Value rhsRealIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, rhsReal, rhsReal); 409 Value rhsImagIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, rhsImag, rhsImag); 410 Value zero = b.create<ConstantOp>(elementType, b.getZeroAttr(elementType)); 411 Value one = 412 b.create<ConstantOp>(elementType, b.getFloatAttr(elementType, 1)); 413 Value lhsRealIsInfFloat = b.create<SelectOp>(lhsRealIsInf, one, zero); 414 lhsReal = b.create<SelectOp>( 415 lhsIsInf, b.create<CopySignOp>(lhsRealIsInfFloat, lhsReal), lhsReal); 416 Value lhsImagIsInfFloat = b.create<SelectOp>(lhsImagIsInf, one, zero); 417 lhsImag = b.create<SelectOp>( 418 lhsIsInf, b.create<CopySignOp>(lhsImagIsInfFloat, lhsImag), lhsImag); 419 Value lhsIsInfAndRhsRealIsNan = b.create<AndOp>(lhsIsInf, rhsRealIsNan); 420 rhsReal = b.create<SelectOp>(lhsIsInfAndRhsRealIsNan, 421 b.create<CopySignOp>(zero, rhsReal), rhsReal); 422 Value lhsIsInfAndRhsImagIsNan = b.create<AndOp>(lhsIsInf, rhsImagIsNan); 423 rhsImag = b.create<SelectOp>(lhsIsInfAndRhsImagIsNan, 424 b.create<CopySignOp>(zero, rhsImag), rhsImag); 425 426 // Case 2. `rhsReal` or `rhsImag` are infinite. 427 Value rhsRealIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, rhsRealAbs, inf); 428 Value rhsImagIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, rhsImagAbs, inf); 429 Value rhsIsInf = b.create<OrOp>(rhsRealIsInf, rhsImagIsInf); 430 Value lhsRealIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, lhsReal, lhsReal); 431 Value lhsImagIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, lhsImag, lhsImag); 432 Value rhsRealIsInfFloat = b.create<SelectOp>(rhsRealIsInf, one, zero); 433 rhsReal = b.create<SelectOp>( 434 rhsIsInf, b.create<CopySignOp>(rhsRealIsInfFloat, rhsReal), rhsReal); 435 Value rhsImagIsInfFloat = b.create<SelectOp>(rhsImagIsInf, one, zero); 436 rhsImag = b.create<SelectOp>( 437 rhsIsInf, b.create<CopySignOp>(rhsImagIsInfFloat, rhsImag), rhsImag); 438 Value rhsIsInfAndLhsRealIsNan = b.create<AndOp>(rhsIsInf, lhsRealIsNan); 439 lhsReal = b.create<SelectOp>(rhsIsInfAndLhsRealIsNan, 440 b.create<CopySignOp>(zero, lhsReal), lhsReal); 441 Value rhsIsInfAndLhsImagIsNan = b.create<AndOp>(rhsIsInf, lhsImagIsNan); 442 lhsImag = b.create<SelectOp>(rhsIsInfAndLhsImagIsNan, 443 b.create<CopySignOp>(zero, lhsImag), lhsImag); 444 Value recalc = b.create<OrOp>(lhsIsInf, rhsIsInf); 445 446 // Case 3. One of the pairwise products of left hand side with right hand 447 // side is infinite. 448 Value lhsRealTimesRhsRealIsInf = 449 b.create<CmpFOp>(CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf); 450 Value lhsImagTimesRhsImagIsInf = 451 b.create<CmpFOp>(CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf); 452 Value isSpecialCase = 453 b.create<OrOp>(lhsRealTimesRhsRealIsInf, lhsImagTimesRhsImagIsInf); 454 Value lhsRealTimesRhsImagIsInf = 455 b.create<CmpFOp>(CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf); 456 isSpecialCase = b.create<OrOp>(isSpecialCase, lhsRealTimesRhsImagIsInf); 457 Value lhsImagTimesRhsRealIsInf = 458 b.create<CmpFOp>(CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf); 459 isSpecialCase = b.create<OrOp>(isSpecialCase, lhsImagTimesRhsRealIsInf); 460 Type i1Type = b.getI1Type(); 461 Value notRecalc = b.create<XOrOp>( 462 recalc, b.create<ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1))); 463 isSpecialCase = b.create<AndOp>(isSpecialCase, notRecalc); 464 Value isSpecialCaseAndLhsRealIsNan = 465 b.create<AndOp>(isSpecialCase, lhsRealIsNan); 466 lhsReal = b.create<SelectOp>(isSpecialCaseAndLhsRealIsNan, 467 b.create<CopySignOp>(zero, lhsReal), lhsReal); 468 Value isSpecialCaseAndLhsImagIsNan = 469 b.create<AndOp>(isSpecialCase, lhsImagIsNan); 470 lhsImag = b.create<SelectOp>(isSpecialCaseAndLhsImagIsNan, 471 b.create<CopySignOp>(zero, lhsImag), lhsImag); 472 Value isSpecialCaseAndRhsRealIsNan = 473 b.create<AndOp>(isSpecialCase, rhsRealIsNan); 474 rhsReal = b.create<SelectOp>(isSpecialCaseAndRhsRealIsNan, 475 b.create<CopySignOp>(zero, rhsReal), rhsReal); 476 Value isSpecialCaseAndRhsImagIsNan = 477 b.create<AndOp>(isSpecialCase, rhsImagIsNan); 478 rhsImag = b.create<SelectOp>(isSpecialCaseAndRhsImagIsNan, 479 b.create<CopySignOp>(zero, rhsImag), rhsImag); 480 recalc = b.create<OrOp>(recalc, isSpecialCase); 481 recalc = b.create<AndOp>(isNan, recalc); 482 483 // Recalculate real part. 484 lhsRealTimesRhsReal = b.create<MulFOp>(lhsReal, rhsReal); 485 lhsImagTimesRhsImag = b.create<MulFOp>(lhsImag, rhsImag); 486 Value newReal = b.create<SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag); 487 real = b.create<SelectOp>(recalc, b.create<MulFOp>(inf, newReal), real); 488 489 // Recalculate imag part. 490 lhsImagTimesRhsReal = b.create<MulFOp>(lhsImag, rhsReal); 491 lhsRealTimesRhsImag = b.create<MulFOp>(lhsReal, rhsImag); 492 Value newImag = b.create<AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag); 493 imag = b.create<SelectOp>(recalc, b.create<MulFOp>(inf, newImag), imag); 494 495 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag); 496 return success(); 497 } 498 }; 499 500 struct NegOpConversion : public OpConversionPattern<complex::NegOp> { 501 using OpConversionPattern<complex::NegOp>::OpConversionPattern; 502 503 LogicalResult 504 matchAndRewrite(complex::NegOp op, ArrayRef<Value> operands, 505 ConversionPatternRewriter &rewriter) const override { 506 complex::NegOp::Adaptor transformed(operands); 507 auto loc = op.getLoc(); 508 auto type = transformed.complex().getType().cast<ComplexType>(); 509 auto elementType = type.getElementType().cast<FloatType>(); 510 511 Value real = 512 rewriter.create<complex::ReOp>(loc, elementType, transformed.complex()); 513 Value imag = 514 rewriter.create<complex::ImOp>(loc, elementType, transformed.complex()); 515 Value negReal = rewriter.create<NegFOp>(loc, real); 516 Value negImag = rewriter.create<NegFOp>(loc, imag); 517 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag); 518 return success(); 519 } 520 }; 521 522 struct SignOpConversion : public OpConversionPattern<complex::SignOp> { 523 using OpConversionPattern<complex::SignOp>::OpConversionPattern; 524 525 LogicalResult 526 matchAndRewrite(complex::SignOp op, ArrayRef<Value> operands, 527 ConversionPatternRewriter &rewriter) const override { 528 complex::SignOp::Adaptor transformed(operands); 529 auto type = transformed.complex().getType().cast<ComplexType>(); 530 auto elementType = type.getElementType().cast<FloatType>(); 531 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 532 533 Value real = b.create<complex::ReOp>(elementType, transformed.complex()); 534 Value imag = b.create<complex::ImOp>(elementType, transformed.complex()); 535 Value zero = b.create<ConstantOp>(elementType, b.getZeroAttr(elementType)); 536 Value realIsZero = b.create<CmpFOp>(CmpFPredicate::OEQ, real, zero); 537 Value imagIsZero = b.create<CmpFOp>(CmpFPredicate::OEQ, imag, zero); 538 Value isZero = b.create<AndOp>(realIsZero, imagIsZero); 539 auto abs = b.create<complex::AbsOp>(elementType, transformed.complex()); 540 Value realSign = b.create<DivFOp>(real, abs); 541 Value imagSign = b.create<DivFOp>(imag, abs); 542 Value sign = b.create<complex::CreateOp>(type, realSign, imagSign); 543 rewriter.replaceOpWithNewOp<SelectOp>(op, isZero, transformed.complex(), 544 sign); 545 return success(); 546 } 547 }; 548 } // namespace 549 550 void mlir::populateComplexToStandardConversionPatterns( 551 RewritePatternSet &patterns) { 552 // clang-format off 553 patterns.add< 554 AbsOpConversion, 555 ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>, 556 ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>, 557 DivOpConversion, 558 ExpOpConversion, 559 LogOpConversion, 560 Log1pOpConversion, 561 MulOpConversion, 562 NegOpConversion, 563 SignOpConversion>(patterns.getContext()); 564 // clang-format on 565 } 566 567 namespace { 568 struct ConvertComplexToStandardPass 569 : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> { 570 void runOnFunction() override; 571 }; 572 573 void ConvertComplexToStandardPass::runOnFunction() { 574 auto function = getFunction(); 575 576 // Convert to the Standard dialect using the converter defined above. 577 RewritePatternSet patterns(&getContext()); 578 populateComplexToStandardConversionPatterns(patterns); 579 580 ConversionTarget target(getContext()); 581 target.addLegalDialect<StandardOpsDialect, math::MathDialect, 582 complex::ComplexDialect>(); 583 target.addIllegalOp<complex::AbsOp, complex::DivOp, complex::EqualOp, 584 complex::ExpOp, complex::LogOp, complex::Log1pOp, 585 complex::MulOp, complex::NegOp, complex::NotEqualOp, 586 complex::SignOp>(); 587 if (failed(applyPartialConversion(function, target, std::move(patterns)))) 588 signalPassFailure(); 589 } 590 } // namespace 591 592 std::unique_ptr<OperationPass<FuncOp>> 593 mlir::createConvertComplexToStandardPass() { 594 return std::make_unique<ConvertComplexToStandardPass>(); 595 } 596