1 //===-- SPIRVPreLegalizer.cpp - prepare IR for legalization -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The pass prepares IR for legalization: it assigns SPIR-V types to registers 10 // and removes intrinsics which holded these types during IR translation. 11 // Also it processes constants and registers them in GR to avoid duplication. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "SPIRV.h" 16 #include "SPIRVGlobalRegistry.h" 17 #include "SPIRVSubtarget.h" 18 #include "SPIRVUtils.h" 19 #include "llvm/ADT/PostOrderIterator.h" 20 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 21 #include "llvm/IR/Attributes.h" 22 #include "llvm/IR/Constants.h" 23 #include "llvm/IR/DebugInfoMetadata.h" 24 #include "llvm/IR/IntrinsicsSPIRV.h" 25 #include "llvm/Target/TargetIntrinsicInfo.h" 26 27 #define DEBUG_TYPE "spirv-prelegalizer" 28 29 using namespace llvm; 30 31 namespace { 32 class SPIRVPreLegalizer : public MachineFunctionPass { 33 public: 34 static char ID; 35 SPIRVPreLegalizer() : MachineFunctionPass(ID) { 36 initializeSPIRVPreLegalizerPass(*PassRegistry::getPassRegistry()); 37 } 38 bool runOnMachineFunction(MachineFunction &MF) override; 39 }; 40 } // namespace 41 42 static bool isSpvIntrinsic(MachineInstr &MI, Intrinsic::ID IntrinsicID) { 43 if (MI.getOpcode() == TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS && 44 MI.getIntrinsicID() == IntrinsicID) 45 return true; 46 return false; 47 } 48 49 static void foldConstantsIntoIntrinsics(MachineFunction &MF) { 50 SmallVector<MachineInstr *, 10> ToErase; 51 MachineRegisterInfo &MRI = MF.getRegInfo(); 52 const unsigned AssignNameOperandShift = 2; 53 for (MachineBasicBlock &MBB : MF) { 54 for (MachineInstr &MI : MBB) { 55 if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_name)) 56 continue; 57 unsigned NumOp = MI.getNumExplicitDefs() + AssignNameOperandShift; 58 while (MI.getOperand(NumOp).isReg()) { 59 MachineOperand &MOp = MI.getOperand(NumOp); 60 MachineInstr *ConstMI = MRI.getVRegDef(MOp.getReg()); 61 assert(ConstMI->getOpcode() == TargetOpcode::G_CONSTANT); 62 MI.removeOperand(NumOp); 63 MI.addOperand(MachineOperand::CreateImm( 64 ConstMI->getOperand(1).getCImm()->getZExtValue())); 65 if (MRI.use_empty(ConstMI->getOperand(0).getReg())) 66 ToErase.push_back(ConstMI); 67 } 68 } 69 } 70 for (MachineInstr *MI : ToErase) 71 MI->eraseFromParent(); 72 } 73 74 static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR, 75 MachineIRBuilder MIB) { 76 SmallVector<MachineInstr *, 10> ToErase; 77 for (MachineBasicBlock &MBB : MF) { 78 for (MachineInstr &MI : MBB) { 79 if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) 80 continue; 81 assert(MI.getOperand(2).isReg()); 82 MIB.setInsertPt(*MI.getParent(), MI); 83 MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg()); 84 ToErase.push_back(&MI); 85 } 86 } 87 for (MachineInstr *MI : ToErase) 88 MI->eraseFromParent(); 89 } 90 91 // Translating GV, IRTranslator sometimes generates following IR: 92 // %1 = G_GLOBAL_VALUE 93 // %2 = COPY %1 94 // %3 = G_ADDRSPACE_CAST %2 95 // New registers have no SPIRVType and no register class info. 96 // 97 // Set SPIRVType for GV, propagate it from GV to other instructions, 98 // also set register classes. 99 static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR, 100 MachineRegisterInfo &MRI, 101 MachineIRBuilder &MIB) { 102 SPIRVType *SpirvTy = nullptr; 103 assert(MI && "Machine instr is expected"); 104 if (MI->getOperand(0).isReg()) { 105 Register Reg = MI->getOperand(0).getReg(); 106 SpirvTy = GR->getSPIRVTypeForVReg(Reg); 107 if (!SpirvTy) { 108 switch (MI->getOpcode()) { 109 case TargetOpcode::G_CONSTANT: { 110 MIB.setInsertPt(*MI->getParent(), MI); 111 Type *Ty = MI->getOperand(1).getCImm()->getType(); 112 SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB); 113 break; 114 } 115 case TargetOpcode::G_GLOBAL_VALUE: { 116 MIB.setInsertPt(*MI->getParent(), MI); 117 Type *Ty = MI->getOperand(1).getGlobal()->getType(); 118 SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB); 119 break; 120 } 121 case TargetOpcode::G_TRUNC: 122 case TargetOpcode::G_ADDRSPACE_CAST: 123 case TargetOpcode::COPY: { 124 MachineOperand &Op = MI->getOperand(1); 125 MachineInstr *Def = Op.isReg() ? MRI.getVRegDef(Op.getReg()) : nullptr; 126 if (Def) 127 SpirvTy = propagateSPIRVType(Def, GR, MRI, MIB); 128 break; 129 } 130 default: 131 break; 132 } 133 if (SpirvTy) 134 GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF()); 135 if (!MRI.getRegClassOrNull(Reg)) 136 MRI.setRegClass(Reg, &SPIRV::IDRegClass); 137 } 138 } 139 return SpirvTy; 140 } 141 142 // Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as 143 // a dst of the definition, assign SPIRVType to both registers. If SpirvTy is 144 // provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty. 145 // TODO: maybe move to SPIRVUtils. 146 static Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy, 147 SPIRVGlobalRegistry *GR, 148 MachineIRBuilder &MIB, 149 MachineRegisterInfo &MRI) { 150 MachineInstr *Def = MRI.getVRegDef(Reg); 151 assert((Ty || SpirvTy) && "Either LLVM or SPIRV type is expected."); 152 MIB.setInsertPt(*Def->getParent(), 153 (Def->getNextNode() ? Def->getNextNode()->getIterator() 154 : Def->getParent()->end())); 155 Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg)); 156 if (auto *RC = MRI.getRegClassOrNull(Reg)) 157 MRI.setRegClass(NewReg, RC); 158 SpirvTy = SpirvTy ? SpirvTy : GR->getOrCreateSPIRVType(Ty, MIB); 159 GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF()); 160 // This is to make it convenient for Legalizer to get the SPIRVType 161 // when processing the actual MI (i.e. not pseudo one). 162 GR->assignSPIRVTypeToVReg(SpirvTy, NewReg, MIB.getMF()); 163 MIB.buildInstr(SPIRV::ASSIGN_TYPE) 164 .addDef(Reg) 165 .addUse(NewReg) 166 .addUse(GR->getSPIRVTypeID(SpirvTy)); 167 Def->getOperand(0).setReg(NewReg); 168 MRI.setRegClass(Reg, &SPIRV::ANYIDRegClass); 169 return NewReg; 170 } 171 172 static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR, 173 MachineIRBuilder MIB) { 174 MachineRegisterInfo &MRI = MF.getRegInfo(); 175 SmallVector<MachineInstr *, 10> ToErase; 176 177 for (MachineBasicBlock *MBB : post_order(&MF)) { 178 if (MBB->empty()) 179 continue; 180 181 bool ReachedBegin = false; 182 for (auto MII = std::prev(MBB->end()), Begin = MBB->begin(); 183 !ReachedBegin;) { 184 MachineInstr &MI = *MII; 185 186 if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) { 187 Register Reg = MI.getOperand(1).getReg(); 188 Type *Ty = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0); 189 MachineInstr *Def = MRI.getVRegDef(Reg); 190 assert(Def && "Expecting an instruction that defines the register"); 191 // G_GLOBAL_VALUE already has type info. 192 if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE) 193 insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo()); 194 ToErase.push_back(&MI); 195 } else if (MI.getOpcode() == TargetOpcode::G_CONSTANT || 196 MI.getOpcode() == TargetOpcode::G_FCONSTANT || 197 MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR) { 198 // %rc = G_CONSTANT ty Val 199 // ===> 200 // %cty = OpType* ty 201 // %rctmp = G_CONSTANT ty Val 202 // %rc = ASSIGN_TYPE %rctmp, %cty 203 Register Reg = MI.getOperand(0).getReg(); 204 if (MRI.hasOneUse(Reg)) { 205 MachineInstr &UseMI = *MRI.use_instr_begin(Reg); 206 if (isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type) || 207 isSpvIntrinsic(UseMI, Intrinsic::spv_assign_name)) 208 continue; 209 } 210 Type *Ty = nullptr; 211 if (MI.getOpcode() == TargetOpcode::G_CONSTANT) 212 Ty = MI.getOperand(1).getCImm()->getType(); 213 else if (MI.getOpcode() == TargetOpcode::G_FCONSTANT) 214 Ty = MI.getOperand(1).getFPImm()->getType(); 215 else { 216 assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR); 217 Type *ElemTy = nullptr; 218 MachineInstr *ElemMI = MRI.getVRegDef(MI.getOperand(1).getReg()); 219 assert(ElemMI); 220 221 if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT) 222 ElemTy = ElemMI->getOperand(1).getCImm()->getType(); 223 else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT) 224 ElemTy = ElemMI->getOperand(1).getFPImm()->getType(); 225 else 226 llvm_unreachable("Unexpected opcode"); 227 unsigned NumElts = 228 MI.getNumExplicitOperands() - MI.getNumExplicitDefs(); 229 Ty = VectorType::get(ElemTy, NumElts, false); 230 } 231 insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI); 232 } else if (MI.getOpcode() == TargetOpcode::G_TRUNC || 233 MI.getOpcode() == TargetOpcode::G_GLOBAL_VALUE || 234 MI.getOpcode() == TargetOpcode::COPY || 235 MI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST) { 236 propagateSPIRVType(&MI, GR, MRI, MIB); 237 } 238 239 if (MII == Begin) 240 ReachedBegin = true; 241 else 242 --MII; 243 } 244 } 245 for (MachineInstr *MI : ToErase) 246 MI->eraseFromParent(); 247 } 248 249 static std::pair<Register, unsigned> 250 createNewIdReg(Register ValReg, unsigned Opcode, MachineRegisterInfo &MRI, 251 const SPIRVGlobalRegistry &GR) { 252 LLT NewT = LLT::scalar(32); 253 SPIRVType *SpvType = GR.getSPIRVTypeForVReg(ValReg); 254 assert(SpvType && "VReg is expected to have SPIRV type"); 255 bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat; 256 bool IsVectorFloat = 257 SpvType->getOpcode() == SPIRV::OpTypeVector && 258 GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg())->getOpcode() == 259 SPIRV::OpTypeFloat; 260 IsFloat |= IsVectorFloat; 261 auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID; 262 auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass; 263 if (MRI.getType(ValReg).isPointer()) { 264 NewT = LLT::pointer(0, 32); 265 GetIdOp = SPIRV::GET_pID; 266 DstClass = &SPIRV::pIDRegClass; 267 } else if (MRI.getType(ValReg).isVector()) { 268 NewT = LLT::fixed_vector(2, NewT); 269 GetIdOp = IsFloat ? SPIRV::GET_vfID : SPIRV::GET_vID; 270 DstClass = IsFloat ? &SPIRV::vfIDRegClass : &SPIRV::vIDRegClass; 271 } 272 Register IdReg = MRI.createGenericVirtualRegister(NewT); 273 MRI.setRegClass(IdReg, DstClass); 274 return {IdReg, GetIdOp}; 275 } 276 277 static void processInstr(MachineInstr &MI, MachineIRBuilder &MIB, 278 MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) { 279 unsigned Opc = MI.getOpcode(); 280 assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg())); 281 MachineInstr &AssignTypeInst = 282 *(MRI.use_instr_begin(MI.getOperand(0).getReg())); 283 auto NewReg = createNewIdReg(MI.getOperand(0).getReg(), Opc, MRI, *GR).first; 284 AssignTypeInst.getOperand(1).setReg(NewReg); 285 MI.getOperand(0).setReg(NewReg); 286 MIB.setInsertPt(*MI.getParent(), 287 (MI.getNextNode() ? MI.getNextNode()->getIterator() 288 : MI.getParent()->end())); 289 for (auto &Op : MI.operands()) { 290 if (!Op.isReg() || Op.isDef()) 291 continue; 292 auto IdOpInfo = createNewIdReg(Op.getReg(), Opc, MRI, *GR); 293 MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg()); 294 Op.setReg(IdOpInfo.first); 295 } 296 } 297 298 // Defined in SPIRVLegalizerInfo.cpp. 299 extern bool isTypeFoldingSupported(unsigned Opcode); 300 301 static void processInstrsWithTypeFolding(MachineFunction &MF, 302 SPIRVGlobalRegistry *GR, 303 MachineIRBuilder MIB) { 304 MachineRegisterInfo &MRI = MF.getRegInfo(); 305 for (MachineBasicBlock &MBB : MF) { 306 for (MachineInstr &MI : MBB) { 307 if (isTypeFoldingSupported(MI.getOpcode())) 308 processInstr(MI, MIB, MRI, GR); 309 } 310 } 311 } 312 313 static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR, 314 MachineIRBuilder MIB) { 315 DenseMap<Register, SmallDenseMap<uint64_t, MachineBasicBlock *>> 316 SwitchRegToMBB; 317 DenseMap<Register, MachineBasicBlock *> DefaultMBBs; 318 DenseSet<Register> SwitchRegs; 319 MachineRegisterInfo &MRI = MF.getRegInfo(); 320 // Before IRTranslator pass, spv_switch calls are inserted before each 321 // switch instruction. IRTranslator lowers switches to ICMP+CBr+Br triples. 322 // A switch with two cases may be translated to this MIR sequesnce: 323 // intrinsic(@llvm.spv.switch), %CmpReg, %Const0, %Const1 324 // %Dst0 = G_ICMP intpred(eq), %CmpReg, %Const0 325 // G_BRCOND %Dst0, %bb.2 326 // G_BR %bb.5 327 // bb.5.entry: 328 // %Dst1 = G_ICMP intpred(eq), %CmpReg, %Const1 329 // G_BRCOND %Dst1, %bb.3 330 // G_BR %bb.4 331 // bb.2.sw.bb: 332 // ... 333 // bb.3.sw.bb1: 334 // ... 335 // bb.4.sw.epilog: 336 // ... 337 // Walk MIs and collect information about destination MBBs to update 338 // spv_switch call. We assume that all spv_switch precede corresponding ICMPs. 339 for (MachineBasicBlock &MBB : MF) { 340 for (MachineInstr &MI : MBB) { 341 if (isSpvIntrinsic(MI, Intrinsic::spv_switch)) { 342 assert(MI.getOperand(1).isReg()); 343 Register Reg = MI.getOperand(1).getReg(); 344 SwitchRegs.insert(Reg); 345 // Set the first successor as default MBB to support empty switches. 346 DefaultMBBs[Reg] = *MBB.succ_begin(); 347 } 348 // Process only ICMPs that relate to spv_switches. 349 if (MI.getOpcode() == TargetOpcode::G_ICMP && MI.getOperand(2).isReg() && 350 SwitchRegs.contains(MI.getOperand(2).getReg())) { 351 assert(MI.getOperand(0).isReg() && MI.getOperand(1).isPredicate() && 352 MI.getOperand(3).isReg()); 353 Register Dst = MI.getOperand(0).getReg(); 354 // Set type info for destination register of switch's ICMP instruction. 355 if (GR->getSPIRVTypeForVReg(Dst) == nullptr) { 356 MIB.setInsertPt(*MI.getParent(), MI); 357 Type *LLVMTy = IntegerType::get(MF.getFunction().getContext(), 1); 358 SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, MIB); 359 MRI.setRegClass(Dst, &SPIRV::IDRegClass); 360 GR->assignSPIRVTypeToVReg(SpirvTy, Dst, MIB.getMF()); 361 } 362 Register CmpReg = MI.getOperand(2).getReg(); 363 MachineOperand &PredOp = MI.getOperand(1); 364 const auto CC = static_cast<CmpInst::Predicate>(PredOp.getPredicate()); 365 assert(CC == CmpInst::ICMP_EQ && MRI.hasOneUse(Dst) && 366 MRI.hasOneDef(CmpReg)); 367 uint64_t Val = getIConstVal(MI.getOperand(3).getReg(), &MRI); 368 MachineInstr *CBr = MRI.use_begin(Dst)->getParent(); 369 assert(CBr->getOpcode() == SPIRV::G_BRCOND && 370 CBr->getOperand(1).isMBB()); 371 SwitchRegToMBB[CmpReg][Val] = CBr->getOperand(1).getMBB(); 372 // The next MI is always BR to either the next case or the default. 373 MachineInstr *NextMI = CBr->getNextNode(); 374 assert(NextMI->getOpcode() == SPIRV::G_BR && 375 NextMI->getOperand(0).isMBB()); 376 MachineBasicBlock *NextMBB = NextMI->getOperand(0).getMBB(); 377 assert(NextMBB != nullptr); 378 // The default MBB is not started by ICMP with switch's cmp register. 379 if (NextMBB->front().getOpcode() != SPIRV::G_ICMP || 380 (NextMBB->front().getOperand(2).isReg() && 381 NextMBB->front().getOperand(2).getReg() != CmpReg)) 382 DefaultMBBs[CmpReg] = NextMBB; 383 } 384 } 385 } 386 // Modify spv_switch's operands by collected values. For the example above, 387 // the result will be like this: 388 // intrinsic(@llvm.spv.switch), %CmpReg, %bb.4, i32 0, %bb.2, i32 1, %bb.3 389 // Note that ICMP+CBr+Br sequences are not removed, but ModuleAnalysis marks 390 // them as skipped and AsmPrinter does not output them. 391 for (MachineBasicBlock &MBB : MF) { 392 for (MachineInstr &MI : MBB) { 393 if (!isSpvIntrinsic(MI, Intrinsic::spv_switch)) 394 continue; 395 assert(MI.getOperand(1).isReg()); 396 Register Reg = MI.getOperand(1).getReg(); 397 unsigned NumOp = MI.getNumExplicitOperands(); 398 SmallVector<const ConstantInt *, 3> Vals; 399 SmallVector<MachineBasicBlock *, 3> MBBs; 400 for (unsigned i = 2; i < NumOp; i++) { 401 Register CReg = MI.getOperand(i).getReg(); 402 uint64_t Val = getIConstVal(CReg, &MRI); 403 MachineInstr *ConstInstr = getDefInstrMaybeConstant(CReg, &MRI); 404 Vals.push_back(ConstInstr->getOperand(1).getCImm()); 405 MBBs.push_back(SwitchRegToMBB[Reg][Val]); 406 } 407 for (unsigned i = MI.getNumExplicitOperands() - 1; i > 1; i--) 408 MI.removeOperand(i); 409 MI.addOperand(MachineOperand::CreateMBB(DefaultMBBs[Reg])); 410 for (unsigned i = 0; i < Vals.size(); i++) { 411 MI.addOperand(MachineOperand::CreateCImm(Vals[i])); 412 MI.addOperand(MachineOperand::CreateMBB(MBBs[i])); 413 } 414 } 415 } 416 } 417 418 bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) { 419 // Initialize the type registry. 420 const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>(); 421 SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry(); 422 GR->setCurrentFunc(MF); 423 MachineIRBuilder MIB(MF); 424 foldConstantsIntoIntrinsics(MF); 425 insertBitcasts(MF, GR, MIB); 426 generateAssignInstrs(MF, GR, MIB); 427 processInstrsWithTypeFolding(MF, GR, MIB); 428 processSwitches(MF, GR, MIB); 429 430 return true; 431 } 432 433 INITIALIZE_PASS(SPIRVPreLegalizer, DEBUG_TYPE, "SPIRV pre legalizer", false, 434 false) 435 436 char SPIRVPreLegalizer::ID = 0; 437 438 FunctionPass *llvm::createSPIRVPreLegalizerPass() { 439 return new SPIRVPreLegalizer(); 440 } 441