1 //===-- SROA.cpp - Scalar Replacement Of Aggregates -------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/SROA.h" 10 #include "mlir/Analysis/DataLayoutAnalysis.h" 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Analysis/TopologicalSortUtils.h" 13 #include "mlir/Interfaces/MemorySlotInterfaces.h" 14 #include "mlir/Transforms/Passes.h" 15 16 namespace mlir { 17 #define GEN_PASS_DEF_SROA 18 #include "mlir/Transforms/Passes.h.inc" 19 } // namespace mlir 20 21 #define DEBUG_TYPE "sroa" 22 23 using namespace mlir; 24 25 namespace { 26 27 /// Information computed by destructurable memory slot analysis used to perform 28 /// actual destructuring of the slot. This struct is only constructed if 29 /// destructuring is possible, and contains the necessary data to perform it. 30 struct MemorySlotDestructuringInfo { 31 /// Set of the indices that are actually used when accessing the subelements. 32 SmallPtrSet<Attribute, 8> usedIndices; 33 /// Blocking uses of a given user of the memory slot that must be eliminated. 34 DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> userToBlockingUses; 35 /// List of potentially indirect accessors of the memory slot that need 36 /// rewiring. 37 SmallVector<DestructurableAccessorOpInterface> accessors; 38 }; 39 40 } // namespace 41 42 /// Computes information for slot destructuring. This will compute whether this 43 /// slot can be destructured and data to perform the destructuring. Returns 44 /// nothing if the slot cannot be destructured or if there is no useful work to 45 /// be done. 46 static std::optional<MemorySlotDestructuringInfo> 47 computeDestructuringInfo(DestructurableMemorySlot &slot, 48 const DataLayout &dataLayout) { 49 assert(isa<DestructurableTypeInterface>(slot.elemType)); 50 51 if (slot.ptr.use_empty()) 52 return {}; 53 54 MemorySlotDestructuringInfo info; 55 56 SmallVector<MemorySlot> usedSafelyWorklist; 57 58 auto scheduleAsBlockingUse = [&](OpOperand &use) { 59 SmallPtrSetImpl<OpOperand *> &blockingUses = 60 info.userToBlockingUses[use.getOwner()]; 61 blockingUses.insert(&use); 62 }; 63 64 // Initialize the analysis with the immediate users of the slot. 65 for (OpOperand &use : slot.ptr.getUses()) { 66 if (auto accessor = 67 dyn_cast<DestructurableAccessorOpInterface>(use.getOwner())) { 68 if (accessor.canRewire(slot, info.usedIndices, usedSafelyWorklist, 69 dataLayout)) { 70 info.accessors.push_back(accessor); 71 continue; 72 } 73 } 74 75 // If it cannot be shown that the operation uses the slot safely, maybe it 76 // can be promoted out of using the slot? 77 scheduleAsBlockingUse(use); 78 } 79 80 SmallPtrSet<OpOperand *, 16> visited; 81 while (!usedSafelyWorklist.empty()) { 82 MemorySlot mustBeUsedSafely = usedSafelyWorklist.pop_back_val(); 83 for (OpOperand &subslotUse : mustBeUsedSafely.ptr.getUses()) { 84 if (!visited.insert(&subslotUse).second) 85 continue; 86 Operation *subslotUser = subslotUse.getOwner(); 87 88 if (auto memOp = dyn_cast<SafeMemorySlotAccessOpInterface>(subslotUser)) 89 if (succeeded(memOp.ensureOnlySafeAccesses( 90 mustBeUsedSafely, usedSafelyWorklist, dataLayout))) 91 continue; 92 93 // If it cannot be shown that the operation uses the slot safely, maybe it 94 // can be promoted out of using the slot? 95 scheduleAsBlockingUse(subslotUse); 96 } 97 } 98 99 SetVector<Operation *> forwardSlice; 100 mlir::getForwardSlice(slot.ptr, &forwardSlice); 101 for (Operation *user : forwardSlice) { 102 // If the next operation has no blocking uses, everything is fine. 103 auto it = info.userToBlockingUses.find(user); 104 if (it == info.userToBlockingUses.end()) 105 continue; 106 107 SmallPtrSet<OpOperand *, 4> &blockingUses = it->second; 108 auto promotable = dyn_cast<PromotableOpInterface>(user); 109 110 // An operation that has blocking uses must be promoted. If it is not 111 // promotable, destructuring must fail. 112 if (!promotable) 113 return {}; 114 115 SmallVector<OpOperand *> newBlockingUses; 116 // If the operation decides it cannot deal with removing the blocking uses, 117 // destructuring must fail. 118 if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses, dataLayout)) 119 return {}; 120 121 // Then, register any new blocking uses for coming operations. 122 for (OpOperand *blockingUse : newBlockingUses) { 123 assert(llvm::is_contained(user->getResults(), blockingUse->get())); 124 125 SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet = 126 info.userToBlockingUses[blockingUse->getOwner()]; 127 newUserBlockingUseSet.insert(blockingUse); 128 } 129 } 130 131 return info; 132 } 133 134 /// Performs the destructuring of a destructible slot given associated 135 /// destructuring information. The provided slot will be destructured in 136 /// subslots as specified by its allocator. 137 static void destructureSlot( 138 DestructurableMemorySlot &slot, 139 DestructurableAllocationOpInterface allocator, OpBuilder &builder, 140 const DataLayout &dataLayout, MemorySlotDestructuringInfo &info, 141 SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators, 142 const SROAStatistics &statistics) { 143 OpBuilder::InsertionGuard guard(builder); 144 145 builder.setInsertionPointToStart(slot.ptr.getParentBlock()); 146 DenseMap<Attribute, MemorySlot> subslots = 147 allocator.destructure(slot, info.usedIndices, builder, newAllocators); 148 149 if (statistics.slotsWithMemoryBenefit && 150 slot.subelementTypes.size() != info.usedIndices.size()) 151 (*statistics.slotsWithMemoryBenefit)++; 152 153 if (statistics.maxSubelementAmount) 154 statistics.maxSubelementAmount->updateMax(slot.subelementTypes.size()); 155 156 SetVector<Operation *> usersToRewire; 157 for (Operation *user : llvm::make_first_range(info.userToBlockingUses)) 158 usersToRewire.insert(user); 159 for (DestructurableAccessorOpInterface accessor : info.accessors) 160 usersToRewire.insert(accessor); 161 usersToRewire = mlir::topologicalSort(usersToRewire); 162 163 llvm::SmallVector<Operation *> toErase; 164 for (Operation *toRewire : llvm::reverse(usersToRewire)) { 165 builder.setInsertionPointAfter(toRewire); 166 if (auto accessor = dyn_cast<DestructurableAccessorOpInterface>(toRewire)) { 167 if (accessor.rewire(slot, subslots, builder, dataLayout) == 168 DeletionKind::Delete) 169 toErase.push_back(accessor); 170 continue; 171 } 172 173 auto promotable = cast<PromotableOpInterface>(toRewire); 174 if (promotable.removeBlockingUses(info.userToBlockingUses[promotable], 175 builder) == DeletionKind::Delete) 176 toErase.push_back(promotable); 177 } 178 179 for (Operation *toEraseOp : toErase) 180 toEraseOp->erase(); 181 182 assert(slot.ptr.use_empty() && "after destructuring, the original slot " 183 "pointer should no longer be used"); 184 185 LLVM_DEBUG(llvm::dbgs() << "[sroa] Destructured memory slot: " << slot.ptr 186 << "\n"); 187 188 if (statistics.destructuredAmount) 189 (*statistics.destructuredAmount)++; 190 191 std::optional<DestructurableAllocationOpInterface> newAllocator = 192 allocator.handleDestructuringComplete(slot, builder); 193 // Add newly created allocators to the worklist for further processing. 194 if (newAllocator) 195 newAllocators.push_back(*newAllocator); 196 } 197 198 LogicalResult mlir::tryToDestructureMemorySlots( 199 ArrayRef<DestructurableAllocationOpInterface> allocators, 200 OpBuilder &builder, const DataLayout &dataLayout, 201 SROAStatistics statistics) { 202 bool destructuredAny = false; 203 204 SmallVector<DestructurableAllocationOpInterface> workList(allocators); 205 SmallVector<DestructurableAllocationOpInterface> newWorkList; 206 newWorkList.reserve(allocators.size()); 207 // Destructuring a slot can allow for further destructuring of other 208 // slots, destructuring is tried until no destructuring succeeds. 209 while (true) { 210 bool changesInThisRound = false; 211 212 for (DestructurableAllocationOpInterface allocator : workList) { 213 bool destructuredAnySlot = false; 214 for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) { 215 std::optional<MemorySlotDestructuringInfo> info = 216 computeDestructuringInfo(slot, dataLayout); 217 if (!info) 218 continue; 219 220 destructureSlot(slot, allocator, builder, dataLayout, *info, 221 newWorkList, statistics); 222 destructuredAnySlot = true; 223 224 // A break is required, since destructuring a slot may invalidate the 225 // remaning slots of an allocator. 226 break; 227 } 228 if (!destructuredAnySlot) 229 newWorkList.push_back(allocator); 230 changesInThisRound |= destructuredAnySlot; 231 } 232 233 if (!changesInThisRound) 234 break; 235 destructuredAny |= changesInThisRound; 236 237 // Swap the vector's backing memory and clear the entries in newWorkList 238 // afterwards. This ensures that additional heap allocations can be avoided. 239 workList.swap(newWorkList); 240 newWorkList.clear(); 241 } 242 243 return success(destructuredAny); 244 } 245 246 namespace { 247 248 struct SROA : public impl::SROABase<SROA> { 249 using impl::SROABase<SROA>::SROABase; 250 251 void runOnOperation() override { 252 Operation *scopeOp = getOperation(); 253 254 SROAStatistics statistics{&destructuredAmount, &slotsWithMemoryBenefit, 255 &maxSubelementAmount}; 256 257 auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 258 const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp); 259 bool changed = false; 260 261 for (Region ®ion : scopeOp->getRegions()) { 262 if (region.getBlocks().empty()) 263 continue; 264 265 OpBuilder builder(®ion.front(), region.front().begin()); 266 267 SmallVector<DestructurableAllocationOpInterface> allocators; 268 // Build a list of allocators to attempt to destructure the slots of. 269 region.walk([&](DestructurableAllocationOpInterface allocator) { 270 allocators.emplace_back(allocator); 271 }); 272 273 // Attempt to destructure as many slots as possible. 274 if (succeeded(tryToDestructureMemorySlots(allocators, builder, dataLayout, 275 statistics))) 276 changed = true; 277 } 278 if (!changed) 279 markAllAnalysesPreserved(); 280 } 281 }; 282 283 } // namespace 284