xref: /llvm-project/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
1 //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
10 
11 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
12 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
14 #include "mlir/Conversion/LLVMCommon/Pattern.h"
15 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/Math/IR/Math.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Pass/Pass.h"
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
23 #include "mlir/Conversion/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 
28 namespace {
29 
30 template <typename SourceOp, typename TargetOp>
31 using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>;
32 
33 template <typename SourceOp, typename TargetOp>
34 using ConvertFMFMathToLLVMPattern =
35     VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath>;
36 
37 using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
38 using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
39 using CopySignOpLowering =
40     ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
41 using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
42 using CtPopFOpLowering =
43     VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
44 using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
45 using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
46 using FloorOpLowering =
47     ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
48 using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
49 using Log10OpLowering =
50     ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
51 using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
52 using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
53 using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
54 using FPowIOpLowering =
55     ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
56 using RoundEvenOpLowering =
57     ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
58 using RoundOpLowering =
59     ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
60 using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
61 using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
62 using FTruncOpLowering =
63     ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
64 
65 // A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
66 template <typename MathOp, typename LLVMOp>
67 struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
68   using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
69   using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
70 
71   LogicalResult
72   matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
73                   ConversionPatternRewriter &rewriter) const override {
74     auto operandType = adaptor.getOperand().getType();
75 
76     if (!operandType || !LLVM::isCompatibleType(operandType))
77       return failure();
78 
79     auto loc = op.getLoc();
80     auto resultType = op.getResult().getType();
81 
82     if (!isa<LLVM::LLVMArrayType>(operandType)) {
83       rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
84                                           false);
85       return success();
86     }
87 
88     auto vectorType = dyn_cast<VectorType>(resultType);
89     if (!vectorType)
90       return failure();
91 
92     return LLVM::detail::handleMultidimensionalVectors(
93         op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
94         [&](Type llvm1DVectorTy, ValueRange operands) {
95           return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
96                                          false);
97         },
98         rewriter);
99   }
100 };
101 
102 using CountLeadingZerosOpLowering =
103     IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
104 using CountTrailingZerosOpLowering =
105     IntOpWithFlagLowering<math::CountTrailingZerosOp,
106                           LLVM::CountTrailingZerosOp>;
107 using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
108 
109 // A `expm1` is converted into `exp - 1`.
110 struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
111   using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
112 
113   LogicalResult
114   matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
115                   ConversionPatternRewriter &rewriter) const override {
116     auto operandType = adaptor.getOperand().getType();
117 
118     if (!operandType || !LLVM::isCompatibleType(operandType))
119       return failure();
120 
121     auto loc = op.getLoc();
122     auto resultType = op.getResult().getType();
123     auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
124     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
125     ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
126     ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
127 
128     if (!isa<LLVM::LLVMArrayType>(operandType)) {
129       LLVM::ConstantOp one;
130       if (LLVM::isCompatibleVectorType(operandType)) {
131         one = rewriter.create<LLVM::ConstantOp>(
132             loc, operandType,
133             SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne));
134       } else {
135         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
136       }
137       auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand(),
138                                               expAttrs.getAttrs());
139       rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
140           op, operandType, ValueRange{exp, one}, subAttrs.getAttrs());
141       return success();
142     }
143 
144     auto vectorType = dyn_cast<VectorType>(resultType);
145     if (!vectorType)
146       return rewriter.notifyMatchFailure(op, "expected vector result type");
147 
148     return LLVM::detail::handleMultidimensionalVectors(
149         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
150         [&](Type llvm1DVectorTy, ValueRange operands) {
151           auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
152           auto splatAttr = SplatElementsAttr::get(
153               mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
154                                     {numElements.isScalable()}),
155               floatOne);
156           auto one =
157               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
158           auto exp = rewriter.create<LLVM::ExpOp>(
159               loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs());
160           return rewriter.create<LLVM::FSubOp>(
161               loc, llvm1DVectorTy, ValueRange{exp, one}, subAttrs.getAttrs());
162         },
163         rewriter);
164   }
165 };
166 
167 // A `log1p` is converted into `log(1 + ...)`.
168 struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
169   using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
170 
171   LogicalResult
172   matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
173                   ConversionPatternRewriter &rewriter) const override {
174     auto operandType = adaptor.getOperand().getType();
175 
176     if (!operandType || !LLVM::isCompatibleType(operandType))
177       return rewriter.notifyMatchFailure(op, "unsupported operand type");
178 
179     auto loc = op.getLoc();
180     auto resultType = op.getResult().getType();
181     auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
182     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
183     ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
184     ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
185 
186     if (!isa<LLVM::LLVMArrayType>(operandType)) {
187       LLVM::ConstantOp one =
188           LLVM::isCompatibleVectorType(operandType)
189               ? rewriter.create<LLVM::ConstantOp>(
190                     loc, operandType,
191                     SplatElementsAttr::get(cast<ShapedType>(resultType),
192                                            floatOne))
193               : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
194 
195       auto add = rewriter.create<LLVM::FAddOp>(
196           loc, operandType, ValueRange{one, adaptor.getOperand()},
197           addAttrs.getAttrs());
198       rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, ValueRange{add},
199                                                logAttrs.getAttrs());
200       return success();
201     }
202 
203     auto vectorType = dyn_cast<VectorType>(resultType);
204     if (!vectorType)
205       return rewriter.notifyMatchFailure(op, "expected vector result type");
206 
207     return LLVM::detail::handleMultidimensionalVectors(
208         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
209         [&](Type llvm1DVectorTy, ValueRange operands) {
210           auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
211           auto splatAttr = SplatElementsAttr::get(
212               mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
213                                     {numElements.isScalable()}),
214               floatOne);
215           auto one =
216               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
217           auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy,
218                                                    ValueRange{one, operands[0]},
219                                                    addAttrs.getAttrs());
220           return rewriter.create<LLVM::LogOp>(
221               loc, llvm1DVectorTy, ValueRange{add}, logAttrs.getAttrs());
222         },
223         rewriter);
224   }
225 };
226 
227 // A `rsqrt` is converted into `1 / sqrt`.
228 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
229   using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
230 
231   LogicalResult
232   matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
233                   ConversionPatternRewriter &rewriter) const override {
234     auto operandType = adaptor.getOperand().getType();
235 
236     if (!operandType || !LLVM::isCompatibleType(operandType))
237       return failure();
238 
239     auto loc = op.getLoc();
240     auto resultType = op.getResult().getType();
241     auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
242     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
243     ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
244     ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
245 
246     if (!isa<LLVM::LLVMArrayType>(operandType)) {
247       LLVM::ConstantOp one;
248       if (LLVM::isCompatibleVectorType(operandType)) {
249         one = rewriter.create<LLVM::ConstantOp>(
250             loc, operandType,
251             SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne));
252       } else {
253         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
254       }
255       auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand(),
256                                                 sqrtAttrs.getAttrs());
257       rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
258           op, operandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
259       return success();
260     }
261 
262     auto vectorType = dyn_cast<VectorType>(resultType);
263     if (!vectorType)
264       return failure();
265 
266     return LLVM::detail::handleMultidimensionalVectors(
267         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
268         [&](Type llvm1DVectorTy, ValueRange operands) {
269           auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
270           auto splatAttr = SplatElementsAttr::get(
271               mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
272                                     {numElements.isScalable()}),
273               floatOne);
274           auto one =
275               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
276           auto sqrt = rewriter.create<LLVM::SqrtOp>(
277               loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs());
278           return rewriter.create<LLVM::FDivOp>(
279               loc, llvm1DVectorTy, ValueRange{one, sqrt}, divAttrs.getAttrs());
280         },
281         rewriter);
282   }
283 };
284 
285 struct ConvertMathToLLVMPass
286     : public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
287   using Base::Base;
288 
289   void runOnOperation() override {
290     RewritePatternSet patterns(&getContext());
291     LLVMTypeConverter converter(&getContext());
292     populateMathToLLVMConversionPatterns(converter, patterns, approximateLog1p);
293     LLVMConversionTarget target(getContext());
294     if (failed(applyPartialConversion(getOperation(), target,
295                                       std::move(patterns))))
296       signalPassFailure();
297   }
298 };
299 } // namespace
300 
301 void mlir::populateMathToLLVMConversionPatterns(
302     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
303     bool approximateLog1p) {
304   if (approximateLog1p)
305     patterns.add<Log1pOpLowering>(converter);
306   // clang-format off
307   patterns.add<
308     AbsFOpLowering,
309     AbsIOpLowering,
310     CeilOpLowering,
311     CopySignOpLowering,
312     CosOpLowering,
313     CountLeadingZerosOpLowering,
314     CountTrailingZerosOpLowering,
315     CtPopFOpLowering,
316     Exp2OpLowering,
317     ExpM1OpLowering,
318     ExpOpLowering,
319     FPowIOpLowering,
320     FloorOpLowering,
321     FmaOpLowering,
322     Log10OpLowering,
323     Log2OpLowering,
324     LogOpLowering,
325     PowFOpLowering,
326     RoundEvenOpLowering,
327     RoundOpLowering,
328     RsqrtOpLowering,
329     SinOpLowering,
330     SqrtOpLowering,
331     FTruncOpLowering
332   >(converter);
333   // clang-format on
334 }
335 
336 //===----------------------------------------------------------------------===//
337 // ConvertToLLVMPatternInterface implementation
338 //===----------------------------------------------------------------------===//
339 
340 namespace {
341 /// Implement the interface to convert Math to LLVM.
342 struct MathToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
343   using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
344   void loadDependentDialects(MLIRContext *context) const final {
345     context->loadDialect<LLVM::LLVMDialect>();
346   }
347 
348   /// Hook for derived dialect interface to provide conversion patterns
349   /// and mark dialect legal for the conversion target.
350   void populateConvertToLLVMConversionPatterns(
351       ConversionTarget &target, LLVMTypeConverter &typeConverter,
352       RewritePatternSet &patterns) const final {
353     populateMathToLLVMConversionPatterns(typeConverter, patterns);
354   }
355 };
356 } // namespace
357 
358 void mlir::registerConvertMathToLLVMInterface(DialectRegistry &registry) {
359   registry.addExtension(+[](MLIRContext *ctx, math::MathDialect *dialect) {
360     dialect->addInterfaces<MathToLLVMDialectInterface>();
361   });
362 }
363