1 //===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop fusion transformation utility functions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/LoopFusionUtils.h" 14 #include "mlir/Analysis/SliceAnalysis.h" 15 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" 16 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" 17 #include "mlir/Dialect/Affine/Analysis/Utils.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Affine/LoopUtils.h" 20 #include "mlir/IR/IRMapping.h" 21 #include "mlir/IR/Operation.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/raw_ostream.h" 24 #include <optional> 25 26 #define DEBUG_TYPE "loop-fusion-utils" 27 28 using namespace mlir; 29 30 // Gathers all load and store memref accesses in 'opA' into 'values', where 31 // 'values[memref] == true' for each store operation. 32 static void getLoadAndStoreMemRefAccesses(Operation *opA, 33 DenseMap<Value, bool> &values) { 34 opA->walk([&](Operation *op) { 35 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) { 36 if (values.count(loadOp.getMemRef()) == 0) 37 values[loadOp.getMemRef()] = false; 38 } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) { 39 values[storeOp.getMemRef()] = true; 40 } 41 }); 42 } 43 44 /// Returns true if 'op' is a load or store operation which access a memref 45 /// accessed 'values' and at least one of the access is a store operation. 46 /// Returns false otherwise. 47 static bool isDependentLoadOrStoreOp(Operation *op, 48 DenseMap<Value, bool> &values) { 49 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) { 50 return values.count(loadOp.getMemRef()) > 0 && values[loadOp.getMemRef()]; 51 } 52 if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) { 53 return values.count(storeOp.getMemRef()) > 0; 54 } 55 return false; 56 } 57 58 // Returns the first operation in range ('opA', 'opB') which has a data 59 // dependence on 'opA'. Returns 'nullptr' of no dependence exists. 60 static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) { 61 // Record memref values from all loads/store in loop nest rooted at 'opA'. 62 // Map from memref value to bool which is true if store, false otherwise. 63 DenseMap<Value, bool> values; 64 getLoadAndStoreMemRefAccesses(opA, values); 65 66 // For each 'opX' in block in range ('opA', 'opB'), check if there is a data 67 // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref 68 // and at least one of the accesses is a store). 69 Operation *firstDepOp = nullptr; 70 for (Block::iterator it = std::next(Block::iterator(opA)); 71 it != Block::iterator(opB); ++it) { 72 Operation *opX = &(*it); 73 opX->walk([&](Operation *op) { 74 if (!firstDepOp && isDependentLoadOrStoreOp(op, values)) 75 firstDepOp = opX; 76 }); 77 if (firstDepOp) 78 break; 79 } 80 return firstDepOp; 81 } 82 83 // Returns the last operation 'opX' in range ('opA', 'opB'), for which there 84 // exists a data dependence from 'opX' to 'opB'. 85 // Returns 'nullptr' of no dependence exists. 86 static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) { 87 // Record memref values from all loads/store in loop nest rooted at 'opB'. 88 // Map from memref value to bool which is true if store, false otherwise. 89 DenseMap<Value, bool> values; 90 getLoadAndStoreMemRefAccesses(opB, values); 91 92 // For each 'opX' in block in range ('opA', 'opB') in reverse order, 93 // check if there is a data dependence from 'opX' to 'opB': 94 // *) 'opX' and 'opB' access the same memref and at least one of the accesses 95 // is a store. 96 // *) 'opX' produces an SSA Value which is used by 'opB'. 97 Operation *lastDepOp = nullptr; 98 for (Block::reverse_iterator it = std::next(Block::reverse_iterator(opB)); 99 it != Block::reverse_iterator(opA); ++it) { 100 Operation *opX = &(*it); 101 opX->walk([&](Operation *op) { 102 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) { 103 if (isDependentLoadOrStoreOp(op, values)) { 104 lastDepOp = opX; 105 return WalkResult::interrupt(); 106 } 107 return WalkResult::advance(); 108 } 109 for (Value value : op->getResults()) { 110 for (Operation *user : value.getUsers()) { 111 SmallVector<AffineForOp, 4> loops; 112 // Check if any loop in loop nest surrounding 'user' is 'opB'. 113 getAffineForIVs(*user, &loops); 114 if (llvm::is_contained(loops, cast<AffineForOp>(opB))) { 115 lastDepOp = opX; 116 return WalkResult::interrupt(); 117 } 118 } 119 } 120 return WalkResult::advance(); 121 }); 122 if (lastDepOp) 123 break; 124 } 125 return lastDepOp; 126 } 127 128 // Computes and returns an insertion point operation, before which the 129 // the fused <srcForOp, dstForOp> loop nest can be inserted while preserving 130 // dependences. Returns nullptr if no such insertion point is found. 131 static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp, 132 AffineForOp dstForOp) { 133 bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp); 134 auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp; 135 auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp; 136 137 Operation *firstDepOpA = getFirstDependentOpInRange(forOpA, forOpB); 138 Operation *lastDepOpB = getLastDependentOpInRange(forOpA, forOpB); 139 // Block: 140 // ... 141 // |-- opA 142 // | ... 143 // | lastDepOpB --| 144 // | ... | 145 // |-> firstDepOpA | 146 // ... | 147 // opB <--------- 148 // 149 // Valid insertion point range: (lastDepOpB, firstDepOpA) 150 // 151 if (firstDepOpA != nullptr) { 152 if (lastDepOpB != nullptr) { 153 if (firstDepOpA->isBeforeInBlock(lastDepOpB) || firstDepOpA == lastDepOpB) 154 // No valid insertion point exists which preserves dependences. 155 return nullptr; 156 } 157 // Return insertion point in valid range closest to 'opB'. 158 // TODO: Consider other insertion points in valid range. 159 return firstDepOpA; 160 } 161 // No dependences from 'opA' to operation in range ('opA', 'opB'), return 162 // 'opB' insertion point. 163 return forOpB; 164 } 165 166 // Gathers all load and store ops in loop nest rooted at 'forOp' into 167 // 'loadAndStoreOps'. 168 static bool 169 gatherLoadsAndStores(AffineForOp forOp, 170 SmallVectorImpl<Operation *> &loadAndStoreOps) { 171 bool hasIfOp = false; 172 forOp.walk([&](Operation *op) { 173 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) 174 loadAndStoreOps.push_back(op); 175 else if (isa<AffineIfOp>(op)) 176 hasIfOp = true; 177 }); 178 return !hasIfOp; 179 } 180 181 /// Returns the maximum loop depth at which we could fuse producer loop 182 /// 'srcForOp' into consumer loop 'dstForOp' without violating data dependences. 183 // TODO: Generalize this check for sibling and more generic fusion scenarios. 184 // TODO: Support forward slice fusion. 185 static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps, 186 ArrayRef<Operation *> dstOps) { 187 if (dstOps.empty()) 188 // Expected at least one memory operation. 189 // TODO: Revisit this case with a specific example. 190 return 0; 191 192 // Filter out ops in 'dstOps' that do not use the producer-consumer memref so 193 // that they are not considered for analysis. 194 DenseSet<Value> producerConsumerMemrefs; 195 gatherProducerConsumerMemrefs(srcOps, dstOps, producerConsumerMemrefs); 196 SmallVector<Operation *, 4> targetDstOps; 197 for (Operation *dstOp : dstOps) { 198 auto loadOp = dyn_cast<AffineReadOpInterface>(dstOp); 199 Value memref = loadOp ? loadOp.getMemRef() 200 : cast<AffineWriteOpInterface>(dstOp).getMemRef(); 201 if (producerConsumerMemrefs.count(memref) > 0) 202 targetDstOps.push_back(dstOp); 203 } 204 205 assert(!targetDstOps.empty() && 206 "No dependences between 'srcForOp' and 'dstForOp'?"); 207 208 // Compute the innermost common loop depth for loads and stores. 209 unsigned loopDepth = getInnermostCommonLoopDepth(targetDstOps); 210 211 // Return common loop depth for loads if there are no store ops. 212 if (all_of(targetDstOps, 213 [&](Operation *op) { return isa<AffineReadOpInterface>(op); })) 214 return loopDepth; 215 216 // Check dependences on all pairs of ops in 'targetDstOps' and store the 217 // minimum loop depth at which a dependence is satisfied. 218 for (unsigned i = 0, e = targetDstOps.size(); i < e; ++i) { 219 auto *srcOpInst = targetDstOps[i]; 220 MemRefAccess srcAccess(srcOpInst); 221 for (unsigned j = 0; j < e; ++j) { 222 auto *dstOpInst = targetDstOps[j]; 223 MemRefAccess dstAccess(dstOpInst); 224 225 unsigned numCommonLoops = 226 getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst); 227 for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { 228 FlatAffineValueConstraints dependenceConstraints; 229 // TODO: Cache dependence analysis results, check cache here. 230 DependenceResult result = checkMemrefAccessDependence( 231 srcAccess, dstAccess, d, &dependenceConstraints, 232 /*dependenceComponents=*/nullptr); 233 if (hasDependence(result)) { 234 // Store minimum loop depth and break because we want the min 'd' at 235 // which there is a dependence. 236 loopDepth = std::min(loopDepth, d - 1); 237 break; 238 } 239 } 240 } 241 } 242 243 return loopDepth; 244 } 245 246 // TODO: Prevent fusion of loop nests with side-effecting operations. 247 // TODO: This pass performs some computation that is the same for all the depths 248 // (e.g., getMaxLoopDepth). Implement a version of this utility that processes 249 // all the depths at once or only the legal maximal depth for maximal fusion. 250 FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, 251 unsigned dstLoopDepth, 252 ComputationSliceState *srcSlice, 253 FusionStrategy fusionStrategy) { 254 // Return 'failure' if 'dstLoopDepth == 0'. 255 if (dstLoopDepth == 0) { 256 LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n"); 257 return FusionResult::FailPrecondition; 258 } 259 // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block. 260 auto *block = srcForOp->getBlock(); 261 if (block != dstForOp->getBlock()) { 262 LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n"); 263 return FusionResult::FailPrecondition; 264 } 265 266 // Return 'failure' if no valid insertion point for fused loop nest in 'block' 267 // exists which would preserve dependences. 268 if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) { 269 LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n"); 270 return FusionResult::FailBlockDependence; 271 } 272 273 // Check if 'srcForOp' precedes 'dstForOp' in 'block'. 274 bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp); 275 // 'forOpA' executes before 'forOpB' in 'block'. 276 auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp; 277 auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp; 278 279 // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'. 280 SmallVector<Operation *, 4> opsA; 281 if (!gatherLoadsAndStores(forOpA, opsA)) { 282 LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n"); 283 return FusionResult::FailPrecondition; 284 } 285 286 // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'. 287 SmallVector<Operation *, 4> opsB; 288 if (!gatherLoadsAndStores(forOpB, opsB)) { 289 LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n"); 290 return FusionResult::FailPrecondition; 291 } 292 293 // Return 'failure' if fusing loops at depth 'dstLoopDepth' wouldn't preserve 294 // loop dependences. 295 // TODO: Enable this check for sibling and more generic loop fusion 296 // strategies. 297 if (fusionStrategy.getStrategy() == FusionStrategy::ProducerConsumer) { 298 // TODO: 'getMaxLoopDepth' does not support forward slice fusion. 299 assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion"); 300 if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) { 301 LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n"); 302 return FusionResult::FailFusionDependence; 303 } 304 } 305 306 // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'. 307 unsigned numCommonLoops = 308 mlir::getNumCommonSurroundingLoops(*srcForOp, *dstForOp); 309 310 // Filter out ops in 'opsA' to compute the slice union based on the 311 // assumptions made by the fusion strategy. 312 SmallVector<Operation *, 4> strategyOpsA; 313 switch (fusionStrategy.getStrategy()) { 314 case FusionStrategy::Generic: 315 // Generic fusion. Take into account all the memory operations to compute 316 // the slice union. 317 strategyOpsA.append(opsA.begin(), opsA.end()); 318 break; 319 case FusionStrategy::ProducerConsumer: 320 // Producer-consumer fusion (AffineLoopFusion pass) only takes into 321 // account stores in 'srcForOp' to compute the slice union. 322 for (Operation *op : opsA) { 323 if (isa<AffineWriteOpInterface>(op)) 324 strategyOpsA.push_back(op); 325 } 326 break; 327 case FusionStrategy::Sibling: 328 // Sibling fusion (AffineLoopFusion pass) only takes into account the loads 329 // to 'memref' in 'srcForOp' to compute the slice union. 330 for (Operation *op : opsA) { 331 auto load = dyn_cast<AffineReadOpInterface>(op); 332 if (load && load.getMemRef() == fusionStrategy.getSiblingFusionMemRef()) 333 strategyOpsA.push_back(op); 334 } 335 break; 336 } 337 338 // Compute union of computation slices computed between all pairs of ops 339 // from 'forOpA' and 'forOpB'. 340 SliceComputationResult sliceComputationResult = 341 mlir::computeSliceUnion(strategyOpsA, opsB, dstLoopDepth, numCommonLoops, 342 isSrcForOpBeforeDstForOp, srcSlice); 343 if (sliceComputationResult.value == SliceComputationResult::GenericFailure) { 344 LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n"); 345 return FusionResult::FailPrecondition; 346 } 347 if (sliceComputationResult.value == 348 SliceComputationResult::IncorrectSliceFailure) { 349 LLVM_DEBUG(llvm::dbgs() << "Incorrect slice computation\n"); 350 return FusionResult::FailIncorrectSlice; 351 } 352 353 return FusionResult::Success; 354 } 355 356 /// Patch the loop body of a forOp that is a single iteration reduction loop 357 /// into its containing block. 358 LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp, 359 bool siblingFusionUser) { 360 // Check if the reduction loop is a single iteration loop. 361 std::optional<uint64_t> tripCount = getConstantTripCount(forOp); 362 if (!tripCount || *tripCount != 1) 363 return failure(); 364 auto iterOperands = forOp.getIterOperands(); 365 auto *parentOp = forOp->getParentOp(); 366 if (!isa<AffineForOp>(parentOp)) 367 return failure(); 368 auto newOperands = forOp.getBody()->getTerminator()->getOperands(); 369 OpBuilder b(parentOp); 370 // Replace the parent loop and add iteroperands and results from the `forOp`. 371 AffineForOp parentForOp = forOp->getParentOfType<AffineForOp>(); 372 AffineForOp newLoop = replaceForOpWithNewYields( 373 b, parentForOp, iterOperands, newOperands, forOp.getRegionIterArgs()); 374 375 // For sibling-fusion users, collect operations that use the results of the 376 // `forOp` outside the new parent loop that has absorbed all its iter args 377 // and operands. These operations will be moved later after the results 378 // have been replaced. 379 SetVector<Operation *> forwardSlice; 380 if (siblingFusionUser) { 381 for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) { 382 SetVector<Operation *> tmpForwardSlice; 383 getForwardSlice(forOp.getResult(i), &tmpForwardSlice); 384 forwardSlice.set_union(tmpForwardSlice); 385 } 386 } 387 // Update the results of the `forOp` in the new loop. 388 for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) { 389 forOp.getResult(i).replaceAllUsesWith( 390 newLoop.getResult(i + parentOp->getNumResults())); 391 } 392 // For sibling-fusion users, move operations that use the results of the 393 // `forOp` outside the new parent loop 394 if (siblingFusionUser) { 395 topologicalSort(forwardSlice); 396 for (Operation *op : llvm::reverse(forwardSlice)) 397 op->moveAfter(newLoop); 398 } 399 // Replace the induction variable. 400 auto iv = forOp.getInductionVar(); 401 iv.replaceAllUsesWith(newLoop.getInductionVar()); 402 // Replace the iter args. 403 auto forOpIterArgs = forOp.getRegionIterArgs(); 404 for (auto it : llvm::zip(forOpIterArgs, newLoop.getRegionIterArgs().take_back( 405 forOpIterArgs.size()))) { 406 std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); 407 } 408 // Move the loop body operations, except for its terminator, to the loop's 409 // containing block. 410 forOp.getBody()->back().erase(); 411 auto *parentBlock = forOp->getBlock(); 412 parentBlock->getOperations().splice(Block::iterator(forOp), 413 forOp.getBody()->getOperations()); 414 forOp.erase(); 415 parentForOp.erase(); 416 return success(); 417 } 418 419 /// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point 420 /// and source slice loop bounds specified in 'srcSlice'. 421 void mlir::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, 422 const ComputationSliceState &srcSlice, 423 bool isInnermostSiblingInsertion) { 424 // Clone 'srcForOp' into 'dstForOp' at 'srcSlice->insertPoint'. 425 OpBuilder b(srcSlice.insertPoint->getBlock(), srcSlice.insertPoint); 426 IRMapping mapper; 427 b.clone(*srcForOp, mapper); 428 429 // Update 'sliceLoopNest' upper and lower bounds from computed 'srcSlice'. 430 SmallVector<AffineForOp, 4> sliceLoops; 431 for (unsigned i = 0, e = srcSlice.ivs.size(); i < e; ++i) { 432 auto loopIV = mapper.lookupOrNull(srcSlice.ivs[i]); 433 if (!loopIV) 434 continue; 435 auto forOp = getForInductionVarOwner(loopIV); 436 sliceLoops.push_back(forOp); 437 if (AffineMap lbMap = srcSlice.lbs[i]) { 438 auto lbOperands = srcSlice.lbOperands[i]; 439 canonicalizeMapAndOperands(&lbMap, &lbOperands); 440 forOp.setLowerBound(lbOperands, lbMap); 441 } 442 if (AffineMap ubMap = srcSlice.ubs[i]) { 443 auto ubOperands = srcSlice.ubOperands[i]; 444 canonicalizeMapAndOperands(&ubMap, &ubOperands); 445 forOp.setUpperBound(ubOperands, ubMap); 446 } 447 } 448 449 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap; 450 auto srcIsUnitSlice = [&]() { 451 return (buildSliceTripCountMap(srcSlice, &sliceTripCountMap) && 452 (getSliceIterationCount(sliceTripCountMap) == 1)); 453 }; 454 // Fix up and if possible, eliminate single iteration loops. 455 for (AffineForOp forOp : sliceLoops) { 456 if (isLoopParallelAndContainsReduction(forOp) && 457 isInnermostSiblingInsertion && srcIsUnitSlice()) 458 // Patch reduction loop - only ones that are sibling-fused with the 459 // destination loop - into the parent loop. 460 (void)promoteSingleIterReductionLoop(forOp, true); 461 else 462 // Promote any single iteration slice loops. 463 (void)promoteIfSingleIteration(forOp); 464 } 465 } 466 467 /// Collect loop nest statistics (eg. loop trip count and operation count) 468 /// in 'stats' for loop nest rooted at 'forOp'. Returns true on success, 469 /// returns false otherwise. 470 bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) { 471 auto walkResult = forOpRoot.walk([&](AffineForOp forOp) { 472 auto *childForOp = forOp.getOperation(); 473 auto *parentForOp = forOp->getParentOp(); 474 if (forOp != forOpRoot) { 475 if (!isa<AffineForOp>(parentForOp)) { 476 LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n"); 477 return WalkResult::interrupt(); 478 } 479 // Add mapping to 'forOp' from its parent AffineForOp. 480 stats->loopMap[parentForOp].push_back(forOp); 481 } 482 483 // Record the number of op operations in the body of 'forOp'. 484 unsigned count = 0; 485 stats->opCountMap[childForOp] = 0; 486 for (auto &op : *forOp.getBody()) { 487 if (!isa<AffineForOp, AffineIfOp>(op)) 488 ++count; 489 } 490 stats->opCountMap[childForOp] = count; 491 492 // Record trip count for 'forOp'. Set flag if trip count is not 493 // constant. 494 std::optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp); 495 if (!maybeConstTripCount) { 496 // Currently only constant trip count loop nests are supported. 497 LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported\n"); 498 return WalkResult::interrupt(); 499 } 500 501 stats->tripCountMap[childForOp] = *maybeConstTripCount; 502 return WalkResult::advance(); 503 }); 504 return !walkResult.wasInterrupted(); 505 } 506 507 // Computes the total cost of the loop nest rooted at 'forOp'. 508 // Currently, the total cost is computed by counting the total operation 509 // instance count (i.e. total number of operations in the loop bodyloop 510 // operation count * loop trip count) for the entire loop nest. 511 // If 'tripCountOverrideMap' is non-null, overrides the trip count for loops 512 // specified in the map when computing the total op instance count. 513 // NOTEs: 1) This is used to compute the cost of computation slices, which are 514 // sliced along the iteration dimension, and thus reduce the trip count. 515 // If 'computeCostMap' is non-null, the total op count for forOps specified 516 // in the map is increased (not overridden) by adding the op count from the 517 // map to the existing op count for the for loop. This is done before 518 // multiplying by the loop's trip count, and is used to model the cost of 519 // inserting a sliced loop nest of known cost into the loop's body. 520 // 2) This is also used to compute the cost of fusing a slice of some loop nest 521 // within another loop. 522 static int64_t getComputeCostHelper( 523 Operation *forOp, LoopNestStats &stats, 524 llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap, 525 DenseMap<Operation *, int64_t> *computeCostMap) { 526 // 'opCount' is the total number operations in one iteration of 'forOp' body, 527 // minus terminator op which is a no-op. 528 int64_t opCount = stats.opCountMap[forOp] - 1; 529 if (stats.loopMap.count(forOp) > 0) { 530 for (auto childForOp : stats.loopMap[forOp]) { 531 opCount += getComputeCostHelper(childForOp, stats, tripCountOverrideMap, 532 computeCostMap); 533 } 534 } 535 // Add in additional op instances from slice (if specified in map). 536 if (computeCostMap != nullptr) { 537 auto it = computeCostMap->find(forOp); 538 if (it != computeCostMap->end()) { 539 opCount += it->second; 540 } 541 } 542 // Override trip count (if specified in map). 543 int64_t tripCount = stats.tripCountMap[forOp]; 544 if (tripCountOverrideMap != nullptr) { 545 auto it = tripCountOverrideMap->find(forOp); 546 if (it != tripCountOverrideMap->end()) { 547 tripCount = it->second; 548 } 549 } 550 // Returns the total number of dynamic instances of operations in loop body. 551 return tripCount * opCount; 552 } 553 554 /// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'. 555 /// Currently, the total cost is computed by counting the total operation 556 /// instance count (i.e. total number of operations in the loop body * loop 557 /// trip count) for the entire loop nest. 558 int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) { 559 return getComputeCostHelper(forOp, stats, 560 /*tripCountOverrideMap=*/nullptr, 561 /*computeCostMap=*/nullptr); 562 } 563 564 /// Computes and returns in 'computeCost', the total compute cost of fusing the 565 /// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently, 566 /// the total cost is computed by counting the total operation instance count 567 /// (i.e. total number of operations in the loop body * loop trip count) for 568 /// the entire loop nest. 569 bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, 570 AffineForOp dstForOp, LoopNestStats &dstStats, 571 const ComputationSliceState &slice, 572 int64_t *computeCost) { 573 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap; 574 DenseMap<Operation *, int64_t> computeCostMap; 575 576 // Build trip count map for computation slice. 577 if (!buildSliceTripCountMap(slice, &sliceTripCountMap)) 578 return false; 579 // Checks whether a store to load forwarding will happen. 580 int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap); 581 assert(sliceIterationCount > 0); 582 bool storeLoadFwdGuaranteed = (sliceIterationCount == 1); 583 auto *insertPointParent = slice.insertPoint->getParentOp(); 584 585 // The store and loads to this memref will disappear. 586 // TODO: Add load coalescing to memref data flow opt pass. 587 if (storeLoadFwdGuaranteed) { 588 // Subtract from operation count the loads/store we expect load/store 589 // forwarding to remove. 590 unsigned storeCount = 0; 591 llvm::SmallDenseSet<Value, 4> storeMemrefs; 592 srcForOp.walk([&](Operation *op) { 593 if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) { 594 storeMemrefs.insert(storeOp.getMemRef()); 595 ++storeCount; 596 } 597 }); 598 // Subtract out any store ops in single-iteration src slice loop nest. 599 if (storeCount > 0) 600 computeCostMap[insertPointParent] = -storeCount; 601 // Subtract out any load users of 'storeMemrefs' nested below 602 // 'insertPointParent'. 603 for (Value memref : storeMemrefs) { 604 for (auto *user : memref.getUsers()) { 605 if (auto loadOp = dyn_cast<AffineReadOpInterface>(user)) { 606 SmallVector<AffineForOp, 4> loops; 607 // Check if any loop in loop nest surrounding 'user' is 608 // 'insertPointParent'. 609 getAffineForIVs(*user, &loops); 610 if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) { 611 if (auto forOp = 612 dyn_cast_or_null<AffineForOp>(user->getParentOp())) { 613 if (computeCostMap.count(forOp) == 0) 614 computeCostMap[forOp] = 0; 615 computeCostMap[forOp] -= 1; 616 } 617 } 618 } 619 } 620 } 621 } 622 623 // Compute op instance count for the src loop nest with iteration slicing. 624 int64_t sliceComputeCost = getComputeCostHelper( 625 srcForOp, srcStats, &sliceTripCountMap, &computeCostMap); 626 627 // Compute cost of fusion for this depth. 628 computeCostMap[insertPointParent] = sliceComputeCost; 629 630 *computeCost = 631 getComputeCostHelper(dstForOp, dstStats, 632 /*tripCountOverrideMap=*/nullptr, &computeCostMap); 633 return true; 634 } 635 636 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a 637 /// producer-consumer dependence between write ops in 'srcOps' and read ops in 638 /// 'dstOps'. 639 void mlir::gatherProducerConsumerMemrefs( 640 ArrayRef<Operation *> srcOps, ArrayRef<Operation *> dstOps, 641 DenseSet<Value> &producerConsumerMemrefs) { 642 // Gather memrefs from stores in 'srcOps'. 643 DenseSet<Value> srcStoreMemRefs; 644 for (Operation *op : srcOps) 645 if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) 646 srcStoreMemRefs.insert(storeOp.getMemRef()); 647 648 // Compute the intersection between memrefs from stores in 'srcOps' and 649 // memrefs from loads in 'dstOps'. 650 for (Operation *op : dstOps) 651 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) 652 if (srcStoreMemRefs.count(loadOp.getMemRef()) > 0) 653 producerConsumerMemrefs.insert(loadOp.getMemRef()); 654 } 655