xref: /llvm-project/flang/lib/Lower/IterationSpace.cpp (revision 6f8ef5ad2f35321257adbe353f86027bf5209023)
1 //===-- IterationSpace.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Lower/IterationSpace.h"
14 #include "flang/Evaluate/expression.h"
15 #include "flang/Lower/AbstractConverter.h"
16 #include "flang/Lower/Support/Utils.h"
17 #include "llvm/Support/Debug.h"
18 #include <optional>
19 
20 #define DEBUG_TYPE "flang-lower-iteration-space"
21 
22 unsigned Fortran::lower::getHashValue(
23     const Fortran::lower::ExplicitIterSpace::ArrayBases &x) {
24   return Fortran::common::visit(
25       [&](const auto *p) { return HashEvaluateExpr::getHashValue(*p); }, x);
26 }
27 
28 bool Fortran::lower::isEqual(
29     const Fortran::lower::ExplicitIterSpace::ArrayBases &x,
30     const Fortran::lower::ExplicitIterSpace::ArrayBases &y) {
31   return Fortran::common::visit(
32       Fortran::common::visitors{
33           // Fortran::semantics::Symbol * are the exception here. These pointers
34           // have identity; if two Symbol * values are the same (different) then
35           // they are the same (different) logical symbol.
36           [&](Fortran::lower::FrontEndSymbol p,
37               Fortran::lower::FrontEndSymbol q) { return p == q; },
38           [&](const auto *p, const auto *q) {
39             if constexpr (std::is_same_v<decltype(p), decltype(q)>) {
40               LLVM_DEBUG(llvm::dbgs()
41                          << "is equal: " << p << ' ' << q << ' '
42                          << IsEqualEvaluateExpr::isEqual(*p, *q) << '\n');
43               return IsEqualEvaluateExpr::isEqual(*p, *q);
44             } else {
45               // Different subtree types are never equal.
46               return false;
47             }
48           }},
49       x, y);
50 }
51 
52 namespace {
53 
54 /// This class can recover the base array in an expression that contains
55 /// explicit iteration space symbols. Most of the class can be ignored as it is
56 /// boilerplate Fortran::evaluate::Expr traversal.
57 class ArrayBaseFinder {
58 public:
59   using RT = bool;
60 
61   ArrayBaseFinder(llvm::ArrayRef<Fortran::lower::FrontEndSymbol> syms)
62       : controlVars(syms) {}
63 
64   template <typename T>
65   void operator()(const T &x) {
66     (void)find(x);
67   }
68 
69   /// Get the list of bases.
70   llvm::ArrayRef<Fortran::lower::ExplicitIterSpace::ArrayBases>
71   getBases() const {
72     LLVM_DEBUG(llvm::dbgs()
73                << "number of array bases found: " << bases.size() << '\n');
74     return bases;
75   }
76 
77 private:
78   // First, the cases that are of interest.
79   RT find(const Fortran::semantics::Symbol &symbol) {
80     if (symbol.Rank() > 0) {
81       bases.push_back(&symbol);
82       return true;
83     }
84     return {};
85   }
86   RT find(const Fortran::evaluate::Component &x) {
87     auto found = find(x.base());
88     if (!found && x.base().Rank() == 0 && x.Rank() > 0) {
89       bases.push_back(&x);
90       return true;
91     }
92     return found;
93   }
94   RT find(const Fortran::evaluate::ArrayRef &x) {
95     for (const auto &sub : x.subscript())
96       (void)find(sub);
97     if (x.base().IsSymbol()) {
98       if (x.Rank() > 0 || intersection(x.subscript())) {
99         bases.push_back(&x);
100         return true;
101       }
102       return {};
103     }
104     auto found = find(x.base());
105     if (!found && ((x.base().Rank() == 0 && x.Rank() > 0) ||
106                    intersection(x.subscript()))) {
107       bases.push_back(&x);
108       return true;
109     }
110     return found;
111   }
112   RT find(const Fortran::evaluate::Triplet &x) {
113     if (const auto *lower = x.GetLower())
114       (void)find(*lower);
115     if (const auto *upper = x.GetUpper())
116       (void)find(*upper);
117     return find(x.GetStride());
118   }
119   RT find(const Fortran::evaluate::IndirectSubscriptIntegerExpr &x) {
120     return find(x.value());
121   }
122   RT find(const Fortran::evaluate::Subscript &x) { return find(x.u); }
123   RT find(const Fortran::evaluate::DataRef &x) { return find(x.u); }
124   RT find(const Fortran::evaluate::CoarrayRef &x) {
125     assert(false && "coarray reference");
126     return {};
127   }
128 
129   template <typename A>
130   bool intersection(const A &subscripts) {
131     return Fortran::lower::symbolsIntersectSubscripts(controlVars, subscripts);
132   }
133 
134   // The rest is traversal boilerplate and can be ignored.
135   RT find(const Fortran::evaluate::Substring &x) { return find(x.parent()); }
136   template <typename A>
137   RT find(const Fortran::semantics::SymbolRef x) {
138     return find(*x);
139   }
140   RT find(const Fortran::evaluate::NamedEntity &x) {
141     if (x.IsSymbol())
142       return find(x.GetFirstSymbol());
143     return find(x.GetComponent());
144   }
145 
146   template <typename A, bool C>
147   RT find(const Fortran::common::Indirection<A, C> &x) {
148     return find(x.value());
149   }
150   template <typename A>
151   RT find(const std::unique_ptr<A> &x) {
152     return find(x.get());
153   }
154   template <typename A>
155   RT find(const std::shared_ptr<A> &x) {
156     return find(x.get());
157   }
158   template <typename A>
159   RT find(const A *x) {
160     if (x)
161       return find(*x);
162     return {};
163   }
164   template <typename A>
165   RT find(const std::optional<A> &x) {
166     if (x)
167       return find(*x);
168     return {};
169   }
170   template <typename... A>
171   RT find(const std::variant<A...> &u) {
172     return Fortran::common::visit([&](const auto &v) { return find(v); }, u);
173   }
174   template <typename A>
175   RT find(const std::vector<A> &x) {
176     for (auto &v : x)
177       (void)find(v);
178     return {};
179   }
180   RT find(const Fortran::evaluate::BOZLiteralConstant &) { return {}; }
181   RT find(const Fortran::evaluate::NullPointer &) { return {}; }
182   template <typename T>
183   RT find(const Fortran::evaluate::Constant<T> &x) {
184     return {};
185   }
186   RT find(const Fortran::evaluate::StaticDataObject &) { return {}; }
187   RT find(const Fortran::evaluate::ImpliedDoIndex &) { return {}; }
188   RT find(const Fortran::evaluate::BaseObject &x) {
189     (void)find(x.u);
190     return {};
191   }
192   RT find(const Fortran::evaluate::TypeParamInquiry &) { return {}; }
193   RT find(const Fortran::evaluate::ComplexPart &x) { return {}; }
194   template <typename T>
195   RT find(const Fortran::evaluate::Designator<T> &x) {
196     return find(x.u);
197   }
198   template <typename T>
199   RT find(const Fortran::evaluate::Variable<T> &x) {
200     return find(x.u);
201   }
202   RT find(const Fortran::evaluate::DescriptorInquiry &) { return {}; }
203   RT find(const Fortran::evaluate::SpecificIntrinsic &) { return {}; }
204   RT find(const Fortran::evaluate::ProcedureDesignator &x) { return {}; }
205   RT find(const Fortran::evaluate::ProcedureRef &x) {
206     (void)find(x.proc());
207     if (x.IsElemental())
208       (void)find(x.arguments());
209     return {};
210   }
211   RT find(const Fortran::evaluate::ActualArgument &x) {
212     if (const auto *sym = x.GetAssumedTypeDummy())
213       (void)find(*sym);
214     else
215       (void)find(x.UnwrapExpr());
216     return {};
217   }
218   template <typename T>
219   RT find(const Fortran::evaluate::FunctionRef<T> &x) {
220     (void)find(static_cast<const Fortran::evaluate::ProcedureRef &>(x));
221     return {};
222   }
223   template <typename T>
224   RT find(const Fortran::evaluate::ArrayConstructorValue<T> &) {
225     return {};
226   }
227   template <typename T>
228   RT find(const Fortran::evaluate::ArrayConstructorValues<T> &) {
229     return {};
230   }
231   template <typename T>
232   RT find(const Fortran::evaluate::ImpliedDo<T> &) {
233     return {};
234   }
235   RT find(const Fortran::semantics::ParamValue &) { return {}; }
236   RT find(const Fortran::semantics::DerivedTypeSpec &) { return {}; }
237   RT find(const Fortran::evaluate::StructureConstructor &) { return {}; }
238   template <typename D, typename R, typename O>
239   RT find(const Fortran::evaluate::Operation<D, R, O> &op) {
240     (void)find(op.left());
241     return false;
242   }
243   template <typename D, typename R, typename LO, typename RO>
244   RT find(const Fortran::evaluate::Operation<D, R, LO, RO> &op) {
245     (void)find(op.left());
246     (void)find(op.right());
247     return false;
248   }
249   RT find(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x) {
250     (void)find(x.u);
251     return {};
252   }
253   template <typename T>
254   RT find(const Fortran::evaluate::Expr<T> &x) {
255     (void)find(x.u);
256     return {};
257   }
258 
259   llvm::SmallVector<Fortran::lower::ExplicitIterSpace::ArrayBases> bases;
260   llvm::SmallVector<Fortran::lower::FrontEndSymbol> controlVars;
261 };
262 
263 } // namespace
264 
265 void Fortran::lower::ExplicitIterSpace::leave() {
266   ccLoopNest.pop_back();
267   --forallContextOpen;
268   conditionalCleanup();
269 }
270 
271 void Fortran::lower::ExplicitIterSpace::addSymbol(
272     Fortran::lower::FrontEndSymbol sym) {
273   assert(!symbolStack.empty());
274   symbolStack.back().push_back(sym);
275 }
276 
277 void Fortran::lower::ExplicitIterSpace::exprBase(Fortran::lower::FrontEndExpr x,
278                                                  bool lhs) {
279   ArrayBaseFinder finder(collectAllSymbols());
280   finder(*x);
281   llvm::ArrayRef<Fortran::lower::ExplicitIterSpace::ArrayBases> bases =
282       finder.getBases();
283   if (rhsBases.empty())
284     endAssign();
285   if (lhs) {
286     if (bases.empty()) {
287       lhsBases.push_back(std::nullopt);
288       return;
289     }
290     assert(bases.size() >= 1 && "must detect an array reference on lhs");
291     if (bases.size() > 1)
292       rhsBases.back().append(bases.begin(), bases.end() - 1);
293     lhsBases.push_back(bases.back());
294     return;
295   }
296   rhsBases.back().append(bases.begin(), bases.end());
297 }
298 
299 void Fortran::lower::ExplicitIterSpace::endAssign() { rhsBases.emplace_back(); }
300 
301 void Fortran::lower::ExplicitIterSpace::pushLevel() {
302   symbolStack.push_back(llvm::SmallVector<Fortran::lower::FrontEndSymbol>{});
303 }
304 
305 void Fortran::lower::ExplicitIterSpace::popLevel() { symbolStack.pop_back(); }
306 
307 void Fortran::lower::ExplicitIterSpace::conditionalCleanup() {
308   if (forallContextOpen == 0) {
309     // Exiting the outermost FORALL context.
310     // Cleanup any residual mask buffers.
311     outermostContext().finalizeAndReset();
312     // Clear and reset all the cached information.
313     symbolStack.clear();
314     lhsBases.clear();
315     rhsBases.clear();
316     loadBindings.clear();
317     ccLoopNest.clear();
318     innerArgs.clear();
319     outerLoop = std::nullopt;
320     clearLoops();
321     counter = 0;
322   }
323 }
324 
325 std::optional<size_t>
326 Fortran::lower::ExplicitIterSpace::findArgPosition(fir::ArrayLoadOp load) {
327   if (lhsBases[counter]) {
328     auto ld = loadBindings.find(*lhsBases[counter]);
329     std::optional<size_t> optPos;
330     if (ld != loadBindings.end() && ld->second == load)
331       optPos = static_cast<size_t>(0u);
332     assert(optPos.has_value() && "load does not correspond to lhs");
333     return optPos;
334   }
335   return std::nullopt;
336 }
337 
338 llvm::SmallVector<Fortran::lower::FrontEndSymbol>
339 Fortran::lower::ExplicitIterSpace::collectAllSymbols() {
340   llvm::SmallVector<Fortran::lower::FrontEndSymbol> result;
341   for (llvm::SmallVector<FrontEndSymbol> vec : symbolStack)
342     result.append(vec.begin(), vec.end());
343   return result;
344 }
345 
346 llvm::raw_ostream &
347 Fortran::lower::operator<<(llvm::raw_ostream &s,
348                            const Fortran::lower::ImplicitIterSpace &e) {
349   for (const llvm::SmallVector<
350            Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr> &xs :
351        e.getMasks()) {
352     s << "{ ";
353     for (const Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr &x : xs)
354       x->AsFortran(s << '(') << "), ";
355     s << "}\n";
356   }
357   return s;
358 }
359 
360 llvm::raw_ostream &
361 Fortran::lower::operator<<(llvm::raw_ostream &s,
362                            const Fortran::lower::ExplicitIterSpace &e) {
363   auto dump = [&](const auto &u) {
364     Fortran::common::visit(
365         Fortran::common::visitors{
366             [&](const Fortran::semantics::Symbol *y) {
367               s << "  " << *y << '\n';
368             },
369             [&](const Fortran::evaluate::ArrayRef *y) {
370               s << "  ";
371               if (y->base().IsSymbol())
372                 s << y->base().GetFirstSymbol();
373               else
374                 s << y->base().GetComponent().GetLastSymbol();
375               s << '\n';
376             },
377             [&](const Fortran::evaluate::Component *y) {
378               s << "  " << y->GetLastSymbol() << '\n';
379             }},
380         u);
381   };
382   s << "LHS bases:\n";
383   for (const std::optional<Fortran::lower::ExplicitIterSpace::ArrayBases> &u :
384        e.lhsBases)
385     if (u)
386       dump(*u);
387   s << "RHS bases:\n";
388   for (const llvm::SmallVector<Fortran::lower::ExplicitIterSpace::ArrayBases>
389            &bases : e.rhsBases) {
390     for (const Fortran::lower::ExplicitIterSpace::ArrayBases &u : bases)
391       dump(u);
392     s << '\n';
393   }
394   return s;
395 }
396 
397 void Fortran::lower::ImplicitIterSpace::dump() const {
398   llvm::errs() << *this << '\n';
399 }
400 
401 void Fortran::lower::ExplicitIterSpace::dump() const {
402   llvm::errs() << *this << '\n';
403 }
404