xref: /llvm-project/flang/lib/Lower/OpenMP/Utils.cpp (revision f98244392b4e3d4075c03528dcec0b268ba13ab7)
1 //===-- Utils..cpp ----------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Utils.h"
14 
15 #include "Clauses.h"
16 #include <flang/Lower/AbstractConverter.h>
17 #include <flang/Lower/ConvertType.h>
18 #include <flang/Lower/PFTBuilder.h>
19 #include <flang/Optimizer/Builder/FIRBuilder.h>
20 #include <flang/Optimizer/Builder/Todo.h>
21 #include <flang/Parser/parse-tree.h>
22 #include <flang/Parser/tools.h>
23 #include <flang/Semantics/tools.h>
24 #include <llvm/Support/CommandLine.h>
25 
26 #include <algorithm>
27 #include <numeric>
28 
29 llvm::cl::opt<bool> treatIndexAsSection(
30     "openmp-treat-index-as-section",
31     llvm::cl::desc("In the OpenMP data clauses treat `a(N)` as `a(N:N)`."),
32     llvm::cl::init(true));
33 
34 llvm::cl::opt<bool> enableDelayedPrivatization(
35     "openmp-enable-delayed-privatization",
36     llvm::cl::desc(
37         "Emit `[first]private` variables as clauses on the MLIR ops."),
38     llvm::cl::init(true));
39 
40 llvm::cl::opt<bool> enableDelayedPrivatizationStaging(
41     "openmp-enable-delayed-privatization-staging",
42     llvm::cl::desc("For partially supported constructs, emit `[first]private` "
43                    "variables as clauses on the MLIR ops."),
44     llvm::cl::init(false));
45 
46 namespace Fortran {
47 namespace lower {
48 namespace omp {
49 
50 int64_t getCollapseValue(const List<Clause> &clauses) {
51   auto iter = llvm::find_if(clauses, [](const Clause &clause) {
52     return clause.id == llvm::omp::Clause::OMPC_collapse;
53   });
54   if (iter != clauses.end()) {
55     const auto &collapse = std::get<clause::Collapse>(iter->u);
56     return evaluate::ToInt64(collapse.v).value();
57   }
58   return 1;
59 }
60 
61 void genObjectList(const ObjectList &objects,
62                    lower::AbstractConverter &converter,
63                    llvm::SmallVectorImpl<mlir::Value> &operands) {
64   for (const Object &object : objects) {
65     const semantics::Symbol *sym = object.sym();
66     assert(sym && "Expected Symbol");
67     if (mlir::Value variable = converter.getSymbolAddress(*sym)) {
68       operands.push_back(variable);
69     } else if (const auto *details =
70                    sym->detailsIf<semantics::HostAssocDetails>()) {
71       operands.push_back(converter.getSymbolAddress(details->symbol()));
72       converter.copySymbolBinding(details->symbol(), *sym);
73     }
74   }
75 }
76 
77 mlir::Type getLoopVarType(lower::AbstractConverter &converter,
78                           std::size_t loopVarTypeSize) {
79   // OpenMP runtime requires 32-bit or 64-bit loop variables.
80   loopVarTypeSize = loopVarTypeSize * 8;
81   if (loopVarTypeSize < 32) {
82     loopVarTypeSize = 32;
83   } else if (loopVarTypeSize > 64) {
84     loopVarTypeSize = 64;
85     mlir::emitWarning(converter.getCurrentLocation(),
86                       "OpenMP loop iteration variable cannot have more than 64 "
87                       "bits size and will be narrowed into 64 bits.");
88   }
89   assert((loopVarTypeSize == 32 || loopVarTypeSize == 64) &&
90          "OpenMP loop iteration variable size must be transformed into 32-bit "
91          "or 64-bit");
92   return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize);
93 }
94 
95 semantics::Symbol *
96 getIterationVariableSymbol(const lower::pft::Evaluation &eval) {
97   return eval.visit(common::visitors{
98       [&](const parser::DoConstruct &doLoop) {
99         if (const auto &maybeCtrl = doLoop.GetLoopControl()) {
100           using LoopControl = parser::LoopControl;
101           if (auto *bounds = std::get_if<LoopControl::Bounds>(&maybeCtrl->u)) {
102             static_assert(std::is_same_v<decltype(bounds->name),
103                                          parser::Scalar<parser::Name>>);
104             return bounds->name.thing.symbol;
105           }
106         }
107         return static_cast<semantics::Symbol *>(nullptr);
108       },
109       [](auto &&) { return static_cast<semantics::Symbol *>(nullptr); },
110   });
111 }
112 
113 void gatherFuncAndVarSyms(
114     const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause,
115     llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
116   for (const Object &object : objects)
117     symbolAndClause.emplace_back(clause, *object.sym());
118 }
119 
120 mlir::omp::MapInfoOp
121 createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
122                 mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
123                 llvm::ArrayRef<mlir::Value> bounds,
124                 llvm::ArrayRef<mlir::Value> members,
125                 mlir::DenseIntElementsAttr membersIndex, uint64_t mapType,
126                 mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
127                 bool partialMap) {
128   if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) {
129     baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
130     retTy = baseAddr.getType();
131   }
132 
133   mlir::TypeAttr varType = mlir::TypeAttr::get(
134       llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType());
135 
136   // For types with unknown extents such as <2x?xi32> we discard the incomplete
137   // type info and only retain the base type. The correct dimensions are later
138   // recovered through the bounds info.
139   if (auto seqType = llvm::dyn_cast<fir::SequenceType>(varType.getValue()))
140     if (seqType.hasDynamicExtents())
141       varType = mlir::TypeAttr::get(seqType.getEleTy());
142 
143   mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>(
144       loc, retTy, baseAddr, varType, varPtrPtr, members, membersIndex, bounds,
145       builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
146       builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
147       builder.getStringAttr(name), builder.getBoolAttr(partialMap));
148 
149   return op;
150 }
151 
152 static int
153 getComponentPlacementInParent(const semantics::Symbol *componentSym) {
154   const auto *derived = componentSym->owner()
155                             .derivedTypeSpec()
156                             ->typeSymbol()
157                             .detailsIf<semantics::DerivedTypeDetails>();
158   assert(derived &&
159          "expected derived type details when processing component symbol");
160   for (auto [placement, name] : llvm::enumerate(derived->componentNames()))
161     if (name == componentSym->name())
162       return placement;
163   return -1;
164 }
165 
166 static std::optional<Object>
167 getComponentObject(std::optional<Object> object,
168                    semantics::SemanticsContext &semaCtx) {
169   if (!object)
170     return std::nullopt;
171 
172   auto ref = evaluate::ExtractDataRef(*object.value().ref());
173   if (!ref)
174     return std::nullopt;
175 
176   if (std::holds_alternative<evaluate::Component>(ref->u))
177     return object;
178 
179   auto baseObj = getBaseObject(object.value(), semaCtx);
180   if (!baseObj)
181     return std::nullopt;
182 
183   return getComponentObject(baseObj.value(), semaCtx);
184 }
185 
186 static void
187 generateMemberPlacementIndices(const Object &object,
188                                llvm::SmallVectorImpl<int> &indices,
189                                semantics::SemanticsContext &semaCtx) {
190   auto compObj = getComponentObject(object, semaCtx);
191   while (compObj) {
192     indices.push_back(getComponentPlacementInParent(compObj->sym()));
193     compObj =
194         getComponentObject(getBaseObject(compObj.value(), semaCtx), semaCtx);
195   }
196 
197   indices = llvm::SmallVector<int>{llvm::reverse(indices)};
198 }
199 
200 void addChildIndexAndMapToParent(
201     const omp::Object &object,
202     std::map<const semantics::Symbol *,
203              llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
204     mlir::omp::MapInfoOp &mapOp, semantics::SemanticsContext &semaCtx) {
205   std::optional<evaluate::DataRef> dataRef = ExtractDataRef(object.ref());
206   assert(dataRef.has_value() &&
207          "DataRef could not be extracted during mapping of derived type "
208          "cannot proceed");
209   const semantics::Symbol *parentSym = &dataRef->GetFirstSymbol();
210   assert(parentSym && "Could not find parent symbol during lower of "
211                       "a component member in OpenMP map clause");
212   llvm::SmallVector<int> indices;
213   generateMemberPlacementIndices(object, indices, semaCtx);
214   parentMemberIndices[parentSym].push_back({indices, mapOp});
215 }
216 
217 static void calculateShapeAndFillIndices(
218     llvm::SmallVectorImpl<int64_t> &shape,
219     llvm::SmallVectorImpl<OmpMapMemberIndicesData> &memberPlacementData) {
220   shape.push_back(memberPlacementData.size());
221   size_t largestIndicesSize =
222       std::max_element(memberPlacementData.begin(), memberPlacementData.end(),
223                        [](auto a, auto b) {
224                          return a.memberPlacementIndices.size() <
225                                 b.memberPlacementIndices.size();
226                        })
227           ->memberPlacementIndices.size();
228   shape.push_back(largestIndicesSize);
229 
230   // DenseElementsAttr expects a rectangular shape for the data, so all
231   // index lists have to be of the same length, this emplaces -1 as filler.
232   for (auto &v : memberPlacementData) {
233     if (v.memberPlacementIndices.size() < largestIndicesSize) {
234       auto *prevEnd = v.memberPlacementIndices.end();
235       v.memberPlacementIndices.resize(largestIndicesSize);
236       std::fill(prevEnd, v.memberPlacementIndices.end(), -1);
237     }
238   }
239 }
240 
241 static mlir::DenseIntElementsAttr createDenseElementsAttrFromIndices(
242     llvm::SmallVectorImpl<OmpMapMemberIndicesData> &memberPlacementData,
243     fir::FirOpBuilder &builder) {
244   llvm::SmallVector<int64_t> shape;
245   calculateShapeAndFillIndices(shape, memberPlacementData);
246 
247   llvm::SmallVector<int> indicesFlattened =
248       std::accumulate(memberPlacementData.begin(), memberPlacementData.end(),
249                       llvm::SmallVector<int>(),
250                       [](llvm::SmallVector<int> &x, OmpMapMemberIndicesData y) {
251                         x.insert(x.end(), y.memberPlacementIndices.begin(),
252                                  y.memberPlacementIndices.end());
253                         return x;
254                       });
255 
256   return mlir::DenseIntElementsAttr::get(
257       mlir::VectorType::get(shape,
258                             mlir::IntegerType::get(builder.getContext(), 32)),
259       indicesFlattened);
260 }
261 
262 void insertChildMapInfoIntoParent(
263     lower::AbstractConverter &converter,
264     std::map<const semantics::Symbol *,
265              llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
266     llvm::SmallVectorImpl<mlir::Value> &mapOperands,
267     llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms,
268     llvm::SmallVectorImpl<mlir::Type> *mapSymTypes,
269     llvm::SmallVectorImpl<mlir::Location> *mapSymLocs) {
270   for (auto indices : parentMemberIndices) {
271     bool parentExists = false;
272     size_t parentIdx;
273     for (parentIdx = 0; parentIdx < mapSyms.size(); ++parentIdx) {
274       if (mapSyms[parentIdx] == indices.first) {
275         parentExists = true;
276         break;
277       }
278     }
279 
280     if (parentExists) {
281       auto mapOp = llvm::cast<mlir::omp::MapInfoOp>(
282           mapOperands[parentIdx].getDefiningOp());
283 
284       // NOTE: To maintain appropriate SSA ordering, we move the parent map
285       // which will now have references to its children after the last
286       // of its members to be generated. This is necessary when a user
287       // has defined a series of parent and children maps where the parent
288       // precedes the children. An alternative, may be to do
289       // delayed generation of map info operations from the clauses and
290       // organize them first before generation.
291       mapOp->moveAfter(indices.second.back().memberMap);
292 
293       for (auto memberIndicesData : indices.second)
294         mapOp.getMembersMutable().append(
295             memberIndicesData.memberMap.getResult());
296 
297       mapOp.setMembersIndexAttr(createDenseElementsAttrFromIndices(
298           indices.second, converter.getFirOpBuilder()));
299     } else {
300       // NOTE: We take the map type of the first child, this may not
301       // be the correct thing to do, however, we shall see. For the moment
302       // it allows this to work with enter and exit without causing MLIR
303       // verification issues. The more appropriate thing may be to take
304       // the "main" map type clause from the directive being used.
305       uint64_t mapType = indices.second[0].memberMap.getMapType().value_or(0);
306 
307       // create parent to emplace and bind members
308       mlir::Value origSymbol = converter.getSymbolAddress(*indices.first);
309 
310       llvm::SmallVector<mlir::Value> members;
311       for (OmpMapMemberIndicesData memberIndicesData : indices.second)
312         members.push_back((mlir::Value)memberIndicesData.memberMap);
313 
314       mlir::Value mapOp = createMapInfoOp(
315           converter.getFirOpBuilder(), origSymbol.getLoc(), origSymbol,
316           /*varPtrPtr=*/mlir::Value(), indices.first->name().ToString(),
317           /*bounds=*/{}, members,
318           createDenseElementsAttrFromIndices(indices.second,
319                                              converter.getFirOpBuilder()),
320           mapType, mlir::omp::VariableCaptureKind::ByRef, origSymbol.getType(),
321           /*partialMap=*/true);
322 
323       mapOperands.push_back(mapOp);
324       mapSyms.push_back(indices.first);
325 
326       if (mapSymTypes)
327         mapSymTypes->push_back(mapOp.getType());
328       if (mapSymLocs)
329         mapSymLocs->push_back(mapOp.getLoc());
330     }
331   }
332 }
333 
334 semantics::Symbol *getOmpObjectSymbol(const parser::OmpObject &ompObject) {
335   semantics::Symbol *sym = nullptr;
336   Fortran::common::visit(
337       common::visitors{
338           [&](const parser::Designator &designator) {
339             if (auto *arrayEle =
340                     parser::Unwrap<parser::ArrayElement>(designator)) {
341               // Use getLastName to retrieve the arrays symbol, this will
342               // provide the farthest right symbol (the last) in a designator,
343               // i.e. providing something like the following:
344               // "dtype1%dtype2%array[2:10]", will result in "array"
345               sym = GetLastName(arrayEle->base).symbol;
346             } else if (auto *structComp =
347                            parser::Unwrap<parser::StructureComponent>(
348                                designator)) {
349               sym = structComp->component.symbol;
350             } else if (const parser::Name *name =
351                            semantics::getDesignatorNameIfDataRef(designator)) {
352               sym = name->symbol;
353             }
354           },
355           [&](const parser::Name &name) { sym = name.symbol; }},
356       ompObject.u);
357   return sym;
358 }
359 
360 void lastprivateModifierNotSupported(const omp::clause::Lastprivate &lastp,
361                                      mlir::Location loc) {
362   using Lastprivate = omp::clause::Lastprivate;
363   auto &maybeMod =
364       std::get<std::optional<Lastprivate::LastprivateModifier>>(lastp.t);
365   if (maybeMod) {
366     assert(*maybeMod == Lastprivate::LastprivateModifier::Conditional &&
367            "Unexpected lastprivate modifier");
368     TODO(loc, "lastprivate clause with CONDITIONAL modifier");
369   }
370 }
371 
372 } // namespace omp
373 } // namespace lower
374 } // namespace Fortran
375