xref: /llvm-project/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (revision 42633cf27bd2cfb44e9f332c33cfd6750b9d7be4)
1 //===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The analysis collects instructions that should be output at the module level
10 // and performs the global register numbering.
11 //
12 // The results of this analysis are used in AsmPrinter to rename registers
13 // globally and to output required instructions at the module level.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "SPIRVModuleAnalysis.h"
18 #include "MCTargetDesc/SPIRVBaseInfo.h"
19 #include "MCTargetDesc/SPIRVMCTargetDesc.h"
20 #include "SPIRV.h"
21 #include "SPIRVSubtarget.h"
22 #include "SPIRVTargetMachine.h"
23 #include "SPIRVUtils.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/CodeGen/MachineModuleInfo.h"
26 #include "llvm/CodeGen/TargetPassConfig.h"
27 
28 using namespace llvm;
29 
30 #define DEBUG_TYPE "spirv-module-analysis"
31 
32 static cl::opt<bool>
33     SPVDumpDeps("spv-dump-deps",
34                 cl::desc("Dump MIR with SPIR-V dependencies info"),
35                 cl::Optional, cl::init(false));
36 
37 static cl::list<SPIRV::Capability::Capability>
38     AvoidCapabilities("avoid-spirv-capabilities",
39                       cl::desc("SPIR-V capabilities to avoid if there are "
40                                "other options enabling a feature"),
41                       cl::ZeroOrMore, cl::Hidden,
42                       cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader",
43                                             "SPIR-V Shader capability")));
44 // Use sets instead of cl::list to check "if contains" condition
45 struct AvoidCapabilitiesSet {
46   SmallSet<SPIRV::Capability::Capability, 4> S;
47   AvoidCapabilitiesSet() {
48     for (auto Cap : AvoidCapabilities)
49       S.insert(Cap);
50   }
51 };
52 
53 char llvm::SPIRVModuleAnalysis::ID = 0;
54 
55 namespace llvm {
56 void initializeSPIRVModuleAnalysisPass(PassRegistry &);
57 } // namespace llvm
58 
59 INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true,
60                 true)
61 
62 // Retrieve an unsigned from an MDNode with a list of them as operands.
63 static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex,
64                                 unsigned DefaultVal = 0) {
65   if (MdNode && OpIndex < MdNode->getNumOperands()) {
66     const auto &Op = MdNode->getOperand(OpIndex);
67     return mdconst::extract<ConstantInt>(Op)->getZExtValue();
68   }
69   return DefaultVal;
70 }
71 
72 static SPIRV::Requirements
73 getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,
74                                unsigned i, const SPIRVSubtarget &ST,
75                                SPIRV::RequirementHandler &Reqs) {
76   static AvoidCapabilitiesSet
77       AvoidCaps; // contains capabilities to avoid if there is another option
78 
79   VersionTuple ReqMinVer = getSymbolicOperandMinVersion(Category, i);
80   VersionTuple ReqMaxVer = getSymbolicOperandMaxVersion(Category, i);
81   VersionTuple SPIRVVersion = ST.getSPIRVVersion();
82   bool MinVerOK = SPIRVVersion.empty() || SPIRVVersion >= ReqMinVer;
83   bool MaxVerOK =
84       ReqMaxVer.empty() || SPIRVVersion.empty() || SPIRVVersion <= ReqMaxVer;
85   CapabilityList ReqCaps = getSymbolicOperandCapabilities(Category, i);
86   ExtensionList ReqExts = getSymbolicOperandExtensions(Category, i);
87   if (ReqCaps.empty()) {
88     if (ReqExts.empty()) {
89       if (MinVerOK && MaxVerOK)
90         return {true, {}, {}, ReqMinVer, ReqMaxVer};
91       return {false, {}, {}, VersionTuple(), VersionTuple()};
92     }
93   } else if (MinVerOK && MaxVerOK) {
94     if (ReqCaps.size() == 1) {
95       auto Cap = ReqCaps[0];
96       if (Reqs.isCapabilityAvailable(Cap))
97         return {true, {Cap}, ReqExts, ReqMinVer, ReqMaxVer};
98     } else {
99       // By SPIR-V specification: "If an instruction, enumerant, or other
100       // feature specifies multiple enabling capabilities, only one such
101       // capability needs to be declared to use the feature." However, one
102       // capability may be preferred over another. We use command line
103       // argument(s) and AvoidCapabilities to avoid selection of certain
104       // capabilities if there are other options.
105       CapabilityList UseCaps;
106       for (auto Cap : ReqCaps)
107         if (Reqs.isCapabilityAvailable(Cap))
108           UseCaps.push_back(Cap);
109       for (size_t i = 0, Sz = UseCaps.size(); i < Sz; ++i) {
110         auto Cap = UseCaps[i];
111         if (i == Sz - 1 || !AvoidCaps.S.contains(Cap))
112           return {true, {Cap}, ReqExts, ReqMinVer, ReqMaxVer};
113       }
114     }
115   }
116   // If there are no capabilities, or we can't satisfy the version or
117   // capability requirements, use the list of extensions (if the subtarget
118   // can handle them all).
119   if (llvm::all_of(ReqExts, [&ST](const SPIRV::Extension::Extension &Ext) {
120         return ST.canUseExtension(Ext);
121       })) {
122     return {true,
123             {},
124             ReqExts,
125             VersionTuple(),
126             VersionTuple()}; // TODO: add versions to extensions.
127   }
128   return {false, {}, {}, VersionTuple(), VersionTuple()};
129 }
130 
131 void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {
132   MAI.MaxID = 0;
133   for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++)
134     MAI.MS[i].clear();
135   MAI.RegisterAliasTable.clear();
136   MAI.InstrsToDelete.clear();
137   MAI.FuncMap.clear();
138   MAI.GlobalVarList.clear();
139   MAI.ExtInstSetMap.clear();
140   MAI.Reqs.clear();
141   MAI.Reqs.initAvailableCapabilities(*ST);
142 
143   // TODO: determine memory model and source language from the configuratoin.
144   if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) {
145     auto MemMD = MemModel->getOperand(0);
146     MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>(
147         getMetadataUInt(MemMD, 0));
148     MAI.Mem =
149         static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1));
150   } else {
151     // TODO: Add support for VulkanMemoryModel.
152     MAI.Mem = ST->isOpenCLEnv() ? SPIRV::MemoryModel::OpenCL
153                                 : SPIRV::MemoryModel::GLSL450;
154     if (MAI.Mem == SPIRV::MemoryModel::OpenCL) {
155       unsigned PtrSize = ST->getPointerSize();
156       MAI.Addr = PtrSize == 32   ? SPIRV::AddressingModel::Physical32
157                  : PtrSize == 64 ? SPIRV::AddressingModel::Physical64
158                                  : SPIRV::AddressingModel::Logical;
159     } else {
160       // TODO: Add support for PhysicalStorageBufferAddress.
161       MAI.Addr = SPIRV::AddressingModel::Logical;
162     }
163   }
164   // Get the OpenCL version number from metadata.
165   // TODO: support other source languages.
166   if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) {
167     MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C;
168     // Construct version literal in accordance with SPIRV-LLVM-Translator.
169     // TODO: support multiple OCL version metadata.
170     assert(VerNode->getNumOperands() > 0 && "Invalid SPIR");
171     auto VersionMD = VerNode->getOperand(0);
172     unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2);
173     unsigned MinorNum = getMetadataUInt(VersionMD, 1);
174     unsigned RevNum = getMetadataUInt(VersionMD, 2);
175     // Prevent Major part of OpenCL version to be 0
176     MAI.SrcLangVersion =
177         (std::max(1U, MajorNum) * 100 + MinorNum) * 1000 + RevNum;
178   } else {
179     // If there is no information about OpenCL version we are forced to generate
180     // OpenCL 1.0 by default for the OpenCL environment to avoid puzzling
181     // run-times with Unknown/0.0 version output. For a reference, LLVM-SPIRV
182     // Translator avoids potential issues with run-times in a similar manner.
183     if (ST->isOpenCLEnv()) {
184       MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_CPP;
185       MAI.SrcLangVersion = 100000;
186     } else {
187       MAI.SrcLang = SPIRV::SourceLanguage::Unknown;
188       MAI.SrcLangVersion = 0;
189     }
190   }
191 
192   if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) {
193     for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) {
194       MDNode *MD = ExtNode->getOperand(I);
195       if (!MD || MD->getNumOperands() == 0)
196         continue;
197       for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)
198         MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString());
199     }
200   }
201 
202   // Update required capabilities for this memory model, addressing model and
203   // source language.
204   MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand,
205                                  MAI.Mem, *ST);
206   MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::SourceLanguageOperand,
207                                  MAI.SrcLang, *ST);
208   MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
209                                  MAI.Addr, *ST);
210 
211   if (ST->isOpenCLEnv()) {
212     // TODO: check if it's required by default.
213     MAI.ExtInstSetMap[static_cast<unsigned>(
214         SPIRV::InstructionSet::OpenCL_std)] =
215         Register::index2VirtReg(MAI.getNextID());
216   }
217 }
218 
219 // Collect MI which defines the register in the given machine function.
220 static void collectDefInstr(Register Reg, const MachineFunction *MF,
221                             SPIRV::ModuleAnalysisInfo *MAI,
222                             SPIRV::ModuleSectionType MSType,
223                             bool DoInsert = true) {
224   assert(MAI->hasRegisterAlias(MF, Reg) && "Cannot find register alias");
225   MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(Reg);
226   assert(MI && "There should be an instruction that defines the register");
227   MAI->setSkipEmission(MI);
228   if (DoInsert)
229     MAI->MS[MSType].push_back(MI);
230 }
231 
232 void SPIRVModuleAnalysis::collectGlobalEntities(
233     const std::vector<SPIRV::DTSortableEntry *> &DepsGraph,
234     SPIRV::ModuleSectionType MSType,
235     std::function<bool(const SPIRV::DTSortableEntry *)> Pred,
236     bool UsePreOrder = false) {
237   DenseSet<const SPIRV::DTSortableEntry *> Visited;
238   for (const auto *E : DepsGraph) {
239     std::function<void(const SPIRV::DTSortableEntry *)> RecHoistUtil;
240     // NOTE: here we prefer recursive approach over iterative because
241     // we don't expect depchains long enough to cause SO.
242     RecHoistUtil = [MSType, UsePreOrder, &Visited, &Pred,
243                     &RecHoistUtil](const SPIRV::DTSortableEntry *E) {
244       if (Visited.count(E) || !Pred(E))
245         return;
246       Visited.insert(E);
247 
248       // Traversing deps graph in post-order allows us to get rid of
249       // register aliases preprocessing.
250       // But pre-order is required for correct processing of function
251       // declaration and arguments processing.
252       if (!UsePreOrder)
253         for (auto *S : E->getDeps())
254           RecHoistUtil(S);
255 
256       Register GlobalReg = Register::index2VirtReg(MAI.getNextID());
257       bool IsFirst = true;
258       for (auto &U : *E) {
259         const MachineFunction *MF = U.first;
260         Register Reg = U.second;
261         MAI.setRegisterAlias(MF, Reg, GlobalReg);
262         if (!MF->getRegInfo().getUniqueVRegDef(Reg))
263           continue;
264         collectDefInstr(Reg, MF, &MAI, MSType, IsFirst);
265         IsFirst = false;
266         if (E->getIsGV())
267           MAI.GlobalVarList.push_back(MF->getRegInfo().getUniqueVRegDef(Reg));
268       }
269 
270       if (UsePreOrder)
271         for (auto *S : E->getDeps())
272           RecHoistUtil(S);
273     };
274     RecHoistUtil(E);
275   }
276 }
277 
278 // The function initializes global register alias table for types, consts,
279 // global vars and func decls and collects these instruction for output
280 // at module level. Also it collects explicit OpExtension/OpCapability
281 // instructions.
282 void SPIRVModuleAnalysis::processDefInstrs(const Module &M) {
283   std::vector<SPIRV::DTSortableEntry *> DepsGraph;
284 
285   GR->buildDepsGraph(DepsGraph, TII, SPVDumpDeps ? MMI : nullptr);
286 
287   collectGlobalEntities(
288       DepsGraph, SPIRV::MB_TypeConstVars,
289       [](const SPIRV::DTSortableEntry *E) { return !E->getIsFunc(); });
290 
291   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
292     MachineFunction *MF = MMI->getMachineFunction(*F);
293     if (!MF)
294       continue;
295     // Iterate through and collect OpExtension/OpCapability instructions.
296     for (MachineBasicBlock &MBB : *MF) {
297       for (MachineInstr &MI : MBB) {
298         if (MI.getOpcode() == SPIRV::OpExtension) {
299           // Here, OpExtension just has a single enum operand, not a string.
300           auto Ext = SPIRV::Extension::Extension(MI.getOperand(0).getImm());
301           MAI.Reqs.addExtension(Ext);
302           MAI.setSkipEmission(&MI);
303         } else if (MI.getOpcode() == SPIRV::OpCapability) {
304           auto Cap = SPIRV::Capability::Capability(MI.getOperand(0).getImm());
305           MAI.Reqs.addCapability(Cap);
306           MAI.setSkipEmission(&MI);
307         }
308       }
309     }
310   }
311 
312   collectGlobalEntities(
313       DepsGraph, SPIRV::MB_ExtFuncDecls,
314       [](const SPIRV::DTSortableEntry *E) { return E->getIsFunc(); }, true);
315 }
316 
317 // Look for IDs declared with Import linkage, and map the corresponding function
318 // to the register defining that variable (which will usually be the result of
319 // an OpFunction). This lets us call externally imported functions using
320 // the correct ID registers.
321 void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
322                                            const Function *F) {
323   if (MI.getOpcode() == SPIRV::OpDecorate) {
324     // If it's got Import linkage.
325     auto Dec = MI.getOperand(1).getImm();
326     if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) {
327       auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm();
328       if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) {
329         // Map imported function name to function ID register.
330         const Function *ImportedFunc =
331             F->getParent()->getFunction(getStringImm(MI, 2));
332         Register Target = MI.getOperand(0).getReg();
333         MAI.FuncMap[ImportedFunc] = MAI.getRegisterAlias(MI.getMF(), Target);
334       }
335     }
336   } else if (MI.getOpcode() == SPIRV::OpFunction) {
337     // Record all internal OpFunction declarations.
338     Register Reg = MI.defs().begin()->getReg();
339     Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg);
340     assert(GlobalReg.isValid());
341     MAI.FuncMap[F] = GlobalReg;
342   }
343 }
344 
345 // References to a function via function pointers generate virtual
346 // registers without a definition. We are able to resolve this
347 // reference using Globar Register info into an OpFunction instruction
348 // and replace dummy operands by the corresponding global register references.
349 void SPIRVModuleAnalysis::collectFuncPtrs() {
350   for (auto &MI : MAI.MS[SPIRV::MB_TypeConstVars])
351     if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL)
352       collectFuncPtrs(MI);
353 }
354 
355 void SPIRVModuleAnalysis::collectFuncPtrs(MachineInstr *MI) {
356   const MachineOperand *FunUse = &MI->getOperand(2);
357   if (const MachineOperand *FunDef = GR->getFunctionDefinitionByUse(FunUse)) {
358     const MachineInstr *FunDefMI = FunDef->getParent();
359     assert(FunDefMI->getOpcode() == SPIRV::OpFunction &&
360            "Constant function pointer must refer to function definition");
361     Register FunDefReg = FunDef->getReg();
362     Register GlobalFunDefReg =
363         MAI.getRegisterAlias(FunDefMI->getMF(), FunDefReg);
364     assert(GlobalFunDefReg.isValid() &&
365            "Function definition must refer to a global register");
366     Register FunPtrReg = FunUse->getReg();
367     MAI.setRegisterAlias(MI->getMF(), FunPtrReg, GlobalFunDefReg);
368   }
369 }
370 
371 using InstrSignature = SmallVector<size_t>;
372 using InstrTraces = std::set<InstrSignature>;
373 
374 // Returns a representation of an instruction as a vector of MachineOperand
375 // hash values, see llvm::hash_value(const MachineOperand &MO) for details.
376 // This creates a signature of the instruction with the same content
377 // that MachineOperand::isIdenticalTo uses for comparison.
378 static InstrSignature instrToSignature(MachineInstr &MI,
379                                        SPIRV::ModuleAnalysisInfo &MAI) {
380   InstrSignature Signature;
381   for (unsigned i = 0; i < MI.getNumOperands(); ++i) {
382     const MachineOperand &MO = MI.getOperand(i);
383     size_t h;
384     if (MO.isReg()) {
385       Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg());
386       // mimic llvm::hash_value(const MachineOperand &MO)
387       h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(),
388                        MO.isDef());
389     } else {
390       h = hash_value(MO);
391     }
392     Signature.push_back(h);
393   }
394   return Signature;
395 }
396 
397 // Collect the given instruction in the specified MS. We assume global register
398 // numbering has already occurred by this point. We can directly compare reg
399 // arguments when detecting duplicates.
400 static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI,
401                               SPIRV::ModuleSectionType MSType, InstrTraces &IS,
402                               bool Append = true) {
403   MAI.setSkipEmission(&MI);
404   InstrSignature MISign = instrToSignature(MI, MAI);
405   auto FoundMI = IS.insert(MISign);
406   if (!FoundMI.second)
407     return; // insert failed, so we found a duplicate; don't add it to MAI.MS
408   // No duplicates, so add it.
409   if (Append)
410     MAI.MS[MSType].push_back(&MI);
411   else
412     MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI);
413 }
414 
415 // Some global instructions make reference to function-local ID regs, so cannot
416 // be correctly collected until these registers are globally numbered.
417 void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) {
418   InstrTraces IS;
419   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
420     if ((*F).isDeclaration())
421       continue;
422     MachineFunction *MF = MMI->getMachineFunction(*F);
423     assert(MF);
424 
425     for (MachineBasicBlock &MBB : *MF)
426       for (MachineInstr &MI : MBB) {
427         if (MAI.getSkipEmission(&MI))
428           continue;
429         const unsigned OpCode = MI.getOpcode();
430         if (OpCode == SPIRV::OpString) {
431           collectOtherInstr(MI, MAI, SPIRV::MB_DebugStrings, IS);
432         } else if (OpCode == SPIRV::OpExtInst && MI.getOperand(2).isImm() &&
433                    MI.getOperand(2).getImm() ==
434                        SPIRV::InstructionSet::
435                            NonSemantic_Shader_DebugInfo_100) {
436           MachineOperand Ins = MI.getOperand(3);
437           namespace NS = SPIRV::NonSemanticExtInst;
438           static constexpr int64_t GlobalNonSemanticDITy[] = {
439               NS::DebugSource, NS::DebugCompilationUnit, NS::DebugInfoNone,
440               NS::DebugTypeBasic, NS::DebugTypePointer};
441           bool IsGlobalDI = false;
442           for (unsigned Idx = 0; Idx < std::size(GlobalNonSemanticDITy); ++Idx)
443             IsGlobalDI |= Ins.getImm() == GlobalNonSemanticDITy[Idx];
444           if (IsGlobalDI)
445             collectOtherInstr(MI, MAI, SPIRV::MB_NonSemanticGlobalDI, IS);
446         } else if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) {
447           collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames, IS);
448         } else if (OpCode == SPIRV::OpEntryPoint) {
449           collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints, IS);
450         } else if (TII->isDecorationInstr(MI)) {
451           collectOtherInstr(MI, MAI, SPIRV::MB_Annotations, IS);
452           collectFuncNames(MI, &*F);
453         } else if (TII->isConstantInstr(MI)) {
454           // Now OpSpecConstant*s are not in DT,
455           // but they need to be collected anyway.
456           collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS);
457         } else if (OpCode == SPIRV::OpFunction) {
458           collectFuncNames(MI, &*F);
459         } else if (OpCode == SPIRV::OpTypeForwardPointer) {
460           collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS, false);
461         }
462       }
463   }
464 }
465 
466 // Number registers in all functions globally from 0 onwards and store
467 // the result in global register alias table. Some registers are already
468 // numbered in collectGlobalEntities.
469 void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {
470   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
471     if ((*F).isDeclaration())
472       continue;
473     MachineFunction *MF = MMI->getMachineFunction(*F);
474     assert(MF);
475     for (MachineBasicBlock &MBB : *MF) {
476       for (MachineInstr &MI : MBB) {
477         for (MachineOperand &Op : MI.operands()) {
478           if (!Op.isReg())
479             continue;
480           Register Reg = Op.getReg();
481           if (MAI.hasRegisterAlias(MF, Reg))
482             continue;
483           Register NewReg = Register::index2VirtReg(MAI.getNextID());
484           MAI.setRegisterAlias(MF, Reg, NewReg);
485         }
486         if (MI.getOpcode() != SPIRV::OpExtInst)
487           continue;
488         auto Set = MI.getOperand(2).getImm();
489         if (!MAI.ExtInstSetMap.contains(Set))
490           MAI.ExtInstSetMap[Set] = Register::index2VirtReg(MAI.getNextID());
491       }
492     }
493   }
494 }
495 
496 // RequirementHandler implementations.
497 void SPIRV::RequirementHandler::getAndAddRequirements(
498     SPIRV::OperandCategory::OperandCategory Category, uint32_t i,
499     const SPIRVSubtarget &ST) {
500   addRequirements(getSymbolicOperandRequirements(Category, i, ST, *this));
501 }
502 
503 void SPIRV::RequirementHandler::recursiveAddCapabilities(
504     const CapabilityList &ToPrune) {
505   for (const auto &Cap : ToPrune) {
506     AllCaps.insert(Cap);
507     CapabilityList ImplicitDecls =
508         getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
509     recursiveAddCapabilities(ImplicitDecls);
510   }
511 }
512 
513 void SPIRV::RequirementHandler::addCapabilities(const CapabilityList &ToAdd) {
514   for (const auto &Cap : ToAdd) {
515     bool IsNewlyInserted = AllCaps.insert(Cap).second;
516     if (!IsNewlyInserted) // Don't re-add if it's already been declared.
517       continue;
518     CapabilityList ImplicitDecls =
519         getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
520     recursiveAddCapabilities(ImplicitDecls);
521     MinimalCaps.push_back(Cap);
522   }
523 }
524 
525 void SPIRV::RequirementHandler::addRequirements(
526     const SPIRV::Requirements &Req) {
527   if (!Req.IsSatisfiable)
528     report_fatal_error("Adding SPIR-V requirements this target can't satisfy.");
529 
530   if (Req.Cap.has_value())
531     addCapabilities({Req.Cap.value()});
532 
533   addExtensions(Req.Exts);
534 
535   if (!Req.MinVer.empty()) {
536     if (!MaxVersion.empty() && Req.MinVer > MaxVersion) {
537       LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer
538                         << " and <= " << MaxVersion << "\n");
539       report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
540     }
541 
542     if (MinVersion.empty() || Req.MinVer > MinVersion)
543       MinVersion = Req.MinVer;
544   }
545 
546   if (!Req.MaxVer.empty()) {
547     if (!MinVersion.empty() && Req.MaxVer < MinVersion) {
548       LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer
549                         << " and >= " << MinVersion << "\n");
550       report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
551     }
552 
553     if (MaxVersion.empty() || Req.MaxVer < MaxVersion)
554       MaxVersion = Req.MaxVer;
555   }
556 }
557 
558 void SPIRV::RequirementHandler::checkSatisfiable(
559     const SPIRVSubtarget &ST) const {
560   // Report as many errors as possible before aborting the compilation.
561   bool IsSatisfiable = true;
562   auto TargetVer = ST.getSPIRVVersion();
563 
564   if (!MaxVersion.empty() && !TargetVer.empty() && MaxVersion < TargetVer) {
565     LLVM_DEBUG(
566         dbgs() << "Target SPIR-V version too high for required features\n"
567                << "Required max version: " << MaxVersion << " target version "
568                << TargetVer << "\n");
569     IsSatisfiable = false;
570   }
571 
572   if (!MinVersion.empty() && !TargetVer.empty() && MinVersion > TargetVer) {
573     LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n"
574                       << "Required min version: " << MinVersion
575                       << " target version " << TargetVer << "\n");
576     IsSatisfiable = false;
577   }
578 
579   if (!MinVersion.empty() && !MaxVersion.empty() && MinVersion > MaxVersion) {
580     LLVM_DEBUG(
581         dbgs()
582         << "Version is too low for some features and too high for others.\n"
583         << "Required SPIR-V min version: " << MinVersion
584         << " required SPIR-V max version " << MaxVersion << "\n");
585     IsSatisfiable = false;
586   }
587 
588   for (auto Cap : MinimalCaps) {
589     if (AvailableCaps.contains(Cap))
590       continue;
591     LLVM_DEBUG(dbgs() << "Capability not supported: "
592                       << getSymbolicOperandMnemonic(
593                              OperandCategory::CapabilityOperand, Cap)
594                       << "\n");
595     IsSatisfiable = false;
596   }
597 
598   for (auto Ext : AllExtensions) {
599     if (ST.canUseExtension(Ext))
600       continue;
601     LLVM_DEBUG(dbgs() << "Extension not supported: "
602                       << getSymbolicOperandMnemonic(
603                              OperandCategory::ExtensionOperand, Ext)
604                       << "\n");
605     IsSatisfiable = false;
606   }
607 
608   if (!IsSatisfiable)
609     report_fatal_error("Unable to meet SPIR-V requirements for this target.");
610 }
611 
612 // Add the given capabilities and all their implicitly defined capabilities too.
613 void SPIRV::RequirementHandler::addAvailableCaps(const CapabilityList &ToAdd) {
614   for (const auto Cap : ToAdd)
615     if (AvailableCaps.insert(Cap).second)
616       addAvailableCaps(getSymbolicOperandCapabilities(
617           SPIRV::OperandCategory::CapabilityOperand, Cap));
618 }
619 
620 void SPIRV::RequirementHandler::removeCapabilityIf(
621     const Capability::Capability ToRemove,
622     const Capability::Capability IfPresent) {
623   if (AllCaps.contains(IfPresent))
624     AllCaps.erase(ToRemove);
625 }
626 
627 namespace llvm {
628 namespace SPIRV {
629 void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
630   // Provided by both all supported Vulkan versions and OpenCl.
631   addAvailableCaps({Capability::Shader, Capability::Linkage, Capability::Int8,
632                     Capability::Int16});
633 
634   if (ST.isAtLeastSPIRVVer(VersionTuple(1, 3)))
635     addAvailableCaps({Capability::GroupNonUniform,
636                       Capability::GroupNonUniformVote,
637                       Capability::GroupNonUniformArithmetic,
638                       Capability::GroupNonUniformBallot,
639                       Capability::GroupNonUniformClustered,
640                       Capability::GroupNonUniformShuffle,
641                       Capability::GroupNonUniformShuffleRelative});
642 
643   if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6)))
644     addAvailableCaps({Capability::DotProduct, Capability::DotProductInputAll,
645                       Capability::DotProductInput4x8Bit,
646                       Capability::DotProductInput4x8BitPacked,
647                       Capability::DemoteToHelperInvocation});
648 
649   // Add capabilities enabled by extensions.
650   for (auto Extension : ST.getAllAvailableExtensions()) {
651     CapabilityList EnabledCapabilities =
652         getCapabilitiesEnabledByExtension(Extension);
653     addAvailableCaps(EnabledCapabilities);
654   }
655 
656   if (ST.isOpenCLEnv()) {
657     initAvailableCapabilitiesForOpenCL(ST);
658     return;
659   }
660 
661   if (ST.isVulkanEnv()) {
662     initAvailableCapabilitiesForVulkan(ST);
663     return;
664   }
665 
666   report_fatal_error("Unimplemented environment for SPIR-V generation.");
667 }
668 
669 void RequirementHandler::initAvailableCapabilitiesForOpenCL(
670     const SPIRVSubtarget &ST) {
671   // Add the min requirements for different OpenCL and SPIR-V versions.
672   addAvailableCaps({Capability::Addresses, Capability::Float16Buffer,
673                     Capability::Kernel, Capability::Vector16,
674                     Capability::Groups, Capability::GenericPointer,
675                     Capability::StorageImageWriteWithoutFormat,
676                     Capability::StorageImageReadWithoutFormat});
677   if (ST.hasOpenCLFullProfile())
678     addAvailableCaps({Capability::Int64, Capability::Int64Atomics});
679   if (ST.hasOpenCLImageSupport()) {
680     addAvailableCaps({Capability::ImageBasic, Capability::LiteralSampler,
681                       Capability::Image1D, Capability::SampledBuffer,
682                       Capability::ImageBuffer});
683     if (ST.isAtLeastOpenCLVer(VersionTuple(2, 0)))
684       addAvailableCaps({Capability::ImageReadWrite});
685   }
686   if (ST.isAtLeastSPIRVVer(VersionTuple(1, 1)) &&
687       ST.isAtLeastOpenCLVer(VersionTuple(2, 2)))
688     addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage});
689   if (ST.isAtLeastSPIRVVer(VersionTuple(1, 4)))
690     addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero,
691                       Capability::SignedZeroInfNanPreserve,
692                       Capability::RoundingModeRTE,
693                       Capability::RoundingModeRTZ});
694   // TODO: verify if this needs some checks.
695   addAvailableCaps({Capability::Float16, Capability::Float64});
696 
697   // TODO: add OpenCL extensions.
698 }
699 
700 void RequirementHandler::initAvailableCapabilitiesForVulkan(
701     const SPIRVSubtarget &ST) {
702 
703   // Core in Vulkan 1.1 and earlier.
704   addAvailableCaps({Capability::Int64, Capability::Float16, Capability::Float64,
705                     Capability::GroupNonUniform, Capability::Image1D,
706                     Capability::SampledBuffer, Capability::ImageBuffer,
707                     Capability::UniformBufferArrayDynamicIndexing,
708                     Capability::SampledImageArrayDynamicIndexing,
709                     Capability::StorageBufferArrayDynamicIndexing,
710                     Capability::StorageImageArrayDynamicIndexing});
711 
712   // Became core in Vulkan 1.2
713   if (ST.isAtLeastSPIRVVer(VersionTuple(1, 5))) {
714     addAvailableCaps(
715         {Capability::ShaderNonUniformEXT, Capability::RuntimeDescriptorArrayEXT,
716          Capability::InputAttachmentArrayDynamicIndexingEXT,
717          Capability::UniformTexelBufferArrayDynamicIndexingEXT,
718          Capability::StorageTexelBufferArrayDynamicIndexingEXT,
719          Capability::UniformBufferArrayNonUniformIndexingEXT,
720          Capability::SampledImageArrayNonUniformIndexingEXT,
721          Capability::StorageBufferArrayNonUniformIndexingEXT,
722          Capability::StorageImageArrayNonUniformIndexingEXT,
723          Capability::InputAttachmentArrayNonUniformIndexingEXT,
724          Capability::UniformTexelBufferArrayNonUniformIndexingEXT,
725          Capability::StorageTexelBufferArrayNonUniformIndexingEXT});
726   }
727 
728   // Became core in Vulkan 1.3
729   if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6)))
730     addAvailableCaps({Capability::StorageImageWriteWithoutFormat,
731                       Capability::StorageImageReadWithoutFormat});
732 }
733 
734 } // namespace SPIRV
735 } // namespace llvm
736 
737 // Add the required capabilities from a decoration instruction (including
738 // BuiltIns).
739 static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex,
740                               SPIRV::RequirementHandler &Reqs,
741                               const SPIRVSubtarget &ST) {
742   int64_t DecOp = MI.getOperand(DecIndex).getImm();
743   auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp);
744   Reqs.addRequirements(getSymbolicOperandRequirements(
745       SPIRV::OperandCategory::DecorationOperand, Dec, ST, Reqs));
746 
747   if (Dec == SPIRV::Decoration::BuiltIn) {
748     int64_t BuiltInOp = MI.getOperand(DecIndex + 1).getImm();
749     auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp);
750     Reqs.addRequirements(getSymbolicOperandRequirements(
751         SPIRV::OperandCategory::BuiltInOperand, BuiltIn, ST, Reqs));
752   } else if (Dec == SPIRV::Decoration::LinkageAttributes) {
753     int64_t LinkageOp = MI.getOperand(MI.getNumOperands() - 1).getImm();
754     SPIRV::LinkageType::LinkageType LnkType =
755         static_cast<SPIRV::LinkageType::LinkageType>(LinkageOp);
756     if (LnkType == SPIRV::LinkageType::LinkOnceODR)
757       Reqs.addExtension(SPIRV::Extension::SPV_KHR_linkonce_odr);
758   } else if (Dec == SPIRV::Decoration::CacheControlLoadINTEL ||
759              Dec == SPIRV::Decoration::CacheControlStoreINTEL) {
760     Reqs.addExtension(SPIRV::Extension::SPV_INTEL_cache_controls);
761   } else if (Dec == SPIRV::Decoration::HostAccessINTEL) {
762     Reqs.addExtension(SPIRV::Extension::SPV_INTEL_global_variable_host_access);
763   } else if (Dec == SPIRV::Decoration::InitModeINTEL ||
764              Dec == SPIRV::Decoration::ImplementInRegisterMapINTEL) {
765     Reqs.addExtension(
766         SPIRV::Extension::SPV_INTEL_global_variable_fpga_decorations);
767   } else if (Dec == SPIRV::Decoration::NonUniformEXT) {
768     Reqs.addRequirements(SPIRV::Capability::ShaderNonUniformEXT);
769   }
770 }
771 
772 // Add requirements for image handling.
773 static void addOpTypeImageReqs(const MachineInstr &MI,
774                                SPIRV::RequirementHandler &Reqs,
775                                const SPIRVSubtarget &ST) {
776   assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage");
777   // The operand indices used here are based on the OpTypeImage layout, which
778   // the MachineInstr follows as well.
779   int64_t ImgFormatOp = MI.getOperand(7).getImm();
780   auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp);
781   Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageFormatOperand,
782                              ImgFormat, ST);
783 
784   bool IsArrayed = MI.getOperand(4).getImm() == 1;
785   bool IsMultisampled = MI.getOperand(5).getImm() == 1;
786   bool NoSampler = MI.getOperand(6).getImm() == 2;
787   // Add dimension requirements.
788   assert(MI.getOperand(2).isImm());
789   switch (MI.getOperand(2).getImm()) {
790   case SPIRV::Dim::DIM_1D:
791     Reqs.addRequirements(NoSampler ? SPIRV::Capability::Image1D
792                                    : SPIRV::Capability::Sampled1D);
793     break;
794   case SPIRV::Dim::DIM_2D:
795     if (IsMultisampled && NoSampler)
796       Reqs.addRequirements(SPIRV::Capability::ImageMSArray);
797     break;
798   case SPIRV::Dim::DIM_Cube:
799     Reqs.addRequirements(SPIRV::Capability::Shader);
800     if (IsArrayed)
801       Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageCubeArray
802                                      : SPIRV::Capability::SampledCubeArray);
803     break;
804   case SPIRV::Dim::DIM_Rect:
805     Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageRect
806                                    : SPIRV::Capability::SampledRect);
807     break;
808   case SPIRV::Dim::DIM_Buffer:
809     Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageBuffer
810                                    : SPIRV::Capability::SampledBuffer);
811     break;
812   case SPIRV::Dim::DIM_SubpassData:
813     Reqs.addRequirements(SPIRV::Capability::InputAttachment);
814     break;
815   }
816 
817   // Has optional access qualifier.
818   if (ST.isOpenCLEnv()) {
819     if (MI.getNumOperands() > 8 &&
820         MI.getOperand(8).getImm() == SPIRV::AccessQualifier::ReadWrite)
821       Reqs.addRequirements(SPIRV::Capability::ImageReadWrite);
822     else
823       Reqs.addRequirements(SPIRV::Capability::ImageBasic);
824   }
825 }
826 
827 // Add requirements for handling atomic float instructions
828 #define ATOM_FLT_REQ_EXT_MSG(ExtName)                                          \
829   "The atomic float instruction requires the following SPIR-V "                \
830   "extension: SPV_EXT_shader_atomic_float" ExtName
831 static void AddAtomicFloatRequirements(const MachineInstr &MI,
832                                        SPIRV::RequirementHandler &Reqs,
833                                        const SPIRVSubtarget &ST) {
834   assert(MI.getOperand(1).isReg() &&
835          "Expect register operand in atomic float instruction");
836   Register TypeReg = MI.getOperand(1).getReg();
837   SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg);
838   if (TypeDef->getOpcode() != SPIRV::OpTypeFloat)
839     report_fatal_error("Result type of an atomic float instruction must be a "
840                        "floating-point type scalar");
841 
842   unsigned BitWidth = TypeDef->getOperand(1).getImm();
843   unsigned Op = MI.getOpcode();
844   if (Op == SPIRV::OpAtomicFAddEXT) {
845     if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add))
846       report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_add"), false);
847     Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
848     switch (BitWidth) {
849     case 16:
850       if (!ST.canUseExtension(
851               SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
852         report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
853       Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
854       Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
855       break;
856     case 32:
857       Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT);
858       break;
859     case 64:
860       Reqs.addCapability(SPIRV::Capability::AtomicFloat64AddEXT);
861       break;
862     default:
863       report_fatal_error(
864           "Unexpected floating-point type width in atomic float instruction");
865     }
866   } else {
867     if (!ST.canUseExtension(
868             SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max))
869       report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), false);
870     Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
871     switch (BitWidth) {
872     case 16:
873       Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
874       break;
875     case 32:
876       Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT);
877       break;
878     case 64:
879       Reqs.addCapability(SPIRV::Capability::AtomicFloat64MinMaxEXT);
880       break;
881     default:
882       report_fatal_error(
883           "Unexpected floating-point type width in atomic float instruction");
884     }
885   }
886 }
887 
888 bool isUniformTexelBuffer(MachineInstr *ImageInst) {
889   if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
890     return false;
891   uint32_t Dim = ImageInst->getOperand(2).getImm();
892   uint32_t Sampled = ImageInst->getOperand(6).getImm();
893   return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 1;
894 }
895 
896 bool isStorageTexelBuffer(MachineInstr *ImageInst) {
897   if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
898     return false;
899   uint32_t Dim = ImageInst->getOperand(2).getImm();
900   uint32_t Sampled = ImageInst->getOperand(6).getImm();
901   return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 2;
902 }
903 
904 bool isSampledImage(MachineInstr *ImageInst) {
905   if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
906     return false;
907   uint32_t Dim = ImageInst->getOperand(2).getImm();
908   uint32_t Sampled = ImageInst->getOperand(6).getImm();
909   return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 1;
910 }
911 
912 bool isInputAttachment(MachineInstr *ImageInst) {
913   if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
914     return false;
915   uint32_t Dim = ImageInst->getOperand(2).getImm();
916   uint32_t Sampled = ImageInst->getOperand(6).getImm();
917   return Dim == SPIRV::Dim::DIM_SubpassData && Sampled == 2;
918 }
919 
920 bool isStorageImage(MachineInstr *ImageInst) {
921   if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
922     return false;
923   uint32_t Dim = ImageInst->getOperand(2).getImm();
924   uint32_t Sampled = ImageInst->getOperand(6).getImm();
925   return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 2;
926 }
927 
928 bool isCombinedImageSampler(MachineInstr *SampledImageInst) {
929   if (SampledImageInst->getOpcode() != SPIRV::OpTypeSampledImage)
930     return false;
931 
932   const MachineRegisterInfo &MRI = SampledImageInst->getMF()->getRegInfo();
933   Register ImageReg = SampledImageInst->getOperand(1).getReg();
934   auto *ImageInst = MRI.getUniqueVRegDef(ImageReg);
935   return isSampledImage(ImageInst);
936 }
937 
938 bool hasNonUniformDecoration(Register Reg, const MachineRegisterInfo &MRI) {
939   for (const auto &MI : MRI.reg_instructions(Reg)) {
940     if (MI.getOpcode() != SPIRV::OpDecorate)
941       continue;
942 
943     uint32_t Dec = MI.getOperand(1).getImm();
944     if (Dec == SPIRV::Decoration::NonUniformEXT)
945       return true;
946   }
947   return false;
948 }
949 
950 void addOpAccessChainReqs(const MachineInstr &Instr,
951                           SPIRV::RequirementHandler &Handler,
952                           const SPIRVSubtarget &Subtarget) {
953   const MachineRegisterInfo &MRI = Instr.getMF()->getRegInfo();
954   // Get the result type. If it is an image type, then the shader uses
955   // descriptor indexing. The appropriate capabilities will be added based
956   // on the specifics of the image.
957   Register ResTypeReg = Instr.getOperand(1).getReg();
958   MachineInstr *ResTypeInst = MRI.getUniqueVRegDef(ResTypeReg);
959 
960   assert(ResTypeInst->getOpcode() == SPIRV::OpTypePointer);
961   uint32_t StorageClass = ResTypeInst->getOperand(1).getImm();
962   if (StorageClass != SPIRV::StorageClass::StorageClass::UniformConstant &&
963       StorageClass != SPIRV::StorageClass::StorageClass::Uniform &&
964       StorageClass != SPIRV::StorageClass::StorageClass::StorageBuffer) {
965     return;
966   }
967 
968   Register PointeeTypeReg = ResTypeInst->getOperand(2).getReg();
969   MachineInstr *PointeeType = MRI.getUniqueVRegDef(PointeeTypeReg);
970   if (PointeeType->getOpcode() != SPIRV::OpTypeImage &&
971       PointeeType->getOpcode() != SPIRV::OpTypeSampledImage &&
972       PointeeType->getOpcode() != SPIRV::OpTypeSampler) {
973     return;
974   }
975 
976   bool IsNonUniform =
977       hasNonUniformDecoration(Instr.getOperand(0).getReg(), MRI);
978   if (isUniformTexelBuffer(PointeeType)) {
979     if (IsNonUniform)
980       Handler.addRequirements(
981           SPIRV::Capability::UniformTexelBufferArrayNonUniformIndexingEXT);
982     else
983       Handler.addRequirements(
984           SPIRV::Capability::UniformTexelBufferArrayDynamicIndexingEXT);
985   } else if (isInputAttachment(PointeeType)) {
986     if (IsNonUniform)
987       Handler.addRequirements(
988           SPIRV::Capability::InputAttachmentArrayNonUniformIndexingEXT);
989     else
990       Handler.addRequirements(
991           SPIRV::Capability::InputAttachmentArrayDynamicIndexingEXT);
992   } else if (isStorageTexelBuffer(PointeeType)) {
993     if (IsNonUniform)
994       Handler.addRequirements(
995           SPIRV::Capability::StorageTexelBufferArrayNonUniformIndexingEXT);
996     else
997       Handler.addRequirements(
998           SPIRV::Capability::StorageTexelBufferArrayDynamicIndexingEXT);
999   } else if (isSampledImage(PointeeType) ||
1000              isCombinedImageSampler(PointeeType) ||
1001              PointeeType->getOpcode() == SPIRV::OpTypeSampler) {
1002     if (IsNonUniform)
1003       Handler.addRequirements(
1004           SPIRV::Capability::SampledImageArrayNonUniformIndexingEXT);
1005     else
1006       Handler.addRequirements(
1007           SPIRV::Capability::SampledImageArrayDynamicIndexing);
1008   } else if (isStorageImage(PointeeType)) {
1009     if (IsNonUniform)
1010       Handler.addRequirements(
1011           SPIRV::Capability::StorageImageArrayNonUniformIndexingEXT);
1012     else
1013       Handler.addRequirements(
1014           SPIRV::Capability::StorageImageArrayDynamicIndexing);
1015   }
1016 }
1017 
1018 static bool isImageTypeWithUnknownFormat(SPIRVType *TypeInst) {
1019   if (TypeInst->getOpcode() != SPIRV::OpTypeImage)
1020     return false;
1021   assert(TypeInst->getOperand(7).isImm() && "The image format must be an imm.");
1022   return TypeInst->getOperand(7).getImm() == 0;
1023 }
1024 
1025 static void AddDotProductRequirements(const MachineInstr &MI,
1026                                       SPIRV::RequirementHandler &Reqs,
1027                                       const SPIRVSubtarget &ST) {
1028   if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product))
1029     Reqs.addExtension(SPIRV::Extension::SPV_KHR_integer_dot_product);
1030   Reqs.addCapability(SPIRV::Capability::DotProduct);
1031 
1032   const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1033   assert(MI.getOperand(2).isReg() && "Unexpected operand in dot");
1034   // We do not consider what the previous instruction is. This is just used
1035   // to get the input register and to check the type.
1036   const MachineInstr *Input = MRI.getVRegDef(MI.getOperand(2).getReg());
1037   assert(Input->getOperand(1).isReg() && "Unexpected operand in dot input");
1038   Register InputReg = Input->getOperand(1).getReg();
1039 
1040   SPIRVType *TypeDef = MRI.getVRegDef(InputReg);
1041   if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1042     assert(TypeDef->getOperand(1).getImm() == 32);
1043     Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPacked);
1044   } else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) {
1045     SPIRVType *ScalarTypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
1046     assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt);
1047     if (ScalarTypeDef->getOperand(1).getImm() == 8) {
1048       assert(TypeDef->getOperand(2).getImm() == 4 &&
1049              "Dot operand of 8-bit integer type requires 4 components");
1050       Reqs.addCapability(SPIRV::Capability::DotProductInput4x8Bit);
1051     } else {
1052       Reqs.addCapability(SPIRV::Capability::DotProductInputAll);
1053     }
1054   }
1055 }
1056 
1057 void addInstrRequirements(const MachineInstr &MI,
1058                           SPIRV::RequirementHandler &Reqs,
1059                           const SPIRVSubtarget &ST) {
1060   switch (MI.getOpcode()) {
1061   case SPIRV::OpMemoryModel: {
1062     int64_t Addr = MI.getOperand(0).getImm();
1063     Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
1064                                Addr, ST);
1065     int64_t Mem = MI.getOperand(1).getImm();
1066     Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, Mem,
1067                                ST);
1068     break;
1069   }
1070   case SPIRV::OpEntryPoint: {
1071     int64_t Exe = MI.getOperand(0).getImm();
1072     Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModelOperand,
1073                                Exe, ST);
1074     break;
1075   }
1076   case SPIRV::OpExecutionMode:
1077   case SPIRV::OpExecutionModeId: {
1078     int64_t Exe = MI.getOperand(1).getImm();
1079     Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModeOperand,
1080                                Exe, ST);
1081     break;
1082   }
1083   case SPIRV::OpTypeMatrix:
1084     Reqs.addCapability(SPIRV::Capability::Matrix);
1085     break;
1086   case SPIRV::OpTypeInt: {
1087     unsigned BitWidth = MI.getOperand(1).getImm();
1088     if (BitWidth == 64)
1089       Reqs.addCapability(SPIRV::Capability::Int64);
1090     else if (BitWidth == 16)
1091       Reqs.addCapability(SPIRV::Capability::Int16);
1092     else if (BitWidth == 8)
1093       Reqs.addCapability(SPIRV::Capability::Int8);
1094     break;
1095   }
1096   case SPIRV::OpTypeFloat: {
1097     unsigned BitWidth = MI.getOperand(1).getImm();
1098     if (BitWidth == 64)
1099       Reqs.addCapability(SPIRV::Capability::Float64);
1100     else if (BitWidth == 16)
1101       Reqs.addCapability(SPIRV::Capability::Float16);
1102     break;
1103   }
1104   case SPIRV::OpTypeVector: {
1105     unsigned NumComponents = MI.getOperand(2).getImm();
1106     if (NumComponents == 8 || NumComponents == 16)
1107       Reqs.addCapability(SPIRV::Capability::Vector16);
1108     break;
1109   }
1110   case SPIRV::OpTypePointer: {
1111     auto SC = MI.getOperand(1).getImm();
1112     Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC,
1113                                ST);
1114     // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer
1115     // capability.
1116     if (!ST.isOpenCLEnv())
1117       break;
1118     assert(MI.getOperand(2).isReg());
1119     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1120     SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
1121     if (TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
1122         TypeDef->getOperand(1).getImm() == 16)
1123       Reqs.addCapability(SPIRV::Capability::Float16Buffer);
1124     break;
1125   }
1126   case SPIRV::OpExtInst: {
1127     if (MI.getOperand(2).getImm() ==
1128         static_cast<int64_t>(
1129             SPIRV::InstructionSet::NonSemantic_Shader_DebugInfo_100)) {
1130       Reqs.addExtension(SPIRV::Extension::SPV_KHR_non_semantic_info);
1131     }
1132     break;
1133   }
1134   case SPIRV::OpBitReverse:
1135   case SPIRV::OpBitFieldInsert:
1136   case SPIRV::OpBitFieldSExtract:
1137   case SPIRV::OpBitFieldUExtract:
1138     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions)) {
1139       Reqs.addCapability(SPIRV::Capability::Shader);
1140       break;
1141     }
1142     Reqs.addExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
1143     Reqs.addCapability(SPIRV::Capability::BitInstructions);
1144     break;
1145   case SPIRV::OpTypeRuntimeArray:
1146     Reqs.addCapability(SPIRV::Capability::Shader);
1147     break;
1148   case SPIRV::OpTypeOpaque:
1149   case SPIRV::OpTypeEvent:
1150     Reqs.addCapability(SPIRV::Capability::Kernel);
1151     break;
1152   case SPIRV::OpTypePipe:
1153   case SPIRV::OpTypeReserveId:
1154     Reqs.addCapability(SPIRV::Capability::Pipes);
1155     break;
1156   case SPIRV::OpTypeDeviceEvent:
1157   case SPIRV::OpTypeQueue:
1158   case SPIRV::OpBuildNDRange:
1159     Reqs.addCapability(SPIRV::Capability::DeviceEnqueue);
1160     break;
1161   case SPIRV::OpDecorate:
1162   case SPIRV::OpDecorateId:
1163   case SPIRV::OpDecorateString:
1164     addOpDecorateReqs(MI, 1, Reqs, ST);
1165     break;
1166   case SPIRV::OpMemberDecorate:
1167   case SPIRV::OpMemberDecorateString:
1168     addOpDecorateReqs(MI, 2, Reqs, ST);
1169     break;
1170   case SPIRV::OpInBoundsPtrAccessChain:
1171     Reqs.addCapability(SPIRV::Capability::Addresses);
1172     break;
1173   case SPIRV::OpConstantSampler:
1174     Reqs.addCapability(SPIRV::Capability::LiteralSampler);
1175     break;
1176   case SPIRV::OpInBoundsAccessChain:
1177   case SPIRV::OpAccessChain:
1178     addOpAccessChainReqs(MI, Reqs, ST);
1179     break;
1180   case SPIRV::OpTypeImage:
1181     addOpTypeImageReqs(MI, Reqs, ST);
1182     break;
1183   case SPIRV::OpTypeSampler:
1184     if (!ST.isVulkanEnv()) {
1185       Reqs.addCapability(SPIRV::Capability::ImageBasic);
1186     }
1187     break;
1188   case SPIRV::OpTypeForwardPointer:
1189     // TODO: check if it's OpenCL's kernel.
1190     Reqs.addCapability(SPIRV::Capability::Addresses);
1191     break;
1192   case SPIRV::OpAtomicFlagTestAndSet:
1193   case SPIRV::OpAtomicLoad:
1194   case SPIRV::OpAtomicStore:
1195   case SPIRV::OpAtomicExchange:
1196   case SPIRV::OpAtomicCompareExchange:
1197   case SPIRV::OpAtomicIIncrement:
1198   case SPIRV::OpAtomicIDecrement:
1199   case SPIRV::OpAtomicIAdd:
1200   case SPIRV::OpAtomicISub:
1201   case SPIRV::OpAtomicUMin:
1202   case SPIRV::OpAtomicUMax:
1203   case SPIRV::OpAtomicSMin:
1204   case SPIRV::OpAtomicSMax:
1205   case SPIRV::OpAtomicAnd:
1206   case SPIRV::OpAtomicOr:
1207   case SPIRV::OpAtomicXor: {
1208     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1209     const MachineInstr *InstrPtr = &MI;
1210     if (MI.getOpcode() == SPIRV::OpAtomicStore) {
1211       assert(MI.getOperand(3).isReg());
1212       InstrPtr = MRI.getVRegDef(MI.getOperand(3).getReg());
1213       assert(InstrPtr && "Unexpected type instruction for OpAtomicStore");
1214     }
1215     assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic");
1216     Register TypeReg = InstrPtr->getOperand(1).getReg();
1217     SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
1218     if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1219       unsigned BitWidth = TypeDef->getOperand(1).getImm();
1220       if (BitWidth == 64)
1221         Reqs.addCapability(SPIRV::Capability::Int64Atomics);
1222     }
1223     break;
1224   }
1225   case SPIRV::OpGroupNonUniformIAdd:
1226   case SPIRV::OpGroupNonUniformFAdd:
1227   case SPIRV::OpGroupNonUniformIMul:
1228   case SPIRV::OpGroupNonUniformFMul:
1229   case SPIRV::OpGroupNonUniformSMin:
1230   case SPIRV::OpGroupNonUniformUMin:
1231   case SPIRV::OpGroupNonUniformFMin:
1232   case SPIRV::OpGroupNonUniformSMax:
1233   case SPIRV::OpGroupNonUniformUMax:
1234   case SPIRV::OpGroupNonUniformFMax:
1235   case SPIRV::OpGroupNonUniformBitwiseAnd:
1236   case SPIRV::OpGroupNonUniformBitwiseOr:
1237   case SPIRV::OpGroupNonUniformBitwiseXor:
1238   case SPIRV::OpGroupNonUniformLogicalAnd:
1239   case SPIRV::OpGroupNonUniformLogicalOr:
1240   case SPIRV::OpGroupNonUniformLogicalXor: {
1241     assert(MI.getOperand(3).isImm());
1242     int64_t GroupOp = MI.getOperand(3).getImm();
1243     switch (GroupOp) {
1244     case SPIRV::GroupOperation::Reduce:
1245     case SPIRV::GroupOperation::InclusiveScan:
1246     case SPIRV::GroupOperation::ExclusiveScan:
1247       Reqs.addCapability(SPIRV::Capability::Kernel);
1248       Reqs.addCapability(SPIRV::Capability::GroupNonUniformArithmetic);
1249       Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
1250       break;
1251     case SPIRV::GroupOperation::ClusteredReduce:
1252       Reqs.addCapability(SPIRV::Capability::GroupNonUniformClustered);
1253       break;
1254     case SPIRV::GroupOperation::PartitionedReduceNV:
1255     case SPIRV::GroupOperation::PartitionedInclusiveScanNV:
1256     case SPIRV::GroupOperation::PartitionedExclusiveScanNV:
1257       Reqs.addCapability(SPIRV::Capability::GroupNonUniformPartitionedNV);
1258       break;
1259     }
1260     break;
1261   }
1262   case SPIRV::OpGroupNonUniformShuffle:
1263   case SPIRV::OpGroupNonUniformShuffleXor:
1264     Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffle);
1265     break;
1266   case SPIRV::OpGroupNonUniformShuffleUp:
1267   case SPIRV::OpGroupNonUniformShuffleDown:
1268     Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffleRelative);
1269     break;
1270   case SPIRV::OpGroupAll:
1271   case SPIRV::OpGroupAny:
1272   case SPIRV::OpGroupBroadcast:
1273   case SPIRV::OpGroupIAdd:
1274   case SPIRV::OpGroupFAdd:
1275   case SPIRV::OpGroupFMin:
1276   case SPIRV::OpGroupUMin:
1277   case SPIRV::OpGroupSMin:
1278   case SPIRV::OpGroupFMax:
1279   case SPIRV::OpGroupUMax:
1280   case SPIRV::OpGroupSMax:
1281     Reqs.addCapability(SPIRV::Capability::Groups);
1282     break;
1283   case SPIRV::OpGroupNonUniformElect:
1284     Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1285     break;
1286   case SPIRV::OpGroupNonUniformAll:
1287   case SPIRV::OpGroupNonUniformAny:
1288   case SPIRV::OpGroupNonUniformAllEqual:
1289     Reqs.addCapability(SPIRV::Capability::GroupNonUniformVote);
1290     break;
1291   case SPIRV::OpGroupNonUniformBroadcast:
1292   case SPIRV::OpGroupNonUniformBroadcastFirst:
1293   case SPIRV::OpGroupNonUniformBallot:
1294   case SPIRV::OpGroupNonUniformInverseBallot:
1295   case SPIRV::OpGroupNonUniformBallotBitExtract:
1296   case SPIRV::OpGroupNonUniformBallotBitCount:
1297   case SPIRV::OpGroupNonUniformBallotFindLSB:
1298   case SPIRV::OpGroupNonUniformBallotFindMSB:
1299     Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
1300     break;
1301   case SPIRV::OpSubgroupShuffleINTEL:
1302   case SPIRV::OpSubgroupShuffleDownINTEL:
1303   case SPIRV::OpSubgroupShuffleUpINTEL:
1304   case SPIRV::OpSubgroupShuffleXorINTEL:
1305     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1306       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1307       Reqs.addCapability(SPIRV::Capability::SubgroupShuffleINTEL);
1308     }
1309     break;
1310   case SPIRV::OpSubgroupBlockReadINTEL:
1311   case SPIRV::OpSubgroupBlockWriteINTEL:
1312     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1313       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1314       Reqs.addCapability(SPIRV::Capability::SubgroupBufferBlockIOINTEL);
1315     }
1316     break;
1317   case SPIRV::OpSubgroupImageBlockReadINTEL:
1318   case SPIRV::OpSubgroupImageBlockWriteINTEL:
1319     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1320       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1321       Reqs.addCapability(SPIRV::Capability::SubgroupImageBlockIOINTEL);
1322     }
1323     break;
1324   case SPIRV::OpSubgroupImageMediaBlockReadINTEL:
1325   case SPIRV::OpSubgroupImageMediaBlockWriteINTEL:
1326     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_media_block_io)) {
1327       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_media_block_io);
1328       Reqs.addCapability(SPIRV::Capability::SubgroupImageMediaBlockIOINTEL);
1329     }
1330     break;
1331   case SPIRV::OpAssumeTrueKHR:
1332   case SPIRV::OpExpectKHR:
1333     if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) {
1334       Reqs.addExtension(SPIRV::Extension::SPV_KHR_expect_assume);
1335       Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
1336     }
1337     break;
1338   case SPIRV::OpPtrCastToCrossWorkgroupINTEL:
1339   case SPIRV::OpCrossWorkgroupCastToPtrINTEL:
1340     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {
1341       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes);
1342       Reqs.addCapability(SPIRV::Capability::USMStorageClassesINTEL);
1343     }
1344     break;
1345   case SPIRV::OpConstantFunctionPointerINTEL:
1346     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1347       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1348       Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1349     }
1350     break;
1351   case SPIRV::OpGroupNonUniformRotateKHR:
1352     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate))
1353       report_fatal_error("OpGroupNonUniformRotateKHR instruction requires the "
1354                          "following SPIR-V extension: SPV_KHR_subgroup_rotate",
1355                          false);
1356     Reqs.addExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate);
1357     Reqs.addCapability(SPIRV::Capability::GroupNonUniformRotateKHR);
1358     Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1359     break;
1360   case SPIRV::OpGroupIMulKHR:
1361   case SPIRV::OpGroupFMulKHR:
1362   case SPIRV::OpGroupBitwiseAndKHR:
1363   case SPIRV::OpGroupBitwiseOrKHR:
1364   case SPIRV::OpGroupBitwiseXorKHR:
1365   case SPIRV::OpGroupLogicalAndKHR:
1366   case SPIRV::OpGroupLogicalOrKHR:
1367   case SPIRV::OpGroupLogicalXorKHR:
1368     if (ST.canUseExtension(
1369             SPIRV::Extension::SPV_KHR_uniform_group_instructions)) {
1370       Reqs.addExtension(SPIRV::Extension::SPV_KHR_uniform_group_instructions);
1371       Reqs.addCapability(SPIRV::Capability::GroupUniformArithmeticKHR);
1372     }
1373     break;
1374   case SPIRV::OpReadClockKHR:
1375     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_shader_clock))
1376       report_fatal_error("OpReadClockKHR instruction requires the "
1377                          "following SPIR-V extension: SPV_KHR_shader_clock",
1378                          false);
1379     Reqs.addExtension(SPIRV::Extension::SPV_KHR_shader_clock);
1380     Reqs.addCapability(SPIRV::Capability::ShaderClockKHR);
1381     break;
1382   case SPIRV::OpFunctionPointerCallINTEL:
1383     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1384       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1385       Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1386     }
1387     break;
1388   case SPIRV::OpAtomicFAddEXT:
1389   case SPIRV::OpAtomicFMinEXT:
1390   case SPIRV::OpAtomicFMaxEXT:
1391     AddAtomicFloatRequirements(MI, Reqs, ST);
1392     break;
1393   case SPIRV::OpConvertBF16ToFINTEL:
1394   case SPIRV::OpConvertFToBF16INTEL:
1395     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) {
1396       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion);
1397       Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL);
1398     }
1399     break;
1400   case SPIRV::OpVariableLengthArrayINTEL:
1401   case SPIRV::OpSaveMemoryINTEL:
1402   case SPIRV::OpRestoreMemoryINTEL:
1403     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_variable_length_array)) {
1404       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_variable_length_array);
1405       Reqs.addCapability(SPIRV::Capability::VariableLengthArrayINTEL);
1406     }
1407     break;
1408   case SPIRV::OpAsmTargetINTEL:
1409   case SPIRV::OpAsmINTEL:
1410   case SPIRV::OpAsmCallINTEL:
1411     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_inline_assembly)) {
1412       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_inline_assembly);
1413       Reqs.addCapability(SPIRV::Capability::AsmINTEL);
1414     }
1415     break;
1416   case SPIRV::OpTypeCooperativeMatrixKHR:
1417     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
1418       report_fatal_error(
1419           "OpTypeCooperativeMatrixKHR type requires the "
1420           "following SPIR-V extension: SPV_KHR_cooperative_matrix",
1421           false);
1422     Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
1423     Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
1424     break;
1425   case SPIRV::OpArithmeticFenceEXT:
1426     if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence))
1427       report_fatal_error("OpArithmeticFenceEXT requires the "
1428                          "following SPIR-V extension: SPV_EXT_arithmetic_fence",
1429                          false);
1430     Reqs.addExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence);
1431     Reqs.addCapability(SPIRV::Capability::ArithmeticFenceEXT);
1432     break;
1433   case SPIRV::OpControlBarrierArriveINTEL:
1434   case SPIRV::OpControlBarrierWaitINTEL:
1435     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_split_barrier)) {
1436       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_split_barrier);
1437       Reqs.addCapability(SPIRV::Capability::SplitBarrierINTEL);
1438     }
1439     break;
1440   case SPIRV::OpCooperativeMatrixMulAddKHR: {
1441     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
1442       report_fatal_error("Cooperative matrix instructions require the "
1443                          "following SPIR-V extension: "
1444                          "SPV_KHR_cooperative_matrix",
1445                          false);
1446     Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
1447     Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
1448     constexpr unsigned MulAddMaxSize = 6;
1449     if (MI.getNumOperands() != MulAddMaxSize)
1450       break;
1451     const int64_t CoopOperands = MI.getOperand(MulAddMaxSize - 1).getImm();
1452     if (CoopOperands &
1453         SPIRV::CooperativeMatrixOperands::MatrixAAndBTF32ComponentsINTEL) {
1454       if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1455         report_fatal_error("MatrixAAndBTF32ComponentsINTEL type interpretation "
1456                            "require the following SPIR-V extension: "
1457                            "SPV_INTEL_joint_matrix",
1458                            false);
1459       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1460       Reqs.addCapability(
1461           SPIRV::Capability::CooperativeMatrixTF32ComponentTypeINTEL);
1462     }
1463     if (CoopOperands & SPIRV::CooperativeMatrixOperands::
1464                            MatrixAAndBBFloat16ComponentsINTEL ||
1465         CoopOperands &
1466             SPIRV::CooperativeMatrixOperands::MatrixCBFloat16ComponentsINTEL ||
1467         CoopOperands & SPIRV::CooperativeMatrixOperands::
1468                            MatrixResultBFloat16ComponentsINTEL) {
1469       if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1470         report_fatal_error("***BF16ComponentsINTEL type interpretations "
1471                            "require the following SPIR-V extension: "
1472                            "SPV_INTEL_joint_matrix",
1473                            false);
1474       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1475       Reqs.addCapability(
1476           SPIRV::Capability::CooperativeMatrixBFloat16ComponentTypeINTEL);
1477     }
1478     break;
1479   }
1480   case SPIRV::OpCooperativeMatrixLoadKHR:
1481   case SPIRV::OpCooperativeMatrixStoreKHR:
1482   case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
1483   case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
1484   case SPIRV::OpCooperativeMatrixPrefetchINTEL: {
1485     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
1486       report_fatal_error("Cooperative matrix instructions require the "
1487                          "following SPIR-V extension: "
1488                          "SPV_KHR_cooperative_matrix",
1489                          false);
1490     Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
1491     Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
1492 
1493     // Check Layout operand in case if it's not a standard one and add the
1494     // appropriate capability.
1495     std::unordered_map<unsigned, unsigned> LayoutToInstMap = {
1496         {SPIRV::OpCooperativeMatrixLoadKHR, 3},
1497         {SPIRV::OpCooperativeMatrixStoreKHR, 2},
1498         {SPIRV::OpCooperativeMatrixLoadCheckedINTEL, 5},
1499         {SPIRV::OpCooperativeMatrixStoreCheckedINTEL, 4},
1500         {SPIRV::OpCooperativeMatrixPrefetchINTEL, 4}};
1501 
1502     const auto OpCode = MI.getOpcode();
1503     const unsigned LayoutNum = LayoutToInstMap[OpCode];
1504     Register RegLayout = MI.getOperand(LayoutNum).getReg();
1505     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1506     MachineInstr *MILayout = MRI.getUniqueVRegDef(RegLayout);
1507     if (MILayout->getOpcode() == SPIRV::OpConstantI) {
1508       const unsigned LayoutVal = MILayout->getOperand(2).getImm();
1509       if (LayoutVal ==
1510           static_cast<unsigned>(SPIRV::CooperativeMatrixLayout::PackedINTEL)) {
1511         if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1512           report_fatal_error("PackedINTEL layout require the following SPIR-V "
1513                              "extension: SPV_INTEL_joint_matrix",
1514                              false);
1515         Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1516         Reqs.addCapability(SPIRV::Capability::PackedCooperativeMatrixINTEL);
1517       }
1518     }
1519 
1520     // Nothing to do.
1521     if (OpCode == SPIRV::OpCooperativeMatrixLoadKHR ||
1522         OpCode == SPIRV::OpCooperativeMatrixStoreKHR)
1523       break;
1524 
1525     std::string InstName;
1526     switch (OpCode) {
1527     case SPIRV::OpCooperativeMatrixPrefetchINTEL:
1528       InstName = "OpCooperativeMatrixPrefetchINTEL";
1529       break;
1530     case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
1531       InstName = "OpCooperativeMatrixLoadCheckedINTEL";
1532       break;
1533     case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
1534       InstName = "OpCooperativeMatrixStoreCheckedINTEL";
1535       break;
1536     }
1537 
1538     if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix)) {
1539       const std::string ErrorMsg =
1540           InstName + " instruction requires the "
1541                      "following SPIR-V extension: SPV_INTEL_joint_matrix";
1542       report_fatal_error(ErrorMsg.c_str(), false);
1543     }
1544     Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1545     if (OpCode == SPIRV::OpCooperativeMatrixPrefetchINTEL) {
1546       Reqs.addCapability(SPIRV::Capability::CooperativeMatrixPrefetchINTEL);
1547       break;
1548     }
1549     Reqs.addCapability(
1550         SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
1551     break;
1552   }
1553   case SPIRV::OpCooperativeMatrixConstructCheckedINTEL:
1554     if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1555       report_fatal_error("OpCooperativeMatrixConstructCheckedINTEL "
1556                          "instructions require the following SPIR-V extension: "
1557                          "SPV_INTEL_joint_matrix",
1558                          false);
1559     Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1560     Reqs.addCapability(
1561         SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
1562     break;
1563   case SPIRV::OpCooperativeMatrixGetElementCoordINTEL:
1564     if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1565       report_fatal_error("OpCooperativeMatrixGetElementCoordINTEL requires the "
1566                          "following SPIR-V extension: SPV_INTEL_joint_matrix",
1567                          false);
1568     Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1569     Reqs.addCapability(
1570         SPIRV::Capability::CooperativeMatrixInvocationInstructionsINTEL);
1571     break;
1572   case SPIRV::OpKill: {
1573     Reqs.addCapability(SPIRV::Capability::Shader);
1574   } break;
1575   case SPIRV::OpDemoteToHelperInvocation:
1576     Reqs.addCapability(SPIRV::Capability::DemoteToHelperInvocation);
1577 
1578     if (ST.canUseExtension(
1579             SPIRV::Extension::SPV_EXT_demote_to_helper_invocation)) {
1580       if (!ST.isAtLeastSPIRVVer(llvm::VersionTuple(1, 6)))
1581         Reqs.addExtension(
1582             SPIRV::Extension::SPV_EXT_demote_to_helper_invocation);
1583     }
1584     break;
1585   case SPIRV::OpSDot:
1586   case SPIRV::OpUDot:
1587     AddDotProductRequirements(MI, Reqs, ST);
1588     break;
1589   case SPIRV::OpImageRead: {
1590     Register ImageReg = MI.getOperand(2).getReg();
1591     SPIRVType *TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(ImageReg);
1592     if (isImageTypeWithUnknownFormat(TypeDef))
1593       Reqs.addCapability(SPIRV::Capability::StorageImageReadWithoutFormat);
1594     break;
1595   }
1596   case SPIRV::OpImageWrite: {
1597     Register ImageReg = MI.getOperand(0).getReg();
1598     SPIRVType *TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(ImageReg);
1599     if (isImageTypeWithUnknownFormat(TypeDef))
1600       Reqs.addCapability(SPIRV::Capability::StorageImageWriteWithoutFormat);
1601     break;
1602   }
1603 
1604   default:
1605     break;
1606   }
1607 
1608   // If we require capability Shader, then we can remove the requirement for
1609   // the BitInstructions capability, since Shader is a superset capability
1610   // of BitInstructions.
1611   Reqs.removeCapabilityIf(SPIRV::Capability::BitInstructions,
1612                           SPIRV::Capability::Shader);
1613 }
1614 
1615 static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
1616                         MachineModuleInfo *MMI, const SPIRVSubtarget &ST) {
1617   // Collect requirements for existing instructions.
1618   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
1619     MachineFunction *MF = MMI->getMachineFunction(*F);
1620     if (!MF)
1621       continue;
1622     for (const MachineBasicBlock &MBB : *MF)
1623       for (const MachineInstr &MI : MBB)
1624         addInstrRequirements(MI, MAI.Reqs, ST);
1625   }
1626   // Collect requirements for OpExecutionMode instructions.
1627   auto Node = M.getNamedMetadata("spirv.ExecutionMode");
1628   if (Node) {
1629     // SPV_KHR_float_controls is not available until v1.4
1630     bool RequireFloatControls = false,
1631          VerLower14 = !ST.isAtLeastSPIRVVer(VersionTuple(1, 4));
1632     for (unsigned i = 0; i < Node->getNumOperands(); i++) {
1633       MDNode *MDN = cast<MDNode>(Node->getOperand(i));
1634       const MDOperand &MDOp = MDN->getOperand(1);
1635       if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) {
1636         Constant *C = CMeta->getValue();
1637         if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) {
1638           auto EM = Const->getZExtValue();
1639           MAI.Reqs.getAndAddRequirements(
1640               SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
1641           // add SPV_KHR_float_controls if the version is too low
1642           switch (EM) {
1643           case SPIRV::ExecutionMode::DenormPreserve:
1644           case SPIRV::ExecutionMode::DenormFlushToZero:
1645           case SPIRV::ExecutionMode::SignedZeroInfNanPreserve:
1646           case SPIRV::ExecutionMode::RoundingModeRTE:
1647           case SPIRV::ExecutionMode::RoundingModeRTZ:
1648             RequireFloatControls = VerLower14;
1649             break;
1650           }
1651         }
1652       }
1653     }
1654     if (RequireFloatControls &&
1655         ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls))
1656       MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls);
1657   }
1658   for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
1659     const Function &F = *FI;
1660     if (F.isDeclaration())
1661       continue;
1662     if (F.getMetadata("reqd_work_group_size"))
1663       MAI.Reqs.getAndAddRequirements(
1664           SPIRV::OperandCategory::ExecutionModeOperand,
1665           SPIRV::ExecutionMode::LocalSize, ST);
1666     if (F.getFnAttribute("hlsl.numthreads").isValid()) {
1667       MAI.Reqs.getAndAddRequirements(
1668           SPIRV::OperandCategory::ExecutionModeOperand,
1669           SPIRV::ExecutionMode::LocalSize, ST);
1670     }
1671     if (F.getMetadata("work_group_size_hint"))
1672       MAI.Reqs.getAndAddRequirements(
1673           SPIRV::OperandCategory::ExecutionModeOperand,
1674           SPIRV::ExecutionMode::LocalSizeHint, ST);
1675     if (F.getMetadata("intel_reqd_sub_group_size"))
1676       MAI.Reqs.getAndAddRequirements(
1677           SPIRV::OperandCategory::ExecutionModeOperand,
1678           SPIRV::ExecutionMode::SubgroupSize, ST);
1679     if (F.getMetadata("vec_type_hint"))
1680       MAI.Reqs.getAndAddRequirements(
1681           SPIRV::OperandCategory::ExecutionModeOperand,
1682           SPIRV::ExecutionMode::VecTypeHint, ST);
1683 
1684     if (F.hasOptNone()) {
1685       if (ST.canUseExtension(SPIRV::Extension::SPV_EXT_optnone)) {
1686         MAI.Reqs.addExtension(SPIRV::Extension::SPV_EXT_optnone);
1687         MAI.Reqs.addCapability(SPIRV::Capability::OptNoneEXT);
1688       } else if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_optnone)) {
1689         MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_optnone);
1690         MAI.Reqs.addCapability(SPIRV::Capability::OptNoneINTEL);
1691       }
1692     }
1693   }
1694 }
1695 
1696 static unsigned getFastMathFlags(const MachineInstr &I) {
1697   unsigned Flags = SPIRV::FPFastMathMode::None;
1698   if (I.getFlag(MachineInstr::MIFlag::FmNoNans))
1699     Flags |= SPIRV::FPFastMathMode::NotNaN;
1700   if (I.getFlag(MachineInstr::MIFlag::FmNoInfs))
1701     Flags |= SPIRV::FPFastMathMode::NotInf;
1702   if (I.getFlag(MachineInstr::MIFlag::FmNsz))
1703     Flags |= SPIRV::FPFastMathMode::NSZ;
1704   if (I.getFlag(MachineInstr::MIFlag::FmArcp))
1705     Flags |= SPIRV::FPFastMathMode::AllowRecip;
1706   if (I.getFlag(MachineInstr::MIFlag::FmReassoc))
1707     Flags |= SPIRV::FPFastMathMode::Fast;
1708   return Flags;
1709 }
1710 
1711 static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST,
1712                                    const SPIRVInstrInfo &TII,
1713                                    SPIRV::RequirementHandler &Reqs) {
1714   if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) &&
1715       getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
1716                                      SPIRV::Decoration::NoSignedWrap, ST, Reqs)
1717           .IsSatisfiable) {
1718     buildOpDecorate(I.getOperand(0).getReg(), I, TII,
1719                     SPIRV::Decoration::NoSignedWrap, {});
1720   }
1721   if (I.getFlag(MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(I) &&
1722       getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
1723                                      SPIRV::Decoration::NoUnsignedWrap, ST,
1724                                      Reqs)
1725           .IsSatisfiable) {
1726     buildOpDecorate(I.getOperand(0).getReg(), I, TII,
1727                     SPIRV::Decoration::NoUnsignedWrap, {});
1728   }
1729   if (!TII.canUseFastMathFlags(I))
1730     return;
1731   unsigned FMFlags = getFastMathFlags(I);
1732   if (FMFlags == SPIRV::FPFastMathMode::None)
1733     return;
1734   Register DstReg = I.getOperand(0).getReg();
1735   buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode, {FMFlags});
1736 }
1737 
1738 // Walk all functions and add decorations related to MI flags.
1739 static void addDecorations(const Module &M, const SPIRVInstrInfo &TII,
1740                            MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
1741                            SPIRV::ModuleAnalysisInfo &MAI) {
1742   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
1743     MachineFunction *MF = MMI->getMachineFunction(*F);
1744     if (!MF)
1745       continue;
1746     for (auto &MBB : *MF)
1747       for (auto &MI : MBB)
1748         handleMIFlagDecoration(MI, ST, TII, MAI.Reqs);
1749   }
1750 }
1751 
1752 static void addMBBNames(const Module &M, const SPIRVInstrInfo &TII,
1753                         MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
1754                         SPIRV::ModuleAnalysisInfo &MAI) {
1755   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
1756     MachineFunction *MF = MMI->getMachineFunction(*F);
1757     if (!MF)
1758       continue;
1759     MachineRegisterInfo &MRI = MF->getRegInfo();
1760     for (auto &MBB : *MF) {
1761       if (!MBB.hasName() || MBB.empty())
1762         continue;
1763       // Emit basic block names.
1764       Register Reg = MRI.createGenericVirtualRegister(LLT::scalar(64));
1765       MRI.setRegClass(Reg, &SPIRV::IDRegClass);
1766       buildOpName(Reg, MBB.getName(), *std::prev(MBB.end()), TII);
1767       Register GlobalReg = MAI.getOrCreateMBBRegister(MBB);
1768       MAI.setRegisterAlias(MF, Reg, GlobalReg);
1769     }
1770   }
1771 }
1772 
1773 // patching Instruction::PHI to SPIRV::OpPhi
1774 static void patchPhis(const Module &M, SPIRVGlobalRegistry *GR,
1775                       const SPIRVInstrInfo &TII, MachineModuleInfo *MMI) {
1776   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
1777     MachineFunction *MF = MMI->getMachineFunction(*F);
1778     if (!MF)
1779       continue;
1780     for (auto &MBB : *MF) {
1781       for (MachineInstr &MI : MBB) {
1782         if (MI.getOpcode() != TargetOpcode::PHI)
1783           continue;
1784         MI.setDesc(TII.get(SPIRV::OpPhi));
1785         Register ResTypeReg = GR->getSPIRVTypeID(
1786             GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg(), MF));
1787         MI.insert(MI.operands_begin() + 1,
1788                   {MachineOperand::CreateReg(ResTypeReg, false)});
1789       }
1790     }
1791   }
1792 }
1793 
1794 struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI;
1795 
1796 void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
1797   AU.addRequired<TargetPassConfig>();
1798   AU.addRequired<MachineModuleInfoWrapperPass>();
1799 }
1800 
1801 bool SPIRVModuleAnalysis::runOnModule(Module &M) {
1802   SPIRVTargetMachine &TM =
1803       getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>();
1804   ST = TM.getSubtargetImpl();
1805   GR = ST->getSPIRVGlobalRegistry();
1806   TII = ST->getInstrInfo();
1807 
1808   MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
1809 
1810   setBaseInfo(M);
1811 
1812   patchPhis(M, GR, *TII, MMI);
1813 
1814   addMBBNames(M, *TII, MMI, *ST, MAI);
1815   addDecorations(M, *TII, MMI, *ST, MAI);
1816 
1817   collectReqs(M, MAI, MMI, *ST);
1818 
1819   // Process type/const/global var/func decl instructions, number their
1820   // destination registers from 0 to N, collect Extensions and Capabilities.
1821   processDefInstrs(M);
1822 
1823   // Number rest of registers from N+1 onwards.
1824   numberRegistersGlobally(M);
1825 
1826   // Update references to OpFunction instructions to use Global Registers
1827   if (GR->hasConstFunPtr())
1828     collectFuncPtrs();
1829 
1830   // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.
1831   processOtherInstrs(M);
1832 
1833   // If there are no entry points, we need the Linkage capability.
1834   if (MAI.MS[SPIRV::MB_EntryPoints].empty())
1835     MAI.Reqs.addCapability(SPIRV::Capability::Linkage);
1836 
1837   // Set maximum ID used.
1838   GR->setBound(MAI.MaxID);
1839 
1840   return false;
1841 }
1842