1 //===-- include/flang/Evaluate/traverse.h -----------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef FORTRAN_EVALUATE_TRAVERSE_H_ 10 #define FORTRAN_EVALUATE_TRAVERSE_H_ 11 12 // A utility for scanning all of the constituent objects in an Expr<> 13 // expression representation using a collection of mutually recursive 14 // functions to compose a function object. 15 // 16 // The class template Traverse<> below implements a function object that 17 // can handle every type that can appear in or around an Expr<>. 18 // Each of its overloads for operator() should be viewed as a *default* 19 // handler; some of these must be overridden by the client to accomplish 20 // its particular task. 21 // 22 // The client (Visitor) of Traverse<Visitor,Result> must define: 23 // - a member function "Result Default();" 24 // - a member function "Result Combine(Result &&, Result &&)" 25 // - overrides for "Result operator()" 26 // 27 // Boilerplate classes also appear below to ease construction of visitors. 28 // See CheckSpecificationExpr() in check-expression.cpp for an example client. 29 // 30 // How this works: 31 // - The operator() overloads in Traverse<> invoke the visitor's Default() for 32 // expression leaf nodes. They invoke the visitor's operator() for the 33 // subtrees of interior nodes, and the visitor's Combine() to merge their 34 // results together. 35 // - Overloads of operator() in each visitor handle the cases of interest. 36 // 37 // The default handler for semantics::Symbol will descend into the associated 38 // expression of an ASSOCIATE (or related) construct entity. 39 40 #include "expression.h" 41 #include "flang/Common/indirection.h" 42 #include "flang/Semantics/symbol.h" 43 #include "flang/Semantics/type.h" 44 #include <set> 45 #include <type_traits> 46 47 namespace Fortran::evaluate { 48 template <typename Visitor, typename Result, 49 bool TraverseAssocEntityDetails = true> 50 class Traverse { 51 public: 52 explicit Traverse(Visitor &v) : visitor_{v} {} 53 54 // Packaging 55 template <typename A, bool C> 56 Result operator()(const common::Indirection<A, C> &x) const { 57 return visitor_(x.value()); 58 } 59 template <typename A> 60 Result operator()(const common::ForwardOwningPointer<A> &p) const { 61 return visitor_(p.get()); 62 } 63 template <typename _> Result operator()(const SymbolRef x) const { 64 return visitor_(*x); 65 } 66 template <typename A> Result operator()(const std::unique_ptr<A> &x) const { 67 return visitor_(x.get()); 68 } 69 template <typename A> Result operator()(const std::shared_ptr<A> &x) const { 70 return visitor_(x.get()); 71 } 72 template <typename A> Result operator()(const A *x) const { 73 if (x) { 74 return visitor_(*x); 75 } else { 76 return visitor_.Default(); 77 } 78 } 79 template <typename A> Result operator()(const std::optional<A> &x) const { 80 if (x) { 81 return visitor_(*x); 82 } else { 83 return visitor_.Default(); 84 } 85 } 86 template <typename... As> 87 Result operator()(const std::variant<As...> &u) const { 88 return common::visit([=](const auto &y) { return visitor_(y); }, u); 89 } 90 template <typename A> Result operator()(const std::vector<A> &x) const { 91 return CombineContents(x); 92 } 93 template <typename A, typename B> 94 Result operator()(const std::pair<A, B> &x) const { 95 return Combine(x.first, x.second); 96 } 97 98 // Leaves 99 Result operator()(const BOZLiteralConstant &) const { 100 return visitor_.Default(); 101 } 102 Result operator()(const NullPointer &) const { return visitor_.Default(); } 103 template <typename T> Result operator()(const Constant<T> &x) const { 104 if constexpr (T::category == TypeCategory::Derived) { 105 return visitor_.Combine( 106 visitor_(x.result().derivedTypeSpec()), CombineContents(x.values())); 107 } else { 108 return visitor_.Default(); 109 } 110 } 111 Result operator()(const Symbol &symbol) const { 112 const Symbol &ultimate{symbol.GetUltimate()}; 113 if constexpr (TraverseAssocEntityDetails) { 114 if (const auto *assoc{ 115 ultimate.detailsIf<semantics::AssocEntityDetails>()}) { 116 return visitor_(assoc->expr()); 117 } 118 } 119 return visitor_.Default(); 120 } 121 Result operator()(const StaticDataObject &) const { 122 return visitor_.Default(); 123 } 124 Result operator()(const ImpliedDoIndex &) const { return visitor_.Default(); } 125 126 // Variables 127 Result operator()(const BaseObject &x) const { return visitor_(x.u); } 128 Result operator()(const Component &x) const { 129 return Combine(x.base(), x.symbol()); 130 } 131 Result operator()(const NamedEntity &x) const { 132 if (const Component * component{x.UnwrapComponent()}) { 133 return visitor_(*component); 134 } else { 135 return visitor_(DEREF(x.UnwrapSymbolRef())); 136 } 137 } 138 Result operator()(const TypeParamInquiry &x) const { 139 return visitor_(x.base()); 140 } 141 Result operator()(const Triplet &x) const { 142 return Combine(x.GetLower(), x.GetUpper(), x.GetStride()); 143 } 144 Result operator()(const Subscript &x) const { return visitor_(x.u); } 145 Result operator()(const ArrayRef &x) const { 146 return Combine(x.base(), x.subscript()); 147 } 148 Result operator()(const CoarrayRef &x) const { 149 return Combine( 150 x.base(), x.subscript(), x.cosubscript(), x.stat(), x.team()); 151 } 152 Result operator()(const DataRef &x) const { return visitor_(x.u); } 153 Result operator()(const Substring &x) const { 154 return Combine(x.parent(), x.GetLower(), x.GetUpper()); 155 } 156 Result operator()(const ComplexPart &x) const { 157 return visitor_(x.complex()); 158 } 159 template <typename T> Result operator()(const Designator<T> &x) const { 160 return visitor_(x.u); 161 } 162 template <typename T> Result operator()(const Variable<T> &x) const { 163 return visitor_(x.u); 164 } 165 Result operator()(const DescriptorInquiry &x) const { 166 return visitor_(x.base()); 167 } 168 169 // Calls 170 Result operator()(const SpecificIntrinsic &) const { 171 return visitor_.Default(); 172 } 173 Result operator()(const ProcedureDesignator &x) const { 174 if (const Component * component{x.GetComponent()}) { 175 return visitor_(*component); 176 } else if (const Symbol * symbol{x.GetSymbol()}) { 177 return visitor_(*symbol); 178 } else { 179 return visitor_(DEREF(x.GetSpecificIntrinsic())); 180 } 181 } 182 Result operator()(const ActualArgument &x) const { 183 if (const auto *symbol{x.GetAssumedTypeDummy()}) { 184 return visitor_(*symbol); 185 } else { 186 return visitor_(x.UnwrapExpr()); 187 } 188 } 189 Result operator()(const ProcedureRef &x) const { 190 return Combine(x.proc(), x.arguments()); 191 } 192 template <typename T> Result operator()(const FunctionRef<T> &x) const { 193 return visitor_(static_cast<const ProcedureRef &>(x)); 194 } 195 196 // Other primaries 197 template <typename T> 198 Result operator()(const ArrayConstructorValue<T> &x) const { 199 return visitor_(x.u); 200 } 201 template <typename T> 202 Result operator()(const ArrayConstructorValues<T> &x) const { 203 return CombineContents(x); 204 } 205 template <typename T> Result operator()(const ImpliedDo<T> &x) const { 206 return Combine(x.lower(), x.upper(), x.stride(), x.values()); 207 } 208 Result operator()(const semantics::ParamValue &x) const { 209 return visitor_(x.GetExplicit()); 210 } 211 Result operator()( 212 const semantics::DerivedTypeSpec::ParameterMapType::value_type &x) const { 213 return visitor_(x.second); 214 } 215 Result operator()( 216 const semantics::DerivedTypeSpec::ParameterMapType &x) const { 217 return CombineContents(x); 218 } 219 Result operator()(const semantics::DerivedTypeSpec &x) const { 220 return Combine(x.originalTypeSymbol(), x.parameters()); 221 } 222 Result operator()(const StructureConstructorValues::value_type &x) const { 223 return visitor_(x.second); 224 } 225 Result operator()(const StructureConstructorValues &x) const { 226 return CombineContents(x); 227 } 228 Result operator()(const StructureConstructor &x) const { 229 return visitor_.Combine(visitor_(x.derivedTypeSpec()), CombineContents(x)); 230 } 231 232 // Operations and wrappers 233 template <typename D, typename R, typename O> 234 Result operator()(const Operation<D, R, O> &op) const { 235 return visitor_(op.left()); 236 } 237 template <typename D, typename R, typename LO, typename RO> 238 Result operator()(const Operation<D, R, LO, RO> &op) const { 239 return Combine(op.left(), op.right()); 240 } 241 Result operator()(const Relational<SomeType> &x) const { 242 return visitor_(x.u); 243 } 244 template <typename T> Result operator()(const Expr<T> &x) const { 245 return visitor_(x.u); 246 } 247 Result operator()(const Assignment &x) const { 248 return Combine(x.lhs, x.rhs, x.u); 249 } 250 Result operator()(const Assignment::Intrinsic &) const { 251 return visitor_.Default(); 252 } 253 Result operator()(const GenericExprWrapper &x) const { return visitor_(x.v); } 254 Result operator()(const GenericAssignmentWrapper &x) const { 255 return visitor_(x.v); 256 } 257 258 private: 259 template <typename ITER> Result CombineRange(ITER iter, ITER end) const { 260 if (iter == end) { 261 return visitor_.Default(); 262 } else { 263 Result result{visitor_(*iter)}; 264 for (++iter; iter != end; ++iter) { 265 result = visitor_.Combine(std::move(result), visitor_(*iter)); 266 } 267 return result; 268 } 269 } 270 271 template <typename A> Result CombineContents(const A &x) const { 272 return CombineRange(x.begin(), x.end()); 273 } 274 275 template <typename A, typename... Bs> 276 Result Combine(const A &x, const Bs &...ys) const { 277 if constexpr (sizeof...(Bs) == 0) { 278 return visitor_(x); 279 } else { 280 return visitor_.Combine(visitor_(x), Combine(ys...)); 281 } 282 } 283 284 Visitor &visitor_; 285 }; 286 287 // For validity checks across an expression: if any operator() result is 288 // false, so is the overall result. 289 template <typename Visitor, bool DefaultValue, 290 bool TraverseAssocEntityDetails = true, 291 typename Base = Traverse<Visitor, bool, TraverseAssocEntityDetails>> 292 struct AllTraverse : public Base { 293 explicit AllTraverse(Visitor &v) : Base{v} {} 294 using Base::operator(); 295 static bool Default() { return DefaultValue; } 296 static bool Combine(bool x, bool y) { return x && y; } 297 }; 298 299 // For searches over an expression: the first operator() result that 300 // is truthful is the final result. Works for Booleans, pointers, 301 // and std::optional<>. 302 template <typename Visitor, typename Result = bool, 303 bool TraverseAssocEntityDetails = true, 304 typename Base = Traverse<Visitor, Result, TraverseAssocEntityDetails>> 305 class AnyTraverse : public Base { 306 public: 307 explicit AnyTraverse(Visitor &v) : Base{v} {} 308 using Base::operator(); 309 Result Default() const { return default_; } 310 static Result Combine(Result &&x, Result &&y) { 311 if (x) { 312 return std::move(x); 313 } else { 314 return std::move(y); 315 } 316 } 317 318 private: 319 Result default_{}; 320 }; 321 322 template <typename Visitor, typename Set, 323 bool TraverseAssocEntityDetails = true, 324 typename Base = Traverse<Visitor, Set, TraverseAssocEntityDetails>> 325 struct SetTraverse : public Base { 326 explicit SetTraverse(Visitor &v) : Base{v} {} 327 using Base::operator(); 328 static Set Default() { return {}; } 329 static Set Combine(Set &&x, Set &&y) { 330 #if defined __GNUC__ && !defined __APPLE__ && !(CLANG_LIBRARIES) 331 x.merge(y); 332 #else 333 // std::set::merge() not available (yet) 334 for (auto &value : y) { 335 x.insert(std::move(value)); 336 } 337 #endif 338 return std::move(x); 339 } 340 }; 341 342 } // namespace Fortran::evaluate 343 #endif 344