1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" 10 11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12 #include "mlir/Conversion/LLVMCommon/Pattern.h" 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Complex/IR/Complex.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Pass/Pass.h" 17 18 namespace mlir { 19 #define GEN_PASS_DEF_CONVERTCOMPLEXTOLLVMPASS 20 #include "mlir/Conversion/Passes.h.inc" 21 } // namespace mlir 22 23 using namespace mlir; 24 using namespace mlir::LLVM; 25 26 //===----------------------------------------------------------------------===// 27 // ComplexStructBuilder implementation. 28 //===----------------------------------------------------------------------===// 29 30 static constexpr unsigned kRealPosInComplexNumberStruct = 0; 31 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1; 32 33 ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder, 34 Location loc, Type type) { 35 Value val = builder.create<LLVM::UndefOp>(loc, type); 36 return ComplexStructBuilder(val); 37 } 38 39 void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc, 40 Value real) { 41 setPtr(builder, loc, kRealPosInComplexNumberStruct, real); 42 } 43 44 Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) { 45 return extractPtr(builder, loc, kRealPosInComplexNumberStruct); 46 } 47 48 void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc, 49 Value imaginary) { 50 setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary); 51 } 52 53 Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) { 54 return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct); 55 } 56 57 //===----------------------------------------------------------------------===// 58 // Conversion patterns. 59 //===----------------------------------------------------------------------===// 60 61 namespace { 62 63 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> { 64 using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern; 65 66 LogicalResult 67 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, 68 ConversionPatternRewriter &rewriter) const override { 69 auto loc = op.getLoc(); 70 71 ComplexStructBuilder complexStruct(adaptor.getComplex()); 72 Value real = complexStruct.real(rewriter, op.getLoc()); 73 Value imag = complexStruct.imaginary(rewriter, op.getLoc()); 74 75 auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {}); 76 Value sqNorm = rewriter.create<LLVM::FAddOp>( 77 loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf), 78 rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf); 79 80 rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm); 81 return success(); 82 } 83 }; 84 85 struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> { 86 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 87 88 LogicalResult 89 matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor, 90 ConversionPatternRewriter &rewriter) const override { 91 return LLVM::detail::oneToOneRewrite( 92 op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), 93 op->getAttrs(), *getTypeConverter(), rewriter); 94 } 95 }; 96 97 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> { 98 using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern; 99 100 LogicalResult 101 matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor, 102 ConversionPatternRewriter &rewriter) const override { 103 // Pack real and imaginary part in a complex number struct. 104 auto loc = complexOp.getLoc(); 105 auto structType = typeConverter->convertType(complexOp.getType()); 106 auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType); 107 complexStruct.setReal(rewriter, loc, adaptor.getReal()); 108 complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary()); 109 110 rewriter.replaceOp(complexOp, {complexStruct}); 111 return success(); 112 } 113 }; 114 115 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> { 116 using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern; 117 118 LogicalResult 119 matchAndRewrite(complex::ReOp op, OpAdaptor adaptor, 120 ConversionPatternRewriter &rewriter) const override { 121 // Extract real part from the complex number struct. 122 ComplexStructBuilder complexStruct(adaptor.getComplex()); 123 Value real = complexStruct.real(rewriter, op.getLoc()); 124 rewriter.replaceOp(op, real); 125 126 return success(); 127 } 128 }; 129 130 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> { 131 using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern; 132 133 LogicalResult 134 matchAndRewrite(complex::ImOp op, OpAdaptor adaptor, 135 ConversionPatternRewriter &rewriter) const override { 136 // Extract imaginary part from the complex number struct. 137 ComplexStructBuilder complexStruct(adaptor.getComplex()); 138 Value imaginary = complexStruct.imaginary(rewriter, op.getLoc()); 139 rewriter.replaceOp(op, imaginary); 140 141 return success(); 142 } 143 }; 144 145 struct BinaryComplexOperands { 146 std::complex<Value> lhs; 147 std::complex<Value> rhs; 148 }; 149 150 template <typename OpTy> 151 BinaryComplexOperands 152 unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor, 153 ConversionPatternRewriter &rewriter) { 154 auto loc = op.getLoc(); 155 156 // Extract real and imaginary values from operands. 157 BinaryComplexOperands unpacked; 158 ComplexStructBuilder lhs(adaptor.getLhs()); 159 unpacked.lhs.real(lhs.real(rewriter, loc)); 160 unpacked.lhs.imag(lhs.imaginary(rewriter, loc)); 161 ComplexStructBuilder rhs(adaptor.getRhs()); 162 unpacked.rhs.real(rhs.real(rewriter, loc)); 163 unpacked.rhs.imag(rhs.imaginary(rewriter, loc)); 164 165 return unpacked; 166 } 167 168 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> { 169 using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern; 170 171 LogicalResult 172 matchAndRewrite(complex::AddOp op, OpAdaptor adaptor, 173 ConversionPatternRewriter &rewriter) const override { 174 auto loc = op.getLoc(); 175 BinaryComplexOperands arg = 176 unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter); 177 178 // Initialize complex number struct for result. 179 auto structType = typeConverter->convertType(op.getType()); 180 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 181 182 // Emit IR to add complex numbers. 183 auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {}); 184 Value real = 185 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); 186 Value imag = 187 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); 188 result.setReal(rewriter, loc, real); 189 result.setImaginary(rewriter, loc, imag); 190 191 rewriter.replaceOp(op, {result}); 192 return success(); 193 } 194 }; 195 196 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> { 197 using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern; 198 199 LogicalResult 200 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, 201 ConversionPatternRewriter &rewriter) const override { 202 auto loc = op.getLoc(); 203 BinaryComplexOperands arg = 204 unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter); 205 206 // Initialize complex number struct for result. 207 auto structType = typeConverter->convertType(op.getType()); 208 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 209 210 // Emit IR to add complex numbers. 211 auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {}); 212 Value rhsRe = arg.rhs.real(); 213 Value rhsIm = arg.rhs.imag(); 214 Value lhsRe = arg.lhs.real(); 215 Value lhsIm = arg.lhs.imag(); 216 217 Value rhsSqNorm = rewriter.create<LLVM::FAddOp>( 218 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf), 219 rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf); 220 221 Value resultReal = rewriter.create<LLVM::FAddOp>( 222 loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf), 223 rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf); 224 225 Value resultImag = rewriter.create<LLVM::FSubOp>( 226 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf), 227 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf); 228 229 result.setReal( 230 rewriter, loc, 231 rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf)); 232 result.setImaginary( 233 rewriter, loc, 234 rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf)); 235 236 rewriter.replaceOp(op, {result}); 237 return success(); 238 } 239 }; 240 241 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> { 242 using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern; 243 244 LogicalResult 245 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, 246 ConversionPatternRewriter &rewriter) const override { 247 auto loc = op.getLoc(); 248 BinaryComplexOperands arg = 249 unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter); 250 251 // Initialize complex number struct for result. 252 auto structType = typeConverter->convertType(op.getType()); 253 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 254 255 // Emit IR to add complex numbers. 256 auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {}); 257 Value rhsRe = arg.rhs.real(); 258 Value rhsIm = arg.rhs.imag(); 259 Value lhsRe = arg.lhs.real(); 260 Value lhsIm = arg.lhs.imag(); 261 262 Value real = rewriter.create<LLVM::FSubOp>( 263 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf), 264 rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf); 265 266 Value imag = rewriter.create<LLVM::FAddOp>( 267 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf), 268 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf); 269 270 result.setReal(rewriter, loc, real); 271 result.setImaginary(rewriter, loc, imag); 272 273 rewriter.replaceOp(op, {result}); 274 return success(); 275 } 276 }; 277 278 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> { 279 using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern; 280 281 LogicalResult 282 matchAndRewrite(complex::SubOp op, OpAdaptor adaptor, 283 ConversionPatternRewriter &rewriter) const override { 284 auto loc = op.getLoc(); 285 BinaryComplexOperands arg = 286 unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter); 287 288 // Initialize complex number struct for result. 289 auto structType = typeConverter->convertType(op.getType()); 290 auto result = ComplexStructBuilder::undef(rewriter, loc, structType); 291 292 // Emit IR to substract complex numbers. 293 auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {}); 294 Value real = 295 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); 296 Value imag = 297 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); 298 result.setReal(rewriter, loc, real); 299 result.setImaginary(rewriter, loc, imag); 300 301 rewriter.replaceOp(op, {result}); 302 return success(); 303 } 304 }; 305 } // namespace 306 307 void mlir::populateComplexToLLVMConversionPatterns( 308 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 309 // clang-format off 310 patterns.add< 311 AbsOpConversion, 312 AddOpConversion, 313 ConstantOpLowering, 314 CreateOpConversion, 315 DivOpConversion, 316 ImOpConversion, 317 MulOpConversion, 318 ReOpConversion, 319 SubOpConversion 320 >(converter); 321 // clang-format on 322 } 323 324 namespace { 325 struct ConvertComplexToLLVMPass 326 : public impl::ConvertComplexToLLVMPassBase<ConvertComplexToLLVMPass> { 327 using Base::Base; 328 329 void runOnOperation() override; 330 }; 331 } // namespace 332 333 void ConvertComplexToLLVMPass::runOnOperation() { 334 // Convert to the LLVM IR dialect using the converter defined above. 335 RewritePatternSet patterns(&getContext()); 336 LLVMTypeConverter converter(&getContext()); 337 populateComplexToLLVMConversionPatterns(converter, patterns); 338 339 LLVMConversionTarget target(getContext()); 340 target.addIllegalDialect<complex::ComplexDialect>(); 341 if (failed( 342 applyPartialConversion(getOperation(), target, std::move(patterns)))) 343 signalPassFailure(); 344 } 345