xref: /llvm-project/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp (revision 76ead96c1d06ee0d828238bce96d0107e650b5fa)
1fa26c7ffSMogball //===- LoopInvariantCodeMotionUtils.cpp - LICM Utils ------------*- C++ -*-===//
2fa26c7ffSMogball //
3fa26c7ffSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fa26c7ffSMogball // See https://llvm.org/LICENSE.txt for license information.
5fa26c7ffSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fa26c7ffSMogball //
7fa26c7ffSMogball //===----------------------------------------------------------------------===//
8fa26c7ffSMogball //
9fa26c7ffSMogball // This file contains the implementation of the core LICM algorithm.
10fa26c7ffSMogball //
11fa26c7ffSMogball //===----------------------------------------------------------------------===//
12fa26c7ffSMogball 
13fa26c7ffSMogball #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
142164a449SMatthias Springer 
15fa26c7ffSMogball #include "mlir/IR/Operation.h"
162164a449SMatthias Springer #include "mlir/IR/PatternMatch.h"
17fa26c7ffSMogball #include "mlir/Interfaces/LoopLikeInterface.h"
18fc367dfaSMahesh Ravishankar #include "mlir/Interfaces/SideEffectInterfaces.h"
192164a449SMatthias Springer #include "mlir/Interfaces/SubsetOpInterface.h"
20fa26c7ffSMogball #include "llvm/Support/Debug.h"
21fa26c7ffSMogball #include <queue>
22fa26c7ffSMogball 
23fa26c7ffSMogball #define DEBUG_TYPE "licm"
24fa26c7ffSMogball 
25fa26c7ffSMogball using namespace mlir;
26fa26c7ffSMogball 
27fa26c7ffSMogball /// Checks whether the given op can be hoisted by checking that
28fa26c7ffSMogball /// - the op and none of its contained operations depend on values inside of the
29fa26c7ffSMogball ///   loop (by means of calling definedOutside).
30fa26c7ffSMogball /// - the op has no side-effects.
canBeHoisted(Operation * op,function_ref<bool (OpOperand &)> condition)31fa26c7ffSMogball static bool canBeHoisted(Operation *op,
322164a449SMatthias Springer                          function_ref<bool(OpOperand &)> condition) {
33fa26c7ffSMogball   // Do not move terminators.
34fa26c7ffSMogball   if (op->hasTrait<OpTrait::IsTerminator>())
35fa26c7ffSMogball     return false;
36fa26c7ffSMogball 
37fa26c7ffSMogball   // Walk the nested operations and check that all used values are either
38fa26c7ffSMogball   // defined outside of the loop or in a nested region, but not at the level of
39fa26c7ffSMogball   // the loop body.
40fa26c7ffSMogball   auto walkFn = [&](Operation *child) {
412164a449SMatthias Springer     for (OpOperand &operand : child->getOpOperands()) {
42fa26c7ffSMogball       // Ignore values defined in a nested region.
432164a449SMatthias Springer       if (op->isAncestor(operand.get().getParentRegion()->getParentOp()))
44fa26c7ffSMogball         continue;
452164a449SMatthias Springer       if (!condition(operand))
46fa26c7ffSMogball         return WalkResult::interrupt();
47fa26c7ffSMogball     }
48fa26c7ffSMogball     return WalkResult::advance();
49fa26c7ffSMogball   };
50fa26c7ffSMogball   return !op->walk(walkFn).wasInterrupted();
51fa26c7ffSMogball }
52fa26c7ffSMogball 
canBeHoisted(Operation * op,function_ref<bool (Value)> definedOutside)532164a449SMatthias Springer static bool canBeHoisted(Operation *op,
542164a449SMatthias Springer                          function_ref<bool(Value)> definedOutside) {
552164a449SMatthias Springer   return canBeHoisted(
562164a449SMatthias Springer       op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
572164a449SMatthias Springer }
582164a449SMatthias Springer 
moveLoopInvariantCode(ArrayRef<Region * > regions,function_ref<bool (Value,Region *)> isDefinedOutsideRegion,function_ref<bool (Operation *,Region *)> shouldMoveOutOfRegion,function_ref<void (Operation *,Region *)> moveOutOfRegion)59fa26c7ffSMogball size_t mlir::moveLoopInvariantCode(
609b5ef2beSMatthias Springer     ArrayRef<Region *> regions,
61fa26c7ffSMogball     function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
62fa26c7ffSMogball     function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
63fa26c7ffSMogball     function_ref<void(Operation *, Region *)> moveOutOfRegion) {
64fa26c7ffSMogball   size_t numMoved = 0;
65fa26c7ffSMogball 
66fa26c7ffSMogball   for (Region *region : regions) {
679bb0f461SCullen Rhodes     LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
689bb0f461SCullen Rhodes                             << *region->getParentOp() << "\n");
69fa26c7ffSMogball 
70fa26c7ffSMogball     std::queue<Operation *> worklist;
71fa26c7ffSMogball     // Add top-level operations in the loop body to the worklist.
72fa26c7ffSMogball     for (Operation &op : region->getOps())
73fa26c7ffSMogball       worklist.push(&op);
74fa26c7ffSMogball 
75fa26c7ffSMogball     auto definedOutside = [&](Value value) {
76fa26c7ffSMogball       return isDefinedOutsideRegion(value, region);
77fa26c7ffSMogball     };
78fa26c7ffSMogball 
79fa26c7ffSMogball     while (!worklist.empty()) {
80fa26c7ffSMogball       Operation *op = worklist.front();
81fa26c7ffSMogball       worklist.pop();
82fa26c7ffSMogball       // Skip ops that have already been moved. Check if the op can be hoisted.
83fa26c7ffSMogball       if (op->getParentRegion() != region)
84fa26c7ffSMogball         continue;
85fa26c7ffSMogball 
869bb0f461SCullen Rhodes       LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
87fa26c7ffSMogball       if (!shouldMoveOutOfRegion(op, region) ||
88fa26c7ffSMogball           !canBeHoisted(op, definedOutside))
89fa26c7ffSMogball         continue;
90fa26c7ffSMogball 
919bb0f461SCullen Rhodes       LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
92fa26c7ffSMogball       moveOutOfRegion(op, region);
93fa26c7ffSMogball       ++numMoved;
94fa26c7ffSMogball 
95fa26c7ffSMogball       // Since the op has been moved, we need to check its users within the
96fa26c7ffSMogball       // top-level of the loop body.
97fa26c7ffSMogball       for (Operation *user : op->getUsers())
98fa26c7ffSMogball         if (user->getParentRegion() == region)
99fa26c7ffSMogball           worklist.push(user);
100fa26c7ffSMogball     }
101fa26c7ffSMogball   }
102fa26c7ffSMogball 
103fa26c7ffSMogball   return numMoved;
104fa26c7ffSMogball }
105fa26c7ffSMogball 
moveLoopInvariantCode(LoopLikeOpInterface loopLike)106fa26c7ffSMogball size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
107fa26c7ffSMogball   return moveLoopInvariantCode(
1089b5ef2beSMatthias Springer       loopLike.getLoopRegions(),
109fa26c7ffSMogball       [&](Value value, Region *) {
110fa26c7ffSMogball         return loopLike.isDefinedOutsideOfLoop(value);
111fa26c7ffSMogball       },
11286771d0bSSanjoy Das       [&](Operation *op, Region *) {
11386771d0bSSanjoy Das         return isMemoryEffectFree(op) && isSpeculatable(op);
11486771d0bSSanjoy Das       },
115fa26c7ffSMogball       [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
116fa26c7ffSMogball }
1172164a449SMatthias Springer 
1182164a449SMatthias Springer namespace {
1192164a449SMatthias Springer /// Helper data structure that keeps track of equivalent/disjoint subset ops.
1202164a449SMatthias Springer class MatchingSubsets {
1212164a449SMatthias Springer public:
1222164a449SMatthias Springer   /// Insert a subset op.
insert(SubsetOpInterface op,bool collectHoistableOps=true)1237ea1c395SMatthias Springer   void insert(SubsetOpInterface op, bool collectHoistableOps = true) {
1242164a449SMatthias Springer     allSubsetOps.push_back(op);
1257ea1c395SMatthias Springer     if (!collectHoistableOps)
1267ea1c395SMatthias Springer       return;
1272164a449SMatthias Springer     if (auto extractionOp =
1282164a449SMatthias Springer             dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
1292164a449SMatthias Springer       insertExtractionOp(extractionOp);
1302164a449SMatthias Springer     if (auto insertionOp =
1312164a449SMatthias Springer             dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
1322164a449SMatthias Springer       insertInsertionOp(insertionOp);
1332164a449SMatthias Springer   }
1342164a449SMatthias Springer 
1352164a449SMatthias Springer   /// Return a range of matching extraction-insertion subset ops. If there is no
1362164a449SMatthias Springer   /// matching extraction/insertion op, the respective value is empty. Ops are
1372164a449SMatthias Springer   /// skipped if there are other subset ops that are not guaranteed to operate
1382164a449SMatthias Springer   /// on disjoint subsets.
getHoistableSubsetOps()1392164a449SMatthias Springer   auto getHoistableSubsetOps() {
1402164a449SMatthias Springer     return llvm::make_filter_range(
1412164a449SMatthias Springer         llvm::zip(extractions, insertions), [&](auto pair) {
1422164a449SMatthias Springer           auto [extractionOp, insertionOp] = pair;
1432164a449SMatthias Springer           // Hoist only if the extracted and inserted values have the same type.
1442164a449SMatthias Springer           if (extractionOp && insertionOp &&
1452164a449SMatthias Springer               extractionOp->getResult(0).getType() !=
1462164a449SMatthias Springer                   insertionOp.getSourceOperand().get().getType())
1472164a449SMatthias Springer             return false;
1482164a449SMatthias Springer           // Hoist only if there are no conflicting subset ops.
1492164a449SMatthias Springer           return allDisjoint(extractionOp, insertionOp);
1502164a449SMatthias Springer         });
1512164a449SMatthias Springer   }
1522164a449SMatthias Springer 
1537ea1c395SMatthias Springer   /// Populate subset ops starting from the given region iter_arg. Return
1547ea1c395SMatthias Springer   /// "failure" if non-subset ops are found along the path to the loop yielding
1557ea1c395SMatthias Springer   /// op or if there is no single path to the tied yielded operand. If
1567ea1c395SMatthias Springer   /// `collectHoistableOps` is set to "false", subset ops are gathered
1577ea1c395SMatthias Springer   /// throughout the traversal, but not enumerated by `getHoistableSubsetOps`.
1587ea1c395SMatthias Springer   LogicalResult populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
1597ea1c395SMatthias Springer                                            BlockArgument iterArg,
1607ea1c395SMatthias Springer                                            bool collectHoistableOps = true);
1617ea1c395SMatthias Springer 
1622164a449SMatthias Springer private:
1632164a449SMatthias Springer   /// Helper function for equivalence of tensor values. Since only insertion
1642164a449SMatthias Springer   /// subset ops (that are also destination style ops) are followed when
1652164a449SMatthias Springer   /// traversing the SSA use-def chain, all tensor values are equivalent.
isEquivalent(Value v1,Value v2)1662164a449SMatthias Springer   static bool isEquivalent(Value v1, Value v2) { return true; }
1672164a449SMatthias Springer 
1682164a449SMatthias Springer   /// Return "true" if the subsets of the given extraction and insertion ops
1692164a449SMatthias Springer   /// are operating disjoint from the subsets that all other known subset ops
1702164a449SMatthias Springer   /// are operating on.
allDisjoint(SubsetExtractionOpInterface extractionOp,SubsetInsertionOpInterface insertionOp) const1712164a449SMatthias Springer   bool allDisjoint(SubsetExtractionOpInterface extractionOp,
1722164a449SMatthias Springer                    SubsetInsertionOpInterface insertionOp) const {
1732164a449SMatthias Springer     for (SubsetOpInterface other : allSubsetOps) {
1742164a449SMatthias Springer       if (other == extractionOp || other == insertionOp)
1752164a449SMatthias Springer         continue;
1762164a449SMatthias Springer       if (extractionOp &&
1772164a449SMatthias Springer           !other.operatesOnDisjointSubset(extractionOp, isEquivalent))
1782164a449SMatthias Springer         return false;
1792164a449SMatthias Springer       if (insertionOp &&
1802164a449SMatthias Springer           !other.operatesOnDisjointSubset(insertionOp, isEquivalent))
1812164a449SMatthias Springer         return false;
1822164a449SMatthias Springer     }
1832164a449SMatthias Springer     return true;
1842164a449SMatthias Springer   }
1852164a449SMatthias Springer 
1862164a449SMatthias Springer   /// Insert a subset extraction op. If the subset is equivalent to an existing
1872164a449SMatthias Springer   /// subset insertion op, pair them up. (If there is already a paired up subset
1882164a449SMatthias Springer   /// extraction op, overwrite the subset extraction op.)
insertExtractionOp(SubsetExtractionOpInterface extractionOp)1892164a449SMatthias Springer   void insertExtractionOp(SubsetExtractionOpInterface extractionOp) {
1902164a449SMatthias Springer     for (auto it : llvm::enumerate(insertions)) {
1912164a449SMatthias Springer       if (!it.value())
1922164a449SMatthias Springer         continue;
1932164a449SMatthias Springer       auto other = cast<SubsetOpInterface>(it.value().getOperation());
1942164a449SMatthias Springer       if (other.operatesOnEquivalentSubset(extractionOp, isEquivalent)) {
1952164a449SMatthias Springer         extractions[it.index()] = extractionOp;
1962164a449SMatthias Springer         return;
1972164a449SMatthias Springer       }
1982164a449SMatthias Springer     }
1992164a449SMatthias Springer     // There is no known equivalent insertion op. Create a new entry.
2002164a449SMatthias Springer     extractions.push_back(extractionOp);
2012164a449SMatthias Springer     insertions.push_back({});
2022164a449SMatthias Springer   }
2032164a449SMatthias Springer 
2042164a449SMatthias Springer   /// Insert a subset insertion op. If the subset is equivalent to an existing
2052164a449SMatthias Springer   /// subset extraction op, pair them up. (If there is already a paired up
2062164a449SMatthias Springer   /// subset insertion op, overwrite the subset insertion op.)
insertInsertionOp(SubsetInsertionOpInterface insertionOp)2072164a449SMatthias Springer   void insertInsertionOp(SubsetInsertionOpInterface insertionOp) {
2082164a449SMatthias Springer     for (auto it : llvm::enumerate(extractions)) {
2092164a449SMatthias Springer       if (!it.value())
2102164a449SMatthias Springer         continue;
2112164a449SMatthias Springer       auto other = cast<SubsetOpInterface>(it.value().getOperation());
2122164a449SMatthias Springer       if (other.operatesOnEquivalentSubset(insertionOp, isEquivalent)) {
2132164a449SMatthias Springer         insertions[it.index()] = insertionOp;
2142164a449SMatthias Springer         return;
2152164a449SMatthias Springer       }
2162164a449SMatthias Springer     }
2172164a449SMatthias Springer     // There is no known equivalent extraction op. Create a new entry.
2182164a449SMatthias Springer     extractions.push_back({});
2192164a449SMatthias Springer     insertions.push_back(insertionOp);
2202164a449SMatthias Springer   }
2212164a449SMatthias Springer 
2222164a449SMatthias Springer   SmallVector<SubsetExtractionOpInterface> extractions;
2232164a449SMatthias Springer   SmallVector<SubsetInsertionOpInterface> insertions;
2242164a449SMatthias Springer   SmallVector<SubsetOpInterface> allSubsetOps;
2252164a449SMatthias Springer };
2262164a449SMatthias Springer } // namespace
2272164a449SMatthias Springer 
2282164a449SMatthias Springer /// If the given value has a single use by an op that is a terminator, return
2292164a449SMatthias Springer /// that use. Otherwise, return nullptr.
getSingleTerminatorUse(Value value)2302164a449SMatthias Springer static OpOperand *getSingleTerminatorUse(Value value) {
2312164a449SMatthias Springer   if (!value.hasOneUse())
2322164a449SMatthias Springer     return nullptr;
2332164a449SMatthias Springer   OpOperand &use = *value.getUses().begin();
2342164a449SMatthias Springer   if (use.getOwner()->hasTrait<OpTrait::IsTerminator>())
2352164a449SMatthias Springer     return &use;
2362164a449SMatthias Springer   return nullptr;
2372164a449SMatthias Springer }
2382164a449SMatthias Springer 
2397ea1c395SMatthias Springer LogicalResult
populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,BlockArgument iterArg,bool collectHoistableOps)2407ea1c395SMatthias Springer MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
2417ea1c395SMatthias Springer                                             BlockArgument iterArg,
2427ea1c395SMatthias Springer                                             bool collectHoistableOps) {
2432164a449SMatthias Springer   assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
2442164a449SMatthias Springer   Value value = iterArg;
2452164a449SMatthias Springer 
2462164a449SMatthias Springer   // Traverse use-def chain. Subset ops can be hoisted only if all ops along the
2472164a449SMatthias Springer   // use-def chain starting from the region iter_arg are subset extraction or
2482164a449SMatthias Springer   // subset insertion ops. The chain must terminate at the corresponding yield
2492164a449SMatthias Springer   // operand (e.g., no swapping of iter_args).
2502164a449SMatthias Springer   OpOperand *yieldedOperand = nullptr;
2512164a449SMatthias Springer   // Iterate until the single use of the current SSA value is a terminator,
2522164a449SMatthias Springer   // which is expected to be the yielding operation of the loop.
2532164a449SMatthias Springer   while (!(yieldedOperand = getSingleTerminatorUse(value))) {
2542164a449SMatthias Springer     Value nextValue = {};
2552164a449SMatthias Springer 
2562164a449SMatthias Springer     for (OpOperand &use : value.getUses()) {
2577ea1c395SMatthias Springer       if (auto nestedLoop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
2587ea1c395SMatthias Springer         // Subset ops in nested loops are collected to check if there are only
2597ea1c395SMatthias Springer         // disjoint subset ops, but such subset ops are not subject to hoisting.
2607ea1c395SMatthias Springer         // To hoist subset ops from nested loops, the hoisting transformation
2617ea1c395SMatthias Springer         // should be run on the nested loop.
2627ea1c395SMatthias Springer         auto nestedIterArg = nestedLoop.getTiedLoopRegionIterArg(&use);
2637ea1c395SMatthias Springer         if (!nestedIterArg)
2647ea1c395SMatthias Springer           return failure();
2657ea1c395SMatthias Springer         // Note: `populateSubsetOpsAtIterArg` fails if there is no single SSA
2667ea1c395SMatthias Springer         // use-def chain starting at `nestedIterArg` and terminating in the
2677ea1c395SMatthias Springer         // tied, yielding operand.
2687ea1c395SMatthias Springer         if (failed(populateSubsetOpsAtIterArg(nestedLoop, nestedIterArg,
2697ea1c395SMatthias Springer                                               /*collectHoistableOps=*/false)))
2707ea1c395SMatthias Springer           return failure();
2717ea1c395SMatthias Springer         nextValue = nestedLoop.getTiedLoopResult(&use);
2727ea1c395SMatthias Springer         continue;
2737ea1c395SMatthias Springer       }
2747ea1c395SMatthias Springer 
2752164a449SMatthias Springer       auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner());
2762164a449SMatthias Springer       if (!subsetOp)
2777ea1c395SMatthias Springer         return failure();
2787ea1c395SMatthias Springer       insert(subsetOp);
2792164a449SMatthias Springer 
2802164a449SMatthias Springer       if (auto insertionOp =
2812164a449SMatthias Springer               dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) {
282*76ead96cSMaheshRavishankar         // Current implementation expects that the insertionOp implement
283*76ead96cSMaheshRavishankar         // the destinationStyleOpInterface as well. Abort if that tha is not
284*76ead96cSMaheshRavishankar         // the case
285*76ead96cSMaheshRavishankar         if (!isa<DestinationStyleOpInterface>(use.getOwner())) {
286*76ead96cSMaheshRavishankar           return failure();
287*76ead96cSMaheshRavishankar         }
288*76ead96cSMaheshRavishankar 
2892164a449SMatthias Springer         // The value must be used as a destination. (In case of a source, the
2902164a449SMatthias Springer         // entire tensor would be read, which would prevent any hoisting.)
2912164a449SMatthias Springer         if (&use != &insertionOp.getDestinationOperand())
2927ea1c395SMatthias Springer           return failure();
2932164a449SMatthias Springer         // There must be a single use-def chain from the region iter_arg to the
2942164a449SMatthias Springer         // terminator. I.e., only one insertion op. Branches are not supported.
2952164a449SMatthias Springer         if (nextValue)
2967ea1c395SMatthias Springer           return failure();
2972164a449SMatthias Springer         nextValue = insertionOp.getUpdatedDestination();
2982164a449SMatthias Springer       }
2992164a449SMatthias Springer     }
3002164a449SMatthias Springer 
3012164a449SMatthias Springer     // Nothing can be hoisted if the chain does not continue with loop yielding
3022164a449SMatthias Springer     // op or a subset insertion op.
3032164a449SMatthias Springer     if (!nextValue)
3047ea1c395SMatthias Springer       return failure();
3052164a449SMatthias Springer     value = nextValue;
3062164a449SMatthias Springer   }
3072164a449SMatthias Springer 
3082164a449SMatthias Springer   // Hoist only if the SSA use-def chain ends in the yielding terminator of the
3092164a449SMatthias Springer   // loop and the yielded value is the `idx`-th operand. (I.e., there is no
3102164a449SMatthias Springer   // swapping yield.)
3112164a449SMatthias Springer   if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand)
3127ea1c395SMatthias Springer     return failure();
3137ea1c395SMatthias Springer 
3147ea1c395SMatthias Springer   return success();
3157ea1c395SMatthias Springer }
3167ea1c395SMatthias Springer 
3177ea1c395SMatthias Springer /// Hoist all subset ops that operate on the idx-th region iter_arg of the given
3187ea1c395SMatthias Springer /// loop-like op and index into loop-invariant subset locations. Return the
3197ea1c395SMatthias Springer /// newly created loop op (that has extra iter_args) or the original loop op if
3207ea1c395SMatthias Springer /// nothing was hoisted.
hoistSubsetAtIterArg(RewriterBase & rewriter,LoopLikeOpInterface loopLike,BlockArgument iterArg)321b9fe461eSMatthias Springer static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
322b9fe461eSMatthias Springer                                                 LoopLikeOpInterface loopLike,
3237ea1c395SMatthias Springer                                                 BlockArgument iterArg) {
3247ea1c395SMatthias Springer   assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
3257ea1c395SMatthias Springer   auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
3267ea1c395SMatthias Springer   int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
3277ea1c395SMatthias Springer   MatchingSubsets subsets;
3287ea1c395SMatthias Springer   if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
3292164a449SMatthias Springer     return loopLike;
3302164a449SMatthias Springer 
3312164a449SMatthias Springer   // Hoist all matching extraction-insertion pairs one-by-one.
3322164a449SMatthias Springer   for (auto it : subsets.getHoistableSubsetOps()) {
3332164a449SMatthias Springer     auto extractionOp = std::get<0>(it);
3342164a449SMatthias Springer     auto insertionOp = std::get<1>(it);
3352164a449SMatthias Springer 
3362164a449SMatthias Springer     // Ops cannot be hoisted if they depend on loop-variant values.
3372164a449SMatthias Springer     if (extractionOp) {
3382164a449SMatthias Springer       if (!canBeHoisted(extractionOp, [&](OpOperand &operand) {
3392164a449SMatthias Springer             return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
3402164a449SMatthias Springer                    &operand == &extractionOp.getSourceOperand();
3412164a449SMatthias Springer           }))
3422164a449SMatthias Springer         extractionOp = {};
3432164a449SMatthias Springer     }
3442164a449SMatthias Springer     if (insertionOp) {
3452164a449SMatthias Springer       if (!canBeHoisted(insertionOp, [&](OpOperand &operand) {
3462164a449SMatthias Springer             return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
3472164a449SMatthias Springer                    &operand == &insertionOp.getSourceOperand() ||
3482164a449SMatthias Springer                    &operand == &insertionOp.getDestinationOperand();
3492164a449SMatthias Springer           }))
3502164a449SMatthias Springer         insertionOp = {};
3512164a449SMatthias Springer     }
3522164a449SMatthias Springer 
3532164a449SMatthias Springer     // Only hoist extraction-insertion pairs for now. Standalone extractions/
3542164a449SMatthias Springer     // insertions that are loop-invariant could be hoisted, but there may be
3552164a449SMatthias Springer     // easier ways to canonicalize the IR.
3562164a449SMatthias Springer     if (extractionOp && insertionOp) {
3572164a449SMatthias Springer       // Create a new loop with an additional iter_arg.
3582164a449SMatthias Springer       NewYieldValuesFn newYieldValuesFn =
3592164a449SMatthias Springer           [&](OpBuilder &b, Location loc,
3602164a449SMatthias Springer               ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
3612164a449SMatthias Springer         return {insertionOp.getSourceOperand().get()};
3622164a449SMatthias Springer       };
3632164a449SMatthias Springer       FailureOr<LoopLikeOpInterface> newLoop =
3642164a449SMatthias Springer           loopLike.replaceWithAdditionalYields(
3652164a449SMatthias Springer               rewriter, extractionOp.getResult(),
3662164a449SMatthias Springer               /*replaceInitOperandUsesInLoop=*/true, newYieldValuesFn);
3672164a449SMatthias Springer       if (failed(newLoop))
3682164a449SMatthias Springer         return loopLike;
3692164a449SMatthias Springer       loopLike = *newLoop;
3702164a449SMatthias Springer 
3712164a449SMatthias Springer       // Hoist the extraction/insertion ops.
3722164a449SMatthias Springer       iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
3732164a449SMatthias Springer       OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
3742164a449SMatthias Springer       OpResult newLoopResult = loopLike.getLoopResults()->back();
3755cc0f76dSMatthias Springer       rewriter.moveOpBefore(extractionOp, loopLike);
3765cc0f76dSMatthias Springer       rewriter.moveOpAfter(insertionOp, loopLike);
377b9fe461eSMatthias Springer       rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
3782164a449SMatthias Springer                                   insertionOp.getDestinationOperand().get());
3792164a449SMatthias Springer       extractionOp.getSourceOperand().set(
3802164a449SMatthias Springer           loopLike.getTiedLoopInit(iterArg)->get());
381b9fe461eSMatthias Springer       rewriter.replaceAllUsesWith(loopResult,
382b9fe461eSMatthias Springer                                   insertionOp.getUpdatedDestination());
3832164a449SMatthias Springer       insertionOp.getSourceOperand().set(newLoopResult);
3842164a449SMatthias Springer       insertionOp.getDestinationOperand().set(loopResult);
3852164a449SMatthias Springer     }
3862164a449SMatthias Springer   }
3872164a449SMatthias Springer 
3882164a449SMatthias Springer   return loopLike;
3892164a449SMatthias Springer }
3902164a449SMatthias Springer 
3912164a449SMatthias Springer LoopLikeOpInterface
hoistLoopInvariantSubsets(RewriterBase & rewriter,LoopLikeOpInterface loopLike)392b9fe461eSMatthias Springer mlir::hoistLoopInvariantSubsets(RewriterBase &rewriter,
393b9fe461eSMatthias Springer                                 LoopLikeOpInterface loopLike) {
3942164a449SMatthias Springer   // Note: As subset ops are getting hoisted, the number of region iter_args
3952164a449SMatthias Springer   // increases. This can enable further hoisting opportunities on the new
3962164a449SMatthias Springer   // iter_args.
3972164a449SMatthias Springer   for (int64_t i = 0;
3982164a449SMatthias Springer        i < static_cast<int64_t>(loopLike.getRegionIterArgs().size()); ++i) {
399b9fe461eSMatthias Springer     loopLike = hoistSubsetAtIterArg(rewriter, loopLike,
400b9fe461eSMatthias Springer                                     loopLike.getRegionIterArgs()[i]);
4012164a449SMatthias Springer   }
4022164a449SMatthias Springer   return loopLike;
4032164a449SMatthias Springer }
404