xref: /llvm-project/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (revision d459784cbea334d167b2dca48e0c26115c68e5d3)
1 //===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the SPIRVGlobalRegistry class,
10 // which is used to maintain rich type information required for SPIR-V even
11 // after lowering from LLVM IR to GMIR. It can convert an llvm::Type into
12 // an OpTypeXXX instruction, and map it to a virtual register. Also it builds
13 // and supports consistency of constants and global variables.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "SPIRVGlobalRegistry.h"
18 #include "SPIRV.h"
19 #include "SPIRVBuiltins.h"
20 #include "SPIRVSubtarget.h"
21 #include "SPIRVTargetMachine.h"
22 #include "SPIRVUtils.h"
23 #include "llvm/ADT/APInt.h"
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/Type.h"
26 #include "llvm/Support/Casting.h"
27 #include <cassert>
28 #include <functional>
29 
30 using namespace llvm;
31 
32 inline unsigned typeToAddressSpace(const Type *Ty) {
33   if (auto PType = dyn_cast<TypedPointerType>(Ty))
34     return PType->getAddressSpace();
35   if (auto PType = dyn_cast<PointerType>(Ty))
36     return PType->getAddressSpace();
37   if (auto *ExtTy = dyn_cast<TargetExtType>(Ty);
38       ExtTy && isTypedPointerWrapper(ExtTy))
39     return ExtTy->getIntParameter(0);
40   report_fatal_error("Unable to convert LLVM type to SPIRVType", true);
41 }
42 
43 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
44     : PointerSize(PointerSize), Bound(0) {}
45 
46 SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth,
47                                                     Register VReg,
48                                                     MachineInstr &I,
49                                                     const SPIRVInstrInfo &TII) {
50   SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
51   assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
52   return SpirvType;
53 }
54 
55 SPIRVType *
56 SPIRVGlobalRegistry::assignFloatTypeToVReg(unsigned BitWidth, Register VReg,
57                                            MachineInstr &I,
58                                            const SPIRVInstrInfo &TII) {
59   SPIRVType *SpirvType = getOrCreateSPIRVFloatType(BitWidth, I, TII);
60   assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
61   return SpirvType;
62 }
63 
64 SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
65     SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I,
66     const SPIRVInstrInfo &TII) {
67   SPIRVType *SpirvType =
68       getOrCreateSPIRVVectorType(BaseType, NumElements, I, TII);
69   assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
70   return SpirvType;
71 }
72 
73 SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
74     const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
75     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
76   SPIRVType *SpirvType =
77       getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
78   assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF());
79   return SpirvType;
80 }
81 
82 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
83                                                 Register VReg,
84                                                 const MachineFunction &MF) {
85   VRegToTypeMap[&MF][VReg] = SpirvType;
86 }
87 
88 static Register createTypeVReg(MachineRegisterInfo &MRI) {
89   auto Res = MRI.createGenericVirtualRegister(LLT::scalar(64));
90   MRI.setRegClass(Res, &SPIRV::TYPERegClass);
91   return Res;
92 }
93 
94 inline Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
95   return createTypeVReg(MIRBuilder.getMF().getRegInfo());
96 }
97 
98 SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
99   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
100     return MIRBuilder.buildInstr(SPIRV::OpTypeBool)
101         .addDef(createTypeVReg(MIRBuilder));
102   });
103 }
104 
105 unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
106   if (Width > 64)
107     report_fatal_error("Unsupported integer width!");
108   const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
109   if (ST.canUseExtension(
110           SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers))
111     return Width;
112   if (Width <= 8)
113     Width = 8;
114   else if (Width <= 16)
115     Width = 16;
116   else if (Width <= 32)
117     Width = 32;
118   else
119     Width = 64;
120   return Width;
121 }
122 
123 SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
124                                              MachineIRBuilder &MIRBuilder,
125                                              bool IsSigned) {
126   Width = adjustOpTypeIntWidth(Width);
127   const SPIRVSubtarget &ST =
128       cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
129   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
130     if (ST.canUseExtension(
131             SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
132       MIRBuilder.buildInstr(SPIRV::OpExtension)
133           .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
134       MIRBuilder.buildInstr(SPIRV::OpCapability)
135           .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
136     }
137     return MIRBuilder.buildInstr(SPIRV::OpTypeInt)
138         .addDef(createTypeVReg(MIRBuilder))
139         .addImm(Width)
140         .addImm(IsSigned ? 1 : 0);
141   });
142 }
143 
144 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
145                                                MachineIRBuilder &MIRBuilder) {
146   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
147     return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
148         .addDef(createTypeVReg(MIRBuilder))
149         .addImm(Width);
150   });
151 }
152 
153 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
154   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
155     return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
156         .addDef(createTypeVReg(MIRBuilder));
157   });
158 }
159 
160 void SPIRVGlobalRegistry::invalidateMachineInstr(MachineInstr *MI) {
161   // TODO:
162   // - take into account duplicate tracker case which is a known issue,
163   // - review other data structure wrt. possible issues related to removal
164   //   of a machine instruction during instruction selection.
165   const MachineFunction *MF = MI->getParent()->getParent();
166   auto It = LastInsertedTypeMap.find(MF);
167   if (It == LastInsertedTypeMap.end())
168     return;
169   if (It->second == MI)
170     LastInsertedTypeMap.erase(MF);
171 }
172 
173 SPIRVType *SPIRVGlobalRegistry::createOpType(
174     MachineIRBuilder &MIRBuilder,
175     std::function<MachineInstr *(MachineIRBuilder &)> Op) {
176   auto oldInsertPoint = MIRBuilder.getInsertPt();
177   MachineBasicBlock *OldMBB = &MIRBuilder.getMBB();
178   MachineBasicBlock *NewMBB = &*MIRBuilder.getMF().begin();
179 
180   auto LastInsertedType = LastInsertedTypeMap.find(CurMF);
181   if (LastInsertedType != LastInsertedTypeMap.end()) {
182     auto It = LastInsertedType->second->getIterator();
183     // It might happen that this instruction was removed from the first MBB,
184     // hence the Parent's check.
185     MachineBasicBlock::iterator InsertAt;
186     if (It->getParent() != NewMBB)
187       InsertAt = oldInsertPoint->getParent() == NewMBB
188                      ? oldInsertPoint
189                      : getInsertPtValidEnd(NewMBB);
190     else if (It->getNextNode())
191       InsertAt = It->getNextNode()->getIterator();
192     else
193       InsertAt = getInsertPtValidEnd(NewMBB);
194     MIRBuilder.setInsertPt(*NewMBB, InsertAt);
195   } else {
196     MIRBuilder.setInsertPt(*NewMBB, NewMBB->begin());
197     auto Result = LastInsertedTypeMap.try_emplace(CurMF, nullptr);
198     assert(Result.second);
199     LastInsertedType = Result.first;
200   }
201 
202   MachineInstr *Type = Op(MIRBuilder);
203   // We expect all users of this function to insert definitions at the insertion
204   // point set above that is always the first MBB.
205   assert(Type->getParent() == NewMBB);
206   LastInsertedType->second = Type;
207 
208   MIRBuilder.setInsertPt(*OldMBB, oldInsertPoint);
209   return Type;
210 }
211 
212 SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
213                                                 SPIRVType *ElemType,
214                                                 MachineIRBuilder &MIRBuilder) {
215   auto EleOpc = ElemType->getOpcode();
216   (void)EleOpc;
217   assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
218           EleOpc == SPIRV::OpTypeBool) &&
219          "Invalid vector element type");
220 
221   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
222     return MIRBuilder.buildInstr(SPIRV::OpTypeVector)
223         .addDef(createTypeVReg(MIRBuilder))
224         .addUse(getSPIRVTypeID(ElemType))
225         .addImm(NumElems);
226   });
227 }
228 
229 std::tuple<Register, ConstantInt *, bool, unsigned>
230 SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
231                                             MachineIRBuilder *MIRBuilder,
232                                             MachineInstr *I,
233                                             const SPIRVInstrInfo *TII) {
234   assert(SpvType);
235   const IntegerType *LLVMIntTy =
236       cast<IntegerType>(getTypeForSPIRVType(SpvType));
237   unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
238   bool NewInstr = false;
239   // Find a constant in DT or build a new one.
240   ConstantInt *CI = ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
241   Register Res = DT.find(CI, CurMF);
242   if (!Res.isValid()) {
243     Res =
244         CurMF->getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
245     CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
246     if (MIRBuilder)
247       assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder);
248     else
249       assignIntTypeToVReg(BitWidth, Res, *I, *TII);
250     DT.add(CI, CurMF, Res);
251     NewInstr = true;
252   }
253   return std::make_tuple(Res, CI, NewInstr, BitWidth);
254 }
255 
256 std::tuple<Register, ConstantFP *, bool, unsigned>
257 SPIRVGlobalRegistry::getOrCreateConstFloatReg(APFloat Val, SPIRVType *SpvType,
258                                               MachineIRBuilder *MIRBuilder,
259                                               MachineInstr *I,
260                                               const SPIRVInstrInfo *TII) {
261   assert(SpvType);
262   LLVMContext &Ctx = CurMF->getFunction().getContext();
263   const Type *LLVMFloatTy = getTypeForSPIRVType(SpvType);
264   unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
265   bool NewInstr = false;
266   // Find a constant in DT or build a new one.
267   auto *const CI = ConstantFP::get(Ctx, Val);
268   Register Res = DT.find(CI, CurMF);
269   if (!Res.isValid()) {
270     Res =
271         CurMF->getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
272     CurMF->getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass);
273     if (MIRBuilder)
274       assignTypeToVReg(LLVMFloatTy, Res, *MIRBuilder);
275     else
276       assignFloatTypeToVReg(BitWidth, Res, *I, *TII);
277     DT.add(CI, CurMF, Res);
278     NewInstr = true;
279   }
280   return std::make_tuple(Res, CI, NewInstr, BitWidth);
281 }
282 
283 Register SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val, MachineInstr &I,
284                                                  SPIRVType *SpvType,
285                                                  const SPIRVInstrInfo &TII,
286                                                  bool ZeroAsNull) {
287   assert(SpvType);
288   ConstantFP *CI;
289   Register Res;
290   bool New;
291   unsigned BitWidth;
292   std::tie(Res, CI, New, BitWidth) =
293       getOrCreateConstFloatReg(Val, SpvType, nullptr, &I, &TII);
294   // If we have found Res register which is defined by the passed G_CONSTANT
295   // machine instruction, a new constant instruction should be created.
296   if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))
297     return Res;
298   MachineIRBuilder MIRBuilder(I);
299   createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
300     MachineInstrBuilder MIB;
301     // In OpenCL OpConstantNull - Scalar floating point: +0.0 (all bits 0)
302     if (Val.isPosZero() && ZeroAsNull) {
303       MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
304                 .addDef(Res)
305                 .addUse(getSPIRVTypeID(SpvType));
306     } else {
307       MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)
308                 .addDef(Res)
309                 .addUse(getSPIRVTypeID(SpvType));
310       addNumImm(
311           APInt(BitWidth, CI->getValueAPF().bitcastToAPInt().getZExtValue()),
312           MIB);
313     }
314     const auto &ST = CurMF->getSubtarget();
315     constrainSelectedInstRegOperands(
316         *MIB, *ST.getInstrInfo(), *ST.getRegisterInfo(), *ST.getRegBankInfo());
317     return MIB;
318   });
319   return Res;
320 }
321 
322 Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
323                                                   SPIRVType *SpvType,
324                                                   const SPIRVInstrInfo &TII,
325                                                   bool ZeroAsNull) {
326   assert(SpvType);
327   ConstantInt *CI;
328   Register Res;
329   bool New;
330   unsigned BitWidth;
331   std::tie(Res, CI, New, BitWidth) =
332       getOrCreateConstIntReg(Val, SpvType, nullptr, &I, &TII);
333   // If we have found Res register which is defined by the passed G_CONSTANT
334   // machine instruction, a new constant instruction should be created.
335   if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))
336     return Res;
337 
338   MachineIRBuilder MIRBuilder(I);
339   createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
340     MachineInstrBuilder MIB;
341     if (Val || !ZeroAsNull) {
342       MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
343                 .addDef(Res)
344                 .addUse(getSPIRVTypeID(SpvType));
345       addNumImm(APInt(BitWidth, Val), MIB);
346     } else {
347       MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
348                 .addDef(Res)
349                 .addUse(getSPIRVTypeID(SpvType));
350     }
351     const auto &ST = CurMF->getSubtarget();
352     constrainSelectedInstRegOperands(
353         *MIB, *ST.getInstrInfo(), *ST.getRegisterInfo(), *ST.getRegBankInfo());
354     return MIB;
355   });
356   return Res;
357 }
358 
359 Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
360                                                MachineIRBuilder &MIRBuilder,
361                                                SPIRVType *SpvType, bool EmitIR,
362                                                bool ZeroAsNull) {
363   assert(SpvType);
364   auto &MF = MIRBuilder.getMF();
365   const IntegerType *LLVMIntTy =
366       cast<IntegerType>(getTypeForSPIRVType(SpvType));
367   // Find a constant in DT or build a new one.
368   const auto ConstInt =
369       ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
370   Register Res = DT.find(ConstInt, &MF);
371   if (!Res.isValid()) {
372     unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
373     LLT LLTy = LLT::scalar(BitWidth);
374     Res = MF.getRegInfo().createGenericVirtualRegister(LLTy);
375     MF.getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
376     assignTypeToVReg(LLVMIntTy, Res, MIRBuilder,
377                      SPIRV::AccessQualifier::ReadWrite, EmitIR);
378     DT.add(ConstInt, &MIRBuilder.getMF(), Res);
379     if (EmitIR) {
380       MIRBuilder.buildConstant(Res, *ConstInt);
381     } else {
382       Register SpvTypeReg = getSPIRVTypeID(SpvType);
383       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
384         MachineInstrBuilder MIB;
385         if (Val || !ZeroAsNull) {
386           MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
387                     .addDef(Res)
388                     .addUse(SpvTypeReg);
389           addNumImm(APInt(BitWidth, Val), MIB);
390         } else {
391           MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
392                     .addDef(Res)
393                     .addUse(SpvTypeReg);
394         }
395         const auto &Subtarget = CurMF->getSubtarget();
396         constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
397                                          *Subtarget.getRegisterInfo(),
398                                          *Subtarget.getRegBankInfo());
399         return MIB;
400       });
401     }
402   }
403   return Res;
404 }
405 
406 Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
407                                               MachineIRBuilder &MIRBuilder,
408                                               SPIRVType *SpvType) {
409   auto &MF = MIRBuilder.getMF();
410   auto &Ctx = MF.getFunction().getContext();
411   if (!SpvType) {
412     const Type *LLVMFPTy = Type::getFloatTy(Ctx);
413     SpvType = getOrCreateSPIRVType(LLVMFPTy, MIRBuilder);
414   }
415   // Find a constant in DT or build a new one.
416   const auto ConstFP = ConstantFP::get(Ctx, Val);
417   Register Res = DT.find(ConstFP, &MF);
418   if (!Res.isValid()) {
419     Res = MF.getRegInfo().createGenericVirtualRegister(
420         LLT::scalar(getScalarOrVectorBitWidth(SpvType)));
421     MF.getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass);
422     assignSPIRVTypeToVReg(SpvType, Res, MF);
423     DT.add(ConstFP, &MF, Res);
424     createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
425       MachineInstrBuilder MIB;
426       MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)
427                 .addDef(Res)
428                 .addUse(getSPIRVTypeID(SpvType));
429       addNumImm(ConstFP->getValueAPF().bitcastToAPInt(), MIB);
430       return MIB;
431     });
432   }
433 
434   return Res;
435 }
436 
437 Register SPIRVGlobalRegistry::getOrCreateBaseRegister(
438     Constant *Val, MachineInstr &I, SPIRVType *SpvType,
439     const SPIRVInstrInfo &TII, unsigned BitWidth, bool ZeroAsNull) {
440   SPIRVType *Type = SpvType;
441   if (SpvType->getOpcode() == SPIRV::OpTypeVector ||
442       SpvType->getOpcode() == SPIRV::OpTypeArray) {
443     auto EleTypeReg = SpvType->getOperand(1).getReg();
444     Type = getSPIRVTypeForVReg(EleTypeReg);
445   }
446   if (Type->getOpcode() == SPIRV::OpTypeFloat) {
447     SPIRVType *SpvBaseType = getOrCreateSPIRVFloatType(BitWidth, I, TII);
448     return getOrCreateConstFP(dyn_cast<ConstantFP>(Val)->getValue(), I,
449                               SpvBaseType, TII, ZeroAsNull);
450   }
451   assert(Type->getOpcode() == SPIRV::OpTypeInt);
452   SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
453   return getOrCreateConstInt(Val->getUniqueInteger().getZExtValue(), I,
454                              SpvBaseType, TII, ZeroAsNull);
455 }
456 
457 Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
458     Constant *Val, MachineInstr &I, SPIRVType *SpvType,
459     const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,
460     unsigned ElemCnt, bool ZeroAsNull) {
461   // Find a constant vector or array in DT or build a new one.
462   Register Res = DT.find(CA, CurMF);
463   // If no values are attached, the composite is null constant.
464   bool IsNull = Val->isNullValue() && ZeroAsNull;
465   if (!Res.isValid()) {
466     // SpvScalConst should be created before SpvVecConst to avoid undefined ID
467     // error on validation.
468     // TODO: can moved below once sorting of types/consts/defs is implemented.
469     Register SpvScalConst;
470     if (!IsNull)
471       SpvScalConst =
472           getOrCreateBaseRegister(Val, I, SpvType, TII, BitWidth, ZeroAsNull);
473 
474     LLT LLTy = LLT::scalar(64);
475     Register SpvVecConst =
476         CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
477     CurMF->getRegInfo().setRegClass(SpvVecConst, getRegClass(SpvType));
478     assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
479     DT.add(CA, CurMF, SpvVecConst);
480     MachineIRBuilder MIRBuilder(I);
481     createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
482       MachineInstrBuilder MIB;
483       if (!IsNull) {
484         MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite)
485                   .addDef(SpvVecConst)
486                   .addUse(getSPIRVTypeID(SpvType));
487         for (unsigned i = 0; i < ElemCnt; ++i)
488           MIB.addUse(SpvScalConst);
489       } else {
490         MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
491                   .addDef(SpvVecConst)
492                   .addUse(getSPIRVTypeID(SpvType));
493       }
494       const auto &Subtarget = CurMF->getSubtarget();
495       constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
496                                        *Subtarget.getRegisterInfo(),
497                                        *Subtarget.getRegBankInfo());
498       return MIB;
499     });
500     return SpvVecConst;
501   }
502   return Res;
503 }
504 
505 Register SPIRVGlobalRegistry::getOrCreateConstVector(uint64_t Val,
506                                                      MachineInstr &I,
507                                                      SPIRVType *SpvType,
508                                                      const SPIRVInstrInfo &TII,
509                                                      bool ZeroAsNull) {
510   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
511   assert(LLVMTy->isVectorTy());
512   const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
513   Type *LLVMBaseTy = LLVMVecTy->getElementType();
514   assert(LLVMBaseTy->isIntegerTy());
515   auto *ConstVal = ConstantInt::get(LLVMBaseTy, Val);
516   auto *ConstVec =
517       ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);
518   unsigned BW = getScalarOrVectorBitWidth(SpvType);
519   return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,
520                                     SpvType->getOperand(2).getImm(),
521                                     ZeroAsNull);
522 }
523 
524 Register SPIRVGlobalRegistry::getOrCreateConstVector(APFloat Val,
525                                                      MachineInstr &I,
526                                                      SPIRVType *SpvType,
527                                                      const SPIRVInstrInfo &TII,
528                                                      bool ZeroAsNull) {
529   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
530   assert(LLVMTy->isVectorTy());
531   const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
532   Type *LLVMBaseTy = LLVMVecTy->getElementType();
533   assert(LLVMBaseTy->isFloatingPointTy());
534   auto *ConstVal = ConstantFP::get(LLVMBaseTy, Val);
535   auto *ConstVec =
536       ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);
537   unsigned BW = getScalarOrVectorBitWidth(SpvType);
538   return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,
539                                     SpvType->getOperand(2).getImm(),
540                                     ZeroAsNull);
541 }
542 
543 Register SPIRVGlobalRegistry::getOrCreateConstIntArray(
544     uint64_t Val, size_t Num, MachineInstr &I, SPIRVType *SpvType,
545     const SPIRVInstrInfo &TII) {
546   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
547   assert(LLVMTy->isArrayTy());
548   const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
549   Type *LLVMBaseTy = LLVMArrTy->getElementType();
550   Constant *CI = ConstantInt::get(LLVMBaseTy, Val);
551   SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
552   unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
553   // The following is reasonably unique key that is better that [Val]. The naive
554   // alternative would be something along the lines of:
555   //   SmallVector<Constant *> NumCI(Num, CI);
556   //   Constant *UniqueKey =
557   //     ConstantArray::get(const_cast<ArrayType*>(LLVMArrTy), NumCI);
558   // that would be a truly unique but dangerous key, because it could lead to
559   // the creation of constants of arbitrary length (that is, the parameter of
560   // memset) which were missing in the original module.
561   Constant *UniqueKey = ConstantStruct::getAnon(
562       {PoisonValue::get(const_cast<ArrayType *>(LLVMArrTy)),
563        ConstantInt::get(LLVMBaseTy, Val), ConstantInt::get(LLVMBaseTy, Num)});
564   return getOrCreateCompositeOrNull(CI, I, SpvType, TII, UniqueKey, BW,
565                                     LLVMArrTy->getNumElements());
566 }
567 
568 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
569     uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR,
570     Constant *CA, unsigned BitWidth, unsigned ElemCnt) {
571   Register Res = DT.find(CA, CurMF);
572   if (!Res.isValid()) {
573     Register SpvScalConst;
574     if (Val || EmitIR) {
575       SPIRVType *SpvBaseType =
576           getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
577       SpvScalConst = buildConstantInt(Val, MIRBuilder, SpvBaseType, EmitIR);
578     }
579     LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(64);
580     Register SpvVecConst =
581         CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
582     CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::iIDRegClass);
583     assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
584     DT.add(CA, CurMF, SpvVecConst);
585     if (EmitIR) {
586       MIRBuilder.buildSplatBuildVector(SpvVecConst, SpvScalConst);
587     } else {
588       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
589         if (Val) {
590           auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite)
591                          .addDef(SpvVecConst)
592                          .addUse(getSPIRVTypeID(SpvType));
593           for (unsigned i = 0; i < ElemCnt; ++i)
594             MIB.addUse(SpvScalConst);
595           return MIB;
596         } else {
597           return MIRBuilder.buildInstr(SPIRV::OpConstantNull)
598               .addDef(SpvVecConst)
599               .addUse(getSPIRVTypeID(SpvType));
600         }
601       });
602     }
603     return SpvVecConst;
604   }
605   return Res;
606 }
607 
608 Register
609 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val,
610                                               MachineIRBuilder &MIRBuilder,
611                                               SPIRVType *SpvType, bool EmitIR) {
612   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
613   assert(LLVMTy->isVectorTy());
614   const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
615   Type *LLVMBaseTy = LLVMVecTy->getElementType();
616   const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
617   auto ConstVec =
618       ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
619   unsigned BW = getScalarOrVectorBitWidth(SpvType);
620   return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
621                                        ConstVec, BW,
622                                        SpvType->getOperand(2).getImm());
623 }
624 
625 Register
626 SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
627                                              SPIRVType *SpvType) {
628   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
629   unsigned AddressSpace = typeToAddressSpace(LLVMTy);
630   // Find a constant in DT or build a new one.
631   Constant *CP = ConstantPointerNull::get(
632       PointerType::get(LLVMTy->getContext(), AddressSpace));
633   Register Res = DT.find(CP, CurMF);
634   if (!Res.isValid()) {
635     LLT LLTy = LLT::pointer(AddressSpace, PointerSize);
636     Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
637     CurMF->getRegInfo().setRegClass(Res, &SPIRV::pIDRegClass);
638     assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
639     createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
640       return MIRBuilder.buildInstr(SPIRV::OpConstantNull)
641           .addDef(Res)
642           .addUse(getSPIRVTypeID(SpvType));
643     });
644     DT.add(CP, CurMF, Res);
645   }
646   return Res;
647 }
648 
649 Register SPIRVGlobalRegistry::buildConstantSampler(
650     Register ResReg, unsigned AddrMode, unsigned Param, unsigned FilerMode,
651     MachineIRBuilder &MIRBuilder, SPIRVType *SpvType) {
652   SPIRVType *SampTy;
653   if (SpvType)
654     SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(SpvType), MIRBuilder);
655   else if ((SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t",
656                                                 MIRBuilder)) == nullptr)
657     report_fatal_error("Unable to recognize SPIRV type name: opencl.sampler_t");
658 
659   auto Sampler =
660       ResReg.isValid()
661           ? ResReg
662           : MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass);
663   auto Res = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
664     return MIRBuilder.buildInstr(SPIRV::OpConstantSampler)
665         .addDef(Sampler)
666         .addUse(getSPIRVTypeID(SampTy))
667         .addImm(AddrMode)
668         .addImm(Param)
669         .addImm(FilerMode);
670   });
671   assert(Res->getOperand(0).isReg());
672   return Res->getOperand(0).getReg();
673 }
674 
675 Register SPIRVGlobalRegistry::buildGlobalVariable(
676     Register ResVReg, SPIRVType *BaseType, StringRef Name,
677     const GlobalValue *GV, SPIRV::StorageClass::StorageClass Storage,
678     const MachineInstr *Init, bool IsConst, bool HasLinkageTy,
679     SPIRV::LinkageType::LinkageType LinkageType, MachineIRBuilder &MIRBuilder,
680     bool IsInstSelector) {
681   const GlobalVariable *GVar = nullptr;
682   if (GV)
683     GVar = cast<const GlobalVariable>(GV);
684   else {
685     // If GV is not passed explicitly, use the name to find or construct
686     // the global variable.
687     Module *M = MIRBuilder.getMF().getFunction().getParent();
688     GVar = M->getGlobalVariable(Name);
689     if (GVar == nullptr) {
690       const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type.
691       // Module takes ownership of the global var.
692       GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,
693                                 GlobalValue::ExternalLinkage, nullptr,
694                                 Twine(Name));
695     }
696     GV = GVar;
697   }
698   Register Reg = DT.find(GVar, &MIRBuilder.getMF());
699   if (Reg.isValid()) {
700     if (Reg != ResVReg)
701       MIRBuilder.buildCopy(ResVReg, Reg);
702     return ResVReg;
703   }
704 
705   auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable)
706                  .addDef(ResVReg)
707                  .addUse(getSPIRVTypeID(BaseType))
708                  .addImm(static_cast<uint32_t>(Storage));
709 
710   if (Init != 0) {
711     MIB.addUse(Init->getOperand(0).getReg());
712   }
713 
714   // ISel may introduce a new register on this step, so we need to add it to
715   // DT and correct its type avoiding fails on the next stage.
716   if (IsInstSelector) {
717     const auto &Subtarget = CurMF->getSubtarget();
718     constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
719                                      *Subtarget.getRegisterInfo(),
720                                      *Subtarget.getRegBankInfo());
721   }
722   Reg = MIB->getOperand(0).getReg();
723   DT.add(GVar, &MIRBuilder.getMF(), Reg);
724   addGlobalObject(GVar, &MIRBuilder.getMF(), Reg);
725 
726   // Set to Reg the same type as ResVReg has.
727   auto MRI = MIRBuilder.getMRI();
728   if (Reg != ResVReg) {
729     LLT RegLLTy =
730         LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), getPointerSize());
731     MRI->setType(Reg, RegLLTy);
732     assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
733   } else {
734     // Our knowledge about the type may be updated.
735     // If that's the case, we need to update a type
736     // associated with the register.
737     SPIRVType *DefType = getSPIRVTypeForVReg(ResVReg);
738     if (!DefType || DefType != BaseType)
739       assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
740   }
741 
742   // If it's a global variable with name, output OpName for it.
743   if (GVar && GVar->hasName())
744     buildOpName(Reg, GVar->getName(), MIRBuilder);
745 
746   // Output decorations for the GV.
747   // TODO: maybe move to GenerateDecorations pass.
748   const SPIRVSubtarget &ST =
749       cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
750   if (IsConst && ST.isOpenCLEnv())
751     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {});
752 
753   if (GVar && GVar->getAlign().valueOrOne().value() != 1) {
754     unsigned Alignment = (unsigned)GVar->getAlign().valueOrOne().value();
755     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Alignment, {Alignment});
756   }
757 
758   if (HasLinkageTy)
759     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
760                     {static_cast<uint32_t>(LinkageType)}, Name);
761 
762   SPIRV::BuiltIn::BuiltIn BuiltInId;
763   if (getSpirvBuiltInIdByName(Name, BuiltInId))
764     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::BuiltIn,
765                     {static_cast<uint32_t>(BuiltInId)});
766 
767   // If it's a global variable with "spirv.Decorations" metadata node
768   // recognize it as a SPIR-V friendly LLVM IR and parse "spirv.Decorations"
769   // arguments.
770   MDNode *GVarMD = nullptr;
771   if (GVar && (GVarMD = GVar->getMetadata("spirv.Decorations")) != nullptr)
772     buildOpSpirvDecorations(Reg, MIRBuilder, GVarMD);
773 
774   return Reg;
775 }
776 
777 static std::string GetSpirvImageTypeName(const SPIRVType *Type,
778                                          MachineIRBuilder &MIRBuilder,
779                                          const std::string &Prefix);
780 
781 static std::string buildSpirvTypeName(const SPIRVType *Type,
782                                       MachineIRBuilder &MIRBuilder) {
783   switch (Type->getOpcode()) {
784   case SPIRV::OpTypeSampledImage: {
785     return GetSpirvImageTypeName(Type, MIRBuilder, "sampled_image_");
786   }
787   case SPIRV::OpTypeImage: {
788     return GetSpirvImageTypeName(Type, MIRBuilder, "image_");
789   }
790   case SPIRV::OpTypeArray: {
791     MachineRegisterInfo *MRI = MIRBuilder.getMRI();
792     Register ElementTypeReg = Type->getOperand(1).getReg();
793     auto *ElementType = MRI->getUniqueVRegDef(ElementTypeReg);
794     const SPIRVType *TypeInst = MRI->getVRegDef(Type->getOperand(2).getReg());
795     assert(TypeInst->getOpcode() != SPIRV::OpConstantI);
796     MachineInstr *ImmInst = MRI->getVRegDef(TypeInst->getOperand(1).getReg());
797     assert(ImmInst->getOpcode() == TargetOpcode::G_CONSTANT);
798     uint32_t ArraySize = ImmInst->getOperand(1).getCImm()->getZExtValue();
799     return (buildSpirvTypeName(ElementType, MIRBuilder) + Twine("[") +
800             Twine(ArraySize) + Twine("]"))
801         .str();
802   }
803   case SPIRV::OpTypeFloat:
804     return ("f" + Twine(Type->getOperand(1).getImm())).str();
805   case SPIRV::OpTypeSampler:
806     return ("sampler");
807   case SPIRV::OpTypeInt:
808     if (Type->getOperand(2).getImm())
809       return ("i" + Twine(Type->getOperand(1).getImm())).str();
810     return ("u" + Twine(Type->getOperand(1).getImm())).str();
811   default:
812     llvm_unreachable("Trying to the the name of an unknown type.");
813   }
814 }
815 
816 static std::string GetSpirvImageTypeName(const SPIRVType *Type,
817                                          MachineIRBuilder &MIRBuilder,
818                                          const std::string &Prefix) {
819   Register SampledTypeReg = Type->getOperand(1).getReg();
820   auto *SampledType = MIRBuilder.getMRI()->getUniqueVRegDef(SampledTypeReg);
821   std::string TypeName = Prefix + buildSpirvTypeName(SampledType, MIRBuilder);
822   for (uint32_t I = 2; I < Type->getNumOperands(); ++I) {
823     TypeName = (TypeName + '_' + Twine(Type->getOperand(I).getImm())).str();
824   }
825   return TypeName;
826 }
827 
828 Register SPIRVGlobalRegistry::getOrCreateGlobalVariableWithBinding(
829     const SPIRVType *VarType, uint32_t Set, uint32_t Binding,
830     MachineIRBuilder &MIRBuilder) {
831   SPIRVType *VarPointerTypeReg = getOrCreateSPIRVPointerType(
832       VarType, MIRBuilder, SPIRV::StorageClass::UniformConstant);
833   Register VarReg =
834       MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass);
835 
836   // TODO: The name should come from the llvm-ir, but how that name will be
837   // passed from the HLSL to the backend has not been decided. Using this place
838   // holder for now.
839   std::string Name = ("__resource_" + buildSpirvTypeName(VarType, MIRBuilder) +
840                       "_" + Twine(Set) + "_" + Twine(Binding))
841                          .str();
842   buildGlobalVariable(VarReg, VarPointerTypeReg, Name, nullptr,
843                       SPIRV::StorageClass::UniformConstant, nullptr, false,
844                       false, SPIRV::LinkageType::Import, MIRBuilder, false);
845 
846   buildOpDecorate(VarReg, MIRBuilder, SPIRV::Decoration::DescriptorSet, {Set});
847   buildOpDecorate(VarReg, MIRBuilder, SPIRV::Decoration::Binding, {Binding});
848   return VarReg;
849 }
850 
851 SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
852                                                SPIRVType *ElemType,
853                                                MachineIRBuilder &MIRBuilder,
854                                                bool EmitIR) {
855   assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) &&
856          "Invalid array element type");
857   SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
858   Register NumElementsVReg =
859       buildConstantInt(NumElems, MIRBuilder, SpvTypeInt32, EmitIR);
860   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
861     return MIRBuilder.buildInstr(SPIRV::OpTypeArray)
862         .addDef(createTypeVReg(MIRBuilder))
863         .addUse(getSPIRVTypeID(ElemType))
864         .addUse(NumElementsVReg);
865   });
866 }
867 
868 SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
869                                                 MachineIRBuilder &MIRBuilder) {
870   assert(Ty->hasName());
871   const StringRef Name = Ty->hasName() ? Ty->getName() : "";
872   Register ResVReg = createTypeVReg(MIRBuilder);
873   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
874     auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg);
875     addStringImm(Name, MIB);
876     buildOpName(ResVReg, Name, MIRBuilder);
877     return MIB;
878   });
879 }
880 
881 SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
882                                                 MachineIRBuilder &MIRBuilder,
883                                                 bool EmitIR) {
884   SmallVector<Register, 4> FieldTypes;
885   for (const auto &Elem : Ty->elements()) {
886     SPIRVType *ElemTy = findSPIRVType(toTypedPointer(Elem), MIRBuilder);
887     assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
888            "Invalid struct element type");
889     FieldTypes.push_back(getSPIRVTypeID(ElemTy));
890   }
891   Register ResVReg = createTypeVReg(MIRBuilder);
892   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
893     auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg);
894     for (const auto &Ty : FieldTypes)
895       MIB.addUse(Ty);
896     if (Ty->hasName())
897       buildOpName(ResVReg, Ty->getName(), MIRBuilder);
898     if (Ty->isPacked())
899       buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {});
900     return MIB;
901   });
902 }
903 
904 SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType(
905     const Type *Ty, MachineIRBuilder &MIRBuilder,
906     SPIRV::AccessQualifier::AccessQualifier AccQual) {
907   assert(isSpecialOpaqueType(Ty) && "Not a special opaque builtin type");
908   return SPIRV::lowerBuiltinType(Ty, AccQual, MIRBuilder, this);
909 }
910 
911 SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(
912     SPIRV::StorageClass::StorageClass SC, SPIRVType *ElemType,
913     MachineIRBuilder &MIRBuilder, Register Reg) {
914   if (!Reg.isValid())
915     Reg = createTypeVReg(MIRBuilder);
916 
917   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
918     return MIRBuilder.buildInstr(SPIRV::OpTypePointer)
919         .addDef(Reg)
920         .addImm(static_cast<uint32_t>(SC))
921         .addUse(getSPIRVTypeID(ElemType));
922   });
923 }
924 
925 SPIRVType *SPIRVGlobalRegistry::getOpTypeForwardPointer(
926     SPIRV::StorageClass::StorageClass SC, MachineIRBuilder &MIRBuilder) {
927   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
928     return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer)
929         .addUse(createTypeVReg(MIRBuilder))
930         .addImm(static_cast<uint32_t>(SC));
931   });
932 }
933 
934 SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction(
935     SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes,
936     MachineIRBuilder &MIRBuilder) {
937   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction)
938                  .addDef(createTypeVReg(MIRBuilder))
939                  .addUse(getSPIRVTypeID(RetType));
940   for (const SPIRVType *ArgType : ArgTypes)
941     MIB.addUse(getSPIRVTypeID(ArgType));
942   return MIB;
943 }
944 
945 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
946     const Type *Ty, SPIRVType *RetType,
947     const SmallVectorImpl<SPIRVType *> &ArgTypes,
948     MachineIRBuilder &MIRBuilder) {
949   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
950   if (Reg.isValid())
951     return getSPIRVTypeForVReg(Reg);
952   SPIRVType *SpirvType = getOpTypeFunction(RetType, ArgTypes, MIRBuilder);
953   DT.add(Ty, CurMF, getSPIRVTypeID(SpirvType));
954   return finishCreatingSPIRVType(Ty, SpirvType);
955 }
956 
957 SPIRVType *SPIRVGlobalRegistry::findSPIRVType(
958     const Type *Ty, MachineIRBuilder &MIRBuilder,
959     SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
960   Ty = adjustIntTypeByWidth(Ty);
961   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
962   if (Reg.isValid())
963     return getSPIRVTypeForVReg(Reg);
964   if (ForwardPointerTypes.contains(Ty))
965     return ForwardPointerTypes[Ty];
966   return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, EmitIR);
967 }
968 
969 Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
970   assert(SpirvType && "Attempting to get type id for nullptr type.");
971   if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer)
972     return SpirvType->uses().begin()->getReg();
973   return SpirvType->defs().begin()->getReg();
974 }
975 
976 // We need to use a new LLVM integer type if there is a mismatch between
977 // number of bits in LLVM and SPIRV integer types to let DuplicateTracker
978 // ensure uniqueness of a SPIRV type by the corresponding LLVM type. Without
979 // such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create the
980 // same "OpTypeInt 8" type for a series of LLVM integer types with number of
981 // bits less than 8. This would lead to duplicate type definitions
982 // eventually due to the method that DuplicateTracker utilizes to reason
983 // about uniqueness of type records.
984 const Type *SPIRVGlobalRegistry::adjustIntTypeByWidth(const Type *Ty) const {
985   if (auto IType = dyn_cast<IntegerType>(Ty)) {
986     unsigned SrcBitWidth = IType->getBitWidth();
987     if (SrcBitWidth > 1) {
988       unsigned BitWidth = adjustOpTypeIntWidth(SrcBitWidth);
989       // Maybe change source LLVM type to keep DuplicateTracker consistent.
990       if (SrcBitWidth != BitWidth)
991         Ty = IntegerType::get(Ty->getContext(), BitWidth);
992     }
993   }
994   return Ty;
995 }
996 
997 SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
998     const Type *Ty, MachineIRBuilder &MIRBuilder,
999     SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
1000   if (isSpecialOpaqueType(Ty))
1001     return getOrCreateSpecialType(Ty, MIRBuilder, AccQual);
1002   auto &TypeToSPIRVTypeMap = DT.getTypes()->getAllUses();
1003   auto t = TypeToSPIRVTypeMap.find(Ty);
1004   if (t != TypeToSPIRVTypeMap.end()) {
1005     auto tt = t->second.find(&MIRBuilder.getMF());
1006     if (tt != t->second.end())
1007       return getSPIRVTypeForVReg(tt->second);
1008   }
1009 
1010   if (auto IType = dyn_cast<IntegerType>(Ty)) {
1011     const unsigned Width = IType->getBitWidth();
1012     return Width == 1 ? getOpTypeBool(MIRBuilder)
1013                       : getOpTypeInt(Width, MIRBuilder, false);
1014   }
1015   if (Ty->isFloatingPointTy())
1016     return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
1017   if (Ty->isVoidTy())
1018     return getOpTypeVoid(MIRBuilder);
1019   if (Ty->isVectorTy()) {
1020     SPIRVType *El =
1021         findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder);
1022     return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El,
1023                            MIRBuilder);
1024   }
1025   if (Ty->isArrayTy()) {
1026     SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder);
1027     return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR);
1028   }
1029   if (auto SType = dyn_cast<StructType>(Ty)) {
1030     if (SType->isOpaque())
1031       return getOpTypeOpaque(SType, MIRBuilder);
1032     return getOpTypeStruct(SType, MIRBuilder, EmitIR);
1033   }
1034   if (auto FType = dyn_cast<FunctionType>(Ty)) {
1035     SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder);
1036     SmallVector<SPIRVType *, 4> ParamTypes;
1037     for (const auto &t : FType->params()) {
1038       ParamTypes.push_back(findSPIRVType(t, MIRBuilder));
1039     }
1040     return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
1041   }
1042 
1043   unsigned AddrSpace = typeToAddressSpace(Ty);
1044   SPIRVType *SpvElementType = nullptr;
1045   if (Type *ElemTy = ::getPointeeType(Ty))
1046     SpvElementType = getOrCreateSPIRVType(ElemTy, MIRBuilder, AccQual, EmitIR);
1047   else
1048     SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
1049 
1050   // Get access to information about available extensions
1051   const SPIRVSubtarget *ST =
1052       static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
1053   auto SC = addressSpaceToStorageClass(AddrSpace, *ST);
1054   // Null pointer means we have a loop in type definitions, make and
1055   // return corresponding OpTypeForwardPointer.
1056   if (SpvElementType == nullptr) {
1057     if (!ForwardPointerTypes.contains(Ty))
1058       ForwardPointerTypes[Ty] = getOpTypeForwardPointer(SC, MIRBuilder);
1059     return ForwardPointerTypes[Ty];
1060   }
1061   // If we have forward pointer associated with this type, use its register
1062   // operand to create OpTypePointer.
1063   if (ForwardPointerTypes.contains(Ty)) {
1064     Register Reg = getSPIRVTypeID(ForwardPointerTypes[Ty]);
1065     return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg);
1066   }
1067 
1068   return getOrCreateSPIRVPointerType(SpvElementType, MIRBuilder, SC);
1069 }
1070 
1071 SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
1072     const Type *Ty, MachineIRBuilder &MIRBuilder,
1073     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
1074   if (TypesInProcessing.count(Ty) && !isPointerTyOrWrapper(Ty))
1075     return nullptr;
1076   TypesInProcessing.insert(Ty);
1077   SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
1078   TypesInProcessing.erase(Ty);
1079   VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
1080   SPIRVToLLVMType[SpirvType] = unifyPtrType(Ty);
1081   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
1082   // Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type
1083   // will be added later. For special types it is already added to DT.
1084   if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() &&
1085       !isSpecialOpaqueType(Ty)) {
1086     if (auto *ExtTy = dyn_cast<TargetExtType>(Ty);
1087         ExtTy && isTypedPointerWrapper(ExtTy))
1088       DT.add(ExtTy->getTypeParameter(0), ExtTy->getIntParameter(0),
1089              &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType));
1090     else if (!isPointerTy(Ty))
1091       DT.add(Ty, &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType));
1092     else if (isTypedPointerTy(Ty))
1093       DT.add(cast<TypedPointerType>(Ty)->getElementType(),
1094              getPointerAddressSpace(Ty), &MIRBuilder.getMF(),
1095              getSPIRVTypeID(SpirvType));
1096     else
1097       DT.add(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
1098              getPointerAddressSpace(Ty), &MIRBuilder.getMF(),
1099              getSPIRVTypeID(SpirvType));
1100   }
1101 
1102   return SpirvType;
1103 }
1104 
1105 SPIRVType *
1106 SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
1107                                          const MachineFunction *MF) const {
1108   auto t = VRegToTypeMap.find(MF ? MF : CurMF);
1109   if (t != VRegToTypeMap.end()) {
1110     auto tt = t->second.find(VReg);
1111     if (tt != t->second.end())
1112       return tt->second;
1113   }
1114   return nullptr;
1115 }
1116 
1117 SPIRVType *SPIRVGlobalRegistry::getResultType(Register VReg,
1118                                               MachineFunction *MF) {
1119   if (!MF)
1120     MF = CurMF;
1121   MachineInstr *Instr = getVRegDef(MF->getRegInfo(), VReg);
1122   return getSPIRVTypeForVReg(Instr->getOperand(1).getReg(), MF);
1123 }
1124 
1125 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
1126     const Type *Ty, MachineIRBuilder &MIRBuilder,
1127     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
1128   Register Reg;
1129   if (auto *ExtTy = dyn_cast<TargetExtType>(Ty);
1130       ExtTy && isTypedPointerWrapper(ExtTy)) {
1131     Reg = DT.find(ExtTy->getTypeParameter(0), ExtTy->getIntParameter(0),
1132                   &MIRBuilder.getMF());
1133   } else if (!isPointerTy(Ty)) {
1134     Ty = adjustIntTypeByWidth(Ty);
1135     Reg = DT.find(Ty, &MIRBuilder.getMF());
1136   } else if (isTypedPointerTy(Ty)) {
1137     Reg = DT.find(cast<TypedPointerType>(Ty)->getElementType(),
1138                   getPointerAddressSpace(Ty), &MIRBuilder.getMF());
1139   } else {
1140     Reg =
1141         DT.find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
1142                 getPointerAddressSpace(Ty), &MIRBuilder.getMF());
1143   }
1144 
1145   if (Reg.isValid() && !isSpecialOpaqueType(Ty))
1146     return getSPIRVTypeForVReg(Reg);
1147   TypesInProcessing.clear();
1148   SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
1149   // Create normal pointer types for the corresponding OpTypeForwardPointers.
1150   for (auto &CU : ForwardPointerTypes) {
1151     const Type *Ty2 = CU.first;
1152     SPIRVType *STy2 = CU.second;
1153     if ((Reg = DT.find(Ty2, &MIRBuilder.getMF())).isValid())
1154       STy2 = getSPIRVTypeForVReg(Reg);
1155     else
1156       STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, EmitIR);
1157     if (Ty == Ty2)
1158       STy = STy2;
1159   }
1160   ForwardPointerTypes.clear();
1161   return STy;
1162 }
1163 
1164 bool SPIRVGlobalRegistry::isScalarOfType(Register VReg,
1165                                          unsigned TypeOpcode) const {
1166   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
1167   assert(Type && "isScalarOfType VReg has no type assigned");
1168   return Type->getOpcode() == TypeOpcode;
1169 }
1170 
1171 bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
1172                                                  unsigned TypeOpcode) const {
1173   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
1174   assert(Type && "isScalarOrVectorOfType VReg has no type assigned");
1175   if (Type->getOpcode() == TypeOpcode)
1176     return true;
1177   if (Type->getOpcode() == SPIRV::OpTypeVector) {
1178     Register ScalarTypeVReg = Type->getOperand(1).getReg();
1179     SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg);
1180     return ScalarType->getOpcode() == TypeOpcode;
1181   }
1182   return false;
1183 }
1184 
1185 unsigned
1186 SPIRVGlobalRegistry::getScalarOrVectorComponentCount(Register VReg) const {
1187   return getScalarOrVectorComponentCount(getSPIRVTypeForVReg(VReg));
1188 }
1189 
1190 unsigned
1191 SPIRVGlobalRegistry::getScalarOrVectorComponentCount(SPIRVType *Type) const {
1192   if (!Type)
1193     return 0;
1194   return Type->getOpcode() == SPIRV::OpTypeVector
1195              ? static_cast<unsigned>(Type->getOperand(2).getImm())
1196              : 1;
1197 }
1198 
1199 SPIRVType *
1200 SPIRVGlobalRegistry::getScalarOrVectorComponentType(Register VReg) const {
1201   return getScalarOrVectorComponentType(getSPIRVTypeForVReg(VReg));
1202 }
1203 
1204 SPIRVType *
1205 SPIRVGlobalRegistry::getScalarOrVectorComponentType(SPIRVType *Type) const {
1206   if (!Type)
1207     return nullptr;
1208   Register ScalarReg = Type->getOpcode() == SPIRV::OpTypeVector
1209                            ? Type->getOperand(1).getReg()
1210                            : Type->getOperand(0).getReg();
1211   SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarReg);
1212   assert(isScalarOrVectorOfType(Type->getOperand(0).getReg(),
1213                                 ScalarType->getOpcode()));
1214   return ScalarType;
1215 }
1216 
1217 unsigned
1218 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
1219   assert(Type && "Invalid Type pointer");
1220   if (Type->getOpcode() == SPIRV::OpTypeVector) {
1221     auto EleTypeReg = Type->getOperand(1).getReg();
1222     Type = getSPIRVTypeForVReg(EleTypeReg);
1223   }
1224   if (Type->getOpcode() == SPIRV::OpTypeInt ||
1225       Type->getOpcode() == SPIRV::OpTypeFloat)
1226     return Type->getOperand(1).getImm();
1227   if (Type->getOpcode() == SPIRV::OpTypeBool)
1228     return 1;
1229   llvm_unreachable("Attempting to get bit width of non-integer/float type.");
1230 }
1231 
1232 unsigned SPIRVGlobalRegistry::getNumScalarOrVectorTotalBitWidth(
1233     const SPIRVType *Type) const {
1234   assert(Type && "Invalid Type pointer");
1235   unsigned NumElements = 1;
1236   if (Type->getOpcode() == SPIRV::OpTypeVector) {
1237     NumElements = static_cast<unsigned>(Type->getOperand(2).getImm());
1238     Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
1239   }
1240   return Type->getOpcode() == SPIRV::OpTypeInt ||
1241                  Type->getOpcode() == SPIRV::OpTypeFloat
1242              ? NumElements * Type->getOperand(1).getImm()
1243              : 0;
1244 }
1245 
1246 const SPIRVType *SPIRVGlobalRegistry::retrieveScalarOrVectorIntType(
1247     const SPIRVType *Type) const {
1248   if (Type && Type->getOpcode() == SPIRV::OpTypeVector)
1249     Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
1250   return Type && Type->getOpcode() == SPIRV::OpTypeInt ? Type : nullptr;
1251 }
1252 
1253 bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
1254   const SPIRVType *IntType = retrieveScalarOrVectorIntType(Type);
1255   return IntType && IntType->getOperand(2).getImm() != 0;
1256 }
1257 
1258 SPIRVType *SPIRVGlobalRegistry::getPointeeType(SPIRVType *PtrType) {
1259   return PtrType && PtrType->getOpcode() == SPIRV::OpTypePointer
1260              ? getSPIRVTypeForVReg(PtrType->getOperand(2).getReg())
1261              : nullptr;
1262 }
1263 
1264 unsigned SPIRVGlobalRegistry::getPointeeTypeOp(Register PtrReg) {
1265   SPIRVType *ElemType = getPointeeType(getSPIRVTypeForVReg(PtrReg));
1266   return ElemType ? ElemType->getOpcode() : 0;
1267 }
1268 
1269 bool SPIRVGlobalRegistry::isBitcastCompatible(const SPIRVType *Type1,
1270                                               const SPIRVType *Type2) const {
1271   if (!Type1 || !Type2)
1272     return false;
1273   auto Op1 = Type1->getOpcode(), Op2 = Type2->getOpcode();
1274   // Ignore difference between <1.5 and >=1.5 protocol versions:
1275   // it's valid if either Result Type or Operand is a pointer, and the other
1276   // is a pointer, an integer scalar, or an integer vector.
1277   if (Op1 == SPIRV::OpTypePointer &&
1278       (Op2 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type2)))
1279     return true;
1280   if (Op2 == SPIRV::OpTypePointer &&
1281       (Op1 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type1)))
1282     return true;
1283   unsigned Bits1 = getNumScalarOrVectorTotalBitWidth(Type1),
1284            Bits2 = getNumScalarOrVectorTotalBitWidth(Type2);
1285   return Bits1 > 0 && Bits1 == Bits2;
1286 }
1287 
1288 SPIRV::StorageClass::StorageClass
1289 SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const {
1290   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
1291   assert(Type && Type->getOpcode() == SPIRV::OpTypePointer &&
1292          Type->getOperand(1).isImm() && "Pointer type is expected");
1293   return getPointerStorageClass(Type);
1294 }
1295 
1296 SPIRV::StorageClass::StorageClass
1297 SPIRVGlobalRegistry::getPointerStorageClass(const SPIRVType *Type) const {
1298   return static_cast<SPIRV::StorageClass::StorageClass>(
1299       Type->getOperand(1).getImm());
1300 }
1301 
1302 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(
1303     MachineIRBuilder &MIRBuilder, SPIRVType *SampledType, SPIRV::Dim::Dim Dim,
1304     uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled,
1305     SPIRV::ImageFormat::ImageFormat ImageFormat,
1306     SPIRV::AccessQualifier::AccessQualifier AccessQual) {
1307   auto TD = SPIRV::make_descr_image(SPIRVToLLVMType.lookup(SampledType), Dim,
1308                                     Depth, Arrayed, Multisampled, Sampled,
1309                                     ImageFormat, AccessQual);
1310   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
1311     return Res;
1312   Register ResVReg = createTypeVReg(MIRBuilder);
1313   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
1314   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeImage)
1315                  .addDef(ResVReg)
1316                  .addUse(getSPIRVTypeID(SampledType))
1317                  .addImm(Dim)
1318                  .addImm(Depth)   // Depth (whether or not it is a Depth image).
1319                  .addImm(Arrayed) // Arrayed.
1320                  .addImm(Multisampled) // Multisampled (0 = only single-sample).
1321                  .addImm(Sampled)      // Sampled (0 = usage known at runtime).
1322                  .addImm(ImageFormat);
1323 
1324   if (AccessQual != SPIRV::AccessQualifier::None)
1325     MIB.addImm(AccessQual);
1326   return MIB;
1327 }
1328 
1329 SPIRVType *
1330 SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {
1331   auto TD = SPIRV::make_descr_sampler();
1332   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
1333     return Res;
1334   Register ResVReg = createTypeVReg(MIRBuilder);
1335   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
1336   return MIRBuilder.buildInstr(SPIRV::OpTypeSampler).addDef(ResVReg);
1337 }
1338 
1339 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(
1340     MachineIRBuilder &MIRBuilder,
1341     SPIRV::AccessQualifier::AccessQualifier AccessQual) {
1342   auto TD = SPIRV::make_descr_pipe(AccessQual);
1343   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
1344     return Res;
1345   Register ResVReg = createTypeVReg(MIRBuilder);
1346   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
1347   return MIRBuilder.buildInstr(SPIRV::OpTypePipe)
1348       .addDef(ResVReg)
1349       .addImm(AccessQual);
1350 }
1351 
1352 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(
1353     MachineIRBuilder &MIRBuilder) {
1354   auto TD = SPIRV::make_descr_event();
1355   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
1356     return Res;
1357   Register ResVReg = createTypeVReg(MIRBuilder);
1358   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
1359   return MIRBuilder.buildInstr(SPIRV::OpTypeDeviceEvent).addDef(ResVReg);
1360 }
1361 
1362 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
1363     SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) {
1364   auto TD = SPIRV::make_descr_sampled_image(
1365       SPIRVToLLVMType.lookup(MIRBuilder.getMF().getRegInfo().getVRegDef(
1366           ImageType->getOperand(1).getReg())),
1367       ImageType);
1368   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
1369     return Res;
1370   Register ResVReg = createTypeVReg(MIRBuilder);
1371   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
1372   return MIRBuilder.buildInstr(SPIRV::OpTypeSampledImage)
1373       .addDef(ResVReg)
1374       .addUse(getSPIRVTypeID(ImageType));
1375 }
1376 
1377 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
1378     MachineIRBuilder &MIRBuilder, const TargetExtType *ExtensionType,
1379     const SPIRVType *ElemType, uint32_t Scope, uint32_t Rows, uint32_t Columns,
1380     uint32_t Use) {
1381   Register ResVReg = DT.find(ExtensionType, &MIRBuilder.getMF());
1382   if (ResVReg.isValid())
1383     return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
1384   ResVReg = createTypeVReg(MIRBuilder);
1385   SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
1386   SPIRVType *SpirvTy =
1387       MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)
1388           .addDef(ResVReg)
1389           .addUse(getSPIRVTypeID(ElemType))
1390           .addUse(buildConstantInt(Scope, MIRBuilder, SpvTypeInt32, true))
1391           .addUse(buildConstantInt(Rows, MIRBuilder, SpvTypeInt32, true))
1392           .addUse(buildConstantInt(Columns, MIRBuilder, SpvTypeInt32, true))
1393           .addUse(buildConstantInt(Use, MIRBuilder, SpvTypeInt32, true));
1394   DT.add(ExtensionType, &MIRBuilder.getMF(), ResVReg);
1395   return SpirvTy;
1396 }
1397 
1398 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
1399     const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) {
1400   Register ResVReg = DT.find(Ty, &MIRBuilder.getMF());
1401   if (ResVReg.isValid())
1402     return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
1403   ResVReg = createTypeVReg(MIRBuilder);
1404   SPIRVType *SpirvTy = MIRBuilder.buildInstr(Opcode).addDef(ResVReg);
1405   DT.add(Ty, &MIRBuilder.getMF(), ResVReg);
1406   return SpirvTy;
1407 }
1408 
1409 const MachineInstr *
1410 SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
1411                                        MachineIRBuilder &MIRBuilder) {
1412   Register Reg = DT.find(TD, &MIRBuilder.getMF());
1413   if (Reg.isValid())
1414     return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(Reg);
1415   return nullptr;
1416 }
1417 
1418 // Returns nullptr if unable to recognize SPIRV type name
1419 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
1420     StringRef TypeStr, MachineIRBuilder &MIRBuilder,
1421     SPIRV::StorageClass::StorageClass SC,
1422     SPIRV::AccessQualifier::AccessQualifier AQ) {
1423   unsigned VecElts = 0;
1424   auto &Ctx = MIRBuilder.getMF().getFunction().getContext();
1425 
1426   // Parse strings representing either a SPIR-V or OpenCL builtin type.
1427   if (hasBuiltinTypePrefix(TypeStr))
1428     return getOrCreateSPIRVType(SPIRV::parseBuiltinTypeNameToTargetExtType(
1429                                     TypeStr.str(), MIRBuilder.getContext()),
1430                                 MIRBuilder, AQ);
1431 
1432   // Parse type name in either "typeN" or "type vector[N]" format, where
1433   // N is the number of elements of the vector.
1434   Type *Ty;
1435 
1436   Ty = parseBasicTypeName(TypeStr, Ctx);
1437   if (!Ty)
1438     // Unable to recognize SPIRV type name
1439     return nullptr;
1440 
1441   auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ);
1442 
1443   // Handle "type*" or  "type* vector[N]".
1444   if (TypeStr.starts_with("*")) {
1445     SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
1446     TypeStr = TypeStr.substr(strlen("*"));
1447   }
1448 
1449   // Handle "typeN*" or  "type vector[N]*".
1450   bool IsPtrToVec = TypeStr.consume_back("*");
1451 
1452   if (TypeStr.consume_front(" vector[")) {
1453     TypeStr = TypeStr.substr(0, TypeStr.find(']'));
1454   }
1455   TypeStr.getAsInteger(10, VecElts);
1456   if (VecElts > 0)
1457     SpirvTy = getOrCreateSPIRVVectorType(SpirvTy, VecElts, MIRBuilder);
1458 
1459   if (IsPtrToVec)
1460     SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
1461 
1462   return SpirvTy;
1463 }
1464 
1465 SPIRVType *
1466 SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth,
1467                                                  MachineIRBuilder &MIRBuilder) {
1468   return getOrCreateSPIRVType(
1469       IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth),
1470       MIRBuilder);
1471 }
1472 
1473 SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
1474                                                         SPIRVType *SpirvType) {
1475   assert(CurMF == SpirvType->getMF());
1476   VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
1477   SPIRVToLLVMType[SpirvType] = unifyPtrType(LLVMTy);
1478   return SpirvType;
1479 }
1480 
1481 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
1482                                                      MachineInstr &I,
1483                                                      const SPIRVInstrInfo &TII,
1484                                                      unsigned SPIRVOPcode,
1485                                                      Type *LLVMTy) {
1486   Register Reg = DT.find(LLVMTy, CurMF);
1487   if (Reg.isValid())
1488     return getSPIRVTypeForVReg(Reg);
1489   MachineBasicBlock &BB = *I.getParent();
1490   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRVOPcode))
1491                  .addDef(createTypeVReg(CurMF->getRegInfo()))
1492                  .addImm(BitWidth)
1493                  .addImm(0);
1494   DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
1495   return finishCreatingSPIRVType(LLVMTy, MIB);
1496 }
1497 
1498 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
1499     unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
1500   // Maybe adjust bit width to keep DuplicateTracker consistent. Without
1501   // such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create, for
1502   // example, the same "OpTypeInt 8" type for a series of LLVM integer types
1503   // with number of bits less than 8, causing duplicate type definitions.
1504   BitWidth = adjustOpTypeIntWidth(BitWidth);
1505   Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
1506   return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeInt, LLVMTy);
1507 }
1508 
1509 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
1510     unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
1511   LLVMContext &Ctx = CurMF->getFunction().getContext();
1512   Type *LLVMTy;
1513   switch (BitWidth) {
1514   case 16:
1515     LLVMTy = Type::getHalfTy(Ctx);
1516     break;
1517   case 32:
1518     LLVMTy = Type::getFloatTy(Ctx);
1519     break;
1520   case 64:
1521     LLVMTy = Type::getDoubleTy(Ctx);
1522     break;
1523   default:
1524     llvm_unreachable("Bit width is of unexpected size.");
1525   }
1526   return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeFloat, LLVMTy);
1527 }
1528 
1529 SPIRVType *
1530 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) {
1531   return getOrCreateSPIRVType(
1532       IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1),
1533       MIRBuilder);
1534 }
1535 
1536 SPIRVType *
1537 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I,
1538                                               const SPIRVInstrInfo &TII) {
1539   Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), 1);
1540   Register Reg = DT.find(LLVMTy, CurMF);
1541   if (Reg.isValid())
1542     return getSPIRVTypeForVReg(Reg);
1543   MachineBasicBlock &BB = *I.getParent();
1544   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeBool))
1545                  .addDef(createTypeVReg(CurMF->getRegInfo()));
1546   DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
1547   return finishCreatingSPIRVType(LLVMTy, MIB);
1548 }
1549 
1550 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1551     SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) {
1552   return getOrCreateSPIRVType(
1553       FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
1554                            NumElements),
1555       MIRBuilder);
1556 }
1557 
1558 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1559     SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
1560     const SPIRVInstrInfo &TII) {
1561   Type *LLVMTy = FixedVectorType::get(
1562       const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
1563   Register Reg = DT.find(LLVMTy, CurMF);
1564   if (Reg.isValid())
1565     return getSPIRVTypeForVReg(Reg);
1566   MachineBasicBlock &BB = *I.getParent();
1567   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector))
1568                  .addDef(createTypeVReg(CurMF->getRegInfo()))
1569                  .addUse(getSPIRVTypeID(BaseType))
1570                  .addImm(NumElements);
1571   DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
1572   return finishCreatingSPIRVType(LLVMTy, MIB);
1573 }
1574 
1575 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType(
1576     SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
1577     const SPIRVInstrInfo &TII) {
1578   Type *LLVMTy = ArrayType::get(
1579       const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
1580   Register Reg = DT.find(LLVMTy, CurMF);
1581   if (Reg.isValid())
1582     return getSPIRVTypeForVReg(Reg);
1583   MachineBasicBlock &BB = *I.getParent();
1584   SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, I, TII);
1585   Register Len = getOrCreateConstInt(NumElements, I, SpvTypeInt32, TII);
1586   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeArray))
1587                  .addDef(createTypeVReg(CurMF->getRegInfo()))
1588                  .addUse(getSPIRVTypeID(BaseType))
1589                  .addUse(Len);
1590   DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
1591   return finishCreatingSPIRVType(LLVMTy, MIB);
1592 }
1593 
1594 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1595     SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
1596     SPIRV::StorageClass::StorageClass SC) {
1597   const Type *PointerElementType = getTypeForSPIRVType(BaseType);
1598   unsigned AddressSpace = storageClassToAddressSpace(SC);
1599   Type *LLVMTy = TypedPointerType::get(const_cast<Type *>(PointerElementType),
1600                                        AddressSpace);
1601   // check if this type is already available
1602   Register Reg = DT.find(PointerElementType, AddressSpace, CurMF);
1603   if (Reg.isValid())
1604     return getSPIRVTypeForVReg(Reg);
1605   // create a new type
1606   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1607     auto MIB = BuildMI(MIRBuilder.getMBB(), MIRBuilder.getInsertPt(),
1608                        MIRBuilder.getDebugLoc(),
1609                        MIRBuilder.getTII().get(SPIRV::OpTypePointer))
1610                    .addDef(createTypeVReg(CurMF->getRegInfo()))
1611                    .addImm(static_cast<uint32_t>(SC))
1612                    .addUse(getSPIRVTypeID(BaseType));
1613     DT.add(PointerElementType, AddressSpace, CurMF, getSPIRVTypeID(MIB));
1614     finishCreatingSPIRVType(LLVMTy, MIB);
1615     return MIB;
1616   });
1617 }
1618 
1619 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1620     SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &,
1621     SPIRV::StorageClass::StorageClass SC) {
1622   MachineIRBuilder MIRBuilder(I);
1623   return getOrCreateSPIRVPointerType(BaseType, MIRBuilder, SC);
1624 }
1625 
1626 Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
1627                                                SPIRVType *SpvType,
1628                                                const SPIRVInstrInfo &TII) {
1629   assert(SpvType);
1630   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
1631   assert(LLVMTy);
1632   // Find a constant in DT or build a new one.
1633   UndefValue *UV = UndefValue::get(const_cast<Type *>(LLVMTy));
1634   Register Res = DT.find(UV, CurMF);
1635   if (Res.isValid())
1636     return Res;
1637   LLT LLTy = LLT::scalar(64);
1638   Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
1639   CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
1640   assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
1641   DT.add(UV, CurMF, Res);
1642 
1643   MachineInstrBuilder MIB;
1644   MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpUndef))
1645             .addDef(Res)
1646             .addUse(getSPIRVTypeID(SpvType));
1647   const auto &ST = CurMF->getSubtarget();
1648   constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
1649                                    *ST.getRegisterInfo(), *ST.getRegBankInfo());
1650   return Res;
1651 }
1652 
1653 const TargetRegisterClass *
1654 SPIRVGlobalRegistry::getRegClass(SPIRVType *SpvType) const {
1655   unsigned Opcode = SpvType->getOpcode();
1656   switch (Opcode) {
1657   case SPIRV::OpTypeFloat:
1658     return &SPIRV::fIDRegClass;
1659   case SPIRV::OpTypePointer:
1660     return &SPIRV::pIDRegClass;
1661   case SPIRV::OpTypeVector: {
1662     SPIRVType *ElemType = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
1663     unsigned ElemOpcode = ElemType ? ElemType->getOpcode() : 0;
1664     if (ElemOpcode == SPIRV::OpTypeFloat)
1665       return &SPIRV::vfIDRegClass;
1666     if (ElemOpcode == SPIRV::OpTypePointer)
1667       return &SPIRV::vpIDRegClass;
1668     return &SPIRV::vIDRegClass;
1669   }
1670   }
1671   return &SPIRV::iIDRegClass;
1672 }
1673 
1674 inline unsigned getAS(SPIRVType *SpvType) {
1675   return storageClassToAddressSpace(
1676       static_cast<SPIRV::StorageClass::StorageClass>(
1677           SpvType->getOperand(1).getImm()));
1678 }
1679 
1680 LLT SPIRVGlobalRegistry::getRegType(SPIRVType *SpvType) const {
1681   unsigned Opcode = SpvType ? SpvType->getOpcode() : 0;
1682   switch (Opcode) {
1683   case SPIRV::OpTypeInt:
1684   case SPIRV::OpTypeFloat:
1685   case SPIRV::OpTypeBool:
1686     return LLT::scalar(getScalarOrVectorBitWidth(SpvType));
1687   case SPIRV::OpTypePointer:
1688     return LLT::pointer(getAS(SpvType), getPointerSize());
1689   case SPIRV::OpTypeVector: {
1690     SPIRVType *ElemType = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
1691     LLT ET;
1692     switch (ElemType ? ElemType->getOpcode() : 0) {
1693     case SPIRV::OpTypePointer:
1694       ET = LLT::pointer(getAS(ElemType), getPointerSize());
1695       break;
1696     case SPIRV::OpTypeInt:
1697     case SPIRV::OpTypeFloat:
1698     case SPIRV::OpTypeBool:
1699       ET = LLT::scalar(getScalarOrVectorBitWidth(ElemType));
1700       break;
1701     default:
1702       ET = LLT::scalar(64);
1703     }
1704     return LLT::fixed_vector(
1705         static_cast<unsigned>(SpvType->getOperand(2).getImm()), ET);
1706   }
1707   }
1708   return LLT::scalar(64);
1709 }
1710