xref: /llvm-project/flang/lib/Lower/OpenMP/Utils.cpp (revision 77d8cfb3c50e3341d65af1f9e442004bbd77af9b)
1 //===-- Utils..cpp ----------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Utils.h"
14 
15 #include "Clauses.h"
16 #include <flang/Lower/AbstractConverter.h>
17 #include <flang/Lower/ConvertType.h>
18 #include <flang/Lower/PFTBuilder.h>
19 #include <flang/Optimizer/Builder/FIRBuilder.h>
20 #include <flang/Parser/parse-tree.h>
21 #include <flang/Parser/tools.h>
22 #include <flang/Semantics/tools.h>
23 #include <llvm/Support/CommandLine.h>
24 
25 #include <algorithm>
26 #include <numeric>
27 
28 llvm::cl::opt<bool> treatIndexAsSection(
29     "openmp-treat-index-as-section",
30     llvm::cl::desc("In the OpenMP data clauses treat `a(N)` as `a(N:N)`."),
31     llvm::cl::init(true));
32 
33 llvm::cl::opt<bool> enableDelayedPrivatization(
34     "openmp-enable-delayed-privatization",
35     llvm::cl::desc(
36         "Emit `[first]private` variables as clauses on the MLIR ops."),
37     llvm::cl::init(false));
38 
39 llvm::cl::opt<bool> enableDelayedPrivatizationStaging(
40     "openmp-enable-delayed-privatization-staging",
41     llvm::cl::desc("For partially supported constructs, emit `[first]private` "
42                    "variables as clauses on the MLIR ops."),
43     llvm::cl::init(false));
44 
45 namespace Fortran {
46 namespace lower {
47 namespace omp {
48 
49 int64_t getCollapseValue(const List<Clause> &clauses) {
50   auto iter = llvm::find_if(clauses, [](const Clause &clause) {
51     return clause.id == llvm::omp::Clause::OMPC_collapse;
52   });
53   if (iter != clauses.end()) {
54     const auto &collapse = std::get<clause::Collapse>(iter->u);
55     return evaluate::ToInt64(collapse.v).value();
56   }
57   return 1;
58 }
59 
60 void genObjectList(const ObjectList &objects,
61                    lower::AbstractConverter &converter,
62                    llvm::SmallVectorImpl<mlir::Value> &operands) {
63   for (const Object &object : objects) {
64     const semantics::Symbol *sym = object.sym();
65     assert(sym && "Expected Symbol");
66     if (mlir::Value variable = converter.getSymbolAddress(*sym)) {
67       operands.push_back(variable);
68     } else if (const auto *details =
69                    sym->detailsIf<semantics::HostAssocDetails>()) {
70       operands.push_back(converter.getSymbolAddress(details->symbol()));
71       converter.copySymbolBinding(details->symbol(), *sym);
72     }
73   }
74 }
75 
76 mlir::Type getLoopVarType(lower::AbstractConverter &converter,
77                           std::size_t loopVarTypeSize) {
78   // OpenMP runtime requires 32-bit or 64-bit loop variables.
79   loopVarTypeSize = loopVarTypeSize * 8;
80   if (loopVarTypeSize < 32) {
81     loopVarTypeSize = 32;
82   } else if (loopVarTypeSize > 64) {
83     loopVarTypeSize = 64;
84     mlir::emitWarning(converter.getCurrentLocation(),
85                       "OpenMP loop iteration variable cannot have more than 64 "
86                       "bits size and will be narrowed into 64 bits.");
87   }
88   assert((loopVarTypeSize == 32 || loopVarTypeSize == 64) &&
89          "OpenMP loop iteration variable size must be transformed into 32-bit "
90          "or 64-bit");
91   return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize);
92 }
93 
94 semantics::Symbol *
95 getIterationVariableSymbol(const lower::pft::Evaluation &eval) {
96   return eval.visit(common::visitors{
97       [&](const parser::DoConstruct &doLoop) {
98         if (const auto &maybeCtrl = doLoop.GetLoopControl()) {
99           using LoopControl = parser::LoopControl;
100           if (auto *bounds = std::get_if<LoopControl::Bounds>(&maybeCtrl->u)) {
101             static_assert(std::is_same_v<decltype(bounds->name),
102                                          parser::Scalar<parser::Name>>);
103             return bounds->name.thing.symbol;
104           }
105         }
106         return static_cast<semantics::Symbol *>(nullptr);
107       },
108       [](auto &&) { return static_cast<semantics::Symbol *>(nullptr); },
109   });
110 }
111 
112 void gatherFuncAndVarSyms(
113     const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause,
114     llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
115   for (const Object &object : objects)
116     symbolAndClause.emplace_back(clause, *object.sym());
117 }
118 
119 mlir::omp::MapInfoOp
120 createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
121                 mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
122                 llvm::ArrayRef<mlir::Value> bounds,
123                 llvm::ArrayRef<mlir::Value> members,
124                 mlir::DenseIntElementsAttr membersIndex, uint64_t mapType,
125                 mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
126                 bool partialMap) {
127   if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) {
128     baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
129     retTy = baseAddr.getType();
130   }
131 
132   mlir::TypeAttr varType = mlir::TypeAttr::get(
133       llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType());
134 
135   mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>(
136       loc, retTy, baseAddr, varType, varPtrPtr, members, membersIndex, bounds,
137       builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
138       builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
139       builder.getStringAttr(name), builder.getBoolAttr(partialMap));
140 
141   return op;
142 }
143 
144 static int
145 getComponentPlacementInParent(const semantics::Symbol *componentSym) {
146   const auto *derived = componentSym->owner()
147                             .derivedTypeSpec()
148                             ->typeSymbol()
149                             .detailsIf<semantics::DerivedTypeDetails>();
150   assert(derived &&
151          "expected derived type details when processing component symbol");
152   for (auto [placement, name] : llvm::enumerate(derived->componentNames()))
153     if (name == componentSym->name())
154       return placement;
155   return -1;
156 }
157 
158 static std::optional<Object>
159 getComponentObject(std::optional<Object> object,
160                    semantics::SemanticsContext &semaCtx) {
161   if (!object)
162     return std::nullopt;
163 
164   auto ref = evaluate::ExtractDataRef(*object.value().ref());
165   if (!ref)
166     return std::nullopt;
167 
168   if (std::holds_alternative<evaluate::Component>(ref->u))
169     return object;
170 
171   auto baseObj = getBaseObject(object.value(), semaCtx);
172   if (!baseObj)
173     return std::nullopt;
174 
175   return getComponentObject(baseObj.value(), semaCtx);
176 }
177 
178 static void
179 generateMemberPlacementIndices(const Object &object,
180                                llvm::SmallVectorImpl<int> &indices,
181                                semantics::SemanticsContext &semaCtx) {
182   auto compObj = getComponentObject(object, semaCtx);
183   while (compObj) {
184     indices.push_back(getComponentPlacementInParent(compObj->sym()));
185     compObj =
186         getComponentObject(getBaseObject(compObj.value(), semaCtx), semaCtx);
187   }
188 
189   indices = llvm::SmallVector<int>{llvm::reverse(indices)};
190 }
191 
192 void addChildIndexAndMapToParent(
193     const omp::Object &object,
194     std::map<const semantics::Symbol *,
195              llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
196     mlir::omp::MapInfoOp &mapOp, semantics::SemanticsContext &semaCtx) {
197   std::optional<evaluate::DataRef> dataRef = ExtractDataRef(object.ref());
198   assert(dataRef.has_value() &&
199          "DataRef could not be extracted during mapping of derived type "
200          "cannot proceed");
201   const semantics::Symbol *parentSym = &dataRef->GetFirstSymbol();
202   assert(parentSym && "Could not find parent symbol during lower of "
203                       "a component member in OpenMP map clause");
204   llvm::SmallVector<int> indices;
205   generateMemberPlacementIndices(object, indices, semaCtx);
206   parentMemberIndices[parentSym].push_back({indices, mapOp});
207 }
208 
209 static void calculateShapeAndFillIndices(
210     llvm::SmallVectorImpl<int64_t> &shape,
211     llvm::SmallVectorImpl<OmpMapMemberIndicesData> &memberPlacementData) {
212   shape.push_back(memberPlacementData.size());
213   size_t largestIndicesSize =
214       std::max_element(memberPlacementData.begin(), memberPlacementData.end(),
215                        [](auto a, auto b) {
216                          return a.memberPlacementIndices.size() <
217                                 b.memberPlacementIndices.size();
218                        })
219           ->memberPlacementIndices.size();
220   shape.push_back(largestIndicesSize);
221 
222   // DenseElementsAttr expects a rectangular shape for the data, so all
223   // index lists have to be of the same length, this emplaces -1 as filler.
224   for (auto &v : memberPlacementData) {
225     if (v.memberPlacementIndices.size() < largestIndicesSize) {
226       auto *prevEnd = v.memberPlacementIndices.end();
227       v.memberPlacementIndices.resize(largestIndicesSize);
228       std::fill(prevEnd, v.memberPlacementIndices.end(), -1);
229     }
230   }
231 }
232 
233 static mlir::DenseIntElementsAttr createDenseElementsAttrFromIndices(
234     llvm::SmallVectorImpl<OmpMapMemberIndicesData> &memberPlacementData,
235     fir::FirOpBuilder &builder) {
236   llvm::SmallVector<int64_t> shape;
237   calculateShapeAndFillIndices(shape, memberPlacementData);
238 
239   llvm::SmallVector<int> indicesFlattened =
240       std::accumulate(memberPlacementData.begin(), memberPlacementData.end(),
241                       llvm::SmallVector<int>(),
242                       [](llvm::SmallVector<int> &x, OmpMapMemberIndicesData y) {
243                         x.insert(x.end(), y.memberPlacementIndices.begin(),
244                                  y.memberPlacementIndices.end());
245                         return x;
246                       });
247 
248   return mlir::DenseIntElementsAttr::get(
249       mlir::VectorType::get(shape,
250                             mlir::IntegerType::get(builder.getContext(), 32)),
251       indicesFlattened);
252 }
253 
254 void insertChildMapInfoIntoParent(
255     lower::AbstractConverter &converter,
256     std::map<const semantics::Symbol *,
257              llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
258     llvm::SmallVectorImpl<mlir::Value> &mapOperands,
259     llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms,
260     llvm::SmallVectorImpl<mlir::Type> *mapSymTypes,
261     llvm::SmallVectorImpl<mlir::Location> *mapSymLocs) {
262   for (auto indices : parentMemberIndices) {
263     bool parentExists = false;
264     size_t parentIdx;
265     for (parentIdx = 0; parentIdx < mapSyms.size(); ++parentIdx) {
266       if (mapSyms[parentIdx] == indices.first) {
267         parentExists = true;
268         break;
269       }
270     }
271 
272     if (parentExists) {
273       auto mapOp = llvm::cast<mlir::omp::MapInfoOp>(
274           mapOperands[parentIdx].getDefiningOp());
275 
276       // NOTE: To maintain appropriate SSA ordering, we move the parent map
277       // which will now have references to its children after the last
278       // of its members to be generated. This is necessary when a user
279       // has defined a series of parent and children maps where the parent
280       // precedes the children. An alternative, may be to do
281       // delayed generation of map info operations from the clauses and
282       // organize them first before generation.
283       mapOp->moveAfter(indices.second.back().memberMap);
284 
285       for (auto memberIndicesData : indices.second)
286         mapOp.getMembersMutable().append(
287             memberIndicesData.memberMap.getResult());
288 
289       mapOp.setMembersIndexAttr(createDenseElementsAttrFromIndices(
290           indices.second, converter.getFirOpBuilder()));
291     } else {
292       // NOTE: We take the map type of the first child, this may not
293       // be the correct thing to do, however, we shall see. For the moment
294       // it allows this to work with enter and exit without causing MLIR
295       // verification issues. The more appropriate thing may be to take
296       // the "main" map type clause from the directive being used.
297       uint64_t mapType = indices.second[0].memberMap.getMapType().value_or(0);
298 
299       // create parent to emplace and bind members
300       mlir::Value origSymbol = converter.getSymbolAddress(*indices.first);
301 
302       llvm::SmallVector<mlir::Value> members;
303       for (OmpMapMemberIndicesData memberIndicesData : indices.second)
304         members.push_back((mlir::Value)memberIndicesData.memberMap);
305 
306       mlir::Value mapOp = createMapInfoOp(
307           converter.getFirOpBuilder(), origSymbol.getLoc(), origSymbol,
308           /*varPtrPtr=*/mlir::Value(), indices.first->name().ToString(),
309           /*bounds=*/{}, members,
310           createDenseElementsAttrFromIndices(indices.second,
311                                              converter.getFirOpBuilder()),
312           mapType, mlir::omp::VariableCaptureKind::ByRef, origSymbol.getType(),
313           /*partialMap=*/true);
314 
315       mapOperands.push_back(mapOp);
316       mapSyms.push_back(indices.first);
317 
318       if (mapSymTypes)
319         mapSymTypes->push_back(mapOp.getType());
320       if (mapSymLocs)
321         mapSymLocs->push_back(mapOp.getLoc());
322     }
323   }
324 }
325 
326 semantics::Symbol *getOmpObjectSymbol(const parser::OmpObject &ompObject) {
327   semantics::Symbol *sym = nullptr;
328   Fortran::common::visit(
329       common::visitors{
330           [&](const parser::Designator &designator) {
331             if (auto *arrayEle =
332                     parser::Unwrap<parser::ArrayElement>(designator)) {
333               // Use getLastName to retrieve the arrays symbol, this will
334               // provide the farthest right symbol (the last) in a designator,
335               // i.e. providing something like the following:
336               // "dtype1%dtype2%array[2:10]", will result in "array"
337               sym = GetLastName(arrayEle->base).symbol;
338             } else if (auto *structComp =
339                            parser::Unwrap<parser::StructureComponent>(
340                                designator)) {
341               sym = structComp->component.symbol;
342             } else if (const parser::Name *name =
343                            semantics::getDesignatorNameIfDataRef(designator)) {
344               sym = name->symbol;
345             }
346           },
347           [&](const parser::Name &name) { sym = name.symbol; }},
348       ompObject.u);
349   return sym;
350 }
351 
352 } // namespace omp
353 } // namespace lower
354 } // namespace Fortran
355