1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with hoisting invariant operations 10 // in the context of Linalg transformations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h" 19 #include "mlir/Dialect/Affine/Utils.h" 20 #include "mlir/Dialect/Arith/IR/Arith.h" 21 #include "mlir/Dialect/Func/IR/FuncOps.h" 22 #include "mlir/Dialect/Linalg/IR/Linalg.h" 23 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 24 #include "mlir/Dialect/SCF/IR/SCF.h" 25 #include "mlir/Dialect/SCF/Utils/Utils.h" 26 #include "mlir/Dialect/Tensor/IR/Tensor.h" 27 #include "mlir/Dialect/Vector/IR/VectorOps.h" 28 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 29 #include "mlir/IR/BuiltinOps.h" 30 #include "mlir/IR/Dominance.h" 31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 32 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" 33 #include "llvm/ADT/StringRef.h" 34 #include "llvm/ADT/TypeSwitch.h" 35 #include "llvm/Support/Debug.h" 36 37 using llvm::dbgs; 38 39 #define DEBUG_TYPE "linalg-hoisting" 40 41 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") 42 43 using namespace mlir; 44 using namespace mlir::linalg; 45 46 /// Replace `loop` with a new loop that has a different init operand at 47 /// position `index`. The body of this loop is moved over to the new loop. 48 /// 49 /// `newInitOperands` specifies the replacement "init" operands. 50 /// `newYieldValue` is the replacement yield value of the loop at position 51 /// `index`. 52 static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, 53 scf::ForOp loop, 54 Value newInitOperand, 55 unsigned index, 56 Value newYieldValue) { 57 OpBuilder::InsertionGuard g(rewriter); 58 rewriter.setInsertionPoint(loop.getOperation()); 59 auto inits = llvm::to_vector(loop.getInits()); 60 61 // Replace the init value with the new operand. 62 assert(index < inits.size()); 63 inits[index] = newInitOperand; 64 65 scf::ForOp newLoop = rewriter.create<scf::ForOp>( 66 loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), 67 inits, [](OpBuilder &, Location, Value, ValueRange) {}); 68 69 // Generate the new yield with the replaced operand. 70 auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator()); 71 yieldOp.setOperand(index, newYieldValue); 72 73 // Move the loop body to the new op. 74 rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(), 75 newLoop.getBody()->getArguments()); 76 77 // Replace the old loop. 78 rewriter.replaceOp(loop.getOperation(), newLoop->getResults()); 79 return newLoop; 80 } 81 82 // Hoist out a pair of corresponding vector.extract+vector.broadcast 83 // operations. This function transforms a loop like this: 84 // %res = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) { 85 // %e = vector.extract %iarg : t1 to t2 86 // %u = "some_use"(%e) : (t2) -> t2 87 // %b = vector.broadcast %u : t2 to t1 88 // scf.yield %b : t1 89 // } 90 // into the following: 91 // %e = vector.extract %v: t1 to t2 92 // %res' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) { 93 // %u' = "some_use"(%iarg) : (t2) -> t2 94 // scf.yield %u' : t2 95 // } 96 // %res = vector.broadcast %res' : t2 to t1 97 void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter, 98 Operation *root) { 99 bool changed = true; 100 while (changed) { 101 changed = false; 102 // First move loop invariant ops outside of their loop. This needs to be 103 // done before as we cannot move ops without interrupting the function walk. 104 root->walk( 105 [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); }); 106 107 root->walk([&](vector::ExtractOp extractOp) { 108 LLVM_DEBUG(DBGS() << "Candidate for hoisting: " 109 << *extractOp.getOperation() << "\n"); 110 111 auto loop = dyn_cast<scf::ForOp>(extractOp->getParentOp()); 112 if (!loop) 113 return WalkResult::advance(); 114 115 // Check that the vector to extract from is a BlockArgument. 116 auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector()); 117 if (!blockArg) 118 return WalkResult::advance(); 119 120 // Check that the blockArg is an iter_arg of the loop. 121 OpOperand *initArg = loop.getTiedLoopInit(blockArg); 122 if (!initArg) 123 return WalkResult::advance(); 124 125 // If the iter_arg does not have only one use, it won't be possible to 126 // hoist the extractOp out. 127 if (!blockArg.hasOneUse()) 128 return WalkResult::advance(); 129 130 unsigned index = blockArg.getArgNumber() - loop.getNumInductionVars(); 131 132 // Check that the loop yields a broadcast that has just one use. 133 Operation *yieldedVal = 134 loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp(); 135 auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal); 136 if (!broadcast || !broadcast.getResult().hasOneUse()) 137 return WalkResult::advance(); 138 139 LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n"); 140 141 Type broadcastInputType = broadcast.getSourceType(); 142 if (broadcastInputType != extractOp.getType()) 143 return WalkResult::advance(); 144 145 // The position of the extract must be defined outside of the loop if 146 // it is dynamic. 147 for (auto operand : extractOp.getDynamicPosition()) 148 if (!loop.isDefinedOutsideOfLoop(operand)) 149 return WalkResult::advance(); 150 151 rewriter.modifyOpInPlace(broadcast, [&] { 152 extractOp.getVectorMutable().assign(initArg->get()); 153 }); 154 loop.moveOutOfLoop(extractOp); 155 rewriter.moveOpAfter(broadcast, loop); 156 157 scf::ForOp newLoop = replaceWithDifferentYield( 158 rewriter, loop, extractOp.getResult(), index, broadcast.getSource()); 159 160 LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n"); 161 162 rewriter.replaceAllUsesWith(newLoop.getResult(index), broadcast); 163 rewriter.modifyOpInPlace( 164 broadcast, [&] { broadcast.setOperand(newLoop.getResult(index)); }); 165 166 changed = true; 167 return WalkResult::interrupt(); 168 }); 169 } 170 } 171 172 static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, 173 LoopLikeOpInterface loop) { 174 Value source = transferRead.getSource(); 175 176 // Skip view-like Ops and retrive the actual soruce Operation 177 while (auto srcOp = 178 dyn_cast_or_null<ViewLikeOpInterface>(source.getDefiningOp())) 179 source = srcOp.getViewSource(); 180 181 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), 182 source.getUsers().end()); 183 llvm::SmallDenseSet<Operation *, 32> processed; 184 while (!users.empty()) { 185 Operation *user = users.pop_back_val(); 186 // If the user has already been processed skip. 187 if (!processed.insert(user).second) 188 continue; 189 if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) { 190 users.append(viewLike->getUsers().begin(), viewLike->getUsers().end()); 191 continue; 192 } 193 if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user)) 194 continue; 195 if (!loop->isAncestor(user)) 196 continue; 197 return false; 198 } 199 return true; 200 } 201 202 void mlir::linalg::hoistRedundantVectorTransfers(Operation *root, 203 bool verifyNonZeroTrip) { 204 bool changed = true; 205 while (changed) { 206 changed = false; 207 // First move loop invariant ops outside of their loop. This needs to be 208 // done before as we cannot move ops without interrupting the function walk. 209 root->walk( 210 [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); }); 211 212 // Find all loops that are certain to have non zero trip count. Any loops 213 // that are not part of this set cannot be hoisted from, since hoisting from 214 // a potentially zero trip count loop may cause a vector transfer to be 215 // executed when it shouldn't be. 216 llvm::DenseSet<LoopLikeOpInterface> definiteNonZeroTripCountLoops; 217 if (verifyNonZeroTrip) { 218 root->walk([&](LoopLikeOpInterface loopLike) { 219 std::optional<SmallVector<OpFoldResult>> lbs = 220 loopLike.getLoopLowerBounds(); 221 std::optional<SmallVector<OpFoldResult>> ubs = 222 loopLike.getLoopUpperBounds(); 223 // If loop bounds cannot be found, assume possibly zero trip count. 224 if (!lbs || !ubs) 225 return; 226 227 // Otherwise, use ValueBounds to find the maximum lower bound and 228 // minimum upper bound. If the bounds are found, and maxLb is less 229 // than the minUb, then the loop will not have zero trip count. 230 for (auto [lb, ub] : llvm::zip_equal(lbs.value(), ubs.value())) { 231 FailureOr<int64_t> maxLb = 232 ValueBoundsConstraintSet::computeConstantBound( 233 presburger::BoundType::UB, lb, 234 /*stopCondition=*/nullptr, /*closedUB=*/true); 235 if (failed(maxLb)) 236 return; 237 FailureOr<int64_t> minUb = 238 ValueBoundsConstraintSet::computeConstantBound( 239 presburger::BoundType::LB, ub); 240 if (failed(minUb)) 241 return; 242 if (minUb.value() <= maxLb.value()) 243 return; 244 definiteNonZeroTripCountLoops.insert(loopLike); 245 } 246 }); 247 } 248 249 root->walk([&](vector::TransferReadOp transferRead) { 250 if (!isa<MemRefType>(transferRead.getShapedType())) 251 return WalkResult::advance(); 252 253 LLVM_DEBUG(DBGS() << "Candidate for hoisting: " 254 << *transferRead.getOperation() << "\n"); 255 auto loop = dyn_cast<LoopLikeOpInterface>(transferRead->getParentOp()); 256 LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp() 257 << "\n"); 258 if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(loop)) 259 return WalkResult::advance(); 260 261 if (verifyNonZeroTrip && !definiteNonZeroTripCountLoops.contains(loop)) { 262 LLVM_DEBUG(DBGS() << "Loop may have zero trip count: " << *loop 263 << "\n"); 264 return WalkResult::advance(); 265 } 266 267 LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation() 268 << "\n"); 269 270 SetVector<Operation *> forwardSlice; 271 getForwardSlice(transferRead.getOperation(), &forwardSlice); 272 273 // Look for the last TransferWriteOp in the forwardSlice of 274 // `transferRead` that operates on the same memref. 275 vector::TransferWriteOp transferWrite; 276 for (auto *sliceOp : llvm::reverse(forwardSlice)) { 277 auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp); 278 if (!candidateWrite || 279 candidateWrite.getSource() != transferRead.getSource()) 280 continue; 281 transferWrite = candidateWrite; 282 } 283 284 // All operands of the TransferRead must be defined outside of the loop. 285 for (auto operand : transferRead.getOperands()) 286 if (!loop.isDefinedOutsideOfLoop(operand)) 287 return WalkResult::advance(); 288 289 // Only hoist transfer_read / transfer_write pairs and singleton 290 // transfer_reads for now. 291 if (!transferWrite) { 292 // Make sure there are no other accesses to the memref before 293 // hoisting transfer_read. 294 if (noAliasingUseInLoop(transferRead, loop)) 295 loop.moveOutOfLoop(transferRead); 296 return WalkResult::advance(); 297 } 298 299 LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation() 300 << "\n"); 301 302 // Approximate aliasing by checking that: 303 // 1. indices, vector type and permutation map are the same (i.e., the 304 // transfer_read/transfer_write ops are matching), 305 // 2. source operands for transfer.{read|write} do not originate from 306 // Ops implementing ViewLikeOpInterface. 307 // 3. no other operations in the loop access the same memref except 308 // for transfer_read/transfer_write accessing statically disjoint 309 // slices. 310 if (transferRead.getIndices() != transferWrite.getIndices() || 311 transferRead.getVectorType() != transferWrite.getVectorType() || 312 transferRead.getPermutationMap() != transferWrite.getPermutationMap()) 313 return WalkResult::advance(); 314 315 auto *source = transferRead.getSource().getDefiningOp(); 316 if (source && isa_and_nonnull<ViewLikeOpInterface>(source)) 317 return WalkResult::advance(); 318 319 source = transferWrite.getSource().getDefiningOp(); 320 if (source && isa_and_nonnull<ViewLikeOpInterface>(source)) 321 return WalkResult::advance(); 322 323 // TODO: may want to memoize this information for performance but it 324 // likely gets invalidated often. 325 DominanceInfo dom(loop); 326 if (!dom.properlyDominates(transferRead.getOperation(), transferWrite)) 327 return WalkResult::advance(); 328 for (auto &use : transferRead.getSource().getUses()) { 329 if (!loop->isAncestor(use.getOwner())) 330 continue; 331 if (use.getOwner() == transferRead.getOperation() || 332 use.getOwner() == transferWrite.getOperation()) 333 continue; 334 if (auto transferWriteUse = 335 dyn_cast<vector::TransferWriteOp>(use.getOwner())) { 336 if (!vector::isDisjointTransferSet( 337 cast<VectorTransferOpInterface>(*transferWrite), 338 cast<VectorTransferOpInterface>(*transferWriteUse), 339 /*testDynamicValueUsingBounds=*/true)) 340 return WalkResult::advance(); 341 } else if (auto transferReadUse = 342 dyn_cast<vector::TransferReadOp>(use.getOwner())) { 343 if (!vector::isDisjointTransferSet( 344 cast<VectorTransferOpInterface>(*transferWrite), 345 cast<VectorTransferOpInterface>(*transferReadUse), 346 /*testDynamicValueUsingBounds=*/true)) 347 return WalkResult::advance(); 348 } else { 349 // Unknown use, we cannot prove that it doesn't alias with the 350 // transferRead/transferWrite operations. 351 return WalkResult::advance(); 352 } 353 } 354 355 // Hoist read before. 356 loop.moveOutOfLoop(transferRead); 357 358 // Hoist write after. 359 transferWrite->moveAfter(loop); 360 361 // Rewrite `loop` with new yields by cloning and erase the original loop. 362 IRRewriter rewriter(transferRead.getContext()); 363 NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc, 364 ArrayRef<BlockArgument> newBBArgs) { 365 return SmallVector<Value>{transferWrite.getVector()}; 366 }; 367 368 auto maybeNewLoop = loop.replaceWithAdditionalYields( 369 rewriter, transferRead.getVector(), 370 /*replaceInitOperandUsesInLoop=*/true, yieldFn); 371 if (failed(maybeNewLoop)) 372 return WalkResult::interrupt(); 373 374 transferWrite.getVectorMutable().assign( 375 maybeNewLoop->getOperation()->getResults().back()); 376 changed = true; 377 // Need to interrupt and restart because erasing the loop messes up 378 // the walk. 379 return WalkResult::interrupt(); 380 }); 381 } 382 } 383