1 //===-- ArrayValueCopy.cpp ------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Builder/Array.h" 10 #include "flang/Optimizer/Builder/BoxValue.h" 11 #include "flang/Optimizer/Builder/FIRBuilder.h" 12 #include "flang/Optimizer/Builder/Factory.h" 13 #include "flang/Optimizer/Builder/Runtime/Derived.h" 14 #include "flang/Optimizer/Builder/Todo.h" 15 #include "flang/Optimizer/Dialect/FIRDialect.h" 16 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 17 #include "flang/Optimizer/Support/FIRContext.h" 18 #include "flang/Optimizer/Transforms/Passes.h" 19 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 20 #include "mlir/Dialect/SCF/IR/SCF.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include "llvm/Support/Debug.h" 23 24 namespace fir { 25 #define GEN_PASS_DEF_ARRAYVALUECOPY 26 #include "flang/Optimizer/Transforms/Passes.h.inc" 27 } // namespace fir 28 29 #define DEBUG_TYPE "flang-array-value-copy" 30 31 using namespace fir; 32 using namespace mlir; 33 34 using OperationUseMapT = llvm::DenseMap<mlir::Operation *, mlir::Operation *>; 35 36 namespace { 37 38 /// Array copy analysis. 39 /// Perform an interference analysis between array values. 40 /// 41 /// Lowering will generate a sequence of the following form. 42 /// ```mlir 43 /// %a_1 = fir.array_load %array_1(%shape) : ... 44 /// ... 45 /// %a_j = fir.array_load %array_j(%shape) : ... 46 /// ... 47 /// %a_n = fir.array_load %array_n(%shape) : ... 48 /// ... 49 /// %v_i = fir.array_fetch %a_i, ... 50 /// %a_j1 = fir.array_update %a_j, ... 51 /// ... 52 /// fir.array_merge_store %a_j, %a_jn to %array_j : ... 53 /// ``` 54 /// 55 /// The analysis is to determine if there are any conflicts. A conflict is when 56 /// one the following cases occurs. 57 /// 58 /// 1. There is an `array_update` to an array value, a_j, such that a_j was 59 /// loaded from the same array memory reference (array_j) but with a different 60 /// shape as the other array values a_i, where i != j. [Possible overlapping 61 /// arrays.] 62 /// 63 /// 2. There is either an array_fetch or array_update of a_j with a different 64 /// set of index values. [Possible loop-carried dependence.] 65 /// 66 /// If none of the array values overlap in storage and the accesses are not 67 /// loop-carried, then the arrays are conflict-free and no copies are required. 68 class ArrayCopyAnalysis { 69 public: 70 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ArrayCopyAnalysis) 71 72 using ConflictSetT = llvm::SmallPtrSet<mlir::Operation *, 16>; 73 using UseSetT = llvm::SmallPtrSet<mlir::OpOperand *, 8>; 74 using LoadMapSetsT = llvm::DenseMap<mlir::Operation *, UseSetT>; 75 using AmendAccessSetT = llvm::SmallPtrSet<mlir::Operation *, 4>; 76 77 ArrayCopyAnalysis(mlir::Operation *op) : operation{op} { construct(op); } 78 79 mlir::Operation *getOperation() const { return operation; } 80 81 /// Return true iff the `array_merge_store` has potential conflicts. 82 bool hasPotentialConflict(mlir::Operation *op) const { 83 LLVM_DEBUG(llvm::dbgs() 84 << "looking for a conflict on " << *op 85 << " and the set has a total of " << conflicts.size() << '\n'); 86 return conflicts.contains(op); 87 } 88 89 /// Return the use map. 90 /// The use map maps array access, amend, fetch and update operations back to 91 /// the array load that is the original source of the array value. 92 /// It maps an array_load to an array_merge_store, if and only if the loaded 93 /// array value has pending modifications to be merged. 94 const OperationUseMapT &getUseMap() const { return useMap; } 95 96 /// Return the set of array_access ops directly associated with array_amend 97 /// ops. 98 bool inAmendAccessSet(mlir::Operation *op) const { 99 return amendAccesses.count(op); 100 } 101 102 /// For ArrayLoad `load`, return the transitive set of all OpOperands. 103 UseSetT getLoadUseSet(mlir::Operation *load) const { 104 assert(loadMapSets.count(load) && "analysis missed an array load?"); 105 return loadMapSets.lookup(load); 106 } 107 108 void arrayMentions(llvm::SmallVectorImpl<mlir::Operation *> &mentions, 109 ArrayLoadOp load); 110 111 private: 112 void construct(mlir::Operation *topLevelOp); 113 114 mlir::Operation *operation; // operation that analysis ran upon 115 ConflictSetT conflicts; // set of conflicts (loads and merge stores) 116 OperationUseMapT useMap; 117 LoadMapSetsT loadMapSets; 118 // Set of array_access ops associated with array_amend ops. 119 AmendAccessSetT amendAccesses; 120 }; 121 } // namespace 122 123 namespace { 124 /// Helper class to collect all array operations that produced an array value. 125 class ReachCollector { 126 public: 127 ReachCollector(llvm::SmallVectorImpl<mlir::Operation *> &reach, 128 mlir::Region *loopRegion) 129 : reach{reach}, loopRegion{loopRegion} {} 130 131 void collectArrayMentionFrom(mlir::Operation *op, mlir::ValueRange range) { 132 if (range.empty()) { 133 collectArrayMentionFrom(op, mlir::Value{}); 134 return; 135 } 136 for (mlir::Value v : range) 137 collectArrayMentionFrom(v); 138 } 139 140 // Collect all the array_access ops in `block`. This recursively looks into 141 // blocks in ops with regions. 142 // FIXME: This is temporarily relying on the array_amend appearing in a 143 // do_loop Region. This phase ordering assumption can be eliminated by using 144 // dominance information to find the array_access ops or by scanning the 145 // transitive closure of the amending array_access's users and the defs that 146 // reach them. 147 void collectAccesses(llvm::SmallVector<ArrayAccessOp> &result, 148 mlir::Block *block) { 149 for (auto &op : *block) { 150 if (auto access = mlir::dyn_cast<ArrayAccessOp>(op)) { 151 LLVM_DEBUG(llvm::dbgs() << "adding access: " << access << '\n'); 152 result.push_back(access); 153 continue; 154 } 155 for (auto ®ion : op.getRegions()) 156 for (auto &bb : region.getBlocks()) 157 collectAccesses(result, &bb); 158 } 159 } 160 161 void collectArrayMentionFrom(mlir::Operation *op, mlir::Value val) { 162 // `val` is defined by an Op, process the defining Op. 163 // If `val` is defined by a region containing Op, we want to drill down 164 // and through that Op's region(s). 165 LLVM_DEBUG(llvm::dbgs() << "popset: " << *op << '\n'); 166 auto popFn = [&](auto rop) { 167 assert(val && "op must have a result value"); 168 auto resNum = val.cast<mlir::OpResult>().getResultNumber(); 169 llvm::SmallVector<mlir::Value> results; 170 rop.resultToSourceOps(results, resNum); 171 for (auto u : results) 172 collectArrayMentionFrom(u); 173 }; 174 if (auto rop = mlir::dyn_cast<DoLoopOp>(op)) { 175 popFn(rop); 176 return; 177 } 178 if (auto rop = mlir::dyn_cast<IterWhileOp>(op)) { 179 popFn(rop); 180 return; 181 } 182 if (auto rop = mlir::dyn_cast<fir::IfOp>(op)) { 183 popFn(rop); 184 return; 185 } 186 if (auto box = mlir::dyn_cast<EmboxOp>(op)) { 187 for (auto *user : box.getMemref().getUsers()) 188 if (user != op) 189 collectArrayMentionFrom(user, user->getResults()); 190 return; 191 } 192 if (auto mergeStore = mlir::dyn_cast<ArrayMergeStoreOp>(op)) { 193 if (opIsInsideLoops(mergeStore)) 194 collectArrayMentionFrom(mergeStore.getSequence()); 195 return; 196 } 197 198 if (mlir::isa<AllocaOp, AllocMemOp>(op)) { 199 // Look for any stores inside the loops, and collect an array operation 200 // that produced the value being stored to it. 201 for (auto *user : op->getUsers()) 202 if (auto store = mlir::dyn_cast<fir::StoreOp>(user)) 203 if (opIsInsideLoops(store)) 204 collectArrayMentionFrom(store.getValue()); 205 return; 206 } 207 208 // Scan the uses of amend's memref 209 if (auto amend = mlir::dyn_cast<ArrayAmendOp>(op)) { 210 reach.push_back(op); 211 llvm::SmallVector<ArrayAccessOp> accesses; 212 collectAccesses(accesses, op->getBlock()); 213 for (auto access : accesses) 214 collectArrayMentionFrom(access.getResult()); 215 } 216 217 // Otherwise, Op does not contain a region so just chase its operands. 218 if (mlir::isa<ArrayAccessOp, ArrayLoadOp, ArrayUpdateOp, ArrayModifyOp, 219 ArrayFetchOp>(op)) { 220 LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n"); 221 reach.push_back(op); 222 } 223 224 // Include all array_access ops using an array_load. 225 if (auto arrLd = mlir::dyn_cast<ArrayLoadOp>(op)) 226 for (auto *user : arrLd.getResult().getUsers()) 227 if (mlir::isa<ArrayAccessOp>(user)) { 228 LLVM_DEBUG(llvm::dbgs() << "add " << *user << " to reachable set\n"); 229 reach.push_back(user); 230 } 231 232 // Array modify assignment is performed on the result. So the analysis must 233 // look at the what is done with the result. 234 if (mlir::isa<ArrayModifyOp>(op)) 235 for (auto *user : op->getResult(0).getUsers()) 236 followUsers(user); 237 238 if (mlir::isa<fir::CallOp>(op)) { 239 LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n"); 240 reach.push_back(op); 241 } 242 243 for (auto u : op->getOperands()) 244 collectArrayMentionFrom(u); 245 } 246 247 void collectArrayMentionFrom(mlir::BlockArgument ba) { 248 auto *parent = ba.getOwner()->getParentOp(); 249 // If inside an Op holding a region, the block argument corresponds to an 250 // argument passed to the containing Op. 251 auto popFn = [&](auto rop) { 252 collectArrayMentionFrom(rop.blockArgToSourceOp(ba.getArgNumber())); 253 }; 254 if (auto rop = mlir::dyn_cast<DoLoopOp>(parent)) { 255 popFn(rop); 256 return; 257 } 258 if (auto rop = mlir::dyn_cast<IterWhileOp>(parent)) { 259 popFn(rop); 260 return; 261 } 262 // Otherwise, a block argument is provided via the pred blocks. 263 for (auto *pred : ba.getOwner()->getPredecessors()) { 264 auto u = pred->getTerminator()->getOperand(ba.getArgNumber()); 265 collectArrayMentionFrom(u); 266 } 267 } 268 269 // Recursively trace operands to find all array operations relating to the 270 // values merged. 271 void collectArrayMentionFrom(mlir::Value val) { 272 if (!val || visited.contains(val)) 273 return; 274 visited.insert(val); 275 276 // Process a block argument. 277 if (auto ba = val.dyn_cast<mlir::BlockArgument>()) { 278 collectArrayMentionFrom(ba); 279 return; 280 } 281 282 // Process an Op. 283 if (auto *op = val.getDefiningOp()) { 284 collectArrayMentionFrom(op, val); 285 return; 286 } 287 288 emitFatalError(val.getLoc(), "unhandled value"); 289 } 290 291 /// Return all ops that produce the array value that is stored into the 292 /// `array_merge_store`. 293 static void reachingValues(llvm::SmallVectorImpl<mlir::Operation *> &reach, 294 mlir::Value seq) { 295 reach.clear(); 296 mlir::Region *loopRegion = nullptr; 297 if (auto doLoop = mlir::dyn_cast_or_null<DoLoopOp>(seq.getDefiningOp())) 298 loopRegion = &doLoop->getRegion(0); 299 ReachCollector collector(reach, loopRegion); 300 collector.collectArrayMentionFrom(seq); 301 } 302 303 private: 304 /// Is \op inside the loop nest region ? 305 /// FIXME: replace this structural dependence with graph properties. 306 bool opIsInsideLoops(mlir::Operation *op) const { 307 auto *region = op->getParentRegion(); 308 while (region) { 309 if (region == loopRegion) 310 return true; 311 region = region->getParentRegion(); 312 } 313 return false; 314 } 315 316 /// Recursively trace the use of an operation results, calling 317 /// collectArrayMentionFrom on the direct and indirect user operands. 318 void followUsers(mlir::Operation *op) { 319 for (auto userOperand : op->getOperands()) 320 collectArrayMentionFrom(userOperand); 321 // Go through potential converts/coordinate_op. 322 for (auto indirectUser : op->getUsers()) 323 followUsers(indirectUser); 324 } 325 326 llvm::SmallVectorImpl<mlir::Operation *> &reach; 327 llvm::SmallPtrSet<mlir::Value, 16> visited; 328 /// Region of the loops nest that produced the array value. 329 mlir::Region *loopRegion; 330 }; 331 } // namespace 332 333 /// Find all the array operations that access the array value that is loaded by 334 /// the array load operation, `load`. 335 void ArrayCopyAnalysis::arrayMentions( 336 llvm::SmallVectorImpl<mlir::Operation *> &mentions, ArrayLoadOp load) { 337 mentions.clear(); 338 auto lmIter = loadMapSets.find(load); 339 if (lmIter != loadMapSets.end()) { 340 for (auto *opnd : lmIter->second) { 341 auto *owner = opnd->getOwner(); 342 if (mlir::isa<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, ArrayUpdateOp, 343 ArrayModifyOp>(owner)) 344 mentions.push_back(owner); 345 } 346 return; 347 } 348 349 UseSetT visited; 350 llvm::SmallVector<mlir::OpOperand *> queue; // uses of ArrayLoad[orig] 351 352 auto appendToQueue = [&](mlir::Value val) { 353 for (auto &use : val.getUses()) 354 if (!visited.count(&use)) { 355 visited.insert(&use); 356 queue.push_back(&use); 357 } 358 }; 359 360 // Build the set of uses of `original`. 361 // let USES = { uses of original fir.load } 362 appendToQueue(load); 363 364 // Process the worklist until done. 365 while (!queue.empty()) { 366 mlir::OpOperand *operand = queue.pop_back_val(); 367 mlir::Operation *owner = operand->getOwner(); 368 if (!owner) 369 continue; 370 auto structuredLoop = [&](auto ro) { 371 if (auto blockArg = ro.iterArgToBlockArg(operand->get())) { 372 int64_t arg = blockArg.getArgNumber(); 373 mlir::Value output = ro.getResult(ro.getFinalValue() ? arg : arg - 1); 374 appendToQueue(output); 375 appendToQueue(blockArg); 376 } 377 }; 378 // TODO: this need to be updated to use the control-flow interface. 379 auto branchOp = [&](mlir::Block *dest, OperandRange operands) { 380 if (operands.empty()) 381 return; 382 383 // Check if this operand is within the range. 384 unsigned operandIndex = operand->getOperandNumber(); 385 unsigned operandsStart = operands.getBeginOperandIndex(); 386 if (operandIndex < operandsStart || 387 operandIndex >= (operandsStart + operands.size())) 388 return; 389 390 // Index the successor. 391 unsigned argIndex = operandIndex - operandsStart; 392 appendToQueue(dest->getArgument(argIndex)); 393 }; 394 // Thread uses into structured loop bodies and return value uses. 395 if (auto ro = mlir::dyn_cast<DoLoopOp>(owner)) { 396 structuredLoop(ro); 397 } else if (auto ro = mlir::dyn_cast<IterWhileOp>(owner)) { 398 structuredLoop(ro); 399 } else if (auto rs = mlir::dyn_cast<ResultOp>(owner)) { 400 // Thread any uses of fir.if that return the marked array value. 401 mlir::Operation *parent = rs->getParentRegion()->getParentOp(); 402 if (auto ifOp = mlir::dyn_cast<fir::IfOp>(parent)) 403 appendToQueue(ifOp.getResult(operand->getOperandNumber())); 404 } else if (mlir::isa<ArrayFetchOp>(owner)) { 405 // Keep track of array value fetches. 406 LLVM_DEBUG(llvm::dbgs() 407 << "add fetch {" << *owner << "} to array value set\n"); 408 mentions.push_back(owner); 409 } else if (auto update = mlir::dyn_cast<ArrayUpdateOp>(owner)) { 410 // Keep track of array value updates and thread the return value uses. 411 LLVM_DEBUG(llvm::dbgs() 412 << "add update {" << *owner << "} to array value set\n"); 413 mentions.push_back(owner); 414 appendToQueue(update.getResult()); 415 } else if (auto update = mlir::dyn_cast<ArrayModifyOp>(owner)) { 416 // Keep track of array value modification and thread the return value 417 // uses. 418 LLVM_DEBUG(llvm::dbgs() 419 << "add modify {" << *owner << "} to array value set\n"); 420 mentions.push_back(owner); 421 appendToQueue(update.getResult(1)); 422 } else if (auto mention = mlir::dyn_cast<ArrayAccessOp>(owner)) { 423 mentions.push_back(owner); 424 } else if (auto amend = mlir::dyn_cast<ArrayAmendOp>(owner)) { 425 mentions.push_back(owner); 426 appendToQueue(amend.getResult()); 427 } else if (auto br = mlir::dyn_cast<mlir::cf::BranchOp>(owner)) { 428 branchOp(br.getDest(), br.getDestOperands()); 429 } else if (auto br = mlir::dyn_cast<mlir::cf::CondBranchOp>(owner)) { 430 branchOp(br.getTrueDest(), br.getTrueOperands()); 431 branchOp(br.getFalseDest(), br.getFalseOperands()); 432 } else if (mlir::isa<ArrayMergeStoreOp>(owner)) { 433 // do nothing 434 } else { 435 llvm::report_fatal_error("array value reached unexpected op"); 436 } 437 } 438 loadMapSets.insert({load, visited}); 439 } 440 441 static bool hasPointerType(mlir::Type type) { 442 if (auto boxTy = type.dyn_cast<BoxType>()) 443 type = boxTy.getEleTy(); 444 return type.isa<fir::PointerType>(); 445 } 446 447 // This is a NF performance hack. It makes a simple test that the slices of the 448 // load, \p ld, and the merge store, \p st, are trivially mutually exclusive. 449 static bool mutuallyExclusiveSliceRange(ArrayLoadOp ld, ArrayMergeStoreOp st) { 450 // If the same array_load, then no further testing is warranted. 451 if (ld.getResult() == st.getOriginal()) 452 return false; 453 454 auto getSliceOp = [](mlir::Value val) -> SliceOp { 455 if (!val) 456 return {}; 457 auto sliceOp = mlir::dyn_cast_or_null<SliceOp>(val.getDefiningOp()); 458 if (!sliceOp) 459 return {}; 460 return sliceOp; 461 }; 462 463 auto ldSlice = getSliceOp(ld.getSlice()); 464 auto stSlice = getSliceOp(st.getSlice()); 465 if (!ldSlice || !stSlice) 466 return false; 467 468 // Resign on subobject slices. 469 if (!ldSlice.getFields().empty() || !stSlice.getFields().empty() || 470 !ldSlice.getSubstr().empty() || !stSlice.getSubstr().empty()) 471 return false; 472 473 // Crudely test that the two slices do not overlap by looking for the 474 // following general condition. If the slices look like (i:j) and (j+1:k) then 475 // these ranges do not overlap. The addend must be a constant. 476 auto ldTriples = ldSlice.getTriples(); 477 auto stTriples = stSlice.getTriples(); 478 const auto size = ldTriples.size(); 479 if (size != stTriples.size()) 480 return false; 481 482 auto displacedByConstant = [](mlir::Value v1, mlir::Value v2) { 483 auto removeConvert = [](mlir::Value v) -> mlir::Operation * { 484 auto *op = v.getDefiningOp(); 485 while (auto conv = mlir::dyn_cast_or_null<ConvertOp>(op)) 486 op = conv.getValue().getDefiningOp(); 487 return op; 488 }; 489 490 auto isPositiveConstant = [](mlir::Value v) -> bool { 491 if (auto conOp = 492 mlir::dyn_cast<mlir::arith::ConstantOp>(v.getDefiningOp())) 493 if (auto iattr = conOp.getValue().dyn_cast<mlir::IntegerAttr>()) 494 return iattr.getInt() > 0; 495 return false; 496 }; 497 498 auto *op1 = removeConvert(v1); 499 auto *op2 = removeConvert(v2); 500 if (!op1 || !op2) 501 return false; 502 if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2)) 503 if ((addi.getLhs().getDefiningOp() == op1 && 504 isPositiveConstant(addi.getRhs())) || 505 (addi.getRhs().getDefiningOp() == op1 && 506 isPositiveConstant(addi.getLhs()))) 507 return true; 508 if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1)) 509 if (subi.getLhs().getDefiningOp() == op2 && 510 isPositiveConstant(subi.getRhs())) 511 return true; 512 return false; 513 }; 514 515 for (std::remove_const_t<decltype(size)> i = 0; i < size; i += 3) { 516 // If both are loop invariant, skip to the next triple. 517 if (mlir::isa_and_nonnull<fir::UndefOp>(ldTriples[i + 1].getDefiningOp()) && 518 mlir::isa_and_nonnull<fir::UndefOp>(stTriples[i + 1].getDefiningOp())) { 519 // Unless either is a vector index, then be conservative. 520 if (mlir::isa_and_nonnull<fir::UndefOp>(ldTriples[i].getDefiningOp()) || 521 mlir::isa_and_nonnull<fir::UndefOp>(stTriples[i].getDefiningOp())) 522 return false; 523 continue; 524 } 525 // If identical, skip to the next triple. 526 if (ldTriples[i] == stTriples[i] && ldTriples[i + 1] == stTriples[i + 1] && 527 ldTriples[i + 2] == stTriples[i + 2]) 528 continue; 529 // If ubound and lbound are the same with a constant offset, skip to the 530 // next triple. 531 if (displacedByConstant(ldTriples[i + 1], stTriples[i]) || 532 displacedByConstant(stTriples[i + 1], ldTriples[i])) 533 continue; 534 return false; 535 } 536 LLVM_DEBUG(llvm::dbgs() << "detected non-overlapping slice ranges on " << ld 537 << " and " << st << ", which is not a conflict\n"); 538 return true; 539 } 540 541 /// Is there a conflict between the array value that was updated and to be 542 /// stored to `st` and the set of arrays loaded (`reach`) and used to compute 543 /// the updated value? 544 static bool conflictOnLoad(llvm::ArrayRef<mlir::Operation *> reach, 545 ArrayMergeStoreOp st) { 546 mlir::Value load; 547 mlir::Value addr = st.getMemref(); 548 const bool storeHasPointerType = hasPointerType(addr.getType()); 549 for (auto *op : reach) 550 if (auto ld = mlir::dyn_cast<ArrayLoadOp>(op)) { 551 mlir::Type ldTy = ld.getMemref().getType(); 552 if (ld.getMemref() == addr) { 553 if (mutuallyExclusiveSliceRange(ld, st)) 554 continue; 555 if (ld.getResult() != st.getOriginal()) 556 return true; 557 if (load) { 558 // TODO: extend this to allow checking if the first `load` and this 559 // `ld` are mutually exclusive accesses but not identical. 560 return true; 561 } 562 load = ld; 563 } else if ((hasPointerType(ldTy) || storeHasPointerType)) { 564 // TODO: Use target attribute to restrict this case further. 565 // TODO: Check if types can also allow ruling out some cases. For now, 566 // the fact that equivalences is using pointer attribute to enforce 567 // aliasing is preventing any attempt to do so, and in general, it may 568 // be wrong to use this if any of the types is a complex or a derived 569 // for which it is possible to create a pointer to a part with a 570 // different type than the whole, although this deserve some more 571 // investigation because existing compiler behavior seem to diverge 572 // here. 573 return true; 574 } 575 } 576 return false; 577 } 578 579 /// Is there an access vector conflict on the array being merged into? If the 580 /// access vectors diverge, then assume that there are potentially overlapping 581 /// loop-carried references. 582 static bool conflictOnMerge(llvm::ArrayRef<mlir::Operation *> mentions) { 583 if (mentions.size() < 2) 584 return false; 585 llvm::SmallVector<mlir::Value> indices; 586 LLVM_DEBUG(llvm::dbgs() << "check merge conflict on with " << mentions.size() 587 << " mentions on the list\n"); 588 bool valSeen = false; 589 bool refSeen = false; 590 for (auto *op : mentions) { 591 llvm::SmallVector<mlir::Value> compareVector; 592 if (auto u = mlir::dyn_cast<ArrayUpdateOp>(op)) { 593 valSeen = true; 594 if (indices.empty()) { 595 indices = u.getIndices(); 596 continue; 597 } 598 compareVector = u.getIndices(); 599 } else if (auto f = mlir::dyn_cast<ArrayModifyOp>(op)) { 600 valSeen = true; 601 if (indices.empty()) { 602 indices = f.getIndices(); 603 continue; 604 } 605 compareVector = f.getIndices(); 606 } else if (auto f = mlir::dyn_cast<ArrayFetchOp>(op)) { 607 valSeen = true; 608 if (indices.empty()) { 609 indices = f.getIndices(); 610 continue; 611 } 612 compareVector = f.getIndices(); 613 } else if (auto f = mlir::dyn_cast<ArrayAccessOp>(op)) { 614 refSeen = true; 615 if (indices.empty()) { 616 indices = f.getIndices(); 617 continue; 618 } 619 compareVector = f.getIndices(); 620 } else if (mlir::isa<ArrayAmendOp>(op)) { 621 refSeen = true; 622 continue; 623 } else { 624 mlir::emitError(op->getLoc(), "unexpected operation in analysis"); 625 } 626 if (compareVector.size() != indices.size() || 627 llvm::any_of(llvm::zip(compareVector, indices), [&](auto pair) { 628 return std::get<0>(pair) != std::get<1>(pair); 629 })) 630 return true; 631 LLVM_DEBUG(llvm::dbgs() << "vectors compare equal\n"); 632 } 633 return valSeen && refSeen; 634 } 635 636 /// With element-by-reference semantics, an amended array with more than once 637 /// access to the same loaded array are conservatively considered a conflict. 638 /// Note: the array copy can still be eliminated in subsequent optimizations. 639 static bool conflictOnReference(llvm::ArrayRef<mlir::Operation *> mentions) { 640 LLVM_DEBUG(llvm::dbgs() << "checking reference semantics " << mentions.size() 641 << '\n'); 642 if (mentions.size() < 3) 643 return false; 644 unsigned amendCount = 0; 645 unsigned accessCount = 0; 646 for (auto *op : mentions) { 647 if (mlir::isa<ArrayAmendOp>(op) && ++amendCount > 1) { 648 LLVM_DEBUG(llvm::dbgs() << "conflict: multiple amends of array value\n"); 649 return true; 650 } 651 if (mlir::isa<ArrayAccessOp>(op) && ++accessCount > 1) { 652 LLVM_DEBUG(llvm::dbgs() 653 << "conflict: multiple accesses of array value\n"); 654 return true; 655 } 656 if (mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(op)) { 657 LLVM_DEBUG(llvm::dbgs() 658 << "conflict: array value has both uses by-value and uses " 659 "by-reference. conservative assumption.\n"); 660 return true; 661 } 662 } 663 return false; 664 } 665 666 static mlir::Operation * 667 amendingAccess(llvm::ArrayRef<mlir::Operation *> mentions) { 668 for (auto *op : mentions) 669 if (auto amend = mlir::dyn_cast<ArrayAmendOp>(op)) 670 return amend.getMemref().getDefiningOp(); 671 return {}; 672 } 673 674 // Are any conflicts present? The conflicts detected here are described above. 675 static bool conflictDetected(llvm::ArrayRef<mlir::Operation *> reach, 676 llvm::ArrayRef<mlir::Operation *> mentions, 677 ArrayMergeStoreOp st) { 678 return conflictOnLoad(reach, st) || conflictOnMerge(mentions); 679 } 680 681 // Assume that any call to a function that uses host-associations will be 682 // modifying the output array. 683 static bool 684 conservativeCallConflict(llvm::ArrayRef<mlir::Operation *> reaches) { 685 return llvm::any_of(reaches, [](mlir::Operation *op) { 686 if (auto call = mlir::dyn_cast<fir::CallOp>(op)) 687 if (auto callee = 688 call.getCallableForCallee().dyn_cast<mlir::SymbolRefAttr>()) { 689 auto module = op->getParentOfType<mlir::ModuleOp>(); 690 return hasHostAssociationArgument( 691 module.lookupSymbol<mlir::func::FuncOp>(callee)); 692 } 693 return false; 694 }); 695 } 696 697 /// Constructor of the array copy analysis. 698 /// This performs the analysis and saves the intermediate results. 699 void ArrayCopyAnalysis::construct(mlir::Operation *topLevelOp) { 700 topLevelOp->walk([&](Operation *op) { 701 if (auto st = mlir::dyn_cast<fir::ArrayMergeStoreOp>(op)) { 702 llvm::SmallVector<mlir::Operation *> values; 703 ReachCollector::reachingValues(values, st.getSequence()); 704 bool callConflict = conservativeCallConflict(values); 705 llvm::SmallVector<mlir::Operation *> mentions; 706 arrayMentions(mentions, 707 mlir::cast<ArrayLoadOp>(st.getOriginal().getDefiningOp())); 708 bool conflict = conflictDetected(values, mentions, st); 709 bool refConflict = conflictOnReference(mentions); 710 if (callConflict || conflict || refConflict) { 711 LLVM_DEBUG(llvm::dbgs() 712 << "CONFLICT: copies required for " << st << '\n' 713 << " adding conflicts on: " << *op << " and " 714 << st.getOriginal() << '\n'); 715 conflicts.insert(op); 716 conflicts.insert(st.getOriginal().getDefiningOp()); 717 if (auto *access = amendingAccess(mentions)) 718 amendAccesses.insert(access); 719 } 720 auto *ld = st.getOriginal().getDefiningOp(); 721 LLVM_DEBUG(llvm::dbgs() 722 << "map: adding {" << *ld << " -> " << st << "}\n"); 723 useMap.insert({ld, op}); 724 } else if (auto load = mlir::dyn_cast<ArrayLoadOp>(op)) { 725 llvm::SmallVector<mlir::Operation *> mentions; 726 arrayMentions(mentions, load); 727 LLVM_DEBUG(llvm::dbgs() << "process load: " << load 728 << ", mentions: " << mentions.size() << '\n'); 729 for (auto *acc : mentions) { 730 LLVM_DEBUG(llvm::dbgs() << " mention: " << *acc << '\n'); 731 if (mlir::isa<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, ArrayUpdateOp, 732 ArrayModifyOp>(acc)) { 733 if (useMap.count(acc)) { 734 mlir::emitError( 735 load.getLoc(), 736 "The parallel semantics of multiple array_merge_stores per " 737 "array_load are not supported."); 738 continue; 739 } 740 LLVM_DEBUG(llvm::dbgs() 741 << "map: adding {" << *acc << "} -> {" << load << "}\n"); 742 useMap.insert({acc, op}); 743 } 744 } 745 } 746 }); 747 } 748 749 //===----------------------------------------------------------------------===// 750 // Conversions for converting out of array value form. 751 //===----------------------------------------------------------------------===// 752 753 namespace { 754 class ArrayLoadConversion : public mlir::OpRewritePattern<ArrayLoadOp> { 755 public: 756 using OpRewritePattern::OpRewritePattern; 757 758 mlir::LogicalResult 759 matchAndRewrite(ArrayLoadOp load, 760 mlir::PatternRewriter &rewriter) const override { 761 LLVM_DEBUG(llvm::dbgs() << "replace load " << load << " with undef.\n"); 762 rewriter.replaceOpWithNewOp<UndefOp>(load, load.getType()); 763 return mlir::success(); 764 } 765 }; 766 767 class ArrayMergeStoreConversion 768 : public mlir::OpRewritePattern<ArrayMergeStoreOp> { 769 public: 770 using OpRewritePattern::OpRewritePattern; 771 772 mlir::LogicalResult 773 matchAndRewrite(ArrayMergeStoreOp store, 774 mlir::PatternRewriter &rewriter) const override { 775 LLVM_DEBUG(llvm::dbgs() << "marking store " << store << " as dead.\n"); 776 rewriter.eraseOp(store); 777 return mlir::success(); 778 } 779 }; 780 } // namespace 781 782 static mlir::Type getEleTy(mlir::Type ty) { 783 auto eleTy = unwrapSequenceType(unwrapPassByRefType(ty)); 784 // FIXME: keep ptr/heap/ref information. 785 return ReferenceType::get(eleTy); 786 } 787 788 // Extract extents from the ShapeOp/ShapeShiftOp into the result vector. 789 static bool getAdjustedExtents(mlir::Location loc, 790 mlir::PatternRewriter &rewriter, 791 ArrayLoadOp arrLoad, 792 llvm::SmallVectorImpl<mlir::Value> &result, 793 mlir::Value shape) { 794 bool copyUsingSlice = false; 795 auto *shapeOp = shape.getDefiningOp(); 796 if (auto s = mlir::dyn_cast_or_null<ShapeOp>(shapeOp)) { 797 auto e = s.getExtents(); 798 result.insert(result.end(), e.begin(), e.end()); 799 } else if (auto s = mlir::dyn_cast_or_null<ShapeShiftOp>(shapeOp)) { 800 auto e = s.getExtents(); 801 result.insert(result.end(), e.begin(), e.end()); 802 } else { 803 emitFatalError(loc, "not a fir.shape/fir.shape_shift op"); 804 } 805 auto idxTy = rewriter.getIndexType(); 806 if (factory::isAssumedSize(result)) { 807 // Use slice information to compute the extent of the column. 808 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1); 809 mlir::Value size = one; 810 if (mlir::Value sliceArg = arrLoad.getSlice()) { 811 if (auto sliceOp = 812 mlir::dyn_cast_or_null<SliceOp>(sliceArg.getDefiningOp())) { 813 auto triples = sliceOp.getTriples(); 814 const std::size_t tripleSize = triples.size(); 815 auto module = arrLoad->getParentOfType<mlir::ModuleOp>(); 816 fir::KindMapping kindMap = getKindMapping(module); 817 FirOpBuilder builder(rewriter, kindMap); 818 size = builder.genExtentFromTriplet(loc, triples[tripleSize - 3], 819 triples[tripleSize - 2], 820 triples[tripleSize - 1], idxTy); 821 copyUsingSlice = true; 822 } 823 } 824 result[result.size() - 1] = size; 825 } 826 return copyUsingSlice; 827 } 828 829 /// Place the extents of the array load, \p arrLoad, into \p result and 830 /// return a ShapeOp or ShapeShiftOp with the same extents. If \p arrLoad is 831 /// loading a `!fir.box`, code will be generated to read the extents from the 832 /// boxed value, and the retunred shape Op will be built with the extents read 833 /// from the box. Otherwise, the extents will be extracted from the ShapeOp (or 834 /// ShapeShiftOp) argument of \p arrLoad. \p copyUsingSlice will be set to true 835 /// if slicing of the output array is to be done in the copy-in/copy-out rather 836 /// than in the elemental computation step. 837 static mlir::Value getOrReadExtentsAndShapeOp( 838 mlir::Location loc, mlir::PatternRewriter &rewriter, ArrayLoadOp arrLoad, 839 llvm::SmallVectorImpl<mlir::Value> &result, bool ©UsingSlice) { 840 assert(result.empty()); 841 if (arrLoad->hasAttr(fir::getOptionalAttrName())) 842 fir::emitFatalError( 843 loc, "shapes from array load of OPTIONAL arrays must not be used"); 844 if (auto boxTy = arrLoad.getMemref().getType().dyn_cast<BoxType>()) { 845 auto rank = 846 dyn_cast_ptrOrBoxEleTy(boxTy).cast<SequenceType>().getDimension(); 847 auto idxTy = rewriter.getIndexType(); 848 for (decltype(rank) dim = 0; dim < rank; ++dim) { 849 auto dimVal = rewriter.create<mlir::arith::ConstantIndexOp>(loc, dim); 850 auto dimInfo = rewriter.create<BoxDimsOp>(loc, idxTy, idxTy, idxTy, 851 arrLoad.getMemref(), dimVal); 852 result.emplace_back(dimInfo.getResult(1)); 853 } 854 if (!arrLoad.getShape()) { 855 auto shapeType = ShapeType::get(rewriter.getContext(), rank); 856 return rewriter.create<ShapeOp>(loc, shapeType, result); 857 } 858 auto shiftOp = arrLoad.getShape().getDefiningOp<ShiftOp>(); 859 auto shapeShiftType = ShapeShiftType::get(rewriter.getContext(), rank); 860 llvm::SmallVector<mlir::Value> shapeShiftOperands; 861 for (auto [lb, extent] : llvm::zip(shiftOp.getOrigins(), result)) { 862 shapeShiftOperands.push_back(lb); 863 shapeShiftOperands.push_back(extent); 864 } 865 return rewriter.create<ShapeShiftOp>(loc, shapeShiftType, 866 shapeShiftOperands); 867 } 868 copyUsingSlice = 869 getAdjustedExtents(loc, rewriter, arrLoad, result, arrLoad.getShape()); 870 return arrLoad.getShape(); 871 } 872 873 static mlir::Type toRefType(mlir::Type ty) { 874 if (fir::isa_ref_type(ty)) 875 return ty; 876 return fir::ReferenceType::get(ty); 877 } 878 879 static llvm::SmallVector<mlir::Value> 880 getTypeParamsIfRawData(mlir::Location loc, FirOpBuilder &builder, 881 ArrayLoadOp arrLoad, mlir::Type ty) { 882 if (ty.isa<BoxType>()) 883 return {}; 884 return fir::factory::getTypeParams(loc, builder, arrLoad); 885 } 886 887 static mlir::Value genCoorOp(mlir::PatternRewriter &rewriter, 888 mlir::Location loc, mlir::Type eleTy, 889 mlir::Type resTy, mlir::Value alloc, 890 mlir::Value shape, mlir::Value slice, 891 mlir::ValueRange indices, ArrayLoadOp load, 892 bool skipOrig = false) { 893 llvm::SmallVector<mlir::Value> originated; 894 if (skipOrig) 895 originated.assign(indices.begin(), indices.end()); 896 else 897 originated = factory::originateIndices(loc, rewriter, alloc.getType(), 898 shape, indices); 899 auto seqTy = dyn_cast_ptrOrBoxEleTy(alloc.getType()); 900 assert(seqTy && seqTy.isa<SequenceType>()); 901 const auto dimension = seqTy.cast<SequenceType>().getDimension(); 902 auto module = load->getParentOfType<mlir::ModuleOp>(); 903 fir::KindMapping kindMap = getKindMapping(module); 904 FirOpBuilder builder(rewriter, kindMap); 905 auto typeparams = getTypeParamsIfRawData(loc, builder, load, alloc.getType()); 906 mlir::Value result = rewriter.create<ArrayCoorOp>( 907 loc, eleTy, alloc, shape, slice, 908 llvm::ArrayRef<mlir::Value>{originated}.take_front(dimension), 909 typeparams); 910 if (dimension < originated.size()) 911 result = rewriter.create<fir::CoordinateOp>( 912 loc, resTy, result, 913 llvm::ArrayRef<mlir::Value>{originated}.drop_front(dimension)); 914 return result; 915 } 916 917 static mlir::Value getCharacterLen(mlir::Location loc, FirOpBuilder &builder, 918 ArrayLoadOp load, CharacterType charTy) { 919 auto charLenTy = builder.getCharacterLengthType(); 920 if (charTy.hasDynamicLen()) { 921 if (load.getMemref().getType().isa<BoxType>()) { 922 // The loaded array is an emboxed value. Get the CHARACTER length from 923 // the box value. 924 auto eleSzInBytes = 925 builder.create<BoxEleSizeOp>(loc, charLenTy, load.getMemref()); 926 auto kindSize = 927 builder.getKindMap().getCharacterBitsize(charTy.getFKind()); 928 auto kindByteSize = 929 builder.createIntegerConstant(loc, charLenTy, kindSize / 8); 930 return builder.create<mlir::arith::DivSIOp>(loc, eleSzInBytes, 931 kindByteSize); 932 } 933 // The loaded array is a (set of) unboxed values. If the CHARACTER's 934 // length is not a constant, it must be provided as a type parameter to 935 // the array_load. 936 auto typeparams = load.getTypeparams(); 937 assert(typeparams.size() > 0 && "expected type parameters on array_load"); 938 return typeparams.back(); 939 } 940 // The typical case: the length of the CHARACTER is a compile-time 941 // constant that is encoded in the type information. 942 return builder.createIntegerConstant(loc, charLenTy, charTy.getLen()); 943 } 944 /// Generate a shallow array copy. This is used for both copy-in and copy-out. 945 template <bool CopyIn> 946 void genArrayCopy(mlir::Location loc, mlir::PatternRewriter &rewriter, 947 mlir::Value dst, mlir::Value src, mlir::Value shapeOp, 948 mlir::Value sliceOp, ArrayLoadOp arrLoad) { 949 auto insPt = rewriter.saveInsertionPoint(); 950 llvm::SmallVector<mlir::Value> indices; 951 llvm::SmallVector<mlir::Value> extents; 952 bool copyUsingSlice = 953 getAdjustedExtents(loc, rewriter, arrLoad, extents, shapeOp); 954 auto idxTy = rewriter.getIndexType(); 955 // Build loop nest from column to row. 956 for (auto sh : llvm::reverse(extents)) { 957 auto ubi = rewriter.create<ConvertOp>(loc, idxTy, sh); 958 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); 959 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1); 960 auto ub = rewriter.create<mlir::arith::SubIOp>(loc, idxTy, ubi, one); 961 auto loop = rewriter.create<DoLoopOp>(loc, zero, ub, one); 962 rewriter.setInsertionPointToStart(loop.getBody()); 963 indices.push_back(loop.getInductionVar()); 964 } 965 // Reverse the indices so they are in column-major order. 966 std::reverse(indices.begin(), indices.end()); 967 auto module = arrLoad->getParentOfType<mlir::ModuleOp>(); 968 fir::KindMapping kindMap = getKindMapping(module); 969 FirOpBuilder builder(rewriter, kindMap); 970 auto fromAddr = rewriter.create<ArrayCoorOp>( 971 loc, getEleTy(src.getType()), src, shapeOp, 972 CopyIn && copyUsingSlice ? sliceOp : mlir::Value{}, 973 factory::originateIndices(loc, rewriter, src.getType(), shapeOp, indices), 974 getTypeParamsIfRawData(loc, builder, arrLoad, src.getType())); 975 auto toAddr = rewriter.create<ArrayCoorOp>( 976 loc, getEleTy(dst.getType()), dst, shapeOp, 977 !CopyIn && copyUsingSlice ? sliceOp : mlir::Value{}, 978 factory::originateIndices(loc, rewriter, dst.getType(), shapeOp, indices), 979 getTypeParamsIfRawData(loc, builder, arrLoad, dst.getType())); 980 auto eleTy = unwrapSequenceType(unwrapPassByRefType(dst.getType())); 981 // Copy from (to) object to (from) temp copy of same object. 982 if (auto charTy = eleTy.dyn_cast<CharacterType>()) { 983 auto len = getCharacterLen(loc, builder, arrLoad, charTy); 984 CharBoxValue toChar(toAddr, len); 985 CharBoxValue fromChar(fromAddr, len); 986 factory::genScalarAssignment(builder, loc, toChar, fromChar); 987 } else { 988 if (hasDynamicSize(eleTy)) 989 TODO(loc, "copy element of dynamic size"); 990 factory::genScalarAssignment(builder, loc, toAddr, fromAddr); 991 } 992 rewriter.restoreInsertionPoint(insPt); 993 } 994 995 /// The array load may be either a boxed or unboxed value. If the value is 996 /// boxed, we read the type parameters from the boxed value. 997 static llvm::SmallVector<mlir::Value> 998 genArrayLoadTypeParameters(mlir::Location loc, mlir::PatternRewriter &rewriter, 999 ArrayLoadOp load) { 1000 if (load.getTypeparams().empty()) { 1001 auto eleTy = 1002 unwrapSequenceType(unwrapPassByRefType(load.getMemref().getType())); 1003 if (hasDynamicSize(eleTy)) { 1004 if (auto charTy = eleTy.dyn_cast<CharacterType>()) { 1005 assert(load.getMemref().getType().isa<BoxType>()); 1006 auto module = load->getParentOfType<mlir::ModuleOp>(); 1007 fir::KindMapping kindMap = getKindMapping(module); 1008 FirOpBuilder builder(rewriter, kindMap); 1009 return {getCharacterLen(loc, builder, load, charTy)}; 1010 } 1011 TODO(loc, "unhandled dynamic type parameters"); 1012 } 1013 return {}; 1014 } 1015 return load.getTypeparams(); 1016 } 1017 1018 static llvm::SmallVector<mlir::Value> 1019 findNonconstantExtents(mlir::Type memrefTy, 1020 llvm::ArrayRef<mlir::Value> extents) { 1021 llvm::SmallVector<mlir::Value> nce; 1022 auto arrTy = unwrapPassByRefType(memrefTy); 1023 auto seqTy = arrTy.cast<SequenceType>(); 1024 for (auto [s, x] : llvm::zip(seqTy.getShape(), extents)) 1025 if (s == SequenceType::getUnknownExtent()) 1026 nce.emplace_back(x); 1027 if (extents.size() > seqTy.getShape().size()) 1028 for (auto x : extents.drop_front(seqTy.getShape().size())) 1029 nce.emplace_back(x); 1030 return nce; 1031 } 1032 1033 /// Allocate temporary storage for an ArrayLoadOp \load and initialize any 1034 /// allocatable direct components of the array elements with an unallocated 1035 /// status. Returns the temporary address as well as a callback to generate the 1036 /// temporary clean-up once it has been used. The clean-up will take care of 1037 /// deallocating all the element allocatable components that may have been 1038 /// allocated while using the temporary. 1039 static std::pair<mlir::Value, 1040 std::function<void(mlir::PatternRewriter &rewriter)>> 1041 allocateArrayTemp(mlir::Location loc, mlir::PatternRewriter &rewriter, 1042 ArrayLoadOp load, llvm::ArrayRef<mlir::Value> extents, 1043 mlir::Value shape) { 1044 mlir::Type baseType = load.getMemref().getType(); 1045 llvm::SmallVector<mlir::Value> nonconstantExtents = 1046 findNonconstantExtents(baseType, extents); 1047 llvm::SmallVector<mlir::Value> typeParams = 1048 genArrayLoadTypeParameters(loc, rewriter, load); 1049 mlir::Value allocmem = rewriter.create<AllocMemOp>( 1050 loc, dyn_cast_ptrOrBoxEleTy(baseType), typeParams, nonconstantExtents); 1051 mlir::Type eleType = 1052 fir::unwrapSequenceType(fir::unwrapPassByRefType(baseType)); 1053 if (fir::isRecordWithAllocatableMember(eleType)) { 1054 // The allocatable component descriptors need to be set to a clean 1055 // deallocated status before anything is done with them. 1056 mlir::Value box = rewriter.create<fir::EmboxOp>( 1057 loc, fir::BoxType::get(baseType), allocmem, shape, 1058 /*slice=*/mlir::Value{}, typeParams); 1059 auto module = load->getParentOfType<mlir::ModuleOp>(); 1060 fir::KindMapping kindMap = getKindMapping(module); 1061 FirOpBuilder builder(rewriter, kindMap); 1062 runtime::genDerivedTypeInitialize(builder, loc, box); 1063 // Any allocatable component that may have been allocated must be 1064 // deallocated during the clean-up. 1065 auto cleanup = [=](mlir::PatternRewriter &r) { 1066 fir::KindMapping kindMap = getKindMapping(module); 1067 FirOpBuilder builder(r, kindMap); 1068 runtime::genDerivedTypeDestroy(builder, loc, box); 1069 r.create<FreeMemOp>(loc, allocmem); 1070 }; 1071 return {allocmem, cleanup}; 1072 } 1073 auto cleanup = [=](mlir::PatternRewriter &r) { 1074 r.create<FreeMemOp>(loc, allocmem); 1075 }; 1076 return {allocmem, cleanup}; 1077 } 1078 1079 namespace { 1080 /// Conversion of fir.array_update and fir.array_modify Ops. 1081 /// If there is a conflict for the update, then we need to perform a 1082 /// copy-in/copy-out to preserve the original values of the array. If there is 1083 /// no conflict, then it is save to eschew making any copies. 1084 template <typename ArrayOp> 1085 class ArrayUpdateConversionBase : public mlir::OpRewritePattern<ArrayOp> { 1086 public: 1087 // TODO: Implement copy/swap semantics? 1088 explicit ArrayUpdateConversionBase(mlir::MLIRContext *ctx, 1089 const ArrayCopyAnalysis &a, 1090 const OperationUseMapT &m) 1091 : mlir::OpRewritePattern<ArrayOp>{ctx}, analysis{a}, useMap{m} {} 1092 1093 /// The array_access, \p access, is to be to a cloned copy due to a potential 1094 /// conflict. Uses copy-in/copy-out semantics and not copy/swap. 1095 mlir::Value referenceToClone(mlir::Location loc, 1096 mlir::PatternRewriter &rewriter, 1097 ArrayOp access) const { 1098 LLVM_DEBUG(llvm::dbgs() 1099 << "generating copy-in/copy-out loops for " << access << '\n'); 1100 auto *op = access.getOperation(); 1101 auto *loadOp = useMap.lookup(op); 1102 auto load = mlir::cast<ArrayLoadOp>(loadOp); 1103 auto eleTy = access.getType(); 1104 rewriter.setInsertionPoint(loadOp); 1105 // Copy in. 1106 llvm::SmallVector<mlir::Value> extents; 1107 bool copyUsingSlice = false; 1108 auto shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents, 1109 copyUsingSlice); 1110 auto [allocmem, genTempCleanUp] = 1111 allocateArrayTemp(loc, rewriter, load, extents, shapeOp); 1112 genArrayCopy</*copyIn=*/true>(load.getLoc(), rewriter, allocmem, 1113 load.getMemref(), shapeOp, load.getSlice(), 1114 load); 1115 // Generate the reference for the access. 1116 rewriter.setInsertionPoint(op); 1117 auto coor = genCoorOp( 1118 rewriter, loc, getEleTy(load.getType()), eleTy, allocmem, shapeOp, 1119 copyUsingSlice ? mlir::Value{} : load.getSlice(), access.getIndices(), 1120 load, access->hasAttr(factory::attrFortranArrayOffsets())); 1121 // Copy out. 1122 auto *storeOp = useMap.lookup(loadOp); 1123 auto store = mlir::cast<ArrayMergeStoreOp>(storeOp); 1124 rewriter.setInsertionPoint(storeOp); 1125 // Copy out. 1126 genArrayCopy</*copyIn=*/false>(store.getLoc(), rewriter, store.getMemref(), 1127 allocmem, shapeOp, store.getSlice(), load); 1128 genTempCleanUp(rewriter); 1129 return coor; 1130 } 1131 1132 /// Copy the RHS element into the LHS and insert copy-in/copy-out between a 1133 /// temp and the LHS if the analysis found potential overlaps between the RHS 1134 /// and LHS arrays. The element copy generator must be provided in \p 1135 /// assignElement. \p update must be the ArrayUpdateOp or the ArrayModifyOp. 1136 /// Returns the address of the LHS element inside the loop and the LHS 1137 /// ArrayLoad result. 1138 std::pair<mlir::Value, mlir::Value> 1139 materializeAssignment(mlir::Location loc, mlir::PatternRewriter &rewriter, 1140 ArrayOp update, 1141 const std::function<void(mlir::Value)> &assignElement, 1142 mlir::Type lhsEltRefType) const { 1143 auto *op = update.getOperation(); 1144 auto *loadOp = useMap.lookup(op); 1145 auto load = mlir::cast<ArrayLoadOp>(loadOp); 1146 LLVM_DEBUG(llvm::outs() << "does " << load << " have a conflict?\n"); 1147 if (analysis.hasPotentialConflict(loadOp)) { 1148 // If there is a conflict between the arrays, then we copy the lhs array 1149 // to a temporary, update the temporary, and copy the temporary back to 1150 // the lhs array. This yields Fortran's copy-in copy-out array semantics. 1151 LLVM_DEBUG(llvm::outs() << "Yes, conflict was found\n"); 1152 rewriter.setInsertionPoint(loadOp); 1153 // Copy in. 1154 llvm::SmallVector<mlir::Value> extents; 1155 bool copyUsingSlice = false; 1156 auto shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents, 1157 copyUsingSlice); 1158 auto [allocmem, genTempCleanUp] = 1159 allocateArrayTemp(loc, rewriter, load, extents, shapeOp); 1160 1161 genArrayCopy</*copyIn=*/true>(load.getLoc(), rewriter, allocmem, 1162 load.getMemref(), shapeOp, load.getSlice(), 1163 load); 1164 rewriter.setInsertionPoint(op); 1165 auto coor = genCoorOp( 1166 rewriter, loc, getEleTy(load.getType()), lhsEltRefType, allocmem, 1167 shapeOp, copyUsingSlice ? mlir::Value{} : load.getSlice(), 1168 update.getIndices(), load, 1169 update->hasAttr(factory::attrFortranArrayOffsets())); 1170 assignElement(coor); 1171 auto *storeOp = useMap.lookup(loadOp); 1172 auto store = mlir::cast<ArrayMergeStoreOp>(storeOp); 1173 rewriter.setInsertionPoint(storeOp); 1174 // Copy out. 1175 genArrayCopy</*copyIn=*/false>(store.getLoc(), rewriter, 1176 store.getMemref(), allocmem, shapeOp, 1177 store.getSlice(), load); 1178 genTempCleanUp(rewriter); 1179 return {coor, load.getResult()}; 1180 } 1181 // Otherwise, when there is no conflict (a possible loop-carried 1182 // dependence), the lhs array can be updated in place. 1183 LLVM_DEBUG(llvm::outs() << "No, conflict wasn't found\n"); 1184 rewriter.setInsertionPoint(op); 1185 auto coorTy = getEleTy(load.getType()); 1186 auto coor = 1187 genCoorOp(rewriter, loc, coorTy, lhsEltRefType, load.getMemref(), 1188 load.getShape(), load.getSlice(), update.getIndices(), load, 1189 update->hasAttr(factory::attrFortranArrayOffsets())); 1190 assignElement(coor); 1191 return {coor, load.getResult()}; 1192 } 1193 1194 protected: 1195 const ArrayCopyAnalysis &analysis; 1196 const OperationUseMapT &useMap; 1197 }; 1198 1199 class ArrayUpdateConversion : public ArrayUpdateConversionBase<ArrayUpdateOp> { 1200 public: 1201 explicit ArrayUpdateConversion(mlir::MLIRContext *ctx, 1202 const ArrayCopyAnalysis &a, 1203 const OperationUseMapT &m) 1204 : ArrayUpdateConversionBase{ctx, a, m} {} 1205 1206 mlir::LogicalResult 1207 matchAndRewrite(ArrayUpdateOp update, 1208 mlir::PatternRewriter &rewriter) const override { 1209 auto loc = update.getLoc(); 1210 auto assignElement = [&](mlir::Value coor) { 1211 auto input = update.getMerge(); 1212 if (auto inEleTy = dyn_cast_ptrEleTy(input.getType())) { 1213 emitFatalError(loc, "array_update on references not supported"); 1214 } else { 1215 rewriter.create<fir::StoreOp>(loc, input, coor); 1216 } 1217 }; 1218 auto lhsEltRefType = toRefType(update.getMerge().getType()); 1219 auto [_, lhsLoadResult] = materializeAssignment( 1220 loc, rewriter, update, assignElement, lhsEltRefType); 1221 update.replaceAllUsesWith(lhsLoadResult); 1222 rewriter.replaceOp(update, lhsLoadResult); 1223 return mlir::success(); 1224 } 1225 }; 1226 1227 class ArrayModifyConversion : public ArrayUpdateConversionBase<ArrayModifyOp> { 1228 public: 1229 explicit ArrayModifyConversion(mlir::MLIRContext *ctx, 1230 const ArrayCopyAnalysis &a, 1231 const OperationUseMapT &m) 1232 : ArrayUpdateConversionBase{ctx, a, m} {} 1233 1234 mlir::LogicalResult 1235 matchAndRewrite(ArrayModifyOp modify, 1236 mlir::PatternRewriter &rewriter) const override { 1237 auto loc = modify.getLoc(); 1238 auto assignElement = [](mlir::Value) { 1239 // Assignment already materialized by lowering using lhs element address. 1240 }; 1241 auto lhsEltRefType = modify.getResult(0).getType(); 1242 auto [lhsEltCoor, lhsLoadResult] = materializeAssignment( 1243 loc, rewriter, modify, assignElement, lhsEltRefType); 1244 modify.replaceAllUsesWith(mlir::ValueRange{lhsEltCoor, lhsLoadResult}); 1245 rewriter.replaceOp(modify, mlir::ValueRange{lhsEltCoor, lhsLoadResult}); 1246 return mlir::success(); 1247 } 1248 }; 1249 1250 class ArrayFetchConversion : public mlir::OpRewritePattern<ArrayFetchOp> { 1251 public: 1252 explicit ArrayFetchConversion(mlir::MLIRContext *ctx, 1253 const OperationUseMapT &m) 1254 : OpRewritePattern{ctx}, useMap{m} {} 1255 1256 mlir::LogicalResult 1257 matchAndRewrite(ArrayFetchOp fetch, 1258 mlir::PatternRewriter &rewriter) const override { 1259 auto *op = fetch.getOperation(); 1260 rewriter.setInsertionPoint(op); 1261 auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op)); 1262 auto loc = fetch.getLoc(); 1263 auto coor = genCoorOp( 1264 rewriter, loc, getEleTy(load.getType()), toRefType(fetch.getType()), 1265 load.getMemref(), load.getShape(), load.getSlice(), fetch.getIndices(), 1266 load, fetch->hasAttr(factory::attrFortranArrayOffsets())); 1267 if (isa_ref_type(fetch.getType())) 1268 rewriter.replaceOp(fetch, coor); 1269 else 1270 rewriter.replaceOpWithNewOp<fir::LoadOp>(fetch, coor); 1271 return mlir::success(); 1272 } 1273 1274 private: 1275 const OperationUseMapT &useMap; 1276 }; 1277 1278 /// As array_access op is like an array_fetch op, except that it does not imply 1279 /// a load op. (It operates in the reference domain.) 1280 class ArrayAccessConversion : public ArrayUpdateConversionBase<ArrayAccessOp> { 1281 public: 1282 explicit ArrayAccessConversion(mlir::MLIRContext *ctx, 1283 const ArrayCopyAnalysis &a, 1284 const OperationUseMapT &m) 1285 : ArrayUpdateConversionBase{ctx, a, m} {} 1286 1287 mlir::LogicalResult 1288 matchAndRewrite(ArrayAccessOp access, 1289 mlir::PatternRewriter &rewriter) const override { 1290 auto *op = access.getOperation(); 1291 auto loc = access.getLoc(); 1292 if (analysis.inAmendAccessSet(op)) { 1293 // This array_access is associated with an array_amend and there is a 1294 // conflict. Make a copy to store into. 1295 auto result = referenceToClone(loc, rewriter, access); 1296 access.replaceAllUsesWith(result); 1297 rewriter.replaceOp(access, result); 1298 return mlir::success(); 1299 } 1300 rewriter.setInsertionPoint(op); 1301 auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op)); 1302 auto coor = genCoorOp( 1303 rewriter, loc, getEleTy(load.getType()), toRefType(access.getType()), 1304 load.getMemref(), load.getShape(), load.getSlice(), access.getIndices(), 1305 load, access->hasAttr(factory::attrFortranArrayOffsets())); 1306 rewriter.replaceOp(access, coor); 1307 return mlir::success(); 1308 } 1309 }; 1310 1311 /// An array_amend op is a marker to record which array access is being used to 1312 /// update an array value. After this pass runs, an array_amend has no 1313 /// semantics. We rewrite these to undefined values here to remove them while 1314 /// preserving SSA form. 1315 class ArrayAmendConversion : public mlir::OpRewritePattern<ArrayAmendOp> { 1316 public: 1317 explicit ArrayAmendConversion(mlir::MLIRContext *ctx) 1318 : OpRewritePattern{ctx} {} 1319 1320 mlir::LogicalResult 1321 matchAndRewrite(ArrayAmendOp amend, 1322 mlir::PatternRewriter &rewriter) const override { 1323 auto *op = amend.getOperation(); 1324 rewriter.setInsertionPoint(op); 1325 auto loc = amend.getLoc(); 1326 auto undef = rewriter.create<UndefOp>(loc, amend.getType()); 1327 rewriter.replaceOp(amend, undef.getResult()); 1328 return mlir::success(); 1329 } 1330 }; 1331 1332 class ArrayValueCopyConverter 1333 : public fir::impl::ArrayValueCopyBase<ArrayValueCopyConverter> { 1334 public: 1335 void runOnOperation() override { 1336 auto func = getOperation(); 1337 LLVM_DEBUG(llvm::dbgs() << "\n\narray-value-copy pass on function '" 1338 << func.getName() << "'\n"); 1339 auto *context = &getContext(); 1340 1341 // Perform the conflict analysis. 1342 const auto &analysis = getAnalysis<ArrayCopyAnalysis>(); 1343 const auto &useMap = analysis.getUseMap(); 1344 1345 mlir::RewritePatternSet patterns1(context); 1346 patterns1.insert<ArrayFetchConversion>(context, useMap); 1347 patterns1.insert<ArrayUpdateConversion>(context, analysis, useMap); 1348 patterns1.insert<ArrayModifyConversion>(context, analysis, useMap); 1349 patterns1.insert<ArrayAccessConversion>(context, analysis, useMap); 1350 patterns1.insert<ArrayAmendConversion>(context); 1351 mlir::ConversionTarget target(*context); 1352 target 1353 .addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect, 1354 mlir::arith::ArithDialect, mlir::func::FuncDialect>(); 1355 target.addIllegalOp<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, 1356 ArrayUpdateOp, ArrayModifyOp>(); 1357 // Rewrite the array fetch and array update ops. 1358 if (mlir::failed( 1359 mlir::applyPartialConversion(func, target, std::move(patterns1)))) { 1360 mlir::emitError(mlir::UnknownLoc::get(context), 1361 "failure in array-value-copy pass, phase 1"); 1362 signalPassFailure(); 1363 } 1364 1365 mlir::RewritePatternSet patterns2(context); 1366 patterns2.insert<ArrayLoadConversion>(context); 1367 patterns2.insert<ArrayMergeStoreConversion>(context); 1368 target.addIllegalOp<ArrayLoadOp, ArrayMergeStoreOp>(); 1369 if (mlir::failed( 1370 mlir::applyPartialConversion(func, target, std::move(patterns2)))) { 1371 mlir::emitError(mlir::UnknownLoc::get(context), 1372 "failure in array-value-copy pass, phase 2"); 1373 signalPassFailure(); 1374 } 1375 } 1376 }; 1377 } // namespace 1378 1379 std::unique_ptr<mlir::Pass> fir::createArrayValueCopyPass() { 1380 return std::make_unique<ArrayValueCopyConverter>(); 1381 } 1382