xref: /llvm-project/flang/lib/Lower/OpenMP/Utils.cpp (revision 25a3ba33153e99c4614d404ba18b761d652e24de)
1 //===-- Utils..cpp ----------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Utils.h"
14 
15 #include "Clauses.h"
16 #include <flang/Lower/AbstractConverter.h>
17 #include <flang/Lower/ConvertType.h>
18 #include <flang/Lower/PFTBuilder.h>
19 #include <flang/Optimizer/Builder/FIRBuilder.h>
20 #include <flang/Parser/parse-tree.h>
21 #include <flang/Parser/tools.h>
22 #include <flang/Semantics/tools.h>
23 #include <llvm/Support/CommandLine.h>
24 
25 #include <algorithm>
26 #include <numeric>
27 
28 llvm::cl::opt<bool> treatIndexAsSection(
29     "openmp-treat-index-as-section",
30     llvm::cl::desc("In the OpenMP data clauses treat `a(N)` as `a(N:N)`."),
31     llvm::cl::init(true));
32 
33 llvm::cl::opt<bool> enableDelayedPrivatization(
34     "openmp-enable-delayed-privatization",
35     llvm::cl::desc(
36         "Emit `[first]private` variables as clauses on the MLIR ops."),
37     llvm::cl::init(false));
38 
39 namespace Fortran {
40 namespace lower {
41 namespace omp {
42 
43 int64_t getCollapseValue(const List<Clause> &clauses) {
44   auto iter = llvm::find_if(clauses, [](const Clause &clause) {
45     return clause.id == llvm::omp::Clause::OMPC_collapse;
46   });
47   if (iter != clauses.end()) {
48     const auto &collapse = std::get<clause::Collapse>(iter->u);
49     return evaluate::ToInt64(collapse.v).value();
50   }
51   return 1;
52 }
53 
54 uint32_t getOpenMPVersion(mlir::ModuleOp mod) {
55   if (mlir::Attribute verAttr = mod->getAttr("omp.version"))
56     return llvm::cast<mlir::omp::VersionAttr>(verAttr).getVersion();
57   llvm_unreachable("Expecting OpenMP version attribute in module");
58 }
59 
60 void genObjectList(const ObjectList &objects,
61                    Fortran::lower::AbstractConverter &converter,
62                    llvm::SmallVectorImpl<mlir::Value> &operands) {
63   for (const Object &object : objects) {
64     const Fortran::semantics::Symbol *sym = object.id();
65     assert(sym && "Expected Symbol");
66     if (mlir::Value variable = converter.getSymbolAddress(*sym)) {
67       operands.push_back(variable);
68     } else if (const auto *details =
69                    sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
70       operands.push_back(converter.getSymbolAddress(details->symbol()));
71       converter.copySymbolBinding(details->symbol(), *sym);
72     }
73   }
74 }
75 
76 mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter,
77                           std::size_t loopVarTypeSize) {
78   // OpenMP runtime requires 32-bit or 64-bit loop variables.
79   loopVarTypeSize = loopVarTypeSize * 8;
80   if (loopVarTypeSize < 32) {
81     loopVarTypeSize = 32;
82   } else if (loopVarTypeSize > 64) {
83     loopVarTypeSize = 64;
84     mlir::emitWarning(converter.getCurrentLocation(),
85                       "OpenMP loop iteration variable cannot have more than 64 "
86                       "bits size and will be narrowed into 64 bits.");
87   }
88   assert((loopVarTypeSize == 32 || loopVarTypeSize == 64) &&
89          "OpenMP loop iteration variable size must be transformed into 32-bit "
90          "or 64-bit");
91   return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize);
92 }
93 
94 Fortran::semantics::Symbol *
95 getIterationVariableSymbol(const Fortran::lower::pft::Evaluation &eval) {
96   return eval.visit(Fortran::common::visitors{
97       [&](const Fortran::parser::DoConstruct &doLoop) {
98         if (const auto &maybeCtrl = doLoop.GetLoopControl()) {
99           using LoopControl = Fortran::parser::LoopControl;
100           if (auto *bounds = std::get_if<LoopControl::Bounds>(&maybeCtrl->u)) {
101             static_assert(
102                 std::is_same_v<decltype(bounds->name),
103                                Fortran::parser::Scalar<Fortran::parser::Name>>);
104             return bounds->name.thing.symbol;
105           }
106         }
107         return static_cast<Fortran::semantics::Symbol *>(nullptr);
108       },
109       [](auto &&) {
110         return static_cast<Fortran::semantics::Symbol *>(nullptr);
111       },
112   });
113 }
114 
115 void gatherFuncAndVarSyms(
116     const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause,
117     llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
118   for (const Object &object : objects)
119     symbolAndClause.emplace_back(clause, *object.id());
120 }
121 
122 mlir::omp::MapInfoOp
123 createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
124                 mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
125                 llvm::ArrayRef<mlir::Value> bounds,
126                 llvm::ArrayRef<mlir::Value> members,
127                 mlir::DenseIntElementsAttr membersIndex, uint64_t mapType,
128                 mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
129                 bool partialMap) {
130   if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) {
131     baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
132     retTy = baseAddr.getType();
133   }
134 
135   mlir::TypeAttr varType = mlir::TypeAttr::get(
136       llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType());
137 
138   mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>(
139       loc, retTy, baseAddr, varType, varPtrPtr, members, membersIndex, bounds,
140       builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
141       builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
142       builder.getStringAttr(name), builder.getBoolAttr(partialMap));
143 
144   return op;
145 }
146 
147 static int
148 getComponentPlacementInParent(const Fortran::semantics::Symbol *componentSym) {
149   const auto *derived =
150       componentSym->owner()
151           .derivedTypeSpec()
152           ->typeSymbol()
153           .detailsIf<Fortran::semantics::DerivedTypeDetails>();
154   assert(derived &&
155          "expected derived type details when processing component symbol");
156   for (auto [placement, name] : llvm::enumerate(derived->componentNames()))
157     if (name == componentSym->name())
158       return placement;
159   return -1;
160 }
161 
162 static std::optional<Object>
163 getComponentObject(std::optional<Object> object,
164                    Fortran::semantics::SemanticsContext &semaCtx) {
165   if (!object)
166     return std::nullopt;
167 
168   auto ref = evaluate::ExtractDataRef(*object.value().ref());
169   if (!ref)
170     return std::nullopt;
171 
172   if (std::holds_alternative<evaluate::Component>(ref->u))
173     return object;
174 
175   auto baseObj = getBaseObject(object.value(), semaCtx);
176   if (!baseObj)
177     return std::nullopt;
178 
179   return getComponentObject(baseObj.value(), semaCtx);
180 }
181 
182 static void
183 generateMemberPlacementIndices(const Object &object,
184                                llvm::SmallVectorImpl<int> &indices,
185                                Fortran::semantics::SemanticsContext &semaCtx) {
186   auto compObj = getComponentObject(object, semaCtx);
187   while (compObj) {
188     indices.push_back(getComponentPlacementInParent(compObj->id()));
189     compObj =
190         getComponentObject(getBaseObject(compObj.value(), semaCtx), semaCtx);
191   }
192 
193   indices = llvm::SmallVector<int>{llvm::reverse(indices)};
194 }
195 
196 void addChildIndexAndMapToParent(
197     const omp::Object &object,
198     std::map<const Fortran::semantics::Symbol *,
199              llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
200     mlir::omp::MapInfoOp &mapOp,
201     Fortran::semantics::SemanticsContext &semaCtx) {
202   std::optional<Fortran::evaluate::DataRef> dataRef =
203       ExtractDataRef(object.designator);
204   assert(dataRef.has_value() &&
205          "DataRef could not be extracted during mapping of derived type "
206          "cannot proceed");
207   const Fortran::semantics::Symbol *parentSym = &dataRef->GetFirstSymbol();
208   assert(parentSym && "Could not find parent symbol during lower of "
209                       "a component member in OpenMP map clause");
210   llvm::SmallVector<int> indices;
211   generateMemberPlacementIndices(object, indices, semaCtx);
212   parentMemberIndices[parentSym].push_back({indices, mapOp});
213 }
214 
215 static void calculateShapeAndFillIndices(
216     llvm::SmallVectorImpl<int64_t> &shape,
217     llvm::SmallVectorImpl<OmpMapMemberIndicesData> &memberPlacementData) {
218   shape.push_back(memberPlacementData.size());
219   size_t largestIndicesSize =
220       std::max_element(memberPlacementData.begin(), memberPlacementData.end(),
221                        [](auto a, auto b) {
222                          return a.memberPlacementIndices.size() <
223                                 b.memberPlacementIndices.size();
224                        })
225           ->memberPlacementIndices.size();
226   shape.push_back(largestIndicesSize);
227 
228   // DenseElementsAttr expects a rectangular shape for the data, so all
229   // index lists have to be of the same length, this emplaces -1 as filler.
230   for (auto &v : memberPlacementData) {
231     if (v.memberPlacementIndices.size() < largestIndicesSize) {
232       auto *prevEnd = v.memberPlacementIndices.end();
233       v.memberPlacementIndices.resize(largestIndicesSize);
234       std::fill(prevEnd, v.memberPlacementIndices.end(), -1);
235     }
236   }
237 }
238 
239 static mlir::DenseIntElementsAttr createDenseElementsAttrFromIndices(
240     llvm::SmallVectorImpl<OmpMapMemberIndicesData> &memberPlacementData,
241     fir::FirOpBuilder &builder) {
242   llvm::SmallVector<int64_t> shape;
243   calculateShapeAndFillIndices(shape, memberPlacementData);
244 
245   llvm::SmallVector<int> indicesFlattened = std::accumulate(
246       memberPlacementData.begin(), memberPlacementData.end(),
247       llvm::SmallVector<int>(),
248       [](llvm::SmallVector<int> &x, OmpMapMemberIndicesData y) {
249         x.insert(x.end(), y.memberPlacementIndices.begin(),
250                  y.memberPlacementIndices.end());
251         return x;
252       });
253 
254   return mlir::DenseIntElementsAttr::get(
255       mlir::VectorType::get(shape,
256                             mlir::IntegerType::get(builder.getContext(), 32)),
257       indicesFlattened);
258 }
259 
260 void insertChildMapInfoIntoParent(
261     Fortran::lower::AbstractConverter &converter,
262     std::map<const Fortran::semantics::Symbol *,
263              llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
264     llvm::SmallVectorImpl<mlir::Value> &mapOperands,
265     llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &mapSyms,
266     llvm::SmallVectorImpl<mlir::Type> *mapSymTypes,
267     llvm::SmallVectorImpl<mlir::Location> *mapSymLocs) {
268   for (auto indices : parentMemberIndices) {
269     bool parentExists = false;
270     size_t parentIdx;
271     for (parentIdx = 0; parentIdx < mapSyms.size(); ++parentIdx) {
272       if (mapSyms[parentIdx] == indices.first) {
273         parentExists = true;
274         break;
275       }
276     }
277 
278     if (parentExists) {
279       auto mapOp = llvm::cast<mlir::omp::MapInfoOp>(
280           mapOperands[parentIdx].getDefiningOp());
281 
282       // NOTE: To maintain appropriate SSA ordering, we move the parent map
283       // which will now have references to its children after the last
284       // of its members to be generated. This is necessary when a user
285       // has defined a series of parent and children maps where the parent
286       // precedes the children. An alternative, may be to do
287       // delayed generation of map info operations from the clauses and
288       // organize them first before generation.
289       mapOp->moveAfter(indices.second.back().memberMap);
290 
291       for (auto memberIndicesData : indices.second)
292         mapOp.getMembersMutable().append(
293             memberIndicesData.memberMap.getResult());
294 
295       mapOp.setMembersIndexAttr(createDenseElementsAttrFromIndices(
296           indices.second, converter.getFirOpBuilder()));
297     } else {
298       // NOTE: We take the map type of the first child, this may not
299       // be the correct thing to do, however, we shall see. For the moment
300       // it allows this to work with enter and exit without causing MLIR
301       // verification issues. The more appropriate thing may be to take
302       // the "main" map type clause from the directive being used.
303       uint64_t mapType = indices.second[0].memberMap.getMapType().value_or(0);
304 
305       // create parent to emplace and bind members
306       mlir::Value origSymbol = converter.getSymbolAddress(*indices.first);
307 
308       llvm::SmallVector<mlir::Value> members;
309       for (OmpMapMemberIndicesData memberIndicesData : indices.second)
310         members.push_back((mlir::Value)memberIndicesData.memberMap);
311 
312       mlir::Value mapOp = createMapInfoOp(
313           converter.getFirOpBuilder(), origSymbol.getLoc(), origSymbol,
314           /*varPtrPtr=*/mlir::Value(), indices.first->name().ToString(),
315           /*bounds=*/{}, members,
316           createDenseElementsAttrFromIndices(indices.second,
317                                              converter.getFirOpBuilder()),
318           mapType, mlir::omp::VariableCaptureKind::ByRef, origSymbol.getType(),
319           /*partialMap=*/true);
320 
321       mapOperands.push_back(mapOp);
322       mapSyms.push_back(indices.first);
323 
324       if (mapSymTypes)
325         mapSymTypes->push_back(mapOp.getType());
326       if (mapSymLocs)
327         mapSymLocs->push_back(mapOp.getLoc());
328     }
329   }
330 }
331 
332 Fortran::semantics::Symbol *
333 getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) {
334   Fortran::semantics::Symbol *sym = nullptr;
335   std::visit(
336       Fortran::common::visitors{
337           [&](const Fortran::parser::Designator &designator) {
338             if (auto *arrayEle =
339                     Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
340                         designator)) {
341               // Use getLastName to retrieve the arrays symbol, this will
342               // provide the farthest right symbol (the last) in a designator,
343               // i.e. providing something like the following:
344               // "dtype1%dtype2%array[2:10]", will result in "array"
345               sym = GetLastName(arrayEle->base).symbol;
346             } else if (auto *structComp = Fortran::parser::Unwrap<
347                            Fortran::parser::StructureComponent>(designator)) {
348               sym = structComp->component.symbol;
349             } else if (const Fortran::parser::Name *name =
350                            Fortran::semantics::getDesignatorNameIfDataRef(
351                                designator)) {
352               sym = name->symbol;
353             }
354           },
355           [&](const Fortran::parser::Name &name) { sym = name.symbol; }},
356       ompObject.u);
357   return sym;
358 }
359 
360 } // namespace omp
361 } // namespace lower
362 } // namespace Fortran
363