1 //===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop fusion transformation utility functions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/LoopFusionUtils.h" 14 #include "mlir/Analysis/SliceAnalysis.h" 15 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" 16 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" 17 #include "mlir/Dialect/Affine/Analysis/Utils.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Affine/LoopUtils.h" 20 #include "mlir/IR/IRMapping.h" 21 #include "mlir/IR/Operation.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/raw_ostream.h" 24 #include <optional> 25 26 #define DEBUG_TYPE "loop-fusion-utils" 27 28 using namespace mlir; 29 using namespace mlir::affine; 30 31 // Gathers all load and store memref accesses in 'opA' into 'values', where 32 // 'values[memref] == true' for each store operation. 33 static void getLoadAndStoreMemRefAccesses(Operation *opA, 34 DenseMap<Value, bool> &values) { 35 opA->walk([&](Operation *op) { 36 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) { 37 if (values.count(loadOp.getMemRef()) == 0) 38 values[loadOp.getMemRef()] = false; 39 } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) { 40 values[storeOp.getMemRef()] = true; 41 } 42 }); 43 } 44 45 /// Returns true if 'op' is a load or store operation which access a memref 46 /// accessed 'values' and at least one of the access is a store operation. 47 /// Returns false otherwise. 48 static bool isDependentLoadOrStoreOp(Operation *op, 49 DenseMap<Value, bool> &values) { 50 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) { 51 return values.count(loadOp.getMemRef()) > 0 && values[loadOp.getMemRef()]; 52 } 53 if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) { 54 return values.count(storeOp.getMemRef()) > 0; 55 } 56 return false; 57 } 58 59 // Returns the first operation in range ('opA', 'opB') which has a data 60 // dependence on 'opA'. Returns 'nullptr' of no dependence exists. 61 static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) { 62 // Record memref values from all loads/store in loop nest rooted at 'opA'. 63 // Map from memref value to bool which is true if store, false otherwise. 64 DenseMap<Value, bool> values; 65 getLoadAndStoreMemRefAccesses(opA, values); 66 67 // For each 'opX' in block in range ('opA', 'opB'), check if there is a data 68 // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref 69 // and at least one of the accesses is a store). 70 Operation *firstDepOp = nullptr; 71 for (Block::iterator it = std::next(Block::iterator(opA)); 72 it != Block::iterator(opB); ++it) { 73 Operation *opX = &(*it); 74 opX->walk([&](Operation *op) { 75 if (!firstDepOp && isDependentLoadOrStoreOp(op, values)) 76 firstDepOp = opX; 77 }); 78 if (firstDepOp) 79 break; 80 } 81 return firstDepOp; 82 } 83 84 // Returns the last operation 'opX' in range ('opA', 'opB'), for which there 85 // exists a data dependence from 'opX' to 'opB'. 86 // Returns 'nullptr' of no dependence exists. 87 static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) { 88 // Record memref values from all loads/store in loop nest rooted at 'opB'. 89 // Map from memref value to bool which is true if store, false otherwise. 90 DenseMap<Value, bool> values; 91 getLoadAndStoreMemRefAccesses(opB, values); 92 93 // For each 'opX' in block in range ('opA', 'opB') in reverse order, 94 // check if there is a data dependence from 'opX' to 'opB': 95 // *) 'opX' and 'opB' access the same memref and at least one of the accesses 96 // is a store. 97 // *) 'opX' produces an SSA Value which is used by 'opB'. 98 Operation *lastDepOp = nullptr; 99 for (Block::reverse_iterator it = std::next(Block::reverse_iterator(opB)); 100 it != Block::reverse_iterator(opA); ++it) { 101 Operation *opX = &(*it); 102 opX->walk([&](Operation *op) { 103 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) { 104 if (isDependentLoadOrStoreOp(op, values)) { 105 lastDepOp = opX; 106 return WalkResult::interrupt(); 107 } 108 return WalkResult::advance(); 109 } 110 for (Value value : op->getResults()) { 111 for (Operation *user : value.getUsers()) { 112 SmallVector<AffineForOp, 4> loops; 113 // Check if any loop in loop nest surrounding 'user' is 'opB'. 114 getAffineForIVs(*user, &loops); 115 if (llvm::is_contained(loops, cast<AffineForOp>(opB))) { 116 lastDepOp = opX; 117 return WalkResult::interrupt(); 118 } 119 } 120 } 121 return WalkResult::advance(); 122 }); 123 if (lastDepOp) 124 break; 125 } 126 return lastDepOp; 127 } 128 129 // Computes and returns an insertion point operation, before which the 130 // the fused <srcForOp, dstForOp> loop nest can be inserted while preserving 131 // dependences. Returns nullptr if no such insertion point is found. 132 static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp, 133 AffineForOp dstForOp) { 134 bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp); 135 auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp; 136 auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp; 137 138 Operation *firstDepOpA = getFirstDependentOpInRange(forOpA, forOpB); 139 Operation *lastDepOpB = getLastDependentOpInRange(forOpA, forOpB); 140 // Block: 141 // ... 142 // |-- opA 143 // | ... 144 // | lastDepOpB --| 145 // | ... | 146 // |-> firstDepOpA | 147 // ... | 148 // opB <--------- 149 // 150 // Valid insertion point range: (lastDepOpB, firstDepOpA) 151 // 152 if (firstDepOpA != nullptr) { 153 if (lastDepOpB != nullptr) { 154 if (firstDepOpA->isBeforeInBlock(lastDepOpB) || firstDepOpA == lastDepOpB) 155 // No valid insertion point exists which preserves dependences. 156 return nullptr; 157 } 158 // Return insertion point in valid range closest to 'opB'. 159 // TODO: Consider other insertion points in valid range. 160 return firstDepOpA; 161 } 162 // No dependences from 'opA' to operation in range ('opA', 'opB'), return 163 // 'opB' insertion point. 164 return forOpB; 165 } 166 167 // Gathers all load and store ops in loop nest rooted at 'forOp' into 168 // 'loadAndStoreOps'. 169 static bool 170 gatherLoadsAndStores(AffineForOp forOp, 171 SmallVectorImpl<Operation *> &loadAndStoreOps) { 172 bool hasIfOp = false; 173 forOp.walk([&](Operation *op) { 174 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) 175 loadAndStoreOps.push_back(op); 176 else if (isa<AffineIfOp>(op)) 177 hasIfOp = true; 178 }); 179 return !hasIfOp; 180 } 181 182 /// Returns the maximum loop depth at which we could fuse producer loop 183 /// 'srcForOp' into consumer loop 'dstForOp' without violating data dependences. 184 // TODO: Generalize this check for sibling and more generic fusion scenarios. 185 // TODO: Support forward slice fusion. 186 static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps, 187 ArrayRef<Operation *> dstOps) { 188 if (dstOps.empty()) 189 // Expected at least one memory operation. 190 // TODO: Revisit this case with a specific example. 191 return 0; 192 193 // Filter out ops in 'dstOps' that do not use the producer-consumer memref so 194 // that they are not considered for analysis. 195 DenseSet<Value> producerConsumerMemrefs; 196 gatherProducerConsumerMemrefs(srcOps, dstOps, producerConsumerMemrefs); 197 SmallVector<Operation *, 4> targetDstOps; 198 for (Operation *dstOp : dstOps) { 199 auto loadOp = dyn_cast<AffineReadOpInterface>(dstOp); 200 Value memref = loadOp ? loadOp.getMemRef() 201 : cast<AffineWriteOpInterface>(dstOp).getMemRef(); 202 if (producerConsumerMemrefs.count(memref) > 0) 203 targetDstOps.push_back(dstOp); 204 } 205 206 assert(!targetDstOps.empty() && 207 "No dependences between 'srcForOp' and 'dstForOp'?"); 208 209 // Compute the innermost common loop depth for loads and stores. 210 unsigned loopDepth = getInnermostCommonLoopDepth(targetDstOps); 211 212 // Return common loop depth for loads if there are no store ops. 213 if (all_of(targetDstOps, 214 [&](Operation *op) { return isa<AffineReadOpInterface>(op); })) 215 return loopDepth; 216 217 // Check dependences on all pairs of ops in 'targetDstOps' and store the 218 // minimum loop depth at which a dependence is satisfied. 219 for (unsigned i = 0, e = targetDstOps.size(); i < e; ++i) { 220 auto *srcOpInst = targetDstOps[i]; 221 MemRefAccess srcAccess(srcOpInst); 222 for (unsigned j = 0; j < e; ++j) { 223 auto *dstOpInst = targetDstOps[j]; 224 MemRefAccess dstAccess(dstOpInst); 225 226 unsigned numCommonLoops = 227 getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst); 228 for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { 229 // TODO: Cache dependence analysis results, check cache here. 230 DependenceResult result = 231 checkMemrefAccessDependence(srcAccess, dstAccess, d); 232 if (hasDependence(result)) { 233 // Store minimum loop depth and break because we want the min 'd' at 234 // which there is a dependence. 235 loopDepth = std::min(loopDepth, d - 1); 236 break; 237 } 238 } 239 } 240 } 241 242 return loopDepth; 243 } 244 245 // TODO: Prevent fusion of loop nests with side-effecting operations. 246 // TODO: This pass performs some computation that is the same for all the depths 247 // (e.g., getMaxLoopDepth). Implement a version of this utility that processes 248 // all the depths at once or only the legal maximal depth for maximal fusion. 249 FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp, 250 AffineForOp dstForOp, 251 unsigned dstLoopDepth, 252 ComputationSliceState *srcSlice, 253 FusionStrategy fusionStrategy) { 254 // Return 'failure' if 'dstLoopDepth == 0'. 255 if (dstLoopDepth == 0) { 256 LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n"); 257 return FusionResult::FailPrecondition; 258 } 259 // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block. 260 auto *block = srcForOp->getBlock(); 261 if (block != dstForOp->getBlock()) { 262 LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n"); 263 return FusionResult::FailPrecondition; 264 } 265 266 // Return 'failure' if no valid insertion point for fused loop nest in 'block' 267 // exists which would preserve dependences. 268 if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) { 269 LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n"); 270 return FusionResult::FailBlockDependence; 271 } 272 273 // Check if 'srcForOp' precedes 'dstForOp' in 'block'. 274 bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp); 275 // 'forOpA' executes before 'forOpB' in 'block'. 276 auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp; 277 auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp; 278 279 // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'. 280 SmallVector<Operation *, 4> opsA; 281 if (!gatherLoadsAndStores(forOpA, opsA)) { 282 LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n"); 283 return FusionResult::FailPrecondition; 284 } 285 286 // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'. 287 SmallVector<Operation *, 4> opsB; 288 if (!gatherLoadsAndStores(forOpB, opsB)) { 289 LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n"); 290 return FusionResult::FailPrecondition; 291 } 292 293 // Return 'failure' if fusing loops at depth 'dstLoopDepth' wouldn't preserve 294 // loop dependences. 295 // TODO: Enable this check for sibling and more generic loop fusion 296 // strategies. 297 if (fusionStrategy.getStrategy() == FusionStrategy::ProducerConsumer) { 298 // TODO: 'getMaxLoopDepth' does not support forward slice fusion. 299 assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion"); 300 if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) { 301 LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n"); 302 return FusionResult::FailFusionDependence; 303 } 304 } 305 306 // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'. 307 unsigned numCommonLoops = 308 affine::getNumCommonSurroundingLoops(*srcForOp, *dstForOp); 309 310 // Filter out ops in 'opsA' to compute the slice union based on the 311 // assumptions made by the fusion strategy. 312 SmallVector<Operation *, 4> strategyOpsA; 313 switch (fusionStrategy.getStrategy()) { 314 case FusionStrategy::Generic: 315 // Generic fusion. Take into account all the memory operations to compute 316 // the slice union. 317 strategyOpsA.append(opsA.begin(), opsA.end()); 318 break; 319 case FusionStrategy::ProducerConsumer: 320 // Producer-consumer fusion (AffineLoopFusion pass) only takes into 321 // account stores in 'srcForOp' to compute the slice union. 322 for (Operation *op : opsA) { 323 if (isa<AffineWriteOpInterface>(op)) 324 strategyOpsA.push_back(op); 325 } 326 break; 327 case FusionStrategy::Sibling: 328 // Sibling fusion (AffineLoopFusion pass) only takes into account the loads 329 // to 'memref' in 'srcForOp' to compute the slice union. 330 for (Operation *op : opsA) { 331 auto load = dyn_cast<AffineReadOpInterface>(op); 332 if (load && load.getMemRef() == fusionStrategy.getSiblingFusionMemRef()) 333 strategyOpsA.push_back(op); 334 } 335 break; 336 } 337 338 // Compute union of computation slices computed between all pairs of ops 339 // from 'forOpA' and 'forOpB'. 340 SliceComputationResult sliceComputationResult = affine::computeSliceUnion( 341 strategyOpsA, opsB, dstLoopDepth, numCommonLoops, 342 isSrcForOpBeforeDstForOp, srcSlice); 343 if (sliceComputationResult.value == SliceComputationResult::GenericFailure) { 344 LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n"); 345 return FusionResult::FailPrecondition; 346 } 347 if (sliceComputationResult.value == 348 SliceComputationResult::IncorrectSliceFailure) { 349 LLVM_DEBUG(llvm::dbgs() << "Incorrect slice computation\n"); 350 return FusionResult::FailIncorrectSlice; 351 } 352 353 return FusionResult::Success; 354 } 355 356 /// Patch the loop body of a forOp that is a single iteration reduction loop 357 /// into its containing block. 358 static LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp, 359 bool siblingFusionUser) { 360 // Check if the reduction loop is a single iteration loop. 361 std::optional<uint64_t> tripCount = getConstantTripCount(forOp); 362 if (!tripCount || *tripCount != 1) 363 return failure(); 364 auto iterOperands = forOp.getIterOperands(); 365 auto *parentOp = forOp->getParentOp(); 366 if (!isa<AffineForOp>(parentOp)) 367 return failure(); 368 auto newOperands = forOp.getBody()->getTerminator()->getOperands(); 369 OpBuilder b(parentOp); 370 // Replace the parent loop and add iteroperands and results from the `forOp`. 371 AffineForOp parentForOp = forOp->getParentOfType<AffineForOp>(); 372 AffineForOp newLoop = replaceForOpWithNewYields( 373 b, parentForOp, iterOperands, newOperands, forOp.getRegionIterArgs()); 374 375 // For sibling-fusion users, collect operations that use the results of the 376 // `forOp` outside the new parent loop that has absorbed all its iter args 377 // and operands. These operations will be moved later after the results 378 // have been replaced. 379 SetVector<Operation *> forwardSlice; 380 if (siblingFusionUser) { 381 for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) { 382 SetVector<Operation *> tmpForwardSlice; 383 getForwardSlice(forOp.getResult(i), &tmpForwardSlice); 384 forwardSlice.set_union(tmpForwardSlice); 385 } 386 } 387 // Update the results of the `forOp` in the new loop. 388 for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) { 389 forOp.getResult(i).replaceAllUsesWith( 390 newLoop.getResult(i + parentOp->getNumResults())); 391 } 392 // For sibling-fusion users, move operations that use the results of the 393 // `forOp` outside the new parent loop 394 if (siblingFusionUser) { 395 topologicalSort(forwardSlice); 396 for (Operation *op : llvm::reverse(forwardSlice)) 397 op->moveAfter(newLoop); 398 } 399 // Replace the induction variable. 400 auto iv = forOp.getInductionVar(); 401 iv.replaceAllUsesWith(newLoop.getInductionVar()); 402 // Replace the iter args. 403 auto forOpIterArgs = forOp.getRegionIterArgs(); 404 for (auto it : llvm::zip(forOpIterArgs, newLoop.getRegionIterArgs().take_back( 405 forOpIterArgs.size()))) { 406 std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); 407 } 408 // Move the loop body operations, except for its terminator, to the loop's 409 // containing block. 410 forOp.getBody()->back().erase(); 411 auto *parentBlock = forOp->getBlock(); 412 parentBlock->getOperations().splice(Block::iterator(forOp), 413 forOp.getBody()->getOperations()); 414 forOp.erase(); 415 parentForOp.erase(); 416 return success(); 417 } 418 419 /// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point 420 /// and source slice loop bounds specified in 'srcSlice'. 421 void mlir::affine::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, 422 const ComputationSliceState &srcSlice, 423 bool isInnermostSiblingInsertion) { 424 // Clone 'srcForOp' into 'dstForOp' at 'srcSlice->insertPoint'. 425 OpBuilder b(srcSlice.insertPoint->getBlock(), srcSlice.insertPoint); 426 IRMapping mapper; 427 b.clone(*srcForOp, mapper); 428 429 // Update 'sliceLoopNest' upper and lower bounds from computed 'srcSlice'. 430 SmallVector<AffineForOp, 4> sliceLoops; 431 for (unsigned i = 0, e = srcSlice.ivs.size(); i < e; ++i) { 432 auto loopIV = mapper.lookupOrNull(srcSlice.ivs[i]); 433 if (!loopIV) 434 continue; 435 auto forOp = getForInductionVarOwner(loopIV); 436 sliceLoops.push_back(forOp); 437 if (AffineMap lbMap = srcSlice.lbs[i]) { 438 auto lbOperands = srcSlice.lbOperands[i]; 439 canonicalizeMapAndOperands(&lbMap, &lbOperands); 440 forOp.setLowerBound(lbOperands, lbMap); 441 } 442 if (AffineMap ubMap = srcSlice.ubs[i]) { 443 auto ubOperands = srcSlice.ubOperands[i]; 444 canonicalizeMapAndOperands(&ubMap, &ubOperands); 445 forOp.setUpperBound(ubOperands, ubMap); 446 } 447 } 448 449 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap; 450 auto srcIsUnitSlice = [&]() { 451 return (buildSliceTripCountMap(srcSlice, &sliceTripCountMap) && 452 (getSliceIterationCount(sliceTripCountMap) == 1)); 453 }; 454 // Fix up and if possible, eliminate single iteration loops. 455 for (AffineForOp forOp : sliceLoops) { 456 if (isLoopParallelAndContainsReduction(forOp) && 457 isInnermostSiblingInsertion && srcIsUnitSlice()) 458 // Patch reduction loop - only ones that are sibling-fused with the 459 // destination loop - into the parent loop. 460 (void)promoteSingleIterReductionLoop(forOp, true); 461 else 462 // Promote any single iteration slice loops. 463 (void)promoteIfSingleIteration(forOp); 464 } 465 } 466 467 /// Collect loop nest statistics (eg. loop trip count and operation count) 468 /// in 'stats' for loop nest rooted at 'forOp'. Returns true on success, 469 /// returns false otherwise. 470 bool mlir::affine::getLoopNestStats(AffineForOp forOpRoot, 471 LoopNestStats *stats) { 472 auto walkResult = forOpRoot.walk([&](AffineForOp forOp) { 473 auto *childForOp = forOp.getOperation(); 474 auto *parentForOp = forOp->getParentOp(); 475 if (forOp != forOpRoot) { 476 if (!isa<AffineForOp>(parentForOp)) { 477 LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n"); 478 return WalkResult::interrupt(); 479 } 480 // Add mapping to 'forOp' from its parent AffineForOp. 481 stats->loopMap[parentForOp].push_back(forOp); 482 } 483 484 // Record the number of op operations in the body of 'forOp'. 485 unsigned count = 0; 486 stats->opCountMap[childForOp] = 0; 487 for (auto &op : *forOp.getBody()) { 488 if (!isa<AffineForOp, AffineIfOp>(op)) 489 ++count; 490 } 491 stats->opCountMap[childForOp] = count; 492 493 // Record trip count for 'forOp'. Set flag if trip count is not 494 // constant. 495 std::optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp); 496 if (!maybeConstTripCount) { 497 // Currently only constant trip count loop nests are supported. 498 LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported\n"); 499 return WalkResult::interrupt(); 500 } 501 502 stats->tripCountMap[childForOp] = *maybeConstTripCount; 503 return WalkResult::advance(); 504 }); 505 return !walkResult.wasInterrupted(); 506 } 507 508 // Computes the total cost of the loop nest rooted at 'forOp'. 509 // Currently, the total cost is computed by counting the total operation 510 // instance count (i.e. total number of operations in the loop bodyloop 511 // operation count * loop trip count) for the entire loop nest. 512 // If 'tripCountOverrideMap' is non-null, overrides the trip count for loops 513 // specified in the map when computing the total op instance count. 514 // NOTEs: 1) This is used to compute the cost of computation slices, which are 515 // sliced along the iteration dimension, and thus reduce the trip count. 516 // If 'computeCostMap' is non-null, the total op count for forOps specified 517 // in the map is increased (not overridden) by adding the op count from the 518 // map to the existing op count for the for loop. This is done before 519 // multiplying by the loop's trip count, and is used to model the cost of 520 // inserting a sliced loop nest of known cost into the loop's body. 521 // 2) This is also used to compute the cost of fusing a slice of some loop nest 522 // within another loop. 523 static int64_t getComputeCostHelper( 524 Operation *forOp, LoopNestStats &stats, 525 llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap, 526 DenseMap<Operation *, int64_t> *computeCostMap) { 527 // 'opCount' is the total number operations in one iteration of 'forOp' body, 528 // minus terminator op which is a no-op. 529 int64_t opCount = stats.opCountMap[forOp] - 1; 530 if (stats.loopMap.count(forOp) > 0) { 531 for (auto childForOp : stats.loopMap[forOp]) { 532 opCount += getComputeCostHelper(childForOp, stats, tripCountOverrideMap, 533 computeCostMap); 534 } 535 } 536 // Add in additional op instances from slice (if specified in map). 537 if (computeCostMap != nullptr) { 538 auto it = computeCostMap->find(forOp); 539 if (it != computeCostMap->end()) { 540 opCount += it->second; 541 } 542 } 543 // Override trip count (if specified in map). 544 int64_t tripCount = stats.tripCountMap[forOp]; 545 if (tripCountOverrideMap != nullptr) { 546 auto it = tripCountOverrideMap->find(forOp); 547 if (it != tripCountOverrideMap->end()) { 548 tripCount = it->second; 549 } 550 } 551 // Returns the total number of dynamic instances of operations in loop body. 552 return tripCount * opCount; 553 } 554 555 /// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'. 556 /// Currently, the total cost is computed by counting the total operation 557 /// instance count (i.e. total number of operations in the loop body * loop 558 /// trip count) for the entire loop nest. 559 int64_t mlir::affine::getComputeCost(AffineForOp forOp, LoopNestStats &stats) { 560 return getComputeCostHelper(forOp, stats, 561 /*tripCountOverrideMap=*/nullptr, 562 /*computeCostMap=*/nullptr); 563 } 564 565 /// Computes and returns in 'computeCost', the total compute cost of fusing the 566 /// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently, 567 /// the total cost is computed by counting the total operation instance count 568 /// (i.e. total number of operations in the loop body * loop trip count) for 569 /// the entire loop nest. 570 bool mlir::affine::getFusionComputeCost(AffineForOp srcForOp, 571 LoopNestStats &srcStats, 572 AffineForOp dstForOp, 573 LoopNestStats &dstStats, 574 const ComputationSliceState &slice, 575 int64_t *computeCost) { 576 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap; 577 DenseMap<Operation *, int64_t> computeCostMap; 578 579 // Build trip count map for computation slice. 580 if (!buildSliceTripCountMap(slice, &sliceTripCountMap)) 581 return false; 582 // Checks whether a store to load forwarding will happen. 583 int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap); 584 assert(sliceIterationCount > 0); 585 bool storeLoadFwdGuaranteed = (sliceIterationCount == 1); 586 auto *insertPointParent = slice.insertPoint->getParentOp(); 587 588 // The store and loads to this memref will disappear. 589 // TODO: Add load coalescing to memref data flow opt pass. 590 if (storeLoadFwdGuaranteed) { 591 // Subtract from operation count the loads/store we expect load/store 592 // forwarding to remove. 593 unsigned storeCount = 0; 594 llvm::SmallDenseSet<Value, 4> storeMemrefs; 595 srcForOp.walk([&](Operation *op) { 596 if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) { 597 storeMemrefs.insert(storeOp.getMemRef()); 598 ++storeCount; 599 } 600 }); 601 // Subtract out any store ops in single-iteration src slice loop nest. 602 if (storeCount > 0) 603 computeCostMap[insertPointParent] = -storeCount; 604 // Subtract out any load users of 'storeMemrefs' nested below 605 // 'insertPointParent'. 606 for (Value memref : storeMemrefs) { 607 for (auto *user : memref.getUsers()) { 608 if (auto loadOp = dyn_cast<AffineReadOpInterface>(user)) { 609 SmallVector<AffineForOp, 4> loops; 610 // Check if any loop in loop nest surrounding 'user' is 611 // 'insertPointParent'. 612 getAffineForIVs(*user, &loops); 613 if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) { 614 if (auto forOp = 615 dyn_cast_or_null<AffineForOp>(user->getParentOp())) { 616 if (computeCostMap.count(forOp) == 0) 617 computeCostMap[forOp] = 0; 618 computeCostMap[forOp] -= 1; 619 } 620 } 621 } 622 } 623 } 624 } 625 626 // Compute op instance count for the src loop nest with iteration slicing. 627 int64_t sliceComputeCost = getComputeCostHelper( 628 srcForOp, srcStats, &sliceTripCountMap, &computeCostMap); 629 630 // Compute cost of fusion for this depth. 631 computeCostMap[insertPointParent] = sliceComputeCost; 632 633 *computeCost = 634 getComputeCostHelper(dstForOp, dstStats, 635 /*tripCountOverrideMap=*/nullptr, &computeCostMap); 636 return true; 637 } 638 639 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a 640 /// producer-consumer dependence between write ops in 'srcOps' and read ops in 641 /// 'dstOps'. 642 void mlir::affine::gatherProducerConsumerMemrefs( 643 ArrayRef<Operation *> srcOps, ArrayRef<Operation *> dstOps, 644 DenseSet<Value> &producerConsumerMemrefs) { 645 // Gather memrefs from stores in 'srcOps'. 646 DenseSet<Value> srcStoreMemRefs; 647 for (Operation *op : srcOps) 648 if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) 649 srcStoreMemRefs.insert(storeOp.getMemRef()); 650 651 // Compute the intersection between memrefs from stores in 'srcOps' and 652 // memrefs from loads in 'dstOps'. 653 for (Operation *op : dstOps) 654 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) 655 if (srcStoreMemRefs.count(loadOp.getMemRef()) > 0) 656 producerConsumerMemrefs.insert(loadOp.getMemRef()); 657 } 658