xref: /llvm-project/llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp (revision 3805355ef69a33fc6b32e4a4de0ad3ef22584c65)
1 //===-- AMDGPULowerBufferFatPointers.cpp ---------------------------=//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass lowers operations on buffer fat pointers (addrspace 7) to
10 // operations on buffer resources (addrspace 8) and is needed for correct
11 // codegen.
12 //
13 // # Background
14 //
15 // Address space 7 (the buffer fat pointer) is a 160-bit pointer that consists
16 // of a 128-bit buffer descriptor and a 32-bit offset into that descriptor.
17 // The buffer resource part needs to be it needs to be a "raw" buffer resource
18 // (it must have a stride of 0 and bounds checks must be in raw buffer mode
19 // or disabled).
20 //
21 // When these requirements are met, a buffer resource can be treated as a
22 // typical (though quite wide) pointer that follows typical LLVM pointer
23 // semantics. This allows the frontend to reason about such buffers (which are
24 // often encountered in the context of SPIR-V kernels).
25 //
26 // However, because of their non-power-of-2 size, these fat pointers cannot be
27 // present during translation to MIR (though this restriction may be lifted
28 // during the transition to GlobalISel). Therefore, this pass is needed in order
29 // to correctly implement these fat pointers.
30 //
31 // The resource intrinsics take the resource part (the address space 8 pointer)
32 // and the offset part (the 32-bit integer) as separate arguments. In addition,
33 // many users of these buffers manipulate the offset while leaving the resource
34 // part alone. For these reasons, we want to typically separate the resource
35 // and offset parts into separate variables, but combine them together when
36 // encountering cases where this is required, such as by inserting these values
37 // into aggretates or moving them to memory.
38 //
39 // Therefore, at a high level, `ptr addrspace(7) %x` becomes `ptr addrspace(8)
40 // %x.rsrc` and `i32 %x.off`, which will be combined into `{ptr addrspace(8),
41 // i32} %x = {%x.rsrc, %x.off}` if needed. Similarly, `vector<Nxp7>` becomes
42 // `{vector<Nxp8>, vector<Nxi32 >}` and its component parts.
43 //
44 // # Implementation
45 //
46 // This pass proceeds in three main phases:
47 //
48 // ## Rewriting loads and stores of p7
49 //
50 // The first phase is to rewrite away all loads and stors of `ptr addrspace(7)`,
51 // including aggregates containing such pointers, to ones that use `i160`. This
52 // is handled by `StoreFatPtrsAsIntsVisitor` , which visits loads, stores, and
53 // allocas and, if the loaded or stored type contains `ptr addrspace(7)`,
54 // rewrites that type to one where the p7s are replaced by i160s, copying other
55 // parts of aggregates as needed. In the case of a store, each pointer is
56 // `ptrtoint`d to i160 before storing, and load integers are `inttoptr`d back.
57 // This same transformation is applied to vectors of pointers.
58 //
59 // Such a transformation allows the later phases of the pass to not need
60 // to handle buffer fat pointers moving to and from memory, where we load
61 // have to handle the incompatibility between a `{Nxp8, Nxi32}` representation
62 // and `Nxi60` directly. Instead, that transposing action (where the vectors
63 // of resources and vectors of offsets are concatentated before being stored to
64 // memory) are handled through implementing `inttoptr` and `ptrtoint` only.
65 //
66 // Atomics operations on `ptr addrspace(7)` values are not suppported, as the
67 // hardware does not include a 160-bit atomic.
68 //
69 // ## Buffer contents type legalization
70 //
71 // The underlying buffer intrinsics only support types up to 128 bits long,
72 // and don't support complex types. If buffer operations were
73 // standard pointer operations that could be represented as MIR-level loads,
74 // this would be handled by the various legalization schemes in instruction
75 // selection. However, because we have to do the conversion from `load` and
76 // `store` to intrinsics at LLVM IR level, we must perform that legalization
77 // ourselves.
78 //
79 // This involves a combination of
80 // - Converting arrays to vectors where possible
81 // - Otherwise, splitting loads and stores of aggregates into loads/stores of
82 //   each component.
83 // - Zero-extending things to fill a whole number of bytes
84 // - Casting values of types that don't neatly correspond to supported machine
85 // value
86 //   (for example, an i96 or i256) into ones that would work (
87 //    like <3 x i32> and <8 x i32>, respectively)
88 // - Splitting values that are too long (such as aforementioned <8 x i32>) into
89 //   multiple operations.
90 //
91 // ## Type remapping
92 //
93 // We use a `ValueMapper` to mangle uses of [vectors of] buffer fat pointers
94 // to the corresponding struct type, which has a resource part and an offset
95 // part.
96 //
97 // This uses a `BufferFatPtrToStructTypeMap` and a `FatPtrConstMaterializer`
98 // to, usually by way of `setType`ing values. Constants are handled here
99 // because there isn't a good way to fix them up later.
100 //
101 // This has the downside of leaving the IR in an invalid state (for example,
102 // the instruction `getelementptr {ptr addrspace(8), i32} %p, ...` will exist),
103 // but all such invalid states will be resolved by the third phase.
104 //
105 // Functions that don't take buffer fat pointers are modified in place. Those
106 // that do take such pointers have their basic blocks moved to a new function
107 // with arguments that are {ptr addrspace(8), i32} arguments and return values.
108 // This phase also records intrinsics so that they can be remangled or deleted
109 // later.
110 //
111 // ## Splitting pointer structs
112 //
113 // The meat of this pass consists of defining semantics for operations that
114 // produce or consume [vectors of] buffer fat pointers in terms of their
115 // resource and offset parts. This is accomplished throgh the `SplitPtrStructs`
116 // visitor.
117 //
118 // In the first pass through each function that is being lowered, the splitter
119 // inserts new instructions to implement the split-structures behavior, which is
120 // needed for correctness and performance. It records a list of "split users",
121 // instructions that are being replaced by operations on the resource and offset
122 // parts.
123 //
124 // Split users do not necessarily need to produce parts themselves (
125 // a `load float, ptr addrspace(7)` does not, for example), but, if they do not
126 // generate fat buffer pointers, they must RAUW in their replacement
127 // instructions during the initial visit.
128 //
129 // When these new instructions are created, they use the split parts recorded
130 // for their initial arguments in order to generate their replacements, creating
131 // a parallel set of instructions that does not refer to the original fat
132 // pointer values but instead to their resource and offset components.
133 //
134 // Instructions, such as `extractvalue`, that produce buffer fat pointers from
135 // sources that do not have split parts, have such parts generated using
136 // `extractvalue`. This is also the initial handling of PHI nodes, which
137 // are then cleaned up.
138 //
139 // ### Conditionals
140 //
141 // PHI nodes are initially given resource parts via `extractvalue`. However,
142 // this is not an efficient rewrite of such nodes, as, in most cases, the
143 // resource part in a conditional or loop remains constant throughout the loop
144 // and only the offset varies. Failing to optimize away these constant resources
145 // would cause additional registers to be sent around loops and might lead to
146 // waterfall loops being generated for buffer operations due to the
147 // "non-uniform" resource argument.
148 //
149 // Therefore, after all instructions have been visited, the pointer splitter
150 // post-processes all encountered conditionals. Given a PHI node or select,
151 // getPossibleRsrcRoots() collects all values that the resource parts of that
152 // conditional's input could come from as well as collecting all conditional
153 // instructions encountered during the search. If, after filtering out the
154 // initial node itself, the set of encountered conditionals is a subset of the
155 // potential roots and there is a single potential resource that isn't in the
156 // conditional set, that value is the only possible value the resource argument
157 // could have throughout the control flow.
158 //
159 // If that condition is met, then a PHI node can have its resource part changed
160 // to the singleton value and then be replaced by a PHI on the offsets.
161 // Otherwise, each PHI node is split into two, one for the resource part and one
162 // for the offset part, which replace the temporary `extractvalue` instructions
163 // that were added during the first pass.
164 //
165 // Similar logic applies to `select`, where
166 // `%z = select i1 %cond, %cond, ptr addrspace(7) %x, ptr addrspace(7) %y`
167 // can be split into `%z.rsrc = %x.rsrc` and
168 // `%z.off = select i1 %cond, ptr i32 %x.off, i32 %y.off`
169 // if both `%x` and `%y` have the same resource part, but two `select`
170 // operations will be needed if they do not.
171 //
172 // ### Final processing
173 //
174 // After conditionals have been cleaned up, the IR for each function is
175 // rewritten to remove all the old instructions that have been split up.
176 //
177 // Any instruction that used to produce a buffer fat pointer (and therefore now
178 // produces a resource-and-offset struct after type remapping) is
179 // replaced as follows:
180 // 1. All debug value annotations are cloned to reflect that the resource part
181 //    and offset parts are computed separately and constitute different
182 //    fragments of the underlying source language variable.
183 // 2. All uses that were themselves split are replaced by a `poison` of the
184 //    struct type, as they will themselves be erased soon. This rule, combined
185 //    with debug handling, should leave the use lists of split instructions
186 //    empty in almost all cases.
187 // 3. If a user of the original struct-valued result remains, the structure
188 //    needed for the new types to work is constructed out of the newly-defined
189 //    parts, and the original instruction is replaced by this structure
190 //    before being erased. Instructions requiring this construction include
191 //    `ret` and `insertvalue`.
192 //
193 // # Consequences
194 //
195 // This pass does not alter the CFG.
196 //
197 // Alias analysis information will become coarser, as the LLVM alias analyzer
198 // cannot handle the buffer intrinsics. Specifically, while we can determine
199 // that the following two loads do not alias:
200 // ```
201 //   %y = getelementptr i32, ptr addrspace(7) %x, i32 1
202 //   %a = load i32, ptr addrspace(7) %x
203 //   %b = load i32, ptr addrspace(7) %y
204 // ```
205 // we cannot (except through some code that runs during scheduling) determine
206 // that the rewritten loads below do not alias.
207 // ```
208 //   %y.off = add i32 %x.off, 1
209 //   %a = call @llvm.amdgcn.raw.ptr.buffer.load(ptr addrspace(8) %x.rsrc, i32
210 //     %x.off, ...)
211 //   %b = call @llvm.amdgcn.raw.ptr.buffer.load(ptr addrspace(8)
212 //     %x.rsrc, i32 %y.off, ...)
213 // ```
214 // However, existing alias information is preserved.
215 //===----------------------------------------------------------------------===//
216 
217 #include "AMDGPU.h"
218 #include "AMDGPUTargetMachine.h"
219 #include "GCNSubtarget.h"
220 #include "SIDefines.h"
221 #include "llvm/ADT/SetOperations.h"
222 #include "llvm/ADT/SmallVector.h"
223 #include "llvm/Analysis/ConstantFolding.h"
224 #include "llvm/Analysis/Utils/Local.h"
225 #include "llvm/CodeGen/TargetPassConfig.h"
226 #include "llvm/IR/AttributeMask.h"
227 #include "llvm/IR/Constants.h"
228 #include "llvm/IR/DebugInfo.h"
229 #include "llvm/IR/DerivedTypes.h"
230 #include "llvm/IR/IRBuilder.h"
231 #include "llvm/IR/InstIterator.h"
232 #include "llvm/IR/InstVisitor.h"
233 #include "llvm/IR/Instructions.h"
234 #include "llvm/IR/Intrinsics.h"
235 #include "llvm/IR/IntrinsicsAMDGPU.h"
236 #include "llvm/IR/Metadata.h"
237 #include "llvm/IR/Operator.h"
238 #include "llvm/IR/PatternMatch.h"
239 #include "llvm/IR/ReplaceConstant.h"
240 #include "llvm/InitializePasses.h"
241 #include "llvm/Pass.h"
242 #include "llvm/Support/Alignment.h"
243 #include "llvm/Support/AtomicOrdering.h"
244 #include "llvm/Support/Debug.h"
245 #include "llvm/Support/ErrorHandling.h"
246 #include "llvm/Transforms/Utils/Cloning.h"
247 #include "llvm/Transforms/Utils/Local.h"
248 #include "llvm/Transforms/Utils/ValueMapper.h"
249 
250 #define DEBUG_TYPE "amdgpu-lower-buffer-fat-pointers"
251 
252 using namespace llvm;
253 
254 static constexpr unsigned BufferOffsetWidth = 32;
255 
256 namespace {
257 /// Recursively replace instances of ptr addrspace(7) and vector<Nxptr
258 /// addrspace(7)> with some other type as defined by the relevant subclass.
259 class BufferFatPtrTypeLoweringBase : public ValueMapTypeRemapper {
260   DenseMap<Type *, Type *> Map;
261 
262   Type *remapTypeImpl(Type *Ty, SmallPtrSetImpl<StructType *> &Seen);
263 
264 protected:
265   virtual Type *remapScalar(PointerType *PT) = 0;
266   virtual Type *remapVector(VectorType *VT) = 0;
267 
268   const DataLayout &DL;
269 
270 public:
271   BufferFatPtrTypeLoweringBase(const DataLayout &DL) : DL(DL) {}
272   Type *remapType(Type *SrcTy) override;
273   void clear() { Map.clear(); }
274 };
275 
276 /// Remap ptr addrspace(7) to i160 and vector<Nxptr addrspace(7)> to
277 /// vector<Nxi60> in order to correctly handling loading/storing these values
278 /// from memory.
279 class BufferFatPtrToIntTypeMap : public BufferFatPtrTypeLoweringBase {
280   using BufferFatPtrTypeLoweringBase::BufferFatPtrTypeLoweringBase;
281 
282 protected:
283   Type *remapScalar(PointerType *PT) override { return DL.getIntPtrType(PT); }
284   Type *remapVector(VectorType *VT) override { return DL.getIntPtrType(VT); }
285 };
286 
287 /// Remap ptr addrspace(7) to {ptr addrspace(8), i32} (the resource and offset
288 /// parts of the pointer) so that we can easily rewrite operations on these
289 /// values that aren't loading them from or storing them to memory.
290 class BufferFatPtrToStructTypeMap : public BufferFatPtrTypeLoweringBase {
291   using BufferFatPtrTypeLoweringBase::BufferFatPtrTypeLoweringBase;
292 
293 protected:
294   Type *remapScalar(PointerType *PT) override;
295   Type *remapVector(VectorType *VT) override;
296 };
297 } // namespace
298 
299 // This code is adapted from the type remapper in lib/Linker/IRMover.cpp
300 Type *BufferFatPtrTypeLoweringBase::remapTypeImpl(
301     Type *Ty, SmallPtrSetImpl<StructType *> &Seen) {
302   Type **Entry = &Map[Ty];
303   if (*Entry)
304     return *Entry;
305   if (auto *PT = dyn_cast<PointerType>(Ty)) {
306     if (PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER) {
307       return *Entry = remapScalar(PT);
308     }
309   }
310   if (auto *VT = dyn_cast<VectorType>(Ty)) {
311     auto *PT = dyn_cast<PointerType>(VT->getElementType());
312     if (PT && PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER) {
313       return *Entry = remapVector(VT);
314     }
315     return *Entry = Ty;
316   }
317   // Whether the type is one that is structurally uniqued - that is, if it is
318   // not a named struct (the only kind of type where multiple structurally
319   // identical types that have a distinct `Type*`)
320   StructType *TyAsStruct = dyn_cast<StructType>(Ty);
321   bool IsUniqued = !TyAsStruct || TyAsStruct->isLiteral();
322   // Base case for ints, floats, opaque pointers, and so on, which don't
323   // require recursion.
324   if (Ty->getNumContainedTypes() == 0 && IsUniqued)
325     return *Entry = Ty;
326   if (!IsUniqued) {
327     // Create a dummy type for recursion purposes.
328     if (!Seen.insert(TyAsStruct).second) {
329       StructType *Placeholder = StructType::create(Ty->getContext());
330       return *Entry = Placeholder;
331     }
332   }
333   bool Changed = false;
334   SmallVector<Type *> ElementTypes(Ty->getNumContainedTypes(), nullptr);
335   for (unsigned int I = 0, E = Ty->getNumContainedTypes(); I < E; ++I) {
336     Type *OldElem = Ty->getContainedType(I);
337     Type *NewElem = remapTypeImpl(OldElem, Seen);
338     ElementTypes[I] = NewElem;
339     Changed |= (OldElem != NewElem);
340   }
341   // Recursive calls to remapTypeImpl() may have invalidated pointer.
342   Entry = &Map[Ty];
343   if (!Changed) {
344     return *Entry = Ty;
345   }
346   if (auto *ArrTy = dyn_cast<ArrayType>(Ty))
347     return *Entry = ArrayType::get(ElementTypes[0], ArrTy->getNumElements());
348   if (auto *FnTy = dyn_cast<FunctionType>(Ty))
349     return *Entry = FunctionType::get(ElementTypes[0],
350                                       ArrayRef(ElementTypes).slice(1),
351                                       FnTy->isVarArg());
352   if (auto *STy = dyn_cast<StructType>(Ty)) {
353     // Genuine opaque types don't have a remapping.
354     if (STy->isOpaque())
355       return *Entry = Ty;
356     bool IsPacked = STy->isPacked();
357     if (IsUniqued)
358       return *Entry = StructType::get(Ty->getContext(), ElementTypes, IsPacked);
359     SmallString<16> Name(STy->getName());
360     STy->setName("");
361     Type **RecursionEntry = &Map[Ty];
362     if (*RecursionEntry) {
363       auto *Placeholder = cast<StructType>(*RecursionEntry);
364       Placeholder->setBody(ElementTypes, IsPacked);
365       Placeholder->setName(Name);
366       return *Entry = Placeholder;
367     }
368     return *Entry = StructType::create(Ty->getContext(), ElementTypes, Name,
369                                        IsPacked);
370   }
371   llvm_unreachable("Unknown type of type that contains elements");
372 }
373 
374 Type *BufferFatPtrTypeLoweringBase::remapType(Type *SrcTy) {
375   SmallPtrSet<StructType *, 2> Visited;
376   return remapTypeImpl(SrcTy, Visited);
377 }
378 
379 Type *BufferFatPtrToStructTypeMap::remapScalar(PointerType *PT) {
380   LLVMContext &Ctx = PT->getContext();
381   return StructType::get(PointerType::get(Ctx, AMDGPUAS::BUFFER_RESOURCE),
382                          IntegerType::get(Ctx, BufferOffsetWidth));
383 }
384 
385 Type *BufferFatPtrToStructTypeMap::remapVector(VectorType *VT) {
386   ElementCount EC = VT->getElementCount();
387   LLVMContext &Ctx = VT->getContext();
388   Type *RsrcVec =
389       VectorType::get(PointerType::get(Ctx, AMDGPUAS::BUFFER_RESOURCE), EC);
390   Type *OffVec = VectorType::get(IntegerType::get(Ctx, BufferOffsetWidth), EC);
391   return StructType::get(RsrcVec, OffVec);
392 }
393 
394 static bool isBufferFatPtrOrVector(Type *Ty) {
395   if (auto *PT = dyn_cast<PointerType>(Ty->getScalarType()))
396     return PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER;
397   return false;
398 }
399 
400 // True if the type is {ptr addrspace(8), i32} or a struct containing vectors of
401 // those types. Used to quickly skip instructions we don't need to process.
402 static bool isSplitFatPtr(Type *Ty) {
403   auto *ST = dyn_cast<StructType>(Ty);
404   if (!ST)
405     return false;
406   if (!ST->isLiteral() || ST->getNumElements() != 2)
407     return false;
408   auto *MaybeRsrc =
409       dyn_cast<PointerType>(ST->getElementType(0)->getScalarType());
410   auto *MaybeOff =
411       dyn_cast<IntegerType>(ST->getElementType(1)->getScalarType());
412   return MaybeRsrc && MaybeOff &&
413          MaybeRsrc->getAddressSpace() == AMDGPUAS::BUFFER_RESOURCE &&
414          MaybeOff->getBitWidth() == BufferOffsetWidth;
415 }
416 
417 // True if the result type or any argument types are buffer fat pointers.
418 static bool isBufferFatPtrConst(Constant *C) {
419   Type *T = C->getType();
420   return isBufferFatPtrOrVector(T) || any_of(C->operands(), [](const Use &U) {
421            return isBufferFatPtrOrVector(U.get()->getType());
422          });
423 }
424 
425 namespace {
426 /// Convert [vectors of] buffer fat pointers to integers when they are read from
427 /// or stored to memory. This ensures that these pointers will have the same
428 /// memory layout as before they are lowered, even though they will no longer
429 /// have their previous layout in registers/in the program (they'll be broken
430 /// down into resource and offset parts). This has the downside of imposing
431 /// marshalling costs when reading or storing these values, but since placing
432 /// such pointers into memory is an uncommon operation at best, we feel that
433 /// this cost is acceptable for better performance in the common case.
434 class StoreFatPtrsAsIntsVisitor
435     : public InstVisitor<StoreFatPtrsAsIntsVisitor, bool> {
436   BufferFatPtrToIntTypeMap *TypeMap;
437 
438   ValueToValueMapTy ConvertedForStore;
439 
440   IRBuilder<> IRB;
441 
442   // Convert all the buffer fat pointers within the input value to inttegers
443   // so that it can be stored in memory.
444   Value *fatPtrsToInts(Value *V, Type *From, Type *To, const Twine &Name);
445   // Convert all the i160s that need to be buffer fat pointers (as specified)
446   // by the To type) into those pointers to preserve the semantics of the rest
447   // of the program.
448   Value *intsToFatPtrs(Value *V, Type *From, Type *To, const Twine &Name);
449 
450 public:
451   StoreFatPtrsAsIntsVisitor(BufferFatPtrToIntTypeMap *TypeMap, LLVMContext &Ctx)
452       : TypeMap(TypeMap), IRB(Ctx) {}
453   bool processFunction(Function &F);
454 
455   bool visitInstruction(Instruction &I) { return false; }
456   bool visitAllocaInst(AllocaInst &I);
457   bool visitLoadInst(LoadInst &LI);
458   bool visitStoreInst(StoreInst &SI);
459   bool visitGetElementPtrInst(GetElementPtrInst &I);
460 };
461 } // namespace
462 
463 Value *StoreFatPtrsAsIntsVisitor::fatPtrsToInts(Value *V, Type *From, Type *To,
464                                                 const Twine &Name) {
465   if (From == To)
466     return V;
467   ValueToValueMapTy::iterator Find = ConvertedForStore.find(V);
468   if (Find != ConvertedForStore.end())
469     return Find->second;
470   if (isBufferFatPtrOrVector(From)) {
471     Value *Cast = IRB.CreatePtrToInt(V, To, Name + ".int");
472     ConvertedForStore[V] = Cast;
473     return Cast;
474   }
475   if (From->getNumContainedTypes() == 0)
476     return V;
477   // Structs, arrays, and other compound types.
478   Value *Ret = PoisonValue::get(To);
479   if (auto *AT = dyn_cast<ArrayType>(From)) {
480     Type *FromPart = AT->getArrayElementType();
481     Type *ToPart = cast<ArrayType>(To)->getElementType();
482     for (uint64_t I = 0, E = AT->getArrayNumElements(); I < E; ++I) {
483       Value *Field = IRB.CreateExtractValue(V, I);
484       Value *NewField =
485           fatPtrsToInts(Field, FromPart, ToPart, Name + "." + Twine(I));
486       Ret = IRB.CreateInsertValue(Ret, NewField, I);
487     }
488   } else {
489     for (auto [Idx, FromPart, ToPart] :
490          enumerate(From->subtypes(), To->subtypes())) {
491       Value *Field = IRB.CreateExtractValue(V, Idx);
492       Value *NewField =
493           fatPtrsToInts(Field, FromPart, ToPart, Name + "." + Twine(Idx));
494       Ret = IRB.CreateInsertValue(Ret, NewField, Idx);
495     }
496   }
497   ConvertedForStore[V] = Ret;
498   return Ret;
499 }
500 
501 Value *StoreFatPtrsAsIntsVisitor::intsToFatPtrs(Value *V, Type *From, Type *To,
502                                                 const Twine &Name) {
503   if (From == To)
504     return V;
505   if (isBufferFatPtrOrVector(To)) {
506     Value *Cast = IRB.CreateIntToPtr(V, To, Name + ".ptr");
507     return Cast;
508   }
509   if (From->getNumContainedTypes() == 0)
510     return V;
511   // Structs, arrays, and other compound types.
512   Value *Ret = PoisonValue::get(To);
513   if (auto *AT = dyn_cast<ArrayType>(From)) {
514     Type *FromPart = AT->getArrayElementType();
515     Type *ToPart = cast<ArrayType>(To)->getElementType();
516     for (uint64_t I = 0, E = AT->getArrayNumElements(); I < E; ++I) {
517       Value *Field = IRB.CreateExtractValue(V, I);
518       Value *NewField =
519           intsToFatPtrs(Field, FromPart, ToPart, Name + "." + Twine(I));
520       Ret = IRB.CreateInsertValue(Ret, NewField, I);
521     }
522   } else {
523     for (auto [Idx, FromPart, ToPart] :
524          enumerate(From->subtypes(), To->subtypes())) {
525       Value *Field = IRB.CreateExtractValue(V, Idx);
526       Value *NewField =
527           intsToFatPtrs(Field, FromPart, ToPart, Name + "." + Twine(Idx));
528       Ret = IRB.CreateInsertValue(Ret, NewField, Idx);
529     }
530   }
531   return Ret;
532 }
533 
534 bool StoreFatPtrsAsIntsVisitor::processFunction(Function &F) {
535   bool Changed = false;
536   // The visitors will mutate GEPs and allocas, but will push loads and stores
537   // to the worklist to avoid invalidation.
538   for (Instruction &I : make_early_inc_range(instructions(F))) {
539     Changed |= visit(I);
540   }
541   ConvertedForStore.clear();
542   return Changed;
543 }
544 
545 bool StoreFatPtrsAsIntsVisitor::visitAllocaInst(AllocaInst &I) {
546   Type *Ty = I.getAllocatedType();
547   Type *NewTy = TypeMap->remapType(Ty);
548   if (Ty == NewTy)
549     return false;
550   I.setAllocatedType(NewTy);
551   return true;
552 }
553 
554 bool StoreFatPtrsAsIntsVisitor::visitGetElementPtrInst(GetElementPtrInst &I) {
555   Type *Ty = I.getSourceElementType();
556   Type *NewTy = TypeMap->remapType(Ty);
557   if (Ty == NewTy)
558     return false;
559   // We'll be rewriting the type `ptr addrspace(7)` out of existence soon, so
560   // make sure GEPs don't have different semantics with the new type.
561   I.setSourceElementType(NewTy);
562   I.setResultElementType(TypeMap->remapType(I.getResultElementType()));
563   return true;
564 }
565 
566 bool StoreFatPtrsAsIntsVisitor::visitLoadInst(LoadInst &LI) {
567   Type *Ty = LI.getType();
568   Type *IntTy = TypeMap->remapType(Ty);
569   if (Ty == IntTy)
570     return false;
571 
572   IRB.SetInsertPoint(&LI);
573   auto *NLI = cast<LoadInst>(LI.clone());
574   NLI->mutateType(IntTy);
575   NLI = IRB.Insert(NLI);
576   NLI->takeName(&LI);
577 
578   Value *CastBack = intsToFatPtrs(NLI, IntTy, Ty, NLI->getName());
579   LI.replaceAllUsesWith(CastBack);
580   LI.eraseFromParent();
581   return true;
582 }
583 
584 bool StoreFatPtrsAsIntsVisitor::visitStoreInst(StoreInst &SI) {
585   Value *V = SI.getValueOperand();
586   Type *Ty = V->getType();
587   Type *IntTy = TypeMap->remapType(Ty);
588   if (Ty == IntTy)
589     return false;
590 
591   IRB.SetInsertPoint(&SI);
592   Value *IntV = fatPtrsToInts(V, Ty, IntTy, V->getName());
593   for (auto *Dbg : at::getAssignmentMarkers(&SI))
594     Dbg->setValue(IntV);
595 
596   SI.setOperand(0, IntV);
597   return true;
598 }
599 
600 namespace {
601 /// Convert loads/stores of types that the buffer intrinsics can't handle into
602 /// one ore more such loads/stores that consist of legal types.
603 ///
604 /// Do this by
605 /// 1. Recursing into structs (and arrays that don't share a memory layout with
606 /// vectors) since the intrinsics can't handle complex types.
607 /// 2. Converting arrays of non-aggregate, byte-sized types into their
608 /// corresponding vectors
609 /// 3. Bitcasting unsupported types, namely overly-long scalars and byte
610 /// vectors, into vectors of supported types.
611 /// 4. Splitting up excessively long reads/writes into multiple operations.
612 ///
613 /// Note that this doesn't handle complex data strucures, but, in the future,
614 /// the aggregate load splitter from SROA could be refactored to allow for that
615 /// case.
616 class LegalizeBufferContentTypesVisitor
617     : public InstVisitor<LegalizeBufferContentTypesVisitor, bool> {
618   friend class InstVisitor<LegalizeBufferContentTypesVisitor, bool>;
619 
620   IRBuilder<> IRB;
621 
622   const DataLayout &DL;
623 
624   /// If T is [N x U], where U is a scalar type, return the vector type
625   /// <N x U>, otherwise, return T.
626   Type *scalarArrayTypeAsVector(Type *MaybeArrayType);
627   Value *arrayToVector(Value *V, Type *TargetType, const Twine &Name);
628   Value *vectorToArray(Value *V, Type *OrigType, const Twine &Name);
629 
630   /// Break up the loads of a struct into the loads of its components
631 
632   /// Convert a vector or scalar type that can't be operated on by buffer
633   /// intrinsics to one that would be legal through bitcasts and/or truncation.
634   /// Uses the wider of i32, i16, or i8 where possible.
635   Type *legalNonAggregateFor(Type *T);
636   Value *makeLegalNonAggregate(Value *V, Type *TargetType, const Twine &Name);
637   Value *makeIllegalNonAggregate(Value *V, Type *OrigType, const Twine &Name);
638 
639   struct VecSlice {
640     uint64_t Index = 0;
641     uint64_t Length = 0;
642     VecSlice() = delete;
643   };
644   /// Return the [index, length] pairs into which `T` needs to be cut to form
645   /// legal buffer load or store operations. Clears `Slices`. Creates an empty
646   /// `Slices` for non-vector inputs and creates one slice if no slicing will be
647   /// needed.
648   void getVecSlices(Type *T, SmallVectorImpl<VecSlice> &Slices);
649 
650   Value *extractSlice(Value *Vec, VecSlice S, const Twine &Name);
651   Value *insertSlice(Value *Whole, Value *Part, VecSlice S, const Twine &Name);
652 
653   /// In most cases, return `LegalType`. However, when given an input that would
654   /// normally be a legal type for the buffer intrinsics to return but that
655   /// isn't hooked up through SelectionDAG, return a type of the same width that
656   /// can be used with the relevant intrinsics. Specifically, handle the cases:
657   /// - <1 x T> => T for all T
658   /// - <N x i8> <=> i16, i32, 2xi32, 4xi32 (as needed)
659   /// - <N x T> where T is under 32 bits and the total size is 96 bits <=> <3 x
660   /// i32>
661   Type *intrinsicTypeFor(Type *LegalType);
662 
663   bool visitLoadImpl(LoadInst &OrigLI, Type *PartType,
664                      SmallVectorImpl<uint32_t> &AggIdxs, uint64_t AggByteOffset,
665                      Value *&Result, const Twine &Name);
666   /// Return value is (Changed, ModifiedInPlace)
667   std::pair<bool, bool> visitStoreImpl(StoreInst &OrigSI, Type *PartType,
668                                        SmallVectorImpl<uint32_t> &AggIdxs,
669                                        uint64_t AggByteOffset,
670                                        const Twine &Name);
671 
672   bool visitInstruction(Instruction &I) { return false; }
673   bool visitLoadInst(LoadInst &LI);
674   bool visitStoreInst(StoreInst &SI);
675 
676 public:
677   LegalizeBufferContentTypesVisitor(const DataLayout &DL, LLVMContext &Ctx)
678       : IRB(Ctx), DL(DL) {}
679   bool processFunction(Function &F);
680 };
681 } // namespace
682 
683 Type *LegalizeBufferContentTypesVisitor::scalarArrayTypeAsVector(Type *T) {
684   ArrayType *AT = dyn_cast<ArrayType>(T);
685   if (!AT)
686     return T;
687   Type *ET = AT->getElementType();
688   if (!ET->isSingleValueType() || isa<VectorType>(ET))
689     report_fatal_error("loading non-scalar arrays from buffer fat pointers "
690                        "should have recursed");
691   if (!DL.typeSizeEqualsStoreSize(AT))
692     report_fatal_error(
693         "loading padded arrays from buffer fat pinters should have recursed");
694   return FixedVectorType::get(ET, AT->getNumElements());
695 }
696 
697 Value *LegalizeBufferContentTypesVisitor::arrayToVector(Value *V,
698                                                         Type *TargetType,
699                                                         const Twine &Name) {
700   Value *VectorRes = PoisonValue::get(TargetType);
701   auto *VT = cast<FixedVectorType>(TargetType);
702   unsigned EC = VT->getNumElements();
703   for (auto I : iota_range<unsigned>(0, EC, /*Inclusive=*/false)) {
704     Value *Elem = IRB.CreateExtractValue(V, I, Name + ".elem." + Twine(I));
705     VectorRes = IRB.CreateInsertElement(VectorRes, Elem, I,
706                                         Name + ".as.vec." + Twine(I));
707   }
708   return VectorRes;
709 }
710 
711 Value *LegalizeBufferContentTypesVisitor::vectorToArray(Value *V,
712                                                         Type *OrigType,
713                                                         const Twine &Name) {
714   Value *ArrayRes = PoisonValue::get(OrigType);
715   ArrayType *AT = cast<ArrayType>(OrigType);
716   unsigned EC = AT->getNumElements();
717   for (auto I : iota_range<unsigned>(0, EC, /*Inclusive=*/false)) {
718     Value *Elem = IRB.CreateExtractElement(V, I, Name + ".elem." + Twine(I));
719     ArrayRes = IRB.CreateInsertValue(ArrayRes, Elem, I,
720                                      Name + ".as.array." + Twine(I));
721   }
722   return ArrayRes;
723 }
724 
725 Type *LegalizeBufferContentTypesVisitor::legalNonAggregateFor(Type *T) {
726   TypeSize Size = DL.getTypeStoreSizeInBits(T);
727   // Implicitly zero-extend to the next byte if needed
728   if (!DL.typeSizeEqualsStoreSize(T))
729     T = IRB.getIntNTy(Size.getFixedValue());
730   Type *ElemTy = T->getScalarType();
731   if (isa<PointerType, ScalableVectorType>(ElemTy)) {
732     // Pointers are always big enough, and we'll let scalable vectors through to
733     // fail in codegen.
734     return T;
735   }
736   unsigned ElemSize = DL.getTypeSizeInBits(ElemTy).getFixedValue();
737   if (isPowerOf2_32(ElemSize) && ElemSize >= 16 && ElemSize <= 128) {
738     // [vectors of] anything that's 16/32/64/128 bits can be cast and split into
739     // legal buffer operations.
740     return T;
741   }
742   Type *BestVectorElemType = nullptr;
743   if (Size.isKnownMultipleOf(32))
744     BestVectorElemType = IRB.getInt32Ty();
745   else if (Size.isKnownMultipleOf(16))
746     BestVectorElemType = IRB.getInt16Ty();
747   else
748     BestVectorElemType = IRB.getInt8Ty();
749   unsigned NumCastElems =
750       Size.getFixedValue() / BestVectorElemType->getIntegerBitWidth();
751   if (NumCastElems == 1)
752     return BestVectorElemType;
753   return FixedVectorType::get(BestVectorElemType, NumCastElems);
754 }
755 
756 Value *LegalizeBufferContentTypesVisitor::makeLegalNonAggregate(
757     Value *V, Type *TargetType, const Twine &Name) {
758   Type *SourceType = V->getType();
759   TypeSize SourceSize = DL.getTypeSizeInBits(SourceType);
760   TypeSize TargetSize = DL.getTypeSizeInBits(TargetType);
761   if (SourceSize != TargetSize) {
762     Type *ShortScalarTy = IRB.getIntNTy(SourceSize.getFixedValue());
763     Type *ByteScalarTy = IRB.getIntNTy(TargetSize.getFixedValue());
764     Value *AsScalar = IRB.CreateBitCast(V, ShortScalarTy, Name + ".as.scalar");
765     Value *Zext = IRB.CreateZExt(AsScalar, ByteScalarTy, Name + ".zext");
766     V = Zext;
767     SourceType = ByteScalarTy;
768   }
769   return IRB.CreateBitCast(V, TargetType, Name + ".legal");
770 }
771 
772 Value *LegalizeBufferContentTypesVisitor::makeIllegalNonAggregate(
773     Value *V, Type *OrigType, const Twine &Name) {
774   Type *LegalType = V->getType();
775   TypeSize LegalSize = DL.getTypeSizeInBits(LegalType);
776   TypeSize OrigSize = DL.getTypeSizeInBits(OrigType);
777   if (LegalSize != OrigSize) {
778     Type *ShortScalarTy = IRB.getIntNTy(OrigSize.getFixedValue());
779     Type *ByteScalarTy = IRB.getIntNTy(LegalSize.getFixedValue());
780     Value *AsScalar = IRB.CreateBitCast(V, ByteScalarTy, Name + ".bytes.cast");
781     Value *Trunc = IRB.CreateTrunc(AsScalar, ShortScalarTy, Name + ".trunc");
782     return IRB.CreateBitCast(Trunc, OrigType, Name + ".orig");
783   }
784   return IRB.CreateBitCast(V, OrigType, Name + ".real.ty");
785 }
786 
787 Type *LegalizeBufferContentTypesVisitor::intrinsicTypeFor(Type *LegalType) {
788   auto *VT = dyn_cast<FixedVectorType>(LegalType);
789   if (!VT)
790     return LegalType;
791   Type *ET = VT->getElementType();
792   // Explicitly return the element type of 1-element vectors because the
793   // underlying intrinsics don't like <1 x T> even though it's a synonym for T.
794   if (VT->getNumElements() == 1)
795     return ET;
796   if (DL.getTypeSizeInBits(LegalType) == 96 && DL.getTypeSizeInBits(ET) < 32)
797     return FixedVectorType::get(IRB.getInt32Ty(), 3);
798   if (ET->isIntegerTy(8)) {
799     switch (VT->getNumElements()) {
800     default:
801       return LegalType; // Let it crash later
802     case 1:
803       return IRB.getInt8Ty();
804     case 2:
805       return IRB.getInt16Ty();
806     case 4:
807       return IRB.getInt32Ty();
808     case 8:
809       return FixedVectorType::get(IRB.getInt32Ty(), 2);
810     case 16:
811       return FixedVectorType::get(IRB.getInt32Ty(), 4);
812     }
813   }
814   return LegalType;
815 }
816 
817 void LegalizeBufferContentTypesVisitor::getVecSlices(
818     Type *T, SmallVectorImpl<VecSlice> &Slices) {
819   Slices.clear();
820   auto *VT = dyn_cast<FixedVectorType>(T);
821   if (!VT)
822     return;
823 
824   uint64_t ElemBitWidth =
825       DL.getTypeSizeInBits(VT->getElementType()).getFixedValue();
826 
827   uint64_t ElemsPer4Words = 128 / ElemBitWidth;
828   uint64_t ElemsPer2Words = ElemsPer4Words / 2;
829   uint64_t ElemsPerWord = ElemsPer2Words / 2;
830   uint64_t ElemsPerShort = ElemsPerWord / 2;
831   uint64_t ElemsPerByte = ElemsPerShort / 2;
832   // If the elements evenly pack into 32-bit words, we can use 3-word stores,
833   // such as for <6 x bfloat> or <3 x i32>, but we can't dot his for, for
834   // example, <3 x i64>, since that's not slicing.
835   uint64_t ElemsPer3Words = ElemsPerWord * 3;
836 
837   uint64_t TotalElems = VT->getNumElements();
838   uint64_t Index = 0;
839   auto TrySlice = [&](unsigned MaybeLen) {
840     if (MaybeLen > 0 && Index + MaybeLen <= TotalElems) {
841       VecSlice Slice{/*Index=*/Index, /*Length=*/MaybeLen};
842       Slices.push_back(Slice);
843       Index += MaybeLen;
844       return true;
845     }
846     return false;
847   };
848   while (Index < TotalElems) {
849     TrySlice(ElemsPer4Words) || TrySlice(ElemsPer3Words) ||
850         TrySlice(ElemsPer2Words) || TrySlice(ElemsPerWord) ||
851         TrySlice(ElemsPerShort) || TrySlice(ElemsPerByte);
852   }
853 }
854 
855 Value *LegalizeBufferContentTypesVisitor::extractSlice(Value *Vec, VecSlice S,
856                                                        const Twine &Name) {
857   auto *VecVT = dyn_cast<FixedVectorType>(Vec->getType());
858   if (!VecVT)
859     return Vec;
860   if (S.Length == VecVT->getNumElements() && S.Index == 0)
861     return Vec;
862   if (S.Length == 1)
863     return IRB.CreateExtractElement(Vec, S.Index,
864                                     Name + ".slice." + Twine(S.Index));
865   SmallVector<int> Mask = llvm::to_vector(
866       llvm::iota_range<int>(S.Index, S.Index + S.Length, /*Inclusive=*/false));
867   return IRB.CreateShuffleVector(Vec, Mask, Name + ".slice." + Twine(S.Index));
868 }
869 
870 Value *LegalizeBufferContentTypesVisitor::insertSlice(Value *Whole, Value *Part,
871                                                       VecSlice S,
872                                                       const Twine &Name) {
873   auto *WholeVT = dyn_cast<FixedVectorType>(Whole->getType());
874   if (!WholeVT)
875     return Part;
876   if (S.Length == WholeVT->getNumElements() && S.Index == 0)
877     return Part;
878   if (S.Length == 1) {
879     return IRB.CreateInsertElement(Whole, Part, S.Index,
880                                    Name + ".slice." + Twine(S.Index));
881   }
882   int NumElems = cast<FixedVectorType>(Whole->getType())->getNumElements();
883 
884   // Extend the slice with poisons to make the main shufflevector happy.
885   SmallVector<int> ExtPartMask(NumElems, -1);
886   for (auto [I, E] : llvm::enumerate(
887            MutableArrayRef<int>(ExtPartMask).take_front(S.Length))) {
888     E = I;
889   }
890   Value *ExtPart = IRB.CreateShuffleVector(Part, ExtPartMask,
891                                            Name + ".ext." + Twine(S.Index));
892 
893   SmallVector<int> Mask =
894       llvm::to_vector(llvm::iota_range<int>(0, NumElems, /*Inclusive=*/false));
895   for (auto [I, E] :
896        llvm::enumerate(MutableArrayRef<int>(Mask).slice(S.Index, S.Length)))
897     E = I + NumElems;
898   return IRB.CreateShuffleVector(Whole, ExtPart, Mask,
899                                  Name + ".parts." + Twine(S.Index));
900 }
901 
902 bool LegalizeBufferContentTypesVisitor::visitLoadImpl(
903     LoadInst &OrigLI, Type *PartType, SmallVectorImpl<uint32_t> &AggIdxs,
904     uint64_t AggByteOff, Value *&Result, const Twine &Name) {
905   if (auto *ST = dyn_cast<StructType>(PartType)) {
906     const StructLayout *Layout = DL.getStructLayout(ST);
907     bool Changed = false;
908     for (auto [I, ElemTy, Offset] :
909          llvm::enumerate(ST->elements(), Layout->getMemberOffsets())) {
910       AggIdxs.push_back(I);
911       Changed |= visitLoadImpl(OrigLI, ElemTy, AggIdxs,
912                                AggByteOff + Offset.getFixedValue(), Result,
913                                Name + "." + Twine(I));
914       AggIdxs.pop_back();
915     }
916     return Changed;
917   }
918   if (auto *AT = dyn_cast<ArrayType>(PartType)) {
919     Type *ElemTy = AT->getElementType();
920     if (!ElemTy->isSingleValueType() || !DL.typeSizeEqualsStoreSize(ElemTy) ||
921         ElemTy->isVectorTy()) {
922       TypeSize ElemStoreSize = DL.getTypeStoreSize(ElemTy);
923       bool Changed = false;
924       for (auto I : llvm::iota_range<uint32_t>(0, AT->getNumElements(),
925                                                /*Inclusive=*/false)) {
926         AggIdxs.push_back(I);
927         Changed |= visitLoadImpl(OrigLI, ElemTy, AggIdxs,
928                                  AggByteOff + I * ElemStoreSize.getFixedValue(),
929                                  Result, Name + Twine(I));
930         AggIdxs.pop_back();
931       }
932       return Changed;
933     }
934   }
935 
936   // Typical case
937 
938   Type *ArrayAsVecType = scalarArrayTypeAsVector(PartType);
939   Type *LegalType = legalNonAggregateFor(ArrayAsVecType);
940 
941   SmallVector<VecSlice> Slices;
942   getVecSlices(LegalType, Slices);
943   bool HasSlices = Slices.size() > 1;
944   bool IsAggPart = !AggIdxs.empty();
945   Value *LoadsRes;
946   if (!HasSlices && !IsAggPart) {
947     Type *LoadableType = intrinsicTypeFor(LegalType);
948     if (LoadableType == PartType)
949       return false;
950 
951     IRB.SetInsertPoint(&OrigLI);
952     auto *NLI = cast<LoadInst>(OrigLI.clone());
953     NLI->mutateType(LoadableType);
954     NLI = IRB.Insert(NLI);
955     NLI->setName(Name + ".loadable");
956 
957     LoadsRes = IRB.CreateBitCast(NLI, LegalType, Name + ".from.loadable");
958   } else {
959     IRB.SetInsertPoint(&OrigLI);
960     LoadsRes = PoisonValue::get(LegalType);
961     Value *OrigPtr = OrigLI.getPointerOperand();
962     // If we're needing to spill something into more than one load, its legal
963     // type will be a vector (ex. an i256 load will have LegalType = <8 x i32>).
964     // But if we're already a scalar (which can happen if we're splitting up a
965     // struct), the element type will be the legal type itself.
966     Type *ElemType = LegalType->getScalarType();
967     unsigned ElemBytes = DL.getTypeStoreSize(ElemType);
968     AAMDNodes AANodes = OrigLI.getAAMetadata();
969     if (IsAggPart && Slices.empty())
970       Slices.push_back(VecSlice{/*Index=*/0, /*Length=*/1});
971     for (VecSlice S : Slices) {
972       Type *SliceType =
973           S.Length != 1 ? FixedVectorType::get(ElemType, S.Length) : ElemType;
974       int64_t ByteOffset = AggByteOff + S.Index * ElemBytes;
975       // You can't reasonably expect loads to wrap around the edge of memory.
976       Value *NewPtr = IRB.CreateGEP(
977           IRB.getInt8Ty(), OrigLI.getPointerOperand(), IRB.getInt32(ByteOffset),
978           OrigPtr->getName() + ".off.ptr." + Twine(ByteOffset),
979           GEPNoWrapFlags::noUnsignedWrap());
980       Type *LoadableType = intrinsicTypeFor(SliceType);
981       LoadInst *NewLI = IRB.CreateAlignedLoad(
982           LoadableType, NewPtr, commonAlignment(OrigLI.getAlign(), ByteOffset),
983           Name + ".off." + Twine(ByteOffset));
984       copyMetadataForLoad(*NewLI, OrigLI);
985       NewLI->setAAMetadata(
986           AANodes.adjustForAccess(ByteOffset, LoadableType, DL));
987       NewLI->setAtomic(OrigLI.getOrdering(), OrigLI.getSyncScopeID());
988       NewLI->setVolatile(OrigLI.isVolatile());
989       Value *Loaded = IRB.CreateBitCast(NewLI, SliceType,
990                                         NewLI->getName() + ".from.loadable");
991       LoadsRes = insertSlice(LoadsRes, Loaded, S, Name);
992     }
993   }
994   if (LegalType != ArrayAsVecType)
995     LoadsRes = makeIllegalNonAggregate(LoadsRes, ArrayAsVecType, Name);
996   if (ArrayAsVecType != PartType)
997     LoadsRes = vectorToArray(LoadsRes, PartType, Name);
998 
999   if (IsAggPart)
1000     Result = IRB.CreateInsertValue(Result, LoadsRes, AggIdxs, Name);
1001   else
1002     Result = LoadsRes;
1003   return true;
1004 }
1005 
1006 bool LegalizeBufferContentTypesVisitor::visitLoadInst(LoadInst &LI) {
1007   if (LI.getPointerAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
1008     return false;
1009 
1010   SmallVector<uint32_t> AggIdxs;
1011   Type *OrigType = LI.getType();
1012   Value *Result = PoisonValue::get(OrigType);
1013   bool Changed = visitLoadImpl(LI, OrigType, AggIdxs, 0, Result, LI.getName());
1014   if (!Changed)
1015     return false;
1016   Result->takeName(&LI);
1017   LI.replaceAllUsesWith(Result);
1018   LI.eraseFromParent();
1019   return Changed;
1020 }
1021 
1022 std::pair<bool, bool> LegalizeBufferContentTypesVisitor::visitStoreImpl(
1023     StoreInst &OrigSI, Type *PartType, SmallVectorImpl<uint32_t> &AggIdxs,
1024     uint64_t AggByteOff, const Twine &Name) {
1025   if (auto *ST = dyn_cast<StructType>(PartType)) {
1026     const StructLayout *Layout = DL.getStructLayout(ST);
1027     bool Changed = false;
1028     for (auto [I, ElemTy, Offset] :
1029          llvm::enumerate(ST->elements(), Layout->getMemberOffsets())) {
1030       AggIdxs.push_back(I);
1031       Changed |= std::get<0>(visitStoreImpl(OrigSI, ElemTy, AggIdxs,
1032                                             AggByteOff + Offset.getFixedValue(),
1033                                             Name + "." + Twine(I)));
1034       AggIdxs.pop_back();
1035     }
1036     return std::make_pair(Changed, /*ModifiedInPlace=*/false);
1037   }
1038   if (auto *AT = dyn_cast<ArrayType>(PartType)) {
1039     Type *ElemTy = AT->getElementType();
1040     if (!ElemTy->isSingleValueType() || !DL.typeSizeEqualsStoreSize(ElemTy) ||
1041         ElemTy->isVectorTy()) {
1042       TypeSize ElemStoreSize = DL.getTypeStoreSize(ElemTy);
1043       bool Changed = false;
1044       for (auto I : llvm::iota_range<uint32_t>(0, AT->getNumElements(),
1045                                                /*Inclusive=*/false)) {
1046         AggIdxs.push_back(I);
1047         Changed |= std::get<0>(visitStoreImpl(
1048             OrigSI, ElemTy, AggIdxs,
1049             AggByteOff + I * ElemStoreSize.getFixedValue(), Name + Twine(I)));
1050         AggIdxs.pop_back();
1051       }
1052       return std::make_pair(Changed, /*ModifiedInPlace=*/false);
1053     }
1054   }
1055 
1056   Value *OrigData = OrigSI.getValueOperand();
1057   Value *NewData = OrigData;
1058 
1059   bool IsAggPart = !AggIdxs.empty();
1060   if (IsAggPart)
1061     NewData = IRB.CreateExtractValue(NewData, AggIdxs, Name);
1062 
1063   Type *ArrayAsVecType = scalarArrayTypeAsVector(PartType);
1064   if (ArrayAsVecType != PartType) {
1065     NewData = arrayToVector(NewData, ArrayAsVecType, Name);
1066   }
1067 
1068   Type *LegalType = legalNonAggregateFor(ArrayAsVecType);
1069   if (LegalType != ArrayAsVecType) {
1070     NewData = makeLegalNonAggregate(NewData, LegalType, Name);
1071   }
1072 
1073   SmallVector<VecSlice> Slices;
1074   getVecSlices(LegalType, Slices);
1075   bool NeedToSplit = Slices.size() > 1 || IsAggPart;
1076   if (!NeedToSplit) {
1077     Type *StorableType = intrinsicTypeFor(LegalType);
1078     if (StorableType == PartType)
1079       return std::make_pair(/*Changed=*/false, /*ModifiedInPlace=*/false);
1080     NewData = IRB.CreateBitCast(NewData, StorableType, Name + ".storable");
1081     OrigSI.setOperand(0, NewData);
1082     return std::make_pair(/*Changed=*/true, /*ModifiedInPlace=*/true);
1083   }
1084 
1085   Value *OrigPtr = OrigSI.getPointerOperand();
1086   Type *ElemType = LegalType->getScalarType();
1087   if (IsAggPart && Slices.empty())
1088     Slices.push_back(VecSlice{/*Index=*/0, /*Length=*/1});
1089   unsigned ElemBytes = DL.getTypeStoreSize(ElemType);
1090   AAMDNodes AANodes = OrigSI.getAAMetadata();
1091   for (VecSlice S : Slices) {
1092     Type *SliceType =
1093         S.Length != 1 ? FixedVectorType::get(ElemType, S.Length) : ElemType;
1094     int64_t ByteOffset = AggByteOff + S.Index * ElemBytes;
1095     Value *NewPtr =
1096         IRB.CreateGEP(IRB.getInt8Ty(), OrigPtr, IRB.getInt32(ByteOffset),
1097                       OrigPtr->getName() + ".part." + Twine(S.Index),
1098                       GEPNoWrapFlags::noUnsignedWrap());
1099     Value *DataSlice = extractSlice(NewData, S, Name);
1100     Type *StorableType = intrinsicTypeFor(SliceType);
1101     DataSlice = IRB.CreateBitCast(DataSlice, StorableType,
1102                                   DataSlice->getName() + ".storable");
1103     auto *NewSI = cast<StoreInst>(OrigSI.clone());
1104     NewSI->setAlignment(commonAlignment(OrigSI.getAlign(), ByteOffset));
1105     IRB.Insert(NewSI);
1106     NewSI->setOperand(0, DataSlice);
1107     NewSI->setOperand(1, NewPtr);
1108     NewSI->setAAMetadata(AANodes.adjustForAccess(ByteOffset, StorableType, DL));
1109   }
1110   return std::make_pair(/*Changed=*/true, /*ModifiedInPlace=*/false);
1111 }
1112 
1113 bool LegalizeBufferContentTypesVisitor::visitStoreInst(StoreInst &SI) {
1114   if (SI.getPointerAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
1115     return false;
1116   IRB.SetInsertPoint(&SI);
1117   SmallVector<uint32_t> AggIdxs;
1118   Value *OrigData = SI.getValueOperand();
1119   auto [Changed, ModifiedInPlace] =
1120       visitStoreImpl(SI, OrigData->getType(), AggIdxs, 0, OrigData->getName());
1121   if (Changed && !ModifiedInPlace)
1122     SI.eraseFromParent();
1123   return Changed;
1124 }
1125 
1126 bool LegalizeBufferContentTypesVisitor::processFunction(Function &F) {
1127   bool Changed = false;
1128   for (Instruction &I : make_early_inc_range(instructions(F))) {
1129     Changed |= visit(I);
1130   }
1131   return Changed;
1132 }
1133 
1134 /// Return the ptr addrspace(8) and i32 (resource and offset parts) in a lowered
1135 /// buffer fat pointer constant.
1136 static std::pair<Constant *, Constant *>
1137 splitLoweredFatBufferConst(Constant *C) {
1138   assert(isSplitFatPtr(C->getType()) && "Not a split fat buffer pointer");
1139   return std::make_pair(C->getAggregateElement(0u), C->getAggregateElement(1u));
1140 }
1141 
1142 namespace {
1143 /// Handle the remapping of ptr addrspace(7) constants.
1144 class FatPtrConstMaterializer final : public ValueMaterializer {
1145   BufferFatPtrToStructTypeMap *TypeMap;
1146   // An internal mapper that is used to recurse into the arguments of constants.
1147   // While the documentation for `ValueMapper` specifies not to use it
1148   // recursively, examination of the logic in mapValue() shows that it can
1149   // safely be used recursively when handling constants, like it does in its own
1150   // logic.
1151   ValueMapper InternalMapper;
1152 
1153   Constant *materializeBufferFatPtrConst(Constant *C);
1154 
1155 public:
1156   // UnderlyingMap is the value map this materializer will be filling.
1157   FatPtrConstMaterializer(BufferFatPtrToStructTypeMap *TypeMap,
1158                           ValueToValueMapTy &UnderlyingMap)
1159       : TypeMap(TypeMap),
1160         InternalMapper(UnderlyingMap, RF_None, TypeMap, this) {}
1161   virtual ~FatPtrConstMaterializer() = default;
1162 
1163   Value *materialize(Value *V) override;
1164 };
1165 } // namespace
1166 
1167 Constant *FatPtrConstMaterializer::materializeBufferFatPtrConst(Constant *C) {
1168   Type *SrcTy = C->getType();
1169   auto *NewTy = dyn_cast<StructType>(TypeMap->remapType(SrcTy));
1170   if (C->isNullValue())
1171     return ConstantAggregateZero::getNullValue(NewTy);
1172   if (isa<PoisonValue>(C)) {
1173     return ConstantStruct::get(NewTy,
1174                                {PoisonValue::get(NewTy->getElementType(0)),
1175                                 PoisonValue::get(NewTy->getElementType(1))});
1176   }
1177   if (isa<UndefValue>(C)) {
1178     return ConstantStruct::get(NewTy,
1179                                {UndefValue::get(NewTy->getElementType(0)),
1180                                 UndefValue::get(NewTy->getElementType(1))});
1181   }
1182 
1183   if (auto *VC = dyn_cast<ConstantVector>(C)) {
1184     if (Constant *S = VC->getSplatValue()) {
1185       Constant *NewS = InternalMapper.mapConstant(*S);
1186       if (!NewS)
1187         return nullptr;
1188       auto [Rsrc, Off] = splitLoweredFatBufferConst(NewS);
1189       auto EC = VC->getType()->getElementCount();
1190       return ConstantStruct::get(NewTy, {ConstantVector::getSplat(EC, Rsrc),
1191                                          ConstantVector::getSplat(EC, Off)});
1192     }
1193     SmallVector<Constant *> Rsrcs;
1194     SmallVector<Constant *> Offs;
1195     for (Value *Op : VC->operand_values()) {
1196       auto *NewOp = dyn_cast_or_null<Constant>(InternalMapper.mapValue(*Op));
1197       if (!NewOp)
1198         return nullptr;
1199       auto [Rsrc, Off] = splitLoweredFatBufferConst(NewOp);
1200       Rsrcs.push_back(Rsrc);
1201       Offs.push_back(Off);
1202     }
1203     Constant *RsrcVec = ConstantVector::get(Rsrcs);
1204     Constant *OffVec = ConstantVector::get(Offs);
1205     return ConstantStruct::get(NewTy, {RsrcVec, OffVec});
1206   }
1207 
1208   if (isa<GlobalValue>(C))
1209     report_fatal_error("Global values containing ptr addrspace(7) (buffer "
1210                        "fat pointer) values are not supported");
1211 
1212   if (isa<ConstantExpr>(C))
1213     report_fatal_error("Constant exprs containing ptr addrspace(7) (buffer "
1214                        "fat pointer) values should have been expanded earlier");
1215 
1216   return nullptr;
1217 }
1218 
1219 Value *FatPtrConstMaterializer::materialize(Value *V) {
1220   Constant *C = dyn_cast<Constant>(V);
1221   if (!C)
1222     return nullptr;
1223   // Structs and other types that happen to contain fat pointers get remapped
1224   // by the mapValue() logic.
1225   if (!isBufferFatPtrConst(C))
1226     return nullptr;
1227   return materializeBufferFatPtrConst(C);
1228 }
1229 
1230 using PtrParts = std::pair<Value *, Value *>;
1231 namespace {
1232 // The visitor returns the resource and offset parts for an instruction if they
1233 // can be computed, or (nullptr, nullptr) for cases that don't have a meaningful
1234 // value mapping.
1235 class SplitPtrStructs : public InstVisitor<SplitPtrStructs, PtrParts> {
1236   ValueToValueMapTy RsrcParts;
1237   ValueToValueMapTy OffParts;
1238 
1239   // Track instructions that have been rewritten into a user of the component
1240   // parts of their ptr addrspace(7) input. Instructions that produced
1241   // ptr addrspace(7) parts should **not** be RAUW'd before being added to this
1242   // set, as that replacement will be handled in a post-visit step. However,
1243   // instructions that yield values that aren't fat pointers (ex. ptrtoint)
1244   // should RAUW themselves with new instructions that use the split parts
1245   // of their arguments during processing.
1246   DenseSet<Instruction *> SplitUsers;
1247 
1248   // Nodes that need a second look once we've computed the parts for all other
1249   // instructions to see if, for example, we really need to phi on the resource
1250   // part.
1251   SmallVector<Instruction *> Conditionals;
1252   // Temporary instructions produced while lowering conditionals that should be
1253   // killed.
1254   SmallVector<Instruction *> ConditionalTemps;
1255 
1256   // Subtarget info, needed for determining what cache control bits to set.
1257   const TargetMachine *TM;
1258   const GCNSubtarget *ST = nullptr;
1259 
1260   IRBuilder<> IRB;
1261 
1262   // Copy metadata between instructions if applicable.
1263   void copyMetadata(Value *Dest, Value *Src);
1264 
1265   // Get the resource and offset parts of the value V, inserting appropriate
1266   // extractvalue calls if needed.
1267   PtrParts getPtrParts(Value *V);
1268 
1269   // Given an instruction that could produce multiple resource parts (a PHI or
1270   // select), collect the set of possible instructions that could have provided
1271   // its resource parts  that it could have (the `Roots`) and the set of
1272   // conditional instructions visited during the search (`Seen`). If, after
1273   // removing the root of the search from `Seen` and `Roots`, `Seen` is a subset
1274   // of `Roots` and `Roots - Seen` contains one element, the resource part of
1275   // that element can replace the resource part of all other elements in `Seen`.
1276   void getPossibleRsrcRoots(Instruction *I, SmallPtrSetImpl<Value *> &Roots,
1277                             SmallPtrSetImpl<Value *> &Seen);
1278   void processConditionals();
1279 
1280   // If an instruction hav been split into resource and offset parts,
1281   // delete that instruction. If any of its uses have not themselves been split
1282   // into parts (for example, an insertvalue), construct the structure
1283   // that the type rewrites declared should be produced by the dying instruction
1284   // and use that.
1285   // Also, kill the temporary extractvalue operations produced by the two-stage
1286   // lowering of PHIs and conditionals.
1287   void killAndReplaceSplitInstructions(SmallVectorImpl<Instruction *> &Origs);
1288 
1289   void setAlign(CallInst *Intr, Align A, unsigned RsrcArgIdx);
1290   void insertPreMemOpFence(AtomicOrdering Order, SyncScope::ID SSID);
1291   void insertPostMemOpFence(AtomicOrdering Order, SyncScope::ID SSID);
1292   Value *handleMemoryInst(Instruction *I, Value *Arg, Value *Ptr, Type *Ty,
1293                           Align Alignment, AtomicOrdering Order,
1294                           bool IsVolatile, SyncScope::ID SSID);
1295 
1296 public:
1297   SplitPtrStructs(LLVMContext &Ctx, const TargetMachine *TM)
1298       : TM(TM), IRB(Ctx) {}
1299 
1300   void processFunction(Function &F);
1301 
1302   PtrParts visitInstruction(Instruction &I);
1303   PtrParts visitLoadInst(LoadInst &LI);
1304   PtrParts visitStoreInst(StoreInst &SI);
1305   PtrParts visitAtomicRMWInst(AtomicRMWInst &AI);
1306   PtrParts visitAtomicCmpXchgInst(AtomicCmpXchgInst &AI);
1307   PtrParts visitGetElementPtrInst(GetElementPtrInst &GEP);
1308 
1309   PtrParts visitPtrToIntInst(PtrToIntInst &PI);
1310   PtrParts visitIntToPtrInst(IntToPtrInst &IP);
1311   PtrParts visitAddrSpaceCastInst(AddrSpaceCastInst &I);
1312   PtrParts visitICmpInst(ICmpInst &Cmp);
1313   PtrParts visitFreezeInst(FreezeInst &I);
1314 
1315   PtrParts visitExtractElementInst(ExtractElementInst &I);
1316   PtrParts visitInsertElementInst(InsertElementInst &I);
1317   PtrParts visitShuffleVectorInst(ShuffleVectorInst &I);
1318 
1319   PtrParts visitPHINode(PHINode &PHI);
1320   PtrParts visitSelectInst(SelectInst &SI);
1321 
1322   PtrParts visitIntrinsicInst(IntrinsicInst &II);
1323 };
1324 } // namespace
1325 
1326 void SplitPtrStructs::copyMetadata(Value *Dest, Value *Src) {
1327   auto *DestI = dyn_cast<Instruction>(Dest);
1328   auto *SrcI = dyn_cast<Instruction>(Src);
1329 
1330   if (!DestI || !SrcI)
1331     return;
1332 
1333   DestI->copyMetadata(*SrcI);
1334 }
1335 
1336 PtrParts SplitPtrStructs::getPtrParts(Value *V) {
1337   assert(isSplitFatPtr(V->getType()) && "it's not meaningful to get the parts "
1338                                         "of something that wasn't rewritten");
1339   auto *RsrcEntry = &RsrcParts[V];
1340   auto *OffEntry = &OffParts[V];
1341   if (*RsrcEntry && *OffEntry)
1342     return {*RsrcEntry, *OffEntry};
1343 
1344   if (auto *C = dyn_cast<Constant>(V)) {
1345     auto [Rsrc, Off] = splitLoweredFatBufferConst(C);
1346     return {*RsrcEntry = Rsrc, *OffEntry = Off};
1347   }
1348 
1349   IRBuilder<>::InsertPointGuard Guard(IRB);
1350   if (auto *I = dyn_cast<Instruction>(V)) {
1351     LLVM_DEBUG(dbgs() << "Recursing to split parts of " << *I << "\n");
1352     auto [Rsrc, Off] = visit(*I);
1353     if (Rsrc && Off)
1354       return {*RsrcEntry = Rsrc, *OffEntry = Off};
1355     // We'll be creating the new values after the relevant instruction.
1356     // This instruction generates a value and so isn't a terminator.
1357     IRB.SetInsertPoint(*I->getInsertionPointAfterDef());
1358     IRB.SetCurrentDebugLocation(I->getDebugLoc());
1359   } else if (auto *A = dyn_cast<Argument>(V)) {
1360     IRB.SetInsertPointPastAllocas(A->getParent());
1361     IRB.SetCurrentDebugLocation(DebugLoc());
1362   }
1363   Value *Rsrc = IRB.CreateExtractValue(V, 0, V->getName() + ".rsrc");
1364   Value *Off = IRB.CreateExtractValue(V, 1, V->getName() + ".off");
1365   return {*RsrcEntry = Rsrc, *OffEntry = Off};
1366 }
1367 
1368 /// Returns the instruction that defines the resource part of the value V.
1369 /// Note that this is not getUnderlyingObject(), since that looks through
1370 /// operations like ptrmask which might modify the resource part.
1371 ///
1372 /// We can limit ourselves to just looking through GEPs followed by looking
1373 /// through addrspacecasts because only those two operations preserve the
1374 /// resource part, and because operations on an `addrspace(8)` (which is the
1375 /// legal input to this addrspacecast) would produce a different resource part.
1376 static Value *rsrcPartRoot(Value *V) {
1377   while (auto *GEP = dyn_cast<GEPOperator>(V))
1378     V = GEP->getPointerOperand();
1379   while (auto *ASC = dyn_cast<AddrSpaceCastOperator>(V))
1380     V = ASC->getPointerOperand();
1381   return V;
1382 }
1383 
1384 void SplitPtrStructs::getPossibleRsrcRoots(Instruction *I,
1385                                            SmallPtrSetImpl<Value *> &Roots,
1386                                            SmallPtrSetImpl<Value *> &Seen) {
1387   if (auto *PHI = dyn_cast<PHINode>(I)) {
1388     if (!Seen.insert(I).second)
1389       return;
1390     for (Value *In : PHI->incoming_values()) {
1391       In = rsrcPartRoot(In);
1392       Roots.insert(In);
1393       if (isa<PHINode, SelectInst>(In))
1394         getPossibleRsrcRoots(cast<Instruction>(In), Roots, Seen);
1395     }
1396   } else if (auto *SI = dyn_cast<SelectInst>(I)) {
1397     if (!Seen.insert(SI).second)
1398       return;
1399     Value *TrueVal = rsrcPartRoot(SI->getTrueValue());
1400     Value *FalseVal = rsrcPartRoot(SI->getFalseValue());
1401     Roots.insert(TrueVal);
1402     Roots.insert(FalseVal);
1403     if (isa<PHINode, SelectInst>(TrueVal))
1404       getPossibleRsrcRoots(cast<Instruction>(TrueVal), Roots, Seen);
1405     if (isa<PHINode, SelectInst>(FalseVal))
1406       getPossibleRsrcRoots(cast<Instruction>(FalseVal), Roots, Seen);
1407   } else {
1408     llvm_unreachable("getPossibleRsrcParts() only works on phi and select");
1409   }
1410 }
1411 
1412 void SplitPtrStructs::processConditionals() {
1413   SmallDenseMap<Instruction *, Value *> FoundRsrcs;
1414   SmallPtrSet<Value *, 4> Roots;
1415   SmallPtrSet<Value *, 4> Seen;
1416   for (Instruction *I : Conditionals) {
1417     // These have to exist by now because we've visited these nodes.
1418     Value *Rsrc = RsrcParts[I];
1419     Value *Off = OffParts[I];
1420     assert(Rsrc && Off && "must have visited conditionals by now");
1421 
1422     std::optional<Value *> MaybeRsrc;
1423     auto MaybeFoundRsrc = FoundRsrcs.find(I);
1424     if (MaybeFoundRsrc != FoundRsrcs.end()) {
1425       MaybeRsrc = MaybeFoundRsrc->second;
1426     } else {
1427       IRBuilder<>::InsertPointGuard Guard(IRB);
1428       Roots.clear();
1429       Seen.clear();
1430       getPossibleRsrcRoots(I, Roots, Seen);
1431       LLVM_DEBUG(dbgs() << "Processing conditional: " << *I << "\n");
1432 #ifndef NDEBUG
1433       for (Value *V : Roots)
1434         LLVM_DEBUG(dbgs() << "Root: " << *V << "\n");
1435       for (Value *V : Seen)
1436         LLVM_DEBUG(dbgs() << "Seen: " << *V << "\n");
1437 #endif
1438       // If we are our own possible root, then we shouldn't block our
1439       // replacement with a valid incoming value.
1440       Roots.erase(I);
1441       // We don't want to block the optimization for conditionals that don't
1442       // refer to themselves but did see themselves during the traversal.
1443       Seen.erase(I);
1444 
1445       if (set_is_subset(Seen, Roots)) {
1446         auto Diff = set_difference(Roots, Seen);
1447         if (Diff.size() == 1) {
1448           Value *RootVal = *Diff.begin();
1449           // Handle the case where previous loops already looked through
1450           // an addrspacecast.
1451           if (isSplitFatPtr(RootVal->getType()))
1452             MaybeRsrc = std::get<0>(getPtrParts(RootVal));
1453           else
1454             MaybeRsrc = RootVal;
1455         }
1456       }
1457     }
1458 
1459     if (auto *PHI = dyn_cast<PHINode>(I)) {
1460       Value *NewRsrc;
1461       StructType *PHITy = cast<StructType>(PHI->getType());
1462       IRB.SetInsertPoint(*PHI->getInsertionPointAfterDef());
1463       IRB.SetCurrentDebugLocation(PHI->getDebugLoc());
1464       if (MaybeRsrc) {
1465         NewRsrc = *MaybeRsrc;
1466       } else {
1467         Type *RsrcTy = PHITy->getElementType(0);
1468         auto *RsrcPHI = IRB.CreatePHI(RsrcTy, PHI->getNumIncomingValues());
1469         RsrcPHI->takeName(Rsrc);
1470         for (auto [V, BB] : llvm::zip(PHI->incoming_values(), PHI->blocks())) {
1471           Value *VRsrc = std::get<0>(getPtrParts(V));
1472           RsrcPHI->addIncoming(VRsrc, BB);
1473         }
1474         copyMetadata(RsrcPHI, PHI);
1475         NewRsrc = RsrcPHI;
1476       }
1477 
1478       Type *OffTy = PHITy->getElementType(1);
1479       auto *NewOff = IRB.CreatePHI(OffTy, PHI->getNumIncomingValues());
1480       NewOff->takeName(Off);
1481       for (auto [V, BB] : llvm::zip(PHI->incoming_values(), PHI->blocks())) {
1482         assert(OffParts.count(V) && "An offset part had to be created by now");
1483         Value *VOff = std::get<1>(getPtrParts(V));
1484         NewOff->addIncoming(VOff, BB);
1485       }
1486       copyMetadata(NewOff, PHI);
1487 
1488       // Note: We don't eraseFromParent() the temporaries because we don't want
1489       // to put the corrections maps in an inconstent state. That'll be handed
1490       // during the rest of the killing. Also, `ValueToValueMapTy` guarantees
1491       // that references in that map will be updated as well.
1492       ConditionalTemps.push_back(cast<Instruction>(Rsrc));
1493       ConditionalTemps.push_back(cast<Instruction>(Off));
1494       Rsrc->replaceAllUsesWith(NewRsrc);
1495       Off->replaceAllUsesWith(NewOff);
1496 
1497       // Save on recomputing the cycle traversals in known-root cases.
1498       if (MaybeRsrc)
1499         for (Value *V : Seen)
1500           FoundRsrcs[cast<Instruction>(V)] = NewRsrc;
1501     } else if (isa<SelectInst>(I)) {
1502       if (MaybeRsrc) {
1503         ConditionalTemps.push_back(cast<Instruction>(Rsrc));
1504         Rsrc->replaceAllUsesWith(*MaybeRsrc);
1505         for (Value *V : Seen)
1506           FoundRsrcs[cast<Instruction>(V)] = *MaybeRsrc;
1507       }
1508     } else {
1509       llvm_unreachable("Only PHIs and selects go in the conditionals list");
1510     }
1511   }
1512 }
1513 
1514 void SplitPtrStructs::killAndReplaceSplitInstructions(
1515     SmallVectorImpl<Instruction *> &Origs) {
1516   for (Instruction *I : ConditionalTemps)
1517     I->eraseFromParent();
1518 
1519   for (Instruction *I : Origs) {
1520     if (!SplitUsers.contains(I))
1521       continue;
1522 
1523     SmallVector<DbgValueInst *> Dbgs;
1524     findDbgValues(Dbgs, I);
1525     for (auto *Dbg : Dbgs) {
1526       IRB.SetInsertPoint(Dbg);
1527       auto &DL = I->getDataLayout();
1528       assert(isSplitFatPtr(I->getType()) &&
1529              "We should've RAUW'd away loads, stores, etc. at this point");
1530       auto *OffDbg = cast<DbgValueInst>(Dbg->clone());
1531       copyMetadata(OffDbg, Dbg);
1532       auto [Rsrc, Off] = getPtrParts(I);
1533 
1534       int64_t RsrcSz = DL.getTypeSizeInBits(Rsrc->getType());
1535       int64_t OffSz = DL.getTypeSizeInBits(Off->getType());
1536 
1537       std::optional<DIExpression *> RsrcExpr =
1538           DIExpression::createFragmentExpression(Dbg->getExpression(), 0,
1539                                                  RsrcSz);
1540       std::optional<DIExpression *> OffExpr =
1541           DIExpression::createFragmentExpression(Dbg->getExpression(), RsrcSz,
1542                                                  OffSz);
1543       if (OffExpr) {
1544         OffDbg->setExpression(*OffExpr);
1545         OffDbg->replaceVariableLocationOp(I, Off);
1546         IRB.Insert(OffDbg);
1547       } else {
1548         OffDbg->deleteValue();
1549       }
1550       if (RsrcExpr) {
1551         Dbg->setExpression(*RsrcExpr);
1552         Dbg->replaceVariableLocationOp(I, Rsrc);
1553       } else {
1554         Dbg->replaceVariableLocationOp(I, UndefValue::get(I->getType()));
1555       }
1556     }
1557 
1558     Value *Poison = PoisonValue::get(I->getType());
1559     I->replaceUsesWithIf(Poison, [&](const Use &U) -> bool {
1560       if (const auto *UI = dyn_cast<Instruction>(U.getUser()))
1561         return SplitUsers.contains(UI);
1562       return false;
1563     });
1564 
1565     if (I->use_empty()) {
1566       I->eraseFromParent();
1567       continue;
1568     }
1569     IRB.SetInsertPoint(*I->getInsertionPointAfterDef());
1570     IRB.SetCurrentDebugLocation(I->getDebugLoc());
1571     auto [Rsrc, Off] = getPtrParts(I);
1572     Value *Struct = PoisonValue::get(I->getType());
1573     Struct = IRB.CreateInsertValue(Struct, Rsrc, 0);
1574     Struct = IRB.CreateInsertValue(Struct, Off, 1);
1575     copyMetadata(Struct, I);
1576     Struct->takeName(I);
1577     I->replaceAllUsesWith(Struct);
1578     I->eraseFromParent();
1579   }
1580 }
1581 
1582 void SplitPtrStructs::setAlign(CallInst *Intr, Align A, unsigned RsrcArgIdx) {
1583   LLVMContext &Ctx = Intr->getContext();
1584   Intr->addParamAttr(RsrcArgIdx, Attribute::getWithAlignment(Ctx, A));
1585 }
1586 
1587 void SplitPtrStructs::insertPreMemOpFence(AtomicOrdering Order,
1588                                           SyncScope::ID SSID) {
1589   switch (Order) {
1590   case AtomicOrdering::Release:
1591   case AtomicOrdering::AcquireRelease:
1592   case AtomicOrdering::SequentiallyConsistent:
1593     IRB.CreateFence(AtomicOrdering::Release, SSID);
1594     break;
1595   default:
1596     break;
1597   }
1598 }
1599 
1600 void SplitPtrStructs::insertPostMemOpFence(AtomicOrdering Order,
1601                                            SyncScope::ID SSID) {
1602   switch (Order) {
1603   case AtomicOrdering::Acquire:
1604   case AtomicOrdering::AcquireRelease:
1605   case AtomicOrdering::SequentiallyConsistent:
1606     IRB.CreateFence(AtomicOrdering::Acquire, SSID);
1607     break;
1608   default:
1609     break;
1610   }
1611 }
1612 
1613 Value *SplitPtrStructs::handleMemoryInst(Instruction *I, Value *Arg, Value *Ptr,
1614                                          Type *Ty, Align Alignment,
1615                                          AtomicOrdering Order, bool IsVolatile,
1616                                          SyncScope::ID SSID) {
1617   IRB.SetInsertPoint(I);
1618 
1619   auto [Rsrc, Off] = getPtrParts(Ptr);
1620   SmallVector<Value *, 5> Args;
1621   if (Arg)
1622     Args.push_back(Arg);
1623   Args.push_back(Rsrc);
1624   Args.push_back(Off);
1625   insertPreMemOpFence(Order, SSID);
1626   // soffset is always 0 for these cases, where we always want any offset to be
1627   // part of bounds checking and we don't know which parts of the GEPs is
1628   // uniform.
1629   Args.push_back(IRB.getInt32(0));
1630 
1631   uint32_t Aux = 0;
1632   if (IsVolatile)
1633     Aux |= AMDGPU::CPol::VOLATILE;
1634   Args.push_back(IRB.getInt32(Aux));
1635 
1636   Intrinsic::ID IID = Intrinsic::not_intrinsic;
1637   if (isa<LoadInst>(I))
1638     IID = Order == AtomicOrdering::NotAtomic
1639               ? Intrinsic::amdgcn_raw_ptr_buffer_load
1640               : Intrinsic::amdgcn_raw_ptr_atomic_buffer_load;
1641   else if (isa<StoreInst>(I))
1642     IID = Intrinsic::amdgcn_raw_ptr_buffer_store;
1643   else if (auto *RMW = dyn_cast<AtomicRMWInst>(I)) {
1644     switch (RMW->getOperation()) {
1645     case AtomicRMWInst::Xchg:
1646       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_swap;
1647       break;
1648     case AtomicRMWInst::Add:
1649       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_add;
1650       break;
1651     case AtomicRMWInst::Sub:
1652       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_sub;
1653       break;
1654     case AtomicRMWInst::And:
1655       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_and;
1656       break;
1657     case AtomicRMWInst::Or:
1658       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_or;
1659       break;
1660     case AtomicRMWInst::Xor:
1661       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_xor;
1662       break;
1663     case AtomicRMWInst::Max:
1664       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_smax;
1665       break;
1666     case AtomicRMWInst::Min:
1667       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_smin;
1668       break;
1669     case AtomicRMWInst::UMax:
1670       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_umax;
1671       break;
1672     case AtomicRMWInst::UMin:
1673       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_umin;
1674       break;
1675     case AtomicRMWInst::FAdd:
1676       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fadd;
1677       break;
1678     case AtomicRMWInst::FMax:
1679       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fmax;
1680       break;
1681     case AtomicRMWInst::FMin:
1682       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fmin;
1683       break;
1684     case AtomicRMWInst::FSub: {
1685       report_fatal_error("atomic floating point subtraction not supported for "
1686                          "buffer resources and should've been expanded away");
1687       break;
1688     }
1689     case AtomicRMWInst::Nand:
1690       report_fatal_error("atomic nand not supported for buffer resources and "
1691                          "should've been expanded away");
1692       break;
1693     case AtomicRMWInst::UIncWrap:
1694     case AtomicRMWInst::UDecWrap:
1695       report_fatal_error("wrapping increment/decrement not supported for "
1696                          "buffer resources and should've ben expanded away");
1697       break;
1698     case AtomicRMWInst::BAD_BINOP:
1699       llvm_unreachable("Not sure how we got a bad binop");
1700     case AtomicRMWInst::USubCond:
1701     case AtomicRMWInst::USubSat:
1702       break;
1703     }
1704   }
1705 
1706   auto *Call = IRB.CreateIntrinsic(IID, Ty, Args);
1707   copyMetadata(Call, I);
1708   setAlign(Call, Alignment, Arg ? 1 : 0);
1709   Call->takeName(I);
1710 
1711   insertPostMemOpFence(Order, SSID);
1712   // The "no moving p7 directly" rewrites ensure that this load or store won't
1713   // itself need to be split into parts.
1714   SplitUsers.insert(I);
1715   I->replaceAllUsesWith(Call);
1716   return Call;
1717 }
1718 
1719 PtrParts SplitPtrStructs::visitInstruction(Instruction &I) {
1720   return {nullptr, nullptr};
1721 }
1722 
1723 PtrParts SplitPtrStructs::visitLoadInst(LoadInst &LI) {
1724   if (!isSplitFatPtr(LI.getPointerOperandType()))
1725     return {nullptr, nullptr};
1726   handleMemoryInst(&LI, nullptr, LI.getPointerOperand(), LI.getType(),
1727                    LI.getAlign(), LI.getOrdering(), LI.isVolatile(),
1728                    LI.getSyncScopeID());
1729   return {nullptr, nullptr};
1730 }
1731 
1732 PtrParts SplitPtrStructs::visitStoreInst(StoreInst &SI) {
1733   if (!isSplitFatPtr(SI.getPointerOperandType()))
1734     return {nullptr, nullptr};
1735   Value *Arg = SI.getValueOperand();
1736   handleMemoryInst(&SI, Arg, SI.getPointerOperand(), Arg->getType(),
1737                    SI.getAlign(), SI.getOrdering(), SI.isVolatile(),
1738                    SI.getSyncScopeID());
1739   return {nullptr, nullptr};
1740 }
1741 
1742 PtrParts SplitPtrStructs::visitAtomicRMWInst(AtomicRMWInst &AI) {
1743   if (!isSplitFatPtr(AI.getPointerOperand()->getType()))
1744     return {nullptr, nullptr};
1745   Value *Arg = AI.getValOperand();
1746   handleMemoryInst(&AI, Arg, AI.getPointerOperand(), Arg->getType(),
1747                    AI.getAlign(), AI.getOrdering(), AI.isVolatile(),
1748                    AI.getSyncScopeID());
1749   return {nullptr, nullptr};
1750 }
1751 
1752 // Unlike load, store, and RMW, cmpxchg needs special handling to account
1753 // for the boolean argument.
1754 PtrParts SplitPtrStructs::visitAtomicCmpXchgInst(AtomicCmpXchgInst &AI) {
1755   Value *Ptr = AI.getPointerOperand();
1756   if (!isSplitFatPtr(Ptr->getType()))
1757     return {nullptr, nullptr};
1758   IRB.SetInsertPoint(&AI);
1759 
1760   Type *Ty = AI.getNewValOperand()->getType();
1761   AtomicOrdering Order = AI.getMergedOrdering();
1762   SyncScope::ID SSID = AI.getSyncScopeID();
1763   bool IsNonTemporal = AI.getMetadata(LLVMContext::MD_nontemporal);
1764 
1765   auto [Rsrc, Off] = getPtrParts(Ptr);
1766   insertPreMemOpFence(Order, SSID);
1767 
1768   uint32_t Aux = 0;
1769   if (IsNonTemporal)
1770     Aux |= AMDGPU::CPol::SLC;
1771   if (AI.isVolatile())
1772     Aux |= AMDGPU::CPol::VOLATILE;
1773   auto *Call =
1774       IRB.CreateIntrinsic(Intrinsic::amdgcn_raw_ptr_buffer_atomic_cmpswap, Ty,
1775                           {AI.getNewValOperand(), AI.getCompareOperand(), Rsrc,
1776                            Off, IRB.getInt32(0), IRB.getInt32(Aux)});
1777   copyMetadata(Call, &AI);
1778   setAlign(Call, AI.getAlign(), 2);
1779   Call->takeName(&AI);
1780   insertPostMemOpFence(Order, SSID);
1781 
1782   Value *Res = PoisonValue::get(AI.getType());
1783   Res = IRB.CreateInsertValue(Res, Call, 0);
1784   if (!AI.isWeak()) {
1785     Value *Succeeded = IRB.CreateICmpEQ(Call, AI.getCompareOperand());
1786     Res = IRB.CreateInsertValue(Res, Succeeded, 1);
1787   }
1788   SplitUsers.insert(&AI);
1789   AI.replaceAllUsesWith(Res);
1790   return {nullptr, nullptr};
1791 }
1792 
1793 PtrParts SplitPtrStructs::visitGetElementPtrInst(GetElementPtrInst &GEP) {
1794   using namespace llvm::PatternMatch;
1795   Value *Ptr = GEP.getPointerOperand();
1796   if (!isSplitFatPtr(Ptr->getType()))
1797     return {nullptr, nullptr};
1798   IRB.SetInsertPoint(&GEP);
1799 
1800   auto [Rsrc, Off] = getPtrParts(Ptr);
1801   const DataLayout &DL = GEP.getDataLayout();
1802   bool IsNUW = GEP.hasNoUnsignedWrap();
1803   bool IsNUSW = GEP.hasNoUnsignedSignedWrap();
1804 
1805   // In order to call emitGEPOffset() and thus not have to reimplement it,
1806   // we need the GEP result to have ptr addrspace(7) type.
1807   Type *FatPtrTy = IRB.getPtrTy(AMDGPUAS::BUFFER_FAT_POINTER);
1808   if (auto *VT = dyn_cast<VectorType>(Off->getType()))
1809     FatPtrTy = VectorType::get(FatPtrTy, VT->getElementCount());
1810   GEP.mutateType(FatPtrTy);
1811   Value *OffAccum = emitGEPOffset(&IRB, DL, &GEP);
1812   GEP.mutateType(Ptr->getType());
1813   if (match(OffAccum, m_Zero())) { // Constant-zero offset
1814     SplitUsers.insert(&GEP);
1815     return {Rsrc, Off};
1816   }
1817 
1818   bool HasNonNegativeOff = false;
1819   if (auto *CI = dyn_cast<ConstantInt>(OffAccum)) {
1820     HasNonNegativeOff = !CI->isNegative();
1821   }
1822   Value *NewOff;
1823   if (match(Off, m_Zero())) {
1824     NewOff = OffAccum;
1825   } else {
1826     NewOff = IRB.CreateAdd(Off, OffAccum, "",
1827                            /*hasNUW=*/IsNUW || (IsNUSW && HasNonNegativeOff),
1828                            /*hasNSW=*/false);
1829   }
1830   copyMetadata(NewOff, &GEP);
1831   NewOff->takeName(&GEP);
1832   SplitUsers.insert(&GEP);
1833   return {Rsrc, NewOff};
1834 }
1835 
1836 PtrParts SplitPtrStructs::visitPtrToIntInst(PtrToIntInst &PI) {
1837   Value *Ptr = PI.getPointerOperand();
1838   if (!isSplitFatPtr(Ptr->getType()))
1839     return {nullptr, nullptr};
1840   IRB.SetInsertPoint(&PI);
1841 
1842   Type *ResTy = PI.getType();
1843   unsigned Width = ResTy->getScalarSizeInBits();
1844 
1845   auto [Rsrc, Off] = getPtrParts(Ptr);
1846   const DataLayout &DL = PI.getDataLayout();
1847   unsigned FatPtrWidth = DL.getPointerSizeInBits(AMDGPUAS::BUFFER_FAT_POINTER);
1848 
1849   Value *Res;
1850   if (Width <= BufferOffsetWidth) {
1851     Res = IRB.CreateIntCast(Off, ResTy, /*isSigned=*/false,
1852                             PI.getName() + ".off");
1853   } else {
1854     Value *RsrcInt = IRB.CreatePtrToInt(Rsrc, ResTy, PI.getName() + ".rsrc");
1855     Value *Shl = IRB.CreateShl(
1856         RsrcInt,
1857         ConstantExpr::getIntegerValue(ResTy, APInt(Width, BufferOffsetWidth)),
1858         "", Width >= FatPtrWidth, Width > FatPtrWidth);
1859     Value *OffCast = IRB.CreateIntCast(Off, ResTy, /*isSigned=*/false,
1860                                        PI.getName() + ".off");
1861     Res = IRB.CreateOr(Shl, OffCast);
1862   }
1863 
1864   copyMetadata(Res, &PI);
1865   Res->takeName(&PI);
1866   SplitUsers.insert(&PI);
1867   PI.replaceAllUsesWith(Res);
1868   return {nullptr, nullptr};
1869 }
1870 
1871 PtrParts SplitPtrStructs::visitIntToPtrInst(IntToPtrInst &IP) {
1872   if (!isSplitFatPtr(IP.getType()))
1873     return {nullptr, nullptr};
1874   IRB.SetInsertPoint(&IP);
1875   const DataLayout &DL = IP.getDataLayout();
1876   unsigned RsrcPtrWidth = DL.getPointerSizeInBits(AMDGPUAS::BUFFER_RESOURCE);
1877   Value *Int = IP.getOperand(0);
1878   Type *IntTy = Int->getType();
1879   Type *RsrcIntTy = IntTy->getWithNewBitWidth(RsrcPtrWidth);
1880   unsigned Width = IntTy->getScalarSizeInBits();
1881 
1882   auto *RetTy = cast<StructType>(IP.getType());
1883   Type *RsrcTy = RetTy->getElementType(0);
1884   Type *OffTy = RetTy->getElementType(1);
1885   Value *RsrcPart = IRB.CreateLShr(
1886       Int,
1887       ConstantExpr::getIntegerValue(IntTy, APInt(Width, BufferOffsetWidth)));
1888   Value *RsrcInt = IRB.CreateIntCast(RsrcPart, RsrcIntTy, /*isSigned=*/false);
1889   Value *Rsrc = IRB.CreateIntToPtr(RsrcInt, RsrcTy, IP.getName() + ".rsrc");
1890   Value *Off =
1891       IRB.CreateIntCast(Int, OffTy, /*IsSigned=*/false, IP.getName() + ".off");
1892 
1893   copyMetadata(Rsrc, &IP);
1894   SplitUsers.insert(&IP);
1895   return {Rsrc, Off};
1896 }
1897 
1898 PtrParts SplitPtrStructs::visitAddrSpaceCastInst(AddrSpaceCastInst &I) {
1899   if (!isSplitFatPtr(I.getType()))
1900     return {nullptr, nullptr};
1901   IRB.SetInsertPoint(&I);
1902   Value *In = I.getPointerOperand();
1903   // No-op casts preserve parts
1904   if (In->getType() == I.getType()) {
1905     auto [Rsrc, Off] = getPtrParts(In);
1906     SplitUsers.insert(&I);
1907     return {Rsrc, Off};
1908   }
1909   if (I.getSrcAddressSpace() != AMDGPUAS::BUFFER_RESOURCE)
1910     report_fatal_error("Only buffer resources (addrspace 8) can be cast to "
1911                        "buffer fat pointers (addrspace 7)");
1912   Type *OffTy = cast<StructType>(I.getType())->getElementType(1);
1913   Value *ZeroOff = Constant::getNullValue(OffTy);
1914   SplitUsers.insert(&I);
1915   return {In, ZeroOff};
1916 }
1917 
1918 PtrParts SplitPtrStructs::visitICmpInst(ICmpInst &Cmp) {
1919   Value *Lhs = Cmp.getOperand(0);
1920   if (!isSplitFatPtr(Lhs->getType()))
1921     return {nullptr, nullptr};
1922   Value *Rhs = Cmp.getOperand(1);
1923   IRB.SetInsertPoint(&Cmp);
1924   ICmpInst::Predicate Pred = Cmp.getPredicate();
1925 
1926   assert((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) &&
1927          "Pointer comparison is only equal or unequal");
1928   auto [LhsRsrc, LhsOff] = getPtrParts(Lhs);
1929   auto [RhsRsrc, RhsOff] = getPtrParts(Rhs);
1930   Value *RsrcCmp =
1931       IRB.CreateICmp(Pred, LhsRsrc, RhsRsrc, Cmp.getName() + ".rsrc");
1932   copyMetadata(RsrcCmp, &Cmp);
1933   Value *OffCmp = IRB.CreateICmp(Pred, LhsOff, RhsOff, Cmp.getName() + ".off");
1934   copyMetadata(OffCmp, &Cmp);
1935 
1936   Value *Res = nullptr;
1937   if (Pred == ICmpInst::ICMP_EQ)
1938     Res = IRB.CreateAnd(RsrcCmp, OffCmp);
1939   else if (Pred == ICmpInst::ICMP_NE)
1940     Res = IRB.CreateOr(RsrcCmp, OffCmp);
1941   copyMetadata(Res, &Cmp);
1942   Res->takeName(&Cmp);
1943   SplitUsers.insert(&Cmp);
1944   Cmp.replaceAllUsesWith(Res);
1945   return {nullptr, nullptr};
1946 }
1947 
1948 PtrParts SplitPtrStructs::visitFreezeInst(FreezeInst &I) {
1949   if (!isSplitFatPtr(I.getType()))
1950     return {nullptr, nullptr};
1951   IRB.SetInsertPoint(&I);
1952   auto [Rsrc, Off] = getPtrParts(I.getOperand(0));
1953 
1954   Value *RsrcRes = IRB.CreateFreeze(Rsrc, I.getName() + ".rsrc");
1955   copyMetadata(RsrcRes, &I);
1956   Value *OffRes = IRB.CreateFreeze(Off, I.getName() + ".off");
1957   copyMetadata(OffRes, &I);
1958   SplitUsers.insert(&I);
1959   return {RsrcRes, OffRes};
1960 }
1961 
1962 PtrParts SplitPtrStructs::visitExtractElementInst(ExtractElementInst &I) {
1963   if (!isSplitFatPtr(I.getType()))
1964     return {nullptr, nullptr};
1965   IRB.SetInsertPoint(&I);
1966   Value *Vec = I.getVectorOperand();
1967   Value *Idx = I.getIndexOperand();
1968   auto [Rsrc, Off] = getPtrParts(Vec);
1969 
1970   Value *RsrcRes = IRB.CreateExtractElement(Rsrc, Idx, I.getName() + ".rsrc");
1971   copyMetadata(RsrcRes, &I);
1972   Value *OffRes = IRB.CreateExtractElement(Off, Idx, I.getName() + ".off");
1973   copyMetadata(OffRes, &I);
1974   SplitUsers.insert(&I);
1975   return {RsrcRes, OffRes};
1976 }
1977 
1978 PtrParts SplitPtrStructs::visitInsertElementInst(InsertElementInst &I) {
1979   // The mutated instructions temporarily don't return vectors, and so
1980   // we need the generic getType() here to avoid crashes.
1981   if (!isSplitFatPtr(cast<Instruction>(I).getType()))
1982     return {nullptr, nullptr};
1983   IRB.SetInsertPoint(&I);
1984   Value *Vec = I.getOperand(0);
1985   Value *Elem = I.getOperand(1);
1986   Value *Idx = I.getOperand(2);
1987   auto [VecRsrc, VecOff] = getPtrParts(Vec);
1988   auto [ElemRsrc, ElemOff] = getPtrParts(Elem);
1989 
1990   Value *RsrcRes =
1991       IRB.CreateInsertElement(VecRsrc, ElemRsrc, Idx, I.getName() + ".rsrc");
1992   copyMetadata(RsrcRes, &I);
1993   Value *OffRes =
1994       IRB.CreateInsertElement(VecOff, ElemOff, Idx, I.getName() + ".off");
1995   copyMetadata(OffRes, &I);
1996   SplitUsers.insert(&I);
1997   return {RsrcRes, OffRes};
1998 }
1999 
2000 PtrParts SplitPtrStructs::visitShuffleVectorInst(ShuffleVectorInst &I) {
2001   // Cast is needed for the same reason as insertelement's.
2002   if (!isSplitFatPtr(cast<Instruction>(I).getType()))
2003     return {nullptr, nullptr};
2004   IRB.SetInsertPoint(&I);
2005 
2006   Value *V1 = I.getOperand(0);
2007   Value *V2 = I.getOperand(1);
2008   ArrayRef<int> Mask = I.getShuffleMask();
2009   auto [V1Rsrc, V1Off] = getPtrParts(V1);
2010   auto [V2Rsrc, V2Off] = getPtrParts(V2);
2011 
2012   Value *RsrcRes =
2013       IRB.CreateShuffleVector(V1Rsrc, V2Rsrc, Mask, I.getName() + ".rsrc");
2014   copyMetadata(RsrcRes, &I);
2015   Value *OffRes =
2016       IRB.CreateShuffleVector(V1Off, V2Off, Mask, I.getName() + ".off");
2017   copyMetadata(OffRes, &I);
2018   SplitUsers.insert(&I);
2019   return {RsrcRes, OffRes};
2020 }
2021 
2022 PtrParts SplitPtrStructs::visitPHINode(PHINode &PHI) {
2023   if (!isSplitFatPtr(PHI.getType()))
2024     return {nullptr, nullptr};
2025   IRB.SetInsertPoint(*PHI.getInsertionPointAfterDef());
2026   // Phi nodes will be handled in post-processing after we've visited every
2027   // instruction. However, instead of just returning {nullptr, nullptr},
2028   // we explicitly create the temporary extractvalue operations that are our
2029   // temporary results so that they end up at the beginning of the block with
2030   // the PHIs.
2031   Value *TmpRsrc = IRB.CreateExtractValue(&PHI, 0, PHI.getName() + ".rsrc");
2032   Value *TmpOff = IRB.CreateExtractValue(&PHI, 1, PHI.getName() + ".off");
2033   Conditionals.push_back(&PHI);
2034   SplitUsers.insert(&PHI);
2035   return {TmpRsrc, TmpOff};
2036 }
2037 
2038 PtrParts SplitPtrStructs::visitSelectInst(SelectInst &SI) {
2039   if (!isSplitFatPtr(SI.getType()))
2040     return {nullptr, nullptr};
2041   IRB.SetInsertPoint(&SI);
2042 
2043   Value *Cond = SI.getCondition();
2044   Value *True = SI.getTrueValue();
2045   Value *False = SI.getFalseValue();
2046   auto [TrueRsrc, TrueOff] = getPtrParts(True);
2047   auto [FalseRsrc, FalseOff] = getPtrParts(False);
2048 
2049   Value *RsrcRes =
2050       IRB.CreateSelect(Cond, TrueRsrc, FalseRsrc, SI.getName() + ".rsrc", &SI);
2051   copyMetadata(RsrcRes, &SI);
2052   Conditionals.push_back(&SI);
2053   Value *OffRes =
2054       IRB.CreateSelect(Cond, TrueOff, FalseOff, SI.getName() + ".off", &SI);
2055   copyMetadata(OffRes, &SI);
2056   SplitUsers.insert(&SI);
2057   return {RsrcRes, OffRes};
2058 }
2059 
2060 /// Returns true if this intrinsic needs to be removed when it is
2061 /// applied to `ptr addrspace(7)` values. Calls to these intrinsics are
2062 /// rewritten into calls to versions of that intrinsic on the resource
2063 /// descriptor.
2064 static bool isRemovablePointerIntrinsic(Intrinsic::ID IID) {
2065   switch (IID) {
2066   default:
2067     return false;
2068   case Intrinsic::ptrmask:
2069   case Intrinsic::invariant_start:
2070   case Intrinsic::invariant_end:
2071   case Intrinsic::launder_invariant_group:
2072   case Intrinsic::strip_invariant_group:
2073     return true;
2074   }
2075 }
2076 
2077 PtrParts SplitPtrStructs::visitIntrinsicInst(IntrinsicInst &I) {
2078   Intrinsic::ID IID = I.getIntrinsicID();
2079   switch (IID) {
2080   default:
2081     break;
2082   case Intrinsic::ptrmask: {
2083     Value *Ptr = I.getArgOperand(0);
2084     if (!isSplitFatPtr(Ptr->getType()))
2085       return {nullptr, nullptr};
2086     Value *Mask = I.getArgOperand(1);
2087     IRB.SetInsertPoint(&I);
2088     auto [Rsrc, Off] = getPtrParts(Ptr);
2089     if (Mask->getType() != Off->getType())
2090       report_fatal_error("offset width is not equal to index width of fat "
2091                          "pointer (data layout not set up correctly?)");
2092     Value *OffRes = IRB.CreateAnd(Off, Mask, I.getName() + ".off");
2093     copyMetadata(OffRes, &I);
2094     SplitUsers.insert(&I);
2095     return {Rsrc, OffRes};
2096   }
2097   // Pointer annotation intrinsics that, given their object-wide nature
2098   // operate on the resource part.
2099   case Intrinsic::invariant_start: {
2100     Value *Ptr = I.getArgOperand(1);
2101     if (!isSplitFatPtr(Ptr->getType()))
2102       return {nullptr, nullptr};
2103     IRB.SetInsertPoint(&I);
2104     auto [Rsrc, Off] = getPtrParts(Ptr);
2105     Type *NewTy = PointerType::get(I.getContext(), AMDGPUAS::BUFFER_RESOURCE);
2106     auto *NewRsrc = IRB.CreateIntrinsic(IID, {NewTy}, {I.getOperand(0), Rsrc});
2107     copyMetadata(NewRsrc, &I);
2108     NewRsrc->takeName(&I);
2109     SplitUsers.insert(&I);
2110     I.replaceAllUsesWith(NewRsrc);
2111     return {nullptr, nullptr};
2112   }
2113   case Intrinsic::invariant_end: {
2114     Value *RealPtr = I.getArgOperand(2);
2115     if (!isSplitFatPtr(RealPtr->getType()))
2116       return {nullptr, nullptr};
2117     IRB.SetInsertPoint(&I);
2118     Value *RealRsrc = getPtrParts(RealPtr).first;
2119     Value *InvPtr = I.getArgOperand(0);
2120     Value *Size = I.getArgOperand(1);
2121     Value *NewRsrc = IRB.CreateIntrinsic(IID, {RealRsrc->getType()},
2122                                          {InvPtr, Size, RealRsrc});
2123     copyMetadata(NewRsrc, &I);
2124     NewRsrc->takeName(&I);
2125     SplitUsers.insert(&I);
2126     I.replaceAllUsesWith(NewRsrc);
2127     return {nullptr, nullptr};
2128   }
2129   case Intrinsic::launder_invariant_group:
2130   case Intrinsic::strip_invariant_group: {
2131     Value *Ptr = I.getArgOperand(0);
2132     if (!isSplitFatPtr(Ptr->getType()))
2133       return {nullptr, nullptr};
2134     IRB.SetInsertPoint(&I);
2135     auto [Rsrc, Off] = getPtrParts(Ptr);
2136     Value *NewRsrc = IRB.CreateIntrinsic(IID, {Rsrc->getType()}, {Rsrc});
2137     copyMetadata(NewRsrc, &I);
2138     NewRsrc->takeName(&I);
2139     SplitUsers.insert(&I);
2140     return {NewRsrc, Off};
2141   }
2142   }
2143   return {nullptr, nullptr};
2144 }
2145 
2146 void SplitPtrStructs::processFunction(Function &F) {
2147   ST = &TM->getSubtarget<GCNSubtarget>(F);
2148   SmallVector<Instruction *, 0> Originals;
2149   LLVM_DEBUG(dbgs() << "Splitting pointer structs in function: " << F.getName()
2150                     << "\n");
2151   for (Instruction &I : instructions(F))
2152     Originals.push_back(&I);
2153   for (Instruction *I : Originals) {
2154     auto [Rsrc, Off] = visit(I);
2155     assert(((Rsrc && Off) || (!Rsrc && !Off)) &&
2156            "Can't have a resource but no offset");
2157     if (Rsrc)
2158       RsrcParts[I] = Rsrc;
2159     if (Off)
2160       OffParts[I] = Off;
2161   }
2162   processConditionals();
2163   killAndReplaceSplitInstructions(Originals);
2164 
2165   // Clean up after ourselves to save on memory.
2166   RsrcParts.clear();
2167   OffParts.clear();
2168   SplitUsers.clear();
2169   Conditionals.clear();
2170   ConditionalTemps.clear();
2171 }
2172 
2173 namespace {
2174 class AMDGPULowerBufferFatPointers : public ModulePass {
2175 public:
2176   static char ID;
2177 
2178   AMDGPULowerBufferFatPointers() : ModulePass(ID) {
2179     initializeAMDGPULowerBufferFatPointersPass(
2180         *PassRegistry::getPassRegistry());
2181   }
2182 
2183   bool run(Module &M, const TargetMachine &TM);
2184   bool runOnModule(Module &M) override;
2185 
2186   void getAnalysisUsage(AnalysisUsage &AU) const override;
2187 };
2188 } // namespace
2189 
2190 /// Returns true if there are values that have a buffer fat pointer in them,
2191 /// which means we'll need to perform rewrites on this function. As a side
2192 /// effect, this will populate the type remapping cache.
2193 static bool containsBufferFatPointers(const Function &F,
2194                                       BufferFatPtrToStructTypeMap *TypeMap) {
2195   bool HasFatPointers = false;
2196   for (const BasicBlock &BB : F)
2197     for (const Instruction &I : BB)
2198       HasFatPointers |= (I.getType() != TypeMap->remapType(I.getType()));
2199   return HasFatPointers;
2200 }
2201 
2202 static bool hasFatPointerInterface(const Function &F,
2203                                    BufferFatPtrToStructTypeMap *TypeMap) {
2204   Type *Ty = F.getFunctionType();
2205   return Ty != TypeMap->remapType(Ty);
2206 }
2207 
2208 /// Move the body of `OldF` into a new function, returning it.
2209 static Function *moveFunctionAdaptingType(Function *OldF, FunctionType *NewTy,
2210                                           ValueToValueMapTy &CloneMap) {
2211   bool IsIntrinsic = OldF->isIntrinsic();
2212   Function *NewF =
2213       Function::Create(NewTy, OldF->getLinkage(), OldF->getAddressSpace());
2214   NewF->IsNewDbgInfoFormat = OldF->IsNewDbgInfoFormat;
2215   NewF->copyAttributesFrom(OldF);
2216   NewF->copyMetadata(OldF, 0);
2217   NewF->takeName(OldF);
2218   NewF->updateAfterNameChange();
2219   NewF->setDLLStorageClass(OldF->getDLLStorageClass());
2220   OldF->getParent()->getFunctionList().insertAfter(OldF->getIterator(), NewF);
2221 
2222   while (!OldF->empty()) {
2223     BasicBlock *BB = &OldF->front();
2224     BB->removeFromParent();
2225     BB->insertInto(NewF);
2226     CloneMap[BB] = BB;
2227     for (Instruction &I : *BB) {
2228       CloneMap[&I] = &I;
2229     }
2230   }
2231 
2232   SmallVector<AttributeSet> ArgAttrs;
2233   AttributeList OldAttrs = OldF->getAttributes();
2234 
2235   for (auto [I, OldArg, NewArg] : enumerate(OldF->args(), NewF->args())) {
2236     CloneMap[&NewArg] = &OldArg;
2237     NewArg.takeName(&OldArg);
2238     Type *OldArgTy = OldArg.getType(), *NewArgTy = NewArg.getType();
2239     // Temporarily mutate type of `NewArg` to allow RAUW to work.
2240     NewArg.mutateType(OldArgTy);
2241     OldArg.replaceAllUsesWith(&NewArg);
2242     NewArg.mutateType(NewArgTy);
2243 
2244     AttributeSet ArgAttr = OldAttrs.getParamAttrs(I);
2245     // Intrinsics get their attributes fixed later.
2246     if (OldArgTy != NewArgTy && !IsIntrinsic)
2247       ArgAttr = ArgAttr.removeAttributes(
2248           NewF->getContext(),
2249           AttributeFuncs::typeIncompatible(NewArgTy, ArgAttr));
2250     ArgAttrs.push_back(ArgAttr);
2251   }
2252   AttributeSet RetAttrs = OldAttrs.getRetAttrs();
2253   if (OldF->getReturnType() != NewF->getReturnType() && !IsIntrinsic)
2254     RetAttrs = RetAttrs.removeAttributes(
2255         NewF->getContext(),
2256         AttributeFuncs::typeIncompatible(NewF->getReturnType(), RetAttrs));
2257   NewF->setAttributes(AttributeList::get(
2258       NewF->getContext(), OldAttrs.getFnAttrs(), RetAttrs, ArgAttrs));
2259   return NewF;
2260 }
2261 
2262 static void makeCloneInPraceMap(Function *F, ValueToValueMapTy &CloneMap) {
2263   for (Argument &A : F->args())
2264     CloneMap[&A] = &A;
2265   for (BasicBlock &BB : *F) {
2266     CloneMap[&BB] = &BB;
2267     for (Instruction &I : BB)
2268       CloneMap[&I] = &I;
2269   }
2270 }
2271 
2272 bool AMDGPULowerBufferFatPointers::run(Module &M, const TargetMachine &TM) {
2273   bool Changed = false;
2274   const DataLayout &DL = M.getDataLayout();
2275   // Record the functions which need to be remapped.
2276   // The second element of the pair indicates whether the function has to have
2277   // its arguments or return types adjusted.
2278   SmallVector<std::pair<Function *, bool>> NeedsRemap;
2279 
2280   BufferFatPtrToStructTypeMap StructTM(DL);
2281   BufferFatPtrToIntTypeMap IntTM(DL);
2282   for (const GlobalVariable &GV : M.globals()) {
2283     if (GV.getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER)
2284       report_fatal_error("Global variables with a buffer fat pointer address "
2285                          "space (7) are not supported");
2286     Type *VT = GV.getValueType();
2287     if (VT != StructTM.remapType(VT))
2288       report_fatal_error("Global variables that contain buffer fat pointers "
2289                          "(address space 7 pointers) are unsupported. Use "
2290                          "buffer resource pointers (address space 8) instead.");
2291   }
2292 
2293   {
2294     // Collect all constant exprs and aggregates referenced by any function.
2295     SmallVector<Constant *, 8> Worklist;
2296     for (Function &F : M.functions())
2297       for (Instruction &I : instructions(F))
2298         for (Value *Op : I.operands())
2299           if (isa<ConstantExpr>(Op) || isa<ConstantAggregate>(Op))
2300             Worklist.push_back(cast<Constant>(Op));
2301 
2302     // Recursively look for any referenced buffer pointer constants.
2303     SmallPtrSet<Constant *, 8> Visited;
2304     SetVector<Constant *> BufferFatPtrConsts;
2305     while (!Worklist.empty()) {
2306       Constant *C = Worklist.pop_back_val();
2307       if (!Visited.insert(C).second)
2308         continue;
2309       if (isBufferFatPtrOrVector(C->getType()))
2310         BufferFatPtrConsts.insert(C);
2311       for (Value *Op : C->operands())
2312         if (isa<ConstantExpr>(Op) || isa<ConstantAggregate>(Op))
2313           Worklist.push_back(cast<Constant>(Op));
2314     }
2315 
2316     // Expand all constant expressions using fat buffer pointers to
2317     // instructions.
2318     Changed |= convertUsersOfConstantsToInstructions(
2319         BufferFatPtrConsts.getArrayRef(), /*RestrictToFunc=*/nullptr,
2320         /*RemoveDeadConstants=*/false, /*IncludeSelf=*/true);
2321   }
2322 
2323   StoreFatPtrsAsIntsVisitor MemOpsRewrite(&IntTM, M.getContext());
2324   LegalizeBufferContentTypesVisitor BufferContentsTypeRewrite(DL,
2325                                                               M.getContext());
2326   for (Function &F : M.functions()) {
2327     bool InterfaceChange = hasFatPointerInterface(F, &StructTM);
2328     bool BodyChanges = containsBufferFatPointers(F, &StructTM);
2329     Changed |= MemOpsRewrite.processFunction(F);
2330     if (InterfaceChange || BodyChanges) {
2331       NeedsRemap.push_back(std::make_pair(&F, InterfaceChange));
2332       Changed |= BufferContentsTypeRewrite.processFunction(F);
2333     }
2334   }
2335   if (NeedsRemap.empty())
2336     return Changed;
2337 
2338   SmallVector<Function *> NeedsPostProcess;
2339   SmallVector<Function *> Intrinsics;
2340   // Keep one big map so as to memoize constants across functions.
2341   ValueToValueMapTy CloneMap;
2342   FatPtrConstMaterializer Materializer(&StructTM, CloneMap);
2343 
2344   ValueMapper LowerInFuncs(CloneMap, RF_None, &StructTM, &Materializer);
2345   for (auto [F, InterfaceChange] : NeedsRemap) {
2346     Function *NewF = F;
2347     if (InterfaceChange)
2348       NewF = moveFunctionAdaptingType(
2349           F, cast<FunctionType>(StructTM.remapType(F->getFunctionType())),
2350           CloneMap);
2351     else
2352       makeCloneInPraceMap(F, CloneMap);
2353     LowerInFuncs.remapFunction(*NewF);
2354     if (NewF->isIntrinsic())
2355       Intrinsics.push_back(NewF);
2356     else
2357       NeedsPostProcess.push_back(NewF);
2358     if (InterfaceChange) {
2359       F->replaceAllUsesWith(NewF);
2360       F->eraseFromParent();
2361     }
2362     Changed = true;
2363   }
2364   StructTM.clear();
2365   IntTM.clear();
2366   CloneMap.clear();
2367 
2368   SplitPtrStructs Splitter(M.getContext(), &TM);
2369   for (Function *F : NeedsPostProcess)
2370     Splitter.processFunction(*F);
2371   for (Function *F : Intrinsics) {
2372     if (isRemovablePointerIntrinsic(F->getIntrinsicID())) {
2373       F->eraseFromParent();
2374     } else {
2375       std::optional<Function *> NewF = Intrinsic::remangleIntrinsicFunction(F);
2376       if (NewF)
2377         F->replaceAllUsesWith(*NewF);
2378     }
2379   }
2380   return Changed;
2381 }
2382 
2383 bool AMDGPULowerBufferFatPointers::runOnModule(Module &M) {
2384   TargetPassConfig &TPC = getAnalysis<TargetPassConfig>();
2385   const TargetMachine &TM = TPC.getTM<TargetMachine>();
2386   return run(M, TM);
2387 }
2388 
2389 char AMDGPULowerBufferFatPointers::ID = 0;
2390 
2391 char &llvm::AMDGPULowerBufferFatPointersID = AMDGPULowerBufferFatPointers::ID;
2392 
2393 void AMDGPULowerBufferFatPointers::getAnalysisUsage(AnalysisUsage &AU) const {
2394   AU.addRequired<TargetPassConfig>();
2395 }
2396 
2397 #define PASS_DESC "Lower buffer fat pointer operations to buffer resources"
2398 INITIALIZE_PASS_BEGIN(AMDGPULowerBufferFatPointers, DEBUG_TYPE, PASS_DESC,
2399                       false, false)
2400 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
2401 INITIALIZE_PASS_END(AMDGPULowerBufferFatPointers, DEBUG_TYPE, PASS_DESC, false,
2402                     false)
2403 #undef PASS_DESC
2404 
2405 ModulePass *llvm::createAMDGPULowerBufferFatPointersPass() {
2406   return new AMDGPULowerBufferFatPointers();
2407 }
2408 
2409 PreservedAnalyses
2410 AMDGPULowerBufferFatPointersPass::run(Module &M, ModuleAnalysisManager &MA) {
2411   return AMDGPULowerBufferFatPointers().run(M, TM) ? PreservedAnalyses::none()
2412                                                    : PreservedAnalyses::all();
2413 }
2414