xref: /llvm-project/mlir/lib/Conversion/MathToEmitC/MathToEmitC.cpp (revision c0055ec434cbb132d7776f8b4c39e99b69fa97ea)
1 //===- MathToEmitC.cpp - Math to EmitC Patterns -----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MathToEmitC/MathToEmitC.h"
10 
11 #include "mlir/Dialect/EmitC/IR/EmitC.h"
12 #include "mlir/Dialect/Math/IR/Math.h"
13 #include "mlir/Transforms/DialectConversion.h"
14 
15 using namespace mlir;
16 
17 namespace {
18 template <typename OpType>
19 class LowerToEmitCCallOpaque : public OpRewritePattern<OpType> {
20   std::string calleeStr;
21   emitc::LanguageTarget languageTarget;
22 
23 public:
24   LowerToEmitCCallOpaque(MLIRContext *context, std::string calleeStr,
25                          emitc::LanguageTarget languageTarget)
26       : OpRewritePattern<OpType>(context), calleeStr(std::move(calleeStr)),
27         languageTarget(languageTarget) {}
28 
29   LogicalResult matchAndRewrite(OpType op,
30                                 PatternRewriter &rewriter) const override;
31 };
32 
33 template <typename OpType>
34 LogicalResult LowerToEmitCCallOpaque<OpType>::matchAndRewrite(
35     OpType op, PatternRewriter &rewriter) const {
36   if (!llvm::all_of(op->getOperandTypes(),
37                     llvm::IsaPred<Float32Type, Float64Type>) ||
38       !llvm::all_of(op->getResultTypes(),
39                     llvm::IsaPred<Float32Type, Float64Type>))
40     return rewriter.notifyMatchFailure(
41         op.getLoc(),
42         "expected all operands and results to be of type f32 or f64");
43   std::string modifiedCalleeStr = calleeStr;
44   if (languageTarget == emitc::LanguageTarget::cpp11) {
45     modifiedCalleeStr = "std::" + calleeStr;
46   } else if (languageTarget == emitc::LanguageTarget::c99) {
47     auto operandType = op->getOperandTypes()[0];
48     if (operandType.isF32())
49       modifiedCalleeStr = calleeStr + "f";
50   }
51   rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
52       op, op.getType(), modifiedCalleeStr, op->getOperands());
53   return success();
54 }
55 
56 } // namespace
57 
58 // Populates patterns to replace `math` operations with `emitc.call_opaque`,
59 // using function names consistent with those in <math.h>.
60 void mlir::populateConvertMathToEmitCPatterns(
61     RewritePatternSet &patterns, emitc::LanguageTarget languageTarget) {
62   auto *context = patterns.getContext();
63   patterns.insert<LowerToEmitCCallOpaque<math::FloorOp>>(context, "floor",
64                                                          languageTarget);
65   patterns.insert<LowerToEmitCCallOpaque<math::RoundOp>>(context, "round",
66                                                          languageTarget);
67   patterns.insert<LowerToEmitCCallOpaque<math::ExpOp>>(context, "exp",
68                                                        languageTarget);
69   patterns.insert<LowerToEmitCCallOpaque<math::CosOp>>(context, "cos",
70                                                        languageTarget);
71   patterns.insert<LowerToEmitCCallOpaque<math::SinOp>>(context, "sin",
72                                                        languageTarget);
73   patterns.insert<LowerToEmitCCallOpaque<math::AcosOp>>(context, "acos",
74                                                         languageTarget);
75   patterns.insert<LowerToEmitCCallOpaque<math::AsinOp>>(context, "asin",
76                                                         languageTarget);
77   patterns.insert<LowerToEmitCCallOpaque<math::Atan2Op>>(context, "atan2",
78                                                          languageTarget);
79   patterns.insert<LowerToEmitCCallOpaque<math::CeilOp>>(context, "ceil",
80                                                         languageTarget);
81   patterns.insert<LowerToEmitCCallOpaque<math::AbsFOp>>(context, "fabs",
82                                                         languageTarget);
83   patterns.insert<LowerToEmitCCallOpaque<math::PowFOp>>(context, "pow",
84                                                         languageTarget);
85 }
86