1 //===-- ArrayValueCopy.cpp ------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "flang/Optimizer/Builder/BoxValue.h" 11 #include "flang/Optimizer/Builder/FIRBuilder.h" 12 #include "flang/Optimizer/Builder/Factory.h" 13 #include "flang/Optimizer/Dialect/FIRDialect.h" 14 #include "flang/Optimizer/Support/FIRContext.h" 15 #include "flang/Optimizer/Transforms/Passes.h" 16 #include "mlir/Dialect/SCF/SCF.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 #include "llvm/Support/Debug.h" 19 20 #define DEBUG_TYPE "flang-array-value-copy" 21 22 using namespace fir; 23 24 using OperationUseMapT = llvm::DenseMap<mlir::Operation *, mlir::Operation *>; 25 26 namespace { 27 28 /// Array copy analysis. 29 /// Perform an interference analysis between array values. 30 /// 31 /// Lowering will generate a sequence of the following form. 32 /// ```mlir 33 /// %a_1 = fir.array_load %array_1(%shape) : ... 34 /// ... 35 /// %a_j = fir.array_load %array_j(%shape) : ... 36 /// ... 37 /// %a_n = fir.array_load %array_n(%shape) : ... 38 /// ... 39 /// %v_i = fir.array_fetch %a_i, ... 40 /// %a_j1 = fir.array_update %a_j, ... 41 /// ... 42 /// fir.array_merge_store %a_j, %a_jn to %array_j : ... 43 /// ``` 44 /// 45 /// The analysis is to determine if there are any conflicts. A conflict is when 46 /// one the following cases occurs. 47 /// 48 /// 1. There is an `array_update` to an array value, a_j, such that a_j was 49 /// loaded from the same array memory reference (array_j) but with a different 50 /// shape as the other array values a_i, where i != j. [Possible overlapping 51 /// arrays.] 52 /// 53 /// 2. There is either an array_fetch or array_update of a_j with a different 54 /// set of index values. [Possible loop-carried dependence.] 55 /// 56 /// If none of the array values overlap in storage and the accesses are not 57 /// loop-carried, then the arrays are conflict-free and no copies are required. 58 class ArrayCopyAnalysis { 59 public: 60 using ConflictSetT = llvm::SmallPtrSet<mlir::Operation *, 16>; 61 using UseSetT = llvm::SmallPtrSet<mlir::OpOperand *, 8>; 62 using LoadMapSetsT = 63 llvm::DenseMap<mlir::Operation *, SmallVector<Operation *>>; 64 65 ArrayCopyAnalysis(mlir::Operation *op) : operation{op} { construct(op); } 66 67 mlir::Operation *getOperation() const { return operation; } 68 69 /// Return true iff the `array_merge_store` has potential conflicts. 70 bool hasPotentialConflict(mlir::Operation *op) const { 71 LLVM_DEBUG(llvm::dbgs() 72 << "looking for a conflict on " << *op 73 << " and the set has a total of " << conflicts.size() << '\n'); 74 return conflicts.contains(op); 75 } 76 77 /// Return the use map. The use map maps array fetch and update operations 78 /// back to the array load that is the original source of the array value. 79 const OperationUseMapT &getUseMap() const { return useMap; } 80 81 /// Find all the array operations that access the array value that is loaded 82 /// by the array load operation, `load`. 83 const llvm::SmallVector<mlir::Operation *> &arrayAccesses(ArrayLoadOp load); 84 85 private: 86 void construct(mlir::Operation *topLevelOp); 87 88 mlir::Operation *operation; // operation that analysis ran upon 89 ConflictSetT conflicts; // set of conflicts (loads and merge stores) 90 OperationUseMapT useMap; 91 LoadMapSetsT loadMapSets; 92 }; 93 } // namespace 94 95 namespace { 96 /// Helper class to collect all array operations that produced an array value. 97 class ReachCollector { 98 private: 99 // If provided, the `loopRegion` is the body of a loop that produces the array 100 // of interest. 101 ReachCollector(llvm::SmallVectorImpl<mlir::Operation *> &reach, 102 mlir::Region *loopRegion) 103 : reach{reach}, loopRegion{loopRegion} {} 104 105 void collectArrayAccessFrom(mlir::Operation *op, mlir::ValueRange range) { 106 llvm::errs() << "COLLECT " << *op << "\n"; 107 if (range.empty()) { 108 collectArrayAccessFrom(op, mlir::Value{}); 109 return; 110 } 111 for (mlir::Value v : range) 112 collectArrayAccessFrom(v); 113 } 114 115 // TODO: Replace recursive algorithm on def-use chain with an iterative one 116 // with an explicit stack. 117 void collectArrayAccessFrom(mlir::Operation *op, mlir::Value val) { 118 // `val` is defined by an Op, process the defining Op. 119 // If `val` is defined by a region containing Op, we want to drill down 120 // and through that Op's region(s). 121 llvm::errs() << "COLLECT " << *op << "\n"; 122 LLVM_DEBUG(llvm::dbgs() << "popset: " << *op << '\n'); 123 auto popFn = [&](auto rop) { 124 assert(val && "op must have a result value"); 125 auto resNum = val.cast<mlir::OpResult>().getResultNumber(); 126 llvm::SmallVector<mlir::Value> results; 127 rop.resultToSourceOps(results, resNum); 128 for (auto u : results) 129 collectArrayAccessFrom(u); 130 }; 131 if (auto rop = mlir::dyn_cast<fir::DoLoopOp>(op)) { 132 popFn(rop); 133 return; 134 } 135 if (auto rop = mlir::dyn_cast<fir::IfOp>(op)) { 136 popFn(rop); 137 return; 138 } 139 if (auto mergeStore = mlir::dyn_cast<ArrayMergeStoreOp>(op)) { 140 if (opIsInsideLoops(mergeStore)) 141 collectArrayAccessFrom(mergeStore.sequence()); 142 return; 143 } 144 145 if (mlir::isa<AllocaOp, AllocMemOp>(op)) { 146 // Look for any stores inside the loops, and collect an array operation 147 // that produced the value being stored to it. 148 for (mlir::Operation *user : op->getUsers()) 149 if (auto store = mlir::dyn_cast<fir::StoreOp>(user)) 150 if (opIsInsideLoops(store)) 151 collectArrayAccessFrom(store.value()); 152 return; 153 } 154 155 // Otherwise, Op does not contain a region so just chase its operands. 156 if (mlir::isa<ArrayLoadOp, ArrayUpdateOp, ArrayModifyOp, ArrayFetchOp>( 157 op)) { 158 LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n"); 159 reach.emplace_back(op); 160 } 161 // Array modify assignment is performed on the result. So the analysis 162 // must look at the what is done with the result. 163 if (mlir::isa<ArrayModifyOp>(op)) 164 for (mlir::Operation *user : op->getResult(0).getUsers()) 165 followUsers(user); 166 167 for (auto u : op->getOperands()) 168 collectArrayAccessFrom(u); 169 } 170 171 void collectArrayAccessFrom(mlir::BlockArgument ba) { 172 auto *parent = ba.getOwner()->getParentOp(); 173 // If inside an Op holding a region, the block argument corresponds to an 174 // argument passed to the containing Op. 175 auto popFn = [&](auto rop) { 176 collectArrayAccessFrom(rop.blockArgToSourceOp(ba.getArgNumber())); 177 }; 178 if (auto rop = mlir::dyn_cast<DoLoopOp>(parent)) { 179 popFn(rop); 180 return; 181 } 182 if (auto rop = mlir::dyn_cast<IterWhileOp>(parent)) { 183 popFn(rop); 184 return; 185 } 186 // Otherwise, a block argument is provided via the pred blocks. 187 for (auto *pred : ba.getOwner()->getPredecessors()) { 188 auto u = pred->getTerminator()->getOperand(ba.getArgNumber()); 189 collectArrayAccessFrom(u); 190 } 191 } 192 193 // Recursively trace operands to find all array operations relating to the 194 // values merged. 195 void collectArrayAccessFrom(mlir::Value val) { 196 if (!val || visited.contains(val)) 197 return; 198 visited.insert(val); 199 200 // Process a block argument. 201 if (auto ba = val.dyn_cast<mlir::BlockArgument>()) { 202 collectArrayAccessFrom(ba); 203 return; 204 } 205 206 // Process an Op. 207 if (auto *op = val.getDefiningOp()) { 208 collectArrayAccessFrom(op, val); 209 return; 210 } 211 212 fir::emitFatalError(val.getLoc(), "unhandled value"); 213 } 214 215 /// Is \op inside the loop nest region ? 216 bool opIsInsideLoops(mlir::Operation *op) const { 217 return loopRegion && loopRegion->isAncestor(op->getParentRegion()); 218 } 219 220 /// Recursively trace the use of an operation results, calling 221 /// collectArrayAccessFrom on the direct and indirect user operands. 222 /// TODO: Replace recursive algorithm on def-use chain with an iterative one 223 /// with an explicit stack. 224 void followUsers(mlir::Operation *op) { 225 for (auto userOperand : op->getOperands()) 226 collectArrayAccessFrom(userOperand); 227 // Go through potential converts/coordinate_op. 228 for (mlir::Operation *indirectUser : op->getUsers()) 229 followUsers(indirectUser); 230 } 231 232 llvm::SmallVectorImpl<mlir::Operation *> &reach; 233 llvm::SmallPtrSet<mlir::Value, 16> visited; 234 /// Region of the loops nest that produced the array value. 235 mlir::Region *loopRegion; 236 237 public: 238 /// Return all ops that produce the array value that is stored into the 239 /// `array_merge_store`. 240 static void reachingValues(llvm::SmallVectorImpl<mlir::Operation *> &reach, 241 mlir::Value seq) { 242 reach.clear(); 243 mlir::Region *loopRegion = nullptr; 244 // Only `DoLoopOp` is tested here since array operations are currently only 245 // associated with this kind of loop. 246 if (auto doLoop = 247 mlir::dyn_cast_or_null<fir::DoLoopOp>(seq.getDefiningOp())) 248 loopRegion = &doLoop->getRegion(0); 249 ReachCollector collector(reach, loopRegion); 250 collector.collectArrayAccessFrom(seq); 251 } 252 }; 253 } // namespace 254 255 /// Find all the array operations that access the array value that is loaded by 256 /// the array load operation, `load`. 257 const llvm::SmallVector<mlir::Operation *> & 258 ArrayCopyAnalysis::arrayAccesses(ArrayLoadOp load) { 259 auto lmIter = loadMapSets.find(load); 260 if (lmIter != loadMapSets.end()) 261 return lmIter->getSecond(); 262 263 llvm::SmallVector<mlir::Operation *> accesses; 264 UseSetT visited; 265 llvm::SmallVector<mlir::OpOperand *> queue; // uses of ArrayLoad[orig] 266 267 auto appendToQueue = [&](mlir::Value val) { 268 for (mlir::OpOperand &use : val.getUses()) 269 if (!visited.count(&use)) { 270 visited.insert(&use); 271 queue.push_back(&use); 272 } 273 }; 274 275 // Build the set of uses of `original`. 276 // let USES = { uses of original fir.load } 277 appendToQueue(load); 278 279 // Process the worklist until done. 280 while (!queue.empty()) { 281 mlir::OpOperand *operand = queue.pop_back_val(); 282 mlir::Operation *owner = operand->getOwner(); 283 284 auto structuredLoop = [&](auto ro) { 285 if (auto blockArg = ro.iterArgToBlockArg(operand->get())) { 286 int64_t arg = blockArg.getArgNumber(); 287 mlir::Value output = ro.getResult(ro.finalValue() ? arg : arg - 1); 288 appendToQueue(output); 289 appendToQueue(blockArg); 290 } 291 }; 292 // TODO: this need to be updated to use the control-flow interface. 293 auto branchOp = [&](mlir::Block *dest, OperandRange operands) { 294 if (operands.empty()) 295 return; 296 297 // Check if this operand is within the range. 298 unsigned operandIndex = operand->getOperandNumber(); 299 unsigned operandsStart = operands.getBeginOperandIndex(); 300 if (operandIndex < operandsStart || 301 operandIndex >= (operandsStart + operands.size())) 302 return; 303 304 // Index the successor. 305 unsigned argIndex = operandIndex - operandsStart; 306 appendToQueue(dest->getArgument(argIndex)); 307 }; 308 // Thread uses into structured loop bodies and return value uses. 309 if (auto ro = mlir::dyn_cast<DoLoopOp>(owner)) { 310 structuredLoop(ro); 311 } else if (auto ro = mlir::dyn_cast<IterWhileOp>(owner)) { 312 structuredLoop(ro); 313 } else if (auto rs = mlir::dyn_cast<ResultOp>(owner)) { 314 // Thread any uses of fir.if that return the marked array value. 315 if (auto ifOp = rs->getParentOfType<fir::IfOp>()) 316 appendToQueue(ifOp.getResult(operand->getOperandNumber())); 317 } else if (mlir::isa<ArrayFetchOp>(owner)) { 318 // Keep track of array value fetches. 319 LLVM_DEBUG(llvm::dbgs() 320 << "add fetch {" << *owner << "} to array value set\n"); 321 accesses.push_back(owner); 322 } else if (auto update = mlir::dyn_cast<ArrayUpdateOp>(owner)) { 323 // Keep track of array value updates and thread the return value uses. 324 LLVM_DEBUG(llvm::dbgs() 325 << "add update {" << *owner << "} to array value set\n"); 326 accesses.push_back(owner); 327 appendToQueue(update.getResult()); 328 } else if (auto update = mlir::dyn_cast<ArrayModifyOp>(owner)) { 329 // Keep track of array value modification and thread the return value 330 // uses. 331 LLVM_DEBUG(llvm::dbgs() 332 << "add modify {" << *owner << "} to array value set\n"); 333 accesses.push_back(owner); 334 appendToQueue(update.getResult(1)); 335 } else if (auto br = mlir::dyn_cast<mlir::BranchOp>(owner)) { 336 branchOp(br.getDest(), br.getDestOperands()); 337 } else if (auto br = mlir::dyn_cast<mlir::CondBranchOp>(owner)) { 338 branchOp(br.getTrueDest(), br.getTrueOperands()); 339 branchOp(br.getFalseDest(), br.getFalseOperands()); 340 } else if (mlir::isa<ArrayMergeStoreOp>(owner)) { 341 // do nothing 342 } else { 343 llvm::report_fatal_error("array value reached unexpected op"); 344 } 345 } 346 return loadMapSets.insert({load, accesses}).first->getSecond(); 347 } 348 349 /// Is there a conflict between the array value that was updated and to be 350 /// stored to `st` and the set of arrays loaded (`reach`) and used to compute 351 /// the updated value? 352 static bool conflictOnLoad(llvm::ArrayRef<mlir::Operation *> reach, 353 ArrayMergeStoreOp st) { 354 mlir::Value load; 355 mlir::Value addr = st.memref(); 356 auto stEleTy = fir::dyn_cast_ptrOrBoxEleTy(addr.getType()); 357 for (auto *op : reach) { 358 auto ld = mlir::dyn_cast<ArrayLoadOp>(op); 359 if (!ld) 360 continue; 361 mlir::Type ldTy = ld.memref().getType(); 362 if (auto boxTy = ldTy.dyn_cast<fir::BoxType>()) 363 ldTy = boxTy.getEleTy(); 364 if (ldTy.isa<fir::PointerType>() && stEleTy == dyn_cast_ptrEleTy(ldTy)) 365 return true; 366 if (ld.memref() == addr) { 367 if (ld.getResult() != st.original()) 368 return true; 369 if (load) 370 return true; 371 load = ld; 372 } 373 } 374 return false; 375 } 376 377 /// Check if there is any potential conflict in the chained update operations 378 /// (ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp) while merging back to the 379 /// array. A potential conflict is detected if two operations work on the same 380 /// indices. 381 static bool conflictOnMerge(llvm::ArrayRef<mlir::Operation *> accesses) { 382 if (accesses.size() < 2) 383 return false; 384 llvm::SmallVector<mlir::Value> indices; 385 LLVM_DEBUG(llvm::dbgs() << "check merge conflict on with " << accesses.size() 386 << " accesses on the list\n"); 387 for (auto *op : accesses) { 388 assert((mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(op)) && 389 "unexpected operation in analysis"); 390 llvm::SmallVector<mlir::Value> compareVector; 391 if (auto u = mlir::dyn_cast<ArrayUpdateOp>(op)) { 392 if (indices.empty()) { 393 indices = u.indices(); 394 continue; 395 } 396 compareVector = u.indices(); 397 } else if (auto f = mlir::dyn_cast<ArrayModifyOp>(op)) { 398 if (indices.empty()) { 399 indices = f.indices(); 400 continue; 401 } 402 compareVector = f.indices(); 403 } else if (auto f = mlir::dyn_cast<ArrayFetchOp>(op)) { 404 if (indices.empty()) { 405 indices = f.indices(); 406 continue; 407 } 408 compareVector = f.indices(); 409 } 410 if (compareVector != indices) 411 return true; 412 LLVM_DEBUG(llvm::dbgs() << "vectors compare equal\n"); 413 } 414 return false; 415 } 416 417 // Are either of types of conflicts present? 418 inline bool conflictDetected(llvm::ArrayRef<mlir::Operation *> reach, 419 llvm::ArrayRef<mlir::Operation *> accesses, 420 ArrayMergeStoreOp st) { 421 return conflictOnLoad(reach, st) || conflictOnMerge(accesses); 422 } 423 424 /// Constructor of the array copy analysis. 425 /// This performs the analysis and saves the intermediate results. 426 void ArrayCopyAnalysis::construct(mlir::Operation *topLevelOp) { 427 topLevelOp->walk([&](Operation *op) { 428 if (auto st = mlir::dyn_cast<fir::ArrayMergeStoreOp>(op)) { 429 llvm::SmallVector<Operation *> values; 430 ReachCollector::reachingValues(values, st.sequence()); 431 const llvm::SmallVector<Operation *> &accesses = 432 arrayAccesses(mlir::cast<ArrayLoadOp>(st.original().getDefiningOp())); 433 if (conflictDetected(values, accesses, st)) { 434 LLVM_DEBUG(llvm::dbgs() 435 << "CONFLICT: copies required for " << st << '\n' 436 << " adding conflicts on: " << op << " and " 437 << st.original() << '\n'); 438 conflicts.insert(op); 439 conflicts.insert(st.original().getDefiningOp()); 440 } 441 auto *ld = st.original().getDefiningOp(); 442 LLVM_DEBUG(llvm::dbgs() 443 << "map: adding {" << *ld << " -> " << st << "}\n"); 444 useMap.insert({ld, op}); 445 } else if (auto load = mlir::dyn_cast<ArrayLoadOp>(op)) { 446 const llvm::SmallVector<mlir::Operation *> &accesses = 447 arrayAccesses(load); 448 LLVM_DEBUG(llvm::dbgs() << "process load: " << load 449 << ", accesses: " << accesses.size() << '\n'); 450 for (auto *acc : accesses) { 451 LLVM_DEBUG(llvm::dbgs() << " access: " << *acc << '\n'); 452 assert((mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(acc))); 453 if (!useMap.insert({acc, op}).second) { 454 mlir::emitError( 455 load.getLoc(), 456 "The parallel semantics of multiple array_merge_stores per " 457 "array_load are not supported."); 458 return; 459 } 460 LLVM_DEBUG(llvm::dbgs() 461 << "map: adding {" << *acc << "} -> {" << load << "}\n"); 462 } 463 } 464 }); 465 } 466 467 namespace { 468 class ArrayLoadConversion : public mlir::OpRewritePattern<ArrayLoadOp> { 469 public: 470 using OpRewritePattern::OpRewritePattern; 471 472 mlir::LogicalResult 473 matchAndRewrite(ArrayLoadOp load, 474 mlir::PatternRewriter &rewriter) const override { 475 LLVM_DEBUG(llvm::dbgs() << "replace load " << load << " with undef.\n"); 476 rewriter.replaceOpWithNewOp<UndefOp>(load, load.getType()); 477 return mlir::success(); 478 } 479 }; 480 481 class ArrayMergeStoreConversion 482 : public mlir::OpRewritePattern<ArrayMergeStoreOp> { 483 public: 484 using OpRewritePattern::OpRewritePattern; 485 486 mlir::LogicalResult 487 matchAndRewrite(ArrayMergeStoreOp store, 488 mlir::PatternRewriter &rewriter) const override { 489 LLVM_DEBUG(llvm::dbgs() << "marking store " << store << " as dead.\n"); 490 rewriter.eraseOp(store); 491 return mlir::success(); 492 } 493 }; 494 } // namespace 495 496 static mlir::Type getEleTy(mlir::Type ty) { 497 if (auto t = dyn_cast_ptrEleTy(ty)) 498 ty = t; 499 if (auto t = ty.dyn_cast<SequenceType>()) 500 ty = t.getEleTy(); 501 // FIXME: keep ptr/heap/ref information. 502 return ReferenceType::get(ty); 503 } 504 505 // Extract extents from the ShapeOp/ShapeShiftOp into the result vector. 506 // TODO: getExtents on op should return a ValueRange instead of a vector. 507 static void getExtents(llvm::SmallVectorImpl<mlir::Value> &result, 508 mlir::Value shape) { 509 auto *shapeOp = shape.getDefiningOp(); 510 if (auto s = mlir::dyn_cast<fir::ShapeOp>(shapeOp)) { 511 auto e = s.getExtents(); 512 result.insert(result.end(), e.begin(), e.end()); 513 return; 514 } 515 if (auto s = mlir::dyn_cast<fir::ShapeShiftOp>(shapeOp)) { 516 auto e = s.getExtents(); 517 result.insert(result.end(), e.begin(), e.end()); 518 return; 519 } 520 llvm::report_fatal_error("not a fir.shape/fir.shape_shift op"); 521 } 522 523 // Place the extents of the array loaded by an ArrayLoadOp into the result 524 // vector and return a ShapeOp/ShapeShiftOp with the corresponding extents. If 525 // the ArrayLoadOp is loading a fir.box, code will be generated to read the 526 // extents from the fir.box, and a the retunred ShapeOp is built with the read 527 // extents. 528 // Otherwise, the extents will be extracted from the ShapeOp/ShapeShiftOp 529 // argument of the ArrayLoadOp that is returned. 530 static mlir::Value 531 getOrReadExtentsAndShapeOp(mlir::Location loc, mlir::PatternRewriter &rewriter, 532 fir::ArrayLoadOp loadOp, 533 llvm::SmallVectorImpl<mlir::Value> &result) { 534 assert(result.empty()); 535 if (auto boxTy = loadOp.memref().getType().dyn_cast<fir::BoxType>()) { 536 auto rank = fir::dyn_cast_ptrOrBoxEleTy(boxTy) 537 .cast<fir::SequenceType>() 538 .getDimension(); 539 auto idxTy = rewriter.getIndexType(); 540 for (decltype(rank) dim = 0; dim < rank; ++dim) { 541 auto dimVal = rewriter.create<arith::ConstantIndexOp>(loc, dim); 542 auto dimInfo = rewriter.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, 543 loadOp.memref(), dimVal); 544 result.emplace_back(dimInfo.getResult(1)); 545 } 546 auto shapeType = fir::ShapeType::get(rewriter.getContext(), rank); 547 return rewriter.create<fir::ShapeOp>(loc, shapeType, result); 548 } 549 getExtents(result, loadOp.shape()); 550 return loadOp.shape(); 551 } 552 553 static mlir::Type toRefType(mlir::Type ty) { 554 if (fir::isa_ref_type(ty)) 555 return ty; 556 return fir::ReferenceType::get(ty); 557 } 558 559 static mlir::Value 560 genCoorOp(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Type eleTy, 561 mlir::Type resTy, mlir::Value alloc, mlir::Value shape, 562 mlir::Value slice, mlir::ValueRange indices, 563 mlir::ValueRange typeparams, bool skipOrig = false) { 564 llvm::SmallVector<mlir::Value> originated; 565 if (skipOrig) 566 originated.assign(indices.begin(), indices.end()); 567 else 568 originated = fir::factory::originateIndices(loc, rewriter, alloc.getType(), 569 shape, indices); 570 auto seqTy = fir::dyn_cast_ptrOrBoxEleTy(alloc.getType()); 571 assert(seqTy && seqTy.isa<fir::SequenceType>()); 572 const auto dimension = seqTy.cast<fir::SequenceType>().getDimension(); 573 mlir::Value result = rewriter.create<fir::ArrayCoorOp>( 574 loc, eleTy, alloc, shape, slice, 575 llvm::ArrayRef<mlir::Value>{originated}.take_front(dimension), 576 typeparams); 577 if (dimension < originated.size()) 578 result = rewriter.create<fir::CoordinateOp>( 579 loc, resTy, result, 580 llvm::ArrayRef<mlir::Value>{originated}.drop_front(dimension)); 581 return result; 582 } 583 584 namespace { 585 /// Conversion of fir.array_update and fir.array_modify Ops. 586 /// If there is a conflict for the update, then we need to perform a 587 /// copy-in/copy-out to preserve the original values of the array. If there is 588 /// no conflict, then it is save to eschew making any copies. 589 template <typename ArrayOp> 590 class ArrayUpdateConversionBase : public mlir::OpRewritePattern<ArrayOp> { 591 public: 592 explicit ArrayUpdateConversionBase(mlir::MLIRContext *ctx, 593 const ArrayCopyAnalysis &a, 594 const OperationUseMapT &m) 595 : mlir::OpRewritePattern<ArrayOp>{ctx}, analysis{a}, useMap{m} {} 596 597 void genArrayCopy(mlir::Location loc, mlir::PatternRewriter &rewriter, 598 mlir::Value dst, mlir::Value src, mlir::Value shapeOp, 599 mlir::Type arrTy) const { 600 auto insPt = rewriter.saveInsertionPoint(); 601 llvm::SmallVector<mlir::Value> indices; 602 llvm::SmallVector<mlir::Value> extents; 603 getExtents(extents, shapeOp); 604 // Build loop nest from column to row. 605 for (auto sh : llvm::reverse(extents)) { 606 auto idxTy = rewriter.getIndexType(); 607 auto ubi = rewriter.create<fir::ConvertOp>(loc, idxTy, sh); 608 auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 609 auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 610 auto ub = rewriter.create<arith::SubIOp>(loc, idxTy, ubi, one); 611 auto loop = rewriter.create<fir::DoLoopOp>(loc, zero, ub, one); 612 rewriter.setInsertionPointToStart(loop.getBody()); 613 indices.push_back(loop.getInductionVar()); 614 } 615 // Reverse the indices so they are in column-major order. 616 std::reverse(indices.begin(), indices.end()); 617 auto ty = getEleTy(arrTy); 618 auto fromAddr = rewriter.create<fir::ArrayCoorOp>( 619 loc, ty, src, shapeOp, mlir::Value{}, 620 fir::factory::originateIndices(loc, rewriter, src.getType(), shapeOp, 621 indices), 622 mlir::ValueRange{}); 623 auto load = rewriter.create<fir::LoadOp>(loc, fromAddr); 624 auto toAddr = rewriter.create<fir::ArrayCoorOp>( 625 loc, ty, dst, shapeOp, mlir::Value{}, 626 fir::factory::originateIndices(loc, rewriter, dst.getType(), shapeOp, 627 indices), 628 mlir::ValueRange{}); 629 rewriter.create<fir::StoreOp>(loc, load, toAddr); 630 rewriter.restoreInsertionPoint(insPt); 631 } 632 633 /// Copy the RHS element into the LHS and insert copy-in/copy-out between a 634 /// temp and the LHS if the analysis found potential overlaps between the RHS 635 /// and LHS arrays. The element copy generator must be provided through \p 636 /// assignElement. \p update must be the ArrayUpdateOp or the ArrayModifyOp. 637 /// Returns the address of the LHS element inside the loop and the LHS 638 /// ArrayLoad result. 639 std::pair<mlir::Value, mlir::Value> 640 materializeAssignment(mlir::Location loc, mlir::PatternRewriter &rewriter, 641 ArrayOp update, 642 llvm::function_ref<void(mlir::Value)> assignElement, 643 mlir::Type lhsEltRefType) const { 644 auto *op = update.getOperation(); 645 mlir::Operation *loadOp = useMap.lookup(op); 646 auto load = mlir::cast<ArrayLoadOp>(loadOp); 647 LLVM_DEBUG(llvm::outs() << "does " << load << " have a conflict?\n"); 648 if (analysis.hasPotentialConflict(loadOp)) { 649 // If there is a conflict between the arrays, then we copy the lhs array 650 // to a temporary, update the temporary, and copy the temporary back to 651 // the lhs array. This yields Fortran's copy-in copy-out array semantics. 652 LLVM_DEBUG(llvm::outs() << "Yes, conflict was found\n"); 653 rewriter.setInsertionPoint(loadOp); 654 // Copy in. 655 llvm::SmallVector<mlir::Value> extents; 656 mlir::Value shapeOp = 657 getOrReadExtentsAndShapeOp(loc, rewriter, load, extents); 658 auto allocmem = rewriter.create<AllocMemOp>( 659 loc, dyn_cast_ptrOrBoxEleTy(load.memref().getType()), 660 load.typeparams(), extents); 661 genArrayCopy(load.getLoc(), rewriter, allocmem, load.memref(), shapeOp, 662 load.getType()); 663 rewriter.setInsertionPoint(op); 664 mlir::Value coor = genCoorOp( 665 rewriter, loc, getEleTy(load.getType()), lhsEltRefType, allocmem, 666 shapeOp, load.slice(), update.indices(), load.typeparams(), 667 update->hasAttr(fir::factory::attrFortranArrayOffsets())); 668 assignElement(coor); 669 mlir::Operation *storeOp = useMap.lookup(loadOp); 670 auto store = mlir::cast<ArrayMergeStoreOp>(storeOp); 671 rewriter.setInsertionPoint(storeOp); 672 // Copy out. 673 genArrayCopy(store.getLoc(), rewriter, store.memref(), allocmem, shapeOp, 674 load.getType()); 675 rewriter.create<FreeMemOp>(loc, allocmem); 676 return {coor, load.getResult()}; 677 } 678 // Otherwise, when there is no conflict (a possible loop-carried 679 // dependence), the lhs array can be updated in place. 680 LLVM_DEBUG(llvm::outs() << "No, conflict wasn't found\n"); 681 rewriter.setInsertionPoint(op); 682 auto coorTy = getEleTy(load.getType()); 683 mlir::Value coor = genCoorOp( 684 rewriter, loc, coorTy, lhsEltRefType, load.memref(), load.shape(), 685 load.slice(), update.indices(), load.typeparams(), 686 update->hasAttr(fir::factory::attrFortranArrayOffsets())); 687 assignElement(coor); 688 return {coor, load.getResult()}; 689 } 690 691 private: 692 const ArrayCopyAnalysis &analysis; 693 const OperationUseMapT &useMap; 694 }; 695 696 class ArrayUpdateConversion : public ArrayUpdateConversionBase<ArrayUpdateOp> { 697 public: 698 explicit ArrayUpdateConversion(mlir::MLIRContext *ctx, 699 const ArrayCopyAnalysis &a, 700 const OperationUseMapT &m) 701 : ArrayUpdateConversionBase{ctx, a, m} {} 702 703 mlir::LogicalResult 704 matchAndRewrite(ArrayUpdateOp update, 705 mlir::PatternRewriter &rewriter) const override { 706 auto loc = update.getLoc(); 707 auto assignElement = [&](mlir::Value coor) { 708 rewriter.create<fir::StoreOp>(loc, update.merge(), coor); 709 }; 710 auto lhsEltRefType = toRefType(update.merge().getType()); 711 auto [_, lhsLoadResult] = materializeAssignment( 712 loc, rewriter, update, assignElement, lhsEltRefType); 713 update.replaceAllUsesWith(lhsLoadResult); 714 rewriter.replaceOp(update, lhsLoadResult); 715 return mlir::success(); 716 } 717 }; 718 719 class ArrayModifyConversion : public ArrayUpdateConversionBase<ArrayModifyOp> { 720 public: 721 explicit ArrayModifyConversion(mlir::MLIRContext *ctx, 722 const ArrayCopyAnalysis &a, 723 const OperationUseMapT &m) 724 : ArrayUpdateConversionBase{ctx, a, m} {} 725 726 mlir::LogicalResult 727 matchAndRewrite(ArrayModifyOp modify, 728 mlir::PatternRewriter &rewriter) const override { 729 auto loc = modify.getLoc(); 730 auto assignElement = [](mlir::Value) { 731 // Assignment already materialized by lowering using lhs element address. 732 }; 733 auto lhsEltRefType = modify.getResult(0).getType(); 734 auto [lhsEltCoor, lhsLoadResult] = materializeAssignment( 735 loc, rewriter, modify, assignElement, lhsEltRefType); 736 modify.replaceAllUsesWith(mlir::ValueRange{lhsEltCoor, lhsLoadResult}); 737 rewriter.replaceOp(modify, mlir::ValueRange{lhsEltCoor, lhsLoadResult}); 738 return mlir::success(); 739 } 740 }; 741 742 class ArrayFetchConversion : public mlir::OpRewritePattern<ArrayFetchOp> { 743 public: 744 explicit ArrayFetchConversion(mlir::MLIRContext *ctx, 745 const OperationUseMapT &m) 746 : OpRewritePattern{ctx}, useMap{m} {} 747 748 mlir::LogicalResult 749 matchAndRewrite(ArrayFetchOp fetch, 750 mlir::PatternRewriter &rewriter) const override { 751 auto *op = fetch.getOperation(); 752 rewriter.setInsertionPoint(op); 753 auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op)); 754 auto loc = fetch.getLoc(); 755 mlir::Value coor = 756 genCoorOp(rewriter, loc, getEleTy(load.getType()), 757 toRefType(fetch.getType()), load.memref(), load.shape(), 758 load.slice(), fetch.indices(), load.typeparams(), 759 fetch->hasAttr(fir::factory::attrFortranArrayOffsets())); 760 rewriter.replaceOpWithNewOp<fir::LoadOp>(fetch, coor); 761 return mlir::success(); 762 } 763 764 private: 765 const OperationUseMapT &useMap; 766 }; 767 } // namespace 768 769 namespace { 770 class ArrayValueCopyConverter 771 : public ArrayValueCopyBase<ArrayValueCopyConverter> { 772 public: 773 void runOnOperation() override { 774 auto func = getOperation(); 775 LLVM_DEBUG(llvm::dbgs() << "\n\narray-value-copy pass on function '" 776 << func.getName() << "'\n"); 777 auto *context = &getContext(); 778 779 // Perform the conflict analysis. 780 auto &analysis = getAnalysis<ArrayCopyAnalysis>(); 781 const auto &useMap = analysis.getUseMap(); 782 783 // Phase 1 is performing a rewrite on the array accesses. Once all the 784 // array accesses are rewritten we can go on phase 2. 785 // Phase 2 gets rid of the useless copy-in/copyout operations. The copy-in 786 // /copy-out refers the Fortran copy-in/copy-out semantics on statements. 787 mlir::OwningRewritePatternList patterns1(context); 788 patterns1.insert<ArrayFetchConversion>(context, useMap); 789 patterns1.insert<ArrayUpdateConversion>(context, analysis, useMap); 790 patterns1.insert<ArrayModifyConversion>(context, analysis, useMap); 791 mlir::ConversionTarget target(*context); 792 target.addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect, 793 mlir::arith::ArithmeticDialect, 794 mlir::StandardOpsDialect>(); 795 target.addIllegalOp<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(); 796 // Rewrite the array fetch and array update ops. 797 if (mlir::failed( 798 mlir::applyPartialConversion(func, target, std::move(patterns1)))) { 799 mlir::emitError(mlir::UnknownLoc::get(context), 800 "failure in array-value-copy pass, phase 1"); 801 signalPassFailure(); 802 } 803 804 mlir::OwningRewritePatternList patterns2(context); 805 patterns2.insert<ArrayLoadConversion>(context); 806 patterns2.insert<ArrayMergeStoreConversion>(context); 807 target.addIllegalOp<ArrayLoadOp, ArrayMergeStoreOp>(); 808 if (mlir::failed( 809 mlir::applyPartialConversion(func, target, std::move(patterns2)))) { 810 mlir::emitError(mlir::UnknownLoc::get(context), 811 "failure in array-value-copy pass, phase 2"); 812 signalPassFailure(); 813 } 814 } 815 }; 816 } // namespace 817 818 std::unique_ptr<mlir::Pass> fir::createArrayValueCopyPass() { 819 return std::make_unique<ArrayValueCopyConverter>(); 820 } 821