1 //===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The analysis collects instructions that should be output at the module level 10 // and performs the global register numbering. 11 // 12 // The results of this analysis are used in AsmPrinter to rename registers 13 // globally and to output required instructions at the module level. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "SPIRVModuleAnalysis.h" 18 #include "SPIRV.h" 19 #include "SPIRVGlobalRegistry.h" 20 #include "SPIRVSubtarget.h" 21 #include "SPIRVTargetMachine.h" 22 #include "SPIRVUtils.h" 23 #include "TargetInfo/SPIRVTargetInfo.h" 24 #include "llvm/CodeGen/MachineModuleInfo.h" 25 #include "llvm/CodeGen/TargetPassConfig.h" 26 27 using namespace llvm; 28 29 #define DEBUG_TYPE "spirv-module-analysis" 30 31 static cl::opt<bool> 32 SPVDumpDeps("spv-dump-deps", 33 cl::desc("Dump MIR with SPIR-V dependencies info"), 34 cl::Optional, cl::init(false)); 35 36 char llvm::SPIRVModuleAnalysis::ID = 0; 37 38 namespace llvm { 39 void initializeSPIRVModuleAnalysisPass(PassRegistry &); 40 } // namespace llvm 41 42 INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true, 43 true) 44 45 // Retrieve an unsigned from an MDNode with a list of them as operands. 46 static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex, 47 unsigned DefaultVal = 0) { 48 if (MdNode && OpIndex < MdNode->getNumOperands()) { 49 const auto &Op = MdNode->getOperand(OpIndex); 50 return mdconst::extract<ConstantInt>(Op)->getZExtValue(); 51 } 52 return DefaultVal; 53 } 54 55 void SPIRVModuleAnalysis::setBaseInfo(const Module &M) { 56 MAI.MaxID = 0; 57 for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++) 58 MAI.MS[i].clear(); 59 MAI.RegisterAliasTable.clear(); 60 MAI.InstrsToDelete.clear(); 61 MAI.FuncNameMap.clear(); 62 MAI.GlobalVarList.clear(); 63 MAI.ExtInstSetMap.clear(); 64 65 // TODO: determine memory model and source language from the configuratoin. 66 if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) { 67 auto MemMD = MemModel->getOperand(0); 68 MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>( 69 getMetadataUInt(MemMD, 0)); 70 MAI.Mem = 71 static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1)); 72 } else { 73 MAI.Mem = SPIRV::MemoryModel::OpenCL; 74 unsigned PtrSize = ST->getPointerSize(); 75 MAI.Addr = PtrSize == 32 ? SPIRV::AddressingModel::Physical32 76 : PtrSize == 64 ? SPIRV::AddressingModel::Physical64 77 : SPIRV::AddressingModel::Logical; 78 } 79 // Get the OpenCL version number from metadata. 80 // TODO: support other source languages. 81 if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) { 82 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C; 83 // Construct version literal in accordance with SPIRV-LLVM-Translator. 84 // TODO: support multiple OCL version metadata. 85 assert(VerNode->getNumOperands() > 0 && "Invalid SPIR"); 86 auto VersionMD = VerNode->getOperand(0); 87 unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2); 88 unsigned MinorNum = getMetadataUInt(VersionMD, 1); 89 unsigned RevNum = getMetadataUInt(VersionMD, 2); 90 MAI.SrcLangVersion = (MajorNum * 100 + MinorNum) * 1000 + RevNum; 91 } else { 92 MAI.SrcLang = SPIRV::SourceLanguage::Unknown; 93 MAI.SrcLangVersion = 0; 94 } 95 96 if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) { 97 for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) { 98 MDNode *MD = ExtNode->getOperand(I); 99 if (!MD || MD->getNumOperands() == 0) 100 continue; 101 for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J) 102 MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString()); 103 } 104 } 105 106 // TODO: check if it's required by default. 107 MAI.ExtInstSetMap[static_cast<unsigned>(SPIRV::InstructionSet::OpenCL_std)] = 108 Register::index2VirtReg(MAI.getNextID()); 109 } 110 111 // Collect MI which defines the register in the given machine function. 112 static void collectDefInstr(Register Reg, const MachineFunction *MF, 113 SPIRV::ModuleAnalysisInfo *MAI, 114 SPIRV::ModuleSectionType MSType, 115 bool DoInsert = true) { 116 assert(MAI->hasRegisterAlias(MF, Reg) && "Cannot find register alias"); 117 MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(Reg); 118 assert(MI && "There should be an instruction that defines the register"); 119 MAI->setSkipEmission(MI); 120 if (DoInsert) 121 MAI->MS[MSType].push_back(MI); 122 } 123 124 void SPIRVModuleAnalysis::collectGlobalEntities( 125 const std::vector<SPIRV::DTSortableEntry *> &DepsGraph, 126 SPIRV::ModuleSectionType MSType, 127 std::function<bool(const SPIRV::DTSortableEntry *)> Pred, 128 bool UsePreOrder = false) { 129 DenseSet<const SPIRV::DTSortableEntry *> Visited; 130 for (const auto *E : DepsGraph) { 131 std::function<void(const SPIRV::DTSortableEntry *)> RecHoistUtil; 132 // NOTE: here we prefer recursive approach over iterative because 133 // we don't expect depchains long enough to cause SO. 134 RecHoistUtil = [MSType, UsePreOrder, &Visited, &Pred, 135 &RecHoistUtil](const SPIRV::DTSortableEntry *E) { 136 if (Visited.count(E) || !Pred(E)) 137 return; 138 Visited.insert(E); 139 140 // Traversing deps graph in post-order allows us to get rid of 141 // register aliases preprocessing. 142 // But pre-order is required for correct processing of function 143 // declaration and arguments processing. 144 if (!UsePreOrder) 145 for (auto *S : E->getDeps()) 146 RecHoistUtil(S); 147 148 Register GlobalReg = Register::index2VirtReg(MAI.getNextID()); 149 bool IsFirst = true; 150 for (auto &U : *E) { 151 const MachineFunction *MF = U.first; 152 Register Reg = U.second; 153 MAI.setRegisterAlias(MF, Reg, GlobalReg); 154 if (!MF->getRegInfo().getUniqueVRegDef(Reg)) 155 continue; 156 collectDefInstr(Reg, MF, &MAI, MSType, IsFirst); 157 IsFirst = false; 158 if (E->getIsGV()) 159 MAI.GlobalVarList.push_back(MF->getRegInfo().getUniqueVRegDef(Reg)); 160 } 161 162 if (UsePreOrder) 163 for (auto *S : E->getDeps()) 164 RecHoistUtil(S); 165 }; 166 RecHoistUtil(E); 167 } 168 } 169 170 // The function initializes global register alias table for types, consts, 171 // global vars and func decls and collects these instruction for output 172 // at module level. Also it collects explicit OpExtension/OpCapability 173 // instructions. 174 void SPIRVModuleAnalysis::processDefInstrs(const Module &M) { 175 std::vector<SPIRV::DTSortableEntry *> DepsGraph; 176 177 GR->buildDepsGraph(DepsGraph, SPVDumpDeps ? MMI : nullptr); 178 179 collectGlobalEntities( 180 DepsGraph, SPIRV::MB_TypeConstVars, 181 [](const SPIRV::DTSortableEntry *E) { return !E->getIsFunc(); }); 182 183 collectGlobalEntities( 184 DepsGraph, SPIRV::MB_ExtFuncDecls, 185 [](const SPIRV::DTSortableEntry *E) { return E->getIsFunc(); }, true); 186 } 187 188 // True if there is an instruction in the MS list with all the same operands as 189 // the given instruction has (after the given starting index). 190 // TODO: maybe it needs to check Opcodes too. 191 static bool findSameInstrInMS(const MachineInstr &A, 192 SPIRV::ModuleSectionType MSType, 193 SPIRV::ModuleAnalysisInfo &MAI, 194 unsigned StartOpIndex = 0) { 195 for (const auto *B : MAI.MS[MSType]) { 196 const unsigned NumAOps = A.getNumOperands(); 197 if (NumAOps != B->getNumOperands() || A.getNumDefs() != B->getNumDefs()) 198 continue; 199 bool AllOpsMatch = true; 200 for (unsigned i = StartOpIndex; i < NumAOps && AllOpsMatch; ++i) { 201 if (A.getOperand(i).isReg() && B->getOperand(i).isReg()) { 202 Register RegA = A.getOperand(i).getReg(); 203 Register RegB = B->getOperand(i).getReg(); 204 AllOpsMatch = MAI.getRegisterAlias(A.getMF(), RegA) == 205 MAI.getRegisterAlias(B->getMF(), RegB); 206 } else { 207 AllOpsMatch = A.getOperand(i).isIdenticalTo(B->getOperand(i)); 208 } 209 } 210 if (AllOpsMatch) 211 return true; 212 } 213 return false; 214 } 215 216 // Look for IDs declared with Import linkage, and map the imported name string 217 // to the register defining that variable (which will usually be the result of 218 // an OpFunction). This lets us call externally imported functions using 219 // the correct ID registers. 220 void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI, 221 const Function &F) { 222 if (MI.getOpcode() == SPIRV::OpDecorate) { 223 // If it's got Import linkage. 224 auto Dec = MI.getOperand(1).getImm(); 225 if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) { 226 auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm(); 227 if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) { 228 // Map imported function name to function ID register. 229 std::string Name = getStringImm(MI, 2); 230 Register Target = MI.getOperand(0).getReg(); 231 // TODO: check defs from different MFs. 232 MAI.FuncNameMap[Name] = MAI.getRegisterAlias(MI.getMF(), Target); 233 } 234 } 235 } else if (MI.getOpcode() == SPIRV::OpFunction) { 236 // Record all internal OpFunction declarations. 237 Register Reg = MI.defs().begin()->getReg(); 238 Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg); 239 assert(GlobalReg.isValid()); 240 // TODO: check that it does not conflict with existing entries. 241 MAI.FuncNameMap[F.getGlobalIdentifier()] = GlobalReg; 242 } 243 } 244 245 // Collect the given instruction in the specified MS. We assume global register 246 // numbering has already occurred by this point. We can directly compare reg 247 // arguments when detecting duplicates. 248 static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI, 249 SPIRV::ModuleSectionType MSType, 250 bool Append = true) { 251 MAI.setSkipEmission(&MI); 252 if (findSameInstrInMS(MI, MSType, MAI)) 253 return; // Found a duplicate, so don't add it. 254 // No duplicates, so add it. 255 if (Append) 256 MAI.MS[MSType].push_back(&MI); 257 else 258 MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI); 259 } 260 261 // Some global instructions make reference to function-local ID regs, so cannot 262 // be correctly collected until these registers are globally numbered. 263 void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) { 264 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 265 if ((*F).isDeclaration()) 266 continue; 267 MachineFunction *MF = MMI->getMachineFunction(*F); 268 assert(MF); 269 for (MachineBasicBlock &MBB : *MF) 270 for (MachineInstr &MI : MBB) { 271 if (MAI.getSkipEmission(&MI)) 272 continue; 273 const unsigned OpCode = MI.getOpcode(); 274 if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) { 275 collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames); 276 } else if (OpCode == SPIRV::OpEntryPoint) { 277 collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints); 278 } else if (TII->isDecorationInstr(MI)) { 279 collectOtherInstr(MI, MAI, SPIRV::MB_Annotations); 280 collectFuncNames(MI, *F); 281 } else if (TII->isConstantInstr(MI)) { 282 // Now OpSpecConstant*s are not in DT, 283 // but they need to be collected anyway. 284 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars); 285 } else if (OpCode == SPIRV::OpFunction) { 286 collectFuncNames(MI, *F); 287 } else if (OpCode == SPIRV::OpTypeForwardPointer) { 288 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, false); 289 } 290 } 291 } 292 } 293 294 // Number registers in all functions globally from 0 onwards and store 295 // the result in global register alias table. Some registers are already 296 // numbered in collectGlobalEntities. 297 void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) { 298 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 299 if ((*F).isDeclaration()) 300 continue; 301 MachineFunction *MF = MMI->getMachineFunction(*F); 302 assert(MF); 303 for (MachineBasicBlock &MBB : *MF) { 304 for (MachineInstr &MI : MBB) { 305 for (MachineOperand &Op : MI.operands()) { 306 if (!Op.isReg()) 307 continue; 308 Register Reg = Op.getReg(); 309 if (MAI.hasRegisterAlias(MF, Reg)) 310 continue; 311 Register NewReg = Register::index2VirtReg(MAI.getNextID()); 312 MAI.setRegisterAlias(MF, Reg, NewReg); 313 } 314 if (MI.getOpcode() != SPIRV::OpExtInst) 315 continue; 316 auto Set = MI.getOperand(2).getImm(); 317 if (MAI.ExtInstSetMap.find(Set) == MAI.ExtInstSetMap.end()) 318 MAI.ExtInstSetMap[Set] = Register::index2VirtReg(MAI.getNextID()); 319 } 320 } 321 } 322 } 323 324 // Find OpIEqual and OpBranchConditional instructions originating from 325 // OpSwitches, mark them skipped for emission. Also mark MBB skipped if it 326 // contains only these instructions. 327 static void processSwitches(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, 328 MachineModuleInfo *MMI) { 329 DenseSet<Register> SwitchRegs; 330 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 331 MachineFunction *MF = MMI->getMachineFunction(*F); 332 if (!MF) 333 continue; 334 for (MachineBasicBlock &MBB : *MF) 335 for (MachineInstr &MI : MBB) { 336 if (MAI.getSkipEmission(&MI)) 337 continue; 338 if (MI.getOpcode() == SPIRV::OpSwitch) { 339 assert(MI.getOperand(0).isReg()); 340 SwitchRegs.insert(MI.getOperand(0).getReg()); 341 } 342 if (MI.getOpcode() != SPIRV::OpIEqual || !MI.getOperand(2).isReg() || 343 !SwitchRegs.contains(MI.getOperand(2).getReg())) 344 continue; 345 Register CmpReg = MI.getOperand(0).getReg(); 346 MachineInstr *CBr = MI.getNextNode(); 347 assert(CBr && CBr->getOpcode() == SPIRV::OpBranchConditional && 348 CBr->getOperand(0).isReg() && 349 CBr->getOperand(0).getReg() == CmpReg); 350 MAI.setSkipEmission(&MI); 351 MAI.setSkipEmission(CBr); 352 if (&MBB.front() == &MI && &MBB.back() == CBr) 353 MAI.MBBsToSkip.insert(&MBB); 354 } 355 } 356 } 357 358 struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI; 359 360 void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { 361 AU.addRequired<TargetPassConfig>(); 362 AU.addRequired<MachineModuleInfoWrapperPass>(); 363 } 364 365 bool SPIRVModuleAnalysis::runOnModule(Module &M) { 366 SPIRVTargetMachine &TM = 367 getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>(); 368 ST = TM.getSubtargetImpl(); 369 GR = ST->getSPIRVGlobalRegistry(); 370 TII = ST->getInstrInfo(); 371 372 MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI(); 373 374 setBaseInfo(M); 375 376 processSwitches(M, MAI, MMI); 377 378 // Process type/const/global var/func decl instructions, number their 379 // destination registers from 0 to N, collect Extensions and Capabilities. 380 processDefInstrs(M); 381 382 // Number rest of registers from N+1 onwards. 383 numberRegistersGlobally(M); 384 385 // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions. 386 processOtherInstrs(M); 387 388 return false; 389 } 390