1 //===- LoopFusion.cpp - Code to perform loop fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements affine fusion. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/Passes.h" 14 15 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 16 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" 17 #include "mlir/Dialect/Affine/Analysis/Utils.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Affine/LoopFusionUtils.h" 20 #include "mlir/Dialect/Affine/LoopUtils.h" 21 #include "mlir/Dialect/Affine/Utils.h" 22 #include "mlir/Dialect/MemRef/IR/MemRef.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/AffineMap.h" 25 #include "mlir/IR/Builders.h" 26 #include "mlir/Transforms/Passes.h" 27 #include "llvm/ADT/DenseMap.h" 28 #include "llvm/ADT/DenseSet.h" 29 #include "llvm/ADT/STLExtras.h" 30 #include "llvm/Support/CommandLine.h" 31 #include "llvm/Support/Debug.h" 32 #include "llvm/Support/raw_ostream.h" 33 #include <iomanip> 34 #include <optional> 35 #include <sstream> 36 37 namespace mlir { 38 namespace affine { 39 #define GEN_PASS_DEF_AFFINELOOPFUSION 40 #include "mlir/Dialect/Affine/Passes.h.inc" 41 } // namespace affine 42 } // namespace mlir 43 44 #define DEBUG_TYPE "affine-loop-fusion" 45 46 using namespace mlir; 47 using namespace mlir::affine; 48 49 namespace { 50 /// Loop fusion pass. This pass currently supports a greedy fusion policy, 51 /// which fuses loop nests with single-writer/single-reader memref dependences 52 /// with the goal of improving locality. 53 // TODO: Support fusion of source loop nests which write to multiple 54 // memrefs, where each memref can have multiple users (if profitable). 55 struct LoopFusion : public affine::impl::AffineLoopFusionBase<LoopFusion> { 56 LoopFusion() = default; 57 LoopFusion(unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes, 58 bool maximalFusion, enum FusionMode affineFusionMode) { 59 this->fastMemorySpace = fastMemorySpace; 60 this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024; 61 this->maximalFusion = maximalFusion; 62 this->affineFusionMode = affineFusionMode; 63 } 64 65 void runOnBlock(Block *block); 66 void runOnOperation() override; 67 }; 68 69 } // namespace 70 71 /// Returns true if node 'srcId' can be removed after fusing it with node 72 /// 'dstId'. The node can be removed if any of the following conditions are met: 73 /// 1. 'srcId' has no output dependences after fusion and no escaping memrefs. 74 /// 2. 'srcId' has no output dependences after fusion, has escaping memrefs 75 /// and the fusion slice is maximal. 76 /// 3. 'srcId' has output dependences after fusion, the fusion slice is 77 /// maximal and the fusion insertion point dominates all the dependences. 78 static bool canRemoveSrcNodeAfterFusion( 79 unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice, 80 Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs, 81 MemRefDependenceGraph *mdg) { 82 83 Operation *dstNodeOp = mdg->getNode(dstId)->op; 84 bool hasOutDepsAfterFusion = false; 85 86 for (auto &outEdge : mdg->outEdges[srcId]) { 87 Operation *depNodeOp = mdg->getNode(outEdge.id)->op; 88 // Skip dependence with dstOp since it will be removed after fusion. 89 if (depNodeOp == dstNodeOp) 90 continue; 91 92 // Only fusion within the same block is supported. Use domination analysis 93 // when needed. 94 if (depNodeOp->getBlock() != dstNodeOp->getBlock()) 95 return false; 96 97 // Check if the insertion point of the fused loop dominates the dependence. 98 // Otherwise, the src loop can't be removed. 99 if (fusedLoopInsPoint != depNodeOp && 100 !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) { 101 LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't " 102 "dominate dependence\n"); 103 return false; 104 } 105 106 hasOutDepsAfterFusion = true; 107 } 108 109 // If src loop has dependences after fusion or it writes to an live-out or 110 // escaping memref, we can only remove it if the fusion slice is maximal so 111 // that all the dependences are preserved. 112 if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) { 113 std::optional<bool> isMaximal = fusionSlice.isMaximal(); 114 if (!isMaximal) { 115 LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine " 116 "if fusion is maximal\n"); 117 return false; 118 } 119 120 if (!*isMaximal) { 121 LLVM_DEBUG(llvm::dbgs() 122 << "Src loop can't be removed: fusion is not maximal\n"); 123 return false; 124 } 125 } 126 127 return true; 128 } 129 130 /// Returns in 'srcIdCandidates' the producer fusion candidates for consumer 131 /// 'dstId'. Candidates are sorted by node id order. This order corresponds to 132 /// the program order when the 'mdg' is created. However, program order is not 133 /// guaranteed and must not be required by the client. Program order won't be 134 /// held if the 'mdg' is reused from a previous fusion step or if the node 135 /// creation order changes in the future to support more advance cases. 136 // TODO: Move this to a loop fusion utility once 'mdg' is also moved. 137 static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg, 138 SmallVectorImpl<unsigned> &srcIdCandidates) { 139 // Skip if no input edges along which to fuse. 140 if (mdg->inEdges.count(dstId) == 0) 141 return; 142 143 // Gather memrefs from loads in 'dstId'. 144 auto *dstNode = mdg->getNode(dstId); 145 DenseSet<Value> consumedMemrefs; 146 for (Operation *load : dstNode->loads) 147 consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef()); 148 149 // Traverse 'dstId' incoming edges and gather the nodes that contain a store 150 // to one of the consumed memrefs. 151 for (auto &srcEdge : mdg->inEdges[dstId]) { 152 auto *srcNode = mdg->getNode(srcEdge.id); 153 // Skip if 'srcNode' is not a loop nest. 154 if (!isa<AffineForOp>(srcNode->op)) 155 continue; 156 157 if (any_of(srcNode->stores, [&](Operation *op) { 158 auto storeOp = cast<AffineWriteOpInterface>(op); 159 return consumedMemrefs.count(storeOp.getMemRef()) > 0; 160 })) 161 srcIdCandidates.push_back(srcNode->id); 162 } 163 164 llvm::sort(srcIdCandidates); 165 srcIdCandidates.erase(llvm::unique(srcIdCandidates), srcIdCandidates.end()); 166 } 167 168 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a 169 /// producer-consumer dependence between 'srcId' and 'dstId'. 170 static void 171 gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId, 172 MemRefDependenceGraph *mdg, 173 DenseSet<Value> &producerConsumerMemrefs) { 174 auto *dstNode = mdg->getNode(dstId); 175 auto *srcNode = mdg->getNode(srcId); 176 gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads, 177 producerConsumerMemrefs); 178 } 179 180 /// A memref escapes in the context of the fusion pass if either: 181 /// 1. it (or its alias) is a block argument, or 182 /// 2. created by an op not known to guarantee alias freedom, 183 /// 3. it (or its alias) are used by ops other than affine dereferencing ops 184 /// (e.g., by call op, memref load/store ops, alias creating ops, unknown ops, 185 /// terminator ops, etc.); such ops do not deference the memref in an affine 186 /// way. 187 static bool isEscapingMemref(Value memref, Block *block) { 188 Operation *defOp = memref.getDefiningOp(); 189 // Check if 'memref' is a block argument. 190 if (!defOp) 191 return true; 192 193 // Check if this is defined to be an alias of another memref. 194 if (auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp)) 195 if (isEscapingMemref(viewOp.getViewSource(), block)) 196 return true; 197 198 // Any op besides allocating ops wouldn't guarantee alias freedom 199 if (!hasSingleEffect<mlir::MemoryEffects::Allocate>(defOp, memref)) 200 return true; 201 202 // Check if 'memref' is used by a non-deferencing op (including unknown ones) 203 // (e.g., call ops, alias creating ops, etc.). 204 return llvm::any_of(memref.getUsers(), [&](Operation *user) { 205 // Ignore users outside of `block`. 206 Operation *ancestorOp = block->getParent()->findAncestorOpInRegion(*user); 207 if (!ancestorOp) 208 return true; 209 if (ancestorOp->getBlock() != block) 210 return false; 211 return !isa<AffineMapAccessInterface>(*user); 212 }); 213 } 214 215 /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' 216 /// that escape the block or are accessed in a non-affine way. 217 static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg, 218 DenseSet<Value> &escapingMemRefs) { 219 auto *node = mdg->getNode(id); 220 for (Operation *storeOp : node->stores) { 221 auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef(); 222 if (escapingMemRefs.count(memref)) 223 continue; 224 if (isEscapingMemref(memref, &mdg->block)) 225 escapingMemRefs.insert(memref); 226 } 227 } 228 229 // Sinks all sequential loops to the innermost levels (while preserving 230 // relative order among them) and moves all parallel loops to the 231 // outermost (while again preserving relative order among them). 232 // This can increase the loop depth at which we can fuse a slice, since we are 233 // pushing loop carried dependence to a greater depth in the loop nest. 234 static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { 235 assert(isa<AffineForOp>(node->op)); 236 AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op)); 237 node->op = newRootForOp; 238 } 239 240 // Creates and returns a private (single-user) memref for fused loop rooted 241 // at 'forOp', with (potentially reduced) memref size based on the 242 // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'. 243 // TODO: consider refactoring the common code from generateDma and 244 // this one. 245 static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, 246 unsigned dstLoopDepth, 247 std::optional<unsigned> fastMemorySpace, 248 uint64_t localBufSizeThreshold) { 249 Operation *forInst = forOp.getOperation(); 250 251 // Create builder to insert alloc op just before 'forOp'. 252 OpBuilder b(forInst); 253 // Builder to create constants at the top level. 254 OpBuilder top(forInst->getParentRegion()); 255 // Create new memref type based on slice bounds. 256 auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef(); 257 auto oldMemRefType = cast<MemRefType>(oldMemRef.getType()); 258 unsigned rank = oldMemRefType.getRank(); 259 260 // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'. 261 MemRefRegion region(srcStoreOpInst->getLoc()); 262 bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth)); 263 (void)validRegion; 264 assert(validRegion && "unexpected memref region failure"); 265 SmallVector<int64_t, 4> newShape; 266 std::vector<SmallVector<int64_t, 4>> lbs; 267 SmallVector<int64_t, 8> lbDivisors; 268 lbs.reserve(rank); 269 // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed 270 // by 'srcStoreOpInst' at depth 'dstLoopDepth'. 271 std::optional<int64_t> numElements = 272 region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors); 273 assert(numElements && "non-constant number of elts in local buffer"); 274 275 const FlatAffineValueConstraints *cst = region.getConstraints(); 276 // 'outerIVs' holds the values that this memory region is symbolic/parametric 277 // on; this would correspond to loop IVs surrounding the level at which the 278 // slice is being materialized. 279 SmallVector<Value, 8> outerIVs; 280 cst->getValues(rank, cst->getNumVars(), &outerIVs); 281 282 // Build 'rank' AffineExprs from MemRefRegion 'lbs' 283 SmallVector<AffineExpr, 4> offsets; 284 offsets.reserve(rank); 285 for (unsigned d = 0; d < rank; ++d) { 286 assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size"); 287 288 AffineExpr offset = top.getAffineConstantExpr(0); 289 for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) { 290 offset = offset + lbs[d][j] * top.getAffineDimExpr(j); 291 } 292 assert(lbDivisors[d] > 0); 293 offset = 294 (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]); 295 offsets.push_back(offset); 296 } 297 298 // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed 299 // by 'srcStoreOpInst'. 300 auto eltSize = getMemRefIntOrFloatEltSizeInBytes(oldMemRefType); 301 assert(eltSize && "memrefs with size elt types expected"); 302 uint64_t bufSize = *eltSize * *numElements; 303 unsigned newMemSpace; 304 if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) { 305 newMemSpace = *fastMemorySpace; 306 } else { 307 newMemSpace = oldMemRefType.getMemorySpaceAsInt(); 308 } 309 auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(), 310 {}, newMemSpace); 311 312 // Create new private memref for fused loop 'forOp'. 'newShape' is always 313 // a constant shape. 314 // TODO: Create/move alloc ops for private memrefs closer to their 315 // consumer loop nests to reduce their live range. Currently they are added 316 // at the beginning of the block, because loop nests can be reordered 317 // during the fusion pass. 318 Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType); 319 320 // Build an AffineMap to remap access functions based on lower bound offsets. 321 SmallVector<AffineExpr, 4> remapExprs; 322 remapExprs.reserve(rank); 323 for (unsigned i = 0; i < rank; i++) { 324 auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i); 325 326 auto remapExpr = 327 simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0); 328 remapExprs.push_back(remapExpr); 329 } 330 331 auto indexRemap = 332 AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext()); 333 334 // Replace all users of 'oldMemRef' with 'newMemRef'. 335 LogicalResult res = 336 replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, 337 /*extraOperands=*/outerIVs, 338 /*symbolOperands=*/{}, 339 /*domOpFilter=*/&*forOp.getBody()->begin()); 340 assert(succeeded(res) && 341 "replaceAllMemrefUsesWith should always succeed here"); 342 (void)res; 343 return newMemRef; 344 } 345 346 /// Returns true if there are any non-affine uses of `memref` in any of 347 /// the operations between `start` and `end` (both exclusive). Any other 348 /// than affine read/write are treated as non-affine uses of `memref`. 349 static bool hasNonAffineUsersOnPath(Operation *start, Operation *end, 350 Value memref) { 351 assert(start->getBlock() == end->getBlock()); 352 assert(start->isBeforeInBlock(end) && "start expected to be before end"); 353 Block *block = start->getBlock(); 354 // Check if there is a non-affine memref user in any op between `start` and 355 // `end`. 356 return llvm::any_of(memref.getUsers(), [&](Operation *user) { 357 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(user)) 358 return false; 359 Operation *ancestor = block->findAncestorOpInBlock(*user); 360 return ancestor && start->isBeforeInBlock(ancestor) && 361 ancestor->isBeforeInBlock(end); 362 }); 363 } 364 365 /// Check whether a memref value used in any operation of 'src' has a 366 /// non-affine operation that is between `src` and `end` (exclusive of `src` 367 /// and `end`) where `src` and `end` are expected to be in the same Block. 368 /// Any other than affine read/write are treated as non-affine uses of memref. 369 static bool hasNonAffineUsersOnPath(Operation *src, Operation *end) { 370 assert(src->getBlock() == end->getBlock() && "same block expected"); 371 372 // Trivial case. `src` and `end` are exclusive. 373 if (src == end || end->isBeforeInBlock(src)) 374 return false; 375 376 // Collect relevant memref values. 377 llvm::SmallDenseSet<Value, 2> memRefValues; 378 src->walk([&](Operation *op) { 379 for (Value v : op->getOperands()) 380 // Collect memref values only. 381 if (isa<MemRefType>(v.getType())) 382 memRefValues.insert(v); 383 return WalkResult::advance(); 384 }); 385 // Look for non-affine users between `src` and `end`. 386 return llvm::any_of(memRefValues, [&](Value memref) { 387 return hasNonAffineUsersOnPath(src, end, memref); 388 }); 389 } 390 391 // Checks the profitability of fusing a backwards slice of the loop nest 392 // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'. 393 // The argument 'srcStoreOpInst' is used to calculate the storage reduction on 394 // the memref being produced and consumed, which is an input to the cost model. 395 // For producer-consumer fusion, 'srcStoreOpInst' will be the same as 396 // 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse 397 // fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the 398 // same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the 399 // unique store op in the src node, which will be used to check that the write 400 // region is the same after input-reuse fusion. Computation slices are provided 401 // in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which 402 // fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is 403 // profitable to fuse the candidate loop nests. Returns false otherwise. 404 // `dstLoopDepth` is set to the most profitable depth at which to materialize 405 // the source loop nest slice. 406 // The profitability model executes the following steps: 407 // *) Computes the backward computation slice at 'srcOpInst'. This 408 // computation slice of the loop nest surrounding 'srcOpInst' is 409 // represented by modified src loop bounds in 'sliceState', which are 410 // functions of loop IVs in the loop nest surrounding 'srcOpInst'. 411 // *) Computes the cost of unfused src/dst loop nests (currently the cost of a 412 // loop nest is the total number of dynamic operation instances in the loop 413 // nest). 414 // *) Computes the cost of fusing a slice of the src loop nest into the dst 415 // loop nest at various values of dst loop depth, attempting to fuse 416 // the largest computation slice at the maximal dst loop depth (closest to 417 // the load) to minimize reuse distance and potentially enable subsequent 418 // load/store forwarding. 419 // NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop 420 // nest, at which the src computation slice is inserted/fused. 421 // NOTE: We attempt to maximize the dst loop depth, but there are cases 422 // where a particular setting for 'dstLoopNest' might fuse an unsliced 423 // loop (within the src computation slice) at a depth which results in 424 // excessive recomputation (see unit tests for examples). 425 // *) Compares the total cost of the unfused loop nests to the min cost fused 426 // loop nest computed in the previous step, and returns true if the latter 427 // is lower. 428 // TODO: Extend profitability analysis to support scenarios with multiple 429 // stores. 430 static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, 431 AffineForOp dstForOp, 432 ArrayRef<ComputationSliceState> depthSliceUnions, 433 unsigned maxLegalFusionDepth, 434 unsigned *dstLoopDepth, 435 double computeToleranceThreshold) { 436 LLVM_DEBUG({ 437 llvm::dbgs() << "Checking whether fusion is profitable between src op:\n"; 438 llvm::dbgs() << ' ' << *srcOpInst << " and destination loop:\n"; 439 llvm::dbgs() << dstForOp << "\n"; 440 }); 441 442 if (maxLegalFusionDepth == 0) { 443 LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth is 0\n"); 444 return false; 445 } 446 447 // Compute cost of sliced and unsliced src loop nest. 448 SmallVector<AffineForOp, 4> srcLoopIVs; 449 getAffineForIVs(*srcOpInst, &srcLoopIVs); 450 451 // Walk src loop nest and collect stats. 452 LoopNestStats srcLoopNestStats; 453 if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats)) 454 return false; 455 456 // Compute cost of dst loop nest. 457 LoopNestStats dstLoopNestStats; 458 if (!getLoopNestStats(dstForOp, &dstLoopNestStats)) 459 return false; 460 461 // Search for min cost value for 'dstLoopDepth'. At each value of 462 // 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice 463 // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union 464 // of these bounds). Next the union slice bounds are used to calculate 465 // the cost of the slice and the cost of the slice inserted into the dst 466 // loop nest at 'dstLoopDepth'. 467 uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max(); 468 double maxStorageReduction = 0.0; 469 std::optional<uint64_t> sliceMemEstimate; 470 471 // The best loop depth at which to materialize the slice. 472 std::optional<unsigned> bestDstLoopDepth; 473 474 // Compute op instance count for the src loop nest without iteration slicing. 475 uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats); 476 477 // Compute src loop nest write region size. 478 MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc()); 479 if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) { 480 LLVM_DEBUG(llvm::dbgs() 481 << "Unable to compute MemRefRegion for source operation\n"); 482 return false; 483 } 484 485 std::optional<int64_t> maybeSrcWriteRegionSizeBytes = 486 srcWriteRegion.getRegionSize(); 487 if (!maybeSrcWriteRegionSizeBytes.has_value()) 488 return false; 489 int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes; 490 491 // Compute op instance count for the src loop nest. 492 uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats); 493 494 // Evaluate all depth choices for materializing the slice in the destination 495 // loop nest. 496 for (unsigned i = maxLegalFusionDepth; i >= 1; --i) { 497 const ComputationSliceState &slice = depthSliceUnions[i - 1]; 498 // Skip slice union if it wasn't computed for this depth. 499 if (slice.isEmpty()) 500 continue; 501 502 int64_t fusedLoopNestComputeCost; 503 if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp, 504 dstLoopNestStats, slice, 505 &fusedLoopNestComputeCost)) { 506 LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n"); 507 continue; 508 } 509 510 double additionalComputeFraction = 511 fusedLoopNestComputeCost / 512 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) - 513 1; 514 515 // Determine what the slice write MemRefRegion would be, if the src loop 516 // nest slice 'slice' were to be inserted into the dst loop nest at loop 517 // depth 'i'. 518 MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc()); 519 if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0, 520 &slice))) { 521 LLVM_DEBUG(llvm::dbgs() 522 << "Failed to compute slice write region at loopDepth: " << i 523 << "\n"); 524 continue; 525 } 526 527 std::optional<int64_t> maybeSliceWriteRegionSizeBytes = 528 sliceWriteRegion.getRegionSize(); 529 if (!maybeSliceWriteRegionSizeBytes.has_value() || 530 *maybeSliceWriteRegionSizeBytes == 0) { 531 LLVM_DEBUG(llvm::dbgs() 532 << "Failed to get slice write region size at loopDepth: " << i 533 << "\n"); 534 continue; 535 } 536 int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes; 537 538 // If we are fusing for reuse, check that write regions remain the same. 539 // TODO: Write region check should check sizes and offsets in 540 // each dimension, so that we are sure they are covering the same memref 541 // region. Also, move this out to a isMemRefRegionSuperSet helper function. 542 if (srcOpInst != srcStoreOpInst && 543 sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes) 544 continue; 545 546 double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) / 547 static_cast<double>(sliceWriteRegionSizeBytes); 548 549 LLVM_DEBUG({ 550 std::stringstream msg; 551 msg << " evaluating fusion profitability at depth : " << i << "\n" 552 << std::fixed << std::setprecision(2) 553 << " additional compute fraction: " 554 << 100.0 * additionalComputeFraction << "%\n" 555 << " storage reduction factor: " << storageReduction << "x\n" 556 << " fused nest cost: " << fusedLoopNestComputeCost << "\n" 557 << " src write region size: " << srcWriteRegionSizeBytes << "\n" 558 << " slice write region size: " << sliceWriteRegionSizeBytes 559 << "\n"; 560 llvm::dbgs() << msg.str(); 561 }); 562 563 // TODO: This is a placeholder cost model. 564 // Among all choices that add an acceptable amount of redundant computation 565 // (as per computeToleranceThreshold), we will simply pick the one that 566 // reduces the intermediary size the most. 567 if ((storageReduction > maxStorageReduction) && 568 (additionalComputeFraction < computeToleranceThreshold)) { 569 maxStorageReduction = storageReduction; 570 bestDstLoopDepth = i; 571 minFusedLoopNestComputeCost = fusedLoopNestComputeCost; 572 sliceMemEstimate = sliceWriteRegionSizeBytes; 573 } 574 } 575 576 // A simple cost model: fuse if it reduces the memory footprint. 577 578 if (!bestDstLoopDepth) { 579 LLVM_DEBUG( 580 llvm::dbgs() 581 << "All fusion choices involve more than the threshold amount of " 582 "redundant computation; NOT fusing.\n"); 583 return false; 584 } 585 586 if (!bestDstLoopDepth) { 587 LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n"); 588 return false; 589 } 590 591 // Set dstLoopDepth based on best values from search. 592 *dstLoopDepth = *bestDstLoopDepth; 593 594 LLVM_DEBUG( 595 llvm::dbgs() << " LoopFusion fusion stats:" 596 << "\n best loop depth: " << bestDstLoopDepth 597 << "\n src loop nest compute cost: " << srcLoopNestCost 598 << "\n dst loop nest compute cost: " << dstLoopNestCost 599 << "\n fused loop nest compute cost: " 600 << minFusedLoopNestComputeCost << "\n"); 601 602 auto dstMemSize = getMemoryFootprintBytes(dstForOp); 603 auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]); 604 605 std::optional<double> storageReduction; 606 607 if (!dstMemSize || !srcMemSize) { 608 LLVM_DEBUG(llvm::dbgs() 609 << " fusion memory benefit cannot be evaluated; NOT fusing.\n"); 610 return false; 611 } 612 613 auto srcMemSizeVal = *srcMemSize; 614 auto dstMemSizeVal = *dstMemSize; 615 616 assert(sliceMemEstimate && "expected value"); 617 auto fusedMem = dstMemSizeVal + *sliceMemEstimate; 618 619 LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n" 620 << " dst mem: " << dstMemSizeVal << "\n" 621 << " fused mem: " << fusedMem << "\n" 622 << " slice mem: " << sliceMemEstimate << "\n"); 623 624 if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) { 625 LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n"); 626 return false; 627 } 628 storageReduction = 629 100.0 * 630 (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal)); 631 632 double additionalComputeFraction = 633 100.0 * (minFusedLoopNestComputeCost / 634 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) - 635 1); 636 (void)additionalComputeFraction; 637 LLVM_DEBUG({ 638 std::stringstream msg; 639 msg << " fusion is most profitable at depth " << *dstLoopDepth << " with " 640 << std::setprecision(2) << additionalComputeFraction 641 << "% redundant computation and a "; 642 msg << (storageReduction ? std::to_string(*storageReduction) : "<unknown>"); 643 msg << "% storage reduction.\n"; 644 llvm::dbgs() << msg.str(); 645 }); 646 647 return true; 648 } 649 650 namespace { 651 652 // GreedyFusion greedily fuses loop nests which have a producer/consumer or 653 // input-reuse relationship on a memref, with the goal of improving locality. 654 // 655 // The steps of the producer-consumer fusion algorithm are as follows: 656 // 657 // *) A worklist is initialized with node ids from the dependence graph. 658 // *) For each node id in the worklist: 659 // *) Pop an AffineForOp of the worklist. This 'dstAffineForOp' will be a 660 // candidate destination AffineForOp into which fusion will be attempted. 661 // *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'. 662 // *) For each LoadOp in 'dstLoadOps' do: 663 // *) Look up dependent loop nests which have a single store op to the same 664 // memref. 665 // *) Check if dependences would be violated by the fusion. 666 // *) Get a computation slice of 'srcLoopNest', which adjusts its loop 667 // bounds to be functions of 'dstLoopNest' IVs and symbols. 668 // *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest', 669 // at a loop depth determined by the cost model in 'isFusionProfitable'. 670 // *) Add the newly fused load/store operations to the state, 671 // and also add newly fused load ops to 'dstLoopOps' to be considered 672 // as fusion dst load ops in another iteration. 673 // *) Remove old src loop nest and its associated state. 674 // 675 // The steps of the input-reuse fusion algorithm are as follows: 676 // 677 // *) Initialize 'worklist' with node ids from the dependence graph. 678 // *) For each 'dstNode' in the worklist: 679 // *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which 680 // loads from the same memref, but which has no dependence paths to/from. 681 // *) Get a computation slice of 'sibLoopNest', which adjusts its loop 682 // bounds to be functions of 'dstLoopNest' IVs and symbols. 683 // *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest', 684 // at a loop depth determined by the cost model in 'isFusionProfitable'. 685 // This function also checks that the memref write region of 'sibLoopNest', 686 // is preserved in the fused loop nest. 687 // *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'. 688 // 689 // Given a graph where top-level operations are vertices in the set 'V' and 690 // edges in the set 'E' are dependences between vertices, this algorithm 691 // takes O(V) time for initialization, and has runtime O(V + E). 692 // 693 // This greedy algorithm is not 'maximal' due to the current restriction of 694 // fusing along single producer consumer edges, but there is a TODO: to fix 695 // this. 696 // 697 // TODO: Experiment with other fusion policies. 698 struct GreedyFusion { 699 public: 700 // The data dependence graph to traverse during fusion. 701 MemRefDependenceGraph *mdg; 702 // Worklist of graph nodes visited during the fusion pass. 703 SmallVector<unsigned, 8> worklist; 704 // Parameter for local buffer size threshold. 705 unsigned localBufSizeThreshold; 706 // Parameter for fast memory space. 707 std::optional<unsigned> fastMemorySpace; 708 // If true, ignore any additional (redundant) computation tolerance threshold 709 // that would have prevented fusion. 710 bool maximalFusion; 711 // The amount of additional computation that is tolerated while fusing 712 // pair-wise as a fraction of the total computation. 713 double computeToleranceThreshold; 714 715 using Node = MemRefDependenceGraph::Node; 716 717 GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold, 718 std::optional<unsigned> fastMemorySpace, bool maximalFusion, 719 double computeToleranceThreshold) 720 : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold), 721 fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion), 722 computeToleranceThreshold(computeToleranceThreshold) {} 723 724 /// Initializes 'worklist' with nodes from 'mdg'. 725 void init() { 726 // TODO: Add a priority queue for prioritizing nodes by different 727 // metrics (e.g. arithmetic intensity/flops-to-bytes ratio). 728 worklist.clear(); 729 for (auto &idAndNode : mdg->nodes) { 730 const Node &node = idAndNode.second; 731 worklist.push_back(node.id); 732 } 733 } 734 /// Run only sibling fusion on the `mdg`. 735 void runSiblingFusionOnly() { 736 fuseSiblingNodes(); 737 eraseUnusedMemRefAllocations(); 738 } 739 740 /// Run only producer/consumer fusion on the `mdg`. 741 void runProducerConsumerFusionOnly() { 742 fuseProducerConsumerNodes( 743 /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max()); 744 eraseUnusedMemRefAllocations(); 745 } 746 747 // Run the GreedyFusion pass. 748 // *) First pass through the nodes fuses single-use producer nodes into their 749 // unique consumer. 750 // *) Second pass fuses sibling nodes which share no dependence edges. 751 // *) Third pass fuses any remaining producer nodes into their users. 752 void runGreedyFusion() { 753 // TODO: Run this repeatedly until a fixed-point is reached. 754 fuseProducerConsumerNodes(/*maxSrcUserCount=*/1); 755 fuseSiblingNodes(); 756 fuseProducerConsumerNodes( 757 /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max()); 758 eraseUnusedMemRefAllocations(); 759 } 760 761 /// Returns true if a private memref can be created for `memref` given 762 /// the fusion scenario reflected by the other arguments. 763 bool canCreatePrivateMemRef(Value memref, 764 const DenseSet<Value> &srcEscapingMemRefs, 765 unsigned producerId, unsigned consumerId, 766 bool removeSrcNode) { 767 const Node *consumerNode = mdg->getNode(consumerId); 768 // If `memref` is an escaping one, do not create a private memref 769 // for the below scenarios, since doing so will leave the escaping 770 // memref unmodified as all the writes originally meant for the 771 // escaping memref would be performed on the private memref: 772 // 1. The source is to be removed after fusion, 773 // OR 774 // 2. The destination writes to `memref`. 775 if (srcEscapingMemRefs.count(memref) > 0 && 776 (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0)) 777 return false; 778 779 // Don't create a private memref if 'srcNode' has in edges on 780 // 'memref' or 'dstNode' has out edges on 'memref'. 781 if (mdg->getIncomingMemRefAccesses(producerId, memref) > 0 || 782 mdg->getOutEdgeCount(consumerId, memref) > 0) 783 return false; 784 785 // If 'srcNode' will be removed but it has out edges on 'memref' to 786 // nodes other than 'dstNode', we have to preserve dependences and 787 // cannot create a private memref. 788 if (removeSrcNode && 789 any_of(mdg->outEdges[producerId], [&](const auto &edge) { 790 return edge.value == memref && edge.id != consumerId; 791 })) 792 return false; 793 794 return true; 795 } 796 797 /// Perform fusions with node `dstId` as the destination of fusion, with 798 /// No fusion is performed when producers with a user count greater than 799 /// `maxSrcUserCount` for any of the memrefs involved. 800 void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) { 801 LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n"); 802 // Skip if this node was removed (fused into another node). 803 if (mdg->nodes.count(dstId) == 0) 804 return; 805 // Get 'dstNode' into which to attempt fusion. 806 auto *dstNode = mdg->getNode(dstId); 807 // Skip if 'dstNode' is not a loop nest. 808 if (!isa<AffineForOp>(dstNode->op)) 809 return; 810 // Skip if 'dstNode' is a loop nest returning values. 811 // TODO: support loop nests that return values. 812 if (dstNode->op->getNumResults() > 0) 813 return; 814 815 LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n"); 816 817 // Sink sequential loops in 'dstNode' (and thus raise parallel loops) 818 // while preserving relative order. This can increase the maximum loop 819 // depth at which we can fuse a slice of a producer loop nest into a 820 // consumer loop nest. 821 sinkSequentialLoops(dstNode); 822 auto dstAffineForOp = cast<AffineForOp>(dstNode->op); 823 824 // Try to fuse 'dstNode' with candidate producer loops until a fixed point 825 // is reached. Fusing two loops may expose new fusion opportunities. 826 bool dstNodeChanged; 827 do { 828 // Gather src loop candidates for 'dstNode' and visit them in "quasi" 829 // reverse program order to minimize the number of iterations needed to 830 // reach the fixed point. Note that this is a best effort approach since 831 // 'getProducerCandidates' does not always guarantee that program order 832 // in 'srcIdCandidates'. 833 dstNodeChanged = false; 834 SmallVector<unsigned, 16> srcIdCandidates; 835 getProducerCandidates(dstId, mdg, srcIdCandidates); 836 837 for (unsigned srcId : llvm::reverse(srcIdCandidates)) { 838 // Get 'srcNode' from which to attempt fusion into 'dstNode'. 839 auto *srcNode = mdg->getNode(srcId); 840 auto srcAffineForOp = cast<AffineForOp>(srcNode->op); 841 LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId 842 << " for dst loop " << dstId << "\n"); 843 844 // Skip if 'srcNode' is a loop nest returning values. 845 // TODO: support loop nests that return values. 846 if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0) 847 continue; 848 849 DenseSet<Value> producerConsumerMemrefs; 850 gatherProducerConsumerMemrefs(srcId, dstId, mdg, 851 producerConsumerMemrefs); 852 853 // Skip if 'srcNode' out edge count on any memref is greater than 854 // 'maxSrcUserCount'. 855 if (any_of(producerConsumerMemrefs, [&](Value memref) { 856 return mdg->getOutEdgeCount(srcNode->id, memref) > 857 maxSrcUserCount; 858 })) 859 continue; 860 861 // Gather memrefs in 'srcNode' that are written and escape out of the 862 // block (e.g., memref block arguments, returned memrefs, 863 // memrefs passed to function calls, etc.). 864 DenseSet<Value> srcEscapingMemRefs; 865 gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs); 866 867 // Skip if there are non-affine operations in between the 'srcNode' 868 // and 'dstNode' using their memrefs. If so, we wouldn't be able to 869 // compute a legal insertion point for now. 'srcNode' and 'dstNode' 870 // memrefs with non-affine operation users would be considered 871 // escaping memrefs so we can limit this check to only scenarios with 872 // escaping memrefs. 873 if (!srcEscapingMemRefs.empty() && 874 hasNonAffineUsersOnPath(srcNode->op, dstNode->op)) { 875 LLVM_DEBUG(llvm::dbgs() 876 << "Can't fuse: non-affine users in between the loops\n"); 877 continue; 878 } 879 880 // Compute an operation list insertion point for the fused loop 881 // nest which preserves dependences. 882 Operation *fusedLoopInsPoint = 883 mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id); 884 if (fusedLoopInsPoint == nullptr) 885 continue; 886 887 // It's possible this fusion is at an inner depth (i.e., there are 888 // common surrounding affine loops for the source and destination for 889 // ops). We need to get this number because the call to canFuseLoops 890 // needs to be passed the absolute depth. The max legal depth and the 891 // depths we try below are however *relative* and as such don't include 892 // the common depth. 893 SmallVector<AffineForOp, 4> surroundingLoops; 894 getAffineForIVs(*dstAffineForOp, &surroundingLoops); 895 unsigned numSurroundingLoops = surroundingLoops.size(); 896 897 // Compute the innermost common loop depth for dstNode 898 // producer-consumer loads/stores. 899 SmallVector<Operation *, 2> dstMemrefOps; 900 for (Operation *op : dstNode->loads) 901 if (producerConsumerMemrefs.count( 902 cast<AffineReadOpInterface>(op).getMemRef()) > 0) 903 dstMemrefOps.push_back(op); 904 for (Operation *op : dstNode->stores) 905 if (producerConsumerMemrefs.count( 906 cast<AffineWriteOpInterface>(op).getMemRef())) 907 dstMemrefOps.push_back(op); 908 unsigned dstLoopDepthTest = 909 getInnermostCommonLoopDepth(dstMemrefOps) - numSurroundingLoops; 910 911 // Check the feasibility of fusing src loop nest into dst loop nest 912 // at loop depths in range [1, dstLoopDepthTest]. 913 unsigned maxLegalFusionDepth = 0; 914 SmallVector<ComputationSliceState, 8> depthSliceUnions; 915 depthSliceUnions.resize(dstLoopDepthTest); 916 FusionStrategy strategy(FusionStrategy::ProducerConsumer); 917 for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { 918 FusionResult result = 919 affine::canFuseLoops(srcAffineForOp, dstAffineForOp, 920 /*dstLoopDepth=*/i + numSurroundingLoops, 921 &depthSliceUnions[i - 1], strategy); 922 923 if (result.value == FusionResult::Success) 924 maxLegalFusionDepth = i; 925 } 926 927 if (maxLegalFusionDepth == 0) { 928 LLVM_DEBUG(llvm::dbgs() 929 << "Can't fuse: fusion is not legal at any depth\n"); 930 continue; 931 } 932 933 // Check if fusion would be profitable. We skip profitability analysis 934 // for maximal fusion since we already know the maximal legal depth to 935 // fuse. 936 unsigned bestDstLoopDepth = maxLegalFusionDepth; 937 if (!maximalFusion) { 938 // Retrieve producer stores from the src loop. 939 SmallVector<Operation *, 2> producerStores; 940 for (Operation *op : srcNode->stores) 941 if (producerConsumerMemrefs.count( 942 cast<AffineWriteOpInterface>(op).getMemRef())) 943 producerStores.push_back(op); 944 945 // TODO: Suppport multiple producer stores in profitability 946 // analysis. We limit profitability analysis to only scenarios with 947 // a single producer store for now. Note that some multi-store 948 // producer scenarios will still go through profitability analysis 949 // if only one of the stores is involved the producer-consumer 950 // relationship of the candidate loops. 951 assert(!producerStores.empty() && "Expected producer store"); 952 if (producerStores.size() > 1) 953 LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not " 954 "supported for this case\n"); 955 else if (!isFusionProfitable(producerStores[0], producerStores[0], 956 dstAffineForOp, depthSliceUnions, 957 maxLegalFusionDepth, &bestDstLoopDepth, 958 computeToleranceThreshold)) 959 continue; 960 } 961 962 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); 963 ComputationSliceState &bestSlice = 964 depthSliceUnions[bestDstLoopDepth - 1]; 965 assert(!bestSlice.isEmpty() && "Missing slice union for depth"); 966 967 // Determine if 'srcId' can be removed after fusion, taking into 968 // account remaining dependences, escaping memrefs and the fusion 969 // insertion point. 970 bool removeSrcNode = canRemoveSrcNodeAfterFusion( 971 srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs, 972 mdg); 973 974 DenseSet<Value> privateMemrefs; 975 for (Value memref : producerConsumerMemrefs) { 976 if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId, 977 removeSrcNode)) { 978 // Create a private version of this memref. 979 LLVM_DEBUG(llvm::dbgs() 980 << "Creating private memref for " << memref << '\n'); 981 // Create a private version of this memref. 982 privateMemrefs.insert(memref); 983 } 984 } 985 986 // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. 987 fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice); 988 dstNodeChanged = true; 989 990 LLVM_DEBUG(llvm::dbgs() 991 << "Fused src loop " << srcId << " into dst loop " << dstId 992 << " at depth " << bestDstLoopDepth << ":\n" 993 << dstAffineForOp << "\n"); 994 995 // Move 'dstAffineForOp' before 'insertPointInst' if needed. 996 if (fusedLoopInsPoint != dstAffineForOp) 997 dstAffineForOp->moveBefore(fusedLoopInsPoint); 998 999 // Update edges between 'srcNode' and 'dstNode'. 1000 mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs, 1001 removeSrcNode); 1002 1003 // Create private memrefs. 1004 if (!privateMemrefs.empty()) { 1005 // Gather stores for all the private-to-be memrefs. 1006 DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores; 1007 dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) { 1008 Value storeMemRef = storeOp.getMemRef(); 1009 if (privateMemrefs.count(storeMemRef) > 0) 1010 privateMemRefToStores[storeMemRef].push_back(storeOp); 1011 }); 1012 1013 // Replace original memrefs with private memrefs. Note that all the 1014 // loads and stores on these memrefs will be replaced with a new 1015 // loads and stores. Any reference to the original ones becomes 1016 // invalid after this point. 1017 for (auto &memrefToStoresPair : privateMemRefToStores) { 1018 // TODO: Use union of memref write regions to compute 1019 // private memref footprint. 1020 SmallVector<Operation *, 4> &storesForMemref = 1021 memrefToStoresPair.second; 1022 Value newMemRef = createPrivateMemRef( 1023 dstAffineForOp, storesForMemref[0], bestDstLoopDepth, 1024 fastMemorySpace, localBufSizeThreshold); 1025 // Create new node in dependence graph for 'newMemRef' alloc op. 1026 unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp()); 1027 // Add edge from 'newMemRef' node to dstNode. 1028 mdg->addEdge(newMemRefNodeId, dstId, newMemRef); 1029 } 1030 // One or more entries for 'newMemRef' alloc op are inserted into 1031 // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to 1032 // reallocate, update dstNode. 1033 dstNode = mdg->getNode(dstId); 1034 } 1035 1036 // Collect dst loop stats after memref privatization transformation. 1037 LoopNestStateCollector dstLoopCollector; 1038 dstLoopCollector.collect(dstAffineForOp); 1039 1040 // Clear and add back loads and stores. 1041 mdg->clearNodeLoadAndStores(dstNode->id); 1042 mdg->addToNode(dstId, dstLoopCollector.loadOpInsts, 1043 dstLoopCollector.storeOpInsts); 1044 1045 if (removeSrcNode) { 1046 LLVM_DEBUG(llvm::dbgs() 1047 << "Removing src loop " << srcId << " after fusion\n"); 1048 // srcNode is no longer valid after it is removed from mdg. 1049 srcAffineForOp.erase(); 1050 mdg->removeNode(srcId); 1051 srcNode = nullptr; 1052 } 1053 } 1054 } while (dstNodeChanged); 1055 } 1056 1057 /// Visit each node in the graph, and for each node, attempt to fuse it with 1058 /// producer-consumer candidates. No fusion is performed when producers with a 1059 /// user count greater than `maxSrcUserCount` for any of the memrefs involved 1060 /// are encountered. 1061 void fuseProducerConsumerNodes(unsigned maxSrcUserCount) { 1062 LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n"); 1063 init(); 1064 while (!worklist.empty()) { 1065 unsigned dstId = worklist.back(); 1066 worklist.pop_back(); 1067 performFusionsIntoDest(dstId, maxSrcUserCount); 1068 } 1069 } 1070 1071 // Visits each node in the graph, and for each node, attempts to fuse it with 1072 // its sibling nodes (nodes which share a parent, but no dependence edges). 1073 void fuseSiblingNodes() { 1074 LLVM_DEBUG(llvm::dbgs() << "--- Sibling Fusion ---\n"); 1075 init(); 1076 while (!worklist.empty()) { 1077 unsigned dstId = worklist.back(); 1078 worklist.pop_back(); 1079 1080 // Skip if this node was removed (fused into another node). 1081 if (mdg->nodes.count(dstId) == 0) 1082 continue; 1083 // Get 'dstNode' into which to attempt fusion. 1084 auto *dstNode = mdg->getNode(dstId); 1085 // Skip if 'dstNode' is not a loop nest. 1086 if (!isa<AffineForOp>(dstNode->op)) 1087 continue; 1088 // Attempt to fuse 'dstNode' with its sibling nodes in the graph. 1089 fuseWithSiblingNodes(dstNode); 1090 } 1091 } 1092 1093 // Attempt to fuse 'dstNode' with sibling nodes in the graph. 1094 void fuseWithSiblingNodes(Node *dstNode) { 1095 DenseSet<unsigned> visitedSibNodeIds; 1096 std::pair<unsigned, Value> idAndMemref; 1097 auto dstAffineForOp = cast<AffineForOp>(dstNode->op); 1098 1099 while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) { 1100 unsigned sibId = idAndMemref.first; 1101 Value memref = idAndMemref.second; 1102 // TODO: Check that 'sibStoreOpInst' post-dominates all other 1103 // stores to the same memref in 'sibNode' loop nest. 1104 auto *sibNode = mdg->getNode(sibId); 1105 // Compute an operation list insertion point for the fused loop 1106 // nest which preserves dependences. 1107 assert(sibNode->op->getBlock() == dstNode->op->getBlock()); 1108 Operation *insertPointInst = 1109 sibNode->op->isBeforeInBlock(dstNode->op) 1110 ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id) 1111 : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id); 1112 if (insertPointInst == nullptr) 1113 continue; 1114 1115 // Check if fusion would be profitable and at what depth. 1116 1117 // Get unique 'sibNode' load op to 'memref'. 1118 SmallVector<Operation *, 2> sibLoadOpInsts; 1119 sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts); 1120 // Currently findSiblingNodeToFuse searches for siblings with one load. 1121 assert(sibLoadOpInsts.size() == 1); 1122 Operation *sibLoadOpInst = sibLoadOpInsts[0]; 1123 1124 // Gather 'dstNode' load ops to 'memref'. 1125 SmallVector<Operation *, 2> dstLoadOpInsts; 1126 dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts); 1127 1128 // It's possible this fusion is at an inner depth (i.e., there are common 1129 // surrounding affine loops for the source and destination for ops). We 1130 // need to get this number because the call to canFuseLoops needs to be 1131 // passed the absolute depth. The max legal depth and the depths we try 1132 // below are however *relative* and as such don't include the common 1133 // depth. 1134 SmallVector<AffineForOp, 4> surroundingLoops; 1135 getAffineForIVs(*dstAffineForOp, &surroundingLoops); 1136 unsigned numSurroundingLoops = surroundingLoops.size(); 1137 SmallVector<AffineForOp, 4> dstLoopIVs; 1138 getAffineForIVs(*dstLoadOpInsts[0], &dstLoopIVs); 1139 unsigned dstLoopDepthTest = dstLoopIVs.size() - numSurroundingLoops; 1140 auto sibAffineForOp = cast<AffineForOp>(sibNode->op); 1141 1142 // Compute loop depth and slice union for fusion. 1143 SmallVector<ComputationSliceState, 8> depthSliceUnions; 1144 depthSliceUnions.resize(dstLoopDepthTest); 1145 unsigned maxLegalFusionDepth = 0; 1146 FusionStrategy strategy(memref); 1147 for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { 1148 FusionResult result = 1149 affine::canFuseLoops(sibAffineForOp, dstAffineForOp, 1150 /*dstLoopDepth=*/i + numSurroundingLoops, 1151 &depthSliceUnions[i - 1], strategy); 1152 1153 if (result.value == FusionResult::Success) 1154 maxLegalFusionDepth = i; 1155 } 1156 1157 LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: " 1158 << maxLegalFusionDepth << '\n'); 1159 1160 // Skip if fusion is not feasible at any loop depths. 1161 if (maxLegalFusionDepth == 0) 1162 continue; 1163 1164 unsigned bestDstLoopDepth = maxLegalFusionDepth; 1165 if (!maximalFusion) { 1166 // Check if fusion would be profitable. For sibling fusion, the sibling 1167 // load op is treated as the src "store" op for fusion profitability 1168 // purposes. The footprint of the load in the slice relative to the 1169 // unfused source's determines reuse. 1170 if (!isFusionProfitable(sibLoadOpInst, sibLoadOpInst, dstAffineForOp, 1171 depthSliceUnions, maxLegalFusionDepth, 1172 &bestDstLoopDepth, computeToleranceThreshold)) 1173 continue; 1174 } 1175 1176 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); 1177 assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() && 1178 "Fusion depth has no computed slice union"); 1179 // Check if source loop is being inserted in the innermost 1180 // destination loop. Based on this, the fused loop may be optimized 1181 // further inside `fuseLoops`. 1182 bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest); 1183 // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'. 1184 affine::fuseLoops(sibAffineForOp, dstAffineForOp, 1185 depthSliceUnions[bestDstLoopDepth - 1], 1186 isInnermostInsertion); 1187 1188 auto dstForInst = cast<AffineForOp>(dstNode->op); 1189 // Update operation position of fused loop nest (if needed). 1190 if (insertPointInst != dstForInst) { 1191 dstForInst->moveBefore(insertPointInst); 1192 } 1193 // Update data dependence graph state post fusion. 1194 updateStateAfterSiblingFusion(sibNode, dstNode); 1195 } 1196 } 1197 1198 // Searches block argument uses and the graph from 'dstNode' looking for a 1199 // fusion candidate sibling node which shares no dependences with 'dstNode' 1200 // but which loads from the same memref. Returns true and sets 1201 // 'idAndMemrefToFuse' on success. Returns false otherwise. 1202 bool findSiblingNodeToFuse(Node *dstNode, 1203 DenseSet<unsigned> *visitedSibNodeIds, 1204 std::pair<unsigned, Value> *idAndMemrefToFuse) { 1205 // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse 1206 // on 'memref'. 1207 auto canFuseWithSibNode = [&](Node *sibNode, Value memref) { 1208 // Skip if 'outEdge' is not a read-after-write dependence. 1209 // TODO: Remove restrict to single load op restriction. 1210 if (sibNode->getLoadOpCount(memref) != 1) 1211 return false; 1212 // Skip if there exists a path of dependent edges between 1213 // 'sibNode' and 'dstNode'. 1214 if (mdg->hasDependencePath(sibNode->id, dstNode->id) || 1215 mdg->hasDependencePath(dstNode->id, sibNode->id)) 1216 return false; 1217 // Skip sib node if it loads to (and stores from) the same memref on 1218 // which it also has an input dependence edge. 1219 DenseSet<Value> loadAndStoreMemrefSet; 1220 sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet); 1221 if (llvm::any_of(loadAndStoreMemrefSet, [=](Value memref) { 1222 return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0; 1223 })) 1224 return false; 1225 1226 // Check that all stores are to the same memref if any. 1227 DenseSet<Value> storeMemrefs; 1228 for (auto *storeOpInst : sibNode->stores) { 1229 storeMemrefs.insert( 1230 cast<AffineWriteOpInterface>(storeOpInst).getMemRef()); 1231 } 1232 if (storeMemrefs.size() > 1) 1233 return false; 1234 1235 // Skip if a memref value in one node is used by a non-affine memref 1236 // access that lies between 'dstNode' and 'sibNode'. 1237 if (hasNonAffineUsersOnPath(dstNode->op, sibNode->op) || 1238 hasNonAffineUsersOnPath(sibNode->op, dstNode->op)) 1239 return false; 1240 return true; 1241 }; 1242 1243 // Search for siblings which load the same memref block argument. 1244 Block *block = dstNode->op->getBlock(); 1245 for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) { 1246 for (Operation *user : block->getArgument(i).getUsers()) { 1247 auto loadOp = dyn_cast<AffineReadOpInterface>(user); 1248 if (!loadOp) 1249 continue; 1250 // Gather loops surrounding 'use'. 1251 SmallVector<AffineForOp, 4> loops; 1252 getAffineForIVs(*user, &loops); 1253 // Skip 'use' if it is not within a loop nest. 1254 // Find the surrounding affine.for nested immediately within the 1255 // block. 1256 auto *it = llvm::find_if(loops, [&](AffineForOp loop) { 1257 return loop->getBlock() == &mdg->block; 1258 }); 1259 // Skip 'use' if it is not within a loop nest in `block`. 1260 if (it == loops.end()) 1261 continue; 1262 Node *sibNode = mdg->getForOpNode(*it); 1263 assert(sibNode != nullptr); 1264 // Skip 'use' if it not a sibling to 'dstNode'. 1265 if (sibNode->id == dstNode->id) 1266 continue; 1267 // Skip 'use' if it has been visited. 1268 if (visitedSibNodeIds->count(sibNode->id) > 0) 1269 continue; 1270 // Skip 'use' if it does not load from the same memref as 'dstNode'. 1271 auto memref = loadOp.getMemRef(); 1272 if (dstNode->getLoadOpCount(memref) == 0) 1273 continue; 1274 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. 1275 if (canFuseWithSibNode(sibNode, memref)) { 1276 visitedSibNodeIds->insert(sibNode->id); 1277 idAndMemrefToFuse->first = sibNode->id; 1278 idAndMemrefToFuse->second = memref; 1279 return true; 1280 } 1281 } 1282 } 1283 1284 // Search for siblings by following edges through an intermediate src node. 1285 // Collect candidate 'dstNode' input edges in 'inEdges'. 1286 SmallVector<MemRefDependenceGraph::Edge, 2> inEdges; 1287 mdg->forEachMemRefInputEdge( 1288 dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) { 1289 // Add 'inEdge' if it is a read-after-write dependence. 1290 if (dstNode->getLoadOpCount(inEdge.value) > 0 && 1291 mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0) 1292 inEdges.push_back(inEdge); 1293 }); 1294 1295 // Search for sibling nodes to fuse by visiting output edges from each input 1296 // edge in 'inEdges'. 1297 for (auto &inEdge : inEdges) { 1298 // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'. 1299 SmallVector<MemRefDependenceGraph::Edge, 2> outEdges; 1300 mdg->forEachMemRefOutputEdge( 1301 inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) { 1302 unsigned sibNodeId = outEdge.id; 1303 if (visitedSibNodeIds->count(sibNodeId) > 0) 1304 return; 1305 // Skip output edge if not a sibling using the same memref. 1306 if (outEdge.id == dstNode->id || outEdge.value != inEdge.value) 1307 return; 1308 auto *sibNode = mdg->getNode(sibNodeId); 1309 if (!isa<AffineForOp>(sibNode->op)) 1310 return; 1311 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. 1312 if (canFuseWithSibNode(sibNode, outEdge.value)) { 1313 // Add candidate 'outEdge' to sibling node. 1314 outEdges.push_back(outEdge); 1315 } 1316 }); 1317 1318 // Add first candidate if any were returned. 1319 if (!outEdges.empty()) { 1320 visitedSibNodeIds->insert(outEdges[0].id); 1321 idAndMemrefToFuse->first = outEdges[0].id; 1322 idAndMemrefToFuse->second = outEdges[0].value; 1323 return true; 1324 } 1325 } 1326 return false; 1327 } 1328 1329 /// Update data dependence graph state to reflect sibling fusion of 'sibNode' 1330 /// into 'dstNode'. 1331 void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) { 1332 // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion. 1333 mdg->updateEdges(sibNode->id, dstNode->id); 1334 1335 // Collect dst loop stats after memref privatization transformation. 1336 auto dstForInst = cast<AffineForOp>(dstNode->op); 1337 LoopNestStateCollector dstLoopCollector; 1338 dstLoopCollector.collect(dstForInst); 1339 // Clear and add back loads and stores 1340 mdg->clearNodeLoadAndStores(dstNode->id); 1341 mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts, 1342 dstLoopCollector.storeOpInsts); 1343 // Remove old sibling loop nest if it no longer has outgoing dependence 1344 // edges, and it does not write to a memref which escapes the block. 1345 if (mdg->getOutEdgeCount(sibNode->id) == 0) { 1346 Operation *op = sibNode->op; 1347 mdg->removeNode(sibNode->id); 1348 op->erase(); 1349 } 1350 } 1351 1352 // Clean up any allocs with no users. 1353 void eraseUnusedMemRefAllocations() { 1354 for (auto &pair : mdg->memrefEdgeCount) { 1355 if (pair.second > 0) 1356 continue; 1357 auto memref = pair.first; 1358 // Skip if there exist other uses (return operation or function calls). 1359 if (!memref.use_empty()) 1360 continue; 1361 // Use list expected to match the dep graph info. 1362 auto *op = memref.getDefiningOp(); 1363 if (isa_and_nonnull<memref::AllocOp>(op)) 1364 op->erase(); 1365 } 1366 } 1367 }; 1368 1369 } // namespace 1370 1371 /// Run fusion on `block`. 1372 void LoopFusion::runOnBlock(Block *block) { 1373 MemRefDependenceGraph g(*block); 1374 if (!g.init()) { 1375 LLVM_DEBUG(llvm::dbgs() << "MDG init failed\n"); 1376 return; 1377 } 1378 1379 std::optional<unsigned> fastMemorySpaceOpt; 1380 if (fastMemorySpace.hasValue()) 1381 fastMemorySpaceOpt = fastMemorySpace; 1382 unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024; 1383 GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt, 1384 maximalFusion, computeToleranceThreshold); 1385 1386 if (affineFusionMode == FusionMode::ProducerConsumer) 1387 fusion.runProducerConsumerFusionOnly(); 1388 else if (affineFusionMode == FusionMode::Sibling) 1389 fusion.runSiblingFusionOnly(); 1390 else 1391 fusion.runGreedyFusion(); 1392 } 1393 1394 void LoopFusion::runOnOperation() { 1395 // Call fusion on every op that has at least two affine.for nests (in post 1396 // order). 1397 getOperation()->walk([&](Operation *op) { 1398 for (Region ®ion : op->getRegions()) { 1399 for (Block &block : region.getBlocks()) { 1400 auto affineFors = block.getOps<AffineForOp>(); 1401 if (!affineFors.empty() && !llvm::hasSingleElement(affineFors)) 1402 runOnBlock(&block); 1403 } 1404 } 1405 }); 1406 } 1407 1408 std::unique_ptr<Pass> mlir::affine::createLoopFusionPass( 1409 unsigned fastMemorySpace, uint64_t localBufSizeThreshold, 1410 bool maximalFusion, enum FusionMode affineFusionMode) { 1411 return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold, 1412 maximalFusion, affineFusionMode); 1413 } 1414