xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (revision 753f127f3ace09432b2baeffd71a308760641a62)
1 //===-- SPIRVPreLegalizer.cpp - prepare IR for legalization -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The pass prepares IR for legalization: it assigns SPIR-V types to registers
10 // and removes intrinsics which holded these types during IR translation.
11 // Also it processes constants and registers them in GR to avoid duplication.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "SPIRV.h"
16 #include "SPIRVGlobalRegistry.h"
17 #include "SPIRVSubtarget.h"
18 #include "SPIRVUtils.h"
19 #include "llvm/ADT/PostOrderIterator.h"
20 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
21 #include "llvm/IR/Attributes.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/DebugInfoMetadata.h"
24 #include "llvm/IR/IntrinsicsSPIRV.h"
25 #include "llvm/Target/TargetIntrinsicInfo.h"
26 
27 #define DEBUG_TYPE "spirv-prelegalizer"
28 
29 using namespace llvm;
30 
31 namespace {
32 class SPIRVPreLegalizer : public MachineFunctionPass {
33 public:
34   static char ID;
35   SPIRVPreLegalizer() : MachineFunctionPass(ID) {
36     initializeSPIRVPreLegalizerPass(*PassRegistry::getPassRegistry());
37   }
38   bool runOnMachineFunction(MachineFunction &MF) override;
39 };
40 } // namespace
41 
42 static bool isSpvIntrinsic(MachineInstr &MI, Intrinsic::ID IntrinsicID) {
43   if (MI.getOpcode() == TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS &&
44       MI.getIntrinsicID() == IntrinsicID)
45     return true;
46   return false;
47 }
48 
49 static void foldConstantsIntoIntrinsics(MachineFunction &MF) {
50   SmallVector<MachineInstr *, 10> ToErase;
51   MachineRegisterInfo &MRI = MF.getRegInfo();
52   const unsigned AssignNameOperandShift = 2;
53   for (MachineBasicBlock &MBB : MF) {
54     for (MachineInstr &MI : MBB) {
55       if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_name))
56         continue;
57       unsigned NumOp = MI.getNumExplicitDefs() + AssignNameOperandShift;
58       while (MI.getOperand(NumOp).isReg()) {
59         MachineOperand &MOp = MI.getOperand(NumOp);
60         MachineInstr *ConstMI = MRI.getVRegDef(MOp.getReg());
61         assert(ConstMI->getOpcode() == TargetOpcode::G_CONSTANT);
62         MI.removeOperand(NumOp);
63         MI.addOperand(MachineOperand::CreateImm(
64             ConstMI->getOperand(1).getCImm()->getZExtValue()));
65         if (MRI.use_empty(ConstMI->getOperand(0).getReg()))
66           ToErase.push_back(ConstMI);
67       }
68     }
69   }
70   for (MachineInstr *MI : ToErase)
71     MI->eraseFromParent();
72 }
73 
74 static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
75                            MachineIRBuilder MIB) {
76   SmallVector<MachineInstr *, 10> ToErase;
77   for (MachineBasicBlock &MBB : MF) {
78     for (MachineInstr &MI : MBB) {
79       if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast))
80         continue;
81       assert(MI.getOperand(2).isReg());
82       MIB.setInsertPt(*MI.getParent(), MI);
83       MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
84       ToErase.push_back(&MI);
85     }
86   }
87   for (MachineInstr *MI : ToErase)
88     MI->eraseFromParent();
89 }
90 
91 // Translating GV, IRTranslator sometimes generates following IR:
92 //   %1 = G_GLOBAL_VALUE
93 //   %2 = COPY %1
94 //   %3 = G_ADDRSPACE_CAST %2
95 // New registers have no SPIRVType and no register class info.
96 //
97 // Set SPIRVType for GV, propagate it from GV to other instructions,
98 // also set register classes.
99 static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
100                                      MachineRegisterInfo &MRI,
101                                      MachineIRBuilder &MIB) {
102   SPIRVType *SpirvTy = nullptr;
103   assert(MI && "Machine instr is expected");
104   if (MI->getOperand(0).isReg()) {
105     Register Reg = MI->getOperand(0).getReg();
106     SpirvTy = GR->getSPIRVTypeForVReg(Reg);
107     if (!SpirvTy) {
108       switch (MI->getOpcode()) {
109       case TargetOpcode::G_CONSTANT: {
110         MIB.setInsertPt(*MI->getParent(), MI);
111         Type *Ty = MI->getOperand(1).getCImm()->getType();
112         SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
113         break;
114       }
115       case TargetOpcode::G_GLOBAL_VALUE: {
116         MIB.setInsertPt(*MI->getParent(), MI);
117         Type *Ty = MI->getOperand(1).getGlobal()->getType();
118         SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
119         break;
120       }
121       case TargetOpcode::G_TRUNC:
122       case TargetOpcode::G_ADDRSPACE_CAST:
123       case TargetOpcode::COPY: {
124         MachineOperand &Op = MI->getOperand(1);
125         MachineInstr *Def = Op.isReg() ? MRI.getVRegDef(Op.getReg()) : nullptr;
126         if (Def)
127           SpirvTy = propagateSPIRVType(Def, GR, MRI, MIB);
128         break;
129       }
130       default:
131         break;
132       }
133       if (SpirvTy)
134         GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF());
135       if (!MRI.getRegClassOrNull(Reg))
136         MRI.setRegClass(Reg, &SPIRV::IDRegClass);
137     }
138   }
139   return SpirvTy;
140 }
141 
142 // Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as
143 // a dst of the definition, assign SPIRVType to both registers. If SpirvTy is
144 // provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty.
145 // TODO: maybe move to SPIRVUtils.
146 static Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
147                                   SPIRVGlobalRegistry *GR,
148                                   MachineIRBuilder &MIB,
149                                   MachineRegisterInfo &MRI) {
150   MachineInstr *Def = MRI.getVRegDef(Reg);
151   assert((Ty || SpirvTy) && "Either LLVM or SPIRV type is expected.");
152   MIB.setInsertPt(*Def->getParent(),
153                   (Def->getNextNode() ? Def->getNextNode()->getIterator()
154                                       : Def->getParent()->end()));
155   Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
156   if (auto *RC = MRI.getRegClassOrNull(Reg))
157     MRI.setRegClass(NewReg, RC);
158   SpirvTy = SpirvTy ? SpirvTy : GR->getOrCreateSPIRVType(Ty, MIB);
159   GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF());
160   // This is to make it convenient for Legalizer to get the SPIRVType
161   // when processing the actual MI (i.e. not pseudo one).
162   GR->assignSPIRVTypeToVReg(SpirvTy, NewReg, MIB.getMF());
163   MIB.buildInstr(SPIRV::ASSIGN_TYPE)
164       .addDef(Reg)
165       .addUse(NewReg)
166       .addUse(GR->getSPIRVTypeID(SpirvTy));
167   Def->getOperand(0).setReg(NewReg);
168   MRI.setRegClass(Reg, &SPIRV::ANYIDRegClass);
169   return NewReg;
170 }
171 
172 static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
173                                  MachineIRBuilder MIB) {
174   MachineRegisterInfo &MRI = MF.getRegInfo();
175   SmallVector<MachineInstr *, 10> ToErase;
176 
177   for (MachineBasicBlock *MBB : post_order(&MF)) {
178     if (MBB->empty())
179       continue;
180 
181     bool ReachedBegin = false;
182     for (auto MII = std::prev(MBB->end()), Begin = MBB->begin();
183          !ReachedBegin;) {
184       MachineInstr &MI = *MII;
185 
186       if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) {
187         Register Reg = MI.getOperand(1).getReg();
188         Type *Ty = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
189         MachineInstr *Def = MRI.getVRegDef(Reg);
190         assert(Def && "Expecting an instruction that defines the register");
191         // G_GLOBAL_VALUE already has type info.
192         if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE)
193           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo());
194         ToErase.push_back(&MI);
195       } else if (MI.getOpcode() == TargetOpcode::G_CONSTANT ||
196                  MI.getOpcode() == TargetOpcode::G_FCONSTANT ||
197                  MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
198         // %rc = G_CONSTANT ty Val
199         // ===>
200         // %cty = OpType* ty
201         // %rctmp = G_CONSTANT ty Val
202         // %rc = ASSIGN_TYPE %rctmp, %cty
203         Register Reg = MI.getOperand(0).getReg();
204         if (MRI.hasOneUse(Reg)) {
205           MachineInstr &UseMI = *MRI.use_instr_begin(Reg);
206           if (isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type) ||
207               isSpvIntrinsic(UseMI, Intrinsic::spv_assign_name))
208             continue;
209         }
210         Type *Ty = nullptr;
211         if (MI.getOpcode() == TargetOpcode::G_CONSTANT)
212           Ty = MI.getOperand(1).getCImm()->getType();
213         else if (MI.getOpcode() == TargetOpcode::G_FCONSTANT)
214           Ty = MI.getOperand(1).getFPImm()->getType();
215         else {
216           assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
217           Type *ElemTy = nullptr;
218           MachineInstr *ElemMI = MRI.getVRegDef(MI.getOperand(1).getReg());
219           assert(ElemMI);
220 
221           if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT)
222             ElemTy = ElemMI->getOperand(1).getCImm()->getType();
223           else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT)
224             ElemTy = ElemMI->getOperand(1).getFPImm()->getType();
225           else
226             llvm_unreachable("Unexpected opcode");
227           unsigned NumElts =
228               MI.getNumExplicitOperands() - MI.getNumExplicitDefs();
229           Ty = VectorType::get(ElemTy, NumElts, false);
230         }
231         insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI);
232       } else if (MI.getOpcode() == TargetOpcode::G_TRUNC ||
233                  MI.getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
234                  MI.getOpcode() == TargetOpcode::COPY ||
235                  MI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST) {
236         propagateSPIRVType(&MI, GR, MRI, MIB);
237       }
238 
239       if (MII == Begin)
240         ReachedBegin = true;
241       else
242         --MII;
243     }
244   }
245   for (MachineInstr *MI : ToErase)
246     MI->eraseFromParent();
247 }
248 
249 static std::pair<Register, unsigned>
250 createNewIdReg(Register ValReg, unsigned Opcode, MachineRegisterInfo &MRI,
251                const SPIRVGlobalRegistry &GR) {
252   LLT NewT = LLT::scalar(32);
253   SPIRVType *SpvType = GR.getSPIRVTypeForVReg(ValReg);
254   assert(SpvType && "VReg is expected to have SPIRV type");
255   bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat;
256   bool IsVectorFloat =
257       SpvType->getOpcode() == SPIRV::OpTypeVector &&
258       GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg())->getOpcode() ==
259           SPIRV::OpTypeFloat;
260   IsFloat |= IsVectorFloat;
261   auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID;
262   auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass;
263   if (MRI.getType(ValReg).isPointer()) {
264     NewT = LLT::pointer(0, 32);
265     GetIdOp = SPIRV::GET_pID;
266     DstClass = &SPIRV::pIDRegClass;
267   } else if (MRI.getType(ValReg).isVector()) {
268     NewT = LLT::fixed_vector(2, NewT);
269     GetIdOp = IsFloat ? SPIRV::GET_vfID : SPIRV::GET_vID;
270     DstClass = IsFloat ? &SPIRV::vfIDRegClass : &SPIRV::vIDRegClass;
271   }
272   Register IdReg = MRI.createGenericVirtualRegister(NewT);
273   MRI.setRegClass(IdReg, DstClass);
274   return {IdReg, GetIdOp};
275 }
276 
277 static void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
278                          MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) {
279   unsigned Opc = MI.getOpcode();
280   assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg()));
281   MachineInstr &AssignTypeInst =
282       *(MRI.use_instr_begin(MI.getOperand(0).getReg()));
283   auto NewReg = createNewIdReg(MI.getOperand(0).getReg(), Opc, MRI, *GR).first;
284   AssignTypeInst.getOperand(1).setReg(NewReg);
285   MI.getOperand(0).setReg(NewReg);
286   MIB.setInsertPt(*MI.getParent(),
287                   (MI.getNextNode() ? MI.getNextNode()->getIterator()
288                                     : MI.getParent()->end()));
289   for (auto &Op : MI.operands()) {
290     if (!Op.isReg() || Op.isDef())
291       continue;
292     auto IdOpInfo = createNewIdReg(Op.getReg(), Opc, MRI, *GR);
293     MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg());
294     Op.setReg(IdOpInfo.first);
295   }
296 }
297 
298 // Defined in SPIRVLegalizerInfo.cpp.
299 extern bool isTypeFoldingSupported(unsigned Opcode);
300 
301 static void processInstrsWithTypeFolding(MachineFunction &MF,
302                                          SPIRVGlobalRegistry *GR,
303                                          MachineIRBuilder MIB) {
304   MachineRegisterInfo &MRI = MF.getRegInfo();
305   for (MachineBasicBlock &MBB : MF) {
306     for (MachineInstr &MI : MBB) {
307       if (isTypeFoldingSupported(MI.getOpcode()))
308         processInstr(MI, MIB, MRI, GR);
309     }
310   }
311 }
312 
313 static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
314                             MachineIRBuilder MIB) {
315   DenseMap<Register, SmallDenseMap<uint64_t, MachineBasicBlock *>>
316       SwitchRegToMBB;
317   DenseMap<Register, MachineBasicBlock *> DefaultMBBs;
318   DenseSet<Register> SwitchRegs;
319   MachineRegisterInfo &MRI = MF.getRegInfo();
320   // Before IRTranslator pass, spv_switch calls are inserted before each
321   // switch instruction. IRTranslator lowers switches to ICMP+CBr+Br triples.
322   // A switch with two cases may be translated to this MIR sequesnce:
323   //   intrinsic(@llvm.spv.switch), %CmpReg, %Const0, %Const1
324   //   %Dst0 = G_ICMP intpred(eq), %CmpReg, %Const0
325   //   G_BRCOND %Dst0, %bb.2
326   //   G_BR %bb.5
327   // bb.5.entry:
328   //   %Dst1 = G_ICMP intpred(eq), %CmpReg, %Const1
329   //   G_BRCOND %Dst1, %bb.3
330   //   G_BR %bb.4
331   // bb.2.sw.bb:
332   //   ...
333   // bb.3.sw.bb1:
334   //   ...
335   // bb.4.sw.epilog:
336   //   ...
337   // Walk MIs and collect information about destination MBBs to update
338   // spv_switch call. We assume that all spv_switch precede corresponding ICMPs.
339   for (MachineBasicBlock &MBB : MF) {
340     for (MachineInstr &MI : MBB) {
341       if (isSpvIntrinsic(MI, Intrinsic::spv_switch)) {
342         assert(MI.getOperand(1).isReg());
343         Register Reg = MI.getOperand(1).getReg();
344         SwitchRegs.insert(Reg);
345         // Set the first successor as default MBB to support empty switches.
346         DefaultMBBs[Reg] = *MBB.succ_begin();
347       }
348       // Process only ICMPs that relate to spv_switches.
349       if (MI.getOpcode() == TargetOpcode::G_ICMP && MI.getOperand(2).isReg() &&
350           SwitchRegs.contains(MI.getOperand(2).getReg())) {
351         assert(MI.getOperand(0).isReg() && MI.getOperand(1).isPredicate() &&
352                MI.getOperand(3).isReg());
353         Register Dst = MI.getOperand(0).getReg();
354         // Set type info for destination register of switch's ICMP instruction.
355         if (GR->getSPIRVTypeForVReg(Dst) == nullptr) {
356           MIB.setInsertPt(*MI.getParent(), MI);
357           Type *LLVMTy = IntegerType::get(MF.getFunction().getContext(), 1);
358           SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, MIB);
359           MRI.setRegClass(Dst, &SPIRV::IDRegClass);
360           GR->assignSPIRVTypeToVReg(SpirvTy, Dst, MIB.getMF());
361         }
362         Register CmpReg = MI.getOperand(2).getReg();
363         MachineOperand &PredOp = MI.getOperand(1);
364         const auto CC = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
365         assert(CC == CmpInst::ICMP_EQ && MRI.hasOneUse(Dst) &&
366                MRI.hasOneDef(CmpReg));
367         uint64_t Val = getIConstVal(MI.getOperand(3).getReg(), &MRI);
368         MachineInstr *CBr = MRI.use_begin(Dst)->getParent();
369         assert(CBr->getOpcode() == SPIRV::G_BRCOND &&
370                CBr->getOperand(1).isMBB());
371         SwitchRegToMBB[CmpReg][Val] = CBr->getOperand(1).getMBB();
372         // The next MI is always BR to either the next case or the default.
373         MachineInstr *NextMI = CBr->getNextNode();
374         assert(NextMI->getOpcode() == SPIRV::G_BR &&
375                NextMI->getOperand(0).isMBB());
376         MachineBasicBlock *NextMBB = NextMI->getOperand(0).getMBB();
377         assert(NextMBB != nullptr);
378         // The default MBB is not started by ICMP with switch's cmp register.
379         if (NextMBB->front().getOpcode() != SPIRV::G_ICMP ||
380             (NextMBB->front().getOperand(2).isReg() &&
381              NextMBB->front().getOperand(2).getReg() != CmpReg))
382           DefaultMBBs[CmpReg] = NextMBB;
383       }
384     }
385   }
386   // Modify spv_switch's operands by collected values. For the example above,
387   // the result will be like this:
388   //   intrinsic(@llvm.spv.switch), %CmpReg, %bb.4, i32 0, %bb.2, i32 1, %bb.3
389   // Note that ICMP+CBr+Br sequences are not removed, but ModuleAnalysis marks
390   // them as skipped and AsmPrinter does not output them.
391   for (MachineBasicBlock &MBB : MF) {
392     for (MachineInstr &MI : MBB) {
393       if (!isSpvIntrinsic(MI, Intrinsic::spv_switch))
394         continue;
395       assert(MI.getOperand(1).isReg());
396       Register Reg = MI.getOperand(1).getReg();
397       unsigned NumOp = MI.getNumExplicitOperands();
398       SmallVector<const ConstantInt *, 3> Vals;
399       SmallVector<MachineBasicBlock *, 3> MBBs;
400       for (unsigned i = 2; i < NumOp; i++) {
401         Register CReg = MI.getOperand(i).getReg();
402         uint64_t Val = getIConstVal(CReg, &MRI);
403         MachineInstr *ConstInstr = getDefInstrMaybeConstant(CReg, &MRI);
404         Vals.push_back(ConstInstr->getOperand(1).getCImm());
405         MBBs.push_back(SwitchRegToMBB[Reg][Val]);
406       }
407       for (unsigned i = MI.getNumExplicitOperands() - 1; i > 1; i--)
408         MI.removeOperand(i);
409       MI.addOperand(MachineOperand::CreateMBB(DefaultMBBs[Reg]));
410       for (unsigned i = 0; i < Vals.size(); i++) {
411         MI.addOperand(MachineOperand::CreateCImm(Vals[i]));
412         MI.addOperand(MachineOperand::CreateMBB(MBBs[i]));
413       }
414     }
415   }
416 }
417 
418 bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
419   // Initialize the type registry.
420   const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
421   SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
422   GR->setCurrentFunc(MF);
423   MachineIRBuilder MIB(MF);
424   foldConstantsIntoIntrinsics(MF);
425   insertBitcasts(MF, GR, MIB);
426   generateAssignInstrs(MF, GR, MIB);
427   processInstrsWithTypeFolding(MF, GR, MIB);
428   processSwitches(MF, GR, MIB);
429 
430   return true;
431 }
432 
433 INITIALIZE_PASS(SPIRVPreLegalizer, DEBUG_TYPE, "SPIRV pre legalizer", false,
434                 false)
435 
436 char SPIRVPreLegalizer::ID = 0;
437 
438 FunctionPass *llvm::createSPIRVPreLegalizerPass() {
439   return new SPIRVPreLegalizer();
440 }
441