1 //===- LowerHLFIROrderedAssignments.cpp - Lower HLFIR ordered assignments -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // This file defines a pass to lower HLFIR ordered assignments. 9 // Ordered assignments are all the operations with the 10 // OrderedAssignmentTreeOpInterface that implements user defined assignments, 11 // assignment to vector subscripted entities, and assignments inside forall and 12 // where. 13 // The pass lowers these operations to regular hlfir.assign, loops and, if 14 // needed, introduces temporary storage to fulfill Fortran semantics. 15 // 16 // For each rewrite, an analysis builds an evaluation schedule, and then the 17 // new code is generated by following the evaluation schedule. 18 //===----------------------------------------------------------------------===// 19 20 #include "ScheduleOrderedAssignments.h" 21 #include "flang/Optimizer/Builder/FIRBuilder.h" 22 #include "flang/Optimizer/Builder/HLFIRTools.h" 23 #include "flang/Optimizer/Builder/TemporaryStorage.h" 24 #include "flang/Optimizer/Builder/Todo.h" 25 #include "flang/Optimizer/Dialect/Support/FIRContext.h" 26 #include "flang/Optimizer/HLFIR/Passes.h" 27 #include "mlir/IR/Dominance.h" 28 #include "mlir/IR/IRMapping.h" 29 #include "mlir/Transforms/DialectConversion.h" 30 #include "llvm/ADT/SmallSet.h" 31 #include "llvm/ADT/TypeSwitch.h" 32 #include "llvm/Support/Debug.h" 33 34 namespace hlfir { 35 #define GEN_PASS_DEF_LOWERHLFIRORDEREDASSIGNMENTS 36 #include "flang/Optimizer/HLFIR/Passes.h.inc" 37 } // namespace hlfir 38 39 #define DEBUG_TYPE "flang-ordered-assignment" 40 41 // Test option only to test the scheduling part only (operations are erased 42 // without codegen). The only goal is to allow printing and testing the debug 43 // info. 44 static llvm::cl::opt<bool> dbgScheduleOnly( 45 "flang-dbg-order-assignment-schedule-only", 46 llvm::cl::desc("Only run ordered assignment scheduling with no codegen"), 47 llvm::cl::init(false)); 48 49 namespace { 50 51 /// Structure that represents a masked expression being lowered. Masked 52 /// expressions are any expressions inside an hlfir.where. As described in 53 /// Fortran 2018 section 10.2.3.2, the evaluation of the elemental parts of such 54 /// expressions must be masked, while the evaluation of none elemental parts 55 /// must not be masked. This structure analyzes the region evaluating the 56 /// expression and allows splitting the generation of the none elemental part 57 /// from the elemental part. 58 struct MaskedArrayExpr { 59 MaskedArrayExpr(mlir::Location loc, mlir::Region ®ion, 60 bool isOuterMaskExpr); 61 62 /// Generate the none elemental part. Must be called outside of the 63 /// loops created for the WHERE construct. 64 void generateNoneElementalPart(fir::FirOpBuilder &builder, 65 mlir::IRMapping &mapper); 66 67 /// Methods below can only be called once generateNoneElementalPart has been 68 /// called. 69 70 /// Return the shape of the expression. 71 mlir::Value generateShape(fir::FirOpBuilder &builder, 72 mlir::IRMapping &mapper); 73 /// Return the value of an element value for this expression given the current 74 /// where loop indices. 75 mlir::Value generateElementalParts(fir::FirOpBuilder &builder, 76 mlir::ValueRange oneBasedIndices, 77 mlir::IRMapping &mapper); 78 /// Generate the cleanup for the none elemental parts, if any. This must be 79 /// called after the loops created for the WHERE construct. 80 void generateNoneElementalCleanupIfAny(fir::FirOpBuilder &builder, 81 mlir::IRMapping &mapper); 82 83 /// Helper to clone the clean-ups of the masked expr region terminator. 84 /// This is called outside of the loops for the initial mask, and inside 85 /// the loops for the other masked expressions. 86 mlir::Operation *generateMaskedExprCleanUps(fir::FirOpBuilder &builder, 87 mlir::IRMapping &mapper); 88 89 mlir::Location loc; 90 mlir::Region ®ion; 91 /// Set of operations that form the elemental parts of the 92 /// expression evaluation. These are the hlfir.elemental and 93 /// hlfir.elemental_addr that form the elemental tree producing 94 /// the expression value. hlfir.elemental that produce values 95 /// used inside transformational operations are not part of this set. 96 llvm::SmallSet<mlir::Operation *, 4> elementalParts{}; 97 /// Was generateNoneElementalPart called? 98 bool noneElementalPartWasGenerated = false; 99 /// Is this expression the mask expression of the outer where statement? 100 /// It is special because its evaluation is not masked by anything yet. 101 bool isOuterMaskExpr = false; 102 }; 103 } // namespace 104 105 namespace { 106 /// Structure that visits an ordered assignment tree and generates code for 107 /// it according to a schedule. 108 class OrderedAssignmentRewriter { 109 public: 110 OrderedAssignmentRewriter(fir::FirOpBuilder &builder, 111 hlfir::OrderedAssignmentTreeOpInterface root) 112 : builder{builder}, root{root} {} 113 114 /// Generate code for the current run of the schedule. 115 void lowerRun(hlfir::Run &run) { 116 currentRun = &run; 117 walk(root); 118 currentRun = nullptr; 119 assert(constructStack.empty() && "must exit constructs after a run"); 120 mapper.clear(); 121 savedInCurrentRunBeforeUse.clear(); 122 } 123 124 /// After all run have been lowered, clean-up all the temporary 125 /// storage that were created (do not call final routines). 126 void cleanupSavedEntities() { 127 for (auto &temp : savedEntities) 128 temp.second.destroy(root.getLoc(), builder); 129 } 130 131 /// Lowered value for an expression, and the original hlfir.yield if any 132 /// clean-up needs to be cloned after usage. 133 using ValueAndCleanUp = std::pair<mlir::Value, std::optional<hlfir::YieldOp>>; 134 135 private: 136 /// Walk the part of an order assignment tree node that needs 137 /// to be evaluated in the current run. 138 void walk(hlfir::OrderedAssignmentTreeOpInterface node); 139 140 /// Generate code when entering a given ordered assignment node. 141 void pre(hlfir::ForallOp forallOp); 142 void pre(hlfir::ForallIndexOp); 143 void pre(hlfir::ForallMaskOp); 144 void pre(hlfir::WhereOp whereOp); 145 void pre(hlfir::ElseWhereOp elseWhereOp); 146 void pre(hlfir::RegionAssignOp); 147 148 /// Generate code when leaving a given ordered assignment node. 149 void post(hlfir::ForallOp); 150 void post(hlfir::ForallMaskOp); 151 void post(hlfir::WhereOp); 152 void post(hlfir::ElseWhereOp); 153 /// Enter (and maybe create) the fir.if else block of an ElseWhereOp, 154 /// but do not generate the elswhere mask or the new fir.if. 155 void enterElsewhere(hlfir::ElseWhereOp); 156 157 /// Are there any leaf region in the node that must be saved in the current 158 /// run? 159 bool mustSaveRegionIn( 160 hlfir::OrderedAssignmentTreeOpInterface node, 161 llvm::SmallVectorImpl<hlfir::SaveEntity> &saveEntities) const; 162 /// Should this node be evaluated in the current run? Saving a region in a 163 /// node does not imply the node needs to be evaluated. 164 bool 165 isRequiredInCurrentRun(hlfir::OrderedAssignmentTreeOpInterface node) const; 166 167 /// Generate a scalar value yielded by an ordered assignment tree region. 168 /// If the value was not saved in a previous run, this clone the region 169 /// code, except the final yield, at the current execution point. 170 /// If the value was saved in a previous run, this fetches the saved value 171 /// from the temporary storage and returns the value. 172 /// Inside Forall, the value will be hoisted outside of the forall loops if 173 /// it does not depend on the forall indices. 174 /// An optional type can be provided to get a value from a specific type 175 /// (the cast will be hoisted if the computation is hoisted). 176 mlir::Value generateYieldedScalarValue( 177 mlir::Region ®ion, 178 std::optional<mlir::Type> castToType = std::nullopt); 179 180 /// Generate an entity yielded by an ordered assignment tree region, and 181 /// optionally return the (uncloned) yield if there is any clean-up that 182 /// should be done after using the entity. Like, generateYieldedScalarValue, 183 /// this will return the saved value if the region was saved in a previous 184 /// run. 185 ValueAndCleanUp 186 generateYieldedEntity(mlir::Region ®ion, 187 std::optional<mlir::Type> castToType = std::nullopt); 188 189 struct LhsValueAndCleanUp { 190 mlir::Value lhs; 191 std::optional<hlfir::YieldOp> elementalCleanup; 192 mlir::Region *nonElementalCleanup = nullptr; 193 std::optional<hlfir::LoopNest> vectorSubscriptLoopNest; 194 std::optional<mlir::Value> vectorSubscriptShape; 195 }; 196 197 /// Generate the left-hand side. If the left-hand side is vector 198 /// subscripted (hlfir.elemental_addr), this will create a loop nest 199 /// (unless it was already created by a WHERE mask) and return the 200 /// element address. 201 LhsValueAndCleanUp 202 generateYieldedLHS(mlir::Location loc, mlir::Region &lhsRegion, 203 std::optional<hlfir::Entity> loweredRhs = std::nullopt); 204 205 /// If \p maybeYield is present and has a clean-up, generate the clean-up 206 /// at the current insertion point (by cloning). 207 void generateCleanupIfAny(std::optional<hlfir::YieldOp> maybeYield); 208 void generateCleanupIfAny(mlir::Region *cleanupRegion); 209 210 /// Generate a masked entity. This can only be called when whereLoopNest was 211 /// set (When an hlfir.where is being visited). 212 /// This method returns the scalar element (that may have been previously 213 /// saved) for the current indices inside the where loop. 214 mlir::Value generateMaskedEntity(mlir::Location loc, mlir::Region ®ion) { 215 MaskedArrayExpr maskedExpr(loc, region, /*isOuterMaskExpr=*/!whereLoopNest); 216 return generateMaskedEntity(maskedExpr); 217 } 218 mlir::Value generateMaskedEntity(MaskedArrayExpr &maskedExpr); 219 220 /// Create a fir.if at the current position inside the where loop nest 221 /// given the element value of a mask. 222 void generateMaskIfOp(mlir::Value cdt); 223 224 /// Save a value for subsequent runs. 225 void generateSaveEntity(hlfir::SaveEntity savedEntity, 226 bool willUseSavedEntityInSameRun); 227 void saveLeftHandSide(hlfir::SaveEntity savedEntity, 228 hlfir::RegionAssignOp regionAssignOp); 229 230 /// Get a value if it was saved in this run or a previous run. Returns 231 /// nullopt if it has not been saved. 232 std::optional<ValueAndCleanUp> getIfSaved(mlir::Region ®ion); 233 234 /// Generate code before the loop nest for the current run, if any. 235 void doBeforeLoopNest(const std::function<void()> &callback) { 236 if (constructStack.empty()) { 237 callback(); 238 return; 239 } 240 auto insertionPoint = builder.saveInsertionPoint(); 241 builder.setInsertionPoint(constructStack[0]); 242 callback(); 243 builder.restoreInsertionPoint(insertionPoint); 244 } 245 246 /// Can the current loop nest iteration number be computed? For simplicity, 247 /// this is true if and only if all the bounds and steps of the fir.do_loop 248 /// nest dominates the outer loop. The argument is filled with the current 249 /// loop nest on success. 250 bool currentLoopNestIterationNumberCanBeComputed( 251 llvm::SmallVectorImpl<fir::DoLoopOp> &loopNest); 252 253 template <typename T> 254 fir::factory::TemporaryStorage *insertSavedEntity(mlir::Region ®ion, 255 T &&temp) { 256 auto inserted = 257 savedEntities.insert(std::make_pair(®ion, std::forward<T>(temp))); 258 assert(inserted.second && "temp must have been emplaced"); 259 return &inserted.first->second; 260 } 261 262 fir::FirOpBuilder &builder; 263 264 /// Map containing the mapping between the original order assignment tree 265 /// operations and the operations that have been cloned in the current run. 266 /// It is reset between two runs. 267 mlir::IRMapping mapper; 268 /// Dominance info is used to determine if inner loop bounds are all computed 269 /// before outer loop for the current loop. It does not need to be reset 270 /// between runs. 271 mlir::DominanceInfo dominanceInfo; 272 /// Construct stack in the current run. This allows setting back the insertion 273 /// point correctly when leaving a node that requires a fir.do_loop or fir.if 274 /// operation. 275 llvm::SmallVector<mlir::Operation *> constructStack; 276 /// Current where loop nest, if any. 277 std::optional<hlfir::LoopNest> whereLoopNest; 278 279 /// Map of temporary storage to keep track of saved entity once the run 280 /// that saves them has been lowered. It is kept in-between runs. 281 /// llvm::MapVector is used to guarantee deterministic order 282 /// of iterating through savedEntities (e.g. for generating 283 /// destruction code for the temporary storages). 284 llvm::MapVector<mlir::Region *, fir::factory::TemporaryStorage> savedEntities; 285 /// Map holding the values that were saved in the current run and that also 286 /// need to be used (because their construct will be visited). It is reset 287 /// after each run. It avoids having to store and fetch in the temporary 288 /// during the same run, which would require the temporary to have different 289 /// fetching and storing counters. 290 llvm::DenseMap<mlir::Region *, ValueAndCleanUp> savedInCurrentRunBeforeUse; 291 292 /// Root of the order assignment tree being lowered. 293 hlfir::OrderedAssignmentTreeOpInterface root; 294 /// Pointer to the current run of the schedule being lowered. 295 hlfir::Run *currentRun = nullptr; 296 297 /// When allocating temporary storage inlined, indicate if the storage should 298 /// be heap or stack allocated. Temporary allocated with the runtime are heap 299 /// allocated by the runtime. 300 bool allocateOnHeap = true; 301 }; 302 } // namespace 303 304 void OrderedAssignmentRewriter::walk( 305 hlfir::OrderedAssignmentTreeOpInterface node) { 306 bool mustVisit = 307 isRequiredInCurrentRun(node) || mlir::isa<hlfir::ForallIndexOp>(node); 308 llvm::SmallVector<hlfir::SaveEntity> saveEntities; 309 mlir::Operation *nodeOp = node.getOperation(); 310 if (mustSaveRegionIn(node, saveEntities)) { 311 mlir::IRRewriter::InsertPoint insertionPoint; 312 if (auto elseWhereOp = mlir::dyn_cast<hlfir::ElseWhereOp>(nodeOp)) { 313 // ElseWhere mask to save must be evaluated inside the fir.if else 314 // for the previous where/elsewehere (its evaluation must be 315 // masked by the "pending control mask"). 316 insertionPoint = builder.saveInsertionPoint(); 317 enterElsewhere(elseWhereOp); 318 } 319 for (hlfir::SaveEntity saveEntity : saveEntities) 320 generateSaveEntity(saveEntity, mustVisit); 321 if (insertionPoint.isSet()) 322 builder.restoreInsertionPoint(insertionPoint); 323 } 324 if (mustVisit) { 325 llvm::TypeSwitch<mlir::Operation *, void>(nodeOp) 326 .Case<hlfir::ForallOp, hlfir::ForallIndexOp, hlfir::ForallMaskOp, 327 hlfir::RegionAssignOp, hlfir::WhereOp, hlfir::ElseWhereOp>( 328 [&](auto concreteOp) { pre(concreteOp); }) 329 .Default([](auto) {}); 330 if (auto *body = node.getSubTreeRegion()) { 331 for (mlir::Operation &op : body->getOps()) 332 if (auto subNode = 333 mlir::dyn_cast<hlfir::OrderedAssignmentTreeOpInterface>(op)) 334 walk(subNode); 335 llvm::TypeSwitch<mlir::Operation *, void>(nodeOp) 336 .Case<hlfir::ForallOp, hlfir::ForallMaskOp, hlfir::WhereOp, 337 hlfir::ElseWhereOp>([&](auto concreteOp) { post(concreteOp); }) 338 .Default([](auto) {}); 339 } 340 } 341 } 342 343 void OrderedAssignmentRewriter::pre(hlfir::ForallOp forallOp) { 344 /// Create a fir.do_loop given the hlfir.forall control values. 345 mlir::Type idxTy = builder.getIndexType(); 346 mlir::Location loc = forallOp.getLoc(); 347 mlir::Value lb = generateYieldedScalarValue(forallOp.getLbRegion(), idxTy); 348 mlir::Value ub = generateYieldedScalarValue(forallOp.getUbRegion(), idxTy); 349 mlir::Value step; 350 if (forallOp.getStepRegion().empty()) { 351 auto insertionPoint = builder.saveInsertionPoint(); 352 if (!constructStack.empty()) 353 builder.setInsertionPoint(constructStack[0]); 354 step = builder.createIntegerConstant(loc, idxTy, 1); 355 if (!constructStack.empty()) 356 builder.restoreInsertionPoint(insertionPoint); 357 } else { 358 step = generateYieldedScalarValue(forallOp.getStepRegion(), idxTy); 359 } 360 auto doLoop = builder.create<fir::DoLoopOp>(loc, lb, ub, step); 361 builder.setInsertionPointToStart(doLoop.getBody()); 362 mlir::Value oldIndex = forallOp.getForallIndexValue(); 363 mlir::Value newIndex = 364 builder.createConvert(loc, oldIndex.getType(), doLoop.getInductionVar()); 365 mapper.map(oldIndex, newIndex); 366 constructStack.push_back(doLoop); 367 } 368 369 void OrderedAssignmentRewriter::post(hlfir::ForallOp) { 370 assert(!constructStack.empty() && "must contain a loop"); 371 builder.setInsertionPointAfter(constructStack.pop_back_val()); 372 } 373 374 void OrderedAssignmentRewriter::pre(hlfir::ForallIndexOp forallIndexOp) { 375 mlir::Location loc = forallIndexOp.getLoc(); 376 mlir::Type intTy = fir::unwrapRefType(forallIndexOp.getType()); 377 mlir::Value indexVar = 378 builder.createTemporary(loc, intTy, forallIndexOp.getName()); 379 mlir::Value newVal = mapper.lookupOrDefault(forallIndexOp.getIndex()); 380 builder.createStoreWithConvert(loc, newVal, indexVar); 381 mapper.map(forallIndexOp, indexVar); 382 } 383 384 void OrderedAssignmentRewriter::pre(hlfir::ForallMaskOp forallMaskOp) { 385 mlir::Location loc = forallMaskOp.getLoc(); 386 mlir::Value mask = generateYieldedScalarValue(forallMaskOp.getMaskRegion(), 387 builder.getI1Type()); 388 auto ifOp = builder.create<fir::IfOp>(loc, std::nullopt, mask, false); 389 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 390 constructStack.push_back(ifOp); 391 } 392 393 void OrderedAssignmentRewriter::post(hlfir::ForallMaskOp forallMaskOp) { 394 assert(!constructStack.empty() && "must contain an ifop"); 395 builder.setInsertionPointAfter(constructStack.pop_back_val()); 396 } 397 398 /// Convert an entity to the type of a given mold. 399 /// This is intended to help with cases where hlfir entity is a value while 400 /// it must be used as a variable or vice-versa. These mismatches may occur 401 /// between the type of user defined assignment block arguments and the actual 402 /// argument that was lowered for them. The actual may be an in-memory copy 403 /// while the block argument expects an hlfir.expr. 404 static hlfir::Entity 405 convertToMoldType(mlir::Location loc, fir::FirOpBuilder &builder, 406 hlfir::Entity input, hlfir::Entity mold, 407 llvm::SmallVectorImpl<hlfir::CleanupFunction> &cleanups) { 408 if (input.getType() == mold.getType()) 409 return input; 410 fir::FirOpBuilder *b = &builder; 411 if (input.isVariable() && mold.isValue()) { 412 if (fir::isa_trivial(mold.getType())) { 413 // fir.ref<T> to T. 414 mlir::Value load = builder.create<fir::LoadOp>(loc, input); 415 return hlfir::Entity{builder.createConvert(loc, mold.getType(), load)}; 416 } 417 // fir.ref<T> to hlfir.expr<T>. 418 mlir::Value asExpr = builder.create<hlfir::AsExprOp>(loc, input); 419 if (asExpr.getType() != mold.getType()) 420 TODO(loc, "hlfir.expr conversion"); 421 cleanups.emplace_back([=]() { b->create<hlfir::DestroyOp>(loc, asExpr); }); 422 return hlfir::Entity{asExpr}; 423 } 424 if (input.isValue() && mold.isVariable()) { 425 // T to fir.ref<T>, or hlfir.expr<T> to fir.ref<T>. 426 hlfir::AssociateOp associate = hlfir::genAssociateExpr( 427 loc, builder, input, mold.getFortranElementType(), ".tmp.val2ref"); 428 cleanups.emplace_back( 429 [=]() { b->create<hlfir::EndAssociateOp>(loc, associate); }); 430 return hlfir::Entity{associate.getBase()}; 431 } 432 // Variable to Variable mismatch (e.g., fir.heap<T> vs fir.ref<T>), or value 433 // to Value mismatch (e.g. i1 vs fir.logical<4>). 434 if (mlir::isa<fir::BaseBoxType>(mold.getType()) && 435 !mlir::isa<fir::BaseBoxType>(input.getType())) { 436 // An entity may have have been saved without descriptor while the original 437 // value had a descriptor (e.g., it was not contiguous). 438 auto emboxed = hlfir::convertToBox(loc, builder, input, mold.getType()); 439 assert(!emboxed.second && "temp should already be in memory"); 440 input = hlfir::Entity{fir::getBase(emboxed.first)}; 441 } 442 return hlfir::Entity{builder.createConvert(loc, mold.getType(), input)}; 443 } 444 445 void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) { 446 mlir::Location loc = regionAssignOp.getLoc(); 447 std::optional<hlfir::LoopNest> elementalLoopNest; 448 auto [rhsValue, oldRhsYield] = 449 generateYieldedEntity(regionAssignOp.getRhsRegion()); 450 hlfir::Entity rhsEntity{rhsValue}; 451 LhsValueAndCleanUp loweredLhs = 452 generateYieldedLHS(loc, regionAssignOp.getLhsRegion(), rhsEntity); 453 hlfir::Entity lhsEntity{loweredLhs.lhs}; 454 if (loweredLhs.vectorSubscriptLoopNest) 455 rhsEntity = hlfir::getElementAt( 456 loc, builder, rhsEntity, 457 loweredLhs.vectorSubscriptLoopNest->oneBasedIndices); 458 if (!regionAssignOp.getUserDefinedAssignment().empty()) { 459 hlfir::Entity userAssignLhs{regionAssignOp.getUserAssignmentLhs()}; 460 hlfir::Entity userAssignRhs{regionAssignOp.getUserAssignmentRhs()}; 461 std::optional<hlfir::LoopNest> elementalLoopNest; 462 if (lhsEntity.isArray() && userAssignLhs.isScalar()) { 463 // Elemental assignment with array argument (the RHS cannot be an array 464 // if the LHS is not). 465 mlir::Value shape = hlfir::genShape(loc, builder, lhsEntity); 466 elementalLoopNest = hlfir::genLoopNest(loc, builder, shape); 467 builder.setInsertionPointToStart(elementalLoopNest->body); 468 lhsEntity = hlfir::getElementAt(loc, builder, lhsEntity, 469 elementalLoopNest->oneBasedIndices); 470 rhsEntity = hlfir::getElementAt(loc, builder, rhsEntity, 471 elementalLoopNest->oneBasedIndices); 472 } 473 474 llvm::SmallVector<hlfir::CleanupFunction, 2> argConversionCleanups; 475 lhsEntity = convertToMoldType(loc, builder, lhsEntity, userAssignLhs, 476 argConversionCleanups); 477 rhsEntity = convertToMoldType(loc, builder, rhsEntity, userAssignRhs, 478 argConversionCleanups); 479 mapper.map(userAssignLhs, lhsEntity); 480 mapper.map(userAssignRhs, rhsEntity); 481 for (auto &op : 482 regionAssignOp.getUserDefinedAssignment().front().without_terminator()) 483 (void)builder.clone(op, mapper); 484 for (auto &cleanupConversion : argConversionCleanups) 485 cleanupConversion(); 486 if (elementalLoopNest) 487 builder.setInsertionPointAfter(elementalLoopNest->outerOp); 488 } else { 489 // TODO: preserve allocatable assignment aspects for forall once 490 // they are conveyed in hlfir.region_assign. 491 builder.create<hlfir::AssignOp>(loc, rhsEntity, lhsEntity); 492 } 493 generateCleanupIfAny(loweredLhs.elementalCleanup); 494 if (loweredLhs.vectorSubscriptLoopNest) 495 builder.setInsertionPointAfter(loweredLhs.vectorSubscriptLoopNest->outerOp); 496 generateCleanupIfAny(oldRhsYield); 497 generateCleanupIfAny(loweredLhs.nonElementalCleanup); 498 } 499 500 void OrderedAssignmentRewriter::generateMaskIfOp(mlir::Value cdt) { 501 mlir::Location loc = cdt.getLoc(); 502 cdt = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{cdt}); 503 cdt = builder.createConvert(loc, builder.getI1Type(), cdt); 504 auto ifOp = builder.create<fir::IfOp>(cdt.getLoc(), std::nullopt, cdt, 505 /*withElseRegion=*/false); 506 constructStack.push_back(ifOp.getOperation()); 507 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 508 } 509 510 void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) { 511 mlir::Location loc = whereOp.getLoc(); 512 if (!whereLoopNest) { 513 // This is the top-level WHERE. Start a loop nest iterating on the shape of 514 // the where mask. 515 if (auto maybeSaved = getIfSaved(whereOp.getMaskRegion())) { 516 // Use the saved value to get the shape and condition element. 517 hlfir::Entity savedMask{maybeSaved->first}; 518 mlir::Value shape = hlfir::genShape(loc, builder, savedMask); 519 whereLoopNest = hlfir::genLoopNest(loc, builder, shape); 520 constructStack.push_back(whereLoopNest->outerOp); 521 builder.setInsertionPointToStart(whereLoopNest->body); 522 mlir::Value cdt = hlfir::getElementAt(loc, builder, savedMask, 523 whereLoopNest->oneBasedIndices); 524 generateMaskIfOp(cdt); 525 if (maybeSaved->second) { 526 // If this is the same run as the one that saved the value, the clean-up 527 // was left-over to be done now. 528 auto insertionPoint = builder.saveInsertionPoint(); 529 builder.setInsertionPointAfter(whereLoopNest->outerOp); 530 generateCleanupIfAny(maybeSaved->second); 531 builder.restoreInsertionPoint(insertionPoint); 532 } 533 return; 534 } 535 // The mask was not evaluated yet or can be safely re-evaluated. 536 MaskedArrayExpr mask(loc, whereOp.getMaskRegion(), 537 /*isOuterMaskExpr=*/true); 538 mask.generateNoneElementalPart(builder, mapper); 539 mlir::Value shape = mask.generateShape(builder, mapper); 540 whereLoopNest = hlfir::genLoopNest(loc, builder, shape); 541 constructStack.push_back(whereLoopNest->outerOp); 542 builder.setInsertionPointToStart(whereLoopNest->body); 543 mlir::Value cdt = generateMaskedEntity(mask); 544 generateMaskIfOp(cdt); 545 return; 546 } 547 // Where Loops have been already created by a parent WHERE. 548 // Generate a fir.if with the value of the current element of the mask 549 // inside the loops. The case where the mask was saved is handled in the 550 // generateYieldedScalarValue call. 551 mlir::Value cdt = generateYieldedScalarValue(whereOp.getMaskRegion()); 552 generateMaskIfOp(cdt); 553 } 554 555 void OrderedAssignmentRewriter::post(hlfir::WhereOp whereOp) { 556 assert(!constructStack.empty() && "must contain a fir.if"); 557 builder.setInsertionPointAfter(constructStack.pop_back_val()); 558 // If all where/elsewhere fir.if have been popped, this is the outer whereOp, 559 // and the where loop must be exited. 560 assert(!constructStack.empty() && "must contain a fir.do_loop or fir.if"); 561 if (mlir::isa<fir::DoLoopOp>(constructStack.back())) { 562 builder.setInsertionPointAfter(constructStack.pop_back_val()); 563 whereLoopNest.reset(); 564 } 565 } 566 567 void OrderedAssignmentRewriter::enterElsewhere(hlfir::ElseWhereOp elseWhereOp) { 568 // Create an "else" region for the current where/elsewhere fir.if. 569 auto ifOp = mlir::dyn_cast<fir::IfOp>(constructStack.back()); 570 assert(ifOp && "must be an if"); 571 if (ifOp.getElseRegion().empty()) { 572 mlir::Location loc = elseWhereOp.getLoc(); 573 builder.createBlock(&ifOp.getElseRegion()); 574 auto end = builder.create<fir::ResultOp>(loc); 575 builder.setInsertionPoint(end); 576 } else { 577 builder.setInsertionPoint(&ifOp.getElseRegion().back().back()); 578 } 579 } 580 581 void OrderedAssignmentRewriter::pre(hlfir::ElseWhereOp elseWhereOp) { 582 enterElsewhere(elseWhereOp); 583 if (elseWhereOp.getMaskRegion().empty()) 584 return; 585 // Create new nested fir.if with elsewhere mask if any. 586 mlir::Value cdt = generateYieldedScalarValue(elseWhereOp.getMaskRegion()); 587 generateMaskIfOp(cdt); 588 } 589 590 void OrderedAssignmentRewriter::post(hlfir::ElseWhereOp elseWhereOp) { 591 // Exit ifOp that was created for the elseWhereOp mask, if any. 592 if (elseWhereOp.getMaskRegion().empty()) 593 return; 594 assert(!constructStack.empty() && "must contain a fir.if"); 595 builder.setInsertionPointAfter(constructStack.pop_back_val()); 596 } 597 598 /// Is this value a Forall index? 599 /// Forall index are block arguments of hlfir.forall body, or the result 600 /// of hlfir.forall_index. 601 static bool isForallIndex(mlir::Value value) { 602 if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(value)) { 603 if (mlir::Block *block = blockArg.getOwner()) 604 return block->isEntryBlock() && 605 mlir::isa_and_nonnull<hlfir::ForallOp>(block->getParentOp()); 606 return false; 607 } 608 return value.getDefiningOp<hlfir::ForallIndexOp>(); 609 } 610 611 static OrderedAssignmentRewriter::ValueAndCleanUp 612 castIfNeeded(mlir::Location loc, fir::FirOpBuilder &builder, 613 OrderedAssignmentRewriter::ValueAndCleanUp valueAndCleanUp, 614 std::optional<mlir::Type> castToType) { 615 if (!castToType.has_value()) 616 return valueAndCleanUp; 617 mlir::Value cast = 618 builder.createConvert(loc, *castToType, valueAndCleanUp.first); 619 return {cast, valueAndCleanUp.second}; 620 } 621 622 std::optional<OrderedAssignmentRewriter::ValueAndCleanUp> 623 OrderedAssignmentRewriter::getIfSaved(mlir::Region ®ion) { 624 mlir::Location loc = region.getParentOp()->getLoc(); 625 // If the region was saved in the same run, use the value that was evaluated 626 // instead of fetching the temp, and do clean-up, if any, that were delayed. 627 // This is done to avoid requiring the temporary stack to have different 628 // fetching and storing counters, and also because it produces slightly better 629 // code. 630 if (auto savedInSameRun = savedInCurrentRunBeforeUse.find(®ion); 631 savedInSameRun != savedInCurrentRunBeforeUse.end()) 632 return savedInSameRun->second; 633 // If the region was saved in a previous run, fetch the saved value. 634 if (auto temp = savedEntities.find(®ion); temp != savedEntities.end()) { 635 doBeforeLoopNest([&]() { temp->second.resetFetchPosition(loc, builder); }); 636 return ValueAndCleanUp{temp->second.fetch(loc, builder), std::nullopt}; 637 } 638 return std::nullopt; 639 } 640 641 static hlfir::YieldOp getYield(mlir::Region ®ion) { 642 auto yield = mlir::dyn_cast_or_null<hlfir::YieldOp>( 643 region.back().getOperations().back()); 644 assert(yield && "region computing entities must end with a YieldOp"); 645 return yield; 646 } 647 648 OrderedAssignmentRewriter::ValueAndCleanUp 649 OrderedAssignmentRewriter::generateYieldedEntity( 650 mlir::Region ®ion, std::optional<mlir::Type> castToType) { 651 mlir::Location loc = region.getParentOp()->getLoc(); 652 if (auto maybeValueAndCleanUp = getIfSaved(region)) 653 return castIfNeeded(loc, builder, *maybeValueAndCleanUp, castToType); 654 // Otherwise, evaluate the region now. 655 656 // Masked expression must not evaluate the elemental parts that are masked, 657 // they have custom code generation. 658 if (whereLoopNest.has_value()) { 659 mlir::Value maskedValue = generateMaskedEntity(loc, region); 660 return castIfNeeded(loc, builder, {maskedValue, std::nullopt}, castToType); 661 } 662 663 auto oldYield = getYield(region); 664 // Inside Forall, scalars that do not depend on forall indices can be hoisted 665 // here because their evaluation is required to only call pure procedures, and 666 // if they depend on a variable previously assigned to in a forall assignment, 667 // this assignment must have been scheduled in a previous run. Hoisting of 668 // scalars is done here to help creating simple temporary storage if needed. 669 // Inner forall bounds can often be hoisted, and this allows computing the 670 // total number of iterations to create temporary storages. 671 bool hoistComputation = false; 672 if (fir::isa_trivial(oldYield.getEntity().getType()) && 673 !constructStack.empty()) { 674 mlir::WalkResult walkResult = 675 region.walk([&](mlir::Operation *op) -> mlir::WalkResult { 676 if (llvm::any_of(op->getOperands(), [](mlir::Value value) { 677 return isForallIndex(value); 678 })) 679 return mlir::WalkResult::interrupt(); 680 return mlir::WalkResult::advance(); 681 }); 682 hoistComputation = !walkResult.wasInterrupted(); 683 } 684 auto insertionPoint = builder.saveInsertionPoint(); 685 if (hoistComputation) 686 builder.setInsertionPoint(constructStack[0]); 687 688 // Clone all operations except the final hlfir.yield. 689 assert(region.hasOneBlock() && "region must contain one block"); 690 for (auto &op : region.back().without_terminator()) 691 (void)builder.clone(op, mapper); 692 // Get the value for the yielded entity, it may be the result of an operation 693 // that was cloned, or it may be the same as the previous value if the yield 694 // operand was created before the ordered assignment tree. 695 mlir::Value newEntity = mapper.lookupOrDefault(oldYield.getEntity()); 696 if (castToType.has_value()) 697 newEntity = 698 builder.createConvert(newEntity.getLoc(), *castToType, newEntity); 699 700 if (hoistComputation) { 701 // Hoisted trivial scalars clean-up can be done right away, the value is 702 // in registers. 703 generateCleanupIfAny(oldYield); 704 builder.restoreInsertionPoint(insertionPoint); 705 return {newEntity, std::nullopt}; 706 } 707 if (oldYield.getCleanup().empty()) 708 return {newEntity, std::nullopt}; 709 return {newEntity, oldYield}; 710 } 711 712 mlir::Value OrderedAssignmentRewriter::generateYieldedScalarValue( 713 mlir::Region ®ion, std::optional<mlir::Type> castToType) { 714 mlir::Location loc = region.getParentOp()->getLoc(); 715 auto [value, maybeYield] = generateYieldedEntity(region, castToType); 716 value = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{value}); 717 assert(fir::isa_trivial(value.getType()) && "not a trivial scalar value"); 718 generateCleanupIfAny(maybeYield); 719 return value; 720 } 721 722 OrderedAssignmentRewriter::LhsValueAndCleanUp 723 OrderedAssignmentRewriter::generateYieldedLHS( 724 mlir::Location loc, mlir::Region &lhsRegion, 725 std::optional<hlfir::Entity> loweredRhs) { 726 LhsValueAndCleanUp loweredLhs; 727 hlfir::ElementalAddrOp elementalAddrLhs = 728 mlir::dyn_cast<hlfir::ElementalAddrOp>(lhsRegion.back().back()); 729 if (auto temp = savedEntities.find(&lhsRegion); temp != savedEntities.end()) { 730 // The LHS address was computed and saved in a previous run. Fetch it. 731 doBeforeLoopNest([&]() { temp->second.resetFetchPosition(loc, builder); }); 732 if (elementalAddrLhs && !whereLoopNest) { 733 // Vector subscripted designator address are saved element by element. 734 // If no "elemental" loops have been created yet, the shape of the 735 // RHS, if it is an array can be used, or the shape of the vector 736 // subscripted designator must be retrieved to generate the "elemental" 737 // loop nest. 738 if (loweredRhs && loweredRhs->isArray()) { 739 // The RHS shape can be used to create the elemental loops and avoid 740 // saving the LHS shape. 741 loweredLhs.vectorSubscriptShape = 742 hlfir::genShape(loc, builder, *loweredRhs); 743 } else { 744 // If the shape cannot be retrieved from the RHS, it must have been 745 // saved. Get it from the temporary. 746 auto &vectorTmp = 747 temp->second.cast<fir::factory::AnyVectorSubscriptStack>(); 748 loweredLhs.vectorSubscriptShape = vectorTmp.fetchShape(loc, builder); 749 } 750 loweredLhs.vectorSubscriptLoopNest = hlfir::genLoopNest( 751 loc, builder, loweredLhs.vectorSubscriptShape.value()); 752 builder.setInsertionPointToStart( 753 loweredLhs.vectorSubscriptLoopNest->body); 754 } 755 loweredLhs.lhs = temp->second.fetch(loc, builder); 756 return loweredLhs; 757 } 758 // The LHS has not yet been evaluated and saved. Evaluate it now. 759 if (elementalAddrLhs && !whereLoopNest) { 760 // This is a vector subscripted entity. The address of elements must 761 // be returned. If no "elemental" loops have been created for a WHERE, 762 // create them now based on the vector subscripted designator shape. 763 for (auto &op : lhsRegion.front().without_terminator()) 764 (void)builder.clone(op, mapper); 765 loweredLhs.vectorSubscriptShape = 766 mapper.lookupOrDefault(elementalAddrLhs.getShape()); 767 loweredLhs.vectorSubscriptLoopNest = 768 hlfir::genLoopNest(loc, builder, *loweredLhs.vectorSubscriptShape, 769 !elementalAddrLhs.isOrdered()); 770 builder.setInsertionPointToStart(loweredLhs.vectorSubscriptLoopNest->body); 771 mapper.map(elementalAddrLhs.getIndices(), 772 loweredLhs.vectorSubscriptLoopNest->oneBasedIndices); 773 for (auto &op : elementalAddrLhs.getBody().front().without_terminator()) 774 (void)builder.clone(op, mapper); 775 loweredLhs.elementalCleanup = elementalAddrLhs.getYieldOp(); 776 loweredLhs.lhs = 777 mapper.lookupOrDefault(loweredLhs.elementalCleanup->getEntity()); 778 } else { 779 // This is a designator without vector subscripts. Generate it as 780 // it is done for other entities. 781 auto [lhs, yield] = generateYieldedEntity(lhsRegion); 782 loweredLhs.lhs = lhs; 783 if (yield && !yield->getCleanup().empty()) 784 loweredLhs.nonElementalCleanup = &yield->getCleanup(); 785 } 786 return loweredLhs; 787 } 788 789 mlir::Value 790 OrderedAssignmentRewriter::generateMaskedEntity(MaskedArrayExpr &maskedExpr) { 791 assert(whereLoopNest.has_value() && "must be inside WHERE loop nest"); 792 auto insertionPoint = builder.saveInsertionPoint(); 793 if (!maskedExpr.noneElementalPartWasGenerated) { 794 // Generate none elemental part before the where loops (but inside the 795 // current forall loops if any). 796 builder.setInsertionPoint(whereLoopNest->outerOp); 797 maskedExpr.generateNoneElementalPart(builder, mapper); 798 } 799 // Generate the none elemental part cleanup after the where loops. 800 builder.setInsertionPointAfter(whereLoopNest->outerOp); 801 maskedExpr.generateNoneElementalCleanupIfAny(builder, mapper); 802 // Generate the value of the current element for the masked expression 803 // at the current insertion point (inside the where loops, and any fir.if 804 // generated for previous masks). 805 builder.restoreInsertionPoint(insertionPoint); 806 mlir::Value scalar = maskedExpr.generateElementalParts( 807 builder, whereLoopNest->oneBasedIndices, mapper); 808 /// Generate cleanups for the elemental parts inside the loops (setting the 809 /// location so that the assignment will be generated before the cleanups). 810 if (!maskedExpr.isOuterMaskExpr) 811 if (mlir::Operation *firstCleanup = 812 maskedExpr.generateMaskedExprCleanUps(builder, mapper)) 813 builder.setInsertionPoint(firstCleanup); 814 return scalar; 815 } 816 817 void OrderedAssignmentRewriter::generateCleanupIfAny( 818 std::optional<hlfir::YieldOp> maybeYield) { 819 if (maybeYield.has_value()) 820 generateCleanupIfAny(&maybeYield->getCleanup()); 821 } 822 void OrderedAssignmentRewriter::generateCleanupIfAny( 823 mlir::Region *cleanupRegion) { 824 if (cleanupRegion && !cleanupRegion->empty()) { 825 assert(cleanupRegion->hasOneBlock() && "region must contain one block"); 826 for (auto &op : cleanupRegion->back().without_terminator()) 827 builder.clone(op, mapper); 828 } 829 } 830 831 bool OrderedAssignmentRewriter::mustSaveRegionIn( 832 hlfir::OrderedAssignmentTreeOpInterface node, 833 llvm::SmallVectorImpl<hlfir::SaveEntity> &saveEntities) const { 834 for (auto &action : currentRun->actions) 835 if (hlfir::SaveEntity *savedEntity = 836 std::get_if<hlfir::SaveEntity>(&action)) 837 if (node.getOperation() == savedEntity->yieldRegion->getParentOp()) 838 saveEntities.push_back(*savedEntity); 839 return !saveEntities.empty(); 840 } 841 842 bool OrderedAssignmentRewriter::isRequiredInCurrentRun( 843 hlfir::OrderedAssignmentTreeOpInterface node) const { 844 // hlfir.forall_index do not contain saved regions/assignments, 845 // but if their hlfir.forall parent was required, they are 846 // required (the forall indices needs to be mapped). 847 if (mlir::isa<hlfir::ForallIndexOp>(node)) 848 return true; 849 for (auto &action : currentRun->actions) 850 if (hlfir::SaveEntity *savedEntity = 851 std::get_if<hlfir::SaveEntity>(&action)) { 852 // A SaveEntity action does not require evaluating the node that contains 853 // it, but it requires to evaluate all the parents of the nodes that 854 // contains it. For instance, an saving a bound in hlfir.forall B does not 855 // require creating the loops for B, but it requires creating the loops 856 // for any forall parent A of the forall B. 857 if (node->isProperAncestor(savedEntity->yieldRegion->getParentOp())) 858 return true; 859 } else { 860 auto assign = std::get<hlfir::RegionAssignOp>(action); 861 if (node->isAncestor(assign.getOperation())) 862 return true; 863 } 864 return false; 865 } 866 867 /// Is the apply using all the elemental indices in order? 868 static bool isInOrderApply(hlfir::ApplyOp apply, 869 hlfir::ElementalOpInterface elemental) { 870 mlir::Region::BlockArgListType elementalIndices = elemental.getIndices(); 871 if (elementalIndices.size() != apply.getIndices().size()) 872 return false; 873 for (auto [elementalIdx, applyIdx] : 874 llvm::zip(elementalIndices, apply.getIndices())) 875 if (elementalIdx != applyIdx) 876 return false; 877 return true; 878 } 879 880 /// Gather the tree of hlfir::ElementalOpInterface use-def, if any, starting 881 /// from \p elemental, which may be a nullptr. 882 static void 883 gatherElementalTree(hlfir::ElementalOpInterface elemental, 884 llvm::SmallPtrSetImpl<mlir::Operation *> &elementalOps, 885 bool isOutOfOrder) { 886 if (elemental) { 887 // Only inline an applied elemental that must be executed in order if the 888 // applying indices are in order. An hlfir::Elemental may have been created 889 // for a transformational like transpose, and Fortran 2018 standard 890 // section 10.2.3.2, point 10 imply that impure elemental sub-expression 891 // evaluations should not be masked if they are the arguments of 892 // transformational expressions. 893 if (isOutOfOrder && elemental.isOrdered()) 894 return; 895 elementalOps.insert(elemental.getOperation()); 896 for (mlir::Operation &op : elemental.getElementalRegion().getOps()) 897 if (auto apply = mlir::dyn_cast<hlfir::ApplyOp>(op)) { 898 bool isUnorderedApply = 899 isOutOfOrder || !isInOrderApply(apply, elemental); 900 auto maybeElemental = 901 mlir::dyn_cast_or_null<hlfir::ElementalOpInterface>( 902 apply.getExpr().getDefiningOp()); 903 gatherElementalTree(maybeElemental, elementalOps, isUnorderedApply); 904 } 905 } 906 } 907 908 MaskedArrayExpr::MaskedArrayExpr(mlir::Location loc, mlir::Region ®ion, 909 bool isOuterMaskExpr) 910 : loc{loc}, region{region}, isOuterMaskExpr{isOuterMaskExpr} { 911 mlir::Operation &terminator = region.back().back(); 912 if (auto elementalAddr = 913 mlir::dyn_cast<hlfir::ElementalOpInterface>(terminator)) { 914 // Vector subscripted designator (hlfir.elemental_addr terminator). 915 gatherElementalTree(elementalAddr, elementalParts, /*isOutOfOrder=*/false); 916 return; 917 } 918 // Try if elemental expression. 919 mlir::Value entity = mlir::cast<hlfir::YieldOp>(terminator).getEntity(); 920 auto maybeElemental = mlir::dyn_cast_or_null<hlfir::ElementalOpInterface>( 921 entity.getDefiningOp()); 922 gatherElementalTree(maybeElemental, elementalParts, /*isOutOfOrder=*/false); 923 } 924 925 void MaskedArrayExpr::generateNoneElementalPart(fir::FirOpBuilder &builder, 926 mlir::IRMapping &mapper) { 927 assert(!noneElementalPartWasGenerated && 928 "none elemental parts already generated"); 929 if (isOuterMaskExpr) { 930 // The outer mask expression is actually not masked, it is dealt as 931 // such so that its elemental part, if any, can be inlined in the WHERE 932 // loops. But all of the operations outside of hlfir.elemental/ 933 // hlfir.elemental_addr must be emitted now because their value may be 934 // required to deduce the mask shape and the WHERE loop bounds. 935 for (mlir::Operation &op : region.back().without_terminator()) 936 if (!elementalParts.contains(&op)) 937 (void)builder.clone(op, mapper); 938 } else { 939 // For actual masked expressions, Fortran requires elemental expressions, 940 // even the scalar ones that are not encoded with hlfir.elemental, to be 941 // evaluated only when the mask is true. Blindly hoisting all scalar SSA 942 // tree could be wrong if the scalar computation has side effects and 943 // would never have been evaluated (e.g. division by zero) if the mask 944 // is fully false. See F'2023 10.2.3.2 point 10. 945 // Clone only the bodies of all hlfir.exactly_once operations, which contain 946 // the evaluation of sub-expression tree whose root was a non elemental 947 // function call at the Fortran level (the call itself may have been inlined 948 // since). These must be evaluated only once as per F'2023 10.2.3.2 point 9. 949 for (mlir::Operation &op : region.back().without_terminator()) 950 if (auto exactlyOnce = mlir::dyn_cast<hlfir::ExactlyOnceOp>(op)) { 951 for (mlir::Operation &subOp : 952 exactlyOnce.getBody().back().without_terminator()) 953 (void)builder.clone(subOp, mapper); 954 mlir::Value oldYield = getYield(exactlyOnce.getBody()).getEntity(); 955 auto newYield = mapper.lookupOrDefault(oldYield); 956 mapper.map(exactlyOnce.getResult(), newYield); 957 } 958 } 959 noneElementalPartWasGenerated = true; 960 } 961 962 mlir::Value MaskedArrayExpr::generateShape(fir::FirOpBuilder &builder, 963 mlir::IRMapping &mapper) { 964 assert(noneElementalPartWasGenerated && 965 "non elemental part must have been generated"); 966 mlir::Operation &terminator = region.back().back(); 967 // If the operation that produced the yielded entity is elemental, it was not 968 // cloned, but it holds a shape argument that was cloned. Return the cloned 969 // shape. 970 if (auto elementalAddrOp = mlir::dyn_cast<hlfir::ElementalAddrOp>(terminator)) 971 return mapper.lookupOrDefault(elementalAddrOp.getShape()); 972 mlir::Value entity = mlir::cast<hlfir::YieldOp>(terminator).getEntity(); 973 if (auto elemental = entity.getDefiningOp<hlfir::ElementalOp>()) 974 return mapper.lookupOrDefault(elemental.getShape()); 975 // Otherwise, the whole entity was cloned, and the shape can be generated 976 // from it. 977 hlfir::Entity clonedEntity{mapper.lookupOrDefault(entity)}; 978 return hlfir::genShape(loc, builder, hlfir::Entity{clonedEntity}); 979 } 980 981 mlir::Value 982 MaskedArrayExpr::generateElementalParts(fir::FirOpBuilder &builder, 983 mlir::ValueRange oneBasedIndices, 984 mlir::IRMapping &mapper) { 985 assert(noneElementalPartWasGenerated && 986 "non elemental part must have been generated"); 987 if (!isOuterMaskExpr) { 988 // Clone all operations that are not hlfir.exactly_once and that are not 989 // hlfir.elemental/hlfir.elemental_addr. 990 for (mlir::Operation &op : region.back().without_terminator()) 991 if (!mlir::isa<hlfir::ExactlyOnceOp>(op) && !elementalParts.contains(&op)) 992 (void)builder.clone(op, mapper); 993 // For the outer mask, this was already done outside of the loop. 994 } 995 // Clone and "index" bodies of hlfir.elemental/hlfir.elemental_addr. 996 mlir::Operation &terminator = region.back().back(); 997 hlfir::ElementalOpInterface elemental = 998 mlir::dyn_cast<hlfir::ElementalAddrOp>(terminator); 999 if (!elemental) { 1000 // If the terminator is not an hlfir.elemental_addr, try if the yielded 1001 // entity was produced by an hlfir.elemental. 1002 mlir::Value entity = mlir::cast<hlfir::YieldOp>(terminator).getEntity(); 1003 elemental = entity.getDefiningOp<hlfir::ElementalOp>(); 1004 if (!elemental) { 1005 // The yielded entity was not produced by an elemental operation, 1006 // get its clone in the non elemental part evaluation and address it. 1007 hlfir::Entity clonedEntity{mapper.lookupOrDefault(entity)}; 1008 return hlfir::getElementAt(loc, builder, clonedEntity, oneBasedIndices); 1009 } 1010 } 1011 1012 auto mustRecursivelyInline = 1013 [&](hlfir::ElementalOp appliedElemental) -> bool { 1014 return elementalParts.contains(appliedElemental.getOperation()); 1015 }; 1016 return inlineElementalOp(loc, builder, elemental, oneBasedIndices, mapper, 1017 mustRecursivelyInline); 1018 } 1019 1020 mlir::Operation * 1021 MaskedArrayExpr::generateMaskedExprCleanUps(fir::FirOpBuilder &builder, 1022 mlir::IRMapping &mapper) { 1023 // Clone the clean-ups from the region itself, except for the destroy 1024 // of the hlfir.elemental that have been inlined. 1025 mlir::Operation &terminator = region.back().back(); 1026 mlir::Region *cleanupRegion = nullptr; 1027 if (auto elementalAddr = mlir::dyn_cast<hlfir::ElementalAddrOp>(terminator)) { 1028 cleanupRegion = &elementalAddr.getCleanup(); 1029 } else { 1030 auto yieldOp = mlir::cast<hlfir::YieldOp>(terminator); 1031 cleanupRegion = &yieldOp.getCleanup(); 1032 } 1033 if (cleanupRegion->empty()) 1034 return nullptr; 1035 mlir::Operation *firstNewCleanup = nullptr; 1036 for (mlir::Operation &op : cleanupRegion->front().without_terminator()) { 1037 if (auto destroy = mlir::dyn_cast<hlfir::DestroyOp>(op)) 1038 if (elementalParts.contains(destroy.getExpr().getDefiningOp())) 1039 continue; 1040 mlir::Operation *cleanup = builder.clone(op, mapper); 1041 if (!firstNewCleanup) 1042 firstNewCleanup = cleanup; 1043 } 1044 return firstNewCleanup; 1045 } 1046 1047 void MaskedArrayExpr::generateNoneElementalCleanupIfAny( 1048 fir::FirOpBuilder &builder, mlir::IRMapping &mapper) { 1049 if (!isOuterMaskExpr) { 1050 // Clone clean-ups of hlfir.exactly_once operations (in reverse order 1051 // to properly deal with stack restores). 1052 for (mlir::Operation &op : 1053 llvm::reverse(region.back().without_terminator())) 1054 if (auto exactlyOnce = mlir::dyn_cast<hlfir::ExactlyOnceOp>(op)) { 1055 mlir::Region &cleanupRegion = 1056 getYield(exactlyOnce.getBody()).getCleanup(); 1057 if (!cleanupRegion.empty()) 1058 for (mlir::Operation &cleanupOp : 1059 cleanupRegion.front().without_terminator()) 1060 (void)builder.clone(cleanupOp, mapper); 1061 } 1062 } else { 1063 // For the outer mask, the region clean-ups must be generated 1064 // outside of the loops since the mask non hlfir.elemental part 1065 // is generated before the loops. 1066 generateMaskedExprCleanUps(builder, mapper); 1067 } 1068 } 1069 1070 static hlfir::RegionAssignOp 1071 getAssignIfLeftHandSideRegion(mlir::Region ®ion) { 1072 auto assign = mlir::dyn_cast<hlfir::RegionAssignOp>(region.getParentOp()); 1073 if (assign && (&assign.getLhsRegion() == ®ion)) 1074 return assign; 1075 return nullptr; 1076 } 1077 1078 bool OrderedAssignmentRewriter::currentLoopNestIterationNumberCanBeComputed( 1079 llvm::SmallVectorImpl<fir::DoLoopOp> &loopNest) { 1080 if (constructStack.empty()) 1081 return true; 1082 mlir::Operation *outerLoop = constructStack[0]; 1083 mlir::Operation *currentConstruct = constructStack.back(); 1084 // Loop through the loops until the outer construct is met, and test if the 1085 // loop operands dominate the outer construct. 1086 while (currentConstruct) { 1087 if (auto doLoop = mlir::dyn_cast<fir::DoLoopOp>(currentConstruct)) { 1088 if (llvm::any_of(doLoop->getOperands(), [&](mlir::Value value) { 1089 return !dominanceInfo.properlyDominates(value, outerLoop); 1090 })) { 1091 return false; 1092 } 1093 loopNest.push_back(doLoop); 1094 } 1095 if (currentConstruct == outerLoop) 1096 currentConstruct = nullptr; 1097 else 1098 currentConstruct = currentConstruct->getParentOp(); 1099 } 1100 return true; 1101 } 1102 1103 static mlir::Value 1104 computeLoopNestIterationNumber(mlir::Location loc, fir::FirOpBuilder &builder, 1105 llvm::ArrayRef<fir::DoLoopOp> loopNest) { 1106 mlir::Value loopExtent; 1107 for (fir::DoLoopOp doLoop : loopNest) { 1108 mlir::Value extent = builder.genExtentFromTriplet( 1109 loc, doLoop.getLowerBound(), doLoop.getUpperBound(), doLoop.getStep(), 1110 builder.getIndexType()); 1111 if (!loopExtent) 1112 loopExtent = extent; 1113 else 1114 loopExtent = builder.create<mlir::arith::MulIOp>(loc, loopExtent, extent); 1115 } 1116 assert(loopExtent && "loopNest must not be empty"); 1117 return loopExtent; 1118 } 1119 1120 /// Return a name for temporary storage that indicates in which context 1121 /// the temporary storage was created. 1122 static llvm::StringRef 1123 getTempName(hlfir::OrderedAssignmentTreeOpInterface root) { 1124 if (mlir::isa<hlfir::ForallOp>(root.getOperation())) 1125 return ".tmp.forall"; 1126 if (mlir::isa<hlfir::WhereOp>(root.getOperation())) 1127 return ".tmp.where"; 1128 return ".tmp.assign"; 1129 } 1130 1131 void OrderedAssignmentRewriter::generateSaveEntity( 1132 hlfir::SaveEntity savedEntity, bool willUseSavedEntityInSameRun) { 1133 mlir::Region ®ion = *savedEntity.yieldRegion; 1134 1135 if (hlfir::RegionAssignOp regionAssignOp = 1136 getAssignIfLeftHandSideRegion(region)) { 1137 // Need to save the address, not the values. 1138 assert(!willUseSavedEntityInSameRun && 1139 "lhs cannot be used in the loop nest where it is saved"); 1140 return saveLeftHandSide(savedEntity, regionAssignOp); 1141 } 1142 1143 mlir::Location loc = region.getParentOp()->getLoc(); 1144 // Evaluate the region inside the loop nest (if any). 1145 auto [clonedValue, oldYield] = generateYieldedEntity(region); 1146 hlfir::Entity entity{clonedValue}; 1147 entity = hlfir::loadTrivialScalar(loc, builder, entity); 1148 mlir::Type entityType = entity.getType(); 1149 1150 llvm::StringRef tempName = getTempName(root); 1151 fir::factory::TemporaryStorage *temp = nullptr; 1152 if (constructStack.empty()) { 1153 // Value evaluated outside of any loops (this may be the first MASK of a 1154 // WHERE construct, or an LHS/RHS temp of hlfir.region_assign outside of 1155 // WHERE/FORALL). 1156 temp = insertSavedEntity( 1157 region, fir::factory::SimpleCopy(loc, builder, entity, tempName)); 1158 } else { 1159 // Need to create a temporary for values computed inside loops. 1160 // Create temporary storage outside of the loop nest given the entity 1161 // type (and the loop context). 1162 llvm::SmallVector<fir::DoLoopOp> loopNest; 1163 bool loopShapeCanBePreComputed = 1164 currentLoopNestIterationNumberCanBeComputed(loopNest); 1165 doBeforeLoopNest([&] { 1166 /// For simple scalars inside loops whose total iteration number can be 1167 /// pre-computed, create a rank-1 array outside of the loops. It will be 1168 /// assigned/fetched inside the loops like a normal Fortran array given 1169 /// the iteration count. 1170 if (loopShapeCanBePreComputed && fir::isa_trivial(entityType)) { 1171 mlir::Value loopExtent = 1172 computeLoopNestIterationNumber(loc, builder, loopNest); 1173 auto sequenceType = 1174 mlir::cast<fir::SequenceType>(builder.getVarLenSeqTy(entityType)); 1175 temp = insertSavedEntity(region, 1176 fir::factory::HomogeneousScalarStack{ 1177 loc, builder, sequenceType, loopExtent, 1178 /*lenParams=*/{}, allocateOnHeap, 1179 /*stackThroughLoops=*/true, tempName}); 1180 1181 } else { 1182 // If the number of iteration is not known, or if the values at each 1183 // iterations are values that may have different shape, type parameters 1184 // or dynamic type, use the runtime to create and manage a stack-like 1185 // temporary. 1186 temp = insertSavedEntity( 1187 region, fir::factory::AnyValueStack{loc, builder, entityType}); 1188 } 1189 }); 1190 // Inside the loop nest (and any fir.if if there are active masks), copy 1191 // the value to the temp and do clean-ups for the value if any. 1192 temp->pushValue(loc, builder, entity); 1193 } 1194 1195 // Delay the clean-up if the entity will be used in the same run (i.e., the 1196 // parent construct will be visited and needs to be lowered). When possible, 1197 // this is not done for hlfir.expr because this use would prevent the 1198 // hlfir.expr storage from being moved when creating the temporary in 1199 // bufferization, and that would lead to an extra copy. 1200 if (willUseSavedEntityInSameRun && 1201 (!temp->canBeFetchedAfterPush() || 1202 !mlir::isa<hlfir::ExprType>(entity.getType()))) { 1203 auto inserted = 1204 savedInCurrentRunBeforeUse.try_emplace(®ion, entity, oldYield); 1205 assert(inserted.second && "entity must have been emplaced"); 1206 (void)inserted; 1207 } else { 1208 if (constructStack.empty() && 1209 mlir::isa<hlfir::RegionAssignOp>(region.getParentOp())) { 1210 // Here the clean-up code is inserted after the original 1211 // RegionAssignOp, so that the assignment code happens 1212 // before the cleanup. We do this only for standalone 1213 // operations, because the clean-up is handled specially 1214 // during lowering of the parent constructs if any 1215 // (e.g. see generateNoneElementalCleanupIfAny for 1216 // WhereOp). 1217 auto insertionPoint = builder.saveInsertionPoint(); 1218 builder.setInsertionPointAfter(region.getParentOp()); 1219 generateCleanupIfAny(oldYield); 1220 builder.restoreInsertionPoint(insertionPoint); 1221 } else { 1222 generateCleanupIfAny(oldYield); 1223 } 1224 } 1225 } 1226 1227 static bool rhsIsArray(hlfir::RegionAssignOp regionAssignOp) { 1228 auto yieldOp = mlir::dyn_cast<hlfir::YieldOp>( 1229 regionAssignOp.getRhsRegion().back().back()); 1230 return yieldOp && hlfir::Entity{yieldOp.getEntity()}.isArray(); 1231 } 1232 1233 void OrderedAssignmentRewriter::saveLeftHandSide( 1234 hlfir::SaveEntity savedEntity, hlfir::RegionAssignOp regionAssignOp) { 1235 mlir::Region ®ion = *savedEntity.yieldRegion; 1236 mlir::Location loc = region.getParentOp()->getLoc(); 1237 LhsValueAndCleanUp loweredLhs = generateYieldedLHS(loc, region); 1238 fir::factory::TemporaryStorage *temp = nullptr; 1239 if (loweredLhs.vectorSubscriptLoopNest) 1240 constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerOp); 1241 if (loweredLhs.vectorSubscriptLoopNest && !rhsIsArray(regionAssignOp)) { 1242 // Vector subscripted entity for which the shape must also be saved on top 1243 // of the element addresses (e.g. the shape may change in each forall 1244 // iteration and is needed to create the elemental loops). 1245 mlir::Value shape = loweredLhs.vectorSubscriptShape.value(); 1246 int rank = mlir::cast<fir::ShapeType>(shape.getType()).getRank(); 1247 const bool shapeIsInvariant = 1248 constructStack.empty() || 1249 dominanceInfo.properlyDominates(shape, constructStack[0]); 1250 doBeforeLoopNest([&] { 1251 // Outside of any forall/where/elemental loops, create a temporary that 1252 // will both be able to save the vector subscripted designator shape(s) 1253 // and element addresses. 1254 temp = 1255 insertSavedEntity(region, fir::factory::AnyVectorSubscriptStack{ 1256 loc, builder, loweredLhs.lhs.getType(), 1257 shapeIsInvariant, rank}); 1258 }); 1259 // Save shape before the elemental loop nest created by the vector 1260 // subscripted LHS. 1261 auto &vectorTmp = temp->cast<fir::factory::AnyVectorSubscriptStack>(); 1262 auto insertionPoint = builder.saveInsertionPoint(); 1263 builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerOp); 1264 vectorTmp.pushShape(loc, builder, shape); 1265 builder.restoreInsertionPoint(insertionPoint); 1266 } else { 1267 // Otherwise, only save the LHS address. 1268 // If the LHS address dominates the constructs, its SSA value can 1269 // simply be tracked and there is no need to save the address in memory. 1270 // Otherwise, the addresses are stored at each iteration in memory with 1271 // a descriptor stack. 1272 if (constructStack.empty() || 1273 dominanceInfo.properlyDominates(loweredLhs.lhs, constructStack[0])) 1274 doBeforeLoopNest([&] { 1275 temp = insertSavedEntity(region, fir::factory::SSARegister{}); 1276 }); 1277 else 1278 doBeforeLoopNest([&] { 1279 temp = insertSavedEntity( 1280 region, fir::factory::AnyVariableStack{loc, builder, 1281 loweredLhs.lhs.getType()}); 1282 }); 1283 } 1284 temp->pushValue(loc, builder, loweredLhs.lhs); 1285 generateCleanupIfAny(loweredLhs.elementalCleanup); 1286 if (loweredLhs.vectorSubscriptLoopNest) { 1287 constructStack.pop_back(); 1288 builder.setInsertionPointAfter(loweredLhs.vectorSubscriptLoopNest->outerOp); 1289 } 1290 } 1291 1292 /// Lower an ordered assignment tree to fir.do_loop and hlfir.assign given 1293 /// a schedule. 1294 static void lower(hlfir::OrderedAssignmentTreeOpInterface root, 1295 mlir::PatternRewriter &rewriter, hlfir::Schedule &schedule) { 1296 auto module = root->getParentOfType<mlir::ModuleOp>(); 1297 fir::FirOpBuilder builder(rewriter, module); 1298 OrderedAssignmentRewriter assignmentRewriter(builder, root); 1299 for (auto &run : schedule) 1300 assignmentRewriter.lowerRun(run); 1301 assignmentRewriter.cleanupSavedEntities(); 1302 } 1303 1304 /// Shared rewrite entry point for all the ordered assignment tree root 1305 /// operations. It calls the scheduler and then apply the schedule. 1306 static llvm::LogicalResult rewrite(hlfir::OrderedAssignmentTreeOpInterface root, 1307 bool tryFusingAssignments, 1308 mlir::PatternRewriter &rewriter) { 1309 hlfir::Schedule schedule = 1310 hlfir::buildEvaluationSchedule(root, tryFusingAssignments); 1311 1312 LLVM_DEBUG( 1313 /// Debug option to print the scheduling debug info without doing 1314 /// any code generation. The operations are simply erased to avoid 1315 /// failing and calling the rewrite patterns on nested operations. 1316 /// The only purpose of this is to help testing scheduling without 1317 /// having to test generated code. 1318 if (dbgScheduleOnly) { 1319 rewriter.eraseOp(root); 1320 return mlir::success(); 1321 }); 1322 lower(root, rewriter, schedule); 1323 rewriter.eraseOp(root); 1324 return mlir::success(); 1325 } 1326 1327 namespace { 1328 1329 class ForallOpConversion : public mlir::OpRewritePattern<hlfir::ForallOp> { 1330 public: 1331 explicit ForallOpConversion(mlir::MLIRContext *ctx, bool tryFusingAssignments) 1332 : OpRewritePattern{ctx}, tryFusingAssignments{tryFusingAssignments} {} 1333 1334 llvm::LogicalResult 1335 matchAndRewrite(hlfir::ForallOp forallOp, 1336 mlir::PatternRewriter &rewriter) const override { 1337 auto root = mlir::cast<hlfir::OrderedAssignmentTreeOpInterface>( 1338 forallOp.getOperation()); 1339 if (mlir::failed(::rewrite(root, tryFusingAssignments, rewriter))) 1340 TODO(forallOp.getLoc(), "FORALL construct or statement in HLFIR"); 1341 return mlir::success(); 1342 } 1343 const bool tryFusingAssignments; 1344 }; 1345 1346 class WhereOpConversion : public mlir::OpRewritePattern<hlfir::WhereOp> { 1347 public: 1348 explicit WhereOpConversion(mlir::MLIRContext *ctx, bool tryFusingAssignments) 1349 : OpRewritePattern{ctx}, tryFusingAssignments{tryFusingAssignments} {} 1350 1351 llvm::LogicalResult 1352 matchAndRewrite(hlfir::WhereOp whereOp, 1353 mlir::PatternRewriter &rewriter) const override { 1354 auto root = mlir::cast<hlfir::OrderedAssignmentTreeOpInterface>( 1355 whereOp.getOperation()); 1356 return ::rewrite(root, tryFusingAssignments, rewriter); 1357 } 1358 const bool tryFusingAssignments; 1359 }; 1360 1361 class RegionAssignConversion 1362 : public mlir::OpRewritePattern<hlfir::RegionAssignOp> { 1363 public: 1364 explicit RegionAssignConversion(mlir::MLIRContext *ctx) 1365 : OpRewritePattern{ctx} {} 1366 1367 llvm::LogicalResult 1368 matchAndRewrite(hlfir::RegionAssignOp regionAssignOp, 1369 mlir::PatternRewriter &rewriter) const override { 1370 auto root = mlir::cast<hlfir::OrderedAssignmentTreeOpInterface>( 1371 regionAssignOp.getOperation()); 1372 return ::rewrite(root, /*tryFusingAssignments=*/false, rewriter); 1373 } 1374 }; 1375 1376 class LowerHLFIROrderedAssignments 1377 : public hlfir::impl::LowerHLFIROrderedAssignmentsBase< 1378 LowerHLFIROrderedAssignments> { 1379 public: 1380 using LowerHLFIROrderedAssignmentsBase< 1381 LowerHLFIROrderedAssignments>::LowerHLFIROrderedAssignmentsBase; 1382 1383 void runOnOperation() override { 1384 // Running on a ModuleOp because this pass may generate FuncOp declaration 1385 // for runtime calls. This could be a FuncOp pass otherwise. 1386 auto module = this->getOperation(); 1387 auto *context = &getContext(); 1388 mlir::RewritePatternSet patterns(context); 1389 // Patterns are only defined for the OrderedAssignmentTreeOpInterface 1390 // operations that can be the root of ordered assignments. The other 1391 // operations will be taken care of while rewriting these trees (they 1392 // cannot exist outside of these operations given their verifiers/traits). 1393 patterns.insert<ForallOpConversion, WhereOpConversion>( 1394 context, this->tryFusingAssignments.getValue()); 1395 patterns.insert<RegionAssignConversion>(context); 1396 mlir::ConversionTarget target(*context); 1397 target.markUnknownOpDynamicallyLegal([](mlir::Operation *op) { 1398 return !mlir::isa<hlfir::OrderedAssignmentTreeOpInterface>(op); 1399 }); 1400 if (mlir::failed(mlir::applyPartialConversion(module, target, 1401 std::move(patterns)))) { 1402 mlir::emitError(mlir::UnknownLoc::get(context), 1403 "failure in HLFIR ordered assignments lowering pass"); 1404 signalPassFailure(); 1405 } 1406 } 1407 }; 1408 } // namespace 1409