xref: /llvm-project/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp (revision cd4ca2d7f991177b64db9df3c17986dc8e250f1d)
1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
10 
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Complex/IR/Complex.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Pass/Pass.h"
17 
18 namespace mlir {
19 #define GEN_PASS_DEF_CONVERTCOMPLEXTOLLVMPASS
20 #include "mlir/Conversion/Passes.h.inc"
21 } // namespace mlir
22 
23 using namespace mlir;
24 using namespace mlir::LLVM;
25 
26 //===----------------------------------------------------------------------===//
27 // ComplexStructBuilder implementation.
28 //===----------------------------------------------------------------------===//
29 
30 static constexpr unsigned kRealPosInComplexNumberStruct = 0;
31 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
32 
33 ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder,
34                                                  Location loc, Type type) {
35   Value val = builder.create<LLVM::UndefOp>(loc, type);
36   return ComplexStructBuilder(val);
37 }
38 
39 void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc,
40                                    Value real) {
41   setPtr(builder, loc, kRealPosInComplexNumberStruct, real);
42 }
43 
44 Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) {
45   return extractPtr(builder, loc, kRealPosInComplexNumberStruct);
46 }
47 
48 void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc,
49                                         Value imaginary) {
50   setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary);
51 }
52 
53 Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) {
54   return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct);
55 }
56 
57 //===----------------------------------------------------------------------===//
58 // Conversion patterns.
59 //===----------------------------------------------------------------------===//
60 
61 namespace {
62 
63 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
64   using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern;
65 
66   LogicalResult
67   matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
68                   ConversionPatternRewriter &rewriter) const override {
69     auto loc = op.getLoc();
70 
71     ComplexStructBuilder complexStruct(adaptor.getComplex());
72     Value real = complexStruct.real(rewriter, op.getLoc());
73     Value imag = complexStruct.imaginary(rewriter, op.getLoc());
74 
75     auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
76     Value sqNorm = rewriter.create<LLVM::FAddOp>(
77         loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
78         rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
79 
80     rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
81     return success();
82   }
83 };
84 
85 struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> {
86   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
87 
88   LogicalResult
89   matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor,
90                   ConversionPatternRewriter &rewriter) const override {
91     return LLVM::detail::oneToOneRewrite(
92         op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
93         op->getAttrs(), *getTypeConverter(), rewriter);
94   }
95 };
96 
97 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
98   using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
99 
100   LogicalResult
101   matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
102                   ConversionPatternRewriter &rewriter) const override {
103     // Pack real and imaginary part in a complex number struct.
104     auto loc = complexOp.getLoc();
105     auto structType = typeConverter->convertType(complexOp.getType());
106     auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
107     complexStruct.setReal(rewriter, loc, adaptor.getReal());
108     complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary());
109 
110     rewriter.replaceOp(complexOp, {complexStruct});
111     return success();
112   }
113 };
114 
115 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
116   using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern;
117 
118   LogicalResult
119   matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
120                   ConversionPatternRewriter &rewriter) const override {
121     // Extract real part from the complex number struct.
122     ComplexStructBuilder complexStruct(adaptor.getComplex());
123     Value real = complexStruct.real(rewriter, op.getLoc());
124     rewriter.replaceOp(op, real);
125 
126     return success();
127   }
128 };
129 
130 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
131   using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern;
132 
133   LogicalResult
134   matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
135                   ConversionPatternRewriter &rewriter) const override {
136     // Extract imaginary part from the complex number struct.
137     ComplexStructBuilder complexStruct(adaptor.getComplex());
138     Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
139     rewriter.replaceOp(op, imaginary);
140 
141     return success();
142   }
143 };
144 
145 struct BinaryComplexOperands {
146   std::complex<Value> lhs;
147   std::complex<Value> rhs;
148 };
149 
150 template <typename OpTy>
151 BinaryComplexOperands
152 unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor,
153                             ConversionPatternRewriter &rewriter) {
154   auto loc = op.getLoc();
155 
156   // Extract real and imaginary values from operands.
157   BinaryComplexOperands unpacked;
158   ComplexStructBuilder lhs(adaptor.getLhs());
159   unpacked.lhs.real(lhs.real(rewriter, loc));
160   unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
161   ComplexStructBuilder rhs(adaptor.getRhs());
162   unpacked.rhs.real(rhs.real(rewriter, loc));
163   unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
164 
165   return unpacked;
166 }
167 
168 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
169   using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern;
170 
171   LogicalResult
172   matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
173                   ConversionPatternRewriter &rewriter) const override {
174     auto loc = op.getLoc();
175     BinaryComplexOperands arg =
176         unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);
177 
178     // Initialize complex number struct for result.
179     auto structType = typeConverter->convertType(op.getType());
180     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
181 
182     // Emit IR to add complex numbers.
183     auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
184     Value real =
185         rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
186     Value imag =
187         rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
188     result.setReal(rewriter, loc, real);
189     result.setImaginary(rewriter, loc, imag);
190 
191     rewriter.replaceOp(op, {result});
192     return success();
193   }
194 };
195 
196 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
197   using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern;
198 
199   LogicalResult
200   matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
201                   ConversionPatternRewriter &rewriter) const override {
202     auto loc = op.getLoc();
203     BinaryComplexOperands arg =
204         unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);
205 
206     // Initialize complex number struct for result.
207     auto structType = typeConverter->convertType(op.getType());
208     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
209 
210     // Emit IR to add complex numbers.
211     auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
212     Value rhsRe = arg.rhs.real();
213     Value rhsIm = arg.rhs.imag();
214     Value lhsRe = arg.lhs.real();
215     Value lhsIm = arg.lhs.imag();
216 
217     Value rhsSqNorm = rewriter.create<LLVM::FAddOp>(
218         loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf),
219         rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf);
220 
221     Value resultReal = rewriter.create<LLVM::FAddOp>(
222         loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf),
223         rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf);
224 
225     Value resultImag = rewriter.create<LLVM::FSubOp>(
226         loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
227         rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
228 
229     result.setReal(
230         rewriter, loc,
231         rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf));
232     result.setImaginary(
233         rewriter, loc,
234         rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf));
235 
236     rewriter.replaceOp(op, {result});
237     return success();
238   }
239 };
240 
241 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
242   using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern;
243 
244   LogicalResult
245   matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
246                   ConversionPatternRewriter &rewriter) const override {
247     auto loc = op.getLoc();
248     BinaryComplexOperands arg =
249         unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);
250 
251     // Initialize complex number struct for result.
252     auto structType = typeConverter->convertType(op.getType());
253     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
254 
255     // Emit IR to add complex numbers.
256     auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
257     Value rhsRe = arg.rhs.real();
258     Value rhsIm = arg.rhs.imag();
259     Value lhsRe = arg.lhs.real();
260     Value lhsIm = arg.lhs.imag();
261 
262     Value real = rewriter.create<LLVM::FSubOp>(
263         loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
264         rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
265 
266     Value imag = rewriter.create<LLVM::FAddOp>(
267         loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
268         rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
269 
270     result.setReal(rewriter, loc, real);
271     result.setImaginary(rewriter, loc, imag);
272 
273     rewriter.replaceOp(op, {result});
274     return success();
275   }
276 };
277 
278 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
279   using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
280 
281   LogicalResult
282   matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
283                   ConversionPatternRewriter &rewriter) const override {
284     auto loc = op.getLoc();
285     BinaryComplexOperands arg =
286         unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);
287 
288     // Initialize complex number struct for result.
289     auto structType = typeConverter->convertType(op.getType());
290     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
291 
292     // Emit IR to substract complex numbers.
293     auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
294     Value real =
295         rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
296     Value imag =
297         rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
298     result.setReal(rewriter, loc, real);
299     result.setImaginary(rewriter, loc, imag);
300 
301     rewriter.replaceOp(op, {result});
302     return success();
303   }
304 };
305 } // namespace
306 
307 void mlir::populateComplexToLLVMConversionPatterns(
308     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
309   // clang-format off
310   patterns.add<
311       AbsOpConversion,
312       AddOpConversion,
313       ConstantOpLowering,
314       CreateOpConversion,
315       DivOpConversion,
316       ImOpConversion,
317       MulOpConversion,
318       ReOpConversion,
319       SubOpConversion
320     >(converter);
321   // clang-format on
322 }
323 
324 namespace {
325 struct ConvertComplexToLLVMPass
326     : public impl::ConvertComplexToLLVMPassBase<ConvertComplexToLLVMPass> {
327   using Base::Base;
328 
329   void runOnOperation() override;
330 };
331 } // namespace
332 
333 void ConvertComplexToLLVMPass::runOnOperation() {
334   // Convert to the LLVM IR dialect using the converter defined above.
335   RewritePatternSet patterns(&getContext());
336   LLVMTypeConverter converter(&getContext());
337   populateComplexToLLVMConversionPatterns(converter, patterns);
338 
339   LLVMConversionTarget target(getContext());
340   target.addIllegalDialect<complex::ComplexDialect>();
341   if (failed(
342           applyPartialConversion(getOperation(), target, std::move(patterns))))
343     signalPassFailure();
344 }
345