xref: /llvm-project/llvm/lib/Target/DirectX/DXILOpLowering.cpp (revision 0e51b54b7ac02b0920e20b8ccae26b32bd6b6982)
181ee3855SJustin Bogner //===- DXILOpLowering.cpp - Lowering to DXIL operations -------------------===//
285285be9SXiang Li //
385285be9SXiang Li // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
485285be9SXiang Li // See https://llvm.org/LICENSE.txt for license information.
585285be9SXiang Li // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
685285be9SXiang Li //
785285be9SXiang Li //===----------------------------------------------------------------------===//
885285be9SXiang Li 
981ee3855SJustin Bogner #include "DXILOpLowering.h"
1085285be9SXiang Li #include "DXILConstants.h"
11de1a97dbSFarzon Lotfi #include "DXILIntrinsicExpansion.h"
1257006b14SXiang Li #include "DXILOpBuilder.h"
13bfd05102SJustin Bogner #include "DXILResourceAnalysis.h"
14bfd05102SJustin Bogner #include "DXILShaderFlags.h"
1585285be9SXiang Li #include "DirectX.h"
1685285be9SXiang Li #include "llvm/ADT/SmallVector.h"
17bfd05102SJustin Bogner #include "llvm/Analysis/DXILMetadataAnalysis.h"
18aa61925eSJustin Bogner #include "llvm/Analysis/DXILResource.h"
1985285be9SXiang Li #include "llvm/CodeGen/Passes.h"
208cf85653SJustin Bogner #include "llvm/IR/DiagnosticInfo.h"
2185285be9SXiang Li #include "llvm/IR/IRBuilder.h"
2285285be9SXiang Li #include "llvm/IR/Instruction.h"
23481bce01Sjoaosaffran #include "llvm/IR/Instructions.h"
2485285be9SXiang Li #include "llvm/IR/Intrinsics.h"
2543dc3190SXiang Li #include "llvm/IR/IntrinsicsDirectX.h"
2685285be9SXiang Li #include "llvm/IR/Module.h"
2785285be9SXiang Li #include "llvm/IR/PassManager.h"
28aa61925eSJustin Bogner #include "llvm/InitializePasses.h"
2985285be9SXiang Li #include "llvm/Pass.h"
3085285be9SXiang Li #include "llvm/Support/ErrorHandling.h"
3185285be9SXiang Li 
3285285be9SXiang Li #define DEBUG_TYPE "dxil-op-lower"
3385285be9SXiang Li 
3485285be9SXiang Li using namespace llvm;
35e77c40ffSChris Bieneman using namespace llvm::dxil;
3685285be9SXiang Li 
37060df78cSFarzon Lotfi static bool isVectorArgExpansion(Function &F) {
38060df78cSFarzon Lotfi   switch (F.getIntrinsicID()) {
39060df78cSFarzon Lotfi   case Intrinsic::dx_dot2:
40060df78cSFarzon Lotfi   case Intrinsic::dx_dot3:
41060df78cSFarzon Lotfi   case Intrinsic::dx_dot4:
42060df78cSFarzon Lotfi     return true;
43060df78cSFarzon Lotfi   }
44060df78cSFarzon Lotfi   return false;
45060df78cSFarzon Lotfi }
46060df78cSFarzon Lotfi 
47060df78cSFarzon Lotfi static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) {
48ba8a2adeSJessica Clarke   SmallVector<Value *> ExtractedElements;
49060df78cSFarzon Lotfi   auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
50060df78cSFarzon Lotfi   for (unsigned I = 0; I < VecArg->getNumElements(); ++I) {
51060df78cSFarzon Lotfi     Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I);
52060df78cSFarzon Lotfi     Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index);
53060df78cSFarzon Lotfi     ExtractedElements.push_back(ExtractedElement);
54060df78cSFarzon Lotfi   }
55060df78cSFarzon Lotfi   return ExtractedElements;
56060df78cSFarzon Lotfi }
57060df78cSFarzon Lotfi 
58060df78cSFarzon Lotfi static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
59060df78cSFarzon Lotfi                                              IRBuilder<> &Builder) {
60060df78cSFarzon Lotfi   // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
61060df78cSFarzon Lotfi   unsigned NumOperands = Orig->getNumOperands() - 1;
62060df78cSFarzon Lotfi   assert(NumOperands > 0);
63060df78cSFarzon Lotfi   Value *Arg0 = Orig->getOperand(0);
64060df78cSFarzon Lotfi   [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
65060df78cSFarzon Lotfi   assert(VecArg0);
66060df78cSFarzon Lotfi   SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder);
67060df78cSFarzon Lotfi   for (unsigned I = 1; I < NumOperands; ++I) {
68060df78cSFarzon Lotfi     Value *Arg = Orig->getOperand(I);
69060df78cSFarzon Lotfi     [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
70060df78cSFarzon Lotfi     assert(VecArg);
71060df78cSFarzon Lotfi     assert(VecArg0->getElementType() == VecArg->getElementType());
72060df78cSFarzon Lotfi     assert(VecArg0->getNumElements() == VecArg->getNumElements());
73060df78cSFarzon Lotfi     auto NextOperandList = populateOperands(Arg, Builder);
74060df78cSFarzon Lotfi     NewOperands.append(NextOperandList.begin(), NextOperandList.end());
75060df78cSFarzon Lotfi   }
76060df78cSFarzon Lotfi   return NewOperands;
77060df78cSFarzon Lotfi }
78060df78cSFarzon Lotfi 
79e56ad22bSJustin Bogner namespace {
80e56ad22bSJustin Bogner class OpLowerer {
81e56ad22bSJustin Bogner   Module &M;
82e56ad22bSJustin Bogner   DXILOpBuilder OpBuilder;
833eca15cbSJustin Bogner   DXILBindingMap &DBM;
843eca15cbSJustin Bogner   DXILResourceTypeMap &DRTM;
85aa61925eSJustin Bogner   SmallVector<CallInst *> CleanupCasts;
86e56ad22bSJustin Bogner 
87e56ad22bSJustin Bogner public:
883eca15cbSJustin Bogner   OpLowerer(Module &M, DXILBindingMap &DBM, DXILResourceTypeMap &DRTM)
893eca15cbSJustin Bogner       : M(M), OpBuilder(M), DBM(DBM), DRTM(DRTM) {}
90e56ad22bSJustin Bogner 
9190e84113SJustin Bogner   /// Replace every call to \c F using \c ReplaceCall, and then erase \c F. If
9290e84113SJustin Bogner   /// there is an error replacing a call, we emit a diagnostic and return true.
9390e84113SJustin Bogner   [[nodiscard]] bool
9490e84113SJustin Bogner   replaceFunction(Function &F,
95e56ad22bSJustin Bogner                   llvm::function_ref<Error(CallInst *CI)> ReplaceCall) {
9685285be9SXiang Li     for (User *U : make_early_inc_range(F.users())) {
9785285be9SXiang Li       CallInst *CI = dyn_cast<CallInst>(U);
9885285be9SXiang Li       if (!CI)
9985285be9SXiang Li         continue;
10085285be9SXiang Li 
101e56ad22bSJustin Bogner       if (Error E = ReplaceCall(CI)) {
1028cf85653SJustin Bogner         std::string Message(toString(std::move(E)));
1038cf85653SJustin Bogner         DiagnosticInfoUnsupported Diag(*CI->getFunction(), Message,
1048cf85653SJustin Bogner                                        CI->getDebugLoc());
1058cf85653SJustin Bogner         M.getContext().diagnose(Diag);
10690e84113SJustin Bogner         return true;
1078cf85653SJustin Bogner       }
10885285be9SXiang Li     }
10985285be9SXiang Li     if (F.user_empty())
11085285be9SXiang Li       F.eraseFromParent();
11190e84113SJustin Bogner     return false;
11285285be9SXiang Li   }
11385285be9SXiang Li 
1140a44b24dSAdam Yang   struct IntrinArgSelect {
1150a44b24dSAdam Yang     enum class Type {
1160a44b24dSAdam Yang #define DXIL_OP_INTRINSIC_ARG_SELECT_TYPE(name) name,
1170a44b24dSAdam Yang #include "DXILOperation.inc"
1180a44b24dSAdam Yang     };
1190a44b24dSAdam Yang     Type Type;
1200a44b24dSAdam Yang     int Value;
1210a44b24dSAdam Yang   };
1220a44b24dSAdam Yang 
1230a44b24dSAdam Yang   [[nodiscard]] bool
1240a44b24dSAdam Yang   replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp,
1250a44b24dSAdam Yang                         ArrayRef<IntrinArgSelect> ArgSelects) {
126e56ad22bSJustin Bogner     bool IsVectorArgExpansion = isVectorArgExpansion(F);
1270a44b24dSAdam Yang     assert(!(IsVectorArgExpansion && ArgSelects.size()) &&
1280a44b24dSAdam Yang            "Cann't do vector arg expansion when using arg selects.");
12990e84113SJustin Bogner     return replaceFunction(F, [&](CallInst *CI) -> Error {
130948249d8SAdam Yang       OpBuilder.getIRB().SetInsertPoint(CI);
1310a44b24dSAdam Yang       SmallVector<Value *> Args;
1320a44b24dSAdam Yang       if (ArgSelects.size()) {
1330a44b24dSAdam Yang         for (const IntrinArgSelect &A : ArgSelects) {
1340a44b24dSAdam Yang           switch (A.Type) {
1350a44b24dSAdam Yang           case IntrinArgSelect::Type::Index:
1360a44b24dSAdam Yang             Args.push_back(CI->getArgOperand(A.Value));
1370a44b24dSAdam Yang             break;
1380a44b24dSAdam Yang           case IntrinArgSelect::Type::I8:
1390a44b24dSAdam Yang             Args.push_back(OpBuilder.getIRB().getInt8((uint8_t)A.Value));
1400a44b24dSAdam Yang             break;
1410a44b24dSAdam Yang           case IntrinArgSelect::Type::I32:
1420a44b24dSAdam Yang             Args.push_back(OpBuilder.getIRB().getInt32(A.Value));
1430a44b24dSAdam Yang             break;
1440a44b24dSAdam Yang           }
1450a44b24dSAdam Yang         }
1460a44b24dSAdam Yang       } else if (IsVectorArgExpansion) {
1470a44b24dSAdam Yang         Args = argVectorFlatten(CI, OpBuilder.getIRB());
1480a44b24dSAdam Yang       } else {
149e56ad22bSJustin Bogner         Args.append(CI->arg_begin(), CI->arg_end());
1500a44b24dSAdam Yang       }
151e56ad22bSJustin Bogner 
152e56ad22bSJustin Bogner       Expected<CallInst *> OpCall =
1533d129016SJustin Bogner           OpBuilder.tryCreateOp(DXILOp, Args, CI->getName(), F.getReturnType());
154e56ad22bSJustin Bogner       if (Error E = OpCall.takeError())
155e56ad22bSJustin Bogner         return E;
156e56ad22bSJustin Bogner 
157e56ad22bSJustin Bogner       CI->replaceAllUsesWith(*OpCall);
158e56ad22bSJustin Bogner       CI->eraseFromParent();
159e56ad22bSJustin Bogner       return Error::success();
160e56ad22bSJustin Bogner     });
161e56ad22bSJustin Bogner   }
162e56ad22bSJustin Bogner 
163481bce01Sjoaosaffran   [[nodiscard]] bool replaceFunctionWithNamedStructOp(
164481bce01Sjoaosaffran       Function &F, dxil::OpCode DXILOp, Type *NewRetTy,
165481bce01Sjoaosaffran       llvm::function_ref<Error(CallInst *CI, CallInst *Op)> ReplaceUses) {
166481bce01Sjoaosaffran     bool IsVectorArgExpansion = isVectorArgExpansion(F);
167481bce01Sjoaosaffran     return replaceFunction(F, [&](CallInst *CI) -> Error {
168481bce01Sjoaosaffran       SmallVector<Value *> Args;
169481bce01Sjoaosaffran       OpBuilder.getIRB().SetInsertPoint(CI);
170481bce01Sjoaosaffran       if (IsVectorArgExpansion) {
171481bce01Sjoaosaffran         SmallVector<Value *> NewArgs = argVectorFlatten(CI, OpBuilder.getIRB());
172481bce01Sjoaosaffran         Args.append(NewArgs.begin(), NewArgs.end());
173481bce01Sjoaosaffran       } else
174481bce01Sjoaosaffran         Args.append(CI->arg_begin(), CI->arg_end());
175481bce01Sjoaosaffran 
176481bce01Sjoaosaffran       Expected<CallInst *> OpCall =
177481bce01Sjoaosaffran           OpBuilder.tryCreateOp(DXILOp, Args, CI->getName(), NewRetTy);
178481bce01Sjoaosaffran       if (Error E = OpCall.takeError())
179481bce01Sjoaosaffran         return E;
180481bce01Sjoaosaffran       if (Error E = ReplaceUses(CI, *OpCall))
181481bce01Sjoaosaffran         return E;
182481bce01Sjoaosaffran 
183481bce01Sjoaosaffran       return Error::success();
184481bce01Sjoaosaffran     });
185481bce01Sjoaosaffran   }
186481bce01Sjoaosaffran 
187aa61925eSJustin Bogner   /// Create a cast between a `target("dx")` type and `dx.types.Handle`, which
188aa61925eSJustin Bogner   /// is intended to be removed by the end of lowering. This is used to allow
189aa61925eSJustin Bogner   /// lowering of ops which need to change their return or argument types in a
190aa61925eSJustin Bogner   /// piecemeal way - we can add the casts in to avoid updating all of the uses
191aa61925eSJustin Bogner   /// or defs, and by the end all of the casts will be redundant.
192aa61925eSJustin Bogner   Value *createTmpHandleCast(Value *V, Type *Ty) {
19385c17e40SJay Foad     CallInst *Cast = OpBuilder.getIRB().CreateIntrinsic(
194aa07f922SJustin Bogner         Intrinsic::dx_resource_casthandle, {Ty, V->getType()}, {V});
195aa61925eSJustin Bogner     CleanupCasts.push_back(Cast);
196aa61925eSJustin Bogner     return Cast;
197aa61925eSJustin Bogner   }
198aa61925eSJustin Bogner 
199aa61925eSJustin Bogner   void cleanupHandleCasts() {
200aa61925eSJustin Bogner     SmallVector<CallInst *> ToRemove;
201aa61925eSJustin Bogner     SmallVector<Function *> CastFns;
202aa61925eSJustin Bogner 
203aa61925eSJustin Bogner     for (CallInst *Cast : CleanupCasts) {
204aa61925eSJustin Bogner       // These casts were only put in to ease the move from `target("dx")` types
205aa61925eSJustin Bogner       // to `dx.types.Handle in a piecemeal way. At this point, all of the
206aa61925eSJustin Bogner       // non-cast uses should now be `dx.types.Handle`, and remaining casts
207aa61925eSJustin Bogner       // should all form pairs to and from the now unused `target("dx")` type.
208aa61925eSJustin Bogner       CastFns.push_back(Cast->getCalledFunction());
209aa61925eSJustin Bogner 
210aa61925eSJustin Bogner       // If the cast is not to `dx.types.Handle`, it should be the first part of
211aa61925eSJustin Bogner       // the pair. Keep track so we can remove it once it has no more uses.
212aa61925eSJustin Bogner       if (Cast->getType() != OpBuilder.getHandleType()) {
213aa61925eSJustin Bogner         ToRemove.push_back(Cast);
214aa61925eSJustin Bogner         continue;
215aa61925eSJustin Bogner       }
216aa61925eSJustin Bogner       // Otherwise, we're the second handle in a pair. Forward the arguments and
217aa61925eSJustin Bogner       // remove the (second) cast.
218aa61925eSJustin Bogner       CallInst *Def = cast<CallInst>(Cast->getOperand(0));
219aa07f922SJustin Bogner       assert(Def->getIntrinsicID() == Intrinsic::dx_resource_casthandle &&
220aa61925eSJustin Bogner              "Unbalanced pair of temporary handle casts");
221aa61925eSJustin Bogner       Cast->replaceAllUsesWith(Def->getOperand(0));
222aa61925eSJustin Bogner       Cast->eraseFromParent();
223aa61925eSJustin Bogner     }
224aa61925eSJustin Bogner     for (CallInst *Cast : ToRemove) {
225aa61925eSJustin Bogner       assert(Cast->user_empty() && "Temporary handle cast still has users");
226aa61925eSJustin Bogner       Cast->eraseFromParent();
227aa61925eSJustin Bogner     }
228aa61925eSJustin Bogner 
229aa61925eSJustin Bogner     // Deduplicate the cast functions so that we only erase each one once.
230aa61925eSJustin Bogner     llvm::sort(CastFns);
231aa61925eSJustin Bogner     CastFns.erase(llvm::unique(CastFns), CastFns.end());
232aa61925eSJustin Bogner     for (Function *F : CastFns)
233aa61925eSJustin Bogner       F->eraseFromParent();
234aa61925eSJustin Bogner 
235aa61925eSJustin Bogner     CleanupCasts.clear();
236aa61925eSJustin Bogner   }
237aa61925eSJustin Bogner 
23847ef3a09SGreg Roth   // Remove the resource global associated with the handleFromBinding call
23947ef3a09SGreg Roth   // instruction and their uses as they aren't needed anymore.
24047ef3a09SGreg Roth   // TODO: We should verify that all the globals get removed.
24147ef3a09SGreg Roth   // It's expected we'll need a custom pass in the future that will eliminate
24247ef3a09SGreg Roth   // the need for this here.
24347ef3a09SGreg Roth   void removeResourceGlobals(CallInst *CI) {
24447ef3a09SGreg Roth     for (User *User : make_early_inc_range(CI->users())) {
24547ef3a09SGreg Roth       if (StoreInst *Store = dyn_cast<StoreInst>(User)) {
24647ef3a09SGreg Roth         Value *V = Store->getOperand(1);
24747ef3a09SGreg Roth         Store->eraseFromParent();
24847ef3a09SGreg Roth         if (GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
24947ef3a09SGreg Roth           if (GV->use_empty()) {
25047ef3a09SGreg Roth             GV->removeDeadConstantUsers();
25147ef3a09SGreg Roth             GV->eraseFromParent();
25247ef3a09SGreg Roth           }
25347ef3a09SGreg Roth       }
25447ef3a09SGreg Roth     }
25547ef3a09SGreg Roth   }
25647ef3a09SGreg Roth 
25790e84113SJustin Bogner   [[nodiscard]] bool lowerToCreateHandle(Function &F) {
258aa61925eSJustin Bogner     IRBuilder<> &IRB = OpBuilder.getIRB();
259aa61925eSJustin Bogner     Type *Int8Ty = IRB.getInt8Ty();
260aa61925eSJustin Bogner     Type *Int32Ty = IRB.getInt32Ty();
261aa61925eSJustin Bogner 
26290e84113SJustin Bogner     return replaceFunction(F, [&](CallInst *CI) -> Error {
263aa61925eSJustin Bogner       IRB.SetInsertPoint(CI);
264aa61925eSJustin Bogner 
2653eca15cbSJustin Bogner       auto *It = DBM.find(CI);
2663eca15cbSJustin Bogner       assert(It != DBM.end() && "Resource not in map?");
2673eca15cbSJustin Bogner       dxil::ResourceBindingInfo &RI = *It;
2683eca15cbSJustin Bogner 
269aa61925eSJustin Bogner       const auto &Binding = RI.getBinding();
2703eca15cbSJustin Bogner       dxil::ResourceClass RC = DRTM[RI.getHandleTy()].getResourceClass();
271aa61925eSJustin Bogner 
272bb88fd17SJustin Bogner       Value *IndexOp = CI->getArgOperand(3);
273bb88fd17SJustin Bogner       if (Binding.LowerBound != 0)
274bb88fd17SJustin Bogner         IndexOp = IRB.CreateAdd(IndexOp,
275bb88fd17SJustin Bogner                                 ConstantInt::get(Int32Ty, Binding.LowerBound));
276bb88fd17SJustin Bogner 
277aa61925eSJustin Bogner       std::array<Value *, 4> Args{
2783eca15cbSJustin Bogner           ConstantInt::get(Int8Ty, llvm::to_underlying(RC)),
279bb88fd17SJustin Bogner           ConstantInt::get(Int32Ty, Binding.RecordID), IndexOp,
280aa61925eSJustin Bogner           CI->getArgOperand(4)};
281aa61925eSJustin Bogner       Expected<CallInst *> OpCall =
2823d129016SJustin Bogner           OpBuilder.tryCreateOp(OpCode::CreateHandle, Args, CI->getName());
283aa61925eSJustin Bogner       if (Error E = OpCall.takeError())
284aa61925eSJustin Bogner         return E;
285aa61925eSJustin Bogner 
286aa61925eSJustin Bogner       Value *Cast = createTmpHandleCast(*OpCall, CI->getType());
287aa61925eSJustin Bogner 
28847ef3a09SGreg Roth       removeResourceGlobals(CI);
28947ef3a09SGreg Roth 
290aa61925eSJustin Bogner       CI->replaceAllUsesWith(Cast);
291aa61925eSJustin Bogner       CI->eraseFromParent();
292aa61925eSJustin Bogner       return Error::success();
293aa61925eSJustin Bogner     });
294aa61925eSJustin Bogner   }
295aa61925eSJustin Bogner 
29690e84113SJustin Bogner   [[nodiscard]] bool lowerToBindAndAnnotateHandle(Function &F) {
297aa61925eSJustin Bogner     IRBuilder<> &IRB = OpBuilder.getIRB();
298bb88fd17SJustin Bogner     Type *Int32Ty = IRB.getInt32Ty();
299aa61925eSJustin Bogner 
30090e84113SJustin Bogner     return replaceFunction(F, [&](CallInst *CI) -> Error {
301aa61925eSJustin Bogner       IRB.SetInsertPoint(CI);
302aa61925eSJustin Bogner 
3033eca15cbSJustin Bogner       auto *It = DBM.find(CI);
3043eca15cbSJustin Bogner       assert(It != DBM.end() && "Resource not in map?");
3053eca15cbSJustin Bogner       dxil::ResourceBindingInfo &RI = *It;
306aa61925eSJustin Bogner 
307aa61925eSJustin Bogner       const auto &Binding = RI.getBinding();
3083eca15cbSJustin Bogner       dxil::ResourceTypeInfo &RTI = DRTM[RI.getHandleTy()];
3093eca15cbSJustin Bogner       dxil::ResourceClass RC = RTI.getResourceClass();
310bb88fd17SJustin Bogner 
311bb88fd17SJustin Bogner       Value *IndexOp = CI->getArgOperand(3);
312bb88fd17SJustin Bogner       if (Binding.LowerBound != 0)
313bb88fd17SJustin Bogner         IndexOp = IRB.CreateAdd(IndexOp,
314bb88fd17SJustin Bogner                                 ConstantInt::get(Int32Ty, Binding.LowerBound));
315bb88fd17SJustin Bogner 
3163eca15cbSJustin Bogner       std::pair<uint32_t, uint32_t> Props =
3173eca15cbSJustin Bogner           RI.getAnnotateProps(*F.getParent(), RTI);
318aa61925eSJustin Bogner 
319aa61925eSJustin Bogner       // For `CreateHandleFromBinding` we need the upper bound rather than the
320aa61925eSJustin Bogner       // size, so we need to be careful about the difference for "unbounded".
321aa61925eSJustin Bogner       uint32_t Unbounded = std::numeric_limits<uint32_t>::max();
322aa61925eSJustin Bogner       uint32_t UpperBound = Binding.Size == Unbounded
323aa61925eSJustin Bogner                                 ? Unbounded
324aa61925eSJustin Bogner                                 : Binding.LowerBound + Binding.Size - 1;
3253eca15cbSJustin Bogner       Constant *ResBind = OpBuilder.getResBind(Binding.LowerBound, UpperBound,
3263eca15cbSJustin Bogner                                                Binding.Space, RC);
327bb88fd17SJustin Bogner       std::array<Value *, 3> BindArgs{ResBind, IndexOp, CI->getArgOperand(4)};
3283d129016SJustin Bogner       Expected<CallInst *> OpBind = OpBuilder.tryCreateOp(
3293d129016SJustin Bogner           OpCode::CreateHandleFromBinding, BindArgs, CI->getName());
330aa61925eSJustin Bogner       if (Error E = OpBind.takeError())
331aa61925eSJustin Bogner         return E;
332aa61925eSJustin Bogner 
333aa61925eSJustin Bogner       std::array<Value *, 2> AnnotateArgs{
334aa61925eSJustin Bogner           *OpBind, OpBuilder.getResProps(Props.first, Props.second)};
3353d129016SJustin Bogner       Expected<CallInst *> OpAnnotate = OpBuilder.tryCreateOp(
3363d129016SJustin Bogner           OpCode::AnnotateHandle, AnnotateArgs,
3373d129016SJustin Bogner           CI->hasName() ? CI->getName() + "_annot" : Twine());
338aa61925eSJustin Bogner       if (Error E = OpAnnotate.takeError())
339aa61925eSJustin Bogner         return E;
340aa61925eSJustin Bogner 
341aa61925eSJustin Bogner       Value *Cast = createTmpHandleCast(*OpAnnotate, CI->getType());
342aa61925eSJustin Bogner 
34347ef3a09SGreg Roth       removeResourceGlobals(CI);
34447ef3a09SGreg Roth 
345aa61925eSJustin Bogner       CI->replaceAllUsesWith(Cast);
346aa61925eSJustin Bogner       CI->eraseFromParent();
347aa61925eSJustin Bogner 
348aa61925eSJustin Bogner       return Error::success();
349aa61925eSJustin Bogner     });
350aa61925eSJustin Bogner   }
351aa61925eSJustin Bogner 
352aa07f922SJustin Bogner   /// Lower `dx.resource.handlefrombinding` intrinsics depending on the shader
353aa07f922SJustin Bogner   /// model and taking into account binding information from
354aa07f922SJustin Bogner   /// DXILResourceBindingAnalysis.
35590e84113SJustin Bogner   bool lowerHandleFromBinding(Function &F) {
356aa61925eSJustin Bogner     Triple TT(Triple(M.getTargetTriple()));
357aa61925eSJustin Bogner     if (TT.getDXILVersion() < VersionTuple(1, 6))
35890e84113SJustin Bogner       return lowerToCreateHandle(F);
35990e84113SJustin Bogner     return lowerToBindAndAnnotateHandle(F);
360aa61925eSJustin Bogner   }
361aa61925eSJustin Bogner 
362481bce01Sjoaosaffran   Error replaceSplitDoubleCallUsages(CallInst *Intrin, CallInst *Op) {
363481bce01Sjoaosaffran     for (Use &U : make_early_inc_range(Intrin->uses())) {
364481bce01Sjoaosaffran       if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
365481bce01Sjoaosaffran 
366481bce01Sjoaosaffran         if (EVI->getNumIndices() != 1)
367481bce01Sjoaosaffran           return createStringError(std::errc::invalid_argument,
368481bce01Sjoaosaffran                                    "Splitdouble has only 2 elements");
369481bce01Sjoaosaffran         EVI->setOperand(0, Op);
370481bce01Sjoaosaffran       } else {
371481bce01Sjoaosaffran         return make_error<StringError>(
372481bce01Sjoaosaffran             "Splitdouble use is not ExtractValueInst",
373481bce01Sjoaosaffran             inconvertibleErrorCode());
374481bce01Sjoaosaffran       }
375481bce01Sjoaosaffran     }
376481bce01Sjoaosaffran 
377481bce01Sjoaosaffran     Intrin->eraseFromParent();
378481bce01Sjoaosaffran 
379481bce01Sjoaosaffran     return Error::success();
380481bce01Sjoaosaffran   }
381481bce01Sjoaosaffran 
3823f22756fSJustin Bogner   /// Replace uses of \c Intrin with the values in the `dx.ResRet` of \c Op.
3833f22756fSJustin Bogner   /// Since we expect to be post-scalarization, make an effort to avoid vectors.
38434e20f18SJustin Bogner   Error replaceResRetUses(CallInst *Intrin, CallInst *Op, bool HasCheckBit) {
3853f22756fSJustin Bogner     IRBuilder<> &IRB = OpBuilder.getIRB();
3863f22756fSJustin Bogner 
38734e20f18SJustin Bogner     Instruction *OldResult = Intrin;
3883f22756fSJustin Bogner     Type *OldTy = Intrin->getType();
3893f22756fSJustin Bogner 
39034e20f18SJustin Bogner     if (HasCheckBit) {
39134e20f18SJustin Bogner       auto *ST = cast<StructType>(OldTy);
39234e20f18SJustin Bogner 
39334e20f18SJustin Bogner       Value *CheckOp = nullptr;
39434e20f18SJustin Bogner       Type *Int32Ty = IRB.getInt32Ty();
39534e20f18SJustin Bogner       for (Use &U : make_early_inc_range(OldResult->uses())) {
39634e20f18SJustin Bogner         if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
39734e20f18SJustin Bogner           ArrayRef<unsigned> Indices = EVI->getIndices();
39834e20f18SJustin Bogner           assert(Indices.size() == 1);
39934e20f18SJustin Bogner           // We're only interested in uses of the check bit for now.
40034e20f18SJustin Bogner           if (Indices[0] != 1)
40134e20f18SJustin Bogner             continue;
40234e20f18SJustin Bogner           if (!CheckOp) {
40334e20f18SJustin Bogner             Value *NewEVI = IRB.CreateExtractValue(Op, 4);
40434e20f18SJustin Bogner             Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
4053d129016SJustin Bogner                 OpCode::CheckAccessFullyMapped, {NewEVI},
4063d129016SJustin Bogner                 OldResult->hasName() ? OldResult->getName() + "_check"
4073d129016SJustin Bogner                                      : Twine(),
4083d129016SJustin Bogner                 Int32Ty);
40934e20f18SJustin Bogner             if (Error E = OpCall.takeError())
41034e20f18SJustin Bogner               return E;
41134e20f18SJustin Bogner             CheckOp = *OpCall;
41234e20f18SJustin Bogner           }
41334e20f18SJustin Bogner           EVI->replaceAllUsesWith(CheckOp);
41434e20f18SJustin Bogner           EVI->eraseFromParent();
41534e20f18SJustin Bogner         }
41634e20f18SJustin Bogner       }
41734e20f18SJustin Bogner 
4182c7c07dfSJustin Bogner       if (OldResult->use_empty()) {
4192c7c07dfSJustin Bogner         // Only the check bit was used, so we're done here.
4202c7c07dfSJustin Bogner         OldResult->eraseFromParent();
4212c7c07dfSJustin Bogner         return Error::success();
4222c7c07dfSJustin Bogner       }
4232c7c07dfSJustin Bogner 
4242c7c07dfSJustin Bogner       assert(OldResult->hasOneUse() &&
4252c7c07dfSJustin Bogner              isa<ExtractValueInst>(*OldResult->user_begin()) &&
4262c7c07dfSJustin Bogner              "Expected only use to be extract of first element");
4272c7c07dfSJustin Bogner       OldResult = cast<Instruction>(*OldResult->user_begin());
42834e20f18SJustin Bogner       OldTy = ST->getElementType(0);
42934e20f18SJustin Bogner     }
43034e20f18SJustin Bogner 
4313f22756fSJustin Bogner     // For scalars, we just extract the first element.
4323f22756fSJustin Bogner     if (!isa<FixedVectorType>(OldTy)) {
4333f22756fSJustin Bogner       Value *EVI = IRB.CreateExtractValue(Op, 0);
43434e20f18SJustin Bogner       OldResult->replaceAllUsesWith(EVI);
43534e20f18SJustin Bogner       OldResult->eraseFromParent();
43634e20f18SJustin Bogner       if (OldResult != Intrin) {
43734e20f18SJustin Bogner         assert(Intrin->use_empty() && "Intrinsic still has uses?");
4383f22756fSJustin Bogner         Intrin->eraseFromParent();
43934e20f18SJustin Bogner       }
4403f22756fSJustin Bogner       return Error::success();
4413f22756fSJustin Bogner     }
4423f22756fSJustin Bogner 
4433f22756fSJustin Bogner     std::array<Value *, 4> Extracts = {};
4443f22756fSJustin Bogner     SmallVector<ExtractElementInst *> DynamicAccesses;
4453f22756fSJustin Bogner 
4463f22756fSJustin Bogner     // The users of the operation should all be scalarized, so we attempt to
4473f22756fSJustin Bogner     // replace the extractelements with extractvalues directly.
44834e20f18SJustin Bogner     for (Use &U : make_early_inc_range(OldResult->uses())) {
4493f22756fSJustin Bogner       if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser())) {
4503f22756fSJustin Bogner         if (auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
4513f22756fSJustin Bogner           size_t IndexVal = IndexOp->getZExtValue();
4523f22756fSJustin Bogner           assert(IndexVal < 4 && "Index into buffer load out of range");
4533f22756fSJustin Bogner           if (!Extracts[IndexVal])
4543f22756fSJustin Bogner             Extracts[IndexVal] = IRB.CreateExtractValue(Op, IndexVal);
4553f22756fSJustin Bogner           EEI->replaceAllUsesWith(Extracts[IndexVal]);
4563f22756fSJustin Bogner           EEI->eraseFromParent();
4573f22756fSJustin Bogner         } else {
4583f22756fSJustin Bogner           DynamicAccesses.push_back(EEI);
4593f22756fSJustin Bogner         }
4603f22756fSJustin Bogner       }
4613f22756fSJustin Bogner     }
4623f22756fSJustin Bogner 
4633f22756fSJustin Bogner     const auto *VecTy = cast<FixedVectorType>(OldTy);
4643f22756fSJustin Bogner     const unsigned N = VecTy->getNumElements();
4653f22756fSJustin Bogner 
4663f22756fSJustin Bogner     // If there's a dynamic access we need to round trip through stack memory so
4673f22756fSJustin Bogner     // that we don't leave vectors around.
4683f22756fSJustin Bogner     if (!DynamicAccesses.empty()) {
4693f22756fSJustin Bogner       Type *Int32Ty = IRB.getInt32Ty();
4703f22756fSJustin Bogner       Constant *Zero = ConstantInt::get(Int32Ty, 0);
4713f22756fSJustin Bogner 
4723f22756fSJustin Bogner       Type *ElTy = VecTy->getElementType();
4733f22756fSJustin Bogner       Type *ArrayTy = ArrayType::get(ElTy, N);
4743f22756fSJustin Bogner       Value *Alloca = IRB.CreateAlloca(ArrayTy);
4753f22756fSJustin Bogner 
4763f22756fSJustin Bogner       for (int I = 0, E = N; I != E; ++I) {
4773f22756fSJustin Bogner         if (!Extracts[I])
4783f22756fSJustin Bogner           Extracts[I] = IRB.CreateExtractValue(Op, I);
4793f22756fSJustin Bogner         Value *GEP = IRB.CreateInBoundsGEP(
4803f22756fSJustin Bogner             ArrayTy, Alloca, {Zero, ConstantInt::get(Int32Ty, I)});
4813f22756fSJustin Bogner         IRB.CreateStore(Extracts[I], GEP);
4823f22756fSJustin Bogner       }
4833f22756fSJustin Bogner 
4843f22756fSJustin Bogner       for (ExtractElementInst *EEI : DynamicAccesses) {
4853f22756fSJustin Bogner         Value *GEP = IRB.CreateInBoundsGEP(ArrayTy, Alloca,
4863f22756fSJustin Bogner                                            {Zero, EEI->getIndexOperand()});
4873f22756fSJustin Bogner         Value *Load = IRB.CreateLoad(ElTy, GEP);
4883f22756fSJustin Bogner         EEI->replaceAllUsesWith(Load);
4893f22756fSJustin Bogner         EEI->eraseFromParent();
4903f22756fSJustin Bogner       }
4913f22756fSJustin Bogner     }
4923f22756fSJustin Bogner 
4933f22756fSJustin Bogner     // If we still have uses, then we're not fully scalarized and need to
4943f22756fSJustin Bogner     // recreate the vector. This should only happen for things like exported
4953f22756fSJustin Bogner     // functions from libraries.
49634e20f18SJustin Bogner     if (!OldResult->use_empty()) {
4973f22756fSJustin Bogner       for (int I = 0, E = N; I != E; ++I)
4983f22756fSJustin Bogner         if (!Extracts[I])
4993f22756fSJustin Bogner           Extracts[I] = IRB.CreateExtractValue(Op, I);
5003f22756fSJustin Bogner 
5013f22756fSJustin Bogner       Value *Vec = UndefValue::get(OldTy);
5023f22756fSJustin Bogner       for (int I = 0, E = N; I != E; ++I)
5033f22756fSJustin Bogner         Vec = IRB.CreateInsertElement(Vec, Extracts[I], I);
50434e20f18SJustin Bogner       OldResult->replaceAllUsesWith(Vec);
5053f22756fSJustin Bogner     }
5063f22756fSJustin Bogner 
50734e20f18SJustin Bogner     OldResult->eraseFromParent();
50834e20f18SJustin Bogner     if (OldResult != Intrin) {
50934e20f18SJustin Bogner       assert(Intrin->use_empty() && "Intrinsic still has uses?");
5103f22756fSJustin Bogner       Intrin->eraseFromParent();
51134e20f18SJustin Bogner     }
51234e20f18SJustin Bogner 
5133f22756fSJustin Bogner     return Error::success();
5143f22756fSJustin Bogner   }
5153f22756fSJustin Bogner 
51634e20f18SJustin Bogner   [[nodiscard]] bool lowerTypedBufferLoad(Function &F, bool HasCheckBit) {
5173f22756fSJustin Bogner     IRBuilder<> &IRB = OpBuilder.getIRB();
5183f22756fSJustin Bogner     Type *Int32Ty = IRB.getInt32Ty();
5193f22756fSJustin Bogner 
52090e84113SJustin Bogner     return replaceFunction(F, [&](CallInst *CI) -> Error {
5213f22756fSJustin Bogner       IRB.SetInsertPoint(CI);
5223f22756fSJustin Bogner 
5233f22756fSJustin Bogner       Value *Handle =
5243f22756fSJustin Bogner           createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
5253f22756fSJustin Bogner       Value *Index0 = CI->getArgOperand(1);
5263f22756fSJustin Bogner       Value *Index1 = UndefValue::get(Int32Ty);
5273f22756fSJustin Bogner 
52834e20f18SJustin Bogner       Type *OldTy = CI->getType();
52934e20f18SJustin Bogner       if (HasCheckBit)
53034e20f18SJustin Bogner         OldTy = cast<StructType>(OldTy)->getElementType(0);
53134e20f18SJustin Bogner       Type *NewRetTy = OpBuilder.getResRetType(OldTy->getScalarType());
5323f22756fSJustin Bogner 
5333f22756fSJustin Bogner       std::array<Value *, 3> Args{Handle, Index0, Index1};
5343d129016SJustin Bogner       Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
5353d129016SJustin Bogner           OpCode::BufferLoad, Args, CI->getName(), NewRetTy);
5363f22756fSJustin Bogner       if (Error E = OpCall.takeError())
5373f22756fSJustin Bogner         return E;
53834e20f18SJustin Bogner       if (Error E = replaceResRetUses(CI, *OpCall, HasCheckBit))
5393f22756fSJustin Bogner         return E;
5403f22756fSJustin Bogner 
5413f22756fSJustin Bogner       return Error::success();
5423f22756fSJustin Bogner     });
5433f22756fSJustin Bogner   }
5443f22756fSJustin Bogner 
545cba9bd5cSJustin Bogner   [[nodiscard]] bool lowerRawBufferLoad(Function &F) {
546cba9bd5cSJustin Bogner     Triple TT(Triple(M.getTargetTriple()));
547cba9bd5cSJustin Bogner     VersionTuple DXILVersion = TT.getDXILVersion();
548cba9bd5cSJustin Bogner     const DataLayout &DL = F.getDataLayout();
549cba9bd5cSJustin Bogner     IRBuilder<> &IRB = OpBuilder.getIRB();
550cba9bd5cSJustin Bogner     Type *Int8Ty = IRB.getInt8Ty();
551cba9bd5cSJustin Bogner     Type *Int32Ty = IRB.getInt32Ty();
552cba9bd5cSJustin Bogner 
553cba9bd5cSJustin Bogner     return replaceFunction(F, [&](CallInst *CI) -> Error {
554cba9bd5cSJustin Bogner       IRB.SetInsertPoint(CI);
555cba9bd5cSJustin Bogner 
556cba9bd5cSJustin Bogner       Type *OldTy = cast<StructType>(CI->getType())->getElementType(0);
557cba9bd5cSJustin Bogner       Type *ScalarTy = OldTy->getScalarType();
558cba9bd5cSJustin Bogner       Type *NewRetTy = OpBuilder.getResRetType(ScalarTy);
559cba9bd5cSJustin Bogner 
560cba9bd5cSJustin Bogner       Value *Handle =
561cba9bd5cSJustin Bogner           createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
562cba9bd5cSJustin Bogner       Value *Index0 = CI->getArgOperand(1);
563cba9bd5cSJustin Bogner       Value *Index1 = CI->getArgOperand(2);
564cba9bd5cSJustin Bogner       uint64_t NumElements =
565cba9bd5cSJustin Bogner           DL.getTypeSizeInBits(OldTy) / DL.getTypeSizeInBits(ScalarTy);
566cba9bd5cSJustin Bogner       Value *Mask = ConstantInt::get(Int8Ty, ~(~0U << NumElements));
567cba9bd5cSJustin Bogner       Value *Align =
568cba9bd5cSJustin Bogner           ConstantInt::get(Int32Ty, DL.getPrefTypeAlign(ScalarTy).value());
569cba9bd5cSJustin Bogner 
570cba9bd5cSJustin Bogner       Expected<CallInst *> OpCall =
571cba9bd5cSJustin Bogner           DXILVersion >= VersionTuple(1, 2)
572cba9bd5cSJustin Bogner               ? OpBuilder.tryCreateOp(OpCode::RawBufferLoad,
573cba9bd5cSJustin Bogner                                       {Handle, Index0, Index1, Mask, Align},
574cba9bd5cSJustin Bogner                                       CI->getName(), NewRetTy)
575cba9bd5cSJustin Bogner               : OpBuilder.tryCreateOp(OpCode::BufferLoad,
576cba9bd5cSJustin Bogner                                       {Handle, Index0, Index1}, CI->getName(),
577cba9bd5cSJustin Bogner                                       NewRetTy);
578cba9bd5cSJustin Bogner       if (Error E = OpCall.takeError())
579cba9bd5cSJustin Bogner         return E;
580cba9bd5cSJustin Bogner       if (Error E = replaceResRetUses(CI, *OpCall, /*HasCheckBit=*/true))
581cba9bd5cSJustin Bogner         return E;
582cba9bd5cSJustin Bogner 
583cba9bd5cSJustin Bogner       return Error::success();
584cba9bd5cSJustin Bogner     });
585cba9bd5cSJustin Bogner   }
586cba9bd5cSJustin Bogner 
5871f250999Sjoaosaffran   [[nodiscard]] bool lowerUpdateCounter(Function &F) {
5881f250999Sjoaosaffran     IRBuilder<> &IRB = OpBuilder.getIRB();
589691bd184Sjoaosaffran     Type *Int32Ty = IRB.getInt32Ty();
5901f250999Sjoaosaffran 
5911f250999Sjoaosaffran     return replaceFunction(F, [&](CallInst *CI) -> Error {
5921f250999Sjoaosaffran       IRB.SetInsertPoint(CI);
5931f250999Sjoaosaffran       Value *Handle =
5941f250999Sjoaosaffran           createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
5951f250999Sjoaosaffran       Value *Op1 = CI->getArgOperand(1);
5961f250999Sjoaosaffran 
5971f250999Sjoaosaffran       std::array<Value *, 2> Args{Handle, Op1};
5981f250999Sjoaosaffran 
599691bd184Sjoaosaffran       Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
600691bd184Sjoaosaffran           OpCode::UpdateCounter, Args, CI->getName(), Int32Ty);
6011f250999Sjoaosaffran 
6021f250999Sjoaosaffran       if (Error E = OpCall.takeError())
6031f250999Sjoaosaffran         return E;
6041f250999Sjoaosaffran 
605691bd184Sjoaosaffran       CI->replaceAllUsesWith(*OpCall);
6061f250999Sjoaosaffran       CI->eraseFromParent();
6071f250999Sjoaosaffran       return Error::success();
6081f250999Sjoaosaffran     });
6091f250999Sjoaosaffran   }
6101f250999Sjoaosaffran 
6110fca76d5SJustin Bogner   [[nodiscard]] bool lowerGetPointer(Function &F) {
6120fca76d5SJustin Bogner     // These should have already been handled in DXILResourceAccess, so we can
6130fca76d5SJustin Bogner     // just clean up the dead prototype.
6140fca76d5SJustin Bogner     assert(F.user_empty() && "getpointer operations should have been removed");
6150fca76d5SJustin Bogner     F.eraseFromParent();
6160fca76d5SJustin Bogner     return false;
6170fca76d5SJustin Bogner   }
6180fca76d5SJustin Bogner 
619*0e51b54bSJustin Bogner   [[nodiscard]] bool lowerBufferStore(Function &F, bool IsRaw) {
620*0e51b54bSJustin Bogner     Triple TT(Triple(M.getTargetTriple()));
621*0e51b54bSJustin Bogner     VersionTuple DXILVersion = TT.getDXILVersion();
622*0e51b54bSJustin Bogner     const DataLayout &DL = F.getDataLayout();
62390e84113SJustin Bogner     IRBuilder<> &IRB = OpBuilder.getIRB();
62490e84113SJustin Bogner     Type *Int8Ty = IRB.getInt8Ty();
62590e84113SJustin Bogner     Type *Int32Ty = IRB.getInt32Ty();
62690e84113SJustin Bogner 
62790e84113SJustin Bogner     return replaceFunction(F, [&](CallInst *CI) -> Error {
62890e84113SJustin Bogner       IRB.SetInsertPoint(CI);
62990e84113SJustin Bogner 
63090e84113SJustin Bogner       Value *Handle =
63190e84113SJustin Bogner           createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
63290e84113SJustin Bogner       Value *Index0 = CI->getArgOperand(1);
633*0e51b54bSJustin Bogner       Value *Index1 = IsRaw ? CI->getArgOperand(2) : UndefValue::get(Int32Ty);
63490e84113SJustin Bogner 
635*0e51b54bSJustin Bogner       Value *Data = CI->getArgOperand(IsRaw ? 3 : 2);
636*0e51b54bSJustin Bogner       Type *DataTy = Data->getType();
637*0e51b54bSJustin Bogner       Type *ScalarTy = DataTy->getScalarType();
638*0e51b54bSJustin Bogner 
639*0e51b54bSJustin Bogner       uint64_t NumElements =
640*0e51b54bSJustin Bogner           DL.getTypeSizeInBits(DataTy) / DL.getTypeSizeInBits(ScalarTy);
641*0e51b54bSJustin Bogner       Value *Mask = ConstantInt::get(Int8Ty, ~(~0U << NumElements));
642*0e51b54bSJustin Bogner 
643*0e51b54bSJustin Bogner       // TODO: check that we only have vector or scalar...
644*0e51b54bSJustin Bogner       if (!IsRaw && NumElements != 4)
64590e84113SJustin Bogner         return make_error<StringError>(
64690e84113SJustin Bogner             "typedBufferStore data must be a vector of 4 elements",
64790e84113SJustin Bogner             inconvertibleErrorCode());
648*0e51b54bSJustin Bogner       else if (NumElements > 4)
649*0e51b54bSJustin Bogner         return make_error<StringError>(
650*0e51b54bSJustin Bogner             "rawBufferStore data must have at most 4 elements",
651*0e51b54bSJustin Bogner             inconvertibleErrorCode());
65290e84113SJustin Bogner 
6532c88ac9dSJustin Bogner       std::array<Value *, 4> DataElements{nullptr, nullptr, nullptr, nullptr};
654*0e51b54bSJustin Bogner       if (DataTy == ScalarTy)
655*0e51b54bSJustin Bogner         DataElements[0] = Data;
656*0e51b54bSJustin Bogner       else {
657*0e51b54bSJustin Bogner         // Since we're post-scalarizer, if we see a vector here it's likely
658*0e51b54bSJustin Bogner         // constructed solely for the argument of the store. Just use the scalar
659*0e51b54bSJustin Bogner         // values from before they're inserted into the temporary.
6602c88ac9dSJustin Bogner         auto *IEI = dyn_cast<InsertElementInst>(Data);
6612c88ac9dSJustin Bogner         while (IEI) {
6622c88ac9dSJustin Bogner           auto *IndexOp = dyn_cast<ConstantInt>(IEI->getOperand(2));
6632c88ac9dSJustin Bogner           if (!IndexOp)
6642c88ac9dSJustin Bogner             break;
6652c88ac9dSJustin Bogner           size_t IndexVal = IndexOp->getZExtValue();
6662c88ac9dSJustin Bogner           assert(IndexVal < 4 && "Too many elements for buffer store");
6672c88ac9dSJustin Bogner           DataElements[IndexVal] = IEI->getOperand(1);
6682c88ac9dSJustin Bogner           IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0));
6692c88ac9dSJustin Bogner         }
670*0e51b54bSJustin Bogner       }
6712c88ac9dSJustin Bogner 
6722c88ac9dSJustin Bogner       // If for some reason we weren't able to forward the arguments from the
673*0e51b54bSJustin Bogner       // scalarizer artifact, then we may need to actually extract elements from
674*0e51b54bSJustin Bogner       // the vector.
675*0e51b54bSJustin Bogner       for (int I = 0, E = NumElements; I < E; ++I)
6762c88ac9dSJustin Bogner         if (DataElements[I] == nullptr)
6772c88ac9dSJustin Bogner           DataElements[I] =
6782c88ac9dSJustin Bogner               IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, I));
679*0e51b54bSJustin Bogner       // For any elements beyond the length of the vector, fill up with undef.
680*0e51b54bSJustin Bogner       for (int I = NumElements, E = 4; I < E; ++I)
681*0e51b54bSJustin Bogner         if (DataElements[I] == nullptr)
682*0e51b54bSJustin Bogner           DataElements[I] = UndefValue::get(ScalarTy);
6832c88ac9dSJustin Bogner 
684*0e51b54bSJustin Bogner       dxil::OpCode Op = OpCode::BufferStore;
685*0e51b54bSJustin Bogner       SmallVector<Value *, 9> Args{
6862c88ac9dSJustin Bogner           Handle,          Index0,          Index1,          DataElements[0],
6872c88ac9dSJustin Bogner           DataElements[1], DataElements[2], DataElements[3], Mask};
688*0e51b54bSJustin Bogner       if (IsRaw && DXILVersion >= VersionTuple(1, 2)) {
689*0e51b54bSJustin Bogner         Op = OpCode::RawBufferStore;
690*0e51b54bSJustin Bogner         // RawBufferStore requires the alignment
691*0e51b54bSJustin Bogner         Args.push_back(
692*0e51b54bSJustin Bogner             ConstantInt::get(Int32Ty, DL.getPrefTypeAlign(ScalarTy).value()));
693*0e51b54bSJustin Bogner       }
69490e84113SJustin Bogner       Expected<CallInst *> OpCall =
695*0e51b54bSJustin Bogner           OpBuilder.tryCreateOp(Op, Args, CI->getName());
69690e84113SJustin Bogner       if (Error E = OpCall.takeError())
69790e84113SJustin Bogner         return E;
69890e84113SJustin Bogner 
69990e84113SJustin Bogner       CI->eraseFromParent();
7002c88ac9dSJustin Bogner       // Clean up any leftover `insertelement`s
701*0e51b54bSJustin Bogner       auto *IEI = dyn_cast<InsertElementInst>(Data);
7022c88ac9dSJustin Bogner       while (IEI && IEI->use_empty()) {
7032c88ac9dSJustin Bogner         InsertElementInst *Tmp = IEI;
7042c88ac9dSJustin Bogner         IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0));
7052c88ac9dSJustin Bogner         Tmp->eraseFromParent();
7062c88ac9dSJustin Bogner       }
7072c88ac9dSJustin Bogner 
70890e84113SJustin Bogner       return Error::success();
70990e84113SJustin Bogner     });
71090e84113SJustin Bogner   }
71190e84113SJustin Bogner 
71275e7ba8cSSarah Spall   [[nodiscard]] bool lowerCtpopToCountBits(Function &F) {
71375e7ba8cSSarah Spall     IRBuilder<> &IRB = OpBuilder.getIRB();
71475e7ba8cSSarah Spall     Type *Int32Ty = IRB.getInt32Ty();
71575e7ba8cSSarah Spall 
71675e7ba8cSSarah Spall     return replaceFunction(F, [&](CallInst *CI) -> Error {
71775e7ba8cSSarah Spall       IRB.SetInsertPoint(CI);
71875e7ba8cSSarah Spall       SmallVector<Value *> Args;
71975e7ba8cSSarah Spall       Args.append(CI->arg_begin(), CI->arg_end());
72075e7ba8cSSarah Spall 
72175e7ba8cSSarah Spall       Type *RetTy = Int32Ty;
72275e7ba8cSSarah Spall       Type *FRT = F.getReturnType();
72375e7ba8cSSarah Spall       if (const auto *VT = dyn_cast<VectorType>(FRT))
72475e7ba8cSSarah Spall         RetTy = VectorType::get(RetTy, VT);
72575e7ba8cSSarah Spall 
72675e7ba8cSSarah Spall       Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
72775e7ba8cSSarah Spall           dxil::OpCode::CountBits, Args, CI->getName(), RetTy);
72875e7ba8cSSarah Spall       if (Error E = OpCall.takeError())
72975e7ba8cSSarah Spall         return E;
73075e7ba8cSSarah Spall 
73175e7ba8cSSarah Spall       // If the result type is 32 bits we can do a direct replacement.
73275e7ba8cSSarah Spall       if (FRT->isIntOrIntVectorTy(32)) {
73375e7ba8cSSarah Spall         CI->replaceAllUsesWith(*OpCall);
73475e7ba8cSSarah Spall         CI->eraseFromParent();
73575e7ba8cSSarah Spall         return Error::success();
73675e7ba8cSSarah Spall       }
73775e7ba8cSSarah Spall 
73875e7ba8cSSarah Spall       unsigned CastOp;
73975e7ba8cSSarah Spall       unsigned CastOp2;
74075e7ba8cSSarah Spall       if (FRT->isIntOrIntVectorTy(16)) {
74175e7ba8cSSarah Spall         CastOp = Instruction::ZExt;
74275e7ba8cSSarah Spall         CastOp2 = Instruction::SExt;
74375e7ba8cSSarah Spall       } else { // must be 64 bits
74475e7ba8cSSarah Spall         assert(FRT->isIntOrIntVectorTy(64) &&
74575e7ba8cSSarah Spall                "Currently only lowering 16, 32, or 64 bit ctpop to CountBits \
74675e7ba8cSSarah Spall                 is supported.");
74775e7ba8cSSarah Spall         CastOp = Instruction::Trunc;
74875e7ba8cSSarah Spall         CastOp2 = Instruction::Trunc;
74975e7ba8cSSarah Spall       }
75075e7ba8cSSarah Spall 
75175e7ba8cSSarah Spall       // It is correct to replace the ctpop with the dxil op and
75275e7ba8cSSarah Spall       // remove all casts to i32
75375e7ba8cSSarah Spall       bool NeedsCast = false;
75475e7ba8cSSarah Spall       for (User *User : make_early_inc_range(CI->users())) {
75575e7ba8cSSarah Spall         Instruction *I = dyn_cast<Instruction>(User);
75675e7ba8cSSarah Spall         if (I && (I->getOpcode() == CastOp || I->getOpcode() == CastOp2) &&
75775e7ba8cSSarah Spall             I->getType() == RetTy) {
75875e7ba8cSSarah Spall           I->replaceAllUsesWith(*OpCall);
75975e7ba8cSSarah Spall           I->eraseFromParent();
76075e7ba8cSSarah Spall         } else
76175e7ba8cSSarah Spall           NeedsCast = true;
76275e7ba8cSSarah Spall       }
76375e7ba8cSSarah Spall 
76475e7ba8cSSarah Spall       // It is correct to replace a ctpop with the dxil op and
76575e7ba8cSSarah Spall       // a cast from i32 to the return type of the ctpop
76675e7ba8cSSarah Spall       // the cast is emitted here if there is a non-cast to i32
76775e7ba8cSSarah Spall       // instr which uses the ctpop
76875e7ba8cSSarah Spall       if (NeedsCast) {
76975e7ba8cSSarah Spall         Value *Cast =
77075e7ba8cSSarah Spall             IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(), "ctpop.cast");
77175e7ba8cSSarah Spall         CI->replaceAllUsesWith(Cast);
77275e7ba8cSSarah Spall       }
77375e7ba8cSSarah Spall 
77475e7ba8cSSarah Spall       CI->eraseFromParent();
77575e7ba8cSSarah Spall       return Error::success();
77675e7ba8cSSarah Spall     });
77775e7ba8cSSarah Spall   }
77875e7ba8cSSarah Spall 
779e56ad22bSJustin Bogner   bool lowerIntrinsics() {
78085285be9SXiang Li     bool Updated = false;
78190e84113SJustin Bogner     bool HasErrors = false;
782264c09b7SXiang Li 
78385285be9SXiang Li     for (Function &F : make_early_inc_range(M.functions())) {
78485285be9SXiang Li       if (!F.isDeclaration())
78585285be9SXiang Li         continue;
78685285be9SXiang Li       Intrinsic::ID ID = F.getIntrinsicID();
78794da6bfbSJustin Bogner       switch (ID) {
78894da6bfbSJustin Bogner       default:
789264c09b7SXiang Li         continue;
7900a44b24dSAdam Yang #define DXIL_OP_INTRINSIC(OpCode, Intrin, ...)                                 \
79194da6bfbSJustin Bogner   case Intrin:                                                                 \
7920a44b24dSAdam Yang     HasErrors |= replaceFunctionWithOp(                                        \
7930a44b24dSAdam Yang         F, OpCode, ArrayRef<IntrinArgSelect>{__VA_ARGS__});                    \
79494da6bfbSJustin Bogner     break;
79594da6bfbSJustin Bogner #include "DXILOperation.inc"
796aa07f922SJustin Bogner       case Intrinsic::dx_resource_handlefrombinding:
79790e84113SJustin Bogner         HasErrors |= lowerHandleFromBinding(F);
7983f22756fSJustin Bogner         break;
7990fca76d5SJustin Bogner       case Intrinsic::dx_resource_getpointer:
8000fca76d5SJustin Bogner         HasErrors |= lowerGetPointer(F);
8010fca76d5SJustin Bogner         break;
802aa07f922SJustin Bogner       case Intrinsic::dx_resource_load_typedbuffer:
80334e20f18SJustin Bogner         HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/true);
80490e84113SJustin Bogner         break;
805aa07f922SJustin Bogner       case Intrinsic::dx_resource_store_typedbuffer:
806*0e51b54bSJustin Bogner         HasErrors |= lowerBufferStore(F, /*IsRaw=*/false);
8073f22756fSJustin Bogner         break;
808cba9bd5cSJustin Bogner       case Intrinsic::dx_resource_load_rawbuffer:
809cba9bd5cSJustin Bogner         HasErrors |= lowerRawBufferLoad(F);
810cba9bd5cSJustin Bogner         break;
811*0e51b54bSJustin Bogner       case Intrinsic::dx_resource_store_rawbuffer:
812*0e51b54bSJustin Bogner         HasErrors |= lowerBufferStore(F, /*IsRaw=*/true);
813*0e51b54bSJustin Bogner         break;
814aa07f922SJustin Bogner       case Intrinsic::dx_resource_updatecounter:
8151f250999Sjoaosaffran         HasErrors |= lowerUpdateCounter(F);
8161f250999Sjoaosaffran         break;
817481bce01Sjoaosaffran       // TODO: this can be removed when
818481bce01Sjoaosaffran       // https://github.com/llvm/llvm-project/issues/113192 is fixed
819481bce01Sjoaosaffran       case Intrinsic::dx_splitdouble:
820481bce01Sjoaosaffran         HasErrors |= replaceFunctionWithNamedStructOp(
821481bce01Sjoaosaffran             F, OpCode::SplitDouble,
822481bce01Sjoaosaffran             OpBuilder.getSplitDoubleType(M.getContext()),
823481bce01Sjoaosaffran             [&](CallInst *CI, CallInst *Op) {
824481bce01Sjoaosaffran               return replaceSplitDoubleCallUsages(CI, Op);
825481bce01Sjoaosaffran             });
826481bce01Sjoaosaffran         break;
82775e7ba8cSSarah Spall       case Intrinsic::ctpop:
82875e7ba8cSSarah Spall         HasErrors |= lowerCtpopToCountBits(F);
82975e7ba8cSSarah Spall         break;
83094da6bfbSJustin Bogner       }
83185285be9SXiang Li       Updated = true;
83285285be9SXiang Li     }
83390e84113SJustin Bogner     if (Updated && !HasErrors)
834aa61925eSJustin Bogner       cleanupHandleCasts();
835aa61925eSJustin Bogner 
83685285be9SXiang Li     return Updated;
83785285be9SXiang Li   }
838e56ad22bSJustin Bogner };
839e56ad22bSJustin Bogner } // namespace
84085285be9SXiang Li 
841aa61925eSJustin Bogner PreservedAnalyses DXILOpLowering::run(Module &M, ModuleAnalysisManager &MAM) {
8423eca15cbSJustin Bogner   DXILBindingMap &DBM = MAM.getResult<DXILResourceBindingAnalysis>(M);
8433eca15cbSJustin Bogner   DXILResourceTypeMap &DRTM = MAM.getResult<DXILResourceTypeAnalysis>(M);
844aa61925eSJustin Bogner 
8453eca15cbSJustin Bogner   bool MadeChanges = OpLowerer(M, DBM, DRTM).lowerIntrinsics();
846aa61925eSJustin Bogner   if (!MadeChanges)
84785285be9SXiang Li     return PreservedAnalyses::all();
848aa61925eSJustin Bogner   PreservedAnalyses PA;
8493eca15cbSJustin Bogner   PA.preserve<DXILResourceBindingAnalysis>();
850bfd05102SJustin Bogner   PA.preserve<DXILMetadataAnalysis>();
851bfd05102SJustin Bogner   PA.preserve<ShaderFlagsAnalysis>();
852aa61925eSJustin Bogner   return PA;
85385285be9SXiang Li }
85485285be9SXiang Li 
85585285be9SXiang Li namespace {
85685285be9SXiang Li class DXILOpLoweringLegacy : public ModulePass {
85785285be9SXiang Li public:
858e56ad22bSJustin Bogner   bool runOnModule(Module &M) override {
8593eca15cbSJustin Bogner     DXILBindingMap &DBM =
8603eca15cbSJustin Bogner         getAnalysis<DXILResourceBindingWrapperPass>().getBindingMap();
8613eca15cbSJustin Bogner     DXILResourceTypeMap &DRTM =
8623eca15cbSJustin Bogner         getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
863aa61925eSJustin Bogner 
8643eca15cbSJustin Bogner     return OpLowerer(M, DBM, DRTM).lowerIntrinsics();
865e56ad22bSJustin Bogner   }
86685285be9SXiang Li   StringRef getPassName() const override { return "DXIL Op Lowering"; }
86785285be9SXiang Li   DXILOpLoweringLegacy() : ModulePass(ID) {}
86885285be9SXiang Li 
86985285be9SXiang Li   static char ID; // Pass identification.
870de1a97dbSFarzon Lotfi   void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
8713eca15cbSJustin Bogner     AU.addRequired<DXILResourceTypeWrapperPass>();
8723eca15cbSJustin Bogner     AU.addRequired<DXILResourceBindingWrapperPass>();
8733eca15cbSJustin Bogner     AU.addPreserved<DXILResourceBindingWrapperPass>();
874bfd05102SJustin Bogner     AU.addPreserved<DXILResourceMDWrapper>();
875bfd05102SJustin Bogner     AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
876bfd05102SJustin Bogner     AU.addPreserved<ShaderFlagsAnalysisWrapper>();
877de1a97dbSFarzon Lotfi   }
87885285be9SXiang Li };
87985285be9SXiang Li char DXILOpLoweringLegacy::ID = 0;
88085285be9SXiang Li } // end anonymous namespace
88185285be9SXiang Li 
88285285be9SXiang Li INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering",
88385285be9SXiang Li                       false, false)
8843eca15cbSJustin Bogner INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass)
8853eca15cbSJustin Bogner INITIALIZE_PASS_DEPENDENCY(DXILResourceBindingWrapperPass)
88685285be9SXiang Li INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false,
88785285be9SXiang Li                     false)
88885285be9SXiang Li 
88985285be9SXiang Li ModulePass *llvm::createDXILOpLoweringLegacyPass() {
89085285be9SXiang Li   return new DXILOpLoweringLegacy();
89185285be9SXiang Li }
892