1 //===-- ComplexToLibm.cpp - conversion from Complex to libm calls ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h" 10 11 #include "mlir/Dialect/Complex/IR/Complex.h" 12 #include "mlir/Dialect/Func/IR/FuncOps.h" 13 #include "mlir/IR/PatternMatch.h" 14 #include "mlir/Pass/Pass.h" 15 #include <optional> 16 17 namespace mlir { 18 #define GEN_PASS_DEF_CONVERTCOMPLEXTOLIBM 19 #include "mlir/Conversion/Passes.h.inc" 20 } // namespace mlir 21 22 using namespace mlir; 23 24 namespace { 25 // Functor to resolve the function name corresponding to the given complex 26 // result type. 27 struct ComplexTypeResolver { 28 std::optional<bool> operator()(Type type) const { 29 auto complexType = cast<ComplexType>(type); 30 auto elementType = complexType.getElementType(); 31 if (!isa<Float32Type, Float64Type>(elementType)) 32 return {}; 33 34 return elementType.getIntOrFloatBitWidth() == 64; 35 } 36 }; 37 38 // Functor to resolve the function name corresponding to the given float result 39 // type. 40 struct FloatTypeResolver { 41 std::optional<bool> operator()(Type type) const { 42 auto elementType = cast<FloatType>(type); 43 if (!isa<Float32Type, Float64Type>(elementType)) 44 return {}; 45 46 return elementType.getIntOrFloatBitWidth() == 64; 47 } 48 }; 49 50 // Pattern to convert scalar complex operations to calls to libm functions. 51 // Additionally the libm function signatures are declared. 52 // TypeResolver is a functor returning the libm function name according to the 53 // expected type double or float. 54 template <typename Op, typename TypeResolver = ComplexTypeResolver> 55 struct ScalarOpToLibmCall : public OpRewritePattern<Op> { 56 public: 57 using OpRewritePattern<Op>::OpRewritePattern; 58 ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc, 59 StringRef doubleFunc, PatternBenefit benefit) 60 : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc), 61 doubleFunc(doubleFunc){}; 62 63 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 64 65 private: 66 std::string floatFunc, doubleFunc; 67 }; 68 } // namespace 69 70 template <typename Op, typename TypeResolver> 71 LogicalResult ScalarOpToLibmCall<Op, TypeResolver>::matchAndRewrite( 72 Op op, PatternRewriter &rewriter) const { 73 auto module = SymbolTable::getNearestSymbolTable(op); 74 auto isDouble = TypeResolver()(op.getType()); 75 if (!isDouble.has_value()) 76 return failure(); 77 78 auto name = *isDouble ? doubleFunc : floatFunc; 79 80 auto opFunc = dyn_cast_or_null<SymbolOpInterface>( 81 SymbolTable::lookupSymbolIn(module, name)); 82 // Forward declare function if it hasn't already been 83 if (!opFunc) { 84 OpBuilder::InsertionGuard guard(rewriter); 85 rewriter.setInsertionPointToStart(&module->getRegion(0).front()); 86 auto opFunctionTy = FunctionType::get( 87 rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); 88 opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name, 89 opFunctionTy); 90 opFunc.setPrivate(); 91 } 92 assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name))); 93 94 rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(), 95 op->getOperands()); 96 97 return success(); 98 } 99 100 void mlir::populateComplexToLibmConversionPatterns(RewritePatternSet &patterns, 101 PatternBenefit benefit) { 102 patterns.add<ScalarOpToLibmCall<complex::PowOp>>(patterns.getContext(), 103 "cpowf", "cpow", benefit); 104 patterns.add<ScalarOpToLibmCall<complex::SqrtOp>>(patterns.getContext(), 105 "csqrtf", "csqrt", benefit); 106 patterns.add<ScalarOpToLibmCall<complex::TanhOp>>(patterns.getContext(), 107 "ctanhf", "ctanh", benefit); 108 patterns.add<ScalarOpToLibmCall<complex::CosOp>>(patterns.getContext(), 109 "ccosf", "ccos", benefit); 110 patterns.add<ScalarOpToLibmCall<complex::SinOp>>(patterns.getContext(), 111 "csinf", "csin", benefit); 112 patterns.add<ScalarOpToLibmCall<complex::ConjOp>>(patterns.getContext(), 113 "conjf", "conj", benefit); 114 patterns.add<ScalarOpToLibmCall<complex::LogOp>>(patterns.getContext(), 115 "clogf", "clog", benefit); 116 patterns.add<ScalarOpToLibmCall<complex::AbsOp, FloatTypeResolver>>( 117 patterns.getContext(), "cabsf", "cabs", benefit); 118 patterns.add<ScalarOpToLibmCall<complex::AngleOp, FloatTypeResolver>>( 119 patterns.getContext(), "cargf", "carg", benefit); 120 } 121 122 namespace { 123 struct ConvertComplexToLibmPass 124 : public impl::ConvertComplexToLibmBase<ConvertComplexToLibmPass> { 125 void runOnOperation() override; 126 }; 127 } // namespace 128 129 void ConvertComplexToLibmPass::runOnOperation() { 130 auto module = getOperation(); 131 132 RewritePatternSet patterns(&getContext()); 133 populateComplexToLibmConversionPatterns(patterns, /*benefit=*/1); 134 135 ConversionTarget target(getContext()); 136 target.addLegalDialect<func::FuncDialect>(); 137 target.addIllegalOp<complex::PowOp, complex::SqrtOp, complex::TanhOp, 138 complex::CosOp, complex::SinOp, complex::ConjOp, 139 complex::LogOp, complex::AbsOp, complex::AngleOp>(); 140 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 141 signalPassFailure(); 142 } 143 144 std::unique_ptr<OperationPass<ModuleOp>> 145 mlir::createConvertComplexToLibmPass() { 146 return std::make_unique<ConvertComplexToLibmPass>(); 147 } 148