xref: /llvm-project/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp (revision 9f114afe092483983a82a73c82704f11bb28bf8c)
1 //===-- MathToROCDL.cpp - conversion from Math to rocdl calls -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
10 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
11 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/Dialect/Utils/IndexingUtils.h"
17 #include "mlir/Dialect/Vector/IR/VectorOps.h"
18 #include "mlir/IR/BuiltinDialect.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 
23 #include "../GPUCommon/GPUOpsLowering.h"
24 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
25 #include "../GPUCommon/OpToFuncCallLowering.h"
26 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_CONVERTMATHTOROCDL
30 #include "mlir/Conversion/Passes.h.inc"
31 } // namespace mlir
32 
33 using namespace mlir;
34 
35 #define DEBUG_TYPE "math-to-rocdl"
36 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
37 
38 template <typename OpTy>
39 static void populateOpPatterns(const LLVMTypeConverter &converter,
40                                RewritePatternSet &patterns, StringRef f32Func,
41                                StringRef f64Func, StringRef f16Func,
42                                StringRef f32ApproxFunc = "") {
43   patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
44   patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
45                                            f32ApproxFunc, f16Func);
46 }
47 
48 void mlir::populateMathToROCDLConversionPatterns(
49     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
50   // Handled by mathToLLVM: math::AbsIOp
51   // Handled by mathToLLVM: math::AbsFOp
52   // Handled by mathToLLVM: math::CopySignOp
53   // Handled by mathToLLVM: math::CountLeadingZerosOp
54   // Handled by mathToLLVM: math::CountTrailingZerosOp
55   // Handled by mathToLLVM: math::CgPopOp
56   // Handled by mathToLLVM: math::ExpOp (32-bit only)
57   // Handled by mathToLLVM: math::FmaOp
58   // Handled by mathToLLVM: math::LogOp (32-bit only)
59   // FIXME: math::IPowIOp
60   // Handled by mathToLLVM: math::RoundEvenOp
61   // Handled by mathToLLVM: math::RoundOp
62   // Handled by mathToLLVM: math::SqrtOp
63   // Handled by mathToLLVM: math::TruncOp
64   populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
65                                    "__ocml_acos_f64", "__ocml_acos_f16");
66   populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
67                                     "__ocml_acosh_f64", "__ocml_acosh_f16");
68   populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
69                                    "__ocml_asin_f64", "__ocml_asin_f16");
70   populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
71                                     "__ocml_asinh_f64", "__ocml_asinh_f16");
72   populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
73                                    "__ocml_atan_f64", "__ocml_atan_f16");
74   populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
75                                     "__ocml_atanh_f64", "__ocml_atanh_f16");
76   populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
77                                     "__ocml_atan2_f64", "__ocml_atan2_f16");
78   populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
79                                    "__ocml_cbrt_f64", "__ocml_cbrt_f16");
80   populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
81                                    "__ocml_ceil_f64", "__ocml_ceil_f16");
82   populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
83                                   "__ocml_cos_f64", "__ocml_cos_f16");
84   populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
85                                    "__ocml_cosh_f64", "__ocml_cosh_f16");
86   populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
87                                    "__ocml_sinh_f64", "__ocml_sinh_f16");
88   populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64",
89                                   "__ocml_exp_f16");
90   populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
91                                    "__ocml_exp2_f64", "__ocml_exp2_f16");
92   populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
93                                     "__ocml_expm1_f64", "__ocml_expm1_f16");
94   populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
95                                     "__ocml_floor_f64", "__ocml_floor_f16");
96   populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64",
97                                   "__ocml_log_f16");
98   populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
99                                     "__ocml_log10_f64", "__ocml_log10_f16");
100   populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
101                                     "__ocml_log1p_f64", "__ocml_log1p_f16");
102   populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
103                                    "__ocml_log2_f64", "__ocml_log2_f16");
104   populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
105                                    "__ocml_pow_f64", "__ocml_pow_f16");
106   populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
107                                     "__ocml_rsqrt_f64", "__ocml_rsqrt_f16");
108   populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
109                                   "__ocml_sin_f64", "__ocml_sin_f16");
110   populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
111                                    "__ocml_tanh_f64", "__ocml_tanh_f16");
112   populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
113                                   "__ocml_tan_f64", "__ocml_tan_f16");
114   populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
115                                   "__ocml_erf_f64", "__ocml_erf_f16");
116   populateOpPatterns<math::FPowIOp>(converter, patterns, "__ocml_pown_f32",
117                                     "__ocml_pown_f64", "__ocml_pown_f16");
118   // Single arith pattern that needs a ROCDL call, probably not
119   // worth creating a separate pass for it.
120   populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
121                                     "__ocml_fmod_f64", "__ocml_fmod_f16");
122 }
123 
124 namespace {
125 struct ConvertMathToROCDLPass
126     : public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
127   ConvertMathToROCDLPass() = default;
128   void runOnOperation() override;
129 };
130 } // namespace
131 
132 void ConvertMathToROCDLPass::runOnOperation() {
133   auto m = getOperation();
134   MLIRContext *ctx = m.getContext();
135 
136   RewritePatternSet patterns(&getContext());
137   LowerToLLVMOptions options(ctx, DataLayout(m));
138   LLVMTypeConverter converter(ctx, options);
139   populateMathToROCDLConversionPatterns(converter, patterns);
140   ConversionTarget target(getContext());
141   target.addLegalDialect<BuiltinDialect, func::FuncDialect,
142                          vector::VectorDialect, LLVM::LLVMDialect>();
143   target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
144                       LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
145                       LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
146                       LLVM::SqrtOp>();
147   if (failed(applyPartialConversion(m, target, std::move(patterns))))
148     signalPassFailure();
149 }
150