1 //===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop fusion transformation utility functions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/LoopFusionUtils.h" 14 #include "mlir/Analysis/SliceAnalysis.h" 15 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" 16 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" 17 #include "mlir/Dialect/Affine/Analysis/Utils.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Affine/LoopUtils.h" 20 #include "mlir/IR/IRMapping.h" 21 #include "mlir/IR/Operation.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "llvm/Support/Debug.h" 24 #include "llvm/Support/raw_ostream.h" 25 #include <optional> 26 27 #define DEBUG_TYPE "loop-fusion-utils" 28 29 using namespace mlir; 30 using namespace mlir::affine; 31 32 // Gathers all load and store memref accesses in 'opA' into 'values', where 33 // 'values[memref] == true' for each store operation. 34 static void getLoadAndStoreMemRefAccesses(Operation *opA, 35 DenseMap<Value, bool> &values) { 36 opA->walk([&](Operation *op) { 37 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) { 38 if (values.count(loadOp.getMemRef()) == 0) 39 values[loadOp.getMemRef()] = false; 40 } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) { 41 values[storeOp.getMemRef()] = true; 42 } 43 }); 44 } 45 46 /// Returns true if 'op' is a load or store operation which access a memref 47 /// accessed 'values' and at least one of the access is a store operation. 48 /// Returns false otherwise. 49 static bool isDependentLoadOrStoreOp(Operation *op, 50 DenseMap<Value, bool> &values) { 51 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) { 52 return values.count(loadOp.getMemRef()) > 0 && values[loadOp.getMemRef()]; 53 } 54 if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) { 55 return values.count(storeOp.getMemRef()) > 0; 56 } 57 return false; 58 } 59 60 // Returns the first operation in range ('opA', 'opB') which has a data 61 // dependence on 'opA'. Returns 'nullptr' of no dependence exists. 62 static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) { 63 // Record memref values from all loads/store in loop nest rooted at 'opA'. 64 // Map from memref value to bool which is true if store, false otherwise. 65 DenseMap<Value, bool> values; 66 getLoadAndStoreMemRefAccesses(opA, values); 67 68 // For each 'opX' in block in range ('opA', 'opB'), check if there is a data 69 // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref 70 // and at least one of the accesses is a store). 71 Operation *firstDepOp = nullptr; 72 for (Block::iterator it = std::next(Block::iterator(opA)); 73 it != Block::iterator(opB); ++it) { 74 Operation *opX = &(*it); 75 opX->walk([&](Operation *op) { 76 if (!firstDepOp && isDependentLoadOrStoreOp(op, values)) 77 firstDepOp = opX; 78 }); 79 if (firstDepOp) 80 break; 81 } 82 return firstDepOp; 83 } 84 85 // Returns the last operation 'opX' in range ('opA', 'opB'), for which there 86 // exists a data dependence from 'opX' to 'opB'. 87 // Returns 'nullptr' of no dependence exists. 88 static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) { 89 // Record memref values from all loads/store in loop nest rooted at 'opB'. 90 // Map from memref value to bool which is true if store, false otherwise. 91 DenseMap<Value, bool> values; 92 getLoadAndStoreMemRefAccesses(opB, values); 93 94 // For each 'opX' in block in range ('opA', 'opB') in reverse order, 95 // check if there is a data dependence from 'opX' to 'opB': 96 // *) 'opX' and 'opB' access the same memref and at least one of the accesses 97 // is a store. 98 // *) 'opX' produces an SSA Value which is used by 'opB'. 99 Operation *lastDepOp = nullptr; 100 for (Block::reverse_iterator it = std::next(Block::reverse_iterator(opB)); 101 it != Block::reverse_iterator(opA); ++it) { 102 Operation *opX = &(*it); 103 opX->walk([&](Operation *op) { 104 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) { 105 if (isDependentLoadOrStoreOp(op, values)) { 106 lastDepOp = opX; 107 return WalkResult::interrupt(); 108 } 109 return WalkResult::advance(); 110 } 111 for (Value value : op->getResults()) { 112 for (Operation *user : value.getUsers()) { 113 SmallVector<AffineForOp, 4> loops; 114 // Check if any loop in loop nest surrounding 'user' is 'opB'. 115 getAffineForIVs(*user, &loops); 116 if (llvm::is_contained(loops, cast<AffineForOp>(opB))) { 117 lastDepOp = opX; 118 return WalkResult::interrupt(); 119 } 120 } 121 } 122 return WalkResult::advance(); 123 }); 124 if (lastDepOp) 125 break; 126 } 127 return lastDepOp; 128 } 129 130 // Computes and returns an insertion point operation, before which the 131 // the fused <srcForOp, dstForOp> loop nest can be inserted while preserving 132 // dependences. Returns nullptr if no such insertion point is found. 133 static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp, 134 AffineForOp dstForOp) { 135 bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp); 136 auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp; 137 auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp; 138 139 Operation *firstDepOpA = getFirstDependentOpInRange(forOpA, forOpB); 140 Operation *lastDepOpB = getLastDependentOpInRange(forOpA, forOpB); 141 // Block: 142 // ... 143 // |-- opA 144 // | ... 145 // | lastDepOpB --| 146 // | ... | 147 // |-> firstDepOpA | 148 // ... | 149 // opB <--------- 150 // 151 // Valid insertion point range: (lastDepOpB, firstDepOpA) 152 // 153 if (firstDepOpA) { 154 if (lastDepOpB) { 155 if (firstDepOpA->isBeforeInBlock(lastDepOpB) || firstDepOpA == lastDepOpB) 156 // No valid insertion point exists which preserves dependences. 157 return nullptr; 158 } 159 // Return insertion point in valid range closest to 'opB'. 160 // TODO: Consider other insertion points in valid range. 161 return firstDepOpA; 162 } 163 // No dependences from 'opA' to operation in range ('opA', 'opB'), return 164 // 'opB' insertion point. 165 return forOpB; 166 } 167 168 // Gathers all load and store ops in loop nest rooted at 'forOp' into 169 // 'loadAndStoreOps'. 170 static bool 171 gatherLoadsAndStores(AffineForOp forOp, 172 SmallVectorImpl<Operation *> &loadAndStoreOps) { 173 bool hasIfOp = false; 174 forOp.walk([&](Operation *op) { 175 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) 176 loadAndStoreOps.push_back(op); 177 else if (isa<AffineIfOp>(op)) 178 hasIfOp = true; 179 }); 180 return !hasIfOp; 181 } 182 183 /// Returns the maximum loop depth at which we could fuse producer loop 184 /// 'srcForOp' into consumer loop 'dstForOp' without violating data dependences. 185 // TODO: Generalize this check for sibling and more generic fusion scenarios. 186 // TODO: Support forward slice fusion. 187 static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps, 188 ArrayRef<Operation *> dstOps) { 189 if (dstOps.empty()) 190 // Expected at least one memory operation. 191 // TODO: Revisit this case with a specific example. 192 return 0; 193 194 // Filter out ops in 'dstOps' that do not use the producer-consumer memref so 195 // that they are not considered for analysis. 196 DenseSet<Value> producerConsumerMemrefs; 197 gatherProducerConsumerMemrefs(srcOps, dstOps, producerConsumerMemrefs); 198 SmallVector<Operation *, 4> targetDstOps; 199 for (Operation *dstOp : dstOps) { 200 auto loadOp = dyn_cast<AffineReadOpInterface>(dstOp); 201 Value memref = loadOp ? loadOp.getMemRef() 202 : cast<AffineWriteOpInterface>(dstOp).getMemRef(); 203 if (producerConsumerMemrefs.count(memref) > 0) 204 targetDstOps.push_back(dstOp); 205 } 206 207 assert(!targetDstOps.empty() && 208 "No dependences between 'srcForOp' and 'dstForOp'?"); 209 210 // Compute the innermost common loop depth for loads and stores. 211 unsigned loopDepth = getInnermostCommonLoopDepth(targetDstOps); 212 213 // Return common loop depth for loads if there are no store ops. 214 if (all_of(targetDstOps, 215 [&](Operation *op) { return isa<AffineReadOpInterface>(op); })) 216 return loopDepth; 217 218 // Check dependences on all pairs of ops in 'targetDstOps' and store the 219 // minimum loop depth at which a dependence is satisfied. 220 for (unsigned i = 0, e = targetDstOps.size(); i < e; ++i) { 221 Operation *srcOpInst = targetDstOps[i]; 222 MemRefAccess srcAccess(srcOpInst); 223 for (unsigned j = 0; j < e; ++j) { 224 auto *dstOpInst = targetDstOps[j]; 225 MemRefAccess dstAccess(dstOpInst); 226 227 unsigned numCommonLoops = 228 getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst); 229 for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { 230 // TODO: Cache dependence analysis results, check cache here. 231 DependenceResult result = 232 checkMemrefAccessDependence(srcAccess, dstAccess, d); 233 if (hasDependence(result)) { 234 // Store minimum loop depth and break because we want the min 'd' at 235 // which there is a dependence. 236 loopDepth = std::min(loopDepth, d - 1); 237 break; 238 } 239 } 240 } 241 } 242 243 return loopDepth; 244 } 245 246 // TODO: This pass performs some computation that is the same for all the depths 247 // (e.g., getMaxLoopDepth). Implement a version of this utility that processes 248 // all the depths at once or only the legal maximal depth for maximal fusion. 249 FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp, 250 AffineForOp dstForOp, 251 unsigned dstLoopDepth, 252 ComputationSliceState *srcSlice, 253 FusionStrategy fusionStrategy) { 254 // Return 'failure' if 'dstLoopDepth == 0'. 255 if (dstLoopDepth == 0) { 256 LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n"); 257 return FusionResult::FailPrecondition; 258 } 259 // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block. 260 auto *block = srcForOp->getBlock(); 261 if (block != dstForOp->getBlock()) { 262 LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n"); 263 return FusionResult::FailPrecondition; 264 } 265 266 // Return 'failure' if no valid insertion point for fused loop nest in 'block' 267 // exists which would preserve dependences. 268 if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) { 269 LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n"); 270 return FusionResult::FailBlockDependence; 271 } 272 273 // Check if 'srcForOp' precedes 'dstForOp' in 'block'. 274 bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp); 275 // 'forOpA' executes before 'forOpB' in 'block'. 276 auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp; 277 auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp; 278 279 // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'. 280 SmallVector<Operation *, 4> opsA; 281 if (!gatherLoadsAndStores(forOpA, opsA)) { 282 LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n"); 283 return FusionResult::FailPrecondition; 284 } 285 286 // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'. 287 SmallVector<Operation *, 4> opsB; 288 if (!gatherLoadsAndStores(forOpB, opsB)) { 289 LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n"); 290 return FusionResult::FailPrecondition; 291 } 292 293 // Return 'failure' if fusing loops at depth 'dstLoopDepth' wouldn't preserve 294 // loop dependences. 295 // TODO: Enable this check for sibling and more generic loop fusion 296 // strategies. 297 if (fusionStrategy.getStrategy() == FusionStrategy::ProducerConsumer) { 298 // TODO: 'getMaxLoopDepth' does not support forward slice fusion. 299 assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion"); 300 if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) { 301 LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n"); 302 return FusionResult::FailFusionDependence; 303 } 304 } 305 306 // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'. 307 unsigned numCommonLoops = 308 affine::getNumCommonSurroundingLoops(*srcForOp, *dstForOp); 309 310 // Filter out ops in 'opsA' to compute the slice union based on the 311 // assumptions made by the fusion strategy. 312 SmallVector<Operation *, 4> strategyOpsA; 313 switch (fusionStrategy.getStrategy()) { 314 case FusionStrategy::Generic: 315 // Generic fusion. Take into account all the memory operations to compute 316 // the slice union. 317 strategyOpsA.append(opsA.begin(), opsA.end()); 318 break; 319 case FusionStrategy::ProducerConsumer: 320 // Producer-consumer fusion (AffineLoopFusion pass) only takes into 321 // account stores in 'srcForOp' to compute the slice union. 322 for (Operation *op : opsA) { 323 if (isa<AffineWriteOpInterface>(op)) 324 strategyOpsA.push_back(op); 325 } 326 break; 327 case FusionStrategy::Sibling: 328 // Sibling fusion (AffineLoopFusion pass) only takes into account the loads 329 // to 'memref' in 'srcForOp' to compute the slice union. 330 for (Operation *op : opsA) { 331 auto load = dyn_cast<AffineReadOpInterface>(op); 332 if (load && load.getMemRef() == fusionStrategy.getSiblingFusionMemRef()) 333 strategyOpsA.push_back(op); 334 } 335 break; 336 } 337 338 // Compute union of computation slices computed between all pairs of ops 339 // from 'forOpA' and 'forOpB'. 340 SliceComputationResult sliceComputationResult = affine::computeSliceUnion( 341 strategyOpsA, opsB, dstLoopDepth, numCommonLoops, 342 isSrcForOpBeforeDstForOp, srcSlice); 343 if (sliceComputationResult.value == SliceComputationResult::GenericFailure) { 344 LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n"); 345 return FusionResult::FailPrecondition; 346 } 347 if (sliceComputationResult.value == 348 SliceComputationResult::IncorrectSliceFailure) { 349 LLVM_DEBUG(llvm::dbgs() << "Incorrect slice computation\n"); 350 return FusionResult::FailIncorrectSlice; 351 } 352 353 return FusionResult::Success; 354 } 355 356 /// Patch the loop body of a forOp that is a single iteration reduction loop 357 /// into its containing block. 358 static LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp, 359 bool siblingFusionUser) { 360 // Check if the reduction loop is a single iteration loop. 361 std::optional<uint64_t> tripCount = getConstantTripCount(forOp); 362 if (!tripCount || *tripCount != 1) 363 return failure(); 364 auto *parentOp = forOp->getParentOp(); 365 if (!isa<AffineForOp>(parentOp)) 366 return failure(); 367 SmallVector<Value> newOperands; 368 llvm::append_range(newOperands, 369 forOp.getBody()->getTerminator()->getOperands()); 370 IRRewriter rewriter(parentOp->getContext()); 371 int64_t parentOpNumResults = parentOp->getNumResults(); 372 // Replace the parent loop and add iteroperands and results from the `forOp`. 373 AffineForOp parentForOp = forOp->getParentOfType<AffineForOp>(); 374 AffineForOp newLoop = 375 cast<AffineForOp>(*parentForOp.replaceWithAdditionalYields( 376 rewriter, forOp.getInits(), /*replaceInitOperandUsesInLoop=*/false, 377 [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) { 378 return newOperands; 379 })); 380 381 // For sibling-fusion users, collect operations that use the results of the 382 // `forOp` outside the new parent loop that has absorbed all its iter args 383 // and operands. These operations will be moved later after the results 384 // have been replaced. 385 SetVector<Operation *> forwardSlice; 386 if (siblingFusionUser) { 387 for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) { 388 SetVector<Operation *> tmpForwardSlice; 389 getForwardSlice(forOp.getResult(i), &tmpForwardSlice); 390 forwardSlice.set_union(tmpForwardSlice); 391 } 392 } 393 // Update the results of the `forOp` in the new loop. 394 for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) { 395 forOp.getResult(i).replaceAllUsesWith( 396 newLoop.getResult(i + parentOpNumResults)); 397 } 398 // For sibling-fusion users, move operations that use the results of the 399 // `forOp` outside the new parent loop 400 if (siblingFusionUser) { 401 topologicalSort(forwardSlice); 402 for (Operation *op : llvm::reverse(forwardSlice)) 403 op->moveAfter(newLoop); 404 } 405 // Replace the induction variable. 406 auto iv = forOp.getInductionVar(); 407 iv.replaceAllUsesWith(newLoop.getInductionVar()); 408 // Replace the iter args. 409 auto forOpIterArgs = forOp.getRegionIterArgs(); 410 for (auto it : llvm::zip(forOpIterArgs, newLoop.getRegionIterArgs().take_back( 411 forOpIterArgs.size()))) { 412 std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); 413 } 414 // Move the loop body operations, except for its terminator, to the loop's 415 // containing block. 416 forOp.getBody()->back().erase(); 417 auto *parentBlock = forOp->getBlock(); 418 parentBlock->getOperations().splice(Block::iterator(forOp), 419 forOp.getBody()->getOperations()); 420 forOp.erase(); 421 return success(); 422 } 423 424 /// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point 425 /// and source slice loop bounds specified in 'srcSlice'. 426 void mlir::affine::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, 427 const ComputationSliceState &srcSlice, 428 bool isInnermostSiblingInsertion) { 429 // Clone 'srcForOp' into 'dstForOp' at 'srcSlice->insertPoint'. 430 OpBuilder b(srcSlice.insertPoint->getBlock(), srcSlice.insertPoint); 431 IRMapping mapper; 432 b.clone(*srcForOp, mapper); 433 434 // Update 'sliceLoopNest' upper and lower bounds from computed 'srcSlice'. 435 SmallVector<AffineForOp, 4> sliceLoops; 436 for (unsigned i = 0, e = srcSlice.ivs.size(); i < e; ++i) { 437 auto loopIV = mapper.lookupOrNull(srcSlice.ivs[i]); 438 if (!loopIV) 439 continue; 440 auto forOp = getForInductionVarOwner(loopIV); 441 sliceLoops.push_back(forOp); 442 if (AffineMap lbMap = srcSlice.lbs[i]) { 443 auto lbOperands = srcSlice.lbOperands[i]; 444 canonicalizeMapAndOperands(&lbMap, &lbOperands); 445 forOp.setLowerBound(lbOperands, lbMap); 446 } 447 if (AffineMap ubMap = srcSlice.ubs[i]) { 448 auto ubOperands = srcSlice.ubOperands[i]; 449 canonicalizeMapAndOperands(&ubMap, &ubOperands); 450 forOp.setUpperBound(ubOperands, ubMap); 451 } 452 } 453 454 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap; 455 auto srcIsUnitSlice = [&]() { 456 return (buildSliceTripCountMap(srcSlice, &sliceTripCountMap) && 457 (getSliceIterationCount(sliceTripCountMap) == 1)); 458 }; 459 // Fix up and if possible, eliminate single iteration loops. 460 for (AffineForOp forOp : sliceLoops) { 461 if (isLoopParallelAndContainsReduction(forOp) && 462 isInnermostSiblingInsertion && srcIsUnitSlice()) 463 // Patch reduction loop - only ones that are sibling-fused with the 464 // destination loop - into the parent loop. 465 (void)promoteSingleIterReductionLoop(forOp, true); 466 else 467 // Promote any single iteration slice loops. 468 (void)promoteIfSingleIteration(forOp); 469 } 470 } 471 472 /// Collect loop nest statistics (eg. loop trip count and operation count) 473 /// in 'stats' for loop nest rooted at 'forOp'. Returns true on success, 474 /// returns false otherwise. 475 bool mlir::affine::getLoopNestStats(AffineForOp forOpRoot, 476 LoopNestStats *stats) { 477 auto walkResult = forOpRoot.walk([&](AffineForOp forOp) { 478 auto *childForOp = forOp.getOperation(); 479 auto *parentForOp = forOp->getParentOp(); 480 if (forOp != forOpRoot) { 481 if (!isa<AffineForOp>(parentForOp)) { 482 LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n"); 483 return WalkResult::interrupt(); 484 } 485 // Add mapping to 'forOp' from its parent AffineForOp. 486 stats->loopMap[parentForOp].push_back(forOp); 487 } 488 489 // Record the number of op operations in the body of 'forOp'. 490 unsigned count = 0; 491 stats->opCountMap[childForOp] = 0; 492 for (auto &op : *forOp.getBody()) { 493 if (!isa<AffineForOp, AffineIfOp>(op)) 494 ++count; 495 } 496 stats->opCountMap[childForOp] = count; 497 498 // Record trip count for 'forOp'. Set flag if trip count is not 499 // constant. 500 std::optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp); 501 if (!maybeConstTripCount) { 502 // Currently only constant trip count loop nests are supported. 503 LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported\n"); 504 return WalkResult::interrupt(); 505 } 506 507 stats->tripCountMap[childForOp] = *maybeConstTripCount; 508 return WalkResult::advance(); 509 }); 510 return !walkResult.wasInterrupted(); 511 } 512 513 // Computes the total cost of the loop nest rooted at 'forOp'. 514 // Currently, the total cost is computed by counting the total operation 515 // instance count (i.e. total number of operations in the loop bodyloop 516 // operation count * loop trip count) for the entire loop nest. 517 // If 'tripCountOverrideMap' is non-null, overrides the trip count for loops 518 // specified in the map when computing the total op instance count. 519 // NOTEs: 1) This is used to compute the cost of computation slices, which are 520 // sliced along the iteration dimension, and thus reduce the trip count. 521 // If 'computeCostMap' is non-null, the total op count for forOps specified 522 // in the map is increased (not overridden) by adding the op count from the 523 // map to the existing op count for the for loop. This is done before 524 // multiplying by the loop's trip count, and is used to model the cost of 525 // inserting a sliced loop nest of known cost into the loop's body. 526 // 2) This is also used to compute the cost of fusing a slice of some loop nest 527 // within another loop. 528 static int64_t getComputeCostHelper( 529 Operation *forOp, LoopNestStats &stats, 530 llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap, 531 DenseMap<Operation *, int64_t> *computeCostMap) { 532 // 'opCount' is the total number operations in one iteration of 'forOp' body, 533 // minus terminator op which is a no-op. 534 int64_t opCount = stats.opCountMap[forOp] - 1; 535 if (stats.loopMap.count(forOp) > 0) { 536 for (auto childForOp : stats.loopMap[forOp]) { 537 opCount += getComputeCostHelper(childForOp, stats, tripCountOverrideMap, 538 computeCostMap); 539 } 540 } 541 // Add in additional op instances from slice (if specified in map). 542 if (computeCostMap) { 543 auto it = computeCostMap->find(forOp); 544 if (it != computeCostMap->end()) { 545 opCount += it->second; 546 } 547 } 548 // Override trip count (if specified in map). 549 int64_t tripCount = stats.tripCountMap[forOp]; 550 if (tripCountOverrideMap) { 551 auto it = tripCountOverrideMap->find(forOp); 552 if (it != tripCountOverrideMap->end()) { 553 tripCount = it->second; 554 } 555 } 556 // Returns the total number of dynamic instances of operations in loop body. 557 return tripCount * opCount; 558 } 559 560 /// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'. 561 /// Currently, the total cost is computed by counting the total operation 562 /// instance count (i.e. total number of operations in the loop body * loop 563 /// trip count) for the entire loop nest. 564 int64_t mlir::affine::getComputeCost(AffineForOp forOp, LoopNestStats &stats) { 565 return getComputeCostHelper(forOp, stats, 566 /*tripCountOverrideMap=*/nullptr, 567 /*computeCostMap=*/nullptr); 568 } 569 570 /// Computes and returns in 'computeCost', the total compute cost of fusing the 571 /// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently, 572 /// the total cost is computed by counting the total operation instance count 573 /// (i.e. total number of operations in the loop body * loop trip count) for 574 /// the entire loop nest. 575 bool mlir::affine::getFusionComputeCost(AffineForOp srcForOp, 576 LoopNestStats &srcStats, 577 AffineForOp dstForOp, 578 LoopNestStats &dstStats, 579 const ComputationSliceState &slice, 580 int64_t *computeCost) { 581 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap; 582 DenseMap<Operation *, int64_t> computeCostMap; 583 584 // Build trip count map for computation slice. 585 if (!buildSliceTripCountMap(slice, &sliceTripCountMap)) 586 return false; 587 // Checks whether a store to load forwarding will happen. 588 int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap); 589 assert(sliceIterationCount > 0); 590 bool storeLoadFwdGuaranteed = (sliceIterationCount == 1); 591 auto *insertPointParent = slice.insertPoint->getParentOp(); 592 593 // The store and loads to this memref will disappear. 594 if (storeLoadFwdGuaranteed) { 595 // Subtract from operation count the loads/store we expect load/store 596 // forwarding to remove. 597 unsigned storeCount = 0; 598 llvm::SmallDenseSet<Value, 4> storeMemrefs; 599 srcForOp.walk([&](AffineWriteOpInterface storeOp) { 600 storeMemrefs.insert(storeOp.getMemRef()); 601 ++storeCount; 602 }); 603 // Subtract out any store ops in single-iteration src slice loop nest. 604 if (storeCount > 0) 605 computeCostMap[insertPointParent] = -storeCount; 606 // Subtract out any load users of 'storeMemrefs' nested below 607 // 'insertPointParent'. 608 for (Value memref : storeMemrefs) { 609 for (Operation *user : memref.getUsers()) { 610 if (!isa<AffineReadOpInterface>(user)) 611 continue; 612 SmallVector<AffineForOp, 4> loops; 613 // Check if any loop in loop nest surrounding 'user' is 614 // 'insertPointParent'. 615 getAffineForIVs(*user, &loops); 616 if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) { 617 if (auto forOp = dyn_cast_or_null<AffineForOp>(user->getParentOp())) { 618 if (computeCostMap.count(forOp) == 0) 619 computeCostMap[forOp] = 0; 620 computeCostMap[forOp] -= 1; 621 } 622 } 623 } 624 } 625 } 626 627 // Compute op instance count for the src loop nest with iteration slicing. 628 int64_t sliceComputeCost = getComputeCostHelper( 629 srcForOp, srcStats, &sliceTripCountMap, &computeCostMap); 630 631 // Compute cost of fusion for this depth. 632 computeCostMap[insertPointParent] = sliceComputeCost; 633 634 *computeCost = 635 getComputeCostHelper(dstForOp, dstStats, 636 /*tripCountOverrideMap=*/nullptr, &computeCostMap); 637 return true; 638 } 639 640 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a 641 /// producer-consumer dependence between write ops in 'srcOps' and read ops in 642 /// 'dstOps'. 643 void mlir::affine::gatherProducerConsumerMemrefs( 644 ArrayRef<Operation *> srcOps, ArrayRef<Operation *> dstOps, 645 DenseSet<Value> &producerConsumerMemrefs) { 646 // Gather memrefs from stores in 'srcOps'. 647 DenseSet<Value> srcStoreMemRefs; 648 for (Operation *op : srcOps) 649 if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) 650 srcStoreMemRefs.insert(storeOp.getMemRef()); 651 652 // Compute the intersection between memrefs from stores in 'srcOps' and 653 // memrefs from loads in 'dstOps'. 654 for (Operation *op : dstOps) 655 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) 656 if (srcStoreMemRefs.count(loadOp.getMemRef()) > 0) 657 producerConsumerMemrefs.insert(loadOp.getMemRef()); 658 } 659