1 //===- LoopFusion.cpp - Code to perform loop fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements affine fusion. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/Passes.h" 14 15 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 16 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" 17 #include "mlir/Dialect/Affine/Analysis/Utils.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Affine/LoopFusionUtils.h" 20 #include "mlir/Dialect/Affine/LoopUtils.h" 21 #include "mlir/Dialect/Affine/Utils.h" 22 #include "mlir/Dialect/MemRef/IR/MemRef.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/AffineMap.h" 25 #include "mlir/IR/Builders.h" 26 #include "mlir/Transforms/Passes.h" 27 #include "llvm/ADT/DenseMap.h" 28 #include "llvm/ADT/DenseSet.h" 29 #include "llvm/ADT/STLExtras.h" 30 #include "llvm/ADT/SetVector.h" 31 #include "llvm/Support/CommandLine.h" 32 #include "llvm/Support/Debug.h" 33 #include "llvm/Support/raw_ostream.h" 34 #include <iomanip> 35 #include <optional> 36 #include <sstream> 37 38 namespace mlir { 39 #define GEN_PASS_DEF_AFFINELOOPFUSION 40 #include "mlir/Dialect/Affine/Passes.h.inc" 41 } // namespace mlir 42 43 #define DEBUG_TYPE "affine-loop-fusion" 44 45 using namespace mlir; 46 47 namespace { 48 /// Loop fusion pass. This pass currently supports a greedy fusion policy, 49 /// which fuses loop nests with single-writer/single-reader memref dependences 50 /// with the goal of improving locality. 51 52 // TODO: Support fusion of source loop nests which write to multiple 53 // memrefs, where each memref can have multiple users (if profitable). 54 // TODO: Extend this pass to check for fusion preventing dependences, 55 // and add support for more general loop fusion algorithms. 56 57 struct LoopFusion : public impl::AffineLoopFusionBase<LoopFusion> { 58 LoopFusion() = default; 59 LoopFusion(unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes, 60 bool maximalFusion, enum FusionMode affineFusionMode) { 61 this->fastMemorySpace = fastMemorySpace; 62 this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024; 63 this->maximalFusion = maximalFusion; 64 this->affineFusionMode = affineFusionMode; 65 } 66 67 void runOnBlock(Block *block); 68 void runOnOperation() override; 69 }; 70 71 } // namespace 72 73 /// Returns true if node 'srcId' can be removed after fusing it with node 74 /// 'dstId'. The node can be removed if any of the following conditions are met: 75 /// 1. 'srcId' has no output dependences after fusion and no escaping memrefs. 76 /// 2. 'srcId' has no output dependences after fusion, has escaping memrefs 77 /// and the fusion slice is maximal. 78 /// 3. 'srcId' has output dependences after fusion, the fusion slice is 79 /// maximal and the fusion insertion point dominates all the dependences. 80 static bool canRemoveSrcNodeAfterFusion( 81 unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice, 82 Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs, 83 MemRefDependenceGraph *mdg) { 84 85 Operation *dstNodeOp = mdg->getNode(dstId)->op; 86 bool hasOutDepsAfterFusion = false; 87 88 for (auto &outEdge : mdg->outEdges[srcId]) { 89 Operation *depNodeOp = mdg->getNode(outEdge.id)->op; 90 // Skip dependence with dstOp since it will be removed after fusion. 91 if (depNodeOp == dstNodeOp) 92 continue; 93 94 // Only fusion within the same block is supported. Use domination analysis 95 // when needed. 96 if (depNodeOp->getBlock() != dstNodeOp->getBlock()) 97 return false; 98 99 // Check if the insertion point of the fused loop dominates the dependence. 100 // Otherwise, the src loop can't be removed. 101 if (fusedLoopInsPoint != depNodeOp && 102 !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) { 103 LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't " 104 "dominate dependence\n"); 105 return false; 106 } 107 108 hasOutDepsAfterFusion = true; 109 } 110 111 // If src loop has dependences after fusion or it writes to an live-out or 112 // escaping memref, we can only remove it if the fusion slice is maximal so 113 // that all the dependences are preserved. 114 if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) { 115 std::optional<bool> isMaximal = fusionSlice.isMaximal(); 116 if (!isMaximal) { 117 LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine " 118 "if fusion is maximal\n"); 119 return false; 120 } 121 122 if (!*isMaximal) { 123 LLVM_DEBUG(llvm::dbgs() 124 << "Src loop can't be removed: fusion is not maximal\n"); 125 return false; 126 } 127 } 128 129 return true; 130 } 131 132 /// Returns in 'srcIdCandidates' the producer fusion candidates for consumer 133 /// 'dstId'. Candidates are sorted by node id order. This order corresponds to 134 /// the program order when the 'mdg' is created. However, program order is not 135 /// guaranteed and must not be required by the client. Program order won't be 136 /// held if the 'mdg' is reused from a previous fusion step or if the node 137 /// creation order changes in the future to support more advance cases. 138 // TODO: Move this to a loop fusion utility once 'mdg' is also moved. 139 static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg, 140 SmallVectorImpl<unsigned> &srcIdCandidates) { 141 // Skip if no input edges along which to fuse. 142 if (mdg->inEdges.count(dstId) == 0) 143 return; 144 145 // Gather memrefs from loads in 'dstId'. 146 auto *dstNode = mdg->getNode(dstId); 147 DenseSet<Value> consumedMemrefs; 148 for (Operation *load : dstNode->loads) 149 consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef()); 150 151 // Traverse 'dstId' incoming edges and gather the nodes that contain a store 152 // to one of the consumed memrefs. 153 for (auto &srcEdge : mdg->inEdges[dstId]) { 154 auto *srcNode = mdg->getNode(srcEdge.id); 155 // Skip if 'srcNode' is not a loop nest. 156 if (!isa<AffineForOp>(srcNode->op)) 157 continue; 158 159 if (any_of(srcNode->stores, [&](Operation *op) { 160 auto storeOp = cast<AffineWriteOpInterface>(op); 161 return consumedMemrefs.count(storeOp.getMemRef()) > 0; 162 })) 163 srcIdCandidates.push_back(srcNode->id); 164 } 165 166 llvm::sort(srcIdCandidates); 167 srcIdCandidates.erase( 168 std::unique(srcIdCandidates.begin(), srcIdCandidates.end()), 169 srcIdCandidates.end()); 170 } 171 172 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a 173 /// producer-consumer dependence between 'srcId' and 'dstId'. 174 static void 175 gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId, 176 MemRefDependenceGraph *mdg, 177 DenseSet<Value> &producerConsumerMemrefs) { 178 auto *dstNode = mdg->getNode(dstId); 179 auto *srcNode = mdg->getNode(srcId); 180 gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads, 181 producerConsumerMemrefs); 182 } 183 184 /// A memref escapes in the context of the fusion pass if either: 185 /// 1. it (or its alias) is a block argument, or 186 /// 2. created by an op not known to guarantee alias freedom, 187 /// 3. it (or its alias) are used by ops other than affine dereferencing ops 188 /// (e.g., by call op, memref load/store ops, alias creating ops, unknown ops, 189 /// terminator ops, etc.); such ops do not deference the memref in an affine 190 /// way. 191 static bool isEscapingMemref(Value memref, Block *block) { 192 Operation *defOp = memref.getDefiningOp(); 193 // Check if 'memref' is a block argument. 194 if (!defOp) 195 return true; 196 197 // Check if this is defined to be an alias of another memref. 198 if (auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp)) 199 if (isEscapingMemref(viewOp.getViewSource(), block)) 200 return true; 201 202 // Any op besides allocating ops wouldn't guarantee alias freedom 203 if (!hasSingleEffect<mlir::MemoryEffects::Allocate>(defOp, memref)) 204 return true; 205 206 // Check if 'memref' is used by a non-deferencing op (including unknown ones) 207 // (e.g., call ops, alias creating ops, etc.). 208 return llvm::any_of(memref.getUsers(), [&](Operation *user) { 209 // Ignore users outside of `block`. 210 if (block->getParent()->findAncestorOpInRegion(*user)->getBlock() != block) 211 return false; 212 return !isa<AffineMapAccessInterface>(*user); 213 }); 214 } 215 216 /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' 217 /// that escape the block or are accessed in a non-affine way. 218 static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg, 219 DenseSet<Value> &escapingMemRefs) { 220 auto *node = mdg->getNode(id); 221 for (Operation *storeOp : node->stores) { 222 auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef(); 223 if (escapingMemRefs.count(memref)) 224 continue; 225 if (isEscapingMemref(memref, &mdg->block)) 226 escapingMemRefs.insert(memref); 227 } 228 } 229 230 // Initializes the data dependence graph by walking operations in `block`. 231 // Assigns each node in the graph a node id based on program order in 'f'. 232 bool MemRefDependenceGraph::init() { 233 LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n"); 234 // Map from a memref to the set of ids of the nodes that have ops accessing 235 // the memref. 236 DenseMap<Value, SetVector<unsigned>> memrefAccesses; 237 238 DenseMap<Operation *, unsigned> forToNodeMap; 239 for (Operation &op : block) { 240 if (auto forOp = dyn_cast<AffineForOp>(op)) { 241 // Create graph node 'id' to represent top-level 'forOp' and record 242 // all loads and store accesses it contains. 243 LoopNestStateCollector collector; 244 collector.collect(&op); 245 // Return false if a region holding op other than 'affine.for' and 246 // 'affine.if' was found (not currently supported). 247 if (collector.hasNonAffineRegionOp) 248 return false; 249 Node node(nextNodeId++, &op); 250 for (auto *opInst : collector.loadOpInsts) { 251 node.loads.push_back(opInst); 252 auto memref = cast<AffineReadOpInterface>(opInst).getMemRef(); 253 memrefAccesses[memref].insert(node.id); 254 } 255 for (auto *opInst : collector.storeOpInsts) { 256 node.stores.push_back(opInst); 257 auto memref = cast<AffineWriteOpInterface>(opInst).getMemRef(); 258 memrefAccesses[memref].insert(node.id); 259 } 260 forToNodeMap[&op] = node.id; 261 nodes.insert({node.id, node}); 262 } else if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) { 263 // Create graph node for top-level load op. 264 Node node(nextNodeId++, &op); 265 node.loads.push_back(&op); 266 auto memref = cast<AffineReadOpInterface>(op).getMemRef(); 267 memrefAccesses[memref].insert(node.id); 268 nodes.insert({node.id, node}); 269 } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) { 270 // Create graph node for top-level store op. 271 Node node(nextNodeId++, &op); 272 node.stores.push_back(&op); 273 auto memref = cast<AffineWriteOpInterface>(op).getMemRef(); 274 memrefAccesses[memref].insert(node.id); 275 nodes.insert({node.id, node}); 276 } else if (op.getNumRegions() != 0) { 277 // Return false if another region is found (not currently supported). 278 return false; 279 } else if (op.getNumResults() > 0 && !op.use_empty()) { 280 // Create graph node for top-level producer of SSA values, which 281 // could be used by loop nest nodes. 282 Node node(nextNodeId++, &op); 283 nodes.insert({node.id, node}); 284 } else if (isa<CallOpInterface>(op)) { 285 // Create graph node for top-level Call Op that takes any argument of 286 // memref type. Call Op that returns one or more memref type results 287 // is already taken care of, by the previous conditions. 288 if (llvm::any_of(op.getOperandTypes(), 289 [&](Type t) { return t.isa<MemRefType>(); })) { 290 Node node(nextNodeId++, &op); 291 nodes.insert({node.id, node}); 292 } 293 } else if (hasEffect<MemoryEffects::Write, MemoryEffects::Free>(&op)) { 294 // Create graph node for top-level op, which could have a memory write 295 // side effect. 296 Node node(nextNodeId++, &op); 297 nodes.insert({node.id, node}); 298 } 299 } 300 301 for (auto &idAndNode : nodes) { 302 LLVM_DEBUG(llvm::dbgs() << "Create node " << idAndNode.first << " for:\n" 303 << *(idAndNode.second.op) << "\n"); 304 (void)idAndNode; 305 } 306 307 // Add dependence edges between nodes which produce SSA values and their 308 // users. Load ops can be considered as the ones producing SSA values. 309 for (auto &idAndNode : nodes) { 310 const Node &node = idAndNode.second; 311 // Stores don't define SSA values, skip them. 312 if (!node.stores.empty()) 313 continue; 314 Operation *opInst = node.op; 315 for (Value value : opInst->getResults()) { 316 for (Operation *user : value.getUsers()) { 317 // Ignore users outside of the block. 318 if (block.getParent()->findAncestorOpInRegion(*user)->getBlock() != 319 &block) 320 continue; 321 SmallVector<AffineForOp, 4> loops; 322 getAffineForIVs(*user, &loops); 323 if (loops.empty()) 324 continue; 325 assert(forToNodeMap.count(loops[0]) > 0 && "missing mapping"); 326 unsigned userLoopNestId = forToNodeMap[loops[0]]; 327 addEdge(node.id, userLoopNestId, value); 328 } 329 } 330 } 331 332 // Walk memref access lists and add graph edges between dependent nodes. 333 for (auto &memrefAndList : memrefAccesses) { 334 unsigned n = memrefAndList.second.size(); 335 for (unsigned i = 0; i < n; ++i) { 336 unsigned srcId = memrefAndList.second[i]; 337 bool srcHasStore = 338 getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0; 339 for (unsigned j = i + 1; j < n; ++j) { 340 unsigned dstId = memrefAndList.second[j]; 341 bool dstHasStore = 342 getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0; 343 if (srcHasStore || dstHasStore) 344 addEdge(srcId, dstId, memrefAndList.first); 345 } 346 } 347 } 348 return true; 349 } 350 351 // Sinks all sequential loops to the innermost levels (while preserving 352 // relative order among them) and moves all parallel loops to the 353 // outermost (while again preserving relative order among them). 354 // This can increase the loop depth at which we can fuse a slice, since we are 355 // pushing loop carried dependence to a greater depth in the loop nest. 356 static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { 357 assert(isa<AffineForOp>(node->op)); 358 AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op)); 359 node->op = newRootForOp; 360 } 361 362 // Creates and returns a private (single-user) memref for fused loop rooted 363 // at 'forOp', with (potentially reduced) memref size based on the 364 // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'. 365 // TODO: consider refactoring the common code from generateDma and 366 // this one. 367 static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, 368 unsigned dstLoopDepth, 369 std::optional<unsigned> fastMemorySpace, 370 uint64_t localBufSizeThreshold) { 371 Operation *forInst = forOp.getOperation(); 372 373 // Create builder to insert alloc op just before 'forOp'. 374 OpBuilder b(forInst); 375 // Builder to create constants at the top level. 376 OpBuilder top(forInst->getParentRegion()); 377 // Create new memref type based on slice bounds. 378 auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef(); 379 auto oldMemRefType = oldMemRef.getType().cast<MemRefType>(); 380 unsigned rank = oldMemRefType.getRank(); 381 382 // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'. 383 MemRefRegion region(srcStoreOpInst->getLoc()); 384 bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth)); 385 (void)validRegion; 386 assert(validRegion && "unexpected memref region failure"); 387 SmallVector<int64_t, 4> newShape; 388 std::vector<SmallVector<int64_t, 4>> lbs; 389 SmallVector<int64_t, 8> lbDivisors; 390 lbs.reserve(rank); 391 // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed 392 // by 'srcStoreOpInst' at depth 'dstLoopDepth'. 393 std::optional<int64_t> numElements = 394 region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors); 395 assert(numElements && "non-constant number of elts in local buffer"); 396 397 const FlatAffineValueConstraints *cst = region.getConstraints(); 398 // 'outerIVs' holds the values that this memory region is symbolic/parametric 399 // on; this would correspond to loop IVs surrounding the level at which the 400 // slice is being materialized. 401 SmallVector<Value, 8> outerIVs; 402 cst->getValues(rank, cst->getNumVars(), &outerIVs); 403 404 // Build 'rank' AffineExprs from MemRefRegion 'lbs' 405 SmallVector<AffineExpr, 4> offsets; 406 offsets.reserve(rank); 407 for (unsigned d = 0; d < rank; ++d) { 408 assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size"); 409 410 AffineExpr offset = top.getAffineConstantExpr(0); 411 for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) { 412 offset = offset + lbs[d][j] * top.getAffineDimExpr(j); 413 } 414 assert(lbDivisors[d] > 0); 415 offset = 416 (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]); 417 offsets.push_back(offset); 418 } 419 420 // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed 421 // by 'srcStoreOpInst'. 422 auto eltSize = getMemRefIntOrFloatEltSizeInBytes(oldMemRefType); 423 assert(eltSize && "memrefs with size elt types expected"); 424 uint64_t bufSize = *eltSize * *numElements; 425 unsigned newMemSpace; 426 if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) { 427 newMemSpace = *fastMemorySpace; 428 } else { 429 newMemSpace = oldMemRefType.getMemorySpaceAsInt(); 430 } 431 auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(), 432 {}, newMemSpace); 433 434 // Create new private memref for fused loop 'forOp'. 'newShape' is always 435 // a constant shape. 436 // TODO: Create/move alloc ops for private memrefs closer to their 437 // consumer loop nests to reduce their live range. Currently they are added 438 // at the beginning of the block, because loop nests can be reordered 439 // during the fusion pass. 440 Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType); 441 442 // Build an AffineMap to remap access functions based on lower bound offsets. 443 SmallVector<AffineExpr, 4> remapExprs; 444 remapExprs.reserve(rank); 445 for (unsigned i = 0; i < rank; i++) { 446 auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i); 447 448 auto remapExpr = 449 simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0); 450 remapExprs.push_back(remapExpr); 451 } 452 453 auto indexRemap = 454 AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext()); 455 456 // Replace all users of 'oldMemRef' with 'newMemRef'. 457 LogicalResult res = 458 replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, 459 /*extraOperands=*/outerIVs, 460 /*symbolOperands=*/{}, 461 /*domOpFilter=*/&*forOp.getBody()->begin()); 462 assert(succeeded(res) && 463 "replaceAllMemrefUsesWith should always succeed here"); 464 (void)res; 465 return newMemRef; 466 } 467 468 /// Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and 469 /// 'dstId'), if there is any non-affine operation accessing 'memref', return 470 /// true. Otherwise, return false. 471 static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, 472 Value memref, 473 MemRefDependenceGraph *mdg) { 474 auto *srcNode = mdg->getNode(srcId); 475 auto *dstNode = mdg->getNode(dstId); 476 Value::user_range users = memref.getUsers(); 477 // For each MemRefDependenceGraph's node that is between 'srcNode' and 478 // 'dstNode' (exclusive of 'srcNodes' and 'dstNode'), check whether any 479 // non-affine operation in the node accesses the 'memref'. 480 for (auto &idAndNode : mdg->nodes) { 481 Operation *op = idAndNode.second.op; 482 // Take care of operations between 'srcNode' and 'dstNode'. 483 if (srcNode->op->isBeforeInBlock(op) && op->isBeforeInBlock(dstNode->op)) { 484 // Walk inside the operation to find any use of the memref. 485 // Interrupt the walk if found. 486 auto walkResult = op->walk([&](Operation *user) { 487 // Skip affine ops. 488 if (isa<AffineMapAccessInterface>(*user)) 489 return WalkResult::advance(); 490 // Find a non-affine op that uses the memref. 491 if (llvm::is_contained(users, user)) 492 return WalkResult::interrupt(); 493 return WalkResult::advance(); 494 }); 495 if (walkResult.wasInterrupted()) 496 return true; 497 } 498 } 499 return false; 500 } 501 502 /// Check whether a memref value in node 'srcId' has a non-affine that 503 /// is between node 'srcId' and node 'dstId' (exclusive of 'srcNode' and 504 /// 'dstNode'). 505 static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, 506 MemRefDependenceGraph *mdg) { 507 // Collect memref values in node 'srcId'. 508 auto *srcNode = mdg->getNode(srcId); 509 llvm::SmallDenseSet<Value, 2> memRefValues; 510 srcNode->op->walk([&](Operation *op) { 511 // Skip affine ops. 512 if (isa<AffineForOp>(op)) 513 return WalkResult::advance(); 514 for (Value v : op->getOperands()) 515 // Collect memref values only. 516 if (v.getType().isa<MemRefType>()) 517 memRefValues.insert(v); 518 return WalkResult::advance(); 519 }); 520 // Looking for users between node 'srcId' and node 'dstId'. 521 return llvm::any_of(memRefValues, [&](Value memref) { 522 return hasNonAffineUsersOnThePath(srcId, dstId, memref, mdg); 523 }); 524 } 525 526 // Checks the profitability of fusing a backwards slice of the loop nest 527 // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'. 528 // The argument 'srcStoreOpInst' is used to calculate the storage reduction on 529 // the memref being produced and consumed, which is an input to the cost model. 530 // For producer-consumer fusion, 'srcStoreOpInst' will be the same as 531 // 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse 532 // fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the 533 // same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the 534 // unique store op in the src node, which will be used to check that the write 535 // region is the same after input-reuse fusion. Computation slices are provided 536 // in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which 537 // fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is 538 // profitable to fuse the candidate loop nests. Returns false otherwise. 539 // `dstLoopDepth` is set to the most profitable depth at which to materialize 540 // the source loop nest slice. 541 // The profitability model executes the following steps: 542 // *) Computes the backward computation slice at 'srcOpInst'. This 543 // computation slice of the loop nest surrounding 'srcOpInst' is 544 // represented by modified src loop bounds in 'sliceState', which are 545 // functions of loop IVs in the loop nest surrounding 'srcOpInst'. 546 // *) Computes the cost of unfused src/dst loop nests (currently the cost of a 547 // loop nest is the total number of dynamic operation instances in the loop 548 // nest). 549 // *) Computes the cost of fusing a slice of the src loop nest into the dst 550 // loop nest at various values of dst loop depth, attempting to fuse 551 // the largest computation slice at the maximal dst loop depth (closest to 552 // the load) to minimize reuse distance and potentially enable subsequent 553 // load/store forwarding. 554 // NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop 555 // nest, at which the src computation slice is inserted/fused. 556 // NOTE: We attempt to maximize the dst loop depth, but there are cases 557 // where a particular setting for 'dstLoopNest' might fuse an unsliced 558 // loop (within the src computation slice) at a depth which results in 559 // excessive recomputation (see unit tests for examples). 560 // *) Compares the total cost of the unfused loop nests to the min cost fused 561 // loop nest computed in the previous step, and returns true if the latter 562 // is lower. 563 // TODO: Extend profitability analysis to support scenarios with multiple 564 // stores. 565 static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, 566 AffineForOp dstForOp, 567 ArrayRef<ComputationSliceState> depthSliceUnions, 568 unsigned maxLegalFusionDepth, 569 unsigned *dstLoopDepth, 570 double computeToleranceThreshold) { 571 LLVM_DEBUG({ 572 llvm::dbgs() << "Checking whether fusion is profitable between src op:\n"; 573 llvm::dbgs() << ' ' << *srcOpInst << " and destination loop:\n"; 574 llvm::dbgs() << dstForOp << "\n"; 575 }); 576 577 if (maxLegalFusionDepth == 0) { 578 LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth is 0\n"); 579 return false; 580 } 581 582 // Compute cost of sliced and unsliced src loop nest. 583 SmallVector<AffineForOp, 4> srcLoopIVs; 584 getAffineForIVs(*srcOpInst, &srcLoopIVs); 585 586 // Walk src loop nest and collect stats. 587 LoopNestStats srcLoopNestStats; 588 if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats)) 589 return false; 590 591 // Compute cost of dst loop nest. 592 LoopNestStats dstLoopNestStats; 593 if (!getLoopNestStats(dstForOp, &dstLoopNestStats)) 594 return false; 595 596 // Search for min cost value for 'dstLoopDepth'. At each value of 597 // 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice 598 // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union 599 // of these bounds). Next the union slice bounds are used to calculate 600 // the cost of the slice and the cost of the slice inserted into the dst 601 // loop nest at 'dstLoopDepth'. 602 uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max(); 603 double maxStorageReduction = 0.0; 604 std::optional<uint64_t> sliceMemEstimate; 605 606 // The best loop depth at which to materialize the slice. 607 std::optional<unsigned> bestDstLoopDepth; 608 609 // Compute op instance count for the src loop nest without iteration slicing. 610 uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats); 611 612 // Compute src loop nest write region size. 613 MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc()); 614 if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) { 615 LLVM_DEBUG(llvm::dbgs() 616 << "Unable to compute MemRefRegion for source operation\n"); 617 return false; 618 } 619 620 std::optional<int64_t> maybeSrcWriteRegionSizeBytes = 621 srcWriteRegion.getRegionSize(); 622 if (!maybeSrcWriteRegionSizeBytes.has_value()) 623 return false; 624 int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes; 625 626 // Compute op instance count for the src loop nest. 627 uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats); 628 629 // Evaluate all depth choices for materializing the slice in the destination 630 // loop nest. 631 for (unsigned i = maxLegalFusionDepth; i >= 1; --i) { 632 const ComputationSliceState &slice = depthSliceUnions[i - 1]; 633 // Skip slice union if it wasn't computed for this depth. 634 if (slice.isEmpty()) 635 continue; 636 637 int64_t fusedLoopNestComputeCost; 638 if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp, 639 dstLoopNestStats, slice, 640 &fusedLoopNestComputeCost)) { 641 LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n"); 642 continue; 643 } 644 645 double additionalComputeFraction = 646 fusedLoopNestComputeCost / 647 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) - 648 1; 649 650 // Determine what the slice write MemRefRegion would be, if the src loop 651 // nest slice 'slice' were to be inserted into the dst loop nest at loop 652 // depth 'i'. 653 MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc()); 654 if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0, 655 &slice))) { 656 LLVM_DEBUG(llvm::dbgs() 657 << "Failed to compute slice write region at loopDepth: " << i 658 << "\n"); 659 continue; 660 } 661 662 std::optional<int64_t> maybeSliceWriteRegionSizeBytes = 663 sliceWriteRegion.getRegionSize(); 664 if (!maybeSliceWriteRegionSizeBytes.has_value() || 665 *maybeSliceWriteRegionSizeBytes == 0) { 666 LLVM_DEBUG(llvm::dbgs() 667 << "Failed to get slice write region size at loopDepth: " << i 668 << "\n"); 669 continue; 670 } 671 int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes; 672 673 // If we are fusing for reuse, check that write regions remain the same. 674 // TODO: Write region check should check sizes and offsets in 675 // each dimension, so that we are sure they are covering the same memref 676 // region. Also, move this out to a isMemRefRegionSuperSet helper function. 677 if (srcOpInst != srcStoreOpInst && 678 sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes) 679 continue; 680 681 double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) / 682 static_cast<double>(sliceWriteRegionSizeBytes); 683 684 LLVM_DEBUG({ 685 std::stringstream msg; 686 msg << " evaluating fusion profitability at depth : " << i << "\n" 687 << std::fixed << std::setprecision(2) 688 << " additional compute fraction: " 689 << 100.0 * additionalComputeFraction << "%\n" 690 << " storage reduction factor: " << storageReduction << "x\n" 691 << " fused nest cost: " << fusedLoopNestComputeCost << "\n" 692 << " src write region size: " << srcWriteRegionSizeBytes << "\n" 693 << " slice write region size: " << sliceWriteRegionSizeBytes 694 << "\n"; 695 llvm::dbgs() << msg.str(); 696 }); 697 698 // TODO: This is a placeholder cost model. 699 // Among all choices that add an acceptable amount of redundant computation 700 // (as per computeToleranceThreshold), we will simply pick the one that 701 // reduces the intermediary size the most. 702 if ((storageReduction > maxStorageReduction) && 703 (additionalComputeFraction < computeToleranceThreshold)) { 704 maxStorageReduction = storageReduction; 705 bestDstLoopDepth = i; 706 minFusedLoopNestComputeCost = fusedLoopNestComputeCost; 707 sliceMemEstimate = sliceWriteRegionSizeBytes; 708 } 709 } 710 711 // A simple cost model: fuse if it reduces the memory footprint. 712 713 if (!bestDstLoopDepth) { 714 LLVM_DEBUG( 715 llvm::dbgs() 716 << "All fusion choices involve more than the threshold amount of " 717 "redundant computation; NOT fusing.\n"); 718 return false; 719 } 720 721 if (!bestDstLoopDepth) { 722 LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n"); 723 return false; 724 } 725 726 // Set dstLoopDepth based on best values from search. 727 *dstLoopDepth = *bestDstLoopDepth; 728 729 LLVM_DEBUG( 730 llvm::dbgs() << " LoopFusion fusion stats:" 731 << "\n best loop depth: " << bestDstLoopDepth 732 << "\n src loop nest compute cost: " << srcLoopNestCost 733 << "\n dst loop nest compute cost: " << dstLoopNestCost 734 << "\n fused loop nest compute cost: " 735 << minFusedLoopNestComputeCost << "\n"); 736 737 auto dstMemSize = getMemoryFootprintBytes(dstForOp); 738 auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]); 739 740 std::optional<double> storageReduction; 741 742 if (!dstMemSize || !srcMemSize) { 743 LLVM_DEBUG(llvm::dbgs() 744 << " fusion memory benefit cannot be evaluated; NOT fusing.\n"); 745 return false; 746 } 747 748 auto srcMemSizeVal = *srcMemSize; 749 auto dstMemSizeVal = *dstMemSize; 750 751 assert(sliceMemEstimate && "expected value"); 752 auto fusedMem = dstMemSizeVal + *sliceMemEstimate; 753 754 LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n" 755 << " dst mem: " << dstMemSizeVal << "\n" 756 << " fused mem: " << fusedMem << "\n" 757 << " slice mem: " << sliceMemEstimate << "\n"); 758 759 if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) { 760 LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n"); 761 return false; 762 } 763 storageReduction = 764 100.0 * 765 (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal)); 766 767 double additionalComputeFraction = 768 100.0 * (minFusedLoopNestComputeCost / 769 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) - 770 1); 771 (void)additionalComputeFraction; 772 LLVM_DEBUG({ 773 std::stringstream msg; 774 msg << " fusion is most profitable at depth " << *dstLoopDepth << " with " 775 << std::setprecision(2) << additionalComputeFraction 776 << "% redundant computation and a "; 777 msg << (storageReduction ? std::to_string(*storageReduction) : "<unknown>"); 778 msg << "% storage reduction.\n"; 779 llvm::dbgs() << msg.str(); 780 }); 781 782 return true; 783 } 784 785 namespace { 786 787 // GreedyFusion greedily fuses loop nests which have a producer/consumer or 788 // input-reuse relationship on a memref, with the goal of improving locality. 789 // 790 // The steps of the producer-consumer fusion algorithm are as follows: 791 // 792 // *) A worklist is initialized with node ids from the dependence graph. 793 // *) For each node id in the worklist: 794 // *) Pop an AffineForOp of the worklist. This 'dstAffineForOp' will be a 795 // candidate destination AffineForOp into which fusion will be attempted. 796 // *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'. 797 // *) For each LoadOp in 'dstLoadOps' do: 798 // *) Look up dependent loop nests which have a single store op to the same 799 // memref. 800 // *) Check if dependences would be violated by the fusion. 801 // *) Get a computation slice of 'srcLoopNest', which adjusts its loop 802 // bounds to be functions of 'dstLoopNest' IVs and symbols. 803 // *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest', 804 // at a loop depth determined by the cost model in 'isFusionProfitable'. 805 // *) Add the newly fused load/store operations to the state, 806 // and also add newly fused load ops to 'dstLoopOps' to be considered 807 // as fusion dst load ops in another iteration. 808 // *) Remove old src loop nest and its associated state. 809 // 810 // The steps of the input-reuse fusion algorithm are as follows: 811 // 812 // *) Initialize 'worklist' with node ids from the dependence graph. 813 // *) For each 'dstNode' in the worklist: 814 // *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which 815 // loads from the same memref, but which has no dependence paths to/from. 816 // *) Get a computation slice of 'sibLoopNest', which adjusts its loop 817 // bounds to be functions of 'dstLoopNest' IVs and symbols. 818 // *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest', 819 // at a loop depth determined by the cost model in 'isFusionProfitable'. 820 // This function also checks that the memref write region of 'sibLoopNest', 821 // is preserved in the fused loop nest. 822 // *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'. 823 // 824 // Given a graph where top-level operations are vertices in the set 'V' and 825 // edges in the set 'E' are dependences between vertices, this algorithm 826 // takes O(V) time for initialization, and has runtime O(V + E). 827 // 828 // This greedy algorithm is not 'maximal' due to the current restriction of 829 // fusing along single producer consumer edges, but there is a TODO: to fix 830 // this. 831 // 832 // TODO: Experiment with other fusion policies. 833 struct GreedyFusion { 834 public: 835 // The data dependence graph to traverse during fusion. 836 MemRefDependenceGraph *mdg; 837 // Worklist of graph nodes visited during the fusion pass. 838 SmallVector<unsigned, 8> worklist; 839 // Parameter for local buffer size threshold. 840 unsigned localBufSizeThreshold; 841 // Parameter for fast memory space. 842 std::optional<unsigned> fastMemorySpace; 843 // If true, ignore any additional (redundant) computation tolerance threshold 844 // that would have prevented fusion. 845 bool maximalFusion; 846 // The amount of additional computation that is tolerated while fusing 847 // pair-wise as a fraction of the total computation. 848 double computeToleranceThreshold; 849 850 using Node = MemRefDependenceGraph::Node; 851 852 GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold, 853 std::optional<unsigned> fastMemorySpace, bool maximalFusion, 854 double computeToleranceThreshold) 855 : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold), 856 fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion), 857 computeToleranceThreshold(computeToleranceThreshold) {} 858 859 /// Initializes 'worklist' with nodes from 'mdg'. 860 void init() { 861 // TODO: Add a priority queue for prioritizing nodes by different 862 // metrics (e.g. arithmetic intensity/flops-to-bytes ratio). 863 worklist.clear(); 864 for (auto &idAndNode : mdg->nodes) { 865 const Node &node = idAndNode.second; 866 worklist.push_back(node.id); 867 } 868 } 869 /// Run only sibling fusion on the `mdg`. 870 void runSiblingFusionOnly() { 871 fuseSiblingNodes(); 872 eraseUnusedMemRefAllocations(); 873 } 874 875 /// Run only producer/consumer fusion on the `mdg`. 876 void runProducerConsumerFusionOnly() { 877 fuseProducerConsumerNodes( 878 /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max()); 879 eraseUnusedMemRefAllocations(); 880 } 881 882 // Run the GreedyFusion pass. 883 // *) First pass through the nodes fuses single-use producer nodes into their 884 // unique consumer. 885 // *) Second pass fuses sibling nodes which share no dependence edges. 886 // *) Third pass fuses any remaining producer nodes into their users. 887 void runGreedyFusion() { 888 // TODO: Run this repeatedly until a fixed-point is reached. 889 fuseProducerConsumerNodes(/*maxSrcUserCount=*/1); 890 fuseSiblingNodes(); 891 fuseProducerConsumerNodes( 892 /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max()); 893 eraseUnusedMemRefAllocations(); 894 } 895 896 /// Returns true if a private memref can be created for `memref` given 897 /// the fusion scenario reflected by the other arguments. 898 bool canCreatePrivateMemRef(Value memref, 899 const DenseSet<Value> &srcEscapingMemRefs, 900 unsigned producerId, unsigned consumerId, 901 bool removeSrcNode) { 902 const Node *consumerNode = mdg->getNode(consumerId); 903 // If `memref` is an escaping one, do not create a private memref 904 // for the below scenarios, since doing so will leave the escaping 905 // memref unmodified as all the writes originally meant for the 906 // escaping memref would be performed on the private memref: 907 // 1. The source is to be removed after fusion, 908 // OR 909 // 2. The destination writes to `memref`. 910 if (srcEscapingMemRefs.count(memref) > 0 && 911 (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0)) 912 return false; 913 914 // Don't create a private memref if 'srcNode' has in edges on 915 // 'memref' or 'dstNode' has out edges on 'memref'. 916 if (mdg->getIncomingMemRefAccesses(producerId, memref) > 0 || 917 mdg->getOutEdgeCount(consumerId, memref) > 0) 918 return false; 919 920 // If 'srcNode' will be removed but it has out edges on 'memref' to 921 // nodes other than 'dstNode', we have to preserve dependences and 922 // cannot create a private memref. 923 if (removeSrcNode && 924 any_of(mdg->outEdges[producerId], [&](const auto &edge) { 925 return edge.value == memref && edge.id != consumerId; 926 })) 927 return false; 928 929 return true; 930 } 931 932 /// Perform fusions with node `dstId` as the destination of fusion, with 933 /// No fusion is performed when producers with a user count greater than 934 /// `maxSrcUserCount` for any of the memrefs involved. 935 void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) { 936 LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n"); 937 // Skip if this node was removed (fused into another node). 938 if (mdg->nodes.count(dstId) == 0) 939 return; 940 // Get 'dstNode' into which to attempt fusion. 941 auto *dstNode = mdg->getNode(dstId); 942 // Skip if 'dstNode' is not a loop nest. 943 if (!isa<AffineForOp>(dstNode->op)) 944 return; 945 // Skip if 'dstNode' is a loop nest returning values. 946 // TODO: support loop nests that return values. 947 if (dstNode->op->getNumResults() > 0) 948 return; 949 950 LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n"); 951 952 // Sink sequential loops in 'dstNode' (and thus raise parallel loops) 953 // while preserving relative order. This can increase the maximum loop 954 // depth at which we can fuse a slice of a producer loop nest into a 955 // consumer loop nest. 956 sinkSequentialLoops(dstNode); 957 auto dstAffineForOp = cast<AffineForOp>(dstNode->op); 958 959 // Try to fuse 'dstNode' with candidate producer loops until a fixed point 960 // is reached. Fusing two loops may expose new fusion opportunities. 961 bool dstNodeChanged; 962 do { 963 // Gather src loop candidates for 'dstNode' and visit them in "quasi" 964 // reverse program order to minimize the number of iterations needed to 965 // reach the fixed point. Note that this is a best effort approach since 966 // 'getProducerCandidates' does not always guarantee that program order 967 // in 'srcIdCandidates'. 968 dstNodeChanged = false; 969 SmallVector<unsigned, 16> srcIdCandidates; 970 getProducerCandidates(dstId, mdg, srcIdCandidates); 971 972 for (unsigned srcId : llvm::reverse(srcIdCandidates)) { 973 // Get 'srcNode' from which to attempt fusion into 'dstNode'. 974 auto *srcNode = mdg->getNode(srcId); 975 auto srcAffineForOp = cast<AffineForOp>(srcNode->op); 976 LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId 977 << " for dst loop " << dstId << "\n"); 978 979 // Skip if 'srcNode' is a loop nest returning values. 980 // TODO: support loop nests that return values. 981 if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0) 982 continue; 983 984 DenseSet<Value> producerConsumerMemrefs; 985 gatherProducerConsumerMemrefs(srcId, dstId, mdg, 986 producerConsumerMemrefs); 987 988 // Skip if 'srcNode' out edge count on any memref is greater than 989 // 'maxSrcUserCount'. 990 if (any_of(producerConsumerMemrefs, [&](Value memref) { 991 return mdg->getOutEdgeCount(srcNode->id, memref) > 992 maxSrcUserCount; 993 })) 994 continue; 995 996 // Gather memrefs in 'srcNode' that are written and escape out of the 997 // block (e.g., memref block arguments, returned memrefs, 998 // memrefs passed to function calls, etc.). 999 DenseSet<Value> srcEscapingMemRefs; 1000 gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs); 1001 1002 // Skip if there are non-affine operations in between the 'srcNode' 1003 // and 'dstNode' using their memrefs. If so, we wouldn't be able to 1004 // compute a legal insertion point for now. 'srcNode' and 'dstNode' 1005 // memrefs with non-affine operation users would be considered 1006 // escaping memrefs so we can limit this check to only scenarios with 1007 // escaping memrefs. 1008 if (!srcEscapingMemRefs.empty() && 1009 hasNonAffineUsersOnThePath(srcId, dstId, mdg)) { 1010 LLVM_DEBUG(llvm::dbgs() 1011 << "Can't fuse: non-affine users in between the loops\n"); 1012 continue; 1013 } 1014 1015 // Compute an operation list insertion point for the fused loop 1016 // nest which preserves dependences. 1017 Operation *fusedLoopInsPoint = 1018 mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id); 1019 if (fusedLoopInsPoint == nullptr) 1020 continue; 1021 1022 // Compute the innermost common loop depth for dstNode 1023 // producer-consumer loads/stores. 1024 SmallVector<Operation *, 2> dstMemrefOps; 1025 for (Operation *op : dstNode->loads) 1026 if (producerConsumerMemrefs.count( 1027 cast<AffineReadOpInterface>(op).getMemRef()) > 0) 1028 dstMemrefOps.push_back(op); 1029 for (Operation *op : dstNode->stores) 1030 if (producerConsumerMemrefs.count( 1031 cast<AffineWriteOpInterface>(op).getMemRef())) 1032 dstMemrefOps.push_back(op); 1033 unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps); 1034 1035 // Check the feasibility of fusing src loop nest into dst loop nest 1036 // at loop depths in range [1, dstLoopDepthTest]. 1037 unsigned maxLegalFusionDepth = 0; 1038 SmallVector<ComputationSliceState, 8> depthSliceUnions; 1039 depthSliceUnions.resize(dstLoopDepthTest); 1040 FusionStrategy strategy(FusionStrategy::ProducerConsumer); 1041 for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { 1042 FusionResult result = mlir::canFuseLoops( 1043 srcAffineForOp, dstAffineForOp, 1044 /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy); 1045 1046 if (result.value == FusionResult::Success) 1047 maxLegalFusionDepth = i; 1048 } 1049 1050 if (maxLegalFusionDepth == 0) { 1051 LLVM_DEBUG(llvm::dbgs() 1052 << "Can't fuse: fusion is not legal at any depth\n"); 1053 continue; 1054 } 1055 1056 // Check if fusion would be profitable. We skip profitability analysis 1057 // for maximal fusion since we already know the maximal legal depth to 1058 // fuse. 1059 unsigned bestDstLoopDepth = maxLegalFusionDepth; 1060 if (!maximalFusion) { 1061 // Retrieve producer stores from the src loop. 1062 SmallVector<Operation *, 2> producerStores; 1063 for (Operation *op : srcNode->stores) 1064 if (producerConsumerMemrefs.count( 1065 cast<AffineWriteOpInterface>(op).getMemRef())) 1066 producerStores.push_back(op); 1067 1068 // TODO: Suppport multiple producer stores in profitability 1069 // analysis. We limit profitability analysis to only scenarios with 1070 // a single producer store for now. Note that some multi-store 1071 // producer scenarios will still go through profitability analysis 1072 // if only one of the stores is involved the producer-consumer 1073 // relationship of the candidate loops. 1074 assert(!producerStores.empty() && "Expected producer store"); 1075 if (producerStores.size() > 1) 1076 LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not " 1077 "supported for this case\n"); 1078 else if (!isFusionProfitable(producerStores[0], producerStores[0], 1079 dstAffineForOp, depthSliceUnions, 1080 maxLegalFusionDepth, &bestDstLoopDepth, 1081 computeToleranceThreshold)) 1082 continue; 1083 } 1084 1085 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); 1086 ComputationSliceState &bestSlice = 1087 depthSliceUnions[bestDstLoopDepth - 1]; 1088 assert(!bestSlice.isEmpty() && "Missing slice union for depth"); 1089 1090 // Determine if 'srcId' can be removed after fusion, taking into 1091 // account remaining dependences, escaping memrefs and the fusion 1092 // insertion point. 1093 bool removeSrcNode = canRemoveSrcNodeAfterFusion( 1094 srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs, 1095 mdg); 1096 1097 DenseSet<Value> privateMemrefs; 1098 for (Value memref : producerConsumerMemrefs) { 1099 if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId, 1100 removeSrcNode)) { 1101 // Create a private version of this memref. 1102 LLVM_DEBUG(llvm::dbgs() 1103 << "Creating private memref for " << memref << '\n'); 1104 // Create a private version of this memref. 1105 privateMemrefs.insert(memref); 1106 } 1107 } 1108 1109 // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. 1110 fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice); 1111 dstNodeChanged = true; 1112 1113 LLVM_DEBUG(llvm::dbgs() 1114 << "Fused src loop " << srcId << " into dst loop " << dstId 1115 << " at depth " << bestDstLoopDepth << ":\n" 1116 << dstAffineForOp << "\n"); 1117 1118 // Move 'dstAffineForOp' before 'insertPointInst' if needed. 1119 if (fusedLoopInsPoint != dstAffineForOp) 1120 dstAffineForOp->moveBefore(fusedLoopInsPoint); 1121 1122 // Update edges between 'srcNode' and 'dstNode'. 1123 mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs, 1124 removeSrcNode); 1125 1126 // Create private memrefs. 1127 if (!privateMemrefs.empty()) { 1128 // Gather stores for all the private-to-be memrefs. 1129 DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores; 1130 dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) { 1131 Value storeMemRef = storeOp.getMemRef(); 1132 if (privateMemrefs.count(storeMemRef) > 0) 1133 privateMemRefToStores[storeMemRef].push_back(storeOp); 1134 }); 1135 1136 // Replace original memrefs with private memrefs. Note that all the 1137 // loads and stores on these memrefs will be replaced with a new 1138 // loads and stores. Any reference to the original ones becomes 1139 // invalid after this point. 1140 for (auto &memrefToStoresPair : privateMemRefToStores) { 1141 // TODO: Use union of memref write regions to compute 1142 // private memref footprint. 1143 SmallVector<Operation *, 4> &storesForMemref = 1144 memrefToStoresPair.second; 1145 Value newMemRef = createPrivateMemRef( 1146 dstAffineForOp, storesForMemref[0], bestDstLoopDepth, 1147 fastMemorySpace, localBufSizeThreshold); 1148 // Create new node in dependence graph for 'newMemRef' alloc op. 1149 unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp()); 1150 // Add edge from 'newMemRef' node to dstNode. 1151 mdg->addEdge(newMemRefNodeId, dstId, newMemRef); 1152 } 1153 // One or more entries for 'newMemRef' alloc op are inserted into 1154 // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to 1155 // reallocate, update dstNode. 1156 dstNode = mdg->getNode(dstId); 1157 } 1158 1159 // Collect dst loop stats after memref privatization transformation. 1160 LoopNestStateCollector dstLoopCollector; 1161 dstLoopCollector.collect(dstAffineForOp); 1162 1163 // Clear and add back loads and stores. 1164 mdg->clearNodeLoadAndStores(dstNode->id); 1165 mdg->addToNode(dstId, dstLoopCollector.loadOpInsts, 1166 dstLoopCollector.storeOpInsts); 1167 1168 if (removeSrcNode) { 1169 LLVM_DEBUG(llvm::dbgs() 1170 << "Removing src loop " << srcId << " after fusion\n"); 1171 // srcNode is no longer valid after it is removed from mdg. 1172 srcAffineForOp.erase(); 1173 mdg->removeNode(srcId); 1174 srcNode = nullptr; 1175 } 1176 } 1177 } while (dstNodeChanged); 1178 } 1179 1180 /// Visit each node in the graph, and for each node, attempt to fuse it with 1181 /// producer-consumer candidates. No fusion is performed when producers with a 1182 /// user count greater than `maxSrcUserCount` for any of the memrefs involved 1183 /// are encountered. 1184 void fuseProducerConsumerNodes(unsigned maxSrcUserCount) { 1185 LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n"); 1186 init(); 1187 while (!worklist.empty()) { 1188 unsigned dstId = worklist.back(); 1189 worklist.pop_back(); 1190 performFusionsIntoDest(dstId, maxSrcUserCount); 1191 } 1192 } 1193 1194 // Visits each node in the graph, and for each node, attempts to fuse it with 1195 // its sibling nodes (nodes which share a parent, but no dependence edges). 1196 void fuseSiblingNodes() { 1197 LLVM_DEBUG(llvm::dbgs() << "--- Sibling Fusion ---\n"); 1198 init(); 1199 while (!worklist.empty()) { 1200 unsigned dstId = worklist.back(); 1201 worklist.pop_back(); 1202 1203 // Skip if this node was removed (fused into another node). 1204 if (mdg->nodes.count(dstId) == 0) 1205 continue; 1206 // Get 'dstNode' into which to attempt fusion. 1207 auto *dstNode = mdg->getNode(dstId); 1208 // Skip if 'dstNode' is not a loop nest. 1209 if (!isa<AffineForOp>(dstNode->op)) 1210 continue; 1211 // Attempt to fuse 'dstNode' with its sibling nodes in the graph. 1212 fuseWithSiblingNodes(dstNode); 1213 } 1214 } 1215 1216 // Attempt to fuse 'dstNode' with sibling nodes in the graph. 1217 void fuseWithSiblingNodes(Node *dstNode) { 1218 DenseSet<unsigned> visitedSibNodeIds; 1219 std::pair<unsigned, Value> idAndMemref; 1220 auto dstAffineForOp = cast<AffineForOp>(dstNode->op); 1221 1222 while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) { 1223 unsigned sibId = idAndMemref.first; 1224 Value memref = idAndMemref.second; 1225 // TODO: Check that 'sibStoreOpInst' post-dominates all other 1226 // stores to the same memref in 'sibNode' loop nest. 1227 auto *sibNode = mdg->getNode(sibId); 1228 // Compute an operation list insertion point for the fused loop 1229 // nest which preserves dependences. 1230 assert(sibNode->op->getBlock() == dstNode->op->getBlock()); 1231 Operation *insertPointInst = 1232 sibNode->op->isBeforeInBlock(dstNode->op) 1233 ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id) 1234 : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id); 1235 if (insertPointInst == nullptr) 1236 continue; 1237 1238 // Check if fusion would be profitable and at what depth. 1239 1240 // Get unique 'sibNode' load op to 'memref'. 1241 SmallVector<Operation *, 2> sibLoadOpInsts; 1242 sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts); 1243 // Currently findSiblingNodeToFuse searches for siblings with one load. 1244 assert(sibLoadOpInsts.size() == 1); 1245 Operation *sibLoadOpInst = sibLoadOpInsts[0]; 1246 1247 // Gather 'dstNode' load ops to 'memref'. 1248 SmallVector<Operation *, 2> dstLoadOpInsts; 1249 dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts); 1250 1251 SmallVector<AffineForOp, 4> dstLoopIVs; 1252 getAffineForIVs(*dstLoadOpInsts[0], &dstLoopIVs); 1253 unsigned dstLoopDepthTest = dstLoopIVs.size(); 1254 auto sibAffineForOp = cast<AffineForOp>(sibNode->op); 1255 1256 // Compute loop depth and slice union for fusion. 1257 SmallVector<ComputationSliceState, 8> depthSliceUnions; 1258 depthSliceUnions.resize(dstLoopDepthTest); 1259 unsigned maxLegalFusionDepth = 0; 1260 FusionStrategy strategy(memref); 1261 for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { 1262 FusionResult result = mlir::canFuseLoops( 1263 sibAffineForOp, dstAffineForOp, 1264 /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy); 1265 1266 if (result.value == FusionResult::Success) 1267 maxLegalFusionDepth = i; 1268 } 1269 1270 // Skip if fusion is not feasible at any loop depths. 1271 if (maxLegalFusionDepth == 0) 1272 continue; 1273 1274 unsigned bestDstLoopDepth = maxLegalFusionDepth; 1275 if (!maximalFusion) { 1276 // Check if fusion would be profitable. For sibling fusion, the sibling 1277 // load op is treated as the src "store" op for fusion profitability 1278 // purposes. The footprint of the load in the slice relative to the 1279 // unfused source's determines reuse. 1280 if (!isFusionProfitable(sibLoadOpInst, sibLoadOpInst, dstAffineForOp, 1281 depthSliceUnions, maxLegalFusionDepth, 1282 &bestDstLoopDepth, computeToleranceThreshold)) 1283 continue; 1284 } 1285 1286 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); 1287 assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() && 1288 "Fusion depth has no computed slice union"); 1289 // Check if source loop is being inserted in the innermost 1290 // destination loop. Based on this, the fused loop may be optimized 1291 // further inside `fuseLoops`. 1292 bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest); 1293 // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'. 1294 mlir::fuseLoops(sibAffineForOp, dstAffineForOp, 1295 depthSliceUnions[bestDstLoopDepth - 1], 1296 isInnermostInsertion); 1297 1298 auto dstForInst = cast<AffineForOp>(dstNode->op); 1299 // Update operation position of fused loop nest (if needed). 1300 if (insertPointInst != dstForInst) { 1301 dstForInst->moveBefore(insertPointInst); 1302 } 1303 // Update data dependence graph state post fusion. 1304 updateStateAfterSiblingFusion(sibNode, dstNode); 1305 } 1306 } 1307 1308 // Searches block argument uses and the graph from 'dstNode' looking for a 1309 // fusion candidate sibling node which shares no dependences with 'dstNode' 1310 // but which loads from the same memref. Returns true and sets 1311 // 'idAndMemrefToFuse' on success. Returns false otherwise. 1312 bool findSiblingNodeToFuse(Node *dstNode, 1313 DenseSet<unsigned> *visitedSibNodeIds, 1314 std::pair<unsigned, Value> *idAndMemrefToFuse) { 1315 // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse 1316 // on 'memref'. 1317 auto canFuseWithSibNode = [&](Node *sibNode, Value memref) { 1318 // Skip if 'outEdge' is not a read-after-write dependence. 1319 // TODO: Remove restrict to single load op restriction. 1320 if (sibNode->getLoadOpCount(memref) != 1) 1321 return false; 1322 // Skip if there exists a path of dependent edges between 1323 // 'sibNode' and 'dstNode'. 1324 if (mdg->hasDependencePath(sibNode->id, dstNode->id) || 1325 mdg->hasDependencePath(dstNode->id, sibNode->id)) 1326 return false; 1327 // Skip sib node if it loads to (and stores from) the same memref on 1328 // which it also has an input dependence edge. 1329 DenseSet<Value> loadAndStoreMemrefSet; 1330 sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet); 1331 if (llvm::any_of(loadAndStoreMemrefSet, [=](Value memref) { 1332 return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0; 1333 })) 1334 return false; 1335 1336 // Check that all stores are to the same memref if any. 1337 DenseSet<Value> storeMemrefs; 1338 for (auto *storeOpInst : sibNode->stores) { 1339 storeMemrefs.insert( 1340 cast<AffineWriteOpInterface>(storeOpInst).getMemRef()); 1341 } 1342 if (storeMemrefs.size() > 1) 1343 return false; 1344 1345 // Skip if a memref value in one node is used by a non-affine memref 1346 // access that lies between 'dstNode' and 'sibNode'. 1347 if (hasNonAffineUsersOnThePath(dstNode->id, sibNode->id, mdg) || 1348 hasNonAffineUsersOnThePath(sibNode->id, dstNode->id, mdg)) 1349 return false; 1350 return true; 1351 }; 1352 1353 // Search for siblings which load the same memref block argument. 1354 Block *block = dstNode->op->getBlock(); 1355 for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) { 1356 for (Operation *user : block->getArgument(i).getUsers()) { 1357 auto loadOp = dyn_cast<AffineReadOpInterface>(user); 1358 if (!loadOp) 1359 continue; 1360 // Gather loops surrounding 'use'. 1361 SmallVector<AffineForOp, 4> loops; 1362 getAffineForIVs(*user, &loops); 1363 // Skip 'use' if it is not within a loop nest. 1364 if (loops.empty()) 1365 continue; 1366 Node *sibNode = mdg->getForOpNode(loops[0]); 1367 assert(sibNode != nullptr); 1368 // Skip 'use' if it not a sibling to 'dstNode'. 1369 if (sibNode->id == dstNode->id) 1370 continue; 1371 // Skip 'use' if it has been visited. 1372 if (visitedSibNodeIds->count(sibNode->id) > 0) 1373 continue; 1374 // Skip 'use' if it does not load from the same memref as 'dstNode'. 1375 auto memref = loadOp.getMemRef(); 1376 if (dstNode->getLoadOpCount(memref) == 0) 1377 continue; 1378 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. 1379 if (canFuseWithSibNode(sibNode, memref)) { 1380 visitedSibNodeIds->insert(sibNode->id); 1381 idAndMemrefToFuse->first = sibNode->id; 1382 idAndMemrefToFuse->second = memref; 1383 return true; 1384 } 1385 } 1386 } 1387 1388 // Search for siblings by following edges through an intermediate src node. 1389 // Collect candidate 'dstNode' input edges in 'inEdges'. 1390 SmallVector<MemRefDependenceGraph::Edge, 2> inEdges; 1391 mdg->forEachMemRefInputEdge( 1392 dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) { 1393 // Add 'inEdge' if it is a read-after-write dependence. 1394 if (dstNode->getLoadOpCount(inEdge.value) > 0 && 1395 mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0) 1396 inEdges.push_back(inEdge); 1397 }); 1398 1399 // Search for sibling nodes to fuse by visiting output edges from each input 1400 // edge in 'inEdges'. 1401 for (auto &inEdge : inEdges) { 1402 // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'. 1403 SmallVector<MemRefDependenceGraph::Edge, 2> outEdges; 1404 mdg->forEachMemRefOutputEdge( 1405 inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) { 1406 unsigned sibNodeId = outEdge.id; 1407 if (visitedSibNodeIds->count(sibNodeId) > 0) 1408 return; 1409 // Skip output edge if not a sibling using the same memref. 1410 if (outEdge.id == dstNode->id || outEdge.value != inEdge.value) 1411 return; 1412 auto *sibNode = mdg->getNode(sibNodeId); 1413 if (!isa<AffineForOp>(sibNode->op)) 1414 return; 1415 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. 1416 if (canFuseWithSibNode(sibNode, outEdge.value)) { 1417 // Add candidate 'outEdge' to sibling node. 1418 outEdges.push_back(outEdge); 1419 } 1420 }); 1421 1422 // Add first candidate if any were returned. 1423 if (!outEdges.empty()) { 1424 visitedSibNodeIds->insert(outEdges[0].id); 1425 idAndMemrefToFuse->first = outEdges[0].id; 1426 idAndMemrefToFuse->second = outEdges[0].value; 1427 return true; 1428 } 1429 } 1430 return false; 1431 } 1432 1433 /// Update data dependence graph state to reflect sibling fusion of 'sibNode' 1434 /// into 'dstNode'. 1435 void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) { 1436 // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion. 1437 mdg->updateEdges(sibNode->id, dstNode->id); 1438 1439 // Collect dst loop stats after memref privatization transformation. 1440 auto dstForInst = cast<AffineForOp>(dstNode->op); 1441 LoopNestStateCollector dstLoopCollector; 1442 dstLoopCollector.collect(dstForInst); 1443 // Clear and add back loads and stores 1444 mdg->clearNodeLoadAndStores(dstNode->id); 1445 mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts, 1446 dstLoopCollector.storeOpInsts); 1447 // Remove old sibling loop nest if it no longer has outgoing dependence 1448 // edges, and it does not write to a memref which escapes the block. 1449 if (mdg->getOutEdgeCount(sibNode->id) == 0) { 1450 Operation *op = sibNode->op; 1451 mdg->removeNode(sibNode->id); 1452 op->erase(); 1453 } 1454 } 1455 1456 // Clean up any allocs with no users. 1457 void eraseUnusedMemRefAllocations() { 1458 for (auto &pair : mdg->memrefEdgeCount) { 1459 if (pair.second > 0) 1460 continue; 1461 auto memref = pair.first; 1462 // Skip if there exist other uses (return operation or function calls). 1463 if (!memref.use_empty()) 1464 continue; 1465 // Use list expected to match the dep graph info. 1466 auto *op = memref.getDefiningOp(); 1467 if (isa_and_nonnull<memref::AllocOp>(op)) 1468 op->erase(); 1469 } 1470 } 1471 }; 1472 1473 } // namespace 1474 1475 /// Run fusion on `block`. 1476 void LoopFusion::runOnBlock(Block *block) { 1477 MemRefDependenceGraph g(*block); 1478 if (!g.init()) { 1479 LLVM_DEBUG(llvm::dbgs() << "MDG init failed\n"); 1480 return; 1481 } 1482 1483 std::optional<unsigned> fastMemorySpaceOpt; 1484 if (fastMemorySpace.hasValue()) 1485 fastMemorySpaceOpt = fastMemorySpace; 1486 unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024; 1487 GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt, 1488 maximalFusion, computeToleranceThreshold); 1489 1490 if (affineFusionMode == FusionMode::ProducerConsumer) 1491 fusion.runProducerConsumerFusionOnly(); 1492 else if (affineFusionMode == FusionMode::Sibling) 1493 fusion.runSiblingFusionOnly(); 1494 else 1495 fusion.runGreedyFusion(); 1496 } 1497 1498 void LoopFusion::runOnOperation() { 1499 for (Region ®ion : getOperation()->getRegions()) 1500 for (Block &block : region.getBlocks()) 1501 runOnBlock(&block); 1502 } 1503 1504 std::unique_ptr<Pass> 1505 mlir::createLoopFusionPass(unsigned fastMemorySpace, 1506 uint64_t localBufSizeThreshold, bool maximalFusion, 1507 enum FusionMode affineFusionMode) { 1508 return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold, 1509 maximalFusion, affineFusionMode); 1510 } 1511