1 //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" 10 11 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" 12 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 14 #include "mlir/Conversion/LLVMCommon/Pattern.h" 15 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/Math/IR/Math.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/Pass/Pass.h" 20 21 namespace mlir { 22 #define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS 23 #include "mlir/Conversion/Passes.h.inc" 24 } // namespace mlir 25 26 using namespace mlir; 27 28 namespace { 29 30 template <typename SourceOp, typename TargetOp> 31 using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>; 32 33 template <typename SourceOp, typename TargetOp> 34 using ConvertFMFMathToLLVMPattern = 35 VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath>; 36 37 using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>; 38 using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>; 39 using CopySignOpLowering = 40 ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>; 41 using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>; 42 using CtPopFOpLowering = 43 VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>; 44 using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>; 45 using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>; 46 using FloorOpLowering = 47 ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>; 48 using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>; 49 using Log10OpLowering = 50 ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>; 51 using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>; 52 using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>; 53 using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>; 54 using FPowIOpLowering = 55 ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>; 56 using RoundEvenOpLowering = 57 ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>; 58 using RoundOpLowering = 59 ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>; 60 using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>; 61 using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>; 62 using FTruncOpLowering = 63 ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>; 64 65 // A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`. 66 template <typename MathOp, typename LLVMOp> 67 struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> { 68 using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern; 69 using Super = IntOpWithFlagLowering<MathOp, LLVMOp>; 70 71 LogicalResult 72 matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor, 73 ConversionPatternRewriter &rewriter) const override { 74 auto operandType = adaptor.getOperand().getType(); 75 76 if (!operandType || !LLVM::isCompatibleType(operandType)) 77 return failure(); 78 79 auto loc = op.getLoc(); 80 auto resultType = op.getResult().getType(); 81 82 if (!isa<LLVM::LLVMArrayType>(operandType)) { 83 rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(), 84 false); 85 return success(); 86 } 87 88 auto vectorType = dyn_cast<VectorType>(resultType); 89 if (!vectorType) 90 return failure(); 91 92 return LLVM::detail::handleMultidimensionalVectors( 93 op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(), 94 [&](Type llvm1DVectorTy, ValueRange operands) { 95 return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0], 96 false); 97 }, 98 rewriter); 99 } 100 }; 101 102 using CountLeadingZerosOpLowering = 103 IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>; 104 using CountTrailingZerosOpLowering = 105 IntOpWithFlagLowering<math::CountTrailingZerosOp, 106 LLVM::CountTrailingZerosOp>; 107 using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>; 108 109 // A `expm1` is converted into `exp - 1`. 110 struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> { 111 using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern; 112 113 LogicalResult 114 matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor, 115 ConversionPatternRewriter &rewriter) const override { 116 auto operandType = adaptor.getOperand().getType(); 117 118 if (!operandType || !LLVM::isCompatibleType(operandType)) 119 return failure(); 120 121 auto loc = op.getLoc(); 122 auto resultType = op.getResult().getType(); 123 auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType)); 124 auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 125 ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op); 126 ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op); 127 128 if (!isa<LLVM::LLVMArrayType>(operandType)) { 129 LLVM::ConstantOp one; 130 if (LLVM::isCompatibleVectorType(operandType)) { 131 one = rewriter.create<LLVM::ConstantOp>( 132 loc, operandType, 133 SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne)); 134 } else { 135 one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 136 } 137 auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand(), 138 expAttrs.getAttrs()); 139 rewriter.replaceOpWithNewOp<LLVM::FSubOp>( 140 op, operandType, ValueRange{exp, one}, subAttrs.getAttrs()); 141 return success(); 142 } 143 144 auto vectorType = dyn_cast<VectorType>(resultType); 145 if (!vectorType) 146 return rewriter.notifyMatchFailure(op, "expected vector result type"); 147 148 return LLVM::detail::handleMultidimensionalVectors( 149 op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 150 [&](Type llvm1DVectorTy, ValueRange operands) { 151 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy); 152 auto splatAttr = SplatElementsAttr::get( 153 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType, 154 {numElements.isScalable()}), 155 floatOne); 156 auto one = 157 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 158 auto exp = rewriter.create<LLVM::ExpOp>( 159 loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs()); 160 return rewriter.create<LLVM::FSubOp>( 161 loc, llvm1DVectorTy, ValueRange{exp, one}, subAttrs.getAttrs()); 162 }, 163 rewriter); 164 } 165 }; 166 167 // A `log1p` is converted into `log(1 + ...)`. 168 struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> { 169 using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern; 170 171 LogicalResult 172 matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor, 173 ConversionPatternRewriter &rewriter) const override { 174 auto operandType = adaptor.getOperand().getType(); 175 176 if (!operandType || !LLVM::isCompatibleType(operandType)) 177 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 178 179 auto loc = op.getLoc(); 180 auto resultType = op.getResult().getType(); 181 auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType)); 182 auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 183 ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op); 184 ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op); 185 186 if (!isa<LLVM::LLVMArrayType>(operandType)) { 187 LLVM::ConstantOp one = 188 LLVM::isCompatibleVectorType(operandType) 189 ? rewriter.create<LLVM::ConstantOp>( 190 loc, operandType, 191 SplatElementsAttr::get(cast<ShapedType>(resultType), 192 floatOne)) 193 : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 194 195 auto add = rewriter.create<LLVM::FAddOp>( 196 loc, operandType, ValueRange{one, adaptor.getOperand()}, 197 addAttrs.getAttrs()); 198 rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, ValueRange{add}, 199 logAttrs.getAttrs()); 200 return success(); 201 } 202 203 auto vectorType = dyn_cast<VectorType>(resultType); 204 if (!vectorType) 205 return rewriter.notifyMatchFailure(op, "expected vector result type"); 206 207 return LLVM::detail::handleMultidimensionalVectors( 208 op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 209 [&](Type llvm1DVectorTy, ValueRange operands) { 210 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy); 211 auto splatAttr = SplatElementsAttr::get( 212 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType, 213 {numElements.isScalable()}), 214 floatOne); 215 auto one = 216 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 217 auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, 218 ValueRange{one, operands[0]}, 219 addAttrs.getAttrs()); 220 return rewriter.create<LLVM::LogOp>( 221 loc, llvm1DVectorTy, ValueRange{add}, logAttrs.getAttrs()); 222 }, 223 rewriter); 224 } 225 }; 226 227 // A `rsqrt` is converted into `1 / sqrt`. 228 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> { 229 using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern; 230 231 LogicalResult 232 matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor, 233 ConversionPatternRewriter &rewriter) const override { 234 auto operandType = adaptor.getOperand().getType(); 235 236 if (!operandType || !LLVM::isCompatibleType(operandType)) 237 return failure(); 238 239 auto loc = op.getLoc(); 240 auto resultType = op.getResult().getType(); 241 auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType)); 242 auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 243 ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op); 244 ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op); 245 246 if (!isa<LLVM::LLVMArrayType>(operandType)) { 247 LLVM::ConstantOp one; 248 if (LLVM::isCompatibleVectorType(operandType)) { 249 one = rewriter.create<LLVM::ConstantOp>( 250 loc, operandType, 251 SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne)); 252 } else { 253 one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 254 } 255 auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand(), 256 sqrtAttrs.getAttrs()); 257 rewriter.replaceOpWithNewOp<LLVM::FDivOp>( 258 op, operandType, ValueRange{one, sqrt}, divAttrs.getAttrs()); 259 return success(); 260 } 261 262 auto vectorType = dyn_cast<VectorType>(resultType); 263 if (!vectorType) 264 return failure(); 265 266 return LLVM::detail::handleMultidimensionalVectors( 267 op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 268 [&](Type llvm1DVectorTy, ValueRange operands) { 269 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy); 270 auto splatAttr = SplatElementsAttr::get( 271 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType, 272 {numElements.isScalable()}), 273 floatOne); 274 auto one = 275 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 276 auto sqrt = rewriter.create<LLVM::SqrtOp>( 277 loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs()); 278 return rewriter.create<LLVM::FDivOp>( 279 loc, llvm1DVectorTy, ValueRange{one, sqrt}, divAttrs.getAttrs()); 280 }, 281 rewriter); 282 } 283 }; 284 285 struct ConvertMathToLLVMPass 286 : public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> { 287 using Base::Base; 288 289 void runOnOperation() override { 290 RewritePatternSet patterns(&getContext()); 291 LLVMTypeConverter converter(&getContext()); 292 populateMathToLLVMConversionPatterns(converter, patterns, approximateLog1p); 293 LLVMConversionTarget target(getContext()); 294 if (failed(applyPartialConversion(getOperation(), target, 295 std::move(patterns)))) 296 signalPassFailure(); 297 } 298 }; 299 } // namespace 300 301 void mlir::populateMathToLLVMConversionPatterns( 302 const LLVMTypeConverter &converter, RewritePatternSet &patterns, 303 bool approximateLog1p) { 304 if (approximateLog1p) 305 patterns.add<Log1pOpLowering>(converter); 306 // clang-format off 307 patterns.add< 308 AbsFOpLowering, 309 AbsIOpLowering, 310 CeilOpLowering, 311 CopySignOpLowering, 312 CosOpLowering, 313 CountLeadingZerosOpLowering, 314 CountTrailingZerosOpLowering, 315 CtPopFOpLowering, 316 Exp2OpLowering, 317 ExpM1OpLowering, 318 ExpOpLowering, 319 FPowIOpLowering, 320 FloorOpLowering, 321 FmaOpLowering, 322 Log10OpLowering, 323 Log2OpLowering, 324 LogOpLowering, 325 PowFOpLowering, 326 RoundEvenOpLowering, 327 RoundOpLowering, 328 RsqrtOpLowering, 329 SinOpLowering, 330 SqrtOpLowering, 331 FTruncOpLowering 332 >(converter); 333 // clang-format on 334 } 335 336 //===----------------------------------------------------------------------===// 337 // ConvertToLLVMPatternInterface implementation 338 //===----------------------------------------------------------------------===// 339 340 namespace { 341 /// Implement the interface to convert Math to LLVM. 342 struct MathToLLVMDialectInterface : public ConvertToLLVMPatternInterface { 343 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; 344 void loadDependentDialects(MLIRContext *context) const final { 345 context->loadDialect<LLVM::LLVMDialect>(); 346 } 347 348 /// Hook for derived dialect interface to provide conversion patterns 349 /// and mark dialect legal for the conversion target. 350 void populateConvertToLLVMConversionPatterns( 351 ConversionTarget &target, LLVMTypeConverter &typeConverter, 352 RewritePatternSet &patterns) const final { 353 populateMathToLLVMConversionPatterns(typeConverter, patterns); 354 } 355 }; 356 } // namespace 357 358 void mlir::registerConvertMathToLLVMInterface(DialectRegistry ®istry) { 359 registry.addExtension(+[](MLIRContext *ctx, math::MathDialect *dialect) { 360 dialect->addInterfaces<MathToLLVMDialectInterface>(); 361 }); 362 } 363