xref: /llvm-project/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (revision 93cda6d6a75e98d5516fbf12ce984604be834f01)
1 //===-- SPIRVPreLegalizer.cpp - prepare IR for legalization -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The pass prepares IR for legalization: it assigns SPIR-V types to registers
10 // and removes intrinsics which holded these types during IR translation.
11 // Also it processes constants and registers them in GR to avoid duplication.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "SPIRV.h"
16 #include "SPIRVSubtarget.h"
17 #include "SPIRVUtils.h"
18 #include "llvm/ADT/PostOrderIterator.h"
19 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
20 #include "llvm/IR/Attributes.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DebugInfoMetadata.h"
23 #include "llvm/IR/IntrinsicsSPIRV.h"
24 #include "llvm/Target/TargetIntrinsicInfo.h"
25 
26 #define DEBUG_TYPE "spirv-prelegalizer"
27 
28 using namespace llvm;
29 
30 namespace {
31 class SPIRVPreLegalizer : public MachineFunctionPass {
32 public:
33   static char ID;
34   SPIRVPreLegalizer() : MachineFunctionPass(ID) {
35     initializeSPIRVPreLegalizerPass(*PassRegistry::getPassRegistry());
36   }
37   bool runOnMachineFunction(MachineFunction &MF) override;
38 };
39 } // namespace
40 
41 static void
42 addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR,
43                     const SPIRVSubtarget &STI,
44                     DenseMap<MachineInstr *, Type *> &TargetExtConstTypes,
45                     SmallSet<Register, 4> &TrackedConstRegs) {
46   MachineRegisterInfo &MRI = MF.getRegInfo();
47   DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
48   SmallVector<MachineInstr *, 10> ToErase, ToEraseComposites;
49   for (MachineBasicBlock &MBB : MF) {
50     for (MachineInstr &MI : MBB) {
51       if (!isSpvIntrinsic(MI, Intrinsic::spv_track_constant))
52         continue;
53       ToErase.push_back(&MI);
54       Register SrcReg = MI.getOperand(2).getReg();
55       auto *Const =
56           cast<Constant>(cast<ConstantAsMetadata>(
57                              MI.getOperand(3).getMetadata()->getOperand(0))
58                              ->getValue());
59       if (auto *GV = dyn_cast<GlobalValue>(Const)) {
60         Register Reg = GR->find(GV, &MF);
61         if (!Reg.isValid())
62           GR->add(GV, &MF, SrcReg);
63         else
64           RegsAlreadyAddedToDT[&MI] = Reg;
65       } else {
66         Register Reg = GR->find(Const, &MF);
67         if (!Reg.isValid()) {
68           if (auto *ConstVec = dyn_cast<ConstantDataVector>(Const)) {
69             auto *BuildVec = MRI.getVRegDef(SrcReg);
70             assert(BuildVec &&
71                    BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
72             for (unsigned i = 0; i < ConstVec->getNumElements(); ++i) {
73               // Ensure that OpConstantComposite reuses a constant when it's
74               // already created and available in the same machine function.
75               Constant *ElemConst = ConstVec->getElementAsConstant(i);
76               Register ElemReg = GR->find(ElemConst, &MF);
77               if (!ElemReg.isValid())
78                 GR->add(ElemConst, &MF, BuildVec->getOperand(1 + i).getReg());
79               else
80                 BuildVec->getOperand(1 + i).setReg(ElemReg);
81             }
82           }
83           GR->add(Const, &MF, SrcReg);
84           TrackedConstRegs.insert(SrcReg);
85           if (Const->getType()->isTargetExtTy()) {
86             // remember association so that we can restore it when assign types
87             MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
88             if (SrcMI && (SrcMI->getOpcode() == TargetOpcode::G_CONSTANT ||
89                           SrcMI->getOpcode() == TargetOpcode::G_IMPLICIT_DEF))
90               TargetExtConstTypes[SrcMI] = Const->getType();
91             if (Const->isNullValue()) {
92               MachineIRBuilder MIB(MF);
93               SPIRVType *ExtType =
94                   GR->getOrCreateSPIRVType(Const->getType(), MIB);
95               SrcMI->setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull));
96               SrcMI->addOperand(MachineOperand::CreateReg(
97                   GR->getSPIRVTypeID(ExtType), false));
98             }
99           }
100         } else {
101           RegsAlreadyAddedToDT[&MI] = Reg;
102           // This MI is unused and will be removed. If the MI uses
103           // const_composite, it will be unused and should be removed too.
104           assert(MI.getOperand(2).isReg() && "Reg operand is expected");
105           MachineInstr *SrcMI = MRI.getVRegDef(MI.getOperand(2).getReg());
106           if (SrcMI && isSpvIntrinsic(*SrcMI, Intrinsic::spv_const_composite))
107             ToEraseComposites.push_back(SrcMI);
108         }
109       }
110     }
111   }
112   for (MachineInstr *MI : ToErase) {
113     Register Reg = MI->getOperand(2).getReg();
114     if (RegsAlreadyAddedToDT.contains(MI))
115       Reg = RegsAlreadyAddedToDT[MI];
116     auto *RC = MRI.getRegClassOrNull(MI->getOperand(0).getReg());
117     if (!MRI.getRegClassOrNull(Reg) && RC)
118       MRI.setRegClass(Reg, RC);
119     MRI.replaceRegWith(MI->getOperand(0).getReg(), Reg);
120     MI->eraseFromParent();
121   }
122   for (MachineInstr *MI : ToEraseComposites)
123     MI->eraseFromParent();
124 }
125 
126 static void
127 foldConstantsIntoIntrinsics(MachineFunction &MF,
128                             const SmallSet<Register, 4> &TrackedConstRegs) {
129   SmallVector<MachineInstr *, 10> ToErase;
130   MachineRegisterInfo &MRI = MF.getRegInfo();
131   const unsigned AssignNameOperandShift = 2;
132   for (MachineBasicBlock &MBB : MF) {
133     for (MachineInstr &MI : MBB) {
134       if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_name))
135         continue;
136       unsigned NumOp = MI.getNumExplicitDefs() + AssignNameOperandShift;
137       while (MI.getOperand(NumOp).isReg()) {
138         MachineOperand &MOp = MI.getOperand(NumOp);
139         MachineInstr *ConstMI = MRI.getVRegDef(MOp.getReg());
140         assert(ConstMI->getOpcode() == TargetOpcode::G_CONSTANT);
141         MI.removeOperand(NumOp);
142         MI.addOperand(MachineOperand::CreateImm(
143             ConstMI->getOperand(1).getCImm()->getZExtValue()));
144         Register DefReg = ConstMI->getOperand(0).getReg();
145         if (MRI.use_empty(DefReg) && !TrackedConstRegs.contains(DefReg))
146           ToErase.push_back(ConstMI);
147       }
148     }
149   }
150   for (MachineInstr *MI : ToErase)
151     MI->eraseFromParent();
152 }
153 
154 static MachineInstr *findAssignTypeInstr(Register Reg,
155                                          MachineRegisterInfo *MRI) {
156   for (MachineRegisterInfo::use_instr_iterator I = MRI->use_instr_begin(Reg),
157                                                IE = MRI->use_instr_end();
158        I != IE; ++I) {
159     MachineInstr *UseMI = &*I;
160     if ((isSpvIntrinsic(*UseMI, Intrinsic::spv_assign_ptr_type) ||
161          isSpvIntrinsic(*UseMI, Intrinsic::spv_assign_type)) &&
162         UseMI->getOperand(1).getReg() == Reg)
163       return UseMI;
164   }
165   return nullptr;
166 }
167 
168 static void buildOpBitcast(SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
169                            Register ResVReg, Register OpReg) {
170   SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResVReg);
171   SPIRVType *OpType = GR->getSPIRVTypeForVReg(OpReg);
172   assert(ResType && OpType && "Operand types are expected");
173   if (!GR->isBitcastCompatible(ResType, OpType))
174     report_fatal_error("incompatible result and operand types in a bitcast");
175   MachineRegisterInfo *MRI = MIB.getMRI();
176   if (!MRI->getRegClassOrNull(ResVReg))
177     MRI->setRegClass(ResVReg, GR->getRegClass(ResType));
178   if (ResType == OpType)
179     MIB.buildInstr(TargetOpcode::COPY).addDef(ResVReg).addUse(OpReg);
180   else
181     MIB.buildInstr(SPIRV::OpBitcast)
182         .addDef(ResVReg)
183         .addUse(GR->getSPIRVTypeID(ResType))
184         .addUse(OpReg);
185 }
186 
187 // We do instruction selections early instead of calling MIB.buildBitcast()
188 // generating the general op code G_BITCAST. When MachineVerifier validates
189 // G_BITCAST we see a check of a kind: if Source Type is equal to Destination
190 // Type then report error "bitcast must change the type". This doesn't take into
191 // account the notion of a typed pointer that is important for SPIR-V where a
192 // user may and should use bitcast between pointers with different pointee types
193 // (https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast).
194 // It's important for correct lowering in SPIR-V, because interpretation of the
195 // data type is not left to instructions that utilize the pointer, but encoded
196 // by the pointer declaration, and the SPIRV target can and must handle the
197 // declaration and use of pointers that specify the type of data they point to.
198 // It's not feasible to improve validation of G_BITCAST using just information
199 // provided by low level types of source and destination. Therefore we don't
200 // produce G_BITCAST as the general op code with semantics different from
201 // OpBitcast, but rather lower to OpBitcast immediately. As for now, the only
202 // difference would be that CombinerHelper couldn't transform known patterns
203 // around G_BUILD_VECTOR. See discussion
204 // in https://github.com/llvm/llvm-project/pull/110270 for even more context.
205 static void selectOpBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
206                              MachineIRBuilder MIB) {
207   SmallVector<MachineInstr *, 16> ToErase;
208   for (MachineBasicBlock &MBB : MF) {
209     for (MachineInstr &MI : MBB) {
210       if (MI.getOpcode() != TargetOpcode::G_BITCAST)
211         continue;
212       MIB.setInsertPt(*MI.getParent(), MI);
213       buildOpBitcast(GR, MIB, MI.getOperand(0).getReg(),
214                      MI.getOperand(1).getReg());
215       ToErase.push_back(&MI);
216     }
217   }
218   for (MachineInstr *MI : ToErase)
219     MI->eraseFromParent();
220 }
221 
222 static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
223                            MachineIRBuilder MIB) {
224   // Get access to information about available extensions
225   const SPIRVSubtarget *ST =
226       static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
227   SmallVector<MachineInstr *, 10> ToErase;
228   for (MachineBasicBlock &MBB : MF) {
229     for (MachineInstr &MI : MBB) {
230       if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast) &&
231           !isSpvIntrinsic(MI, Intrinsic::spv_ptrcast))
232         continue;
233       assert(MI.getOperand(2).isReg());
234       MIB.setInsertPt(*MI.getParent(), MI);
235       ToErase.push_back(&MI);
236       if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) {
237         MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
238         continue;
239       }
240       Register Def = MI.getOperand(0).getReg();
241       Register Source = MI.getOperand(2).getReg();
242       Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0);
243       SPIRVType *BaseTy = GR->getOrCreateSPIRVType(ElemTy, MIB);
244       SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
245           BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
246           addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST));
247 
248       // If the ptrcast would be redundant, replace all uses with the source
249       // register.
250       MachineRegisterInfo *MRI = MIB.getMRI();
251       if (GR->getSPIRVTypeForVReg(Source) == AssignedPtrType) {
252         // Erase Def's assign type instruction if we are going to replace Def.
253         if (MachineInstr *AssignMI = findAssignTypeInstr(Def, MRI))
254           ToErase.push_back(AssignMI);
255         MRI->replaceRegWith(Def, Source);
256       } else {
257         GR->assignSPIRVTypeToVReg(AssignedPtrType, Def, MF);
258         MIB.buildBitcast(Def, Source);
259       }
260     }
261   }
262   for (MachineInstr *MI : ToErase)
263     MI->eraseFromParent();
264 }
265 
266 // Translating GV, IRTranslator sometimes generates following IR:
267 //   %1 = G_GLOBAL_VALUE
268 //   %2 = COPY %1
269 //   %3 = G_ADDRSPACE_CAST %2
270 //
271 // or
272 //
273 //  %1 = G_ZEXT %2
274 //  G_MEMCPY ... %2 ...
275 //
276 // New registers have no SPIRVType and no register class info.
277 //
278 // Set SPIRVType for GV, propagate it from GV to other instructions,
279 // also set register classes.
280 static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
281                                      MachineRegisterInfo &MRI,
282                                      MachineIRBuilder &MIB) {
283   SPIRVType *SpvType = nullptr;
284   assert(MI && "Machine instr is expected");
285   if (MI->getOperand(0).isReg()) {
286     Register Reg = MI->getOperand(0).getReg();
287     SpvType = GR->getSPIRVTypeForVReg(Reg);
288     if (!SpvType) {
289       switch (MI->getOpcode()) {
290       case TargetOpcode::G_CONSTANT: {
291         MIB.setInsertPt(*MI->getParent(), MI);
292         Type *Ty = MI->getOperand(1).getCImm()->getType();
293         SpvType = GR->getOrCreateSPIRVType(Ty, MIB);
294         break;
295       }
296       case TargetOpcode::G_GLOBAL_VALUE: {
297         MIB.setInsertPt(*MI->getParent(), MI);
298         const GlobalValue *Global = MI->getOperand(1).getGlobal();
299         Type *ElementTy = toTypedPointer(GR->getDeducedGlobalValueType(Global));
300         auto *Ty = TypedPointerType::get(ElementTy,
301                                          Global->getType()->getAddressSpace());
302         SpvType = GR->getOrCreateSPIRVType(Ty, MIB);
303         break;
304       }
305       case TargetOpcode::G_ANYEXT:
306       case TargetOpcode::G_SEXT:
307       case TargetOpcode::G_ZEXT: {
308         if (MI->getOperand(1).isReg()) {
309           if (MachineInstr *DefInstr =
310                   MRI.getVRegDef(MI->getOperand(1).getReg())) {
311             if (SPIRVType *Def = propagateSPIRVType(DefInstr, GR, MRI, MIB)) {
312               unsigned CurrentBW = GR->getScalarOrVectorBitWidth(Def);
313               unsigned ExpectedBW =
314                   std::max(MRI.getType(Reg).getScalarSizeInBits(), CurrentBW);
315               unsigned NumElements = GR->getScalarOrVectorComponentCount(Def);
316               SpvType = GR->getOrCreateSPIRVIntegerType(ExpectedBW, MIB);
317               if (NumElements > 1)
318                 SpvType =
319                     GR->getOrCreateSPIRVVectorType(SpvType, NumElements, MIB);
320             }
321           }
322         }
323         break;
324       }
325       case TargetOpcode::G_PTRTOINT:
326         SpvType = GR->getOrCreateSPIRVIntegerType(
327             MRI.getType(Reg).getScalarSizeInBits(), MIB);
328         break;
329       case TargetOpcode::G_TRUNC:
330       case TargetOpcode::G_ADDRSPACE_CAST:
331       case TargetOpcode::G_PTR_ADD:
332       case TargetOpcode::COPY: {
333         MachineOperand &Op = MI->getOperand(1);
334         MachineInstr *Def = Op.isReg() ? MRI.getVRegDef(Op.getReg()) : nullptr;
335         if (Def)
336           SpvType = propagateSPIRVType(Def, GR, MRI, MIB);
337         break;
338       }
339       default:
340         break;
341       }
342       if (SpvType) {
343         // check if the address space needs correction
344         LLT RegType = MRI.getType(Reg);
345         if (SpvType->getOpcode() == SPIRV::OpTypePointer &&
346             RegType.isPointer() &&
347             storageClassToAddressSpace(GR->getPointerStorageClass(SpvType)) !=
348                 RegType.getAddressSpace()) {
349           const SPIRVSubtarget &ST =
350               MI->getParent()->getParent()->getSubtarget<SPIRVSubtarget>();
351           SpvType = GR->getOrCreateSPIRVPointerType(
352               GR->getPointeeType(SpvType), *MI, *ST.getInstrInfo(),
353               addressSpaceToStorageClass(RegType.getAddressSpace(), ST));
354         }
355         GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF());
356       }
357       if (!MRI.getRegClassOrNull(Reg))
358         MRI.setRegClass(Reg, SpvType ? GR->getRegClass(SpvType)
359                                      : &SPIRV::iIDRegClass);
360     }
361   }
362   return SpvType;
363 }
364 
365 // To support current approach and limitations wrt. bit width here we widen a
366 // scalar register with a bit width greater than 1 to valid sizes and cap it to
367 // 64 width.
368 static void widenScalarLLTNextPow2(Register Reg, MachineRegisterInfo &MRI) {
369   LLT RegType = MRI.getType(Reg);
370   if (!RegType.isScalar())
371     return;
372   unsigned Sz = RegType.getScalarSizeInBits();
373   if (Sz == 1)
374     return;
375   unsigned NewSz = std::min(std::max(1u << Log2_32_Ceil(Sz), 8u), 64u);
376   if (NewSz != Sz)
377     MRI.setType(Reg, LLT::scalar(NewSz));
378 }
379 
380 static std::pair<Register, unsigned>
381 createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI,
382                const SPIRVGlobalRegistry &GR) {
383   if (!SpvType)
384     SpvType = GR.getSPIRVTypeForVReg(SrcReg);
385   const TargetRegisterClass *RC = GR.getRegClass(SpvType);
386   Register Reg = MRI.createGenericVirtualRegister(GR.getRegType(SpvType));
387   MRI.setRegClass(Reg, RC);
388   unsigned GetIdOp = SPIRV::GET_ID;
389   if (RC == &SPIRV::fIDRegClass)
390     GetIdOp = SPIRV::GET_fID;
391   else if (RC == &SPIRV::pIDRegClass)
392     GetIdOp = SPIRV::GET_pID;
393   else if (RC == &SPIRV::vfIDRegClass)
394     GetIdOp = SPIRV::GET_vfID;
395   else if (RC == &SPIRV::vpIDRegClass)
396     GetIdOp = SPIRV::GET_vpID;
397   else if (RC == &SPIRV::vIDRegClass)
398     GetIdOp = SPIRV::GET_vID;
399   return {Reg, GetIdOp};
400 }
401 
402 static void setInsertPtAfterDef(MachineIRBuilder &MIB, MachineInstr *Def) {
403   MachineBasicBlock &MBB = *Def->getParent();
404   MachineBasicBlock::iterator DefIt =
405       Def->getNextNode() ? Def->getNextNode()->getIterator() : MBB.end();
406   // Skip all the PHI and debug instructions.
407   while (DefIt != MBB.end() &&
408          (DefIt->isPHI() || DefIt->isDebugOrPseudoInstr()))
409     DefIt = std::next(DefIt);
410   MIB.setInsertPt(MBB, DefIt);
411 }
412 
413 // Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as
414 // a dst of the definition, assign SPIRVType to both registers. If SpvType is
415 // provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty.
416 // It's used also in SPIRVBuiltins.cpp.
417 // TODO: maybe move to SPIRVUtils.
418 namespace llvm {
419 Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpvType,
420                            SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
421                            MachineRegisterInfo &MRI) {
422   assert((Ty || SpvType) && "Either LLVM or SPIRV type is expected.");
423   MachineInstr *Def = MRI.getVRegDef(Reg);
424   setInsertPtAfterDef(MIB, Def);
425   SpvType = SpvType ? SpvType : GR->getOrCreateSPIRVType(Ty, MIB);
426   Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
427   if (auto *RC = MRI.getRegClassOrNull(Reg)) {
428     MRI.setRegClass(NewReg, RC);
429   } else {
430     auto RegClass = GR->getRegClass(SpvType);
431     MRI.setRegClass(NewReg, RegClass);
432     MRI.setRegClass(Reg, RegClass);
433   }
434   GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF());
435   // This is to make it convenient for Legalizer to get the SPIRVType
436   // when processing the actual MI (i.e. not pseudo one).
437   GR->assignSPIRVTypeToVReg(SpvType, NewReg, MIB.getMF());
438   // Copy MIFlags from Def to ASSIGN_TYPE instruction. It's required to keep
439   // the flags after instruction selection.
440   const uint32_t Flags = Def->getFlags();
441   MIB.buildInstr(SPIRV::ASSIGN_TYPE)
442       .addDef(Reg)
443       .addUse(NewReg)
444       .addUse(GR->getSPIRVTypeID(SpvType))
445       .setMIFlags(Flags);
446   for (unsigned I = 0, E = Def->getNumDefs(); I != E; ++I) {
447     MachineOperand &MO = Def->getOperand(I);
448     if (MO.getReg() == Reg) {
449       MO.setReg(NewReg);
450       break;
451     }
452   }
453   return NewReg;
454 }
455 
456 void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
457                   MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) {
458   assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg()));
459   MachineInstr &AssignTypeInst =
460       *(MRI.use_instr_begin(MI.getOperand(0).getReg()));
461   auto NewReg =
462       createNewIdReg(nullptr, MI.getOperand(0).getReg(), MRI, *GR).first;
463   AssignTypeInst.getOperand(1).setReg(NewReg);
464   MI.getOperand(0).setReg(NewReg);
465   MIB.setInsertPt(*MI.getParent(), MI.getIterator());
466   for (auto &Op : MI.operands()) {
467     if (!Op.isReg() || Op.isDef())
468       continue;
469     auto IdOpInfo = createNewIdReg(nullptr, Op.getReg(), MRI, *GR);
470     MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg());
471     Op.setReg(IdOpInfo.first);
472   }
473 }
474 } // namespace llvm
475 
476 static void
477 generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
478                      MachineIRBuilder MIB,
479                      DenseMap<MachineInstr *, Type *> &TargetExtConstTypes) {
480   // Get access to information about available extensions
481   const SPIRVSubtarget *ST =
482       static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
483 
484   MachineRegisterInfo &MRI = MF.getRegInfo();
485   SmallVector<MachineInstr *, 10> ToErase;
486   DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
487 
488   bool IsExtendedInts =
489       ST->canUseExtension(
490           SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
491       ST->canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
492 
493   for (MachineBasicBlock *MBB : post_order(&MF)) {
494     if (MBB->empty())
495       continue;
496 
497     bool ReachedBegin = false;
498     for (auto MII = std::prev(MBB->end()), Begin = MBB->begin();
499          !ReachedBegin;) {
500       MachineInstr &MI = *MII;
501       unsigned MIOp = MI.getOpcode();
502 
503       if (!IsExtendedInts) {
504         // validate bit width of scalar registers
505         for (const auto &MOP : MI.operands())
506           if (MOP.isReg())
507             widenScalarLLTNextPow2(MOP.getReg(), MRI);
508       }
509 
510       if (isSpvIntrinsic(MI, Intrinsic::spv_assign_ptr_type)) {
511         Register Reg = MI.getOperand(1).getReg();
512         MIB.setInsertPt(*MI.getParent(), MI.getIterator());
513         Type *ElementTy = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
514         SPIRVType *BaseTy = GR->getOrCreateSPIRVType(ElementTy, MIB);
515         SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
516             BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
517             addressSpaceToStorageClass(MI.getOperand(3).getImm(), *ST));
518         MachineInstr *Def = MRI.getVRegDef(Reg);
519         assert(Def && "Expecting an instruction that defines the register");
520         // G_GLOBAL_VALUE already has type info.
521         if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE &&
522             Def->getOpcode() != SPIRV::ASSIGN_TYPE)
523           insertAssignInstr(Reg, nullptr, AssignedPtrType, GR, MIB,
524                             MF.getRegInfo());
525         ToErase.push_back(&MI);
526       } else if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) {
527         Register Reg = MI.getOperand(1).getReg();
528         Type *Ty = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
529         MachineInstr *Def = MRI.getVRegDef(Reg);
530         assert(Def && "Expecting an instruction that defines the register");
531         // G_GLOBAL_VALUE already has type info.
532         if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE &&
533             Def->getOpcode() != SPIRV::ASSIGN_TYPE)
534           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo());
535         ToErase.push_back(&MI);
536       } else if (MIOp == TargetOpcode::FAKE_USE && MI.getNumOperands() > 0) {
537         MachineInstr *MdMI = MI.getPrevNode();
538         if (MdMI && isSpvIntrinsic(*MdMI, Intrinsic::spv_value_md)) {
539           // It's an internal service info from before IRTranslator passes.
540           MachineInstr *Def = getVRegDef(MRI, MI.getOperand(0).getReg());
541           for (unsigned I = 1, E = MI.getNumOperands(); I != E && Def; ++I)
542             if (getVRegDef(MRI, MI.getOperand(I).getReg()) != Def)
543               Def = nullptr;
544           if (Def) {
545             const MDNode *MD = MdMI->getOperand(1).getMetadata();
546             StringRef ValueName =
547                 cast<MDString>(MD->getOperand(1))->getString();
548             const MDNode *TypeMD = cast<MDNode>(MD->getOperand(0));
549             Type *ValueTy = getMDOperandAsType(TypeMD, 0);
550             GR->addValueAttrs(Def, std::make_pair(ValueTy, ValueName.str()));
551           }
552           ToErase.push_back(MdMI);
553         }
554         ToErase.push_back(&MI);
555       } else if (MIOp == TargetOpcode::G_CONSTANT ||
556                  MIOp == TargetOpcode::G_FCONSTANT ||
557                  MIOp == TargetOpcode::G_BUILD_VECTOR) {
558         // %rc = G_CONSTANT ty Val
559         // ===>
560         // %cty = OpType* ty
561         // %rctmp = G_CONSTANT ty Val
562         // %rc = ASSIGN_TYPE %rctmp, %cty
563         Register Reg = MI.getOperand(0).getReg();
564         bool NeedAssignType = true;
565         if (MRI.hasOneUse(Reg)) {
566           MachineInstr &UseMI = *MRI.use_instr_begin(Reg);
567           if (isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type) ||
568               isSpvIntrinsic(UseMI, Intrinsic::spv_assign_name))
569             continue;
570           if (UseMI.getOpcode() == SPIRV::ASSIGN_TYPE)
571             NeedAssignType = false;
572         }
573         Type *Ty = nullptr;
574         if (MIOp == TargetOpcode::G_CONSTANT) {
575           auto TargetExtIt = TargetExtConstTypes.find(&MI);
576           Ty = TargetExtIt == TargetExtConstTypes.end()
577                    ? MI.getOperand(1).getCImm()->getType()
578                    : TargetExtIt->second;
579           const ConstantInt *OpCI = MI.getOperand(1).getCImm();
580           // TODO: we may wish to analyze here if OpCI is zero and LLT RegType =
581           // MRI.getType(Reg); RegType.isPointer() is true, so that we observe
582           // at this point not i64/i32 constant but null pointer in the
583           // corresponding address space of RegType.getAddressSpace(). This may
584           // help to successfully validate the case when a OpConstantComposite's
585           // constituent has type that does not match Result Type of
586           // OpConstantComposite (see, for example,
587           // pointers/PtrCast-null-in-OpSpecConstantOp.ll).
588           Register PrimaryReg = GR->find(OpCI, &MF);
589           if (!PrimaryReg.isValid()) {
590             GR->add(OpCI, &MF, Reg);
591           } else if (PrimaryReg != Reg &&
592                      MRI.getType(Reg) == MRI.getType(PrimaryReg)) {
593             auto *RCReg = MRI.getRegClassOrNull(Reg);
594             auto *RCPrimary = MRI.getRegClassOrNull(PrimaryReg);
595             if (!RCReg || RCPrimary == RCReg) {
596               RegsAlreadyAddedToDT[&MI] = PrimaryReg;
597               ToErase.push_back(&MI);
598               NeedAssignType = false;
599             }
600           }
601         } else if (MIOp == TargetOpcode::G_FCONSTANT) {
602           Ty = MI.getOperand(1).getFPImm()->getType();
603         } else {
604           assert(MIOp == TargetOpcode::G_BUILD_VECTOR);
605           Type *ElemTy = nullptr;
606           MachineInstr *ElemMI = MRI.getVRegDef(MI.getOperand(1).getReg());
607           assert(ElemMI);
608 
609           if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT) {
610             ElemTy = ElemMI->getOperand(1).getCImm()->getType();
611           } else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT) {
612             ElemTy = ElemMI->getOperand(1).getFPImm()->getType();
613           } else {
614             // There may be a case when we already know Reg's type.
615             MachineInstr *NextMI = MI.getNextNode();
616             if (!NextMI || NextMI->getOpcode() != SPIRV::ASSIGN_TYPE ||
617                 NextMI->getOperand(1).getReg() != Reg)
618               llvm_unreachable("Unexpected opcode");
619           }
620           if (ElemTy)
621             Ty = VectorType::get(
622                 ElemTy, MI.getNumExplicitOperands() - MI.getNumExplicitDefs(),
623                 false);
624           else
625             NeedAssignType = false;
626         }
627         if (NeedAssignType)
628           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI);
629       } else if (MIOp == TargetOpcode::G_GLOBAL_VALUE) {
630         propagateSPIRVType(&MI, GR, MRI, MIB);
631       }
632 
633       if (MII == Begin)
634         ReachedBegin = true;
635       else
636         --MII;
637     }
638   }
639   for (MachineInstr *MI : ToErase) {
640     auto It = RegsAlreadyAddedToDT.find(MI);
641     if (RegsAlreadyAddedToDT.contains(MI))
642       MRI.replaceRegWith(MI->getOperand(0).getReg(), It->second);
643     MI->eraseFromParent();
644   }
645 
646   // Address the case when IRTranslator introduces instructions with new
647   // registers without SPIRVType associated.
648   for (MachineBasicBlock &MBB : MF) {
649     for (MachineInstr &MI : MBB) {
650       switch (MI.getOpcode()) {
651       case TargetOpcode::G_TRUNC:
652       case TargetOpcode::G_ANYEXT:
653       case TargetOpcode::G_SEXT:
654       case TargetOpcode::G_ZEXT:
655       case TargetOpcode::G_PTRTOINT:
656       case TargetOpcode::COPY:
657       case TargetOpcode::G_ADDRSPACE_CAST:
658         propagateSPIRVType(&MI, GR, MRI, MIB);
659         break;
660       }
661     }
662   }
663 }
664 
665 // Defined in SPIRVLegalizerInfo.cpp.
666 extern bool isTypeFoldingSupported(unsigned Opcode);
667 
668 static void processInstrsWithTypeFolding(MachineFunction &MF,
669                                          SPIRVGlobalRegistry *GR,
670                                          MachineIRBuilder MIB) {
671   MachineRegisterInfo &MRI = MF.getRegInfo();
672   for (MachineBasicBlock &MBB : MF) {
673     for (MachineInstr &MI : MBB) {
674       if (isTypeFoldingSupported(MI.getOpcode()))
675         processInstr(MI, MIB, MRI, GR);
676     }
677   }
678 
679   for (MachineBasicBlock &MBB : MF) {
680     for (MachineInstr &MI : MBB) {
681       // We need to rewrite dst types for ASSIGN_TYPE instrs to be able
682       // to perform tblgen'erated selection and we can't do that on Legalizer
683       // as it operates on gMIR only.
684       if (MI.getOpcode() != SPIRV::ASSIGN_TYPE)
685         continue;
686       Register SrcReg = MI.getOperand(1).getReg();
687       unsigned Opcode = MRI.getVRegDef(SrcReg)->getOpcode();
688       if (!isTypeFoldingSupported(Opcode))
689         continue;
690       Register DstReg = MI.getOperand(0).getReg();
691       // Don't need to reset type of register holding constant and used in
692       // G_ADDRSPACE_CAST, since it breaks legalizer.
693       if (Opcode == TargetOpcode::G_CONSTANT && MRI.hasOneUse(DstReg)) {
694         MachineInstr &UseMI = *MRI.use_instr_begin(DstReg);
695         if (UseMI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST)
696           continue;
697       }
698     }
699   }
700 }
701 
702 static Register
703 collectInlineAsmInstrOperands(MachineInstr *MI,
704                               SmallVector<unsigned, 4> *Ops = nullptr) {
705   Register DefReg;
706   unsigned StartOp = InlineAsm::MIOp_FirstOperand,
707            AsmDescOp = InlineAsm::MIOp_FirstOperand;
708   for (unsigned Idx = StartOp, MISz = MI->getNumOperands(); Idx != MISz;
709        ++Idx) {
710     const MachineOperand &MO = MI->getOperand(Idx);
711     if (MO.isMetadata())
712       continue;
713     if (Idx == AsmDescOp && MO.isImm()) {
714       // compute the index of the next operand descriptor
715       const InlineAsm::Flag F(MO.getImm());
716       AsmDescOp += 1 + F.getNumOperandRegisters();
717       continue;
718     }
719     if (MO.isReg() && MO.isDef()) {
720       if (!Ops)
721         return MO.getReg();
722       else
723         DefReg = MO.getReg();
724     } else if (Ops) {
725       Ops->push_back(Idx);
726     }
727   }
728   return DefReg;
729 }
730 
731 static void
732 insertInlineAsmProcess(MachineFunction &MF, SPIRVGlobalRegistry *GR,
733                        const SPIRVSubtarget &ST, MachineIRBuilder MIRBuilder,
734                        const SmallVector<MachineInstr *> &ToProcess) {
735   MachineRegisterInfo &MRI = MF.getRegInfo();
736   Register AsmTargetReg;
737   for (unsigned i = 0, Sz = ToProcess.size(); i + 1 < Sz; i += 2) {
738     MachineInstr *I1 = ToProcess[i], *I2 = ToProcess[i + 1];
739     assert(isSpvIntrinsic(*I1, Intrinsic::spv_inline_asm) && I2->isInlineAsm());
740     MIRBuilder.setInsertPt(*I2->getParent(), *I2);
741 
742     if (!AsmTargetReg.isValid()) {
743       // define vendor specific assembly target or dialect
744       AsmTargetReg = MRI.createGenericVirtualRegister(LLT::scalar(32));
745       MRI.setRegClass(AsmTargetReg, &SPIRV::iIDRegClass);
746       auto AsmTargetMIB =
747           MIRBuilder.buildInstr(SPIRV::OpAsmTargetINTEL).addDef(AsmTargetReg);
748       addStringImm(ST.getTargetTripleAsStr(), AsmTargetMIB);
749       GR->add(AsmTargetMIB.getInstr(), &MF, AsmTargetReg);
750     }
751 
752     // create types
753     const MDNode *IAMD = I1->getOperand(1).getMetadata();
754     FunctionType *FTy = cast<FunctionType>(getMDOperandAsType(IAMD, 0));
755     SmallVector<SPIRVType *, 4> ArgTypes;
756     for (const auto &ArgTy : FTy->params())
757       ArgTypes.push_back(GR->getOrCreateSPIRVType(ArgTy, MIRBuilder));
758     SPIRVType *RetType =
759         GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
760     SPIRVType *FuncType = GR->getOrCreateOpTypeFunctionWithArgs(
761         FTy, RetType, ArgTypes, MIRBuilder);
762 
763     // define vendor specific assembly instructions string
764     Register AsmReg = MRI.createGenericVirtualRegister(LLT::scalar(32));
765     MRI.setRegClass(AsmReg, &SPIRV::iIDRegClass);
766     auto AsmMIB = MIRBuilder.buildInstr(SPIRV::OpAsmINTEL)
767                       .addDef(AsmReg)
768                       .addUse(GR->getSPIRVTypeID(RetType))
769                       .addUse(GR->getSPIRVTypeID(FuncType))
770                       .addUse(AsmTargetReg);
771     // inline asm string:
772     addStringImm(I2->getOperand(InlineAsm::MIOp_AsmString).getSymbolName(),
773                  AsmMIB);
774     // inline asm constraint string:
775     addStringImm(cast<MDString>(I1->getOperand(2).getMetadata()->getOperand(0))
776                      ->getString(),
777                  AsmMIB);
778     GR->add(AsmMIB.getInstr(), &MF, AsmReg);
779 
780     // calls the inline assembly instruction
781     unsigned ExtraInfo = I2->getOperand(InlineAsm::MIOp_ExtraInfo).getImm();
782     if (ExtraInfo & InlineAsm::Extra_HasSideEffects)
783       MIRBuilder.buildInstr(SPIRV::OpDecorate)
784           .addUse(AsmReg)
785           .addImm(static_cast<uint32_t>(SPIRV::Decoration::SideEffectsINTEL));
786 
787     Register DefReg = collectInlineAsmInstrOperands(I2);
788     if (!DefReg.isValid()) {
789       DefReg = MRI.createGenericVirtualRegister(LLT::scalar(32));
790       MRI.setRegClass(DefReg, &SPIRV::iIDRegClass);
791       SPIRVType *VoidType = GR->getOrCreateSPIRVType(
792           Type::getVoidTy(MF.getFunction().getContext()), MIRBuilder);
793       GR->assignSPIRVTypeToVReg(VoidType, DefReg, MF);
794     }
795 
796     auto AsmCall = MIRBuilder.buildInstr(SPIRV::OpAsmCallINTEL)
797                        .addDef(DefReg)
798                        .addUse(GR->getSPIRVTypeID(RetType))
799                        .addUse(AsmReg);
800     for (unsigned IntrIdx = 3; IntrIdx < I1->getNumOperands(); ++IntrIdx)
801       AsmCall.addUse(I1->getOperand(IntrIdx).getReg());
802   }
803   for (MachineInstr *MI : ToProcess)
804     MI->eraseFromParent();
805 }
806 
807 static void insertInlineAsm(MachineFunction &MF, SPIRVGlobalRegistry *GR,
808                             const SPIRVSubtarget &ST,
809                             MachineIRBuilder MIRBuilder) {
810   SmallVector<MachineInstr *> ToProcess;
811   for (MachineBasicBlock &MBB : MF) {
812     for (MachineInstr &MI : MBB) {
813       if (isSpvIntrinsic(MI, Intrinsic::spv_inline_asm) ||
814           MI.getOpcode() == TargetOpcode::INLINEASM)
815         ToProcess.push_back(&MI);
816     }
817   }
818   if (ToProcess.size() == 0)
819     return;
820 
821   if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_inline_assembly))
822     report_fatal_error("Inline assembly instructions require the "
823                        "following SPIR-V extension: SPV_INTEL_inline_assembly",
824                        false);
825 
826   insertInlineAsmProcess(MF, GR, ST, MIRBuilder, ToProcess);
827 }
828 
829 static void insertSpirvDecorations(MachineFunction &MF, MachineIRBuilder MIB) {
830   SmallVector<MachineInstr *, 10> ToErase;
831   for (MachineBasicBlock &MBB : MF) {
832     for (MachineInstr &MI : MBB) {
833       if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_decoration))
834         continue;
835       MIB.setInsertPt(*MI.getParent(), MI.getNextNode());
836       buildOpSpirvDecorations(MI.getOperand(1).getReg(), MIB,
837                               MI.getOperand(2).getMetadata());
838       ToErase.push_back(&MI);
839     }
840   }
841   for (MachineInstr *MI : ToErase)
842     MI->eraseFromParent();
843 }
844 
845 // LLVM allows the switches to use registers as cases, while SPIR-V required
846 // those to be immediate values. This function replaces such operands with the
847 // equivalent immediate constant.
848 static void processSwitchesConstants(MachineFunction &MF,
849                                      SPIRVGlobalRegistry *GR,
850                                      MachineIRBuilder MIB) {
851   MachineRegisterInfo &MRI = MF.getRegInfo();
852   for (MachineBasicBlock &MBB : MF) {
853     for (MachineInstr &MI : MBB) {
854       if (!isSpvIntrinsic(MI, Intrinsic::spv_switch))
855         continue;
856 
857       SmallVector<MachineOperand, 8> NewOperands;
858       NewOperands.push_back(MI.getOperand(0)); // Opcode
859       NewOperands.push_back(MI.getOperand(1)); // Condition
860       NewOperands.push_back(MI.getOperand(2)); // Default
861       for (unsigned i = 3; i < MI.getNumOperands(); i += 2) {
862         Register Reg = MI.getOperand(i).getReg();
863         MachineInstr *ConstInstr = getDefInstrMaybeConstant(Reg, &MRI);
864         NewOperands.push_back(
865             MachineOperand::CreateCImm(ConstInstr->getOperand(1).getCImm()));
866 
867         NewOperands.push_back(MI.getOperand(i + 1));
868       }
869 
870       assert(MI.getNumOperands() == NewOperands.size());
871       while (MI.getNumOperands() > 0)
872         MI.removeOperand(0);
873       for (auto &MO : NewOperands)
874         MI.addOperand(MO);
875     }
876   }
877 }
878 
879 // Some instructions are used during CodeGen but should never be emitted.
880 // Cleaning up those.
881 static void cleanupHelperInstructions(MachineFunction &MF) {
882   SmallVector<MachineInstr *, 8> ToEraseMI;
883   for (MachineBasicBlock &MBB : MF) {
884     for (MachineInstr &MI : MBB) {
885       if (isSpvIntrinsic(MI, Intrinsic::spv_track_constant) ||
886           MI.getOpcode() == TargetOpcode::G_BRINDIRECT)
887         ToEraseMI.push_back(&MI);
888     }
889   }
890 
891   for (MachineInstr *MI : ToEraseMI)
892     MI->eraseFromParent();
893 }
894 
895 // Find all usages of G_BLOCK_ADDR in our intrinsics and replace those
896 // operands/registers by the actual MBB it references.
897 static void processBlockAddr(MachineFunction &MF, SPIRVGlobalRegistry *GR,
898                              MachineIRBuilder MIB) {
899   // Gather the reverse-mapping BB -> MBB.
900   DenseMap<const BasicBlock *, MachineBasicBlock *> BB2MBB;
901   for (MachineBasicBlock &MBB : MF)
902     BB2MBB[MBB.getBasicBlock()] = &MBB;
903 
904   // Gather instructions requiring patching. For now, only those can use
905   // G_BLOCK_ADDR.
906   SmallVector<MachineInstr *, 8> InstructionsToPatch;
907   for (MachineBasicBlock &MBB : MF) {
908     for (MachineInstr &MI : MBB) {
909       if (isSpvIntrinsic(MI, Intrinsic::spv_switch) ||
910           isSpvIntrinsic(MI, Intrinsic::spv_loop_merge) ||
911           isSpvIntrinsic(MI, Intrinsic::spv_selection_merge))
912         InstructionsToPatch.push_back(&MI);
913     }
914   }
915 
916   // For each instruction to fix, we replace all the G_BLOCK_ADDR operands by
917   // the actual MBB it references. Once those references have been updated, we
918   // can cleanup remaining G_BLOCK_ADDR references.
919   SmallPtrSet<MachineBasicBlock *, 8> ClearAddressTaken;
920   SmallPtrSet<MachineInstr *, 8> ToEraseMI;
921   MachineRegisterInfo &MRI = MF.getRegInfo();
922   for (MachineInstr *MI : InstructionsToPatch) {
923     SmallVector<MachineOperand, 8> NewOps;
924     for (unsigned i = 0; i < MI->getNumOperands(); ++i) {
925       // The operand is not a register, keep as-is.
926       if (!MI->getOperand(i).isReg()) {
927         NewOps.push_back(MI->getOperand(i));
928         continue;
929       }
930 
931       Register Reg = MI->getOperand(i).getReg();
932       MachineInstr *BuildMBB = MRI.getVRegDef(Reg);
933       // The register is not the result of G_BLOCK_ADDR, keep as-is.
934       if (!BuildMBB || BuildMBB->getOpcode() != TargetOpcode::G_BLOCK_ADDR) {
935         NewOps.push_back(MI->getOperand(i));
936         continue;
937       }
938 
939       assert(BuildMBB && BuildMBB->getOpcode() == TargetOpcode::G_BLOCK_ADDR &&
940              BuildMBB->getOperand(1).isBlockAddress() &&
941              BuildMBB->getOperand(1).getBlockAddress());
942       BasicBlock *BB =
943           BuildMBB->getOperand(1).getBlockAddress()->getBasicBlock();
944       auto It = BB2MBB.find(BB);
945       if (It == BB2MBB.end())
946         report_fatal_error("cannot find a machine basic block by a basic block "
947                            "in a switch statement");
948       MachineBasicBlock *ReferencedBlock = It->second;
949       NewOps.push_back(MachineOperand::CreateMBB(ReferencedBlock));
950 
951       ClearAddressTaken.insert(ReferencedBlock);
952       ToEraseMI.insert(BuildMBB);
953     }
954 
955     // Replace the operands.
956     assert(MI->getNumOperands() == NewOps.size());
957     while (MI->getNumOperands() > 0)
958       MI->removeOperand(0);
959     for (auto &MO : NewOps)
960       MI->addOperand(MO);
961 
962     if (MachineInstr *Next = MI->getNextNode()) {
963       if (isSpvIntrinsic(*Next, Intrinsic::spv_track_constant)) {
964         ToEraseMI.insert(Next);
965         Next = MI->getNextNode();
966       }
967       if (Next && Next->getOpcode() == TargetOpcode::G_BRINDIRECT)
968         ToEraseMI.insert(Next);
969     }
970   }
971 
972   // BlockAddress operands were used to keep information between passes,
973   // let's undo the "address taken" status to reflect that Succ doesn't
974   // actually correspond to an IR-level basic block.
975   for (MachineBasicBlock *Succ : ClearAddressTaken)
976     Succ->setAddressTakenIRBlock(nullptr);
977 
978   // If we just delete G_BLOCK_ADDR instructions with BlockAddress operands,
979   // this leaves their BasicBlock counterparts in a "address taken" status. This
980   // would make AsmPrinter to generate a series of unneeded labels of a "Address
981   // of block that was removed by CodeGen" kind. Let's first ensure that we
982   // don't have a dangling BlockAddress constants by zapping the BlockAddress
983   // nodes, and only after that proceed with erasing G_BLOCK_ADDR instructions.
984   Constant *Replacement =
985       ConstantInt::get(Type::getInt32Ty(MF.getFunction().getContext()), 1);
986   for (MachineInstr *BlockAddrI : ToEraseMI) {
987     if (BlockAddrI->getOpcode() == TargetOpcode::G_BLOCK_ADDR) {
988       BlockAddress *BA = const_cast<BlockAddress *>(
989           BlockAddrI->getOperand(1).getBlockAddress());
990       BA->replaceAllUsesWith(
991           ConstantExpr::getIntToPtr(Replacement, BA->getType()));
992       BA->destroyConstant();
993     }
994     BlockAddrI->eraseFromParent();
995   }
996 }
997 
998 static bool isImplicitFallthrough(MachineBasicBlock &MBB) {
999   if (MBB.empty())
1000     return true;
1001 
1002   // Branching SPIR-V intrinsics are not detected by this generic method.
1003   // Thus, we can only trust negative result.
1004   if (!MBB.canFallThrough())
1005     return false;
1006 
1007   // Otherwise, we must manually check if we have a SPIR-V intrinsic which
1008   // prevent an implicit fallthrough.
1009   for (MachineBasicBlock::reverse_iterator It = MBB.rbegin(), E = MBB.rend();
1010        It != E; ++It) {
1011     if (isSpvIntrinsic(*It, Intrinsic::spv_switch))
1012       return false;
1013   }
1014   return true;
1015 }
1016 
1017 static void removeImplicitFallthroughs(MachineFunction &MF,
1018                                        MachineIRBuilder MIB) {
1019   // It is valid for MachineBasicBlocks to not finish with a branch instruction.
1020   // In such cases, they will simply fallthrough their immediate successor.
1021   for (MachineBasicBlock &MBB : MF) {
1022     if (!isImplicitFallthrough(MBB))
1023       continue;
1024 
1025     assert(std::distance(MBB.successors().begin(), MBB.successors().end()) ==
1026            1);
1027     MIB.setInsertPt(MBB, MBB.end());
1028     MIB.buildBr(**MBB.successors().begin());
1029   }
1030 }
1031 
1032 bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
1033   // Initialize the type registry.
1034   const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
1035   SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
1036   GR->setCurrentFunc(MF);
1037   MachineIRBuilder MIB(MF);
1038   // a registry of target extension constants
1039   DenseMap<MachineInstr *, Type *> TargetExtConstTypes;
1040   // to keep record of tracked constants
1041   SmallSet<Register, 4> TrackedConstRegs;
1042   addConstantsToTrack(MF, GR, ST, TargetExtConstTypes, TrackedConstRegs);
1043   foldConstantsIntoIntrinsics(MF, TrackedConstRegs);
1044   insertBitcasts(MF, GR, MIB);
1045   generateAssignInstrs(MF, GR, MIB, TargetExtConstTypes);
1046 
1047   processSwitchesConstants(MF, GR, MIB);
1048   processBlockAddr(MF, GR, MIB);
1049   cleanupHelperInstructions(MF);
1050 
1051   processInstrsWithTypeFolding(MF, GR, MIB);
1052   removeImplicitFallthroughs(MF, MIB);
1053   insertSpirvDecorations(MF, MIB);
1054   insertInlineAsm(MF, GR, ST, MIB);
1055   selectOpBitcasts(MF, GR, MIB);
1056 
1057   return true;
1058 }
1059 
1060 INITIALIZE_PASS(SPIRVPreLegalizer, DEBUG_TYPE, "SPIRV pre legalizer", false,
1061                 false)
1062 
1063 char SPIRVPreLegalizer::ID = 0;
1064 
1065 FunctionPass *llvm::createSPIRVPreLegalizerPass() {
1066   return new SPIRVPreLegalizer();
1067 }
1068