1 //===-- ComplexToLibm.cpp - conversion from Complex to libm calls ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/Complex/IR/Complex.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/IR/PatternMatch.h" 15 16 using namespace mlir; 17 18 namespace { 19 // Functor to resolve the function name corresponding to the given complex 20 // result type. 21 struct ComplexTypeResolver { 22 llvm::Optional<bool> operator()(Type type) const { 23 auto complexType = type.cast<ComplexType>(); 24 auto elementType = complexType.getElementType(); 25 if (!elementType.isa<Float32Type, Float64Type>()) 26 return {}; 27 28 return elementType.getIntOrFloatBitWidth() == 64; 29 } 30 }; 31 32 // Functor to resolve the function name corresponding to the given float result 33 // type. 34 struct FloatTypeResolver { 35 llvm::Optional<bool> operator()(Type type) const { 36 auto elementType = type.cast<FloatType>(); 37 if (!elementType.isa<Float32Type, Float64Type>()) 38 return {}; 39 40 return elementType.getIntOrFloatBitWidth() == 64; 41 } 42 }; 43 44 // Pattern to convert scalar complex operations to calls to libm functions. 45 // Additionally the libm function signatures are declared. 46 // TypeResolver is a functor returning the libm function name according to the 47 // expected type double or float. 48 template <typename Op, typename TypeResolver = ComplexTypeResolver> 49 struct ScalarOpToLibmCall : public OpRewritePattern<Op> { 50 public: 51 using OpRewritePattern<Op>::OpRewritePattern; 52 ScalarOpToLibmCall<Op, TypeResolver>(MLIRContext *context, 53 StringRef floatFunc, 54 StringRef doubleFunc, 55 PatternBenefit benefit) 56 : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc), 57 doubleFunc(doubleFunc){}; 58 59 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 60 61 private: 62 std::string floatFunc, doubleFunc; 63 }; 64 } // namespace 65 66 template <typename Op, typename TypeResolver> 67 LogicalResult ScalarOpToLibmCall<Op, TypeResolver>::matchAndRewrite( 68 Op op, PatternRewriter &rewriter) const { 69 auto module = SymbolTable::getNearestSymbolTable(op); 70 auto isDouble = TypeResolver()(op.getType()); 71 if (!isDouble.has_value()) 72 return failure(); 73 74 auto name = isDouble.value() ? doubleFunc : floatFunc; 75 76 auto opFunc = dyn_cast_or_null<SymbolOpInterface>( 77 SymbolTable::lookupSymbolIn(module, name)); 78 // Forward declare function if it hasn't already been 79 if (!opFunc) { 80 OpBuilder::InsertionGuard guard(rewriter); 81 rewriter.setInsertionPointToStart(&module->getRegion(0).front()); 82 auto opFunctionTy = FunctionType::get( 83 rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); 84 opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name, 85 opFunctionTy); 86 opFunc.setPrivate(); 87 } 88 assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name))); 89 90 rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(), 91 op->getOperands()); 92 93 return success(); 94 } 95 96 void mlir::populateComplexToLibmConversionPatterns(RewritePatternSet &patterns, 97 PatternBenefit benefit) { 98 patterns.add<ScalarOpToLibmCall<complex::PowOp>>(patterns.getContext(), 99 "cpowf", "cpow", benefit); 100 patterns.add<ScalarOpToLibmCall<complex::SqrtOp>>(patterns.getContext(), 101 "csqrtf", "csqrt", benefit); 102 patterns.add<ScalarOpToLibmCall<complex::TanhOp>>(patterns.getContext(), 103 "ctanhf", "ctanh", benefit); 104 patterns.add<ScalarOpToLibmCall<complex::CosOp>>(patterns.getContext(), 105 "ccosf", "ccos", benefit); 106 patterns.add<ScalarOpToLibmCall<complex::SinOp>>(patterns.getContext(), 107 "csinf", "csin", benefit); 108 patterns.add<ScalarOpToLibmCall<complex::ConjOp>>(patterns.getContext(), 109 "conjf", "conj", benefit); 110 patterns.add<ScalarOpToLibmCall<complex::LogOp>>(patterns.getContext(), 111 "clogf", "clog", benefit); 112 patterns.add<ScalarOpToLibmCall<complex::AbsOp, FloatTypeResolver>>( 113 patterns.getContext(), "cabsf", "cabs", benefit); 114 patterns.add<ScalarOpToLibmCall<complex::AngleOp, FloatTypeResolver>>( 115 patterns.getContext(), "cargf", "carg", benefit); 116 } 117 118 namespace { 119 struct ConvertComplexToLibmPass 120 : public ConvertComplexToLibmBase<ConvertComplexToLibmPass> { 121 void runOnOperation() override; 122 }; 123 } // namespace 124 125 void ConvertComplexToLibmPass::runOnOperation() { 126 auto module = getOperation(); 127 128 RewritePatternSet patterns(&getContext()); 129 populateComplexToLibmConversionPatterns(patterns, /*benefit=*/1); 130 131 ConversionTarget target(getContext()); 132 target.addLegalDialect<func::FuncDialect>(); 133 target.addIllegalOp<complex::PowOp, complex::SqrtOp, complex::TanhOp, 134 complex::CosOp, complex::SinOp, complex::ConjOp, 135 complex::LogOp, complex::AbsOp, complex::AngleOp>(); 136 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 137 signalPassFailure(); 138 } 139 140 std::unique_ptr<OperationPass<ModuleOp>> 141 mlir::createConvertComplexToLibmPass() { 142 return std::make_unique<ConvertComplexToLibmPass>(); 143 } 144