14ed502efSThéo Degioanni //===-- SROA.cpp - Scalar Replacement Of Aggregates -------------*- C++ -*-===// 24ed502efSThéo Degioanni // 34ed502efSThéo Degioanni // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 44ed502efSThéo Degioanni // See https://llvm.org/LICENSE.txt for license information. 54ed502efSThéo Degioanni // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 64ed502efSThéo Degioanni // 74ed502efSThéo Degioanni //===----------------------------------------------------------------------===// 84ed502efSThéo Degioanni 94ed502efSThéo Degioanni #include "mlir/Transforms/SROA.h" 1098c6bc53SChristian Ulmann #include "mlir/Analysis/DataLayoutAnalysis.h" 114ed502efSThéo Degioanni #include "mlir/Analysis/SliceAnalysis.h" 12b00e0c16SChristian Ulmann #include "mlir/Analysis/TopologicalSortUtils.h" 134ed502efSThéo Degioanni #include "mlir/Interfaces/MemorySlotInterfaces.h" 144ed502efSThéo Degioanni #include "mlir/Transforms/Passes.h" 154ed502efSThéo Degioanni 164ed502efSThéo Degioanni namespace mlir { 174ed502efSThéo Degioanni #define GEN_PASS_DEF_SROA 184ed502efSThéo Degioanni #include "mlir/Transforms/Passes.h.inc" 194ed502efSThéo Degioanni } // namespace mlir 204ed502efSThéo Degioanni 214ed502efSThéo Degioanni #define DEBUG_TYPE "sroa" 224ed502efSThéo Degioanni 234ed502efSThéo Degioanni using namespace mlir; 244ed502efSThéo Degioanni 254ed502efSThéo Degioanni namespace { 264ed502efSThéo Degioanni 274ed502efSThéo Degioanni /// Information computed by destructurable memory slot analysis used to perform 284ed502efSThéo Degioanni /// actual destructuring of the slot. This struct is only constructed if 294ed502efSThéo Degioanni /// destructuring is possible, and contains the necessary data to perform it. 304ed502efSThéo Degioanni struct MemorySlotDestructuringInfo { 314ed502efSThéo Degioanni /// Set of the indices that are actually used when accessing the subelements. 324ed502efSThéo Degioanni SmallPtrSet<Attribute, 8> usedIndices; 334ed502efSThéo Degioanni /// Blocking uses of a given user of the memory slot that must be eliminated. 344ed502efSThéo Degioanni DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> userToBlockingUses; 354ed502efSThéo Degioanni /// List of potentially indirect accessors of the memory slot that need 364ed502efSThéo Degioanni /// rewiring. 374ed502efSThéo Degioanni SmallVector<DestructurableAccessorOpInterface> accessors; 384ed502efSThéo Degioanni }; 394ed502efSThéo Degioanni 404ed502efSThéo Degioanni } // namespace 414ed502efSThéo Degioanni 424ed502efSThéo Degioanni /// Computes information for slot destructuring. This will compute whether this 434ed502efSThéo Degioanni /// slot can be destructured and data to perform the destructuring. Returns 444ed502efSThéo Degioanni /// nothing if the slot cannot be destructured or if there is no useful work to 454ed502efSThéo Degioanni /// be done. 464ed502efSThéo Degioanni static std::optional<MemorySlotDestructuringInfo> 4798c6bc53SChristian Ulmann computeDestructuringInfo(DestructurableMemorySlot &slot, 4898c6bc53SChristian Ulmann const DataLayout &dataLayout) { 494ed502efSThéo Degioanni assert(isa<DestructurableTypeInterface>(slot.elemType)); 504ed502efSThéo Degioanni 514ed502efSThéo Degioanni if (slot.ptr.use_empty()) 524ed502efSThéo Degioanni return {}; 534ed502efSThéo Degioanni 544ed502efSThéo Degioanni MemorySlotDestructuringInfo info; 554ed502efSThéo Degioanni 564ed502efSThéo Degioanni SmallVector<MemorySlot> usedSafelyWorklist; 574ed502efSThéo Degioanni 584ed502efSThéo Degioanni auto scheduleAsBlockingUse = [&](OpOperand &use) { 594ed502efSThéo Degioanni SmallPtrSetImpl<OpOperand *> &blockingUses = 6059a3b415SKazu Hirata info.userToBlockingUses[use.getOwner()]; 614ed502efSThéo Degioanni blockingUses.insert(&use); 624ed502efSThéo Degioanni }; 634ed502efSThéo Degioanni 644ed502efSThéo Degioanni // Initialize the analysis with the immediate users of the slot. 654ed502efSThéo Degioanni for (OpOperand &use : slot.ptr.getUses()) { 664ed502efSThéo Degioanni if (auto accessor = 674ed502efSThéo Degioanni dyn_cast<DestructurableAccessorOpInterface>(use.getOwner())) { 6898c6bc53SChristian Ulmann if (accessor.canRewire(slot, info.usedIndices, usedSafelyWorklist, 6998c6bc53SChristian Ulmann dataLayout)) { 704ed502efSThéo Degioanni info.accessors.push_back(accessor); 714ed502efSThéo Degioanni continue; 724ed502efSThéo Degioanni } 734ed502efSThéo Degioanni } 744ed502efSThéo Degioanni 754ed502efSThéo Degioanni // If it cannot be shown that the operation uses the slot safely, maybe it 764ed502efSThéo Degioanni // can be promoted out of using the slot? 774ed502efSThéo Degioanni scheduleAsBlockingUse(use); 784ed502efSThéo Degioanni } 794ed502efSThéo Degioanni 804ed502efSThéo Degioanni SmallPtrSet<OpOperand *, 16> visited; 814ed502efSThéo Degioanni while (!usedSafelyWorklist.empty()) { 824ed502efSThéo Degioanni MemorySlot mustBeUsedSafely = usedSafelyWorklist.pop_back_val(); 834ed502efSThéo Degioanni for (OpOperand &subslotUse : mustBeUsedSafely.ptr.getUses()) { 844ed502efSThéo Degioanni if (!visited.insert(&subslotUse).second) 854ed502efSThéo Degioanni continue; 864ed502efSThéo Degioanni Operation *subslotUser = subslotUse.getOwner(); 874ed502efSThéo Degioanni 884ed502efSThéo Degioanni if (auto memOp = dyn_cast<SafeMemorySlotAccessOpInterface>(subslotUser)) 8998c6bc53SChristian Ulmann if (succeeded(memOp.ensureOnlySafeAccesses( 9098c6bc53SChristian Ulmann mustBeUsedSafely, usedSafelyWorklist, dataLayout))) 914ed502efSThéo Degioanni continue; 924ed502efSThéo Degioanni 934ed502efSThéo Degioanni // If it cannot be shown that the operation uses the slot safely, maybe it 944ed502efSThéo Degioanni // can be promoted out of using the slot? 954ed502efSThéo Degioanni scheduleAsBlockingUse(subslotUse); 964ed502efSThéo Degioanni } 974ed502efSThéo Degioanni } 984ed502efSThéo Degioanni 994ed502efSThéo Degioanni SetVector<Operation *> forwardSlice; 1004ed502efSThéo Degioanni mlir::getForwardSlice(slot.ptr, &forwardSlice); 1014ed502efSThéo Degioanni for (Operation *user : forwardSlice) { 1024ed502efSThéo Degioanni // If the next operation has no blocking uses, everything is fine. 103*ef436f3fSKazu Hirata auto it = info.userToBlockingUses.find(user); 104*ef436f3fSKazu Hirata if (it == info.userToBlockingUses.end()) 1054ed502efSThéo Degioanni continue; 1064ed502efSThéo Degioanni 107*ef436f3fSKazu Hirata SmallPtrSet<OpOperand *, 4> &blockingUses = it->second; 1084ed502efSThéo Degioanni auto promotable = dyn_cast<PromotableOpInterface>(user); 1094ed502efSThéo Degioanni 1104ed502efSThéo Degioanni // An operation that has blocking uses must be promoted. If it is not 1114ed502efSThéo Degioanni // promotable, destructuring must fail. 1124ed502efSThéo Degioanni if (!promotable) 1134ed502efSThéo Degioanni return {}; 1144ed502efSThéo Degioanni 1154ed502efSThéo Degioanni SmallVector<OpOperand *> newBlockingUses; 1164ed502efSThéo Degioanni // If the operation decides it cannot deal with removing the blocking uses, 1174ed502efSThéo Degioanni // destructuring must fail. 11898c6bc53SChristian Ulmann if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses, dataLayout)) 1194ed502efSThéo Degioanni return {}; 1204ed502efSThéo Degioanni 1214ed502efSThéo Degioanni // Then, register any new blocking uses for coming operations. 1224ed502efSThéo Degioanni for (OpOperand *blockingUse : newBlockingUses) { 1234ed502efSThéo Degioanni assert(llvm::is_contained(user->getResults(), blockingUse->get())); 1244ed502efSThéo Degioanni 1254ed502efSThéo Degioanni SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet = 12659a3b415SKazu Hirata info.userToBlockingUses[blockingUse->getOwner()]; 1274ed502efSThéo Degioanni newUserBlockingUseSet.insert(blockingUse); 1284ed502efSThéo Degioanni } 1294ed502efSThéo Degioanni } 1304ed502efSThéo Degioanni 1314ed502efSThéo Degioanni return info; 1324ed502efSThéo Degioanni } 1334ed502efSThéo Degioanni 1344ed502efSThéo Degioanni /// Performs the destructuring of a destructible slot given associated 1354ed502efSThéo Degioanni /// destructuring information. The provided slot will be destructured in 1364ed502efSThéo Degioanni /// subslots as specified by its allocator. 1370b5b2027SChristian Ulmann static void destructureSlot( 1380b5b2027SChristian Ulmann DestructurableMemorySlot &slot, 1390b5b2027SChristian Ulmann DestructurableAllocationOpInterface allocator, OpBuilder &builder, 1400b5b2027SChristian Ulmann const DataLayout &dataLayout, MemorySlotDestructuringInfo &info, 1410b5b2027SChristian Ulmann SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators, 1424ed502efSThéo Degioanni const SROAStatistics &statistics) { 143084e2b53SChristian Ulmann OpBuilder::InsertionGuard guard(builder); 1444ed502efSThéo Degioanni 145084e2b53SChristian Ulmann builder.setInsertionPointToStart(slot.ptr.getParentBlock()); 1464ed502efSThéo Degioanni DenseMap<Attribute, MemorySlot> subslots = 1470b5b2027SChristian Ulmann allocator.destructure(slot, info.usedIndices, builder, newAllocators); 1484ed502efSThéo Degioanni 1494ed502efSThéo Degioanni if (statistics.slotsWithMemoryBenefit && 15069d3793fSThéo Degioanni slot.subelementTypes.size() != info.usedIndices.size()) 1514ed502efSThéo Degioanni (*statistics.slotsWithMemoryBenefit)++; 1524ed502efSThéo Degioanni 1534ed502efSThéo Degioanni if (statistics.maxSubelementAmount) 15469d3793fSThéo Degioanni statistics.maxSubelementAmount->updateMax(slot.subelementTypes.size()); 1554ed502efSThéo Degioanni 1564ed502efSThéo Degioanni SetVector<Operation *> usersToRewire; 1574ed502efSThéo Degioanni for (Operation *user : llvm::make_first_range(info.userToBlockingUses)) 1584ed502efSThéo Degioanni usersToRewire.insert(user); 1594ed502efSThéo Degioanni for (DestructurableAccessorOpInterface accessor : info.accessors) 1604ed502efSThéo Degioanni usersToRewire.insert(accessor); 1614ed502efSThéo Degioanni usersToRewire = mlir::topologicalSort(usersToRewire); 1624ed502efSThéo Degioanni 1634ed502efSThéo Degioanni llvm::SmallVector<Operation *> toErase; 1644ed502efSThéo Degioanni for (Operation *toRewire : llvm::reverse(usersToRewire)) { 165084e2b53SChristian Ulmann builder.setInsertionPointAfter(toRewire); 1664ed502efSThéo Degioanni if (auto accessor = dyn_cast<DestructurableAccessorOpInterface>(toRewire)) { 167084e2b53SChristian Ulmann if (accessor.rewire(slot, subslots, builder, dataLayout) == 16898c6bc53SChristian Ulmann DeletionKind::Delete) 1694ed502efSThéo Degioanni toErase.push_back(accessor); 1704ed502efSThéo Degioanni continue; 1714ed502efSThéo Degioanni } 1724ed502efSThéo Degioanni 1734ed502efSThéo Degioanni auto promotable = cast<PromotableOpInterface>(toRewire); 1744ed502efSThéo Degioanni if (promotable.removeBlockingUses(info.userToBlockingUses[promotable], 175084e2b53SChristian Ulmann builder) == DeletionKind::Delete) 1764ed502efSThéo Degioanni toErase.push_back(promotable); 1774ed502efSThéo Degioanni } 1784ed502efSThéo Degioanni 1794ed502efSThéo Degioanni for (Operation *toEraseOp : toErase) 180084e2b53SChristian Ulmann toEraseOp->erase(); 1814ed502efSThéo Degioanni 1824ed502efSThéo Degioanni assert(slot.ptr.use_empty() && "after destructuring, the original slot " 1834ed502efSThéo Degioanni "pointer should no longer be used"); 1844ed502efSThéo Degioanni 1854ed502efSThéo Degioanni LLVM_DEBUG(llvm::dbgs() << "[sroa] Destructured memory slot: " << slot.ptr 1864ed502efSThéo Degioanni << "\n"); 1874ed502efSThéo Degioanni 1884ed502efSThéo Degioanni if (statistics.destructuredAmount) 1894ed502efSThéo Degioanni (*statistics.destructuredAmount)++; 1904ed502efSThéo Degioanni 1910b5b2027SChristian Ulmann std::optional<DestructurableAllocationOpInterface> newAllocator = 192084e2b53SChristian Ulmann allocator.handleDestructuringComplete(slot, builder); 1930b5b2027SChristian Ulmann // Add newly created allocators to the worklist for further processing. 1940b5b2027SChristian Ulmann if (newAllocator) 1950b5b2027SChristian Ulmann newAllocators.push_back(*newAllocator); 1964ed502efSThéo Degioanni } 1974ed502efSThéo Degioanni 1984ed502efSThéo Degioanni LogicalResult mlir::tryToDestructureMemorySlots( 1994ed502efSThéo Degioanni ArrayRef<DestructurableAllocationOpInterface> allocators, 200084e2b53SChristian Ulmann OpBuilder &builder, const DataLayout &dataLayout, 20198c6bc53SChristian Ulmann SROAStatistics statistics) { 2024ed502efSThéo Degioanni bool destructuredAny = false; 2034ed502efSThéo Degioanni 2045262865aSKazu Hirata SmallVector<DestructurableAllocationOpInterface> workList(allocators); 2050b5b2027SChristian Ulmann SmallVector<DestructurableAllocationOpInterface> newWorkList; 2060b5b2027SChristian Ulmann newWorkList.reserve(allocators.size()); 2070b5b2027SChristian Ulmann // Destructuring a slot can allow for further destructuring of other 2080b5b2027SChristian Ulmann // slots, destructuring is tried until no destructuring succeeds. 2090b5b2027SChristian Ulmann while (true) { 2100b5b2027SChristian Ulmann bool changesInThisRound = false; 2110b5b2027SChristian Ulmann 2120b5b2027SChristian Ulmann for (DestructurableAllocationOpInterface allocator : workList) { 2130b5b2027SChristian Ulmann bool destructuredAnySlot = false; 2144ed502efSThéo Degioanni for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) { 2154ed502efSThéo Degioanni std::optional<MemorySlotDestructuringInfo> info = 21698c6bc53SChristian Ulmann computeDestructuringInfo(slot, dataLayout); 2174ed502efSThéo Degioanni if (!info) 2184ed502efSThéo Degioanni continue; 2194ed502efSThéo Degioanni 2200b5b2027SChristian Ulmann destructureSlot(slot, allocator, builder, dataLayout, *info, 2210b5b2027SChristian Ulmann newWorkList, statistics); 2220b5b2027SChristian Ulmann destructuredAnySlot = true; 2230b5b2027SChristian Ulmann 2240b5b2027SChristian Ulmann // A break is required, since destructuring a slot may invalidate the 2250b5b2027SChristian Ulmann // remaning slots of an allocator. 2260b5b2027SChristian Ulmann break; 2274ed502efSThéo Degioanni } 2280b5b2027SChristian Ulmann if (!destructuredAnySlot) 2290b5b2027SChristian Ulmann newWorkList.push_back(allocator); 2300b5b2027SChristian Ulmann changesInThisRound |= destructuredAnySlot; 2310b5b2027SChristian Ulmann } 2320b5b2027SChristian Ulmann 2330b5b2027SChristian Ulmann if (!changesInThisRound) 2340b5b2027SChristian Ulmann break; 2350b5b2027SChristian Ulmann destructuredAny |= changesInThisRound; 2360b5b2027SChristian Ulmann 2370b5b2027SChristian Ulmann // Swap the vector's backing memory and clear the entries in newWorkList 2380b5b2027SChristian Ulmann // afterwards. This ensures that additional heap allocations can be avoided. 2390b5b2027SChristian Ulmann workList.swap(newWorkList); 2400b5b2027SChristian Ulmann newWorkList.clear(); 2414ed502efSThéo Degioanni } 2424ed502efSThéo Degioanni 2434ed502efSThéo Degioanni return success(destructuredAny); 2444ed502efSThéo Degioanni } 2454ed502efSThéo Degioanni 2464ed502efSThéo Degioanni namespace { 2474ed502efSThéo Degioanni 2484ed502efSThéo Degioanni struct SROA : public impl::SROABase<SROA> { 2494ed502efSThéo Degioanni using impl::SROABase<SROA>::SROABase; 2504ed502efSThéo Degioanni 2514ed502efSThéo Degioanni void runOnOperation() override { 2524ed502efSThéo Degioanni Operation *scopeOp = getOperation(); 2534ed502efSThéo Degioanni 2544ed502efSThéo Degioanni SROAStatistics statistics{&destructuredAmount, &slotsWithMemoryBenefit, 2554ed502efSThéo Degioanni &maxSubelementAmount}; 2564ed502efSThéo Degioanni 25798c6bc53SChristian Ulmann auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 25898c6bc53SChristian Ulmann const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp); 25963897a59SChristian Ulmann bool changed = false; 2604ed502efSThéo Degioanni 26163897a59SChristian Ulmann for (Region ®ion : scopeOp->getRegions()) { 26263897a59SChristian Ulmann if (region.getBlocks().empty()) 26363897a59SChristian Ulmann continue; 26463897a59SChristian Ulmann 26563897a59SChristian Ulmann OpBuilder builder(®ion.front(), region.front().begin()); 26663897a59SChristian Ulmann 26763897a59SChristian Ulmann SmallVector<DestructurableAllocationOpInterface> allocators; 26863897a59SChristian Ulmann // Build a list of allocators to attempt to destructure the slots of. 26963897a59SChristian Ulmann region.walk([&](DestructurableAllocationOpInterface allocator) { 27063897a59SChristian Ulmann allocators.emplace_back(allocator); 27163897a59SChristian Ulmann }); 27263897a59SChristian Ulmann 2730b5b2027SChristian Ulmann // Attempt to destructure as many slots as possible. 2740b5b2027SChristian Ulmann if (succeeded(tryToDestructureMemorySlots(allocators, builder, dataLayout, 27598c6bc53SChristian Ulmann statistics))) 27663897a59SChristian Ulmann changed = true; 27763897a59SChristian Ulmann } 27863897a59SChristian Ulmann if (!changed) 27963897a59SChristian Ulmann markAllAnalysesPreserved(); 2804ed502efSThéo Degioanni } 2814ed502efSThéo Degioanni }; 2824ed502efSThéo Degioanni 2834ed502efSThéo Degioanni } // namespace 284