//===-- SROA.cpp - Scalar Replacement Of Aggregates -------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Transforms/SROA.h" #include "mlir/Analysis/DataLayoutAnalysis.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Transforms/Passes.h" namespace mlir { #define GEN_PASS_DEF_SROA #include "mlir/Transforms/Passes.h.inc" } // namespace mlir #define DEBUG_TYPE "sroa" using namespace mlir; namespace { /// Information computed by destructurable memory slot analysis used to perform /// actual destructuring of the slot. This struct is only constructed if /// destructuring is possible, and contains the necessary data to perform it. struct MemorySlotDestructuringInfo { /// Set of the indices that are actually used when accessing the subelements. SmallPtrSet usedIndices; /// Blocking uses of a given user of the memory slot that must be eliminated. DenseMap> userToBlockingUses; /// List of potentially indirect accessors of the memory slot that need /// rewiring. SmallVector accessors; }; } // namespace /// Computes information for slot destructuring. This will compute whether this /// slot can be destructured and data to perform the destructuring. Returns /// nothing if the slot cannot be destructured or if there is no useful work to /// be done. static std::optional computeDestructuringInfo(DestructurableMemorySlot &slot, const DataLayout &dataLayout) { assert(isa(slot.elemType)); if (slot.ptr.use_empty()) return {}; MemorySlotDestructuringInfo info; SmallVector usedSafelyWorklist; auto scheduleAsBlockingUse = [&](OpOperand &use) { SmallPtrSetImpl &blockingUses = info.userToBlockingUses[use.getOwner()]; blockingUses.insert(&use); }; // Initialize the analysis with the immediate users of the slot. for (OpOperand &use : slot.ptr.getUses()) { if (auto accessor = dyn_cast(use.getOwner())) { if (accessor.canRewire(slot, info.usedIndices, usedSafelyWorklist, dataLayout)) { info.accessors.push_back(accessor); continue; } } // If it cannot be shown that the operation uses the slot safely, maybe it // can be promoted out of using the slot? scheduleAsBlockingUse(use); } SmallPtrSet visited; while (!usedSafelyWorklist.empty()) { MemorySlot mustBeUsedSafely = usedSafelyWorklist.pop_back_val(); for (OpOperand &subslotUse : mustBeUsedSafely.ptr.getUses()) { if (!visited.insert(&subslotUse).second) continue; Operation *subslotUser = subslotUse.getOwner(); if (auto memOp = dyn_cast(subslotUser)) if (succeeded(memOp.ensureOnlySafeAccesses( mustBeUsedSafely, usedSafelyWorklist, dataLayout))) continue; // If it cannot be shown that the operation uses the slot safely, maybe it // can be promoted out of using the slot? scheduleAsBlockingUse(subslotUse); } } SetVector forwardSlice; mlir::getForwardSlice(slot.ptr, &forwardSlice); for (Operation *user : forwardSlice) { // If the next operation has no blocking uses, everything is fine. auto it = info.userToBlockingUses.find(user); if (it == info.userToBlockingUses.end()) continue; SmallPtrSet &blockingUses = it->second; auto promotable = dyn_cast(user); // An operation that has blocking uses must be promoted. If it is not // promotable, destructuring must fail. if (!promotable) return {}; SmallVector newBlockingUses; // If the operation decides it cannot deal with removing the blocking uses, // destructuring must fail. if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses, dataLayout)) return {}; // Then, register any new blocking uses for coming operations. for (OpOperand *blockingUse : newBlockingUses) { assert(llvm::is_contained(user->getResults(), blockingUse->get())); SmallPtrSetImpl &newUserBlockingUseSet = info.userToBlockingUses[blockingUse->getOwner()]; newUserBlockingUseSet.insert(blockingUse); } } return info; } /// Performs the destructuring of a destructible slot given associated /// destructuring information. The provided slot will be destructured in /// subslots as specified by its allocator. static void destructureSlot( DestructurableMemorySlot &slot, DestructurableAllocationOpInterface allocator, OpBuilder &builder, const DataLayout &dataLayout, MemorySlotDestructuringInfo &info, SmallVectorImpl &newAllocators, const SROAStatistics &statistics) { OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(slot.ptr.getParentBlock()); DenseMap subslots = allocator.destructure(slot, info.usedIndices, builder, newAllocators); if (statistics.slotsWithMemoryBenefit && slot.subelementTypes.size() != info.usedIndices.size()) (*statistics.slotsWithMemoryBenefit)++; if (statistics.maxSubelementAmount) statistics.maxSubelementAmount->updateMax(slot.subelementTypes.size()); SetVector usersToRewire; for (Operation *user : llvm::make_first_range(info.userToBlockingUses)) usersToRewire.insert(user); for (DestructurableAccessorOpInterface accessor : info.accessors) usersToRewire.insert(accessor); usersToRewire = mlir::topologicalSort(usersToRewire); llvm::SmallVector toErase; for (Operation *toRewire : llvm::reverse(usersToRewire)) { builder.setInsertionPointAfter(toRewire); if (auto accessor = dyn_cast(toRewire)) { if (accessor.rewire(slot, subslots, builder, dataLayout) == DeletionKind::Delete) toErase.push_back(accessor); continue; } auto promotable = cast(toRewire); if (promotable.removeBlockingUses(info.userToBlockingUses[promotable], builder) == DeletionKind::Delete) toErase.push_back(promotable); } for (Operation *toEraseOp : toErase) toEraseOp->erase(); assert(slot.ptr.use_empty() && "after destructuring, the original slot " "pointer should no longer be used"); LLVM_DEBUG(llvm::dbgs() << "[sroa] Destructured memory slot: " << slot.ptr << "\n"); if (statistics.destructuredAmount) (*statistics.destructuredAmount)++; std::optional newAllocator = allocator.handleDestructuringComplete(slot, builder); // Add newly created allocators to the worklist for further processing. if (newAllocator) newAllocators.push_back(*newAllocator); } LogicalResult mlir::tryToDestructureMemorySlots( ArrayRef allocators, OpBuilder &builder, const DataLayout &dataLayout, SROAStatistics statistics) { bool destructuredAny = false; SmallVector workList(allocators); SmallVector newWorkList; newWorkList.reserve(allocators.size()); // Destructuring a slot can allow for further destructuring of other // slots, destructuring is tried until no destructuring succeeds. while (true) { bool changesInThisRound = false; for (DestructurableAllocationOpInterface allocator : workList) { bool destructuredAnySlot = false; for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) { std::optional info = computeDestructuringInfo(slot, dataLayout); if (!info) continue; destructureSlot(slot, allocator, builder, dataLayout, *info, newWorkList, statistics); destructuredAnySlot = true; // A break is required, since destructuring a slot may invalidate the // remaning slots of an allocator. break; } if (!destructuredAnySlot) newWorkList.push_back(allocator); changesInThisRound |= destructuredAnySlot; } if (!changesInThisRound) break; destructuredAny |= changesInThisRound; // Swap the vector's backing memory and clear the entries in newWorkList // afterwards. This ensures that additional heap allocations can be avoided. workList.swap(newWorkList); newWorkList.clear(); } return success(destructuredAny); } namespace { struct SROA : public impl::SROABase { using impl::SROABase::SROABase; void runOnOperation() override { Operation *scopeOp = getOperation(); SROAStatistics statistics{&destructuredAmount, &slotsWithMemoryBenefit, &maxSubelementAmount}; auto &dataLayoutAnalysis = getAnalysis(); const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp); bool changed = false; for (Region ®ion : scopeOp->getRegions()) { if (region.getBlocks().empty()) continue; OpBuilder builder(®ion.front(), region.front().begin()); SmallVector allocators; // Build a list of allocators to attempt to destructure the slots of. region.walk([&](DestructurableAllocationOpInterface allocator) { allocators.emplace_back(allocator); }); // Attempt to destructure as many slots as possible. if (succeeded(tryToDestructureMemorySlots(allocators, builder, dataLayout, statistics))) changed = true; } if (!changed) markAllAnalysesPreserved(); } }; } // namespace