xref: /llvm-project/llvm/lib/Target/DirectX/DXILOpLowering.cpp (revision 0e51b54b7ac02b0920e20b8ccae26b32bd6b6982)
1 //===- DXILOpLowering.cpp - Lowering to DXIL operations -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "DXILOpLowering.h"
10 #include "DXILConstants.h"
11 #include "DXILIntrinsicExpansion.h"
12 #include "DXILOpBuilder.h"
13 #include "DXILResourceAnalysis.h"
14 #include "DXILShaderFlags.h"
15 #include "DirectX.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/Analysis/DXILMetadataAnalysis.h"
18 #include "llvm/Analysis/DXILResource.h"
19 #include "llvm/CodeGen/Passes.h"
20 #include "llvm/IR/DiagnosticInfo.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/Instruction.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/Intrinsics.h"
25 #include "llvm/IR/IntrinsicsDirectX.h"
26 #include "llvm/IR/Module.h"
27 #include "llvm/IR/PassManager.h"
28 #include "llvm/InitializePasses.h"
29 #include "llvm/Pass.h"
30 #include "llvm/Support/ErrorHandling.h"
31 
32 #define DEBUG_TYPE "dxil-op-lower"
33 
34 using namespace llvm;
35 using namespace llvm::dxil;
36 
37 static bool isVectorArgExpansion(Function &F) {
38   switch (F.getIntrinsicID()) {
39   case Intrinsic::dx_dot2:
40   case Intrinsic::dx_dot3:
41   case Intrinsic::dx_dot4:
42     return true;
43   }
44   return false;
45 }
46 
47 static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) {
48   SmallVector<Value *> ExtractedElements;
49   auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
50   for (unsigned I = 0; I < VecArg->getNumElements(); ++I) {
51     Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I);
52     Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index);
53     ExtractedElements.push_back(ExtractedElement);
54   }
55   return ExtractedElements;
56 }
57 
58 static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
59                                              IRBuilder<> &Builder) {
60   // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
61   unsigned NumOperands = Orig->getNumOperands() - 1;
62   assert(NumOperands > 0);
63   Value *Arg0 = Orig->getOperand(0);
64   [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
65   assert(VecArg0);
66   SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder);
67   for (unsigned I = 1; I < NumOperands; ++I) {
68     Value *Arg = Orig->getOperand(I);
69     [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
70     assert(VecArg);
71     assert(VecArg0->getElementType() == VecArg->getElementType());
72     assert(VecArg0->getNumElements() == VecArg->getNumElements());
73     auto NextOperandList = populateOperands(Arg, Builder);
74     NewOperands.append(NextOperandList.begin(), NextOperandList.end());
75   }
76   return NewOperands;
77 }
78 
79 namespace {
80 class OpLowerer {
81   Module &M;
82   DXILOpBuilder OpBuilder;
83   DXILBindingMap &DBM;
84   DXILResourceTypeMap &DRTM;
85   SmallVector<CallInst *> CleanupCasts;
86 
87 public:
88   OpLowerer(Module &M, DXILBindingMap &DBM, DXILResourceTypeMap &DRTM)
89       : M(M), OpBuilder(M), DBM(DBM), DRTM(DRTM) {}
90 
91   /// Replace every call to \c F using \c ReplaceCall, and then erase \c F. If
92   /// there is an error replacing a call, we emit a diagnostic and return true.
93   [[nodiscard]] bool
94   replaceFunction(Function &F,
95                   llvm::function_ref<Error(CallInst *CI)> ReplaceCall) {
96     for (User *U : make_early_inc_range(F.users())) {
97       CallInst *CI = dyn_cast<CallInst>(U);
98       if (!CI)
99         continue;
100 
101       if (Error E = ReplaceCall(CI)) {
102         std::string Message(toString(std::move(E)));
103         DiagnosticInfoUnsupported Diag(*CI->getFunction(), Message,
104                                        CI->getDebugLoc());
105         M.getContext().diagnose(Diag);
106         return true;
107       }
108     }
109     if (F.user_empty())
110       F.eraseFromParent();
111     return false;
112   }
113 
114   struct IntrinArgSelect {
115     enum class Type {
116 #define DXIL_OP_INTRINSIC_ARG_SELECT_TYPE(name) name,
117 #include "DXILOperation.inc"
118     };
119     Type Type;
120     int Value;
121   };
122 
123   [[nodiscard]] bool
124   replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp,
125                         ArrayRef<IntrinArgSelect> ArgSelects) {
126     bool IsVectorArgExpansion = isVectorArgExpansion(F);
127     assert(!(IsVectorArgExpansion && ArgSelects.size()) &&
128            "Cann't do vector arg expansion when using arg selects.");
129     return replaceFunction(F, [&](CallInst *CI) -> Error {
130       OpBuilder.getIRB().SetInsertPoint(CI);
131       SmallVector<Value *> Args;
132       if (ArgSelects.size()) {
133         for (const IntrinArgSelect &A : ArgSelects) {
134           switch (A.Type) {
135           case IntrinArgSelect::Type::Index:
136             Args.push_back(CI->getArgOperand(A.Value));
137             break;
138           case IntrinArgSelect::Type::I8:
139             Args.push_back(OpBuilder.getIRB().getInt8((uint8_t)A.Value));
140             break;
141           case IntrinArgSelect::Type::I32:
142             Args.push_back(OpBuilder.getIRB().getInt32(A.Value));
143             break;
144           }
145         }
146       } else if (IsVectorArgExpansion) {
147         Args = argVectorFlatten(CI, OpBuilder.getIRB());
148       } else {
149         Args.append(CI->arg_begin(), CI->arg_end());
150       }
151 
152       Expected<CallInst *> OpCall =
153           OpBuilder.tryCreateOp(DXILOp, Args, CI->getName(), F.getReturnType());
154       if (Error E = OpCall.takeError())
155         return E;
156 
157       CI->replaceAllUsesWith(*OpCall);
158       CI->eraseFromParent();
159       return Error::success();
160     });
161   }
162 
163   [[nodiscard]] bool replaceFunctionWithNamedStructOp(
164       Function &F, dxil::OpCode DXILOp, Type *NewRetTy,
165       llvm::function_ref<Error(CallInst *CI, CallInst *Op)> ReplaceUses) {
166     bool IsVectorArgExpansion = isVectorArgExpansion(F);
167     return replaceFunction(F, [&](CallInst *CI) -> Error {
168       SmallVector<Value *> Args;
169       OpBuilder.getIRB().SetInsertPoint(CI);
170       if (IsVectorArgExpansion) {
171         SmallVector<Value *> NewArgs = argVectorFlatten(CI, OpBuilder.getIRB());
172         Args.append(NewArgs.begin(), NewArgs.end());
173       } else
174         Args.append(CI->arg_begin(), CI->arg_end());
175 
176       Expected<CallInst *> OpCall =
177           OpBuilder.tryCreateOp(DXILOp, Args, CI->getName(), NewRetTy);
178       if (Error E = OpCall.takeError())
179         return E;
180       if (Error E = ReplaceUses(CI, *OpCall))
181         return E;
182 
183       return Error::success();
184     });
185   }
186 
187   /// Create a cast between a `target("dx")` type and `dx.types.Handle`, which
188   /// is intended to be removed by the end of lowering. This is used to allow
189   /// lowering of ops which need to change their return or argument types in a
190   /// piecemeal way - we can add the casts in to avoid updating all of the uses
191   /// or defs, and by the end all of the casts will be redundant.
192   Value *createTmpHandleCast(Value *V, Type *Ty) {
193     CallInst *Cast = OpBuilder.getIRB().CreateIntrinsic(
194         Intrinsic::dx_resource_casthandle, {Ty, V->getType()}, {V});
195     CleanupCasts.push_back(Cast);
196     return Cast;
197   }
198 
199   void cleanupHandleCasts() {
200     SmallVector<CallInst *> ToRemove;
201     SmallVector<Function *> CastFns;
202 
203     for (CallInst *Cast : CleanupCasts) {
204       // These casts were only put in to ease the move from `target("dx")` types
205       // to `dx.types.Handle in a piecemeal way. At this point, all of the
206       // non-cast uses should now be `dx.types.Handle`, and remaining casts
207       // should all form pairs to and from the now unused `target("dx")` type.
208       CastFns.push_back(Cast->getCalledFunction());
209 
210       // If the cast is not to `dx.types.Handle`, it should be the first part of
211       // the pair. Keep track so we can remove it once it has no more uses.
212       if (Cast->getType() != OpBuilder.getHandleType()) {
213         ToRemove.push_back(Cast);
214         continue;
215       }
216       // Otherwise, we're the second handle in a pair. Forward the arguments and
217       // remove the (second) cast.
218       CallInst *Def = cast<CallInst>(Cast->getOperand(0));
219       assert(Def->getIntrinsicID() == Intrinsic::dx_resource_casthandle &&
220              "Unbalanced pair of temporary handle casts");
221       Cast->replaceAllUsesWith(Def->getOperand(0));
222       Cast->eraseFromParent();
223     }
224     for (CallInst *Cast : ToRemove) {
225       assert(Cast->user_empty() && "Temporary handle cast still has users");
226       Cast->eraseFromParent();
227     }
228 
229     // Deduplicate the cast functions so that we only erase each one once.
230     llvm::sort(CastFns);
231     CastFns.erase(llvm::unique(CastFns), CastFns.end());
232     for (Function *F : CastFns)
233       F->eraseFromParent();
234 
235     CleanupCasts.clear();
236   }
237 
238   // Remove the resource global associated with the handleFromBinding call
239   // instruction and their uses as they aren't needed anymore.
240   // TODO: We should verify that all the globals get removed.
241   // It's expected we'll need a custom pass in the future that will eliminate
242   // the need for this here.
243   void removeResourceGlobals(CallInst *CI) {
244     for (User *User : make_early_inc_range(CI->users())) {
245       if (StoreInst *Store = dyn_cast<StoreInst>(User)) {
246         Value *V = Store->getOperand(1);
247         Store->eraseFromParent();
248         if (GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
249           if (GV->use_empty()) {
250             GV->removeDeadConstantUsers();
251             GV->eraseFromParent();
252           }
253       }
254     }
255   }
256 
257   [[nodiscard]] bool lowerToCreateHandle(Function &F) {
258     IRBuilder<> &IRB = OpBuilder.getIRB();
259     Type *Int8Ty = IRB.getInt8Ty();
260     Type *Int32Ty = IRB.getInt32Ty();
261 
262     return replaceFunction(F, [&](CallInst *CI) -> Error {
263       IRB.SetInsertPoint(CI);
264 
265       auto *It = DBM.find(CI);
266       assert(It != DBM.end() && "Resource not in map?");
267       dxil::ResourceBindingInfo &RI = *It;
268 
269       const auto &Binding = RI.getBinding();
270       dxil::ResourceClass RC = DRTM[RI.getHandleTy()].getResourceClass();
271 
272       Value *IndexOp = CI->getArgOperand(3);
273       if (Binding.LowerBound != 0)
274         IndexOp = IRB.CreateAdd(IndexOp,
275                                 ConstantInt::get(Int32Ty, Binding.LowerBound));
276 
277       std::array<Value *, 4> Args{
278           ConstantInt::get(Int8Ty, llvm::to_underlying(RC)),
279           ConstantInt::get(Int32Ty, Binding.RecordID), IndexOp,
280           CI->getArgOperand(4)};
281       Expected<CallInst *> OpCall =
282           OpBuilder.tryCreateOp(OpCode::CreateHandle, Args, CI->getName());
283       if (Error E = OpCall.takeError())
284         return E;
285 
286       Value *Cast = createTmpHandleCast(*OpCall, CI->getType());
287 
288       removeResourceGlobals(CI);
289 
290       CI->replaceAllUsesWith(Cast);
291       CI->eraseFromParent();
292       return Error::success();
293     });
294   }
295 
296   [[nodiscard]] bool lowerToBindAndAnnotateHandle(Function &F) {
297     IRBuilder<> &IRB = OpBuilder.getIRB();
298     Type *Int32Ty = IRB.getInt32Ty();
299 
300     return replaceFunction(F, [&](CallInst *CI) -> Error {
301       IRB.SetInsertPoint(CI);
302 
303       auto *It = DBM.find(CI);
304       assert(It != DBM.end() && "Resource not in map?");
305       dxil::ResourceBindingInfo &RI = *It;
306 
307       const auto &Binding = RI.getBinding();
308       dxil::ResourceTypeInfo &RTI = DRTM[RI.getHandleTy()];
309       dxil::ResourceClass RC = RTI.getResourceClass();
310 
311       Value *IndexOp = CI->getArgOperand(3);
312       if (Binding.LowerBound != 0)
313         IndexOp = IRB.CreateAdd(IndexOp,
314                                 ConstantInt::get(Int32Ty, Binding.LowerBound));
315 
316       std::pair<uint32_t, uint32_t> Props =
317           RI.getAnnotateProps(*F.getParent(), RTI);
318 
319       // For `CreateHandleFromBinding` we need the upper bound rather than the
320       // size, so we need to be careful about the difference for "unbounded".
321       uint32_t Unbounded = std::numeric_limits<uint32_t>::max();
322       uint32_t UpperBound = Binding.Size == Unbounded
323                                 ? Unbounded
324                                 : Binding.LowerBound + Binding.Size - 1;
325       Constant *ResBind = OpBuilder.getResBind(Binding.LowerBound, UpperBound,
326                                                Binding.Space, RC);
327       std::array<Value *, 3> BindArgs{ResBind, IndexOp, CI->getArgOperand(4)};
328       Expected<CallInst *> OpBind = OpBuilder.tryCreateOp(
329           OpCode::CreateHandleFromBinding, BindArgs, CI->getName());
330       if (Error E = OpBind.takeError())
331         return E;
332 
333       std::array<Value *, 2> AnnotateArgs{
334           *OpBind, OpBuilder.getResProps(Props.first, Props.second)};
335       Expected<CallInst *> OpAnnotate = OpBuilder.tryCreateOp(
336           OpCode::AnnotateHandle, AnnotateArgs,
337           CI->hasName() ? CI->getName() + "_annot" : Twine());
338       if (Error E = OpAnnotate.takeError())
339         return E;
340 
341       Value *Cast = createTmpHandleCast(*OpAnnotate, CI->getType());
342 
343       removeResourceGlobals(CI);
344 
345       CI->replaceAllUsesWith(Cast);
346       CI->eraseFromParent();
347 
348       return Error::success();
349     });
350   }
351 
352   /// Lower `dx.resource.handlefrombinding` intrinsics depending on the shader
353   /// model and taking into account binding information from
354   /// DXILResourceBindingAnalysis.
355   bool lowerHandleFromBinding(Function &F) {
356     Triple TT(Triple(M.getTargetTriple()));
357     if (TT.getDXILVersion() < VersionTuple(1, 6))
358       return lowerToCreateHandle(F);
359     return lowerToBindAndAnnotateHandle(F);
360   }
361 
362   Error replaceSplitDoubleCallUsages(CallInst *Intrin, CallInst *Op) {
363     for (Use &U : make_early_inc_range(Intrin->uses())) {
364       if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
365 
366         if (EVI->getNumIndices() != 1)
367           return createStringError(std::errc::invalid_argument,
368                                    "Splitdouble has only 2 elements");
369         EVI->setOperand(0, Op);
370       } else {
371         return make_error<StringError>(
372             "Splitdouble use is not ExtractValueInst",
373             inconvertibleErrorCode());
374       }
375     }
376 
377     Intrin->eraseFromParent();
378 
379     return Error::success();
380   }
381 
382   /// Replace uses of \c Intrin with the values in the `dx.ResRet` of \c Op.
383   /// Since we expect to be post-scalarization, make an effort to avoid vectors.
384   Error replaceResRetUses(CallInst *Intrin, CallInst *Op, bool HasCheckBit) {
385     IRBuilder<> &IRB = OpBuilder.getIRB();
386 
387     Instruction *OldResult = Intrin;
388     Type *OldTy = Intrin->getType();
389 
390     if (HasCheckBit) {
391       auto *ST = cast<StructType>(OldTy);
392 
393       Value *CheckOp = nullptr;
394       Type *Int32Ty = IRB.getInt32Ty();
395       for (Use &U : make_early_inc_range(OldResult->uses())) {
396         if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
397           ArrayRef<unsigned> Indices = EVI->getIndices();
398           assert(Indices.size() == 1);
399           // We're only interested in uses of the check bit for now.
400           if (Indices[0] != 1)
401             continue;
402           if (!CheckOp) {
403             Value *NewEVI = IRB.CreateExtractValue(Op, 4);
404             Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
405                 OpCode::CheckAccessFullyMapped, {NewEVI},
406                 OldResult->hasName() ? OldResult->getName() + "_check"
407                                      : Twine(),
408                 Int32Ty);
409             if (Error E = OpCall.takeError())
410               return E;
411             CheckOp = *OpCall;
412           }
413           EVI->replaceAllUsesWith(CheckOp);
414           EVI->eraseFromParent();
415         }
416       }
417 
418       if (OldResult->use_empty()) {
419         // Only the check bit was used, so we're done here.
420         OldResult->eraseFromParent();
421         return Error::success();
422       }
423 
424       assert(OldResult->hasOneUse() &&
425              isa<ExtractValueInst>(*OldResult->user_begin()) &&
426              "Expected only use to be extract of first element");
427       OldResult = cast<Instruction>(*OldResult->user_begin());
428       OldTy = ST->getElementType(0);
429     }
430 
431     // For scalars, we just extract the first element.
432     if (!isa<FixedVectorType>(OldTy)) {
433       Value *EVI = IRB.CreateExtractValue(Op, 0);
434       OldResult->replaceAllUsesWith(EVI);
435       OldResult->eraseFromParent();
436       if (OldResult != Intrin) {
437         assert(Intrin->use_empty() && "Intrinsic still has uses?");
438         Intrin->eraseFromParent();
439       }
440       return Error::success();
441     }
442 
443     std::array<Value *, 4> Extracts = {};
444     SmallVector<ExtractElementInst *> DynamicAccesses;
445 
446     // The users of the operation should all be scalarized, so we attempt to
447     // replace the extractelements with extractvalues directly.
448     for (Use &U : make_early_inc_range(OldResult->uses())) {
449       if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser())) {
450         if (auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
451           size_t IndexVal = IndexOp->getZExtValue();
452           assert(IndexVal < 4 && "Index into buffer load out of range");
453           if (!Extracts[IndexVal])
454             Extracts[IndexVal] = IRB.CreateExtractValue(Op, IndexVal);
455           EEI->replaceAllUsesWith(Extracts[IndexVal]);
456           EEI->eraseFromParent();
457         } else {
458           DynamicAccesses.push_back(EEI);
459         }
460       }
461     }
462 
463     const auto *VecTy = cast<FixedVectorType>(OldTy);
464     const unsigned N = VecTy->getNumElements();
465 
466     // If there's a dynamic access we need to round trip through stack memory so
467     // that we don't leave vectors around.
468     if (!DynamicAccesses.empty()) {
469       Type *Int32Ty = IRB.getInt32Ty();
470       Constant *Zero = ConstantInt::get(Int32Ty, 0);
471 
472       Type *ElTy = VecTy->getElementType();
473       Type *ArrayTy = ArrayType::get(ElTy, N);
474       Value *Alloca = IRB.CreateAlloca(ArrayTy);
475 
476       for (int I = 0, E = N; I != E; ++I) {
477         if (!Extracts[I])
478           Extracts[I] = IRB.CreateExtractValue(Op, I);
479         Value *GEP = IRB.CreateInBoundsGEP(
480             ArrayTy, Alloca, {Zero, ConstantInt::get(Int32Ty, I)});
481         IRB.CreateStore(Extracts[I], GEP);
482       }
483 
484       for (ExtractElementInst *EEI : DynamicAccesses) {
485         Value *GEP = IRB.CreateInBoundsGEP(ArrayTy, Alloca,
486                                            {Zero, EEI->getIndexOperand()});
487         Value *Load = IRB.CreateLoad(ElTy, GEP);
488         EEI->replaceAllUsesWith(Load);
489         EEI->eraseFromParent();
490       }
491     }
492 
493     // If we still have uses, then we're not fully scalarized and need to
494     // recreate the vector. This should only happen for things like exported
495     // functions from libraries.
496     if (!OldResult->use_empty()) {
497       for (int I = 0, E = N; I != E; ++I)
498         if (!Extracts[I])
499           Extracts[I] = IRB.CreateExtractValue(Op, I);
500 
501       Value *Vec = UndefValue::get(OldTy);
502       for (int I = 0, E = N; I != E; ++I)
503         Vec = IRB.CreateInsertElement(Vec, Extracts[I], I);
504       OldResult->replaceAllUsesWith(Vec);
505     }
506 
507     OldResult->eraseFromParent();
508     if (OldResult != Intrin) {
509       assert(Intrin->use_empty() && "Intrinsic still has uses?");
510       Intrin->eraseFromParent();
511     }
512 
513     return Error::success();
514   }
515 
516   [[nodiscard]] bool lowerTypedBufferLoad(Function &F, bool HasCheckBit) {
517     IRBuilder<> &IRB = OpBuilder.getIRB();
518     Type *Int32Ty = IRB.getInt32Ty();
519 
520     return replaceFunction(F, [&](CallInst *CI) -> Error {
521       IRB.SetInsertPoint(CI);
522 
523       Value *Handle =
524           createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
525       Value *Index0 = CI->getArgOperand(1);
526       Value *Index1 = UndefValue::get(Int32Ty);
527 
528       Type *OldTy = CI->getType();
529       if (HasCheckBit)
530         OldTy = cast<StructType>(OldTy)->getElementType(0);
531       Type *NewRetTy = OpBuilder.getResRetType(OldTy->getScalarType());
532 
533       std::array<Value *, 3> Args{Handle, Index0, Index1};
534       Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
535           OpCode::BufferLoad, Args, CI->getName(), NewRetTy);
536       if (Error E = OpCall.takeError())
537         return E;
538       if (Error E = replaceResRetUses(CI, *OpCall, HasCheckBit))
539         return E;
540 
541       return Error::success();
542     });
543   }
544 
545   [[nodiscard]] bool lowerRawBufferLoad(Function &F) {
546     Triple TT(Triple(M.getTargetTriple()));
547     VersionTuple DXILVersion = TT.getDXILVersion();
548     const DataLayout &DL = F.getDataLayout();
549     IRBuilder<> &IRB = OpBuilder.getIRB();
550     Type *Int8Ty = IRB.getInt8Ty();
551     Type *Int32Ty = IRB.getInt32Ty();
552 
553     return replaceFunction(F, [&](CallInst *CI) -> Error {
554       IRB.SetInsertPoint(CI);
555 
556       Type *OldTy = cast<StructType>(CI->getType())->getElementType(0);
557       Type *ScalarTy = OldTy->getScalarType();
558       Type *NewRetTy = OpBuilder.getResRetType(ScalarTy);
559 
560       Value *Handle =
561           createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
562       Value *Index0 = CI->getArgOperand(1);
563       Value *Index1 = CI->getArgOperand(2);
564       uint64_t NumElements =
565           DL.getTypeSizeInBits(OldTy) / DL.getTypeSizeInBits(ScalarTy);
566       Value *Mask = ConstantInt::get(Int8Ty, ~(~0U << NumElements));
567       Value *Align =
568           ConstantInt::get(Int32Ty, DL.getPrefTypeAlign(ScalarTy).value());
569 
570       Expected<CallInst *> OpCall =
571           DXILVersion >= VersionTuple(1, 2)
572               ? OpBuilder.tryCreateOp(OpCode::RawBufferLoad,
573                                       {Handle, Index0, Index1, Mask, Align},
574                                       CI->getName(), NewRetTy)
575               : OpBuilder.tryCreateOp(OpCode::BufferLoad,
576                                       {Handle, Index0, Index1}, CI->getName(),
577                                       NewRetTy);
578       if (Error E = OpCall.takeError())
579         return E;
580       if (Error E = replaceResRetUses(CI, *OpCall, /*HasCheckBit=*/true))
581         return E;
582 
583       return Error::success();
584     });
585   }
586 
587   [[nodiscard]] bool lowerUpdateCounter(Function &F) {
588     IRBuilder<> &IRB = OpBuilder.getIRB();
589     Type *Int32Ty = IRB.getInt32Ty();
590 
591     return replaceFunction(F, [&](CallInst *CI) -> Error {
592       IRB.SetInsertPoint(CI);
593       Value *Handle =
594           createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
595       Value *Op1 = CI->getArgOperand(1);
596 
597       std::array<Value *, 2> Args{Handle, Op1};
598 
599       Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
600           OpCode::UpdateCounter, Args, CI->getName(), Int32Ty);
601 
602       if (Error E = OpCall.takeError())
603         return E;
604 
605       CI->replaceAllUsesWith(*OpCall);
606       CI->eraseFromParent();
607       return Error::success();
608     });
609   }
610 
611   [[nodiscard]] bool lowerGetPointer(Function &F) {
612     // These should have already been handled in DXILResourceAccess, so we can
613     // just clean up the dead prototype.
614     assert(F.user_empty() && "getpointer operations should have been removed");
615     F.eraseFromParent();
616     return false;
617   }
618 
619   [[nodiscard]] bool lowerBufferStore(Function &F, bool IsRaw) {
620     Triple TT(Triple(M.getTargetTriple()));
621     VersionTuple DXILVersion = TT.getDXILVersion();
622     const DataLayout &DL = F.getDataLayout();
623     IRBuilder<> &IRB = OpBuilder.getIRB();
624     Type *Int8Ty = IRB.getInt8Ty();
625     Type *Int32Ty = IRB.getInt32Ty();
626 
627     return replaceFunction(F, [&](CallInst *CI) -> Error {
628       IRB.SetInsertPoint(CI);
629 
630       Value *Handle =
631           createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
632       Value *Index0 = CI->getArgOperand(1);
633       Value *Index1 = IsRaw ? CI->getArgOperand(2) : UndefValue::get(Int32Ty);
634 
635       Value *Data = CI->getArgOperand(IsRaw ? 3 : 2);
636       Type *DataTy = Data->getType();
637       Type *ScalarTy = DataTy->getScalarType();
638 
639       uint64_t NumElements =
640           DL.getTypeSizeInBits(DataTy) / DL.getTypeSizeInBits(ScalarTy);
641       Value *Mask = ConstantInt::get(Int8Ty, ~(~0U << NumElements));
642 
643       // TODO: check that we only have vector or scalar...
644       if (!IsRaw && NumElements != 4)
645         return make_error<StringError>(
646             "typedBufferStore data must be a vector of 4 elements",
647             inconvertibleErrorCode());
648       else if (NumElements > 4)
649         return make_error<StringError>(
650             "rawBufferStore data must have at most 4 elements",
651             inconvertibleErrorCode());
652 
653       std::array<Value *, 4> DataElements{nullptr, nullptr, nullptr, nullptr};
654       if (DataTy == ScalarTy)
655         DataElements[0] = Data;
656       else {
657         // Since we're post-scalarizer, if we see a vector here it's likely
658         // constructed solely for the argument of the store. Just use the scalar
659         // values from before they're inserted into the temporary.
660         auto *IEI = dyn_cast<InsertElementInst>(Data);
661         while (IEI) {
662           auto *IndexOp = dyn_cast<ConstantInt>(IEI->getOperand(2));
663           if (!IndexOp)
664             break;
665           size_t IndexVal = IndexOp->getZExtValue();
666           assert(IndexVal < 4 && "Too many elements for buffer store");
667           DataElements[IndexVal] = IEI->getOperand(1);
668           IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0));
669         }
670       }
671 
672       // If for some reason we weren't able to forward the arguments from the
673       // scalarizer artifact, then we may need to actually extract elements from
674       // the vector.
675       for (int I = 0, E = NumElements; I < E; ++I)
676         if (DataElements[I] == nullptr)
677           DataElements[I] =
678               IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, I));
679       // For any elements beyond the length of the vector, fill up with undef.
680       for (int I = NumElements, E = 4; I < E; ++I)
681         if (DataElements[I] == nullptr)
682           DataElements[I] = UndefValue::get(ScalarTy);
683 
684       dxil::OpCode Op = OpCode::BufferStore;
685       SmallVector<Value *, 9> Args{
686           Handle,          Index0,          Index1,          DataElements[0],
687           DataElements[1], DataElements[2], DataElements[3], Mask};
688       if (IsRaw && DXILVersion >= VersionTuple(1, 2)) {
689         Op = OpCode::RawBufferStore;
690         // RawBufferStore requires the alignment
691         Args.push_back(
692             ConstantInt::get(Int32Ty, DL.getPrefTypeAlign(ScalarTy).value()));
693       }
694       Expected<CallInst *> OpCall =
695           OpBuilder.tryCreateOp(Op, Args, CI->getName());
696       if (Error E = OpCall.takeError())
697         return E;
698 
699       CI->eraseFromParent();
700       // Clean up any leftover `insertelement`s
701       auto *IEI = dyn_cast<InsertElementInst>(Data);
702       while (IEI && IEI->use_empty()) {
703         InsertElementInst *Tmp = IEI;
704         IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0));
705         Tmp->eraseFromParent();
706       }
707 
708       return Error::success();
709     });
710   }
711 
712   [[nodiscard]] bool lowerCtpopToCountBits(Function &F) {
713     IRBuilder<> &IRB = OpBuilder.getIRB();
714     Type *Int32Ty = IRB.getInt32Ty();
715 
716     return replaceFunction(F, [&](CallInst *CI) -> Error {
717       IRB.SetInsertPoint(CI);
718       SmallVector<Value *> Args;
719       Args.append(CI->arg_begin(), CI->arg_end());
720 
721       Type *RetTy = Int32Ty;
722       Type *FRT = F.getReturnType();
723       if (const auto *VT = dyn_cast<VectorType>(FRT))
724         RetTy = VectorType::get(RetTy, VT);
725 
726       Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
727           dxil::OpCode::CountBits, Args, CI->getName(), RetTy);
728       if (Error E = OpCall.takeError())
729         return E;
730 
731       // If the result type is 32 bits we can do a direct replacement.
732       if (FRT->isIntOrIntVectorTy(32)) {
733         CI->replaceAllUsesWith(*OpCall);
734         CI->eraseFromParent();
735         return Error::success();
736       }
737 
738       unsigned CastOp;
739       unsigned CastOp2;
740       if (FRT->isIntOrIntVectorTy(16)) {
741         CastOp = Instruction::ZExt;
742         CastOp2 = Instruction::SExt;
743       } else { // must be 64 bits
744         assert(FRT->isIntOrIntVectorTy(64) &&
745                "Currently only lowering 16, 32, or 64 bit ctpop to CountBits \
746                 is supported.");
747         CastOp = Instruction::Trunc;
748         CastOp2 = Instruction::Trunc;
749       }
750 
751       // It is correct to replace the ctpop with the dxil op and
752       // remove all casts to i32
753       bool NeedsCast = false;
754       for (User *User : make_early_inc_range(CI->users())) {
755         Instruction *I = dyn_cast<Instruction>(User);
756         if (I && (I->getOpcode() == CastOp || I->getOpcode() == CastOp2) &&
757             I->getType() == RetTy) {
758           I->replaceAllUsesWith(*OpCall);
759           I->eraseFromParent();
760         } else
761           NeedsCast = true;
762       }
763 
764       // It is correct to replace a ctpop with the dxil op and
765       // a cast from i32 to the return type of the ctpop
766       // the cast is emitted here if there is a non-cast to i32
767       // instr which uses the ctpop
768       if (NeedsCast) {
769         Value *Cast =
770             IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(), "ctpop.cast");
771         CI->replaceAllUsesWith(Cast);
772       }
773 
774       CI->eraseFromParent();
775       return Error::success();
776     });
777   }
778 
779   bool lowerIntrinsics() {
780     bool Updated = false;
781     bool HasErrors = false;
782 
783     for (Function &F : make_early_inc_range(M.functions())) {
784       if (!F.isDeclaration())
785         continue;
786       Intrinsic::ID ID = F.getIntrinsicID();
787       switch (ID) {
788       default:
789         continue;
790 #define DXIL_OP_INTRINSIC(OpCode, Intrin, ...)                                 \
791   case Intrin:                                                                 \
792     HasErrors |= replaceFunctionWithOp(                                        \
793         F, OpCode, ArrayRef<IntrinArgSelect>{__VA_ARGS__});                    \
794     break;
795 #include "DXILOperation.inc"
796       case Intrinsic::dx_resource_handlefrombinding:
797         HasErrors |= lowerHandleFromBinding(F);
798         break;
799       case Intrinsic::dx_resource_getpointer:
800         HasErrors |= lowerGetPointer(F);
801         break;
802       case Intrinsic::dx_resource_load_typedbuffer:
803         HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/true);
804         break;
805       case Intrinsic::dx_resource_store_typedbuffer:
806         HasErrors |= lowerBufferStore(F, /*IsRaw=*/false);
807         break;
808       case Intrinsic::dx_resource_load_rawbuffer:
809         HasErrors |= lowerRawBufferLoad(F);
810         break;
811       case Intrinsic::dx_resource_store_rawbuffer:
812         HasErrors |= lowerBufferStore(F, /*IsRaw=*/true);
813         break;
814       case Intrinsic::dx_resource_updatecounter:
815         HasErrors |= lowerUpdateCounter(F);
816         break;
817       // TODO: this can be removed when
818       // https://github.com/llvm/llvm-project/issues/113192 is fixed
819       case Intrinsic::dx_splitdouble:
820         HasErrors |= replaceFunctionWithNamedStructOp(
821             F, OpCode::SplitDouble,
822             OpBuilder.getSplitDoubleType(M.getContext()),
823             [&](CallInst *CI, CallInst *Op) {
824               return replaceSplitDoubleCallUsages(CI, Op);
825             });
826         break;
827       case Intrinsic::ctpop:
828         HasErrors |= lowerCtpopToCountBits(F);
829         break;
830       }
831       Updated = true;
832     }
833     if (Updated && !HasErrors)
834       cleanupHandleCasts();
835 
836     return Updated;
837   }
838 };
839 } // namespace
840 
841 PreservedAnalyses DXILOpLowering::run(Module &M, ModuleAnalysisManager &MAM) {
842   DXILBindingMap &DBM = MAM.getResult<DXILResourceBindingAnalysis>(M);
843   DXILResourceTypeMap &DRTM = MAM.getResult<DXILResourceTypeAnalysis>(M);
844 
845   bool MadeChanges = OpLowerer(M, DBM, DRTM).lowerIntrinsics();
846   if (!MadeChanges)
847     return PreservedAnalyses::all();
848   PreservedAnalyses PA;
849   PA.preserve<DXILResourceBindingAnalysis>();
850   PA.preserve<DXILMetadataAnalysis>();
851   PA.preserve<ShaderFlagsAnalysis>();
852   return PA;
853 }
854 
855 namespace {
856 class DXILOpLoweringLegacy : public ModulePass {
857 public:
858   bool runOnModule(Module &M) override {
859     DXILBindingMap &DBM =
860         getAnalysis<DXILResourceBindingWrapperPass>().getBindingMap();
861     DXILResourceTypeMap &DRTM =
862         getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
863 
864     return OpLowerer(M, DBM, DRTM).lowerIntrinsics();
865   }
866   StringRef getPassName() const override { return "DXIL Op Lowering"; }
867   DXILOpLoweringLegacy() : ModulePass(ID) {}
868 
869   static char ID; // Pass identification.
870   void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
871     AU.addRequired<DXILResourceTypeWrapperPass>();
872     AU.addRequired<DXILResourceBindingWrapperPass>();
873     AU.addPreserved<DXILResourceBindingWrapperPass>();
874     AU.addPreserved<DXILResourceMDWrapper>();
875     AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
876     AU.addPreserved<ShaderFlagsAnalysisWrapper>();
877   }
878 };
879 char DXILOpLoweringLegacy::ID = 0;
880 } // end anonymous namespace
881 
882 INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering",
883                       false, false)
884 INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass)
885 INITIALIZE_PASS_DEPENDENCY(DXILResourceBindingWrapperPass)
886 INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false,
887                     false)
888 
889 ModulePass *llvm::createDXILOpLoweringLegacyPass() {
890   return new DXILOpLoweringLegacy();
891 }
892