xref: /llvm-project/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (revision 14193f4320e26b75dc592abf952af1c63760665c)
1 //===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The analysis collects instructions that should be output at the module level
10 // and performs the global register numbering.
11 //
12 // The results of this analysis are used in AsmPrinter to rename registers
13 // globally and to output required instructions at the module level.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "SPIRVModuleAnalysis.h"
18 #include "MCTargetDesc/SPIRVBaseInfo.h"
19 #include "MCTargetDesc/SPIRVMCTargetDesc.h"
20 #include "SPIRV.h"
21 #include "SPIRVSubtarget.h"
22 #include "SPIRVTargetMachine.h"
23 #include "SPIRVUtils.h"
24 #include "TargetInfo/SPIRVTargetInfo.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/CodeGen/MachineModuleInfo.h"
27 #include "llvm/CodeGen/TargetPassConfig.h"
28 
29 using namespace llvm;
30 
31 #define DEBUG_TYPE "spirv-module-analysis"
32 
33 static cl::opt<bool>
34     SPVDumpDeps("spv-dump-deps",
35                 cl::desc("Dump MIR with SPIR-V dependencies info"),
36                 cl::Optional, cl::init(false));
37 
38 static cl::list<SPIRV::Capability::Capability>
39     AvoidCapabilities("avoid-spirv-capabilities",
40                       cl::desc("SPIR-V capabilities to avoid if there are "
41                                "other options enabling a feature"),
42                       cl::ZeroOrMore, cl::Hidden,
43                       cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader",
44                                             "SPIR-V Shader capability")));
45 // Use sets instead of cl::list to check "if contains" condition
46 struct AvoidCapabilitiesSet {
47   SmallSet<SPIRV::Capability::Capability, 4> S;
48   AvoidCapabilitiesSet() {
49     for (auto Cap : AvoidCapabilities)
50       S.insert(Cap);
51   }
52 };
53 
54 char llvm::SPIRVModuleAnalysis::ID = 0;
55 
56 namespace llvm {
57 void initializeSPIRVModuleAnalysisPass(PassRegistry &);
58 } // namespace llvm
59 
60 INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true,
61                 true)
62 
63 // Retrieve an unsigned from an MDNode with a list of them as operands.
64 static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex,
65                                 unsigned DefaultVal = 0) {
66   if (MdNode && OpIndex < MdNode->getNumOperands()) {
67     const auto &Op = MdNode->getOperand(OpIndex);
68     return mdconst::extract<ConstantInt>(Op)->getZExtValue();
69   }
70   return DefaultVal;
71 }
72 
73 static SPIRV::Requirements
74 getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,
75                                unsigned i, const SPIRVSubtarget &ST,
76                                SPIRV::RequirementHandler &Reqs) {
77   static AvoidCapabilitiesSet
78       AvoidCaps; // contains capabilities to avoid if there is another option
79   unsigned ReqMinVer = getSymbolicOperandMinVersion(Category, i);
80   unsigned ReqMaxVer = getSymbolicOperandMaxVersion(Category, i);
81   unsigned TargetVer = ST.getSPIRVVersion();
82   bool MinVerOK = !ReqMinVer || !TargetVer || TargetVer >= ReqMinVer;
83   bool MaxVerOK = !ReqMaxVer || !TargetVer || TargetVer <= ReqMaxVer;
84   CapabilityList ReqCaps = getSymbolicOperandCapabilities(Category, i);
85   ExtensionList ReqExts = getSymbolicOperandExtensions(Category, i);
86   if (ReqCaps.empty()) {
87     if (ReqExts.empty()) {
88       if (MinVerOK && MaxVerOK)
89         return {true, {}, {}, ReqMinVer, ReqMaxVer};
90       return {false, {}, {}, 0, 0};
91     }
92   } else if (MinVerOK && MaxVerOK) {
93     if (ReqCaps.size() == 1) {
94       auto Cap = ReqCaps[0];
95       if (Reqs.isCapabilityAvailable(Cap))
96         return {true, {Cap}, {}, ReqMinVer, ReqMaxVer};
97     } else {
98       // By SPIR-V specification: "If an instruction, enumerant, or other
99       // feature specifies multiple enabling capabilities, only one such
100       // capability needs to be declared to use the feature." However, one
101       // capability may be preferred over another. We use command line
102       // argument(s) and AvoidCapabilities to avoid selection of certain
103       // capabilities if there are other options.
104       CapabilityList UseCaps;
105       for (auto Cap : ReqCaps)
106         if (Reqs.isCapabilityAvailable(Cap))
107           UseCaps.push_back(Cap);
108       for (size_t i = 0, Sz = UseCaps.size(); i < Sz; ++i) {
109         auto Cap = UseCaps[i];
110         if (i == Sz - 1 || !AvoidCaps.S.contains(Cap))
111           return {true, {Cap}, {}, ReqMinVer, ReqMaxVer};
112       }
113     }
114   }
115   // If there are no capabilities, or we can't satisfy the version or
116   // capability requirements, use the list of extensions (if the subtarget
117   // can handle them all).
118   if (llvm::all_of(ReqExts, [&ST](const SPIRV::Extension::Extension &Ext) {
119         return ST.canUseExtension(Ext);
120       })) {
121     return {true, {}, ReqExts, 0, 0}; // TODO: add versions to extensions.
122   }
123   return {false, {}, {}, 0, 0};
124 }
125 
126 void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {
127   MAI.MaxID = 0;
128   for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++)
129     MAI.MS[i].clear();
130   MAI.RegisterAliasTable.clear();
131   MAI.InstrsToDelete.clear();
132   MAI.FuncMap.clear();
133   MAI.GlobalVarList.clear();
134   MAI.ExtInstSetMap.clear();
135   MAI.Reqs.clear();
136   MAI.Reqs.initAvailableCapabilities(*ST);
137 
138   // TODO: determine memory model and source language from the configuratoin.
139   if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) {
140     auto MemMD = MemModel->getOperand(0);
141     MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>(
142         getMetadataUInt(MemMD, 0));
143     MAI.Mem =
144         static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1));
145   } else {
146     // TODO: Add support for VulkanMemoryModel.
147     MAI.Mem = ST->isOpenCLEnv() ? SPIRV::MemoryModel::OpenCL
148                                 : SPIRV::MemoryModel::GLSL450;
149     if (MAI.Mem == SPIRV::MemoryModel::OpenCL) {
150       unsigned PtrSize = ST->getPointerSize();
151       MAI.Addr = PtrSize == 32   ? SPIRV::AddressingModel::Physical32
152                  : PtrSize == 64 ? SPIRV::AddressingModel::Physical64
153                                  : SPIRV::AddressingModel::Logical;
154     } else {
155       // TODO: Add support for PhysicalStorageBufferAddress.
156       MAI.Addr = SPIRV::AddressingModel::Logical;
157     }
158   }
159   // Get the OpenCL version number from metadata.
160   // TODO: support other source languages.
161   if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) {
162     MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C;
163     // Construct version literal in accordance with SPIRV-LLVM-Translator.
164     // TODO: support multiple OCL version metadata.
165     assert(VerNode->getNumOperands() > 0 && "Invalid SPIR");
166     auto VersionMD = VerNode->getOperand(0);
167     unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2);
168     unsigned MinorNum = getMetadataUInt(VersionMD, 1);
169     unsigned RevNum = getMetadataUInt(VersionMD, 2);
170     // Prevent Major part of OpenCL version to be 0
171     MAI.SrcLangVersion =
172         (std::max(1U, MajorNum) * 100 + MinorNum) * 1000 + RevNum;
173   } else {
174     // If there is no information about OpenCL version we are forced to generate
175     // OpenCL 1.0 by default for the OpenCL environment to avoid puzzling
176     // run-times with Unknown/0.0 version output. For a reference, LLVM-SPIRV
177     // Translator avoids potential issues with run-times in a similar manner.
178     if (ST->isOpenCLEnv()) {
179       MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_CPP;
180       MAI.SrcLangVersion = 100000;
181     } else {
182       MAI.SrcLang = SPIRV::SourceLanguage::Unknown;
183       MAI.SrcLangVersion = 0;
184     }
185   }
186 
187   if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) {
188     for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) {
189       MDNode *MD = ExtNode->getOperand(I);
190       if (!MD || MD->getNumOperands() == 0)
191         continue;
192       for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)
193         MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString());
194     }
195   }
196 
197   // Update required capabilities for this memory model, addressing model and
198   // source language.
199   MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand,
200                                  MAI.Mem, *ST);
201   MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::SourceLanguageOperand,
202                                  MAI.SrcLang, *ST);
203   MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
204                                  MAI.Addr, *ST);
205 
206   if (ST->isOpenCLEnv()) {
207     // TODO: check if it's required by default.
208     MAI.ExtInstSetMap[static_cast<unsigned>(
209         SPIRV::InstructionSet::OpenCL_std)] =
210         Register::index2VirtReg(MAI.getNextID());
211   }
212 }
213 
214 // Collect MI which defines the register in the given machine function.
215 static void collectDefInstr(Register Reg, const MachineFunction *MF,
216                             SPIRV::ModuleAnalysisInfo *MAI,
217                             SPIRV::ModuleSectionType MSType,
218                             bool DoInsert = true) {
219   assert(MAI->hasRegisterAlias(MF, Reg) && "Cannot find register alias");
220   MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(Reg);
221   assert(MI && "There should be an instruction that defines the register");
222   MAI->setSkipEmission(MI);
223   if (DoInsert)
224     MAI->MS[MSType].push_back(MI);
225 }
226 
227 void SPIRVModuleAnalysis::collectGlobalEntities(
228     const std::vector<SPIRV::DTSortableEntry *> &DepsGraph,
229     SPIRV::ModuleSectionType MSType,
230     std::function<bool(const SPIRV::DTSortableEntry *)> Pred,
231     bool UsePreOrder = false) {
232   DenseSet<const SPIRV::DTSortableEntry *> Visited;
233   for (const auto *E : DepsGraph) {
234     std::function<void(const SPIRV::DTSortableEntry *)> RecHoistUtil;
235     // NOTE: here we prefer recursive approach over iterative because
236     // we don't expect depchains long enough to cause SO.
237     RecHoistUtil = [MSType, UsePreOrder, &Visited, &Pred,
238                     &RecHoistUtil](const SPIRV::DTSortableEntry *E) {
239       if (Visited.count(E) || !Pred(E))
240         return;
241       Visited.insert(E);
242 
243       // Traversing deps graph in post-order allows us to get rid of
244       // register aliases preprocessing.
245       // But pre-order is required for correct processing of function
246       // declaration and arguments processing.
247       if (!UsePreOrder)
248         for (auto *S : E->getDeps())
249           RecHoistUtil(S);
250 
251       Register GlobalReg = Register::index2VirtReg(MAI.getNextID());
252       bool IsFirst = true;
253       for (auto &U : *E) {
254         const MachineFunction *MF = U.first;
255         Register Reg = U.second;
256         MAI.setRegisterAlias(MF, Reg, GlobalReg);
257         if (!MF->getRegInfo().getUniqueVRegDef(Reg))
258           continue;
259         collectDefInstr(Reg, MF, &MAI, MSType, IsFirst);
260         IsFirst = false;
261         if (E->getIsGV())
262           MAI.GlobalVarList.push_back(MF->getRegInfo().getUniqueVRegDef(Reg));
263       }
264 
265       if (UsePreOrder)
266         for (auto *S : E->getDeps())
267           RecHoistUtil(S);
268     };
269     RecHoistUtil(E);
270   }
271 }
272 
273 // The function initializes global register alias table for types, consts,
274 // global vars and func decls and collects these instruction for output
275 // at module level. Also it collects explicit OpExtension/OpCapability
276 // instructions.
277 void SPIRVModuleAnalysis::processDefInstrs(const Module &M) {
278   std::vector<SPIRV::DTSortableEntry *> DepsGraph;
279 
280   GR->buildDepsGraph(DepsGraph, SPVDumpDeps ? MMI : nullptr);
281 
282   collectGlobalEntities(
283       DepsGraph, SPIRV::MB_TypeConstVars,
284       [](const SPIRV::DTSortableEntry *E) { return !E->getIsFunc(); });
285 
286   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
287     MachineFunction *MF = MMI->getMachineFunction(*F);
288     if (!MF)
289       continue;
290     // Iterate through and collect OpExtension/OpCapability instructions.
291     for (MachineBasicBlock &MBB : *MF) {
292       for (MachineInstr &MI : MBB) {
293         if (MI.getOpcode() == SPIRV::OpExtension) {
294           // Here, OpExtension just has a single enum operand, not a string.
295           auto Ext = SPIRV::Extension::Extension(MI.getOperand(0).getImm());
296           MAI.Reqs.addExtension(Ext);
297           MAI.setSkipEmission(&MI);
298         } else if (MI.getOpcode() == SPIRV::OpCapability) {
299           auto Cap = SPIRV::Capability::Capability(MI.getOperand(0).getImm());
300           MAI.Reqs.addCapability(Cap);
301           MAI.setSkipEmission(&MI);
302         }
303       }
304     }
305   }
306 
307   collectGlobalEntities(
308       DepsGraph, SPIRV::MB_ExtFuncDecls,
309       [](const SPIRV::DTSortableEntry *E) { return E->getIsFunc(); }, true);
310 }
311 
312 // Look for IDs declared with Import linkage, and map the corresponding function
313 // to the register defining that variable (which will usually be the result of
314 // an OpFunction). This lets us call externally imported functions using
315 // the correct ID registers.
316 void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
317                                            const Function *F) {
318   if (MI.getOpcode() == SPIRV::OpDecorate) {
319     // If it's got Import linkage.
320     auto Dec = MI.getOperand(1).getImm();
321     if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) {
322       auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm();
323       if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) {
324         // Map imported function name to function ID register.
325         const Function *ImportedFunc =
326             F->getParent()->getFunction(getStringImm(MI, 2));
327         Register Target = MI.getOperand(0).getReg();
328         MAI.FuncMap[ImportedFunc] = MAI.getRegisterAlias(MI.getMF(), Target);
329       }
330     }
331   } else if (MI.getOpcode() == SPIRV::OpFunction) {
332     // Record all internal OpFunction declarations.
333     Register Reg = MI.defs().begin()->getReg();
334     Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg);
335     assert(GlobalReg.isValid());
336     MAI.FuncMap[F] = GlobalReg;
337   }
338 }
339 
340 // References to a function via function pointers generate virtual
341 // registers without a definition. We are able to resolve this
342 // reference using Globar Register info into an OpFunction instruction
343 // and replace dummy operands by the corresponding global register references.
344 void SPIRVModuleAnalysis::collectFuncPtrs() {
345   for (auto &MI : MAI.MS[SPIRV::MB_TypeConstVars])
346     if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL)
347       collectFuncPtrs(MI);
348 }
349 
350 void SPIRVModuleAnalysis::collectFuncPtrs(MachineInstr *MI) {
351   const MachineOperand *FunUse = &MI->getOperand(2);
352   if (const MachineOperand *FunDef = GR->getFunctionDefinitionByUse(FunUse)) {
353     const MachineInstr *FunDefMI = FunDef->getParent();
354     assert(FunDefMI->getOpcode() == SPIRV::OpFunction &&
355            "Constant function pointer must refer to function definition");
356     Register FunDefReg = FunDef->getReg();
357     Register GlobalFunDefReg =
358         MAI.getRegisterAlias(FunDefMI->getMF(), FunDefReg);
359     assert(GlobalFunDefReg.isValid() &&
360            "Function definition must refer to a global register");
361     Register FunPtrReg = FunUse->getReg();
362     MAI.setRegisterAlias(MI->getMF(), FunPtrReg, GlobalFunDefReg);
363   }
364 }
365 
366 using InstrSignature = SmallVector<size_t>;
367 using InstrTraces = std::set<InstrSignature>;
368 
369 // Returns a representation of an instruction as a vector of MachineOperand
370 // hash values, see llvm::hash_value(const MachineOperand &MO) for details.
371 // This creates a signature of the instruction with the same content
372 // that MachineOperand::isIdenticalTo uses for comparison.
373 static InstrSignature instrToSignature(MachineInstr &MI,
374                                        SPIRV::ModuleAnalysisInfo &MAI) {
375   InstrSignature Signature;
376   for (unsigned i = 0; i < MI.getNumOperands(); ++i) {
377     const MachineOperand &MO = MI.getOperand(i);
378     size_t h;
379     if (MO.isReg()) {
380       Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg());
381       // mimic llvm::hash_value(const MachineOperand &MO)
382       h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(),
383                        MO.isDef());
384     } else {
385       h = hash_value(MO);
386     }
387     Signature.push_back(h);
388   }
389   return Signature;
390 }
391 
392 // Collect the given instruction in the specified MS. We assume global register
393 // numbering has already occurred by this point. We can directly compare reg
394 // arguments when detecting duplicates.
395 static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI,
396                               SPIRV::ModuleSectionType MSType, InstrTraces &IS,
397                               bool Append = true) {
398   MAI.setSkipEmission(&MI);
399   InstrSignature MISign = instrToSignature(MI, MAI);
400   auto FoundMI = IS.insert(MISign);
401   if (!FoundMI.second)
402     return; // insert failed, so we found a duplicate; don't add it to MAI.MS
403   // No duplicates, so add it.
404   if (Append)
405     MAI.MS[MSType].push_back(&MI);
406   else
407     MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI);
408 }
409 
410 // Some global instructions make reference to function-local ID regs, so cannot
411 // be correctly collected until these registers are globally numbered.
412 void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) {
413   InstrTraces IS;
414   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
415     if ((*F).isDeclaration())
416       continue;
417     MachineFunction *MF = MMI->getMachineFunction(*F);
418     assert(MF);
419     for (MachineBasicBlock &MBB : *MF)
420       for (MachineInstr &MI : MBB) {
421         if (MAI.getSkipEmission(&MI))
422           continue;
423         const unsigned OpCode = MI.getOpcode();
424         if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) {
425           collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames, IS);
426         } else if (OpCode == SPIRV::OpEntryPoint) {
427           collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints, IS);
428         } else if (TII->isDecorationInstr(MI)) {
429           collectOtherInstr(MI, MAI, SPIRV::MB_Annotations, IS);
430           collectFuncNames(MI, &*F);
431         } else if (TII->isConstantInstr(MI)) {
432           // Now OpSpecConstant*s are not in DT,
433           // but they need to be collected anyway.
434           collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS);
435         } else if (OpCode == SPIRV::OpFunction) {
436           collectFuncNames(MI, &*F);
437         } else if (OpCode == SPIRV::OpTypeForwardPointer) {
438           collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS, false);
439         }
440       }
441   }
442 }
443 
444 // Number registers in all functions globally from 0 onwards and store
445 // the result in global register alias table. Some registers are already
446 // numbered in collectGlobalEntities.
447 void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {
448   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
449     if ((*F).isDeclaration())
450       continue;
451     MachineFunction *MF = MMI->getMachineFunction(*F);
452     assert(MF);
453     for (MachineBasicBlock &MBB : *MF) {
454       for (MachineInstr &MI : MBB) {
455         for (MachineOperand &Op : MI.operands()) {
456           if (!Op.isReg())
457             continue;
458           Register Reg = Op.getReg();
459           if (MAI.hasRegisterAlias(MF, Reg))
460             continue;
461           Register NewReg = Register::index2VirtReg(MAI.getNextID());
462           MAI.setRegisterAlias(MF, Reg, NewReg);
463         }
464         if (MI.getOpcode() != SPIRV::OpExtInst)
465           continue;
466         auto Set = MI.getOperand(2).getImm();
467         if (!MAI.ExtInstSetMap.contains(Set))
468           MAI.ExtInstSetMap[Set] = Register::index2VirtReg(MAI.getNextID());
469       }
470     }
471   }
472 }
473 
474 // RequirementHandler implementations.
475 void SPIRV::RequirementHandler::getAndAddRequirements(
476     SPIRV::OperandCategory::OperandCategory Category, uint32_t i,
477     const SPIRVSubtarget &ST) {
478   addRequirements(getSymbolicOperandRequirements(Category, i, ST, *this));
479 }
480 
481 void SPIRV::RequirementHandler::recursiveAddCapabilities(
482     const CapabilityList &ToPrune) {
483   for (const auto &Cap : ToPrune) {
484     AllCaps.insert(Cap);
485     CapabilityList ImplicitDecls =
486         getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
487     recursiveAddCapabilities(ImplicitDecls);
488   }
489 }
490 
491 void SPIRV::RequirementHandler::addCapabilities(const CapabilityList &ToAdd) {
492   for (const auto &Cap : ToAdd) {
493     bool IsNewlyInserted = AllCaps.insert(Cap).second;
494     if (!IsNewlyInserted) // Don't re-add if it's already been declared.
495       continue;
496     CapabilityList ImplicitDecls =
497         getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
498     recursiveAddCapabilities(ImplicitDecls);
499     MinimalCaps.push_back(Cap);
500   }
501 }
502 
503 void SPIRV::RequirementHandler::addRequirements(
504     const SPIRV::Requirements &Req) {
505   if (!Req.IsSatisfiable)
506     report_fatal_error("Adding SPIR-V requirements this target can't satisfy.");
507 
508   if (Req.Cap.has_value())
509     addCapabilities({Req.Cap.value()});
510 
511   addExtensions(Req.Exts);
512 
513   if (Req.MinVer) {
514     if (MaxVersion && Req.MinVer > MaxVersion) {
515       LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer
516                         << " and <= " << MaxVersion << "\n");
517       report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
518     }
519 
520     if (MinVersion == 0 || Req.MinVer > MinVersion)
521       MinVersion = Req.MinVer;
522   }
523 
524   if (Req.MaxVer) {
525     if (MinVersion && Req.MaxVer < MinVersion) {
526       LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer
527                         << " and >= " << MinVersion << "\n");
528       report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
529     }
530 
531     if (MaxVersion == 0 || Req.MaxVer < MaxVersion)
532       MaxVersion = Req.MaxVer;
533   }
534 }
535 
536 void SPIRV::RequirementHandler::checkSatisfiable(
537     const SPIRVSubtarget &ST) const {
538   // Report as many errors as possible before aborting the compilation.
539   bool IsSatisfiable = true;
540   auto TargetVer = ST.getSPIRVVersion();
541 
542   if (MaxVersion && TargetVer && MaxVersion < TargetVer) {
543     LLVM_DEBUG(
544         dbgs() << "Target SPIR-V version too high for required features\n"
545                << "Required max version: " << MaxVersion << " target version "
546                << TargetVer << "\n");
547     IsSatisfiable = false;
548   }
549 
550   if (MinVersion && TargetVer && MinVersion > TargetVer) {
551     LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n"
552                       << "Required min version: " << MinVersion
553                       << " target version " << TargetVer << "\n");
554     IsSatisfiable = false;
555   }
556 
557   if (MinVersion && MaxVersion && MinVersion > MaxVersion) {
558     LLVM_DEBUG(
559         dbgs()
560         << "Version is too low for some features and too high for others.\n"
561         << "Required SPIR-V min version: " << MinVersion
562         << " required SPIR-V max version " << MaxVersion << "\n");
563     IsSatisfiable = false;
564   }
565 
566   for (auto Cap : MinimalCaps) {
567     if (AvailableCaps.contains(Cap))
568       continue;
569     LLVM_DEBUG(dbgs() << "Capability not supported: "
570                       << getSymbolicOperandMnemonic(
571                              OperandCategory::CapabilityOperand, Cap)
572                       << "\n");
573     IsSatisfiable = false;
574   }
575 
576   for (auto Ext : AllExtensions) {
577     if (ST.canUseExtension(Ext))
578       continue;
579     LLVM_DEBUG(dbgs() << "Extension not supported: "
580                       << getSymbolicOperandMnemonic(
581                              OperandCategory::ExtensionOperand, Ext)
582                       << "\n");
583     IsSatisfiable = false;
584   }
585 
586   if (!IsSatisfiable)
587     report_fatal_error("Unable to meet SPIR-V requirements for this target.");
588 }
589 
590 // Add the given capabilities and all their implicitly defined capabilities too.
591 void SPIRV::RequirementHandler::addAvailableCaps(const CapabilityList &ToAdd) {
592   for (const auto Cap : ToAdd)
593     if (AvailableCaps.insert(Cap).second)
594       addAvailableCaps(getSymbolicOperandCapabilities(
595           SPIRV::OperandCategory::CapabilityOperand, Cap));
596 }
597 
598 void SPIRV::RequirementHandler::removeCapabilityIf(
599     const Capability::Capability ToRemove,
600     const Capability::Capability IfPresent) {
601   if (AllCaps.contains(IfPresent))
602     AllCaps.erase(ToRemove);
603 }
604 
605 namespace llvm {
606 namespace SPIRV {
607 void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
608   if (ST.isOpenCLEnv()) {
609     initAvailableCapabilitiesForOpenCL(ST);
610     return;
611   }
612 
613   if (ST.isVulkanEnv()) {
614     initAvailableCapabilitiesForVulkan(ST);
615     return;
616   }
617 
618   report_fatal_error("Unimplemented environment for SPIR-V generation.");
619 }
620 
621 void RequirementHandler::initAvailableCapabilitiesForOpenCL(
622     const SPIRVSubtarget &ST) {
623   // Add the min requirements for different OpenCL and SPIR-V versions.
624   addAvailableCaps({Capability::Addresses, Capability::Float16Buffer,
625                     Capability::Int16, Capability::Int8, Capability::Kernel,
626                     Capability::Linkage, Capability::Vector16,
627                     Capability::Groups, Capability::GenericPointer,
628                     Capability::Shader});
629   if (ST.hasOpenCLFullProfile())
630     addAvailableCaps({Capability::Int64, Capability::Int64Atomics});
631   if (ST.hasOpenCLImageSupport()) {
632     addAvailableCaps({Capability::ImageBasic, Capability::LiteralSampler,
633                       Capability::Image1D, Capability::SampledBuffer,
634                       Capability::ImageBuffer});
635     if (ST.isAtLeastOpenCLVer(20))
636       addAvailableCaps({Capability::ImageReadWrite});
637   }
638   if (ST.isAtLeastSPIRVVer(11) && ST.isAtLeastOpenCLVer(22))
639     addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage});
640   if (ST.isAtLeastSPIRVVer(13))
641     addAvailableCaps({Capability::GroupNonUniform,
642                       Capability::GroupNonUniformVote,
643                       Capability::GroupNonUniformArithmetic,
644                       Capability::GroupNonUniformBallot,
645                       Capability::GroupNonUniformClustered,
646                       Capability::GroupNonUniformShuffle,
647                       Capability::GroupNonUniformShuffleRelative});
648   if (ST.isAtLeastSPIRVVer(14))
649     addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero,
650                       Capability::SignedZeroInfNanPreserve,
651                       Capability::RoundingModeRTE,
652                       Capability::RoundingModeRTZ});
653   // TODO: verify if this needs some checks.
654   addAvailableCaps({Capability::Float16, Capability::Float64});
655 
656   // Add capabilities enabled by extensions.
657   for (auto Extension : ST.getAllAvailableExtensions()) {
658     CapabilityList EnabledCapabilities =
659         getCapabilitiesEnabledByExtension(Extension);
660     addAvailableCaps(EnabledCapabilities);
661   }
662 
663   // TODO: add OpenCL extensions.
664 }
665 
666 void RequirementHandler::initAvailableCapabilitiesForVulkan(
667     const SPIRVSubtarget &ST) {
668   addAvailableCaps({Capability::Shader, Capability::Linkage});
669 
670   // Provided by all supported Vulkan versions.
671   addAvailableCaps({Capability::Int16, Capability::Int64, Capability::Float16,
672                     Capability::Float64, Capability::GroupNonUniform});
673 }
674 
675 } // namespace SPIRV
676 } // namespace llvm
677 
678 // Add the required capabilities from a decoration instruction (including
679 // BuiltIns).
680 static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex,
681                               SPIRV::RequirementHandler &Reqs,
682                               const SPIRVSubtarget &ST) {
683   int64_t DecOp = MI.getOperand(DecIndex).getImm();
684   auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp);
685   Reqs.addRequirements(getSymbolicOperandRequirements(
686       SPIRV::OperandCategory::DecorationOperand, Dec, ST, Reqs));
687 
688   if (Dec == SPIRV::Decoration::BuiltIn) {
689     int64_t BuiltInOp = MI.getOperand(DecIndex + 1).getImm();
690     auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp);
691     Reqs.addRequirements(getSymbolicOperandRequirements(
692         SPIRV::OperandCategory::BuiltInOperand, BuiltIn, ST, Reqs));
693   } else if (Dec == SPIRV::Decoration::LinkageAttributes) {
694     int64_t LinkageOp = MI.getOperand(MI.getNumOperands() - 1).getImm();
695     SPIRV::LinkageType::LinkageType LnkType =
696         static_cast<SPIRV::LinkageType::LinkageType>(LinkageOp);
697     if (LnkType == SPIRV::LinkageType::LinkOnceODR)
698       Reqs.addExtension(SPIRV::Extension::SPV_KHR_linkonce_odr);
699   }
700 }
701 
702 // Add requirements for image handling.
703 static void addOpTypeImageReqs(const MachineInstr &MI,
704                                SPIRV::RequirementHandler &Reqs,
705                                const SPIRVSubtarget &ST) {
706   assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage");
707   // The operand indices used here are based on the OpTypeImage layout, which
708   // the MachineInstr follows as well.
709   int64_t ImgFormatOp = MI.getOperand(7).getImm();
710   auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp);
711   Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageFormatOperand,
712                              ImgFormat, ST);
713 
714   bool IsArrayed = MI.getOperand(4).getImm() == 1;
715   bool IsMultisampled = MI.getOperand(5).getImm() == 1;
716   bool NoSampler = MI.getOperand(6).getImm() == 2;
717   // Add dimension requirements.
718   assert(MI.getOperand(2).isImm());
719   switch (MI.getOperand(2).getImm()) {
720   case SPIRV::Dim::DIM_1D:
721     Reqs.addRequirements(NoSampler ? SPIRV::Capability::Image1D
722                                    : SPIRV::Capability::Sampled1D);
723     break;
724   case SPIRV::Dim::DIM_2D:
725     if (IsMultisampled && NoSampler)
726       Reqs.addRequirements(SPIRV::Capability::ImageMSArray);
727     break;
728   case SPIRV::Dim::DIM_Cube:
729     Reqs.addRequirements(SPIRV::Capability::Shader);
730     if (IsArrayed)
731       Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageCubeArray
732                                      : SPIRV::Capability::SampledCubeArray);
733     break;
734   case SPIRV::Dim::DIM_Rect:
735     Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageRect
736                                    : SPIRV::Capability::SampledRect);
737     break;
738   case SPIRV::Dim::DIM_Buffer:
739     Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageBuffer
740                                    : SPIRV::Capability::SampledBuffer);
741     break;
742   case SPIRV::Dim::DIM_SubpassData:
743     Reqs.addRequirements(SPIRV::Capability::InputAttachment);
744     break;
745   }
746 
747   // Has optional access qualifier.
748   // TODO: check if it's OpenCL's kernel.
749   if (MI.getNumOperands() > 8 &&
750       MI.getOperand(8).getImm() == SPIRV::AccessQualifier::ReadWrite)
751     Reqs.addRequirements(SPIRV::Capability::ImageReadWrite);
752   else
753     Reqs.addRequirements(SPIRV::Capability::ImageBasic);
754 }
755 
756 // Add requirements for handling atomic float instructions
757 #define ATOM_FLT_REQ_EXT_MSG(ExtName)                                          \
758   "The atomic float instruction requires the following SPIR-V "                \
759   "extension: SPV_EXT_shader_atomic_float" ExtName
760 static void AddAtomicFloatRequirements(const MachineInstr &MI,
761                                        SPIRV::RequirementHandler &Reqs,
762                                        const SPIRVSubtarget &ST) {
763   assert(MI.getOperand(1).isReg() &&
764          "Expect register operand in atomic float instruction");
765   Register TypeReg = MI.getOperand(1).getReg();
766   SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg);
767   if (TypeDef->getOpcode() != SPIRV::OpTypeFloat)
768     report_fatal_error("Result type of an atomic float instruction must be a "
769                        "floating-point type scalar");
770 
771   unsigned BitWidth = TypeDef->getOperand(1).getImm();
772   unsigned Op = MI.getOpcode();
773   if (Op == SPIRV::OpAtomicFAddEXT) {
774     if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add))
775       report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_add"), false);
776     Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
777     switch (BitWidth) {
778     case 16:
779       if (!ST.canUseExtension(
780               SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
781         report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
782       Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
783       Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
784       break;
785     case 32:
786       Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT);
787       break;
788     case 64:
789       Reqs.addCapability(SPIRV::Capability::AtomicFloat64AddEXT);
790       break;
791     default:
792       report_fatal_error(
793           "Unexpected floating-point type width in atomic float instruction");
794     }
795   } else {
796     if (!ST.canUseExtension(
797             SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max))
798       report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), false);
799     Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
800     switch (BitWidth) {
801     case 16:
802       Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
803       break;
804     case 32:
805       Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT);
806       break;
807     case 64:
808       Reqs.addCapability(SPIRV::Capability::AtomicFloat64MinMaxEXT);
809       break;
810     default:
811       report_fatal_error(
812           "Unexpected floating-point type width in atomic float instruction");
813     }
814   }
815 }
816 
817 void addInstrRequirements(const MachineInstr &MI,
818                           SPIRV::RequirementHandler &Reqs,
819                           const SPIRVSubtarget &ST) {
820   switch (MI.getOpcode()) {
821   case SPIRV::OpMemoryModel: {
822     int64_t Addr = MI.getOperand(0).getImm();
823     Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
824                                Addr, ST);
825     int64_t Mem = MI.getOperand(1).getImm();
826     Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, Mem,
827                                ST);
828     break;
829   }
830   case SPIRV::OpEntryPoint: {
831     int64_t Exe = MI.getOperand(0).getImm();
832     Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModelOperand,
833                                Exe, ST);
834     break;
835   }
836   case SPIRV::OpExecutionMode:
837   case SPIRV::OpExecutionModeId: {
838     int64_t Exe = MI.getOperand(1).getImm();
839     Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModeOperand,
840                                Exe, ST);
841     break;
842   }
843   case SPIRV::OpTypeMatrix:
844     Reqs.addCapability(SPIRV::Capability::Matrix);
845     break;
846   case SPIRV::OpTypeInt: {
847     unsigned BitWidth = MI.getOperand(1).getImm();
848     if (BitWidth == 64)
849       Reqs.addCapability(SPIRV::Capability::Int64);
850     else if (BitWidth == 16)
851       Reqs.addCapability(SPIRV::Capability::Int16);
852     else if (BitWidth == 8)
853       Reqs.addCapability(SPIRV::Capability::Int8);
854     break;
855   }
856   case SPIRV::OpTypeFloat: {
857     unsigned BitWidth = MI.getOperand(1).getImm();
858     if (BitWidth == 64)
859       Reqs.addCapability(SPIRV::Capability::Float64);
860     else if (BitWidth == 16)
861       Reqs.addCapability(SPIRV::Capability::Float16);
862     break;
863   }
864   case SPIRV::OpTypeVector: {
865     unsigned NumComponents = MI.getOperand(2).getImm();
866     if (NumComponents == 8 || NumComponents == 16)
867       Reqs.addCapability(SPIRV::Capability::Vector16);
868     break;
869   }
870   case SPIRV::OpTypePointer: {
871     auto SC = MI.getOperand(1).getImm();
872     Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC,
873                                ST);
874     // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer
875     // capability.
876     if (!ST.isOpenCLEnv())
877       break;
878     assert(MI.getOperand(2).isReg());
879     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
880     SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
881     if (TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
882         TypeDef->getOperand(1).getImm() == 16)
883       Reqs.addCapability(SPIRV::Capability::Float16Buffer);
884     break;
885   }
886   case SPIRV::OpBitReverse:
887   case SPIRV::OpBitFieldInsert:
888   case SPIRV::OpBitFieldSExtract:
889   case SPIRV::OpBitFieldUExtract:
890     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions)) {
891       Reqs.addCapability(SPIRV::Capability::Shader);
892       break;
893     }
894     Reqs.addExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
895     Reqs.addCapability(SPIRV::Capability::BitInstructions);
896     break;
897   case SPIRV::OpTypeRuntimeArray:
898     Reqs.addCapability(SPIRV::Capability::Shader);
899     break;
900   case SPIRV::OpTypeOpaque:
901   case SPIRV::OpTypeEvent:
902     Reqs.addCapability(SPIRV::Capability::Kernel);
903     break;
904   case SPIRV::OpTypePipe:
905   case SPIRV::OpTypeReserveId:
906     Reqs.addCapability(SPIRV::Capability::Pipes);
907     break;
908   case SPIRV::OpTypeDeviceEvent:
909   case SPIRV::OpTypeQueue:
910   case SPIRV::OpBuildNDRange:
911     Reqs.addCapability(SPIRV::Capability::DeviceEnqueue);
912     break;
913   case SPIRV::OpDecorate:
914   case SPIRV::OpDecorateId:
915   case SPIRV::OpDecorateString:
916     addOpDecorateReqs(MI, 1, Reqs, ST);
917     break;
918   case SPIRV::OpMemberDecorate:
919   case SPIRV::OpMemberDecorateString:
920     addOpDecorateReqs(MI, 2, Reqs, ST);
921     break;
922   case SPIRV::OpInBoundsPtrAccessChain:
923     Reqs.addCapability(SPIRV::Capability::Addresses);
924     break;
925   case SPIRV::OpConstantSampler:
926     Reqs.addCapability(SPIRV::Capability::LiteralSampler);
927     break;
928   case SPIRV::OpTypeImage:
929     addOpTypeImageReqs(MI, Reqs, ST);
930     break;
931   case SPIRV::OpTypeSampler:
932     Reqs.addCapability(SPIRV::Capability::ImageBasic);
933     break;
934   case SPIRV::OpTypeForwardPointer:
935     // TODO: check if it's OpenCL's kernel.
936     Reqs.addCapability(SPIRV::Capability::Addresses);
937     break;
938   case SPIRV::OpAtomicFlagTestAndSet:
939   case SPIRV::OpAtomicLoad:
940   case SPIRV::OpAtomicStore:
941   case SPIRV::OpAtomicExchange:
942   case SPIRV::OpAtomicCompareExchange:
943   case SPIRV::OpAtomicIIncrement:
944   case SPIRV::OpAtomicIDecrement:
945   case SPIRV::OpAtomicIAdd:
946   case SPIRV::OpAtomicISub:
947   case SPIRV::OpAtomicUMin:
948   case SPIRV::OpAtomicUMax:
949   case SPIRV::OpAtomicSMin:
950   case SPIRV::OpAtomicSMax:
951   case SPIRV::OpAtomicAnd:
952   case SPIRV::OpAtomicOr:
953   case SPIRV::OpAtomicXor: {
954     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
955     const MachineInstr *InstrPtr = &MI;
956     if (MI.getOpcode() == SPIRV::OpAtomicStore) {
957       assert(MI.getOperand(3).isReg());
958       InstrPtr = MRI.getVRegDef(MI.getOperand(3).getReg());
959       assert(InstrPtr && "Unexpected type instruction for OpAtomicStore");
960     }
961     assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic");
962     Register TypeReg = InstrPtr->getOperand(1).getReg();
963     SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
964     if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
965       unsigned BitWidth = TypeDef->getOperand(1).getImm();
966       if (BitWidth == 64)
967         Reqs.addCapability(SPIRV::Capability::Int64Atomics);
968     }
969     break;
970   }
971   case SPIRV::OpGroupNonUniformIAdd:
972   case SPIRV::OpGroupNonUniformFAdd:
973   case SPIRV::OpGroupNonUniformIMul:
974   case SPIRV::OpGroupNonUniformFMul:
975   case SPIRV::OpGroupNonUniformSMin:
976   case SPIRV::OpGroupNonUniformUMin:
977   case SPIRV::OpGroupNonUniformFMin:
978   case SPIRV::OpGroupNonUniformSMax:
979   case SPIRV::OpGroupNonUniformUMax:
980   case SPIRV::OpGroupNonUniformFMax:
981   case SPIRV::OpGroupNonUniformBitwiseAnd:
982   case SPIRV::OpGroupNonUniformBitwiseOr:
983   case SPIRV::OpGroupNonUniformBitwiseXor:
984   case SPIRV::OpGroupNonUniformLogicalAnd:
985   case SPIRV::OpGroupNonUniformLogicalOr:
986   case SPIRV::OpGroupNonUniformLogicalXor: {
987     assert(MI.getOperand(3).isImm());
988     int64_t GroupOp = MI.getOperand(3).getImm();
989     switch (GroupOp) {
990     case SPIRV::GroupOperation::Reduce:
991     case SPIRV::GroupOperation::InclusiveScan:
992     case SPIRV::GroupOperation::ExclusiveScan:
993       Reqs.addCapability(SPIRV::Capability::Kernel);
994       Reqs.addCapability(SPIRV::Capability::GroupNonUniformArithmetic);
995       Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
996       break;
997     case SPIRV::GroupOperation::ClusteredReduce:
998       Reqs.addCapability(SPIRV::Capability::GroupNonUniformClustered);
999       break;
1000     case SPIRV::GroupOperation::PartitionedReduceNV:
1001     case SPIRV::GroupOperation::PartitionedInclusiveScanNV:
1002     case SPIRV::GroupOperation::PartitionedExclusiveScanNV:
1003       Reqs.addCapability(SPIRV::Capability::GroupNonUniformPartitionedNV);
1004       break;
1005     }
1006     break;
1007   }
1008   case SPIRV::OpGroupNonUniformShuffle:
1009   case SPIRV::OpGroupNonUniformShuffleXor:
1010     Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffle);
1011     break;
1012   case SPIRV::OpGroupNonUniformShuffleUp:
1013   case SPIRV::OpGroupNonUniformShuffleDown:
1014     Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffleRelative);
1015     break;
1016   case SPIRV::OpGroupAll:
1017   case SPIRV::OpGroupAny:
1018   case SPIRV::OpGroupBroadcast:
1019   case SPIRV::OpGroupIAdd:
1020   case SPIRV::OpGroupFAdd:
1021   case SPIRV::OpGroupFMin:
1022   case SPIRV::OpGroupUMin:
1023   case SPIRV::OpGroupSMin:
1024   case SPIRV::OpGroupFMax:
1025   case SPIRV::OpGroupUMax:
1026   case SPIRV::OpGroupSMax:
1027     Reqs.addCapability(SPIRV::Capability::Groups);
1028     break;
1029   case SPIRV::OpGroupNonUniformElect:
1030     Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1031     break;
1032   case SPIRV::OpGroupNonUniformAll:
1033   case SPIRV::OpGroupNonUniformAny:
1034   case SPIRV::OpGroupNonUniformAllEqual:
1035     Reqs.addCapability(SPIRV::Capability::GroupNonUniformVote);
1036     break;
1037   case SPIRV::OpGroupNonUniformBroadcast:
1038   case SPIRV::OpGroupNonUniformBroadcastFirst:
1039   case SPIRV::OpGroupNonUniformBallot:
1040   case SPIRV::OpGroupNonUniformInverseBallot:
1041   case SPIRV::OpGroupNonUniformBallotBitExtract:
1042   case SPIRV::OpGroupNonUniformBallotBitCount:
1043   case SPIRV::OpGroupNonUniformBallotFindLSB:
1044   case SPIRV::OpGroupNonUniformBallotFindMSB:
1045     Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
1046     break;
1047   case SPIRV::OpSubgroupShuffleINTEL:
1048   case SPIRV::OpSubgroupShuffleDownINTEL:
1049   case SPIRV::OpSubgroupShuffleUpINTEL:
1050   case SPIRV::OpSubgroupShuffleXorINTEL:
1051     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1052       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1053       Reqs.addCapability(SPIRV::Capability::SubgroupShuffleINTEL);
1054     }
1055     break;
1056   case SPIRV::OpSubgroupBlockReadINTEL:
1057   case SPIRV::OpSubgroupBlockWriteINTEL:
1058     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1059       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1060       Reqs.addCapability(SPIRV::Capability::SubgroupBufferBlockIOINTEL);
1061     }
1062     break;
1063   case SPIRV::OpSubgroupImageBlockReadINTEL:
1064   case SPIRV::OpSubgroupImageBlockWriteINTEL:
1065     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1066       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1067       Reqs.addCapability(SPIRV::Capability::SubgroupImageBlockIOINTEL);
1068     }
1069     break;
1070   case SPIRV::OpAssumeTrueKHR:
1071   case SPIRV::OpExpectKHR:
1072     if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) {
1073       Reqs.addExtension(SPIRV::Extension::SPV_KHR_expect_assume);
1074       Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
1075     }
1076     break;
1077   case SPIRV::OpPtrCastToCrossWorkgroupINTEL:
1078   case SPIRV::OpCrossWorkgroupCastToPtrINTEL:
1079     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {
1080       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes);
1081       Reqs.addCapability(SPIRV::Capability::USMStorageClassesINTEL);
1082     }
1083     break;
1084   case SPIRV::OpConstantFunctionPointerINTEL:
1085     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1086       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1087       Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1088     }
1089     break;
1090   case SPIRV::OpGroupNonUniformRotateKHR:
1091     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate))
1092       report_fatal_error("OpGroupNonUniformRotateKHR instruction requires the "
1093                          "following SPIR-V extension: SPV_KHR_subgroup_rotate",
1094                          false);
1095     Reqs.addExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate);
1096     Reqs.addCapability(SPIRV::Capability::GroupNonUniformRotateKHR);
1097     Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1098     break;
1099   case SPIRV::OpGroupIMulKHR:
1100   case SPIRV::OpGroupFMulKHR:
1101   case SPIRV::OpGroupBitwiseAndKHR:
1102   case SPIRV::OpGroupBitwiseOrKHR:
1103   case SPIRV::OpGroupBitwiseXorKHR:
1104   case SPIRV::OpGroupLogicalAndKHR:
1105   case SPIRV::OpGroupLogicalOrKHR:
1106   case SPIRV::OpGroupLogicalXorKHR:
1107     if (ST.canUseExtension(
1108             SPIRV::Extension::SPV_KHR_uniform_group_instructions)) {
1109       Reqs.addExtension(SPIRV::Extension::SPV_KHR_uniform_group_instructions);
1110       Reqs.addCapability(SPIRV::Capability::GroupUniformArithmeticKHR);
1111     }
1112     break;
1113   case SPIRV::OpFunctionPointerCallINTEL:
1114     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1115       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1116       Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1117     }
1118     break;
1119   case SPIRV::OpAtomicFAddEXT:
1120   case SPIRV::OpAtomicFMinEXT:
1121   case SPIRV::OpAtomicFMaxEXT:
1122     AddAtomicFloatRequirements(MI, Reqs, ST);
1123     break;
1124   case SPIRV::OpConvertBF16ToFINTEL:
1125   case SPIRV::OpConvertFToBF16INTEL:
1126     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) {
1127       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion);
1128       Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL);
1129     }
1130     break;
1131   case SPIRV::OpVariableLengthArrayINTEL:
1132   case SPIRV::OpSaveMemoryINTEL:
1133   case SPIRV::OpRestoreMemoryINTEL:
1134     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_variable_length_array)) {
1135       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_variable_length_array);
1136       Reqs.addCapability(SPIRV::Capability::VariableLengthArrayINTEL);
1137     }
1138     break;
1139   default:
1140     break;
1141   }
1142 
1143   // If we require capability Shader, then we can remove the requirement for
1144   // the BitInstructions capability, since Shader is a superset capability
1145   // of BitInstructions.
1146   Reqs.removeCapabilityIf(SPIRV::Capability::BitInstructions,
1147                           SPIRV::Capability::Shader);
1148 }
1149 
1150 static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
1151                         MachineModuleInfo *MMI, const SPIRVSubtarget &ST) {
1152   // Collect requirements for existing instructions.
1153   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
1154     MachineFunction *MF = MMI->getMachineFunction(*F);
1155     if (!MF)
1156       continue;
1157     for (const MachineBasicBlock &MBB : *MF)
1158       for (const MachineInstr &MI : MBB)
1159         addInstrRequirements(MI, MAI.Reqs, ST);
1160   }
1161   // Collect requirements for OpExecutionMode instructions.
1162   auto Node = M.getNamedMetadata("spirv.ExecutionMode");
1163   if (Node) {
1164     // SPV_KHR_float_controls is not available until v1.4
1165     bool RequireFloatControls = false, VerLower14 = !ST.isAtLeastSPIRVVer(14);
1166     for (unsigned i = 0; i < Node->getNumOperands(); i++) {
1167       MDNode *MDN = cast<MDNode>(Node->getOperand(i));
1168       const MDOperand &MDOp = MDN->getOperand(1);
1169       if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) {
1170         Constant *C = CMeta->getValue();
1171         if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) {
1172           auto EM = Const->getZExtValue();
1173           MAI.Reqs.getAndAddRequirements(
1174               SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
1175           // add SPV_KHR_float_controls if the version is too low
1176           switch (EM) {
1177           case SPIRV::ExecutionMode::DenormPreserve:
1178           case SPIRV::ExecutionMode::DenormFlushToZero:
1179           case SPIRV::ExecutionMode::SignedZeroInfNanPreserve:
1180           case SPIRV::ExecutionMode::RoundingModeRTE:
1181           case SPIRV::ExecutionMode::RoundingModeRTZ:
1182             RequireFloatControls = VerLower14;
1183             break;
1184           }
1185         }
1186       }
1187     }
1188     if (RequireFloatControls &&
1189         ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls))
1190       MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls);
1191   }
1192   for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
1193     const Function &F = *FI;
1194     if (F.isDeclaration())
1195       continue;
1196     if (F.getMetadata("reqd_work_group_size"))
1197       MAI.Reqs.getAndAddRequirements(
1198           SPIRV::OperandCategory::ExecutionModeOperand,
1199           SPIRV::ExecutionMode::LocalSize, ST);
1200     if (F.getFnAttribute("hlsl.numthreads").isValid()) {
1201       MAI.Reqs.getAndAddRequirements(
1202           SPIRV::OperandCategory::ExecutionModeOperand,
1203           SPIRV::ExecutionMode::LocalSize, ST);
1204     }
1205     if (F.getMetadata("work_group_size_hint"))
1206       MAI.Reqs.getAndAddRequirements(
1207           SPIRV::OperandCategory::ExecutionModeOperand,
1208           SPIRV::ExecutionMode::LocalSizeHint, ST);
1209     if (F.getMetadata("intel_reqd_sub_group_size"))
1210       MAI.Reqs.getAndAddRequirements(
1211           SPIRV::OperandCategory::ExecutionModeOperand,
1212           SPIRV::ExecutionMode::SubgroupSize, ST);
1213     if (F.getMetadata("vec_type_hint"))
1214       MAI.Reqs.getAndAddRequirements(
1215           SPIRV::OperandCategory::ExecutionModeOperand,
1216           SPIRV::ExecutionMode::VecTypeHint, ST);
1217 
1218     if (F.hasOptNone() &&
1219         ST.canUseExtension(SPIRV::Extension::SPV_INTEL_optnone)) {
1220       // Output OpCapability OptNoneINTEL.
1221       MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_optnone);
1222       MAI.Reqs.addCapability(SPIRV::Capability::OptNoneINTEL);
1223     }
1224   }
1225 }
1226 
1227 static unsigned getFastMathFlags(const MachineInstr &I) {
1228   unsigned Flags = SPIRV::FPFastMathMode::None;
1229   if (I.getFlag(MachineInstr::MIFlag::FmNoNans))
1230     Flags |= SPIRV::FPFastMathMode::NotNaN;
1231   if (I.getFlag(MachineInstr::MIFlag::FmNoInfs))
1232     Flags |= SPIRV::FPFastMathMode::NotInf;
1233   if (I.getFlag(MachineInstr::MIFlag::FmNsz))
1234     Flags |= SPIRV::FPFastMathMode::NSZ;
1235   if (I.getFlag(MachineInstr::MIFlag::FmArcp))
1236     Flags |= SPIRV::FPFastMathMode::AllowRecip;
1237   if (I.getFlag(MachineInstr::MIFlag::FmReassoc))
1238     Flags |= SPIRV::FPFastMathMode::Fast;
1239   return Flags;
1240 }
1241 
1242 static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST,
1243                                    const SPIRVInstrInfo &TII,
1244                                    SPIRV::RequirementHandler &Reqs) {
1245   if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) &&
1246       getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
1247                                      SPIRV::Decoration::NoSignedWrap, ST, Reqs)
1248           .IsSatisfiable) {
1249     buildOpDecorate(I.getOperand(0).getReg(), I, TII,
1250                     SPIRV::Decoration::NoSignedWrap, {});
1251   }
1252   if (I.getFlag(MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(I) &&
1253       getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
1254                                      SPIRV::Decoration::NoUnsignedWrap, ST,
1255                                      Reqs)
1256           .IsSatisfiable) {
1257     buildOpDecorate(I.getOperand(0).getReg(), I, TII,
1258                     SPIRV::Decoration::NoUnsignedWrap, {});
1259   }
1260   if (!TII.canUseFastMathFlags(I))
1261     return;
1262   unsigned FMFlags = getFastMathFlags(I);
1263   if (FMFlags == SPIRV::FPFastMathMode::None)
1264     return;
1265   Register DstReg = I.getOperand(0).getReg();
1266   buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode, {FMFlags});
1267 }
1268 
1269 // Walk all functions and add decorations related to MI flags.
1270 static void addDecorations(const Module &M, const SPIRVInstrInfo &TII,
1271                            MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
1272                            SPIRV::ModuleAnalysisInfo &MAI) {
1273   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
1274     MachineFunction *MF = MMI->getMachineFunction(*F);
1275     if (!MF)
1276       continue;
1277     for (auto &MBB : *MF)
1278       for (auto &MI : MBB)
1279         handleMIFlagDecoration(MI, ST, TII, MAI.Reqs);
1280   }
1281 }
1282 
1283 struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI;
1284 
1285 void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
1286   AU.addRequired<TargetPassConfig>();
1287   AU.addRequired<MachineModuleInfoWrapperPass>();
1288 }
1289 
1290 bool SPIRVModuleAnalysis::runOnModule(Module &M) {
1291   SPIRVTargetMachine &TM =
1292       getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>();
1293   ST = TM.getSubtargetImpl();
1294   GR = ST->getSPIRVGlobalRegistry();
1295   TII = ST->getInstrInfo();
1296 
1297   MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
1298 
1299   setBaseInfo(M);
1300 
1301   addDecorations(M, *TII, MMI, *ST, MAI);
1302 
1303   collectReqs(M, MAI, MMI, *ST);
1304 
1305   // Process type/const/global var/func decl instructions, number their
1306   // destination registers from 0 to N, collect Extensions and Capabilities.
1307   processDefInstrs(M);
1308 
1309   // Number rest of registers from N+1 onwards.
1310   numberRegistersGlobally(M);
1311 
1312   // Update references to OpFunction instructions to use Global Registers
1313   if (GR->hasConstFunPtr())
1314     collectFuncPtrs();
1315 
1316   // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.
1317   processOtherInstrs(M);
1318 
1319   // If there are no entry points, we need the Linkage capability.
1320   if (MAI.MS[SPIRV::MB_EntryPoints].empty())
1321     MAI.Reqs.addCapability(SPIRV::Capability::Linkage);
1322 
1323   // Set maximum ID used.
1324   GR->setBound(MAI.MaxID);
1325 
1326   return false;
1327 }
1328