1 //===- LoopFusion.cpp - Code to perform loop fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements affine fusion. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/Passes.h" 14 15 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 16 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" 17 #include "mlir/Dialect/Affine/Analysis/Utils.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Affine/LoopFusionUtils.h" 20 #include "mlir/Dialect/Affine/LoopUtils.h" 21 #include "mlir/Dialect/Affine/Utils.h" 22 #include "mlir/Dialect/MemRef/IR/MemRef.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/AffineMap.h" 25 #include "mlir/IR/Builders.h" 26 #include "mlir/Transforms/Passes.h" 27 #include "llvm/ADT/DenseMap.h" 28 #include "llvm/ADT/DenseSet.h" 29 #include "llvm/ADT/STLExtras.h" 30 #include "llvm/ADT/SetVector.h" 31 #include "llvm/Support/CommandLine.h" 32 #include "llvm/Support/Debug.h" 33 #include "llvm/Support/raw_ostream.h" 34 #include <iomanip> 35 #include <optional> 36 #include <sstream> 37 38 namespace mlir { 39 #define GEN_PASS_DEF_AFFINELOOPFUSION 40 #include "mlir/Dialect/Affine/Passes.h.inc" 41 } // namespace mlir 42 43 #define DEBUG_TYPE "affine-loop-fusion" 44 45 using namespace mlir; 46 47 namespace { 48 /// Loop fusion pass. This pass currently supports a greedy fusion policy, 49 /// which fuses loop nests with single-writer/single-reader memref dependences 50 /// with the goal of improving locality. 51 52 // TODO: Support fusion of source loop nests which write to multiple 53 // memrefs, where each memref can have multiple users (if profitable). 54 // TODO: Extend this pass to check for fusion preventing dependences, 55 // and add support for more general loop fusion algorithms. 56 57 struct LoopFusion : public impl::AffineLoopFusionBase<LoopFusion> { 58 LoopFusion() = default; 59 LoopFusion(unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes, 60 bool maximalFusion, enum FusionMode affineFusionMode) { 61 this->fastMemorySpace = fastMemorySpace; 62 this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024; 63 this->maximalFusion = maximalFusion; 64 this->affineFusionMode = affineFusionMode; 65 } 66 67 void runOnBlock(Block *block); 68 void runOnOperation() override; 69 }; 70 71 } // namespace 72 73 std::unique_ptr<Pass> 74 mlir::createLoopFusionPass(unsigned fastMemorySpace, 75 uint64_t localBufSizeThreshold, bool maximalFusion, 76 enum FusionMode affineFusionMode) { 77 return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold, 78 maximalFusion, affineFusionMode); 79 } 80 81 namespace { 82 83 // LoopNestStateCollector walks loop nests and collects load and store 84 // operations, and whether or not a region holding op other than ForOp and IfOp 85 // was encountered in the loop nest. 86 struct LoopNestStateCollector { 87 SmallVector<AffineForOp, 4> forOps; 88 SmallVector<Operation *, 4> loadOpInsts; 89 SmallVector<Operation *, 4> storeOpInsts; 90 bool hasNonAffineRegionOp = false; 91 92 void collect(Operation *opToWalk) { 93 opToWalk->walk([&](Operation *op) { 94 if (isa<AffineForOp>(op)) 95 forOps.push_back(cast<AffineForOp>(op)); 96 else if (op->getNumRegions() != 0 && !isa<AffineIfOp>(op)) 97 hasNonAffineRegionOp = true; 98 else if (isa<AffineReadOpInterface>(op)) 99 loadOpInsts.push_back(op); 100 else if (isa<AffineWriteOpInterface>(op)) 101 storeOpInsts.push_back(op); 102 }); 103 } 104 }; 105 106 // MemRefDependenceGraph is a graph data structure where graph nodes are 107 // top-level operations in a `Block` which contain load/store ops, and edges 108 // are memref dependences between the nodes. 109 // TODO: Add a more flexible dependence graph representation. 110 // TODO: Add a depth parameter to dependence graph construction. 111 struct MemRefDependenceGraph { 112 public: 113 // Node represents a node in the graph. A Node is either an entire loop nest 114 // rooted at the top level which contains loads/stores, or a top level 115 // load/store. 116 struct Node { 117 // The unique identifier of this node in the graph. 118 unsigned id; 119 // The top-level statement which is (or contains) a load/store. 120 Operation *op; 121 // List of load operations. 122 SmallVector<Operation *, 4> loads; 123 // List of store op insts. 124 SmallVector<Operation *, 4> stores; 125 Node(unsigned id, Operation *op) : id(id), op(op) {} 126 127 // Returns the load op count for 'memref'. 128 unsigned getLoadOpCount(Value memref) { 129 unsigned loadOpCount = 0; 130 for (auto *loadOpInst : loads) { 131 if (memref == cast<AffineReadOpInterface>(loadOpInst).getMemRef()) 132 ++loadOpCount; 133 } 134 return loadOpCount; 135 } 136 137 // Returns the store op count for 'memref'. 138 unsigned getStoreOpCount(Value memref) { 139 unsigned storeOpCount = 0; 140 for (auto *storeOpInst : stores) { 141 if (memref == cast<AffineWriteOpInterface>(storeOpInst).getMemRef()) 142 ++storeOpCount; 143 } 144 return storeOpCount; 145 } 146 147 // Returns all store ops in 'storeOps' which access 'memref'. 148 void getStoreOpsForMemref(Value memref, 149 SmallVectorImpl<Operation *> *storeOps) { 150 for (auto *storeOpInst : stores) { 151 if (memref == cast<AffineWriteOpInterface>(storeOpInst).getMemRef()) 152 storeOps->push_back(storeOpInst); 153 } 154 } 155 156 // Returns all load ops in 'loadOps' which access 'memref'. 157 void getLoadOpsForMemref(Value memref, 158 SmallVectorImpl<Operation *> *loadOps) { 159 for (auto *loadOpInst : loads) { 160 if (memref == cast<AffineReadOpInterface>(loadOpInst).getMemRef()) 161 loadOps->push_back(loadOpInst); 162 } 163 } 164 165 // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node 166 // has at least one load and store operation. 167 void getLoadAndStoreMemrefSet(DenseSet<Value> *loadAndStoreMemrefSet) { 168 llvm::SmallDenseSet<Value, 2> loadMemrefs; 169 for (auto *loadOpInst : loads) { 170 loadMemrefs.insert(cast<AffineReadOpInterface>(loadOpInst).getMemRef()); 171 } 172 for (auto *storeOpInst : stores) { 173 auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef(); 174 if (loadMemrefs.count(memref) > 0) 175 loadAndStoreMemrefSet->insert(memref); 176 } 177 } 178 }; 179 180 // Edge represents a data dependence between nodes in the graph. 181 struct Edge { 182 // The id of the node at the other end of the edge. 183 // If this edge is stored in Edge = Node.inEdges[i], then 184 // 'Node.inEdges[i].id' is the identifier of the source node of the edge. 185 // If this edge is stored in Edge = Node.outEdges[i], then 186 // 'Node.outEdges[i].id' is the identifier of the dest node of the edge. 187 unsigned id; 188 // The SSA value on which this edge represents a dependence. 189 // If the value is a memref, then the dependence is between graph nodes 190 // which contain accesses to the same memref 'value'. If the value is a 191 // non-memref value, then the dependence is between a graph node which 192 // defines an SSA value and another graph node which uses the SSA value 193 // (e.g. a constant or load operation defining a value which is used inside 194 // a loop nest). 195 Value value; 196 }; 197 198 // Map from node id to Node. 199 DenseMap<unsigned, Node> nodes; 200 // Map from node id to list of input edges. 201 DenseMap<unsigned, SmallVector<Edge, 2>> inEdges; 202 // Map from node id to list of output edges. 203 DenseMap<unsigned, SmallVector<Edge, 2>> outEdges; 204 // Map from memref to a count on the dependence edges associated with that 205 // memref. 206 DenseMap<Value, unsigned> memrefEdgeCount; 207 // The next unique identifier to use for newly created graph nodes. 208 unsigned nextNodeId = 0; 209 210 MemRefDependenceGraph(Block &block) : block(block) {} 211 212 // Initializes the dependence graph based on operations in 'f'. 213 // Returns true on success, false otherwise. 214 bool init(Block *block); 215 216 // Returns the graph node for 'id'. 217 Node *getNode(unsigned id) { 218 auto it = nodes.find(id); 219 assert(it != nodes.end()); 220 return &it->second; 221 } 222 223 // Returns the graph node for 'forOp'. 224 Node *getForOpNode(AffineForOp forOp) { 225 for (auto &idAndNode : nodes) 226 if (idAndNode.second.op == forOp) 227 return &idAndNode.second; 228 return nullptr; 229 } 230 231 // Adds a node with 'op' to the graph and returns its unique identifier. 232 unsigned addNode(Operation *op) { 233 Node node(nextNodeId++, op); 234 nodes.insert({node.id, node}); 235 return node.id; 236 } 237 238 // Remove node 'id' (and its associated edges) from graph. 239 void removeNode(unsigned id) { 240 // Remove each edge in 'inEdges[id]'. 241 if (inEdges.count(id) > 0) { 242 SmallVector<Edge, 2> oldInEdges = inEdges[id]; 243 for (auto &inEdge : oldInEdges) { 244 removeEdge(inEdge.id, id, inEdge.value); 245 } 246 } 247 // Remove each edge in 'outEdges[id]'. 248 if (outEdges.count(id) > 0) { 249 SmallVector<Edge, 2> oldOutEdges = outEdges[id]; 250 for (auto &outEdge : oldOutEdges) { 251 removeEdge(id, outEdge.id, outEdge.value); 252 } 253 } 254 // Erase remaining node state. 255 inEdges.erase(id); 256 outEdges.erase(id); 257 nodes.erase(id); 258 } 259 260 // Returns true if node 'id' writes to any memref which escapes (or is an 261 // argument to) the block. Returns false otherwise. 262 bool writesToLiveInOrEscapingMemrefs(unsigned id) { 263 Node *node = getNode(id); 264 for (auto *storeOpInst : node->stores) { 265 auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef(); 266 auto *op = memref.getDefiningOp(); 267 // Return true if 'memref' is a block argument. 268 if (!op) 269 return true; 270 // Return true if any use of 'memref' does not deference it in an affine 271 // way. 272 for (auto *user : memref.getUsers()) 273 if (!isa<AffineMapAccessInterface>(*user)) 274 return true; 275 } 276 return false; 277 } 278 279 // Returns true iff there is an edge from node 'srcId' to node 'dstId' which 280 // is for 'value' if non-null, or for any value otherwise. Returns false 281 // otherwise. 282 bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr) { 283 if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) { 284 return false; 285 } 286 bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) { 287 return edge.id == dstId && (!value || edge.value == value); 288 }); 289 bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) { 290 return edge.id == srcId && (!value || edge.value == value); 291 }); 292 return hasOutEdge && hasInEdge; 293 } 294 295 // Adds an edge from node 'srcId' to node 'dstId' for 'value'. 296 void addEdge(unsigned srcId, unsigned dstId, Value value) { 297 if (!hasEdge(srcId, dstId, value)) { 298 outEdges[srcId].push_back({dstId, value}); 299 inEdges[dstId].push_back({srcId, value}); 300 if (value.getType().isa<MemRefType>()) 301 memrefEdgeCount[value]++; 302 } 303 } 304 305 // Removes an edge from node 'srcId' to node 'dstId' for 'value'. 306 void removeEdge(unsigned srcId, unsigned dstId, Value value) { 307 assert(inEdges.count(dstId) > 0); 308 assert(outEdges.count(srcId) > 0); 309 if (value.getType().isa<MemRefType>()) { 310 assert(memrefEdgeCount.count(value) > 0); 311 memrefEdgeCount[value]--; 312 } 313 // Remove 'srcId' from 'inEdges[dstId]'. 314 for (auto *it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) { 315 if ((*it).id == srcId && (*it).value == value) { 316 inEdges[dstId].erase(it); 317 break; 318 } 319 } 320 // Remove 'dstId' from 'outEdges[srcId]'. 321 for (auto *it = outEdges[srcId].begin(); it != outEdges[srcId].end(); 322 ++it) { 323 if ((*it).id == dstId && (*it).value == value) { 324 outEdges[srcId].erase(it); 325 break; 326 } 327 } 328 } 329 330 // Returns true if there is a path in the dependence graph from node 'srcId' 331 // to node 'dstId'. Returns false otherwise. 332 bool hasDependencePath(unsigned srcId, unsigned dstId) { 333 // Worklist state is: <node-id, next-output-edge-index-to-visit> 334 SmallVector<std::pair<unsigned, unsigned>, 4> worklist; 335 worklist.push_back({srcId, 0}); 336 // Run DFS traversal to see if 'dstId' is reachable from 'srcId'. 337 while (!worklist.empty()) { 338 auto &idAndIndex = worklist.back(); 339 // Return true if we have reached 'dstId'. 340 if (idAndIndex.first == dstId) 341 return true; 342 // Pop and continue if node has no out edges, or if all out edges have 343 // already been visited. 344 if (outEdges.count(idAndIndex.first) == 0 || 345 idAndIndex.second == outEdges[idAndIndex.first].size()) { 346 worklist.pop_back(); 347 continue; 348 } 349 // Get graph edge to traverse. 350 Edge edge = outEdges[idAndIndex.first][idAndIndex.second]; 351 // Increment next output edge index for 'idAndIndex'. 352 ++idAndIndex.second; 353 // Add node at 'edge.id' to worklist. 354 worklist.push_back({edge.id, 0}); 355 } 356 return false; 357 } 358 359 // Returns the input edge count for node 'id' and 'memref' from src nodes 360 // which access 'memref' with a store operation. 361 unsigned getIncomingMemRefAccesses(unsigned id, Value memref) { 362 unsigned inEdgeCount = 0; 363 if (inEdges.count(id) > 0) 364 for (auto &inEdge : inEdges[id]) 365 if (inEdge.value == memref) { 366 Node *srcNode = getNode(inEdge.id); 367 // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref' 368 if (srcNode->getStoreOpCount(memref) > 0) 369 ++inEdgeCount; 370 } 371 return inEdgeCount; 372 } 373 374 // Returns the output edge count for node 'id' and 'memref' (if non-null), 375 // otherwise returns the total output edge count from node 'id'. 376 unsigned getOutEdgeCount(unsigned id, Value memref = nullptr) { 377 unsigned outEdgeCount = 0; 378 if (outEdges.count(id) > 0) 379 for (auto &outEdge : outEdges[id]) 380 if (!memref || outEdge.value == memref) 381 ++outEdgeCount; 382 return outEdgeCount; 383 } 384 385 /// Return all nodes which define SSA values used in node 'id'. 386 void gatherDefiningNodes(unsigned id, DenseSet<unsigned> &definingNodes) { 387 for (MemRefDependenceGraph::Edge edge : inEdges[id]) 388 // By definition of edge, if the edge value is a non-memref value, 389 // then the dependence is between a graph node which defines an SSA value 390 // and another graph node which uses the SSA value. 391 if (!edge.value.getType().isa<MemRefType>()) 392 definingNodes.insert(edge.id); 393 } 394 395 // Computes and returns an insertion point operation, before which the 396 // the fused <srcId, dstId> loop nest can be inserted while preserving 397 // dependences. Returns nullptr if no such insertion point is found. 398 Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) { 399 if (outEdges.count(srcId) == 0) 400 return getNode(dstId)->op; 401 402 // Skip if there is any defining node of 'dstId' that depends on 'srcId'. 403 DenseSet<unsigned> definingNodes; 404 gatherDefiningNodes(dstId, definingNodes); 405 if (llvm::any_of(definingNodes, [&](unsigned id) { 406 return hasDependencePath(srcId, id); 407 })) { 408 LLVM_DEBUG(llvm::dbgs() 409 << "Can't fuse: a defining op with a user in the dst " 410 "loop has dependence from the src loop\n"); 411 return nullptr; 412 } 413 414 // Build set of insts in range (srcId, dstId) which depend on 'srcId'. 415 SmallPtrSet<Operation *, 2> srcDepInsts; 416 for (auto &outEdge : outEdges[srcId]) 417 if (outEdge.id != dstId) 418 srcDepInsts.insert(getNode(outEdge.id)->op); 419 420 // Build set of insts in range (srcId, dstId) on which 'dstId' depends. 421 SmallPtrSet<Operation *, 2> dstDepInsts; 422 for (auto &inEdge : inEdges[dstId]) 423 if (inEdge.id != srcId) 424 dstDepInsts.insert(getNode(inEdge.id)->op); 425 426 Operation *srcNodeInst = getNode(srcId)->op; 427 Operation *dstNodeInst = getNode(dstId)->op; 428 429 // Computing insertion point: 430 // *) Walk all operation positions in Block operation list in the 431 // range (src, dst). For each operation 'op' visited in this search: 432 // *) Store in 'firstSrcDepPos' the first position where 'op' has a 433 // dependence edge from 'srcNode'. 434 // *) Store in 'lastDstDepPost' the last position where 'op' has a 435 // dependence edge to 'dstNode'. 436 // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the 437 // operation insertion point (or return null pointer if no such 438 // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos'). 439 SmallVector<Operation *, 2> depInsts; 440 std::optional<unsigned> firstSrcDepPos; 441 std::optional<unsigned> lastDstDepPos; 442 unsigned pos = 0; 443 for (Block::iterator it = std::next(Block::iterator(srcNodeInst)); 444 it != Block::iterator(dstNodeInst); ++it) { 445 Operation *op = &(*it); 446 if (srcDepInsts.count(op) > 0 && firstSrcDepPos == std::nullopt) 447 firstSrcDepPos = pos; 448 if (dstDepInsts.count(op) > 0) 449 lastDstDepPos = pos; 450 depInsts.push_back(op); 451 ++pos; 452 } 453 454 if (firstSrcDepPos.has_value()) { 455 if (lastDstDepPos.has_value()) { 456 if (*firstSrcDepPos <= *lastDstDepPos) { 457 // No valid insertion point exists which preserves dependences. 458 return nullptr; 459 } 460 } 461 // Return the insertion point at 'firstSrcDepPos'. 462 return depInsts[*firstSrcDepPos]; 463 } 464 // No dependence targets in range (or only dst deps in range), return 465 // 'dstNodInst' insertion point. 466 return dstNodeInst; 467 } 468 469 // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them, 470 // taking into account that: 471 // *) if 'removeSrcId' is true, 'srcId' will be removed after fusion, 472 // *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a 473 // private memref. 474 void updateEdges(unsigned srcId, unsigned dstId, 475 const DenseSet<Value> &privateMemRefs, bool removeSrcId) { 476 // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'. 477 if (inEdges.count(srcId) > 0) { 478 SmallVector<Edge, 2> oldInEdges = inEdges[srcId]; 479 for (auto &inEdge : oldInEdges) { 480 // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref. 481 if (privateMemRefs.count(inEdge.value) == 0) 482 addEdge(inEdge.id, dstId, inEdge.value); 483 } 484 } 485 // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'. 486 // If 'srcId' is going to be removed, remap all the out edges to 'dstId'. 487 if (outEdges.count(srcId) > 0) { 488 SmallVector<Edge, 2> oldOutEdges = outEdges[srcId]; 489 for (auto &outEdge : oldOutEdges) { 490 // Remove any out edges from 'srcId' to 'dstId' across memrefs. 491 if (outEdge.id == dstId) 492 removeEdge(srcId, outEdge.id, outEdge.value); 493 else if (removeSrcId) { 494 addEdge(dstId, outEdge.id, outEdge.value); 495 removeEdge(srcId, outEdge.id, outEdge.value); 496 } 497 } 498 } 499 // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being 500 // replaced by a private memref). These edges could come from nodes 501 // other than 'srcId' which were removed in the previous step. 502 if (inEdges.count(dstId) > 0 && !privateMemRefs.empty()) { 503 SmallVector<Edge, 2> oldInEdges = inEdges[dstId]; 504 for (auto &inEdge : oldInEdges) 505 if (privateMemRefs.count(inEdge.value) > 0) 506 removeEdge(inEdge.id, dstId, inEdge.value); 507 } 508 } 509 510 // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion 511 // of sibling node 'sibId' into node 'dstId'. 512 void updateEdges(unsigned sibId, unsigned dstId) { 513 // For each edge in 'inEdges[sibId]': 514 // *) Add new edge from source node 'inEdge.id' to 'dstNode'. 515 // *) Remove edge from source node 'inEdge.id' to 'sibNode'. 516 if (inEdges.count(sibId) > 0) { 517 SmallVector<Edge, 2> oldInEdges = inEdges[sibId]; 518 for (auto &inEdge : oldInEdges) { 519 addEdge(inEdge.id, dstId, inEdge.value); 520 removeEdge(inEdge.id, sibId, inEdge.value); 521 } 522 } 523 524 // For each edge in 'outEdges[sibId]' to node 'id' 525 // *) Add new edge from 'dstId' to 'outEdge.id'. 526 // *) Remove edge from 'sibId' to 'outEdge.id'. 527 if (outEdges.count(sibId) > 0) { 528 SmallVector<Edge, 2> oldOutEdges = outEdges[sibId]; 529 for (auto &outEdge : oldOutEdges) { 530 addEdge(dstId, outEdge.id, outEdge.value); 531 removeEdge(sibId, outEdge.id, outEdge.value); 532 } 533 } 534 } 535 536 // Adds ops in 'loads' and 'stores' to node at 'id'. 537 void addToNode(unsigned id, const SmallVectorImpl<Operation *> &loads, 538 const SmallVectorImpl<Operation *> &stores) { 539 Node *node = getNode(id); 540 llvm::append_range(node->loads, loads); 541 llvm::append_range(node->stores, stores); 542 } 543 544 void clearNodeLoadAndStores(unsigned id) { 545 Node *node = getNode(id); 546 node->loads.clear(); 547 node->stores.clear(); 548 } 549 550 // Calls 'callback' for each input edge incident to node 'id' which carries a 551 // memref dependence. 552 void forEachMemRefInputEdge(unsigned id, 553 const std::function<void(Edge)> &callback) { 554 if (inEdges.count(id) > 0) 555 forEachMemRefEdge(inEdges[id], callback); 556 } 557 558 // Calls 'callback' for each output edge from node 'id' which carries a 559 // memref dependence. 560 void forEachMemRefOutputEdge(unsigned id, 561 const std::function<void(Edge)> &callback) { 562 if (outEdges.count(id) > 0) 563 forEachMemRefEdge(outEdges[id], callback); 564 } 565 566 // Calls 'callback' for each edge in 'edges' which carries a memref 567 // dependence. 568 void forEachMemRefEdge(ArrayRef<Edge> edges, 569 const std::function<void(Edge)> &callback) { 570 for (const auto &edge : edges) { 571 // Skip if 'edge' is not a memref dependence edge. 572 if (!edge.value.getType().isa<MemRefType>()) 573 continue; 574 assert(nodes.count(edge.id) > 0); 575 // Skip if 'edge.id' is not a loop nest. 576 if (!isa<AffineForOp>(getNode(edge.id)->op)) 577 continue; 578 // Visit current input edge 'edge'. 579 callback(edge); 580 } 581 } 582 583 void print(raw_ostream &os) const { 584 os << "\nMemRefDependenceGraph\n"; 585 os << "\nNodes:\n"; 586 for (const auto &idAndNode : nodes) { 587 os << "Node: " << idAndNode.first << "\n"; 588 auto it = inEdges.find(idAndNode.first); 589 if (it != inEdges.end()) { 590 for (const auto &e : it->second) 591 os << " InEdge: " << e.id << " " << e.value << "\n"; 592 } 593 it = outEdges.find(idAndNode.first); 594 if (it != outEdges.end()) { 595 for (const auto &e : it->second) 596 os << " OutEdge: " << e.id << " " << e.value << "\n"; 597 } 598 } 599 } 600 void dump() const { print(llvm::errs()); } 601 602 /// The block for which this graph is created to perform fusion. 603 Block █ 604 }; 605 606 /// Returns true if node 'srcId' can be removed after fusing it with node 607 /// 'dstId'. The node can be removed if any of the following conditions are met: 608 /// 1. 'srcId' has no output dependences after fusion and no escaping memrefs. 609 /// 2. 'srcId' has no output dependences after fusion, has escaping memrefs 610 /// and the fusion slice is maximal. 611 /// 3. 'srcId' has output dependences after fusion, the fusion slice is 612 /// maximal and the fusion insertion point dominates all the dependences. 613 static bool canRemoveSrcNodeAfterFusion( 614 unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice, 615 Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs, 616 MemRefDependenceGraph *mdg) { 617 618 Operation *dstNodeOp = mdg->getNode(dstId)->op; 619 bool hasOutDepsAfterFusion = false; 620 621 for (auto &outEdge : mdg->outEdges[srcId]) { 622 Operation *depNodeOp = mdg->getNode(outEdge.id)->op; 623 // Skip dependence with dstOp since it will be removed after fusion. 624 if (depNodeOp == dstNodeOp) 625 continue; 626 627 // Only fusion within the same block is supported. Use domination analysis 628 // when needed. 629 if (depNodeOp->getBlock() != dstNodeOp->getBlock()) 630 return false; 631 632 // Check if the insertion point of the fused loop dominates the dependence. 633 // Otherwise, the src loop can't be removed. 634 if (fusedLoopInsPoint != depNodeOp && 635 !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) { 636 LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't " 637 "dominate dependence\n"); 638 return false; 639 } 640 641 hasOutDepsAfterFusion = true; 642 } 643 644 // If src loop has dependences after fusion or it writes to an live-out or 645 // escaping memref, we can only remove it if the fusion slice is maximal so 646 // that all the dependences are preserved. 647 if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) { 648 Optional<bool> isMaximal = fusionSlice.isMaximal(); 649 if (!isMaximal) { 650 LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine " 651 "if fusion is maximal\n"); 652 return false; 653 } 654 655 if (!*isMaximal) { 656 LLVM_DEBUG(llvm::dbgs() 657 << "Src loop can't be removed: fusion is not maximal\n"); 658 return false; 659 } 660 } 661 662 return true; 663 } 664 665 /// Returns in 'srcIdCandidates' the producer fusion candidates for consumer 666 /// 'dstId'. Candidates are sorted by node id order. This order corresponds to 667 /// the program order when the 'mdg' is created. However, program order is not 668 /// guaranteed and must not be required by the client. Program order won't be 669 /// held if the 'mdg' is reused from a previous fusion step or if the node 670 /// creation order changes in the future to support more advance cases. 671 // TODO: Move this to a loop fusion utility once 'mdg' is also moved. 672 static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg, 673 SmallVectorImpl<unsigned> &srcIdCandidates) { 674 // Skip if no input edges along which to fuse. 675 if (mdg->inEdges.count(dstId) == 0) 676 return; 677 678 // Gather memrefs from loads in 'dstId'. 679 auto *dstNode = mdg->getNode(dstId); 680 DenseSet<Value> consumedMemrefs; 681 for (Operation *load : dstNode->loads) 682 consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef()); 683 684 // Traverse 'dstId' incoming edges and gather the nodes that contain a store 685 // to one of the consumed memrefs. 686 for (auto &srcEdge : mdg->inEdges[dstId]) { 687 auto *srcNode = mdg->getNode(srcEdge.id); 688 // Skip if 'srcNode' is not a loop nest. 689 if (!isa<AffineForOp>(srcNode->op)) 690 continue; 691 692 if (any_of(srcNode->stores, [&](Operation *op) { 693 auto storeOp = cast<AffineWriteOpInterface>(op); 694 return consumedMemrefs.count(storeOp.getMemRef()) > 0; 695 })) 696 srcIdCandidates.push_back(srcNode->id); 697 } 698 699 llvm::sort(srcIdCandidates); 700 srcIdCandidates.erase( 701 std::unique(srcIdCandidates.begin(), srcIdCandidates.end()), 702 srcIdCandidates.end()); 703 } 704 705 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a 706 /// producer-consumer dependence between 'srcId' and 'dstId'. 707 static void 708 gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId, 709 MemRefDependenceGraph *mdg, 710 DenseSet<Value> &producerConsumerMemrefs) { 711 auto *dstNode = mdg->getNode(dstId); 712 auto *srcNode = mdg->getNode(srcId); 713 gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads, 714 producerConsumerMemrefs); 715 } 716 717 /// A memref escapes in the context of the fusion pass if either: 718 /// 1. it (or its alias) is a block argument, or 719 /// 2. created by an op not known to guarantee alias freedom, 720 /// 3. it (or its alias) are used by ops other than affine dereferencing ops 721 /// (e.g., by call op, memref load/store ops, alias creating ops, unknown ops, 722 /// terminator ops, etc.); such ops do not deference the memref in an affine 723 /// way. 724 static bool isEscapingMemref(Value memref, Block *block) { 725 Operation *defOp = memref.getDefiningOp(); 726 // Check if 'memref' is a block argument. 727 if (!defOp) 728 return true; 729 730 // Check if this is defined to be an alias of another memref. 731 if (auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp)) 732 if (isEscapingMemref(viewOp.getViewSource(), block)) 733 return true; 734 735 // Any op besides allocating ops wouldn't guarantee alias freedom 736 if (!hasSingleEffect<mlir::MemoryEffects::Allocate>(defOp, memref)) 737 return true; 738 739 // Check if 'memref' is used by a non-deferencing op (including unknown ones) 740 // (e.g., call ops, alias creating ops, etc.). 741 for (Operation *user : memref.getUsers()) { 742 // Ignore users outside of `block`. 743 if (block->getParent()->findAncestorOpInRegion(*user)->getBlock() != block) 744 continue; 745 if (!isa<AffineMapAccessInterface>(*user)) 746 return true; 747 } 748 return false; 749 } 750 751 /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' 752 /// that escape the block or are accessed in a non-affine way. 753 void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg, 754 DenseSet<Value> &escapingMemRefs) { 755 auto *node = mdg->getNode(id); 756 for (Operation *storeOp : node->stores) { 757 auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef(); 758 if (escapingMemRefs.count(memref)) 759 continue; 760 if (isEscapingMemref(memref, &mdg->block)) 761 escapingMemRefs.insert(memref); 762 } 763 } 764 765 } // namespace 766 767 // Initializes the data dependence graph by walking operations in `block`. 768 // Assigns each node in the graph a node id based on program order in 'f'. 769 // TODO: Add support for taking a Block arg to construct the 770 // dependence graph at a different depth. 771 bool MemRefDependenceGraph::init(Block *block) { 772 LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n"); 773 // Map from a memref to the set of ids of the nodes that have ops accessing 774 // the memref. 775 DenseMap<Value, SetVector<unsigned>> memrefAccesses; 776 777 DenseMap<Operation *, unsigned> forToNodeMap; 778 for (Operation &op : *block) { 779 if (auto forOp = dyn_cast<AffineForOp>(op)) { 780 // Create graph node 'id' to represent top-level 'forOp' and record 781 // all loads and store accesses it contains. 782 LoopNestStateCollector collector; 783 collector.collect(&op); 784 // Return false if a region holding op other than 'affine.for' and 785 // 'affine.if' was found (not currently supported). 786 if (collector.hasNonAffineRegionOp) 787 return false; 788 Node node(nextNodeId++, &op); 789 for (auto *opInst : collector.loadOpInsts) { 790 node.loads.push_back(opInst); 791 auto memref = cast<AffineReadOpInterface>(opInst).getMemRef(); 792 memrefAccesses[memref].insert(node.id); 793 } 794 for (auto *opInst : collector.storeOpInsts) { 795 node.stores.push_back(opInst); 796 auto memref = cast<AffineWriteOpInterface>(opInst).getMemRef(); 797 memrefAccesses[memref].insert(node.id); 798 } 799 forToNodeMap[&op] = node.id; 800 nodes.insert({node.id, node}); 801 } else if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) { 802 // Create graph node for top-level load op. 803 Node node(nextNodeId++, &op); 804 node.loads.push_back(&op); 805 auto memref = cast<AffineReadOpInterface>(op).getMemRef(); 806 memrefAccesses[memref].insert(node.id); 807 nodes.insert({node.id, node}); 808 } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) { 809 // Create graph node for top-level store op. 810 Node node(nextNodeId++, &op); 811 node.stores.push_back(&op); 812 auto memref = cast<AffineWriteOpInterface>(op).getMemRef(); 813 memrefAccesses[memref].insert(node.id); 814 nodes.insert({node.id, node}); 815 } else if (op.getNumRegions() != 0) { 816 // Return false if another region is found (not currently supported). 817 return false; 818 } else if (op.getNumResults() > 0 && !op.use_empty()) { 819 // Create graph node for top-level producer of SSA values, which 820 // could be used by loop nest nodes. 821 Node node(nextNodeId++, &op); 822 nodes.insert({node.id, node}); 823 } else if (isa<CallOpInterface>(op)) { 824 // Create graph node for top-level Call Op that takes any argument of 825 // memref type. Call Op that returns one or more memref type results 826 // is already taken care of, by the previous conditions. 827 if (llvm::any_of(op.getOperandTypes(), 828 [&](Type t) { return t.isa<MemRefType>(); })) { 829 Node node(nextNodeId++, &op); 830 nodes.insert({node.id, node}); 831 } 832 } else if (hasEffect<MemoryEffects::Write, MemoryEffects::Free>(&op)) { 833 // Create graph node for top-level op, which could have a memory write 834 // side effect. 835 Node node(nextNodeId++, &op); 836 nodes.insert({node.id, node}); 837 } 838 } 839 840 for (auto &idAndNode : nodes) { 841 LLVM_DEBUG(llvm::dbgs() << "Create node " << idAndNode.first << " for:\n" 842 << *(idAndNode.second.op) << "\n"); 843 (void)idAndNode; 844 } 845 846 // Add dependence edges between nodes which produce SSA values and their 847 // users. Load ops can be considered as the ones producing SSA values. 848 for (auto &idAndNode : nodes) { 849 const Node &node = idAndNode.second; 850 // Stores don't define SSA values, skip them. 851 if (!node.stores.empty()) 852 continue; 853 Operation *opInst = node.op; 854 for (Value value : opInst->getResults()) { 855 for (Operation *user : value.getUsers()) { 856 // Ignore users outside of the block. 857 if (block->getParent()->findAncestorOpInRegion(*user)->getBlock() != 858 block) 859 continue; 860 SmallVector<AffineForOp, 4> loops; 861 getLoopIVs(*user, &loops); 862 if (loops.empty()) 863 continue; 864 assert(forToNodeMap.count(loops[0]) > 0 && "missing mapping"); 865 unsigned userLoopNestId = forToNodeMap[loops[0]]; 866 addEdge(node.id, userLoopNestId, value); 867 } 868 } 869 } 870 871 // Walk memref access lists and add graph edges between dependent nodes. 872 for (auto &memrefAndList : memrefAccesses) { 873 unsigned n = memrefAndList.second.size(); 874 for (unsigned i = 0; i < n; ++i) { 875 unsigned srcId = memrefAndList.second[i]; 876 bool srcHasStore = 877 getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0; 878 for (unsigned j = i + 1; j < n; ++j) { 879 unsigned dstId = memrefAndList.second[j]; 880 bool dstHasStore = 881 getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0; 882 if (srcHasStore || dstHasStore) 883 addEdge(srcId, dstId, memrefAndList.first); 884 } 885 } 886 } 887 return true; 888 } 889 890 // Sinks all sequential loops to the innermost levels (while preserving 891 // relative order among them) and moves all parallel loops to the 892 // outermost (while again preserving relative order among them). 893 // This can increase the loop depth at which we can fuse a slice, since we are 894 // pushing loop carried dependence to a greater depth in the loop nest. 895 static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { 896 assert(isa<AffineForOp>(node->op)); 897 AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op)); 898 node->op = newRootForOp; 899 } 900 901 // TODO: improve/complete this when we have target data. 902 static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { 903 auto elementType = memRefType.getElementType(); 904 905 unsigned sizeInBits; 906 if (elementType.isIntOrFloat()) { 907 sizeInBits = elementType.getIntOrFloatBitWidth(); 908 } else { 909 auto vectorType = elementType.cast<VectorType>(); 910 sizeInBits = 911 vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); 912 } 913 return llvm::divideCeil(sizeInBits, 8); 914 } 915 916 // Creates and returns a private (single-user) memref for fused loop rooted 917 // at 'forOp', with (potentially reduced) memref size based on the 918 // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'. 919 // TODO: consider refactoring the common code from generateDma and 920 // this one. 921 static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, 922 unsigned dstLoopDepth, 923 Optional<unsigned> fastMemorySpace, 924 uint64_t localBufSizeThreshold) { 925 Operation *forInst = forOp.getOperation(); 926 927 // Create builder to insert alloc op just before 'forOp'. 928 OpBuilder b(forInst); 929 // Builder to create constants at the top level. 930 OpBuilder top(forInst->getParentRegion()); 931 // Create new memref type based on slice bounds. 932 auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef(); 933 auto oldMemRefType = oldMemRef.getType().cast<MemRefType>(); 934 unsigned rank = oldMemRefType.getRank(); 935 936 // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'. 937 MemRefRegion region(srcStoreOpInst->getLoc()); 938 bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth)); 939 (void)validRegion; 940 assert(validRegion && "unexpected memref region failure"); 941 SmallVector<int64_t, 4> newShape; 942 std::vector<SmallVector<int64_t, 4>> lbs; 943 SmallVector<int64_t, 8> lbDivisors; 944 lbs.reserve(rank); 945 // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed 946 // by 'srcStoreOpInst' at depth 'dstLoopDepth'. 947 Optional<int64_t> numElements = 948 region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors); 949 assert(numElements && "non-constant number of elts in local buffer"); 950 951 const FlatAffineValueConstraints *cst = region.getConstraints(); 952 // 'outerIVs' holds the values that this memory region is symbolic/parametric 953 // on; this would correspond to loop IVs surrounding the level at which the 954 // slice is being materialized. 955 SmallVector<Value, 8> outerIVs; 956 cst->getValues(rank, cst->getNumVars(), &outerIVs); 957 958 // Build 'rank' AffineExprs from MemRefRegion 'lbs' 959 SmallVector<AffineExpr, 4> offsets; 960 offsets.reserve(rank); 961 for (unsigned d = 0; d < rank; ++d) { 962 assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size"); 963 964 AffineExpr offset = top.getAffineConstantExpr(0); 965 for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) { 966 offset = offset + lbs[d][j] * top.getAffineDimExpr(j); 967 } 968 assert(lbDivisors[d] > 0); 969 offset = 970 (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]); 971 offsets.push_back(offset); 972 } 973 974 // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed 975 // by 'srcStoreOpInst'. 976 uint64_t bufSize = getMemRefEltSizeInBytes(oldMemRefType) * *numElements; 977 unsigned newMemSpace; 978 if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) { 979 newMemSpace = *fastMemorySpace; 980 } else { 981 newMemSpace = oldMemRefType.getMemorySpaceAsInt(); 982 } 983 auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(), 984 {}, newMemSpace); 985 986 // Create new private memref for fused loop 'forOp'. 'newShape' is always 987 // a constant shape. 988 // TODO: Create/move alloc ops for private memrefs closer to their 989 // consumer loop nests to reduce their live range. Currently they are added 990 // at the beginning of the block, because loop nests can be reordered 991 // during the fusion pass. 992 Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType); 993 994 // Build an AffineMap to remap access functions based on lower bound offsets. 995 SmallVector<AffineExpr, 4> remapExprs; 996 remapExprs.reserve(rank); 997 for (unsigned i = 0; i < rank; i++) { 998 auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i); 999 1000 auto remapExpr = 1001 simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0); 1002 remapExprs.push_back(remapExpr); 1003 } 1004 1005 auto indexRemap = 1006 AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext()); 1007 1008 // Replace all users of 'oldMemRef' with 'newMemRef'. 1009 LogicalResult res = 1010 replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, 1011 /*extraOperands=*/outerIVs, 1012 /*symbolOperands=*/{}, 1013 /*domOpFilter=*/&*forOp.getBody()->begin()); 1014 assert(succeeded(res) && 1015 "replaceAllMemrefUsesWith should always succeed here"); 1016 (void)res; 1017 return newMemRef; 1018 } 1019 1020 /// Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and 1021 /// 'dstId'), if there is any non-affine operation accessing 'memref', return 1022 /// true. Otherwise, return false. 1023 static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, 1024 Value memref, 1025 MemRefDependenceGraph *mdg) { 1026 auto *srcNode = mdg->getNode(srcId); 1027 auto *dstNode = mdg->getNode(dstId); 1028 Value::user_range users = memref.getUsers(); 1029 // For each MemRefDependenceGraph's node that is between 'srcNode' and 1030 // 'dstNode' (exclusive of 'srcNodes' and 'dstNode'), check whether any 1031 // non-affine operation in the node accesses the 'memref'. 1032 for (auto &idAndNode : mdg->nodes) { 1033 Operation *op = idAndNode.second.op; 1034 // Take care of operations between 'srcNode' and 'dstNode'. 1035 if (srcNode->op->isBeforeInBlock(op) && op->isBeforeInBlock(dstNode->op)) { 1036 // Walk inside the operation to find any use of the memref. 1037 // Interrupt the walk if found. 1038 auto walkResult = op->walk([&](Operation *user) { 1039 // Skip affine ops. 1040 if (isa<AffineMapAccessInterface>(*user)) 1041 return WalkResult::advance(); 1042 // Find a non-affine op that uses the memref. 1043 if (llvm::is_contained(users, user)) 1044 return WalkResult::interrupt(); 1045 return WalkResult::advance(); 1046 }); 1047 if (walkResult.wasInterrupted()) 1048 return true; 1049 } 1050 } 1051 return false; 1052 } 1053 1054 /// Check whether a memref value in node 'srcId' has a non-affine that 1055 /// is between node 'srcId' and node 'dstId' (exclusive of 'srcNode' and 1056 /// 'dstNode'). 1057 static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, 1058 MemRefDependenceGraph *mdg) { 1059 // Collect memref values in node 'srcId'. 1060 auto *srcNode = mdg->getNode(srcId); 1061 llvm::SmallDenseSet<Value, 2> memRefValues; 1062 srcNode->op->walk([&](Operation *op) { 1063 // Skip affine ops. 1064 if (isa<AffineForOp>(op)) 1065 return WalkResult::advance(); 1066 for (Value v : op->getOperands()) 1067 // Collect memref values only. 1068 if (v.getType().isa<MemRefType>()) 1069 memRefValues.insert(v); 1070 return WalkResult::advance(); 1071 }); 1072 // Looking for users between node 'srcId' and node 'dstId'. 1073 for (Value memref : memRefValues) 1074 if (hasNonAffineUsersOnThePath(srcId, dstId, memref, mdg)) 1075 return true; 1076 return false; 1077 } 1078 1079 // Checks the profitability of fusing a backwards slice of the loop nest 1080 // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'. 1081 // The argument 'srcStoreOpInst' is used to calculate the storage reduction on 1082 // the memref being produced and consumed, which is an input to the cost model. 1083 // For producer-consumer fusion, 'srcStoreOpInst' will be the same as 1084 // 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse 1085 // fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the 1086 // same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the 1087 // unique store op in the src node, which will be used to check that the write 1088 // region is the same after input-reuse fusion. Computation slices are provided 1089 // in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which 1090 // fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is 1091 // profitable to fuse the candidate loop nests. Returns false otherwise. 1092 // `dstLoopDepth` is set to the most profitable depth at which to materialize 1093 // the source loop nest slice. 1094 // The profitability model executes the following steps: 1095 // *) Computes the backward computation slice at 'srcOpInst'. This 1096 // computation slice of the loop nest surrounding 'srcOpInst' is 1097 // represented by modified src loop bounds in 'sliceState', which are 1098 // functions of loop IVs in the loop nest surrounding 'srcOpInst'. 1099 // *) Computes the cost of unfused src/dst loop nests (currently the cost of a 1100 // loop nest is the total number of dynamic operation instances in the loop 1101 // nest). 1102 // *) Computes the cost of fusing a slice of the src loop nest into the dst 1103 // loop nest at various values of dst loop depth, attempting to fuse 1104 // the largest computation slice at the maximal dst loop depth (closest to 1105 // the load) to minimize reuse distance and potentially enable subsequent 1106 // load/store forwarding. 1107 // NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop 1108 // nest, at which the src computation slice is inserted/fused. 1109 // NOTE: We attempt to maximize the dst loop depth, but there are cases 1110 // where a particular setting for 'dstLoopNest' might fuse an unsliced 1111 // loop (within the src computation slice) at a depth which results in 1112 // excessive recomputation (see unit tests for examples). 1113 // *) Compares the total cost of the unfused loop nests to the min cost fused 1114 // loop nest computed in the previous step, and returns true if the latter 1115 // is lower. 1116 // TODO: Extend profitability analysis to support scenarios with multiple 1117 // stores. 1118 static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, 1119 AffineForOp dstForOp, 1120 ArrayRef<ComputationSliceState> depthSliceUnions, 1121 unsigned maxLegalFusionDepth, 1122 unsigned *dstLoopDepth, 1123 double computeToleranceThreshold) { 1124 LLVM_DEBUG({ 1125 llvm::dbgs() << "Checking whether fusion is profitable between src op:\n"; 1126 llvm::dbgs() << ' ' << *srcOpInst << " and destination loop:\n"; 1127 llvm::dbgs() << dstForOp << "\n"; 1128 }); 1129 1130 if (maxLegalFusionDepth == 0) { 1131 LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth == 0 .\n"); 1132 return false; 1133 } 1134 1135 // Compute cost of sliced and unsliced src loop nest. 1136 SmallVector<AffineForOp, 4> srcLoopIVs; 1137 getLoopIVs(*srcOpInst, &srcLoopIVs); 1138 1139 // Walk src loop nest and collect stats. 1140 LoopNestStats srcLoopNestStats; 1141 if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats)) 1142 return false; 1143 1144 // Compute cost of dst loop nest. 1145 LoopNestStats dstLoopNestStats; 1146 if (!getLoopNestStats(dstForOp, &dstLoopNestStats)) 1147 return false; 1148 1149 // Search for min cost value for 'dstLoopDepth'. At each value of 1150 // 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice 1151 // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union 1152 // of these bounds). Next the union slice bounds are used to calculate 1153 // the cost of the slice and the cost of the slice inserted into the dst 1154 // loop nest at 'dstLoopDepth'. 1155 uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max(); 1156 double maxStorageReduction = 0.0; 1157 std::optional<uint64_t> sliceMemEstimate; 1158 1159 // The best loop depth at which to materialize the slice. 1160 std::optional<unsigned> bestDstLoopDepth; 1161 1162 // Compute op instance count for the src loop nest without iteration slicing. 1163 uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats); 1164 1165 // Compute src loop nest write region size. 1166 MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc()); 1167 if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) { 1168 LLVM_DEBUG(llvm::dbgs() 1169 << "Unable to compute MemRefRegion for source operation\n."); 1170 return false; 1171 } 1172 1173 Optional<int64_t> maybeSrcWriteRegionSizeBytes = 1174 srcWriteRegion.getRegionSize(); 1175 if (!maybeSrcWriteRegionSizeBytes.has_value()) 1176 return false; 1177 int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes; 1178 1179 // Compute op instance count for the src loop nest. 1180 uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats); 1181 1182 // Evaluate all depth choices for materializing the slice in the destination 1183 // loop nest. 1184 for (unsigned i = maxLegalFusionDepth; i >= 1; --i) { 1185 const ComputationSliceState &slice = depthSliceUnions[i - 1]; 1186 // Skip slice union if it wasn't computed for this depth. 1187 if (slice.isEmpty()) 1188 continue; 1189 1190 int64_t fusedLoopNestComputeCost; 1191 if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp, 1192 dstLoopNestStats, slice, 1193 &fusedLoopNestComputeCost)) { 1194 LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n."); 1195 continue; 1196 } 1197 1198 double additionalComputeFraction = 1199 fusedLoopNestComputeCost / 1200 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) - 1201 1; 1202 1203 // Determine what the slice write MemRefRegion would be, if the src loop 1204 // nest slice 'slice' were to be inserted into the dst loop nest at loop 1205 // depth 'i'. 1206 MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc()); 1207 if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0, 1208 &slice))) { 1209 LLVM_DEBUG(llvm::dbgs() 1210 << "Failed to compute slice write region at loopDepth: " << i 1211 << "\n"); 1212 continue; 1213 } 1214 1215 Optional<int64_t> maybeSliceWriteRegionSizeBytes = 1216 sliceWriteRegion.getRegionSize(); 1217 if (!maybeSliceWriteRegionSizeBytes.has_value() || 1218 *maybeSliceWriteRegionSizeBytes == 0) { 1219 LLVM_DEBUG(llvm::dbgs() 1220 << "Failed to get slice write region size at loopDepth: " << i 1221 << "\n"); 1222 continue; 1223 } 1224 int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes; 1225 1226 // If we are fusing for reuse, check that write regions remain the same. 1227 // TODO: Write region check should check sizes and offsets in 1228 // each dimension, so that we are sure they are covering the same memref 1229 // region. Also, move this out to a isMemRefRegionSuperSet helper function. 1230 if (srcOpInst != srcStoreOpInst && 1231 sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes) 1232 continue; 1233 1234 double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) / 1235 static_cast<double>(sliceWriteRegionSizeBytes); 1236 1237 LLVM_DEBUG({ 1238 std::stringstream msg; 1239 msg << " evaluating fusion profitability at depth : " << i << "\n" 1240 << std::fixed << std::setprecision(2) 1241 << " additional compute fraction: " 1242 << 100.0 * additionalComputeFraction << "%\n" 1243 << " storage reduction factor: " << storageReduction << "x\n" 1244 << " fused nest cost: " << fusedLoopNestComputeCost << "\n" 1245 << " src write region size: " << srcWriteRegionSizeBytes << "\n" 1246 << " slice write region size: " << sliceWriteRegionSizeBytes 1247 << "\n"; 1248 llvm::dbgs() << msg.str(); 1249 }); 1250 1251 // TODO: This is a placeholder cost model. 1252 // Among all choices that add an acceptable amount of redundant computation 1253 // (as per computeToleranceThreshold), we will simply pick the one that 1254 // reduces the intermediary size the most. 1255 if ((storageReduction > maxStorageReduction) && 1256 (additionalComputeFraction < computeToleranceThreshold)) { 1257 maxStorageReduction = storageReduction; 1258 bestDstLoopDepth = i; 1259 minFusedLoopNestComputeCost = fusedLoopNestComputeCost; 1260 sliceMemEstimate = sliceWriteRegionSizeBytes; 1261 } 1262 } 1263 1264 // A simple cost model: fuse if it reduces the memory footprint. 1265 1266 if (!bestDstLoopDepth) { 1267 LLVM_DEBUG( 1268 llvm::dbgs() 1269 << "All fusion choices involve more than the threshold amount of " 1270 "redundant computation; NOT fusing.\n"); 1271 return false; 1272 } 1273 1274 if (!bestDstLoopDepth) { 1275 LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n"); 1276 return false; 1277 } 1278 1279 // Set dstLoopDepth based on best values from search. 1280 *dstLoopDepth = *bestDstLoopDepth; 1281 1282 LLVM_DEBUG( 1283 llvm::dbgs() << " LoopFusion fusion stats:" 1284 << "\n best loop depth: " << bestDstLoopDepth 1285 << "\n src loop nest compute cost: " << srcLoopNestCost 1286 << "\n dst loop nest compute cost: " << dstLoopNestCost 1287 << "\n fused loop nest compute cost: " 1288 << minFusedLoopNestComputeCost << "\n"); 1289 1290 auto dstMemSize = getMemoryFootprintBytes(dstForOp); 1291 auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]); 1292 1293 std::optional<double> storageReduction; 1294 1295 if (!dstMemSize || !srcMemSize) { 1296 LLVM_DEBUG(llvm::dbgs() 1297 << " fusion memory benefit cannot be evaluated; NOT fusing.\n"); 1298 return false; 1299 } 1300 1301 auto srcMemSizeVal = *srcMemSize; 1302 auto dstMemSizeVal = *dstMemSize; 1303 1304 assert(sliceMemEstimate && "expected value"); 1305 auto fusedMem = dstMemSizeVal + *sliceMemEstimate; 1306 1307 LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n" 1308 << " dst mem: " << dstMemSizeVal << "\n" 1309 << " fused mem: " << fusedMem << "\n" 1310 << " slice mem: " << sliceMemEstimate << "\n"); 1311 1312 if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) { 1313 LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n"); 1314 return false; 1315 } 1316 storageReduction = 1317 100.0 * 1318 (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal)); 1319 1320 double additionalComputeFraction = 1321 100.0 * (minFusedLoopNestComputeCost / 1322 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) - 1323 1); 1324 (void)additionalComputeFraction; 1325 LLVM_DEBUG({ 1326 std::stringstream msg; 1327 msg << " fusion is most profitable at depth " << *dstLoopDepth << " with " 1328 << std::setprecision(2) << additionalComputeFraction 1329 << "% redundant computation and a "; 1330 msg << (storageReduction ? std::to_string(*storageReduction) : "<unknown>"); 1331 msg << "% storage reduction.\n"; 1332 llvm::dbgs() << msg.str(); 1333 }); 1334 1335 return true; 1336 } 1337 1338 namespace { 1339 1340 // GreedyFusion greedily fuses loop nests which have a producer/consumer or 1341 // input-reuse relationship on a memref, with the goal of improving locality. 1342 // 1343 // The steps of the producer-consumer fusion algorithm are as follows: 1344 // 1345 // *) A worklist is initialized with node ids from the dependence graph. 1346 // *) For each node id in the worklist: 1347 // *) Pop an AffineForOp of the worklist. This 'dstAffineForOp' will be a 1348 // candidate destination AffineForOp into which fusion will be attempted. 1349 // *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'. 1350 // *) For each LoadOp in 'dstLoadOps' do: 1351 // *) Look up dependent loop nests which have a single store op to the same 1352 // memref. 1353 // *) Check if dependences would be violated by the fusion. 1354 // *) Get a computation slice of 'srcLoopNest', which adjusts its loop 1355 // bounds to be functions of 'dstLoopNest' IVs and symbols. 1356 // *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest', 1357 // at a loop depth determined by the cost model in 'isFusionProfitable'. 1358 // *) Add the newly fused load/store operations to the state, 1359 // and also add newly fused load ops to 'dstLoopOps' to be considered 1360 // as fusion dst load ops in another iteration. 1361 // *) Remove old src loop nest and its associated state. 1362 // 1363 // The steps of the input-reuse fusion algorithm are as follows: 1364 // 1365 // *) Initialize 'worklist' with node ids from the dependence graph. 1366 // *) For each 'dstNode' in the worklist: 1367 // *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which 1368 // loads from the same memref, but which has no dependence paths to/from. 1369 // *) Get a computation slice of 'sibLoopNest', which adjusts its loop 1370 // bounds to be functions of 'dstLoopNest' IVs and symbols. 1371 // *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest', 1372 // at a loop depth determined by the cost model in 'isFusionProfitable'. 1373 // This function also checks that the memref write region of 'sibLoopNest', 1374 // is preserved in the fused loop nest. 1375 // *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'. 1376 // 1377 // Given a graph where top-level operations are vertices in the set 'V' and 1378 // edges in the set 'E' are dependences between vertices, this algorithm 1379 // takes O(V) time for initialization, and has runtime O(V + E). 1380 // 1381 // This greedy algorithm is not 'maximal' due to the current restriction of 1382 // fusing along single producer consumer edges, but there is a TODO: to fix 1383 // this. 1384 // 1385 // TODO: Experiment with other fusion policies. 1386 struct GreedyFusion { 1387 public: 1388 // The data dependence graph to traverse during fusion. 1389 MemRefDependenceGraph *mdg; 1390 // Worklist of graph nodes visited during the fusion pass. 1391 SmallVector<unsigned, 8> worklist; 1392 // Parameter for local buffer size threshold. 1393 unsigned localBufSizeThreshold; 1394 // Parameter for fast memory space. 1395 Optional<unsigned> fastMemorySpace; 1396 // If true, ignore any additional (redundant) computation tolerance threshold 1397 // that would have prevented fusion. 1398 bool maximalFusion; 1399 // The amount of additional computation that is tolerated while fusing 1400 // pair-wise as a fraction of the total computation. 1401 double computeToleranceThreshold; 1402 1403 using Node = MemRefDependenceGraph::Node; 1404 1405 GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold, 1406 Optional<unsigned> fastMemorySpace, bool maximalFusion, 1407 double computeToleranceThreshold) 1408 : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold), 1409 fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion), 1410 computeToleranceThreshold(computeToleranceThreshold) {} 1411 1412 /// Initializes 'worklist' with nodes from 'mdg'. 1413 void init() { 1414 // TODO: Add a priority queue for prioritizing nodes by different 1415 // metrics (e.g. arithmetic intensity/flops-to-bytes ratio). 1416 worklist.clear(); 1417 for (auto &idAndNode : mdg->nodes) { 1418 const Node &node = idAndNode.second; 1419 worklist.push_back(node.id); 1420 } 1421 } 1422 /// Run only sibling fusion on the `mdg`. 1423 void runSiblingFusionOnly() { 1424 fuseSiblingNodes(); 1425 eraseUnusedMemRefAllocations(); 1426 } 1427 1428 /// Run only producer/consumer fusion on the `mdg`. 1429 void runProducerConsumerFusionOnly() { 1430 fuseProducerConsumerNodes( 1431 /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max()); 1432 eraseUnusedMemRefAllocations(); 1433 } 1434 1435 // Run the GreedyFusion pass. 1436 // *) First pass through the nodes fuses single-use producer nodes into their 1437 // unique consumer. 1438 // *) Second pass fuses sibling nodes which share no dependence edges. 1439 // *) Third pass fuses any remaining producer nodes into their users. 1440 void runGreedyFusion() { 1441 // TODO: Run this repeatedly until a fixed-point is reached. 1442 fuseProducerConsumerNodes(/*maxSrcUserCount=*/1); 1443 fuseSiblingNodes(); 1444 fuseProducerConsumerNodes( 1445 /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max()); 1446 eraseUnusedMemRefAllocations(); 1447 } 1448 1449 /// Visit each node in the graph, and for each node, attempt to fuse it with 1450 /// producer-consumer candidates. No fusion is performed when producers with a 1451 /// user count greater than `maxSrcUserCount` for any of the memrefs involved 1452 /// are encountered. 1453 void fuseProducerConsumerNodes(unsigned maxSrcUserCount) { 1454 LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n"); 1455 init(); 1456 while (!worklist.empty()) { 1457 unsigned dstId = worklist.back(); 1458 worklist.pop_back(); 1459 1460 // Skip if this node was removed (fused into another node). 1461 if (mdg->nodes.count(dstId) == 0) 1462 continue; 1463 // Get 'dstNode' into which to attempt fusion. 1464 auto *dstNode = mdg->getNode(dstId); 1465 // Skip if 'dstNode' is not a loop nest. 1466 if (!isa<AffineForOp>(dstNode->op)) 1467 continue; 1468 // Skip if 'dstNode' is a loop nest returning values. 1469 // TODO: support loop nests that return values. 1470 if (dstNode->op->getNumResults() > 0) 1471 continue; 1472 1473 LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n"); 1474 1475 // Sink sequential loops in 'dstNode' (and thus raise parallel loops) 1476 // while preserving relative order. This can increase the maximum loop 1477 // depth at which we can fuse a slice of a producer loop nest into a 1478 // consumer loop nest. 1479 sinkSequentialLoops(dstNode); 1480 auto dstAffineForOp = cast<AffineForOp>(dstNode->op); 1481 1482 // Try to fuse 'dstNode' with candidate producer loops until a fixed point 1483 // is reached. Fusing two loops may expose new fusion opportunities. 1484 bool dstNodeChanged; 1485 do { 1486 // Gather src loop candidates for 'dstNode' and visit them in "quasi" 1487 // reverse program order to minimize the number of iterations needed to 1488 // reach the fixed point. Note that this is a best effort approach since 1489 // 'getProducerCandidates' does not always guarantee that program order 1490 // in 'srcIdCandidates'. 1491 dstNodeChanged = false; 1492 SmallVector<unsigned, 16> srcIdCandidates; 1493 getProducerCandidates(dstId, mdg, srcIdCandidates); 1494 1495 for (unsigned srcId : llvm::reverse(srcIdCandidates)) { 1496 // Get 'srcNode' from which to attempt fusion into 'dstNode'. 1497 auto *srcNode = mdg->getNode(srcId); 1498 auto srcAffineForOp = cast<AffineForOp>(srcNode->op); 1499 LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId 1500 << " for dst loop " << dstId << "\n"); 1501 1502 // Skip if 'srcNode' is a loop nest returning values. 1503 // TODO: support loop nests that return values. 1504 if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0) 1505 continue; 1506 1507 DenseSet<Value> producerConsumerMemrefs; 1508 gatherProducerConsumerMemrefs(srcId, dstId, mdg, 1509 producerConsumerMemrefs); 1510 1511 // Skip if 'srcNode' out edge count on any memref is greater than 1512 // 'maxSrcUserCount'. 1513 if (any_of(producerConsumerMemrefs, [&](Value memref) { 1514 return mdg->getOutEdgeCount(srcNode->id, memref) > 1515 maxSrcUserCount; 1516 })) 1517 continue; 1518 1519 // Gather memrefs in 'srcNode' that are written and escape out of the 1520 // block (e.g., memref block arguments, returned memrefs, 1521 // memrefs passed to function calls, etc.). 1522 DenseSet<Value> srcEscapingMemRefs; 1523 gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs); 1524 1525 // Skip if there are non-affine operations in between the 'srcNode' 1526 // and 'dstNode' using their memrefs. If so, we wouldn't be able to 1527 // compute a legal insertion point for now. 'srcNode' and 'dstNode' 1528 // memrefs with non-affine operation users would be considered 1529 // escaping memrefs so we can limit this check to only scenarios with 1530 // escaping memrefs. 1531 if (!srcEscapingMemRefs.empty() && 1532 hasNonAffineUsersOnThePath(srcId, dstId, mdg)) { 1533 LLVM_DEBUG( 1534 llvm::dbgs() 1535 << "Can't fuse: non-affine users in between the loops\n."); 1536 continue; 1537 } 1538 1539 // Compute an operation list insertion point for the fused loop 1540 // nest which preserves dependences. 1541 Operation *fusedLoopInsPoint = 1542 mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id); 1543 if (fusedLoopInsPoint == nullptr) 1544 continue; 1545 1546 // Compute the innermost common loop depth for dstNode 1547 // producer-consumer loads/stores. 1548 SmallVector<Operation *, 2> dstMemrefOps; 1549 for (Operation *op : dstNode->loads) 1550 if (producerConsumerMemrefs.count( 1551 cast<AffineReadOpInterface>(op).getMemRef()) > 0) 1552 dstMemrefOps.push_back(op); 1553 for (Operation *op : dstNode->stores) 1554 if (producerConsumerMemrefs.count( 1555 cast<AffineWriteOpInterface>(op).getMemRef())) 1556 dstMemrefOps.push_back(op); 1557 unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps); 1558 1559 // Check the feasibility of fusing src loop nest into dst loop nest 1560 // at loop depths in range [1, dstLoopDepthTest]. 1561 unsigned maxLegalFusionDepth = 0; 1562 SmallVector<ComputationSliceState, 8> depthSliceUnions; 1563 depthSliceUnions.resize(dstLoopDepthTest); 1564 FusionStrategy strategy(FusionStrategy::ProducerConsumer); 1565 for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { 1566 FusionResult result = mlir::canFuseLoops( 1567 srcAffineForOp, dstAffineForOp, 1568 /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy); 1569 1570 if (result.value == FusionResult::Success) 1571 maxLegalFusionDepth = i; 1572 } 1573 1574 if (maxLegalFusionDepth == 0) { 1575 LLVM_DEBUG(llvm::dbgs() 1576 << "Can't fuse: fusion is not legal at any depth\n"); 1577 continue; 1578 } 1579 1580 // Check if fusion would be profitable. We skip profitability analysis 1581 // for maximal fusion since we already know the maximal legal depth to 1582 // fuse. 1583 unsigned bestDstLoopDepth = maxLegalFusionDepth; 1584 if (!maximalFusion) { 1585 // Retrieve producer stores from the src loop. 1586 SmallVector<Operation *, 2> producerStores; 1587 for (Operation *op : srcNode->stores) 1588 if (producerConsumerMemrefs.count( 1589 cast<AffineWriteOpInterface>(op).getMemRef())) 1590 producerStores.push_back(op); 1591 1592 // TODO: Suppport multiple producer stores in profitability 1593 // analysis. We limit profitability analysis to only scenarios with 1594 // a single producer store for now. Note that some multi-store 1595 // producer scenarios will still go through profitability analysis 1596 // if only one of the stores is involved the producer-consumer 1597 // relationship of the candidate loops. 1598 assert(!producerStores.empty() && "Expected producer store"); 1599 if (producerStores.size() > 1) 1600 LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not " 1601 "supported for this case\n"); 1602 else if (!isFusionProfitable(producerStores[0], producerStores[0], 1603 dstAffineForOp, depthSliceUnions, 1604 maxLegalFusionDepth, &bestDstLoopDepth, 1605 computeToleranceThreshold)) 1606 continue; 1607 } 1608 1609 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); 1610 ComputationSliceState &bestSlice = 1611 depthSliceUnions[bestDstLoopDepth - 1]; 1612 assert(!bestSlice.isEmpty() && "Missing slice union for depth"); 1613 1614 // Determine if 'srcId' can be removed after fusion, taking into 1615 // account remaining dependences, escaping memrefs and the fusion 1616 // insertion point. 1617 bool removeSrcNode = canRemoveSrcNodeAfterFusion( 1618 srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs, 1619 mdg); 1620 1621 DenseSet<Value> privateMemrefs; 1622 for (Value memref : producerConsumerMemrefs) { 1623 // If `memref` is an escaping one, do not create a private memref 1624 // for the below scenarios, since doing so will leave the escaping 1625 // memref unmodified as all the writes originally meant for the 1626 // escaping memref would be performed on the private memref: 1627 // 1. The source is to be removed after fusion, 1628 // OR 1629 // 2. The destination writes to `memref`. 1630 if (srcEscapingMemRefs.count(memref) > 0 && 1631 (removeSrcNode || dstNode->getStoreOpCount(memref) > 0)) 1632 continue; 1633 1634 // Don't create a private memref if 'srcNode' has in edges on 1635 // 'memref' or 'dstNode' has out edges on 'memref'. 1636 if (mdg->getIncomingMemRefAccesses(srcId, memref) > 0 || 1637 mdg->getOutEdgeCount(dstId, memref) > 0) 1638 continue; 1639 1640 // If 'srcNode' will be removed but it has out edges on 'memref' to 1641 // nodes other than 'dstNode', we have to preserve dependences and 1642 // cannot create a private memref. 1643 if (removeSrcNode && 1644 any_of(mdg->outEdges[srcId], [&](const auto &edge) { 1645 return edge.value == memref && edge.id != dstId; 1646 })) 1647 continue; 1648 1649 // Create a private version of this memref. 1650 privateMemrefs.insert(memref); 1651 } 1652 1653 // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. 1654 fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice); 1655 dstNodeChanged = true; 1656 1657 LLVM_DEBUG(llvm::dbgs() 1658 << "Fused src loop " << srcId << " into dst loop " << dstId 1659 << " at depth " << bestDstLoopDepth << ":\n" 1660 << dstAffineForOp << "\n"); 1661 1662 // Move 'dstAffineForOp' before 'insertPointInst' if needed. 1663 if (fusedLoopInsPoint != dstAffineForOp) 1664 dstAffineForOp->moveBefore(fusedLoopInsPoint); 1665 1666 // Update edges between 'srcNode' and 'dstNode'. 1667 mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs, 1668 removeSrcNode); 1669 1670 // Create private memrefs. 1671 if (!privateMemrefs.empty()) { 1672 // Gather stores for all the private-to-be memrefs. 1673 DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores; 1674 dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) { 1675 Value storeMemRef = storeOp.getMemRef(); 1676 if (privateMemrefs.count(storeMemRef) > 0) 1677 privateMemRefToStores[storeMemRef].push_back(storeOp); 1678 }); 1679 1680 // Replace original memrefs with private memrefs. Note that all the 1681 // loads and stores on these memrefs will be replaced with a new 1682 // loads and stores. Any reference to the original ones becomes 1683 // invalid after this point. 1684 for (auto &memrefToStoresPair : privateMemRefToStores) { 1685 // TODO: Use union of memref write regions to compute 1686 // private memref footprint. 1687 SmallVector<Operation *, 4> &storesForMemref = 1688 memrefToStoresPair.second; 1689 Value newMemRef = createPrivateMemRef( 1690 dstAffineForOp, storesForMemref[0], bestDstLoopDepth, 1691 fastMemorySpace, localBufSizeThreshold); 1692 // Create new node in dependence graph for 'newMemRef' alloc op. 1693 unsigned newMemRefNodeId = 1694 mdg->addNode(newMemRef.getDefiningOp()); 1695 // Add edge from 'newMemRef' node to dstNode. 1696 mdg->addEdge(newMemRefNodeId, dstId, newMemRef); 1697 } 1698 // One or more entries for 'newMemRef' alloc op are inserted into 1699 // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to 1700 // reallocate, update dstNode. 1701 dstNode = mdg->getNode(dstId); 1702 } 1703 1704 // Collect dst loop stats after memref privatization transformation. 1705 LoopNestStateCollector dstLoopCollector; 1706 dstLoopCollector.collect(dstAffineForOp); 1707 1708 // Clear and add back loads and stores. 1709 mdg->clearNodeLoadAndStores(dstNode->id); 1710 mdg->addToNode(dstId, dstLoopCollector.loadOpInsts, 1711 dstLoopCollector.storeOpInsts); 1712 1713 if (removeSrcNode) { 1714 LLVM_DEBUG(llvm::dbgs() 1715 << "Removing src loop " << srcId << " after fusion\n"); 1716 // srcNode is no longer valid after it is removed from mdg. 1717 srcAffineForOp.erase(); 1718 mdg->removeNode(srcId); 1719 srcNode = nullptr; 1720 } 1721 } 1722 } while (dstNodeChanged); 1723 } 1724 } 1725 1726 // Visits each node in the graph, and for each node, attempts to fuse it with 1727 // its sibling nodes (nodes which share a parent, but no dependence edges). 1728 void fuseSiblingNodes() { 1729 LLVM_DEBUG(llvm::dbgs() << "--- Sibling Fusion ---\n"); 1730 init(); 1731 while (!worklist.empty()) { 1732 unsigned dstId = worklist.back(); 1733 worklist.pop_back(); 1734 1735 // Skip if this node was removed (fused into another node). 1736 if (mdg->nodes.count(dstId) == 0) 1737 continue; 1738 // Get 'dstNode' into which to attempt fusion. 1739 auto *dstNode = mdg->getNode(dstId); 1740 // Skip if 'dstNode' is not a loop nest. 1741 if (!isa<AffineForOp>(dstNode->op)) 1742 continue; 1743 // Attempt to fuse 'dstNode' with its sibling nodes in the graph. 1744 fuseWithSiblingNodes(dstNode); 1745 } 1746 } 1747 1748 // Attempt to fuse 'dstNode' with sibling nodes in the graph. 1749 void fuseWithSiblingNodes(Node *dstNode) { 1750 DenseSet<unsigned> visitedSibNodeIds; 1751 std::pair<unsigned, Value> idAndMemref; 1752 auto dstAffineForOp = cast<AffineForOp>(dstNode->op); 1753 1754 while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) { 1755 unsigned sibId = idAndMemref.first; 1756 Value memref = idAndMemref.second; 1757 // TODO: Check that 'sibStoreOpInst' post-dominates all other 1758 // stores to the same memref in 'sibNode' loop nest. 1759 auto *sibNode = mdg->getNode(sibId); 1760 // Compute an operation list insertion point for the fused loop 1761 // nest which preserves dependences. 1762 assert(sibNode->op->getBlock() == dstNode->op->getBlock()); 1763 Operation *insertPointInst = 1764 sibNode->op->isBeforeInBlock(dstNode->op) 1765 ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id) 1766 : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id); 1767 if (insertPointInst == nullptr) 1768 continue; 1769 1770 // Check if fusion would be profitable and at what depth. 1771 1772 // Get unique 'sibNode' load op to 'memref'. 1773 SmallVector<Operation *, 2> sibLoadOpInsts; 1774 sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts); 1775 // Currently findSiblingNodeToFuse searches for siblings with one load. 1776 assert(sibLoadOpInsts.size() == 1); 1777 Operation *sibLoadOpInst = sibLoadOpInsts[0]; 1778 assert(!sibNode->stores.empty()); 1779 // TODO: Choose the store which postdominates all other stores. 1780 auto *sibStoreOpInst = sibNode->stores.back(); 1781 1782 // Gather 'dstNode' load ops to 'memref'. 1783 SmallVector<Operation *, 2> dstLoadOpInsts; 1784 dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts); 1785 1786 SmallVector<AffineForOp, 4> dstLoopIVs; 1787 getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs); 1788 unsigned dstLoopDepthTest = dstLoopIVs.size(); 1789 auto sibAffineForOp = cast<AffineForOp>(sibNode->op); 1790 1791 // Compute loop depth and slice union for fusion. 1792 SmallVector<ComputationSliceState, 8> depthSliceUnions; 1793 depthSliceUnions.resize(dstLoopDepthTest); 1794 unsigned maxLegalFusionDepth = 0; 1795 FusionStrategy strategy(memref); 1796 for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { 1797 FusionResult result = mlir::canFuseLoops( 1798 sibAffineForOp, dstAffineForOp, 1799 /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy); 1800 1801 if (result.value == FusionResult::Success) 1802 maxLegalFusionDepth = i; 1803 } 1804 1805 // Skip if fusion is not feasible at any loop depths. 1806 if (maxLegalFusionDepth == 0) 1807 continue; 1808 1809 unsigned bestDstLoopDepth = maxLegalFusionDepth; 1810 if (!maximalFusion) { 1811 // Check if fusion would be profitable. 1812 if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstAffineForOp, 1813 depthSliceUnions, maxLegalFusionDepth, 1814 &bestDstLoopDepth, computeToleranceThreshold)) 1815 continue; 1816 } 1817 1818 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); 1819 assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() && 1820 "Fusion depth has no computed slice union"); 1821 // Check if source loop is being inserted in the innermost 1822 // destination loop. Based on this, the fused loop may be optimized 1823 // further inside `fuseLoops`. 1824 bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest); 1825 // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'. 1826 mlir::fuseLoops(sibAffineForOp, dstAffineForOp, 1827 depthSliceUnions[bestDstLoopDepth - 1], 1828 isInnermostInsertion); 1829 1830 auto dstForInst = cast<AffineForOp>(dstNode->op); 1831 // Update operation position of fused loop nest (if needed). 1832 if (insertPointInst != dstForInst) { 1833 dstForInst->moveBefore(insertPointInst); 1834 } 1835 // Update data dependence graph state post fusion. 1836 updateStateAfterSiblingFusion(sibNode, dstNode); 1837 } 1838 } 1839 1840 // Searches block argument uses and the graph from 'dstNode' looking for a 1841 // fusion candidate sibling node which shares no dependences with 'dstNode' 1842 // but which loads from the same memref. Returns true and sets 1843 // 'idAndMemrefToFuse' on success. Returns false otherwise. 1844 bool findSiblingNodeToFuse(Node *dstNode, 1845 DenseSet<unsigned> *visitedSibNodeIds, 1846 std::pair<unsigned, Value> *idAndMemrefToFuse) { 1847 // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse 1848 // on 'memref'. 1849 auto canFuseWithSibNode = [&](Node *sibNode, Value memref) { 1850 // Skip if 'outEdge' is not a read-after-write dependence. 1851 // TODO: Remove restrict to single load op restriction. 1852 if (sibNode->getLoadOpCount(memref) != 1) 1853 return false; 1854 // Skip if there exists a path of dependent edges between 1855 // 'sibNode' and 'dstNode'. 1856 if (mdg->hasDependencePath(sibNode->id, dstNode->id) || 1857 mdg->hasDependencePath(dstNode->id, sibNode->id)) 1858 return false; 1859 // Skip sib node if it loads to (and stores from) the same memref on 1860 // which it also has an input dependence edge. 1861 DenseSet<Value> loadAndStoreMemrefSet; 1862 sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet); 1863 if (llvm::any_of(loadAndStoreMemrefSet, [=](Value memref) { 1864 return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0; 1865 })) 1866 return false; 1867 1868 // Check that all stores are to the same memref. 1869 DenseSet<Value> storeMemrefs; 1870 for (auto *storeOpInst : sibNode->stores) { 1871 storeMemrefs.insert( 1872 cast<AffineWriteOpInterface>(storeOpInst).getMemRef()); 1873 } 1874 if (storeMemrefs.size() != 1) 1875 return false; 1876 1877 // Skip if a memref value in one node is used by a non-affine memref 1878 // access that lies between 'dstNode' and 'sibNode'. 1879 if (hasNonAffineUsersOnThePath(dstNode->id, sibNode->id, mdg) || 1880 hasNonAffineUsersOnThePath(sibNode->id, dstNode->id, mdg)) 1881 return false; 1882 return true; 1883 }; 1884 1885 // Search for siblings which load the same memref block argument. 1886 Block *block = dstNode->op->getBlock(); 1887 for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) { 1888 for (Operation *user : block->getArgument(i).getUsers()) { 1889 auto loadOp = dyn_cast<AffineReadOpInterface>(user); 1890 if (!loadOp) 1891 continue; 1892 // Gather loops surrounding 'use'. 1893 SmallVector<AffineForOp, 4> loops; 1894 getLoopIVs(*user, &loops); 1895 // Skip 'use' if it is not within a loop nest. 1896 if (loops.empty()) 1897 continue; 1898 Node *sibNode = mdg->getForOpNode(loops[0]); 1899 assert(sibNode != nullptr); 1900 // Skip 'use' if it not a sibling to 'dstNode'. 1901 if (sibNode->id == dstNode->id) 1902 continue; 1903 // Skip 'use' if it has been visited. 1904 if (visitedSibNodeIds->count(sibNode->id) > 0) 1905 continue; 1906 // Skip 'use' if it does not load from the same memref as 'dstNode'. 1907 auto memref = loadOp.getMemRef(); 1908 if (dstNode->getLoadOpCount(memref) == 0) 1909 continue; 1910 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. 1911 if (canFuseWithSibNode(sibNode, memref)) { 1912 visitedSibNodeIds->insert(sibNode->id); 1913 idAndMemrefToFuse->first = sibNode->id; 1914 idAndMemrefToFuse->second = memref; 1915 return true; 1916 } 1917 } 1918 } 1919 1920 // Search for siblings by following edges through an intermediate src node. 1921 // Collect candidate 'dstNode' input edges in 'inEdges'. 1922 SmallVector<MemRefDependenceGraph::Edge, 2> inEdges; 1923 mdg->forEachMemRefInputEdge( 1924 dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) { 1925 // Add 'inEdge' if it is a read-after-write dependence. 1926 if (dstNode->getLoadOpCount(inEdge.value) > 0 && 1927 mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0) 1928 inEdges.push_back(inEdge); 1929 }); 1930 1931 // Search for sibling nodes to fuse by visiting output edges from each input 1932 // edge in 'inEdges'. 1933 for (auto &inEdge : inEdges) { 1934 // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'. 1935 SmallVector<MemRefDependenceGraph::Edge, 2> outEdges; 1936 mdg->forEachMemRefOutputEdge( 1937 inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) { 1938 unsigned sibNodeId = outEdge.id; 1939 if (visitedSibNodeIds->count(sibNodeId) > 0) 1940 return; 1941 // Skip output edge if not a sibling using the same memref. 1942 if (outEdge.id == dstNode->id || outEdge.value != inEdge.value) 1943 return; 1944 auto *sibNode = mdg->getNode(sibNodeId); 1945 if (!isa<AffineForOp>(sibNode->op)) 1946 return; 1947 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. 1948 if (canFuseWithSibNode(sibNode, outEdge.value)) { 1949 // Add candidate 'outEdge' to sibling node. 1950 outEdges.push_back(outEdge); 1951 } 1952 }); 1953 1954 // Add first candidate if any were returned. 1955 if (!outEdges.empty()) { 1956 visitedSibNodeIds->insert(outEdges[0].id); 1957 idAndMemrefToFuse->first = outEdges[0].id; 1958 idAndMemrefToFuse->second = outEdges[0].value; 1959 return true; 1960 } 1961 } 1962 return false; 1963 } 1964 1965 /// Update data dependence graph state to reflect sibling fusion of 'sibNode' 1966 /// into 'dstNode'. 1967 void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) { 1968 // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion. 1969 mdg->updateEdges(sibNode->id, dstNode->id); 1970 1971 // Collect dst loop stats after memref privatization transformation. 1972 auto dstForInst = cast<AffineForOp>(dstNode->op); 1973 LoopNestStateCollector dstLoopCollector; 1974 dstLoopCollector.collect(dstForInst); 1975 // Clear and add back loads and stores 1976 mdg->clearNodeLoadAndStores(dstNode->id); 1977 mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts, 1978 dstLoopCollector.storeOpInsts); 1979 // Remove old sibling loop nest if it no longer has outgoing dependence 1980 // edges, and it does not write to a memref which escapes the block. 1981 if (mdg->getOutEdgeCount(sibNode->id) == 0) { 1982 Operation *op = sibNode->op; 1983 mdg->removeNode(sibNode->id); 1984 op->erase(); 1985 } 1986 } 1987 1988 // Clean up any allocs with no users. 1989 void eraseUnusedMemRefAllocations() { 1990 for (auto &pair : mdg->memrefEdgeCount) { 1991 if (pair.second > 0) 1992 continue; 1993 auto memref = pair.first; 1994 // Skip if there exist other uses (return operation or function calls). 1995 if (!memref.use_empty()) 1996 continue; 1997 // Use list expected to match the dep graph info. 1998 auto *op = memref.getDefiningOp(); 1999 if (isa_and_nonnull<memref::AllocOp>(op)) 2000 op->erase(); 2001 } 2002 } 2003 }; 2004 2005 } // namespace 2006 2007 /// Run fusion on `block`. 2008 void LoopFusion::runOnBlock(Block *block) { 2009 MemRefDependenceGraph g(*block); 2010 if (!g.init(block)) 2011 return; 2012 2013 Optional<unsigned> fastMemorySpaceOpt; 2014 if (fastMemorySpace.hasValue()) 2015 fastMemorySpaceOpt = fastMemorySpace; 2016 unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024; 2017 GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt, 2018 maximalFusion, computeToleranceThreshold); 2019 2020 if (affineFusionMode == FusionMode::ProducerConsumer) 2021 fusion.runProducerConsumerFusionOnly(); 2022 else if (affineFusionMode == FusionMode::Sibling) 2023 fusion.runSiblingFusionOnly(); 2024 else 2025 fusion.runGreedyFusion(); 2026 } 2027 2028 void LoopFusion::runOnOperation() { 2029 for (Region ®ion : getOperation()->getRegions()) 2030 for (Block &block : region.getBlocks()) 2031 runOnBlock(&block); 2032 } 2033