xref: /llvm-project/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp (revision fdcb76f2480d2a0187641cc844e92f1d6c4b2635)
1e4978713SBenjamin Kramer //===-- ComplexToLibm.cpp - conversion from Complex to libm calls ---------===//
2e4978713SBenjamin Kramer //
3e4978713SBenjamin Kramer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4e4978713SBenjamin Kramer // See https://llvm.org/LICENSE.txt for license information.
5e4978713SBenjamin Kramer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6e4978713SBenjamin Kramer //
7e4978713SBenjamin Kramer //===----------------------------------------------------------------------===//
8e4978713SBenjamin Kramer 
9e4978713SBenjamin Kramer #include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h"
10e4978713SBenjamin Kramer 
11e4978713SBenjamin Kramer #include "mlir/Dialect/Complex/IR/Complex.h"
12e4978713SBenjamin Kramer #include "mlir/Dialect/Func/IR/FuncOps.h"
13e4978713SBenjamin Kramer #include "mlir/IR/PatternMatch.h"
1467d0d7acSMichele Scuttari #include "mlir/Pass/Pass.h"
151adfdab3SKazu Hirata #include <optional>
1667d0d7acSMichele Scuttari 
1767d0d7acSMichele Scuttari namespace mlir {
1867d0d7acSMichele Scuttari #define GEN_PASS_DEF_CONVERTCOMPLEXTOLIBM
1967d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc"
2067d0d7acSMichele Scuttari } // namespace mlir
21e4978713SBenjamin Kramer 
22e4978713SBenjamin Kramer using namespace mlir;
23e4978713SBenjamin Kramer 
24e4978713SBenjamin Kramer namespace {
25eaba6e0bSlewuathe // Functor to resolve the function name corresponding to the given complex
26eaba6e0bSlewuathe // result type.
27eaba6e0bSlewuathe struct ComplexTypeResolver {
operator ()__anonb1d42d3e0111::ComplexTypeResolver281adfdab3SKazu Hirata   std::optional<bool> operator()(Type type) const {
295550c821STres Popp     auto complexType = cast<ComplexType>(type);
30eaba6e0bSlewuathe     auto elementType = complexType.getElementType();
315550c821STres Popp     if (!isa<Float32Type, Float64Type>(elementType))
32eaba6e0bSlewuathe       return {};
33eaba6e0bSlewuathe 
34eaba6e0bSlewuathe     return elementType.getIntOrFloatBitWidth() == 64;
35eaba6e0bSlewuathe   }
36eaba6e0bSlewuathe };
37eaba6e0bSlewuathe 
38eaba6e0bSlewuathe // Functor to resolve the function name corresponding to the given float result
39eaba6e0bSlewuathe // type.
40eaba6e0bSlewuathe struct FloatTypeResolver {
operator ()__anonb1d42d3e0111::FloatTypeResolver411adfdab3SKazu Hirata   std::optional<bool> operator()(Type type) const {
425550c821STres Popp     auto elementType = cast<FloatType>(type);
435550c821STres Popp     if (!isa<Float32Type, Float64Type>(elementType))
44eaba6e0bSlewuathe       return {};
45eaba6e0bSlewuathe 
46eaba6e0bSlewuathe     return elementType.getIntOrFloatBitWidth() == 64;
47eaba6e0bSlewuathe   }
48eaba6e0bSlewuathe };
49eaba6e0bSlewuathe 
50e4978713SBenjamin Kramer // Pattern to convert scalar complex operations to calls to libm functions.
51e4978713SBenjamin Kramer // Additionally the libm function signatures are declared.
52eaba6e0bSlewuathe // TypeResolver is a functor returning the libm function name according to the
53eaba6e0bSlewuathe // expected type double or float.
54eaba6e0bSlewuathe template <typename Op, typename TypeResolver = ComplexTypeResolver>
55e4978713SBenjamin Kramer struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
56e4978713SBenjamin Kramer public:
57e4978713SBenjamin Kramer   using OpRewritePattern<Op>::OpRewritePattern;
ScalarOpToLibmCall__anonb1d42d3e0111::ScalarOpToLibmCall58a4ee55feSAlexander Batashev   ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc,
59a4ee55feSAlexander Batashev                      StringRef doubleFunc, PatternBenefit benefit)
60e4978713SBenjamin Kramer       : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
61e4978713SBenjamin Kramer         doubleFunc(doubleFunc){};
62e4978713SBenjamin Kramer 
63e4978713SBenjamin Kramer   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
64e4978713SBenjamin Kramer 
65e4978713SBenjamin Kramer private:
66e4978713SBenjamin Kramer   std::string floatFunc, doubleFunc;
67e4978713SBenjamin Kramer };
68e4978713SBenjamin Kramer } // namespace
69e4978713SBenjamin Kramer 
70eaba6e0bSlewuathe template <typename Op, typename TypeResolver>
matchAndRewrite(Op op,PatternRewriter & rewriter) const71eaba6e0bSlewuathe LogicalResult ScalarOpToLibmCall<Op, TypeResolver>::matchAndRewrite(
72eaba6e0bSlewuathe     Op op, PatternRewriter &rewriter) const {
73e4978713SBenjamin Kramer   auto module = SymbolTable::getNearestSymbolTable(op);
74eaba6e0bSlewuathe   auto isDouble = TypeResolver()(op.getType());
75491d2701SKazu Hirata   if (!isDouble.has_value())
76e4978713SBenjamin Kramer     return failure();
77e4978713SBenjamin Kramer 
784913e5daSFangrui Song   auto name = *isDouble ? doubleFunc : floatFunc;
79eaba6e0bSlewuathe 
80e4978713SBenjamin Kramer   auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
81e4978713SBenjamin Kramer       SymbolTable::lookupSymbolIn(module, name));
82e4978713SBenjamin Kramer   // Forward declare function if it hasn't already been
83e4978713SBenjamin Kramer   if (!opFunc) {
84e4978713SBenjamin Kramer     OpBuilder::InsertionGuard guard(rewriter);
85e4978713SBenjamin Kramer     rewriter.setInsertionPointToStart(&module->getRegion(0).front());
86e4978713SBenjamin Kramer     auto opFunctionTy = FunctionType::get(
87e4978713SBenjamin Kramer         rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
88e4978713SBenjamin Kramer     opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name,
89e4978713SBenjamin Kramer                                            opFunctionTy);
90e4978713SBenjamin Kramer     opFunc.setPrivate();
91e4978713SBenjamin Kramer   }
92e4978713SBenjamin Kramer   assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
93e4978713SBenjamin Kramer 
94eaba6e0bSlewuathe   rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
95eaba6e0bSlewuathe                                             op->getOperands());
96e4978713SBenjamin Kramer 
97e4978713SBenjamin Kramer   return success();
98e4978713SBenjamin Kramer }
99e4978713SBenjamin Kramer 
populateComplexToLibmConversionPatterns(RewritePatternSet & patterns,PatternBenefit benefit)100e4978713SBenjamin Kramer void mlir::populateComplexToLibmConversionPatterns(RewritePatternSet &patterns,
101e4978713SBenjamin Kramer                                                    PatternBenefit benefit) {
102e4978713SBenjamin Kramer   patterns.add<ScalarOpToLibmCall<complex::PowOp>>(patterns.getContext(),
103e4978713SBenjamin Kramer                                                    "cpowf", "cpow", benefit);
104e4978713SBenjamin Kramer   patterns.add<ScalarOpToLibmCall<complex::SqrtOp>>(patterns.getContext(),
105e4978713SBenjamin Kramer                                                     "csqrtf", "csqrt", benefit);
106e4978713SBenjamin Kramer   patterns.add<ScalarOpToLibmCall<complex::TanhOp>>(patterns.getContext(),
107e4978713SBenjamin Kramer                                                     "ctanhf", "ctanh", benefit);
1089f0869a6Slewuathe   patterns.add<ScalarOpToLibmCall<complex::CosOp>>(patterns.getContext(),
1099f0869a6Slewuathe                                                    "ccosf", "ccos", benefit);
1109f0869a6Slewuathe   patterns.add<ScalarOpToLibmCall<complex::SinOp>>(patterns.getContext(),
1119f0869a6Slewuathe                                                    "csinf", "csin", benefit);
11272ee11a8Slewuathe   patterns.add<ScalarOpToLibmCall<complex::ConjOp>>(patterns.getContext(),
11372ee11a8Slewuathe                                                     "conjf", "conj", benefit);
1147769505aSKai Sasaki   patterns.add<ScalarOpToLibmCall<complex::LogOp>>(patterns.getContext(),
1157769505aSKai Sasaki                                                    "clogf", "clog", benefit);
116eaba6e0bSlewuathe   patterns.add<ScalarOpToLibmCall<complex::AbsOp, FloatTypeResolver>>(
117eaba6e0bSlewuathe       patterns.getContext(), "cabsf", "cabs", benefit);
118f27deeeeSlewuathe   patterns.add<ScalarOpToLibmCall<complex::AngleOp, FloatTypeResolver>>(
119f27deeeeSlewuathe       patterns.getContext(), "cargf", "carg", benefit);
120*fdcb76f2SKai Sasaki   patterns.add<ScalarOpToLibmCall<complex::TanOp>>(patterns.getContext(),
121*fdcb76f2SKai Sasaki                                                    "ctanf", "ctan", benefit);
122e4978713SBenjamin Kramer }
123e4978713SBenjamin Kramer 
124e4978713SBenjamin Kramer namespace {
125e4978713SBenjamin Kramer struct ConvertComplexToLibmPass
12667d0d7acSMichele Scuttari     : public impl::ConvertComplexToLibmBase<ConvertComplexToLibmPass> {
127e4978713SBenjamin Kramer   void runOnOperation() override;
128e4978713SBenjamin Kramer };
129e4978713SBenjamin Kramer } // namespace
130e4978713SBenjamin Kramer 
runOnOperation()131e4978713SBenjamin Kramer void ConvertComplexToLibmPass::runOnOperation() {
132e4978713SBenjamin Kramer   auto module = getOperation();
133e4978713SBenjamin Kramer 
134e4978713SBenjamin Kramer   RewritePatternSet patterns(&getContext());
135e4978713SBenjamin Kramer   populateComplexToLibmConversionPatterns(patterns, /*benefit=*/1);
136e4978713SBenjamin Kramer 
137e4978713SBenjamin Kramer   ConversionTarget target(getContext());
138e4978713SBenjamin Kramer   target.addLegalDialect<func::FuncDialect>();
139eaba6e0bSlewuathe   target.addIllegalOp<complex::PowOp, complex::SqrtOp, complex::TanhOp,
1405605f17aSKai Sasaki                       complex::CosOp, complex::SinOp, complex::ConjOp,
141*fdcb76f2SKai Sasaki                       complex::LogOp, complex::AbsOp, complex::AngleOp,
142*fdcb76f2SKai Sasaki                       complex::TanOp>();
143e4978713SBenjamin Kramer   if (failed(applyPartialConversion(module, target, std::move(patterns))))
144e4978713SBenjamin Kramer     signalPassFailure();
145e4978713SBenjamin Kramer }
146e4978713SBenjamin Kramer 
147e4978713SBenjamin Kramer std::unique_ptr<OperationPass<ModuleOp>>
createConvertComplexToLibmPass()148e4978713SBenjamin Kramer mlir::createConvertComplexToLibmPass() {
149e4978713SBenjamin Kramer   return std::make_unique<ConvertComplexToLibmPass>();
150e4978713SBenjamin Kramer }
151