xref: /llvm-project/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (revision eddeb36cf1ced0e14e17ac90f60922366e382100)
1 //===-- SPIRVPreLegalizer.cpp - prepare IR for legalization -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The pass prepares IR for legalization: it assigns SPIR-V types to registers
10 // and removes intrinsics which holded these types during IR translation.
11 // Also it processes constants and registers them in GR to avoid duplication.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "SPIRV.h"
16 #include "SPIRVSubtarget.h"
17 #include "SPIRVUtils.h"
18 #include "llvm/ADT/PostOrderIterator.h"
19 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
20 #include "llvm/CodeGen/GlobalISel/CSEInfo.h"
21 #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h"
22 #include "llvm/IR/Attributes.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/DebugInfoMetadata.h"
25 #include "llvm/IR/IntrinsicsSPIRV.h"
26 #include "llvm/Target/TargetIntrinsicInfo.h"
27 
28 #define DEBUG_TYPE "spirv-prelegalizer"
29 
30 using namespace llvm;
31 
32 namespace {
33 class SPIRVPreLegalizer : public MachineFunctionPass {
34 public:
35   static char ID;
36   SPIRVPreLegalizer() : MachineFunctionPass(ID) {
37     initializeSPIRVPreLegalizerPass(*PassRegistry::getPassRegistry());
38   }
39   bool runOnMachineFunction(MachineFunction &MF) override;
40   void getAnalysisUsage(AnalysisUsage &AU) const override;
41 };
42 } // namespace
43 
44 void SPIRVPreLegalizer::getAnalysisUsage(AnalysisUsage &AU) const {
45   AU.addPreserved<GISelKnownBitsAnalysis>();
46   MachineFunctionPass::getAnalysisUsage(AU);
47 }
48 
49 static void
50 addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR,
51                     const SPIRVSubtarget &STI,
52                     DenseMap<MachineInstr *, Type *> &TargetExtConstTypes,
53                     SmallSet<Register, 4> &TrackedConstRegs) {
54   MachineRegisterInfo &MRI = MF.getRegInfo();
55   DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
56   SmallVector<MachineInstr *, 10> ToErase, ToEraseComposites;
57   for (MachineBasicBlock &MBB : MF) {
58     for (MachineInstr &MI : MBB) {
59       if (!isSpvIntrinsic(MI, Intrinsic::spv_track_constant))
60         continue;
61       ToErase.push_back(&MI);
62       Register SrcReg = MI.getOperand(2).getReg();
63       auto *Const =
64           cast<Constant>(cast<ConstantAsMetadata>(
65                              MI.getOperand(3).getMetadata()->getOperand(0))
66                              ->getValue());
67       if (auto *GV = dyn_cast<GlobalValue>(Const)) {
68         Register Reg = GR->find(GV, &MF);
69         if (!Reg.isValid()) {
70           GR->add(GV, &MF, SrcReg);
71           GR->addGlobalObject(GV, &MF, SrcReg);
72         } else
73           RegsAlreadyAddedToDT[&MI] = Reg;
74       } else {
75         Register Reg = GR->find(Const, &MF);
76         if (!Reg.isValid()) {
77           if (auto *ConstVec = dyn_cast<ConstantDataVector>(Const)) {
78             auto *BuildVec = MRI.getVRegDef(SrcReg);
79             assert(BuildVec &&
80                    BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
81             for (unsigned i = 0; i < ConstVec->getNumElements(); ++i) {
82               // Ensure that OpConstantComposite reuses a constant when it's
83               // already created and available in the same machine function.
84               Constant *ElemConst = ConstVec->getElementAsConstant(i);
85               Register ElemReg = GR->find(ElemConst, &MF);
86               if (!ElemReg.isValid())
87                 GR->add(ElemConst, &MF, BuildVec->getOperand(1 + i).getReg());
88               else
89                 BuildVec->getOperand(1 + i).setReg(ElemReg);
90             }
91           }
92           GR->add(Const, &MF, SrcReg);
93           TrackedConstRegs.insert(SrcReg);
94           if (Const->getType()->isTargetExtTy()) {
95             // remember association so that we can restore it when assign types
96             MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
97             if (SrcMI && (SrcMI->getOpcode() == TargetOpcode::G_CONSTANT ||
98                           SrcMI->getOpcode() == TargetOpcode::G_IMPLICIT_DEF))
99               TargetExtConstTypes[SrcMI] = Const->getType();
100             if (Const->isNullValue()) {
101               MachineIRBuilder MIB(MF);
102               SPIRVType *ExtType =
103                   GR->getOrCreateSPIRVType(Const->getType(), MIB);
104               SrcMI->setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull));
105               SrcMI->addOperand(MachineOperand::CreateReg(
106                   GR->getSPIRVTypeID(ExtType), false));
107             }
108           }
109         } else {
110           RegsAlreadyAddedToDT[&MI] = Reg;
111           // This MI is unused and will be removed. If the MI uses
112           // const_composite, it will be unused and should be removed too.
113           assert(MI.getOperand(2).isReg() && "Reg operand is expected");
114           MachineInstr *SrcMI = MRI.getVRegDef(MI.getOperand(2).getReg());
115           if (SrcMI && isSpvIntrinsic(*SrcMI, Intrinsic::spv_const_composite))
116             ToEraseComposites.push_back(SrcMI);
117         }
118       }
119     }
120   }
121   for (MachineInstr *MI : ToErase) {
122     Register Reg = MI->getOperand(2).getReg();
123     if (RegsAlreadyAddedToDT.contains(MI))
124       Reg = RegsAlreadyAddedToDT[MI];
125     auto *RC = MRI.getRegClassOrNull(MI->getOperand(0).getReg());
126     if (!MRI.getRegClassOrNull(Reg) && RC)
127       MRI.setRegClass(Reg, RC);
128     MRI.replaceRegWith(MI->getOperand(0).getReg(), Reg);
129     MI->eraseFromParent();
130   }
131   for (MachineInstr *MI : ToEraseComposites)
132     MI->eraseFromParent();
133 }
134 
135 static void
136 foldConstantsIntoIntrinsics(MachineFunction &MF,
137                             const SmallSet<Register, 4> &TrackedConstRegs) {
138   SmallVector<MachineInstr *, 10> ToErase;
139   MachineRegisterInfo &MRI = MF.getRegInfo();
140   const unsigned AssignNameOperandShift = 2;
141   for (MachineBasicBlock &MBB : MF) {
142     for (MachineInstr &MI : MBB) {
143       if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_name))
144         continue;
145       unsigned NumOp = MI.getNumExplicitDefs() + AssignNameOperandShift;
146       while (MI.getOperand(NumOp).isReg()) {
147         MachineOperand &MOp = MI.getOperand(NumOp);
148         MachineInstr *ConstMI = MRI.getVRegDef(MOp.getReg());
149         assert(ConstMI->getOpcode() == TargetOpcode::G_CONSTANT);
150         MI.removeOperand(NumOp);
151         MI.addOperand(MachineOperand::CreateImm(
152             ConstMI->getOperand(1).getCImm()->getZExtValue()));
153         Register DefReg = ConstMI->getOperand(0).getReg();
154         if (MRI.use_empty(DefReg) && !TrackedConstRegs.contains(DefReg))
155           ToErase.push_back(ConstMI);
156       }
157     }
158   }
159   for (MachineInstr *MI : ToErase)
160     MI->eraseFromParent();
161 }
162 
163 static MachineInstr *findAssignTypeInstr(Register Reg,
164                                          MachineRegisterInfo *MRI) {
165   for (MachineRegisterInfo::use_instr_iterator I = MRI->use_instr_begin(Reg),
166                                                IE = MRI->use_instr_end();
167        I != IE; ++I) {
168     MachineInstr *UseMI = &*I;
169     if ((isSpvIntrinsic(*UseMI, Intrinsic::spv_assign_ptr_type) ||
170          isSpvIntrinsic(*UseMI, Intrinsic::spv_assign_type)) &&
171         UseMI->getOperand(1).getReg() == Reg)
172       return UseMI;
173   }
174   return nullptr;
175 }
176 
177 static void buildOpBitcast(SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
178                            Register ResVReg, Register OpReg) {
179   SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResVReg);
180   SPIRVType *OpType = GR->getSPIRVTypeForVReg(OpReg);
181   assert(ResType && OpType && "Operand types are expected");
182   if (!GR->isBitcastCompatible(ResType, OpType))
183     report_fatal_error("incompatible result and operand types in a bitcast");
184   MachineRegisterInfo *MRI = MIB.getMRI();
185   if (!MRI->getRegClassOrNull(ResVReg))
186     MRI->setRegClass(ResVReg, GR->getRegClass(ResType));
187   if (ResType == OpType)
188     MIB.buildInstr(TargetOpcode::COPY).addDef(ResVReg).addUse(OpReg);
189   else
190     MIB.buildInstr(SPIRV::OpBitcast)
191         .addDef(ResVReg)
192         .addUse(GR->getSPIRVTypeID(ResType))
193         .addUse(OpReg);
194 }
195 
196 // We do instruction selections early instead of calling MIB.buildBitcast()
197 // generating the general op code G_BITCAST. When MachineVerifier validates
198 // G_BITCAST we see a check of a kind: if Source Type is equal to Destination
199 // Type then report error "bitcast must change the type". This doesn't take into
200 // account the notion of a typed pointer that is important for SPIR-V where a
201 // user may and should use bitcast between pointers with different pointee types
202 // (https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast).
203 // It's important for correct lowering in SPIR-V, because interpretation of the
204 // data type is not left to instructions that utilize the pointer, but encoded
205 // by the pointer declaration, and the SPIRV target can and must handle the
206 // declaration and use of pointers that specify the type of data they point to.
207 // It's not feasible to improve validation of G_BITCAST using just information
208 // provided by low level types of source and destination. Therefore we don't
209 // produce G_BITCAST as the general op code with semantics different from
210 // OpBitcast, but rather lower to OpBitcast immediately. As for now, the only
211 // difference would be that CombinerHelper couldn't transform known patterns
212 // around G_BUILD_VECTOR. See discussion
213 // in https://github.com/llvm/llvm-project/pull/110270 for even more context.
214 static void selectOpBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
215                              MachineIRBuilder MIB) {
216   SmallVector<MachineInstr *, 16> ToErase;
217   for (MachineBasicBlock &MBB : MF) {
218     for (MachineInstr &MI : MBB) {
219       if (MI.getOpcode() != TargetOpcode::G_BITCAST)
220         continue;
221       MIB.setInsertPt(*MI.getParent(), MI);
222       buildOpBitcast(GR, MIB, MI.getOperand(0).getReg(),
223                      MI.getOperand(1).getReg());
224       ToErase.push_back(&MI);
225     }
226   }
227   for (MachineInstr *MI : ToErase)
228     MI->eraseFromParent();
229 }
230 
231 static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
232                            MachineIRBuilder MIB) {
233   // Get access to information about available extensions
234   const SPIRVSubtarget *ST =
235       static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
236   SmallVector<MachineInstr *, 10> ToErase;
237   for (MachineBasicBlock &MBB : MF) {
238     for (MachineInstr &MI : MBB) {
239       if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast) &&
240           !isSpvIntrinsic(MI, Intrinsic::spv_ptrcast))
241         continue;
242       assert(MI.getOperand(2).isReg());
243       MIB.setInsertPt(*MI.getParent(), MI);
244       ToErase.push_back(&MI);
245       if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) {
246         MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
247         continue;
248       }
249       Register Def = MI.getOperand(0).getReg();
250       Register Source = MI.getOperand(2).getReg();
251       Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0);
252       SPIRVType *BaseTy = GR->getOrCreateSPIRVType(ElemTy, MIB);
253       SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
254           BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
255           addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST));
256 
257       // If the ptrcast would be redundant, replace all uses with the source
258       // register.
259       MachineRegisterInfo *MRI = MIB.getMRI();
260       if (GR->getSPIRVTypeForVReg(Source) == AssignedPtrType) {
261         // Erase Def's assign type instruction if we are going to replace Def.
262         if (MachineInstr *AssignMI = findAssignTypeInstr(Def, MRI))
263           ToErase.push_back(AssignMI);
264         MRI->replaceRegWith(Def, Source);
265       } else {
266         GR->assignSPIRVTypeToVReg(AssignedPtrType, Def, MF);
267         MIB.buildBitcast(Def, Source);
268       }
269     }
270   }
271   for (MachineInstr *MI : ToErase)
272     MI->eraseFromParent();
273 }
274 
275 // Translating GV, IRTranslator sometimes generates following IR:
276 //   %1 = G_GLOBAL_VALUE
277 //   %2 = COPY %1
278 //   %3 = G_ADDRSPACE_CAST %2
279 //
280 // or
281 //
282 //  %1 = G_ZEXT %2
283 //  G_MEMCPY ... %2 ...
284 //
285 // New registers have no SPIRVType and no register class info.
286 //
287 // Set SPIRVType for GV, propagate it from GV to other instructions,
288 // also set register classes.
289 static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
290                                      MachineRegisterInfo &MRI,
291                                      MachineIRBuilder &MIB) {
292   SPIRVType *SpvType = nullptr;
293   assert(MI && "Machine instr is expected");
294   if (MI->getOperand(0).isReg()) {
295     Register Reg = MI->getOperand(0).getReg();
296     SpvType = GR->getSPIRVTypeForVReg(Reg);
297     if (!SpvType) {
298       switch (MI->getOpcode()) {
299       case TargetOpcode::G_FCONSTANT:
300       case TargetOpcode::G_CONSTANT: {
301         MIB.setInsertPt(*MI->getParent(), MI);
302         Type *Ty = MI->getOperand(1).getCImm()->getType();
303         SpvType = GR->getOrCreateSPIRVType(Ty, MIB);
304         break;
305       }
306       case TargetOpcode::G_GLOBAL_VALUE: {
307         MIB.setInsertPt(*MI->getParent(), MI);
308         const GlobalValue *Global = MI->getOperand(1).getGlobal();
309         Type *ElementTy = toTypedPointer(GR->getDeducedGlobalValueType(Global));
310         auto *Ty = TypedPointerType::get(ElementTy,
311                                          Global->getType()->getAddressSpace());
312         SpvType = GR->getOrCreateSPIRVType(Ty, MIB);
313         break;
314       }
315       case TargetOpcode::G_ANYEXT:
316       case TargetOpcode::G_SEXT:
317       case TargetOpcode::G_ZEXT: {
318         if (MI->getOperand(1).isReg()) {
319           if (MachineInstr *DefInstr =
320                   MRI.getVRegDef(MI->getOperand(1).getReg())) {
321             if (SPIRVType *Def = propagateSPIRVType(DefInstr, GR, MRI, MIB)) {
322               unsigned CurrentBW = GR->getScalarOrVectorBitWidth(Def);
323               unsigned ExpectedBW =
324                   std::max(MRI.getType(Reg).getScalarSizeInBits(), CurrentBW);
325               unsigned NumElements = GR->getScalarOrVectorComponentCount(Def);
326               SpvType = GR->getOrCreateSPIRVIntegerType(ExpectedBW, MIB);
327               if (NumElements > 1)
328                 SpvType =
329                     GR->getOrCreateSPIRVVectorType(SpvType, NumElements, MIB);
330             }
331           }
332         }
333         break;
334       }
335       case TargetOpcode::G_PTRTOINT:
336         SpvType = GR->getOrCreateSPIRVIntegerType(
337             MRI.getType(Reg).getScalarSizeInBits(), MIB);
338         break;
339       case TargetOpcode::G_TRUNC:
340       case TargetOpcode::G_ADDRSPACE_CAST:
341       case TargetOpcode::G_PTR_ADD:
342       case TargetOpcode::COPY: {
343         MachineOperand &Op = MI->getOperand(1);
344         MachineInstr *Def = Op.isReg() ? MRI.getVRegDef(Op.getReg()) : nullptr;
345         if (Def)
346           SpvType = propagateSPIRVType(Def, GR, MRI, MIB);
347         break;
348       }
349       default:
350         break;
351       }
352       if (SpvType) {
353         // check if the address space needs correction
354         LLT RegType = MRI.getType(Reg);
355         if (SpvType->getOpcode() == SPIRV::OpTypePointer &&
356             RegType.isPointer() &&
357             storageClassToAddressSpace(GR->getPointerStorageClass(SpvType)) !=
358                 RegType.getAddressSpace()) {
359           const SPIRVSubtarget &ST =
360               MI->getParent()->getParent()->getSubtarget<SPIRVSubtarget>();
361           SpvType = GR->getOrCreateSPIRVPointerType(
362               GR->getPointeeType(SpvType), *MI, *ST.getInstrInfo(),
363               addressSpaceToStorageClass(RegType.getAddressSpace(), ST));
364         }
365         GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF());
366       }
367       if (!MRI.getRegClassOrNull(Reg))
368         MRI.setRegClass(Reg, SpvType ? GR->getRegClass(SpvType)
369                                      : &SPIRV::iIDRegClass);
370     }
371   }
372   return SpvType;
373 }
374 
375 // To support current approach and limitations wrt. bit width here we widen a
376 // scalar register with a bit width greater than 1 to valid sizes and cap it to
377 // 64 width.
378 static void widenScalarLLTNextPow2(Register Reg, MachineRegisterInfo &MRI) {
379   LLT RegType = MRI.getType(Reg);
380   if (!RegType.isScalar())
381     return;
382   unsigned Sz = RegType.getScalarSizeInBits();
383   if (Sz == 1)
384     return;
385   unsigned NewSz = std::min(std::max(1u << Log2_32_Ceil(Sz), 8u), 64u);
386   if (NewSz != Sz)
387     MRI.setType(Reg, LLT::scalar(NewSz));
388 }
389 
390 static std::pair<Register, unsigned>
391 createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI,
392                const SPIRVGlobalRegistry &GR) {
393   if (!SpvType)
394     SpvType = GR.getSPIRVTypeForVReg(SrcReg);
395   const TargetRegisterClass *RC = GR.getRegClass(SpvType);
396   Register Reg = MRI.createGenericVirtualRegister(GR.getRegType(SpvType));
397   MRI.setRegClass(Reg, RC);
398   unsigned GetIdOp = SPIRV::GET_ID;
399   if (RC == &SPIRV::fIDRegClass)
400     GetIdOp = SPIRV::GET_fID;
401   else if (RC == &SPIRV::pIDRegClass)
402     GetIdOp = SPIRV::GET_pID;
403   else if (RC == &SPIRV::vfIDRegClass)
404     GetIdOp = SPIRV::GET_vfID;
405   else if (RC == &SPIRV::vpIDRegClass)
406     GetIdOp = SPIRV::GET_vpID;
407   else if (RC == &SPIRV::vIDRegClass)
408     GetIdOp = SPIRV::GET_vID;
409   return {Reg, GetIdOp};
410 }
411 
412 static void setInsertPtAfterDef(MachineIRBuilder &MIB, MachineInstr *Def) {
413   MachineBasicBlock &MBB = *Def->getParent();
414   MachineBasicBlock::iterator DefIt =
415       Def->getNextNode() ? Def->getNextNode()->getIterator() : MBB.end();
416   // Skip all the PHI and debug instructions.
417   while (DefIt != MBB.end() &&
418          (DefIt->isPHI() || DefIt->isDebugOrPseudoInstr()))
419     DefIt = std::next(DefIt);
420   MIB.setInsertPt(MBB, DefIt);
421 }
422 
423 // Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as
424 // a dst of the definition, assign SPIRVType to both registers. If SpvType is
425 // provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty.
426 // It's used also in SPIRVBuiltins.cpp.
427 // TODO: maybe move to SPIRVUtils.
428 namespace llvm {
429 Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpvType,
430                            SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
431                            MachineRegisterInfo &MRI) {
432   assert((Ty || SpvType) && "Either LLVM or SPIRV type is expected.");
433   MachineInstr *Def = MRI.getVRegDef(Reg);
434   setInsertPtAfterDef(MIB, Def);
435   SpvType = SpvType ? SpvType : GR->getOrCreateSPIRVType(Ty, MIB);
436   Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
437   if (auto *RC = MRI.getRegClassOrNull(Reg)) {
438     MRI.setRegClass(NewReg, RC);
439   } else {
440     auto RegClass = GR->getRegClass(SpvType);
441     MRI.setRegClass(NewReg, RegClass);
442     MRI.setRegClass(Reg, RegClass);
443   }
444   GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF());
445   // This is to make it convenient for Legalizer to get the SPIRVType
446   // when processing the actual MI (i.e. not pseudo one).
447   GR->assignSPIRVTypeToVReg(SpvType, NewReg, MIB.getMF());
448   // Copy MIFlags from Def to ASSIGN_TYPE instruction. It's required to keep
449   // the flags after instruction selection.
450   const uint32_t Flags = Def->getFlags();
451   MIB.buildInstr(SPIRV::ASSIGN_TYPE)
452       .addDef(Reg)
453       .addUse(NewReg)
454       .addUse(GR->getSPIRVTypeID(SpvType))
455       .setMIFlags(Flags);
456   for (unsigned I = 0, E = Def->getNumDefs(); I != E; ++I) {
457     MachineOperand &MO = Def->getOperand(I);
458     if (MO.getReg() == Reg) {
459       MO.setReg(NewReg);
460       break;
461     }
462   }
463   return NewReg;
464 }
465 
466 void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
467                   MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) {
468   MIB.setInsertPt(*MI.getParent(), MI.getIterator());
469   for (auto &Op : MI.operands()) {
470     if (!Op.isReg() || Op.isDef())
471       continue;
472     Register OpReg = Op.getReg();
473     SPIRVType *SpvType = GR->getSPIRVTypeForVReg(OpReg);
474     auto IdOpInfo = createNewIdReg(SpvType, OpReg, MRI, *GR);
475     MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(OpReg);
476     const TargetRegisterClass *RC = GR->getRegClass(SpvType);
477     if (RC != MRI.getRegClassOrNull(OpReg))
478       MRI.setRegClass(OpReg, RC);
479     Op.setReg(IdOpInfo.first);
480   }
481 }
482 } // namespace llvm
483 
484 static void
485 generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
486                      MachineIRBuilder MIB,
487                      DenseMap<MachineInstr *, Type *> &TargetExtConstTypes) {
488   // Get access to information about available extensions
489   const SPIRVSubtarget *ST =
490       static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
491 
492   MachineRegisterInfo &MRI = MF.getRegInfo();
493   SmallVector<MachineInstr *, 10> ToErase;
494   DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
495 
496   bool IsExtendedInts =
497       ST->canUseExtension(
498           SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
499       ST->canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
500 
501   for (MachineBasicBlock *MBB : post_order(&MF)) {
502     if (MBB->empty())
503       continue;
504 
505     bool ReachedBegin = false;
506     for (auto MII = std::prev(MBB->end()), Begin = MBB->begin();
507          !ReachedBegin;) {
508       MachineInstr &MI = *MII;
509       unsigned MIOp = MI.getOpcode();
510 
511       if (!IsExtendedInts) {
512         // validate bit width of scalar registers
513         for (const auto &MOP : MI.operands())
514           if (MOP.isReg())
515             widenScalarLLTNextPow2(MOP.getReg(), MRI);
516       }
517 
518       if (isSpvIntrinsic(MI, Intrinsic::spv_assign_ptr_type)) {
519         Register Reg = MI.getOperand(1).getReg();
520         MIB.setInsertPt(*MI.getParent(), MI.getIterator());
521         Type *ElementTy = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
522         SPIRVType *BaseTy = GR->getOrCreateSPIRVType(ElementTy, MIB);
523         SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
524             BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
525             addressSpaceToStorageClass(MI.getOperand(3).getImm(), *ST));
526         MachineInstr *Def = MRI.getVRegDef(Reg);
527         assert(Def && "Expecting an instruction that defines the register");
528         // G_GLOBAL_VALUE already has type info.
529         if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE &&
530             Def->getOpcode() != SPIRV::ASSIGN_TYPE)
531           insertAssignInstr(Reg, nullptr, AssignedPtrType, GR, MIB,
532                             MF.getRegInfo());
533         ToErase.push_back(&MI);
534       } else if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) {
535         Register Reg = MI.getOperand(1).getReg();
536         Type *Ty = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
537         MachineInstr *Def = MRI.getVRegDef(Reg);
538         assert(Def && "Expecting an instruction that defines the register");
539         // G_GLOBAL_VALUE already has type info.
540         if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE &&
541             Def->getOpcode() != SPIRV::ASSIGN_TYPE)
542           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo());
543         ToErase.push_back(&MI);
544       } else if (MIOp == TargetOpcode::FAKE_USE && MI.getNumOperands() > 0) {
545         MachineInstr *MdMI = MI.getPrevNode();
546         if (MdMI && isSpvIntrinsic(*MdMI, Intrinsic::spv_value_md)) {
547           // It's an internal service info from before IRTranslator passes.
548           MachineInstr *Def = getVRegDef(MRI, MI.getOperand(0).getReg());
549           for (unsigned I = 1, E = MI.getNumOperands(); I != E && Def; ++I)
550             if (getVRegDef(MRI, MI.getOperand(I).getReg()) != Def)
551               Def = nullptr;
552           if (Def) {
553             const MDNode *MD = MdMI->getOperand(1).getMetadata();
554             StringRef ValueName =
555                 cast<MDString>(MD->getOperand(1))->getString();
556             const MDNode *TypeMD = cast<MDNode>(MD->getOperand(0));
557             Type *ValueTy = getMDOperandAsType(TypeMD, 0);
558             GR->addValueAttrs(Def, std::make_pair(ValueTy, ValueName.str()));
559           }
560           ToErase.push_back(MdMI);
561         }
562         ToErase.push_back(&MI);
563       } else if (MIOp == TargetOpcode::G_CONSTANT ||
564                  MIOp == TargetOpcode::G_FCONSTANT ||
565                  MIOp == TargetOpcode::G_BUILD_VECTOR) {
566         // %rc = G_CONSTANT ty Val
567         // ===>
568         // %cty = OpType* ty
569         // %rctmp = G_CONSTANT ty Val
570         // %rc = ASSIGN_TYPE %rctmp, %cty
571         Register Reg = MI.getOperand(0).getReg();
572         bool NeedAssignType = true;
573         if (MRI.hasOneUse(Reg)) {
574           MachineInstr &UseMI = *MRI.use_instr_begin(Reg);
575           if (isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type) ||
576               isSpvIntrinsic(UseMI, Intrinsic::spv_assign_name))
577             continue;
578           if (UseMI.getOpcode() == SPIRV::ASSIGN_TYPE)
579             NeedAssignType = false;
580         }
581         Type *Ty = nullptr;
582         if (MIOp == TargetOpcode::G_CONSTANT) {
583           auto TargetExtIt = TargetExtConstTypes.find(&MI);
584           Ty = TargetExtIt == TargetExtConstTypes.end()
585                    ? MI.getOperand(1).getCImm()->getType()
586                    : TargetExtIt->second;
587           const ConstantInt *OpCI = MI.getOperand(1).getCImm();
588           // TODO: we may wish to analyze here if OpCI is zero and LLT RegType =
589           // MRI.getType(Reg); RegType.isPointer() is true, so that we observe
590           // at this point not i64/i32 constant but null pointer in the
591           // corresponding address space of RegType.getAddressSpace(). This may
592           // help to successfully validate the case when a OpConstantComposite's
593           // constituent has type that does not match Result Type of
594           // OpConstantComposite (see, for example,
595           // pointers/PtrCast-null-in-OpSpecConstantOp.ll).
596           Register PrimaryReg = GR->find(OpCI, &MF);
597           if (!PrimaryReg.isValid()) {
598             GR->add(OpCI, &MF, Reg);
599           } else if (PrimaryReg != Reg &&
600                      MRI.getType(Reg) == MRI.getType(PrimaryReg)) {
601             auto *RCReg = MRI.getRegClassOrNull(Reg);
602             auto *RCPrimary = MRI.getRegClassOrNull(PrimaryReg);
603             if (!RCReg || RCPrimary == RCReg) {
604               RegsAlreadyAddedToDT[&MI] = PrimaryReg;
605               ToErase.push_back(&MI);
606               NeedAssignType = false;
607             }
608           }
609         } else if (MIOp == TargetOpcode::G_FCONSTANT) {
610           Ty = MI.getOperand(1).getFPImm()->getType();
611         } else {
612           assert(MIOp == TargetOpcode::G_BUILD_VECTOR);
613           Type *ElemTy = nullptr;
614           MachineInstr *ElemMI = MRI.getVRegDef(MI.getOperand(1).getReg());
615           assert(ElemMI);
616 
617           if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT) {
618             ElemTy = ElemMI->getOperand(1).getCImm()->getType();
619           } else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT) {
620             ElemTy = ElemMI->getOperand(1).getFPImm()->getType();
621           } else {
622             // There may be a case when we already know Reg's type.
623             MachineInstr *NextMI = MI.getNextNode();
624             if (!NextMI || NextMI->getOpcode() != SPIRV::ASSIGN_TYPE ||
625                 NextMI->getOperand(1).getReg() != Reg)
626               llvm_unreachable("Unexpected opcode");
627           }
628           if (ElemTy)
629             Ty = VectorType::get(
630                 ElemTy, MI.getNumExplicitOperands() - MI.getNumExplicitDefs(),
631                 false);
632           else
633             NeedAssignType = false;
634         }
635         if (NeedAssignType)
636           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI);
637       } else if (MIOp == TargetOpcode::G_GLOBAL_VALUE) {
638         propagateSPIRVType(&MI, GR, MRI, MIB);
639       }
640 
641       if (MII == Begin)
642         ReachedBegin = true;
643       else
644         --MII;
645     }
646   }
647   for (MachineInstr *MI : ToErase) {
648     auto It = RegsAlreadyAddedToDT.find(MI);
649     if (RegsAlreadyAddedToDT.contains(MI))
650       MRI.replaceRegWith(MI->getOperand(0).getReg(), It->second);
651     MI->eraseFromParent();
652   }
653 
654   // Address the case when IRTranslator introduces instructions with new
655   // registers without SPIRVType associated.
656   for (MachineBasicBlock &MBB : MF) {
657     for (MachineInstr &MI : MBB) {
658       switch (MI.getOpcode()) {
659       case TargetOpcode::G_TRUNC:
660       case TargetOpcode::G_ANYEXT:
661       case TargetOpcode::G_SEXT:
662       case TargetOpcode::G_ZEXT:
663       case TargetOpcode::G_PTRTOINT:
664       case TargetOpcode::COPY:
665       case TargetOpcode::G_ADDRSPACE_CAST:
666         propagateSPIRVType(&MI, GR, MRI, MIB);
667         break;
668       }
669     }
670   }
671 }
672 
673 // Defined in SPIRVLegalizerInfo.cpp.
674 extern bool isTypeFoldingSupported(unsigned Opcode);
675 
676 static void processInstrsWithTypeFolding(MachineFunction &MF,
677                                          SPIRVGlobalRegistry *GR,
678                                          MachineIRBuilder MIB) {
679   MachineRegisterInfo &MRI = MF.getRegInfo();
680   for (MachineBasicBlock &MBB : MF)
681     for (MachineInstr &MI : MBB)
682       if (isTypeFoldingSupported(MI.getOpcode()))
683         processInstr(MI, MIB, MRI, GR);
684 }
685 
686 static Register
687 collectInlineAsmInstrOperands(MachineInstr *MI,
688                               SmallVector<unsigned, 4> *Ops = nullptr) {
689   Register DefReg;
690   unsigned StartOp = InlineAsm::MIOp_FirstOperand,
691            AsmDescOp = InlineAsm::MIOp_FirstOperand;
692   for (unsigned Idx = StartOp, MISz = MI->getNumOperands(); Idx != MISz;
693        ++Idx) {
694     const MachineOperand &MO = MI->getOperand(Idx);
695     if (MO.isMetadata())
696       continue;
697     if (Idx == AsmDescOp && MO.isImm()) {
698       // compute the index of the next operand descriptor
699       const InlineAsm::Flag F(MO.getImm());
700       AsmDescOp += 1 + F.getNumOperandRegisters();
701       continue;
702     }
703     if (MO.isReg() && MO.isDef()) {
704       if (!Ops)
705         return MO.getReg();
706       else
707         DefReg = MO.getReg();
708     } else if (Ops) {
709       Ops->push_back(Idx);
710     }
711   }
712   return DefReg;
713 }
714 
715 static void
716 insertInlineAsmProcess(MachineFunction &MF, SPIRVGlobalRegistry *GR,
717                        const SPIRVSubtarget &ST, MachineIRBuilder MIRBuilder,
718                        const SmallVector<MachineInstr *> &ToProcess) {
719   MachineRegisterInfo &MRI = MF.getRegInfo();
720   Register AsmTargetReg;
721   for (unsigned i = 0, Sz = ToProcess.size(); i + 1 < Sz; i += 2) {
722     MachineInstr *I1 = ToProcess[i], *I2 = ToProcess[i + 1];
723     assert(isSpvIntrinsic(*I1, Intrinsic::spv_inline_asm) && I2->isInlineAsm());
724     MIRBuilder.setInsertPt(*I2->getParent(), *I2);
725 
726     if (!AsmTargetReg.isValid()) {
727       // define vendor specific assembly target or dialect
728       AsmTargetReg = MRI.createGenericVirtualRegister(LLT::scalar(32));
729       MRI.setRegClass(AsmTargetReg, &SPIRV::iIDRegClass);
730       auto AsmTargetMIB =
731           MIRBuilder.buildInstr(SPIRV::OpAsmTargetINTEL).addDef(AsmTargetReg);
732       addStringImm(ST.getTargetTripleAsStr(), AsmTargetMIB);
733       GR->add(AsmTargetMIB.getInstr(), &MF, AsmTargetReg);
734     }
735 
736     // create types
737     const MDNode *IAMD = I1->getOperand(1).getMetadata();
738     FunctionType *FTy = cast<FunctionType>(getMDOperandAsType(IAMD, 0));
739     SmallVector<SPIRVType *, 4> ArgTypes;
740     for (const auto &ArgTy : FTy->params())
741       ArgTypes.push_back(GR->getOrCreateSPIRVType(ArgTy, MIRBuilder));
742     SPIRVType *RetType =
743         GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
744     SPIRVType *FuncType = GR->getOrCreateOpTypeFunctionWithArgs(
745         FTy, RetType, ArgTypes, MIRBuilder);
746 
747     // define vendor specific assembly instructions string
748     Register AsmReg = MRI.createGenericVirtualRegister(LLT::scalar(32));
749     MRI.setRegClass(AsmReg, &SPIRV::iIDRegClass);
750     auto AsmMIB = MIRBuilder.buildInstr(SPIRV::OpAsmINTEL)
751                       .addDef(AsmReg)
752                       .addUse(GR->getSPIRVTypeID(RetType))
753                       .addUse(GR->getSPIRVTypeID(FuncType))
754                       .addUse(AsmTargetReg);
755     // inline asm string:
756     addStringImm(I2->getOperand(InlineAsm::MIOp_AsmString).getSymbolName(),
757                  AsmMIB);
758     // inline asm constraint string:
759     addStringImm(cast<MDString>(I1->getOperand(2).getMetadata()->getOperand(0))
760                      ->getString(),
761                  AsmMIB);
762     GR->add(AsmMIB.getInstr(), &MF, AsmReg);
763 
764     // calls the inline assembly instruction
765     unsigned ExtraInfo = I2->getOperand(InlineAsm::MIOp_ExtraInfo).getImm();
766     if (ExtraInfo & InlineAsm::Extra_HasSideEffects)
767       MIRBuilder.buildInstr(SPIRV::OpDecorate)
768           .addUse(AsmReg)
769           .addImm(static_cast<uint32_t>(SPIRV::Decoration::SideEffectsINTEL));
770 
771     Register DefReg = collectInlineAsmInstrOperands(I2);
772     if (!DefReg.isValid()) {
773       DefReg = MRI.createGenericVirtualRegister(LLT::scalar(32));
774       MRI.setRegClass(DefReg, &SPIRV::iIDRegClass);
775       SPIRVType *VoidType = GR->getOrCreateSPIRVType(
776           Type::getVoidTy(MF.getFunction().getContext()), MIRBuilder);
777       GR->assignSPIRVTypeToVReg(VoidType, DefReg, MF);
778     }
779 
780     auto AsmCall = MIRBuilder.buildInstr(SPIRV::OpAsmCallINTEL)
781                        .addDef(DefReg)
782                        .addUse(GR->getSPIRVTypeID(RetType))
783                        .addUse(AsmReg);
784     for (unsigned IntrIdx = 3; IntrIdx < I1->getNumOperands(); ++IntrIdx)
785       AsmCall.addUse(I1->getOperand(IntrIdx).getReg());
786   }
787   for (MachineInstr *MI : ToProcess)
788     MI->eraseFromParent();
789 }
790 
791 static void insertInlineAsm(MachineFunction &MF, SPIRVGlobalRegistry *GR,
792                             const SPIRVSubtarget &ST,
793                             MachineIRBuilder MIRBuilder) {
794   SmallVector<MachineInstr *> ToProcess;
795   for (MachineBasicBlock &MBB : MF) {
796     for (MachineInstr &MI : MBB) {
797       if (isSpvIntrinsic(MI, Intrinsic::spv_inline_asm) ||
798           MI.getOpcode() == TargetOpcode::INLINEASM)
799         ToProcess.push_back(&MI);
800     }
801   }
802   if (ToProcess.size() == 0)
803     return;
804 
805   if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_inline_assembly))
806     report_fatal_error("Inline assembly instructions require the "
807                        "following SPIR-V extension: SPV_INTEL_inline_assembly",
808                        false);
809 
810   insertInlineAsmProcess(MF, GR, ST, MIRBuilder, ToProcess);
811 }
812 
813 static void insertSpirvDecorations(MachineFunction &MF, MachineIRBuilder MIB) {
814   SmallVector<MachineInstr *, 10> ToErase;
815   for (MachineBasicBlock &MBB : MF) {
816     for (MachineInstr &MI : MBB) {
817       if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_decoration))
818         continue;
819       MIB.setInsertPt(*MI.getParent(), MI.getNextNode());
820       buildOpSpirvDecorations(MI.getOperand(1).getReg(), MIB,
821                               MI.getOperand(2).getMetadata());
822       ToErase.push_back(&MI);
823     }
824   }
825   for (MachineInstr *MI : ToErase)
826     MI->eraseFromParent();
827 }
828 
829 // LLVM allows the switches to use registers as cases, while SPIR-V required
830 // those to be immediate values. This function replaces such operands with the
831 // equivalent immediate constant.
832 static void processSwitchesConstants(MachineFunction &MF,
833                                      SPIRVGlobalRegistry *GR,
834                                      MachineIRBuilder MIB) {
835   MachineRegisterInfo &MRI = MF.getRegInfo();
836   for (MachineBasicBlock &MBB : MF) {
837     for (MachineInstr &MI : MBB) {
838       if (!isSpvIntrinsic(MI, Intrinsic::spv_switch))
839         continue;
840 
841       SmallVector<MachineOperand, 8> NewOperands;
842       NewOperands.push_back(MI.getOperand(0)); // Opcode
843       NewOperands.push_back(MI.getOperand(1)); // Condition
844       NewOperands.push_back(MI.getOperand(2)); // Default
845       for (unsigned i = 3; i < MI.getNumOperands(); i += 2) {
846         Register Reg = MI.getOperand(i).getReg();
847         MachineInstr *ConstInstr = getDefInstrMaybeConstant(Reg, &MRI);
848         NewOperands.push_back(
849             MachineOperand::CreateCImm(ConstInstr->getOperand(1).getCImm()));
850 
851         NewOperands.push_back(MI.getOperand(i + 1));
852       }
853 
854       assert(MI.getNumOperands() == NewOperands.size());
855       while (MI.getNumOperands() > 0)
856         MI.removeOperand(0);
857       for (auto &MO : NewOperands)
858         MI.addOperand(MO);
859     }
860   }
861 }
862 
863 // Some instructions are used during CodeGen but should never be emitted.
864 // Cleaning up those.
865 static void cleanupHelperInstructions(MachineFunction &MF) {
866   SmallVector<MachineInstr *, 8> ToEraseMI;
867   for (MachineBasicBlock &MBB : MF) {
868     for (MachineInstr &MI : MBB) {
869       if (isSpvIntrinsic(MI, Intrinsic::spv_track_constant) ||
870           MI.getOpcode() == TargetOpcode::G_BRINDIRECT)
871         ToEraseMI.push_back(&MI);
872     }
873   }
874 
875   for (MachineInstr *MI : ToEraseMI)
876     MI->eraseFromParent();
877 }
878 
879 // Find all usages of G_BLOCK_ADDR in our intrinsics and replace those
880 // operands/registers by the actual MBB it references.
881 static void processBlockAddr(MachineFunction &MF, SPIRVGlobalRegistry *GR,
882                              MachineIRBuilder MIB) {
883   // Gather the reverse-mapping BB -> MBB.
884   DenseMap<const BasicBlock *, MachineBasicBlock *> BB2MBB;
885   for (MachineBasicBlock &MBB : MF)
886     BB2MBB[MBB.getBasicBlock()] = &MBB;
887 
888   // Gather instructions requiring patching. For now, only those can use
889   // G_BLOCK_ADDR.
890   SmallVector<MachineInstr *, 8> InstructionsToPatch;
891   for (MachineBasicBlock &MBB : MF) {
892     for (MachineInstr &MI : MBB) {
893       if (isSpvIntrinsic(MI, Intrinsic::spv_switch) ||
894           isSpvIntrinsic(MI, Intrinsic::spv_loop_merge) ||
895           isSpvIntrinsic(MI, Intrinsic::spv_selection_merge))
896         InstructionsToPatch.push_back(&MI);
897     }
898   }
899 
900   // For each instruction to fix, we replace all the G_BLOCK_ADDR operands by
901   // the actual MBB it references. Once those references have been updated, we
902   // can cleanup remaining G_BLOCK_ADDR references.
903   SmallPtrSet<MachineBasicBlock *, 8> ClearAddressTaken;
904   SmallPtrSet<MachineInstr *, 8> ToEraseMI;
905   MachineRegisterInfo &MRI = MF.getRegInfo();
906   for (MachineInstr *MI : InstructionsToPatch) {
907     SmallVector<MachineOperand, 8> NewOps;
908     for (unsigned i = 0; i < MI->getNumOperands(); ++i) {
909       // The operand is not a register, keep as-is.
910       if (!MI->getOperand(i).isReg()) {
911         NewOps.push_back(MI->getOperand(i));
912         continue;
913       }
914 
915       Register Reg = MI->getOperand(i).getReg();
916       MachineInstr *BuildMBB = MRI.getVRegDef(Reg);
917       // The register is not the result of G_BLOCK_ADDR, keep as-is.
918       if (!BuildMBB || BuildMBB->getOpcode() != TargetOpcode::G_BLOCK_ADDR) {
919         NewOps.push_back(MI->getOperand(i));
920         continue;
921       }
922 
923       assert(BuildMBB && BuildMBB->getOpcode() == TargetOpcode::G_BLOCK_ADDR &&
924              BuildMBB->getOperand(1).isBlockAddress() &&
925              BuildMBB->getOperand(1).getBlockAddress());
926       BasicBlock *BB =
927           BuildMBB->getOperand(1).getBlockAddress()->getBasicBlock();
928       auto It = BB2MBB.find(BB);
929       if (It == BB2MBB.end())
930         report_fatal_error("cannot find a machine basic block by a basic block "
931                            "in a switch statement");
932       MachineBasicBlock *ReferencedBlock = It->second;
933       NewOps.push_back(MachineOperand::CreateMBB(ReferencedBlock));
934 
935       ClearAddressTaken.insert(ReferencedBlock);
936       ToEraseMI.insert(BuildMBB);
937     }
938 
939     // Replace the operands.
940     assert(MI->getNumOperands() == NewOps.size());
941     while (MI->getNumOperands() > 0)
942       MI->removeOperand(0);
943     for (auto &MO : NewOps)
944       MI->addOperand(MO);
945 
946     if (MachineInstr *Next = MI->getNextNode()) {
947       if (isSpvIntrinsic(*Next, Intrinsic::spv_track_constant)) {
948         ToEraseMI.insert(Next);
949         Next = MI->getNextNode();
950       }
951       if (Next && Next->getOpcode() == TargetOpcode::G_BRINDIRECT)
952         ToEraseMI.insert(Next);
953     }
954   }
955 
956   // BlockAddress operands were used to keep information between passes,
957   // let's undo the "address taken" status to reflect that Succ doesn't
958   // actually correspond to an IR-level basic block.
959   for (MachineBasicBlock *Succ : ClearAddressTaken)
960     Succ->setAddressTakenIRBlock(nullptr);
961 
962   // If we just delete G_BLOCK_ADDR instructions with BlockAddress operands,
963   // this leaves their BasicBlock counterparts in a "address taken" status. This
964   // would make AsmPrinter to generate a series of unneeded labels of a "Address
965   // of block that was removed by CodeGen" kind. Let's first ensure that we
966   // don't have a dangling BlockAddress constants by zapping the BlockAddress
967   // nodes, and only after that proceed with erasing G_BLOCK_ADDR instructions.
968   Constant *Replacement =
969       ConstantInt::get(Type::getInt32Ty(MF.getFunction().getContext()), 1);
970   for (MachineInstr *BlockAddrI : ToEraseMI) {
971     if (BlockAddrI->getOpcode() == TargetOpcode::G_BLOCK_ADDR) {
972       BlockAddress *BA = const_cast<BlockAddress *>(
973           BlockAddrI->getOperand(1).getBlockAddress());
974       BA->replaceAllUsesWith(
975           ConstantExpr::getIntToPtr(Replacement, BA->getType()));
976       BA->destroyConstant();
977     }
978     BlockAddrI->eraseFromParent();
979   }
980 }
981 
982 static bool isImplicitFallthrough(MachineBasicBlock &MBB) {
983   if (MBB.empty())
984     return true;
985 
986   // Branching SPIR-V intrinsics are not detected by this generic method.
987   // Thus, we can only trust negative result.
988   if (!MBB.canFallThrough())
989     return false;
990 
991   // Otherwise, we must manually check if we have a SPIR-V intrinsic which
992   // prevent an implicit fallthrough.
993   for (MachineBasicBlock::reverse_iterator It = MBB.rbegin(), E = MBB.rend();
994        It != E; ++It) {
995     if (isSpvIntrinsic(*It, Intrinsic::spv_switch))
996       return false;
997   }
998   return true;
999 }
1000 
1001 static void removeImplicitFallthroughs(MachineFunction &MF,
1002                                        MachineIRBuilder MIB) {
1003   // It is valid for MachineBasicBlocks to not finish with a branch instruction.
1004   // In such cases, they will simply fallthrough their immediate successor.
1005   for (MachineBasicBlock &MBB : MF) {
1006     if (!isImplicitFallthrough(MBB))
1007       continue;
1008 
1009     assert(std::distance(MBB.successors().begin(), MBB.successors().end()) ==
1010            1);
1011     MIB.setInsertPt(MBB, MBB.end());
1012     MIB.buildBr(**MBB.successors().begin());
1013   }
1014 }
1015 
1016 bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
1017   // Initialize the type registry.
1018   const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
1019   SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
1020   GR->setCurrentFunc(MF);
1021   MachineIRBuilder MIB(MF);
1022   // a registry of target extension constants
1023   DenseMap<MachineInstr *, Type *> TargetExtConstTypes;
1024   // to keep record of tracked constants
1025   SmallSet<Register, 4> TrackedConstRegs;
1026   addConstantsToTrack(MF, GR, ST, TargetExtConstTypes, TrackedConstRegs);
1027   foldConstantsIntoIntrinsics(MF, TrackedConstRegs);
1028   insertBitcasts(MF, GR, MIB);
1029   generateAssignInstrs(MF, GR, MIB, TargetExtConstTypes);
1030 
1031   processSwitchesConstants(MF, GR, MIB);
1032   processBlockAddr(MF, GR, MIB);
1033   cleanupHelperInstructions(MF);
1034 
1035   processInstrsWithTypeFolding(MF, GR, MIB);
1036   removeImplicitFallthroughs(MF, MIB);
1037   insertSpirvDecorations(MF, MIB);
1038   insertInlineAsm(MF, GR, ST, MIB);
1039   selectOpBitcasts(MF, GR, MIB);
1040 
1041   return true;
1042 }
1043 
1044 INITIALIZE_PASS(SPIRVPreLegalizer, DEBUG_TYPE, "SPIRV pre legalizer", false,
1045                 false)
1046 
1047 char SPIRVPreLegalizer::ID = 0;
1048 
1049 FunctionPass *llvm::createSPIRVPreLegalizerPass() {
1050   return new SPIRVPreLegalizer();
1051 }
1052