1 //===-- IterationSpace.cpp ------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Lower/IterationSpace.h" 14 #include "flang/Evaluate/expression.h" 15 #include "flang/Lower/AbstractConverter.h" 16 #include "flang/Lower/Support/Utils.h" 17 #include "llvm/Support/Debug.h" 18 #include <optional> 19 20 #define DEBUG_TYPE "flang-lower-iteration-space" 21 22 unsigned Fortran::lower::getHashValue( 23 const Fortran::lower::ExplicitIterSpace::ArrayBases &x) { 24 return Fortran::common::visit( 25 [&](const auto *p) { return HashEvaluateExpr::getHashValue(*p); }, x); 26 } 27 28 bool Fortran::lower::isEqual( 29 const Fortran::lower::ExplicitIterSpace::ArrayBases &x, 30 const Fortran::lower::ExplicitIterSpace::ArrayBases &y) { 31 return Fortran::common::visit( 32 Fortran::common::visitors{ 33 // Fortran::semantics::Symbol * are the exception here. These pointers 34 // have identity; if two Symbol * values are the same (different) then 35 // they are the same (different) logical symbol. 36 [&](Fortran::lower::FrontEndSymbol p, 37 Fortran::lower::FrontEndSymbol q) { return p == q; }, 38 [&](const auto *p, const auto *q) { 39 if constexpr (std::is_same_v<decltype(p), decltype(q)>) { 40 LLVM_DEBUG(llvm::dbgs() 41 << "is equal: " << p << ' ' << q << ' ' 42 << IsEqualEvaluateExpr::isEqual(*p, *q) << '\n'); 43 return IsEqualEvaluateExpr::isEqual(*p, *q); 44 } else { 45 // Different subtree types are never equal. 46 return false; 47 } 48 }}, 49 x, y); 50 } 51 52 namespace { 53 54 /// This class can recover the base array in an expression that contains 55 /// explicit iteration space symbols. Most of the class can be ignored as it is 56 /// boilerplate Fortran::evaluate::Expr traversal. 57 class ArrayBaseFinder { 58 public: 59 using RT = bool; 60 61 ArrayBaseFinder(llvm::ArrayRef<Fortran::lower::FrontEndSymbol> syms) 62 : controlVars(syms) {} 63 64 template <typename T> 65 void operator()(const T &x) { 66 (void)find(x); 67 } 68 69 /// Get the list of bases. 70 llvm::ArrayRef<Fortran::lower::ExplicitIterSpace::ArrayBases> 71 getBases() const { 72 LLVM_DEBUG(llvm::dbgs() 73 << "number of array bases found: " << bases.size() << '\n'); 74 return bases; 75 } 76 77 private: 78 // First, the cases that are of interest. 79 RT find(const Fortran::semantics::Symbol &symbol) { 80 if (symbol.Rank() > 0) { 81 bases.push_back(&symbol); 82 return true; 83 } 84 return {}; 85 } 86 RT find(const Fortran::evaluate::Component &x) { 87 auto found = find(x.base()); 88 if (!found && x.base().Rank() == 0 && x.Rank() > 0) { 89 bases.push_back(&x); 90 return true; 91 } 92 return found; 93 } 94 RT find(const Fortran::evaluate::ArrayRef &x) { 95 for (const auto &sub : x.subscript()) 96 (void)find(sub); 97 if (x.base().IsSymbol()) { 98 if (x.Rank() > 0 || intersection(x.subscript())) { 99 bases.push_back(&x); 100 return true; 101 } 102 return {}; 103 } 104 auto found = find(x.base()); 105 if (!found && ((x.base().Rank() == 0 && x.Rank() > 0) || 106 intersection(x.subscript()))) { 107 bases.push_back(&x); 108 return true; 109 } 110 return found; 111 } 112 RT find(const Fortran::evaluate::Triplet &x) { 113 if (const auto *lower = x.GetLower()) 114 (void)find(*lower); 115 if (const auto *upper = x.GetUpper()) 116 (void)find(*upper); 117 return find(x.GetStride()); 118 } 119 RT find(const Fortran::evaluate::IndirectSubscriptIntegerExpr &x) { 120 return find(x.value()); 121 } 122 RT find(const Fortran::evaluate::Subscript &x) { return find(x.u); } 123 RT find(const Fortran::evaluate::DataRef &x) { return find(x.u); } 124 RT find(const Fortran::evaluate::CoarrayRef &x) { 125 assert(false && "coarray reference"); 126 return {}; 127 } 128 129 template <typename A> 130 bool intersection(const A &subscripts) { 131 return Fortran::lower::symbolsIntersectSubscripts(controlVars, subscripts); 132 } 133 134 // The rest is traversal boilerplate and can be ignored. 135 RT find(const Fortran::evaluate::Substring &x) { return find(x.parent()); } 136 template <typename A> 137 RT find(const Fortran::semantics::SymbolRef x) { 138 return find(*x); 139 } 140 RT find(const Fortran::evaluate::NamedEntity &x) { 141 if (x.IsSymbol()) 142 return find(x.GetFirstSymbol()); 143 return find(x.GetComponent()); 144 } 145 146 template <typename A, bool C> 147 RT find(const Fortran::common::Indirection<A, C> &x) { 148 return find(x.value()); 149 } 150 template <typename A> 151 RT find(const std::unique_ptr<A> &x) { 152 return find(x.get()); 153 } 154 template <typename A> 155 RT find(const std::shared_ptr<A> &x) { 156 return find(x.get()); 157 } 158 template <typename A> 159 RT find(const A *x) { 160 if (x) 161 return find(*x); 162 return {}; 163 } 164 template <typename A> 165 RT find(const std::optional<A> &x) { 166 if (x) 167 return find(*x); 168 return {}; 169 } 170 template <typename... A> 171 RT find(const std::variant<A...> &u) { 172 return Fortran::common::visit([&](const auto &v) { return find(v); }, u); 173 } 174 template <typename A> 175 RT find(const std::vector<A> &x) { 176 for (auto &v : x) 177 (void)find(v); 178 return {}; 179 } 180 RT find(const Fortran::evaluate::BOZLiteralConstant &) { return {}; } 181 RT find(const Fortran::evaluate::NullPointer &) { return {}; } 182 template <typename T> 183 RT find(const Fortran::evaluate::Constant<T> &x) { 184 return {}; 185 } 186 RT find(const Fortran::evaluate::StaticDataObject &) { return {}; } 187 RT find(const Fortran::evaluate::ImpliedDoIndex &) { return {}; } 188 RT find(const Fortran::evaluate::BaseObject &x) { 189 (void)find(x.u); 190 return {}; 191 } 192 RT find(const Fortran::evaluate::TypeParamInquiry &) { return {}; } 193 RT find(const Fortran::evaluate::ComplexPart &x) { return {}; } 194 template <typename T> 195 RT find(const Fortran::evaluate::Designator<T> &x) { 196 return find(x.u); 197 } 198 template <typename T> 199 RT find(const Fortran::evaluate::Variable<T> &x) { 200 return find(x.u); 201 } 202 RT find(const Fortran::evaluate::DescriptorInquiry &) { return {}; } 203 RT find(const Fortran::evaluate::SpecificIntrinsic &) { return {}; } 204 RT find(const Fortran::evaluate::ProcedureDesignator &x) { return {}; } 205 RT find(const Fortran::evaluate::ProcedureRef &x) { 206 (void)find(x.proc()); 207 if (x.IsElemental()) 208 (void)find(x.arguments()); 209 return {}; 210 } 211 RT find(const Fortran::evaluate::ActualArgument &x) { 212 if (const auto *sym = x.GetAssumedTypeDummy()) 213 (void)find(*sym); 214 else 215 (void)find(x.UnwrapExpr()); 216 return {}; 217 } 218 template <typename T> 219 RT find(const Fortran::evaluate::FunctionRef<T> &x) { 220 (void)find(static_cast<const Fortran::evaluate::ProcedureRef &>(x)); 221 return {}; 222 } 223 template <typename T> 224 RT find(const Fortran::evaluate::ArrayConstructorValue<T> &) { 225 return {}; 226 } 227 template <typename T> 228 RT find(const Fortran::evaluate::ArrayConstructorValues<T> &) { 229 return {}; 230 } 231 template <typename T> 232 RT find(const Fortran::evaluate::ImpliedDo<T> &) { 233 return {}; 234 } 235 RT find(const Fortran::semantics::ParamValue &) { return {}; } 236 RT find(const Fortran::semantics::DerivedTypeSpec &) { return {}; } 237 RT find(const Fortran::evaluate::StructureConstructor &) { return {}; } 238 template <typename D, typename R, typename O> 239 RT find(const Fortran::evaluate::Operation<D, R, O> &op) { 240 (void)find(op.left()); 241 return false; 242 } 243 template <typename D, typename R, typename LO, typename RO> 244 RT find(const Fortran::evaluate::Operation<D, R, LO, RO> &op) { 245 (void)find(op.left()); 246 (void)find(op.right()); 247 return false; 248 } 249 RT find(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x) { 250 (void)find(x.u); 251 return {}; 252 } 253 template <typename T> 254 RT find(const Fortran::evaluate::Expr<T> &x) { 255 (void)find(x.u); 256 return {}; 257 } 258 259 llvm::SmallVector<Fortran::lower::ExplicitIterSpace::ArrayBases> bases; 260 llvm::SmallVector<Fortran::lower::FrontEndSymbol> controlVars; 261 }; 262 263 } // namespace 264 265 void Fortran::lower::ExplicitIterSpace::leave() { 266 ccLoopNest.pop_back(); 267 --forallContextOpen; 268 conditionalCleanup(); 269 } 270 271 void Fortran::lower::ExplicitIterSpace::addSymbol( 272 Fortran::lower::FrontEndSymbol sym) { 273 assert(!symbolStack.empty()); 274 symbolStack.back().push_back(sym); 275 } 276 277 void Fortran::lower::ExplicitIterSpace::exprBase(Fortran::lower::FrontEndExpr x, 278 bool lhs) { 279 ArrayBaseFinder finder(collectAllSymbols()); 280 finder(*x); 281 llvm::ArrayRef<Fortran::lower::ExplicitIterSpace::ArrayBases> bases = 282 finder.getBases(); 283 if (rhsBases.empty()) 284 endAssign(); 285 if (lhs) { 286 if (bases.empty()) { 287 lhsBases.push_back(std::nullopt); 288 return; 289 } 290 assert(bases.size() >= 1 && "must detect an array reference on lhs"); 291 if (bases.size() > 1) 292 rhsBases.back().append(bases.begin(), bases.end() - 1); 293 lhsBases.push_back(bases.back()); 294 return; 295 } 296 rhsBases.back().append(bases.begin(), bases.end()); 297 } 298 299 void Fortran::lower::ExplicitIterSpace::endAssign() { rhsBases.emplace_back(); } 300 301 void Fortran::lower::ExplicitIterSpace::pushLevel() { 302 symbolStack.push_back(llvm::SmallVector<Fortran::lower::FrontEndSymbol>{}); 303 } 304 305 void Fortran::lower::ExplicitIterSpace::popLevel() { symbolStack.pop_back(); } 306 307 void Fortran::lower::ExplicitIterSpace::conditionalCleanup() { 308 if (forallContextOpen == 0) { 309 // Exiting the outermost FORALL context. 310 // Cleanup any residual mask buffers. 311 outermostContext().finalizeAndReset(); 312 // Clear and reset all the cached information. 313 symbolStack.clear(); 314 lhsBases.clear(); 315 rhsBases.clear(); 316 loadBindings.clear(); 317 ccLoopNest.clear(); 318 innerArgs.clear(); 319 outerLoop = std::nullopt; 320 clearLoops(); 321 counter = 0; 322 } 323 } 324 325 std::optional<size_t> 326 Fortran::lower::ExplicitIterSpace::findArgPosition(fir::ArrayLoadOp load) { 327 if (lhsBases[counter]) { 328 auto ld = loadBindings.find(*lhsBases[counter]); 329 std::optional<size_t> optPos; 330 if (ld != loadBindings.end() && ld->second == load) 331 optPos = static_cast<size_t>(0u); 332 assert(optPos.has_value() && "load does not correspond to lhs"); 333 return optPos; 334 } 335 return std::nullopt; 336 } 337 338 llvm::SmallVector<Fortran::lower::FrontEndSymbol> 339 Fortran::lower::ExplicitIterSpace::collectAllSymbols() { 340 llvm::SmallVector<Fortran::lower::FrontEndSymbol> result; 341 for (llvm::SmallVector<FrontEndSymbol> vec : symbolStack) 342 result.append(vec.begin(), vec.end()); 343 return result; 344 } 345 346 llvm::raw_ostream & 347 Fortran::lower::operator<<(llvm::raw_ostream &s, 348 const Fortran::lower::ImplicitIterSpace &e) { 349 for (const llvm::SmallVector< 350 Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr> &xs : 351 e.getMasks()) { 352 s << "{ "; 353 for (const Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr &x : xs) 354 x->AsFortran(s << '(') << "), "; 355 s << "}\n"; 356 } 357 return s; 358 } 359 360 llvm::raw_ostream & 361 Fortran::lower::operator<<(llvm::raw_ostream &s, 362 const Fortran::lower::ExplicitIterSpace &e) { 363 auto dump = [&](const auto &u) { 364 Fortran::common::visit( 365 Fortran::common::visitors{ 366 [&](const Fortran::semantics::Symbol *y) { 367 s << " " << *y << '\n'; 368 }, 369 [&](const Fortran::evaluate::ArrayRef *y) { 370 s << " "; 371 if (y->base().IsSymbol()) 372 s << y->base().GetFirstSymbol(); 373 else 374 s << y->base().GetComponent().GetLastSymbol(); 375 s << '\n'; 376 }, 377 [&](const Fortran::evaluate::Component *y) { 378 s << " " << y->GetLastSymbol() << '\n'; 379 }}, 380 u); 381 }; 382 s << "LHS bases:\n"; 383 for (const std::optional<Fortran::lower::ExplicitIterSpace::ArrayBases> &u : 384 e.lhsBases) 385 if (u) 386 dump(*u); 387 s << "RHS bases:\n"; 388 for (const llvm::SmallVector<Fortran::lower::ExplicitIterSpace::ArrayBases> 389 &bases : e.rhsBases) { 390 for (const Fortran::lower::ExplicitIterSpace::ArrayBases &u : bases) 391 dump(u); 392 s << '\n'; 393 } 394 return s; 395 } 396 397 void Fortran::lower::ImplicitIterSpace::dump() const { 398 llvm::errs() << *this << '\n'; 399 } 400 401 void Fortran::lower::ExplicitIterSpace::dump() const { 402 llvm::errs() << *this << '\n'; 403 } 404