xref: /llvm-project/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (revision fddf23c6f4478fc39b0077538d288082f983ce80)
1 //===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The analysis collects instructions that should be output at the module level
10 // and performs the global register numbering.
11 //
12 // The results of this analysis are used in AsmPrinter to rename registers
13 // globally and to output required instructions at the module level.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "SPIRVModuleAnalysis.h"
18 #include "MCTargetDesc/SPIRVBaseInfo.h"
19 #include "MCTargetDesc/SPIRVMCTargetDesc.h"
20 #include "SPIRV.h"
21 #include "SPIRVSubtarget.h"
22 #include "SPIRVTargetMachine.h"
23 #include "SPIRVUtils.h"
24 #include "TargetInfo/SPIRVTargetInfo.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/CodeGen/MachineModuleInfo.h"
27 #include "llvm/CodeGen/TargetPassConfig.h"
28 
29 using namespace llvm;
30 
31 #define DEBUG_TYPE "spirv-module-analysis"
32 
33 static cl::opt<bool>
34     SPVDumpDeps("spv-dump-deps",
35                 cl::desc("Dump MIR with SPIR-V dependencies info"),
36                 cl::Optional, cl::init(false));
37 
38 static cl::list<SPIRV::Capability::Capability>
39     AvoidCapabilities("avoid-spirv-capabilities",
40                       cl::desc("SPIR-V capabilities to avoid if there are "
41                                "other options enabling a feature"),
42                       cl::ZeroOrMore, cl::Hidden,
43                       cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader",
44                                             "SPIR-V Shader capability")));
45 // Use sets instead of cl::list to check "if contains" condition
46 struct AvoidCapabilitiesSet {
47   SmallSet<SPIRV::Capability::Capability, 4> S;
48   AvoidCapabilitiesSet() {
49     for (auto Cap : AvoidCapabilities)
50       S.insert(Cap);
51   }
52 };
53 
54 char llvm::SPIRVModuleAnalysis::ID = 0;
55 
56 namespace llvm {
57 void initializeSPIRVModuleAnalysisPass(PassRegistry &);
58 } // namespace llvm
59 
60 INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true,
61                 true)
62 
63 // Retrieve an unsigned from an MDNode with a list of them as operands.
64 static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex,
65                                 unsigned DefaultVal = 0) {
66   if (MdNode && OpIndex < MdNode->getNumOperands()) {
67     const auto &Op = MdNode->getOperand(OpIndex);
68     return mdconst::extract<ConstantInt>(Op)->getZExtValue();
69   }
70   return DefaultVal;
71 }
72 
73 static SPIRV::Requirements
74 getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,
75                                unsigned i, const SPIRVSubtarget &ST,
76                                SPIRV::RequirementHandler &Reqs) {
77   static AvoidCapabilitiesSet
78       AvoidCaps; // contains capabilities to avoid if there is another option
79   unsigned ReqMinVer = getSymbolicOperandMinVersion(Category, i);
80   unsigned ReqMaxVer = getSymbolicOperandMaxVersion(Category, i);
81   unsigned TargetVer = ST.getSPIRVVersion();
82   bool MinVerOK = !ReqMinVer || !TargetVer || TargetVer >= ReqMinVer;
83   bool MaxVerOK = !ReqMaxVer || !TargetVer || TargetVer <= ReqMaxVer;
84   CapabilityList ReqCaps = getSymbolicOperandCapabilities(Category, i);
85   ExtensionList ReqExts = getSymbolicOperandExtensions(Category, i);
86   if (ReqCaps.empty()) {
87     if (ReqExts.empty()) {
88       if (MinVerOK && MaxVerOK)
89         return {true, {}, {}, ReqMinVer, ReqMaxVer};
90       return {false, {}, {}, 0, 0};
91     }
92   } else if (MinVerOK && MaxVerOK) {
93     if (ReqCaps.size() == 1) {
94       auto Cap = ReqCaps[0];
95       if (Reqs.isCapabilityAvailable(Cap))
96         return {true, {Cap}, {}, ReqMinVer, ReqMaxVer};
97     } else {
98       // By SPIR-V specification: "If an instruction, enumerant, or other
99       // feature specifies multiple enabling capabilities, only one such
100       // capability needs to be declared to use the feature." However, one
101       // capability may be preferred over another. We use command line
102       // argument(s) and AvoidCapabilities to avoid selection of certain
103       // capabilities if there are other options.
104       CapabilityList UseCaps;
105       for (auto Cap : ReqCaps)
106         if (Reqs.isCapabilityAvailable(Cap))
107           UseCaps.push_back(Cap);
108       for (size_t i = 0, Sz = UseCaps.size(); i < Sz; ++i) {
109         auto Cap = UseCaps[i];
110         if (i == Sz - 1 || !AvoidCaps.S.contains(Cap))
111           return {true, {Cap}, {}, ReqMinVer, ReqMaxVer};
112       }
113     }
114   }
115   // If there are no capabilities, or we can't satisfy the version or
116   // capability requirements, use the list of extensions (if the subtarget
117   // can handle them all).
118   if (llvm::all_of(ReqExts, [&ST](const SPIRV::Extension::Extension &Ext) {
119         return ST.canUseExtension(Ext);
120       })) {
121     return {true, {}, ReqExts, 0, 0}; // TODO: add versions to extensions.
122   }
123   return {false, {}, {}, 0, 0};
124 }
125 
126 void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {
127   MAI.MaxID = 0;
128   for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++)
129     MAI.MS[i].clear();
130   MAI.RegisterAliasTable.clear();
131   MAI.InstrsToDelete.clear();
132   MAI.FuncMap.clear();
133   MAI.GlobalVarList.clear();
134   MAI.ExtInstSetMap.clear();
135   MAI.Reqs.clear();
136   MAI.Reqs.initAvailableCapabilities(*ST);
137 
138   // TODO: determine memory model and source language from the configuratoin.
139   if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) {
140     auto MemMD = MemModel->getOperand(0);
141     MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>(
142         getMetadataUInt(MemMD, 0));
143     MAI.Mem =
144         static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1));
145   } else {
146     // TODO: Add support for VulkanMemoryModel.
147     MAI.Mem = ST->isOpenCLEnv() ? SPIRV::MemoryModel::OpenCL
148                                 : SPIRV::MemoryModel::GLSL450;
149     if (MAI.Mem == SPIRV::MemoryModel::OpenCL) {
150       unsigned PtrSize = ST->getPointerSize();
151       MAI.Addr = PtrSize == 32   ? SPIRV::AddressingModel::Physical32
152                  : PtrSize == 64 ? SPIRV::AddressingModel::Physical64
153                                  : SPIRV::AddressingModel::Logical;
154     } else {
155       // TODO: Add support for PhysicalStorageBufferAddress.
156       MAI.Addr = SPIRV::AddressingModel::Logical;
157     }
158   }
159   // Get the OpenCL version number from metadata.
160   // TODO: support other source languages.
161   if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) {
162     MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C;
163     // Construct version literal in accordance with SPIRV-LLVM-Translator.
164     // TODO: support multiple OCL version metadata.
165     assert(VerNode->getNumOperands() > 0 && "Invalid SPIR");
166     auto VersionMD = VerNode->getOperand(0);
167     unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2);
168     unsigned MinorNum = getMetadataUInt(VersionMD, 1);
169     unsigned RevNum = getMetadataUInt(VersionMD, 2);
170     MAI.SrcLangVersion = (MajorNum * 100 + MinorNum) * 1000 + RevNum;
171   } else {
172     MAI.SrcLang = SPIRV::SourceLanguage::Unknown;
173     MAI.SrcLangVersion = 0;
174   }
175 
176   if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) {
177     for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) {
178       MDNode *MD = ExtNode->getOperand(I);
179       if (!MD || MD->getNumOperands() == 0)
180         continue;
181       for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)
182         MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString());
183     }
184   }
185 
186   // Update required capabilities for this memory model, addressing model and
187   // source language.
188   MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand,
189                                  MAI.Mem, *ST);
190   MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::SourceLanguageOperand,
191                                  MAI.SrcLang, *ST);
192   MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
193                                  MAI.Addr, *ST);
194 
195   if (ST->isOpenCLEnv()) {
196     // TODO: check if it's required by default.
197     MAI.ExtInstSetMap[static_cast<unsigned>(
198         SPIRV::InstructionSet::OpenCL_std)] =
199         Register::index2VirtReg(MAI.getNextID());
200   }
201 }
202 
203 // Collect MI which defines the register in the given machine function.
204 static void collectDefInstr(Register Reg, const MachineFunction *MF,
205                             SPIRV::ModuleAnalysisInfo *MAI,
206                             SPIRV::ModuleSectionType MSType,
207                             bool DoInsert = true) {
208   assert(MAI->hasRegisterAlias(MF, Reg) && "Cannot find register alias");
209   MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(Reg);
210   assert(MI && "There should be an instruction that defines the register");
211   MAI->setSkipEmission(MI);
212   if (DoInsert)
213     MAI->MS[MSType].push_back(MI);
214 }
215 
216 void SPIRVModuleAnalysis::collectGlobalEntities(
217     const std::vector<SPIRV::DTSortableEntry *> &DepsGraph,
218     SPIRV::ModuleSectionType MSType,
219     std::function<bool(const SPIRV::DTSortableEntry *)> Pred,
220     bool UsePreOrder = false) {
221   DenseSet<const SPIRV::DTSortableEntry *> Visited;
222   for (const auto *E : DepsGraph) {
223     std::function<void(const SPIRV::DTSortableEntry *)> RecHoistUtil;
224     // NOTE: here we prefer recursive approach over iterative because
225     // we don't expect depchains long enough to cause SO.
226     RecHoistUtil = [MSType, UsePreOrder, &Visited, &Pred,
227                     &RecHoistUtil](const SPIRV::DTSortableEntry *E) {
228       if (Visited.count(E) || !Pred(E))
229         return;
230       Visited.insert(E);
231 
232       // Traversing deps graph in post-order allows us to get rid of
233       // register aliases preprocessing.
234       // But pre-order is required for correct processing of function
235       // declaration and arguments processing.
236       if (!UsePreOrder)
237         for (auto *S : E->getDeps())
238           RecHoistUtil(S);
239 
240       Register GlobalReg = Register::index2VirtReg(MAI.getNextID());
241       bool IsFirst = true;
242       for (auto &U : *E) {
243         const MachineFunction *MF = U.first;
244         Register Reg = U.second;
245         MAI.setRegisterAlias(MF, Reg, GlobalReg);
246         if (!MF->getRegInfo().getUniqueVRegDef(Reg))
247           continue;
248         collectDefInstr(Reg, MF, &MAI, MSType, IsFirst);
249         IsFirst = false;
250         if (E->getIsGV())
251           MAI.GlobalVarList.push_back(MF->getRegInfo().getUniqueVRegDef(Reg));
252       }
253 
254       if (UsePreOrder)
255         for (auto *S : E->getDeps())
256           RecHoistUtil(S);
257     };
258     RecHoistUtil(E);
259   }
260 }
261 
262 // The function initializes global register alias table for types, consts,
263 // global vars and func decls and collects these instruction for output
264 // at module level. Also it collects explicit OpExtension/OpCapability
265 // instructions.
266 void SPIRVModuleAnalysis::processDefInstrs(const Module &M) {
267   std::vector<SPIRV::DTSortableEntry *> DepsGraph;
268 
269   GR->buildDepsGraph(DepsGraph, SPVDumpDeps ? MMI : nullptr);
270 
271   collectGlobalEntities(
272       DepsGraph, SPIRV::MB_TypeConstVars,
273       [](const SPIRV::DTSortableEntry *E) { return !E->getIsFunc(); });
274 
275   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
276     MachineFunction *MF = MMI->getMachineFunction(*F);
277     if (!MF)
278       continue;
279     // Iterate through and collect OpExtension/OpCapability instructions.
280     for (MachineBasicBlock &MBB : *MF) {
281       for (MachineInstr &MI : MBB) {
282         if (MI.getOpcode() == SPIRV::OpExtension) {
283           // Here, OpExtension just has a single enum operand, not a string.
284           auto Ext = SPIRV::Extension::Extension(MI.getOperand(0).getImm());
285           MAI.Reqs.addExtension(Ext);
286           MAI.setSkipEmission(&MI);
287         } else if (MI.getOpcode() == SPIRV::OpCapability) {
288           auto Cap = SPIRV::Capability::Capability(MI.getOperand(0).getImm());
289           MAI.Reqs.addCapability(Cap);
290           MAI.setSkipEmission(&MI);
291         }
292       }
293     }
294   }
295 
296   collectGlobalEntities(
297       DepsGraph, SPIRV::MB_ExtFuncDecls,
298       [](const SPIRV::DTSortableEntry *E) { return E->getIsFunc(); }, true);
299 }
300 
301 // Look for IDs declared with Import linkage, and map the corresponding function
302 // to the register defining that variable (which will usually be the result of
303 // an OpFunction). This lets us call externally imported functions using
304 // the correct ID registers.
305 void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
306                                            const Function *F) {
307   if (MI.getOpcode() == SPIRV::OpDecorate) {
308     // If it's got Import linkage.
309     auto Dec = MI.getOperand(1).getImm();
310     if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) {
311       auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm();
312       if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) {
313         // Map imported function name to function ID register.
314         const Function *ImportedFunc =
315             F->getParent()->getFunction(getStringImm(MI, 2));
316         Register Target = MI.getOperand(0).getReg();
317         MAI.FuncMap[ImportedFunc] = MAI.getRegisterAlias(MI.getMF(), Target);
318       }
319     }
320   } else if (MI.getOpcode() == SPIRV::OpFunction) {
321     // Record all internal OpFunction declarations.
322     Register Reg = MI.defs().begin()->getReg();
323     Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg);
324     assert(GlobalReg.isValid());
325     MAI.FuncMap[F] = GlobalReg;
326   }
327 }
328 
329 // References to a function via function pointers generate virtual
330 // registers without a definition. We are able to resolve this
331 // reference using Globar Register info into an OpFunction instruction
332 // and replace dummy operands by the corresponding global register references.
333 void SPIRVModuleAnalysis::collectFuncPtrs() {
334   for (auto &MI : MAI.MS[SPIRV::MB_TypeConstVars])
335     if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL)
336       collectFuncPtrs(MI);
337 }
338 
339 void SPIRVModuleAnalysis::collectFuncPtrs(MachineInstr *MI) {
340   const MachineOperand *FunUse = &MI->getOperand(2);
341   if (const MachineOperand *FunDef = GR->getFunctionDefinitionByUse(FunUse)) {
342     const MachineInstr *FunDefMI = FunDef->getParent();
343     assert(FunDefMI->getOpcode() == SPIRV::OpFunction &&
344            "Constant function pointer must refer to function definition");
345     Register FunDefReg = FunDef->getReg();
346     Register GlobalFunDefReg =
347         MAI.getRegisterAlias(FunDefMI->getMF(), FunDefReg);
348     assert(GlobalFunDefReg.isValid() &&
349            "Function definition must refer to a global register");
350     Register FunPtrReg = FunUse->getReg();
351     MAI.setRegisterAlias(MI->getMF(), FunPtrReg, GlobalFunDefReg);
352   }
353 }
354 
355 using InstrSignature = SmallVector<size_t>;
356 using InstrTraces = std::set<InstrSignature>;
357 
358 // Returns a representation of an instruction as a vector of MachineOperand
359 // hash values, see llvm::hash_value(const MachineOperand &MO) for details.
360 // This creates a signature of the instruction with the same content
361 // that MachineOperand::isIdenticalTo uses for comparison.
362 static InstrSignature instrToSignature(MachineInstr &MI,
363                                        SPIRV::ModuleAnalysisInfo &MAI) {
364   InstrSignature Signature;
365   for (unsigned i = 0; i < MI.getNumOperands(); ++i) {
366     const MachineOperand &MO = MI.getOperand(i);
367     size_t h;
368     if (MO.isReg()) {
369       Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg());
370       // mimic llvm::hash_value(const MachineOperand &MO)
371       h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(),
372                        MO.isDef());
373     } else {
374       h = hash_value(MO);
375     }
376     Signature.push_back(h);
377   }
378   return Signature;
379 }
380 
381 // Collect the given instruction in the specified MS. We assume global register
382 // numbering has already occurred by this point. We can directly compare reg
383 // arguments when detecting duplicates.
384 static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI,
385                               SPIRV::ModuleSectionType MSType, InstrTraces &IS,
386                               bool Append = true) {
387   MAI.setSkipEmission(&MI);
388   InstrSignature MISign = instrToSignature(MI, MAI);
389   auto FoundMI = IS.insert(MISign);
390   if (!FoundMI.second)
391     return; // insert failed, so we found a duplicate; don't add it to MAI.MS
392   // No duplicates, so add it.
393   if (Append)
394     MAI.MS[MSType].push_back(&MI);
395   else
396     MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI);
397 }
398 
399 // Some global instructions make reference to function-local ID regs, so cannot
400 // be correctly collected until these registers are globally numbered.
401 void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) {
402   InstrTraces IS;
403   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
404     if ((*F).isDeclaration())
405       continue;
406     MachineFunction *MF = MMI->getMachineFunction(*F);
407     assert(MF);
408     for (MachineBasicBlock &MBB : *MF)
409       for (MachineInstr &MI : MBB) {
410         if (MAI.getSkipEmission(&MI))
411           continue;
412         const unsigned OpCode = MI.getOpcode();
413         if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) {
414           collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames, IS);
415         } else if (OpCode == SPIRV::OpEntryPoint) {
416           collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints, IS);
417         } else if (TII->isDecorationInstr(MI)) {
418           collectOtherInstr(MI, MAI, SPIRV::MB_Annotations, IS);
419           collectFuncNames(MI, &*F);
420         } else if (TII->isConstantInstr(MI)) {
421           // Now OpSpecConstant*s are not in DT,
422           // but they need to be collected anyway.
423           collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS);
424         } else if (OpCode == SPIRV::OpFunction) {
425           collectFuncNames(MI, &*F);
426         } else if (OpCode == SPIRV::OpTypeForwardPointer) {
427           collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS, false);
428         }
429       }
430   }
431 }
432 
433 // Number registers in all functions globally from 0 onwards and store
434 // the result in global register alias table. Some registers are already
435 // numbered in collectGlobalEntities.
436 void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {
437   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
438     if ((*F).isDeclaration())
439       continue;
440     MachineFunction *MF = MMI->getMachineFunction(*F);
441     assert(MF);
442     for (MachineBasicBlock &MBB : *MF) {
443       for (MachineInstr &MI : MBB) {
444         for (MachineOperand &Op : MI.operands()) {
445           if (!Op.isReg())
446             continue;
447           Register Reg = Op.getReg();
448           if (MAI.hasRegisterAlias(MF, Reg))
449             continue;
450           Register NewReg = Register::index2VirtReg(MAI.getNextID());
451           MAI.setRegisterAlias(MF, Reg, NewReg);
452         }
453         if (MI.getOpcode() != SPIRV::OpExtInst)
454           continue;
455         auto Set = MI.getOperand(2).getImm();
456         if (!MAI.ExtInstSetMap.contains(Set))
457           MAI.ExtInstSetMap[Set] = Register::index2VirtReg(MAI.getNextID());
458       }
459     }
460   }
461 }
462 
463 // RequirementHandler implementations.
464 void SPIRV::RequirementHandler::getAndAddRequirements(
465     SPIRV::OperandCategory::OperandCategory Category, uint32_t i,
466     const SPIRVSubtarget &ST) {
467   addRequirements(getSymbolicOperandRequirements(Category, i, ST, *this));
468 }
469 
470 void SPIRV::RequirementHandler::recursiveAddCapabilities(
471     const CapabilityList &ToPrune) {
472   for (const auto &Cap : ToPrune) {
473     AllCaps.insert(Cap);
474     CapabilityList ImplicitDecls =
475         getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
476     recursiveAddCapabilities(ImplicitDecls);
477   }
478 }
479 
480 void SPIRV::RequirementHandler::addCapabilities(const CapabilityList &ToAdd) {
481   for (const auto &Cap : ToAdd) {
482     bool IsNewlyInserted = AllCaps.insert(Cap).second;
483     if (!IsNewlyInserted) // Don't re-add if it's already been declared.
484       continue;
485     CapabilityList ImplicitDecls =
486         getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
487     recursiveAddCapabilities(ImplicitDecls);
488     MinimalCaps.push_back(Cap);
489   }
490 }
491 
492 void SPIRV::RequirementHandler::addRequirements(
493     const SPIRV::Requirements &Req) {
494   if (!Req.IsSatisfiable)
495     report_fatal_error("Adding SPIR-V requirements this target can't satisfy.");
496 
497   if (Req.Cap.has_value())
498     addCapabilities({Req.Cap.value()});
499 
500   addExtensions(Req.Exts);
501 
502   if (Req.MinVer) {
503     if (MaxVersion && Req.MinVer > MaxVersion) {
504       LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer
505                         << " and <= " << MaxVersion << "\n");
506       report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
507     }
508 
509     if (MinVersion == 0 || Req.MinVer > MinVersion)
510       MinVersion = Req.MinVer;
511   }
512 
513   if (Req.MaxVer) {
514     if (MinVersion && Req.MaxVer < MinVersion) {
515       LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer
516                         << " and >= " << MinVersion << "\n");
517       report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
518     }
519 
520     if (MaxVersion == 0 || Req.MaxVer < MaxVersion)
521       MaxVersion = Req.MaxVer;
522   }
523 }
524 
525 void SPIRV::RequirementHandler::checkSatisfiable(
526     const SPIRVSubtarget &ST) const {
527   // Report as many errors as possible before aborting the compilation.
528   bool IsSatisfiable = true;
529   auto TargetVer = ST.getSPIRVVersion();
530 
531   if (MaxVersion && TargetVer && MaxVersion < TargetVer) {
532     LLVM_DEBUG(
533         dbgs() << "Target SPIR-V version too high for required features\n"
534                << "Required max version: " << MaxVersion << " target version "
535                << TargetVer << "\n");
536     IsSatisfiable = false;
537   }
538 
539   if (MinVersion && TargetVer && MinVersion > TargetVer) {
540     LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n"
541                       << "Required min version: " << MinVersion
542                       << " target version " << TargetVer << "\n");
543     IsSatisfiable = false;
544   }
545 
546   if (MinVersion && MaxVersion && MinVersion > MaxVersion) {
547     LLVM_DEBUG(
548         dbgs()
549         << "Version is too low for some features and too high for others.\n"
550         << "Required SPIR-V min version: " << MinVersion
551         << " required SPIR-V max version " << MaxVersion << "\n");
552     IsSatisfiable = false;
553   }
554 
555   for (auto Cap : MinimalCaps) {
556     if (AvailableCaps.contains(Cap))
557       continue;
558     LLVM_DEBUG(dbgs() << "Capability not supported: "
559                       << getSymbolicOperandMnemonic(
560                              OperandCategory::CapabilityOperand, Cap)
561                       << "\n");
562     IsSatisfiable = false;
563   }
564 
565   for (auto Ext : AllExtensions) {
566     if (ST.canUseExtension(Ext))
567       continue;
568     LLVM_DEBUG(dbgs() << "Extension not supported: "
569                       << getSymbolicOperandMnemonic(
570                              OperandCategory::ExtensionOperand, Ext)
571                       << "\n");
572     IsSatisfiable = false;
573   }
574 
575   if (!IsSatisfiable)
576     report_fatal_error("Unable to meet SPIR-V requirements for this target.");
577 }
578 
579 // Add the given capabilities and all their implicitly defined capabilities too.
580 void SPIRV::RequirementHandler::addAvailableCaps(const CapabilityList &ToAdd) {
581   for (const auto Cap : ToAdd)
582     if (AvailableCaps.insert(Cap).second)
583       addAvailableCaps(getSymbolicOperandCapabilities(
584           SPIRV::OperandCategory::CapabilityOperand, Cap));
585 }
586 
587 void SPIRV::RequirementHandler::removeCapabilityIf(
588     const Capability::Capability ToRemove,
589     const Capability::Capability IfPresent) {
590   if (AllCaps.contains(IfPresent))
591     AllCaps.erase(ToRemove);
592 }
593 
594 namespace llvm {
595 namespace SPIRV {
596 void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
597   if (ST.isOpenCLEnv()) {
598     initAvailableCapabilitiesForOpenCL(ST);
599     return;
600   }
601 
602   if (ST.isVulkanEnv()) {
603     initAvailableCapabilitiesForVulkan(ST);
604     return;
605   }
606 
607   report_fatal_error("Unimplemented environment for SPIR-V generation.");
608 }
609 
610 void RequirementHandler::initAvailableCapabilitiesForOpenCL(
611     const SPIRVSubtarget &ST) {
612   // Add the min requirements for different OpenCL and SPIR-V versions.
613   addAvailableCaps({Capability::Addresses, Capability::Float16Buffer,
614                     Capability::Int16, Capability::Int8, Capability::Kernel,
615                     Capability::Linkage, Capability::Vector16,
616                     Capability::Groups, Capability::GenericPointer,
617                     Capability::Shader});
618   if (ST.hasOpenCLFullProfile())
619     addAvailableCaps({Capability::Int64, Capability::Int64Atomics});
620   if (ST.hasOpenCLImageSupport()) {
621     addAvailableCaps({Capability::ImageBasic, Capability::LiteralSampler,
622                       Capability::Image1D, Capability::SampledBuffer,
623                       Capability::ImageBuffer});
624     if (ST.isAtLeastOpenCLVer(20))
625       addAvailableCaps({Capability::ImageReadWrite});
626   }
627   if (ST.isAtLeastSPIRVVer(11) && ST.isAtLeastOpenCLVer(22))
628     addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage});
629   if (ST.isAtLeastSPIRVVer(13))
630     addAvailableCaps({Capability::GroupNonUniform,
631                       Capability::GroupNonUniformVote,
632                       Capability::GroupNonUniformArithmetic,
633                       Capability::GroupNonUniformBallot,
634                       Capability::GroupNonUniformClustered,
635                       Capability::GroupNonUniformShuffle,
636                       Capability::GroupNonUniformShuffleRelative});
637   if (ST.isAtLeastSPIRVVer(14))
638     addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero,
639                       Capability::SignedZeroInfNanPreserve,
640                       Capability::RoundingModeRTE,
641                       Capability::RoundingModeRTZ});
642   // TODO: verify if this needs some checks.
643   addAvailableCaps({Capability::Float16, Capability::Float64});
644 
645   // Add capabilities enabled by extensions.
646   for (auto Extension : ST.getAllAvailableExtensions()) {
647     CapabilityList EnabledCapabilities =
648         getCapabilitiesEnabledByExtension(Extension);
649     addAvailableCaps(EnabledCapabilities);
650   }
651 
652   // TODO: add OpenCL extensions.
653 }
654 
655 void RequirementHandler::initAvailableCapabilitiesForVulkan(
656     const SPIRVSubtarget &ST) {
657   addAvailableCaps({Capability::Shader, Capability::Linkage});
658 
659   // Provided by all supported Vulkan versions.
660   addAvailableCaps({Capability::Int16, Capability::Int64, Capability::Float16,
661                     Capability::Float64});
662 }
663 
664 } // namespace SPIRV
665 } // namespace llvm
666 
667 // Add the required capabilities from a decoration instruction (including
668 // BuiltIns).
669 static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex,
670                               SPIRV::RequirementHandler &Reqs,
671                               const SPIRVSubtarget &ST) {
672   int64_t DecOp = MI.getOperand(DecIndex).getImm();
673   auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp);
674   Reqs.addRequirements(getSymbolicOperandRequirements(
675       SPIRV::OperandCategory::DecorationOperand, Dec, ST, Reqs));
676 
677   if (Dec == SPIRV::Decoration::BuiltIn) {
678     int64_t BuiltInOp = MI.getOperand(DecIndex + 1).getImm();
679     auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp);
680     Reqs.addRequirements(getSymbolicOperandRequirements(
681         SPIRV::OperandCategory::BuiltInOperand, BuiltIn, ST, Reqs));
682   } else if (Dec == SPIRV::Decoration::LinkageAttributes) {
683     int64_t LinkageOp = MI.getOperand(MI.getNumOperands() - 1).getImm();
684     SPIRV::LinkageType::LinkageType LnkType =
685         static_cast<SPIRV::LinkageType::LinkageType>(LinkageOp);
686     if (LnkType == SPIRV::LinkageType::LinkOnceODR)
687       Reqs.addExtension(SPIRV::Extension::SPV_KHR_linkonce_odr);
688   }
689 }
690 
691 // Add requirements for image handling.
692 static void addOpTypeImageReqs(const MachineInstr &MI,
693                                SPIRV::RequirementHandler &Reqs,
694                                const SPIRVSubtarget &ST) {
695   assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage");
696   // The operand indices used here are based on the OpTypeImage layout, which
697   // the MachineInstr follows as well.
698   int64_t ImgFormatOp = MI.getOperand(7).getImm();
699   auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp);
700   Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageFormatOperand,
701                              ImgFormat, ST);
702 
703   bool IsArrayed = MI.getOperand(4).getImm() == 1;
704   bool IsMultisampled = MI.getOperand(5).getImm() == 1;
705   bool NoSampler = MI.getOperand(6).getImm() == 2;
706   // Add dimension requirements.
707   assert(MI.getOperand(2).isImm());
708   switch (MI.getOperand(2).getImm()) {
709   case SPIRV::Dim::DIM_1D:
710     Reqs.addRequirements(NoSampler ? SPIRV::Capability::Image1D
711                                    : SPIRV::Capability::Sampled1D);
712     break;
713   case SPIRV::Dim::DIM_2D:
714     if (IsMultisampled && NoSampler)
715       Reqs.addRequirements(SPIRV::Capability::ImageMSArray);
716     break;
717   case SPIRV::Dim::DIM_Cube:
718     Reqs.addRequirements(SPIRV::Capability::Shader);
719     if (IsArrayed)
720       Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageCubeArray
721                                      : SPIRV::Capability::SampledCubeArray);
722     break;
723   case SPIRV::Dim::DIM_Rect:
724     Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageRect
725                                    : SPIRV::Capability::SampledRect);
726     break;
727   case SPIRV::Dim::DIM_Buffer:
728     Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageBuffer
729                                    : SPIRV::Capability::SampledBuffer);
730     break;
731   case SPIRV::Dim::DIM_SubpassData:
732     Reqs.addRequirements(SPIRV::Capability::InputAttachment);
733     break;
734   }
735 
736   // Has optional access qualifier.
737   // TODO: check if it's OpenCL's kernel.
738   if (MI.getNumOperands() > 8 &&
739       MI.getOperand(8).getImm() == SPIRV::AccessQualifier::ReadWrite)
740     Reqs.addRequirements(SPIRV::Capability::ImageReadWrite);
741   else
742     Reqs.addRequirements(SPIRV::Capability::ImageBasic);
743 }
744 
745 // Add requirements for handling atomic float instructions
746 #define ATOM_FLT_REQ_EXT_MSG(ExtName)                                          \
747   "The atomic float instruction requires the following SPIR-V "                \
748   "extension: SPV_EXT_shader_atomic_float" ExtName
749 static void AddAtomicFloatRequirements(const MachineInstr &MI,
750                                        SPIRV::RequirementHandler &Reqs,
751                                        const SPIRVSubtarget &ST) {
752   assert(MI.getOperand(1).isReg() &&
753          "Expect register operand in atomic float instruction");
754   Register TypeReg = MI.getOperand(1).getReg();
755   SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg);
756   if (TypeDef->getOpcode() != SPIRV::OpTypeFloat)
757     report_fatal_error("Result type of an atomic float instruction must be a "
758                        "floating-point type scalar");
759 
760   unsigned BitWidth = TypeDef->getOperand(1).getImm();
761   unsigned Op = MI.getOpcode();
762   if (Op == SPIRV::OpAtomicFAddEXT) {
763     if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add))
764       report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_add"), false);
765     Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
766     switch (BitWidth) {
767     case 16:
768       if (!ST.canUseExtension(
769               SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
770         report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
771       Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
772       Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
773       break;
774     case 32:
775       Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT);
776       break;
777     case 64:
778       Reqs.addCapability(SPIRV::Capability::AtomicFloat64AddEXT);
779       break;
780     default:
781       report_fatal_error(
782           "Unexpected floating-point type width in atomic float instruction");
783     }
784   } else {
785     if (!ST.canUseExtension(
786             SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max))
787       report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), false);
788     Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
789     switch (BitWidth) {
790     case 16:
791       Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
792       break;
793     case 32:
794       Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT);
795       break;
796     case 64:
797       Reqs.addCapability(SPIRV::Capability::AtomicFloat64MinMaxEXT);
798       break;
799     default:
800       report_fatal_error(
801           "Unexpected floating-point type width in atomic float instruction");
802     }
803   }
804 }
805 
806 void addInstrRequirements(const MachineInstr &MI,
807                           SPIRV::RequirementHandler &Reqs,
808                           const SPIRVSubtarget &ST) {
809   switch (MI.getOpcode()) {
810   case SPIRV::OpMemoryModel: {
811     int64_t Addr = MI.getOperand(0).getImm();
812     Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
813                                Addr, ST);
814     int64_t Mem = MI.getOperand(1).getImm();
815     Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, Mem,
816                                ST);
817     break;
818   }
819   case SPIRV::OpEntryPoint: {
820     int64_t Exe = MI.getOperand(0).getImm();
821     Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModelOperand,
822                                Exe, ST);
823     break;
824   }
825   case SPIRV::OpExecutionMode:
826   case SPIRV::OpExecutionModeId: {
827     int64_t Exe = MI.getOperand(1).getImm();
828     Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModeOperand,
829                                Exe, ST);
830     break;
831   }
832   case SPIRV::OpTypeMatrix:
833     Reqs.addCapability(SPIRV::Capability::Matrix);
834     break;
835   case SPIRV::OpTypeInt: {
836     unsigned BitWidth = MI.getOperand(1).getImm();
837     if (BitWidth == 64)
838       Reqs.addCapability(SPIRV::Capability::Int64);
839     else if (BitWidth == 16)
840       Reqs.addCapability(SPIRV::Capability::Int16);
841     else if (BitWidth == 8)
842       Reqs.addCapability(SPIRV::Capability::Int8);
843     break;
844   }
845   case SPIRV::OpTypeFloat: {
846     unsigned BitWidth = MI.getOperand(1).getImm();
847     if (BitWidth == 64)
848       Reqs.addCapability(SPIRV::Capability::Float64);
849     else if (BitWidth == 16)
850       Reqs.addCapability(SPIRV::Capability::Float16);
851     break;
852   }
853   case SPIRV::OpTypeVector: {
854     unsigned NumComponents = MI.getOperand(2).getImm();
855     if (NumComponents == 8 || NumComponents == 16)
856       Reqs.addCapability(SPIRV::Capability::Vector16);
857     break;
858   }
859   case SPIRV::OpTypePointer: {
860     auto SC = MI.getOperand(1).getImm();
861     Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC,
862                                ST);
863     // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer
864     // capability.
865     if (!ST.isOpenCLEnv())
866       break;
867     assert(MI.getOperand(2).isReg());
868     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
869     SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
870     if (TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
871         TypeDef->getOperand(1).getImm() == 16)
872       Reqs.addCapability(SPIRV::Capability::Float16Buffer);
873     break;
874   }
875   case SPIRV::OpBitReverse:
876   case SPIRV::OpBitFieldInsert:
877   case SPIRV::OpBitFieldSExtract:
878   case SPIRV::OpBitFieldUExtract:
879     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions)) {
880       Reqs.addCapability(SPIRV::Capability::Shader);
881       break;
882     }
883     Reqs.addExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
884     Reqs.addCapability(SPIRV::Capability::BitInstructions);
885     break;
886   case SPIRV::OpTypeRuntimeArray:
887     Reqs.addCapability(SPIRV::Capability::Shader);
888     break;
889   case SPIRV::OpTypeOpaque:
890   case SPIRV::OpTypeEvent:
891     Reqs.addCapability(SPIRV::Capability::Kernel);
892     break;
893   case SPIRV::OpTypePipe:
894   case SPIRV::OpTypeReserveId:
895     Reqs.addCapability(SPIRV::Capability::Pipes);
896     break;
897   case SPIRV::OpTypeDeviceEvent:
898   case SPIRV::OpTypeQueue:
899   case SPIRV::OpBuildNDRange:
900     Reqs.addCapability(SPIRV::Capability::DeviceEnqueue);
901     break;
902   case SPIRV::OpDecorate:
903   case SPIRV::OpDecorateId:
904   case SPIRV::OpDecorateString:
905     addOpDecorateReqs(MI, 1, Reqs, ST);
906     break;
907   case SPIRV::OpMemberDecorate:
908   case SPIRV::OpMemberDecorateString:
909     addOpDecorateReqs(MI, 2, Reqs, ST);
910     break;
911   case SPIRV::OpInBoundsPtrAccessChain:
912     Reqs.addCapability(SPIRV::Capability::Addresses);
913     break;
914   case SPIRV::OpConstantSampler:
915     Reqs.addCapability(SPIRV::Capability::LiteralSampler);
916     break;
917   case SPIRV::OpTypeImage:
918     addOpTypeImageReqs(MI, Reqs, ST);
919     break;
920   case SPIRV::OpTypeSampler:
921     Reqs.addCapability(SPIRV::Capability::ImageBasic);
922     break;
923   case SPIRV::OpTypeForwardPointer:
924     // TODO: check if it's OpenCL's kernel.
925     Reqs.addCapability(SPIRV::Capability::Addresses);
926     break;
927   case SPIRV::OpAtomicFlagTestAndSet:
928   case SPIRV::OpAtomicLoad:
929   case SPIRV::OpAtomicStore:
930   case SPIRV::OpAtomicExchange:
931   case SPIRV::OpAtomicCompareExchange:
932   case SPIRV::OpAtomicIIncrement:
933   case SPIRV::OpAtomicIDecrement:
934   case SPIRV::OpAtomicIAdd:
935   case SPIRV::OpAtomicISub:
936   case SPIRV::OpAtomicUMin:
937   case SPIRV::OpAtomicUMax:
938   case SPIRV::OpAtomicSMin:
939   case SPIRV::OpAtomicSMax:
940   case SPIRV::OpAtomicAnd:
941   case SPIRV::OpAtomicOr:
942   case SPIRV::OpAtomicXor: {
943     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
944     const MachineInstr *InstrPtr = &MI;
945     if (MI.getOpcode() == SPIRV::OpAtomicStore) {
946       assert(MI.getOperand(3).isReg());
947       InstrPtr = MRI.getVRegDef(MI.getOperand(3).getReg());
948       assert(InstrPtr && "Unexpected type instruction for OpAtomicStore");
949     }
950     assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic");
951     Register TypeReg = InstrPtr->getOperand(1).getReg();
952     SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
953     if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
954       unsigned BitWidth = TypeDef->getOperand(1).getImm();
955       if (BitWidth == 64)
956         Reqs.addCapability(SPIRV::Capability::Int64Atomics);
957     }
958     break;
959   }
960   case SPIRV::OpGroupNonUniformIAdd:
961   case SPIRV::OpGroupNonUniformFAdd:
962   case SPIRV::OpGroupNonUniformIMul:
963   case SPIRV::OpGroupNonUniformFMul:
964   case SPIRV::OpGroupNonUniformSMin:
965   case SPIRV::OpGroupNonUniformUMin:
966   case SPIRV::OpGroupNonUniformFMin:
967   case SPIRV::OpGroupNonUniformSMax:
968   case SPIRV::OpGroupNonUniformUMax:
969   case SPIRV::OpGroupNonUniformFMax:
970   case SPIRV::OpGroupNonUniformBitwiseAnd:
971   case SPIRV::OpGroupNonUniformBitwiseOr:
972   case SPIRV::OpGroupNonUniformBitwiseXor:
973   case SPIRV::OpGroupNonUniformLogicalAnd:
974   case SPIRV::OpGroupNonUniformLogicalOr:
975   case SPIRV::OpGroupNonUniformLogicalXor: {
976     assert(MI.getOperand(3).isImm());
977     int64_t GroupOp = MI.getOperand(3).getImm();
978     switch (GroupOp) {
979     case SPIRV::GroupOperation::Reduce:
980     case SPIRV::GroupOperation::InclusiveScan:
981     case SPIRV::GroupOperation::ExclusiveScan:
982       Reqs.addCapability(SPIRV::Capability::Kernel);
983       Reqs.addCapability(SPIRV::Capability::GroupNonUniformArithmetic);
984       Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
985       break;
986     case SPIRV::GroupOperation::ClusteredReduce:
987       Reqs.addCapability(SPIRV::Capability::GroupNonUniformClustered);
988       break;
989     case SPIRV::GroupOperation::PartitionedReduceNV:
990     case SPIRV::GroupOperation::PartitionedInclusiveScanNV:
991     case SPIRV::GroupOperation::PartitionedExclusiveScanNV:
992       Reqs.addCapability(SPIRV::Capability::GroupNonUniformPartitionedNV);
993       break;
994     }
995     break;
996   }
997   case SPIRV::OpGroupNonUniformShuffle:
998   case SPIRV::OpGroupNonUniformShuffleXor:
999     Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffle);
1000     break;
1001   case SPIRV::OpGroupNonUniformShuffleUp:
1002   case SPIRV::OpGroupNonUniformShuffleDown:
1003     Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffleRelative);
1004     break;
1005   case SPIRV::OpGroupAll:
1006   case SPIRV::OpGroupAny:
1007   case SPIRV::OpGroupBroadcast:
1008   case SPIRV::OpGroupIAdd:
1009   case SPIRV::OpGroupFAdd:
1010   case SPIRV::OpGroupFMin:
1011   case SPIRV::OpGroupUMin:
1012   case SPIRV::OpGroupSMin:
1013   case SPIRV::OpGroupFMax:
1014   case SPIRV::OpGroupUMax:
1015   case SPIRV::OpGroupSMax:
1016     Reqs.addCapability(SPIRV::Capability::Groups);
1017     break;
1018   case SPIRV::OpGroupNonUniformElect:
1019     Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1020     break;
1021   case SPIRV::OpGroupNonUniformAll:
1022   case SPIRV::OpGroupNonUniformAny:
1023   case SPIRV::OpGroupNonUniformAllEqual:
1024     Reqs.addCapability(SPIRV::Capability::GroupNonUniformVote);
1025     break;
1026   case SPIRV::OpGroupNonUniformBroadcast:
1027   case SPIRV::OpGroupNonUniformBroadcastFirst:
1028   case SPIRV::OpGroupNonUniformBallot:
1029   case SPIRV::OpGroupNonUniformInverseBallot:
1030   case SPIRV::OpGroupNonUniformBallotBitExtract:
1031   case SPIRV::OpGroupNonUniformBallotBitCount:
1032   case SPIRV::OpGroupNonUniformBallotFindLSB:
1033   case SPIRV::OpGroupNonUniformBallotFindMSB:
1034     Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
1035     break;
1036   case SPIRV::OpSubgroupShuffleINTEL:
1037   case SPIRV::OpSubgroupShuffleDownINTEL:
1038   case SPIRV::OpSubgroupShuffleUpINTEL:
1039   case SPIRV::OpSubgroupShuffleXorINTEL:
1040     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1041       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1042       Reqs.addCapability(SPIRV::Capability::SubgroupShuffleINTEL);
1043     }
1044     break;
1045   case SPIRV::OpSubgroupBlockReadINTEL:
1046   case SPIRV::OpSubgroupBlockWriteINTEL:
1047     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1048       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1049       Reqs.addCapability(SPIRV::Capability::SubgroupBufferBlockIOINTEL);
1050     }
1051     break;
1052   case SPIRV::OpSubgroupImageBlockReadINTEL:
1053   case SPIRV::OpSubgroupImageBlockWriteINTEL:
1054     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1055       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1056       Reqs.addCapability(SPIRV::Capability::SubgroupImageBlockIOINTEL);
1057     }
1058     break;
1059   case SPIRV::OpAssumeTrueKHR:
1060   case SPIRV::OpExpectKHR:
1061     if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) {
1062       Reqs.addExtension(SPIRV::Extension::SPV_KHR_expect_assume);
1063       Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
1064     }
1065     break;
1066   case SPIRV::OpConstantFunctionPointerINTEL:
1067     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1068       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1069       Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1070     }
1071     break;
1072   case SPIRV::OpGroupNonUniformRotateKHR:
1073     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate))
1074       report_fatal_error("OpGroupNonUniformRotateKHR instruction requires the "
1075                          "following SPIR-V extension: SPV_KHR_subgroup_rotate",
1076                          false);
1077     Reqs.addExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate);
1078     Reqs.addCapability(SPIRV::Capability::GroupNonUniformRotateKHR);
1079     Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1080     break;
1081   case SPIRV::OpGroupIMulKHR:
1082   case SPIRV::OpGroupFMulKHR:
1083   case SPIRV::OpGroupBitwiseAndKHR:
1084   case SPIRV::OpGroupBitwiseOrKHR:
1085   case SPIRV::OpGroupBitwiseXorKHR:
1086   case SPIRV::OpGroupLogicalAndKHR:
1087   case SPIRV::OpGroupLogicalOrKHR:
1088   case SPIRV::OpGroupLogicalXorKHR:
1089     if (ST.canUseExtension(
1090             SPIRV::Extension::SPV_KHR_uniform_group_instructions)) {
1091       Reqs.addExtension(SPIRV::Extension::SPV_KHR_uniform_group_instructions);
1092       Reqs.addCapability(SPIRV::Capability::GroupUniformArithmeticKHR);
1093     }
1094     break;
1095   case SPIRV::OpFunctionPointerCallINTEL:
1096     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1097       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1098       Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1099     }
1100     break;
1101   case SPIRV::OpAtomicFAddEXT:
1102   case SPIRV::OpAtomicFMinEXT:
1103   case SPIRV::OpAtomicFMaxEXT:
1104     AddAtomicFloatRequirements(MI, Reqs, ST);
1105     break;
1106   default:
1107     break;
1108   }
1109 
1110   // If we require capability Shader, then we can remove the requirement for
1111   // the BitInstructions capability, since Shader is a superset capability
1112   // of BitInstructions.
1113   Reqs.removeCapabilityIf(SPIRV::Capability::BitInstructions,
1114                           SPIRV::Capability::Shader);
1115 }
1116 
1117 static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
1118                         MachineModuleInfo *MMI, const SPIRVSubtarget &ST) {
1119   // Collect requirements for existing instructions.
1120   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
1121     MachineFunction *MF = MMI->getMachineFunction(*F);
1122     if (!MF)
1123       continue;
1124     for (const MachineBasicBlock &MBB : *MF)
1125       for (const MachineInstr &MI : MBB)
1126         addInstrRequirements(MI, MAI.Reqs, ST);
1127   }
1128   // Collect requirements for OpExecutionMode instructions.
1129   auto Node = M.getNamedMetadata("spirv.ExecutionMode");
1130   if (Node) {
1131     for (unsigned i = 0; i < Node->getNumOperands(); i++) {
1132       MDNode *MDN = cast<MDNode>(Node->getOperand(i));
1133       const MDOperand &MDOp = MDN->getOperand(1);
1134       if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) {
1135         Constant *C = CMeta->getValue();
1136         if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) {
1137           auto EM = Const->getZExtValue();
1138           MAI.Reqs.getAndAddRequirements(
1139               SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
1140         }
1141       }
1142     }
1143   }
1144   for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
1145     const Function &F = *FI;
1146     if (F.isDeclaration())
1147       continue;
1148     if (F.getMetadata("reqd_work_group_size"))
1149       MAI.Reqs.getAndAddRequirements(
1150           SPIRV::OperandCategory::ExecutionModeOperand,
1151           SPIRV::ExecutionMode::LocalSize, ST);
1152     if (F.getFnAttribute("hlsl.numthreads").isValid()) {
1153       MAI.Reqs.getAndAddRequirements(
1154           SPIRV::OperandCategory::ExecutionModeOperand,
1155           SPIRV::ExecutionMode::LocalSize, ST);
1156     }
1157     if (F.getMetadata("work_group_size_hint"))
1158       MAI.Reqs.getAndAddRequirements(
1159           SPIRV::OperandCategory::ExecutionModeOperand,
1160           SPIRV::ExecutionMode::LocalSizeHint, ST);
1161     if (F.getMetadata("intel_reqd_sub_group_size"))
1162       MAI.Reqs.getAndAddRequirements(
1163           SPIRV::OperandCategory::ExecutionModeOperand,
1164           SPIRV::ExecutionMode::SubgroupSize, ST);
1165     if (F.getMetadata("vec_type_hint"))
1166       MAI.Reqs.getAndAddRequirements(
1167           SPIRV::OperandCategory::ExecutionModeOperand,
1168           SPIRV::ExecutionMode::VecTypeHint, ST);
1169 
1170     if (F.hasOptNone() &&
1171         ST.canUseExtension(SPIRV::Extension::SPV_INTEL_optnone)) {
1172       // Output OpCapability OptNoneINTEL.
1173       MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_optnone);
1174       MAI.Reqs.addCapability(SPIRV::Capability::OptNoneINTEL);
1175     }
1176   }
1177 }
1178 
1179 static unsigned getFastMathFlags(const MachineInstr &I) {
1180   unsigned Flags = SPIRV::FPFastMathMode::None;
1181   if (I.getFlag(MachineInstr::MIFlag::FmNoNans))
1182     Flags |= SPIRV::FPFastMathMode::NotNaN;
1183   if (I.getFlag(MachineInstr::MIFlag::FmNoInfs))
1184     Flags |= SPIRV::FPFastMathMode::NotInf;
1185   if (I.getFlag(MachineInstr::MIFlag::FmNsz))
1186     Flags |= SPIRV::FPFastMathMode::NSZ;
1187   if (I.getFlag(MachineInstr::MIFlag::FmArcp))
1188     Flags |= SPIRV::FPFastMathMode::AllowRecip;
1189   if (I.getFlag(MachineInstr::MIFlag::FmReassoc))
1190     Flags |= SPIRV::FPFastMathMode::Fast;
1191   return Flags;
1192 }
1193 
1194 static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST,
1195                                    const SPIRVInstrInfo &TII,
1196                                    SPIRV::RequirementHandler &Reqs) {
1197   if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) &&
1198       getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
1199                                      SPIRV::Decoration::NoSignedWrap, ST, Reqs)
1200           .IsSatisfiable) {
1201     buildOpDecorate(I.getOperand(0).getReg(), I, TII,
1202                     SPIRV::Decoration::NoSignedWrap, {});
1203   }
1204   if (I.getFlag(MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(I) &&
1205       getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
1206                                      SPIRV::Decoration::NoUnsignedWrap, ST,
1207                                      Reqs)
1208           .IsSatisfiable) {
1209     buildOpDecorate(I.getOperand(0).getReg(), I, TII,
1210                     SPIRV::Decoration::NoUnsignedWrap, {});
1211   }
1212   if (!TII.canUseFastMathFlags(I))
1213     return;
1214   unsigned FMFlags = getFastMathFlags(I);
1215   if (FMFlags == SPIRV::FPFastMathMode::None)
1216     return;
1217   Register DstReg = I.getOperand(0).getReg();
1218   buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode, {FMFlags});
1219 }
1220 
1221 // Walk all functions and add decorations related to MI flags.
1222 static void addDecorations(const Module &M, const SPIRVInstrInfo &TII,
1223                            MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
1224                            SPIRV::ModuleAnalysisInfo &MAI) {
1225   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
1226     MachineFunction *MF = MMI->getMachineFunction(*F);
1227     if (!MF)
1228       continue;
1229     for (auto &MBB : *MF)
1230       for (auto &MI : MBB)
1231         handleMIFlagDecoration(MI, ST, TII, MAI.Reqs);
1232   }
1233 }
1234 
1235 struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI;
1236 
1237 void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
1238   AU.addRequired<TargetPassConfig>();
1239   AU.addRequired<MachineModuleInfoWrapperPass>();
1240 }
1241 
1242 bool SPIRVModuleAnalysis::runOnModule(Module &M) {
1243   SPIRVTargetMachine &TM =
1244       getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>();
1245   ST = TM.getSubtargetImpl();
1246   GR = ST->getSPIRVGlobalRegistry();
1247   TII = ST->getInstrInfo();
1248 
1249   MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
1250 
1251   setBaseInfo(M);
1252 
1253   addDecorations(M, *TII, MMI, *ST, MAI);
1254 
1255   collectReqs(M, MAI, MMI, *ST);
1256 
1257   // Process type/const/global var/func decl instructions, number their
1258   // destination registers from 0 to N, collect Extensions and Capabilities.
1259   processDefInstrs(M);
1260 
1261   // Number rest of registers from N+1 onwards.
1262   numberRegistersGlobally(M);
1263 
1264   // Update references to OpFunction instructions to use Global Registers
1265   if (GR->hasConstFunPtr())
1266     collectFuncPtrs();
1267 
1268   // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.
1269   processOtherInstrs(M);
1270 
1271   // If there are no entry points, we need the Linkage capability.
1272   if (MAI.MS[SPIRV::MB_EntryPoints].empty())
1273     MAI.Reqs.addCapability(SPIRV::Capability::Linkage);
1274 
1275   return false;
1276 }
1277