xref: /llvm-project/mlir/lib/Transforms/Mem2Reg.cpp (revision 5fdda4147495c6bcfb6c100dbf32a669c3b39874)
1f88f8fd0SThéo Degioanni //===- Mem2Reg.cpp - Promotes memory slots into values ----------*- C++ -*-===//
2f88f8fd0SThéo Degioanni //
3f88f8fd0SThéo Degioanni // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f88f8fd0SThéo Degioanni // See https://llvm.org/LICENSE.txt for license information.
5f88f8fd0SThéo Degioanni // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f88f8fd0SThéo Degioanni //
7f88f8fd0SThéo Degioanni //===----------------------------------------------------------------------===//
8f88f8fd0SThéo Degioanni 
9f88f8fd0SThéo Degioanni #include "mlir/Transforms/Mem2Reg.h"
1098c6bc53SChristian Ulmann #include "mlir/Analysis/DataLayoutAnalysis.h"
11f88f8fd0SThéo Degioanni #include "mlir/Analysis/SliceAnalysis.h"
12b00e0c16SChristian Ulmann #include "mlir/Analysis/TopologicalSortUtils.h"
13f88f8fd0SThéo Degioanni #include "mlir/IR/Builders.h"
1492cc30acSThéo Degioanni #include "mlir/IR/Dominance.h"
15ead8e9d7SThéo Degioanni #include "mlir/IR/PatternMatch.h"
16b084111cSThéo Degioanni #include "mlir/IR/RegionKindInterface.h"
17ead8e9d7SThéo Degioanni #include "mlir/IR/Value.h"
18f88f8fd0SThéo Degioanni #include "mlir/Interfaces/ControlFlowInterfaces.h"
1992cc30acSThéo Degioanni #include "mlir/Interfaces/MemorySlotInterfaces.h"
20f88f8fd0SThéo Degioanni #include "mlir/Transforms/Passes.h"
21f88f8fd0SThéo Degioanni #include "llvm/ADT/STLExtras.h"
22f88f8fd0SThéo Degioanni #include "llvm/Support/GenericIteratedDominanceFrontier.h"
23f88f8fd0SThéo Degioanni 
24f88f8fd0SThéo Degioanni namespace mlir {
25f88f8fd0SThéo Degioanni #define GEN_PASS_DEF_MEM2REG
26f88f8fd0SThéo Degioanni #include "mlir/Transforms/Passes.h.inc"
27f88f8fd0SThéo Degioanni } // namespace mlir
28f88f8fd0SThéo Degioanni 
293ba79a36SThéo Degioanni #define DEBUG_TYPE "mem2reg"
303ba79a36SThéo Degioanni 
31f88f8fd0SThéo Degioanni using namespace mlir;
32f88f8fd0SThéo Degioanni 
33f88f8fd0SThéo Degioanni /// mem2reg
34f88f8fd0SThéo Degioanni ///
35f88f8fd0SThéo Degioanni /// This pass turns unnecessary uses of automatically allocated memory slots
36f88f8fd0SThéo Degioanni /// into direct Value-based operations. For example, it will simplify storing a
37f88f8fd0SThéo Degioanni /// constant in a memory slot to immediately load it to a direct use of that
38f88f8fd0SThéo Degioanni /// constant. In other words, given a memory slot addressed by a non-aliased
39f88f8fd0SThéo Degioanni /// "pointer" Value, mem2reg removes all the uses of that pointer.
40f88f8fd0SThéo Degioanni ///
41f88f8fd0SThéo Degioanni /// Within a block, this is done by following the chain of stores and loads of
42f88f8fd0SThéo Degioanni /// the slot and replacing the results of loads with the values previously
43f88f8fd0SThéo Degioanni /// stored. If a load happens before any other store, a poison value is used
44f88f8fd0SThéo Degioanni /// instead.
45f88f8fd0SThéo Degioanni ///
46f88f8fd0SThéo Degioanni /// Control flow can create situations where a load could be replaced by
47f88f8fd0SThéo Degioanni /// multiple possible stores depending on the control flow path taken. As a
48f88f8fd0SThéo Degioanni /// result, this pass must introduce new block arguments in some blocks to
4960baaf15SCongcong Cai /// accommodate for the multiple possible definitions. Each predecessor will
50f88f8fd0SThéo Degioanni /// populate the block argument with the definition reached at its end. With
51f88f8fd0SThéo Degioanni /// this, the value stored can be well defined at block boundaries, allowing
52f88f8fd0SThéo Degioanni /// the propagation of replacement through blocks.
53f88f8fd0SThéo Degioanni ///
5492cc30acSThéo Degioanni /// This pass computes this transformation in four main steps. The two first
5592cc30acSThéo Degioanni /// steps are performed during an analysis phase that does not mutate IR.
5692cc30acSThéo Degioanni ///
5792cc30acSThéo Degioanni /// The two steps of the analysis phase are the following:
58f88f8fd0SThéo Degioanni /// - A first step computes the list of operations that transitively use the
59f88f8fd0SThéo Degioanni /// memory slot we would like to promote. The purpose of this phase is to
60f88f8fd0SThéo Degioanni /// identify which uses must be removed to promote the slot, either by rewiring
61f88f8fd0SThéo Degioanni /// the user or deleting it. Naturally, direct uses of the slot must be removed.
62f88f8fd0SThéo Degioanni /// Sometimes additional uses must also be removed: this is notably the case
63f88f8fd0SThéo Degioanni /// when a direct user of the slot cannot rewire its use and must delete itself,
64f88f8fd0SThéo Degioanni /// and thus must make its users no longer use it. If any of those uses cannot
65f88f8fd0SThéo Degioanni /// be removed by their users in any way, promotion cannot continue: this is
66f88f8fd0SThéo Degioanni /// decided at this step.
67f88f8fd0SThéo Degioanni /// - A second step computes the list of blocks where a block argument will be
68f88f8fd0SThéo Degioanni /// needed ("merge points") without mutating the IR. These blocks are the blocks
69f88f8fd0SThéo Degioanni /// leading to a definition clash between two predecessors. Such blocks happen
70f88f8fd0SThéo Degioanni /// to be the Iterated Dominance Frontier (IDF) of the set of blocks containing
71f88f8fd0SThéo Degioanni /// a store, as they represent the point where a clear defining dominator stops
72f88f8fd0SThéo Degioanni /// existing. Computing this information in advance allows making sure the
73f88f8fd0SThéo Degioanni /// terminators that will forward values are capable of doing so (inability to
74f88f8fd0SThéo Degioanni /// do so aborts promotion at this step).
7592cc30acSThéo Degioanni ///
7692cc30acSThéo Degioanni /// At this point, promotion is guaranteed to happen, and the mutation phase can
7792cc30acSThéo Degioanni /// begin with the following steps:
78f88f8fd0SThéo Degioanni /// - A third step computes the reaching definition of the memory slot at each
79f88f8fd0SThéo Degioanni /// blocking user. This is the core of the mem2reg algorithm, also known as
80f88f8fd0SThéo Degioanni /// load-store forwarding. This analyses loads and stores and propagates which
81f88f8fd0SThéo Degioanni /// value must be stored in the slot at each blocking user.  This is achieved by
82f88f8fd0SThéo Degioanni /// doing a depth-first walk of the dominator tree of the function. This is
83f88f8fd0SThéo Degioanni /// sufficient because the reaching definition at the beginning of a block is
84f88f8fd0SThéo Degioanni /// either its new block argument if it is a merge block, or the definition
85f88f8fd0SThéo Degioanni /// reaching the end of its immediate dominator (parent in the dominator tree).
86f88f8fd0SThéo Degioanni /// We can therefore propagate this information down the dominator tree to
87f88f8fd0SThéo Degioanni /// proceed with renaming within blocks.
88f88f8fd0SThéo Degioanni /// - The final fourth step uses the reaching definition to remove blocking uses
89f88f8fd0SThéo Degioanni /// in topological order.
90f88f8fd0SThéo Degioanni ///
91f88f8fd0SThéo Degioanni /// For further reading, chapter three of SSA-based Compiler Design [1]
92f88f8fd0SThéo Degioanni /// showcases SSA construction, where mem2reg is an adaptation of the same
93f88f8fd0SThéo Degioanni /// process.
94f88f8fd0SThéo Degioanni ///
95f88f8fd0SThéo Degioanni /// [1]: Rastello F. & Bouchez Tichadou F., SSA-based Compiler Design (2022),
96f88f8fd0SThéo Degioanni ///      Springer.
97f88f8fd0SThéo Degioanni 
98ead8e9d7SThéo Degioanni namespace {
99ead8e9d7SThéo Degioanni 
100ab6a66dbSChristian Ulmann using BlockingUsesMap =
101ab6a66dbSChristian Ulmann     llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4>>;
102ab6a66dbSChristian Ulmann 
103ead8e9d7SThéo Degioanni /// Information computed during promotion analysis used to perform actual
104ead8e9d7SThéo Degioanni /// promotion.
105ead8e9d7SThéo Degioanni struct MemorySlotPromotionInfo {
106ead8e9d7SThéo Degioanni   /// Blocks for which at least two definitions of the slot values clash.
107ead8e9d7SThéo Degioanni   SmallPtrSet<Block *, 8> mergePoints;
108ead8e9d7SThéo Degioanni   /// Contains, for each operation, which uses must be eliminated by promotion.
109ead8e9d7SThéo Degioanni   /// This is a DAG structure because if an operation must eliminate some of
110ead8e9d7SThéo Degioanni   /// its uses, it is because the defining ops of the blocking uses requested
111ead8e9d7SThéo Degioanni   /// it. The defining ops therefore must also have blocking uses or be the
11260baaf15SCongcong Cai   /// starting point of the blocking uses.
113ab6a66dbSChristian Ulmann   BlockingUsesMap userToBlockingUses;
114ead8e9d7SThéo Degioanni };
115ead8e9d7SThéo Degioanni 
116ead8e9d7SThéo Degioanni /// Computes information for basic slot promotion. This will check that direct
117ead8e9d7SThéo Degioanni /// slot promotion can be performed, and provide the information to execute the
118ead8e9d7SThéo Degioanni /// promotion. This does not mutate IR.
119ead8e9d7SThéo Degioanni class MemorySlotPromotionAnalyzer {
120ead8e9d7SThéo Degioanni public:
12198c6bc53SChristian Ulmann   MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance,
12298c6bc53SChristian Ulmann                               const DataLayout &dataLayout)
12398c6bc53SChristian Ulmann       : slot(slot), dominance(dominance), dataLayout(dataLayout) {}
124ead8e9d7SThéo Degioanni 
125ead8e9d7SThéo Degioanni   /// Computes the information for slot promotion if promotion is possible,
126ead8e9d7SThéo Degioanni   /// returns nothing otherwise.
127ead8e9d7SThéo Degioanni   std::optional<MemorySlotPromotionInfo> computeInfo();
128ead8e9d7SThéo Degioanni 
129ead8e9d7SThéo Degioanni private:
130ead8e9d7SThéo Degioanni   /// Computes the transitive uses of the slot that block promotion. This finds
131ead8e9d7SThéo Degioanni   /// uses that would block the promotion, checks that the operation has a
132ead8e9d7SThéo Degioanni   /// solution to remove the blocking use, and potentially forwards the analysis
133ead8e9d7SThéo Degioanni   /// if the operation needs further blocking uses resolved to resolve its own
134ead8e9d7SThéo Degioanni   /// uses (typically, removing its users because it will delete itself to
135ead8e9d7SThéo Degioanni   /// resolve its own blocking uses). This will fail if one of the transitive
136ead8e9d7SThéo Degioanni   /// users cannot remove a requested use, and should prevent promotion.
137ab6a66dbSChristian Ulmann   LogicalResult computeBlockingUses(BlockingUsesMap &userToBlockingUses);
138ead8e9d7SThéo Degioanni 
139ead8e9d7SThéo Degioanni   /// Computes in which blocks the value stored in the slot is actually used,
140ead8e9d7SThéo Degioanni   /// meaning blocks leading to a load. This method uses `definingBlocks`, the
141ead8e9d7SThéo Degioanni   /// set of blocks containing a store to the slot (defining the value of the
142ead8e9d7SThéo Degioanni   /// slot).
143ead8e9d7SThéo Degioanni   SmallPtrSet<Block *, 16>
144ead8e9d7SThéo Degioanni   computeSlotLiveIn(SmallPtrSetImpl<Block *> &definingBlocks);
145ead8e9d7SThéo Degioanni 
146ead8e9d7SThéo Degioanni   /// Computes the points in which multiple re-definitions of the slot's value
147ead8e9d7SThéo Degioanni   /// (stores) may conflict.
148ead8e9d7SThéo Degioanni   void computeMergePoints(SmallPtrSetImpl<Block *> &mergePoints);
149ead8e9d7SThéo Degioanni 
150ead8e9d7SThéo Degioanni   /// Ensures predecessors of merge points can properly provide their current
151ead8e9d7SThéo Degioanni   /// definition of the value stored in the slot to the merge point. This can
152ead8e9d7SThéo Degioanni   /// notably be an issue if the terminator used does not have the ability to
153ead8e9d7SThéo Degioanni   /// forward values through block operands.
154ead8e9d7SThéo Degioanni   bool areMergePointsUsable(SmallPtrSetImpl<Block *> &mergePoints);
155ead8e9d7SThéo Degioanni 
156ead8e9d7SThéo Degioanni   MemorySlot slot;
157ead8e9d7SThéo Degioanni   DominanceInfo &dominance;
15898c6bc53SChristian Ulmann   const DataLayout &dataLayout;
159ead8e9d7SThéo Degioanni };
160ead8e9d7SThéo Degioanni 
161c6efcc92SChristian Ulmann using BlockIndexCache = DenseMap<Region *, DenseMap<Block *, size_t>>;
162c6efcc92SChristian Ulmann 
163ead8e9d7SThéo Degioanni /// The MemorySlotPromoter handles the state of promoting a memory slot. It
164ead8e9d7SThéo Degioanni /// wraps a slot and its associated allocator. This will perform the mutation of
165ead8e9d7SThéo Degioanni /// IR.
166ead8e9d7SThéo Degioanni class MemorySlotPromoter {
167ead8e9d7SThéo Degioanni public:
168ead8e9d7SThéo Degioanni   MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
169084e2b53SChristian Ulmann                      OpBuilder &builder, DominanceInfo &dominance,
170ac39fa74SChristian Ulmann                      const DataLayout &dataLayout, MemorySlotPromotionInfo info,
171c6efcc92SChristian Ulmann                      const Mem2RegStatistics &statistics,
172c6efcc92SChristian Ulmann                      BlockIndexCache &blockIndexCache);
173ead8e9d7SThéo Degioanni 
174ead8e9d7SThéo Degioanni   /// Actually promotes the slot by mutating IR. Promoting a slot DOES
175ead8e9d7SThéo Degioanni   /// invalidate the MemorySlotPromotionInfo of other slots. Preparation of
176ead8e9d7SThéo Degioanni   /// promotion info should NOT be performed in batches.
177eeafc9daSChristian Ulmann   /// Returns a promotable allocation op if a new allocator was created, nullopt
178eeafc9daSChristian Ulmann   /// otherwise.
179eeafc9daSChristian Ulmann   std::optional<PromotableAllocationOpInterface> promoteSlot();
180ead8e9d7SThéo Degioanni 
181ead8e9d7SThéo Degioanni private:
182ead8e9d7SThéo Degioanni   /// Computes the reaching definition for all the operations that require
183ead8e9d7SThéo Degioanni   /// promotion. `reachingDef` is the value the slot should contain at the
184ead8e9d7SThéo Degioanni   /// beginning of the block. This method returns the reached definition at the
1858404b23aSThéo Degioanni   /// end of the block. This method must only be called at most once per block.
186ead8e9d7SThéo Degioanni   Value computeReachingDefInBlock(Block *block, Value reachingDef);
187ead8e9d7SThéo Degioanni 
188ead8e9d7SThéo Degioanni   /// Computes the reaching definition for all the operations that require
189ead8e9d7SThéo Degioanni   /// promotion. `reachingDef` corresponds to the initial value the
190ead8e9d7SThéo Degioanni   /// slot will contain before any write, typically a poison value.
1918404b23aSThéo Degioanni   /// This method must only be called at most once per region.
192ead8e9d7SThéo Degioanni   void computeReachingDefInRegion(Region *region, Value reachingDef);
193ead8e9d7SThéo Degioanni 
194ead8e9d7SThéo Degioanni   /// Removes the blocking uses of the slot, in topological order.
195ead8e9d7SThéo Degioanni   void removeBlockingUses();
196ead8e9d7SThéo Degioanni 
197ead8e9d7SThéo Degioanni   /// Lazily-constructed default value representing the content of the slot when
198ead8e9d7SThéo Degioanni   /// no store has been executed. This function may mutate IR.
1996e9ea6eaSChristian Ulmann   Value getOrCreateDefaultValue();
200ead8e9d7SThéo Degioanni 
201ead8e9d7SThéo Degioanni   MemorySlot slot;
202ead8e9d7SThéo Degioanni   PromotableAllocationOpInterface allocator;
203084e2b53SChristian Ulmann   OpBuilder &builder;
2046e9ea6eaSChristian Ulmann   /// Potentially non-initialized default value. Use `getOrCreateDefaultValue`
2056e9ea6eaSChristian Ulmann   /// to initialize it on demand.
206ead8e9d7SThéo Degioanni   Value defaultValue;
207ead8e9d7SThéo Degioanni   /// Contains the reaching definition at this operation. Reaching definitions
208ead8e9d7SThéo Degioanni   /// are only computed for promotable memory operations with blocking uses.
209ead8e9d7SThéo Degioanni   DenseMap<PromotableMemOpInterface, Value> reachingDefs;
210220cdf94SFabian Mora   DenseMap<PromotableMemOpInterface, Value> replacedValuesMap;
211ead8e9d7SThéo Degioanni   DominanceInfo &dominance;
212ac39fa74SChristian Ulmann   const DataLayout &dataLayout;
213ead8e9d7SThéo Degioanni   MemorySlotPromotionInfo info;
214ead8e9d7SThéo Degioanni   const Mem2RegStatistics &statistics;
215c6efcc92SChristian Ulmann 
216c6efcc92SChristian Ulmann   /// Shared cache of block indices of specific regions.
217c6efcc92SChristian Ulmann   BlockIndexCache &blockIndexCache;
218ead8e9d7SThéo Degioanni };
219ead8e9d7SThéo Degioanni 
220ead8e9d7SThéo Degioanni } // namespace
221ead8e9d7SThéo Degioanni 
22292cc30acSThéo Degioanni MemorySlotPromoter::MemorySlotPromoter(
22392cc30acSThéo Degioanni     MemorySlot slot, PromotableAllocationOpInterface allocator,
224084e2b53SChristian Ulmann     OpBuilder &builder, DominanceInfo &dominance, const DataLayout &dataLayout,
225c6efcc92SChristian Ulmann     MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics,
226c6efcc92SChristian Ulmann     BlockIndexCache &blockIndexCache)
227084e2b53SChristian Ulmann     : slot(slot), allocator(allocator), builder(builder), dominance(dominance),
228c6efcc92SChristian Ulmann       dataLayout(dataLayout), info(std::move(info)), statistics(statistics),
229c6efcc92SChristian Ulmann       blockIndexCache(blockIndexCache) {
2301367c5d6SThéo Degioanni #ifndef NDEBUG
2311367c5d6SThéo Degioanni   auto isResultOrNewBlockArgument = [&]() {
2325550c821STres Popp     if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr))
2331367c5d6SThéo Degioanni       return arg.getOwner()->getParentOp() == allocator;
2341367c5d6SThéo Degioanni     return slot.ptr.getDefiningOp() == allocator;
2351367c5d6SThéo Degioanni   };
2361367c5d6SThéo Degioanni 
2371367c5d6SThéo Degioanni   assert(isResultOrNewBlockArgument() &&
238f88f8fd0SThéo Degioanni          "a slot must be a result of the allocator or an argument of the child "
239f88f8fd0SThéo Degioanni          "regions of the allocator");
2401367c5d6SThéo Degioanni #endif // NDEBUG
241f88f8fd0SThéo Degioanni }
242f88f8fd0SThéo Degioanni 
2436e9ea6eaSChristian Ulmann Value MemorySlotPromoter::getOrCreateDefaultValue() {
244f88f8fd0SThéo Degioanni   if (defaultValue)
245f88f8fd0SThéo Degioanni     return defaultValue;
246f88f8fd0SThéo Degioanni 
247084e2b53SChristian Ulmann   OpBuilder::InsertionGuard guard(builder);
248084e2b53SChristian Ulmann   builder.setInsertionPointToStart(slot.ptr.getParentBlock());
249084e2b53SChristian Ulmann   return defaultValue = allocator.getDefaultValue(slot, builder);
250f88f8fd0SThéo Degioanni }
251f88f8fd0SThéo Degioanni 
25292cc30acSThéo Degioanni LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
253ab6a66dbSChristian Ulmann     BlockingUsesMap &userToBlockingUses) {
254f88f8fd0SThéo Degioanni   // The promotion of an operation may require the promotion of further
255f88f8fd0SThéo Degioanni   // operations (typically, removing operations that use an operation that must
256f88f8fd0SThéo Degioanni   // delete itself). We thus need to start from the use of the slot pointer and
257f88f8fd0SThéo Degioanni   // propagate further requests through the forward slice.
258f88f8fd0SThéo Degioanni 
259b084111cSThéo Degioanni   // Because this pass currently only supports analysing the parent region of
260b084111cSThéo Degioanni   // the slot pointer, if a promotable memory op that needs promotion is within
261b084111cSThéo Degioanni   // a graph region, the slot may only be used in a graph region and should
262b084111cSThéo Degioanni   // therefore be ignored.
263b084111cSThéo Degioanni   Region *slotPtrRegion = slot.ptr.getParentRegion();
264b084111cSThéo Degioanni   auto slotPtrRegionOp =
265b084111cSThéo Degioanni       dyn_cast<RegionKindInterface>(slotPtrRegion->getParentOp());
266b084111cSThéo Degioanni   if (slotPtrRegionOp &&
267b084111cSThéo Degioanni       slotPtrRegionOp.getRegionKind(slotPtrRegion->getRegionNumber()) ==
268b084111cSThéo Degioanni           RegionKind::Graph)
269b084111cSThéo Degioanni     return failure();
270b084111cSThéo Degioanni 
271f88f8fd0SThéo Degioanni   // First insert that all immediate users of the slot pointer must no longer
272f88f8fd0SThéo Degioanni   // use it.
273f88f8fd0SThéo Degioanni   for (OpOperand &use : slot.ptr.getUses()) {
274f88f8fd0SThéo Degioanni     SmallPtrSet<OpOperand *, 4> &blockingUses =
275ab6a66dbSChristian Ulmann         userToBlockingUses[use.getOwner()];
276f88f8fd0SThéo Degioanni     blockingUses.insert(&use);
277f88f8fd0SThéo Degioanni   }
278f88f8fd0SThéo Degioanni 
279f88f8fd0SThéo Degioanni   // Then, propagate the requirements for the removal of uses. The
280f88f8fd0SThéo Degioanni   // topologically-sorted forward slice allows for all blocking uses of an
28192cc30acSThéo Degioanni   // operation to have been computed before it is reached. Operations are
282f88f8fd0SThéo Degioanni   // traversed in topological order of their uses, starting from the slot
283f88f8fd0SThéo Degioanni   // pointer.
284f88f8fd0SThéo Degioanni   SetVector<Operation *> forwardSlice;
285f88f8fd0SThéo Degioanni   mlir::getForwardSlice(slot.ptr, &forwardSlice);
286f88f8fd0SThéo Degioanni   for (Operation *user : forwardSlice) {
287f88f8fd0SThéo Degioanni     // If the next operation has no blocking uses, everything is fine.
288*5fdda414SKazu Hirata     auto it = userToBlockingUses.find(user);
289*5fdda414SKazu Hirata     if (it == userToBlockingUses.end())
290f88f8fd0SThéo Degioanni       continue;
291f88f8fd0SThéo Degioanni 
292*5fdda414SKazu Hirata     SmallPtrSet<OpOperand *, 4> &blockingUses = it->second;
293f88f8fd0SThéo Degioanni 
294f88f8fd0SThéo Degioanni     SmallVector<OpOperand *> newBlockingUses;
295f88f8fd0SThéo Degioanni     // If the operation decides it cannot deal with removing the blocking uses,
296f88f8fd0SThéo Degioanni     // promotion must fail.
297f88f8fd0SThéo Degioanni     if (auto promotable = dyn_cast<PromotableOpInterface>(user)) {
29898c6bc53SChristian Ulmann       if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses,
29998c6bc53SChristian Ulmann                                        dataLayout))
300f88f8fd0SThéo Degioanni         return failure();
301f88f8fd0SThéo Degioanni     } else if (auto promotable = dyn_cast<PromotableMemOpInterface>(user)) {
30298c6bc53SChristian Ulmann       if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses,
30398c6bc53SChristian Ulmann                                        dataLayout))
304f88f8fd0SThéo Degioanni         return failure();
305f88f8fd0SThéo Degioanni     } else {
306f88f8fd0SThéo Degioanni       // An operation that has blocking uses must be promoted. If it is not
307f88f8fd0SThéo Degioanni       // promotable, promotion must fail.
308f88f8fd0SThéo Degioanni       return failure();
309f88f8fd0SThéo Degioanni     }
310f88f8fd0SThéo Degioanni 
311f88f8fd0SThéo Degioanni     // Then, register any new blocking uses for coming operations.
312f88f8fd0SThéo Degioanni     for (OpOperand *blockingUse : newBlockingUses) {
3131367c5d6SThéo Degioanni       assert(llvm::is_contained(user->getResults(), blockingUse->get()));
314f88f8fd0SThéo Degioanni 
315f88f8fd0SThéo Degioanni       SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =
316ab6a66dbSChristian Ulmann           userToBlockingUses[blockingUse->getOwner()];
317f88f8fd0SThéo Degioanni       newUserBlockingUseSet.insert(blockingUse);
318f88f8fd0SThéo Degioanni     }
319f88f8fd0SThéo Degioanni   }
320f88f8fd0SThéo Degioanni 
321f88f8fd0SThéo Degioanni   // Because this pass currently only supports analysing the parent region of
32292cc30acSThéo Degioanni   // the slot pointer, if a promotable memory op that needs promotion is outside
32392cc30acSThéo Degioanni   // of this region, promotion must fail because it will be impossible to
32492cc30acSThéo Degioanni   // provide a valid `reachingDef` for it.
325f88f8fd0SThéo Degioanni   for (auto &[toPromote, _] : userToBlockingUses)
326f88f8fd0SThéo Degioanni     if (isa<PromotableMemOpInterface>(toPromote) &&
327f88f8fd0SThéo Degioanni         toPromote->getParentRegion() != slot.ptr.getParentRegion())
328f88f8fd0SThéo Degioanni       return failure();
329f88f8fd0SThéo Degioanni 
330f88f8fd0SThéo Degioanni   return success();
331f88f8fd0SThéo Degioanni }
332f88f8fd0SThéo Degioanni 
33392cc30acSThéo Degioanni SmallPtrSet<Block *, 16> MemorySlotPromotionAnalyzer::computeSlotLiveIn(
33492cc30acSThéo Degioanni     SmallPtrSetImpl<Block *> &definingBlocks) {
335f88f8fd0SThéo Degioanni   SmallPtrSet<Block *, 16> liveIn;
336f88f8fd0SThéo Degioanni 
337f88f8fd0SThéo Degioanni   // The worklist contains blocks in which it is known that the slot value is
338f88f8fd0SThéo Degioanni   // live-in. The further blocks where this value is live-in will be inferred
339f88f8fd0SThéo Degioanni   // from these.
340f88f8fd0SThéo Degioanni   SmallVector<Block *> liveInWorkList;
341f88f8fd0SThéo Degioanni 
342f88f8fd0SThéo Degioanni   // Blocks with a load before any other store to the slot are the starting
343f88f8fd0SThéo Degioanni   // points of the analysis. The slot value is definitely live-in in those
344f88f8fd0SThéo Degioanni   // blocks.
345f88f8fd0SThéo Degioanni   SmallPtrSet<Block *, 16> visited;
346f88f8fd0SThéo Degioanni   for (Operation *user : slot.ptr.getUsers()) {
34742494e51SKazu Hirata     if (!visited.insert(user->getBlock()).second)
348f88f8fd0SThéo Degioanni       continue;
349f88f8fd0SThéo Degioanni 
350f88f8fd0SThéo Degioanni     for (Operation &op : user->getBlock()->getOperations()) {
351f88f8fd0SThéo Degioanni       if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
352f88f8fd0SThéo Degioanni         // If this operation loads the slot, it is loading from it before
353f88f8fd0SThéo Degioanni         // ever writing to it, so the value is live-in in this block.
354f88f8fd0SThéo Degioanni         if (memOp.loadsFrom(slot)) {
355f88f8fd0SThéo Degioanni           liveInWorkList.push_back(user->getBlock());
356f88f8fd0SThéo Degioanni           break;
357f88f8fd0SThéo Degioanni         }
358f88f8fd0SThéo Degioanni 
359f88f8fd0SThéo Degioanni         // If we store to the slot, further loads will see that value.
360f88f8fd0SThéo Degioanni         // Because we did not meet any load before, the value is not live-in.
3618404b23aSThéo Degioanni         if (memOp.storesTo(slot))
362f88f8fd0SThéo Degioanni           break;
363f88f8fd0SThéo Degioanni       }
364f88f8fd0SThéo Degioanni     }
365f88f8fd0SThéo Degioanni   }
366f88f8fd0SThéo Degioanni 
367f88f8fd0SThéo Degioanni   // The information is then propagated to the predecessors until a def site
368f88f8fd0SThéo Degioanni   // (store) is found.
369f88f8fd0SThéo Degioanni   while (!liveInWorkList.empty()) {
370f88f8fd0SThéo Degioanni     Block *liveInBlock = liveInWorkList.pop_back_val();
371f88f8fd0SThéo Degioanni 
372f88f8fd0SThéo Degioanni     if (!liveIn.insert(liveInBlock).second)
373f88f8fd0SThéo Degioanni       continue;
374f88f8fd0SThéo Degioanni 
375f88f8fd0SThéo Degioanni     // If a predecessor is a defining block, either:
376f88f8fd0SThéo Degioanni     // - It has a load before its first store, in which case it is live-in but
377f88f8fd0SThéo Degioanni     // has already been processed in the initialisation step.
378f88f8fd0SThéo Degioanni     // - It has a store before any load, in which case it is not live-in.
379f88f8fd0SThéo Degioanni     // We can thus at this stage insert to the worklist only predecessors that
380f88f8fd0SThéo Degioanni     // are not defining blocks.
381f88f8fd0SThéo Degioanni     for (Block *pred : liveInBlock->getPredecessors())
382f88f8fd0SThéo Degioanni       if (!definingBlocks.contains(pred))
383f88f8fd0SThéo Degioanni         liveInWorkList.push_back(pred);
384f88f8fd0SThéo Degioanni   }
385f88f8fd0SThéo Degioanni 
386f88f8fd0SThéo Degioanni   return liveIn;
387f88f8fd0SThéo Degioanni }
388f88f8fd0SThéo Degioanni 
389f88f8fd0SThéo Degioanni using IDFCalculator = llvm::IDFCalculatorBase<Block, false>;
39092cc30acSThéo Degioanni void MemorySlotPromotionAnalyzer::computeMergePoints(
39192cc30acSThéo Degioanni     SmallPtrSetImpl<Block *> &mergePoints) {
392f88f8fd0SThéo Degioanni   if (slot.ptr.getParentRegion()->hasOneBlock())
393f88f8fd0SThéo Degioanni     return;
394f88f8fd0SThéo Degioanni 
395f88f8fd0SThéo Degioanni   IDFCalculator idfCalculator(dominance.getDomTree(slot.ptr.getParentRegion()));
396f88f8fd0SThéo Degioanni 
397f88f8fd0SThéo Degioanni   SmallPtrSet<Block *, 16> definingBlocks;
398f88f8fd0SThéo Degioanni   for (Operation *user : slot.ptr.getUsers())
399f88f8fd0SThéo Degioanni     if (auto storeOp = dyn_cast<PromotableMemOpInterface>(user))
4008404b23aSThéo Degioanni       if (storeOp.storesTo(slot))
401f88f8fd0SThéo Degioanni         definingBlocks.insert(user->getBlock());
402f88f8fd0SThéo Degioanni 
403f88f8fd0SThéo Degioanni   idfCalculator.setDefiningBlocks(definingBlocks);
404f88f8fd0SThéo Degioanni 
405f88f8fd0SThéo Degioanni   SmallPtrSet<Block *, 16> liveIn = computeSlotLiveIn(definingBlocks);
406f88f8fd0SThéo Degioanni   idfCalculator.setLiveInBlocks(liveIn);
407f88f8fd0SThéo Degioanni 
408f88f8fd0SThéo Degioanni   SmallVector<Block *> mergePointsVec;
409f88f8fd0SThéo Degioanni   idfCalculator.calculate(mergePointsVec);
410f88f8fd0SThéo Degioanni 
411f88f8fd0SThéo Degioanni   mergePoints.insert(mergePointsVec.begin(), mergePointsVec.end());
412f88f8fd0SThéo Degioanni }
413f88f8fd0SThéo Degioanni 
41492cc30acSThéo Degioanni bool MemorySlotPromotionAnalyzer::areMergePointsUsable(
41592cc30acSThéo Degioanni     SmallPtrSetImpl<Block *> &mergePoints) {
416f88f8fd0SThéo Degioanni   for (Block *mergePoint : mergePoints)
417f88f8fd0SThéo Degioanni     for (Block *pred : mergePoint->getPredecessors())
418f88f8fd0SThéo Degioanni       if (!isa<BranchOpInterface>(pred->getTerminator()))
419f88f8fd0SThéo Degioanni         return false;
420f88f8fd0SThéo Degioanni 
421f88f8fd0SThéo Degioanni   return true;
422f88f8fd0SThéo Degioanni }
423f88f8fd0SThéo Degioanni 
42492cc30acSThéo Degioanni std::optional<MemorySlotPromotionInfo>
42592cc30acSThéo Degioanni MemorySlotPromotionAnalyzer::computeInfo() {
42692cc30acSThéo Degioanni   MemorySlotPromotionInfo info;
42792cc30acSThéo Degioanni 
42892cc30acSThéo Degioanni   // First, find the set of operations that will need to be changed for the
42992cc30acSThéo Degioanni   // promotion to happen. These operations need to resolve some of their uses,
43092cc30acSThéo Degioanni   // either by rewiring them or simply deleting themselves. If any of them
43192cc30acSThéo Degioanni   // cannot find a way to resolve their blocking uses, we abort the promotion.
43292cc30acSThéo Degioanni   if (failed(computeBlockingUses(info.userToBlockingUses)))
43392cc30acSThéo Degioanni     return {};
43492cc30acSThéo Degioanni 
43592cc30acSThéo Degioanni   // Then, compute blocks in which two or more definitions of the allocated
43692cc30acSThéo Degioanni   // variable may conflict. These blocks will need a new block argument to
43760baaf15SCongcong Cai   // accommodate this.
43892cc30acSThéo Degioanni   computeMergePoints(info.mergePoints);
43992cc30acSThéo Degioanni 
44092cc30acSThéo Degioanni   // The slot can be promoted if the block arguments to be created can
44192cc30acSThéo Degioanni   // actually be populated with values, which may not be possible depending
44292cc30acSThéo Degioanni   // on their predecessors.
44392cc30acSThéo Degioanni   if (!areMergePointsUsable(info.mergePoints))
44492cc30acSThéo Degioanni     return {};
44592cc30acSThéo Degioanni 
44692cc30acSThéo Degioanni   return info;
44792cc30acSThéo Degioanni }
44892cc30acSThéo Degioanni 
44992cc30acSThéo Degioanni Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
45092cc30acSThéo Degioanni                                                     Value reachingDef) {
4518404b23aSThéo Degioanni   SmallVector<Operation *> blockOps;
4528404b23aSThéo Degioanni   for (Operation &op : block->getOperations())
4538404b23aSThéo Degioanni     blockOps.push_back(&op);
4548404b23aSThéo Degioanni   for (Operation *op : blockOps) {
455f88f8fd0SThéo Degioanni     if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
45692cc30acSThéo Degioanni       if (info.userToBlockingUses.contains(memOp))
457f88f8fd0SThéo Degioanni         reachingDefs.insert({memOp, reachingDef});
458f88f8fd0SThéo Degioanni 
4598404b23aSThéo Degioanni       if (memOp.storesTo(slot)) {
460084e2b53SChristian Ulmann         builder.setInsertionPointAfter(memOp);
461084e2b53SChristian Ulmann         Value stored = memOp.getStored(slot, builder, reachingDef, dataLayout);
4628404b23aSThéo Degioanni         assert(stored && "a memory operation storing to a slot must provide a "
4638404b23aSThéo Degioanni                          "new definition of the slot");
464f88f8fd0SThéo Degioanni         reachingDef = stored;
465220cdf94SFabian Mora         replacedValuesMap[memOp] = stored;
466f88f8fd0SThéo Degioanni       }
467f88f8fd0SThéo Degioanni     }
4688404b23aSThéo Degioanni   }
469f88f8fd0SThéo Degioanni 
470f88f8fd0SThéo Degioanni   return reachingDef;
471f88f8fd0SThéo Degioanni }
472f88f8fd0SThéo Degioanni 
47392cc30acSThéo Degioanni void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
474f88f8fd0SThéo Degioanni                                                     Value reachingDef) {
4756e9ea6eaSChristian Ulmann   assert(reachingDef && "expected an initial reaching def to be provided");
476f88f8fd0SThéo Degioanni   if (region->hasOneBlock()) {
477f88f8fd0SThéo Degioanni     computeReachingDefInBlock(&region->front(), reachingDef);
478f88f8fd0SThéo Degioanni     return;
479f88f8fd0SThéo Degioanni   }
480f88f8fd0SThéo Degioanni 
481f88f8fd0SThéo Degioanni   struct DfsJob {
482f88f8fd0SThéo Degioanni     llvm::DomTreeNodeBase<Block> *block;
483f88f8fd0SThéo Degioanni     Value reachingDef;
484f88f8fd0SThéo Degioanni   };
485f88f8fd0SThéo Degioanni 
486f88f8fd0SThéo Degioanni   SmallVector<DfsJob> dfsStack;
487f88f8fd0SThéo Degioanni 
488f88f8fd0SThéo Degioanni   auto &domTree = dominance.getDomTree(slot.ptr.getParentRegion());
489f88f8fd0SThéo Degioanni 
490f88f8fd0SThéo Degioanni   dfsStack.emplace_back<DfsJob>(
491f88f8fd0SThéo Degioanni       {domTree.getNode(&region->front()), reachingDef});
492f88f8fd0SThéo Degioanni 
493f88f8fd0SThéo Degioanni   while (!dfsStack.empty()) {
494f88f8fd0SThéo Degioanni     DfsJob job = dfsStack.pop_back_val();
495f88f8fd0SThéo Degioanni     Block *block = job.block->getBlock();
496f88f8fd0SThéo Degioanni 
49792cc30acSThéo Degioanni     if (info.mergePoints.contains(block)) {
498084e2b53SChristian Ulmann       BlockArgument blockArgument =
499084e2b53SChristian Ulmann           block->addArgument(slot.elemType, slot.ptr.getLoc());
500084e2b53SChristian Ulmann       builder.setInsertionPointToStart(block);
501084e2b53SChristian Ulmann       allocator.handleBlockArgument(slot, blockArgument, builder);
502f88f8fd0SThéo Degioanni       job.reachingDef = blockArgument;
503ead8e9d7SThéo Degioanni 
504ead8e9d7SThéo Degioanni       if (statistics.newBlockArgumentAmount)
505ead8e9d7SThéo Degioanni         (*statistics.newBlockArgumentAmount)++;
506f88f8fd0SThéo Degioanni     }
507f88f8fd0SThéo Degioanni 
508f88f8fd0SThéo Degioanni     job.reachingDef = computeReachingDefInBlock(block, job.reachingDef);
5096e9ea6eaSChristian Ulmann     assert(job.reachingDef);
510f88f8fd0SThéo Degioanni 
511f88f8fd0SThéo Degioanni     if (auto terminator = dyn_cast<BranchOpInterface>(block->getTerminator())) {
512f88f8fd0SThéo Degioanni       for (BlockOperand &blockOperand : terminator->getBlockOperands()) {
51392cc30acSThéo Degioanni         if (info.mergePoints.contains(blockOperand.get())) {
514f88f8fd0SThéo Degioanni           terminator.getSuccessorOperands(blockOperand.getOperandNumber())
515f88f8fd0SThéo Degioanni               .append(job.reachingDef);
516f88f8fd0SThéo Degioanni         }
517f88f8fd0SThéo Degioanni       }
518f88f8fd0SThéo Degioanni     }
519f88f8fd0SThéo Degioanni 
520f88f8fd0SThéo Degioanni     for (auto *child : job.block->children())
521f88f8fd0SThéo Degioanni       dfsStack.emplace_back<DfsJob>({child, job.reachingDef});
522f88f8fd0SThéo Degioanni   }
523f88f8fd0SThéo Degioanni }
524f88f8fd0SThéo Degioanni 
525c6efcc92SChristian Ulmann /// Gets or creates a block index mapping for `region`.
526c6efcc92SChristian Ulmann static const DenseMap<Block *, size_t> &
527c6efcc92SChristian Ulmann getOrCreateBlockIndices(BlockIndexCache &blockIndexCache, Region *region) {
528c6efcc92SChristian Ulmann   auto [it, inserted] = blockIndexCache.try_emplace(region);
529c6efcc92SChristian Ulmann   if (!inserted)
530c6efcc92SChristian Ulmann     return it->second;
531c6efcc92SChristian Ulmann 
532c6efcc92SChristian Ulmann   DenseMap<Block *, size_t> &blockIndices = it->second;
533e919df57SChristian Ulmann   SetVector<Block *> topologicalOrder = getBlocksSortedByDominance(*region);
534c6efcc92SChristian Ulmann   for (auto [index, block] : llvm::enumerate(topologicalOrder))
535c6efcc92SChristian Ulmann     blockIndices[block] = index;
536c6efcc92SChristian Ulmann   return blockIndices;
537c6efcc92SChristian Ulmann }
538c6efcc92SChristian Ulmann 
539ab6a66dbSChristian Ulmann /// Sorts `ops` according to dominance. Relies on the topological order of basic
540c6efcc92SChristian Ulmann /// blocks to get a deterministic ordering. Uses `blockIndexCache` to avoid the
541c6efcc92SChristian Ulmann /// potentially expensive recomputation of a block index map.
542c6efcc92SChristian Ulmann static void dominanceSort(SmallVector<Operation *> &ops, Region &region,
543c6efcc92SChristian Ulmann                           BlockIndexCache &blockIndexCache) {
544ab6a66dbSChristian Ulmann   // Produce a topological block order and construct a map to lookup the indices
545ab6a66dbSChristian Ulmann   // of blocks.
546c6efcc92SChristian Ulmann   const DenseMap<Block *, size_t> &topoBlockIndices =
547c6efcc92SChristian Ulmann       getOrCreateBlockIndices(blockIndexCache, &region);
548ab6a66dbSChristian Ulmann 
549ab6a66dbSChristian Ulmann   // Combining the topological order of the basic blocks together with block
550ab6a66dbSChristian Ulmann   // internal operation order guarantees a deterministic, dominance respecting
551ab6a66dbSChristian Ulmann   // order.
552ab6a66dbSChristian Ulmann   llvm::sort(ops, [&](Operation *lhs, Operation *rhs) {
553ab6a66dbSChristian Ulmann     size_t lhsBlockIndex = topoBlockIndices.at(lhs->getBlock());
554ab6a66dbSChristian Ulmann     size_t rhsBlockIndex = topoBlockIndices.at(rhs->getBlock());
555ab6a66dbSChristian Ulmann     if (lhsBlockIndex == rhsBlockIndex)
556ab6a66dbSChristian Ulmann       return lhs->isBeforeInBlock(rhs);
557ab6a66dbSChristian Ulmann     return lhsBlockIndex < rhsBlockIndex;
558ab6a66dbSChristian Ulmann   });
559ab6a66dbSChristian Ulmann }
560ab6a66dbSChristian Ulmann 
56192cc30acSThéo Degioanni void MemorySlotPromoter::removeBlockingUses() {
562ab6a66dbSChristian Ulmann   llvm::SmallVector<Operation *> usersToRemoveUses(
563ab6a66dbSChristian Ulmann       llvm::make_first_range(info.userToBlockingUses));
564ab6a66dbSChristian Ulmann 
565ab6a66dbSChristian Ulmann   // Sort according to dominance.
566c6efcc92SChristian Ulmann   dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent(),
567c6efcc92SChristian Ulmann                 blockIndexCache);
568f88f8fd0SThéo Degioanni 
569f88f8fd0SThéo Degioanni   llvm::SmallVector<Operation *> toErase;
570220cdf94SFabian Mora   // List of all replaced values in the slot.
571220cdf94SFabian Mora   llvm::SmallVector<std::pair<Operation *, Value>> replacedValuesList;
572220cdf94SFabian Mora   // Ops to visit with the `visitReplacedValues` method.
573220cdf94SFabian Mora   llvm::SmallVector<PromotableOpInterface> toVisit;
574ab6a66dbSChristian Ulmann   for (Operation *toPromote : llvm::reverse(usersToRemoveUses)) {
575f88f8fd0SThéo Degioanni     if (auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
576f88f8fd0SThéo Degioanni       Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
577f88f8fd0SThéo Degioanni       // If no reaching definition is known, this use is outside the reach of
578f88f8fd0SThéo Degioanni       // the slot. The default value should thus be used.
579f88f8fd0SThéo Degioanni       if (!reachingDef)
5806e9ea6eaSChristian Ulmann         reachingDef = getOrCreateDefaultValue();
581f88f8fd0SThéo Degioanni 
582084e2b53SChristian Ulmann       builder.setInsertionPointAfter(toPromote);
58392cc30acSThéo Degioanni       if (toPromoteMemOp.removeBlockingUses(
584084e2b53SChristian Ulmann               slot, info.userToBlockingUses[toPromote], builder, reachingDef,
585ac39fa74SChristian Ulmann               dataLayout) == DeletionKind::Delete)
586f88f8fd0SThéo Degioanni         toErase.push_back(toPromote);
587220cdf94SFabian Mora       if (toPromoteMemOp.storesTo(slot))
588220cdf94SFabian Mora         if (Value replacedValue = replacedValuesMap[toPromoteMemOp])
589220cdf94SFabian Mora           replacedValuesList.push_back({toPromoteMemOp, replacedValue});
590f88f8fd0SThéo Degioanni       continue;
591f88f8fd0SThéo Degioanni     }
592f88f8fd0SThéo Degioanni 
593f88f8fd0SThéo Degioanni     auto toPromoteBasic = cast<PromotableOpInterface>(toPromote);
594084e2b53SChristian Ulmann     builder.setInsertionPointAfter(toPromote);
59592cc30acSThéo Degioanni     if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote],
596084e2b53SChristian Ulmann                                           builder) == DeletionKind::Delete)
597f88f8fd0SThéo Degioanni       toErase.push_back(toPromote);
598220cdf94SFabian Mora     if (toPromoteBasic.requiresReplacedValues())
599220cdf94SFabian Mora       toVisit.push_back(toPromoteBasic);
600220cdf94SFabian Mora   }
601220cdf94SFabian Mora   for (PromotableOpInterface op : toVisit) {
602084e2b53SChristian Ulmann     builder.setInsertionPointAfter(op);
603084e2b53SChristian Ulmann     op.visitReplacedValues(replacedValuesList, builder);
604f88f8fd0SThéo Degioanni   }
605f88f8fd0SThéo Degioanni 
606f88f8fd0SThéo Degioanni   for (Operation *toEraseOp : toErase)
607084e2b53SChristian Ulmann     toEraseOp->erase();
608f88f8fd0SThéo Degioanni 
609f88f8fd0SThéo Degioanni   assert(slot.ptr.use_empty() &&
610f88f8fd0SThéo Degioanni          "after promotion, the slot pointer should not be used anymore");
611f88f8fd0SThéo Degioanni }
612f88f8fd0SThéo Degioanni 
613eeafc9daSChristian Ulmann std::optional<PromotableAllocationOpInterface>
614eeafc9daSChristian Ulmann MemorySlotPromoter::promoteSlot() {
6156e9ea6eaSChristian Ulmann   computeReachingDefInRegion(slot.ptr.getParentRegion(),
6166e9ea6eaSChristian Ulmann                              getOrCreateDefaultValue());
617f88f8fd0SThéo Degioanni 
618f88f8fd0SThéo Degioanni   // Now that reaching definitions are known, remove all users.
619f88f8fd0SThéo Degioanni   removeBlockingUses();
620f88f8fd0SThéo Degioanni 
621f88f8fd0SThéo Degioanni   // Update terminators in dead branches to forward default if they are
622f88f8fd0SThéo Degioanni   // succeeded by a merge points.
62392cc30acSThéo Degioanni   for (Block *mergePoint : info.mergePoints) {
624f88f8fd0SThéo Degioanni     for (BlockOperand &use : mergePoint->getUses()) {
625f88f8fd0SThéo Degioanni       auto user = cast<BranchOpInterface>(use.getOwner());
626f88f8fd0SThéo Degioanni       SuccessorOperands succOperands =
627f88f8fd0SThéo Degioanni           user.getSuccessorOperands(use.getOperandNumber());
628f88f8fd0SThéo Degioanni       assert(succOperands.size() == mergePoint->getNumArguments() ||
629f88f8fd0SThéo Degioanni              succOperands.size() + 1 == mergePoint->getNumArguments());
630f88f8fd0SThéo Degioanni       if (succOperands.size() + 1 == mergePoint->getNumArguments())
631084e2b53SChristian Ulmann         succOperands.append(getOrCreateDefaultValue());
632f88f8fd0SThéo Degioanni     }
633f88f8fd0SThéo Degioanni   }
634f88f8fd0SThéo Degioanni 
6353ba79a36SThéo Degioanni   LLVM_DEBUG(llvm::dbgs() << "[mem2reg] Promoted memory slot: " << slot.ptr
6363ba79a36SThéo Degioanni                           << "\n");
6373ba79a36SThéo Degioanni 
638ead8e9d7SThéo Degioanni   if (statistics.promotedAmount)
639ead8e9d7SThéo Degioanni     (*statistics.promotedAmount)++;
640ead8e9d7SThéo Degioanni 
641eeafc9daSChristian Ulmann   return allocator.handlePromotionComplete(slot, defaultValue, builder);
642f88f8fd0SThéo Degioanni }
643f88f8fd0SThéo Degioanni 
644f88f8fd0SThéo Degioanni LogicalResult mlir::tryToPromoteMemorySlots(
645084e2b53SChristian Ulmann     ArrayRef<PromotableAllocationOpInterface> allocators, OpBuilder &builder,
646c6efcc92SChristian Ulmann     const DataLayout &dataLayout, DominanceInfo &dominance,
647c6efcc92SChristian Ulmann     Mem2RegStatistics statistics) {
648ead8e9d7SThéo Degioanni   bool promotedAny = false;
649ead8e9d7SThéo Degioanni 
650c6efcc92SChristian Ulmann   // A cache that stores deterministic block indices which are used to determine
651c6efcc92SChristian Ulmann   // a valid operation modification order. The block index maps are computed
652c6efcc92SChristian Ulmann   // lazily and cached to avoid expensive recomputation.
653c6efcc92SChristian Ulmann   BlockIndexCache blockIndexCache;
654c6efcc92SChristian Ulmann 
6555262865aSKazu Hirata   SmallVector<PromotableAllocationOpInterface> workList(allocators);
656eeafc9daSChristian Ulmann 
657eeafc9daSChristian Ulmann   SmallVector<PromotableAllocationOpInterface> newWorkList;
658eeafc9daSChristian Ulmann   newWorkList.reserve(workList.size());
659eeafc9daSChristian Ulmann   while (true) {
660eeafc9daSChristian Ulmann     bool changesInThisRound = false;
661eeafc9daSChristian Ulmann     for (PromotableAllocationOpInterface allocator : workList) {
662eeafc9daSChristian Ulmann       bool changedAllocator = false;
663f88f8fd0SThéo Degioanni       for (MemorySlot slot : allocator.getPromotableSlots()) {
664f88f8fd0SThéo Degioanni         if (slot.ptr.use_empty())
665f88f8fd0SThéo Degioanni           continue;
666f88f8fd0SThéo Degioanni 
66798c6bc53SChristian Ulmann         MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
66892cc30acSThéo Degioanni         std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
669ead8e9d7SThéo Degioanni         if (info) {
670eeafc9daSChristian Ulmann           std::optional<PromotableAllocationOpInterface> newAllocator =
671eeafc9daSChristian Ulmann               MemorySlotPromoter(slot, allocator, builder, dominance,
672eeafc9daSChristian Ulmann                                  dataLayout, std::move(*info), statistics,
673eeafc9daSChristian Ulmann                                  blockIndexCache)
674ead8e9d7SThéo Degioanni                   .promoteSlot();
675eeafc9daSChristian Ulmann           changedAllocator = true;
676eeafc9daSChristian Ulmann           // Add newly created allocators to the worklist for further
677eeafc9daSChristian Ulmann           // processing.
678eeafc9daSChristian Ulmann           if (newAllocator)
679eeafc9daSChristian Ulmann             newWorkList.push_back(*newAllocator);
680eeafc9daSChristian Ulmann 
681eeafc9daSChristian Ulmann           // A break is required, since promoting a slot may invalidate the
682eeafc9daSChristian Ulmann           // remaining slots of an allocator.
683eeafc9daSChristian Ulmann           break;
684eeafc9daSChristian Ulmann         }
685eeafc9daSChristian Ulmann       }
686eeafc9daSChristian Ulmann       if (!changedAllocator)
687eeafc9daSChristian Ulmann         newWorkList.push_back(allocator);
688eeafc9daSChristian Ulmann       changesInThisRound |= changedAllocator;
689eeafc9daSChristian Ulmann     }
690eeafc9daSChristian Ulmann     if (!changesInThisRound)
691eeafc9daSChristian Ulmann       break;
692ead8e9d7SThéo Degioanni     promotedAny = true;
693eeafc9daSChristian Ulmann 
694eeafc9daSChristian Ulmann     // Swap the vector's backing memory and clear the entries in newWorkList
695eeafc9daSChristian Ulmann     // afterwards. This ensures that additional heap allocations can be avoided.
696eeafc9daSChristian Ulmann     workList.swap(newWorkList);
697eeafc9daSChristian Ulmann     newWorkList.clear();
698f88f8fd0SThéo Degioanni   }
699f88f8fd0SThéo Degioanni 
700ead8e9d7SThéo Degioanni   return success(promotedAny);
701f88f8fd0SThéo Degioanni }
702f88f8fd0SThéo Degioanni 
703f88f8fd0SThéo Degioanni namespace {
704f88f8fd0SThéo Degioanni 
705f88f8fd0SThéo Degioanni struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
706ead8e9d7SThéo Degioanni   using impl::Mem2RegBase<Mem2Reg>::Mem2RegBase;
707ead8e9d7SThéo Degioanni 
708f88f8fd0SThéo Degioanni   void runOnOperation() override {
709f88f8fd0SThéo Degioanni     Operation *scopeOp = getOperation();
710ead8e9d7SThéo Degioanni 
7113e2992f0SChristian Ulmann     Mem2RegStatistics statistics{&promotedAmount, &newBlockArgumentAmount};
712ead8e9d7SThéo Degioanni 
7133e2992f0SChristian Ulmann     bool changed = false;
714f88f8fd0SThéo Degioanni 
715c6efcc92SChristian Ulmann     auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
716c6efcc92SChristian Ulmann     const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
717c6efcc92SChristian Ulmann     auto &dominance = getAnalysis<DominanceInfo>();
718c6efcc92SChristian Ulmann 
7193e2992f0SChristian Ulmann     for (Region &region : scopeOp->getRegions()) {
7203e2992f0SChristian Ulmann       if (region.getBlocks().empty())
7213e2992f0SChristian Ulmann         continue;
722f88f8fd0SThéo Degioanni 
7233e2992f0SChristian Ulmann       OpBuilder builder(&region.front(), region.front().begin());
7243e2992f0SChristian Ulmann 
7253e2992f0SChristian Ulmann       SmallVector<PromotableAllocationOpInterface> allocators;
7263e2992f0SChristian Ulmann       // Build a list of allocators to attempt to promote the slots of.
7273e2992f0SChristian Ulmann       region.walk([&](PromotableAllocationOpInterface allocator) {
7283e2992f0SChristian Ulmann         allocators.emplace_back(allocator);
7293e2992f0SChristian Ulmann       });
7303e2992f0SChristian Ulmann 
731eeafc9daSChristian Ulmann       // Attempt promoting as many of the slots as possible.
732eeafc9daSChristian Ulmann       if (succeeded(tryToPromoteMemorySlots(allocators, builder, dataLayout,
733c6efcc92SChristian Ulmann                                             dominance, statistics)))
7343e2992f0SChristian Ulmann         changed = true;
7353e2992f0SChristian Ulmann     }
7363e2992f0SChristian Ulmann     if (!changed)
7373e2992f0SChristian Ulmann       markAllAnalysesPreserved();
738f88f8fd0SThéo Degioanni   }
739f88f8fd0SThéo Degioanni };
740f88f8fd0SThéo Degioanni 
741f88f8fd0SThéo Degioanni } // namespace
742