xref: /llvm-project/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h (revision 83c1d003118a2cb8136fe49e2ec43958c93d9d6b)
1 //===-- SPIRVDuplicatesTracker.h - SPIR-V Duplicates Tracker ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // General infrastructure for keeping track of the values that according to
10 // the SPIR-V binary layout should be global to the whole module.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
15 #define LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
16 
17 #include "MCTargetDesc/SPIRVBaseInfo.h"
18 #include "MCTargetDesc/SPIRVMCTargetDesc.h"
19 #include "SPIRVUtils.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/MapVector.h"
22 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
23 #include "llvm/CodeGen/MachineModuleInfo.h"
24 
25 #include <type_traits>
26 
27 namespace llvm {
28 namespace SPIRV {
29 class SPIRVInstrInfo;
30 // NOTE: using MapVector instead of DenseMap because it helps getting
31 // everything ordered in a stable manner for a price of extra (NumKeys)*PtrSize
32 // memory and expensive removals which do not happen anyway.
33 class DTSortableEntry : public MapVector<const MachineFunction *, Register> {
34   SmallVector<DTSortableEntry *, 2> Deps;
35 
36   struct FlagsTy {
37     unsigned IsFunc : 1;
38     unsigned IsGV : 1;
39     unsigned IsConst : 1;
40     // NOTE: bit-field default init is a C++20 feature.
41     FlagsTy() : IsFunc(0), IsGV(0), IsConst(0) {}
42   };
43   FlagsTy Flags;
44 
45 public:
46   // Common hoisting utility doesn't support function, because their hoisting
47   // require hoisting of params as well.
48   bool getIsFunc() const { return Flags.IsFunc; }
49   bool getIsGV() const { return Flags.IsGV; }
50   bool getIsConst() const { return Flags.IsConst; }
51   void setIsFunc(bool V) { Flags.IsFunc = V; }
52   void setIsGV(bool V) { Flags.IsGV = V; }
53   void setIsConst(bool V) { Flags.IsConst = V; }
54 
55   const SmallVector<DTSortableEntry *, 2> &getDeps() const { return Deps; }
56   void addDep(DTSortableEntry *E) { Deps.push_back(E); }
57 };
58 
59 enum SpecialTypeKind {
60   STK_Empty = 0,
61   STK_Image,
62   STK_SampledImage,
63   STK_Sampler,
64   STK_Pipe,
65   STK_DeviceEvent,
66   STK_Pointer,
67   STK_Last = -1
68 };
69 
70 using SpecialTypeDescriptor = std::tuple<const Type *, unsigned, unsigned>;
71 
72 union ImageAttrs {
73   struct BitFlags {
74     unsigned Dim : 3;
75     unsigned Depth : 2;
76     unsigned Arrayed : 1;
77     unsigned MS : 1;
78     unsigned Sampled : 2;
79     unsigned ImageFormat : 6;
80     unsigned AQ : 2;
81   } Flags;
82   unsigned Val;
83 
84   ImageAttrs(unsigned Dim, unsigned Depth, unsigned Arrayed, unsigned MS,
85              unsigned Sampled, unsigned ImageFormat, unsigned AQ = 0) {
86     Val = 0;
87     Flags.Dim = Dim;
88     Flags.Depth = Depth;
89     Flags.Arrayed = Arrayed;
90     Flags.MS = MS;
91     Flags.Sampled = Sampled;
92     Flags.ImageFormat = ImageFormat;
93     Flags.AQ = AQ;
94   }
95 };
96 
97 inline SpecialTypeDescriptor
98 make_descr_image(const Type *SampledTy, unsigned Dim, unsigned Depth,
99                  unsigned Arrayed, unsigned MS, unsigned Sampled,
100                  unsigned ImageFormat, unsigned AQ = 0) {
101   return std::make_tuple(
102       SampledTy,
103       ImageAttrs(Dim, Depth, Arrayed, MS, Sampled, ImageFormat, AQ).Val,
104       SpecialTypeKind::STK_Image);
105 }
106 
107 inline SpecialTypeDescriptor
108 make_descr_sampled_image(const Type *SampledTy, const MachineInstr *ImageTy) {
109   assert(ImageTy->getOpcode() == SPIRV::OpTypeImage);
110   unsigned AC = AccessQualifier::AccessQualifier::None;
111   if (ImageTy->getNumOperands() > 8)
112     AC = ImageTy->getOperand(8).getImm();
113   return std::make_tuple(
114       SampledTy,
115       ImageAttrs(
116           ImageTy->getOperand(2).getImm(), ImageTy->getOperand(3).getImm(),
117           ImageTy->getOperand(4).getImm(), ImageTy->getOperand(5).getImm(),
118           ImageTy->getOperand(6).getImm(), ImageTy->getOperand(7).getImm(), AC)
119           .Val,
120       SpecialTypeKind::STK_SampledImage);
121 }
122 
123 inline SpecialTypeDescriptor make_descr_sampler() {
124   return std::make_tuple(nullptr, 0U, SpecialTypeKind::STK_Sampler);
125 }
126 
127 inline SpecialTypeDescriptor make_descr_pipe(uint8_t AQ) {
128   return std::make_tuple(nullptr, AQ, SpecialTypeKind::STK_Pipe);
129 }
130 
131 inline SpecialTypeDescriptor make_descr_event() {
132   return std::make_tuple(nullptr, 0U, SpecialTypeKind::STK_DeviceEvent);
133 }
134 
135 inline SpecialTypeDescriptor make_descr_pointee(const Type *ElementType,
136                                                 unsigned AddressSpace) {
137   return std::make_tuple(ElementType, AddressSpace,
138                          SpecialTypeKind::STK_Pointer);
139 }
140 } // namespace SPIRV
141 
142 template <typename KeyTy> class SPIRVDuplicatesTrackerBase {
143 public:
144   // NOTE: using MapVector instead of DenseMap helps getting everything ordered
145   // in a stable manner for a price of extra (NumKeys)*PtrSize memory and
146   // expensive removals which don't happen anyway.
147   using StorageTy = MapVector<KeyTy, SPIRV::DTSortableEntry>;
148 
149 private:
150   StorageTy Storage;
151 
152 public:
153   void add(KeyTy V, const MachineFunction *MF, Register R) {
154     if (find(V, MF).isValid())
155       return;
156 
157     Storage[V][MF] = R;
158     if (std::is_same<Function,
159                      typename std::remove_const<
160                          typename std::remove_pointer<KeyTy>::type>::type>() ||
161         std::is_same<Argument,
162                      typename std::remove_const<
163                          typename std::remove_pointer<KeyTy>::type>::type>())
164       Storage[V].setIsFunc(true);
165     if (std::is_same<GlobalVariable,
166                      typename std::remove_const<
167                          typename std::remove_pointer<KeyTy>::type>::type>())
168       Storage[V].setIsGV(true);
169     if (std::is_same<Constant,
170                      typename std::remove_const<
171                          typename std::remove_pointer<KeyTy>::type>::type>())
172       Storage[V].setIsConst(true);
173   }
174 
175   Register find(KeyTy V, const MachineFunction *MF) const {
176     auto iter = Storage.find(V);
177     if (iter != Storage.end()) {
178       auto Map = iter->second;
179       auto iter2 = Map.find(MF);
180       if (iter2 != Map.end())
181         return iter2->second;
182     }
183     return Register();
184   }
185 
186   const StorageTy &getAllUses() const { return Storage; }
187 
188 private:
189   StorageTy &getAllUses() { return Storage; }
190 
191   // The friend class needs to have access to the internal storage
192   // to be able to build dependency graph, can't declare only one
193   // function a 'friend' due to the incomplete declaration at this point
194   // and mutual dependency problems.
195   friend class SPIRVGeneralDuplicatesTracker;
196 };
197 
198 template <typename T>
199 class SPIRVDuplicatesTracker : public SPIRVDuplicatesTrackerBase<const T *> {};
200 
201 template <>
202 class SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor>
203     : public SPIRVDuplicatesTrackerBase<SPIRV::SpecialTypeDescriptor> {};
204 
205 class SPIRVGeneralDuplicatesTracker {
206   SPIRVDuplicatesTracker<Type> TT;
207   SPIRVDuplicatesTracker<Constant> CT;
208   SPIRVDuplicatesTracker<GlobalVariable> GT;
209   SPIRVDuplicatesTracker<Function> FT;
210   SPIRVDuplicatesTracker<Argument> AT;
211   SPIRVDuplicatesTracker<MachineInstr> MT;
212   SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor> ST;
213 
214 public:
215   void add(const Type *Ty, const MachineFunction *MF, Register R) {
216     TT.add(unifyPtrType(Ty), MF, R);
217   }
218 
219   void add(const Type *PointeeTy, unsigned AddressSpace,
220            const MachineFunction *MF, Register R) {
221     ST.add(SPIRV::make_descr_pointee(unifyPtrType(PointeeTy), AddressSpace), MF,
222            R);
223   }
224 
225   void add(const Constant *C, const MachineFunction *MF, Register R) {
226     CT.add(C, MF, R);
227   }
228 
229   void add(const GlobalVariable *GV, const MachineFunction *MF, Register R) {
230     GT.add(GV, MF, R);
231   }
232 
233   void add(const Function *F, const MachineFunction *MF, Register R) {
234     FT.add(F, MF, R);
235   }
236 
237   void add(const Argument *Arg, const MachineFunction *MF, Register R) {
238     AT.add(Arg, MF, R);
239   }
240 
241   void add(const MachineInstr *MI, const MachineFunction *MF, Register R) {
242     MT.add(MI, MF, R);
243   }
244 
245   void add(const SPIRV::SpecialTypeDescriptor &TD, const MachineFunction *MF,
246            Register R) {
247     ST.add(TD, MF, R);
248   }
249 
250   Register find(const Type *Ty, const MachineFunction *MF) {
251     return TT.find(unifyPtrType(Ty), MF);
252   }
253 
254   Register find(const Type *PointeeTy, unsigned AddressSpace,
255                 const MachineFunction *MF) {
256     return ST.find(
257         SPIRV::make_descr_pointee(unifyPtrType(PointeeTy), AddressSpace), MF);
258   }
259 
260   Register find(const Constant *C, const MachineFunction *MF) {
261     return CT.find(const_cast<Constant *>(C), MF);
262   }
263 
264   Register find(const GlobalVariable *GV, const MachineFunction *MF) {
265     return GT.find(const_cast<GlobalVariable *>(GV), MF);
266   }
267 
268   Register find(const Function *F, const MachineFunction *MF) {
269     return FT.find(const_cast<Function *>(F), MF);
270   }
271 
272   Register find(const Argument *Arg, const MachineFunction *MF) {
273     return AT.find(const_cast<Argument *>(Arg), MF);
274   }
275 
276   Register find(const MachineInstr *MI, const MachineFunction *MF) {
277     return MT.find(const_cast<MachineInstr *>(MI), MF);
278   }
279 
280   Register find(const SPIRV::SpecialTypeDescriptor &TD,
281                 const MachineFunction *MF) {
282     return ST.find(TD, MF);
283   }
284 
285   const SPIRVDuplicatesTracker<Type> *getTypes() { return &TT; }
286 };
287 } // namespace llvm
288 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
289