xref: /llvm-project/llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp (revision 697c1883f15b81cc526ed2d72cf00f9eaea2502f)
1 //===-- AMDGPULowerBufferFatPointers.cpp ---------------------------=//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass lowers operations on buffer fat pointers (addrspace 7) to
10 // operations on buffer resources (addrspace 8) and is needed for correct
11 // codegen.
12 //
13 // # Background
14 //
15 // Address space 7 (the buffer fat pointer) is a 160-bit pointer that consists
16 // of a 128-bit buffer descriptor and a 32-bit offset into that descriptor.
17 // The buffer resource part needs to be it needs to be a "raw" buffer resource
18 // (it must have a stride of 0 and bounds checks must be in raw buffer mode
19 // or disabled).
20 //
21 // When these requirements are met, a buffer resource can be treated as a
22 // typical (though quite wide) pointer that follows typical LLVM pointer
23 // semantics. This allows the frontend to reason about such buffers (which are
24 // often encountered in the context of SPIR-V kernels).
25 //
26 // However, because of their non-power-of-2 size, these fat pointers cannot be
27 // present during translation to MIR (though this restriction may be lifted
28 // during the transition to GlobalISel). Therefore, this pass is needed in order
29 // to correctly implement these fat pointers.
30 //
31 // The resource intrinsics take the resource part (the address space 8 pointer)
32 // and the offset part (the 32-bit integer) as separate arguments. In addition,
33 // many users of these buffers manipulate the offset while leaving the resource
34 // part alone. For these reasons, we want to typically separate the resource
35 // and offset parts into separate variables, but combine them together when
36 // encountering cases where this is required, such as by inserting these values
37 // into aggretates or moving them to memory.
38 //
39 // Therefore, at a high level, `ptr addrspace(7) %x` becomes `ptr addrspace(8)
40 // %x.rsrc` and `i32 %x.off`, which will be combined into `{ptr addrspace(8),
41 // i32} %x = {%x.rsrc, %x.off}` if needed. Similarly, `vector<Nxp7>` becomes
42 // `{vector<Nxp8>, vector<Nxi32 >}` and its component parts.
43 //
44 // # Implementation
45 //
46 // This pass proceeds in three main phases:
47 //
48 // ## Rewriting loads and stores of p7
49 //
50 // The first phase is to rewrite away all loads and stors of `ptr addrspace(7)`,
51 // including aggregates containing such pointers, to ones that use `i160`. This
52 // is handled by `StoreFatPtrsAsIntsVisitor` , which visits loads, stores, and
53 // allocas and, if the loaded or stored type contains `ptr addrspace(7)`,
54 // rewrites that type to one where the p7s are replaced by i160s, copying other
55 // parts of aggregates as needed. In the case of a store, each pointer is
56 // `ptrtoint`d to i160 before storing, and load integers are `inttoptr`d back.
57 // This same transformation is applied to vectors of pointers.
58 //
59 // Such a transformation allows the later phases of the pass to not need
60 // to handle buffer fat pointers moving to and from memory, where we load
61 // have to handle the incompatibility between a `{Nxp8, Nxi32}` representation
62 // and `Nxi60` directly. Instead, that transposing action (where the vectors
63 // of resources and vectors of offsets are concatentated before being stored to
64 // memory) are handled through implementing `inttoptr` and `ptrtoint` only.
65 //
66 // Atomics operations on `ptr addrspace(7)` values are not suppported, as the
67 // hardware does not include a 160-bit atomic.
68 //
69 // ## Buffer contents type legalization
70 //
71 // The underlying buffer intrinsics only support types up to 128 bits long,
72 // and don't support complex types. If buffer operations were
73 // standard pointer operations that could be represented as MIR-level loads,
74 // this would be handled by the various legalization schemes in instruction
75 // selection. However, because we have to do the conversion from `load` and
76 // `store` to intrinsics at LLVM IR level, we must perform that legalization
77 // ourselves.
78 //
79 // This involves a combination of
80 // - Converting arrays to vectors where possible
81 // - Otherwise, splitting loads and stores of aggregates into loads/stores of
82 //   each component.
83 // - Zero-extending things to fill a whole number of bytes
84 // - Casting values of types that don't neatly correspond to supported machine
85 // value
86 //   (for example, an i96 or i256) into ones that would work (
87 //    like <3 x i32> and <8 x i32>, respectively)
88 // - Splitting values that are too long (such as aforementioned <8 x i32>) into
89 //   multiple operations.
90 //
91 // ## Type remapping
92 //
93 // We use a `ValueMapper` to mangle uses of [vectors of] buffer fat pointers
94 // to the corresponding struct type, which has a resource part and an offset
95 // part.
96 //
97 // This uses a `BufferFatPtrToStructTypeMap` and a `FatPtrConstMaterializer`
98 // to, usually by way of `setType`ing values. Constants are handled here
99 // because there isn't a good way to fix them up later.
100 //
101 // This has the downside of leaving the IR in an invalid state (for example,
102 // the instruction `getelementptr {ptr addrspace(8), i32} %p, ...` will exist),
103 // but all such invalid states will be resolved by the third phase.
104 //
105 // Functions that don't take buffer fat pointers are modified in place. Those
106 // that do take such pointers have their basic blocks moved to a new function
107 // with arguments that are {ptr addrspace(8), i32} arguments and return values.
108 // This phase also records intrinsics so that they can be remangled or deleted
109 // later.
110 //
111 // ## Splitting pointer structs
112 //
113 // The meat of this pass consists of defining semantics for operations that
114 // produce or consume [vectors of] buffer fat pointers in terms of their
115 // resource and offset parts. This is accomplished throgh the `SplitPtrStructs`
116 // visitor.
117 //
118 // In the first pass through each function that is being lowered, the splitter
119 // inserts new instructions to implement the split-structures behavior, which is
120 // needed for correctness and performance. It records a list of "split users",
121 // instructions that are being replaced by operations on the resource and offset
122 // parts.
123 //
124 // Split users do not necessarily need to produce parts themselves (
125 // a `load float, ptr addrspace(7)` does not, for example), but, if they do not
126 // generate fat buffer pointers, they must RAUW in their replacement
127 // instructions during the initial visit.
128 //
129 // When these new instructions are created, they use the split parts recorded
130 // for their initial arguments in order to generate their replacements, creating
131 // a parallel set of instructions that does not refer to the original fat
132 // pointer values but instead to their resource and offset components.
133 //
134 // Instructions, such as `extractvalue`, that produce buffer fat pointers from
135 // sources that do not have split parts, have such parts generated using
136 // `extractvalue`. This is also the initial handling of PHI nodes, which
137 // are then cleaned up.
138 //
139 // ### Conditionals
140 //
141 // PHI nodes are initially given resource parts via `extractvalue`. However,
142 // this is not an efficient rewrite of such nodes, as, in most cases, the
143 // resource part in a conditional or loop remains constant throughout the loop
144 // and only the offset varies. Failing to optimize away these constant resources
145 // would cause additional registers to be sent around loops and might lead to
146 // waterfall loops being generated for buffer operations due to the
147 // "non-uniform" resource argument.
148 //
149 // Therefore, after all instructions have been visited, the pointer splitter
150 // post-processes all encountered conditionals. Given a PHI node or select,
151 // getPossibleRsrcRoots() collects all values that the resource parts of that
152 // conditional's input could come from as well as collecting all conditional
153 // instructions encountered during the search. If, after filtering out the
154 // initial node itself, the set of encountered conditionals is a subset of the
155 // potential roots and there is a single potential resource that isn't in the
156 // conditional set, that value is the only possible value the resource argument
157 // could have throughout the control flow.
158 //
159 // If that condition is met, then a PHI node can have its resource part changed
160 // to the singleton value and then be replaced by a PHI on the offsets.
161 // Otherwise, each PHI node is split into two, one for the resource part and one
162 // for the offset part, which replace the temporary `extractvalue` instructions
163 // that were added during the first pass.
164 //
165 // Similar logic applies to `select`, where
166 // `%z = select i1 %cond, %cond, ptr addrspace(7) %x, ptr addrspace(7) %y`
167 // can be split into `%z.rsrc = %x.rsrc` and
168 // `%z.off = select i1 %cond, ptr i32 %x.off, i32 %y.off`
169 // if both `%x` and `%y` have the same resource part, but two `select`
170 // operations will be needed if they do not.
171 //
172 // ### Final processing
173 //
174 // After conditionals have been cleaned up, the IR for each function is
175 // rewritten to remove all the old instructions that have been split up.
176 //
177 // Any instruction that used to produce a buffer fat pointer (and therefore now
178 // produces a resource-and-offset struct after type remapping) is
179 // replaced as follows:
180 // 1. All debug value annotations are cloned to reflect that the resource part
181 //    and offset parts are computed separately and constitute different
182 //    fragments of the underlying source language variable.
183 // 2. All uses that were themselves split are replaced by a `poison` of the
184 //    struct type, as they will themselves be erased soon. This rule, combined
185 //    with debug handling, should leave the use lists of split instructions
186 //    empty in almost all cases.
187 // 3. If a user of the original struct-valued result remains, the structure
188 //    needed for the new types to work is constructed out of the newly-defined
189 //    parts, and the original instruction is replaced by this structure
190 //    before being erased. Instructions requiring this construction include
191 //    `ret` and `insertvalue`.
192 //
193 // # Consequences
194 //
195 // This pass does not alter the CFG.
196 //
197 // Alias analysis information will become coarser, as the LLVM alias analyzer
198 // cannot handle the buffer intrinsics. Specifically, while we can determine
199 // that the following two loads do not alias:
200 // ```
201 //   %y = getelementptr i32, ptr addrspace(7) %x, i32 1
202 //   %a = load i32, ptr addrspace(7) %x
203 //   %b = load i32, ptr addrspace(7) %y
204 // ```
205 // we cannot (except through some code that runs during scheduling) determine
206 // that the rewritten loads below do not alias.
207 // ```
208 //   %y.off = add i32 %x.off, 1
209 //   %a = call @llvm.amdgcn.raw.ptr.buffer.load(ptr addrspace(8) %x.rsrc, i32
210 //     %x.off, ...)
211 //   %b = call @llvm.amdgcn.raw.ptr.buffer.load(ptr addrspace(8)
212 //     %x.rsrc, i32 %y.off, ...)
213 // ```
214 // However, existing alias information is preserved.
215 //===----------------------------------------------------------------------===//
216 
217 #include "AMDGPU.h"
218 #include "AMDGPUTargetMachine.h"
219 #include "GCNSubtarget.h"
220 #include "SIDefines.h"
221 #include "llvm/ADT/SetOperations.h"
222 #include "llvm/ADT/SmallVector.h"
223 #include "llvm/Analysis/ConstantFolding.h"
224 #include "llvm/Analysis/Utils/Local.h"
225 #include "llvm/CodeGen/TargetPassConfig.h"
226 #include "llvm/IR/AttributeMask.h"
227 #include "llvm/IR/Constants.h"
228 #include "llvm/IR/DebugInfo.h"
229 #include "llvm/IR/DerivedTypes.h"
230 #include "llvm/IR/IRBuilder.h"
231 #include "llvm/IR/InstIterator.h"
232 #include "llvm/IR/InstVisitor.h"
233 #include "llvm/IR/Instructions.h"
234 #include "llvm/IR/Intrinsics.h"
235 #include "llvm/IR/IntrinsicsAMDGPU.h"
236 #include "llvm/IR/Metadata.h"
237 #include "llvm/IR/Operator.h"
238 #include "llvm/IR/PatternMatch.h"
239 #include "llvm/IR/ReplaceConstant.h"
240 #include "llvm/InitializePasses.h"
241 #include "llvm/Pass.h"
242 #include "llvm/Support/Alignment.h"
243 #include "llvm/Support/AtomicOrdering.h"
244 #include "llvm/Support/Debug.h"
245 #include "llvm/Support/ErrorHandling.h"
246 #include "llvm/Transforms/Utils/Cloning.h"
247 #include "llvm/Transforms/Utils/Local.h"
248 #include "llvm/Transforms/Utils/ValueMapper.h"
249 
250 #define DEBUG_TYPE "amdgpu-lower-buffer-fat-pointers"
251 
252 using namespace llvm;
253 
254 static constexpr unsigned BufferOffsetWidth = 32;
255 
256 namespace {
257 /// Recursively replace instances of ptr addrspace(7) and vector<Nxptr
258 /// addrspace(7)> with some other type as defined by the relevant subclass.
259 class BufferFatPtrTypeLoweringBase : public ValueMapTypeRemapper {
260   DenseMap<Type *, Type *> Map;
261 
262   Type *remapTypeImpl(Type *Ty, SmallPtrSetImpl<StructType *> &Seen);
263 
264 protected:
265   virtual Type *remapScalar(PointerType *PT) = 0;
266   virtual Type *remapVector(VectorType *VT) = 0;
267 
268   const DataLayout &DL;
269 
270 public:
271   BufferFatPtrTypeLoweringBase(const DataLayout &DL) : DL(DL) {}
272   Type *remapType(Type *SrcTy) override;
273   void clear() { Map.clear(); }
274 };
275 
276 /// Remap ptr addrspace(7) to i160 and vector<Nxptr addrspace(7)> to
277 /// vector<Nxi60> in order to correctly handling loading/storing these values
278 /// from memory.
279 class BufferFatPtrToIntTypeMap : public BufferFatPtrTypeLoweringBase {
280   using BufferFatPtrTypeLoweringBase::BufferFatPtrTypeLoweringBase;
281 
282 protected:
283   Type *remapScalar(PointerType *PT) override { return DL.getIntPtrType(PT); }
284   Type *remapVector(VectorType *VT) override { return DL.getIntPtrType(VT); }
285 };
286 
287 /// Remap ptr addrspace(7) to {ptr addrspace(8), i32} (the resource and offset
288 /// parts of the pointer) so that we can easily rewrite operations on these
289 /// values that aren't loading them from or storing them to memory.
290 class BufferFatPtrToStructTypeMap : public BufferFatPtrTypeLoweringBase {
291   using BufferFatPtrTypeLoweringBase::BufferFatPtrTypeLoweringBase;
292 
293 protected:
294   Type *remapScalar(PointerType *PT) override;
295   Type *remapVector(VectorType *VT) override;
296 };
297 } // namespace
298 
299 // This code is adapted from the type remapper in lib/Linker/IRMover.cpp
300 Type *BufferFatPtrTypeLoweringBase::remapTypeImpl(
301     Type *Ty, SmallPtrSetImpl<StructType *> &Seen) {
302   Type **Entry = &Map[Ty];
303   if (*Entry)
304     return *Entry;
305   if (auto *PT = dyn_cast<PointerType>(Ty)) {
306     if (PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER) {
307       return *Entry = remapScalar(PT);
308     }
309   }
310   if (auto *VT = dyn_cast<VectorType>(Ty)) {
311     auto *PT = dyn_cast<PointerType>(VT->getElementType());
312     if (PT && PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER) {
313       return *Entry = remapVector(VT);
314     }
315     return *Entry = Ty;
316   }
317   // Whether the type is one that is structurally uniqued - that is, if it is
318   // not a named struct (the only kind of type where multiple structurally
319   // identical types that have a distinct `Type*`)
320   StructType *TyAsStruct = dyn_cast<StructType>(Ty);
321   bool IsUniqued = !TyAsStruct || TyAsStruct->isLiteral();
322   // Base case for ints, floats, opaque pointers, and so on, which don't
323   // require recursion.
324   if (Ty->getNumContainedTypes() == 0 && IsUniqued)
325     return *Entry = Ty;
326   if (!IsUniqued) {
327     // Create a dummy type for recursion purposes.
328     if (!Seen.insert(TyAsStruct).second) {
329       StructType *Placeholder = StructType::create(Ty->getContext());
330       return *Entry = Placeholder;
331     }
332   }
333   bool Changed = false;
334   SmallVector<Type *> ElementTypes(Ty->getNumContainedTypes(), nullptr);
335   for (unsigned int I = 0, E = Ty->getNumContainedTypes(); I < E; ++I) {
336     Type *OldElem = Ty->getContainedType(I);
337     Type *NewElem = remapTypeImpl(OldElem, Seen);
338     ElementTypes[I] = NewElem;
339     Changed |= (OldElem != NewElem);
340   }
341   // Recursive calls to remapTypeImpl() may have invalidated pointer.
342   Entry = &Map[Ty];
343   if (!Changed) {
344     return *Entry = Ty;
345   }
346   if (auto *ArrTy = dyn_cast<ArrayType>(Ty))
347     return *Entry = ArrayType::get(ElementTypes[0], ArrTy->getNumElements());
348   if (auto *FnTy = dyn_cast<FunctionType>(Ty))
349     return *Entry = FunctionType::get(ElementTypes[0],
350                                       ArrayRef(ElementTypes).slice(1),
351                                       FnTy->isVarArg());
352   if (auto *STy = dyn_cast<StructType>(Ty)) {
353     // Genuine opaque types don't have a remapping.
354     if (STy->isOpaque())
355       return *Entry = Ty;
356     bool IsPacked = STy->isPacked();
357     if (IsUniqued)
358       return *Entry = StructType::get(Ty->getContext(), ElementTypes, IsPacked);
359     SmallString<16> Name(STy->getName());
360     STy->setName("");
361     Type **RecursionEntry = &Map[Ty];
362     if (*RecursionEntry) {
363       auto *Placeholder = cast<StructType>(*RecursionEntry);
364       Placeholder->setBody(ElementTypes, IsPacked);
365       Placeholder->setName(Name);
366       return *Entry = Placeholder;
367     }
368     return *Entry = StructType::create(Ty->getContext(), ElementTypes, Name,
369                                        IsPacked);
370   }
371   llvm_unreachable("Unknown type of type that contains elements");
372 }
373 
374 Type *BufferFatPtrTypeLoweringBase::remapType(Type *SrcTy) {
375   SmallPtrSet<StructType *, 2> Visited;
376   return remapTypeImpl(SrcTy, Visited);
377 }
378 
379 Type *BufferFatPtrToStructTypeMap::remapScalar(PointerType *PT) {
380   LLVMContext &Ctx = PT->getContext();
381   return StructType::get(PointerType::get(Ctx, AMDGPUAS::BUFFER_RESOURCE),
382                          IntegerType::get(Ctx, BufferOffsetWidth));
383 }
384 
385 Type *BufferFatPtrToStructTypeMap::remapVector(VectorType *VT) {
386   ElementCount EC = VT->getElementCount();
387   LLVMContext &Ctx = VT->getContext();
388   Type *RsrcVec =
389       VectorType::get(PointerType::get(Ctx, AMDGPUAS::BUFFER_RESOURCE), EC);
390   Type *OffVec = VectorType::get(IntegerType::get(Ctx, BufferOffsetWidth), EC);
391   return StructType::get(RsrcVec, OffVec);
392 }
393 
394 static bool isBufferFatPtrOrVector(Type *Ty) {
395   if (auto *PT = dyn_cast<PointerType>(Ty->getScalarType()))
396     return PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER;
397   return false;
398 }
399 
400 // True if the type is {ptr addrspace(8), i32} or a struct containing vectors of
401 // those types. Used to quickly skip instructions we don't need to process.
402 static bool isSplitFatPtr(Type *Ty) {
403   auto *ST = dyn_cast<StructType>(Ty);
404   if (!ST)
405     return false;
406   if (!ST->isLiteral() || ST->getNumElements() != 2)
407     return false;
408   auto *MaybeRsrc =
409       dyn_cast<PointerType>(ST->getElementType(0)->getScalarType());
410   auto *MaybeOff =
411       dyn_cast<IntegerType>(ST->getElementType(1)->getScalarType());
412   return MaybeRsrc && MaybeOff &&
413          MaybeRsrc->getAddressSpace() == AMDGPUAS::BUFFER_RESOURCE &&
414          MaybeOff->getBitWidth() == BufferOffsetWidth;
415 }
416 
417 // True if the result type or any argument types are buffer fat pointers.
418 static bool isBufferFatPtrConst(Constant *C) {
419   Type *T = C->getType();
420   return isBufferFatPtrOrVector(T) || any_of(C->operands(), [](const Use &U) {
421            return isBufferFatPtrOrVector(U.get()->getType());
422          });
423 }
424 
425 namespace {
426 /// Convert [vectors of] buffer fat pointers to integers when they are read from
427 /// or stored to memory. This ensures that these pointers will have the same
428 /// memory layout as before they are lowered, even though they will no longer
429 /// have their previous layout in registers/in the program (they'll be broken
430 /// down into resource and offset parts). This has the downside of imposing
431 /// marshalling costs when reading or storing these values, but since placing
432 /// such pointers into memory is an uncommon operation at best, we feel that
433 /// this cost is acceptable for better performance in the common case.
434 class StoreFatPtrsAsIntsVisitor
435     : public InstVisitor<StoreFatPtrsAsIntsVisitor, bool> {
436   BufferFatPtrToIntTypeMap *TypeMap;
437 
438   ValueToValueMapTy ConvertedForStore;
439 
440   IRBuilder<> IRB;
441 
442   // Convert all the buffer fat pointers within the input value to inttegers
443   // so that it can be stored in memory.
444   Value *fatPtrsToInts(Value *V, Type *From, Type *To, const Twine &Name);
445   // Convert all the i160s that need to be buffer fat pointers (as specified)
446   // by the To type) into those pointers to preserve the semantics of the rest
447   // of the program.
448   Value *intsToFatPtrs(Value *V, Type *From, Type *To, const Twine &Name);
449 
450 public:
451   StoreFatPtrsAsIntsVisitor(BufferFatPtrToIntTypeMap *TypeMap, LLVMContext &Ctx)
452       : TypeMap(TypeMap), IRB(Ctx) {}
453   bool processFunction(Function &F);
454 
455   bool visitInstruction(Instruction &I) { return false; }
456   bool visitAllocaInst(AllocaInst &I);
457   bool visitLoadInst(LoadInst &LI);
458   bool visitStoreInst(StoreInst &SI);
459   bool visitGetElementPtrInst(GetElementPtrInst &I);
460 };
461 } // namespace
462 
463 Value *StoreFatPtrsAsIntsVisitor::fatPtrsToInts(Value *V, Type *From, Type *To,
464                                                 const Twine &Name) {
465   if (From == To)
466     return V;
467   ValueToValueMapTy::iterator Find = ConvertedForStore.find(V);
468   if (Find != ConvertedForStore.end())
469     return Find->second;
470   if (isBufferFatPtrOrVector(From)) {
471     Value *Cast = IRB.CreatePtrToInt(V, To, Name + ".int");
472     ConvertedForStore[V] = Cast;
473     return Cast;
474   }
475   if (From->getNumContainedTypes() == 0)
476     return V;
477   // Structs, arrays, and other compound types.
478   Value *Ret = PoisonValue::get(To);
479   if (auto *AT = dyn_cast<ArrayType>(From)) {
480     Type *FromPart = AT->getArrayElementType();
481     Type *ToPart = cast<ArrayType>(To)->getElementType();
482     for (uint64_t I = 0, E = AT->getArrayNumElements(); I < E; ++I) {
483       Value *Field = IRB.CreateExtractValue(V, I);
484       Value *NewField =
485           fatPtrsToInts(Field, FromPart, ToPart, Name + "." + Twine(I));
486       Ret = IRB.CreateInsertValue(Ret, NewField, I);
487     }
488   } else {
489     for (auto [Idx, FromPart, ToPart] :
490          enumerate(From->subtypes(), To->subtypes())) {
491       Value *Field = IRB.CreateExtractValue(V, Idx);
492       Value *NewField =
493           fatPtrsToInts(Field, FromPart, ToPart, Name + "." + Twine(Idx));
494       Ret = IRB.CreateInsertValue(Ret, NewField, Idx);
495     }
496   }
497   ConvertedForStore[V] = Ret;
498   return Ret;
499 }
500 
501 Value *StoreFatPtrsAsIntsVisitor::intsToFatPtrs(Value *V, Type *From, Type *To,
502                                                 const Twine &Name) {
503   if (From == To)
504     return V;
505   if (isBufferFatPtrOrVector(To)) {
506     Value *Cast = IRB.CreateIntToPtr(V, To, Name + ".ptr");
507     return Cast;
508   }
509   if (From->getNumContainedTypes() == 0)
510     return V;
511   // Structs, arrays, and other compound types.
512   Value *Ret = PoisonValue::get(To);
513   if (auto *AT = dyn_cast<ArrayType>(From)) {
514     Type *FromPart = AT->getArrayElementType();
515     Type *ToPart = cast<ArrayType>(To)->getElementType();
516     for (uint64_t I = 0, E = AT->getArrayNumElements(); I < E; ++I) {
517       Value *Field = IRB.CreateExtractValue(V, I);
518       Value *NewField =
519           intsToFatPtrs(Field, FromPart, ToPart, Name + "." + Twine(I));
520       Ret = IRB.CreateInsertValue(Ret, NewField, I);
521     }
522   } else {
523     for (auto [Idx, FromPart, ToPart] :
524          enumerate(From->subtypes(), To->subtypes())) {
525       Value *Field = IRB.CreateExtractValue(V, Idx);
526       Value *NewField =
527           intsToFatPtrs(Field, FromPart, ToPart, Name + "." + Twine(Idx));
528       Ret = IRB.CreateInsertValue(Ret, NewField, Idx);
529     }
530   }
531   return Ret;
532 }
533 
534 bool StoreFatPtrsAsIntsVisitor::processFunction(Function &F) {
535   bool Changed = false;
536   // The visitors will mutate GEPs and allocas, but will push loads and stores
537   // to the worklist to avoid invalidation.
538   for (Instruction &I : make_early_inc_range(instructions(F))) {
539     Changed |= visit(I);
540   }
541   ConvertedForStore.clear();
542   return Changed;
543 }
544 
545 bool StoreFatPtrsAsIntsVisitor::visitAllocaInst(AllocaInst &I) {
546   Type *Ty = I.getAllocatedType();
547   Type *NewTy = TypeMap->remapType(Ty);
548   if (Ty == NewTy)
549     return false;
550   I.setAllocatedType(NewTy);
551   return true;
552 }
553 
554 bool StoreFatPtrsAsIntsVisitor::visitGetElementPtrInst(GetElementPtrInst &I) {
555   Type *Ty = I.getSourceElementType();
556   Type *NewTy = TypeMap->remapType(Ty);
557   if (Ty == NewTy)
558     return false;
559   // We'll be rewriting the type `ptr addrspace(7)` out of existence soon, so
560   // make sure GEPs don't have different semantics with the new type.
561   I.setSourceElementType(NewTy);
562   I.setResultElementType(TypeMap->remapType(I.getResultElementType()));
563   return true;
564 }
565 
566 bool StoreFatPtrsAsIntsVisitor::visitLoadInst(LoadInst &LI) {
567   Type *Ty = LI.getType();
568   Type *IntTy = TypeMap->remapType(Ty);
569   if (Ty == IntTy)
570     return false;
571 
572   IRB.SetInsertPoint(&LI);
573   auto *NLI = cast<LoadInst>(LI.clone());
574   NLI->mutateType(IntTy);
575   NLI = IRB.Insert(NLI);
576   NLI->takeName(&LI);
577 
578   Value *CastBack = intsToFatPtrs(NLI, IntTy, Ty, NLI->getName());
579   LI.replaceAllUsesWith(CastBack);
580   LI.eraseFromParent();
581   return true;
582 }
583 
584 bool StoreFatPtrsAsIntsVisitor::visitStoreInst(StoreInst &SI) {
585   Value *V = SI.getValueOperand();
586   Type *Ty = V->getType();
587   Type *IntTy = TypeMap->remapType(Ty);
588   if (Ty == IntTy)
589     return false;
590 
591   IRB.SetInsertPoint(&SI);
592   Value *IntV = fatPtrsToInts(V, Ty, IntTy, V->getName());
593   for (auto *Dbg : at::getAssignmentMarkers(&SI))
594     Dbg->setValue(IntV);
595 
596   SI.setOperand(0, IntV);
597   return true;
598 }
599 
600 namespace {
601 /// Convert loads/stores of types that the buffer intrinsics can't handle into
602 /// one ore more such loads/stores that consist of legal types.
603 ///
604 /// Do this by
605 /// 1. Recursing into structs (and arrays that don't share a memory layout with
606 /// vectors) since the intrinsics can't handle complex types.
607 /// 2. Converting arrays of non-aggregate, byte-sized types into their
608 /// corresponding vectors
609 /// 3. Bitcasting unsupported types, namely overly-long scalars and byte
610 /// vectors, into vectors of supported types.
611 /// 4. Splitting up excessively long reads/writes into multiple operations.
612 ///
613 /// Note that this doesn't handle complex data strucures, but, in the future,
614 /// the aggregate load splitter from SROA could be refactored to allow for that
615 /// case.
616 class LegalizeBufferContentTypesVisitor
617     : public InstVisitor<LegalizeBufferContentTypesVisitor, bool> {
618   friend class InstVisitor<LegalizeBufferContentTypesVisitor, bool>;
619 
620   IRBuilder<> IRB;
621 
622   const DataLayout &DL;
623 
624   /// If T is [N x U], where U is a scalar type, return the vector type
625   /// <N x U>, otherwise, return T.
626   Type *scalarArrayTypeAsVector(Type *MaybeArrayType);
627   Value *arrayToVector(Value *V, Type *TargetType, const Twine &Name);
628   Value *vectorToArray(Value *V, Type *OrigType, const Twine &Name);
629 
630   /// Break up the loads of a struct into the loads of its components
631 
632   /// Convert a vector or scalar type that can't be operated on by buffer
633   /// intrinsics to one that would be legal through bitcasts and/or truncation.
634   /// Uses the wider of i32, i16, or i8 where possible.
635   Type *legalNonAggregateFor(Type *T);
636   Value *makeLegalNonAggregate(Value *V, Type *TargetType, const Twine &Name);
637   Value *makeIllegalNonAggregate(Value *V, Type *OrigType, const Twine &Name);
638 
639   struct VecSlice {
640     uint64_t Index = 0;
641     uint64_t Length = 0;
642     VecSlice() = delete;
643     // Needed for some Clangs
644     VecSlice(uint64_t Index, uint64_t Length) : Index(Index), Length(Length) {}
645   };
646   /// Return the [index, length] pairs into which `T` needs to be cut to form
647   /// legal buffer load or store operations. Clears `Slices`. Creates an empty
648   /// `Slices` for non-vector inputs and creates one slice if no slicing will be
649   /// needed.
650   void getVecSlices(Type *T, SmallVectorImpl<VecSlice> &Slices);
651 
652   Value *extractSlice(Value *Vec, VecSlice S, const Twine &Name);
653   Value *insertSlice(Value *Whole, Value *Part, VecSlice S, const Twine &Name);
654 
655   /// In most cases, return `LegalType`. However, when given an input that would
656   /// normally be a legal type for the buffer intrinsics to return but that
657   /// isn't hooked up through SelectionDAG, return a type of the same width that
658   /// can be used with the relevant intrinsics. Specifically, handle the cases:
659   /// - <1 x T> => T for all T
660   /// - <N x i8> <=> i16, i32, 2xi32, 4xi32 (as needed)
661   /// - <N x T> where T is under 32 bits and the total size is 96 bits <=> <3 x
662   /// i32>
663   Type *intrinsicTypeFor(Type *LegalType);
664 
665   bool visitLoadImpl(LoadInst &OrigLI, Type *PartType,
666                      SmallVectorImpl<uint32_t> &AggIdxs, uint64_t AggByteOffset,
667                      Value *&Result, const Twine &Name);
668   /// Return value is (Changed, ModifiedInPlace)
669   std::pair<bool, bool> visitStoreImpl(StoreInst &OrigSI, Type *PartType,
670                                        SmallVectorImpl<uint32_t> &AggIdxs,
671                                        uint64_t AggByteOffset,
672                                        const Twine &Name);
673 
674   bool visitInstruction(Instruction &I) { return false; }
675   bool visitLoadInst(LoadInst &LI);
676   bool visitStoreInst(StoreInst &SI);
677 
678 public:
679   LegalizeBufferContentTypesVisitor(const DataLayout &DL, LLVMContext &Ctx)
680       : IRB(Ctx), DL(DL) {}
681   bool processFunction(Function &F);
682 };
683 } // namespace
684 
685 Type *LegalizeBufferContentTypesVisitor::scalarArrayTypeAsVector(Type *T) {
686   ArrayType *AT = dyn_cast<ArrayType>(T);
687   if (!AT)
688     return T;
689   Type *ET = AT->getElementType();
690   if (!ET->isSingleValueType() || isa<VectorType>(ET))
691     report_fatal_error("loading non-scalar arrays from buffer fat pointers "
692                        "should have recursed");
693   if (!DL.typeSizeEqualsStoreSize(AT))
694     report_fatal_error(
695         "loading padded arrays from buffer fat pinters should have recursed");
696   return FixedVectorType::get(ET, AT->getNumElements());
697 }
698 
699 Value *LegalizeBufferContentTypesVisitor::arrayToVector(Value *V,
700                                                         Type *TargetType,
701                                                         const Twine &Name) {
702   Value *VectorRes = PoisonValue::get(TargetType);
703   auto *VT = cast<FixedVectorType>(TargetType);
704   unsigned EC = VT->getNumElements();
705   for (auto I : iota_range<unsigned>(0, EC, /*Inclusive=*/false)) {
706     Value *Elem = IRB.CreateExtractValue(V, I, Name + ".elem." + Twine(I));
707     VectorRes = IRB.CreateInsertElement(VectorRes, Elem, I,
708                                         Name + ".as.vec." + Twine(I));
709   }
710   return VectorRes;
711 }
712 
713 Value *LegalizeBufferContentTypesVisitor::vectorToArray(Value *V,
714                                                         Type *OrigType,
715                                                         const Twine &Name) {
716   Value *ArrayRes = PoisonValue::get(OrigType);
717   ArrayType *AT = cast<ArrayType>(OrigType);
718   unsigned EC = AT->getNumElements();
719   for (auto I : iota_range<unsigned>(0, EC, /*Inclusive=*/false)) {
720     Value *Elem = IRB.CreateExtractElement(V, I, Name + ".elem." + Twine(I));
721     ArrayRes = IRB.CreateInsertValue(ArrayRes, Elem, I,
722                                      Name + ".as.array." + Twine(I));
723   }
724   return ArrayRes;
725 }
726 
727 Type *LegalizeBufferContentTypesVisitor::legalNonAggregateFor(Type *T) {
728   TypeSize Size = DL.getTypeStoreSizeInBits(T);
729   // Implicitly zero-extend to the next byte if needed
730   if (!DL.typeSizeEqualsStoreSize(T))
731     T = IRB.getIntNTy(Size.getFixedValue());
732   Type *ElemTy = T->getScalarType();
733   if (isa<PointerType, ScalableVectorType>(ElemTy)) {
734     // Pointers are always big enough, and we'll let scalable vectors through to
735     // fail in codegen.
736     return T;
737   }
738   unsigned ElemSize = DL.getTypeSizeInBits(ElemTy).getFixedValue();
739   if (isPowerOf2_32(ElemSize) && ElemSize >= 16 && ElemSize <= 128) {
740     // [vectors of] anything that's 16/32/64/128 bits can be cast and split into
741     // legal buffer operations.
742     return T;
743   }
744   Type *BestVectorElemType = nullptr;
745   if (Size.isKnownMultipleOf(32))
746     BestVectorElemType = IRB.getInt32Ty();
747   else if (Size.isKnownMultipleOf(16))
748     BestVectorElemType = IRB.getInt16Ty();
749   else
750     BestVectorElemType = IRB.getInt8Ty();
751   unsigned NumCastElems =
752       Size.getFixedValue() / BestVectorElemType->getIntegerBitWidth();
753   if (NumCastElems == 1)
754     return BestVectorElemType;
755   return FixedVectorType::get(BestVectorElemType, NumCastElems);
756 }
757 
758 Value *LegalizeBufferContentTypesVisitor::makeLegalNonAggregate(
759     Value *V, Type *TargetType, const Twine &Name) {
760   Type *SourceType = V->getType();
761   TypeSize SourceSize = DL.getTypeSizeInBits(SourceType);
762   TypeSize TargetSize = DL.getTypeSizeInBits(TargetType);
763   if (SourceSize != TargetSize) {
764     Type *ShortScalarTy = IRB.getIntNTy(SourceSize.getFixedValue());
765     Type *ByteScalarTy = IRB.getIntNTy(TargetSize.getFixedValue());
766     Value *AsScalar = IRB.CreateBitCast(V, ShortScalarTy, Name + ".as.scalar");
767     Value *Zext = IRB.CreateZExt(AsScalar, ByteScalarTy, Name + ".zext");
768     V = Zext;
769     SourceType = ByteScalarTy;
770   }
771   return IRB.CreateBitCast(V, TargetType, Name + ".legal");
772 }
773 
774 Value *LegalizeBufferContentTypesVisitor::makeIllegalNonAggregate(
775     Value *V, Type *OrigType, const Twine &Name) {
776   Type *LegalType = V->getType();
777   TypeSize LegalSize = DL.getTypeSizeInBits(LegalType);
778   TypeSize OrigSize = DL.getTypeSizeInBits(OrigType);
779   if (LegalSize != OrigSize) {
780     Type *ShortScalarTy = IRB.getIntNTy(OrigSize.getFixedValue());
781     Type *ByteScalarTy = IRB.getIntNTy(LegalSize.getFixedValue());
782     Value *AsScalar = IRB.CreateBitCast(V, ByteScalarTy, Name + ".bytes.cast");
783     Value *Trunc = IRB.CreateTrunc(AsScalar, ShortScalarTy, Name + ".trunc");
784     return IRB.CreateBitCast(Trunc, OrigType, Name + ".orig");
785   }
786   return IRB.CreateBitCast(V, OrigType, Name + ".real.ty");
787 }
788 
789 Type *LegalizeBufferContentTypesVisitor::intrinsicTypeFor(Type *LegalType) {
790   auto *VT = dyn_cast<FixedVectorType>(LegalType);
791   if (!VT)
792     return LegalType;
793   Type *ET = VT->getElementType();
794   // Explicitly return the element type of 1-element vectors because the
795   // underlying intrinsics don't like <1 x T> even though it's a synonym for T.
796   if (VT->getNumElements() == 1)
797     return ET;
798   if (DL.getTypeSizeInBits(LegalType) == 96 && DL.getTypeSizeInBits(ET) < 32)
799     return FixedVectorType::get(IRB.getInt32Ty(), 3);
800   if (ET->isIntegerTy(8)) {
801     switch (VT->getNumElements()) {
802     default:
803       return LegalType; // Let it crash later
804     case 1:
805       return IRB.getInt8Ty();
806     case 2:
807       return IRB.getInt16Ty();
808     case 4:
809       return IRB.getInt32Ty();
810     case 8:
811       return FixedVectorType::get(IRB.getInt32Ty(), 2);
812     case 16:
813       return FixedVectorType::get(IRB.getInt32Ty(), 4);
814     }
815   }
816   return LegalType;
817 }
818 
819 void LegalizeBufferContentTypesVisitor::getVecSlices(
820     Type *T, SmallVectorImpl<VecSlice> &Slices) {
821   Slices.clear();
822   auto *VT = dyn_cast<FixedVectorType>(T);
823   if (!VT)
824     return;
825 
826   uint64_t ElemBitWidth =
827       DL.getTypeSizeInBits(VT->getElementType()).getFixedValue();
828 
829   uint64_t ElemsPer4Words = 128 / ElemBitWidth;
830   uint64_t ElemsPer2Words = ElemsPer4Words / 2;
831   uint64_t ElemsPerWord = ElemsPer2Words / 2;
832   uint64_t ElemsPerShort = ElemsPerWord / 2;
833   uint64_t ElemsPerByte = ElemsPerShort / 2;
834   // If the elements evenly pack into 32-bit words, we can use 3-word stores,
835   // such as for <6 x bfloat> or <3 x i32>, but we can't dot his for, for
836   // example, <3 x i64>, since that's not slicing.
837   uint64_t ElemsPer3Words = ElemsPerWord * 3;
838 
839   uint64_t TotalElems = VT->getNumElements();
840   uint64_t Index = 0;
841   auto TrySlice = [&](unsigned MaybeLen) {
842     if (MaybeLen > 0 && Index + MaybeLen <= TotalElems) {
843       VecSlice Slice{/*Index=*/Index, /*Length=*/MaybeLen};
844       Slices.push_back(Slice);
845       Index += MaybeLen;
846       return true;
847     }
848     return false;
849   };
850   while (Index < TotalElems) {
851     TrySlice(ElemsPer4Words) || TrySlice(ElemsPer3Words) ||
852         TrySlice(ElemsPer2Words) || TrySlice(ElemsPerWord) ||
853         TrySlice(ElemsPerShort) || TrySlice(ElemsPerByte);
854   }
855 }
856 
857 Value *LegalizeBufferContentTypesVisitor::extractSlice(Value *Vec, VecSlice S,
858                                                        const Twine &Name) {
859   auto *VecVT = dyn_cast<FixedVectorType>(Vec->getType());
860   if (!VecVT)
861     return Vec;
862   if (S.Length == VecVT->getNumElements() && S.Index == 0)
863     return Vec;
864   if (S.Length == 1)
865     return IRB.CreateExtractElement(Vec, S.Index,
866                                     Name + ".slice." + Twine(S.Index));
867   SmallVector<int> Mask = llvm::to_vector(
868       llvm::iota_range<int>(S.Index, S.Index + S.Length, /*Inclusive=*/false));
869   return IRB.CreateShuffleVector(Vec, Mask, Name + ".slice." + Twine(S.Index));
870 }
871 
872 Value *LegalizeBufferContentTypesVisitor::insertSlice(Value *Whole, Value *Part,
873                                                       VecSlice S,
874                                                       const Twine &Name) {
875   auto *WholeVT = dyn_cast<FixedVectorType>(Whole->getType());
876   if (!WholeVT)
877     return Part;
878   if (S.Length == WholeVT->getNumElements() && S.Index == 0)
879     return Part;
880   if (S.Length == 1) {
881     return IRB.CreateInsertElement(Whole, Part, S.Index,
882                                    Name + ".slice." + Twine(S.Index));
883   }
884   int NumElems = cast<FixedVectorType>(Whole->getType())->getNumElements();
885 
886   // Extend the slice with poisons to make the main shufflevector happy.
887   SmallVector<int> ExtPartMask(NumElems, -1);
888   for (auto [I, E] : llvm::enumerate(
889            MutableArrayRef<int>(ExtPartMask).take_front(S.Length))) {
890     E = I;
891   }
892   Value *ExtPart = IRB.CreateShuffleVector(Part, ExtPartMask,
893                                            Name + ".ext." + Twine(S.Index));
894 
895   SmallVector<int> Mask =
896       llvm::to_vector(llvm::iota_range<int>(0, NumElems, /*Inclusive=*/false));
897   for (auto [I, E] :
898        llvm::enumerate(MutableArrayRef<int>(Mask).slice(S.Index, S.Length)))
899     E = I + NumElems;
900   return IRB.CreateShuffleVector(Whole, ExtPart, Mask,
901                                  Name + ".parts." + Twine(S.Index));
902 }
903 
904 bool LegalizeBufferContentTypesVisitor::visitLoadImpl(
905     LoadInst &OrigLI, Type *PartType, SmallVectorImpl<uint32_t> &AggIdxs,
906     uint64_t AggByteOff, Value *&Result, const Twine &Name) {
907   if (auto *ST = dyn_cast<StructType>(PartType)) {
908     const StructLayout *Layout = DL.getStructLayout(ST);
909     bool Changed = false;
910     for (auto [I, ElemTy, Offset] :
911          llvm::enumerate(ST->elements(), Layout->getMemberOffsets())) {
912       AggIdxs.push_back(I);
913       Changed |= visitLoadImpl(OrigLI, ElemTy, AggIdxs,
914                                AggByteOff + Offset.getFixedValue(), Result,
915                                Name + "." + Twine(I));
916       AggIdxs.pop_back();
917     }
918     return Changed;
919   }
920   if (auto *AT = dyn_cast<ArrayType>(PartType)) {
921     Type *ElemTy = AT->getElementType();
922     if (!ElemTy->isSingleValueType() || !DL.typeSizeEqualsStoreSize(ElemTy) ||
923         ElemTy->isVectorTy()) {
924       TypeSize ElemStoreSize = DL.getTypeStoreSize(ElemTy);
925       bool Changed = false;
926       for (auto I : llvm::iota_range<uint32_t>(0, AT->getNumElements(),
927                                                /*Inclusive=*/false)) {
928         AggIdxs.push_back(I);
929         Changed |= visitLoadImpl(OrigLI, ElemTy, AggIdxs,
930                                  AggByteOff + I * ElemStoreSize.getFixedValue(),
931                                  Result, Name + Twine(I));
932         AggIdxs.pop_back();
933       }
934       return Changed;
935     }
936   }
937 
938   // Typical case
939 
940   Type *ArrayAsVecType = scalarArrayTypeAsVector(PartType);
941   Type *LegalType = legalNonAggregateFor(ArrayAsVecType);
942 
943   SmallVector<VecSlice> Slices;
944   getVecSlices(LegalType, Slices);
945   bool HasSlices = Slices.size() > 1;
946   bool IsAggPart = !AggIdxs.empty();
947   Value *LoadsRes;
948   if (!HasSlices && !IsAggPart) {
949     Type *LoadableType = intrinsicTypeFor(LegalType);
950     if (LoadableType == PartType)
951       return false;
952 
953     IRB.SetInsertPoint(&OrigLI);
954     auto *NLI = cast<LoadInst>(OrigLI.clone());
955     NLI->mutateType(LoadableType);
956     NLI = IRB.Insert(NLI);
957     NLI->setName(Name + ".loadable");
958 
959     LoadsRes = IRB.CreateBitCast(NLI, LegalType, Name + ".from.loadable");
960   } else {
961     IRB.SetInsertPoint(&OrigLI);
962     LoadsRes = PoisonValue::get(LegalType);
963     Value *OrigPtr = OrigLI.getPointerOperand();
964     // If we're needing to spill something into more than one load, its legal
965     // type will be a vector (ex. an i256 load will have LegalType = <8 x i32>).
966     // But if we're already a scalar (which can happen if we're splitting up a
967     // struct), the element type will be the legal type itself.
968     Type *ElemType = LegalType->getScalarType();
969     unsigned ElemBytes = DL.getTypeStoreSize(ElemType);
970     AAMDNodes AANodes = OrigLI.getAAMetadata();
971     if (IsAggPart && Slices.empty())
972       Slices.push_back(VecSlice{/*Index=*/0, /*Length=*/1});
973     for (VecSlice S : Slices) {
974       Type *SliceType =
975           S.Length != 1 ? FixedVectorType::get(ElemType, S.Length) : ElemType;
976       int64_t ByteOffset = AggByteOff + S.Index * ElemBytes;
977       // You can't reasonably expect loads to wrap around the edge of memory.
978       Value *NewPtr = IRB.CreateGEP(
979           IRB.getInt8Ty(), OrigLI.getPointerOperand(), IRB.getInt32(ByteOffset),
980           OrigPtr->getName() + ".off.ptr." + Twine(ByteOffset),
981           GEPNoWrapFlags::noUnsignedWrap());
982       Type *LoadableType = intrinsicTypeFor(SliceType);
983       LoadInst *NewLI = IRB.CreateAlignedLoad(
984           LoadableType, NewPtr, commonAlignment(OrigLI.getAlign(), ByteOffset),
985           Name + ".off." + Twine(ByteOffset));
986       copyMetadataForLoad(*NewLI, OrigLI);
987       NewLI->setAAMetadata(
988           AANodes.adjustForAccess(ByteOffset, LoadableType, DL));
989       NewLI->setAtomic(OrigLI.getOrdering(), OrigLI.getSyncScopeID());
990       NewLI->setVolatile(OrigLI.isVolatile());
991       Value *Loaded = IRB.CreateBitCast(NewLI, SliceType,
992                                         NewLI->getName() + ".from.loadable");
993       LoadsRes = insertSlice(LoadsRes, Loaded, S, Name);
994     }
995   }
996   if (LegalType != ArrayAsVecType)
997     LoadsRes = makeIllegalNonAggregate(LoadsRes, ArrayAsVecType, Name);
998   if (ArrayAsVecType != PartType)
999     LoadsRes = vectorToArray(LoadsRes, PartType, Name);
1000 
1001   if (IsAggPart)
1002     Result = IRB.CreateInsertValue(Result, LoadsRes, AggIdxs, Name);
1003   else
1004     Result = LoadsRes;
1005   return true;
1006 }
1007 
1008 bool LegalizeBufferContentTypesVisitor::visitLoadInst(LoadInst &LI) {
1009   if (LI.getPointerAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
1010     return false;
1011 
1012   SmallVector<uint32_t> AggIdxs;
1013   Type *OrigType = LI.getType();
1014   Value *Result = PoisonValue::get(OrigType);
1015   bool Changed = visitLoadImpl(LI, OrigType, AggIdxs, 0, Result, LI.getName());
1016   if (!Changed)
1017     return false;
1018   Result->takeName(&LI);
1019   LI.replaceAllUsesWith(Result);
1020   LI.eraseFromParent();
1021   return Changed;
1022 }
1023 
1024 std::pair<bool, bool> LegalizeBufferContentTypesVisitor::visitStoreImpl(
1025     StoreInst &OrigSI, Type *PartType, SmallVectorImpl<uint32_t> &AggIdxs,
1026     uint64_t AggByteOff, const Twine &Name) {
1027   if (auto *ST = dyn_cast<StructType>(PartType)) {
1028     const StructLayout *Layout = DL.getStructLayout(ST);
1029     bool Changed = false;
1030     for (auto [I, ElemTy, Offset] :
1031          llvm::enumerate(ST->elements(), Layout->getMemberOffsets())) {
1032       AggIdxs.push_back(I);
1033       Changed |= std::get<0>(visitStoreImpl(OrigSI, ElemTy, AggIdxs,
1034                                             AggByteOff + Offset.getFixedValue(),
1035                                             Name + "." + Twine(I)));
1036       AggIdxs.pop_back();
1037     }
1038     return std::make_pair(Changed, /*ModifiedInPlace=*/false);
1039   }
1040   if (auto *AT = dyn_cast<ArrayType>(PartType)) {
1041     Type *ElemTy = AT->getElementType();
1042     if (!ElemTy->isSingleValueType() || !DL.typeSizeEqualsStoreSize(ElemTy) ||
1043         ElemTy->isVectorTy()) {
1044       TypeSize ElemStoreSize = DL.getTypeStoreSize(ElemTy);
1045       bool Changed = false;
1046       for (auto I : llvm::iota_range<uint32_t>(0, AT->getNumElements(),
1047                                                /*Inclusive=*/false)) {
1048         AggIdxs.push_back(I);
1049         Changed |= std::get<0>(visitStoreImpl(
1050             OrigSI, ElemTy, AggIdxs,
1051             AggByteOff + I * ElemStoreSize.getFixedValue(), Name + Twine(I)));
1052         AggIdxs.pop_back();
1053       }
1054       return std::make_pair(Changed, /*ModifiedInPlace=*/false);
1055     }
1056   }
1057 
1058   Value *OrigData = OrigSI.getValueOperand();
1059   Value *NewData = OrigData;
1060 
1061   bool IsAggPart = !AggIdxs.empty();
1062   if (IsAggPart)
1063     NewData = IRB.CreateExtractValue(NewData, AggIdxs, Name);
1064 
1065   Type *ArrayAsVecType = scalarArrayTypeAsVector(PartType);
1066   if (ArrayAsVecType != PartType) {
1067     NewData = arrayToVector(NewData, ArrayAsVecType, Name);
1068   }
1069 
1070   Type *LegalType = legalNonAggregateFor(ArrayAsVecType);
1071   if (LegalType != ArrayAsVecType) {
1072     NewData = makeLegalNonAggregate(NewData, LegalType, Name);
1073   }
1074 
1075   SmallVector<VecSlice> Slices;
1076   getVecSlices(LegalType, Slices);
1077   bool NeedToSplit = Slices.size() > 1 || IsAggPart;
1078   if (!NeedToSplit) {
1079     Type *StorableType = intrinsicTypeFor(LegalType);
1080     if (StorableType == PartType)
1081       return std::make_pair(/*Changed=*/false, /*ModifiedInPlace=*/false);
1082     NewData = IRB.CreateBitCast(NewData, StorableType, Name + ".storable");
1083     OrigSI.setOperand(0, NewData);
1084     return std::make_pair(/*Changed=*/true, /*ModifiedInPlace=*/true);
1085   }
1086 
1087   Value *OrigPtr = OrigSI.getPointerOperand();
1088   Type *ElemType = LegalType->getScalarType();
1089   if (IsAggPart && Slices.empty())
1090     Slices.push_back(VecSlice{/*Index=*/0, /*Length=*/1});
1091   unsigned ElemBytes = DL.getTypeStoreSize(ElemType);
1092   AAMDNodes AANodes = OrigSI.getAAMetadata();
1093   for (VecSlice S : Slices) {
1094     Type *SliceType =
1095         S.Length != 1 ? FixedVectorType::get(ElemType, S.Length) : ElemType;
1096     int64_t ByteOffset = AggByteOff + S.Index * ElemBytes;
1097     Value *NewPtr =
1098         IRB.CreateGEP(IRB.getInt8Ty(), OrigPtr, IRB.getInt32(ByteOffset),
1099                       OrigPtr->getName() + ".part." + Twine(S.Index),
1100                       GEPNoWrapFlags::noUnsignedWrap());
1101     Value *DataSlice = extractSlice(NewData, S, Name);
1102     Type *StorableType = intrinsicTypeFor(SliceType);
1103     DataSlice = IRB.CreateBitCast(DataSlice, StorableType,
1104                                   DataSlice->getName() + ".storable");
1105     auto *NewSI = cast<StoreInst>(OrigSI.clone());
1106     NewSI->setAlignment(commonAlignment(OrigSI.getAlign(), ByteOffset));
1107     IRB.Insert(NewSI);
1108     NewSI->setOperand(0, DataSlice);
1109     NewSI->setOperand(1, NewPtr);
1110     NewSI->setAAMetadata(AANodes.adjustForAccess(ByteOffset, StorableType, DL));
1111   }
1112   return std::make_pair(/*Changed=*/true, /*ModifiedInPlace=*/false);
1113 }
1114 
1115 bool LegalizeBufferContentTypesVisitor::visitStoreInst(StoreInst &SI) {
1116   if (SI.getPointerAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
1117     return false;
1118   IRB.SetInsertPoint(&SI);
1119   SmallVector<uint32_t> AggIdxs;
1120   Value *OrigData = SI.getValueOperand();
1121   auto [Changed, ModifiedInPlace] =
1122       visitStoreImpl(SI, OrigData->getType(), AggIdxs, 0, OrigData->getName());
1123   if (Changed && !ModifiedInPlace)
1124     SI.eraseFromParent();
1125   return Changed;
1126 }
1127 
1128 bool LegalizeBufferContentTypesVisitor::processFunction(Function &F) {
1129   bool Changed = false;
1130   for (Instruction &I : make_early_inc_range(instructions(F))) {
1131     Changed |= visit(I);
1132   }
1133   return Changed;
1134 }
1135 
1136 /// Return the ptr addrspace(8) and i32 (resource and offset parts) in a lowered
1137 /// buffer fat pointer constant.
1138 static std::pair<Constant *, Constant *>
1139 splitLoweredFatBufferConst(Constant *C) {
1140   assert(isSplitFatPtr(C->getType()) && "Not a split fat buffer pointer");
1141   return std::make_pair(C->getAggregateElement(0u), C->getAggregateElement(1u));
1142 }
1143 
1144 namespace {
1145 /// Handle the remapping of ptr addrspace(7) constants.
1146 class FatPtrConstMaterializer final : public ValueMaterializer {
1147   BufferFatPtrToStructTypeMap *TypeMap;
1148   // An internal mapper that is used to recurse into the arguments of constants.
1149   // While the documentation for `ValueMapper` specifies not to use it
1150   // recursively, examination of the logic in mapValue() shows that it can
1151   // safely be used recursively when handling constants, like it does in its own
1152   // logic.
1153   ValueMapper InternalMapper;
1154 
1155   Constant *materializeBufferFatPtrConst(Constant *C);
1156 
1157 public:
1158   // UnderlyingMap is the value map this materializer will be filling.
1159   FatPtrConstMaterializer(BufferFatPtrToStructTypeMap *TypeMap,
1160                           ValueToValueMapTy &UnderlyingMap)
1161       : TypeMap(TypeMap),
1162         InternalMapper(UnderlyingMap, RF_None, TypeMap, this) {}
1163   virtual ~FatPtrConstMaterializer() = default;
1164 
1165   Value *materialize(Value *V) override;
1166 };
1167 } // namespace
1168 
1169 Constant *FatPtrConstMaterializer::materializeBufferFatPtrConst(Constant *C) {
1170   Type *SrcTy = C->getType();
1171   auto *NewTy = dyn_cast<StructType>(TypeMap->remapType(SrcTy));
1172   if (C->isNullValue())
1173     return ConstantAggregateZero::getNullValue(NewTy);
1174   if (isa<PoisonValue>(C)) {
1175     return ConstantStruct::get(NewTy,
1176                                {PoisonValue::get(NewTy->getElementType(0)),
1177                                 PoisonValue::get(NewTy->getElementType(1))});
1178   }
1179   if (isa<UndefValue>(C)) {
1180     return ConstantStruct::get(NewTy,
1181                                {UndefValue::get(NewTy->getElementType(0)),
1182                                 UndefValue::get(NewTy->getElementType(1))});
1183   }
1184 
1185   if (auto *VC = dyn_cast<ConstantVector>(C)) {
1186     if (Constant *S = VC->getSplatValue()) {
1187       Constant *NewS = InternalMapper.mapConstant(*S);
1188       if (!NewS)
1189         return nullptr;
1190       auto [Rsrc, Off] = splitLoweredFatBufferConst(NewS);
1191       auto EC = VC->getType()->getElementCount();
1192       return ConstantStruct::get(NewTy, {ConstantVector::getSplat(EC, Rsrc),
1193                                          ConstantVector::getSplat(EC, Off)});
1194     }
1195     SmallVector<Constant *> Rsrcs;
1196     SmallVector<Constant *> Offs;
1197     for (Value *Op : VC->operand_values()) {
1198       auto *NewOp = dyn_cast_or_null<Constant>(InternalMapper.mapValue(*Op));
1199       if (!NewOp)
1200         return nullptr;
1201       auto [Rsrc, Off] = splitLoweredFatBufferConst(NewOp);
1202       Rsrcs.push_back(Rsrc);
1203       Offs.push_back(Off);
1204     }
1205     Constant *RsrcVec = ConstantVector::get(Rsrcs);
1206     Constant *OffVec = ConstantVector::get(Offs);
1207     return ConstantStruct::get(NewTy, {RsrcVec, OffVec});
1208   }
1209 
1210   if (isa<GlobalValue>(C))
1211     report_fatal_error("Global values containing ptr addrspace(7) (buffer "
1212                        "fat pointer) values are not supported");
1213 
1214   if (isa<ConstantExpr>(C))
1215     report_fatal_error("Constant exprs containing ptr addrspace(7) (buffer "
1216                        "fat pointer) values should have been expanded earlier");
1217 
1218   return nullptr;
1219 }
1220 
1221 Value *FatPtrConstMaterializer::materialize(Value *V) {
1222   Constant *C = dyn_cast<Constant>(V);
1223   if (!C)
1224     return nullptr;
1225   // Structs and other types that happen to contain fat pointers get remapped
1226   // by the mapValue() logic.
1227   if (!isBufferFatPtrConst(C))
1228     return nullptr;
1229   return materializeBufferFatPtrConst(C);
1230 }
1231 
1232 using PtrParts = std::pair<Value *, Value *>;
1233 namespace {
1234 // The visitor returns the resource and offset parts for an instruction if they
1235 // can be computed, or (nullptr, nullptr) for cases that don't have a meaningful
1236 // value mapping.
1237 class SplitPtrStructs : public InstVisitor<SplitPtrStructs, PtrParts> {
1238   ValueToValueMapTy RsrcParts;
1239   ValueToValueMapTy OffParts;
1240 
1241   // Track instructions that have been rewritten into a user of the component
1242   // parts of their ptr addrspace(7) input. Instructions that produced
1243   // ptr addrspace(7) parts should **not** be RAUW'd before being added to this
1244   // set, as that replacement will be handled in a post-visit step. However,
1245   // instructions that yield values that aren't fat pointers (ex. ptrtoint)
1246   // should RAUW themselves with new instructions that use the split parts
1247   // of their arguments during processing.
1248   DenseSet<Instruction *> SplitUsers;
1249 
1250   // Nodes that need a second look once we've computed the parts for all other
1251   // instructions to see if, for example, we really need to phi on the resource
1252   // part.
1253   SmallVector<Instruction *> Conditionals;
1254   // Temporary instructions produced while lowering conditionals that should be
1255   // killed.
1256   SmallVector<Instruction *> ConditionalTemps;
1257 
1258   // Subtarget info, needed for determining what cache control bits to set.
1259   const TargetMachine *TM;
1260   const GCNSubtarget *ST = nullptr;
1261 
1262   IRBuilder<> IRB;
1263 
1264   // Copy metadata between instructions if applicable.
1265   void copyMetadata(Value *Dest, Value *Src);
1266 
1267   // Get the resource and offset parts of the value V, inserting appropriate
1268   // extractvalue calls if needed.
1269   PtrParts getPtrParts(Value *V);
1270 
1271   // Given an instruction that could produce multiple resource parts (a PHI or
1272   // select), collect the set of possible instructions that could have provided
1273   // its resource parts  that it could have (the `Roots`) and the set of
1274   // conditional instructions visited during the search (`Seen`). If, after
1275   // removing the root of the search from `Seen` and `Roots`, `Seen` is a subset
1276   // of `Roots` and `Roots - Seen` contains one element, the resource part of
1277   // that element can replace the resource part of all other elements in `Seen`.
1278   void getPossibleRsrcRoots(Instruction *I, SmallPtrSetImpl<Value *> &Roots,
1279                             SmallPtrSetImpl<Value *> &Seen);
1280   void processConditionals();
1281 
1282   // If an instruction hav been split into resource and offset parts,
1283   // delete that instruction. If any of its uses have not themselves been split
1284   // into parts (for example, an insertvalue), construct the structure
1285   // that the type rewrites declared should be produced by the dying instruction
1286   // and use that.
1287   // Also, kill the temporary extractvalue operations produced by the two-stage
1288   // lowering of PHIs and conditionals.
1289   void killAndReplaceSplitInstructions(SmallVectorImpl<Instruction *> &Origs);
1290 
1291   void setAlign(CallInst *Intr, Align A, unsigned RsrcArgIdx);
1292   void insertPreMemOpFence(AtomicOrdering Order, SyncScope::ID SSID);
1293   void insertPostMemOpFence(AtomicOrdering Order, SyncScope::ID SSID);
1294   Value *handleMemoryInst(Instruction *I, Value *Arg, Value *Ptr, Type *Ty,
1295                           Align Alignment, AtomicOrdering Order,
1296                           bool IsVolatile, SyncScope::ID SSID);
1297 
1298 public:
1299   SplitPtrStructs(LLVMContext &Ctx, const TargetMachine *TM)
1300       : TM(TM), IRB(Ctx) {}
1301 
1302   void processFunction(Function &F);
1303 
1304   PtrParts visitInstruction(Instruction &I);
1305   PtrParts visitLoadInst(LoadInst &LI);
1306   PtrParts visitStoreInst(StoreInst &SI);
1307   PtrParts visitAtomicRMWInst(AtomicRMWInst &AI);
1308   PtrParts visitAtomicCmpXchgInst(AtomicCmpXchgInst &AI);
1309   PtrParts visitGetElementPtrInst(GetElementPtrInst &GEP);
1310 
1311   PtrParts visitPtrToIntInst(PtrToIntInst &PI);
1312   PtrParts visitIntToPtrInst(IntToPtrInst &IP);
1313   PtrParts visitAddrSpaceCastInst(AddrSpaceCastInst &I);
1314   PtrParts visitICmpInst(ICmpInst &Cmp);
1315   PtrParts visitFreezeInst(FreezeInst &I);
1316 
1317   PtrParts visitExtractElementInst(ExtractElementInst &I);
1318   PtrParts visitInsertElementInst(InsertElementInst &I);
1319   PtrParts visitShuffleVectorInst(ShuffleVectorInst &I);
1320 
1321   PtrParts visitPHINode(PHINode &PHI);
1322   PtrParts visitSelectInst(SelectInst &SI);
1323 
1324   PtrParts visitIntrinsicInst(IntrinsicInst &II);
1325 };
1326 } // namespace
1327 
1328 void SplitPtrStructs::copyMetadata(Value *Dest, Value *Src) {
1329   auto *DestI = dyn_cast<Instruction>(Dest);
1330   auto *SrcI = dyn_cast<Instruction>(Src);
1331 
1332   if (!DestI || !SrcI)
1333     return;
1334 
1335   DestI->copyMetadata(*SrcI);
1336 }
1337 
1338 PtrParts SplitPtrStructs::getPtrParts(Value *V) {
1339   assert(isSplitFatPtr(V->getType()) && "it's not meaningful to get the parts "
1340                                         "of something that wasn't rewritten");
1341   auto *RsrcEntry = &RsrcParts[V];
1342   auto *OffEntry = &OffParts[V];
1343   if (*RsrcEntry && *OffEntry)
1344     return {*RsrcEntry, *OffEntry};
1345 
1346   if (auto *C = dyn_cast<Constant>(V)) {
1347     auto [Rsrc, Off] = splitLoweredFatBufferConst(C);
1348     return {*RsrcEntry = Rsrc, *OffEntry = Off};
1349   }
1350 
1351   IRBuilder<>::InsertPointGuard Guard(IRB);
1352   if (auto *I = dyn_cast<Instruction>(V)) {
1353     LLVM_DEBUG(dbgs() << "Recursing to split parts of " << *I << "\n");
1354     auto [Rsrc, Off] = visit(*I);
1355     if (Rsrc && Off)
1356       return {*RsrcEntry = Rsrc, *OffEntry = Off};
1357     // We'll be creating the new values after the relevant instruction.
1358     // This instruction generates a value and so isn't a terminator.
1359     IRB.SetInsertPoint(*I->getInsertionPointAfterDef());
1360     IRB.SetCurrentDebugLocation(I->getDebugLoc());
1361   } else if (auto *A = dyn_cast<Argument>(V)) {
1362     IRB.SetInsertPointPastAllocas(A->getParent());
1363     IRB.SetCurrentDebugLocation(DebugLoc());
1364   }
1365   Value *Rsrc = IRB.CreateExtractValue(V, 0, V->getName() + ".rsrc");
1366   Value *Off = IRB.CreateExtractValue(V, 1, V->getName() + ".off");
1367   return {*RsrcEntry = Rsrc, *OffEntry = Off};
1368 }
1369 
1370 /// Returns the instruction that defines the resource part of the value V.
1371 /// Note that this is not getUnderlyingObject(), since that looks through
1372 /// operations like ptrmask which might modify the resource part.
1373 ///
1374 /// We can limit ourselves to just looking through GEPs followed by looking
1375 /// through addrspacecasts because only those two operations preserve the
1376 /// resource part, and because operations on an `addrspace(8)` (which is the
1377 /// legal input to this addrspacecast) would produce a different resource part.
1378 static Value *rsrcPartRoot(Value *V) {
1379   while (auto *GEP = dyn_cast<GEPOperator>(V))
1380     V = GEP->getPointerOperand();
1381   while (auto *ASC = dyn_cast<AddrSpaceCastOperator>(V))
1382     V = ASC->getPointerOperand();
1383   return V;
1384 }
1385 
1386 void SplitPtrStructs::getPossibleRsrcRoots(Instruction *I,
1387                                            SmallPtrSetImpl<Value *> &Roots,
1388                                            SmallPtrSetImpl<Value *> &Seen) {
1389   if (auto *PHI = dyn_cast<PHINode>(I)) {
1390     if (!Seen.insert(I).second)
1391       return;
1392     for (Value *In : PHI->incoming_values()) {
1393       In = rsrcPartRoot(In);
1394       Roots.insert(In);
1395       if (isa<PHINode, SelectInst>(In))
1396         getPossibleRsrcRoots(cast<Instruction>(In), Roots, Seen);
1397     }
1398   } else if (auto *SI = dyn_cast<SelectInst>(I)) {
1399     if (!Seen.insert(SI).second)
1400       return;
1401     Value *TrueVal = rsrcPartRoot(SI->getTrueValue());
1402     Value *FalseVal = rsrcPartRoot(SI->getFalseValue());
1403     Roots.insert(TrueVal);
1404     Roots.insert(FalseVal);
1405     if (isa<PHINode, SelectInst>(TrueVal))
1406       getPossibleRsrcRoots(cast<Instruction>(TrueVal), Roots, Seen);
1407     if (isa<PHINode, SelectInst>(FalseVal))
1408       getPossibleRsrcRoots(cast<Instruction>(FalseVal), Roots, Seen);
1409   } else {
1410     llvm_unreachable("getPossibleRsrcParts() only works on phi and select");
1411   }
1412 }
1413 
1414 void SplitPtrStructs::processConditionals() {
1415   SmallDenseMap<Instruction *, Value *> FoundRsrcs;
1416   SmallPtrSet<Value *, 4> Roots;
1417   SmallPtrSet<Value *, 4> Seen;
1418   for (Instruction *I : Conditionals) {
1419     // These have to exist by now because we've visited these nodes.
1420     Value *Rsrc = RsrcParts[I];
1421     Value *Off = OffParts[I];
1422     assert(Rsrc && Off && "must have visited conditionals by now");
1423 
1424     std::optional<Value *> MaybeRsrc;
1425     auto MaybeFoundRsrc = FoundRsrcs.find(I);
1426     if (MaybeFoundRsrc != FoundRsrcs.end()) {
1427       MaybeRsrc = MaybeFoundRsrc->second;
1428     } else {
1429       IRBuilder<>::InsertPointGuard Guard(IRB);
1430       Roots.clear();
1431       Seen.clear();
1432       getPossibleRsrcRoots(I, Roots, Seen);
1433       LLVM_DEBUG(dbgs() << "Processing conditional: " << *I << "\n");
1434 #ifndef NDEBUG
1435       for (Value *V : Roots)
1436         LLVM_DEBUG(dbgs() << "Root: " << *V << "\n");
1437       for (Value *V : Seen)
1438         LLVM_DEBUG(dbgs() << "Seen: " << *V << "\n");
1439 #endif
1440       // If we are our own possible root, then we shouldn't block our
1441       // replacement with a valid incoming value.
1442       Roots.erase(I);
1443       // We don't want to block the optimization for conditionals that don't
1444       // refer to themselves but did see themselves during the traversal.
1445       Seen.erase(I);
1446 
1447       if (set_is_subset(Seen, Roots)) {
1448         auto Diff = set_difference(Roots, Seen);
1449         if (Diff.size() == 1) {
1450           Value *RootVal = *Diff.begin();
1451           // Handle the case where previous loops already looked through
1452           // an addrspacecast.
1453           if (isSplitFatPtr(RootVal->getType()))
1454             MaybeRsrc = std::get<0>(getPtrParts(RootVal));
1455           else
1456             MaybeRsrc = RootVal;
1457         }
1458       }
1459     }
1460 
1461     if (auto *PHI = dyn_cast<PHINode>(I)) {
1462       Value *NewRsrc;
1463       StructType *PHITy = cast<StructType>(PHI->getType());
1464       IRB.SetInsertPoint(*PHI->getInsertionPointAfterDef());
1465       IRB.SetCurrentDebugLocation(PHI->getDebugLoc());
1466       if (MaybeRsrc) {
1467         NewRsrc = *MaybeRsrc;
1468       } else {
1469         Type *RsrcTy = PHITy->getElementType(0);
1470         auto *RsrcPHI = IRB.CreatePHI(RsrcTy, PHI->getNumIncomingValues());
1471         RsrcPHI->takeName(Rsrc);
1472         for (auto [V, BB] : llvm::zip(PHI->incoming_values(), PHI->blocks())) {
1473           Value *VRsrc = std::get<0>(getPtrParts(V));
1474           RsrcPHI->addIncoming(VRsrc, BB);
1475         }
1476         copyMetadata(RsrcPHI, PHI);
1477         NewRsrc = RsrcPHI;
1478       }
1479 
1480       Type *OffTy = PHITy->getElementType(1);
1481       auto *NewOff = IRB.CreatePHI(OffTy, PHI->getNumIncomingValues());
1482       NewOff->takeName(Off);
1483       for (auto [V, BB] : llvm::zip(PHI->incoming_values(), PHI->blocks())) {
1484         assert(OffParts.count(V) && "An offset part had to be created by now");
1485         Value *VOff = std::get<1>(getPtrParts(V));
1486         NewOff->addIncoming(VOff, BB);
1487       }
1488       copyMetadata(NewOff, PHI);
1489 
1490       // Note: We don't eraseFromParent() the temporaries because we don't want
1491       // to put the corrections maps in an inconstent state. That'll be handed
1492       // during the rest of the killing. Also, `ValueToValueMapTy` guarantees
1493       // that references in that map will be updated as well.
1494       ConditionalTemps.push_back(cast<Instruction>(Rsrc));
1495       ConditionalTemps.push_back(cast<Instruction>(Off));
1496       Rsrc->replaceAllUsesWith(NewRsrc);
1497       Off->replaceAllUsesWith(NewOff);
1498 
1499       // Save on recomputing the cycle traversals in known-root cases.
1500       if (MaybeRsrc)
1501         for (Value *V : Seen)
1502           FoundRsrcs[cast<Instruction>(V)] = NewRsrc;
1503     } else if (isa<SelectInst>(I)) {
1504       if (MaybeRsrc) {
1505         ConditionalTemps.push_back(cast<Instruction>(Rsrc));
1506         Rsrc->replaceAllUsesWith(*MaybeRsrc);
1507         for (Value *V : Seen)
1508           FoundRsrcs[cast<Instruction>(V)] = *MaybeRsrc;
1509       }
1510     } else {
1511       llvm_unreachable("Only PHIs and selects go in the conditionals list");
1512     }
1513   }
1514 }
1515 
1516 void SplitPtrStructs::killAndReplaceSplitInstructions(
1517     SmallVectorImpl<Instruction *> &Origs) {
1518   for (Instruction *I : ConditionalTemps)
1519     I->eraseFromParent();
1520 
1521   for (Instruction *I : Origs) {
1522     if (!SplitUsers.contains(I))
1523       continue;
1524 
1525     SmallVector<DbgValueInst *> Dbgs;
1526     findDbgValues(Dbgs, I);
1527     for (auto *Dbg : Dbgs) {
1528       IRB.SetInsertPoint(Dbg);
1529       auto &DL = I->getDataLayout();
1530       assert(isSplitFatPtr(I->getType()) &&
1531              "We should've RAUW'd away loads, stores, etc. at this point");
1532       auto *OffDbg = cast<DbgValueInst>(Dbg->clone());
1533       copyMetadata(OffDbg, Dbg);
1534       auto [Rsrc, Off] = getPtrParts(I);
1535 
1536       int64_t RsrcSz = DL.getTypeSizeInBits(Rsrc->getType());
1537       int64_t OffSz = DL.getTypeSizeInBits(Off->getType());
1538 
1539       std::optional<DIExpression *> RsrcExpr =
1540           DIExpression::createFragmentExpression(Dbg->getExpression(), 0,
1541                                                  RsrcSz);
1542       std::optional<DIExpression *> OffExpr =
1543           DIExpression::createFragmentExpression(Dbg->getExpression(), RsrcSz,
1544                                                  OffSz);
1545       if (OffExpr) {
1546         OffDbg->setExpression(*OffExpr);
1547         OffDbg->replaceVariableLocationOp(I, Off);
1548         IRB.Insert(OffDbg);
1549       } else {
1550         OffDbg->deleteValue();
1551       }
1552       if (RsrcExpr) {
1553         Dbg->setExpression(*RsrcExpr);
1554         Dbg->replaceVariableLocationOp(I, Rsrc);
1555       } else {
1556         Dbg->replaceVariableLocationOp(I, UndefValue::get(I->getType()));
1557       }
1558     }
1559 
1560     Value *Poison = PoisonValue::get(I->getType());
1561     I->replaceUsesWithIf(Poison, [&](const Use &U) -> bool {
1562       if (const auto *UI = dyn_cast<Instruction>(U.getUser()))
1563         return SplitUsers.contains(UI);
1564       return false;
1565     });
1566 
1567     if (I->use_empty()) {
1568       I->eraseFromParent();
1569       continue;
1570     }
1571     IRB.SetInsertPoint(*I->getInsertionPointAfterDef());
1572     IRB.SetCurrentDebugLocation(I->getDebugLoc());
1573     auto [Rsrc, Off] = getPtrParts(I);
1574     Value *Struct = PoisonValue::get(I->getType());
1575     Struct = IRB.CreateInsertValue(Struct, Rsrc, 0);
1576     Struct = IRB.CreateInsertValue(Struct, Off, 1);
1577     copyMetadata(Struct, I);
1578     Struct->takeName(I);
1579     I->replaceAllUsesWith(Struct);
1580     I->eraseFromParent();
1581   }
1582 }
1583 
1584 void SplitPtrStructs::setAlign(CallInst *Intr, Align A, unsigned RsrcArgIdx) {
1585   LLVMContext &Ctx = Intr->getContext();
1586   Intr->addParamAttr(RsrcArgIdx, Attribute::getWithAlignment(Ctx, A));
1587 }
1588 
1589 void SplitPtrStructs::insertPreMemOpFence(AtomicOrdering Order,
1590                                           SyncScope::ID SSID) {
1591   switch (Order) {
1592   case AtomicOrdering::Release:
1593   case AtomicOrdering::AcquireRelease:
1594   case AtomicOrdering::SequentiallyConsistent:
1595     IRB.CreateFence(AtomicOrdering::Release, SSID);
1596     break;
1597   default:
1598     break;
1599   }
1600 }
1601 
1602 void SplitPtrStructs::insertPostMemOpFence(AtomicOrdering Order,
1603                                            SyncScope::ID SSID) {
1604   switch (Order) {
1605   case AtomicOrdering::Acquire:
1606   case AtomicOrdering::AcquireRelease:
1607   case AtomicOrdering::SequentiallyConsistent:
1608     IRB.CreateFence(AtomicOrdering::Acquire, SSID);
1609     break;
1610   default:
1611     break;
1612   }
1613 }
1614 
1615 Value *SplitPtrStructs::handleMemoryInst(Instruction *I, Value *Arg, Value *Ptr,
1616                                          Type *Ty, Align Alignment,
1617                                          AtomicOrdering Order, bool IsVolatile,
1618                                          SyncScope::ID SSID) {
1619   IRB.SetInsertPoint(I);
1620 
1621   auto [Rsrc, Off] = getPtrParts(Ptr);
1622   SmallVector<Value *, 5> Args;
1623   if (Arg)
1624     Args.push_back(Arg);
1625   Args.push_back(Rsrc);
1626   Args.push_back(Off);
1627   insertPreMemOpFence(Order, SSID);
1628   // soffset is always 0 for these cases, where we always want any offset to be
1629   // part of bounds checking and we don't know which parts of the GEPs is
1630   // uniform.
1631   Args.push_back(IRB.getInt32(0));
1632 
1633   uint32_t Aux = 0;
1634   if (IsVolatile)
1635     Aux |= AMDGPU::CPol::VOLATILE;
1636   Args.push_back(IRB.getInt32(Aux));
1637 
1638   Intrinsic::ID IID = Intrinsic::not_intrinsic;
1639   if (isa<LoadInst>(I))
1640     IID = Order == AtomicOrdering::NotAtomic
1641               ? Intrinsic::amdgcn_raw_ptr_buffer_load
1642               : Intrinsic::amdgcn_raw_ptr_atomic_buffer_load;
1643   else if (isa<StoreInst>(I))
1644     IID = Intrinsic::amdgcn_raw_ptr_buffer_store;
1645   else if (auto *RMW = dyn_cast<AtomicRMWInst>(I)) {
1646     switch (RMW->getOperation()) {
1647     case AtomicRMWInst::Xchg:
1648       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_swap;
1649       break;
1650     case AtomicRMWInst::Add:
1651       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_add;
1652       break;
1653     case AtomicRMWInst::Sub:
1654       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_sub;
1655       break;
1656     case AtomicRMWInst::And:
1657       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_and;
1658       break;
1659     case AtomicRMWInst::Or:
1660       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_or;
1661       break;
1662     case AtomicRMWInst::Xor:
1663       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_xor;
1664       break;
1665     case AtomicRMWInst::Max:
1666       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_smax;
1667       break;
1668     case AtomicRMWInst::Min:
1669       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_smin;
1670       break;
1671     case AtomicRMWInst::UMax:
1672       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_umax;
1673       break;
1674     case AtomicRMWInst::UMin:
1675       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_umin;
1676       break;
1677     case AtomicRMWInst::FAdd:
1678       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fadd;
1679       break;
1680     case AtomicRMWInst::FMax:
1681       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fmax;
1682       break;
1683     case AtomicRMWInst::FMin:
1684       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fmin;
1685       break;
1686     case AtomicRMWInst::FSub: {
1687       report_fatal_error("atomic floating point subtraction not supported for "
1688                          "buffer resources and should've been expanded away");
1689       break;
1690     }
1691     case AtomicRMWInst::Nand:
1692       report_fatal_error("atomic nand not supported for buffer resources and "
1693                          "should've been expanded away");
1694       break;
1695     case AtomicRMWInst::UIncWrap:
1696     case AtomicRMWInst::UDecWrap:
1697       report_fatal_error("wrapping increment/decrement not supported for "
1698                          "buffer resources and should've ben expanded away");
1699       break;
1700     case AtomicRMWInst::BAD_BINOP:
1701       llvm_unreachable("Not sure how we got a bad binop");
1702     case AtomicRMWInst::USubCond:
1703     case AtomicRMWInst::USubSat:
1704       break;
1705     }
1706   }
1707 
1708   auto *Call = IRB.CreateIntrinsic(IID, Ty, Args);
1709   copyMetadata(Call, I);
1710   setAlign(Call, Alignment, Arg ? 1 : 0);
1711   Call->takeName(I);
1712 
1713   insertPostMemOpFence(Order, SSID);
1714   // The "no moving p7 directly" rewrites ensure that this load or store won't
1715   // itself need to be split into parts.
1716   SplitUsers.insert(I);
1717   I->replaceAllUsesWith(Call);
1718   return Call;
1719 }
1720 
1721 PtrParts SplitPtrStructs::visitInstruction(Instruction &I) {
1722   return {nullptr, nullptr};
1723 }
1724 
1725 PtrParts SplitPtrStructs::visitLoadInst(LoadInst &LI) {
1726   if (!isSplitFatPtr(LI.getPointerOperandType()))
1727     return {nullptr, nullptr};
1728   handleMemoryInst(&LI, nullptr, LI.getPointerOperand(), LI.getType(),
1729                    LI.getAlign(), LI.getOrdering(), LI.isVolatile(),
1730                    LI.getSyncScopeID());
1731   return {nullptr, nullptr};
1732 }
1733 
1734 PtrParts SplitPtrStructs::visitStoreInst(StoreInst &SI) {
1735   if (!isSplitFatPtr(SI.getPointerOperandType()))
1736     return {nullptr, nullptr};
1737   Value *Arg = SI.getValueOperand();
1738   handleMemoryInst(&SI, Arg, SI.getPointerOperand(), Arg->getType(),
1739                    SI.getAlign(), SI.getOrdering(), SI.isVolatile(),
1740                    SI.getSyncScopeID());
1741   return {nullptr, nullptr};
1742 }
1743 
1744 PtrParts SplitPtrStructs::visitAtomicRMWInst(AtomicRMWInst &AI) {
1745   if (!isSplitFatPtr(AI.getPointerOperand()->getType()))
1746     return {nullptr, nullptr};
1747   Value *Arg = AI.getValOperand();
1748   handleMemoryInst(&AI, Arg, AI.getPointerOperand(), Arg->getType(),
1749                    AI.getAlign(), AI.getOrdering(), AI.isVolatile(),
1750                    AI.getSyncScopeID());
1751   return {nullptr, nullptr};
1752 }
1753 
1754 // Unlike load, store, and RMW, cmpxchg needs special handling to account
1755 // for the boolean argument.
1756 PtrParts SplitPtrStructs::visitAtomicCmpXchgInst(AtomicCmpXchgInst &AI) {
1757   Value *Ptr = AI.getPointerOperand();
1758   if (!isSplitFatPtr(Ptr->getType()))
1759     return {nullptr, nullptr};
1760   IRB.SetInsertPoint(&AI);
1761 
1762   Type *Ty = AI.getNewValOperand()->getType();
1763   AtomicOrdering Order = AI.getMergedOrdering();
1764   SyncScope::ID SSID = AI.getSyncScopeID();
1765   bool IsNonTemporal = AI.getMetadata(LLVMContext::MD_nontemporal);
1766 
1767   auto [Rsrc, Off] = getPtrParts(Ptr);
1768   insertPreMemOpFence(Order, SSID);
1769 
1770   uint32_t Aux = 0;
1771   if (IsNonTemporal)
1772     Aux |= AMDGPU::CPol::SLC;
1773   if (AI.isVolatile())
1774     Aux |= AMDGPU::CPol::VOLATILE;
1775   auto *Call =
1776       IRB.CreateIntrinsic(Intrinsic::amdgcn_raw_ptr_buffer_atomic_cmpswap, Ty,
1777                           {AI.getNewValOperand(), AI.getCompareOperand(), Rsrc,
1778                            Off, IRB.getInt32(0), IRB.getInt32(Aux)});
1779   copyMetadata(Call, &AI);
1780   setAlign(Call, AI.getAlign(), 2);
1781   Call->takeName(&AI);
1782   insertPostMemOpFence(Order, SSID);
1783 
1784   Value *Res = PoisonValue::get(AI.getType());
1785   Res = IRB.CreateInsertValue(Res, Call, 0);
1786   if (!AI.isWeak()) {
1787     Value *Succeeded = IRB.CreateICmpEQ(Call, AI.getCompareOperand());
1788     Res = IRB.CreateInsertValue(Res, Succeeded, 1);
1789   }
1790   SplitUsers.insert(&AI);
1791   AI.replaceAllUsesWith(Res);
1792   return {nullptr, nullptr};
1793 }
1794 
1795 PtrParts SplitPtrStructs::visitGetElementPtrInst(GetElementPtrInst &GEP) {
1796   using namespace llvm::PatternMatch;
1797   Value *Ptr = GEP.getPointerOperand();
1798   if (!isSplitFatPtr(Ptr->getType()))
1799     return {nullptr, nullptr};
1800   IRB.SetInsertPoint(&GEP);
1801 
1802   auto [Rsrc, Off] = getPtrParts(Ptr);
1803   const DataLayout &DL = GEP.getDataLayout();
1804   bool IsNUW = GEP.hasNoUnsignedWrap();
1805   bool IsNUSW = GEP.hasNoUnsignedSignedWrap();
1806 
1807   // In order to call emitGEPOffset() and thus not have to reimplement it,
1808   // we need the GEP result to have ptr addrspace(7) type.
1809   Type *FatPtrTy = IRB.getPtrTy(AMDGPUAS::BUFFER_FAT_POINTER);
1810   if (auto *VT = dyn_cast<VectorType>(Off->getType()))
1811     FatPtrTy = VectorType::get(FatPtrTy, VT->getElementCount());
1812   GEP.mutateType(FatPtrTy);
1813   Value *OffAccum = emitGEPOffset(&IRB, DL, &GEP);
1814   GEP.mutateType(Ptr->getType());
1815   if (match(OffAccum, m_Zero())) { // Constant-zero offset
1816     SplitUsers.insert(&GEP);
1817     return {Rsrc, Off};
1818   }
1819 
1820   bool HasNonNegativeOff = false;
1821   if (auto *CI = dyn_cast<ConstantInt>(OffAccum)) {
1822     HasNonNegativeOff = !CI->isNegative();
1823   }
1824   Value *NewOff;
1825   if (match(Off, m_Zero())) {
1826     NewOff = OffAccum;
1827   } else {
1828     NewOff = IRB.CreateAdd(Off, OffAccum, "",
1829                            /*hasNUW=*/IsNUW || (IsNUSW && HasNonNegativeOff),
1830                            /*hasNSW=*/false);
1831   }
1832   copyMetadata(NewOff, &GEP);
1833   NewOff->takeName(&GEP);
1834   SplitUsers.insert(&GEP);
1835   return {Rsrc, NewOff};
1836 }
1837 
1838 PtrParts SplitPtrStructs::visitPtrToIntInst(PtrToIntInst &PI) {
1839   Value *Ptr = PI.getPointerOperand();
1840   if (!isSplitFatPtr(Ptr->getType()))
1841     return {nullptr, nullptr};
1842   IRB.SetInsertPoint(&PI);
1843 
1844   Type *ResTy = PI.getType();
1845   unsigned Width = ResTy->getScalarSizeInBits();
1846 
1847   auto [Rsrc, Off] = getPtrParts(Ptr);
1848   const DataLayout &DL = PI.getDataLayout();
1849   unsigned FatPtrWidth = DL.getPointerSizeInBits(AMDGPUAS::BUFFER_FAT_POINTER);
1850 
1851   Value *Res;
1852   if (Width <= BufferOffsetWidth) {
1853     Res = IRB.CreateIntCast(Off, ResTy, /*isSigned=*/false,
1854                             PI.getName() + ".off");
1855   } else {
1856     Value *RsrcInt = IRB.CreatePtrToInt(Rsrc, ResTy, PI.getName() + ".rsrc");
1857     Value *Shl = IRB.CreateShl(
1858         RsrcInt,
1859         ConstantExpr::getIntegerValue(ResTy, APInt(Width, BufferOffsetWidth)),
1860         "", Width >= FatPtrWidth, Width > FatPtrWidth);
1861     Value *OffCast = IRB.CreateIntCast(Off, ResTy, /*isSigned=*/false,
1862                                        PI.getName() + ".off");
1863     Res = IRB.CreateOr(Shl, OffCast);
1864   }
1865 
1866   copyMetadata(Res, &PI);
1867   Res->takeName(&PI);
1868   SplitUsers.insert(&PI);
1869   PI.replaceAllUsesWith(Res);
1870   return {nullptr, nullptr};
1871 }
1872 
1873 PtrParts SplitPtrStructs::visitIntToPtrInst(IntToPtrInst &IP) {
1874   if (!isSplitFatPtr(IP.getType()))
1875     return {nullptr, nullptr};
1876   IRB.SetInsertPoint(&IP);
1877   const DataLayout &DL = IP.getDataLayout();
1878   unsigned RsrcPtrWidth = DL.getPointerSizeInBits(AMDGPUAS::BUFFER_RESOURCE);
1879   Value *Int = IP.getOperand(0);
1880   Type *IntTy = Int->getType();
1881   Type *RsrcIntTy = IntTy->getWithNewBitWidth(RsrcPtrWidth);
1882   unsigned Width = IntTy->getScalarSizeInBits();
1883 
1884   auto *RetTy = cast<StructType>(IP.getType());
1885   Type *RsrcTy = RetTy->getElementType(0);
1886   Type *OffTy = RetTy->getElementType(1);
1887   Value *RsrcPart = IRB.CreateLShr(
1888       Int,
1889       ConstantExpr::getIntegerValue(IntTy, APInt(Width, BufferOffsetWidth)));
1890   Value *RsrcInt = IRB.CreateIntCast(RsrcPart, RsrcIntTy, /*isSigned=*/false);
1891   Value *Rsrc = IRB.CreateIntToPtr(RsrcInt, RsrcTy, IP.getName() + ".rsrc");
1892   Value *Off =
1893       IRB.CreateIntCast(Int, OffTy, /*IsSigned=*/false, IP.getName() + ".off");
1894 
1895   copyMetadata(Rsrc, &IP);
1896   SplitUsers.insert(&IP);
1897   return {Rsrc, Off};
1898 }
1899 
1900 PtrParts SplitPtrStructs::visitAddrSpaceCastInst(AddrSpaceCastInst &I) {
1901   if (!isSplitFatPtr(I.getType()))
1902     return {nullptr, nullptr};
1903   IRB.SetInsertPoint(&I);
1904   Value *In = I.getPointerOperand();
1905   // No-op casts preserve parts
1906   if (In->getType() == I.getType()) {
1907     auto [Rsrc, Off] = getPtrParts(In);
1908     SplitUsers.insert(&I);
1909     return {Rsrc, Off};
1910   }
1911   if (I.getSrcAddressSpace() != AMDGPUAS::BUFFER_RESOURCE)
1912     report_fatal_error("Only buffer resources (addrspace 8) can be cast to "
1913                        "buffer fat pointers (addrspace 7)");
1914   Type *OffTy = cast<StructType>(I.getType())->getElementType(1);
1915   Value *ZeroOff = Constant::getNullValue(OffTy);
1916   SplitUsers.insert(&I);
1917   return {In, ZeroOff};
1918 }
1919 
1920 PtrParts SplitPtrStructs::visitICmpInst(ICmpInst &Cmp) {
1921   Value *Lhs = Cmp.getOperand(0);
1922   if (!isSplitFatPtr(Lhs->getType()))
1923     return {nullptr, nullptr};
1924   Value *Rhs = Cmp.getOperand(1);
1925   IRB.SetInsertPoint(&Cmp);
1926   ICmpInst::Predicate Pred = Cmp.getPredicate();
1927 
1928   assert((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) &&
1929          "Pointer comparison is only equal or unequal");
1930   auto [LhsRsrc, LhsOff] = getPtrParts(Lhs);
1931   auto [RhsRsrc, RhsOff] = getPtrParts(Rhs);
1932   Value *RsrcCmp =
1933       IRB.CreateICmp(Pred, LhsRsrc, RhsRsrc, Cmp.getName() + ".rsrc");
1934   copyMetadata(RsrcCmp, &Cmp);
1935   Value *OffCmp = IRB.CreateICmp(Pred, LhsOff, RhsOff, Cmp.getName() + ".off");
1936   copyMetadata(OffCmp, &Cmp);
1937 
1938   Value *Res = nullptr;
1939   if (Pred == ICmpInst::ICMP_EQ)
1940     Res = IRB.CreateAnd(RsrcCmp, OffCmp);
1941   else if (Pred == ICmpInst::ICMP_NE)
1942     Res = IRB.CreateOr(RsrcCmp, OffCmp);
1943   copyMetadata(Res, &Cmp);
1944   Res->takeName(&Cmp);
1945   SplitUsers.insert(&Cmp);
1946   Cmp.replaceAllUsesWith(Res);
1947   return {nullptr, nullptr};
1948 }
1949 
1950 PtrParts SplitPtrStructs::visitFreezeInst(FreezeInst &I) {
1951   if (!isSplitFatPtr(I.getType()))
1952     return {nullptr, nullptr};
1953   IRB.SetInsertPoint(&I);
1954   auto [Rsrc, Off] = getPtrParts(I.getOperand(0));
1955 
1956   Value *RsrcRes = IRB.CreateFreeze(Rsrc, I.getName() + ".rsrc");
1957   copyMetadata(RsrcRes, &I);
1958   Value *OffRes = IRB.CreateFreeze(Off, I.getName() + ".off");
1959   copyMetadata(OffRes, &I);
1960   SplitUsers.insert(&I);
1961   return {RsrcRes, OffRes};
1962 }
1963 
1964 PtrParts SplitPtrStructs::visitExtractElementInst(ExtractElementInst &I) {
1965   if (!isSplitFatPtr(I.getType()))
1966     return {nullptr, nullptr};
1967   IRB.SetInsertPoint(&I);
1968   Value *Vec = I.getVectorOperand();
1969   Value *Idx = I.getIndexOperand();
1970   auto [Rsrc, Off] = getPtrParts(Vec);
1971 
1972   Value *RsrcRes = IRB.CreateExtractElement(Rsrc, Idx, I.getName() + ".rsrc");
1973   copyMetadata(RsrcRes, &I);
1974   Value *OffRes = IRB.CreateExtractElement(Off, Idx, I.getName() + ".off");
1975   copyMetadata(OffRes, &I);
1976   SplitUsers.insert(&I);
1977   return {RsrcRes, OffRes};
1978 }
1979 
1980 PtrParts SplitPtrStructs::visitInsertElementInst(InsertElementInst &I) {
1981   // The mutated instructions temporarily don't return vectors, and so
1982   // we need the generic getType() here to avoid crashes.
1983   if (!isSplitFatPtr(cast<Instruction>(I).getType()))
1984     return {nullptr, nullptr};
1985   IRB.SetInsertPoint(&I);
1986   Value *Vec = I.getOperand(0);
1987   Value *Elem = I.getOperand(1);
1988   Value *Idx = I.getOperand(2);
1989   auto [VecRsrc, VecOff] = getPtrParts(Vec);
1990   auto [ElemRsrc, ElemOff] = getPtrParts(Elem);
1991 
1992   Value *RsrcRes =
1993       IRB.CreateInsertElement(VecRsrc, ElemRsrc, Idx, I.getName() + ".rsrc");
1994   copyMetadata(RsrcRes, &I);
1995   Value *OffRes =
1996       IRB.CreateInsertElement(VecOff, ElemOff, Idx, I.getName() + ".off");
1997   copyMetadata(OffRes, &I);
1998   SplitUsers.insert(&I);
1999   return {RsrcRes, OffRes};
2000 }
2001 
2002 PtrParts SplitPtrStructs::visitShuffleVectorInst(ShuffleVectorInst &I) {
2003   // Cast is needed for the same reason as insertelement's.
2004   if (!isSplitFatPtr(cast<Instruction>(I).getType()))
2005     return {nullptr, nullptr};
2006   IRB.SetInsertPoint(&I);
2007 
2008   Value *V1 = I.getOperand(0);
2009   Value *V2 = I.getOperand(1);
2010   ArrayRef<int> Mask = I.getShuffleMask();
2011   auto [V1Rsrc, V1Off] = getPtrParts(V1);
2012   auto [V2Rsrc, V2Off] = getPtrParts(V2);
2013 
2014   Value *RsrcRes =
2015       IRB.CreateShuffleVector(V1Rsrc, V2Rsrc, Mask, I.getName() + ".rsrc");
2016   copyMetadata(RsrcRes, &I);
2017   Value *OffRes =
2018       IRB.CreateShuffleVector(V1Off, V2Off, Mask, I.getName() + ".off");
2019   copyMetadata(OffRes, &I);
2020   SplitUsers.insert(&I);
2021   return {RsrcRes, OffRes};
2022 }
2023 
2024 PtrParts SplitPtrStructs::visitPHINode(PHINode &PHI) {
2025   if (!isSplitFatPtr(PHI.getType()))
2026     return {nullptr, nullptr};
2027   IRB.SetInsertPoint(*PHI.getInsertionPointAfterDef());
2028   // Phi nodes will be handled in post-processing after we've visited every
2029   // instruction. However, instead of just returning {nullptr, nullptr},
2030   // we explicitly create the temporary extractvalue operations that are our
2031   // temporary results so that they end up at the beginning of the block with
2032   // the PHIs.
2033   Value *TmpRsrc = IRB.CreateExtractValue(&PHI, 0, PHI.getName() + ".rsrc");
2034   Value *TmpOff = IRB.CreateExtractValue(&PHI, 1, PHI.getName() + ".off");
2035   Conditionals.push_back(&PHI);
2036   SplitUsers.insert(&PHI);
2037   return {TmpRsrc, TmpOff};
2038 }
2039 
2040 PtrParts SplitPtrStructs::visitSelectInst(SelectInst &SI) {
2041   if (!isSplitFatPtr(SI.getType()))
2042     return {nullptr, nullptr};
2043   IRB.SetInsertPoint(&SI);
2044 
2045   Value *Cond = SI.getCondition();
2046   Value *True = SI.getTrueValue();
2047   Value *False = SI.getFalseValue();
2048   auto [TrueRsrc, TrueOff] = getPtrParts(True);
2049   auto [FalseRsrc, FalseOff] = getPtrParts(False);
2050 
2051   Value *RsrcRes =
2052       IRB.CreateSelect(Cond, TrueRsrc, FalseRsrc, SI.getName() + ".rsrc", &SI);
2053   copyMetadata(RsrcRes, &SI);
2054   Conditionals.push_back(&SI);
2055   Value *OffRes =
2056       IRB.CreateSelect(Cond, TrueOff, FalseOff, SI.getName() + ".off", &SI);
2057   copyMetadata(OffRes, &SI);
2058   SplitUsers.insert(&SI);
2059   return {RsrcRes, OffRes};
2060 }
2061 
2062 /// Returns true if this intrinsic needs to be removed when it is
2063 /// applied to `ptr addrspace(7)` values. Calls to these intrinsics are
2064 /// rewritten into calls to versions of that intrinsic on the resource
2065 /// descriptor.
2066 static bool isRemovablePointerIntrinsic(Intrinsic::ID IID) {
2067   switch (IID) {
2068   default:
2069     return false;
2070   case Intrinsic::ptrmask:
2071   case Intrinsic::invariant_start:
2072   case Intrinsic::invariant_end:
2073   case Intrinsic::launder_invariant_group:
2074   case Intrinsic::strip_invariant_group:
2075     return true;
2076   }
2077 }
2078 
2079 PtrParts SplitPtrStructs::visitIntrinsicInst(IntrinsicInst &I) {
2080   Intrinsic::ID IID = I.getIntrinsicID();
2081   switch (IID) {
2082   default:
2083     break;
2084   case Intrinsic::ptrmask: {
2085     Value *Ptr = I.getArgOperand(0);
2086     if (!isSplitFatPtr(Ptr->getType()))
2087       return {nullptr, nullptr};
2088     Value *Mask = I.getArgOperand(1);
2089     IRB.SetInsertPoint(&I);
2090     auto [Rsrc, Off] = getPtrParts(Ptr);
2091     if (Mask->getType() != Off->getType())
2092       report_fatal_error("offset width is not equal to index width of fat "
2093                          "pointer (data layout not set up correctly?)");
2094     Value *OffRes = IRB.CreateAnd(Off, Mask, I.getName() + ".off");
2095     copyMetadata(OffRes, &I);
2096     SplitUsers.insert(&I);
2097     return {Rsrc, OffRes};
2098   }
2099   // Pointer annotation intrinsics that, given their object-wide nature
2100   // operate on the resource part.
2101   case Intrinsic::invariant_start: {
2102     Value *Ptr = I.getArgOperand(1);
2103     if (!isSplitFatPtr(Ptr->getType()))
2104       return {nullptr, nullptr};
2105     IRB.SetInsertPoint(&I);
2106     auto [Rsrc, Off] = getPtrParts(Ptr);
2107     Type *NewTy = PointerType::get(I.getContext(), AMDGPUAS::BUFFER_RESOURCE);
2108     auto *NewRsrc = IRB.CreateIntrinsic(IID, {NewTy}, {I.getOperand(0), Rsrc});
2109     copyMetadata(NewRsrc, &I);
2110     NewRsrc->takeName(&I);
2111     SplitUsers.insert(&I);
2112     I.replaceAllUsesWith(NewRsrc);
2113     return {nullptr, nullptr};
2114   }
2115   case Intrinsic::invariant_end: {
2116     Value *RealPtr = I.getArgOperand(2);
2117     if (!isSplitFatPtr(RealPtr->getType()))
2118       return {nullptr, nullptr};
2119     IRB.SetInsertPoint(&I);
2120     Value *RealRsrc = getPtrParts(RealPtr).first;
2121     Value *InvPtr = I.getArgOperand(0);
2122     Value *Size = I.getArgOperand(1);
2123     Value *NewRsrc = IRB.CreateIntrinsic(IID, {RealRsrc->getType()},
2124                                          {InvPtr, Size, RealRsrc});
2125     copyMetadata(NewRsrc, &I);
2126     NewRsrc->takeName(&I);
2127     SplitUsers.insert(&I);
2128     I.replaceAllUsesWith(NewRsrc);
2129     return {nullptr, nullptr};
2130   }
2131   case Intrinsic::launder_invariant_group:
2132   case Intrinsic::strip_invariant_group: {
2133     Value *Ptr = I.getArgOperand(0);
2134     if (!isSplitFatPtr(Ptr->getType()))
2135       return {nullptr, nullptr};
2136     IRB.SetInsertPoint(&I);
2137     auto [Rsrc, Off] = getPtrParts(Ptr);
2138     Value *NewRsrc = IRB.CreateIntrinsic(IID, {Rsrc->getType()}, {Rsrc});
2139     copyMetadata(NewRsrc, &I);
2140     NewRsrc->takeName(&I);
2141     SplitUsers.insert(&I);
2142     return {NewRsrc, Off};
2143   }
2144   }
2145   return {nullptr, nullptr};
2146 }
2147 
2148 void SplitPtrStructs::processFunction(Function &F) {
2149   ST = &TM->getSubtarget<GCNSubtarget>(F);
2150   SmallVector<Instruction *, 0> Originals;
2151   LLVM_DEBUG(dbgs() << "Splitting pointer structs in function: " << F.getName()
2152                     << "\n");
2153   for (Instruction &I : instructions(F))
2154     Originals.push_back(&I);
2155   for (Instruction *I : Originals) {
2156     auto [Rsrc, Off] = visit(I);
2157     assert(((Rsrc && Off) || (!Rsrc && !Off)) &&
2158            "Can't have a resource but no offset");
2159     if (Rsrc)
2160       RsrcParts[I] = Rsrc;
2161     if (Off)
2162       OffParts[I] = Off;
2163   }
2164   processConditionals();
2165   killAndReplaceSplitInstructions(Originals);
2166 
2167   // Clean up after ourselves to save on memory.
2168   RsrcParts.clear();
2169   OffParts.clear();
2170   SplitUsers.clear();
2171   Conditionals.clear();
2172   ConditionalTemps.clear();
2173 }
2174 
2175 namespace {
2176 class AMDGPULowerBufferFatPointers : public ModulePass {
2177 public:
2178   static char ID;
2179 
2180   AMDGPULowerBufferFatPointers() : ModulePass(ID) {
2181     initializeAMDGPULowerBufferFatPointersPass(
2182         *PassRegistry::getPassRegistry());
2183   }
2184 
2185   bool run(Module &M, const TargetMachine &TM);
2186   bool runOnModule(Module &M) override;
2187 
2188   void getAnalysisUsage(AnalysisUsage &AU) const override;
2189 };
2190 } // namespace
2191 
2192 /// Returns true if there are values that have a buffer fat pointer in them,
2193 /// which means we'll need to perform rewrites on this function. As a side
2194 /// effect, this will populate the type remapping cache.
2195 static bool containsBufferFatPointers(const Function &F,
2196                                       BufferFatPtrToStructTypeMap *TypeMap) {
2197   bool HasFatPointers = false;
2198   for (const BasicBlock &BB : F)
2199     for (const Instruction &I : BB)
2200       HasFatPointers |= (I.getType() != TypeMap->remapType(I.getType()));
2201   return HasFatPointers;
2202 }
2203 
2204 static bool hasFatPointerInterface(const Function &F,
2205                                    BufferFatPtrToStructTypeMap *TypeMap) {
2206   Type *Ty = F.getFunctionType();
2207   return Ty != TypeMap->remapType(Ty);
2208 }
2209 
2210 /// Move the body of `OldF` into a new function, returning it.
2211 static Function *moveFunctionAdaptingType(Function *OldF, FunctionType *NewTy,
2212                                           ValueToValueMapTy &CloneMap) {
2213   bool IsIntrinsic = OldF->isIntrinsic();
2214   Function *NewF =
2215       Function::Create(NewTy, OldF->getLinkage(), OldF->getAddressSpace());
2216   NewF->IsNewDbgInfoFormat = OldF->IsNewDbgInfoFormat;
2217   NewF->copyAttributesFrom(OldF);
2218   NewF->copyMetadata(OldF, 0);
2219   NewF->takeName(OldF);
2220   NewF->updateAfterNameChange();
2221   NewF->setDLLStorageClass(OldF->getDLLStorageClass());
2222   OldF->getParent()->getFunctionList().insertAfter(OldF->getIterator(), NewF);
2223 
2224   while (!OldF->empty()) {
2225     BasicBlock *BB = &OldF->front();
2226     BB->removeFromParent();
2227     BB->insertInto(NewF);
2228     CloneMap[BB] = BB;
2229     for (Instruction &I : *BB) {
2230       CloneMap[&I] = &I;
2231     }
2232   }
2233 
2234   SmallVector<AttributeSet> ArgAttrs;
2235   AttributeList OldAttrs = OldF->getAttributes();
2236 
2237   for (auto [I, OldArg, NewArg] : enumerate(OldF->args(), NewF->args())) {
2238     CloneMap[&NewArg] = &OldArg;
2239     NewArg.takeName(&OldArg);
2240     Type *OldArgTy = OldArg.getType(), *NewArgTy = NewArg.getType();
2241     // Temporarily mutate type of `NewArg` to allow RAUW to work.
2242     NewArg.mutateType(OldArgTy);
2243     OldArg.replaceAllUsesWith(&NewArg);
2244     NewArg.mutateType(NewArgTy);
2245 
2246     AttributeSet ArgAttr = OldAttrs.getParamAttrs(I);
2247     // Intrinsics get their attributes fixed later.
2248     if (OldArgTy != NewArgTy && !IsIntrinsic)
2249       ArgAttr = ArgAttr.removeAttributes(
2250           NewF->getContext(),
2251           AttributeFuncs::typeIncompatible(NewArgTy, ArgAttr));
2252     ArgAttrs.push_back(ArgAttr);
2253   }
2254   AttributeSet RetAttrs = OldAttrs.getRetAttrs();
2255   if (OldF->getReturnType() != NewF->getReturnType() && !IsIntrinsic)
2256     RetAttrs = RetAttrs.removeAttributes(
2257         NewF->getContext(),
2258         AttributeFuncs::typeIncompatible(NewF->getReturnType(), RetAttrs));
2259   NewF->setAttributes(AttributeList::get(
2260       NewF->getContext(), OldAttrs.getFnAttrs(), RetAttrs, ArgAttrs));
2261   return NewF;
2262 }
2263 
2264 static void makeCloneInPraceMap(Function *F, ValueToValueMapTy &CloneMap) {
2265   for (Argument &A : F->args())
2266     CloneMap[&A] = &A;
2267   for (BasicBlock &BB : *F) {
2268     CloneMap[&BB] = &BB;
2269     for (Instruction &I : BB)
2270       CloneMap[&I] = &I;
2271   }
2272 }
2273 
2274 bool AMDGPULowerBufferFatPointers::run(Module &M, const TargetMachine &TM) {
2275   bool Changed = false;
2276   const DataLayout &DL = M.getDataLayout();
2277   // Record the functions which need to be remapped.
2278   // The second element of the pair indicates whether the function has to have
2279   // its arguments or return types adjusted.
2280   SmallVector<std::pair<Function *, bool>> NeedsRemap;
2281 
2282   BufferFatPtrToStructTypeMap StructTM(DL);
2283   BufferFatPtrToIntTypeMap IntTM(DL);
2284   for (const GlobalVariable &GV : M.globals()) {
2285     if (GV.getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER)
2286       report_fatal_error("Global variables with a buffer fat pointer address "
2287                          "space (7) are not supported");
2288     Type *VT = GV.getValueType();
2289     if (VT != StructTM.remapType(VT))
2290       report_fatal_error("Global variables that contain buffer fat pointers "
2291                          "(address space 7 pointers) are unsupported. Use "
2292                          "buffer resource pointers (address space 8) instead.");
2293   }
2294 
2295   {
2296     // Collect all constant exprs and aggregates referenced by any function.
2297     SmallVector<Constant *, 8> Worklist;
2298     for (Function &F : M.functions())
2299       for (Instruction &I : instructions(F))
2300         for (Value *Op : I.operands())
2301           if (isa<ConstantExpr>(Op) || isa<ConstantAggregate>(Op))
2302             Worklist.push_back(cast<Constant>(Op));
2303 
2304     // Recursively look for any referenced buffer pointer constants.
2305     SmallPtrSet<Constant *, 8> Visited;
2306     SetVector<Constant *> BufferFatPtrConsts;
2307     while (!Worklist.empty()) {
2308       Constant *C = Worklist.pop_back_val();
2309       if (!Visited.insert(C).second)
2310         continue;
2311       if (isBufferFatPtrOrVector(C->getType()))
2312         BufferFatPtrConsts.insert(C);
2313       for (Value *Op : C->operands())
2314         if (isa<ConstantExpr>(Op) || isa<ConstantAggregate>(Op))
2315           Worklist.push_back(cast<Constant>(Op));
2316     }
2317 
2318     // Expand all constant expressions using fat buffer pointers to
2319     // instructions.
2320     Changed |= convertUsersOfConstantsToInstructions(
2321         BufferFatPtrConsts.getArrayRef(), /*RestrictToFunc=*/nullptr,
2322         /*RemoveDeadConstants=*/false, /*IncludeSelf=*/true);
2323   }
2324 
2325   StoreFatPtrsAsIntsVisitor MemOpsRewrite(&IntTM, M.getContext());
2326   LegalizeBufferContentTypesVisitor BufferContentsTypeRewrite(DL,
2327                                                               M.getContext());
2328   for (Function &F : M.functions()) {
2329     bool InterfaceChange = hasFatPointerInterface(F, &StructTM);
2330     bool BodyChanges = containsBufferFatPointers(F, &StructTM);
2331     Changed |= MemOpsRewrite.processFunction(F);
2332     if (InterfaceChange || BodyChanges) {
2333       NeedsRemap.push_back(std::make_pair(&F, InterfaceChange));
2334       Changed |= BufferContentsTypeRewrite.processFunction(F);
2335     }
2336   }
2337   if (NeedsRemap.empty())
2338     return Changed;
2339 
2340   SmallVector<Function *> NeedsPostProcess;
2341   SmallVector<Function *> Intrinsics;
2342   // Keep one big map so as to memoize constants across functions.
2343   ValueToValueMapTy CloneMap;
2344   FatPtrConstMaterializer Materializer(&StructTM, CloneMap);
2345 
2346   ValueMapper LowerInFuncs(CloneMap, RF_None, &StructTM, &Materializer);
2347   for (auto [F, InterfaceChange] : NeedsRemap) {
2348     Function *NewF = F;
2349     if (InterfaceChange)
2350       NewF = moveFunctionAdaptingType(
2351           F, cast<FunctionType>(StructTM.remapType(F->getFunctionType())),
2352           CloneMap);
2353     else
2354       makeCloneInPraceMap(F, CloneMap);
2355     LowerInFuncs.remapFunction(*NewF);
2356     if (NewF->isIntrinsic())
2357       Intrinsics.push_back(NewF);
2358     else
2359       NeedsPostProcess.push_back(NewF);
2360     if (InterfaceChange) {
2361       F->replaceAllUsesWith(NewF);
2362       F->eraseFromParent();
2363     }
2364     Changed = true;
2365   }
2366   StructTM.clear();
2367   IntTM.clear();
2368   CloneMap.clear();
2369 
2370   SplitPtrStructs Splitter(M.getContext(), &TM);
2371   for (Function *F : NeedsPostProcess)
2372     Splitter.processFunction(*F);
2373   for (Function *F : Intrinsics) {
2374     if (isRemovablePointerIntrinsic(F->getIntrinsicID())) {
2375       F->eraseFromParent();
2376     } else {
2377       std::optional<Function *> NewF = Intrinsic::remangleIntrinsicFunction(F);
2378       if (NewF)
2379         F->replaceAllUsesWith(*NewF);
2380     }
2381   }
2382   return Changed;
2383 }
2384 
2385 bool AMDGPULowerBufferFatPointers::runOnModule(Module &M) {
2386   TargetPassConfig &TPC = getAnalysis<TargetPassConfig>();
2387   const TargetMachine &TM = TPC.getTM<TargetMachine>();
2388   return run(M, TM);
2389 }
2390 
2391 char AMDGPULowerBufferFatPointers::ID = 0;
2392 
2393 char &llvm::AMDGPULowerBufferFatPointersID = AMDGPULowerBufferFatPointers::ID;
2394 
2395 void AMDGPULowerBufferFatPointers::getAnalysisUsage(AnalysisUsage &AU) const {
2396   AU.addRequired<TargetPassConfig>();
2397 }
2398 
2399 #define PASS_DESC "Lower buffer fat pointer operations to buffer resources"
2400 INITIALIZE_PASS_BEGIN(AMDGPULowerBufferFatPointers, DEBUG_TYPE, PASS_DESC,
2401                       false, false)
2402 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
2403 INITIALIZE_PASS_END(AMDGPULowerBufferFatPointers, DEBUG_TYPE, PASS_DESC, false,
2404                     false)
2405 #undef PASS_DESC
2406 
2407 ModulePass *llvm::createAMDGPULowerBufferFatPointersPass() {
2408   return new AMDGPULowerBufferFatPointers();
2409 }
2410 
2411 PreservedAnalyses
2412 AMDGPULowerBufferFatPointersPass::run(Module &M, ModuleAnalysisManager &MA) {
2413   return AMDGPULowerBufferFatPointers().run(M, TM) ? PreservedAnalyses::none()
2414                                                    : PreservedAnalyses::all();
2415 }
2416