xref: /llvm-project/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (revision 02a334de6690202154ef09456c581618ff290f9a)
1 //===-- SPIRVPreLegalizer.cpp - prepare IR for legalization -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The pass prepares IR for legalization: it assigns SPIR-V types to registers
10 // and removes intrinsics which holded these types during IR translation.
11 // Also it processes constants and registers them in GR to avoid duplication.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "SPIRV.h"
16 #include "SPIRVSubtarget.h"
17 #include "SPIRVUtils.h"
18 #include "llvm/ADT/PostOrderIterator.h"
19 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
20 #include "llvm/IR/Attributes.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DebugInfoMetadata.h"
23 #include "llvm/IR/IntrinsicsSPIRV.h"
24 #include "llvm/Target/TargetIntrinsicInfo.h"
25 
26 #define DEBUG_TYPE "spirv-prelegalizer"
27 
28 using namespace llvm;
29 
30 namespace {
31 class SPIRVPreLegalizer : public MachineFunctionPass {
32 public:
33   static char ID;
34   SPIRVPreLegalizer() : MachineFunctionPass(ID) {
35     initializeSPIRVPreLegalizerPass(*PassRegistry::getPassRegistry());
36   }
37   bool runOnMachineFunction(MachineFunction &MF) override;
38 };
39 } // namespace
40 
41 static void
42 addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR,
43                     const SPIRVSubtarget &STI,
44                     DenseMap<MachineInstr *, Type *> &TargetExtConstTypes,
45                     SmallSet<Register, 4> &TrackedConstRegs) {
46   MachineRegisterInfo &MRI = MF.getRegInfo();
47   DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
48   SmallVector<MachineInstr *, 10> ToErase, ToEraseComposites;
49   for (MachineBasicBlock &MBB : MF) {
50     for (MachineInstr &MI : MBB) {
51       if (!isSpvIntrinsic(MI, Intrinsic::spv_track_constant))
52         continue;
53       ToErase.push_back(&MI);
54       Register SrcReg = MI.getOperand(2).getReg();
55       auto *Const =
56           cast<Constant>(cast<ConstantAsMetadata>(
57                              MI.getOperand(3).getMetadata()->getOperand(0))
58                              ->getValue());
59       if (auto *GV = dyn_cast<GlobalValue>(Const)) {
60         Register Reg = GR->find(GV, &MF);
61         if (!Reg.isValid())
62           GR->add(GV, &MF, SrcReg);
63         else
64           RegsAlreadyAddedToDT[&MI] = Reg;
65       } else {
66         Register Reg = GR->find(Const, &MF);
67         if (!Reg.isValid()) {
68           if (auto *ConstVec = dyn_cast<ConstantDataVector>(Const)) {
69             auto *BuildVec = MRI.getVRegDef(SrcReg);
70             assert(BuildVec &&
71                    BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
72             for (unsigned i = 0; i < ConstVec->getNumElements(); ++i) {
73               // Ensure that OpConstantComposite reuses a constant when it's
74               // already created and available in the same machine function.
75               Constant *ElemConst = ConstVec->getElementAsConstant(i);
76               Register ElemReg = GR->find(ElemConst, &MF);
77               if (!ElemReg.isValid())
78                 GR->add(ElemConst, &MF, BuildVec->getOperand(1 + i).getReg());
79               else
80                 BuildVec->getOperand(1 + i).setReg(ElemReg);
81             }
82           }
83           GR->add(Const, &MF, SrcReg);
84           TrackedConstRegs.insert(SrcReg);
85           if (Const->getType()->isTargetExtTy()) {
86             // remember association so that we can restore it when assign types
87             MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
88             if (SrcMI && (SrcMI->getOpcode() == TargetOpcode::G_CONSTANT ||
89                           SrcMI->getOpcode() == TargetOpcode::G_IMPLICIT_DEF))
90               TargetExtConstTypes[SrcMI] = Const->getType();
91             if (Const->isNullValue()) {
92               MachineIRBuilder MIB(MF);
93               SPIRVType *ExtType =
94                   GR->getOrCreateSPIRVType(Const->getType(), MIB);
95               SrcMI->setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull));
96               SrcMI->addOperand(MachineOperand::CreateReg(
97                   GR->getSPIRVTypeID(ExtType), false));
98             }
99           }
100         } else {
101           RegsAlreadyAddedToDT[&MI] = Reg;
102           // This MI is unused and will be removed. If the MI uses
103           // const_composite, it will be unused and should be removed too.
104           assert(MI.getOperand(2).isReg() && "Reg operand is expected");
105           MachineInstr *SrcMI = MRI.getVRegDef(MI.getOperand(2).getReg());
106           if (SrcMI && isSpvIntrinsic(*SrcMI, Intrinsic::spv_const_composite))
107             ToEraseComposites.push_back(SrcMI);
108         }
109       }
110     }
111   }
112   for (MachineInstr *MI : ToErase) {
113     Register Reg = MI->getOperand(2).getReg();
114     if (RegsAlreadyAddedToDT.contains(MI))
115       Reg = RegsAlreadyAddedToDT[MI];
116     auto *RC = MRI.getRegClassOrNull(MI->getOperand(0).getReg());
117     if (!MRI.getRegClassOrNull(Reg) && RC)
118       MRI.setRegClass(Reg, RC);
119     MRI.replaceRegWith(MI->getOperand(0).getReg(), Reg);
120     MI->eraseFromParent();
121   }
122   for (MachineInstr *MI : ToEraseComposites)
123     MI->eraseFromParent();
124 }
125 
126 static void
127 foldConstantsIntoIntrinsics(MachineFunction &MF,
128                             const SmallSet<Register, 4> &TrackedConstRegs) {
129   SmallVector<MachineInstr *, 10> ToErase;
130   MachineRegisterInfo &MRI = MF.getRegInfo();
131   const unsigned AssignNameOperandShift = 2;
132   for (MachineBasicBlock &MBB : MF) {
133     for (MachineInstr &MI : MBB) {
134       if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_name))
135         continue;
136       unsigned NumOp = MI.getNumExplicitDefs() + AssignNameOperandShift;
137       while (MI.getOperand(NumOp).isReg()) {
138         MachineOperand &MOp = MI.getOperand(NumOp);
139         MachineInstr *ConstMI = MRI.getVRegDef(MOp.getReg());
140         assert(ConstMI->getOpcode() == TargetOpcode::G_CONSTANT);
141         MI.removeOperand(NumOp);
142         MI.addOperand(MachineOperand::CreateImm(
143             ConstMI->getOperand(1).getCImm()->getZExtValue()));
144         Register DefReg = ConstMI->getOperand(0).getReg();
145         if (MRI.use_empty(DefReg) && !TrackedConstRegs.contains(DefReg))
146           ToErase.push_back(ConstMI);
147       }
148     }
149   }
150   for (MachineInstr *MI : ToErase)
151     MI->eraseFromParent();
152 }
153 
154 static MachineInstr *findAssignTypeInstr(Register Reg,
155                                          MachineRegisterInfo *MRI) {
156   for (MachineRegisterInfo::use_instr_iterator I = MRI->use_instr_begin(Reg),
157                                                IE = MRI->use_instr_end();
158        I != IE; ++I) {
159     MachineInstr *UseMI = &*I;
160     if ((isSpvIntrinsic(*UseMI, Intrinsic::spv_assign_ptr_type) ||
161          isSpvIntrinsic(*UseMI, Intrinsic::spv_assign_type)) &&
162         UseMI->getOperand(1).getReg() == Reg)
163       return UseMI;
164   }
165   return nullptr;
166 }
167 
168 static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
169                            MachineIRBuilder MIB) {
170   // Get access to information about available extensions
171   const SPIRVSubtarget *ST =
172       static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
173   SmallVector<MachineInstr *, 10> ToErase;
174   for (MachineBasicBlock &MBB : MF) {
175     for (MachineInstr &MI : MBB) {
176       if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast) &&
177           !isSpvIntrinsic(MI, Intrinsic::spv_ptrcast))
178         continue;
179       assert(MI.getOperand(2).isReg());
180       MIB.setInsertPt(*MI.getParent(), MI);
181       ToErase.push_back(&MI);
182       if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) {
183         MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
184         continue;
185       }
186       Register Def = MI.getOperand(0).getReg();
187       Register Source = MI.getOperand(2).getReg();
188       Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0);
189       SPIRVType *BaseTy = GR->getOrCreateSPIRVType(ElemTy, MIB);
190       SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
191           BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
192           addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST));
193 
194       // If the ptrcast would be redundant, replace all uses with the source
195       // register.
196       MachineRegisterInfo *MRI = MIB.getMRI();
197       if (GR->getSPIRVTypeForVReg(Source) == AssignedPtrType) {
198         // Erase Def's assign type instruction if we are going to replace Def.
199         if (MachineInstr *AssignMI = findAssignTypeInstr(Def, MRI))
200           ToErase.push_back(AssignMI);
201         MRI->replaceRegWith(Def, Source);
202       } else {
203         GR->assignSPIRVTypeToVReg(AssignedPtrType, Def, MF);
204         MIB.buildBitcast(Def, Source);
205         // MachineVerifier requires that bitcast must change the type.
206         // Change AddressSpace if needed to hint that Def and Source points to
207         // different types: this doesn't change actual code generation.
208         LLT DefType = MRI->getType(Def);
209         if (DefType == MRI->getType(Source))
210           MRI->setType(Def,
211                        LLT::pointer((DefType.getAddressSpace() + 1) %
212                                         SPIRVSubtarget::MaxLegalAddressSpace,
213                                     GR->getPointerSize()));
214       }
215     }
216   }
217   for (MachineInstr *MI : ToErase)
218     MI->eraseFromParent();
219 }
220 
221 // Translating GV, IRTranslator sometimes generates following IR:
222 //   %1 = G_GLOBAL_VALUE
223 //   %2 = COPY %1
224 //   %3 = G_ADDRSPACE_CAST %2
225 //
226 // or
227 //
228 //  %1 = G_ZEXT %2
229 //  G_MEMCPY ... %2 ...
230 //
231 // New registers have no SPIRVType and no register class info.
232 //
233 // Set SPIRVType for GV, propagate it from GV to other instructions,
234 // also set register classes.
235 static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
236                                      MachineRegisterInfo &MRI,
237                                      MachineIRBuilder &MIB) {
238   SPIRVType *SpvType = nullptr;
239   assert(MI && "Machine instr is expected");
240   if (MI->getOperand(0).isReg()) {
241     Register Reg = MI->getOperand(0).getReg();
242     SpvType = GR->getSPIRVTypeForVReg(Reg);
243     if (!SpvType) {
244       switch (MI->getOpcode()) {
245       case TargetOpcode::G_CONSTANT: {
246         MIB.setInsertPt(*MI->getParent(), MI);
247         Type *Ty = MI->getOperand(1).getCImm()->getType();
248         SpvType = GR->getOrCreateSPIRVType(Ty, MIB);
249         break;
250       }
251       case TargetOpcode::G_GLOBAL_VALUE: {
252         MIB.setInsertPt(*MI->getParent(), MI);
253         const GlobalValue *Global = MI->getOperand(1).getGlobal();
254         Type *ElementTy = toTypedPointer(GR->getDeducedGlobalValueType(Global));
255         auto *Ty = TypedPointerType::get(ElementTy,
256                                          Global->getType()->getAddressSpace());
257         SpvType = GR->getOrCreateSPIRVType(Ty, MIB);
258         break;
259       }
260       case TargetOpcode::G_ANYEXT:
261       case TargetOpcode::G_SEXT:
262       case TargetOpcode::G_ZEXT: {
263         if (MI->getOperand(1).isReg()) {
264           if (MachineInstr *DefInstr =
265                   MRI.getVRegDef(MI->getOperand(1).getReg())) {
266             if (SPIRVType *Def = propagateSPIRVType(DefInstr, GR, MRI, MIB)) {
267               unsigned CurrentBW = GR->getScalarOrVectorBitWidth(Def);
268               unsigned ExpectedBW =
269                   std::max(MRI.getType(Reg).getScalarSizeInBits(), CurrentBW);
270               unsigned NumElements = GR->getScalarOrVectorComponentCount(Def);
271               SpvType = GR->getOrCreateSPIRVIntegerType(ExpectedBW, MIB);
272               if (NumElements > 1)
273                 SpvType =
274                     GR->getOrCreateSPIRVVectorType(SpvType, NumElements, MIB);
275             }
276           }
277         }
278         break;
279       }
280       case TargetOpcode::G_PTRTOINT:
281         SpvType = GR->getOrCreateSPIRVIntegerType(
282             MRI.getType(Reg).getScalarSizeInBits(), MIB);
283         break;
284       case TargetOpcode::G_TRUNC:
285       case TargetOpcode::G_ADDRSPACE_CAST:
286       case TargetOpcode::G_PTR_ADD:
287       case TargetOpcode::COPY: {
288         MachineOperand &Op = MI->getOperand(1);
289         MachineInstr *Def = Op.isReg() ? MRI.getVRegDef(Op.getReg()) : nullptr;
290         if (Def)
291           SpvType = propagateSPIRVType(Def, GR, MRI, MIB);
292         break;
293       }
294       default:
295         break;
296       }
297       if (SpvType)
298         GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF());
299       if (!MRI.getRegClassOrNull(Reg))
300         MRI.setRegClass(Reg, SpvType ? GR->getRegClass(SpvType)
301                                      : &SPIRV::iIDRegClass);
302     }
303   }
304   return SpvType;
305 }
306 
307 // To support current approach and limitations wrt. bit width here we widen a
308 // scalar register with a bit width greater than 1 to valid sizes and cap it to
309 // 64 width.
310 static void widenScalarLLTNextPow2(Register Reg, MachineRegisterInfo &MRI) {
311   LLT RegType = MRI.getType(Reg);
312   if (!RegType.isScalar())
313     return;
314   unsigned Sz = RegType.getScalarSizeInBits();
315   if (Sz == 1)
316     return;
317   unsigned NewSz = std::min(std::max(1u << Log2_32_Ceil(Sz), 8u), 64u);
318   if (NewSz != Sz)
319     MRI.setType(Reg, LLT::scalar(NewSz));
320 }
321 
322 static std::pair<Register, unsigned>
323 createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI,
324                const SPIRVGlobalRegistry &GR) {
325   if (!SpvType)
326     SpvType = GR.getSPIRVTypeForVReg(SrcReg);
327   const TargetRegisterClass *RC = GR.getRegClass(SpvType);
328   Register Reg = MRI.createGenericVirtualRegister(GR.getRegType(SpvType));
329   MRI.setRegClass(Reg, RC);
330   unsigned GetIdOp = SPIRV::GET_ID;
331   if (RC == &SPIRV::fIDRegClass)
332     GetIdOp = SPIRV::GET_fID;
333   else if (RC == &SPIRV::pIDRegClass)
334     GetIdOp = SPIRV::GET_pID;
335   else if (RC == &SPIRV::vfIDRegClass)
336     GetIdOp = SPIRV::GET_vfID;
337   else if (RC == &SPIRV::vpIDRegClass)
338     GetIdOp = SPIRV::GET_vpID;
339   else if (RC == &SPIRV::vIDRegClass)
340     GetIdOp = SPIRV::GET_vID;
341   return {Reg, GetIdOp};
342 }
343 
344 // Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as
345 // a dst of the definition, assign SPIRVType to both registers. If SpvType is
346 // provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty.
347 // It's used also in SPIRVBuiltins.cpp.
348 // TODO: maybe move to SPIRVUtils.
349 namespace llvm {
350 Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpvType,
351                            SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
352                            MachineRegisterInfo &MRI) {
353   MachineInstr *Def = MRI.getVRegDef(Reg);
354   assert((Ty || SpvType) && "Either LLVM or SPIRV type is expected.");
355   MIB.setInsertPt(*Def->getParent(),
356                   (Def->getNextNode() ? Def->getNextNode()->getIterator()
357                                       : Def->getParent()->end()));
358   SpvType = SpvType ? SpvType : GR->getOrCreateSPIRVType(Ty, MIB);
359   Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
360   if (auto *RC = MRI.getRegClassOrNull(Reg)) {
361     MRI.setRegClass(NewReg, RC);
362   } else {
363     auto RegClass = GR->getRegClass(SpvType);
364     MRI.setRegClass(NewReg, RegClass);
365     MRI.setRegClass(Reg, RegClass);
366   }
367   GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF());
368   // This is to make it convenient for Legalizer to get the SPIRVType
369   // when processing the actual MI (i.e. not pseudo one).
370   GR->assignSPIRVTypeToVReg(SpvType, NewReg, MIB.getMF());
371   // Copy MIFlags from Def to ASSIGN_TYPE instruction. It's required to keep
372   // the flags after instruction selection.
373   const uint32_t Flags = Def->getFlags();
374   MIB.buildInstr(SPIRV::ASSIGN_TYPE)
375       .addDef(Reg)
376       .addUse(NewReg)
377       .addUse(GR->getSPIRVTypeID(SpvType))
378       .setMIFlags(Flags);
379   Def->getOperand(0).setReg(NewReg);
380   return NewReg;
381 }
382 
383 void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
384                   MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) {
385   assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg()));
386   MachineInstr &AssignTypeInst =
387       *(MRI.use_instr_begin(MI.getOperand(0).getReg()));
388   auto NewReg =
389       createNewIdReg(nullptr, MI.getOperand(0).getReg(), MRI, *GR).first;
390   AssignTypeInst.getOperand(1).setReg(NewReg);
391   MI.getOperand(0).setReg(NewReg);
392   MIB.setInsertPt(*MI.getParent(), MI.getIterator());
393   for (auto &Op : MI.operands()) {
394     if (!Op.isReg() || Op.isDef())
395       continue;
396     auto IdOpInfo = createNewIdReg(nullptr, Op.getReg(), MRI, *GR);
397     MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg());
398     Op.setReg(IdOpInfo.first);
399   }
400 }
401 } // namespace llvm
402 
403 static void
404 generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
405                      MachineIRBuilder MIB,
406                      DenseMap<MachineInstr *, Type *> &TargetExtConstTypes) {
407   // Get access to information about available extensions
408   const SPIRVSubtarget *ST =
409       static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
410 
411   MachineRegisterInfo &MRI = MF.getRegInfo();
412   SmallVector<MachineInstr *, 10> ToErase;
413   DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
414 
415   bool IsExtendedInts =
416       ST->canUseExtension(
417           SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
418       ST->canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
419 
420   for (MachineBasicBlock *MBB : post_order(&MF)) {
421     if (MBB->empty())
422       continue;
423 
424     bool ReachedBegin = false;
425     for (auto MII = std::prev(MBB->end()), Begin = MBB->begin();
426          !ReachedBegin;) {
427       MachineInstr &MI = *MII;
428       unsigned MIOp = MI.getOpcode();
429 
430       if (!IsExtendedInts) {
431         // validate bit width of scalar registers
432         for (const auto &MOP : MI.operands())
433           if (MOP.isReg())
434             widenScalarLLTNextPow2(MOP.getReg(), MRI);
435       }
436 
437       if (isSpvIntrinsic(MI, Intrinsic::spv_assign_ptr_type)) {
438         Register Reg = MI.getOperand(1).getReg();
439         MIB.setInsertPt(*MI.getParent(), MI.getIterator());
440         Type *ElementTy = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
441         SPIRVType *BaseTy = GR->getOrCreateSPIRVType(ElementTy, MIB);
442         SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
443             BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
444             addressSpaceToStorageClass(MI.getOperand(3).getImm(), *ST));
445         MachineInstr *Def = MRI.getVRegDef(Reg);
446         assert(Def && "Expecting an instruction that defines the register");
447         // G_GLOBAL_VALUE already has type info.
448         if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE &&
449             Def->getOpcode() != SPIRV::ASSIGN_TYPE)
450           insertAssignInstr(Reg, nullptr, AssignedPtrType, GR, MIB,
451                             MF.getRegInfo());
452         ToErase.push_back(&MI);
453       } else if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) {
454         Register Reg = MI.getOperand(1).getReg();
455         Type *Ty = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
456         MachineInstr *Def = MRI.getVRegDef(Reg);
457         assert(Def && "Expecting an instruction that defines the register");
458         // G_GLOBAL_VALUE already has type info.
459         if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE &&
460             Def->getOpcode() != SPIRV::ASSIGN_TYPE)
461           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo());
462         ToErase.push_back(&MI);
463       } else if (MIOp == TargetOpcode::G_CONSTANT ||
464                  MIOp == TargetOpcode::G_FCONSTANT ||
465                  MIOp == TargetOpcode::G_BUILD_VECTOR) {
466         // %rc = G_CONSTANT ty Val
467         // ===>
468         // %cty = OpType* ty
469         // %rctmp = G_CONSTANT ty Val
470         // %rc = ASSIGN_TYPE %rctmp, %cty
471         Register Reg = MI.getOperand(0).getReg();
472         bool NeedAssignType = true;
473         if (MRI.hasOneUse(Reg)) {
474           MachineInstr &UseMI = *MRI.use_instr_begin(Reg);
475           if (isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type) ||
476               isSpvIntrinsic(UseMI, Intrinsic::spv_assign_name))
477             continue;
478           if (UseMI.getOpcode() == SPIRV::ASSIGN_TYPE)
479             NeedAssignType = false;
480         }
481         Type *Ty = nullptr;
482         if (MIOp == TargetOpcode::G_CONSTANT) {
483           auto TargetExtIt = TargetExtConstTypes.find(&MI);
484           Ty = TargetExtIt == TargetExtConstTypes.end()
485                    ? MI.getOperand(1).getCImm()->getType()
486                    : TargetExtIt->second;
487           const ConstantInt *OpCI = MI.getOperand(1).getCImm();
488           Register PrimaryReg = GR->find(OpCI, &MF);
489           if (!PrimaryReg.isValid()) {
490             GR->add(OpCI, &MF, Reg);
491           } else if (PrimaryReg != Reg &&
492                      MRI.getType(Reg) == MRI.getType(PrimaryReg)) {
493             auto *RCReg = MRI.getRegClassOrNull(Reg);
494             auto *RCPrimary = MRI.getRegClassOrNull(PrimaryReg);
495             if (!RCReg || RCPrimary == RCReg) {
496               RegsAlreadyAddedToDT[&MI] = PrimaryReg;
497               ToErase.push_back(&MI);
498               NeedAssignType = false;
499             }
500           }
501         } else if (MIOp == TargetOpcode::G_FCONSTANT) {
502           Ty = MI.getOperand(1).getFPImm()->getType();
503         } else {
504           assert(MIOp == TargetOpcode::G_BUILD_VECTOR);
505           Type *ElemTy = nullptr;
506           MachineInstr *ElemMI = MRI.getVRegDef(MI.getOperand(1).getReg());
507           assert(ElemMI);
508 
509           if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT) {
510             ElemTy = ElemMI->getOperand(1).getCImm()->getType();
511           } else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT) {
512             ElemTy = ElemMI->getOperand(1).getFPImm()->getType();
513           } else {
514             // There may be a case when we already know Reg's type.
515             MachineInstr *NextMI = MI.getNextNode();
516             if (!NextMI || NextMI->getOpcode() != SPIRV::ASSIGN_TYPE ||
517                 NextMI->getOperand(1).getReg() != Reg)
518               llvm_unreachable("Unexpected opcode");
519           }
520           if (ElemTy)
521             Ty = VectorType::get(
522                 ElemTy, MI.getNumExplicitOperands() - MI.getNumExplicitDefs(),
523                 false);
524           else
525             NeedAssignType = false;
526         }
527         if (NeedAssignType)
528           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI);
529       } else if (MIOp == TargetOpcode::G_GLOBAL_VALUE) {
530         propagateSPIRVType(&MI, GR, MRI, MIB);
531       }
532 
533       if (MII == Begin)
534         ReachedBegin = true;
535       else
536         --MII;
537     }
538   }
539   for (MachineInstr *MI : ToErase) {
540     auto It = RegsAlreadyAddedToDT.find(MI);
541     if (RegsAlreadyAddedToDT.contains(MI))
542       MRI.replaceRegWith(MI->getOperand(0).getReg(), It->second);
543     MI->eraseFromParent();
544   }
545 
546   // Address the case when IRTranslator introduces instructions with new
547   // registers without SPIRVType associated.
548   for (MachineBasicBlock &MBB : MF) {
549     for (MachineInstr &MI : MBB) {
550       switch (MI.getOpcode()) {
551       case TargetOpcode::G_TRUNC:
552       case TargetOpcode::G_ANYEXT:
553       case TargetOpcode::G_SEXT:
554       case TargetOpcode::G_ZEXT:
555       case TargetOpcode::G_PTRTOINT:
556       case TargetOpcode::COPY:
557       case TargetOpcode::G_ADDRSPACE_CAST:
558         propagateSPIRVType(&MI, GR, MRI, MIB);
559         break;
560       }
561     }
562   }
563 }
564 
565 // Defined in SPIRVLegalizerInfo.cpp.
566 extern bool isTypeFoldingSupported(unsigned Opcode);
567 
568 static void processInstrsWithTypeFolding(MachineFunction &MF,
569                                          SPIRVGlobalRegistry *GR,
570                                          MachineIRBuilder MIB) {
571   MachineRegisterInfo &MRI = MF.getRegInfo();
572   for (MachineBasicBlock &MBB : MF) {
573     for (MachineInstr &MI : MBB) {
574       if (isTypeFoldingSupported(MI.getOpcode()))
575         processInstr(MI, MIB, MRI, GR);
576     }
577   }
578 
579   for (MachineBasicBlock &MBB : MF) {
580     for (MachineInstr &MI : MBB) {
581       // We need to rewrite dst types for ASSIGN_TYPE instrs to be able
582       // to perform tblgen'erated selection and we can't do that on Legalizer
583       // as it operates on gMIR only.
584       if (MI.getOpcode() != SPIRV::ASSIGN_TYPE)
585         continue;
586       Register SrcReg = MI.getOperand(1).getReg();
587       unsigned Opcode = MRI.getVRegDef(SrcReg)->getOpcode();
588       if (!isTypeFoldingSupported(Opcode))
589         continue;
590       Register DstReg = MI.getOperand(0).getReg();
591       // Don't need to reset type of register holding constant and used in
592       // G_ADDRSPACE_CAST, since it breaks legalizer.
593       if (Opcode == TargetOpcode::G_CONSTANT && MRI.hasOneUse(DstReg)) {
594         MachineInstr &UseMI = *MRI.use_instr_begin(DstReg);
595         if (UseMI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST)
596           continue;
597       }
598     }
599   }
600 }
601 
602 static Register
603 collectInlineAsmInstrOperands(MachineInstr *MI,
604                               SmallVector<unsigned, 4> *Ops = nullptr) {
605   Register DefReg;
606   unsigned StartOp = InlineAsm::MIOp_FirstOperand,
607            AsmDescOp = InlineAsm::MIOp_FirstOperand;
608   for (unsigned Idx = StartOp, MISz = MI->getNumOperands(); Idx != MISz;
609        ++Idx) {
610     const MachineOperand &MO = MI->getOperand(Idx);
611     if (MO.isMetadata())
612       continue;
613     if (Idx == AsmDescOp && MO.isImm()) {
614       // compute the index of the next operand descriptor
615       const InlineAsm::Flag F(MO.getImm());
616       AsmDescOp += 1 + F.getNumOperandRegisters();
617       continue;
618     }
619     if (MO.isReg() && MO.isDef()) {
620       if (!Ops)
621         return MO.getReg();
622       else
623         DefReg = MO.getReg();
624     } else if (Ops) {
625       Ops->push_back(Idx);
626     }
627   }
628   return DefReg;
629 }
630 
631 static void
632 insertInlineAsmProcess(MachineFunction &MF, SPIRVGlobalRegistry *GR,
633                        const SPIRVSubtarget &ST, MachineIRBuilder MIRBuilder,
634                        const SmallVector<MachineInstr *> &ToProcess) {
635   MachineRegisterInfo &MRI = MF.getRegInfo();
636   Register AsmTargetReg;
637   for (unsigned i = 0, Sz = ToProcess.size(); i + 1 < Sz; i += 2) {
638     MachineInstr *I1 = ToProcess[i], *I2 = ToProcess[i + 1];
639     assert(isSpvIntrinsic(*I1, Intrinsic::spv_inline_asm) && I2->isInlineAsm());
640     MIRBuilder.setInsertPt(*I2->getParent(), *I2);
641 
642     if (!AsmTargetReg.isValid()) {
643       // define vendor specific assembly target or dialect
644       AsmTargetReg = MRI.createGenericVirtualRegister(LLT::scalar(32));
645       MRI.setRegClass(AsmTargetReg, &SPIRV::iIDRegClass);
646       auto AsmTargetMIB =
647           MIRBuilder.buildInstr(SPIRV::OpAsmTargetINTEL).addDef(AsmTargetReg);
648       addStringImm(ST.getTargetTripleAsStr(), AsmTargetMIB);
649       GR->add(AsmTargetMIB.getInstr(), &MF, AsmTargetReg);
650     }
651 
652     // create types
653     const MDNode *IAMD = I1->getOperand(1).getMetadata();
654     FunctionType *FTy = cast<FunctionType>(getMDOperandAsType(IAMD, 0));
655     SmallVector<SPIRVType *, 4> ArgTypes;
656     for (const auto &ArgTy : FTy->params())
657       ArgTypes.push_back(GR->getOrCreateSPIRVType(ArgTy, MIRBuilder));
658     SPIRVType *RetType =
659         GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
660     SPIRVType *FuncType = GR->getOrCreateOpTypeFunctionWithArgs(
661         FTy, RetType, ArgTypes, MIRBuilder);
662 
663     // define vendor specific assembly instructions string
664     Register AsmReg = MRI.createGenericVirtualRegister(LLT::scalar(32));
665     MRI.setRegClass(AsmReg, &SPIRV::iIDRegClass);
666     auto AsmMIB = MIRBuilder.buildInstr(SPIRV::OpAsmINTEL)
667                       .addDef(AsmReg)
668                       .addUse(GR->getSPIRVTypeID(RetType))
669                       .addUse(GR->getSPIRVTypeID(FuncType))
670                       .addUse(AsmTargetReg);
671     // inline asm string:
672     addStringImm(I2->getOperand(InlineAsm::MIOp_AsmString).getSymbolName(),
673                  AsmMIB);
674     // inline asm constraint string:
675     addStringImm(cast<MDString>(I1->getOperand(2).getMetadata()->getOperand(0))
676                      ->getString(),
677                  AsmMIB);
678     GR->add(AsmMIB.getInstr(), &MF, AsmReg);
679 
680     // calls the inline assembly instruction
681     unsigned ExtraInfo = I2->getOperand(InlineAsm::MIOp_ExtraInfo).getImm();
682     if (ExtraInfo & InlineAsm::Extra_HasSideEffects)
683       MIRBuilder.buildInstr(SPIRV::OpDecorate)
684           .addUse(AsmReg)
685           .addImm(static_cast<uint32_t>(SPIRV::Decoration::SideEffectsINTEL));
686 
687     Register DefReg = collectInlineAsmInstrOperands(I2);
688     if (!DefReg.isValid()) {
689       DefReg = MRI.createGenericVirtualRegister(LLT::scalar(32));
690       MRI.setRegClass(DefReg, &SPIRV::iIDRegClass);
691       SPIRVType *VoidType = GR->getOrCreateSPIRVType(
692           Type::getVoidTy(MF.getFunction().getContext()), MIRBuilder);
693       GR->assignSPIRVTypeToVReg(VoidType, DefReg, MF);
694     }
695 
696     auto AsmCall = MIRBuilder.buildInstr(SPIRV::OpAsmCallINTEL)
697                        .addDef(DefReg)
698                        .addUse(GR->getSPIRVTypeID(RetType))
699                        .addUse(AsmReg);
700     for (unsigned IntrIdx = 3; IntrIdx < I1->getNumOperands(); ++IntrIdx)
701       AsmCall.addUse(I1->getOperand(IntrIdx).getReg());
702   }
703   for (MachineInstr *MI : ToProcess)
704     MI->eraseFromParent();
705 }
706 
707 static void insertInlineAsm(MachineFunction &MF, SPIRVGlobalRegistry *GR,
708                             const SPIRVSubtarget &ST,
709                             MachineIRBuilder MIRBuilder) {
710   SmallVector<MachineInstr *> ToProcess;
711   for (MachineBasicBlock &MBB : MF) {
712     for (MachineInstr &MI : MBB) {
713       if (isSpvIntrinsic(MI, Intrinsic::spv_inline_asm) ||
714           MI.getOpcode() == TargetOpcode::INLINEASM)
715         ToProcess.push_back(&MI);
716     }
717   }
718   if (ToProcess.size() == 0)
719     return;
720 
721   if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_inline_assembly))
722     report_fatal_error("Inline assembly instructions require the "
723                        "following SPIR-V extension: SPV_INTEL_inline_assembly",
724                        false);
725 
726   insertInlineAsmProcess(MF, GR, ST, MIRBuilder, ToProcess);
727 }
728 
729 static void insertSpirvDecorations(MachineFunction &MF, MachineIRBuilder MIB) {
730   SmallVector<MachineInstr *, 10> ToErase;
731   for (MachineBasicBlock &MBB : MF) {
732     for (MachineInstr &MI : MBB) {
733       if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_decoration))
734         continue;
735       MIB.setInsertPt(*MI.getParent(), MI);
736       buildOpSpirvDecorations(MI.getOperand(1).getReg(), MIB,
737                               MI.getOperand(2).getMetadata());
738       ToErase.push_back(&MI);
739     }
740   }
741   for (MachineInstr *MI : ToErase)
742     MI->eraseFromParent();
743 }
744 
745 // LLVM allows the switches to use registers as cases, while SPIR-V required
746 // those to be immediate values. This function replaces such operands with the
747 // equivalent immediate constant.
748 static void processSwitchesConstants(MachineFunction &MF,
749                                      SPIRVGlobalRegistry *GR,
750                                      MachineIRBuilder MIB) {
751   MachineRegisterInfo &MRI = MF.getRegInfo();
752   for (MachineBasicBlock &MBB : MF) {
753     for (MachineInstr &MI : MBB) {
754       if (!isSpvIntrinsic(MI, Intrinsic::spv_switch))
755         continue;
756 
757       SmallVector<MachineOperand, 8> NewOperands;
758       NewOperands.push_back(MI.getOperand(0)); // Opcode
759       NewOperands.push_back(MI.getOperand(1)); // Condition
760       NewOperands.push_back(MI.getOperand(2)); // Default
761       for (unsigned i = 3; i < MI.getNumOperands(); i += 2) {
762         Register Reg = MI.getOperand(i).getReg();
763         MachineInstr *ConstInstr = getDefInstrMaybeConstant(Reg, &MRI);
764         NewOperands.push_back(
765             MachineOperand::CreateCImm(ConstInstr->getOperand(1).getCImm()));
766 
767         NewOperands.push_back(MI.getOperand(i + 1));
768       }
769 
770       assert(MI.getNumOperands() == NewOperands.size());
771       while (MI.getNumOperands() > 0)
772         MI.removeOperand(0);
773       for (auto &MO : NewOperands)
774         MI.addOperand(MO);
775     }
776   }
777 }
778 
779 // Some instructions are used during CodeGen but should never be emitted.
780 // Cleaning up those.
781 static void cleanupHelperInstructions(MachineFunction &MF) {
782   SmallVector<MachineInstr *, 8> ToEraseMI;
783   for (MachineBasicBlock &MBB : MF) {
784     for (MachineInstr &MI : MBB) {
785       if (isSpvIntrinsic(MI, Intrinsic::spv_track_constant) ||
786           MI.getOpcode() == TargetOpcode::G_BRINDIRECT)
787         ToEraseMI.push_back(&MI);
788     }
789   }
790 
791   for (MachineInstr *MI : ToEraseMI)
792     MI->eraseFromParent();
793 }
794 
795 // Find all usages of G_BLOCK_ADDR in our intrinsics and replace those
796 // operands/registers by the actual MBB it references.
797 static void processBlockAddr(MachineFunction &MF, SPIRVGlobalRegistry *GR,
798                              MachineIRBuilder MIB) {
799   // Gather the reverse-mapping BB -> MBB.
800   DenseMap<const BasicBlock *, MachineBasicBlock *> BB2MBB;
801   for (MachineBasicBlock &MBB : MF)
802     BB2MBB[MBB.getBasicBlock()] = &MBB;
803 
804   // Gather instructions requiring patching. For now, only those can use
805   // G_BLOCK_ADDR.
806   SmallVector<MachineInstr *, 8> InstructionsToPatch;
807   for (MachineBasicBlock &MBB : MF) {
808     for (MachineInstr &MI : MBB) {
809       if (isSpvIntrinsic(MI, Intrinsic::spv_switch) ||
810           isSpvIntrinsic(MI, Intrinsic::spv_loop_merge) ||
811           isSpvIntrinsic(MI, Intrinsic::spv_selection_merge))
812         InstructionsToPatch.push_back(&MI);
813     }
814   }
815 
816   // For each instruction to fix, we replace all the G_BLOCK_ADDR operands by
817   // the actual MBB it references. Once those references have been updated, we
818   // can cleanup remaining G_BLOCK_ADDR references.
819   SmallPtrSet<MachineBasicBlock *, 8> ClearAddressTaken;
820   SmallPtrSet<MachineInstr *, 8> ToEraseMI;
821   MachineRegisterInfo &MRI = MF.getRegInfo();
822   for (MachineInstr *MI : InstructionsToPatch) {
823     SmallVector<MachineOperand, 8> NewOps;
824     for (unsigned i = 0; i < MI->getNumOperands(); ++i) {
825       // The operand is not a register, keep as-is.
826       if (!MI->getOperand(i).isReg()) {
827         NewOps.push_back(MI->getOperand(i));
828         continue;
829       }
830 
831       Register Reg = MI->getOperand(i).getReg();
832       MachineInstr *BuildMBB = MRI.getVRegDef(Reg);
833       // The register is not the result of G_BLOCK_ADDR, keep as-is.
834       if (!BuildMBB || BuildMBB->getOpcode() != TargetOpcode::G_BLOCK_ADDR) {
835         NewOps.push_back(MI->getOperand(i));
836         continue;
837       }
838 
839       assert(BuildMBB && BuildMBB->getOpcode() == TargetOpcode::G_BLOCK_ADDR &&
840              BuildMBB->getOperand(1).isBlockAddress() &&
841              BuildMBB->getOperand(1).getBlockAddress());
842       BasicBlock *BB =
843           BuildMBB->getOperand(1).getBlockAddress()->getBasicBlock();
844       auto It = BB2MBB.find(BB);
845       if (It == BB2MBB.end())
846         report_fatal_error("cannot find a machine basic block by a basic block "
847                            "in a switch statement");
848       MachineBasicBlock *ReferencedBlock = It->second;
849       NewOps.push_back(MachineOperand::CreateMBB(ReferencedBlock));
850 
851       ClearAddressTaken.insert(ReferencedBlock);
852       ToEraseMI.insert(BuildMBB);
853     }
854 
855     // Replace the operands.
856     assert(MI->getNumOperands() == NewOps.size());
857     while (MI->getNumOperands() > 0)
858       MI->removeOperand(0);
859     for (auto &MO : NewOps)
860       MI->addOperand(MO);
861 
862     if (MachineInstr *Next = MI->getNextNode()) {
863       if (isSpvIntrinsic(*Next, Intrinsic::spv_track_constant)) {
864         ToEraseMI.insert(Next);
865         Next = MI->getNextNode();
866       }
867       if (Next && Next->getOpcode() == TargetOpcode::G_BRINDIRECT)
868         ToEraseMI.insert(Next);
869     }
870   }
871 
872   // BlockAddress operands were used to keep information between passes,
873   // let's undo the "address taken" status to reflect that Succ doesn't
874   // actually correspond to an IR-level basic block.
875   for (MachineBasicBlock *Succ : ClearAddressTaken)
876     Succ->setAddressTakenIRBlock(nullptr);
877 
878   // If we just delete G_BLOCK_ADDR instructions with BlockAddress operands,
879   // this leaves their BasicBlock counterparts in a "address taken" status. This
880   // would make AsmPrinter to generate a series of unneeded labels of a "Address
881   // of block that was removed by CodeGen" kind. Let's first ensure that we
882   // don't have a dangling BlockAddress constants by zapping the BlockAddress
883   // nodes, and only after that proceed with erasing G_BLOCK_ADDR instructions.
884   Constant *Replacement =
885       ConstantInt::get(Type::getInt32Ty(MF.getFunction().getContext()), 1);
886   for (MachineInstr *BlockAddrI : ToEraseMI) {
887     if (BlockAddrI->getOpcode() == TargetOpcode::G_BLOCK_ADDR) {
888       BlockAddress *BA = const_cast<BlockAddress *>(
889           BlockAddrI->getOperand(1).getBlockAddress());
890       BA->replaceAllUsesWith(
891           ConstantExpr::getIntToPtr(Replacement, BA->getType()));
892       BA->destroyConstant();
893     }
894     BlockAddrI->eraseFromParent();
895   }
896 }
897 
898 static bool isImplicitFallthrough(MachineBasicBlock &MBB) {
899   if (MBB.empty())
900     return true;
901 
902   // Branching SPIR-V intrinsics are not detected by this generic method.
903   // Thus, we can only trust negative result.
904   if (!MBB.canFallThrough())
905     return false;
906 
907   // Otherwise, we must manually check if we have a SPIR-V intrinsic which
908   // prevent an implicit fallthrough.
909   for (MachineBasicBlock::reverse_iterator It = MBB.rbegin(), E = MBB.rend();
910        It != E; ++It) {
911     if (isSpvIntrinsic(*It, Intrinsic::spv_switch))
912       return false;
913   }
914   return true;
915 }
916 
917 static void removeImplicitFallthroughs(MachineFunction &MF,
918                                        MachineIRBuilder MIB) {
919   // It is valid for MachineBasicBlocks to not finish with a branch instruction.
920   // In such cases, they will simply fallthrough their immediate successor.
921   for (MachineBasicBlock &MBB : MF) {
922     if (!isImplicitFallthrough(MBB))
923       continue;
924 
925     assert(std::distance(MBB.successors().begin(), MBB.successors().end()) ==
926            1);
927     MIB.setInsertPt(MBB, MBB.end());
928     MIB.buildBr(**MBB.successors().begin());
929   }
930 }
931 
932 bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
933   // Initialize the type registry.
934   const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
935   SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
936   GR->setCurrentFunc(MF);
937   MachineIRBuilder MIB(MF);
938   // a registry of target extension constants
939   DenseMap<MachineInstr *, Type *> TargetExtConstTypes;
940   // to keep record of tracked constants
941   SmallSet<Register, 4> TrackedConstRegs;
942   addConstantsToTrack(MF, GR, ST, TargetExtConstTypes, TrackedConstRegs);
943   foldConstantsIntoIntrinsics(MF, TrackedConstRegs);
944   insertBitcasts(MF, GR, MIB);
945   generateAssignInstrs(MF, GR, MIB, TargetExtConstTypes);
946 
947   processSwitchesConstants(MF, GR, MIB);
948   processBlockAddr(MF, GR, MIB);
949   cleanupHelperInstructions(MF);
950 
951   processInstrsWithTypeFolding(MF, GR, MIB);
952   removeImplicitFallthroughs(MF, MIB);
953   insertSpirvDecorations(MF, MIB);
954   insertInlineAsm(MF, GR, ST, MIB);
955 
956   return true;
957 }
958 
959 INITIALIZE_PASS(SPIRVPreLegalizer, DEBUG_TYPE, "SPIRV pre legalizer", false,
960                 false)
961 
962 char SPIRVPreLegalizer::ID = 0;
963 
964 FunctionPass *llvm::createSPIRVPreLegalizerPass() {
965   return new SPIRVPreLegalizer();
966 }
967