xref: /llvm-project/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp (revision b00e0c167186d69e1e6bceda57c09b272bd6acfc)
1a70aa7bbSRiver Riddle //===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===//
2a70aa7bbSRiver Riddle //
3a70aa7bbSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a70aa7bbSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5a70aa7bbSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a70aa7bbSRiver Riddle //
7a70aa7bbSRiver Riddle //===----------------------------------------------------------------------===//
8a70aa7bbSRiver Riddle //
9a70aa7bbSRiver Riddle // This file implements loop fusion transformation utility functions.
10a70aa7bbSRiver Riddle //
11a70aa7bbSRiver Riddle //===----------------------------------------------------------------------===//
12a70aa7bbSRiver Riddle 
13a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/LoopFusionUtils.h"
14a70aa7bbSRiver Riddle #include "mlir/Analysis/SliceAnalysis.h"
15*b00e0c16SChristian Ulmann #include "mlir/Analysis/TopologicalSortUtils.h"
16a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
17a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
18a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/Analysis/Utils.h"
19a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
20a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/LoopUtils.h"
214d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
22a70aa7bbSRiver Riddle #include "mlir/IR/Operation.h"
2363086d6aSMatthias Springer #include "mlir/IR/PatternMatch.h"
24a70aa7bbSRiver Riddle #include "llvm/Support/Debug.h"
25a70aa7bbSRiver Riddle #include "llvm/Support/raw_ostream.h"
26a1fe1f5fSKazu Hirata #include <optional>
27a70aa7bbSRiver Riddle 
28a70aa7bbSRiver Riddle #define DEBUG_TYPE "loop-fusion-utils"
29a70aa7bbSRiver Riddle 
30a70aa7bbSRiver Riddle using namespace mlir;
314c48f016SMatthias Springer using namespace mlir::affine;
32a70aa7bbSRiver Riddle 
33a70aa7bbSRiver Riddle // Gathers all load and store memref accesses in 'opA' into 'values', where
34a70aa7bbSRiver Riddle // 'values[memref] == true' for each store operation.
getLoadAndStoreMemRefAccesses(Operation * opA,DenseMap<Value,bool> & values)35a70aa7bbSRiver Riddle static void getLoadAndStoreMemRefAccesses(Operation *opA,
36a70aa7bbSRiver Riddle                                           DenseMap<Value, bool> &values) {
37a70aa7bbSRiver Riddle   opA->walk([&](Operation *op) {
38a70aa7bbSRiver Riddle     if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
39a70aa7bbSRiver Riddle       if (values.count(loadOp.getMemRef()) == 0)
40a70aa7bbSRiver Riddle         values[loadOp.getMemRef()] = false;
41a70aa7bbSRiver Riddle     } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
42a70aa7bbSRiver Riddle       values[storeOp.getMemRef()] = true;
43a70aa7bbSRiver Riddle     }
44a70aa7bbSRiver Riddle   });
45a70aa7bbSRiver Riddle }
46a70aa7bbSRiver Riddle 
47a70aa7bbSRiver Riddle /// Returns true if 'op' is a load or store operation which access a memref
48a70aa7bbSRiver Riddle /// accessed 'values' and at least one of the access is a store operation.
49a70aa7bbSRiver Riddle /// Returns false otherwise.
isDependentLoadOrStoreOp(Operation * op,DenseMap<Value,bool> & values)50a70aa7bbSRiver Riddle static bool isDependentLoadOrStoreOp(Operation *op,
51a70aa7bbSRiver Riddle                                      DenseMap<Value, bool> &values) {
52a70aa7bbSRiver Riddle   if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
53a70aa7bbSRiver Riddle     return values.count(loadOp.getMemRef()) > 0 && values[loadOp.getMemRef()];
54a70aa7bbSRiver Riddle   }
55a70aa7bbSRiver Riddle   if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
56a70aa7bbSRiver Riddle     return values.count(storeOp.getMemRef()) > 0;
57a70aa7bbSRiver Riddle   }
58a70aa7bbSRiver Riddle   return false;
59a70aa7bbSRiver Riddle }
60a70aa7bbSRiver Riddle 
61a70aa7bbSRiver Riddle // Returns the first operation in range ('opA', 'opB') which has a data
62a70aa7bbSRiver Riddle // dependence on 'opA'. Returns 'nullptr' of no dependence exists.
getFirstDependentOpInRange(Operation * opA,Operation * opB)63a70aa7bbSRiver Riddle static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) {
64a70aa7bbSRiver Riddle   // Record memref values from all loads/store in loop nest rooted at 'opA'.
65a70aa7bbSRiver Riddle   // Map from memref value to bool which is true if store, false otherwise.
66a70aa7bbSRiver Riddle   DenseMap<Value, bool> values;
67a70aa7bbSRiver Riddle   getLoadAndStoreMemRefAccesses(opA, values);
68a70aa7bbSRiver Riddle 
69a70aa7bbSRiver Riddle   // For each 'opX' in block in range ('opA', 'opB'), check if there is a data
70a70aa7bbSRiver Riddle   // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref
71a70aa7bbSRiver Riddle   // and at least one of the accesses is a store).
72a70aa7bbSRiver Riddle   Operation *firstDepOp = nullptr;
73a70aa7bbSRiver Riddle   for (Block::iterator it = std::next(Block::iterator(opA));
74a70aa7bbSRiver Riddle        it != Block::iterator(opB); ++it) {
75a70aa7bbSRiver Riddle     Operation *opX = &(*it);
76a70aa7bbSRiver Riddle     opX->walk([&](Operation *op) {
77a70aa7bbSRiver Riddle       if (!firstDepOp && isDependentLoadOrStoreOp(op, values))
78a70aa7bbSRiver Riddle         firstDepOp = opX;
79a70aa7bbSRiver Riddle     });
80a70aa7bbSRiver Riddle     if (firstDepOp)
81a70aa7bbSRiver Riddle       break;
82a70aa7bbSRiver Riddle   }
83a70aa7bbSRiver Riddle   return firstDepOp;
84a70aa7bbSRiver Riddle }
85a70aa7bbSRiver Riddle 
86a70aa7bbSRiver Riddle // Returns the last operation 'opX' in range ('opA', 'opB'), for which there
87a70aa7bbSRiver Riddle // exists a data dependence from 'opX' to 'opB'.
88a70aa7bbSRiver Riddle // Returns 'nullptr' of no dependence exists.
getLastDependentOpInRange(Operation * opA,Operation * opB)89a70aa7bbSRiver Riddle static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) {
90a70aa7bbSRiver Riddle   // Record memref values from all loads/store in loop nest rooted at 'opB'.
91a70aa7bbSRiver Riddle   // Map from memref value to bool which is true if store, false otherwise.
92a70aa7bbSRiver Riddle   DenseMap<Value, bool> values;
93a70aa7bbSRiver Riddle   getLoadAndStoreMemRefAccesses(opB, values);
94a70aa7bbSRiver Riddle 
95a70aa7bbSRiver Riddle   // For each 'opX' in block in range ('opA', 'opB') in reverse order,
96a70aa7bbSRiver Riddle   // check if there is a data dependence from 'opX' to 'opB':
97a70aa7bbSRiver Riddle   // *) 'opX' and 'opB' access the same memref and at least one of the accesses
98a70aa7bbSRiver Riddle   //    is a store.
99a70aa7bbSRiver Riddle   // *) 'opX' produces an SSA Value which is used by 'opB'.
100a70aa7bbSRiver Riddle   Operation *lastDepOp = nullptr;
101a70aa7bbSRiver Riddle   for (Block::reverse_iterator it = std::next(Block::reverse_iterator(opB));
102a70aa7bbSRiver Riddle        it != Block::reverse_iterator(opA); ++it) {
103a70aa7bbSRiver Riddle     Operation *opX = &(*it);
104a70aa7bbSRiver Riddle     opX->walk([&](Operation *op) {
105a70aa7bbSRiver Riddle       if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
106a70aa7bbSRiver Riddle         if (isDependentLoadOrStoreOp(op, values)) {
107a70aa7bbSRiver Riddle           lastDepOp = opX;
108a70aa7bbSRiver Riddle           return WalkResult::interrupt();
109a70aa7bbSRiver Riddle         }
110a70aa7bbSRiver Riddle         return WalkResult::advance();
111a70aa7bbSRiver Riddle       }
112825b23a7SUday Bondhugula       for (Value value : op->getResults()) {
113a70aa7bbSRiver Riddle         for (Operation *user : value.getUsers()) {
114a70aa7bbSRiver Riddle           SmallVector<AffineForOp, 4> loops;
115a70aa7bbSRiver Riddle           // Check if any loop in loop nest surrounding 'user' is 'opB'.
11623bcd6b8SUday Bondhugula           getAffineForIVs(*user, &loops);
117a70aa7bbSRiver Riddle           if (llvm::is_contained(loops, cast<AffineForOp>(opB))) {
118a70aa7bbSRiver Riddle             lastDepOp = opX;
119a70aa7bbSRiver Riddle             return WalkResult::interrupt();
120a70aa7bbSRiver Riddle           }
121a70aa7bbSRiver Riddle         }
122a70aa7bbSRiver Riddle       }
123a70aa7bbSRiver Riddle       return WalkResult::advance();
124a70aa7bbSRiver Riddle     });
125a70aa7bbSRiver Riddle     if (lastDepOp)
126a70aa7bbSRiver Riddle       break;
127a70aa7bbSRiver Riddle   }
128a70aa7bbSRiver Riddle   return lastDepOp;
129a70aa7bbSRiver Riddle }
130a70aa7bbSRiver Riddle 
131a70aa7bbSRiver Riddle // Computes and returns an insertion point operation, before which the
132a70aa7bbSRiver Riddle // the fused <srcForOp, dstForOp> loop nest can be inserted while preserving
133a70aa7bbSRiver Riddle // dependences. Returns nullptr if no such insertion point is found.
getFusedLoopNestInsertionPoint(AffineForOp srcForOp,AffineForOp dstForOp)134a70aa7bbSRiver Riddle static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp,
135a70aa7bbSRiver Riddle                                                  AffineForOp dstForOp) {
136825b23a7SUday Bondhugula   bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp);
137a70aa7bbSRiver Riddle   auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
138a70aa7bbSRiver Riddle   auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
139a70aa7bbSRiver Riddle 
140825b23a7SUday Bondhugula   Operation *firstDepOpA = getFirstDependentOpInRange(forOpA, forOpB);
141825b23a7SUday Bondhugula   Operation *lastDepOpB = getLastDependentOpInRange(forOpA, forOpB);
142a70aa7bbSRiver Riddle   // Block:
143a70aa7bbSRiver Riddle   //      ...
144a70aa7bbSRiver Riddle   //  |-- opA
145a70aa7bbSRiver Riddle   //  |   ...
146a70aa7bbSRiver Riddle   //  |   lastDepOpB --|
147a70aa7bbSRiver Riddle   //  |   ...          |
148a70aa7bbSRiver Riddle   //  |-> firstDepOpA  |
149a70aa7bbSRiver Riddle   //      ...          |
150a70aa7bbSRiver Riddle   //      opB <---------
151a70aa7bbSRiver Riddle   //
152a70aa7bbSRiver Riddle   // Valid insertion point range: (lastDepOpB, firstDepOpA)
153a70aa7bbSRiver Riddle   //
154b4785cebSUday Bondhugula   if (firstDepOpA) {
155b4785cebSUday Bondhugula     if (lastDepOpB) {
156a70aa7bbSRiver Riddle       if (firstDepOpA->isBeforeInBlock(lastDepOpB) || firstDepOpA == lastDepOpB)
157a70aa7bbSRiver Riddle         // No valid insertion point exists which preserves dependences.
158a70aa7bbSRiver Riddle         return nullptr;
159a70aa7bbSRiver Riddle     }
160a70aa7bbSRiver Riddle     // Return insertion point in valid range closest to 'opB'.
161a70aa7bbSRiver Riddle     // TODO: Consider other insertion points in valid range.
162a70aa7bbSRiver Riddle     return firstDepOpA;
163a70aa7bbSRiver Riddle   }
164a70aa7bbSRiver Riddle   // No dependences from 'opA' to operation in range ('opA', 'opB'), return
165a70aa7bbSRiver Riddle   // 'opB' insertion point.
166825b23a7SUday Bondhugula   return forOpB;
167a70aa7bbSRiver Riddle }
168a70aa7bbSRiver Riddle 
169a70aa7bbSRiver Riddle // Gathers all load and store ops in loop nest rooted at 'forOp' into
170a70aa7bbSRiver Riddle // 'loadAndStoreOps'.
171a70aa7bbSRiver Riddle static bool
gatherLoadsAndStores(AffineForOp forOp,SmallVectorImpl<Operation * > & loadAndStoreOps)172a70aa7bbSRiver Riddle gatherLoadsAndStores(AffineForOp forOp,
173a70aa7bbSRiver Riddle                      SmallVectorImpl<Operation *> &loadAndStoreOps) {
174a70aa7bbSRiver Riddle   bool hasIfOp = false;
175a70aa7bbSRiver Riddle   forOp.walk([&](Operation *op) {
176a70aa7bbSRiver Riddle     if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op))
177a70aa7bbSRiver Riddle       loadAndStoreOps.push_back(op);
178a70aa7bbSRiver Riddle     else if (isa<AffineIfOp>(op))
179a70aa7bbSRiver Riddle       hasIfOp = true;
180a70aa7bbSRiver Riddle   });
181a70aa7bbSRiver Riddle   return !hasIfOp;
182a70aa7bbSRiver Riddle }
183a70aa7bbSRiver Riddle 
184a70aa7bbSRiver Riddle /// Returns the maximum loop depth at which we could fuse producer loop
185a70aa7bbSRiver Riddle /// 'srcForOp' into consumer loop 'dstForOp' without violating data dependences.
186a70aa7bbSRiver Riddle // TODO: Generalize this check for sibling and more generic fusion scenarios.
187a70aa7bbSRiver Riddle // TODO: Support forward slice fusion.
getMaxLoopDepth(ArrayRef<Operation * > srcOps,ArrayRef<Operation * > dstOps)188a70aa7bbSRiver Riddle static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
189a70aa7bbSRiver Riddle                                 ArrayRef<Operation *> dstOps) {
190a70aa7bbSRiver Riddle   if (dstOps.empty())
191a70aa7bbSRiver Riddle     // Expected at least one memory operation.
192a70aa7bbSRiver Riddle     // TODO: Revisit this case with a specific example.
193a70aa7bbSRiver Riddle     return 0;
194a70aa7bbSRiver Riddle 
195a70aa7bbSRiver Riddle   // Filter out ops in 'dstOps' that do not use the producer-consumer memref so
196a70aa7bbSRiver Riddle   // that they are not considered for analysis.
197a70aa7bbSRiver Riddle   DenseSet<Value> producerConsumerMemrefs;
198a70aa7bbSRiver Riddle   gatherProducerConsumerMemrefs(srcOps, dstOps, producerConsumerMemrefs);
199a70aa7bbSRiver Riddle   SmallVector<Operation *, 4> targetDstOps;
200a70aa7bbSRiver Riddle   for (Operation *dstOp : dstOps) {
201a70aa7bbSRiver Riddle     auto loadOp = dyn_cast<AffineReadOpInterface>(dstOp);
202a70aa7bbSRiver Riddle     Value memref = loadOp ? loadOp.getMemRef()
203a70aa7bbSRiver Riddle                           : cast<AffineWriteOpInterface>(dstOp).getMemRef();
204a70aa7bbSRiver Riddle     if (producerConsumerMemrefs.count(memref) > 0)
205a70aa7bbSRiver Riddle       targetDstOps.push_back(dstOp);
206a70aa7bbSRiver Riddle   }
207a70aa7bbSRiver Riddle 
208a70aa7bbSRiver Riddle   assert(!targetDstOps.empty() &&
209a70aa7bbSRiver Riddle          "No dependences between 'srcForOp' and 'dstForOp'?");
210a70aa7bbSRiver Riddle 
211a70aa7bbSRiver Riddle   // Compute the innermost common loop depth for loads and stores.
212a70aa7bbSRiver Riddle   unsigned loopDepth = getInnermostCommonLoopDepth(targetDstOps);
213a70aa7bbSRiver Riddle 
214a70aa7bbSRiver Riddle   // Return common loop depth for loads if there are no store ops.
215971b8525SJakub Kuderski   if (all_of(targetDstOps, llvm::IsaPred<AffineReadOpInterface>))
216a70aa7bbSRiver Riddle     return loopDepth;
217a70aa7bbSRiver Riddle 
218a70aa7bbSRiver Riddle   // Check dependences on all pairs of ops in 'targetDstOps' and store the
219a70aa7bbSRiver Riddle   // minimum loop depth at which a dependence is satisfied.
220a70aa7bbSRiver Riddle   for (unsigned i = 0, e = targetDstOps.size(); i < e; ++i) {
221b4785cebSUday Bondhugula     Operation *srcOpInst = targetDstOps[i];
222a70aa7bbSRiver Riddle     MemRefAccess srcAccess(srcOpInst);
223a70aa7bbSRiver Riddle     for (unsigned j = 0; j < e; ++j) {
224a70aa7bbSRiver Riddle       auto *dstOpInst = targetDstOps[j];
225a70aa7bbSRiver Riddle       MemRefAccess dstAccess(dstOpInst);
226a70aa7bbSRiver Riddle 
227a70aa7bbSRiver Riddle       unsigned numCommonLoops =
228a70aa7bbSRiver Riddle           getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
229a70aa7bbSRiver Riddle       for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
230a70aa7bbSRiver Riddle         // TODO: Cache dependence analysis results, check cache here.
231ee7c4741SMatthias Springer         DependenceResult result =
232ee7c4741SMatthias Springer             checkMemrefAccessDependence(srcAccess, dstAccess, d);
233a70aa7bbSRiver Riddle         if (hasDependence(result)) {
234a70aa7bbSRiver Riddle           // Store minimum loop depth and break because we want the min 'd' at
235a70aa7bbSRiver Riddle           // which there is a dependence.
236a70aa7bbSRiver Riddle           loopDepth = std::min(loopDepth, d - 1);
237a70aa7bbSRiver Riddle           break;
238a70aa7bbSRiver Riddle         }
239a70aa7bbSRiver Riddle       }
240a70aa7bbSRiver Riddle     }
241a70aa7bbSRiver Riddle   }
242a70aa7bbSRiver Riddle 
243a70aa7bbSRiver Riddle   return loopDepth;
244a70aa7bbSRiver Riddle }
245a70aa7bbSRiver Riddle 
246a70aa7bbSRiver Riddle // TODO: This pass performs some computation that is the same for all the depths
247a70aa7bbSRiver Riddle // (e.g., getMaxLoopDepth). Implement a version of this utility that processes
248a70aa7bbSRiver Riddle // all the depths at once or only the legal maximal depth for maximal fusion.
canFuseLoops(AffineForOp srcForOp,AffineForOp dstForOp,unsigned dstLoopDepth,ComputationSliceState * srcSlice,FusionStrategy fusionStrategy)2494c48f016SMatthias Springer FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp,
2504c48f016SMatthias Springer                                         AffineForOp dstForOp,
251a70aa7bbSRiver Riddle                                         unsigned dstLoopDepth,
252a70aa7bbSRiver Riddle                                         ComputationSliceState *srcSlice,
253a70aa7bbSRiver Riddle                                         FusionStrategy fusionStrategy) {
254a70aa7bbSRiver Riddle   // Return 'failure' if 'dstLoopDepth == 0'.
255a70aa7bbSRiver Riddle   if (dstLoopDepth == 0) {
256a70aa7bbSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n");
257a70aa7bbSRiver Riddle     return FusionResult::FailPrecondition;
258a70aa7bbSRiver Riddle   }
259a70aa7bbSRiver Riddle   // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block.
260a70aa7bbSRiver Riddle   auto *block = srcForOp->getBlock();
261a70aa7bbSRiver Riddle   if (block != dstForOp->getBlock()) {
262a70aa7bbSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n");
263a70aa7bbSRiver Riddle     return FusionResult::FailPrecondition;
264a70aa7bbSRiver Riddle   }
265a70aa7bbSRiver Riddle 
266a70aa7bbSRiver Riddle   // Return 'failure' if no valid insertion point for fused loop nest in 'block'
267a70aa7bbSRiver Riddle   // exists which would preserve dependences.
268a70aa7bbSRiver Riddle   if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) {
269a70aa7bbSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n");
270a70aa7bbSRiver Riddle     return FusionResult::FailBlockDependence;
271a70aa7bbSRiver Riddle   }
272a70aa7bbSRiver Riddle 
273a70aa7bbSRiver Riddle   // Check if 'srcForOp' precedes 'dstForOp' in 'block'.
274825b23a7SUday Bondhugula   bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp);
275a70aa7bbSRiver Riddle   // 'forOpA' executes before 'forOpB' in 'block'.
276a70aa7bbSRiver Riddle   auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
277a70aa7bbSRiver Riddle   auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
278a70aa7bbSRiver Riddle 
279a70aa7bbSRiver Riddle   // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'.
280a70aa7bbSRiver Riddle   SmallVector<Operation *, 4> opsA;
281a70aa7bbSRiver Riddle   if (!gatherLoadsAndStores(forOpA, opsA)) {
282a70aa7bbSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
283a70aa7bbSRiver Riddle     return FusionResult::FailPrecondition;
284a70aa7bbSRiver Riddle   }
285a70aa7bbSRiver Riddle 
286a70aa7bbSRiver Riddle   // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'.
287a70aa7bbSRiver Riddle   SmallVector<Operation *, 4> opsB;
288a70aa7bbSRiver Riddle   if (!gatherLoadsAndStores(forOpB, opsB)) {
289a70aa7bbSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
290a70aa7bbSRiver Riddle     return FusionResult::FailPrecondition;
291a70aa7bbSRiver Riddle   }
292a70aa7bbSRiver Riddle 
293a70aa7bbSRiver Riddle   // Return 'failure' if fusing loops at depth 'dstLoopDepth' wouldn't preserve
294a70aa7bbSRiver Riddle   // loop dependences.
295a70aa7bbSRiver Riddle   // TODO: Enable this check for sibling and more generic loop fusion
296a70aa7bbSRiver Riddle   // strategies.
297a70aa7bbSRiver Riddle   if (fusionStrategy.getStrategy() == FusionStrategy::ProducerConsumer) {
298a70aa7bbSRiver Riddle     // TODO: 'getMaxLoopDepth' does not support forward slice fusion.
299a70aa7bbSRiver Riddle     assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion");
300a70aa7bbSRiver Riddle     if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) {
301a70aa7bbSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n");
302a70aa7bbSRiver Riddle       return FusionResult::FailFusionDependence;
303a70aa7bbSRiver Riddle     }
304a70aa7bbSRiver Riddle   }
305a70aa7bbSRiver Riddle 
306a70aa7bbSRiver Riddle   // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'.
307825b23a7SUday Bondhugula   unsigned numCommonLoops =
3084c48f016SMatthias Springer       affine::getNumCommonSurroundingLoops(*srcForOp, *dstForOp);
309a70aa7bbSRiver Riddle 
310a70aa7bbSRiver Riddle   // Filter out ops in 'opsA' to compute the slice union based on the
311a70aa7bbSRiver Riddle   // assumptions made by the fusion strategy.
312a70aa7bbSRiver Riddle   SmallVector<Operation *, 4> strategyOpsA;
313a70aa7bbSRiver Riddle   switch (fusionStrategy.getStrategy()) {
314a70aa7bbSRiver Riddle   case FusionStrategy::Generic:
315a70aa7bbSRiver Riddle     // Generic fusion. Take into account all the memory operations to compute
316a70aa7bbSRiver Riddle     // the slice union.
317a70aa7bbSRiver Riddle     strategyOpsA.append(opsA.begin(), opsA.end());
318a70aa7bbSRiver Riddle     break;
319a70aa7bbSRiver Riddle   case FusionStrategy::ProducerConsumer:
320a70aa7bbSRiver Riddle     // Producer-consumer fusion (AffineLoopFusion pass) only takes into
321a70aa7bbSRiver Riddle     // account stores in 'srcForOp' to compute the slice union.
322a70aa7bbSRiver Riddle     for (Operation *op : opsA) {
323a70aa7bbSRiver Riddle       if (isa<AffineWriteOpInterface>(op))
324a70aa7bbSRiver Riddle         strategyOpsA.push_back(op);
325a70aa7bbSRiver Riddle     }
326a70aa7bbSRiver Riddle     break;
327a70aa7bbSRiver Riddle   case FusionStrategy::Sibling:
328a70aa7bbSRiver Riddle     // Sibling fusion (AffineLoopFusion pass) only takes into account the loads
329a70aa7bbSRiver Riddle     // to 'memref' in 'srcForOp' to compute the slice union.
330a70aa7bbSRiver Riddle     for (Operation *op : opsA) {
331a70aa7bbSRiver Riddle       auto load = dyn_cast<AffineReadOpInterface>(op);
332a70aa7bbSRiver Riddle       if (load && load.getMemRef() == fusionStrategy.getSiblingFusionMemRef())
333a70aa7bbSRiver Riddle         strategyOpsA.push_back(op);
334a70aa7bbSRiver Riddle     }
335a70aa7bbSRiver Riddle     break;
336a70aa7bbSRiver Riddle   }
337a70aa7bbSRiver Riddle 
338a70aa7bbSRiver Riddle   // Compute union of computation slices computed between all pairs of ops
339a70aa7bbSRiver Riddle   // from 'forOpA' and 'forOpB'.
3404c48f016SMatthias Springer   SliceComputationResult sliceComputationResult = affine::computeSliceUnion(
3414c48f016SMatthias Springer       strategyOpsA, opsB, dstLoopDepth, numCommonLoops,
342a70aa7bbSRiver Riddle       isSrcForOpBeforeDstForOp, srcSlice);
343a70aa7bbSRiver Riddle   if (sliceComputationResult.value == SliceComputationResult::GenericFailure) {
344a70aa7bbSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
345a70aa7bbSRiver Riddle     return FusionResult::FailPrecondition;
346a70aa7bbSRiver Riddle   }
347a70aa7bbSRiver Riddle   if (sliceComputationResult.value ==
348a70aa7bbSRiver Riddle       SliceComputationResult::IncorrectSliceFailure) {
349a70aa7bbSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "Incorrect slice computation\n");
350a70aa7bbSRiver Riddle     return FusionResult::FailIncorrectSlice;
351a70aa7bbSRiver Riddle   }
352a70aa7bbSRiver Riddle 
353a70aa7bbSRiver Riddle   return FusionResult::Success;
354a70aa7bbSRiver Riddle }
355a70aa7bbSRiver Riddle 
356a70aa7bbSRiver Riddle /// Patch the loop body of a forOp that is a single iteration reduction loop
357a70aa7bbSRiver Riddle /// into its containing block.
promoteSingleIterReductionLoop(AffineForOp forOp,bool siblingFusionUser)3584c48f016SMatthias Springer static LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp,
359a70aa7bbSRiver Riddle                                                     bool siblingFusionUser) {
360a70aa7bbSRiver Riddle   // Check if the reduction loop is a single iteration loop.
3610a81ace0SKazu Hirata   std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
3626d5fc1e3SKazu Hirata   if (!tripCount || *tripCount != 1)
363a70aa7bbSRiver Riddle     return failure();
364a70aa7bbSRiver Riddle   auto *parentOp = forOp->getParentOp();
365a70aa7bbSRiver Riddle   if (!isa<AffineForOp>(parentOp))
366a70aa7bbSRiver Riddle     return failure();
36763086d6aSMatthias Springer   SmallVector<Value> newOperands;
36863086d6aSMatthias Springer   llvm::append_range(newOperands,
36963086d6aSMatthias Springer                      forOp.getBody()->getTerminator()->getOperands());
37063086d6aSMatthias Springer   IRRewriter rewriter(parentOp->getContext());
37163086d6aSMatthias Springer   int64_t parentOpNumResults = parentOp->getNumResults();
372a70aa7bbSRiver Riddle   // Replace the parent loop and add iteroperands and results from the `forOp`.
373a70aa7bbSRiver Riddle   AffineForOp parentForOp = forOp->getParentOfType<AffineForOp>();
37463086d6aSMatthias Springer   AffineForOp newLoop =
37563086d6aSMatthias Springer       cast<AffineForOp>(*parentForOp.replaceWithAdditionalYields(
37663086d6aSMatthias Springer           rewriter, forOp.getInits(), /*replaceInitOperandUsesInLoop=*/false,
37763086d6aSMatthias Springer           [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
37863086d6aSMatthias Springer             return newOperands;
37963086d6aSMatthias Springer           }));
380a70aa7bbSRiver Riddle 
381a70aa7bbSRiver Riddle   // For sibling-fusion users, collect operations that use the results of the
382a70aa7bbSRiver Riddle   // `forOp` outside the new parent loop that has absorbed all its iter args
383a70aa7bbSRiver Riddle   // and operands. These operations will be moved later after the results
384a70aa7bbSRiver Riddle   // have been replaced.
385a70aa7bbSRiver Riddle   SetVector<Operation *> forwardSlice;
386a70aa7bbSRiver Riddle   if (siblingFusionUser) {
387a70aa7bbSRiver Riddle     for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
388a70aa7bbSRiver Riddle       SetVector<Operation *> tmpForwardSlice;
389a70aa7bbSRiver Riddle       getForwardSlice(forOp.getResult(i), &tmpForwardSlice);
390a70aa7bbSRiver Riddle       forwardSlice.set_union(tmpForwardSlice);
391a70aa7bbSRiver Riddle     }
392a70aa7bbSRiver Riddle   }
393a70aa7bbSRiver Riddle   // Update the results of the `forOp` in the new loop.
394a70aa7bbSRiver Riddle   for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
395a70aa7bbSRiver Riddle     forOp.getResult(i).replaceAllUsesWith(
39663086d6aSMatthias Springer         newLoop.getResult(i + parentOpNumResults));
397a70aa7bbSRiver Riddle   }
398a70aa7bbSRiver Riddle   // For sibling-fusion users, move operations that use the results of the
399a70aa7bbSRiver Riddle   // `forOp` outside the new parent loop
400a70aa7bbSRiver Riddle   if (siblingFusionUser) {
401a70aa7bbSRiver Riddle     topologicalSort(forwardSlice);
402a70aa7bbSRiver Riddle     for (Operation *op : llvm::reverse(forwardSlice))
403a70aa7bbSRiver Riddle       op->moveAfter(newLoop);
404a70aa7bbSRiver Riddle   }
405a70aa7bbSRiver Riddle   // Replace the induction variable.
406a70aa7bbSRiver Riddle   auto iv = forOp.getInductionVar();
407a70aa7bbSRiver Riddle   iv.replaceAllUsesWith(newLoop.getInductionVar());
408a70aa7bbSRiver Riddle   // Replace the iter args.
409a70aa7bbSRiver Riddle   auto forOpIterArgs = forOp.getRegionIterArgs();
410a70aa7bbSRiver Riddle   for (auto it : llvm::zip(forOpIterArgs, newLoop.getRegionIterArgs().take_back(
411a70aa7bbSRiver Riddle                                               forOpIterArgs.size()))) {
412a70aa7bbSRiver Riddle     std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
413a70aa7bbSRiver Riddle   }
414a70aa7bbSRiver Riddle   // Move the loop body operations, except for its terminator, to the loop's
415a70aa7bbSRiver Riddle   // containing block.
416a70aa7bbSRiver Riddle   forOp.getBody()->back().erase();
417a70aa7bbSRiver Riddle   auto *parentBlock = forOp->getBlock();
418a70aa7bbSRiver Riddle   parentBlock->getOperations().splice(Block::iterator(forOp),
419a70aa7bbSRiver Riddle                                       forOp.getBody()->getOperations());
420a70aa7bbSRiver Riddle   forOp.erase();
421a70aa7bbSRiver Riddle   return success();
422a70aa7bbSRiver Riddle }
423a70aa7bbSRiver Riddle 
424a70aa7bbSRiver Riddle /// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point
425a70aa7bbSRiver Riddle /// and source slice loop bounds specified in 'srcSlice'.
fuseLoops(AffineForOp srcForOp,AffineForOp dstForOp,const ComputationSliceState & srcSlice,bool isInnermostSiblingInsertion)4264c48f016SMatthias Springer void mlir::affine::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
427a70aa7bbSRiver Riddle                              const ComputationSliceState &srcSlice,
428a70aa7bbSRiver Riddle                              bool isInnermostSiblingInsertion) {
429a70aa7bbSRiver Riddle   // Clone 'srcForOp' into 'dstForOp' at 'srcSlice->insertPoint'.
430a70aa7bbSRiver Riddle   OpBuilder b(srcSlice.insertPoint->getBlock(), srcSlice.insertPoint);
4314d67b278SJeff Niu   IRMapping mapper;
432a70aa7bbSRiver Riddle   b.clone(*srcForOp, mapper);
433a70aa7bbSRiver Riddle 
434a70aa7bbSRiver Riddle   // Update 'sliceLoopNest' upper and lower bounds from computed 'srcSlice'.
435a70aa7bbSRiver Riddle   SmallVector<AffineForOp, 4> sliceLoops;
436a70aa7bbSRiver Riddle   for (unsigned i = 0, e = srcSlice.ivs.size(); i < e; ++i) {
437a70aa7bbSRiver Riddle     auto loopIV = mapper.lookupOrNull(srcSlice.ivs[i]);
438a70aa7bbSRiver Riddle     if (!loopIV)
439a70aa7bbSRiver Riddle       continue;
440a70aa7bbSRiver Riddle     auto forOp = getForInductionVarOwner(loopIV);
441a70aa7bbSRiver Riddle     sliceLoops.push_back(forOp);
442a70aa7bbSRiver Riddle     if (AffineMap lbMap = srcSlice.lbs[i]) {
443a70aa7bbSRiver Riddle       auto lbOperands = srcSlice.lbOperands[i];
444a70aa7bbSRiver Riddle       canonicalizeMapAndOperands(&lbMap, &lbOperands);
445a70aa7bbSRiver Riddle       forOp.setLowerBound(lbOperands, lbMap);
446a70aa7bbSRiver Riddle     }
447a70aa7bbSRiver Riddle     if (AffineMap ubMap = srcSlice.ubs[i]) {
448a70aa7bbSRiver Riddle       auto ubOperands = srcSlice.ubOperands[i];
449a70aa7bbSRiver Riddle       canonicalizeMapAndOperands(&ubMap, &ubOperands);
450a70aa7bbSRiver Riddle       forOp.setUpperBound(ubOperands, ubMap);
451a70aa7bbSRiver Riddle     }
452a70aa7bbSRiver Riddle   }
453a70aa7bbSRiver Riddle 
454a70aa7bbSRiver Riddle   llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
455a70aa7bbSRiver Riddle   auto srcIsUnitSlice = [&]() {
456a70aa7bbSRiver Riddle     return (buildSliceTripCountMap(srcSlice, &sliceTripCountMap) &&
457a70aa7bbSRiver Riddle             (getSliceIterationCount(sliceTripCountMap) == 1));
458a70aa7bbSRiver Riddle   };
459a70aa7bbSRiver Riddle   // Fix up and if possible, eliminate single iteration loops.
460a70aa7bbSRiver Riddle   for (AffineForOp forOp : sliceLoops) {
461a70aa7bbSRiver Riddle     if (isLoopParallelAndContainsReduction(forOp) &&
462a70aa7bbSRiver Riddle         isInnermostSiblingInsertion && srcIsUnitSlice())
463a70aa7bbSRiver Riddle       // Patch reduction loop - only ones that are sibling-fused with the
464a70aa7bbSRiver Riddle       // destination loop - into the parent loop.
465a70aa7bbSRiver Riddle       (void)promoteSingleIterReductionLoop(forOp, true);
466d1e9c7b6Smax     else
467d1e9c7b6Smax       // Promote any single iteration slice loops.
468d1e9c7b6Smax       (void)promoteIfSingleIteration(forOp);
469a70aa7bbSRiver Riddle   }
470a70aa7bbSRiver Riddle }
471a70aa7bbSRiver Riddle 
472a70aa7bbSRiver Riddle /// Collect loop nest statistics (eg. loop trip count and operation count)
473a70aa7bbSRiver Riddle /// in 'stats' for loop nest rooted at 'forOp'. Returns true on success,
474a70aa7bbSRiver Riddle /// returns false otherwise.
getLoopNestStats(AffineForOp forOpRoot,LoopNestStats * stats)4754c48f016SMatthias Springer bool mlir::affine::getLoopNestStats(AffineForOp forOpRoot,
4764c48f016SMatthias Springer                                     LoopNestStats *stats) {
477a70aa7bbSRiver Riddle   auto walkResult = forOpRoot.walk([&](AffineForOp forOp) {
478a70aa7bbSRiver Riddle     auto *childForOp = forOp.getOperation();
479a70aa7bbSRiver Riddle     auto *parentForOp = forOp->getParentOp();
480fe9d0a47SUday Bondhugula     if (forOp != forOpRoot) {
481a70aa7bbSRiver Riddle       if (!isa<AffineForOp>(parentForOp)) {
482a70aa7bbSRiver Riddle         LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n");
483a70aa7bbSRiver Riddle         return WalkResult::interrupt();
484a70aa7bbSRiver Riddle       }
485a70aa7bbSRiver Riddle       // Add mapping to 'forOp' from its parent AffineForOp.
486a70aa7bbSRiver Riddle       stats->loopMap[parentForOp].push_back(forOp);
487a70aa7bbSRiver Riddle     }
488a70aa7bbSRiver Riddle 
489a70aa7bbSRiver Riddle     // Record the number of op operations in the body of 'forOp'.
490a70aa7bbSRiver Riddle     unsigned count = 0;
491a70aa7bbSRiver Riddle     stats->opCountMap[childForOp] = 0;
492a70aa7bbSRiver Riddle     for (auto &op : *forOp.getBody()) {
493a70aa7bbSRiver Riddle       if (!isa<AffineForOp, AffineIfOp>(op))
494a70aa7bbSRiver Riddle         ++count;
495a70aa7bbSRiver Riddle     }
496a70aa7bbSRiver Riddle     stats->opCountMap[childForOp] = count;
497a70aa7bbSRiver Riddle 
498a70aa7bbSRiver Riddle     // Record trip count for 'forOp'. Set flag if trip count is not
499a70aa7bbSRiver Riddle     // constant.
5000a81ace0SKazu Hirata     std::optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
501037f0995SKazu Hirata     if (!maybeConstTripCount) {
502a70aa7bbSRiver Riddle       // Currently only constant trip count loop nests are supported.
503a70aa7bbSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported\n");
504a70aa7bbSRiver Riddle       return WalkResult::interrupt();
505a70aa7bbSRiver Riddle     }
506a70aa7bbSRiver Riddle 
5076d5fc1e3SKazu Hirata     stats->tripCountMap[childForOp] = *maybeConstTripCount;
508a70aa7bbSRiver Riddle     return WalkResult::advance();
509a70aa7bbSRiver Riddle   });
510a70aa7bbSRiver Riddle   return !walkResult.wasInterrupted();
511a70aa7bbSRiver Riddle }
512a70aa7bbSRiver Riddle 
513a70aa7bbSRiver Riddle // Computes the total cost of the loop nest rooted at 'forOp'.
514a70aa7bbSRiver Riddle // Currently, the total cost is computed by counting the total operation
515a70aa7bbSRiver Riddle // instance count (i.e. total number of operations in the loop bodyloop
516a70aa7bbSRiver Riddle // operation count * loop trip count) for the entire loop nest.
517a70aa7bbSRiver Riddle // If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
518a70aa7bbSRiver Riddle // specified in the map when computing the total op instance count.
519a70aa7bbSRiver Riddle // NOTEs: 1) This is used to compute the cost of computation slices, which are
520a70aa7bbSRiver Riddle // sliced along the iteration dimension, and thus reduce the trip count.
521a70aa7bbSRiver Riddle // If 'computeCostMap' is non-null, the total op count for forOps specified
522a70aa7bbSRiver Riddle // in the map is increased (not overridden) by adding the op count from the
523a70aa7bbSRiver Riddle // map to the existing op count for the for loop. This is done before
524a70aa7bbSRiver Riddle // multiplying by the loop's trip count, and is used to model the cost of
525a70aa7bbSRiver Riddle // inserting a sliced loop nest of known cost into the loop's body.
526a70aa7bbSRiver Riddle // 2) This is also used to compute the cost of fusing a slice of some loop nest
527a70aa7bbSRiver Riddle // within another loop.
getComputeCostHelper(Operation * forOp,LoopNestStats & stats,llvm::SmallDenseMap<Operation *,uint64_t,8> * tripCountOverrideMap,DenseMap<Operation *,int64_t> * computeCostMap)528a70aa7bbSRiver Riddle static int64_t getComputeCostHelper(
529a70aa7bbSRiver Riddle     Operation *forOp, LoopNestStats &stats,
530a70aa7bbSRiver Riddle     llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap,
531a70aa7bbSRiver Riddle     DenseMap<Operation *, int64_t> *computeCostMap) {
532a70aa7bbSRiver Riddle   // 'opCount' is the total number operations in one iteration of 'forOp' body,
533a70aa7bbSRiver Riddle   // minus terminator op which is a no-op.
534a70aa7bbSRiver Riddle   int64_t opCount = stats.opCountMap[forOp] - 1;
535a70aa7bbSRiver Riddle   if (stats.loopMap.count(forOp) > 0) {
536a70aa7bbSRiver Riddle     for (auto childForOp : stats.loopMap[forOp]) {
537825b23a7SUday Bondhugula       opCount += getComputeCostHelper(childForOp, stats, tripCountOverrideMap,
538825b23a7SUday Bondhugula                                       computeCostMap);
539a70aa7bbSRiver Riddle     }
540a70aa7bbSRiver Riddle   }
541a70aa7bbSRiver Riddle   // Add in additional op instances from slice (if specified in map).
542b4785cebSUday Bondhugula   if (computeCostMap) {
543a70aa7bbSRiver Riddle     auto it = computeCostMap->find(forOp);
544a70aa7bbSRiver Riddle     if (it != computeCostMap->end()) {
545a70aa7bbSRiver Riddle       opCount += it->second;
546a70aa7bbSRiver Riddle     }
547a70aa7bbSRiver Riddle   }
548a70aa7bbSRiver Riddle   // Override trip count (if specified in map).
549a70aa7bbSRiver Riddle   int64_t tripCount = stats.tripCountMap[forOp];
550b4785cebSUday Bondhugula   if (tripCountOverrideMap) {
551a70aa7bbSRiver Riddle     auto it = tripCountOverrideMap->find(forOp);
552a70aa7bbSRiver Riddle     if (it != tripCountOverrideMap->end()) {
553a70aa7bbSRiver Riddle       tripCount = it->second;
554a70aa7bbSRiver Riddle     }
555a70aa7bbSRiver Riddle   }
556a70aa7bbSRiver Riddle   // Returns the total number of dynamic instances of operations in loop body.
557a70aa7bbSRiver Riddle   return tripCount * opCount;
558a70aa7bbSRiver Riddle }
559a70aa7bbSRiver Riddle 
560a70aa7bbSRiver Riddle /// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
561a70aa7bbSRiver Riddle /// Currently, the total cost is computed by counting the total operation
562a70aa7bbSRiver Riddle /// instance count (i.e. total number of operations in the loop body * loop
563a70aa7bbSRiver Riddle /// trip count) for the entire loop nest.
getComputeCost(AffineForOp forOp,LoopNestStats & stats)5644c48f016SMatthias Springer int64_t mlir::affine::getComputeCost(AffineForOp forOp, LoopNestStats &stats) {
565825b23a7SUday Bondhugula   return getComputeCostHelper(forOp, stats,
566a70aa7bbSRiver Riddle                               /*tripCountOverrideMap=*/nullptr,
567a70aa7bbSRiver Riddle                               /*computeCostMap=*/nullptr);
568a70aa7bbSRiver Riddle }
569a70aa7bbSRiver Riddle 
570a70aa7bbSRiver Riddle /// Computes and returns in 'computeCost', the total compute cost of fusing the
571a70aa7bbSRiver Riddle /// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently,
572a70aa7bbSRiver Riddle /// the total cost is computed by counting the total operation instance count
573a70aa7bbSRiver Riddle /// (i.e. total number of operations in the loop body * loop trip count) for
574a70aa7bbSRiver Riddle /// the entire loop nest.
getFusionComputeCost(AffineForOp srcForOp,LoopNestStats & srcStats,AffineForOp dstForOp,LoopNestStats & dstStats,const ComputationSliceState & slice,int64_t * computeCost)5754c48f016SMatthias Springer bool mlir::affine::getFusionComputeCost(AffineForOp srcForOp,
5764c48f016SMatthias Springer                                         LoopNestStats &srcStats,
5774c48f016SMatthias Springer                                         AffineForOp dstForOp,
5784c48f016SMatthias Springer                                         LoopNestStats &dstStats,
579a70aa7bbSRiver Riddle                                         const ComputationSliceState &slice,
580a70aa7bbSRiver Riddle                                         int64_t *computeCost) {
581a70aa7bbSRiver Riddle   llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
582a70aa7bbSRiver Riddle   DenseMap<Operation *, int64_t> computeCostMap;
583a70aa7bbSRiver Riddle 
584a70aa7bbSRiver Riddle   // Build trip count map for computation slice.
585a70aa7bbSRiver Riddle   if (!buildSliceTripCountMap(slice, &sliceTripCountMap))
586a70aa7bbSRiver Riddle     return false;
587a70aa7bbSRiver Riddle   // Checks whether a store to load forwarding will happen.
588a70aa7bbSRiver Riddle   int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
589a70aa7bbSRiver Riddle   assert(sliceIterationCount > 0);
590a70aa7bbSRiver Riddle   bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
591a70aa7bbSRiver Riddle   auto *insertPointParent = slice.insertPoint->getParentOp();
592a70aa7bbSRiver Riddle 
593a70aa7bbSRiver Riddle   // The store and loads to this memref will disappear.
594a70aa7bbSRiver Riddle   if (storeLoadFwdGuaranteed) {
595a70aa7bbSRiver Riddle     // Subtract from operation count the loads/store we expect load/store
596a70aa7bbSRiver Riddle     // forwarding to remove.
597a70aa7bbSRiver Riddle     unsigned storeCount = 0;
598a70aa7bbSRiver Riddle     llvm::SmallDenseSet<Value, 4> storeMemrefs;
599b4785cebSUday Bondhugula     srcForOp.walk([&](AffineWriteOpInterface storeOp) {
600a70aa7bbSRiver Riddle       storeMemrefs.insert(storeOp.getMemRef());
601a70aa7bbSRiver Riddle       ++storeCount;
602a70aa7bbSRiver Riddle     });
603a70aa7bbSRiver Riddle     // Subtract out any store ops in single-iteration src slice loop nest.
604a70aa7bbSRiver Riddle     if (storeCount > 0)
605a70aa7bbSRiver Riddle       computeCostMap[insertPointParent] = -storeCount;
606a70aa7bbSRiver Riddle     // Subtract out any load users of 'storeMemrefs' nested below
607a70aa7bbSRiver Riddle     // 'insertPointParent'.
608825b23a7SUday Bondhugula     for (Value memref : storeMemrefs) {
609b4785cebSUday Bondhugula       for (Operation *user : memref.getUsers()) {
610b4785cebSUday Bondhugula         if (!isa<AffineReadOpInterface>(user))
611b4785cebSUday Bondhugula           continue;
612a70aa7bbSRiver Riddle         SmallVector<AffineForOp, 4> loops;
613a70aa7bbSRiver Riddle         // Check if any loop in loop nest surrounding 'user' is
614a70aa7bbSRiver Riddle         // 'insertPointParent'.
61523bcd6b8SUday Bondhugula         getAffineForIVs(*user, &loops);
616a70aa7bbSRiver Riddle         if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) {
617b4785cebSUday Bondhugula           if (auto forOp = dyn_cast_or_null<AffineForOp>(user->getParentOp())) {
618a70aa7bbSRiver Riddle             if (computeCostMap.count(forOp) == 0)
619a70aa7bbSRiver Riddle               computeCostMap[forOp] = 0;
620a70aa7bbSRiver Riddle             computeCostMap[forOp] -= 1;
621a70aa7bbSRiver Riddle           }
622a70aa7bbSRiver Riddle         }
623a70aa7bbSRiver Riddle       }
624a70aa7bbSRiver Riddle     }
625a70aa7bbSRiver Riddle   }
626a70aa7bbSRiver Riddle 
627a70aa7bbSRiver Riddle   // Compute op instance count for the src loop nest with iteration slicing.
628a70aa7bbSRiver Riddle   int64_t sliceComputeCost = getComputeCostHelper(
629825b23a7SUday Bondhugula       srcForOp, srcStats, &sliceTripCountMap, &computeCostMap);
630a70aa7bbSRiver Riddle 
631a70aa7bbSRiver Riddle   // Compute cost of fusion for this depth.
632a70aa7bbSRiver Riddle   computeCostMap[insertPointParent] = sliceComputeCost;
633a70aa7bbSRiver Riddle 
634a70aa7bbSRiver Riddle   *computeCost =
635825b23a7SUday Bondhugula       getComputeCostHelper(dstForOp, dstStats,
636a70aa7bbSRiver Riddle                            /*tripCountOverrideMap=*/nullptr, &computeCostMap);
637a70aa7bbSRiver Riddle   return true;
638a70aa7bbSRiver Riddle }
639a70aa7bbSRiver Riddle 
640a70aa7bbSRiver Riddle /// Returns in 'producerConsumerMemrefs' the memrefs involved in a
641a70aa7bbSRiver Riddle /// producer-consumer dependence between write ops in 'srcOps' and read ops in
642a70aa7bbSRiver Riddle /// 'dstOps'.
gatherProducerConsumerMemrefs(ArrayRef<Operation * > srcOps,ArrayRef<Operation * > dstOps,DenseSet<Value> & producerConsumerMemrefs)6434c48f016SMatthias Springer void mlir::affine::gatherProducerConsumerMemrefs(
644a70aa7bbSRiver Riddle     ArrayRef<Operation *> srcOps, ArrayRef<Operation *> dstOps,
645a70aa7bbSRiver Riddle     DenseSet<Value> &producerConsumerMemrefs) {
646a70aa7bbSRiver Riddle   // Gather memrefs from stores in 'srcOps'.
647a70aa7bbSRiver Riddle   DenseSet<Value> srcStoreMemRefs;
648a70aa7bbSRiver Riddle   for (Operation *op : srcOps)
649a70aa7bbSRiver Riddle     if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op))
650a70aa7bbSRiver Riddle       srcStoreMemRefs.insert(storeOp.getMemRef());
651a70aa7bbSRiver Riddle 
652a70aa7bbSRiver Riddle   // Compute the intersection between memrefs from stores in 'srcOps' and
653a70aa7bbSRiver Riddle   // memrefs from loads in 'dstOps'.
654a70aa7bbSRiver Riddle   for (Operation *op : dstOps)
655a70aa7bbSRiver Riddle     if (auto loadOp = dyn_cast<AffineReadOpInterface>(op))
656a70aa7bbSRiver Riddle       if (srcStoreMemRefs.count(loadOp.getMemRef()) > 0)
657a70aa7bbSRiver Riddle         producerConsumerMemrefs.insert(loadOp.getMemRef());
658a70aa7bbSRiver Riddle }
659