1 //===-- ComplexToLibm.cpp - conversion from Complex to libm calls ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h" 10 11 #include "mlir/Dialect/Complex/IR/Complex.h" 12 #include "mlir/Dialect/Func/IR/FuncOps.h" 13 #include "mlir/IR/PatternMatch.h" 14 #include "mlir/Pass/Pass.h" 15 #include <optional> 16 17 namespace mlir { 18 #define GEN_PASS_DEF_CONVERTCOMPLEXTOLIBM 19 #include "mlir/Conversion/Passes.h.inc" 20 } // namespace mlir 21 22 using namespace mlir; 23 24 namespace { 25 // Functor to resolve the function name corresponding to the given complex 26 // result type. 27 struct ComplexTypeResolver { 28 std::optional<bool> operator()(Type type) const { 29 auto complexType = type.cast<ComplexType>(); 30 auto elementType = complexType.getElementType(); 31 if (!elementType.isa<Float32Type, Float64Type>()) 32 return {}; 33 34 return elementType.getIntOrFloatBitWidth() == 64; 35 } 36 }; 37 38 // Functor to resolve the function name corresponding to the given float result 39 // type. 40 struct FloatTypeResolver { 41 std::optional<bool> operator()(Type type) const { 42 auto elementType = type.cast<FloatType>(); 43 if (!elementType.isa<Float32Type, Float64Type>()) 44 return {}; 45 46 return elementType.getIntOrFloatBitWidth() == 64; 47 } 48 }; 49 50 // Pattern to convert scalar complex operations to calls to libm functions. 51 // Additionally the libm function signatures are declared. 52 // TypeResolver is a functor returning the libm function name according to the 53 // expected type double or float. 54 template <typename Op, typename TypeResolver = ComplexTypeResolver> 55 struct ScalarOpToLibmCall : public OpRewritePattern<Op> { 56 public: 57 using OpRewritePattern<Op>::OpRewritePattern; 58 ScalarOpToLibmCall<Op, TypeResolver>(MLIRContext *context, 59 StringRef floatFunc, 60 StringRef doubleFunc, 61 PatternBenefit benefit) 62 : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc), 63 doubleFunc(doubleFunc){}; 64 65 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 66 67 private: 68 std::string floatFunc, doubleFunc; 69 }; 70 } // namespace 71 72 template <typename Op, typename TypeResolver> 73 LogicalResult ScalarOpToLibmCall<Op, TypeResolver>::matchAndRewrite( 74 Op op, PatternRewriter &rewriter) const { 75 auto module = SymbolTable::getNearestSymbolTable(op); 76 auto isDouble = TypeResolver()(op.getType()); 77 if (!isDouble.has_value()) 78 return failure(); 79 80 auto name = isDouble.value() ? doubleFunc : floatFunc; 81 82 auto opFunc = dyn_cast_or_null<SymbolOpInterface>( 83 SymbolTable::lookupSymbolIn(module, name)); 84 // Forward declare function if it hasn't already been 85 if (!opFunc) { 86 OpBuilder::InsertionGuard guard(rewriter); 87 rewriter.setInsertionPointToStart(&module->getRegion(0).front()); 88 auto opFunctionTy = FunctionType::get( 89 rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); 90 opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name, 91 opFunctionTy); 92 opFunc.setPrivate(); 93 } 94 assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name))); 95 96 rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(), 97 op->getOperands()); 98 99 return success(); 100 } 101 102 void mlir::populateComplexToLibmConversionPatterns(RewritePatternSet &patterns, 103 PatternBenefit benefit) { 104 patterns.add<ScalarOpToLibmCall<complex::PowOp>>(patterns.getContext(), 105 "cpowf", "cpow", benefit); 106 patterns.add<ScalarOpToLibmCall<complex::SqrtOp>>(patterns.getContext(), 107 "csqrtf", "csqrt", benefit); 108 patterns.add<ScalarOpToLibmCall<complex::TanhOp>>(patterns.getContext(), 109 "ctanhf", "ctanh", benefit); 110 patterns.add<ScalarOpToLibmCall<complex::CosOp>>(patterns.getContext(), 111 "ccosf", "ccos", benefit); 112 patterns.add<ScalarOpToLibmCall<complex::SinOp>>(patterns.getContext(), 113 "csinf", "csin", benefit); 114 patterns.add<ScalarOpToLibmCall<complex::ConjOp>>(patterns.getContext(), 115 "conjf", "conj", benefit); 116 patterns.add<ScalarOpToLibmCall<complex::LogOp>>(patterns.getContext(), 117 "clogf", "clog", benefit); 118 patterns.add<ScalarOpToLibmCall<complex::AbsOp, FloatTypeResolver>>( 119 patterns.getContext(), "cabsf", "cabs", benefit); 120 patterns.add<ScalarOpToLibmCall<complex::AngleOp, FloatTypeResolver>>( 121 patterns.getContext(), "cargf", "carg", benefit); 122 } 123 124 namespace { 125 struct ConvertComplexToLibmPass 126 : public impl::ConvertComplexToLibmBase<ConvertComplexToLibmPass> { 127 void runOnOperation() override; 128 }; 129 } // namespace 130 131 void ConvertComplexToLibmPass::runOnOperation() { 132 auto module = getOperation(); 133 134 RewritePatternSet patterns(&getContext()); 135 populateComplexToLibmConversionPatterns(patterns, /*benefit=*/1); 136 137 ConversionTarget target(getContext()); 138 target.addLegalDialect<func::FuncDialect>(); 139 target.addIllegalOp<complex::PowOp, complex::SqrtOp, complex::TanhOp, 140 complex::CosOp, complex::SinOp, complex::ConjOp, 141 complex::LogOp, complex::AbsOp, complex::AngleOp>(); 142 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 143 signalPassFailure(); 144 } 145 146 std::unique_ptr<OperationPass<ModuleOp>> 147 mlir::createConvertComplexToLibmPass() { 148 return std::make_unique<ConvertComplexToLibmPass>(); 149 } 150