1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" 10 11 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" 12 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 14 #include "mlir/Conversion/LLVMCommon/Pattern.h" 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/Complex/IR/Complex.h" 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Pass/Pass.h" 19 20 namespace mlir { 21 #define GEN_PASS_DEF_CONVERTCOMPLEXTOLLVMPASS 22 #include "mlir/Conversion/Passes.h.inc" 23 } // namespace mlir 24 25 using namespace mlir; 26 using namespace mlir::LLVM; 27 using namespace mlir::arith; 28 29 //===----------------------------------------------------------------------===// 30 // ComplexStructBuilder implementation. 31 //===----------------------------------------------------------------------===// 32 33 static constexpr unsigned kRealPosInComplexNumberStruct = 0; 34 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1; 35 36 ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder, 37 Location loc, Type type) { 38 Value val = builder.create<LLVM::UndefOp>(loc, type); 39 return ComplexStructBuilder(val); 40 } 41 42 void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc, 43 Value real) { 44 setPtr(builder, loc, kRealPosInComplexNumberStruct, real); 45 } 46 47 Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) { 48 return extractPtr(builder, loc, kRealPosInComplexNumberStruct); 49 } 50 51 void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc, 52 Value imaginary) { 53 setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary); 54 } 55 56 Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) { 57 return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct); 58 } 59 60 //===----------------------------------------------------------------------===// 61 // Conversion patterns. 62 //===----------------------------------------------------------------------===// 63 64 namespace { 65 66 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> { 67 using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern; 68 69 LogicalResult 70 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, 71 ConversionPatternRewriter &rewriter) const override { 72 auto loc = op.getLoc(); 73 74 ComplexStructBuilder complexStruct(adaptor.getComplex()); 75 Value real = complexStruct.real(rewriter, op.getLoc()); 76 Value imag = complexStruct.imaginary(rewriter, op.getLoc()); 77 78 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); 79 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( 80 op.getContext(), 81 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); 82 Value sqNorm = rewriter.create<LLVM::FAddOp>( 83 loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf), 84 rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf); 85 86 rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm); 87 return success(); 88 } 89 }; 90 91 struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> { 92 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 93 94 LogicalResult 95 matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor, 96 ConversionPatternRewriter &rewriter) const override { 97 return LLVM::detail::oneToOneRewrite( 98 op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), 99 op->getAttrs(), *getTypeConverter(), rewriter); 100 } 101 }; 102 103 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> { 104 using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern; 105 106 LogicalResult 107 matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor, 108 ConversionPatternRewriter &rewriter) const override { 109 // Pack real and imaginary part in a complex number struct. 110 auto loc = complexOp.getLoc(); 111 auto structType = typeConverter->convertType(complexOp.getType()); 112 auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType); 113 complexStruct.setReal(rewriter, loc, adaptor.getReal()); 114 complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary()); 115 116 rewriter.replaceOp(complexOp, {complexStruct}); 117 return success(); 118 } 119 }; 120 121 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> { 122 using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern; 123 124 LogicalResult 125 matchAndRewrite(complex::ReOp op, OpAdaptor adaptor, 126 ConversionPatternRewriter &rewriter) const override { 127 // Extract real part from the complex number struct. 128 ComplexStructBuilder complexStruct(adaptor.getComplex()); 129 Value real = complexStruct.real(rewriter, op.getLoc()); 130 rewriter.replaceOp(op, real); 131 132 return success(); 133 } 134 }; 135 136 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> { 137 using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern; 138 139 LogicalResult 140 matchAndRewrite(complex::ImOp op, OpAdaptor adaptor, 141 ConversionPatternRewriter &rewriter) const override { 142 // Extract imaginary part from the complex number struct. 143 ComplexStructBuilder complexStruct(adaptor.getComplex()); 144 Value imaginary = complexStruct.imaginary(rewriter, op.getLoc()); 145 rewriter.replaceOp(op, imaginary); 146 147 return success(); 148 } 149 }; 150 151 struct BinaryComplexOperands { 152 std::complex<Value> lhs; 153 std::complex<Value> rhs; 154 }; 155 156 template <typename OpTy> 157 BinaryComplexOperands 158 unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor, 159 ConversionPatternRewriter &rewriter) { 160 auto loc = op.getLoc(); 161 162 // Extract real and imaginary values from operands. 163 BinaryComplexOperands unpacked; 164 ComplexStructBuilder lhs(adaptor.getLhs()); 165 unpacked.lhs.real(lhs.real(rewriter, loc)); 166 unpacked.lhs.imag(lhs.imaginary(rewriter, loc)); 167 ComplexStructBuilder rhs(adaptor.getRhs()); 168 unpacked.rhs.real(rhs.real(rewriter, loc)); 169 unpacked.rhs.imag(rhs.imaginary(rewriter, loc)); 170 171 return unpacked; 172 } 173 174 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> { 175 using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern; 176 177 LogicalResult 178 matchAndRewrite(complex::AddOp op, OpAdaptor adaptor, 179 ConversionPatternRewriter &rewriter) const override { 180 auto loc = op.getLoc(); 181 BinaryComplexOperands arg = 182 unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter); 183 184 // Initialize complex number struct for result. 185 auto structType = typeConverter->convertType(op.getType()); 186 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 187 188 // Emit IR to add complex numbers. 189 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); 190 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( 191 op.getContext(), 192 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); 193 Value real = 194 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); 195 Value imag = 196 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); 197 result.setReal(rewriter, loc, real); 198 result.setImaginary(rewriter, loc, imag); 199 200 rewriter.replaceOp(op, {result}); 201 return success(); 202 } 203 }; 204 205 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> { 206 using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern; 207 208 LogicalResult 209 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, 210 ConversionPatternRewriter &rewriter) const override { 211 auto loc = op.getLoc(); 212 BinaryComplexOperands arg = 213 unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter); 214 215 // Initialize complex number struct for result. 216 auto structType = typeConverter->convertType(op.getType()); 217 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 218 219 // Emit IR to add complex numbers. 220 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); 221 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( 222 op.getContext(), 223 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); 224 Value rhsRe = arg.rhs.real(); 225 Value rhsIm = arg.rhs.imag(); 226 Value lhsRe = arg.lhs.real(); 227 Value lhsIm = arg.lhs.imag(); 228 229 Value rhsSqNorm = rewriter.create<LLVM::FAddOp>( 230 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf), 231 rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf); 232 233 Value resultReal = rewriter.create<LLVM::FAddOp>( 234 loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf), 235 rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf); 236 237 Value resultImag = rewriter.create<LLVM::FSubOp>( 238 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf), 239 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf); 240 241 result.setReal( 242 rewriter, loc, 243 rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf)); 244 result.setImaginary( 245 rewriter, loc, 246 rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf)); 247 248 rewriter.replaceOp(op, {result}); 249 return success(); 250 } 251 }; 252 253 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> { 254 using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern; 255 256 LogicalResult 257 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, 258 ConversionPatternRewriter &rewriter) const override { 259 auto loc = op.getLoc(); 260 BinaryComplexOperands arg = 261 unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter); 262 263 // Initialize complex number struct for result. 264 auto structType = typeConverter->convertType(op.getType()); 265 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 266 267 // Emit IR to add complex numbers. 268 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); 269 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( 270 op.getContext(), 271 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); 272 Value rhsRe = arg.rhs.real(); 273 Value rhsIm = arg.rhs.imag(); 274 Value lhsRe = arg.lhs.real(); 275 Value lhsIm = arg.lhs.imag(); 276 277 Value real = rewriter.create<LLVM::FSubOp>( 278 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf), 279 rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf); 280 281 Value imag = rewriter.create<LLVM::FAddOp>( 282 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf), 283 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf); 284 285 result.setReal(rewriter, loc, real); 286 result.setImaginary(rewriter, loc, imag); 287 288 rewriter.replaceOp(op, {result}); 289 return success(); 290 } 291 }; 292 293 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> { 294 using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern; 295 296 LogicalResult 297 matchAndRewrite(complex::SubOp op, OpAdaptor adaptor, 298 ConversionPatternRewriter &rewriter) const override { 299 auto loc = op.getLoc(); 300 BinaryComplexOperands arg = 301 unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter); 302 303 // Initialize complex number struct for result. 304 auto structType = typeConverter->convertType(op.getType()); 305 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 306 307 // Emit IR to substract complex numbers. 308 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); 309 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( 310 op.getContext(), 311 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); 312 Value real = 313 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); 314 Value imag = 315 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); 316 result.setReal(rewriter, loc, real); 317 result.setImaginary(rewriter, loc, imag); 318 319 rewriter.replaceOp(op, {result}); 320 return success(); 321 } 322 }; 323 } // namespace 324 325 void mlir::populateComplexToLLVMConversionPatterns( 326 const LLVMTypeConverter &converter, RewritePatternSet &patterns) { 327 // clang-format off 328 patterns.add< 329 AbsOpConversion, 330 AddOpConversion, 331 ConstantOpLowering, 332 CreateOpConversion, 333 DivOpConversion, 334 ImOpConversion, 335 MulOpConversion, 336 ReOpConversion, 337 SubOpConversion 338 >(converter); 339 // clang-format on 340 } 341 342 namespace { 343 struct ConvertComplexToLLVMPass 344 : public impl::ConvertComplexToLLVMPassBase<ConvertComplexToLLVMPass> { 345 using Base::Base; 346 347 void runOnOperation() override; 348 }; 349 } // namespace 350 351 void ConvertComplexToLLVMPass::runOnOperation() { 352 // Convert to the LLVM IR dialect using the converter defined above. 353 RewritePatternSet patterns(&getContext()); 354 LLVMTypeConverter converter(&getContext()); 355 populateComplexToLLVMConversionPatterns(converter, patterns); 356 357 LLVMConversionTarget target(getContext()); 358 target.addIllegalDialect<complex::ComplexDialect>(); 359 if (failed( 360 applyPartialConversion(getOperation(), target, std::move(patterns)))) 361 signalPassFailure(); 362 } 363 364 //===----------------------------------------------------------------------===// 365 // ConvertToLLVMPatternInterface implementation 366 //===----------------------------------------------------------------------===// 367 368 namespace { 369 /// Implement the interface to convert MemRef to LLVM. 370 struct ComplexToLLVMDialectInterface : public ConvertToLLVMPatternInterface { 371 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; 372 void loadDependentDialects(MLIRContext *context) const final { 373 context->loadDialect<LLVM::LLVMDialect>(); 374 } 375 376 /// Hook for derived dialect interface to provide conversion patterns 377 /// and mark dialect legal for the conversion target. 378 void populateConvertToLLVMConversionPatterns( 379 ConversionTarget &target, LLVMTypeConverter &typeConverter, 380 RewritePatternSet &patterns) const final { 381 populateComplexToLLVMConversionPatterns(typeConverter, patterns); 382 } 383 }; 384 } // namespace 385 386 void mlir::registerConvertComplexToLLVMInterface(DialectRegistry ®istry) { 387 registry.addExtension( 388 +[](MLIRContext *ctx, complex::ComplexDialect *dialect) { 389 dialect->addInterfaces<ComplexToLLVMDialectInterface>(); 390 }); 391 } 392