1 //===-- ArrayValueCopy.cpp ------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "flang/Optimizer/Builder/BoxValue.h" 11 #include "flang/Optimizer/Builder/FIRBuilder.h" 12 #include "flang/Optimizer/Builder/Factory.h" 13 #include "flang/Optimizer/Dialect/FIRDialect.h" 14 #include "flang/Optimizer/Support/FIRContext.h" 15 #include "flang/Optimizer/Transforms/Passes.h" 16 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 17 #include "mlir/Dialect/SCF/SCF.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 #include "llvm/Support/Debug.h" 20 21 #define DEBUG_TYPE "flang-array-value-copy" 22 23 using namespace fir; 24 25 using OperationUseMapT = llvm::DenseMap<mlir::Operation *, mlir::Operation *>; 26 27 namespace { 28 29 /// Array copy analysis. 30 /// Perform an interference analysis between array values. 31 /// 32 /// Lowering will generate a sequence of the following form. 33 /// ```mlir 34 /// %a_1 = fir.array_load %array_1(%shape) : ... 35 /// ... 36 /// %a_j = fir.array_load %array_j(%shape) : ... 37 /// ... 38 /// %a_n = fir.array_load %array_n(%shape) : ... 39 /// ... 40 /// %v_i = fir.array_fetch %a_i, ... 41 /// %a_j1 = fir.array_update %a_j, ... 42 /// ... 43 /// fir.array_merge_store %a_j, %a_jn to %array_j : ... 44 /// ``` 45 /// 46 /// The analysis is to determine if there are any conflicts. A conflict is when 47 /// one the following cases occurs. 48 /// 49 /// 1. There is an `array_update` to an array value, a_j, such that a_j was 50 /// loaded from the same array memory reference (array_j) but with a different 51 /// shape as the other array values a_i, where i != j. [Possible overlapping 52 /// arrays.] 53 /// 54 /// 2. There is either an array_fetch or array_update of a_j with a different 55 /// set of index values. [Possible loop-carried dependence.] 56 /// 57 /// If none of the array values overlap in storage and the accesses are not 58 /// loop-carried, then the arrays are conflict-free and no copies are required. 59 class ArrayCopyAnalysis { 60 public: 61 using ConflictSetT = llvm::SmallPtrSet<mlir::Operation *, 16>; 62 using UseSetT = llvm::SmallPtrSet<mlir::OpOperand *, 8>; 63 using LoadMapSetsT = 64 llvm::DenseMap<mlir::Operation *, SmallVector<Operation *>>; 65 66 ArrayCopyAnalysis(mlir::Operation *op) : operation{op} { construct(op); } 67 68 mlir::Operation *getOperation() const { return operation; } 69 70 /// Return true iff the `array_merge_store` has potential conflicts. 71 bool hasPotentialConflict(mlir::Operation *op) const { 72 LLVM_DEBUG(llvm::dbgs() 73 << "looking for a conflict on " << *op 74 << " and the set has a total of " << conflicts.size() << '\n'); 75 return conflicts.contains(op); 76 } 77 78 /// Return the use map. The use map maps array fetch and update operations 79 /// back to the array load that is the original source of the array value. 80 const OperationUseMapT &getUseMap() const { return useMap; } 81 82 /// Find all the array operations that access the array value that is loaded 83 /// by the array load operation, `load`. 84 const llvm::SmallVector<mlir::Operation *> &arrayAccesses(ArrayLoadOp load); 85 86 private: 87 void construct(mlir::Operation *topLevelOp); 88 89 mlir::Operation *operation; // operation that analysis ran upon 90 ConflictSetT conflicts; // set of conflicts (loads and merge stores) 91 OperationUseMapT useMap; 92 LoadMapSetsT loadMapSets; 93 }; 94 } // namespace 95 96 namespace { 97 /// Helper class to collect all array operations that produced an array value. 98 class ReachCollector { 99 private: 100 // If provided, the `loopRegion` is the body of a loop that produces the array 101 // of interest. 102 ReachCollector(llvm::SmallVectorImpl<mlir::Operation *> &reach, 103 mlir::Region *loopRegion) 104 : reach{reach}, loopRegion{loopRegion} {} 105 106 void collectArrayAccessFrom(mlir::Operation *op, mlir::ValueRange range) { 107 llvm::errs() << "COLLECT " << *op << "\n"; 108 if (range.empty()) { 109 collectArrayAccessFrom(op, mlir::Value{}); 110 return; 111 } 112 for (mlir::Value v : range) 113 collectArrayAccessFrom(v); 114 } 115 116 // TODO: Replace recursive algorithm on def-use chain with an iterative one 117 // with an explicit stack. 118 void collectArrayAccessFrom(mlir::Operation *op, mlir::Value val) { 119 // `val` is defined by an Op, process the defining Op. 120 // If `val` is defined by a region containing Op, we want to drill down 121 // and through that Op's region(s). 122 llvm::errs() << "COLLECT " << *op << "\n"; 123 LLVM_DEBUG(llvm::dbgs() << "popset: " << *op << '\n'); 124 auto popFn = [&](auto rop) { 125 assert(val && "op must have a result value"); 126 auto resNum = val.cast<mlir::OpResult>().getResultNumber(); 127 llvm::SmallVector<mlir::Value> results; 128 rop.resultToSourceOps(results, resNum); 129 for (auto u : results) 130 collectArrayAccessFrom(u); 131 }; 132 if (auto rop = mlir::dyn_cast<fir::DoLoopOp>(op)) { 133 popFn(rop); 134 return; 135 } 136 if (auto rop = mlir::dyn_cast<fir::IfOp>(op)) { 137 popFn(rop); 138 return; 139 } 140 if (auto mergeStore = mlir::dyn_cast<ArrayMergeStoreOp>(op)) { 141 if (opIsInsideLoops(mergeStore)) 142 collectArrayAccessFrom(mergeStore.getSequence()); 143 return; 144 } 145 146 if (mlir::isa<AllocaOp, AllocMemOp>(op)) { 147 // Look for any stores inside the loops, and collect an array operation 148 // that produced the value being stored to it. 149 for (mlir::Operation *user : op->getUsers()) 150 if (auto store = mlir::dyn_cast<fir::StoreOp>(user)) 151 if (opIsInsideLoops(store)) 152 collectArrayAccessFrom(store.getValue()); 153 return; 154 } 155 156 // Otherwise, Op does not contain a region so just chase its operands. 157 if (mlir::isa<ArrayLoadOp, ArrayUpdateOp, ArrayModifyOp, ArrayFetchOp>( 158 op)) { 159 LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n"); 160 reach.emplace_back(op); 161 } 162 // Array modify assignment is performed on the result. So the analysis 163 // must look at the what is done with the result. 164 if (mlir::isa<ArrayModifyOp>(op)) 165 for (mlir::Operation *user : op->getResult(0).getUsers()) 166 followUsers(user); 167 168 for (auto u : op->getOperands()) 169 collectArrayAccessFrom(u); 170 } 171 172 void collectArrayAccessFrom(mlir::BlockArgument ba) { 173 auto *parent = ba.getOwner()->getParentOp(); 174 // If inside an Op holding a region, the block argument corresponds to an 175 // argument passed to the containing Op. 176 auto popFn = [&](auto rop) { 177 collectArrayAccessFrom(rop.blockArgToSourceOp(ba.getArgNumber())); 178 }; 179 if (auto rop = mlir::dyn_cast<DoLoopOp>(parent)) { 180 popFn(rop); 181 return; 182 } 183 if (auto rop = mlir::dyn_cast<IterWhileOp>(parent)) { 184 popFn(rop); 185 return; 186 } 187 // Otherwise, a block argument is provided via the pred blocks. 188 for (auto *pred : ba.getOwner()->getPredecessors()) { 189 auto u = pred->getTerminator()->getOperand(ba.getArgNumber()); 190 collectArrayAccessFrom(u); 191 } 192 } 193 194 // Recursively trace operands to find all array operations relating to the 195 // values merged. 196 void collectArrayAccessFrom(mlir::Value val) { 197 if (!val || visited.contains(val)) 198 return; 199 visited.insert(val); 200 201 // Process a block argument. 202 if (auto ba = val.dyn_cast<mlir::BlockArgument>()) { 203 collectArrayAccessFrom(ba); 204 return; 205 } 206 207 // Process an Op. 208 if (auto *op = val.getDefiningOp()) { 209 collectArrayAccessFrom(op, val); 210 return; 211 } 212 213 fir::emitFatalError(val.getLoc(), "unhandled value"); 214 } 215 216 /// Is \op inside the loop nest region ? 217 bool opIsInsideLoops(mlir::Operation *op) const { 218 return loopRegion && loopRegion->isAncestor(op->getParentRegion()); 219 } 220 221 /// Recursively trace the use of an operation results, calling 222 /// collectArrayAccessFrom on the direct and indirect user operands. 223 /// TODO: Replace recursive algorithm on def-use chain with an iterative one 224 /// with an explicit stack. 225 void followUsers(mlir::Operation *op) { 226 for (auto userOperand : op->getOperands()) 227 collectArrayAccessFrom(userOperand); 228 // Go through potential converts/coordinate_op. 229 for (mlir::Operation *indirectUser : op->getUsers()) 230 followUsers(indirectUser); 231 } 232 233 llvm::SmallVectorImpl<mlir::Operation *> &reach; 234 llvm::SmallPtrSet<mlir::Value, 16> visited; 235 /// Region of the loops nest that produced the array value. 236 mlir::Region *loopRegion; 237 238 public: 239 /// Return all ops that produce the array value that is stored into the 240 /// `array_merge_store`. 241 static void reachingValues(llvm::SmallVectorImpl<mlir::Operation *> &reach, 242 mlir::Value seq) { 243 reach.clear(); 244 mlir::Region *loopRegion = nullptr; 245 // Only `DoLoopOp` is tested here since array operations are currently only 246 // associated with this kind of loop. 247 if (auto doLoop = 248 mlir::dyn_cast_or_null<fir::DoLoopOp>(seq.getDefiningOp())) 249 loopRegion = &doLoop->getRegion(0); 250 ReachCollector collector(reach, loopRegion); 251 collector.collectArrayAccessFrom(seq); 252 } 253 }; 254 } // namespace 255 256 /// Find all the array operations that access the array value that is loaded by 257 /// the array load operation, `load`. 258 const llvm::SmallVector<mlir::Operation *> & 259 ArrayCopyAnalysis::arrayAccesses(ArrayLoadOp load) { 260 auto lmIter = loadMapSets.find(load); 261 if (lmIter != loadMapSets.end()) 262 return lmIter->getSecond(); 263 264 llvm::SmallVector<mlir::Operation *> accesses; 265 UseSetT visited; 266 llvm::SmallVector<mlir::OpOperand *> queue; // uses of ArrayLoad[orig] 267 268 auto appendToQueue = [&](mlir::Value val) { 269 for (mlir::OpOperand &use : val.getUses()) 270 if (!visited.count(&use)) { 271 visited.insert(&use); 272 queue.push_back(&use); 273 } 274 }; 275 276 // Build the set of uses of `original`. 277 // let USES = { uses of original fir.load } 278 appendToQueue(load); 279 280 // Process the worklist until done. 281 while (!queue.empty()) { 282 mlir::OpOperand *operand = queue.pop_back_val(); 283 mlir::Operation *owner = operand->getOwner(); 284 285 auto structuredLoop = [&](auto ro) { 286 if (auto blockArg = ro.iterArgToBlockArg(operand->get())) { 287 int64_t arg = blockArg.getArgNumber(); 288 mlir::Value output = ro.getResult(ro.getFinalValue() ? arg : arg - 1); 289 appendToQueue(output); 290 appendToQueue(blockArg); 291 } 292 }; 293 // TODO: this need to be updated to use the control-flow interface. 294 auto branchOp = [&](mlir::Block *dest, OperandRange operands) { 295 if (operands.empty()) 296 return; 297 298 // Check if this operand is within the range. 299 unsigned operandIndex = operand->getOperandNumber(); 300 unsigned operandsStart = operands.getBeginOperandIndex(); 301 if (operandIndex < operandsStart || 302 operandIndex >= (operandsStart + operands.size())) 303 return; 304 305 // Index the successor. 306 unsigned argIndex = operandIndex - operandsStart; 307 appendToQueue(dest->getArgument(argIndex)); 308 }; 309 // Thread uses into structured loop bodies and return value uses. 310 if (auto ro = mlir::dyn_cast<DoLoopOp>(owner)) { 311 structuredLoop(ro); 312 } else if (auto ro = mlir::dyn_cast<IterWhileOp>(owner)) { 313 structuredLoop(ro); 314 } else if (auto rs = mlir::dyn_cast<ResultOp>(owner)) { 315 // Thread any uses of fir.if that return the marked array value. 316 if (auto ifOp = rs->getParentOfType<fir::IfOp>()) 317 appendToQueue(ifOp.getResult(operand->getOperandNumber())); 318 } else if (mlir::isa<ArrayFetchOp>(owner)) { 319 // Keep track of array value fetches. 320 LLVM_DEBUG(llvm::dbgs() 321 << "add fetch {" << *owner << "} to array value set\n"); 322 accesses.push_back(owner); 323 } else if (auto update = mlir::dyn_cast<ArrayUpdateOp>(owner)) { 324 // Keep track of array value updates and thread the return value uses. 325 LLVM_DEBUG(llvm::dbgs() 326 << "add update {" << *owner << "} to array value set\n"); 327 accesses.push_back(owner); 328 appendToQueue(update.getResult()); 329 } else if (auto update = mlir::dyn_cast<ArrayModifyOp>(owner)) { 330 // Keep track of array value modification and thread the return value 331 // uses. 332 LLVM_DEBUG(llvm::dbgs() 333 << "add modify {" << *owner << "} to array value set\n"); 334 accesses.push_back(owner); 335 appendToQueue(update.getResult(1)); 336 } else if (auto br = mlir::dyn_cast<mlir::cf::BranchOp>(owner)) { 337 branchOp(br.getDest(), br.getDestOperands()); 338 } else if (auto br = mlir::dyn_cast<mlir::cf::CondBranchOp>(owner)) { 339 branchOp(br.getTrueDest(), br.getTrueOperands()); 340 branchOp(br.getFalseDest(), br.getFalseOperands()); 341 } else if (mlir::isa<ArrayMergeStoreOp>(owner)) { 342 // do nothing 343 } else { 344 llvm::report_fatal_error("array value reached unexpected op"); 345 } 346 } 347 return loadMapSets.insert({load, accesses}).first->getSecond(); 348 } 349 350 /// Is there a conflict between the array value that was updated and to be 351 /// stored to `st` and the set of arrays loaded (`reach`) and used to compute 352 /// the updated value? 353 static bool conflictOnLoad(llvm::ArrayRef<mlir::Operation *> reach, 354 ArrayMergeStoreOp st) { 355 mlir::Value load; 356 mlir::Value addr = st.getMemref(); 357 auto stEleTy = fir::dyn_cast_ptrOrBoxEleTy(addr.getType()); 358 for (auto *op : reach) { 359 auto ld = mlir::dyn_cast<ArrayLoadOp>(op); 360 if (!ld) 361 continue; 362 mlir::Type ldTy = ld.getMemref().getType(); 363 if (auto boxTy = ldTy.dyn_cast<fir::BoxType>()) 364 ldTy = boxTy.getEleTy(); 365 if (ldTy.isa<fir::PointerType>() && stEleTy == dyn_cast_ptrEleTy(ldTy)) 366 return true; 367 if (ld.getMemref() == addr) { 368 if (ld.getResult() != st.getOriginal()) 369 return true; 370 if (load) 371 return true; 372 load = ld; 373 } 374 } 375 return false; 376 } 377 378 /// Check if there is any potential conflict in the chained update operations 379 /// (ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp) while merging back to the 380 /// array. A potential conflict is detected if two operations work on the same 381 /// indices. 382 static bool conflictOnMerge(llvm::ArrayRef<mlir::Operation *> accesses) { 383 if (accesses.size() < 2) 384 return false; 385 llvm::SmallVector<mlir::Value> indices; 386 LLVM_DEBUG(llvm::dbgs() << "check merge conflict on with " << accesses.size() 387 << " accesses on the list\n"); 388 for (auto *op : accesses) { 389 assert((mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(op)) && 390 "unexpected operation in analysis"); 391 llvm::SmallVector<mlir::Value> compareVector; 392 if (auto u = mlir::dyn_cast<ArrayUpdateOp>(op)) { 393 if (indices.empty()) { 394 indices = u.getIndices(); 395 continue; 396 } 397 compareVector = u.getIndices(); 398 } else if (auto f = mlir::dyn_cast<ArrayModifyOp>(op)) { 399 if (indices.empty()) { 400 indices = f.getIndices(); 401 continue; 402 } 403 compareVector = f.getIndices(); 404 } else if (auto f = mlir::dyn_cast<ArrayFetchOp>(op)) { 405 if (indices.empty()) { 406 indices = f.getIndices(); 407 continue; 408 } 409 compareVector = f.getIndices(); 410 } 411 if (compareVector != indices) 412 return true; 413 LLVM_DEBUG(llvm::dbgs() << "vectors compare equal\n"); 414 } 415 return false; 416 } 417 418 // Are either of types of conflicts present? 419 inline bool conflictDetected(llvm::ArrayRef<mlir::Operation *> reach, 420 llvm::ArrayRef<mlir::Operation *> accesses, 421 ArrayMergeStoreOp st) { 422 return conflictOnLoad(reach, st) || conflictOnMerge(accesses); 423 } 424 425 /// Constructor of the array copy analysis. 426 /// This performs the analysis and saves the intermediate results. 427 void ArrayCopyAnalysis::construct(mlir::Operation *topLevelOp) { 428 topLevelOp->walk([&](Operation *op) { 429 if (auto st = mlir::dyn_cast<fir::ArrayMergeStoreOp>(op)) { 430 llvm::SmallVector<Operation *> values; 431 ReachCollector::reachingValues(values, st.getSequence()); 432 const llvm::SmallVector<Operation *> &accesses = arrayAccesses( 433 mlir::cast<ArrayLoadOp>(st.getOriginal().getDefiningOp())); 434 if (conflictDetected(values, accesses, st)) { 435 LLVM_DEBUG(llvm::dbgs() 436 << "CONFLICT: copies required for " << st << '\n' 437 << " adding conflicts on: " << op << " and " 438 << st.getOriginal() << '\n'); 439 conflicts.insert(op); 440 conflicts.insert(st.getOriginal().getDefiningOp()); 441 } 442 auto *ld = st.getOriginal().getDefiningOp(); 443 LLVM_DEBUG(llvm::dbgs() 444 << "map: adding {" << *ld << " -> " << st << "}\n"); 445 useMap.insert({ld, op}); 446 } else if (auto load = mlir::dyn_cast<ArrayLoadOp>(op)) { 447 const llvm::SmallVector<mlir::Operation *> &accesses = 448 arrayAccesses(load); 449 LLVM_DEBUG(llvm::dbgs() << "process load: " << load 450 << ", accesses: " << accesses.size() << '\n'); 451 for (auto *acc : accesses) { 452 LLVM_DEBUG(llvm::dbgs() << " access: " << *acc << '\n'); 453 assert((mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(acc))); 454 if (!useMap.insert({acc, op}).second) { 455 mlir::emitError( 456 load.getLoc(), 457 "The parallel semantics of multiple array_merge_stores per " 458 "array_load are not supported."); 459 return; 460 } 461 LLVM_DEBUG(llvm::dbgs() 462 << "map: adding {" << *acc << "} -> {" << load << "}\n"); 463 } 464 } 465 }); 466 } 467 468 namespace { 469 class ArrayLoadConversion : public mlir::OpRewritePattern<ArrayLoadOp> { 470 public: 471 using OpRewritePattern::OpRewritePattern; 472 473 mlir::LogicalResult 474 matchAndRewrite(ArrayLoadOp load, 475 mlir::PatternRewriter &rewriter) const override { 476 LLVM_DEBUG(llvm::dbgs() << "replace load " << load << " with undef.\n"); 477 rewriter.replaceOpWithNewOp<UndefOp>(load, load.getType()); 478 return mlir::success(); 479 } 480 }; 481 482 class ArrayMergeStoreConversion 483 : public mlir::OpRewritePattern<ArrayMergeStoreOp> { 484 public: 485 using OpRewritePattern::OpRewritePattern; 486 487 mlir::LogicalResult 488 matchAndRewrite(ArrayMergeStoreOp store, 489 mlir::PatternRewriter &rewriter) const override { 490 LLVM_DEBUG(llvm::dbgs() << "marking store " << store << " as dead.\n"); 491 rewriter.eraseOp(store); 492 return mlir::success(); 493 } 494 }; 495 } // namespace 496 497 static mlir::Type getEleTy(mlir::Type ty) { 498 if (auto t = dyn_cast_ptrEleTy(ty)) 499 ty = t; 500 if (auto t = ty.dyn_cast<SequenceType>()) 501 ty = t.getEleTy(); 502 // FIXME: keep ptr/heap/ref information. 503 return ReferenceType::get(ty); 504 } 505 506 // Extract extents from the ShapeOp/ShapeShiftOp into the result vector. 507 // TODO: getExtents on op should return a ValueRange instead of a vector. 508 static void getExtents(llvm::SmallVectorImpl<mlir::Value> &result, 509 mlir::Value shape) { 510 auto *shapeOp = shape.getDefiningOp(); 511 if (auto s = mlir::dyn_cast<fir::ShapeOp>(shapeOp)) { 512 auto e = s.getExtents(); 513 result.insert(result.end(), e.begin(), e.end()); 514 return; 515 } 516 if (auto s = mlir::dyn_cast<fir::ShapeShiftOp>(shapeOp)) { 517 auto e = s.getExtents(); 518 result.insert(result.end(), e.begin(), e.end()); 519 return; 520 } 521 llvm::report_fatal_error("not a fir.shape/fir.shape_shift op"); 522 } 523 524 // Place the extents of the array loaded by an ArrayLoadOp into the result 525 // vector and return a ShapeOp/ShapeShiftOp with the corresponding extents. If 526 // the ArrayLoadOp is loading a fir.box, code will be generated to read the 527 // extents from the fir.box, and a the retunred ShapeOp is built with the read 528 // extents. 529 // Otherwise, the extents will be extracted from the ShapeOp/ShapeShiftOp 530 // argument of the ArrayLoadOp that is returned. 531 static mlir::Value 532 getOrReadExtentsAndShapeOp(mlir::Location loc, mlir::PatternRewriter &rewriter, 533 fir::ArrayLoadOp loadOp, 534 llvm::SmallVectorImpl<mlir::Value> &result) { 535 assert(result.empty()); 536 if (auto boxTy = loadOp.getMemref().getType().dyn_cast<fir::BoxType>()) { 537 auto rank = fir::dyn_cast_ptrOrBoxEleTy(boxTy) 538 .cast<fir::SequenceType>() 539 .getDimension(); 540 auto idxTy = rewriter.getIndexType(); 541 for (decltype(rank) dim = 0; dim < rank; ++dim) { 542 auto dimVal = rewriter.create<arith::ConstantIndexOp>(loc, dim); 543 auto dimInfo = rewriter.create<fir::BoxDimsOp>( 544 loc, idxTy, idxTy, idxTy, loadOp.getMemref(), dimVal); 545 result.emplace_back(dimInfo.getResult(1)); 546 } 547 auto shapeType = fir::ShapeType::get(rewriter.getContext(), rank); 548 return rewriter.create<fir::ShapeOp>(loc, shapeType, result); 549 } 550 getExtents(result, loadOp.getShape()); 551 return loadOp.getShape(); 552 } 553 554 static mlir::Type toRefType(mlir::Type ty) { 555 if (fir::isa_ref_type(ty)) 556 return ty; 557 return fir::ReferenceType::get(ty); 558 } 559 560 static mlir::Value 561 genCoorOp(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Type eleTy, 562 mlir::Type resTy, mlir::Value alloc, mlir::Value shape, 563 mlir::Value slice, mlir::ValueRange indices, 564 mlir::ValueRange typeparams, bool skipOrig = false) { 565 llvm::SmallVector<mlir::Value> originated; 566 if (skipOrig) 567 originated.assign(indices.begin(), indices.end()); 568 else 569 originated = fir::factory::originateIndices(loc, rewriter, alloc.getType(), 570 shape, indices); 571 auto seqTy = fir::dyn_cast_ptrOrBoxEleTy(alloc.getType()); 572 assert(seqTy && seqTy.isa<fir::SequenceType>()); 573 const auto dimension = seqTy.cast<fir::SequenceType>().getDimension(); 574 mlir::Value result = rewriter.create<fir::ArrayCoorOp>( 575 loc, eleTy, alloc, shape, slice, 576 llvm::ArrayRef<mlir::Value>{originated}.take_front(dimension), 577 typeparams); 578 if (dimension < originated.size()) 579 result = rewriter.create<fir::CoordinateOp>( 580 loc, resTy, result, 581 llvm::ArrayRef<mlir::Value>{originated}.drop_front(dimension)); 582 return result; 583 } 584 585 namespace { 586 /// Conversion of fir.array_update and fir.array_modify Ops. 587 /// If there is a conflict for the update, then we need to perform a 588 /// copy-in/copy-out to preserve the original values of the array. If there is 589 /// no conflict, then it is save to eschew making any copies. 590 template <typename ArrayOp> 591 class ArrayUpdateConversionBase : public mlir::OpRewritePattern<ArrayOp> { 592 public: 593 explicit ArrayUpdateConversionBase(mlir::MLIRContext *ctx, 594 const ArrayCopyAnalysis &a, 595 const OperationUseMapT &m) 596 : mlir::OpRewritePattern<ArrayOp>{ctx}, analysis{a}, useMap{m} {} 597 598 void genArrayCopy(mlir::Location loc, mlir::PatternRewriter &rewriter, 599 mlir::Value dst, mlir::Value src, mlir::Value shapeOp, 600 mlir::Type arrTy) const { 601 auto insPt = rewriter.saveInsertionPoint(); 602 llvm::SmallVector<mlir::Value> indices; 603 llvm::SmallVector<mlir::Value> extents; 604 getExtents(extents, shapeOp); 605 // Build loop nest from column to row. 606 for (auto sh : llvm::reverse(extents)) { 607 auto idxTy = rewriter.getIndexType(); 608 auto ubi = rewriter.create<fir::ConvertOp>(loc, idxTy, sh); 609 auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 610 auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 611 auto ub = rewriter.create<arith::SubIOp>(loc, idxTy, ubi, one); 612 auto loop = rewriter.create<fir::DoLoopOp>(loc, zero, ub, one); 613 rewriter.setInsertionPointToStart(loop.getBody()); 614 indices.push_back(loop.getInductionVar()); 615 } 616 // Reverse the indices so they are in column-major order. 617 std::reverse(indices.begin(), indices.end()); 618 auto ty = getEleTy(arrTy); 619 auto fromAddr = rewriter.create<fir::ArrayCoorOp>( 620 loc, ty, src, shapeOp, mlir::Value{}, 621 fir::factory::originateIndices(loc, rewriter, src.getType(), shapeOp, 622 indices), 623 mlir::ValueRange{}); 624 auto load = rewriter.create<fir::LoadOp>(loc, fromAddr); 625 auto toAddr = rewriter.create<fir::ArrayCoorOp>( 626 loc, ty, dst, shapeOp, mlir::Value{}, 627 fir::factory::originateIndices(loc, rewriter, dst.getType(), shapeOp, 628 indices), 629 mlir::ValueRange{}); 630 rewriter.create<fir::StoreOp>(loc, load, toAddr); 631 rewriter.restoreInsertionPoint(insPt); 632 } 633 634 /// Copy the RHS element into the LHS and insert copy-in/copy-out between a 635 /// temp and the LHS if the analysis found potential overlaps between the RHS 636 /// and LHS arrays. The element copy generator must be provided through \p 637 /// assignElement. \p update must be the ArrayUpdateOp or the ArrayModifyOp. 638 /// Returns the address of the LHS element inside the loop and the LHS 639 /// ArrayLoad result. 640 std::pair<mlir::Value, mlir::Value> 641 materializeAssignment(mlir::Location loc, mlir::PatternRewriter &rewriter, 642 ArrayOp update, 643 llvm::function_ref<void(mlir::Value)> assignElement, 644 mlir::Type lhsEltRefType) const { 645 auto *op = update.getOperation(); 646 mlir::Operation *loadOp = useMap.lookup(op); 647 auto load = mlir::cast<ArrayLoadOp>(loadOp); 648 LLVM_DEBUG(llvm::outs() << "does " << load << " have a conflict?\n"); 649 if (analysis.hasPotentialConflict(loadOp)) { 650 // If there is a conflict between the arrays, then we copy the lhs array 651 // to a temporary, update the temporary, and copy the temporary back to 652 // the lhs array. This yields Fortran's copy-in copy-out array semantics. 653 LLVM_DEBUG(llvm::outs() << "Yes, conflict was found\n"); 654 rewriter.setInsertionPoint(loadOp); 655 // Copy in. 656 llvm::SmallVector<mlir::Value> extents; 657 mlir::Value shapeOp = 658 getOrReadExtentsAndShapeOp(loc, rewriter, load, extents); 659 auto allocmem = rewriter.create<AllocMemOp>( 660 loc, dyn_cast_ptrOrBoxEleTy(load.getMemref().getType()), 661 load.getTypeparams(), extents); 662 genArrayCopy(load.getLoc(), rewriter, allocmem, load.getMemref(), shapeOp, 663 load.getType()); 664 rewriter.setInsertionPoint(op); 665 mlir::Value coor = genCoorOp( 666 rewriter, loc, getEleTy(load.getType()), lhsEltRefType, allocmem, 667 shapeOp, load.getSlice(), update.getIndices(), load.getTypeparams(), 668 update->hasAttr(fir::factory::attrFortranArrayOffsets())); 669 assignElement(coor); 670 mlir::Operation *storeOp = useMap.lookup(loadOp); 671 auto store = mlir::cast<ArrayMergeStoreOp>(storeOp); 672 rewriter.setInsertionPoint(storeOp); 673 // Copy out. 674 genArrayCopy(store.getLoc(), rewriter, store.getMemref(), allocmem, 675 shapeOp, load.getType()); 676 rewriter.create<FreeMemOp>(loc, allocmem); 677 return {coor, load.getResult()}; 678 } 679 // Otherwise, when there is no conflict (a possible loop-carried 680 // dependence), the lhs array can be updated in place. 681 LLVM_DEBUG(llvm::outs() << "No, conflict wasn't found\n"); 682 rewriter.setInsertionPoint(op); 683 auto coorTy = getEleTy(load.getType()); 684 mlir::Value coor = genCoorOp( 685 rewriter, loc, coorTy, lhsEltRefType, load.getMemref(), load.getShape(), 686 load.getSlice(), update.getIndices(), load.getTypeparams(), 687 update->hasAttr(fir::factory::attrFortranArrayOffsets())); 688 assignElement(coor); 689 return {coor, load.getResult()}; 690 } 691 692 private: 693 const ArrayCopyAnalysis &analysis; 694 const OperationUseMapT &useMap; 695 }; 696 697 class ArrayUpdateConversion : public ArrayUpdateConversionBase<ArrayUpdateOp> { 698 public: 699 explicit ArrayUpdateConversion(mlir::MLIRContext *ctx, 700 const ArrayCopyAnalysis &a, 701 const OperationUseMapT &m) 702 : ArrayUpdateConversionBase{ctx, a, m} {} 703 704 mlir::LogicalResult 705 matchAndRewrite(ArrayUpdateOp update, 706 mlir::PatternRewriter &rewriter) const override { 707 auto loc = update.getLoc(); 708 auto assignElement = [&](mlir::Value coor) { 709 rewriter.create<fir::StoreOp>(loc, update.getMerge(), coor); 710 }; 711 auto lhsEltRefType = toRefType(update.getMerge().getType()); 712 auto [_, lhsLoadResult] = materializeAssignment( 713 loc, rewriter, update, assignElement, lhsEltRefType); 714 update.replaceAllUsesWith(lhsLoadResult); 715 rewriter.replaceOp(update, lhsLoadResult); 716 return mlir::success(); 717 } 718 }; 719 720 class ArrayModifyConversion : public ArrayUpdateConversionBase<ArrayModifyOp> { 721 public: 722 explicit ArrayModifyConversion(mlir::MLIRContext *ctx, 723 const ArrayCopyAnalysis &a, 724 const OperationUseMapT &m) 725 : ArrayUpdateConversionBase{ctx, a, m} {} 726 727 mlir::LogicalResult 728 matchAndRewrite(ArrayModifyOp modify, 729 mlir::PatternRewriter &rewriter) const override { 730 auto loc = modify.getLoc(); 731 auto assignElement = [](mlir::Value) { 732 // Assignment already materialized by lowering using lhs element address. 733 }; 734 auto lhsEltRefType = modify.getResult(0).getType(); 735 auto [lhsEltCoor, lhsLoadResult] = materializeAssignment( 736 loc, rewriter, modify, assignElement, lhsEltRefType); 737 modify.replaceAllUsesWith(mlir::ValueRange{lhsEltCoor, lhsLoadResult}); 738 rewriter.replaceOp(modify, mlir::ValueRange{lhsEltCoor, lhsLoadResult}); 739 return mlir::success(); 740 } 741 }; 742 743 class ArrayFetchConversion : public mlir::OpRewritePattern<ArrayFetchOp> { 744 public: 745 explicit ArrayFetchConversion(mlir::MLIRContext *ctx, 746 const OperationUseMapT &m) 747 : OpRewritePattern{ctx}, useMap{m} {} 748 749 mlir::LogicalResult 750 matchAndRewrite(ArrayFetchOp fetch, 751 mlir::PatternRewriter &rewriter) const override { 752 auto *op = fetch.getOperation(); 753 rewriter.setInsertionPoint(op); 754 auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op)); 755 auto loc = fetch.getLoc(); 756 mlir::Value coor = 757 genCoorOp(rewriter, loc, getEleTy(load.getType()), 758 toRefType(fetch.getType()), load.getMemref(), load.getShape(), 759 load.getSlice(), fetch.getIndices(), load.getTypeparams(), 760 fetch->hasAttr(fir::factory::attrFortranArrayOffsets())); 761 rewriter.replaceOpWithNewOp<fir::LoadOp>(fetch, coor); 762 return mlir::success(); 763 } 764 765 private: 766 const OperationUseMapT &useMap; 767 }; 768 } // namespace 769 770 namespace { 771 class ArrayValueCopyConverter 772 : public ArrayValueCopyBase<ArrayValueCopyConverter> { 773 public: 774 void runOnOperation() override { 775 auto func = getOperation(); 776 LLVM_DEBUG(llvm::dbgs() << "\n\narray-value-copy pass on function '" 777 << func.getName() << "'\n"); 778 auto *context = &getContext(); 779 780 // Perform the conflict analysis. 781 auto &analysis = getAnalysis<ArrayCopyAnalysis>(); 782 const auto &useMap = analysis.getUseMap(); 783 784 // Phase 1 is performing a rewrite on the array accesses. Once all the 785 // array accesses are rewritten we can go on phase 2. 786 // Phase 2 gets rid of the useless copy-in/copyout operations. The copy-in 787 // /copy-out refers the Fortran copy-in/copy-out semantics on statements. 788 mlir::RewritePatternSet patterns1(context); 789 patterns1.insert<ArrayFetchConversion>(context, useMap); 790 patterns1.insert<ArrayUpdateConversion>(context, analysis, useMap); 791 patterns1.insert<ArrayModifyConversion>(context, analysis, useMap); 792 mlir::ConversionTarget target(*context); 793 target.addLegalDialect< 794 FIROpsDialect, mlir::scf::SCFDialect, mlir::arith::ArithmeticDialect, 795 mlir::cf::ControlFlowDialect, mlir::StandardOpsDialect>(); 796 target.addIllegalOp<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(); 797 // Rewrite the array fetch and array update ops. 798 if (mlir::failed( 799 mlir::applyPartialConversion(func, target, std::move(patterns1)))) { 800 mlir::emitError(mlir::UnknownLoc::get(context), 801 "failure in array-value-copy pass, phase 1"); 802 signalPassFailure(); 803 } 804 805 mlir::RewritePatternSet patterns2(context); 806 patterns2.insert<ArrayLoadConversion>(context); 807 patterns2.insert<ArrayMergeStoreConversion>(context); 808 target.addIllegalOp<ArrayLoadOp, ArrayMergeStoreOp>(); 809 if (mlir::failed( 810 mlir::applyPartialConversion(func, target, std::move(patterns2)))) { 811 mlir::emitError(mlir::UnknownLoc::get(context), 812 "failure in array-value-copy pass, phase 2"); 813 signalPassFailure(); 814 } 815 } 816 }; 817 } // namespace 818 819 std::unique_ptr<mlir::Pass> fir::createArrayValueCopyPass() { 820 return std::make_unique<ArrayValueCopyConverter>(); 821 } 822