xref: /llvm-project/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
10 
11 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
12 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
14 #include "mlir/Conversion/LLVMCommon/Pattern.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Complex/IR/Complex.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Pass/Pass.h"
19 
20 namespace mlir {
21 #define GEN_PASS_DEF_CONVERTCOMPLEXTOLLVMPASS
22 #include "mlir/Conversion/Passes.h.inc"
23 } // namespace mlir
24 
25 using namespace mlir;
26 using namespace mlir::LLVM;
27 using namespace mlir::arith;
28 
29 //===----------------------------------------------------------------------===//
30 // ComplexStructBuilder implementation.
31 //===----------------------------------------------------------------------===//
32 
33 static constexpr unsigned kRealPosInComplexNumberStruct = 0;
34 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
35 
36 ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder,
37                                                  Location loc, Type type) {
38   Value val = builder.create<LLVM::UndefOp>(loc, type);
39   return ComplexStructBuilder(val);
40 }
41 
42 void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc,
43                                    Value real) {
44   setPtr(builder, loc, kRealPosInComplexNumberStruct, real);
45 }
46 
47 Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) {
48   return extractPtr(builder, loc, kRealPosInComplexNumberStruct);
49 }
50 
51 void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc,
52                                         Value imaginary) {
53   setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary);
54 }
55 
56 Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) {
57   return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct);
58 }
59 
60 //===----------------------------------------------------------------------===//
61 // Conversion patterns.
62 //===----------------------------------------------------------------------===//
63 
64 namespace {
65 
66 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
67   using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern;
68 
69   LogicalResult
70   matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
71                   ConversionPatternRewriter &rewriter) const override {
72     auto loc = op.getLoc();
73 
74     ComplexStructBuilder complexStruct(adaptor.getComplex());
75     Value real = complexStruct.real(rewriter, op.getLoc());
76     Value imag = complexStruct.imaginary(rewriter, op.getLoc());
77 
78     arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
79     LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
80         op.getContext(),
81         convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
82     Value sqNorm = rewriter.create<LLVM::FAddOp>(
83         loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
84         rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
85 
86     rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
87     return success();
88   }
89 };
90 
91 struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> {
92   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
93 
94   LogicalResult
95   matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor,
96                   ConversionPatternRewriter &rewriter) const override {
97     return LLVM::detail::oneToOneRewrite(
98         op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
99         op->getAttrs(), *getTypeConverter(), rewriter);
100   }
101 };
102 
103 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
104   using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
105 
106   LogicalResult
107   matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
108                   ConversionPatternRewriter &rewriter) const override {
109     // Pack real and imaginary part in a complex number struct.
110     auto loc = complexOp.getLoc();
111     auto structType = typeConverter->convertType(complexOp.getType());
112     auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
113     complexStruct.setReal(rewriter, loc, adaptor.getReal());
114     complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary());
115 
116     rewriter.replaceOp(complexOp, {complexStruct});
117     return success();
118   }
119 };
120 
121 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
122   using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern;
123 
124   LogicalResult
125   matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
126                   ConversionPatternRewriter &rewriter) const override {
127     // Extract real part from the complex number struct.
128     ComplexStructBuilder complexStruct(adaptor.getComplex());
129     Value real = complexStruct.real(rewriter, op.getLoc());
130     rewriter.replaceOp(op, real);
131 
132     return success();
133   }
134 };
135 
136 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
137   using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern;
138 
139   LogicalResult
140   matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
141                   ConversionPatternRewriter &rewriter) const override {
142     // Extract imaginary part from the complex number struct.
143     ComplexStructBuilder complexStruct(adaptor.getComplex());
144     Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
145     rewriter.replaceOp(op, imaginary);
146 
147     return success();
148   }
149 };
150 
151 struct BinaryComplexOperands {
152   std::complex<Value> lhs;
153   std::complex<Value> rhs;
154 };
155 
156 template <typename OpTy>
157 BinaryComplexOperands
158 unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor,
159                             ConversionPatternRewriter &rewriter) {
160   auto loc = op.getLoc();
161 
162   // Extract real and imaginary values from operands.
163   BinaryComplexOperands unpacked;
164   ComplexStructBuilder lhs(adaptor.getLhs());
165   unpacked.lhs.real(lhs.real(rewriter, loc));
166   unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
167   ComplexStructBuilder rhs(adaptor.getRhs());
168   unpacked.rhs.real(rhs.real(rewriter, loc));
169   unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
170 
171   return unpacked;
172 }
173 
174 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
175   using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern;
176 
177   LogicalResult
178   matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
179                   ConversionPatternRewriter &rewriter) const override {
180     auto loc = op.getLoc();
181     BinaryComplexOperands arg =
182         unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);
183 
184     // Initialize complex number struct for result.
185     auto structType = typeConverter->convertType(op.getType());
186     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
187 
188     // Emit IR to add complex numbers.
189     arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
190     LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
191         op.getContext(),
192         convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
193     Value real =
194         rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
195     Value imag =
196         rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
197     result.setReal(rewriter, loc, real);
198     result.setImaginary(rewriter, loc, imag);
199 
200     rewriter.replaceOp(op, {result});
201     return success();
202   }
203 };
204 
205 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
206   using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern;
207 
208   LogicalResult
209   matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
210                   ConversionPatternRewriter &rewriter) const override {
211     auto loc = op.getLoc();
212     BinaryComplexOperands arg =
213         unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);
214 
215     // Initialize complex number struct for result.
216     auto structType = typeConverter->convertType(op.getType());
217     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
218 
219     // Emit IR to add complex numbers.
220     arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
221     LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
222         op.getContext(),
223         convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
224     Value rhsRe = arg.rhs.real();
225     Value rhsIm = arg.rhs.imag();
226     Value lhsRe = arg.lhs.real();
227     Value lhsIm = arg.lhs.imag();
228 
229     Value rhsSqNorm = rewriter.create<LLVM::FAddOp>(
230         loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf),
231         rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf);
232 
233     Value resultReal = rewriter.create<LLVM::FAddOp>(
234         loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf),
235         rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf);
236 
237     Value resultImag = rewriter.create<LLVM::FSubOp>(
238         loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
239         rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
240 
241     result.setReal(
242         rewriter, loc,
243         rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf));
244     result.setImaginary(
245         rewriter, loc,
246         rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf));
247 
248     rewriter.replaceOp(op, {result});
249     return success();
250   }
251 };
252 
253 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
254   using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern;
255 
256   LogicalResult
257   matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
258                   ConversionPatternRewriter &rewriter) const override {
259     auto loc = op.getLoc();
260     BinaryComplexOperands arg =
261         unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);
262 
263     // Initialize complex number struct for result.
264     auto structType = typeConverter->convertType(op.getType());
265     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
266 
267     // Emit IR to add complex numbers.
268     arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
269     LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
270         op.getContext(),
271         convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
272     Value rhsRe = arg.rhs.real();
273     Value rhsIm = arg.rhs.imag();
274     Value lhsRe = arg.lhs.real();
275     Value lhsIm = arg.lhs.imag();
276 
277     Value real = rewriter.create<LLVM::FSubOp>(
278         loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
279         rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
280 
281     Value imag = rewriter.create<LLVM::FAddOp>(
282         loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
283         rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
284 
285     result.setReal(rewriter, loc, real);
286     result.setImaginary(rewriter, loc, imag);
287 
288     rewriter.replaceOp(op, {result});
289     return success();
290   }
291 };
292 
293 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
294   using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
295 
296   LogicalResult
297   matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
298                   ConversionPatternRewriter &rewriter) const override {
299     auto loc = op.getLoc();
300     BinaryComplexOperands arg =
301         unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);
302 
303     // Initialize complex number struct for result.
304     auto structType = typeConverter->convertType(op.getType());
305     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
306 
307     // Emit IR to substract complex numbers.
308     arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
309     LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
310         op.getContext(),
311         convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
312     Value real =
313         rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
314     Value imag =
315         rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
316     result.setReal(rewriter, loc, real);
317     result.setImaginary(rewriter, loc, imag);
318 
319     rewriter.replaceOp(op, {result});
320     return success();
321   }
322 };
323 } // namespace
324 
325 void mlir::populateComplexToLLVMConversionPatterns(
326     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
327   // clang-format off
328   patterns.add<
329       AbsOpConversion,
330       AddOpConversion,
331       ConstantOpLowering,
332       CreateOpConversion,
333       DivOpConversion,
334       ImOpConversion,
335       MulOpConversion,
336       ReOpConversion,
337       SubOpConversion
338     >(converter);
339   // clang-format on
340 }
341 
342 namespace {
343 struct ConvertComplexToLLVMPass
344     : public impl::ConvertComplexToLLVMPassBase<ConvertComplexToLLVMPass> {
345   using Base::Base;
346 
347   void runOnOperation() override;
348 };
349 } // namespace
350 
351 void ConvertComplexToLLVMPass::runOnOperation() {
352   // Convert to the LLVM IR dialect using the converter defined above.
353   RewritePatternSet patterns(&getContext());
354   LLVMTypeConverter converter(&getContext());
355   populateComplexToLLVMConversionPatterns(converter, patterns);
356 
357   LLVMConversionTarget target(getContext());
358   target.addIllegalDialect<complex::ComplexDialect>();
359   if (failed(
360           applyPartialConversion(getOperation(), target, std::move(patterns))))
361     signalPassFailure();
362 }
363 
364 //===----------------------------------------------------------------------===//
365 // ConvertToLLVMPatternInterface implementation
366 //===----------------------------------------------------------------------===//
367 
368 namespace {
369 /// Implement the interface to convert MemRef to LLVM.
370 struct ComplexToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
371   using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
372   void loadDependentDialects(MLIRContext *context) const final {
373     context->loadDialect<LLVM::LLVMDialect>();
374   }
375 
376   /// Hook for derived dialect interface to provide conversion patterns
377   /// and mark dialect legal for the conversion target.
378   void populateConvertToLLVMConversionPatterns(
379       ConversionTarget &target, LLVMTypeConverter &typeConverter,
380       RewritePatternSet &patterns) const final {
381     populateComplexToLLVMConversionPatterns(typeConverter, patterns);
382   }
383 };
384 } // namespace
385 
386 void mlir::registerConvertComplexToLLVMInterface(DialectRegistry &registry) {
387   registry.addExtension(
388       +[](MLIRContext *ctx, complex::ComplexDialect *dialect) {
389         dialect->addInterfaces<ComplexToLLVMDialectInterface>();
390       });
391 }
392