xref: /llvm-project/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (revision bfe84f7085d82d06d61c632a7bad1e692fd159e4)
1 //===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The analysis collects instructions that should be output at the module level
10 // and performs the global register numbering.
11 //
12 // The results of this analysis are used in AsmPrinter to rename registers
13 // globally and to output required instructions at the module level.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "SPIRVModuleAnalysis.h"
18 #include "MCTargetDesc/SPIRVBaseInfo.h"
19 #include "MCTargetDesc/SPIRVMCTargetDesc.h"
20 #include "SPIRV.h"
21 #include "SPIRVSubtarget.h"
22 #include "SPIRVTargetMachine.h"
23 #include "SPIRVUtils.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/CodeGen/MachineModuleInfo.h"
26 #include "llvm/CodeGen/TargetPassConfig.h"
27 
28 using namespace llvm;
29 
30 #define DEBUG_TYPE "spirv-module-analysis"
31 
32 static cl::opt<bool>
33     SPVDumpDeps("spv-dump-deps",
34                 cl::desc("Dump MIR with SPIR-V dependencies info"),
35                 cl::Optional, cl::init(false));
36 
37 static cl::list<SPIRV::Capability::Capability>
38     AvoidCapabilities("avoid-spirv-capabilities",
39                       cl::desc("SPIR-V capabilities to avoid if there are "
40                                "other options enabling a feature"),
41                       cl::ZeroOrMore, cl::Hidden,
42                       cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader",
43                                             "SPIR-V Shader capability")));
44 // Use sets instead of cl::list to check "if contains" condition
45 struct AvoidCapabilitiesSet {
46   SmallSet<SPIRV::Capability::Capability, 4> S;
47   AvoidCapabilitiesSet() {
48     for (auto Cap : AvoidCapabilities)
49       S.insert(Cap);
50   }
51 };
52 
53 char llvm::SPIRVModuleAnalysis::ID = 0;
54 
55 namespace llvm {
56 void initializeSPIRVModuleAnalysisPass(PassRegistry &);
57 } // namespace llvm
58 
59 INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true,
60                 true)
61 
62 // Retrieve an unsigned from an MDNode with a list of them as operands.
63 static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex,
64                                 unsigned DefaultVal = 0) {
65   if (MdNode && OpIndex < MdNode->getNumOperands()) {
66     const auto &Op = MdNode->getOperand(OpIndex);
67     return mdconst::extract<ConstantInt>(Op)->getZExtValue();
68   }
69   return DefaultVal;
70 }
71 
72 static SPIRV::Requirements
73 getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,
74                                unsigned i, const SPIRVSubtarget &ST,
75                                SPIRV::RequirementHandler &Reqs) {
76   static AvoidCapabilitiesSet
77       AvoidCaps; // contains capabilities to avoid if there is another option
78 
79   VersionTuple ReqMinVer = getSymbolicOperandMinVersion(Category, i);
80   VersionTuple ReqMaxVer = getSymbolicOperandMaxVersion(Category, i);
81   VersionTuple SPIRVVersion = ST.getSPIRVVersion();
82   bool MinVerOK = SPIRVVersion.empty() || SPIRVVersion >= ReqMinVer;
83   bool MaxVerOK =
84       ReqMaxVer.empty() || SPIRVVersion.empty() || SPIRVVersion <= ReqMaxVer;
85   CapabilityList ReqCaps = getSymbolicOperandCapabilities(Category, i);
86   ExtensionList ReqExts = getSymbolicOperandExtensions(Category, i);
87   if (ReqCaps.empty()) {
88     if (ReqExts.empty()) {
89       if (MinVerOK && MaxVerOK)
90         return {true, {}, {}, ReqMinVer, ReqMaxVer};
91       return {false, {}, {}, VersionTuple(), VersionTuple()};
92     }
93   } else if (MinVerOK && MaxVerOK) {
94     if (ReqCaps.size() == 1) {
95       auto Cap = ReqCaps[0];
96       if (Reqs.isCapabilityAvailable(Cap))
97         return {true, {Cap}, ReqExts, ReqMinVer, ReqMaxVer};
98     } else {
99       // By SPIR-V specification: "If an instruction, enumerant, or other
100       // feature specifies multiple enabling capabilities, only one such
101       // capability needs to be declared to use the feature." However, one
102       // capability may be preferred over another. We use command line
103       // argument(s) and AvoidCapabilities to avoid selection of certain
104       // capabilities if there are other options.
105       CapabilityList UseCaps;
106       for (auto Cap : ReqCaps)
107         if (Reqs.isCapabilityAvailable(Cap))
108           UseCaps.push_back(Cap);
109       for (size_t i = 0, Sz = UseCaps.size(); i < Sz; ++i) {
110         auto Cap = UseCaps[i];
111         if (i == Sz - 1 || !AvoidCaps.S.contains(Cap))
112           return {true, {Cap}, ReqExts, ReqMinVer, ReqMaxVer};
113       }
114     }
115   }
116   // If there are no capabilities, or we can't satisfy the version or
117   // capability requirements, use the list of extensions (if the subtarget
118   // can handle them all).
119   if (llvm::all_of(ReqExts, [&ST](const SPIRV::Extension::Extension &Ext) {
120         return ST.canUseExtension(Ext);
121       })) {
122     return {true,
123             {},
124             ReqExts,
125             VersionTuple(),
126             VersionTuple()}; // TODO: add versions to extensions.
127   }
128   return {false, {}, {}, VersionTuple(), VersionTuple()};
129 }
130 
131 void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {
132   MAI.MaxID = 0;
133   for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++)
134     MAI.MS[i].clear();
135   MAI.RegisterAliasTable.clear();
136   MAI.InstrsToDelete.clear();
137   MAI.FuncMap.clear();
138   MAI.GlobalVarList.clear();
139   MAI.ExtInstSetMap.clear();
140   MAI.Reqs.clear();
141   MAI.Reqs.initAvailableCapabilities(*ST);
142 
143   // TODO: determine memory model and source language from the configuratoin.
144   if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) {
145     auto MemMD = MemModel->getOperand(0);
146     MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>(
147         getMetadataUInt(MemMD, 0));
148     MAI.Mem =
149         static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1));
150   } else {
151     // TODO: Add support for VulkanMemoryModel.
152     MAI.Mem = ST->isOpenCLEnv() ? SPIRV::MemoryModel::OpenCL
153                                 : SPIRV::MemoryModel::GLSL450;
154     if (MAI.Mem == SPIRV::MemoryModel::OpenCL) {
155       unsigned PtrSize = ST->getPointerSize();
156       MAI.Addr = PtrSize == 32   ? SPIRV::AddressingModel::Physical32
157                  : PtrSize == 64 ? SPIRV::AddressingModel::Physical64
158                                  : SPIRV::AddressingModel::Logical;
159     } else {
160       // TODO: Add support for PhysicalStorageBufferAddress.
161       MAI.Addr = SPIRV::AddressingModel::Logical;
162     }
163   }
164   // Get the OpenCL version number from metadata.
165   // TODO: support other source languages.
166   if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) {
167     MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C;
168     // Construct version literal in accordance with SPIRV-LLVM-Translator.
169     // TODO: support multiple OCL version metadata.
170     assert(VerNode->getNumOperands() > 0 && "Invalid SPIR");
171     auto VersionMD = VerNode->getOperand(0);
172     unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2);
173     unsigned MinorNum = getMetadataUInt(VersionMD, 1);
174     unsigned RevNum = getMetadataUInt(VersionMD, 2);
175     // Prevent Major part of OpenCL version to be 0
176     MAI.SrcLangVersion =
177         (std::max(1U, MajorNum) * 100 + MinorNum) * 1000 + RevNum;
178   } else {
179     // If there is no information about OpenCL version we are forced to generate
180     // OpenCL 1.0 by default for the OpenCL environment to avoid puzzling
181     // run-times with Unknown/0.0 version output. For a reference, LLVM-SPIRV
182     // Translator avoids potential issues with run-times in a similar manner.
183     if (ST->isOpenCLEnv()) {
184       MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_CPP;
185       MAI.SrcLangVersion = 100000;
186     } else {
187       MAI.SrcLang = SPIRV::SourceLanguage::Unknown;
188       MAI.SrcLangVersion = 0;
189     }
190   }
191 
192   if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) {
193     for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) {
194       MDNode *MD = ExtNode->getOperand(I);
195       if (!MD || MD->getNumOperands() == 0)
196         continue;
197       for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)
198         MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString());
199     }
200   }
201 
202   // Update required capabilities for this memory model, addressing model and
203   // source language.
204   MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand,
205                                  MAI.Mem, *ST);
206   MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::SourceLanguageOperand,
207                                  MAI.SrcLang, *ST);
208   MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
209                                  MAI.Addr, *ST);
210 
211   if (ST->isOpenCLEnv()) {
212     // TODO: check if it's required by default.
213     MAI.ExtInstSetMap[static_cast<unsigned>(
214         SPIRV::InstructionSet::OpenCL_std)] =
215         Register::index2VirtReg(MAI.getNextID());
216   }
217 }
218 
219 // Collect MI which defines the register in the given machine function.
220 static void collectDefInstr(Register Reg, const MachineFunction *MF,
221                             SPIRV::ModuleAnalysisInfo *MAI,
222                             SPIRV::ModuleSectionType MSType,
223                             bool DoInsert = true) {
224   assert(MAI->hasRegisterAlias(MF, Reg) && "Cannot find register alias");
225   MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(Reg);
226   assert(MI && "There should be an instruction that defines the register");
227   MAI->setSkipEmission(MI);
228   if (DoInsert)
229     MAI->MS[MSType].push_back(MI);
230 }
231 
232 void SPIRVModuleAnalysis::collectGlobalEntities(
233     const std::vector<SPIRV::DTSortableEntry *> &DepsGraph,
234     SPIRV::ModuleSectionType MSType,
235     std::function<bool(const SPIRV::DTSortableEntry *)> Pred,
236     bool UsePreOrder = false) {
237   DenseSet<const SPIRV::DTSortableEntry *> Visited;
238   for (const auto *E : DepsGraph) {
239     std::function<void(const SPIRV::DTSortableEntry *)> RecHoistUtil;
240     // NOTE: here we prefer recursive approach over iterative because
241     // we don't expect depchains long enough to cause SO.
242     RecHoistUtil = [MSType, UsePreOrder, &Visited, &Pred,
243                     &RecHoistUtil](const SPIRV::DTSortableEntry *E) {
244       if (Visited.count(E) || !Pred(E))
245         return;
246       Visited.insert(E);
247 
248       // Traversing deps graph in post-order allows us to get rid of
249       // register aliases preprocessing.
250       // But pre-order is required for correct processing of function
251       // declaration and arguments processing.
252       if (!UsePreOrder)
253         for (auto *S : E->getDeps())
254           RecHoistUtil(S);
255 
256       Register GlobalReg = Register::index2VirtReg(MAI.getNextID());
257       bool IsFirst = true;
258       for (auto &U : *E) {
259         const MachineFunction *MF = U.first;
260         Register Reg = U.second;
261         MAI.setRegisterAlias(MF, Reg, GlobalReg);
262         if (!MF->getRegInfo().getUniqueVRegDef(Reg))
263           continue;
264         collectDefInstr(Reg, MF, &MAI, MSType, IsFirst);
265         IsFirst = false;
266         if (E->getIsGV())
267           MAI.GlobalVarList.push_back(MF->getRegInfo().getUniqueVRegDef(Reg));
268       }
269 
270       if (UsePreOrder)
271         for (auto *S : E->getDeps())
272           RecHoistUtil(S);
273     };
274     RecHoistUtil(E);
275   }
276 }
277 
278 // The function initializes global register alias table for types, consts,
279 // global vars and func decls and collects these instruction for output
280 // at module level. Also it collects explicit OpExtension/OpCapability
281 // instructions.
282 void SPIRVModuleAnalysis::processDefInstrs(const Module &M) {
283   std::vector<SPIRV::DTSortableEntry *> DepsGraph;
284 
285   GR->buildDepsGraph(DepsGraph, TII, SPVDumpDeps ? MMI : nullptr);
286 
287   collectGlobalEntities(
288       DepsGraph, SPIRV::MB_TypeConstVars,
289       [](const SPIRV::DTSortableEntry *E) { return !E->getIsFunc(); });
290 
291   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
292     MachineFunction *MF = MMI->getMachineFunction(*F);
293     if (!MF)
294       continue;
295     // Iterate through and collect OpExtension/OpCapability instructions.
296     for (MachineBasicBlock &MBB : *MF) {
297       for (MachineInstr &MI : MBB) {
298         if (MI.getOpcode() == SPIRV::OpExtension) {
299           // Here, OpExtension just has a single enum operand, not a string.
300           auto Ext = SPIRV::Extension::Extension(MI.getOperand(0).getImm());
301           MAI.Reqs.addExtension(Ext);
302           MAI.setSkipEmission(&MI);
303         } else if (MI.getOpcode() == SPIRV::OpCapability) {
304           auto Cap = SPIRV::Capability::Capability(MI.getOperand(0).getImm());
305           MAI.Reqs.addCapability(Cap);
306           MAI.setSkipEmission(&MI);
307         }
308       }
309     }
310   }
311 
312   collectGlobalEntities(
313       DepsGraph, SPIRV::MB_ExtFuncDecls,
314       [](const SPIRV::DTSortableEntry *E) { return E->getIsFunc(); }, true);
315 }
316 
317 // Look for IDs declared with Import linkage, and map the corresponding function
318 // to the register defining that variable (which will usually be the result of
319 // an OpFunction). This lets us call externally imported functions using
320 // the correct ID registers.
321 void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
322                                            const Function *F) {
323   if (MI.getOpcode() == SPIRV::OpDecorate) {
324     // If it's got Import linkage.
325     auto Dec = MI.getOperand(1).getImm();
326     if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) {
327       auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm();
328       if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) {
329         // Map imported function name to function ID register.
330         const Function *ImportedFunc =
331             F->getParent()->getFunction(getStringImm(MI, 2));
332         Register Target = MI.getOperand(0).getReg();
333         MAI.FuncMap[ImportedFunc] = MAI.getRegisterAlias(MI.getMF(), Target);
334       }
335     }
336   } else if (MI.getOpcode() == SPIRV::OpFunction) {
337     // Record all internal OpFunction declarations.
338     Register Reg = MI.defs().begin()->getReg();
339     Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg);
340     assert(GlobalReg.isValid());
341     MAI.FuncMap[F] = GlobalReg;
342   }
343 }
344 
345 // References to a function via function pointers generate virtual
346 // registers without a definition. We are able to resolve this
347 // reference using Globar Register info into an OpFunction instruction
348 // and replace dummy operands by the corresponding global register references.
349 void SPIRVModuleAnalysis::collectFuncPtrs() {
350   for (auto &MI : MAI.MS[SPIRV::MB_TypeConstVars])
351     if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL)
352       collectFuncPtrs(MI);
353 }
354 
355 void SPIRVModuleAnalysis::collectFuncPtrs(MachineInstr *MI) {
356   const MachineOperand *FunUse = &MI->getOperand(2);
357   if (const MachineOperand *FunDef = GR->getFunctionDefinitionByUse(FunUse)) {
358     const MachineInstr *FunDefMI = FunDef->getParent();
359     assert(FunDefMI->getOpcode() == SPIRV::OpFunction &&
360            "Constant function pointer must refer to function definition");
361     Register FunDefReg = FunDef->getReg();
362     Register GlobalFunDefReg =
363         MAI.getRegisterAlias(FunDefMI->getMF(), FunDefReg);
364     assert(GlobalFunDefReg.isValid() &&
365            "Function definition must refer to a global register");
366     Register FunPtrReg = FunUse->getReg();
367     MAI.setRegisterAlias(MI->getMF(), FunPtrReg, GlobalFunDefReg);
368   }
369 }
370 
371 using InstrSignature = SmallVector<size_t>;
372 using InstrTraces = std::set<InstrSignature>;
373 
374 // Returns a representation of an instruction as a vector of MachineOperand
375 // hash values, see llvm::hash_value(const MachineOperand &MO) for details.
376 // This creates a signature of the instruction with the same content
377 // that MachineOperand::isIdenticalTo uses for comparison.
378 static InstrSignature instrToSignature(MachineInstr &MI,
379                                        SPIRV::ModuleAnalysisInfo &MAI) {
380   InstrSignature Signature;
381   for (unsigned i = 0; i < MI.getNumOperands(); ++i) {
382     const MachineOperand &MO = MI.getOperand(i);
383     size_t h;
384     if (MO.isReg()) {
385       Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg());
386       // mimic llvm::hash_value(const MachineOperand &MO)
387       h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(),
388                        MO.isDef());
389     } else {
390       h = hash_value(MO);
391     }
392     Signature.push_back(h);
393   }
394   return Signature;
395 }
396 
397 // Collect the given instruction in the specified MS. We assume global register
398 // numbering has already occurred by this point. We can directly compare reg
399 // arguments when detecting duplicates.
400 static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI,
401                               SPIRV::ModuleSectionType MSType, InstrTraces &IS,
402                               bool Append = true) {
403   MAI.setSkipEmission(&MI);
404   InstrSignature MISign = instrToSignature(MI, MAI);
405   auto FoundMI = IS.insert(MISign);
406   if (!FoundMI.second)
407     return; // insert failed, so we found a duplicate; don't add it to MAI.MS
408   // No duplicates, so add it.
409   if (Append)
410     MAI.MS[MSType].push_back(&MI);
411   else
412     MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI);
413 }
414 
415 // Some global instructions make reference to function-local ID regs, so cannot
416 // be correctly collected until these registers are globally numbered.
417 void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) {
418   InstrTraces IS;
419   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
420     if ((*F).isDeclaration())
421       continue;
422     MachineFunction *MF = MMI->getMachineFunction(*F);
423     assert(MF);
424     for (MachineBasicBlock &MBB : *MF)
425       for (MachineInstr &MI : MBB) {
426         if (MAI.getSkipEmission(&MI))
427           continue;
428         const unsigned OpCode = MI.getOpcode();
429         if (OpCode == SPIRV::OpString) {
430           collectOtherInstr(MI, MAI, SPIRV::MB_DebugStrings, IS);
431         } else if (OpCode == SPIRV::OpExtInst && MI.getOperand(2).isImm() &&
432                    MI.getOperand(2).getImm() ==
433                        SPIRV::InstructionSet::
434                            NonSemantic_Shader_DebugInfo_100) {
435           MachineOperand Ins = MI.getOperand(3);
436           namespace NS = SPIRV::NonSemanticExtInst;
437           static constexpr int64_t GlobalNonSemanticDITy[] = {
438               NS::DebugSource, NS::DebugCompilationUnit, NS::DebugInfoNone,
439               NS::DebugTypeBasic, NS::DebugTypePointer};
440           bool IsGlobalDI = false;
441           for (unsigned Idx = 0; Idx < std::size(GlobalNonSemanticDITy); ++Idx)
442             IsGlobalDI |= Ins.getImm() == GlobalNonSemanticDITy[Idx];
443           if (IsGlobalDI)
444             collectOtherInstr(MI, MAI, SPIRV::MB_NonSemanticGlobalDI, IS);
445         } else if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) {
446           collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames, IS);
447         } else if (OpCode == SPIRV::OpEntryPoint) {
448           collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints, IS);
449         } else if (TII->isDecorationInstr(MI)) {
450           collectOtherInstr(MI, MAI, SPIRV::MB_Annotations, IS);
451           collectFuncNames(MI, &*F);
452         } else if (TII->isConstantInstr(MI)) {
453           // Now OpSpecConstant*s are not in DT,
454           // but they need to be collected anyway.
455           collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS);
456         } else if (OpCode == SPIRV::OpFunction) {
457           collectFuncNames(MI, &*F);
458         } else if (OpCode == SPIRV::OpTypeForwardPointer) {
459           collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS, false);
460         }
461       }
462   }
463 }
464 
465 // Number registers in all functions globally from 0 onwards and store
466 // the result in global register alias table. Some registers are already
467 // numbered in collectGlobalEntities.
468 void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {
469   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
470     if ((*F).isDeclaration())
471       continue;
472     MachineFunction *MF = MMI->getMachineFunction(*F);
473     assert(MF);
474     for (MachineBasicBlock &MBB : *MF) {
475       for (MachineInstr &MI : MBB) {
476         for (MachineOperand &Op : MI.operands()) {
477           if (!Op.isReg())
478             continue;
479           Register Reg = Op.getReg();
480           if (MAI.hasRegisterAlias(MF, Reg))
481             continue;
482           Register NewReg = Register::index2VirtReg(MAI.getNextID());
483           MAI.setRegisterAlias(MF, Reg, NewReg);
484         }
485         if (MI.getOpcode() != SPIRV::OpExtInst)
486           continue;
487         auto Set = MI.getOperand(2).getImm();
488         if (!MAI.ExtInstSetMap.contains(Set))
489           MAI.ExtInstSetMap[Set] = Register::index2VirtReg(MAI.getNextID());
490       }
491     }
492   }
493 }
494 
495 // RequirementHandler implementations.
496 void SPIRV::RequirementHandler::getAndAddRequirements(
497     SPIRV::OperandCategory::OperandCategory Category, uint32_t i,
498     const SPIRVSubtarget &ST) {
499   addRequirements(getSymbolicOperandRequirements(Category, i, ST, *this));
500 }
501 
502 void SPIRV::RequirementHandler::recursiveAddCapabilities(
503     const CapabilityList &ToPrune) {
504   for (const auto &Cap : ToPrune) {
505     AllCaps.insert(Cap);
506     CapabilityList ImplicitDecls =
507         getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
508     recursiveAddCapabilities(ImplicitDecls);
509   }
510 }
511 
512 void SPIRV::RequirementHandler::addCapabilities(const CapabilityList &ToAdd) {
513   for (const auto &Cap : ToAdd) {
514     bool IsNewlyInserted = AllCaps.insert(Cap).second;
515     if (!IsNewlyInserted) // Don't re-add if it's already been declared.
516       continue;
517     CapabilityList ImplicitDecls =
518         getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
519     recursiveAddCapabilities(ImplicitDecls);
520     MinimalCaps.push_back(Cap);
521   }
522 }
523 
524 void SPIRV::RequirementHandler::addRequirements(
525     const SPIRV::Requirements &Req) {
526   if (!Req.IsSatisfiable)
527     report_fatal_error("Adding SPIR-V requirements this target can't satisfy.");
528 
529   if (Req.Cap.has_value())
530     addCapabilities({Req.Cap.value()});
531 
532   addExtensions(Req.Exts);
533 
534   if (!Req.MinVer.empty()) {
535     if (!MaxVersion.empty() && Req.MinVer > MaxVersion) {
536       LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer
537                         << " and <= " << MaxVersion << "\n");
538       report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
539     }
540 
541     if (MinVersion.empty() || Req.MinVer > MinVersion)
542       MinVersion = Req.MinVer;
543   }
544 
545   if (!Req.MaxVer.empty()) {
546     if (!MinVersion.empty() && Req.MaxVer < MinVersion) {
547       LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer
548                         << " and >= " << MinVersion << "\n");
549       report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
550     }
551 
552     if (MaxVersion.empty() || Req.MaxVer < MaxVersion)
553       MaxVersion = Req.MaxVer;
554   }
555 }
556 
557 void SPIRV::RequirementHandler::checkSatisfiable(
558     const SPIRVSubtarget &ST) const {
559   // Report as many errors as possible before aborting the compilation.
560   bool IsSatisfiable = true;
561   auto TargetVer = ST.getSPIRVVersion();
562 
563   if (!MaxVersion.empty() && !TargetVer.empty() && MaxVersion < TargetVer) {
564     LLVM_DEBUG(
565         dbgs() << "Target SPIR-V version too high for required features\n"
566                << "Required max version: " << MaxVersion << " target version "
567                << TargetVer << "\n");
568     IsSatisfiable = false;
569   }
570 
571   if (!MinVersion.empty() && !TargetVer.empty() && MinVersion > TargetVer) {
572     LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n"
573                       << "Required min version: " << MinVersion
574                       << " target version " << TargetVer << "\n");
575     IsSatisfiable = false;
576   }
577 
578   if (!MinVersion.empty() && !MaxVersion.empty() && MinVersion > MaxVersion) {
579     LLVM_DEBUG(
580         dbgs()
581         << "Version is too low for some features and too high for others.\n"
582         << "Required SPIR-V min version: " << MinVersion
583         << " required SPIR-V max version " << MaxVersion << "\n");
584     IsSatisfiable = false;
585   }
586 
587   for (auto Cap : MinimalCaps) {
588     if (AvailableCaps.contains(Cap))
589       continue;
590     LLVM_DEBUG(dbgs() << "Capability not supported: "
591                       << getSymbolicOperandMnemonic(
592                              OperandCategory::CapabilityOperand, Cap)
593                       << "\n");
594     IsSatisfiable = false;
595   }
596 
597   for (auto Ext : AllExtensions) {
598     if (ST.canUseExtension(Ext))
599       continue;
600     LLVM_DEBUG(dbgs() << "Extension not supported: "
601                       << getSymbolicOperandMnemonic(
602                              OperandCategory::ExtensionOperand, Ext)
603                       << "\n");
604     IsSatisfiable = false;
605   }
606 
607   if (!IsSatisfiable)
608     report_fatal_error("Unable to meet SPIR-V requirements for this target.");
609 }
610 
611 // Add the given capabilities and all their implicitly defined capabilities too.
612 void SPIRV::RequirementHandler::addAvailableCaps(const CapabilityList &ToAdd) {
613   for (const auto Cap : ToAdd)
614     if (AvailableCaps.insert(Cap).second)
615       addAvailableCaps(getSymbolicOperandCapabilities(
616           SPIRV::OperandCategory::CapabilityOperand, Cap));
617 }
618 
619 void SPIRV::RequirementHandler::removeCapabilityIf(
620     const Capability::Capability ToRemove,
621     const Capability::Capability IfPresent) {
622   if (AllCaps.contains(IfPresent))
623     AllCaps.erase(ToRemove);
624 }
625 
626 namespace llvm {
627 namespace SPIRV {
628 void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
629   if (ST.isOpenCLEnv()) {
630     initAvailableCapabilitiesForOpenCL(ST);
631     return;
632   }
633 
634   if (ST.isVulkanEnv()) {
635     initAvailableCapabilitiesForVulkan(ST);
636     return;
637   }
638 
639   report_fatal_error("Unimplemented environment for SPIR-V generation.");
640 }
641 
642 void RequirementHandler::initAvailableCapabilitiesForOpenCL(
643     const SPIRVSubtarget &ST) {
644   // Add the min requirements for different OpenCL and SPIR-V versions.
645   addAvailableCaps({Capability::Addresses, Capability::Float16Buffer,
646                     Capability::Int16, Capability::Int8, Capability::Kernel,
647                     Capability::Linkage, Capability::Vector16,
648                     Capability::Groups, Capability::GenericPointer,
649                     Capability::Shader});
650   if (ST.hasOpenCLFullProfile())
651     addAvailableCaps({Capability::Int64, Capability::Int64Atomics});
652   if (ST.hasOpenCLImageSupport()) {
653     addAvailableCaps({Capability::ImageBasic, Capability::LiteralSampler,
654                       Capability::Image1D, Capability::SampledBuffer,
655                       Capability::ImageBuffer});
656     if (ST.isAtLeastOpenCLVer(VersionTuple(2, 0)))
657       addAvailableCaps({Capability::ImageReadWrite});
658   }
659   if (ST.isAtLeastSPIRVVer(VersionTuple(1, 1)) &&
660       ST.isAtLeastOpenCLVer(VersionTuple(2, 2)))
661     addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage});
662   if (ST.isAtLeastSPIRVVer(VersionTuple(1, 3)))
663     addAvailableCaps({Capability::GroupNonUniform,
664                       Capability::GroupNonUniformVote,
665                       Capability::GroupNonUniformArithmetic,
666                       Capability::GroupNonUniformBallot,
667                       Capability::GroupNonUniformClustered,
668                       Capability::GroupNonUniformShuffle,
669                       Capability::GroupNonUniformShuffleRelative});
670   if (ST.isAtLeastSPIRVVer(VersionTuple(1, 4)))
671     addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero,
672                       Capability::SignedZeroInfNanPreserve,
673                       Capability::RoundingModeRTE,
674                       Capability::RoundingModeRTZ});
675   // TODO: verify if this needs some checks.
676   addAvailableCaps({Capability::Float16, Capability::Float64});
677 
678   // Add capabilities enabled by extensions.
679   for (auto Extension : ST.getAllAvailableExtensions()) {
680     CapabilityList EnabledCapabilities =
681         getCapabilitiesEnabledByExtension(Extension);
682     addAvailableCaps(EnabledCapabilities);
683   }
684 
685   // TODO: add OpenCL extensions.
686 }
687 
688 void RequirementHandler::initAvailableCapabilitiesForVulkan(
689     const SPIRVSubtarget &ST) {
690   addAvailableCaps({Capability::Shader, Capability::Linkage});
691 
692   // Provided by all supported Vulkan versions.
693   addAvailableCaps({Capability::Int16, Capability::Int64, Capability::Float16,
694                     Capability::Float64, Capability::GroupNonUniform,
695                     Capability::Image1D, Capability::SampledBuffer,
696                     Capability::ImageBuffer});
697 }
698 
699 } // namespace SPIRV
700 } // namespace llvm
701 
702 // Add the required capabilities from a decoration instruction (including
703 // BuiltIns).
704 static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex,
705                               SPIRV::RequirementHandler &Reqs,
706                               const SPIRVSubtarget &ST) {
707   int64_t DecOp = MI.getOperand(DecIndex).getImm();
708   auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp);
709   Reqs.addRequirements(getSymbolicOperandRequirements(
710       SPIRV::OperandCategory::DecorationOperand, Dec, ST, Reqs));
711 
712   if (Dec == SPIRV::Decoration::BuiltIn) {
713     int64_t BuiltInOp = MI.getOperand(DecIndex + 1).getImm();
714     auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp);
715     Reqs.addRequirements(getSymbolicOperandRequirements(
716         SPIRV::OperandCategory::BuiltInOperand, BuiltIn, ST, Reqs));
717   } else if (Dec == SPIRV::Decoration::LinkageAttributes) {
718     int64_t LinkageOp = MI.getOperand(MI.getNumOperands() - 1).getImm();
719     SPIRV::LinkageType::LinkageType LnkType =
720         static_cast<SPIRV::LinkageType::LinkageType>(LinkageOp);
721     if (LnkType == SPIRV::LinkageType::LinkOnceODR)
722       Reqs.addExtension(SPIRV::Extension::SPV_KHR_linkonce_odr);
723   } else if (Dec == SPIRV::Decoration::CacheControlLoadINTEL ||
724              Dec == SPIRV::Decoration::CacheControlStoreINTEL) {
725     Reqs.addExtension(SPIRV::Extension::SPV_INTEL_cache_controls);
726   } else if (Dec == SPIRV::Decoration::HostAccessINTEL) {
727     Reqs.addExtension(SPIRV::Extension::SPV_INTEL_global_variable_host_access);
728   } else if (Dec == SPIRV::Decoration::InitModeINTEL ||
729              Dec == SPIRV::Decoration::ImplementInRegisterMapINTEL) {
730     Reqs.addExtension(
731         SPIRV::Extension::SPV_INTEL_global_variable_fpga_decorations);
732   }
733 }
734 
735 // Add requirements for image handling.
736 static void addOpTypeImageReqs(const MachineInstr &MI,
737                                SPIRV::RequirementHandler &Reqs,
738                                const SPIRVSubtarget &ST) {
739   assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage");
740   // The operand indices used here are based on the OpTypeImage layout, which
741   // the MachineInstr follows as well.
742   int64_t ImgFormatOp = MI.getOperand(7).getImm();
743   auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp);
744   Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageFormatOperand,
745                              ImgFormat, ST);
746 
747   bool IsArrayed = MI.getOperand(4).getImm() == 1;
748   bool IsMultisampled = MI.getOperand(5).getImm() == 1;
749   bool NoSampler = MI.getOperand(6).getImm() == 2;
750   // Add dimension requirements.
751   assert(MI.getOperand(2).isImm());
752   switch (MI.getOperand(2).getImm()) {
753   case SPIRV::Dim::DIM_1D:
754     Reqs.addRequirements(NoSampler ? SPIRV::Capability::Image1D
755                                    : SPIRV::Capability::Sampled1D);
756     break;
757   case SPIRV::Dim::DIM_2D:
758     if (IsMultisampled && NoSampler)
759       Reqs.addRequirements(SPIRV::Capability::ImageMSArray);
760     break;
761   case SPIRV::Dim::DIM_Cube:
762     Reqs.addRequirements(SPIRV::Capability::Shader);
763     if (IsArrayed)
764       Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageCubeArray
765                                      : SPIRV::Capability::SampledCubeArray);
766     break;
767   case SPIRV::Dim::DIM_Rect:
768     Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageRect
769                                    : SPIRV::Capability::SampledRect);
770     break;
771   case SPIRV::Dim::DIM_Buffer:
772     Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageBuffer
773                                    : SPIRV::Capability::SampledBuffer);
774     break;
775   case SPIRV::Dim::DIM_SubpassData:
776     Reqs.addRequirements(SPIRV::Capability::InputAttachment);
777     break;
778   }
779 
780   // Has optional access qualifier.
781   if (ST.isOpenCLEnv()) {
782     if (MI.getNumOperands() > 8 &&
783         MI.getOperand(8).getImm() == SPIRV::AccessQualifier::ReadWrite)
784       Reqs.addRequirements(SPIRV::Capability::ImageReadWrite);
785     else
786       Reqs.addRequirements(SPIRV::Capability::ImageBasic);
787   }
788 }
789 
790 // Add requirements for handling atomic float instructions
791 #define ATOM_FLT_REQ_EXT_MSG(ExtName)                                          \
792   "The atomic float instruction requires the following SPIR-V "                \
793   "extension: SPV_EXT_shader_atomic_float" ExtName
794 static void AddAtomicFloatRequirements(const MachineInstr &MI,
795                                        SPIRV::RequirementHandler &Reqs,
796                                        const SPIRVSubtarget &ST) {
797   assert(MI.getOperand(1).isReg() &&
798          "Expect register operand in atomic float instruction");
799   Register TypeReg = MI.getOperand(1).getReg();
800   SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg);
801   if (TypeDef->getOpcode() != SPIRV::OpTypeFloat)
802     report_fatal_error("Result type of an atomic float instruction must be a "
803                        "floating-point type scalar");
804 
805   unsigned BitWidth = TypeDef->getOperand(1).getImm();
806   unsigned Op = MI.getOpcode();
807   if (Op == SPIRV::OpAtomicFAddEXT) {
808     if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add))
809       report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_add"), false);
810     Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
811     switch (BitWidth) {
812     case 16:
813       if (!ST.canUseExtension(
814               SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
815         report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
816       Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
817       Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
818       break;
819     case 32:
820       Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT);
821       break;
822     case 64:
823       Reqs.addCapability(SPIRV::Capability::AtomicFloat64AddEXT);
824       break;
825     default:
826       report_fatal_error(
827           "Unexpected floating-point type width in atomic float instruction");
828     }
829   } else {
830     if (!ST.canUseExtension(
831             SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max))
832       report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), false);
833     Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
834     switch (BitWidth) {
835     case 16:
836       Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
837       break;
838     case 32:
839       Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT);
840       break;
841     case 64:
842       Reqs.addCapability(SPIRV::Capability::AtomicFloat64MinMaxEXT);
843       break;
844     default:
845       report_fatal_error(
846           "Unexpected floating-point type width in atomic float instruction");
847     }
848   }
849 }
850 
851 void addInstrRequirements(const MachineInstr &MI,
852                           SPIRV::RequirementHandler &Reqs,
853                           const SPIRVSubtarget &ST) {
854   switch (MI.getOpcode()) {
855   case SPIRV::OpMemoryModel: {
856     int64_t Addr = MI.getOperand(0).getImm();
857     Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
858                                Addr, ST);
859     int64_t Mem = MI.getOperand(1).getImm();
860     Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, Mem,
861                                ST);
862     break;
863   }
864   case SPIRV::OpEntryPoint: {
865     int64_t Exe = MI.getOperand(0).getImm();
866     Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModelOperand,
867                                Exe, ST);
868     break;
869   }
870   case SPIRV::OpExecutionMode:
871   case SPIRV::OpExecutionModeId: {
872     int64_t Exe = MI.getOperand(1).getImm();
873     Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModeOperand,
874                                Exe, ST);
875     break;
876   }
877   case SPIRV::OpTypeMatrix:
878     Reqs.addCapability(SPIRV::Capability::Matrix);
879     break;
880   case SPIRV::OpTypeInt: {
881     unsigned BitWidth = MI.getOperand(1).getImm();
882     if (BitWidth == 64)
883       Reqs.addCapability(SPIRV::Capability::Int64);
884     else if (BitWidth == 16)
885       Reqs.addCapability(SPIRV::Capability::Int16);
886     else if (BitWidth == 8)
887       Reqs.addCapability(SPIRV::Capability::Int8);
888     break;
889   }
890   case SPIRV::OpTypeFloat: {
891     unsigned BitWidth = MI.getOperand(1).getImm();
892     if (BitWidth == 64)
893       Reqs.addCapability(SPIRV::Capability::Float64);
894     else if (BitWidth == 16)
895       Reqs.addCapability(SPIRV::Capability::Float16);
896     break;
897   }
898   case SPIRV::OpTypeVector: {
899     unsigned NumComponents = MI.getOperand(2).getImm();
900     if (NumComponents == 8 || NumComponents == 16)
901       Reqs.addCapability(SPIRV::Capability::Vector16);
902     break;
903   }
904   case SPIRV::OpTypePointer: {
905     auto SC = MI.getOperand(1).getImm();
906     Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC,
907                                ST);
908     // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer
909     // capability.
910     if (!ST.isOpenCLEnv())
911       break;
912     assert(MI.getOperand(2).isReg());
913     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
914     SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
915     if (TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
916         TypeDef->getOperand(1).getImm() == 16)
917       Reqs.addCapability(SPIRV::Capability::Float16Buffer);
918     break;
919   }
920   case SPIRV::OpExtInst: {
921     if (MI.getOperand(2).getImm() ==
922         static_cast<int64_t>(
923             SPIRV::InstructionSet::NonSemantic_Shader_DebugInfo_100)) {
924       Reqs.addExtension(SPIRV::Extension::SPV_KHR_non_semantic_info);
925     }
926     break;
927   }
928   case SPIRV::OpBitReverse:
929   case SPIRV::OpBitFieldInsert:
930   case SPIRV::OpBitFieldSExtract:
931   case SPIRV::OpBitFieldUExtract:
932     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions)) {
933       Reqs.addCapability(SPIRV::Capability::Shader);
934       break;
935     }
936     Reqs.addExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
937     Reqs.addCapability(SPIRV::Capability::BitInstructions);
938     break;
939   case SPIRV::OpTypeRuntimeArray:
940     Reqs.addCapability(SPIRV::Capability::Shader);
941     break;
942   case SPIRV::OpTypeOpaque:
943   case SPIRV::OpTypeEvent:
944     Reqs.addCapability(SPIRV::Capability::Kernel);
945     break;
946   case SPIRV::OpTypePipe:
947   case SPIRV::OpTypeReserveId:
948     Reqs.addCapability(SPIRV::Capability::Pipes);
949     break;
950   case SPIRV::OpTypeDeviceEvent:
951   case SPIRV::OpTypeQueue:
952   case SPIRV::OpBuildNDRange:
953     Reqs.addCapability(SPIRV::Capability::DeviceEnqueue);
954     break;
955   case SPIRV::OpDecorate:
956   case SPIRV::OpDecorateId:
957   case SPIRV::OpDecorateString:
958     addOpDecorateReqs(MI, 1, Reqs, ST);
959     break;
960   case SPIRV::OpMemberDecorate:
961   case SPIRV::OpMemberDecorateString:
962     addOpDecorateReqs(MI, 2, Reqs, ST);
963     break;
964   case SPIRV::OpInBoundsPtrAccessChain:
965     Reqs.addCapability(SPIRV::Capability::Addresses);
966     break;
967   case SPIRV::OpConstantSampler:
968     Reqs.addCapability(SPIRV::Capability::LiteralSampler);
969     break;
970   case SPIRV::OpTypeImage:
971     addOpTypeImageReqs(MI, Reqs, ST);
972     break;
973   case SPIRV::OpTypeSampler:
974     Reqs.addCapability(SPIRV::Capability::ImageBasic);
975     break;
976   case SPIRV::OpTypeForwardPointer:
977     // TODO: check if it's OpenCL's kernel.
978     Reqs.addCapability(SPIRV::Capability::Addresses);
979     break;
980   case SPIRV::OpAtomicFlagTestAndSet:
981   case SPIRV::OpAtomicLoad:
982   case SPIRV::OpAtomicStore:
983   case SPIRV::OpAtomicExchange:
984   case SPIRV::OpAtomicCompareExchange:
985   case SPIRV::OpAtomicIIncrement:
986   case SPIRV::OpAtomicIDecrement:
987   case SPIRV::OpAtomicIAdd:
988   case SPIRV::OpAtomicISub:
989   case SPIRV::OpAtomicUMin:
990   case SPIRV::OpAtomicUMax:
991   case SPIRV::OpAtomicSMin:
992   case SPIRV::OpAtomicSMax:
993   case SPIRV::OpAtomicAnd:
994   case SPIRV::OpAtomicOr:
995   case SPIRV::OpAtomicXor: {
996     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
997     const MachineInstr *InstrPtr = &MI;
998     if (MI.getOpcode() == SPIRV::OpAtomicStore) {
999       assert(MI.getOperand(3).isReg());
1000       InstrPtr = MRI.getVRegDef(MI.getOperand(3).getReg());
1001       assert(InstrPtr && "Unexpected type instruction for OpAtomicStore");
1002     }
1003     assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic");
1004     Register TypeReg = InstrPtr->getOperand(1).getReg();
1005     SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
1006     if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1007       unsigned BitWidth = TypeDef->getOperand(1).getImm();
1008       if (BitWidth == 64)
1009         Reqs.addCapability(SPIRV::Capability::Int64Atomics);
1010     }
1011     break;
1012   }
1013   case SPIRV::OpGroupNonUniformIAdd:
1014   case SPIRV::OpGroupNonUniformFAdd:
1015   case SPIRV::OpGroupNonUniformIMul:
1016   case SPIRV::OpGroupNonUniformFMul:
1017   case SPIRV::OpGroupNonUniformSMin:
1018   case SPIRV::OpGroupNonUniformUMin:
1019   case SPIRV::OpGroupNonUniformFMin:
1020   case SPIRV::OpGroupNonUniformSMax:
1021   case SPIRV::OpGroupNonUniformUMax:
1022   case SPIRV::OpGroupNonUniformFMax:
1023   case SPIRV::OpGroupNonUniformBitwiseAnd:
1024   case SPIRV::OpGroupNonUniformBitwiseOr:
1025   case SPIRV::OpGroupNonUniformBitwiseXor:
1026   case SPIRV::OpGroupNonUniformLogicalAnd:
1027   case SPIRV::OpGroupNonUniformLogicalOr:
1028   case SPIRV::OpGroupNonUniformLogicalXor: {
1029     assert(MI.getOperand(3).isImm());
1030     int64_t GroupOp = MI.getOperand(3).getImm();
1031     switch (GroupOp) {
1032     case SPIRV::GroupOperation::Reduce:
1033     case SPIRV::GroupOperation::InclusiveScan:
1034     case SPIRV::GroupOperation::ExclusiveScan:
1035       Reqs.addCapability(SPIRV::Capability::Kernel);
1036       Reqs.addCapability(SPIRV::Capability::GroupNonUniformArithmetic);
1037       Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
1038       break;
1039     case SPIRV::GroupOperation::ClusteredReduce:
1040       Reqs.addCapability(SPIRV::Capability::GroupNonUniformClustered);
1041       break;
1042     case SPIRV::GroupOperation::PartitionedReduceNV:
1043     case SPIRV::GroupOperation::PartitionedInclusiveScanNV:
1044     case SPIRV::GroupOperation::PartitionedExclusiveScanNV:
1045       Reqs.addCapability(SPIRV::Capability::GroupNonUniformPartitionedNV);
1046       break;
1047     }
1048     break;
1049   }
1050   case SPIRV::OpGroupNonUniformShuffle:
1051   case SPIRV::OpGroupNonUniformShuffleXor:
1052     Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffle);
1053     break;
1054   case SPIRV::OpGroupNonUniformShuffleUp:
1055   case SPIRV::OpGroupNonUniformShuffleDown:
1056     Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffleRelative);
1057     break;
1058   case SPIRV::OpGroupAll:
1059   case SPIRV::OpGroupAny:
1060   case SPIRV::OpGroupBroadcast:
1061   case SPIRV::OpGroupIAdd:
1062   case SPIRV::OpGroupFAdd:
1063   case SPIRV::OpGroupFMin:
1064   case SPIRV::OpGroupUMin:
1065   case SPIRV::OpGroupSMin:
1066   case SPIRV::OpGroupFMax:
1067   case SPIRV::OpGroupUMax:
1068   case SPIRV::OpGroupSMax:
1069     Reqs.addCapability(SPIRV::Capability::Groups);
1070     break;
1071   case SPIRV::OpGroupNonUniformElect:
1072     Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1073     break;
1074   case SPIRV::OpGroupNonUniformAll:
1075   case SPIRV::OpGroupNonUniformAny:
1076   case SPIRV::OpGroupNonUniformAllEqual:
1077     Reqs.addCapability(SPIRV::Capability::GroupNonUniformVote);
1078     break;
1079   case SPIRV::OpGroupNonUniformBroadcast:
1080   case SPIRV::OpGroupNonUniformBroadcastFirst:
1081   case SPIRV::OpGroupNonUniformBallot:
1082   case SPIRV::OpGroupNonUniformInverseBallot:
1083   case SPIRV::OpGroupNonUniformBallotBitExtract:
1084   case SPIRV::OpGroupNonUniformBallotBitCount:
1085   case SPIRV::OpGroupNonUniformBallotFindLSB:
1086   case SPIRV::OpGroupNonUniformBallotFindMSB:
1087     Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
1088     break;
1089   case SPIRV::OpSubgroupShuffleINTEL:
1090   case SPIRV::OpSubgroupShuffleDownINTEL:
1091   case SPIRV::OpSubgroupShuffleUpINTEL:
1092   case SPIRV::OpSubgroupShuffleXorINTEL:
1093     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1094       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1095       Reqs.addCapability(SPIRV::Capability::SubgroupShuffleINTEL);
1096     }
1097     break;
1098   case SPIRV::OpSubgroupBlockReadINTEL:
1099   case SPIRV::OpSubgroupBlockWriteINTEL:
1100     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1101       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1102       Reqs.addCapability(SPIRV::Capability::SubgroupBufferBlockIOINTEL);
1103     }
1104     break;
1105   case SPIRV::OpSubgroupImageBlockReadINTEL:
1106   case SPIRV::OpSubgroupImageBlockWriteINTEL:
1107     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1108       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1109       Reqs.addCapability(SPIRV::Capability::SubgroupImageBlockIOINTEL);
1110     }
1111     break;
1112   case SPIRV::OpAssumeTrueKHR:
1113   case SPIRV::OpExpectKHR:
1114     if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) {
1115       Reqs.addExtension(SPIRV::Extension::SPV_KHR_expect_assume);
1116       Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
1117     }
1118     break;
1119   case SPIRV::OpPtrCastToCrossWorkgroupINTEL:
1120   case SPIRV::OpCrossWorkgroupCastToPtrINTEL:
1121     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {
1122       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes);
1123       Reqs.addCapability(SPIRV::Capability::USMStorageClassesINTEL);
1124     }
1125     break;
1126   case SPIRV::OpConstantFunctionPointerINTEL:
1127     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1128       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1129       Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1130     }
1131     break;
1132   case SPIRV::OpGroupNonUniformRotateKHR:
1133     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate))
1134       report_fatal_error("OpGroupNonUniformRotateKHR instruction requires the "
1135                          "following SPIR-V extension: SPV_KHR_subgroup_rotate",
1136                          false);
1137     Reqs.addExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate);
1138     Reqs.addCapability(SPIRV::Capability::GroupNonUniformRotateKHR);
1139     Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1140     break;
1141   case SPIRV::OpGroupIMulKHR:
1142   case SPIRV::OpGroupFMulKHR:
1143   case SPIRV::OpGroupBitwiseAndKHR:
1144   case SPIRV::OpGroupBitwiseOrKHR:
1145   case SPIRV::OpGroupBitwiseXorKHR:
1146   case SPIRV::OpGroupLogicalAndKHR:
1147   case SPIRV::OpGroupLogicalOrKHR:
1148   case SPIRV::OpGroupLogicalXorKHR:
1149     if (ST.canUseExtension(
1150             SPIRV::Extension::SPV_KHR_uniform_group_instructions)) {
1151       Reqs.addExtension(SPIRV::Extension::SPV_KHR_uniform_group_instructions);
1152       Reqs.addCapability(SPIRV::Capability::GroupUniformArithmeticKHR);
1153     }
1154     break;
1155   case SPIRV::OpReadClockKHR:
1156     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_shader_clock))
1157       report_fatal_error("OpReadClockKHR instruction requires the "
1158                          "following SPIR-V extension: SPV_KHR_shader_clock",
1159                          false);
1160     Reqs.addExtension(SPIRV::Extension::SPV_KHR_shader_clock);
1161     Reqs.addCapability(SPIRV::Capability::ShaderClockKHR);
1162     break;
1163   case SPIRV::OpFunctionPointerCallINTEL:
1164     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1165       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1166       Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1167     }
1168     break;
1169   case SPIRV::OpAtomicFAddEXT:
1170   case SPIRV::OpAtomicFMinEXT:
1171   case SPIRV::OpAtomicFMaxEXT:
1172     AddAtomicFloatRequirements(MI, Reqs, ST);
1173     break;
1174   case SPIRV::OpConvertBF16ToFINTEL:
1175   case SPIRV::OpConvertFToBF16INTEL:
1176     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) {
1177       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion);
1178       Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL);
1179     }
1180     break;
1181   case SPIRV::OpVariableLengthArrayINTEL:
1182   case SPIRV::OpSaveMemoryINTEL:
1183   case SPIRV::OpRestoreMemoryINTEL:
1184     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_variable_length_array)) {
1185       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_variable_length_array);
1186       Reqs.addCapability(SPIRV::Capability::VariableLengthArrayINTEL);
1187     }
1188     break;
1189   case SPIRV::OpAsmTargetINTEL:
1190   case SPIRV::OpAsmINTEL:
1191   case SPIRV::OpAsmCallINTEL:
1192     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_inline_assembly)) {
1193       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_inline_assembly);
1194       Reqs.addCapability(SPIRV::Capability::AsmINTEL);
1195     }
1196     break;
1197   case SPIRV::OpTypeCooperativeMatrixKHR:
1198     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
1199       report_fatal_error(
1200           "OpTypeCooperativeMatrixKHR type requires the "
1201           "following SPIR-V extension: SPV_KHR_cooperative_matrix",
1202           false);
1203     Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
1204     Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
1205     break;
1206   case SPIRV::OpArithmeticFenceEXT:
1207     if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence))
1208       report_fatal_error("OpArithmeticFenceEXT requires the "
1209                          "following SPIR-V extension: SPV_EXT_arithmetic_fence",
1210                          false);
1211     Reqs.addExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence);
1212     Reqs.addCapability(SPIRV::Capability::ArithmeticFenceEXT);
1213     break;
1214   case SPIRV::OpControlBarrierArriveINTEL:
1215   case SPIRV::OpControlBarrierWaitINTEL:
1216     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_split_barrier)) {
1217       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_split_barrier);
1218       Reqs.addCapability(SPIRV::Capability::SplitBarrierINTEL);
1219     }
1220     break;
1221   default:
1222     break;
1223   }
1224 
1225   // If we require capability Shader, then we can remove the requirement for
1226   // the BitInstructions capability, since Shader is a superset capability
1227   // of BitInstructions.
1228   Reqs.removeCapabilityIf(SPIRV::Capability::BitInstructions,
1229                           SPIRV::Capability::Shader);
1230 }
1231 
1232 static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
1233                         MachineModuleInfo *MMI, const SPIRVSubtarget &ST) {
1234   // Collect requirements for existing instructions.
1235   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
1236     MachineFunction *MF = MMI->getMachineFunction(*F);
1237     if (!MF)
1238       continue;
1239     for (const MachineBasicBlock &MBB : *MF)
1240       for (const MachineInstr &MI : MBB)
1241         addInstrRequirements(MI, MAI.Reqs, ST);
1242   }
1243   // Collect requirements for OpExecutionMode instructions.
1244   auto Node = M.getNamedMetadata("spirv.ExecutionMode");
1245   if (Node) {
1246     // SPV_KHR_float_controls is not available until v1.4
1247     bool RequireFloatControls = false,
1248          VerLower14 = !ST.isAtLeastSPIRVVer(VersionTuple(1, 4));
1249     for (unsigned i = 0; i < Node->getNumOperands(); i++) {
1250       MDNode *MDN = cast<MDNode>(Node->getOperand(i));
1251       const MDOperand &MDOp = MDN->getOperand(1);
1252       if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) {
1253         Constant *C = CMeta->getValue();
1254         if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) {
1255           auto EM = Const->getZExtValue();
1256           MAI.Reqs.getAndAddRequirements(
1257               SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
1258           // add SPV_KHR_float_controls if the version is too low
1259           switch (EM) {
1260           case SPIRV::ExecutionMode::DenormPreserve:
1261           case SPIRV::ExecutionMode::DenormFlushToZero:
1262           case SPIRV::ExecutionMode::SignedZeroInfNanPreserve:
1263           case SPIRV::ExecutionMode::RoundingModeRTE:
1264           case SPIRV::ExecutionMode::RoundingModeRTZ:
1265             RequireFloatControls = VerLower14;
1266             break;
1267           }
1268         }
1269       }
1270     }
1271     if (RequireFloatControls &&
1272         ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls))
1273       MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls);
1274   }
1275   for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
1276     const Function &F = *FI;
1277     if (F.isDeclaration())
1278       continue;
1279     if (F.getMetadata("reqd_work_group_size"))
1280       MAI.Reqs.getAndAddRequirements(
1281           SPIRV::OperandCategory::ExecutionModeOperand,
1282           SPIRV::ExecutionMode::LocalSize, ST);
1283     if (F.getFnAttribute("hlsl.numthreads").isValid()) {
1284       MAI.Reqs.getAndAddRequirements(
1285           SPIRV::OperandCategory::ExecutionModeOperand,
1286           SPIRV::ExecutionMode::LocalSize, ST);
1287     }
1288     if (F.getMetadata("work_group_size_hint"))
1289       MAI.Reqs.getAndAddRequirements(
1290           SPIRV::OperandCategory::ExecutionModeOperand,
1291           SPIRV::ExecutionMode::LocalSizeHint, ST);
1292     if (F.getMetadata("intel_reqd_sub_group_size"))
1293       MAI.Reqs.getAndAddRequirements(
1294           SPIRV::OperandCategory::ExecutionModeOperand,
1295           SPIRV::ExecutionMode::SubgroupSize, ST);
1296     if (F.getMetadata("vec_type_hint"))
1297       MAI.Reqs.getAndAddRequirements(
1298           SPIRV::OperandCategory::ExecutionModeOperand,
1299           SPIRV::ExecutionMode::VecTypeHint, ST);
1300 
1301     if (F.hasOptNone() &&
1302         ST.canUseExtension(SPIRV::Extension::SPV_INTEL_optnone)) {
1303       // Output OpCapability OptNoneINTEL.
1304       MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_optnone);
1305       MAI.Reqs.addCapability(SPIRV::Capability::OptNoneINTEL);
1306     }
1307   }
1308 }
1309 
1310 static unsigned getFastMathFlags(const MachineInstr &I) {
1311   unsigned Flags = SPIRV::FPFastMathMode::None;
1312   if (I.getFlag(MachineInstr::MIFlag::FmNoNans))
1313     Flags |= SPIRV::FPFastMathMode::NotNaN;
1314   if (I.getFlag(MachineInstr::MIFlag::FmNoInfs))
1315     Flags |= SPIRV::FPFastMathMode::NotInf;
1316   if (I.getFlag(MachineInstr::MIFlag::FmNsz))
1317     Flags |= SPIRV::FPFastMathMode::NSZ;
1318   if (I.getFlag(MachineInstr::MIFlag::FmArcp))
1319     Flags |= SPIRV::FPFastMathMode::AllowRecip;
1320   if (I.getFlag(MachineInstr::MIFlag::FmReassoc))
1321     Flags |= SPIRV::FPFastMathMode::Fast;
1322   return Flags;
1323 }
1324 
1325 static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST,
1326                                    const SPIRVInstrInfo &TII,
1327                                    SPIRV::RequirementHandler &Reqs) {
1328   if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) &&
1329       getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
1330                                      SPIRV::Decoration::NoSignedWrap, ST, Reqs)
1331           .IsSatisfiable) {
1332     buildOpDecorate(I.getOperand(0).getReg(), I, TII,
1333                     SPIRV::Decoration::NoSignedWrap, {});
1334   }
1335   if (I.getFlag(MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(I) &&
1336       getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
1337                                      SPIRV::Decoration::NoUnsignedWrap, ST,
1338                                      Reqs)
1339           .IsSatisfiable) {
1340     buildOpDecorate(I.getOperand(0).getReg(), I, TII,
1341                     SPIRV::Decoration::NoUnsignedWrap, {});
1342   }
1343   if (!TII.canUseFastMathFlags(I))
1344     return;
1345   unsigned FMFlags = getFastMathFlags(I);
1346   if (FMFlags == SPIRV::FPFastMathMode::None)
1347     return;
1348   Register DstReg = I.getOperand(0).getReg();
1349   buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode, {FMFlags});
1350 }
1351 
1352 // Walk all functions and add decorations related to MI flags.
1353 static void addDecorations(const Module &M, const SPIRVInstrInfo &TII,
1354                            MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
1355                            SPIRV::ModuleAnalysisInfo &MAI) {
1356   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
1357     MachineFunction *MF = MMI->getMachineFunction(*F);
1358     if (!MF)
1359       continue;
1360     for (auto &MBB : *MF)
1361       for (auto &MI : MBB)
1362         handleMIFlagDecoration(MI, ST, TII, MAI.Reqs);
1363   }
1364 }
1365 
1366 struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI;
1367 
1368 void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
1369   AU.addRequired<TargetPassConfig>();
1370   AU.addRequired<MachineModuleInfoWrapperPass>();
1371 }
1372 
1373 bool SPIRVModuleAnalysis::runOnModule(Module &M) {
1374   SPIRVTargetMachine &TM =
1375       getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>();
1376   ST = TM.getSubtargetImpl();
1377   GR = ST->getSPIRVGlobalRegistry();
1378   TII = ST->getInstrInfo();
1379 
1380   MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
1381 
1382   setBaseInfo(M);
1383 
1384   addDecorations(M, *TII, MMI, *ST, MAI);
1385 
1386   collectReqs(M, MAI, MMI, *ST);
1387 
1388   // Process type/const/global var/func decl instructions, number their
1389   // destination registers from 0 to N, collect Extensions and Capabilities.
1390   processDefInstrs(M);
1391 
1392   // Number rest of registers from N+1 onwards.
1393   numberRegistersGlobally(M);
1394 
1395   // Update references to OpFunction instructions to use Global Registers
1396   if (GR->hasConstFunPtr())
1397     collectFuncPtrs();
1398 
1399   // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.
1400   processOtherInstrs(M);
1401 
1402   // If there are no entry points, we need the Linkage capability.
1403   if (MAI.MS[SPIRV::MB_EntryPoints].empty())
1404     MAI.Reqs.addCapability(SPIRV::Capability::Linkage);
1405 
1406   // Set maximum ID used.
1407   GR->setBound(MAI.MaxID);
1408 
1409   return false;
1410 }
1411