1 //===- DXILOpLowering.cpp - Lowering to DXIL operations -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "DXILOpLowering.h" 10 #include "DXILConstants.h" 11 #include "DXILIntrinsicExpansion.h" 12 #include "DXILOpBuilder.h" 13 #include "DXILResourceAnalysis.h" 14 #include "DXILShaderFlags.h" 15 #include "DirectX.h" 16 #include "llvm/ADT/SmallVector.h" 17 #include "llvm/Analysis/DXILMetadataAnalysis.h" 18 #include "llvm/Analysis/DXILResource.h" 19 #include "llvm/CodeGen/Passes.h" 20 #include "llvm/IR/DiagnosticInfo.h" 21 #include "llvm/IR/IRBuilder.h" 22 #include "llvm/IR/Instruction.h" 23 #include "llvm/IR/Instructions.h" 24 #include "llvm/IR/Intrinsics.h" 25 #include "llvm/IR/IntrinsicsDirectX.h" 26 #include "llvm/IR/Module.h" 27 #include "llvm/IR/PassManager.h" 28 #include "llvm/InitializePasses.h" 29 #include "llvm/Pass.h" 30 #include "llvm/Support/ErrorHandling.h" 31 32 #define DEBUG_TYPE "dxil-op-lower" 33 34 using namespace llvm; 35 using namespace llvm::dxil; 36 37 static bool isVectorArgExpansion(Function &F) { 38 switch (F.getIntrinsicID()) { 39 case Intrinsic::dx_dot2: 40 case Intrinsic::dx_dot3: 41 case Intrinsic::dx_dot4: 42 return true; 43 } 44 return false; 45 } 46 47 static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) { 48 SmallVector<Value *> ExtractedElements; 49 auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType()); 50 for (unsigned I = 0; I < VecArg->getNumElements(); ++I) { 51 Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I); 52 Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index); 53 ExtractedElements.push_back(ExtractedElement); 54 } 55 return ExtractedElements; 56 } 57 58 static SmallVector<Value *> argVectorFlatten(CallInst *Orig, 59 IRBuilder<> &Builder) { 60 // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening. 61 unsigned NumOperands = Orig->getNumOperands() - 1; 62 assert(NumOperands > 0); 63 Value *Arg0 = Orig->getOperand(0); 64 [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType()); 65 assert(VecArg0); 66 SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder); 67 for (unsigned I = 1; I < NumOperands; ++I) { 68 Value *Arg = Orig->getOperand(I); 69 [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType()); 70 assert(VecArg); 71 assert(VecArg0->getElementType() == VecArg->getElementType()); 72 assert(VecArg0->getNumElements() == VecArg->getNumElements()); 73 auto NextOperandList = populateOperands(Arg, Builder); 74 NewOperands.append(NextOperandList.begin(), NextOperandList.end()); 75 } 76 return NewOperands; 77 } 78 79 namespace { 80 class OpLowerer { 81 Module &M; 82 DXILOpBuilder OpBuilder; 83 DXILBindingMap &DBM; 84 DXILResourceTypeMap &DRTM; 85 SmallVector<CallInst *> CleanupCasts; 86 87 public: 88 OpLowerer(Module &M, DXILBindingMap &DBM, DXILResourceTypeMap &DRTM) 89 : M(M), OpBuilder(M), DBM(DBM), DRTM(DRTM) {} 90 91 /// Replace every call to \c F using \c ReplaceCall, and then erase \c F. If 92 /// there is an error replacing a call, we emit a diagnostic and return true. 93 [[nodiscard]] bool 94 replaceFunction(Function &F, 95 llvm::function_ref<Error(CallInst *CI)> ReplaceCall) { 96 for (User *U : make_early_inc_range(F.users())) { 97 CallInst *CI = dyn_cast<CallInst>(U); 98 if (!CI) 99 continue; 100 101 if (Error E = ReplaceCall(CI)) { 102 std::string Message(toString(std::move(E))); 103 DiagnosticInfoUnsupported Diag(*CI->getFunction(), Message, 104 CI->getDebugLoc()); 105 M.getContext().diagnose(Diag); 106 return true; 107 } 108 } 109 if (F.user_empty()) 110 F.eraseFromParent(); 111 return false; 112 } 113 114 struct IntrinArgSelect { 115 enum class Type { 116 #define DXIL_OP_INTRINSIC_ARG_SELECT_TYPE(name) name, 117 #include "DXILOperation.inc" 118 }; 119 Type Type; 120 int Value; 121 }; 122 123 [[nodiscard]] bool 124 replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp, 125 ArrayRef<IntrinArgSelect> ArgSelects) { 126 bool IsVectorArgExpansion = isVectorArgExpansion(F); 127 assert(!(IsVectorArgExpansion && ArgSelects.size()) && 128 "Cann't do vector arg expansion when using arg selects."); 129 return replaceFunction(F, [&](CallInst *CI) -> Error { 130 OpBuilder.getIRB().SetInsertPoint(CI); 131 SmallVector<Value *> Args; 132 if (ArgSelects.size()) { 133 for (const IntrinArgSelect &A : ArgSelects) { 134 switch (A.Type) { 135 case IntrinArgSelect::Type::Index: 136 Args.push_back(CI->getArgOperand(A.Value)); 137 break; 138 case IntrinArgSelect::Type::I8: 139 Args.push_back(OpBuilder.getIRB().getInt8((uint8_t)A.Value)); 140 break; 141 case IntrinArgSelect::Type::I32: 142 Args.push_back(OpBuilder.getIRB().getInt32(A.Value)); 143 break; 144 } 145 } 146 } else if (IsVectorArgExpansion) { 147 Args = argVectorFlatten(CI, OpBuilder.getIRB()); 148 } else { 149 Args.append(CI->arg_begin(), CI->arg_end()); 150 } 151 152 Expected<CallInst *> OpCall = 153 OpBuilder.tryCreateOp(DXILOp, Args, CI->getName(), F.getReturnType()); 154 if (Error E = OpCall.takeError()) 155 return E; 156 157 CI->replaceAllUsesWith(*OpCall); 158 CI->eraseFromParent(); 159 return Error::success(); 160 }); 161 } 162 163 [[nodiscard]] bool replaceFunctionWithNamedStructOp( 164 Function &F, dxil::OpCode DXILOp, Type *NewRetTy, 165 llvm::function_ref<Error(CallInst *CI, CallInst *Op)> ReplaceUses) { 166 bool IsVectorArgExpansion = isVectorArgExpansion(F); 167 return replaceFunction(F, [&](CallInst *CI) -> Error { 168 SmallVector<Value *> Args; 169 OpBuilder.getIRB().SetInsertPoint(CI); 170 if (IsVectorArgExpansion) { 171 SmallVector<Value *> NewArgs = argVectorFlatten(CI, OpBuilder.getIRB()); 172 Args.append(NewArgs.begin(), NewArgs.end()); 173 } else 174 Args.append(CI->arg_begin(), CI->arg_end()); 175 176 Expected<CallInst *> OpCall = 177 OpBuilder.tryCreateOp(DXILOp, Args, CI->getName(), NewRetTy); 178 if (Error E = OpCall.takeError()) 179 return E; 180 if (Error E = ReplaceUses(CI, *OpCall)) 181 return E; 182 183 return Error::success(); 184 }); 185 } 186 187 /// Create a cast between a `target("dx")` type and `dx.types.Handle`, which 188 /// is intended to be removed by the end of lowering. This is used to allow 189 /// lowering of ops which need to change their return or argument types in a 190 /// piecemeal way - we can add the casts in to avoid updating all of the uses 191 /// or defs, and by the end all of the casts will be redundant. 192 Value *createTmpHandleCast(Value *V, Type *Ty) { 193 CallInst *Cast = OpBuilder.getIRB().CreateIntrinsic( 194 Intrinsic::dx_resource_casthandle, {Ty, V->getType()}, {V}); 195 CleanupCasts.push_back(Cast); 196 return Cast; 197 } 198 199 void cleanupHandleCasts() { 200 SmallVector<CallInst *> ToRemove; 201 SmallVector<Function *> CastFns; 202 203 for (CallInst *Cast : CleanupCasts) { 204 // These casts were only put in to ease the move from `target("dx")` types 205 // to `dx.types.Handle in a piecemeal way. At this point, all of the 206 // non-cast uses should now be `dx.types.Handle`, and remaining casts 207 // should all form pairs to and from the now unused `target("dx")` type. 208 CastFns.push_back(Cast->getCalledFunction()); 209 210 // If the cast is not to `dx.types.Handle`, it should be the first part of 211 // the pair. Keep track so we can remove it once it has no more uses. 212 if (Cast->getType() != OpBuilder.getHandleType()) { 213 ToRemove.push_back(Cast); 214 continue; 215 } 216 // Otherwise, we're the second handle in a pair. Forward the arguments and 217 // remove the (second) cast. 218 CallInst *Def = cast<CallInst>(Cast->getOperand(0)); 219 assert(Def->getIntrinsicID() == Intrinsic::dx_resource_casthandle && 220 "Unbalanced pair of temporary handle casts"); 221 Cast->replaceAllUsesWith(Def->getOperand(0)); 222 Cast->eraseFromParent(); 223 } 224 for (CallInst *Cast : ToRemove) { 225 assert(Cast->user_empty() && "Temporary handle cast still has users"); 226 Cast->eraseFromParent(); 227 } 228 229 // Deduplicate the cast functions so that we only erase each one once. 230 llvm::sort(CastFns); 231 CastFns.erase(llvm::unique(CastFns), CastFns.end()); 232 for (Function *F : CastFns) 233 F->eraseFromParent(); 234 235 CleanupCasts.clear(); 236 } 237 238 // Remove the resource global associated with the handleFromBinding call 239 // instruction and their uses as they aren't needed anymore. 240 // TODO: We should verify that all the globals get removed. 241 // It's expected we'll need a custom pass in the future that will eliminate 242 // the need for this here. 243 void removeResourceGlobals(CallInst *CI) { 244 for (User *User : make_early_inc_range(CI->users())) { 245 if (StoreInst *Store = dyn_cast<StoreInst>(User)) { 246 Value *V = Store->getOperand(1); 247 Store->eraseFromParent(); 248 if (GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) 249 if (GV->use_empty()) { 250 GV->removeDeadConstantUsers(); 251 GV->eraseFromParent(); 252 } 253 } 254 } 255 } 256 257 [[nodiscard]] bool lowerToCreateHandle(Function &F) { 258 IRBuilder<> &IRB = OpBuilder.getIRB(); 259 Type *Int8Ty = IRB.getInt8Ty(); 260 Type *Int32Ty = IRB.getInt32Ty(); 261 262 return replaceFunction(F, [&](CallInst *CI) -> Error { 263 IRB.SetInsertPoint(CI); 264 265 auto *It = DBM.find(CI); 266 assert(It != DBM.end() && "Resource not in map?"); 267 dxil::ResourceBindingInfo &RI = *It; 268 269 const auto &Binding = RI.getBinding(); 270 dxil::ResourceClass RC = DRTM[RI.getHandleTy()].getResourceClass(); 271 272 Value *IndexOp = CI->getArgOperand(3); 273 if (Binding.LowerBound != 0) 274 IndexOp = IRB.CreateAdd(IndexOp, 275 ConstantInt::get(Int32Ty, Binding.LowerBound)); 276 277 std::array<Value *, 4> Args{ 278 ConstantInt::get(Int8Ty, llvm::to_underlying(RC)), 279 ConstantInt::get(Int32Ty, Binding.RecordID), IndexOp, 280 CI->getArgOperand(4)}; 281 Expected<CallInst *> OpCall = 282 OpBuilder.tryCreateOp(OpCode::CreateHandle, Args, CI->getName()); 283 if (Error E = OpCall.takeError()) 284 return E; 285 286 Value *Cast = createTmpHandleCast(*OpCall, CI->getType()); 287 288 removeResourceGlobals(CI); 289 290 CI->replaceAllUsesWith(Cast); 291 CI->eraseFromParent(); 292 return Error::success(); 293 }); 294 } 295 296 [[nodiscard]] bool lowerToBindAndAnnotateHandle(Function &F) { 297 IRBuilder<> &IRB = OpBuilder.getIRB(); 298 Type *Int32Ty = IRB.getInt32Ty(); 299 300 return replaceFunction(F, [&](CallInst *CI) -> Error { 301 IRB.SetInsertPoint(CI); 302 303 auto *It = DBM.find(CI); 304 assert(It != DBM.end() && "Resource not in map?"); 305 dxil::ResourceBindingInfo &RI = *It; 306 307 const auto &Binding = RI.getBinding(); 308 dxil::ResourceTypeInfo &RTI = DRTM[RI.getHandleTy()]; 309 dxil::ResourceClass RC = RTI.getResourceClass(); 310 311 Value *IndexOp = CI->getArgOperand(3); 312 if (Binding.LowerBound != 0) 313 IndexOp = IRB.CreateAdd(IndexOp, 314 ConstantInt::get(Int32Ty, Binding.LowerBound)); 315 316 std::pair<uint32_t, uint32_t> Props = 317 RI.getAnnotateProps(*F.getParent(), RTI); 318 319 // For `CreateHandleFromBinding` we need the upper bound rather than the 320 // size, so we need to be careful about the difference for "unbounded". 321 uint32_t Unbounded = std::numeric_limits<uint32_t>::max(); 322 uint32_t UpperBound = Binding.Size == Unbounded 323 ? Unbounded 324 : Binding.LowerBound + Binding.Size - 1; 325 Constant *ResBind = OpBuilder.getResBind(Binding.LowerBound, UpperBound, 326 Binding.Space, RC); 327 std::array<Value *, 3> BindArgs{ResBind, IndexOp, CI->getArgOperand(4)}; 328 Expected<CallInst *> OpBind = OpBuilder.tryCreateOp( 329 OpCode::CreateHandleFromBinding, BindArgs, CI->getName()); 330 if (Error E = OpBind.takeError()) 331 return E; 332 333 std::array<Value *, 2> AnnotateArgs{ 334 *OpBind, OpBuilder.getResProps(Props.first, Props.second)}; 335 Expected<CallInst *> OpAnnotate = OpBuilder.tryCreateOp( 336 OpCode::AnnotateHandle, AnnotateArgs, 337 CI->hasName() ? CI->getName() + "_annot" : Twine()); 338 if (Error E = OpAnnotate.takeError()) 339 return E; 340 341 Value *Cast = createTmpHandleCast(*OpAnnotate, CI->getType()); 342 343 removeResourceGlobals(CI); 344 345 CI->replaceAllUsesWith(Cast); 346 CI->eraseFromParent(); 347 348 return Error::success(); 349 }); 350 } 351 352 /// Lower `dx.resource.handlefrombinding` intrinsics depending on the shader 353 /// model and taking into account binding information from 354 /// DXILResourceBindingAnalysis. 355 bool lowerHandleFromBinding(Function &F) { 356 Triple TT(Triple(M.getTargetTriple())); 357 if (TT.getDXILVersion() < VersionTuple(1, 6)) 358 return lowerToCreateHandle(F); 359 return lowerToBindAndAnnotateHandle(F); 360 } 361 362 Error replaceSplitDoubleCallUsages(CallInst *Intrin, CallInst *Op) { 363 for (Use &U : make_early_inc_range(Intrin->uses())) { 364 if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) { 365 366 if (EVI->getNumIndices() != 1) 367 return createStringError(std::errc::invalid_argument, 368 "Splitdouble has only 2 elements"); 369 EVI->setOperand(0, Op); 370 } else { 371 return make_error<StringError>( 372 "Splitdouble use is not ExtractValueInst", 373 inconvertibleErrorCode()); 374 } 375 } 376 377 Intrin->eraseFromParent(); 378 379 return Error::success(); 380 } 381 382 /// Replace uses of \c Intrin with the values in the `dx.ResRet` of \c Op. 383 /// Since we expect to be post-scalarization, make an effort to avoid vectors. 384 Error replaceResRetUses(CallInst *Intrin, CallInst *Op, bool HasCheckBit) { 385 IRBuilder<> &IRB = OpBuilder.getIRB(); 386 387 Instruction *OldResult = Intrin; 388 Type *OldTy = Intrin->getType(); 389 390 if (HasCheckBit) { 391 auto *ST = cast<StructType>(OldTy); 392 393 Value *CheckOp = nullptr; 394 Type *Int32Ty = IRB.getInt32Ty(); 395 for (Use &U : make_early_inc_range(OldResult->uses())) { 396 if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) { 397 ArrayRef<unsigned> Indices = EVI->getIndices(); 398 assert(Indices.size() == 1); 399 // We're only interested in uses of the check bit for now. 400 if (Indices[0] != 1) 401 continue; 402 if (!CheckOp) { 403 Value *NewEVI = IRB.CreateExtractValue(Op, 4); 404 Expected<CallInst *> OpCall = OpBuilder.tryCreateOp( 405 OpCode::CheckAccessFullyMapped, {NewEVI}, 406 OldResult->hasName() ? OldResult->getName() + "_check" 407 : Twine(), 408 Int32Ty); 409 if (Error E = OpCall.takeError()) 410 return E; 411 CheckOp = *OpCall; 412 } 413 EVI->replaceAllUsesWith(CheckOp); 414 EVI->eraseFromParent(); 415 } 416 } 417 418 if (OldResult->use_empty()) { 419 // Only the check bit was used, so we're done here. 420 OldResult->eraseFromParent(); 421 return Error::success(); 422 } 423 424 assert(OldResult->hasOneUse() && 425 isa<ExtractValueInst>(*OldResult->user_begin()) && 426 "Expected only use to be extract of first element"); 427 OldResult = cast<Instruction>(*OldResult->user_begin()); 428 OldTy = ST->getElementType(0); 429 } 430 431 // For scalars, we just extract the first element. 432 if (!isa<FixedVectorType>(OldTy)) { 433 Value *EVI = IRB.CreateExtractValue(Op, 0); 434 OldResult->replaceAllUsesWith(EVI); 435 OldResult->eraseFromParent(); 436 if (OldResult != Intrin) { 437 assert(Intrin->use_empty() && "Intrinsic still has uses?"); 438 Intrin->eraseFromParent(); 439 } 440 return Error::success(); 441 } 442 443 std::array<Value *, 4> Extracts = {}; 444 SmallVector<ExtractElementInst *> DynamicAccesses; 445 446 // The users of the operation should all be scalarized, so we attempt to 447 // replace the extractelements with extractvalues directly. 448 for (Use &U : make_early_inc_range(OldResult->uses())) { 449 if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser())) { 450 if (auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand())) { 451 size_t IndexVal = IndexOp->getZExtValue(); 452 assert(IndexVal < 4 && "Index into buffer load out of range"); 453 if (!Extracts[IndexVal]) 454 Extracts[IndexVal] = IRB.CreateExtractValue(Op, IndexVal); 455 EEI->replaceAllUsesWith(Extracts[IndexVal]); 456 EEI->eraseFromParent(); 457 } else { 458 DynamicAccesses.push_back(EEI); 459 } 460 } 461 } 462 463 const auto *VecTy = cast<FixedVectorType>(OldTy); 464 const unsigned N = VecTy->getNumElements(); 465 466 // If there's a dynamic access we need to round trip through stack memory so 467 // that we don't leave vectors around. 468 if (!DynamicAccesses.empty()) { 469 Type *Int32Ty = IRB.getInt32Ty(); 470 Constant *Zero = ConstantInt::get(Int32Ty, 0); 471 472 Type *ElTy = VecTy->getElementType(); 473 Type *ArrayTy = ArrayType::get(ElTy, N); 474 Value *Alloca = IRB.CreateAlloca(ArrayTy); 475 476 for (int I = 0, E = N; I != E; ++I) { 477 if (!Extracts[I]) 478 Extracts[I] = IRB.CreateExtractValue(Op, I); 479 Value *GEP = IRB.CreateInBoundsGEP( 480 ArrayTy, Alloca, {Zero, ConstantInt::get(Int32Ty, I)}); 481 IRB.CreateStore(Extracts[I], GEP); 482 } 483 484 for (ExtractElementInst *EEI : DynamicAccesses) { 485 Value *GEP = IRB.CreateInBoundsGEP(ArrayTy, Alloca, 486 {Zero, EEI->getIndexOperand()}); 487 Value *Load = IRB.CreateLoad(ElTy, GEP); 488 EEI->replaceAllUsesWith(Load); 489 EEI->eraseFromParent(); 490 } 491 } 492 493 // If we still have uses, then we're not fully scalarized and need to 494 // recreate the vector. This should only happen for things like exported 495 // functions from libraries. 496 if (!OldResult->use_empty()) { 497 for (int I = 0, E = N; I != E; ++I) 498 if (!Extracts[I]) 499 Extracts[I] = IRB.CreateExtractValue(Op, I); 500 501 Value *Vec = UndefValue::get(OldTy); 502 for (int I = 0, E = N; I != E; ++I) 503 Vec = IRB.CreateInsertElement(Vec, Extracts[I], I); 504 OldResult->replaceAllUsesWith(Vec); 505 } 506 507 OldResult->eraseFromParent(); 508 if (OldResult != Intrin) { 509 assert(Intrin->use_empty() && "Intrinsic still has uses?"); 510 Intrin->eraseFromParent(); 511 } 512 513 return Error::success(); 514 } 515 516 [[nodiscard]] bool lowerTypedBufferLoad(Function &F, bool HasCheckBit) { 517 IRBuilder<> &IRB = OpBuilder.getIRB(); 518 Type *Int32Ty = IRB.getInt32Ty(); 519 520 return replaceFunction(F, [&](CallInst *CI) -> Error { 521 IRB.SetInsertPoint(CI); 522 523 Value *Handle = 524 createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType()); 525 Value *Index0 = CI->getArgOperand(1); 526 Value *Index1 = UndefValue::get(Int32Ty); 527 528 Type *OldTy = CI->getType(); 529 if (HasCheckBit) 530 OldTy = cast<StructType>(OldTy)->getElementType(0); 531 Type *NewRetTy = OpBuilder.getResRetType(OldTy->getScalarType()); 532 533 std::array<Value *, 3> Args{Handle, Index0, Index1}; 534 Expected<CallInst *> OpCall = OpBuilder.tryCreateOp( 535 OpCode::BufferLoad, Args, CI->getName(), NewRetTy); 536 if (Error E = OpCall.takeError()) 537 return E; 538 if (Error E = replaceResRetUses(CI, *OpCall, HasCheckBit)) 539 return E; 540 541 return Error::success(); 542 }); 543 } 544 545 [[nodiscard]] bool lowerRawBufferLoad(Function &F) { 546 Triple TT(Triple(M.getTargetTriple())); 547 VersionTuple DXILVersion = TT.getDXILVersion(); 548 const DataLayout &DL = F.getDataLayout(); 549 IRBuilder<> &IRB = OpBuilder.getIRB(); 550 Type *Int8Ty = IRB.getInt8Ty(); 551 Type *Int32Ty = IRB.getInt32Ty(); 552 553 return replaceFunction(F, [&](CallInst *CI) -> Error { 554 IRB.SetInsertPoint(CI); 555 556 Type *OldTy = cast<StructType>(CI->getType())->getElementType(0); 557 Type *ScalarTy = OldTy->getScalarType(); 558 Type *NewRetTy = OpBuilder.getResRetType(ScalarTy); 559 560 Value *Handle = 561 createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType()); 562 Value *Index0 = CI->getArgOperand(1); 563 Value *Index1 = CI->getArgOperand(2); 564 uint64_t NumElements = 565 DL.getTypeSizeInBits(OldTy) / DL.getTypeSizeInBits(ScalarTy); 566 Value *Mask = ConstantInt::get(Int8Ty, ~(~0U << NumElements)); 567 Value *Align = 568 ConstantInt::get(Int32Ty, DL.getPrefTypeAlign(ScalarTy).value()); 569 570 Expected<CallInst *> OpCall = 571 DXILVersion >= VersionTuple(1, 2) 572 ? OpBuilder.tryCreateOp(OpCode::RawBufferLoad, 573 {Handle, Index0, Index1, Mask, Align}, 574 CI->getName(), NewRetTy) 575 : OpBuilder.tryCreateOp(OpCode::BufferLoad, 576 {Handle, Index0, Index1}, CI->getName(), 577 NewRetTy); 578 if (Error E = OpCall.takeError()) 579 return E; 580 if (Error E = replaceResRetUses(CI, *OpCall, /*HasCheckBit=*/true)) 581 return E; 582 583 return Error::success(); 584 }); 585 } 586 587 [[nodiscard]] bool lowerUpdateCounter(Function &F) { 588 IRBuilder<> &IRB = OpBuilder.getIRB(); 589 Type *Int32Ty = IRB.getInt32Ty(); 590 591 return replaceFunction(F, [&](CallInst *CI) -> Error { 592 IRB.SetInsertPoint(CI); 593 Value *Handle = 594 createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType()); 595 Value *Op1 = CI->getArgOperand(1); 596 597 std::array<Value *, 2> Args{Handle, Op1}; 598 599 Expected<CallInst *> OpCall = OpBuilder.tryCreateOp( 600 OpCode::UpdateCounter, Args, CI->getName(), Int32Ty); 601 602 if (Error E = OpCall.takeError()) 603 return E; 604 605 CI->replaceAllUsesWith(*OpCall); 606 CI->eraseFromParent(); 607 return Error::success(); 608 }); 609 } 610 611 [[nodiscard]] bool lowerGetPointer(Function &F) { 612 // These should have already been handled in DXILResourceAccess, so we can 613 // just clean up the dead prototype. 614 assert(F.user_empty() && "getpointer operations should have been removed"); 615 F.eraseFromParent(); 616 return false; 617 } 618 619 [[nodiscard]] bool lowerBufferStore(Function &F, bool IsRaw) { 620 Triple TT(Triple(M.getTargetTriple())); 621 VersionTuple DXILVersion = TT.getDXILVersion(); 622 const DataLayout &DL = F.getDataLayout(); 623 IRBuilder<> &IRB = OpBuilder.getIRB(); 624 Type *Int8Ty = IRB.getInt8Ty(); 625 Type *Int32Ty = IRB.getInt32Ty(); 626 627 return replaceFunction(F, [&](CallInst *CI) -> Error { 628 IRB.SetInsertPoint(CI); 629 630 Value *Handle = 631 createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType()); 632 Value *Index0 = CI->getArgOperand(1); 633 Value *Index1 = IsRaw ? CI->getArgOperand(2) : UndefValue::get(Int32Ty); 634 635 Value *Data = CI->getArgOperand(IsRaw ? 3 : 2); 636 Type *DataTy = Data->getType(); 637 Type *ScalarTy = DataTy->getScalarType(); 638 639 uint64_t NumElements = 640 DL.getTypeSizeInBits(DataTy) / DL.getTypeSizeInBits(ScalarTy); 641 Value *Mask = ConstantInt::get(Int8Ty, ~(~0U << NumElements)); 642 643 // TODO: check that we only have vector or scalar... 644 if (!IsRaw && NumElements != 4) 645 return make_error<StringError>( 646 "typedBufferStore data must be a vector of 4 elements", 647 inconvertibleErrorCode()); 648 else if (NumElements > 4) 649 return make_error<StringError>( 650 "rawBufferStore data must have at most 4 elements", 651 inconvertibleErrorCode()); 652 653 std::array<Value *, 4> DataElements{nullptr, nullptr, nullptr, nullptr}; 654 if (DataTy == ScalarTy) 655 DataElements[0] = Data; 656 else { 657 // Since we're post-scalarizer, if we see a vector here it's likely 658 // constructed solely for the argument of the store. Just use the scalar 659 // values from before they're inserted into the temporary. 660 auto *IEI = dyn_cast<InsertElementInst>(Data); 661 while (IEI) { 662 auto *IndexOp = dyn_cast<ConstantInt>(IEI->getOperand(2)); 663 if (!IndexOp) 664 break; 665 size_t IndexVal = IndexOp->getZExtValue(); 666 assert(IndexVal < 4 && "Too many elements for buffer store"); 667 DataElements[IndexVal] = IEI->getOperand(1); 668 IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0)); 669 } 670 } 671 672 // If for some reason we weren't able to forward the arguments from the 673 // scalarizer artifact, then we may need to actually extract elements from 674 // the vector. 675 for (int I = 0, E = NumElements; I < E; ++I) 676 if (DataElements[I] == nullptr) 677 DataElements[I] = 678 IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, I)); 679 // For any elements beyond the length of the vector, fill up with undef. 680 for (int I = NumElements, E = 4; I < E; ++I) 681 if (DataElements[I] == nullptr) 682 DataElements[I] = UndefValue::get(ScalarTy); 683 684 dxil::OpCode Op = OpCode::BufferStore; 685 SmallVector<Value *, 9> Args{ 686 Handle, Index0, Index1, DataElements[0], 687 DataElements[1], DataElements[2], DataElements[3], Mask}; 688 if (IsRaw && DXILVersion >= VersionTuple(1, 2)) { 689 Op = OpCode::RawBufferStore; 690 // RawBufferStore requires the alignment 691 Args.push_back( 692 ConstantInt::get(Int32Ty, DL.getPrefTypeAlign(ScalarTy).value())); 693 } 694 Expected<CallInst *> OpCall = 695 OpBuilder.tryCreateOp(Op, Args, CI->getName()); 696 if (Error E = OpCall.takeError()) 697 return E; 698 699 CI->eraseFromParent(); 700 // Clean up any leftover `insertelement`s 701 auto *IEI = dyn_cast<InsertElementInst>(Data); 702 while (IEI && IEI->use_empty()) { 703 InsertElementInst *Tmp = IEI; 704 IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0)); 705 Tmp->eraseFromParent(); 706 } 707 708 return Error::success(); 709 }); 710 } 711 712 [[nodiscard]] bool lowerCtpopToCountBits(Function &F) { 713 IRBuilder<> &IRB = OpBuilder.getIRB(); 714 Type *Int32Ty = IRB.getInt32Ty(); 715 716 return replaceFunction(F, [&](CallInst *CI) -> Error { 717 IRB.SetInsertPoint(CI); 718 SmallVector<Value *> Args; 719 Args.append(CI->arg_begin(), CI->arg_end()); 720 721 Type *RetTy = Int32Ty; 722 Type *FRT = F.getReturnType(); 723 if (const auto *VT = dyn_cast<VectorType>(FRT)) 724 RetTy = VectorType::get(RetTy, VT); 725 726 Expected<CallInst *> OpCall = OpBuilder.tryCreateOp( 727 dxil::OpCode::CountBits, Args, CI->getName(), RetTy); 728 if (Error E = OpCall.takeError()) 729 return E; 730 731 // If the result type is 32 bits we can do a direct replacement. 732 if (FRT->isIntOrIntVectorTy(32)) { 733 CI->replaceAllUsesWith(*OpCall); 734 CI->eraseFromParent(); 735 return Error::success(); 736 } 737 738 unsigned CastOp; 739 unsigned CastOp2; 740 if (FRT->isIntOrIntVectorTy(16)) { 741 CastOp = Instruction::ZExt; 742 CastOp2 = Instruction::SExt; 743 } else { // must be 64 bits 744 assert(FRT->isIntOrIntVectorTy(64) && 745 "Currently only lowering 16, 32, or 64 bit ctpop to CountBits \ 746 is supported."); 747 CastOp = Instruction::Trunc; 748 CastOp2 = Instruction::Trunc; 749 } 750 751 // It is correct to replace the ctpop with the dxil op and 752 // remove all casts to i32 753 bool NeedsCast = false; 754 for (User *User : make_early_inc_range(CI->users())) { 755 Instruction *I = dyn_cast<Instruction>(User); 756 if (I && (I->getOpcode() == CastOp || I->getOpcode() == CastOp2) && 757 I->getType() == RetTy) { 758 I->replaceAllUsesWith(*OpCall); 759 I->eraseFromParent(); 760 } else 761 NeedsCast = true; 762 } 763 764 // It is correct to replace a ctpop with the dxil op and 765 // a cast from i32 to the return type of the ctpop 766 // the cast is emitted here if there is a non-cast to i32 767 // instr which uses the ctpop 768 if (NeedsCast) { 769 Value *Cast = 770 IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(), "ctpop.cast"); 771 CI->replaceAllUsesWith(Cast); 772 } 773 774 CI->eraseFromParent(); 775 return Error::success(); 776 }); 777 } 778 779 bool lowerIntrinsics() { 780 bool Updated = false; 781 bool HasErrors = false; 782 783 for (Function &F : make_early_inc_range(M.functions())) { 784 if (!F.isDeclaration()) 785 continue; 786 Intrinsic::ID ID = F.getIntrinsicID(); 787 switch (ID) { 788 default: 789 continue; 790 #define DXIL_OP_INTRINSIC(OpCode, Intrin, ...) \ 791 case Intrin: \ 792 HasErrors |= replaceFunctionWithOp( \ 793 F, OpCode, ArrayRef<IntrinArgSelect>{__VA_ARGS__}); \ 794 break; 795 #include "DXILOperation.inc" 796 case Intrinsic::dx_resource_handlefrombinding: 797 HasErrors |= lowerHandleFromBinding(F); 798 break; 799 case Intrinsic::dx_resource_getpointer: 800 HasErrors |= lowerGetPointer(F); 801 break; 802 case Intrinsic::dx_resource_load_typedbuffer: 803 HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/true); 804 break; 805 case Intrinsic::dx_resource_store_typedbuffer: 806 HasErrors |= lowerBufferStore(F, /*IsRaw=*/false); 807 break; 808 case Intrinsic::dx_resource_load_rawbuffer: 809 HasErrors |= lowerRawBufferLoad(F); 810 break; 811 case Intrinsic::dx_resource_store_rawbuffer: 812 HasErrors |= lowerBufferStore(F, /*IsRaw=*/true); 813 break; 814 case Intrinsic::dx_resource_updatecounter: 815 HasErrors |= lowerUpdateCounter(F); 816 break; 817 // TODO: this can be removed when 818 // https://github.com/llvm/llvm-project/issues/113192 is fixed 819 case Intrinsic::dx_splitdouble: 820 HasErrors |= replaceFunctionWithNamedStructOp( 821 F, OpCode::SplitDouble, 822 OpBuilder.getSplitDoubleType(M.getContext()), 823 [&](CallInst *CI, CallInst *Op) { 824 return replaceSplitDoubleCallUsages(CI, Op); 825 }); 826 break; 827 case Intrinsic::ctpop: 828 HasErrors |= lowerCtpopToCountBits(F); 829 break; 830 } 831 Updated = true; 832 } 833 if (Updated && !HasErrors) 834 cleanupHandleCasts(); 835 836 return Updated; 837 } 838 }; 839 } // namespace 840 841 PreservedAnalyses DXILOpLowering::run(Module &M, ModuleAnalysisManager &MAM) { 842 DXILBindingMap &DBM = MAM.getResult<DXILResourceBindingAnalysis>(M); 843 DXILResourceTypeMap &DRTM = MAM.getResult<DXILResourceTypeAnalysis>(M); 844 845 bool MadeChanges = OpLowerer(M, DBM, DRTM).lowerIntrinsics(); 846 if (!MadeChanges) 847 return PreservedAnalyses::all(); 848 PreservedAnalyses PA; 849 PA.preserve<DXILResourceBindingAnalysis>(); 850 PA.preserve<DXILMetadataAnalysis>(); 851 PA.preserve<ShaderFlagsAnalysis>(); 852 return PA; 853 } 854 855 namespace { 856 class DXILOpLoweringLegacy : public ModulePass { 857 public: 858 bool runOnModule(Module &M) override { 859 DXILBindingMap &DBM = 860 getAnalysis<DXILResourceBindingWrapperPass>().getBindingMap(); 861 DXILResourceTypeMap &DRTM = 862 getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap(); 863 864 return OpLowerer(M, DBM, DRTM).lowerIntrinsics(); 865 } 866 StringRef getPassName() const override { return "DXIL Op Lowering"; } 867 DXILOpLoweringLegacy() : ModulePass(ID) {} 868 869 static char ID; // Pass identification. 870 void getAnalysisUsage(llvm::AnalysisUsage &AU) const override { 871 AU.addRequired<DXILResourceTypeWrapperPass>(); 872 AU.addRequired<DXILResourceBindingWrapperPass>(); 873 AU.addPreserved<DXILResourceBindingWrapperPass>(); 874 AU.addPreserved<DXILResourceMDWrapper>(); 875 AU.addPreserved<DXILMetadataAnalysisWrapperPass>(); 876 AU.addPreserved<ShaderFlagsAnalysisWrapper>(); 877 } 878 }; 879 char DXILOpLoweringLegacy::ID = 0; 880 } // end anonymous namespace 881 882 INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", 883 false, false) 884 INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass) 885 INITIALIZE_PASS_DEPENDENCY(DXILResourceBindingWrapperPass) 886 INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false, 887 false) 888 889 ModulePass *llvm::createDXILOpLoweringLegacyPass() { 890 return new DXILOpLoweringLegacy(); 891 } 892