xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (revision fcaf7f8644a9988098ac6be2165bce3ea4786e91)
181ad6265SDimitry Andric //===-- SPIRVPreLegalizer.cpp - prepare IR for legalization -----*- C++ -*-===//
281ad6265SDimitry Andric //
381ad6265SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
481ad6265SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
581ad6265SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
681ad6265SDimitry Andric //
781ad6265SDimitry Andric //===----------------------------------------------------------------------===//
881ad6265SDimitry Andric //
981ad6265SDimitry Andric // The pass prepares IR for legalization: it assigns SPIR-V types to registers
1081ad6265SDimitry Andric // and removes intrinsics which holded these types during IR translation.
1181ad6265SDimitry Andric // Also it processes constants and registers them in GR to avoid duplication.
1281ad6265SDimitry Andric //
1381ad6265SDimitry Andric //===----------------------------------------------------------------------===//
1481ad6265SDimitry Andric 
1581ad6265SDimitry Andric #include "SPIRV.h"
1681ad6265SDimitry Andric #include "SPIRVGlobalRegistry.h"
1781ad6265SDimitry Andric #include "SPIRVSubtarget.h"
1881ad6265SDimitry Andric #include "SPIRVUtils.h"
1981ad6265SDimitry Andric #include "llvm/ADT/PostOrderIterator.h"
2081ad6265SDimitry Andric #include "llvm/Analysis/OptimizationRemarkEmitter.h"
2181ad6265SDimitry Andric #include "llvm/IR/Attributes.h"
2281ad6265SDimitry Andric #include "llvm/IR/Constants.h"
2381ad6265SDimitry Andric #include "llvm/IR/DebugInfoMetadata.h"
2481ad6265SDimitry Andric #include "llvm/IR/IntrinsicsSPIRV.h"
2581ad6265SDimitry Andric #include "llvm/Target/TargetIntrinsicInfo.h"
2681ad6265SDimitry Andric 
2781ad6265SDimitry Andric #define DEBUG_TYPE "spirv-prelegalizer"
2881ad6265SDimitry Andric 
2981ad6265SDimitry Andric using namespace llvm;
3081ad6265SDimitry Andric 
3181ad6265SDimitry Andric namespace {
3281ad6265SDimitry Andric class SPIRVPreLegalizer : public MachineFunctionPass {
3381ad6265SDimitry Andric public:
3481ad6265SDimitry Andric   static char ID;
3581ad6265SDimitry Andric   SPIRVPreLegalizer() : MachineFunctionPass(ID) {
3681ad6265SDimitry Andric     initializeSPIRVPreLegalizerPass(*PassRegistry::getPassRegistry());
3781ad6265SDimitry Andric   }
3881ad6265SDimitry Andric   bool runOnMachineFunction(MachineFunction &MF) override;
3981ad6265SDimitry Andric };
4081ad6265SDimitry Andric } // namespace
4181ad6265SDimitry Andric 
42*fcaf7f86SDimitry Andric static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR) {
43*fcaf7f86SDimitry Andric   MachineRegisterInfo &MRI = MF.getRegInfo();
44*fcaf7f86SDimitry Andric   DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
45*fcaf7f86SDimitry Andric   SmallVector<MachineInstr *, 10> ToErase, ToEraseComposites;
46*fcaf7f86SDimitry Andric   for (MachineBasicBlock &MBB : MF) {
47*fcaf7f86SDimitry Andric     for (MachineInstr &MI : MBB) {
48*fcaf7f86SDimitry Andric       if (!isSpvIntrinsic(MI, Intrinsic::spv_track_constant))
49*fcaf7f86SDimitry Andric         continue;
50*fcaf7f86SDimitry Andric       ToErase.push_back(&MI);
51*fcaf7f86SDimitry Andric       auto *Const =
52*fcaf7f86SDimitry Andric           cast<Constant>(cast<ConstantAsMetadata>(
53*fcaf7f86SDimitry Andric                              MI.getOperand(3).getMetadata()->getOperand(0))
54*fcaf7f86SDimitry Andric                              ->getValue());
55*fcaf7f86SDimitry Andric       if (auto *GV = dyn_cast<GlobalValue>(Const)) {
56*fcaf7f86SDimitry Andric         Register Reg = GR->find(GV, &MF);
57*fcaf7f86SDimitry Andric         if (!Reg.isValid())
58*fcaf7f86SDimitry Andric           GR->add(GV, &MF, MI.getOperand(2).getReg());
59*fcaf7f86SDimitry Andric         else
60*fcaf7f86SDimitry Andric           RegsAlreadyAddedToDT[&MI] = Reg;
61*fcaf7f86SDimitry Andric       } else {
62*fcaf7f86SDimitry Andric         Register Reg = GR->find(Const, &MF);
63*fcaf7f86SDimitry Andric         if (!Reg.isValid()) {
64*fcaf7f86SDimitry Andric           if (auto *ConstVec = dyn_cast<ConstantDataVector>(Const)) {
65*fcaf7f86SDimitry Andric             auto *BuildVec = MRI.getVRegDef(MI.getOperand(2).getReg());
66*fcaf7f86SDimitry Andric             assert(BuildVec &&
67*fcaf7f86SDimitry Andric                    BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
68*fcaf7f86SDimitry Andric             for (unsigned i = 0; i < ConstVec->getNumElements(); ++i)
69*fcaf7f86SDimitry Andric               GR->add(ConstVec->getElementAsConstant(i), &MF,
70*fcaf7f86SDimitry Andric                       BuildVec->getOperand(1 + i).getReg());
71*fcaf7f86SDimitry Andric           }
72*fcaf7f86SDimitry Andric           GR->add(Const, &MF, MI.getOperand(2).getReg());
73*fcaf7f86SDimitry Andric         } else {
74*fcaf7f86SDimitry Andric           RegsAlreadyAddedToDT[&MI] = Reg;
75*fcaf7f86SDimitry Andric           // This MI is unused and will be removed. If the MI uses
76*fcaf7f86SDimitry Andric           // const_composite, it will be unused and should be removed too.
77*fcaf7f86SDimitry Andric           assert(MI.getOperand(2).isReg() && "Reg operand is expected");
78*fcaf7f86SDimitry Andric           MachineInstr *SrcMI = MRI.getVRegDef(MI.getOperand(2).getReg());
79*fcaf7f86SDimitry Andric           if (SrcMI && isSpvIntrinsic(*SrcMI, Intrinsic::spv_const_composite))
80*fcaf7f86SDimitry Andric             ToEraseComposites.push_back(SrcMI);
81*fcaf7f86SDimitry Andric         }
82*fcaf7f86SDimitry Andric       }
83*fcaf7f86SDimitry Andric     }
84*fcaf7f86SDimitry Andric   }
85*fcaf7f86SDimitry Andric   for (MachineInstr *MI : ToErase) {
86*fcaf7f86SDimitry Andric     Register Reg = MI->getOperand(2).getReg();
87*fcaf7f86SDimitry Andric     if (RegsAlreadyAddedToDT.find(MI) != RegsAlreadyAddedToDT.end())
88*fcaf7f86SDimitry Andric       Reg = RegsAlreadyAddedToDT[MI];
89*fcaf7f86SDimitry Andric     MRI.replaceRegWith(MI->getOperand(0).getReg(), Reg);
90*fcaf7f86SDimitry Andric     MI->eraseFromParent();
91*fcaf7f86SDimitry Andric   }
92*fcaf7f86SDimitry Andric   for (MachineInstr *MI : ToEraseComposites)
93*fcaf7f86SDimitry Andric     MI->eraseFromParent();
9481ad6265SDimitry Andric }
9581ad6265SDimitry Andric 
9681ad6265SDimitry Andric static void foldConstantsIntoIntrinsics(MachineFunction &MF) {
9781ad6265SDimitry Andric   SmallVector<MachineInstr *, 10> ToErase;
9881ad6265SDimitry Andric   MachineRegisterInfo &MRI = MF.getRegInfo();
9981ad6265SDimitry Andric   const unsigned AssignNameOperandShift = 2;
10081ad6265SDimitry Andric   for (MachineBasicBlock &MBB : MF) {
10181ad6265SDimitry Andric     for (MachineInstr &MI : MBB) {
10281ad6265SDimitry Andric       if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_name))
10381ad6265SDimitry Andric         continue;
10481ad6265SDimitry Andric       unsigned NumOp = MI.getNumExplicitDefs() + AssignNameOperandShift;
10581ad6265SDimitry Andric       while (MI.getOperand(NumOp).isReg()) {
10681ad6265SDimitry Andric         MachineOperand &MOp = MI.getOperand(NumOp);
10781ad6265SDimitry Andric         MachineInstr *ConstMI = MRI.getVRegDef(MOp.getReg());
10881ad6265SDimitry Andric         assert(ConstMI->getOpcode() == TargetOpcode::G_CONSTANT);
10981ad6265SDimitry Andric         MI.removeOperand(NumOp);
11081ad6265SDimitry Andric         MI.addOperand(MachineOperand::CreateImm(
11181ad6265SDimitry Andric             ConstMI->getOperand(1).getCImm()->getZExtValue()));
11281ad6265SDimitry Andric         if (MRI.use_empty(ConstMI->getOperand(0).getReg()))
11381ad6265SDimitry Andric           ToErase.push_back(ConstMI);
11481ad6265SDimitry Andric       }
11581ad6265SDimitry Andric     }
11681ad6265SDimitry Andric   }
11781ad6265SDimitry Andric   for (MachineInstr *MI : ToErase)
11881ad6265SDimitry Andric     MI->eraseFromParent();
11981ad6265SDimitry Andric }
12081ad6265SDimitry Andric 
12181ad6265SDimitry Andric static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
12281ad6265SDimitry Andric                            MachineIRBuilder MIB) {
12381ad6265SDimitry Andric   SmallVector<MachineInstr *, 10> ToErase;
12481ad6265SDimitry Andric   for (MachineBasicBlock &MBB : MF) {
12581ad6265SDimitry Andric     for (MachineInstr &MI : MBB) {
12681ad6265SDimitry Andric       if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast))
12781ad6265SDimitry Andric         continue;
12881ad6265SDimitry Andric       assert(MI.getOperand(2).isReg());
12981ad6265SDimitry Andric       MIB.setInsertPt(*MI.getParent(), MI);
13081ad6265SDimitry Andric       MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
13181ad6265SDimitry Andric       ToErase.push_back(&MI);
13281ad6265SDimitry Andric     }
13381ad6265SDimitry Andric   }
13481ad6265SDimitry Andric   for (MachineInstr *MI : ToErase)
13581ad6265SDimitry Andric     MI->eraseFromParent();
13681ad6265SDimitry Andric }
13781ad6265SDimitry Andric 
13881ad6265SDimitry Andric // Translating GV, IRTranslator sometimes generates following IR:
13981ad6265SDimitry Andric //   %1 = G_GLOBAL_VALUE
14081ad6265SDimitry Andric //   %2 = COPY %1
14181ad6265SDimitry Andric //   %3 = G_ADDRSPACE_CAST %2
14281ad6265SDimitry Andric // New registers have no SPIRVType and no register class info.
14381ad6265SDimitry Andric //
14481ad6265SDimitry Andric // Set SPIRVType for GV, propagate it from GV to other instructions,
14581ad6265SDimitry Andric // also set register classes.
14681ad6265SDimitry Andric static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
14781ad6265SDimitry Andric                                      MachineRegisterInfo &MRI,
14881ad6265SDimitry Andric                                      MachineIRBuilder &MIB) {
14981ad6265SDimitry Andric   SPIRVType *SpirvTy = nullptr;
15081ad6265SDimitry Andric   assert(MI && "Machine instr is expected");
15181ad6265SDimitry Andric   if (MI->getOperand(0).isReg()) {
15281ad6265SDimitry Andric     Register Reg = MI->getOperand(0).getReg();
15381ad6265SDimitry Andric     SpirvTy = GR->getSPIRVTypeForVReg(Reg);
15481ad6265SDimitry Andric     if (!SpirvTy) {
15581ad6265SDimitry Andric       switch (MI->getOpcode()) {
15681ad6265SDimitry Andric       case TargetOpcode::G_CONSTANT: {
15781ad6265SDimitry Andric         MIB.setInsertPt(*MI->getParent(), MI);
15881ad6265SDimitry Andric         Type *Ty = MI->getOperand(1).getCImm()->getType();
15981ad6265SDimitry Andric         SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
16081ad6265SDimitry Andric         break;
16181ad6265SDimitry Andric       }
16281ad6265SDimitry Andric       case TargetOpcode::G_GLOBAL_VALUE: {
16381ad6265SDimitry Andric         MIB.setInsertPt(*MI->getParent(), MI);
16481ad6265SDimitry Andric         Type *Ty = MI->getOperand(1).getGlobal()->getType();
16581ad6265SDimitry Andric         SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
16681ad6265SDimitry Andric         break;
16781ad6265SDimitry Andric       }
16881ad6265SDimitry Andric       case TargetOpcode::G_TRUNC:
16981ad6265SDimitry Andric       case TargetOpcode::G_ADDRSPACE_CAST:
170*fcaf7f86SDimitry Andric       case TargetOpcode::G_PTR_ADD:
17181ad6265SDimitry Andric       case TargetOpcode::COPY: {
17281ad6265SDimitry Andric         MachineOperand &Op = MI->getOperand(1);
17381ad6265SDimitry Andric         MachineInstr *Def = Op.isReg() ? MRI.getVRegDef(Op.getReg()) : nullptr;
17481ad6265SDimitry Andric         if (Def)
17581ad6265SDimitry Andric           SpirvTy = propagateSPIRVType(Def, GR, MRI, MIB);
17681ad6265SDimitry Andric         break;
17781ad6265SDimitry Andric       }
17881ad6265SDimitry Andric       default:
17981ad6265SDimitry Andric         break;
18081ad6265SDimitry Andric       }
18181ad6265SDimitry Andric       if (SpirvTy)
18281ad6265SDimitry Andric         GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF());
18381ad6265SDimitry Andric       if (!MRI.getRegClassOrNull(Reg))
18481ad6265SDimitry Andric         MRI.setRegClass(Reg, &SPIRV::IDRegClass);
18581ad6265SDimitry Andric     }
18681ad6265SDimitry Andric   }
18781ad6265SDimitry Andric   return SpirvTy;
18881ad6265SDimitry Andric }
18981ad6265SDimitry Andric 
19081ad6265SDimitry Andric // Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as
19181ad6265SDimitry Andric // a dst of the definition, assign SPIRVType to both registers. If SpirvTy is
19281ad6265SDimitry Andric // provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty.
19381ad6265SDimitry Andric // TODO: maybe move to SPIRVUtils.
19481ad6265SDimitry Andric static Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
19581ad6265SDimitry Andric                                   SPIRVGlobalRegistry *GR,
19681ad6265SDimitry Andric                                   MachineIRBuilder &MIB,
19781ad6265SDimitry Andric                                   MachineRegisterInfo &MRI) {
19881ad6265SDimitry Andric   MachineInstr *Def = MRI.getVRegDef(Reg);
19981ad6265SDimitry Andric   assert((Ty || SpirvTy) && "Either LLVM or SPIRV type is expected.");
20081ad6265SDimitry Andric   MIB.setInsertPt(*Def->getParent(),
20181ad6265SDimitry Andric                   (Def->getNextNode() ? Def->getNextNode()->getIterator()
20281ad6265SDimitry Andric                                       : Def->getParent()->end()));
20381ad6265SDimitry Andric   Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
20481ad6265SDimitry Andric   if (auto *RC = MRI.getRegClassOrNull(Reg))
20581ad6265SDimitry Andric     MRI.setRegClass(NewReg, RC);
20681ad6265SDimitry Andric   SpirvTy = SpirvTy ? SpirvTy : GR->getOrCreateSPIRVType(Ty, MIB);
20781ad6265SDimitry Andric   GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF());
20881ad6265SDimitry Andric   // This is to make it convenient for Legalizer to get the SPIRVType
20981ad6265SDimitry Andric   // when processing the actual MI (i.e. not pseudo one).
21081ad6265SDimitry Andric   GR->assignSPIRVTypeToVReg(SpirvTy, NewReg, MIB.getMF());
21181ad6265SDimitry Andric   MIB.buildInstr(SPIRV::ASSIGN_TYPE)
21281ad6265SDimitry Andric       .addDef(Reg)
21381ad6265SDimitry Andric       .addUse(NewReg)
21481ad6265SDimitry Andric       .addUse(GR->getSPIRVTypeID(SpirvTy));
21581ad6265SDimitry Andric   Def->getOperand(0).setReg(NewReg);
21681ad6265SDimitry Andric   MRI.setRegClass(Reg, &SPIRV::ANYIDRegClass);
21781ad6265SDimitry Andric   return NewReg;
21881ad6265SDimitry Andric }
21981ad6265SDimitry Andric 
22081ad6265SDimitry Andric static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
22181ad6265SDimitry Andric                                  MachineIRBuilder MIB) {
22281ad6265SDimitry Andric   MachineRegisterInfo &MRI = MF.getRegInfo();
22381ad6265SDimitry Andric   SmallVector<MachineInstr *, 10> ToErase;
22481ad6265SDimitry Andric 
22581ad6265SDimitry Andric   for (MachineBasicBlock *MBB : post_order(&MF)) {
22681ad6265SDimitry Andric     if (MBB->empty())
22781ad6265SDimitry Andric       continue;
22881ad6265SDimitry Andric 
22981ad6265SDimitry Andric     bool ReachedBegin = false;
23081ad6265SDimitry Andric     for (auto MII = std::prev(MBB->end()), Begin = MBB->begin();
23181ad6265SDimitry Andric          !ReachedBegin;) {
23281ad6265SDimitry Andric       MachineInstr &MI = *MII;
23381ad6265SDimitry Andric 
23481ad6265SDimitry Andric       if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) {
23581ad6265SDimitry Andric         Register Reg = MI.getOperand(1).getReg();
23681ad6265SDimitry Andric         Type *Ty = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
23781ad6265SDimitry Andric         MachineInstr *Def = MRI.getVRegDef(Reg);
23881ad6265SDimitry Andric         assert(Def && "Expecting an instruction that defines the register");
23981ad6265SDimitry Andric         // G_GLOBAL_VALUE already has type info.
24081ad6265SDimitry Andric         if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE)
24181ad6265SDimitry Andric           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo());
24281ad6265SDimitry Andric         ToErase.push_back(&MI);
24381ad6265SDimitry Andric       } else if (MI.getOpcode() == TargetOpcode::G_CONSTANT ||
24481ad6265SDimitry Andric                  MI.getOpcode() == TargetOpcode::G_FCONSTANT ||
24581ad6265SDimitry Andric                  MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
24681ad6265SDimitry Andric         // %rc = G_CONSTANT ty Val
24781ad6265SDimitry Andric         // ===>
24881ad6265SDimitry Andric         // %cty = OpType* ty
24981ad6265SDimitry Andric         // %rctmp = G_CONSTANT ty Val
25081ad6265SDimitry Andric         // %rc = ASSIGN_TYPE %rctmp, %cty
25181ad6265SDimitry Andric         Register Reg = MI.getOperand(0).getReg();
25281ad6265SDimitry Andric         if (MRI.hasOneUse(Reg)) {
25381ad6265SDimitry Andric           MachineInstr &UseMI = *MRI.use_instr_begin(Reg);
25481ad6265SDimitry Andric           if (isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type) ||
25581ad6265SDimitry Andric               isSpvIntrinsic(UseMI, Intrinsic::spv_assign_name))
25681ad6265SDimitry Andric             continue;
25781ad6265SDimitry Andric         }
25881ad6265SDimitry Andric         Type *Ty = nullptr;
25981ad6265SDimitry Andric         if (MI.getOpcode() == TargetOpcode::G_CONSTANT)
26081ad6265SDimitry Andric           Ty = MI.getOperand(1).getCImm()->getType();
26181ad6265SDimitry Andric         else if (MI.getOpcode() == TargetOpcode::G_FCONSTANT)
26281ad6265SDimitry Andric           Ty = MI.getOperand(1).getFPImm()->getType();
26381ad6265SDimitry Andric         else {
26481ad6265SDimitry Andric           assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
26581ad6265SDimitry Andric           Type *ElemTy = nullptr;
26681ad6265SDimitry Andric           MachineInstr *ElemMI = MRI.getVRegDef(MI.getOperand(1).getReg());
26781ad6265SDimitry Andric           assert(ElemMI);
26881ad6265SDimitry Andric 
26981ad6265SDimitry Andric           if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT)
27081ad6265SDimitry Andric             ElemTy = ElemMI->getOperand(1).getCImm()->getType();
27181ad6265SDimitry Andric           else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT)
27281ad6265SDimitry Andric             ElemTy = ElemMI->getOperand(1).getFPImm()->getType();
27381ad6265SDimitry Andric           else
27481ad6265SDimitry Andric             llvm_unreachable("Unexpected opcode");
27581ad6265SDimitry Andric           unsigned NumElts =
27681ad6265SDimitry Andric               MI.getNumExplicitOperands() - MI.getNumExplicitDefs();
27781ad6265SDimitry Andric           Ty = VectorType::get(ElemTy, NumElts, false);
27881ad6265SDimitry Andric         }
27981ad6265SDimitry Andric         insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI);
28081ad6265SDimitry Andric       } else if (MI.getOpcode() == TargetOpcode::G_TRUNC ||
28181ad6265SDimitry Andric                  MI.getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
28281ad6265SDimitry Andric                  MI.getOpcode() == TargetOpcode::COPY ||
28381ad6265SDimitry Andric                  MI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST) {
28481ad6265SDimitry Andric         propagateSPIRVType(&MI, GR, MRI, MIB);
28581ad6265SDimitry Andric       }
28681ad6265SDimitry Andric 
28781ad6265SDimitry Andric       if (MII == Begin)
28881ad6265SDimitry Andric         ReachedBegin = true;
28981ad6265SDimitry Andric       else
29081ad6265SDimitry Andric         --MII;
29181ad6265SDimitry Andric     }
29281ad6265SDimitry Andric   }
29381ad6265SDimitry Andric   for (MachineInstr *MI : ToErase)
29481ad6265SDimitry Andric     MI->eraseFromParent();
29581ad6265SDimitry Andric }
29681ad6265SDimitry Andric 
29781ad6265SDimitry Andric static std::pair<Register, unsigned>
29881ad6265SDimitry Andric createNewIdReg(Register ValReg, unsigned Opcode, MachineRegisterInfo &MRI,
29981ad6265SDimitry Andric                const SPIRVGlobalRegistry &GR) {
30081ad6265SDimitry Andric   LLT NewT = LLT::scalar(32);
30181ad6265SDimitry Andric   SPIRVType *SpvType = GR.getSPIRVTypeForVReg(ValReg);
30281ad6265SDimitry Andric   assert(SpvType && "VReg is expected to have SPIRV type");
30381ad6265SDimitry Andric   bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat;
30481ad6265SDimitry Andric   bool IsVectorFloat =
30581ad6265SDimitry Andric       SpvType->getOpcode() == SPIRV::OpTypeVector &&
30681ad6265SDimitry Andric       GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg())->getOpcode() ==
30781ad6265SDimitry Andric           SPIRV::OpTypeFloat;
30881ad6265SDimitry Andric   IsFloat |= IsVectorFloat;
30981ad6265SDimitry Andric   auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID;
31081ad6265SDimitry Andric   auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass;
31181ad6265SDimitry Andric   if (MRI.getType(ValReg).isPointer()) {
31281ad6265SDimitry Andric     NewT = LLT::pointer(0, 32);
31381ad6265SDimitry Andric     GetIdOp = SPIRV::GET_pID;
31481ad6265SDimitry Andric     DstClass = &SPIRV::pIDRegClass;
31581ad6265SDimitry Andric   } else if (MRI.getType(ValReg).isVector()) {
31681ad6265SDimitry Andric     NewT = LLT::fixed_vector(2, NewT);
31781ad6265SDimitry Andric     GetIdOp = IsFloat ? SPIRV::GET_vfID : SPIRV::GET_vID;
31881ad6265SDimitry Andric     DstClass = IsFloat ? &SPIRV::vfIDRegClass : &SPIRV::vIDRegClass;
31981ad6265SDimitry Andric   }
32081ad6265SDimitry Andric   Register IdReg = MRI.createGenericVirtualRegister(NewT);
32181ad6265SDimitry Andric   MRI.setRegClass(IdReg, DstClass);
32281ad6265SDimitry Andric   return {IdReg, GetIdOp};
32381ad6265SDimitry Andric }
32481ad6265SDimitry Andric 
32581ad6265SDimitry Andric static void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
32681ad6265SDimitry Andric                          MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) {
32781ad6265SDimitry Andric   unsigned Opc = MI.getOpcode();
32881ad6265SDimitry Andric   assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg()));
32981ad6265SDimitry Andric   MachineInstr &AssignTypeInst =
33081ad6265SDimitry Andric       *(MRI.use_instr_begin(MI.getOperand(0).getReg()));
33181ad6265SDimitry Andric   auto NewReg = createNewIdReg(MI.getOperand(0).getReg(), Opc, MRI, *GR).first;
33281ad6265SDimitry Andric   AssignTypeInst.getOperand(1).setReg(NewReg);
33381ad6265SDimitry Andric   MI.getOperand(0).setReg(NewReg);
33481ad6265SDimitry Andric   MIB.setInsertPt(*MI.getParent(),
33581ad6265SDimitry Andric                   (MI.getNextNode() ? MI.getNextNode()->getIterator()
33681ad6265SDimitry Andric                                     : MI.getParent()->end()));
33781ad6265SDimitry Andric   for (auto &Op : MI.operands()) {
33881ad6265SDimitry Andric     if (!Op.isReg() || Op.isDef())
33981ad6265SDimitry Andric       continue;
34081ad6265SDimitry Andric     auto IdOpInfo = createNewIdReg(Op.getReg(), Opc, MRI, *GR);
34181ad6265SDimitry Andric     MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg());
34281ad6265SDimitry Andric     Op.setReg(IdOpInfo.first);
34381ad6265SDimitry Andric   }
34481ad6265SDimitry Andric }
34581ad6265SDimitry Andric 
34681ad6265SDimitry Andric // Defined in SPIRVLegalizerInfo.cpp.
34781ad6265SDimitry Andric extern bool isTypeFoldingSupported(unsigned Opcode);
34881ad6265SDimitry Andric 
34981ad6265SDimitry Andric static void processInstrsWithTypeFolding(MachineFunction &MF,
35081ad6265SDimitry Andric                                          SPIRVGlobalRegistry *GR,
35181ad6265SDimitry Andric                                          MachineIRBuilder MIB) {
35281ad6265SDimitry Andric   MachineRegisterInfo &MRI = MF.getRegInfo();
35381ad6265SDimitry Andric   for (MachineBasicBlock &MBB : MF) {
35481ad6265SDimitry Andric     for (MachineInstr &MI : MBB) {
35581ad6265SDimitry Andric       if (isTypeFoldingSupported(MI.getOpcode()))
35681ad6265SDimitry Andric         processInstr(MI, MIB, MRI, GR);
35781ad6265SDimitry Andric     }
35881ad6265SDimitry Andric   }
359*fcaf7f86SDimitry Andric   for (MachineBasicBlock &MBB : MF) {
360*fcaf7f86SDimitry Andric     for (MachineInstr &MI : MBB) {
361*fcaf7f86SDimitry Andric       // We need to rewrite dst types for ASSIGN_TYPE instrs to be able
362*fcaf7f86SDimitry Andric       // to perform tblgen'erated selection and we can't do that on Legalizer
363*fcaf7f86SDimitry Andric       // as it operates on gMIR only.
364*fcaf7f86SDimitry Andric       if (MI.getOpcode() != SPIRV::ASSIGN_TYPE)
365*fcaf7f86SDimitry Andric         continue;
366*fcaf7f86SDimitry Andric       Register SrcReg = MI.getOperand(1).getReg();
367*fcaf7f86SDimitry Andric       if (!isTypeFoldingSupported(MRI.getVRegDef(SrcReg)->getOpcode()))
368*fcaf7f86SDimitry Andric         continue;
369*fcaf7f86SDimitry Andric       Register DstReg = MI.getOperand(0).getReg();
370*fcaf7f86SDimitry Andric       if (MRI.getType(DstReg).isVector())
371*fcaf7f86SDimitry Andric         MRI.setRegClass(DstReg, &SPIRV::IDRegClass);
372*fcaf7f86SDimitry Andric       MRI.setType(DstReg, LLT::scalar(32));
373*fcaf7f86SDimitry Andric     }
374*fcaf7f86SDimitry Andric   }
37581ad6265SDimitry Andric }
37681ad6265SDimitry Andric 
37781ad6265SDimitry Andric static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
37881ad6265SDimitry Andric                             MachineIRBuilder MIB) {
37981ad6265SDimitry Andric   DenseMap<Register, SmallDenseMap<uint64_t, MachineBasicBlock *>>
38081ad6265SDimitry Andric       SwitchRegToMBB;
38181ad6265SDimitry Andric   DenseMap<Register, MachineBasicBlock *> DefaultMBBs;
38281ad6265SDimitry Andric   DenseSet<Register> SwitchRegs;
38381ad6265SDimitry Andric   MachineRegisterInfo &MRI = MF.getRegInfo();
38481ad6265SDimitry Andric   // Before IRTranslator pass, spv_switch calls are inserted before each
38581ad6265SDimitry Andric   // switch instruction. IRTranslator lowers switches to ICMP+CBr+Br triples.
38681ad6265SDimitry Andric   // A switch with two cases may be translated to this MIR sequesnce:
38781ad6265SDimitry Andric   //   intrinsic(@llvm.spv.switch), %CmpReg, %Const0, %Const1
38881ad6265SDimitry Andric   //   %Dst0 = G_ICMP intpred(eq), %CmpReg, %Const0
38981ad6265SDimitry Andric   //   G_BRCOND %Dst0, %bb.2
39081ad6265SDimitry Andric   //   G_BR %bb.5
39181ad6265SDimitry Andric   // bb.5.entry:
39281ad6265SDimitry Andric   //   %Dst1 = G_ICMP intpred(eq), %CmpReg, %Const1
39381ad6265SDimitry Andric   //   G_BRCOND %Dst1, %bb.3
39481ad6265SDimitry Andric   //   G_BR %bb.4
39581ad6265SDimitry Andric   // bb.2.sw.bb:
39681ad6265SDimitry Andric   //   ...
39781ad6265SDimitry Andric   // bb.3.sw.bb1:
39881ad6265SDimitry Andric   //   ...
39981ad6265SDimitry Andric   // bb.4.sw.epilog:
40081ad6265SDimitry Andric   //   ...
40181ad6265SDimitry Andric   // Walk MIs and collect information about destination MBBs to update
40281ad6265SDimitry Andric   // spv_switch call. We assume that all spv_switch precede corresponding ICMPs.
40381ad6265SDimitry Andric   for (MachineBasicBlock &MBB : MF) {
40481ad6265SDimitry Andric     for (MachineInstr &MI : MBB) {
40581ad6265SDimitry Andric       if (isSpvIntrinsic(MI, Intrinsic::spv_switch)) {
40681ad6265SDimitry Andric         assert(MI.getOperand(1).isReg());
40781ad6265SDimitry Andric         Register Reg = MI.getOperand(1).getReg();
40881ad6265SDimitry Andric         SwitchRegs.insert(Reg);
40981ad6265SDimitry Andric         // Set the first successor as default MBB to support empty switches.
41081ad6265SDimitry Andric         DefaultMBBs[Reg] = *MBB.succ_begin();
41181ad6265SDimitry Andric       }
41281ad6265SDimitry Andric       // Process only ICMPs that relate to spv_switches.
41381ad6265SDimitry Andric       if (MI.getOpcode() == TargetOpcode::G_ICMP && MI.getOperand(2).isReg() &&
41481ad6265SDimitry Andric           SwitchRegs.contains(MI.getOperand(2).getReg())) {
41581ad6265SDimitry Andric         assert(MI.getOperand(0).isReg() && MI.getOperand(1).isPredicate() &&
41681ad6265SDimitry Andric                MI.getOperand(3).isReg());
41781ad6265SDimitry Andric         Register Dst = MI.getOperand(0).getReg();
41881ad6265SDimitry Andric         // Set type info for destination register of switch's ICMP instruction.
41981ad6265SDimitry Andric         if (GR->getSPIRVTypeForVReg(Dst) == nullptr) {
42081ad6265SDimitry Andric           MIB.setInsertPt(*MI.getParent(), MI);
42181ad6265SDimitry Andric           Type *LLVMTy = IntegerType::get(MF.getFunction().getContext(), 1);
42281ad6265SDimitry Andric           SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, MIB);
42381ad6265SDimitry Andric           MRI.setRegClass(Dst, &SPIRV::IDRegClass);
42481ad6265SDimitry Andric           GR->assignSPIRVTypeToVReg(SpirvTy, Dst, MIB.getMF());
42581ad6265SDimitry Andric         }
42681ad6265SDimitry Andric         Register CmpReg = MI.getOperand(2).getReg();
42781ad6265SDimitry Andric         MachineOperand &PredOp = MI.getOperand(1);
42881ad6265SDimitry Andric         const auto CC = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
42981ad6265SDimitry Andric         assert(CC == CmpInst::ICMP_EQ && MRI.hasOneUse(Dst) &&
43081ad6265SDimitry Andric                MRI.hasOneDef(CmpReg));
43181ad6265SDimitry Andric         uint64_t Val = getIConstVal(MI.getOperand(3).getReg(), &MRI);
43281ad6265SDimitry Andric         MachineInstr *CBr = MRI.use_begin(Dst)->getParent();
43381ad6265SDimitry Andric         assert(CBr->getOpcode() == SPIRV::G_BRCOND &&
43481ad6265SDimitry Andric                CBr->getOperand(1).isMBB());
43581ad6265SDimitry Andric         SwitchRegToMBB[CmpReg][Val] = CBr->getOperand(1).getMBB();
43681ad6265SDimitry Andric         // The next MI is always BR to either the next case or the default.
43781ad6265SDimitry Andric         MachineInstr *NextMI = CBr->getNextNode();
43881ad6265SDimitry Andric         assert(NextMI->getOpcode() == SPIRV::G_BR &&
43981ad6265SDimitry Andric                NextMI->getOperand(0).isMBB());
44081ad6265SDimitry Andric         MachineBasicBlock *NextMBB = NextMI->getOperand(0).getMBB();
44181ad6265SDimitry Andric         assert(NextMBB != nullptr);
44281ad6265SDimitry Andric         // The default MBB is not started by ICMP with switch's cmp register.
44381ad6265SDimitry Andric         if (NextMBB->front().getOpcode() != SPIRV::G_ICMP ||
44481ad6265SDimitry Andric             (NextMBB->front().getOperand(2).isReg() &&
44581ad6265SDimitry Andric              NextMBB->front().getOperand(2).getReg() != CmpReg))
44681ad6265SDimitry Andric           DefaultMBBs[CmpReg] = NextMBB;
44781ad6265SDimitry Andric       }
44881ad6265SDimitry Andric     }
44981ad6265SDimitry Andric   }
45081ad6265SDimitry Andric   // Modify spv_switch's operands by collected values. For the example above,
45181ad6265SDimitry Andric   // the result will be like this:
45281ad6265SDimitry Andric   //   intrinsic(@llvm.spv.switch), %CmpReg, %bb.4, i32 0, %bb.2, i32 1, %bb.3
45381ad6265SDimitry Andric   // Note that ICMP+CBr+Br sequences are not removed, but ModuleAnalysis marks
45481ad6265SDimitry Andric   // them as skipped and AsmPrinter does not output them.
45581ad6265SDimitry Andric   for (MachineBasicBlock &MBB : MF) {
45681ad6265SDimitry Andric     for (MachineInstr &MI : MBB) {
45781ad6265SDimitry Andric       if (!isSpvIntrinsic(MI, Intrinsic::spv_switch))
45881ad6265SDimitry Andric         continue;
45981ad6265SDimitry Andric       assert(MI.getOperand(1).isReg());
46081ad6265SDimitry Andric       Register Reg = MI.getOperand(1).getReg();
46181ad6265SDimitry Andric       unsigned NumOp = MI.getNumExplicitOperands();
46281ad6265SDimitry Andric       SmallVector<const ConstantInt *, 3> Vals;
46381ad6265SDimitry Andric       SmallVector<MachineBasicBlock *, 3> MBBs;
46481ad6265SDimitry Andric       for (unsigned i = 2; i < NumOp; i++) {
46581ad6265SDimitry Andric         Register CReg = MI.getOperand(i).getReg();
46681ad6265SDimitry Andric         uint64_t Val = getIConstVal(CReg, &MRI);
46781ad6265SDimitry Andric         MachineInstr *ConstInstr = getDefInstrMaybeConstant(CReg, &MRI);
46881ad6265SDimitry Andric         Vals.push_back(ConstInstr->getOperand(1).getCImm());
46981ad6265SDimitry Andric         MBBs.push_back(SwitchRegToMBB[Reg][Val]);
47081ad6265SDimitry Andric       }
47181ad6265SDimitry Andric       for (unsigned i = MI.getNumExplicitOperands() - 1; i > 1; i--)
47281ad6265SDimitry Andric         MI.removeOperand(i);
47381ad6265SDimitry Andric       MI.addOperand(MachineOperand::CreateMBB(DefaultMBBs[Reg]));
47481ad6265SDimitry Andric       for (unsigned i = 0; i < Vals.size(); i++) {
47581ad6265SDimitry Andric         MI.addOperand(MachineOperand::CreateCImm(Vals[i]));
47681ad6265SDimitry Andric         MI.addOperand(MachineOperand::CreateMBB(MBBs[i]));
47781ad6265SDimitry Andric       }
47881ad6265SDimitry Andric     }
47981ad6265SDimitry Andric   }
48081ad6265SDimitry Andric }
48181ad6265SDimitry Andric 
48281ad6265SDimitry Andric bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
48381ad6265SDimitry Andric   // Initialize the type registry.
48481ad6265SDimitry Andric   const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
48581ad6265SDimitry Andric   SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
48681ad6265SDimitry Andric   GR->setCurrentFunc(MF);
48781ad6265SDimitry Andric   MachineIRBuilder MIB(MF);
488*fcaf7f86SDimitry Andric   addConstantsToTrack(MF, GR);
48981ad6265SDimitry Andric   foldConstantsIntoIntrinsics(MF);
49081ad6265SDimitry Andric   insertBitcasts(MF, GR, MIB);
49181ad6265SDimitry Andric   generateAssignInstrs(MF, GR, MIB);
49281ad6265SDimitry Andric   processInstrsWithTypeFolding(MF, GR, MIB);
49381ad6265SDimitry Andric   processSwitches(MF, GR, MIB);
49481ad6265SDimitry Andric 
49581ad6265SDimitry Andric   return true;
49681ad6265SDimitry Andric }
49781ad6265SDimitry Andric 
49881ad6265SDimitry Andric INITIALIZE_PASS(SPIRVPreLegalizer, DEBUG_TYPE, "SPIRV pre legalizer", false,
49981ad6265SDimitry Andric                 false)
50081ad6265SDimitry Andric 
50181ad6265SDimitry Andric char SPIRVPreLegalizer::ID = 0;
50281ad6265SDimitry Andric 
50381ad6265SDimitry Andric FunctionPass *llvm::createSPIRVPreLegalizerPass() {
50481ad6265SDimitry Andric   return new SPIRVPreLegalizer();
50581ad6265SDimitry Andric }
506