xref: /llvm-project/flang/lib/Lower/OpenMP/ClauseProcessor.h (revision a43d2f686a46fd9d971aa65fde4563375e16f3de)
1 //===-- Lower/OpenMP/ClauseProcessor.h --------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 #ifndef FORTRAN_LOWER_CLAUSEPROCESSOR_H
13 #define FORTRAN_LOWER_CLAUSEPROCESSOR_H
14 
15 #include "Clauses.h"
16 #include "DirectivesCommon.h"
17 #include "ReductionProcessor.h"
18 #include "Utils.h"
19 #include "flang/Lower/AbstractConverter.h"
20 #include "flang/Lower/Bridge.h"
21 #include "flang/Optimizer/Builder/Todo.h"
22 #include "flang/Parser/dump-parse-tree.h"
23 #include "flang/Parser/parse-tree.h"
24 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
25 
26 namespace fir {
27 class FirOpBuilder;
28 } // namespace fir
29 
30 namespace Fortran {
31 namespace lower {
32 namespace omp {
33 
34 /// Class that handles the processing of OpenMP clauses.
35 ///
36 /// Its `process<ClauseName>()` methods perform MLIR code generation for their
37 /// corresponding clause if it is present in the clause list. Otherwise, they
38 /// will return `false` to signal that the clause was not found.
39 ///
40 /// The intended use of this class is to move clause processing outside of
41 /// construct processing, since the same clauses can appear attached to
42 /// different constructs and constructs can be combined, so that code
43 /// duplication is minimized.
44 ///
45 /// Each construct-lowering function only calls the `process<ClauseName>()`
46 /// methods that relate to clauses that can impact the lowering of that
47 /// construct.
48 class ClauseProcessor {
49 public:
50   ClauseProcessor(lower::AbstractConverter &converter,
51                   semantics::SemanticsContext &semaCtx,
52                   const List<Clause> &clauses)
53       : converter(converter), semaCtx(semaCtx), clauses(clauses) {}
54 
55   // 'Unique' clauses: They can appear at most once in the clause list.
56   bool
57   processCollapse(mlir::Location currentLocation, lower::pft::Evaluation &eval,
58                   mlir::omp::CollapseClauseOps &result,
59                   llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const;
60   bool processDefault() const;
61   bool processDevice(lower::StatementContext &stmtCtx,
62                      mlir::omp::DeviceClauseOps &result) const;
63   bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
64   bool processFinal(lower::StatementContext &stmtCtx,
65                     mlir::omp::FinalClauseOps &result) const;
66   bool processHasDeviceAddr(
67       mlir::omp::HasDeviceAddrClauseOps &result,
68       llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
69       llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
70       llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const;
71   bool processHint(mlir::omp::HintClauseOps &result) const;
72   bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
73   bool processNowait(mlir::omp::NowaitClauseOps &result) const;
74   bool processNumTeams(lower::StatementContext &stmtCtx,
75                        mlir::omp::NumTeamsClauseOps &result) const;
76   bool processNumThreads(lower::StatementContext &stmtCtx,
77                          mlir::omp::NumThreadsClauseOps &result) const;
78   bool processOrdered(mlir::omp::OrderedClauseOps &result) const;
79   bool processPriority(lower::StatementContext &stmtCtx,
80                        mlir::omp::PriorityClauseOps &result) const;
81   bool processProcBind(mlir::omp::ProcBindClauseOps &result) const;
82   bool processSafelen(mlir::omp::SafelenClauseOps &result) const;
83   bool processSchedule(lower::StatementContext &stmtCtx,
84                        mlir::omp::ScheduleClauseOps &result) const;
85   bool processSimdlen(mlir::omp::SimdlenClauseOps &result) const;
86   bool processThreadLimit(lower::StatementContext &stmtCtx,
87                           mlir::omp::ThreadLimitClauseOps &result) const;
88   bool processUntied(mlir::omp::UntiedClauseOps &result) const;
89 
90   // 'Repeatable' clauses: They can appear multiple times in the clause list.
91   bool processAllocate(mlir::omp::AllocateClauseOps &result) const;
92   bool processCopyin() const;
93   bool processCopyprivate(mlir::Location currentLocation,
94                           mlir::omp::CopyprivateClauseOps &result) const;
95   bool processDepend(mlir::omp::DependClauseOps &result) const;
96   bool
97   processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
98   bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
99                  mlir::omp::IfClauseOps &result) const;
100   bool processIsDevicePtr(
101       mlir::omp::IsDevicePtrClauseOps &result,
102       llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
103       llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
104       llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const;
105   bool
106   processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
107 
108   // This method is used to process a map clause.
109   // The optional parameters - mapSymTypes, mapSymLocs & mapSyms are used to
110   // store the original type, location and Fortran symbol for the map operands.
111   // They may be used later on to create the block_arguments for some of the
112   // target directives that require it.
113   bool processMap(
114       mlir::Location currentLocation, lower::StatementContext &stmtCtx,
115       mlir::omp::MapClauseOps &result,
116       llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms = nullptr,
117       llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
118       llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr) const;
119   bool processReduction(
120       mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
121       llvm::SmallVectorImpl<mlir::Type> *reductionTypes = nullptr,
122       llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSyms =
123           nullptr) const;
124   bool processSectionsReduction(mlir::Location currentLocation,
125                                 mlir::omp::ReductionClauseOps &result) const;
126   bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
127   bool processUseDeviceAddr(
128       mlir::omp::UseDeviceClauseOps &result,
129       llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
130       llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
131       llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const;
132   bool processUseDevicePtr(
133       mlir::omp::UseDeviceClauseOps &result,
134       llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
135       llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
136       llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const;
137 
138   template <typename T>
139   bool processMotionClauses(lower::StatementContext &stmtCtx,
140                             mlir::omp::MapClauseOps &result);
141 
142   // Call this method for these clauses that should be supported but are not
143   // implemented yet. It triggers a compilation error if any of the given
144   // clauses is found.
145   template <typename... Ts>
146   void processTODO(mlir::Location currentLocation,
147                    llvm::omp::Directive directive) const;
148 
149 private:
150   using ClauseIterator = List<Clause>::const_iterator;
151 
152   /// Utility to find a clause within a range in the clause list.
153   template <typename T>
154   static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end);
155 
156   /// Return the first instance of the given clause found in the clause list or
157   /// `nullptr` if not present. If more than one instance is expected, use
158   /// `findRepeatableClause` instead.
159   template <typename T>
160   const T *findUniqueClause(const parser::CharBlock **source = nullptr) const;
161 
162   /// Call `callbackFn` for each occurrence of the given clause. Return `true`
163   /// if at least one instance was found.
164   template <typename T>
165   bool findRepeatableClause(
166       std::function<void(const T &, const parser::CharBlock &source)>
167           callbackFn) const;
168 
169   /// Set the `result` to a new `mlir::UnitAttr` if the clause is present.
170   template <typename T>
171   bool markClauseOccurrence(mlir::UnitAttr &result) const;
172 
173   lower::AbstractConverter &converter;
174   semantics::SemanticsContext &semaCtx;
175   List<Clause> clauses;
176 };
177 
178 template <typename T>
179 bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
180                                            mlir::omp::MapClauseOps &result) {
181   std::map<const semantics::Symbol *,
182            llvm::SmallVector<OmpMapMemberIndicesData>>
183       parentMemberIndices;
184   llvm::SmallVector<const semantics::Symbol *> mapSymbols;
185 
186   bool clauseFound = findRepeatableClause<T>(
187       [&](const T &clause, const parser::CharBlock &source) {
188         mlir::Location clauseLocation = converter.genLocation(source);
189         fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
190 
191         static_assert(std::is_same_v<T, omp::clause::To> ||
192                       std::is_same_v<T, omp::clause::From>);
193 
194         // TODO Support motion modifiers: present, mapper, iterator.
195         constexpr llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
196             std::is_same_v<T, omp::clause::To>
197                 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO
198                 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
199 
200         auto &objects = std::get<ObjectList>(clause.t);
201         for (const omp::Object &object : objects) {
202           llvm::SmallVector<mlir::Value> bounds;
203           std::stringstream asFortran;
204 
205           lower::AddrAndBoundsInfo info =
206               lower::gatherDataOperandAddrAndBounds<mlir::omp::MapBoundsOp,
207                                                     mlir::omp::MapBoundsType>(
208                   converter, firOpBuilder, semaCtx, stmtCtx, *object.id(),
209                   object.ref(), clauseLocation, asFortran, bounds,
210                   treatIndexAsSection);
211 
212           auto origSymbol = converter.getSymbolAddress(*object.id());
213           mlir::Value symAddr = info.addr;
214           if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
215             symAddr = origSymbol;
216 
217           // Explicit map captures are captured ByRef by default,
218           // optimisation passes may alter this to ByCopy or other capture
219           // types to optimise
220           mlir::omp::MapInfoOp mapOp = createMapInfoOp(
221               firOpBuilder, clauseLocation, symAddr,
222               /*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
223               /*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
224               static_cast<
225                   std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
226                   mapTypeBits),
227               mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
228 
229           if (object.id()->owner().IsDerivedType()) {
230             addChildIndexAndMapToParent(object, parentMemberIndices, mapOp,
231                                         semaCtx);
232           } else {
233             result.mapVars.push_back(mapOp);
234             mapSymbols.push_back(object.id());
235           }
236         }
237       });
238 
239   insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
240                                mapSymbols,
241                                /*mapSymTypes=*/nullptr, /*mapSymLocs=*/nullptr);
242   return clauseFound;
243 }
244 
245 template <typename... Ts>
246 void ClauseProcessor::processTODO(mlir::Location currentLocation,
247                                   llvm::omp::Directive directive) const {
248   auto checkUnhandledClause = [&](llvm::omp::Clause id, const auto *x) {
249     if (!x)
250       return;
251     TODO(currentLocation,
252          "Unhandled clause " + llvm::omp::getOpenMPClauseName(id).upper() +
253              " in " + llvm::omp::getOpenMPDirectiveName(directive).upper() +
254              " construct");
255   };
256 
257   for (ClauseIterator it = clauses.begin(); it != clauses.end(); ++it)
258     (checkUnhandledClause(it->id, std::get_if<Ts>(&it->u)), ...);
259 }
260 
261 template <typename T>
262 ClauseProcessor::ClauseIterator
263 ClauseProcessor::findClause(ClauseIterator begin, ClauseIterator end) {
264   for (ClauseIterator it = begin; it != end; ++it) {
265     if (std::get_if<T>(&it->u))
266       return it;
267   }
268 
269   return end;
270 }
271 
272 template <typename T>
273 const T *
274 ClauseProcessor::findUniqueClause(const parser::CharBlock **source) const {
275   ClauseIterator it = findClause<T>(clauses.begin(), clauses.end());
276   if (it != clauses.end()) {
277     if (source)
278       *source = &it->source;
279     return &std::get<T>(it->u);
280   }
281   return nullptr;
282 }
283 
284 template <typename T>
285 bool ClauseProcessor::findRepeatableClause(
286     std::function<void(const T &, const parser::CharBlock &source)> callbackFn)
287     const {
288   bool found = false;
289   ClauseIterator nextIt, endIt = clauses.end();
290   for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) {
291     nextIt = findClause<T>(it, endIt);
292 
293     if (nextIt != endIt) {
294       callbackFn(std::get<T>(nextIt->u), nextIt->source);
295       found = true;
296       ++nextIt;
297     }
298   }
299   return found;
300 }
301 
302 template <typename T>
303 bool ClauseProcessor::markClauseOccurrence(mlir::UnitAttr &result) const {
304   if (findUniqueClause<T>()) {
305     result = converter.getFirOpBuilder().getUnitAttr();
306     return true;
307   }
308   return false;
309 }
310 
311 } // namespace omp
312 } // namespace lower
313 } // namespace Fortran
314 
315 #endif // FORTRAN_LOWER_CLAUSEPROCESSOR_H
316