1 //===-- AMDGPULowerBufferFatPointers.cpp ---------------------------=// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass lowers operations on buffer fat pointers (addrspace 7) to 10 // operations on buffer resources (addrspace 8) and is needed for correct 11 // codegen. 12 // 13 // # Background 14 // 15 // Address space 7 (the buffer fat pointer) is a 160-bit pointer that consists 16 // of a 128-bit buffer descriptor and a 32-bit offset into that descriptor. 17 // The buffer resource part needs to be it needs to be a "raw" buffer resource 18 // (it must have a stride of 0 and bounds checks must be in raw buffer mode 19 // or disabled). 20 // 21 // When these requirements are met, a buffer resource can be treated as a 22 // typical (though quite wide) pointer that follows typical LLVM pointer 23 // semantics. This allows the frontend to reason about such buffers (which are 24 // often encountered in the context of SPIR-V kernels). 25 // 26 // However, because of their non-power-of-2 size, these fat pointers cannot be 27 // present during translation to MIR (though this restriction may be lifted 28 // during the transition to GlobalISel). Therefore, this pass is needed in order 29 // to correctly implement these fat pointers. 30 // 31 // The resource intrinsics take the resource part (the address space 8 pointer) 32 // and the offset part (the 32-bit integer) as separate arguments. In addition, 33 // many users of these buffers manipulate the offset while leaving the resource 34 // part alone. For these reasons, we want to typically separate the resource 35 // and offset parts into separate variables, but combine them together when 36 // encountering cases where this is required, such as by inserting these values 37 // into aggretates or moving them to memory. 38 // 39 // Therefore, at a high level, `ptr addrspace(7) %x` becomes `ptr addrspace(8) 40 // %x.rsrc` and `i32 %x.off`, which will be combined into `{ptr addrspace(8), 41 // i32} %x = {%x.rsrc, %x.off}` if needed. Similarly, `vector<Nxp7>` becomes 42 // `{vector<Nxp8>, vector<Nxi32 >}` and its component parts. 43 // 44 // # Implementation 45 // 46 // This pass proceeds in three main phases: 47 // 48 // ## Rewriting loads and stores of p7 49 // 50 // The first phase is to rewrite away all loads and stors of `ptr addrspace(7)`, 51 // including aggregates containing such pointers, to ones that use `i160`. This 52 // is handled by `StoreFatPtrsAsIntsVisitor` , which visits loads, stores, and 53 // allocas and, if the loaded or stored type contains `ptr addrspace(7)`, 54 // rewrites that type to one where the p7s are replaced by i160s, copying other 55 // parts of aggregates as needed. In the case of a store, each pointer is 56 // `ptrtoint`d to i160 before storing, and load integers are `inttoptr`d back. 57 // This same transformation is applied to vectors of pointers. 58 // 59 // Such a transformation allows the later phases of the pass to not need 60 // to handle buffer fat pointers moving to and from memory, where we load 61 // have to handle the incompatibility between a `{Nxp8, Nxi32}` representation 62 // and `Nxi60` directly. Instead, that transposing action (where the vectors 63 // of resources and vectors of offsets are concatentated before being stored to 64 // memory) are handled through implementing `inttoptr` and `ptrtoint` only. 65 // 66 // Atomics operations on `ptr addrspace(7)` values are not suppported, as the 67 // hardware does not include a 160-bit atomic. 68 // 69 // ## Buffer contents type legalization 70 // 71 // The underlying buffer intrinsics only support types up to 128 bits long, 72 // and don't support complex types. If buffer operations were 73 // standard pointer operations that could be represented as MIR-level loads, 74 // this would be handled by the various legalization schemes in instruction 75 // selection. However, because we have to do the conversion from `load` and 76 // `store` to intrinsics at LLVM IR level, we must perform that legalization 77 // ourselves. 78 // 79 // This involves a combination of 80 // - Converting arrays to vectors where possible 81 // - Otherwise, splitting loads and stores of aggregates into loads/stores of 82 // each component. 83 // - Zero-extending things to fill a whole number of bytes 84 // - Casting values of types that don't neatly correspond to supported machine 85 // value 86 // (for example, an i96 or i256) into ones that would work ( 87 // like <3 x i32> and <8 x i32>, respectively) 88 // - Splitting values that are too long (such as aforementioned <8 x i32>) into 89 // multiple operations. 90 // 91 // ## Type remapping 92 // 93 // We use a `ValueMapper` to mangle uses of [vectors of] buffer fat pointers 94 // to the corresponding struct type, which has a resource part and an offset 95 // part. 96 // 97 // This uses a `BufferFatPtrToStructTypeMap` and a `FatPtrConstMaterializer` 98 // to, usually by way of `setType`ing values. Constants are handled here 99 // because there isn't a good way to fix them up later. 100 // 101 // This has the downside of leaving the IR in an invalid state (for example, 102 // the instruction `getelementptr {ptr addrspace(8), i32} %p, ...` will exist), 103 // but all such invalid states will be resolved by the third phase. 104 // 105 // Functions that don't take buffer fat pointers are modified in place. Those 106 // that do take such pointers have their basic blocks moved to a new function 107 // with arguments that are {ptr addrspace(8), i32} arguments and return values. 108 // This phase also records intrinsics so that they can be remangled or deleted 109 // later. 110 // 111 // ## Splitting pointer structs 112 // 113 // The meat of this pass consists of defining semantics for operations that 114 // produce or consume [vectors of] buffer fat pointers in terms of their 115 // resource and offset parts. This is accomplished throgh the `SplitPtrStructs` 116 // visitor. 117 // 118 // In the first pass through each function that is being lowered, the splitter 119 // inserts new instructions to implement the split-structures behavior, which is 120 // needed for correctness and performance. It records a list of "split users", 121 // instructions that are being replaced by operations on the resource and offset 122 // parts. 123 // 124 // Split users do not necessarily need to produce parts themselves ( 125 // a `load float, ptr addrspace(7)` does not, for example), but, if they do not 126 // generate fat buffer pointers, they must RAUW in their replacement 127 // instructions during the initial visit. 128 // 129 // When these new instructions are created, they use the split parts recorded 130 // for their initial arguments in order to generate their replacements, creating 131 // a parallel set of instructions that does not refer to the original fat 132 // pointer values but instead to their resource and offset components. 133 // 134 // Instructions, such as `extractvalue`, that produce buffer fat pointers from 135 // sources that do not have split parts, have such parts generated using 136 // `extractvalue`. This is also the initial handling of PHI nodes, which 137 // are then cleaned up. 138 // 139 // ### Conditionals 140 // 141 // PHI nodes are initially given resource parts via `extractvalue`. However, 142 // this is not an efficient rewrite of such nodes, as, in most cases, the 143 // resource part in a conditional or loop remains constant throughout the loop 144 // and only the offset varies. Failing to optimize away these constant resources 145 // would cause additional registers to be sent around loops and might lead to 146 // waterfall loops being generated for buffer operations due to the 147 // "non-uniform" resource argument. 148 // 149 // Therefore, after all instructions have been visited, the pointer splitter 150 // post-processes all encountered conditionals. Given a PHI node or select, 151 // getPossibleRsrcRoots() collects all values that the resource parts of that 152 // conditional's input could come from as well as collecting all conditional 153 // instructions encountered during the search. If, after filtering out the 154 // initial node itself, the set of encountered conditionals is a subset of the 155 // potential roots and there is a single potential resource that isn't in the 156 // conditional set, that value is the only possible value the resource argument 157 // could have throughout the control flow. 158 // 159 // If that condition is met, then a PHI node can have its resource part changed 160 // to the singleton value and then be replaced by a PHI on the offsets. 161 // Otherwise, each PHI node is split into two, one for the resource part and one 162 // for the offset part, which replace the temporary `extractvalue` instructions 163 // that were added during the first pass. 164 // 165 // Similar logic applies to `select`, where 166 // `%z = select i1 %cond, %cond, ptr addrspace(7) %x, ptr addrspace(7) %y` 167 // can be split into `%z.rsrc = %x.rsrc` and 168 // `%z.off = select i1 %cond, ptr i32 %x.off, i32 %y.off` 169 // if both `%x` and `%y` have the same resource part, but two `select` 170 // operations will be needed if they do not. 171 // 172 // ### Final processing 173 // 174 // After conditionals have been cleaned up, the IR for each function is 175 // rewritten to remove all the old instructions that have been split up. 176 // 177 // Any instruction that used to produce a buffer fat pointer (and therefore now 178 // produces a resource-and-offset struct after type remapping) is 179 // replaced as follows: 180 // 1. All debug value annotations are cloned to reflect that the resource part 181 // and offset parts are computed separately and constitute different 182 // fragments of the underlying source language variable. 183 // 2. All uses that were themselves split are replaced by a `poison` of the 184 // struct type, as they will themselves be erased soon. This rule, combined 185 // with debug handling, should leave the use lists of split instructions 186 // empty in almost all cases. 187 // 3. If a user of the original struct-valued result remains, the structure 188 // needed for the new types to work is constructed out of the newly-defined 189 // parts, and the original instruction is replaced by this structure 190 // before being erased. Instructions requiring this construction include 191 // `ret` and `insertvalue`. 192 // 193 // # Consequences 194 // 195 // This pass does not alter the CFG. 196 // 197 // Alias analysis information will become coarser, as the LLVM alias analyzer 198 // cannot handle the buffer intrinsics. Specifically, while we can determine 199 // that the following two loads do not alias: 200 // ``` 201 // %y = getelementptr i32, ptr addrspace(7) %x, i32 1 202 // %a = load i32, ptr addrspace(7) %x 203 // %b = load i32, ptr addrspace(7) %y 204 // ``` 205 // we cannot (except through some code that runs during scheduling) determine 206 // that the rewritten loads below do not alias. 207 // ``` 208 // %y.off = add i32 %x.off, 1 209 // %a = call @llvm.amdgcn.raw.ptr.buffer.load(ptr addrspace(8) %x.rsrc, i32 210 // %x.off, ...) 211 // %b = call @llvm.amdgcn.raw.ptr.buffer.load(ptr addrspace(8) 212 // %x.rsrc, i32 %y.off, ...) 213 // ``` 214 // However, existing alias information is preserved. 215 //===----------------------------------------------------------------------===// 216 217 #include "AMDGPU.h" 218 #include "AMDGPUTargetMachine.h" 219 #include "GCNSubtarget.h" 220 #include "SIDefines.h" 221 #include "llvm/ADT/SetOperations.h" 222 #include "llvm/ADT/SmallVector.h" 223 #include "llvm/Analysis/ConstantFolding.h" 224 #include "llvm/Analysis/Utils/Local.h" 225 #include "llvm/CodeGen/TargetPassConfig.h" 226 #include "llvm/IR/AttributeMask.h" 227 #include "llvm/IR/Constants.h" 228 #include "llvm/IR/DebugInfo.h" 229 #include "llvm/IR/DerivedTypes.h" 230 #include "llvm/IR/IRBuilder.h" 231 #include "llvm/IR/InstIterator.h" 232 #include "llvm/IR/InstVisitor.h" 233 #include "llvm/IR/Instructions.h" 234 #include "llvm/IR/Intrinsics.h" 235 #include "llvm/IR/IntrinsicsAMDGPU.h" 236 #include "llvm/IR/Metadata.h" 237 #include "llvm/IR/Operator.h" 238 #include "llvm/IR/PatternMatch.h" 239 #include "llvm/IR/ReplaceConstant.h" 240 #include "llvm/InitializePasses.h" 241 #include "llvm/Pass.h" 242 #include "llvm/Support/Alignment.h" 243 #include "llvm/Support/AtomicOrdering.h" 244 #include "llvm/Support/Debug.h" 245 #include "llvm/Support/ErrorHandling.h" 246 #include "llvm/Transforms/Utils/Cloning.h" 247 #include "llvm/Transforms/Utils/Local.h" 248 #include "llvm/Transforms/Utils/ValueMapper.h" 249 250 #define DEBUG_TYPE "amdgpu-lower-buffer-fat-pointers" 251 252 using namespace llvm; 253 254 static constexpr unsigned BufferOffsetWidth = 32; 255 256 namespace { 257 /// Recursively replace instances of ptr addrspace(7) and vector<Nxptr 258 /// addrspace(7)> with some other type as defined by the relevant subclass. 259 class BufferFatPtrTypeLoweringBase : public ValueMapTypeRemapper { 260 DenseMap<Type *, Type *> Map; 261 262 Type *remapTypeImpl(Type *Ty, SmallPtrSetImpl<StructType *> &Seen); 263 264 protected: 265 virtual Type *remapScalar(PointerType *PT) = 0; 266 virtual Type *remapVector(VectorType *VT) = 0; 267 268 const DataLayout &DL; 269 270 public: 271 BufferFatPtrTypeLoweringBase(const DataLayout &DL) : DL(DL) {} 272 Type *remapType(Type *SrcTy) override; 273 void clear() { Map.clear(); } 274 }; 275 276 /// Remap ptr addrspace(7) to i160 and vector<Nxptr addrspace(7)> to 277 /// vector<Nxi60> in order to correctly handling loading/storing these values 278 /// from memory. 279 class BufferFatPtrToIntTypeMap : public BufferFatPtrTypeLoweringBase { 280 using BufferFatPtrTypeLoweringBase::BufferFatPtrTypeLoweringBase; 281 282 protected: 283 Type *remapScalar(PointerType *PT) override { return DL.getIntPtrType(PT); } 284 Type *remapVector(VectorType *VT) override { return DL.getIntPtrType(VT); } 285 }; 286 287 /// Remap ptr addrspace(7) to {ptr addrspace(8), i32} (the resource and offset 288 /// parts of the pointer) so that we can easily rewrite operations on these 289 /// values that aren't loading them from or storing them to memory. 290 class BufferFatPtrToStructTypeMap : public BufferFatPtrTypeLoweringBase { 291 using BufferFatPtrTypeLoweringBase::BufferFatPtrTypeLoweringBase; 292 293 protected: 294 Type *remapScalar(PointerType *PT) override; 295 Type *remapVector(VectorType *VT) override; 296 }; 297 } // namespace 298 299 // This code is adapted from the type remapper in lib/Linker/IRMover.cpp 300 Type *BufferFatPtrTypeLoweringBase::remapTypeImpl( 301 Type *Ty, SmallPtrSetImpl<StructType *> &Seen) { 302 Type **Entry = &Map[Ty]; 303 if (*Entry) 304 return *Entry; 305 if (auto *PT = dyn_cast<PointerType>(Ty)) { 306 if (PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER) { 307 return *Entry = remapScalar(PT); 308 } 309 } 310 if (auto *VT = dyn_cast<VectorType>(Ty)) { 311 auto *PT = dyn_cast<PointerType>(VT->getElementType()); 312 if (PT && PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER) { 313 return *Entry = remapVector(VT); 314 } 315 return *Entry = Ty; 316 } 317 // Whether the type is one that is structurally uniqued - that is, if it is 318 // not a named struct (the only kind of type where multiple structurally 319 // identical types that have a distinct `Type*`) 320 StructType *TyAsStruct = dyn_cast<StructType>(Ty); 321 bool IsUniqued = !TyAsStruct || TyAsStruct->isLiteral(); 322 // Base case for ints, floats, opaque pointers, and so on, which don't 323 // require recursion. 324 if (Ty->getNumContainedTypes() == 0 && IsUniqued) 325 return *Entry = Ty; 326 if (!IsUniqued) { 327 // Create a dummy type for recursion purposes. 328 if (!Seen.insert(TyAsStruct).second) { 329 StructType *Placeholder = StructType::create(Ty->getContext()); 330 return *Entry = Placeholder; 331 } 332 } 333 bool Changed = false; 334 SmallVector<Type *> ElementTypes(Ty->getNumContainedTypes(), nullptr); 335 for (unsigned int I = 0, E = Ty->getNumContainedTypes(); I < E; ++I) { 336 Type *OldElem = Ty->getContainedType(I); 337 Type *NewElem = remapTypeImpl(OldElem, Seen); 338 ElementTypes[I] = NewElem; 339 Changed |= (OldElem != NewElem); 340 } 341 // Recursive calls to remapTypeImpl() may have invalidated pointer. 342 Entry = &Map[Ty]; 343 if (!Changed) { 344 return *Entry = Ty; 345 } 346 if (auto *ArrTy = dyn_cast<ArrayType>(Ty)) 347 return *Entry = ArrayType::get(ElementTypes[0], ArrTy->getNumElements()); 348 if (auto *FnTy = dyn_cast<FunctionType>(Ty)) 349 return *Entry = FunctionType::get(ElementTypes[0], 350 ArrayRef(ElementTypes).slice(1), 351 FnTy->isVarArg()); 352 if (auto *STy = dyn_cast<StructType>(Ty)) { 353 // Genuine opaque types don't have a remapping. 354 if (STy->isOpaque()) 355 return *Entry = Ty; 356 bool IsPacked = STy->isPacked(); 357 if (IsUniqued) 358 return *Entry = StructType::get(Ty->getContext(), ElementTypes, IsPacked); 359 SmallString<16> Name(STy->getName()); 360 STy->setName(""); 361 Type **RecursionEntry = &Map[Ty]; 362 if (*RecursionEntry) { 363 auto *Placeholder = cast<StructType>(*RecursionEntry); 364 Placeholder->setBody(ElementTypes, IsPacked); 365 Placeholder->setName(Name); 366 return *Entry = Placeholder; 367 } 368 return *Entry = StructType::create(Ty->getContext(), ElementTypes, Name, 369 IsPacked); 370 } 371 llvm_unreachable("Unknown type of type that contains elements"); 372 } 373 374 Type *BufferFatPtrTypeLoweringBase::remapType(Type *SrcTy) { 375 SmallPtrSet<StructType *, 2> Visited; 376 return remapTypeImpl(SrcTy, Visited); 377 } 378 379 Type *BufferFatPtrToStructTypeMap::remapScalar(PointerType *PT) { 380 LLVMContext &Ctx = PT->getContext(); 381 return StructType::get(PointerType::get(Ctx, AMDGPUAS::BUFFER_RESOURCE), 382 IntegerType::get(Ctx, BufferOffsetWidth)); 383 } 384 385 Type *BufferFatPtrToStructTypeMap::remapVector(VectorType *VT) { 386 ElementCount EC = VT->getElementCount(); 387 LLVMContext &Ctx = VT->getContext(); 388 Type *RsrcVec = 389 VectorType::get(PointerType::get(Ctx, AMDGPUAS::BUFFER_RESOURCE), EC); 390 Type *OffVec = VectorType::get(IntegerType::get(Ctx, BufferOffsetWidth), EC); 391 return StructType::get(RsrcVec, OffVec); 392 } 393 394 static bool isBufferFatPtrOrVector(Type *Ty) { 395 if (auto *PT = dyn_cast<PointerType>(Ty->getScalarType())) 396 return PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER; 397 return false; 398 } 399 400 // True if the type is {ptr addrspace(8), i32} or a struct containing vectors of 401 // those types. Used to quickly skip instructions we don't need to process. 402 static bool isSplitFatPtr(Type *Ty) { 403 auto *ST = dyn_cast<StructType>(Ty); 404 if (!ST) 405 return false; 406 if (!ST->isLiteral() || ST->getNumElements() != 2) 407 return false; 408 auto *MaybeRsrc = 409 dyn_cast<PointerType>(ST->getElementType(0)->getScalarType()); 410 auto *MaybeOff = 411 dyn_cast<IntegerType>(ST->getElementType(1)->getScalarType()); 412 return MaybeRsrc && MaybeOff && 413 MaybeRsrc->getAddressSpace() == AMDGPUAS::BUFFER_RESOURCE && 414 MaybeOff->getBitWidth() == BufferOffsetWidth; 415 } 416 417 // True if the result type or any argument types are buffer fat pointers. 418 static bool isBufferFatPtrConst(Constant *C) { 419 Type *T = C->getType(); 420 return isBufferFatPtrOrVector(T) || any_of(C->operands(), [](const Use &U) { 421 return isBufferFatPtrOrVector(U.get()->getType()); 422 }); 423 } 424 425 namespace { 426 /// Convert [vectors of] buffer fat pointers to integers when they are read from 427 /// or stored to memory. This ensures that these pointers will have the same 428 /// memory layout as before they are lowered, even though they will no longer 429 /// have their previous layout in registers/in the program (they'll be broken 430 /// down into resource and offset parts). This has the downside of imposing 431 /// marshalling costs when reading or storing these values, but since placing 432 /// such pointers into memory is an uncommon operation at best, we feel that 433 /// this cost is acceptable for better performance in the common case. 434 class StoreFatPtrsAsIntsVisitor 435 : public InstVisitor<StoreFatPtrsAsIntsVisitor, bool> { 436 BufferFatPtrToIntTypeMap *TypeMap; 437 438 ValueToValueMapTy ConvertedForStore; 439 440 IRBuilder<> IRB; 441 442 // Convert all the buffer fat pointers within the input value to inttegers 443 // so that it can be stored in memory. 444 Value *fatPtrsToInts(Value *V, Type *From, Type *To, const Twine &Name); 445 // Convert all the i160s that need to be buffer fat pointers (as specified) 446 // by the To type) into those pointers to preserve the semantics of the rest 447 // of the program. 448 Value *intsToFatPtrs(Value *V, Type *From, Type *To, const Twine &Name); 449 450 public: 451 StoreFatPtrsAsIntsVisitor(BufferFatPtrToIntTypeMap *TypeMap, LLVMContext &Ctx) 452 : TypeMap(TypeMap), IRB(Ctx) {} 453 bool processFunction(Function &F); 454 455 bool visitInstruction(Instruction &I) { return false; } 456 bool visitAllocaInst(AllocaInst &I); 457 bool visitLoadInst(LoadInst &LI); 458 bool visitStoreInst(StoreInst &SI); 459 bool visitGetElementPtrInst(GetElementPtrInst &I); 460 }; 461 } // namespace 462 463 Value *StoreFatPtrsAsIntsVisitor::fatPtrsToInts(Value *V, Type *From, Type *To, 464 const Twine &Name) { 465 if (From == To) 466 return V; 467 ValueToValueMapTy::iterator Find = ConvertedForStore.find(V); 468 if (Find != ConvertedForStore.end()) 469 return Find->second; 470 if (isBufferFatPtrOrVector(From)) { 471 Value *Cast = IRB.CreatePtrToInt(V, To, Name + ".int"); 472 ConvertedForStore[V] = Cast; 473 return Cast; 474 } 475 if (From->getNumContainedTypes() == 0) 476 return V; 477 // Structs, arrays, and other compound types. 478 Value *Ret = PoisonValue::get(To); 479 if (auto *AT = dyn_cast<ArrayType>(From)) { 480 Type *FromPart = AT->getArrayElementType(); 481 Type *ToPart = cast<ArrayType>(To)->getElementType(); 482 for (uint64_t I = 0, E = AT->getArrayNumElements(); I < E; ++I) { 483 Value *Field = IRB.CreateExtractValue(V, I); 484 Value *NewField = 485 fatPtrsToInts(Field, FromPart, ToPart, Name + "." + Twine(I)); 486 Ret = IRB.CreateInsertValue(Ret, NewField, I); 487 } 488 } else { 489 for (auto [Idx, FromPart, ToPart] : 490 enumerate(From->subtypes(), To->subtypes())) { 491 Value *Field = IRB.CreateExtractValue(V, Idx); 492 Value *NewField = 493 fatPtrsToInts(Field, FromPart, ToPart, Name + "." + Twine(Idx)); 494 Ret = IRB.CreateInsertValue(Ret, NewField, Idx); 495 } 496 } 497 ConvertedForStore[V] = Ret; 498 return Ret; 499 } 500 501 Value *StoreFatPtrsAsIntsVisitor::intsToFatPtrs(Value *V, Type *From, Type *To, 502 const Twine &Name) { 503 if (From == To) 504 return V; 505 if (isBufferFatPtrOrVector(To)) { 506 Value *Cast = IRB.CreateIntToPtr(V, To, Name + ".ptr"); 507 return Cast; 508 } 509 if (From->getNumContainedTypes() == 0) 510 return V; 511 // Structs, arrays, and other compound types. 512 Value *Ret = PoisonValue::get(To); 513 if (auto *AT = dyn_cast<ArrayType>(From)) { 514 Type *FromPart = AT->getArrayElementType(); 515 Type *ToPart = cast<ArrayType>(To)->getElementType(); 516 for (uint64_t I = 0, E = AT->getArrayNumElements(); I < E; ++I) { 517 Value *Field = IRB.CreateExtractValue(V, I); 518 Value *NewField = 519 intsToFatPtrs(Field, FromPart, ToPart, Name + "." + Twine(I)); 520 Ret = IRB.CreateInsertValue(Ret, NewField, I); 521 } 522 } else { 523 for (auto [Idx, FromPart, ToPart] : 524 enumerate(From->subtypes(), To->subtypes())) { 525 Value *Field = IRB.CreateExtractValue(V, Idx); 526 Value *NewField = 527 intsToFatPtrs(Field, FromPart, ToPart, Name + "." + Twine(Idx)); 528 Ret = IRB.CreateInsertValue(Ret, NewField, Idx); 529 } 530 } 531 return Ret; 532 } 533 534 bool StoreFatPtrsAsIntsVisitor::processFunction(Function &F) { 535 bool Changed = false; 536 // The visitors will mutate GEPs and allocas, but will push loads and stores 537 // to the worklist to avoid invalidation. 538 for (Instruction &I : make_early_inc_range(instructions(F))) { 539 Changed |= visit(I); 540 } 541 ConvertedForStore.clear(); 542 return Changed; 543 } 544 545 bool StoreFatPtrsAsIntsVisitor::visitAllocaInst(AllocaInst &I) { 546 Type *Ty = I.getAllocatedType(); 547 Type *NewTy = TypeMap->remapType(Ty); 548 if (Ty == NewTy) 549 return false; 550 I.setAllocatedType(NewTy); 551 return true; 552 } 553 554 bool StoreFatPtrsAsIntsVisitor::visitGetElementPtrInst(GetElementPtrInst &I) { 555 Type *Ty = I.getSourceElementType(); 556 Type *NewTy = TypeMap->remapType(Ty); 557 if (Ty == NewTy) 558 return false; 559 // We'll be rewriting the type `ptr addrspace(7)` out of existence soon, so 560 // make sure GEPs don't have different semantics with the new type. 561 I.setSourceElementType(NewTy); 562 I.setResultElementType(TypeMap->remapType(I.getResultElementType())); 563 return true; 564 } 565 566 bool StoreFatPtrsAsIntsVisitor::visitLoadInst(LoadInst &LI) { 567 Type *Ty = LI.getType(); 568 Type *IntTy = TypeMap->remapType(Ty); 569 if (Ty == IntTy) 570 return false; 571 572 IRB.SetInsertPoint(&LI); 573 auto *NLI = cast<LoadInst>(LI.clone()); 574 NLI->mutateType(IntTy); 575 NLI = IRB.Insert(NLI); 576 NLI->takeName(&LI); 577 578 Value *CastBack = intsToFatPtrs(NLI, IntTy, Ty, NLI->getName()); 579 LI.replaceAllUsesWith(CastBack); 580 LI.eraseFromParent(); 581 return true; 582 } 583 584 bool StoreFatPtrsAsIntsVisitor::visitStoreInst(StoreInst &SI) { 585 Value *V = SI.getValueOperand(); 586 Type *Ty = V->getType(); 587 Type *IntTy = TypeMap->remapType(Ty); 588 if (Ty == IntTy) 589 return false; 590 591 IRB.SetInsertPoint(&SI); 592 Value *IntV = fatPtrsToInts(V, Ty, IntTy, V->getName()); 593 for (auto *Dbg : at::getAssignmentMarkers(&SI)) 594 Dbg->setValue(IntV); 595 596 SI.setOperand(0, IntV); 597 return true; 598 } 599 600 namespace { 601 /// Convert loads/stores of types that the buffer intrinsics can't handle into 602 /// one ore more such loads/stores that consist of legal types. 603 /// 604 /// Do this by 605 /// 1. Recursing into structs (and arrays that don't share a memory layout with 606 /// vectors) since the intrinsics can't handle complex types. 607 /// 2. Converting arrays of non-aggregate, byte-sized types into their 608 /// corresponding vectors 609 /// 3. Bitcasting unsupported types, namely overly-long scalars and byte 610 /// vectors, into vectors of supported types. 611 /// 4. Splitting up excessively long reads/writes into multiple operations. 612 /// 613 /// Note that this doesn't handle complex data strucures, but, in the future, 614 /// the aggregate load splitter from SROA could be refactored to allow for that 615 /// case. 616 class LegalizeBufferContentTypesVisitor 617 : public InstVisitor<LegalizeBufferContentTypesVisitor, bool> { 618 friend class InstVisitor<LegalizeBufferContentTypesVisitor, bool>; 619 620 IRBuilder<> IRB; 621 622 const DataLayout &DL; 623 624 /// If T is [N x U], where U is a scalar type, return the vector type 625 /// <N x U>, otherwise, return T. 626 Type *scalarArrayTypeAsVector(Type *MaybeArrayType); 627 Value *arrayToVector(Value *V, Type *TargetType, const Twine &Name); 628 Value *vectorToArray(Value *V, Type *OrigType, const Twine &Name); 629 630 /// Break up the loads of a struct into the loads of its components 631 632 /// Convert a vector or scalar type that can't be operated on by buffer 633 /// intrinsics to one that would be legal through bitcasts and/or truncation. 634 /// Uses the wider of i32, i16, or i8 where possible. 635 Type *legalNonAggregateFor(Type *T); 636 Value *makeLegalNonAggregate(Value *V, Type *TargetType, const Twine &Name); 637 Value *makeIllegalNonAggregate(Value *V, Type *OrigType, const Twine &Name); 638 639 struct VecSlice { 640 uint64_t Index = 0; 641 uint64_t Length = 0; 642 VecSlice() = delete; 643 }; 644 /// Return the [index, length] pairs into which `T` needs to be cut to form 645 /// legal buffer load or store operations. Clears `Slices`. Creates an empty 646 /// `Slices` for non-vector inputs and creates one slice if no slicing will be 647 /// needed. 648 void getVecSlices(Type *T, SmallVectorImpl<VecSlice> &Slices); 649 650 Value *extractSlice(Value *Vec, VecSlice S, const Twine &Name); 651 Value *insertSlice(Value *Whole, Value *Part, VecSlice S, const Twine &Name); 652 653 /// In most cases, return `LegalType`. However, when given an input that would 654 /// normally be a legal type for the buffer intrinsics to return but that 655 /// isn't hooked up through SelectionDAG, return a type of the same width that 656 /// can be used with the relevant intrinsics. Specifically, handle the cases: 657 /// - <1 x T> => T for all T 658 /// - <N x i8> <=> i16, i32, 2xi32, 4xi32 (as needed) 659 /// - <N x T> where T is under 32 bits and the total size is 96 bits <=> <3 x 660 /// i32> 661 Type *intrinsicTypeFor(Type *LegalType); 662 663 bool visitLoadImpl(LoadInst &OrigLI, Type *PartType, 664 SmallVectorImpl<uint32_t> &AggIdxs, uint64_t AggByteOffset, 665 Value *&Result, const Twine &Name); 666 /// Return value is (Changed, ModifiedInPlace) 667 std::pair<bool, bool> visitStoreImpl(StoreInst &OrigSI, Type *PartType, 668 SmallVectorImpl<uint32_t> &AggIdxs, 669 uint64_t AggByteOffset, 670 const Twine &Name); 671 672 bool visitInstruction(Instruction &I) { return false; } 673 bool visitLoadInst(LoadInst &LI); 674 bool visitStoreInst(StoreInst &SI); 675 676 public: 677 LegalizeBufferContentTypesVisitor(const DataLayout &DL, LLVMContext &Ctx) 678 : IRB(Ctx), DL(DL) {} 679 bool processFunction(Function &F); 680 }; 681 } // namespace 682 683 Type *LegalizeBufferContentTypesVisitor::scalarArrayTypeAsVector(Type *T) { 684 ArrayType *AT = dyn_cast<ArrayType>(T); 685 if (!AT) 686 return T; 687 Type *ET = AT->getElementType(); 688 if (!ET->isSingleValueType() || isa<VectorType>(ET)) 689 report_fatal_error("loading non-scalar arrays from buffer fat pointers " 690 "should have recursed"); 691 if (!DL.typeSizeEqualsStoreSize(AT)) 692 report_fatal_error( 693 "loading padded arrays from buffer fat pinters should have recursed"); 694 return FixedVectorType::get(ET, AT->getNumElements()); 695 } 696 697 Value *LegalizeBufferContentTypesVisitor::arrayToVector(Value *V, 698 Type *TargetType, 699 const Twine &Name) { 700 Value *VectorRes = PoisonValue::get(TargetType); 701 auto *VT = cast<FixedVectorType>(TargetType); 702 unsigned EC = VT->getNumElements(); 703 for (auto I : iota_range<unsigned>(0, EC, /*Inclusive=*/false)) { 704 Value *Elem = IRB.CreateExtractValue(V, I, Name + ".elem." + Twine(I)); 705 VectorRes = IRB.CreateInsertElement(VectorRes, Elem, I, 706 Name + ".as.vec." + Twine(I)); 707 } 708 return VectorRes; 709 } 710 711 Value *LegalizeBufferContentTypesVisitor::vectorToArray(Value *V, 712 Type *OrigType, 713 const Twine &Name) { 714 Value *ArrayRes = PoisonValue::get(OrigType); 715 ArrayType *AT = cast<ArrayType>(OrigType); 716 unsigned EC = AT->getNumElements(); 717 for (auto I : iota_range<unsigned>(0, EC, /*Inclusive=*/false)) { 718 Value *Elem = IRB.CreateExtractElement(V, I, Name + ".elem." + Twine(I)); 719 ArrayRes = IRB.CreateInsertValue(ArrayRes, Elem, I, 720 Name + ".as.array." + Twine(I)); 721 } 722 return ArrayRes; 723 } 724 725 Type *LegalizeBufferContentTypesVisitor::legalNonAggregateFor(Type *T) { 726 TypeSize Size = DL.getTypeStoreSizeInBits(T); 727 // Implicitly zero-extend to the next byte if needed 728 if (!DL.typeSizeEqualsStoreSize(T)) 729 T = IRB.getIntNTy(Size.getFixedValue()); 730 Type *ElemTy = T->getScalarType(); 731 if (isa<PointerType, ScalableVectorType>(ElemTy)) { 732 // Pointers are always big enough, and we'll let scalable vectors through to 733 // fail in codegen. 734 return T; 735 } 736 unsigned ElemSize = DL.getTypeSizeInBits(ElemTy).getFixedValue(); 737 if (isPowerOf2_32(ElemSize) && ElemSize >= 16 && ElemSize <= 128) { 738 // [vectors of] anything that's 16/32/64/128 bits can be cast and split into 739 // legal buffer operations. 740 return T; 741 } 742 Type *BestVectorElemType = nullptr; 743 if (Size.isKnownMultipleOf(32)) 744 BestVectorElemType = IRB.getInt32Ty(); 745 else if (Size.isKnownMultipleOf(16)) 746 BestVectorElemType = IRB.getInt16Ty(); 747 else 748 BestVectorElemType = IRB.getInt8Ty(); 749 unsigned NumCastElems = 750 Size.getFixedValue() / BestVectorElemType->getIntegerBitWidth(); 751 if (NumCastElems == 1) 752 return BestVectorElemType; 753 return FixedVectorType::get(BestVectorElemType, NumCastElems); 754 } 755 756 Value *LegalizeBufferContentTypesVisitor::makeLegalNonAggregate( 757 Value *V, Type *TargetType, const Twine &Name) { 758 Type *SourceType = V->getType(); 759 TypeSize SourceSize = DL.getTypeSizeInBits(SourceType); 760 TypeSize TargetSize = DL.getTypeSizeInBits(TargetType); 761 if (SourceSize != TargetSize) { 762 Type *ShortScalarTy = IRB.getIntNTy(SourceSize.getFixedValue()); 763 Type *ByteScalarTy = IRB.getIntNTy(TargetSize.getFixedValue()); 764 Value *AsScalar = IRB.CreateBitCast(V, ShortScalarTy, Name + ".as.scalar"); 765 Value *Zext = IRB.CreateZExt(AsScalar, ByteScalarTy, Name + ".zext"); 766 V = Zext; 767 SourceType = ByteScalarTy; 768 } 769 return IRB.CreateBitCast(V, TargetType, Name + ".legal"); 770 } 771 772 Value *LegalizeBufferContentTypesVisitor::makeIllegalNonAggregate( 773 Value *V, Type *OrigType, const Twine &Name) { 774 Type *LegalType = V->getType(); 775 TypeSize LegalSize = DL.getTypeSizeInBits(LegalType); 776 TypeSize OrigSize = DL.getTypeSizeInBits(OrigType); 777 if (LegalSize != OrigSize) { 778 Type *ShortScalarTy = IRB.getIntNTy(OrigSize.getFixedValue()); 779 Type *ByteScalarTy = IRB.getIntNTy(LegalSize.getFixedValue()); 780 Value *AsScalar = IRB.CreateBitCast(V, ByteScalarTy, Name + ".bytes.cast"); 781 Value *Trunc = IRB.CreateTrunc(AsScalar, ShortScalarTy, Name + ".trunc"); 782 return IRB.CreateBitCast(Trunc, OrigType, Name + ".orig"); 783 } 784 return IRB.CreateBitCast(V, OrigType, Name + ".real.ty"); 785 } 786 787 Type *LegalizeBufferContentTypesVisitor::intrinsicTypeFor(Type *LegalType) { 788 auto *VT = dyn_cast<FixedVectorType>(LegalType); 789 if (!VT) 790 return LegalType; 791 Type *ET = VT->getElementType(); 792 // Explicitly return the element type of 1-element vectors because the 793 // underlying intrinsics don't like <1 x T> even though it's a synonym for T. 794 if (VT->getNumElements() == 1) 795 return ET; 796 if (DL.getTypeSizeInBits(LegalType) == 96 && DL.getTypeSizeInBits(ET) < 32) 797 return FixedVectorType::get(IRB.getInt32Ty(), 3); 798 if (ET->isIntegerTy(8)) { 799 switch (VT->getNumElements()) { 800 default: 801 return LegalType; // Let it crash later 802 case 1: 803 return IRB.getInt8Ty(); 804 case 2: 805 return IRB.getInt16Ty(); 806 case 4: 807 return IRB.getInt32Ty(); 808 case 8: 809 return FixedVectorType::get(IRB.getInt32Ty(), 2); 810 case 16: 811 return FixedVectorType::get(IRB.getInt32Ty(), 4); 812 } 813 } 814 return LegalType; 815 } 816 817 void LegalizeBufferContentTypesVisitor::getVecSlices( 818 Type *T, SmallVectorImpl<VecSlice> &Slices) { 819 Slices.clear(); 820 auto *VT = dyn_cast<FixedVectorType>(T); 821 if (!VT) 822 return; 823 824 uint64_t ElemBitWidth = 825 DL.getTypeSizeInBits(VT->getElementType()).getFixedValue(); 826 827 uint64_t ElemsPer4Words = 128 / ElemBitWidth; 828 uint64_t ElemsPer2Words = ElemsPer4Words / 2; 829 uint64_t ElemsPerWord = ElemsPer2Words / 2; 830 uint64_t ElemsPerShort = ElemsPerWord / 2; 831 uint64_t ElemsPerByte = ElemsPerShort / 2; 832 // If the elements evenly pack into 32-bit words, we can use 3-word stores, 833 // such as for <6 x bfloat> or <3 x i32>, but we can't dot his for, for 834 // example, <3 x i64>, since that's not slicing. 835 uint64_t ElemsPer3Words = ElemsPerWord * 3; 836 837 uint64_t TotalElems = VT->getNumElements(); 838 uint64_t Index = 0; 839 auto TrySlice = [&](unsigned MaybeLen) { 840 if (MaybeLen > 0 && Index + MaybeLen <= TotalElems) { 841 VecSlice Slice{/*Index=*/Index, /*Length=*/MaybeLen}; 842 Slices.push_back(Slice); 843 Index += MaybeLen; 844 return true; 845 } 846 return false; 847 }; 848 while (Index < TotalElems) { 849 TrySlice(ElemsPer4Words) || TrySlice(ElemsPer3Words) || 850 TrySlice(ElemsPer2Words) || TrySlice(ElemsPerWord) || 851 TrySlice(ElemsPerShort) || TrySlice(ElemsPerByte); 852 } 853 } 854 855 Value *LegalizeBufferContentTypesVisitor::extractSlice(Value *Vec, VecSlice S, 856 const Twine &Name) { 857 auto *VecVT = dyn_cast<FixedVectorType>(Vec->getType()); 858 if (!VecVT) 859 return Vec; 860 if (S.Length == VecVT->getNumElements() && S.Index == 0) 861 return Vec; 862 if (S.Length == 1) 863 return IRB.CreateExtractElement(Vec, S.Index, 864 Name + ".slice." + Twine(S.Index)); 865 SmallVector<int> Mask = llvm::to_vector( 866 llvm::iota_range<int>(S.Index, S.Index + S.Length, /*Inclusive=*/false)); 867 return IRB.CreateShuffleVector(Vec, Mask, Name + ".slice." + Twine(S.Index)); 868 } 869 870 Value *LegalizeBufferContentTypesVisitor::insertSlice(Value *Whole, Value *Part, 871 VecSlice S, 872 const Twine &Name) { 873 auto *WholeVT = dyn_cast<FixedVectorType>(Whole->getType()); 874 if (!WholeVT) 875 return Part; 876 if (S.Length == WholeVT->getNumElements() && S.Index == 0) 877 return Part; 878 if (S.Length == 1) { 879 return IRB.CreateInsertElement(Whole, Part, S.Index, 880 Name + ".slice." + Twine(S.Index)); 881 } 882 int NumElems = cast<FixedVectorType>(Whole->getType())->getNumElements(); 883 884 // Extend the slice with poisons to make the main shufflevector happy. 885 SmallVector<int> ExtPartMask(NumElems, -1); 886 for (auto [I, E] : llvm::enumerate( 887 MutableArrayRef<int>(ExtPartMask).take_front(S.Length))) { 888 E = I; 889 } 890 Value *ExtPart = IRB.CreateShuffleVector(Part, ExtPartMask, 891 Name + ".ext." + Twine(S.Index)); 892 893 SmallVector<int> Mask = 894 llvm::to_vector(llvm::iota_range<int>(0, NumElems, /*Inclusive=*/false)); 895 for (auto [I, E] : 896 llvm::enumerate(MutableArrayRef<int>(Mask).slice(S.Index, S.Length))) 897 E = I + NumElems; 898 return IRB.CreateShuffleVector(Whole, ExtPart, Mask, 899 Name + ".parts." + Twine(S.Index)); 900 } 901 902 bool LegalizeBufferContentTypesVisitor::visitLoadImpl( 903 LoadInst &OrigLI, Type *PartType, SmallVectorImpl<uint32_t> &AggIdxs, 904 uint64_t AggByteOff, Value *&Result, const Twine &Name) { 905 if (auto *ST = dyn_cast<StructType>(PartType)) { 906 const StructLayout *Layout = DL.getStructLayout(ST); 907 bool Changed = false; 908 for (auto [I, ElemTy, Offset] : 909 llvm::enumerate(ST->elements(), Layout->getMemberOffsets())) { 910 AggIdxs.push_back(I); 911 Changed |= visitLoadImpl(OrigLI, ElemTy, AggIdxs, 912 AggByteOff + Offset.getFixedValue(), Result, 913 Name + "." + Twine(I)); 914 AggIdxs.pop_back(); 915 } 916 return Changed; 917 } 918 if (auto *AT = dyn_cast<ArrayType>(PartType)) { 919 Type *ElemTy = AT->getElementType(); 920 if (!ElemTy->isSingleValueType() || !DL.typeSizeEqualsStoreSize(ElemTy) || 921 ElemTy->isVectorTy()) { 922 TypeSize ElemStoreSize = DL.getTypeStoreSize(ElemTy); 923 bool Changed = false; 924 for (auto I : llvm::iota_range<uint32_t>(0, AT->getNumElements(), 925 /*Inclusive=*/false)) { 926 AggIdxs.push_back(I); 927 Changed |= visitLoadImpl(OrigLI, ElemTy, AggIdxs, 928 AggByteOff + I * ElemStoreSize.getFixedValue(), 929 Result, Name + Twine(I)); 930 AggIdxs.pop_back(); 931 } 932 return Changed; 933 } 934 } 935 936 // Typical case 937 938 Type *ArrayAsVecType = scalarArrayTypeAsVector(PartType); 939 Type *LegalType = legalNonAggregateFor(ArrayAsVecType); 940 941 SmallVector<VecSlice> Slices; 942 getVecSlices(LegalType, Slices); 943 bool HasSlices = Slices.size() > 1; 944 bool IsAggPart = !AggIdxs.empty(); 945 Value *LoadsRes; 946 if (!HasSlices && !IsAggPart) { 947 Type *LoadableType = intrinsicTypeFor(LegalType); 948 if (LoadableType == PartType) 949 return false; 950 951 IRB.SetInsertPoint(&OrigLI); 952 auto *NLI = cast<LoadInst>(OrigLI.clone()); 953 NLI->mutateType(LoadableType); 954 NLI = IRB.Insert(NLI); 955 NLI->setName(Name + ".loadable"); 956 957 LoadsRes = IRB.CreateBitCast(NLI, LegalType, Name + ".from.loadable"); 958 } else { 959 IRB.SetInsertPoint(&OrigLI); 960 LoadsRes = PoisonValue::get(LegalType); 961 Value *OrigPtr = OrigLI.getPointerOperand(); 962 // If we're needing to spill something into more than one load, its legal 963 // type will be a vector (ex. an i256 load will have LegalType = <8 x i32>). 964 // But if we're already a scalar (which can happen if we're splitting up a 965 // struct), the element type will be the legal type itself. 966 Type *ElemType = LegalType->getScalarType(); 967 unsigned ElemBytes = DL.getTypeStoreSize(ElemType); 968 AAMDNodes AANodes = OrigLI.getAAMetadata(); 969 if (IsAggPart && Slices.empty()) 970 Slices.push_back(VecSlice{/*Index=*/0, /*Length=*/1}); 971 for (VecSlice S : Slices) { 972 Type *SliceType = 973 S.Length != 1 ? FixedVectorType::get(ElemType, S.Length) : ElemType; 974 int64_t ByteOffset = AggByteOff + S.Index * ElemBytes; 975 // You can't reasonably expect loads to wrap around the edge of memory. 976 Value *NewPtr = IRB.CreateGEP( 977 IRB.getInt8Ty(), OrigLI.getPointerOperand(), IRB.getInt32(ByteOffset), 978 OrigPtr->getName() + ".off.ptr." + Twine(ByteOffset), 979 GEPNoWrapFlags::noUnsignedWrap()); 980 Type *LoadableType = intrinsicTypeFor(SliceType); 981 LoadInst *NewLI = IRB.CreateAlignedLoad( 982 LoadableType, NewPtr, commonAlignment(OrigLI.getAlign(), ByteOffset), 983 Name + ".off." + Twine(ByteOffset)); 984 copyMetadataForLoad(*NewLI, OrigLI); 985 NewLI->setAAMetadata( 986 AANodes.adjustForAccess(ByteOffset, LoadableType, DL)); 987 NewLI->setAtomic(OrigLI.getOrdering(), OrigLI.getSyncScopeID()); 988 NewLI->setVolatile(OrigLI.isVolatile()); 989 Value *Loaded = IRB.CreateBitCast(NewLI, SliceType, 990 NewLI->getName() + ".from.loadable"); 991 LoadsRes = insertSlice(LoadsRes, Loaded, S, Name); 992 } 993 } 994 if (LegalType != ArrayAsVecType) 995 LoadsRes = makeIllegalNonAggregate(LoadsRes, ArrayAsVecType, Name); 996 if (ArrayAsVecType != PartType) 997 LoadsRes = vectorToArray(LoadsRes, PartType, Name); 998 999 if (IsAggPart) 1000 Result = IRB.CreateInsertValue(Result, LoadsRes, AggIdxs, Name); 1001 else 1002 Result = LoadsRes; 1003 return true; 1004 } 1005 1006 bool LegalizeBufferContentTypesVisitor::visitLoadInst(LoadInst &LI) { 1007 if (LI.getPointerAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER) 1008 return false; 1009 1010 SmallVector<uint32_t> AggIdxs; 1011 Type *OrigType = LI.getType(); 1012 Value *Result = PoisonValue::get(OrigType); 1013 bool Changed = visitLoadImpl(LI, OrigType, AggIdxs, 0, Result, LI.getName()); 1014 if (!Changed) 1015 return false; 1016 Result->takeName(&LI); 1017 LI.replaceAllUsesWith(Result); 1018 LI.eraseFromParent(); 1019 return Changed; 1020 } 1021 1022 std::pair<bool, bool> LegalizeBufferContentTypesVisitor::visitStoreImpl( 1023 StoreInst &OrigSI, Type *PartType, SmallVectorImpl<uint32_t> &AggIdxs, 1024 uint64_t AggByteOff, const Twine &Name) { 1025 if (auto *ST = dyn_cast<StructType>(PartType)) { 1026 const StructLayout *Layout = DL.getStructLayout(ST); 1027 bool Changed = false; 1028 for (auto [I, ElemTy, Offset] : 1029 llvm::enumerate(ST->elements(), Layout->getMemberOffsets())) { 1030 AggIdxs.push_back(I); 1031 Changed |= std::get<0>(visitStoreImpl(OrigSI, ElemTy, AggIdxs, 1032 AggByteOff + Offset.getFixedValue(), 1033 Name + "." + Twine(I))); 1034 AggIdxs.pop_back(); 1035 } 1036 return std::make_pair(Changed, /*ModifiedInPlace=*/false); 1037 } 1038 if (auto *AT = dyn_cast<ArrayType>(PartType)) { 1039 Type *ElemTy = AT->getElementType(); 1040 if (!ElemTy->isSingleValueType() || !DL.typeSizeEqualsStoreSize(ElemTy) || 1041 ElemTy->isVectorTy()) { 1042 TypeSize ElemStoreSize = DL.getTypeStoreSize(ElemTy); 1043 bool Changed = false; 1044 for (auto I : llvm::iota_range<uint32_t>(0, AT->getNumElements(), 1045 /*Inclusive=*/false)) { 1046 AggIdxs.push_back(I); 1047 Changed |= std::get<0>(visitStoreImpl( 1048 OrigSI, ElemTy, AggIdxs, 1049 AggByteOff + I * ElemStoreSize.getFixedValue(), Name + Twine(I))); 1050 AggIdxs.pop_back(); 1051 } 1052 return std::make_pair(Changed, /*ModifiedInPlace=*/false); 1053 } 1054 } 1055 1056 Value *OrigData = OrigSI.getValueOperand(); 1057 Value *NewData = OrigData; 1058 1059 bool IsAggPart = !AggIdxs.empty(); 1060 if (IsAggPart) 1061 NewData = IRB.CreateExtractValue(NewData, AggIdxs, Name); 1062 1063 Type *ArrayAsVecType = scalarArrayTypeAsVector(PartType); 1064 if (ArrayAsVecType != PartType) { 1065 NewData = arrayToVector(NewData, ArrayAsVecType, Name); 1066 } 1067 1068 Type *LegalType = legalNonAggregateFor(ArrayAsVecType); 1069 if (LegalType != ArrayAsVecType) { 1070 NewData = makeLegalNonAggregate(NewData, LegalType, Name); 1071 } 1072 1073 SmallVector<VecSlice> Slices; 1074 getVecSlices(LegalType, Slices); 1075 bool NeedToSplit = Slices.size() > 1 || IsAggPart; 1076 if (!NeedToSplit) { 1077 Type *StorableType = intrinsicTypeFor(LegalType); 1078 if (StorableType == PartType) 1079 return std::make_pair(/*Changed=*/false, /*ModifiedInPlace=*/false); 1080 NewData = IRB.CreateBitCast(NewData, StorableType, Name + ".storable"); 1081 OrigSI.setOperand(0, NewData); 1082 return std::make_pair(/*Changed=*/true, /*ModifiedInPlace=*/true); 1083 } 1084 1085 Value *OrigPtr = OrigSI.getPointerOperand(); 1086 Type *ElemType = LegalType->getScalarType(); 1087 if (IsAggPart && Slices.empty()) 1088 Slices.push_back(VecSlice{/*Index=*/0, /*Length=*/1}); 1089 unsigned ElemBytes = DL.getTypeStoreSize(ElemType); 1090 AAMDNodes AANodes = OrigSI.getAAMetadata(); 1091 for (VecSlice S : Slices) { 1092 Type *SliceType = 1093 S.Length != 1 ? FixedVectorType::get(ElemType, S.Length) : ElemType; 1094 int64_t ByteOffset = AggByteOff + S.Index * ElemBytes; 1095 Value *NewPtr = 1096 IRB.CreateGEP(IRB.getInt8Ty(), OrigPtr, IRB.getInt32(ByteOffset), 1097 OrigPtr->getName() + ".part." + Twine(S.Index), 1098 GEPNoWrapFlags::noUnsignedWrap()); 1099 Value *DataSlice = extractSlice(NewData, S, Name); 1100 Type *StorableType = intrinsicTypeFor(SliceType); 1101 DataSlice = IRB.CreateBitCast(DataSlice, StorableType, 1102 DataSlice->getName() + ".storable"); 1103 auto *NewSI = cast<StoreInst>(OrigSI.clone()); 1104 NewSI->setAlignment(commonAlignment(OrigSI.getAlign(), ByteOffset)); 1105 IRB.Insert(NewSI); 1106 NewSI->setOperand(0, DataSlice); 1107 NewSI->setOperand(1, NewPtr); 1108 NewSI->setAAMetadata(AANodes.adjustForAccess(ByteOffset, StorableType, DL)); 1109 } 1110 return std::make_pair(/*Changed=*/true, /*ModifiedInPlace=*/false); 1111 } 1112 1113 bool LegalizeBufferContentTypesVisitor::visitStoreInst(StoreInst &SI) { 1114 if (SI.getPointerAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER) 1115 return false; 1116 IRB.SetInsertPoint(&SI); 1117 SmallVector<uint32_t> AggIdxs; 1118 Value *OrigData = SI.getValueOperand(); 1119 auto [Changed, ModifiedInPlace] = 1120 visitStoreImpl(SI, OrigData->getType(), AggIdxs, 0, OrigData->getName()); 1121 if (Changed && !ModifiedInPlace) 1122 SI.eraseFromParent(); 1123 return Changed; 1124 } 1125 1126 bool LegalizeBufferContentTypesVisitor::processFunction(Function &F) { 1127 bool Changed = false; 1128 for (Instruction &I : make_early_inc_range(instructions(F))) { 1129 Changed |= visit(I); 1130 } 1131 return Changed; 1132 } 1133 1134 /// Return the ptr addrspace(8) and i32 (resource and offset parts) in a lowered 1135 /// buffer fat pointer constant. 1136 static std::pair<Constant *, Constant *> 1137 splitLoweredFatBufferConst(Constant *C) { 1138 assert(isSplitFatPtr(C->getType()) && "Not a split fat buffer pointer"); 1139 return std::make_pair(C->getAggregateElement(0u), C->getAggregateElement(1u)); 1140 } 1141 1142 namespace { 1143 /// Handle the remapping of ptr addrspace(7) constants. 1144 class FatPtrConstMaterializer final : public ValueMaterializer { 1145 BufferFatPtrToStructTypeMap *TypeMap; 1146 // An internal mapper that is used to recurse into the arguments of constants. 1147 // While the documentation for `ValueMapper` specifies not to use it 1148 // recursively, examination of the logic in mapValue() shows that it can 1149 // safely be used recursively when handling constants, like it does in its own 1150 // logic. 1151 ValueMapper InternalMapper; 1152 1153 Constant *materializeBufferFatPtrConst(Constant *C); 1154 1155 public: 1156 // UnderlyingMap is the value map this materializer will be filling. 1157 FatPtrConstMaterializer(BufferFatPtrToStructTypeMap *TypeMap, 1158 ValueToValueMapTy &UnderlyingMap) 1159 : TypeMap(TypeMap), 1160 InternalMapper(UnderlyingMap, RF_None, TypeMap, this) {} 1161 virtual ~FatPtrConstMaterializer() = default; 1162 1163 Value *materialize(Value *V) override; 1164 }; 1165 } // namespace 1166 1167 Constant *FatPtrConstMaterializer::materializeBufferFatPtrConst(Constant *C) { 1168 Type *SrcTy = C->getType(); 1169 auto *NewTy = dyn_cast<StructType>(TypeMap->remapType(SrcTy)); 1170 if (C->isNullValue()) 1171 return ConstantAggregateZero::getNullValue(NewTy); 1172 if (isa<PoisonValue>(C)) { 1173 return ConstantStruct::get(NewTy, 1174 {PoisonValue::get(NewTy->getElementType(0)), 1175 PoisonValue::get(NewTy->getElementType(1))}); 1176 } 1177 if (isa<UndefValue>(C)) { 1178 return ConstantStruct::get(NewTy, 1179 {UndefValue::get(NewTy->getElementType(0)), 1180 UndefValue::get(NewTy->getElementType(1))}); 1181 } 1182 1183 if (auto *VC = dyn_cast<ConstantVector>(C)) { 1184 if (Constant *S = VC->getSplatValue()) { 1185 Constant *NewS = InternalMapper.mapConstant(*S); 1186 if (!NewS) 1187 return nullptr; 1188 auto [Rsrc, Off] = splitLoweredFatBufferConst(NewS); 1189 auto EC = VC->getType()->getElementCount(); 1190 return ConstantStruct::get(NewTy, {ConstantVector::getSplat(EC, Rsrc), 1191 ConstantVector::getSplat(EC, Off)}); 1192 } 1193 SmallVector<Constant *> Rsrcs; 1194 SmallVector<Constant *> Offs; 1195 for (Value *Op : VC->operand_values()) { 1196 auto *NewOp = dyn_cast_or_null<Constant>(InternalMapper.mapValue(*Op)); 1197 if (!NewOp) 1198 return nullptr; 1199 auto [Rsrc, Off] = splitLoweredFatBufferConst(NewOp); 1200 Rsrcs.push_back(Rsrc); 1201 Offs.push_back(Off); 1202 } 1203 Constant *RsrcVec = ConstantVector::get(Rsrcs); 1204 Constant *OffVec = ConstantVector::get(Offs); 1205 return ConstantStruct::get(NewTy, {RsrcVec, OffVec}); 1206 } 1207 1208 if (isa<GlobalValue>(C)) 1209 report_fatal_error("Global values containing ptr addrspace(7) (buffer " 1210 "fat pointer) values are not supported"); 1211 1212 if (isa<ConstantExpr>(C)) 1213 report_fatal_error("Constant exprs containing ptr addrspace(7) (buffer " 1214 "fat pointer) values should have been expanded earlier"); 1215 1216 return nullptr; 1217 } 1218 1219 Value *FatPtrConstMaterializer::materialize(Value *V) { 1220 Constant *C = dyn_cast<Constant>(V); 1221 if (!C) 1222 return nullptr; 1223 // Structs and other types that happen to contain fat pointers get remapped 1224 // by the mapValue() logic. 1225 if (!isBufferFatPtrConst(C)) 1226 return nullptr; 1227 return materializeBufferFatPtrConst(C); 1228 } 1229 1230 using PtrParts = std::pair<Value *, Value *>; 1231 namespace { 1232 // The visitor returns the resource and offset parts for an instruction if they 1233 // can be computed, or (nullptr, nullptr) for cases that don't have a meaningful 1234 // value mapping. 1235 class SplitPtrStructs : public InstVisitor<SplitPtrStructs, PtrParts> { 1236 ValueToValueMapTy RsrcParts; 1237 ValueToValueMapTy OffParts; 1238 1239 // Track instructions that have been rewritten into a user of the component 1240 // parts of their ptr addrspace(7) input. Instructions that produced 1241 // ptr addrspace(7) parts should **not** be RAUW'd before being added to this 1242 // set, as that replacement will be handled in a post-visit step. However, 1243 // instructions that yield values that aren't fat pointers (ex. ptrtoint) 1244 // should RAUW themselves with new instructions that use the split parts 1245 // of their arguments during processing. 1246 DenseSet<Instruction *> SplitUsers; 1247 1248 // Nodes that need a second look once we've computed the parts for all other 1249 // instructions to see if, for example, we really need to phi on the resource 1250 // part. 1251 SmallVector<Instruction *> Conditionals; 1252 // Temporary instructions produced while lowering conditionals that should be 1253 // killed. 1254 SmallVector<Instruction *> ConditionalTemps; 1255 1256 // Subtarget info, needed for determining what cache control bits to set. 1257 const TargetMachine *TM; 1258 const GCNSubtarget *ST = nullptr; 1259 1260 IRBuilder<> IRB; 1261 1262 // Copy metadata between instructions if applicable. 1263 void copyMetadata(Value *Dest, Value *Src); 1264 1265 // Get the resource and offset parts of the value V, inserting appropriate 1266 // extractvalue calls if needed. 1267 PtrParts getPtrParts(Value *V); 1268 1269 // Given an instruction that could produce multiple resource parts (a PHI or 1270 // select), collect the set of possible instructions that could have provided 1271 // its resource parts that it could have (the `Roots`) and the set of 1272 // conditional instructions visited during the search (`Seen`). If, after 1273 // removing the root of the search from `Seen` and `Roots`, `Seen` is a subset 1274 // of `Roots` and `Roots - Seen` contains one element, the resource part of 1275 // that element can replace the resource part of all other elements in `Seen`. 1276 void getPossibleRsrcRoots(Instruction *I, SmallPtrSetImpl<Value *> &Roots, 1277 SmallPtrSetImpl<Value *> &Seen); 1278 void processConditionals(); 1279 1280 // If an instruction hav been split into resource and offset parts, 1281 // delete that instruction. If any of its uses have not themselves been split 1282 // into parts (for example, an insertvalue), construct the structure 1283 // that the type rewrites declared should be produced by the dying instruction 1284 // and use that. 1285 // Also, kill the temporary extractvalue operations produced by the two-stage 1286 // lowering of PHIs and conditionals. 1287 void killAndReplaceSplitInstructions(SmallVectorImpl<Instruction *> &Origs); 1288 1289 void setAlign(CallInst *Intr, Align A, unsigned RsrcArgIdx); 1290 void insertPreMemOpFence(AtomicOrdering Order, SyncScope::ID SSID); 1291 void insertPostMemOpFence(AtomicOrdering Order, SyncScope::ID SSID); 1292 Value *handleMemoryInst(Instruction *I, Value *Arg, Value *Ptr, Type *Ty, 1293 Align Alignment, AtomicOrdering Order, 1294 bool IsVolatile, SyncScope::ID SSID); 1295 1296 public: 1297 SplitPtrStructs(LLVMContext &Ctx, const TargetMachine *TM) 1298 : TM(TM), IRB(Ctx) {} 1299 1300 void processFunction(Function &F); 1301 1302 PtrParts visitInstruction(Instruction &I); 1303 PtrParts visitLoadInst(LoadInst &LI); 1304 PtrParts visitStoreInst(StoreInst &SI); 1305 PtrParts visitAtomicRMWInst(AtomicRMWInst &AI); 1306 PtrParts visitAtomicCmpXchgInst(AtomicCmpXchgInst &AI); 1307 PtrParts visitGetElementPtrInst(GetElementPtrInst &GEP); 1308 1309 PtrParts visitPtrToIntInst(PtrToIntInst &PI); 1310 PtrParts visitIntToPtrInst(IntToPtrInst &IP); 1311 PtrParts visitAddrSpaceCastInst(AddrSpaceCastInst &I); 1312 PtrParts visitICmpInst(ICmpInst &Cmp); 1313 PtrParts visitFreezeInst(FreezeInst &I); 1314 1315 PtrParts visitExtractElementInst(ExtractElementInst &I); 1316 PtrParts visitInsertElementInst(InsertElementInst &I); 1317 PtrParts visitShuffleVectorInst(ShuffleVectorInst &I); 1318 1319 PtrParts visitPHINode(PHINode &PHI); 1320 PtrParts visitSelectInst(SelectInst &SI); 1321 1322 PtrParts visitIntrinsicInst(IntrinsicInst &II); 1323 }; 1324 } // namespace 1325 1326 void SplitPtrStructs::copyMetadata(Value *Dest, Value *Src) { 1327 auto *DestI = dyn_cast<Instruction>(Dest); 1328 auto *SrcI = dyn_cast<Instruction>(Src); 1329 1330 if (!DestI || !SrcI) 1331 return; 1332 1333 DestI->copyMetadata(*SrcI); 1334 } 1335 1336 PtrParts SplitPtrStructs::getPtrParts(Value *V) { 1337 assert(isSplitFatPtr(V->getType()) && "it's not meaningful to get the parts " 1338 "of something that wasn't rewritten"); 1339 auto *RsrcEntry = &RsrcParts[V]; 1340 auto *OffEntry = &OffParts[V]; 1341 if (*RsrcEntry && *OffEntry) 1342 return {*RsrcEntry, *OffEntry}; 1343 1344 if (auto *C = dyn_cast<Constant>(V)) { 1345 auto [Rsrc, Off] = splitLoweredFatBufferConst(C); 1346 return {*RsrcEntry = Rsrc, *OffEntry = Off}; 1347 } 1348 1349 IRBuilder<>::InsertPointGuard Guard(IRB); 1350 if (auto *I = dyn_cast<Instruction>(V)) { 1351 LLVM_DEBUG(dbgs() << "Recursing to split parts of " << *I << "\n"); 1352 auto [Rsrc, Off] = visit(*I); 1353 if (Rsrc && Off) 1354 return {*RsrcEntry = Rsrc, *OffEntry = Off}; 1355 // We'll be creating the new values after the relevant instruction. 1356 // This instruction generates a value and so isn't a terminator. 1357 IRB.SetInsertPoint(*I->getInsertionPointAfterDef()); 1358 IRB.SetCurrentDebugLocation(I->getDebugLoc()); 1359 } else if (auto *A = dyn_cast<Argument>(V)) { 1360 IRB.SetInsertPointPastAllocas(A->getParent()); 1361 IRB.SetCurrentDebugLocation(DebugLoc()); 1362 } 1363 Value *Rsrc = IRB.CreateExtractValue(V, 0, V->getName() + ".rsrc"); 1364 Value *Off = IRB.CreateExtractValue(V, 1, V->getName() + ".off"); 1365 return {*RsrcEntry = Rsrc, *OffEntry = Off}; 1366 } 1367 1368 /// Returns the instruction that defines the resource part of the value V. 1369 /// Note that this is not getUnderlyingObject(), since that looks through 1370 /// operations like ptrmask which might modify the resource part. 1371 /// 1372 /// We can limit ourselves to just looking through GEPs followed by looking 1373 /// through addrspacecasts because only those two operations preserve the 1374 /// resource part, and because operations on an `addrspace(8)` (which is the 1375 /// legal input to this addrspacecast) would produce a different resource part. 1376 static Value *rsrcPartRoot(Value *V) { 1377 while (auto *GEP = dyn_cast<GEPOperator>(V)) 1378 V = GEP->getPointerOperand(); 1379 while (auto *ASC = dyn_cast<AddrSpaceCastOperator>(V)) 1380 V = ASC->getPointerOperand(); 1381 return V; 1382 } 1383 1384 void SplitPtrStructs::getPossibleRsrcRoots(Instruction *I, 1385 SmallPtrSetImpl<Value *> &Roots, 1386 SmallPtrSetImpl<Value *> &Seen) { 1387 if (auto *PHI = dyn_cast<PHINode>(I)) { 1388 if (!Seen.insert(I).second) 1389 return; 1390 for (Value *In : PHI->incoming_values()) { 1391 In = rsrcPartRoot(In); 1392 Roots.insert(In); 1393 if (isa<PHINode, SelectInst>(In)) 1394 getPossibleRsrcRoots(cast<Instruction>(In), Roots, Seen); 1395 } 1396 } else if (auto *SI = dyn_cast<SelectInst>(I)) { 1397 if (!Seen.insert(SI).second) 1398 return; 1399 Value *TrueVal = rsrcPartRoot(SI->getTrueValue()); 1400 Value *FalseVal = rsrcPartRoot(SI->getFalseValue()); 1401 Roots.insert(TrueVal); 1402 Roots.insert(FalseVal); 1403 if (isa<PHINode, SelectInst>(TrueVal)) 1404 getPossibleRsrcRoots(cast<Instruction>(TrueVal), Roots, Seen); 1405 if (isa<PHINode, SelectInst>(FalseVal)) 1406 getPossibleRsrcRoots(cast<Instruction>(FalseVal), Roots, Seen); 1407 } else { 1408 llvm_unreachable("getPossibleRsrcParts() only works on phi and select"); 1409 } 1410 } 1411 1412 void SplitPtrStructs::processConditionals() { 1413 SmallDenseMap<Instruction *, Value *> FoundRsrcs; 1414 SmallPtrSet<Value *, 4> Roots; 1415 SmallPtrSet<Value *, 4> Seen; 1416 for (Instruction *I : Conditionals) { 1417 // These have to exist by now because we've visited these nodes. 1418 Value *Rsrc = RsrcParts[I]; 1419 Value *Off = OffParts[I]; 1420 assert(Rsrc && Off && "must have visited conditionals by now"); 1421 1422 std::optional<Value *> MaybeRsrc; 1423 auto MaybeFoundRsrc = FoundRsrcs.find(I); 1424 if (MaybeFoundRsrc != FoundRsrcs.end()) { 1425 MaybeRsrc = MaybeFoundRsrc->second; 1426 } else { 1427 IRBuilder<>::InsertPointGuard Guard(IRB); 1428 Roots.clear(); 1429 Seen.clear(); 1430 getPossibleRsrcRoots(I, Roots, Seen); 1431 LLVM_DEBUG(dbgs() << "Processing conditional: " << *I << "\n"); 1432 #ifndef NDEBUG 1433 for (Value *V : Roots) 1434 LLVM_DEBUG(dbgs() << "Root: " << *V << "\n"); 1435 for (Value *V : Seen) 1436 LLVM_DEBUG(dbgs() << "Seen: " << *V << "\n"); 1437 #endif 1438 // If we are our own possible root, then we shouldn't block our 1439 // replacement with a valid incoming value. 1440 Roots.erase(I); 1441 // We don't want to block the optimization for conditionals that don't 1442 // refer to themselves but did see themselves during the traversal. 1443 Seen.erase(I); 1444 1445 if (set_is_subset(Seen, Roots)) { 1446 auto Diff = set_difference(Roots, Seen); 1447 if (Diff.size() == 1) { 1448 Value *RootVal = *Diff.begin(); 1449 // Handle the case where previous loops already looked through 1450 // an addrspacecast. 1451 if (isSplitFatPtr(RootVal->getType())) 1452 MaybeRsrc = std::get<0>(getPtrParts(RootVal)); 1453 else 1454 MaybeRsrc = RootVal; 1455 } 1456 } 1457 } 1458 1459 if (auto *PHI = dyn_cast<PHINode>(I)) { 1460 Value *NewRsrc; 1461 StructType *PHITy = cast<StructType>(PHI->getType()); 1462 IRB.SetInsertPoint(*PHI->getInsertionPointAfterDef()); 1463 IRB.SetCurrentDebugLocation(PHI->getDebugLoc()); 1464 if (MaybeRsrc) { 1465 NewRsrc = *MaybeRsrc; 1466 } else { 1467 Type *RsrcTy = PHITy->getElementType(0); 1468 auto *RsrcPHI = IRB.CreatePHI(RsrcTy, PHI->getNumIncomingValues()); 1469 RsrcPHI->takeName(Rsrc); 1470 for (auto [V, BB] : llvm::zip(PHI->incoming_values(), PHI->blocks())) { 1471 Value *VRsrc = std::get<0>(getPtrParts(V)); 1472 RsrcPHI->addIncoming(VRsrc, BB); 1473 } 1474 copyMetadata(RsrcPHI, PHI); 1475 NewRsrc = RsrcPHI; 1476 } 1477 1478 Type *OffTy = PHITy->getElementType(1); 1479 auto *NewOff = IRB.CreatePHI(OffTy, PHI->getNumIncomingValues()); 1480 NewOff->takeName(Off); 1481 for (auto [V, BB] : llvm::zip(PHI->incoming_values(), PHI->blocks())) { 1482 assert(OffParts.count(V) && "An offset part had to be created by now"); 1483 Value *VOff = std::get<1>(getPtrParts(V)); 1484 NewOff->addIncoming(VOff, BB); 1485 } 1486 copyMetadata(NewOff, PHI); 1487 1488 // Note: We don't eraseFromParent() the temporaries because we don't want 1489 // to put the corrections maps in an inconstent state. That'll be handed 1490 // during the rest of the killing. Also, `ValueToValueMapTy` guarantees 1491 // that references in that map will be updated as well. 1492 ConditionalTemps.push_back(cast<Instruction>(Rsrc)); 1493 ConditionalTemps.push_back(cast<Instruction>(Off)); 1494 Rsrc->replaceAllUsesWith(NewRsrc); 1495 Off->replaceAllUsesWith(NewOff); 1496 1497 // Save on recomputing the cycle traversals in known-root cases. 1498 if (MaybeRsrc) 1499 for (Value *V : Seen) 1500 FoundRsrcs[cast<Instruction>(V)] = NewRsrc; 1501 } else if (isa<SelectInst>(I)) { 1502 if (MaybeRsrc) { 1503 ConditionalTemps.push_back(cast<Instruction>(Rsrc)); 1504 Rsrc->replaceAllUsesWith(*MaybeRsrc); 1505 for (Value *V : Seen) 1506 FoundRsrcs[cast<Instruction>(V)] = *MaybeRsrc; 1507 } 1508 } else { 1509 llvm_unreachable("Only PHIs and selects go in the conditionals list"); 1510 } 1511 } 1512 } 1513 1514 void SplitPtrStructs::killAndReplaceSplitInstructions( 1515 SmallVectorImpl<Instruction *> &Origs) { 1516 for (Instruction *I : ConditionalTemps) 1517 I->eraseFromParent(); 1518 1519 for (Instruction *I : Origs) { 1520 if (!SplitUsers.contains(I)) 1521 continue; 1522 1523 SmallVector<DbgValueInst *> Dbgs; 1524 findDbgValues(Dbgs, I); 1525 for (auto *Dbg : Dbgs) { 1526 IRB.SetInsertPoint(Dbg); 1527 auto &DL = I->getDataLayout(); 1528 assert(isSplitFatPtr(I->getType()) && 1529 "We should've RAUW'd away loads, stores, etc. at this point"); 1530 auto *OffDbg = cast<DbgValueInst>(Dbg->clone()); 1531 copyMetadata(OffDbg, Dbg); 1532 auto [Rsrc, Off] = getPtrParts(I); 1533 1534 int64_t RsrcSz = DL.getTypeSizeInBits(Rsrc->getType()); 1535 int64_t OffSz = DL.getTypeSizeInBits(Off->getType()); 1536 1537 std::optional<DIExpression *> RsrcExpr = 1538 DIExpression::createFragmentExpression(Dbg->getExpression(), 0, 1539 RsrcSz); 1540 std::optional<DIExpression *> OffExpr = 1541 DIExpression::createFragmentExpression(Dbg->getExpression(), RsrcSz, 1542 OffSz); 1543 if (OffExpr) { 1544 OffDbg->setExpression(*OffExpr); 1545 OffDbg->replaceVariableLocationOp(I, Off); 1546 IRB.Insert(OffDbg); 1547 } else { 1548 OffDbg->deleteValue(); 1549 } 1550 if (RsrcExpr) { 1551 Dbg->setExpression(*RsrcExpr); 1552 Dbg->replaceVariableLocationOp(I, Rsrc); 1553 } else { 1554 Dbg->replaceVariableLocationOp(I, UndefValue::get(I->getType())); 1555 } 1556 } 1557 1558 Value *Poison = PoisonValue::get(I->getType()); 1559 I->replaceUsesWithIf(Poison, [&](const Use &U) -> bool { 1560 if (const auto *UI = dyn_cast<Instruction>(U.getUser())) 1561 return SplitUsers.contains(UI); 1562 return false; 1563 }); 1564 1565 if (I->use_empty()) { 1566 I->eraseFromParent(); 1567 continue; 1568 } 1569 IRB.SetInsertPoint(*I->getInsertionPointAfterDef()); 1570 IRB.SetCurrentDebugLocation(I->getDebugLoc()); 1571 auto [Rsrc, Off] = getPtrParts(I); 1572 Value *Struct = PoisonValue::get(I->getType()); 1573 Struct = IRB.CreateInsertValue(Struct, Rsrc, 0); 1574 Struct = IRB.CreateInsertValue(Struct, Off, 1); 1575 copyMetadata(Struct, I); 1576 Struct->takeName(I); 1577 I->replaceAllUsesWith(Struct); 1578 I->eraseFromParent(); 1579 } 1580 } 1581 1582 void SplitPtrStructs::setAlign(CallInst *Intr, Align A, unsigned RsrcArgIdx) { 1583 LLVMContext &Ctx = Intr->getContext(); 1584 Intr->addParamAttr(RsrcArgIdx, Attribute::getWithAlignment(Ctx, A)); 1585 } 1586 1587 void SplitPtrStructs::insertPreMemOpFence(AtomicOrdering Order, 1588 SyncScope::ID SSID) { 1589 switch (Order) { 1590 case AtomicOrdering::Release: 1591 case AtomicOrdering::AcquireRelease: 1592 case AtomicOrdering::SequentiallyConsistent: 1593 IRB.CreateFence(AtomicOrdering::Release, SSID); 1594 break; 1595 default: 1596 break; 1597 } 1598 } 1599 1600 void SplitPtrStructs::insertPostMemOpFence(AtomicOrdering Order, 1601 SyncScope::ID SSID) { 1602 switch (Order) { 1603 case AtomicOrdering::Acquire: 1604 case AtomicOrdering::AcquireRelease: 1605 case AtomicOrdering::SequentiallyConsistent: 1606 IRB.CreateFence(AtomicOrdering::Acquire, SSID); 1607 break; 1608 default: 1609 break; 1610 } 1611 } 1612 1613 Value *SplitPtrStructs::handleMemoryInst(Instruction *I, Value *Arg, Value *Ptr, 1614 Type *Ty, Align Alignment, 1615 AtomicOrdering Order, bool IsVolatile, 1616 SyncScope::ID SSID) { 1617 IRB.SetInsertPoint(I); 1618 1619 auto [Rsrc, Off] = getPtrParts(Ptr); 1620 SmallVector<Value *, 5> Args; 1621 if (Arg) 1622 Args.push_back(Arg); 1623 Args.push_back(Rsrc); 1624 Args.push_back(Off); 1625 insertPreMemOpFence(Order, SSID); 1626 // soffset is always 0 for these cases, where we always want any offset to be 1627 // part of bounds checking and we don't know which parts of the GEPs is 1628 // uniform. 1629 Args.push_back(IRB.getInt32(0)); 1630 1631 uint32_t Aux = 0; 1632 if (IsVolatile) 1633 Aux |= AMDGPU::CPol::VOLATILE; 1634 Args.push_back(IRB.getInt32(Aux)); 1635 1636 Intrinsic::ID IID = Intrinsic::not_intrinsic; 1637 if (isa<LoadInst>(I)) 1638 IID = Order == AtomicOrdering::NotAtomic 1639 ? Intrinsic::amdgcn_raw_ptr_buffer_load 1640 : Intrinsic::amdgcn_raw_ptr_atomic_buffer_load; 1641 else if (isa<StoreInst>(I)) 1642 IID = Intrinsic::amdgcn_raw_ptr_buffer_store; 1643 else if (auto *RMW = dyn_cast<AtomicRMWInst>(I)) { 1644 switch (RMW->getOperation()) { 1645 case AtomicRMWInst::Xchg: 1646 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_swap; 1647 break; 1648 case AtomicRMWInst::Add: 1649 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_add; 1650 break; 1651 case AtomicRMWInst::Sub: 1652 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_sub; 1653 break; 1654 case AtomicRMWInst::And: 1655 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_and; 1656 break; 1657 case AtomicRMWInst::Or: 1658 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_or; 1659 break; 1660 case AtomicRMWInst::Xor: 1661 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_xor; 1662 break; 1663 case AtomicRMWInst::Max: 1664 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_smax; 1665 break; 1666 case AtomicRMWInst::Min: 1667 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_smin; 1668 break; 1669 case AtomicRMWInst::UMax: 1670 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_umax; 1671 break; 1672 case AtomicRMWInst::UMin: 1673 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_umin; 1674 break; 1675 case AtomicRMWInst::FAdd: 1676 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fadd; 1677 break; 1678 case AtomicRMWInst::FMax: 1679 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fmax; 1680 break; 1681 case AtomicRMWInst::FMin: 1682 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fmin; 1683 break; 1684 case AtomicRMWInst::FSub: { 1685 report_fatal_error("atomic floating point subtraction not supported for " 1686 "buffer resources and should've been expanded away"); 1687 break; 1688 } 1689 case AtomicRMWInst::Nand: 1690 report_fatal_error("atomic nand not supported for buffer resources and " 1691 "should've been expanded away"); 1692 break; 1693 case AtomicRMWInst::UIncWrap: 1694 case AtomicRMWInst::UDecWrap: 1695 report_fatal_error("wrapping increment/decrement not supported for " 1696 "buffer resources and should've ben expanded away"); 1697 break; 1698 case AtomicRMWInst::BAD_BINOP: 1699 llvm_unreachable("Not sure how we got a bad binop"); 1700 case AtomicRMWInst::USubCond: 1701 case AtomicRMWInst::USubSat: 1702 break; 1703 } 1704 } 1705 1706 auto *Call = IRB.CreateIntrinsic(IID, Ty, Args); 1707 copyMetadata(Call, I); 1708 setAlign(Call, Alignment, Arg ? 1 : 0); 1709 Call->takeName(I); 1710 1711 insertPostMemOpFence(Order, SSID); 1712 // The "no moving p7 directly" rewrites ensure that this load or store won't 1713 // itself need to be split into parts. 1714 SplitUsers.insert(I); 1715 I->replaceAllUsesWith(Call); 1716 return Call; 1717 } 1718 1719 PtrParts SplitPtrStructs::visitInstruction(Instruction &I) { 1720 return {nullptr, nullptr}; 1721 } 1722 1723 PtrParts SplitPtrStructs::visitLoadInst(LoadInst &LI) { 1724 if (!isSplitFatPtr(LI.getPointerOperandType())) 1725 return {nullptr, nullptr}; 1726 handleMemoryInst(&LI, nullptr, LI.getPointerOperand(), LI.getType(), 1727 LI.getAlign(), LI.getOrdering(), LI.isVolatile(), 1728 LI.getSyncScopeID()); 1729 return {nullptr, nullptr}; 1730 } 1731 1732 PtrParts SplitPtrStructs::visitStoreInst(StoreInst &SI) { 1733 if (!isSplitFatPtr(SI.getPointerOperandType())) 1734 return {nullptr, nullptr}; 1735 Value *Arg = SI.getValueOperand(); 1736 handleMemoryInst(&SI, Arg, SI.getPointerOperand(), Arg->getType(), 1737 SI.getAlign(), SI.getOrdering(), SI.isVolatile(), 1738 SI.getSyncScopeID()); 1739 return {nullptr, nullptr}; 1740 } 1741 1742 PtrParts SplitPtrStructs::visitAtomicRMWInst(AtomicRMWInst &AI) { 1743 if (!isSplitFatPtr(AI.getPointerOperand()->getType())) 1744 return {nullptr, nullptr}; 1745 Value *Arg = AI.getValOperand(); 1746 handleMemoryInst(&AI, Arg, AI.getPointerOperand(), Arg->getType(), 1747 AI.getAlign(), AI.getOrdering(), AI.isVolatile(), 1748 AI.getSyncScopeID()); 1749 return {nullptr, nullptr}; 1750 } 1751 1752 // Unlike load, store, and RMW, cmpxchg needs special handling to account 1753 // for the boolean argument. 1754 PtrParts SplitPtrStructs::visitAtomicCmpXchgInst(AtomicCmpXchgInst &AI) { 1755 Value *Ptr = AI.getPointerOperand(); 1756 if (!isSplitFatPtr(Ptr->getType())) 1757 return {nullptr, nullptr}; 1758 IRB.SetInsertPoint(&AI); 1759 1760 Type *Ty = AI.getNewValOperand()->getType(); 1761 AtomicOrdering Order = AI.getMergedOrdering(); 1762 SyncScope::ID SSID = AI.getSyncScopeID(); 1763 bool IsNonTemporal = AI.getMetadata(LLVMContext::MD_nontemporal); 1764 1765 auto [Rsrc, Off] = getPtrParts(Ptr); 1766 insertPreMemOpFence(Order, SSID); 1767 1768 uint32_t Aux = 0; 1769 if (IsNonTemporal) 1770 Aux |= AMDGPU::CPol::SLC; 1771 if (AI.isVolatile()) 1772 Aux |= AMDGPU::CPol::VOLATILE; 1773 auto *Call = 1774 IRB.CreateIntrinsic(Intrinsic::amdgcn_raw_ptr_buffer_atomic_cmpswap, Ty, 1775 {AI.getNewValOperand(), AI.getCompareOperand(), Rsrc, 1776 Off, IRB.getInt32(0), IRB.getInt32(Aux)}); 1777 copyMetadata(Call, &AI); 1778 setAlign(Call, AI.getAlign(), 2); 1779 Call->takeName(&AI); 1780 insertPostMemOpFence(Order, SSID); 1781 1782 Value *Res = PoisonValue::get(AI.getType()); 1783 Res = IRB.CreateInsertValue(Res, Call, 0); 1784 if (!AI.isWeak()) { 1785 Value *Succeeded = IRB.CreateICmpEQ(Call, AI.getCompareOperand()); 1786 Res = IRB.CreateInsertValue(Res, Succeeded, 1); 1787 } 1788 SplitUsers.insert(&AI); 1789 AI.replaceAllUsesWith(Res); 1790 return {nullptr, nullptr}; 1791 } 1792 1793 PtrParts SplitPtrStructs::visitGetElementPtrInst(GetElementPtrInst &GEP) { 1794 using namespace llvm::PatternMatch; 1795 Value *Ptr = GEP.getPointerOperand(); 1796 if (!isSplitFatPtr(Ptr->getType())) 1797 return {nullptr, nullptr}; 1798 IRB.SetInsertPoint(&GEP); 1799 1800 auto [Rsrc, Off] = getPtrParts(Ptr); 1801 const DataLayout &DL = GEP.getDataLayout(); 1802 bool IsNUW = GEP.hasNoUnsignedWrap(); 1803 bool IsNUSW = GEP.hasNoUnsignedSignedWrap(); 1804 1805 // In order to call emitGEPOffset() and thus not have to reimplement it, 1806 // we need the GEP result to have ptr addrspace(7) type. 1807 Type *FatPtrTy = IRB.getPtrTy(AMDGPUAS::BUFFER_FAT_POINTER); 1808 if (auto *VT = dyn_cast<VectorType>(Off->getType())) 1809 FatPtrTy = VectorType::get(FatPtrTy, VT->getElementCount()); 1810 GEP.mutateType(FatPtrTy); 1811 Value *OffAccum = emitGEPOffset(&IRB, DL, &GEP); 1812 GEP.mutateType(Ptr->getType()); 1813 if (match(OffAccum, m_Zero())) { // Constant-zero offset 1814 SplitUsers.insert(&GEP); 1815 return {Rsrc, Off}; 1816 } 1817 1818 bool HasNonNegativeOff = false; 1819 if (auto *CI = dyn_cast<ConstantInt>(OffAccum)) { 1820 HasNonNegativeOff = !CI->isNegative(); 1821 } 1822 Value *NewOff; 1823 if (match(Off, m_Zero())) { 1824 NewOff = OffAccum; 1825 } else { 1826 NewOff = IRB.CreateAdd(Off, OffAccum, "", 1827 /*hasNUW=*/IsNUW || (IsNUSW && HasNonNegativeOff), 1828 /*hasNSW=*/false); 1829 } 1830 copyMetadata(NewOff, &GEP); 1831 NewOff->takeName(&GEP); 1832 SplitUsers.insert(&GEP); 1833 return {Rsrc, NewOff}; 1834 } 1835 1836 PtrParts SplitPtrStructs::visitPtrToIntInst(PtrToIntInst &PI) { 1837 Value *Ptr = PI.getPointerOperand(); 1838 if (!isSplitFatPtr(Ptr->getType())) 1839 return {nullptr, nullptr}; 1840 IRB.SetInsertPoint(&PI); 1841 1842 Type *ResTy = PI.getType(); 1843 unsigned Width = ResTy->getScalarSizeInBits(); 1844 1845 auto [Rsrc, Off] = getPtrParts(Ptr); 1846 const DataLayout &DL = PI.getDataLayout(); 1847 unsigned FatPtrWidth = DL.getPointerSizeInBits(AMDGPUAS::BUFFER_FAT_POINTER); 1848 1849 Value *Res; 1850 if (Width <= BufferOffsetWidth) { 1851 Res = IRB.CreateIntCast(Off, ResTy, /*isSigned=*/false, 1852 PI.getName() + ".off"); 1853 } else { 1854 Value *RsrcInt = IRB.CreatePtrToInt(Rsrc, ResTy, PI.getName() + ".rsrc"); 1855 Value *Shl = IRB.CreateShl( 1856 RsrcInt, 1857 ConstantExpr::getIntegerValue(ResTy, APInt(Width, BufferOffsetWidth)), 1858 "", Width >= FatPtrWidth, Width > FatPtrWidth); 1859 Value *OffCast = IRB.CreateIntCast(Off, ResTy, /*isSigned=*/false, 1860 PI.getName() + ".off"); 1861 Res = IRB.CreateOr(Shl, OffCast); 1862 } 1863 1864 copyMetadata(Res, &PI); 1865 Res->takeName(&PI); 1866 SplitUsers.insert(&PI); 1867 PI.replaceAllUsesWith(Res); 1868 return {nullptr, nullptr}; 1869 } 1870 1871 PtrParts SplitPtrStructs::visitIntToPtrInst(IntToPtrInst &IP) { 1872 if (!isSplitFatPtr(IP.getType())) 1873 return {nullptr, nullptr}; 1874 IRB.SetInsertPoint(&IP); 1875 const DataLayout &DL = IP.getDataLayout(); 1876 unsigned RsrcPtrWidth = DL.getPointerSizeInBits(AMDGPUAS::BUFFER_RESOURCE); 1877 Value *Int = IP.getOperand(0); 1878 Type *IntTy = Int->getType(); 1879 Type *RsrcIntTy = IntTy->getWithNewBitWidth(RsrcPtrWidth); 1880 unsigned Width = IntTy->getScalarSizeInBits(); 1881 1882 auto *RetTy = cast<StructType>(IP.getType()); 1883 Type *RsrcTy = RetTy->getElementType(0); 1884 Type *OffTy = RetTy->getElementType(1); 1885 Value *RsrcPart = IRB.CreateLShr( 1886 Int, 1887 ConstantExpr::getIntegerValue(IntTy, APInt(Width, BufferOffsetWidth))); 1888 Value *RsrcInt = IRB.CreateIntCast(RsrcPart, RsrcIntTy, /*isSigned=*/false); 1889 Value *Rsrc = IRB.CreateIntToPtr(RsrcInt, RsrcTy, IP.getName() + ".rsrc"); 1890 Value *Off = 1891 IRB.CreateIntCast(Int, OffTy, /*IsSigned=*/false, IP.getName() + ".off"); 1892 1893 copyMetadata(Rsrc, &IP); 1894 SplitUsers.insert(&IP); 1895 return {Rsrc, Off}; 1896 } 1897 1898 PtrParts SplitPtrStructs::visitAddrSpaceCastInst(AddrSpaceCastInst &I) { 1899 if (!isSplitFatPtr(I.getType())) 1900 return {nullptr, nullptr}; 1901 IRB.SetInsertPoint(&I); 1902 Value *In = I.getPointerOperand(); 1903 // No-op casts preserve parts 1904 if (In->getType() == I.getType()) { 1905 auto [Rsrc, Off] = getPtrParts(In); 1906 SplitUsers.insert(&I); 1907 return {Rsrc, Off}; 1908 } 1909 if (I.getSrcAddressSpace() != AMDGPUAS::BUFFER_RESOURCE) 1910 report_fatal_error("Only buffer resources (addrspace 8) can be cast to " 1911 "buffer fat pointers (addrspace 7)"); 1912 Type *OffTy = cast<StructType>(I.getType())->getElementType(1); 1913 Value *ZeroOff = Constant::getNullValue(OffTy); 1914 SplitUsers.insert(&I); 1915 return {In, ZeroOff}; 1916 } 1917 1918 PtrParts SplitPtrStructs::visitICmpInst(ICmpInst &Cmp) { 1919 Value *Lhs = Cmp.getOperand(0); 1920 if (!isSplitFatPtr(Lhs->getType())) 1921 return {nullptr, nullptr}; 1922 Value *Rhs = Cmp.getOperand(1); 1923 IRB.SetInsertPoint(&Cmp); 1924 ICmpInst::Predicate Pred = Cmp.getPredicate(); 1925 1926 assert((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) && 1927 "Pointer comparison is only equal or unequal"); 1928 auto [LhsRsrc, LhsOff] = getPtrParts(Lhs); 1929 auto [RhsRsrc, RhsOff] = getPtrParts(Rhs); 1930 Value *RsrcCmp = 1931 IRB.CreateICmp(Pred, LhsRsrc, RhsRsrc, Cmp.getName() + ".rsrc"); 1932 copyMetadata(RsrcCmp, &Cmp); 1933 Value *OffCmp = IRB.CreateICmp(Pred, LhsOff, RhsOff, Cmp.getName() + ".off"); 1934 copyMetadata(OffCmp, &Cmp); 1935 1936 Value *Res = nullptr; 1937 if (Pred == ICmpInst::ICMP_EQ) 1938 Res = IRB.CreateAnd(RsrcCmp, OffCmp); 1939 else if (Pred == ICmpInst::ICMP_NE) 1940 Res = IRB.CreateOr(RsrcCmp, OffCmp); 1941 copyMetadata(Res, &Cmp); 1942 Res->takeName(&Cmp); 1943 SplitUsers.insert(&Cmp); 1944 Cmp.replaceAllUsesWith(Res); 1945 return {nullptr, nullptr}; 1946 } 1947 1948 PtrParts SplitPtrStructs::visitFreezeInst(FreezeInst &I) { 1949 if (!isSplitFatPtr(I.getType())) 1950 return {nullptr, nullptr}; 1951 IRB.SetInsertPoint(&I); 1952 auto [Rsrc, Off] = getPtrParts(I.getOperand(0)); 1953 1954 Value *RsrcRes = IRB.CreateFreeze(Rsrc, I.getName() + ".rsrc"); 1955 copyMetadata(RsrcRes, &I); 1956 Value *OffRes = IRB.CreateFreeze(Off, I.getName() + ".off"); 1957 copyMetadata(OffRes, &I); 1958 SplitUsers.insert(&I); 1959 return {RsrcRes, OffRes}; 1960 } 1961 1962 PtrParts SplitPtrStructs::visitExtractElementInst(ExtractElementInst &I) { 1963 if (!isSplitFatPtr(I.getType())) 1964 return {nullptr, nullptr}; 1965 IRB.SetInsertPoint(&I); 1966 Value *Vec = I.getVectorOperand(); 1967 Value *Idx = I.getIndexOperand(); 1968 auto [Rsrc, Off] = getPtrParts(Vec); 1969 1970 Value *RsrcRes = IRB.CreateExtractElement(Rsrc, Idx, I.getName() + ".rsrc"); 1971 copyMetadata(RsrcRes, &I); 1972 Value *OffRes = IRB.CreateExtractElement(Off, Idx, I.getName() + ".off"); 1973 copyMetadata(OffRes, &I); 1974 SplitUsers.insert(&I); 1975 return {RsrcRes, OffRes}; 1976 } 1977 1978 PtrParts SplitPtrStructs::visitInsertElementInst(InsertElementInst &I) { 1979 // The mutated instructions temporarily don't return vectors, and so 1980 // we need the generic getType() here to avoid crashes. 1981 if (!isSplitFatPtr(cast<Instruction>(I).getType())) 1982 return {nullptr, nullptr}; 1983 IRB.SetInsertPoint(&I); 1984 Value *Vec = I.getOperand(0); 1985 Value *Elem = I.getOperand(1); 1986 Value *Idx = I.getOperand(2); 1987 auto [VecRsrc, VecOff] = getPtrParts(Vec); 1988 auto [ElemRsrc, ElemOff] = getPtrParts(Elem); 1989 1990 Value *RsrcRes = 1991 IRB.CreateInsertElement(VecRsrc, ElemRsrc, Idx, I.getName() + ".rsrc"); 1992 copyMetadata(RsrcRes, &I); 1993 Value *OffRes = 1994 IRB.CreateInsertElement(VecOff, ElemOff, Idx, I.getName() + ".off"); 1995 copyMetadata(OffRes, &I); 1996 SplitUsers.insert(&I); 1997 return {RsrcRes, OffRes}; 1998 } 1999 2000 PtrParts SplitPtrStructs::visitShuffleVectorInst(ShuffleVectorInst &I) { 2001 // Cast is needed for the same reason as insertelement's. 2002 if (!isSplitFatPtr(cast<Instruction>(I).getType())) 2003 return {nullptr, nullptr}; 2004 IRB.SetInsertPoint(&I); 2005 2006 Value *V1 = I.getOperand(0); 2007 Value *V2 = I.getOperand(1); 2008 ArrayRef<int> Mask = I.getShuffleMask(); 2009 auto [V1Rsrc, V1Off] = getPtrParts(V1); 2010 auto [V2Rsrc, V2Off] = getPtrParts(V2); 2011 2012 Value *RsrcRes = 2013 IRB.CreateShuffleVector(V1Rsrc, V2Rsrc, Mask, I.getName() + ".rsrc"); 2014 copyMetadata(RsrcRes, &I); 2015 Value *OffRes = 2016 IRB.CreateShuffleVector(V1Off, V2Off, Mask, I.getName() + ".off"); 2017 copyMetadata(OffRes, &I); 2018 SplitUsers.insert(&I); 2019 return {RsrcRes, OffRes}; 2020 } 2021 2022 PtrParts SplitPtrStructs::visitPHINode(PHINode &PHI) { 2023 if (!isSplitFatPtr(PHI.getType())) 2024 return {nullptr, nullptr}; 2025 IRB.SetInsertPoint(*PHI.getInsertionPointAfterDef()); 2026 // Phi nodes will be handled in post-processing after we've visited every 2027 // instruction. However, instead of just returning {nullptr, nullptr}, 2028 // we explicitly create the temporary extractvalue operations that are our 2029 // temporary results so that they end up at the beginning of the block with 2030 // the PHIs. 2031 Value *TmpRsrc = IRB.CreateExtractValue(&PHI, 0, PHI.getName() + ".rsrc"); 2032 Value *TmpOff = IRB.CreateExtractValue(&PHI, 1, PHI.getName() + ".off"); 2033 Conditionals.push_back(&PHI); 2034 SplitUsers.insert(&PHI); 2035 return {TmpRsrc, TmpOff}; 2036 } 2037 2038 PtrParts SplitPtrStructs::visitSelectInst(SelectInst &SI) { 2039 if (!isSplitFatPtr(SI.getType())) 2040 return {nullptr, nullptr}; 2041 IRB.SetInsertPoint(&SI); 2042 2043 Value *Cond = SI.getCondition(); 2044 Value *True = SI.getTrueValue(); 2045 Value *False = SI.getFalseValue(); 2046 auto [TrueRsrc, TrueOff] = getPtrParts(True); 2047 auto [FalseRsrc, FalseOff] = getPtrParts(False); 2048 2049 Value *RsrcRes = 2050 IRB.CreateSelect(Cond, TrueRsrc, FalseRsrc, SI.getName() + ".rsrc", &SI); 2051 copyMetadata(RsrcRes, &SI); 2052 Conditionals.push_back(&SI); 2053 Value *OffRes = 2054 IRB.CreateSelect(Cond, TrueOff, FalseOff, SI.getName() + ".off", &SI); 2055 copyMetadata(OffRes, &SI); 2056 SplitUsers.insert(&SI); 2057 return {RsrcRes, OffRes}; 2058 } 2059 2060 /// Returns true if this intrinsic needs to be removed when it is 2061 /// applied to `ptr addrspace(7)` values. Calls to these intrinsics are 2062 /// rewritten into calls to versions of that intrinsic on the resource 2063 /// descriptor. 2064 static bool isRemovablePointerIntrinsic(Intrinsic::ID IID) { 2065 switch (IID) { 2066 default: 2067 return false; 2068 case Intrinsic::ptrmask: 2069 case Intrinsic::invariant_start: 2070 case Intrinsic::invariant_end: 2071 case Intrinsic::launder_invariant_group: 2072 case Intrinsic::strip_invariant_group: 2073 return true; 2074 } 2075 } 2076 2077 PtrParts SplitPtrStructs::visitIntrinsicInst(IntrinsicInst &I) { 2078 Intrinsic::ID IID = I.getIntrinsicID(); 2079 switch (IID) { 2080 default: 2081 break; 2082 case Intrinsic::ptrmask: { 2083 Value *Ptr = I.getArgOperand(0); 2084 if (!isSplitFatPtr(Ptr->getType())) 2085 return {nullptr, nullptr}; 2086 Value *Mask = I.getArgOperand(1); 2087 IRB.SetInsertPoint(&I); 2088 auto [Rsrc, Off] = getPtrParts(Ptr); 2089 if (Mask->getType() != Off->getType()) 2090 report_fatal_error("offset width is not equal to index width of fat " 2091 "pointer (data layout not set up correctly?)"); 2092 Value *OffRes = IRB.CreateAnd(Off, Mask, I.getName() + ".off"); 2093 copyMetadata(OffRes, &I); 2094 SplitUsers.insert(&I); 2095 return {Rsrc, OffRes}; 2096 } 2097 // Pointer annotation intrinsics that, given their object-wide nature 2098 // operate on the resource part. 2099 case Intrinsic::invariant_start: { 2100 Value *Ptr = I.getArgOperand(1); 2101 if (!isSplitFatPtr(Ptr->getType())) 2102 return {nullptr, nullptr}; 2103 IRB.SetInsertPoint(&I); 2104 auto [Rsrc, Off] = getPtrParts(Ptr); 2105 Type *NewTy = PointerType::get(I.getContext(), AMDGPUAS::BUFFER_RESOURCE); 2106 auto *NewRsrc = IRB.CreateIntrinsic(IID, {NewTy}, {I.getOperand(0), Rsrc}); 2107 copyMetadata(NewRsrc, &I); 2108 NewRsrc->takeName(&I); 2109 SplitUsers.insert(&I); 2110 I.replaceAllUsesWith(NewRsrc); 2111 return {nullptr, nullptr}; 2112 } 2113 case Intrinsic::invariant_end: { 2114 Value *RealPtr = I.getArgOperand(2); 2115 if (!isSplitFatPtr(RealPtr->getType())) 2116 return {nullptr, nullptr}; 2117 IRB.SetInsertPoint(&I); 2118 Value *RealRsrc = getPtrParts(RealPtr).first; 2119 Value *InvPtr = I.getArgOperand(0); 2120 Value *Size = I.getArgOperand(1); 2121 Value *NewRsrc = IRB.CreateIntrinsic(IID, {RealRsrc->getType()}, 2122 {InvPtr, Size, RealRsrc}); 2123 copyMetadata(NewRsrc, &I); 2124 NewRsrc->takeName(&I); 2125 SplitUsers.insert(&I); 2126 I.replaceAllUsesWith(NewRsrc); 2127 return {nullptr, nullptr}; 2128 } 2129 case Intrinsic::launder_invariant_group: 2130 case Intrinsic::strip_invariant_group: { 2131 Value *Ptr = I.getArgOperand(0); 2132 if (!isSplitFatPtr(Ptr->getType())) 2133 return {nullptr, nullptr}; 2134 IRB.SetInsertPoint(&I); 2135 auto [Rsrc, Off] = getPtrParts(Ptr); 2136 Value *NewRsrc = IRB.CreateIntrinsic(IID, {Rsrc->getType()}, {Rsrc}); 2137 copyMetadata(NewRsrc, &I); 2138 NewRsrc->takeName(&I); 2139 SplitUsers.insert(&I); 2140 return {NewRsrc, Off}; 2141 } 2142 } 2143 return {nullptr, nullptr}; 2144 } 2145 2146 void SplitPtrStructs::processFunction(Function &F) { 2147 ST = &TM->getSubtarget<GCNSubtarget>(F); 2148 SmallVector<Instruction *, 0> Originals; 2149 LLVM_DEBUG(dbgs() << "Splitting pointer structs in function: " << F.getName() 2150 << "\n"); 2151 for (Instruction &I : instructions(F)) 2152 Originals.push_back(&I); 2153 for (Instruction *I : Originals) { 2154 auto [Rsrc, Off] = visit(I); 2155 assert(((Rsrc && Off) || (!Rsrc && !Off)) && 2156 "Can't have a resource but no offset"); 2157 if (Rsrc) 2158 RsrcParts[I] = Rsrc; 2159 if (Off) 2160 OffParts[I] = Off; 2161 } 2162 processConditionals(); 2163 killAndReplaceSplitInstructions(Originals); 2164 2165 // Clean up after ourselves to save on memory. 2166 RsrcParts.clear(); 2167 OffParts.clear(); 2168 SplitUsers.clear(); 2169 Conditionals.clear(); 2170 ConditionalTemps.clear(); 2171 } 2172 2173 namespace { 2174 class AMDGPULowerBufferFatPointers : public ModulePass { 2175 public: 2176 static char ID; 2177 2178 AMDGPULowerBufferFatPointers() : ModulePass(ID) { 2179 initializeAMDGPULowerBufferFatPointersPass( 2180 *PassRegistry::getPassRegistry()); 2181 } 2182 2183 bool run(Module &M, const TargetMachine &TM); 2184 bool runOnModule(Module &M) override; 2185 2186 void getAnalysisUsage(AnalysisUsage &AU) const override; 2187 }; 2188 } // namespace 2189 2190 /// Returns true if there are values that have a buffer fat pointer in them, 2191 /// which means we'll need to perform rewrites on this function. As a side 2192 /// effect, this will populate the type remapping cache. 2193 static bool containsBufferFatPointers(const Function &F, 2194 BufferFatPtrToStructTypeMap *TypeMap) { 2195 bool HasFatPointers = false; 2196 for (const BasicBlock &BB : F) 2197 for (const Instruction &I : BB) 2198 HasFatPointers |= (I.getType() != TypeMap->remapType(I.getType())); 2199 return HasFatPointers; 2200 } 2201 2202 static bool hasFatPointerInterface(const Function &F, 2203 BufferFatPtrToStructTypeMap *TypeMap) { 2204 Type *Ty = F.getFunctionType(); 2205 return Ty != TypeMap->remapType(Ty); 2206 } 2207 2208 /// Move the body of `OldF` into a new function, returning it. 2209 static Function *moveFunctionAdaptingType(Function *OldF, FunctionType *NewTy, 2210 ValueToValueMapTy &CloneMap) { 2211 bool IsIntrinsic = OldF->isIntrinsic(); 2212 Function *NewF = 2213 Function::Create(NewTy, OldF->getLinkage(), OldF->getAddressSpace()); 2214 NewF->IsNewDbgInfoFormat = OldF->IsNewDbgInfoFormat; 2215 NewF->copyAttributesFrom(OldF); 2216 NewF->copyMetadata(OldF, 0); 2217 NewF->takeName(OldF); 2218 NewF->updateAfterNameChange(); 2219 NewF->setDLLStorageClass(OldF->getDLLStorageClass()); 2220 OldF->getParent()->getFunctionList().insertAfter(OldF->getIterator(), NewF); 2221 2222 while (!OldF->empty()) { 2223 BasicBlock *BB = &OldF->front(); 2224 BB->removeFromParent(); 2225 BB->insertInto(NewF); 2226 CloneMap[BB] = BB; 2227 for (Instruction &I : *BB) { 2228 CloneMap[&I] = &I; 2229 } 2230 } 2231 2232 SmallVector<AttributeSet> ArgAttrs; 2233 AttributeList OldAttrs = OldF->getAttributes(); 2234 2235 for (auto [I, OldArg, NewArg] : enumerate(OldF->args(), NewF->args())) { 2236 CloneMap[&NewArg] = &OldArg; 2237 NewArg.takeName(&OldArg); 2238 Type *OldArgTy = OldArg.getType(), *NewArgTy = NewArg.getType(); 2239 // Temporarily mutate type of `NewArg` to allow RAUW to work. 2240 NewArg.mutateType(OldArgTy); 2241 OldArg.replaceAllUsesWith(&NewArg); 2242 NewArg.mutateType(NewArgTy); 2243 2244 AttributeSet ArgAttr = OldAttrs.getParamAttrs(I); 2245 // Intrinsics get their attributes fixed later. 2246 if (OldArgTy != NewArgTy && !IsIntrinsic) 2247 ArgAttr = ArgAttr.removeAttributes( 2248 NewF->getContext(), 2249 AttributeFuncs::typeIncompatible(NewArgTy, ArgAttr)); 2250 ArgAttrs.push_back(ArgAttr); 2251 } 2252 AttributeSet RetAttrs = OldAttrs.getRetAttrs(); 2253 if (OldF->getReturnType() != NewF->getReturnType() && !IsIntrinsic) 2254 RetAttrs = RetAttrs.removeAttributes( 2255 NewF->getContext(), 2256 AttributeFuncs::typeIncompatible(NewF->getReturnType(), RetAttrs)); 2257 NewF->setAttributes(AttributeList::get( 2258 NewF->getContext(), OldAttrs.getFnAttrs(), RetAttrs, ArgAttrs)); 2259 return NewF; 2260 } 2261 2262 static void makeCloneInPraceMap(Function *F, ValueToValueMapTy &CloneMap) { 2263 for (Argument &A : F->args()) 2264 CloneMap[&A] = &A; 2265 for (BasicBlock &BB : *F) { 2266 CloneMap[&BB] = &BB; 2267 for (Instruction &I : BB) 2268 CloneMap[&I] = &I; 2269 } 2270 } 2271 2272 bool AMDGPULowerBufferFatPointers::run(Module &M, const TargetMachine &TM) { 2273 bool Changed = false; 2274 const DataLayout &DL = M.getDataLayout(); 2275 // Record the functions which need to be remapped. 2276 // The second element of the pair indicates whether the function has to have 2277 // its arguments or return types adjusted. 2278 SmallVector<std::pair<Function *, bool>> NeedsRemap; 2279 2280 BufferFatPtrToStructTypeMap StructTM(DL); 2281 BufferFatPtrToIntTypeMap IntTM(DL); 2282 for (const GlobalVariable &GV : M.globals()) { 2283 if (GV.getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER) 2284 report_fatal_error("Global variables with a buffer fat pointer address " 2285 "space (7) are not supported"); 2286 Type *VT = GV.getValueType(); 2287 if (VT != StructTM.remapType(VT)) 2288 report_fatal_error("Global variables that contain buffer fat pointers " 2289 "(address space 7 pointers) are unsupported. Use " 2290 "buffer resource pointers (address space 8) instead."); 2291 } 2292 2293 { 2294 // Collect all constant exprs and aggregates referenced by any function. 2295 SmallVector<Constant *, 8> Worklist; 2296 for (Function &F : M.functions()) 2297 for (Instruction &I : instructions(F)) 2298 for (Value *Op : I.operands()) 2299 if (isa<ConstantExpr>(Op) || isa<ConstantAggregate>(Op)) 2300 Worklist.push_back(cast<Constant>(Op)); 2301 2302 // Recursively look for any referenced buffer pointer constants. 2303 SmallPtrSet<Constant *, 8> Visited; 2304 SetVector<Constant *> BufferFatPtrConsts; 2305 while (!Worklist.empty()) { 2306 Constant *C = Worklist.pop_back_val(); 2307 if (!Visited.insert(C).second) 2308 continue; 2309 if (isBufferFatPtrOrVector(C->getType())) 2310 BufferFatPtrConsts.insert(C); 2311 for (Value *Op : C->operands()) 2312 if (isa<ConstantExpr>(Op) || isa<ConstantAggregate>(Op)) 2313 Worklist.push_back(cast<Constant>(Op)); 2314 } 2315 2316 // Expand all constant expressions using fat buffer pointers to 2317 // instructions. 2318 Changed |= convertUsersOfConstantsToInstructions( 2319 BufferFatPtrConsts.getArrayRef(), /*RestrictToFunc=*/nullptr, 2320 /*RemoveDeadConstants=*/false, /*IncludeSelf=*/true); 2321 } 2322 2323 StoreFatPtrsAsIntsVisitor MemOpsRewrite(&IntTM, M.getContext()); 2324 LegalizeBufferContentTypesVisitor BufferContentsTypeRewrite(DL, 2325 M.getContext()); 2326 for (Function &F : M.functions()) { 2327 bool InterfaceChange = hasFatPointerInterface(F, &StructTM); 2328 bool BodyChanges = containsBufferFatPointers(F, &StructTM); 2329 Changed |= MemOpsRewrite.processFunction(F); 2330 if (InterfaceChange || BodyChanges) { 2331 NeedsRemap.push_back(std::make_pair(&F, InterfaceChange)); 2332 Changed |= BufferContentsTypeRewrite.processFunction(F); 2333 } 2334 } 2335 if (NeedsRemap.empty()) 2336 return Changed; 2337 2338 SmallVector<Function *> NeedsPostProcess; 2339 SmallVector<Function *> Intrinsics; 2340 // Keep one big map so as to memoize constants across functions. 2341 ValueToValueMapTy CloneMap; 2342 FatPtrConstMaterializer Materializer(&StructTM, CloneMap); 2343 2344 ValueMapper LowerInFuncs(CloneMap, RF_None, &StructTM, &Materializer); 2345 for (auto [F, InterfaceChange] : NeedsRemap) { 2346 Function *NewF = F; 2347 if (InterfaceChange) 2348 NewF = moveFunctionAdaptingType( 2349 F, cast<FunctionType>(StructTM.remapType(F->getFunctionType())), 2350 CloneMap); 2351 else 2352 makeCloneInPraceMap(F, CloneMap); 2353 LowerInFuncs.remapFunction(*NewF); 2354 if (NewF->isIntrinsic()) 2355 Intrinsics.push_back(NewF); 2356 else 2357 NeedsPostProcess.push_back(NewF); 2358 if (InterfaceChange) { 2359 F->replaceAllUsesWith(NewF); 2360 F->eraseFromParent(); 2361 } 2362 Changed = true; 2363 } 2364 StructTM.clear(); 2365 IntTM.clear(); 2366 CloneMap.clear(); 2367 2368 SplitPtrStructs Splitter(M.getContext(), &TM); 2369 for (Function *F : NeedsPostProcess) 2370 Splitter.processFunction(*F); 2371 for (Function *F : Intrinsics) { 2372 if (isRemovablePointerIntrinsic(F->getIntrinsicID())) { 2373 F->eraseFromParent(); 2374 } else { 2375 std::optional<Function *> NewF = Intrinsic::remangleIntrinsicFunction(F); 2376 if (NewF) 2377 F->replaceAllUsesWith(*NewF); 2378 } 2379 } 2380 return Changed; 2381 } 2382 2383 bool AMDGPULowerBufferFatPointers::runOnModule(Module &M) { 2384 TargetPassConfig &TPC = getAnalysis<TargetPassConfig>(); 2385 const TargetMachine &TM = TPC.getTM<TargetMachine>(); 2386 return run(M, TM); 2387 } 2388 2389 char AMDGPULowerBufferFatPointers::ID = 0; 2390 2391 char &llvm::AMDGPULowerBufferFatPointersID = AMDGPULowerBufferFatPointers::ID; 2392 2393 void AMDGPULowerBufferFatPointers::getAnalysisUsage(AnalysisUsage &AU) const { 2394 AU.addRequired<TargetPassConfig>(); 2395 } 2396 2397 #define PASS_DESC "Lower buffer fat pointer operations to buffer resources" 2398 INITIALIZE_PASS_BEGIN(AMDGPULowerBufferFatPointers, DEBUG_TYPE, PASS_DESC, 2399 false, false) 2400 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 2401 INITIALIZE_PASS_END(AMDGPULowerBufferFatPointers, DEBUG_TYPE, PASS_DESC, false, 2402 false) 2403 #undef PASS_DESC 2404 2405 ModulePass *llvm::createAMDGPULowerBufferFatPointersPass() { 2406 return new AMDGPULowerBufferFatPointers(); 2407 } 2408 2409 PreservedAnalyses 2410 AMDGPULowerBufferFatPointersPass::run(Module &M, ModuleAnalysisManager &MA) { 2411 return AMDGPULowerBufferFatPointers().run(M, TM) ? PreservedAnalyses::none() 2412 : PreservedAnalyses::all(); 2413 } 2414