xref: /llvm-project/llvm/lib/Target/DirectX/DXILResourceAccess.cpp (revision 2f39d138dc38a1fdf4754e4e26dd0aeb7409b13d)
1 //===- DXILResourceAccess.cpp - Resource access via load/store ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "DXILResourceAccess.h"
10 #include "DirectX.h"
11 #include "llvm/Analysis/DXILResource.h"
12 #include "llvm/IR/Dominators.h"
13 #include "llvm/IR/IRBuilder.h"
14 #include "llvm/IR/Instructions.h"
15 #include "llvm/IR/IntrinsicInst.h"
16 #include "llvm/IR/Intrinsics.h"
17 #include "llvm/IR/IntrinsicsDirectX.h"
18 #include "llvm/InitializePasses.h"
19 
20 #define DEBUG_TYPE "dxil-resource-access"
21 
22 using namespace llvm;
23 
24 static Value *calculateGEPOffset(GetElementPtrInst *GEP, Value *PrevOffset,
25                                  dxil::ResourceTypeInfo &RTI) {
26   assert(!PrevOffset && "Non-constant GEP chains not handled yet");
27 
28   const DataLayout &DL = GEP->getDataLayout();
29 
30   uint64_t ScalarSize = 1;
31   if (RTI.isTyped()) {
32     Type *ContainedType = RTI.getHandleTy()->getTypeParameter(0);
33     // We need the size of an element in bytes so that we can calculate the
34     // offset in elements given a total offset in bytes.
35     Type *ScalarType = ContainedType->getScalarType();
36     ScalarSize = DL.getTypeSizeInBits(ScalarType) / 8;
37   }
38 
39   APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
40   if (GEP->accumulateConstantOffset(DL, ConstantOffset)) {
41     APInt Scaled = ConstantOffset.udiv(ScalarSize);
42     return ConstantInt::get(Type::getInt32Ty(GEP->getContext()), Scaled);
43   }
44 
45   auto IndexIt = GEP->idx_begin();
46   assert(cast<ConstantInt>(IndexIt)->getZExtValue() == 0 &&
47          "GEP is not indexing through pointer");
48   ++IndexIt;
49   Value *Offset = *IndexIt;
50   assert(++IndexIt == GEP->idx_end() && "Too many indices in GEP");
51   return Offset;
52 }
53 
54 static void createTypedBufferStore(IntrinsicInst *II, StoreInst *SI,
55                                    Value *Offset, dxil::ResourceTypeInfo &RTI) {
56   IRBuilder<> Builder(SI);
57   Type *ContainedType = RTI.getHandleTy()->getTypeParameter(0);
58   Type *LoadType = StructType::get(ContainedType, Builder.getInt1Ty());
59 
60   Value *V = SI->getValueOperand();
61   if (V->getType() == ContainedType) {
62     // V is already the right type.
63     assert(!Offset && "store of whole element has offset?");
64   } else if (V->getType() == ContainedType->getScalarType()) {
65     // We're storing a scalar, so we need to load the current value and only
66     // replace the relevant part.
67     auto *Load = Builder.CreateIntrinsic(
68         LoadType, Intrinsic::dx_resource_load_typedbuffer,
69         {II->getOperand(0), II->getOperand(1)});
70     auto *Struct = Builder.CreateExtractValue(Load, {0});
71 
72     // If we have an offset from seeing a GEP earlier, use that. Otherwise, 0.
73     if (!Offset)
74       Offset = ConstantInt::get(Builder.getInt32Ty(), 0);
75     V = Builder.CreateInsertElement(Struct, V, Offset);
76   } else {
77     llvm_unreachable("Store to typed resource has invalid type");
78   }
79 
80   auto *Inst = Builder.CreateIntrinsic(
81       Builder.getVoidTy(), Intrinsic::dx_resource_store_typedbuffer,
82       {II->getOperand(0), II->getOperand(1), V});
83   SI->replaceAllUsesWith(Inst);
84 }
85 
86 static void createRawStore(IntrinsicInst *II, StoreInst *SI, Value *Offset) {
87   IRBuilder<> Builder(SI);
88 
89   if (!Offset)
90     Offset = ConstantInt::get(Builder.getInt32Ty(), 0);
91   Value *V = SI->getValueOperand();
92   // TODO: break up larger types
93   auto *Inst = Builder.CreateIntrinsic(
94       Builder.getVoidTy(), Intrinsic::dx_resource_store_rawbuffer,
95       {II->getOperand(0), II->getOperand(1), Offset, V});
96   SI->replaceAllUsesWith(Inst);
97 }
98 
99 static void createStoreIntrinsic(IntrinsicInst *II, StoreInst *SI,
100                                  Value *Offset, dxil::ResourceTypeInfo &RTI) {
101   switch (RTI.getResourceKind()) {
102   case dxil::ResourceKind::TypedBuffer:
103     return createTypedBufferStore(II, SI, Offset, RTI);
104   case dxil::ResourceKind::RawBuffer:
105   case dxil::ResourceKind::StructuredBuffer:
106     return createRawStore(II, SI, Offset);
107   case dxil::ResourceKind::Texture1D:
108   case dxil::ResourceKind::Texture2D:
109   case dxil::ResourceKind::Texture2DMS:
110   case dxil::ResourceKind::Texture3D:
111   case dxil::ResourceKind::TextureCube:
112   case dxil::ResourceKind::Texture1DArray:
113   case dxil::ResourceKind::Texture2DArray:
114   case dxil::ResourceKind::Texture2DMSArray:
115   case dxil::ResourceKind::TextureCubeArray:
116   case dxil::ResourceKind::FeedbackTexture2D:
117   case dxil::ResourceKind::FeedbackTexture2DArray:
118     report_fatal_error("DXIL Load not implemented yet",
119                        /*gen_crash_diag=*/false);
120     return;
121   case dxil::ResourceKind::CBuffer:
122   case dxil::ResourceKind::Sampler:
123   case dxil::ResourceKind::TBuffer:
124   case dxil::ResourceKind::RTAccelerationStructure:
125   case dxil::ResourceKind::Invalid:
126   case dxil::ResourceKind::NumEntries:
127     llvm_unreachable("Invalid resource kind for store");
128   }
129   llvm_unreachable("Unhandled case in switch");
130 }
131 
132 static void createTypedBufferLoad(IntrinsicInst *II, LoadInst *LI,
133                                   Value *Offset, dxil::ResourceTypeInfo &RTI) {
134   IRBuilder<> Builder(LI);
135   Type *ContainedType = RTI.getHandleTy()->getTypeParameter(0);
136   Type *LoadType = StructType::get(ContainedType, Builder.getInt1Ty());
137 
138   Value *V =
139       Builder.CreateIntrinsic(LoadType, Intrinsic::dx_resource_load_typedbuffer,
140                               {II->getOperand(0), II->getOperand(1)});
141   V = Builder.CreateExtractValue(V, {0});
142 
143   if (Offset)
144     V = Builder.CreateExtractElement(V, Offset);
145 
146   LI->replaceAllUsesWith(V);
147 }
148 
149 static void createRawLoad(IntrinsicInst *II, LoadInst *LI, Value *Offset) {
150   IRBuilder<> Builder(LI);
151   // TODO: break up larger types
152   Type *LoadType = StructType::get(LI->getType(), Builder.getInt1Ty());
153   if (!Offset)
154     Offset = ConstantInt::get(Builder.getInt32Ty(), 0);
155   Value *V =
156       Builder.CreateIntrinsic(LoadType, Intrinsic::dx_resource_load_rawbuffer,
157                               {II->getOperand(0), II->getOperand(1), Offset});
158   V = Builder.CreateExtractValue(V, {0});
159 
160   LI->replaceAllUsesWith(V);
161 }
162 
163 static void createLoadIntrinsic(IntrinsicInst *II, LoadInst *LI, Value *Offset,
164                                 dxil::ResourceTypeInfo &RTI) {
165   switch (RTI.getResourceKind()) {
166   case dxil::ResourceKind::TypedBuffer:
167     return createTypedBufferLoad(II, LI, Offset, RTI);
168   case dxil::ResourceKind::RawBuffer:
169   case dxil::ResourceKind::StructuredBuffer:
170     return createRawLoad(II, LI, Offset);
171   case dxil::ResourceKind::Texture1D:
172   case dxil::ResourceKind::Texture2D:
173   case dxil::ResourceKind::Texture2DMS:
174   case dxil::ResourceKind::Texture3D:
175   case dxil::ResourceKind::TextureCube:
176   case dxil::ResourceKind::Texture1DArray:
177   case dxil::ResourceKind::Texture2DArray:
178   case dxil::ResourceKind::Texture2DMSArray:
179   case dxil::ResourceKind::TextureCubeArray:
180   case dxil::ResourceKind::FeedbackTexture2D:
181   case dxil::ResourceKind::FeedbackTexture2DArray:
182   case dxil::ResourceKind::CBuffer:
183   case dxil::ResourceKind::TBuffer:
184     // TODO: handle these
185     return;
186   case dxil::ResourceKind::Sampler:
187   case dxil::ResourceKind::RTAccelerationStructure:
188   case dxil::ResourceKind::Invalid:
189   case dxil::ResourceKind::NumEntries:
190     llvm_unreachable("Invalid resource kind for load");
191   }
192   llvm_unreachable("Unhandled case in switch");
193 }
194 
195 static void replaceAccess(IntrinsicInst *II, dxil::ResourceTypeInfo &RTI) {
196   // Process users keeping track of indexing accumulated from GEPs.
197   struct AccessAndOffset {
198     User *Access;
199     Value *Offset;
200   };
201   SmallVector<AccessAndOffset> Worklist;
202   for (User *U : II->users())
203     Worklist.push_back({U, nullptr});
204 
205   SmallVector<Instruction *> DeadInsts;
206   while (!Worklist.empty()) {
207     AccessAndOffset Current = Worklist.back();
208     Worklist.pop_back();
209 
210     if (auto *GEP = dyn_cast<GetElementPtrInst>(Current.Access)) {
211       IRBuilder<> Builder(GEP);
212 
213       Value *Offset = calculateGEPOffset(GEP, Current.Offset, RTI);
214       for (User *U : GEP->users())
215         Worklist.push_back({U, Offset});
216       DeadInsts.push_back(GEP);
217 
218     } else if (auto *SI = dyn_cast<StoreInst>(Current.Access)) {
219       assert(SI->getValueOperand() != II && "Pointer escaped!");
220       createStoreIntrinsic(II, SI, Current.Offset, RTI);
221       DeadInsts.push_back(SI);
222 
223     } else if (auto *LI = dyn_cast<LoadInst>(Current.Access)) {
224       createLoadIntrinsic(II, LI, Current.Offset, RTI);
225       DeadInsts.push_back(LI);
226 
227     } else
228       llvm_unreachable("Unhandled instruction - pointer escaped?");
229   }
230 
231   // Traverse the now-dead instructions in RPO and remove them.
232   for (Instruction *Dead : llvm::reverse(DeadInsts))
233     Dead->eraseFromParent();
234   II->eraseFromParent();
235 }
236 
237 static bool transformResourcePointers(Function &F, DXILResourceTypeMap &DRTM) {
238   bool Changed = false;
239   SmallVector<std::pair<IntrinsicInst *, dxil::ResourceTypeInfo>> Resources;
240   for (BasicBlock &BB : F)
241     for (Instruction &I : BB)
242       if (auto *II = dyn_cast<IntrinsicInst>(&I))
243         if (II->getIntrinsicID() == Intrinsic::dx_resource_getpointer) {
244           auto *HandleTy = cast<TargetExtType>(II->getArgOperand(0)->getType());
245           Resources.emplace_back(II, DRTM[HandleTy]);
246         }
247 
248   for (auto &[II, RI] : Resources)
249     replaceAccess(II, RI);
250 
251   return Changed;
252 }
253 
254 PreservedAnalyses DXILResourceAccess::run(Function &F,
255                                           FunctionAnalysisManager &FAM) {
256   auto &MAMProxy = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F);
257   DXILResourceTypeMap *DRTM =
258       MAMProxy.getCachedResult<DXILResourceTypeAnalysis>(*F.getParent());
259   assert(DRTM && "DXILResourceTypeAnalysis must be available");
260 
261   bool MadeChanges = transformResourcePointers(F, *DRTM);
262   if (!MadeChanges)
263     return PreservedAnalyses::all();
264 
265   PreservedAnalyses PA;
266   PA.preserve<DXILResourceTypeAnalysis>();
267   PA.preserve<DominatorTreeAnalysis>();
268   return PA;
269 }
270 
271 namespace {
272 class DXILResourceAccessLegacy : public FunctionPass {
273 public:
274   bool runOnFunction(Function &F) override {
275     DXILResourceTypeMap &DRTM =
276         getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
277 
278     return transformResourcePointers(F, DRTM);
279   }
280   StringRef getPassName() const override { return "DXIL Resource Access"; }
281   DXILResourceAccessLegacy() : FunctionPass(ID) {}
282 
283   static char ID; // Pass identification.
284   void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
285     AU.addRequired<DXILResourceTypeWrapperPass>();
286     AU.addPreserved<DominatorTreeWrapperPass>();
287   }
288 };
289 char DXILResourceAccessLegacy::ID = 0;
290 } // end anonymous namespace
291 
292 INITIALIZE_PASS_BEGIN(DXILResourceAccessLegacy, DEBUG_TYPE,
293                       "DXIL Resource Access", false, false)
294 INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass)
295 INITIALIZE_PASS_END(DXILResourceAccessLegacy, DEBUG_TYPE,
296                     "DXIL Resource Access", false, false)
297 
298 FunctionPass *llvm::createDXILResourceAccessLegacyPass() {
299   return new DXILResourceAccessLegacy();
300 }
301