xref: /llvm-project/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (revision b25b507c779747697eb8ca35509b84451226e27b)
1 //===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The analysis collects instructions that should be output at the module level
10 // and performs the global register numbering.
11 //
12 // The results of this analysis are used in AsmPrinter to rename registers
13 // globally and to output required instructions at the module level.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "SPIRVModuleAnalysis.h"
18 #include "SPIRV.h"
19 #include "SPIRVGlobalRegistry.h"
20 #include "SPIRVSubtarget.h"
21 #include "SPIRVTargetMachine.h"
22 #include "SPIRVUtils.h"
23 #include "TargetInfo/SPIRVTargetInfo.h"
24 #include "llvm/CodeGen/MachineModuleInfo.h"
25 #include "llvm/CodeGen/TargetPassConfig.h"
26 
27 using namespace llvm;
28 
29 #define DEBUG_TYPE "spirv-module-analysis"
30 
31 static cl::opt<bool>
32     SPVDumpDeps("spv-dump-deps",
33                 cl::desc("Dump MIR with SPIR-V dependencies info"),
34                 cl::Optional, cl::init(false));
35 
36 char llvm::SPIRVModuleAnalysis::ID = 0;
37 
38 namespace llvm {
39 void initializeSPIRVModuleAnalysisPass(PassRegistry &);
40 } // namespace llvm
41 
42 INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true,
43                 true)
44 
45 // Retrieve an unsigned from an MDNode with a list of them as operands.
46 static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex,
47                                 unsigned DefaultVal = 0) {
48   if (MdNode && OpIndex < MdNode->getNumOperands()) {
49     const auto &Op = MdNode->getOperand(OpIndex);
50     return mdconst::extract<ConstantInt>(Op)->getZExtValue();
51   }
52   return DefaultVal;
53 }
54 
55 void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {
56   MAI.MaxID = 0;
57   for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++)
58     MAI.MS[i].clear();
59   MAI.RegisterAliasTable.clear();
60   MAI.InstrsToDelete.clear();
61   MAI.FuncNameMap.clear();
62   MAI.GlobalVarList.clear();
63   MAI.ExtInstSetMap.clear();
64 
65   // TODO: determine memory model and source language from the configuratoin.
66   if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) {
67     auto MemMD = MemModel->getOperand(0);
68     MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>(
69         getMetadataUInt(MemMD, 0));
70     MAI.Mem =
71         static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1));
72   } else {
73     MAI.Mem = SPIRV::MemoryModel::OpenCL;
74     unsigned PtrSize = ST->getPointerSize();
75     MAI.Addr = PtrSize == 32   ? SPIRV::AddressingModel::Physical32
76                : PtrSize == 64 ? SPIRV::AddressingModel::Physical64
77                                : SPIRV::AddressingModel::Logical;
78   }
79   // Get the OpenCL version number from metadata.
80   // TODO: support other source languages.
81   if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) {
82     MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C;
83     // Construct version literal in accordance with SPIRV-LLVM-Translator.
84     // TODO: support multiple OCL version metadata.
85     assert(VerNode->getNumOperands() > 0 && "Invalid SPIR");
86     auto VersionMD = VerNode->getOperand(0);
87     unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2);
88     unsigned MinorNum = getMetadataUInt(VersionMD, 1);
89     unsigned RevNum = getMetadataUInt(VersionMD, 2);
90     MAI.SrcLangVersion = (MajorNum * 100 + MinorNum) * 1000 + RevNum;
91   } else {
92     MAI.SrcLang = SPIRV::SourceLanguage::Unknown;
93     MAI.SrcLangVersion = 0;
94   }
95 
96   if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) {
97     for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) {
98       MDNode *MD = ExtNode->getOperand(I);
99       if (!MD || MD->getNumOperands() == 0)
100         continue;
101       for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)
102         MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString());
103     }
104   }
105 
106   // TODO: check if it's required by default.
107   MAI.ExtInstSetMap[static_cast<unsigned>(SPIRV::InstructionSet::OpenCL_std)] =
108       Register::index2VirtReg(MAI.getNextID());
109 }
110 
111 // Collect MI which defines the register in the given machine function.
112 static void collectDefInstr(Register Reg, const MachineFunction *MF,
113                             SPIRV::ModuleAnalysisInfo *MAI,
114                             SPIRV::ModuleSectionType MSType,
115                             bool DoInsert = true) {
116   assert(MAI->hasRegisterAlias(MF, Reg) && "Cannot find register alias");
117   MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(Reg);
118   assert(MI && "There should be an instruction that defines the register");
119   MAI->setSkipEmission(MI);
120   if (DoInsert)
121     MAI->MS[MSType].push_back(MI);
122 }
123 
124 void SPIRVModuleAnalysis::collectGlobalEntities(
125     const std::vector<SPIRV::DTSortableEntry *> &DepsGraph,
126     SPIRV::ModuleSectionType MSType,
127     std::function<bool(const SPIRV::DTSortableEntry *)> Pred,
128     bool UsePreOrder = false) {
129   DenseSet<const SPIRV::DTSortableEntry *> Visited;
130   for (const auto *E : DepsGraph) {
131     std::function<void(const SPIRV::DTSortableEntry *)> RecHoistUtil;
132     // NOTE: here we prefer recursive approach over iterative because
133     // we don't expect depchains long enough to cause SO.
134     RecHoistUtil = [MSType, UsePreOrder, &Visited, &Pred,
135                     &RecHoistUtil](const SPIRV::DTSortableEntry *E) {
136       if (Visited.count(E) || !Pred(E))
137         return;
138       Visited.insert(E);
139 
140       // Traversing deps graph in post-order allows us to get rid of
141       // register aliases preprocessing.
142       // But pre-order is required for correct processing of function
143       // declaration and arguments processing.
144       if (!UsePreOrder)
145         for (auto *S : E->getDeps())
146           RecHoistUtil(S);
147 
148       Register GlobalReg = Register::index2VirtReg(MAI.getNextID());
149       bool IsFirst = true;
150       for (auto &U : *E) {
151         const MachineFunction *MF = U.first;
152         Register Reg = U.second;
153         MAI.setRegisterAlias(MF, Reg, GlobalReg);
154         if (!MF->getRegInfo().getUniqueVRegDef(Reg))
155           continue;
156         collectDefInstr(Reg, MF, &MAI, MSType, IsFirst);
157         IsFirst = false;
158         if (E->getIsGV())
159           MAI.GlobalVarList.push_back(MF->getRegInfo().getUniqueVRegDef(Reg));
160       }
161 
162       if (UsePreOrder)
163         for (auto *S : E->getDeps())
164           RecHoistUtil(S);
165     };
166     RecHoistUtil(E);
167   }
168 }
169 
170 // The function initializes global register alias table for types, consts,
171 // global vars and func decls and collects these instruction for output
172 // at module level. Also it collects explicit OpExtension/OpCapability
173 // instructions.
174 void SPIRVModuleAnalysis::processDefInstrs(const Module &M) {
175   std::vector<SPIRV::DTSortableEntry *> DepsGraph;
176 
177   GR->buildDepsGraph(DepsGraph, SPVDumpDeps ? MMI : nullptr);
178 
179   collectGlobalEntities(
180       DepsGraph, SPIRV::MB_TypeConstVars,
181       [](const SPIRV::DTSortableEntry *E) { return !E->getIsFunc(); });
182 
183   collectGlobalEntities(
184       DepsGraph, SPIRV::MB_ExtFuncDecls,
185       [](const SPIRV::DTSortableEntry *E) { return E->getIsFunc(); }, true);
186 }
187 
188 // True if there is an instruction in the MS list with all the same operands as
189 // the given instruction has (after the given starting index).
190 // TODO: maybe it needs to check Opcodes too.
191 static bool findSameInstrInMS(const MachineInstr &A,
192                               SPIRV::ModuleSectionType MSType,
193                               SPIRV::ModuleAnalysisInfo &MAI,
194                               unsigned StartOpIndex = 0) {
195   for (const auto *B : MAI.MS[MSType]) {
196     const unsigned NumAOps = A.getNumOperands();
197     if (NumAOps != B->getNumOperands() || A.getNumDefs() != B->getNumDefs())
198       continue;
199     bool AllOpsMatch = true;
200     for (unsigned i = StartOpIndex; i < NumAOps && AllOpsMatch; ++i) {
201       if (A.getOperand(i).isReg() && B->getOperand(i).isReg()) {
202         Register RegA = A.getOperand(i).getReg();
203         Register RegB = B->getOperand(i).getReg();
204         AllOpsMatch = MAI.getRegisterAlias(A.getMF(), RegA) ==
205                       MAI.getRegisterAlias(B->getMF(), RegB);
206       } else {
207         AllOpsMatch = A.getOperand(i).isIdenticalTo(B->getOperand(i));
208       }
209     }
210     if (AllOpsMatch)
211       return true;
212   }
213   return false;
214 }
215 
216 // Look for IDs declared with Import linkage, and map the imported name string
217 // to the register defining that variable (which will usually be the result of
218 // an OpFunction). This lets us call externally imported functions using
219 // the correct ID registers.
220 void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
221                                            const Function &F) {
222   if (MI.getOpcode() == SPIRV::OpDecorate) {
223     // If it's got Import linkage.
224     auto Dec = MI.getOperand(1).getImm();
225     if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) {
226       auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm();
227       if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) {
228         // Map imported function name to function ID register.
229         std::string Name = getStringImm(MI, 2);
230         Register Target = MI.getOperand(0).getReg();
231         // TODO: check defs from different MFs.
232         MAI.FuncNameMap[Name] = MAI.getRegisterAlias(MI.getMF(), Target);
233       }
234     }
235   } else if (MI.getOpcode() == SPIRV::OpFunction) {
236     // Record all internal OpFunction declarations.
237     Register Reg = MI.defs().begin()->getReg();
238     Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg);
239     assert(GlobalReg.isValid());
240     // TODO: check that it does not conflict with existing entries.
241     MAI.FuncNameMap[F.getGlobalIdentifier()] = GlobalReg;
242   }
243 }
244 
245 // Collect the given instruction in the specified MS. We assume global register
246 // numbering has already occurred by this point. We can directly compare reg
247 // arguments when detecting duplicates.
248 static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI,
249                               SPIRV::ModuleSectionType MSType,
250                               bool Append = true) {
251   MAI.setSkipEmission(&MI);
252   if (findSameInstrInMS(MI, MSType, MAI))
253     return; // Found a duplicate, so don't add it.
254   // No duplicates, so add it.
255   if (Append)
256     MAI.MS[MSType].push_back(&MI);
257   else
258     MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI);
259 }
260 
261 // Some global instructions make reference to function-local ID regs, so cannot
262 // be correctly collected until these registers are globally numbered.
263 void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) {
264   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
265     if ((*F).isDeclaration())
266       continue;
267     MachineFunction *MF = MMI->getMachineFunction(*F);
268     assert(MF);
269     for (MachineBasicBlock &MBB : *MF)
270       for (MachineInstr &MI : MBB) {
271         if (MAI.getSkipEmission(&MI))
272           continue;
273         const unsigned OpCode = MI.getOpcode();
274         if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) {
275           collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames);
276         } else if (OpCode == SPIRV::OpEntryPoint) {
277           collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints);
278         } else if (TII->isDecorationInstr(MI)) {
279           collectOtherInstr(MI, MAI, SPIRV::MB_Annotations);
280           collectFuncNames(MI, *F);
281         } else if (TII->isConstantInstr(MI)) {
282           // Now OpSpecConstant*s are not in DT,
283           // but they need to be collected anyway.
284           collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars);
285         } else if (OpCode == SPIRV::OpFunction) {
286           collectFuncNames(MI, *F);
287         } else if (OpCode == SPIRV::OpTypeForwardPointer) {
288           collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, false);
289         }
290       }
291   }
292 }
293 
294 // Number registers in all functions globally from 0 onwards and store
295 // the result in global register alias table. Some registers are already
296 // numbered in collectGlobalEntities.
297 void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {
298   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
299     if ((*F).isDeclaration())
300       continue;
301     MachineFunction *MF = MMI->getMachineFunction(*F);
302     assert(MF);
303     for (MachineBasicBlock &MBB : *MF) {
304       for (MachineInstr &MI : MBB) {
305         for (MachineOperand &Op : MI.operands()) {
306           if (!Op.isReg())
307             continue;
308           Register Reg = Op.getReg();
309           if (MAI.hasRegisterAlias(MF, Reg))
310             continue;
311           Register NewReg = Register::index2VirtReg(MAI.getNextID());
312           MAI.setRegisterAlias(MF, Reg, NewReg);
313         }
314         if (MI.getOpcode() != SPIRV::OpExtInst)
315           continue;
316         auto Set = MI.getOperand(2).getImm();
317         if (MAI.ExtInstSetMap.find(Set) == MAI.ExtInstSetMap.end())
318           MAI.ExtInstSetMap[Set] = Register::index2VirtReg(MAI.getNextID());
319       }
320     }
321   }
322 }
323 
324 // Find OpIEqual and OpBranchConditional instructions originating from
325 // OpSwitches, mark them skipped for emission. Also mark MBB skipped if it
326 // contains only these instructions.
327 static void processSwitches(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
328                             MachineModuleInfo *MMI) {
329   DenseSet<Register> SwitchRegs;
330   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
331     MachineFunction *MF = MMI->getMachineFunction(*F);
332     if (!MF)
333       continue;
334     for (MachineBasicBlock &MBB : *MF)
335       for (MachineInstr &MI : MBB) {
336         if (MAI.getSkipEmission(&MI))
337           continue;
338         if (MI.getOpcode() == SPIRV::OpSwitch) {
339           assert(MI.getOperand(0).isReg());
340           SwitchRegs.insert(MI.getOperand(0).getReg());
341         }
342         if (MI.getOpcode() != SPIRV::OpIEqual || !MI.getOperand(2).isReg() ||
343             !SwitchRegs.contains(MI.getOperand(2).getReg()))
344           continue;
345         Register CmpReg = MI.getOperand(0).getReg();
346         MachineInstr *CBr = MI.getNextNode();
347         assert(CBr && CBr->getOpcode() == SPIRV::OpBranchConditional &&
348                CBr->getOperand(0).isReg() &&
349                CBr->getOperand(0).getReg() == CmpReg);
350         MAI.setSkipEmission(&MI);
351         MAI.setSkipEmission(CBr);
352         if (&MBB.front() == &MI && &MBB.back() == CBr)
353           MAI.MBBsToSkip.insert(&MBB);
354       }
355   }
356 }
357 
358 struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI;
359 
360 void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
361   AU.addRequired<TargetPassConfig>();
362   AU.addRequired<MachineModuleInfoWrapperPass>();
363 }
364 
365 bool SPIRVModuleAnalysis::runOnModule(Module &M) {
366   SPIRVTargetMachine &TM =
367       getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>();
368   ST = TM.getSubtargetImpl();
369   GR = ST->getSPIRVGlobalRegistry();
370   TII = ST->getInstrInfo();
371 
372   MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
373 
374   setBaseInfo(M);
375 
376   processSwitches(M, MAI, MMI);
377 
378   // Process type/const/global var/func decl instructions, number their
379   // destination registers from 0 to N, collect Extensions and Capabilities.
380   processDefInstrs(M);
381 
382   // Number rest of registers from N+1 onwards.
383   numberRegistersGlobally(M);
384 
385   // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.
386   processOtherInstrs(M);
387 
388   return false;
389 }
390