xref: /llvm-project/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp (revision 76ead96c1d06ee0d828238bce96d0107e650b5fa)
1 //===- LoopInvariantCodeMotionUtils.cpp - LICM Utils ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the core LICM algorithm.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
14 
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/Interfaces/LoopLikeInterface.h"
18 #include "mlir/Interfaces/SideEffectInterfaces.h"
19 #include "mlir/Interfaces/SubsetOpInterface.h"
20 #include "llvm/Support/Debug.h"
21 #include <queue>
22 
23 #define DEBUG_TYPE "licm"
24 
25 using namespace mlir;
26 
27 /// Checks whether the given op can be hoisted by checking that
28 /// - the op and none of its contained operations depend on values inside of the
29 ///   loop (by means of calling definedOutside).
30 /// - the op has no side-effects.
canBeHoisted(Operation * op,function_ref<bool (OpOperand &)> condition)31 static bool canBeHoisted(Operation *op,
32                          function_ref<bool(OpOperand &)> condition) {
33   // Do not move terminators.
34   if (op->hasTrait<OpTrait::IsTerminator>())
35     return false;
36 
37   // Walk the nested operations and check that all used values are either
38   // defined outside of the loop or in a nested region, but not at the level of
39   // the loop body.
40   auto walkFn = [&](Operation *child) {
41     for (OpOperand &operand : child->getOpOperands()) {
42       // Ignore values defined in a nested region.
43       if (op->isAncestor(operand.get().getParentRegion()->getParentOp()))
44         continue;
45       if (!condition(operand))
46         return WalkResult::interrupt();
47     }
48     return WalkResult::advance();
49   };
50   return !op->walk(walkFn).wasInterrupted();
51 }
52 
canBeHoisted(Operation * op,function_ref<bool (Value)> definedOutside)53 static bool canBeHoisted(Operation *op,
54                          function_ref<bool(Value)> definedOutside) {
55   return canBeHoisted(
56       op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
57 }
58 
moveLoopInvariantCode(ArrayRef<Region * > regions,function_ref<bool (Value,Region *)> isDefinedOutsideRegion,function_ref<bool (Operation *,Region *)> shouldMoveOutOfRegion,function_ref<void (Operation *,Region *)> moveOutOfRegion)59 size_t mlir::moveLoopInvariantCode(
60     ArrayRef<Region *> regions,
61     function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
62     function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
63     function_ref<void(Operation *, Region *)> moveOutOfRegion) {
64   size_t numMoved = 0;
65 
66   for (Region *region : regions) {
67     LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
68                             << *region->getParentOp() << "\n");
69 
70     std::queue<Operation *> worklist;
71     // Add top-level operations in the loop body to the worklist.
72     for (Operation &op : region->getOps())
73       worklist.push(&op);
74 
75     auto definedOutside = [&](Value value) {
76       return isDefinedOutsideRegion(value, region);
77     };
78 
79     while (!worklist.empty()) {
80       Operation *op = worklist.front();
81       worklist.pop();
82       // Skip ops that have already been moved. Check if the op can be hoisted.
83       if (op->getParentRegion() != region)
84         continue;
85 
86       LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
87       if (!shouldMoveOutOfRegion(op, region) ||
88           !canBeHoisted(op, definedOutside))
89         continue;
90 
91       LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
92       moveOutOfRegion(op, region);
93       ++numMoved;
94 
95       // Since the op has been moved, we need to check its users within the
96       // top-level of the loop body.
97       for (Operation *user : op->getUsers())
98         if (user->getParentRegion() == region)
99           worklist.push(user);
100     }
101   }
102 
103   return numMoved;
104 }
105 
moveLoopInvariantCode(LoopLikeOpInterface loopLike)106 size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
107   return moveLoopInvariantCode(
108       loopLike.getLoopRegions(),
109       [&](Value value, Region *) {
110         return loopLike.isDefinedOutsideOfLoop(value);
111       },
112       [&](Operation *op, Region *) {
113         return isMemoryEffectFree(op) && isSpeculatable(op);
114       },
115       [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
116 }
117 
118 namespace {
119 /// Helper data structure that keeps track of equivalent/disjoint subset ops.
120 class MatchingSubsets {
121 public:
122   /// Insert a subset op.
insert(SubsetOpInterface op,bool collectHoistableOps=true)123   void insert(SubsetOpInterface op, bool collectHoistableOps = true) {
124     allSubsetOps.push_back(op);
125     if (!collectHoistableOps)
126       return;
127     if (auto extractionOp =
128             dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
129       insertExtractionOp(extractionOp);
130     if (auto insertionOp =
131             dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
132       insertInsertionOp(insertionOp);
133   }
134 
135   /// Return a range of matching extraction-insertion subset ops. If there is no
136   /// matching extraction/insertion op, the respective value is empty. Ops are
137   /// skipped if there are other subset ops that are not guaranteed to operate
138   /// on disjoint subsets.
getHoistableSubsetOps()139   auto getHoistableSubsetOps() {
140     return llvm::make_filter_range(
141         llvm::zip(extractions, insertions), [&](auto pair) {
142           auto [extractionOp, insertionOp] = pair;
143           // Hoist only if the extracted and inserted values have the same type.
144           if (extractionOp && insertionOp &&
145               extractionOp->getResult(0).getType() !=
146                   insertionOp.getSourceOperand().get().getType())
147             return false;
148           // Hoist only if there are no conflicting subset ops.
149           return allDisjoint(extractionOp, insertionOp);
150         });
151   }
152 
153   /// Populate subset ops starting from the given region iter_arg. Return
154   /// "failure" if non-subset ops are found along the path to the loop yielding
155   /// op or if there is no single path to the tied yielded operand. If
156   /// `collectHoistableOps` is set to "false", subset ops are gathered
157   /// throughout the traversal, but not enumerated by `getHoistableSubsetOps`.
158   LogicalResult populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
159                                            BlockArgument iterArg,
160                                            bool collectHoistableOps = true);
161 
162 private:
163   /// Helper function for equivalence of tensor values. Since only insertion
164   /// subset ops (that are also destination style ops) are followed when
165   /// traversing the SSA use-def chain, all tensor values are equivalent.
isEquivalent(Value v1,Value v2)166   static bool isEquivalent(Value v1, Value v2) { return true; }
167 
168   /// Return "true" if the subsets of the given extraction and insertion ops
169   /// are operating disjoint from the subsets that all other known subset ops
170   /// are operating on.
allDisjoint(SubsetExtractionOpInterface extractionOp,SubsetInsertionOpInterface insertionOp) const171   bool allDisjoint(SubsetExtractionOpInterface extractionOp,
172                    SubsetInsertionOpInterface insertionOp) const {
173     for (SubsetOpInterface other : allSubsetOps) {
174       if (other == extractionOp || other == insertionOp)
175         continue;
176       if (extractionOp &&
177           !other.operatesOnDisjointSubset(extractionOp, isEquivalent))
178         return false;
179       if (insertionOp &&
180           !other.operatesOnDisjointSubset(insertionOp, isEquivalent))
181         return false;
182     }
183     return true;
184   }
185 
186   /// Insert a subset extraction op. If the subset is equivalent to an existing
187   /// subset insertion op, pair them up. (If there is already a paired up subset
188   /// extraction op, overwrite the subset extraction op.)
insertExtractionOp(SubsetExtractionOpInterface extractionOp)189   void insertExtractionOp(SubsetExtractionOpInterface extractionOp) {
190     for (auto it : llvm::enumerate(insertions)) {
191       if (!it.value())
192         continue;
193       auto other = cast<SubsetOpInterface>(it.value().getOperation());
194       if (other.operatesOnEquivalentSubset(extractionOp, isEquivalent)) {
195         extractions[it.index()] = extractionOp;
196         return;
197       }
198     }
199     // There is no known equivalent insertion op. Create a new entry.
200     extractions.push_back(extractionOp);
201     insertions.push_back({});
202   }
203 
204   /// Insert a subset insertion op. If the subset is equivalent to an existing
205   /// subset extraction op, pair them up. (If there is already a paired up
206   /// subset insertion op, overwrite the subset insertion op.)
insertInsertionOp(SubsetInsertionOpInterface insertionOp)207   void insertInsertionOp(SubsetInsertionOpInterface insertionOp) {
208     for (auto it : llvm::enumerate(extractions)) {
209       if (!it.value())
210         continue;
211       auto other = cast<SubsetOpInterface>(it.value().getOperation());
212       if (other.operatesOnEquivalentSubset(insertionOp, isEquivalent)) {
213         insertions[it.index()] = insertionOp;
214         return;
215       }
216     }
217     // There is no known equivalent extraction op. Create a new entry.
218     extractions.push_back({});
219     insertions.push_back(insertionOp);
220   }
221 
222   SmallVector<SubsetExtractionOpInterface> extractions;
223   SmallVector<SubsetInsertionOpInterface> insertions;
224   SmallVector<SubsetOpInterface> allSubsetOps;
225 };
226 } // namespace
227 
228 /// If the given value has a single use by an op that is a terminator, return
229 /// that use. Otherwise, return nullptr.
getSingleTerminatorUse(Value value)230 static OpOperand *getSingleTerminatorUse(Value value) {
231   if (!value.hasOneUse())
232     return nullptr;
233   OpOperand &use = *value.getUses().begin();
234   if (use.getOwner()->hasTrait<OpTrait::IsTerminator>())
235     return &use;
236   return nullptr;
237 }
238 
239 LogicalResult
populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,BlockArgument iterArg,bool collectHoistableOps)240 MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
241                                             BlockArgument iterArg,
242                                             bool collectHoistableOps) {
243   assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
244   Value value = iterArg;
245 
246   // Traverse use-def chain. Subset ops can be hoisted only if all ops along the
247   // use-def chain starting from the region iter_arg are subset extraction or
248   // subset insertion ops. The chain must terminate at the corresponding yield
249   // operand (e.g., no swapping of iter_args).
250   OpOperand *yieldedOperand = nullptr;
251   // Iterate until the single use of the current SSA value is a terminator,
252   // which is expected to be the yielding operation of the loop.
253   while (!(yieldedOperand = getSingleTerminatorUse(value))) {
254     Value nextValue = {};
255 
256     for (OpOperand &use : value.getUses()) {
257       if (auto nestedLoop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
258         // Subset ops in nested loops are collected to check if there are only
259         // disjoint subset ops, but such subset ops are not subject to hoisting.
260         // To hoist subset ops from nested loops, the hoisting transformation
261         // should be run on the nested loop.
262         auto nestedIterArg = nestedLoop.getTiedLoopRegionIterArg(&use);
263         if (!nestedIterArg)
264           return failure();
265         // Note: `populateSubsetOpsAtIterArg` fails if there is no single SSA
266         // use-def chain starting at `nestedIterArg` and terminating in the
267         // tied, yielding operand.
268         if (failed(populateSubsetOpsAtIterArg(nestedLoop, nestedIterArg,
269                                               /*collectHoistableOps=*/false)))
270           return failure();
271         nextValue = nestedLoop.getTiedLoopResult(&use);
272         continue;
273       }
274 
275       auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner());
276       if (!subsetOp)
277         return failure();
278       insert(subsetOp);
279 
280       if (auto insertionOp =
281               dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) {
282         // Current implementation expects that the insertionOp implement
283         // the destinationStyleOpInterface as well. Abort if that tha is not
284         // the case
285         if (!isa<DestinationStyleOpInterface>(use.getOwner())) {
286           return failure();
287         }
288 
289         // The value must be used as a destination. (In case of a source, the
290         // entire tensor would be read, which would prevent any hoisting.)
291         if (&use != &insertionOp.getDestinationOperand())
292           return failure();
293         // There must be a single use-def chain from the region iter_arg to the
294         // terminator. I.e., only one insertion op. Branches are not supported.
295         if (nextValue)
296           return failure();
297         nextValue = insertionOp.getUpdatedDestination();
298       }
299     }
300 
301     // Nothing can be hoisted if the chain does not continue with loop yielding
302     // op or a subset insertion op.
303     if (!nextValue)
304       return failure();
305     value = nextValue;
306   }
307 
308   // Hoist only if the SSA use-def chain ends in the yielding terminator of the
309   // loop and the yielded value is the `idx`-th operand. (I.e., there is no
310   // swapping yield.)
311   if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand)
312     return failure();
313 
314   return success();
315 }
316 
317 /// Hoist all subset ops that operate on the idx-th region iter_arg of the given
318 /// loop-like op and index into loop-invariant subset locations. Return the
319 /// newly created loop op (that has extra iter_args) or the original loop op if
320 /// nothing was hoisted.
hoistSubsetAtIterArg(RewriterBase & rewriter,LoopLikeOpInterface loopLike,BlockArgument iterArg)321 static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
322                                                 LoopLikeOpInterface loopLike,
323                                                 BlockArgument iterArg) {
324   assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
325   auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
326   int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
327   MatchingSubsets subsets;
328   if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
329     return loopLike;
330 
331   // Hoist all matching extraction-insertion pairs one-by-one.
332   for (auto it : subsets.getHoistableSubsetOps()) {
333     auto extractionOp = std::get<0>(it);
334     auto insertionOp = std::get<1>(it);
335 
336     // Ops cannot be hoisted if they depend on loop-variant values.
337     if (extractionOp) {
338       if (!canBeHoisted(extractionOp, [&](OpOperand &operand) {
339             return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
340                    &operand == &extractionOp.getSourceOperand();
341           }))
342         extractionOp = {};
343     }
344     if (insertionOp) {
345       if (!canBeHoisted(insertionOp, [&](OpOperand &operand) {
346             return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
347                    &operand == &insertionOp.getSourceOperand() ||
348                    &operand == &insertionOp.getDestinationOperand();
349           }))
350         insertionOp = {};
351     }
352 
353     // Only hoist extraction-insertion pairs for now. Standalone extractions/
354     // insertions that are loop-invariant could be hoisted, but there may be
355     // easier ways to canonicalize the IR.
356     if (extractionOp && insertionOp) {
357       // Create a new loop with an additional iter_arg.
358       NewYieldValuesFn newYieldValuesFn =
359           [&](OpBuilder &b, Location loc,
360               ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
361         return {insertionOp.getSourceOperand().get()};
362       };
363       FailureOr<LoopLikeOpInterface> newLoop =
364           loopLike.replaceWithAdditionalYields(
365               rewriter, extractionOp.getResult(),
366               /*replaceInitOperandUsesInLoop=*/true, newYieldValuesFn);
367       if (failed(newLoop))
368         return loopLike;
369       loopLike = *newLoop;
370 
371       // Hoist the extraction/insertion ops.
372       iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
373       OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
374       OpResult newLoopResult = loopLike.getLoopResults()->back();
375       rewriter.moveOpBefore(extractionOp, loopLike);
376       rewriter.moveOpAfter(insertionOp, loopLike);
377       rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
378                                   insertionOp.getDestinationOperand().get());
379       extractionOp.getSourceOperand().set(
380           loopLike.getTiedLoopInit(iterArg)->get());
381       rewriter.replaceAllUsesWith(loopResult,
382                                   insertionOp.getUpdatedDestination());
383       insertionOp.getSourceOperand().set(newLoopResult);
384       insertionOp.getDestinationOperand().set(loopResult);
385     }
386   }
387 
388   return loopLike;
389 }
390 
391 LoopLikeOpInterface
hoistLoopInvariantSubsets(RewriterBase & rewriter,LoopLikeOpInterface loopLike)392 mlir::hoistLoopInvariantSubsets(RewriterBase &rewriter,
393                                 LoopLikeOpInterface loopLike) {
394   // Note: As subset ops are getting hoisted, the number of region iter_args
395   // increases. This can enable further hoisting opportunities on the new
396   // iter_args.
397   for (int64_t i = 0;
398        i < static_cast<int64_t>(loopLike.getRegionIterArgs().size()); ++i) {
399     loopLike = hoistSubsetAtIterArg(rewriter, loopLike,
400                                     loopLike.getRegionIterArgs()[i]);
401   }
402   return loopLike;
403 }
404