xref: /llvm-project/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (revision 2fc7a72733762d685a07f846b44dc17a0585098e)
1 //===-- SPIRVPreLegalizer.cpp - prepare IR for legalization -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The pass prepares IR for legalization: it assigns SPIR-V types to registers
10 // and removes intrinsics which holded these types during IR translation.
11 // Also it processes constants and registers them in GR to avoid duplication.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "SPIRV.h"
16 #include "SPIRVSubtarget.h"
17 #include "SPIRVUtils.h"
18 #include "llvm/ADT/PostOrderIterator.h"
19 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
20 #include "llvm/IR/Attributes.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DebugInfoMetadata.h"
23 #include "llvm/IR/IntrinsicsSPIRV.h"
24 #include "llvm/Target/TargetIntrinsicInfo.h"
25 
26 #define DEBUG_TYPE "spirv-prelegalizer"
27 
28 using namespace llvm;
29 
30 namespace {
31 class SPIRVPreLegalizer : public MachineFunctionPass {
32 public:
33   static char ID;
34   SPIRVPreLegalizer() : MachineFunctionPass(ID) {
35     initializeSPIRVPreLegalizerPass(*PassRegistry::getPassRegistry());
36   }
37   bool runOnMachineFunction(MachineFunction &MF) override;
38 };
39 } // namespace
40 
41 static void
42 addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR,
43                     const SPIRVSubtarget &STI,
44                     DenseMap<MachineInstr *, Type *> &TargetExtConstTypes,
45                     SmallSet<Register, 4> &TrackedConstRegs) {
46   MachineRegisterInfo &MRI = MF.getRegInfo();
47   DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
48   SmallVector<MachineInstr *, 10> ToErase, ToEraseComposites;
49   for (MachineBasicBlock &MBB : MF) {
50     for (MachineInstr &MI : MBB) {
51       if (!isSpvIntrinsic(MI, Intrinsic::spv_track_constant))
52         continue;
53       ToErase.push_back(&MI);
54       Register SrcReg = MI.getOperand(2).getReg();
55       auto *Const =
56           cast<Constant>(cast<ConstantAsMetadata>(
57                              MI.getOperand(3).getMetadata()->getOperand(0))
58                              ->getValue());
59       if (auto *GV = dyn_cast<GlobalValue>(Const)) {
60         Register Reg = GR->find(GV, &MF);
61         if (!Reg.isValid())
62           GR->add(GV, &MF, SrcReg);
63         else
64           RegsAlreadyAddedToDT[&MI] = Reg;
65       } else {
66         Register Reg = GR->find(Const, &MF);
67         if (!Reg.isValid()) {
68           if (auto *ConstVec = dyn_cast<ConstantDataVector>(Const)) {
69             auto *BuildVec = MRI.getVRegDef(SrcReg);
70             assert(BuildVec &&
71                    BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
72             for (unsigned i = 0; i < ConstVec->getNumElements(); ++i) {
73               // Ensure that OpConstantComposite reuses a constant when it's
74               // already created and available in the same machine function.
75               Constant *ElemConst = ConstVec->getElementAsConstant(i);
76               Register ElemReg = GR->find(ElemConst, &MF);
77               if (!ElemReg.isValid())
78                 GR->add(ElemConst, &MF, BuildVec->getOperand(1 + i).getReg());
79               else
80                 BuildVec->getOperand(1 + i).setReg(ElemReg);
81             }
82           }
83           GR->add(Const, &MF, SrcReg);
84           TrackedConstRegs.insert(SrcReg);
85           if (Const->getType()->isTargetExtTy()) {
86             // remember association so that we can restore it when assign types
87             MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
88             if (SrcMI && (SrcMI->getOpcode() == TargetOpcode::G_CONSTANT ||
89                           SrcMI->getOpcode() == TargetOpcode::G_IMPLICIT_DEF))
90               TargetExtConstTypes[SrcMI] = Const->getType();
91             if (Const->isNullValue()) {
92               MachineIRBuilder MIB(MF);
93               SPIRVType *ExtType =
94                   GR->getOrCreateSPIRVType(Const->getType(), MIB);
95               SrcMI->setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull));
96               SrcMI->addOperand(MachineOperand::CreateReg(
97                   GR->getSPIRVTypeID(ExtType), false));
98             }
99           }
100         } else {
101           RegsAlreadyAddedToDT[&MI] = Reg;
102           // This MI is unused and will be removed. If the MI uses
103           // const_composite, it will be unused and should be removed too.
104           assert(MI.getOperand(2).isReg() && "Reg operand is expected");
105           MachineInstr *SrcMI = MRI.getVRegDef(MI.getOperand(2).getReg());
106           if (SrcMI && isSpvIntrinsic(*SrcMI, Intrinsic::spv_const_composite))
107             ToEraseComposites.push_back(SrcMI);
108         }
109       }
110     }
111   }
112   for (MachineInstr *MI : ToErase) {
113     Register Reg = MI->getOperand(2).getReg();
114     if (RegsAlreadyAddedToDT.contains(MI))
115       Reg = RegsAlreadyAddedToDT[MI];
116     auto *RC = MRI.getRegClassOrNull(MI->getOperand(0).getReg());
117     if (!MRI.getRegClassOrNull(Reg) && RC)
118       MRI.setRegClass(Reg, RC);
119     MRI.replaceRegWith(MI->getOperand(0).getReg(), Reg);
120     MI->eraseFromParent();
121   }
122   for (MachineInstr *MI : ToEraseComposites)
123     MI->eraseFromParent();
124 }
125 
126 static void
127 foldConstantsIntoIntrinsics(MachineFunction &MF,
128                             const SmallSet<Register, 4> &TrackedConstRegs) {
129   SmallVector<MachineInstr *, 10> ToErase;
130   MachineRegisterInfo &MRI = MF.getRegInfo();
131   const unsigned AssignNameOperandShift = 2;
132   for (MachineBasicBlock &MBB : MF) {
133     for (MachineInstr &MI : MBB) {
134       if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_name))
135         continue;
136       unsigned NumOp = MI.getNumExplicitDefs() + AssignNameOperandShift;
137       while (MI.getOperand(NumOp).isReg()) {
138         MachineOperand &MOp = MI.getOperand(NumOp);
139         MachineInstr *ConstMI = MRI.getVRegDef(MOp.getReg());
140         assert(ConstMI->getOpcode() == TargetOpcode::G_CONSTANT);
141         MI.removeOperand(NumOp);
142         MI.addOperand(MachineOperand::CreateImm(
143             ConstMI->getOperand(1).getCImm()->getZExtValue()));
144         Register DefReg = ConstMI->getOperand(0).getReg();
145         if (MRI.use_empty(DefReg) && !TrackedConstRegs.contains(DefReg))
146           ToErase.push_back(ConstMI);
147       }
148     }
149   }
150   for (MachineInstr *MI : ToErase)
151     MI->eraseFromParent();
152 }
153 
154 static MachineInstr *findAssignTypeInstr(Register Reg,
155                                          MachineRegisterInfo *MRI) {
156   for (MachineRegisterInfo::use_instr_iterator I = MRI->use_instr_begin(Reg),
157                                                IE = MRI->use_instr_end();
158        I != IE; ++I) {
159     MachineInstr *UseMI = &*I;
160     if ((isSpvIntrinsic(*UseMI, Intrinsic::spv_assign_ptr_type) ||
161          isSpvIntrinsic(*UseMI, Intrinsic::spv_assign_type)) &&
162         UseMI->getOperand(1).getReg() == Reg)
163       return UseMI;
164   }
165   return nullptr;
166 }
167 
168 static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
169                            MachineIRBuilder MIB) {
170   // Get access to information about available extensions
171   const SPIRVSubtarget *ST =
172       static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
173   SmallVector<MachineInstr *, 10> ToErase;
174   for (MachineBasicBlock &MBB : MF) {
175     for (MachineInstr &MI : MBB) {
176       if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast) &&
177           !isSpvIntrinsic(MI, Intrinsic::spv_ptrcast))
178         continue;
179       assert(MI.getOperand(2).isReg());
180       MIB.setInsertPt(*MI.getParent(), MI);
181       ToErase.push_back(&MI);
182       if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) {
183         MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
184         continue;
185       }
186       Register Def = MI.getOperand(0).getReg();
187       Register Source = MI.getOperand(2).getReg();
188       Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0);
189       SPIRVType *BaseTy = GR->getOrCreateSPIRVType(ElemTy, MIB);
190       SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
191           BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
192           addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST));
193 
194       // If the ptrcast would be redundant, replace all uses with the source
195       // register.
196       if (GR->getSPIRVTypeForVReg(Source) == AssignedPtrType) {
197         // Erase Def's assign type instruction if we are going to replace Def.
198         if (MachineInstr *AssignMI = findAssignTypeInstr(Def, MIB.getMRI()))
199           ToErase.push_back(AssignMI);
200         MIB.getMRI()->replaceRegWith(Def, Source);
201       } else {
202         GR->assignSPIRVTypeToVReg(AssignedPtrType, Def, MF);
203         MIB.buildBitcast(Def, Source);
204       }
205     }
206   }
207   for (MachineInstr *MI : ToErase)
208     MI->eraseFromParent();
209 }
210 
211 // Translating GV, IRTranslator sometimes generates following IR:
212 //   %1 = G_GLOBAL_VALUE
213 //   %2 = COPY %1
214 //   %3 = G_ADDRSPACE_CAST %2
215 //
216 // or
217 //
218 //  %1 = G_ZEXT %2
219 //  G_MEMCPY ... %2 ...
220 //
221 // New registers have no SPIRVType and no register class info.
222 //
223 // Set SPIRVType for GV, propagate it from GV to other instructions,
224 // also set register classes.
225 static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
226                                      MachineRegisterInfo &MRI,
227                                      MachineIRBuilder &MIB) {
228   SPIRVType *SpvType = nullptr;
229   assert(MI && "Machine instr is expected");
230   if (MI->getOperand(0).isReg()) {
231     Register Reg = MI->getOperand(0).getReg();
232     SpvType = GR->getSPIRVTypeForVReg(Reg);
233     if (!SpvType) {
234       switch (MI->getOpcode()) {
235       case TargetOpcode::G_CONSTANT: {
236         MIB.setInsertPt(*MI->getParent(), MI);
237         Type *Ty = MI->getOperand(1).getCImm()->getType();
238         SpvType = GR->getOrCreateSPIRVType(Ty, MIB);
239         break;
240       }
241       case TargetOpcode::G_GLOBAL_VALUE: {
242         MIB.setInsertPt(*MI->getParent(), MI);
243         const GlobalValue *Global = MI->getOperand(1).getGlobal();
244         Type *ElementTy = toTypedPointer(GR->getDeducedGlobalValueType(Global));
245         auto *Ty = TypedPointerType::get(ElementTy,
246                                          Global->getType()->getAddressSpace());
247         SpvType = GR->getOrCreateSPIRVType(Ty, MIB);
248         break;
249       }
250       case TargetOpcode::G_ANYEXT:
251       case TargetOpcode::G_SEXT:
252       case TargetOpcode::G_ZEXT: {
253         if (MI->getOperand(1).isReg()) {
254           if (MachineInstr *DefInstr =
255                   MRI.getVRegDef(MI->getOperand(1).getReg())) {
256             if (SPIRVType *Def = propagateSPIRVType(DefInstr, GR, MRI, MIB)) {
257               unsigned CurrentBW = GR->getScalarOrVectorBitWidth(Def);
258               unsigned ExpectedBW =
259                   std::max(MRI.getType(Reg).getScalarSizeInBits(), CurrentBW);
260               unsigned NumElements = GR->getScalarOrVectorComponentCount(Def);
261               SpvType = GR->getOrCreateSPIRVIntegerType(ExpectedBW, MIB);
262               if (NumElements > 1)
263                 SpvType =
264                     GR->getOrCreateSPIRVVectorType(SpvType, NumElements, MIB);
265             }
266           }
267         }
268         break;
269       }
270       case TargetOpcode::G_PTRTOINT:
271         SpvType = GR->getOrCreateSPIRVIntegerType(
272             MRI.getType(Reg).getScalarSizeInBits(), MIB);
273         break;
274       case TargetOpcode::G_TRUNC:
275       case TargetOpcode::G_ADDRSPACE_CAST:
276       case TargetOpcode::G_PTR_ADD:
277       case TargetOpcode::COPY: {
278         MachineOperand &Op = MI->getOperand(1);
279         MachineInstr *Def = Op.isReg() ? MRI.getVRegDef(Op.getReg()) : nullptr;
280         if (Def)
281           SpvType = propagateSPIRVType(Def, GR, MRI, MIB);
282         break;
283       }
284       default:
285         break;
286       }
287       if (SpvType)
288         GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF());
289       if (!MRI.getRegClassOrNull(Reg))
290         MRI.setRegClass(Reg, &SPIRV::iIDRegClass);
291     }
292   }
293   return SpvType;
294 }
295 
296 // To support current approach and limitations wrt. bit width here we widen a
297 // scalar register with a bit width greater than 1 to valid sizes and cap it to
298 // 64 width.
299 static void widenScalarLLTNextPow2(Register Reg, MachineRegisterInfo &MRI) {
300   LLT RegType = MRI.getType(Reg);
301   if (!RegType.isScalar())
302     return;
303   unsigned Sz = RegType.getScalarSizeInBits();
304   if (Sz == 1)
305     return;
306   unsigned NewSz = std::min(std::max(1u << Log2_32_Ceil(Sz), 8u), 64u);
307   if (NewSz != Sz)
308     MRI.setType(Reg, LLT::scalar(NewSz));
309 }
310 
311 inline bool getIsFloat(SPIRVType *SpvType, const SPIRVGlobalRegistry &GR) {
312   bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat;
313   return IsFloat ? true
314                  : SpvType->getOpcode() == SPIRV::OpTypeVector &&
315                        GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg())
316                                ->getOpcode() == SPIRV::OpTypeFloat;
317 }
318 
319 static const TargetRegisterClass *getRegClass(SPIRVType *SpvType,
320                                               const SPIRVGlobalRegistry &GR) {
321   unsigned Opcode = SpvType->getOpcode();
322   switch (Opcode) {
323   case SPIRV::OpTypeFloat:
324     return &SPIRV::fIDRegClass;
325   case SPIRV::OpTypePointer:
326     return GR.getPointerSize() == 64 ? &SPIRV::pID64RegClass
327                                      : &SPIRV::pID32RegClass;
328   case SPIRV::OpTypeVector: {
329     SPIRVType *ElemType =
330         GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
331     unsigned ElemOpcode = ElemType ? ElemType->getOpcode() : 0;
332     if (ElemOpcode == SPIRV::OpTypeFloat)
333       return &SPIRV::vfIDRegClass;
334     if (ElemOpcode == SPIRV::OpTypePointer)
335       return GR.getPointerSize() == 64 ? &SPIRV::vpID64RegClass
336                                        : &SPIRV::vpID32RegClass;
337     return &SPIRV::vIDRegClass;
338   }
339   }
340   return &SPIRV::iIDRegClass;
341 }
342 
343 static std::pair<Register, unsigned>
344 createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI,
345                const SPIRVGlobalRegistry &GR) {
346   if (!SpvType)
347     SpvType = GR.getSPIRVTypeForVReg(SrcReg);
348   assert(SpvType && "VReg is expected to have SPIRV type");
349   LLT NewT;
350   LLT SrcLLT = MRI.getType(SrcReg);
351   bool IsFloat = getIsFloat(SpvType, GR);
352   auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID;
353   if (SrcLLT.isPointer()) {
354     unsigned PtrSz = GR.getPointerSize();
355     NewT = LLT::pointer(0, PtrSz);
356     bool IsVec = SrcLLT.isVector();
357     if (IsVec)
358       NewT = LLT::fixed_vector(2, NewT);
359     if (PtrSz == 64)
360       GetIdOp = IsVec ? SPIRV::GET_vpID64 : SPIRV::GET_pID64;
361     else
362       GetIdOp = IsVec ? SPIRV::GET_vpID32 : SPIRV::GET_pID32;
363   } else if (SrcLLT.isVector()) {
364     NewT = LLT::scalar(GR.getScalarOrVectorBitWidth(SpvType));
365     NewT = LLT::fixed_vector(2, NewT);
366     GetIdOp = IsFloat ? SPIRV::GET_vfID : SPIRV::GET_vID;
367   } else {
368     NewT = LLT::scalar(GR.getScalarOrVectorBitWidth(SpvType));
369   }
370   Register IdReg = MRI.createGenericVirtualRegister(NewT);
371   MRI.setRegClass(IdReg, getRegClass(SpvType, GR));
372   return {IdReg, GetIdOp};
373 }
374 
375 // Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as
376 // a dst of the definition, assign SPIRVType to both registers. If SpvType is
377 // provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty.
378 // It's used also in SPIRVBuiltins.cpp.
379 // TODO: maybe move to SPIRVUtils.
380 namespace llvm {
381 Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpvType,
382                            SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
383                            MachineRegisterInfo &MRI) {
384   MachineInstr *Def = MRI.getVRegDef(Reg);
385   assert((Ty || SpvType) && "Either LLVM or SPIRV type is expected.");
386   MIB.setInsertPt(*Def->getParent(),
387                   (Def->getNextNode() ? Def->getNextNode()->getIterator()
388                                       : Def->getParent()->end()));
389   SpvType = SpvType ? SpvType : GR->getOrCreateSPIRVType(Ty, MIB);
390   Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
391   if (auto *RC = MRI.getRegClassOrNull(Reg)) {
392     MRI.setRegClass(NewReg, RC);
393   } else {
394     auto RegClass = getRegClass(SpvType, *GR);
395     MRI.setRegClass(NewReg, RegClass);
396     MRI.setRegClass(Reg, RegClass);
397   }
398   GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF());
399   // This is to make it convenient for Legalizer to get the SPIRVType
400   // when processing the actual MI (i.e. not pseudo one).
401   GR->assignSPIRVTypeToVReg(SpvType, NewReg, MIB.getMF());
402   // Copy MIFlags from Def to ASSIGN_TYPE instruction. It's required to keep
403   // the flags after instruction selection.
404   const uint32_t Flags = Def->getFlags();
405   MIB.buildInstr(SPIRV::ASSIGN_TYPE)
406       .addDef(Reg)
407       .addUse(NewReg)
408       .addUse(GR->getSPIRVTypeID(SpvType))
409       .setMIFlags(Flags);
410   Def->getOperand(0).setReg(NewReg);
411   return NewReg;
412 }
413 
414 void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
415                   MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) {
416   assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg()));
417   MachineInstr &AssignTypeInst =
418       *(MRI.use_instr_begin(MI.getOperand(0).getReg()));
419   auto NewReg =
420       createNewIdReg(nullptr, MI.getOperand(0).getReg(), MRI, *GR).first;
421   AssignTypeInst.getOperand(1).setReg(NewReg);
422   MI.getOperand(0).setReg(NewReg);
423   MIB.setInsertPt(*MI.getParent(),
424                   (MI.getNextNode() ? MI.getNextNode()->getIterator()
425                                     : MI.getParent()->end()));
426   for (auto &Op : MI.operands()) {
427     if (!Op.isReg() || Op.isDef())
428       continue;
429     auto IdOpInfo = createNewIdReg(nullptr, Op.getReg(), MRI, *GR);
430     MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg());
431     Op.setReg(IdOpInfo.first);
432   }
433 }
434 } // namespace llvm
435 
436 static void
437 generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
438                      MachineIRBuilder MIB,
439                      DenseMap<MachineInstr *, Type *> &TargetExtConstTypes) {
440   // Get access to information about available extensions
441   const SPIRVSubtarget *ST =
442       static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
443 
444   MachineRegisterInfo &MRI = MF.getRegInfo();
445   SmallVector<MachineInstr *, 10> ToErase;
446   DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
447 
448   for (MachineBasicBlock *MBB : post_order(&MF)) {
449     if (MBB->empty())
450       continue;
451 
452     bool ReachedBegin = false;
453     for (auto MII = std::prev(MBB->end()), Begin = MBB->begin();
454          !ReachedBegin;) {
455       MachineInstr &MI = *MII;
456       unsigned MIOp = MI.getOpcode();
457 
458       // validate bit width of scalar registers
459       for (const auto &MOP : MI.operands())
460         if (MOP.isReg())
461           widenScalarLLTNextPow2(MOP.getReg(), MRI);
462 
463       if (isSpvIntrinsic(MI, Intrinsic::spv_assign_ptr_type)) {
464         Register Reg = MI.getOperand(1).getReg();
465         MIB.setInsertPt(*MI.getParent(), MI.getIterator());
466         Type *ElementTy = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
467         SPIRVType *BaseTy = GR->getOrCreateSPIRVType(ElementTy, MIB);
468         SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
469             BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
470             addressSpaceToStorageClass(MI.getOperand(3).getImm(), *ST));
471         MachineInstr *Def = MRI.getVRegDef(Reg);
472         assert(Def && "Expecting an instruction that defines the register");
473         // G_GLOBAL_VALUE already has type info.
474         if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE &&
475             Def->getOpcode() != SPIRV::ASSIGN_TYPE)
476           insertAssignInstr(Reg, nullptr, AssignedPtrType, GR, MIB,
477                             MF.getRegInfo());
478         ToErase.push_back(&MI);
479       } else if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) {
480         Register Reg = MI.getOperand(1).getReg();
481         Type *Ty = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
482         MachineInstr *Def = MRI.getVRegDef(Reg);
483         assert(Def && "Expecting an instruction that defines the register");
484         // G_GLOBAL_VALUE already has type info.
485         if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE &&
486             Def->getOpcode() != SPIRV::ASSIGN_TYPE)
487           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo());
488         ToErase.push_back(&MI);
489       } else if (MIOp == TargetOpcode::G_CONSTANT ||
490                  MIOp == TargetOpcode::G_FCONSTANT ||
491                  MIOp == TargetOpcode::G_BUILD_VECTOR) {
492         // %rc = G_CONSTANT ty Val
493         // ===>
494         // %cty = OpType* ty
495         // %rctmp = G_CONSTANT ty Val
496         // %rc = ASSIGN_TYPE %rctmp, %cty
497         Register Reg = MI.getOperand(0).getReg();
498         bool NeedAssignType = true;
499         if (MRI.hasOneUse(Reg)) {
500           MachineInstr &UseMI = *MRI.use_instr_begin(Reg);
501           if (isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type) ||
502               isSpvIntrinsic(UseMI, Intrinsic::spv_assign_name))
503             continue;
504         }
505         Type *Ty = nullptr;
506         if (MIOp == TargetOpcode::G_CONSTANT) {
507           auto TargetExtIt = TargetExtConstTypes.find(&MI);
508           Ty = TargetExtIt == TargetExtConstTypes.end()
509                    ? MI.getOperand(1).getCImm()->getType()
510                    : TargetExtIt->second;
511           const ConstantInt *OpCI = MI.getOperand(1).getCImm();
512           Register PrimaryReg = GR->find(OpCI, &MF);
513           if (!PrimaryReg.isValid()) {
514             GR->add(OpCI, &MF, Reg);
515           } else if (PrimaryReg != Reg &&
516                      MRI.getType(Reg) == MRI.getType(PrimaryReg)) {
517             auto *RCReg = MRI.getRegClassOrNull(Reg);
518             auto *RCPrimary = MRI.getRegClassOrNull(PrimaryReg);
519             if (!RCReg || RCPrimary == RCReg) {
520               RegsAlreadyAddedToDT[&MI] = PrimaryReg;
521               ToErase.push_back(&MI);
522               NeedAssignType = false;
523             }
524           }
525         } else if (MIOp == TargetOpcode::G_FCONSTANT) {
526           Ty = MI.getOperand(1).getFPImm()->getType();
527         } else {
528           assert(MIOp == TargetOpcode::G_BUILD_VECTOR);
529           Type *ElemTy = nullptr;
530           MachineInstr *ElemMI = MRI.getVRegDef(MI.getOperand(1).getReg());
531           assert(ElemMI);
532 
533           if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT) {
534             ElemTy = ElemMI->getOperand(1).getCImm()->getType();
535           } else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT) {
536             ElemTy = ElemMI->getOperand(1).getFPImm()->getType();
537           } else {
538             // There may be a case when we already know Reg's type.
539             MachineInstr *NextMI = MI.getNextNode();
540             if (!NextMI || NextMI->getOpcode() != SPIRV::ASSIGN_TYPE ||
541                 NextMI->getOperand(1).getReg() != Reg)
542               llvm_unreachable("Unexpected opcode");
543           }
544           if (ElemTy)
545             Ty = VectorType::get(
546                 ElemTy, MI.getNumExplicitOperands() - MI.getNumExplicitDefs(),
547                 false);
548           else
549             NeedAssignType = false;
550         }
551         if (NeedAssignType)
552           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI);
553       } else if (MIOp == TargetOpcode::G_GLOBAL_VALUE) {
554         propagateSPIRVType(&MI, GR, MRI, MIB);
555       }
556 
557       if (MII == Begin)
558         ReachedBegin = true;
559       else
560         --MII;
561     }
562   }
563   for (MachineInstr *MI : ToErase) {
564     auto It = RegsAlreadyAddedToDT.find(MI);
565     if (RegsAlreadyAddedToDT.contains(MI))
566       MRI.replaceRegWith(MI->getOperand(0).getReg(), It->second);
567     MI->eraseFromParent();
568   }
569 
570   // Address the case when IRTranslator introduces instructions with new
571   // registers without SPIRVType associated.
572   for (MachineBasicBlock &MBB : MF) {
573     for (MachineInstr &MI : MBB) {
574       switch (MI.getOpcode()) {
575       case TargetOpcode::G_TRUNC:
576       case TargetOpcode::G_ANYEXT:
577       case TargetOpcode::G_SEXT:
578       case TargetOpcode::G_ZEXT:
579       case TargetOpcode::G_PTRTOINT:
580       case TargetOpcode::COPY:
581       case TargetOpcode::G_ADDRSPACE_CAST:
582         propagateSPIRVType(&MI, GR, MRI, MIB);
583         break;
584       }
585     }
586   }
587 }
588 
589 // Defined in SPIRVLegalizerInfo.cpp.
590 extern bool isTypeFoldingSupported(unsigned Opcode);
591 
592 static void processInstrsWithTypeFolding(MachineFunction &MF,
593                                          SPIRVGlobalRegistry *GR,
594                                          MachineIRBuilder MIB) {
595   MachineRegisterInfo &MRI = MF.getRegInfo();
596   for (MachineBasicBlock &MBB : MF) {
597     for (MachineInstr &MI : MBB) {
598       if (isTypeFoldingSupported(MI.getOpcode()))
599         processInstr(MI, MIB, MRI, GR);
600     }
601   }
602 
603   for (MachineBasicBlock &MBB : MF) {
604     for (MachineInstr &MI : MBB) {
605       // We need to rewrite dst types for ASSIGN_TYPE instrs to be able
606       // to perform tblgen'erated selection and we can't do that on Legalizer
607       // as it operates on gMIR only.
608       if (MI.getOpcode() != SPIRV::ASSIGN_TYPE)
609         continue;
610       Register SrcReg = MI.getOperand(1).getReg();
611       unsigned Opcode = MRI.getVRegDef(SrcReg)->getOpcode();
612       if (!isTypeFoldingSupported(Opcode))
613         continue;
614       Register DstReg = MI.getOperand(0).getReg();
615       // Don't need to reset type of register holding constant and used in
616       // G_ADDRSPACE_CAST, since it breaks legalizer.
617       if (Opcode == TargetOpcode::G_CONSTANT && MRI.hasOneUse(DstReg)) {
618         MachineInstr &UseMI = *MRI.use_instr_begin(DstReg);
619         if (UseMI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST)
620           continue;
621       }
622       if (MRI.getType(DstReg).isPointer())
623         MRI.setType(DstReg, LLT::pointer(0, GR->getPointerSize()));
624     }
625   }
626 }
627 
628 static void
629 insertInlineAsmProcess(MachineFunction &MF, SPIRVGlobalRegistry *GR,
630                        const SPIRVSubtarget &ST, MachineIRBuilder MIRBuilder,
631                        const SmallVector<MachineInstr *> &ToProcess) {
632   MachineRegisterInfo &MRI = MF.getRegInfo();
633   Register AsmTargetReg;
634   for (unsigned i = 0, Sz = ToProcess.size(); i + 1 < Sz; i += 2) {
635     MachineInstr *I1 = ToProcess[i], *I2 = ToProcess[i + 1];
636     assert(isSpvIntrinsic(*I1, Intrinsic::spv_inline_asm) && I2->isInlineAsm());
637     MIRBuilder.setInsertPt(*I1->getParent(), *I1);
638 
639     if (!AsmTargetReg.isValid()) {
640       // define vendor specific assembly target or dialect
641       AsmTargetReg = MRI.createGenericVirtualRegister(LLT::scalar(32));
642       MRI.setRegClass(AsmTargetReg, &SPIRV::iIDRegClass);
643       auto AsmTargetMIB =
644           MIRBuilder.buildInstr(SPIRV::OpAsmTargetINTEL).addDef(AsmTargetReg);
645       addStringImm(ST.getTargetTripleAsStr(), AsmTargetMIB);
646       GR->add(AsmTargetMIB.getInstr(), &MF, AsmTargetReg);
647     }
648 
649     // create types
650     const MDNode *IAMD = I1->getOperand(1).getMetadata();
651     FunctionType *FTy = cast<FunctionType>(getMDOperandAsType(IAMD, 0));
652     SmallVector<SPIRVType *, 4> ArgTypes;
653     for (const auto &ArgTy : FTy->params())
654       ArgTypes.push_back(GR->getOrCreateSPIRVType(ArgTy, MIRBuilder));
655     SPIRVType *RetType =
656         GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
657     SPIRVType *FuncType = GR->getOrCreateOpTypeFunctionWithArgs(
658         FTy, RetType, ArgTypes, MIRBuilder);
659 
660     // define vendor specific assembly instructions string
661     Register AsmReg = MRI.createGenericVirtualRegister(LLT::scalar(32));
662     MRI.setRegClass(AsmReg, &SPIRV::iIDRegClass);
663     auto AsmMIB = MIRBuilder.buildInstr(SPIRV::OpAsmINTEL)
664                       .addDef(AsmReg)
665                       .addUse(GR->getSPIRVTypeID(RetType))
666                       .addUse(GR->getSPIRVTypeID(FuncType))
667                       .addUse(AsmTargetReg);
668     // inline asm string:
669     addStringImm(I2->getOperand(InlineAsm::MIOp_AsmString).getSymbolName(),
670                  AsmMIB);
671     // inline asm constraint string:
672     addStringImm(cast<MDString>(I1->getOperand(2).getMetadata()->getOperand(0))
673                      ->getString(),
674                  AsmMIB);
675     GR->add(AsmMIB.getInstr(), &MF, AsmReg);
676 
677     // calls the inline assembly instruction
678     unsigned ExtraInfo = I2->getOperand(InlineAsm::MIOp_ExtraInfo).getImm();
679     if (ExtraInfo & InlineAsm::Extra_HasSideEffects)
680       MIRBuilder.buildInstr(SPIRV::OpDecorate)
681           .addUse(AsmReg)
682           .addImm(static_cast<uint32_t>(SPIRV::Decoration::SideEffectsINTEL));
683     Register DefReg;
684     SmallVector<unsigned, 4> Ops;
685     unsigned StartOp = InlineAsm::MIOp_FirstOperand,
686              AsmDescOp = InlineAsm::MIOp_FirstOperand;
687     unsigned I2Sz = I2->getNumOperands();
688     for (unsigned Idx = StartOp; Idx != I2Sz; ++Idx) {
689       const MachineOperand &MO = I2->getOperand(Idx);
690       if (MO.isMetadata())
691         continue;
692       if (Idx == AsmDescOp && MO.isImm()) {
693         // compute the index of the next operand descriptor
694         const InlineAsm::Flag F(MO.getImm());
695         AsmDescOp += 1 + F.getNumOperandRegisters();
696       } else {
697         if (MO.isReg() && MO.isDef())
698           DefReg = MO.getReg();
699         else
700           Ops.push_back(Idx);
701       }
702     }
703     if (!DefReg.isValid()) {
704       DefReg = MRI.createGenericVirtualRegister(LLT::scalar(32));
705       MRI.setRegClass(DefReg, &SPIRV::iIDRegClass);
706       SPIRVType *VoidType = GR->getOrCreateSPIRVType(
707           Type::getVoidTy(MF.getFunction().getContext()), MIRBuilder);
708       GR->assignSPIRVTypeToVReg(VoidType, DefReg, MF);
709     }
710     auto AsmCall = MIRBuilder.buildInstr(SPIRV::OpAsmCallINTEL)
711                        .addDef(DefReg)
712                        .addUse(GR->getSPIRVTypeID(RetType))
713                        .addUse(AsmReg);
714     unsigned IntrIdx = 2;
715     for (unsigned Idx : Ops) {
716       ++IntrIdx;
717       const MachineOperand &MO = I2->getOperand(Idx);
718       if (MO.isReg())
719         AsmCall.addUse(MO.getReg());
720       else
721         AsmCall.addUse(I1->getOperand(IntrIdx).getReg());
722     }
723   }
724   for (MachineInstr *MI : ToProcess)
725     MI->eraseFromParent();
726 }
727 
728 static void insertInlineAsm(MachineFunction &MF, SPIRVGlobalRegistry *GR,
729                             const SPIRVSubtarget &ST,
730                             MachineIRBuilder MIRBuilder) {
731   SmallVector<MachineInstr *> ToProcess;
732   for (MachineBasicBlock &MBB : MF) {
733     for (MachineInstr &MI : MBB) {
734       if (isSpvIntrinsic(MI, Intrinsic::spv_inline_asm) ||
735           MI.getOpcode() == TargetOpcode::INLINEASM)
736         ToProcess.push_back(&MI);
737     }
738   }
739   if (ToProcess.size() == 0)
740     return;
741 
742   if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_inline_assembly))
743     report_fatal_error("Inline assembly instructions require the "
744                        "following SPIR-V extension: SPV_INTEL_inline_assembly",
745                        false);
746 
747   insertInlineAsmProcess(MF, GR, ST, MIRBuilder, ToProcess);
748 }
749 
750 static void insertSpirvDecorations(MachineFunction &MF, MachineIRBuilder MIB) {
751   SmallVector<MachineInstr *, 10> ToErase;
752   for (MachineBasicBlock &MBB : MF) {
753     for (MachineInstr &MI : MBB) {
754       if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_decoration))
755         continue;
756       MIB.setInsertPt(*MI.getParent(), MI);
757       buildOpSpirvDecorations(MI.getOperand(1).getReg(), MIB,
758                               MI.getOperand(2).getMetadata());
759       ToErase.push_back(&MI);
760     }
761   }
762   for (MachineInstr *MI : ToErase)
763     MI->eraseFromParent();
764 }
765 
766 // Find basic blocks of the switch and replace registers in spv_switch() by its
767 // MBB equivalent.
768 static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
769                             MachineIRBuilder MIB) {
770   DenseMap<const BasicBlock *, MachineBasicBlock *> BB2MBB;
771   SmallVector<std::pair<MachineInstr *, SmallVector<MachineInstr *, 8>>>
772       Switches;
773   for (MachineBasicBlock &MBB : MF) {
774     MachineRegisterInfo &MRI = MF.getRegInfo();
775     BB2MBB[MBB.getBasicBlock()] = &MBB;
776     for (MachineInstr &MI : MBB) {
777       if (!isSpvIntrinsic(MI, Intrinsic::spv_switch))
778         continue;
779       // Calls to spv_switch intrinsics representing IR switches.
780       SmallVector<MachineInstr *, 8> NewOps;
781       for (unsigned i = 2; i < MI.getNumOperands(); ++i) {
782         Register Reg = MI.getOperand(i).getReg();
783         if (i % 2 == 1) {
784           MachineInstr *ConstInstr = getDefInstrMaybeConstant(Reg, &MRI);
785           NewOps.push_back(ConstInstr);
786         } else {
787           MachineInstr *BuildMBB = MRI.getVRegDef(Reg);
788           assert(BuildMBB &&
789                  BuildMBB->getOpcode() == TargetOpcode::G_BLOCK_ADDR &&
790                  BuildMBB->getOperand(1).isBlockAddress() &&
791                  BuildMBB->getOperand(1).getBlockAddress());
792           NewOps.push_back(BuildMBB);
793         }
794       }
795       Switches.push_back(std::make_pair(&MI, NewOps));
796     }
797   }
798 
799   SmallPtrSet<MachineInstr *, 8> ToEraseMI;
800   for (auto &SwIt : Switches) {
801     MachineInstr &MI = *SwIt.first;
802     SmallVector<MachineInstr *, 8> &Ins = SwIt.second;
803     SmallVector<MachineOperand, 8> NewOps;
804     for (unsigned i = 0; i < Ins.size(); ++i) {
805       if (Ins[i]->getOpcode() == TargetOpcode::G_BLOCK_ADDR) {
806         BasicBlock *CaseBB =
807             Ins[i]->getOperand(1).getBlockAddress()->getBasicBlock();
808         auto It = BB2MBB.find(CaseBB);
809         if (It == BB2MBB.end())
810           report_fatal_error("cannot find a machine basic block by a basic "
811                              "block in a switch statement");
812         NewOps.push_back(MachineOperand::CreateMBB(It->second));
813         MI.getParent()->addSuccessor(It->second);
814         ToEraseMI.insert(Ins[i]);
815       } else {
816         NewOps.push_back(
817             MachineOperand::CreateCImm(Ins[i]->getOperand(1).getCImm()));
818       }
819     }
820     for (unsigned i = MI.getNumOperands() - 1; i > 1; --i)
821       MI.removeOperand(i);
822     for (auto &MO : NewOps)
823       MI.addOperand(MO);
824     if (MachineInstr *Next = MI.getNextNode()) {
825       if (isSpvIntrinsic(*Next, Intrinsic::spv_track_constant)) {
826         ToEraseMI.insert(Next);
827         Next = MI.getNextNode();
828       }
829       if (Next && Next->getOpcode() == TargetOpcode::G_BRINDIRECT)
830         ToEraseMI.insert(Next);
831     }
832   }
833 
834   // If we just delete G_BLOCK_ADDR instructions with BlockAddress operands,
835   // this leaves their BasicBlock counterparts in a "address taken" status. This
836   // would make AsmPrinter to generate a series of unneeded labels of a "Address
837   // of block that was removed by CodeGen" kind. Let's first ensure that we
838   // don't have a dangling BlockAddress constants by zapping the BlockAddress
839   // nodes, and only after that proceed with erasing G_BLOCK_ADDR instructions.
840   Constant *Replacement =
841       ConstantInt::get(Type::getInt32Ty(MF.getFunction().getContext()), 1);
842   for (MachineInstr *BlockAddrI : ToEraseMI) {
843     if (BlockAddrI->getOpcode() == TargetOpcode::G_BLOCK_ADDR) {
844       BlockAddress *BA = const_cast<BlockAddress *>(
845           BlockAddrI->getOperand(1).getBlockAddress());
846       BA->replaceAllUsesWith(
847           ConstantExpr::getIntToPtr(Replacement, BA->getType()));
848       BA->destroyConstant();
849     }
850     BlockAddrI->eraseFromParent();
851   }
852 }
853 
854 static bool isImplicitFallthrough(MachineBasicBlock &MBB) {
855   if (MBB.empty())
856     return true;
857 
858   // Branching SPIR-V intrinsics are not detected by this generic method.
859   // Thus, we can only trust negative result.
860   if (!MBB.canFallThrough())
861     return false;
862 
863   // Otherwise, we must manually check if we have a SPIR-V intrinsic which
864   // prevent an implicit fallthrough.
865   for (MachineBasicBlock::reverse_iterator It = MBB.rbegin(), E = MBB.rend();
866        It != E; ++It) {
867     if (isSpvIntrinsic(*It, Intrinsic::spv_switch))
868       return false;
869   }
870   return true;
871 }
872 
873 static void removeImplicitFallthroughs(MachineFunction &MF,
874                                        MachineIRBuilder MIB) {
875   // It is valid for MachineBasicBlocks to not finish with a branch instruction.
876   // In such cases, they will simply fallthrough their immediate successor.
877   for (MachineBasicBlock &MBB : MF) {
878     if (!isImplicitFallthrough(MBB))
879       continue;
880 
881     assert(std::distance(MBB.successors().begin(), MBB.successors().end()) ==
882            1);
883     MIB.setInsertPt(MBB, MBB.end());
884     MIB.buildBr(**MBB.successors().begin());
885   }
886 }
887 
888 bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
889   // Initialize the type registry.
890   const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
891   SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
892   GR->setCurrentFunc(MF);
893   MachineIRBuilder MIB(MF);
894   // a registry of target extension constants
895   DenseMap<MachineInstr *, Type *> TargetExtConstTypes;
896   // to keep record of tracked constants
897   SmallSet<Register, 4> TrackedConstRegs;
898   addConstantsToTrack(MF, GR, ST, TargetExtConstTypes, TrackedConstRegs);
899   foldConstantsIntoIntrinsics(MF, TrackedConstRegs);
900   insertBitcasts(MF, GR, MIB);
901   generateAssignInstrs(MF, GR, MIB, TargetExtConstTypes);
902   processSwitches(MF, GR, MIB);
903   processInstrsWithTypeFolding(MF, GR, MIB);
904   removeImplicitFallthroughs(MF, MIB);
905   insertSpirvDecorations(MF, MIB);
906   insertInlineAsm(MF, GR, ST, MIB);
907 
908   return true;
909 }
910 
911 INITIALIZE_PASS(SPIRVPreLegalizer, DEBUG_TYPE, "SPIRV pre legalizer", false,
912                 false)
913 
914 char SPIRVPreLegalizer::ID = 0;
915 
916 FunctionPass *llvm::createSPIRVPreLegalizerPass() {
917   return new SPIRVPreLegalizer();
918 }
919