xref: /llvm-project/flang/lib/Optimizer/Transforms/FunctionAttr.cpp (revision f3cf24fcc46ab1b9612d7dcb55ec5f18ea2dc62f)
1 //===- FunctionAttr.cpp ---------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 //===----------------------------------------------------------------------===//
10 /// \file
11 /// This is a generic pass for adding attributes to functions.
12 //===----------------------------------------------------------------------===//
13 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
14 #include "flang/Optimizer/Support/InternalNames.h"
15 #include "flang/Optimizer/Transforms/Passes.h"
16 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 
19 namespace fir {
20 #define GEN_PASS_DEF_FUNCTIONATTR
21 #include "flang/Optimizer/Transforms/Passes.h.inc"
22 } // namespace fir
23 
24 #define DEBUG_TYPE "func-attr"
25 
26 namespace {
27 
28 class FunctionAttrPass : public fir::impl::FunctionAttrBase<FunctionAttrPass> {
29 public:
30   FunctionAttrPass(const fir::FunctionAttrOptions &options) {
31     framePointerKind = options.framePointerKind;
32     noInfsFPMath = options.noInfsFPMath;
33     noNaNsFPMath = options.noNaNsFPMath;
34     approxFuncFPMath = options.approxFuncFPMath;
35     noSignedZerosFPMath = options.noSignedZerosFPMath;
36     unsafeFPMath = options.unsafeFPMath;
37   }
38   FunctionAttrPass() {}
39   void runOnOperation() override;
40 };
41 
42 } // namespace
43 
44 void FunctionAttrPass::runOnOperation() {
45   LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
46   mlir::func::FuncOp func = getOperation();
47 
48   LLVM_DEBUG(llvm::dbgs() << "Func-name:" << func.getSymName() << "\n");
49 
50   llvm::StringRef name = func.getSymName();
51   auto deconstructed = fir::NameUniquer::deconstruct(name);
52   bool isFromModule = !deconstructed.second.modules.empty();
53 
54   if ((isFromModule || !func.isDeclaration()) &&
55       !fir::hasBindcAttr(func.getOperation())) {
56     llvm::StringRef nocapture = mlir::LLVM::LLVMDialect::getNoCaptureAttrName();
57     mlir::UnitAttr unitAttr = mlir::UnitAttr::get(func.getContext());
58 
59     for (auto [index, argType] : llvm::enumerate(func.getArgumentTypes())) {
60       if (mlir::isa<fir::ReferenceType>(argType) &&
61           !func.getArgAttr(index, fir::getTargetAttrName()) &&
62           !func.getArgAttr(index, fir::getAsynchronousAttrName()) &&
63           !func.getArgAttr(index, fir::getVolatileAttrName()))
64         func.setArgAttr(index, nocapture, unitAttr);
65     }
66   }
67 
68   mlir::MLIRContext *context = &getContext();
69   if (framePointerKind != mlir::LLVM::framePointerKind::FramePointerKind::None)
70     func->setAttr("frame_pointer", mlir::LLVM::FramePointerKindAttr::get(
71                                        context, framePointerKind));
72 
73   auto llvmFuncOpName =
74       mlir::OperationName(mlir::LLVM::LLVMFuncOp::getOperationName(), context);
75   if (noInfsFPMath)
76     func->setAttr(
77         mlir::LLVM::LLVMFuncOp::getNoInfsFpMathAttrName(llvmFuncOpName),
78         mlir::BoolAttr::get(context, true));
79   if (noNaNsFPMath)
80     func->setAttr(
81         mlir::LLVM::LLVMFuncOp::getNoNansFpMathAttrName(llvmFuncOpName),
82         mlir::BoolAttr::get(context, true));
83   if (approxFuncFPMath)
84     func->setAttr(
85         mlir::LLVM::LLVMFuncOp::getApproxFuncFpMathAttrName(llvmFuncOpName),
86         mlir::BoolAttr::get(context, true));
87   if (noSignedZerosFPMath)
88     func->setAttr(
89         mlir::LLVM::LLVMFuncOp::getNoSignedZerosFpMathAttrName(llvmFuncOpName),
90         mlir::BoolAttr::get(context, true));
91   if (unsafeFPMath)
92     func->setAttr(
93         mlir::LLVM::LLVMFuncOp::getUnsafeFpMathAttrName(llvmFuncOpName),
94         mlir::BoolAttr::get(context, true));
95 
96   LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n");
97 }
98