xref: /llvm-project/flang/lib/Lower/OpenMP/ClauseProcessor.h (revision e532241b021cd48bad303721757c1194bc844775)
1 //===-- Lower/OpenMP/ClauseProcessor.h --------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 #ifndef FORTRAN_LOWER_CLAUSEPROCESSOR_H
13 #define FORTRAN_LOWER_CLAUSEPROCESSOR_H
14 
15 #include "Clauses.h"
16 #include "ReductionProcessor.h"
17 #include "Utils.h"
18 #include "flang/Lower/AbstractConverter.h"
19 #include "flang/Lower/Bridge.h"
20 #include "flang/Lower/DirectivesCommon.h"
21 #include "flang/Optimizer/Builder/Todo.h"
22 #include "flang/Parser/dump-parse-tree.h"
23 #include "flang/Parser/parse-tree.h"
24 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
25 
26 namespace fir {
27 class FirOpBuilder;
28 } // namespace fir
29 
30 namespace Fortran {
31 namespace lower {
32 namespace omp {
33 
34 /// Class that handles the processing of OpenMP clauses.
35 ///
36 /// Its `process<ClauseName>()` methods perform MLIR code generation for their
37 /// corresponding clause if it is present in the clause list. Otherwise, they
38 /// will return `false` to signal that the clause was not found.
39 ///
40 /// The intended use of this class is to move clause processing outside of
41 /// construct processing, since the same clauses can appear attached to
42 /// different constructs and constructs can be combined, so that code
43 /// duplication is minimized.
44 ///
45 /// Each construct-lowering function only calls the `process<ClauseName>()`
46 /// methods that relate to clauses that can impact the lowering of that
47 /// construct.
48 class ClauseProcessor {
49 public:
50   ClauseProcessor(lower::AbstractConverter &converter,
51                   semantics::SemanticsContext &semaCtx,
52                   const List<Clause> &clauses)
53       : converter(converter), semaCtx(semaCtx), clauses(clauses) {}
54 
55   // 'Unique' clauses: They can appear at most once in the clause list.
56   bool processBare(mlir::omp::BareClauseOps &result) const;
57   bool processBind(mlir::omp::BindClauseOps &result) const;
58   bool
59   processCollapse(mlir::Location currentLocation, lower::pft::Evaluation &eval,
60                   mlir::omp::LoopRelatedClauseOps &result,
61                   llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const;
62   bool processDevice(lower::StatementContext &stmtCtx,
63                      mlir::omp::DeviceClauseOps &result) const;
64   bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
65   bool processDistSchedule(lower::StatementContext &stmtCtx,
66                            mlir::omp::DistScheduleClauseOps &result) const;
67   bool processFilter(lower::StatementContext &stmtCtx,
68                      mlir::omp::FilterClauseOps &result) const;
69   bool processFinal(lower::StatementContext &stmtCtx,
70                     mlir::omp::FinalClauseOps &result) const;
71   bool processHasDeviceAddr(
72       mlir::omp::HasDeviceAddrClauseOps &result,
73       llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
74   bool processHint(mlir::omp::HintClauseOps &result) const;
75   bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
76   bool processNowait(mlir::omp::NowaitClauseOps &result) const;
77   bool processNumTeams(lower::StatementContext &stmtCtx,
78                        mlir::omp::NumTeamsClauseOps &result) const;
79   bool processNumThreads(lower::StatementContext &stmtCtx,
80                          mlir::omp::NumThreadsClauseOps &result) const;
81   bool processOrder(mlir::omp::OrderClauseOps &result) const;
82   bool processOrdered(mlir::omp::OrderedClauseOps &result) const;
83   bool processPriority(lower::StatementContext &stmtCtx,
84                        mlir::omp::PriorityClauseOps &result) const;
85   bool processProcBind(mlir::omp::ProcBindClauseOps &result) const;
86   bool processSafelen(mlir::omp::SafelenClauseOps &result) const;
87   bool processSchedule(lower::StatementContext &stmtCtx,
88                        mlir::omp::ScheduleClauseOps &result) const;
89   bool processSimdlen(mlir::omp::SimdlenClauseOps &result) const;
90   bool processThreadLimit(lower::StatementContext &stmtCtx,
91                           mlir::omp::ThreadLimitClauseOps &result) const;
92   bool processUntied(mlir::omp::UntiedClauseOps &result) const;
93 
94   bool processDetach(mlir::omp::DetachClauseOps &result) const;
95   // 'Repeatable' clauses: They can appear multiple times in the clause list.
96   bool processAligned(mlir::omp::AlignedClauseOps &result) const;
97   bool processAllocate(mlir::omp::AllocateClauseOps &result) const;
98   bool processCopyin() const;
99   bool processCopyprivate(mlir::Location currentLocation,
100                           mlir::omp::CopyprivateClauseOps &result) const;
101   bool processDepend(mlir::omp::DependClauseOps &result) const;
102   bool
103   processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
104   bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
105                  mlir::omp::IfClauseOps &result) const;
106   bool processIsDevicePtr(
107       mlir::omp::IsDevicePtrClauseOps &result,
108       llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
109   bool
110   processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
111 
112   // This method is used to process a map clause.
113   // The optional parameter mapSyms is used to store the original Fortran symbol
114   // for the map operands. It may be used later on to create the block_arguments
115   // for some of the directives that require it.
116   bool processMap(mlir::Location currentLocation,
117                   lower::StatementContext &stmtCtx,
118                   mlir::omp::MapClauseOps &result,
119                   llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms =
120                       nullptr) const;
121   bool processMotionClauses(lower::StatementContext &stmtCtx,
122                             mlir::omp::MapClauseOps &result);
123   bool processNontemporal(mlir::omp::NontemporalClauseOps &result) const;
124   bool processReduction(
125       mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
126       llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) const;
127   bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
128   bool processUseDeviceAddr(
129       lower::StatementContext &stmtCtx,
130       mlir::omp::UseDeviceAddrClauseOps &result,
131       llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const;
132   bool processUseDevicePtr(
133       lower::StatementContext &stmtCtx,
134       mlir::omp::UseDevicePtrClauseOps &result,
135       llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const;
136 
137   // Call this method for these clauses that should be supported but are not
138   // implemented yet. It triggers a compilation error if any of the given
139   // clauses is found.
140   template <typename... Ts>
141   void processTODO(mlir::Location currentLocation,
142                    llvm::omp::Directive directive) const;
143 
144 private:
145   using ClauseIterator = List<Clause>::const_iterator;
146 
147   /// Utility to find a clause within a range in the clause list.
148   template <typename T>
149   static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end);
150 
151   /// Return the first instance of the given clause found in the clause list or
152   /// `nullptr` if not present. If more than one instance is expected, use
153   /// `findRepeatableClause` instead.
154   template <typename T>
155   const T *findUniqueClause(const parser::CharBlock **source = nullptr) const;
156 
157   /// Call `callbackFn` for each occurrence of the given clause. Return `true`
158   /// if at least one instance was found.
159   template <typename T>
160   bool findRepeatableClause(
161       std::function<void(const T &, const parser::CharBlock &source)>
162           callbackFn) const;
163 
164   /// Set the `result` to a new `mlir::UnitAttr` if the clause is present.
165   template <typename T>
166   bool markClauseOccurrence(mlir::UnitAttr &result) const;
167 
168   void processMapObjects(
169       lower::StatementContext &stmtCtx, mlir::Location clauseLocation,
170       const omp::ObjectList &objects,
171       llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
172       std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
173       llvm::SmallVectorImpl<mlir::Value> &mapVars,
174       llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) const;
175 
176   lower::AbstractConverter &converter;
177   semantics::SemanticsContext &semaCtx;
178   List<Clause> clauses;
179 };
180 
181 template <typename... Ts>
182 void ClauseProcessor::processTODO(mlir::Location currentLocation,
183                                   llvm::omp::Directive directive) const {
184   auto checkUnhandledClause = [&](llvm::omp::Clause id, const auto *x) {
185     if (!x)
186       return;
187     TODO(currentLocation,
188          "Unhandled clause " + llvm::omp::getOpenMPClauseName(id).upper() +
189              " in " + llvm::omp::getOpenMPDirectiveName(directive).upper() +
190              " construct");
191   };
192 
193   for (ClauseIterator it = clauses.begin(); it != clauses.end(); ++it)
194     (checkUnhandledClause(it->id, std::get_if<Ts>(&it->u)), ...);
195 }
196 
197 template <typename T>
198 ClauseProcessor::ClauseIterator
199 ClauseProcessor::findClause(ClauseIterator begin, ClauseIterator end) {
200   for (ClauseIterator it = begin; it != end; ++it) {
201     if (std::get_if<T>(&it->u))
202       return it;
203   }
204 
205   return end;
206 }
207 
208 template <typename T>
209 const T *
210 ClauseProcessor::findUniqueClause(const parser::CharBlock **source) const {
211   ClauseIterator it = findClause<T>(clauses.begin(), clauses.end());
212   if (it != clauses.end()) {
213     if (source)
214       *source = &it->source;
215     return &std::get<T>(it->u);
216   }
217   return nullptr;
218 }
219 
220 template <typename T>
221 bool ClauseProcessor::findRepeatableClause(
222     std::function<void(const T &, const parser::CharBlock &source)> callbackFn)
223     const {
224   bool found = false;
225   ClauseIterator nextIt, endIt = clauses.end();
226   for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) {
227     nextIt = findClause<T>(it, endIt);
228 
229     if (nextIt != endIt) {
230       callbackFn(std::get<T>(nextIt->u), nextIt->source);
231       found = true;
232       ++nextIt;
233     }
234   }
235   return found;
236 }
237 
238 template <typename T>
239 bool ClauseProcessor::markClauseOccurrence(mlir::UnitAttr &result) const {
240   if (findUniqueClause<T>()) {
241     result = converter.getFirOpBuilder().getUnitAttr();
242     return true;
243   }
244   return false;
245 }
246 
247 } // namespace omp
248 } // namespace lower
249 } // namespace Fortran
250 
251 #endif // FORTRAN_LOWER_CLAUSEPROCESSOR_H
252