181ee3855SJustin Bogner //===- DXILOpLowering.cpp - Lowering to DXIL operations -------------------===// 285285be9SXiang Li // 385285be9SXiang Li // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 485285be9SXiang Li // See https://llvm.org/LICENSE.txt for license information. 585285be9SXiang Li // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 685285be9SXiang Li // 785285be9SXiang Li //===----------------------------------------------------------------------===// 885285be9SXiang Li 981ee3855SJustin Bogner #include "DXILOpLowering.h" 1085285be9SXiang Li #include "DXILConstants.h" 11de1a97dbSFarzon Lotfi #include "DXILIntrinsicExpansion.h" 1257006b14SXiang Li #include "DXILOpBuilder.h" 13bfd05102SJustin Bogner #include "DXILResourceAnalysis.h" 14bfd05102SJustin Bogner #include "DXILShaderFlags.h" 1585285be9SXiang Li #include "DirectX.h" 1685285be9SXiang Li #include "llvm/ADT/SmallVector.h" 17bfd05102SJustin Bogner #include "llvm/Analysis/DXILMetadataAnalysis.h" 18aa61925eSJustin Bogner #include "llvm/Analysis/DXILResource.h" 1985285be9SXiang Li #include "llvm/CodeGen/Passes.h" 208cf85653SJustin Bogner #include "llvm/IR/DiagnosticInfo.h" 2185285be9SXiang Li #include "llvm/IR/IRBuilder.h" 2285285be9SXiang Li #include "llvm/IR/Instruction.h" 23481bce01Sjoaosaffran #include "llvm/IR/Instructions.h" 2485285be9SXiang Li #include "llvm/IR/Intrinsics.h" 2543dc3190SXiang Li #include "llvm/IR/IntrinsicsDirectX.h" 2685285be9SXiang Li #include "llvm/IR/Module.h" 2785285be9SXiang Li #include "llvm/IR/PassManager.h" 28aa61925eSJustin Bogner #include "llvm/InitializePasses.h" 2985285be9SXiang Li #include "llvm/Pass.h" 3085285be9SXiang Li #include "llvm/Support/ErrorHandling.h" 3185285be9SXiang Li 3285285be9SXiang Li #define DEBUG_TYPE "dxil-op-lower" 3385285be9SXiang Li 3485285be9SXiang Li using namespace llvm; 35e77c40ffSChris Bieneman using namespace llvm::dxil; 3685285be9SXiang Li 37060df78cSFarzon Lotfi static bool isVectorArgExpansion(Function &F) { 38060df78cSFarzon Lotfi switch (F.getIntrinsicID()) { 39060df78cSFarzon Lotfi case Intrinsic::dx_dot2: 40060df78cSFarzon Lotfi case Intrinsic::dx_dot3: 41060df78cSFarzon Lotfi case Intrinsic::dx_dot4: 42060df78cSFarzon Lotfi return true; 43060df78cSFarzon Lotfi } 44060df78cSFarzon Lotfi return false; 45060df78cSFarzon Lotfi } 46060df78cSFarzon Lotfi 47060df78cSFarzon Lotfi static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) { 48ba8a2adeSJessica Clarke SmallVector<Value *> ExtractedElements; 49060df78cSFarzon Lotfi auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType()); 50060df78cSFarzon Lotfi for (unsigned I = 0; I < VecArg->getNumElements(); ++I) { 51060df78cSFarzon Lotfi Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I); 52060df78cSFarzon Lotfi Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index); 53060df78cSFarzon Lotfi ExtractedElements.push_back(ExtractedElement); 54060df78cSFarzon Lotfi } 55060df78cSFarzon Lotfi return ExtractedElements; 56060df78cSFarzon Lotfi } 57060df78cSFarzon Lotfi 58060df78cSFarzon Lotfi static SmallVector<Value *> argVectorFlatten(CallInst *Orig, 59060df78cSFarzon Lotfi IRBuilder<> &Builder) { 60060df78cSFarzon Lotfi // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening. 61060df78cSFarzon Lotfi unsigned NumOperands = Orig->getNumOperands() - 1; 62060df78cSFarzon Lotfi assert(NumOperands > 0); 63060df78cSFarzon Lotfi Value *Arg0 = Orig->getOperand(0); 64060df78cSFarzon Lotfi [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType()); 65060df78cSFarzon Lotfi assert(VecArg0); 66060df78cSFarzon Lotfi SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder); 67060df78cSFarzon Lotfi for (unsigned I = 1; I < NumOperands; ++I) { 68060df78cSFarzon Lotfi Value *Arg = Orig->getOperand(I); 69060df78cSFarzon Lotfi [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType()); 70060df78cSFarzon Lotfi assert(VecArg); 71060df78cSFarzon Lotfi assert(VecArg0->getElementType() == VecArg->getElementType()); 72060df78cSFarzon Lotfi assert(VecArg0->getNumElements() == VecArg->getNumElements()); 73060df78cSFarzon Lotfi auto NextOperandList = populateOperands(Arg, Builder); 74060df78cSFarzon Lotfi NewOperands.append(NextOperandList.begin(), NextOperandList.end()); 75060df78cSFarzon Lotfi } 76060df78cSFarzon Lotfi return NewOperands; 77060df78cSFarzon Lotfi } 78060df78cSFarzon Lotfi 79e56ad22bSJustin Bogner namespace { 80e56ad22bSJustin Bogner class OpLowerer { 81e56ad22bSJustin Bogner Module &M; 82e56ad22bSJustin Bogner DXILOpBuilder OpBuilder; 833eca15cbSJustin Bogner DXILBindingMap &DBM; 843eca15cbSJustin Bogner DXILResourceTypeMap &DRTM; 85aa61925eSJustin Bogner SmallVector<CallInst *> CleanupCasts; 86e56ad22bSJustin Bogner 87e56ad22bSJustin Bogner public: 883eca15cbSJustin Bogner OpLowerer(Module &M, DXILBindingMap &DBM, DXILResourceTypeMap &DRTM) 893eca15cbSJustin Bogner : M(M), OpBuilder(M), DBM(DBM), DRTM(DRTM) {} 90e56ad22bSJustin Bogner 9190e84113SJustin Bogner /// Replace every call to \c F using \c ReplaceCall, and then erase \c F. If 9290e84113SJustin Bogner /// there is an error replacing a call, we emit a diagnostic and return true. 9390e84113SJustin Bogner [[nodiscard]] bool 9490e84113SJustin Bogner replaceFunction(Function &F, 95e56ad22bSJustin Bogner llvm::function_ref<Error(CallInst *CI)> ReplaceCall) { 9685285be9SXiang Li for (User *U : make_early_inc_range(F.users())) { 9785285be9SXiang Li CallInst *CI = dyn_cast<CallInst>(U); 9885285be9SXiang Li if (!CI) 9985285be9SXiang Li continue; 10085285be9SXiang Li 101e56ad22bSJustin Bogner if (Error E = ReplaceCall(CI)) { 1028cf85653SJustin Bogner std::string Message(toString(std::move(E))); 1038cf85653SJustin Bogner DiagnosticInfoUnsupported Diag(*CI->getFunction(), Message, 1048cf85653SJustin Bogner CI->getDebugLoc()); 1058cf85653SJustin Bogner M.getContext().diagnose(Diag); 10690e84113SJustin Bogner return true; 1078cf85653SJustin Bogner } 10885285be9SXiang Li } 10985285be9SXiang Li if (F.user_empty()) 11085285be9SXiang Li F.eraseFromParent(); 11190e84113SJustin Bogner return false; 11285285be9SXiang Li } 11385285be9SXiang Li 1140a44b24dSAdam Yang struct IntrinArgSelect { 1150a44b24dSAdam Yang enum class Type { 1160a44b24dSAdam Yang #define DXIL_OP_INTRINSIC_ARG_SELECT_TYPE(name) name, 1170a44b24dSAdam Yang #include "DXILOperation.inc" 1180a44b24dSAdam Yang }; 1190a44b24dSAdam Yang Type Type; 1200a44b24dSAdam Yang int Value; 1210a44b24dSAdam Yang }; 1220a44b24dSAdam Yang 1230a44b24dSAdam Yang [[nodiscard]] bool 1240a44b24dSAdam Yang replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp, 1250a44b24dSAdam Yang ArrayRef<IntrinArgSelect> ArgSelects) { 126e56ad22bSJustin Bogner bool IsVectorArgExpansion = isVectorArgExpansion(F); 1270a44b24dSAdam Yang assert(!(IsVectorArgExpansion && ArgSelects.size()) && 1280a44b24dSAdam Yang "Cann't do vector arg expansion when using arg selects."); 12990e84113SJustin Bogner return replaceFunction(F, [&](CallInst *CI) -> Error { 130948249d8SAdam Yang OpBuilder.getIRB().SetInsertPoint(CI); 1310a44b24dSAdam Yang SmallVector<Value *> Args; 1320a44b24dSAdam Yang if (ArgSelects.size()) { 1330a44b24dSAdam Yang for (const IntrinArgSelect &A : ArgSelects) { 1340a44b24dSAdam Yang switch (A.Type) { 1350a44b24dSAdam Yang case IntrinArgSelect::Type::Index: 1360a44b24dSAdam Yang Args.push_back(CI->getArgOperand(A.Value)); 1370a44b24dSAdam Yang break; 1380a44b24dSAdam Yang case IntrinArgSelect::Type::I8: 1390a44b24dSAdam Yang Args.push_back(OpBuilder.getIRB().getInt8((uint8_t)A.Value)); 1400a44b24dSAdam Yang break; 1410a44b24dSAdam Yang case IntrinArgSelect::Type::I32: 1420a44b24dSAdam Yang Args.push_back(OpBuilder.getIRB().getInt32(A.Value)); 1430a44b24dSAdam Yang break; 1440a44b24dSAdam Yang } 1450a44b24dSAdam Yang } 1460a44b24dSAdam Yang } else if (IsVectorArgExpansion) { 1470a44b24dSAdam Yang Args = argVectorFlatten(CI, OpBuilder.getIRB()); 1480a44b24dSAdam Yang } else { 149e56ad22bSJustin Bogner Args.append(CI->arg_begin(), CI->arg_end()); 1500a44b24dSAdam Yang } 151e56ad22bSJustin Bogner 152e56ad22bSJustin Bogner Expected<CallInst *> OpCall = 1533d129016SJustin Bogner OpBuilder.tryCreateOp(DXILOp, Args, CI->getName(), F.getReturnType()); 154e56ad22bSJustin Bogner if (Error E = OpCall.takeError()) 155e56ad22bSJustin Bogner return E; 156e56ad22bSJustin Bogner 157e56ad22bSJustin Bogner CI->replaceAllUsesWith(*OpCall); 158e56ad22bSJustin Bogner CI->eraseFromParent(); 159e56ad22bSJustin Bogner return Error::success(); 160e56ad22bSJustin Bogner }); 161e56ad22bSJustin Bogner } 162e56ad22bSJustin Bogner 163481bce01Sjoaosaffran [[nodiscard]] bool replaceFunctionWithNamedStructOp( 164481bce01Sjoaosaffran Function &F, dxil::OpCode DXILOp, Type *NewRetTy, 165481bce01Sjoaosaffran llvm::function_ref<Error(CallInst *CI, CallInst *Op)> ReplaceUses) { 166481bce01Sjoaosaffran bool IsVectorArgExpansion = isVectorArgExpansion(F); 167481bce01Sjoaosaffran return replaceFunction(F, [&](CallInst *CI) -> Error { 168481bce01Sjoaosaffran SmallVector<Value *> Args; 169481bce01Sjoaosaffran OpBuilder.getIRB().SetInsertPoint(CI); 170481bce01Sjoaosaffran if (IsVectorArgExpansion) { 171481bce01Sjoaosaffran SmallVector<Value *> NewArgs = argVectorFlatten(CI, OpBuilder.getIRB()); 172481bce01Sjoaosaffran Args.append(NewArgs.begin(), NewArgs.end()); 173481bce01Sjoaosaffran } else 174481bce01Sjoaosaffran Args.append(CI->arg_begin(), CI->arg_end()); 175481bce01Sjoaosaffran 176481bce01Sjoaosaffran Expected<CallInst *> OpCall = 177481bce01Sjoaosaffran OpBuilder.tryCreateOp(DXILOp, Args, CI->getName(), NewRetTy); 178481bce01Sjoaosaffran if (Error E = OpCall.takeError()) 179481bce01Sjoaosaffran return E; 180481bce01Sjoaosaffran if (Error E = ReplaceUses(CI, *OpCall)) 181481bce01Sjoaosaffran return E; 182481bce01Sjoaosaffran 183481bce01Sjoaosaffran return Error::success(); 184481bce01Sjoaosaffran }); 185481bce01Sjoaosaffran } 186481bce01Sjoaosaffran 187aa61925eSJustin Bogner /// Create a cast between a `target("dx")` type and `dx.types.Handle`, which 188aa61925eSJustin Bogner /// is intended to be removed by the end of lowering. This is used to allow 189aa61925eSJustin Bogner /// lowering of ops which need to change their return or argument types in a 190aa61925eSJustin Bogner /// piecemeal way - we can add the casts in to avoid updating all of the uses 191aa61925eSJustin Bogner /// or defs, and by the end all of the casts will be redundant. 192aa61925eSJustin Bogner Value *createTmpHandleCast(Value *V, Type *Ty) { 19385c17e40SJay Foad CallInst *Cast = OpBuilder.getIRB().CreateIntrinsic( 194aa07f922SJustin Bogner Intrinsic::dx_resource_casthandle, {Ty, V->getType()}, {V}); 195aa61925eSJustin Bogner CleanupCasts.push_back(Cast); 196aa61925eSJustin Bogner return Cast; 197aa61925eSJustin Bogner } 198aa61925eSJustin Bogner 199aa61925eSJustin Bogner void cleanupHandleCasts() { 200aa61925eSJustin Bogner SmallVector<CallInst *> ToRemove; 201aa61925eSJustin Bogner SmallVector<Function *> CastFns; 202aa61925eSJustin Bogner 203aa61925eSJustin Bogner for (CallInst *Cast : CleanupCasts) { 204aa61925eSJustin Bogner // These casts were only put in to ease the move from `target("dx")` types 205aa61925eSJustin Bogner // to `dx.types.Handle in a piecemeal way. At this point, all of the 206aa61925eSJustin Bogner // non-cast uses should now be `dx.types.Handle`, and remaining casts 207aa61925eSJustin Bogner // should all form pairs to and from the now unused `target("dx")` type. 208aa61925eSJustin Bogner CastFns.push_back(Cast->getCalledFunction()); 209aa61925eSJustin Bogner 210aa61925eSJustin Bogner // If the cast is not to `dx.types.Handle`, it should be the first part of 211aa61925eSJustin Bogner // the pair. Keep track so we can remove it once it has no more uses. 212aa61925eSJustin Bogner if (Cast->getType() != OpBuilder.getHandleType()) { 213aa61925eSJustin Bogner ToRemove.push_back(Cast); 214aa61925eSJustin Bogner continue; 215aa61925eSJustin Bogner } 216aa61925eSJustin Bogner // Otherwise, we're the second handle in a pair. Forward the arguments and 217aa61925eSJustin Bogner // remove the (second) cast. 218aa61925eSJustin Bogner CallInst *Def = cast<CallInst>(Cast->getOperand(0)); 219aa07f922SJustin Bogner assert(Def->getIntrinsicID() == Intrinsic::dx_resource_casthandle && 220aa61925eSJustin Bogner "Unbalanced pair of temporary handle casts"); 221aa61925eSJustin Bogner Cast->replaceAllUsesWith(Def->getOperand(0)); 222aa61925eSJustin Bogner Cast->eraseFromParent(); 223aa61925eSJustin Bogner } 224aa61925eSJustin Bogner for (CallInst *Cast : ToRemove) { 225aa61925eSJustin Bogner assert(Cast->user_empty() && "Temporary handle cast still has users"); 226aa61925eSJustin Bogner Cast->eraseFromParent(); 227aa61925eSJustin Bogner } 228aa61925eSJustin Bogner 229aa61925eSJustin Bogner // Deduplicate the cast functions so that we only erase each one once. 230aa61925eSJustin Bogner llvm::sort(CastFns); 231aa61925eSJustin Bogner CastFns.erase(llvm::unique(CastFns), CastFns.end()); 232aa61925eSJustin Bogner for (Function *F : CastFns) 233aa61925eSJustin Bogner F->eraseFromParent(); 234aa61925eSJustin Bogner 235aa61925eSJustin Bogner CleanupCasts.clear(); 236aa61925eSJustin Bogner } 237aa61925eSJustin Bogner 23847ef3a09SGreg Roth // Remove the resource global associated with the handleFromBinding call 23947ef3a09SGreg Roth // instruction and their uses as they aren't needed anymore. 24047ef3a09SGreg Roth // TODO: We should verify that all the globals get removed. 24147ef3a09SGreg Roth // It's expected we'll need a custom pass in the future that will eliminate 24247ef3a09SGreg Roth // the need for this here. 24347ef3a09SGreg Roth void removeResourceGlobals(CallInst *CI) { 24447ef3a09SGreg Roth for (User *User : make_early_inc_range(CI->users())) { 24547ef3a09SGreg Roth if (StoreInst *Store = dyn_cast<StoreInst>(User)) { 24647ef3a09SGreg Roth Value *V = Store->getOperand(1); 24747ef3a09SGreg Roth Store->eraseFromParent(); 24847ef3a09SGreg Roth if (GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) 24947ef3a09SGreg Roth if (GV->use_empty()) { 25047ef3a09SGreg Roth GV->removeDeadConstantUsers(); 25147ef3a09SGreg Roth GV->eraseFromParent(); 25247ef3a09SGreg Roth } 25347ef3a09SGreg Roth } 25447ef3a09SGreg Roth } 25547ef3a09SGreg Roth } 25647ef3a09SGreg Roth 25790e84113SJustin Bogner [[nodiscard]] bool lowerToCreateHandle(Function &F) { 258aa61925eSJustin Bogner IRBuilder<> &IRB = OpBuilder.getIRB(); 259aa61925eSJustin Bogner Type *Int8Ty = IRB.getInt8Ty(); 260aa61925eSJustin Bogner Type *Int32Ty = IRB.getInt32Ty(); 261aa61925eSJustin Bogner 26290e84113SJustin Bogner return replaceFunction(F, [&](CallInst *CI) -> Error { 263aa61925eSJustin Bogner IRB.SetInsertPoint(CI); 264aa61925eSJustin Bogner 2653eca15cbSJustin Bogner auto *It = DBM.find(CI); 2663eca15cbSJustin Bogner assert(It != DBM.end() && "Resource not in map?"); 2673eca15cbSJustin Bogner dxil::ResourceBindingInfo &RI = *It; 2683eca15cbSJustin Bogner 269aa61925eSJustin Bogner const auto &Binding = RI.getBinding(); 2703eca15cbSJustin Bogner dxil::ResourceClass RC = DRTM[RI.getHandleTy()].getResourceClass(); 271aa61925eSJustin Bogner 272bb88fd17SJustin Bogner Value *IndexOp = CI->getArgOperand(3); 273bb88fd17SJustin Bogner if (Binding.LowerBound != 0) 274bb88fd17SJustin Bogner IndexOp = IRB.CreateAdd(IndexOp, 275bb88fd17SJustin Bogner ConstantInt::get(Int32Ty, Binding.LowerBound)); 276bb88fd17SJustin Bogner 277aa61925eSJustin Bogner std::array<Value *, 4> Args{ 2783eca15cbSJustin Bogner ConstantInt::get(Int8Ty, llvm::to_underlying(RC)), 279bb88fd17SJustin Bogner ConstantInt::get(Int32Ty, Binding.RecordID), IndexOp, 280aa61925eSJustin Bogner CI->getArgOperand(4)}; 281aa61925eSJustin Bogner Expected<CallInst *> OpCall = 2823d129016SJustin Bogner OpBuilder.tryCreateOp(OpCode::CreateHandle, Args, CI->getName()); 283aa61925eSJustin Bogner if (Error E = OpCall.takeError()) 284aa61925eSJustin Bogner return E; 285aa61925eSJustin Bogner 286aa61925eSJustin Bogner Value *Cast = createTmpHandleCast(*OpCall, CI->getType()); 287aa61925eSJustin Bogner 28847ef3a09SGreg Roth removeResourceGlobals(CI); 28947ef3a09SGreg Roth 290aa61925eSJustin Bogner CI->replaceAllUsesWith(Cast); 291aa61925eSJustin Bogner CI->eraseFromParent(); 292aa61925eSJustin Bogner return Error::success(); 293aa61925eSJustin Bogner }); 294aa61925eSJustin Bogner } 295aa61925eSJustin Bogner 29690e84113SJustin Bogner [[nodiscard]] bool lowerToBindAndAnnotateHandle(Function &F) { 297aa61925eSJustin Bogner IRBuilder<> &IRB = OpBuilder.getIRB(); 298bb88fd17SJustin Bogner Type *Int32Ty = IRB.getInt32Ty(); 299aa61925eSJustin Bogner 30090e84113SJustin Bogner return replaceFunction(F, [&](CallInst *CI) -> Error { 301aa61925eSJustin Bogner IRB.SetInsertPoint(CI); 302aa61925eSJustin Bogner 3033eca15cbSJustin Bogner auto *It = DBM.find(CI); 3043eca15cbSJustin Bogner assert(It != DBM.end() && "Resource not in map?"); 3053eca15cbSJustin Bogner dxil::ResourceBindingInfo &RI = *It; 306aa61925eSJustin Bogner 307aa61925eSJustin Bogner const auto &Binding = RI.getBinding(); 3083eca15cbSJustin Bogner dxil::ResourceTypeInfo &RTI = DRTM[RI.getHandleTy()]; 3093eca15cbSJustin Bogner dxil::ResourceClass RC = RTI.getResourceClass(); 310bb88fd17SJustin Bogner 311bb88fd17SJustin Bogner Value *IndexOp = CI->getArgOperand(3); 312bb88fd17SJustin Bogner if (Binding.LowerBound != 0) 313bb88fd17SJustin Bogner IndexOp = IRB.CreateAdd(IndexOp, 314bb88fd17SJustin Bogner ConstantInt::get(Int32Ty, Binding.LowerBound)); 315bb88fd17SJustin Bogner 3163eca15cbSJustin Bogner std::pair<uint32_t, uint32_t> Props = 3173eca15cbSJustin Bogner RI.getAnnotateProps(*F.getParent(), RTI); 318aa61925eSJustin Bogner 319aa61925eSJustin Bogner // For `CreateHandleFromBinding` we need the upper bound rather than the 320aa61925eSJustin Bogner // size, so we need to be careful about the difference for "unbounded". 321aa61925eSJustin Bogner uint32_t Unbounded = std::numeric_limits<uint32_t>::max(); 322aa61925eSJustin Bogner uint32_t UpperBound = Binding.Size == Unbounded 323aa61925eSJustin Bogner ? Unbounded 324aa61925eSJustin Bogner : Binding.LowerBound + Binding.Size - 1; 3253eca15cbSJustin Bogner Constant *ResBind = OpBuilder.getResBind(Binding.LowerBound, UpperBound, 3263eca15cbSJustin Bogner Binding.Space, RC); 327bb88fd17SJustin Bogner std::array<Value *, 3> BindArgs{ResBind, IndexOp, CI->getArgOperand(4)}; 3283d129016SJustin Bogner Expected<CallInst *> OpBind = OpBuilder.tryCreateOp( 3293d129016SJustin Bogner OpCode::CreateHandleFromBinding, BindArgs, CI->getName()); 330aa61925eSJustin Bogner if (Error E = OpBind.takeError()) 331aa61925eSJustin Bogner return E; 332aa61925eSJustin Bogner 333aa61925eSJustin Bogner std::array<Value *, 2> AnnotateArgs{ 334aa61925eSJustin Bogner *OpBind, OpBuilder.getResProps(Props.first, Props.second)}; 3353d129016SJustin Bogner Expected<CallInst *> OpAnnotate = OpBuilder.tryCreateOp( 3363d129016SJustin Bogner OpCode::AnnotateHandle, AnnotateArgs, 3373d129016SJustin Bogner CI->hasName() ? CI->getName() + "_annot" : Twine()); 338aa61925eSJustin Bogner if (Error E = OpAnnotate.takeError()) 339aa61925eSJustin Bogner return E; 340aa61925eSJustin Bogner 341aa61925eSJustin Bogner Value *Cast = createTmpHandleCast(*OpAnnotate, CI->getType()); 342aa61925eSJustin Bogner 34347ef3a09SGreg Roth removeResourceGlobals(CI); 34447ef3a09SGreg Roth 345aa61925eSJustin Bogner CI->replaceAllUsesWith(Cast); 346aa61925eSJustin Bogner CI->eraseFromParent(); 347aa61925eSJustin Bogner 348aa61925eSJustin Bogner return Error::success(); 349aa61925eSJustin Bogner }); 350aa61925eSJustin Bogner } 351aa61925eSJustin Bogner 352aa07f922SJustin Bogner /// Lower `dx.resource.handlefrombinding` intrinsics depending on the shader 353aa07f922SJustin Bogner /// model and taking into account binding information from 354aa07f922SJustin Bogner /// DXILResourceBindingAnalysis. 35590e84113SJustin Bogner bool lowerHandleFromBinding(Function &F) { 356aa61925eSJustin Bogner Triple TT(Triple(M.getTargetTriple())); 357aa61925eSJustin Bogner if (TT.getDXILVersion() < VersionTuple(1, 6)) 35890e84113SJustin Bogner return lowerToCreateHandle(F); 35990e84113SJustin Bogner return lowerToBindAndAnnotateHandle(F); 360aa61925eSJustin Bogner } 361aa61925eSJustin Bogner 362481bce01Sjoaosaffran Error replaceSplitDoubleCallUsages(CallInst *Intrin, CallInst *Op) { 363481bce01Sjoaosaffran for (Use &U : make_early_inc_range(Intrin->uses())) { 364481bce01Sjoaosaffran if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) { 365481bce01Sjoaosaffran 366481bce01Sjoaosaffran if (EVI->getNumIndices() != 1) 367481bce01Sjoaosaffran return createStringError(std::errc::invalid_argument, 368481bce01Sjoaosaffran "Splitdouble has only 2 elements"); 369481bce01Sjoaosaffran EVI->setOperand(0, Op); 370481bce01Sjoaosaffran } else { 371481bce01Sjoaosaffran return make_error<StringError>( 372481bce01Sjoaosaffran "Splitdouble use is not ExtractValueInst", 373481bce01Sjoaosaffran inconvertibleErrorCode()); 374481bce01Sjoaosaffran } 375481bce01Sjoaosaffran } 376481bce01Sjoaosaffran 377481bce01Sjoaosaffran Intrin->eraseFromParent(); 378481bce01Sjoaosaffran 379481bce01Sjoaosaffran return Error::success(); 380481bce01Sjoaosaffran } 381481bce01Sjoaosaffran 3823f22756fSJustin Bogner /// Replace uses of \c Intrin with the values in the `dx.ResRet` of \c Op. 3833f22756fSJustin Bogner /// Since we expect to be post-scalarization, make an effort to avoid vectors. 38434e20f18SJustin Bogner Error replaceResRetUses(CallInst *Intrin, CallInst *Op, bool HasCheckBit) { 3853f22756fSJustin Bogner IRBuilder<> &IRB = OpBuilder.getIRB(); 3863f22756fSJustin Bogner 38734e20f18SJustin Bogner Instruction *OldResult = Intrin; 3883f22756fSJustin Bogner Type *OldTy = Intrin->getType(); 3893f22756fSJustin Bogner 39034e20f18SJustin Bogner if (HasCheckBit) { 39134e20f18SJustin Bogner auto *ST = cast<StructType>(OldTy); 39234e20f18SJustin Bogner 39334e20f18SJustin Bogner Value *CheckOp = nullptr; 39434e20f18SJustin Bogner Type *Int32Ty = IRB.getInt32Ty(); 39534e20f18SJustin Bogner for (Use &U : make_early_inc_range(OldResult->uses())) { 39634e20f18SJustin Bogner if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) { 39734e20f18SJustin Bogner ArrayRef<unsigned> Indices = EVI->getIndices(); 39834e20f18SJustin Bogner assert(Indices.size() == 1); 39934e20f18SJustin Bogner // We're only interested in uses of the check bit for now. 40034e20f18SJustin Bogner if (Indices[0] != 1) 40134e20f18SJustin Bogner continue; 40234e20f18SJustin Bogner if (!CheckOp) { 40334e20f18SJustin Bogner Value *NewEVI = IRB.CreateExtractValue(Op, 4); 40434e20f18SJustin Bogner Expected<CallInst *> OpCall = OpBuilder.tryCreateOp( 4053d129016SJustin Bogner OpCode::CheckAccessFullyMapped, {NewEVI}, 4063d129016SJustin Bogner OldResult->hasName() ? OldResult->getName() + "_check" 4073d129016SJustin Bogner : Twine(), 4083d129016SJustin Bogner Int32Ty); 40934e20f18SJustin Bogner if (Error E = OpCall.takeError()) 41034e20f18SJustin Bogner return E; 41134e20f18SJustin Bogner CheckOp = *OpCall; 41234e20f18SJustin Bogner } 41334e20f18SJustin Bogner EVI->replaceAllUsesWith(CheckOp); 41434e20f18SJustin Bogner EVI->eraseFromParent(); 41534e20f18SJustin Bogner } 41634e20f18SJustin Bogner } 41734e20f18SJustin Bogner 4182c7c07dfSJustin Bogner if (OldResult->use_empty()) { 4192c7c07dfSJustin Bogner // Only the check bit was used, so we're done here. 4202c7c07dfSJustin Bogner OldResult->eraseFromParent(); 4212c7c07dfSJustin Bogner return Error::success(); 4222c7c07dfSJustin Bogner } 4232c7c07dfSJustin Bogner 4242c7c07dfSJustin Bogner assert(OldResult->hasOneUse() && 4252c7c07dfSJustin Bogner isa<ExtractValueInst>(*OldResult->user_begin()) && 4262c7c07dfSJustin Bogner "Expected only use to be extract of first element"); 4272c7c07dfSJustin Bogner OldResult = cast<Instruction>(*OldResult->user_begin()); 42834e20f18SJustin Bogner OldTy = ST->getElementType(0); 42934e20f18SJustin Bogner } 43034e20f18SJustin Bogner 4313f22756fSJustin Bogner // For scalars, we just extract the first element. 4323f22756fSJustin Bogner if (!isa<FixedVectorType>(OldTy)) { 4333f22756fSJustin Bogner Value *EVI = IRB.CreateExtractValue(Op, 0); 43434e20f18SJustin Bogner OldResult->replaceAllUsesWith(EVI); 43534e20f18SJustin Bogner OldResult->eraseFromParent(); 43634e20f18SJustin Bogner if (OldResult != Intrin) { 43734e20f18SJustin Bogner assert(Intrin->use_empty() && "Intrinsic still has uses?"); 4383f22756fSJustin Bogner Intrin->eraseFromParent(); 43934e20f18SJustin Bogner } 4403f22756fSJustin Bogner return Error::success(); 4413f22756fSJustin Bogner } 4423f22756fSJustin Bogner 4433f22756fSJustin Bogner std::array<Value *, 4> Extracts = {}; 4443f22756fSJustin Bogner SmallVector<ExtractElementInst *> DynamicAccesses; 4453f22756fSJustin Bogner 4463f22756fSJustin Bogner // The users of the operation should all be scalarized, so we attempt to 4473f22756fSJustin Bogner // replace the extractelements with extractvalues directly. 44834e20f18SJustin Bogner for (Use &U : make_early_inc_range(OldResult->uses())) { 4493f22756fSJustin Bogner if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser())) { 4503f22756fSJustin Bogner if (auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand())) { 4513f22756fSJustin Bogner size_t IndexVal = IndexOp->getZExtValue(); 4523f22756fSJustin Bogner assert(IndexVal < 4 && "Index into buffer load out of range"); 4533f22756fSJustin Bogner if (!Extracts[IndexVal]) 4543f22756fSJustin Bogner Extracts[IndexVal] = IRB.CreateExtractValue(Op, IndexVal); 4553f22756fSJustin Bogner EEI->replaceAllUsesWith(Extracts[IndexVal]); 4563f22756fSJustin Bogner EEI->eraseFromParent(); 4573f22756fSJustin Bogner } else { 4583f22756fSJustin Bogner DynamicAccesses.push_back(EEI); 4593f22756fSJustin Bogner } 4603f22756fSJustin Bogner } 4613f22756fSJustin Bogner } 4623f22756fSJustin Bogner 4633f22756fSJustin Bogner const auto *VecTy = cast<FixedVectorType>(OldTy); 4643f22756fSJustin Bogner const unsigned N = VecTy->getNumElements(); 4653f22756fSJustin Bogner 4663f22756fSJustin Bogner // If there's a dynamic access we need to round trip through stack memory so 4673f22756fSJustin Bogner // that we don't leave vectors around. 4683f22756fSJustin Bogner if (!DynamicAccesses.empty()) { 4693f22756fSJustin Bogner Type *Int32Ty = IRB.getInt32Ty(); 4703f22756fSJustin Bogner Constant *Zero = ConstantInt::get(Int32Ty, 0); 4713f22756fSJustin Bogner 4723f22756fSJustin Bogner Type *ElTy = VecTy->getElementType(); 4733f22756fSJustin Bogner Type *ArrayTy = ArrayType::get(ElTy, N); 4743f22756fSJustin Bogner Value *Alloca = IRB.CreateAlloca(ArrayTy); 4753f22756fSJustin Bogner 4763f22756fSJustin Bogner for (int I = 0, E = N; I != E; ++I) { 4773f22756fSJustin Bogner if (!Extracts[I]) 4783f22756fSJustin Bogner Extracts[I] = IRB.CreateExtractValue(Op, I); 4793f22756fSJustin Bogner Value *GEP = IRB.CreateInBoundsGEP( 4803f22756fSJustin Bogner ArrayTy, Alloca, {Zero, ConstantInt::get(Int32Ty, I)}); 4813f22756fSJustin Bogner IRB.CreateStore(Extracts[I], GEP); 4823f22756fSJustin Bogner } 4833f22756fSJustin Bogner 4843f22756fSJustin Bogner for (ExtractElementInst *EEI : DynamicAccesses) { 4853f22756fSJustin Bogner Value *GEP = IRB.CreateInBoundsGEP(ArrayTy, Alloca, 4863f22756fSJustin Bogner {Zero, EEI->getIndexOperand()}); 4873f22756fSJustin Bogner Value *Load = IRB.CreateLoad(ElTy, GEP); 4883f22756fSJustin Bogner EEI->replaceAllUsesWith(Load); 4893f22756fSJustin Bogner EEI->eraseFromParent(); 4903f22756fSJustin Bogner } 4913f22756fSJustin Bogner } 4923f22756fSJustin Bogner 4933f22756fSJustin Bogner // If we still have uses, then we're not fully scalarized and need to 4943f22756fSJustin Bogner // recreate the vector. This should only happen for things like exported 4953f22756fSJustin Bogner // functions from libraries. 49634e20f18SJustin Bogner if (!OldResult->use_empty()) { 4973f22756fSJustin Bogner for (int I = 0, E = N; I != E; ++I) 4983f22756fSJustin Bogner if (!Extracts[I]) 4993f22756fSJustin Bogner Extracts[I] = IRB.CreateExtractValue(Op, I); 5003f22756fSJustin Bogner 5013f22756fSJustin Bogner Value *Vec = UndefValue::get(OldTy); 5023f22756fSJustin Bogner for (int I = 0, E = N; I != E; ++I) 5033f22756fSJustin Bogner Vec = IRB.CreateInsertElement(Vec, Extracts[I], I); 50434e20f18SJustin Bogner OldResult->replaceAllUsesWith(Vec); 5053f22756fSJustin Bogner } 5063f22756fSJustin Bogner 50734e20f18SJustin Bogner OldResult->eraseFromParent(); 50834e20f18SJustin Bogner if (OldResult != Intrin) { 50934e20f18SJustin Bogner assert(Intrin->use_empty() && "Intrinsic still has uses?"); 5103f22756fSJustin Bogner Intrin->eraseFromParent(); 51134e20f18SJustin Bogner } 51234e20f18SJustin Bogner 5133f22756fSJustin Bogner return Error::success(); 5143f22756fSJustin Bogner } 5153f22756fSJustin Bogner 51634e20f18SJustin Bogner [[nodiscard]] bool lowerTypedBufferLoad(Function &F, bool HasCheckBit) { 5173f22756fSJustin Bogner IRBuilder<> &IRB = OpBuilder.getIRB(); 5183f22756fSJustin Bogner Type *Int32Ty = IRB.getInt32Ty(); 5193f22756fSJustin Bogner 52090e84113SJustin Bogner return replaceFunction(F, [&](CallInst *CI) -> Error { 5213f22756fSJustin Bogner IRB.SetInsertPoint(CI); 5223f22756fSJustin Bogner 5233f22756fSJustin Bogner Value *Handle = 5243f22756fSJustin Bogner createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType()); 5253f22756fSJustin Bogner Value *Index0 = CI->getArgOperand(1); 5263f22756fSJustin Bogner Value *Index1 = UndefValue::get(Int32Ty); 5273f22756fSJustin Bogner 52834e20f18SJustin Bogner Type *OldTy = CI->getType(); 52934e20f18SJustin Bogner if (HasCheckBit) 53034e20f18SJustin Bogner OldTy = cast<StructType>(OldTy)->getElementType(0); 53134e20f18SJustin Bogner Type *NewRetTy = OpBuilder.getResRetType(OldTy->getScalarType()); 5323f22756fSJustin Bogner 5333f22756fSJustin Bogner std::array<Value *, 3> Args{Handle, Index0, Index1}; 5343d129016SJustin Bogner Expected<CallInst *> OpCall = OpBuilder.tryCreateOp( 5353d129016SJustin Bogner OpCode::BufferLoad, Args, CI->getName(), NewRetTy); 5363f22756fSJustin Bogner if (Error E = OpCall.takeError()) 5373f22756fSJustin Bogner return E; 53834e20f18SJustin Bogner if (Error E = replaceResRetUses(CI, *OpCall, HasCheckBit)) 5393f22756fSJustin Bogner return E; 5403f22756fSJustin Bogner 5413f22756fSJustin Bogner return Error::success(); 5423f22756fSJustin Bogner }); 5433f22756fSJustin Bogner } 5443f22756fSJustin Bogner 545cba9bd5cSJustin Bogner [[nodiscard]] bool lowerRawBufferLoad(Function &F) { 546cba9bd5cSJustin Bogner Triple TT(Triple(M.getTargetTriple())); 547cba9bd5cSJustin Bogner VersionTuple DXILVersion = TT.getDXILVersion(); 548cba9bd5cSJustin Bogner const DataLayout &DL = F.getDataLayout(); 549cba9bd5cSJustin Bogner IRBuilder<> &IRB = OpBuilder.getIRB(); 550cba9bd5cSJustin Bogner Type *Int8Ty = IRB.getInt8Ty(); 551cba9bd5cSJustin Bogner Type *Int32Ty = IRB.getInt32Ty(); 552cba9bd5cSJustin Bogner 553cba9bd5cSJustin Bogner return replaceFunction(F, [&](CallInst *CI) -> Error { 554cba9bd5cSJustin Bogner IRB.SetInsertPoint(CI); 555cba9bd5cSJustin Bogner 556cba9bd5cSJustin Bogner Type *OldTy = cast<StructType>(CI->getType())->getElementType(0); 557cba9bd5cSJustin Bogner Type *ScalarTy = OldTy->getScalarType(); 558cba9bd5cSJustin Bogner Type *NewRetTy = OpBuilder.getResRetType(ScalarTy); 559cba9bd5cSJustin Bogner 560cba9bd5cSJustin Bogner Value *Handle = 561cba9bd5cSJustin Bogner createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType()); 562cba9bd5cSJustin Bogner Value *Index0 = CI->getArgOperand(1); 563cba9bd5cSJustin Bogner Value *Index1 = CI->getArgOperand(2); 564cba9bd5cSJustin Bogner uint64_t NumElements = 565cba9bd5cSJustin Bogner DL.getTypeSizeInBits(OldTy) / DL.getTypeSizeInBits(ScalarTy); 566cba9bd5cSJustin Bogner Value *Mask = ConstantInt::get(Int8Ty, ~(~0U << NumElements)); 567cba9bd5cSJustin Bogner Value *Align = 568cba9bd5cSJustin Bogner ConstantInt::get(Int32Ty, DL.getPrefTypeAlign(ScalarTy).value()); 569cba9bd5cSJustin Bogner 570cba9bd5cSJustin Bogner Expected<CallInst *> OpCall = 571cba9bd5cSJustin Bogner DXILVersion >= VersionTuple(1, 2) 572cba9bd5cSJustin Bogner ? OpBuilder.tryCreateOp(OpCode::RawBufferLoad, 573cba9bd5cSJustin Bogner {Handle, Index0, Index1, Mask, Align}, 574cba9bd5cSJustin Bogner CI->getName(), NewRetTy) 575cba9bd5cSJustin Bogner : OpBuilder.tryCreateOp(OpCode::BufferLoad, 576cba9bd5cSJustin Bogner {Handle, Index0, Index1}, CI->getName(), 577cba9bd5cSJustin Bogner NewRetTy); 578cba9bd5cSJustin Bogner if (Error E = OpCall.takeError()) 579cba9bd5cSJustin Bogner return E; 580cba9bd5cSJustin Bogner if (Error E = replaceResRetUses(CI, *OpCall, /*HasCheckBit=*/true)) 581cba9bd5cSJustin Bogner return E; 582cba9bd5cSJustin Bogner 583cba9bd5cSJustin Bogner return Error::success(); 584cba9bd5cSJustin Bogner }); 585cba9bd5cSJustin Bogner } 586cba9bd5cSJustin Bogner 5871f250999Sjoaosaffran [[nodiscard]] bool lowerUpdateCounter(Function &F) { 5881f250999Sjoaosaffran IRBuilder<> &IRB = OpBuilder.getIRB(); 589691bd184Sjoaosaffran Type *Int32Ty = IRB.getInt32Ty(); 5901f250999Sjoaosaffran 5911f250999Sjoaosaffran return replaceFunction(F, [&](CallInst *CI) -> Error { 5921f250999Sjoaosaffran IRB.SetInsertPoint(CI); 5931f250999Sjoaosaffran Value *Handle = 5941f250999Sjoaosaffran createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType()); 5951f250999Sjoaosaffran Value *Op1 = CI->getArgOperand(1); 5961f250999Sjoaosaffran 5971f250999Sjoaosaffran std::array<Value *, 2> Args{Handle, Op1}; 5981f250999Sjoaosaffran 599691bd184Sjoaosaffran Expected<CallInst *> OpCall = OpBuilder.tryCreateOp( 600691bd184Sjoaosaffran OpCode::UpdateCounter, Args, CI->getName(), Int32Ty); 6011f250999Sjoaosaffran 6021f250999Sjoaosaffran if (Error E = OpCall.takeError()) 6031f250999Sjoaosaffran return E; 6041f250999Sjoaosaffran 605691bd184Sjoaosaffran CI->replaceAllUsesWith(*OpCall); 6061f250999Sjoaosaffran CI->eraseFromParent(); 6071f250999Sjoaosaffran return Error::success(); 6081f250999Sjoaosaffran }); 6091f250999Sjoaosaffran } 6101f250999Sjoaosaffran 6110fca76d5SJustin Bogner [[nodiscard]] bool lowerGetPointer(Function &F) { 6120fca76d5SJustin Bogner // These should have already been handled in DXILResourceAccess, so we can 6130fca76d5SJustin Bogner // just clean up the dead prototype. 6140fca76d5SJustin Bogner assert(F.user_empty() && "getpointer operations should have been removed"); 6150fca76d5SJustin Bogner F.eraseFromParent(); 6160fca76d5SJustin Bogner return false; 6170fca76d5SJustin Bogner } 6180fca76d5SJustin Bogner 619*0e51b54bSJustin Bogner [[nodiscard]] bool lowerBufferStore(Function &F, bool IsRaw) { 620*0e51b54bSJustin Bogner Triple TT(Triple(M.getTargetTriple())); 621*0e51b54bSJustin Bogner VersionTuple DXILVersion = TT.getDXILVersion(); 622*0e51b54bSJustin Bogner const DataLayout &DL = F.getDataLayout(); 62390e84113SJustin Bogner IRBuilder<> &IRB = OpBuilder.getIRB(); 62490e84113SJustin Bogner Type *Int8Ty = IRB.getInt8Ty(); 62590e84113SJustin Bogner Type *Int32Ty = IRB.getInt32Ty(); 62690e84113SJustin Bogner 62790e84113SJustin Bogner return replaceFunction(F, [&](CallInst *CI) -> Error { 62890e84113SJustin Bogner IRB.SetInsertPoint(CI); 62990e84113SJustin Bogner 63090e84113SJustin Bogner Value *Handle = 63190e84113SJustin Bogner createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType()); 63290e84113SJustin Bogner Value *Index0 = CI->getArgOperand(1); 633*0e51b54bSJustin Bogner Value *Index1 = IsRaw ? CI->getArgOperand(2) : UndefValue::get(Int32Ty); 63490e84113SJustin Bogner 635*0e51b54bSJustin Bogner Value *Data = CI->getArgOperand(IsRaw ? 3 : 2); 636*0e51b54bSJustin Bogner Type *DataTy = Data->getType(); 637*0e51b54bSJustin Bogner Type *ScalarTy = DataTy->getScalarType(); 638*0e51b54bSJustin Bogner 639*0e51b54bSJustin Bogner uint64_t NumElements = 640*0e51b54bSJustin Bogner DL.getTypeSizeInBits(DataTy) / DL.getTypeSizeInBits(ScalarTy); 641*0e51b54bSJustin Bogner Value *Mask = ConstantInt::get(Int8Ty, ~(~0U << NumElements)); 642*0e51b54bSJustin Bogner 643*0e51b54bSJustin Bogner // TODO: check that we only have vector or scalar... 644*0e51b54bSJustin Bogner if (!IsRaw && NumElements != 4) 64590e84113SJustin Bogner return make_error<StringError>( 64690e84113SJustin Bogner "typedBufferStore data must be a vector of 4 elements", 64790e84113SJustin Bogner inconvertibleErrorCode()); 648*0e51b54bSJustin Bogner else if (NumElements > 4) 649*0e51b54bSJustin Bogner return make_error<StringError>( 650*0e51b54bSJustin Bogner "rawBufferStore data must have at most 4 elements", 651*0e51b54bSJustin Bogner inconvertibleErrorCode()); 65290e84113SJustin Bogner 6532c88ac9dSJustin Bogner std::array<Value *, 4> DataElements{nullptr, nullptr, nullptr, nullptr}; 654*0e51b54bSJustin Bogner if (DataTy == ScalarTy) 655*0e51b54bSJustin Bogner DataElements[0] = Data; 656*0e51b54bSJustin Bogner else { 657*0e51b54bSJustin Bogner // Since we're post-scalarizer, if we see a vector here it's likely 658*0e51b54bSJustin Bogner // constructed solely for the argument of the store. Just use the scalar 659*0e51b54bSJustin Bogner // values from before they're inserted into the temporary. 6602c88ac9dSJustin Bogner auto *IEI = dyn_cast<InsertElementInst>(Data); 6612c88ac9dSJustin Bogner while (IEI) { 6622c88ac9dSJustin Bogner auto *IndexOp = dyn_cast<ConstantInt>(IEI->getOperand(2)); 6632c88ac9dSJustin Bogner if (!IndexOp) 6642c88ac9dSJustin Bogner break; 6652c88ac9dSJustin Bogner size_t IndexVal = IndexOp->getZExtValue(); 6662c88ac9dSJustin Bogner assert(IndexVal < 4 && "Too many elements for buffer store"); 6672c88ac9dSJustin Bogner DataElements[IndexVal] = IEI->getOperand(1); 6682c88ac9dSJustin Bogner IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0)); 6692c88ac9dSJustin Bogner } 670*0e51b54bSJustin Bogner } 6712c88ac9dSJustin Bogner 6722c88ac9dSJustin Bogner // If for some reason we weren't able to forward the arguments from the 673*0e51b54bSJustin Bogner // scalarizer artifact, then we may need to actually extract elements from 674*0e51b54bSJustin Bogner // the vector. 675*0e51b54bSJustin Bogner for (int I = 0, E = NumElements; I < E; ++I) 6762c88ac9dSJustin Bogner if (DataElements[I] == nullptr) 6772c88ac9dSJustin Bogner DataElements[I] = 6782c88ac9dSJustin Bogner IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, I)); 679*0e51b54bSJustin Bogner // For any elements beyond the length of the vector, fill up with undef. 680*0e51b54bSJustin Bogner for (int I = NumElements, E = 4; I < E; ++I) 681*0e51b54bSJustin Bogner if (DataElements[I] == nullptr) 682*0e51b54bSJustin Bogner DataElements[I] = UndefValue::get(ScalarTy); 6832c88ac9dSJustin Bogner 684*0e51b54bSJustin Bogner dxil::OpCode Op = OpCode::BufferStore; 685*0e51b54bSJustin Bogner SmallVector<Value *, 9> Args{ 6862c88ac9dSJustin Bogner Handle, Index0, Index1, DataElements[0], 6872c88ac9dSJustin Bogner DataElements[1], DataElements[2], DataElements[3], Mask}; 688*0e51b54bSJustin Bogner if (IsRaw && DXILVersion >= VersionTuple(1, 2)) { 689*0e51b54bSJustin Bogner Op = OpCode::RawBufferStore; 690*0e51b54bSJustin Bogner // RawBufferStore requires the alignment 691*0e51b54bSJustin Bogner Args.push_back( 692*0e51b54bSJustin Bogner ConstantInt::get(Int32Ty, DL.getPrefTypeAlign(ScalarTy).value())); 693*0e51b54bSJustin Bogner } 69490e84113SJustin Bogner Expected<CallInst *> OpCall = 695*0e51b54bSJustin Bogner OpBuilder.tryCreateOp(Op, Args, CI->getName()); 69690e84113SJustin Bogner if (Error E = OpCall.takeError()) 69790e84113SJustin Bogner return E; 69890e84113SJustin Bogner 69990e84113SJustin Bogner CI->eraseFromParent(); 7002c88ac9dSJustin Bogner // Clean up any leftover `insertelement`s 701*0e51b54bSJustin Bogner auto *IEI = dyn_cast<InsertElementInst>(Data); 7022c88ac9dSJustin Bogner while (IEI && IEI->use_empty()) { 7032c88ac9dSJustin Bogner InsertElementInst *Tmp = IEI; 7042c88ac9dSJustin Bogner IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0)); 7052c88ac9dSJustin Bogner Tmp->eraseFromParent(); 7062c88ac9dSJustin Bogner } 7072c88ac9dSJustin Bogner 70890e84113SJustin Bogner return Error::success(); 70990e84113SJustin Bogner }); 71090e84113SJustin Bogner } 71190e84113SJustin Bogner 71275e7ba8cSSarah Spall [[nodiscard]] bool lowerCtpopToCountBits(Function &F) { 71375e7ba8cSSarah Spall IRBuilder<> &IRB = OpBuilder.getIRB(); 71475e7ba8cSSarah Spall Type *Int32Ty = IRB.getInt32Ty(); 71575e7ba8cSSarah Spall 71675e7ba8cSSarah Spall return replaceFunction(F, [&](CallInst *CI) -> Error { 71775e7ba8cSSarah Spall IRB.SetInsertPoint(CI); 71875e7ba8cSSarah Spall SmallVector<Value *> Args; 71975e7ba8cSSarah Spall Args.append(CI->arg_begin(), CI->arg_end()); 72075e7ba8cSSarah Spall 72175e7ba8cSSarah Spall Type *RetTy = Int32Ty; 72275e7ba8cSSarah Spall Type *FRT = F.getReturnType(); 72375e7ba8cSSarah Spall if (const auto *VT = dyn_cast<VectorType>(FRT)) 72475e7ba8cSSarah Spall RetTy = VectorType::get(RetTy, VT); 72575e7ba8cSSarah Spall 72675e7ba8cSSarah Spall Expected<CallInst *> OpCall = OpBuilder.tryCreateOp( 72775e7ba8cSSarah Spall dxil::OpCode::CountBits, Args, CI->getName(), RetTy); 72875e7ba8cSSarah Spall if (Error E = OpCall.takeError()) 72975e7ba8cSSarah Spall return E; 73075e7ba8cSSarah Spall 73175e7ba8cSSarah Spall // If the result type is 32 bits we can do a direct replacement. 73275e7ba8cSSarah Spall if (FRT->isIntOrIntVectorTy(32)) { 73375e7ba8cSSarah Spall CI->replaceAllUsesWith(*OpCall); 73475e7ba8cSSarah Spall CI->eraseFromParent(); 73575e7ba8cSSarah Spall return Error::success(); 73675e7ba8cSSarah Spall } 73775e7ba8cSSarah Spall 73875e7ba8cSSarah Spall unsigned CastOp; 73975e7ba8cSSarah Spall unsigned CastOp2; 74075e7ba8cSSarah Spall if (FRT->isIntOrIntVectorTy(16)) { 74175e7ba8cSSarah Spall CastOp = Instruction::ZExt; 74275e7ba8cSSarah Spall CastOp2 = Instruction::SExt; 74375e7ba8cSSarah Spall } else { // must be 64 bits 74475e7ba8cSSarah Spall assert(FRT->isIntOrIntVectorTy(64) && 74575e7ba8cSSarah Spall "Currently only lowering 16, 32, or 64 bit ctpop to CountBits \ 74675e7ba8cSSarah Spall is supported."); 74775e7ba8cSSarah Spall CastOp = Instruction::Trunc; 74875e7ba8cSSarah Spall CastOp2 = Instruction::Trunc; 74975e7ba8cSSarah Spall } 75075e7ba8cSSarah Spall 75175e7ba8cSSarah Spall // It is correct to replace the ctpop with the dxil op and 75275e7ba8cSSarah Spall // remove all casts to i32 75375e7ba8cSSarah Spall bool NeedsCast = false; 75475e7ba8cSSarah Spall for (User *User : make_early_inc_range(CI->users())) { 75575e7ba8cSSarah Spall Instruction *I = dyn_cast<Instruction>(User); 75675e7ba8cSSarah Spall if (I && (I->getOpcode() == CastOp || I->getOpcode() == CastOp2) && 75775e7ba8cSSarah Spall I->getType() == RetTy) { 75875e7ba8cSSarah Spall I->replaceAllUsesWith(*OpCall); 75975e7ba8cSSarah Spall I->eraseFromParent(); 76075e7ba8cSSarah Spall } else 76175e7ba8cSSarah Spall NeedsCast = true; 76275e7ba8cSSarah Spall } 76375e7ba8cSSarah Spall 76475e7ba8cSSarah Spall // It is correct to replace a ctpop with the dxil op and 76575e7ba8cSSarah Spall // a cast from i32 to the return type of the ctpop 76675e7ba8cSSarah Spall // the cast is emitted here if there is a non-cast to i32 76775e7ba8cSSarah Spall // instr which uses the ctpop 76875e7ba8cSSarah Spall if (NeedsCast) { 76975e7ba8cSSarah Spall Value *Cast = 77075e7ba8cSSarah Spall IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(), "ctpop.cast"); 77175e7ba8cSSarah Spall CI->replaceAllUsesWith(Cast); 77275e7ba8cSSarah Spall } 77375e7ba8cSSarah Spall 77475e7ba8cSSarah Spall CI->eraseFromParent(); 77575e7ba8cSSarah Spall return Error::success(); 77675e7ba8cSSarah Spall }); 77775e7ba8cSSarah Spall } 77875e7ba8cSSarah Spall 779e56ad22bSJustin Bogner bool lowerIntrinsics() { 78085285be9SXiang Li bool Updated = false; 78190e84113SJustin Bogner bool HasErrors = false; 782264c09b7SXiang Li 78385285be9SXiang Li for (Function &F : make_early_inc_range(M.functions())) { 78485285be9SXiang Li if (!F.isDeclaration()) 78585285be9SXiang Li continue; 78685285be9SXiang Li Intrinsic::ID ID = F.getIntrinsicID(); 78794da6bfbSJustin Bogner switch (ID) { 78894da6bfbSJustin Bogner default: 789264c09b7SXiang Li continue; 7900a44b24dSAdam Yang #define DXIL_OP_INTRINSIC(OpCode, Intrin, ...) \ 79194da6bfbSJustin Bogner case Intrin: \ 7920a44b24dSAdam Yang HasErrors |= replaceFunctionWithOp( \ 7930a44b24dSAdam Yang F, OpCode, ArrayRef<IntrinArgSelect>{__VA_ARGS__}); \ 79494da6bfbSJustin Bogner break; 79594da6bfbSJustin Bogner #include "DXILOperation.inc" 796aa07f922SJustin Bogner case Intrinsic::dx_resource_handlefrombinding: 79790e84113SJustin Bogner HasErrors |= lowerHandleFromBinding(F); 7983f22756fSJustin Bogner break; 7990fca76d5SJustin Bogner case Intrinsic::dx_resource_getpointer: 8000fca76d5SJustin Bogner HasErrors |= lowerGetPointer(F); 8010fca76d5SJustin Bogner break; 802aa07f922SJustin Bogner case Intrinsic::dx_resource_load_typedbuffer: 80334e20f18SJustin Bogner HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/true); 80490e84113SJustin Bogner break; 805aa07f922SJustin Bogner case Intrinsic::dx_resource_store_typedbuffer: 806*0e51b54bSJustin Bogner HasErrors |= lowerBufferStore(F, /*IsRaw=*/false); 8073f22756fSJustin Bogner break; 808cba9bd5cSJustin Bogner case Intrinsic::dx_resource_load_rawbuffer: 809cba9bd5cSJustin Bogner HasErrors |= lowerRawBufferLoad(F); 810cba9bd5cSJustin Bogner break; 811*0e51b54bSJustin Bogner case Intrinsic::dx_resource_store_rawbuffer: 812*0e51b54bSJustin Bogner HasErrors |= lowerBufferStore(F, /*IsRaw=*/true); 813*0e51b54bSJustin Bogner break; 814aa07f922SJustin Bogner case Intrinsic::dx_resource_updatecounter: 8151f250999Sjoaosaffran HasErrors |= lowerUpdateCounter(F); 8161f250999Sjoaosaffran break; 817481bce01Sjoaosaffran // TODO: this can be removed when 818481bce01Sjoaosaffran // https://github.com/llvm/llvm-project/issues/113192 is fixed 819481bce01Sjoaosaffran case Intrinsic::dx_splitdouble: 820481bce01Sjoaosaffran HasErrors |= replaceFunctionWithNamedStructOp( 821481bce01Sjoaosaffran F, OpCode::SplitDouble, 822481bce01Sjoaosaffran OpBuilder.getSplitDoubleType(M.getContext()), 823481bce01Sjoaosaffran [&](CallInst *CI, CallInst *Op) { 824481bce01Sjoaosaffran return replaceSplitDoubleCallUsages(CI, Op); 825481bce01Sjoaosaffran }); 826481bce01Sjoaosaffran break; 82775e7ba8cSSarah Spall case Intrinsic::ctpop: 82875e7ba8cSSarah Spall HasErrors |= lowerCtpopToCountBits(F); 82975e7ba8cSSarah Spall break; 83094da6bfbSJustin Bogner } 83185285be9SXiang Li Updated = true; 83285285be9SXiang Li } 83390e84113SJustin Bogner if (Updated && !HasErrors) 834aa61925eSJustin Bogner cleanupHandleCasts(); 835aa61925eSJustin Bogner 83685285be9SXiang Li return Updated; 83785285be9SXiang Li } 838e56ad22bSJustin Bogner }; 839e56ad22bSJustin Bogner } // namespace 84085285be9SXiang Li 841aa61925eSJustin Bogner PreservedAnalyses DXILOpLowering::run(Module &M, ModuleAnalysisManager &MAM) { 8423eca15cbSJustin Bogner DXILBindingMap &DBM = MAM.getResult<DXILResourceBindingAnalysis>(M); 8433eca15cbSJustin Bogner DXILResourceTypeMap &DRTM = MAM.getResult<DXILResourceTypeAnalysis>(M); 844aa61925eSJustin Bogner 8453eca15cbSJustin Bogner bool MadeChanges = OpLowerer(M, DBM, DRTM).lowerIntrinsics(); 846aa61925eSJustin Bogner if (!MadeChanges) 84785285be9SXiang Li return PreservedAnalyses::all(); 848aa61925eSJustin Bogner PreservedAnalyses PA; 8493eca15cbSJustin Bogner PA.preserve<DXILResourceBindingAnalysis>(); 850bfd05102SJustin Bogner PA.preserve<DXILMetadataAnalysis>(); 851bfd05102SJustin Bogner PA.preserve<ShaderFlagsAnalysis>(); 852aa61925eSJustin Bogner return PA; 85385285be9SXiang Li } 85485285be9SXiang Li 85585285be9SXiang Li namespace { 85685285be9SXiang Li class DXILOpLoweringLegacy : public ModulePass { 85785285be9SXiang Li public: 858e56ad22bSJustin Bogner bool runOnModule(Module &M) override { 8593eca15cbSJustin Bogner DXILBindingMap &DBM = 8603eca15cbSJustin Bogner getAnalysis<DXILResourceBindingWrapperPass>().getBindingMap(); 8613eca15cbSJustin Bogner DXILResourceTypeMap &DRTM = 8623eca15cbSJustin Bogner getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap(); 863aa61925eSJustin Bogner 8643eca15cbSJustin Bogner return OpLowerer(M, DBM, DRTM).lowerIntrinsics(); 865e56ad22bSJustin Bogner } 86685285be9SXiang Li StringRef getPassName() const override { return "DXIL Op Lowering"; } 86785285be9SXiang Li DXILOpLoweringLegacy() : ModulePass(ID) {} 86885285be9SXiang Li 86985285be9SXiang Li static char ID; // Pass identification. 870de1a97dbSFarzon Lotfi void getAnalysisUsage(llvm::AnalysisUsage &AU) const override { 8713eca15cbSJustin Bogner AU.addRequired<DXILResourceTypeWrapperPass>(); 8723eca15cbSJustin Bogner AU.addRequired<DXILResourceBindingWrapperPass>(); 8733eca15cbSJustin Bogner AU.addPreserved<DXILResourceBindingWrapperPass>(); 874bfd05102SJustin Bogner AU.addPreserved<DXILResourceMDWrapper>(); 875bfd05102SJustin Bogner AU.addPreserved<DXILMetadataAnalysisWrapperPass>(); 876bfd05102SJustin Bogner AU.addPreserved<ShaderFlagsAnalysisWrapper>(); 877de1a97dbSFarzon Lotfi } 87885285be9SXiang Li }; 87985285be9SXiang Li char DXILOpLoweringLegacy::ID = 0; 88085285be9SXiang Li } // end anonymous namespace 88185285be9SXiang Li 88285285be9SXiang Li INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", 88385285be9SXiang Li false, false) 8843eca15cbSJustin Bogner INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass) 8853eca15cbSJustin Bogner INITIALIZE_PASS_DEPENDENCY(DXILResourceBindingWrapperPass) 88685285be9SXiang Li INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false, 88785285be9SXiang Li false) 88885285be9SXiang Li 88985285be9SXiang Li ModulePass *llvm::createDXILOpLoweringLegacyPass() { 89085285be9SXiang Li return new DXILOpLoweringLegacy(); 89185285be9SXiang Li } 892