1 //===-- MathToROCDL.cpp - conversion from Math to rocdl calls -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MathToROCDL/MathToROCDL.h" 10 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" 11 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 12 #include "mlir/Dialect/Func/IR/FuncOps.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" 15 #include "mlir/Dialect/Math/IR/Math.h" 16 #include "mlir/Dialect/Utils/IndexingUtils.h" 17 #include "mlir/Dialect/Vector/IR/VectorOps.h" 18 #include "mlir/IR/BuiltinDialect.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Pass/Pass.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 23 #include "../GPUCommon/GPUOpsLowering.h" 24 #include "../GPUCommon/IndexIntrinsicsOpLowering.h" 25 #include "../GPUCommon/OpToFuncCallLowering.h" 26 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" 27 28 namespace mlir { 29 #define GEN_PASS_DEF_CONVERTMATHTOROCDL 30 #include "mlir/Conversion/Passes.h.inc" 31 } // namespace mlir 32 33 using namespace mlir; 34 35 #define DEBUG_TYPE "math-to-rocdl" 36 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") 37 38 template <typename OpTy> 39 static void populateOpPatterns(const LLVMTypeConverter &converter, 40 RewritePatternSet &patterns, StringRef f32Func, 41 StringRef f64Func, StringRef f16Func, 42 StringRef f32ApproxFunc = "") { 43 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter); 44 patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, 45 f32ApproxFunc, f16Func); 46 } 47 48 void mlir::populateMathToROCDLConversionPatterns( 49 const LLVMTypeConverter &converter, RewritePatternSet &patterns) { 50 // Handled by mathToLLVM: math::AbsIOp 51 // Handled by mathToLLVM: math::AbsFOp 52 // Handled by mathToLLVM: math::CopySignOp 53 // Handled by mathToLLVM: math::CountLeadingZerosOp 54 // Handled by mathToLLVM: math::CountTrailingZerosOp 55 // Handled by mathToLLVM: math::CgPopOp 56 // Handled by mathToLLVM: math::ExpOp (32-bit only) 57 // Handled by mathToLLVM: math::FmaOp 58 // Handled by mathToLLVM: math::LogOp (32-bit only) 59 // FIXME: math::IPowIOp 60 // Handled by mathToLLVM: math::RoundEvenOp 61 // Handled by mathToLLVM: math::RoundOp 62 // Handled by mathToLLVM: math::SqrtOp 63 // Handled by mathToLLVM: math::TruncOp 64 populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32", 65 "__ocml_acos_f64", "__ocml_acos_f16"); 66 populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32", 67 "__ocml_acosh_f64", "__ocml_acosh_f16"); 68 populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32", 69 "__ocml_asin_f64", "__ocml_asin_f16"); 70 populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32", 71 "__ocml_asinh_f64", "__ocml_asinh_f16"); 72 populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32", 73 "__ocml_atan_f64", "__ocml_atan_f16"); 74 populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32", 75 "__ocml_atanh_f64", "__ocml_atanh_f16"); 76 populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32", 77 "__ocml_atan2_f64", "__ocml_atan2_f16"); 78 populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32", 79 "__ocml_cbrt_f64", "__ocml_cbrt_f16"); 80 populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32", 81 "__ocml_ceil_f64", "__ocml_ceil_f16"); 82 populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32", 83 "__ocml_cos_f64", "__ocml_cos_f16"); 84 populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32", 85 "__ocml_cosh_f64", "__ocml_cosh_f16"); 86 populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32", 87 "__ocml_sinh_f64", "__ocml_sinh_f16"); 88 populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64", 89 "__ocml_exp_f16"); 90 populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32", 91 "__ocml_exp2_f64", "__ocml_exp2_f16"); 92 populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32", 93 "__ocml_expm1_f64", "__ocml_expm1_f16"); 94 populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32", 95 "__ocml_floor_f64", "__ocml_floor_f16"); 96 populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64", 97 "__ocml_log_f16"); 98 populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32", 99 "__ocml_log10_f64", "__ocml_log10_f16"); 100 populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32", 101 "__ocml_log1p_f64", "__ocml_log1p_f16"); 102 populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32", 103 "__ocml_log2_f64", "__ocml_log2_f16"); 104 populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32", 105 "__ocml_pow_f64", "__ocml_pow_f16"); 106 populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32", 107 "__ocml_rsqrt_f64", "__ocml_rsqrt_f16"); 108 populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32", 109 "__ocml_sin_f64", "__ocml_sin_f16"); 110 populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32", 111 "__ocml_tanh_f64", "__ocml_tanh_f16"); 112 populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32", 113 "__ocml_tan_f64", "__ocml_tan_f16"); 114 populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32", 115 "__ocml_erf_f64", "__ocml_erf_f16"); 116 populateOpPatterns<math::FPowIOp>(converter, patterns, "__ocml_pown_f32", 117 "__ocml_pown_f64", "__ocml_pown_f16"); 118 // Single arith pattern that needs a ROCDL call, probably not 119 // worth creating a separate pass for it. 120 populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32", 121 "__ocml_fmod_f64", "__ocml_fmod_f16"); 122 } 123 124 namespace { 125 struct ConvertMathToROCDLPass 126 : public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> { 127 ConvertMathToROCDLPass() = default; 128 void runOnOperation() override; 129 }; 130 } // namespace 131 132 void ConvertMathToROCDLPass::runOnOperation() { 133 auto m = getOperation(); 134 MLIRContext *ctx = m.getContext(); 135 136 RewritePatternSet patterns(&getContext()); 137 LowerToLLVMOptions options(ctx, DataLayout(m)); 138 LLVMTypeConverter converter(ctx, options); 139 populateMathToROCDLConversionPatterns(converter, patterns); 140 ConversionTarget target(getContext()); 141 target.addLegalDialect<BuiltinDialect, func::FuncDialect, 142 vector::VectorDialect, LLVM::LLVMDialect>(); 143 target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp, 144 LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, 145 LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp, 146 LLVM::SqrtOp>(); 147 if (failed(applyPartialConversion(m, target, std::move(patterns)))) 148 signalPassFailure(); 149 } 150