xref: /llvm-project/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp (revision d35098bfa8e1e213f85a6b5035a5a7102f5da315)
1 //===- LLVMMemorySlot.cpp - MemorySlot interfaces ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MemorySlot-related interfaces for LLVM dialect
10 // operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/Interfaces/DataLayoutInterfaces.h"
19 #include "mlir/Interfaces/MemorySlotInterfaces.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 #define DEBUG_TYPE "sroa"
24 
25 using namespace mlir;
26 
27 //===----------------------------------------------------------------------===//
28 // Interfaces for AllocaOp
29 //===----------------------------------------------------------------------===//
30 
31 llvm::SmallVector<MemorySlot> LLVM::AllocaOp::getPromotableSlots() {
32   if (!getOperation()->getBlock()->isEntryBlock())
33     return {};
34 
35   return {MemorySlot{getResult(), getElemType()}};
36 }
37 
38 Value LLVM::AllocaOp::getDefaultValue(const MemorySlot &slot,
39                                       OpBuilder &builder) {
40   return builder.create<LLVM::UndefOp>(getLoc(), slot.elemType);
41 }
42 
43 void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot,
44                                          BlockArgument argument,
45                                          OpBuilder &builder) {
46   for (Operation *user : getOperation()->getUsers())
47     if (auto declareOp = llvm::dyn_cast<LLVM::DbgDeclareOp>(user))
48       builder.create<LLVM::DbgValueOp>(declareOp.getLoc(), argument,
49                                        declareOp.getVarInfo(),
50                                        declareOp.getLocationExpr());
51 }
52 
53 std::optional<PromotableAllocationOpInterface>
54 LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
55                                         Value defaultValue,
56                                         OpBuilder &builder) {
57   if (defaultValue && defaultValue.use_empty())
58     defaultValue.getDefiningOp()->erase();
59   this->erase();
60   return std::nullopt;
61 }
62 
63 SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
64   if (!mlir::matchPattern(getArraySize(), m_One()))
65     return {};
66 
67   auto destructurable = dyn_cast<DestructurableTypeInterface>(getElemType());
68   if (!destructurable)
69     return {};
70 
71   std::optional<DenseMap<Attribute, Type>> destructuredType =
72       destructurable.getSubelementIndexMap();
73   if (!destructuredType)
74     return {};
75 
76   return {DestructurableMemorySlot{{getResult(), getElemType()},
77                                    *destructuredType}};
78 }
79 
80 DenseMap<Attribute, MemorySlot> LLVM::AllocaOp::destructure(
81     const DestructurableMemorySlot &slot,
82     const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
83     SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
84   assert(slot.ptr == getResult());
85   builder.setInsertionPointAfter(*this);
86 
87   auto destructurableType = cast<DestructurableTypeInterface>(getElemType());
88   DenseMap<Attribute, MemorySlot> slotMap;
89   for (Attribute index : usedIndices) {
90     Type elemType = destructurableType.getTypeAtIndex(index);
91     assert(elemType && "used index must exist");
92     auto subAlloca = builder.create<LLVM::AllocaOp>(
93         getLoc(), LLVM::LLVMPointerType::get(getContext()), elemType,
94         getArraySize());
95     newAllocators.push_back(subAlloca);
96     slotMap.try_emplace<MemorySlot>(index, {subAlloca.getResult(), elemType});
97   }
98 
99   return slotMap;
100 }
101 
102 std::optional<DestructurableAllocationOpInterface>
103 LLVM::AllocaOp::handleDestructuringComplete(
104     const DestructurableMemorySlot &slot, OpBuilder &builder) {
105   assert(slot.ptr == getResult());
106   this->erase();
107   return std::nullopt;
108 }
109 
110 //===----------------------------------------------------------------------===//
111 // Interfaces for LoadOp/StoreOp
112 //===----------------------------------------------------------------------===//
113 
114 bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
115   return getAddr() == slot.ptr;
116 }
117 
118 bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
119 
120 Value LLVM::LoadOp::getStored(const MemorySlot &slot, OpBuilder &builder,
121                               Value reachingDef, const DataLayout &dataLayout) {
122   llvm_unreachable("getStored should not be called on LoadOp");
123 }
124 
125 bool LLVM::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
126 
127 bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
128   return getAddr() == slot.ptr;
129 }
130 
131 /// Checks if `type` can be used in any kind of conversion sequences.
132 static bool isSupportedTypeForConversion(Type type) {
133   // Aggregate types are not bitcastable.
134   if (isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(type))
135     return false;
136 
137   // LLVM vector types are only used for either pointers or target specific
138   // types. These types cannot be casted in the general case, thus the memory
139   // optimizations do not support them.
140   if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
141     return false;
142 
143   // Scalable types are not supported.
144   if (auto vectorType = dyn_cast<VectorType>(type))
145     return !vectorType.isScalable();
146   return true;
147 }
148 
149 /// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
150 /// truncations. Checks for narrowing or widening conversion compatibility
151 /// depending on `narrowingConversion`.
152 static bool areConversionCompatible(const DataLayout &layout, Type targetType,
153                                     Type srcType, bool narrowingConversion) {
154   if (targetType == srcType)
155     return true;
156 
157   if (!isSupportedTypeForConversion(targetType) ||
158       !isSupportedTypeForConversion(srcType))
159     return false;
160 
161   uint64_t targetSize = layout.getTypeSize(targetType);
162   uint64_t srcSize = layout.getTypeSize(srcType);
163 
164   // Pointer casts will only be sane when the bitsize of both pointer types is
165   // the same.
166   if (isa<LLVM::LLVMPointerType>(targetType) &&
167       isa<LLVM::LLVMPointerType>(srcType))
168     return targetSize == srcSize;
169 
170   if (narrowingConversion)
171     return targetSize <= srcSize;
172   return targetSize >= srcSize;
173 }
174 
175 /// Checks if `dataLayout` describes a little endian layout.
176 static bool isBigEndian(const DataLayout &dataLayout) {
177   auto endiannessStr = dyn_cast_or_null<StringAttr>(dataLayout.getEndianness());
178   return endiannessStr && endiannessStr == "big";
179 }
180 
181 /// Converts a value to an integer type of the same size.
182 /// Assumes that the type can be converted.
183 static Value castToSameSizedInt(OpBuilder &builder, Location loc, Value val,
184                                 const DataLayout &dataLayout) {
185   Type type = val.getType();
186   assert(isSupportedTypeForConversion(type) &&
187          "expected value to have a convertible type");
188 
189   if (isa<IntegerType>(type))
190     return val;
191 
192   uint64_t typeBitSize = dataLayout.getTypeSizeInBits(type);
193   IntegerType valueSizeInteger = builder.getIntegerType(typeBitSize);
194 
195   if (isa<LLVM::LLVMPointerType>(type))
196     return builder.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger, val);
197   return builder.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger, val);
198 }
199 
200 /// Converts a value with an integer type to `targetType`.
201 static Value castIntValueToSameSizedType(OpBuilder &builder, Location loc,
202                                          Value val, Type targetType) {
203   assert(isa<IntegerType>(val.getType()) &&
204          "expected value to have an integer type");
205   assert(isSupportedTypeForConversion(targetType) &&
206          "expected the target type to be supported for conversions");
207   if (val.getType() == targetType)
208     return val;
209   if (isa<LLVM::LLVMPointerType>(targetType))
210     return builder.createOrFold<LLVM::IntToPtrOp>(loc, targetType, val);
211   return builder.createOrFold<LLVM::BitcastOp>(loc, targetType, val);
212 }
213 
214 /// Constructs operations that convert `srcValue` into a new value of type
215 /// `targetType`. Assumes the types have the same bitsize.
216 static Value castSameSizedTypes(OpBuilder &builder, Location loc,
217                                 Value srcValue, Type targetType,
218                                 const DataLayout &dataLayout) {
219   Type srcType = srcValue.getType();
220   assert(areConversionCompatible(dataLayout, targetType, srcType,
221                                  /*narrowingConversion=*/true) &&
222          "expected that the compatibility was checked before");
223 
224   // Nothing has to be done if the types are already the same.
225   if (srcType == targetType)
226     return srcValue;
227 
228   // In the special case of casting one pointer to another, we want to generate
229   // an address space cast. Bitcasts of pointers are not allowed and using
230   // pointer to integer conversions are not equivalent due to the loss of
231   // provenance.
232   if (isa<LLVM::LLVMPointerType>(targetType) &&
233       isa<LLVM::LLVMPointerType>(srcType))
234     return builder.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
235                                                        srcValue);
236 
237   // For all other castable types, casting through integers is necessary.
238   Value replacement = castToSameSizedInt(builder, loc, srcValue, dataLayout);
239   return castIntValueToSameSizedType(builder, loc, replacement, targetType);
240 }
241 
242 /// Constructs operations that convert `srcValue` into a new value of type
243 /// `targetType`. Performs bit-level extraction if the source type is larger
244 /// than the target type. Assumes that this conversion is possible.
245 static Value createExtractAndCast(OpBuilder &builder, Location loc,
246                                   Value srcValue, Type targetType,
247                                   const DataLayout &dataLayout) {
248   // Get the types of the source and target values.
249   Type srcType = srcValue.getType();
250   assert(areConversionCompatible(dataLayout, targetType, srcType,
251                                  /*narrowingConversion=*/true) &&
252          "expected that the compatibility was checked before");
253 
254   uint64_t srcTypeSize = dataLayout.getTypeSizeInBits(srcType);
255   uint64_t targetTypeSize = dataLayout.getTypeSizeInBits(targetType);
256   if (srcTypeSize == targetTypeSize)
257     return castSameSizedTypes(builder, loc, srcValue, targetType, dataLayout);
258 
259   // First, cast the value to a same-sized integer type.
260   Value replacement = castToSameSizedInt(builder, loc, srcValue, dataLayout);
261 
262   // Truncate the integer if the size of the target is less than the value.
263   if (isBigEndian(dataLayout)) {
264     uint64_t shiftAmount = srcTypeSize - targetTypeSize;
265     auto shiftConstant = builder.create<LLVM::ConstantOp>(
266         loc, builder.getIntegerAttr(srcType, shiftAmount));
267     replacement =
268         builder.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
269   }
270 
271   replacement = builder.create<LLVM::TruncOp>(
272       loc, builder.getIntegerType(targetTypeSize), replacement);
273 
274   // Now cast the integer to the actual target type if required.
275   return castIntValueToSameSizedType(builder, loc, replacement, targetType);
276 }
277 
278 /// Constructs operations that insert the bits of `srcValue` into the
279 /// "beginning" of `reachingDef` (beginning is endianness dependent).
280 /// Assumes that this conversion is possible.
281 static Value createInsertAndCast(OpBuilder &builder, Location loc,
282                                  Value srcValue, Value reachingDef,
283                                  const DataLayout &dataLayout) {
284 
285   assert(areConversionCompatible(dataLayout, reachingDef.getType(),
286                                  srcValue.getType(),
287                                  /*narrowingConversion=*/false) &&
288          "expected that the compatibility was checked before");
289   uint64_t valueTypeSize = dataLayout.getTypeSizeInBits(srcValue.getType());
290   uint64_t slotTypeSize = dataLayout.getTypeSizeInBits(reachingDef.getType());
291   if (slotTypeSize == valueTypeSize)
292     return castSameSizedTypes(builder, loc, srcValue, reachingDef.getType(),
293                               dataLayout);
294 
295   // In the case where the store only overwrites parts of the memory,
296   // bit fiddling is required to construct the new value.
297 
298   // First convert both values to integers of the same size.
299   Value defAsInt = castToSameSizedInt(builder, loc, reachingDef, dataLayout);
300   Value valueAsInt = castToSameSizedInt(builder, loc, srcValue, dataLayout);
301   // Extend the value to the size of the reaching definition.
302   valueAsInt =
303       builder.createOrFold<LLVM::ZExtOp>(loc, defAsInt.getType(), valueAsInt);
304   uint64_t sizeDifference = slotTypeSize - valueTypeSize;
305   if (isBigEndian(dataLayout)) {
306     // On big endian systems, a store to the base pointer overwrites the most
307     // significant bits. To accomodate for this, the stored value needs to be
308     // shifted into the according position.
309     Value bigEndianShift = builder.create<LLVM::ConstantOp>(
310         loc, builder.getIntegerAttr(defAsInt.getType(), sizeDifference));
311     valueAsInt =
312         builder.createOrFold<LLVM::ShlOp>(loc, valueAsInt, bigEndianShift);
313   }
314 
315   // Construct the mask that is used to erase the bits that are overwritten by
316   // the store.
317   APInt maskValue;
318   if (isBigEndian(dataLayout)) {
319     // Build a mask that has the most significant bits set to zero.
320     // Note: This is the same as 2^sizeDifference - 1
321     maskValue = APInt::getAllOnes(sizeDifference).zext(slotTypeSize);
322   } else {
323     // Build a mask that has the least significant bits set to zero.
324     // Note: This is the same as -(2^valueTypeSize)
325     maskValue = APInt::getAllOnes(valueTypeSize).zext(slotTypeSize);
326     maskValue.flipAllBits();
327   }
328 
329   // Mask out the affected bits ...
330   Value mask = builder.create<LLVM::ConstantOp>(
331       loc, builder.getIntegerAttr(defAsInt.getType(), maskValue));
332   Value masked = builder.createOrFold<LLVM::AndOp>(loc, defAsInt, mask);
333 
334   // ... and combine the result with the new value.
335   Value combined = builder.createOrFold<LLVM::OrOp>(loc, masked, valueAsInt);
336 
337   return castIntValueToSameSizedType(builder, loc, combined,
338                                      reachingDef.getType());
339 }
340 
341 Value LLVM::StoreOp::getStored(const MemorySlot &slot, OpBuilder &builder,
342                                Value reachingDef,
343                                const DataLayout &dataLayout) {
344   assert(reachingDef && reachingDef.getType() == slot.elemType &&
345          "expected the reaching definition's type to match the slot's type");
346   return createInsertAndCast(builder, getLoc(), getValue(), reachingDef,
347                              dataLayout);
348 }
349 
350 bool LLVM::LoadOp::canUsesBeRemoved(
351     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
352     SmallVectorImpl<OpOperand *> &newBlockingUses,
353     const DataLayout &dataLayout) {
354   if (blockingUses.size() != 1)
355     return false;
356   Value blockingUse = (*blockingUses.begin())->get();
357   // If the blocking use is the slot ptr itself, there will be enough
358   // context to reconstruct the result of the load at removal time, so it can
359   // be removed (provided it is not volatile).
360   return blockingUse == slot.ptr && getAddr() == slot.ptr &&
361          areConversionCompatible(dataLayout, getResult().getType(),
362                                  slot.elemType, /*narrowingConversion=*/true) &&
363          !getVolatile_();
364 }
365 
366 DeletionKind LLVM::LoadOp::removeBlockingUses(
367     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
368     OpBuilder &builder, Value reachingDefinition,
369     const DataLayout &dataLayout) {
370   // `canUsesBeRemoved` checked this blocking use must be the loaded slot
371   // pointer.
372   Value newResult = createExtractAndCast(builder, getLoc(), reachingDefinition,
373                                          getResult().getType(), dataLayout);
374   getResult().replaceAllUsesWith(newResult);
375   return DeletionKind::Delete;
376 }
377 
378 bool LLVM::StoreOp::canUsesBeRemoved(
379     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
380     SmallVectorImpl<OpOperand *> &newBlockingUses,
381     const DataLayout &dataLayout) {
382   if (blockingUses.size() != 1)
383     return false;
384   Value blockingUse = (*blockingUses.begin())->get();
385   // If the blocking use is the slot ptr itself, dropping the store is
386   // fine, provided we are currently promoting its target value. Don't allow a
387   // store OF the slot pointer, only INTO the slot pointer.
388   return blockingUse == slot.ptr && getAddr() == slot.ptr &&
389          getValue() != slot.ptr &&
390          areConversionCompatible(dataLayout, slot.elemType,
391                                  getValue().getType(),
392                                  /*narrowingConversion=*/false) &&
393          !getVolatile_();
394 }
395 
396 DeletionKind LLVM::StoreOp::removeBlockingUses(
397     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
398     OpBuilder &builder, Value reachingDefinition,
399     const DataLayout &dataLayout) {
400   return DeletionKind::Delete;
401 }
402 
403 /// Checks if `slot` can be accessed through the provided access type.
404 static bool isValidAccessType(const MemorySlot &slot, Type accessType,
405                               const DataLayout &dataLayout) {
406   return dataLayout.getTypeSize(accessType) <=
407          dataLayout.getTypeSize(slot.elemType);
408 }
409 
410 LogicalResult LLVM::LoadOp::ensureOnlySafeAccesses(
411     const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
412     const DataLayout &dataLayout) {
413   return success(getAddr() != slot.ptr ||
414                  isValidAccessType(slot, getType(), dataLayout));
415 }
416 
417 LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses(
418     const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
419     const DataLayout &dataLayout) {
420   return success(getAddr() != slot.ptr ||
421                  isValidAccessType(slot, getValue().getType(), dataLayout));
422 }
423 
424 /// Returns the subslot's type at the requested index.
425 static Type getTypeAtIndex(const DestructurableMemorySlot &slot,
426                            Attribute index) {
427   auto subelementIndexMap =
428       cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap();
429   if (!subelementIndexMap)
430     return {};
431   assert(!subelementIndexMap->empty());
432 
433   // Note: Returns a null-type when no entry was found.
434   return subelementIndexMap->lookup(index);
435 }
436 
437 bool LLVM::LoadOp::canRewire(const DestructurableMemorySlot &slot,
438                              SmallPtrSetImpl<Attribute> &usedIndices,
439                              SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
440                              const DataLayout &dataLayout) {
441   if (getVolatile_())
442     return false;
443 
444   // A load always accesses the first element of the destructured slot.
445   auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
446   Type subslotType = getTypeAtIndex(slot, index);
447   if (!subslotType)
448     return false;
449 
450   // The access can only be replaced when the subslot is read within its bounds.
451   if (dataLayout.getTypeSize(getType()) > dataLayout.getTypeSize(subslotType))
452     return false;
453 
454   usedIndices.insert(index);
455   return true;
456 }
457 
458 DeletionKind LLVM::LoadOp::rewire(const DestructurableMemorySlot &slot,
459                                   DenseMap<Attribute, MemorySlot> &subslots,
460                                   OpBuilder &builder,
461                                   const DataLayout &dataLayout) {
462   auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
463   auto it = subslots.find(index);
464   assert(it != subslots.end());
465 
466   getAddrMutable().set(it->getSecond().ptr);
467   return DeletionKind::Keep;
468 }
469 
470 bool LLVM::StoreOp::canRewire(const DestructurableMemorySlot &slot,
471                               SmallPtrSetImpl<Attribute> &usedIndices,
472                               SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
473                               const DataLayout &dataLayout) {
474   if (getVolatile_())
475     return false;
476 
477   // Storing the pointer to memory cannot be dealt with.
478   if (getValue() == slot.ptr)
479     return false;
480 
481   // A store always accesses the first element of the destructured slot.
482   auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
483   Type subslotType = getTypeAtIndex(slot, index);
484   if (!subslotType)
485     return false;
486 
487   // The access can only be replaced when the subslot is read within its bounds.
488   if (dataLayout.getTypeSize(getValue().getType()) >
489       dataLayout.getTypeSize(subslotType))
490     return false;
491 
492   usedIndices.insert(index);
493   return true;
494 }
495 
496 DeletionKind LLVM::StoreOp::rewire(const DestructurableMemorySlot &slot,
497                                    DenseMap<Attribute, MemorySlot> &subslots,
498                                    OpBuilder &builder,
499                                    const DataLayout &dataLayout) {
500   auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
501   auto it = subslots.find(index);
502   assert(it != subslots.end());
503 
504   getAddrMutable().set(it->getSecond().ptr);
505   return DeletionKind::Keep;
506 }
507 
508 //===----------------------------------------------------------------------===//
509 // Interfaces for discardable OPs
510 //===----------------------------------------------------------------------===//
511 
512 /// Conditions the deletion of the operation to the removal of all its uses.
513 static bool forwardToUsers(Operation *op,
514                            SmallVectorImpl<OpOperand *> &newBlockingUses) {
515   for (Value result : op->getResults())
516     for (OpOperand &use : result.getUses())
517       newBlockingUses.push_back(&use);
518   return true;
519 }
520 
521 bool LLVM::BitcastOp::canUsesBeRemoved(
522     const SmallPtrSetImpl<OpOperand *> &blockingUses,
523     SmallVectorImpl<OpOperand *> &newBlockingUses,
524     const DataLayout &dataLayout) {
525   return forwardToUsers(*this, newBlockingUses);
526 }
527 
528 DeletionKind LLVM::BitcastOp::removeBlockingUses(
529     const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
530   return DeletionKind::Delete;
531 }
532 
533 bool LLVM::AddrSpaceCastOp::canUsesBeRemoved(
534     const SmallPtrSetImpl<OpOperand *> &blockingUses,
535     SmallVectorImpl<OpOperand *> &newBlockingUses,
536     const DataLayout &dataLayout) {
537   return forwardToUsers(*this, newBlockingUses);
538 }
539 
540 DeletionKind LLVM::AddrSpaceCastOp::removeBlockingUses(
541     const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
542   return DeletionKind::Delete;
543 }
544 
545 bool LLVM::LifetimeStartOp::canUsesBeRemoved(
546     const SmallPtrSetImpl<OpOperand *> &blockingUses,
547     SmallVectorImpl<OpOperand *> &newBlockingUses,
548     const DataLayout &dataLayout) {
549   return true;
550 }
551 
552 DeletionKind LLVM::LifetimeStartOp::removeBlockingUses(
553     const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
554   return DeletionKind::Delete;
555 }
556 
557 bool LLVM::LifetimeEndOp::canUsesBeRemoved(
558     const SmallPtrSetImpl<OpOperand *> &blockingUses,
559     SmallVectorImpl<OpOperand *> &newBlockingUses,
560     const DataLayout &dataLayout) {
561   return true;
562 }
563 
564 DeletionKind LLVM::LifetimeEndOp::removeBlockingUses(
565     const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
566   return DeletionKind::Delete;
567 }
568 
569 bool LLVM::InvariantStartOp::canUsesBeRemoved(
570     const SmallPtrSetImpl<OpOperand *> &blockingUses,
571     SmallVectorImpl<OpOperand *> &newBlockingUses,
572     const DataLayout &dataLayout) {
573   return true;
574 }
575 
576 DeletionKind LLVM::InvariantStartOp::removeBlockingUses(
577     const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
578   return DeletionKind::Delete;
579 }
580 
581 bool LLVM::InvariantEndOp::canUsesBeRemoved(
582     const SmallPtrSetImpl<OpOperand *> &blockingUses,
583     SmallVectorImpl<OpOperand *> &newBlockingUses,
584     const DataLayout &dataLayout) {
585   return true;
586 }
587 
588 DeletionKind LLVM::InvariantEndOp::removeBlockingUses(
589     const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
590   return DeletionKind::Delete;
591 }
592 
593 bool LLVM::LaunderInvariantGroupOp::canUsesBeRemoved(
594     const SmallPtrSetImpl<OpOperand *> &blockingUses,
595     SmallVectorImpl<OpOperand *> &newBlockingUses,
596     const DataLayout &dataLayout) {
597   return forwardToUsers(*this, newBlockingUses);
598 }
599 
600 DeletionKind LLVM::LaunderInvariantGroupOp::removeBlockingUses(
601     const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
602   return DeletionKind::Delete;
603 }
604 
605 bool LLVM::StripInvariantGroupOp::canUsesBeRemoved(
606     const SmallPtrSetImpl<OpOperand *> &blockingUses,
607     SmallVectorImpl<OpOperand *> &newBlockingUses,
608     const DataLayout &dataLayout) {
609   return forwardToUsers(*this, newBlockingUses);
610 }
611 
612 DeletionKind LLVM::StripInvariantGroupOp::removeBlockingUses(
613     const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
614   return DeletionKind::Delete;
615 }
616 
617 bool LLVM::DbgDeclareOp::canUsesBeRemoved(
618     const SmallPtrSetImpl<OpOperand *> &blockingUses,
619     SmallVectorImpl<OpOperand *> &newBlockingUses,
620     const DataLayout &dataLayout) {
621   return true;
622 }
623 
624 DeletionKind LLVM::DbgDeclareOp::removeBlockingUses(
625     const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
626   return DeletionKind::Delete;
627 }
628 
629 bool LLVM::DbgValueOp::canUsesBeRemoved(
630     const SmallPtrSetImpl<OpOperand *> &blockingUses,
631     SmallVectorImpl<OpOperand *> &newBlockingUses,
632     const DataLayout &dataLayout) {
633   // There is only one operand that we can remove the use of.
634   if (blockingUses.size() != 1)
635     return false;
636 
637   return (*blockingUses.begin())->get() == getValue();
638 }
639 
640 DeletionKind LLVM::DbgValueOp::removeBlockingUses(
641     const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
642   // builder by default is after '*this', but we need it before '*this'.
643   builder.setInsertionPoint(*this);
644 
645   // Rather than dropping the debug value, replace it with undef to preserve the
646   // debug local variable info. This allows the debugger to inform the user that
647   // the variable has been optimized out.
648   auto undef =
649       builder.create<UndefOp>(getValue().getLoc(), getValue().getType());
650   getValueMutable().assign(undef);
651   return DeletionKind::Keep;
652 }
653 
654 bool LLVM::DbgDeclareOp::requiresReplacedValues() { return true; }
655 
656 void LLVM::DbgDeclareOp::visitReplacedValues(
657     ArrayRef<std::pair<Operation *, Value>> definitions, OpBuilder &builder) {
658   for (auto [op, value] : definitions) {
659     builder.setInsertionPointAfter(op);
660     builder.create<LLVM::DbgValueOp>(getLoc(), value, getVarInfo(),
661                                      getLocationExpr());
662   }
663 }
664 
665 //===----------------------------------------------------------------------===//
666 // Interfaces for GEPOp
667 //===----------------------------------------------------------------------===//
668 
669 static bool hasAllZeroIndices(LLVM::GEPOp gepOp) {
670   return llvm::all_of(gepOp.getIndices(), [](auto index) {
671     auto indexAttr = llvm::dyn_cast_if_present<IntegerAttr>(index);
672     return indexAttr && indexAttr.getValue() == 0;
673   });
674 }
675 
676 bool LLVM::GEPOp::canUsesBeRemoved(
677     const SmallPtrSetImpl<OpOperand *> &blockingUses,
678     SmallVectorImpl<OpOperand *> &newBlockingUses,
679     const DataLayout &dataLayout) {
680   // GEP can be removed as long as it is a no-op and its users can be removed.
681   if (!hasAllZeroIndices(*this))
682     return false;
683   return forwardToUsers(*this, newBlockingUses);
684 }
685 
686 DeletionKind LLVM::GEPOp::removeBlockingUses(
687     const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
688   return DeletionKind::Delete;
689 }
690 
691 /// Returns the amount of bytes the provided GEP elements will offset the
692 /// pointer by. Returns nullopt if no constant offset could be computed.
693 static std::optional<uint64_t> gepToByteOffset(const DataLayout &dataLayout,
694                                                LLVM::GEPOp gep) {
695   // Collects all indices.
696   SmallVector<uint64_t> indices;
697   for (auto index : gep.getIndices()) {
698     auto constIndex = dyn_cast<IntegerAttr>(index);
699     if (!constIndex)
700       return {};
701     int64_t gepIndex = constIndex.getInt();
702     // Negative indices are not supported.
703     if (gepIndex < 0)
704       return {};
705     indices.push_back(gepIndex);
706   }
707 
708   Type currentType = gep.getElemType();
709   uint64_t offset = indices[0] * dataLayout.getTypeSize(currentType);
710 
711   for (uint64_t index : llvm::drop_begin(indices)) {
712     bool shouldCancel =
713         TypeSwitch<Type, bool>(currentType)
714             .Case([&](LLVM::LLVMArrayType arrayType) {
715               offset +=
716                   index * dataLayout.getTypeSize(arrayType.getElementType());
717               currentType = arrayType.getElementType();
718               return false;
719             })
720             .Case([&](LLVM::LLVMStructType structType) {
721               ArrayRef<Type> body = structType.getBody();
722               assert(index < body.size() && "expected valid struct indexing");
723               for (uint32_t i : llvm::seq(index)) {
724                 if (!structType.isPacked())
725                   offset = llvm::alignTo(
726                       offset, dataLayout.getTypeABIAlignment(body[i]));
727                 offset += dataLayout.getTypeSize(body[i]);
728               }
729 
730               // Align for the current type as well.
731               if (!structType.isPacked())
732                 offset = llvm::alignTo(
733                     offset, dataLayout.getTypeABIAlignment(body[index]));
734               currentType = body[index];
735               return false;
736             })
737             .Default([&](Type type) {
738               LLVM_DEBUG(llvm::dbgs()
739                          << "[sroa] Unsupported type for offset computations"
740                          << type << "\n");
741               return true;
742             });
743 
744     if (shouldCancel)
745       return std::nullopt;
746   }
747 
748   return offset;
749 }
750 
751 namespace {
752 /// A struct that stores both the index into the aggregate type of the slot as
753 /// well as the corresponding byte offset in memory.
754 struct SubslotAccessInfo {
755   /// The parent slot's index that the access falls into.
756   uint32_t index;
757   /// The offset into the subslot of the access.
758   uint64_t subslotOffset;
759 };
760 } // namespace
761 
762 /// Computes subslot access information for an access into `slot` with the given
763 /// offset.
764 /// Returns nullopt when the offset is out-of-bounds or when the access is into
765 /// the padding of `slot`.
766 static std::optional<SubslotAccessInfo>
767 getSubslotAccessInfo(const DestructurableMemorySlot &slot,
768                      const DataLayout &dataLayout, LLVM::GEPOp gep) {
769   std::optional<uint64_t> offset = gepToByteOffset(dataLayout, gep);
770   if (!offset)
771     return {};
772 
773   // Helper to check that a constant index is in the bounds of the GEP index
774   // representation. LLVM dialects's GEP arguments have a limited bitwidth, thus
775   // this additional check is necessary.
776   auto isOutOfBoundsGEPIndex = [](uint64_t index) {
777     return index >= (1 << LLVM::kGEPConstantBitWidth);
778   };
779 
780   Type type = slot.elemType;
781   if (*offset >= dataLayout.getTypeSize(type))
782     return {};
783   return TypeSwitch<Type, std::optional<SubslotAccessInfo>>(type)
784       .Case([&](LLVM::LLVMArrayType arrayType)
785                 -> std::optional<SubslotAccessInfo> {
786         // Find which element of the array contains the offset.
787         uint64_t elemSize = dataLayout.getTypeSize(arrayType.getElementType());
788         uint64_t index = *offset / elemSize;
789         if (isOutOfBoundsGEPIndex(index))
790           return {};
791         return SubslotAccessInfo{static_cast<uint32_t>(index),
792                                  *offset - (index * elemSize)};
793       })
794       .Case([&](LLVM::LLVMStructType structType)
795                 -> std::optional<SubslotAccessInfo> {
796         uint64_t distanceToStart = 0;
797         // Walk over the elements of the struct to find in which of
798         // them the offset is.
799         for (auto [index, elem] : llvm::enumerate(structType.getBody())) {
800           uint64_t elemSize = dataLayout.getTypeSize(elem);
801           if (!structType.isPacked()) {
802             distanceToStart = llvm::alignTo(
803                 distanceToStart, dataLayout.getTypeABIAlignment(elem));
804             // If the offset is in padding, cancel the rewrite.
805             if (offset < distanceToStart)
806               return {};
807           }
808 
809           if (offset < distanceToStart + elemSize) {
810             if (isOutOfBoundsGEPIndex(index))
811               return {};
812             // The offset is within this element, stop iterating the
813             // struct and return the index.
814             return SubslotAccessInfo{static_cast<uint32_t>(index),
815                                      *offset - distanceToStart};
816           }
817 
818           // The offset is not within this element, continue walking
819           // over the struct.
820           distanceToStart += elemSize;
821         }
822 
823         return {};
824       });
825 }
826 
827 /// Constructs a byte array type of the given size.
828 static LLVM::LLVMArrayType getByteArrayType(MLIRContext *context,
829                                             unsigned size) {
830   auto byteType = IntegerType::get(context, 8);
831   return LLVM::LLVMArrayType::get(context, byteType, size);
832 }
833 
834 LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses(
835     const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
836     const DataLayout &dataLayout) {
837   if (getBase() != slot.ptr)
838     return success();
839   std::optional<uint64_t> gepOffset = gepToByteOffset(dataLayout, *this);
840   if (!gepOffset)
841     return failure();
842   uint64_t slotSize = dataLayout.getTypeSize(slot.elemType);
843   // Check that the access is strictly inside the slot.
844   if (*gepOffset >= slotSize)
845     return failure();
846   // Every access that remains in bounds of the remaining slot is considered
847   // legal.
848   mustBeSafelyUsed.emplace_back<MemorySlot>(
849       {getRes(), getByteArrayType(getContext(), slotSize - *gepOffset)});
850   return success();
851 }
852 
853 bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
854                             SmallPtrSetImpl<Attribute> &usedIndices,
855                             SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
856                             const DataLayout &dataLayout) {
857   if (!isa<LLVM::LLVMPointerType>(getBase().getType()))
858     return false;
859 
860   if (getBase() != slot.ptr)
861     return false;
862   std::optional<SubslotAccessInfo> accessInfo =
863       getSubslotAccessInfo(slot, dataLayout, *this);
864   if (!accessInfo)
865     return false;
866   auto indexAttr =
867       IntegerAttr::get(IntegerType::get(getContext(), 32), accessInfo->index);
868   assert(slot.subelementTypes.contains(indexAttr));
869   usedIndices.insert(indexAttr);
870 
871   // The remainder of the subslot should be accesses in-bounds. Thus, we create
872   // a dummy slot with the size of the remainder.
873   Type subslotType = slot.subelementTypes.lookup(indexAttr);
874   uint64_t slotSize = dataLayout.getTypeSize(subslotType);
875   LLVM::LLVMArrayType remainingSlotType =
876       getByteArrayType(getContext(), slotSize - accessInfo->subslotOffset);
877   mustBeSafelyUsed.emplace_back<MemorySlot>({getRes(), remainingSlotType});
878 
879   return true;
880 }
881 
882 DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
883                                  DenseMap<Attribute, MemorySlot> &subslots,
884                                  OpBuilder &builder,
885                                  const DataLayout &dataLayout) {
886   std::optional<SubslotAccessInfo> accessInfo =
887       getSubslotAccessInfo(slot, dataLayout, *this);
888   assert(accessInfo && "expected access info to be checked before");
889   auto indexAttr =
890       IntegerAttr::get(IntegerType::get(getContext(), 32), accessInfo->index);
891   const MemorySlot &newSlot = subslots.at(indexAttr);
892 
893   auto byteType = IntegerType::get(builder.getContext(), 8);
894   auto newPtr = builder.createOrFold<LLVM::GEPOp>(
895       getLoc(), getResult().getType(), byteType, newSlot.ptr,
896       ArrayRef<GEPArg>(accessInfo->subslotOffset), getInbounds());
897   getResult().replaceAllUsesWith(newPtr);
898   return DeletionKind::Delete;
899 }
900 
901 //===----------------------------------------------------------------------===//
902 // Utilities for memory intrinsics
903 //===----------------------------------------------------------------------===//
904 
905 namespace {
906 
907 /// Returns the length of the given memory intrinsic in bytes if it can be known
908 /// at compile-time on a best-effort basis, nothing otherwise.
909 template <class MemIntr>
910 std::optional<uint64_t> getStaticMemIntrLen(MemIntr op) {
911   APInt memIntrLen;
912   if (!matchPattern(op.getLen(), m_ConstantInt(&memIntrLen)))
913     return {};
914   if (memIntrLen.getBitWidth() > 64)
915     return {};
916   return memIntrLen.getZExtValue();
917 }
918 
919 /// Returns the length of the given memory intrinsic in bytes if it can be known
920 /// at compile-time on a best-effort basis, nothing otherwise.
921 /// Because MemcpyInlineOp has its length encoded as an attribute, this requires
922 /// specialized handling.
923 template <>
924 std::optional<uint64_t> getStaticMemIntrLen(LLVM::MemcpyInlineOp op) {
925   APInt memIntrLen = op.getLen();
926   if (memIntrLen.getBitWidth() > 64)
927     return {};
928   return memIntrLen.getZExtValue();
929 }
930 
931 /// Returns the length of the given memory intrinsic in bytes if it can be known
932 /// at compile-time on a best-effort basis, nothing otherwise.
933 /// Because MemsetInlineOp has its length encoded as an attribute, this requires
934 /// specialized handling.
935 template <>
936 std::optional<uint64_t> getStaticMemIntrLen(LLVM::MemsetInlineOp op) {
937   APInt memIntrLen = op.getLen();
938   if (memIntrLen.getBitWidth() > 64)
939     return {};
940   return memIntrLen.getZExtValue();
941 }
942 
943 /// Returns an integer attribute representing the length of a memset intrinsic
944 template <class MemsetIntr>
945 IntegerAttr createMemsetLenAttr(MemsetIntr op) {
946   IntegerAttr memsetLenAttr;
947   bool successfulMatch =
948       matchPattern(op.getLen(), m_Constant<IntegerAttr>(&memsetLenAttr));
949   (void)successfulMatch;
950   assert(successfulMatch);
951   return memsetLenAttr;
952 }
953 
954 /// Returns an integer attribute representing the length of a memset intrinsic
955 /// Because MemsetInlineOp has its length encoded as an attribute, this requires
956 /// specialized handling.
957 template <>
958 IntegerAttr createMemsetLenAttr(LLVM::MemsetInlineOp op) {
959   return op.getLenAttr();
960 }
961 
962 /// Creates a memset intrinsic of that matches the `toReplace` intrinsic
963 /// using the provided parameters. There are template specializations for
964 /// MemsetOp and MemsetInlineOp.
965 template <class MemsetIntr>
966 void createMemsetIntr(OpBuilder &builder, MemsetIntr toReplace,
967                       IntegerAttr memsetLenAttr, uint64_t newMemsetSize,
968                       DenseMap<Attribute, MemorySlot> &subslots,
969                       Attribute index);
970 
971 template <>
972 void createMemsetIntr(OpBuilder &builder, LLVM::MemsetOp toReplace,
973                       IntegerAttr memsetLenAttr, uint64_t newMemsetSize,
974                       DenseMap<Attribute, MemorySlot> &subslots,
975                       Attribute index) {
976   Value newMemsetSizeValue =
977       builder
978           .create<LLVM::ConstantOp>(
979               toReplace.getLen().getLoc(),
980               IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize))
981           .getResult();
982 
983   builder.create<LLVM::MemsetOp>(toReplace.getLoc(), subslots.at(index).ptr,
984                                  toReplace.getVal(), newMemsetSizeValue,
985                                  toReplace.getIsVolatile());
986 }
987 
988 template <>
989 void createMemsetIntr(OpBuilder &builder, LLVM::MemsetInlineOp toReplace,
990                       IntegerAttr memsetLenAttr, uint64_t newMemsetSize,
991                       DenseMap<Attribute, MemorySlot> &subslots,
992                       Attribute index) {
993   auto newMemsetSizeValue =
994       IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize);
995 
996   builder.create<LLVM::MemsetInlineOp>(
997       toReplace.getLoc(), subslots.at(index).ptr, toReplace.getVal(),
998       newMemsetSizeValue, toReplace.getIsVolatile());
999 }
1000 
1001 } // namespace
1002 
1003 /// Returns whether one can be sure the memory intrinsic does not write outside
1004 /// of the bounds of the given slot, on a best-effort basis.
1005 template <class MemIntr>
1006 static bool definitelyWritesOnlyWithinSlot(MemIntr op, const MemorySlot &slot,
1007                                            const DataLayout &dataLayout) {
1008   if (!isa<LLVM::LLVMPointerType>(slot.ptr.getType()) ||
1009       op.getDst() != slot.ptr)
1010     return false;
1011 
1012   std::optional<uint64_t> memIntrLen = getStaticMemIntrLen(op);
1013   return memIntrLen && *memIntrLen <= dataLayout.getTypeSize(slot.elemType);
1014 }
1015 
1016 /// Checks whether all indices are i32. This is used to check GEPs can index
1017 /// into them.
1018 static bool areAllIndicesI32(const DestructurableMemorySlot &slot) {
1019   Type i32 = IntegerType::get(slot.ptr.getContext(), 32);
1020   return llvm::all_of(llvm::make_first_range(slot.subelementTypes),
1021                       [&](Attribute index) {
1022                         auto intIndex = dyn_cast<IntegerAttr>(index);
1023                         return intIndex && intIndex.getType() == i32;
1024                       });
1025 }
1026 
1027 //===----------------------------------------------------------------------===//
1028 // Interfaces for memset and memset.inline
1029 //===----------------------------------------------------------------------===//
1030 
1031 template <class MemsetIntr>
1032 static bool memsetCanRewire(MemsetIntr op, const DestructurableMemorySlot &slot,
1033                             SmallPtrSetImpl<Attribute> &usedIndices,
1034                             SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1035                             const DataLayout &dataLayout) {
1036   if (&slot.elemType.getDialect() != op.getOperation()->getDialect())
1037     return false;
1038 
1039   if (op.getIsVolatile())
1040     return false;
1041 
1042   if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
1043     return false;
1044 
1045   if (!areAllIndicesI32(slot))
1046     return false;
1047 
1048   return definitelyWritesOnlyWithinSlot(op, slot, dataLayout);
1049 }
1050 
1051 template <class MemsetIntr>
1052 static Value memsetGetStored(MemsetIntr op, const MemorySlot &slot,
1053                              OpBuilder &builder) {
1054   // TODO: Support non-integer types.
1055   return TypeSwitch<Type, Value>(slot.elemType)
1056       .Case([&](IntegerType intType) -> Value {
1057         if (intType.getWidth() == 8)
1058           return op.getVal();
1059 
1060         assert(intType.getWidth() % 8 == 0);
1061 
1062         // Build the memset integer by repeatedly shifting the value and
1063         // or-ing it with the previous value.
1064         uint64_t coveredBits = 8;
1065         Value currentValue =
1066             builder.create<LLVM::ZExtOp>(op.getLoc(), intType, op.getVal());
1067         while (coveredBits < intType.getWidth()) {
1068           Value shiftBy = builder.create<LLVM::ConstantOp>(op.getLoc(), intType,
1069                                                            coveredBits);
1070           Value shifted =
1071               builder.create<LLVM::ShlOp>(op.getLoc(), currentValue, shiftBy);
1072           currentValue =
1073               builder.create<LLVM::OrOp>(op.getLoc(), currentValue, shifted);
1074           coveredBits *= 2;
1075         }
1076 
1077         return currentValue;
1078       })
1079       .Default([](Type) -> Value {
1080         llvm_unreachable(
1081             "getStored should not be called on memset to unsupported type");
1082       });
1083 }
1084 
1085 template <class MemsetIntr>
1086 static bool
1087 memsetCanUsesBeRemoved(MemsetIntr op, const MemorySlot &slot,
1088                        const SmallPtrSetImpl<OpOperand *> &blockingUses,
1089                        SmallVectorImpl<OpOperand *> &newBlockingUses,
1090                        const DataLayout &dataLayout) {
1091   // TODO: Support non-integer types.
1092   bool canConvertType =
1093       TypeSwitch<Type, bool>(slot.elemType)
1094           .Case([](IntegerType intType) {
1095             return intType.getWidth() % 8 == 0 && intType.getWidth() > 0;
1096           })
1097           .Default([](Type) { return false; });
1098   if (!canConvertType)
1099     return false;
1100 
1101   if (op.getIsVolatile())
1102     return false;
1103 
1104   return getStaticMemIntrLen(op) == dataLayout.getTypeSize(slot.elemType);
1105 }
1106 
1107 template <class MemsetIntr>
1108 static DeletionKind
1109 memsetRewire(MemsetIntr op, const DestructurableMemorySlot &slot,
1110              DenseMap<Attribute, MemorySlot> &subslots, OpBuilder &builder,
1111              const DataLayout &dataLayout) {
1112 
1113   std::optional<DenseMap<Attribute, Type>> types =
1114       cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap();
1115 
1116   IntegerAttr memsetLenAttr = createMemsetLenAttr(op);
1117 
1118   bool packed = false;
1119   if (auto structType = dyn_cast<LLVM::LLVMStructType>(slot.elemType))
1120     packed = structType.isPacked();
1121 
1122   Type i32 = IntegerType::get(op.getContext(), 32);
1123   uint64_t memsetLen = memsetLenAttr.getValue().getZExtValue();
1124   uint64_t covered = 0;
1125   for (size_t i = 0; i < types->size(); i++) {
1126     // Create indices on the fly to get elements in the right order.
1127     Attribute index = IntegerAttr::get(i32, i);
1128     Type elemType = types->at(index);
1129     uint64_t typeSize = dataLayout.getTypeSize(elemType);
1130 
1131     if (!packed)
1132       covered =
1133           llvm::alignTo(covered, dataLayout.getTypeABIAlignment(elemType));
1134 
1135     if (covered >= memsetLen)
1136       break;
1137 
1138     // If this subslot is used, apply a new memset to it.
1139     // Otherwise, only compute its offset within the original memset.
1140     if (subslots.contains(index)) {
1141       uint64_t newMemsetSize = std::min(memsetLen - covered, typeSize);
1142       createMemsetIntr(builder, op, memsetLenAttr, newMemsetSize, subslots,
1143                        index);
1144     }
1145 
1146     covered += typeSize;
1147   }
1148 
1149   return DeletionKind::Delete;
1150 }
1151 
1152 bool LLVM::MemsetOp::loadsFrom(const MemorySlot &slot) { return false; }
1153 
1154 bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
1155   return getDst() == slot.ptr;
1156 }
1157 
1158 Value LLVM::MemsetOp::getStored(const MemorySlot &slot, OpBuilder &builder,
1159                                 Value reachingDef,
1160                                 const DataLayout &dataLayout) {
1161   return memsetGetStored(*this, slot, builder);
1162 }
1163 
1164 bool LLVM::MemsetOp::canUsesBeRemoved(
1165     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1166     SmallVectorImpl<OpOperand *> &newBlockingUses,
1167     const DataLayout &dataLayout) {
1168   return memsetCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
1169                                 dataLayout);
1170 }
1171 
1172 DeletionKind LLVM::MemsetOp::removeBlockingUses(
1173     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1174     OpBuilder &builder, Value reachingDefinition,
1175     const DataLayout &dataLayout) {
1176   return DeletionKind::Delete;
1177 }
1178 
1179 LogicalResult LLVM::MemsetOp::ensureOnlySafeAccesses(
1180     const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1181     const DataLayout &dataLayout) {
1182   return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
1183 }
1184 
1185 bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
1186                                SmallPtrSetImpl<Attribute> &usedIndices,
1187                                SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1188                                const DataLayout &dataLayout) {
1189   return memsetCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
1190                          dataLayout);
1191 }
1192 
1193 DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
1194                                     DenseMap<Attribute, MemorySlot> &subslots,
1195                                     OpBuilder &builder,
1196                                     const DataLayout &dataLayout) {
1197   return memsetRewire(*this, slot, subslots, builder, dataLayout);
1198 }
1199 
1200 bool LLVM::MemsetInlineOp::loadsFrom(const MemorySlot &slot) { return false; }
1201 
1202 bool LLVM::MemsetInlineOp::storesTo(const MemorySlot &slot) {
1203   return getDst() == slot.ptr;
1204 }
1205 
1206 Value LLVM::MemsetInlineOp::getStored(const MemorySlot &slot,
1207                                       OpBuilder &builder, Value reachingDef,
1208                                       const DataLayout &dataLayout) {
1209   return memsetGetStored(*this, slot, builder);
1210 }
1211 
1212 bool LLVM::MemsetInlineOp::canUsesBeRemoved(
1213     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1214     SmallVectorImpl<OpOperand *> &newBlockingUses,
1215     const DataLayout &dataLayout) {
1216   return memsetCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
1217                                 dataLayout);
1218 }
1219 
1220 DeletionKind LLVM::MemsetInlineOp::removeBlockingUses(
1221     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1222     OpBuilder &builder, Value reachingDefinition,
1223     const DataLayout &dataLayout) {
1224   return DeletionKind::Delete;
1225 }
1226 
1227 LogicalResult LLVM::MemsetInlineOp::ensureOnlySafeAccesses(
1228     const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1229     const DataLayout &dataLayout) {
1230   return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
1231 }
1232 
1233 bool LLVM::MemsetInlineOp::canRewire(
1234     const DestructurableMemorySlot &slot,
1235     SmallPtrSetImpl<Attribute> &usedIndices,
1236     SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1237     const DataLayout &dataLayout) {
1238   return memsetCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
1239                          dataLayout);
1240 }
1241 
1242 DeletionKind
1243 LLVM::MemsetInlineOp::rewire(const DestructurableMemorySlot &slot,
1244                              DenseMap<Attribute, MemorySlot> &subslots,
1245                              OpBuilder &builder, const DataLayout &dataLayout) {
1246   return memsetRewire(*this, slot, subslots, builder, dataLayout);
1247 }
1248 
1249 //===----------------------------------------------------------------------===//
1250 // Interfaces for memcpy/memmove
1251 //===----------------------------------------------------------------------===//
1252 
1253 template <class MemcpyLike>
1254 static bool memcpyLoadsFrom(MemcpyLike op, const MemorySlot &slot) {
1255   return op.getSrc() == slot.ptr;
1256 }
1257 
1258 template <class MemcpyLike>
1259 static bool memcpyStoresTo(MemcpyLike op, const MemorySlot &slot) {
1260   return op.getDst() == slot.ptr;
1261 }
1262 
1263 template <class MemcpyLike>
1264 static Value memcpyGetStored(MemcpyLike op, const MemorySlot &slot,
1265                              OpBuilder &builder) {
1266   return builder.create<LLVM::LoadOp>(op.getLoc(), slot.elemType, op.getSrc());
1267 }
1268 
1269 template <class MemcpyLike>
1270 static bool
1271 memcpyCanUsesBeRemoved(MemcpyLike op, const MemorySlot &slot,
1272                        const SmallPtrSetImpl<OpOperand *> &blockingUses,
1273                        SmallVectorImpl<OpOperand *> &newBlockingUses,
1274                        const DataLayout &dataLayout) {
1275   // If source and destination are the same, memcpy behavior is undefined and
1276   // memmove is a no-op. Because there is no memory change happening here,
1277   // simplifying such operations is left to canonicalization.
1278   if (op.getDst() == op.getSrc())
1279     return false;
1280 
1281   if (op.getIsVolatile())
1282     return false;
1283 
1284   return getStaticMemIntrLen(op) == dataLayout.getTypeSize(slot.elemType);
1285 }
1286 
1287 template <class MemcpyLike>
1288 static DeletionKind
1289 memcpyRemoveBlockingUses(MemcpyLike op, const MemorySlot &slot,
1290                          const SmallPtrSetImpl<OpOperand *> &blockingUses,
1291                          OpBuilder &builder, Value reachingDefinition) {
1292   if (op.loadsFrom(slot))
1293     builder.create<LLVM::StoreOp>(op.getLoc(), reachingDefinition, op.getDst());
1294   return DeletionKind::Delete;
1295 }
1296 
1297 template <class MemcpyLike>
1298 static LogicalResult
1299 memcpyEnsureOnlySafeAccesses(MemcpyLike op, const MemorySlot &slot,
1300                              SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
1301   DataLayout dataLayout = DataLayout::closest(op);
1302   // While rewiring memcpy-like intrinsics only supports full copies, partial
1303   // copies are still safe accesses so it is enough to only check for writes
1304   // within bounds.
1305   return success(definitelyWritesOnlyWithinSlot(op, slot, dataLayout));
1306 }
1307 
1308 template <class MemcpyLike>
1309 static bool memcpyCanRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
1310                             SmallPtrSetImpl<Attribute> &usedIndices,
1311                             SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1312                             const DataLayout &dataLayout) {
1313   if (op.getIsVolatile())
1314     return false;
1315 
1316   if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
1317     return false;
1318 
1319   if (!areAllIndicesI32(slot))
1320     return false;
1321 
1322   // Only full copies are supported.
1323   if (getStaticMemIntrLen(op) != dataLayout.getTypeSize(slot.elemType))
1324     return false;
1325 
1326   if (op.getSrc() == slot.ptr)
1327     for (Attribute index : llvm::make_first_range(slot.subelementTypes))
1328       usedIndices.insert(index);
1329 
1330   return true;
1331 }
1332 
1333 namespace {
1334 
1335 template <class MemcpyLike>
1336 void createMemcpyLikeToReplace(OpBuilder &builder, const DataLayout &layout,
1337                                MemcpyLike toReplace, Value dst, Value src,
1338                                Type toCpy, bool isVolatile) {
1339   Value memcpySize = builder.create<LLVM::ConstantOp>(
1340       toReplace.getLoc(), IntegerAttr::get(toReplace.getLen().getType(),
1341                                            layout.getTypeSize(toCpy)));
1342   builder.create<MemcpyLike>(toReplace.getLoc(), dst, src, memcpySize,
1343                              isVolatile);
1344 }
1345 
1346 template <>
1347 void createMemcpyLikeToReplace(OpBuilder &builder, const DataLayout &layout,
1348                                LLVM::MemcpyInlineOp toReplace, Value dst,
1349                                Value src, Type toCpy, bool isVolatile) {
1350   Type lenType = IntegerType::get(toReplace->getContext(),
1351                                   toReplace.getLen().getBitWidth());
1352   builder.create<LLVM::MemcpyInlineOp>(
1353       toReplace.getLoc(), dst, src,
1354       IntegerAttr::get(lenType, layout.getTypeSize(toCpy)), isVolatile);
1355 }
1356 
1357 } // namespace
1358 
1359 /// Rewires a memcpy-like operation. Only copies to or from the full slot are
1360 /// supported.
1361 template <class MemcpyLike>
1362 static DeletionKind
1363 memcpyRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
1364              DenseMap<Attribute, MemorySlot> &subslots, OpBuilder &builder,
1365              const DataLayout &dataLayout) {
1366   if (subslots.empty())
1367     return DeletionKind::Delete;
1368 
1369   assert((slot.ptr == op.getDst()) != (slot.ptr == op.getSrc()));
1370   bool isDst = slot.ptr == op.getDst();
1371 
1372 #ifndef NDEBUG
1373   size_t slotsTreated = 0;
1374 #endif
1375 
1376   // It was previously checked that index types are consistent, so this type can
1377   // be fetched now.
1378   Type indexType = cast<IntegerAttr>(subslots.begin()->first).getType();
1379   for (size_t i = 0, e = slot.subelementTypes.size(); i != e; i++) {
1380     Attribute index = IntegerAttr::get(indexType, i);
1381     if (!subslots.contains(index))
1382       continue;
1383     const MemorySlot &subslot = subslots.at(index);
1384 
1385 #ifndef NDEBUG
1386     slotsTreated++;
1387 #endif
1388 
1389     // First get a pointer to the equivalent of this subslot from the source
1390     // pointer.
1391     SmallVector<LLVM::GEPArg> gepIndices{
1392         0, static_cast<int32_t>(
1393                cast<IntegerAttr>(index).getValue().getZExtValue())};
1394     Value subslotPtrInOther = builder.create<LLVM::GEPOp>(
1395         op.getLoc(), LLVM::LLVMPointerType::get(op.getContext()), slot.elemType,
1396         isDst ? op.getSrc() : op.getDst(), gepIndices);
1397 
1398     // Then create a new memcpy out of this source pointer.
1399     createMemcpyLikeToReplace(builder, dataLayout, op,
1400                               isDst ? subslot.ptr : subslotPtrInOther,
1401                               isDst ? subslotPtrInOther : subslot.ptr,
1402                               subslot.elemType, op.getIsVolatile());
1403   }
1404 
1405   assert(subslots.size() == slotsTreated);
1406 
1407   return DeletionKind::Delete;
1408 }
1409 
1410 bool LLVM::MemcpyOp::loadsFrom(const MemorySlot &slot) {
1411   return memcpyLoadsFrom(*this, slot);
1412 }
1413 
1414 bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
1415   return memcpyStoresTo(*this, slot);
1416 }
1417 
1418 Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, OpBuilder &builder,
1419                                 Value reachingDef,
1420                                 const DataLayout &dataLayout) {
1421   return memcpyGetStored(*this, slot, builder);
1422 }
1423 
1424 bool LLVM::MemcpyOp::canUsesBeRemoved(
1425     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1426     SmallVectorImpl<OpOperand *> &newBlockingUses,
1427     const DataLayout &dataLayout) {
1428   return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
1429                                 dataLayout);
1430 }
1431 
1432 DeletionKind LLVM::MemcpyOp::removeBlockingUses(
1433     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1434     OpBuilder &builder, Value reachingDefinition,
1435     const DataLayout &dataLayout) {
1436   return memcpyRemoveBlockingUses(*this, slot, blockingUses, builder,
1437                                   reachingDefinition);
1438 }
1439 
1440 LogicalResult LLVM::MemcpyOp::ensureOnlySafeAccesses(
1441     const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1442     const DataLayout &dataLayout) {
1443   return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
1444 }
1445 
1446 bool LLVM::MemcpyOp::canRewire(const DestructurableMemorySlot &slot,
1447                                SmallPtrSetImpl<Attribute> &usedIndices,
1448                                SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1449                                const DataLayout &dataLayout) {
1450   return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
1451                          dataLayout);
1452 }
1453 
1454 DeletionKind LLVM::MemcpyOp::rewire(const DestructurableMemorySlot &slot,
1455                                     DenseMap<Attribute, MemorySlot> &subslots,
1456                                     OpBuilder &builder,
1457                                     const DataLayout &dataLayout) {
1458   return memcpyRewire(*this, slot, subslots, builder, dataLayout);
1459 }
1460 
1461 bool LLVM::MemcpyInlineOp::loadsFrom(const MemorySlot &slot) {
1462   return memcpyLoadsFrom(*this, slot);
1463 }
1464 
1465 bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
1466   return memcpyStoresTo(*this, slot);
1467 }
1468 
1469 Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
1470                                       OpBuilder &builder, Value reachingDef,
1471                                       const DataLayout &dataLayout) {
1472   return memcpyGetStored(*this, slot, builder);
1473 }
1474 
1475 bool LLVM::MemcpyInlineOp::canUsesBeRemoved(
1476     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1477     SmallVectorImpl<OpOperand *> &newBlockingUses,
1478     const DataLayout &dataLayout) {
1479   return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
1480                                 dataLayout);
1481 }
1482 
1483 DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses(
1484     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1485     OpBuilder &builder, Value reachingDefinition,
1486     const DataLayout &dataLayout) {
1487   return memcpyRemoveBlockingUses(*this, slot, blockingUses, builder,
1488                                   reachingDefinition);
1489 }
1490 
1491 LogicalResult LLVM::MemcpyInlineOp::ensureOnlySafeAccesses(
1492     const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1493     const DataLayout &dataLayout) {
1494   return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
1495 }
1496 
1497 bool LLVM::MemcpyInlineOp::canRewire(
1498     const DestructurableMemorySlot &slot,
1499     SmallPtrSetImpl<Attribute> &usedIndices,
1500     SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1501     const DataLayout &dataLayout) {
1502   return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
1503                          dataLayout);
1504 }
1505 
1506 DeletionKind
1507 LLVM::MemcpyInlineOp::rewire(const DestructurableMemorySlot &slot,
1508                              DenseMap<Attribute, MemorySlot> &subslots,
1509                              OpBuilder &builder, const DataLayout &dataLayout) {
1510   return memcpyRewire(*this, slot, subslots, builder, dataLayout);
1511 }
1512 
1513 bool LLVM::MemmoveOp::loadsFrom(const MemorySlot &slot) {
1514   return memcpyLoadsFrom(*this, slot);
1515 }
1516 
1517 bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
1518   return memcpyStoresTo(*this, slot);
1519 }
1520 
1521 Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, OpBuilder &builder,
1522                                  Value reachingDef,
1523                                  const DataLayout &dataLayout) {
1524   return memcpyGetStored(*this, slot, builder);
1525 }
1526 
1527 bool LLVM::MemmoveOp::canUsesBeRemoved(
1528     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1529     SmallVectorImpl<OpOperand *> &newBlockingUses,
1530     const DataLayout &dataLayout) {
1531   return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
1532                                 dataLayout);
1533 }
1534 
1535 DeletionKind LLVM::MemmoveOp::removeBlockingUses(
1536     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1537     OpBuilder &builder, Value reachingDefinition,
1538     const DataLayout &dataLayout) {
1539   return memcpyRemoveBlockingUses(*this, slot, blockingUses, builder,
1540                                   reachingDefinition);
1541 }
1542 
1543 LogicalResult LLVM::MemmoveOp::ensureOnlySafeAccesses(
1544     const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1545     const DataLayout &dataLayout) {
1546   return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
1547 }
1548 
1549 bool LLVM::MemmoveOp::canRewire(const DestructurableMemorySlot &slot,
1550                                 SmallPtrSetImpl<Attribute> &usedIndices,
1551                                 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1552                                 const DataLayout &dataLayout) {
1553   return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
1554                          dataLayout);
1555 }
1556 
1557 DeletionKind LLVM::MemmoveOp::rewire(const DestructurableMemorySlot &slot,
1558                                      DenseMap<Attribute, MemorySlot> &subslots,
1559                                      OpBuilder &builder,
1560                                      const DataLayout &dataLayout) {
1561   return memcpyRewire(*this, slot, subslots, builder, dataLayout);
1562 }
1563 
1564 //===----------------------------------------------------------------------===//
1565 // Interfaces for destructurable types
1566 //===----------------------------------------------------------------------===//
1567 
1568 std::optional<DenseMap<Attribute, Type>>
1569 LLVM::LLVMStructType::getSubelementIndexMap() const {
1570   Type i32 = IntegerType::get(getContext(), 32);
1571   DenseMap<Attribute, Type> destructured;
1572   for (const auto &[index, elemType] : llvm::enumerate(getBody()))
1573     destructured.insert({IntegerAttr::get(i32, index), elemType});
1574   return destructured;
1575 }
1576 
1577 Type LLVM::LLVMStructType::getTypeAtIndex(Attribute index) const {
1578   auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
1579   if (!indexAttr || !indexAttr.getType().isInteger(32))
1580     return {};
1581   int32_t indexInt = indexAttr.getInt();
1582   ArrayRef<Type> body = getBody();
1583   if (indexInt < 0 || body.size() <= static_cast<uint32_t>(indexInt))
1584     return {};
1585   return body[indexInt];
1586 }
1587 
1588 std::optional<DenseMap<Attribute, Type>>
1589 LLVM::LLVMArrayType::getSubelementIndexMap() const {
1590   constexpr size_t maxArraySizeForDestructuring = 16;
1591   if (getNumElements() > maxArraySizeForDestructuring)
1592     return {};
1593   int32_t numElements = getNumElements();
1594 
1595   Type i32 = IntegerType::get(getContext(), 32);
1596   DenseMap<Attribute, Type> destructured;
1597   for (int32_t index = 0; index < numElements; ++index)
1598     destructured.insert({IntegerAttr::get(i32, index), getElementType()});
1599   return destructured;
1600 }
1601 
1602 Type LLVM::LLVMArrayType::getTypeAtIndex(Attribute index) const {
1603   auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
1604   if (!indexAttr || !indexAttr.getType().isInteger(32))
1605     return {};
1606   int32_t indexInt = indexAttr.getInt();
1607   if (indexInt < 0 || getNumElements() <= static_cast<uint32_t>(indexInt))
1608     return {};
1609   return getElementType();
1610 }
1611