xref: /llvm-project/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp (revision 23bcd6b86271f1c219a69183a5d90654faca64b8)
1 //===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop fusion transformation utility functions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/LoopFusionUtils.h"
14 #include "mlir/Analysis/SliceAnalysis.h"
15 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
16 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
17 #include "mlir/Dialect/Affine/Analysis/Utils.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Affine/LoopUtils.h"
20 #include "mlir/IR/BlockAndValueMapping.h"
21 #include "mlir/IR/Operation.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/raw_ostream.h"
24 
25 #define DEBUG_TYPE "loop-fusion-utils"
26 
27 using namespace mlir;
28 
29 // Gathers all load and store memref accesses in 'opA' into 'values', where
30 // 'values[memref] == true' for each store operation.
31 static void getLoadAndStoreMemRefAccesses(Operation *opA,
32                                           DenseMap<Value, bool> &values) {
33   opA->walk([&](Operation *op) {
34     if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
35       if (values.count(loadOp.getMemRef()) == 0)
36         values[loadOp.getMemRef()] = false;
37     } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
38       values[storeOp.getMemRef()] = true;
39     }
40   });
41 }
42 
43 /// Returns true if 'op' is a load or store operation which access a memref
44 /// accessed 'values' and at least one of the access is a store operation.
45 /// Returns false otherwise.
46 static bool isDependentLoadOrStoreOp(Operation *op,
47                                      DenseMap<Value, bool> &values) {
48   if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
49     return values.count(loadOp.getMemRef()) > 0 && values[loadOp.getMemRef()];
50   }
51   if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
52     return values.count(storeOp.getMemRef()) > 0;
53   }
54   return false;
55 }
56 
57 // Returns the first operation in range ('opA', 'opB') which has a data
58 // dependence on 'opA'. Returns 'nullptr' of no dependence exists.
59 static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) {
60   // Record memref values from all loads/store in loop nest rooted at 'opA'.
61   // Map from memref value to bool which is true if store, false otherwise.
62   DenseMap<Value, bool> values;
63   getLoadAndStoreMemRefAccesses(opA, values);
64 
65   // For each 'opX' in block in range ('opA', 'opB'), check if there is a data
66   // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref
67   // and at least one of the accesses is a store).
68   Operation *firstDepOp = nullptr;
69   for (Block::iterator it = std::next(Block::iterator(opA));
70        it != Block::iterator(opB); ++it) {
71     Operation *opX = &(*it);
72     opX->walk([&](Operation *op) {
73       if (!firstDepOp && isDependentLoadOrStoreOp(op, values))
74         firstDepOp = opX;
75     });
76     if (firstDepOp)
77       break;
78   }
79   return firstDepOp;
80 }
81 
82 // Returns the last operation 'opX' in range ('opA', 'opB'), for which there
83 // exists a data dependence from 'opX' to 'opB'.
84 // Returns 'nullptr' of no dependence exists.
85 static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) {
86   // Record memref values from all loads/store in loop nest rooted at 'opB'.
87   // Map from memref value to bool which is true if store, false otherwise.
88   DenseMap<Value, bool> values;
89   getLoadAndStoreMemRefAccesses(opB, values);
90 
91   // For each 'opX' in block in range ('opA', 'opB') in reverse order,
92   // check if there is a data dependence from 'opX' to 'opB':
93   // *) 'opX' and 'opB' access the same memref and at least one of the accesses
94   //    is a store.
95   // *) 'opX' produces an SSA Value which is used by 'opB'.
96   Operation *lastDepOp = nullptr;
97   for (Block::reverse_iterator it = std::next(Block::reverse_iterator(opB));
98        it != Block::reverse_iterator(opA); ++it) {
99     Operation *opX = &(*it);
100     opX->walk([&](Operation *op) {
101       if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
102         if (isDependentLoadOrStoreOp(op, values)) {
103           lastDepOp = opX;
104           return WalkResult::interrupt();
105         }
106         return WalkResult::advance();
107       }
108       for (Value value : op->getResults()) {
109         for (Operation *user : value.getUsers()) {
110           SmallVector<AffineForOp, 4> loops;
111           // Check if any loop in loop nest surrounding 'user' is 'opB'.
112           getAffineForIVs(*user, &loops);
113           if (llvm::is_contained(loops, cast<AffineForOp>(opB))) {
114             lastDepOp = opX;
115             return WalkResult::interrupt();
116           }
117         }
118       }
119       return WalkResult::advance();
120     });
121     if (lastDepOp)
122       break;
123   }
124   return lastDepOp;
125 }
126 
127 // Computes and returns an insertion point operation, before which the
128 // the fused <srcForOp, dstForOp> loop nest can be inserted while preserving
129 // dependences. Returns nullptr if no such insertion point is found.
130 static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp,
131                                                  AffineForOp dstForOp) {
132   bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp);
133   auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
134   auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
135 
136   Operation *firstDepOpA = getFirstDependentOpInRange(forOpA, forOpB);
137   Operation *lastDepOpB = getLastDependentOpInRange(forOpA, forOpB);
138   // Block:
139   //      ...
140   //  |-- opA
141   //  |   ...
142   //  |   lastDepOpB --|
143   //  |   ...          |
144   //  |-> firstDepOpA  |
145   //      ...          |
146   //      opB <---------
147   //
148   // Valid insertion point range: (lastDepOpB, firstDepOpA)
149   //
150   if (firstDepOpA != nullptr) {
151     if (lastDepOpB != nullptr) {
152       if (firstDepOpA->isBeforeInBlock(lastDepOpB) || firstDepOpA == lastDepOpB)
153         // No valid insertion point exists which preserves dependences.
154         return nullptr;
155     }
156     // Return insertion point in valid range closest to 'opB'.
157     // TODO: Consider other insertion points in valid range.
158     return firstDepOpA;
159   }
160   // No dependences from 'opA' to operation in range ('opA', 'opB'), return
161   // 'opB' insertion point.
162   return forOpB;
163 }
164 
165 // Gathers all load and store ops in loop nest rooted at 'forOp' into
166 // 'loadAndStoreOps'.
167 static bool
168 gatherLoadsAndStores(AffineForOp forOp,
169                      SmallVectorImpl<Operation *> &loadAndStoreOps) {
170   bool hasIfOp = false;
171   forOp.walk([&](Operation *op) {
172     if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op))
173       loadAndStoreOps.push_back(op);
174     else if (isa<AffineIfOp>(op))
175       hasIfOp = true;
176   });
177   return !hasIfOp;
178 }
179 
180 /// Returns the maximum loop depth at which we could fuse producer loop
181 /// 'srcForOp' into consumer loop 'dstForOp' without violating data dependences.
182 // TODO: Generalize this check for sibling and more generic fusion scenarios.
183 // TODO: Support forward slice fusion.
184 static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
185                                 ArrayRef<Operation *> dstOps) {
186   if (dstOps.empty())
187     // Expected at least one memory operation.
188     // TODO: Revisit this case with a specific example.
189     return 0;
190 
191   // Filter out ops in 'dstOps' that do not use the producer-consumer memref so
192   // that they are not considered for analysis.
193   DenseSet<Value> producerConsumerMemrefs;
194   gatherProducerConsumerMemrefs(srcOps, dstOps, producerConsumerMemrefs);
195   SmallVector<Operation *, 4> targetDstOps;
196   for (Operation *dstOp : dstOps) {
197     auto loadOp = dyn_cast<AffineReadOpInterface>(dstOp);
198     Value memref = loadOp ? loadOp.getMemRef()
199                           : cast<AffineWriteOpInterface>(dstOp).getMemRef();
200     if (producerConsumerMemrefs.count(memref) > 0)
201       targetDstOps.push_back(dstOp);
202   }
203 
204   assert(!targetDstOps.empty() &&
205          "No dependences between 'srcForOp' and 'dstForOp'?");
206 
207   // Compute the innermost common loop depth for loads and stores.
208   unsigned loopDepth = getInnermostCommonLoopDepth(targetDstOps);
209 
210   // Return common loop depth for loads if there are no store ops.
211   if (all_of(targetDstOps,
212              [&](Operation *op) { return isa<AffineReadOpInterface>(op); }))
213     return loopDepth;
214 
215   // Check dependences on all pairs of ops in 'targetDstOps' and store the
216   // minimum loop depth at which a dependence is satisfied.
217   for (unsigned i = 0, e = targetDstOps.size(); i < e; ++i) {
218     auto *srcOpInst = targetDstOps[i];
219     MemRefAccess srcAccess(srcOpInst);
220     for (unsigned j = 0; j < e; ++j) {
221       auto *dstOpInst = targetDstOps[j];
222       MemRefAccess dstAccess(dstOpInst);
223 
224       unsigned numCommonLoops =
225           getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
226       for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
227         FlatAffineValueConstraints dependenceConstraints;
228         // TODO: Cache dependence analysis results, check cache here.
229         DependenceResult result = checkMemrefAccessDependence(
230             srcAccess, dstAccess, d, &dependenceConstraints,
231             /*dependenceComponents=*/nullptr);
232         if (hasDependence(result)) {
233           // Store minimum loop depth and break because we want the min 'd' at
234           // which there is a dependence.
235           loopDepth = std::min(loopDepth, d - 1);
236           break;
237         }
238       }
239     }
240   }
241 
242   return loopDepth;
243 }
244 
245 // TODO: Prevent fusion of loop nests with side-effecting operations.
246 // TODO: This pass performs some computation that is the same for all the depths
247 // (e.g., getMaxLoopDepth). Implement a version of this utility that processes
248 // all the depths at once or only the legal maximal depth for maximal fusion.
249 FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
250                                 unsigned dstLoopDepth,
251                                 ComputationSliceState *srcSlice,
252                                 FusionStrategy fusionStrategy) {
253   // Return 'failure' if 'dstLoopDepth == 0'.
254   if (dstLoopDepth == 0) {
255     LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n");
256     return FusionResult::FailPrecondition;
257   }
258   // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block.
259   auto *block = srcForOp->getBlock();
260   if (block != dstForOp->getBlock()) {
261     LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n");
262     return FusionResult::FailPrecondition;
263   }
264 
265   // Return 'failure' if no valid insertion point for fused loop nest in 'block'
266   // exists which would preserve dependences.
267   if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) {
268     LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n");
269     return FusionResult::FailBlockDependence;
270   }
271 
272   // Check if 'srcForOp' precedes 'dstForOp' in 'block'.
273   bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp);
274   // 'forOpA' executes before 'forOpB' in 'block'.
275   auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
276   auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
277 
278   // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'.
279   SmallVector<Operation *, 4> opsA;
280   if (!gatherLoadsAndStores(forOpA, opsA)) {
281     LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
282     return FusionResult::FailPrecondition;
283   }
284 
285   // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'.
286   SmallVector<Operation *, 4> opsB;
287   if (!gatherLoadsAndStores(forOpB, opsB)) {
288     LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
289     return FusionResult::FailPrecondition;
290   }
291 
292   // Return 'failure' if fusing loops at depth 'dstLoopDepth' wouldn't preserve
293   // loop dependences.
294   // TODO: Enable this check for sibling and more generic loop fusion
295   // strategies.
296   if (fusionStrategy.getStrategy() == FusionStrategy::ProducerConsumer) {
297     // TODO: 'getMaxLoopDepth' does not support forward slice fusion.
298     assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion");
299     if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) {
300       LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n");
301       return FusionResult::FailFusionDependence;
302     }
303   }
304 
305   // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'.
306   unsigned numCommonLoops =
307       mlir::getNumCommonSurroundingLoops(*srcForOp, *dstForOp);
308 
309   // Filter out ops in 'opsA' to compute the slice union based on the
310   // assumptions made by the fusion strategy.
311   SmallVector<Operation *, 4> strategyOpsA;
312   switch (fusionStrategy.getStrategy()) {
313   case FusionStrategy::Generic:
314     // Generic fusion. Take into account all the memory operations to compute
315     // the slice union.
316     strategyOpsA.append(opsA.begin(), opsA.end());
317     break;
318   case FusionStrategy::ProducerConsumer:
319     // Producer-consumer fusion (AffineLoopFusion pass) only takes into
320     // account stores in 'srcForOp' to compute the slice union.
321     for (Operation *op : opsA) {
322       if (isa<AffineWriteOpInterface>(op))
323         strategyOpsA.push_back(op);
324     }
325     break;
326   case FusionStrategy::Sibling:
327     // Sibling fusion (AffineLoopFusion pass) only takes into account the loads
328     // to 'memref' in 'srcForOp' to compute the slice union.
329     for (Operation *op : opsA) {
330       auto load = dyn_cast<AffineReadOpInterface>(op);
331       if (load && load.getMemRef() == fusionStrategy.getSiblingFusionMemRef())
332         strategyOpsA.push_back(op);
333     }
334     break;
335   }
336 
337   // Compute union of computation slices computed between all pairs of ops
338   // from 'forOpA' and 'forOpB'.
339   SliceComputationResult sliceComputationResult =
340       mlir::computeSliceUnion(strategyOpsA, opsB, dstLoopDepth, numCommonLoops,
341                               isSrcForOpBeforeDstForOp, srcSlice);
342   if (sliceComputationResult.value == SliceComputationResult::GenericFailure) {
343     LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
344     return FusionResult::FailPrecondition;
345   }
346   if (sliceComputationResult.value ==
347       SliceComputationResult::IncorrectSliceFailure) {
348     LLVM_DEBUG(llvm::dbgs() << "Incorrect slice computation\n");
349     return FusionResult::FailIncorrectSlice;
350   }
351 
352   return FusionResult::Success;
353 }
354 
355 /// Patch the loop body of a forOp that is a single iteration reduction loop
356 /// into its containing block.
357 LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp,
358                                              bool siblingFusionUser) {
359   // Check if the reduction loop is a single iteration loop.
360   Optional<uint64_t> tripCount = getConstantTripCount(forOp);
361   if (!tripCount || *tripCount != 1)
362     return failure();
363   auto iterOperands = forOp.getIterOperands();
364   auto *parentOp = forOp->getParentOp();
365   if (!isa<AffineForOp>(parentOp))
366     return failure();
367   auto newOperands = forOp.getBody()->getTerminator()->getOperands();
368   OpBuilder b(parentOp);
369   // Replace the parent loop and add iteroperands and results from the `forOp`.
370   AffineForOp parentForOp = forOp->getParentOfType<AffineForOp>();
371   AffineForOp newLoop = replaceForOpWithNewYields(
372       b, parentForOp, iterOperands, newOperands, forOp.getRegionIterArgs());
373 
374   // For sibling-fusion users, collect operations that use the results of the
375   // `forOp` outside the new parent loop that has absorbed all its iter args
376   // and operands. These operations will be moved later after the results
377   // have been replaced.
378   SetVector<Operation *> forwardSlice;
379   if (siblingFusionUser) {
380     for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
381       SetVector<Operation *> tmpForwardSlice;
382       getForwardSlice(forOp.getResult(i), &tmpForwardSlice);
383       forwardSlice.set_union(tmpForwardSlice);
384     }
385   }
386   // Update the results of the `forOp` in the new loop.
387   for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
388     forOp.getResult(i).replaceAllUsesWith(
389         newLoop.getResult(i + parentOp->getNumResults()));
390   }
391   // For sibling-fusion users, move operations that use the results of the
392   // `forOp` outside the new parent loop
393   if (siblingFusionUser) {
394     topologicalSort(forwardSlice);
395     for (Operation *op : llvm::reverse(forwardSlice))
396       op->moveAfter(newLoop);
397   }
398   // Replace the induction variable.
399   auto iv = forOp.getInductionVar();
400   iv.replaceAllUsesWith(newLoop.getInductionVar());
401   // Replace the iter args.
402   auto forOpIterArgs = forOp.getRegionIterArgs();
403   for (auto it : llvm::zip(forOpIterArgs, newLoop.getRegionIterArgs().take_back(
404                                               forOpIterArgs.size()))) {
405     std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
406   }
407   // Move the loop body operations, except for its terminator, to the loop's
408   // containing block.
409   forOp.getBody()->back().erase();
410   auto *parentBlock = forOp->getBlock();
411   parentBlock->getOperations().splice(Block::iterator(forOp),
412                                       forOp.getBody()->getOperations());
413   forOp.erase();
414   parentForOp.erase();
415   return success();
416 }
417 
418 /// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point
419 /// and source slice loop bounds specified in 'srcSlice'.
420 void mlir::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
421                      const ComputationSliceState &srcSlice,
422                      bool isInnermostSiblingInsertion) {
423   // Clone 'srcForOp' into 'dstForOp' at 'srcSlice->insertPoint'.
424   OpBuilder b(srcSlice.insertPoint->getBlock(), srcSlice.insertPoint);
425   BlockAndValueMapping mapper;
426   b.clone(*srcForOp, mapper);
427 
428   // Update 'sliceLoopNest' upper and lower bounds from computed 'srcSlice'.
429   SmallVector<AffineForOp, 4> sliceLoops;
430   for (unsigned i = 0, e = srcSlice.ivs.size(); i < e; ++i) {
431     auto loopIV = mapper.lookupOrNull(srcSlice.ivs[i]);
432     if (!loopIV)
433       continue;
434     auto forOp = getForInductionVarOwner(loopIV);
435     sliceLoops.push_back(forOp);
436     if (AffineMap lbMap = srcSlice.lbs[i]) {
437       auto lbOperands = srcSlice.lbOperands[i];
438       canonicalizeMapAndOperands(&lbMap, &lbOperands);
439       forOp.setLowerBound(lbOperands, lbMap);
440     }
441     if (AffineMap ubMap = srcSlice.ubs[i]) {
442       auto ubOperands = srcSlice.ubOperands[i];
443       canonicalizeMapAndOperands(&ubMap, &ubOperands);
444       forOp.setUpperBound(ubOperands, ubMap);
445     }
446   }
447 
448   llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
449   auto srcIsUnitSlice = [&]() {
450     return (buildSliceTripCountMap(srcSlice, &sliceTripCountMap) &&
451             (getSliceIterationCount(sliceTripCountMap) == 1));
452   };
453   // Fix up and if possible, eliminate single iteration loops.
454   for (AffineForOp forOp : sliceLoops) {
455     if (isLoopParallelAndContainsReduction(forOp) &&
456         isInnermostSiblingInsertion && srcIsUnitSlice())
457       // Patch reduction loop - only ones that are sibling-fused with the
458       // destination loop - into the parent loop.
459       (void)promoteSingleIterReductionLoop(forOp, true);
460     else
461       // Promote any single iteration slice loops.
462       (void)promoteIfSingleIteration(forOp);
463   }
464 }
465 
466 /// Collect loop nest statistics (eg. loop trip count and operation count)
467 /// in 'stats' for loop nest rooted at 'forOp'. Returns true on success,
468 /// returns false otherwise.
469 bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) {
470   auto walkResult = forOpRoot.walk([&](AffineForOp forOp) {
471     auto *childForOp = forOp.getOperation();
472     auto *parentForOp = forOp->getParentOp();
473     if (forOp != forOpRoot) {
474       if (!isa<AffineForOp>(parentForOp)) {
475         LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n");
476         return WalkResult::interrupt();
477       }
478       // Add mapping to 'forOp' from its parent AffineForOp.
479       stats->loopMap[parentForOp].push_back(forOp);
480     }
481 
482     // Record the number of op operations in the body of 'forOp'.
483     unsigned count = 0;
484     stats->opCountMap[childForOp] = 0;
485     for (auto &op : *forOp.getBody()) {
486       if (!isa<AffineForOp, AffineIfOp>(op))
487         ++count;
488     }
489     stats->opCountMap[childForOp] = count;
490 
491     // Record trip count for 'forOp'. Set flag if trip count is not
492     // constant.
493     Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
494     if (!maybeConstTripCount) {
495       // Currently only constant trip count loop nests are supported.
496       LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported\n");
497       return WalkResult::interrupt();
498     }
499 
500     stats->tripCountMap[childForOp] = *maybeConstTripCount;
501     return WalkResult::advance();
502   });
503   return !walkResult.wasInterrupted();
504 }
505 
506 // Computes the total cost of the loop nest rooted at 'forOp'.
507 // Currently, the total cost is computed by counting the total operation
508 // instance count (i.e. total number of operations in the loop bodyloop
509 // operation count * loop trip count) for the entire loop nest.
510 // If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
511 // specified in the map when computing the total op instance count.
512 // NOTEs: 1) This is used to compute the cost of computation slices, which are
513 // sliced along the iteration dimension, and thus reduce the trip count.
514 // If 'computeCostMap' is non-null, the total op count for forOps specified
515 // in the map is increased (not overridden) by adding the op count from the
516 // map to the existing op count for the for loop. This is done before
517 // multiplying by the loop's trip count, and is used to model the cost of
518 // inserting a sliced loop nest of known cost into the loop's body.
519 // 2) This is also used to compute the cost of fusing a slice of some loop nest
520 // within another loop.
521 static int64_t getComputeCostHelper(
522     Operation *forOp, LoopNestStats &stats,
523     llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap,
524     DenseMap<Operation *, int64_t> *computeCostMap) {
525   // 'opCount' is the total number operations in one iteration of 'forOp' body,
526   // minus terminator op which is a no-op.
527   int64_t opCount = stats.opCountMap[forOp] - 1;
528   if (stats.loopMap.count(forOp) > 0) {
529     for (auto childForOp : stats.loopMap[forOp]) {
530       opCount += getComputeCostHelper(childForOp, stats, tripCountOverrideMap,
531                                       computeCostMap);
532     }
533   }
534   // Add in additional op instances from slice (if specified in map).
535   if (computeCostMap != nullptr) {
536     auto it = computeCostMap->find(forOp);
537     if (it != computeCostMap->end()) {
538       opCount += it->second;
539     }
540   }
541   // Override trip count (if specified in map).
542   int64_t tripCount = stats.tripCountMap[forOp];
543   if (tripCountOverrideMap != nullptr) {
544     auto it = tripCountOverrideMap->find(forOp);
545     if (it != tripCountOverrideMap->end()) {
546       tripCount = it->second;
547     }
548   }
549   // Returns the total number of dynamic instances of operations in loop body.
550   return tripCount * opCount;
551 }
552 
553 /// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
554 /// Currently, the total cost is computed by counting the total operation
555 /// instance count (i.e. total number of operations in the loop body * loop
556 /// trip count) for the entire loop nest.
557 int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) {
558   return getComputeCostHelper(forOp, stats,
559                               /*tripCountOverrideMap=*/nullptr,
560                               /*computeCostMap=*/nullptr);
561 }
562 
563 /// Computes and returns in 'computeCost', the total compute cost of fusing the
564 /// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently,
565 /// the total cost is computed by counting the total operation instance count
566 /// (i.e. total number of operations in the loop body * loop trip count) for
567 /// the entire loop nest.
568 bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
569                                 AffineForOp dstForOp, LoopNestStats &dstStats,
570                                 const ComputationSliceState &slice,
571                                 int64_t *computeCost) {
572   llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
573   DenseMap<Operation *, int64_t> computeCostMap;
574 
575   // Build trip count map for computation slice.
576   if (!buildSliceTripCountMap(slice, &sliceTripCountMap))
577     return false;
578   // Checks whether a store to load forwarding will happen.
579   int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
580   assert(sliceIterationCount > 0);
581   bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
582   auto *insertPointParent = slice.insertPoint->getParentOp();
583 
584   // The store and loads to this memref will disappear.
585   // TODO: Add load coalescing to memref data flow opt pass.
586   if (storeLoadFwdGuaranteed) {
587     // Subtract from operation count the loads/store we expect load/store
588     // forwarding to remove.
589     unsigned storeCount = 0;
590     llvm::SmallDenseSet<Value, 4> storeMemrefs;
591     srcForOp.walk([&](Operation *op) {
592       if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
593         storeMemrefs.insert(storeOp.getMemRef());
594         ++storeCount;
595       }
596     });
597     // Subtract out any store ops in single-iteration src slice loop nest.
598     if (storeCount > 0)
599       computeCostMap[insertPointParent] = -storeCount;
600     // Subtract out any load users of 'storeMemrefs' nested below
601     // 'insertPointParent'.
602     for (Value memref : storeMemrefs) {
603       for (auto *user : memref.getUsers()) {
604         if (auto loadOp = dyn_cast<AffineReadOpInterface>(user)) {
605           SmallVector<AffineForOp, 4> loops;
606           // Check if any loop in loop nest surrounding 'user' is
607           // 'insertPointParent'.
608           getAffineForIVs(*user, &loops);
609           if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) {
610             if (auto forOp =
611                     dyn_cast_or_null<AffineForOp>(user->getParentOp())) {
612               if (computeCostMap.count(forOp) == 0)
613                 computeCostMap[forOp] = 0;
614               computeCostMap[forOp] -= 1;
615             }
616           }
617         }
618       }
619     }
620   }
621 
622   // Compute op instance count for the src loop nest with iteration slicing.
623   int64_t sliceComputeCost = getComputeCostHelper(
624       srcForOp, srcStats, &sliceTripCountMap, &computeCostMap);
625 
626   // Compute cost of fusion for this depth.
627   computeCostMap[insertPointParent] = sliceComputeCost;
628 
629   *computeCost =
630       getComputeCostHelper(dstForOp, dstStats,
631                            /*tripCountOverrideMap=*/nullptr, &computeCostMap);
632   return true;
633 }
634 
635 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a
636 /// producer-consumer dependence between write ops in 'srcOps' and read ops in
637 /// 'dstOps'.
638 void mlir::gatherProducerConsumerMemrefs(
639     ArrayRef<Operation *> srcOps, ArrayRef<Operation *> dstOps,
640     DenseSet<Value> &producerConsumerMemrefs) {
641   // Gather memrefs from stores in 'srcOps'.
642   DenseSet<Value> srcStoreMemRefs;
643   for (Operation *op : srcOps)
644     if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op))
645       srcStoreMemRefs.insert(storeOp.getMemRef());
646 
647   // Compute the intersection between memrefs from stores in 'srcOps' and
648   // memrefs from loads in 'dstOps'.
649   for (Operation *op : dstOps)
650     if (auto loadOp = dyn_cast<AffineReadOpInterface>(op))
651       if (srcStoreMemRefs.count(loadOp.getMemRef()) > 0)
652         producerConsumerMemrefs.insert(loadOp.getMemRef());
653 }
654