xref: /llvm-project/flang/include/flang/Evaluate/traverse.h (revision ebec4d6369cbf9bbd64236b02d90e8f3597ad103)
1 //===-- include/flang/Evaluate/traverse.h -----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef FORTRAN_EVALUATE_TRAVERSE_H_
10 #define FORTRAN_EVALUATE_TRAVERSE_H_
11 
12 // A utility for scanning all of the constituent objects in an Expr<>
13 // expression representation using a collection of mutually recursive
14 // functions to compose a function object.
15 //
16 // The class template Traverse<> below implements a function object that
17 // can handle every type that can appear in or around an Expr<>.
18 // Each of its overloads for operator() should be viewed as a *default*
19 // handler; some of these must be overridden by the client to accomplish
20 // its particular task.
21 //
22 // The client (Visitor) of Traverse<Visitor,Result> must define:
23 // - a member function "Result Default();"
24 // - a member function "Result Combine(Result &&, Result &&)"
25 // - overrides for "Result operator()"
26 //
27 // Boilerplate classes also appear below to ease construction of visitors.
28 // See CheckSpecificationExpr() in check-expression.cpp for an example client.
29 //
30 // How this works:
31 // - The operator() overloads in Traverse<> invoke the visitor's Default() for
32 //   expression leaf nodes.  They invoke the visitor's operator() for the
33 //   subtrees of interior nodes, and the visitor's Combine() to merge their
34 //   results together.
35 // - Overloads of operator() in each visitor handle the cases of interest.
36 //
37 // The default handler for semantics::Symbol will descend into the associated
38 // expression of an ASSOCIATE (or related) construct entity.
39 
40 #include "expression.h"
41 #include "flang/Common/indirection.h"
42 #include "flang/Semantics/symbol.h"
43 #include "flang/Semantics/type.h"
44 #include <set>
45 #include <type_traits>
46 
47 namespace Fortran::evaluate {
48 template <typename Visitor, typename Result,
49     bool TraverseAssocEntityDetails = true>
50 class Traverse {
51 public:
52   explicit Traverse(Visitor &v) : visitor_{v} {}
53 
54   // Packaging
55   template <typename A, bool C>
56   Result operator()(const common::Indirection<A, C> &x) const {
57     return visitor_(x.value());
58   }
59   template <typename A>
60   Result operator()(const common::ForwardOwningPointer<A> &p) const {
61     return visitor_(p.get());
62   }
63   template <typename _> Result operator()(const SymbolRef x) const {
64     return visitor_(*x);
65   }
66   template <typename A> Result operator()(const std::unique_ptr<A> &x) const {
67     return visitor_(x.get());
68   }
69   template <typename A> Result operator()(const std::shared_ptr<A> &x) const {
70     return visitor_(x.get());
71   }
72   template <typename A> Result operator()(const A *x) const {
73     if (x) {
74       return visitor_(*x);
75     } else {
76       return visitor_.Default();
77     }
78   }
79   template <typename A> Result operator()(const std::optional<A> &x) const {
80     if (x) {
81       return visitor_(*x);
82     } else {
83       return visitor_.Default();
84     }
85   }
86   template <typename... As>
87   Result operator()(const std::variant<As...> &u) const {
88     return common::visit([=](const auto &y) { return visitor_(y); }, u);
89   }
90   template <typename A> Result operator()(const std::vector<A> &x) const {
91     return CombineContents(x);
92   }
93   template <typename A, typename B>
94   Result operator()(const std::pair<A, B> &x) const {
95     return Combine(x.first, x.second);
96   }
97 
98   // Leaves
99   Result operator()(const BOZLiteralConstant &) const {
100     return visitor_.Default();
101   }
102   Result operator()(const NullPointer &) const { return visitor_.Default(); }
103   template <typename T> Result operator()(const Constant<T> &x) const {
104     if constexpr (T::category == TypeCategory::Derived) {
105       return visitor_.Combine(
106           visitor_(x.result().derivedTypeSpec()), CombineContents(x.values()));
107     } else {
108       return visitor_.Default();
109     }
110   }
111   Result operator()(const Symbol &symbol) const {
112     const Symbol &ultimate{symbol.GetUltimate()};
113     if constexpr (TraverseAssocEntityDetails) {
114       if (const auto *assoc{
115               ultimate.detailsIf<semantics::AssocEntityDetails>()}) {
116         return visitor_(assoc->expr());
117       }
118     }
119     return visitor_.Default();
120   }
121   Result operator()(const StaticDataObject &) const {
122     return visitor_.Default();
123   }
124   Result operator()(const ImpliedDoIndex &) const { return visitor_.Default(); }
125 
126   // Variables
127   Result operator()(const BaseObject &x) const { return visitor_(x.u); }
128   Result operator()(const Component &x) const {
129     return Combine(x.base(), x.symbol());
130   }
131   Result operator()(const NamedEntity &x) const {
132     if (const Component * component{x.UnwrapComponent()}) {
133       return visitor_(*component);
134     } else {
135       return visitor_(DEREF(x.UnwrapSymbolRef()));
136     }
137   }
138   Result operator()(const TypeParamInquiry &x) const {
139     return visitor_(x.base());
140   }
141   Result operator()(const Triplet &x) const {
142     return Combine(x.GetLower(), x.GetUpper(), x.GetStride());
143   }
144   Result operator()(const Subscript &x) const { return visitor_(x.u); }
145   Result operator()(const ArrayRef &x) const {
146     return Combine(x.base(), x.subscript());
147   }
148   Result operator()(const CoarrayRef &x) const {
149     return Combine(
150         x.base(), x.subscript(), x.cosubscript(), x.stat(), x.team());
151   }
152   Result operator()(const DataRef &x) const { return visitor_(x.u); }
153   Result operator()(const Substring &x) const {
154     return Combine(x.parent(), x.GetLower(), x.GetUpper());
155   }
156   Result operator()(const ComplexPart &x) const {
157     return visitor_(x.complex());
158   }
159   template <typename T> Result operator()(const Designator<T> &x) const {
160     return visitor_(x.u);
161   }
162   template <typename T> Result operator()(const Variable<T> &x) const {
163     return visitor_(x.u);
164   }
165   Result operator()(const DescriptorInquiry &x) const {
166     return visitor_(x.base());
167   }
168 
169   // Calls
170   Result operator()(const SpecificIntrinsic &) const {
171     return visitor_.Default();
172   }
173   Result operator()(const ProcedureDesignator &x) const {
174     if (const Component * component{x.GetComponent()}) {
175       return visitor_(*component);
176     } else if (const Symbol * symbol{x.GetSymbol()}) {
177       return visitor_(*symbol);
178     } else {
179       return visitor_(DEREF(x.GetSpecificIntrinsic()));
180     }
181   }
182   Result operator()(const ActualArgument &x) const {
183     if (const auto *symbol{x.GetAssumedTypeDummy()}) {
184       return visitor_(*symbol);
185     } else {
186       return visitor_(x.UnwrapExpr());
187     }
188   }
189   Result operator()(const ProcedureRef &x) const {
190     return Combine(x.proc(), x.arguments());
191   }
192   template <typename T> Result operator()(const FunctionRef<T> &x) const {
193     return visitor_(static_cast<const ProcedureRef &>(x));
194   }
195 
196   // Other primaries
197   template <typename T>
198   Result operator()(const ArrayConstructorValue<T> &x) const {
199     return visitor_(x.u);
200   }
201   template <typename T>
202   Result operator()(const ArrayConstructorValues<T> &x) const {
203     return CombineContents(x);
204   }
205   template <typename T> Result operator()(const ImpliedDo<T> &x) const {
206     return Combine(x.lower(), x.upper(), x.stride(), x.values());
207   }
208   Result operator()(const semantics::ParamValue &x) const {
209     return visitor_(x.GetExplicit());
210   }
211   Result operator()(
212       const semantics::DerivedTypeSpec::ParameterMapType::value_type &x) const {
213     return visitor_(x.second);
214   }
215   Result operator()(
216       const semantics::DerivedTypeSpec::ParameterMapType &x) const {
217     return CombineContents(x);
218   }
219   Result operator()(const semantics::DerivedTypeSpec &x) const {
220     return Combine(x.originalTypeSymbol(), x.parameters());
221   }
222   Result operator()(const StructureConstructorValues::value_type &x) const {
223     return visitor_(x.second);
224   }
225   Result operator()(const StructureConstructorValues &x) const {
226     return CombineContents(x);
227   }
228   Result operator()(const StructureConstructor &x) const {
229     return visitor_.Combine(visitor_(x.derivedTypeSpec()), CombineContents(x));
230   }
231 
232   // Operations and wrappers
233   template <typename D, typename R, typename O>
234   Result operator()(const Operation<D, R, O> &op) const {
235     return visitor_(op.left());
236   }
237   template <typename D, typename R, typename LO, typename RO>
238   Result operator()(const Operation<D, R, LO, RO> &op) const {
239     return Combine(op.left(), op.right());
240   }
241   Result operator()(const Relational<SomeType> &x) const {
242     return visitor_(x.u);
243   }
244   template <typename T> Result operator()(const Expr<T> &x) const {
245     return visitor_(x.u);
246   }
247   Result operator()(const Assignment &x) const {
248     return Combine(x.lhs, x.rhs, x.u);
249   }
250   Result operator()(const Assignment::Intrinsic &) const {
251     return visitor_.Default();
252   }
253   Result operator()(const GenericExprWrapper &x) const { return visitor_(x.v); }
254   Result operator()(const GenericAssignmentWrapper &x) const {
255     return visitor_(x.v);
256   }
257 
258 private:
259   template <typename ITER> Result CombineRange(ITER iter, ITER end) const {
260     if (iter == end) {
261       return visitor_.Default();
262     } else {
263       Result result{visitor_(*iter)};
264       for (++iter; iter != end; ++iter) {
265         result = visitor_.Combine(std::move(result), visitor_(*iter));
266       }
267       return result;
268     }
269   }
270 
271   template <typename A> Result CombineContents(const A &x) const {
272     return CombineRange(x.begin(), x.end());
273   }
274 
275   template <typename A, typename... Bs>
276   Result Combine(const A &x, const Bs &...ys) const {
277     if constexpr (sizeof...(Bs) == 0) {
278       return visitor_(x);
279     } else {
280       return visitor_.Combine(visitor_(x), Combine(ys...));
281     }
282   }
283 
284   Visitor &visitor_;
285 };
286 
287 // For validity checks across an expression: if any operator() result is
288 // false, so is the overall result.
289 template <typename Visitor, bool DefaultValue,
290     bool TraverseAssocEntityDetails = true,
291     typename Base = Traverse<Visitor, bool, TraverseAssocEntityDetails>>
292 struct AllTraverse : public Base {
293   explicit AllTraverse(Visitor &v) : Base{v} {}
294   using Base::operator();
295   static bool Default() { return DefaultValue; }
296   static bool Combine(bool x, bool y) { return x && y; }
297 };
298 
299 // For searches over an expression: the first operator() result that
300 // is truthful is the final result.  Works for Booleans, pointers,
301 // and std::optional<>.
302 template <typename Visitor, typename Result = bool,
303     bool TraverseAssocEntityDetails = true,
304     typename Base = Traverse<Visitor, Result, TraverseAssocEntityDetails>>
305 class AnyTraverse : public Base {
306 public:
307   explicit AnyTraverse(Visitor &v) : Base{v} {}
308   using Base::operator();
309   Result Default() const { return default_; }
310   static Result Combine(Result &&x, Result &&y) {
311     if (x) {
312       return std::move(x);
313     } else {
314       return std::move(y);
315     }
316   }
317 
318 private:
319   Result default_{};
320 };
321 
322 template <typename Visitor, typename Set,
323     bool TraverseAssocEntityDetails = true,
324     typename Base = Traverse<Visitor, Set, TraverseAssocEntityDetails>>
325 struct SetTraverse : public Base {
326   explicit SetTraverse(Visitor &v) : Base{v} {}
327   using Base::operator();
328   static Set Default() { return {}; }
329   static Set Combine(Set &&x, Set &&y) {
330 #if defined __GNUC__ && !defined __APPLE__ && !(CLANG_LIBRARIES)
331     x.merge(y);
332 #else
333     // std::set::merge() not available (yet)
334     for (auto &value : y) {
335       x.insert(std::move(value));
336     }
337 #endif
338     return std::move(x);
339   }
340 };
341 
342 } // namespace Fortran::evaluate
343 #endif
344