xref: /llvm-project/mlir/lib/Transforms/SROA.cpp (revision ef436f3f60c14935d244d73f11d008f272d123e1)
1 //===-- SROA.cpp - Scalar Replacement Of Aggregates -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/SROA.h"
10 #include "mlir/Analysis/DataLayoutAnalysis.h"
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Analysis/TopologicalSortUtils.h"
13 #include "mlir/Interfaces/MemorySlotInterfaces.h"
14 #include "mlir/Transforms/Passes.h"
15 
16 namespace mlir {
17 #define GEN_PASS_DEF_SROA
18 #include "mlir/Transforms/Passes.h.inc"
19 } // namespace mlir
20 
21 #define DEBUG_TYPE "sroa"
22 
23 using namespace mlir;
24 
25 namespace {
26 
27 /// Information computed by destructurable memory slot analysis used to perform
28 /// actual destructuring of the slot. This struct is only constructed if
29 /// destructuring is possible, and contains the necessary data to perform it.
30 struct MemorySlotDestructuringInfo {
31   /// Set of the indices that are actually used when accessing the subelements.
32   SmallPtrSet<Attribute, 8> usedIndices;
33   /// Blocking uses of a given user of the memory slot that must be eliminated.
34   DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> userToBlockingUses;
35   /// List of potentially indirect accessors of the memory slot that need
36   /// rewiring.
37   SmallVector<DestructurableAccessorOpInterface> accessors;
38 };
39 
40 } // namespace
41 
42 /// Computes information for slot destructuring. This will compute whether this
43 /// slot can be destructured and data to perform the destructuring. Returns
44 /// nothing if the slot cannot be destructured or if there is no useful work to
45 /// be done.
46 static std::optional<MemorySlotDestructuringInfo>
47 computeDestructuringInfo(DestructurableMemorySlot &slot,
48                          const DataLayout &dataLayout) {
49   assert(isa<DestructurableTypeInterface>(slot.elemType));
50 
51   if (slot.ptr.use_empty())
52     return {};
53 
54   MemorySlotDestructuringInfo info;
55 
56   SmallVector<MemorySlot> usedSafelyWorklist;
57 
58   auto scheduleAsBlockingUse = [&](OpOperand &use) {
59     SmallPtrSetImpl<OpOperand *> &blockingUses =
60         info.userToBlockingUses[use.getOwner()];
61     blockingUses.insert(&use);
62   };
63 
64   // Initialize the analysis with the immediate users of the slot.
65   for (OpOperand &use : slot.ptr.getUses()) {
66     if (auto accessor =
67             dyn_cast<DestructurableAccessorOpInterface>(use.getOwner())) {
68       if (accessor.canRewire(slot, info.usedIndices, usedSafelyWorklist,
69                              dataLayout)) {
70         info.accessors.push_back(accessor);
71         continue;
72       }
73     }
74 
75     // If it cannot be shown that the operation uses the slot safely, maybe it
76     // can be promoted out of using the slot?
77     scheduleAsBlockingUse(use);
78   }
79 
80   SmallPtrSet<OpOperand *, 16> visited;
81   while (!usedSafelyWorklist.empty()) {
82     MemorySlot mustBeUsedSafely = usedSafelyWorklist.pop_back_val();
83     for (OpOperand &subslotUse : mustBeUsedSafely.ptr.getUses()) {
84       if (!visited.insert(&subslotUse).second)
85         continue;
86       Operation *subslotUser = subslotUse.getOwner();
87 
88       if (auto memOp = dyn_cast<SafeMemorySlotAccessOpInterface>(subslotUser))
89         if (succeeded(memOp.ensureOnlySafeAccesses(
90                 mustBeUsedSafely, usedSafelyWorklist, dataLayout)))
91           continue;
92 
93       // If it cannot be shown that the operation uses the slot safely, maybe it
94       // can be promoted out of using the slot?
95       scheduleAsBlockingUse(subslotUse);
96     }
97   }
98 
99   SetVector<Operation *> forwardSlice;
100   mlir::getForwardSlice(slot.ptr, &forwardSlice);
101   for (Operation *user : forwardSlice) {
102     // If the next operation has no blocking uses, everything is fine.
103     auto it = info.userToBlockingUses.find(user);
104     if (it == info.userToBlockingUses.end())
105       continue;
106 
107     SmallPtrSet<OpOperand *, 4> &blockingUses = it->second;
108     auto promotable = dyn_cast<PromotableOpInterface>(user);
109 
110     // An operation that has blocking uses must be promoted. If it is not
111     // promotable, destructuring must fail.
112     if (!promotable)
113       return {};
114 
115     SmallVector<OpOperand *> newBlockingUses;
116     // If the operation decides it cannot deal with removing the blocking uses,
117     // destructuring must fail.
118     if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses, dataLayout))
119       return {};
120 
121     // Then, register any new blocking uses for coming operations.
122     for (OpOperand *blockingUse : newBlockingUses) {
123       assert(llvm::is_contained(user->getResults(), blockingUse->get()));
124 
125       SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =
126           info.userToBlockingUses[blockingUse->getOwner()];
127       newUserBlockingUseSet.insert(blockingUse);
128     }
129   }
130 
131   return info;
132 }
133 
134 /// Performs the destructuring of a destructible slot given associated
135 /// destructuring information. The provided slot will be destructured in
136 /// subslots as specified by its allocator.
137 static void destructureSlot(
138     DestructurableMemorySlot &slot,
139     DestructurableAllocationOpInterface allocator, OpBuilder &builder,
140     const DataLayout &dataLayout, MemorySlotDestructuringInfo &info,
141     SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators,
142     const SROAStatistics &statistics) {
143   OpBuilder::InsertionGuard guard(builder);
144 
145   builder.setInsertionPointToStart(slot.ptr.getParentBlock());
146   DenseMap<Attribute, MemorySlot> subslots =
147       allocator.destructure(slot, info.usedIndices, builder, newAllocators);
148 
149   if (statistics.slotsWithMemoryBenefit &&
150       slot.subelementTypes.size() != info.usedIndices.size())
151     (*statistics.slotsWithMemoryBenefit)++;
152 
153   if (statistics.maxSubelementAmount)
154     statistics.maxSubelementAmount->updateMax(slot.subelementTypes.size());
155 
156   SetVector<Operation *> usersToRewire;
157   for (Operation *user : llvm::make_first_range(info.userToBlockingUses))
158     usersToRewire.insert(user);
159   for (DestructurableAccessorOpInterface accessor : info.accessors)
160     usersToRewire.insert(accessor);
161   usersToRewire = mlir::topologicalSort(usersToRewire);
162 
163   llvm::SmallVector<Operation *> toErase;
164   for (Operation *toRewire : llvm::reverse(usersToRewire)) {
165     builder.setInsertionPointAfter(toRewire);
166     if (auto accessor = dyn_cast<DestructurableAccessorOpInterface>(toRewire)) {
167       if (accessor.rewire(slot, subslots, builder, dataLayout) ==
168           DeletionKind::Delete)
169         toErase.push_back(accessor);
170       continue;
171     }
172 
173     auto promotable = cast<PromotableOpInterface>(toRewire);
174     if (promotable.removeBlockingUses(info.userToBlockingUses[promotable],
175                                       builder) == DeletionKind::Delete)
176       toErase.push_back(promotable);
177   }
178 
179   for (Operation *toEraseOp : toErase)
180     toEraseOp->erase();
181 
182   assert(slot.ptr.use_empty() && "after destructuring, the original slot "
183                                  "pointer should no longer be used");
184 
185   LLVM_DEBUG(llvm::dbgs() << "[sroa] Destructured memory slot: " << slot.ptr
186                           << "\n");
187 
188   if (statistics.destructuredAmount)
189     (*statistics.destructuredAmount)++;
190 
191   std::optional<DestructurableAllocationOpInterface> newAllocator =
192       allocator.handleDestructuringComplete(slot, builder);
193   // Add newly created allocators to the worklist for further processing.
194   if (newAllocator)
195     newAllocators.push_back(*newAllocator);
196 }
197 
198 LogicalResult mlir::tryToDestructureMemorySlots(
199     ArrayRef<DestructurableAllocationOpInterface> allocators,
200     OpBuilder &builder, const DataLayout &dataLayout,
201     SROAStatistics statistics) {
202   bool destructuredAny = false;
203 
204   SmallVector<DestructurableAllocationOpInterface> workList(allocators);
205   SmallVector<DestructurableAllocationOpInterface> newWorkList;
206   newWorkList.reserve(allocators.size());
207   // Destructuring a slot can allow for further destructuring of other
208   // slots, destructuring is tried until no destructuring succeeds.
209   while (true) {
210     bool changesInThisRound = false;
211 
212     for (DestructurableAllocationOpInterface allocator : workList) {
213       bool destructuredAnySlot = false;
214       for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) {
215         std::optional<MemorySlotDestructuringInfo> info =
216             computeDestructuringInfo(slot, dataLayout);
217         if (!info)
218           continue;
219 
220         destructureSlot(slot, allocator, builder, dataLayout, *info,
221                         newWorkList, statistics);
222         destructuredAnySlot = true;
223 
224         // A break is required, since destructuring a slot may invalidate the
225         // remaning slots of an allocator.
226         break;
227       }
228       if (!destructuredAnySlot)
229         newWorkList.push_back(allocator);
230       changesInThisRound |= destructuredAnySlot;
231     }
232 
233     if (!changesInThisRound)
234       break;
235     destructuredAny |= changesInThisRound;
236 
237     // Swap the vector's backing memory and clear the entries in newWorkList
238     // afterwards. This ensures that additional heap allocations can be avoided.
239     workList.swap(newWorkList);
240     newWorkList.clear();
241   }
242 
243   return success(destructuredAny);
244 }
245 
246 namespace {
247 
248 struct SROA : public impl::SROABase<SROA> {
249   using impl::SROABase<SROA>::SROABase;
250 
251   void runOnOperation() override {
252     Operation *scopeOp = getOperation();
253 
254     SROAStatistics statistics{&destructuredAmount, &slotsWithMemoryBenefit,
255                               &maxSubelementAmount};
256 
257     auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
258     const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
259     bool changed = false;
260 
261     for (Region &region : scopeOp->getRegions()) {
262       if (region.getBlocks().empty())
263         continue;
264 
265       OpBuilder builder(&region.front(), region.front().begin());
266 
267       SmallVector<DestructurableAllocationOpInterface> allocators;
268       // Build a list of allocators to attempt to destructure the slots of.
269       region.walk([&](DestructurableAllocationOpInterface allocator) {
270         allocators.emplace_back(allocator);
271       });
272 
273       // Attempt to destructure as many slots as possible.
274       if (succeeded(tryToDestructureMemorySlots(allocators, builder, dataLayout,
275                                                 statistics)))
276         changed = true;
277     }
278     if (!changed)
279       markAllAnalysesPreserved();
280   }
281 };
282 
283 } // namespace
284