1 //===- LoopFusion.cpp - Code to perform loop fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements affine fusion. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/Passes.h" 14 15 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 16 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" 17 #include "mlir/Dialect/Affine/Analysis/Utils.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Affine/LoopFusionUtils.h" 20 #include "mlir/Dialect/Affine/LoopUtils.h" 21 #include "mlir/Dialect/Affine/Utils.h" 22 #include "mlir/Dialect/MemRef/IR/MemRef.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/AffineMap.h" 25 #include "mlir/IR/Builders.h" 26 #include "mlir/Transforms/Passes.h" 27 #include "llvm/ADT/DenseMap.h" 28 #include "llvm/ADT/DenseSet.h" 29 #include "llvm/ADT/STLExtras.h" 30 #include "llvm/ADT/SetVector.h" 31 #include "llvm/Support/CommandLine.h" 32 #include "llvm/Support/Debug.h" 33 #include "llvm/Support/raw_ostream.h" 34 #include <iomanip> 35 #include <optional> 36 #include <sstream> 37 38 namespace mlir { 39 namespace affine { 40 #define GEN_PASS_DEF_AFFINELOOPFUSION 41 #include "mlir/Dialect/Affine/Passes.h.inc" 42 } // namespace affine 43 } // namespace mlir 44 45 #define DEBUG_TYPE "affine-loop-fusion" 46 47 using namespace mlir; 48 using namespace mlir::affine; 49 50 namespace { 51 /// Loop fusion pass. This pass currently supports a greedy fusion policy, 52 /// which fuses loop nests with single-writer/single-reader memref dependences 53 /// with the goal of improving locality. 54 55 // TODO: Support fusion of source loop nests which write to multiple 56 // memrefs, where each memref can have multiple users (if profitable). 57 // TODO: Extend this pass to check for fusion preventing dependences, 58 // and add support for more general loop fusion algorithms. 59 60 struct LoopFusion : public affine::impl::AffineLoopFusionBase<LoopFusion> { 61 LoopFusion() = default; 62 LoopFusion(unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes, 63 bool maximalFusion, enum FusionMode affineFusionMode) { 64 this->fastMemorySpace = fastMemorySpace; 65 this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024; 66 this->maximalFusion = maximalFusion; 67 this->affineFusionMode = affineFusionMode; 68 } 69 70 void runOnBlock(Block *block); 71 void runOnOperation() override; 72 }; 73 74 } // namespace 75 76 /// Returns true if node 'srcId' can be removed after fusing it with node 77 /// 'dstId'. The node can be removed if any of the following conditions are met: 78 /// 1. 'srcId' has no output dependences after fusion and no escaping memrefs. 79 /// 2. 'srcId' has no output dependences after fusion, has escaping memrefs 80 /// and the fusion slice is maximal. 81 /// 3. 'srcId' has output dependences after fusion, the fusion slice is 82 /// maximal and the fusion insertion point dominates all the dependences. 83 static bool canRemoveSrcNodeAfterFusion( 84 unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice, 85 Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs, 86 MemRefDependenceGraph *mdg) { 87 88 Operation *dstNodeOp = mdg->getNode(dstId)->op; 89 bool hasOutDepsAfterFusion = false; 90 91 for (auto &outEdge : mdg->outEdges[srcId]) { 92 Operation *depNodeOp = mdg->getNode(outEdge.id)->op; 93 // Skip dependence with dstOp since it will be removed after fusion. 94 if (depNodeOp == dstNodeOp) 95 continue; 96 97 // Only fusion within the same block is supported. Use domination analysis 98 // when needed. 99 if (depNodeOp->getBlock() != dstNodeOp->getBlock()) 100 return false; 101 102 // Check if the insertion point of the fused loop dominates the dependence. 103 // Otherwise, the src loop can't be removed. 104 if (fusedLoopInsPoint != depNodeOp && 105 !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) { 106 LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't " 107 "dominate dependence\n"); 108 return false; 109 } 110 111 hasOutDepsAfterFusion = true; 112 } 113 114 // If src loop has dependences after fusion or it writes to an live-out or 115 // escaping memref, we can only remove it if the fusion slice is maximal so 116 // that all the dependences are preserved. 117 if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) { 118 std::optional<bool> isMaximal = fusionSlice.isMaximal(); 119 if (!isMaximal) { 120 LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine " 121 "if fusion is maximal\n"); 122 return false; 123 } 124 125 if (!*isMaximal) { 126 LLVM_DEBUG(llvm::dbgs() 127 << "Src loop can't be removed: fusion is not maximal\n"); 128 return false; 129 } 130 } 131 132 return true; 133 } 134 135 /// Returns in 'srcIdCandidates' the producer fusion candidates for consumer 136 /// 'dstId'. Candidates are sorted by node id order. This order corresponds to 137 /// the program order when the 'mdg' is created. However, program order is not 138 /// guaranteed and must not be required by the client. Program order won't be 139 /// held if the 'mdg' is reused from a previous fusion step or if the node 140 /// creation order changes in the future to support more advance cases. 141 // TODO: Move this to a loop fusion utility once 'mdg' is also moved. 142 static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg, 143 SmallVectorImpl<unsigned> &srcIdCandidates) { 144 // Skip if no input edges along which to fuse. 145 if (mdg->inEdges.count(dstId) == 0) 146 return; 147 148 // Gather memrefs from loads in 'dstId'. 149 auto *dstNode = mdg->getNode(dstId); 150 DenseSet<Value> consumedMemrefs; 151 for (Operation *load : dstNode->loads) 152 consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef()); 153 154 // Traverse 'dstId' incoming edges and gather the nodes that contain a store 155 // to one of the consumed memrefs. 156 for (auto &srcEdge : mdg->inEdges[dstId]) { 157 auto *srcNode = mdg->getNode(srcEdge.id); 158 // Skip if 'srcNode' is not a loop nest. 159 if (!isa<AffineForOp>(srcNode->op)) 160 continue; 161 162 if (any_of(srcNode->stores, [&](Operation *op) { 163 auto storeOp = cast<AffineWriteOpInterface>(op); 164 return consumedMemrefs.count(storeOp.getMemRef()) > 0; 165 })) 166 srcIdCandidates.push_back(srcNode->id); 167 } 168 169 llvm::sort(srcIdCandidates); 170 srcIdCandidates.erase( 171 std::unique(srcIdCandidates.begin(), srcIdCandidates.end()), 172 srcIdCandidates.end()); 173 } 174 175 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a 176 /// producer-consumer dependence between 'srcId' and 'dstId'. 177 static void 178 gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId, 179 MemRefDependenceGraph *mdg, 180 DenseSet<Value> &producerConsumerMemrefs) { 181 auto *dstNode = mdg->getNode(dstId); 182 auto *srcNode = mdg->getNode(srcId); 183 gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads, 184 producerConsumerMemrefs); 185 } 186 187 /// A memref escapes in the context of the fusion pass if either: 188 /// 1. it (or its alias) is a block argument, or 189 /// 2. created by an op not known to guarantee alias freedom, 190 /// 3. it (or its alias) are used by ops other than affine dereferencing ops 191 /// (e.g., by call op, memref load/store ops, alias creating ops, unknown ops, 192 /// terminator ops, etc.); such ops do not deference the memref in an affine 193 /// way. 194 static bool isEscapingMemref(Value memref, Block *block) { 195 Operation *defOp = memref.getDefiningOp(); 196 // Check if 'memref' is a block argument. 197 if (!defOp) 198 return true; 199 200 // Check if this is defined to be an alias of another memref. 201 if (auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp)) 202 if (isEscapingMemref(viewOp.getViewSource(), block)) 203 return true; 204 205 // Any op besides allocating ops wouldn't guarantee alias freedom 206 if (!hasSingleEffect<mlir::MemoryEffects::Allocate>(defOp, memref)) 207 return true; 208 209 // Check if 'memref' is used by a non-deferencing op (including unknown ones) 210 // (e.g., call ops, alias creating ops, etc.). 211 return llvm::any_of(memref.getUsers(), [&](Operation *user) { 212 // Ignore users outside of `block`. 213 if (block->getParent()->findAncestorOpInRegion(*user)->getBlock() != block) 214 return false; 215 return !isa<AffineMapAccessInterface>(*user); 216 }); 217 } 218 219 /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' 220 /// that escape the block or are accessed in a non-affine way. 221 static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg, 222 DenseSet<Value> &escapingMemRefs) { 223 auto *node = mdg->getNode(id); 224 for (Operation *storeOp : node->stores) { 225 auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef(); 226 if (escapingMemRefs.count(memref)) 227 continue; 228 if (isEscapingMemref(memref, &mdg->block)) 229 escapingMemRefs.insert(memref); 230 } 231 } 232 233 // Initializes the data dependence graph by walking operations in `block`. 234 // Assigns each node in the graph a node id based on program order in 'f'. 235 bool MemRefDependenceGraph::init() { 236 LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n"); 237 // Map from a memref to the set of ids of the nodes that have ops accessing 238 // the memref. 239 DenseMap<Value, SetVector<unsigned>> memrefAccesses; 240 241 DenseMap<Operation *, unsigned> forToNodeMap; 242 for (Operation &op : block) { 243 if (auto forOp = dyn_cast<AffineForOp>(op)) { 244 // Create graph node 'id' to represent top-level 'forOp' and record 245 // all loads and store accesses it contains. 246 LoopNestStateCollector collector; 247 collector.collect(&op); 248 // Return false if a region holding op other than 'affine.for' and 249 // 'affine.if' was found (not currently supported). 250 if (collector.hasNonAffineRegionOp) 251 return false; 252 Node node(nextNodeId++, &op); 253 for (auto *opInst : collector.loadOpInsts) { 254 node.loads.push_back(opInst); 255 auto memref = cast<AffineReadOpInterface>(opInst).getMemRef(); 256 memrefAccesses[memref].insert(node.id); 257 } 258 for (auto *opInst : collector.storeOpInsts) { 259 node.stores.push_back(opInst); 260 auto memref = cast<AffineWriteOpInterface>(opInst).getMemRef(); 261 memrefAccesses[memref].insert(node.id); 262 } 263 forToNodeMap[&op] = node.id; 264 nodes.insert({node.id, node}); 265 } else if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) { 266 // Create graph node for top-level load op. 267 Node node(nextNodeId++, &op); 268 node.loads.push_back(&op); 269 auto memref = cast<AffineReadOpInterface>(op).getMemRef(); 270 memrefAccesses[memref].insert(node.id); 271 nodes.insert({node.id, node}); 272 } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) { 273 // Create graph node for top-level store op. 274 Node node(nextNodeId++, &op); 275 node.stores.push_back(&op); 276 auto memref = cast<AffineWriteOpInterface>(op).getMemRef(); 277 memrefAccesses[memref].insert(node.id); 278 nodes.insert({node.id, node}); 279 } else if (op.getNumRegions() != 0) { 280 // Return false if another region is found (not currently supported). 281 return false; 282 } else if (op.getNumResults() > 0 && !op.use_empty()) { 283 // Create graph node for top-level producer of SSA values, which 284 // could be used by loop nest nodes. 285 Node node(nextNodeId++, &op); 286 nodes.insert({node.id, node}); 287 } else if (isa<CallOpInterface>(op)) { 288 // Create graph node for top-level Call Op that takes any argument of 289 // memref type. Call Op that returns one or more memref type results 290 // is already taken care of, by the previous conditions. 291 if (llvm::any_of(op.getOperandTypes(), 292 [&](Type t) { return isa<MemRefType>(t); })) { 293 Node node(nextNodeId++, &op); 294 nodes.insert({node.id, node}); 295 } 296 } else if (hasEffect<MemoryEffects::Write, MemoryEffects::Free>(&op)) { 297 // Create graph node for top-level op, which could have a memory write 298 // side effect. 299 Node node(nextNodeId++, &op); 300 nodes.insert({node.id, node}); 301 } 302 } 303 304 for (auto &idAndNode : nodes) { 305 LLVM_DEBUG(llvm::dbgs() << "Create node " << idAndNode.first << " for:\n" 306 << *(idAndNode.second.op) << "\n"); 307 (void)idAndNode; 308 } 309 310 // Add dependence edges between nodes which produce SSA values and their 311 // users. Load ops can be considered as the ones producing SSA values. 312 for (auto &idAndNode : nodes) { 313 const Node &node = idAndNode.second; 314 // Stores don't define SSA values, skip them. 315 if (!node.stores.empty()) 316 continue; 317 Operation *opInst = node.op; 318 for (Value value : opInst->getResults()) { 319 for (Operation *user : value.getUsers()) { 320 // Ignore users outside of the block. 321 if (block.getParent()->findAncestorOpInRegion(*user)->getBlock() != 322 &block) 323 continue; 324 SmallVector<AffineForOp, 4> loops; 325 getAffineForIVs(*user, &loops); 326 if (loops.empty()) 327 continue; 328 assert(forToNodeMap.count(loops[0]) > 0 && "missing mapping"); 329 unsigned userLoopNestId = forToNodeMap[loops[0]]; 330 addEdge(node.id, userLoopNestId, value); 331 } 332 } 333 } 334 335 // Walk memref access lists and add graph edges between dependent nodes. 336 for (auto &memrefAndList : memrefAccesses) { 337 unsigned n = memrefAndList.second.size(); 338 for (unsigned i = 0; i < n; ++i) { 339 unsigned srcId = memrefAndList.second[i]; 340 bool srcHasStore = 341 getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0; 342 for (unsigned j = i + 1; j < n; ++j) { 343 unsigned dstId = memrefAndList.second[j]; 344 bool dstHasStore = 345 getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0; 346 if (srcHasStore || dstHasStore) 347 addEdge(srcId, dstId, memrefAndList.first); 348 } 349 } 350 } 351 return true; 352 } 353 354 // Sinks all sequential loops to the innermost levels (while preserving 355 // relative order among them) and moves all parallel loops to the 356 // outermost (while again preserving relative order among them). 357 // This can increase the loop depth at which we can fuse a slice, since we are 358 // pushing loop carried dependence to a greater depth in the loop nest. 359 static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { 360 assert(isa<AffineForOp>(node->op)); 361 AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op)); 362 node->op = newRootForOp; 363 } 364 365 // Creates and returns a private (single-user) memref for fused loop rooted 366 // at 'forOp', with (potentially reduced) memref size based on the 367 // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'. 368 // TODO: consider refactoring the common code from generateDma and 369 // this one. 370 static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, 371 unsigned dstLoopDepth, 372 std::optional<unsigned> fastMemorySpace, 373 uint64_t localBufSizeThreshold) { 374 Operation *forInst = forOp.getOperation(); 375 376 // Create builder to insert alloc op just before 'forOp'. 377 OpBuilder b(forInst); 378 // Builder to create constants at the top level. 379 OpBuilder top(forInst->getParentRegion()); 380 // Create new memref type based on slice bounds. 381 auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef(); 382 auto oldMemRefType = cast<MemRefType>(oldMemRef.getType()); 383 unsigned rank = oldMemRefType.getRank(); 384 385 // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'. 386 MemRefRegion region(srcStoreOpInst->getLoc()); 387 bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth)); 388 (void)validRegion; 389 assert(validRegion && "unexpected memref region failure"); 390 SmallVector<int64_t, 4> newShape; 391 std::vector<SmallVector<int64_t, 4>> lbs; 392 SmallVector<int64_t, 8> lbDivisors; 393 lbs.reserve(rank); 394 // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed 395 // by 'srcStoreOpInst' at depth 'dstLoopDepth'. 396 std::optional<int64_t> numElements = 397 region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors); 398 assert(numElements && "non-constant number of elts in local buffer"); 399 400 const FlatAffineValueConstraints *cst = region.getConstraints(); 401 // 'outerIVs' holds the values that this memory region is symbolic/parametric 402 // on; this would correspond to loop IVs surrounding the level at which the 403 // slice is being materialized. 404 SmallVector<Value, 8> outerIVs; 405 cst->getValues(rank, cst->getNumVars(), &outerIVs); 406 407 // Build 'rank' AffineExprs from MemRefRegion 'lbs' 408 SmallVector<AffineExpr, 4> offsets; 409 offsets.reserve(rank); 410 for (unsigned d = 0; d < rank; ++d) { 411 assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size"); 412 413 AffineExpr offset = top.getAffineConstantExpr(0); 414 for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) { 415 offset = offset + lbs[d][j] * top.getAffineDimExpr(j); 416 } 417 assert(lbDivisors[d] > 0); 418 offset = 419 (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]); 420 offsets.push_back(offset); 421 } 422 423 // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed 424 // by 'srcStoreOpInst'. 425 auto eltSize = getMemRefIntOrFloatEltSizeInBytes(oldMemRefType); 426 assert(eltSize && "memrefs with size elt types expected"); 427 uint64_t bufSize = *eltSize * *numElements; 428 unsigned newMemSpace; 429 if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) { 430 newMemSpace = *fastMemorySpace; 431 } else { 432 newMemSpace = oldMemRefType.getMemorySpaceAsInt(); 433 } 434 auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(), 435 {}, newMemSpace); 436 437 // Create new private memref for fused loop 'forOp'. 'newShape' is always 438 // a constant shape. 439 // TODO: Create/move alloc ops for private memrefs closer to their 440 // consumer loop nests to reduce their live range. Currently they are added 441 // at the beginning of the block, because loop nests can be reordered 442 // during the fusion pass. 443 Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType); 444 445 // Build an AffineMap to remap access functions based on lower bound offsets. 446 SmallVector<AffineExpr, 4> remapExprs; 447 remapExprs.reserve(rank); 448 for (unsigned i = 0; i < rank; i++) { 449 auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i); 450 451 auto remapExpr = 452 simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0); 453 remapExprs.push_back(remapExpr); 454 } 455 456 auto indexRemap = 457 AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext()); 458 459 // Replace all users of 'oldMemRef' with 'newMemRef'. 460 LogicalResult res = 461 replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, 462 /*extraOperands=*/outerIVs, 463 /*symbolOperands=*/{}, 464 /*domOpFilter=*/&*forOp.getBody()->begin()); 465 assert(succeeded(res) && 466 "replaceAllMemrefUsesWith should always succeed here"); 467 (void)res; 468 return newMemRef; 469 } 470 471 /// Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and 472 /// 'dstId'), if there is any non-affine operation accessing 'memref', return 473 /// true. Otherwise, return false. 474 static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, 475 Value memref, 476 MemRefDependenceGraph *mdg) { 477 auto *srcNode = mdg->getNode(srcId); 478 auto *dstNode = mdg->getNode(dstId); 479 Value::user_range users = memref.getUsers(); 480 // For each MemRefDependenceGraph's node that is between 'srcNode' and 481 // 'dstNode' (exclusive of 'srcNodes' and 'dstNode'), check whether any 482 // non-affine operation in the node accesses the 'memref'. 483 for (auto &idAndNode : mdg->nodes) { 484 Operation *op = idAndNode.second.op; 485 // Take care of operations between 'srcNode' and 'dstNode'. 486 if (srcNode->op->isBeforeInBlock(op) && op->isBeforeInBlock(dstNode->op)) { 487 // Walk inside the operation to find any use of the memref. 488 // Interrupt the walk if found. 489 auto walkResult = op->walk([&](Operation *user) { 490 // Skip affine ops. 491 if (isa<AffineMapAccessInterface>(*user)) 492 return WalkResult::advance(); 493 // Find a non-affine op that uses the memref. 494 if (llvm::is_contained(users, user)) 495 return WalkResult::interrupt(); 496 return WalkResult::advance(); 497 }); 498 if (walkResult.wasInterrupted()) 499 return true; 500 } 501 } 502 return false; 503 } 504 505 /// Check whether a memref value in node 'srcId' has a non-affine that 506 /// is between node 'srcId' and node 'dstId' (exclusive of 'srcNode' and 507 /// 'dstNode'). 508 static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, 509 MemRefDependenceGraph *mdg) { 510 // Collect memref values in node 'srcId'. 511 auto *srcNode = mdg->getNode(srcId); 512 llvm::SmallDenseSet<Value, 2> memRefValues; 513 srcNode->op->walk([&](Operation *op) { 514 // Skip affine ops. 515 if (isa<AffineForOp>(op)) 516 return WalkResult::advance(); 517 for (Value v : op->getOperands()) 518 // Collect memref values only. 519 if (isa<MemRefType>(v.getType())) 520 memRefValues.insert(v); 521 return WalkResult::advance(); 522 }); 523 // Looking for users between node 'srcId' and node 'dstId'. 524 return llvm::any_of(memRefValues, [&](Value memref) { 525 return hasNonAffineUsersOnThePath(srcId, dstId, memref, mdg); 526 }); 527 } 528 529 // Checks the profitability of fusing a backwards slice of the loop nest 530 // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'. 531 // The argument 'srcStoreOpInst' is used to calculate the storage reduction on 532 // the memref being produced and consumed, which is an input to the cost model. 533 // For producer-consumer fusion, 'srcStoreOpInst' will be the same as 534 // 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse 535 // fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the 536 // same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the 537 // unique store op in the src node, which will be used to check that the write 538 // region is the same after input-reuse fusion. Computation slices are provided 539 // in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which 540 // fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is 541 // profitable to fuse the candidate loop nests. Returns false otherwise. 542 // `dstLoopDepth` is set to the most profitable depth at which to materialize 543 // the source loop nest slice. 544 // The profitability model executes the following steps: 545 // *) Computes the backward computation slice at 'srcOpInst'. This 546 // computation slice of the loop nest surrounding 'srcOpInst' is 547 // represented by modified src loop bounds in 'sliceState', which are 548 // functions of loop IVs in the loop nest surrounding 'srcOpInst'. 549 // *) Computes the cost of unfused src/dst loop nests (currently the cost of a 550 // loop nest is the total number of dynamic operation instances in the loop 551 // nest). 552 // *) Computes the cost of fusing a slice of the src loop nest into the dst 553 // loop nest at various values of dst loop depth, attempting to fuse 554 // the largest computation slice at the maximal dst loop depth (closest to 555 // the load) to minimize reuse distance and potentially enable subsequent 556 // load/store forwarding. 557 // NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop 558 // nest, at which the src computation slice is inserted/fused. 559 // NOTE: We attempt to maximize the dst loop depth, but there are cases 560 // where a particular setting for 'dstLoopNest' might fuse an unsliced 561 // loop (within the src computation slice) at a depth which results in 562 // excessive recomputation (see unit tests for examples). 563 // *) Compares the total cost of the unfused loop nests to the min cost fused 564 // loop nest computed in the previous step, and returns true if the latter 565 // is lower. 566 // TODO: Extend profitability analysis to support scenarios with multiple 567 // stores. 568 static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, 569 AffineForOp dstForOp, 570 ArrayRef<ComputationSliceState> depthSliceUnions, 571 unsigned maxLegalFusionDepth, 572 unsigned *dstLoopDepth, 573 double computeToleranceThreshold) { 574 LLVM_DEBUG({ 575 llvm::dbgs() << "Checking whether fusion is profitable between src op:\n"; 576 llvm::dbgs() << ' ' << *srcOpInst << " and destination loop:\n"; 577 llvm::dbgs() << dstForOp << "\n"; 578 }); 579 580 if (maxLegalFusionDepth == 0) { 581 LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth is 0\n"); 582 return false; 583 } 584 585 // Compute cost of sliced and unsliced src loop nest. 586 SmallVector<AffineForOp, 4> srcLoopIVs; 587 getAffineForIVs(*srcOpInst, &srcLoopIVs); 588 589 // Walk src loop nest and collect stats. 590 LoopNestStats srcLoopNestStats; 591 if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats)) 592 return false; 593 594 // Compute cost of dst loop nest. 595 LoopNestStats dstLoopNestStats; 596 if (!getLoopNestStats(dstForOp, &dstLoopNestStats)) 597 return false; 598 599 // Search for min cost value for 'dstLoopDepth'. At each value of 600 // 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice 601 // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union 602 // of these bounds). Next the union slice bounds are used to calculate 603 // the cost of the slice and the cost of the slice inserted into the dst 604 // loop nest at 'dstLoopDepth'. 605 uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max(); 606 double maxStorageReduction = 0.0; 607 std::optional<uint64_t> sliceMemEstimate; 608 609 // The best loop depth at which to materialize the slice. 610 std::optional<unsigned> bestDstLoopDepth; 611 612 // Compute op instance count for the src loop nest without iteration slicing. 613 uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats); 614 615 // Compute src loop nest write region size. 616 MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc()); 617 if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) { 618 LLVM_DEBUG(llvm::dbgs() 619 << "Unable to compute MemRefRegion for source operation\n"); 620 return false; 621 } 622 623 std::optional<int64_t> maybeSrcWriteRegionSizeBytes = 624 srcWriteRegion.getRegionSize(); 625 if (!maybeSrcWriteRegionSizeBytes.has_value()) 626 return false; 627 int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes; 628 629 // Compute op instance count for the src loop nest. 630 uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats); 631 632 // Evaluate all depth choices for materializing the slice in the destination 633 // loop nest. 634 for (unsigned i = maxLegalFusionDepth; i >= 1; --i) { 635 const ComputationSliceState &slice = depthSliceUnions[i - 1]; 636 // Skip slice union if it wasn't computed for this depth. 637 if (slice.isEmpty()) 638 continue; 639 640 int64_t fusedLoopNestComputeCost; 641 if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp, 642 dstLoopNestStats, slice, 643 &fusedLoopNestComputeCost)) { 644 LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n"); 645 continue; 646 } 647 648 double additionalComputeFraction = 649 fusedLoopNestComputeCost / 650 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) - 651 1; 652 653 // Determine what the slice write MemRefRegion would be, if the src loop 654 // nest slice 'slice' were to be inserted into the dst loop nest at loop 655 // depth 'i'. 656 MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc()); 657 if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0, 658 &slice))) { 659 LLVM_DEBUG(llvm::dbgs() 660 << "Failed to compute slice write region at loopDepth: " << i 661 << "\n"); 662 continue; 663 } 664 665 std::optional<int64_t> maybeSliceWriteRegionSizeBytes = 666 sliceWriteRegion.getRegionSize(); 667 if (!maybeSliceWriteRegionSizeBytes.has_value() || 668 *maybeSliceWriteRegionSizeBytes == 0) { 669 LLVM_DEBUG(llvm::dbgs() 670 << "Failed to get slice write region size at loopDepth: " << i 671 << "\n"); 672 continue; 673 } 674 int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes; 675 676 // If we are fusing for reuse, check that write regions remain the same. 677 // TODO: Write region check should check sizes and offsets in 678 // each dimension, so that we are sure they are covering the same memref 679 // region. Also, move this out to a isMemRefRegionSuperSet helper function. 680 if (srcOpInst != srcStoreOpInst && 681 sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes) 682 continue; 683 684 double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) / 685 static_cast<double>(sliceWriteRegionSizeBytes); 686 687 LLVM_DEBUG({ 688 std::stringstream msg; 689 msg << " evaluating fusion profitability at depth : " << i << "\n" 690 << std::fixed << std::setprecision(2) 691 << " additional compute fraction: " 692 << 100.0 * additionalComputeFraction << "%\n" 693 << " storage reduction factor: " << storageReduction << "x\n" 694 << " fused nest cost: " << fusedLoopNestComputeCost << "\n" 695 << " src write region size: " << srcWriteRegionSizeBytes << "\n" 696 << " slice write region size: " << sliceWriteRegionSizeBytes 697 << "\n"; 698 llvm::dbgs() << msg.str(); 699 }); 700 701 // TODO: This is a placeholder cost model. 702 // Among all choices that add an acceptable amount of redundant computation 703 // (as per computeToleranceThreshold), we will simply pick the one that 704 // reduces the intermediary size the most. 705 if ((storageReduction > maxStorageReduction) && 706 (additionalComputeFraction < computeToleranceThreshold)) { 707 maxStorageReduction = storageReduction; 708 bestDstLoopDepth = i; 709 minFusedLoopNestComputeCost = fusedLoopNestComputeCost; 710 sliceMemEstimate = sliceWriteRegionSizeBytes; 711 } 712 } 713 714 // A simple cost model: fuse if it reduces the memory footprint. 715 716 if (!bestDstLoopDepth) { 717 LLVM_DEBUG( 718 llvm::dbgs() 719 << "All fusion choices involve more than the threshold amount of " 720 "redundant computation; NOT fusing.\n"); 721 return false; 722 } 723 724 if (!bestDstLoopDepth) { 725 LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n"); 726 return false; 727 } 728 729 // Set dstLoopDepth based on best values from search. 730 *dstLoopDepth = *bestDstLoopDepth; 731 732 LLVM_DEBUG( 733 llvm::dbgs() << " LoopFusion fusion stats:" 734 << "\n best loop depth: " << bestDstLoopDepth 735 << "\n src loop nest compute cost: " << srcLoopNestCost 736 << "\n dst loop nest compute cost: " << dstLoopNestCost 737 << "\n fused loop nest compute cost: " 738 << minFusedLoopNestComputeCost << "\n"); 739 740 auto dstMemSize = getMemoryFootprintBytes(dstForOp); 741 auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]); 742 743 std::optional<double> storageReduction; 744 745 if (!dstMemSize || !srcMemSize) { 746 LLVM_DEBUG(llvm::dbgs() 747 << " fusion memory benefit cannot be evaluated; NOT fusing.\n"); 748 return false; 749 } 750 751 auto srcMemSizeVal = *srcMemSize; 752 auto dstMemSizeVal = *dstMemSize; 753 754 assert(sliceMemEstimate && "expected value"); 755 auto fusedMem = dstMemSizeVal + *sliceMemEstimate; 756 757 LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n" 758 << " dst mem: " << dstMemSizeVal << "\n" 759 << " fused mem: " << fusedMem << "\n" 760 << " slice mem: " << sliceMemEstimate << "\n"); 761 762 if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) { 763 LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n"); 764 return false; 765 } 766 storageReduction = 767 100.0 * 768 (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal)); 769 770 double additionalComputeFraction = 771 100.0 * (minFusedLoopNestComputeCost / 772 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) - 773 1); 774 (void)additionalComputeFraction; 775 LLVM_DEBUG({ 776 std::stringstream msg; 777 msg << " fusion is most profitable at depth " << *dstLoopDepth << " with " 778 << std::setprecision(2) << additionalComputeFraction 779 << "% redundant computation and a "; 780 msg << (storageReduction ? std::to_string(*storageReduction) : "<unknown>"); 781 msg << "% storage reduction.\n"; 782 llvm::dbgs() << msg.str(); 783 }); 784 785 return true; 786 } 787 788 namespace { 789 790 // GreedyFusion greedily fuses loop nests which have a producer/consumer or 791 // input-reuse relationship on a memref, with the goal of improving locality. 792 // 793 // The steps of the producer-consumer fusion algorithm are as follows: 794 // 795 // *) A worklist is initialized with node ids from the dependence graph. 796 // *) For each node id in the worklist: 797 // *) Pop an AffineForOp of the worklist. This 'dstAffineForOp' will be a 798 // candidate destination AffineForOp into which fusion will be attempted. 799 // *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'. 800 // *) For each LoadOp in 'dstLoadOps' do: 801 // *) Look up dependent loop nests which have a single store op to the same 802 // memref. 803 // *) Check if dependences would be violated by the fusion. 804 // *) Get a computation slice of 'srcLoopNest', which adjusts its loop 805 // bounds to be functions of 'dstLoopNest' IVs and symbols. 806 // *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest', 807 // at a loop depth determined by the cost model in 'isFusionProfitable'. 808 // *) Add the newly fused load/store operations to the state, 809 // and also add newly fused load ops to 'dstLoopOps' to be considered 810 // as fusion dst load ops in another iteration. 811 // *) Remove old src loop nest and its associated state. 812 // 813 // The steps of the input-reuse fusion algorithm are as follows: 814 // 815 // *) Initialize 'worklist' with node ids from the dependence graph. 816 // *) For each 'dstNode' in the worklist: 817 // *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which 818 // loads from the same memref, but which has no dependence paths to/from. 819 // *) Get a computation slice of 'sibLoopNest', which adjusts its loop 820 // bounds to be functions of 'dstLoopNest' IVs and symbols. 821 // *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest', 822 // at a loop depth determined by the cost model in 'isFusionProfitable'. 823 // This function also checks that the memref write region of 'sibLoopNest', 824 // is preserved in the fused loop nest. 825 // *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'. 826 // 827 // Given a graph where top-level operations are vertices in the set 'V' and 828 // edges in the set 'E' are dependences between vertices, this algorithm 829 // takes O(V) time for initialization, and has runtime O(V + E). 830 // 831 // This greedy algorithm is not 'maximal' due to the current restriction of 832 // fusing along single producer consumer edges, but there is a TODO: to fix 833 // this. 834 // 835 // TODO: Experiment with other fusion policies. 836 struct GreedyFusion { 837 public: 838 // The data dependence graph to traverse during fusion. 839 MemRefDependenceGraph *mdg; 840 // Worklist of graph nodes visited during the fusion pass. 841 SmallVector<unsigned, 8> worklist; 842 // Parameter for local buffer size threshold. 843 unsigned localBufSizeThreshold; 844 // Parameter for fast memory space. 845 std::optional<unsigned> fastMemorySpace; 846 // If true, ignore any additional (redundant) computation tolerance threshold 847 // that would have prevented fusion. 848 bool maximalFusion; 849 // The amount of additional computation that is tolerated while fusing 850 // pair-wise as a fraction of the total computation. 851 double computeToleranceThreshold; 852 853 using Node = MemRefDependenceGraph::Node; 854 855 GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold, 856 std::optional<unsigned> fastMemorySpace, bool maximalFusion, 857 double computeToleranceThreshold) 858 : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold), 859 fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion), 860 computeToleranceThreshold(computeToleranceThreshold) {} 861 862 /// Initializes 'worklist' with nodes from 'mdg'. 863 void init() { 864 // TODO: Add a priority queue for prioritizing nodes by different 865 // metrics (e.g. arithmetic intensity/flops-to-bytes ratio). 866 worklist.clear(); 867 for (auto &idAndNode : mdg->nodes) { 868 const Node &node = idAndNode.second; 869 worklist.push_back(node.id); 870 } 871 } 872 /// Run only sibling fusion on the `mdg`. 873 void runSiblingFusionOnly() { 874 fuseSiblingNodes(); 875 eraseUnusedMemRefAllocations(); 876 } 877 878 /// Run only producer/consumer fusion on the `mdg`. 879 void runProducerConsumerFusionOnly() { 880 fuseProducerConsumerNodes( 881 /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max()); 882 eraseUnusedMemRefAllocations(); 883 } 884 885 // Run the GreedyFusion pass. 886 // *) First pass through the nodes fuses single-use producer nodes into their 887 // unique consumer. 888 // *) Second pass fuses sibling nodes which share no dependence edges. 889 // *) Third pass fuses any remaining producer nodes into their users. 890 void runGreedyFusion() { 891 // TODO: Run this repeatedly until a fixed-point is reached. 892 fuseProducerConsumerNodes(/*maxSrcUserCount=*/1); 893 fuseSiblingNodes(); 894 fuseProducerConsumerNodes( 895 /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max()); 896 eraseUnusedMemRefAllocations(); 897 } 898 899 /// Returns true if a private memref can be created for `memref` given 900 /// the fusion scenario reflected by the other arguments. 901 bool canCreatePrivateMemRef(Value memref, 902 const DenseSet<Value> &srcEscapingMemRefs, 903 unsigned producerId, unsigned consumerId, 904 bool removeSrcNode) { 905 const Node *consumerNode = mdg->getNode(consumerId); 906 // If `memref` is an escaping one, do not create a private memref 907 // for the below scenarios, since doing so will leave the escaping 908 // memref unmodified as all the writes originally meant for the 909 // escaping memref would be performed on the private memref: 910 // 1. The source is to be removed after fusion, 911 // OR 912 // 2. The destination writes to `memref`. 913 if (srcEscapingMemRefs.count(memref) > 0 && 914 (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0)) 915 return false; 916 917 // Don't create a private memref if 'srcNode' has in edges on 918 // 'memref' or 'dstNode' has out edges on 'memref'. 919 if (mdg->getIncomingMemRefAccesses(producerId, memref) > 0 || 920 mdg->getOutEdgeCount(consumerId, memref) > 0) 921 return false; 922 923 // If 'srcNode' will be removed but it has out edges on 'memref' to 924 // nodes other than 'dstNode', we have to preserve dependences and 925 // cannot create a private memref. 926 if (removeSrcNode && 927 any_of(mdg->outEdges[producerId], [&](const auto &edge) { 928 return edge.value == memref && edge.id != consumerId; 929 })) 930 return false; 931 932 return true; 933 } 934 935 /// Perform fusions with node `dstId` as the destination of fusion, with 936 /// No fusion is performed when producers with a user count greater than 937 /// `maxSrcUserCount` for any of the memrefs involved. 938 void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) { 939 LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n"); 940 // Skip if this node was removed (fused into another node). 941 if (mdg->nodes.count(dstId) == 0) 942 return; 943 // Get 'dstNode' into which to attempt fusion. 944 auto *dstNode = mdg->getNode(dstId); 945 // Skip if 'dstNode' is not a loop nest. 946 if (!isa<AffineForOp>(dstNode->op)) 947 return; 948 // Skip if 'dstNode' is a loop nest returning values. 949 // TODO: support loop nests that return values. 950 if (dstNode->op->getNumResults() > 0) 951 return; 952 953 LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n"); 954 955 // Sink sequential loops in 'dstNode' (and thus raise parallel loops) 956 // while preserving relative order. This can increase the maximum loop 957 // depth at which we can fuse a slice of a producer loop nest into a 958 // consumer loop nest. 959 sinkSequentialLoops(dstNode); 960 auto dstAffineForOp = cast<AffineForOp>(dstNode->op); 961 962 // Try to fuse 'dstNode' with candidate producer loops until a fixed point 963 // is reached. Fusing two loops may expose new fusion opportunities. 964 bool dstNodeChanged; 965 do { 966 // Gather src loop candidates for 'dstNode' and visit them in "quasi" 967 // reverse program order to minimize the number of iterations needed to 968 // reach the fixed point. Note that this is a best effort approach since 969 // 'getProducerCandidates' does not always guarantee that program order 970 // in 'srcIdCandidates'. 971 dstNodeChanged = false; 972 SmallVector<unsigned, 16> srcIdCandidates; 973 getProducerCandidates(dstId, mdg, srcIdCandidates); 974 975 for (unsigned srcId : llvm::reverse(srcIdCandidates)) { 976 // Get 'srcNode' from which to attempt fusion into 'dstNode'. 977 auto *srcNode = mdg->getNode(srcId); 978 auto srcAffineForOp = cast<AffineForOp>(srcNode->op); 979 LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId 980 << " for dst loop " << dstId << "\n"); 981 982 // Skip if 'srcNode' is a loop nest returning values. 983 // TODO: support loop nests that return values. 984 if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0) 985 continue; 986 987 DenseSet<Value> producerConsumerMemrefs; 988 gatherProducerConsumerMemrefs(srcId, dstId, mdg, 989 producerConsumerMemrefs); 990 991 // Skip if 'srcNode' out edge count on any memref is greater than 992 // 'maxSrcUserCount'. 993 if (any_of(producerConsumerMemrefs, [&](Value memref) { 994 return mdg->getOutEdgeCount(srcNode->id, memref) > 995 maxSrcUserCount; 996 })) 997 continue; 998 999 // Gather memrefs in 'srcNode' that are written and escape out of the 1000 // block (e.g., memref block arguments, returned memrefs, 1001 // memrefs passed to function calls, etc.). 1002 DenseSet<Value> srcEscapingMemRefs; 1003 gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs); 1004 1005 // Skip if there are non-affine operations in between the 'srcNode' 1006 // and 'dstNode' using their memrefs. If so, we wouldn't be able to 1007 // compute a legal insertion point for now. 'srcNode' and 'dstNode' 1008 // memrefs with non-affine operation users would be considered 1009 // escaping memrefs so we can limit this check to only scenarios with 1010 // escaping memrefs. 1011 if (!srcEscapingMemRefs.empty() && 1012 hasNonAffineUsersOnThePath(srcId, dstId, mdg)) { 1013 LLVM_DEBUG(llvm::dbgs() 1014 << "Can't fuse: non-affine users in between the loops\n"); 1015 continue; 1016 } 1017 1018 // Compute an operation list insertion point for the fused loop 1019 // nest which preserves dependences. 1020 Operation *fusedLoopInsPoint = 1021 mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id); 1022 if (fusedLoopInsPoint == nullptr) 1023 continue; 1024 1025 // Compute the innermost common loop depth for dstNode 1026 // producer-consumer loads/stores. 1027 SmallVector<Operation *, 2> dstMemrefOps; 1028 for (Operation *op : dstNode->loads) 1029 if (producerConsumerMemrefs.count( 1030 cast<AffineReadOpInterface>(op).getMemRef()) > 0) 1031 dstMemrefOps.push_back(op); 1032 for (Operation *op : dstNode->stores) 1033 if (producerConsumerMemrefs.count( 1034 cast<AffineWriteOpInterface>(op).getMemRef())) 1035 dstMemrefOps.push_back(op); 1036 unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps); 1037 1038 // Check the feasibility of fusing src loop nest into dst loop nest 1039 // at loop depths in range [1, dstLoopDepthTest]. 1040 unsigned maxLegalFusionDepth = 0; 1041 SmallVector<ComputationSliceState, 8> depthSliceUnions; 1042 depthSliceUnions.resize(dstLoopDepthTest); 1043 FusionStrategy strategy(FusionStrategy::ProducerConsumer); 1044 for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { 1045 FusionResult result = affine::canFuseLoops( 1046 srcAffineForOp, dstAffineForOp, 1047 /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy); 1048 1049 if (result.value == FusionResult::Success) 1050 maxLegalFusionDepth = i; 1051 } 1052 1053 if (maxLegalFusionDepth == 0) { 1054 LLVM_DEBUG(llvm::dbgs() 1055 << "Can't fuse: fusion is not legal at any depth\n"); 1056 continue; 1057 } 1058 1059 // Check if fusion would be profitable. We skip profitability analysis 1060 // for maximal fusion since we already know the maximal legal depth to 1061 // fuse. 1062 unsigned bestDstLoopDepth = maxLegalFusionDepth; 1063 if (!maximalFusion) { 1064 // Retrieve producer stores from the src loop. 1065 SmallVector<Operation *, 2> producerStores; 1066 for (Operation *op : srcNode->stores) 1067 if (producerConsumerMemrefs.count( 1068 cast<AffineWriteOpInterface>(op).getMemRef())) 1069 producerStores.push_back(op); 1070 1071 // TODO: Suppport multiple producer stores in profitability 1072 // analysis. We limit profitability analysis to only scenarios with 1073 // a single producer store for now. Note that some multi-store 1074 // producer scenarios will still go through profitability analysis 1075 // if only one of the stores is involved the producer-consumer 1076 // relationship of the candidate loops. 1077 assert(!producerStores.empty() && "Expected producer store"); 1078 if (producerStores.size() > 1) 1079 LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not " 1080 "supported for this case\n"); 1081 else if (!isFusionProfitable(producerStores[0], producerStores[0], 1082 dstAffineForOp, depthSliceUnions, 1083 maxLegalFusionDepth, &bestDstLoopDepth, 1084 computeToleranceThreshold)) 1085 continue; 1086 } 1087 1088 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); 1089 ComputationSliceState &bestSlice = 1090 depthSliceUnions[bestDstLoopDepth - 1]; 1091 assert(!bestSlice.isEmpty() && "Missing slice union for depth"); 1092 1093 // Determine if 'srcId' can be removed after fusion, taking into 1094 // account remaining dependences, escaping memrefs and the fusion 1095 // insertion point. 1096 bool removeSrcNode = canRemoveSrcNodeAfterFusion( 1097 srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs, 1098 mdg); 1099 1100 DenseSet<Value> privateMemrefs; 1101 for (Value memref : producerConsumerMemrefs) { 1102 if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId, 1103 removeSrcNode)) { 1104 // Create a private version of this memref. 1105 LLVM_DEBUG(llvm::dbgs() 1106 << "Creating private memref for " << memref << '\n'); 1107 // Create a private version of this memref. 1108 privateMemrefs.insert(memref); 1109 } 1110 } 1111 1112 // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. 1113 fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice); 1114 dstNodeChanged = true; 1115 1116 LLVM_DEBUG(llvm::dbgs() 1117 << "Fused src loop " << srcId << " into dst loop " << dstId 1118 << " at depth " << bestDstLoopDepth << ":\n" 1119 << dstAffineForOp << "\n"); 1120 1121 // Move 'dstAffineForOp' before 'insertPointInst' if needed. 1122 if (fusedLoopInsPoint != dstAffineForOp) 1123 dstAffineForOp->moveBefore(fusedLoopInsPoint); 1124 1125 // Update edges between 'srcNode' and 'dstNode'. 1126 mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs, 1127 removeSrcNode); 1128 1129 // Create private memrefs. 1130 if (!privateMemrefs.empty()) { 1131 // Gather stores for all the private-to-be memrefs. 1132 DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores; 1133 dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) { 1134 Value storeMemRef = storeOp.getMemRef(); 1135 if (privateMemrefs.count(storeMemRef) > 0) 1136 privateMemRefToStores[storeMemRef].push_back(storeOp); 1137 }); 1138 1139 // Replace original memrefs with private memrefs. Note that all the 1140 // loads and stores on these memrefs will be replaced with a new 1141 // loads and stores. Any reference to the original ones becomes 1142 // invalid after this point. 1143 for (auto &memrefToStoresPair : privateMemRefToStores) { 1144 // TODO: Use union of memref write regions to compute 1145 // private memref footprint. 1146 SmallVector<Operation *, 4> &storesForMemref = 1147 memrefToStoresPair.second; 1148 Value newMemRef = createPrivateMemRef( 1149 dstAffineForOp, storesForMemref[0], bestDstLoopDepth, 1150 fastMemorySpace, localBufSizeThreshold); 1151 // Create new node in dependence graph for 'newMemRef' alloc op. 1152 unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp()); 1153 // Add edge from 'newMemRef' node to dstNode. 1154 mdg->addEdge(newMemRefNodeId, dstId, newMemRef); 1155 } 1156 // One or more entries for 'newMemRef' alloc op are inserted into 1157 // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to 1158 // reallocate, update dstNode. 1159 dstNode = mdg->getNode(dstId); 1160 } 1161 1162 // Collect dst loop stats after memref privatization transformation. 1163 LoopNestStateCollector dstLoopCollector; 1164 dstLoopCollector.collect(dstAffineForOp); 1165 1166 // Clear and add back loads and stores. 1167 mdg->clearNodeLoadAndStores(dstNode->id); 1168 mdg->addToNode(dstId, dstLoopCollector.loadOpInsts, 1169 dstLoopCollector.storeOpInsts); 1170 1171 if (removeSrcNode) { 1172 LLVM_DEBUG(llvm::dbgs() 1173 << "Removing src loop " << srcId << " after fusion\n"); 1174 // srcNode is no longer valid after it is removed from mdg. 1175 srcAffineForOp.erase(); 1176 mdg->removeNode(srcId); 1177 srcNode = nullptr; 1178 } 1179 } 1180 } while (dstNodeChanged); 1181 } 1182 1183 /// Visit each node in the graph, and for each node, attempt to fuse it with 1184 /// producer-consumer candidates. No fusion is performed when producers with a 1185 /// user count greater than `maxSrcUserCount` for any of the memrefs involved 1186 /// are encountered. 1187 void fuseProducerConsumerNodes(unsigned maxSrcUserCount) { 1188 LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n"); 1189 init(); 1190 while (!worklist.empty()) { 1191 unsigned dstId = worklist.back(); 1192 worklist.pop_back(); 1193 performFusionsIntoDest(dstId, maxSrcUserCount); 1194 } 1195 } 1196 1197 // Visits each node in the graph, and for each node, attempts to fuse it with 1198 // its sibling nodes (nodes which share a parent, but no dependence edges). 1199 void fuseSiblingNodes() { 1200 LLVM_DEBUG(llvm::dbgs() << "--- Sibling Fusion ---\n"); 1201 init(); 1202 while (!worklist.empty()) { 1203 unsigned dstId = worklist.back(); 1204 worklist.pop_back(); 1205 1206 // Skip if this node was removed (fused into another node). 1207 if (mdg->nodes.count(dstId) == 0) 1208 continue; 1209 // Get 'dstNode' into which to attempt fusion. 1210 auto *dstNode = mdg->getNode(dstId); 1211 // Skip if 'dstNode' is not a loop nest. 1212 if (!isa<AffineForOp>(dstNode->op)) 1213 continue; 1214 // Attempt to fuse 'dstNode' with its sibling nodes in the graph. 1215 fuseWithSiblingNodes(dstNode); 1216 } 1217 } 1218 1219 // Attempt to fuse 'dstNode' with sibling nodes in the graph. 1220 void fuseWithSiblingNodes(Node *dstNode) { 1221 DenseSet<unsigned> visitedSibNodeIds; 1222 std::pair<unsigned, Value> idAndMemref; 1223 auto dstAffineForOp = cast<AffineForOp>(dstNode->op); 1224 1225 while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) { 1226 unsigned sibId = idAndMemref.first; 1227 Value memref = idAndMemref.second; 1228 // TODO: Check that 'sibStoreOpInst' post-dominates all other 1229 // stores to the same memref in 'sibNode' loop nest. 1230 auto *sibNode = mdg->getNode(sibId); 1231 // Compute an operation list insertion point for the fused loop 1232 // nest which preserves dependences. 1233 assert(sibNode->op->getBlock() == dstNode->op->getBlock()); 1234 Operation *insertPointInst = 1235 sibNode->op->isBeforeInBlock(dstNode->op) 1236 ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id) 1237 : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id); 1238 if (insertPointInst == nullptr) 1239 continue; 1240 1241 // Check if fusion would be profitable and at what depth. 1242 1243 // Get unique 'sibNode' load op to 'memref'. 1244 SmallVector<Operation *, 2> sibLoadOpInsts; 1245 sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts); 1246 // Currently findSiblingNodeToFuse searches for siblings with one load. 1247 assert(sibLoadOpInsts.size() == 1); 1248 Operation *sibLoadOpInst = sibLoadOpInsts[0]; 1249 1250 // Gather 'dstNode' load ops to 'memref'. 1251 SmallVector<Operation *, 2> dstLoadOpInsts; 1252 dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts); 1253 1254 SmallVector<AffineForOp, 4> dstLoopIVs; 1255 getAffineForIVs(*dstLoadOpInsts[0], &dstLoopIVs); 1256 unsigned dstLoopDepthTest = dstLoopIVs.size(); 1257 auto sibAffineForOp = cast<AffineForOp>(sibNode->op); 1258 1259 // Compute loop depth and slice union for fusion. 1260 SmallVector<ComputationSliceState, 8> depthSliceUnions; 1261 depthSliceUnions.resize(dstLoopDepthTest); 1262 unsigned maxLegalFusionDepth = 0; 1263 FusionStrategy strategy(memref); 1264 for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { 1265 FusionResult result = affine::canFuseLoops( 1266 sibAffineForOp, dstAffineForOp, 1267 /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy); 1268 1269 if (result.value == FusionResult::Success) 1270 maxLegalFusionDepth = i; 1271 } 1272 1273 // Skip if fusion is not feasible at any loop depths. 1274 if (maxLegalFusionDepth == 0) 1275 continue; 1276 1277 unsigned bestDstLoopDepth = maxLegalFusionDepth; 1278 if (!maximalFusion) { 1279 // Check if fusion would be profitable. For sibling fusion, the sibling 1280 // load op is treated as the src "store" op for fusion profitability 1281 // purposes. The footprint of the load in the slice relative to the 1282 // unfused source's determines reuse. 1283 if (!isFusionProfitable(sibLoadOpInst, sibLoadOpInst, dstAffineForOp, 1284 depthSliceUnions, maxLegalFusionDepth, 1285 &bestDstLoopDepth, computeToleranceThreshold)) 1286 continue; 1287 } 1288 1289 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); 1290 assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() && 1291 "Fusion depth has no computed slice union"); 1292 // Check if source loop is being inserted in the innermost 1293 // destination loop. Based on this, the fused loop may be optimized 1294 // further inside `fuseLoops`. 1295 bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest); 1296 // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'. 1297 affine::fuseLoops(sibAffineForOp, dstAffineForOp, 1298 depthSliceUnions[bestDstLoopDepth - 1], 1299 isInnermostInsertion); 1300 1301 auto dstForInst = cast<AffineForOp>(dstNode->op); 1302 // Update operation position of fused loop nest (if needed). 1303 if (insertPointInst != dstForInst) { 1304 dstForInst->moveBefore(insertPointInst); 1305 } 1306 // Update data dependence graph state post fusion. 1307 updateStateAfterSiblingFusion(sibNode, dstNode); 1308 } 1309 } 1310 1311 // Searches block argument uses and the graph from 'dstNode' looking for a 1312 // fusion candidate sibling node which shares no dependences with 'dstNode' 1313 // but which loads from the same memref. Returns true and sets 1314 // 'idAndMemrefToFuse' on success. Returns false otherwise. 1315 bool findSiblingNodeToFuse(Node *dstNode, 1316 DenseSet<unsigned> *visitedSibNodeIds, 1317 std::pair<unsigned, Value> *idAndMemrefToFuse) { 1318 // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse 1319 // on 'memref'. 1320 auto canFuseWithSibNode = [&](Node *sibNode, Value memref) { 1321 // Skip if 'outEdge' is not a read-after-write dependence. 1322 // TODO: Remove restrict to single load op restriction. 1323 if (sibNode->getLoadOpCount(memref) != 1) 1324 return false; 1325 // Skip if there exists a path of dependent edges between 1326 // 'sibNode' and 'dstNode'. 1327 if (mdg->hasDependencePath(sibNode->id, dstNode->id) || 1328 mdg->hasDependencePath(dstNode->id, sibNode->id)) 1329 return false; 1330 // Skip sib node if it loads to (and stores from) the same memref on 1331 // which it also has an input dependence edge. 1332 DenseSet<Value> loadAndStoreMemrefSet; 1333 sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet); 1334 if (llvm::any_of(loadAndStoreMemrefSet, [=](Value memref) { 1335 return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0; 1336 })) 1337 return false; 1338 1339 // Check that all stores are to the same memref if any. 1340 DenseSet<Value> storeMemrefs; 1341 for (auto *storeOpInst : sibNode->stores) { 1342 storeMemrefs.insert( 1343 cast<AffineWriteOpInterface>(storeOpInst).getMemRef()); 1344 } 1345 if (storeMemrefs.size() > 1) 1346 return false; 1347 1348 // Skip if a memref value in one node is used by a non-affine memref 1349 // access that lies between 'dstNode' and 'sibNode'. 1350 if (hasNonAffineUsersOnThePath(dstNode->id, sibNode->id, mdg) || 1351 hasNonAffineUsersOnThePath(sibNode->id, dstNode->id, mdg)) 1352 return false; 1353 return true; 1354 }; 1355 1356 // Search for siblings which load the same memref block argument. 1357 Block *block = dstNode->op->getBlock(); 1358 for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) { 1359 for (Operation *user : block->getArgument(i).getUsers()) { 1360 auto loadOp = dyn_cast<AffineReadOpInterface>(user); 1361 if (!loadOp) 1362 continue; 1363 // Gather loops surrounding 'use'. 1364 SmallVector<AffineForOp, 4> loops; 1365 getAffineForIVs(*user, &loops); 1366 // Skip 'use' if it is not within a loop nest. 1367 if (loops.empty()) 1368 continue; 1369 Node *sibNode = mdg->getForOpNode(loops[0]); 1370 assert(sibNode != nullptr); 1371 // Skip 'use' if it not a sibling to 'dstNode'. 1372 if (sibNode->id == dstNode->id) 1373 continue; 1374 // Skip 'use' if it has been visited. 1375 if (visitedSibNodeIds->count(sibNode->id) > 0) 1376 continue; 1377 // Skip 'use' if it does not load from the same memref as 'dstNode'. 1378 auto memref = loadOp.getMemRef(); 1379 if (dstNode->getLoadOpCount(memref) == 0) 1380 continue; 1381 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. 1382 if (canFuseWithSibNode(sibNode, memref)) { 1383 visitedSibNodeIds->insert(sibNode->id); 1384 idAndMemrefToFuse->first = sibNode->id; 1385 idAndMemrefToFuse->second = memref; 1386 return true; 1387 } 1388 } 1389 } 1390 1391 // Search for siblings by following edges through an intermediate src node. 1392 // Collect candidate 'dstNode' input edges in 'inEdges'. 1393 SmallVector<MemRefDependenceGraph::Edge, 2> inEdges; 1394 mdg->forEachMemRefInputEdge( 1395 dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) { 1396 // Add 'inEdge' if it is a read-after-write dependence. 1397 if (dstNode->getLoadOpCount(inEdge.value) > 0 && 1398 mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0) 1399 inEdges.push_back(inEdge); 1400 }); 1401 1402 // Search for sibling nodes to fuse by visiting output edges from each input 1403 // edge in 'inEdges'. 1404 for (auto &inEdge : inEdges) { 1405 // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'. 1406 SmallVector<MemRefDependenceGraph::Edge, 2> outEdges; 1407 mdg->forEachMemRefOutputEdge( 1408 inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) { 1409 unsigned sibNodeId = outEdge.id; 1410 if (visitedSibNodeIds->count(sibNodeId) > 0) 1411 return; 1412 // Skip output edge if not a sibling using the same memref. 1413 if (outEdge.id == dstNode->id || outEdge.value != inEdge.value) 1414 return; 1415 auto *sibNode = mdg->getNode(sibNodeId); 1416 if (!isa<AffineForOp>(sibNode->op)) 1417 return; 1418 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. 1419 if (canFuseWithSibNode(sibNode, outEdge.value)) { 1420 // Add candidate 'outEdge' to sibling node. 1421 outEdges.push_back(outEdge); 1422 } 1423 }); 1424 1425 // Add first candidate if any were returned. 1426 if (!outEdges.empty()) { 1427 visitedSibNodeIds->insert(outEdges[0].id); 1428 idAndMemrefToFuse->first = outEdges[0].id; 1429 idAndMemrefToFuse->second = outEdges[0].value; 1430 return true; 1431 } 1432 } 1433 return false; 1434 } 1435 1436 /// Update data dependence graph state to reflect sibling fusion of 'sibNode' 1437 /// into 'dstNode'. 1438 void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) { 1439 // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion. 1440 mdg->updateEdges(sibNode->id, dstNode->id); 1441 1442 // Collect dst loop stats after memref privatization transformation. 1443 auto dstForInst = cast<AffineForOp>(dstNode->op); 1444 LoopNestStateCollector dstLoopCollector; 1445 dstLoopCollector.collect(dstForInst); 1446 // Clear and add back loads and stores 1447 mdg->clearNodeLoadAndStores(dstNode->id); 1448 mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts, 1449 dstLoopCollector.storeOpInsts); 1450 // Remove old sibling loop nest if it no longer has outgoing dependence 1451 // edges, and it does not write to a memref which escapes the block. 1452 if (mdg->getOutEdgeCount(sibNode->id) == 0) { 1453 Operation *op = sibNode->op; 1454 mdg->removeNode(sibNode->id); 1455 op->erase(); 1456 } 1457 } 1458 1459 // Clean up any allocs with no users. 1460 void eraseUnusedMemRefAllocations() { 1461 for (auto &pair : mdg->memrefEdgeCount) { 1462 if (pair.second > 0) 1463 continue; 1464 auto memref = pair.first; 1465 // Skip if there exist other uses (return operation or function calls). 1466 if (!memref.use_empty()) 1467 continue; 1468 // Use list expected to match the dep graph info. 1469 auto *op = memref.getDefiningOp(); 1470 if (isa_and_nonnull<memref::AllocOp>(op)) 1471 op->erase(); 1472 } 1473 } 1474 }; 1475 1476 } // namespace 1477 1478 /// Run fusion on `block`. 1479 void LoopFusion::runOnBlock(Block *block) { 1480 MemRefDependenceGraph g(*block); 1481 if (!g.init()) { 1482 LLVM_DEBUG(llvm::dbgs() << "MDG init failed\n"); 1483 return; 1484 } 1485 1486 std::optional<unsigned> fastMemorySpaceOpt; 1487 if (fastMemorySpace.hasValue()) 1488 fastMemorySpaceOpt = fastMemorySpace; 1489 unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024; 1490 GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt, 1491 maximalFusion, computeToleranceThreshold); 1492 1493 if (affineFusionMode == FusionMode::ProducerConsumer) 1494 fusion.runProducerConsumerFusionOnly(); 1495 else if (affineFusionMode == FusionMode::Sibling) 1496 fusion.runSiblingFusionOnly(); 1497 else 1498 fusion.runGreedyFusion(); 1499 } 1500 1501 void LoopFusion::runOnOperation() { 1502 for (Region ®ion : getOperation()->getRegions()) 1503 for (Block &block : region.getBlocks()) 1504 runOnBlock(&block); 1505 } 1506 1507 std::unique_ptr<Pass> mlir::affine::createLoopFusionPass( 1508 unsigned fastMemorySpace, uint64_t localBufSizeThreshold, 1509 bool maximalFusion, enum FusionMode affineFusionMode) { 1510 return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold, 1511 maximalFusion, affineFusionMode); 1512 } 1513