xref: /llvm-project/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (revision 1ed65febd996eaa018164e880c87a9e9afc6f68d)
1 //===-- SPIRVPreLegalizer.cpp - prepare IR for legalization -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The pass prepares IR for legalization: it assigns SPIR-V types to registers
10 // and removes intrinsics which holded these types during IR translation.
11 // Also it processes constants and registers them in GR to avoid duplication.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "SPIRV.h"
16 #include "SPIRVSubtarget.h"
17 #include "SPIRVUtils.h"
18 #include "llvm/ADT/PostOrderIterator.h"
19 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
20 #include "llvm/IR/Attributes.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DebugInfoMetadata.h"
23 #include "llvm/IR/IntrinsicsSPIRV.h"
24 #include "llvm/Target/TargetIntrinsicInfo.h"
25 
26 #define DEBUG_TYPE "spirv-prelegalizer"
27 
28 using namespace llvm;
29 
30 namespace {
31 class SPIRVPreLegalizer : public MachineFunctionPass {
32 public:
33   static char ID;
34   SPIRVPreLegalizer() : MachineFunctionPass(ID) {
35     initializeSPIRVPreLegalizerPass(*PassRegistry::getPassRegistry());
36   }
37   bool runOnMachineFunction(MachineFunction &MF) override;
38 };
39 } // namespace
40 
41 static void
42 addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR,
43                     const SPIRVSubtarget &STI,
44                     DenseMap<MachineInstr *, Type *> &TargetExtConstTypes,
45                     SmallSet<Register, 4> &TrackedConstRegs) {
46   MachineRegisterInfo &MRI = MF.getRegInfo();
47   DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
48   SmallVector<MachineInstr *, 10> ToErase, ToEraseComposites;
49   for (MachineBasicBlock &MBB : MF) {
50     for (MachineInstr &MI : MBB) {
51       if (!isSpvIntrinsic(MI, Intrinsic::spv_track_constant))
52         continue;
53       ToErase.push_back(&MI);
54       Register SrcReg = MI.getOperand(2).getReg();
55       auto *Const =
56           cast<Constant>(cast<ConstantAsMetadata>(
57                              MI.getOperand(3).getMetadata()->getOperand(0))
58                              ->getValue());
59       if (auto *GV = dyn_cast<GlobalValue>(Const)) {
60         Register Reg = GR->find(GV, &MF);
61         if (!Reg.isValid())
62           GR->add(GV, &MF, SrcReg);
63         else
64           RegsAlreadyAddedToDT[&MI] = Reg;
65       } else {
66         Register Reg = GR->find(Const, &MF);
67         if (!Reg.isValid()) {
68           if (auto *ConstVec = dyn_cast<ConstantDataVector>(Const)) {
69             auto *BuildVec = MRI.getVRegDef(SrcReg);
70             assert(BuildVec &&
71                    BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
72             for (unsigned i = 0; i < ConstVec->getNumElements(); ++i) {
73               // Ensure that OpConstantComposite reuses a constant when it's
74               // already created and available in the same machine function.
75               Constant *ElemConst = ConstVec->getElementAsConstant(i);
76               Register ElemReg = GR->find(ElemConst, &MF);
77               if (!ElemReg.isValid())
78                 GR->add(ElemConst, &MF, BuildVec->getOperand(1 + i).getReg());
79               else
80                 BuildVec->getOperand(1 + i).setReg(ElemReg);
81             }
82           }
83           GR->add(Const, &MF, SrcReg);
84           TrackedConstRegs.insert(SrcReg);
85           if (Const->getType()->isTargetExtTy()) {
86             // remember association so that we can restore it when assign types
87             MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
88             if (SrcMI && (SrcMI->getOpcode() == TargetOpcode::G_CONSTANT ||
89                           SrcMI->getOpcode() == TargetOpcode::G_IMPLICIT_DEF))
90               TargetExtConstTypes[SrcMI] = Const->getType();
91             if (Const->isNullValue()) {
92               MachineIRBuilder MIB(MF);
93               SPIRVType *ExtType =
94                   GR->getOrCreateSPIRVType(Const->getType(), MIB);
95               SrcMI->setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull));
96               SrcMI->addOperand(MachineOperand::CreateReg(
97                   GR->getSPIRVTypeID(ExtType), false));
98             }
99           }
100         } else {
101           RegsAlreadyAddedToDT[&MI] = Reg;
102           // This MI is unused and will be removed. If the MI uses
103           // const_composite, it will be unused and should be removed too.
104           assert(MI.getOperand(2).isReg() && "Reg operand is expected");
105           MachineInstr *SrcMI = MRI.getVRegDef(MI.getOperand(2).getReg());
106           if (SrcMI && isSpvIntrinsic(*SrcMI, Intrinsic::spv_const_composite))
107             ToEraseComposites.push_back(SrcMI);
108         }
109       }
110     }
111   }
112   for (MachineInstr *MI : ToErase) {
113     Register Reg = MI->getOperand(2).getReg();
114     if (RegsAlreadyAddedToDT.contains(MI))
115       Reg = RegsAlreadyAddedToDT[MI];
116     auto *RC = MRI.getRegClassOrNull(MI->getOperand(0).getReg());
117     if (!MRI.getRegClassOrNull(Reg) && RC)
118       MRI.setRegClass(Reg, RC);
119     MRI.replaceRegWith(MI->getOperand(0).getReg(), Reg);
120     MI->eraseFromParent();
121   }
122   for (MachineInstr *MI : ToEraseComposites)
123     MI->eraseFromParent();
124 }
125 
126 static void
127 foldConstantsIntoIntrinsics(MachineFunction &MF,
128                             const SmallSet<Register, 4> &TrackedConstRegs) {
129   SmallVector<MachineInstr *, 10> ToErase;
130   MachineRegisterInfo &MRI = MF.getRegInfo();
131   const unsigned AssignNameOperandShift = 2;
132   for (MachineBasicBlock &MBB : MF) {
133     for (MachineInstr &MI : MBB) {
134       if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_name))
135         continue;
136       unsigned NumOp = MI.getNumExplicitDefs() + AssignNameOperandShift;
137       while (MI.getOperand(NumOp).isReg()) {
138         MachineOperand &MOp = MI.getOperand(NumOp);
139         MachineInstr *ConstMI = MRI.getVRegDef(MOp.getReg());
140         assert(ConstMI->getOpcode() == TargetOpcode::G_CONSTANT);
141         MI.removeOperand(NumOp);
142         MI.addOperand(MachineOperand::CreateImm(
143             ConstMI->getOperand(1).getCImm()->getZExtValue()));
144         Register DefReg = ConstMI->getOperand(0).getReg();
145         if (MRI.use_empty(DefReg) && !TrackedConstRegs.contains(DefReg))
146           ToErase.push_back(ConstMI);
147       }
148     }
149   }
150   for (MachineInstr *MI : ToErase)
151     MI->eraseFromParent();
152 }
153 
154 static MachineInstr *findAssignTypeInstr(Register Reg,
155                                          MachineRegisterInfo *MRI) {
156   for (MachineRegisterInfo::use_instr_iterator I = MRI->use_instr_begin(Reg),
157                                                IE = MRI->use_instr_end();
158        I != IE; ++I) {
159     MachineInstr *UseMI = &*I;
160     if ((isSpvIntrinsic(*UseMI, Intrinsic::spv_assign_ptr_type) ||
161          isSpvIntrinsic(*UseMI, Intrinsic::spv_assign_type)) &&
162         UseMI->getOperand(1).getReg() == Reg)
163       return UseMI;
164   }
165   return nullptr;
166 }
167 
168 static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
169                            MachineIRBuilder MIB) {
170   // Get access to information about available extensions
171   const SPIRVSubtarget *ST =
172       static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
173   SmallVector<MachineInstr *, 10> ToErase;
174   for (MachineBasicBlock &MBB : MF) {
175     for (MachineInstr &MI : MBB) {
176       if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast) &&
177           !isSpvIntrinsic(MI, Intrinsic::spv_ptrcast))
178         continue;
179       assert(MI.getOperand(2).isReg());
180       MIB.setInsertPt(*MI.getParent(), MI);
181       ToErase.push_back(&MI);
182       if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) {
183         MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
184         continue;
185       }
186       Register Def = MI.getOperand(0).getReg();
187       Register Source = MI.getOperand(2).getReg();
188       Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0);
189       SPIRVType *BaseTy = GR->getOrCreateSPIRVType(ElemTy, MIB);
190       SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
191           BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
192           addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST));
193 
194       // If the ptrcast would be redundant, replace all uses with the source
195       // register.
196       MachineRegisterInfo *MRI = MIB.getMRI();
197       if (GR->getSPIRVTypeForVReg(Source) == AssignedPtrType) {
198         // Erase Def's assign type instruction if we are going to replace Def.
199         if (MachineInstr *AssignMI = findAssignTypeInstr(Def, MRI))
200           ToErase.push_back(AssignMI);
201         MRI->replaceRegWith(Def, Source);
202       } else {
203         GR->assignSPIRVTypeToVReg(AssignedPtrType, Def, MF);
204         MIB.buildBitcast(Def, Source);
205         // MachineVerifier requires that bitcast must change the type.
206         // Change AddressSpace if needed to hint that Def and Source points to
207         // different types: this doesn't change actual code generation.
208         LLT DefType = MRI->getType(Def);
209         if (DefType == MRI->getType(Source))
210           MRI->setType(Def,
211                        LLT::pointer((DefType.getAddressSpace() + 1) %
212                                         SPIRVSubtarget::MaxLegalAddressSpace,
213                                     GR->getPointerSize()));
214       }
215     }
216   }
217   for (MachineInstr *MI : ToErase)
218     MI->eraseFromParent();
219 }
220 
221 // Translating GV, IRTranslator sometimes generates following IR:
222 //   %1 = G_GLOBAL_VALUE
223 //   %2 = COPY %1
224 //   %3 = G_ADDRSPACE_CAST %2
225 //
226 // or
227 //
228 //  %1 = G_ZEXT %2
229 //  G_MEMCPY ... %2 ...
230 //
231 // New registers have no SPIRVType and no register class info.
232 //
233 // Set SPIRVType for GV, propagate it from GV to other instructions,
234 // also set register classes.
235 static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
236                                      MachineRegisterInfo &MRI,
237                                      MachineIRBuilder &MIB) {
238   SPIRVType *SpvType = nullptr;
239   assert(MI && "Machine instr is expected");
240   if (MI->getOperand(0).isReg()) {
241     Register Reg = MI->getOperand(0).getReg();
242     SpvType = GR->getSPIRVTypeForVReg(Reg);
243     if (!SpvType) {
244       switch (MI->getOpcode()) {
245       case TargetOpcode::G_CONSTANT: {
246         MIB.setInsertPt(*MI->getParent(), MI);
247         Type *Ty = MI->getOperand(1).getCImm()->getType();
248         SpvType = GR->getOrCreateSPIRVType(Ty, MIB);
249         break;
250       }
251       case TargetOpcode::G_GLOBAL_VALUE: {
252         MIB.setInsertPt(*MI->getParent(), MI);
253         const GlobalValue *Global = MI->getOperand(1).getGlobal();
254         Type *ElementTy = toTypedPointer(GR->getDeducedGlobalValueType(Global));
255         auto *Ty = TypedPointerType::get(ElementTy,
256                                          Global->getType()->getAddressSpace());
257         SpvType = GR->getOrCreateSPIRVType(Ty, MIB);
258         break;
259       }
260       case TargetOpcode::G_ANYEXT:
261       case TargetOpcode::G_SEXT:
262       case TargetOpcode::G_ZEXT: {
263         if (MI->getOperand(1).isReg()) {
264           if (MachineInstr *DefInstr =
265                   MRI.getVRegDef(MI->getOperand(1).getReg())) {
266             if (SPIRVType *Def = propagateSPIRVType(DefInstr, GR, MRI, MIB)) {
267               unsigned CurrentBW = GR->getScalarOrVectorBitWidth(Def);
268               unsigned ExpectedBW =
269                   std::max(MRI.getType(Reg).getScalarSizeInBits(), CurrentBW);
270               unsigned NumElements = GR->getScalarOrVectorComponentCount(Def);
271               SpvType = GR->getOrCreateSPIRVIntegerType(ExpectedBW, MIB);
272               if (NumElements > 1)
273                 SpvType =
274                     GR->getOrCreateSPIRVVectorType(SpvType, NumElements, MIB);
275             }
276           }
277         }
278         break;
279       }
280       case TargetOpcode::G_PTRTOINT:
281         SpvType = GR->getOrCreateSPIRVIntegerType(
282             MRI.getType(Reg).getScalarSizeInBits(), MIB);
283         break;
284       case TargetOpcode::G_TRUNC:
285       case TargetOpcode::G_ADDRSPACE_CAST:
286       case TargetOpcode::G_PTR_ADD:
287       case TargetOpcode::COPY: {
288         MachineOperand &Op = MI->getOperand(1);
289         MachineInstr *Def = Op.isReg() ? MRI.getVRegDef(Op.getReg()) : nullptr;
290         if (Def)
291           SpvType = propagateSPIRVType(Def, GR, MRI, MIB);
292         break;
293       }
294       default:
295         break;
296       }
297       if (SpvType)
298         GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF());
299       if (!MRI.getRegClassOrNull(Reg))
300         MRI.setRegClass(Reg, SpvType ? GR->getRegClass(SpvType)
301                                      : &SPIRV::iIDRegClass);
302     }
303   }
304   return SpvType;
305 }
306 
307 // To support current approach and limitations wrt. bit width here we widen a
308 // scalar register with a bit width greater than 1 to valid sizes and cap it to
309 // 64 width.
310 static void widenScalarLLTNextPow2(Register Reg, MachineRegisterInfo &MRI) {
311   LLT RegType = MRI.getType(Reg);
312   if (!RegType.isScalar())
313     return;
314   unsigned Sz = RegType.getScalarSizeInBits();
315   if (Sz == 1)
316     return;
317   unsigned NewSz = std::min(std::max(1u << Log2_32_Ceil(Sz), 8u), 64u);
318   if (NewSz != Sz)
319     MRI.setType(Reg, LLT::scalar(NewSz));
320 }
321 
322 static std::pair<Register, unsigned>
323 createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI,
324                const SPIRVGlobalRegistry &GR) {
325   if (!SpvType)
326     SpvType = GR.getSPIRVTypeForVReg(SrcReg);
327   const TargetRegisterClass *RC = GR.getRegClass(SpvType);
328   Register Reg = MRI.createGenericVirtualRegister(GR.getRegType(SpvType));
329   MRI.setRegClass(Reg, RC);
330   unsigned GetIdOp = SPIRV::GET_ID;
331   if (RC == &SPIRV::fIDRegClass)
332     GetIdOp = SPIRV::GET_fID;
333   else if (RC == &SPIRV::pIDRegClass)
334     GetIdOp = SPIRV::GET_pID;
335   else if (RC == &SPIRV::vfIDRegClass)
336     GetIdOp = SPIRV::GET_vfID;
337   else if (RC == &SPIRV::vpIDRegClass)
338     GetIdOp = SPIRV::GET_vpID;
339   else if (RC == &SPIRV::vIDRegClass)
340     GetIdOp = SPIRV::GET_vID;
341   return {Reg, GetIdOp};
342 }
343 
344 // Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as
345 // a dst of the definition, assign SPIRVType to both registers. If SpvType is
346 // provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty.
347 // It's used also in SPIRVBuiltins.cpp.
348 // TODO: maybe move to SPIRVUtils.
349 namespace llvm {
350 Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpvType,
351                            SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
352                            MachineRegisterInfo &MRI) {
353   MachineInstr *Def = MRI.getVRegDef(Reg);
354   assert((Ty || SpvType) && "Either LLVM or SPIRV type is expected.");
355   MIB.setInsertPt(*Def->getParent(),
356                   (Def->getNextNode() ? Def->getNextNode()->getIterator()
357                                       : Def->getParent()->end()));
358   SpvType = SpvType ? SpvType : GR->getOrCreateSPIRVType(Ty, MIB);
359   Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
360   if (auto *RC = MRI.getRegClassOrNull(Reg)) {
361     MRI.setRegClass(NewReg, RC);
362   } else {
363     auto RegClass = GR->getRegClass(SpvType);
364     MRI.setRegClass(NewReg, RegClass);
365     MRI.setRegClass(Reg, RegClass);
366   }
367   GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF());
368   // This is to make it convenient for Legalizer to get the SPIRVType
369   // when processing the actual MI (i.e. not pseudo one).
370   GR->assignSPIRVTypeToVReg(SpvType, NewReg, MIB.getMF());
371   // Copy MIFlags from Def to ASSIGN_TYPE instruction. It's required to keep
372   // the flags after instruction selection.
373   const uint32_t Flags = Def->getFlags();
374   MIB.buildInstr(SPIRV::ASSIGN_TYPE)
375       .addDef(Reg)
376       .addUse(NewReg)
377       .addUse(GR->getSPIRVTypeID(SpvType))
378       .setMIFlags(Flags);
379   Def->getOperand(0).setReg(NewReg);
380   return NewReg;
381 }
382 
383 void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
384                   MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) {
385   assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg()));
386   MachineInstr &AssignTypeInst =
387       *(MRI.use_instr_begin(MI.getOperand(0).getReg()));
388   auto NewReg =
389       createNewIdReg(nullptr, MI.getOperand(0).getReg(), MRI, *GR).first;
390   AssignTypeInst.getOperand(1).setReg(NewReg);
391   MI.getOperand(0).setReg(NewReg);
392   MIB.setInsertPt(*MI.getParent(),
393                   (MI.getNextNode() ? MI.getNextNode()->getIterator()
394                                     : MI.getParent()->end()));
395   for (auto &Op : MI.operands()) {
396     if (!Op.isReg() || Op.isDef())
397       continue;
398     auto IdOpInfo = createNewIdReg(nullptr, Op.getReg(), MRI, *GR);
399     MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg());
400     Op.setReg(IdOpInfo.first);
401   }
402 }
403 } // namespace llvm
404 
405 static void
406 generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
407                      MachineIRBuilder MIB,
408                      DenseMap<MachineInstr *, Type *> &TargetExtConstTypes) {
409   // Get access to information about available extensions
410   const SPIRVSubtarget *ST =
411       static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
412 
413   MachineRegisterInfo &MRI = MF.getRegInfo();
414   SmallVector<MachineInstr *, 10> ToErase;
415   DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
416 
417   bool IsExtendedInts =
418       ST->canUseExtension(
419           SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
420       ST->canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
421 
422   for (MachineBasicBlock *MBB : post_order(&MF)) {
423     if (MBB->empty())
424       continue;
425 
426     bool ReachedBegin = false;
427     for (auto MII = std::prev(MBB->end()), Begin = MBB->begin();
428          !ReachedBegin;) {
429       MachineInstr &MI = *MII;
430       unsigned MIOp = MI.getOpcode();
431 
432       if (!IsExtendedInts) {
433         // validate bit width of scalar registers
434         for (const auto &MOP : MI.operands())
435           if (MOP.isReg())
436             widenScalarLLTNextPow2(MOP.getReg(), MRI);
437       }
438 
439       if (isSpvIntrinsic(MI, Intrinsic::spv_assign_ptr_type)) {
440         Register Reg = MI.getOperand(1).getReg();
441         MIB.setInsertPt(*MI.getParent(), MI.getIterator());
442         Type *ElementTy = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
443         SPIRVType *BaseTy = GR->getOrCreateSPIRVType(ElementTy, MIB);
444         SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
445             BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
446             addressSpaceToStorageClass(MI.getOperand(3).getImm(), *ST));
447         MachineInstr *Def = MRI.getVRegDef(Reg);
448         assert(Def && "Expecting an instruction that defines the register");
449         // G_GLOBAL_VALUE already has type info.
450         if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE &&
451             Def->getOpcode() != SPIRV::ASSIGN_TYPE)
452           insertAssignInstr(Reg, nullptr, AssignedPtrType, GR, MIB,
453                             MF.getRegInfo());
454         ToErase.push_back(&MI);
455       } else if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) {
456         Register Reg = MI.getOperand(1).getReg();
457         Type *Ty = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
458         MachineInstr *Def = MRI.getVRegDef(Reg);
459         assert(Def && "Expecting an instruction that defines the register");
460         // G_GLOBAL_VALUE already has type info.
461         if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE &&
462             Def->getOpcode() != SPIRV::ASSIGN_TYPE)
463           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo());
464         ToErase.push_back(&MI);
465       } else if (MIOp == TargetOpcode::G_CONSTANT ||
466                  MIOp == TargetOpcode::G_FCONSTANT ||
467                  MIOp == TargetOpcode::G_BUILD_VECTOR) {
468         // %rc = G_CONSTANT ty Val
469         // ===>
470         // %cty = OpType* ty
471         // %rctmp = G_CONSTANT ty Val
472         // %rc = ASSIGN_TYPE %rctmp, %cty
473         Register Reg = MI.getOperand(0).getReg();
474         bool NeedAssignType = true;
475         if (MRI.hasOneUse(Reg)) {
476           MachineInstr &UseMI = *MRI.use_instr_begin(Reg);
477           if (isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type) ||
478               isSpvIntrinsic(UseMI, Intrinsic::spv_assign_name))
479             continue;
480           if (UseMI.getOpcode() == SPIRV::ASSIGN_TYPE)
481             NeedAssignType = false;
482         }
483         Type *Ty = nullptr;
484         if (MIOp == TargetOpcode::G_CONSTANT) {
485           auto TargetExtIt = TargetExtConstTypes.find(&MI);
486           Ty = TargetExtIt == TargetExtConstTypes.end()
487                    ? MI.getOperand(1).getCImm()->getType()
488                    : TargetExtIt->second;
489           const ConstantInt *OpCI = MI.getOperand(1).getCImm();
490           Register PrimaryReg = GR->find(OpCI, &MF);
491           if (!PrimaryReg.isValid()) {
492             GR->add(OpCI, &MF, Reg);
493           } else if (PrimaryReg != Reg &&
494                      MRI.getType(Reg) == MRI.getType(PrimaryReg)) {
495             auto *RCReg = MRI.getRegClassOrNull(Reg);
496             auto *RCPrimary = MRI.getRegClassOrNull(PrimaryReg);
497             if (!RCReg || RCPrimary == RCReg) {
498               RegsAlreadyAddedToDT[&MI] = PrimaryReg;
499               ToErase.push_back(&MI);
500               NeedAssignType = false;
501             }
502           }
503         } else if (MIOp == TargetOpcode::G_FCONSTANT) {
504           Ty = MI.getOperand(1).getFPImm()->getType();
505         } else {
506           assert(MIOp == TargetOpcode::G_BUILD_VECTOR);
507           Type *ElemTy = nullptr;
508           MachineInstr *ElemMI = MRI.getVRegDef(MI.getOperand(1).getReg());
509           assert(ElemMI);
510 
511           if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT) {
512             ElemTy = ElemMI->getOperand(1).getCImm()->getType();
513           } else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT) {
514             ElemTy = ElemMI->getOperand(1).getFPImm()->getType();
515           } else {
516             // There may be a case when we already know Reg's type.
517             MachineInstr *NextMI = MI.getNextNode();
518             if (!NextMI || NextMI->getOpcode() != SPIRV::ASSIGN_TYPE ||
519                 NextMI->getOperand(1).getReg() != Reg)
520               llvm_unreachable("Unexpected opcode");
521           }
522           if (ElemTy)
523             Ty = VectorType::get(
524                 ElemTy, MI.getNumExplicitOperands() - MI.getNumExplicitDefs(),
525                 false);
526           else
527             NeedAssignType = false;
528         }
529         if (NeedAssignType)
530           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI);
531       } else if (MIOp == TargetOpcode::G_GLOBAL_VALUE) {
532         propagateSPIRVType(&MI, GR, MRI, MIB);
533       }
534 
535       if (MII == Begin)
536         ReachedBegin = true;
537       else
538         --MII;
539     }
540   }
541   for (MachineInstr *MI : ToErase) {
542     auto It = RegsAlreadyAddedToDT.find(MI);
543     if (RegsAlreadyAddedToDT.contains(MI))
544       MRI.replaceRegWith(MI->getOperand(0).getReg(), It->second);
545     MI->eraseFromParent();
546   }
547 
548   // Address the case when IRTranslator introduces instructions with new
549   // registers without SPIRVType associated.
550   for (MachineBasicBlock &MBB : MF) {
551     for (MachineInstr &MI : MBB) {
552       switch (MI.getOpcode()) {
553       case TargetOpcode::G_TRUNC:
554       case TargetOpcode::G_ANYEXT:
555       case TargetOpcode::G_SEXT:
556       case TargetOpcode::G_ZEXT:
557       case TargetOpcode::G_PTRTOINT:
558       case TargetOpcode::COPY:
559       case TargetOpcode::G_ADDRSPACE_CAST:
560         propagateSPIRVType(&MI, GR, MRI, MIB);
561         break;
562       }
563     }
564   }
565 }
566 
567 // Defined in SPIRVLegalizerInfo.cpp.
568 extern bool isTypeFoldingSupported(unsigned Opcode);
569 
570 static void processInstrsWithTypeFolding(MachineFunction &MF,
571                                          SPIRVGlobalRegistry *GR,
572                                          MachineIRBuilder MIB) {
573   MachineRegisterInfo &MRI = MF.getRegInfo();
574   for (MachineBasicBlock &MBB : MF) {
575     for (MachineInstr &MI : MBB) {
576       if (isTypeFoldingSupported(MI.getOpcode()))
577         processInstr(MI, MIB, MRI, GR);
578     }
579   }
580 
581   for (MachineBasicBlock &MBB : MF) {
582     for (MachineInstr &MI : MBB) {
583       // We need to rewrite dst types for ASSIGN_TYPE instrs to be able
584       // to perform tblgen'erated selection and we can't do that on Legalizer
585       // as it operates on gMIR only.
586       if (MI.getOpcode() != SPIRV::ASSIGN_TYPE)
587         continue;
588       Register SrcReg = MI.getOperand(1).getReg();
589       unsigned Opcode = MRI.getVRegDef(SrcReg)->getOpcode();
590       if (!isTypeFoldingSupported(Opcode))
591         continue;
592       Register DstReg = MI.getOperand(0).getReg();
593       // Don't need to reset type of register holding constant and used in
594       // G_ADDRSPACE_CAST, since it breaks legalizer.
595       if (Opcode == TargetOpcode::G_CONSTANT && MRI.hasOneUse(DstReg)) {
596         MachineInstr &UseMI = *MRI.use_instr_begin(DstReg);
597         if (UseMI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST)
598           continue;
599       }
600     }
601   }
602 }
603 
604 static Register
605 collectInlineAsmInstrOperands(MachineInstr *MI,
606                               SmallVector<unsigned, 4> *Ops = nullptr) {
607   Register DefReg;
608   unsigned StartOp = InlineAsm::MIOp_FirstOperand,
609            AsmDescOp = InlineAsm::MIOp_FirstOperand;
610   for (unsigned Idx = StartOp, MISz = MI->getNumOperands(); Idx != MISz;
611        ++Idx) {
612     const MachineOperand &MO = MI->getOperand(Idx);
613     if (MO.isMetadata())
614       continue;
615     if (Idx == AsmDescOp && MO.isImm()) {
616       // compute the index of the next operand descriptor
617       const InlineAsm::Flag F(MO.getImm());
618       AsmDescOp += 1 + F.getNumOperandRegisters();
619       continue;
620     }
621     if (MO.isReg() && MO.isDef()) {
622       if (!Ops)
623         return MO.getReg();
624       else
625         DefReg = MO.getReg();
626     } else if (Ops) {
627       Ops->push_back(Idx);
628     }
629   }
630   return DefReg;
631 }
632 
633 static void
634 insertInlineAsmProcess(MachineFunction &MF, SPIRVGlobalRegistry *GR,
635                        const SPIRVSubtarget &ST, MachineIRBuilder MIRBuilder,
636                        const SmallVector<MachineInstr *> &ToProcess) {
637   MachineRegisterInfo &MRI = MF.getRegInfo();
638   Register AsmTargetReg;
639   for (unsigned i = 0, Sz = ToProcess.size(); i + 1 < Sz; i += 2) {
640     MachineInstr *I1 = ToProcess[i], *I2 = ToProcess[i + 1];
641     assert(isSpvIntrinsic(*I1, Intrinsic::spv_inline_asm) && I2->isInlineAsm());
642     MIRBuilder.setInsertPt(*I2->getParent(), *I2);
643 
644     if (!AsmTargetReg.isValid()) {
645       // define vendor specific assembly target or dialect
646       AsmTargetReg = MRI.createGenericVirtualRegister(LLT::scalar(32));
647       MRI.setRegClass(AsmTargetReg, &SPIRV::iIDRegClass);
648       auto AsmTargetMIB =
649           MIRBuilder.buildInstr(SPIRV::OpAsmTargetINTEL).addDef(AsmTargetReg);
650       addStringImm(ST.getTargetTripleAsStr(), AsmTargetMIB);
651       GR->add(AsmTargetMIB.getInstr(), &MF, AsmTargetReg);
652     }
653 
654     // create types
655     const MDNode *IAMD = I1->getOperand(1).getMetadata();
656     FunctionType *FTy = cast<FunctionType>(getMDOperandAsType(IAMD, 0));
657     SmallVector<SPIRVType *, 4> ArgTypes;
658     for (const auto &ArgTy : FTy->params())
659       ArgTypes.push_back(GR->getOrCreateSPIRVType(ArgTy, MIRBuilder));
660     SPIRVType *RetType =
661         GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
662     SPIRVType *FuncType = GR->getOrCreateOpTypeFunctionWithArgs(
663         FTy, RetType, ArgTypes, MIRBuilder);
664 
665     // define vendor specific assembly instructions string
666     Register AsmReg = MRI.createGenericVirtualRegister(LLT::scalar(32));
667     MRI.setRegClass(AsmReg, &SPIRV::iIDRegClass);
668     auto AsmMIB = MIRBuilder.buildInstr(SPIRV::OpAsmINTEL)
669                       .addDef(AsmReg)
670                       .addUse(GR->getSPIRVTypeID(RetType))
671                       .addUse(GR->getSPIRVTypeID(FuncType))
672                       .addUse(AsmTargetReg);
673     // inline asm string:
674     addStringImm(I2->getOperand(InlineAsm::MIOp_AsmString).getSymbolName(),
675                  AsmMIB);
676     // inline asm constraint string:
677     addStringImm(cast<MDString>(I1->getOperand(2).getMetadata()->getOperand(0))
678                      ->getString(),
679                  AsmMIB);
680     GR->add(AsmMIB.getInstr(), &MF, AsmReg);
681 
682     // calls the inline assembly instruction
683     unsigned ExtraInfo = I2->getOperand(InlineAsm::MIOp_ExtraInfo).getImm();
684     if (ExtraInfo & InlineAsm::Extra_HasSideEffects)
685       MIRBuilder.buildInstr(SPIRV::OpDecorate)
686           .addUse(AsmReg)
687           .addImm(static_cast<uint32_t>(SPIRV::Decoration::SideEffectsINTEL));
688 
689     Register DefReg = collectInlineAsmInstrOperands(I2);
690     if (!DefReg.isValid()) {
691       DefReg = MRI.createGenericVirtualRegister(LLT::scalar(32));
692       MRI.setRegClass(DefReg, &SPIRV::iIDRegClass);
693       SPIRVType *VoidType = GR->getOrCreateSPIRVType(
694           Type::getVoidTy(MF.getFunction().getContext()), MIRBuilder);
695       GR->assignSPIRVTypeToVReg(VoidType, DefReg, MF);
696     }
697 
698     auto AsmCall = MIRBuilder.buildInstr(SPIRV::OpAsmCallINTEL)
699                        .addDef(DefReg)
700                        .addUse(GR->getSPIRVTypeID(RetType))
701                        .addUse(AsmReg);
702     for (unsigned IntrIdx = 3; IntrIdx < I1->getNumOperands(); ++IntrIdx)
703       AsmCall.addUse(I1->getOperand(IntrIdx).getReg());
704   }
705   for (MachineInstr *MI : ToProcess)
706     MI->eraseFromParent();
707 }
708 
709 static void insertInlineAsm(MachineFunction &MF, SPIRVGlobalRegistry *GR,
710                             const SPIRVSubtarget &ST,
711                             MachineIRBuilder MIRBuilder) {
712   SmallVector<MachineInstr *> ToProcess;
713   for (MachineBasicBlock &MBB : MF) {
714     for (MachineInstr &MI : MBB) {
715       if (isSpvIntrinsic(MI, Intrinsic::spv_inline_asm) ||
716           MI.getOpcode() == TargetOpcode::INLINEASM)
717         ToProcess.push_back(&MI);
718     }
719   }
720   if (ToProcess.size() == 0)
721     return;
722 
723   if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_inline_assembly))
724     report_fatal_error("Inline assembly instructions require the "
725                        "following SPIR-V extension: SPV_INTEL_inline_assembly",
726                        false);
727 
728   insertInlineAsmProcess(MF, GR, ST, MIRBuilder, ToProcess);
729 }
730 
731 static void insertSpirvDecorations(MachineFunction &MF, MachineIRBuilder MIB) {
732   SmallVector<MachineInstr *, 10> ToErase;
733   for (MachineBasicBlock &MBB : MF) {
734     for (MachineInstr &MI : MBB) {
735       if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_decoration))
736         continue;
737       MIB.setInsertPt(*MI.getParent(), MI);
738       buildOpSpirvDecorations(MI.getOperand(1).getReg(), MIB,
739                               MI.getOperand(2).getMetadata());
740       ToErase.push_back(&MI);
741     }
742   }
743   for (MachineInstr *MI : ToErase)
744     MI->eraseFromParent();
745 }
746 
747 // LLVM allows the switches to use registers as cases, while SPIR-V required
748 // those to be immediate values. This function replaces such operands with the
749 // equivalent immediate constant.
750 static void processSwitchesConstants(MachineFunction &MF,
751                                      SPIRVGlobalRegistry *GR,
752                                      MachineIRBuilder MIB) {
753   MachineRegisterInfo &MRI = MF.getRegInfo();
754   for (MachineBasicBlock &MBB : MF) {
755     for (MachineInstr &MI : MBB) {
756       if (!isSpvIntrinsic(MI, Intrinsic::spv_switch))
757         continue;
758 
759       SmallVector<MachineOperand, 8> NewOperands;
760       NewOperands.push_back(MI.getOperand(0)); // Opcode
761       NewOperands.push_back(MI.getOperand(1)); // Condition
762       NewOperands.push_back(MI.getOperand(2)); // Default
763       for (unsigned i = 3; i < MI.getNumOperands(); i += 2) {
764         Register Reg = MI.getOperand(i).getReg();
765         MachineInstr *ConstInstr = getDefInstrMaybeConstant(Reg, &MRI);
766         NewOperands.push_back(
767             MachineOperand::CreateCImm(ConstInstr->getOperand(1).getCImm()));
768 
769         NewOperands.push_back(MI.getOperand(i + 1));
770       }
771 
772       assert(MI.getNumOperands() == NewOperands.size());
773       while (MI.getNumOperands() > 0)
774         MI.removeOperand(0);
775       for (auto &MO : NewOperands)
776         MI.addOperand(MO);
777     }
778   }
779 }
780 
781 // Some instructions are used during CodeGen but should never be emitted.
782 // Cleaning up those.
783 static void cleanupHelperInstructions(MachineFunction &MF) {
784   SmallVector<MachineInstr *, 8> ToEraseMI;
785   for (MachineBasicBlock &MBB : MF) {
786     for (MachineInstr &MI : MBB) {
787       if (isSpvIntrinsic(MI, Intrinsic::spv_track_constant) ||
788           MI.getOpcode() == TargetOpcode::G_BRINDIRECT)
789         ToEraseMI.push_back(&MI);
790     }
791   }
792 
793   for (MachineInstr *MI : ToEraseMI)
794     MI->eraseFromParent();
795 }
796 
797 // Find all usages of G_BLOCK_ADDR in our intrinsics and replace those
798 // operands/registers by the actual MBB it references.
799 static void processBlockAddr(MachineFunction &MF, SPIRVGlobalRegistry *GR,
800                              MachineIRBuilder MIB) {
801   // Gather the reverse-mapping BB -> MBB.
802   DenseMap<const BasicBlock *, MachineBasicBlock *> BB2MBB;
803   for (MachineBasicBlock &MBB : MF)
804     BB2MBB[MBB.getBasicBlock()] = &MBB;
805 
806   // Gather instructions requiring patching. For now, only those can use
807   // G_BLOCK_ADDR.
808   SmallVector<MachineInstr *, 8> InstructionsToPatch;
809   for (MachineBasicBlock &MBB : MF) {
810     for (MachineInstr &MI : MBB) {
811       if (isSpvIntrinsic(MI, Intrinsic::spv_switch) ||
812           isSpvIntrinsic(MI, Intrinsic::spv_loop_merge) ||
813           isSpvIntrinsic(MI, Intrinsic::spv_selection_merge))
814         InstructionsToPatch.push_back(&MI);
815     }
816   }
817 
818   // For each instruction to fix, we replace all the G_BLOCK_ADDR operands by
819   // the actual MBB it references. Once those references have been updated, we
820   // can cleanup remaining G_BLOCK_ADDR references.
821   SmallPtrSet<MachineBasicBlock *, 8> ClearAddressTaken;
822   SmallPtrSet<MachineInstr *, 8> ToEraseMI;
823   MachineRegisterInfo &MRI = MF.getRegInfo();
824   for (MachineInstr *MI : InstructionsToPatch) {
825     SmallVector<MachineOperand, 8> NewOps;
826     for (unsigned i = 0; i < MI->getNumOperands(); ++i) {
827       // The operand is not a register, keep as-is.
828       if (!MI->getOperand(i).isReg()) {
829         NewOps.push_back(MI->getOperand(i));
830         continue;
831       }
832 
833       Register Reg = MI->getOperand(i).getReg();
834       MachineInstr *BuildMBB = MRI.getVRegDef(Reg);
835       // The register is not the result of G_BLOCK_ADDR, keep as-is.
836       if (!BuildMBB || BuildMBB->getOpcode() != TargetOpcode::G_BLOCK_ADDR) {
837         NewOps.push_back(MI->getOperand(i));
838         continue;
839       }
840 
841       assert(BuildMBB && BuildMBB->getOpcode() == TargetOpcode::G_BLOCK_ADDR &&
842              BuildMBB->getOperand(1).isBlockAddress() &&
843              BuildMBB->getOperand(1).getBlockAddress());
844       BasicBlock *BB =
845           BuildMBB->getOperand(1).getBlockAddress()->getBasicBlock();
846       auto It = BB2MBB.find(BB);
847       if (It == BB2MBB.end())
848         report_fatal_error("cannot find a machine basic block by a basic block "
849                            "in a switch statement");
850       MachineBasicBlock *ReferencedBlock = It->second;
851       NewOps.push_back(MachineOperand::CreateMBB(ReferencedBlock));
852 
853       ClearAddressTaken.insert(ReferencedBlock);
854       ToEraseMI.insert(BuildMBB);
855     }
856 
857     // Replace the operands.
858     assert(MI->getNumOperands() == NewOps.size());
859     while (MI->getNumOperands() > 0)
860       MI->removeOperand(0);
861     for (auto &MO : NewOps)
862       MI->addOperand(MO);
863 
864     if (MachineInstr *Next = MI->getNextNode()) {
865       if (isSpvIntrinsic(*Next, Intrinsic::spv_track_constant)) {
866         ToEraseMI.insert(Next);
867         Next = MI->getNextNode();
868       }
869       if (Next && Next->getOpcode() == TargetOpcode::G_BRINDIRECT)
870         ToEraseMI.insert(Next);
871     }
872   }
873 
874   // BlockAddress operands were used to keep information between passes,
875   // let's undo the "address taken" status to reflect that Succ doesn't
876   // actually correspond to an IR-level basic block.
877   for (MachineBasicBlock *Succ : ClearAddressTaken)
878     Succ->setAddressTakenIRBlock(nullptr);
879 
880   // If we just delete G_BLOCK_ADDR instructions with BlockAddress operands,
881   // this leaves their BasicBlock counterparts in a "address taken" status. This
882   // would make AsmPrinter to generate a series of unneeded labels of a "Address
883   // of block that was removed by CodeGen" kind. Let's first ensure that we
884   // don't have a dangling BlockAddress constants by zapping the BlockAddress
885   // nodes, and only after that proceed with erasing G_BLOCK_ADDR instructions.
886   Constant *Replacement =
887       ConstantInt::get(Type::getInt32Ty(MF.getFunction().getContext()), 1);
888   for (MachineInstr *BlockAddrI : ToEraseMI) {
889     if (BlockAddrI->getOpcode() == TargetOpcode::G_BLOCK_ADDR) {
890       BlockAddress *BA = const_cast<BlockAddress *>(
891           BlockAddrI->getOperand(1).getBlockAddress());
892       BA->replaceAllUsesWith(
893           ConstantExpr::getIntToPtr(Replacement, BA->getType()));
894       BA->destroyConstant();
895     }
896     BlockAddrI->eraseFromParent();
897   }
898 }
899 
900 static bool isImplicitFallthrough(MachineBasicBlock &MBB) {
901   if (MBB.empty())
902     return true;
903 
904   // Branching SPIR-V intrinsics are not detected by this generic method.
905   // Thus, we can only trust negative result.
906   if (!MBB.canFallThrough())
907     return false;
908 
909   // Otherwise, we must manually check if we have a SPIR-V intrinsic which
910   // prevent an implicit fallthrough.
911   for (MachineBasicBlock::reverse_iterator It = MBB.rbegin(), E = MBB.rend();
912        It != E; ++It) {
913     if (isSpvIntrinsic(*It, Intrinsic::spv_switch))
914       return false;
915   }
916   return true;
917 }
918 
919 static void removeImplicitFallthroughs(MachineFunction &MF,
920                                        MachineIRBuilder MIB) {
921   // It is valid for MachineBasicBlocks to not finish with a branch instruction.
922   // In such cases, they will simply fallthrough their immediate successor.
923   for (MachineBasicBlock &MBB : MF) {
924     if (!isImplicitFallthrough(MBB))
925       continue;
926 
927     assert(std::distance(MBB.successors().begin(), MBB.successors().end()) ==
928            1);
929     MIB.setInsertPt(MBB, MBB.end());
930     MIB.buildBr(**MBB.successors().begin());
931   }
932 }
933 
934 bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
935   // Initialize the type registry.
936   const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
937   SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
938   GR->setCurrentFunc(MF);
939   MachineIRBuilder MIB(MF);
940   // a registry of target extension constants
941   DenseMap<MachineInstr *, Type *> TargetExtConstTypes;
942   // to keep record of tracked constants
943   SmallSet<Register, 4> TrackedConstRegs;
944   addConstantsToTrack(MF, GR, ST, TargetExtConstTypes, TrackedConstRegs);
945   foldConstantsIntoIntrinsics(MF, TrackedConstRegs);
946   insertBitcasts(MF, GR, MIB);
947   generateAssignInstrs(MF, GR, MIB, TargetExtConstTypes);
948 
949   processSwitchesConstants(MF, GR, MIB);
950   processBlockAddr(MF, GR, MIB);
951   cleanupHelperInstructions(MF);
952 
953   processInstrsWithTypeFolding(MF, GR, MIB);
954   removeImplicitFallthroughs(MF, MIB);
955   insertSpirvDecorations(MF, MIB);
956   insertInlineAsm(MF, GR, ST, MIB);
957 
958   return true;
959 }
960 
961 INITIALIZE_PASS(SPIRVPreLegalizer, DEBUG_TYPE, "SPIRV pre legalizer", false,
962                 false)
963 
964 char SPIRVPreLegalizer::ID = 0;
965 
966 FunctionPass *llvm::createSPIRVPreLegalizerPass() {
967   return new SPIRVPreLegalizer();
968 }
969