Lines Matching +full:depth +full:- +full:wise

1 //===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
11 //===----------------------------------------------------------------------===//
44 #define DEBUG_TYPE "affine-loop-fusion"
51 /// which fuses loop nests with single-writer/single-reader memref dependences
59 this->fastMemorySpace = fastMemorySpace;
60 this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024;
61 this->maximalFusion = maximalFusion;
62 this->affineFusionMode = affineFusionMode;
83 Operation *dstNodeOp = mdg->getNode(dstId)->op;
86 for (auto &outEdge : mdg->outEdges[srcId]) {
87 Operation *depNodeOp = mdg->getNode(outEdge.id)->op;
94 if (depNodeOp->getBlock() != dstNodeOp->getBlock())
100 !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) {
109 // If src loop has dependences after fusion or it writes to an live-out or
140 if (mdg->inEdges.count(dstId) == 0)
144 auto *dstNode = mdg->getNode(dstId);
146 for (Operation *load : dstNode->loads)
151 for (auto &srcEdge : mdg->inEdges[dstId]) {
152 auto *srcNode = mdg->getNode(srcEdge.id);
154 if (!isa<AffineForOp>(srcNode->op))
157 if (any_of(srcNode->stores, [&](Operation *op) {
161 srcIdCandidates.push_back(srcNode->id);
169 /// producer-consumer dependence between 'srcId' and 'dstId'.
174 auto *dstNode = mdg->getNode(dstId);
175 auto *srcNode = mdg->getNode(srcId);
176 gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads,
202 // Check if 'memref' is used by a non-deferencing op (including unknown ones)
206 Operation *ancestorOp = block->getParent()->findAncestorOpInRegion(*user);
209 if (ancestorOp->getBlock() != block)
216 /// that escape the block or are accessed in a non-affine way.
219 auto *node = mdg->getNode(id);
220 for (Operation *storeOp : node->stores) {
224 if (isEscapingMemref(memref, &mdg->block))
232 // This can increase the loop depth at which we can fuse a slice, since we are
233 // pushing loop carried dependence to a greater depth in the loop nest.
235 assert(isa<AffineForOp>(node->op));
236 AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op));
237 node->op = newRootForOp;
240 // Creates and returns a private (single-user) memref for fused loop rooted
242 // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
254 OpBuilder top(forInst->getParentRegion());
260 // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
261 MemRefRegion region(srcStoreOpInst->getLoc());
270 // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
273 assert(numElements && "non-constant number of elts in local buffer");
280 cst->getValues(rank, cst->getNumVars(), &outerIVs);
286 assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
289 for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) {
294 (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
327 simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
339 /*domOpFilter=*/&*forOp.getBody()->begin());
346 /// Returns true if there are any non-affine uses of `memref` in any of
348 /// than affine read/write are treated as non-affine uses of `memref`.
351 assert(start->getBlock() == end->getBlock());
352 assert(start->isBeforeInBlock(end) && "start expected to be before end");
353 Block *block = start->getBlock();
354 // Check if there is a non-affine memref user in any op between `start` and
359 Operation *ancestor = block->findAncestorOpInBlock(*user);
360 return ancestor && start->isBeforeInBlock(ancestor) &&
361 ancestor->isBeforeInBlock(end);
366 /// non-affine operation that is between `src` and `end` (exclusive of `src`
368 /// Any other than affine read/write are treated as non-affine uses of memref.
370 assert(src->getBlock() == end->getBlock() && "same block expected");
373 if (src == end || end->isBeforeInBlock(src))
378 src->walk([&](Operation *op) {
379 for (Value v : op->getOperands())
385 // Look for non-affine users between `src` and `end`.
395 // For producer-consumer fusion, 'srcStoreOpInst' will be the same as
396 // 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse
400 // region is the same after input-reuse fusion. Computation slices are provided
401 // in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which
404 // `dstLoopDepth` is set to the most profitable depth at which to materialize
415 // loop nest at various values of dst loop depth, attempting to fuse
416 // the largest computation slice at the maximal dst loop depth (closest to
419 // NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
421 // NOTE: We attempt to maximize the dst loop depth, but there are cases
423 // loop (within the src computation slice) at a depth which results in
471 // The best loop depth at which to materialize the slice.
478 MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
494 // Evaluate all depth choices for materializing the slice in the destination
496 for (unsigned i = maxLegalFusionDepth; i >= 1; --i) {
497 const ComputationSliceState &slice = depthSliceUnions[i - 1];
498 // Skip slice union if it wasn't computed for this depth.
512 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
517 // depth 'i'.
518 MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
551 msg << " evaluating fusion profitability at depth : " << i << "\n"
587 LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n");
596 << "\n best loop depth: " << bestDstLoopDepth
630 (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
634 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
639 msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
653 // input-reuse relationship on a memref, with the goal of improving locality.
655 // The steps of the producer-consumer fusion algorithm are as follows:
669 // at a loop depth determined by the cost model in 'isFusionProfitable'.
675 // The steps of the input-reuse fusion algorithm are as follows:
684 // at a loop depth determined by the cost model in 'isFusionProfitable'.
689 // Given a graph where top-level operations are vertices in the set 'V' and
712 // pair-wise as a fraction of the total computation.
727 // metrics (e.g. arithmetic intensity/flops-to-bytes ratio).
729 for (auto &idAndNode : mdg->nodes) {
748 // *) First pass through the nodes fuses single-use producer nodes into their
753 // TODO: Run this repeatedly until a fixed-point is reached.
767 const Node *consumerNode = mdg->getNode(consumerId);
776 (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0))
781 if (mdg->getIncomingMemRefAccesses(producerId, memref) > 0 ||
782 mdg->getOutEdgeCount(consumerId, memref) > 0)
789 any_of(mdg->outEdges[producerId], [&](const auto &edge) {
803 if (mdg->nodes.count(dstId) == 0)
806 auto *dstNode = mdg->getNode(dstId);
808 if (!isa<AffineForOp>(dstNode->op))
812 if (dstNode->op->getNumResults() > 0)
819 // depth at which we can fuse a slice of a producer loop nest into a
822 auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
839 auto *srcNode = mdg->getNode(srcId);
840 auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
846 if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0)
856 return mdg->getOutEdgeCount(srcNode->id, memref) >
865 gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
867 // Skip if there are non-affine operations in between the 'srcNode'
870 // memrefs with non-affine operation users would be considered
874 hasNonAffineUsersOnPath(srcNode->op, dstNode->op)) {
876 << "Can't fuse: non-affine users in between the loops\n");
883 mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
887 // It's possible this fusion is at an inner depth (i.e., there are
890 // needs to be passed the absolute depth. The max legal depth and the
892 // the common depth.
897 // Compute the innermost common loop depth for dstNode
898 // producer-consumer loads/stores.
900 for (Operation *op : dstNode->loads)
904 for (Operation *op : dstNode->stores)
909 getInnermostCommonLoopDepth(dstMemrefOps) - numSurroundingLoops;
921 &depthSliceUnions[i - 1], strategy);
929 << "Can't fuse: fusion is not legal at any depth\n");
934 // for maximal fusion since we already know the maximal legal depth to
940 for (Operation *op : srcNode->stores)
947 // a single producer store for now. Note that some multi-store
949 // if only one of the stores is involved the producer-consumer
962 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
964 depthSliceUnions[bestDstLoopDepth - 1];
965 assert(!bestSlice.isEmpty() && "Missing slice union for depth");
992 << " at depth " << bestDstLoopDepth << ":\n"
997 dstAffineForOp->moveBefore(fusedLoopInsPoint);
1000 mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs,
1005 // Gather stores for all the private-to-be memrefs.
1026 unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
1028 mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
1031 // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to
1033 dstNode = mdg->getNode(dstId);
1041 mdg->clearNodeLoadAndStores(dstNode->id);
1042 mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
1050 mdg->removeNode(srcId);
1058 /// producer-consumer candidates. No fusion is performed when producers with a
1062 LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n");
1074 LLVM_DEBUG(llvm::dbgs() << "--- Sibling Fusion ---\n");
1081 if (mdg->nodes.count(dstId) == 0)
1084 auto *dstNode = mdg->getNode(dstId);
1086 if (!isa<AffineForOp>(dstNode->op))
1097 auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
1102 // TODO: Check that 'sibStoreOpInst' post-dominates all other
1104 auto *sibNode = mdg->getNode(sibId);
1107 assert(sibNode->op->getBlock() == dstNode->op->getBlock());
1109 sibNode->op->isBeforeInBlock(dstNode->op)
1110 ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id)
1111 : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id);
1115 // Check if fusion would be profitable and at what depth.
1119 sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
1126 dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
1128 // It's possible this fusion is at an inner depth (i.e., there are common
1131 // passed the absolute depth. The max legal depth and the depths we try
1133 // depth.
1139 unsigned dstLoopDepthTest = dstLoopIVs.size() - numSurroundingLoops;
1140 auto sibAffineForOp = cast<AffineForOp>(sibNode->op);
1142 // Compute loop depth and slice union for fusion.
1151 &depthSliceUnions[i - 1], strategy);
1157 LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: "
1176 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
1177 assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
1178 "Fusion depth has no computed slice union");
1185 depthSliceUnions[bestDstLoopDepth - 1],
1188 auto dstForInst = cast<AffineForOp>(dstNode->op);
1191 dstForInst->moveBefore(insertPointInst);
1208 // Skip if 'outEdge' is not a read-after-write dependence.
1210 if (sibNode->getLoadOpCount(memref) != 1)
1214 if (mdg->hasDependencePath(sibNode->id, dstNode->id) ||
1215 mdg->hasDependencePath(dstNode->id, sibNode->id))
1220 sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
1222 return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0;
1228 for (auto *storeOpInst : sibNode->stores) {
1235 // Skip if a memref value in one node is used by a non-affine memref
1237 if (hasNonAffineUsersOnPath(dstNode->op, sibNode->op) ||
1238 hasNonAffineUsersOnPath(sibNode->op, dstNode->op))
1244 Block *block = dstNode->op->getBlock();
1245 for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) {
1246 for (Operation *user : block->getArgument(i).getUsers()) {
1257 return loop->getBlock() == &mdg->block;
1262 Node *sibNode = mdg->getForOpNode(*it);
1265 if (sibNode->id == dstNode->id)
1268 if (visitedSibNodeIds->count(sibNode->id) > 0)
1272 if (dstNode->getLoadOpCount(memref) == 0)
1274 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
1276 visitedSibNodeIds->insert(sibNode->id);
1277 idAndMemrefToFuse->first = sibNode->id;
1278 idAndMemrefToFuse->second = memref;
1287 mdg->forEachMemRefInputEdge(
1288 dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) {
1289 // Add 'inEdge' if it is a read-after-write dependence.
1290 if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
1291 mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0)
1300 mdg->forEachMemRefOutputEdge(
1303 if (visitedSibNodeIds->count(sibNodeId) > 0)
1306 if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
1308 auto *sibNode = mdg->getNode(sibNodeId);
1309 if (!isa<AffineForOp>(sibNode->op))
1311 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
1320 visitedSibNodeIds->insert(outEdges[0].id);
1321 idAndMemrefToFuse->first = outEdges[0].id;
1322 idAndMemrefToFuse->second = outEdges[0].value;
1333 mdg->updateEdges(sibNode->id, dstNode->id);
1336 auto dstForInst = cast<AffineForOp>(dstNode->op);
1340 mdg->clearNodeLoadAndStores(dstNode->id);
1341 mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
1345 if (mdg->getOutEdgeCount(sibNode->id) == 0) {
1346 Operation *op = sibNode->op;
1347 mdg->removeNode(sibNode->id);
1348 op->erase();
1354 for (auto &pair : mdg->memrefEdgeCount) {
1364 op->erase();
1397 getOperation()->walk([&](Operation *op) {
1398 for (Region &region : op->getRegions()) {