xref: /llvm-project/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp (revision b4785cebfb0a756e19ff4ae3e41d2563f10cd292)
1 //===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop fusion transformation utility functions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/LoopFusionUtils.h"
14 #include "mlir/Analysis/SliceAnalysis.h"
15 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
16 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
17 #include "mlir/Dialect/Affine/Analysis/Utils.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Affine/LoopUtils.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Operation.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include <optional>
26 
27 #define DEBUG_TYPE "loop-fusion-utils"
28 
29 using namespace mlir;
30 using namespace mlir::affine;
31 
32 // Gathers all load and store memref accesses in 'opA' into 'values', where
33 // 'values[memref] == true' for each store operation.
34 static void getLoadAndStoreMemRefAccesses(Operation *opA,
35                                           DenseMap<Value, bool> &values) {
36   opA->walk([&](Operation *op) {
37     if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
38       if (values.count(loadOp.getMemRef()) == 0)
39         values[loadOp.getMemRef()] = false;
40     } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
41       values[storeOp.getMemRef()] = true;
42     }
43   });
44 }
45 
46 /// Returns true if 'op' is a load or store operation which access a memref
47 /// accessed 'values' and at least one of the access is a store operation.
48 /// Returns false otherwise.
49 static bool isDependentLoadOrStoreOp(Operation *op,
50                                      DenseMap<Value, bool> &values) {
51   if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
52     return values.count(loadOp.getMemRef()) > 0 && values[loadOp.getMemRef()];
53   }
54   if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
55     return values.count(storeOp.getMemRef()) > 0;
56   }
57   return false;
58 }
59 
60 // Returns the first operation in range ('opA', 'opB') which has a data
61 // dependence on 'opA'. Returns 'nullptr' of no dependence exists.
62 static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) {
63   // Record memref values from all loads/store in loop nest rooted at 'opA'.
64   // Map from memref value to bool which is true if store, false otherwise.
65   DenseMap<Value, bool> values;
66   getLoadAndStoreMemRefAccesses(opA, values);
67 
68   // For each 'opX' in block in range ('opA', 'opB'), check if there is a data
69   // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref
70   // and at least one of the accesses is a store).
71   Operation *firstDepOp = nullptr;
72   for (Block::iterator it = std::next(Block::iterator(opA));
73        it != Block::iterator(opB); ++it) {
74     Operation *opX = &(*it);
75     opX->walk([&](Operation *op) {
76       if (!firstDepOp && isDependentLoadOrStoreOp(op, values))
77         firstDepOp = opX;
78     });
79     if (firstDepOp)
80       break;
81   }
82   return firstDepOp;
83 }
84 
85 // Returns the last operation 'opX' in range ('opA', 'opB'), for which there
86 // exists a data dependence from 'opX' to 'opB'.
87 // Returns 'nullptr' of no dependence exists.
88 static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) {
89   // Record memref values from all loads/store in loop nest rooted at 'opB'.
90   // Map from memref value to bool which is true if store, false otherwise.
91   DenseMap<Value, bool> values;
92   getLoadAndStoreMemRefAccesses(opB, values);
93 
94   // For each 'opX' in block in range ('opA', 'opB') in reverse order,
95   // check if there is a data dependence from 'opX' to 'opB':
96   // *) 'opX' and 'opB' access the same memref and at least one of the accesses
97   //    is a store.
98   // *) 'opX' produces an SSA Value which is used by 'opB'.
99   Operation *lastDepOp = nullptr;
100   for (Block::reverse_iterator it = std::next(Block::reverse_iterator(opB));
101        it != Block::reverse_iterator(opA); ++it) {
102     Operation *opX = &(*it);
103     opX->walk([&](Operation *op) {
104       if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
105         if (isDependentLoadOrStoreOp(op, values)) {
106           lastDepOp = opX;
107           return WalkResult::interrupt();
108         }
109         return WalkResult::advance();
110       }
111       for (Value value : op->getResults()) {
112         for (Operation *user : value.getUsers()) {
113           SmallVector<AffineForOp, 4> loops;
114           // Check if any loop in loop nest surrounding 'user' is 'opB'.
115           getAffineForIVs(*user, &loops);
116           if (llvm::is_contained(loops, cast<AffineForOp>(opB))) {
117             lastDepOp = opX;
118             return WalkResult::interrupt();
119           }
120         }
121       }
122       return WalkResult::advance();
123     });
124     if (lastDepOp)
125       break;
126   }
127   return lastDepOp;
128 }
129 
130 // Computes and returns an insertion point operation, before which the
131 // the fused <srcForOp, dstForOp> loop nest can be inserted while preserving
132 // dependences. Returns nullptr if no such insertion point is found.
133 static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp,
134                                                  AffineForOp dstForOp) {
135   bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp);
136   auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
137   auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
138 
139   Operation *firstDepOpA = getFirstDependentOpInRange(forOpA, forOpB);
140   Operation *lastDepOpB = getLastDependentOpInRange(forOpA, forOpB);
141   // Block:
142   //      ...
143   //  |-- opA
144   //  |   ...
145   //  |   lastDepOpB --|
146   //  |   ...          |
147   //  |-> firstDepOpA  |
148   //      ...          |
149   //      opB <---------
150   //
151   // Valid insertion point range: (lastDepOpB, firstDepOpA)
152   //
153   if (firstDepOpA) {
154     if (lastDepOpB) {
155       if (firstDepOpA->isBeforeInBlock(lastDepOpB) || firstDepOpA == lastDepOpB)
156         // No valid insertion point exists which preserves dependences.
157         return nullptr;
158     }
159     // Return insertion point in valid range closest to 'opB'.
160     // TODO: Consider other insertion points in valid range.
161     return firstDepOpA;
162   }
163   // No dependences from 'opA' to operation in range ('opA', 'opB'), return
164   // 'opB' insertion point.
165   return forOpB;
166 }
167 
168 // Gathers all load and store ops in loop nest rooted at 'forOp' into
169 // 'loadAndStoreOps'.
170 static bool
171 gatherLoadsAndStores(AffineForOp forOp,
172                      SmallVectorImpl<Operation *> &loadAndStoreOps) {
173   bool hasIfOp = false;
174   forOp.walk([&](Operation *op) {
175     if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op))
176       loadAndStoreOps.push_back(op);
177     else if (isa<AffineIfOp>(op))
178       hasIfOp = true;
179   });
180   return !hasIfOp;
181 }
182 
183 /// Returns the maximum loop depth at which we could fuse producer loop
184 /// 'srcForOp' into consumer loop 'dstForOp' without violating data dependences.
185 // TODO: Generalize this check for sibling and more generic fusion scenarios.
186 // TODO: Support forward slice fusion.
187 static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
188                                 ArrayRef<Operation *> dstOps) {
189   if (dstOps.empty())
190     // Expected at least one memory operation.
191     // TODO: Revisit this case with a specific example.
192     return 0;
193 
194   // Filter out ops in 'dstOps' that do not use the producer-consumer memref so
195   // that they are not considered for analysis.
196   DenseSet<Value> producerConsumerMemrefs;
197   gatherProducerConsumerMemrefs(srcOps, dstOps, producerConsumerMemrefs);
198   SmallVector<Operation *, 4> targetDstOps;
199   for (Operation *dstOp : dstOps) {
200     auto loadOp = dyn_cast<AffineReadOpInterface>(dstOp);
201     Value memref = loadOp ? loadOp.getMemRef()
202                           : cast<AffineWriteOpInterface>(dstOp).getMemRef();
203     if (producerConsumerMemrefs.count(memref) > 0)
204       targetDstOps.push_back(dstOp);
205   }
206 
207   assert(!targetDstOps.empty() &&
208          "No dependences between 'srcForOp' and 'dstForOp'?");
209 
210   // Compute the innermost common loop depth for loads and stores.
211   unsigned loopDepth = getInnermostCommonLoopDepth(targetDstOps);
212 
213   // Return common loop depth for loads if there are no store ops.
214   if (all_of(targetDstOps,
215              [&](Operation *op) { return isa<AffineReadOpInterface>(op); }))
216     return loopDepth;
217 
218   // Check dependences on all pairs of ops in 'targetDstOps' and store the
219   // minimum loop depth at which a dependence is satisfied.
220   for (unsigned i = 0, e = targetDstOps.size(); i < e; ++i) {
221     Operation *srcOpInst = targetDstOps[i];
222     MemRefAccess srcAccess(srcOpInst);
223     for (unsigned j = 0; j < e; ++j) {
224       auto *dstOpInst = targetDstOps[j];
225       MemRefAccess dstAccess(dstOpInst);
226 
227       unsigned numCommonLoops =
228           getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
229       for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
230         // TODO: Cache dependence analysis results, check cache here.
231         DependenceResult result =
232             checkMemrefAccessDependence(srcAccess, dstAccess, d);
233         if (hasDependence(result)) {
234           // Store minimum loop depth and break because we want the min 'd' at
235           // which there is a dependence.
236           loopDepth = std::min(loopDepth, d - 1);
237           break;
238         }
239       }
240     }
241   }
242 
243   return loopDepth;
244 }
245 
246 // TODO: This pass performs some computation that is the same for all the depths
247 // (e.g., getMaxLoopDepth). Implement a version of this utility that processes
248 // all the depths at once or only the legal maximal depth for maximal fusion.
249 FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp,
250                                         AffineForOp dstForOp,
251                                         unsigned dstLoopDepth,
252                                         ComputationSliceState *srcSlice,
253                                         FusionStrategy fusionStrategy) {
254   // Return 'failure' if 'dstLoopDepth == 0'.
255   if (dstLoopDepth == 0) {
256     LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n");
257     return FusionResult::FailPrecondition;
258   }
259   // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block.
260   auto *block = srcForOp->getBlock();
261   if (block != dstForOp->getBlock()) {
262     LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n");
263     return FusionResult::FailPrecondition;
264   }
265 
266   // Return 'failure' if no valid insertion point for fused loop nest in 'block'
267   // exists which would preserve dependences.
268   if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) {
269     LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n");
270     return FusionResult::FailBlockDependence;
271   }
272 
273   // Check if 'srcForOp' precedes 'dstForOp' in 'block'.
274   bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp);
275   // 'forOpA' executes before 'forOpB' in 'block'.
276   auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
277   auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
278 
279   // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'.
280   SmallVector<Operation *, 4> opsA;
281   if (!gatherLoadsAndStores(forOpA, opsA)) {
282     LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
283     return FusionResult::FailPrecondition;
284   }
285 
286   // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'.
287   SmallVector<Operation *, 4> opsB;
288   if (!gatherLoadsAndStores(forOpB, opsB)) {
289     LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
290     return FusionResult::FailPrecondition;
291   }
292 
293   // Return 'failure' if fusing loops at depth 'dstLoopDepth' wouldn't preserve
294   // loop dependences.
295   // TODO: Enable this check for sibling and more generic loop fusion
296   // strategies.
297   if (fusionStrategy.getStrategy() == FusionStrategy::ProducerConsumer) {
298     // TODO: 'getMaxLoopDepth' does not support forward slice fusion.
299     assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion");
300     if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) {
301       LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n");
302       return FusionResult::FailFusionDependence;
303     }
304   }
305 
306   // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'.
307   unsigned numCommonLoops =
308       affine::getNumCommonSurroundingLoops(*srcForOp, *dstForOp);
309 
310   // Filter out ops in 'opsA' to compute the slice union based on the
311   // assumptions made by the fusion strategy.
312   SmallVector<Operation *, 4> strategyOpsA;
313   switch (fusionStrategy.getStrategy()) {
314   case FusionStrategy::Generic:
315     // Generic fusion. Take into account all the memory operations to compute
316     // the slice union.
317     strategyOpsA.append(opsA.begin(), opsA.end());
318     break;
319   case FusionStrategy::ProducerConsumer:
320     // Producer-consumer fusion (AffineLoopFusion pass) only takes into
321     // account stores in 'srcForOp' to compute the slice union.
322     for (Operation *op : opsA) {
323       if (isa<AffineWriteOpInterface>(op))
324         strategyOpsA.push_back(op);
325     }
326     break;
327   case FusionStrategy::Sibling:
328     // Sibling fusion (AffineLoopFusion pass) only takes into account the loads
329     // to 'memref' in 'srcForOp' to compute the slice union.
330     for (Operation *op : opsA) {
331       auto load = dyn_cast<AffineReadOpInterface>(op);
332       if (load && load.getMemRef() == fusionStrategy.getSiblingFusionMemRef())
333         strategyOpsA.push_back(op);
334     }
335     break;
336   }
337 
338   // Compute union of computation slices computed between all pairs of ops
339   // from 'forOpA' and 'forOpB'.
340   SliceComputationResult sliceComputationResult = affine::computeSliceUnion(
341       strategyOpsA, opsB, dstLoopDepth, numCommonLoops,
342       isSrcForOpBeforeDstForOp, srcSlice);
343   if (sliceComputationResult.value == SliceComputationResult::GenericFailure) {
344     LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
345     return FusionResult::FailPrecondition;
346   }
347   if (sliceComputationResult.value ==
348       SliceComputationResult::IncorrectSliceFailure) {
349     LLVM_DEBUG(llvm::dbgs() << "Incorrect slice computation\n");
350     return FusionResult::FailIncorrectSlice;
351   }
352 
353   return FusionResult::Success;
354 }
355 
356 /// Patch the loop body of a forOp that is a single iteration reduction loop
357 /// into its containing block.
358 static LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp,
359                                                     bool siblingFusionUser) {
360   // Check if the reduction loop is a single iteration loop.
361   std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
362   if (!tripCount || *tripCount != 1)
363     return failure();
364   auto *parentOp = forOp->getParentOp();
365   if (!isa<AffineForOp>(parentOp))
366     return failure();
367   SmallVector<Value> newOperands;
368   llvm::append_range(newOperands,
369                      forOp.getBody()->getTerminator()->getOperands());
370   IRRewriter rewriter(parentOp->getContext());
371   int64_t parentOpNumResults = parentOp->getNumResults();
372   // Replace the parent loop and add iteroperands and results from the `forOp`.
373   AffineForOp parentForOp = forOp->getParentOfType<AffineForOp>();
374   AffineForOp newLoop =
375       cast<AffineForOp>(*parentForOp.replaceWithAdditionalYields(
376           rewriter, forOp.getInits(), /*replaceInitOperandUsesInLoop=*/false,
377           [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
378             return newOperands;
379           }));
380 
381   // For sibling-fusion users, collect operations that use the results of the
382   // `forOp` outside the new parent loop that has absorbed all its iter args
383   // and operands. These operations will be moved later after the results
384   // have been replaced.
385   SetVector<Operation *> forwardSlice;
386   if (siblingFusionUser) {
387     for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
388       SetVector<Operation *> tmpForwardSlice;
389       getForwardSlice(forOp.getResult(i), &tmpForwardSlice);
390       forwardSlice.set_union(tmpForwardSlice);
391     }
392   }
393   // Update the results of the `forOp` in the new loop.
394   for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
395     forOp.getResult(i).replaceAllUsesWith(
396         newLoop.getResult(i + parentOpNumResults));
397   }
398   // For sibling-fusion users, move operations that use the results of the
399   // `forOp` outside the new parent loop
400   if (siblingFusionUser) {
401     topologicalSort(forwardSlice);
402     for (Operation *op : llvm::reverse(forwardSlice))
403       op->moveAfter(newLoop);
404   }
405   // Replace the induction variable.
406   auto iv = forOp.getInductionVar();
407   iv.replaceAllUsesWith(newLoop.getInductionVar());
408   // Replace the iter args.
409   auto forOpIterArgs = forOp.getRegionIterArgs();
410   for (auto it : llvm::zip(forOpIterArgs, newLoop.getRegionIterArgs().take_back(
411                                               forOpIterArgs.size()))) {
412     std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
413   }
414   // Move the loop body operations, except for its terminator, to the loop's
415   // containing block.
416   forOp.getBody()->back().erase();
417   auto *parentBlock = forOp->getBlock();
418   parentBlock->getOperations().splice(Block::iterator(forOp),
419                                       forOp.getBody()->getOperations());
420   forOp.erase();
421   return success();
422 }
423 
424 /// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point
425 /// and source slice loop bounds specified in 'srcSlice'.
426 void mlir::affine::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
427                              const ComputationSliceState &srcSlice,
428                              bool isInnermostSiblingInsertion) {
429   // Clone 'srcForOp' into 'dstForOp' at 'srcSlice->insertPoint'.
430   OpBuilder b(srcSlice.insertPoint->getBlock(), srcSlice.insertPoint);
431   IRMapping mapper;
432   b.clone(*srcForOp, mapper);
433 
434   // Update 'sliceLoopNest' upper and lower bounds from computed 'srcSlice'.
435   SmallVector<AffineForOp, 4> sliceLoops;
436   for (unsigned i = 0, e = srcSlice.ivs.size(); i < e; ++i) {
437     auto loopIV = mapper.lookupOrNull(srcSlice.ivs[i]);
438     if (!loopIV)
439       continue;
440     auto forOp = getForInductionVarOwner(loopIV);
441     sliceLoops.push_back(forOp);
442     if (AffineMap lbMap = srcSlice.lbs[i]) {
443       auto lbOperands = srcSlice.lbOperands[i];
444       canonicalizeMapAndOperands(&lbMap, &lbOperands);
445       forOp.setLowerBound(lbOperands, lbMap);
446     }
447     if (AffineMap ubMap = srcSlice.ubs[i]) {
448       auto ubOperands = srcSlice.ubOperands[i];
449       canonicalizeMapAndOperands(&ubMap, &ubOperands);
450       forOp.setUpperBound(ubOperands, ubMap);
451     }
452   }
453 
454   llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
455   auto srcIsUnitSlice = [&]() {
456     return (buildSliceTripCountMap(srcSlice, &sliceTripCountMap) &&
457             (getSliceIterationCount(sliceTripCountMap) == 1));
458   };
459   // Fix up and if possible, eliminate single iteration loops.
460   for (AffineForOp forOp : sliceLoops) {
461     if (isLoopParallelAndContainsReduction(forOp) &&
462         isInnermostSiblingInsertion && srcIsUnitSlice())
463       // Patch reduction loop - only ones that are sibling-fused with the
464       // destination loop - into the parent loop.
465       (void)promoteSingleIterReductionLoop(forOp, true);
466     else
467       // Promote any single iteration slice loops.
468       (void)promoteIfSingleIteration(forOp);
469   }
470 }
471 
472 /// Collect loop nest statistics (eg. loop trip count and operation count)
473 /// in 'stats' for loop nest rooted at 'forOp'. Returns true on success,
474 /// returns false otherwise.
475 bool mlir::affine::getLoopNestStats(AffineForOp forOpRoot,
476                                     LoopNestStats *stats) {
477   auto walkResult = forOpRoot.walk([&](AffineForOp forOp) {
478     auto *childForOp = forOp.getOperation();
479     auto *parentForOp = forOp->getParentOp();
480     if (forOp != forOpRoot) {
481       if (!isa<AffineForOp>(parentForOp)) {
482         LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n");
483         return WalkResult::interrupt();
484       }
485       // Add mapping to 'forOp' from its parent AffineForOp.
486       stats->loopMap[parentForOp].push_back(forOp);
487     }
488 
489     // Record the number of op operations in the body of 'forOp'.
490     unsigned count = 0;
491     stats->opCountMap[childForOp] = 0;
492     for (auto &op : *forOp.getBody()) {
493       if (!isa<AffineForOp, AffineIfOp>(op))
494         ++count;
495     }
496     stats->opCountMap[childForOp] = count;
497 
498     // Record trip count for 'forOp'. Set flag if trip count is not
499     // constant.
500     std::optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
501     if (!maybeConstTripCount) {
502       // Currently only constant trip count loop nests are supported.
503       LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported\n");
504       return WalkResult::interrupt();
505     }
506 
507     stats->tripCountMap[childForOp] = *maybeConstTripCount;
508     return WalkResult::advance();
509   });
510   return !walkResult.wasInterrupted();
511 }
512 
513 // Computes the total cost of the loop nest rooted at 'forOp'.
514 // Currently, the total cost is computed by counting the total operation
515 // instance count (i.e. total number of operations in the loop bodyloop
516 // operation count * loop trip count) for the entire loop nest.
517 // If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
518 // specified in the map when computing the total op instance count.
519 // NOTEs: 1) This is used to compute the cost of computation slices, which are
520 // sliced along the iteration dimension, and thus reduce the trip count.
521 // If 'computeCostMap' is non-null, the total op count for forOps specified
522 // in the map is increased (not overridden) by adding the op count from the
523 // map to the existing op count for the for loop. This is done before
524 // multiplying by the loop's trip count, and is used to model the cost of
525 // inserting a sliced loop nest of known cost into the loop's body.
526 // 2) This is also used to compute the cost of fusing a slice of some loop nest
527 // within another loop.
528 static int64_t getComputeCostHelper(
529     Operation *forOp, LoopNestStats &stats,
530     llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap,
531     DenseMap<Operation *, int64_t> *computeCostMap) {
532   // 'opCount' is the total number operations in one iteration of 'forOp' body,
533   // minus terminator op which is a no-op.
534   int64_t opCount = stats.opCountMap[forOp] - 1;
535   if (stats.loopMap.count(forOp) > 0) {
536     for (auto childForOp : stats.loopMap[forOp]) {
537       opCount += getComputeCostHelper(childForOp, stats, tripCountOverrideMap,
538                                       computeCostMap);
539     }
540   }
541   // Add in additional op instances from slice (if specified in map).
542   if (computeCostMap) {
543     auto it = computeCostMap->find(forOp);
544     if (it != computeCostMap->end()) {
545       opCount += it->second;
546     }
547   }
548   // Override trip count (if specified in map).
549   int64_t tripCount = stats.tripCountMap[forOp];
550   if (tripCountOverrideMap) {
551     auto it = tripCountOverrideMap->find(forOp);
552     if (it != tripCountOverrideMap->end()) {
553       tripCount = it->second;
554     }
555   }
556   // Returns the total number of dynamic instances of operations in loop body.
557   return tripCount * opCount;
558 }
559 
560 /// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
561 /// Currently, the total cost is computed by counting the total operation
562 /// instance count (i.e. total number of operations in the loop body * loop
563 /// trip count) for the entire loop nest.
564 int64_t mlir::affine::getComputeCost(AffineForOp forOp, LoopNestStats &stats) {
565   return getComputeCostHelper(forOp, stats,
566                               /*tripCountOverrideMap=*/nullptr,
567                               /*computeCostMap=*/nullptr);
568 }
569 
570 /// Computes and returns in 'computeCost', the total compute cost of fusing the
571 /// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently,
572 /// the total cost is computed by counting the total operation instance count
573 /// (i.e. total number of operations in the loop body * loop trip count) for
574 /// the entire loop nest.
575 bool mlir::affine::getFusionComputeCost(AffineForOp srcForOp,
576                                         LoopNestStats &srcStats,
577                                         AffineForOp dstForOp,
578                                         LoopNestStats &dstStats,
579                                         const ComputationSliceState &slice,
580                                         int64_t *computeCost) {
581   llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
582   DenseMap<Operation *, int64_t> computeCostMap;
583 
584   // Build trip count map for computation slice.
585   if (!buildSliceTripCountMap(slice, &sliceTripCountMap))
586     return false;
587   // Checks whether a store to load forwarding will happen.
588   int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
589   assert(sliceIterationCount > 0);
590   bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
591   auto *insertPointParent = slice.insertPoint->getParentOp();
592 
593   // The store and loads to this memref will disappear.
594   if (storeLoadFwdGuaranteed) {
595     // Subtract from operation count the loads/store we expect load/store
596     // forwarding to remove.
597     unsigned storeCount = 0;
598     llvm::SmallDenseSet<Value, 4> storeMemrefs;
599     srcForOp.walk([&](AffineWriteOpInterface storeOp) {
600       storeMemrefs.insert(storeOp.getMemRef());
601       ++storeCount;
602     });
603     // Subtract out any store ops in single-iteration src slice loop nest.
604     if (storeCount > 0)
605       computeCostMap[insertPointParent] = -storeCount;
606     // Subtract out any load users of 'storeMemrefs' nested below
607     // 'insertPointParent'.
608     for (Value memref : storeMemrefs) {
609       for (Operation *user : memref.getUsers()) {
610         if (!isa<AffineReadOpInterface>(user))
611           continue;
612         SmallVector<AffineForOp, 4> loops;
613         // Check if any loop in loop nest surrounding 'user' is
614         // 'insertPointParent'.
615         getAffineForIVs(*user, &loops);
616         if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) {
617           if (auto forOp = dyn_cast_or_null<AffineForOp>(user->getParentOp())) {
618             if (computeCostMap.count(forOp) == 0)
619               computeCostMap[forOp] = 0;
620             computeCostMap[forOp] -= 1;
621           }
622         }
623       }
624     }
625   }
626 
627   // Compute op instance count for the src loop nest with iteration slicing.
628   int64_t sliceComputeCost = getComputeCostHelper(
629       srcForOp, srcStats, &sliceTripCountMap, &computeCostMap);
630 
631   // Compute cost of fusion for this depth.
632   computeCostMap[insertPointParent] = sliceComputeCost;
633 
634   *computeCost =
635       getComputeCostHelper(dstForOp, dstStats,
636                            /*tripCountOverrideMap=*/nullptr, &computeCostMap);
637   return true;
638 }
639 
640 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a
641 /// producer-consumer dependence between write ops in 'srcOps' and read ops in
642 /// 'dstOps'.
643 void mlir::affine::gatherProducerConsumerMemrefs(
644     ArrayRef<Operation *> srcOps, ArrayRef<Operation *> dstOps,
645     DenseSet<Value> &producerConsumerMemrefs) {
646   // Gather memrefs from stores in 'srcOps'.
647   DenseSet<Value> srcStoreMemRefs;
648   for (Operation *op : srcOps)
649     if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op))
650       srcStoreMemRefs.insert(storeOp.getMemRef());
651 
652   // Compute the intersection between memrefs from stores in 'srcOps' and
653   // memrefs from loads in 'dstOps'.
654   for (Operation *op : dstOps)
655     if (auto loadOp = dyn_cast<AffineReadOpInterface>(op))
656       if (srcStoreMemRefs.count(loadOp.getMemRef()) > 0)
657         producerConsumerMemrefs.insert(loadOp.getMemRef());
658 }
659