xref: /llvm-project/flang/include/flang/Lower/IterationSpace.h (revision 6f8ef5ad2f35321257adbe353f86027bf5209023)
1 //===-- IterationSpace.h ----------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef FORTRAN_LOWER_ITERATIONSPACE_H
14 #define FORTRAN_LOWER_ITERATIONSPACE_H
15 
16 #include "flang/Evaluate/tools.h"
17 #include "flang/Lower/StatementContext.h"
18 #include "flang/Lower/SymbolMap.h"
19 #include "flang/Optimizer/Builder/FIRBuilder.h"
20 #include <optional>
21 
22 namespace llvm {
23 class raw_ostream;
24 }
25 
26 namespace Fortran {
27 namespace evaluate {
28 struct SomeType;
29 template <typename>
30 class Expr;
31 } // namespace evaluate
32 
33 namespace lower {
34 
35 using FrontEndExpr = const evaluate::Expr<evaluate::SomeType> *;
36 using FrontEndSymbol = const semantics::Symbol *;
37 
38 class AbstractConverter;
39 
40 } // namespace lower
41 } // namespace Fortran
42 
43 namespace Fortran::lower {
44 
45 /// Abstraction of the iteration space for building the elemental compute loop
46 /// of an array(-like) statement.
47 class IterationSpace {
48 public:
49   IterationSpace() = default;
50 
51   template <typename A>
52   explicit IterationSpace(mlir::Value inArg, mlir::Value outRes,
53                           llvm::iterator_range<A> range)
54       : inArg{inArg}, outRes{outRes}, indices{range.begin(), range.end()} {}
55 
56   explicit IterationSpace(const IterationSpace &from,
57                           llvm::ArrayRef<mlir::Value> idxs)
58       : inArg(from.inArg), outRes(from.outRes), element(from.element),
59         indices(idxs) {}
60 
61   /// Create a copy of the \p from IterationSpace and prepend the \p prefix
62   /// values and append the \p suffix values, respectively.
63   explicit IterationSpace(const IterationSpace &from,
64                           llvm::ArrayRef<mlir::Value> prefix,
65                           llvm::ArrayRef<mlir::Value> suffix)
66       : inArg(from.inArg), outRes(from.outRes), element(from.element) {
67     indices.assign(prefix.begin(), prefix.end());
68     indices.append(from.indices.begin(), from.indices.end());
69     indices.append(suffix.begin(), suffix.end());
70   }
71 
72   bool empty() const { return indices.empty(); }
73 
74   /// This is the output value as it appears as an argument in the innermost
75   /// loop in the nest. The output value is threaded through the loop (and
76   /// conditionals) to maintain proper SSA form.
77   mlir::Value innerArgument() const { return inArg; }
78 
79   /// This is the output value as it appears as an output value from the
80   /// outermost loop in the loop nest. The output value is threaded through the
81   /// loop (and conditionals) to maintain proper SSA form.
82   mlir::Value outerResult() const { return outRes; }
83 
84   /// Returns a vector for the iteration space. This vector is used to access
85   /// elements of arrays in the compute loop.
86   llvm::SmallVector<mlir::Value> iterVec() const { return indices; }
87 
88   mlir::Value iterValue(std::size_t i) const {
89     assert(i < indices.size());
90     return indices[i];
91   }
92 
93   /// Set (rewrite) the Value at a given index.
94   void setIndexValue(std::size_t i, mlir::Value v) {
95     assert(i < indices.size());
96     indices[i] = v;
97   }
98 
99   void setIndexValues(llvm::ArrayRef<mlir::Value> vals) {
100     indices.assign(vals.begin(), vals.end());
101   }
102 
103   void insertIndexValue(std::size_t i, mlir::Value av) {
104     assert(i <= indices.size());
105     indices.insert(indices.begin() + i, av);
106   }
107 
108   /// Set the `element` value. This is the SSA value that corresponds to an
109   /// element of the resultant array value.
110   void setElement(fir::ExtendedValue &&ele) {
111     assert(!fir::getBase(element) && "result element already set");
112     element = ele;
113   }
114 
115   /// Get the value that will be merged into the resultant array. This is the
116   /// computed value that will be stored to the lhs of the assignment.
117   mlir::Value getElement() const {
118     assert(fir::getBase(element) && "element must be set");
119     return fir::getBase(element);
120   }
121 
122   /// Get the element as an extended value.
123   fir::ExtendedValue elementExv() const { return element; }
124 
125   void clearIndices() { indices.clear(); }
126 
127 private:
128   mlir::Value inArg;
129   mlir::Value outRes;
130   fir::ExtendedValue element;
131   llvm::SmallVector<mlir::Value> indices;
132 };
133 
134 using GenerateElementalArrayFunc =
135     std::function<fir::ExtendedValue(const IterationSpace &)>;
136 
137 template <typename A>
138 class StackableConstructExpr {
139 public:
140   bool empty() const { return stack.empty(); }
141 
142   void growStack() { stack.push_back(A{}); }
143 
144   /// Bind a front-end expression to a closure.
145   void bind(FrontEndExpr e, GenerateElementalArrayFunc &&fun) {
146     vmap.insert({e, std::move(fun)});
147   }
148 
149   /// Replace the binding of front-end expression `e` with a new closure.
150   void rebind(FrontEndExpr e, GenerateElementalArrayFunc &&fun) {
151     vmap.erase(e);
152     bind(e, std::move(fun));
153   }
154 
155   /// Get the closure bound to the front-end expression, `e`.
156   GenerateElementalArrayFunc getBoundClosure(FrontEndExpr e) const {
157     if (!vmap.count(e))
158       llvm::report_fatal_error(
159           "evaluate::Expr is not in the map of lowered mask expressions");
160     return vmap.lookup(e);
161   }
162 
163   /// Has the front-end expression, `e`, been lowered and bound?
164   bool isLowered(FrontEndExpr e) const { return vmap.count(e); }
165 
166   StatementContext &stmtContext() { return stmtCtx; }
167 
168 protected:
169   void shrinkStack() {
170     assert(!empty());
171     stack.pop_back();
172     if (empty()) {
173       stmtCtx.finalizeAndReset();
174       vmap.clear();
175     }
176   }
177 
178   // The stack for the construct information.
179   llvm::SmallVector<A> stack;
180 
181   // Map each mask expression back to the temporary holding the initial
182   // evaluation results.
183   llvm::DenseMap<FrontEndExpr, GenerateElementalArrayFunc> vmap;
184 
185   // Inflate the statement context for the entire construct. We have to cache
186   // the mask expression results, which are always evaluated first, across the
187   // entire construct.
188   StatementContext stmtCtx;
189 };
190 
191 class ImplicitIterSpace;
192 llvm::raw_ostream &operator<<(llvm::raw_ostream &, const ImplicitIterSpace &);
193 
194 /// All array expressions have an implicit iteration space, which is isomorphic
195 /// to the shape of the base array that facilitates the expression having a
196 /// non-zero rank. This implied iteration space may be conditionalized
197 /// (disjunctively) with an if-elseif-else like structure, specifically
198 /// Fortran's WHERE construct.
199 ///
200 /// This class is used in the bridge to collect the expressions from the
201 /// front end (the WHERE construct mask expressions), forward them for lowering
202 /// as array expressions in an "evaluate once" (copy-in, copy-out) semantics.
203 /// See 10.2.3.2p3, 10.2.3.2p13, etc.
204 class ImplicitIterSpace
205     : public StackableConstructExpr<llvm::SmallVector<FrontEndExpr>> {
206 public:
207   using Base = StackableConstructExpr<llvm::SmallVector<FrontEndExpr>>;
208   using FrontEndMaskExpr = FrontEndExpr;
209 
210   friend llvm::raw_ostream &operator<<(llvm::raw_ostream &,
211                                        const ImplicitIterSpace &);
212 
213   LLVM_DUMP_METHOD void dump() const;
214 
215   void append(FrontEndMaskExpr e) {
216     assert(!empty());
217     getMasks().back().push_back(e);
218   }
219 
220   llvm::SmallVector<FrontEndMaskExpr> getExprs() const {
221     llvm::SmallVector<FrontEndMaskExpr> maskList = getMasks()[0];
222     for (size_t i = 1, d = getMasks().size(); i < d; ++i)
223       maskList.append(getMasks()[i].begin(), getMasks()[i].end());
224     return maskList;
225   }
226 
227   /// Add a variable binding, `var`, along with its shape for the mask
228   /// expression `exp`.
229   void addMaskVariable(FrontEndExpr exp, mlir::Value var, mlir::Value shape,
230                        mlir::Value header) {
231     maskVarMap.try_emplace(exp, std::make_tuple(var, shape, header));
232   }
233 
234   /// Lookup the variable corresponding to the temporary buffer that contains
235   /// the mask array expression results.
236   mlir::Value lookupMaskVariable(FrontEndExpr exp) {
237     return std::get<0>(maskVarMap.lookup(exp));
238   }
239 
240   /// Lookup the variable containing the shape vector for the mask array
241   /// expression results.
242   mlir::Value lookupMaskShapeBuffer(FrontEndExpr exp) {
243     return std::get<1>(maskVarMap.lookup(exp));
244   }
245 
246   mlir::Value lookupMaskHeader(FrontEndExpr exp) {
247     return std::get<2>(maskVarMap.lookup(exp));
248   }
249 
250   // Stack of WHERE constructs, each building a list of mask expressions.
251   llvm::SmallVector<llvm::SmallVector<FrontEndMaskExpr>> &getMasks() {
252     return stack;
253   }
254   const llvm::SmallVector<llvm::SmallVector<FrontEndMaskExpr>> &
255   getMasks() const {
256     return stack;
257   }
258 
259   // Cleanup at the end of a WHERE statement or construct.
260   void shrinkStack() {
261     Base::shrinkStack();
262     if (stack.empty())
263       maskVarMap.clear();
264   }
265 
266 private:
267   llvm::DenseMap<FrontEndExpr,
268                  std::tuple<mlir::Value, mlir::Value, mlir::Value>>
269       maskVarMap;
270 };
271 
272 class ExplicitIterSpace;
273 llvm::raw_ostream &operator<<(llvm::raw_ostream &, const ExplicitIterSpace &);
274 
275 /// Create all the array_load ops for the explicit iteration space context. The
276 /// nest of FORALLs must have been analyzed a priori.
277 void createArrayLoads(AbstractConverter &converter, ExplicitIterSpace &esp,
278                       SymMap &symMap);
279 
280 /// Create the array_merge_store ops after the explicit iteration space context
281 /// is conmpleted.
282 void createArrayMergeStores(AbstractConverter &converter,
283                             ExplicitIterSpace &esp);
284 using ExplicitSpaceArrayBases =
285     std::variant<FrontEndSymbol, const evaluate::Component *,
286                  const evaluate::ArrayRef *>;
287 
288 unsigned getHashValue(const ExplicitSpaceArrayBases &x);
289 bool isEqual(const ExplicitSpaceArrayBases &x,
290              const ExplicitSpaceArrayBases &y);
291 
292 } // namespace Fortran::lower
293 
294 namespace llvm {
295 template <>
296 struct DenseMapInfo<Fortran::lower::ExplicitSpaceArrayBases> {
297   static inline Fortran::lower::ExplicitSpaceArrayBases getEmptyKey() {
298     return reinterpret_cast<Fortran::lower::FrontEndSymbol>(~0);
299   }
300   static inline Fortran::lower::ExplicitSpaceArrayBases getTombstoneKey() {
301     return reinterpret_cast<Fortran::lower::FrontEndSymbol>(~0 - 1);
302   }
303   static unsigned
304   getHashValue(const Fortran::lower::ExplicitSpaceArrayBases &v) {
305     return Fortran::lower::getHashValue(v);
306   }
307   static bool isEqual(const Fortran::lower::ExplicitSpaceArrayBases &lhs,
308                       const Fortran::lower::ExplicitSpaceArrayBases &rhs) {
309     return Fortran::lower::isEqual(lhs, rhs);
310   }
311 };
312 } // namespace llvm
313 
314 namespace Fortran::lower {
315 /// Fortran also allows arrays to be evaluated under constructs which allow the
316 /// user to explicitly specify the iteration space using concurrent-control
317 /// expressions. These constructs allow the user to define both an iteration
318 /// space and explicit access vectors on arrays. These need not be isomorphic.
319 /// The explicit iteration spaces may be conditionalized (conjunctively) with an
320 /// "and" structure and may be found in FORALL (and DO CONCURRENT) constructs.
321 ///
322 /// This class is used in the bridge to collect a stack of lists of
323 /// concurrent-control expressions to be used to generate the iteration space
324 /// and associated masks (if any) for a set of nested FORALL constructs around
325 /// assignment and WHERE constructs.
326 class ExplicitIterSpace {
327 public:
328   using IterSpaceDim =
329       std::tuple<FrontEndSymbol, FrontEndExpr, FrontEndExpr, FrontEndExpr>;
330   using ConcurrentSpec =
331       std::pair<llvm::SmallVector<IterSpaceDim>, FrontEndExpr>;
332   using ArrayBases = ExplicitSpaceArrayBases;
333 
334   friend void createArrayLoads(AbstractConverter &converter,
335                                ExplicitIterSpace &esp, SymMap &symMap);
336   friend void createArrayMergeStores(AbstractConverter &converter,
337                                      ExplicitIterSpace &esp);
338 
339   /// Is a FORALL context presently active?
340   /// If we are lowering constructs/statements nested within a FORALL, then a
341   /// FORALL context is active.
342   bool isActive() const { return forallContextOpen != 0; }
343 
344   /// Get the statement context.
345   StatementContext &stmtContext() { return stmtCtx; }
346 
347   //===--------------------------------------------------------------------===//
348   // Analysis support
349   //===--------------------------------------------------------------------===//
350 
351   /// Open a new construct. The analysis phase starts here.
352   void pushLevel();
353 
354   /// Close the construct.
355   void popLevel();
356 
357   /// Add new concurrent header control variable symbol.
358   void addSymbol(FrontEndSymbol sym);
359 
360   /// Collect array bases from the expression, `x`.
361   void exprBase(FrontEndExpr x, bool lhs);
362 
363   /// Called at the end of a assignment statement.
364   void endAssign();
365 
366   /// Return all the active control variables on the stack.
367   llvm::SmallVector<FrontEndSymbol> collectAllSymbols();
368 
369   //===--------------------------------------------------------------------===//
370   // Code gen support
371   //===--------------------------------------------------------------------===//
372 
373   /// Enter a FORALL context.
374   void enter() { forallContextOpen++; }
375 
376   /// Leave a FORALL context.
377   void leave();
378 
379   void pushLoopNest(std::function<void()> lambda) {
380     ccLoopNest.push_back(lambda);
381   }
382 
383   /// Get the inner arguments that correspond to the output arrays.
384   mlir::ValueRange getInnerArgs() const { return innerArgs; }
385 
386   /// Set the inner arguments for the next loop level.
387   void setInnerArgs(llvm::ArrayRef<mlir::BlockArgument> args) {
388     innerArgs.clear();
389     for (auto &arg : args)
390       innerArgs.push_back(arg);
391   }
392 
393   /// Reset the outermost `array_load` arguments to the loop nest.
394   void resetInnerArgs() { innerArgs = initialArgs; }
395 
396   /// Capture the current outermost loop.
397   void setOuterLoop(fir::DoLoopOp loop) {
398     clearLoops();
399     outerLoop = loop;
400   }
401 
402   /// Sets the inner loop argument at position \p offset to \p val.
403   void setInnerArg(size_t offset, mlir::Value val) {
404     assert(offset < innerArgs.size());
405     innerArgs[offset] = val;
406   }
407 
408   /// Get the types of the output arrays.
409   llvm::SmallVector<mlir::Type> innerArgTypes() const {
410     llvm::SmallVector<mlir::Type> result;
411     for (auto &arg : innerArgs)
412       result.push_back(arg.getType());
413     return result;
414   }
415 
416   /// Create a binding between an Ev::Expr node pointer and a fir::array_load
417   /// op. This bindings will be used when generating the IR.
418   void bindLoad(ArrayBases base, fir::ArrayLoadOp load) {
419     loadBindings.try_emplace(std::move(base), load);
420   }
421 
422   fir::ArrayLoadOp findBinding(const ArrayBases &base) {
423     return loadBindings.lookup(base);
424   }
425 
426   /// `load` must be a LHS array_load. Returns `std::nullopt` on error.
427   std::optional<size_t> findArgPosition(fir::ArrayLoadOp load);
428 
429   bool isLHS(fir::ArrayLoadOp load) {
430     return findArgPosition(load).has_value();
431   }
432 
433   /// `load` must be a LHS array_load. Determine the threaded inner argument
434   /// corresponding to this load.
435   mlir::Value findArgumentOfLoad(fir::ArrayLoadOp load) {
436     if (auto opt = findArgPosition(load))
437       return innerArgs[*opt];
438     llvm_unreachable("array load argument not found");
439   }
440 
441   size_t argPosition(mlir::Value arg) {
442     for (auto i : llvm::enumerate(innerArgs))
443       if (arg == i.value())
444         return i.index();
445     llvm_unreachable("inner argument value was not found");
446   }
447 
448   std::optional<fir::ArrayLoadOp> getLhsLoad(size_t i) {
449     assert(i < lhsBases.size());
450     if (lhsBases[counter])
451       return findBinding(*lhsBases[counter]);
452     return std::nullopt;
453   }
454 
455   /// Return the outermost loop in this FORALL nest.
456   fir::DoLoopOp getOuterLoop() {
457     assert(outerLoop.has_value());
458     return *outerLoop;
459   }
460 
461   /// Return the statement context for the entire, outermost FORALL construct.
462   StatementContext &outermostContext() { return outerContext; }
463 
464   /// Generate the explicit loop nest.
465   void genLoopNest() {
466     for (auto &lambda : ccLoopNest)
467       lambda();
468   }
469 
470   /// Clear the array_load bindings.
471   void resetBindings() { loadBindings.clear(); }
472 
473   /// Get the current counter value.
474   std::size_t getCounter() const { return counter; }
475 
476   /// Increment the counter value to the next assignment statement.
477   void incrementCounter() { counter++; }
478 
479   bool isOutermostForall() const {
480     assert(forallContextOpen);
481     return forallContextOpen == 1;
482   }
483 
484   void attachLoopCleanup(std::function<void(fir::FirOpBuilder &builder)> fn) {
485     if (!loopCleanup) {
486       loopCleanup = fn;
487       return;
488     }
489     std::function<void(fir::FirOpBuilder &)> oldFn = *loopCleanup;
490     loopCleanup = [=](fir::FirOpBuilder &builder) {
491       oldFn(builder);
492       fn(builder);
493     };
494   }
495 
496   // LLVM standard dump method.
497   LLVM_DUMP_METHOD void dump() const;
498 
499   // Pretty-print.
500   friend llvm::raw_ostream &operator<<(llvm::raw_ostream &,
501                                        const ExplicitIterSpace &);
502 
503   /// Finalize the current body statement context.
504   void finalizeContext() { stmtCtx.finalizeAndReset(); }
505 
506   void appendLoops(const llvm::SmallVector<fir::DoLoopOp> &loops) {
507     loopStack.push_back(loops);
508   }
509 
510   void clearLoops() { loopStack.clear(); }
511 
512   llvm::SmallVector<llvm::SmallVector<fir::DoLoopOp>> getLoopStack() const {
513     return loopStack;
514   }
515 
516 private:
517   /// Cleanup the analysis results.
518   void conditionalCleanup();
519 
520   StatementContext outerContext;
521 
522   // A stack of lists of front-end symbols.
523   llvm::SmallVector<llvm::SmallVector<FrontEndSymbol>> symbolStack;
524   llvm::SmallVector<std::optional<ArrayBases>> lhsBases;
525   llvm::SmallVector<llvm::SmallVector<ArrayBases>> rhsBases;
526   llvm::DenseMap<ArrayBases, fir::ArrayLoadOp> loadBindings;
527 
528   // Stack of lambdas to create the loop nest.
529   llvm::SmallVector<std::function<void()>> ccLoopNest;
530 
531   // Assignment statement context (inside the loop nest).
532   StatementContext stmtCtx;
533   llvm::SmallVector<mlir::Value> innerArgs;
534   llvm::SmallVector<mlir::Value> initialArgs;
535   std::optional<fir::DoLoopOp> outerLoop;
536   llvm::SmallVector<llvm::SmallVector<fir::DoLoopOp>> loopStack;
537   std::optional<std::function<void(fir::FirOpBuilder &)>> loopCleanup;
538   std::size_t forallContextOpen = 0;
539   std::size_t counter = 0;
540 };
541 
542 /// Is there a Symbol in common between the concurrent header set and the set
543 /// of symbols in the expression?
544 template <typename A>
545 bool symbolSetsIntersect(llvm::ArrayRef<FrontEndSymbol> ctrlSet,
546                          const A &exprSyms) {
547   for (const auto &sym : exprSyms)
548     if (llvm::is_contained(ctrlSet, &sym.get()))
549       return true;
550   return false;
551 }
552 
553 /// Determine if the subscript expression symbols from an Ev::ArrayRef
554 /// intersects with the set of concurrent control symbols, `ctrlSet`.
555 template <typename A>
556 bool symbolsIntersectSubscripts(llvm::ArrayRef<FrontEndSymbol> ctrlSet,
557                                 const A &subscripts) {
558   for (auto &sub : subscripts) {
559     if (const auto *expr =
560             std::get_if<evaluate::IndirectSubscriptIntegerExpr>(&sub.u))
561       if (symbolSetsIntersect(ctrlSet, evaluate::CollectSymbols(expr->value())))
562         return true;
563   }
564   return false;
565 }
566 
567 } // namespace Fortran::lower
568 
569 #endif // FORTRAN_LOWER_ITERATIONSPACE_H
570