xref: /llvm-project/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp (revision 9b5ef2bea8e3de8fd564ab74b8123b446883216f)
1 //===- LoopInvariantCodeMotionUtils.cpp - LICM Utils ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the core LICM algorithm.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
14 #include "mlir/IR/Operation.h"
15 #include "mlir/Interfaces/LoopLikeInterface.h"
16 #include "mlir/Interfaces/SideEffectInterfaces.h"
17 #include "llvm/Support/Debug.h"
18 #include <queue>
19 
20 #define DEBUG_TYPE "licm"
21 
22 using namespace mlir;
23 
24 /// Checks whether the given op can be hoisted by checking that
25 /// - the op and none of its contained operations depend on values inside of the
26 ///   loop (by means of calling definedOutside).
27 /// - the op has no side-effects.
28 static bool canBeHoisted(Operation *op,
29                          function_ref<bool(Value)> definedOutside) {
30   // Do not move terminators.
31   if (op->hasTrait<OpTrait::IsTerminator>())
32     return false;
33 
34   // Walk the nested operations and check that all used values are either
35   // defined outside of the loop or in a nested region, but not at the level of
36   // the loop body.
37   auto walkFn = [&](Operation *child) {
38     for (Value operand : child->getOperands()) {
39       // Ignore values defined in a nested region.
40       if (op->isAncestor(operand.getParentRegion()->getParentOp()))
41         continue;
42       if (!definedOutside(operand))
43         return WalkResult::interrupt();
44     }
45     return WalkResult::advance();
46   };
47   return !op->walk(walkFn).wasInterrupted();
48 }
49 
50 size_t mlir::moveLoopInvariantCode(
51     ArrayRef<Region *> regions,
52     function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
53     function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
54     function_ref<void(Operation *, Region *)> moveOutOfRegion) {
55   size_t numMoved = 0;
56 
57   for (Region *region : regions) {
58     LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
59                             << *region->getParentOp() << "\n");
60 
61     std::queue<Operation *> worklist;
62     // Add top-level operations in the loop body to the worklist.
63     for (Operation &op : region->getOps())
64       worklist.push(&op);
65 
66     auto definedOutside = [&](Value value) {
67       return isDefinedOutsideRegion(value, region);
68     };
69 
70     while (!worklist.empty()) {
71       Operation *op = worklist.front();
72       worklist.pop();
73       // Skip ops that have already been moved. Check if the op can be hoisted.
74       if (op->getParentRegion() != region)
75         continue;
76 
77       LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
78       if (!shouldMoveOutOfRegion(op, region) ||
79           !canBeHoisted(op, definedOutside))
80         continue;
81 
82       LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
83       moveOutOfRegion(op, region);
84       ++numMoved;
85 
86       // Since the op has been moved, we need to check its users within the
87       // top-level of the loop body.
88       for (Operation *user : op->getUsers())
89         if (user->getParentRegion() == region)
90           worklist.push(user);
91     }
92   }
93 
94   return numMoved;
95 }
96 
97 size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
98   return moveLoopInvariantCode(
99       loopLike.getLoopRegions(),
100       [&](Value value, Region *) {
101         return loopLike.isDefinedOutsideOfLoop(value);
102       },
103       [&](Operation *op, Region *) {
104         return isMemoryEffectFree(op) && isSpeculatable(op);
105       },
106       [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
107 }
108