1 //===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop fusion transformation utility functions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/LoopFusionUtils.h" 14 #include "mlir/Analysis/SliceAnalysis.h" 15 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" 16 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" 17 #include "mlir/Dialect/Affine/Analysis/Utils.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Affine/LoopUtils.h" 20 #include "mlir/IR/BlockAndValueMapping.h" 21 #include "mlir/IR/Operation.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/raw_ostream.h" 24 25 #define DEBUG_TYPE "loop-fusion-utils" 26 27 using namespace mlir; 28 29 // Gathers all load and store memref accesses in 'opA' into 'values', where 30 // 'values[memref] == true' for each store operation. 31 static void getLoadAndStoreMemRefAccesses(Operation *opA, 32 DenseMap<Value, bool> &values) { 33 opA->walk([&](Operation *op) { 34 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) { 35 if (values.count(loadOp.getMemRef()) == 0) 36 values[loadOp.getMemRef()] = false; 37 } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) { 38 values[storeOp.getMemRef()] = true; 39 } 40 }); 41 } 42 43 /// Returns true if 'op' is a load or store operation which access a memref 44 /// accessed 'values' and at least one of the access is a store operation. 45 /// Returns false otherwise. 46 static bool isDependentLoadOrStoreOp(Operation *op, 47 DenseMap<Value, bool> &values) { 48 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) { 49 return values.count(loadOp.getMemRef()) > 0 && values[loadOp.getMemRef()]; 50 } 51 if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) { 52 return values.count(storeOp.getMemRef()) > 0; 53 } 54 return false; 55 } 56 57 // Returns the first operation in range ('opA', 'opB') which has a data 58 // dependence on 'opA'. Returns 'nullptr' of no dependence exists. 59 static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) { 60 // Record memref values from all loads/store in loop nest rooted at 'opA'. 61 // Map from memref value to bool which is true if store, false otherwise. 62 DenseMap<Value, bool> values; 63 getLoadAndStoreMemRefAccesses(opA, values); 64 65 // For each 'opX' in block in range ('opA', 'opB'), check if there is a data 66 // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref 67 // and at least one of the accesses is a store). 68 Operation *firstDepOp = nullptr; 69 for (Block::iterator it = std::next(Block::iterator(opA)); 70 it != Block::iterator(opB); ++it) { 71 Operation *opX = &(*it); 72 opX->walk([&](Operation *op) { 73 if (!firstDepOp && isDependentLoadOrStoreOp(op, values)) 74 firstDepOp = opX; 75 }); 76 if (firstDepOp) 77 break; 78 } 79 return firstDepOp; 80 } 81 82 // Returns the last operation 'opX' in range ('opA', 'opB'), for which there 83 // exists a data dependence from 'opX' to 'opB'. 84 // Returns 'nullptr' of no dependence exists. 85 static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) { 86 // Record memref values from all loads/store in loop nest rooted at 'opB'. 87 // Map from memref value to bool which is true if store, false otherwise. 88 DenseMap<Value, bool> values; 89 getLoadAndStoreMemRefAccesses(opB, values); 90 91 // For each 'opX' in block in range ('opA', 'opB') in reverse order, 92 // check if there is a data dependence from 'opX' to 'opB': 93 // *) 'opX' and 'opB' access the same memref and at least one of the accesses 94 // is a store. 95 // *) 'opX' produces an SSA Value which is used by 'opB'. 96 Operation *lastDepOp = nullptr; 97 for (Block::reverse_iterator it = std::next(Block::reverse_iterator(opB)); 98 it != Block::reverse_iterator(opA); ++it) { 99 Operation *opX = &(*it); 100 opX->walk([&](Operation *op) { 101 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) { 102 if (isDependentLoadOrStoreOp(op, values)) { 103 lastDepOp = opX; 104 return WalkResult::interrupt(); 105 } 106 return WalkResult::advance(); 107 } 108 for (Value value : op->getResults()) { 109 for (Operation *user : value.getUsers()) { 110 SmallVector<AffineForOp, 4> loops; 111 // Check if any loop in loop nest surrounding 'user' is 'opB'. 112 getAffineForIVs(*user, &loops); 113 if (llvm::is_contained(loops, cast<AffineForOp>(opB))) { 114 lastDepOp = opX; 115 return WalkResult::interrupt(); 116 } 117 } 118 } 119 return WalkResult::advance(); 120 }); 121 if (lastDepOp) 122 break; 123 } 124 return lastDepOp; 125 } 126 127 // Computes and returns an insertion point operation, before which the 128 // the fused <srcForOp, dstForOp> loop nest can be inserted while preserving 129 // dependences. Returns nullptr if no such insertion point is found. 130 static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp, 131 AffineForOp dstForOp) { 132 bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp); 133 auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp; 134 auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp; 135 136 Operation *firstDepOpA = getFirstDependentOpInRange(forOpA, forOpB); 137 Operation *lastDepOpB = getLastDependentOpInRange(forOpA, forOpB); 138 // Block: 139 // ... 140 // |-- opA 141 // | ... 142 // | lastDepOpB --| 143 // | ... | 144 // |-> firstDepOpA | 145 // ... | 146 // opB <--------- 147 // 148 // Valid insertion point range: (lastDepOpB, firstDepOpA) 149 // 150 if (firstDepOpA != nullptr) { 151 if (lastDepOpB != nullptr) { 152 if (firstDepOpA->isBeforeInBlock(lastDepOpB) || firstDepOpA == lastDepOpB) 153 // No valid insertion point exists which preserves dependences. 154 return nullptr; 155 } 156 // Return insertion point in valid range closest to 'opB'. 157 // TODO: Consider other insertion points in valid range. 158 return firstDepOpA; 159 } 160 // No dependences from 'opA' to operation in range ('opA', 'opB'), return 161 // 'opB' insertion point. 162 return forOpB; 163 } 164 165 // Gathers all load and store ops in loop nest rooted at 'forOp' into 166 // 'loadAndStoreOps'. 167 static bool 168 gatherLoadsAndStores(AffineForOp forOp, 169 SmallVectorImpl<Operation *> &loadAndStoreOps) { 170 bool hasIfOp = false; 171 forOp.walk([&](Operation *op) { 172 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) 173 loadAndStoreOps.push_back(op); 174 else if (isa<AffineIfOp>(op)) 175 hasIfOp = true; 176 }); 177 return !hasIfOp; 178 } 179 180 /// Returns the maximum loop depth at which we could fuse producer loop 181 /// 'srcForOp' into consumer loop 'dstForOp' without violating data dependences. 182 // TODO: Generalize this check for sibling and more generic fusion scenarios. 183 // TODO: Support forward slice fusion. 184 static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps, 185 ArrayRef<Operation *> dstOps) { 186 if (dstOps.empty()) 187 // Expected at least one memory operation. 188 // TODO: Revisit this case with a specific example. 189 return 0; 190 191 // Filter out ops in 'dstOps' that do not use the producer-consumer memref so 192 // that they are not considered for analysis. 193 DenseSet<Value> producerConsumerMemrefs; 194 gatherProducerConsumerMemrefs(srcOps, dstOps, producerConsumerMemrefs); 195 SmallVector<Operation *, 4> targetDstOps; 196 for (Operation *dstOp : dstOps) { 197 auto loadOp = dyn_cast<AffineReadOpInterface>(dstOp); 198 Value memref = loadOp ? loadOp.getMemRef() 199 : cast<AffineWriteOpInterface>(dstOp).getMemRef(); 200 if (producerConsumerMemrefs.count(memref) > 0) 201 targetDstOps.push_back(dstOp); 202 } 203 204 assert(!targetDstOps.empty() && 205 "No dependences between 'srcForOp' and 'dstForOp'?"); 206 207 // Compute the innermost common loop depth for loads and stores. 208 unsigned loopDepth = getInnermostCommonLoopDepth(targetDstOps); 209 210 // Return common loop depth for loads if there are no store ops. 211 if (all_of(targetDstOps, 212 [&](Operation *op) { return isa<AffineReadOpInterface>(op); })) 213 return loopDepth; 214 215 // Check dependences on all pairs of ops in 'targetDstOps' and store the 216 // minimum loop depth at which a dependence is satisfied. 217 for (unsigned i = 0, e = targetDstOps.size(); i < e; ++i) { 218 auto *srcOpInst = targetDstOps[i]; 219 MemRefAccess srcAccess(srcOpInst); 220 for (unsigned j = 0; j < e; ++j) { 221 auto *dstOpInst = targetDstOps[j]; 222 MemRefAccess dstAccess(dstOpInst); 223 224 unsigned numCommonLoops = 225 getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst); 226 for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { 227 FlatAffineValueConstraints dependenceConstraints; 228 // TODO: Cache dependence analysis results, check cache here. 229 DependenceResult result = checkMemrefAccessDependence( 230 srcAccess, dstAccess, d, &dependenceConstraints, 231 /*dependenceComponents=*/nullptr); 232 if (hasDependence(result)) { 233 // Store minimum loop depth and break because we want the min 'd' at 234 // which there is a dependence. 235 loopDepth = std::min(loopDepth, d - 1); 236 break; 237 } 238 } 239 } 240 } 241 242 return loopDepth; 243 } 244 245 // TODO: Prevent fusion of loop nests with side-effecting operations. 246 // TODO: This pass performs some computation that is the same for all the depths 247 // (e.g., getMaxLoopDepth). Implement a version of this utility that processes 248 // all the depths at once or only the legal maximal depth for maximal fusion. 249 FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, 250 unsigned dstLoopDepth, 251 ComputationSliceState *srcSlice, 252 FusionStrategy fusionStrategy) { 253 // Return 'failure' if 'dstLoopDepth == 0'. 254 if (dstLoopDepth == 0) { 255 LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n"); 256 return FusionResult::FailPrecondition; 257 } 258 // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block. 259 auto *block = srcForOp->getBlock(); 260 if (block != dstForOp->getBlock()) { 261 LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n"); 262 return FusionResult::FailPrecondition; 263 } 264 265 // Return 'failure' if no valid insertion point for fused loop nest in 'block' 266 // exists which would preserve dependences. 267 if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) { 268 LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n"); 269 return FusionResult::FailBlockDependence; 270 } 271 272 // Check if 'srcForOp' precedes 'dstForOp' in 'block'. 273 bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp); 274 // 'forOpA' executes before 'forOpB' in 'block'. 275 auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp; 276 auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp; 277 278 // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'. 279 SmallVector<Operation *, 4> opsA; 280 if (!gatherLoadsAndStores(forOpA, opsA)) { 281 LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n"); 282 return FusionResult::FailPrecondition; 283 } 284 285 // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'. 286 SmallVector<Operation *, 4> opsB; 287 if (!gatherLoadsAndStores(forOpB, opsB)) { 288 LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n"); 289 return FusionResult::FailPrecondition; 290 } 291 292 // Return 'failure' if fusing loops at depth 'dstLoopDepth' wouldn't preserve 293 // loop dependences. 294 // TODO: Enable this check for sibling and more generic loop fusion 295 // strategies. 296 if (fusionStrategy.getStrategy() == FusionStrategy::ProducerConsumer) { 297 // TODO: 'getMaxLoopDepth' does not support forward slice fusion. 298 assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion"); 299 if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) { 300 LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n"); 301 return FusionResult::FailFusionDependence; 302 } 303 } 304 305 // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'. 306 unsigned numCommonLoops = 307 mlir::getNumCommonSurroundingLoops(*srcForOp, *dstForOp); 308 309 // Filter out ops in 'opsA' to compute the slice union based on the 310 // assumptions made by the fusion strategy. 311 SmallVector<Operation *, 4> strategyOpsA; 312 switch (fusionStrategy.getStrategy()) { 313 case FusionStrategy::Generic: 314 // Generic fusion. Take into account all the memory operations to compute 315 // the slice union. 316 strategyOpsA.append(opsA.begin(), opsA.end()); 317 break; 318 case FusionStrategy::ProducerConsumer: 319 // Producer-consumer fusion (AffineLoopFusion pass) only takes into 320 // account stores in 'srcForOp' to compute the slice union. 321 for (Operation *op : opsA) { 322 if (isa<AffineWriteOpInterface>(op)) 323 strategyOpsA.push_back(op); 324 } 325 break; 326 case FusionStrategy::Sibling: 327 // Sibling fusion (AffineLoopFusion pass) only takes into account the loads 328 // to 'memref' in 'srcForOp' to compute the slice union. 329 for (Operation *op : opsA) { 330 auto load = dyn_cast<AffineReadOpInterface>(op); 331 if (load && load.getMemRef() == fusionStrategy.getSiblingFusionMemRef()) 332 strategyOpsA.push_back(op); 333 } 334 break; 335 } 336 337 // Compute union of computation slices computed between all pairs of ops 338 // from 'forOpA' and 'forOpB'. 339 SliceComputationResult sliceComputationResult = 340 mlir::computeSliceUnion(strategyOpsA, opsB, dstLoopDepth, numCommonLoops, 341 isSrcForOpBeforeDstForOp, srcSlice); 342 if (sliceComputationResult.value == SliceComputationResult::GenericFailure) { 343 LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n"); 344 return FusionResult::FailPrecondition; 345 } 346 if (sliceComputationResult.value == 347 SliceComputationResult::IncorrectSliceFailure) { 348 LLVM_DEBUG(llvm::dbgs() << "Incorrect slice computation\n"); 349 return FusionResult::FailIncorrectSlice; 350 } 351 352 return FusionResult::Success; 353 } 354 355 /// Patch the loop body of a forOp that is a single iteration reduction loop 356 /// into its containing block. 357 LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp, 358 bool siblingFusionUser) { 359 // Check if the reduction loop is a single iteration loop. 360 Optional<uint64_t> tripCount = getConstantTripCount(forOp); 361 if (!tripCount || *tripCount != 1) 362 return failure(); 363 auto iterOperands = forOp.getIterOperands(); 364 auto *parentOp = forOp->getParentOp(); 365 if (!isa<AffineForOp>(parentOp)) 366 return failure(); 367 auto newOperands = forOp.getBody()->getTerminator()->getOperands(); 368 OpBuilder b(parentOp); 369 // Replace the parent loop and add iteroperands and results from the `forOp`. 370 AffineForOp parentForOp = forOp->getParentOfType<AffineForOp>(); 371 AffineForOp newLoop = replaceForOpWithNewYields( 372 b, parentForOp, iterOperands, newOperands, forOp.getRegionIterArgs()); 373 374 // For sibling-fusion users, collect operations that use the results of the 375 // `forOp` outside the new parent loop that has absorbed all its iter args 376 // and operands. These operations will be moved later after the results 377 // have been replaced. 378 SetVector<Operation *> forwardSlice; 379 if (siblingFusionUser) { 380 for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) { 381 SetVector<Operation *> tmpForwardSlice; 382 getForwardSlice(forOp.getResult(i), &tmpForwardSlice); 383 forwardSlice.set_union(tmpForwardSlice); 384 } 385 } 386 // Update the results of the `forOp` in the new loop. 387 for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) { 388 forOp.getResult(i).replaceAllUsesWith( 389 newLoop.getResult(i + parentOp->getNumResults())); 390 } 391 // For sibling-fusion users, move operations that use the results of the 392 // `forOp` outside the new parent loop 393 if (siblingFusionUser) { 394 topologicalSort(forwardSlice); 395 for (Operation *op : llvm::reverse(forwardSlice)) 396 op->moveAfter(newLoop); 397 } 398 // Replace the induction variable. 399 auto iv = forOp.getInductionVar(); 400 iv.replaceAllUsesWith(newLoop.getInductionVar()); 401 // Replace the iter args. 402 auto forOpIterArgs = forOp.getRegionIterArgs(); 403 for (auto it : llvm::zip(forOpIterArgs, newLoop.getRegionIterArgs().take_back( 404 forOpIterArgs.size()))) { 405 std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); 406 } 407 // Move the loop body operations, except for its terminator, to the loop's 408 // containing block. 409 forOp.getBody()->back().erase(); 410 auto *parentBlock = forOp->getBlock(); 411 parentBlock->getOperations().splice(Block::iterator(forOp), 412 forOp.getBody()->getOperations()); 413 forOp.erase(); 414 parentForOp.erase(); 415 return success(); 416 } 417 418 /// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point 419 /// and source slice loop bounds specified in 'srcSlice'. 420 void mlir::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, 421 const ComputationSliceState &srcSlice, 422 bool isInnermostSiblingInsertion) { 423 // Clone 'srcForOp' into 'dstForOp' at 'srcSlice->insertPoint'. 424 OpBuilder b(srcSlice.insertPoint->getBlock(), srcSlice.insertPoint); 425 BlockAndValueMapping mapper; 426 b.clone(*srcForOp, mapper); 427 428 // Update 'sliceLoopNest' upper and lower bounds from computed 'srcSlice'. 429 SmallVector<AffineForOp, 4> sliceLoops; 430 for (unsigned i = 0, e = srcSlice.ivs.size(); i < e; ++i) { 431 auto loopIV = mapper.lookupOrNull(srcSlice.ivs[i]); 432 if (!loopIV) 433 continue; 434 auto forOp = getForInductionVarOwner(loopIV); 435 sliceLoops.push_back(forOp); 436 if (AffineMap lbMap = srcSlice.lbs[i]) { 437 auto lbOperands = srcSlice.lbOperands[i]; 438 canonicalizeMapAndOperands(&lbMap, &lbOperands); 439 forOp.setLowerBound(lbOperands, lbMap); 440 } 441 if (AffineMap ubMap = srcSlice.ubs[i]) { 442 auto ubOperands = srcSlice.ubOperands[i]; 443 canonicalizeMapAndOperands(&ubMap, &ubOperands); 444 forOp.setUpperBound(ubOperands, ubMap); 445 } 446 } 447 448 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap; 449 auto srcIsUnitSlice = [&]() { 450 return (buildSliceTripCountMap(srcSlice, &sliceTripCountMap) && 451 (getSliceIterationCount(sliceTripCountMap) == 1)); 452 }; 453 // Fix up and if possible, eliminate single iteration loops. 454 for (AffineForOp forOp : sliceLoops) { 455 if (isLoopParallelAndContainsReduction(forOp) && 456 isInnermostSiblingInsertion && srcIsUnitSlice()) 457 // Patch reduction loop - only ones that are sibling-fused with the 458 // destination loop - into the parent loop. 459 (void)promoteSingleIterReductionLoop(forOp, true); 460 else 461 // Promote any single iteration slice loops. 462 (void)promoteIfSingleIteration(forOp); 463 } 464 } 465 466 /// Collect loop nest statistics (eg. loop trip count and operation count) 467 /// in 'stats' for loop nest rooted at 'forOp'. Returns true on success, 468 /// returns false otherwise. 469 bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) { 470 auto walkResult = forOpRoot.walk([&](AffineForOp forOp) { 471 auto *childForOp = forOp.getOperation(); 472 auto *parentForOp = forOp->getParentOp(); 473 if (forOp != forOpRoot) { 474 if (!isa<AffineForOp>(parentForOp)) { 475 LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n"); 476 return WalkResult::interrupt(); 477 } 478 // Add mapping to 'forOp' from its parent AffineForOp. 479 stats->loopMap[parentForOp].push_back(forOp); 480 } 481 482 // Record the number of op operations in the body of 'forOp'. 483 unsigned count = 0; 484 stats->opCountMap[childForOp] = 0; 485 for (auto &op : *forOp.getBody()) { 486 if (!isa<AffineForOp, AffineIfOp>(op)) 487 ++count; 488 } 489 stats->opCountMap[childForOp] = count; 490 491 // Record trip count for 'forOp'. Set flag if trip count is not 492 // constant. 493 Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp); 494 if (!maybeConstTripCount) { 495 // Currently only constant trip count loop nests are supported. 496 LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported\n"); 497 return WalkResult::interrupt(); 498 } 499 500 stats->tripCountMap[childForOp] = *maybeConstTripCount; 501 return WalkResult::advance(); 502 }); 503 return !walkResult.wasInterrupted(); 504 } 505 506 // Computes the total cost of the loop nest rooted at 'forOp'. 507 // Currently, the total cost is computed by counting the total operation 508 // instance count (i.e. total number of operations in the loop bodyloop 509 // operation count * loop trip count) for the entire loop nest. 510 // If 'tripCountOverrideMap' is non-null, overrides the trip count for loops 511 // specified in the map when computing the total op instance count. 512 // NOTEs: 1) This is used to compute the cost of computation slices, which are 513 // sliced along the iteration dimension, and thus reduce the trip count. 514 // If 'computeCostMap' is non-null, the total op count for forOps specified 515 // in the map is increased (not overridden) by adding the op count from the 516 // map to the existing op count for the for loop. This is done before 517 // multiplying by the loop's trip count, and is used to model the cost of 518 // inserting a sliced loop nest of known cost into the loop's body. 519 // 2) This is also used to compute the cost of fusing a slice of some loop nest 520 // within another loop. 521 static int64_t getComputeCostHelper( 522 Operation *forOp, LoopNestStats &stats, 523 llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap, 524 DenseMap<Operation *, int64_t> *computeCostMap) { 525 // 'opCount' is the total number operations in one iteration of 'forOp' body, 526 // minus terminator op which is a no-op. 527 int64_t opCount = stats.opCountMap[forOp] - 1; 528 if (stats.loopMap.count(forOp) > 0) { 529 for (auto childForOp : stats.loopMap[forOp]) { 530 opCount += getComputeCostHelper(childForOp, stats, tripCountOverrideMap, 531 computeCostMap); 532 } 533 } 534 // Add in additional op instances from slice (if specified in map). 535 if (computeCostMap != nullptr) { 536 auto it = computeCostMap->find(forOp); 537 if (it != computeCostMap->end()) { 538 opCount += it->second; 539 } 540 } 541 // Override trip count (if specified in map). 542 int64_t tripCount = stats.tripCountMap[forOp]; 543 if (tripCountOverrideMap != nullptr) { 544 auto it = tripCountOverrideMap->find(forOp); 545 if (it != tripCountOverrideMap->end()) { 546 tripCount = it->second; 547 } 548 } 549 // Returns the total number of dynamic instances of operations in loop body. 550 return tripCount * opCount; 551 } 552 553 /// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'. 554 /// Currently, the total cost is computed by counting the total operation 555 /// instance count (i.e. total number of operations in the loop body * loop 556 /// trip count) for the entire loop nest. 557 int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) { 558 return getComputeCostHelper(forOp, stats, 559 /*tripCountOverrideMap=*/nullptr, 560 /*computeCostMap=*/nullptr); 561 } 562 563 /// Computes and returns in 'computeCost', the total compute cost of fusing the 564 /// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently, 565 /// the total cost is computed by counting the total operation instance count 566 /// (i.e. total number of operations in the loop body * loop trip count) for 567 /// the entire loop nest. 568 bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, 569 AffineForOp dstForOp, LoopNestStats &dstStats, 570 const ComputationSliceState &slice, 571 int64_t *computeCost) { 572 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap; 573 DenseMap<Operation *, int64_t> computeCostMap; 574 575 // Build trip count map for computation slice. 576 if (!buildSliceTripCountMap(slice, &sliceTripCountMap)) 577 return false; 578 // Checks whether a store to load forwarding will happen. 579 int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap); 580 assert(sliceIterationCount > 0); 581 bool storeLoadFwdGuaranteed = (sliceIterationCount == 1); 582 auto *insertPointParent = slice.insertPoint->getParentOp(); 583 584 // The store and loads to this memref will disappear. 585 // TODO: Add load coalescing to memref data flow opt pass. 586 if (storeLoadFwdGuaranteed) { 587 // Subtract from operation count the loads/store we expect load/store 588 // forwarding to remove. 589 unsigned storeCount = 0; 590 llvm::SmallDenseSet<Value, 4> storeMemrefs; 591 srcForOp.walk([&](Operation *op) { 592 if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) { 593 storeMemrefs.insert(storeOp.getMemRef()); 594 ++storeCount; 595 } 596 }); 597 // Subtract out any store ops in single-iteration src slice loop nest. 598 if (storeCount > 0) 599 computeCostMap[insertPointParent] = -storeCount; 600 // Subtract out any load users of 'storeMemrefs' nested below 601 // 'insertPointParent'. 602 for (Value memref : storeMemrefs) { 603 for (auto *user : memref.getUsers()) { 604 if (auto loadOp = dyn_cast<AffineReadOpInterface>(user)) { 605 SmallVector<AffineForOp, 4> loops; 606 // Check if any loop in loop nest surrounding 'user' is 607 // 'insertPointParent'. 608 getAffineForIVs(*user, &loops); 609 if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) { 610 if (auto forOp = 611 dyn_cast_or_null<AffineForOp>(user->getParentOp())) { 612 if (computeCostMap.count(forOp) == 0) 613 computeCostMap[forOp] = 0; 614 computeCostMap[forOp] -= 1; 615 } 616 } 617 } 618 } 619 } 620 } 621 622 // Compute op instance count for the src loop nest with iteration slicing. 623 int64_t sliceComputeCost = getComputeCostHelper( 624 srcForOp, srcStats, &sliceTripCountMap, &computeCostMap); 625 626 // Compute cost of fusion for this depth. 627 computeCostMap[insertPointParent] = sliceComputeCost; 628 629 *computeCost = 630 getComputeCostHelper(dstForOp, dstStats, 631 /*tripCountOverrideMap=*/nullptr, &computeCostMap); 632 return true; 633 } 634 635 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a 636 /// producer-consumer dependence between write ops in 'srcOps' and read ops in 637 /// 'dstOps'. 638 void mlir::gatherProducerConsumerMemrefs( 639 ArrayRef<Operation *> srcOps, ArrayRef<Operation *> dstOps, 640 DenseSet<Value> &producerConsumerMemrefs) { 641 // Gather memrefs from stores in 'srcOps'. 642 DenseSet<Value> srcStoreMemRefs; 643 for (Operation *op : srcOps) 644 if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) 645 srcStoreMemRefs.insert(storeOp.getMemRef()); 646 647 // Compute the intersection between memrefs from stores in 'srcOps' and 648 // memrefs from loads in 'dstOps'. 649 for (Operation *op : dstOps) 650 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) 651 if (srcStoreMemRefs.count(loadOp.getMemRef()) > 0) 652 producerConsumerMemrefs.insert(loadOp.getMemRef()); 653 } 654