1 //===-- ArrayValueCopy.cpp ------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "flang/Optimizer/Builder/Array.h" 11 #include "flang/Optimizer/Builder/BoxValue.h" 12 #include "flang/Optimizer/Builder/FIRBuilder.h" 13 #include "flang/Optimizer/Builder/Factory.h" 14 #include "flang/Optimizer/Builder/Runtime/Derived.h" 15 #include "flang/Optimizer/Builder/Todo.h" 16 #include "flang/Optimizer/Dialect/FIRDialect.h" 17 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 18 #include "flang/Optimizer/Support/FIRContext.h" 19 #include "flang/Optimizer/Transforms/Passes.h" 20 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 21 #include "mlir/Dialect/SCF/IR/SCF.h" 22 #include "mlir/Transforms/DialectConversion.h" 23 #include "llvm/Support/Debug.h" 24 25 #define DEBUG_TYPE "flang-array-value-copy" 26 27 using namespace fir; 28 using namespace mlir; 29 30 using OperationUseMapT = llvm::DenseMap<mlir::Operation *, mlir::Operation *>; 31 32 namespace { 33 34 /// Array copy analysis. 35 /// Perform an interference analysis between array values. 36 /// 37 /// Lowering will generate a sequence of the following form. 38 /// ```mlir 39 /// %a_1 = fir.array_load %array_1(%shape) : ... 40 /// ... 41 /// %a_j = fir.array_load %array_j(%shape) : ... 42 /// ... 43 /// %a_n = fir.array_load %array_n(%shape) : ... 44 /// ... 45 /// %v_i = fir.array_fetch %a_i, ... 46 /// %a_j1 = fir.array_update %a_j, ... 47 /// ... 48 /// fir.array_merge_store %a_j, %a_jn to %array_j : ... 49 /// ``` 50 /// 51 /// The analysis is to determine if there are any conflicts. A conflict is when 52 /// one the following cases occurs. 53 /// 54 /// 1. There is an `array_update` to an array value, a_j, such that a_j was 55 /// loaded from the same array memory reference (array_j) but with a different 56 /// shape as the other array values a_i, where i != j. [Possible overlapping 57 /// arrays.] 58 /// 59 /// 2. There is either an array_fetch or array_update of a_j with a different 60 /// set of index values. [Possible loop-carried dependence.] 61 /// 62 /// If none of the array values overlap in storage and the accesses are not 63 /// loop-carried, then the arrays are conflict-free and no copies are required. 64 class ArrayCopyAnalysis { 65 public: 66 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ArrayCopyAnalysis) 67 68 using ConflictSetT = llvm::SmallPtrSet<mlir::Operation *, 16>; 69 using UseSetT = llvm::SmallPtrSet<mlir::OpOperand *, 8>; 70 using LoadMapSetsT = llvm::DenseMap<mlir::Operation *, UseSetT>; 71 using AmendAccessSetT = llvm::SmallPtrSet<mlir::Operation *, 4>; 72 73 ArrayCopyAnalysis(mlir::Operation *op) : operation{op} { construct(op); } 74 75 mlir::Operation *getOperation() const { return operation; } 76 77 /// Return true iff the `array_merge_store` has potential conflicts. 78 bool hasPotentialConflict(mlir::Operation *op) const { 79 LLVM_DEBUG(llvm::dbgs() 80 << "looking for a conflict on " << *op 81 << " and the set has a total of " << conflicts.size() << '\n'); 82 return conflicts.contains(op); 83 } 84 85 /// Return the use map. 86 /// The use map maps array access, amend, fetch and update operations back to 87 /// the array load that is the original source of the array value. 88 /// It maps an array_load to an array_merge_store, if and only if the loaded 89 /// array value has pending modifications to be merged. 90 const OperationUseMapT &getUseMap() const { return useMap; } 91 92 /// Return the set of array_access ops directly associated with array_amend 93 /// ops. 94 bool inAmendAccessSet(mlir::Operation *op) const { 95 return amendAccesses.count(op); 96 } 97 98 /// For ArrayLoad `load`, return the transitive set of all OpOperands. 99 UseSetT getLoadUseSet(mlir::Operation *load) const { 100 assert(loadMapSets.count(load) && "analysis missed an array load?"); 101 return loadMapSets.lookup(load); 102 } 103 104 void arrayMentions(llvm::SmallVectorImpl<mlir::Operation *> &mentions, 105 ArrayLoadOp load); 106 107 private: 108 void construct(mlir::Operation *topLevelOp); 109 110 mlir::Operation *operation; // operation that analysis ran upon 111 ConflictSetT conflicts; // set of conflicts (loads and merge stores) 112 OperationUseMapT useMap; 113 LoadMapSetsT loadMapSets; 114 // Set of array_access ops associated with array_amend ops. 115 AmendAccessSetT amendAccesses; 116 }; 117 } // namespace 118 119 namespace { 120 /// Helper class to collect all array operations that produced an array value. 121 class ReachCollector { 122 public: 123 ReachCollector(llvm::SmallVectorImpl<mlir::Operation *> &reach, 124 mlir::Region *loopRegion) 125 : reach{reach}, loopRegion{loopRegion} {} 126 127 void collectArrayMentionFrom(mlir::Operation *op, mlir::ValueRange range) { 128 if (range.empty()) { 129 collectArrayMentionFrom(op, mlir::Value{}); 130 return; 131 } 132 for (mlir::Value v : range) 133 collectArrayMentionFrom(v); 134 } 135 136 // Collect all the array_access ops in `block`. This recursively looks into 137 // blocks in ops with regions. 138 // FIXME: This is temporarily relying on the array_amend appearing in a 139 // do_loop Region. This phase ordering assumption can be eliminated by using 140 // dominance information to find the array_access ops or by scanning the 141 // transitive closure of the amending array_access's users and the defs that 142 // reach them. 143 void collectAccesses(llvm::SmallVector<ArrayAccessOp> &result, 144 mlir::Block *block) { 145 for (auto &op : *block) { 146 if (auto access = mlir::dyn_cast<ArrayAccessOp>(op)) { 147 LLVM_DEBUG(llvm::dbgs() << "adding access: " << access << '\n'); 148 result.push_back(access); 149 continue; 150 } 151 for (auto ®ion : op.getRegions()) 152 for (auto &bb : region.getBlocks()) 153 collectAccesses(result, &bb); 154 } 155 } 156 157 void collectArrayMentionFrom(mlir::Operation *op, mlir::Value val) { 158 // `val` is defined by an Op, process the defining Op. 159 // If `val` is defined by a region containing Op, we want to drill down 160 // and through that Op's region(s). 161 LLVM_DEBUG(llvm::dbgs() << "popset: " << *op << '\n'); 162 auto popFn = [&](auto rop) { 163 assert(val && "op must have a result value"); 164 auto resNum = val.cast<mlir::OpResult>().getResultNumber(); 165 llvm::SmallVector<mlir::Value> results; 166 rop.resultToSourceOps(results, resNum); 167 for (auto u : results) 168 collectArrayMentionFrom(u); 169 }; 170 if (auto rop = mlir::dyn_cast<DoLoopOp>(op)) { 171 popFn(rop); 172 return; 173 } 174 if (auto rop = mlir::dyn_cast<IterWhileOp>(op)) { 175 popFn(rop); 176 return; 177 } 178 if (auto rop = mlir::dyn_cast<fir::IfOp>(op)) { 179 popFn(rop); 180 return; 181 } 182 if (auto box = mlir::dyn_cast<EmboxOp>(op)) { 183 for (auto *user : box.getMemref().getUsers()) 184 if (user != op) 185 collectArrayMentionFrom(user, user->getResults()); 186 return; 187 } 188 if (auto mergeStore = mlir::dyn_cast<ArrayMergeStoreOp>(op)) { 189 if (opIsInsideLoops(mergeStore)) 190 collectArrayMentionFrom(mergeStore.getSequence()); 191 return; 192 } 193 194 if (mlir::isa<AllocaOp, AllocMemOp>(op)) { 195 // Look for any stores inside the loops, and collect an array operation 196 // that produced the value being stored to it. 197 for (auto *user : op->getUsers()) 198 if (auto store = mlir::dyn_cast<fir::StoreOp>(user)) 199 if (opIsInsideLoops(store)) 200 collectArrayMentionFrom(store.getValue()); 201 return; 202 } 203 204 // Scan the uses of amend's memref 205 if (auto amend = mlir::dyn_cast<ArrayAmendOp>(op)) { 206 reach.push_back(op); 207 llvm::SmallVector<ArrayAccessOp> accesses; 208 collectAccesses(accesses, op->getBlock()); 209 for (auto access : accesses) 210 collectArrayMentionFrom(access.getResult()); 211 } 212 213 // Otherwise, Op does not contain a region so just chase its operands. 214 if (mlir::isa<ArrayAccessOp, ArrayLoadOp, ArrayUpdateOp, ArrayModifyOp, 215 ArrayFetchOp>(op)) { 216 LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n"); 217 reach.push_back(op); 218 } 219 220 // Include all array_access ops using an array_load. 221 if (auto arrLd = mlir::dyn_cast<ArrayLoadOp>(op)) 222 for (auto *user : arrLd.getResult().getUsers()) 223 if (mlir::isa<ArrayAccessOp>(user)) { 224 LLVM_DEBUG(llvm::dbgs() << "add " << *user << " to reachable set\n"); 225 reach.push_back(user); 226 } 227 228 // Array modify assignment is performed on the result. So the analysis must 229 // look at the what is done with the result. 230 if (mlir::isa<ArrayModifyOp>(op)) 231 for (auto *user : op->getResult(0).getUsers()) 232 followUsers(user); 233 234 if (mlir::isa<fir::CallOp>(op)) { 235 LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n"); 236 reach.push_back(op); 237 } 238 239 for (auto u : op->getOperands()) 240 collectArrayMentionFrom(u); 241 } 242 243 void collectArrayMentionFrom(mlir::BlockArgument ba) { 244 auto *parent = ba.getOwner()->getParentOp(); 245 // If inside an Op holding a region, the block argument corresponds to an 246 // argument passed to the containing Op. 247 auto popFn = [&](auto rop) { 248 collectArrayMentionFrom(rop.blockArgToSourceOp(ba.getArgNumber())); 249 }; 250 if (auto rop = mlir::dyn_cast<DoLoopOp>(parent)) { 251 popFn(rop); 252 return; 253 } 254 if (auto rop = mlir::dyn_cast<IterWhileOp>(parent)) { 255 popFn(rop); 256 return; 257 } 258 // Otherwise, a block argument is provided via the pred blocks. 259 for (auto *pred : ba.getOwner()->getPredecessors()) { 260 auto u = pred->getTerminator()->getOperand(ba.getArgNumber()); 261 collectArrayMentionFrom(u); 262 } 263 } 264 265 // Recursively trace operands to find all array operations relating to the 266 // values merged. 267 void collectArrayMentionFrom(mlir::Value val) { 268 if (!val || visited.contains(val)) 269 return; 270 visited.insert(val); 271 272 // Process a block argument. 273 if (auto ba = val.dyn_cast<mlir::BlockArgument>()) { 274 collectArrayMentionFrom(ba); 275 return; 276 } 277 278 // Process an Op. 279 if (auto *op = val.getDefiningOp()) { 280 collectArrayMentionFrom(op, val); 281 return; 282 } 283 284 emitFatalError(val.getLoc(), "unhandled value"); 285 } 286 287 /// Return all ops that produce the array value that is stored into the 288 /// `array_merge_store`. 289 static void reachingValues(llvm::SmallVectorImpl<mlir::Operation *> &reach, 290 mlir::Value seq) { 291 reach.clear(); 292 mlir::Region *loopRegion = nullptr; 293 if (auto doLoop = mlir::dyn_cast_or_null<DoLoopOp>(seq.getDefiningOp())) 294 loopRegion = &doLoop->getRegion(0); 295 ReachCollector collector(reach, loopRegion); 296 collector.collectArrayMentionFrom(seq); 297 } 298 299 private: 300 /// Is \op inside the loop nest region ? 301 /// FIXME: replace this structural dependence with graph properties. 302 bool opIsInsideLoops(mlir::Operation *op) const { 303 auto *region = op->getParentRegion(); 304 while (region) { 305 if (region == loopRegion) 306 return true; 307 region = region->getParentRegion(); 308 } 309 return false; 310 } 311 312 /// Recursively trace the use of an operation results, calling 313 /// collectArrayMentionFrom on the direct and indirect user operands. 314 void followUsers(mlir::Operation *op) { 315 for (auto userOperand : op->getOperands()) 316 collectArrayMentionFrom(userOperand); 317 // Go through potential converts/coordinate_op. 318 for (auto indirectUser : op->getUsers()) 319 followUsers(indirectUser); 320 } 321 322 llvm::SmallVectorImpl<mlir::Operation *> &reach; 323 llvm::SmallPtrSet<mlir::Value, 16> visited; 324 /// Region of the loops nest that produced the array value. 325 mlir::Region *loopRegion; 326 }; 327 } // namespace 328 329 /// Find all the array operations that access the array value that is loaded by 330 /// the array load operation, `load`. 331 void ArrayCopyAnalysis::arrayMentions( 332 llvm::SmallVectorImpl<mlir::Operation *> &mentions, ArrayLoadOp load) { 333 mentions.clear(); 334 auto lmIter = loadMapSets.find(load); 335 if (lmIter != loadMapSets.end()) { 336 for (auto *opnd : lmIter->second) { 337 auto *owner = opnd->getOwner(); 338 if (mlir::isa<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, ArrayUpdateOp, 339 ArrayModifyOp>(owner)) 340 mentions.push_back(owner); 341 } 342 return; 343 } 344 345 UseSetT visited; 346 llvm::SmallVector<mlir::OpOperand *> queue; // uses of ArrayLoad[orig] 347 348 auto appendToQueue = [&](mlir::Value val) { 349 for (auto &use : val.getUses()) 350 if (!visited.count(&use)) { 351 visited.insert(&use); 352 queue.push_back(&use); 353 } 354 }; 355 356 // Build the set of uses of `original`. 357 // let USES = { uses of original fir.load } 358 appendToQueue(load); 359 360 // Process the worklist until done. 361 while (!queue.empty()) { 362 mlir::OpOperand *operand = queue.pop_back_val(); 363 mlir::Operation *owner = operand->getOwner(); 364 if (!owner) 365 continue; 366 auto structuredLoop = [&](auto ro) { 367 if (auto blockArg = ro.iterArgToBlockArg(operand->get())) { 368 int64_t arg = blockArg.getArgNumber(); 369 mlir::Value output = ro.getResult(ro.getFinalValue() ? arg : arg - 1); 370 appendToQueue(output); 371 appendToQueue(blockArg); 372 } 373 }; 374 // TODO: this need to be updated to use the control-flow interface. 375 auto branchOp = [&](mlir::Block *dest, OperandRange operands) { 376 if (operands.empty()) 377 return; 378 379 // Check if this operand is within the range. 380 unsigned operandIndex = operand->getOperandNumber(); 381 unsigned operandsStart = operands.getBeginOperandIndex(); 382 if (operandIndex < operandsStart || 383 operandIndex >= (operandsStart + operands.size())) 384 return; 385 386 // Index the successor. 387 unsigned argIndex = operandIndex - operandsStart; 388 appendToQueue(dest->getArgument(argIndex)); 389 }; 390 // Thread uses into structured loop bodies and return value uses. 391 if (auto ro = mlir::dyn_cast<DoLoopOp>(owner)) { 392 structuredLoop(ro); 393 } else if (auto ro = mlir::dyn_cast<IterWhileOp>(owner)) { 394 structuredLoop(ro); 395 } else if (auto rs = mlir::dyn_cast<ResultOp>(owner)) { 396 // Thread any uses of fir.if that return the marked array value. 397 mlir::Operation *parent = rs->getParentRegion()->getParentOp(); 398 if (auto ifOp = mlir::dyn_cast<fir::IfOp>(parent)) 399 appendToQueue(ifOp.getResult(operand->getOperandNumber())); 400 } else if (mlir::isa<ArrayFetchOp>(owner)) { 401 // Keep track of array value fetches. 402 LLVM_DEBUG(llvm::dbgs() 403 << "add fetch {" << *owner << "} to array value set\n"); 404 mentions.push_back(owner); 405 } else if (auto update = mlir::dyn_cast<ArrayUpdateOp>(owner)) { 406 // Keep track of array value updates and thread the return value uses. 407 LLVM_DEBUG(llvm::dbgs() 408 << "add update {" << *owner << "} to array value set\n"); 409 mentions.push_back(owner); 410 appendToQueue(update.getResult()); 411 } else if (auto update = mlir::dyn_cast<ArrayModifyOp>(owner)) { 412 // Keep track of array value modification and thread the return value 413 // uses. 414 LLVM_DEBUG(llvm::dbgs() 415 << "add modify {" << *owner << "} to array value set\n"); 416 mentions.push_back(owner); 417 appendToQueue(update.getResult(1)); 418 } else if (auto mention = mlir::dyn_cast<ArrayAccessOp>(owner)) { 419 mentions.push_back(owner); 420 } else if (auto amend = mlir::dyn_cast<ArrayAmendOp>(owner)) { 421 mentions.push_back(owner); 422 appendToQueue(amend.getResult()); 423 } else if (auto br = mlir::dyn_cast<mlir::cf::BranchOp>(owner)) { 424 branchOp(br.getDest(), br.getDestOperands()); 425 } else if (auto br = mlir::dyn_cast<mlir::cf::CondBranchOp>(owner)) { 426 branchOp(br.getTrueDest(), br.getTrueOperands()); 427 branchOp(br.getFalseDest(), br.getFalseOperands()); 428 } else if (mlir::isa<ArrayMergeStoreOp>(owner)) { 429 // do nothing 430 } else { 431 llvm::report_fatal_error("array value reached unexpected op"); 432 } 433 } 434 loadMapSets.insert({load, visited}); 435 } 436 437 static bool hasPointerType(mlir::Type type) { 438 if (auto boxTy = type.dyn_cast<BoxType>()) 439 type = boxTy.getEleTy(); 440 return type.isa<fir::PointerType>(); 441 } 442 443 // This is a NF performance hack. It makes a simple test that the slices of the 444 // load, \p ld, and the merge store, \p st, are trivially mutually exclusive. 445 static bool mutuallyExclusiveSliceRange(ArrayLoadOp ld, ArrayMergeStoreOp st) { 446 // If the same array_load, then no further testing is warranted. 447 if (ld.getResult() == st.getOriginal()) 448 return false; 449 450 auto getSliceOp = [](mlir::Value val) -> SliceOp { 451 if (!val) 452 return {}; 453 auto sliceOp = mlir::dyn_cast_or_null<SliceOp>(val.getDefiningOp()); 454 if (!sliceOp) 455 return {}; 456 return sliceOp; 457 }; 458 459 auto ldSlice = getSliceOp(ld.getSlice()); 460 auto stSlice = getSliceOp(st.getSlice()); 461 if (!ldSlice || !stSlice) 462 return false; 463 464 // Resign on subobject slices. 465 if (!ldSlice.getFields().empty() || !stSlice.getFields().empty() || 466 !ldSlice.getSubstr().empty() || !stSlice.getSubstr().empty()) 467 return false; 468 469 // Crudely test that the two slices do not overlap by looking for the 470 // following general condition. If the slices look like (i:j) and (j+1:k) then 471 // these ranges do not overlap. The addend must be a constant. 472 auto ldTriples = ldSlice.getTriples(); 473 auto stTriples = stSlice.getTriples(); 474 const auto size = ldTriples.size(); 475 if (size != stTriples.size()) 476 return false; 477 478 auto displacedByConstant = [](mlir::Value v1, mlir::Value v2) { 479 auto removeConvert = [](mlir::Value v) -> mlir::Operation * { 480 auto *op = v.getDefiningOp(); 481 while (auto conv = mlir::dyn_cast_or_null<ConvertOp>(op)) 482 op = conv.getValue().getDefiningOp(); 483 return op; 484 }; 485 486 auto isPositiveConstant = [](mlir::Value v) -> bool { 487 if (auto conOp = 488 mlir::dyn_cast<mlir::arith::ConstantOp>(v.getDefiningOp())) 489 if (auto iattr = conOp.getValue().dyn_cast<mlir::IntegerAttr>()) 490 return iattr.getInt() > 0; 491 return false; 492 }; 493 494 auto *op1 = removeConvert(v1); 495 auto *op2 = removeConvert(v2); 496 if (!op1 || !op2) 497 return false; 498 if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2)) 499 if ((addi.getLhs().getDefiningOp() == op1 && 500 isPositiveConstant(addi.getRhs())) || 501 (addi.getRhs().getDefiningOp() == op1 && 502 isPositiveConstant(addi.getLhs()))) 503 return true; 504 if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1)) 505 if (subi.getLhs().getDefiningOp() == op2 && 506 isPositiveConstant(subi.getRhs())) 507 return true; 508 return false; 509 }; 510 511 for (std::remove_const_t<decltype(size)> i = 0; i < size; i += 3) { 512 // If both are loop invariant, skip to the next triple. 513 if (mlir::isa_and_nonnull<fir::UndefOp>(ldTriples[i + 1].getDefiningOp()) && 514 mlir::isa_and_nonnull<fir::UndefOp>(stTriples[i + 1].getDefiningOp())) { 515 // Unless either is a vector index, then be conservative. 516 if (mlir::isa_and_nonnull<fir::UndefOp>(ldTriples[i].getDefiningOp()) || 517 mlir::isa_and_nonnull<fir::UndefOp>(stTriples[i].getDefiningOp())) 518 return false; 519 continue; 520 } 521 // If identical, skip to the next triple. 522 if (ldTriples[i] == stTriples[i] && ldTriples[i + 1] == stTriples[i + 1] && 523 ldTriples[i + 2] == stTriples[i + 2]) 524 continue; 525 // If ubound and lbound are the same with a constant offset, skip to the 526 // next triple. 527 if (displacedByConstant(ldTriples[i + 1], stTriples[i]) || 528 displacedByConstant(stTriples[i + 1], ldTriples[i])) 529 continue; 530 return false; 531 } 532 LLVM_DEBUG(llvm::dbgs() << "detected non-overlapping slice ranges on " << ld 533 << " and " << st << ", which is not a conflict\n"); 534 return true; 535 } 536 537 /// Is there a conflict between the array value that was updated and to be 538 /// stored to `st` and the set of arrays loaded (`reach`) and used to compute 539 /// the updated value? 540 static bool conflictOnLoad(llvm::ArrayRef<mlir::Operation *> reach, 541 ArrayMergeStoreOp st) { 542 mlir::Value load; 543 mlir::Value addr = st.getMemref(); 544 const bool storeHasPointerType = hasPointerType(addr.getType()); 545 for (auto *op : reach) 546 if (auto ld = mlir::dyn_cast<ArrayLoadOp>(op)) { 547 mlir::Type ldTy = ld.getMemref().getType(); 548 if (ld.getMemref() == addr) { 549 if (mutuallyExclusiveSliceRange(ld, st)) 550 continue; 551 if (ld.getResult() != st.getOriginal()) 552 return true; 553 if (load) { 554 // TODO: extend this to allow checking if the first `load` and this 555 // `ld` are mutually exclusive accesses but not identical. 556 return true; 557 } 558 load = ld; 559 } else if ((hasPointerType(ldTy) || storeHasPointerType)) { 560 // TODO: Use target attribute to restrict this case further. 561 // TODO: Check if types can also allow ruling out some cases. For now, 562 // the fact that equivalences is using pointer attribute to enforce 563 // aliasing is preventing any attempt to do so, and in general, it may 564 // be wrong to use this if any of the types is a complex or a derived 565 // for which it is possible to create a pointer to a part with a 566 // different type than the whole, although this deserve some more 567 // investigation because existing compiler behavior seem to diverge 568 // here. 569 return true; 570 } 571 } 572 return false; 573 } 574 575 /// Is there an access vector conflict on the array being merged into? If the 576 /// access vectors diverge, then assume that there are potentially overlapping 577 /// loop-carried references. 578 static bool conflictOnMerge(llvm::ArrayRef<mlir::Operation *> mentions) { 579 if (mentions.size() < 2) 580 return false; 581 llvm::SmallVector<mlir::Value> indices; 582 LLVM_DEBUG(llvm::dbgs() << "check merge conflict on with " << mentions.size() 583 << " mentions on the list\n"); 584 bool valSeen = false; 585 bool refSeen = false; 586 for (auto *op : mentions) { 587 llvm::SmallVector<mlir::Value> compareVector; 588 if (auto u = mlir::dyn_cast<ArrayUpdateOp>(op)) { 589 valSeen = true; 590 if (indices.empty()) { 591 indices = u.getIndices(); 592 continue; 593 } 594 compareVector = u.getIndices(); 595 } else if (auto f = mlir::dyn_cast<ArrayModifyOp>(op)) { 596 valSeen = true; 597 if (indices.empty()) { 598 indices = f.getIndices(); 599 continue; 600 } 601 compareVector = f.getIndices(); 602 } else if (auto f = mlir::dyn_cast<ArrayFetchOp>(op)) { 603 valSeen = true; 604 if (indices.empty()) { 605 indices = f.getIndices(); 606 continue; 607 } 608 compareVector = f.getIndices(); 609 } else if (auto f = mlir::dyn_cast<ArrayAccessOp>(op)) { 610 refSeen = true; 611 if (indices.empty()) { 612 indices = f.getIndices(); 613 continue; 614 } 615 compareVector = f.getIndices(); 616 } else if (mlir::isa<ArrayAmendOp>(op)) { 617 refSeen = true; 618 continue; 619 } else { 620 mlir::emitError(op->getLoc(), "unexpected operation in analysis"); 621 } 622 if (compareVector.size() != indices.size() || 623 llvm::any_of(llvm::zip(compareVector, indices), [&](auto pair) { 624 return std::get<0>(pair) != std::get<1>(pair); 625 })) 626 return true; 627 LLVM_DEBUG(llvm::dbgs() << "vectors compare equal\n"); 628 } 629 return valSeen && refSeen; 630 } 631 632 /// With element-by-reference semantics, an amended array with more than once 633 /// access to the same loaded array are conservatively considered a conflict. 634 /// Note: the array copy can still be eliminated in subsequent optimizations. 635 static bool conflictOnReference(llvm::ArrayRef<mlir::Operation *> mentions) { 636 LLVM_DEBUG(llvm::dbgs() << "checking reference semantics " << mentions.size() 637 << '\n'); 638 if (mentions.size() < 3) 639 return false; 640 unsigned amendCount = 0; 641 unsigned accessCount = 0; 642 for (auto *op : mentions) { 643 if (mlir::isa<ArrayAmendOp>(op) && ++amendCount > 1) { 644 LLVM_DEBUG(llvm::dbgs() << "conflict: multiple amends of array value\n"); 645 return true; 646 } 647 if (mlir::isa<ArrayAccessOp>(op) && ++accessCount > 1) { 648 LLVM_DEBUG(llvm::dbgs() 649 << "conflict: multiple accesses of array value\n"); 650 return true; 651 } 652 if (mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(op)) { 653 LLVM_DEBUG(llvm::dbgs() 654 << "conflict: array value has both uses by-value and uses " 655 "by-reference. conservative assumption.\n"); 656 return true; 657 } 658 } 659 return false; 660 } 661 662 static mlir::Operation * 663 amendingAccess(llvm::ArrayRef<mlir::Operation *> mentions) { 664 for (auto *op : mentions) 665 if (auto amend = mlir::dyn_cast<ArrayAmendOp>(op)) 666 return amend.getMemref().getDefiningOp(); 667 return {}; 668 } 669 670 // Are any conflicts present? The conflicts detected here are described above. 671 static bool conflictDetected(llvm::ArrayRef<mlir::Operation *> reach, 672 llvm::ArrayRef<mlir::Operation *> mentions, 673 ArrayMergeStoreOp st) { 674 return conflictOnLoad(reach, st) || conflictOnMerge(mentions); 675 } 676 677 // Assume that any call to a function that uses host-associations will be 678 // modifying the output array. 679 static bool 680 conservativeCallConflict(llvm::ArrayRef<mlir::Operation *> reaches) { 681 return llvm::any_of(reaches, [](mlir::Operation *op) { 682 if (auto call = mlir::dyn_cast<fir::CallOp>(op)) 683 if (auto callee = 684 call.getCallableForCallee().dyn_cast<mlir::SymbolRefAttr>()) { 685 auto module = op->getParentOfType<mlir::ModuleOp>(); 686 return hasHostAssociationArgument( 687 module.lookupSymbol<mlir::func::FuncOp>(callee)); 688 } 689 return false; 690 }); 691 } 692 693 /// Constructor of the array copy analysis. 694 /// This performs the analysis and saves the intermediate results. 695 void ArrayCopyAnalysis::construct(mlir::Operation *topLevelOp) { 696 topLevelOp->walk([&](Operation *op) { 697 if (auto st = mlir::dyn_cast<fir::ArrayMergeStoreOp>(op)) { 698 llvm::SmallVector<mlir::Operation *> values; 699 ReachCollector::reachingValues(values, st.getSequence()); 700 bool callConflict = conservativeCallConflict(values); 701 llvm::SmallVector<mlir::Operation *> mentions; 702 arrayMentions(mentions, 703 mlir::cast<ArrayLoadOp>(st.getOriginal().getDefiningOp())); 704 bool conflict = conflictDetected(values, mentions, st); 705 bool refConflict = conflictOnReference(mentions); 706 if (callConflict || conflict || refConflict) { 707 LLVM_DEBUG(llvm::dbgs() 708 << "CONFLICT: copies required for " << st << '\n' 709 << " adding conflicts on: " << *op << " and " 710 << st.getOriginal() << '\n'); 711 conflicts.insert(op); 712 conflicts.insert(st.getOriginal().getDefiningOp()); 713 if (auto *access = amendingAccess(mentions)) 714 amendAccesses.insert(access); 715 } 716 auto *ld = st.getOriginal().getDefiningOp(); 717 LLVM_DEBUG(llvm::dbgs() 718 << "map: adding {" << *ld << " -> " << st << "}\n"); 719 useMap.insert({ld, op}); 720 } else if (auto load = mlir::dyn_cast<ArrayLoadOp>(op)) { 721 llvm::SmallVector<mlir::Operation *> mentions; 722 arrayMentions(mentions, load); 723 LLVM_DEBUG(llvm::dbgs() << "process load: " << load 724 << ", mentions: " << mentions.size() << '\n'); 725 for (auto *acc : mentions) { 726 LLVM_DEBUG(llvm::dbgs() << " mention: " << *acc << '\n'); 727 if (mlir::isa<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, ArrayUpdateOp, 728 ArrayModifyOp>(acc)) { 729 if (useMap.count(acc)) { 730 mlir::emitError( 731 load.getLoc(), 732 "The parallel semantics of multiple array_merge_stores per " 733 "array_load are not supported."); 734 continue; 735 } 736 LLVM_DEBUG(llvm::dbgs() 737 << "map: adding {" << *acc << "} -> {" << load << "}\n"); 738 useMap.insert({acc, op}); 739 } 740 } 741 } 742 }); 743 } 744 745 //===----------------------------------------------------------------------===// 746 // Conversions for converting out of array value form. 747 //===----------------------------------------------------------------------===// 748 749 namespace { 750 class ArrayLoadConversion : public mlir::OpRewritePattern<ArrayLoadOp> { 751 public: 752 using OpRewritePattern::OpRewritePattern; 753 754 mlir::LogicalResult 755 matchAndRewrite(ArrayLoadOp load, 756 mlir::PatternRewriter &rewriter) const override { 757 LLVM_DEBUG(llvm::dbgs() << "replace load " << load << " with undef.\n"); 758 rewriter.replaceOpWithNewOp<UndefOp>(load, load.getType()); 759 return mlir::success(); 760 } 761 }; 762 763 class ArrayMergeStoreConversion 764 : public mlir::OpRewritePattern<ArrayMergeStoreOp> { 765 public: 766 using OpRewritePattern::OpRewritePattern; 767 768 mlir::LogicalResult 769 matchAndRewrite(ArrayMergeStoreOp store, 770 mlir::PatternRewriter &rewriter) const override { 771 LLVM_DEBUG(llvm::dbgs() << "marking store " << store << " as dead.\n"); 772 rewriter.eraseOp(store); 773 return mlir::success(); 774 } 775 }; 776 } // namespace 777 778 static mlir::Type getEleTy(mlir::Type ty) { 779 auto eleTy = unwrapSequenceType(unwrapPassByRefType(ty)); 780 // FIXME: keep ptr/heap/ref information. 781 return ReferenceType::get(eleTy); 782 } 783 784 // Extract extents from the ShapeOp/ShapeShiftOp into the result vector. 785 static bool getAdjustedExtents(mlir::Location loc, 786 mlir::PatternRewriter &rewriter, 787 ArrayLoadOp arrLoad, 788 llvm::SmallVectorImpl<mlir::Value> &result, 789 mlir::Value shape) { 790 bool copyUsingSlice = false; 791 auto *shapeOp = shape.getDefiningOp(); 792 if (auto s = mlir::dyn_cast_or_null<ShapeOp>(shapeOp)) { 793 auto e = s.getExtents(); 794 result.insert(result.end(), e.begin(), e.end()); 795 } else if (auto s = mlir::dyn_cast_or_null<ShapeShiftOp>(shapeOp)) { 796 auto e = s.getExtents(); 797 result.insert(result.end(), e.begin(), e.end()); 798 } else { 799 emitFatalError(loc, "not a fir.shape/fir.shape_shift op"); 800 } 801 auto idxTy = rewriter.getIndexType(); 802 if (factory::isAssumedSize(result)) { 803 // Use slice information to compute the extent of the column. 804 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1); 805 mlir::Value size = one; 806 if (mlir::Value sliceArg = arrLoad.getSlice()) { 807 if (auto sliceOp = 808 mlir::dyn_cast_or_null<SliceOp>(sliceArg.getDefiningOp())) { 809 auto triples = sliceOp.getTriples(); 810 const std::size_t tripleSize = triples.size(); 811 auto module = arrLoad->getParentOfType<mlir::ModuleOp>(); 812 fir::KindMapping kindMap = getKindMapping(module); 813 FirOpBuilder builder(rewriter, kindMap); 814 size = builder.genExtentFromTriplet(loc, triples[tripleSize - 3], 815 triples[tripleSize - 2], 816 triples[tripleSize - 1], idxTy); 817 copyUsingSlice = true; 818 } 819 } 820 result[result.size() - 1] = size; 821 } 822 return copyUsingSlice; 823 } 824 825 /// Place the extents of the array load, \p arrLoad, into \p result and 826 /// return a ShapeOp or ShapeShiftOp with the same extents. If \p arrLoad is 827 /// loading a `!fir.box`, code will be generated to read the extents from the 828 /// boxed value, and the retunred shape Op will be built with the extents read 829 /// from the box. Otherwise, the extents will be extracted from the ShapeOp (or 830 /// ShapeShiftOp) argument of \p arrLoad. \p copyUsingSlice will be set to true 831 /// if slicing of the output array is to be done in the copy-in/copy-out rather 832 /// than in the elemental computation step. 833 static mlir::Value getOrReadExtentsAndShapeOp( 834 mlir::Location loc, mlir::PatternRewriter &rewriter, ArrayLoadOp arrLoad, 835 llvm::SmallVectorImpl<mlir::Value> &result, bool ©UsingSlice) { 836 assert(result.empty()); 837 if (arrLoad->hasAttr(fir::getOptionalAttrName())) 838 fir::emitFatalError( 839 loc, "shapes from array load of OPTIONAL arrays must not be used"); 840 if (auto boxTy = arrLoad.getMemref().getType().dyn_cast<BoxType>()) { 841 auto rank = 842 dyn_cast_ptrOrBoxEleTy(boxTy).cast<SequenceType>().getDimension(); 843 auto idxTy = rewriter.getIndexType(); 844 for (decltype(rank) dim = 0; dim < rank; ++dim) { 845 auto dimVal = rewriter.create<mlir::arith::ConstantIndexOp>(loc, dim); 846 auto dimInfo = rewriter.create<BoxDimsOp>(loc, idxTy, idxTy, idxTy, 847 arrLoad.getMemref(), dimVal); 848 result.emplace_back(dimInfo.getResult(1)); 849 } 850 if (!arrLoad.getShape()) { 851 auto shapeType = ShapeType::get(rewriter.getContext(), rank); 852 return rewriter.create<ShapeOp>(loc, shapeType, result); 853 } 854 auto shiftOp = arrLoad.getShape().getDefiningOp<ShiftOp>(); 855 auto shapeShiftType = ShapeShiftType::get(rewriter.getContext(), rank); 856 llvm::SmallVector<mlir::Value> shapeShiftOperands; 857 for (auto [lb, extent] : llvm::zip(shiftOp.getOrigins(), result)) { 858 shapeShiftOperands.push_back(lb); 859 shapeShiftOperands.push_back(extent); 860 } 861 return rewriter.create<ShapeShiftOp>(loc, shapeShiftType, 862 shapeShiftOperands); 863 } 864 copyUsingSlice = 865 getAdjustedExtents(loc, rewriter, arrLoad, result, arrLoad.getShape()); 866 return arrLoad.getShape(); 867 } 868 869 static mlir::Type toRefType(mlir::Type ty) { 870 if (fir::isa_ref_type(ty)) 871 return ty; 872 return fir::ReferenceType::get(ty); 873 } 874 875 static llvm::SmallVector<mlir::Value> 876 getTypeParamsIfRawData(mlir::Location loc, FirOpBuilder &builder, 877 ArrayLoadOp arrLoad, mlir::Type ty) { 878 if (ty.isa<BoxType>()) 879 return {}; 880 return fir::factory::getTypeParams(loc, builder, arrLoad); 881 } 882 883 static mlir::Value genCoorOp(mlir::PatternRewriter &rewriter, 884 mlir::Location loc, mlir::Type eleTy, 885 mlir::Type resTy, mlir::Value alloc, 886 mlir::Value shape, mlir::Value slice, 887 mlir::ValueRange indices, ArrayLoadOp load, 888 bool skipOrig = false) { 889 llvm::SmallVector<mlir::Value> originated; 890 if (skipOrig) 891 originated.assign(indices.begin(), indices.end()); 892 else 893 originated = factory::originateIndices(loc, rewriter, alloc.getType(), 894 shape, indices); 895 auto seqTy = dyn_cast_ptrOrBoxEleTy(alloc.getType()); 896 assert(seqTy && seqTy.isa<SequenceType>()); 897 const auto dimension = seqTy.cast<SequenceType>().getDimension(); 898 auto module = load->getParentOfType<mlir::ModuleOp>(); 899 fir::KindMapping kindMap = getKindMapping(module); 900 FirOpBuilder builder(rewriter, kindMap); 901 auto typeparams = getTypeParamsIfRawData(loc, builder, load, alloc.getType()); 902 mlir::Value result = rewriter.create<ArrayCoorOp>( 903 loc, eleTy, alloc, shape, slice, 904 llvm::ArrayRef<mlir::Value>{originated}.take_front(dimension), 905 typeparams); 906 if (dimension < originated.size()) 907 result = rewriter.create<fir::CoordinateOp>( 908 loc, resTy, result, 909 llvm::ArrayRef<mlir::Value>{originated}.drop_front(dimension)); 910 return result; 911 } 912 913 static mlir::Value getCharacterLen(mlir::Location loc, FirOpBuilder &builder, 914 ArrayLoadOp load, CharacterType charTy) { 915 auto charLenTy = builder.getCharacterLengthType(); 916 if (charTy.hasDynamicLen()) { 917 if (load.getMemref().getType().isa<BoxType>()) { 918 // The loaded array is an emboxed value. Get the CHARACTER length from 919 // the box value. 920 auto eleSzInBytes = 921 builder.create<BoxEleSizeOp>(loc, charLenTy, load.getMemref()); 922 auto kindSize = 923 builder.getKindMap().getCharacterBitsize(charTy.getFKind()); 924 auto kindByteSize = 925 builder.createIntegerConstant(loc, charLenTy, kindSize / 8); 926 return builder.create<mlir::arith::DivSIOp>(loc, eleSzInBytes, 927 kindByteSize); 928 } 929 // The loaded array is a (set of) unboxed values. If the CHARACTER's 930 // length is not a constant, it must be provided as a type parameter to 931 // the array_load. 932 auto typeparams = load.getTypeparams(); 933 assert(typeparams.size() > 0 && "expected type parameters on array_load"); 934 return typeparams.back(); 935 } 936 // The typical case: the length of the CHARACTER is a compile-time 937 // constant that is encoded in the type information. 938 return builder.createIntegerConstant(loc, charLenTy, charTy.getLen()); 939 } 940 /// Generate a shallow array copy. This is used for both copy-in and copy-out. 941 template <bool CopyIn> 942 void genArrayCopy(mlir::Location loc, mlir::PatternRewriter &rewriter, 943 mlir::Value dst, mlir::Value src, mlir::Value shapeOp, 944 mlir::Value sliceOp, ArrayLoadOp arrLoad) { 945 auto insPt = rewriter.saveInsertionPoint(); 946 llvm::SmallVector<mlir::Value> indices; 947 llvm::SmallVector<mlir::Value> extents; 948 bool copyUsingSlice = 949 getAdjustedExtents(loc, rewriter, arrLoad, extents, shapeOp); 950 auto idxTy = rewriter.getIndexType(); 951 // Build loop nest from column to row. 952 for (auto sh : llvm::reverse(extents)) { 953 auto ubi = rewriter.create<ConvertOp>(loc, idxTy, sh); 954 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); 955 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1); 956 auto ub = rewriter.create<mlir::arith::SubIOp>(loc, idxTy, ubi, one); 957 auto loop = rewriter.create<DoLoopOp>(loc, zero, ub, one); 958 rewriter.setInsertionPointToStart(loop.getBody()); 959 indices.push_back(loop.getInductionVar()); 960 } 961 // Reverse the indices so they are in column-major order. 962 std::reverse(indices.begin(), indices.end()); 963 auto module = arrLoad->getParentOfType<mlir::ModuleOp>(); 964 fir::KindMapping kindMap = getKindMapping(module); 965 FirOpBuilder builder(rewriter, kindMap); 966 auto fromAddr = rewriter.create<ArrayCoorOp>( 967 loc, getEleTy(src.getType()), src, shapeOp, 968 CopyIn && copyUsingSlice ? sliceOp : mlir::Value{}, 969 factory::originateIndices(loc, rewriter, src.getType(), shapeOp, indices), 970 getTypeParamsIfRawData(loc, builder, arrLoad, src.getType())); 971 auto toAddr = rewriter.create<ArrayCoorOp>( 972 loc, getEleTy(dst.getType()), dst, shapeOp, 973 !CopyIn && copyUsingSlice ? sliceOp : mlir::Value{}, 974 factory::originateIndices(loc, rewriter, dst.getType(), shapeOp, indices), 975 getTypeParamsIfRawData(loc, builder, arrLoad, dst.getType())); 976 auto eleTy = unwrapSequenceType(unwrapPassByRefType(dst.getType())); 977 // Copy from (to) object to (from) temp copy of same object. 978 if (auto charTy = eleTy.dyn_cast<CharacterType>()) { 979 auto len = getCharacterLen(loc, builder, arrLoad, charTy); 980 CharBoxValue toChar(toAddr, len); 981 CharBoxValue fromChar(fromAddr, len); 982 factory::genScalarAssignment(builder, loc, toChar, fromChar); 983 } else { 984 if (hasDynamicSize(eleTy)) 985 TODO(loc, "copy element of dynamic size"); 986 factory::genScalarAssignment(builder, loc, toAddr, fromAddr); 987 } 988 rewriter.restoreInsertionPoint(insPt); 989 } 990 991 /// The array load may be either a boxed or unboxed value. If the value is 992 /// boxed, we read the type parameters from the boxed value. 993 static llvm::SmallVector<mlir::Value> 994 genArrayLoadTypeParameters(mlir::Location loc, mlir::PatternRewriter &rewriter, 995 ArrayLoadOp load) { 996 if (load.getTypeparams().empty()) { 997 auto eleTy = 998 unwrapSequenceType(unwrapPassByRefType(load.getMemref().getType())); 999 if (hasDynamicSize(eleTy)) { 1000 if (auto charTy = eleTy.dyn_cast<CharacterType>()) { 1001 assert(load.getMemref().getType().isa<BoxType>()); 1002 auto module = load->getParentOfType<mlir::ModuleOp>(); 1003 fir::KindMapping kindMap = getKindMapping(module); 1004 FirOpBuilder builder(rewriter, kindMap); 1005 return {getCharacterLen(loc, builder, load, charTy)}; 1006 } 1007 TODO(loc, "unhandled dynamic type parameters"); 1008 } 1009 return {}; 1010 } 1011 return load.getTypeparams(); 1012 } 1013 1014 static llvm::SmallVector<mlir::Value> 1015 findNonconstantExtents(mlir::Type memrefTy, 1016 llvm::ArrayRef<mlir::Value> extents) { 1017 llvm::SmallVector<mlir::Value> nce; 1018 auto arrTy = unwrapPassByRefType(memrefTy); 1019 auto seqTy = arrTy.cast<SequenceType>(); 1020 for (auto [s, x] : llvm::zip(seqTy.getShape(), extents)) 1021 if (s == SequenceType::getUnknownExtent()) 1022 nce.emplace_back(x); 1023 if (extents.size() > seqTy.getShape().size()) 1024 for (auto x : extents.drop_front(seqTy.getShape().size())) 1025 nce.emplace_back(x); 1026 return nce; 1027 } 1028 1029 /// Allocate temporary storage for an ArrayLoadOp \load and initialize any 1030 /// allocatable direct components of the array elements with an unallocated 1031 /// status. Returns the temporary address as well as a callback to generate the 1032 /// temporary clean-up once it has been used. The clean-up will take care of 1033 /// deallocating all the element allocatable components that may have been 1034 /// allocated while using the temporary. 1035 static std::pair<mlir::Value, 1036 std::function<void(mlir::PatternRewriter &rewriter)>> 1037 allocateArrayTemp(mlir::Location loc, mlir::PatternRewriter &rewriter, 1038 ArrayLoadOp load, llvm::ArrayRef<mlir::Value> extents, 1039 mlir::Value shape) { 1040 mlir::Type baseType = load.getMemref().getType(); 1041 llvm::SmallVector<mlir::Value> nonconstantExtents = 1042 findNonconstantExtents(baseType, extents); 1043 llvm::SmallVector<mlir::Value> typeParams = 1044 genArrayLoadTypeParameters(loc, rewriter, load); 1045 mlir::Value allocmem = rewriter.create<AllocMemOp>( 1046 loc, dyn_cast_ptrOrBoxEleTy(baseType), typeParams, nonconstantExtents); 1047 mlir::Type eleType = 1048 fir::unwrapSequenceType(fir::unwrapPassByRefType(baseType)); 1049 if (fir::isRecordWithAllocatableMember(eleType)) { 1050 // The allocatable component descriptors need to be set to a clean 1051 // deallocated status before anything is done with them. 1052 mlir::Value box = rewriter.create<fir::EmboxOp>( 1053 loc, fir::BoxType::get(baseType), allocmem, shape, 1054 /*slice=*/mlir::Value{}, typeParams); 1055 auto module = load->getParentOfType<mlir::ModuleOp>(); 1056 fir::KindMapping kindMap = getKindMapping(module); 1057 FirOpBuilder builder(rewriter, kindMap); 1058 runtime::genDerivedTypeInitialize(builder, loc, box); 1059 // Any allocatable component that may have been allocated must be 1060 // deallocated during the clean-up. 1061 auto cleanup = [=](mlir::PatternRewriter &r) { 1062 fir::KindMapping kindMap = getKindMapping(module); 1063 FirOpBuilder builder(r, kindMap); 1064 runtime::genDerivedTypeDestroy(builder, loc, box); 1065 r.create<FreeMemOp>(loc, allocmem); 1066 }; 1067 return {allocmem, cleanup}; 1068 } 1069 auto cleanup = [=](mlir::PatternRewriter &r) { 1070 r.create<FreeMemOp>(loc, allocmem); 1071 }; 1072 return {allocmem, cleanup}; 1073 } 1074 1075 namespace { 1076 /// Conversion of fir.array_update and fir.array_modify Ops. 1077 /// If there is a conflict for the update, then we need to perform a 1078 /// copy-in/copy-out to preserve the original values of the array. If there is 1079 /// no conflict, then it is save to eschew making any copies. 1080 template <typename ArrayOp> 1081 class ArrayUpdateConversionBase : public mlir::OpRewritePattern<ArrayOp> { 1082 public: 1083 // TODO: Implement copy/swap semantics? 1084 explicit ArrayUpdateConversionBase(mlir::MLIRContext *ctx, 1085 const ArrayCopyAnalysis &a, 1086 const OperationUseMapT &m) 1087 : mlir::OpRewritePattern<ArrayOp>{ctx}, analysis{a}, useMap{m} {} 1088 1089 /// The array_access, \p access, is to be to a cloned copy due to a potential 1090 /// conflict. Uses copy-in/copy-out semantics and not copy/swap. 1091 mlir::Value referenceToClone(mlir::Location loc, 1092 mlir::PatternRewriter &rewriter, 1093 ArrayOp access) const { 1094 LLVM_DEBUG(llvm::dbgs() 1095 << "generating copy-in/copy-out loops for " << access << '\n'); 1096 auto *op = access.getOperation(); 1097 auto *loadOp = useMap.lookup(op); 1098 auto load = mlir::cast<ArrayLoadOp>(loadOp); 1099 auto eleTy = access.getType(); 1100 rewriter.setInsertionPoint(loadOp); 1101 // Copy in. 1102 llvm::SmallVector<mlir::Value> extents; 1103 bool copyUsingSlice = false; 1104 auto shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents, 1105 copyUsingSlice); 1106 auto [allocmem, genTempCleanUp] = 1107 allocateArrayTemp(loc, rewriter, load, extents, shapeOp); 1108 genArrayCopy</*copyIn=*/true>(load.getLoc(), rewriter, allocmem, 1109 load.getMemref(), shapeOp, load.getSlice(), 1110 load); 1111 // Generate the reference for the access. 1112 rewriter.setInsertionPoint(op); 1113 auto coor = genCoorOp( 1114 rewriter, loc, getEleTy(load.getType()), eleTy, allocmem, shapeOp, 1115 copyUsingSlice ? mlir::Value{} : load.getSlice(), access.getIndices(), 1116 load, access->hasAttr(factory::attrFortranArrayOffsets())); 1117 // Copy out. 1118 auto *storeOp = useMap.lookup(loadOp); 1119 auto store = mlir::cast<ArrayMergeStoreOp>(storeOp); 1120 rewriter.setInsertionPoint(storeOp); 1121 // Copy out. 1122 genArrayCopy</*copyIn=*/false>(store.getLoc(), rewriter, store.getMemref(), 1123 allocmem, shapeOp, store.getSlice(), load); 1124 genTempCleanUp(rewriter); 1125 return coor; 1126 } 1127 1128 /// Copy the RHS element into the LHS and insert copy-in/copy-out between a 1129 /// temp and the LHS if the analysis found potential overlaps between the RHS 1130 /// and LHS arrays. The element copy generator must be provided in \p 1131 /// assignElement. \p update must be the ArrayUpdateOp or the ArrayModifyOp. 1132 /// Returns the address of the LHS element inside the loop and the LHS 1133 /// ArrayLoad result. 1134 std::pair<mlir::Value, mlir::Value> 1135 materializeAssignment(mlir::Location loc, mlir::PatternRewriter &rewriter, 1136 ArrayOp update, 1137 const std::function<void(mlir::Value)> &assignElement, 1138 mlir::Type lhsEltRefType) const { 1139 auto *op = update.getOperation(); 1140 auto *loadOp = useMap.lookup(op); 1141 auto load = mlir::cast<ArrayLoadOp>(loadOp); 1142 LLVM_DEBUG(llvm::outs() << "does " << load << " have a conflict?\n"); 1143 if (analysis.hasPotentialConflict(loadOp)) { 1144 // If there is a conflict between the arrays, then we copy the lhs array 1145 // to a temporary, update the temporary, and copy the temporary back to 1146 // the lhs array. This yields Fortran's copy-in copy-out array semantics. 1147 LLVM_DEBUG(llvm::outs() << "Yes, conflict was found\n"); 1148 rewriter.setInsertionPoint(loadOp); 1149 // Copy in. 1150 llvm::SmallVector<mlir::Value> extents; 1151 bool copyUsingSlice = false; 1152 auto shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents, 1153 copyUsingSlice); 1154 auto [allocmem, genTempCleanUp] = 1155 allocateArrayTemp(loc, rewriter, load, extents, shapeOp); 1156 1157 genArrayCopy</*copyIn=*/true>(load.getLoc(), rewriter, allocmem, 1158 load.getMemref(), shapeOp, load.getSlice(), 1159 load); 1160 rewriter.setInsertionPoint(op); 1161 auto coor = genCoorOp( 1162 rewriter, loc, getEleTy(load.getType()), lhsEltRefType, allocmem, 1163 shapeOp, copyUsingSlice ? mlir::Value{} : load.getSlice(), 1164 update.getIndices(), load, 1165 update->hasAttr(factory::attrFortranArrayOffsets())); 1166 assignElement(coor); 1167 auto *storeOp = useMap.lookup(loadOp); 1168 auto store = mlir::cast<ArrayMergeStoreOp>(storeOp); 1169 rewriter.setInsertionPoint(storeOp); 1170 // Copy out. 1171 genArrayCopy</*copyIn=*/false>(store.getLoc(), rewriter, 1172 store.getMemref(), allocmem, shapeOp, 1173 store.getSlice(), load); 1174 genTempCleanUp(rewriter); 1175 return {coor, load.getResult()}; 1176 } 1177 // Otherwise, when there is no conflict (a possible loop-carried 1178 // dependence), the lhs array can be updated in place. 1179 LLVM_DEBUG(llvm::outs() << "No, conflict wasn't found\n"); 1180 rewriter.setInsertionPoint(op); 1181 auto coorTy = getEleTy(load.getType()); 1182 auto coor = 1183 genCoorOp(rewriter, loc, coorTy, lhsEltRefType, load.getMemref(), 1184 load.getShape(), load.getSlice(), update.getIndices(), load, 1185 update->hasAttr(factory::attrFortranArrayOffsets())); 1186 assignElement(coor); 1187 return {coor, load.getResult()}; 1188 } 1189 1190 protected: 1191 const ArrayCopyAnalysis &analysis; 1192 const OperationUseMapT &useMap; 1193 }; 1194 1195 class ArrayUpdateConversion : public ArrayUpdateConversionBase<ArrayUpdateOp> { 1196 public: 1197 explicit ArrayUpdateConversion(mlir::MLIRContext *ctx, 1198 const ArrayCopyAnalysis &a, 1199 const OperationUseMapT &m) 1200 : ArrayUpdateConversionBase{ctx, a, m} {} 1201 1202 mlir::LogicalResult 1203 matchAndRewrite(ArrayUpdateOp update, 1204 mlir::PatternRewriter &rewriter) const override { 1205 auto loc = update.getLoc(); 1206 auto assignElement = [&](mlir::Value coor) { 1207 auto input = update.getMerge(); 1208 if (auto inEleTy = dyn_cast_ptrEleTy(input.getType())) { 1209 emitFatalError(loc, "array_update on references not supported"); 1210 } else { 1211 rewriter.create<fir::StoreOp>(loc, input, coor); 1212 } 1213 }; 1214 auto lhsEltRefType = toRefType(update.getMerge().getType()); 1215 auto [_, lhsLoadResult] = materializeAssignment( 1216 loc, rewriter, update, assignElement, lhsEltRefType); 1217 update.replaceAllUsesWith(lhsLoadResult); 1218 rewriter.replaceOp(update, lhsLoadResult); 1219 return mlir::success(); 1220 } 1221 }; 1222 1223 class ArrayModifyConversion : public ArrayUpdateConversionBase<ArrayModifyOp> { 1224 public: 1225 explicit ArrayModifyConversion(mlir::MLIRContext *ctx, 1226 const ArrayCopyAnalysis &a, 1227 const OperationUseMapT &m) 1228 : ArrayUpdateConversionBase{ctx, a, m} {} 1229 1230 mlir::LogicalResult 1231 matchAndRewrite(ArrayModifyOp modify, 1232 mlir::PatternRewriter &rewriter) const override { 1233 auto loc = modify.getLoc(); 1234 auto assignElement = [](mlir::Value) { 1235 // Assignment already materialized by lowering using lhs element address. 1236 }; 1237 auto lhsEltRefType = modify.getResult(0).getType(); 1238 auto [lhsEltCoor, lhsLoadResult] = materializeAssignment( 1239 loc, rewriter, modify, assignElement, lhsEltRefType); 1240 modify.replaceAllUsesWith(mlir::ValueRange{lhsEltCoor, lhsLoadResult}); 1241 rewriter.replaceOp(modify, mlir::ValueRange{lhsEltCoor, lhsLoadResult}); 1242 return mlir::success(); 1243 } 1244 }; 1245 1246 class ArrayFetchConversion : public mlir::OpRewritePattern<ArrayFetchOp> { 1247 public: 1248 explicit ArrayFetchConversion(mlir::MLIRContext *ctx, 1249 const OperationUseMapT &m) 1250 : OpRewritePattern{ctx}, useMap{m} {} 1251 1252 mlir::LogicalResult 1253 matchAndRewrite(ArrayFetchOp fetch, 1254 mlir::PatternRewriter &rewriter) const override { 1255 auto *op = fetch.getOperation(); 1256 rewriter.setInsertionPoint(op); 1257 auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op)); 1258 auto loc = fetch.getLoc(); 1259 auto coor = genCoorOp( 1260 rewriter, loc, getEleTy(load.getType()), toRefType(fetch.getType()), 1261 load.getMemref(), load.getShape(), load.getSlice(), fetch.getIndices(), 1262 load, fetch->hasAttr(factory::attrFortranArrayOffsets())); 1263 if (isa_ref_type(fetch.getType())) 1264 rewriter.replaceOp(fetch, coor); 1265 else 1266 rewriter.replaceOpWithNewOp<fir::LoadOp>(fetch, coor); 1267 return mlir::success(); 1268 } 1269 1270 private: 1271 const OperationUseMapT &useMap; 1272 }; 1273 1274 /// As array_access op is like an array_fetch op, except that it does not imply 1275 /// a load op. (It operates in the reference domain.) 1276 class ArrayAccessConversion : public ArrayUpdateConversionBase<ArrayAccessOp> { 1277 public: 1278 explicit ArrayAccessConversion(mlir::MLIRContext *ctx, 1279 const ArrayCopyAnalysis &a, 1280 const OperationUseMapT &m) 1281 : ArrayUpdateConversionBase{ctx, a, m} {} 1282 1283 mlir::LogicalResult 1284 matchAndRewrite(ArrayAccessOp access, 1285 mlir::PatternRewriter &rewriter) const override { 1286 auto *op = access.getOperation(); 1287 auto loc = access.getLoc(); 1288 if (analysis.inAmendAccessSet(op)) { 1289 // This array_access is associated with an array_amend and there is a 1290 // conflict. Make a copy to store into. 1291 auto result = referenceToClone(loc, rewriter, access); 1292 access.replaceAllUsesWith(result); 1293 rewriter.replaceOp(access, result); 1294 return mlir::success(); 1295 } 1296 rewriter.setInsertionPoint(op); 1297 auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op)); 1298 auto coor = genCoorOp( 1299 rewriter, loc, getEleTy(load.getType()), toRefType(access.getType()), 1300 load.getMemref(), load.getShape(), load.getSlice(), access.getIndices(), 1301 load, access->hasAttr(factory::attrFortranArrayOffsets())); 1302 rewriter.replaceOp(access, coor); 1303 return mlir::success(); 1304 } 1305 }; 1306 1307 /// An array_amend op is a marker to record which array access is being used to 1308 /// update an array value. After this pass runs, an array_amend has no 1309 /// semantics. We rewrite these to undefined values here to remove them while 1310 /// preserving SSA form. 1311 class ArrayAmendConversion : public mlir::OpRewritePattern<ArrayAmendOp> { 1312 public: 1313 explicit ArrayAmendConversion(mlir::MLIRContext *ctx) 1314 : OpRewritePattern{ctx} {} 1315 1316 mlir::LogicalResult 1317 matchAndRewrite(ArrayAmendOp amend, 1318 mlir::PatternRewriter &rewriter) const override { 1319 auto *op = amend.getOperation(); 1320 rewriter.setInsertionPoint(op); 1321 auto loc = amend.getLoc(); 1322 auto undef = rewriter.create<UndefOp>(loc, amend.getType()); 1323 rewriter.replaceOp(amend, undef.getResult()); 1324 return mlir::success(); 1325 } 1326 }; 1327 1328 class ArrayValueCopyConverter 1329 : public ArrayValueCopyBase<ArrayValueCopyConverter> { 1330 public: 1331 void runOnOperation() override { 1332 auto func = getOperation(); 1333 LLVM_DEBUG(llvm::dbgs() << "\n\narray-value-copy pass on function '" 1334 << func.getName() << "'\n"); 1335 auto *context = &getContext(); 1336 1337 // Perform the conflict analysis. 1338 const auto &analysis = getAnalysis<ArrayCopyAnalysis>(); 1339 const auto &useMap = analysis.getUseMap(); 1340 1341 mlir::RewritePatternSet patterns1(context); 1342 patterns1.insert<ArrayFetchConversion>(context, useMap); 1343 patterns1.insert<ArrayUpdateConversion>(context, analysis, useMap); 1344 patterns1.insert<ArrayModifyConversion>(context, analysis, useMap); 1345 patterns1.insert<ArrayAccessConversion>(context, analysis, useMap); 1346 patterns1.insert<ArrayAmendConversion>(context); 1347 mlir::ConversionTarget target(*context); 1348 target.addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect, 1349 mlir::arith::ArithmeticDialect, 1350 mlir::func::FuncDialect>(); 1351 target.addIllegalOp<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, 1352 ArrayUpdateOp, ArrayModifyOp>(); 1353 // Rewrite the array fetch and array update ops. 1354 if (mlir::failed( 1355 mlir::applyPartialConversion(func, target, std::move(patterns1)))) { 1356 mlir::emitError(mlir::UnknownLoc::get(context), 1357 "failure in array-value-copy pass, phase 1"); 1358 signalPassFailure(); 1359 } 1360 1361 mlir::RewritePatternSet patterns2(context); 1362 patterns2.insert<ArrayLoadConversion>(context); 1363 patterns2.insert<ArrayMergeStoreConversion>(context); 1364 target.addIllegalOp<ArrayLoadOp, ArrayMergeStoreOp>(); 1365 if (mlir::failed( 1366 mlir::applyPartialConversion(func, target, std::move(patterns2)))) { 1367 mlir::emitError(mlir::UnknownLoc::get(context), 1368 "failure in array-value-copy pass, phase 2"); 1369 signalPassFailure(); 1370 } 1371 } 1372 }; 1373 } // namespace 1374 1375 std::unique_ptr<mlir::Pass> fir::createArrayValueCopyPass() { 1376 return std::make_unique<ArrayValueCopyConverter>(); 1377 } 1378