xref: /llvm-project/flang/lib/Semantics/rewrite-directives.cpp (revision 8470cdd499904093ba4faeff870fee12a3e80ff3)
1 //===-- lib/Semantics/rewrite-directives.cpp ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "rewrite-directives.h"
10 #include "flang/Parser/parse-tree-visitor.h"
11 #include "flang/Parser/parse-tree.h"
12 #include "flang/Semantics/semantics.h"
13 #include "flang/Semantics/symbol.h"
14 #include "llvm/Frontend/OpenMP/OMP.h"
15 #include <list>
16 
17 namespace Fortran::semantics {
18 
19 using namespace parser::literals;
20 
21 class DirectiveRewriteMutator {
22 public:
23   explicit DirectiveRewriteMutator(SemanticsContext &context)
24       : context_{context} {}
25 
26   // Default action for a parse tree node is to visit children.
27   template <typename T> bool Pre(T &) { return true; }
28   template <typename T> void Post(T &) {}
29 
30 protected:
31   SemanticsContext &context_;
32 };
33 
34 // Rewrite atomic constructs to add an explicit memory ordering to all that do
35 // not specify it, honoring in this way the `atomic_default_mem_order` clause of
36 // the REQUIRES directive.
37 class OmpRewriteMutator : public DirectiveRewriteMutator {
38 public:
39   explicit OmpRewriteMutator(SemanticsContext &context)
40       : DirectiveRewriteMutator(context) {}
41 
42   template <typename T> bool Pre(T &) { return true; }
43   template <typename T> void Post(T &) {}
44 
45   bool Pre(parser::OpenMPAtomicConstruct &);
46   bool Pre(parser::OpenMPRequiresConstruct &);
47 
48 private:
49   bool atomicDirectiveDefaultOrderFound_{false};
50 };
51 
52 bool OmpRewriteMutator::Pre(parser::OpenMPAtomicConstruct &x) {
53   // Find top-level parent of the operation.
54   Symbol *topLevelParent{common::visit(
55       [&](auto &atomic) {
56         Symbol *symbol{nullptr};
57         Scope *scope{
58             &context_.FindScope(std::get<parser::Verbatim>(atomic.t).source)};
59         do {
60           if (Symbol * parent{scope->symbol()}) {
61             symbol = parent;
62           }
63           scope = &scope->parent();
64         } while (!scope->IsGlobal());
65 
66         assert(symbol &&
67             "Atomic construct must be within a scope associated with a symbol");
68         return symbol;
69       },
70       x.u)};
71 
72   // Get the `atomic_default_mem_order` clause from the top-level parent.
73   std::optional<common::OmpAtomicDefaultMemOrderType> defaultMemOrder;
74   common::visit(
75       [&](auto &details) {
76         if constexpr (std::is_convertible_v<decltype(&details),
77                           WithOmpDeclarative *>) {
78           if (details.has_ompAtomicDefaultMemOrder()) {
79             defaultMemOrder = *details.ompAtomicDefaultMemOrder();
80           }
81         }
82       },
83       topLevelParent->details());
84 
85   if (!defaultMemOrder) {
86     return false;
87   }
88 
89   auto findMemOrderClause =
90       [](const std::list<parser::OmpAtomicClause> &clauses) {
91         return llvm::any_of(clauses, [](const auto &clause) {
92           return std::get_if<parser::OmpMemoryOrderClause>(&clause.u);
93         });
94       };
95 
96   // Get the clause list to which the new memory order clause must be added,
97   // only if there are no other memory order clauses present for this atomic
98   // directive.
99   std::list<parser::OmpAtomicClause> *clauseList = common::visit(
100       common::visitors{[&](parser::OmpAtomic &atomicConstruct) {
101                          // OmpAtomic only has a single list of clauses.
102                          auto &clauses{std::get<parser::OmpAtomicClauseList>(
103                              atomicConstruct.t)};
104                          return !findMemOrderClause(clauses.v) ? &clauses.v
105                                                                : nullptr;
106                        },
107           [&](auto &atomicConstruct) {
108             // All other atomic constructs have two lists of clauses.
109             auto &clausesLhs{std::get<0>(atomicConstruct.t)};
110             auto &clausesRhs{std::get<2>(atomicConstruct.t)};
111             return !findMemOrderClause(clausesLhs.v) &&
112                     !findMemOrderClause(clausesRhs.v)
113                 ? &clausesRhs.v
114                 : nullptr;
115           }},
116       x.u);
117 
118   // Add a memory order clause to the atomic directive.
119   if (clauseList) {
120     atomicDirectiveDefaultOrderFound_ = true;
121     switch (*defaultMemOrder) {
122     case common::OmpAtomicDefaultMemOrderType::AcqRel:
123       clauseList->emplace_back<parser::OmpMemoryOrderClause>(common::visit(
124           common::visitors{[](parser::OmpAtomicRead &) -> parser::OmpClause {
125                              return parser::OmpClause::Acquire{};
126                            },
127               [](parser::OmpAtomicCapture &) -> parser::OmpClause {
128                 return parser::OmpClause::AcqRel{};
129               },
130               [](auto &) -> parser::OmpClause {
131                 // parser::{OmpAtomic, OmpAtomicUpdate, OmpAtomicWrite}
132                 return parser::OmpClause::Release{};
133               }},
134           x.u));
135       break;
136     case common::OmpAtomicDefaultMemOrderType::Relaxed:
137       clauseList->emplace_back<parser::OmpMemoryOrderClause>(
138           parser::OmpClause{parser::OmpClause::Relaxed{}});
139       break;
140     case common::OmpAtomicDefaultMemOrderType::SeqCst:
141       clauseList->emplace_back<parser::OmpMemoryOrderClause>(
142           parser::OmpClause{parser::OmpClause::SeqCst{}});
143       break;
144     }
145   }
146 
147   return false;
148 }
149 
150 bool OmpRewriteMutator::Pre(parser::OpenMPRequiresConstruct &x) {
151   for (parser::OmpClause &clause : std::get<parser::OmpClauseList>(x.t).v) {
152     if (std::holds_alternative<parser::OmpClause::AtomicDefaultMemOrder>(
153             clause.u) &&
154         atomicDirectiveDefaultOrderFound_) {
155       context_.Say(clause.source,
156           "REQUIRES directive with '%s' clause found lexically after atomic "
157           "operation without a memory order clause"_err_en_US,
158           parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(
159               llvm::omp::OMPC_atomic_default_mem_order)
160                                          .str()));
161     }
162   }
163   return false;
164 }
165 
166 bool RewriteOmpParts(SemanticsContext &context, parser::Program &program) {
167   if (!context.IsEnabled(common::LanguageFeature::OpenMP)) {
168     return true;
169   }
170   OmpRewriteMutator ompMutator{context};
171   parser::Walk(program, ompMutator);
172   return !context.AnyFatalError();
173 }
174 
175 } // namespace Fortran::semantics
176