1 //===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop fusion transformation utility functions.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Affine/LoopFusionUtils.h"
14 #include "mlir/Analysis/SliceAnalysis.h"
15 #include "mlir/Analysis/TopologicalSortUtils.h"
16 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
17 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
18 #include "mlir/Dialect/Affine/Analysis/Utils.h"
19 #include "mlir/Dialect/Affine/IR/AffineOps.h"
20 #include "mlir/Dialect/Affine/LoopUtils.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/Operation.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include <optional>
27
28 #define DEBUG_TYPE "loop-fusion-utils"
29
30 using namespace mlir;
31 using namespace mlir::affine;
32
33 // Gathers all load and store memref accesses in 'opA' into 'values', where
34 // 'values[memref] == true' for each store operation.
getLoadAndStoreMemRefAccesses(Operation * opA,DenseMap<Value,bool> & values)35 static void getLoadAndStoreMemRefAccesses(Operation *opA,
36 DenseMap<Value, bool> &values) {
37 opA->walk([&](Operation *op) {
38 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
39 if (values.count(loadOp.getMemRef()) == 0)
40 values[loadOp.getMemRef()] = false;
41 } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
42 values[storeOp.getMemRef()] = true;
43 }
44 });
45 }
46
47 /// Returns true if 'op' is a load or store operation which access a memref
48 /// accessed 'values' and at least one of the access is a store operation.
49 /// Returns false otherwise.
isDependentLoadOrStoreOp(Operation * op,DenseMap<Value,bool> & values)50 static bool isDependentLoadOrStoreOp(Operation *op,
51 DenseMap<Value, bool> &values) {
52 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
53 return values.count(loadOp.getMemRef()) > 0 && values[loadOp.getMemRef()];
54 }
55 if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
56 return values.count(storeOp.getMemRef()) > 0;
57 }
58 return false;
59 }
60
61 // Returns the first operation in range ('opA', 'opB') which has a data
62 // dependence on 'opA'. Returns 'nullptr' of no dependence exists.
getFirstDependentOpInRange(Operation * opA,Operation * opB)63 static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) {
64 // Record memref values from all loads/store in loop nest rooted at 'opA'.
65 // Map from memref value to bool which is true if store, false otherwise.
66 DenseMap<Value, bool> values;
67 getLoadAndStoreMemRefAccesses(opA, values);
68
69 // For each 'opX' in block in range ('opA', 'opB'), check if there is a data
70 // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref
71 // and at least one of the accesses is a store).
72 Operation *firstDepOp = nullptr;
73 for (Block::iterator it = std::next(Block::iterator(opA));
74 it != Block::iterator(opB); ++it) {
75 Operation *opX = &(*it);
76 opX->walk([&](Operation *op) {
77 if (!firstDepOp && isDependentLoadOrStoreOp(op, values))
78 firstDepOp = opX;
79 });
80 if (firstDepOp)
81 break;
82 }
83 return firstDepOp;
84 }
85
86 // Returns the last operation 'opX' in range ('opA', 'opB'), for which there
87 // exists a data dependence from 'opX' to 'opB'.
88 // Returns 'nullptr' of no dependence exists.
getLastDependentOpInRange(Operation * opA,Operation * opB)89 static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) {
90 // Record memref values from all loads/store in loop nest rooted at 'opB'.
91 // Map from memref value to bool which is true if store, false otherwise.
92 DenseMap<Value, bool> values;
93 getLoadAndStoreMemRefAccesses(opB, values);
94
95 // For each 'opX' in block in range ('opA', 'opB') in reverse order,
96 // check if there is a data dependence from 'opX' to 'opB':
97 // *) 'opX' and 'opB' access the same memref and at least one of the accesses
98 // is a store.
99 // *) 'opX' produces an SSA Value which is used by 'opB'.
100 Operation *lastDepOp = nullptr;
101 for (Block::reverse_iterator it = std::next(Block::reverse_iterator(opB));
102 it != Block::reverse_iterator(opA); ++it) {
103 Operation *opX = &(*it);
104 opX->walk([&](Operation *op) {
105 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
106 if (isDependentLoadOrStoreOp(op, values)) {
107 lastDepOp = opX;
108 return WalkResult::interrupt();
109 }
110 return WalkResult::advance();
111 }
112 for (Value value : op->getResults()) {
113 for (Operation *user : value.getUsers()) {
114 SmallVector<AffineForOp, 4> loops;
115 // Check if any loop in loop nest surrounding 'user' is 'opB'.
116 getAffineForIVs(*user, &loops);
117 if (llvm::is_contained(loops, cast<AffineForOp>(opB))) {
118 lastDepOp = opX;
119 return WalkResult::interrupt();
120 }
121 }
122 }
123 return WalkResult::advance();
124 });
125 if (lastDepOp)
126 break;
127 }
128 return lastDepOp;
129 }
130
131 // Computes and returns an insertion point operation, before which the
132 // the fused <srcForOp, dstForOp> loop nest can be inserted while preserving
133 // dependences. Returns nullptr if no such insertion point is found.
getFusedLoopNestInsertionPoint(AffineForOp srcForOp,AffineForOp dstForOp)134 static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp,
135 AffineForOp dstForOp) {
136 bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp);
137 auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
138 auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
139
140 Operation *firstDepOpA = getFirstDependentOpInRange(forOpA, forOpB);
141 Operation *lastDepOpB = getLastDependentOpInRange(forOpA, forOpB);
142 // Block:
143 // ...
144 // |-- opA
145 // | ...
146 // | lastDepOpB --|
147 // | ... |
148 // |-> firstDepOpA |
149 // ... |
150 // opB <---------
151 //
152 // Valid insertion point range: (lastDepOpB, firstDepOpA)
153 //
154 if (firstDepOpA) {
155 if (lastDepOpB) {
156 if (firstDepOpA->isBeforeInBlock(lastDepOpB) || firstDepOpA == lastDepOpB)
157 // No valid insertion point exists which preserves dependences.
158 return nullptr;
159 }
160 // Return insertion point in valid range closest to 'opB'.
161 // TODO: Consider other insertion points in valid range.
162 return firstDepOpA;
163 }
164 // No dependences from 'opA' to operation in range ('opA', 'opB'), return
165 // 'opB' insertion point.
166 return forOpB;
167 }
168
169 // Gathers all load and store ops in loop nest rooted at 'forOp' into
170 // 'loadAndStoreOps'.
171 static bool
gatherLoadsAndStores(AffineForOp forOp,SmallVectorImpl<Operation * > & loadAndStoreOps)172 gatherLoadsAndStores(AffineForOp forOp,
173 SmallVectorImpl<Operation *> &loadAndStoreOps) {
174 bool hasIfOp = false;
175 forOp.walk([&](Operation *op) {
176 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op))
177 loadAndStoreOps.push_back(op);
178 else if (isa<AffineIfOp>(op))
179 hasIfOp = true;
180 });
181 return !hasIfOp;
182 }
183
184 /// Returns the maximum loop depth at which we could fuse producer loop
185 /// 'srcForOp' into consumer loop 'dstForOp' without violating data dependences.
186 // TODO: Generalize this check for sibling and more generic fusion scenarios.
187 // TODO: Support forward slice fusion.
getMaxLoopDepth(ArrayRef<Operation * > srcOps,ArrayRef<Operation * > dstOps)188 static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
189 ArrayRef<Operation *> dstOps) {
190 if (dstOps.empty())
191 // Expected at least one memory operation.
192 // TODO: Revisit this case with a specific example.
193 return 0;
194
195 // Filter out ops in 'dstOps' that do not use the producer-consumer memref so
196 // that they are not considered for analysis.
197 DenseSet<Value> producerConsumerMemrefs;
198 gatherProducerConsumerMemrefs(srcOps, dstOps, producerConsumerMemrefs);
199 SmallVector<Operation *, 4> targetDstOps;
200 for (Operation *dstOp : dstOps) {
201 auto loadOp = dyn_cast<AffineReadOpInterface>(dstOp);
202 Value memref = loadOp ? loadOp.getMemRef()
203 : cast<AffineWriteOpInterface>(dstOp).getMemRef();
204 if (producerConsumerMemrefs.count(memref) > 0)
205 targetDstOps.push_back(dstOp);
206 }
207
208 assert(!targetDstOps.empty() &&
209 "No dependences between 'srcForOp' and 'dstForOp'?");
210
211 // Compute the innermost common loop depth for loads and stores.
212 unsigned loopDepth = getInnermostCommonLoopDepth(targetDstOps);
213
214 // Return common loop depth for loads if there are no store ops.
215 if (all_of(targetDstOps, llvm::IsaPred<AffineReadOpInterface>))
216 return loopDepth;
217
218 // Check dependences on all pairs of ops in 'targetDstOps' and store the
219 // minimum loop depth at which a dependence is satisfied.
220 for (unsigned i = 0, e = targetDstOps.size(); i < e; ++i) {
221 Operation *srcOpInst = targetDstOps[i];
222 MemRefAccess srcAccess(srcOpInst);
223 for (unsigned j = 0; j < e; ++j) {
224 auto *dstOpInst = targetDstOps[j];
225 MemRefAccess dstAccess(dstOpInst);
226
227 unsigned numCommonLoops =
228 getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
229 for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
230 // TODO: Cache dependence analysis results, check cache here.
231 DependenceResult result =
232 checkMemrefAccessDependence(srcAccess, dstAccess, d);
233 if (hasDependence(result)) {
234 // Store minimum loop depth and break because we want the min 'd' at
235 // which there is a dependence.
236 loopDepth = std::min(loopDepth, d - 1);
237 break;
238 }
239 }
240 }
241 }
242
243 return loopDepth;
244 }
245
246 // TODO: This pass performs some computation that is the same for all the depths
247 // (e.g., getMaxLoopDepth). Implement a version of this utility that processes
248 // all the depths at once or only the legal maximal depth for maximal fusion.
canFuseLoops(AffineForOp srcForOp,AffineForOp dstForOp,unsigned dstLoopDepth,ComputationSliceState * srcSlice,FusionStrategy fusionStrategy)249 FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp,
250 AffineForOp dstForOp,
251 unsigned dstLoopDepth,
252 ComputationSliceState *srcSlice,
253 FusionStrategy fusionStrategy) {
254 // Return 'failure' if 'dstLoopDepth == 0'.
255 if (dstLoopDepth == 0) {
256 LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n");
257 return FusionResult::FailPrecondition;
258 }
259 // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block.
260 auto *block = srcForOp->getBlock();
261 if (block != dstForOp->getBlock()) {
262 LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n");
263 return FusionResult::FailPrecondition;
264 }
265
266 // Return 'failure' if no valid insertion point for fused loop nest in 'block'
267 // exists which would preserve dependences.
268 if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) {
269 LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n");
270 return FusionResult::FailBlockDependence;
271 }
272
273 // Check if 'srcForOp' precedes 'dstForOp' in 'block'.
274 bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp);
275 // 'forOpA' executes before 'forOpB' in 'block'.
276 auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
277 auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
278
279 // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'.
280 SmallVector<Operation *, 4> opsA;
281 if (!gatherLoadsAndStores(forOpA, opsA)) {
282 LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
283 return FusionResult::FailPrecondition;
284 }
285
286 // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'.
287 SmallVector<Operation *, 4> opsB;
288 if (!gatherLoadsAndStores(forOpB, opsB)) {
289 LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
290 return FusionResult::FailPrecondition;
291 }
292
293 // Return 'failure' if fusing loops at depth 'dstLoopDepth' wouldn't preserve
294 // loop dependences.
295 // TODO: Enable this check for sibling and more generic loop fusion
296 // strategies.
297 if (fusionStrategy.getStrategy() == FusionStrategy::ProducerConsumer) {
298 // TODO: 'getMaxLoopDepth' does not support forward slice fusion.
299 assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion");
300 if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) {
301 LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n");
302 return FusionResult::FailFusionDependence;
303 }
304 }
305
306 // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'.
307 unsigned numCommonLoops =
308 affine::getNumCommonSurroundingLoops(*srcForOp, *dstForOp);
309
310 // Filter out ops in 'opsA' to compute the slice union based on the
311 // assumptions made by the fusion strategy.
312 SmallVector<Operation *, 4> strategyOpsA;
313 switch (fusionStrategy.getStrategy()) {
314 case FusionStrategy::Generic:
315 // Generic fusion. Take into account all the memory operations to compute
316 // the slice union.
317 strategyOpsA.append(opsA.begin(), opsA.end());
318 break;
319 case FusionStrategy::ProducerConsumer:
320 // Producer-consumer fusion (AffineLoopFusion pass) only takes into
321 // account stores in 'srcForOp' to compute the slice union.
322 for (Operation *op : opsA) {
323 if (isa<AffineWriteOpInterface>(op))
324 strategyOpsA.push_back(op);
325 }
326 break;
327 case FusionStrategy::Sibling:
328 // Sibling fusion (AffineLoopFusion pass) only takes into account the loads
329 // to 'memref' in 'srcForOp' to compute the slice union.
330 for (Operation *op : opsA) {
331 auto load = dyn_cast<AffineReadOpInterface>(op);
332 if (load && load.getMemRef() == fusionStrategy.getSiblingFusionMemRef())
333 strategyOpsA.push_back(op);
334 }
335 break;
336 }
337
338 // Compute union of computation slices computed between all pairs of ops
339 // from 'forOpA' and 'forOpB'.
340 SliceComputationResult sliceComputationResult = affine::computeSliceUnion(
341 strategyOpsA, opsB, dstLoopDepth, numCommonLoops,
342 isSrcForOpBeforeDstForOp, srcSlice);
343 if (sliceComputationResult.value == SliceComputationResult::GenericFailure) {
344 LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
345 return FusionResult::FailPrecondition;
346 }
347 if (sliceComputationResult.value ==
348 SliceComputationResult::IncorrectSliceFailure) {
349 LLVM_DEBUG(llvm::dbgs() << "Incorrect slice computation\n");
350 return FusionResult::FailIncorrectSlice;
351 }
352
353 return FusionResult::Success;
354 }
355
356 /// Patch the loop body of a forOp that is a single iteration reduction loop
357 /// into its containing block.
promoteSingleIterReductionLoop(AffineForOp forOp,bool siblingFusionUser)358 static LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp,
359 bool siblingFusionUser) {
360 // Check if the reduction loop is a single iteration loop.
361 std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
362 if (!tripCount || *tripCount != 1)
363 return failure();
364 auto *parentOp = forOp->getParentOp();
365 if (!isa<AffineForOp>(parentOp))
366 return failure();
367 SmallVector<Value> newOperands;
368 llvm::append_range(newOperands,
369 forOp.getBody()->getTerminator()->getOperands());
370 IRRewriter rewriter(parentOp->getContext());
371 int64_t parentOpNumResults = parentOp->getNumResults();
372 // Replace the parent loop and add iteroperands and results from the `forOp`.
373 AffineForOp parentForOp = forOp->getParentOfType<AffineForOp>();
374 AffineForOp newLoop =
375 cast<AffineForOp>(*parentForOp.replaceWithAdditionalYields(
376 rewriter, forOp.getInits(), /*replaceInitOperandUsesInLoop=*/false,
377 [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
378 return newOperands;
379 }));
380
381 // For sibling-fusion users, collect operations that use the results of the
382 // `forOp` outside the new parent loop that has absorbed all its iter args
383 // and operands. These operations will be moved later after the results
384 // have been replaced.
385 SetVector<Operation *> forwardSlice;
386 if (siblingFusionUser) {
387 for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
388 SetVector<Operation *> tmpForwardSlice;
389 getForwardSlice(forOp.getResult(i), &tmpForwardSlice);
390 forwardSlice.set_union(tmpForwardSlice);
391 }
392 }
393 // Update the results of the `forOp` in the new loop.
394 for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
395 forOp.getResult(i).replaceAllUsesWith(
396 newLoop.getResult(i + parentOpNumResults));
397 }
398 // For sibling-fusion users, move operations that use the results of the
399 // `forOp` outside the new parent loop
400 if (siblingFusionUser) {
401 topologicalSort(forwardSlice);
402 for (Operation *op : llvm::reverse(forwardSlice))
403 op->moveAfter(newLoop);
404 }
405 // Replace the induction variable.
406 auto iv = forOp.getInductionVar();
407 iv.replaceAllUsesWith(newLoop.getInductionVar());
408 // Replace the iter args.
409 auto forOpIterArgs = forOp.getRegionIterArgs();
410 for (auto it : llvm::zip(forOpIterArgs, newLoop.getRegionIterArgs().take_back(
411 forOpIterArgs.size()))) {
412 std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
413 }
414 // Move the loop body operations, except for its terminator, to the loop's
415 // containing block.
416 forOp.getBody()->back().erase();
417 auto *parentBlock = forOp->getBlock();
418 parentBlock->getOperations().splice(Block::iterator(forOp),
419 forOp.getBody()->getOperations());
420 forOp.erase();
421 return success();
422 }
423
424 /// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point
425 /// and source slice loop bounds specified in 'srcSlice'.
fuseLoops(AffineForOp srcForOp,AffineForOp dstForOp,const ComputationSliceState & srcSlice,bool isInnermostSiblingInsertion)426 void mlir::affine::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
427 const ComputationSliceState &srcSlice,
428 bool isInnermostSiblingInsertion) {
429 // Clone 'srcForOp' into 'dstForOp' at 'srcSlice->insertPoint'.
430 OpBuilder b(srcSlice.insertPoint->getBlock(), srcSlice.insertPoint);
431 IRMapping mapper;
432 b.clone(*srcForOp, mapper);
433
434 // Update 'sliceLoopNest' upper and lower bounds from computed 'srcSlice'.
435 SmallVector<AffineForOp, 4> sliceLoops;
436 for (unsigned i = 0, e = srcSlice.ivs.size(); i < e; ++i) {
437 auto loopIV = mapper.lookupOrNull(srcSlice.ivs[i]);
438 if (!loopIV)
439 continue;
440 auto forOp = getForInductionVarOwner(loopIV);
441 sliceLoops.push_back(forOp);
442 if (AffineMap lbMap = srcSlice.lbs[i]) {
443 auto lbOperands = srcSlice.lbOperands[i];
444 canonicalizeMapAndOperands(&lbMap, &lbOperands);
445 forOp.setLowerBound(lbOperands, lbMap);
446 }
447 if (AffineMap ubMap = srcSlice.ubs[i]) {
448 auto ubOperands = srcSlice.ubOperands[i];
449 canonicalizeMapAndOperands(&ubMap, &ubOperands);
450 forOp.setUpperBound(ubOperands, ubMap);
451 }
452 }
453
454 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
455 auto srcIsUnitSlice = [&]() {
456 return (buildSliceTripCountMap(srcSlice, &sliceTripCountMap) &&
457 (getSliceIterationCount(sliceTripCountMap) == 1));
458 };
459 // Fix up and if possible, eliminate single iteration loops.
460 for (AffineForOp forOp : sliceLoops) {
461 if (isLoopParallelAndContainsReduction(forOp) &&
462 isInnermostSiblingInsertion && srcIsUnitSlice())
463 // Patch reduction loop - only ones that are sibling-fused with the
464 // destination loop - into the parent loop.
465 (void)promoteSingleIterReductionLoop(forOp, true);
466 else
467 // Promote any single iteration slice loops.
468 (void)promoteIfSingleIteration(forOp);
469 }
470 }
471
472 /// Collect loop nest statistics (eg. loop trip count and operation count)
473 /// in 'stats' for loop nest rooted at 'forOp'. Returns true on success,
474 /// returns false otherwise.
getLoopNestStats(AffineForOp forOpRoot,LoopNestStats * stats)475 bool mlir::affine::getLoopNestStats(AffineForOp forOpRoot,
476 LoopNestStats *stats) {
477 auto walkResult = forOpRoot.walk([&](AffineForOp forOp) {
478 auto *childForOp = forOp.getOperation();
479 auto *parentForOp = forOp->getParentOp();
480 if (forOp != forOpRoot) {
481 if (!isa<AffineForOp>(parentForOp)) {
482 LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n");
483 return WalkResult::interrupt();
484 }
485 // Add mapping to 'forOp' from its parent AffineForOp.
486 stats->loopMap[parentForOp].push_back(forOp);
487 }
488
489 // Record the number of op operations in the body of 'forOp'.
490 unsigned count = 0;
491 stats->opCountMap[childForOp] = 0;
492 for (auto &op : *forOp.getBody()) {
493 if (!isa<AffineForOp, AffineIfOp>(op))
494 ++count;
495 }
496 stats->opCountMap[childForOp] = count;
497
498 // Record trip count for 'forOp'. Set flag if trip count is not
499 // constant.
500 std::optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
501 if (!maybeConstTripCount) {
502 // Currently only constant trip count loop nests are supported.
503 LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported\n");
504 return WalkResult::interrupt();
505 }
506
507 stats->tripCountMap[childForOp] = *maybeConstTripCount;
508 return WalkResult::advance();
509 });
510 return !walkResult.wasInterrupted();
511 }
512
513 // Computes the total cost of the loop nest rooted at 'forOp'.
514 // Currently, the total cost is computed by counting the total operation
515 // instance count (i.e. total number of operations in the loop bodyloop
516 // operation count * loop trip count) for the entire loop nest.
517 // If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
518 // specified in the map when computing the total op instance count.
519 // NOTEs: 1) This is used to compute the cost of computation slices, which are
520 // sliced along the iteration dimension, and thus reduce the trip count.
521 // If 'computeCostMap' is non-null, the total op count for forOps specified
522 // in the map is increased (not overridden) by adding the op count from the
523 // map to the existing op count for the for loop. This is done before
524 // multiplying by the loop's trip count, and is used to model the cost of
525 // inserting a sliced loop nest of known cost into the loop's body.
526 // 2) This is also used to compute the cost of fusing a slice of some loop nest
527 // within another loop.
getComputeCostHelper(Operation * forOp,LoopNestStats & stats,llvm::SmallDenseMap<Operation *,uint64_t,8> * tripCountOverrideMap,DenseMap<Operation *,int64_t> * computeCostMap)528 static int64_t getComputeCostHelper(
529 Operation *forOp, LoopNestStats &stats,
530 llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap,
531 DenseMap<Operation *, int64_t> *computeCostMap) {
532 // 'opCount' is the total number operations in one iteration of 'forOp' body,
533 // minus terminator op which is a no-op.
534 int64_t opCount = stats.opCountMap[forOp] - 1;
535 if (stats.loopMap.count(forOp) > 0) {
536 for (auto childForOp : stats.loopMap[forOp]) {
537 opCount += getComputeCostHelper(childForOp, stats, tripCountOverrideMap,
538 computeCostMap);
539 }
540 }
541 // Add in additional op instances from slice (if specified in map).
542 if (computeCostMap) {
543 auto it = computeCostMap->find(forOp);
544 if (it != computeCostMap->end()) {
545 opCount += it->second;
546 }
547 }
548 // Override trip count (if specified in map).
549 int64_t tripCount = stats.tripCountMap[forOp];
550 if (tripCountOverrideMap) {
551 auto it = tripCountOverrideMap->find(forOp);
552 if (it != tripCountOverrideMap->end()) {
553 tripCount = it->second;
554 }
555 }
556 // Returns the total number of dynamic instances of operations in loop body.
557 return tripCount * opCount;
558 }
559
560 /// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
561 /// Currently, the total cost is computed by counting the total operation
562 /// instance count (i.e. total number of operations in the loop body * loop
563 /// trip count) for the entire loop nest.
getComputeCost(AffineForOp forOp,LoopNestStats & stats)564 int64_t mlir::affine::getComputeCost(AffineForOp forOp, LoopNestStats &stats) {
565 return getComputeCostHelper(forOp, stats,
566 /*tripCountOverrideMap=*/nullptr,
567 /*computeCostMap=*/nullptr);
568 }
569
570 /// Computes and returns in 'computeCost', the total compute cost of fusing the
571 /// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently,
572 /// the total cost is computed by counting the total operation instance count
573 /// (i.e. total number of operations in the loop body * loop trip count) for
574 /// the entire loop nest.
getFusionComputeCost(AffineForOp srcForOp,LoopNestStats & srcStats,AffineForOp dstForOp,LoopNestStats & dstStats,const ComputationSliceState & slice,int64_t * computeCost)575 bool mlir::affine::getFusionComputeCost(AffineForOp srcForOp,
576 LoopNestStats &srcStats,
577 AffineForOp dstForOp,
578 LoopNestStats &dstStats,
579 const ComputationSliceState &slice,
580 int64_t *computeCost) {
581 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
582 DenseMap<Operation *, int64_t> computeCostMap;
583
584 // Build trip count map for computation slice.
585 if (!buildSliceTripCountMap(slice, &sliceTripCountMap))
586 return false;
587 // Checks whether a store to load forwarding will happen.
588 int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
589 assert(sliceIterationCount > 0);
590 bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
591 auto *insertPointParent = slice.insertPoint->getParentOp();
592
593 // The store and loads to this memref will disappear.
594 if (storeLoadFwdGuaranteed) {
595 // Subtract from operation count the loads/store we expect load/store
596 // forwarding to remove.
597 unsigned storeCount = 0;
598 llvm::SmallDenseSet<Value, 4> storeMemrefs;
599 srcForOp.walk([&](AffineWriteOpInterface storeOp) {
600 storeMemrefs.insert(storeOp.getMemRef());
601 ++storeCount;
602 });
603 // Subtract out any store ops in single-iteration src slice loop nest.
604 if (storeCount > 0)
605 computeCostMap[insertPointParent] = -storeCount;
606 // Subtract out any load users of 'storeMemrefs' nested below
607 // 'insertPointParent'.
608 for (Value memref : storeMemrefs) {
609 for (Operation *user : memref.getUsers()) {
610 if (!isa<AffineReadOpInterface>(user))
611 continue;
612 SmallVector<AffineForOp, 4> loops;
613 // Check if any loop in loop nest surrounding 'user' is
614 // 'insertPointParent'.
615 getAffineForIVs(*user, &loops);
616 if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) {
617 if (auto forOp = dyn_cast_or_null<AffineForOp>(user->getParentOp())) {
618 if (computeCostMap.count(forOp) == 0)
619 computeCostMap[forOp] = 0;
620 computeCostMap[forOp] -= 1;
621 }
622 }
623 }
624 }
625 }
626
627 // Compute op instance count for the src loop nest with iteration slicing.
628 int64_t sliceComputeCost = getComputeCostHelper(
629 srcForOp, srcStats, &sliceTripCountMap, &computeCostMap);
630
631 // Compute cost of fusion for this depth.
632 computeCostMap[insertPointParent] = sliceComputeCost;
633
634 *computeCost =
635 getComputeCostHelper(dstForOp, dstStats,
636 /*tripCountOverrideMap=*/nullptr, &computeCostMap);
637 return true;
638 }
639
640 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a
641 /// producer-consumer dependence between write ops in 'srcOps' and read ops in
642 /// 'dstOps'.
gatherProducerConsumerMemrefs(ArrayRef<Operation * > srcOps,ArrayRef<Operation * > dstOps,DenseSet<Value> & producerConsumerMemrefs)643 void mlir::affine::gatherProducerConsumerMemrefs(
644 ArrayRef<Operation *> srcOps, ArrayRef<Operation *> dstOps,
645 DenseSet<Value> &producerConsumerMemrefs) {
646 // Gather memrefs from stores in 'srcOps'.
647 DenseSet<Value> srcStoreMemRefs;
648 for (Operation *op : srcOps)
649 if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op))
650 srcStoreMemRefs.insert(storeOp.getMemRef());
651
652 // Compute the intersection between memrefs from stores in 'srcOps' and
653 // memrefs from loads in 'dstOps'.
654 for (Operation *op : dstOps)
655 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op))
656 if (srcStoreMemRefs.count(loadOp.getMemRef()) > 0)
657 producerConsumerMemrefs.insert(loadOp.getMemRef());
658 }
659