xref: /llvm-project/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp (revision b96f18b20c31449ef9a6878b5c2725a7cf65c552)
134810e1bSTres Popp //===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===//
234810e1bSTres Popp //
334810e1bSTres Popp // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
434810e1bSTres Popp // See https://llvm.org/LICENSE.txt for license information.
534810e1bSTres Popp // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
634810e1bSTres Popp //
734810e1bSTres Popp //===----------------------------------------------------------------------===//
834810e1bSTres Popp 
934810e1bSTres Popp #include "mlir/Conversion/MathToLibm/MathToLibm.h"
1034810e1bSTres Popp 
11abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
1223aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
139c442c7dSSlava Zakharin #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1434810e1bSTres Popp #include "mlir/Dialect/Math/IR/Math.h"
1599ef9eebSMatthias Springer #include "mlir/Dialect/Utils/IndexingUtils.h"
1699ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
1734810e1bSTres Popp #include "mlir/IR/BuiltinDialect.h"
1834810e1bSTres Popp #include "mlir/IR/PatternMatch.h"
1967d0d7acSMichele Scuttari #include "mlir/Pass/Pass.h"
208a9d4895SAlexander Belyaev #include "mlir/Transforms/DialectConversion.h"
2167d0d7acSMichele Scuttari 
2267d0d7acSMichele Scuttari namespace mlir {
2367d0d7acSMichele Scuttari #define GEN_PASS_DEF_CONVERTMATHTOLIBM
2467d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc"
2567d0d7acSMichele Scuttari } // namespace mlir
2634810e1bSTres Popp 
2734810e1bSTres Popp using namespace mlir;
2834810e1bSTres Popp 
2934810e1bSTres Popp namespace {
3034810e1bSTres Popp // Pattern to convert vector operations to scalar operations. This is needed as
3134810e1bSTres Popp // libm calls require scalars.
3234810e1bSTres Popp template <typename Op>
3334810e1bSTres Popp struct VecOpToScalarOp : public OpRewritePattern<Op> {
3434810e1bSTres Popp public:
3534810e1bSTres Popp   using OpRewritePattern<Op>::OpRewritePattern;
3634810e1bSTres Popp 
3734810e1bSTres Popp   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
3834810e1bSTres Popp };
39a48adc56SBenjamin Kramer // Pattern to promote an op of a smaller floating point type to F32.
40a48adc56SBenjamin Kramer template <typename Op>
41a48adc56SBenjamin Kramer struct PromoteOpToF32 : public OpRewritePattern<Op> {
42a48adc56SBenjamin Kramer public:
43a48adc56SBenjamin Kramer   using OpRewritePattern<Op>::OpRewritePattern;
44a48adc56SBenjamin Kramer 
45a48adc56SBenjamin Kramer   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
46a48adc56SBenjamin Kramer };
4734810e1bSTres Popp // Pattern to convert scalar math operations to calls to libm functions.
4834810e1bSTres Popp // Additionally the libm function signatures are declared.
4934810e1bSTres Popp template <typename Op>
5034810e1bSTres Popp struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
5134810e1bSTres Popp public:
5234810e1bSTres Popp   using OpRewritePattern<Op>::OpRewritePattern;
53a4ee55feSAlexander Batashev   ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc,
548a9d4895SAlexander Belyaev                      StringRef doubleFunc)
558a9d4895SAlexander Belyaev       : OpRewritePattern<Op>(context), floatFunc(floatFunc),
5634810e1bSTres Popp         doubleFunc(doubleFunc){};
5734810e1bSTres Popp 
5834810e1bSTres Popp   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
5934810e1bSTres Popp 
6034810e1bSTres Popp private:
6134810e1bSTres Popp   std::string floatFunc, doubleFunc;
6234810e1bSTres Popp };
633bf1f0e7SAlexander Belyaev 
643bf1f0e7SAlexander Belyaev template <typename OpTy>
653bf1f0e7SAlexander Belyaev void populatePatternsForOp(RewritePatternSet &patterns, MLIRContext *ctx,
663bf1f0e7SAlexander Belyaev                            StringRef floatFunc, StringRef doubleFunc) {
673bf1f0e7SAlexander Belyaev   patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx);
683bf1f0e7SAlexander Belyaev   patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, floatFunc, doubleFunc);
693bf1f0e7SAlexander Belyaev }
703bf1f0e7SAlexander Belyaev 
7134810e1bSTres Popp } // namespace
7234810e1bSTres Popp 
7334810e1bSTres Popp template <typename Op>
7434810e1bSTres Popp LogicalResult
7534810e1bSTres Popp VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
7634810e1bSTres Popp   auto opType = op.getType();
7734810e1bSTres Popp   auto loc = op.getLoc();
785550c821STres Popp   auto vecType = dyn_cast<VectorType>(opType);
7934810e1bSTres Popp 
8034810e1bSTres Popp   if (!vecType)
8134810e1bSTres Popp     return failure();
8234810e1bSTres Popp   if (!vecType.hasRank())
8334810e1bSTres Popp     return failure();
8434810e1bSTres Popp   auto shape = vecType.getShape();
85921d91f3SAdrian Kuegel   int64_t numElements = vecType.getNumElements();
8634810e1bSTres Popp 
87a54f4eaeSMogball   Value result = rewriter.create<arith::ConstantOp>(
8834810e1bSTres Popp       loc, DenseElementsAttr::get(
8934810e1bSTres Popp                vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
907a69a9d7SNicolas Vasilache   SmallVector<int64_t> strides = computeStrides(shape);
91921d91f3SAdrian Kuegel   for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
92203fad47SNicolas Vasilache     SmallVector<int64_t> positions = delinearize(linearIndex, strides);
9334810e1bSTres Popp     SmallVector<Value> operands;
9434810e1bSTres Popp     for (auto input : op->getOperands())
9534810e1bSTres Popp       operands.push_back(
96921d91f3SAdrian Kuegel           rewriter.create<vector::ExtractOp>(loc, input, positions));
9734810e1bSTres Popp     Value scalarOp =
9834810e1bSTres Popp         rewriter.create<Op>(loc, vecType.getElementType(), operands);
99921d91f3SAdrian Kuegel     result =
100921d91f3SAdrian Kuegel         rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
10134810e1bSTres Popp   }
10234810e1bSTres Popp   rewriter.replaceOp(op, {result});
10334810e1bSTres Popp   return success();
10434810e1bSTres Popp }
10534810e1bSTres Popp 
10634810e1bSTres Popp template <typename Op>
10734810e1bSTres Popp LogicalResult
108a48adc56SBenjamin Kramer PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
109a48adc56SBenjamin Kramer   auto opType = op.getType();
1105550c821STres Popp   if (!isa<Float16Type, BFloat16Type>(opType))
111a48adc56SBenjamin Kramer     return failure();
112a48adc56SBenjamin Kramer 
113a48adc56SBenjamin Kramer   auto loc = op.getLoc();
114a48adc56SBenjamin Kramer   auto f32 = rewriter.getF32Type();
115a48adc56SBenjamin Kramer   auto extendedOperands = llvm::to_vector(
116a48adc56SBenjamin Kramer       llvm::map_range(op->getOperands(), [&](Value operand) -> Value {
117a48adc56SBenjamin Kramer         return rewriter.create<arith::ExtFOp>(loc, f32, operand);
118a48adc56SBenjamin Kramer       }));
119a48adc56SBenjamin Kramer   auto newOp = rewriter.create<Op>(loc, f32, extendedOperands);
120a48adc56SBenjamin Kramer   rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, opType, newOp);
121a48adc56SBenjamin Kramer   return success();
122a48adc56SBenjamin Kramer }
123a48adc56SBenjamin Kramer 
124a48adc56SBenjamin Kramer template <typename Op>
125a48adc56SBenjamin Kramer LogicalResult
12634810e1bSTres Popp ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
12734810e1bSTres Popp                                         PatternRewriter &rewriter) const {
1281ebf7ce9STres Popp   auto module = SymbolTable::getNearestSymbolTable(op);
12934810e1bSTres Popp   auto type = op.getType();
1305550c821STres Popp   if (!isa<Float32Type, Float64Type>(type))
13134810e1bSTres Popp     return failure();
13234810e1bSTres Popp 
13334810e1bSTres Popp   auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
1341ebf7ce9STres Popp   auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
1351ebf7ce9STres Popp       SymbolTable::lookupSymbolIn(module, name));
13634810e1bSTres Popp   // Forward declare function if it hasn't already been
13734810e1bSTres Popp   if (!opFunc) {
13834810e1bSTres Popp     OpBuilder::InsertionGuard guard(rewriter);
1391ebf7ce9STres Popp     rewriter.setInsertionPointToStart(&module->getRegion(0).front());
14034810e1bSTres Popp     auto opFunctionTy = FunctionType::get(
14134810e1bSTres Popp         rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
14258ceae95SRiver Riddle     opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name,
14358ceae95SRiver Riddle                                            opFunctionTy);
14434810e1bSTres Popp     opFunc.setPrivate();
1459c442c7dSSlava Zakharin 
1469c442c7dSSlava Zakharin     // By definition Math dialect operations imply LLVM's "readnone"
1479c442c7dSSlava Zakharin     // function attribute, so we can set it here to provide more
1489c442c7dSSlava Zakharin     // optimization opportunities (e.g. LICM) for backends targeting LLVM IR.
1499c442c7dSSlava Zakharin     // This will have to be changed, when strict FP behavior is supported
1509c442c7dSSlava Zakharin     // by Math dialect.
1519c442c7dSSlava Zakharin     opFunc->setAttr(LLVM::LLVMDialect::getReadnoneAttrName(),
1529c442c7dSSlava Zakharin                     UnitAttr::get(rewriter.getContext()));
15334810e1bSTres Popp   }
1547ceffae1SRiver Riddle   assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
15534810e1bSTres Popp 
15623aa5a74SRiver Riddle   rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
1571ebf7ce9STres Popp                                             op->getOperands());
15834810e1bSTres Popp 
15934810e1bSTres Popp   return success();
16034810e1bSTres Popp }
16134810e1bSTres Popp 
1628a9d4895SAlexander Belyaev void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) {
1638a9d4895SAlexander Belyaev   MLIRContext *ctx = patterns.getContext();
1643bf1f0e7SAlexander Belyaev 
16550b93733SCorentin Ferry   populatePatternsForOp<math::AbsFOp>(patterns, ctx, "fabsf", "fabs");
166f7250179SFrederik Harwath   populatePatternsForOp<math::AcosOp>(patterns, ctx, "acosf", "acos");
167b8dca4faSVivek Khandelwal   populatePatternsForOp<math::AcoshOp>(patterns, ctx, "acoshf", "acosh");
168b8dca4faSVivek Khandelwal   populatePatternsForOp<math::AsinOp>(patterns, ctx, "asinf", "asin");
169b8dca4faSVivek Khandelwal   populatePatternsForOp<math::AsinhOp>(patterns, ctx, "asinhf", "asinh");
1703bf1f0e7SAlexander Belyaev   populatePatternsForOp<math::Atan2Op>(patterns, ctx, "atan2f", "atan2");
1713bf1f0e7SAlexander Belyaev   populatePatternsForOp<math::AtanOp>(patterns, ctx, "atanf", "atan");
172b8dca4faSVivek Khandelwal   populatePatternsForOp<math::AtanhOp>(patterns, ctx, "atanhf", "atanh");
1733bf1f0e7SAlexander Belyaev   populatePatternsForOp<math::CbrtOp>(patterns, ctx, "cbrtf", "cbrt");
1743bf1f0e7SAlexander Belyaev   populatePatternsForOp<math::CeilOp>(patterns, ctx, "ceilf", "ceil");
1753bf1f0e7SAlexander Belyaev   populatePatternsForOp<math::CosOp>(patterns, ctx, "cosf", "cos");
176762964e9SSungsoon Cho   populatePatternsForOp<math::CoshOp>(patterns, ctx, "coshf", "cosh");
1773bf1f0e7SAlexander Belyaev   populatePatternsForOp<math::ErfOp>(patterns, ctx, "erff", "erf");
17850b93733SCorentin Ferry   populatePatternsForOp<math::ExpOp>(patterns, ctx, "expf", "exp");
17950b93733SCorentin Ferry   populatePatternsForOp<math::Exp2Op>(patterns, ctx, "exp2f", "exp2");
1803bf1f0e7SAlexander Belyaev   populatePatternsForOp<math::ExpM1Op>(patterns, ctx, "expm1f", "expm1");
1813bf1f0e7SAlexander Belyaev   populatePatternsForOp<math::FloorOp>(patterns, ctx, "floorf", "floor");
18250b93733SCorentin Ferry   populatePatternsForOp<math::FmaOp>(patterns, ctx, "fmaf", "fma");
18350b93733SCorentin Ferry   populatePatternsForOp<math::LogOp>(patterns, ctx, "logf", "log");
18450b93733SCorentin Ferry   populatePatternsForOp<math::Log2Op>(patterns, ctx, "log2f", "log2");
18550b93733SCorentin Ferry   populatePatternsForOp<math::Log10Op>(patterns, ctx, "log10f", "log10");
1863bf1f0e7SAlexander Belyaev   populatePatternsForOp<math::Log1pOp>(patterns, ctx, "log1pf", "log1p");
18750b93733SCorentin Ferry   populatePatternsForOp<math::PowFOp>(patterns, ctx, "powf", "pow");
1883bf1f0e7SAlexander Belyaev   populatePatternsForOp<math::RoundEvenOp>(patterns, ctx, "roundevenf",
1898a9d4895SAlexander Belyaev                                            "roundeven");
1903bf1f0e7SAlexander Belyaev   populatePatternsForOp<math::RoundOp>(patterns, ctx, "roundf", "round");
1913bf1f0e7SAlexander Belyaev   populatePatternsForOp<math::SinOp>(patterns, ctx, "sinf", "sin");
192aa165edcSRob Suderman   populatePatternsForOp<math::SinhOp>(patterns, ctx, "sinhf", "sinh");
19350b93733SCorentin Ferry   populatePatternsForOp<math::SqrtOp>(patterns, ctx, "sqrtf", "sqrt");
194*b96f18b2SIvy Zhang   populatePatternsForOp<math::RsqrtOp>(patterns, ctx, "rsqrtf", "rsqrt");
1953bf1f0e7SAlexander Belyaev   populatePatternsForOp<math::TanOp>(patterns, ctx, "tanf", "tan");
1963bf1f0e7SAlexander Belyaev   populatePatternsForOp<math::TanhOp>(patterns, ctx, "tanhf", "tanh");
1973bf1f0e7SAlexander Belyaev   populatePatternsForOp<math::TruncOp>(patterns, ctx, "truncf", "trunc");
19834810e1bSTres Popp }
19934810e1bSTres Popp 
20034810e1bSTres Popp namespace {
20134810e1bSTres Popp struct ConvertMathToLibmPass
20267d0d7acSMichele Scuttari     : public impl::ConvertMathToLibmBase<ConvertMathToLibmPass> {
20334810e1bSTres Popp   void runOnOperation() override;
20434810e1bSTres Popp };
20534810e1bSTres Popp } // namespace
20634810e1bSTres Popp 
20734810e1bSTres Popp void ConvertMathToLibmPass::runOnOperation() {
20834810e1bSTres Popp   auto module = getOperation();
20934810e1bSTres Popp 
21034810e1bSTres Popp   RewritePatternSet patterns(&getContext());
2118a9d4895SAlexander Belyaev   populateMathToLibmConversionPatterns(patterns);
21234810e1bSTres Popp 
21334810e1bSTres Popp   ConversionTarget target(getContext());
214abc362a1SJakub Kuderski   target.addLegalDialect<arith::ArithDialect, BuiltinDialect, func::FuncDialect,
215abc362a1SJakub Kuderski                          vector::VectorDialect>();
21634810e1bSTres Popp   target.addIllegalDialect<math::MathDialect>();
21734810e1bSTres Popp   if (failed(applyPartialConversion(module, target, std::move(patterns))))
21834810e1bSTres Popp     signalPassFailure();
21934810e1bSTres Popp }
22034810e1bSTres Popp 
22134810e1bSTres Popp std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() {
22234810e1bSTres Popp   return std::make_unique<ConvertMathToLibmPass>();
22334810e1bSTres Popp }
224