xref: /llvm-project/mlir/lib/Transforms/SROA.cpp (revision ef436f3f60c14935d244d73f11d008f272d123e1)
14ed502efSThéo Degioanni //===-- SROA.cpp - Scalar Replacement Of Aggregates -------------*- C++ -*-===//
24ed502efSThéo Degioanni //
34ed502efSThéo Degioanni // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44ed502efSThéo Degioanni // See https://llvm.org/LICENSE.txt for license information.
54ed502efSThéo Degioanni // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64ed502efSThéo Degioanni //
74ed502efSThéo Degioanni //===----------------------------------------------------------------------===//
84ed502efSThéo Degioanni 
94ed502efSThéo Degioanni #include "mlir/Transforms/SROA.h"
1098c6bc53SChristian Ulmann #include "mlir/Analysis/DataLayoutAnalysis.h"
114ed502efSThéo Degioanni #include "mlir/Analysis/SliceAnalysis.h"
12b00e0c16SChristian Ulmann #include "mlir/Analysis/TopologicalSortUtils.h"
134ed502efSThéo Degioanni #include "mlir/Interfaces/MemorySlotInterfaces.h"
144ed502efSThéo Degioanni #include "mlir/Transforms/Passes.h"
154ed502efSThéo Degioanni 
164ed502efSThéo Degioanni namespace mlir {
174ed502efSThéo Degioanni #define GEN_PASS_DEF_SROA
184ed502efSThéo Degioanni #include "mlir/Transforms/Passes.h.inc"
194ed502efSThéo Degioanni } // namespace mlir
204ed502efSThéo Degioanni 
214ed502efSThéo Degioanni #define DEBUG_TYPE "sroa"
224ed502efSThéo Degioanni 
234ed502efSThéo Degioanni using namespace mlir;
244ed502efSThéo Degioanni 
254ed502efSThéo Degioanni namespace {
264ed502efSThéo Degioanni 
274ed502efSThéo Degioanni /// Information computed by destructurable memory slot analysis used to perform
284ed502efSThéo Degioanni /// actual destructuring of the slot. This struct is only constructed if
294ed502efSThéo Degioanni /// destructuring is possible, and contains the necessary data to perform it.
304ed502efSThéo Degioanni struct MemorySlotDestructuringInfo {
314ed502efSThéo Degioanni   /// Set of the indices that are actually used when accessing the subelements.
324ed502efSThéo Degioanni   SmallPtrSet<Attribute, 8> usedIndices;
334ed502efSThéo Degioanni   /// Blocking uses of a given user of the memory slot that must be eliminated.
344ed502efSThéo Degioanni   DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> userToBlockingUses;
354ed502efSThéo Degioanni   /// List of potentially indirect accessors of the memory slot that need
364ed502efSThéo Degioanni   /// rewiring.
374ed502efSThéo Degioanni   SmallVector<DestructurableAccessorOpInterface> accessors;
384ed502efSThéo Degioanni };
394ed502efSThéo Degioanni 
404ed502efSThéo Degioanni } // namespace
414ed502efSThéo Degioanni 
424ed502efSThéo Degioanni /// Computes information for slot destructuring. This will compute whether this
434ed502efSThéo Degioanni /// slot can be destructured and data to perform the destructuring. Returns
444ed502efSThéo Degioanni /// nothing if the slot cannot be destructured or if there is no useful work to
454ed502efSThéo Degioanni /// be done.
464ed502efSThéo Degioanni static std::optional<MemorySlotDestructuringInfo>
4798c6bc53SChristian Ulmann computeDestructuringInfo(DestructurableMemorySlot &slot,
4898c6bc53SChristian Ulmann                          const DataLayout &dataLayout) {
494ed502efSThéo Degioanni   assert(isa<DestructurableTypeInterface>(slot.elemType));
504ed502efSThéo Degioanni 
514ed502efSThéo Degioanni   if (slot.ptr.use_empty())
524ed502efSThéo Degioanni     return {};
534ed502efSThéo Degioanni 
544ed502efSThéo Degioanni   MemorySlotDestructuringInfo info;
554ed502efSThéo Degioanni 
564ed502efSThéo Degioanni   SmallVector<MemorySlot> usedSafelyWorklist;
574ed502efSThéo Degioanni 
584ed502efSThéo Degioanni   auto scheduleAsBlockingUse = [&](OpOperand &use) {
594ed502efSThéo Degioanni     SmallPtrSetImpl<OpOperand *> &blockingUses =
6059a3b415SKazu Hirata         info.userToBlockingUses[use.getOwner()];
614ed502efSThéo Degioanni     blockingUses.insert(&use);
624ed502efSThéo Degioanni   };
634ed502efSThéo Degioanni 
644ed502efSThéo Degioanni   // Initialize the analysis with the immediate users of the slot.
654ed502efSThéo Degioanni   for (OpOperand &use : slot.ptr.getUses()) {
664ed502efSThéo Degioanni     if (auto accessor =
674ed502efSThéo Degioanni             dyn_cast<DestructurableAccessorOpInterface>(use.getOwner())) {
6898c6bc53SChristian Ulmann       if (accessor.canRewire(slot, info.usedIndices, usedSafelyWorklist,
6998c6bc53SChristian Ulmann                              dataLayout)) {
704ed502efSThéo Degioanni         info.accessors.push_back(accessor);
714ed502efSThéo Degioanni         continue;
724ed502efSThéo Degioanni       }
734ed502efSThéo Degioanni     }
744ed502efSThéo Degioanni 
754ed502efSThéo Degioanni     // If it cannot be shown that the operation uses the slot safely, maybe it
764ed502efSThéo Degioanni     // can be promoted out of using the slot?
774ed502efSThéo Degioanni     scheduleAsBlockingUse(use);
784ed502efSThéo Degioanni   }
794ed502efSThéo Degioanni 
804ed502efSThéo Degioanni   SmallPtrSet<OpOperand *, 16> visited;
814ed502efSThéo Degioanni   while (!usedSafelyWorklist.empty()) {
824ed502efSThéo Degioanni     MemorySlot mustBeUsedSafely = usedSafelyWorklist.pop_back_val();
834ed502efSThéo Degioanni     for (OpOperand &subslotUse : mustBeUsedSafely.ptr.getUses()) {
844ed502efSThéo Degioanni       if (!visited.insert(&subslotUse).second)
854ed502efSThéo Degioanni         continue;
864ed502efSThéo Degioanni       Operation *subslotUser = subslotUse.getOwner();
874ed502efSThéo Degioanni 
884ed502efSThéo Degioanni       if (auto memOp = dyn_cast<SafeMemorySlotAccessOpInterface>(subslotUser))
8998c6bc53SChristian Ulmann         if (succeeded(memOp.ensureOnlySafeAccesses(
9098c6bc53SChristian Ulmann                 mustBeUsedSafely, usedSafelyWorklist, dataLayout)))
914ed502efSThéo Degioanni           continue;
924ed502efSThéo Degioanni 
934ed502efSThéo Degioanni       // If it cannot be shown that the operation uses the slot safely, maybe it
944ed502efSThéo Degioanni       // can be promoted out of using the slot?
954ed502efSThéo Degioanni       scheduleAsBlockingUse(subslotUse);
964ed502efSThéo Degioanni     }
974ed502efSThéo Degioanni   }
984ed502efSThéo Degioanni 
994ed502efSThéo Degioanni   SetVector<Operation *> forwardSlice;
1004ed502efSThéo Degioanni   mlir::getForwardSlice(slot.ptr, &forwardSlice);
1014ed502efSThéo Degioanni   for (Operation *user : forwardSlice) {
1024ed502efSThéo Degioanni     // If the next operation has no blocking uses, everything is fine.
103*ef436f3fSKazu Hirata     auto it = info.userToBlockingUses.find(user);
104*ef436f3fSKazu Hirata     if (it == info.userToBlockingUses.end())
1054ed502efSThéo Degioanni       continue;
1064ed502efSThéo Degioanni 
107*ef436f3fSKazu Hirata     SmallPtrSet<OpOperand *, 4> &blockingUses = it->second;
1084ed502efSThéo Degioanni     auto promotable = dyn_cast<PromotableOpInterface>(user);
1094ed502efSThéo Degioanni 
1104ed502efSThéo Degioanni     // An operation that has blocking uses must be promoted. If it is not
1114ed502efSThéo Degioanni     // promotable, destructuring must fail.
1124ed502efSThéo Degioanni     if (!promotable)
1134ed502efSThéo Degioanni       return {};
1144ed502efSThéo Degioanni 
1154ed502efSThéo Degioanni     SmallVector<OpOperand *> newBlockingUses;
1164ed502efSThéo Degioanni     // If the operation decides it cannot deal with removing the blocking uses,
1174ed502efSThéo Degioanni     // destructuring must fail.
11898c6bc53SChristian Ulmann     if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses, dataLayout))
1194ed502efSThéo Degioanni       return {};
1204ed502efSThéo Degioanni 
1214ed502efSThéo Degioanni     // Then, register any new blocking uses for coming operations.
1224ed502efSThéo Degioanni     for (OpOperand *blockingUse : newBlockingUses) {
1234ed502efSThéo Degioanni       assert(llvm::is_contained(user->getResults(), blockingUse->get()));
1244ed502efSThéo Degioanni 
1254ed502efSThéo Degioanni       SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =
12659a3b415SKazu Hirata           info.userToBlockingUses[blockingUse->getOwner()];
1274ed502efSThéo Degioanni       newUserBlockingUseSet.insert(blockingUse);
1284ed502efSThéo Degioanni     }
1294ed502efSThéo Degioanni   }
1304ed502efSThéo Degioanni 
1314ed502efSThéo Degioanni   return info;
1324ed502efSThéo Degioanni }
1334ed502efSThéo Degioanni 
1344ed502efSThéo Degioanni /// Performs the destructuring of a destructible slot given associated
1354ed502efSThéo Degioanni /// destructuring information. The provided slot will be destructured in
1364ed502efSThéo Degioanni /// subslots as specified by its allocator.
1370b5b2027SChristian Ulmann static void destructureSlot(
1380b5b2027SChristian Ulmann     DestructurableMemorySlot &slot,
1390b5b2027SChristian Ulmann     DestructurableAllocationOpInterface allocator, OpBuilder &builder,
1400b5b2027SChristian Ulmann     const DataLayout &dataLayout, MemorySlotDestructuringInfo &info,
1410b5b2027SChristian Ulmann     SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators,
1424ed502efSThéo Degioanni     const SROAStatistics &statistics) {
143084e2b53SChristian Ulmann   OpBuilder::InsertionGuard guard(builder);
1444ed502efSThéo Degioanni 
145084e2b53SChristian Ulmann   builder.setInsertionPointToStart(slot.ptr.getParentBlock());
1464ed502efSThéo Degioanni   DenseMap<Attribute, MemorySlot> subslots =
1470b5b2027SChristian Ulmann       allocator.destructure(slot, info.usedIndices, builder, newAllocators);
1484ed502efSThéo Degioanni 
1494ed502efSThéo Degioanni   if (statistics.slotsWithMemoryBenefit &&
15069d3793fSThéo Degioanni       slot.subelementTypes.size() != info.usedIndices.size())
1514ed502efSThéo Degioanni     (*statistics.slotsWithMemoryBenefit)++;
1524ed502efSThéo Degioanni 
1534ed502efSThéo Degioanni   if (statistics.maxSubelementAmount)
15469d3793fSThéo Degioanni     statistics.maxSubelementAmount->updateMax(slot.subelementTypes.size());
1554ed502efSThéo Degioanni 
1564ed502efSThéo Degioanni   SetVector<Operation *> usersToRewire;
1574ed502efSThéo Degioanni   for (Operation *user : llvm::make_first_range(info.userToBlockingUses))
1584ed502efSThéo Degioanni     usersToRewire.insert(user);
1594ed502efSThéo Degioanni   for (DestructurableAccessorOpInterface accessor : info.accessors)
1604ed502efSThéo Degioanni     usersToRewire.insert(accessor);
1614ed502efSThéo Degioanni   usersToRewire = mlir::topologicalSort(usersToRewire);
1624ed502efSThéo Degioanni 
1634ed502efSThéo Degioanni   llvm::SmallVector<Operation *> toErase;
1644ed502efSThéo Degioanni   for (Operation *toRewire : llvm::reverse(usersToRewire)) {
165084e2b53SChristian Ulmann     builder.setInsertionPointAfter(toRewire);
1664ed502efSThéo Degioanni     if (auto accessor = dyn_cast<DestructurableAccessorOpInterface>(toRewire)) {
167084e2b53SChristian Ulmann       if (accessor.rewire(slot, subslots, builder, dataLayout) ==
16898c6bc53SChristian Ulmann           DeletionKind::Delete)
1694ed502efSThéo Degioanni         toErase.push_back(accessor);
1704ed502efSThéo Degioanni       continue;
1714ed502efSThéo Degioanni     }
1724ed502efSThéo Degioanni 
1734ed502efSThéo Degioanni     auto promotable = cast<PromotableOpInterface>(toRewire);
1744ed502efSThéo Degioanni     if (promotable.removeBlockingUses(info.userToBlockingUses[promotable],
175084e2b53SChristian Ulmann                                       builder) == DeletionKind::Delete)
1764ed502efSThéo Degioanni       toErase.push_back(promotable);
1774ed502efSThéo Degioanni   }
1784ed502efSThéo Degioanni 
1794ed502efSThéo Degioanni   for (Operation *toEraseOp : toErase)
180084e2b53SChristian Ulmann     toEraseOp->erase();
1814ed502efSThéo Degioanni 
1824ed502efSThéo Degioanni   assert(slot.ptr.use_empty() && "after destructuring, the original slot "
1834ed502efSThéo Degioanni                                  "pointer should no longer be used");
1844ed502efSThéo Degioanni 
1854ed502efSThéo Degioanni   LLVM_DEBUG(llvm::dbgs() << "[sroa] Destructured memory slot: " << slot.ptr
1864ed502efSThéo Degioanni                           << "\n");
1874ed502efSThéo Degioanni 
1884ed502efSThéo Degioanni   if (statistics.destructuredAmount)
1894ed502efSThéo Degioanni     (*statistics.destructuredAmount)++;
1904ed502efSThéo Degioanni 
1910b5b2027SChristian Ulmann   std::optional<DestructurableAllocationOpInterface> newAllocator =
192084e2b53SChristian Ulmann       allocator.handleDestructuringComplete(slot, builder);
1930b5b2027SChristian Ulmann   // Add newly created allocators to the worklist for further processing.
1940b5b2027SChristian Ulmann   if (newAllocator)
1950b5b2027SChristian Ulmann     newAllocators.push_back(*newAllocator);
1964ed502efSThéo Degioanni }
1974ed502efSThéo Degioanni 
1984ed502efSThéo Degioanni LogicalResult mlir::tryToDestructureMemorySlots(
1994ed502efSThéo Degioanni     ArrayRef<DestructurableAllocationOpInterface> allocators,
200084e2b53SChristian Ulmann     OpBuilder &builder, const DataLayout &dataLayout,
20198c6bc53SChristian Ulmann     SROAStatistics statistics) {
2024ed502efSThéo Degioanni   bool destructuredAny = false;
2034ed502efSThéo Degioanni 
2045262865aSKazu Hirata   SmallVector<DestructurableAllocationOpInterface> workList(allocators);
2050b5b2027SChristian Ulmann   SmallVector<DestructurableAllocationOpInterface> newWorkList;
2060b5b2027SChristian Ulmann   newWorkList.reserve(allocators.size());
2070b5b2027SChristian Ulmann   // Destructuring a slot can allow for further destructuring of other
2080b5b2027SChristian Ulmann   // slots, destructuring is tried until no destructuring succeeds.
2090b5b2027SChristian Ulmann   while (true) {
2100b5b2027SChristian Ulmann     bool changesInThisRound = false;
2110b5b2027SChristian Ulmann 
2120b5b2027SChristian Ulmann     for (DestructurableAllocationOpInterface allocator : workList) {
2130b5b2027SChristian Ulmann       bool destructuredAnySlot = false;
2144ed502efSThéo Degioanni       for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) {
2154ed502efSThéo Degioanni         std::optional<MemorySlotDestructuringInfo> info =
21698c6bc53SChristian Ulmann             computeDestructuringInfo(slot, dataLayout);
2174ed502efSThéo Degioanni         if (!info)
2184ed502efSThéo Degioanni           continue;
2194ed502efSThéo Degioanni 
2200b5b2027SChristian Ulmann         destructureSlot(slot, allocator, builder, dataLayout, *info,
2210b5b2027SChristian Ulmann                         newWorkList, statistics);
2220b5b2027SChristian Ulmann         destructuredAnySlot = true;
2230b5b2027SChristian Ulmann 
2240b5b2027SChristian Ulmann         // A break is required, since destructuring a slot may invalidate the
2250b5b2027SChristian Ulmann         // remaning slots of an allocator.
2260b5b2027SChristian Ulmann         break;
2274ed502efSThéo Degioanni       }
2280b5b2027SChristian Ulmann       if (!destructuredAnySlot)
2290b5b2027SChristian Ulmann         newWorkList.push_back(allocator);
2300b5b2027SChristian Ulmann       changesInThisRound |= destructuredAnySlot;
2310b5b2027SChristian Ulmann     }
2320b5b2027SChristian Ulmann 
2330b5b2027SChristian Ulmann     if (!changesInThisRound)
2340b5b2027SChristian Ulmann       break;
2350b5b2027SChristian Ulmann     destructuredAny |= changesInThisRound;
2360b5b2027SChristian Ulmann 
2370b5b2027SChristian Ulmann     // Swap the vector's backing memory and clear the entries in newWorkList
2380b5b2027SChristian Ulmann     // afterwards. This ensures that additional heap allocations can be avoided.
2390b5b2027SChristian Ulmann     workList.swap(newWorkList);
2400b5b2027SChristian Ulmann     newWorkList.clear();
2414ed502efSThéo Degioanni   }
2424ed502efSThéo Degioanni 
2434ed502efSThéo Degioanni   return success(destructuredAny);
2444ed502efSThéo Degioanni }
2454ed502efSThéo Degioanni 
2464ed502efSThéo Degioanni namespace {
2474ed502efSThéo Degioanni 
2484ed502efSThéo Degioanni struct SROA : public impl::SROABase<SROA> {
2494ed502efSThéo Degioanni   using impl::SROABase<SROA>::SROABase;
2504ed502efSThéo Degioanni 
2514ed502efSThéo Degioanni   void runOnOperation() override {
2524ed502efSThéo Degioanni     Operation *scopeOp = getOperation();
2534ed502efSThéo Degioanni 
2544ed502efSThéo Degioanni     SROAStatistics statistics{&destructuredAmount, &slotsWithMemoryBenefit,
2554ed502efSThéo Degioanni                               &maxSubelementAmount};
2564ed502efSThéo Degioanni 
25798c6bc53SChristian Ulmann     auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
25898c6bc53SChristian Ulmann     const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
25963897a59SChristian Ulmann     bool changed = false;
2604ed502efSThéo Degioanni 
26163897a59SChristian Ulmann     for (Region &region : scopeOp->getRegions()) {
26263897a59SChristian Ulmann       if (region.getBlocks().empty())
26363897a59SChristian Ulmann         continue;
26463897a59SChristian Ulmann 
26563897a59SChristian Ulmann       OpBuilder builder(&region.front(), region.front().begin());
26663897a59SChristian Ulmann 
26763897a59SChristian Ulmann       SmallVector<DestructurableAllocationOpInterface> allocators;
26863897a59SChristian Ulmann       // Build a list of allocators to attempt to destructure the slots of.
26963897a59SChristian Ulmann       region.walk([&](DestructurableAllocationOpInterface allocator) {
27063897a59SChristian Ulmann         allocators.emplace_back(allocator);
27163897a59SChristian Ulmann       });
27263897a59SChristian Ulmann 
2730b5b2027SChristian Ulmann       // Attempt to destructure as many slots as possible.
2740b5b2027SChristian Ulmann       if (succeeded(tryToDestructureMemorySlots(allocators, builder, dataLayout,
27598c6bc53SChristian Ulmann                                                 statistics)))
27663897a59SChristian Ulmann         changed = true;
27763897a59SChristian Ulmann     }
27863897a59SChristian Ulmann     if (!changed)
27963897a59SChristian Ulmann       markAllAnalysesPreserved();
2804ed502efSThéo Degioanni   }
2814ed502efSThéo Degioanni };
2824ed502efSThéo Degioanni 
2834ed502efSThéo Degioanni } // namespace
284