1fa26c7ffSMogball //===- LoopInvariantCodeMotionUtils.cpp - LICM Utils ------------*- C++ -*-===//
2fa26c7ffSMogball //
3fa26c7ffSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fa26c7ffSMogball // See https://llvm.org/LICENSE.txt for license information.
5fa26c7ffSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fa26c7ffSMogball //
7fa26c7ffSMogball //===----------------------------------------------------------------------===//
8fa26c7ffSMogball //
9fa26c7ffSMogball // This file contains the implementation of the core LICM algorithm.
10fa26c7ffSMogball //
11fa26c7ffSMogball //===----------------------------------------------------------------------===//
12fa26c7ffSMogball
13fa26c7ffSMogball #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
142164a449SMatthias Springer
15fa26c7ffSMogball #include "mlir/IR/Operation.h"
162164a449SMatthias Springer #include "mlir/IR/PatternMatch.h"
17fa26c7ffSMogball #include "mlir/Interfaces/LoopLikeInterface.h"
18fc367dfaSMahesh Ravishankar #include "mlir/Interfaces/SideEffectInterfaces.h"
192164a449SMatthias Springer #include "mlir/Interfaces/SubsetOpInterface.h"
20fa26c7ffSMogball #include "llvm/Support/Debug.h"
21fa26c7ffSMogball #include <queue>
22fa26c7ffSMogball
23fa26c7ffSMogball #define DEBUG_TYPE "licm"
24fa26c7ffSMogball
25fa26c7ffSMogball using namespace mlir;
26fa26c7ffSMogball
27fa26c7ffSMogball /// Checks whether the given op can be hoisted by checking that
28fa26c7ffSMogball /// - the op and none of its contained operations depend on values inside of the
29fa26c7ffSMogball /// loop (by means of calling definedOutside).
30fa26c7ffSMogball /// - the op has no side-effects.
canBeHoisted(Operation * op,function_ref<bool (OpOperand &)> condition)31fa26c7ffSMogball static bool canBeHoisted(Operation *op,
322164a449SMatthias Springer function_ref<bool(OpOperand &)> condition) {
33fa26c7ffSMogball // Do not move terminators.
34fa26c7ffSMogball if (op->hasTrait<OpTrait::IsTerminator>())
35fa26c7ffSMogball return false;
36fa26c7ffSMogball
37fa26c7ffSMogball // Walk the nested operations and check that all used values are either
38fa26c7ffSMogball // defined outside of the loop or in a nested region, but not at the level of
39fa26c7ffSMogball // the loop body.
40fa26c7ffSMogball auto walkFn = [&](Operation *child) {
412164a449SMatthias Springer for (OpOperand &operand : child->getOpOperands()) {
42fa26c7ffSMogball // Ignore values defined in a nested region.
432164a449SMatthias Springer if (op->isAncestor(operand.get().getParentRegion()->getParentOp()))
44fa26c7ffSMogball continue;
452164a449SMatthias Springer if (!condition(operand))
46fa26c7ffSMogball return WalkResult::interrupt();
47fa26c7ffSMogball }
48fa26c7ffSMogball return WalkResult::advance();
49fa26c7ffSMogball };
50fa26c7ffSMogball return !op->walk(walkFn).wasInterrupted();
51fa26c7ffSMogball }
52fa26c7ffSMogball
canBeHoisted(Operation * op,function_ref<bool (Value)> definedOutside)532164a449SMatthias Springer static bool canBeHoisted(Operation *op,
542164a449SMatthias Springer function_ref<bool(Value)> definedOutside) {
552164a449SMatthias Springer return canBeHoisted(
562164a449SMatthias Springer op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
572164a449SMatthias Springer }
582164a449SMatthias Springer
moveLoopInvariantCode(ArrayRef<Region * > regions,function_ref<bool (Value,Region *)> isDefinedOutsideRegion,function_ref<bool (Operation *,Region *)> shouldMoveOutOfRegion,function_ref<void (Operation *,Region *)> moveOutOfRegion)59fa26c7ffSMogball size_t mlir::moveLoopInvariantCode(
609b5ef2beSMatthias Springer ArrayRef<Region *> regions,
61fa26c7ffSMogball function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
62fa26c7ffSMogball function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
63fa26c7ffSMogball function_ref<void(Operation *, Region *)> moveOutOfRegion) {
64fa26c7ffSMogball size_t numMoved = 0;
65fa26c7ffSMogball
66fa26c7ffSMogball for (Region *region : regions) {
679bb0f461SCullen Rhodes LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
689bb0f461SCullen Rhodes << *region->getParentOp() << "\n");
69fa26c7ffSMogball
70fa26c7ffSMogball std::queue<Operation *> worklist;
71fa26c7ffSMogball // Add top-level operations in the loop body to the worklist.
72fa26c7ffSMogball for (Operation &op : region->getOps())
73fa26c7ffSMogball worklist.push(&op);
74fa26c7ffSMogball
75fa26c7ffSMogball auto definedOutside = [&](Value value) {
76fa26c7ffSMogball return isDefinedOutsideRegion(value, region);
77fa26c7ffSMogball };
78fa26c7ffSMogball
79fa26c7ffSMogball while (!worklist.empty()) {
80fa26c7ffSMogball Operation *op = worklist.front();
81fa26c7ffSMogball worklist.pop();
82fa26c7ffSMogball // Skip ops that have already been moved. Check if the op can be hoisted.
83fa26c7ffSMogball if (op->getParentRegion() != region)
84fa26c7ffSMogball continue;
85fa26c7ffSMogball
869bb0f461SCullen Rhodes LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
87fa26c7ffSMogball if (!shouldMoveOutOfRegion(op, region) ||
88fa26c7ffSMogball !canBeHoisted(op, definedOutside))
89fa26c7ffSMogball continue;
90fa26c7ffSMogball
919bb0f461SCullen Rhodes LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
92fa26c7ffSMogball moveOutOfRegion(op, region);
93fa26c7ffSMogball ++numMoved;
94fa26c7ffSMogball
95fa26c7ffSMogball // Since the op has been moved, we need to check its users within the
96fa26c7ffSMogball // top-level of the loop body.
97fa26c7ffSMogball for (Operation *user : op->getUsers())
98fa26c7ffSMogball if (user->getParentRegion() == region)
99fa26c7ffSMogball worklist.push(user);
100fa26c7ffSMogball }
101fa26c7ffSMogball }
102fa26c7ffSMogball
103fa26c7ffSMogball return numMoved;
104fa26c7ffSMogball }
105fa26c7ffSMogball
moveLoopInvariantCode(LoopLikeOpInterface loopLike)106fa26c7ffSMogball size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
107fa26c7ffSMogball return moveLoopInvariantCode(
1089b5ef2beSMatthias Springer loopLike.getLoopRegions(),
109fa26c7ffSMogball [&](Value value, Region *) {
110fa26c7ffSMogball return loopLike.isDefinedOutsideOfLoop(value);
111fa26c7ffSMogball },
11286771d0bSSanjoy Das [&](Operation *op, Region *) {
11386771d0bSSanjoy Das return isMemoryEffectFree(op) && isSpeculatable(op);
11486771d0bSSanjoy Das },
115fa26c7ffSMogball [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
116fa26c7ffSMogball }
1172164a449SMatthias Springer
1182164a449SMatthias Springer namespace {
1192164a449SMatthias Springer /// Helper data structure that keeps track of equivalent/disjoint subset ops.
1202164a449SMatthias Springer class MatchingSubsets {
1212164a449SMatthias Springer public:
1222164a449SMatthias Springer /// Insert a subset op.
insert(SubsetOpInterface op,bool collectHoistableOps=true)1237ea1c395SMatthias Springer void insert(SubsetOpInterface op, bool collectHoistableOps = true) {
1242164a449SMatthias Springer allSubsetOps.push_back(op);
1257ea1c395SMatthias Springer if (!collectHoistableOps)
1267ea1c395SMatthias Springer return;
1272164a449SMatthias Springer if (auto extractionOp =
1282164a449SMatthias Springer dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
1292164a449SMatthias Springer insertExtractionOp(extractionOp);
1302164a449SMatthias Springer if (auto insertionOp =
1312164a449SMatthias Springer dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
1322164a449SMatthias Springer insertInsertionOp(insertionOp);
1332164a449SMatthias Springer }
1342164a449SMatthias Springer
1352164a449SMatthias Springer /// Return a range of matching extraction-insertion subset ops. If there is no
1362164a449SMatthias Springer /// matching extraction/insertion op, the respective value is empty. Ops are
1372164a449SMatthias Springer /// skipped if there are other subset ops that are not guaranteed to operate
1382164a449SMatthias Springer /// on disjoint subsets.
getHoistableSubsetOps()1392164a449SMatthias Springer auto getHoistableSubsetOps() {
1402164a449SMatthias Springer return llvm::make_filter_range(
1412164a449SMatthias Springer llvm::zip(extractions, insertions), [&](auto pair) {
1422164a449SMatthias Springer auto [extractionOp, insertionOp] = pair;
1432164a449SMatthias Springer // Hoist only if the extracted and inserted values have the same type.
1442164a449SMatthias Springer if (extractionOp && insertionOp &&
1452164a449SMatthias Springer extractionOp->getResult(0).getType() !=
1462164a449SMatthias Springer insertionOp.getSourceOperand().get().getType())
1472164a449SMatthias Springer return false;
1482164a449SMatthias Springer // Hoist only if there are no conflicting subset ops.
1492164a449SMatthias Springer return allDisjoint(extractionOp, insertionOp);
1502164a449SMatthias Springer });
1512164a449SMatthias Springer }
1522164a449SMatthias Springer
1537ea1c395SMatthias Springer /// Populate subset ops starting from the given region iter_arg. Return
1547ea1c395SMatthias Springer /// "failure" if non-subset ops are found along the path to the loop yielding
1557ea1c395SMatthias Springer /// op or if there is no single path to the tied yielded operand. If
1567ea1c395SMatthias Springer /// `collectHoistableOps` is set to "false", subset ops are gathered
1577ea1c395SMatthias Springer /// throughout the traversal, but not enumerated by `getHoistableSubsetOps`.
1587ea1c395SMatthias Springer LogicalResult populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
1597ea1c395SMatthias Springer BlockArgument iterArg,
1607ea1c395SMatthias Springer bool collectHoistableOps = true);
1617ea1c395SMatthias Springer
1622164a449SMatthias Springer private:
1632164a449SMatthias Springer /// Helper function for equivalence of tensor values. Since only insertion
1642164a449SMatthias Springer /// subset ops (that are also destination style ops) are followed when
1652164a449SMatthias Springer /// traversing the SSA use-def chain, all tensor values are equivalent.
isEquivalent(Value v1,Value v2)1662164a449SMatthias Springer static bool isEquivalent(Value v1, Value v2) { return true; }
1672164a449SMatthias Springer
1682164a449SMatthias Springer /// Return "true" if the subsets of the given extraction and insertion ops
1692164a449SMatthias Springer /// are operating disjoint from the subsets that all other known subset ops
1702164a449SMatthias Springer /// are operating on.
allDisjoint(SubsetExtractionOpInterface extractionOp,SubsetInsertionOpInterface insertionOp) const1712164a449SMatthias Springer bool allDisjoint(SubsetExtractionOpInterface extractionOp,
1722164a449SMatthias Springer SubsetInsertionOpInterface insertionOp) const {
1732164a449SMatthias Springer for (SubsetOpInterface other : allSubsetOps) {
1742164a449SMatthias Springer if (other == extractionOp || other == insertionOp)
1752164a449SMatthias Springer continue;
1762164a449SMatthias Springer if (extractionOp &&
1772164a449SMatthias Springer !other.operatesOnDisjointSubset(extractionOp, isEquivalent))
1782164a449SMatthias Springer return false;
1792164a449SMatthias Springer if (insertionOp &&
1802164a449SMatthias Springer !other.operatesOnDisjointSubset(insertionOp, isEquivalent))
1812164a449SMatthias Springer return false;
1822164a449SMatthias Springer }
1832164a449SMatthias Springer return true;
1842164a449SMatthias Springer }
1852164a449SMatthias Springer
1862164a449SMatthias Springer /// Insert a subset extraction op. If the subset is equivalent to an existing
1872164a449SMatthias Springer /// subset insertion op, pair them up. (If there is already a paired up subset
1882164a449SMatthias Springer /// extraction op, overwrite the subset extraction op.)
insertExtractionOp(SubsetExtractionOpInterface extractionOp)1892164a449SMatthias Springer void insertExtractionOp(SubsetExtractionOpInterface extractionOp) {
1902164a449SMatthias Springer for (auto it : llvm::enumerate(insertions)) {
1912164a449SMatthias Springer if (!it.value())
1922164a449SMatthias Springer continue;
1932164a449SMatthias Springer auto other = cast<SubsetOpInterface>(it.value().getOperation());
1942164a449SMatthias Springer if (other.operatesOnEquivalentSubset(extractionOp, isEquivalent)) {
1952164a449SMatthias Springer extractions[it.index()] = extractionOp;
1962164a449SMatthias Springer return;
1972164a449SMatthias Springer }
1982164a449SMatthias Springer }
1992164a449SMatthias Springer // There is no known equivalent insertion op. Create a new entry.
2002164a449SMatthias Springer extractions.push_back(extractionOp);
2012164a449SMatthias Springer insertions.push_back({});
2022164a449SMatthias Springer }
2032164a449SMatthias Springer
2042164a449SMatthias Springer /// Insert a subset insertion op. If the subset is equivalent to an existing
2052164a449SMatthias Springer /// subset extraction op, pair them up. (If there is already a paired up
2062164a449SMatthias Springer /// subset insertion op, overwrite the subset insertion op.)
insertInsertionOp(SubsetInsertionOpInterface insertionOp)2072164a449SMatthias Springer void insertInsertionOp(SubsetInsertionOpInterface insertionOp) {
2082164a449SMatthias Springer for (auto it : llvm::enumerate(extractions)) {
2092164a449SMatthias Springer if (!it.value())
2102164a449SMatthias Springer continue;
2112164a449SMatthias Springer auto other = cast<SubsetOpInterface>(it.value().getOperation());
2122164a449SMatthias Springer if (other.operatesOnEquivalentSubset(insertionOp, isEquivalent)) {
2132164a449SMatthias Springer insertions[it.index()] = insertionOp;
2142164a449SMatthias Springer return;
2152164a449SMatthias Springer }
2162164a449SMatthias Springer }
2172164a449SMatthias Springer // There is no known equivalent extraction op. Create a new entry.
2182164a449SMatthias Springer extractions.push_back({});
2192164a449SMatthias Springer insertions.push_back(insertionOp);
2202164a449SMatthias Springer }
2212164a449SMatthias Springer
2222164a449SMatthias Springer SmallVector<SubsetExtractionOpInterface> extractions;
2232164a449SMatthias Springer SmallVector<SubsetInsertionOpInterface> insertions;
2242164a449SMatthias Springer SmallVector<SubsetOpInterface> allSubsetOps;
2252164a449SMatthias Springer };
2262164a449SMatthias Springer } // namespace
2272164a449SMatthias Springer
2282164a449SMatthias Springer /// If the given value has a single use by an op that is a terminator, return
2292164a449SMatthias Springer /// that use. Otherwise, return nullptr.
getSingleTerminatorUse(Value value)2302164a449SMatthias Springer static OpOperand *getSingleTerminatorUse(Value value) {
2312164a449SMatthias Springer if (!value.hasOneUse())
2322164a449SMatthias Springer return nullptr;
2332164a449SMatthias Springer OpOperand &use = *value.getUses().begin();
2342164a449SMatthias Springer if (use.getOwner()->hasTrait<OpTrait::IsTerminator>())
2352164a449SMatthias Springer return &use;
2362164a449SMatthias Springer return nullptr;
2372164a449SMatthias Springer }
2382164a449SMatthias Springer
2397ea1c395SMatthias Springer LogicalResult
populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,BlockArgument iterArg,bool collectHoistableOps)2407ea1c395SMatthias Springer MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
2417ea1c395SMatthias Springer BlockArgument iterArg,
2427ea1c395SMatthias Springer bool collectHoistableOps) {
2432164a449SMatthias Springer assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
2442164a449SMatthias Springer Value value = iterArg;
2452164a449SMatthias Springer
2462164a449SMatthias Springer // Traverse use-def chain. Subset ops can be hoisted only if all ops along the
2472164a449SMatthias Springer // use-def chain starting from the region iter_arg are subset extraction or
2482164a449SMatthias Springer // subset insertion ops. The chain must terminate at the corresponding yield
2492164a449SMatthias Springer // operand (e.g., no swapping of iter_args).
2502164a449SMatthias Springer OpOperand *yieldedOperand = nullptr;
2512164a449SMatthias Springer // Iterate until the single use of the current SSA value is a terminator,
2522164a449SMatthias Springer // which is expected to be the yielding operation of the loop.
2532164a449SMatthias Springer while (!(yieldedOperand = getSingleTerminatorUse(value))) {
2542164a449SMatthias Springer Value nextValue = {};
2552164a449SMatthias Springer
2562164a449SMatthias Springer for (OpOperand &use : value.getUses()) {
2577ea1c395SMatthias Springer if (auto nestedLoop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
2587ea1c395SMatthias Springer // Subset ops in nested loops are collected to check if there are only
2597ea1c395SMatthias Springer // disjoint subset ops, but such subset ops are not subject to hoisting.
2607ea1c395SMatthias Springer // To hoist subset ops from nested loops, the hoisting transformation
2617ea1c395SMatthias Springer // should be run on the nested loop.
2627ea1c395SMatthias Springer auto nestedIterArg = nestedLoop.getTiedLoopRegionIterArg(&use);
2637ea1c395SMatthias Springer if (!nestedIterArg)
2647ea1c395SMatthias Springer return failure();
2657ea1c395SMatthias Springer // Note: `populateSubsetOpsAtIterArg` fails if there is no single SSA
2667ea1c395SMatthias Springer // use-def chain starting at `nestedIterArg` and terminating in the
2677ea1c395SMatthias Springer // tied, yielding operand.
2687ea1c395SMatthias Springer if (failed(populateSubsetOpsAtIterArg(nestedLoop, nestedIterArg,
2697ea1c395SMatthias Springer /*collectHoistableOps=*/false)))
2707ea1c395SMatthias Springer return failure();
2717ea1c395SMatthias Springer nextValue = nestedLoop.getTiedLoopResult(&use);
2727ea1c395SMatthias Springer continue;
2737ea1c395SMatthias Springer }
2747ea1c395SMatthias Springer
2752164a449SMatthias Springer auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner());
2762164a449SMatthias Springer if (!subsetOp)
2777ea1c395SMatthias Springer return failure();
2787ea1c395SMatthias Springer insert(subsetOp);
2792164a449SMatthias Springer
2802164a449SMatthias Springer if (auto insertionOp =
2812164a449SMatthias Springer dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) {
282*76ead96cSMaheshRavishankar // Current implementation expects that the insertionOp implement
283*76ead96cSMaheshRavishankar // the destinationStyleOpInterface as well. Abort if that tha is not
284*76ead96cSMaheshRavishankar // the case
285*76ead96cSMaheshRavishankar if (!isa<DestinationStyleOpInterface>(use.getOwner())) {
286*76ead96cSMaheshRavishankar return failure();
287*76ead96cSMaheshRavishankar }
288*76ead96cSMaheshRavishankar
2892164a449SMatthias Springer // The value must be used as a destination. (In case of a source, the
2902164a449SMatthias Springer // entire tensor would be read, which would prevent any hoisting.)
2912164a449SMatthias Springer if (&use != &insertionOp.getDestinationOperand())
2927ea1c395SMatthias Springer return failure();
2932164a449SMatthias Springer // There must be a single use-def chain from the region iter_arg to the
2942164a449SMatthias Springer // terminator. I.e., only one insertion op. Branches are not supported.
2952164a449SMatthias Springer if (nextValue)
2967ea1c395SMatthias Springer return failure();
2972164a449SMatthias Springer nextValue = insertionOp.getUpdatedDestination();
2982164a449SMatthias Springer }
2992164a449SMatthias Springer }
3002164a449SMatthias Springer
3012164a449SMatthias Springer // Nothing can be hoisted if the chain does not continue with loop yielding
3022164a449SMatthias Springer // op or a subset insertion op.
3032164a449SMatthias Springer if (!nextValue)
3047ea1c395SMatthias Springer return failure();
3052164a449SMatthias Springer value = nextValue;
3062164a449SMatthias Springer }
3072164a449SMatthias Springer
3082164a449SMatthias Springer // Hoist only if the SSA use-def chain ends in the yielding terminator of the
3092164a449SMatthias Springer // loop and the yielded value is the `idx`-th operand. (I.e., there is no
3102164a449SMatthias Springer // swapping yield.)
3112164a449SMatthias Springer if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand)
3127ea1c395SMatthias Springer return failure();
3137ea1c395SMatthias Springer
3147ea1c395SMatthias Springer return success();
3157ea1c395SMatthias Springer }
3167ea1c395SMatthias Springer
3177ea1c395SMatthias Springer /// Hoist all subset ops that operate on the idx-th region iter_arg of the given
3187ea1c395SMatthias Springer /// loop-like op and index into loop-invariant subset locations. Return the
3197ea1c395SMatthias Springer /// newly created loop op (that has extra iter_args) or the original loop op if
3207ea1c395SMatthias Springer /// nothing was hoisted.
hoistSubsetAtIterArg(RewriterBase & rewriter,LoopLikeOpInterface loopLike,BlockArgument iterArg)321b9fe461eSMatthias Springer static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
322b9fe461eSMatthias Springer LoopLikeOpInterface loopLike,
3237ea1c395SMatthias Springer BlockArgument iterArg) {
3247ea1c395SMatthias Springer assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
3257ea1c395SMatthias Springer auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
3267ea1c395SMatthias Springer int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
3277ea1c395SMatthias Springer MatchingSubsets subsets;
3287ea1c395SMatthias Springer if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
3292164a449SMatthias Springer return loopLike;
3302164a449SMatthias Springer
3312164a449SMatthias Springer // Hoist all matching extraction-insertion pairs one-by-one.
3322164a449SMatthias Springer for (auto it : subsets.getHoistableSubsetOps()) {
3332164a449SMatthias Springer auto extractionOp = std::get<0>(it);
3342164a449SMatthias Springer auto insertionOp = std::get<1>(it);
3352164a449SMatthias Springer
3362164a449SMatthias Springer // Ops cannot be hoisted if they depend on loop-variant values.
3372164a449SMatthias Springer if (extractionOp) {
3382164a449SMatthias Springer if (!canBeHoisted(extractionOp, [&](OpOperand &operand) {
3392164a449SMatthias Springer return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
3402164a449SMatthias Springer &operand == &extractionOp.getSourceOperand();
3412164a449SMatthias Springer }))
3422164a449SMatthias Springer extractionOp = {};
3432164a449SMatthias Springer }
3442164a449SMatthias Springer if (insertionOp) {
3452164a449SMatthias Springer if (!canBeHoisted(insertionOp, [&](OpOperand &operand) {
3462164a449SMatthias Springer return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
3472164a449SMatthias Springer &operand == &insertionOp.getSourceOperand() ||
3482164a449SMatthias Springer &operand == &insertionOp.getDestinationOperand();
3492164a449SMatthias Springer }))
3502164a449SMatthias Springer insertionOp = {};
3512164a449SMatthias Springer }
3522164a449SMatthias Springer
3532164a449SMatthias Springer // Only hoist extraction-insertion pairs for now. Standalone extractions/
3542164a449SMatthias Springer // insertions that are loop-invariant could be hoisted, but there may be
3552164a449SMatthias Springer // easier ways to canonicalize the IR.
3562164a449SMatthias Springer if (extractionOp && insertionOp) {
3572164a449SMatthias Springer // Create a new loop with an additional iter_arg.
3582164a449SMatthias Springer NewYieldValuesFn newYieldValuesFn =
3592164a449SMatthias Springer [&](OpBuilder &b, Location loc,
3602164a449SMatthias Springer ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
3612164a449SMatthias Springer return {insertionOp.getSourceOperand().get()};
3622164a449SMatthias Springer };
3632164a449SMatthias Springer FailureOr<LoopLikeOpInterface> newLoop =
3642164a449SMatthias Springer loopLike.replaceWithAdditionalYields(
3652164a449SMatthias Springer rewriter, extractionOp.getResult(),
3662164a449SMatthias Springer /*replaceInitOperandUsesInLoop=*/true, newYieldValuesFn);
3672164a449SMatthias Springer if (failed(newLoop))
3682164a449SMatthias Springer return loopLike;
3692164a449SMatthias Springer loopLike = *newLoop;
3702164a449SMatthias Springer
3712164a449SMatthias Springer // Hoist the extraction/insertion ops.
3722164a449SMatthias Springer iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
3732164a449SMatthias Springer OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
3742164a449SMatthias Springer OpResult newLoopResult = loopLike.getLoopResults()->back();
3755cc0f76dSMatthias Springer rewriter.moveOpBefore(extractionOp, loopLike);
3765cc0f76dSMatthias Springer rewriter.moveOpAfter(insertionOp, loopLike);
377b9fe461eSMatthias Springer rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
3782164a449SMatthias Springer insertionOp.getDestinationOperand().get());
3792164a449SMatthias Springer extractionOp.getSourceOperand().set(
3802164a449SMatthias Springer loopLike.getTiedLoopInit(iterArg)->get());
381b9fe461eSMatthias Springer rewriter.replaceAllUsesWith(loopResult,
382b9fe461eSMatthias Springer insertionOp.getUpdatedDestination());
3832164a449SMatthias Springer insertionOp.getSourceOperand().set(newLoopResult);
3842164a449SMatthias Springer insertionOp.getDestinationOperand().set(loopResult);
3852164a449SMatthias Springer }
3862164a449SMatthias Springer }
3872164a449SMatthias Springer
3882164a449SMatthias Springer return loopLike;
3892164a449SMatthias Springer }
3902164a449SMatthias Springer
3912164a449SMatthias Springer LoopLikeOpInterface
hoistLoopInvariantSubsets(RewriterBase & rewriter,LoopLikeOpInterface loopLike)392b9fe461eSMatthias Springer mlir::hoistLoopInvariantSubsets(RewriterBase &rewriter,
393b9fe461eSMatthias Springer LoopLikeOpInterface loopLike) {
3942164a449SMatthias Springer // Note: As subset ops are getting hoisted, the number of region iter_args
3952164a449SMatthias Springer // increases. This can enable further hoisting opportunities on the new
3962164a449SMatthias Springer // iter_args.
3972164a449SMatthias Springer for (int64_t i = 0;
3982164a449SMatthias Springer i < static_cast<int64_t>(loopLike.getRegionIterArgs().size()); ++i) {
399b9fe461eSMatthias Springer loopLike = hoistSubsetAtIterArg(rewriter, loopLike,
400b9fe461eSMatthias Springer loopLike.getRegionIterArgs()[i]);
4012164a449SMatthias Springer }
4022164a449SMatthias Springer return loopLike;
4032164a449SMatthias Springer }
404