134810e1bSTres Popp //===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===// 234810e1bSTres Popp // 334810e1bSTres Popp // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 434810e1bSTres Popp // See https://llvm.org/LICENSE.txt for license information. 534810e1bSTres Popp // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 634810e1bSTres Popp // 734810e1bSTres Popp //===----------------------------------------------------------------------===// 834810e1bSTres Popp 934810e1bSTres Popp #include "mlir/Conversion/MathToLibm/MathToLibm.h" 1034810e1bSTres Popp 11abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 1223aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 139c442c7dSSlava Zakharin #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 1434810e1bSTres Popp #include "mlir/Dialect/Math/IR/Math.h" 1599ef9eebSMatthias Springer #include "mlir/Dialect/Utils/IndexingUtils.h" 1699ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h" 1734810e1bSTres Popp #include "mlir/IR/BuiltinDialect.h" 1834810e1bSTres Popp #include "mlir/IR/PatternMatch.h" 1967d0d7acSMichele Scuttari #include "mlir/Pass/Pass.h" 208a9d4895SAlexander Belyaev #include "mlir/Transforms/DialectConversion.h" 2167d0d7acSMichele Scuttari 2267d0d7acSMichele Scuttari namespace mlir { 2367d0d7acSMichele Scuttari #define GEN_PASS_DEF_CONVERTMATHTOLIBM 2467d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc" 2567d0d7acSMichele Scuttari } // namespace mlir 2634810e1bSTres Popp 2734810e1bSTres Popp using namespace mlir; 2834810e1bSTres Popp 2934810e1bSTres Popp namespace { 3034810e1bSTres Popp // Pattern to convert vector operations to scalar operations. This is needed as 3134810e1bSTres Popp // libm calls require scalars. 3234810e1bSTres Popp template <typename Op> 3334810e1bSTres Popp struct VecOpToScalarOp : public OpRewritePattern<Op> { 3434810e1bSTres Popp public: 3534810e1bSTres Popp using OpRewritePattern<Op>::OpRewritePattern; 3634810e1bSTres Popp 3734810e1bSTres Popp LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 3834810e1bSTres Popp }; 39a48adc56SBenjamin Kramer // Pattern to promote an op of a smaller floating point type to F32. 40a48adc56SBenjamin Kramer template <typename Op> 41a48adc56SBenjamin Kramer struct PromoteOpToF32 : public OpRewritePattern<Op> { 42a48adc56SBenjamin Kramer public: 43a48adc56SBenjamin Kramer using OpRewritePattern<Op>::OpRewritePattern; 44a48adc56SBenjamin Kramer 45a48adc56SBenjamin Kramer LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 46a48adc56SBenjamin Kramer }; 4734810e1bSTres Popp // Pattern to convert scalar math operations to calls to libm functions. 4834810e1bSTres Popp // Additionally the libm function signatures are declared. 4934810e1bSTres Popp template <typename Op> 5034810e1bSTres Popp struct ScalarOpToLibmCall : public OpRewritePattern<Op> { 5134810e1bSTres Popp public: 5234810e1bSTres Popp using OpRewritePattern<Op>::OpRewritePattern; 53a4ee55feSAlexander Batashev ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc, 548a9d4895SAlexander Belyaev StringRef doubleFunc) 558a9d4895SAlexander Belyaev : OpRewritePattern<Op>(context), floatFunc(floatFunc), 5634810e1bSTres Popp doubleFunc(doubleFunc){}; 5734810e1bSTres Popp 5834810e1bSTres Popp LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 5934810e1bSTres Popp 6034810e1bSTres Popp private: 6134810e1bSTres Popp std::string floatFunc, doubleFunc; 6234810e1bSTres Popp }; 633bf1f0e7SAlexander Belyaev 643bf1f0e7SAlexander Belyaev template <typename OpTy> 653bf1f0e7SAlexander Belyaev void populatePatternsForOp(RewritePatternSet &patterns, MLIRContext *ctx, 663bf1f0e7SAlexander Belyaev StringRef floatFunc, StringRef doubleFunc) { 673bf1f0e7SAlexander Belyaev patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx); 683bf1f0e7SAlexander Belyaev patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, floatFunc, doubleFunc); 693bf1f0e7SAlexander Belyaev } 703bf1f0e7SAlexander Belyaev 7134810e1bSTres Popp } // namespace 7234810e1bSTres Popp 7334810e1bSTres Popp template <typename Op> 7434810e1bSTres Popp LogicalResult 7534810e1bSTres Popp VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const { 7634810e1bSTres Popp auto opType = op.getType(); 7734810e1bSTres Popp auto loc = op.getLoc(); 785550c821STres Popp auto vecType = dyn_cast<VectorType>(opType); 7934810e1bSTres Popp 8034810e1bSTres Popp if (!vecType) 8134810e1bSTres Popp return failure(); 8234810e1bSTres Popp if (!vecType.hasRank()) 8334810e1bSTres Popp return failure(); 8434810e1bSTres Popp auto shape = vecType.getShape(); 85921d91f3SAdrian Kuegel int64_t numElements = vecType.getNumElements(); 8634810e1bSTres Popp 87a54f4eaeSMogball Value result = rewriter.create<arith::ConstantOp>( 8834810e1bSTres Popp loc, DenseElementsAttr::get( 8934810e1bSTres Popp vecType, FloatAttr::get(vecType.getElementType(), 0.0))); 907a69a9d7SNicolas Vasilache SmallVector<int64_t> strides = computeStrides(shape); 91921d91f3SAdrian Kuegel for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) { 92203fad47SNicolas Vasilache SmallVector<int64_t> positions = delinearize(linearIndex, strides); 9334810e1bSTres Popp SmallVector<Value> operands; 9434810e1bSTres Popp for (auto input : op->getOperands()) 9534810e1bSTres Popp operands.push_back( 96921d91f3SAdrian Kuegel rewriter.create<vector::ExtractOp>(loc, input, positions)); 9734810e1bSTres Popp Value scalarOp = 9834810e1bSTres Popp rewriter.create<Op>(loc, vecType.getElementType(), operands); 99921d91f3SAdrian Kuegel result = 100921d91f3SAdrian Kuegel rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions); 10134810e1bSTres Popp } 10234810e1bSTres Popp rewriter.replaceOp(op, {result}); 10334810e1bSTres Popp return success(); 10434810e1bSTres Popp } 10534810e1bSTres Popp 10634810e1bSTres Popp template <typename Op> 10734810e1bSTres Popp LogicalResult 108a48adc56SBenjamin Kramer PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const { 109a48adc56SBenjamin Kramer auto opType = op.getType(); 1105550c821STres Popp if (!isa<Float16Type, BFloat16Type>(opType)) 111a48adc56SBenjamin Kramer return failure(); 112a48adc56SBenjamin Kramer 113a48adc56SBenjamin Kramer auto loc = op.getLoc(); 114a48adc56SBenjamin Kramer auto f32 = rewriter.getF32Type(); 115a48adc56SBenjamin Kramer auto extendedOperands = llvm::to_vector( 116a48adc56SBenjamin Kramer llvm::map_range(op->getOperands(), [&](Value operand) -> Value { 117a48adc56SBenjamin Kramer return rewriter.create<arith::ExtFOp>(loc, f32, operand); 118a48adc56SBenjamin Kramer })); 119a48adc56SBenjamin Kramer auto newOp = rewriter.create<Op>(loc, f32, extendedOperands); 120a48adc56SBenjamin Kramer rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, opType, newOp); 121a48adc56SBenjamin Kramer return success(); 122a48adc56SBenjamin Kramer } 123a48adc56SBenjamin Kramer 124a48adc56SBenjamin Kramer template <typename Op> 125a48adc56SBenjamin Kramer LogicalResult 12634810e1bSTres Popp ScalarOpToLibmCall<Op>::matchAndRewrite(Op op, 12734810e1bSTres Popp PatternRewriter &rewriter) const { 1281ebf7ce9STres Popp auto module = SymbolTable::getNearestSymbolTable(op); 12934810e1bSTres Popp auto type = op.getType(); 1305550c821STres Popp if (!isa<Float32Type, Float64Type>(type)) 13134810e1bSTres Popp return failure(); 13234810e1bSTres Popp 13334810e1bSTres Popp auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; 1341ebf7ce9STres Popp auto opFunc = dyn_cast_or_null<SymbolOpInterface>( 1351ebf7ce9STres Popp SymbolTable::lookupSymbolIn(module, name)); 13634810e1bSTres Popp // Forward declare function if it hasn't already been 13734810e1bSTres Popp if (!opFunc) { 13834810e1bSTres Popp OpBuilder::InsertionGuard guard(rewriter); 1391ebf7ce9STres Popp rewriter.setInsertionPointToStart(&module->getRegion(0).front()); 14034810e1bSTres Popp auto opFunctionTy = FunctionType::get( 14134810e1bSTres Popp rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); 14258ceae95SRiver Riddle opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name, 14358ceae95SRiver Riddle opFunctionTy); 14434810e1bSTres Popp opFunc.setPrivate(); 1459c442c7dSSlava Zakharin 1469c442c7dSSlava Zakharin // By definition Math dialect operations imply LLVM's "readnone" 1479c442c7dSSlava Zakharin // function attribute, so we can set it here to provide more 1489c442c7dSSlava Zakharin // optimization opportunities (e.g. LICM) for backends targeting LLVM IR. 1499c442c7dSSlava Zakharin // This will have to be changed, when strict FP behavior is supported 1509c442c7dSSlava Zakharin // by Math dialect. 1519c442c7dSSlava Zakharin opFunc->setAttr(LLVM::LLVMDialect::getReadnoneAttrName(), 1529c442c7dSSlava Zakharin UnitAttr::get(rewriter.getContext())); 15334810e1bSTres Popp } 1547ceffae1SRiver Riddle assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name))); 15534810e1bSTres Popp 15623aa5a74SRiver Riddle rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(), 1571ebf7ce9STres Popp op->getOperands()); 15834810e1bSTres Popp 15934810e1bSTres Popp return success(); 16034810e1bSTres Popp } 16134810e1bSTres Popp 1628a9d4895SAlexander Belyaev void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) { 1638a9d4895SAlexander Belyaev MLIRContext *ctx = patterns.getContext(); 1643bf1f0e7SAlexander Belyaev 16550b93733SCorentin Ferry populatePatternsForOp<math::AbsFOp>(patterns, ctx, "fabsf", "fabs"); 166f7250179SFrederik Harwath populatePatternsForOp<math::AcosOp>(patterns, ctx, "acosf", "acos"); 167b8dca4faSVivek Khandelwal populatePatternsForOp<math::AcoshOp>(patterns, ctx, "acoshf", "acosh"); 168b8dca4faSVivek Khandelwal populatePatternsForOp<math::AsinOp>(patterns, ctx, "asinf", "asin"); 169b8dca4faSVivek Khandelwal populatePatternsForOp<math::AsinhOp>(patterns, ctx, "asinhf", "asinh"); 1703bf1f0e7SAlexander Belyaev populatePatternsForOp<math::Atan2Op>(patterns, ctx, "atan2f", "atan2"); 1713bf1f0e7SAlexander Belyaev populatePatternsForOp<math::AtanOp>(patterns, ctx, "atanf", "atan"); 172b8dca4faSVivek Khandelwal populatePatternsForOp<math::AtanhOp>(patterns, ctx, "atanhf", "atanh"); 1733bf1f0e7SAlexander Belyaev populatePatternsForOp<math::CbrtOp>(patterns, ctx, "cbrtf", "cbrt"); 1743bf1f0e7SAlexander Belyaev populatePatternsForOp<math::CeilOp>(patterns, ctx, "ceilf", "ceil"); 1753bf1f0e7SAlexander Belyaev populatePatternsForOp<math::CosOp>(patterns, ctx, "cosf", "cos"); 176762964e9SSungsoon Cho populatePatternsForOp<math::CoshOp>(patterns, ctx, "coshf", "cosh"); 1773bf1f0e7SAlexander Belyaev populatePatternsForOp<math::ErfOp>(patterns, ctx, "erff", "erf"); 17850b93733SCorentin Ferry populatePatternsForOp<math::ExpOp>(patterns, ctx, "expf", "exp"); 17950b93733SCorentin Ferry populatePatternsForOp<math::Exp2Op>(patterns, ctx, "exp2f", "exp2"); 1803bf1f0e7SAlexander Belyaev populatePatternsForOp<math::ExpM1Op>(patterns, ctx, "expm1f", "expm1"); 1813bf1f0e7SAlexander Belyaev populatePatternsForOp<math::FloorOp>(patterns, ctx, "floorf", "floor"); 18250b93733SCorentin Ferry populatePatternsForOp<math::FmaOp>(patterns, ctx, "fmaf", "fma"); 18350b93733SCorentin Ferry populatePatternsForOp<math::LogOp>(patterns, ctx, "logf", "log"); 18450b93733SCorentin Ferry populatePatternsForOp<math::Log2Op>(patterns, ctx, "log2f", "log2"); 18550b93733SCorentin Ferry populatePatternsForOp<math::Log10Op>(patterns, ctx, "log10f", "log10"); 1863bf1f0e7SAlexander Belyaev populatePatternsForOp<math::Log1pOp>(patterns, ctx, "log1pf", "log1p"); 18750b93733SCorentin Ferry populatePatternsForOp<math::PowFOp>(patterns, ctx, "powf", "pow"); 1883bf1f0e7SAlexander Belyaev populatePatternsForOp<math::RoundEvenOp>(patterns, ctx, "roundevenf", 1898a9d4895SAlexander Belyaev "roundeven"); 1903bf1f0e7SAlexander Belyaev populatePatternsForOp<math::RoundOp>(patterns, ctx, "roundf", "round"); 1913bf1f0e7SAlexander Belyaev populatePatternsForOp<math::SinOp>(patterns, ctx, "sinf", "sin"); 192aa165edcSRob Suderman populatePatternsForOp<math::SinhOp>(patterns, ctx, "sinhf", "sinh"); 19350b93733SCorentin Ferry populatePatternsForOp<math::SqrtOp>(patterns, ctx, "sqrtf", "sqrt"); 194*b96f18b2SIvy Zhang populatePatternsForOp<math::RsqrtOp>(patterns, ctx, "rsqrtf", "rsqrt"); 1953bf1f0e7SAlexander Belyaev populatePatternsForOp<math::TanOp>(patterns, ctx, "tanf", "tan"); 1963bf1f0e7SAlexander Belyaev populatePatternsForOp<math::TanhOp>(patterns, ctx, "tanhf", "tanh"); 1973bf1f0e7SAlexander Belyaev populatePatternsForOp<math::TruncOp>(patterns, ctx, "truncf", "trunc"); 19834810e1bSTres Popp } 19934810e1bSTres Popp 20034810e1bSTres Popp namespace { 20134810e1bSTres Popp struct ConvertMathToLibmPass 20267d0d7acSMichele Scuttari : public impl::ConvertMathToLibmBase<ConvertMathToLibmPass> { 20334810e1bSTres Popp void runOnOperation() override; 20434810e1bSTres Popp }; 20534810e1bSTres Popp } // namespace 20634810e1bSTres Popp 20734810e1bSTres Popp void ConvertMathToLibmPass::runOnOperation() { 20834810e1bSTres Popp auto module = getOperation(); 20934810e1bSTres Popp 21034810e1bSTres Popp RewritePatternSet patterns(&getContext()); 2118a9d4895SAlexander Belyaev populateMathToLibmConversionPatterns(patterns); 21234810e1bSTres Popp 21334810e1bSTres Popp ConversionTarget target(getContext()); 214abc362a1SJakub Kuderski target.addLegalDialect<arith::ArithDialect, BuiltinDialect, func::FuncDialect, 215abc362a1SJakub Kuderski vector::VectorDialect>(); 21634810e1bSTres Popp target.addIllegalDialect<math::MathDialect>(); 21734810e1bSTres Popp if (failed(applyPartialConversion(module, target, std::move(patterns)))) 21834810e1bSTres Popp signalPassFailure(); 21934810e1bSTres Popp } 22034810e1bSTres Popp 22134810e1bSTres Popp std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() { 22234810e1bSTres Popp return std::make_unique<ConvertMathToLibmPass>(); 22334810e1bSTres Popp } 224