1 //===-- ArrayValueCopy.cpp ------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "flang/Optimizer/Builder/BoxValue.h" 11 #include "flang/Optimizer/Builder/FIRBuilder.h" 12 #include "flang/Optimizer/Builder/Factory.h" 13 #include "flang/Optimizer/Dialect/FIRDialect.h" 14 #include "flang/Optimizer/Support/FIRContext.h" 15 #include "flang/Optimizer/Transforms/Passes.h" 16 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 17 #include "mlir/Dialect/SCF/SCF.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 #include "llvm/Support/Debug.h" 20 21 #define DEBUG_TYPE "flang-array-value-copy" 22 23 using namespace fir; 24 using namespace mlir; 25 26 using OperationUseMapT = llvm::DenseMap<mlir::Operation *, mlir::Operation *>; 27 28 namespace { 29 30 /// Array copy analysis. 31 /// Perform an interference analysis between array values. 32 /// 33 /// Lowering will generate a sequence of the following form. 34 /// ```mlir 35 /// %a_1 = fir.array_load %array_1(%shape) : ... 36 /// ... 37 /// %a_j = fir.array_load %array_j(%shape) : ... 38 /// ... 39 /// %a_n = fir.array_load %array_n(%shape) : ... 40 /// ... 41 /// %v_i = fir.array_fetch %a_i, ... 42 /// %a_j1 = fir.array_update %a_j, ... 43 /// ... 44 /// fir.array_merge_store %a_j, %a_jn to %array_j : ... 45 /// ``` 46 /// 47 /// The analysis is to determine if there are any conflicts. A conflict is when 48 /// one the following cases occurs. 49 /// 50 /// 1. There is an `array_update` to an array value, a_j, such that a_j was 51 /// loaded from the same array memory reference (array_j) but with a different 52 /// shape as the other array values a_i, where i != j. [Possible overlapping 53 /// arrays.] 54 /// 55 /// 2. There is either an array_fetch or array_update of a_j with a different 56 /// set of index values. [Possible loop-carried dependence.] 57 /// 58 /// If none of the array values overlap in storage and the accesses are not 59 /// loop-carried, then the arrays are conflict-free and no copies are required. 60 class ArrayCopyAnalysis { 61 public: 62 using ConflictSetT = llvm::SmallPtrSet<mlir::Operation *, 16>; 63 using UseSetT = llvm::SmallPtrSet<mlir::OpOperand *, 8>; 64 using LoadMapSetsT = 65 llvm::DenseMap<mlir::Operation *, SmallVector<Operation *>>; 66 67 ArrayCopyAnalysis(mlir::Operation *op) : operation{op} { construct(op); } 68 69 mlir::Operation *getOperation() const { return operation; } 70 71 /// Return true iff the `array_merge_store` has potential conflicts. 72 bool hasPotentialConflict(mlir::Operation *op) const { 73 LLVM_DEBUG(llvm::dbgs() 74 << "looking for a conflict on " << *op 75 << " and the set has a total of " << conflicts.size() << '\n'); 76 return conflicts.contains(op); 77 } 78 79 /// Return the use map. The use map maps array fetch and update operations 80 /// back to the array load that is the original source of the array value. 81 const OperationUseMapT &getUseMap() const { return useMap; } 82 83 /// Find all the array operations that access the array value that is loaded 84 /// by the array load operation, `load`. 85 const llvm::SmallVector<mlir::Operation *> &arrayAccesses(ArrayLoadOp load); 86 87 private: 88 void construct(mlir::Operation *topLevelOp); 89 90 mlir::Operation *operation; // operation that analysis ran upon 91 ConflictSetT conflicts; // set of conflicts (loads and merge stores) 92 OperationUseMapT useMap; 93 LoadMapSetsT loadMapSets; 94 }; 95 } // namespace 96 97 namespace { 98 /// Helper class to collect all array operations that produced an array value. 99 class ReachCollector { 100 private: 101 // If provided, the `loopRegion` is the body of a loop that produces the array 102 // of interest. 103 ReachCollector(llvm::SmallVectorImpl<mlir::Operation *> &reach, 104 mlir::Region *loopRegion) 105 : reach{reach}, loopRegion{loopRegion} {} 106 107 void collectArrayAccessFrom(mlir::Operation *op, mlir::ValueRange range) { 108 llvm::errs() << "COLLECT " << *op << "\n"; 109 if (range.empty()) { 110 collectArrayAccessFrom(op, mlir::Value{}); 111 return; 112 } 113 for (mlir::Value v : range) 114 collectArrayAccessFrom(v); 115 } 116 117 // TODO: Replace recursive algorithm on def-use chain with an iterative one 118 // with an explicit stack. 119 void collectArrayAccessFrom(mlir::Operation *op, mlir::Value val) { 120 // `val` is defined by an Op, process the defining Op. 121 // If `val` is defined by a region containing Op, we want to drill down 122 // and through that Op's region(s). 123 llvm::errs() << "COLLECT " << *op << "\n"; 124 LLVM_DEBUG(llvm::dbgs() << "popset: " << *op << '\n'); 125 auto popFn = [&](auto rop) { 126 assert(val && "op must have a result value"); 127 auto resNum = val.cast<mlir::OpResult>().getResultNumber(); 128 llvm::SmallVector<mlir::Value> results; 129 rop.resultToSourceOps(results, resNum); 130 for (auto u : results) 131 collectArrayAccessFrom(u); 132 }; 133 if (auto rop = mlir::dyn_cast<fir::DoLoopOp>(op)) { 134 popFn(rop); 135 return; 136 } 137 if (auto rop = mlir::dyn_cast<fir::IfOp>(op)) { 138 popFn(rop); 139 return; 140 } 141 if (auto mergeStore = mlir::dyn_cast<ArrayMergeStoreOp>(op)) { 142 if (opIsInsideLoops(mergeStore)) 143 collectArrayAccessFrom(mergeStore.getSequence()); 144 return; 145 } 146 147 if (mlir::isa<AllocaOp, AllocMemOp>(op)) { 148 // Look for any stores inside the loops, and collect an array operation 149 // that produced the value being stored to it. 150 for (mlir::Operation *user : op->getUsers()) 151 if (auto store = mlir::dyn_cast<fir::StoreOp>(user)) 152 if (opIsInsideLoops(store)) 153 collectArrayAccessFrom(store.getValue()); 154 return; 155 } 156 157 // Otherwise, Op does not contain a region so just chase its operands. 158 if (mlir::isa<ArrayLoadOp, ArrayUpdateOp, ArrayModifyOp, ArrayFetchOp>( 159 op)) { 160 LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n"); 161 reach.emplace_back(op); 162 } 163 // Array modify assignment is performed on the result. So the analysis 164 // must look at the what is done with the result. 165 if (mlir::isa<ArrayModifyOp>(op)) 166 for (mlir::Operation *user : op->getResult(0).getUsers()) 167 followUsers(user); 168 169 for (auto u : op->getOperands()) 170 collectArrayAccessFrom(u); 171 } 172 173 void collectArrayAccessFrom(mlir::BlockArgument ba) { 174 auto *parent = ba.getOwner()->getParentOp(); 175 // If inside an Op holding a region, the block argument corresponds to an 176 // argument passed to the containing Op. 177 auto popFn = [&](auto rop) { 178 collectArrayAccessFrom(rop.blockArgToSourceOp(ba.getArgNumber())); 179 }; 180 if (auto rop = mlir::dyn_cast<DoLoopOp>(parent)) { 181 popFn(rop); 182 return; 183 } 184 if (auto rop = mlir::dyn_cast<IterWhileOp>(parent)) { 185 popFn(rop); 186 return; 187 } 188 // Otherwise, a block argument is provided via the pred blocks. 189 for (auto *pred : ba.getOwner()->getPredecessors()) { 190 auto u = pred->getTerminator()->getOperand(ba.getArgNumber()); 191 collectArrayAccessFrom(u); 192 } 193 } 194 195 // Recursively trace operands to find all array operations relating to the 196 // values merged. 197 void collectArrayAccessFrom(mlir::Value val) { 198 if (!val || visited.contains(val)) 199 return; 200 visited.insert(val); 201 202 // Process a block argument. 203 if (auto ba = val.dyn_cast<mlir::BlockArgument>()) { 204 collectArrayAccessFrom(ba); 205 return; 206 } 207 208 // Process an Op. 209 if (auto *op = val.getDefiningOp()) { 210 collectArrayAccessFrom(op, val); 211 return; 212 } 213 214 fir::emitFatalError(val.getLoc(), "unhandled value"); 215 } 216 217 /// Is \op inside the loop nest region ? 218 bool opIsInsideLoops(mlir::Operation *op) const { 219 return loopRegion && loopRegion->isAncestor(op->getParentRegion()); 220 } 221 222 /// Recursively trace the use of an operation results, calling 223 /// collectArrayAccessFrom on the direct and indirect user operands. 224 /// TODO: Replace recursive algorithm on def-use chain with an iterative one 225 /// with an explicit stack. 226 void followUsers(mlir::Operation *op) { 227 for (auto userOperand : op->getOperands()) 228 collectArrayAccessFrom(userOperand); 229 // Go through potential converts/coordinate_op. 230 for (mlir::Operation *indirectUser : op->getUsers()) 231 followUsers(indirectUser); 232 } 233 234 llvm::SmallVectorImpl<mlir::Operation *> &reach; 235 llvm::SmallPtrSet<mlir::Value, 16> visited; 236 /// Region of the loops nest that produced the array value. 237 mlir::Region *loopRegion; 238 239 public: 240 /// Return all ops that produce the array value that is stored into the 241 /// `array_merge_store`. 242 static void reachingValues(llvm::SmallVectorImpl<mlir::Operation *> &reach, 243 mlir::Value seq) { 244 reach.clear(); 245 mlir::Region *loopRegion = nullptr; 246 // Only `DoLoopOp` is tested here since array operations are currently only 247 // associated with this kind of loop. 248 if (auto doLoop = 249 mlir::dyn_cast_or_null<fir::DoLoopOp>(seq.getDefiningOp())) 250 loopRegion = &doLoop->getRegion(0); 251 ReachCollector collector(reach, loopRegion); 252 collector.collectArrayAccessFrom(seq); 253 } 254 }; 255 } // namespace 256 257 /// Find all the array operations that access the array value that is loaded by 258 /// the array load operation, `load`. 259 const llvm::SmallVector<mlir::Operation *> & 260 ArrayCopyAnalysis::arrayAccesses(ArrayLoadOp load) { 261 auto lmIter = loadMapSets.find(load); 262 if (lmIter != loadMapSets.end()) 263 return lmIter->getSecond(); 264 265 llvm::SmallVector<mlir::Operation *> accesses; 266 UseSetT visited; 267 llvm::SmallVector<mlir::OpOperand *> queue; // uses of ArrayLoad[orig] 268 269 auto appendToQueue = [&](mlir::Value val) { 270 for (mlir::OpOperand &use : val.getUses()) 271 if (!visited.count(&use)) { 272 visited.insert(&use); 273 queue.push_back(&use); 274 } 275 }; 276 277 // Build the set of uses of `original`. 278 // let USES = { uses of original fir.load } 279 appendToQueue(load); 280 281 // Process the worklist until done. 282 while (!queue.empty()) { 283 mlir::OpOperand *operand = queue.pop_back_val(); 284 mlir::Operation *owner = operand->getOwner(); 285 286 auto structuredLoop = [&](auto ro) { 287 if (auto blockArg = ro.iterArgToBlockArg(operand->get())) { 288 int64_t arg = blockArg.getArgNumber(); 289 mlir::Value output = ro.getResult(ro.getFinalValue() ? arg : arg - 1); 290 appendToQueue(output); 291 appendToQueue(blockArg); 292 } 293 }; 294 // TODO: this need to be updated to use the control-flow interface. 295 auto branchOp = [&](mlir::Block *dest, OperandRange operands) { 296 if (operands.empty()) 297 return; 298 299 // Check if this operand is within the range. 300 unsigned operandIndex = operand->getOperandNumber(); 301 unsigned operandsStart = operands.getBeginOperandIndex(); 302 if (operandIndex < operandsStart || 303 operandIndex >= (operandsStart + operands.size())) 304 return; 305 306 // Index the successor. 307 unsigned argIndex = operandIndex - operandsStart; 308 appendToQueue(dest->getArgument(argIndex)); 309 }; 310 // Thread uses into structured loop bodies and return value uses. 311 if (auto ro = mlir::dyn_cast<DoLoopOp>(owner)) { 312 structuredLoop(ro); 313 } else if (auto ro = mlir::dyn_cast<IterWhileOp>(owner)) { 314 structuredLoop(ro); 315 } else if (auto rs = mlir::dyn_cast<ResultOp>(owner)) { 316 // Thread any uses of fir.if that return the marked array value. 317 if (auto ifOp = rs->getParentOfType<fir::IfOp>()) 318 appendToQueue(ifOp.getResult(operand->getOperandNumber())); 319 } else if (mlir::isa<ArrayFetchOp>(owner)) { 320 // Keep track of array value fetches. 321 LLVM_DEBUG(llvm::dbgs() 322 << "add fetch {" << *owner << "} to array value set\n"); 323 accesses.push_back(owner); 324 } else if (auto update = mlir::dyn_cast<ArrayUpdateOp>(owner)) { 325 // Keep track of array value updates and thread the return value uses. 326 LLVM_DEBUG(llvm::dbgs() 327 << "add update {" << *owner << "} to array value set\n"); 328 accesses.push_back(owner); 329 appendToQueue(update.getResult()); 330 } else if (auto update = mlir::dyn_cast<ArrayModifyOp>(owner)) { 331 // Keep track of array value modification and thread the return value 332 // uses. 333 LLVM_DEBUG(llvm::dbgs() 334 << "add modify {" << *owner << "} to array value set\n"); 335 accesses.push_back(owner); 336 appendToQueue(update.getResult(1)); 337 } else if (auto br = mlir::dyn_cast<mlir::cf::BranchOp>(owner)) { 338 branchOp(br.getDest(), br.getDestOperands()); 339 } else if (auto br = mlir::dyn_cast<mlir::cf::CondBranchOp>(owner)) { 340 branchOp(br.getTrueDest(), br.getTrueOperands()); 341 branchOp(br.getFalseDest(), br.getFalseOperands()); 342 } else if (mlir::isa<ArrayMergeStoreOp>(owner)) { 343 // do nothing 344 } else { 345 llvm::report_fatal_error("array value reached unexpected op"); 346 } 347 } 348 return loadMapSets.insert({load, accesses}).first->getSecond(); 349 } 350 351 /// Is there a conflict between the array value that was updated and to be 352 /// stored to `st` and the set of arrays loaded (`reach`) and used to compute 353 /// the updated value? 354 static bool conflictOnLoad(llvm::ArrayRef<mlir::Operation *> reach, 355 ArrayMergeStoreOp st) { 356 mlir::Value load; 357 mlir::Value addr = st.getMemref(); 358 auto stEleTy = fir::dyn_cast_ptrOrBoxEleTy(addr.getType()); 359 for (auto *op : reach) { 360 auto ld = mlir::dyn_cast<ArrayLoadOp>(op); 361 if (!ld) 362 continue; 363 mlir::Type ldTy = ld.getMemref().getType(); 364 if (auto boxTy = ldTy.dyn_cast<fir::BoxType>()) 365 ldTy = boxTy.getEleTy(); 366 if (ldTy.isa<fir::PointerType>() && stEleTy == dyn_cast_ptrEleTy(ldTy)) 367 return true; 368 if (ld.getMemref() == addr) { 369 if (ld.getResult() != st.getOriginal()) 370 return true; 371 if (load) 372 return true; 373 load = ld; 374 } 375 } 376 return false; 377 } 378 379 /// Check if there is any potential conflict in the chained update operations 380 /// (ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp) while merging back to the 381 /// array. A potential conflict is detected if two operations work on the same 382 /// indices. 383 static bool conflictOnMerge(llvm::ArrayRef<mlir::Operation *> accesses) { 384 if (accesses.size() < 2) 385 return false; 386 llvm::SmallVector<mlir::Value> indices; 387 LLVM_DEBUG(llvm::dbgs() << "check merge conflict on with " << accesses.size() 388 << " accesses on the list\n"); 389 for (auto *op : accesses) { 390 assert((mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(op)) && 391 "unexpected operation in analysis"); 392 llvm::SmallVector<mlir::Value> compareVector; 393 if (auto u = mlir::dyn_cast<ArrayUpdateOp>(op)) { 394 if (indices.empty()) { 395 indices = u.getIndices(); 396 continue; 397 } 398 compareVector = u.getIndices(); 399 } else if (auto f = mlir::dyn_cast<ArrayModifyOp>(op)) { 400 if (indices.empty()) { 401 indices = f.getIndices(); 402 continue; 403 } 404 compareVector = f.getIndices(); 405 } else if (auto f = mlir::dyn_cast<ArrayFetchOp>(op)) { 406 if (indices.empty()) { 407 indices = f.getIndices(); 408 continue; 409 } 410 compareVector = f.getIndices(); 411 } 412 if (compareVector != indices) 413 return true; 414 LLVM_DEBUG(llvm::dbgs() << "vectors compare equal\n"); 415 } 416 return false; 417 } 418 419 // Are either of types of conflicts present? 420 inline bool conflictDetected(llvm::ArrayRef<mlir::Operation *> reach, 421 llvm::ArrayRef<mlir::Operation *> accesses, 422 ArrayMergeStoreOp st) { 423 return conflictOnLoad(reach, st) || conflictOnMerge(accesses); 424 } 425 426 /// Constructor of the array copy analysis. 427 /// This performs the analysis and saves the intermediate results. 428 void ArrayCopyAnalysis::construct(mlir::Operation *topLevelOp) { 429 topLevelOp->walk([&](Operation *op) { 430 if (auto st = mlir::dyn_cast<fir::ArrayMergeStoreOp>(op)) { 431 llvm::SmallVector<Operation *> values; 432 ReachCollector::reachingValues(values, st.getSequence()); 433 const llvm::SmallVector<Operation *> &accesses = arrayAccesses( 434 mlir::cast<ArrayLoadOp>(st.getOriginal().getDefiningOp())); 435 if (conflictDetected(values, accesses, st)) { 436 LLVM_DEBUG(llvm::dbgs() 437 << "CONFLICT: copies required for " << st << '\n' 438 << " adding conflicts on: " << op << " and " 439 << st.getOriginal() << '\n'); 440 conflicts.insert(op); 441 conflicts.insert(st.getOriginal().getDefiningOp()); 442 } 443 auto *ld = st.getOriginal().getDefiningOp(); 444 LLVM_DEBUG(llvm::dbgs() 445 << "map: adding {" << *ld << " -> " << st << "}\n"); 446 useMap.insert({ld, op}); 447 } else if (auto load = mlir::dyn_cast<ArrayLoadOp>(op)) { 448 const llvm::SmallVector<mlir::Operation *> &accesses = 449 arrayAccesses(load); 450 LLVM_DEBUG(llvm::dbgs() << "process load: " << load 451 << ", accesses: " << accesses.size() << '\n'); 452 for (auto *acc : accesses) { 453 LLVM_DEBUG(llvm::dbgs() << " access: " << *acc << '\n'); 454 assert((mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(acc))); 455 if (!useMap.insert({acc, op}).second) { 456 mlir::emitError( 457 load.getLoc(), 458 "The parallel semantics of multiple array_merge_stores per " 459 "array_load are not supported."); 460 return; 461 } 462 LLVM_DEBUG(llvm::dbgs() 463 << "map: adding {" << *acc << "} -> {" << load << "}\n"); 464 } 465 } 466 }); 467 } 468 469 namespace { 470 class ArrayLoadConversion : public mlir::OpRewritePattern<ArrayLoadOp> { 471 public: 472 using OpRewritePattern::OpRewritePattern; 473 474 mlir::LogicalResult 475 matchAndRewrite(ArrayLoadOp load, 476 mlir::PatternRewriter &rewriter) const override { 477 LLVM_DEBUG(llvm::dbgs() << "replace load " << load << " with undef.\n"); 478 rewriter.replaceOpWithNewOp<UndefOp>(load, load.getType()); 479 return mlir::success(); 480 } 481 }; 482 483 class ArrayMergeStoreConversion 484 : public mlir::OpRewritePattern<ArrayMergeStoreOp> { 485 public: 486 using OpRewritePattern::OpRewritePattern; 487 488 mlir::LogicalResult 489 matchAndRewrite(ArrayMergeStoreOp store, 490 mlir::PatternRewriter &rewriter) const override { 491 LLVM_DEBUG(llvm::dbgs() << "marking store " << store << " as dead.\n"); 492 rewriter.eraseOp(store); 493 return mlir::success(); 494 } 495 }; 496 } // namespace 497 498 static mlir::Type getEleTy(mlir::Type ty) { 499 if (auto t = dyn_cast_ptrEleTy(ty)) 500 ty = t; 501 if (auto t = ty.dyn_cast<SequenceType>()) 502 ty = t.getEleTy(); 503 // FIXME: keep ptr/heap/ref information. 504 return ReferenceType::get(ty); 505 } 506 507 // Extract extents from the ShapeOp/ShapeShiftOp into the result vector. 508 // TODO: getExtents on op should return a ValueRange instead of a vector. 509 static void getExtents(llvm::SmallVectorImpl<mlir::Value> &result, 510 mlir::Value shape) { 511 auto *shapeOp = shape.getDefiningOp(); 512 if (auto s = mlir::dyn_cast<fir::ShapeOp>(shapeOp)) { 513 auto e = s.getExtents(); 514 result.insert(result.end(), e.begin(), e.end()); 515 return; 516 } 517 if (auto s = mlir::dyn_cast<fir::ShapeShiftOp>(shapeOp)) { 518 auto e = s.getExtents(); 519 result.insert(result.end(), e.begin(), e.end()); 520 return; 521 } 522 llvm::report_fatal_error("not a fir.shape/fir.shape_shift op"); 523 } 524 525 // Place the extents of the array loaded by an ArrayLoadOp into the result 526 // vector and return a ShapeOp/ShapeShiftOp with the corresponding extents. If 527 // the ArrayLoadOp is loading a fir.box, code will be generated to read the 528 // extents from the fir.box, and a the retunred ShapeOp is built with the read 529 // extents. 530 // Otherwise, the extents will be extracted from the ShapeOp/ShapeShiftOp 531 // argument of the ArrayLoadOp that is returned. 532 static mlir::Value 533 getOrReadExtentsAndShapeOp(mlir::Location loc, mlir::PatternRewriter &rewriter, 534 fir::ArrayLoadOp loadOp, 535 llvm::SmallVectorImpl<mlir::Value> &result) { 536 assert(result.empty()); 537 if (auto boxTy = loadOp.getMemref().getType().dyn_cast<fir::BoxType>()) { 538 auto rank = fir::dyn_cast_ptrOrBoxEleTy(boxTy) 539 .cast<fir::SequenceType>() 540 .getDimension(); 541 auto idxTy = rewriter.getIndexType(); 542 for (decltype(rank) dim = 0; dim < rank; ++dim) { 543 auto dimVal = rewriter.create<arith::ConstantIndexOp>(loc, dim); 544 auto dimInfo = rewriter.create<fir::BoxDimsOp>( 545 loc, idxTy, idxTy, idxTy, loadOp.getMemref(), dimVal); 546 result.emplace_back(dimInfo.getResult(1)); 547 } 548 auto shapeType = fir::ShapeType::get(rewriter.getContext(), rank); 549 return rewriter.create<fir::ShapeOp>(loc, shapeType, result); 550 } 551 getExtents(result, loadOp.getShape()); 552 return loadOp.getShape(); 553 } 554 555 static mlir::Type toRefType(mlir::Type ty) { 556 if (fir::isa_ref_type(ty)) 557 return ty; 558 return fir::ReferenceType::get(ty); 559 } 560 561 static mlir::Value 562 genCoorOp(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Type eleTy, 563 mlir::Type resTy, mlir::Value alloc, mlir::Value shape, 564 mlir::Value slice, mlir::ValueRange indices, 565 mlir::ValueRange typeparams, bool skipOrig = false) { 566 llvm::SmallVector<mlir::Value> originated; 567 if (skipOrig) 568 originated.assign(indices.begin(), indices.end()); 569 else 570 originated = fir::factory::originateIndices(loc, rewriter, alloc.getType(), 571 shape, indices); 572 auto seqTy = fir::dyn_cast_ptrOrBoxEleTy(alloc.getType()); 573 assert(seqTy && seqTy.isa<fir::SequenceType>()); 574 const auto dimension = seqTy.cast<fir::SequenceType>().getDimension(); 575 mlir::Value result = rewriter.create<fir::ArrayCoorOp>( 576 loc, eleTy, alloc, shape, slice, 577 llvm::ArrayRef<mlir::Value>{originated}.take_front(dimension), 578 typeparams); 579 if (dimension < originated.size()) 580 result = rewriter.create<fir::CoordinateOp>( 581 loc, resTy, result, 582 llvm::ArrayRef<mlir::Value>{originated}.drop_front(dimension)); 583 return result; 584 } 585 586 namespace { 587 /// Conversion of fir.array_update and fir.array_modify Ops. 588 /// If there is a conflict for the update, then we need to perform a 589 /// copy-in/copy-out to preserve the original values of the array. If there is 590 /// no conflict, then it is save to eschew making any copies. 591 template <typename ArrayOp> 592 class ArrayUpdateConversionBase : public mlir::OpRewritePattern<ArrayOp> { 593 public: 594 explicit ArrayUpdateConversionBase(mlir::MLIRContext *ctx, 595 const ArrayCopyAnalysis &a, 596 const OperationUseMapT &m) 597 : mlir::OpRewritePattern<ArrayOp>{ctx}, analysis{a}, useMap{m} {} 598 599 void genArrayCopy(mlir::Location loc, mlir::PatternRewriter &rewriter, 600 mlir::Value dst, mlir::Value src, mlir::Value shapeOp, 601 mlir::Type arrTy) const { 602 auto insPt = rewriter.saveInsertionPoint(); 603 llvm::SmallVector<mlir::Value> indices; 604 llvm::SmallVector<mlir::Value> extents; 605 getExtents(extents, shapeOp); 606 // Build loop nest from column to row. 607 for (auto sh : llvm::reverse(extents)) { 608 auto idxTy = rewriter.getIndexType(); 609 auto ubi = rewriter.create<fir::ConvertOp>(loc, idxTy, sh); 610 auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 611 auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 612 auto ub = rewriter.create<arith::SubIOp>(loc, idxTy, ubi, one); 613 auto loop = rewriter.create<fir::DoLoopOp>(loc, zero, ub, one); 614 rewriter.setInsertionPointToStart(loop.getBody()); 615 indices.push_back(loop.getInductionVar()); 616 } 617 // Reverse the indices so they are in column-major order. 618 std::reverse(indices.begin(), indices.end()); 619 auto ty = getEleTy(arrTy); 620 auto fromAddr = rewriter.create<fir::ArrayCoorOp>( 621 loc, ty, src, shapeOp, mlir::Value{}, 622 fir::factory::originateIndices(loc, rewriter, src.getType(), shapeOp, 623 indices), 624 mlir::ValueRange{}); 625 auto load = rewriter.create<fir::LoadOp>(loc, fromAddr); 626 auto toAddr = rewriter.create<fir::ArrayCoorOp>( 627 loc, ty, dst, shapeOp, mlir::Value{}, 628 fir::factory::originateIndices(loc, rewriter, dst.getType(), shapeOp, 629 indices), 630 mlir::ValueRange{}); 631 rewriter.create<fir::StoreOp>(loc, load, toAddr); 632 rewriter.restoreInsertionPoint(insPt); 633 } 634 635 /// Copy the RHS element into the LHS and insert copy-in/copy-out between a 636 /// temp and the LHS if the analysis found potential overlaps between the RHS 637 /// and LHS arrays. The element copy generator must be provided through \p 638 /// assignElement. \p update must be the ArrayUpdateOp or the ArrayModifyOp. 639 /// Returns the address of the LHS element inside the loop and the LHS 640 /// ArrayLoad result. 641 std::pair<mlir::Value, mlir::Value> 642 materializeAssignment(mlir::Location loc, mlir::PatternRewriter &rewriter, 643 ArrayOp update, 644 llvm::function_ref<void(mlir::Value)> assignElement, 645 mlir::Type lhsEltRefType) const { 646 auto *op = update.getOperation(); 647 mlir::Operation *loadOp = useMap.lookup(op); 648 auto load = mlir::cast<ArrayLoadOp>(loadOp); 649 LLVM_DEBUG(llvm::outs() << "does " << load << " have a conflict?\n"); 650 if (analysis.hasPotentialConflict(loadOp)) { 651 // If there is a conflict between the arrays, then we copy the lhs array 652 // to a temporary, update the temporary, and copy the temporary back to 653 // the lhs array. This yields Fortran's copy-in copy-out array semantics. 654 LLVM_DEBUG(llvm::outs() << "Yes, conflict was found\n"); 655 rewriter.setInsertionPoint(loadOp); 656 // Copy in. 657 llvm::SmallVector<mlir::Value> extents; 658 mlir::Value shapeOp = 659 getOrReadExtentsAndShapeOp(loc, rewriter, load, extents); 660 auto allocmem = rewriter.create<AllocMemOp>( 661 loc, dyn_cast_ptrOrBoxEleTy(load.getMemref().getType()), 662 load.getTypeparams(), extents); 663 genArrayCopy(load.getLoc(), rewriter, allocmem, load.getMemref(), shapeOp, 664 load.getType()); 665 rewriter.setInsertionPoint(op); 666 mlir::Value coor = genCoorOp( 667 rewriter, loc, getEleTy(load.getType()), lhsEltRefType, allocmem, 668 shapeOp, load.getSlice(), update.getIndices(), load.getTypeparams(), 669 update->hasAttr(fir::factory::attrFortranArrayOffsets())); 670 assignElement(coor); 671 mlir::Operation *storeOp = useMap.lookup(loadOp); 672 auto store = mlir::cast<ArrayMergeStoreOp>(storeOp); 673 rewriter.setInsertionPoint(storeOp); 674 // Copy out. 675 genArrayCopy(store.getLoc(), rewriter, store.getMemref(), allocmem, 676 shapeOp, load.getType()); 677 rewriter.create<FreeMemOp>(loc, allocmem); 678 return {coor, load.getResult()}; 679 } 680 // Otherwise, when there is no conflict (a possible loop-carried 681 // dependence), the lhs array can be updated in place. 682 LLVM_DEBUG(llvm::outs() << "No, conflict wasn't found\n"); 683 rewriter.setInsertionPoint(op); 684 auto coorTy = getEleTy(load.getType()); 685 mlir::Value coor = genCoorOp( 686 rewriter, loc, coorTy, lhsEltRefType, load.getMemref(), load.getShape(), 687 load.getSlice(), update.getIndices(), load.getTypeparams(), 688 update->hasAttr(fir::factory::attrFortranArrayOffsets())); 689 assignElement(coor); 690 return {coor, load.getResult()}; 691 } 692 693 private: 694 const ArrayCopyAnalysis &analysis; 695 const OperationUseMapT &useMap; 696 }; 697 698 class ArrayUpdateConversion : public ArrayUpdateConversionBase<ArrayUpdateOp> { 699 public: 700 explicit ArrayUpdateConversion(mlir::MLIRContext *ctx, 701 const ArrayCopyAnalysis &a, 702 const OperationUseMapT &m) 703 : ArrayUpdateConversionBase{ctx, a, m} {} 704 705 mlir::LogicalResult 706 matchAndRewrite(ArrayUpdateOp update, 707 mlir::PatternRewriter &rewriter) const override { 708 auto loc = update.getLoc(); 709 auto assignElement = [&](mlir::Value coor) { 710 rewriter.create<fir::StoreOp>(loc, update.getMerge(), coor); 711 }; 712 auto lhsEltRefType = toRefType(update.getMerge().getType()); 713 auto [_, lhsLoadResult] = materializeAssignment( 714 loc, rewriter, update, assignElement, lhsEltRefType); 715 update.replaceAllUsesWith(lhsLoadResult); 716 rewriter.replaceOp(update, lhsLoadResult); 717 return mlir::success(); 718 } 719 }; 720 721 class ArrayModifyConversion : public ArrayUpdateConversionBase<ArrayModifyOp> { 722 public: 723 explicit ArrayModifyConversion(mlir::MLIRContext *ctx, 724 const ArrayCopyAnalysis &a, 725 const OperationUseMapT &m) 726 : ArrayUpdateConversionBase{ctx, a, m} {} 727 728 mlir::LogicalResult 729 matchAndRewrite(ArrayModifyOp modify, 730 mlir::PatternRewriter &rewriter) const override { 731 auto loc = modify.getLoc(); 732 auto assignElement = [](mlir::Value) { 733 // Assignment already materialized by lowering using lhs element address. 734 }; 735 auto lhsEltRefType = modify.getResult(0).getType(); 736 auto [lhsEltCoor, lhsLoadResult] = materializeAssignment( 737 loc, rewriter, modify, assignElement, lhsEltRefType); 738 modify.replaceAllUsesWith(mlir::ValueRange{lhsEltCoor, lhsLoadResult}); 739 rewriter.replaceOp(modify, mlir::ValueRange{lhsEltCoor, lhsLoadResult}); 740 return mlir::success(); 741 } 742 }; 743 744 class ArrayFetchConversion : public mlir::OpRewritePattern<ArrayFetchOp> { 745 public: 746 explicit ArrayFetchConversion(mlir::MLIRContext *ctx, 747 const OperationUseMapT &m) 748 : OpRewritePattern{ctx}, useMap{m} {} 749 750 mlir::LogicalResult 751 matchAndRewrite(ArrayFetchOp fetch, 752 mlir::PatternRewriter &rewriter) const override { 753 auto *op = fetch.getOperation(); 754 rewriter.setInsertionPoint(op); 755 auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op)); 756 auto loc = fetch.getLoc(); 757 mlir::Value coor = 758 genCoorOp(rewriter, loc, getEleTy(load.getType()), 759 toRefType(fetch.getType()), load.getMemref(), load.getShape(), 760 load.getSlice(), fetch.getIndices(), load.getTypeparams(), 761 fetch->hasAttr(fir::factory::attrFortranArrayOffsets())); 762 rewriter.replaceOpWithNewOp<fir::LoadOp>(fetch, coor); 763 return mlir::success(); 764 } 765 766 private: 767 const OperationUseMapT &useMap; 768 }; 769 } // namespace 770 771 namespace { 772 class ArrayValueCopyConverter 773 : public ArrayValueCopyBase<ArrayValueCopyConverter> { 774 public: 775 void runOnOperation() override { 776 auto func = getOperation(); 777 LLVM_DEBUG(llvm::dbgs() << "\n\narray-value-copy pass on function '" 778 << func.getName() << "'\n"); 779 auto *context = &getContext(); 780 781 // Perform the conflict analysis. 782 auto &analysis = getAnalysis<ArrayCopyAnalysis>(); 783 const auto &useMap = analysis.getUseMap(); 784 785 // Phase 1 is performing a rewrite on the array accesses. Once all the 786 // array accesses are rewritten we can go on phase 2. 787 // Phase 2 gets rid of the useless copy-in/copyout operations. The copy-in 788 // /copy-out refers the Fortran copy-in/copy-out semantics on statements. 789 mlir::RewritePatternSet patterns1(context); 790 patterns1.insert<ArrayFetchConversion>(context, useMap); 791 patterns1.insert<ArrayUpdateConversion>(context, analysis, useMap); 792 patterns1.insert<ArrayModifyConversion>(context, analysis, useMap); 793 mlir::ConversionTarget target(*context); 794 target.addLegalDialect< 795 FIROpsDialect, mlir::scf::SCFDialect, mlir::arith::ArithmeticDialect, 796 mlir::cf::ControlFlowDialect, mlir::func::FuncDialect>(); 797 target.addIllegalOp<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(); 798 // Rewrite the array fetch and array update ops. 799 if (mlir::failed( 800 mlir::applyPartialConversion(func, target, std::move(patterns1)))) { 801 mlir::emitError(mlir::UnknownLoc::get(context), 802 "failure in array-value-copy pass, phase 1"); 803 signalPassFailure(); 804 } 805 806 mlir::RewritePatternSet patterns2(context); 807 patterns2.insert<ArrayLoadConversion>(context); 808 patterns2.insert<ArrayMergeStoreConversion>(context); 809 target.addIllegalOp<ArrayLoadOp, ArrayMergeStoreOp>(); 810 if (mlir::failed( 811 mlir::applyPartialConversion(func, target, std::move(patterns2)))) { 812 mlir::emitError(mlir::UnknownLoc::get(context), 813 "failure in array-value-copy pass, phase 2"); 814 signalPassFailure(); 815 } 816 } 817 }; 818 } // namespace 819 820 std::unique_ptr<mlir::Pass> fir::createArrayValueCopyPass() { 821 return std::make_unique<ArrayValueCopyConverter>(); 822 } 823