xref: /llvm-project/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (revision 7c917e8268225735bf6fe0f7d8491fc944358e47)
1 //===-- SPIRVPreLegalizer.cpp - prepare IR for legalization -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The pass prepares IR for legalization: it assigns SPIR-V types to registers
10 // and removes intrinsics which holded these types during IR translation.
11 // Also it processes constants and registers them in GR to avoid duplication.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "SPIRV.h"
16 #include "SPIRVSubtarget.h"
17 #include "SPIRVUtils.h"
18 #include "llvm/ADT/PostOrderIterator.h"
19 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
20 #include "llvm/IR/Attributes.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DebugInfoMetadata.h"
23 #include "llvm/IR/IntrinsicsSPIRV.h"
24 #include "llvm/Target/TargetIntrinsicInfo.h"
25 
26 #define DEBUG_TYPE "spirv-prelegalizer"
27 
28 using namespace llvm;
29 
30 namespace {
31 class SPIRVPreLegalizer : public MachineFunctionPass {
32 public:
33   static char ID;
34   SPIRVPreLegalizer() : MachineFunctionPass(ID) {
35     initializeSPIRVPreLegalizerPass(*PassRegistry::getPassRegistry());
36   }
37   bool runOnMachineFunction(MachineFunction &MF) override;
38 };
39 } // namespace
40 
41 static void
42 addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR,
43                     const SPIRVSubtarget &STI,
44                     DenseMap<MachineInstr *, Type *> &TargetExtConstTypes) {
45   MachineRegisterInfo &MRI = MF.getRegInfo();
46   DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
47   SmallVector<MachineInstr *, 10> ToErase, ToEraseComposites;
48   for (MachineBasicBlock &MBB : MF) {
49     for (MachineInstr &MI : MBB) {
50       if (!isSpvIntrinsic(MI, Intrinsic::spv_track_constant))
51         continue;
52       ToErase.push_back(&MI);
53       Register SrcReg = MI.getOperand(2).getReg();
54       auto *Const =
55           cast<Constant>(cast<ConstantAsMetadata>(
56                              MI.getOperand(3).getMetadata()->getOperand(0))
57                              ->getValue());
58       if (auto *GV = dyn_cast<GlobalValue>(Const)) {
59         Register Reg = GR->find(GV, &MF);
60         if (!Reg.isValid())
61           GR->add(GV, &MF, SrcReg);
62         else
63           RegsAlreadyAddedToDT[&MI] = Reg;
64       } else {
65         Register Reg = GR->find(Const, &MF);
66         if (!Reg.isValid()) {
67           if (auto *ConstVec = dyn_cast<ConstantDataVector>(Const)) {
68             auto *BuildVec = MRI.getVRegDef(SrcReg);
69             assert(BuildVec &&
70                    BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
71             for (unsigned i = 0; i < ConstVec->getNumElements(); ++i) {
72               // Ensure that OpConstantComposite reuses a constant when it's
73               // already created and available in the same machine function.
74               Constant *ElemConst = ConstVec->getElementAsConstant(i);
75               Register ElemReg = GR->find(ElemConst, &MF);
76               if (!ElemReg.isValid())
77                 GR->add(ElemConst, &MF, BuildVec->getOperand(1 + i).getReg());
78               else
79                 BuildVec->getOperand(1 + i).setReg(ElemReg);
80             }
81           }
82           GR->add(Const, &MF, SrcReg);
83           if (Const->getType()->isTargetExtTy()) {
84             // remember association so that we can restore it when assign types
85             MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
86             if (SrcMI && (SrcMI->getOpcode() == TargetOpcode::G_CONSTANT ||
87                           SrcMI->getOpcode() == TargetOpcode::G_IMPLICIT_DEF))
88               TargetExtConstTypes[SrcMI] = Const->getType();
89             if (Const->isNullValue()) {
90               MachineIRBuilder MIB(MF);
91               SPIRVType *ExtType =
92                   GR->getOrCreateSPIRVType(Const->getType(), MIB);
93               SrcMI->setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull));
94               SrcMI->addOperand(MachineOperand::CreateReg(
95                   GR->getSPIRVTypeID(ExtType), false));
96             }
97           }
98         } else {
99           RegsAlreadyAddedToDT[&MI] = Reg;
100           // This MI is unused and will be removed. If the MI uses
101           // const_composite, it will be unused and should be removed too.
102           assert(MI.getOperand(2).isReg() && "Reg operand is expected");
103           MachineInstr *SrcMI = MRI.getVRegDef(MI.getOperand(2).getReg());
104           if (SrcMI && isSpvIntrinsic(*SrcMI, Intrinsic::spv_const_composite))
105             ToEraseComposites.push_back(SrcMI);
106         }
107       }
108     }
109   }
110   for (MachineInstr *MI : ToErase) {
111     Register Reg = MI->getOperand(2).getReg();
112     if (RegsAlreadyAddedToDT.contains(MI))
113       Reg = RegsAlreadyAddedToDT[MI];
114     auto *RC = MRI.getRegClassOrNull(MI->getOperand(0).getReg());
115     if (!MRI.getRegClassOrNull(Reg) && RC)
116       MRI.setRegClass(Reg, RC);
117     MRI.replaceRegWith(MI->getOperand(0).getReg(), Reg);
118     MI->eraseFromParent();
119   }
120   for (MachineInstr *MI : ToEraseComposites)
121     MI->eraseFromParent();
122 }
123 
124 static void foldConstantsIntoIntrinsics(MachineFunction &MF) {
125   SmallVector<MachineInstr *, 10> ToErase;
126   MachineRegisterInfo &MRI = MF.getRegInfo();
127   const unsigned AssignNameOperandShift = 2;
128   for (MachineBasicBlock &MBB : MF) {
129     for (MachineInstr &MI : MBB) {
130       if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_name))
131         continue;
132       unsigned NumOp = MI.getNumExplicitDefs() + AssignNameOperandShift;
133       while (MI.getOperand(NumOp).isReg()) {
134         MachineOperand &MOp = MI.getOperand(NumOp);
135         MachineInstr *ConstMI = MRI.getVRegDef(MOp.getReg());
136         assert(ConstMI->getOpcode() == TargetOpcode::G_CONSTANT);
137         MI.removeOperand(NumOp);
138         MI.addOperand(MachineOperand::CreateImm(
139             ConstMI->getOperand(1).getCImm()->getZExtValue()));
140         if (MRI.use_empty(ConstMI->getOperand(0).getReg()))
141           ToErase.push_back(ConstMI);
142       }
143     }
144   }
145   for (MachineInstr *MI : ToErase)
146     MI->eraseFromParent();
147 }
148 
149 static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
150                            MachineIRBuilder MIB) {
151   // Get access to information about available extensions
152   const SPIRVSubtarget *ST =
153       static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
154   SmallVector<MachineInstr *, 10> ToErase;
155   for (MachineBasicBlock &MBB : MF) {
156     for (MachineInstr &MI : MBB) {
157       if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast) &&
158           !isSpvIntrinsic(MI, Intrinsic::spv_ptrcast))
159         continue;
160       assert(MI.getOperand(2).isReg());
161       MIB.setInsertPt(*MI.getParent(), MI);
162       ToErase.push_back(&MI);
163       if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) {
164         MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
165         continue;
166       }
167       Register Def = MI.getOperand(0).getReg();
168       Register Source = MI.getOperand(2).getReg();
169       SPIRVType *BaseTy = GR->getOrCreateSPIRVType(
170           getMDOperandAsType(MI.getOperand(3).getMetadata(), 0), MIB);
171       SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
172           BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
173           addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST));
174 
175       // If the bitcast would be redundant, replace all uses with the source
176       // register.
177       if (GR->getSPIRVTypeForVReg(Source) == AssignedPtrType) {
178         MIB.getMRI()->replaceRegWith(Def, Source);
179       } else {
180         GR->assignSPIRVTypeToVReg(AssignedPtrType, Def, MF);
181         MIB.buildBitcast(Def, Source);
182       }
183     }
184   }
185   for (MachineInstr *MI : ToErase)
186     MI->eraseFromParent();
187 }
188 
189 // Translating GV, IRTranslator sometimes generates following IR:
190 //   %1 = G_GLOBAL_VALUE
191 //   %2 = COPY %1
192 //   %3 = G_ADDRSPACE_CAST %2
193 //
194 // or
195 //
196 //  %1 = G_ZEXT %2
197 //  G_MEMCPY ... %2 ...
198 //
199 // New registers have no SPIRVType and no register class info.
200 //
201 // Set SPIRVType for GV, propagate it from GV to other instructions,
202 // also set register classes.
203 static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
204                                      MachineRegisterInfo &MRI,
205                                      MachineIRBuilder &MIB) {
206   SPIRVType *SpirvTy = nullptr;
207   assert(MI && "Machine instr is expected");
208   if (MI->getOperand(0).isReg()) {
209     Register Reg = MI->getOperand(0).getReg();
210     SpirvTy = GR->getSPIRVTypeForVReg(Reg);
211     if (!SpirvTy) {
212       switch (MI->getOpcode()) {
213       case TargetOpcode::G_CONSTANT: {
214         MIB.setInsertPt(*MI->getParent(), MI);
215         Type *Ty = MI->getOperand(1).getCImm()->getType();
216         SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
217         break;
218       }
219       case TargetOpcode::G_GLOBAL_VALUE: {
220         MIB.setInsertPt(*MI->getParent(), MI);
221         const GlobalValue *Global = MI->getOperand(1).getGlobal();
222         Type *ElementTy = GR->getDeducedGlobalValueType(Global);
223         auto *Ty = TypedPointerType::get(ElementTy,
224                                          Global->getType()->getAddressSpace());
225         SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
226         break;
227       }
228       case TargetOpcode::G_ANYEXT:
229       case TargetOpcode::G_SEXT:
230       case TargetOpcode::G_ZEXT: {
231         if (MI->getOperand(1).isReg()) {
232           if (MachineInstr *DefInstr =
233                   MRI.getVRegDef(MI->getOperand(1).getReg())) {
234             if (SPIRVType *Def = propagateSPIRVType(DefInstr, GR, MRI, MIB)) {
235               unsigned CurrentBW = GR->getScalarOrVectorBitWidth(Def);
236               unsigned ExpectedBW =
237                   std::max(MRI.getType(Reg).getScalarSizeInBits(), CurrentBW);
238               unsigned NumElements = GR->getScalarOrVectorComponentCount(Def);
239               SpirvTy = GR->getOrCreateSPIRVIntegerType(ExpectedBW, MIB);
240               if (NumElements > 1)
241                 SpirvTy =
242                     GR->getOrCreateSPIRVVectorType(SpirvTy, NumElements, MIB);
243             }
244           }
245         }
246         break;
247       }
248       case TargetOpcode::G_PTRTOINT:
249         SpirvTy = GR->getOrCreateSPIRVIntegerType(
250             MRI.getType(Reg).getScalarSizeInBits(), MIB);
251         break;
252       case TargetOpcode::G_TRUNC:
253       case TargetOpcode::G_ADDRSPACE_CAST:
254       case TargetOpcode::G_PTR_ADD:
255       case TargetOpcode::COPY: {
256         MachineOperand &Op = MI->getOperand(1);
257         MachineInstr *Def = Op.isReg() ? MRI.getVRegDef(Op.getReg()) : nullptr;
258         if (Def)
259           SpirvTy = propagateSPIRVType(Def, GR, MRI, MIB);
260         break;
261       }
262       default:
263         break;
264       }
265       if (SpirvTy)
266         GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF());
267       if (!MRI.getRegClassOrNull(Reg))
268         MRI.setRegClass(Reg, &SPIRV::IDRegClass);
269     }
270   }
271   return SpirvTy;
272 }
273 
274 static std::pair<Register, unsigned>
275 createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI,
276                const SPIRVGlobalRegistry &GR) {
277   if (!SpvType)
278     SpvType = GR.getSPIRVTypeForVReg(SrcReg);
279   assert(SpvType && "VReg is expected to have SPIRV type");
280   LLT SrcLLT = MRI.getType(SrcReg);
281   LLT NewT = LLT::scalar(32);
282   bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat;
283   bool IsVectorFloat =
284       SpvType->getOpcode() == SPIRV::OpTypeVector &&
285       GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg())->getOpcode() ==
286           SPIRV::OpTypeFloat;
287   IsFloat |= IsVectorFloat;
288   auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID;
289   auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass;
290   if (SrcLLT.isPointer()) {
291     unsigned PtrSz = GR.getPointerSize();
292     NewT = LLT::pointer(0, PtrSz);
293     bool IsVec = SrcLLT.isVector();
294     if (IsVec)
295       NewT = LLT::fixed_vector(2, NewT);
296     if (PtrSz == 64) {
297       if (IsVec) {
298         GetIdOp = SPIRV::GET_vpID64;
299         DstClass = &SPIRV::vpID64RegClass;
300       } else {
301         GetIdOp = SPIRV::GET_pID64;
302         DstClass = &SPIRV::pID64RegClass;
303       }
304     } else {
305       if (IsVec) {
306         GetIdOp = SPIRV::GET_vpID32;
307         DstClass = &SPIRV::vpID32RegClass;
308       } else {
309         GetIdOp = SPIRV::GET_pID32;
310         DstClass = &SPIRV::pID32RegClass;
311       }
312     }
313   } else if (SrcLLT.isVector()) {
314     NewT = LLT::fixed_vector(2, NewT);
315     if (IsFloat) {
316       GetIdOp = SPIRV::GET_vfID;
317       DstClass = &SPIRV::vfIDRegClass;
318     } else {
319       GetIdOp = SPIRV::GET_vID;
320       DstClass = &SPIRV::vIDRegClass;
321     }
322   }
323   Register IdReg = MRI.createGenericVirtualRegister(NewT);
324   MRI.setRegClass(IdReg, DstClass);
325   return {IdReg, GetIdOp};
326 }
327 
328 // Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as
329 // a dst of the definition, assign SPIRVType to both registers. If SpirvTy is
330 // provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty.
331 // It's used also in SPIRVBuiltins.cpp.
332 // TODO: maybe move to SPIRVUtils.
333 namespace llvm {
334 Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
335                            SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
336                            MachineRegisterInfo &MRI) {
337   MachineInstr *Def = MRI.getVRegDef(Reg);
338   assert((Ty || SpirvTy) && "Either LLVM or SPIRV type is expected.");
339   MIB.setInsertPt(*Def->getParent(),
340                   (Def->getNextNode() ? Def->getNextNode()->getIterator()
341                                       : Def->getParent()->end()));
342   SpirvTy = SpirvTy ? SpirvTy : GR->getOrCreateSPIRVType(Ty, MIB);
343   Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
344   if (auto *RC = MRI.getRegClassOrNull(Reg)) {
345     MRI.setRegClass(NewReg, RC);
346   } else {
347     MRI.setRegClass(NewReg, &SPIRV::IDRegClass);
348     MRI.setRegClass(Reg, &SPIRV::IDRegClass);
349   }
350   GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF());
351   // This is to make it convenient for Legalizer to get the SPIRVType
352   // when processing the actual MI (i.e. not pseudo one).
353   GR->assignSPIRVTypeToVReg(SpirvTy, NewReg, MIB.getMF());
354   // Copy MIFlags from Def to ASSIGN_TYPE instruction. It's required to keep
355   // the flags after instruction selection.
356   const uint32_t Flags = Def->getFlags();
357   MIB.buildInstr(SPIRV::ASSIGN_TYPE)
358       .addDef(Reg)
359       .addUse(NewReg)
360       .addUse(GR->getSPIRVTypeID(SpirvTy))
361       .setMIFlags(Flags);
362   Def->getOperand(0).setReg(NewReg);
363   return NewReg;
364 }
365 
366 void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
367                   MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) {
368   assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg()));
369   MachineInstr &AssignTypeInst =
370       *(MRI.use_instr_begin(MI.getOperand(0).getReg()));
371   auto NewReg =
372       createNewIdReg(nullptr, MI.getOperand(0).getReg(), MRI, *GR).first;
373   AssignTypeInst.getOperand(1).setReg(NewReg);
374   MI.getOperand(0).setReg(NewReg);
375   MIB.setInsertPt(*MI.getParent(),
376                   (MI.getNextNode() ? MI.getNextNode()->getIterator()
377                                     : MI.getParent()->end()));
378   for (auto &Op : MI.operands()) {
379     if (!Op.isReg() || Op.isDef())
380       continue;
381     auto IdOpInfo = createNewIdReg(nullptr, Op.getReg(), MRI, *GR);
382     MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg());
383     Op.setReg(IdOpInfo.first);
384   }
385 }
386 } // namespace llvm
387 
388 static void
389 generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
390                      MachineIRBuilder MIB,
391                      DenseMap<MachineInstr *, Type *> &TargetExtConstTypes) {
392   // Get access to information about available extensions
393   const SPIRVSubtarget *ST =
394       static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
395 
396   MachineRegisterInfo &MRI = MF.getRegInfo();
397   SmallVector<MachineInstr *, 10> ToErase;
398 
399   for (MachineBasicBlock *MBB : post_order(&MF)) {
400     if (MBB->empty())
401       continue;
402 
403     bool ReachedBegin = false;
404     for (auto MII = std::prev(MBB->end()), Begin = MBB->begin();
405          !ReachedBegin;) {
406       MachineInstr &MI = *MII;
407       unsigned MIOp = MI.getOpcode();
408 
409       if (isSpvIntrinsic(MI, Intrinsic::spv_assign_ptr_type)) {
410         Register Reg = MI.getOperand(1).getReg();
411         MIB.setInsertPt(*MI.getParent(), MI.getIterator());
412         SPIRVType *BaseTy = GR->getOrCreateSPIRVType(
413             getMDOperandAsType(MI.getOperand(2).getMetadata(), 0), MIB);
414         SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
415             BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
416             addressSpaceToStorageClass(MI.getOperand(3).getImm(), *ST));
417         MachineInstr *Def = MRI.getVRegDef(Reg);
418         assert(Def && "Expecting an instruction that defines the register");
419         // G_GLOBAL_VALUE already has type info.
420         if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE)
421           insertAssignInstr(Reg, nullptr, AssignedPtrType, GR, MIB,
422                             MF.getRegInfo());
423         ToErase.push_back(&MI);
424       } else if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) {
425         Register Reg = MI.getOperand(1).getReg();
426         Type *Ty = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
427         MachineInstr *Def = MRI.getVRegDef(Reg);
428         assert(Def && "Expecting an instruction that defines the register");
429         // G_GLOBAL_VALUE already has type info.
430         if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE)
431           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo());
432         ToErase.push_back(&MI);
433       } else if (MIOp == TargetOpcode::G_CONSTANT ||
434                  MIOp == TargetOpcode::G_FCONSTANT ||
435                  MIOp == TargetOpcode::G_BUILD_VECTOR) {
436         // %rc = G_CONSTANT ty Val
437         // ===>
438         // %cty = OpType* ty
439         // %rctmp = G_CONSTANT ty Val
440         // %rc = ASSIGN_TYPE %rctmp, %cty
441         Register Reg = MI.getOperand(0).getReg();
442         if (MRI.hasOneUse(Reg)) {
443           MachineInstr &UseMI = *MRI.use_instr_begin(Reg);
444           if (isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type) ||
445               isSpvIntrinsic(UseMI, Intrinsic::spv_assign_name))
446             continue;
447         }
448         Type *Ty = nullptr;
449         if (MIOp == TargetOpcode::G_CONSTANT) {
450           auto TargetExtIt = TargetExtConstTypes.find(&MI);
451           Ty = TargetExtIt == TargetExtConstTypes.end()
452                    ? MI.getOperand(1).getCImm()->getType()
453                    : TargetExtIt->second;
454         } else if (MIOp == TargetOpcode::G_FCONSTANT) {
455           Ty = MI.getOperand(1).getFPImm()->getType();
456         } else {
457           assert(MIOp == TargetOpcode::G_BUILD_VECTOR);
458           Type *ElemTy = nullptr;
459           MachineInstr *ElemMI = MRI.getVRegDef(MI.getOperand(1).getReg());
460           assert(ElemMI);
461 
462           if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT)
463             ElemTy = ElemMI->getOperand(1).getCImm()->getType();
464           else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT)
465             ElemTy = ElemMI->getOperand(1).getFPImm()->getType();
466           else
467             llvm_unreachable("Unexpected opcode");
468           unsigned NumElts =
469               MI.getNumExplicitOperands() - MI.getNumExplicitDefs();
470           Ty = VectorType::get(ElemTy, NumElts, false);
471         }
472         insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI);
473       } else if (MIOp == TargetOpcode::G_GLOBAL_VALUE) {
474         propagateSPIRVType(&MI, GR, MRI, MIB);
475       }
476 
477       if (MII == Begin)
478         ReachedBegin = true;
479       else
480         --MII;
481     }
482   }
483   for (MachineInstr *MI : ToErase)
484     MI->eraseFromParent();
485 
486   // Address the case when IRTranslator introduces instructions with new
487   // registers without SPIRVType associated.
488   for (MachineBasicBlock &MBB : MF) {
489     for (MachineInstr &MI : MBB) {
490       switch (MI.getOpcode()) {
491       case TargetOpcode::G_TRUNC:
492       case TargetOpcode::G_ANYEXT:
493       case TargetOpcode::G_SEXT:
494       case TargetOpcode::G_ZEXT:
495       case TargetOpcode::G_PTRTOINT:
496       case TargetOpcode::COPY:
497       case TargetOpcode::G_ADDRSPACE_CAST:
498         propagateSPIRVType(&MI, GR, MRI, MIB);
499         break;
500       }
501     }
502   }
503 }
504 
505 // Defined in SPIRVLegalizerInfo.cpp.
506 extern bool isTypeFoldingSupported(unsigned Opcode);
507 
508 static void processInstrsWithTypeFolding(MachineFunction &MF,
509                                          SPIRVGlobalRegistry *GR,
510                                          MachineIRBuilder MIB) {
511   MachineRegisterInfo &MRI = MF.getRegInfo();
512   for (MachineBasicBlock &MBB : MF) {
513     for (MachineInstr &MI : MBB) {
514       if (isTypeFoldingSupported(MI.getOpcode()))
515         processInstr(MI, MIB, MRI, GR);
516     }
517   }
518 
519   for (MachineBasicBlock &MBB : MF) {
520     for (MachineInstr &MI : MBB) {
521       // We need to rewrite dst types for ASSIGN_TYPE instrs to be able
522       // to perform tblgen'erated selection and we can't do that on Legalizer
523       // as it operates on gMIR only.
524       if (MI.getOpcode() != SPIRV::ASSIGN_TYPE)
525         continue;
526       Register SrcReg = MI.getOperand(1).getReg();
527       unsigned Opcode = MRI.getVRegDef(SrcReg)->getOpcode();
528       if (!isTypeFoldingSupported(Opcode))
529         continue;
530       Register DstReg = MI.getOperand(0).getReg();
531       bool IsDstPtr = MRI.getType(DstReg).isPointer();
532       bool isDstVec = MRI.getType(DstReg).isVector();
533       if (IsDstPtr || isDstVec)
534         MRI.setRegClass(DstReg, &SPIRV::IDRegClass);
535       // Don't need to reset type of register holding constant and used in
536       // G_ADDRSPACE_CAST, since it breaks legalizer.
537       if (Opcode == TargetOpcode::G_CONSTANT && MRI.hasOneUse(DstReg)) {
538         MachineInstr &UseMI = *MRI.use_instr_begin(DstReg);
539         if (UseMI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST)
540           continue;
541       }
542       MRI.setType(DstReg, IsDstPtr ? LLT::pointer(0, GR->getPointerSize())
543                                    : LLT::scalar(32));
544     }
545   }
546 }
547 
548 static void
549 insertInlineAsmProcess(MachineFunction &MF, SPIRVGlobalRegistry *GR,
550                        const SPIRVSubtarget &ST, MachineIRBuilder MIRBuilder,
551                        const SmallVector<MachineInstr *> &ToProcess) {
552   MachineRegisterInfo &MRI = MF.getRegInfo();
553   Register AsmTargetReg;
554   for (unsigned i = 0, Sz = ToProcess.size(); i + 1 < Sz; i += 2) {
555     MachineInstr *I1 = ToProcess[i], *I2 = ToProcess[i + 1];
556     assert(isSpvIntrinsic(*I1, Intrinsic::spv_inline_asm) && I2->isInlineAsm());
557     MIRBuilder.setInsertPt(*I1->getParent(), *I1);
558 
559     if (!AsmTargetReg.isValid()) {
560       // define vendor specific assembly target or dialect
561       AsmTargetReg = MRI.createGenericVirtualRegister(LLT::scalar(32));
562       MRI.setRegClass(AsmTargetReg, &SPIRV::IDRegClass);
563       auto AsmTargetMIB =
564           MIRBuilder.buildInstr(SPIRV::OpAsmTargetINTEL).addDef(AsmTargetReg);
565       addStringImm(ST.getTargetTripleAsStr(), AsmTargetMIB);
566       GR->add(AsmTargetMIB.getInstr(), &MF, AsmTargetReg);
567     }
568 
569     // create types
570     const MDNode *IAMD = I1->getOperand(1).getMetadata();
571     FunctionType *FTy = cast<FunctionType>(getMDOperandAsType(IAMD, 0));
572     SmallVector<SPIRVType *, 4> ArgTypes;
573     for (const auto &ArgTy : FTy->params())
574       ArgTypes.push_back(GR->getOrCreateSPIRVType(ArgTy, MIRBuilder));
575     SPIRVType *RetType =
576         GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
577     SPIRVType *FuncType = GR->getOrCreateOpTypeFunctionWithArgs(
578         FTy, RetType, ArgTypes, MIRBuilder);
579 
580     // define vendor specific assembly instructions string
581     Register AsmReg = MRI.createGenericVirtualRegister(LLT::scalar(32));
582     MRI.setRegClass(AsmReg, &SPIRV::IDRegClass);
583     auto AsmMIB = MIRBuilder.buildInstr(SPIRV::OpAsmINTEL)
584                       .addDef(AsmReg)
585                       .addUse(GR->getSPIRVTypeID(RetType))
586                       .addUse(GR->getSPIRVTypeID(FuncType))
587                       .addUse(AsmTargetReg);
588     // inline asm string:
589     addStringImm(I2->getOperand(InlineAsm::MIOp_AsmString).getSymbolName(),
590                  AsmMIB);
591     // inline asm constraint string:
592     addStringImm(cast<MDString>(I1->getOperand(2).getMetadata()->getOperand(0))
593                      ->getString(),
594                  AsmMIB);
595     GR->add(AsmMIB.getInstr(), &MF, AsmReg);
596 
597     // calls the inline assembly instruction
598     unsigned ExtraInfo = I2->getOperand(InlineAsm::MIOp_ExtraInfo).getImm();
599     if (ExtraInfo & InlineAsm::Extra_HasSideEffects)
600       MIRBuilder.buildInstr(SPIRV::OpDecorate)
601           .addUse(AsmReg)
602           .addImm(static_cast<uint32_t>(SPIRV::Decoration::SideEffectsINTEL));
603     Register DefReg;
604     SmallVector<unsigned, 4> Ops;
605     unsigned StartOp = InlineAsm::MIOp_FirstOperand,
606              AsmDescOp = InlineAsm::MIOp_FirstOperand;
607     unsigned I2Sz = I2->getNumOperands();
608     for (unsigned Idx = StartOp; Idx != I2Sz; ++Idx) {
609       const MachineOperand &MO = I2->getOperand(Idx);
610       if (MO.isMetadata())
611         continue;
612       if (Idx == AsmDescOp && MO.isImm()) {
613         // compute the index of the next operand descriptor
614         const InlineAsm::Flag F(MO.getImm());
615         AsmDescOp += 1 + F.getNumOperandRegisters();
616       } else {
617         if (MO.isReg() && MO.isDef())
618           DefReg = MO.getReg();
619         else
620           Ops.push_back(Idx);
621       }
622     }
623     if (!DefReg.isValid()) {
624       DefReg = MRI.createGenericVirtualRegister(LLT::scalar(32));
625       MRI.setRegClass(DefReg, &SPIRV::IDRegClass);
626       SPIRVType *VoidType = GR->getOrCreateSPIRVType(
627           Type::getVoidTy(MF.getFunction().getContext()), MIRBuilder);
628       GR->assignSPIRVTypeToVReg(VoidType, DefReg, MF);
629     }
630     auto AsmCall = MIRBuilder.buildInstr(SPIRV::OpAsmCallINTEL)
631                        .addDef(DefReg)
632                        .addUse(GR->getSPIRVTypeID(RetType))
633                        .addUse(AsmReg);
634     unsigned IntrIdx = 2;
635     for (unsigned Idx : Ops) {
636       ++IntrIdx;
637       const MachineOperand &MO = I2->getOperand(Idx);
638       if (MO.isReg())
639         AsmCall.addUse(MO.getReg());
640       else
641         AsmCall.addUse(I1->getOperand(IntrIdx).getReg());
642     }
643   }
644   for (MachineInstr *MI : ToProcess)
645     MI->eraseFromParent();
646 }
647 
648 static void insertInlineAsm(MachineFunction &MF, SPIRVGlobalRegistry *GR,
649                             const SPIRVSubtarget &ST,
650                             MachineIRBuilder MIRBuilder) {
651   SmallVector<MachineInstr *> ToProcess;
652   for (MachineBasicBlock &MBB : MF) {
653     for (MachineInstr &MI : MBB) {
654       if (isSpvIntrinsic(MI, Intrinsic::spv_inline_asm) ||
655           MI.getOpcode() == TargetOpcode::INLINEASM)
656         ToProcess.push_back(&MI);
657     }
658   }
659   if (ToProcess.size() == 0)
660     return;
661 
662   if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_inline_assembly))
663     report_fatal_error("Inline assembly instructions require the "
664                        "following SPIR-V extension: SPV_INTEL_inline_assembly",
665                        false);
666 
667   insertInlineAsmProcess(MF, GR, ST, MIRBuilder, ToProcess);
668 }
669 
670 static void insertSpirvDecorations(MachineFunction &MF, MachineIRBuilder MIB) {
671   SmallVector<MachineInstr *, 10> ToErase;
672   for (MachineBasicBlock &MBB : MF) {
673     for (MachineInstr &MI : MBB) {
674       if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_decoration))
675         continue;
676       MIB.setInsertPt(*MI.getParent(), MI);
677       buildOpSpirvDecorations(MI.getOperand(1).getReg(), MIB,
678                               MI.getOperand(2).getMetadata());
679       ToErase.push_back(&MI);
680     }
681   }
682   for (MachineInstr *MI : ToErase)
683     MI->eraseFromParent();
684 }
685 
686 // Find basic blocks of the switch and replace registers in spv_switch() by its
687 // MBB equivalent.
688 static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
689                             MachineIRBuilder MIB) {
690   DenseMap<const BasicBlock *, MachineBasicBlock *> BB2MBB;
691   SmallVector<std::pair<MachineInstr *, SmallVector<MachineInstr *, 8>>>
692       Switches;
693   for (MachineBasicBlock &MBB : MF) {
694     MachineRegisterInfo &MRI = MF.getRegInfo();
695     BB2MBB[MBB.getBasicBlock()] = &MBB;
696     for (MachineInstr &MI : MBB) {
697       if (!isSpvIntrinsic(MI, Intrinsic::spv_switch))
698         continue;
699       // Calls to spv_switch intrinsics representing IR switches.
700       SmallVector<MachineInstr *, 8> NewOps;
701       for (unsigned i = 2; i < MI.getNumOperands(); ++i) {
702         Register Reg = MI.getOperand(i).getReg();
703         if (i % 2 == 1) {
704           MachineInstr *ConstInstr = getDefInstrMaybeConstant(Reg, &MRI);
705           NewOps.push_back(ConstInstr);
706         } else {
707           MachineInstr *BuildMBB = MRI.getVRegDef(Reg);
708           assert(BuildMBB &&
709                  BuildMBB->getOpcode() == TargetOpcode::G_BLOCK_ADDR &&
710                  BuildMBB->getOperand(1).isBlockAddress() &&
711                  BuildMBB->getOperand(1).getBlockAddress());
712           NewOps.push_back(BuildMBB);
713         }
714       }
715       Switches.push_back(std::make_pair(&MI, NewOps));
716     }
717   }
718 
719   SmallPtrSet<MachineInstr *, 8> ToEraseMI;
720   for (auto &SwIt : Switches) {
721     MachineInstr &MI = *SwIt.first;
722     SmallVector<MachineInstr *, 8> &Ins = SwIt.second;
723     SmallVector<MachineOperand, 8> NewOps;
724     for (unsigned i = 0; i < Ins.size(); ++i) {
725       if (Ins[i]->getOpcode() == TargetOpcode::G_BLOCK_ADDR) {
726         BasicBlock *CaseBB =
727             Ins[i]->getOperand(1).getBlockAddress()->getBasicBlock();
728         auto It = BB2MBB.find(CaseBB);
729         if (It == BB2MBB.end())
730           report_fatal_error("cannot find a machine basic block by a basic "
731                              "block in a switch statement");
732         NewOps.push_back(MachineOperand::CreateMBB(It->second));
733         MI.getParent()->addSuccessor(It->second);
734         ToEraseMI.insert(Ins[i]);
735       } else {
736         NewOps.push_back(
737             MachineOperand::CreateCImm(Ins[i]->getOperand(1).getCImm()));
738       }
739     }
740     for (unsigned i = MI.getNumOperands() - 1; i > 1; --i)
741       MI.removeOperand(i);
742     for (auto &MO : NewOps)
743       MI.addOperand(MO);
744     if (MachineInstr *Next = MI.getNextNode()) {
745       if (isSpvIntrinsic(*Next, Intrinsic::spv_track_constant)) {
746         ToEraseMI.insert(Next);
747         Next = MI.getNextNode();
748       }
749       if (Next && Next->getOpcode() == TargetOpcode::G_BRINDIRECT)
750         ToEraseMI.insert(Next);
751     }
752   }
753 
754   // If we just delete G_BLOCK_ADDR instructions with BlockAddress operands,
755   // this leaves their BasicBlock counterparts in a "address taken" status. This
756   // would make AsmPrinter to generate a series of unneeded labels of a "Address
757   // of block that was removed by CodeGen" kind. Let's first ensure that we
758   // don't have a dangling BlockAddress constants by zapping the BlockAddress
759   // nodes, and only after that proceed with erasing G_BLOCK_ADDR instructions.
760   Constant *Replacement =
761       ConstantInt::get(Type::getInt32Ty(MF.getFunction().getContext()), 1);
762   for (MachineInstr *BlockAddrI : ToEraseMI) {
763     if (BlockAddrI->getOpcode() == TargetOpcode::G_BLOCK_ADDR) {
764       BlockAddress *BA = const_cast<BlockAddress *>(
765           BlockAddrI->getOperand(1).getBlockAddress());
766       BA->replaceAllUsesWith(
767           ConstantExpr::getIntToPtr(Replacement, BA->getType()));
768       BA->destroyConstant();
769     }
770     BlockAddrI->eraseFromParent();
771   }
772 }
773 
774 static bool isImplicitFallthrough(MachineBasicBlock &MBB) {
775   if (MBB.empty())
776     return true;
777 
778   // Branching SPIR-V intrinsics are not detected by this generic method.
779   // Thus, we can only trust negative result.
780   if (!MBB.canFallThrough())
781     return false;
782 
783   // Otherwise, we must manually check if we have a SPIR-V intrinsic which
784   // prevent an implicit fallthrough.
785   for (MachineBasicBlock::reverse_iterator It = MBB.rbegin(), E = MBB.rend();
786        It != E; ++It) {
787     if (isSpvIntrinsic(*It, Intrinsic::spv_switch))
788       return false;
789   }
790   return true;
791 }
792 
793 static void removeImplicitFallthroughs(MachineFunction &MF,
794                                        MachineIRBuilder MIB) {
795   // It is valid for MachineBasicBlocks to not finish with a branch instruction.
796   // In such cases, they will simply fallthrough their immediate successor.
797   for (MachineBasicBlock &MBB : MF) {
798     if (!isImplicitFallthrough(MBB))
799       continue;
800 
801     assert(std::distance(MBB.successors().begin(), MBB.successors().end()) ==
802            1);
803     MIB.setInsertPt(MBB, MBB.end());
804     MIB.buildBr(**MBB.successors().begin());
805   }
806 }
807 
808 bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
809   // Initialize the type registry.
810   const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
811   SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
812   GR->setCurrentFunc(MF);
813   MachineIRBuilder MIB(MF);
814   // a registry of target extension constants
815   DenseMap<MachineInstr *, Type *> TargetExtConstTypes;
816   addConstantsToTrack(MF, GR, ST, TargetExtConstTypes);
817   foldConstantsIntoIntrinsics(MF);
818   insertBitcasts(MF, GR, MIB);
819   generateAssignInstrs(MF, GR, MIB, TargetExtConstTypes);
820   processSwitches(MF, GR, MIB);
821   processInstrsWithTypeFolding(MF, GR, MIB);
822   removeImplicitFallthroughs(MF, MIB);
823   insertSpirvDecorations(MF, MIB);
824   insertInlineAsm(MF, GR, ST, MIB);
825 
826   return true;
827 }
828 
829 INITIALIZE_PASS(SPIRVPreLegalizer, DEBUG_TYPE, "SPIRV pre legalizer", false,
830                 false)
831 
832 char SPIRVPreLegalizer::ID = 0;
833 
834 FunctionPass *llvm::createSPIRVPreLegalizerPass() {
835   return new SPIRVPreLegalizer();
836 }
837