1 //===- LoopInvariantCodeMotionUtils.cpp - LICM Utils ------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains the implementation of the core LICM algorithm. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" 14 15 #include "mlir/IR/Operation.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/Interfaces/LoopLikeInterface.h" 18 #include "mlir/Interfaces/SideEffectInterfaces.h" 19 #include "mlir/Interfaces/SubsetOpInterface.h" 20 #include "llvm/Support/Debug.h" 21 #include <queue> 22 23 #define DEBUG_TYPE "licm" 24 25 using namespace mlir; 26 27 /// Checks whether the given op can be hoisted by checking that 28 /// - the op and none of its contained operations depend on values inside of the 29 /// loop (by means of calling definedOutside). 30 /// - the op has no side-effects. 31 static bool canBeHoisted(Operation *op, 32 function_ref<bool(OpOperand &)> condition) { 33 // Do not move terminators. 34 if (op->hasTrait<OpTrait::IsTerminator>()) 35 return false; 36 37 // Walk the nested operations and check that all used values are either 38 // defined outside of the loop or in a nested region, but not at the level of 39 // the loop body. 40 auto walkFn = [&](Operation *child) { 41 for (OpOperand &operand : child->getOpOperands()) { 42 // Ignore values defined in a nested region. 43 if (op->isAncestor(operand.get().getParentRegion()->getParentOp())) 44 continue; 45 if (!condition(operand)) 46 return WalkResult::interrupt(); 47 } 48 return WalkResult::advance(); 49 }; 50 return !op->walk(walkFn).wasInterrupted(); 51 } 52 53 static bool canBeHoisted(Operation *op, 54 function_ref<bool(Value)> definedOutside) { 55 return canBeHoisted( 56 op, [&](OpOperand &operand) { return definedOutside(operand.get()); }); 57 } 58 59 size_t mlir::moveLoopInvariantCode( 60 ArrayRef<Region *> regions, 61 function_ref<bool(Value, Region *)> isDefinedOutsideRegion, 62 function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion, 63 function_ref<void(Operation *, Region *)> moveOutOfRegion) { 64 size_t numMoved = 0; 65 66 for (Region *region : regions) { 67 LLVM_DEBUG(llvm::dbgs() << "Original loop:\n" 68 << *region->getParentOp() << "\n"); 69 70 std::queue<Operation *> worklist; 71 // Add top-level operations in the loop body to the worklist. 72 for (Operation &op : region->getOps()) 73 worklist.push(&op); 74 75 auto definedOutside = [&](Value value) { 76 return isDefinedOutsideRegion(value, region); 77 }; 78 79 while (!worklist.empty()) { 80 Operation *op = worklist.front(); 81 worklist.pop(); 82 // Skip ops that have already been moved. Check if the op can be hoisted. 83 if (op->getParentRegion() != region) 84 continue; 85 86 LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n"); 87 if (!shouldMoveOutOfRegion(op, region) || 88 !canBeHoisted(op, definedOutside)) 89 continue; 90 91 LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n"); 92 moveOutOfRegion(op, region); 93 ++numMoved; 94 95 // Since the op has been moved, we need to check its users within the 96 // top-level of the loop body. 97 for (Operation *user : op->getUsers()) 98 if (user->getParentRegion() == region) 99 worklist.push(user); 100 } 101 } 102 103 return numMoved; 104 } 105 106 size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) { 107 return moveLoopInvariantCode( 108 loopLike.getLoopRegions(), 109 [&](Value value, Region *) { 110 return loopLike.isDefinedOutsideOfLoop(value); 111 }, 112 [&](Operation *op, Region *) { 113 return isMemoryEffectFree(op) && isSpeculatable(op); 114 }, 115 [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); }); 116 } 117 118 namespace { 119 /// Helper data structure that keeps track of equivalent/disjoint subset ops. 120 class MatchingSubsets { 121 public: 122 /// Insert a subset op. 123 void insert(SubsetOpInterface op, bool collectHoistableOps = true) { 124 allSubsetOps.push_back(op); 125 if (!collectHoistableOps) 126 return; 127 if (auto extractionOp = 128 dyn_cast<SubsetExtractionOpInterface>(op.getOperation())) 129 insertExtractionOp(extractionOp); 130 if (auto insertionOp = 131 dyn_cast<SubsetInsertionOpInterface>(op.getOperation())) 132 insertInsertionOp(insertionOp); 133 } 134 135 /// Return a range of matching extraction-insertion subset ops. If there is no 136 /// matching extraction/insertion op, the respective value is empty. Ops are 137 /// skipped if there are other subset ops that are not guaranteed to operate 138 /// on disjoint subsets. 139 auto getHoistableSubsetOps() { 140 return llvm::make_filter_range( 141 llvm::zip(extractions, insertions), [&](auto pair) { 142 auto [extractionOp, insertionOp] = pair; 143 // Hoist only if the extracted and inserted values have the same type. 144 if (extractionOp && insertionOp && 145 extractionOp->getResult(0).getType() != 146 insertionOp.getSourceOperand().get().getType()) 147 return false; 148 // Hoist only if there are no conflicting subset ops. 149 return allDisjoint(extractionOp, insertionOp); 150 }); 151 } 152 153 /// Populate subset ops starting from the given region iter_arg. Return 154 /// "failure" if non-subset ops are found along the path to the loop yielding 155 /// op or if there is no single path to the tied yielded operand. If 156 /// `collectHoistableOps` is set to "false", subset ops are gathered 157 /// throughout the traversal, but not enumerated by `getHoistableSubsetOps`. 158 LogicalResult populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike, 159 BlockArgument iterArg, 160 bool collectHoistableOps = true); 161 162 private: 163 /// Helper function for equivalence of tensor values. Since only insertion 164 /// subset ops (that are also destination style ops) are followed when 165 /// traversing the SSA use-def chain, all tensor values are equivalent. 166 static bool isEquivalent(Value v1, Value v2) { return true; } 167 168 /// Return "true" if the subsets of the given extraction and insertion ops 169 /// are operating disjoint from the subsets that all other known subset ops 170 /// are operating on. 171 bool allDisjoint(SubsetExtractionOpInterface extractionOp, 172 SubsetInsertionOpInterface insertionOp) const { 173 for (SubsetOpInterface other : allSubsetOps) { 174 if (other == extractionOp || other == insertionOp) 175 continue; 176 if (extractionOp && 177 !other.operatesOnDisjointSubset(extractionOp, isEquivalent)) 178 return false; 179 if (insertionOp && 180 !other.operatesOnDisjointSubset(insertionOp, isEquivalent)) 181 return false; 182 } 183 return true; 184 } 185 186 /// Insert a subset extraction op. If the subset is equivalent to an existing 187 /// subset insertion op, pair them up. (If there is already a paired up subset 188 /// extraction op, overwrite the subset extraction op.) 189 void insertExtractionOp(SubsetExtractionOpInterface extractionOp) { 190 for (auto it : llvm::enumerate(insertions)) { 191 if (!it.value()) 192 continue; 193 auto other = cast<SubsetOpInterface>(it.value().getOperation()); 194 if (other.operatesOnEquivalentSubset(extractionOp, isEquivalent)) { 195 extractions[it.index()] = extractionOp; 196 return; 197 } 198 } 199 // There is no known equivalent insertion op. Create a new entry. 200 extractions.push_back(extractionOp); 201 insertions.push_back({}); 202 } 203 204 /// Insert a subset insertion op. If the subset is equivalent to an existing 205 /// subset extraction op, pair them up. (If there is already a paired up 206 /// subset insertion op, overwrite the subset insertion op.) 207 void insertInsertionOp(SubsetInsertionOpInterface insertionOp) { 208 for (auto it : llvm::enumerate(extractions)) { 209 if (!it.value()) 210 continue; 211 auto other = cast<SubsetOpInterface>(it.value().getOperation()); 212 if (other.operatesOnEquivalentSubset(insertionOp, isEquivalent)) { 213 insertions[it.index()] = insertionOp; 214 return; 215 } 216 } 217 // There is no known equivalent extraction op. Create a new entry. 218 extractions.push_back({}); 219 insertions.push_back(insertionOp); 220 } 221 222 SmallVector<SubsetExtractionOpInterface> extractions; 223 SmallVector<SubsetInsertionOpInterface> insertions; 224 SmallVector<SubsetOpInterface> allSubsetOps; 225 }; 226 } // namespace 227 228 /// If the given value has a single use by an op that is a terminator, return 229 /// that use. Otherwise, return nullptr. 230 static OpOperand *getSingleTerminatorUse(Value value) { 231 if (!value.hasOneUse()) 232 return nullptr; 233 OpOperand &use = *value.getUses().begin(); 234 if (use.getOwner()->hasTrait<OpTrait::IsTerminator>()) 235 return &use; 236 return nullptr; 237 } 238 239 LogicalResult 240 MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike, 241 BlockArgument iterArg, 242 bool collectHoistableOps) { 243 assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg"); 244 Value value = iterArg; 245 246 // Traverse use-def chain. Subset ops can be hoisted only if all ops along the 247 // use-def chain starting from the region iter_arg are subset extraction or 248 // subset insertion ops. The chain must terminate at the corresponding yield 249 // operand (e.g., no swapping of iter_args). 250 OpOperand *yieldedOperand = nullptr; 251 // Iterate until the single use of the current SSA value is a terminator, 252 // which is expected to be the yielding operation of the loop. 253 while (!(yieldedOperand = getSingleTerminatorUse(value))) { 254 Value nextValue = {}; 255 256 for (OpOperand &use : value.getUses()) { 257 if (auto nestedLoop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) { 258 // Subset ops in nested loops are collected to check if there are only 259 // disjoint subset ops, but such subset ops are not subject to hoisting. 260 // To hoist subset ops from nested loops, the hoisting transformation 261 // should be run on the nested loop. 262 auto nestedIterArg = nestedLoop.getTiedLoopRegionIterArg(&use); 263 if (!nestedIterArg) 264 return failure(); 265 // Note: `populateSubsetOpsAtIterArg` fails if there is no single SSA 266 // use-def chain starting at `nestedIterArg` and terminating in the 267 // tied, yielding operand. 268 if (failed(populateSubsetOpsAtIterArg(nestedLoop, nestedIterArg, 269 /*collectHoistableOps=*/false))) 270 return failure(); 271 nextValue = nestedLoop.getTiedLoopResult(&use); 272 continue; 273 } 274 275 auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner()); 276 if (!subsetOp) 277 return failure(); 278 insert(subsetOp); 279 280 if (auto insertionOp = 281 dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) { 282 // The value must be used as a destination. (In case of a source, the 283 // entire tensor would be read, which would prevent any hoisting.) 284 if (&use != &insertionOp.getDestinationOperand()) 285 return failure(); 286 // There must be a single use-def chain from the region iter_arg to the 287 // terminator. I.e., only one insertion op. Branches are not supported. 288 if (nextValue) 289 return failure(); 290 nextValue = insertionOp.getUpdatedDestination(); 291 } 292 } 293 294 // Nothing can be hoisted if the chain does not continue with loop yielding 295 // op or a subset insertion op. 296 if (!nextValue) 297 return failure(); 298 value = nextValue; 299 } 300 301 // Hoist only if the SSA use-def chain ends in the yielding terminator of the 302 // loop and the yielded value is the `idx`-th operand. (I.e., there is no 303 // swapping yield.) 304 if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand) 305 return failure(); 306 307 return success(); 308 } 309 310 /// Hoist all subset ops that operate on the idx-th region iter_arg of the given 311 /// loop-like op and index into loop-invariant subset locations. Return the 312 /// newly created loop op (that has extra iter_args) or the original loop op if 313 /// nothing was hoisted. 314 static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter, 315 LoopLikeOpInterface loopLike, 316 BlockArgument iterArg) { 317 assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg"); 318 auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg); 319 int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it); 320 MatchingSubsets subsets; 321 if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg))) 322 return loopLike; 323 324 // Hoist all matching extraction-insertion pairs one-by-one. 325 for (auto it : subsets.getHoistableSubsetOps()) { 326 auto extractionOp = std::get<0>(it); 327 auto insertionOp = std::get<1>(it); 328 329 // Ops cannot be hoisted if they depend on loop-variant values. 330 if (extractionOp) { 331 if (!canBeHoisted(extractionOp, [&](OpOperand &operand) { 332 return loopLike.isDefinedOutsideOfLoop(operand.get()) || 333 &operand == &extractionOp.getSourceOperand(); 334 })) 335 extractionOp = {}; 336 } 337 if (insertionOp) { 338 if (!canBeHoisted(insertionOp, [&](OpOperand &operand) { 339 return loopLike.isDefinedOutsideOfLoop(operand.get()) || 340 &operand == &insertionOp.getSourceOperand() || 341 &operand == &insertionOp.getDestinationOperand(); 342 })) 343 insertionOp = {}; 344 } 345 346 // Only hoist extraction-insertion pairs for now. Standalone extractions/ 347 // insertions that are loop-invariant could be hoisted, but there may be 348 // easier ways to canonicalize the IR. 349 if (extractionOp && insertionOp) { 350 // Create a new loop with an additional iter_arg. 351 NewYieldValuesFn newYieldValuesFn = 352 [&](OpBuilder &b, Location loc, 353 ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> { 354 return {insertionOp.getSourceOperand().get()}; 355 }; 356 FailureOr<LoopLikeOpInterface> newLoop = 357 loopLike.replaceWithAdditionalYields( 358 rewriter, extractionOp.getResult(), 359 /*replaceInitOperandUsesInLoop=*/true, newYieldValuesFn); 360 if (failed(newLoop)) 361 return loopLike; 362 loopLike = *newLoop; 363 364 // Hoist the extraction/insertion ops. 365 iterArg = loopLike.getRegionIterArgs()[iterArgIdx]; 366 OpResult loopResult = loopLike.getTiedLoopResult(iterArg); 367 OpResult newLoopResult = loopLike.getLoopResults()->back(); 368 rewriter.moveOpBefore(extractionOp, loopLike); 369 rewriter.moveOpAfter(insertionOp, loopLike); 370 rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(), 371 insertionOp.getDestinationOperand().get()); 372 extractionOp.getSourceOperand().set( 373 loopLike.getTiedLoopInit(iterArg)->get()); 374 rewriter.replaceAllUsesWith(loopResult, 375 insertionOp.getUpdatedDestination()); 376 insertionOp.getSourceOperand().set(newLoopResult); 377 insertionOp.getDestinationOperand().set(loopResult); 378 } 379 } 380 381 return loopLike; 382 } 383 384 LoopLikeOpInterface 385 mlir::hoistLoopInvariantSubsets(RewriterBase &rewriter, 386 LoopLikeOpInterface loopLike) { 387 // Note: As subset ops are getting hoisted, the number of region iter_args 388 // increases. This can enable further hoisting opportunities on the new 389 // iter_args. 390 for (int64_t i = 0; 391 i < static_cast<int64_t>(loopLike.getRegionIterArgs().size()); ++i) { 392 loopLike = hoistSubsetAtIterArg(rewriter, loopLike, 393 loopLike.getRegionIterArgs()[i]); 394 } 395 return loopLike; 396 } 397