xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (revision 81ad626541db97eb356e2c1d4a20eb2a26a766ab)
1*81ad6265SDimitry Andric //===-- SPIRVPreLegalizer.cpp - prepare IR for legalization -----*- C++ -*-===//
2*81ad6265SDimitry Andric //
3*81ad6265SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*81ad6265SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*81ad6265SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*81ad6265SDimitry Andric //
7*81ad6265SDimitry Andric //===----------------------------------------------------------------------===//
8*81ad6265SDimitry Andric //
9*81ad6265SDimitry Andric // The pass prepares IR for legalization: it assigns SPIR-V types to registers
10*81ad6265SDimitry Andric // and removes intrinsics which holded these types during IR translation.
11*81ad6265SDimitry Andric // Also it processes constants and registers them in GR to avoid duplication.
12*81ad6265SDimitry Andric //
13*81ad6265SDimitry Andric //===----------------------------------------------------------------------===//
14*81ad6265SDimitry Andric 
15*81ad6265SDimitry Andric #include "SPIRV.h"
16*81ad6265SDimitry Andric #include "SPIRVGlobalRegistry.h"
17*81ad6265SDimitry Andric #include "SPIRVSubtarget.h"
18*81ad6265SDimitry Andric #include "SPIRVUtils.h"
19*81ad6265SDimitry Andric #include "llvm/ADT/PostOrderIterator.h"
20*81ad6265SDimitry Andric #include "llvm/Analysis/OptimizationRemarkEmitter.h"
21*81ad6265SDimitry Andric #include "llvm/IR/Attributes.h"
22*81ad6265SDimitry Andric #include "llvm/IR/Constants.h"
23*81ad6265SDimitry Andric #include "llvm/IR/DebugInfoMetadata.h"
24*81ad6265SDimitry Andric #include "llvm/IR/IntrinsicsSPIRV.h"
25*81ad6265SDimitry Andric #include "llvm/Target/TargetIntrinsicInfo.h"
26*81ad6265SDimitry Andric 
27*81ad6265SDimitry Andric #define DEBUG_TYPE "spirv-prelegalizer"
28*81ad6265SDimitry Andric 
29*81ad6265SDimitry Andric using namespace llvm;
30*81ad6265SDimitry Andric 
31*81ad6265SDimitry Andric namespace {
32*81ad6265SDimitry Andric class SPIRVPreLegalizer : public MachineFunctionPass {
33*81ad6265SDimitry Andric public:
34*81ad6265SDimitry Andric   static char ID;
35*81ad6265SDimitry Andric   SPIRVPreLegalizer() : MachineFunctionPass(ID) {
36*81ad6265SDimitry Andric     initializeSPIRVPreLegalizerPass(*PassRegistry::getPassRegistry());
37*81ad6265SDimitry Andric   }
38*81ad6265SDimitry Andric   bool runOnMachineFunction(MachineFunction &MF) override;
39*81ad6265SDimitry Andric };
40*81ad6265SDimitry Andric } // namespace
41*81ad6265SDimitry Andric 
42*81ad6265SDimitry Andric static bool isSpvIntrinsic(MachineInstr &MI, Intrinsic::ID IntrinsicID) {
43*81ad6265SDimitry Andric   if (MI.getOpcode() == TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS &&
44*81ad6265SDimitry Andric       MI.getIntrinsicID() == IntrinsicID)
45*81ad6265SDimitry Andric     return true;
46*81ad6265SDimitry Andric   return false;
47*81ad6265SDimitry Andric }
48*81ad6265SDimitry Andric 
49*81ad6265SDimitry Andric static void foldConstantsIntoIntrinsics(MachineFunction &MF) {
50*81ad6265SDimitry Andric   SmallVector<MachineInstr *, 10> ToErase;
51*81ad6265SDimitry Andric   MachineRegisterInfo &MRI = MF.getRegInfo();
52*81ad6265SDimitry Andric   const unsigned AssignNameOperandShift = 2;
53*81ad6265SDimitry Andric   for (MachineBasicBlock &MBB : MF) {
54*81ad6265SDimitry Andric     for (MachineInstr &MI : MBB) {
55*81ad6265SDimitry Andric       if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_name))
56*81ad6265SDimitry Andric         continue;
57*81ad6265SDimitry Andric       unsigned NumOp = MI.getNumExplicitDefs() + AssignNameOperandShift;
58*81ad6265SDimitry Andric       while (MI.getOperand(NumOp).isReg()) {
59*81ad6265SDimitry Andric         MachineOperand &MOp = MI.getOperand(NumOp);
60*81ad6265SDimitry Andric         MachineInstr *ConstMI = MRI.getVRegDef(MOp.getReg());
61*81ad6265SDimitry Andric         assert(ConstMI->getOpcode() == TargetOpcode::G_CONSTANT);
62*81ad6265SDimitry Andric         MI.removeOperand(NumOp);
63*81ad6265SDimitry Andric         MI.addOperand(MachineOperand::CreateImm(
64*81ad6265SDimitry Andric             ConstMI->getOperand(1).getCImm()->getZExtValue()));
65*81ad6265SDimitry Andric         if (MRI.use_empty(ConstMI->getOperand(0).getReg()))
66*81ad6265SDimitry Andric           ToErase.push_back(ConstMI);
67*81ad6265SDimitry Andric       }
68*81ad6265SDimitry Andric     }
69*81ad6265SDimitry Andric   }
70*81ad6265SDimitry Andric   for (MachineInstr *MI : ToErase)
71*81ad6265SDimitry Andric     MI->eraseFromParent();
72*81ad6265SDimitry Andric }
73*81ad6265SDimitry Andric 
74*81ad6265SDimitry Andric static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
75*81ad6265SDimitry Andric                            MachineIRBuilder MIB) {
76*81ad6265SDimitry Andric   SmallVector<MachineInstr *, 10> ToErase;
77*81ad6265SDimitry Andric   for (MachineBasicBlock &MBB : MF) {
78*81ad6265SDimitry Andric     for (MachineInstr &MI : MBB) {
79*81ad6265SDimitry Andric       if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast))
80*81ad6265SDimitry Andric         continue;
81*81ad6265SDimitry Andric       assert(MI.getOperand(2).isReg());
82*81ad6265SDimitry Andric       MIB.setInsertPt(*MI.getParent(), MI);
83*81ad6265SDimitry Andric       MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
84*81ad6265SDimitry Andric       ToErase.push_back(&MI);
85*81ad6265SDimitry Andric     }
86*81ad6265SDimitry Andric   }
87*81ad6265SDimitry Andric   for (MachineInstr *MI : ToErase)
88*81ad6265SDimitry Andric     MI->eraseFromParent();
89*81ad6265SDimitry Andric }
90*81ad6265SDimitry Andric 
91*81ad6265SDimitry Andric // Translating GV, IRTranslator sometimes generates following IR:
92*81ad6265SDimitry Andric //   %1 = G_GLOBAL_VALUE
93*81ad6265SDimitry Andric //   %2 = COPY %1
94*81ad6265SDimitry Andric //   %3 = G_ADDRSPACE_CAST %2
95*81ad6265SDimitry Andric // New registers have no SPIRVType and no register class info.
96*81ad6265SDimitry Andric //
97*81ad6265SDimitry Andric // Set SPIRVType for GV, propagate it from GV to other instructions,
98*81ad6265SDimitry Andric // also set register classes.
99*81ad6265SDimitry Andric static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
100*81ad6265SDimitry Andric                                      MachineRegisterInfo &MRI,
101*81ad6265SDimitry Andric                                      MachineIRBuilder &MIB) {
102*81ad6265SDimitry Andric   SPIRVType *SpirvTy = nullptr;
103*81ad6265SDimitry Andric   assert(MI && "Machine instr is expected");
104*81ad6265SDimitry Andric   if (MI->getOperand(0).isReg()) {
105*81ad6265SDimitry Andric     Register Reg = MI->getOperand(0).getReg();
106*81ad6265SDimitry Andric     SpirvTy = GR->getSPIRVTypeForVReg(Reg);
107*81ad6265SDimitry Andric     if (!SpirvTy) {
108*81ad6265SDimitry Andric       switch (MI->getOpcode()) {
109*81ad6265SDimitry Andric       case TargetOpcode::G_CONSTANT: {
110*81ad6265SDimitry Andric         MIB.setInsertPt(*MI->getParent(), MI);
111*81ad6265SDimitry Andric         Type *Ty = MI->getOperand(1).getCImm()->getType();
112*81ad6265SDimitry Andric         SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
113*81ad6265SDimitry Andric         break;
114*81ad6265SDimitry Andric       }
115*81ad6265SDimitry Andric       case TargetOpcode::G_GLOBAL_VALUE: {
116*81ad6265SDimitry Andric         MIB.setInsertPt(*MI->getParent(), MI);
117*81ad6265SDimitry Andric         Type *Ty = MI->getOperand(1).getGlobal()->getType();
118*81ad6265SDimitry Andric         SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
119*81ad6265SDimitry Andric         break;
120*81ad6265SDimitry Andric       }
121*81ad6265SDimitry Andric       case TargetOpcode::G_TRUNC:
122*81ad6265SDimitry Andric       case TargetOpcode::G_ADDRSPACE_CAST:
123*81ad6265SDimitry Andric       case TargetOpcode::COPY: {
124*81ad6265SDimitry Andric         MachineOperand &Op = MI->getOperand(1);
125*81ad6265SDimitry Andric         MachineInstr *Def = Op.isReg() ? MRI.getVRegDef(Op.getReg()) : nullptr;
126*81ad6265SDimitry Andric         if (Def)
127*81ad6265SDimitry Andric           SpirvTy = propagateSPIRVType(Def, GR, MRI, MIB);
128*81ad6265SDimitry Andric         break;
129*81ad6265SDimitry Andric       }
130*81ad6265SDimitry Andric       default:
131*81ad6265SDimitry Andric         break;
132*81ad6265SDimitry Andric       }
133*81ad6265SDimitry Andric       if (SpirvTy)
134*81ad6265SDimitry Andric         GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF());
135*81ad6265SDimitry Andric       if (!MRI.getRegClassOrNull(Reg))
136*81ad6265SDimitry Andric         MRI.setRegClass(Reg, &SPIRV::IDRegClass);
137*81ad6265SDimitry Andric     }
138*81ad6265SDimitry Andric   }
139*81ad6265SDimitry Andric   return SpirvTy;
140*81ad6265SDimitry Andric }
141*81ad6265SDimitry Andric 
142*81ad6265SDimitry Andric // Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as
143*81ad6265SDimitry Andric // a dst of the definition, assign SPIRVType to both registers. If SpirvTy is
144*81ad6265SDimitry Andric // provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty.
145*81ad6265SDimitry Andric // TODO: maybe move to SPIRVUtils.
146*81ad6265SDimitry Andric static Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
147*81ad6265SDimitry Andric                                   SPIRVGlobalRegistry *GR,
148*81ad6265SDimitry Andric                                   MachineIRBuilder &MIB,
149*81ad6265SDimitry Andric                                   MachineRegisterInfo &MRI) {
150*81ad6265SDimitry Andric   MachineInstr *Def = MRI.getVRegDef(Reg);
151*81ad6265SDimitry Andric   assert((Ty || SpirvTy) && "Either LLVM or SPIRV type is expected.");
152*81ad6265SDimitry Andric   MIB.setInsertPt(*Def->getParent(),
153*81ad6265SDimitry Andric                   (Def->getNextNode() ? Def->getNextNode()->getIterator()
154*81ad6265SDimitry Andric                                       : Def->getParent()->end()));
155*81ad6265SDimitry Andric   Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
156*81ad6265SDimitry Andric   if (auto *RC = MRI.getRegClassOrNull(Reg))
157*81ad6265SDimitry Andric     MRI.setRegClass(NewReg, RC);
158*81ad6265SDimitry Andric   SpirvTy = SpirvTy ? SpirvTy : GR->getOrCreateSPIRVType(Ty, MIB);
159*81ad6265SDimitry Andric   GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF());
160*81ad6265SDimitry Andric   // This is to make it convenient for Legalizer to get the SPIRVType
161*81ad6265SDimitry Andric   // when processing the actual MI (i.e. not pseudo one).
162*81ad6265SDimitry Andric   GR->assignSPIRVTypeToVReg(SpirvTy, NewReg, MIB.getMF());
163*81ad6265SDimitry Andric   MIB.buildInstr(SPIRV::ASSIGN_TYPE)
164*81ad6265SDimitry Andric       .addDef(Reg)
165*81ad6265SDimitry Andric       .addUse(NewReg)
166*81ad6265SDimitry Andric       .addUse(GR->getSPIRVTypeID(SpirvTy));
167*81ad6265SDimitry Andric   Def->getOperand(0).setReg(NewReg);
168*81ad6265SDimitry Andric   MRI.setRegClass(Reg, &SPIRV::ANYIDRegClass);
169*81ad6265SDimitry Andric   return NewReg;
170*81ad6265SDimitry Andric }
171*81ad6265SDimitry Andric 
172*81ad6265SDimitry Andric static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
173*81ad6265SDimitry Andric                                  MachineIRBuilder MIB) {
174*81ad6265SDimitry Andric   MachineRegisterInfo &MRI = MF.getRegInfo();
175*81ad6265SDimitry Andric   SmallVector<MachineInstr *, 10> ToErase;
176*81ad6265SDimitry Andric 
177*81ad6265SDimitry Andric   for (MachineBasicBlock *MBB : post_order(&MF)) {
178*81ad6265SDimitry Andric     if (MBB->empty())
179*81ad6265SDimitry Andric       continue;
180*81ad6265SDimitry Andric 
181*81ad6265SDimitry Andric     bool ReachedBegin = false;
182*81ad6265SDimitry Andric     for (auto MII = std::prev(MBB->end()), Begin = MBB->begin();
183*81ad6265SDimitry Andric          !ReachedBegin;) {
184*81ad6265SDimitry Andric       MachineInstr &MI = *MII;
185*81ad6265SDimitry Andric 
186*81ad6265SDimitry Andric       if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) {
187*81ad6265SDimitry Andric         Register Reg = MI.getOperand(1).getReg();
188*81ad6265SDimitry Andric         Type *Ty = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
189*81ad6265SDimitry Andric         MachineInstr *Def = MRI.getVRegDef(Reg);
190*81ad6265SDimitry Andric         assert(Def && "Expecting an instruction that defines the register");
191*81ad6265SDimitry Andric         // G_GLOBAL_VALUE already has type info.
192*81ad6265SDimitry Andric         if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE)
193*81ad6265SDimitry Andric           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo());
194*81ad6265SDimitry Andric         ToErase.push_back(&MI);
195*81ad6265SDimitry Andric       } else if (MI.getOpcode() == TargetOpcode::G_CONSTANT ||
196*81ad6265SDimitry Andric                  MI.getOpcode() == TargetOpcode::G_FCONSTANT ||
197*81ad6265SDimitry Andric                  MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
198*81ad6265SDimitry Andric         // %rc = G_CONSTANT ty Val
199*81ad6265SDimitry Andric         // ===>
200*81ad6265SDimitry Andric         // %cty = OpType* ty
201*81ad6265SDimitry Andric         // %rctmp = G_CONSTANT ty Val
202*81ad6265SDimitry Andric         // %rc = ASSIGN_TYPE %rctmp, %cty
203*81ad6265SDimitry Andric         Register Reg = MI.getOperand(0).getReg();
204*81ad6265SDimitry Andric         if (MRI.hasOneUse(Reg)) {
205*81ad6265SDimitry Andric           MachineInstr &UseMI = *MRI.use_instr_begin(Reg);
206*81ad6265SDimitry Andric           if (isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type) ||
207*81ad6265SDimitry Andric               isSpvIntrinsic(UseMI, Intrinsic::spv_assign_name))
208*81ad6265SDimitry Andric             continue;
209*81ad6265SDimitry Andric         }
210*81ad6265SDimitry Andric         Type *Ty = nullptr;
211*81ad6265SDimitry Andric         if (MI.getOpcode() == TargetOpcode::G_CONSTANT)
212*81ad6265SDimitry Andric           Ty = MI.getOperand(1).getCImm()->getType();
213*81ad6265SDimitry Andric         else if (MI.getOpcode() == TargetOpcode::G_FCONSTANT)
214*81ad6265SDimitry Andric           Ty = MI.getOperand(1).getFPImm()->getType();
215*81ad6265SDimitry Andric         else {
216*81ad6265SDimitry Andric           assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
217*81ad6265SDimitry Andric           Type *ElemTy = nullptr;
218*81ad6265SDimitry Andric           MachineInstr *ElemMI = MRI.getVRegDef(MI.getOperand(1).getReg());
219*81ad6265SDimitry Andric           assert(ElemMI);
220*81ad6265SDimitry Andric 
221*81ad6265SDimitry Andric           if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT)
222*81ad6265SDimitry Andric             ElemTy = ElemMI->getOperand(1).getCImm()->getType();
223*81ad6265SDimitry Andric           else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT)
224*81ad6265SDimitry Andric             ElemTy = ElemMI->getOperand(1).getFPImm()->getType();
225*81ad6265SDimitry Andric           else
226*81ad6265SDimitry Andric             llvm_unreachable("Unexpected opcode");
227*81ad6265SDimitry Andric           unsigned NumElts =
228*81ad6265SDimitry Andric               MI.getNumExplicitOperands() - MI.getNumExplicitDefs();
229*81ad6265SDimitry Andric           Ty = VectorType::get(ElemTy, NumElts, false);
230*81ad6265SDimitry Andric         }
231*81ad6265SDimitry Andric         insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI);
232*81ad6265SDimitry Andric       } else if (MI.getOpcode() == TargetOpcode::G_TRUNC ||
233*81ad6265SDimitry Andric                  MI.getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
234*81ad6265SDimitry Andric                  MI.getOpcode() == TargetOpcode::COPY ||
235*81ad6265SDimitry Andric                  MI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST) {
236*81ad6265SDimitry Andric         propagateSPIRVType(&MI, GR, MRI, MIB);
237*81ad6265SDimitry Andric       }
238*81ad6265SDimitry Andric 
239*81ad6265SDimitry Andric       if (MII == Begin)
240*81ad6265SDimitry Andric         ReachedBegin = true;
241*81ad6265SDimitry Andric       else
242*81ad6265SDimitry Andric         --MII;
243*81ad6265SDimitry Andric     }
244*81ad6265SDimitry Andric   }
245*81ad6265SDimitry Andric   for (MachineInstr *MI : ToErase)
246*81ad6265SDimitry Andric     MI->eraseFromParent();
247*81ad6265SDimitry Andric }
248*81ad6265SDimitry Andric 
249*81ad6265SDimitry Andric static std::pair<Register, unsigned>
250*81ad6265SDimitry Andric createNewIdReg(Register ValReg, unsigned Opcode, MachineRegisterInfo &MRI,
251*81ad6265SDimitry Andric                const SPIRVGlobalRegistry &GR) {
252*81ad6265SDimitry Andric   LLT NewT = LLT::scalar(32);
253*81ad6265SDimitry Andric   SPIRVType *SpvType = GR.getSPIRVTypeForVReg(ValReg);
254*81ad6265SDimitry Andric   assert(SpvType && "VReg is expected to have SPIRV type");
255*81ad6265SDimitry Andric   bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat;
256*81ad6265SDimitry Andric   bool IsVectorFloat =
257*81ad6265SDimitry Andric       SpvType->getOpcode() == SPIRV::OpTypeVector &&
258*81ad6265SDimitry Andric       GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg())->getOpcode() ==
259*81ad6265SDimitry Andric           SPIRV::OpTypeFloat;
260*81ad6265SDimitry Andric   IsFloat |= IsVectorFloat;
261*81ad6265SDimitry Andric   auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID;
262*81ad6265SDimitry Andric   auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass;
263*81ad6265SDimitry Andric   if (MRI.getType(ValReg).isPointer()) {
264*81ad6265SDimitry Andric     NewT = LLT::pointer(0, 32);
265*81ad6265SDimitry Andric     GetIdOp = SPIRV::GET_pID;
266*81ad6265SDimitry Andric     DstClass = &SPIRV::pIDRegClass;
267*81ad6265SDimitry Andric   } else if (MRI.getType(ValReg).isVector()) {
268*81ad6265SDimitry Andric     NewT = LLT::fixed_vector(2, NewT);
269*81ad6265SDimitry Andric     GetIdOp = IsFloat ? SPIRV::GET_vfID : SPIRV::GET_vID;
270*81ad6265SDimitry Andric     DstClass = IsFloat ? &SPIRV::vfIDRegClass : &SPIRV::vIDRegClass;
271*81ad6265SDimitry Andric   }
272*81ad6265SDimitry Andric   Register IdReg = MRI.createGenericVirtualRegister(NewT);
273*81ad6265SDimitry Andric   MRI.setRegClass(IdReg, DstClass);
274*81ad6265SDimitry Andric   return {IdReg, GetIdOp};
275*81ad6265SDimitry Andric }
276*81ad6265SDimitry Andric 
277*81ad6265SDimitry Andric static void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
278*81ad6265SDimitry Andric                          MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) {
279*81ad6265SDimitry Andric   unsigned Opc = MI.getOpcode();
280*81ad6265SDimitry Andric   assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg()));
281*81ad6265SDimitry Andric   MachineInstr &AssignTypeInst =
282*81ad6265SDimitry Andric       *(MRI.use_instr_begin(MI.getOperand(0).getReg()));
283*81ad6265SDimitry Andric   auto NewReg = createNewIdReg(MI.getOperand(0).getReg(), Opc, MRI, *GR).first;
284*81ad6265SDimitry Andric   AssignTypeInst.getOperand(1).setReg(NewReg);
285*81ad6265SDimitry Andric   MI.getOperand(0).setReg(NewReg);
286*81ad6265SDimitry Andric   MIB.setInsertPt(*MI.getParent(),
287*81ad6265SDimitry Andric                   (MI.getNextNode() ? MI.getNextNode()->getIterator()
288*81ad6265SDimitry Andric                                     : MI.getParent()->end()));
289*81ad6265SDimitry Andric   for (auto &Op : MI.operands()) {
290*81ad6265SDimitry Andric     if (!Op.isReg() || Op.isDef())
291*81ad6265SDimitry Andric       continue;
292*81ad6265SDimitry Andric     auto IdOpInfo = createNewIdReg(Op.getReg(), Opc, MRI, *GR);
293*81ad6265SDimitry Andric     MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg());
294*81ad6265SDimitry Andric     Op.setReg(IdOpInfo.first);
295*81ad6265SDimitry Andric   }
296*81ad6265SDimitry Andric }
297*81ad6265SDimitry Andric 
298*81ad6265SDimitry Andric // Defined in SPIRVLegalizerInfo.cpp.
299*81ad6265SDimitry Andric extern bool isTypeFoldingSupported(unsigned Opcode);
300*81ad6265SDimitry Andric 
301*81ad6265SDimitry Andric static void processInstrsWithTypeFolding(MachineFunction &MF,
302*81ad6265SDimitry Andric                                          SPIRVGlobalRegistry *GR,
303*81ad6265SDimitry Andric                                          MachineIRBuilder MIB) {
304*81ad6265SDimitry Andric   MachineRegisterInfo &MRI = MF.getRegInfo();
305*81ad6265SDimitry Andric   for (MachineBasicBlock &MBB : MF) {
306*81ad6265SDimitry Andric     for (MachineInstr &MI : MBB) {
307*81ad6265SDimitry Andric       if (isTypeFoldingSupported(MI.getOpcode()))
308*81ad6265SDimitry Andric         processInstr(MI, MIB, MRI, GR);
309*81ad6265SDimitry Andric     }
310*81ad6265SDimitry Andric   }
311*81ad6265SDimitry Andric }
312*81ad6265SDimitry Andric 
313*81ad6265SDimitry Andric static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
314*81ad6265SDimitry Andric                             MachineIRBuilder MIB) {
315*81ad6265SDimitry Andric   DenseMap<Register, SmallDenseMap<uint64_t, MachineBasicBlock *>>
316*81ad6265SDimitry Andric       SwitchRegToMBB;
317*81ad6265SDimitry Andric   DenseMap<Register, MachineBasicBlock *> DefaultMBBs;
318*81ad6265SDimitry Andric   DenseSet<Register> SwitchRegs;
319*81ad6265SDimitry Andric   MachineRegisterInfo &MRI = MF.getRegInfo();
320*81ad6265SDimitry Andric   // Before IRTranslator pass, spv_switch calls are inserted before each
321*81ad6265SDimitry Andric   // switch instruction. IRTranslator lowers switches to ICMP+CBr+Br triples.
322*81ad6265SDimitry Andric   // A switch with two cases may be translated to this MIR sequesnce:
323*81ad6265SDimitry Andric   //   intrinsic(@llvm.spv.switch), %CmpReg, %Const0, %Const1
324*81ad6265SDimitry Andric   //   %Dst0 = G_ICMP intpred(eq), %CmpReg, %Const0
325*81ad6265SDimitry Andric   //   G_BRCOND %Dst0, %bb.2
326*81ad6265SDimitry Andric   //   G_BR %bb.5
327*81ad6265SDimitry Andric   // bb.5.entry:
328*81ad6265SDimitry Andric   //   %Dst1 = G_ICMP intpred(eq), %CmpReg, %Const1
329*81ad6265SDimitry Andric   //   G_BRCOND %Dst1, %bb.3
330*81ad6265SDimitry Andric   //   G_BR %bb.4
331*81ad6265SDimitry Andric   // bb.2.sw.bb:
332*81ad6265SDimitry Andric   //   ...
333*81ad6265SDimitry Andric   // bb.3.sw.bb1:
334*81ad6265SDimitry Andric   //   ...
335*81ad6265SDimitry Andric   // bb.4.sw.epilog:
336*81ad6265SDimitry Andric   //   ...
337*81ad6265SDimitry Andric   // Walk MIs and collect information about destination MBBs to update
338*81ad6265SDimitry Andric   // spv_switch call. We assume that all spv_switch precede corresponding ICMPs.
339*81ad6265SDimitry Andric   for (MachineBasicBlock &MBB : MF) {
340*81ad6265SDimitry Andric     for (MachineInstr &MI : MBB) {
341*81ad6265SDimitry Andric       if (isSpvIntrinsic(MI, Intrinsic::spv_switch)) {
342*81ad6265SDimitry Andric         assert(MI.getOperand(1).isReg());
343*81ad6265SDimitry Andric         Register Reg = MI.getOperand(1).getReg();
344*81ad6265SDimitry Andric         SwitchRegs.insert(Reg);
345*81ad6265SDimitry Andric         // Set the first successor as default MBB to support empty switches.
346*81ad6265SDimitry Andric         DefaultMBBs[Reg] = *MBB.succ_begin();
347*81ad6265SDimitry Andric       }
348*81ad6265SDimitry Andric       // Process only ICMPs that relate to spv_switches.
349*81ad6265SDimitry Andric       if (MI.getOpcode() == TargetOpcode::G_ICMP && MI.getOperand(2).isReg() &&
350*81ad6265SDimitry Andric           SwitchRegs.contains(MI.getOperand(2).getReg())) {
351*81ad6265SDimitry Andric         assert(MI.getOperand(0).isReg() && MI.getOperand(1).isPredicate() &&
352*81ad6265SDimitry Andric                MI.getOperand(3).isReg());
353*81ad6265SDimitry Andric         Register Dst = MI.getOperand(0).getReg();
354*81ad6265SDimitry Andric         // Set type info for destination register of switch's ICMP instruction.
355*81ad6265SDimitry Andric         if (GR->getSPIRVTypeForVReg(Dst) == nullptr) {
356*81ad6265SDimitry Andric           MIB.setInsertPt(*MI.getParent(), MI);
357*81ad6265SDimitry Andric           Type *LLVMTy = IntegerType::get(MF.getFunction().getContext(), 1);
358*81ad6265SDimitry Andric           SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, MIB);
359*81ad6265SDimitry Andric           MRI.setRegClass(Dst, &SPIRV::IDRegClass);
360*81ad6265SDimitry Andric           GR->assignSPIRVTypeToVReg(SpirvTy, Dst, MIB.getMF());
361*81ad6265SDimitry Andric         }
362*81ad6265SDimitry Andric         Register CmpReg = MI.getOperand(2).getReg();
363*81ad6265SDimitry Andric         MachineOperand &PredOp = MI.getOperand(1);
364*81ad6265SDimitry Andric         const auto CC = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
365*81ad6265SDimitry Andric         assert(CC == CmpInst::ICMP_EQ && MRI.hasOneUse(Dst) &&
366*81ad6265SDimitry Andric                MRI.hasOneDef(CmpReg));
367*81ad6265SDimitry Andric         uint64_t Val = getIConstVal(MI.getOperand(3).getReg(), &MRI);
368*81ad6265SDimitry Andric         MachineInstr *CBr = MRI.use_begin(Dst)->getParent();
369*81ad6265SDimitry Andric         assert(CBr->getOpcode() == SPIRV::G_BRCOND &&
370*81ad6265SDimitry Andric                CBr->getOperand(1).isMBB());
371*81ad6265SDimitry Andric         SwitchRegToMBB[CmpReg][Val] = CBr->getOperand(1).getMBB();
372*81ad6265SDimitry Andric         // The next MI is always BR to either the next case or the default.
373*81ad6265SDimitry Andric         MachineInstr *NextMI = CBr->getNextNode();
374*81ad6265SDimitry Andric         assert(NextMI->getOpcode() == SPIRV::G_BR &&
375*81ad6265SDimitry Andric                NextMI->getOperand(0).isMBB());
376*81ad6265SDimitry Andric         MachineBasicBlock *NextMBB = NextMI->getOperand(0).getMBB();
377*81ad6265SDimitry Andric         assert(NextMBB != nullptr);
378*81ad6265SDimitry Andric         // The default MBB is not started by ICMP with switch's cmp register.
379*81ad6265SDimitry Andric         if (NextMBB->front().getOpcode() != SPIRV::G_ICMP ||
380*81ad6265SDimitry Andric             (NextMBB->front().getOperand(2).isReg() &&
381*81ad6265SDimitry Andric              NextMBB->front().getOperand(2).getReg() != CmpReg))
382*81ad6265SDimitry Andric           DefaultMBBs[CmpReg] = NextMBB;
383*81ad6265SDimitry Andric       }
384*81ad6265SDimitry Andric     }
385*81ad6265SDimitry Andric   }
386*81ad6265SDimitry Andric   // Modify spv_switch's operands by collected values. For the example above,
387*81ad6265SDimitry Andric   // the result will be like this:
388*81ad6265SDimitry Andric   //   intrinsic(@llvm.spv.switch), %CmpReg, %bb.4, i32 0, %bb.2, i32 1, %bb.3
389*81ad6265SDimitry Andric   // Note that ICMP+CBr+Br sequences are not removed, but ModuleAnalysis marks
390*81ad6265SDimitry Andric   // them as skipped and AsmPrinter does not output them.
391*81ad6265SDimitry Andric   for (MachineBasicBlock &MBB : MF) {
392*81ad6265SDimitry Andric     for (MachineInstr &MI : MBB) {
393*81ad6265SDimitry Andric       if (!isSpvIntrinsic(MI, Intrinsic::spv_switch))
394*81ad6265SDimitry Andric         continue;
395*81ad6265SDimitry Andric       assert(MI.getOperand(1).isReg());
396*81ad6265SDimitry Andric       Register Reg = MI.getOperand(1).getReg();
397*81ad6265SDimitry Andric       unsigned NumOp = MI.getNumExplicitOperands();
398*81ad6265SDimitry Andric       SmallVector<const ConstantInt *, 3> Vals;
399*81ad6265SDimitry Andric       SmallVector<MachineBasicBlock *, 3> MBBs;
400*81ad6265SDimitry Andric       for (unsigned i = 2; i < NumOp; i++) {
401*81ad6265SDimitry Andric         Register CReg = MI.getOperand(i).getReg();
402*81ad6265SDimitry Andric         uint64_t Val = getIConstVal(CReg, &MRI);
403*81ad6265SDimitry Andric         MachineInstr *ConstInstr = getDefInstrMaybeConstant(CReg, &MRI);
404*81ad6265SDimitry Andric         Vals.push_back(ConstInstr->getOperand(1).getCImm());
405*81ad6265SDimitry Andric         MBBs.push_back(SwitchRegToMBB[Reg][Val]);
406*81ad6265SDimitry Andric       }
407*81ad6265SDimitry Andric       for (unsigned i = MI.getNumExplicitOperands() - 1; i > 1; i--)
408*81ad6265SDimitry Andric         MI.removeOperand(i);
409*81ad6265SDimitry Andric       MI.addOperand(MachineOperand::CreateMBB(DefaultMBBs[Reg]));
410*81ad6265SDimitry Andric       for (unsigned i = 0; i < Vals.size(); i++) {
411*81ad6265SDimitry Andric         MI.addOperand(MachineOperand::CreateCImm(Vals[i]));
412*81ad6265SDimitry Andric         MI.addOperand(MachineOperand::CreateMBB(MBBs[i]));
413*81ad6265SDimitry Andric       }
414*81ad6265SDimitry Andric     }
415*81ad6265SDimitry Andric   }
416*81ad6265SDimitry Andric }
417*81ad6265SDimitry Andric 
418*81ad6265SDimitry Andric bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
419*81ad6265SDimitry Andric   // Initialize the type registry.
420*81ad6265SDimitry Andric   const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
421*81ad6265SDimitry Andric   SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
422*81ad6265SDimitry Andric   GR->setCurrentFunc(MF);
423*81ad6265SDimitry Andric   MachineIRBuilder MIB(MF);
424*81ad6265SDimitry Andric   foldConstantsIntoIntrinsics(MF);
425*81ad6265SDimitry Andric   insertBitcasts(MF, GR, MIB);
426*81ad6265SDimitry Andric   generateAssignInstrs(MF, GR, MIB);
427*81ad6265SDimitry Andric   processInstrsWithTypeFolding(MF, GR, MIB);
428*81ad6265SDimitry Andric   processSwitches(MF, GR, MIB);
429*81ad6265SDimitry Andric 
430*81ad6265SDimitry Andric   return true;
431*81ad6265SDimitry Andric }
432*81ad6265SDimitry Andric 
433*81ad6265SDimitry Andric INITIALIZE_PASS(SPIRVPreLegalizer, DEBUG_TYPE, "SPIRV pre legalizer", false,
434*81ad6265SDimitry Andric                 false)
435*81ad6265SDimitry Andric 
436*81ad6265SDimitry Andric char SPIRVPreLegalizer::ID = 0;
437*81ad6265SDimitry Andric 
438*81ad6265SDimitry Andric FunctionPass *llvm::createSPIRVPreLegalizerPass() {
439*81ad6265SDimitry Andric   return new SPIRVPreLegalizer();
440*81ad6265SDimitry Andric }
441