1 //===- LoopFusion.cpp - Code to perform loop fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements affine fusion. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/Passes.h" 14 15 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 16 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" 17 #include "mlir/Dialect/Affine/Analysis/Utils.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Affine/LoopFusionUtils.h" 20 #include "mlir/Dialect/Affine/LoopUtils.h" 21 #include "mlir/Dialect/Affine/Utils.h" 22 #include "mlir/Dialect/MemRef/IR/MemRef.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/AffineMap.h" 25 #include "mlir/IR/Builders.h" 26 #include "mlir/Transforms/Passes.h" 27 #include "llvm/ADT/DenseMap.h" 28 #include "llvm/ADT/DenseSet.h" 29 #include "llvm/ADT/STLExtras.h" 30 #include "llvm/ADT/SetVector.h" 31 #include "llvm/Support/CommandLine.h" 32 #include "llvm/Support/Debug.h" 33 #include "llvm/Support/raw_ostream.h" 34 #include <iomanip> 35 #include <optional> 36 #include <sstream> 37 38 namespace mlir { 39 #define GEN_PASS_DEF_AFFINELOOPFUSION 40 #include "mlir/Dialect/Affine/Passes.h.inc" 41 } // namespace mlir 42 43 #define DEBUG_TYPE "affine-loop-fusion" 44 45 using namespace mlir; 46 47 namespace { 48 /// Loop fusion pass. This pass currently supports a greedy fusion policy, 49 /// which fuses loop nests with single-writer/single-reader memref dependences 50 /// with the goal of improving locality. 51 52 // TODO: Support fusion of source loop nests which write to multiple 53 // memrefs, where each memref can have multiple users (if profitable). 54 // TODO: Extend this pass to check for fusion preventing dependences, 55 // and add support for more general loop fusion algorithms. 56 57 struct LoopFusion : public impl::AffineLoopFusionBase<LoopFusion> { 58 LoopFusion() = default; 59 LoopFusion(unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes, 60 bool maximalFusion, enum FusionMode affineFusionMode) { 61 this->fastMemorySpace = fastMemorySpace; 62 this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024; 63 this->maximalFusion = maximalFusion; 64 this->affineFusionMode = affineFusionMode; 65 } 66 67 void runOnBlock(Block *block); 68 void runOnOperation() override; 69 }; 70 71 } // namespace 72 73 std::unique_ptr<Pass> 74 mlir::createLoopFusionPass(unsigned fastMemorySpace, 75 uint64_t localBufSizeThreshold, bool maximalFusion, 76 enum FusionMode affineFusionMode) { 77 return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold, 78 maximalFusion, affineFusionMode); 79 } 80 81 namespace { 82 83 // LoopNestStateCollector walks loop nests and collects load and store 84 // operations, and whether or not a region holding op other than ForOp and IfOp 85 // was encountered in the loop nest. 86 struct LoopNestStateCollector { 87 SmallVector<AffineForOp, 4> forOps; 88 SmallVector<Operation *, 4> loadOpInsts; 89 SmallVector<Operation *, 4> storeOpInsts; 90 bool hasNonAffineRegionOp = false; 91 92 void collect(Operation *opToWalk) { 93 opToWalk->walk([&](Operation *op) { 94 if (isa<AffineForOp>(op)) 95 forOps.push_back(cast<AffineForOp>(op)); 96 else if (op->getNumRegions() != 0 && !isa<AffineIfOp>(op)) 97 hasNonAffineRegionOp = true; 98 else if (isa<AffineReadOpInterface>(op)) 99 loadOpInsts.push_back(op); 100 else if (isa<AffineWriteOpInterface>(op)) 101 storeOpInsts.push_back(op); 102 }); 103 } 104 }; 105 106 // MemRefDependenceGraph is a graph data structure where graph nodes are 107 // top-level operations in a `Block` which contain load/store ops, and edges 108 // are memref dependences between the nodes. 109 // TODO: Add a more flexible dependence graph representation. 110 // TODO: Add a depth parameter to dependence graph construction. 111 struct MemRefDependenceGraph { 112 public: 113 // Node represents a node in the graph. A Node is either an entire loop nest 114 // rooted at the top level which contains loads/stores, or a top level 115 // load/store. 116 struct Node { 117 // The unique identifier of this node in the graph. 118 unsigned id; 119 // The top-level statement which is (or contains) a load/store. 120 Operation *op; 121 // List of load operations. 122 SmallVector<Operation *, 4> loads; 123 // List of store op insts. 124 SmallVector<Operation *, 4> stores; 125 Node(unsigned id, Operation *op) : id(id), op(op) {} 126 127 // Returns the load op count for 'memref'. 128 unsigned getLoadOpCount(Value memref) const { 129 unsigned loadOpCount = 0; 130 for (Operation *loadOp : loads) { 131 if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef()) 132 ++loadOpCount; 133 } 134 return loadOpCount; 135 } 136 137 // Returns the store op count for 'memref'. 138 unsigned getStoreOpCount(Value memref) const { 139 unsigned storeOpCount = 0; 140 for (Operation *storeOp : stores) { 141 if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef()) 142 ++storeOpCount; 143 } 144 return storeOpCount; 145 } 146 147 // Returns all store ops in 'storeOps' which access 'memref'. 148 void getStoreOpsForMemref(Value memref, 149 SmallVectorImpl<Operation *> *storeOps) const { 150 for (Operation *storeOp : stores) { 151 if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef()) 152 storeOps->push_back(storeOp); 153 } 154 } 155 156 // Returns all load ops in 'loadOps' which access 'memref'. 157 void getLoadOpsForMemref(Value memref, 158 SmallVectorImpl<Operation *> *loadOps) const { 159 for (Operation *loadOp : loads) { 160 if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef()) 161 loadOps->push_back(loadOp); 162 } 163 } 164 165 // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node 166 // has at least one load and store operation. 167 void 168 getLoadAndStoreMemrefSet(DenseSet<Value> *loadAndStoreMemrefSet) const { 169 llvm::SmallDenseSet<Value, 2> loadMemrefs; 170 for (Operation *loadOp : loads) { 171 loadMemrefs.insert(cast<AffineReadOpInterface>(loadOp).getMemRef()); 172 } 173 for (Operation *storeOp : stores) { 174 auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef(); 175 if (loadMemrefs.count(memref) > 0) 176 loadAndStoreMemrefSet->insert(memref); 177 } 178 } 179 }; 180 181 // Edge represents a data dependence between nodes in the graph. 182 struct Edge { 183 // The id of the node at the other end of the edge. 184 // If this edge is stored in Edge = Node.inEdges[i], then 185 // 'Node.inEdges[i].id' is the identifier of the source node of the edge. 186 // If this edge is stored in Edge = Node.outEdges[i], then 187 // 'Node.outEdges[i].id' is the identifier of the dest node of the edge. 188 unsigned id; 189 // The SSA value on which this edge represents a dependence. 190 // If the value is a memref, then the dependence is between graph nodes 191 // which contain accesses to the same memref 'value'. If the value is a 192 // non-memref value, then the dependence is between a graph node which 193 // defines an SSA value and another graph node which uses the SSA value 194 // (e.g. a constant or load operation defining a value which is used inside 195 // a loop nest). 196 Value value; 197 }; 198 199 // Map from node id to Node. 200 DenseMap<unsigned, Node> nodes; 201 // Map from node id to list of input edges. 202 DenseMap<unsigned, SmallVector<Edge, 2>> inEdges; 203 // Map from node id to list of output edges. 204 DenseMap<unsigned, SmallVector<Edge, 2>> outEdges; 205 // Map from memref to a count on the dependence edges associated with that 206 // memref. 207 DenseMap<Value, unsigned> memrefEdgeCount; 208 // The next unique identifier to use for newly created graph nodes. 209 unsigned nextNodeId = 0; 210 211 MemRefDependenceGraph(Block &block) : block(block) {} 212 213 // Initializes the dependence graph based on operations in `block'. 214 // Returns true on success, false otherwise. 215 bool init(); 216 217 // Returns the graph node for 'id'. 218 Node *getNode(unsigned id) { 219 auto it = nodes.find(id); 220 assert(it != nodes.end()); 221 return &it->second; 222 } 223 224 // Returns the graph node for 'forOp'. 225 Node *getForOpNode(AffineForOp forOp) { 226 for (auto &idAndNode : nodes) 227 if (idAndNode.second.op == forOp) 228 return &idAndNode.second; 229 return nullptr; 230 } 231 232 // Adds a node with 'op' to the graph and returns its unique identifier. 233 unsigned addNode(Operation *op) { 234 Node node(nextNodeId++, op); 235 nodes.insert({node.id, node}); 236 return node.id; 237 } 238 239 // Remove node 'id' (and its associated edges) from graph. 240 void removeNode(unsigned id) { 241 // Remove each edge in 'inEdges[id]'. 242 if (inEdges.count(id) > 0) { 243 SmallVector<Edge, 2> oldInEdges = inEdges[id]; 244 for (auto &inEdge : oldInEdges) { 245 removeEdge(inEdge.id, id, inEdge.value); 246 } 247 } 248 // Remove each edge in 'outEdges[id]'. 249 if (outEdges.count(id) > 0) { 250 SmallVector<Edge, 2> oldOutEdges = outEdges[id]; 251 for (auto &outEdge : oldOutEdges) { 252 removeEdge(id, outEdge.id, outEdge.value); 253 } 254 } 255 // Erase remaining node state. 256 inEdges.erase(id); 257 outEdges.erase(id); 258 nodes.erase(id); 259 } 260 261 // Returns true if node 'id' writes to any memref which escapes (or is an 262 // argument to) the block. Returns false otherwise. 263 bool writesToLiveInOrEscapingMemrefs(unsigned id) { 264 Node *node = getNode(id); 265 for (auto *storeOpInst : node->stores) { 266 auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef(); 267 auto *op = memref.getDefiningOp(); 268 // Return true if 'memref' is a block argument. 269 if (!op) 270 return true; 271 // Return true if any use of 'memref' does not deference it in an affine 272 // way. 273 for (auto *user : memref.getUsers()) 274 if (!isa<AffineMapAccessInterface>(*user)) 275 return true; 276 } 277 return false; 278 } 279 280 // Returns true iff there is an edge from node 'srcId' to node 'dstId' which 281 // is for 'value' if non-null, or for any value otherwise. Returns false 282 // otherwise. 283 bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr) { 284 if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) { 285 return false; 286 } 287 bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) { 288 return edge.id == dstId && (!value || edge.value == value); 289 }); 290 bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) { 291 return edge.id == srcId && (!value || edge.value == value); 292 }); 293 return hasOutEdge && hasInEdge; 294 } 295 296 // Adds an edge from node 'srcId' to node 'dstId' for 'value'. 297 void addEdge(unsigned srcId, unsigned dstId, Value value) { 298 if (!hasEdge(srcId, dstId, value)) { 299 outEdges[srcId].push_back({dstId, value}); 300 inEdges[dstId].push_back({srcId, value}); 301 if (value.getType().isa<MemRefType>()) 302 memrefEdgeCount[value]++; 303 } 304 } 305 306 // Removes an edge from node 'srcId' to node 'dstId' for 'value'. 307 void removeEdge(unsigned srcId, unsigned dstId, Value value) { 308 assert(inEdges.count(dstId) > 0); 309 assert(outEdges.count(srcId) > 0); 310 if (value.getType().isa<MemRefType>()) { 311 assert(memrefEdgeCount.count(value) > 0); 312 memrefEdgeCount[value]--; 313 } 314 // Remove 'srcId' from 'inEdges[dstId]'. 315 for (auto *it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) { 316 if ((*it).id == srcId && (*it).value == value) { 317 inEdges[dstId].erase(it); 318 break; 319 } 320 } 321 // Remove 'dstId' from 'outEdges[srcId]'. 322 for (auto *it = outEdges[srcId].begin(); it != outEdges[srcId].end(); 323 ++it) { 324 if ((*it).id == dstId && (*it).value == value) { 325 outEdges[srcId].erase(it); 326 break; 327 } 328 } 329 } 330 331 // Returns true if there is a path in the dependence graph from node 'srcId' 332 // to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the 333 // operations that the edges connected are expected to be from the same block. 334 bool hasDependencePath(unsigned srcId, unsigned dstId) { 335 // Worklist state is: <node-id, next-output-edge-index-to-visit> 336 SmallVector<std::pair<unsigned, unsigned>, 4> worklist; 337 worklist.push_back({srcId, 0}); 338 Operation *dstOp = getNode(dstId)->op; 339 // Run DFS traversal to see if 'dstId' is reachable from 'srcId'. 340 while (!worklist.empty()) { 341 auto &idAndIndex = worklist.back(); 342 // Return true if we have reached 'dstId'. 343 if (idAndIndex.first == dstId) 344 return true; 345 // Pop and continue if node has no out edges, or if all out edges have 346 // already been visited. 347 if (outEdges.count(idAndIndex.first) == 0 || 348 idAndIndex.second == outEdges[idAndIndex.first].size()) { 349 worklist.pop_back(); 350 continue; 351 } 352 // Get graph edge to traverse. 353 Edge edge = outEdges[idAndIndex.first][idAndIndex.second]; 354 // Increment next output edge index for 'idAndIndex'. 355 ++idAndIndex.second; 356 // Add node at 'edge.id' to the worklist. We don't need to consider 357 // nodes that are "after" dstId in the containing block; one can't have a 358 // path to `dstId` from any of those nodes. 359 bool afterDst = dstOp->isBeforeInBlock(getNode(edge.id)->op); 360 if (!afterDst && edge.id != idAndIndex.first) 361 worklist.push_back({edge.id, 0}); 362 } 363 return false; 364 } 365 366 // Returns the input edge count for node 'id' and 'memref' from src nodes 367 // which access 'memref' with a store operation. 368 unsigned getIncomingMemRefAccesses(unsigned id, Value memref) { 369 unsigned inEdgeCount = 0; 370 if (inEdges.count(id) > 0) 371 for (auto &inEdge : inEdges[id]) 372 if (inEdge.value == memref) { 373 Node *srcNode = getNode(inEdge.id); 374 // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref' 375 if (srcNode->getStoreOpCount(memref) > 0) 376 ++inEdgeCount; 377 } 378 return inEdgeCount; 379 } 380 381 // Returns the output edge count for node 'id' and 'memref' (if non-null), 382 // otherwise returns the total output edge count from node 'id'. 383 unsigned getOutEdgeCount(unsigned id, Value memref = nullptr) { 384 unsigned outEdgeCount = 0; 385 if (outEdges.count(id) > 0) 386 for (auto &outEdge : outEdges[id]) 387 if (!memref || outEdge.value == memref) 388 ++outEdgeCount; 389 return outEdgeCount; 390 } 391 392 /// Return all nodes which define SSA values used in node 'id'. 393 void gatherDefiningNodes(unsigned id, DenseSet<unsigned> &definingNodes) { 394 for (MemRefDependenceGraph::Edge edge : inEdges[id]) 395 // By definition of edge, if the edge value is a non-memref value, 396 // then the dependence is between a graph node which defines an SSA value 397 // and another graph node which uses the SSA value. 398 if (!edge.value.getType().isa<MemRefType>()) 399 definingNodes.insert(edge.id); 400 } 401 402 // Computes and returns an insertion point operation, before which the 403 // the fused <srcId, dstId> loop nest can be inserted while preserving 404 // dependences. Returns nullptr if no such insertion point is found. 405 Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) { 406 if (outEdges.count(srcId) == 0) 407 return getNode(dstId)->op; 408 409 // Skip if there is any defining node of 'dstId' that depends on 'srcId'. 410 DenseSet<unsigned> definingNodes; 411 gatherDefiningNodes(dstId, definingNodes); 412 if (llvm::any_of(definingNodes, [&](unsigned id) { 413 return hasDependencePath(srcId, id); 414 })) { 415 LLVM_DEBUG(llvm::dbgs() 416 << "Can't fuse: a defining op with a user in the dst " 417 "loop has dependence from the src loop\n"); 418 return nullptr; 419 } 420 421 // Build set of insts in range (srcId, dstId) which depend on 'srcId'. 422 SmallPtrSet<Operation *, 2> srcDepInsts; 423 for (auto &outEdge : outEdges[srcId]) 424 if (outEdge.id != dstId) 425 srcDepInsts.insert(getNode(outEdge.id)->op); 426 427 // Build set of insts in range (srcId, dstId) on which 'dstId' depends. 428 SmallPtrSet<Operation *, 2> dstDepInsts; 429 for (auto &inEdge : inEdges[dstId]) 430 if (inEdge.id != srcId) 431 dstDepInsts.insert(getNode(inEdge.id)->op); 432 433 Operation *srcNodeInst = getNode(srcId)->op; 434 Operation *dstNodeInst = getNode(dstId)->op; 435 436 // Computing insertion point: 437 // *) Walk all operation positions in Block operation list in the 438 // range (src, dst). For each operation 'op' visited in this search: 439 // *) Store in 'firstSrcDepPos' the first position where 'op' has a 440 // dependence edge from 'srcNode'. 441 // *) Store in 'lastDstDepPost' the last position where 'op' has a 442 // dependence edge to 'dstNode'. 443 // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the 444 // operation insertion point (or return null pointer if no such 445 // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos'). 446 SmallVector<Operation *, 2> depInsts; 447 std::optional<unsigned> firstSrcDepPos; 448 std::optional<unsigned> lastDstDepPos; 449 unsigned pos = 0; 450 for (Block::iterator it = std::next(Block::iterator(srcNodeInst)); 451 it != Block::iterator(dstNodeInst); ++it) { 452 Operation *op = &(*it); 453 if (srcDepInsts.count(op) > 0 && firstSrcDepPos == std::nullopt) 454 firstSrcDepPos = pos; 455 if (dstDepInsts.count(op) > 0) 456 lastDstDepPos = pos; 457 depInsts.push_back(op); 458 ++pos; 459 } 460 461 if (firstSrcDepPos.has_value()) { 462 if (lastDstDepPos.has_value()) { 463 if (*firstSrcDepPos <= *lastDstDepPos) { 464 // No valid insertion point exists which preserves dependences. 465 return nullptr; 466 } 467 } 468 // Return the insertion point at 'firstSrcDepPos'. 469 return depInsts[*firstSrcDepPos]; 470 } 471 // No dependence targets in range (or only dst deps in range), return 472 // 'dstNodInst' insertion point. 473 return dstNodeInst; 474 } 475 476 // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them, 477 // taking into account that: 478 // *) if 'removeSrcId' is true, 'srcId' will be removed after fusion, 479 // *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a 480 // private memref. 481 void updateEdges(unsigned srcId, unsigned dstId, 482 const DenseSet<Value> &privateMemRefs, bool removeSrcId) { 483 // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'. 484 if (inEdges.count(srcId) > 0) { 485 SmallVector<Edge, 2> oldInEdges = inEdges[srcId]; 486 for (auto &inEdge : oldInEdges) { 487 // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref. 488 if (privateMemRefs.count(inEdge.value) == 0) 489 addEdge(inEdge.id, dstId, inEdge.value); 490 } 491 } 492 // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'. 493 // If 'srcId' is going to be removed, remap all the out edges to 'dstId'. 494 if (outEdges.count(srcId) > 0) { 495 SmallVector<Edge, 2> oldOutEdges = outEdges[srcId]; 496 for (auto &outEdge : oldOutEdges) { 497 // Remove any out edges from 'srcId' to 'dstId' across memrefs. 498 if (outEdge.id == dstId) 499 removeEdge(srcId, outEdge.id, outEdge.value); 500 else if (removeSrcId) { 501 addEdge(dstId, outEdge.id, outEdge.value); 502 removeEdge(srcId, outEdge.id, outEdge.value); 503 } 504 } 505 } 506 // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being 507 // replaced by a private memref). These edges could come from nodes 508 // other than 'srcId' which were removed in the previous step. 509 if (inEdges.count(dstId) > 0 && !privateMemRefs.empty()) { 510 SmallVector<Edge, 2> oldInEdges = inEdges[dstId]; 511 for (auto &inEdge : oldInEdges) 512 if (privateMemRefs.count(inEdge.value) > 0) 513 removeEdge(inEdge.id, dstId, inEdge.value); 514 } 515 } 516 517 // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion 518 // of sibling node 'sibId' into node 'dstId'. 519 void updateEdges(unsigned sibId, unsigned dstId) { 520 // For each edge in 'inEdges[sibId]': 521 // *) Add new edge from source node 'inEdge.id' to 'dstNode'. 522 // *) Remove edge from source node 'inEdge.id' to 'sibNode'. 523 if (inEdges.count(sibId) > 0) { 524 SmallVector<Edge, 2> oldInEdges = inEdges[sibId]; 525 for (auto &inEdge : oldInEdges) { 526 addEdge(inEdge.id, dstId, inEdge.value); 527 removeEdge(inEdge.id, sibId, inEdge.value); 528 } 529 } 530 531 // For each edge in 'outEdges[sibId]' to node 'id' 532 // *) Add new edge from 'dstId' to 'outEdge.id'. 533 // *) Remove edge from 'sibId' to 'outEdge.id'. 534 if (outEdges.count(sibId) > 0) { 535 SmallVector<Edge, 2> oldOutEdges = outEdges[sibId]; 536 for (auto &outEdge : oldOutEdges) { 537 addEdge(dstId, outEdge.id, outEdge.value); 538 removeEdge(sibId, outEdge.id, outEdge.value); 539 } 540 } 541 } 542 543 // Adds ops in 'loads' and 'stores' to node at 'id'. 544 void addToNode(unsigned id, const SmallVectorImpl<Operation *> &loads, 545 const SmallVectorImpl<Operation *> &stores) { 546 Node *node = getNode(id); 547 llvm::append_range(node->loads, loads); 548 llvm::append_range(node->stores, stores); 549 } 550 551 void clearNodeLoadAndStores(unsigned id) { 552 Node *node = getNode(id); 553 node->loads.clear(); 554 node->stores.clear(); 555 } 556 557 // Calls 'callback' for each input edge incident to node 'id' which carries a 558 // memref dependence. 559 void forEachMemRefInputEdge(unsigned id, 560 const std::function<void(Edge)> &callback) { 561 if (inEdges.count(id) > 0) 562 forEachMemRefEdge(inEdges[id], callback); 563 } 564 565 // Calls 'callback' for each output edge from node 'id' which carries a 566 // memref dependence. 567 void forEachMemRefOutputEdge(unsigned id, 568 const std::function<void(Edge)> &callback) { 569 if (outEdges.count(id) > 0) 570 forEachMemRefEdge(outEdges[id], callback); 571 } 572 573 // Calls 'callback' for each edge in 'edges' which carries a memref 574 // dependence. 575 void forEachMemRefEdge(ArrayRef<Edge> edges, 576 const std::function<void(Edge)> &callback) { 577 for (const auto &edge : edges) { 578 // Skip if 'edge' is not a memref dependence edge. 579 if (!edge.value.getType().isa<MemRefType>()) 580 continue; 581 assert(nodes.count(edge.id) > 0); 582 // Skip if 'edge.id' is not a loop nest. 583 if (!isa<AffineForOp>(getNode(edge.id)->op)) 584 continue; 585 // Visit current input edge 'edge'. 586 callback(edge); 587 } 588 } 589 590 void print(raw_ostream &os) const { 591 os << "\nMemRefDependenceGraph\n"; 592 os << "\nNodes:\n"; 593 for (const auto &idAndNode : nodes) { 594 os << "Node: " << idAndNode.first << "\n"; 595 auto it = inEdges.find(idAndNode.first); 596 if (it != inEdges.end()) { 597 for (const auto &e : it->second) 598 os << " InEdge: " << e.id << " " << e.value << "\n"; 599 } 600 it = outEdges.find(idAndNode.first); 601 if (it != outEdges.end()) { 602 for (const auto &e : it->second) 603 os << " OutEdge: " << e.id << " " << e.value << "\n"; 604 } 605 } 606 } 607 void dump() const { print(llvm::errs()); } 608 609 /// The block for which this graph is created to perform fusion. 610 Block █ 611 }; 612 613 /// Returns true if node 'srcId' can be removed after fusing it with node 614 /// 'dstId'. The node can be removed if any of the following conditions are met: 615 /// 1. 'srcId' has no output dependences after fusion and no escaping memrefs. 616 /// 2. 'srcId' has no output dependences after fusion, has escaping memrefs 617 /// and the fusion slice is maximal. 618 /// 3. 'srcId' has output dependences after fusion, the fusion slice is 619 /// maximal and the fusion insertion point dominates all the dependences. 620 static bool canRemoveSrcNodeAfterFusion( 621 unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice, 622 Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs, 623 MemRefDependenceGraph *mdg) { 624 625 Operation *dstNodeOp = mdg->getNode(dstId)->op; 626 bool hasOutDepsAfterFusion = false; 627 628 for (auto &outEdge : mdg->outEdges[srcId]) { 629 Operation *depNodeOp = mdg->getNode(outEdge.id)->op; 630 // Skip dependence with dstOp since it will be removed after fusion. 631 if (depNodeOp == dstNodeOp) 632 continue; 633 634 // Only fusion within the same block is supported. Use domination analysis 635 // when needed. 636 if (depNodeOp->getBlock() != dstNodeOp->getBlock()) 637 return false; 638 639 // Check if the insertion point of the fused loop dominates the dependence. 640 // Otherwise, the src loop can't be removed. 641 if (fusedLoopInsPoint != depNodeOp && 642 !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) { 643 LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't " 644 "dominate dependence\n"); 645 return false; 646 } 647 648 hasOutDepsAfterFusion = true; 649 } 650 651 // If src loop has dependences after fusion or it writes to an live-out or 652 // escaping memref, we can only remove it if the fusion slice is maximal so 653 // that all the dependences are preserved. 654 if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) { 655 std::optional<bool> isMaximal = fusionSlice.isMaximal(); 656 if (!isMaximal) { 657 LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine " 658 "if fusion is maximal\n"); 659 return false; 660 } 661 662 if (!*isMaximal) { 663 LLVM_DEBUG(llvm::dbgs() 664 << "Src loop can't be removed: fusion is not maximal\n"); 665 return false; 666 } 667 } 668 669 return true; 670 } 671 672 /// Returns in 'srcIdCandidates' the producer fusion candidates for consumer 673 /// 'dstId'. Candidates are sorted by node id order. This order corresponds to 674 /// the program order when the 'mdg' is created. However, program order is not 675 /// guaranteed and must not be required by the client. Program order won't be 676 /// held if the 'mdg' is reused from a previous fusion step or if the node 677 /// creation order changes in the future to support more advance cases. 678 // TODO: Move this to a loop fusion utility once 'mdg' is also moved. 679 static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg, 680 SmallVectorImpl<unsigned> &srcIdCandidates) { 681 // Skip if no input edges along which to fuse. 682 if (mdg->inEdges.count(dstId) == 0) 683 return; 684 685 // Gather memrefs from loads in 'dstId'. 686 auto *dstNode = mdg->getNode(dstId); 687 DenseSet<Value> consumedMemrefs; 688 for (Operation *load : dstNode->loads) 689 consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef()); 690 691 // Traverse 'dstId' incoming edges and gather the nodes that contain a store 692 // to one of the consumed memrefs. 693 for (auto &srcEdge : mdg->inEdges[dstId]) { 694 auto *srcNode = mdg->getNode(srcEdge.id); 695 // Skip if 'srcNode' is not a loop nest. 696 if (!isa<AffineForOp>(srcNode->op)) 697 continue; 698 699 if (any_of(srcNode->stores, [&](Operation *op) { 700 auto storeOp = cast<AffineWriteOpInterface>(op); 701 return consumedMemrefs.count(storeOp.getMemRef()) > 0; 702 })) 703 srcIdCandidates.push_back(srcNode->id); 704 } 705 706 llvm::sort(srcIdCandidates); 707 srcIdCandidates.erase( 708 std::unique(srcIdCandidates.begin(), srcIdCandidates.end()), 709 srcIdCandidates.end()); 710 } 711 712 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a 713 /// producer-consumer dependence between 'srcId' and 'dstId'. 714 static void 715 gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId, 716 MemRefDependenceGraph *mdg, 717 DenseSet<Value> &producerConsumerMemrefs) { 718 auto *dstNode = mdg->getNode(dstId); 719 auto *srcNode = mdg->getNode(srcId); 720 gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads, 721 producerConsumerMemrefs); 722 } 723 724 /// A memref escapes in the context of the fusion pass if either: 725 /// 1. it (or its alias) is a block argument, or 726 /// 2. created by an op not known to guarantee alias freedom, 727 /// 3. it (or its alias) are used by ops other than affine dereferencing ops 728 /// (e.g., by call op, memref load/store ops, alias creating ops, unknown ops, 729 /// terminator ops, etc.); such ops do not deference the memref in an affine 730 /// way. 731 static bool isEscapingMemref(Value memref, Block *block) { 732 Operation *defOp = memref.getDefiningOp(); 733 // Check if 'memref' is a block argument. 734 if (!defOp) 735 return true; 736 737 // Check if this is defined to be an alias of another memref. 738 if (auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp)) 739 if (isEscapingMemref(viewOp.getViewSource(), block)) 740 return true; 741 742 // Any op besides allocating ops wouldn't guarantee alias freedom 743 if (!hasSingleEffect<mlir::MemoryEffects::Allocate>(defOp, memref)) 744 return true; 745 746 // Check if 'memref' is used by a non-deferencing op (including unknown ones) 747 // (e.g., call ops, alias creating ops, etc.). 748 return llvm::any_of(memref.getUsers(), [&](Operation *user) { 749 // Ignore users outside of `block`. 750 if (block->getParent()->findAncestorOpInRegion(*user)->getBlock() != block) 751 return false; 752 return !isa<AffineMapAccessInterface>(*user); 753 }); 754 } 755 756 /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' 757 /// that escape the block or are accessed in a non-affine way. 758 void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg, 759 DenseSet<Value> &escapingMemRefs) { 760 auto *node = mdg->getNode(id); 761 for (Operation *storeOp : node->stores) { 762 auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef(); 763 if (escapingMemRefs.count(memref)) 764 continue; 765 if (isEscapingMemref(memref, &mdg->block)) 766 escapingMemRefs.insert(memref); 767 } 768 } 769 770 } // namespace 771 772 // Initializes the data dependence graph by walking operations in `block`. 773 // Assigns each node in the graph a node id based on program order in 'f'. 774 bool MemRefDependenceGraph::init() { 775 LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n"); 776 // Map from a memref to the set of ids of the nodes that have ops accessing 777 // the memref. 778 DenseMap<Value, SetVector<unsigned>> memrefAccesses; 779 780 DenseMap<Operation *, unsigned> forToNodeMap; 781 for (Operation &op : block) { 782 if (auto forOp = dyn_cast<AffineForOp>(op)) { 783 // Create graph node 'id' to represent top-level 'forOp' and record 784 // all loads and store accesses it contains. 785 LoopNestStateCollector collector; 786 collector.collect(&op); 787 // Return false if a region holding op other than 'affine.for' and 788 // 'affine.if' was found (not currently supported). 789 if (collector.hasNonAffineRegionOp) 790 return false; 791 Node node(nextNodeId++, &op); 792 for (auto *opInst : collector.loadOpInsts) { 793 node.loads.push_back(opInst); 794 auto memref = cast<AffineReadOpInterface>(opInst).getMemRef(); 795 memrefAccesses[memref].insert(node.id); 796 } 797 for (auto *opInst : collector.storeOpInsts) { 798 node.stores.push_back(opInst); 799 auto memref = cast<AffineWriteOpInterface>(opInst).getMemRef(); 800 memrefAccesses[memref].insert(node.id); 801 } 802 forToNodeMap[&op] = node.id; 803 nodes.insert({node.id, node}); 804 } else if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) { 805 // Create graph node for top-level load op. 806 Node node(nextNodeId++, &op); 807 node.loads.push_back(&op); 808 auto memref = cast<AffineReadOpInterface>(op).getMemRef(); 809 memrefAccesses[memref].insert(node.id); 810 nodes.insert({node.id, node}); 811 } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) { 812 // Create graph node for top-level store op. 813 Node node(nextNodeId++, &op); 814 node.stores.push_back(&op); 815 auto memref = cast<AffineWriteOpInterface>(op).getMemRef(); 816 memrefAccesses[memref].insert(node.id); 817 nodes.insert({node.id, node}); 818 } else if (op.getNumRegions() != 0) { 819 // Return false if another region is found (not currently supported). 820 return false; 821 } else if (op.getNumResults() > 0 && !op.use_empty()) { 822 // Create graph node for top-level producer of SSA values, which 823 // could be used by loop nest nodes. 824 Node node(nextNodeId++, &op); 825 nodes.insert({node.id, node}); 826 } else if (isa<CallOpInterface>(op)) { 827 // Create graph node for top-level Call Op that takes any argument of 828 // memref type. Call Op that returns one or more memref type results 829 // is already taken care of, by the previous conditions. 830 if (llvm::any_of(op.getOperandTypes(), 831 [&](Type t) { return t.isa<MemRefType>(); })) { 832 Node node(nextNodeId++, &op); 833 nodes.insert({node.id, node}); 834 } 835 } else if (hasEffect<MemoryEffects::Write, MemoryEffects::Free>(&op)) { 836 // Create graph node for top-level op, which could have a memory write 837 // side effect. 838 Node node(nextNodeId++, &op); 839 nodes.insert({node.id, node}); 840 } 841 } 842 843 for (auto &idAndNode : nodes) { 844 LLVM_DEBUG(llvm::dbgs() << "Create node " << idAndNode.first << " for:\n" 845 << *(idAndNode.second.op) << "\n"); 846 (void)idAndNode; 847 } 848 849 // Add dependence edges between nodes which produce SSA values and their 850 // users. Load ops can be considered as the ones producing SSA values. 851 for (auto &idAndNode : nodes) { 852 const Node &node = idAndNode.second; 853 // Stores don't define SSA values, skip them. 854 if (!node.stores.empty()) 855 continue; 856 Operation *opInst = node.op; 857 for (Value value : opInst->getResults()) { 858 for (Operation *user : value.getUsers()) { 859 // Ignore users outside of the block. 860 if (block.getParent()->findAncestorOpInRegion(*user)->getBlock() != 861 &block) 862 continue; 863 SmallVector<AffineForOp, 4> loops; 864 getAffineForIVs(*user, &loops); 865 if (loops.empty()) 866 continue; 867 assert(forToNodeMap.count(loops[0]) > 0 && "missing mapping"); 868 unsigned userLoopNestId = forToNodeMap[loops[0]]; 869 addEdge(node.id, userLoopNestId, value); 870 } 871 } 872 } 873 874 // Walk memref access lists and add graph edges between dependent nodes. 875 for (auto &memrefAndList : memrefAccesses) { 876 unsigned n = memrefAndList.second.size(); 877 for (unsigned i = 0; i < n; ++i) { 878 unsigned srcId = memrefAndList.second[i]; 879 bool srcHasStore = 880 getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0; 881 for (unsigned j = i + 1; j < n; ++j) { 882 unsigned dstId = memrefAndList.second[j]; 883 bool dstHasStore = 884 getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0; 885 if (srcHasStore || dstHasStore) 886 addEdge(srcId, dstId, memrefAndList.first); 887 } 888 } 889 } 890 return true; 891 } 892 893 // Sinks all sequential loops to the innermost levels (while preserving 894 // relative order among them) and moves all parallel loops to the 895 // outermost (while again preserving relative order among them). 896 // This can increase the loop depth at which we can fuse a slice, since we are 897 // pushing loop carried dependence to a greater depth in the loop nest. 898 static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { 899 assert(isa<AffineForOp>(node->op)); 900 AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op)); 901 node->op = newRootForOp; 902 } 903 904 // Creates and returns a private (single-user) memref for fused loop rooted 905 // at 'forOp', with (potentially reduced) memref size based on the 906 // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'. 907 // TODO: consider refactoring the common code from generateDma and 908 // this one. 909 static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, 910 unsigned dstLoopDepth, 911 std::optional<unsigned> fastMemorySpace, 912 uint64_t localBufSizeThreshold) { 913 Operation *forInst = forOp.getOperation(); 914 915 // Create builder to insert alloc op just before 'forOp'. 916 OpBuilder b(forInst); 917 // Builder to create constants at the top level. 918 OpBuilder top(forInst->getParentRegion()); 919 // Create new memref type based on slice bounds. 920 auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef(); 921 auto oldMemRefType = oldMemRef.getType().cast<MemRefType>(); 922 unsigned rank = oldMemRefType.getRank(); 923 924 // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'. 925 MemRefRegion region(srcStoreOpInst->getLoc()); 926 bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth)); 927 (void)validRegion; 928 assert(validRegion && "unexpected memref region failure"); 929 SmallVector<int64_t, 4> newShape; 930 std::vector<SmallVector<int64_t, 4>> lbs; 931 SmallVector<int64_t, 8> lbDivisors; 932 lbs.reserve(rank); 933 // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed 934 // by 'srcStoreOpInst' at depth 'dstLoopDepth'. 935 std::optional<int64_t> numElements = 936 region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors); 937 assert(numElements && "non-constant number of elts in local buffer"); 938 939 const FlatAffineValueConstraints *cst = region.getConstraints(); 940 // 'outerIVs' holds the values that this memory region is symbolic/parametric 941 // on; this would correspond to loop IVs surrounding the level at which the 942 // slice is being materialized. 943 SmallVector<Value, 8> outerIVs; 944 cst->getValues(rank, cst->getNumVars(), &outerIVs); 945 946 // Build 'rank' AffineExprs from MemRefRegion 'lbs' 947 SmallVector<AffineExpr, 4> offsets; 948 offsets.reserve(rank); 949 for (unsigned d = 0; d < rank; ++d) { 950 assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size"); 951 952 AffineExpr offset = top.getAffineConstantExpr(0); 953 for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) { 954 offset = offset + lbs[d][j] * top.getAffineDimExpr(j); 955 } 956 assert(lbDivisors[d] > 0); 957 offset = 958 (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]); 959 offsets.push_back(offset); 960 } 961 962 // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed 963 // by 'srcStoreOpInst'. 964 auto eltSize = getMemRefIntOrFloatEltSizeInBytes(oldMemRefType); 965 assert(eltSize && "memrefs with size elt types expected"); 966 uint64_t bufSize = *eltSize * *numElements; 967 unsigned newMemSpace; 968 if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) { 969 newMemSpace = *fastMemorySpace; 970 } else { 971 newMemSpace = oldMemRefType.getMemorySpaceAsInt(); 972 } 973 auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(), 974 {}, newMemSpace); 975 976 // Create new private memref for fused loop 'forOp'. 'newShape' is always 977 // a constant shape. 978 // TODO: Create/move alloc ops for private memrefs closer to their 979 // consumer loop nests to reduce their live range. Currently they are added 980 // at the beginning of the block, because loop nests can be reordered 981 // during the fusion pass. 982 Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType); 983 984 // Build an AffineMap to remap access functions based on lower bound offsets. 985 SmallVector<AffineExpr, 4> remapExprs; 986 remapExprs.reserve(rank); 987 for (unsigned i = 0; i < rank; i++) { 988 auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i); 989 990 auto remapExpr = 991 simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0); 992 remapExprs.push_back(remapExpr); 993 } 994 995 auto indexRemap = 996 AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext()); 997 998 // Replace all users of 'oldMemRef' with 'newMemRef'. 999 LogicalResult res = 1000 replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, 1001 /*extraOperands=*/outerIVs, 1002 /*symbolOperands=*/{}, 1003 /*domOpFilter=*/&*forOp.getBody()->begin()); 1004 assert(succeeded(res) && 1005 "replaceAllMemrefUsesWith should always succeed here"); 1006 (void)res; 1007 return newMemRef; 1008 } 1009 1010 /// Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and 1011 /// 'dstId'), if there is any non-affine operation accessing 'memref', return 1012 /// true. Otherwise, return false. 1013 static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, 1014 Value memref, 1015 MemRefDependenceGraph *mdg) { 1016 auto *srcNode = mdg->getNode(srcId); 1017 auto *dstNode = mdg->getNode(dstId); 1018 Value::user_range users = memref.getUsers(); 1019 // For each MemRefDependenceGraph's node that is between 'srcNode' and 1020 // 'dstNode' (exclusive of 'srcNodes' and 'dstNode'), check whether any 1021 // non-affine operation in the node accesses the 'memref'. 1022 for (auto &idAndNode : mdg->nodes) { 1023 Operation *op = idAndNode.second.op; 1024 // Take care of operations between 'srcNode' and 'dstNode'. 1025 if (srcNode->op->isBeforeInBlock(op) && op->isBeforeInBlock(dstNode->op)) { 1026 // Walk inside the operation to find any use of the memref. 1027 // Interrupt the walk if found. 1028 auto walkResult = op->walk([&](Operation *user) { 1029 // Skip affine ops. 1030 if (isa<AffineMapAccessInterface>(*user)) 1031 return WalkResult::advance(); 1032 // Find a non-affine op that uses the memref. 1033 if (llvm::is_contained(users, user)) 1034 return WalkResult::interrupt(); 1035 return WalkResult::advance(); 1036 }); 1037 if (walkResult.wasInterrupted()) 1038 return true; 1039 } 1040 } 1041 return false; 1042 } 1043 1044 /// Check whether a memref value in node 'srcId' has a non-affine that 1045 /// is between node 'srcId' and node 'dstId' (exclusive of 'srcNode' and 1046 /// 'dstNode'). 1047 static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, 1048 MemRefDependenceGraph *mdg) { 1049 // Collect memref values in node 'srcId'. 1050 auto *srcNode = mdg->getNode(srcId); 1051 llvm::SmallDenseSet<Value, 2> memRefValues; 1052 srcNode->op->walk([&](Operation *op) { 1053 // Skip affine ops. 1054 if (isa<AffineForOp>(op)) 1055 return WalkResult::advance(); 1056 for (Value v : op->getOperands()) 1057 // Collect memref values only. 1058 if (v.getType().isa<MemRefType>()) 1059 memRefValues.insert(v); 1060 return WalkResult::advance(); 1061 }); 1062 // Looking for users between node 'srcId' and node 'dstId'. 1063 return llvm::any_of(memRefValues, [&](Value memref) { 1064 return hasNonAffineUsersOnThePath(srcId, dstId, memref, mdg); 1065 }); 1066 } 1067 1068 // Checks the profitability of fusing a backwards slice of the loop nest 1069 // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'. 1070 // The argument 'srcStoreOpInst' is used to calculate the storage reduction on 1071 // the memref being produced and consumed, which is an input to the cost model. 1072 // For producer-consumer fusion, 'srcStoreOpInst' will be the same as 1073 // 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse 1074 // fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the 1075 // same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the 1076 // unique store op in the src node, which will be used to check that the write 1077 // region is the same after input-reuse fusion. Computation slices are provided 1078 // in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which 1079 // fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is 1080 // profitable to fuse the candidate loop nests. Returns false otherwise. 1081 // `dstLoopDepth` is set to the most profitable depth at which to materialize 1082 // the source loop nest slice. 1083 // The profitability model executes the following steps: 1084 // *) Computes the backward computation slice at 'srcOpInst'. This 1085 // computation slice of the loop nest surrounding 'srcOpInst' is 1086 // represented by modified src loop bounds in 'sliceState', which are 1087 // functions of loop IVs in the loop nest surrounding 'srcOpInst'. 1088 // *) Computes the cost of unfused src/dst loop nests (currently the cost of a 1089 // loop nest is the total number of dynamic operation instances in the loop 1090 // nest). 1091 // *) Computes the cost of fusing a slice of the src loop nest into the dst 1092 // loop nest at various values of dst loop depth, attempting to fuse 1093 // the largest computation slice at the maximal dst loop depth (closest to 1094 // the load) to minimize reuse distance and potentially enable subsequent 1095 // load/store forwarding. 1096 // NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop 1097 // nest, at which the src computation slice is inserted/fused. 1098 // NOTE: We attempt to maximize the dst loop depth, but there are cases 1099 // where a particular setting for 'dstLoopNest' might fuse an unsliced 1100 // loop (within the src computation slice) at a depth which results in 1101 // excessive recomputation (see unit tests for examples). 1102 // *) Compares the total cost of the unfused loop nests to the min cost fused 1103 // loop nest computed in the previous step, and returns true if the latter 1104 // is lower. 1105 // TODO: Extend profitability analysis to support scenarios with multiple 1106 // stores. 1107 static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, 1108 AffineForOp dstForOp, 1109 ArrayRef<ComputationSliceState> depthSliceUnions, 1110 unsigned maxLegalFusionDepth, 1111 unsigned *dstLoopDepth, 1112 double computeToleranceThreshold) { 1113 LLVM_DEBUG({ 1114 llvm::dbgs() << "Checking whether fusion is profitable between src op:\n"; 1115 llvm::dbgs() << ' ' << *srcOpInst << " and destination loop:\n"; 1116 llvm::dbgs() << dstForOp << "\n"; 1117 }); 1118 1119 if (maxLegalFusionDepth == 0) { 1120 LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth is 0\n"); 1121 return false; 1122 } 1123 1124 // Compute cost of sliced and unsliced src loop nest. 1125 SmallVector<AffineForOp, 4> srcLoopIVs; 1126 getAffineForIVs(*srcOpInst, &srcLoopIVs); 1127 1128 // Walk src loop nest and collect stats. 1129 LoopNestStats srcLoopNestStats; 1130 if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats)) 1131 return false; 1132 1133 // Compute cost of dst loop nest. 1134 LoopNestStats dstLoopNestStats; 1135 if (!getLoopNestStats(dstForOp, &dstLoopNestStats)) 1136 return false; 1137 1138 // Search for min cost value for 'dstLoopDepth'. At each value of 1139 // 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice 1140 // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union 1141 // of these bounds). Next the union slice bounds are used to calculate 1142 // the cost of the slice and the cost of the slice inserted into the dst 1143 // loop nest at 'dstLoopDepth'. 1144 uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max(); 1145 double maxStorageReduction = 0.0; 1146 std::optional<uint64_t> sliceMemEstimate; 1147 1148 // The best loop depth at which to materialize the slice. 1149 std::optional<unsigned> bestDstLoopDepth; 1150 1151 // Compute op instance count for the src loop nest without iteration slicing. 1152 uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats); 1153 1154 // Compute src loop nest write region size. 1155 MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc()); 1156 if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) { 1157 LLVM_DEBUG(llvm::dbgs() 1158 << "Unable to compute MemRefRegion for source operation\n"); 1159 return false; 1160 } 1161 1162 std::optional<int64_t> maybeSrcWriteRegionSizeBytes = 1163 srcWriteRegion.getRegionSize(); 1164 if (!maybeSrcWriteRegionSizeBytes.has_value()) 1165 return false; 1166 int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes; 1167 1168 // Compute op instance count for the src loop nest. 1169 uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats); 1170 1171 // Evaluate all depth choices for materializing the slice in the destination 1172 // loop nest. 1173 for (unsigned i = maxLegalFusionDepth; i >= 1; --i) { 1174 const ComputationSliceState &slice = depthSliceUnions[i - 1]; 1175 // Skip slice union if it wasn't computed for this depth. 1176 if (slice.isEmpty()) 1177 continue; 1178 1179 int64_t fusedLoopNestComputeCost; 1180 if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp, 1181 dstLoopNestStats, slice, 1182 &fusedLoopNestComputeCost)) { 1183 LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n"); 1184 continue; 1185 } 1186 1187 double additionalComputeFraction = 1188 fusedLoopNestComputeCost / 1189 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) - 1190 1; 1191 1192 // Determine what the slice write MemRefRegion would be, if the src loop 1193 // nest slice 'slice' were to be inserted into the dst loop nest at loop 1194 // depth 'i'. 1195 MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc()); 1196 if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0, 1197 &slice))) { 1198 LLVM_DEBUG(llvm::dbgs() 1199 << "Failed to compute slice write region at loopDepth: " << i 1200 << "\n"); 1201 continue; 1202 } 1203 1204 std::optional<int64_t> maybeSliceWriteRegionSizeBytes = 1205 sliceWriteRegion.getRegionSize(); 1206 if (!maybeSliceWriteRegionSizeBytes.has_value() || 1207 *maybeSliceWriteRegionSizeBytes == 0) { 1208 LLVM_DEBUG(llvm::dbgs() 1209 << "Failed to get slice write region size at loopDepth: " << i 1210 << "\n"); 1211 continue; 1212 } 1213 int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes; 1214 1215 // If we are fusing for reuse, check that write regions remain the same. 1216 // TODO: Write region check should check sizes and offsets in 1217 // each dimension, so that we are sure they are covering the same memref 1218 // region. Also, move this out to a isMemRefRegionSuperSet helper function. 1219 if (srcOpInst != srcStoreOpInst && 1220 sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes) 1221 continue; 1222 1223 double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) / 1224 static_cast<double>(sliceWriteRegionSizeBytes); 1225 1226 LLVM_DEBUG({ 1227 std::stringstream msg; 1228 msg << " evaluating fusion profitability at depth : " << i << "\n" 1229 << std::fixed << std::setprecision(2) 1230 << " additional compute fraction: " 1231 << 100.0 * additionalComputeFraction << "%\n" 1232 << " storage reduction factor: " << storageReduction << "x\n" 1233 << " fused nest cost: " << fusedLoopNestComputeCost << "\n" 1234 << " src write region size: " << srcWriteRegionSizeBytes << "\n" 1235 << " slice write region size: " << sliceWriteRegionSizeBytes 1236 << "\n"; 1237 llvm::dbgs() << msg.str(); 1238 }); 1239 1240 // TODO: This is a placeholder cost model. 1241 // Among all choices that add an acceptable amount of redundant computation 1242 // (as per computeToleranceThreshold), we will simply pick the one that 1243 // reduces the intermediary size the most. 1244 if ((storageReduction > maxStorageReduction) && 1245 (additionalComputeFraction < computeToleranceThreshold)) { 1246 maxStorageReduction = storageReduction; 1247 bestDstLoopDepth = i; 1248 minFusedLoopNestComputeCost = fusedLoopNestComputeCost; 1249 sliceMemEstimate = sliceWriteRegionSizeBytes; 1250 } 1251 } 1252 1253 // A simple cost model: fuse if it reduces the memory footprint. 1254 1255 if (!bestDstLoopDepth) { 1256 LLVM_DEBUG( 1257 llvm::dbgs() 1258 << "All fusion choices involve more than the threshold amount of " 1259 "redundant computation; NOT fusing.\n"); 1260 return false; 1261 } 1262 1263 if (!bestDstLoopDepth) { 1264 LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n"); 1265 return false; 1266 } 1267 1268 // Set dstLoopDepth based on best values from search. 1269 *dstLoopDepth = *bestDstLoopDepth; 1270 1271 LLVM_DEBUG( 1272 llvm::dbgs() << " LoopFusion fusion stats:" 1273 << "\n best loop depth: " << bestDstLoopDepth 1274 << "\n src loop nest compute cost: " << srcLoopNestCost 1275 << "\n dst loop nest compute cost: " << dstLoopNestCost 1276 << "\n fused loop nest compute cost: " 1277 << minFusedLoopNestComputeCost << "\n"); 1278 1279 auto dstMemSize = getMemoryFootprintBytes(dstForOp); 1280 auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]); 1281 1282 std::optional<double> storageReduction; 1283 1284 if (!dstMemSize || !srcMemSize) { 1285 LLVM_DEBUG(llvm::dbgs() 1286 << " fusion memory benefit cannot be evaluated; NOT fusing.\n"); 1287 return false; 1288 } 1289 1290 auto srcMemSizeVal = *srcMemSize; 1291 auto dstMemSizeVal = *dstMemSize; 1292 1293 assert(sliceMemEstimate && "expected value"); 1294 auto fusedMem = dstMemSizeVal + *sliceMemEstimate; 1295 1296 LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n" 1297 << " dst mem: " << dstMemSizeVal << "\n" 1298 << " fused mem: " << fusedMem << "\n" 1299 << " slice mem: " << sliceMemEstimate << "\n"); 1300 1301 if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) { 1302 LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n"); 1303 return false; 1304 } 1305 storageReduction = 1306 100.0 * 1307 (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal)); 1308 1309 double additionalComputeFraction = 1310 100.0 * (minFusedLoopNestComputeCost / 1311 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) - 1312 1); 1313 (void)additionalComputeFraction; 1314 LLVM_DEBUG({ 1315 std::stringstream msg; 1316 msg << " fusion is most profitable at depth " << *dstLoopDepth << " with " 1317 << std::setprecision(2) << additionalComputeFraction 1318 << "% redundant computation and a "; 1319 msg << (storageReduction ? std::to_string(*storageReduction) : "<unknown>"); 1320 msg << "% storage reduction.\n"; 1321 llvm::dbgs() << msg.str(); 1322 }); 1323 1324 return true; 1325 } 1326 1327 namespace { 1328 1329 // GreedyFusion greedily fuses loop nests which have a producer/consumer or 1330 // input-reuse relationship on a memref, with the goal of improving locality. 1331 // 1332 // The steps of the producer-consumer fusion algorithm are as follows: 1333 // 1334 // *) A worklist is initialized with node ids from the dependence graph. 1335 // *) For each node id in the worklist: 1336 // *) Pop an AffineForOp of the worklist. This 'dstAffineForOp' will be a 1337 // candidate destination AffineForOp into which fusion will be attempted. 1338 // *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'. 1339 // *) For each LoadOp in 'dstLoadOps' do: 1340 // *) Look up dependent loop nests which have a single store op to the same 1341 // memref. 1342 // *) Check if dependences would be violated by the fusion. 1343 // *) Get a computation slice of 'srcLoopNest', which adjusts its loop 1344 // bounds to be functions of 'dstLoopNest' IVs and symbols. 1345 // *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest', 1346 // at a loop depth determined by the cost model in 'isFusionProfitable'. 1347 // *) Add the newly fused load/store operations to the state, 1348 // and also add newly fused load ops to 'dstLoopOps' to be considered 1349 // as fusion dst load ops in another iteration. 1350 // *) Remove old src loop nest and its associated state. 1351 // 1352 // The steps of the input-reuse fusion algorithm are as follows: 1353 // 1354 // *) Initialize 'worklist' with node ids from the dependence graph. 1355 // *) For each 'dstNode' in the worklist: 1356 // *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which 1357 // loads from the same memref, but which has no dependence paths to/from. 1358 // *) Get a computation slice of 'sibLoopNest', which adjusts its loop 1359 // bounds to be functions of 'dstLoopNest' IVs and symbols. 1360 // *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest', 1361 // at a loop depth determined by the cost model in 'isFusionProfitable'. 1362 // This function also checks that the memref write region of 'sibLoopNest', 1363 // is preserved in the fused loop nest. 1364 // *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'. 1365 // 1366 // Given a graph where top-level operations are vertices in the set 'V' and 1367 // edges in the set 'E' are dependences between vertices, this algorithm 1368 // takes O(V) time for initialization, and has runtime O(V + E). 1369 // 1370 // This greedy algorithm is not 'maximal' due to the current restriction of 1371 // fusing along single producer consumer edges, but there is a TODO: to fix 1372 // this. 1373 // 1374 // TODO: Experiment with other fusion policies. 1375 struct GreedyFusion { 1376 public: 1377 // The data dependence graph to traverse during fusion. 1378 MemRefDependenceGraph *mdg; 1379 // Worklist of graph nodes visited during the fusion pass. 1380 SmallVector<unsigned, 8> worklist; 1381 // Parameter for local buffer size threshold. 1382 unsigned localBufSizeThreshold; 1383 // Parameter for fast memory space. 1384 std::optional<unsigned> fastMemorySpace; 1385 // If true, ignore any additional (redundant) computation tolerance threshold 1386 // that would have prevented fusion. 1387 bool maximalFusion; 1388 // The amount of additional computation that is tolerated while fusing 1389 // pair-wise as a fraction of the total computation. 1390 double computeToleranceThreshold; 1391 1392 using Node = MemRefDependenceGraph::Node; 1393 1394 GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold, 1395 std::optional<unsigned> fastMemorySpace, bool maximalFusion, 1396 double computeToleranceThreshold) 1397 : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold), 1398 fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion), 1399 computeToleranceThreshold(computeToleranceThreshold) {} 1400 1401 /// Initializes 'worklist' with nodes from 'mdg'. 1402 void init() { 1403 // TODO: Add a priority queue for prioritizing nodes by different 1404 // metrics (e.g. arithmetic intensity/flops-to-bytes ratio). 1405 worklist.clear(); 1406 for (auto &idAndNode : mdg->nodes) { 1407 const Node &node = idAndNode.second; 1408 worklist.push_back(node.id); 1409 } 1410 } 1411 /// Run only sibling fusion on the `mdg`. 1412 void runSiblingFusionOnly() { 1413 fuseSiblingNodes(); 1414 eraseUnusedMemRefAllocations(); 1415 } 1416 1417 /// Run only producer/consumer fusion on the `mdg`. 1418 void runProducerConsumerFusionOnly() { 1419 fuseProducerConsumerNodes( 1420 /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max()); 1421 eraseUnusedMemRefAllocations(); 1422 } 1423 1424 // Run the GreedyFusion pass. 1425 // *) First pass through the nodes fuses single-use producer nodes into their 1426 // unique consumer. 1427 // *) Second pass fuses sibling nodes which share no dependence edges. 1428 // *) Third pass fuses any remaining producer nodes into their users. 1429 void runGreedyFusion() { 1430 // TODO: Run this repeatedly until a fixed-point is reached. 1431 fuseProducerConsumerNodes(/*maxSrcUserCount=*/1); 1432 fuseSiblingNodes(); 1433 fuseProducerConsumerNodes( 1434 /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max()); 1435 eraseUnusedMemRefAllocations(); 1436 } 1437 1438 /// Returns true if a private memref can be created for `memref` given 1439 /// the fusion scenario reflected by the other arguments. 1440 bool canCreatePrivateMemRef(Value memref, 1441 const DenseSet<Value> &srcEscapingMemRefs, 1442 unsigned producerId, unsigned consumerId, 1443 bool removeSrcNode) { 1444 const Node *consumerNode = mdg->getNode(consumerId); 1445 // If `memref` is an escaping one, do not create a private memref 1446 // for the below scenarios, since doing so will leave the escaping 1447 // memref unmodified as all the writes originally meant for the 1448 // escaping memref would be performed on the private memref: 1449 // 1. The source is to be removed after fusion, 1450 // OR 1451 // 2. The destination writes to `memref`. 1452 if (srcEscapingMemRefs.count(memref) > 0 && 1453 (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0)) 1454 return false; 1455 1456 // Don't create a private memref if 'srcNode' has in edges on 1457 // 'memref' or 'dstNode' has out edges on 'memref'. 1458 if (mdg->getIncomingMemRefAccesses(producerId, memref) > 0 || 1459 mdg->getOutEdgeCount(consumerId, memref) > 0) 1460 return false; 1461 1462 // If 'srcNode' will be removed but it has out edges on 'memref' to 1463 // nodes other than 'dstNode', we have to preserve dependences and 1464 // cannot create a private memref. 1465 if (removeSrcNode && 1466 any_of(mdg->outEdges[producerId], [&](const auto &edge) { 1467 return edge.value == memref && edge.id != consumerId; 1468 })) 1469 return false; 1470 1471 return true; 1472 } 1473 1474 /// Perform fusions with node `dstId` as the destination of fusion, with 1475 /// No fusion is performed when producers with a user count greater than 1476 /// `maxSrcUserCount` for any of the memrefs involved. 1477 void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) { 1478 LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n"); 1479 // Skip if this node was removed (fused into another node). 1480 if (mdg->nodes.count(dstId) == 0) 1481 return; 1482 // Get 'dstNode' into which to attempt fusion. 1483 auto *dstNode = mdg->getNode(dstId); 1484 // Skip if 'dstNode' is not a loop nest. 1485 if (!isa<AffineForOp>(dstNode->op)) 1486 return; 1487 // Skip if 'dstNode' is a loop nest returning values. 1488 // TODO: support loop nests that return values. 1489 if (dstNode->op->getNumResults() > 0) 1490 return; 1491 1492 LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n"); 1493 1494 // Sink sequential loops in 'dstNode' (and thus raise parallel loops) 1495 // while preserving relative order. This can increase the maximum loop 1496 // depth at which we can fuse a slice of a producer loop nest into a 1497 // consumer loop nest. 1498 sinkSequentialLoops(dstNode); 1499 auto dstAffineForOp = cast<AffineForOp>(dstNode->op); 1500 1501 // Try to fuse 'dstNode' with candidate producer loops until a fixed point 1502 // is reached. Fusing two loops may expose new fusion opportunities. 1503 bool dstNodeChanged; 1504 do { 1505 // Gather src loop candidates for 'dstNode' and visit them in "quasi" 1506 // reverse program order to minimize the number of iterations needed to 1507 // reach the fixed point. Note that this is a best effort approach since 1508 // 'getProducerCandidates' does not always guarantee that program order 1509 // in 'srcIdCandidates'. 1510 dstNodeChanged = false; 1511 SmallVector<unsigned, 16> srcIdCandidates; 1512 getProducerCandidates(dstId, mdg, srcIdCandidates); 1513 1514 for (unsigned srcId : llvm::reverse(srcIdCandidates)) { 1515 // Get 'srcNode' from which to attempt fusion into 'dstNode'. 1516 auto *srcNode = mdg->getNode(srcId); 1517 auto srcAffineForOp = cast<AffineForOp>(srcNode->op); 1518 LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId 1519 << " for dst loop " << dstId << "\n"); 1520 1521 // Skip if 'srcNode' is a loop nest returning values. 1522 // TODO: support loop nests that return values. 1523 if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0) 1524 continue; 1525 1526 DenseSet<Value> producerConsumerMemrefs; 1527 gatherProducerConsumerMemrefs(srcId, dstId, mdg, 1528 producerConsumerMemrefs); 1529 1530 // Skip if 'srcNode' out edge count on any memref is greater than 1531 // 'maxSrcUserCount'. 1532 if (any_of(producerConsumerMemrefs, [&](Value memref) { 1533 return mdg->getOutEdgeCount(srcNode->id, memref) > 1534 maxSrcUserCount; 1535 })) 1536 continue; 1537 1538 // Gather memrefs in 'srcNode' that are written and escape out of the 1539 // block (e.g., memref block arguments, returned memrefs, 1540 // memrefs passed to function calls, etc.). 1541 DenseSet<Value> srcEscapingMemRefs; 1542 gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs); 1543 1544 // Skip if there are non-affine operations in between the 'srcNode' 1545 // and 'dstNode' using their memrefs. If so, we wouldn't be able to 1546 // compute a legal insertion point for now. 'srcNode' and 'dstNode' 1547 // memrefs with non-affine operation users would be considered 1548 // escaping memrefs so we can limit this check to only scenarios with 1549 // escaping memrefs. 1550 if (!srcEscapingMemRefs.empty() && 1551 hasNonAffineUsersOnThePath(srcId, dstId, mdg)) { 1552 LLVM_DEBUG(llvm::dbgs() 1553 << "Can't fuse: non-affine users in between the loops\n"); 1554 continue; 1555 } 1556 1557 // Compute an operation list insertion point for the fused loop 1558 // nest which preserves dependences. 1559 Operation *fusedLoopInsPoint = 1560 mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id); 1561 if (fusedLoopInsPoint == nullptr) 1562 continue; 1563 1564 // Compute the innermost common loop depth for dstNode 1565 // producer-consumer loads/stores. 1566 SmallVector<Operation *, 2> dstMemrefOps; 1567 for (Operation *op : dstNode->loads) 1568 if (producerConsumerMemrefs.count( 1569 cast<AffineReadOpInterface>(op).getMemRef()) > 0) 1570 dstMemrefOps.push_back(op); 1571 for (Operation *op : dstNode->stores) 1572 if (producerConsumerMemrefs.count( 1573 cast<AffineWriteOpInterface>(op).getMemRef())) 1574 dstMemrefOps.push_back(op); 1575 unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps); 1576 1577 // Check the feasibility of fusing src loop nest into dst loop nest 1578 // at loop depths in range [1, dstLoopDepthTest]. 1579 unsigned maxLegalFusionDepth = 0; 1580 SmallVector<ComputationSliceState, 8> depthSliceUnions; 1581 depthSliceUnions.resize(dstLoopDepthTest); 1582 FusionStrategy strategy(FusionStrategy::ProducerConsumer); 1583 for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { 1584 FusionResult result = mlir::canFuseLoops( 1585 srcAffineForOp, dstAffineForOp, 1586 /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy); 1587 1588 if (result.value == FusionResult::Success) 1589 maxLegalFusionDepth = i; 1590 } 1591 1592 if (maxLegalFusionDepth == 0) { 1593 LLVM_DEBUG(llvm::dbgs() 1594 << "Can't fuse: fusion is not legal at any depth\n"); 1595 continue; 1596 } 1597 1598 // Check if fusion would be profitable. We skip profitability analysis 1599 // for maximal fusion since we already know the maximal legal depth to 1600 // fuse. 1601 unsigned bestDstLoopDepth = maxLegalFusionDepth; 1602 if (!maximalFusion) { 1603 // Retrieve producer stores from the src loop. 1604 SmallVector<Operation *, 2> producerStores; 1605 for (Operation *op : srcNode->stores) 1606 if (producerConsumerMemrefs.count( 1607 cast<AffineWriteOpInterface>(op).getMemRef())) 1608 producerStores.push_back(op); 1609 1610 // TODO: Suppport multiple producer stores in profitability 1611 // analysis. We limit profitability analysis to only scenarios with 1612 // a single producer store for now. Note that some multi-store 1613 // producer scenarios will still go through profitability analysis 1614 // if only one of the stores is involved the producer-consumer 1615 // relationship of the candidate loops. 1616 assert(!producerStores.empty() && "Expected producer store"); 1617 if (producerStores.size() > 1) 1618 LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not " 1619 "supported for this case\n"); 1620 else if (!isFusionProfitable(producerStores[0], producerStores[0], 1621 dstAffineForOp, depthSliceUnions, 1622 maxLegalFusionDepth, &bestDstLoopDepth, 1623 computeToleranceThreshold)) 1624 continue; 1625 } 1626 1627 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); 1628 ComputationSliceState &bestSlice = 1629 depthSliceUnions[bestDstLoopDepth - 1]; 1630 assert(!bestSlice.isEmpty() && "Missing slice union for depth"); 1631 1632 // Determine if 'srcId' can be removed after fusion, taking into 1633 // account remaining dependences, escaping memrefs and the fusion 1634 // insertion point. 1635 bool removeSrcNode = canRemoveSrcNodeAfterFusion( 1636 srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs, 1637 mdg); 1638 1639 DenseSet<Value> privateMemrefs; 1640 for (Value memref : producerConsumerMemrefs) { 1641 if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId, 1642 removeSrcNode)) { 1643 // Create a private version of this memref. 1644 LLVM_DEBUG(llvm::dbgs() 1645 << "Creating private memref for " << memref << '\n'); 1646 // Create a private version of this memref. 1647 privateMemrefs.insert(memref); 1648 } 1649 } 1650 1651 // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. 1652 fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice); 1653 dstNodeChanged = true; 1654 1655 LLVM_DEBUG(llvm::dbgs() 1656 << "Fused src loop " << srcId << " into dst loop " << dstId 1657 << " at depth " << bestDstLoopDepth << ":\n" 1658 << dstAffineForOp << "\n"); 1659 1660 // Move 'dstAffineForOp' before 'insertPointInst' if needed. 1661 if (fusedLoopInsPoint != dstAffineForOp) 1662 dstAffineForOp->moveBefore(fusedLoopInsPoint); 1663 1664 // Update edges between 'srcNode' and 'dstNode'. 1665 mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs, 1666 removeSrcNode); 1667 1668 // Create private memrefs. 1669 if (!privateMemrefs.empty()) { 1670 // Gather stores for all the private-to-be memrefs. 1671 DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores; 1672 dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) { 1673 Value storeMemRef = storeOp.getMemRef(); 1674 if (privateMemrefs.count(storeMemRef) > 0) 1675 privateMemRefToStores[storeMemRef].push_back(storeOp); 1676 }); 1677 1678 // Replace original memrefs with private memrefs. Note that all the 1679 // loads and stores on these memrefs will be replaced with a new 1680 // loads and stores. Any reference to the original ones becomes 1681 // invalid after this point. 1682 for (auto &memrefToStoresPair : privateMemRefToStores) { 1683 // TODO: Use union of memref write regions to compute 1684 // private memref footprint. 1685 SmallVector<Operation *, 4> &storesForMemref = 1686 memrefToStoresPair.second; 1687 Value newMemRef = createPrivateMemRef( 1688 dstAffineForOp, storesForMemref[0], bestDstLoopDepth, 1689 fastMemorySpace, localBufSizeThreshold); 1690 // Create new node in dependence graph for 'newMemRef' alloc op. 1691 unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp()); 1692 // Add edge from 'newMemRef' node to dstNode. 1693 mdg->addEdge(newMemRefNodeId, dstId, newMemRef); 1694 } 1695 // One or more entries for 'newMemRef' alloc op are inserted into 1696 // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to 1697 // reallocate, update dstNode. 1698 dstNode = mdg->getNode(dstId); 1699 } 1700 1701 // Collect dst loop stats after memref privatization transformation. 1702 LoopNestStateCollector dstLoopCollector; 1703 dstLoopCollector.collect(dstAffineForOp); 1704 1705 // Clear and add back loads and stores. 1706 mdg->clearNodeLoadAndStores(dstNode->id); 1707 mdg->addToNode(dstId, dstLoopCollector.loadOpInsts, 1708 dstLoopCollector.storeOpInsts); 1709 1710 if (removeSrcNode) { 1711 LLVM_DEBUG(llvm::dbgs() 1712 << "Removing src loop " << srcId << " after fusion\n"); 1713 // srcNode is no longer valid after it is removed from mdg. 1714 srcAffineForOp.erase(); 1715 mdg->removeNode(srcId); 1716 srcNode = nullptr; 1717 } 1718 } 1719 } while (dstNodeChanged); 1720 } 1721 1722 /// Visit each node in the graph, and for each node, attempt to fuse it with 1723 /// producer-consumer candidates. No fusion is performed when producers with a 1724 /// user count greater than `maxSrcUserCount` for any of the memrefs involved 1725 /// are encountered. 1726 void fuseProducerConsumerNodes(unsigned maxSrcUserCount) { 1727 LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n"); 1728 init(); 1729 while (!worklist.empty()) { 1730 unsigned dstId = worklist.back(); 1731 worklist.pop_back(); 1732 performFusionsIntoDest(dstId, maxSrcUserCount); 1733 } 1734 } 1735 1736 // Visits each node in the graph, and for each node, attempts to fuse it with 1737 // its sibling nodes (nodes which share a parent, but no dependence edges). 1738 void fuseSiblingNodes() { 1739 LLVM_DEBUG(llvm::dbgs() << "--- Sibling Fusion ---\n"); 1740 init(); 1741 while (!worklist.empty()) { 1742 unsigned dstId = worklist.back(); 1743 worklist.pop_back(); 1744 1745 // Skip if this node was removed (fused into another node). 1746 if (mdg->nodes.count(dstId) == 0) 1747 continue; 1748 // Get 'dstNode' into which to attempt fusion. 1749 auto *dstNode = mdg->getNode(dstId); 1750 // Skip if 'dstNode' is not a loop nest. 1751 if (!isa<AffineForOp>(dstNode->op)) 1752 continue; 1753 // Attempt to fuse 'dstNode' with its sibling nodes in the graph. 1754 fuseWithSiblingNodes(dstNode); 1755 } 1756 } 1757 1758 // Attempt to fuse 'dstNode' with sibling nodes in the graph. 1759 void fuseWithSiblingNodes(Node *dstNode) { 1760 DenseSet<unsigned> visitedSibNodeIds; 1761 std::pair<unsigned, Value> idAndMemref; 1762 auto dstAffineForOp = cast<AffineForOp>(dstNode->op); 1763 1764 while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) { 1765 unsigned sibId = idAndMemref.first; 1766 Value memref = idAndMemref.second; 1767 // TODO: Check that 'sibStoreOpInst' post-dominates all other 1768 // stores to the same memref in 'sibNode' loop nest. 1769 auto *sibNode = mdg->getNode(sibId); 1770 // Compute an operation list insertion point for the fused loop 1771 // nest which preserves dependences. 1772 assert(sibNode->op->getBlock() == dstNode->op->getBlock()); 1773 Operation *insertPointInst = 1774 sibNode->op->isBeforeInBlock(dstNode->op) 1775 ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id) 1776 : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id); 1777 if (insertPointInst == nullptr) 1778 continue; 1779 1780 // Check if fusion would be profitable and at what depth. 1781 1782 // Get unique 'sibNode' load op to 'memref'. 1783 SmallVector<Operation *, 2> sibLoadOpInsts; 1784 sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts); 1785 // Currently findSiblingNodeToFuse searches for siblings with one load. 1786 assert(sibLoadOpInsts.size() == 1); 1787 Operation *sibLoadOpInst = sibLoadOpInsts[0]; 1788 assert(!sibNode->stores.empty()); 1789 // TODO: Choose the store which postdominates all other stores. 1790 auto *sibStoreOpInst = sibNode->stores.back(); 1791 1792 // Gather 'dstNode' load ops to 'memref'. 1793 SmallVector<Operation *, 2> dstLoadOpInsts; 1794 dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts); 1795 1796 SmallVector<AffineForOp, 4> dstLoopIVs; 1797 getAffineForIVs(*dstLoadOpInsts[0], &dstLoopIVs); 1798 unsigned dstLoopDepthTest = dstLoopIVs.size(); 1799 auto sibAffineForOp = cast<AffineForOp>(sibNode->op); 1800 1801 // Compute loop depth and slice union for fusion. 1802 SmallVector<ComputationSliceState, 8> depthSliceUnions; 1803 depthSliceUnions.resize(dstLoopDepthTest); 1804 unsigned maxLegalFusionDepth = 0; 1805 FusionStrategy strategy(memref); 1806 for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { 1807 FusionResult result = mlir::canFuseLoops( 1808 sibAffineForOp, dstAffineForOp, 1809 /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy); 1810 1811 if (result.value == FusionResult::Success) 1812 maxLegalFusionDepth = i; 1813 } 1814 1815 // Skip if fusion is not feasible at any loop depths. 1816 if (maxLegalFusionDepth == 0) 1817 continue; 1818 1819 unsigned bestDstLoopDepth = maxLegalFusionDepth; 1820 if (!maximalFusion) { 1821 // Check if fusion would be profitable. 1822 if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstAffineForOp, 1823 depthSliceUnions, maxLegalFusionDepth, 1824 &bestDstLoopDepth, computeToleranceThreshold)) 1825 continue; 1826 } 1827 1828 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); 1829 assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() && 1830 "Fusion depth has no computed slice union"); 1831 // Check if source loop is being inserted in the innermost 1832 // destination loop. Based on this, the fused loop may be optimized 1833 // further inside `fuseLoops`. 1834 bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest); 1835 // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'. 1836 mlir::fuseLoops(sibAffineForOp, dstAffineForOp, 1837 depthSliceUnions[bestDstLoopDepth - 1], 1838 isInnermostInsertion); 1839 1840 auto dstForInst = cast<AffineForOp>(dstNode->op); 1841 // Update operation position of fused loop nest (if needed). 1842 if (insertPointInst != dstForInst) { 1843 dstForInst->moveBefore(insertPointInst); 1844 } 1845 // Update data dependence graph state post fusion. 1846 updateStateAfterSiblingFusion(sibNode, dstNode); 1847 } 1848 } 1849 1850 // Searches block argument uses and the graph from 'dstNode' looking for a 1851 // fusion candidate sibling node which shares no dependences with 'dstNode' 1852 // but which loads from the same memref. Returns true and sets 1853 // 'idAndMemrefToFuse' on success. Returns false otherwise. 1854 bool findSiblingNodeToFuse(Node *dstNode, 1855 DenseSet<unsigned> *visitedSibNodeIds, 1856 std::pair<unsigned, Value> *idAndMemrefToFuse) { 1857 // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse 1858 // on 'memref'. 1859 auto canFuseWithSibNode = [&](Node *sibNode, Value memref) { 1860 // Skip if 'outEdge' is not a read-after-write dependence. 1861 // TODO: Remove restrict to single load op restriction. 1862 if (sibNode->getLoadOpCount(memref) != 1) 1863 return false; 1864 // Skip if there exists a path of dependent edges between 1865 // 'sibNode' and 'dstNode'. 1866 if (mdg->hasDependencePath(sibNode->id, dstNode->id) || 1867 mdg->hasDependencePath(dstNode->id, sibNode->id)) 1868 return false; 1869 // Skip sib node if it loads to (and stores from) the same memref on 1870 // which it also has an input dependence edge. 1871 DenseSet<Value> loadAndStoreMemrefSet; 1872 sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet); 1873 if (llvm::any_of(loadAndStoreMemrefSet, [=](Value memref) { 1874 return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0; 1875 })) 1876 return false; 1877 1878 // Check that all stores are to the same memref. 1879 DenseSet<Value> storeMemrefs; 1880 for (auto *storeOpInst : sibNode->stores) { 1881 storeMemrefs.insert( 1882 cast<AffineWriteOpInterface>(storeOpInst).getMemRef()); 1883 } 1884 if (storeMemrefs.size() != 1) 1885 return false; 1886 1887 // Skip if a memref value in one node is used by a non-affine memref 1888 // access that lies between 'dstNode' and 'sibNode'. 1889 if (hasNonAffineUsersOnThePath(dstNode->id, sibNode->id, mdg) || 1890 hasNonAffineUsersOnThePath(sibNode->id, dstNode->id, mdg)) 1891 return false; 1892 return true; 1893 }; 1894 1895 // Search for siblings which load the same memref block argument. 1896 Block *block = dstNode->op->getBlock(); 1897 for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) { 1898 for (Operation *user : block->getArgument(i).getUsers()) { 1899 auto loadOp = dyn_cast<AffineReadOpInterface>(user); 1900 if (!loadOp) 1901 continue; 1902 // Gather loops surrounding 'use'. 1903 SmallVector<AffineForOp, 4> loops; 1904 getAffineForIVs(*user, &loops); 1905 // Skip 'use' if it is not within a loop nest. 1906 if (loops.empty()) 1907 continue; 1908 Node *sibNode = mdg->getForOpNode(loops[0]); 1909 assert(sibNode != nullptr); 1910 // Skip 'use' if it not a sibling to 'dstNode'. 1911 if (sibNode->id == dstNode->id) 1912 continue; 1913 // Skip 'use' if it has been visited. 1914 if (visitedSibNodeIds->count(sibNode->id) > 0) 1915 continue; 1916 // Skip 'use' if it does not load from the same memref as 'dstNode'. 1917 auto memref = loadOp.getMemRef(); 1918 if (dstNode->getLoadOpCount(memref) == 0) 1919 continue; 1920 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. 1921 if (canFuseWithSibNode(sibNode, memref)) { 1922 visitedSibNodeIds->insert(sibNode->id); 1923 idAndMemrefToFuse->first = sibNode->id; 1924 idAndMemrefToFuse->second = memref; 1925 return true; 1926 } 1927 } 1928 } 1929 1930 // Search for siblings by following edges through an intermediate src node. 1931 // Collect candidate 'dstNode' input edges in 'inEdges'. 1932 SmallVector<MemRefDependenceGraph::Edge, 2> inEdges; 1933 mdg->forEachMemRefInputEdge( 1934 dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) { 1935 // Add 'inEdge' if it is a read-after-write dependence. 1936 if (dstNode->getLoadOpCount(inEdge.value) > 0 && 1937 mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0) 1938 inEdges.push_back(inEdge); 1939 }); 1940 1941 // Search for sibling nodes to fuse by visiting output edges from each input 1942 // edge in 'inEdges'. 1943 for (auto &inEdge : inEdges) { 1944 // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'. 1945 SmallVector<MemRefDependenceGraph::Edge, 2> outEdges; 1946 mdg->forEachMemRefOutputEdge( 1947 inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) { 1948 unsigned sibNodeId = outEdge.id; 1949 if (visitedSibNodeIds->count(sibNodeId) > 0) 1950 return; 1951 // Skip output edge if not a sibling using the same memref. 1952 if (outEdge.id == dstNode->id || outEdge.value != inEdge.value) 1953 return; 1954 auto *sibNode = mdg->getNode(sibNodeId); 1955 if (!isa<AffineForOp>(sibNode->op)) 1956 return; 1957 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. 1958 if (canFuseWithSibNode(sibNode, outEdge.value)) { 1959 // Add candidate 'outEdge' to sibling node. 1960 outEdges.push_back(outEdge); 1961 } 1962 }); 1963 1964 // Add first candidate if any were returned. 1965 if (!outEdges.empty()) { 1966 visitedSibNodeIds->insert(outEdges[0].id); 1967 idAndMemrefToFuse->first = outEdges[0].id; 1968 idAndMemrefToFuse->second = outEdges[0].value; 1969 return true; 1970 } 1971 } 1972 return false; 1973 } 1974 1975 /// Update data dependence graph state to reflect sibling fusion of 'sibNode' 1976 /// into 'dstNode'. 1977 void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) { 1978 // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion. 1979 mdg->updateEdges(sibNode->id, dstNode->id); 1980 1981 // Collect dst loop stats after memref privatization transformation. 1982 auto dstForInst = cast<AffineForOp>(dstNode->op); 1983 LoopNestStateCollector dstLoopCollector; 1984 dstLoopCollector.collect(dstForInst); 1985 // Clear and add back loads and stores 1986 mdg->clearNodeLoadAndStores(dstNode->id); 1987 mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts, 1988 dstLoopCollector.storeOpInsts); 1989 // Remove old sibling loop nest if it no longer has outgoing dependence 1990 // edges, and it does not write to a memref which escapes the block. 1991 if (mdg->getOutEdgeCount(sibNode->id) == 0) { 1992 Operation *op = sibNode->op; 1993 mdg->removeNode(sibNode->id); 1994 op->erase(); 1995 } 1996 } 1997 1998 // Clean up any allocs with no users. 1999 void eraseUnusedMemRefAllocations() { 2000 for (auto &pair : mdg->memrefEdgeCount) { 2001 if (pair.second > 0) 2002 continue; 2003 auto memref = pair.first; 2004 // Skip if there exist other uses (return operation or function calls). 2005 if (!memref.use_empty()) 2006 continue; 2007 // Use list expected to match the dep graph info. 2008 auto *op = memref.getDefiningOp(); 2009 if (isa_and_nonnull<memref::AllocOp>(op)) 2010 op->erase(); 2011 } 2012 } 2013 }; 2014 2015 } // namespace 2016 2017 /// Run fusion on `block`. 2018 void LoopFusion::runOnBlock(Block *block) { 2019 MemRefDependenceGraph g(*block); 2020 if (!g.init()) { 2021 LLVM_DEBUG(llvm::dbgs() << "MDG init failed\n"); 2022 return; 2023 } 2024 2025 std::optional<unsigned> fastMemorySpaceOpt; 2026 if (fastMemorySpace.hasValue()) 2027 fastMemorySpaceOpt = fastMemorySpace; 2028 unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024; 2029 GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt, 2030 maximalFusion, computeToleranceThreshold); 2031 2032 if (affineFusionMode == FusionMode::ProducerConsumer) 2033 fusion.runProducerConsumerFusionOnly(); 2034 else if (affineFusionMode == FusionMode::Sibling) 2035 fusion.runSiblingFusionOnly(); 2036 else 2037 fusion.runGreedyFusion(); 2038 } 2039 2040 void LoopFusion::runOnOperation() { 2041 for (Region ®ion : getOperation()->getRegions()) 2042 for (Block &block : region.getBlocks()) 2043 runOnBlock(&block); 2044 } 2045