1 //===-- IterationSpace.h ----------------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef FORTRAN_LOWER_ITERATIONSPACE_H 14 #define FORTRAN_LOWER_ITERATIONSPACE_H 15 16 #include "flang/Evaluate/tools.h" 17 #include "flang/Lower/StatementContext.h" 18 #include "flang/Lower/SymbolMap.h" 19 #include "flang/Optimizer/Builder/FIRBuilder.h" 20 #include <optional> 21 22 namespace llvm { 23 class raw_ostream; 24 } 25 26 namespace Fortran { 27 namespace evaluate { 28 struct SomeType; 29 template <typename> 30 class Expr; 31 } // namespace evaluate 32 33 namespace lower { 34 35 using FrontEndExpr = const evaluate::Expr<evaluate::SomeType> *; 36 using FrontEndSymbol = const semantics::Symbol *; 37 38 class AbstractConverter; 39 40 } // namespace lower 41 } // namespace Fortran 42 43 namespace Fortran::lower { 44 45 /// Abstraction of the iteration space for building the elemental compute loop 46 /// of an array(-like) statement. 47 class IterationSpace { 48 public: 49 IterationSpace() = default; 50 51 template <typename A> 52 explicit IterationSpace(mlir::Value inArg, mlir::Value outRes, 53 llvm::iterator_range<A> range) 54 : inArg{inArg}, outRes{outRes}, indices{range.begin(), range.end()} {} 55 56 explicit IterationSpace(const IterationSpace &from, 57 llvm::ArrayRef<mlir::Value> idxs) 58 : inArg(from.inArg), outRes(from.outRes), element(from.element), 59 indices(idxs) {} 60 61 /// Create a copy of the \p from IterationSpace and prepend the \p prefix 62 /// values and append the \p suffix values, respectively. 63 explicit IterationSpace(const IterationSpace &from, 64 llvm::ArrayRef<mlir::Value> prefix, 65 llvm::ArrayRef<mlir::Value> suffix) 66 : inArg(from.inArg), outRes(from.outRes), element(from.element) { 67 indices.assign(prefix.begin(), prefix.end()); 68 indices.append(from.indices.begin(), from.indices.end()); 69 indices.append(suffix.begin(), suffix.end()); 70 } 71 72 bool empty() const { return indices.empty(); } 73 74 /// This is the output value as it appears as an argument in the innermost 75 /// loop in the nest. The output value is threaded through the loop (and 76 /// conditionals) to maintain proper SSA form. 77 mlir::Value innerArgument() const { return inArg; } 78 79 /// This is the output value as it appears as an output value from the 80 /// outermost loop in the loop nest. The output value is threaded through the 81 /// loop (and conditionals) to maintain proper SSA form. 82 mlir::Value outerResult() const { return outRes; } 83 84 /// Returns a vector for the iteration space. This vector is used to access 85 /// elements of arrays in the compute loop. 86 llvm::SmallVector<mlir::Value> iterVec() const { return indices; } 87 88 mlir::Value iterValue(std::size_t i) const { 89 assert(i < indices.size()); 90 return indices[i]; 91 } 92 93 /// Set (rewrite) the Value at a given index. 94 void setIndexValue(std::size_t i, mlir::Value v) { 95 assert(i < indices.size()); 96 indices[i] = v; 97 } 98 99 void setIndexValues(llvm::ArrayRef<mlir::Value> vals) { 100 indices.assign(vals.begin(), vals.end()); 101 } 102 103 void insertIndexValue(std::size_t i, mlir::Value av) { 104 assert(i <= indices.size()); 105 indices.insert(indices.begin() + i, av); 106 } 107 108 /// Set the `element` value. This is the SSA value that corresponds to an 109 /// element of the resultant array value. 110 void setElement(fir::ExtendedValue &&ele) { 111 assert(!fir::getBase(element) && "result element already set"); 112 element = ele; 113 } 114 115 /// Get the value that will be merged into the resultant array. This is the 116 /// computed value that will be stored to the lhs of the assignment. 117 mlir::Value getElement() const { 118 assert(fir::getBase(element) && "element must be set"); 119 return fir::getBase(element); 120 } 121 122 /// Get the element as an extended value. 123 fir::ExtendedValue elementExv() const { return element; } 124 125 void clearIndices() { indices.clear(); } 126 127 private: 128 mlir::Value inArg; 129 mlir::Value outRes; 130 fir::ExtendedValue element; 131 llvm::SmallVector<mlir::Value> indices; 132 }; 133 134 using GenerateElementalArrayFunc = 135 std::function<fir::ExtendedValue(const IterationSpace &)>; 136 137 template <typename A> 138 class StackableConstructExpr { 139 public: 140 bool empty() const { return stack.empty(); } 141 142 void growStack() { stack.push_back(A{}); } 143 144 /// Bind a front-end expression to a closure. 145 void bind(FrontEndExpr e, GenerateElementalArrayFunc &&fun) { 146 vmap.insert({e, std::move(fun)}); 147 } 148 149 /// Replace the binding of front-end expression `e` with a new closure. 150 void rebind(FrontEndExpr e, GenerateElementalArrayFunc &&fun) { 151 vmap.erase(e); 152 bind(e, std::move(fun)); 153 } 154 155 /// Get the closure bound to the front-end expression, `e`. 156 GenerateElementalArrayFunc getBoundClosure(FrontEndExpr e) const { 157 if (!vmap.count(e)) 158 llvm::report_fatal_error( 159 "evaluate::Expr is not in the map of lowered mask expressions"); 160 return vmap.lookup(e); 161 } 162 163 /// Has the front-end expression, `e`, been lowered and bound? 164 bool isLowered(FrontEndExpr e) const { return vmap.count(e); } 165 166 StatementContext &stmtContext() { return stmtCtx; } 167 168 protected: 169 void shrinkStack() { 170 assert(!empty()); 171 stack.pop_back(); 172 if (empty()) { 173 stmtCtx.finalizeAndReset(); 174 vmap.clear(); 175 } 176 } 177 178 // The stack for the construct information. 179 llvm::SmallVector<A> stack; 180 181 // Map each mask expression back to the temporary holding the initial 182 // evaluation results. 183 llvm::DenseMap<FrontEndExpr, GenerateElementalArrayFunc> vmap; 184 185 // Inflate the statement context for the entire construct. We have to cache 186 // the mask expression results, which are always evaluated first, across the 187 // entire construct. 188 StatementContext stmtCtx; 189 }; 190 191 class ImplicitIterSpace; 192 llvm::raw_ostream &operator<<(llvm::raw_ostream &, const ImplicitIterSpace &); 193 194 /// All array expressions have an implicit iteration space, which is isomorphic 195 /// to the shape of the base array that facilitates the expression having a 196 /// non-zero rank. This implied iteration space may be conditionalized 197 /// (disjunctively) with an if-elseif-else like structure, specifically 198 /// Fortran's WHERE construct. 199 /// 200 /// This class is used in the bridge to collect the expressions from the 201 /// front end (the WHERE construct mask expressions), forward them for lowering 202 /// as array expressions in an "evaluate once" (copy-in, copy-out) semantics. 203 /// See 10.2.3.2p3, 10.2.3.2p13, etc. 204 class ImplicitIterSpace 205 : public StackableConstructExpr<llvm::SmallVector<FrontEndExpr>> { 206 public: 207 using Base = StackableConstructExpr<llvm::SmallVector<FrontEndExpr>>; 208 using FrontEndMaskExpr = FrontEndExpr; 209 210 friend llvm::raw_ostream &operator<<(llvm::raw_ostream &, 211 const ImplicitIterSpace &); 212 213 LLVM_DUMP_METHOD void dump() const; 214 215 void append(FrontEndMaskExpr e) { 216 assert(!empty()); 217 getMasks().back().push_back(e); 218 } 219 220 llvm::SmallVector<FrontEndMaskExpr> getExprs() const { 221 llvm::SmallVector<FrontEndMaskExpr> maskList = getMasks()[0]; 222 for (size_t i = 1, d = getMasks().size(); i < d; ++i) 223 maskList.append(getMasks()[i].begin(), getMasks()[i].end()); 224 return maskList; 225 } 226 227 /// Add a variable binding, `var`, along with its shape for the mask 228 /// expression `exp`. 229 void addMaskVariable(FrontEndExpr exp, mlir::Value var, mlir::Value shape, 230 mlir::Value header) { 231 maskVarMap.try_emplace(exp, std::make_tuple(var, shape, header)); 232 } 233 234 /// Lookup the variable corresponding to the temporary buffer that contains 235 /// the mask array expression results. 236 mlir::Value lookupMaskVariable(FrontEndExpr exp) { 237 return std::get<0>(maskVarMap.lookup(exp)); 238 } 239 240 /// Lookup the variable containing the shape vector for the mask array 241 /// expression results. 242 mlir::Value lookupMaskShapeBuffer(FrontEndExpr exp) { 243 return std::get<1>(maskVarMap.lookup(exp)); 244 } 245 246 mlir::Value lookupMaskHeader(FrontEndExpr exp) { 247 return std::get<2>(maskVarMap.lookup(exp)); 248 } 249 250 // Stack of WHERE constructs, each building a list of mask expressions. 251 llvm::SmallVector<llvm::SmallVector<FrontEndMaskExpr>> &getMasks() { 252 return stack; 253 } 254 const llvm::SmallVector<llvm::SmallVector<FrontEndMaskExpr>> & 255 getMasks() const { 256 return stack; 257 } 258 259 // Cleanup at the end of a WHERE statement or construct. 260 void shrinkStack() { 261 Base::shrinkStack(); 262 if (stack.empty()) 263 maskVarMap.clear(); 264 } 265 266 private: 267 llvm::DenseMap<FrontEndExpr, 268 std::tuple<mlir::Value, mlir::Value, mlir::Value>> 269 maskVarMap; 270 }; 271 272 class ExplicitIterSpace; 273 llvm::raw_ostream &operator<<(llvm::raw_ostream &, const ExplicitIterSpace &); 274 275 /// Create all the array_load ops for the explicit iteration space context. The 276 /// nest of FORALLs must have been analyzed a priori. 277 void createArrayLoads(AbstractConverter &converter, ExplicitIterSpace &esp, 278 SymMap &symMap); 279 280 /// Create the array_merge_store ops after the explicit iteration space context 281 /// is conmpleted. 282 void createArrayMergeStores(AbstractConverter &converter, 283 ExplicitIterSpace &esp); 284 using ExplicitSpaceArrayBases = 285 std::variant<FrontEndSymbol, const evaluate::Component *, 286 const evaluate::ArrayRef *>; 287 288 unsigned getHashValue(const ExplicitSpaceArrayBases &x); 289 bool isEqual(const ExplicitSpaceArrayBases &x, 290 const ExplicitSpaceArrayBases &y); 291 292 } // namespace Fortran::lower 293 294 namespace llvm { 295 template <> 296 struct DenseMapInfo<Fortran::lower::ExplicitSpaceArrayBases> { 297 static inline Fortran::lower::ExplicitSpaceArrayBases getEmptyKey() { 298 return reinterpret_cast<Fortran::lower::FrontEndSymbol>(~0); 299 } 300 static inline Fortran::lower::ExplicitSpaceArrayBases getTombstoneKey() { 301 return reinterpret_cast<Fortran::lower::FrontEndSymbol>(~0 - 1); 302 } 303 static unsigned 304 getHashValue(const Fortran::lower::ExplicitSpaceArrayBases &v) { 305 return Fortran::lower::getHashValue(v); 306 } 307 static bool isEqual(const Fortran::lower::ExplicitSpaceArrayBases &lhs, 308 const Fortran::lower::ExplicitSpaceArrayBases &rhs) { 309 return Fortran::lower::isEqual(lhs, rhs); 310 } 311 }; 312 } // namespace llvm 313 314 namespace Fortran::lower { 315 /// Fortran also allows arrays to be evaluated under constructs which allow the 316 /// user to explicitly specify the iteration space using concurrent-control 317 /// expressions. These constructs allow the user to define both an iteration 318 /// space and explicit access vectors on arrays. These need not be isomorphic. 319 /// The explicit iteration spaces may be conditionalized (conjunctively) with an 320 /// "and" structure and may be found in FORALL (and DO CONCURRENT) constructs. 321 /// 322 /// This class is used in the bridge to collect a stack of lists of 323 /// concurrent-control expressions to be used to generate the iteration space 324 /// and associated masks (if any) for a set of nested FORALL constructs around 325 /// assignment and WHERE constructs. 326 class ExplicitIterSpace { 327 public: 328 using IterSpaceDim = 329 std::tuple<FrontEndSymbol, FrontEndExpr, FrontEndExpr, FrontEndExpr>; 330 using ConcurrentSpec = 331 std::pair<llvm::SmallVector<IterSpaceDim>, FrontEndExpr>; 332 using ArrayBases = ExplicitSpaceArrayBases; 333 334 friend void createArrayLoads(AbstractConverter &converter, 335 ExplicitIterSpace &esp, SymMap &symMap); 336 friend void createArrayMergeStores(AbstractConverter &converter, 337 ExplicitIterSpace &esp); 338 339 /// Is a FORALL context presently active? 340 /// If we are lowering constructs/statements nested within a FORALL, then a 341 /// FORALL context is active. 342 bool isActive() const { return forallContextOpen != 0; } 343 344 /// Get the statement context. 345 StatementContext &stmtContext() { return stmtCtx; } 346 347 //===--------------------------------------------------------------------===// 348 // Analysis support 349 //===--------------------------------------------------------------------===// 350 351 /// Open a new construct. The analysis phase starts here. 352 void pushLevel(); 353 354 /// Close the construct. 355 void popLevel(); 356 357 /// Add new concurrent header control variable symbol. 358 void addSymbol(FrontEndSymbol sym); 359 360 /// Collect array bases from the expression, `x`. 361 void exprBase(FrontEndExpr x, bool lhs); 362 363 /// Called at the end of a assignment statement. 364 void endAssign(); 365 366 /// Return all the active control variables on the stack. 367 llvm::SmallVector<FrontEndSymbol> collectAllSymbols(); 368 369 //===--------------------------------------------------------------------===// 370 // Code gen support 371 //===--------------------------------------------------------------------===// 372 373 /// Enter a FORALL context. 374 void enter() { forallContextOpen++; } 375 376 /// Leave a FORALL context. 377 void leave(); 378 379 void pushLoopNest(std::function<void()> lambda) { 380 ccLoopNest.push_back(lambda); 381 } 382 383 /// Get the inner arguments that correspond to the output arrays. 384 mlir::ValueRange getInnerArgs() const { return innerArgs; } 385 386 /// Set the inner arguments for the next loop level. 387 void setInnerArgs(llvm::ArrayRef<mlir::BlockArgument> args) { 388 innerArgs.clear(); 389 for (auto &arg : args) 390 innerArgs.push_back(arg); 391 } 392 393 /// Reset the outermost `array_load` arguments to the loop nest. 394 void resetInnerArgs() { innerArgs = initialArgs; } 395 396 /// Capture the current outermost loop. 397 void setOuterLoop(fir::DoLoopOp loop) { 398 clearLoops(); 399 outerLoop = loop; 400 } 401 402 /// Sets the inner loop argument at position \p offset to \p val. 403 void setInnerArg(size_t offset, mlir::Value val) { 404 assert(offset < innerArgs.size()); 405 innerArgs[offset] = val; 406 } 407 408 /// Get the types of the output arrays. 409 llvm::SmallVector<mlir::Type> innerArgTypes() const { 410 llvm::SmallVector<mlir::Type> result; 411 for (auto &arg : innerArgs) 412 result.push_back(arg.getType()); 413 return result; 414 } 415 416 /// Create a binding between an Ev::Expr node pointer and a fir::array_load 417 /// op. This bindings will be used when generating the IR. 418 void bindLoad(ArrayBases base, fir::ArrayLoadOp load) { 419 loadBindings.try_emplace(std::move(base), load); 420 } 421 422 fir::ArrayLoadOp findBinding(const ArrayBases &base) { 423 return loadBindings.lookup(base); 424 } 425 426 /// `load` must be a LHS array_load. Returns `std::nullopt` on error. 427 std::optional<size_t> findArgPosition(fir::ArrayLoadOp load); 428 429 bool isLHS(fir::ArrayLoadOp load) { 430 return findArgPosition(load).has_value(); 431 } 432 433 /// `load` must be a LHS array_load. Determine the threaded inner argument 434 /// corresponding to this load. 435 mlir::Value findArgumentOfLoad(fir::ArrayLoadOp load) { 436 if (auto opt = findArgPosition(load)) 437 return innerArgs[*opt]; 438 llvm_unreachable("array load argument not found"); 439 } 440 441 size_t argPosition(mlir::Value arg) { 442 for (auto i : llvm::enumerate(innerArgs)) 443 if (arg == i.value()) 444 return i.index(); 445 llvm_unreachable("inner argument value was not found"); 446 } 447 448 std::optional<fir::ArrayLoadOp> getLhsLoad(size_t i) { 449 assert(i < lhsBases.size()); 450 if (lhsBases[counter]) 451 return findBinding(*lhsBases[counter]); 452 return std::nullopt; 453 } 454 455 /// Return the outermost loop in this FORALL nest. 456 fir::DoLoopOp getOuterLoop() { 457 assert(outerLoop.has_value()); 458 return *outerLoop; 459 } 460 461 /// Return the statement context for the entire, outermost FORALL construct. 462 StatementContext &outermostContext() { return outerContext; } 463 464 /// Generate the explicit loop nest. 465 void genLoopNest() { 466 for (auto &lambda : ccLoopNest) 467 lambda(); 468 } 469 470 /// Clear the array_load bindings. 471 void resetBindings() { loadBindings.clear(); } 472 473 /// Get the current counter value. 474 std::size_t getCounter() const { return counter; } 475 476 /// Increment the counter value to the next assignment statement. 477 void incrementCounter() { counter++; } 478 479 bool isOutermostForall() const { 480 assert(forallContextOpen); 481 return forallContextOpen == 1; 482 } 483 484 void attachLoopCleanup(std::function<void(fir::FirOpBuilder &builder)> fn) { 485 if (!loopCleanup) { 486 loopCleanup = fn; 487 return; 488 } 489 std::function<void(fir::FirOpBuilder &)> oldFn = *loopCleanup; 490 loopCleanup = [=](fir::FirOpBuilder &builder) { 491 oldFn(builder); 492 fn(builder); 493 }; 494 } 495 496 // LLVM standard dump method. 497 LLVM_DUMP_METHOD void dump() const; 498 499 // Pretty-print. 500 friend llvm::raw_ostream &operator<<(llvm::raw_ostream &, 501 const ExplicitIterSpace &); 502 503 /// Finalize the current body statement context. 504 void finalizeContext() { stmtCtx.finalizeAndReset(); } 505 506 void appendLoops(const llvm::SmallVector<fir::DoLoopOp> &loops) { 507 loopStack.push_back(loops); 508 } 509 510 void clearLoops() { loopStack.clear(); } 511 512 llvm::SmallVector<llvm::SmallVector<fir::DoLoopOp>> getLoopStack() const { 513 return loopStack; 514 } 515 516 private: 517 /// Cleanup the analysis results. 518 void conditionalCleanup(); 519 520 StatementContext outerContext; 521 522 // A stack of lists of front-end symbols. 523 llvm::SmallVector<llvm::SmallVector<FrontEndSymbol>> symbolStack; 524 llvm::SmallVector<std::optional<ArrayBases>> lhsBases; 525 llvm::SmallVector<llvm::SmallVector<ArrayBases>> rhsBases; 526 llvm::DenseMap<ArrayBases, fir::ArrayLoadOp> loadBindings; 527 528 // Stack of lambdas to create the loop nest. 529 llvm::SmallVector<std::function<void()>> ccLoopNest; 530 531 // Assignment statement context (inside the loop nest). 532 StatementContext stmtCtx; 533 llvm::SmallVector<mlir::Value> innerArgs; 534 llvm::SmallVector<mlir::Value> initialArgs; 535 std::optional<fir::DoLoopOp> outerLoop; 536 llvm::SmallVector<llvm::SmallVector<fir::DoLoopOp>> loopStack; 537 std::optional<std::function<void(fir::FirOpBuilder &)>> loopCleanup; 538 std::size_t forallContextOpen = 0; 539 std::size_t counter = 0; 540 }; 541 542 /// Is there a Symbol in common between the concurrent header set and the set 543 /// of symbols in the expression? 544 template <typename A> 545 bool symbolSetsIntersect(llvm::ArrayRef<FrontEndSymbol> ctrlSet, 546 const A &exprSyms) { 547 for (const auto &sym : exprSyms) 548 if (llvm::is_contained(ctrlSet, &sym.get())) 549 return true; 550 return false; 551 } 552 553 /// Determine if the subscript expression symbols from an Ev::ArrayRef 554 /// intersects with the set of concurrent control symbols, `ctrlSet`. 555 template <typename A> 556 bool symbolsIntersectSubscripts(llvm::ArrayRef<FrontEndSymbol> ctrlSet, 557 const A &subscripts) { 558 for (auto &sub : subscripts) { 559 if (const auto *expr = 560 std::get_if<evaluate::IndirectSubscriptIntegerExpr>(&sub.u)) 561 if (symbolSetsIntersect(ctrlSet, evaluate::CollectSymbols(expr->value()))) 562 return true; 563 } 564 return false; 565 } 566 567 } // namespace Fortran::lower 568 569 #endif // FORTRAN_LOWER_ITERATIONSPACE_H 570