xref: /llvm-project/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
126e59cc1SAlex Zinenko //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
226e59cc1SAlex Zinenko //
326e59cc1SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
426e59cc1SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
526e59cc1SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
626e59cc1SAlex Zinenko //
726e59cc1SAlex Zinenko //===----------------------------------------------------------------------===//
826e59cc1SAlex Zinenko 
926e59cc1SAlex Zinenko #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
1067d0d7acSMichele Scuttari 
11589764a3SSlava Zakharin #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
12bfdc4723SMatthias Springer #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
1326e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1426e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h"
1526e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
1626e59cc1SAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1726e59cc1SAlex Zinenko #include "mlir/Dialect/Math/IR/Math.h"
1826e59cc1SAlex Zinenko #include "mlir/IR/TypeUtilities.h"
1967d0d7acSMichele Scuttari #include "mlir/Pass/Pass.h"
2067d0d7acSMichele Scuttari 
2167d0d7acSMichele Scuttari namespace mlir {
22cd4ca2d7SMarkus Böck #define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
2367d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc"
2467d0d7acSMichele Scuttari } // namespace mlir
2526e59cc1SAlex Zinenko 
2626e59cc1SAlex Zinenko using namespace mlir;
2726e59cc1SAlex Zinenko 
2826e59cc1SAlex Zinenko namespace {
29589764a3SSlava Zakharin 
30589764a3SSlava Zakharin template <typename SourceOp, typename TargetOp>
31589764a3SSlava Zakharin using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>;
32589764a3SSlava Zakharin 
33589764a3SSlava Zakharin template <typename SourceOp, typename TargetOp>
34589764a3SSlava Zakharin using ConvertFMFMathToLLVMPattern =
35589764a3SSlava Zakharin     VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath>;
36589764a3SSlava Zakharin 
37589764a3SSlava Zakharin using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
38589764a3SSlava Zakharin using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
39a54f4eaeSMogball using CopySignOpLowering =
40589764a3SSlava Zakharin     ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
41589764a3SSlava Zakharin using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
42c5fef77bSRob Suderman using CtPopFOpLowering =
43c5fef77bSRob Suderman     VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
44589764a3SSlava Zakharin using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
45589764a3SSlava Zakharin using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
46a54f4eaeSMogball using FloorOpLowering =
47589764a3SSlava Zakharin     ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
48589764a3SSlava Zakharin using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
4926e59cc1SAlex Zinenko using Log10OpLowering =
50589764a3SSlava Zakharin     ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
51589764a3SSlava Zakharin using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
52589764a3SSlava Zakharin using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
53589764a3SSlava Zakharin using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
5470174b80SSlava Zakharin using FPowIOpLowering =
5570174b80SSlava Zakharin     ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
564639a85fSTres Popp using RoundEvenOpLowering =
57589764a3SSlava Zakharin     ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
58a0fc94abSlorenzo chelini using RoundOpLowering =
59589764a3SSlava Zakharin     ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
60589764a3SSlava Zakharin using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
61589764a3SSlava Zakharin using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
629d0b90e9Sjacquesguan using FTruncOpLowering =
63589764a3SSlava Zakharin     ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
6426e59cc1SAlex Zinenko 
65bc8d9664SJeff Niu // A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
6623149d52SRob Suderman template <typename MathOp, typename LLVMOp>
67bc8d9664SJeff Niu struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
6823149d52SRob Suderman   using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
69bc8d9664SJeff Niu   using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
7023149d52SRob Suderman 
7123149d52SRob Suderman   LogicalResult
7223149d52SRob Suderman   matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
7323149d52SRob Suderman                   ConversionPatternRewriter &rewriter) const override {
7423149d52SRob Suderman     auto operandType = adaptor.getOperand().getType();
7523149d52SRob Suderman 
7623149d52SRob Suderman     if (!operandType || !LLVM::isCompatibleType(operandType))
7723149d52SRob Suderman       return failure();
7823149d52SRob Suderman 
7923149d52SRob Suderman     auto loc = op.getLoc();
8023149d52SRob Suderman     auto resultType = op.getResult().getType();
8123149d52SRob Suderman 
825550c821STres Popp     if (!isa<LLVM::LLVMArrayType>(operandType)) {
8323149d52SRob Suderman       rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
8448b126e3SChristian Ulmann                                           false);
8523149d52SRob Suderman       return success();
8623149d52SRob Suderman     }
8723149d52SRob Suderman 
885550c821STres Popp     auto vectorType = dyn_cast<VectorType>(resultType);
8923149d52SRob Suderman     if (!vectorType)
9023149d52SRob Suderman       return failure();
9123149d52SRob Suderman 
9223149d52SRob Suderman     return LLVM::detail::handleMultidimensionalVectors(
9323149d52SRob Suderman         op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
9423149d52SRob Suderman         [&](Type llvm1DVectorTy, ValueRange operands) {
95cb4a5eaeSRobert Suderman           return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
9648b126e3SChristian Ulmann                                          false);
9723149d52SRob Suderman         },
9823149d52SRob Suderman         rewriter);
9923149d52SRob Suderman   }
10023149d52SRob Suderman };
10123149d52SRob Suderman 
10223149d52SRob Suderman using CountLeadingZerosOpLowering =
103bc8d9664SJeff Niu     IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
10423149d52SRob Suderman using CountTrailingZerosOpLowering =
10548b126e3SChristian Ulmann     IntOpWithFlagLowering<math::CountTrailingZerosOp,
10648b126e3SChristian Ulmann                           LLVM::CountTrailingZerosOp>;
107bc8d9664SJeff Niu using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
10823149d52SRob Suderman 
10926e59cc1SAlex Zinenko // A `expm1` is converted into `exp - 1`.
11026e59cc1SAlex Zinenko struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
11126e59cc1SAlex Zinenko   using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
11226e59cc1SAlex Zinenko 
11326e59cc1SAlex Zinenko   LogicalResult
114ef976337SRiver Riddle   matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
11526e59cc1SAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
11662fea88bSJacques Pienaar     auto operandType = adaptor.getOperand().getType();
11726e59cc1SAlex Zinenko 
11826e59cc1SAlex Zinenko     if (!operandType || !LLVM::isCompatibleType(operandType))
11926e59cc1SAlex Zinenko       return failure();
12026e59cc1SAlex Zinenko 
12126e59cc1SAlex Zinenko     auto loc = op.getLoc();
12226e59cc1SAlex Zinenko     auto resultType = op.getResult().getType();
1235550c821STres Popp     auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
12426e59cc1SAlex Zinenko     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
125589764a3SSlava Zakharin     ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
126589764a3SSlava Zakharin     ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
12726e59cc1SAlex Zinenko 
1285550c821STres Popp     if (!isa<LLVM::LLVMArrayType>(operandType)) {
12926e59cc1SAlex Zinenko       LLVM::ConstantOp one;
13026e59cc1SAlex Zinenko       if (LLVM::isCompatibleVectorType(operandType)) {
13126e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(
13226e59cc1SAlex Zinenko             loc, operandType,
1335550c821STres Popp             SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne));
13426e59cc1SAlex Zinenko       } else {
13526e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
13626e59cc1SAlex Zinenko       }
137589764a3SSlava Zakharin       auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand(),
138589764a3SSlava Zakharin                                               expAttrs.getAttrs());
139589764a3SSlava Zakharin       rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
140589764a3SSlava Zakharin           op, operandType, ValueRange{exp, one}, subAttrs.getAttrs());
14126e59cc1SAlex Zinenko       return success();
14226e59cc1SAlex Zinenko     }
14326e59cc1SAlex Zinenko 
1445550c821STres Popp     auto vectorType = dyn_cast<VectorType>(resultType);
14526e59cc1SAlex Zinenko     if (!vectorType)
14626e59cc1SAlex Zinenko       return rewriter.notifyMatchFailure(op, "expected vector result type");
14726e59cc1SAlex Zinenko 
14826e59cc1SAlex Zinenko     return LLVM::detail::handleMultidimensionalVectors(
149ef976337SRiver Riddle         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
15026e59cc1SAlex Zinenko         [&](Type llvm1DVectorTy, ValueRange operands) {
15178890904SBenjamin Maxwell           auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
15226e59cc1SAlex Zinenko           auto splatAttr = SplatElementsAttr::get(
15378890904SBenjamin Maxwell               mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
15478890904SBenjamin Maxwell                                     {numElements.isScalable()}),
15526e59cc1SAlex Zinenko               floatOne);
15626e59cc1SAlex Zinenko           auto one =
15726e59cc1SAlex Zinenko               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
158589764a3SSlava Zakharin           auto exp = rewriter.create<LLVM::ExpOp>(
159589764a3SSlava Zakharin               loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs());
160589764a3SSlava Zakharin           return rewriter.create<LLVM::FSubOp>(
161589764a3SSlava Zakharin               loc, llvm1DVectorTy, ValueRange{exp, one}, subAttrs.getAttrs());
16226e59cc1SAlex Zinenko         },
16326e59cc1SAlex Zinenko         rewriter);
16426e59cc1SAlex Zinenko   }
16526e59cc1SAlex Zinenko };
16626e59cc1SAlex Zinenko 
16726e59cc1SAlex Zinenko // A `log1p` is converted into `log(1 + ...)`.
16826e59cc1SAlex Zinenko struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
16926e59cc1SAlex Zinenko   using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
17026e59cc1SAlex Zinenko 
17126e59cc1SAlex Zinenko   LogicalResult
172ef976337SRiver Riddle   matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
17326e59cc1SAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
17462fea88bSJacques Pienaar     auto operandType = adaptor.getOperand().getType();
17526e59cc1SAlex Zinenko 
17626e59cc1SAlex Zinenko     if (!operandType || !LLVM::isCompatibleType(operandType))
17726e59cc1SAlex Zinenko       return rewriter.notifyMatchFailure(op, "unsupported operand type");
17826e59cc1SAlex Zinenko 
17926e59cc1SAlex Zinenko     auto loc = op.getLoc();
18026e59cc1SAlex Zinenko     auto resultType = op.getResult().getType();
1815550c821STres Popp     auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
18226e59cc1SAlex Zinenko     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
183589764a3SSlava Zakharin     ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
184589764a3SSlava Zakharin     ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
18526e59cc1SAlex Zinenko 
1865550c821STres Popp     if (!isa<LLVM::LLVMArrayType>(operandType)) {
18726e59cc1SAlex Zinenko       LLVM::ConstantOp one =
18826e59cc1SAlex Zinenko           LLVM::isCompatibleVectorType(operandType)
18926e59cc1SAlex Zinenko               ? rewriter.create<LLVM::ConstantOp>(
19026e59cc1SAlex Zinenko                     loc, operandType,
1915550c821STres Popp                     SplatElementsAttr::get(cast<ShapedType>(resultType),
19226e59cc1SAlex Zinenko                                            floatOne))
19326e59cc1SAlex Zinenko               : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
19426e59cc1SAlex Zinenko 
195589764a3SSlava Zakharin       auto add = rewriter.create<LLVM::FAddOp>(
196589764a3SSlava Zakharin           loc, operandType, ValueRange{one, adaptor.getOperand()},
197589764a3SSlava Zakharin           addAttrs.getAttrs());
198589764a3SSlava Zakharin       rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, ValueRange{add},
199589764a3SSlava Zakharin                                                logAttrs.getAttrs());
20026e59cc1SAlex Zinenko       return success();
20126e59cc1SAlex Zinenko     }
20226e59cc1SAlex Zinenko 
2035550c821STres Popp     auto vectorType = dyn_cast<VectorType>(resultType);
20426e59cc1SAlex Zinenko     if (!vectorType)
20526e59cc1SAlex Zinenko       return rewriter.notifyMatchFailure(op, "expected vector result type");
20626e59cc1SAlex Zinenko 
20726e59cc1SAlex Zinenko     return LLVM::detail::handleMultidimensionalVectors(
208ef976337SRiver Riddle         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
20926e59cc1SAlex Zinenko         [&](Type llvm1DVectorTy, ValueRange operands) {
21078890904SBenjamin Maxwell           auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
21126e59cc1SAlex Zinenko           auto splatAttr = SplatElementsAttr::get(
21278890904SBenjamin Maxwell               mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
21378890904SBenjamin Maxwell                                     {numElements.isScalable()}),
21426e59cc1SAlex Zinenko               floatOne);
21526e59cc1SAlex Zinenko           auto one =
21626e59cc1SAlex Zinenko               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
217589764a3SSlava Zakharin           auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy,
218589764a3SSlava Zakharin                                                    ValueRange{one, operands[0]},
219589764a3SSlava Zakharin                                                    addAttrs.getAttrs());
220589764a3SSlava Zakharin           return rewriter.create<LLVM::LogOp>(
221589764a3SSlava Zakharin               loc, llvm1DVectorTy, ValueRange{add}, logAttrs.getAttrs());
22226e59cc1SAlex Zinenko         },
22326e59cc1SAlex Zinenko         rewriter);
22426e59cc1SAlex Zinenko   }
22526e59cc1SAlex Zinenko };
22626e59cc1SAlex Zinenko 
22726e59cc1SAlex Zinenko // A `rsqrt` is converted into `1 / sqrt`.
22826e59cc1SAlex Zinenko struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
22926e59cc1SAlex Zinenko   using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
23026e59cc1SAlex Zinenko 
23126e59cc1SAlex Zinenko   LogicalResult
232ef976337SRiver Riddle   matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
23326e59cc1SAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
23462fea88bSJacques Pienaar     auto operandType = adaptor.getOperand().getType();
23526e59cc1SAlex Zinenko 
23626e59cc1SAlex Zinenko     if (!operandType || !LLVM::isCompatibleType(operandType))
23726e59cc1SAlex Zinenko       return failure();
23826e59cc1SAlex Zinenko 
23926e59cc1SAlex Zinenko     auto loc = op.getLoc();
24026e59cc1SAlex Zinenko     auto resultType = op.getResult().getType();
2415550c821STres Popp     auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
24226e59cc1SAlex Zinenko     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
243589764a3SSlava Zakharin     ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
244589764a3SSlava Zakharin     ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
24526e59cc1SAlex Zinenko 
2465550c821STres Popp     if (!isa<LLVM::LLVMArrayType>(operandType)) {
24726e59cc1SAlex Zinenko       LLVM::ConstantOp one;
24826e59cc1SAlex Zinenko       if (LLVM::isCompatibleVectorType(operandType)) {
24926e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(
25026e59cc1SAlex Zinenko             loc, operandType,
2515550c821STres Popp             SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne));
25226e59cc1SAlex Zinenko       } else {
25326e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
25426e59cc1SAlex Zinenko       }
255589764a3SSlava Zakharin       auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand(),
256589764a3SSlava Zakharin                                                 sqrtAttrs.getAttrs());
257589764a3SSlava Zakharin       rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
258589764a3SSlava Zakharin           op, operandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
25926e59cc1SAlex Zinenko       return success();
26026e59cc1SAlex Zinenko     }
26126e59cc1SAlex Zinenko 
2625550c821STres Popp     auto vectorType = dyn_cast<VectorType>(resultType);
26326e59cc1SAlex Zinenko     if (!vectorType)
26426e59cc1SAlex Zinenko       return failure();
26526e59cc1SAlex Zinenko 
26626e59cc1SAlex Zinenko     return LLVM::detail::handleMultidimensionalVectors(
267ef976337SRiver Riddle         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
26826e59cc1SAlex Zinenko         [&](Type llvm1DVectorTy, ValueRange operands) {
26978890904SBenjamin Maxwell           auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
27026e59cc1SAlex Zinenko           auto splatAttr = SplatElementsAttr::get(
27178890904SBenjamin Maxwell               mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
27278890904SBenjamin Maxwell                                     {numElements.isScalable()}),
27326e59cc1SAlex Zinenko               floatOne);
27426e59cc1SAlex Zinenko           auto one =
27526e59cc1SAlex Zinenko               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
276589764a3SSlava Zakharin           auto sqrt = rewriter.create<LLVM::SqrtOp>(
277589764a3SSlava Zakharin               loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs());
278589764a3SSlava Zakharin           return rewriter.create<LLVM::FDivOp>(
279589764a3SSlava Zakharin               loc, llvm1DVectorTy, ValueRange{one, sqrt}, divAttrs.getAttrs());
28026e59cc1SAlex Zinenko         },
28126e59cc1SAlex Zinenko         rewriter);
28226e59cc1SAlex Zinenko   }
28326e59cc1SAlex Zinenko };
28426e59cc1SAlex Zinenko 
28526e59cc1SAlex Zinenko struct ConvertMathToLLVMPass
286cd4ca2d7SMarkus Böck     : public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
287cd4ca2d7SMarkus Böck   using Base::Base;
28826e59cc1SAlex Zinenko 
28941574554SRiver Riddle   void runOnOperation() override {
29026e59cc1SAlex Zinenko     RewritePatternSet patterns(&getContext());
29126e59cc1SAlex Zinenko     LLVMTypeConverter converter(&getContext());
2928a9d4895SAlexander Belyaev     populateMathToLLVMConversionPatterns(converter, patterns, approximateLog1p);
29326e59cc1SAlex Zinenko     LLVMConversionTarget target(getContext());
29441574554SRiver Riddle     if (failed(applyPartialConversion(getOperation(), target,
29541574554SRiver Riddle                                       std::move(patterns))))
29626e59cc1SAlex Zinenko       signalPassFailure();
29726e59cc1SAlex Zinenko   }
29826e59cc1SAlex Zinenko };
29926e59cc1SAlex Zinenko } // namespace
30026e59cc1SAlex Zinenko 
301*206fad0eSMatthias Springer void mlir::populateMathToLLVMConversionPatterns(
302*206fad0eSMatthias Springer     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
3038a9d4895SAlexander Belyaev     bool approximateLog1p) {
3048a9d4895SAlexander Belyaev   if (approximateLog1p)
3058a9d4895SAlexander Belyaev     patterns.add<Log1pOpLowering>(converter);
30626e59cc1SAlex Zinenko   // clang-format off
30726e59cc1SAlex Zinenko   patterns.add<
30800f7096dSJeff Niu     AbsFOpLowering,
3097d9fc95bSJeff Niu     AbsIOpLowering,
310a54f4eaeSMogball     CeilOpLowering,
311a54f4eaeSMogball     CopySignOpLowering,
31226e59cc1SAlex Zinenko     CosOpLowering,
31323149d52SRob Suderman     CountLeadingZerosOpLowering,
31423149d52SRob Suderman     CountTrailingZerosOpLowering,
315c5fef77bSRob Suderman     CtPopFOpLowering,
31626e59cc1SAlex Zinenko     Exp2OpLowering,
31726e59cc1SAlex Zinenko     ExpM1OpLowering,
3182a3c07f8Slorenzo chelini     ExpOpLowering,
31970174b80SSlava Zakharin     FPowIOpLowering,
320a54f4eaeSMogball     FloorOpLowering,
321a54f4eaeSMogball     FmaOpLowering,
32226e59cc1SAlex Zinenko     Log10OpLowering,
32326e59cc1SAlex Zinenko     Log2OpLowering,
32426e59cc1SAlex Zinenko     LogOpLowering,
32526e59cc1SAlex Zinenko     PowFOpLowering,
3264639a85fSTres Popp     RoundEvenOpLowering,
3272a3c07f8Slorenzo chelini     RoundOpLowering,
32826e59cc1SAlex Zinenko     RsqrtOpLowering,
32926e59cc1SAlex Zinenko     SinOpLowering,
3309d0b90e9Sjacquesguan     SqrtOpLowering,
3319d0b90e9Sjacquesguan     FTruncOpLowering
33226e59cc1SAlex Zinenko   >(converter);
33326e59cc1SAlex Zinenko   // clang-format on
33426e59cc1SAlex Zinenko }
335bfdc4723SMatthias Springer 
336bfdc4723SMatthias Springer //===----------------------------------------------------------------------===//
337bfdc4723SMatthias Springer // ConvertToLLVMPatternInterface implementation
338bfdc4723SMatthias Springer //===----------------------------------------------------------------------===//
339bfdc4723SMatthias Springer 
340bfdc4723SMatthias Springer namespace {
341bfdc4723SMatthias Springer /// Implement the interface to convert Math to LLVM.
342bfdc4723SMatthias Springer struct MathToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
343bfdc4723SMatthias Springer   using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
344bfdc4723SMatthias Springer   void loadDependentDialects(MLIRContext *context) const final {
345bfdc4723SMatthias Springer     context->loadDialect<LLVM::LLVMDialect>();
346bfdc4723SMatthias Springer   }
347bfdc4723SMatthias Springer 
348bfdc4723SMatthias Springer   /// Hook for derived dialect interface to provide conversion patterns
349bfdc4723SMatthias Springer   /// and mark dialect legal for the conversion target.
350bfdc4723SMatthias Springer   void populateConvertToLLVMConversionPatterns(
351bfdc4723SMatthias Springer       ConversionTarget &target, LLVMTypeConverter &typeConverter,
352bfdc4723SMatthias Springer       RewritePatternSet &patterns) const final {
353bfdc4723SMatthias Springer     populateMathToLLVMConversionPatterns(typeConverter, patterns);
354bfdc4723SMatthias Springer   }
355bfdc4723SMatthias Springer };
356bfdc4723SMatthias Springer } // namespace
357bfdc4723SMatthias Springer 
358bfdc4723SMatthias Springer void mlir::registerConvertMathToLLVMInterface(DialectRegistry &registry) {
359bfdc4723SMatthias Springer   registry.addExtension(+[](MLIRContext *ctx, math::MathDialect *dialect) {
360bfdc4723SMatthias Springer     dialect->addInterfaces<MathToLLVMDialectInterface>();
361bfdc4723SMatthias Springer   });
362bfdc4723SMatthias Springer }
363