1 //===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 #ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ 9 #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ 10 11 #include "mlir/Conversion/LLVMCommon/Pattern.h" 12 #include "mlir/Dialect/Arith/IR/Arith.h" 13 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/IR/Builders.h" 16 17 namespace mlir { 18 19 namespace { 20 /// Detection trait tor the `getFastmath` instance method. 21 template <typename T> 22 using has_get_fastmath_t = decltype(std::declval<T>().getFastmath()); 23 } // namespace 24 25 /// Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func` or 26 /// `f32ApproxFunc` or `f16Func` or `i32Type` depending on the element type and 27 /// the fastMathFlag of that Op, if present. The function declaration is added 28 /// in case it was not added before. 29 /// 30 /// If the input values are of bf16 type (or f16 type if f16Func is empty), the 31 /// value is first casted to f32, the function called and then the result casted 32 /// back. 33 /// 34 /// Example with NVVM: 35 /// %exp_f32 = math.exp %arg_f32 : f32 36 /// 37 /// will be transformed into 38 /// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32 39 /// 40 /// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers 41 /// to the approximate calculation function. 42 /// 43 /// Also example with NVVM: 44 /// %exp_f32 = math.exp %arg_f32 fastmath<afn> : f32 45 /// 46 /// will be transformed into 47 /// llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32 48 /// 49 /// Final example with NVVM: 50 /// %pow_f32 = math.fpowi %arg_f32, %arg_i32 51 /// 52 /// will be transformed into 53 /// llvm.call @__nv_powif(%arg_f32, %arg_i32) : (f32, i32) -> f32 54 template <typename SourceOp> 55 struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> { 56 public: 57 explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering, 58 StringRef f32Func, StringRef f64Func, 59 StringRef f32ApproxFunc, StringRef f16Func, 60 StringRef i32Func = "") 61 : ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func), 62 f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func), 63 i32Func(i32Func) {} 64 65 LogicalResult 66 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, 67 ConversionPatternRewriter &rewriter) const override { 68 using LLVM::LLVMFuncOp; 69 70 static_assert( 71 std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value, 72 "expected single result op"); 73 74 if constexpr (!std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>, 75 SourceOp>::value) { 76 assert(op->getNumOperands() > 0 && 77 "expected op to take at least one operand"); 78 assert(op->getResultTypes().front() == op->getOperand(0).getType() && 79 "expected op with same operand and result types"); 80 } 81 82 if (!op->template getParentOfType<FunctionOpInterface>()) { 83 return rewriter.notifyMatchFailure( 84 op, "expected op to be within a function region"); 85 } 86 87 SmallVector<Value, 1> castedOperands; 88 for (Value operand : adaptor.getOperands()) 89 castedOperands.push_back(maybeCast(operand, rewriter)); 90 91 Type resultType = castedOperands.front().getType(); 92 Type funcType = getFunctionType(resultType, castedOperands); 93 StringRef funcName = getFunctionName( 94 cast<LLVM::LLVMFunctionType>(funcType).getReturnType(), op); 95 if (funcName.empty()) 96 return failure(); 97 98 LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op); 99 auto callOp = 100 rewriter.create<LLVM::CallOp>(op->getLoc(), funcOp, castedOperands); 101 102 if (resultType == adaptor.getOperands().front().getType()) { 103 rewriter.replaceOp(op, {callOp.getResult()}); 104 return success(); 105 } 106 107 assert(callOp.getResult().getType().isF32() && 108 "only f32 types are supposed to be truncated back"); 109 Value truncated = rewriter.create<LLVM::FPTruncOp>( 110 op->getLoc(), adaptor.getOperands().front().getType(), 111 callOp.getResult()); 112 rewriter.replaceOp(op, {truncated}); 113 return success(); 114 } 115 116 Value maybeCast(Value operand, PatternRewriter &rewriter) const { 117 Type type = operand.getType(); 118 if (!isa<Float16Type, BFloat16Type>(type)) 119 return operand; 120 121 // if there's a f16 function, no need to cast f16 values 122 if (!f16Func.empty() && isa<Float16Type>(type)) 123 return operand; 124 125 return rewriter.create<LLVM::FPExtOp>( 126 operand.getLoc(), Float32Type::get(rewriter.getContext()), operand); 127 } 128 129 Type getFunctionType(Type resultType, ValueRange operands) const { 130 SmallVector<Type> operandTypes(operands.getTypes()); 131 return LLVM::LLVMFunctionType::get(resultType, operandTypes); 132 } 133 134 LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType, 135 Operation *op) const { 136 using LLVM::LLVMFuncOp; 137 138 auto funcAttr = StringAttr::get(op->getContext(), funcName); 139 auto funcOp = 140 SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr); 141 if (funcOp) 142 return funcOp; 143 144 auto parentFunc = op->getParentOfType<FunctionOpInterface>(); 145 assert(parentFunc && "expected there to be a parent function"); 146 OpBuilder b(parentFunc); 147 return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType); 148 } 149 150 StringRef getFunctionName(Type type, SourceOp op) const { 151 bool useApprox = false; 152 if constexpr (llvm::is_detected<has_get_fastmath_t, SourceOp>::value) { 153 arith::FastMathFlags flag = op.getFastmath(); 154 useApprox = ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) && 155 !f32ApproxFunc.empty(); 156 } 157 158 if (isa<Float16Type>(type)) 159 return f16Func; 160 if (isa<Float32Type>(type)) { 161 if (useApprox) 162 return f32ApproxFunc; 163 return f32Func; 164 } 165 if (isa<Float64Type>(type)) 166 return f64Func; 167 168 if (type.isInteger(32)) 169 return i32Func; 170 return ""; 171 } 172 173 const std::string f32Func; 174 const std::string f64Func; 175 const std::string f32ApproxFunc; 176 const std::string f16Func; 177 const std::string i32Func; 178 }; 179 180 } // namespace mlir 181 182 #endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ 183