1 //===-- ArrayValueCopy.cpp ------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "flang/Lower/Todo.h" 11 #include "flang/Optimizer/Builder/Array.h" 12 #include "flang/Optimizer/Builder/BoxValue.h" 13 #include "flang/Optimizer/Builder/FIRBuilder.h" 14 #include "flang/Optimizer/Builder/Factory.h" 15 #include "flang/Optimizer/Dialect/FIRDialect.h" 16 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 17 #include "flang/Optimizer/Support/FIRContext.h" 18 #include "flang/Optimizer/Transforms/Passes.h" 19 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include "llvm/Support/Debug.h" 23 24 #define DEBUG_TYPE "flang-array-value-copy" 25 26 using namespace fir; 27 using namespace mlir; 28 29 using OperationUseMapT = llvm::DenseMap<mlir::Operation *, mlir::Operation *>; 30 31 namespace { 32 33 /// Array copy analysis. 34 /// Perform an interference analysis between array values. 35 /// 36 /// Lowering will generate a sequence of the following form. 37 /// ```mlir 38 /// %a_1 = fir.array_load %array_1(%shape) : ... 39 /// ... 40 /// %a_j = fir.array_load %array_j(%shape) : ... 41 /// ... 42 /// %a_n = fir.array_load %array_n(%shape) : ... 43 /// ... 44 /// %v_i = fir.array_fetch %a_i, ... 45 /// %a_j1 = fir.array_update %a_j, ... 46 /// ... 47 /// fir.array_merge_store %a_j, %a_jn to %array_j : ... 48 /// ``` 49 /// 50 /// The analysis is to determine if there are any conflicts. A conflict is when 51 /// one the following cases occurs. 52 /// 53 /// 1. There is an `array_update` to an array value, a_j, such that a_j was 54 /// loaded from the same array memory reference (array_j) but with a different 55 /// shape as the other array values a_i, where i != j. [Possible overlapping 56 /// arrays.] 57 /// 58 /// 2. There is either an array_fetch or array_update of a_j with a different 59 /// set of index values. [Possible loop-carried dependence.] 60 /// 61 /// If none of the array values overlap in storage and the accesses are not 62 /// loop-carried, then the arrays are conflict-free and no copies are required. 63 class ArrayCopyAnalysis { 64 public: 65 using ConflictSetT = llvm::SmallPtrSet<mlir::Operation *, 16>; 66 using UseSetT = llvm::SmallPtrSet<mlir::OpOperand *, 8>; 67 using LoadMapSetsT = llvm::DenseMap<mlir::Operation *, UseSetT>; 68 using AmendAccessSetT = llvm::SmallPtrSet<mlir::Operation *, 4>; 69 70 ArrayCopyAnalysis(mlir::Operation *op) : operation{op} { construct(op); } 71 72 mlir::Operation *getOperation() const { return operation; } 73 74 /// Return true iff the `array_merge_store` has potential conflicts. 75 bool hasPotentialConflict(mlir::Operation *op) const { 76 LLVM_DEBUG(llvm::dbgs() 77 << "looking for a conflict on " << *op 78 << " and the set has a total of " << conflicts.size() << '\n'); 79 return conflicts.contains(op); 80 } 81 82 /// Return the use map. 83 /// The use map maps array access, amend, fetch and update operations back to 84 /// the array load that is the original source of the array value. 85 /// It maps an array_load to an array_merge_store, if and only if the loaded 86 /// array value has pending modifications to be merged. 87 const OperationUseMapT &getUseMap() const { return useMap; } 88 89 /// Return the set of array_access ops directly associated with array_amend 90 /// ops. 91 bool inAmendAccessSet(mlir::Operation *op) const { 92 return amendAccesses.count(op); 93 } 94 95 /// For ArrayLoad `load`, return the transitive set of all OpOperands. 96 UseSetT getLoadUseSet(mlir::Operation *load) const { 97 assert(loadMapSets.count(load) && "analysis missed an array load?"); 98 return loadMapSets.lookup(load); 99 } 100 101 void arrayMentions(llvm::SmallVectorImpl<mlir::Operation *> &mentions, 102 ArrayLoadOp load); 103 104 private: 105 void construct(mlir::Operation *topLevelOp); 106 107 mlir::Operation *operation; // operation that analysis ran upon 108 ConflictSetT conflicts; // set of conflicts (loads and merge stores) 109 OperationUseMapT useMap; 110 LoadMapSetsT loadMapSets; 111 // Set of array_access ops associated with array_amend ops. 112 AmendAccessSetT amendAccesses; 113 }; 114 } // namespace 115 116 namespace { 117 /// Helper class to collect all array operations that produced an array value. 118 class ReachCollector { 119 public: 120 ReachCollector(llvm::SmallVectorImpl<mlir::Operation *> &reach, 121 mlir::Region *loopRegion) 122 : reach{reach}, loopRegion{loopRegion} {} 123 124 void collectArrayMentionFrom(mlir::Operation *op, mlir::ValueRange range) { 125 if (range.empty()) { 126 collectArrayMentionFrom(op, mlir::Value{}); 127 return; 128 } 129 for (mlir::Value v : range) 130 collectArrayMentionFrom(v); 131 } 132 133 // Collect all the array_access ops in `block`. This recursively looks into 134 // blocks in ops with regions. 135 // FIXME: This is temporarily relying on the array_amend appearing in a 136 // do_loop Region. This phase ordering assumption can be eliminated by using 137 // dominance information to find the array_access ops or by scanning the 138 // transitive closure of the amending array_access's users and the defs that 139 // reach them. 140 void collectAccesses(llvm::SmallVector<ArrayAccessOp> &result, 141 mlir::Block *block) { 142 for (auto &op : *block) { 143 if (auto access = mlir::dyn_cast<ArrayAccessOp>(op)) { 144 LLVM_DEBUG(llvm::dbgs() << "adding access: " << access << '\n'); 145 result.push_back(access); 146 continue; 147 } 148 for (auto ®ion : op.getRegions()) 149 for (auto &bb : region.getBlocks()) 150 collectAccesses(result, &bb); 151 } 152 } 153 154 void collectArrayMentionFrom(mlir::Operation *op, mlir::Value val) { 155 // `val` is defined by an Op, process the defining Op. 156 // If `val` is defined by a region containing Op, we want to drill down 157 // and through that Op's region(s). 158 LLVM_DEBUG(llvm::dbgs() << "popset: " << *op << '\n'); 159 auto popFn = [&](auto rop) { 160 assert(val && "op must have a result value"); 161 auto resNum = val.cast<mlir::OpResult>().getResultNumber(); 162 llvm::SmallVector<mlir::Value> results; 163 rop.resultToSourceOps(results, resNum); 164 for (auto u : results) 165 collectArrayMentionFrom(u); 166 }; 167 if (auto rop = mlir::dyn_cast<DoLoopOp>(op)) { 168 popFn(rop); 169 return; 170 } 171 if (auto rop = mlir::dyn_cast<IterWhileOp>(op)) { 172 popFn(rop); 173 return; 174 } 175 if (auto rop = mlir::dyn_cast<fir::IfOp>(op)) { 176 popFn(rop); 177 return; 178 } 179 if (auto box = mlir::dyn_cast<EmboxOp>(op)) { 180 for (auto *user : box.getMemref().getUsers()) 181 if (user != op) 182 collectArrayMentionFrom(user, user->getResults()); 183 return; 184 } 185 if (auto mergeStore = mlir::dyn_cast<ArrayMergeStoreOp>(op)) { 186 if (opIsInsideLoops(mergeStore)) 187 collectArrayMentionFrom(mergeStore.getSequence()); 188 return; 189 } 190 191 if (mlir::isa<AllocaOp, AllocMemOp>(op)) { 192 // Look for any stores inside the loops, and collect an array operation 193 // that produced the value being stored to it. 194 for (auto *user : op->getUsers()) 195 if (auto store = mlir::dyn_cast<fir::StoreOp>(user)) 196 if (opIsInsideLoops(store)) 197 collectArrayMentionFrom(store.getValue()); 198 return; 199 } 200 201 // Scan the uses of amend's memref 202 if (auto amend = mlir::dyn_cast<ArrayAmendOp>(op)) { 203 reach.push_back(op); 204 llvm::SmallVector<ArrayAccessOp> accesses; 205 collectAccesses(accesses, op->getBlock()); 206 for (auto access : accesses) 207 collectArrayMentionFrom(access.getResult()); 208 } 209 210 // Otherwise, Op does not contain a region so just chase its operands. 211 if (mlir::isa<ArrayAccessOp, ArrayLoadOp, ArrayUpdateOp, ArrayModifyOp, 212 ArrayFetchOp>(op)) { 213 LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n"); 214 reach.push_back(op); 215 } 216 217 // Include all array_access ops using an array_load. 218 if (auto arrLd = mlir::dyn_cast<ArrayLoadOp>(op)) 219 for (auto *user : arrLd.getResult().getUsers()) 220 if (mlir::isa<ArrayAccessOp>(user)) { 221 LLVM_DEBUG(llvm::dbgs() << "add " << *user << " to reachable set\n"); 222 reach.push_back(user); 223 } 224 225 // Array modify assignment is performed on the result. So the analysis must 226 // look at the what is done with the result. 227 if (mlir::isa<ArrayModifyOp>(op)) 228 for (auto *user : op->getResult(0).getUsers()) 229 followUsers(user); 230 231 if (mlir::isa<fir::CallOp>(op)) { 232 LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n"); 233 reach.push_back(op); 234 } 235 236 for (auto u : op->getOperands()) 237 collectArrayMentionFrom(u); 238 } 239 240 void collectArrayMentionFrom(mlir::BlockArgument ba) { 241 auto *parent = ba.getOwner()->getParentOp(); 242 // If inside an Op holding a region, the block argument corresponds to an 243 // argument passed to the containing Op. 244 auto popFn = [&](auto rop) { 245 collectArrayMentionFrom(rop.blockArgToSourceOp(ba.getArgNumber())); 246 }; 247 if (auto rop = mlir::dyn_cast<DoLoopOp>(parent)) { 248 popFn(rop); 249 return; 250 } 251 if (auto rop = mlir::dyn_cast<IterWhileOp>(parent)) { 252 popFn(rop); 253 return; 254 } 255 // Otherwise, a block argument is provided via the pred blocks. 256 for (auto *pred : ba.getOwner()->getPredecessors()) { 257 auto u = pred->getTerminator()->getOperand(ba.getArgNumber()); 258 collectArrayMentionFrom(u); 259 } 260 } 261 262 // Recursively trace operands to find all array operations relating to the 263 // values merged. 264 void collectArrayMentionFrom(mlir::Value val) { 265 if (!val || visited.contains(val)) 266 return; 267 visited.insert(val); 268 269 // Process a block argument. 270 if (auto ba = val.dyn_cast<mlir::BlockArgument>()) { 271 collectArrayMentionFrom(ba); 272 return; 273 } 274 275 // Process an Op. 276 if (auto *op = val.getDefiningOp()) { 277 collectArrayMentionFrom(op, val); 278 return; 279 } 280 281 emitFatalError(val.getLoc(), "unhandled value"); 282 } 283 284 /// Return all ops that produce the array value that is stored into the 285 /// `array_merge_store`. 286 static void reachingValues(llvm::SmallVectorImpl<mlir::Operation *> &reach, 287 mlir::Value seq) { 288 reach.clear(); 289 mlir::Region *loopRegion = nullptr; 290 if (auto doLoop = mlir::dyn_cast_or_null<DoLoopOp>(seq.getDefiningOp())) 291 loopRegion = &doLoop->getRegion(0); 292 ReachCollector collector(reach, loopRegion); 293 collector.collectArrayMentionFrom(seq); 294 } 295 296 private: 297 /// Is \op inside the loop nest region ? 298 /// FIXME: replace this structural dependence with graph properties. 299 bool opIsInsideLoops(mlir::Operation *op) const { 300 auto *region = op->getParentRegion(); 301 while (region) { 302 if (region == loopRegion) 303 return true; 304 region = region->getParentRegion(); 305 } 306 return false; 307 } 308 309 /// Recursively trace the use of an operation results, calling 310 /// collectArrayMentionFrom on the direct and indirect user operands. 311 void followUsers(mlir::Operation *op) { 312 for (auto userOperand : op->getOperands()) 313 collectArrayMentionFrom(userOperand); 314 // Go through potential converts/coordinate_op. 315 for (auto indirectUser : op->getUsers()) 316 followUsers(indirectUser); 317 } 318 319 llvm::SmallVectorImpl<mlir::Operation *> &reach; 320 llvm::SmallPtrSet<mlir::Value, 16> visited; 321 /// Region of the loops nest that produced the array value. 322 mlir::Region *loopRegion; 323 }; 324 } // namespace 325 326 /// Find all the array operations that access the array value that is loaded by 327 /// the array load operation, `load`. 328 void ArrayCopyAnalysis::arrayMentions( 329 llvm::SmallVectorImpl<mlir::Operation *> &mentions, ArrayLoadOp load) { 330 mentions.clear(); 331 auto lmIter = loadMapSets.find(load); 332 if (lmIter != loadMapSets.end()) { 333 for (auto *opnd : lmIter->second) { 334 auto *owner = opnd->getOwner(); 335 if (mlir::isa<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, ArrayUpdateOp, 336 ArrayModifyOp>(owner)) 337 mentions.push_back(owner); 338 } 339 return; 340 } 341 342 UseSetT visited; 343 llvm::SmallVector<mlir::OpOperand *> queue; // uses of ArrayLoad[orig] 344 345 auto appendToQueue = [&](mlir::Value val) { 346 for (auto &use : val.getUses()) 347 if (!visited.count(&use)) { 348 visited.insert(&use); 349 queue.push_back(&use); 350 } 351 }; 352 353 // Build the set of uses of `original`. 354 // let USES = { uses of original fir.load } 355 appendToQueue(load); 356 357 // Process the worklist until done. 358 while (!queue.empty()) { 359 mlir::OpOperand *operand = queue.pop_back_val(); 360 mlir::Operation *owner = operand->getOwner(); 361 if (!owner) 362 continue; 363 auto structuredLoop = [&](auto ro) { 364 if (auto blockArg = ro.iterArgToBlockArg(operand->get())) { 365 int64_t arg = blockArg.getArgNumber(); 366 mlir::Value output = ro.getResult(ro.getFinalValue() ? arg : arg - 1); 367 appendToQueue(output); 368 appendToQueue(blockArg); 369 } 370 }; 371 // TODO: this need to be updated to use the control-flow interface. 372 auto branchOp = [&](mlir::Block *dest, OperandRange operands) { 373 if (operands.empty()) 374 return; 375 376 // Check if this operand is within the range. 377 unsigned operandIndex = operand->getOperandNumber(); 378 unsigned operandsStart = operands.getBeginOperandIndex(); 379 if (operandIndex < operandsStart || 380 operandIndex >= (operandsStart + operands.size())) 381 return; 382 383 // Index the successor. 384 unsigned argIndex = operandIndex - operandsStart; 385 appendToQueue(dest->getArgument(argIndex)); 386 }; 387 // Thread uses into structured loop bodies and return value uses. 388 if (auto ro = mlir::dyn_cast<DoLoopOp>(owner)) { 389 structuredLoop(ro); 390 } else if (auto ro = mlir::dyn_cast<IterWhileOp>(owner)) { 391 structuredLoop(ro); 392 } else if (auto rs = mlir::dyn_cast<ResultOp>(owner)) { 393 // Thread any uses of fir.if that return the marked array value. 394 mlir::Operation *parent = rs->getParentRegion()->getParentOp(); 395 if (auto ifOp = mlir::dyn_cast<fir::IfOp>(parent)) 396 appendToQueue(ifOp.getResult(operand->getOperandNumber())); 397 } else if (mlir::isa<ArrayFetchOp>(owner)) { 398 // Keep track of array value fetches. 399 LLVM_DEBUG(llvm::dbgs() 400 << "add fetch {" << *owner << "} to array value set\n"); 401 mentions.push_back(owner); 402 } else if (auto update = mlir::dyn_cast<ArrayUpdateOp>(owner)) { 403 // Keep track of array value updates and thread the return value uses. 404 LLVM_DEBUG(llvm::dbgs() 405 << "add update {" << *owner << "} to array value set\n"); 406 mentions.push_back(owner); 407 appendToQueue(update.getResult()); 408 } else if (auto update = mlir::dyn_cast<ArrayModifyOp>(owner)) { 409 // Keep track of array value modification and thread the return value 410 // uses. 411 LLVM_DEBUG(llvm::dbgs() 412 << "add modify {" << *owner << "} to array value set\n"); 413 mentions.push_back(owner); 414 appendToQueue(update.getResult(1)); 415 } else if (auto mention = mlir::dyn_cast<ArrayAccessOp>(owner)) { 416 mentions.push_back(owner); 417 } else if (auto amend = mlir::dyn_cast<ArrayAmendOp>(owner)) { 418 mentions.push_back(owner); 419 appendToQueue(amend.getResult()); 420 } else if (auto br = mlir::dyn_cast<mlir::cf::BranchOp>(owner)) { 421 branchOp(br.getDest(), br.getDestOperands()); 422 } else if (auto br = mlir::dyn_cast<mlir::cf::CondBranchOp>(owner)) { 423 branchOp(br.getTrueDest(), br.getTrueOperands()); 424 branchOp(br.getFalseDest(), br.getFalseOperands()); 425 } else if (mlir::isa<ArrayMergeStoreOp>(owner)) { 426 // do nothing 427 } else { 428 llvm::report_fatal_error("array value reached unexpected op"); 429 } 430 } 431 loadMapSets.insert({load, visited}); 432 } 433 434 static bool hasPointerType(mlir::Type type) { 435 if (auto boxTy = type.dyn_cast<BoxType>()) 436 type = boxTy.getEleTy(); 437 return type.isa<fir::PointerType>(); 438 } 439 440 // This is a NF performance hack. It makes a simple test that the slices of the 441 // load, \p ld, and the merge store, \p st, are trivially mutually exclusive. 442 static bool mutuallyExclusiveSliceRange(ArrayLoadOp ld, ArrayMergeStoreOp st) { 443 // If the same array_load, then no further testing is warranted. 444 if (ld.getResult() == st.getOriginal()) 445 return false; 446 447 auto getSliceOp = [](mlir::Value val) -> SliceOp { 448 if (!val) 449 return {}; 450 auto sliceOp = mlir::dyn_cast_or_null<SliceOp>(val.getDefiningOp()); 451 if (!sliceOp) 452 return {}; 453 return sliceOp; 454 }; 455 456 auto ldSlice = getSliceOp(ld.getSlice()); 457 auto stSlice = getSliceOp(st.getSlice()); 458 if (!ldSlice || !stSlice) 459 return false; 460 461 // Resign on subobject slices. 462 if (!ldSlice.getFields().empty() || !stSlice.getFields().empty() || 463 !ldSlice.getSubstr().empty() || !stSlice.getSubstr().empty()) 464 return false; 465 466 // Crudely test that the two slices do not overlap by looking for the 467 // following general condition. If the slices look like (i:j) and (j+1:k) then 468 // these ranges do not overlap. The addend must be a constant. 469 auto ldTriples = ldSlice.getTriples(); 470 auto stTriples = stSlice.getTriples(); 471 const auto size = ldTriples.size(); 472 if (size != stTriples.size()) 473 return false; 474 475 auto displacedByConstant = [](mlir::Value v1, mlir::Value v2) { 476 auto removeConvert = [](mlir::Value v) -> mlir::Operation * { 477 auto *op = v.getDefiningOp(); 478 while (auto conv = mlir::dyn_cast_or_null<ConvertOp>(op)) 479 op = conv.getValue().getDefiningOp(); 480 return op; 481 }; 482 483 auto isPositiveConstant = [](mlir::Value v) -> bool { 484 if (auto conOp = 485 mlir::dyn_cast<mlir::arith::ConstantOp>(v.getDefiningOp())) 486 if (auto iattr = conOp.getValue().dyn_cast<mlir::IntegerAttr>()) 487 return iattr.getInt() > 0; 488 return false; 489 }; 490 491 auto *op1 = removeConvert(v1); 492 auto *op2 = removeConvert(v2); 493 if (!op1 || !op2) 494 return false; 495 if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2)) 496 if ((addi.getLhs().getDefiningOp() == op1 && 497 isPositiveConstant(addi.getRhs())) || 498 (addi.getRhs().getDefiningOp() == op1 && 499 isPositiveConstant(addi.getLhs()))) 500 return true; 501 if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1)) 502 if (subi.getLhs().getDefiningOp() == op2 && 503 isPositiveConstant(subi.getRhs())) 504 return true; 505 return false; 506 }; 507 508 for (std::remove_const_t<decltype(size)> i = 0; i < size; i += 3) { 509 // If both are loop invariant, skip to the next triple. 510 if (mlir::isa_and_nonnull<fir::UndefOp>(ldTriples[i + 1].getDefiningOp()) && 511 mlir::isa_and_nonnull<fir::UndefOp>(stTriples[i + 1].getDefiningOp())) { 512 // Unless either is a vector index, then be conservative. 513 if (mlir::isa_and_nonnull<fir::UndefOp>(ldTriples[i].getDefiningOp()) || 514 mlir::isa_and_nonnull<fir::UndefOp>(stTriples[i].getDefiningOp())) 515 return false; 516 continue; 517 } 518 // If identical, skip to the next triple. 519 if (ldTriples[i] == stTriples[i] && ldTriples[i + 1] == stTriples[i + 1] && 520 ldTriples[i + 2] == stTriples[i + 2]) 521 continue; 522 // If ubound and lbound are the same with a constant offset, skip to the 523 // next triple. 524 if (displacedByConstant(ldTriples[i + 1], stTriples[i]) || 525 displacedByConstant(stTriples[i + 1], ldTriples[i])) 526 continue; 527 return false; 528 } 529 LLVM_DEBUG(llvm::dbgs() << "detected non-overlapping slice ranges on " << ld 530 << " and " << st << ", which is not a conflict\n"); 531 return true; 532 } 533 534 /// Is there a conflict between the array value that was updated and to be 535 /// stored to `st` and the set of arrays loaded (`reach`) and used to compute 536 /// the updated value? 537 static bool conflictOnLoad(llvm::ArrayRef<mlir::Operation *> reach, 538 ArrayMergeStoreOp st) { 539 mlir::Value load; 540 mlir::Value addr = st.getMemref(); 541 const bool storeHasPointerType = hasPointerType(addr.getType()); 542 for (auto *op : reach) 543 if (auto ld = mlir::dyn_cast<ArrayLoadOp>(op)) { 544 mlir::Type ldTy = ld.getMemref().getType(); 545 if (ld.getMemref() == addr) { 546 if (mutuallyExclusiveSliceRange(ld, st)) 547 continue; 548 if (ld.getResult() != st.getOriginal()) 549 return true; 550 if (load) { 551 // TODO: extend this to allow checking if the first `load` and this 552 // `ld` are mutually exclusive accesses but not identical. 553 return true; 554 } 555 load = ld; 556 } else if ((hasPointerType(ldTy) || storeHasPointerType)) { 557 // TODO: Use target attribute to restrict this case further. 558 // TODO: Check if types can also allow ruling out some cases. For now, 559 // the fact that equivalences is using pointer attribute to enforce 560 // aliasing is preventing any attempt to do so, and in general, it may 561 // be wrong to use this if any of the types is a complex or a derived 562 // for which it is possible to create a pointer to a part with a 563 // different type than the whole, although this deserve some more 564 // investigation because existing compiler behavior seem to diverge 565 // here. 566 return true; 567 } 568 } 569 return false; 570 } 571 572 /// Is there an access vector conflict on the array being merged into? If the 573 /// access vectors diverge, then assume that there are potentially overlapping 574 /// loop-carried references. 575 static bool conflictOnMerge(llvm::ArrayRef<mlir::Operation *> mentions) { 576 if (mentions.size() < 2) 577 return false; 578 llvm::SmallVector<mlir::Value> indices; 579 LLVM_DEBUG(llvm::dbgs() << "check merge conflict on with " << mentions.size() 580 << " mentions on the list\n"); 581 bool valSeen = false; 582 bool refSeen = false; 583 for (auto *op : mentions) { 584 llvm::SmallVector<mlir::Value> compareVector; 585 if (auto u = mlir::dyn_cast<ArrayUpdateOp>(op)) { 586 valSeen = true; 587 if (indices.empty()) { 588 indices = u.getIndices(); 589 continue; 590 } 591 compareVector = u.getIndices(); 592 } else if (auto f = mlir::dyn_cast<ArrayModifyOp>(op)) { 593 valSeen = true; 594 if (indices.empty()) { 595 indices = f.getIndices(); 596 continue; 597 } 598 compareVector = f.getIndices(); 599 } else if (auto f = mlir::dyn_cast<ArrayFetchOp>(op)) { 600 valSeen = true; 601 if (indices.empty()) { 602 indices = f.getIndices(); 603 continue; 604 } 605 compareVector = f.getIndices(); 606 } else if (auto f = mlir::dyn_cast<ArrayAccessOp>(op)) { 607 refSeen = true; 608 if (indices.empty()) { 609 indices = f.getIndices(); 610 continue; 611 } 612 compareVector = f.getIndices(); 613 } else if (mlir::isa<ArrayAmendOp>(op)) { 614 refSeen = true; 615 continue; 616 } else { 617 mlir::emitError(op->getLoc(), "unexpected operation in analysis"); 618 } 619 if (compareVector.size() != indices.size() || 620 llvm::any_of(llvm::zip(compareVector, indices), [&](auto pair) { 621 return std::get<0>(pair) != std::get<1>(pair); 622 })) 623 return true; 624 LLVM_DEBUG(llvm::dbgs() << "vectors compare equal\n"); 625 } 626 return valSeen && refSeen; 627 } 628 629 /// With element-by-reference semantics, an amended array with more than once 630 /// access to the same loaded array are conservatively considered a conflict. 631 /// Note: the array copy can still be eliminated in subsequent optimizations. 632 static bool conflictOnReference(llvm::ArrayRef<mlir::Operation *> mentions) { 633 LLVM_DEBUG(llvm::dbgs() << "checking reference semantics " << mentions.size() 634 << '\n'); 635 if (mentions.size() < 3) 636 return false; 637 unsigned amendCount = 0; 638 unsigned accessCount = 0; 639 for (auto *op : mentions) { 640 if (mlir::isa<ArrayAmendOp>(op) && ++amendCount > 1) { 641 LLVM_DEBUG(llvm::dbgs() << "conflict: multiple amends of array value\n"); 642 return true; 643 } 644 if (mlir::isa<ArrayAccessOp>(op) && ++accessCount > 1) { 645 LLVM_DEBUG(llvm::dbgs() 646 << "conflict: multiple accesses of array value\n"); 647 return true; 648 } 649 if (mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(op)) { 650 LLVM_DEBUG(llvm::dbgs() 651 << "conflict: array value has both uses by-value and uses " 652 "by-reference. conservative assumption.\n"); 653 return true; 654 } 655 } 656 return false; 657 } 658 659 static mlir::Operation * 660 amendingAccess(llvm::ArrayRef<mlir::Operation *> mentions) { 661 for (auto *op : mentions) 662 if (auto amend = mlir::dyn_cast<ArrayAmendOp>(op)) 663 return amend.getMemref().getDefiningOp(); 664 return {}; 665 } 666 667 // Are either of types of conflicts present? 668 inline bool conflictDetected(llvm::ArrayRef<mlir::Operation *> reach, 669 llvm::ArrayRef<mlir::Operation *> accesses, 670 ArrayMergeStoreOp st) { 671 return conflictOnLoad(reach, st) || conflictOnMerge(accesses); 672 } 673 674 // Assume that any call to a function that uses host-associations will be 675 // modifying the output array. 676 static bool 677 conservativeCallConflict(llvm::ArrayRef<mlir::Operation *> reaches) { 678 return llvm::any_of(reaches, [](mlir::Operation *op) { 679 if (auto call = mlir::dyn_cast<fir::CallOp>(op)) 680 if (auto callee = 681 call.getCallableForCallee().dyn_cast<mlir::SymbolRefAttr>()) { 682 auto module = op->getParentOfType<mlir::ModuleOp>(); 683 return hasHostAssociationArgument( 684 module.lookupSymbol<mlir::FuncOp>(callee)); 685 } 686 return false; 687 }); 688 } 689 690 /// Constructor of the array copy analysis. 691 /// This performs the analysis and saves the intermediate results. 692 void ArrayCopyAnalysis::construct(mlir::Operation *topLevelOp) { 693 topLevelOp->walk([&](Operation *op) { 694 if (auto st = mlir::dyn_cast<fir::ArrayMergeStoreOp>(op)) { 695 llvm::SmallVector<mlir::Operation *> values; 696 ReachCollector::reachingValues(values, st.getSequence()); 697 bool callConflict = conservativeCallConflict(values); 698 llvm::SmallVector<mlir::Operation *> mentions; 699 arrayMentions(mentions, 700 mlir::cast<ArrayLoadOp>(st.getOriginal().getDefiningOp())); 701 bool conflict = conflictDetected(values, mentions, st); 702 bool refConflict = conflictOnReference(mentions); 703 if (callConflict || conflict || refConflict) { 704 LLVM_DEBUG(llvm::dbgs() 705 << "CONFLICT: copies required for " << st << '\n' 706 << " adding conflicts on: " << op << " and " 707 << st.getOriginal() << '\n'); 708 conflicts.insert(op); 709 conflicts.insert(st.getOriginal().getDefiningOp()); 710 if (auto *access = amendingAccess(mentions)) 711 amendAccesses.insert(access); 712 } 713 auto *ld = st.getOriginal().getDefiningOp(); 714 LLVM_DEBUG(llvm::dbgs() 715 << "map: adding {" << *ld << " -> " << st << "}\n"); 716 useMap.insert({ld, op}); 717 } else if (auto load = mlir::dyn_cast<ArrayLoadOp>(op)) { 718 llvm::SmallVector<mlir::Operation *> mentions; 719 arrayMentions(mentions, load); 720 LLVM_DEBUG(llvm::dbgs() << "process load: " << load 721 << ", mentions: " << mentions.size() << '\n'); 722 for (auto *acc : mentions) { 723 LLVM_DEBUG(llvm::dbgs() << " mention: " << *acc << '\n'); 724 if (mlir::isa<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, ArrayUpdateOp, 725 ArrayModifyOp>(acc)) { 726 if (useMap.count(acc)) { 727 mlir::emitError( 728 load.getLoc(), 729 "The parallel semantics of multiple array_merge_stores per " 730 "array_load are not supported."); 731 continue; 732 } 733 LLVM_DEBUG(llvm::dbgs() 734 << "map: adding {" << *acc << "} -> {" << load << "}\n"); 735 useMap.insert({acc, op}); 736 } 737 } 738 } 739 }); 740 } 741 742 //===----------------------------------------------------------------------===// 743 // Conversions for converting out of array value form. 744 //===----------------------------------------------------------------------===// 745 746 namespace { 747 class ArrayLoadConversion : public mlir::OpRewritePattern<ArrayLoadOp> { 748 public: 749 using OpRewritePattern::OpRewritePattern; 750 751 mlir::LogicalResult 752 matchAndRewrite(ArrayLoadOp load, 753 mlir::PatternRewriter &rewriter) const override { 754 LLVM_DEBUG(llvm::dbgs() << "replace load " << load << " with undef.\n"); 755 rewriter.replaceOpWithNewOp<UndefOp>(load, load.getType()); 756 return mlir::success(); 757 } 758 }; 759 760 class ArrayMergeStoreConversion 761 : public mlir::OpRewritePattern<ArrayMergeStoreOp> { 762 public: 763 using OpRewritePattern::OpRewritePattern; 764 765 mlir::LogicalResult 766 matchAndRewrite(ArrayMergeStoreOp store, 767 mlir::PatternRewriter &rewriter) const override { 768 LLVM_DEBUG(llvm::dbgs() << "marking store " << store << " as dead.\n"); 769 rewriter.eraseOp(store); 770 return mlir::success(); 771 } 772 }; 773 } // namespace 774 775 static mlir::Type getEleTy(mlir::Type ty) { 776 auto eleTy = unwrapSequenceType(unwrapPassByRefType(ty)); 777 // FIXME: keep ptr/heap/ref information. 778 return ReferenceType::get(eleTy); 779 } 780 781 // Extract extents from the ShapeOp/ShapeShiftOp into the result vector. 782 static bool getAdjustedExtents(mlir::Location loc, 783 mlir::PatternRewriter &rewriter, 784 ArrayLoadOp arrLoad, 785 llvm::SmallVectorImpl<mlir::Value> &result, 786 mlir::Value shape) { 787 bool copyUsingSlice = false; 788 auto *shapeOp = shape.getDefiningOp(); 789 if (auto s = mlir::dyn_cast_or_null<ShapeOp>(shapeOp)) { 790 auto e = s.getExtents(); 791 result.insert(result.end(), e.begin(), e.end()); 792 } else if (auto s = mlir::dyn_cast_or_null<ShapeShiftOp>(shapeOp)) { 793 auto e = s.getExtents(); 794 result.insert(result.end(), e.begin(), e.end()); 795 } else { 796 emitFatalError(loc, "not a fir.shape/fir.shape_shift op"); 797 } 798 auto idxTy = rewriter.getIndexType(); 799 if (factory::isAssumedSize(result)) { 800 // Use slice information to compute the extent of the column. 801 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1); 802 mlir::Value size = one; 803 if (mlir::Value sliceArg = arrLoad.getSlice()) { 804 if (auto sliceOp = 805 mlir::dyn_cast_or_null<SliceOp>(sliceArg.getDefiningOp())) { 806 auto triples = sliceOp.getTriples(); 807 const std::size_t tripleSize = triples.size(); 808 auto module = arrLoad->getParentOfType<mlir::ModuleOp>(); 809 FirOpBuilder builder(rewriter, getKindMapping(module)); 810 size = builder.genExtentFromTriplet(loc, triples[tripleSize - 3], 811 triples[tripleSize - 2], 812 triples[tripleSize - 1], idxTy); 813 copyUsingSlice = true; 814 } 815 } 816 result[result.size() - 1] = size; 817 } 818 return copyUsingSlice; 819 } 820 821 /// Place the extents of the array load, \p arrLoad, into \p result and 822 /// return a ShapeOp or ShapeShiftOp with the same extents. If \p arrLoad is 823 /// loading a `!fir.box`, code will be generated to read the extents from the 824 /// boxed value, and the retunred shape Op will be built with the extents read 825 /// from the box. Otherwise, the extents will be extracted from the ShapeOp (or 826 /// ShapeShiftOp) argument of \p arrLoad. \p copyUsingSlice will be set to true 827 /// if slicing of the output array is to be done in the copy-in/copy-out rather 828 /// than in the elemental computation step. 829 static mlir::Value getOrReadExtentsAndShapeOp( 830 mlir::Location loc, mlir::PatternRewriter &rewriter, ArrayLoadOp arrLoad, 831 llvm::SmallVectorImpl<mlir::Value> &result, bool ©UsingSlice) { 832 assert(result.empty()); 833 if (arrLoad->hasAttr(fir::getOptionalAttrName())) 834 fir::emitFatalError( 835 loc, "shapes from array load of OPTIONAL arrays must not be used"); 836 if (auto boxTy = arrLoad.getMemref().getType().dyn_cast<BoxType>()) { 837 auto rank = 838 dyn_cast_ptrOrBoxEleTy(boxTy).cast<SequenceType>().getDimension(); 839 auto idxTy = rewriter.getIndexType(); 840 for (decltype(rank) dim = 0; dim < rank; ++dim) { 841 auto dimVal = rewriter.create<mlir::arith::ConstantIndexOp>(loc, dim); 842 auto dimInfo = rewriter.create<BoxDimsOp>(loc, idxTy, idxTy, idxTy, 843 arrLoad.getMemref(), dimVal); 844 result.emplace_back(dimInfo.getResult(1)); 845 } 846 if (!arrLoad.getShape()) { 847 auto shapeType = ShapeType::get(rewriter.getContext(), rank); 848 return rewriter.create<ShapeOp>(loc, shapeType, result); 849 } 850 auto shiftOp = arrLoad.getShape().getDefiningOp<ShiftOp>(); 851 auto shapeShiftType = ShapeShiftType::get(rewriter.getContext(), rank); 852 llvm::SmallVector<mlir::Value> shapeShiftOperands; 853 for (auto [lb, extent] : llvm::zip(shiftOp.getOrigins(), result)) { 854 shapeShiftOperands.push_back(lb); 855 shapeShiftOperands.push_back(extent); 856 } 857 return rewriter.create<ShapeShiftOp>(loc, shapeShiftType, 858 shapeShiftOperands); 859 } 860 copyUsingSlice = 861 getAdjustedExtents(loc, rewriter, arrLoad, result, arrLoad.getShape()); 862 return arrLoad.getShape(); 863 } 864 865 static mlir::Type toRefType(mlir::Type ty) { 866 if (fir::isa_ref_type(ty)) 867 return ty; 868 return fir::ReferenceType::get(ty); 869 } 870 871 static mlir::Value 872 genCoorOp(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Type eleTy, 873 mlir::Type resTy, mlir::Value alloc, mlir::Value shape, 874 mlir::Value slice, mlir::ValueRange indices, 875 mlir::ValueRange typeparams, bool skipOrig = false) { 876 llvm::SmallVector<mlir::Value> originated; 877 if (skipOrig) 878 originated.assign(indices.begin(), indices.end()); 879 else 880 originated = fir::factory::originateIndices(loc, rewriter, alloc.getType(), 881 shape, indices); 882 auto seqTy = fir::dyn_cast_ptrOrBoxEleTy(alloc.getType()); 883 assert(seqTy && seqTy.isa<fir::SequenceType>()); 884 const auto dimension = seqTy.cast<fir::SequenceType>().getDimension(); 885 mlir::Value result = rewriter.create<fir::ArrayCoorOp>( 886 loc, eleTy, alloc, shape, slice, 887 llvm::ArrayRef<mlir::Value>{originated}.take_front(dimension), 888 typeparams); 889 if (dimension < originated.size()) 890 result = rewriter.create<fir::CoordinateOp>( 891 loc, resTy, result, 892 llvm::ArrayRef<mlir::Value>{originated}.drop_front(dimension)); 893 return result; 894 } 895 896 static mlir::Value getCharacterLen(mlir::Location loc, FirOpBuilder &builder, 897 ArrayLoadOp load, CharacterType charTy) { 898 auto charLenTy = builder.getCharacterLengthType(); 899 if (charTy.hasDynamicLen()) { 900 if (load.getMemref().getType().isa<BoxType>()) { 901 // The loaded array is an emboxed value. Get the CHARACTER length from 902 // the box value. 903 auto eleSzInBytes = 904 builder.create<BoxEleSizeOp>(loc, charLenTy, load.getMemref()); 905 auto kindSize = 906 builder.getKindMap().getCharacterBitsize(charTy.getFKind()); 907 auto kindByteSize = 908 builder.createIntegerConstant(loc, charLenTy, kindSize / 8); 909 return builder.create<mlir::arith::DivSIOp>(loc, eleSzInBytes, 910 kindByteSize); 911 } 912 // The loaded array is a (set of) unboxed values. If the CHARACTER's 913 // length is not a constant, it must be provided as a type parameter to 914 // the array_load. 915 auto typeparams = load.getTypeparams(); 916 assert(typeparams.size() > 0 && "expected type parameters on array_load"); 917 return typeparams.back(); 918 } 919 // The typical case: the length of the CHARACTER is a compile-time 920 // constant that is encoded in the type information. 921 return builder.createIntegerConstant(loc, charLenTy, charTy.getLen()); 922 } 923 /// Generate a shallow array copy. This is used for both copy-in and copy-out. 924 template <bool CopyIn> 925 void genArrayCopy(mlir::Location loc, mlir::PatternRewriter &rewriter, 926 mlir::Value dst, mlir::Value src, mlir::Value shapeOp, 927 mlir::Value sliceOp, ArrayLoadOp arrLoad) { 928 auto insPt = rewriter.saveInsertionPoint(); 929 llvm::SmallVector<mlir::Value> indices; 930 llvm::SmallVector<mlir::Value> extents; 931 bool copyUsingSlice = 932 getAdjustedExtents(loc, rewriter, arrLoad, extents, shapeOp); 933 auto idxTy = rewriter.getIndexType(); 934 // Build loop nest from column to row. 935 for (auto sh : llvm::reverse(extents)) { 936 auto ubi = rewriter.create<ConvertOp>(loc, idxTy, sh); 937 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); 938 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1); 939 auto ub = rewriter.create<mlir::arith::SubIOp>(loc, idxTy, ubi, one); 940 auto loop = rewriter.create<DoLoopOp>(loc, zero, ub, one); 941 rewriter.setInsertionPointToStart(loop.getBody()); 942 indices.push_back(loop.getInductionVar()); 943 } 944 // Reverse the indices so they are in column-major order. 945 std::reverse(indices.begin(), indices.end()); 946 auto typeparams = arrLoad.getTypeparams(); 947 auto fromAddr = rewriter.create<ArrayCoorOp>( 948 loc, getEleTy(src.getType()), src, shapeOp, 949 CopyIn && copyUsingSlice ? sliceOp : mlir::Value{}, 950 factory::originateIndices(loc, rewriter, src.getType(), shapeOp, indices), 951 typeparams); 952 auto toAddr = rewriter.create<ArrayCoorOp>( 953 loc, getEleTy(dst.getType()), dst, shapeOp, 954 !CopyIn && copyUsingSlice ? sliceOp : mlir::Value{}, 955 factory::originateIndices(loc, rewriter, dst.getType(), shapeOp, indices), 956 typeparams); 957 auto eleTy = unwrapSequenceType(unwrapPassByRefType(dst.getType())); 958 auto module = toAddr->getParentOfType<mlir::ModuleOp>(); 959 FirOpBuilder builder(rewriter, getKindMapping(module)); 960 // Copy from (to) object to (from) temp copy of same object. 961 if (auto charTy = eleTy.dyn_cast<CharacterType>()) { 962 auto len = getCharacterLen(loc, builder, arrLoad, charTy); 963 CharBoxValue toChar(toAddr, len); 964 CharBoxValue fromChar(fromAddr, len); 965 factory::genScalarAssignment(builder, loc, toChar, fromChar); 966 } else { 967 if (hasDynamicSize(eleTy)) 968 TODO(loc, "copy element of dynamic size"); 969 factory::genScalarAssignment(builder, loc, toAddr, fromAddr); 970 } 971 rewriter.restoreInsertionPoint(insPt); 972 } 973 974 /// The array load may be either a boxed or unboxed value. If the value is 975 /// boxed, we read the type parameters from the boxed value. 976 static llvm::SmallVector<mlir::Value> 977 genArrayLoadTypeParameters(mlir::Location loc, mlir::PatternRewriter &rewriter, 978 ArrayLoadOp load) { 979 if (load.getTypeparams().empty()) { 980 auto eleTy = 981 unwrapSequenceType(unwrapPassByRefType(load.getMemref().getType())); 982 if (hasDynamicSize(eleTy)) { 983 if (auto charTy = eleTy.dyn_cast<CharacterType>()) { 984 assert(load.getMemref().getType().isa<BoxType>()); 985 auto module = load->getParentOfType<mlir::ModuleOp>(); 986 FirOpBuilder builder(rewriter, getKindMapping(module)); 987 return {getCharacterLen(loc, builder, load, charTy)}; 988 } 989 TODO(loc, "unhandled dynamic type parameters"); 990 } 991 return {}; 992 } 993 return load.getTypeparams(); 994 } 995 996 static llvm::SmallVector<mlir::Value> 997 findNonconstantExtents(mlir::Type memrefTy, 998 llvm::ArrayRef<mlir::Value> extents) { 999 llvm::SmallVector<mlir::Value> nce; 1000 auto arrTy = unwrapPassByRefType(memrefTy); 1001 auto seqTy = arrTy.cast<SequenceType>(); 1002 for (auto [s, x] : llvm::zip(seqTy.getShape(), extents)) 1003 if (s == SequenceType::getUnknownExtent()) 1004 nce.emplace_back(x); 1005 if (extents.size() > seqTy.getShape().size()) 1006 for (auto x : extents.drop_front(seqTy.getShape().size())) 1007 nce.emplace_back(x); 1008 return nce; 1009 } 1010 1011 namespace { 1012 /// Conversion of fir.array_update and fir.array_modify Ops. 1013 /// If there is a conflict for the update, then we need to perform a 1014 /// copy-in/copy-out to preserve the original values of the array. If there is 1015 /// no conflict, then it is save to eschew making any copies. 1016 template <typename ArrayOp> 1017 class ArrayUpdateConversionBase : public mlir::OpRewritePattern<ArrayOp> { 1018 public: 1019 // TODO: Implement copy/swap semantics? 1020 explicit ArrayUpdateConversionBase(mlir::MLIRContext *ctx, 1021 const ArrayCopyAnalysis &a, 1022 const OperationUseMapT &m) 1023 : mlir::OpRewritePattern<ArrayOp>{ctx}, analysis{a}, useMap{m} {} 1024 1025 /// The array_access, \p access, is to be to a cloned copy due to a potential 1026 /// conflict. Uses copy-in/copy-out semantics and not copy/swap. 1027 mlir::Value referenceToClone(mlir::Location loc, 1028 mlir::PatternRewriter &rewriter, 1029 ArrayOp access) const { 1030 LLVM_DEBUG(llvm::dbgs() 1031 << "generating copy-in/copy-out loops for " << access << '\n'); 1032 auto *op = access.getOperation(); 1033 auto *loadOp = useMap.lookup(op); 1034 auto load = mlir::cast<ArrayLoadOp>(loadOp); 1035 auto eleTy = access.getType(); 1036 rewriter.setInsertionPoint(loadOp); 1037 // Copy in. 1038 llvm::SmallVector<mlir::Value> extents; 1039 bool copyUsingSlice = false; 1040 auto shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents, 1041 copyUsingSlice); 1042 llvm::SmallVector<mlir::Value> nonconstantExtents = 1043 findNonconstantExtents(load.getMemref().getType(), extents); 1044 auto allocmem = rewriter.create<AllocMemOp>( 1045 loc, dyn_cast_ptrOrBoxEleTy(load.getMemref().getType()), 1046 genArrayLoadTypeParameters(loc, rewriter, load), nonconstantExtents); 1047 genArrayCopy</*copyIn=*/true>(load.getLoc(), rewriter, allocmem, 1048 load.getMemref(), shapeOp, load.getSlice(), 1049 load); 1050 // Generate the reference for the access. 1051 rewriter.setInsertionPoint(op); 1052 auto coor = 1053 genCoorOp(rewriter, loc, getEleTy(load.getType()), eleTy, allocmem, 1054 shapeOp, copyUsingSlice ? mlir::Value{} : load.getSlice(), 1055 access.getIndices(), load.getTypeparams(), 1056 access->hasAttr(factory::attrFortranArrayOffsets())); 1057 // Copy out. 1058 auto *storeOp = useMap.lookup(loadOp); 1059 auto store = mlir::cast<ArrayMergeStoreOp>(storeOp); 1060 rewriter.setInsertionPoint(storeOp); 1061 // Copy out. 1062 genArrayCopy</*copyIn=*/false>(store.getLoc(), rewriter, store.getMemref(), 1063 allocmem, shapeOp, store.getSlice(), load); 1064 rewriter.create<FreeMemOp>(loc, allocmem); 1065 return coor; 1066 } 1067 1068 /// Copy the RHS element into the LHS and insert copy-in/copy-out between a 1069 /// temp and the LHS if the analysis found potential overlaps between the RHS 1070 /// and LHS arrays. The element copy generator must be provided in \p 1071 /// assignElement. \p update must be the ArrayUpdateOp or the ArrayModifyOp. 1072 /// Returns the address of the LHS element inside the loop and the LHS 1073 /// ArrayLoad result. 1074 std::pair<mlir::Value, mlir::Value> 1075 materializeAssignment(mlir::Location loc, mlir::PatternRewriter &rewriter, 1076 ArrayOp update, 1077 const std::function<void(mlir::Value)> &assignElement, 1078 mlir::Type lhsEltRefType) const { 1079 auto *op = update.getOperation(); 1080 auto *loadOp = useMap.lookup(op); 1081 auto load = mlir::cast<ArrayLoadOp>(loadOp); 1082 LLVM_DEBUG(llvm::outs() << "does " << load << " have a conflict?\n"); 1083 if (analysis.hasPotentialConflict(loadOp)) { 1084 // If there is a conflict between the arrays, then we copy the lhs array 1085 // to a temporary, update the temporary, and copy the temporary back to 1086 // the lhs array. This yields Fortran's copy-in copy-out array semantics. 1087 LLVM_DEBUG(llvm::outs() << "Yes, conflict was found\n"); 1088 rewriter.setInsertionPoint(loadOp); 1089 // Copy in. 1090 llvm::SmallVector<mlir::Value> extents; 1091 bool copyUsingSlice = false; 1092 auto shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents, 1093 copyUsingSlice); 1094 llvm::SmallVector<mlir::Value> nonconstantExtents = 1095 findNonconstantExtents(load.getMemref().getType(), extents); 1096 auto allocmem = rewriter.create<AllocMemOp>( 1097 loc, dyn_cast_ptrOrBoxEleTy(load.getMemref().getType()), 1098 genArrayLoadTypeParameters(loc, rewriter, load), nonconstantExtents); 1099 genArrayCopy</*copyIn=*/true>(load.getLoc(), rewriter, allocmem, 1100 load.getMemref(), shapeOp, load.getSlice(), 1101 load); 1102 rewriter.setInsertionPoint(op); 1103 auto coor = genCoorOp( 1104 rewriter, loc, getEleTy(load.getType()), lhsEltRefType, allocmem, 1105 shapeOp, copyUsingSlice ? mlir::Value{} : load.getSlice(), 1106 update.getIndices(), load.getTypeparams(), 1107 update->hasAttr(factory::attrFortranArrayOffsets())); 1108 assignElement(coor); 1109 auto *storeOp = useMap.lookup(loadOp); 1110 auto store = mlir::cast<ArrayMergeStoreOp>(storeOp); 1111 rewriter.setInsertionPoint(storeOp); 1112 // Copy out. 1113 genArrayCopy</*copyIn=*/false>(store.getLoc(), rewriter, 1114 store.getMemref(), allocmem, shapeOp, 1115 store.getSlice(), load); 1116 rewriter.create<FreeMemOp>(loc, allocmem); 1117 return {coor, load.getResult()}; 1118 } 1119 // Otherwise, when there is no conflict (a possible loop-carried 1120 // dependence), the lhs array can be updated in place. 1121 LLVM_DEBUG(llvm::outs() << "No, conflict wasn't found\n"); 1122 rewriter.setInsertionPoint(op); 1123 auto coorTy = getEleTy(load.getType()); 1124 auto coor = genCoorOp(rewriter, loc, coorTy, lhsEltRefType, 1125 load.getMemref(), load.getShape(), load.getSlice(), 1126 update.getIndices(), load.getTypeparams(), 1127 update->hasAttr(factory::attrFortranArrayOffsets())); 1128 assignElement(coor); 1129 return {coor, load.getResult()}; 1130 } 1131 1132 protected: 1133 const ArrayCopyAnalysis &analysis; 1134 const OperationUseMapT &useMap; 1135 }; 1136 1137 class ArrayUpdateConversion : public ArrayUpdateConversionBase<ArrayUpdateOp> { 1138 public: 1139 explicit ArrayUpdateConversion(mlir::MLIRContext *ctx, 1140 const ArrayCopyAnalysis &a, 1141 const OperationUseMapT &m) 1142 : ArrayUpdateConversionBase{ctx, a, m} {} 1143 1144 mlir::LogicalResult 1145 matchAndRewrite(ArrayUpdateOp update, 1146 mlir::PatternRewriter &rewriter) const override { 1147 auto loc = update.getLoc(); 1148 auto assignElement = [&](mlir::Value coor) { 1149 auto input = update.getMerge(); 1150 if (auto inEleTy = dyn_cast_ptrEleTy(input.getType())) { 1151 emitFatalError(loc, "array_update on references not supported"); 1152 } else { 1153 rewriter.create<fir::StoreOp>(loc, input, coor); 1154 } 1155 }; 1156 auto lhsEltRefType = toRefType(update.getMerge().getType()); 1157 auto [_, lhsLoadResult] = materializeAssignment( 1158 loc, rewriter, update, assignElement, lhsEltRefType); 1159 update.replaceAllUsesWith(lhsLoadResult); 1160 rewriter.replaceOp(update, lhsLoadResult); 1161 return mlir::success(); 1162 } 1163 }; 1164 1165 class ArrayModifyConversion : public ArrayUpdateConversionBase<ArrayModifyOp> { 1166 public: 1167 explicit ArrayModifyConversion(mlir::MLIRContext *ctx, 1168 const ArrayCopyAnalysis &a, 1169 const OperationUseMapT &m) 1170 : ArrayUpdateConversionBase{ctx, a, m} {} 1171 1172 mlir::LogicalResult 1173 matchAndRewrite(ArrayModifyOp modify, 1174 mlir::PatternRewriter &rewriter) const override { 1175 auto loc = modify.getLoc(); 1176 auto assignElement = [](mlir::Value) { 1177 // Assignment already materialized by lowering using lhs element address. 1178 }; 1179 auto lhsEltRefType = modify.getResult(0).getType(); 1180 auto [lhsEltCoor, lhsLoadResult] = materializeAssignment( 1181 loc, rewriter, modify, assignElement, lhsEltRefType); 1182 modify.replaceAllUsesWith(mlir::ValueRange{lhsEltCoor, lhsLoadResult}); 1183 rewriter.replaceOp(modify, mlir::ValueRange{lhsEltCoor, lhsLoadResult}); 1184 return mlir::success(); 1185 } 1186 }; 1187 1188 class ArrayFetchConversion : public mlir::OpRewritePattern<ArrayFetchOp> { 1189 public: 1190 explicit ArrayFetchConversion(mlir::MLIRContext *ctx, 1191 const OperationUseMapT &m) 1192 : OpRewritePattern{ctx}, useMap{m} {} 1193 1194 mlir::LogicalResult 1195 matchAndRewrite(ArrayFetchOp fetch, 1196 mlir::PatternRewriter &rewriter) const override { 1197 auto *op = fetch.getOperation(); 1198 rewriter.setInsertionPoint(op); 1199 auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op)); 1200 auto loc = fetch.getLoc(); 1201 auto coor = 1202 genCoorOp(rewriter, loc, getEleTy(load.getType()), 1203 toRefType(fetch.getType()), load.getMemref(), load.getShape(), 1204 load.getSlice(), fetch.getIndices(), load.getTypeparams(), 1205 fetch->hasAttr(factory::attrFortranArrayOffsets())); 1206 if (isa_ref_type(fetch.getType())) 1207 rewriter.replaceOp(fetch, coor); 1208 else 1209 rewriter.replaceOpWithNewOp<fir::LoadOp>(fetch, coor); 1210 return mlir::success(); 1211 } 1212 1213 private: 1214 const OperationUseMapT &useMap; 1215 }; 1216 1217 /// As array_access op is like an array_fetch op, except that it does not imply 1218 /// a load op. (It operates in the reference domain.) 1219 class ArrayAccessConversion : public ArrayUpdateConversionBase<ArrayAccessOp> { 1220 public: 1221 explicit ArrayAccessConversion(mlir::MLIRContext *ctx, 1222 const ArrayCopyAnalysis &a, 1223 const OperationUseMapT &m) 1224 : ArrayUpdateConversionBase{ctx, a, m} {} 1225 1226 mlir::LogicalResult 1227 matchAndRewrite(ArrayAccessOp access, 1228 mlir::PatternRewriter &rewriter) const override { 1229 auto *op = access.getOperation(); 1230 auto loc = access.getLoc(); 1231 if (analysis.inAmendAccessSet(op)) { 1232 // This array_access is associated with an array_amend and there is a 1233 // conflict. Make a copy to store into. 1234 auto result = referenceToClone(loc, rewriter, access); 1235 access.replaceAllUsesWith(result); 1236 rewriter.replaceOp(access, result); 1237 return mlir::success(); 1238 } 1239 rewriter.setInsertionPoint(op); 1240 auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op)); 1241 auto coor = genCoorOp(rewriter, loc, getEleTy(load.getType()), 1242 toRefType(access.getType()), load.getMemref(), 1243 load.getShape(), load.getSlice(), access.getIndices(), 1244 load.getTypeparams(), 1245 access->hasAttr(factory::attrFortranArrayOffsets())); 1246 rewriter.replaceOp(access, coor); 1247 return mlir::success(); 1248 } 1249 }; 1250 1251 /// An array_amend op is a marker to record which array access is being used to 1252 /// update an array value. After this pass runs, an array_amend has no 1253 /// semantics. We rewrite these to undefined values here to remove them while 1254 /// preserving SSA form. 1255 class ArrayAmendConversion : public mlir::OpRewritePattern<ArrayAmendOp> { 1256 public: 1257 explicit ArrayAmendConversion(mlir::MLIRContext *ctx) 1258 : OpRewritePattern{ctx} {} 1259 1260 mlir::LogicalResult 1261 matchAndRewrite(ArrayAmendOp amend, 1262 mlir::PatternRewriter &rewriter) const override { 1263 auto *op = amend.getOperation(); 1264 rewriter.setInsertionPoint(op); 1265 auto loc = amend.getLoc(); 1266 auto undef = rewriter.create<UndefOp>(loc, amend.getType()); 1267 rewriter.replaceOp(amend, undef.getResult()); 1268 return mlir::success(); 1269 } 1270 }; 1271 1272 class ArrayValueCopyConverter 1273 : public ArrayValueCopyBase<ArrayValueCopyConverter> { 1274 public: 1275 void runOnOperation() override { 1276 auto func = getOperation(); 1277 LLVM_DEBUG(llvm::dbgs() << "\n\narray-value-copy pass on function '" 1278 << func.getName() << "'\n"); 1279 auto *context = &getContext(); 1280 1281 // Perform the conflict analysis. 1282 const auto &analysis = getAnalysis<ArrayCopyAnalysis>(); 1283 const auto &useMap = analysis.getUseMap(); 1284 1285 mlir::RewritePatternSet patterns1(context); 1286 patterns1.insert<ArrayFetchConversion>(context, useMap); 1287 patterns1.insert<ArrayUpdateConversion>(context, analysis, useMap); 1288 patterns1.insert<ArrayModifyConversion>(context, analysis, useMap); 1289 patterns1.insert<ArrayAccessConversion>(context, analysis, useMap); 1290 patterns1.insert<ArrayAmendConversion>(context); 1291 mlir::ConversionTarget target(*context); 1292 target.addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect, 1293 mlir::arith::ArithmeticDialect, 1294 mlir::func::FuncDialect>(); 1295 target.addIllegalOp<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, 1296 ArrayUpdateOp, ArrayModifyOp>(); 1297 // Rewrite the array fetch and array update ops. 1298 if (mlir::failed( 1299 mlir::applyPartialConversion(func, target, std::move(patterns1)))) { 1300 mlir::emitError(mlir::UnknownLoc::get(context), 1301 "failure in array-value-copy pass, phase 1"); 1302 signalPassFailure(); 1303 } 1304 1305 mlir::RewritePatternSet patterns2(context); 1306 patterns2.insert<ArrayLoadConversion>(context); 1307 patterns2.insert<ArrayMergeStoreConversion>(context); 1308 target.addIllegalOp<ArrayLoadOp, ArrayMergeStoreOp>(); 1309 if (mlir::failed( 1310 mlir::applyPartialConversion(func, target, std::move(patterns2)))) { 1311 mlir::emitError(mlir::UnknownLoc::get(context), 1312 "failure in array-value-copy pass, phase 2"); 1313 signalPassFailure(); 1314 } 1315 } 1316 }; 1317 } // namespace 1318 1319 std::unique_ptr<mlir::Pass> fir::createArrayValueCopyPass() { 1320 return std::make_unique<ArrayValueCopyConverter>(); 1321 } 1322