1 //===-- ComplexToLibm.cpp - conversion from Complex to libm calls ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h" 10 11 #include "mlir/Dialect/Complex/IR/Complex.h" 12 #include "mlir/Dialect/Func/IR/FuncOps.h" 13 #include "mlir/IR/PatternMatch.h" 14 #include "mlir/Pass/Pass.h" 15 16 namespace mlir { 17 #define GEN_PASS_DEF_CONVERTCOMPLEXTOLIBMPASS 18 #include "mlir/Conversion/Passes.h.inc" 19 } // namespace mlir 20 21 using namespace mlir; 22 23 namespace { 24 // Functor to resolve the function name corresponding to the given complex 25 // result type. 26 struct ComplexTypeResolver { 27 llvm::Optional<bool> operator()(Type type) const { 28 auto complexType = type.cast<ComplexType>(); 29 auto elementType = complexType.getElementType(); 30 if (!elementType.isa<Float32Type, Float64Type>()) 31 return {}; 32 33 return elementType.getIntOrFloatBitWidth() == 64; 34 } 35 }; 36 37 // Functor to resolve the function name corresponding to the given float result 38 // type. 39 struct FloatTypeResolver { 40 llvm::Optional<bool> operator()(Type type) const { 41 auto elementType = type.cast<FloatType>(); 42 if (!elementType.isa<Float32Type, Float64Type>()) 43 return {}; 44 45 return elementType.getIntOrFloatBitWidth() == 64; 46 } 47 }; 48 49 // Pattern to convert scalar complex operations to calls to libm functions. 50 // Additionally the libm function signatures are declared. 51 // TypeResolver is a functor returning the libm function name according to the 52 // expected type double or float. 53 template <typename Op, typename TypeResolver = ComplexTypeResolver> 54 struct ScalarOpToLibmCall : public OpRewritePattern<Op> { 55 public: 56 using OpRewritePattern<Op>::OpRewritePattern; 57 ScalarOpToLibmCall<Op, TypeResolver>(MLIRContext *context, 58 StringRef floatFunc, 59 StringRef doubleFunc, 60 PatternBenefit benefit) 61 : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc), 62 doubleFunc(doubleFunc){}; 63 64 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 65 66 private: 67 std::string floatFunc, doubleFunc; 68 }; 69 } // namespace 70 71 template <typename Op, typename TypeResolver> 72 LogicalResult ScalarOpToLibmCall<Op, TypeResolver>::matchAndRewrite( 73 Op op, PatternRewriter &rewriter) const { 74 auto module = SymbolTable::getNearestSymbolTable(op); 75 auto isDouble = TypeResolver()(op.getType()); 76 if (!isDouble.has_value()) 77 return failure(); 78 79 auto name = isDouble.value() ? doubleFunc : floatFunc; 80 81 auto opFunc = dyn_cast_or_null<SymbolOpInterface>( 82 SymbolTable::lookupSymbolIn(module, name)); 83 // Forward declare function if it hasn't already been 84 if (!opFunc) { 85 OpBuilder::InsertionGuard guard(rewriter); 86 rewriter.setInsertionPointToStart(&module->getRegion(0).front()); 87 auto opFunctionTy = FunctionType::get( 88 rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); 89 opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name, 90 opFunctionTy); 91 opFunc.setPrivate(); 92 } 93 assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name))); 94 95 rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(), 96 op->getOperands()); 97 98 return success(); 99 } 100 101 void mlir::populateComplexToLibmConversionPatterns(RewritePatternSet &patterns, 102 PatternBenefit benefit) { 103 patterns.add<ScalarOpToLibmCall<complex::PowOp>>(patterns.getContext(), 104 "cpowf", "cpow", benefit); 105 patterns.add<ScalarOpToLibmCall<complex::SqrtOp>>(patterns.getContext(), 106 "csqrtf", "csqrt", benefit); 107 patterns.add<ScalarOpToLibmCall<complex::TanhOp>>(patterns.getContext(), 108 "ctanhf", "ctanh", benefit); 109 patterns.add<ScalarOpToLibmCall<complex::CosOp>>(patterns.getContext(), 110 "ccosf", "ccos", benefit); 111 patterns.add<ScalarOpToLibmCall<complex::SinOp>>(patterns.getContext(), 112 "csinf", "csin", benefit); 113 patterns.add<ScalarOpToLibmCall<complex::ConjOp>>(patterns.getContext(), 114 "conjf", "conj", benefit); 115 patterns.add<ScalarOpToLibmCall<complex::LogOp>>(patterns.getContext(), 116 "clogf", "clog", benefit); 117 patterns.add<ScalarOpToLibmCall<complex::AbsOp, FloatTypeResolver>>( 118 patterns.getContext(), "cabsf", "cabs", benefit); 119 patterns.add<ScalarOpToLibmCall<complex::AngleOp, FloatTypeResolver>>( 120 patterns.getContext(), "cargf", "carg", benefit); 121 } 122 123 namespace { 124 struct ConvertComplexToLibmPass 125 : public impl::ConvertComplexToLibmPassBase<ConvertComplexToLibmPass> { 126 using ConvertComplexToLibmPassBase::ConvertComplexToLibmPassBase; 127 128 void runOnOperation() override; 129 }; 130 } // namespace 131 132 void ConvertComplexToLibmPass::runOnOperation() { 133 auto module = getOperation(); 134 135 RewritePatternSet patterns(&getContext()); 136 populateComplexToLibmConversionPatterns(patterns, /*benefit=*/1); 137 138 ConversionTarget target(getContext()); 139 target.addLegalDialect<func::FuncDialect>(); 140 target.addIllegalOp<complex::PowOp, complex::SqrtOp, complex::TanhOp, 141 complex::CosOp, complex::SinOp, complex::ConjOp, 142 complex::LogOp, complex::AbsOp, complex::AngleOp>(); 143 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 144 signalPassFailure(); 145 } 146 147 std::unique_ptr<OperationPass<ModuleOp>> 148 mlir::createConvertComplexToLibmPass() { 149 return std::make_unique<ConvertComplexToLibmPass>(); 150 } 151