1 //===-- AMDGPULowerBufferFatPointers.cpp ---------------------------=// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass lowers operations on buffer fat pointers (addrspace 7) to 10 // operations on buffer resources (addrspace 8) and is needed for correct 11 // codegen. 12 // 13 // # Background 14 // 15 // Address space 7 (the buffer fat pointer) is a 160-bit pointer that consists 16 // of a 128-bit buffer descriptor and a 32-bit offset into that descriptor. 17 // The buffer resource part needs to be it needs to be a "raw" buffer resource 18 // (it must have a stride of 0 and bounds checks must be in raw buffer mode 19 // or disabled). 20 // 21 // When these requirements are met, a buffer resource can be treated as a 22 // typical (though quite wide) pointer that follows typical LLVM pointer 23 // semantics. This allows the frontend to reason about such buffers (which are 24 // often encountered in the context of SPIR-V kernels). 25 // 26 // However, because of their non-power-of-2 size, these fat pointers cannot be 27 // present during translation to MIR (though this restriction may be lifted 28 // during the transition to GlobalISel). Therefore, this pass is needed in order 29 // to correctly implement these fat pointers. 30 // 31 // The resource intrinsics take the resource part (the address space 8 pointer) 32 // and the offset part (the 32-bit integer) as separate arguments. In addition, 33 // many users of these buffers manipulate the offset while leaving the resource 34 // part alone. For these reasons, we want to typically separate the resource 35 // and offset parts into separate variables, but combine them together when 36 // encountering cases where this is required, such as by inserting these values 37 // into aggretates or moving them to memory. 38 // 39 // Therefore, at a high level, `ptr addrspace(7) %x` becomes `ptr addrspace(8) 40 // %x.rsrc` and `i32 %x.off`, which will be combined into `{ptr addrspace(8), 41 // i32} %x = {%x.rsrc, %x.off}` if needed. Similarly, `vector<Nxp7>` becomes 42 // `{vector<Nxp8>, vector<Nxi32 >}` and its component parts. 43 // 44 // # Implementation 45 // 46 // This pass proceeds in three main phases: 47 // 48 // ## Rewriting loads and stores of p7 49 // 50 // The first phase is to rewrite away all loads and stors of `ptr addrspace(7)`, 51 // including aggregates containing such pointers, to ones that use `i160`. This 52 // is handled by `StoreFatPtrsAsIntsVisitor` , which visits loads, stores, and 53 // allocas and, if the loaded or stored type contains `ptr addrspace(7)`, 54 // rewrites that type to one where the p7s are replaced by i160s, copying other 55 // parts of aggregates as needed. In the case of a store, each pointer is 56 // `ptrtoint`d to i160 before storing, and load integers are `inttoptr`d back. 57 // This same transformation is applied to vectors of pointers. 58 // 59 // Such a transformation allows the later phases of the pass to not need 60 // to handle buffer fat pointers moving to and from memory, where we load 61 // have to handle the incompatibility between a `{Nxp8, Nxi32}` representation 62 // and `Nxi60` directly. Instead, that transposing action (where the vectors 63 // of resources and vectors of offsets are concatentated before being stored to 64 // memory) are handled through implementing `inttoptr` and `ptrtoint` only. 65 // 66 // Atomics operations on `ptr addrspace(7)` values are not suppported, as the 67 // hardware does not include a 160-bit atomic. 68 // 69 // ## Buffer contents type legalization 70 // 71 // The underlying buffer intrinsics only support types up to 128 bits long, 72 // and don't support complex types. If buffer operations were 73 // standard pointer operations that could be represented as MIR-level loads, 74 // this would be handled by the various legalization schemes in instruction 75 // selection. However, because we have to do the conversion from `load` and 76 // `store` to intrinsics at LLVM IR level, we must perform that legalization 77 // ourselves. 78 // 79 // This involves a combination of 80 // - Converting arrays to vectors where possible 81 // - Otherwise, splitting loads and stores of aggregates into loads/stores of 82 // each component. 83 // - Zero-extending things to fill a whole number of bytes 84 // - Casting values of types that don't neatly correspond to supported machine 85 // value 86 // (for example, an i96 or i256) into ones that would work ( 87 // like <3 x i32> and <8 x i32>, respectively) 88 // - Splitting values that are too long (such as aforementioned <8 x i32>) into 89 // multiple operations. 90 // 91 // ## Type remapping 92 // 93 // We use a `ValueMapper` to mangle uses of [vectors of] buffer fat pointers 94 // to the corresponding struct type, which has a resource part and an offset 95 // part. 96 // 97 // This uses a `BufferFatPtrToStructTypeMap` and a `FatPtrConstMaterializer` 98 // to, usually by way of `setType`ing values. Constants are handled here 99 // because there isn't a good way to fix them up later. 100 // 101 // This has the downside of leaving the IR in an invalid state (for example, 102 // the instruction `getelementptr {ptr addrspace(8), i32} %p, ...` will exist), 103 // but all such invalid states will be resolved by the third phase. 104 // 105 // Functions that don't take buffer fat pointers are modified in place. Those 106 // that do take such pointers have their basic blocks moved to a new function 107 // with arguments that are {ptr addrspace(8), i32} arguments and return values. 108 // This phase also records intrinsics so that they can be remangled or deleted 109 // later. 110 // 111 // ## Splitting pointer structs 112 // 113 // The meat of this pass consists of defining semantics for operations that 114 // produce or consume [vectors of] buffer fat pointers in terms of their 115 // resource and offset parts. This is accomplished throgh the `SplitPtrStructs` 116 // visitor. 117 // 118 // In the first pass through each function that is being lowered, the splitter 119 // inserts new instructions to implement the split-structures behavior, which is 120 // needed for correctness and performance. It records a list of "split users", 121 // instructions that are being replaced by operations on the resource and offset 122 // parts. 123 // 124 // Split users do not necessarily need to produce parts themselves ( 125 // a `load float, ptr addrspace(7)` does not, for example), but, if they do not 126 // generate fat buffer pointers, they must RAUW in their replacement 127 // instructions during the initial visit. 128 // 129 // When these new instructions are created, they use the split parts recorded 130 // for their initial arguments in order to generate their replacements, creating 131 // a parallel set of instructions that does not refer to the original fat 132 // pointer values but instead to their resource and offset components. 133 // 134 // Instructions, such as `extractvalue`, that produce buffer fat pointers from 135 // sources that do not have split parts, have such parts generated using 136 // `extractvalue`. This is also the initial handling of PHI nodes, which 137 // are then cleaned up. 138 // 139 // ### Conditionals 140 // 141 // PHI nodes are initially given resource parts via `extractvalue`. However, 142 // this is not an efficient rewrite of such nodes, as, in most cases, the 143 // resource part in a conditional or loop remains constant throughout the loop 144 // and only the offset varies. Failing to optimize away these constant resources 145 // would cause additional registers to be sent around loops and might lead to 146 // waterfall loops being generated for buffer operations due to the 147 // "non-uniform" resource argument. 148 // 149 // Therefore, after all instructions have been visited, the pointer splitter 150 // post-processes all encountered conditionals. Given a PHI node or select, 151 // getPossibleRsrcRoots() collects all values that the resource parts of that 152 // conditional's input could come from as well as collecting all conditional 153 // instructions encountered during the search. If, after filtering out the 154 // initial node itself, the set of encountered conditionals is a subset of the 155 // potential roots and there is a single potential resource that isn't in the 156 // conditional set, that value is the only possible value the resource argument 157 // could have throughout the control flow. 158 // 159 // If that condition is met, then a PHI node can have its resource part changed 160 // to the singleton value and then be replaced by a PHI on the offsets. 161 // Otherwise, each PHI node is split into two, one for the resource part and one 162 // for the offset part, which replace the temporary `extractvalue` instructions 163 // that were added during the first pass. 164 // 165 // Similar logic applies to `select`, where 166 // `%z = select i1 %cond, %cond, ptr addrspace(7) %x, ptr addrspace(7) %y` 167 // can be split into `%z.rsrc = %x.rsrc` and 168 // `%z.off = select i1 %cond, ptr i32 %x.off, i32 %y.off` 169 // if both `%x` and `%y` have the same resource part, but two `select` 170 // operations will be needed if they do not. 171 // 172 // ### Final processing 173 // 174 // After conditionals have been cleaned up, the IR for each function is 175 // rewritten to remove all the old instructions that have been split up. 176 // 177 // Any instruction that used to produce a buffer fat pointer (and therefore now 178 // produces a resource-and-offset struct after type remapping) is 179 // replaced as follows: 180 // 1. All debug value annotations are cloned to reflect that the resource part 181 // and offset parts are computed separately and constitute different 182 // fragments of the underlying source language variable. 183 // 2. All uses that were themselves split are replaced by a `poison` of the 184 // struct type, as they will themselves be erased soon. This rule, combined 185 // with debug handling, should leave the use lists of split instructions 186 // empty in almost all cases. 187 // 3. If a user of the original struct-valued result remains, the structure 188 // needed for the new types to work is constructed out of the newly-defined 189 // parts, and the original instruction is replaced by this structure 190 // before being erased. Instructions requiring this construction include 191 // `ret` and `insertvalue`. 192 // 193 // # Consequences 194 // 195 // This pass does not alter the CFG. 196 // 197 // Alias analysis information will become coarser, as the LLVM alias analyzer 198 // cannot handle the buffer intrinsics. Specifically, while we can determine 199 // that the following two loads do not alias: 200 // ``` 201 // %y = getelementptr i32, ptr addrspace(7) %x, i32 1 202 // %a = load i32, ptr addrspace(7) %x 203 // %b = load i32, ptr addrspace(7) %y 204 // ``` 205 // we cannot (except through some code that runs during scheduling) determine 206 // that the rewritten loads below do not alias. 207 // ``` 208 // %y.off = add i32 %x.off, 1 209 // %a = call @llvm.amdgcn.raw.ptr.buffer.load(ptr addrspace(8) %x.rsrc, i32 210 // %x.off, ...) 211 // %b = call @llvm.amdgcn.raw.ptr.buffer.load(ptr addrspace(8) 212 // %x.rsrc, i32 %y.off, ...) 213 // ``` 214 // However, existing alias information is preserved. 215 //===----------------------------------------------------------------------===// 216 217 #include "AMDGPU.h" 218 #include "AMDGPUTargetMachine.h" 219 #include "GCNSubtarget.h" 220 #include "SIDefines.h" 221 #include "llvm/ADT/SetOperations.h" 222 #include "llvm/ADT/SmallVector.h" 223 #include "llvm/Analysis/ConstantFolding.h" 224 #include "llvm/Analysis/Utils/Local.h" 225 #include "llvm/CodeGen/TargetPassConfig.h" 226 #include "llvm/IR/AttributeMask.h" 227 #include "llvm/IR/Constants.h" 228 #include "llvm/IR/DebugInfo.h" 229 #include "llvm/IR/DerivedTypes.h" 230 #include "llvm/IR/IRBuilder.h" 231 #include "llvm/IR/InstIterator.h" 232 #include "llvm/IR/InstVisitor.h" 233 #include "llvm/IR/Instructions.h" 234 #include "llvm/IR/Intrinsics.h" 235 #include "llvm/IR/IntrinsicsAMDGPU.h" 236 #include "llvm/IR/Metadata.h" 237 #include "llvm/IR/Operator.h" 238 #include "llvm/IR/PatternMatch.h" 239 #include "llvm/IR/ReplaceConstant.h" 240 #include "llvm/InitializePasses.h" 241 #include "llvm/Pass.h" 242 #include "llvm/Support/Alignment.h" 243 #include "llvm/Support/AtomicOrdering.h" 244 #include "llvm/Support/Debug.h" 245 #include "llvm/Support/ErrorHandling.h" 246 #include "llvm/Transforms/Utils/Cloning.h" 247 #include "llvm/Transforms/Utils/Local.h" 248 #include "llvm/Transforms/Utils/ValueMapper.h" 249 250 #define DEBUG_TYPE "amdgpu-lower-buffer-fat-pointers" 251 252 using namespace llvm; 253 254 static constexpr unsigned BufferOffsetWidth = 32; 255 256 namespace { 257 /// Recursively replace instances of ptr addrspace(7) and vector<Nxptr 258 /// addrspace(7)> with some other type as defined by the relevant subclass. 259 class BufferFatPtrTypeLoweringBase : public ValueMapTypeRemapper { 260 DenseMap<Type *, Type *> Map; 261 262 Type *remapTypeImpl(Type *Ty, SmallPtrSetImpl<StructType *> &Seen); 263 264 protected: 265 virtual Type *remapScalar(PointerType *PT) = 0; 266 virtual Type *remapVector(VectorType *VT) = 0; 267 268 const DataLayout &DL; 269 270 public: 271 BufferFatPtrTypeLoweringBase(const DataLayout &DL) : DL(DL) {} 272 Type *remapType(Type *SrcTy) override; 273 void clear() { Map.clear(); } 274 }; 275 276 /// Remap ptr addrspace(7) to i160 and vector<Nxptr addrspace(7)> to 277 /// vector<Nxi60> in order to correctly handling loading/storing these values 278 /// from memory. 279 class BufferFatPtrToIntTypeMap : public BufferFatPtrTypeLoweringBase { 280 using BufferFatPtrTypeLoweringBase::BufferFatPtrTypeLoweringBase; 281 282 protected: 283 Type *remapScalar(PointerType *PT) override { return DL.getIntPtrType(PT); } 284 Type *remapVector(VectorType *VT) override { return DL.getIntPtrType(VT); } 285 }; 286 287 /// Remap ptr addrspace(7) to {ptr addrspace(8), i32} (the resource and offset 288 /// parts of the pointer) so that we can easily rewrite operations on these 289 /// values that aren't loading them from or storing them to memory. 290 class BufferFatPtrToStructTypeMap : public BufferFatPtrTypeLoweringBase { 291 using BufferFatPtrTypeLoweringBase::BufferFatPtrTypeLoweringBase; 292 293 protected: 294 Type *remapScalar(PointerType *PT) override; 295 Type *remapVector(VectorType *VT) override; 296 }; 297 } // namespace 298 299 // This code is adapted from the type remapper in lib/Linker/IRMover.cpp 300 Type *BufferFatPtrTypeLoweringBase::remapTypeImpl( 301 Type *Ty, SmallPtrSetImpl<StructType *> &Seen) { 302 Type **Entry = &Map[Ty]; 303 if (*Entry) 304 return *Entry; 305 if (auto *PT = dyn_cast<PointerType>(Ty)) { 306 if (PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER) { 307 return *Entry = remapScalar(PT); 308 } 309 } 310 if (auto *VT = dyn_cast<VectorType>(Ty)) { 311 auto *PT = dyn_cast<PointerType>(VT->getElementType()); 312 if (PT && PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER) { 313 return *Entry = remapVector(VT); 314 } 315 return *Entry = Ty; 316 } 317 // Whether the type is one that is structurally uniqued - that is, if it is 318 // not a named struct (the only kind of type where multiple structurally 319 // identical types that have a distinct `Type*`) 320 StructType *TyAsStruct = dyn_cast<StructType>(Ty); 321 bool IsUniqued = !TyAsStruct || TyAsStruct->isLiteral(); 322 // Base case for ints, floats, opaque pointers, and so on, which don't 323 // require recursion. 324 if (Ty->getNumContainedTypes() == 0 && IsUniqued) 325 return *Entry = Ty; 326 if (!IsUniqued) { 327 // Create a dummy type for recursion purposes. 328 if (!Seen.insert(TyAsStruct).second) { 329 StructType *Placeholder = StructType::create(Ty->getContext()); 330 return *Entry = Placeholder; 331 } 332 } 333 bool Changed = false; 334 SmallVector<Type *> ElementTypes(Ty->getNumContainedTypes(), nullptr); 335 for (unsigned int I = 0, E = Ty->getNumContainedTypes(); I < E; ++I) { 336 Type *OldElem = Ty->getContainedType(I); 337 Type *NewElem = remapTypeImpl(OldElem, Seen); 338 ElementTypes[I] = NewElem; 339 Changed |= (OldElem != NewElem); 340 } 341 // Recursive calls to remapTypeImpl() may have invalidated pointer. 342 Entry = &Map[Ty]; 343 if (!Changed) { 344 return *Entry = Ty; 345 } 346 if (auto *ArrTy = dyn_cast<ArrayType>(Ty)) 347 return *Entry = ArrayType::get(ElementTypes[0], ArrTy->getNumElements()); 348 if (auto *FnTy = dyn_cast<FunctionType>(Ty)) 349 return *Entry = FunctionType::get(ElementTypes[0], 350 ArrayRef(ElementTypes).slice(1), 351 FnTy->isVarArg()); 352 if (auto *STy = dyn_cast<StructType>(Ty)) { 353 // Genuine opaque types don't have a remapping. 354 if (STy->isOpaque()) 355 return *Entry = Ty; 356 bool IsPacked = STy->isPacked(); 357 if (IsUniqued) 358 return *Entry = StructType::get(Ty->getContext(), ElementTypes, IsPacked); 359 SmallString<16> Name(STy->getName()); 360 STy->setName(""); 361 Type **RecursionEntry = &Map[Ty]; 362 if (*RecursionEntry) { 363 auto *Placeholder = cast<StructType>(*RecursionEntry); 364 Placeholder->setBody(ElementTypes, IsPacked); 365 Placeholder->setName(Name); 366 return *Entry = Placeholder; 367 } 368 return *Entry = StructType::create(Ty->getContext(), ElementTypes, Name, 369 IsPacked); 370 } 371 llvm_unreachable("Unknown type of type that contains elements"); 372 } 373 374 Type *BufferFatPtrTypeLoweringBase::remapType(Type *SrcTy) { 375 SmallPtrSet<StructType *, 2> Visited; 376 return remapTypeImpl(SrcTy, Visited); 377 } 378 379 Type *BufferFatPtrToStructTypeMap::remapScalar(PointerType *PT) { 380 LLVMContext &Ctx = PT->getContext(); 381 return StructType::get(PointerType::get(Ctx, AMDGPUAS::BUFFER_RESOURCE), 382 IntegerType::get(Ctx, BufferOffsetWidth)); 383 } 384 385 Type *BufferFatPtrToStructTypeMap::remapVector(VectorType *VT) { 386 ElementCount EC = VT->getElementCount(); 387 LLVMContext &Ctx = VT->getContext(); 388 Type *RsrcVec = 389 VectorType::get(PointerType::get(Ctx, AMDGPUAS::BUFFER_RESOURCE), EC); 390 Type *OffVec = VectorType::get(IntegerType::get(Ctx, BufferOffsetWidth), EC); 391 return StructType::get(RsrcVec, OffVec); 392 } 393 394 static bool isBufferFatPtrOrVector(Type *Ty) { 395 if (auto *PT = dyn_cast<PointerType>(Ty->getScalarType())) 396 return PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER; 397 return false; 398 } 399 400 // True if the type is {ptr addrspace(8), i32} or a struct containing vectors of 401 // those types. Used to quickly skip instructions we don't need to process. 402 static bool isSplitFatPtr(Type *Ty) { 403 auto *ST = dyn_cast<StructType>(Ty); 404 if (!ST) 405 return false; 406 if (!ST->isLiteral() || ST->getNumElements() != 2) 407 return false; 408 auto *MaybeRsrc = 409 dyn_cast<PointerType>(ST->getElementType(0)->getScalarType()); 410 auto *MaybeOff = 411 dyn_cast<IntegerType>(ST->getElementType(1)->getScalarType()); 412 return MaybeRsrc && MaybeOff && 413 MaybeRsrc->getAddressSpace() == AMDGPUAS::BUFFER_RESOURCE && 414 MaybeOff->getBitWidth() == BufferOffsetWidth; 415 } 416 417 // True if the result type or any argument types are buffer fat pointers. 418 static bool isBufferFatPtrConst(Constant *C) { 419 Type *T = C->getType(); 420 return isBufferFatPtrOrVector(T) || any_of(C->operands(), [](const Use &U) { 421 return isBufferFatPtrOrVector(U.get()->getType()); 422 }); 423 } 424 425 namespace { 426 /// Convert [vectors of] buffer fat pointers to integers when they are read from 427 /// or stored to memory. This ensures that these pointers will have the same 428 /// memory layout as before they are lowered, even though they will no longer 429 /// have their previous layout in registers/in the program (they'll be broken 430 /// down into resource and offset parts). This has the downside of imposing 431 /// marshalling costs when reading or storing these values, but since placing 432 /// such pointers into memory is an uncommon operation at best, we feel that 433 /// this cost is acceptable for better performance in the common case. 434 class StoreFatPtrsAsIntsVisitor 435 : public InstVisitor<StoreFatPtrsAsIntsVisitor, bool> { 436 BufferFatPtrToIntTypeMap *TypeMap; 437 438 ValueToValueMapTy ConvertedForStore; 439 440 IRBuilder<> IRB; 441 442 // Convert all the buffer fat pointers within the input value to inttegers 443 // so that it can be stored in memory. 444 Value *fatPtrsToInts(Value *V, Type *From, Type *To, const Twine &Name); 445 // Convert all the i160s that need to be buffer fat pointers (as specified) 446 // by the To type) into those pointers to preserve the semantics of the rest 447 // of the program. 448 Value *intsToFatPtrs(Value *V, Type *From, Type *To, const Twine &Name); 449 450 public: 451 StoreFatPtrsAsIntsVisitor(BufferFatPtrToIntTypeMap *TypeMap, LLVMContext &Ctx) 452 : TypeMap(TypeMap), IRB(Ctx) {} 453 bool processFunction(Function &F); 454 455 bool visitInstruction(Instruction &I) { return false; } 456 bool visitAllocaInst(AllocaInst &I); 457 bool visitLoadInst(LoadInst &LI); 458 bool visitStoreInst(StoreInst &SI); 459 bool visitGetElementPtrInst(GetElementPtrInst &I); 460 }; 461 } // namespace 462 463 Value *StoreFatPtrsAsIntsVisitor::fatPtrsToInts(Value *V, Type *From, Type *To, 464 const Twine &Name) { 465 if (From == To) 466 return V; 467 ValueToValueMapTy::iterator Find = ConvertedForStore.find(V); 468 if (Find != ConvertedForStore.end()) 469 return Find->second; 470 if (isBufferFatPtrOrVector(From)) { 471 Value *Cast = IRB.CreatePtrToInt(V, To, Name + ".int"); 472 ConvertedForStore[V] = Cast; 473 return Cast; 474 } 475 if (From->getNumContainedTypes() == 0) 476 return V; 477 // Structs, arrays, and other compound types. 478 Value *Ret = PoisonValue::get(To); 479 if (auto *AT = dyn_cast<ArrayType>(From)) { 480 Type *FromPart = AT->getArrayElementType(); 481 Type *ToPart = cast<ArrayType>(To)->getElementType(); 482 for (uint64_t I = 0, E = AT->getArrayNumElements(); I < E; ++I) { 483 Value *Field = IRB.CreateExtractValue(V, I); 484 Value *NewField = 485 fatPtrsToInts(Field, FromPart, ToPart, Name + "." + Twine(I)); 486 Ret = IRB.CreateInsertValue(Ret, NewField, I); 487 } 488 } else { 489 for (auto [Idx, FromPart, ToPart] : 490 enumerate(From->subtypes(), To->subtypes())) { 491 Value *Field = IRB.CreateExtractValue(V, Idx); 492 Value *NewField = 493 fatPtrsToInts(Field, FromPart, ToPart, Name + "." + Twine(Idx)); 494 Ret = IRB.CreateInsertValue(Ret, NewField, Idx); 495 } 496 } 497 ConvertedForStore[V] = Ret; 498 return Ret; 499 } 500 501 Value *StoreFatPtrsAsIntsVisitor::intsToFatPtrs(Value *V, Type *From, Type *To, 502 const Twine &Name) { 503 if (From == To) 504 return V; 505 if (isBufferFatPtrOrVector(To)) { 506 Value *Cast = IRB.CreateIntToPtr(V, To, Name + ".ptr"); 507 return Cast; 508 } 509 if (From->getNumContainedTypes() == 0) 510 return V; 511 // Structs, arrays, and other compound types. 512 Value *Ret = PoisonValue::get(To); 513 if (auto *AT = dyn_cast<ArrayType>(From)) { 514 Type *FromPart = AT->getArrayElementType(); 515 Type *ToPart = cast<ArrayType>(To)->getElementType(); 516 for (uint64_t I = 0, E = AT->getArrayNumElements(); I < E; ++I) { 517 Value *Field = IRB.CreateExtractValue(V, I); 518 Value *NewField = 519 intsToFatPtrs(Field, FromPart, ToPart, Name + "." + Twine(I)); 520 Ret = IRB.CreateInsertValue(Ret, NewField, I); 521 } 522 } else { 523 for (auto [Idx, FromPart, ToPart] : 524 enumerate(From->subtypes(), To->subtypes())) { 525 Value *Field = IRB.CreateExtractValue(V, Idx); 526 Value *NewField = 527 intsToFatPtrs(Field, FromPart, ToPart, Name + "." + Twine(Idx)); 528 Ret = IRB.CreateInsertValue(Ret, NewField, Idx); 529 } 530 } 531 return Ret; 532 } 533 534 bool StoreFatPtrsAsIntsVisitor::processFunction(Function &F) { 535 bool Changed = false; 536 // The visitors will mutate GEPs and allocas, but will push loads and stores 537 // to the worklist to avoid invalidation. 538 for (Instruction &I : make_early_inc_range(instructions(F))) { 539 Changed |= visit(I); 540 } 541 ConvertedForStore.clear(); 542 return Changed; 543 } 544 545 bool StoreFatPtrsAsIntsVisitor::visitAllocaInst(AllocaInst &I) { 546 Type *Ty = I.getAllocatedType(); 547 Type *NewTy = TypeMap->remapType(Ty); 548 if (Ty == NewTy) 549 return false; 550 I.setAllocatedType(NewTy); 551 return true; 552 } 553 554 bool StoreFatPtrsAsIntsVisitor::visitGetElementPtrInst(GetElementPtrInst &I) { 555 Type *Ty = I.getSourceElementType(); 556 Type *NewTy = TypeMap->remapType(Ty); 557 if (Ty == NewTy) 558 return false; 559 // We'll be rewriting the type `ptr addrspace(7)` out of existence soon, so 560 // make sure GEPs don't have different semantics with the new type. 561 I.setSourceElementType(NewTy); 562 I.setResultElementType(TypeMap->remapType(I.getResultElementType())); 563 return true; 564 } 565 566 bool StoreFatPtrsAsIntsVisitor::visitLoadInst(LoadInst &LI) { 567 Type *Ty = LI.getType(); 568 Type *IntTy = TypeMap->remapType(Ty); 569 if (Ty == IntTy) 570 return false; 571 572 IRB.SetInsertPoint(&LI); 573 auto *NLI = cast<LoadInst>(LI.clone()); 574 NLI->mutateType(IntTy); 575 NLI = IRB.Insert(NLI); 576 NLI->takeName(&LI); 577 578 Value *CastBack = intsToFatPtrs(NLI, IntTy, Ty, NLI->getName()); 579 LI.replaceAllUsesWith(CastBack); 580 LI.eraseFromParent(); 581 return true; 582 } 583 584 bool StoreFatPtrsAsIntsVisitor::visitStoreInst(StoreInst &SI) { 585 Value *V = SI.getValueOperand(); 586 Type *Ty = V->getType(); 587 Type *IntTy = TypeMap->remapType(Ty); 588 if (Ty == IntTy) 589 return false; 590 591 IRB.SetInsertPoint(&SI); 592 Value *IntV = fatPtrsToInts(V, Ty, IntTy, V->getName()); 593 for (auto *Dbg : at::getAssignmentMarkers(&SI)) 594 Dbg->setValue(IntV); 595 596 SI.setOperand(0, IntV); 597 return true; 598 } 599 600 namespace { 601 /// Convert loads/stores of types that the buffer intrinsics can't handle into 602 /// one ore more such loads/stores that consist of legal types. 603 /// 604 /// Do this by 605 /// 1. Recursing into structs (and arrays that don't share a memory layout with 606 /// vectors) since the intrinsics can't handle complex types. 607 /// 2. Converting arrays of non-aggregate, byte-sized types into their 608 /// corresponding vectors 609 /// 3. Bitcasting unsupported types, namely overly-long scalars and byte 610 /// vectors, into vectors of supported types. 611 /// 4. Splitting up excessively long reads/writes into multiple operations. 612 /// 613 /// Note that this doesn't handle complex data strucures, but, in the future, 614 /// the aggregate load splitter from SROA could be refactored to allow for that 615 /// case. 616 class LegalizeBufferContentTypesVisitor 617 : public InstVisitor<LegalizeBufferContentTypesVisitor, bool> { 618 friend class InstVisitor<LegalizeBufferContentTypesVisitor, bool>; 619 620 IRBuilder<> IRB; 621 622 const DataLayout &DL; 623 624 /// If T is [N x U], where U is a scalar type, return the vector type 625 /// <N x U>, otherwise, return T. 626 Type *scalarArrayTypeAsVector(Type *MaybeArrayType); 627 Value *arrayToVector(Value *V, Type *TargetType, const Twine &Name); 628 Value *vectorToArray(Value *V, Type *OrigType, const Twine &Name); 629 630 /// Break up the loads of a struct into the loads of its components 631 632 /// Convert a vector or scalar type that can't be operated on by buffer 633 /// intrinsics to one that would be legal through bitcasts and/or truncation. 634 /// Uses the wider of i32, i16, or i8 where possible. 635 Type *legalNonAggregateFor(Type *T); 636 Value *makeLegalNonAggregate(Value *V, Type *TargetType, const Twine &Name); 637 Value *makeIllegalNonAggregate(Value *V, Type *OrigType, const Twine &Name); 638 639 struct VecSlice { 640 uint64_t Index = 0; 641 uint64_t Length = 0; 642 VecSlice() = delete; 643 // Needed for some Clangs 644 VecSlice(uint64_t Index, uint64_t Length) : Index(Index), Length(Length) {} 645 }; 646 /// Return the [index, length] pairs into which `T` needs to be cut to form 647 /// legal buffer load or store operations. Clears `Slices`. Creates an empty 648 /// `Slices` for non-vector inputs and creates one slice if no slicing will be 649 /// needed. 650 void getVecSlices(Type *T, SmallVectorImpl<VecSlice> &Slices); 651 652 Value *extractSlice(Value *Vec, VecSlice S, const Twine &Name); 653 Value *insertSlice(Value *Whole, Value *Part, VecSlice S, const Twine &Name); 654 655 /// In most cases, return `LegalType`. However, when given an input that would 656 /// normally be a legal type for the buffer intrinsics to return but that 657 /// isn't hooked up through SelectionDAG, return a type of the same width that 658 /// can be used with the relevant intrinsics. Specifically, handle the cases: 659 /// - <1 x T> => T for all T 660 /// - <N x i8> <=> i16, i32, 2xi32, 4xi32 (as needed) 661 /// - <N x T> where T is under 32 bits and the total size is 96 bits <=> <3 x 662 /// i32> 663 Type *intrinsicTypeFor(Type *LegalType); 664 665 bool visitLoadImpl(LoadInst &OrigLI, Type *PartType, 666 SmallVectorImpl<uint32_t> &AggIdxs, uint64_t AggByteOffset, 667 Value *&Result, const Twine &Name); 668 /// Return value is (Changed, ModifiedInPlace) 669 std::pair<bool, bool> visitStoreImpl(StoreInst &OrigSI, Type *PartType, 670 SmallVectorImpl<uint32_t> &AggIdxs, 671 uint64_t AggByteOffset, 672 const Twine &Name); 673 674 bool visitInstruction(Instruction &I) { return false; } 675 bool visitLoadInst(LoadInst &LI); 676 bool visitStoreInst(StoreInst &SI); 677 678 public: 679 LegalizeBufferContentTypesVisitor(const DataLayout &DL, LLVMContext &Ctx) 680 : IRB(Ctx), DL(DL) {} 681 bool processFunction(Function &F); 682 }; 683 } // namespace 684 685 Type *LegalizeBufferContentTypesVisitor::scalarArrayTypeAsVector(Type *T) { 686 ArrayType *AT = dyn_cast<ArrayType>(T); 687 if (!AT) 688 return T; 689 Type *ET = AT->getElementType(); 690 if (!ET->isSingleValueType() || isa<VectorType>(ET)) 691 report_fatal_error("loading non-scalar arrays from buffer fat pointers " 692 "should have recursed"); 693 if (!DL.typeSizeEqualsStoreSize(AT)) 694 report_fatal_error( 695 "loading padded arrays from buffer fat pinters should have recursed"); 696 return FixedVectorType::get(ET, AT->getNumElements()); 697 } 698 699 Value *LegalizeBufferContentTypesVisitor::arrayToVector(Value *V, 700 Type *TargetType, 701 const Twine &Name) { 702 Value *VectorRes = PoisonValue::get(TargetType); 703 auto *VT = cast<FixedVectorType>(TargetType); 704 unsigned EC = VT->getNumElements(); 705 for (auto I : iota_range<unsigned>(0, EC, /*Inclusive=*/false)) { 706 Value *Elem = IRB.CreateExtractValue(V, I, Name + ".elem." + Twine(I)); 707 VectorRes = IRB.CreateInsertElement(VectorRes, Elem, I, 708 Name + ".as.vec." + Twine(I)); 709 } 710 return VectorRes; 711 } 712 713 Value *LegalizeBufferContentTypesVisitor::vectorToArray(Value *V, 714 Type *OrigType, 715 const Twine &Name) { 716 Value *ArrayRes = PoisonValue::get(OrigType); 717 ArrayType *AT = cast<ArrayType>(OrigType); 718 unsigned EC = AT->getNumElements(); 719 for (auto I : iota_range<unsigned>(0, EC, /*Inclusive=*/false)) { 720 Value *Elem = IRB.CreateExtractElement(V, I, Name + ".elem." + Twine(I)); 721 ArrayRes = IRB.CreateInsertValue(ArrayRes, Elem, I, 722 Name + ".as.array." + Twine(I)); 723 } 724 return ArrayRes; 725 } 726 727 Type *LegalizeBufferContentTypesVisitor::legalNonAggregateFor(Type *T) { 728 TypeSize Size = DL.getTypeStoreSizeInBits(T); 729 // Implicitly zero-extend to the next byte if needed 730 if (!DL.typeSizeEqualsStoreSize(T)) 731 T = IRB.getIntNTy(Size.getFixedValue()); 732 Type *ElemTy = T->getScalarType(); 733 if (isa<PointerType, ScalableVectorType>(ElemTy)) { 734 // Pointers are always big enough, and we'll let scalable vectors through to 735 // fail in codegen. 736 return T; 737 } 738 unsigned ElemSize = DL.getTypeSizeInBits(ElemTy).getFixedValue(); 739 if (isPowerOf2_32(ElemSize) && ElemSize >= 16 && ElemSize <= 128) { 740 // [vectors of] anything that's 16/32/64/128 bits can be cast and split into 741 // legal buffer operations. 742 return T; 743 } 744 Type *BestVectorElemType = nullptr; 745 if (Size.isKnownMultipleOf(32)) 746 BestVectorElemType = IRB.getInt32Ty(); 747 else if (Size.isKnownMultipleOf(16)) 748 BestVectorElemType = IRB.getInt16Ty(); 749 else 750 BestVectorElemType = IRB.getInt8Ty(); 751 unsigned NumCastElems = 752 Size.getFixedValue() / BestVectorElemType->getIntegerBitWidth(); 753 if (NumCastElems == 1) 754 return BestVectorElemType; 755 return FixedVectorType::get(BestVectorElemType, NumCastElems); 756 } 757 758 Value *LegalizeBufferContentTypesVisitor::makeLegalNonAggregate( 759 Value *V, Type *TargetType, const Twine &Name) { 760 Type *SourceType = V->getType(); 761 TypeSize SourceSize = DL.getTypeSizeInBits(SourceType); 762 TypeSize TargetSize = DL.getTypeSizeInBits(TargetType); 763 if (SourceSize != TargetSize) { 764 Type *ShortScalarTy = IRB.getIntNTy(SourceSize.getFixedValue()); 765 Type *ByteScalarTy = IRB.getIntNTy(TargetSize.getFixedValue()); 766 Value *AsScalar = IRB.CreateBitCast(V, ShortScalarTy, Name + ".as.scalar"); 767 Value *Zext = IRB.CreateZExt(AsScalar, ByteScalarTy, Name + ".zext"); 768 V = Zext; 769 SourceType = ByteScalarTy; 770 } 771 return IRB.CreateBitCast(V, TargetType, Name + ".legal"); 772 } 773 774 Value *LegalizeBufferContentTypesVisitor::makeIllegalNonAggregate( 775 Value *V, Type *OrigType, const Twine &Name) { 776 Type *LegalType = V->getType(); 777 TypeSize LegalSize = DL.getTypeSizeInBits(LegalType); 778 TypeSize OrigSize = DL.getTypeSizeInBits(OrigType); 779 if (LegalSize != OrigSize) { 780 Type *ShortScalarTy = IRB.getIntNTy(OrigSize.getFixedValue()); 781 Type *ByteScalarTy = IRB.getIntNTy(LegalSize.getFixedValue()); 782 Value *AsScalar = IRB.CreateBitCast(V, ByteScalarTy, Name + ".bytes.cast"); 783 Value *Trunc = IRB.CreateTrunc(AsScalar, ShortScalarTy, Name + ".trunc"); 784 return IRB.CreateBitCast(Trunc, OrigType, Name + ".orig"); 785 } 786 return IRB.CreateBitCast(V, OrigType, Name + ".real.ty"); 787 } 788 789 Type *LegalizeBufferContentTypesVisitor::intrinsicTypeFor(Type *LegalType) { 790 auto *VT = dyn_cast<FixedVectorType>(LegalType); 791 if (!VT) 792 return LegalType; 793 Type *ET = VT->getElementType(); 794 // Explicitly return the element type of 1-element vectors because the 795 // underlying intrinsics don't like <1 x T> even though it's a synonym for T. 796 if (VT->getNumElements() == 1) 797 return ET; 798 if (DL.getTypeSizeInBits(LegalType) == 96 && DL.getTypeSizeInBits(ET) < 32) 799 return FixedVectorType::get(IRB.getInt32Ty(), 3); 800 if (ET->isIntegerTy(8)) { 801 switch (VT->getNumElements()) { 802 default: 803 return LegalType; // Let it crash later 804 case 1: 805 return IRB.getInt8Ty(); 806 case 2: 807 return IRB.getInt16Ty(); 808 case 4: 809 return IRB.getInt32Ty(); 810 case 8: 811 return FixedVectorType::get(IRB.getInt32Ty(), 2); 812 case 16: 813 return FixedVectorType::get(IRB.getInt32Ty(), 4); 814 } 815 } 816 return LegalType; 817 } 818 819 void LegalizeBufferContentTypesVisitor::getVecSlices( 820 Type *T, SmallVectorImpl<VecSlice> &Slices) { 821 Slices.clear(); 822 auto *VT = dyn_cast<FixedVectorType>(T); 823 if (!VT) 824 return; 825 826 uint64_t ElemBitWidth = 827 DL.getTypeSizeInBits(VT->getElementType()).getFixedValue(); 828 829 uint64_t ElemsPer4Words = 128 / ElemBitWidth; 830 uint64_t ElemsPer2Words = ElemsPer4Words / 2; 831 uint64_t ElemsPerWord = ElemsPer2Words / 2; 832 uint64_t ElemsPerShort = ElemsPerWord / 2; 833 uint64_t ElemsPerByte = ElemsPerShort / 2; 834 // If the elements evenly pack into 32-bit words, we can use 3-word stores, 835 // such as for <6 x bfloat> or <3 x i32>, but we can't dot his for, for 836 // example, <3 x i64>, since that's not slicing. 837 uint64_t ElemsPer3Words = ElemsPerWord * 3; 838 839 uint64_t TotalElems = VT->getNumElements(); 840 uint64_t Index = 0; 841 auto TrySlice = [&](unsigned MaybeLen) { 842 if (MaybeLen > 0 && Index + MaybeLen <= TotalElems) { 843 VecSlice Slice{/*Index=*/Index, /*Length=*/MaybeLen}; 844 Slices.push_back(Slice); 845 Index += MaybeLen; 846 return true; 847 } 848 return false; 849 }; 850 while (Index < TotalElems) { 851 TrySlice(ElemsPer4Words) || TrySlice(ElemsPer3Words) || 852 TrySlice(ElemsPer2Words) || TrySlice(ElemsPerWord) || 853 TrySlice(ElemsPerShort) || TrySlice(ElemsPerByte); 854 } 855 } 856 857 Value *LegalizeBufferContentTypesVisitor::extractSlice(Value *Vec, VecSlice S, 858 const Twine &Name) { 859 auto *VecVT = dyn_cast<FixedVectorType>(Vec->getType()); 860 if (!VecVT) 861 return Vec; 862 if (S.Length == VecVT->getNumElements() && S.Index == 0) 863 return Vec; 864 if (S.Length == 1) 865 return IRB.CreateExtractElement(Vec, S.Index, 866 Name + ".slice." + Twine(S.Index)); 867 SmallVector<int> Mask = llvm::to_vector( 868 llvm::iota_range<int>(S.Index, S.Index + S.Length, /*Inclusive=*/false)); 869 return IRB.CreateShuffleVector(Vec, Mask, Name + ".slice." + Twine(S.Index)); 870 } 871 872 Value *LegalizeBufferContentTypesVisitor::insertSlice(Value *Whole, Value *Part, 873 VecSlice S, 874 const Twine &Name) { 875 auto *WholeVT = dyn_cast<FixedVectorType>(Whole->getType()); 876 if (!WholeVT) 877 return Part; 878 if (S.Length == WholeVT->getNumElements() && S.Index == 0) 879 return Part; 880 if (S.Length == 1) { 881 return IRB.CreateInsertElement(Whole, Part, S.Index, 882 Name + ".slice." + Twine(S.Index)); 883 } 884 int NumElems = cast<FixedVectorType>(Whole->getType())->getNumElements(); 885 886 // Extend the slice with poisons to make the main shufflevector happy. 887 SmallVector<int> ExtPartMask(NumElems, -1); 888 for (auto [I, E] : llvm::enumerate( 889 MutableArrayRef<int>(ExtPartMask).take_front(S.Length))) { 890 E = I; 891 } 892 Value *ExtPart = IRB.CreateShuffleVector(Part, ExtPartMask, 893 Name + ".ext." + Twine(S.Index)); 894 895 SmallVector<int> Mask = 896 llvm::to_vector(llvm::iota_range<int>(0, NumElems, /*Inclusive=*/false)); 897 for (auto [I, E] : 898 llvm::enumerate(MutableArrayRef<int>(Mask).slice(S.Index, S.Length))) 899 E = I + NumElems; 900 return IRB.CreateShuffleVector(Whole, ExtPart, Mask, 901 Name + ".parts." + Twine(S.Index)); 902 } 903 904 bool LegalizeBufferContentTypesVisitor::visitLoadImpl( 905 LoadInst &OrigLI, Type *PartType, SmallVectorImpl<uint32_t> &AggIdxs, 906 uint64_t AggByteOff, Value *&Result, const Twine &Name) { 907 if (auto *ST = dyn_cast<StructType>(PartType)) { 908 const StructLayout *Layout = DL.getStructLayout(ST); 909 bool Changed = false; 910 for (auto [I, ElemTy, Offset] : 911 llvm::enumerate(ST->elements(), Layout->getMemberOffsets())) { 912 AggIdxs.push_back(I); 913 Changed |= visitLoadImpl(OrigLI, ElemTy, AggIdxs, 914 AggByteOff + Offset.getFixedValue(), Result, 915 Name + "." + Twine(I)); 916 AggIdxs.pop_back(); 917 } 918 return Changed; 919 } 920 if (auto *AT = dyn_cast<ArrayType>(PartType)) { 921 Type *ElemTy = AT->getElementType(); 922 if (!ElemTy->isSingleValueType() || !DL.typeSizeEqualsStoreSize(ElemTy) || 923 ElemTy->isVectorTy()) { 924 TypeSize ElemStoreSize = DL.getTypeStoreSize(ElemTy); 925 bool Changed = false; 926 for (auto I : llvm::iota_range<uint32_t>(0, AT->getNumElements(), 927 /*Inclusive=*/false)) { 928 AggIdxs.push_back(I); 929 Changed |= visitLoadImpl(OrigLI, ElemTy, AggIdxs, 930 AggByteOff + I * ElemStoreSize.getFixedValue(), 931 Result, Name + Twine(I)); 932 AggIdxs.pop_back(); 933 } 934 return Changed; 935 } 936 } 937 938 // Typical case 939 940 Type *ArrayAsVecType = scalarArrayTypeAsVector(PartType); 941 Type *LegalType = legalNonAggregateFor(ArrayAsVecType); 942 943 SmallVector<VecSlice> Slices; 944 getVecSlices(LegalType, Slices); 945 bool HasSlices = Slices.size() > 1; 946 bool IsAggPart = !AggIdxs.empty(); 947 Value *LoadsRes; 948 if (!HasSlices && !IsAggPart) { 949 Type *LoadableType = intrinsicTypeFor(LegalType); 950 if (LoadableType == PartType) 951 return false; 952 953 IRB.SetInsertPoint(&OrigLI); 954 auto *NLI = cast<LoadInst>(OrigLI.clone()); 955 NLI->mutateType(LoadableType); 956 NLI = IRB.Insert(NLI); 957 NLI->setName(Name + ".loadable"); 958 959 LoadsRes = IRB.CreateBitCast(NLI, LegalType, Name + ".from.loadable"); 960 } else { 961 IRB.SetInsertPoint(&OrigLI); 962 LoadsRes = PoisonValue::get(LegalType); 963 Value *OrigPtr = OrigLI.getPointerOperand(); 964 // If we're needing to spill something into more than one load, its legal 965 // type will be a vector (ex. an i256 load will have LegalType = <8 x i32>). 966 // But if we're already a scalar (which can happen if we're splitting up a 967 // struct), the element type will be the legal type itself. 968 Type *ElemType = LegalType->getScalarType(); 969 unsigned ElemBytes = DL.getTypeStoreSize(ElemType); 970 AAMDNodes AANodes = OrigLI.getAAMetadata(); 971 if (IsAggPart && Slices.empty()) 972 Slices.push_back(VecSlice{/*Index=*/0, /*Length=*/1}); 973 for (VecSlice S : Slices) { 974 Type *SliceType = 975 S.Length != 1 ? FixedVectorType::get(ElemType, S.Length) : ElemType; 976 int64_t ByteOffset = AggByteOff + S.Index * ElemBytes; 977 // You can't reasonably expect loads to wrap around the edge of memory. 978 Value *NewPtr = IRB.CreateGEP( 979 IRB.getInt8Ty(), OrigLI.getPointerOperand(), IRB.getInt32(ByteOffset), 980 OrigPtr->getName() + ".off.ptr." + Twine(ByteOffset), 981 GEPNoWrapFlags::noUnsignedWrap()); 982 Type *LoadableType = intrinsicTypeFor(SliceType); 983 LoadInst *NewLI = IRB.CreateAlignedLoad( 984 LoadableType, NewPtr, commonAlignment(OrigLI.getAlign(), ByteOffset), 985 Name + ".off." + Twine(ByteOffset)); 986 copyMetadataForLoad(*NewLI, OrigLI); 987 NewLI->setAAMetadata( 988 AANodes.adjustForAccess(ByteOffset, LoadableType, DL)); 989 NewLI->setAtomic(OrigLI.getOrdering(), OrigLI.getSyncScopeID()); 990 NewLI->setVolatile(OrigLI.isVolatile()); 991 Value *Loaded = IRB.CreateBitCast(NewLI, SliceType, 992 NewLI->getName() + ".from.loadable"); 993 LoadsRes = insertSlice(LoadsRes, Loaded, S, Name); 994 } 995 } 996 if (LegalType != ArrayAsVecType) 997 LoadsRes = makeIllegalNonAggregate(LoadsRes, ArrayAsVecType, Name); 998 if (ArrayAsVecType != PartType) 999 LoadsRes = vectorToArray(LoadsRes, PartType, Name); 1000 1001 if (IsAggPart) 1002 Result = IRB.CreateInsertValue(Result, LoadsRes, AggIdxs, Name); 1003 else 1004 Result = LoadsRes; 1005 return true; 1006 } 1007 1008 bool LegalizeBufferContentTypesVisitor::visitLoadInst(LoadInst &LI) { 1009 if (LI.getPointerAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER) 1010 return false; 1011 1012 SmallVector<uint32_t> AggIdxs; 1013 Type *OrigType = LI.getType(); 1014 Value *Result = PoisonValue::get(OrigType); 1015 bool Changed = visitLoadImpl(LI, OrigType, AggIdxs, 0, Result, LI.getName()); 1016 if (!Changed) 1017 return false; 1018 Result->takeName(&LI); 1019 LI.replaceAllUsesWith(Result); 1020 LI.eraseFromParent(); 1021 return Changed; 1022 } 1023 1024 std::pair<bool, bool> LegalizeBufferContentTypesVisitor::visitStoreImpl( 1025 StoreInst &OrigSI, Type *PartType, SmallVectorImpl<uint32_t> &AggIdxs, 1026 uint64_t AggByteOff, const Twine &Name) { 1027 if (auto *ST = dyn_cast<StructType>(PartType)) { 1028 const StructLayout *Layout = DL.getStructLayout(ST); 1029 bool Changed = false; 1030 for (auto [I, ElemTy, Offset] : 1031 llvm::enumerate(ST->elements(), Layout->getMemberOffsets())) { 1032 AggIdxs.push_back(I); 1033 Changed |= std::get<0>(visitStoreImpl(OrigSI, ElemTy, AggIdxs, 1034 AggByteOff + Offset.getFixedValue(), 1035 Name + "." + Twine(I))); 1036 AggIdxs.pop_back(); 1037 } 1038 return std::make_pair(Changed, /*ModifiedInPlace=*/false); 1039 } 1040 if (auto *AT = dyn_cast<ArrayType>(PartType)) { 1041 Type *ElemTy = AT->getElementType(); 1042 if (!ElemTy->isSingleValueType() || !DL.typeSizeEqualsStoreSize(ElemTy) || 1043 ElemTy->isVectorTy()) { 1044 TypeSize ElemStoreSize = DL.getTypeStoreSize(ElemTy); 1045 bool Changed = false; 1046 for (auto I : llvm::iota_range<uint32_t>(0, AT->getNumElements(), 1047 /*Inclusive=*/false)) { 1048 AggIdxs.push_back(I); 1049 Changed |= std::get<0>(visitStoreImpl( 1050 OrigSI, ElemTy, AggIdxs, 1051 AggByteOff + I * ElemStoreSize.getFixedValue(), Name + Twine(I))); 1052 AggIdxs.pop_back(); 1053 } 1054 return std::make_pair(Changed, /*ModifiedInPlace=*/false); 1055 } 1056 } 1057 1058 Value *OrigData = OrigSI.getValueOperand(); 1059 Value *NewData = OrigData; 1060 1061 bool IsAggPart = !AggIdxs.empty(); 1062 if (IsAggPart) 1063 NewData = IRB.CreateExtractValue(NewData, AggIdxs, Name); 1064 1065 Type *ArrayAsVecType = scalarArrayTypeAsVector(PartType); 1066 if (ArrayAsVecType != PartType) { 1067 NewData = arrayToVector(NewData, ArrayAsVecType, Name); 1068 } 1069 1070 Type *LegalType = legalNonAggregateFor(ArrayAsVecType); 1071 if (LegalType != ArrayAsVecType) { 1072 NewData = makeLegalNonAggregate(NewData, LegalType, Name); 1073 } 1074 1075 SmallVector<VecSlice> Slices; 1076 getVecSlices(LegalType, Slices); 1077 bool NeedToSplit = Slices.size() > 1 || IsAggPart; 1078 if (!NeedToSplit) { 1079 Type *StorableType = intrinsicTypeFor(LegalType); 1080 if (StorableType == PartType) 1081 return std::make_pair(/*Changed=*/false, /*ModifiedInPlace=*/false); 1082 NewData = IRB.CreateBitCast(NewData, StorableType, Name + ".storable"); 1083 OrigSI.setOperand(0, NewData); 1084 return std::make_pair(/*Changed=*/true, /*ModifiedInPlace=*/true); 1085 } 1086 1087 Value *OrigPtr = OrigSI.getPointerOperand(); 1088 Type *ElemType = LegalType->getScalarType(); 1089 if (IsAggPart && Slices.empty()) 1090 Slices.push_back(VecSlice{/*Index=*/0, /*Length=*/1}); 1091 unsigned ElemBytes = DL.getTypeStoreSize(ElemType); 1092 AAMDNodes AANodes = OrigSI.getAAMetadata(); 1093 for (VecSlice S : Slices) { 1094 Type *SliceType = 1095 S.Length != 1 ? FixedVectorType::get(ElemType, S.Length) : ElemType; 1096 int64_t ByteOffset = AggByteOff + S.Index * ElemBytes; 1097 Value *NewPtr = 1098 IRB.CreateGEP(IRB.getInt8Ty(), OrigPtr, IRB.getInt32(ByteOffset), 1099 OrigPtr->getName() + ".part." + Twine(S.Index), 1100 GEPNoWrapFlags::noUnsignedWrap()); 1101 Value *DataSlice = extractSlice(NewData, S, Name); 1102 Type *StorableType = intrinsicTypeFor(SliceType); 1103 DataSlice = IRB.CreateBitCast(DataSlice, StorableType, 1104 DataSlice->getName() + ".storable"); 1105 auto *NewSI = cast<StoreInst>(OrigSI.clone()); 1106 NewSI->setAlignment(commonAlignment(OrigSI.getAlign(), ByteOffset)); 1107 IRB.Insert(NewSI); 1108 NewSI->setOperand(0, DataSlice); 1109 NewSI->setOperand(1, NewPtr); 1110 NewSI->setAAMetadata(AANodes.adjustForAccess(ByteOffset, StorableType, DL)); 1111 } 1112 return std::make_pair(/*Changed=*/true, /*ModifiedInPlace=*/false); 1113 } 1114 1115 bool LegalizeBufferContentTypesVisitor::visitStoreInst(StoreInst &SI) { 1116 if (SI.getPointerAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER) 1117 return false; 1118 IRB.SetInsertPoint(&SI); 1119 SmallVector<uint32_t> AggIdxs; 1120 Value *OrigData = SI.getValueOperand(); 1121 auto [Changed, ModifiedInPlace] = 1122 visitStoreImpl(SI, OrigData->getType(), AggIdxs, 0, OrigData->getName()); 1123 if (Changed && !ModifiedInPlace) 1124 SI.eraseFromParent(); 1125 return Changed; 1126 } 1127 1128 bool LegalizeBufferContentTypesVisitor::processFunction(Function &F) { 1129 bool Changed = false; 1130 for (Instruction &I : make_early_inc_range(instructions(F))) { 1131 Changed |= visit(I); 1132 } 1133 return Changed; 1134 } 1135 1136 /// Return the ptr addrspace(8) and i32 (resource and offset parts) in a lowered 1137 /// buffer fat pointer constant. 1138 static std::pair<Constant *, Constant *> 1139 splitLoweredFatBufferConst(Constant *C) { 1140 assert(isSplitFatPtr(C->getType()) && "Not a split fat buffer pointer"); 1141 return std::make_pair(C->getAggregateElement(0u), C->getAggregateElement(1u)); 1142 } 1143 1144 namespace { 1145 /// Handle the remapping of ptr addrspace(7) constants. 1146 class FatPtrConstMaterializer final : public ValueMaterializer { 1147 BufferFatPtrToStructTypeMap *TypeMap; 1148 // An internal mapper that is used to recurse into the arguments of constants. 1149 // While the documentation for `ValueMapper` specifies not to use it 1150 // recursively, examination of the logic in mapValue() shows that it can 1151 // safely be used recursively when handling constants, like it does in its own 1152 // logic. 1153 ValueMapper InternalMapper; 1154 1155 Constant *materializeBufferFatPtrConst(Constant *C); 1156 1157 public: 1158 // UnderlyingMap is the value map this materializer will be filling. 1159 FatPtrConstMaterializer(BufferFatPtrToStructTypeMap *TypeMap, 1160 ValueToValueMapTy &UnderlyingMap) 1161 : TypeMap(TypeMap), 1162 InternalMapper(UnderlyingMap, RF_None, TypeMap, this) {} 1163 virtual ~FatPtrConstMaterializer() = default; 1164 1165 Value *materialize(Value *V) override; 1166 }; 1167 } // namespace 1168 1169 Constant *FatPtrConstMaterializer::materializeBufferFatPtrConst(Constant *C) { 1170 Type *SrcTy = C->getType(); 1171 auto *NewTy = dyn_cast<StructType>(TypeMap->remapType(SrcTy)); 1172 if (C->isNullValue()) 1173 return ConstantAggregateZero::getNullValue(NewTy); 1174 if (isa<PoisonValue>(C)) { 1175 return ConstantStruct::get(NewTy, 1176 {PoisonValue::get(NewTy->getElementType(0)), 1177 PoisonValue::get(NewTy->getElementType(1))}); 1178 } 1179 if (isa<UndefValue>(C)) { 1180 return ConstantStruct::get(NewTy, 1181 {UndefValue::get(NewTy->getElementType(0)), 1182 UndefValue::get(NewTy->getElementType(1))}); 1183 } 1184 1185 if (auto *VC = dyn_cast<ConstantVector>(C)) { 1186 if (Constant *S = VC->getSplatValue()) { 1187 Constant *NewS = InternalMapper.mapConstant(*S); 1188 if (!NewS) 1189 return nullptr; 1190 auto [Rsrc, Off] = splitLoweredFatBufferConst(NewS); 1191 auto EC = VC->getType()->getElementCount(); 1192 return ConstantStruct::get(NewTy, {ConstantVector::getSplat(EC, Rsrc), 1193 ConstantVector::getSplat(EC, Off)}); 1194 } 1195 SmallVector<Constant *> Rsrcs; 1196 SmallVector<Constant *> Offs; 1197 for (Value *Op : VC->operand_values()) { 1198 auto *NewOp = dyn_cast_or_null<Constant>(InternalMapper.mapValue(*Op)); 1199 if (!NewOp) 1200 return nullptr; 1201 auto [Rsrc, Off] = splitLoweredFatBufferConst(NewOp); 1202 Rsrcs.push_back(Rsrc); 1203 Offs.push_back(Off); 1204 } 1205 Constant *RsrcVec = ConstantVector::get(Rsrcs); 1206 Constant *OffVec = ConstantVector::get(Offs); 1207 return ConstantStruct::get(NewTy, {RsrcVec, OffVec}); 1208 } 1209 1210 if (isa<GlobalValue>(C)) 1211 report_fatal_error("Global values containing ptr addrspace(7) (buffer " 1212 "fat pointer) values are not supported"); 1213 1214 if (isa<ConstantExpr>(C)) 1215 report_fatal_error("Constant exprs containing ptr addrspace(7) (buffer " 1216 "fat pointer) values should have been expanded earlier"); 1217 1218 return nullptr; 1219 } 1220 1221 Value *FatPtrConstMaterializer::materialize(Value *V) { 1222 Constant *C = dyn_cast<Constant>(V); 1223 if (!C) 1224 return nullptr; 1225 // Structs and other types that happen to contain fat pointers get remapped 1226 // by the mapValue() logic. 1227 if (!isBufferFatPtrConst(C)) 1228 return nullptr; 1229 return materializeBufferFatPtrConst(C); 1230 } 1231 1232 using PtrParts = std::pair<Value *, Value *>; 1233 namespace { 1234 // The visitor returns the resource and offset parts for an instruction if they 1235 // can be computed, or (nullptr, nullptr) for cases that don't have a meaningful 1236 // value mapping. 1237 class SplitPtrStructs : public InstVisitor<SplitPtrStructs, PtrParts> { 1238 ValueToValueMapTy RsrcParts; 1239 ValueToValueMapTy OffParts; 1240 1241 // Track instructions that have been rewritten into a user of the component 1242 // parts of their ptr addrspace(7) input. Instructions that produced 1243 // ptr addrspace(7) parts should **not** be RAUW'd before being added to this 1244 // set, as that replacement will be handled in a post-visit step. However, 1245 // instructions that yield values that aren't fat pointers (ex. ptrtoint) 1246 // should RAUW themselves with new instructions that use the split parts 1247 // of their arguments during processing. 1248 DenseSet<Instruction *> SplitUsers; 1249 1250 // Nodes that need a second look once we've computed the parts for all other 1251 // instructions to see if, for example, we really need to phi on the resource 1252 // part. 1253 SmallVector<Instruction *> Conditionals; 1254 // Temporary instructions produced while lowering conditionals that should be 1255 // killed. 1256 SmallVector<Instruction *> ConditionalTemps; 1257 1258 // Subtarget info, needed for determining what cache control bits to set. 1259 const TargetMachine *TM; 1260 const GCNSubtarget *ST = nullptr; 1261 1262 IRBuilder<> IRB; 1263 1264 // Copy metadata between instructions if applicable. 1265 void copyMetadata(Value *Dest, Value *Src); 1266 1267 // Get the resource and offset parts of the value V, inserting appropriate 1268 // extractvalue calls if needed. 1269 PtrParts getPtrParts(Value *V); 1270 1271 // Given an instruction that could produce multiple resource parts (a PHI or 1272 // select), collect the set of possible instructions that could have provided 1273 // its resource parts that it could have (the `Roots`) and the set of 1274 // conditional instructions visited during the search (`Seen`). If, after 1275 // removing the root of the search from `Seen` and `Roots`, `Seen` is a subset 1276 // of `Roots` and `Roots - Seen` contains one element, the resource part of 1277 // that element can replace the resource part of all other elements in `Seen`. 1278 void getPossibleRsrcRoots(Instruction *I, SmallPtrSetImpl<Value *> &Roots, 1279 SmallPtrSetImpl<Value *> &Seen); 1280 void processConditionals(); 1281 1282 // If an instruction hav been split into resource and offset parts, 1283 // delete that instruction. If any of its uses have not themselves been split 1284 // into parts (for example, an insertvalue), construct the structure 1285 // that the type rewrites declared should be produced by the dying instruction 1286 // and use that. 1287 // Also, kill the temporary extractvalue operations produced by the two-stage 1288 // lowering of PHIs and conditionals. 1289 void killAndReplaceSplitInstructions(SmallVectorImpl<Instruction *> &Origs); 1290 1291 void setAlign(CallInst *Intr, Align A, unsigned RsrcArgIdx); 1292 void insertPreMemOpFence(AtomicOrdering Order, SyncScope::ID SSID); 1293 void insertPostMemOpFence(AtomicOrdering Order, SyncScope::ID SSID); 1294 Value *handleMemoryInst(Instruction *I, Value *Arg, Value *Ptr, Type *Ty, 1295 Align Alignment, AtomicOrdering Order, 1296 bool IsVolatile, SyncScope::ID SSID); 1297 1298 public: 1299 SplitPtrStructs(LLVMContext &Ctx, const TargetMachine *TM) 1300 : TM(TM), IRB(Ctx) {} 1301 1302 void processFunction(Function &F); 1303 1304 PtrParts visitInstruction(Instruction &I); 1305 PtrParts visitLoadInst(LoadInst &LI); 1306 PtrParts visitStoreInst(StoreInst &SI); 1307 PtrParts visitAtomicRMWInst(AtomicRMWInst &AI); 1308 PtrParts visitAtomicCmpXchgInst(AtomicCmpXchgInst &AI); 1309 PtrParts visitGetElementPtrInst(GetElementPtrInst &GEP); 1310 1311 PtrParts visitPtrToIntInst(PtrToIntInst &PI); 1312 PtrParts visitIntToPtrInst(IntToPtrInst &IP); 1313 PtrParts visitAddrSpaceCastInst(AddrSpaceCastInst &I); 1314 PtrParts visitICmpInst(ICmpInst &Cmp); 1315 PtrParts visitFreezeInst(FreezeInst &I); 1316 1317 PtrParts visitExtractElementInst(ExtractElementInst &I); 1318 PtrParts visitInsertElementInst(InsertElementInst &I); 1319 PtrParts visitShuffleVectorInst(ShuffleVectorInst &I); 1320 1321 PtrParts visitPHINode(PHINode &PHI); 1322 PtrParts visitSelectInst(SelectInst &SI); 1323 1324 PtrParts visitIntrinsicInst(IntrinsicInst &II); 1325 }; 1326 } // namespace 1327 1328 void SplitPtrStructs::copyMetadata(Value *Dest, Value *Src) { 1329 auto *DestI = dyn_cast<Instruction>(Dest); 1330 auto *SrcI = dyn_cast<Instruction>(Src); 1331 1332 if (!DestI || !SrcI) 1333 return; 1334 1335 DestI->copyMetadata(*SrcI); 1336 } 1337 1338 PtrParts SplitPtrStructs::getPtrParts(Value *V) { 1339 assert(isSplitFatPtr(V->getType()) && "it's not meaningful to get the parts " 1340 "of something that wasn't rewritten"); 1341 auto *RsrcEntry = &RsrcParts[V]; 1342 auto *OffEntry = &OffParts[V]; 1343 if (*RsrcEntry && *OffEntry) 1344 return {*RsrcEntry, *OffEntry}; 1345 1346 if (auto *C = dyn_cast<Constant>(V)) { 1347 auto [Rsrc, Off] = splitLoweredFatBufferConst(C); 1348 return {*RsrcEntry = Rsrc, *OffEntry = Off}; 1349 } 1350 1351 IRBuilder<>::InsertPointGuard Guard(IRB); 1352 if (auto *I = dyn_cast<Instruction>(V)) { 1353 LLVM_DEBUG(dbgs() << "Recursing to split parts of " << *I << "\n"); 1354 auto [Rsrc, Off] = visit(*I); 1355 if (Rsrc && Off) 1356 return {*RsrcEntry = Rsrc, *OffEntry = Off}; 1357 // We'll be creating the new values after the relevant instruction. 1358 // This instruction generates a value and so isn't a terminator. 1359 IRB.SetInsertPoint(*I->getInsertionPointAfterDef()); 1360 IRB.SetCurrentDebugLocation(I->getDebugLoc()); 1361 } else if (auto *A = dyn_cast<Argument>(V)) { 1362 IRB.SetInsertPointPastAllocas(A->getParent()); 1363 IRB.SetCurrentDebugLocation(DebugLoc()); 1364 } 1365 Value *Rsrc = IRB.CreateExtractValue(V, 0, V->getName() + ".rsrc"); 1366 Value *Off = IRB.CreateExtractValue(V, 1, V->getName() + ".off"); 1367 return {*RsrcEntry = Rsrc, *OffEntry = Off}; 1368 } 1369 1370 /// Returns the instruction that defines the resource part of the value V. 1371 /// Note that this is not getUnderlyingObject(), since that looks through 1372 /// operations like ptrmask which might modify the resource part. 1373 /// 1374 /// We can limit ourselves to just looking through GEPs followed by looking 1375 /// through addrspacecasts because only those two operations preserve the 1376 /// resource part, and because operations on an `addrspace(8)` (which is the 1377 /// legal input to this addrspacecast) would produce a different resource part. 1378 static Value *rsrcPartRoot(Value *V) { 1379 while (auto *GEP = dyn_cast<GEPOperator>(V)) 1380 V = GEP->getPointerOperand(); 1381 while (auto *ASC = dyn_cast<AddrSpaceCastOperator>(V)) 1382 V = ASC->getPointerOperand(); 1383 return V; 1384 } 1385 1386 void SplitPtrStructs::getPossibleRsrcRoots(Instruction *I, 1387 SmallPtrSetImpl<Value *> &Roots, 1388 SmallPtrSetImpl<Value *> &Seen) { 1389 if (auto *PHI = dyn_cast<PHINode>(I)) { 1390 if (!Seen.insert(I).second) 1391 return; 1392 for (Value *In : PHI->incoming_values()) { 1393 In = rsrcPartRoot(In); 1394 Roots.insert(In); 1395 if (isa<PHINode, SelectInst>(In)) 1396 getPossibleRsrcRoots(cast<Instruction>(In), Roots, Seen); 1397 } 1398 } else if (auto *SI = dyn_cast<SelectInst>(I)) { 1399 if (!Seen.insert(SI).second) 1400 return; 1401 Value *TrueVal = rsrcPartRoot(SI->getTrueValue()); 1402 Value *FalseVal = rsrcPartRoot(SI->getFalseValue()); 1403 Roots.insert(TrueVal); 1404 Roots.insert(FalseVal); 1405 if (isa<PHINode, SelectInst>(TrueVal)) 1406 getPossibleRsrcRoots(cast<Instruction>(TrueVal), Roots, Seen); 1407 if (isa<PHINode, SelectInst>(FalseVal)) 1408 getPossibleRsrcRoots(cast<Instruction>(FalseVal), Roots, Seen); 1409 } else { 1410 llvm_unreachable("getPossibleRsrcParts() only works on phi and select"); 1411 } 1412 } 1413 1414 void SplitPtrStructs::processConditionals() { 1415 SmallDenseMap<Instruction *, Value *> FoundRsrcs; 1416 SmallPtrSet<Value *, 4> Roots; 1417 SmallPtrSet<Value *, 4> Seen; 1418 for (Instruction *I : Conditionals) { 1419 // These have to exist by now because we've visited these nodes. 1420 Value *Rsrc = RsrcParts[I]; 1421 Value *Off = OffParts[I]; 1422 assert(Rsrc && Off && "must have visited conditionals by now"); 1423 1424 std::optional<Value *> MaybeRsrc; 1425 auto MaybeFoundRsrc = FoundRsrcs.find(I); 1426 if (MaybeFoundRsrc != FoundRsrcs.end()) { 1427 MaybeRsrc = MaybeFoundRsrc->second; 1428 } else { 1429 IRBuilder<>::InsertPointGuard Guard(IRB); 1430 Roots.clear(); 1431 Seen.clear(); 1432 getPossibleRsrcRoots(I, Roots, Seen); 1433 LLVM_DEBUG(dbgs() << "Processing conditional: " << *I << "\n"); 1434 #ifndef NDEBUG 1435 for (Value *V : Roots) 1436 LLVM_DEBUG(dbgs() << "Root: " << *V << "\n"); 1437 for (Value *V : Seen) 1438 LLVM_DEBUG(dbgs() << "Seen: " << *V << "\n"); 1439 #endif 1440 // If we are our own possible root, then we shouldn't block our 1441 // replacement with a valid incoming value. 1442 Roots.erase(I); 1443 // We don't want to block the optimization for conditionals that don't 1444 // refer to themselves but did see themselves during the traversal. 1445 Seen.erase(I); 1446 1447 if (set_is_subset(Seen, Roots)) { 1448 auto Diff = set_difference(Roots, Seen); 1449 if (Diff.size() == 1) { 1450 Value *RootVal = *Diff.begin(); 1451 // Handle the case where previous loops already looked through 1452 // an addrspacecast. 1453 if (isSplitFatPtr(RootVal->getType())) 1454 MaybeRsrc = std::get<0>(getPtrParts(RootVal)); 1455 else 1456 MaybeRsrc = RootVal; 1457 } 1458 } 1459 } 1460 1461 if (auto *PHI = dyn_cast<PHINode>(I)) { 1462 Value *NewRsrc; 1463 StructType *PHITy = cast<StructType>(PHI->getType()); 1464 IRB.SetInsertPoint(*PHI->getInsertionPointAfterDef()); 1465 IRB.SetCurrentDebugLocation(PHI->getDebugLoc()); 1466 if (MaybeRsrc) { 1467 NewRsrc = *MaybeRsrc; 1468 } else { 1469 Type *RsrcTy = PHITy->getElementType(0); 1470 auto *RsrcPHI = IRB.CreatePHI(RsrcTy, PHI->getNumIncomingValues()); 1471 RsrcPHI->takeName(Rsrc); 1472 for (auto [V, BB] : llvm::zip(PHI->incoming_values(), PHI->blocks())) { 1473 Value *VRsrc = std::get<0>(getPtrParts(V)); 1474 RsrcPHI->addIncoming(VRsrc, BB); 1475 } 1476 copyMetadata(RsrcPHI, PHI); 1477 NewRsrc = RsrcPHI; 1478 } 1479 1480 Type *OffTy = PHITy->getElementType(1); 1481 auto *NewOff = IRB.CreatePHI(OffTy, PHI->getNumIncomingValues()); 1482 NewOff->takeName(Off); 1483 for (auto [V, BB] : llvm::zip(PHI->incoming_values(), PHI->blocks())) { 1484 assert(OffParts.count(V) && "An offset part had to be created by now"); 1485 Value *VOff = std::get<1>(getPtrParts(V)); 1486 NewOff->addIncoming(VOff, BB); 1487 } 1488 copyMetadata(NewOff, PHI); 1489 1490 // Note: We don't eraseFromParent() the temporaries because we don't want 1491 // to put the corrections maps in an inconstent state. That'll be handed 1492 // during the rest of the killing. Also, `ValueToValueMapTy` guarantees 1493 // that references in that map will be updated as well. 1494 ConditionalTemps.push_back(cast<Instruction>(Rsrc)); 1495 ConditionalTemps.push_back(cast<Instruction>(Off)); 1496 Rsrc->replaceAllUsesWith(NewRsrc); 1497 Off->replaceAllUsesWith(NewOff); 1498 1499 // Save on recomputing the cycle traversals in known-root cases. 1500 if (MaybeRsrc) 1501 for (Value *V : Seen) 1502 FoundRsrcs[cast<Instruction>(V)] = NewRsrc; 1503 } else if (isa<SelectInst>(I)) { 1504 if (MaybeRsrc) { 1505 ConditionalTemps.push_back(cast<Instruction>(Rsrc)); 1506 Rsrc->replaceAllUsesWith(*MaybeRsrc); 1507 for (Value *V : Seen) 1508 FoundRsrcs[cast<Instruction>(V)] = *MaybeRsrc; 1509 } 1510 } else { 1511 llvm_unreachable("Only PHIs and selects go in the conditionals list"); 1512 } 1513 } 1514 } 1515 1516 void SplitPtrStructs::killAndReplaceSplitInstructions( 1517 SmallVectorImpl<Instruction *> &Origs) { 1518 for (Instruction *I : ConditionalTemps) 1519 I->eraseFromParent(); 1520 1521 for (Instruction *I : Origs) { 1522 if (!SplitUsers.contains(I)) 1523 continue; 1524 1525 SmallVector<DbgValueInst *> Dbgs; 1526 findDbgValues(Dbgs, I); 1527 for (auto *Dbg : Dbgs) { 1528 IRB.SetInsertPoint(Dbg); 1529 auto &DL = I->getDataLayout(); 1530 assert(isSplitFatPtr(I->getType()) && 1531 "We should've RAUW'd away loads, stores, etc. at this point"); 1532 auto *OffDbg = cast<DbgValueInst>(Dbg->clone()); 1533 copyMetadata(OffDbg, Dbg); 1534 auto [Rsrc, Off] = getPtrParts(I); 1535 1536 int64_t RsrcSz = DL.getTypeSizeInBits(Rsrc->getType()); 1537 int64_t OffSz = DL.getTypeSizeInBits(Off->getType()); 1538 1539 std::optional<DIExpression *> RsrcExpr = 1540 DIExpression::createFragmentExpression(Dbg->getExpression(), 0, 1541 RsrcSz); 1542 std::optional<DIExpression *> OffExpr = 1543 DIExpression::createFragmentExpression(Dbg->getExpression(), RsrcSz, 1544 OffSz); 1545 if (OffExpr) { 1546 OffDbg->setExpression(*OffExpr); 1547 OffDbg->replaceVariableLocationOp(I, Off); 1548 IRB.Insert(OffDbg); 1549 } else { 1550 OffDbg->deleteValue(); 1551 } 1552 if (RsrcExpr) { 1553 Dbg->setExpression(*RsrcExpr); 1554 Dbg->replaceVariableLocationOp(I, Rsrc); 1555 } else { 1556 Dbg->replaceVariableLocationOp(I, UndefValue::get(I->getType())); 1557 } 1558 } 1559 1560 Value *Poison = PoisonValue::get(I->getType()); 1561 I->replaceUsesWithIf(Poison, [&](const Use &U) -> bool { 1562 if (const auto *UI = dyn_cast<Instruction>(U.getUser())) 1563 return SplitUsers.contains(UI); 1564 return false; 1565 }); 1566 1567 if (I->use_empty()) { 1568 I->eraseFromParent(); 1569 continue; 1570 } 1571 IRB.SetInsertPoint(*I->getInsertionPointAfterDef()); 1572 IRB.SetCurrentDebugLocation(I->getDebugLoc()); 1573 auto [Rsrc, Off] = getPtrParts(I); 1574 Value *Struct = PoisonValue::get(I->getType()); 1575 Struct = IRB.CreateInsertValue(Struct, Rsrc, 0); 1576 Struct = IRB.CreateInsertValue(Struct, Off, 1); 1577 copyMetadata(Struct, I); 1578 Struct->takeName(I); 1579 I->replaceAllUsesWith(Struct); 1580 I->eraseFromParent(); 1581 } 1582 } 1583 1584 void SplitPtrStructs::setAlign(CallInst *Intr, Align A, unsigned RsrcArgIdx) { 1585 LLVMContext &Ctx = Intr->getContext(); 1586 Intr->addParamAttr(RsrcArgIdx, Attribute::getWithAlignment(Ctx, A)); 1587 } 1588 1589 void SplitPtrStructs::insertPreMemOpFence(AtomicOrdering Order, 1590 SyncScope::ID SSID) { 1591 switch (Order) { 1592 case AtomicOrdering::Release: 1593 case AtomicOrdering::AcquireRelease: 1594 case AtomicOrdering::SequentiallyConsistent: 1595 IRB.CreateFence(AtomicOrdering::Release, SSID); 1596 break; 1597 default: 1598 break; 1599 } 1600 } 1601 1602 void SplitPtrStructs::insertPostMemOpFence(AtomicOrdering Order, 1603 SyncScope::ID SSID) { 1604 switch (Order) { 1605 case AtomicOrdering::Acquire: 1606 case AtomicOrdering::AcquireRelease: 1607 case AtomicOrdering::SequentiallyConsistent: 1608 IRB.CreateFence(AtomicOrdering::Acquire, SSID); 1609 break; 1610 default: 1611 break; 1612 } 1613 } 1614 1615 Value *SplitPtrStructs::handleMemoryInst(Instruction *I, Value *Arg, Value *Ptr, 1616 Type *Ty, Align Alignment, 1617 AtomicOrdering Order, bool IsVolatile, 1618 SyncScope::ID SSID) { 1619 IRB.SetInsertPoint(I); 1620 1621 auto [Rsrc, Off] = getPtrParts(Ptr); 1622 SmallVector<Value *, 5> Args; 1623 if (Arg) 1624 Args.push_back(Arg); 1625 Args.push_back(Rsrc); 1626 Args.push_back(Off); 1627 insertPreMemOpFence(Order, SSID); 1628 // soffset is always 0 for these cases, where we always want any offset to be 1629 // part of bounds checking and we don't know which parts of the GEPs is 1630 // uniform. 1631 Args.push_back(IRB.getInt32(0)); 1632 1633 uint32_t Aux = 0; 1634 if (IsVolatile) 1635 Aux |= AMDGPU::CPol::VOLATILE; 1636 Args.push_back(IRB.getInt32(Aux)); 1637 1638 Intrinsic::ID IID = Intrinsic::not_intrinsic; 1639 if (isa<LoadInst>(I)) 1640 IID = Order == AtomicOrdering::NotAtomic 1641 ? Intrinsic::amdgcn_raw_ptr_buffer_load 1642 : Intrinsic::amdgcn_raw_ptr_atomic_buffer_load; 1643 else if (isa<StoreInst>(I)) 1644 IID = Intrinsic::amdgcn_raw_ptr_buffer_store; 1645 else if (auto *RMW = dyn_cast<AtomicRMWInst>(I)) { 1646 switch (RMW->getOperation()) { 1647 case AtomicRMWInst::Xchg: 1648 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_swap; 1649 break; 1650 case AtomicRMWInst::Add: 1651 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_add; 1652 break; 1653 case AtomicRMWInst::Sub: 1654 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_sub; 1655 break; 1656 case AtomicRMWInst::And: 1657 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_and; 1658 break; 1659 case AtomicRMWInst::Or: 1660 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_or; 1661 break; 1662 case AtomicRMWInst::Xor: 1663 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_xor; 1664 break; 1665 case AtomicRMWInst::Max: 1666 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_smax; 1667 break; 1668 case AtomicRMWInst::Min: 1669 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_smin; 1670 break; 1671 case AtomicRMWInst::UMax: 1672 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_umax; 1673 break; 1674 case AtomicRMWInst::UMin: 1675 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_umin; 1676 break; 1677 case AtomicRMWInst::FAdd: 1678 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fadd; 1679 break; 1680 case AtomicRMWInst::FMax: 1681 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fmax; 1682 break; 1683 case AtomicRMWInst::FMin: 1684 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fmin; 1685 break; 1686 case AtomicRMWInst::FSub: { 1687 report_fatal_error("atomic floating point subtraction not supported for " 1688 "buffer resources and should've been expanded away"); 1689 break; 1690 } 1691 case AtomicRMWInst::Nand: 1692 report_fatal_error("atomic nand not supported for buffer resources and " 1693 "should've been expanded away"); 1694 break; 1695 case AtomicRMWInst::UIncWrap: 1696 case AtomicRMWInst::UDecWrap: 1697 report_fatal_error("wrapping increment/decrement not supported for " 1698 "buffer resources and should've ben expanded away"); 1699 break; 1700 case AtomicRMWInst::BAD_BINOP: 1701 llvm_unreachable("Not sure how we got a bad binop"); 1702 case AtomicRMWInst::USubCond: 1703 case AtomicRMWInst::USubSat: 1704 break; 1705 } 1706 } 1707 1708 auto *Call = IRB.CreateIntrinsic(IID, Ty, Args); 1709 copyMetadata(Call, I); 1710 setAlign(Call, Alignment, Arg ? 1 : 0); 1711 Call->takeName(I); 1712 1713 insertPostMemOpFence(Order, SSID); 1714 // The "no moving p7 directly" rewrites ensure that this load or store won't 1715 // itself need to be split into parts. 1716 SplitUsers.insert(I); 1717 I->replaceAllUsesWith(Call); 1718 return Call; 1719 } 1720 1721 PtrParts SplitPtrStructs::visitInstruction(Instruction &I) { 1722 return {nullptr, nullptr}; 1723 } 1724 1725 PtrParts SplitPtrStructs::visitLoadInst(LoadInst &LI) { 1726 if (!isSplitFatPtr(LI.getPointerOperandType())) 1727 return {nullptr, nullptr}; 1728 handleMemoryInst(&LI, nullptr, LI.getPointerOperand(), LI.getType(), 1729 LI.getAlign(), LI.getOrdering(), LI.isVolatile(), 1730 LI.getSyncScopeID()); 1731 return {nullptr, nullptr}; 1732 } 1733 1734 PtrParts SplitPtrStructs::visitStoreInst(StoreInst &SI) { 1735 if (!isSplitFatPtr(SI.getPointerOperandType())) 1736 return {nullptr, nullptr}; 1737 Value *Arg = SI.getValueOperand(); 1738 handleMemoryInst(&SI, Arg, SI.getPointerOperand(), Arg->getType(), 1739 SI.getAlign(), SI.getOrdering(), SI.isVolatile(), 1740 SI.getSyncScopeID()); 1741 return {nullptr, nullptr}; 1742 } 1743 1744 PtrParts SplitPtrStructs::visitAtomicRMWInst(AtomicRMWInst &AI) { 1745 if (!isSplitFatPtr(AI.getPointerOperand()->getType())) 1746 return {nullptr, nullptr}; 1747 Value *Arg = AI.getValOperand(); 1748 handleMemoryInst(&AI, Arg, AI.getPointerOperand(), Arg->getType(), 1749 AI.getAlign(), AI.getOrdering(), AI.isVolatile(), 1750 AI.getSyncScopeID()); 1751 return {nullptr, nullptr}; 1752 } 1753 1754 // Unlike load, store, and RMW, cmpxchg needs special handling to account 1755 // for the boolean argument. 1756 PtrParts SplitPtrStructs::visitAtomicCmpXchgInst(AtomicCmpXchgInst &AI) { 1757 Value *Ptr = AI.getPointerOperand(); 1758 if (!isSplitFatPtr(Ptr->getType())) 1759 return {nullptr, nullptr}; 1760 IRB.SetInsertPoint(&AI); 1761 1762 Type *Ty = AI.getNewValOperand()->getType(); 1763 AtomicOrdering Order = AI.getMergedOrdering(); 1764 SyncScope::ID SSID = AI.getSyncScopeID(); 1765 bool IsNonTemporal = AI.getMetadata(LLVMContext::MD_nontemporal); 1766 1767 auto [Rsrc, Off] = getPtrParts(Ptr); 1768 insertPreMemOpFence(Order, SSID); 1769 1770 uint32_t Aux = 0; 1771 if (IsNonTemporal) 1772 Aux |= AMDGPU::CPol::SLC; 1773 if (AI.isVolatile()) 1774 Aux |= AMDGPU::CPol::VOLATILE; 1775 auto *Call = 1776 IRB.CreateIntrinsic(Intrinsic::amdgcn_raw_ptr_buffer_atomic_cmpswap, Ty, 1777 {AI.getNewValOperand(), AI.getCompareOperand(), Rsrc, 1778 Off, IRB.getInt32(0), IRB.getInt32(Aux)}); 1779 copyMetadata(Call, &AI); 1780 setAlign(Call, AI.getAlign(), 2); 1781 Call->takeName(&AI); 1782 insertPostMemOpFence(Order, SSID); 1783 1784 Value *Res = PoisonValue::get(AI.getType()); 1785 Res = IRB.CreateInsertValue(Res, Call, 0); 1786 if (!AI.isWeak()) { 1787 Value *Succeeded = IRB.CreateICmpEQ(Call, AI.getCompareOperand()); 1788 Res = IRB.CreateInsertValue(Res, Succeeded, 1); 1789 } 1790 SplitUsers.insert(&AI); 1791 AI.replaceAllUsesWith(Res); 1792 return {nullptr, nullptr}; 1793 } 1794 1795 PtrParts SplitPtrStructs::visitGetElementPtrInst(GetElementPtrInst &GEP) { 1796 using namespace llvm::PatternMatch; 1797 Value *Ptr = GEP.getPointerOperand(); 1798 if (!isSplitFatPtr(Ptr->getType())) 1799 return {nullptr, nullptr}; 1800 IRB.SetInsertPoint(&GEP); 1801 1802 auto [Rsrc, Off] = getPtrParts(Ptr); 1803 const DataLayout &DL = GEP.getDataLayout(); 1804 bool IsNUW = GEP.hasNoUnsignedWrap(); 1805 bool IsNUSW = GEP.hasNoUnsignedSignedWrap(); 1806 1807 // In order to call emitGEPOffset() and thus not have to reimplement it, 1808 // we need the GEP result to have ptr addrspace(7) type. 1809 Type *FatPtrTy = IRB.getPtrTy(AMDGPUAS::BUFFER_FAT_POINTER); 1810 if (auto *VT = dyn_cast<VectorType>(Off->getType())) 1811 FatPtrTy = VectorType::get(FatPtrTy, VT->getElementCount()); 1812 GEP.mutateType(FatPtrTy); 1813 Value *OffAccum = emitGEPOffset(&IRB, DL, &GEP); 1814 GEP.mutateType(Ptr->getType()); 1815 if (match(OffAccum, m_Zero())) { // Constant-zero offset 1816 SplitUsers.insert(&GEP); 1817 return {Rsrc, Off}; 1818 } 1819 1820 bool HasNonNegativeOff = false; 1821 if (auto *CI = dyn_cast<ConstantInt>(OffAccum)) { 1822 HasNonNegativeOff = !CI->isNegative(); 1823 } 1824 Value *NewOff; 1825 if (match(Off, m_Zero())) { 1826 NewOff = OffAccum; 1827 } else { 1828 NewOff = IRB.CreateAdd(Off, OffAccum, "", 1829 /*hasNUW=*/IsNUW || (IsNUSW && HasNonNegativeOff), 1830 /*hasNSW=*/false); 1831 } 1832 copyMetadata(NewOff, &GEP); 1833 NewOff->takeName(&GEP); 1834 SplitUsers.insert(&GEP); 1835 return {Rsrc, NewOff}; 1836 } 1837 1838 PtrParts SplitPtrStructs::visitPtrToIntInst(PtrToIntInst &PI) { 1839 Value *Ptr = PI.getPointerOperand(); 1840 if (!isSplitFatPtr(Ptr->getType())) 1841 return {nullptr, nullptr}; 1842 IRB.SetInsertPoint(&PI); 1843 1844 Type *ResTy = PI.getType(); 1845 unsigned Width = ResTy->getScalarSizeInBits(); 1846 1847 auto [Rsrc, Off] = getPtrParts(Ptr); 1848 const DataLayout &DL = PI.getDataLayout(); 1849 unsigned FatPtrWidth = DL.getPointerSizeInBits(AMDGPUAS::BUFFER_FAT_POINTER); 1850 1851 Value *Res; 1852 if (Width <= BufferOffsetWidth) { 1853 Res = IRB.CreateIntCast(Off, ResTy, /*isSigned=*/false, 1854 PI.getName() + ".off"); 1855 } else { 1856 Value *RsrcInt = IRB.CreatePtrToInt(Rsrc, ResTy, PI.getName() + ".rsrc"); 1857 Value *Shl = IRB.CreateShl( 1858 RsrcInt, 1859 ConstantExpr::getIntegerValue(ResTy, APInt(Width, BufferOffsetWidth)), 1860 "", Width >= FatPtrWidth, Width > FatPtrWidth); 1861 Value *OffCast = IRB.CreateIntCast(Off, ResTy, /*isSigned=*/false, 1862 PI.getName() + ".off"); 1863 Res = IRB.CreateOr(Shl, OffCast); 1864 } 1865 1866 copyMetadata(Res, &PI); 1867 Res->takeName(&PI); 1868 SplitUsers.insert(&PI); 1869 PI.replaceAllUsesWith(Res); 1870 return {nullptr, nullptr}; 1871 } 1872 1873 PtrParts SplitPtrStructs::visitIntToPtrInst(IntToPtrInst &IP) { 1874 if (!isSplitFatPtr(IP.getType())) 1875 return {nullptr, nullptr}; 1876 IRB.SetInsertPoint(&IP); 1877 const DataLayout &DL = IP.getDataLayout(); 1878 unsigned RsrcPtrWidth = DL.getPointerSizeInBits(AMDGPUAS::BUFFER_RESOURCE); 1879 Value *Int = IP.getOperand(0); 1880 Type *IntTy = Int->getType(); 1881 Type *RsrcIntTy = IntTy->getWithNewBitWidth(RsrcPtrWidth); 1882 unsigned Width = IntTy->getScalarSizeInBits(); 1883 1884 auto *RetTy = cast<StructType>(IP.getType()); 1885 Type *RsrcTy = RetTy->getElementType(0); 1886 Type *OffTy = RetTy->getElementType(1); 1887 Value *RsrcPart = IRB.CreateLShr( 1888 Int, 1889 ConstantExpr::getIntegerValue(IntTy, APInt(Width, BufferOffsetWidth))); 1890 Value *RsrcInt = IRB.CreateIntCast(RsrcPart, RsrcIntTy, /*isSigned=*/false); 1891 Value *Rsrc = IRB.CreateIntToPtr(RsrcInt, RsrcTy, IP.getName() + ".rsrc"); 1892 Value *Off = 1893 IRB.CreateIntCast(Int, OffTy, /*IsSigned=*/false, IP.getName() + ".off"); 1894 1895 copyMetadata(Rsrc, &IP); 1896 SplitUsers.insert(&IP); 1897 return {Rsrc, Off}; 1898 } 1899 1900 PtrParts SplitPtrStructs::visitAddrSpaceCastInst(AddrSpaceCastInst &I) { 1901 if (!isSplitFatPtr(I.getType())) 1902 return {nullptr, nullptr}; 1903 IRB.SetInsertPoint(&I); 1904 Value *In = I.getPointerOperand(); 1905 // No-op casts preserve parts 1906 if (In->getType() == I.getType()) { 1907 auto [Rsrc, Off] = getPtrParts(In); 1908 SplitUsers.insert(&I); 1909 return {Rsrc, Off}; 1910 } 1911 if (I.getSrcAddressSpace() != AMDGPUAS::BUFFER_RESOURCE) 1912 report_fatal_error("Only buffer resources (addrspace 8) can be cast to " 1913 "buffer fat pointers (addrspace 7)"); 1914 Type *OffTy = cast<StructType>(I.getType())->getElementType(1); 1915 Value *ZeroOff = Constant::getNullValue(OffTy); 1916 SplitUsers.insert(&I); 1917 return {In, ZeroOff}; 1918 } 1919 1920 PtrParts SplitPtrStructs::visitICmpInst(ICmpInst &Cmp) { 1921 Value *Lhs = Cmp.getOperand(0); 1922 if (!isSplitFatPtr(Lhs->getType())) 1923 return {nullptr, nullptr}; 1924 Value *Rhs = Cmp.getOperand(1); 1925 IRB.SetInsertPoint(&Cmp); 1926 ICmpInst::Predicate Pred = Cmp.getPredicate(); 1927 1928 assert((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) && 1929 "Pointer comparison is only equal or unequal"); 1930 auto [LhsRsrc, LhsOff] = getPtrParts(Lhs); 1931 auto [RhsRsrc, RhsOff] = getPtrParts(Rhs); 1932 Value *RsrcCmp = 1933 IRB.CreateICmp(Pred, LhsRsrc, RhsRsrc, Cmp.getName() + ".rsrc"); 1934 copyMetadata(RsrcCmp, &Cmp); 1935 Value *OffCmp = IRB.CreateICmp(Pred, LhsOff, RhsOff, Cmp.getName() + ".off"); 1936 copyMetadata(OffCmp, &Cmp); 1937 1938 Value *Res = nullptr; 1939 if (Pred == ICmpInst::ICMP_EQ) 1940 Res = IRB.CreateAnd(RsrcCmp, OffCmp); 1941 else if (Pred == ICmpInst::ICMP_NE) 1942 Res = IRB.CreateOr(RsrcCmp, OffCmp); 1943 copyMetadata(Res, &Cmp); 1944 Res->takeName(&Cmp); 1945 SplitUsers.insert(&Cmp); 1946 Cmp.replaceAllUsesWith(Res); 1947 return {nullptr, nullptr}; 1948 } 1949 1950 PtrParts SplitPtrStructs::visitFreezeInst(FreezeInst &I) { 1951 if (!isSplitFatPtr(I.getType())) 1952 return {nullptr, nullptr}; 1953 IRB.SetInsertPoint(&I); 1954 auto [Rsrc, Off] = getPtrParts(I.getOperand(0)); 1955 1956 Value *RsrcRes = IRB.CreateFreeze(Rsrc, I.getName() + ".rsrc"); 1957 copyMetadata(RsrcRes, &I); 1958 Value *OffRes = IRB.CreateFreeze(Off, I.getName() + ".off"); 1959 copyMetadata(OffRes, &I); 1960 SplitUsers.insert(&I); 1961 return {RsrcRes, OffRes}; 1962 } 1963 1964 PtrParts SplitPtrStructs::visitExtractElementInst(ExtractElementInst &I) { 1965 if (!isSplitFatPtr(I.getType())) 1966 return {nullptr, nullptr}; 1967 IRB.SetInsertPoint(&I); 1968 Value *Vec = I.getVectorOperand(); 1969 Value *Idx = I.getIndexOperand(); 1970 auto [Rsrc, Off] = getPtrParts(Vec); 1971 1972 Value *RsrcRes = IRB.CreateExtractElement(Rsrc, Idx, I.getName() + ".rsrc"); 1973 copyMetadata(RsrcRes, &I); 1974 Value *OffRes = IRB.CreateExtractElement(Off, Idx, I.getName() + ".off"); 1975 copyMetadata(OffRes, &I); 1976 SplitUsers.insert(&I); 1977 return {RsrcRes, OffRes}; 1978 } 1979 1980 PtrParts SplitPtrStructs::visitInsertElementInst(InsertElementInst &I) { 1981 // The mutated instructions temporarily don't return vectors, and so 1982 // we need the generic getType() here to avoid crashes. 1983 if (!isSplitFatPtr(cast<Instruction>(I).getType())) 1984 return {nullptr, nullptr}; 1985 IRB.SetInsertPoint(&I); 1986 Value *Vec = I.getOperand(0); 1987 Value *Elem = I.getOperand(1); 1988 Value *Idx = I.getOperand(2); 1989 auto [VecRsrc, VecOff] = getPtrParts(Vec); 1990 auto [ElemRsrc, ElemOff] = getPtrParts(Elem); 1991 1992 Value *RsrcRes = 1993 IRB.CreateInsertElement(VecRsrc, ElemRsrc, Idx, I.getName() + ".rsrc"); 1994 copyMetadata(RsrcRes, &I); 1995 Value *OffRes = 1996 IRB.CreateInsertElement(VecOff, ElemOff, Idx, I.getName() + ".off"); 1997 copyMetadata(OffRes, &I); 1998 SplitUsers.insert(&I); 1999 return {RsrcRes, OffRes}; 2000 } 2001 2002 PtrParts SplitPtrStructs::visitShuffleVectorInst(ShuffleVectorInst &I) { 2003 // Cast is needed for the same reason as insertelement's. 2004 if (!isSplitFatPtr(cast<Instruction>(I).getType())) 2005 return {nullptr, nullptr}; 2006 IRB.SetInsertPoint(&I); 2007 2008 Value *V1 = I.getOperand(0); 2009 Value *V2 = I.getOperand(1); 2010 ArrayRef<int> Mask = I.getShuffleMask(); 2011 auto [V1Rsrc, V1Off] = getPtrParts(V1); 2012 auto [V2Rsrc, V2Off] = getPtrParts(V2); 2013 2014 Value *RsrcRes = 2015 IRB.CreateShuffleVector(V1Rsrc, V2Rsrc, Mask, I.getName() + ".rsrc"); 2016 copyMetadata(RsrcRes, &I); 2017 Value *OffRes = 2018 IRB.CreateShuffleVector(V1Off, V2Off, Mask, I.getName() + ".off"); 2019 copyMetadata(OffRes, &I); 2020 SplitUsers.insert(&I); 2021 return {RsrcRes, OffRes}; 2022 } 2023 2024 PtrParts SplitPtrStructs::visitPHINode(PHINode &PHI) { 2025 if (!isSplitFatPtr(PHI.getType())) 2026 return {nullptr, nullptr}; 2027 IRB.SetInsertPoint(*PHI.getInsertionPointAfterDef()); 2028 // Phi nodes will be handled in post-processing after we've visited every 2029 // instruction. However, instead of just returning {nullptr, nullptr}, 2030 // we explicitly create the temporary extractvalue operations that are our 2031 // temporary results so that they end up at the beginning of the block with 2032 // the PHIs. 2033 Value *TmpRsrc = IRB.CreateExtractValue(&PHI, 0, PHI.getName() + ".rsrc"); 2034 Value *TmpOff = IRB.CreateExtractValue(&PHI, 1, PHI.getName() + ".off"); 2035 Conditionals.push_back(&PHI); 2036 SplitUsers.insert(&PHI); 2037 return {TmpRsrc, TmpOff}; 2038 } 2039 2040 PtrParts SplitPtrStructs::visitSelectInst(SelectInst &SI) { 2041 if (!isSplitFatPtr(SI.getType())) 2042 return {nullptr, nullptr}; 2043 IRB.SetInsertPoint(&SI); 2044 2045 Value *Cond = SI.getCondition(); 2046 Value *True = SI.getTrueValue(); 2047 Value *False = SI.getFalseValue(); 2048 auto [TrueRsrc, TrueOff] = getPtrParts(True); 2049 auto [FalseRsrc, FalseOff] = getPtrParts(False); 2050 2051 Value *RsrcRes = 2052 IRB.CreateSelect(Cond, TrueRsrc, FalseRsrc, SI.getName() + ".rsrc", &SI); 2053 copyMetadata(RsrcRes, &SI); 2054 Conditionals.push_back(&SI); 2055 Value *OffRes = 2056 IRB.CreateSelect(Cond, TrueOff, FalseOff, SI.getName() + ".off", &SI); 2057 copyMetadata(OffRes, &SI); 2058 SplitUsers.insert(&SI); 2059 return {RsrcRes, OffRes}; 2060 } 2061 2062 /// Returns true if this intrinsic needs to be removed when it is 2063 /// applied to `ptr addrspace(7)` values. Calls to these intrinsics are 2064 /// rewritten into calls to versions of that intrinsic on the resource 2065 /// descriptor. 2066 static bool isRemovablePointerIntrinsic(Intrinsic::ID IID) { 2067 switch (IID) { 2068 default: 2069 return false; 2070 case Intrinsic::ptrmask: 2071 case Intrinsic::invariant_start: 2072 case Intrinsic::invariant_end: 2073 case Intrinsic::launder_invariant_group: 2074 case Intrinsic::strip_invariant_group: 2075 return true; 2076 } 2077 } 2078 2079 PtrParts SplitPtrStructs::visitIntrinsicInst(IntrinsicInst &I) { 2080 Intrinsic::ID IID = I.getIntrinsicID(); 2081 switch (IID) { 2082 default: 2083 break; 2084 case Intrinsic::ptrmask: { 2085 Value *Ptr = I.getArgOperand(0); 2086 if (!isSplitFatPtr(Ptr->getType())) 2087 return {nullptr, nullptr}; 2088 Value *Mask = I.getArgOperand(1); 2089 IRB.SetInsertPoint(&I); 2090 auto [Rsrc, Off] = getPtrParts(Ptr); 2091 if (Mask->getType() != Off->getType()) 2092 report_fatal_error("offset width is not equal to index width of fat " 2093 "pointer (data layout not set up correctly?)"); 2094 Value *OffRes = IRB.CreateAnd(Off, Mask, I.getName() + ".off"); 2095 copyMetadata(OffRes, &I); 2096 SplitUsers.insert(&I); 2097 return {Rsrc, OffRes}; 2098 } 2099 // Pointer annotation intrinsics that, given their object-wide nature 2100 // operate on the resource part. 2101 case Intrinsic::invariant_start: { 2102 Value *Ptr = I.getArgOperand(1); 2103 if (!isSplitFatPtr(Ptr->getType())) 2104 return {nullptr, nullptr}; 2105 IRB.SetInsertPoint(&I); 2106 auto [Rsrc, Off] = getPtrParts(Ptr); 2107 Type *NewTy = PointerType::get(I.getContext(), AMDGPUAS::BUFFER_RESOURCE); 2108 auto *NewRsrc = IRB.CreateIntrinsic(IID, {NewTy}, {I.getOperand(0), Rsrc}); 2109 copyMetadata(NewRsrc, &I); 2110 NewRsrc->takeName(&I); 2111 SplitUsers.insert(&I); 2112 I.replaceAllUsesWith(NewRsrc); 2113 return {nullptr, nullptr}; 2114 } 2115 case Intrinsic::invariant_end: { 2116 Value *RealPtr = I.getArgOperand(2); 2117 if (!isSplitFatPtr(RealPtr->getType())) 2118 return {nullptr, nullptr}; 2119 IRB.SetInsertPoint(&I); 2120 Value *RealRsrc = getPtrParts(RealPtr).first; 2121 Value *InvPtr = I.getArgOperand(0); 2122 Value *Size = I.getArgOperand(1); 2123 Value *NewRsrc = IRB.CreateIntrinsic(IID, {RealRsrc->getType()}, 2124 {InvPtr, Size, RealRsrc}); 2125 copyMetadata(NewRsrc, &I); 2126 NewRsrc->takeName(&I); 2127 SplitUsers.insert(&I); 2128 I.replaceAllUsesWith(NewRsrc); 2129 return {nullptr, nullptr}; 2130 } 2131 case Intrinsic::launder_invariant_group: 2132 case Intrinsic::strip_invariant_group: { 2133 Value *Ptr = I.getArgOperand(0); 2134 if (!isSplitFatPtr(Ptr->getType())) 2135 return {nullptr, nullptr}; 2136 IRB.SetInsertPoint(&I); 2137 auto [Rsrc, Off] = getPtrParts(Ptr); 2138 Value *NewRsrc = IRB.CreateIntrinsic(IID, {Rsrc->getType()}, {Rsrc}); 2139 copyMetadata(NewRsrc, &I); 2140 NewRsrc->takeName(&I); 2141 SplitUsers.insert(&I); 2142 return {NewRsrc, Off}; 2143 } 2144 } 2145 return {nullptr, nullptr}; 2146 } 2147 2148 void SplitPtrStructs::processFunction(Function &F) { 2149 ST = &TM->getSubtarget<GCNSubtarget>(F); 2150 SmallVector<Instruction *, 0> Originals; 2151 LLVM_DEBUG(dbgs() << "Splitting pointer structs in function: " << F.getName() 2152 << "\n"); 2153 for (Instruction &I : instructions(F)) 2154 Originals.push_back(&I); 2155 for (Instruction *I : Originals) { 2156 auto [Rsrc, Off] = visit(I); 2157 assert(((Rsrc && Off) || (!Rsrc && !Off)) && 2158 "Can't have a resource but no offset"); 2159 if (Rsrc) 2160 RsrcParts[I] = Rsrc; 2161 if (Off) 2162 OffParts[I] = Off; 2163 } 2164 processConditionals(); 2165 killAndReplaceSplitInstructions(Originals); 2166 2167 // Clean up after ourselves to save on memory. 2168 RsrcParts.clear(); 2169 OffParts.clear(); 2170 SplitUsers.clear(); 2171 Conditionals.clear(); 2172 ConditionalTemps.clear(); 2173 } 2174 2175 namespace { 2176 class AMDGPULowerBufferFatPointers : public ModulePass { 2177 public: 2178 static char ID; 2179 2180 AMDGPULowerBufferFatPointers() : ModulePass(ID) { 2181 initializeAMDGPULowerBufferFatPointersPass( 2182 *PassRegistry::getPassRegistry()); 2183 } 2184 2185 bool run(Module &M, const TargetMachine &TM); 2186 bool runOnModule(Module &M) override; 2187 2188 void getAnalysisUsage(AnalysisUsage &AU) const override; 2189 }; 2190 } // namespace 2191 2192 /// Returns true if there are values that have a buffer fat pointer in them, 2193 /// which means we'll need to perform rewrites on this function. As a side 2194 /// effect, this will populate the type remapping cache. 2195 static bool containsBufferFatPointers(const Function &F, 2196 BufferFatPtrToStructTypeMap *TypeMap) { 2197 bool HasFatPointers = false; 2198 for (const BasicBlock &BB : F) 2199 for (const Instruction &I : BB) 2200 HasFatPointers |= (I.getType() != TypeMap->remapType(I.getType())); 2201 return HasFatPointers; 2202 } 2203 2204 static bool hasFatPointerInterface(const Function &F, 2205 BufferFatPtrToStructTypeMap *TypeMap) { 2206 Type *Ty = F.getFunctionType(); 2207 return Ty != TypeMap->remapType(Ty); 2208 } 2209 2210 /// Move the body of `OldF` into a new function, returning it. 2211 static Function *moveFunctionAdaptingType(Function *OldF, FunctionType *NewTy, 2212 ValueToValueMapTy &CloneMap) { 2213 bool IsIntrinsic = OldF->isIntrinsic(); 2214 Function *NewF = 2215 Function::Create(NewTy, OldF->getLinkage(), OldF->getAddressSpace()); 2216 NewF->IsNewDbgInfoFormat = OldF->IsNewDbgInfoFormat; 2217 NewF->copyAttributesFrom(OldF); 2218 NewF->copyMetadata(OldF, 0); 2219 NewF->takeName(OldF); 2220 NewF->updateAfterNameChange(); 2221 NewF->setDLLStorageClass(OldF->getDLLStorageClass()); 2222 OldF->getParent()->getFunctionList().insertAfter(OldF->getIterator(), NewF); 2223 2224 while (!OldF->empty()) { 2225 BasicBlock *BB = &OldF->front(); 2226 BB->removeFromParent(); 2227 BB->insertInto(NewF); 2228 CloneMap[BB] = BB; 2229 for (Instruction &I : *BB) { 2230 CloneMap[&I] = &I; 2231 } 2232 } 2233 2234 SmallVector<AttributeSet> ArgAttrs; 2235 AttributeList OldAttrs = OldF->getAttributes(); 2236 2237 for (auto [I, OldArg, NewArg] : enumerate(OldF->args(), NewF->args())) { 2238 CloneMap[&NewArg] = &OldArg; 2239 NewArg.takeName(&OldArg); 2240 Type *OldArgTy = OldArg.getType(), *NewArgTy = NewArg.getType(); 2241 // Temporarily mutate type of `NewArg` to allow RAUW to work. 2242 NewArg.mutateType(OldArgTy); 2243 OldArg.replaceAllUsesWith(&NewArg); 2244 NewArg.mutateType(NewArgTy); 2245 2246 AttributeSet ArgAttr = OldAttrs.getParamAttrs(I); 2247 // Intrinsics get their attributes fixed later. 2248 if (OldArgTy != NewArgTy && !IsIntrinsic) 2249 ArgAttr = ArgAttr.removeAttributes( 2250 NewF->getContext(), 2251 AttributeFuncs::typeIncompatible(NewArgTy, ArgAttr)); 2252 ArgAttrs.push_back(ArgAttr); 2253 } 2254 AttributeSet RetAttrs = OldAttrs.getRetAttrs(); 2255 if (OldF->getReturnType() != NewF->getReturnType() && !IsIntrinsic) 2256 RetAttrs = RetAttrs.removeAttributes( 2257 NewF->getContext(), 2258 AttributeFuncs::typeIncompatible(NewF->getReturnType(), RetAttrs)); 2259 NewF->setAttributes(AttributeList::get( 2260 NewF->getContext(), OldAttrs.getFnAttrs(), RetAttrs, ArgAttrs)); 2261 return NewF; 2262 } 2263 2264 static void makeCloneInPraceMap(Function *F, ValueToValueMapTy &CloneMap) { 2265 for (Argument &A : F->args()) 2266 CloneMap[&A] = &A; 2267 for (BasicBlock &BB : *F) { 2268 CloneMap[&BB] = &BB; 2269 for (Instruction &I : BB) 2270 CloneMap[&I] = &I; 2271 } 2272 } 2273 2274 bool AMDGPULowerBufferFatPointers::run(Module &M, const TargetMachine &TM) { 2275 bool Changed = false; 2276 const DataLayout &DL = M.getDataLayout(); 2277 // Record the functions which need to be remapped. 2278 // The second element of the pair indicates whether the function has to have 2279 // its arguments or return types adjusted. 2280 SmallVector<std::pair<Function *, bool>> NeedsRemap; 2281 2282 BufferFatPtrToStructTypeMap StructTM(DL); 2283 BufferFatPtrToIntTypeMap IntTM(DL); 2284 for (const GlobalVariable &GV : M.globals()) { 2285 if (GV.getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER) 2286 report_fatal_error("Global variables with a buffer fat pointer address " 2287 "space (7) are not supported"); 2288 Type *VT = GV.getValueType(); 2289 if (VT != StructTM.remapType(VT)) 2290 report_fatal_error("Global variables that contain buffer fat pointers " 2291 "(address space 7 pointers) are unsupported. Use " 2292 "buffer resource pointers (address space 8) instead."); 2293 } 2294 2295 { 2296 // Collect all constant exprs and aggregates referenced by any function. 2297 SmallVector<Constant *, 8> Worklist; 2298 for (Function &F : M.functions()) 2299 for (Instruction &I : instructions(F)) 2300 for (Value *Op : I.operands()) 2301 if (isa<ConstantExpr>(Op) || isa<ConstantAggregate>(Op)) 2302 Worklist.push_back(cast<Constant>(Op)); 2303 2304 // Recursively look for any referenced buffer pointer constants. 2305 SmallPtrSet<Constant *, 8> Visited; 2306 SetVector<Constant *> BufferFatPtrConsts; 2307 while (!Worklist.empty()) { 2308 Constant *C = Worklist.pop_back_val(); 2309 if (!Visited.insert(C).second) 2310 continue; 2311 if (isBufferFatPtrOrVector(C->getType())) 2312 BufferFatPtrConsts.insert(C); 2313 for (Value *Op : C->operands()) 2314 if (isa<ConstantExpr>(Op) || isa<ConstantAggregate>(Op)) 2315 Worklist.push_back(cast<Constant>(Op)); 2316 } 2317 2318 // Expand all constant expressions using fat buffer pointers to 2319 // instructions. 2320 Changed |= convertUsersOfConstantsToInstructions( 2321 BufferFatPtrConsts.getArrayRef(), /*RestrictToFunc=*/nullptr, 2322 /*RemoveDeadConstants=*/false, /*IncludeSelf=*/true); 2323 } 2324 2325 StoreFatPtrsAsIntsVisitor MemOpsRewrite(&IntTM, M.getContext()); 2326 LegalizeBufferContentTypesVisitor BufferContentsTypeRewrite(DL, 2327 M.getContext()); 2328 for (Function &F : M.functions()) { 2329 bool InterfaceChange = hasFatPointerInterface(F, &StructTM); 2330 bool BodyChanges = containsBufferFatPointers(F, &StructTM); 2331 Changed |= MemOpsRewrite.processFunction(F); 2332 if (InterfaceChange || BodyChanges) { 2333 NeedsRemap.push_back(std::make_pair(&F, InterfaceChange)); 2334 Changed |= BufferContentsTypeRewrite.processFunction(F); 2335 } 2336 } 2337 if (NeedsRemap.empty()) 2338 return Changed; 2339 2340 SmallVector<Function *> NeedsPostProcess; 2341 SmallVector<Function *> Intrinsics; 2342 // Keep one big map so as to memoize constants across functions. 2343 ValueToValueMapTy CloneMap; 2344 FatPtrConstMaterializer Materializer(&StructTM, CloneMap); 2345 2346 ValueMapper LowerInFuncs(CloneMap, RF_None, &StructTM, &Materializer); 2347 for (auto [F, InterfaceChange] : NeedsRemap) { 2348 Function *NewF = F; 2349 if (InterfaceChange) 2350 NewF = moveFunctionAdaptingType( 2351 F, cast<FunctionType>(StructTM.remapType(F->getFunctionType())), 2352 CloneMap); 2353 else 2354 makeCloneInPraceMap(F, CloneMap); 2355 LowerInFuncs.remapFunction(*NewF); 2356 if (NewF->isIntrinsic()) 2357 Intrinsics.push_back(NewF); 2358 else 2359 NeedsPostProcess.push_back(NewF); 2360 if (InterfaceChange) { 2361 F->replaceAllUsesWith(NewF); 2362 F->eraseFromParent(); 2363 } 2364 Changed = true; 2365 } 2366 StructTM.clear(); 2367 IntTM.clear(); 2368 CloneMap.clear(); 2369 2370 SplitPtrStructs Splitter(M.getContext(), &TM); 2371 for (Function *F : NeedsPostProcess) 2372 Splitter.processFunction(*F); 2373 for (Function *F : Intrinsics) { 2374 if (isRemovablePointerIntrinsic(F->getIntrinsicID())) { 2375 F->eraseFromParent(); 2376 } else { 2377 std::optional<Function *> NewF = Intrinsic::remangleIntrinsicFunction(F); 2378 if (NewF) 2379 F->replaceAllUsesWith(*NewF); 2380 } 2381 } 2382 return Changed; 2383 } 2384 2385 bool AMDGPULowerBufferFatPointers::runOnModule(Module &M) { 2386 TargetPassConfig &TPC = getAnalysis<TargetPassConfig>(); 2387 const TargetMachine &TM = TPC.getTM<TargetMachine>(); 2388 return run(M, TM); 2389 } 2390 2391 char AMDGPULowerBufferFatPointers::ID = 0; 2392 2393 char &llvm::AMDGPULowerBufferFatPointersID = AMDGPULowerBufferFatPointers::ID; 2394 2395 void AMDGPULowerBufferFatPointers::getAnalysisUsage(AnalysisUsage &AU) const { 2396 AU.addRequired<TargetPassConfig>(); 2397 } 2398 2399 #define PASS_DESC "Lower buffer fat pointer operations to buffer resources" 2400 INITIALIZE_PASS_BEGIN(AMDGPULowerBufferFatPointers, DEBUG_TYPE, PASS_DESC, 2401 false, false) 2402 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 2403 INITIALIZE_PASS_END(AMDGPULowerBufferFatPointers, DEBUG_TYPE, PASS_DESC, false, 2404 false) 2405 #undef PASS_DESC 2406 2407 ModulePass *llvm::createAMDGPULowerBufferFatPointersPass() { 2408 return new AMDGPULowerBufferFatPointers(); 2409 } 2410 2411 PreservedAnalyses 2412 AMDGPULowerBufferFatPointersPass::run(Module &M, ModuleAnalysisManager &MA) { 2413 return AMDGPULowerBufferFatPointers().run(M, TM) ? PreservedAnalyses::none() 2414 : PreservedAnalyses::all(); 2415 } 2416