1 //===- MathToEmitC.cpp - Math to EmitC Patterns -----------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MathToEmitC/MathToEmitC.h" 10 11 #include "mlir/Dialect/EmitC/IR/EmitC.h" 12 #include "mlir/Dialect/Math/IR/Math.h" 13 #include "mlir/Transforms/DialectConversion.h" 14 15 using namespace mlir; 16 17 namespace { 18 template <typename OpType> 19 class LowerToEmitCCallOpaque : public OpRewritePattern<OpType> { 20 std::string calleeStr; 21 emitc::LanguageTarget languageTarget; 22 23 public: 24 LowerToEmitCCallOpaque(MLIRContext *context, std::string calleeStr, 25 emitc::LanguageTarget languageTarget) 26 : OpRewritePattern<OpType>(context), calleeStr(std::move(calleeStr)), 27 languageTarget(languageTarget) {} 28 29 LogicalResult matchAndRewrite(OpType op, 30 PatternRewriter &rewriter) const override; 31 }; 32 33 template <typename OpType> 34 LogicalResult LowerToEmitCCallOpaque<OpType>::matchAndRewrite( 35 OpType op, PatternRewriter &rewriter) const { 36 if (!llvm::all_of(op->getOperandTypes(), 37 llvm::IsaPred<Float32Type, Float64Type>) || 38 !llvm::all_of(op->getResultTypes(), 39 llvm::IsaPred<Float32Type, Float64Type>)) 40 return rewriter.notifyMatchFailure( 41 op.getLoc(), 42 "expected all operands and results to be of type f32 or f64"); 43 std::string modifiedCalleeStr = calleeStr; 44 if (languageTarget == emitc::LanguageTarget::cpp11) { 45 modifiedCalleeStr = "std::" + calleeStr; 46 } else if (languageTarget == emitc::LanguageTarget::c99) { 47 auto operandType = op->getOperandTypes()[0]; 48 if (operandType.isF32()) 49 modifiedCalleeStr = calleeStr + "f"; 50 } 51 rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>( 52 op, op.getType(), modifiedCalleeStr, op->getOperands()); 53 return success(); 54 } 55 56 } // namespace 57 58 // Populates patterns to replace `math` operations with `emitc.call_opaque`, 59 // using function names consistent with those in <math.h>. 60 void mlir::populateConvertMathToEmitCPatterns( 61 RewritePatternSet &patterns, emitc::LanguageTarget languageTarget) { 62 auto *context = patterns.getContext(); 63 patterns.insert<LowerToEmitCCallOpaque<math::FloorOp>>(context, "floor", 64 languageTarget); 65 patterns.insert<LowerToEmitCCallOpaque<math::RoundOp>>(context, "round", 66 languageTarget); 67 patterns.insert<LowerToEmitCCallOpaque<math::ExpOp>>(context, "exp", 68 languageTarget); 69 patterns.insert<LowerToEmitCCallOpaque<math::CosOp>>(context, "cos", 70 languageTarget); 71 patterns.insert<LowerToEmitCCallOpaque<math::SinOp>>(context, "sin", 72 languageTarget); 73 patterns.insert<LowerToEmitCCallOpaque<math::AcosOp>>(context, "acos", 74 languageTarget); 75 patterns.insert<LowerToEmitCCallOpaque<math::AsinOp>>(context, "asin", 76 languageTarget); 77 patterns.insert<LowerToEmitCCallOpaque<math::Atan2Op>>(context, "atan2", 78 languageTarget); 79 patterns.insert<LowerToEmitCCallOpaque<math::CeilOp>>(context, "ceil", 80 languageTarget); 81 patterns.insert<LowerToEmitCCallOpaque<math::AbsFOp>>(context, "fabs", 82 languageTarget); 83 patterns.insert<LowerToEmitCCallOpaque<math::PowFOp>>(context, "pow", 84 languageTarget); 85 } 86