xref: /llvm-project/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h (revision 4df28af7134518981d40cb3242b2a90af867fdae)
1 //===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
9 #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
10 
11 #include "mlir/Conversion/LLVMCommon/Pattern.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/IR/Builders.h"
16 
17 namespace mlir {
18 
19 namespace {
20 /// Detection trait tor the `getFastmath` instance method.
21 template <typename T>
22 using has_get_fastmath_t = decltype(std::declval<T>().getFastmath());
23 } // namespace
24 
25 /// Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func` or
26 /// `f32ApproxFunc` or `f16Func` or `i32Type` depending on the element type and
27 /// the fastMathFlag of that Op, if present. The function declaration is added
28 /// in case it was not added before.
29 ///
30 /// If the input values are of bf16 type (or f16 type if f16Func is empty), the
31 /// value is first casted to f32, the function called and then the result casted
32 /// back.
33 ///
34 /// Example with NVVM:
35 ///   %exp_f32 = math.exp %arg_f32 : f32
36 ///
37 /// will be transformed into
38 ///   llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
39 ///
40 /// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers
41 /// to the approximate calculation function.
42 ///
43 /// Also example with NVVM:
44 ///   %exp_f32 = math.exp %arg_f32 fastmath<afn> : f32
45 ///
46 /// will be transformed into
47 ///   llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32
48 ///
49 /// Final example with NVVM:
50 ///   %pow_f32 = math.fpowi %arg_f32, %arg_i32
51 ///
52 /// will be transformed into
53 ///   llvm.call @__nv_powif(%arg_f32, %arg_i32) : (f32, i32) -> f32
54 template <typename SourceOp>
55 struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
56 public:
57   explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering,
58                                 StringRef f32Func, StringRef f64Func,
59                                 StringRef f32ApproxFunc, StringRef f16Func,
60                                 StringRef i32Func = "")
61       : ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
62         f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func),
63         i32Func(i32Func) {}
64 
65   LogicalResult
66   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
67                   ConversionPatternRewriter &rewriter) const override {
68     using LLVM::LLVMFuncOp;
69 
70     static_assert(
71         std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
72         "expected single result op");
73 
74     if constexpr (!std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
75                                    SourceOp>::value) {
76       assert(op->getNumOperands() > 0 &&
77              "expected op to take at least one operand");
78       assert(op->getResultTypes().front() == op->getOperand(0).getType() &&
79              "expected op with same operand and result types");
80     }
81 
82     if (!op->template getParentOfType<FunctionOpInterface>()) {
83       return rewriter.notifyMatchFailure(
84           op, "expected op to be within a function region");
85     }
86 
87     SmallVector<Value, 1> castedOperands;
88     for (Value operand : adaptor.getOperands())
89       castedOperands.push_back(maybeCast(operand, rewriter));
90 
91     Type resultType = castedOperands.front().getType();
92     Type funcType = getFunctionType(resultType, castedOperands);
93     StringRef funcName = getFunctionName(
94         cast<LLVM::LLVMFunctionType>(funcType).getReturnType(), op);
95     if (funcName.empty())
96       return failure();
97 
98     LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
99     auto callOp =
100         rewriter.create<LLVM::CallOp>(op->getLoc(), funcOp, castedOperands);
101 
102     if (resultType == adaptor.getOperands().front().getType()) {
103       rewriter.replaceOp(op, {callOp.getResult()});
104       return success();
105     }
106 
107     assert(callOp.getResult().getType().isF32() &&
108            "only f32 types are supposed to be truncated back");
109     Value truncated = rewriter.create<LLVM::FPTruncOp>(
110         op->getLoc(), adaptor.getOperands().front().getType(),
111         callOp.getResult());
112     rewriter.replaceOp(op, {truncated});
113     return success();
114   }
115 
116   Value maybeCast(Value operand, PatternRewriter &rewriter) const {
117     Type type = operand.getType();
118     if (!isa<Float16Type, BFloat16Type>(type))
119       return operand;
120 
121     // if there's a f16 function, no need to cast f16 values
122     if (!f16Func.empty() && isa<Float16Type>(type))
123       return operand;
124 
125     return rewriter.create<LLVM::FPExtOp>(
126         operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
127   }
128 
129   Type getFunctionType(Type resultType, ValueRange operands) const {
130     SmallVector<Type> operandTypes(operands.getTypes());
131     return LLVM::LLVMFunctionType::get(resultType, operandTypes);
132   }
133 
134   LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType,
135                                      Operation *op) const {
136     using LLVM::LLVMFuncOp;
137 
138     auto funcAttr = StringAttr::get(op->getContext(), funcName);
139     auto funcOp =
140         SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
141     if (funcOp)
142       return funcOp;
143 
144     auto parentFunc = op->getParentOfType<FunctionOpInterface>();
145     assert(parentFunc && "expected there to be a parent function");
146     OpBuilder b(parentFunc);
147     return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
148   }
149 
150   StringRef getFunctionName(Type type, SourceOp op) const {
151     bool useApprox = false;
152     if constexpr (llvm::is_detected<has_get_fastmath_t, SourceOp>::value) {
153       arith::FastMathFlags flag = op.getFastmath();
154       useApprox = ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
155                   !f32ApproxFunc.empty();
156     }
157 
158     if (isa<Float16Type>(type))
159       return f16Func;
160     if (isa<Float32Type>(type)) {
161       if (useApprox)
162         return f32ApproxFunc;
163       return f32Func;
164     }
165     if (isa<Float64Type>(type))
166       return f64Func;
167 
168     if (type.isInteger(32))
169       return i32Func;
170     return "";
171   }
172 
173   const std::string f32Func;
174   const std::string f64Func;
175   const std::string f32ApproxFunc;
176   const std::string f16Func;
177   const std::string i32Func;
178 };
179 
180 } // namespace mlir
181 
182 #endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
183