xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp (revision 2bff9d9ffe3a4813961c1cf3af2e9ac5a20190bd)
1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with hoisting invariant operations
10 // in the context of Linalg transformations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19 #include "mlir/Dialect/Affine/Utils.h"
20 #include "mlir/Dialect/Arith/IR/Arith.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"
22 #include "mlir/Dialect/Linalg/IR/Linalg.h"
23 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
24 #include "mlir/Dialect/SCF/IR/SCF.h"
25 #include "mlir/Dialect/SCF/Utils/Utils.h"
26 #include "mlir/Dialect/Tensor/IR/Tensor.h"
27 #include "mlir/Dialect/Vector/IR/VectorOps.h"
28 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
29 #include "mlir/IR/BuiltinOps.h"
30 #include "mlir/IR/Dominance.h"
31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
33 #include "llvm/ADT/StringRef.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
36 
37 using llvm::dbgs;
38 
39 #define DEBUG_TYPE "linalg-hoisting"
40 
41 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
42 
43 using namespace mlir;
44 using namespace mlir::linalg;
45 
46 /// Replace `loop` with a new loop that has a different init operand at
47 /// position `index`. The body of this loop is moved over to the new loop.
48 ///
49 /// `newInitOperands` specifies the replacement "init" operands.
50 /// `newYieldValue` is the replacement yield value of the loop at position
51 /// `index`.
52 static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter,
53                                             scf::ForOp loop,
54                                             Value newInitOperand,
55                                             unsigned index,
56                                             Value newYieldValue) {
57   OpBuilder::InsertionGuard g(rewriter);
58   rewriter.setInsertionPoint(loop.getOperation());
59   auto inits = llvm::to_vector(loop.getInits());
60 
61   // Replace the init value with the new operand.
62   assert(index < inits.size());
63   inits[index] = newInitOperand;
64 
65   scf::ForOp newLoop = rewriter.create<scf::ForOp>(
66       loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
67       inits, [](OpBuilder &, Location, Value, ValueRange) {});
68 
69   // Generate the new yield with the replaced operand.
70   auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
71   yieldOp.setOperand(index, newYieldValue);
72 
73   // Move the loop body to the new op.
74   rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(),
75                        newLoop.getBody()->getArguments());
76 
77   // Replace the old loop.
78   rewriter.replaceOp(loop.getOperation(), newLoop->getResults());
79   return newLoop;
80 }
81 
82 // Hoist out a pair of corresponding vector.extract+vector.broadcast
83 // operations. This function transforms a loop like this:
84 //  %res = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) {
85 //   %e = vector.extract %iarg : t1 to t2
86 //   %u = "some_use"(%e) : (t2) -> t2
87 //   %b = vector.broadcast %u : t2 to t1
88 //   scf.yield %b : t1
89 //  }
90 // into the following:
91 //  %e = vector.extract %v: t1 to t2
92 //  %res' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) {
93 //   %u' = "some_use"(%iarg) : (t2) -> t2
94 //   scf.yield %u' : t2
95 //  }
96 //  %res = vector.broadcast %res' : t2 to t1
97 void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
98                                                   Operation *root) {
99   bool changed = true;
100   while (changed) {
101     changed = false;
102     // First move loop invariant ops outside of their loop. This needs to be
103     // done before as we cannot move ops without interrupting the function walk.
104     root->walk(
105         [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
106 
107     root->walk([&](vector::ExtractOp extractOp) {
108       LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
109                         << *extractOp.getOperation() << "\n");
110 
111       auto loop = dyn_cast<scf::ForOp>(extractOp->getParentOp());
112       if (!loop)
113         return WalkResult::advance();
114 
115       // Check that the vector to extract from is a BlockArgument.
116       auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
117       if (!blockArg)
118         return WalkResult::advance();
119 
120       // Check that the blockArg is an iter_arg of the loop.
121       OpOperand *initArg = loop.getTiedLoopInit(blockArg);
122       if (!initArg)
123         return WalkResult::advance();
124 
125       // If the iter_arg does not have only one use, it won't be possible to
126       // hoist the extractOp out.
127       if (!blockArg.hasOneUse())
128         return WalkResult::advance();
129 
130       unsigned index = blockArg.getArgNumber() - loop.getNumInductionVars();
131 
132       // Check that the loop yields a broadcast that has just one use.
133       Operation *yieldedVal =
134           loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp();
135       auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal);
136       if (!broadcast || !broadcast.getResult().hasOneUse())
137         return WalkResult::advance();
138 
139       LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n");
140 
141       Type broadcastInputType = broadcast.getSourceType();
142       if (broadcastInputType != extractOp.getType())
143         return WalkResult::advance();
144 
145       // The position of the extract must be defined outside of the loop if
146       // it is dynamic.
147       for (auto operand : extractOp.getDynamicPosition())
148         if (!loop.isDefinedOutsideOfLoop(operand))
149           return WalkResult::advance();
150 
151       rewriter.modifyOpInPlace(broadcast, [&] {
152         extractOp.getVectorMutable().assign(initArg->get());
153       });
154       loop.moveOutOfLoop(extractOp);
155       rewriter.moveOpAfter(broadcast, loop);
156 
157       scf::ForOp newLoop = replaceWithDifferentYield(
158           rewriter, loop, extractOp.getResult(), index, broadcast.getSource());
159 
160       LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
161 
162       rewriter.replaceAllUsesWith(newLoop.getResult(index), broadcast);
163       rewriter.modifyOpInPlace(
164           broadcast, [&] { broadcast.setOperand(newLoop.getResult(index)); });
165 
166       changed = true;
167       return WalkResult::interrupt();
168     });
169   }
170 }
171 
172 static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
173                                 LoopLikeOpInterface loop) {
174   Value source = transferRead.getSource();
175 
176   // Skip view-like Ops and retrive the actual soruce Operation
177   while (auto srcOp =
178              dyn_cast_or_null<ViewLikeOpInterface>(source.getDefiningOp()))
179     source = srcOp.getViewSource();
180 
181   llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
182                                            source.getUsers().end());
183   llvm::SmallDenseSet<Operation *, 32> processed;
184   while (!users.empty()) {
185     Operation *user = users.pop_back_val();
186     // If the user has already been processed skip.
187     if (!processed.insert(user).second)
188       continue;
189     if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
190       users.append(viewLike->getUsers().begin(), viewLike->getUsers().end());
191       continue;
192     }
193     if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
194       continue;
195     if (!loop->isAncestor(user))
196       continue;
197     return false;
198   }
199   return true;
200 }
201 
202 void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
203                                                  bool verifyNonZeroTrip) {
204   bool changed = true;
205   while (changed) {
206     changed = false;
207     // First move loop invariant ops outside of their loop. This needs to be
208     // done before as we cannot move ops without interrupting the function walk.
209     root->walk(
210         [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
211 
212     // Find all loops that are certain to have non zero trip count. Any loops
213     // that are not part of this set cannot be hoisted from, since hoisting from
214     // a potentially zero trip count loop may cause a vector transfer to be
215     // executed when it shouldn't be.
216     llvm::DenseSet<LoopLikeOpInterface> definiteNonZeroTripCountLoops;
217     if (verifyNonZeroTrip) {
218       root->walk([&](LoopLikeOpInterface loopLike) {
219         std::optional<SmallVector<OpFoldResult>> lbs =
220             loopLike.getLoopLowerBounds();
221         std::optional<SmallVector<OpFoldResult>> ubs =
222             loopLike.getLoopUpperBounds();
223         // If loop bounds cannot be found, assume possibly zero trip count.
224         if (!lbs || !ubs)
225           return;
226 
227         // Otherwise, use ValueBounds to find the maximum lower bound and
228         // minimum upper bound. If the bounds are found, and maxLb is less
229         // than the minUb, then the loop will not have zero trip count.
230         for (auto [lb, ub] : llvm::zip_equal(lbs.value(), ubs.value())) {
231           FailureOr<int64_t> maxLb =
232               ValueBoundsConstraintSet::computeConstantBound(
233                   presburger::BoundType::UB, lb,
234                   /*stopCondition=*/nullptr, /*closedUB=*/true);
235           if (failed(maxLb))
236             return;
237           FailureOr<int64_t> minUb =
238               ValueBoundsConstraintSet::computeConstantBound(
239                   presburger::BoundType::LB, ub);
240           if (failed(minUb))
241             return;
242           if (minUb.value() <= maxLb.value())
243             return;
244           definiteNonZeroTripCountLoops.insert(loopLike);
245         }
246       });
247     }
248 
249     root->walk([&](vector::TransferReadOp transferRead) {
250       if (!isa<MemRefType>(transferRead.getShapedType()))
251         return WalkResult::advance();
252 
253       LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
254                         << *transferRead.getOperation() << "\n");
255       auto loop = dyn_cast<LoopLikeOpInterface>(transferRead->getParentOp());
256       LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
257                         << "\n");
258       if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(loop))
259         return WalkResult::advance();
260 
261       if (verifyNonZeroTrip && !definiteNonZeroTripCountLoops.contains(loop)) {
262         LLVM_DEBUG(DBGS() << "Loop may have zero trip count: " << *loop
263                           << "\n");
264         return WalkResult::advance();
265       }
266 
267       LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
268                         << "\n");
269 
270       SetVector<Operation *> forwardSlice;
271       getForwardSlice(transferRead.getOperation(), &forwardSlice);
272 
273       // Look for the last TransferWriteOp in the forwardSlice of
274       // `transferRead` that operates on the same memref.
275       vector::TransferWriteOp transferWrite;
276       for (auto *sliceOp : llvm::reverse(forwardSlice)) {
277         auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
278         if (!candidateWrite ||
279             candidateWrite.getSource() != transferRead.getSource())
280           continue;
281         transferWrite = candidateWrite;
282       }
283 
284       // All operands of the TransferRead must be defined outside of the loop.
285       for (auto operand : transferRead.getOperands())
286         if (!loop.isDefinedOutsideOfLoop(operand))
287           return WalkResult::advance();
288 
289       // Only hoist transfer_read / transfer_write pairs and singleton
290       // transfer_reads for now.
291       if (!transferWrite) {
292         // Make sure there are no other accesses to the memref before
293         // hoisting transfer_read.
294         if (noAliasingUseInLoop(transferRead, loop))
295           loop.moveOutOfLoop(transferRead);
296         return WalkResult::advance();
297       }
298 
299       LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
300                         << "\n");
301 
302       // Approximate aliasing by checking that:
303       //   1. indices, vector type and permutation map are the same (i.e., the
304       //      transfer_read/transfer_write ops are matching),
305       //   2. source operands for transfer.{read|write} do not originate from
306       //      Ops implementing ViewLikeOpInterface.
307       //   3. no other operations in the loop access the same memref except
308       //      for transfer_read/transfer_write accessing statically disjoint
309       //      slices.
310       if (transferRead.getIndices() != transferWrite.getIndices() ||
311           transferRead.getVectorType() != transferWrite.getVectorType() ||
312           transferRead.getPermutationMap() != transferWrite.getPermutationMap())
313         return WalkResult::advance();
314 
315       auto *source = transferRead.getSource().getDefiningOp();
316       if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
317         return WalkResult::advance();
318 
319       source = transferWrite.getSource().getDefiningOp();
320       if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
321         return WalkResult::advance();
322 
323       // TODO: may want to memoize this information for performance but it
324       // likely gets invalidated often.
325       DominanceInfo dom(loop);
326       if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
327         return WalkResult::advance();
328       for (auto &use : transferRead.getSource().getUses()) {
329         if (!loop->isAncestor(use.getOwner()))
330           continue;
331         if (use.getOwner() == transferRead.getOperation() ||
332             use.getOwner() == transferWrite.getOperation())
333           continue;
334         if (auto transferWriteUse =
335                 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
336           if (!vector::isDisjointTransferSet(
337                   cast<VectorTransferOpInterface>(*transferWrite),
338                   cast<VectorTransferOpInterface>(*transferWriteUse),
339                   /*testDynamicValueUsingBounds=*/true))
340             return WalkResult::advance();
341         } else if (auto transferReadUse =
342                        dyn_cast<vector::TransferReadOp>(use.getOwner())) {
343           if (!vector::isDisjointTransferSet(
344                   cast<VectorTransferOpInterface>(*transferWrite),
345                   cast<VectorTransferOpInterface>(*transferReadUse),
346                   /*testDynamicValueUsingBounds=*/true))
347             return WalkResult::advance();
348         } else {
349           // Unknown use, we cannot prove that it doesn't alias with the
350           // transferRead/transferWrite operations.
351           return WalkResult::advance();
352         }
353       }
354 
355       // Hoist read before.
356       loop.moveOutOfLoop(transferRead);
357 
358       // Hoist write after.
359       transferWrite->moveAfter(loop);
360 
361       // Rewrite `loop` with new yields by cloning and erase the original loop.
362       IRRewriter rewriter(transferRead.getContext());
363       NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
364                                      ArrayRef<BlockArgument> newBBArgs) {
365         return SmallVector<Value>{transferWrite.getVector()};
366       };
367 
368       auto maybeNewLoop = loop.replaceWithAdditionalYields(
369           rewriter, transferRead.getVector(),
370           /*replaceInitOperandUsesInLoop=*/true, yieldFn);
371       if (failed(maybeNewLoop))
372         return WalkResult::interrupt();
373 
374       transferWrite.getVectorMutable().assign(
375           maybeNewLoop->getOperation()->getResults().back());
376       changed = true;
377       // Need to interrupt and restart because erasing the loop messes up
378       // the walk.
379       return WalkResult::interrupt();
380     });
381   }
382 }
383