1a70aa7bbSRiver Riddle //===- LoopFusion.cpp - Code to perform loop fusion -----------------------===// 2a70aa7bbSRiver Riddle // 3a70aa7bbSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4a70aa7bbSRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 5a70aa7bbSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6a70aa7bbSRiver Riddle // 7a70aa7bbSRiver Riddle //===----------------------------------------------------------------------===// 8a70aa7bbSRiver Riddle // 9fe9d0a47SUday Bondhugula // This file implements affine fusion. 10a70aa7bbSRiver Riddle // 11a70aa7bbSRiver Riddle //===----------------------------------------------------------------------===// 12a70aa7bbSRiver Riddle 1367d0d7acSMichele Scuttari #include "mlir/Dialect/Affine/Passes.h" 1467d0d7acSMichele Scuttari 15a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 16a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" 17a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/Analysis/Utils.h" 18a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h" 19a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/LoopFusionUtils.h" 20a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/LoopUtils.h" 21a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/Utils.h" 22a70aa7bbSRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h" 23a70aa7bbSRiver Riddle #include "mlir/IR/AffineExpr.h" 24a70aa7bbSRiver Riddle #include "mlir/IR/AffineMap.h" 25a70aa7bbSRiver Riddle #include "mlir/IR/Builders.h" 26a70aa7bbSRiver Riddle #include "mlir/Transforms/Passes.h" 27a70aa7bbSRiver Riddle #include "llvm/ADT/DenseMap.h" 28a70aa7bbSRiver Riddle #include "llvm/ADT/DenseSet.h" 29aba43035SDmitri Gribenko #include "llvm/ADT/STLExtras.h" 30a70aa7bbSRiver Riddle #include "llvm/Support/CommandLine.h" 31a70aa7bbSRiver Riddle #include "llvm/Support/Debug.h" 32a70aa7bbSRiver Riddle #include "llvm/Support/raw_ostream.h" 33a70aa7bbSRiver Riddle #include <iomanip> 348ada9cf7SKazu Hirata #include <optional> 35a70aa7bbSRiver Riddle #include <sstream> 3667d0d7acSMichele Scuttari 3767d0d7acSMichele Scuttari namespace mlir { 384c48f016SMatthias Springer namespace affine { 3967d0d7acSMichele Scuttari #define GEN_PASS_DEF_AFFINELOOPFUSION 4067d0d7acSMichele Scuttari #include "mlir/Dialect/Affine/Passes.h.inc" 414c48f016SMatthias Springer } // namespace affine 4267d0d7acSMichele Scuttari } // namespace mlir 4367d0d7acSMichele Scuttari 44a70aa7bbSRiver Riddle #define DEBUG_TYPE "affine-loop-fusion" 45a70aa7bbSRiver Riddle 46a70aa7bbSRiver Riddle using namespace mlir; 474c48f016SMatthias Springer using namespace mlir::affine; 48a70aa7bbSRiver Riddle 49a70aa7bbSRiver Riddle namespace { 50a70aa7bbSRiver Riddle /// Loop fusion pass. This pass currently supports a greedy fusion policy, 51a70aa7bbSRiver Riddle /// which fuses loop nests with single-writer/single-reader memref dependences 52a70aa7bbSRiver Riddle /// with the goal of improving locality. 53a70aa7bbSRiver Riddle // TODO: Support fusion of source loop nests which write to multiple 54a70aa7bbSRiver Riddle // memrefs, where each memref can have multiple users (if profitable). 554c48f016SMatthias Springer struct LoopFusion : public affine::impl::AffineLoopFusionBase<LoopFusion> { 56a70aa7bbSRiver Riddle LoopFusion() = default; 57a70aa7bbSRiver Riddle LoopFusion(unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes, 58a70aa7bbSRiver Riddle bool maximalFusion, enum FusionMode affineFusionMode) { 59a70aa7bbSRiver Riddle this->fastMemorySpace = fastMemorySpace; 60a70aa7bbSRiver Riddle this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024; 61a70aa7bbSRiver Riddle this->maximalFusion = maximalFusion; 62a70aa7bbSRiver Riddle this->affineFusionMode = affineFusionMode; 63a70aa7bbSRiver Riddle } 64a70aa7bbSRiver Riddle 65fe9d0a47SUday Bondhugula void runOnBlock(Block *block); 66a70aa7bbSRiver Riddle void runOnOperation() override; 67a70aa7bbSRiver Riddle }; 68a70aa7bbSRiver Riddle 69a70aa7bbSRiver Riddle } // namespace 70a70aa7bbSRiver Riddle 71a70aa7bbSRiver Riddle /// Returns true if node 'srcId' can be removed after fusing it with node 72a70aa7bbSRiver Riddle /// 'dstId'. The node can be removed if any of the following conditions are met: 73a70aa7bbSRiver Riddle /// 1. 'srcId' has no output dependences after fusion and no escaping memrefs. 74a70aa7bbSRiver Riddle /// 2. 'srcId' has no output dependences after fusion, has escaping memrefs 75a70aa7bbSRiver Riddle /// and the fusion slice is maximal. 76a70aa7bbSRiver Riddle /// 3. 'srcId' has output dependences after fusion, the fusion slice is 77a70aa7bbSRiver Riddle /// maximal and the fusion insertion point dominates all the dependences. 78a70aa7bbSRiver Riddle static bool canRemoveSrcNodeAfterFusion( 79a70aa7bbSRiver Riddle unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice, 80a70aa7bbSRiver Riddle Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs, 81a70aa7bbSRiver Riddle MemRefDependenceGraph *mdg) { 82a70aa7bbSRiver Riddle 83a70aa7bbSRiver Riddle Operation *dstNodeOp = mdg->getNode(dstId)->op; 84a70aa7bbSRiver Riddle bool hasOutDepsAfterFusion = false; 85a70aa7bbSRiver Riddle 86a70aa7bbSRiver Riddle for (auto &outEdge : mdg->outEdges[srcId]) { 87a70aa7bbSRiver Riddle Operation *depNodeOp = mdg->getNode(outEdge.id)->op; 88a70aa7bbSRiver Riddle // Skip dependence with dstOp since it will be removed after fusion. 89a70aa7bbSRiver Riddle if (depNodeOp == dstNodeOp) 90a70aa7bbSRiver Riddle continue; 91a70aa7bbSRiver Riddle 92a70aa7bbSRiver Riddle // Only fusion within the same block is supported. Use domination analysis 93a70aa7bbSRiver Riddle // when needed. 94a70aa7bbSRiver Riddle if (depNodeOp->getBlock() != dstNodeOp->getBlock()) 95a70aa7bbSRiver Riddle return false; 96a70aa7bbSRiver Riddle 97a70aa7bbSRiver Riddle // Check if the insertion point of the fused loop dominates the dependence. 98a70aa7bbSRiver Riddle // Otherwise, the src loop can't be removed. 99a70aa7bbSRiver Riddle if (fusedLoopInsPoint != depNodeOp && 100a70aa7bbSRiver Riddle !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) { 101a70aa7bbSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't " 102a70aa7bbSRiver Riddle "dominate dependence\n"); 103a70aa7bbSRiver Riddle return false; 104a70aa7bbSRiver Riddle } 105a70aa7bbSRiver Riddle 106a70aa7bbSRiver Riddle hasOutDepsAfterFusion = true; 107a70aa7bbSRiver Riddle } 108a70aa7bbSRiver Riddle 109a70aa7bbSRiver Riddle // If src loop has dependences after fusion or it writes to an live-out or 110a70aa7bbSRiver Riddle // escaping memref, we can only remove it if the fusion slice is maximal so 111a70aa7bbSRiver Riddle // that all the dependences are preserved. 112a70aa7bbSRiver Riddle if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) { 1130a81ace0SKazu Hirata std::optional<bool> isMaximal = fusionSlice.isMaximal(); 114037f0995SKazu Hirata if (!isMaximal) { 115a70aa7bbSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine " 116a70aa7bbSRiver Riddle "if fusion is maximal\n"); 117a70aa7bbSRiver Riddle return false; 118a70aa7bbSRiver Riddle } 119a70aa7bbSRiver Riddle 1206d5fc1e3SKazu Hirata if (!*isMaximal) { 121a70aa7bbSRiver Riddle LLVM_DEBUG(llvm::dbgs() 122a70aa7bbSRiver Riddle << "Src loop can't be removed: fusion is not maximal\n"); 123a70aa7bbSRiver Riddle return false; 124a70aa7bbSRiver Riddle } 125a70aa7bbSRiver Riddle } 126a70aa7bbSRiver Riddle 127a70aa7bbSRiver Riddle return true; 128a70aa7bbSRiver Riddle } 129a70aa7bbSRiver Riddle 130a70aa7bbSRiver Riddle /// Returns in 'srcIdCandidates' the producer fusion candidates for consumer 131a70aa7bbSRiver Riddle /// 'dstId'. Candidates are sorted by node id order. This order corresponds to 132a70aa7bbSRiver Riddle /// the program order when the 'mdg' is created. However, program order is not 133a70aa7bbSRiver Riddle /// guaranteed and must not be required by the client. Program order won't be 134a70aa7bbSRiver Riddle /// held if the 'mdg' is reused from a previous fusion step or if the node 135a70aa7bbSRiver Riddle /// creation order changes in the future to support more advance cases. 136a70aa7bbSRiver Riddle // TODO: Move this to a loop fusion utility once 'mdg' is also moved. 137a70aa7bbSRiver Riddle static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg, 138a70aa7bbSRiver Riddle SmallVectorImpl<unsigned> &srcIdCandidates) { 139a70aa7bbSRiver Riddle // Skip if no input edges along which to fuse. 140a70aa7bbSRiver Riddle if (mdg->inEdges.count(dstId) == 0) 141a70aa7bbSRiver Riddle return; 142a70aa7bbSRiver Riddle 143a70aa7bbSRiver Riddle // Gather memrefs from loads in 'dstId'. 144a70aa7bbSRiver Riddle auto *dstNode = mdg->getNode(dstId); 145a70aa7bbSRiver Riddle DenseSet<Value> consumedMemrefs; 146a70aa7bbSRiver Riddle for (Operation *load : dstNode->loads) 147a70aa7bbSRiver Riddle consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef()); 148a70aa7bbSRiver Riddle 149a70aa7bbSRiver Riddle // Traverse 'dstId' incoming edges and gather the nodes that contain a store 150a70aa7bbSRiver Riddle // to one of the consumed memrefs. 151a70aa7bbSRiver Riddle for (auto &srcEdge : mdg->inEdges[dstId]) { 152a70aa7bbSRiver Riddle auto *srcNode = mdg->getNode(srcEdge.id); 153a70aa7bbSRiver Riddle // Skip if 'srcNode' is not a loop nest. 154a70aa7bbSRiver Riddle if (!isa<AffineForOp>(srcNode->op)) 155a70aa7bbSRiver Riddle continue; 156a70aa7bbSRiver Riddle 157a70aa7bbSRiver Riddle if (any_of(srcNode->stores, [&](Operation *op) { 158a70aa7bbSRiver Riddle auto storeOp = cast<AffineWriteOpInterface>(op); 159a70aa7bbSRiver Riddle return consumedMemrefs.count(storeOp.getMemRef()) > 0; 160a70aa7bbSRiver Riddle })) 161a70aa7bbSRiver Riddle srcIdCandidates.push_back(srcNode->id); 162a70aa7bbSRiver Riddle } 163a70aa7bbSRiver Riddle 164aba43035SDmitri Gribenko llvm::sort(srcIdCandidates); 165b7b337fbSKazu Hirata srcIdCandidates.erase(llvm::unique(srcIdCandidates), srcIdCandidates.end()); 166a70aa7bbSRiver Riddle } 167a70aa7bbSRiver Riddle 168a70aa7bbSRiver Riddle /// Returns in 'producerConsumerMemrefs' the memrefs involved in a 169a70aa7bbSRiver Riddle /// producer-consumer dependence between 'srcId' and 'dstId'. 170a70aa7bbSRiver Riddle static void 171a70aa7bbSRiver Riddle gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId, 172a70aa7bbSRiver Riddle MemRefDependenceGraph *mdg, 173a70aa7bbSRiver Riddle DenseSet<Value> &producerConsumerMemrefs) { 174a70aa7bbSRiver Riddle auto *dstNode = mdg->getNode(dstId); 175a70aa7bbSRiver Riddle auto *srcNode = mdg->getNode(srcId); 176a70aa7bbSRiver Riddle gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads, 177a70aa7bbSRiver Riddle producerConsumerMemrefs); 178a70aa7bbSRiver Riddle } 179a70aa7bbSRiver Riddle 180fe9d0a47SUday Bondhugula /// A memref escapes in the context of the fusion pass if either: 181dc44acc9SUday Bondhugula /// 1. it (or its alias) is a block argument, or 182dc44acc9SUday Bondhugula /// 2. created by an op not known to guarantee alias freedom, 183fe9d0a47SUday Bondhugula /// 3. it (or its alias) are used by ops other than affine dereferencing ops 184fe9d0a47SUday Bondhugula /// (e.g., by call op, memref load/store ops, alias creating ops, unknown ops, 185fe9d0a47SUday Bondhugula /// terminator ops, etc.); such ops do not deference the memref in an affine 186fe9d0a47SUday Bondhugula /// way. 187fe9d0a47SUday Bondhugula static bool isEscapingMemref(Value memref, Block *block) { 188dc44acc9SUday Bondhugula Operation *defOp = memref.getDefiningOp(); 189dc44acc9SUday Bondhugula // Check if 'memref' is a block argument. 190dc44acc9SUday Bondhugula if (!defOp) 191825b23a7SUday Bondhugula return true; 192825b23a7SUday Bondhugula 193dc44acc9SUday Bondhugula // Check if this is defined to be an alias of another memref. 194dc44acc9SUday Bondhugula if (auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp)) 195fe9d0a47SUday Bondhugula if (isEscapingMemref(viewOp.getViewSource(), block)) 196dc44acc9SUday Bondhugula return true; 197dc44acc9SUday Bondhugula 198dc44acc9SUday Bondhugula // Any op besides allocating ops wouldn't guarantee alias freedom 199dc44acc9SUday Bondhugula if (!hasSingleEffect<mlir::MemoryEffects::Allocate>(defOp, memref)) 200dc44acc9SUday Bondhugula return true; 201dc44acc9SUday Bondhugula 202dc44acc9SUday Bondhugula // Check if 'memref' is used by a non-deferencing op (including unknown ones) 203dc44acc9SUday Bondhugula // (e.g., call ops, alias creating ops, etc.). 2046b2e29c5SUday Bondhugula return llvm::any_of(memref.getUsers(), [&](Operation *user) { 205fe9d0a47SUday Bondhugula // Ignore users outside of `block`. 206eaa32d20Slong.chen Operation *ancestorOp = block->getParent()->findAncestorOpInRegion(*user); 207eaa32d20Slong.chen if (!ancestorOp) 208eaa32d20Slong.chen return true; 209eaa32d20Slong.chen if (ancestorOp->getBlock() != block) 210825b23a7SUday Bondhugula return false; 2116b2e29c5SUday Bondhugula return !isa<AffineMapAccessInterface>(*user); 2126b2e29c5SUday Bondhugula }); 213825b23a7SUday Bondhugula } 214825b23a7SUday Bondhugula 215a70aa7bbSRiver Riddle /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' 216fe9d0a47SUday Bondhugula /// that escape the block or are accessed in a non-affine way. 217c910570fSUday Bondhugula static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg, 218a70aa7bbSRiver Riddle DenseSet<Value> &escapingMemRefs) { 219a70aa7bbSRiver Riddle auto *node = mdg->getNode(id); 220825b23a7SUday Bondhugula for (Operation *storeOp : node->stores) { 221825b23a7SUday Bondhugula auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef(); 222a70aa7bbSRiver Riddle if (escapingMemRefs.count(memref)) 223a70aa7bbSRiver Riddle continue; 224fe9d0a47SUday Bondhugula if (isEscapingMemref(memref, &mdg->block)) 225a70aa7bbSRiver Riddle escapingMemRefs.insert(memref); 226a70aa7bbSRiver Riddle } 227a70aa7bbSRiver Riddle } 228a70aa7bbSRiver Riddle 229a70aa7bbSRiver Riddle // Sinks all sequential loops to the innermost levels (while preserving 230a70aa7bbSRiver Riddle // relative order among them) and moves all parallel loops to the 231a70aa7bbSRiver Riddle // outermost (while again preserving relative order among them). 232a70aa7bbSRiver Riddle // This can increase the loop depth at which we can fuse a slice, since we are 233a70aa7bbSRiver Riddle // pushing loop carried dependence to a greater depth in the loop nest. 234a70aa7bbSRiver Riddle static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { 235a70aa7bbSRiver Riddle assert(isa<AffineForOp>(node->op)); 236a70aa7bbSRiver Riddle AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op)); 237825b23a7SUday Bondhugula node->op = newRootForOp; 238a70aa7bbSRiver Riddle } 239a70aa7bbSRiver Riddle 240a70aa7bbSRiver Riddle // Creates and returns a private (single-user) memref for fused loop rooted 241a70aa7bbSRiver Riddle // at 'forOp', with (potentially reduced) memref size based on the 242a70aa7bbSRiver Riddle // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'. 243a70aa7bbSRiver Riddle // TODO: consider refactoring the common code from generateDma and 244a70aa7bbSRiver Riddle // this one. 245a70aa7bbSRiver Riddle static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, 246a70aa7bbSRiver Riddle unsigned dstLoopDepth, 2470a81ace0SKazu Hirata std::optional<unsigned> fastMemorySpace, 248a70aa7bbSRiver Riddle uint64_t localBufSizeThreshold) { 249825b23a7SUday Bondhugula Operation *forInst = forOp.getOperation(); 250a70aa7bbSRiver Riddle 251a70aa7bbSRiver Riddle // Create builder to insert alloc op just before 'forOp'. 252a70aa7bbSRiver Riddle OpBuilder b(forInst); 253a70aa7bbSRiver Riddle // Builder to create constants at the top level. 254fe9d0a47SUday Bondhugula OpBuilder top(forInst->getParentRegion()); 255a70aa7bbSRiver Riddle // Create new memref type based on slice bounds. 256a70aa7bbSRiver Riddle auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef(); 2575550c821STres Popp auto oldMemRefType = cast<MemRefType>(oldMemRef.getType()); 258a70aa7bbSRiver Riddle unsigned rank = oldMemRefType.getRank(); 259a70aa7bbSRiver Riddle 260a70aa7bbSRiver Riddle // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'. 261a70aa7bbSRiver Riddle MemRefRegion region(srcStoreOpInst->getLoc()); 262a70aa7bbSRiver Riddle bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth)); 263a70aa7bbSRiver Riddle (void)validRegion; 264a70aa7bbSRiver Riddle assert(validRegion && "unexpected memref region failure"); 265a70aa7bbSRiver Riddle SmallVector<int64_t, 4> newShape; 266a70aa7bbSRiver Riddle std::vector<SmallVector<int64_t, 4>> lbs; 267a70aa7bbSRiver Riddle SmallVector<int64_t, 8> lbDivisors; 268a70aa7bbSRiver Riddle lbs.reserve(rank); 269a70aa7bbSRiver Riddle // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed 270a70aa7bbSRiver Riddle // by 'srcStoreOpInst' at depth 'dstLoopDepth'. 2710a81ace0SKazu Hirata std::optional<int64_t> numElements = 272a70aa7bbSRiver Riddle region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors); 273ad7ce1e7SKazu Hirata assert(numElements && "non-constant number of elts in local buffer"); 274a70aa7bbSRiver Riddle 275a70aa7bbSRiver Riddle const FlatAffineValueConstraints *cst = region.getConstraints(); 276a70aa7bbSRiver Riddle // 'outerIVs' holds the values that this memory region is symbolic/parametric 277a70aa7bbSRiver Riddle // on; this would correspond to loop IVs surrounding the level at which the 278a70aa7bbSRiver Riddle // slice is being materialized. 279a70aa7bbSRiver Riddle SmallVector<Value, 8> outerIVs; 280d95140a5SGroverkss cst->getValues(rank, cst->getNumVars(), &outerIVs); 281a70aa7bbSRiver Riddle 282a70aa7bbSRiver Riddle // Build 'rank' AffineExprs from MemRefRegion 'lbs' 283a70aa7bbSRiver Riddle SmallVector<AffineExpr, 4> offsets; 284a70aa7bbSRiver Riddle offsets.reserve(rank); 285a70aa7bbSRiver Riddle for (unsigned d = 0; d < rank; ++d) { 286a70aa7bbSRiver Riddle assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size"); 287a70aa7bbSRiver Riddle 288a70aa7bbSRiver Riddle AffineExpr offset = top.getAffineConstantExpr(0); 289a70aa7bbSRiver Riddle for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) { 290a70aa7bbSRiver Riddle offset = offset + lbs[d][j] * top.getAffineDimExpr(j); 291a70aa7bbSRiver Riddle } 292a70aa7bbSRiver Riddle assert(lbDivisors[d] > 0); 293a70aa7bbSRiver Riddle offset = 294a70aa7bbSRiver Riddle (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]); 295a70aa7bbSRiver Riddle offsets.push_back(offset); 296a70aa7bbSRiver Riddle } 297a70aa7bbSRiver Riddle 298a70aa7bbSRiver Riddle // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed 299a70aa7bbSRiver Riddle // by 'srcStoreOpInst'. 300d25e022cSUday Bondhugula auto eltSize = getMemRefIntOrFloatEltSizeInBytes(oldMemRefType); 301d25e022cSUday Bondhugula assert(eltSize && "memrefs with size elt types expected"); 302d25e022cSUday Bondhugula uint64_t bufSize = *eltSize * *numElements; 303a70aa7bbSRiver Riddle unsigned newMemSpace; 304491d2701SKazu Hirata if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) { 3054913e5daSFangrui Song newMemSpace = *fastMemorySpace; 306a70aa7bbSRiver Riddle } else { 307a70aa7bbSRiver Riddle newMemSpace = oldMemRefType.getMemorySpaceAsInt(); 308a70aa7bbSRiver Riddle } 309a70aa7bbSRiver Riddle auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(), 310a70aa7bbSRiver Riddle {}, newMemSpace); 311a70aa7bbSRiver Riddle 312a70aa7bbSRiver Riddle // Create new private memref for fused loop 'forOp'. 'newShape' is always 313a70aa7bbSRiver Riddle // a constant shape. 314a70aa7bbSRiver Riddle // TODO: Create/move alloc ops for private memrefs closer to their 315a70aa7bbSRiver Riddle // consumer loop nests to reduce their live range. Currently they are added 316fe9d0a47SUday Bondhugula // at the beginning of the block, because loop nests can be reordered 317a70aa7bbSRiver Riddle // during the fusion pass. 318a70aa7bbSRiver Riddle Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType); 319a70aa7bbSRiver Riddle 320a70aa7bbSRiver Riddle // Build an AffineMap to remap access functions based on lower bound offsets. 321a70aa7bbSRiver Riddle SmallVector<AffineExpr, 4> remapExprs; 322a70aa7bbSRiver Riddle remapExprs.reserve(rank); 323a70aa7bbSRiver Riddle for (unsigned i = 0; i < rank; i++) { 324a70aa7bbSRiver Riddle auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i); 325a70aa7bbSRiver Riddle 326a70aa7bbSRiver Riddle auto remapExpr = 327a70aa7bbSRiver Riddle simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0); 328a70aa7bbSRiver Riddle remapExprs.push_back(remapExpr); 329a70aa7bbSRiver Riddle } 330a70aa7bbSRiver Riddle 331a70aa7bbSRiver Riddle auto indexRemap = 332a70aa7bbSRiver Riddle AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext()); 333a70aa7bbSRiver Riddle 334a70aa7bbSRiver Riddle // Replace all users of 'oldMemRef' with 'newMemRef'. 335a70aa7bbSRiver Riddle LogicalResult res = 336a70aa7bbSRiver Riddle replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, 337a70aa7bbSRiver Riddle /*extraOperands=*/outerIVs, 338a70aa7bbSRiver Riddle /*symbolOperands=*/{}, 339a70aa7bbSRiver Riddle /*domOpFilter=*/&*forOp.getBody()->begin()); 340a70aa7bbSRiver Riddle assert(succeeded(res) && 341a70aa7bbSRiver Riddle "replaceAllMemrefUsesWith should always succeed here"); 342a70aa7bbSRiver Riddle (void)res; 343a70aa7bbSRiver Riddle return newMemRef; 344a70aa7bbSRiver Riddle } 345a70aa7bbSRiver Riddle 346*469f9d5fSUday Bondhugula /// Returns true if there are any non-affine uses of `memref` in any of 347*469f9d5fSUday Bondhugula /// the operations between `start` and `end` (both exclusive). Any other 348*469f9d5fSUday Bondhugula /// than affine read/write are treated as non-affine uses of `memref`. 349*469f9d5fSUday Bondhugula static bool hasNonAffineUsersOnPath(Operation *start, Operation *end, 350*469f9d5fSUday Bondhugula Value memref) { 351*469f9d5fSUday Bondhugula assert(start->getBlock() == end->getBlock()); 352*469f9d5fSUday Bondhugula assert(start->isBeforeInBlock(end) && "start expected to be before end"); 353*469f9d5fSUday Bondhugula Block *block = start->getBlock(); 354*469f9d5fSUday Bondhugula // Check if there is a non-affine memref user in any op between `start` and 355*469f9d5fSUday Bondhugula // `end`. 356*469f9d5fSUday Bondhugula return llvm::any_of(memref.getUsers(), [&](Operation *user) { 357*469f9d5fSUday Bondhugula if (isa<AffineReadOpInterface, AffineWriteOpInterface>(user)) 358a70aa7bbSRiver Riddle return false; 359*469f9d5fSUday Bondhugula Operation *ancestor = block->findAncestorOpInBlock(*user); 360*469f9d5fSUday Bondhugula return ancestor && start->isBeforeInBlock(ancestor) && 361*469f9d5fSUday Bondhugula ancestor->isBeforeInBlock(end); 362*469f9d5fSUday Bondhugula }); 363a70aa7bbSRiver Riddle } 364a70aa7bbSRiver Riddle 365*469f9d5fSUday Bondhugula /// Check whether a memref value used in any operation of 'src' has a 366*469f9d5fSUday Bondhugula /// non-affine operation that is between `src` and `end` (exclusive of `src` 367*469f9d5fSUday Bondhugula /// and `end`) where `src` and `end` are expected to be in the same Block. 368*469f9d5fSUday Bondhugula /// Any other than affine read/write are treated as non-affine uses of memref. 369*469f9d5fSUday Bondhugula static bool hasNonAffineUsersOnPath(Operation *src, Operation *end) { 370*469f9d5fSUday Bondhugula assert(src->getBlock() == end->getBlock() && "same block expected"); 371*469f9d5fSUday Bondhugula 372*469f9d5fSUday Bondhugula // Trivial case. `src` and `end` are exclusive. 373*469f9d5fSUday Bondhugula if (src == end || end->isBeforeInBlock(src)) 374*469f9d5fSUday Bondhugula return false; 375*469f9d5fSUday Bondhugula 376*469f9d5fSUday Bondhugula // Collect relevant memref values. 377a70aa7bbSRiver Riddle llvm::SmallDenseSet<Value, 2> memRefValues; 378*469f9d5fSUday Bondhugula src->walk([&](Operation *op) { 379a70aa7bbSRiver Riddle for (Value v : op->getOperands()) 380a70aa7bbSRiver Riddle // Collect memref values only. 3815550c821STres Popp if (isa<MemRefType>(v.getType())) 382a70aa7bbSRiver Riddle memRefValues.insert(v); 383a70aa7bbSRiver Riddle return WalkResult::advance(); 384a70aa7bbSRiver Riddle }); 385*469f9d5fSUday Bondhugula // Look for non-affine users between `src` and `end`. 3866b2e29c5SUday Bondhugula return llvm::any_of(memRefValues, [&](Value memref) { 387*469f9d5fSUday Bondhugula return hasNonAffineUsersOnPath(src, end, memref); 3886b2e29c5SUday Bondhugula }); 389a70aa7bbSRiver Riddle } 390a70aa7bbSRiver Riddle 391a70aa7bbSRiver Riddle // Checks the profitability of fusing a backwards slice of the loop nest 392a70aa7bbSRiver Riddle // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'. 393a70aa7bbSRiver Riddle // The argument 'srcStoreOpInst' is used to calculate the storage reduction on 394a70aa7bbSRiver Riddle // the memref being produced and consumed, which is an input to the cost model. 395a70aa7bbSRiver Riddle // For producer-consumer fusion, 'srcStoreOpInst' will be the same as 396a70aa7bbSRiver Riddle // 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse 397a70aa7bbSRiver Riddle // fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the 398a70aa7bbSRiver Riddle // same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the 399a70aa7bbSRiver Riddle // unique store op in the src node, which will be used to check that the write 400a70aa7bbSRiver Riddle // region is the same after input-reuse fusion. Computation slices are provided 401a70aa7bbSRiver Riddle // in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which 402a70aa7bbSRiver Riddle // fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is 403a70aa7bbSRiver Riddle // profitable to fuse the candidate loop nests. Returns false otherwise. 404a70aa7bbSRiver Riddle // `dstLoopDepth` is set to the most profitable depth at which to materialize 405a70aa7bbSRiver Riddle // the source loop nest slice. 406a70aa7bbSRiver Riddle // The profitability model executes the following steps: 407a70aa7bbSRiver Riddle // *) Computes the backward computation slice at 'srcOpInst'. This 408a70aa7bbSRiver Riddle // computation slice of the loop nest surrounding 'srcOpInst' is 409a70aa7bbSRiver Riddle // represented by modified src loop bounds in 'sliceState', which are 410a70aa7bbSRiver Riddle // functions of loop IVs in the loop nest surrounding 'srcOpInst'. 411a70aa7bbSRiver Riddle // *) Computes the cost of unfused src/dst loop nests (currently the cost of a 412a70aa7bbSRiver Riddle // loop nest is the total number of dynamic operation instances in the loop 413a70aa7bbSRiver Riddle // nest). 414a70aa7bbSRiver Riddle // *) Computes the cost of fusing a slice of the src loop nest into the dst 415a70aa7bbSRiver Riddle // loop nest at various values of dst loop depth, attempting to fuse 416a70aa7bbSRiver Riddle // the largest computation slice at the maximal dst loop depth (closest to 417a70aa7bbSRiver Riddle // the load) to minimize reuse distance and potentially enable subsequent 418a70aa7bbSRiver Riddle // load/store forwarding. 419a70aa7bbSRiver Riddle // NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop 420a70aa7bbSRiver Riddle // nest, at which the src computation slice is inserted/fused. 421a70aa7bbSRiver Riddle // NOTE: We attempt to maximize the dst loop depth, but there are cases 422a70aa7bbSRiver Riddle // where a particular setting for 'dstLoopNest' might fuse an unsliced 423a70aa7bbSRiver Riddle // loop (within the src computation slice) at a depth which results in 424a70aa7bbSRiver Riddle // excessive recomputation (see unit tests for examples). 425a70aa7bbSRiver Riddle // *) Compares the total cost of the unfused loop nests to the min cost fused 426a70aa7bbSRiver Riddle // loop nest computed in the previous step, and returns true if the latter 427a70aa7bbSRiver Riddle // is lower. 428a70aa7bbSRiver Riddle // TODO: Extend profitability analysis to support scenarios with multiple 429a70aa7bbSRiver Riddle // stores. 430a70aa7bbSRiver Riddle static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, 431a70aa7bbSRiver Riddle AffineForOp dstForOp, 432a70aa7bbSRiver Riddle ArrayRef<ComputationSliceState> depthSliceUnions, 433a70aa7bbSRiver Riddle unsigned maxLegalFusionDepth, 434a70aa7bbSRiver Riddle unsigned *dstLoopDepth, 435a70aa7bbSRiver Riddle double computeToleranceThreshold) { 436a70aa7bbSRiver Riddle LLVM_DEBUG({ 437a70aa7bbSRiver Riddle llvm::dbgs() << "Checking whether fusion is profitable between src op:\n"; 438a70aa7bbSRiver Riddle llvm::dbgs() << ' ' << *srcOpInst << " and destination loop:\n"; 439a70aa7bbSRiver Riddle llvm::dbgs() << dstForOp << "\n"; 440a70aa7bbSRiver Riddle }); 441a70aa7bbSRiver Riddle 442a70aa7bbSRiver Riddle if (maxLegalFusionDepth == 0) { 4433ab88e79SUday Bondhugula LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth is 0\n"); 444a70aa7bbSRiver Riddle return false; 445a70aa7bbSRiver Riddle } 446a70aa7bbSRiver Riddle 447a70aa7bbSRiver Riddle // Compute cost of sliced and unsliced src loop nest. 448a70aa7bbSRiver Riddle SmallVector<AffineForOp, 4> srcLoopIVs; 44923bcd6b8SUday Bondhugula getAffineForIVs(*srcOpInst, &srcLoopIVs); 450a70aa7bbSRiver Riddle 451a70aa7bbSRiver Riddle // Walk src loop nest and collect stats. 452a70aa7bbSRiver Riddle LoopNestStats srcLoopNestStats; 453a70aa7bbSRiver Riddle if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats)) 454a70aa7bbSRiver Riddle return false; 455a70aa7bbSRiver Riddle 456a70aa7bbSRiver Riddle // Compute cost of dst loop nest. 457a70aa7bbSRiver Riddle LoopNestStats dstLoopNestStats; 458a70aa7bbSRiver Riddle if (!getLoopNestStats(dstForOp, &dstLoopNestStats)) 459a70aa7bbSRiver Riddle return false; 460a70aa7bbSRiver Riddle 461a70aa7bbSRiver Riddle // Search for min cost value for 'dstLoopDepth'. At each value of 462a70aa7bbSRiver Riddle // 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice 463a70aa7bbSRiver Riddle // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union 464a70aa7bbSRiver Riddle // of these bounds). Next the union slice bounds are used to calculate 465a70aa7bbSRiver Riddle // the cost of the slice and the cost of the slice inserted into the dst 466a70aa7bbSRiver Riddle // loop nest at 'dstLoopDepth'. 467a70aa7bbSRiver Riddle uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max(); 468a70aa7bbSRiver Riddle double maxStorageReduction = 0.0; 4698ada9cf7SKazu Hirata std::optional<uint64_t> sliceMemEstimate; 470a70aa7bbSRiver Riddle 471a70aa7bbSRiver Riddle // The best loop depth at which to materialize the slice. 4728ada9cf7SKazu Hirata std::optional<unsigned> bestDstLoopDepth; 473a70aa7bbSRiver Riddle 474a70aa7bbSRiver Riddle // Compute op instance count for the src loop nest without iteration slicing. 475a70aa7bbSRiver Riddle uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats); 476a70aa7bbSRiver Riddle 477a70aa7bbSRiver Riddle // Compute src loop nest write region size. 478a70aa7bbSRiver Riddle MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc()); 479a70aa7bbSRiver Riddle if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) { 480a70aa7bbSRiver Riddle LLVM_DEBUG(llvm::dbgs() 4813ab88e79SUday Bondhugula << "Unable to compute MemRefRegion for source operation\n"); 482a70aa7bbSRiver Riddle return false; 483a70aa7bbSRiver Riddle } 484a70aa7bbSRiver Riddle 4850a81ace0SKazu Hirata std::optional<int64_t> maybeSrcWriteRegionSizeBytes = 486a70aa7bbSRiver Riddle srcWriteRegion.getRegionSize(); 487491d2701SKazu Hirata if (!maybeSrcWriteRegionSizeBytes.has_value()) 488a70aa7bbSRiver Riddle return false; 4894913e5daSFangrui Song int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes; 490a70aa7bbSRiver Riddle 491a70aa7bbSRiver Riddle // Compute op instance count for the src loop nest. 492a70aa7bbSRiver Riddle uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats); 493a70aa7bbSRiver Riddle 494a70aa7bbSRiver Riddle // Evaluate all depth choices for materializing the slice in the destination 495a70aa7bbSRiver Riddle // loop nest. 496a70aa7bbSRiver Riddle for (unsigned i = maxLegalFusionDepth; i >= 1; --i) { 497a70aa7bbSRiver Riddle const ComputationSliceState &slice = depthSliceUnions[i - 1]; 498a70aa7bbSRiver Riddle // Skip slice union if it wasn't computed for this depth. 499a70aa7bbSRiver Riddle if (slice.isEmpty()) 500a70aa7bbSRiver Riddle continue; 501a70aa7bbSRiver Riddle 502a70aa7bbSRiver Riddle int64_t fusedLoopNestComputeCost; 503a70aa7bbSRiver Riddle if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp, 504a70aa7bbSRiver Riddle dstLoopNestStats, slice, 505a70aa7bbSRiver Riddle &fusedLoopNestComputeCost)) { 5063ab88e79SUday Bondhugula LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n"); 507a70aa7bbSRiver Riddle continue; 508a70aa7bbSRiver Riddle } 509a70aa7bbSRiver Riddle 510a70aa7bbSRiver Riddle double additionalComputeFraction = 511a70aa7bbSRiver Riddle fusedLoopNestComputeCost / 512a70aa7bbSRiver Riddle (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) - 513a70aa7bbSRiver Riddle 1; 514a70aa7bbSRiver Riddle 515a70aa7bbSRiver Riddle // Determine what the slice write MemRefRegion would be, if the src loop 516a70aa7bbSRiver Riddle // nest slice 'slice' were to be inserted into the dst loop nest at loop 517a70aa7bbSRiver Riddle // depth 'i'. 518a70aa7bbSRiver Riddle MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc()); 519a70aa7bbSRiver Riddle if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0, 520a70aa7bbSRiver Riddle &slice))) { 521a70aa7bbSRiver Riddle LLVM_DEBUG(llvm::dbgs() 522a70aa7bbSRiver Riddle << "Failed to compute slice write region at loopDepth: " << i 523a70aa7bbSRiver Riddle << "\n"); 524a70aa7bbSRiver Riddle continue; 525a70aa7bbSRiver Riddle } 526a70aa7bbSRiver Riddle 5270a81ace0SKazu Hirata std::optional<int64_t> maybeSliceWriteRegionSizeBytes = 528a70aa7bbSRiver Riddle sliceWriteRegion.getRegionSize(); 529491d2701SKazu Hirata if (!maybeSliceWriteRegionSizeBytes.has_value() || 5304913e5daSFangrui Song *maybeSliceWriteRegionSizeBytes == 0) { 531a70aa7bbSRiver Riddle LLVM_DEBUG(llvm::dbgs() 532a70aa7bbSRiver Riddle << "Failed to get slice write region size at loopDepth: " << i 533a70aa7bbSRiver Riddle << "\n"); 534a70aa7bbSRiver Riddle continue; 535a70aa7bbSRiver Riddle } 5364913e5daSFangrui Song int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes; 537a70aa7bbSRiver Riddle 538a70aa7bbSRiver Riddle // If we are fusing for reuse, check that write regions remain the same. 539a70aa7bbSRiver Riddle // TODO: Write region check should check sizes and offsets in 540a70aa7bbSRiver Riddle // each dimension, so that we are sure they are covering the same memref 541a70aa7bbSRiver Riddle // region. Also, move this out to a isMemRefRegionSuperSet helper function. 542a70aa7bbSRiver Riddle if (srcOpInst != srcStoreOpInst && 543a70aa7bbSRiver Riddle sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes) 544a70aa7bbSRiver Riddle continue; 545a70aa7bbSRiver Riddle 546a70aa7bbSRiver Riddle double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) / 547a70aa7bbSRiver Riddle static_cast<double>(sliceWriteRegionSizeBytes); 548a70aa7bbSRiver Riddle 549a70aa7bbSRiver Riddle LLVM_DEBUG({ 550a70aa7bbSRiver Riddle std::stringstream msg; 551a70aa7bbSRiver Riddle msg << " evaluating fusion profitability at depth : " << i << "\n" 552a70aa7bbSRiver Riddle << std::fixed << std::setprecision(2) 553a70aa7bbSRiver Riddle << " additional compute fraction: " 554a70aa7bbSRiver Riddle << 100.0 * additionalComputeFraction << "%\n" 555a70aa7bbSRiver Riddle << " storage reduction factor: " << storageReduction << "x\n" 556a70aa7bbSRiver Riddle << " fused nest cost: " << fusedLoopNestComputeCost << "\n" 557a70aa7bbSRiver Riddle << " src write region size: " << srcWriteRegionSizeBytes << "\n" 558a70aa7bbSRiver Riddle << " slice write region size: " << sliceWriteRegionSizeBytes 559a70aa7bbSRiver Riddle << "\n"; 560a70aa7bbSRiver Riddle llvm::dbgs() << msg.str(); 561a70aa7bbSRiver Riddle }); 562a70aa7bbSRiver Riddle 563a70aa7bbSRiver Riddle // TODO: This is a placeholder cost model. 564a70aa7bbSRiver Riddle // Among all choices that add an acceptable amount of redundant computation 565a70aa7bbSRiver Riddle // (as per computeToleranceThreshold), we will simply pick the one that 566a70aa7bbSRiver Riddle // reduces the intermediary size the most. 567a70aa7bbSRiver Riddle if ((storageReduction > maxStorageReduction) && 568a70aa7bbSRiver Riddle (additionalComputeFraction < computeToleranceThreshold)) { 569a70aa7bbSRiver Riddle maxStorageReduction = storageReduction; 570a70aa7bbSRiver Riddle bestDstLoopDepth = i; 571a70aa7bbSRiver Riddle minFusedLoopNestComputeCost = fusedLoopNestComputeCost; 572a70aa7bbSRiver Riddle sliceMemEstimate = sliceWriteRegionSizeBytes; 573a70aa7bbSRiver Riddle } 574a70aa7bbSRiver Riddle } 575a70aa7bbSRiver Riddle 576a70aa7bbSRiver Riddle // A simple cost model: fuse if it reduces the memory footprint. 577a70aa7bbSRiver Riddle 578037f0995SKazu Hirata if (!bestDstLoopDepth) { 579a70aa7bbSRiver Riddle LLVM_DEBUG( 580a70aa7bbSRiver Riddle llvm::dbgs() 581a70aa7bbSRiver Riddle << "All fusion choices involve more than the threshold amount of " 582a70aa7bbSRiver Riddle "redundant computation; NOT fusing.\n"); 583a70aa7bbSRiver Riddle return false; 584a70aa7bbSRiver Riddle } 585a70aa7bbSRiver Riddle 586037f0995SKazu Hirata if (!bestDstLoopDepth) { 587a70aa7bbSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n"); 588a70aa7bbSRiver Riddle return false; 589a70aa7bbSRiver Riddle } 590a70aa7bbSRiver Riddle 591a70aa7bbSRiver Riddle // Set dstLoopDepth based on best values from search. 5926d5fc1e3SKazu Hirata *dstLoopDepth = *bestDstLoopDepth; 593a70aa7bbSRiver Riddle 594a70aa7bbSRiver Riddle LLVM_DEBUG( 595a70aa7bbSRiver Riddle llvm::dbgs() << " LoopFusion fusion stats:" 596a70aa7bbSRiver Riddle << "\n best loop depth: " << bestDstLoopDepth 597a70aa7bbSRiver Riddle << "\n src loop nest compute cost: " << srcLoopNestCost 598a70aa7bbSRiver Riddle << "\n dst loop nest compute cost: " << dstLoopNestCost 599a70aa7bbSRiver Riddle << "\n fused loop nest compute cost: " 600a70aa7bbSRiver Riddle << minFusedLoopNestComputeCost << "\n"); 601a70aa7bbSRiver Riddle 602a70aa7bbSRiver Riddle auto dstMemSize = getMemoryFootprintBytes(dstForOp); 603a70aa7bbSRiver Riddle auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]); 604a70aa7bbSRiver Riddle 6058ada9cf7SKazu Hirata std::optional<double> storageReduction; 606a70aa7bbSRiver Riddle 607037f0995SKazu Hirata if (!dstMemSize || !srcMemSize) { 608a70aa7bbSRiver Riddle LLVM_DEBUG(llvm::dbgs() 609a70aa7bbSRiver Riddle << " fusion memory benefit cannot be evaluated; NOT fusing.\n"); 610a70aa7bbSRiver Riddle return false; 611a70aa7bbSRiver Riddle } 612a70aa7bbSRiver Riddle 6134913e5daSFangrui Song auto srcMemSizeVal = *srcMemSize; 6144913e5daSFangrui Song auto dstMemSizeVal = *dstMemSize; 615a70aa7bbSRiver Riddle 6165413bf1bSKazu Hirata assert(sliceMemEstimate && "expected value"); 6174913e5daSFangrui Song auto fusedMem = dstMemSizeVal + *sliceMemEstimate; 618a70aa7bbSRiver Riddle 619a70aa7bbSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n" 620a70aa7bbSRiver Riddle << " dst mem: " << dstMemSizeVal << "\n" 621a70aa7bbSRiver Riddle << " fused mem: " << fusedMem << "\n" 622a70aa7bbSRiver Riddle << " slice mem: " << sliceMemEstimate << "\n"); 623a70aa7bbSRiver Riddle 624a70aa7bbSRiver Riddle if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) { 625a70aa7bbSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n"); 626a70aa7bbSRiver Riddle return false; 627a70aa7bbSRiver Riddle } 628a70aa7bbSRiver Riddle storageReduction = 629a70aa7bbSRiver Riddle 100.0 * 630a70aa7bbSRiver Riddle (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal)); 631a70aa7bbSRiver Riddle 632a70aa7bbSRiver Riddle double additionalComputeFraction = 633a70aa7bbSRiver Riddle 100.0 * (minFusedLoopNestComputeCost / 634a70aa7bbSRiver Riddle (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) - 635a70aa7bbSRiver Riddle 1); 636a70aa7bbSRiver Riddle (void)additionalComputeFraction; 637a70aa7bbSRiver Riddle LLVM_DEBUG({ 638a70aa7bbSRiver Riddle std::stringstream msg; 639a70aa7bbSRiver Riddle msg << " fusion is most profitable at depth " << *dstLoopDepth << " with " 640a70aa7bbSRiver Riddle << std::setprecision(2) << additionalComputeFraction 641a70aa7bbSRiver Riddle << "% redundant computation and a "; 642d66cbc56SKazu Hirata msg << (storageReduction ? std::to_string(*storageReduction) : "<unknown>"); 643a70aa7bbSRiver Riddle msg << "% storage reduction.\n"; 644a70aa7bbSRiver Riddle llvm::dbgs() << msg.str(); 645a70aa7bbSRiver Riddle }); 646a70aa7bbSRiver Riddle 647a70aa7bbSRiver Riddle return true; 648a70aa7bbSRiver Riddle } 649a70aa7bbSRiver Riddle 650a70aa7bbSRiver Riddle namespace { 651a70aa7bbSRiver Riddle 652a70aa7bbSRiver Riddle // GreedyFusion greedily fuses loop nests which have a producer/consumer or 653a70aa7bbSRiver Riddle // input-reuse relationship on a memref, with the goal of improving locality. 654a70aa7bbSRiver Riddle // 655a70aa7bbSRiver Riddle // The steps of the producer-consumer fusion algorithm are as follows: 656a70aa7bbSRiver Riddle // 657a70aa7bbSRiver Riddle // *) A worklist is initialized with node ids from the dependence graph. 658a70aa7bbSRiver Riddle // *) For each node id in the worklist: 659a70aa7bbSRiver Riddle // *) Pop an AffineForOp of the worklist. This 'dstAffineForOp' will be a 660a70aa7bbSRiver Riddle // candidate destination AffineForOp into which fusion will be attempted. 661a70aa7bbSRiver Riddle // *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'. 662a70aa7bbSRiver Riddle // *) For each LoadOp in 'dstLoadOps' do: 663a70aa7bbSRiver Riddle // *) Look up dependent loop nests which have a single store op to the same 664a70aa7bbSRiver Riddle // memref. 665a70aa7bbSRiver Riddle // *) Check if dependences would be violated by the fusion. 666a70aa7bbSRiver Riddle // *) Get a computation slice of 'srcLoopNest', which adjusts its loop 667a70aa7bbSRiver Riddle // bounds to be functions of 'dstLoopNest' IVs and symbols. 668a70aa7bbSRiver Riddle // *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest', 669a70aa7bbSRiver Riddle // at a loop depth determined by the cost model in 'isFusionProfitable'. 670a70aa7bbSRiver Riddle // *) Add the newly fused load/store operations to the state, 671a70aa7bbSRiver Riddle // and also add newly fused load ops to 'dstLoopOps' to be considered 672a70aa7bbSRiver Riddle // as fusion dst load ops in another iteration. 673a70aa7bbSRiver Riddle // *) Remove old src loop nest and its associated state. 674a70aa7bbSRiver Riddle // 675a70aa7bbSRiver Riddle // The steps of the input-reuse fusion algorithm are as follows: 676a70aa7bbSRiver Riddle // 677a70aa7bbSRiver Riddle // *) Initialize 'worklist' with node ids from the dependence graph. 678a70aa7bbSRiver Riddle // *) For each 'dstNode' in the worklist: 679a70aa7bbSRiver Riddle // *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which 680a70aa7bbSRiver Riddle // loads from the same memref, but which has no dependence paths to/from. 681a70aa7bbSRiver Riddle // *) Get a computation slice of 'sibLoopNest', which adjusts its loop 682a70aa7bbSRiver Riddle // bounds to be functions of 'dstLoopNest' IVs and symbols. 683a70aa7bbSRiver Riddle // *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest', 684a70aa7bbSRiver Riddle // at a loop depth determined by the cost model in 'isFusionProfitable'. 685a70aa7bbSRiver Riddle // This function also checks that the memref write region of 'sibLoopNest', 686a70aa7bbSRiver Riddle // is preserved in the fused loop nest. 687a70aa7bbSRiver Riddle // *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'. 688a70aa7bbSRiver Riddle // 689a70aa7bbSRiver Riddle // Given a graph where top-level operations are vertices in the set 'V' and 690a70aa7bbSRiver Riddle // edges in the set 'E' are dependences between vertices, this algorithm 691a70aa7bbSRiver Riddle // takes O(V) time for initialization, and has runtime O(V + E). 692a70aa7bbSRiver Riddle // 693a70aa7bbSRiver Riddle // This greedy algorithm is not 'maximal' due to the current restriction of 694a70aa7bbSRiver Riddle // fusing along single producer consumer edges, but there is a TODO: to fix 695a70aa7bbSRiver Riddle // this. 696a70aa7bbSRiver Riddle // 697a70aa7bbSRiver Riddle // TODO: Experiment with other fusion policies. 698a70aa7bbSRiver Riddle struct GreedyFusion { 699a70aa7bbSRiver Riddle public: 700a70aa7bbSRiver Riddle // The data dependence graph to traverse during fusion. 701a70aa7bbSRiver Riddle MemRefDependenceGraph *mdg; 702a70aa7bbSRiver Riddle // Worklist of graph nodes visited during the fusion pass. 703a70aa7bbSRiver Riddle SmallVector<unsigned, 8> worklist; 704a70aa7bbSRiver Riddle // Parameter for local buffer size threshold. 705a70aa7bbSRiver Riddle unsigned localBufSizeThreshold; 706a70aa7bbSRiver Riddle // Parameter for fast memory space. 7070a81ace0SKazu Hirata std::optional<unsigned> fastMemorySpace; 708a70aa7bbSRiver Riddle // If true, ignore any additional (redundant) computation tolerance threshold 709a70aa7bbSRiver Riddle // that would have prevented fusion. 710a70aa7bbSRiver Riddle bool maximalFusion; 711a70aa7bbSRiver Riddle // The amount of additional computation that is tolerated while fusing 712a70aa7bbSRiver Riddle // pair-wise as a fraction of the total computation. 713a70aa7bbSRiver Riddle double computeToleranceThreshold; 714a70aa7bbSRiver Riddle 715a70aa7bbSRiver Riddle using Node = MemRefDependenceGraph::Node; 716a70aa7bbSRiver Riddle 717a70aa7bbSRiver Riddle GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold, 7180a81ace0SKazu Hirata std::optional<unsigned> fastMemorySpace, bool maximalFusion, 719a70aa7bbSRiver Riddle double computeToleranceThreshold) 720a70aa7bbSRiver Riddle : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold), 721a70aa7bbSRiver Riddle fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion), 722a70aa7bbSRiver Riddle computeToleranceThreshold(computeToleranceThreshold) {} 723a70aa7bbSRiver Riddle 724a70aa7bbSRiver Riddle /// Initializes 'worklist' with nodes from 'mdg'. 725a70aa7bbSRiver Riddle void init() { 726a70aa7bbSRiver Riddle // TODO: Add a priority queue for prioritizing nodes by different 727a70aa7bbSRiver Riddle // metrics (e.g. arithmetic intensity/flops-to-bytes ratio). 728a70aa7bbSRiver Riddle worklist.clear(); 729a70aa7bbSRiver Riddle for (auto &idAndNode : mdg->nodes) { 730a70aa7bbSRiver Riddle const Node &node = idAndNode.second; 731a70aa7bbSRiver Riddle worklist.push_back(node.id); 732a70aa7bbSRiver Riddle } 733a70aa7bbSRiver Riddle } 734a70aa7bbSRiver Riddle /// Run only sibling fusion on the `mdg`. 735a70aa7bbSRiver Riddle void runSiblingFusionOnly() { 736a70aa7bbSRiver Riddle fuseSiblingNodes(); 737a70aa7bbSRiver Riddle eraseUnusedMemRefAllocations(); 738a70aa7bbSRiver Riddle } 739a70aa7bbSRiver Riddle 740a70aa7bbSRiver Riddle /// Run only producer/consumer fusion on the `mdg`. 741a70aa7bbSRiver Riddle void runProducerConsumerFusionOnly() { 742a70aa7bbSRiver Riddle fuseProducerConsumerNodes( 743a70aa7bbSRiver Riddle /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max()); 744a70aa7bbSRiver Riddle eraseUnusedMemRefAllocations(); 745a70aa7bbSRiver Riddle } 746a70aa7bbSRiver Riddle 747a70aa7bbSRiver Riddle // Run the GreedyFusion pass. 748a70aa7bbSRiver Riddle // *) First pass through the nodes fuses single-use producer nodes into their 749a70aa7bbSRiver Riddle // unique consumer. 750a70aa7bbSRiver Riddle // *) Second pass fuses sibling nodes which share no dependence edges. 751a70aa7bbSRiver Riddle // *) Third pass fuses any remaining producer nodes into their users. 752a70aa7bbSRiver Riddle void runGreedyFusion() { 753a70aa7bbSRiver Riddle // TODO: Run this repeatedly until a fixed-point is reached. 754a70aa7bbSRiver Riddle fuseProducerConsumerNodes(/*maxSrcUserCount=*/1); 755a70aa7bbSRiver Riddle fuseSiblingNodes(); 756a70aa7bbSRiver Riddle fuseProducerConsumerNodes( 757a70aa7bbSRiver Riddle /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max()); 758a70aa7bbSRiver Riddle eraseUnusedMemRefAllocations(); 759a70aa7bbSRiver Riddle } 760a70aa7bbSRiver Riddle 7616b2e29c5SUday Bondhugula /// Returns true if a private memref can be created for `memref` given 7626b2e29c5SUday Bondhugula /// the fusion scenario reflected by the other arguments. 7636b2e29c5SUday Bondhugula bool canCreatePrivateMemRef(Value memref, 7646b2e29c5SUday Bondhugula const DenseSet<Value> &srcEscapingMemRefs, 7656b2e29c5SUday Bondhugula unsigned producerId, unsigned consumerId, 7666b2e29c5SUday Bondhugula bool removeSrcNode) { 7676b2e29c5SUday Bondhugula const Node *consumerNode = mdg->getNode(consumerId); 7686b2e29c5SUday Bondhugula // If `memref` is an escaping one, do not create a private memref 7696b2e29c5SUday Bondhugula // for the below scenarios, since doing so will leave the escaping 7706b2e29c5SUday Bondhugula // memref unmodified as all the writes originally meant for the 7716b2e29c5SUday Bondhugula // escaping memref would be performed on the private memref: 7726b2e29c5SUday Bondhugula // 1. The source is to be removed after fusion, 7736b2e29c5SUday Bondhugula // OR 7746b2e29c5SUday Bondhugula // 2. The destination writes to `memref`. 7756b2e29c5SUday Bondhugula if (srcEscapingMemRefs.count(memref) > 0 && 7766b2e29c5SUday Bondhugula (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0)) 7776b2e29c5SUday Bondhugula return false; 778a70aa7bbSRiver Riddle 7796b2e29c5SUday Bondhugula // Don't create a private memref if 'srcNode' has in edges on 7806b2e29c5SUday Bondhugula // 'memref' or 'dstNode' has out edges on 'memref'. 7816b2e29c5SUday Bondhugula if (mdg->getIncomingMemRefAccesses(producerId, memref) > 0 || 7826b2e29c5SUday Bondhugula mdg->getOutEdgeCount(consumerId, memref) > 0) 7836b2e29c5SUday Bondhugula return false; 7846b2e29c5SUday Bondhugula 7856b2e29c5SUday Bondhugula // If 'srcNode' will be removed but it has out edges on 'memref' to 7866b2e29c5SUday Bondhugula // nodes other than 'dstNode', we have to preserve dependences and 7876b2e29c5SUday Bondhugula // cannot create a private memref. 7886b2e29c5SUday Bondhugula if (removeSrcNode && 7896b2e29c5SUday Bondhugula any_of(mdg->outEdges[producerId], [&](const auto &edge) { 7906b2e29c5SUday Bondhugula return edge.value == memref && edge.id != consumerId; 7916b2e29c5SUday Bondhugula })) 7926b2e29c5SUday Bondhugula return false; 7936b2e29c5SUday Bondhugula 7946b2e29c5SUday Bondhugula return true; 7956b2e29c5SUday Bondhugula } 7966b2e29c5SUday Bondhugula 7976b2e29c5SUday Bondhugula /// Perform fusions with node `dstId` as the destination of fusion, with 7986b2e29c5SUday Bondhugula /// No fusion is performed when producers with a user count greater than 7996b2e29c5SUday Bondhugula /// `maxSrcUserCount` for any of the memrefs involved. 8006b2e29c5SUday Bondhugula void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) { 8016b2e29c5SUday Bondhugula LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n"); 802a70aa7bbSRiver Riddle // Skip if this node was removed (fused into another node). 803a70aa7bbSRiver Riddle if (mdg->nodes.count(dstId) == 0) 8046b2e29c5SUday Bondhugula return; 805a70aa7bbSRiver Riddle // Get 'dstNode' into which to attempt fusion. 806a70aa7bbSRiver Riddle auto *dstNode = mdg->getNode(dstId); 807a70aa7bbSRiver Riddle // Skip if 'dstNode' is not a loop nest. 808a70aa7bbSRiver Riddle if (!isa<AffineForOp>(dstNode->op)) 8096b2e29c5SUday Bondhugula return; 810a70aa7bbSRiver Riddle // Skip if 'dstNode' is a loop nest returning values. 811a70aa7bbSRiver Riddle // TODO: support loop nests that return values. 812a70aa7bbSRiver Riddle if (dstNode->op->getNumResults() > 0) 8136b2e29c5SUday Bondhugula return; 814a70aa7bbSRiver Riddle 815a70aa7bbSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n"); 816a70aa7bbSRiver Riddle 817a70aa7bbSRiver Riddle // Sink sequential loops in 'dstNode' (and thus raise parallel loops) 818a70aa7bbSRiver Riddle // while preserving relative order. This can increase the maximum loop 819a70aa7bbSRiver Riddle // depth at which we can fuse a slice of a producer loop nest into a 820a70aa7bbSRiver Riddle // consumer loop nest. 821a70aa7bbSRiver Riddle sinkSequentialLoops(dstNode); 822a70aa7bbSRiver Riddle auto dstAffineForOp = cast<AffineForOp>(dstNode->op); 823a70aa7bbSRiver Riddle 824a70aa7bbSRiver Riddle // Try to fuse 'dstNode' with candidate producer loops until a fixed point 825a70aa7bbSRiver Riddle // is reached. Fusing two loops may expose new fusion opportunities. 826a70aa7bbSRiver Riddle bool dstNodeChanged; 827a70aa7bbSRiver Riddle do { 828a70aa7bbSRiver Riddle // Gather src loop candidates for 'dstNode' and visit them in "quasi" 829a70aa7bbSRiver Riddle // reverse program order to minimize the number of iterations needed to 830a70aa7bbSRiver Riddle // reach the fixed point. Note that this is a best effort approach since 831a70aa7bbSRiver Riddle // 'getProducerCandidates' does not always guarantee that program order 832a70aa7bbSRiver Riddle // in 'srcIdCandidates'. 833a70aa7bbSRiver Riddle dstNodeChanged = false; 834a70aa7bbSRiver Riddle SmallVector<unsigned, 16> srcIdCandidates; 835a70aa7bbSRiver Riddle getProducerCandidates(dstId, mdg, srcIdCandidates); 836a70aa7bbSRiver Riddle 837a70aa7bbSRiver Riddle for (unsigned srcId : llvm::reverse(srcIdCandidates)) { 838a70aa7bbSRiver Riddle // Get 'srcNode' from which to attempt fusion into 'dstNode'. 839a70aa7bbSRiver Riddle auto *srcNode = mdg->getNode(srcId); 840a70aa7bbSRiver Riddle auto srcAffineForOp = cast<AffineForOp>(srcNode->op); 841a70aa7bbSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId 842a70aa7bbSRiver Riddle << " for dst loop " << dstId << "\n"); 843a70aa7bbSRiver Riddle 844a70aa7bbSRiver Riddle // Skip if 'srcNode' is a loop nest returning values. 845a70aa7bbSRiver Riddle // TODO: support loop nests that return values. 846a70aa7bbSRiver Riddle if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0) 847a70aa7bbSRiver Riddle continue; 848a70aa7bbSRiver Riddle 849a70aa7bbSRiver Riddle DenseSet<Value> producerConsumerMemrefs; 850a70aa7bbSRiver Riddle gatherProducerConsumerMemrefs(srcId, dstId, mdg, 851a70aa7bbSRiver Riddle producerConsumerMemrefs); 852a70aa7bbSRiver Riddle 853a70aa7bbSRiver Riddle // Skip if 'srcNode' out edge count on any memref is greater than 854a70aa7bbSRiver Riddle // 'maxSrcUserCount'. 855a70aa7bbSRiver Riddle if (any_of(producerConsumerMemrefs, [&](Value memref) { 856a70aa7bbSRiver Riddle return mdg->getOutEdgeCount(srcNode->id, memref) > 857a70aa7bbSRiver Riddle maxSrcUserCount; 858a70aa7bbSRiver Riddle })) 859a70aa7bbSRiver Riddle continue; 860a70aa7bbSRiver Riddle 861fe9d0a47SUday Bondhugula // Gather memrefs in 'srcNode' that are written and escape out of the 862fe9d0a47SUday Bondhugula // block (e.g., memref block arguments, returned memrefs, 863a70aa7bbSRiver Riddle // memrefs passed to function calls, etc.). 864a70aa7bbSRiver Riddle DenseSet<Value> srcEscapingMemRefs; 865a70aa7bbSRiver Riddle gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs); 866a70aa7bbSRiver Riddle 867a70aa7bbSRiver Riddle // Skip if there are non-affine operations in between the 'srcNode' 868a70aa7bbSRiver Riddle // and 'dstNode' using their memrefs. If so, we wouldn't be able to 869a70aa7bbSRiver Riddle // compute a legal insertion point for now. 'srcNode' and 'dstNode' 870a70aa7bbSRiver Riddle // memrefs with non-affine operation users would be considered 871a70aa7bbSRiver Riddle // escaping memrefs so we can limit this check to only scenarios with 872a70aa7bbSRiver Riddle // escaping memrefs. 873a70aa7bbSRiver Riddle if (!srcEscapingMemRefs.empty() && 874*469f9d5fSUday Bondhugula hasNonAffineUsersOnPath(srcNode->op, dstNode->op)) { 8756b2e29c5SUday Bondhugula LLVM_DEBUG(llvm::dbgs() 8763ab88e79SUday Bondhugula << "Can't fuse: non-affine users in between the loops\n"); 877a70aa7bbSRiver Riddle continue; 878a70aa7bbSRiver Riddle } 879a70aa7bbSRiver Riddle 880a70aa7bbSRiver Riddle // Compute an operation list insertion point for the fused loop 881a70aa7bbSRiver Riddle // nest which preserves dependences. 882a70aa7bbSRiver Riddle Operation *fusedLoopInsPoint = 883a70aa7bbSRiver Riddle mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id); 884a70aa7bbSRiver Riddle if (fusedLoopInsPoint == nullptr) 885a70aa7bbSRiver Riddle continue; 886a70aa7bbSRiver Riddle 887c79ffb02SUday Bondhugula // It's possible this fusion is at an inner depth (i.e., there are 888c79ffb02SUday Bondhugula // common surrounding affine loops for the source and destination for 889c79ffb02SUday Bondhugula // ops). We need to get this number because the call to canFuseLoops 890c79ffb02SUday Bondhugula // needs to be passed the absolute depth. The max legal depth and the 891c79ffb02SUday Bondhugula // depths we try below are however *relative* and as such don't include 892c79ffb02SUday Bondhugula // the common depth. 893c79ffb02SUday Bondhugula SmallVector<AffineForOp, 4> surroundingLoops; 894c79ffb02SUday Bondhugula getAffineForIVs(*dstAffineForOp, &surroundingLoops); 895c79ffb02SUday Bondhugula unsigned numSurroundingLoops = surroundingLoops.size(); 896c79ffb02SUday Bondhugula 897a70aa7bbSRiver Riddle // Compute the innermost common loop depth for dstNode 898a70aa7bbSRiver Riddle // producer-consumer loads/stores. 899a70aa7bbSRiver Riddle SmallVector<Operation *, 2> dstMemrefOps; 900a70aa7bbSRiver Riddle for (Operation *op : dstNode->loads) 901a70aa7bbSRiver Riddle if (producerConsumerMemrefs.count( 902a70aa7bbSRiver Riddle cast<AffineReadOpInterface>(op).getMemRef()) > 0) 903a70aa7bbSRiver Riddle dstMemrefOps.push_back(op); 904a70aa7bbSRiver Riddle for (Operation *op : dstNode->stores) 905a70aa7bbSRiver Riddle if (producerConsumerMemrefs.count( 906a70aa7bbSRiver Riddle cast<AffineWriteOpInterface>(op).getMemRef())) 907a70aa7bbSRiver Riddle dstMemrefOps.push_back(op); 908c79ffb02SUday Bondhugula unsigned dstLoopDepthTest = 909c79ffb02SUday Bondhugula getInnermostCommonLoopDepth(dstMemrefOps) - numSurroundingLoops; 910a70aa7bbSRiver Riddle 911a70aa7bbSRiver Riddle // Check the feasibility of fusing src loop nest into dst loop nest 912a70aa7bbSRiver Riddle // at loop depths in range [1, dstLoopDepthTest]. 913a70aa7bbSRiver Riddle unsigned maxLegalFusionDepth = 0; 914a70aa7bbSRiver Riddle SmallVector<ComputationSliceState, 8> depthSliceUnions; 915a70aa7bbSRiver Riddle depthSliceUnions.resize(dstLoopDepthTest); 916a70aa7bbSRiver Riddle FusionStrategy strategy(FusionStrategy::ProducerConsumer); 917a70aa7bbSRiver Riddle for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { 918c79ffb02SUday Bondhugula FusionResult result = 919c79ffb02SUday Bondhugula affine::canFuseLoops(srcAffineForOp, dstAffineForOp, 920c79ffb02SUday Bondhugula /*dstLoopDepth=*/i + numSurroundingLoops, 921c79ffb02SUday Bondhugula &depthSliceUnions[i - 1], strategy); 922a70aa7bbSRiver Riddle 923a70aa7bbSRiver Riddle if (result.value == FusionResult::Success) 924a70aa7bbSRiver Riddle maxLegalFusionDepth = i; 925a70aa7bbSRiver Riddle } 926a70aa7bbSRiver Riddle 927a70aa7bbSRiver Riddle if (maxLegalFusionDepth == 0) { 928a70aa7bbSRiver Riddle LLVM_DEBUG(llvm::dbgs() 929a70aa7bbSRiver Riddle << "Can't fuse: fusion is not legal at any depth\n"); 930a70aa7bbSRiver Riddle continue; 931a70aa7bbSRiver Riddle } 932a70aa7bbSRiver Riddle 933a70aa7bbSRiver Riddle // Check if fusion would be profitable. We skip profitability analysis 934a70aa7bbSRiver Riddle // for maximal fusion since we already know the maximal legal depth to 935a70aa7bbSRiver Riddle // fuse. 936a70aa7bbSRiver Riddle unsigned bestDstLoopDepth = maxLegalFusionDepth; 937a70aa7bbSRiver Riddle if (!maximalFusion) { 938a70aa7bbSRiver Riddle // Retrieve producer stores from the src loop. 939a70aa7bbSRiver Riddle SmallVector<Operation *, 2> producerStores; 940a70aa7bbSRiver Riddle for (Operation *op : srcNode->stores) 941a70aa7bbSRiver Riddle if (producerConsumerMemrefs.count( 942a70aa7bbSRiver Riddle cast<AffineWriteOpInterface>(op).getMemRef())) 943a70aa7bbSRiver Riddle producerStores.push_back(op); 944a70aa7bbSRiver Riddle 945a70aa7bbSRiver Riddle // TODO: Suppport multiple producer stores in profitability 946a70aa7bbSRiver Riddle // analysis. We limit profitability analysis to only scenarios with 947a70aa7bbSRiver Riddle // a single producer store for now. Note that some multi-store 948a70aa7bbSRiver Riddle // producer scenarios will still go through profitability analysis 949a70aa7bbSRiver Riddle // if only one of the stores is involved the producer-consumer 950a70aa7bbSRiver Riddle // relationship of the candidate loops. 951a70aa7bbSRiver Riddle assert(!producerStores.empty() && "Expected producer store"); 952a70aa7bbSRiver Riddle if (producerStores.size() > 1) 953a70aa7bbSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not " 954a70aa7bbSRiver Riddle "supported for this case\n"); 955a70aa7bbSRiver Riddle else if (!isFusionProfitable(producerStores[0], producerStores[0], 956a70aa7bbSRiver Riddle dstAffineForOp, depthSliceUnions, 957a70aa7bbSRiver Riddle maxLegalFusionDepth, &bestDstLoopDepth, 958a70aa7bbSRiver Riddle computeToleranceThreshold)) 959a70aa7bbSRiver Riddle continue; 960a70aa7bbSRiver Riddle } 961a70aa7bbSRiver Riddle 962a70aa7bbSRiver Riddle assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); 963a70aa7bbSRiver Riddle ComputationSliceState &bestSlice = 964a70aa7bbSRiver Riddle depthSliceUnions[bestDstLoopDepth - 1]; 965a70aa7bbSRiver Riddle assert(!bestSlice.isEmpty() && "Missing slice union for depth"); 966a70aa7bbSRiver Riddle 967a70aa7bbSRiver Riddle // Determine if 'srcId' can be removed after fusion, taking into 968a70aa7bbSRiver Riddle // account remaining dependences, escaping memrefs and the fusion 969a70aa7bbSRiver Riddle // insertion point. 970a70aa7bbSRiver Riddle bool removeSrcNode = canRemoveSrcNodeAfterFusion( 971a70aa7bbSRiver Riddle srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs, 972a70aa7bbSRiver Riddle mdg); 973a70aa7bbSRiver Riddle 974a70aa7bbSRiver Riddle DenseSet<Value> privateMemrefs; 975a70aa7bbSRiver Riddle for (Value memref : producerConsumerMemrefs) { 9766b2e29c5SUday Bondhugula if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId, 9776b2e29c5SUday Bondhugula removeSrcNode)) { 9786b2e29c5SUday Bondhugula // Create a private version of this memref. 9796b2e29c5SUday Bondhugula LLVM_DEBUG(llvm::dbgs() 9806b2e29c5SUday Bondhugula << "Creating private memref for " << memref << '\n'); 981a70aa7bbSRiver Riddle // Create a private version of this memref. 982a70aa7bbSRiver Riddle privateMemrefs.insert(memref); 983a70aa7bbSRiver Riddle } 9846b2e29c5SUday Bondhugula } 985a70aa7bbSRiver Riddle 986a70aa7bbSRiver Riddle // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. 987a70aa7bbSRiver Riddle fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice); 988a70aa7bbSRiver Riddle dstNodeChanged = true; 989a70aa7bbSRiver Riddle 990a70aa7bbSRiver Riddle LLVM_DEBUG(llvm::dbgs() 991a70aa7bbSRiver Riddle << "Fused src loop " << srcId << " into dst loop " << dstId 992a70aa7bbSRiver Riddle << " at depth " << bestDstLoopDepth << ":\n" 993a70aa7bbSRiver Riddle << dstAffineForOp << "\n"); 994a70aa7bbSRiver Riddle 995a70aa7bbSRiver Riddle // Move 'dstAffineForOp' before 'insertPointInst' if needed. 996825b23a7SUday Bondhugula if (fusedLoopInsPoint != dstAffineForOp) 997825b23a7SUday Bondhugula dstAffineForOp->moveBefore(fusedLoopInsPoint); 998a70aa7bbSRiver Riddle 999a70aa7bbSRiver Riddle // Update edges between 'srcNode' and 'dstNode'. 1000a70aa7bbSRiver Riddle mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs, 1001a70aa7bbSRiver Riddle removeSrcNode); 1002a70aa7bbSRiver Riddle 1003a70aa7bbSRiver Riddle // Create private memrefs. 1004a70aa7bbSRiver Riddle if (!privateMemrefs.empty()) { 1005a70aa7bbSRiver Riddle // Gather stores for all the private-to-be memrefs. 1006a70aa7bbSRiver Riddle DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores; 1007a70aa7bbSRiver Riddle dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) { 1008a70aa7bbSRiver Riddle Value storeMemRef = storeOp.getMemRef(); 1009a70aa7bbSRiver Riddle if (privateMemrefs.count(storeMemRef) > 0) 1010825b23a7SUday Bondhugula privateMemRefToStores[storeMemRef].push_back(storeOp); 1011a70aa7bbSRiver Riddle }); 1012a70aa7bbSRiver Riddle 1013a70aa7bbSRiver Riddle // Replace original memrefs with private memrefs. Note that all the 1014a70aa7bbSRiver Riddle // loads and stores on these memrefs will be replaced with a new 1015a70aa7bbSRiver Riddle // loads and stores. Any reference to the original ones becomes 1016a70aa7bbSRiver Riddle // invalid after this point. 1017a70aa7bbSRiver Riddle for (auto &memrefToStoresPair : privateMemRefToStores) { 1018a70aa7bbSRiver Riddle // TODO: Use union of memref write regions to compute 1019a70aa7bbSRiver Riddle // private memref footprint. 1020a70aa7bbSRiver Riddle SmallVector<Operation *, 4> &storesForMemref = 1021a70aa7bbSRiver Riddle memrefToStoresPair.second; 1022a70aa7bbSRiver Riddle Value newMemRef = createPrivateMemRef( 1023a70aa7bbSRiver Riddle dstAffineForOp, storesForMemref[0], bestDstLoopDepth, 1024a70aa7bbSRiver Riddle fastMemorySpace, localBufSizeThreshold); 1025a70aa7bbSRiver Riddle // Create new node in dependence graph for 'newMemRef' alloc op. 10266b2e29c5SUday Bondhugula unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp()); 1027a70aa7bbSRiver Riddle // Add edge from 'newMemRef' node to dstNode. 1028a70aa7bbSRiver Riddle mdg->addEdge(newMemRefNodeId, dstId, newMemRef); 1029a70aa7bbSRiver Riddle } 1030a70aa7bbSRiver Riddle // One or more entries for 'newMemRef' alloc op are inserted into 1031a70aa7bbSRiver Riddle // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to 1032a70aa7bbSRiver Riddle // reallocate, update dstNode. 1033a70aa7bbSRiver Riddle dstNode = mdg->getNode(dstId); 1034a70aa7bbSRiver Riddle } 1035a70aa7bbSRiver Riddle 1036a70aa7bbSRiver Riddle // Collect dst loop stats after memref privatization transformation. 1037a70aa7bbSRiver Riddle LoopNestStateCollector dstLoopCollector; 1038825b23a7SUday Bondhugula dstLoopCollector.collect(dstAffineForOp); 1039a70aa7bbSRiver Riddle 1040a70aa7bbSRiver Riddle // Clear and add back loads and stores. 1041a70aa7bbSRiver Riddle mdg->clearNodeLoadAndStores(dstNode->id); 1042a70aa7bbSRiver Riddle mdg->addToNode(dstId, dstLoopCollector.loadOpInsts, 1043a70aa7bbSRiver Riddle dstLoopCollector.storeOpInsts); 1044a70aa7bbSRiver Riddle 1045a70aa7bbSRiver Riddle if (removeSrcNode) { 1046a70aa7bbSRiver Riddle LLVM_DEBUG(llvm::dbgs() 1047a70aa7bbSRiver Riddle << "Removing src loop " << srcId << " after fusion\n"); 1048a70aa7bbSRiver Riddle // srcNode is no longer valid after it is removed from mdg. 1049a70aa7bbSRiver Riddle srcAffineForOp.erase(); 1050a70aa7bbSRiver Riddle mdg->removeNode(srcId); 1051a70aa7bbSRiver Riddle srcNode = nullptr; 1052a70aa7bbSRiver Riddle } 1053a70aa7bbSRiver Riddle } 1054a70aa7bbSRiver Riddle } while (dstNodeChanged); 1055a70aa7bbSRiver Riddle } 10566b2e29c5SUday Bondhugula 10576b2e29c5SUday Bondhugula /// Visit each node in the graph, and for each node, attempt to fuse it with 10586b2e29c5SUday Bondhugula /// producer-consumer candidates. No fusion is performed when producers with a 10596b2e29c5SUday Bondhugula /// user count greater than `maxSrcUserCount` for any of the memrefs involved 10606b2e29c5SUday Bondhugula /// are encountered. 10616b2e29c5SUday Bondhugula void fuseProducerConsumerNodes(unsigned maxSrcUserCount) { 10626b2e29c5SUday Bondhugula LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n"); 10636b2e29c5SUday Bondhugula init(); 10646b2e29c5SUday Bondhugula while (!worklist.empty()) { 10656b2e29c5SUday Bondhugula unsigned dstId = worklist.back(); 10666b2e29c5SUday Bondhugula worklist.pop_back(); 10676b2e29c5SUday Bondhugula performFusionsIntoDest(dstId, maxSrcUserCount); 10686b2e29c5SUday Bondhugula } 1069a70aa7bbSRiver Riddle } 1070a70aa7bbSRiver Riddle 1071a70aa7bbSRiver Riddle // Visits each node in the graph, and for each node, attempts to fuse it with 1072a70aa7bbSRiver Riddle // its sibling nodes (nodes which share a parent, but no dependence edges). 1073a70aa7bbSRiver Riddle void fuseSiblingNodes() { 1074a70aa7bbSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "--- Sibling Fusion ---\n"); 1075a70aa7bbSRiver Riddle init(); 1076a70aa7bbSRiver Riddle while (!worklist.empty()) { 1077a70aa7bbSRiver Riddle unsigned dstId = worklist.back(); 1078a70aa7bbSRiver Riddle worklist.pop_back(); 1079a70aa7bbSRiver Riddle 1080a70aa7bbSRiver Riddle // Skip if this node was removed (fused into another node). 1081a70aa7bbSRiver Riddle if (mdg->nodes.count(dstId) == 0) 1082a70aa7bbSRiver Riddle continue; 1083a70aa7bbSRiver Riddle // Get 'dstNode' into which to attempt fusion. 1084a70aa7bbSRiver Riddle auto *dstNode = mdg->getNode(dstId); 1085a70aa7bbSRiver Riddle // Skip if 'dstNode' is not a loop nest. 1086a70aa7bbSRiver Riddle if (!isa<AffineForOp>(dstNode->op)) 1087a70aa7bbSRiver Riddle continue; 1088a70aa7bbSRiver Riddle // Attempt to fuse 'dstNode' with its sibling nodes in the graph. 1089a70aa7bbSRiver Riddle fuseWithSiblingNodes(dstNode); 1090a70aa7bbSRiver Riddle } 1091a70aa7bbSRiver Riddle } 1092a70aa7bbSRiver Riddle 1093a70aa7bbSRiver Riddle // Attempt to fuse 'dstNode' with sibling nodes in the graph. 1094a70aa7bbSRiver Riddle void fuseWithSiblingNodes(Node *dstNode) { 1095a70aa7bbSRiver Riddle DenseSet<unsigned> visitedSibNodeIds; 1096a70aa7bbSRiver Riddle std::pair<unsigned, Value> idAndMemref; 1097a70aa7bbSRiver Riddle auto dstAffineForOp = cast<AffineForOp>(dstNode->op); 1098a70aa7bbSRiver Riddle 1099a70aa7bbSRiver Riddle while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) { 1100a70aa7bbSRiver Riddle unsigned sibId = idAndMemref.first; 1101a70aa7bbSRiver Riddle Value memref = idAndMemref.second; 1102a70aa7bbSRiver Riddle // TODO: Check that 'sibStoreOpInst' post-dominates all other 1103a70aa7bbSRiver Riddle // stores to the same memref in 'sibNode' loop nest. 1104a70aa7bbSRiver Riddle auto *sibNode = mdg->getNode(sibId); 1105a70aa7bbSRiver Riddle // Compute an operation list insertion point for the fused loop 1106a70aa7bbSRiver Riddle // nest which preserves dependences. 1107a70aa7bbSRiver Riddle assert(sibNode->op->getBlock() == dstNode->op->getBlock()); 1108a70aa7bbSRiver Riddle Operation *insertPointInst = 1109a70aa7bbSRiver Riddle sibNode->op->isBeforeInBlock(dstNode->op) 1110a70aa7bbSRiver Riddle ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id) 1111a70aa7bbSRiver Riddle : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id); 1112a70aa7bbSRiver Riddle if (insertPointInst == nullptr) 1113a70aa7bbSRiver Riddle continue; 1114a70aa7bbSRiver Riddle 1115a70aa7bbSRiver Riddle // Check if fusion would be profitable and at what depth. 1116a70aa7bbSRiver Riddle 1117a70aa7bbSRiver Riddle // Get unique 'sibNode' load op to 'memref'. 1118a70aa7bbSRiver Riddle SmallVector<Operation *, 2> sibLoadOpInsts; 1119a70aa7bbSRiver Riddle sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts); 1120a70aa7bbSRiver Riddle // Currently findSiblingNodeToFuse searches for siblings with one load. 1121a70aa7bbSRiver Riddle assert(sibLoadOpInsts.size() == 1); 1122a70aa7bbSRiver Riddle Operation *sibLoadOpInst = sibLoadOpInsts[0]; 1123a70aa7bbSRiver Riddle 1124a70aa7bbSRiver Riddle // Gather 'dstNode' load ops to 'memref'. 1125a70aa7bbSRiver Riddle SmallVector<Operation *, 2> dstLoadOpInsts; 1126a70aa7bbSRiver Riddle dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts); 1127a70aa7bbSRiver Riddle 1128c79ffb02SUday Bondhugula // It's possible this fusion is at an inner depth (i.e., there are common 1129c79ffb02SUday Bondhugula // surrounding affine loops for the source and destination for ops). We 1130c79ffb02SUday Bondhugula // need to get this number because the call to canFuseLoops needs to be 1131c79ffb02SUday Bondhugula // passed the absolute depth. The max legal depth and the depths we try 1132c79ffb02SUday Bondhugula // below are however *relative* and as such don't include the common 1133c79ffb02SUday Bondhugula // depth. 1134c79ffb02SUday Bondhugula SmallVector<AffineForOp, 4> surroundingLoops; 1135c79ffb02SUday Bondhugula getAffineForIVs(*dstAffineForOp, &surroundingLoops); 1136c79ffb02SUday Bondhugula unsigned numSurroundingLoops = surroundingLoops.size(); 1137a70aa7bbSRiver Riddle SmallVector<AffineForOp, 4> dstLoopIVs; 113823bcd6b8SUday Bondhugula getAffineForIVs(*dstLoadOpInsts[0], &dstLoopIVs); 1139c79ffb02SUday Bondhugula unsigned dstLoopDepthTest = dstLoopIVs.size() - numSurroundingLoops; 1140a70aa7bbSRiver Riddle auto sibAffineForOp = cast<AffineForOp>(sibNode->op); 1141a70aa7bbSRiver Riddle 1142a70aa7bbSRiver Riddle // Compute loop depth and slice union for fusion. 1143a70aa7bbSRiver Riddle SmallVector<ComputationSliceState, 8> depthSliceUnions; 1144a70aa7bbSRiver Riddle depthSliceUnions.resize(dstLoopDepthTest); 1145a70aa7bbSRiver Riddle unsigned maxLegalFusionDepth = 0; 1146a70aa7bbSRiver Riddle FusionStrategy strategy(memref); 1147a70aa7bbSRiver Riddle for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { 1148c79ffb02SUday Bondhugula FusionResult result = 1149c79ffb02SUday Bondhugula affine::canFuseLoops(sibAffineForOp, dstAffineForOp, 1150c79ffb02SUday Bondhugula /*dstLoopDepth=*/i + numSurroundingLoops, 1151c79ffb02SUday Bondhugula &depthSliceUnions[i - 1], strategy); 1152a70aa7bbSRiver Riddle 1153a70aa7bbSRiver Riddle if (result.value == FusionResult::Success) 1154a70aa7bbSRiver Riddle maxLegalFusionDepth = i; 1155a70aa7bbSRiver Riddle } 1156a70aa7bbSRiver Riddle 1157c79ffb02SUday Bondhugula LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: " 1158c79ffb02SUday Bondhugula << maxLegalFusionDepth << '\n'); 1159c79ffb02SUday Bondhugula 1160a70aa7bbSRiver Riddle // Skip if fusion is not feasible at any loop depths. 1161a70aa7bbSRiver Riddle if (maxLegalFusionDepth == 0) 1162a70aa7bbSRiver Riddle continue; 1163a70aa7bbSRiver Riddle 1164a70aa7bbSRiver Riddle unsigned bestDstLoopDepth = maxLegalFusionDepth; 1165a70aa7bbSRiver Riddle if (!maximalFusion) { 1166721defb4SUday Bondhugula // Check if fusion would be profitable. For sibling fusion, the sibling 1167721defb4SUday Bondhugula // load op is treated as the src "store" op for fusion profitability 1168721defb4SUday Bondhugula // purposes. The footprint of the load in the slice relative to the 1169721defb4SUday Bondhugula // unfused source's determines reuse. 1170721defb4SUday Bondhugula if (!isFusionProfitable(sibLoadOpInst, sibLoadOpInst, dstAffineForOp, 1171a70aa7bbSRiver Riddle depthSliceUnions, maxLegalFusionDepth, 1172a70aa7bbSRiver Riddle &bestDstLoopDepth, computeToleranceThreshold)) 1173a70aa7bbSRiver Riddle continue; 1174a70aa7bbSRiver Riddle } 1175a70aa7bbSRiver Riddle 1176a70aa7bbSRiver Riddle assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); 1177a70aa7bbSRiver Riddle assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() && 1178a70aa7bbSRiver Riddle "Fusion depth has no computed slice union"); 1179a70aa7bbSRiver Riddle // Check if source loop is being inserted in the innermost 1180a70aa7bbSRiver Riddle // destination loop. Based on this, the fused loop may be optimized 1181a70aa7bbSRiver Riddle // further inside `fuseLoops`. 1182a70aa7bbSRiver Riddle bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest); 1183a70aa7bbSRiver Riddle // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'. 11844c48f016SMatthias Springer affine::fuseLoops(sibAffineForOp, dstAffineForOp, 1185a70aa7bbSRiver Riddle depthSliceUnions[bestDstLoopDepth - 1], 1186a70aa7bbSRiver Riddle isInnermostInsertion); 1187a70aa7bbSRiver Riddle 1188a70aa7bbSRiver Riddle auto dstForInst = cast<AffineForOp>(dstNode->op); 1189a70aa7bbSRiver Riddle // Update operation position of fused loop nest (if needed). 1190825b23a7SUday Bondhugula if (insertPointInst != dstForInst) { 1191a70aa7bbSRiver Riddle dstForInst->moveBefore(insertPointInst); 1192a70aa7bbSRiver Riddle } 1193a70aa7bbSRiver Riddle // Update data dependence graph state post fusion. 1194a70aa7bbSRiver Riddle updateStateAfterSiblingFusion(sibNode, dstNode); 1195a70aa7bbSRiver Riddle } 1196a70aa7bbSRiver Riddle } 1197a70aa7bbSRiver Riddle 1198fe9d0a47SUday Bondhugula // Searches block argument uses and the graph from 'dstNode' looking for a 1199a70aa7bbSRiver Riddle // fusion candidate sibling node which shares no dependences with 'dstNode' 1200a70aa7bbSRiver Riddle // but which loads from the same memref. Returns true and sets 1201a70aa7bbSRiver Riddle // 'idAndMemrefToFuse' on success. Returns false otherwise. 1202a70aa7bbSRiver Riddle bool findSiblingNodeToFuse(Node *dstNode, 1203a70aa7bbSRiver Riddle DenseSet<unsigned> *visitedSibNodeIds, 1204a70aa7bbSRiver Riddle std::pair<unsigned, Value> *idAndMemrefToFuse) { 1205a70aa7bbSRiver Riddle // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse 1206a70aa7bbSRiver Riddle // on 'memref'. 1207a70aa7bbSRiver Riddle auto canFuseWithSibNode = [&](Node *sibNode, Value memref) { 1208a70aa7bbSRiver Riddle // Skip if 'outEdge' is not a read-after-write dependence. 1209a70aa7bbSRiver Riddle // TODO: Remove restrict to single load op restriction. 1210a70aa7bbSRiver Riddle if (sibNode->getLoadOpCount(memref) != 1) 1211a70aa7bbSRiver Riddle return false; 1212a70aa7bbSRiver Riddle // Skip if there exists a path of dependent edges between 1213a70aa7bbSRiver Riddle // 'sibNode' and 'dstNode'. 1214a70aa7bbSRiver Riddle if (mdg->hasDependencePath(sibNode->id, dstNode->id) || 1215a70aa7bbSRiver Riddle mdg->hasDependencePath(dstNode->id, sibNode->id)) 1216a70aa7bbSRiver Riddle return false; 1217a70aa7bbSRiver Riddle // Skip sib node if it loads to (and stores from) the same memref on 1218a70aa7bbSRiver Riddle // which it also has an input dependence edge. 1219a70aa7bbSRiver Riddle DenseSet<Value> loadAndStoreMemrefSet; 1220a70aa7bbSRiver Riddle sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet); 1221a70aa7bbSRiver Riddle if (llvm::any_of(loadAndStoreMemrefSet, [=](Value memref) { 1222a70aa7bbSRiver Riddle return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0; 1223a70aa7bbSRiver Riddle })) 1224a70aa7bbSRiver Riddle return false; 1225a70aa7bbSRiver Riddle 1226721defb4SUday Bondhugula // Check that all stores are to the same memref if any. 1227a70aa7bbSRiver Riddle DenseSet<Value> storeMemrefs; 1228a70aa7bbSRiver Riddle for (auto *storeOpInst : sibNode->stores) { 1229a70aa7bbSRiver Riddle storeMemrefs.insert( 1230a70aa7bbSRiver Riddle cast<AffineWriteOpInterface>(storeOpInst).getMemRef()); 1231a70aa7bbSRiver Riddle } 1232721defb4SUday Bondhugula if (storeMemrefs.size() > 1) 1233a70aa7bbSRiver Riddle return false; 1234a70aa7bbSRiver Riddle 1235a70aa7bbSRiver Riddle // Skip if a memref value in one node is used by a non-affine memref 1236a70aa7bbSRiver Riddle // access that lies between 'dstNode' and 'sibNode'. 1237*469f9d5fSUday Bondhugula if (hasNonAffineUsersOnPath(dstNode->op, sibNode->op) || 1238*469f9d5fSUday Bondhugula hasNonAffineUsersOnPath(sibNode->op, dstNode->op)) 1239a70aa7bbSRiver Riddle return false; 1240a70aa7bbSRiver Riddle return true; 1241a70aa7bbSRiver Riddle }; 1242a70aa7bbSRiver Riddle 1243fe9d0a47SUday Bondhugula // Search for siblings which load the same memref block argument. 1244fe9d0a47SUday Bondhugula Block *block = dstNode->op->getBlock(); 1245fe9d0a47SUday Bondhugula for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) { 1246fe9d0a47SUday Bondhugula for (Operation *user : block->getArgument(i).getUsers()) { 1247fe9d0a47SUday Bondhugula auto loadOp = dyn_cast<AffineReadOpInterface>(user); 1248fe9d0a47SUday Bondhugula if (!loadOp) 1249fe9d0a47SUday Bondhugula continue; 1250a70aa7bbSRiver Riddle // Gather loops surrounding 'use'. 1251a70aa7bbSRiver Riddle SmallVector<AffineForOp, 4> loops; 125223bcd6b8SUday Bondhugula getAffineForIVs(*user, &loops); 1253a70aa7bbSRiver Riddle // Skip 'use' if it is not within a loop nest. 1254c79ffb02SUday Bondhugula // Find the surrounding affine.for nested immediately within the 1255c79ffb02SUday Bondhugula // block. 1256c79ffb02SUday Bondhugula auto *it = llvm::find_if(loops, [&](AffineForOp loop) { 1257c79ffb02SUday Bondhugula return loop->getBlock() == &mdg->block; 1258c79ffb02SUday Bondhugula }); 1259c79ffb02SUday Bondhugula // Skip 'use' if it is not within a loop nest in `block`. 1260c79ffb02SUday Bondhugula if (it == loops.end()) 1261a70aa7bbSRiver Riddle continue; 1262c79ffb02SUday Bondhugula Node *sibNode = mdg->getForOpNode(*it); 1263a70aa7bbSRiver Riddle assert(sibNode != nullptr); 1264a70aa7bbSRiver Riddle // Skip 'use' if it not a sibling to 'dstNode'. 1265a70aa7bbSRiver Riddle if (sibNode->id == dstNode->id) 1266a70aa7bbSRiver Riddle continue; 1267a70aa7bbSRiver Riddle // Skip 'use' if it has been visited. 1268a70aa7bbSRiver Riddle if (visitedSibNodeIds->count(sibNode->id) > 0) 1269a70aa7bbSRiver Riddle continue; 1270a70aa7bbSRiver Riddle // Skip 'use' if it does not load from the same memref as 'dstNode'. 1271a70aa7bbSRiver Riddle auto memref = loadOp.getMemRef(); 1272a70aa7bbSRiver Riddle if (dstNode->getLoadOpCount(memref) == 0) 1273a70aa7bbSRiver Riddle continue; 1274a70aa7bbSRiver Riddle // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. 1275a70aa7bbSRiver Riddle if (canFuseWithSibNode(sibNode, memref)) { 1276a70aa7bbSRiver Riddle visitedSibNodeIds->insert(sibNode->id); 1277a70aa7bbSRiver Riddle idAndMemrefToFuse->first = sibNode->id; 1278a70aa7bbSRiver Riddle idAndMemrefToFuse->second = memref; 1279a70aa7bbSRiver Riddle return true; 1280a70aa7bbSRiver Riddle } 1281a70aa7bbSRiver Riddle } 1282a70aa7bbSRiver Riddle } 1283a70aa7bbSRiver Riddle 1284a70aa7bbSRiver Riddle // Search for siblings by following edges through an intermediate src node. 1285a70aa7bbSRiver Riddle // Collect candidate 'dstNode' input edges in 'inEdges'. 1286a70aa7bbSRiver Riddle SmallVector<MemRefDependenceGraph::Edge, 2> inEdges; 1287a70aa7bbSRiver Riddle mdg->forEachMemRefInputEdge( 1288a70aa7bbSRiver Riddle dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) { 1289a70aa7bbSRiver Riddle // Add 'inEdge' if it is a read-after-write dependence. 1290a70aa7bbSRiver Riddle if (dstNode->getLoadOpCount(inEdge.value) > 0 && 1291a70aa7bbSRiver Riddle mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0) 1292a70aa7bbSRiver Riddle inEdges.push_back(inEdge); 1293a70aa7bbSRiver Riddle }); 1294a70aa7bbSRiver Riddle 1295a70aa7bbSRiver Riddle // Search for sibling nodes to fuse by visiting output edges from each input 1296a70aa7bbSRiver Riddle // edge in 'inEdges'. 1297a70aa7bbSRiver Riddle for (auto &inEdge : inEdges) { 1298a70aa7bbSRiver Riddle // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'. 1299a70aa7bbSRiver Riddle SmallVector<MemRefDependenceGraph::Edge, 2> outEdges; 1300a70aa7bbSRiver Riddle mdg->forEachMemRefOutputEdge( 1301a70aa7bbSRiver Riddle inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) { 1302a70aa7bbSRiver Riddle unsigned sibNodeId = outEdge.id; 1303a70aa7bbSRiver Riddle if (visitedSibNodeIds->count(sibNodeId) > 0) 1304a70aa7bbSRiver Riddle return; 1305a70aa7bbSRiver Riddle // Skip output edge if not a sibling using the same memref. 1306a70aa7bbSRiver Riddle if (outEdge.id == dstNode->id || outEdge.value != inEdge.value) 1307a70aa7bbSRiver Riddle return; 1308a70aa7bbSRiver Riddle auto *sibNode = mdg->getNode(sibNodeId); 1309a70aa7bbSRiver Riddle if (!isa<AffineForOp>(sibNode->op)) 1310a70aa7bbSRiver Riddle return; 1311a70aa7bbSRiver Riddle // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. 1312a70aa7bbSRiver Riddle if (canFuseWithSibNode(sibNode, outEdge.value)) { 1313a70aa7bbSRiver Riddle // Add candidate 'outEdge' to sibling node. 1314a70aa7bbSRiver Riddle outEdges.push_back(outEdge); 1315a70aa7bbSRiver Riddle } 1316a70aa7bbSRiver Riddle }); 1317a70aa7bbSRiver Riddle 1318a70aa7bbSRiver Riddle // Add first candidate if any were returned. 1319a70aa7bbSRiver Riddle if (!outEdges.empty()) { 1320a70aa7bbSRiver Riddle visitedSibNodeIds->insert(outEdges[0].id); 1321a70aa7bbSRiver Riddle idAndMemrefToFuse->first = outEdges[0].id; 1322a70aa7bbSRiver Riddle idAndMemrefToFuse->second = outEdges[0].value; 1323a70aa7bbSRiver Riddle return true; 1324a70aa7bbSRiver Riddle } 1325a70aa7bbSRiver Riddle } 1326a70aa7bbSRiver Riddle return false; 1327a70aa7bbSRiver Riddle } 1328a70aa7bbSRiver Riddle 1329a70aa7bbSRiver Riddle /// Update data dependence graph state to reflect sibling fusion of 'sibNode' 1330a70aa7bbSRiver Riddle /// into 'dstNode'. 1331a70aa7bbSRiver Riddle void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) { 1332a70aa7bbSRiver Riddle // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion. 1333a70aa7bbSRiver Riddle mdg->updateEdges(sibNode->id, dstNode->id); 1334a70aa7bbSRiver Riddle 1335a70aa7bbSRiver Riddle // Collect dst loop stats after memref privatization transformation. 1336a70aa7bbSRiver Riddle auto dstForInst = cast<AffineForOp>(dstNode->op); 1337a70aa7bbSRiver Riddle LoopNestStateCollector dstLoopCollector; 1338825b23a7SUday Bondhugula dstLoopCollector.collect(dstForInst); 1339a70aa7bbSRiver Riddle // Clear and add back loads and stores 1340a70aa7bbSRiver Riddle mdg->clearNodeLoadAndStores(dstNode->id); 1341a70aa7bbSRiver Riddle mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts, 1342a70aa7bbSRiver Riddle dstLoopCollector.storeOpInsts); 1343a70aa7bbSRiver Riddle // Remove old sibling loop nest if it no longer has outgoing dependence 1344fe9d0a47SUday Bondhugula // edges, and it does not write to a memref which escapes the block. 1345a70aa7bbSRiver Riddle if (mdg->getOutEdgeCount(sibNode->id) == 0) { 13469f235a88SVitaly Buka Operation *op = sibNode->op; 1347a70aa7bbSRiver Riddle mdg->removeNode(sibNode->id); 13489f235a88SVitaly Buka op->erase(); 1349a70aa7bbSRiver Riddle } 1350a70aa7bbSRiver Riddle } 1351a70aa7bbSRiver Riddle 1352a70aa7bbSRiver Riddle // Clean up any allocs with no users. 1353a70aa7bbSRiver Riddle void eraseUnusedMemRefAllocations() { 1354a70aa7bbSRiver Riddle for (auto &pair : mdg->memrefEdgeCount) { 1355a70aa7bbSRiver Riddle if (pair.second > 0) 1356a70aa7bbSRiver Riddle continue; 1357a70aa7bbSRiver Riddle auto memref = pair.first; 1358a70aa7bbSRiver Riddle // Skip if there exist other uses (return operation or function calls). 1359a70aa7bbSRiver Riddle if (!memref.use_empty()) 1360a70aa7bbSRiver Riddle continue; 1361a70aa7bbSRiver Riddle // Use list expected to match the dep graph info. 1362a70aa7bbSRiver Riddle auto *op = memref.getDefiningOp(); 1363a70aa7bbSRiver Riddle if (isa_and_nonnull<memref::AllocOp>(op)) 1364a70aa7bbSRiver Riddle op->erase(); 1365a70aa7bbSRiver Riddle } 1366a70aa7bbSRiver Riddle } 1367a70aa7bbSRiver Riddle }; 1368a70aa7bbSRiver Riddle 1369a70aa7bbSRiver Riddle } // namespace 1370a70aa7bbSRiver Riddle 1371fe9d0a47SUday Bondhugula /// Run fusion on `block`. 1372fe9d0a47SUday Bondhugula void LoopFusion::runOnBlock(Block *block) { 1373fe9d0a47SUday Bondhugula MemRefDependenceGraph g(*block); 13743ab88e79SUday Bondhugula if (!g.init()) { 13753ab88e79SUday Bondhugula LLVM_DEBUG(llvm::dbgs() << "MDG init failed\n"); 1376a70aa7bbSRiver Riddle return; 13773ab88e79SUday Bondhugula } 1378a70aa7bbSRiver Riddle 13790a81ace0SKazu Hirata std::optional<unsigned> fastMemorySpaceOpt; 1380a70aa7bbSRiver Riddle if (fastMemorySpace.hasValue()) 1381a70aa7bbSRiver Riddle fastMemorySpaceOpt = fastMemorySpace; 1382a70aa7bbSRiver Riddle unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024; 1383a70aa7bbSRiver Riddle GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt, 1384a70aa7bbSRiver Riddle maximalFusion, computeToleranceThreshold); 1385a70aa7bbSRiver Riddle 1386a70aa7bbSRiver Riddle if (affineFusionMode == FusionMode::ProducerConsumer) 1387a70aa7bbSRiver Riddle fusion.runProducerConsumerFusionOnly(); 1388a70aa7bbSRiver Riddle else if (affineFusionMode == FusionMode::Sibling) 1389a70aa7bbSRiver Riddle fusion.runSiblingFusionOnly(); 1390a70aa7bbSRiver Riddle else 1391a70aa7bbSRiver Riddle fusion.runGreedyFusion(); 1392a70aa7bbSRiver Riddle } 1393fe9d0a47SUday Bondhugula 1394fe9d0a47SUday Bondhugula void LoopFusion::runOnOperation() { 1395c79ffb02SUday Bondhugula // Call fusion on every op that has at least two affine.for nests (in post 1396c79ffb02SUday Bondhugula // order). 1397c79ffb02SUday Bondhugula getOperation()->walk([&](Operation *op) { 1398c79ffb02SUday Bondhugula for (Region ®ion : op->getRegions()) { 1399c79ffb02SUday Bondhugula for (Block &block : region.getBlocks()) { 1400c79ffb02SUday Bondhugula auto affineFors = block.getOps<AffineForOp>(); 1401c79ffb02SUday Bondhugula if (!affineFors.empty() && !llvm::hasSingleElement(affineFors)) 1402fe9d0a47SUday Bondhugula runOnBlock(&block); 1403fe9d0a47SUday Bondhugula } 1404c79ffb02SUday Bondhugula } 1405c79ffb02SUday Bondhugula }); 1406c79ffb02SUday Bondhugula } 1407c910570fSUday Bondhugula 14084c48f016SMatthias Springer std::unique_ptr<Pass> mlir::affine::createLoopFusionPass( 14094c48f016SMatthias Springer unsigned fastMemorySpace, uint64_t localBufSizeThreshold, 14104c48f016SMatthias Springer bool maximalFusion, enum FusionMode affineFusionMode) { 1411c910570fSUday Bondhugula return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold, 1412c910570fSUday Bondhugula maximalFusion, affineFusionMode); 1413c910570fSUday Bondhugula } 1414