1 //===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The analysis collects instructions that should be output at the module level 10 // and performs the global register numbering. 11 // 12 // The results of this analysis are used in AsmPrinter to rename registers 13 // globally and to output required instructions at the module level. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "SPIRVModuleAnalysis.h" 18 #include "MCTargetDesc/SPIRVBaseInfo.h" 19 #include "MCTargetDesc/SPIRVMCTargetDesc.h" 20 #include "SPIRV.h" 21 #include "SPIRVSubtarget.h" 22 #include "SPIRVTargetMachine.h" 23 #include "SPIRVUtils.h" 24 #include "llvm/ADT/STLExtras.h" 25 #include "llvm/CodeGen/MachineModuleInfo.h" 26 #include "llvm/CodeGen/TargetPassConfig.h" 27 28 using namespace llvm; 29 30 #define DEBUG_TYPE "spirv-module-analysis" 31 32 static cl::opt<bool> 33 SPVDumpDeps("spv-dump-deps", 34 cl::desc("Dump MIR with SPIR-V dependencies info"), 35 cl::Optional, cl::init(false)); 36 37 static cl::list<SPIRV::Capability::Capability> 38 AvoidCapabilities("avoid-spirv-capabilities", 39 cl::desc("SPIR-V capabilities to avoid if there are " 40 "other options enabling a feature"), 41 cl::ZeroOrMore, cl::Hidden, 42 cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader", 43 "SPIR-V Shader capability"))); 44 // Use sets instead of cl::list to check "if contains" condition 45 struct AvoidCapabilitiesSet { 46 SmallSet<SPIRV::Capability::Capability, 4> S; 47 AvoidCapabilitiesSet() { 48 for (auto Cap : AvoidCapabilities) 49 S.insert(Cap); 50 } 51 }; 52 53 char llvm::SPIRVModuleAnalysis::ID = 0; 54 55 namespace llvm { 56 void initializeSPIRVModuleAnalysisPass(PassRegistry &); 57 } // namespace llvm 58 59 INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true, 60 true) 61 62 // Retrieve an unsigned from an MDNode with a list of them as operands. 63 static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex, 64 unsigned DefaultVal = 0) { 65 if (MdNode && OpIndex < MdNode->getNumOperands()) { 66 const auto &Op = MdNode->getOperand(OpIndex); 67 return mdconst::extract<ConstantInt>(Op)->getZExtValue(); 68 } 69 return DefaultVal; 70 } 71 72 static SPIRV::Requirements 73 getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category, 74 unsigned i, const SPIRVSubtarget &ST, 75 SPIRV::RequirementHandler &Reqs) { 76 static AvoidCapabilitiesSet 77 AvoidCaps; // contains capabilities to avoid if there is another option 78 79 VersionTuple ReqMinVer = getSymbolicOperandMinVersion(Category, i); 80 VersionTuple ReqMaxVer = getSymbolicOperandMaxVersion(Category, i); 81 VersionTuple SPIRVVersion = ST.getSPIRVVersion(); 82 bool MinVerOK = SPIRVVersion.empty() || SPIRVVersion >= ReqMinVer; 83 bool MaxVerOK = 84 ReqMaxVer.empty() || SPIRVVersion.empty() || SPIRVVersion <= ReqMaxVer; 85 CapabilityList ReqCaps = getSymbolicOperandCapabilities(Category, i); 86 ExtensionList ReqExts = getSymbolicOperandExtensions(Category, i); 87 if (ReqCaps.empty()) { 88 if (ReqExts.empty()) { 89 if (MinVerOK && MaxVerOK) 90 return {true, {}, {}, ReqMinVer, ReqMaxVer}; 91 return {false, {}, {}, VersionTuple(), VersionTuple()}; 92 } 93 } else if (MinVerOK && MaxVerOK) { 94 if (ReqCaps.size() == 1) { 95 auto Cap = ReqCaps[0]; 96 if (Reqs.isCapabilityAvailable(Cap)) 97 return {true, {Cap}, ReqExts, ReqMinVer, ReqMaxVer}; 98 } else { 99 // By SPIR-V specification: "If an instruction, enumerant, or other 100 // feature specifies multiple enabling capabilities, only one such 101 // capability needs to be declared to use the feature." However, one 102 // capability may be preferred over another. We use command line 103 // argument(s) and AvoidCapabilities to avoid selection of certain 104 // capabilities if there are other options. 105 CapabilityList UseCaps; 106 for (auto Cap : ReqCaps) 107 if (Reqs.isCapabilityAvailable(Cap)) 108 UseCaps.push_back(Cap); 109 for (size_t i = 0, Sz = UseCaps.size(); i < Sz; ++i) { 110 auto Cap = UseCaps[i]; 111 if (i == Sz - 1 || !AvoidCaps.S.contains(Cap)) 112 return {true, {Cap}, ReqExts, ReqMinVer, ReqMaxVer}; 113 } 114 } 115 } 116 // If there are no capabilities, or we can't satisfy the version or 117 // capability requirements, use the list of extensions (if the subtarget 118 // can handle them all). 119 if (llvm::all_of(ReqExts, [&ST](const SPIRV::Extension::Extension &Ext) { 120 return ST.canUseExtension(Ext); 121 })) { 122 return {true, 123 {}, 124 ReqExts, 125 VersionTuple(), 126 VersionTuple()}; // TODO: add versions to extensions. 127 } 128 return {false, {}, {}, VersionTuple(), VersionTuple()}; 129 } 130 131 void SPIRVModuleAnalysis::setBaseInfo(const Module &M) { 132 MAI.MaxID = 0; 133 for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++) 134 MAI.MS[i].clear(); 135 MAI.RegisterAliasTable.clear(); 136 MAI.InstrsToDelete.clear(); 137 MAI.FuncMap.clear(); 138 MAI.GlobalVarList.clear(); 139 MAI.ExtInstSetMap.clear(); 140 MAI.Reqs.clear(); 141 MAI.Reqs.initAvailableCapabilities(*ST); 142 143 // TODO: determine memory model and source language from the configuratoin. 144 if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) { 145 auto MemMD = MemModel->getOperand(0); 146 MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>( 147 getMetadataUInt(MemMD, 0)); 148 MAI.Mem = 149 static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1)); 150 } else { 151 // TODO: Add support for VulkanMemoryModel. 152 MAI.Mem = ST->isOpenCLEnv() ? SPIRV::MemoryModel::OpenCL 153 : SPIRV::MemoryModel::GLSL450; 154 if (MAI.Mem == SPIRV::MemoryModel::OpenCL) { 155 unsigned PtrSize = ST->getPointerSize(); 156 MAI.Addr = PtrSize == 32 ? SPIRV::AddressingModel::Physical32 157 : PtrSize == 64 ? SPIRV::AddressingModel::Physical64 158 : SPIRV::AddressingModel::Logical; 159 } else { 160 // TODO: Add support for PhysicalStorageBufferAddress. 161 MAI.Addr = SPIRV::AddressingModel::Logical; 162 } 163 } 164 // Get the OpenCL version number from metadata. 165 // TODO: support other source languages. 166 if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) { 167 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C; 168 // Construct version literal in accordance with SPIRV-LLVM-Translator. 169 // TODO: support multiple OCL version metadata. 170 assert(VerNode->getNumOperands() > 0 && "Invalid SPIR"); 171 auto VersionMD = VerNode->getOperand(0); 172 unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2); 173 unsigned MinorNum = getMetadataUInt(VersionMD, 1); 174 unsigned RevNum = getMetadataUInt(VersionMD, 2); 175 // Prevent Major part of OpenCL version to be 0 176 MAI.SrcLangVersion = 177 (std::max(1U, MajorNum) * 100 + MinorNum) * 1000 + RevNum; 178 } else { 179 // If there is no information about OpenCL version we are forced to generate 180 // OpenCL 1.0 by default for the OpenCL environment to avoid puzzling 181 // run-times with Unknown/0.0 version output. For a reference, LLVM-SPIRV 182 // Translator avoids potential issues with run-times in a similar manner. 183 if (ST->isOpenCLEnv()) { 184 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_CPP; 185 MAI.SrcLangVersion = 100000; 186 } else { 187 MAI.SrcLang = SPIRV::SourceLanguage::Unknown; 188 MAI.SrcLangVersion = 0; 189 } 190 } 191 192 if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) { 193 for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) { 194 MDNode *MD = ExtNode->getOperand(I); 195 if (!MD || MD->getNumOperands() == 0) 196 continue; 197 for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J) 198 MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString()); 199 } 200 } 201 202 // Update required capabilities for this memory model, addressing model and 203 // source language. 204 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, 205 MAI.Mem, *ST); 206 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::SourceLanguageOperand, 207 MAI.SrcLang, *ST); 208 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand, 209 MAI.Addr, *ST); 210 211 if (ST->isOpenCLEnv()) { 212 // TODO: check if it's required by default. 213 MAI.ExtInstSetMap[static_cast<unsigned>( 214 SPIRV::InstructionSet::OpenCL_std)] = 215 Register::index2VirtReg(MAI.getNextID()); 216 } 217 } 218 219 // Collect MI which defines the register in the given machine function. 220 static void collectDefInstr(Register Reg, const MachineFunction *MF, 221 SPIRV::ModuleAnalysisInfo *MAI, 222 SPIRV::ModuleSectionType MSType, 223 bool DoInsert = true) { 224 assert(MAI->hasRegisterAlias(MF, Reg) && "Cannot find register alias"); 225 MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(Reg); 226 assert(MI && "There should be an instruction that defines the register"); 227 MAI->setSkipEmission(MI); 228 if (DoInsert) 229 MAI->MS[MSType].push_back(MI); 230 } 231 232 void SPIRVModuleAnalysis::collectGlobalEntities( 233 const std::vector<SPIRV::DTSortableEntry *> &DepsGraph, 234 SPIRV::ModuleSectionType MSType, 235 std::function<bool(const SPIRV::DTSortableEntry *)> Pred, 236 bool UsePreOrder = false) { 237 DenseSet<const SPIRV::DTSortableEntry *> Visited; 238 for (const auto *E : DepsGraph) { 239 std::function<void(const SPIRV::DTSortableEntry *)> RecHoistUtil; 240 // NOTE: here we prefer recursive approach over iterative because 241 // we don't expect depchains long enough to cause SO. 242 RecHoistUtil = [MSType, UsePreOrder, &Visited, &Pred, 243 &RecHoistUtil](const SPIRV::DTSortableEntry *E) { 244 if (Visited.count(E) || !Pred(E)) 245 return; 246 Visited.insert(E); 247 248 // Traversing deps graph in post-order allows us to get rid of 249 // register aliases preprocessing. 250 // But pre-order is required for correct processing of function 251 // declaration and arguments processing. 252 if (!UsePreOrder) 253 for (auto *S : E->getDeps()) 254 RecHoistUtil(S); 255 256 Register GlobalReg = Register::index2VirtReg(MAI.getNextID()); 257 bool IsFirst = true; 258 for (auto &U : *E) { 259 const MachineFunction *MF = U.first; 260 Register Reg = U.second; 261 MAI.setRegisterAlias(MF, Reg, GlobalReg); 262 if (!MF->getRegInfo().getUniqueVRegDef(Reg)) 263 continue; 264 collectDefInstr(Reg, MF, &MAI, MSType, IsFirst); 265 IsFirst = false; 266 if (E->getIsGV()) 267 MAI.GlobalVarList.push_back(MF->getRegInfo().getUniqueVRegDef(Reg)); 268 } 269 270 if (UsePreOrder) 271 for (auto *S : E->getDeps()) 272 RecHoistUtil(S); 273 }; 274 RecHoistUtil(E); 275 } 276 } 277 278 // The function initializes global register alias table for types, consts, 279 // global vars and func decls and collects these instruction for output 280 // at module level. Also it collects explicit OpExtension/OpCapability 281 // instructions. 282 void SPIRVModuleAnalysis::processDefInstrs(const Module &M) { 283 std::vector<SPIRV::DTSortableEntry *> DepsGraph; 284 285 GR->buildDepsGraph(DepsGraph, TII, SPVDumpDeps ? MMI : nullptr); 286 287 collectGlobalEntities( 288 DepsGraph, SPIRV::MB_TypeConstVars, 289 [](const SPIRV::DTSortableEntry *E) { return !E->getIsFunc(); }); 290 291 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 292 MachineFunction *MF = MMI->getMachineFunction(*F); 293 if (!MF) 294 continue; 295 // Iterate through and collect OpExtension/OpCapability instructions. 296 for (MachineBasicBlock &MBB : *MF) { 297 for (MachineInstr &MI : MBB) { 298 if (MI.getOpcode() == SPIRV::OpExtension) { 299 // Here, OpExtension just has a single enum operand, not a string. 300 auto Ext = SPIRV::Extension::Extension(MI.getOperand(0).getImm()); 301 MAI.Reqs.addExtension(Ext); 302 MAI.setSkipEmission(&MI); 303 } else if (MI.getOpcode() == SPIRV::OpCapability) { 304 auto Cap = SPIRV::Capability::Capability(MI.getOperand(0).getImm()); 305 MAI.Reqs.addCapability(Cap); 306 MAI.setSkipEmission(&MI); 307 } 308 } 309 } 310 } 311 312 collectGlobalEntities( 313 DepsGraph, SPIRV::MB_ExtFuncDecls, 314 [](const SPIRV::DTSortableEntry *E) { return E->getIsFunc(); }, true); 315 } 316 317 // Look for IDs declared with Import linkage, and map the corresponding function 318 // to the register defining that variable (which will usually be the result of 319 // an OpFunction). This lets us call externally imported functions using 320 // the correct ID registers. 321 void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI, 322 const Function *F) { 323 if (MI.getOpcode() == SPIRV::OpDecorate) { 324 // If it's got Import linkage. 325 auto Dec = MI.getOperand(1).getImm(); 326 if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) { 327 auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm(); 328 if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) { 329 // Map imported function name to function ID register. 330 const Function *ImportedFunc = 331 F->getParent()->getFunction(getStringImm(MI, 2)); 332 Register Target = MI.getOperand(0).getReg(); 333 MAI.FuncMap[ImportedFunc] = MAI.getRegisterAlias(MI.getMF(), Target); 334 } 335 } 336 } else if (MI.getOpcode() == SPIRV::OpFunction) { 337 // Record all internal OpFunction declarations. 338 Register Reg = MI.defs().begin()->getReg(); 339 Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg); 340 assert(GlobalReg.isValid()); 341 MAI.FuncMap[F] = GlobalReg; 342 } 343 } 344 345 // References to a function via function pointers generate virtual 346 // registers without a definition. We are able to resolve this 347 // reference using Globar Register info into an OpFunction instruction 348 // and replace dummy operands by the corresponding global register references. 349 void SPIRVModuleAnalysis::collectFuncPtrs() { 350 for (auto &MI : MAI.MS[SPIRV::MB_TypeConstVars]) 351 if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL) 352 collectFuncPtrs(MI); 353 } 354 355 void SPIRVModuleAnalysis::collectFuncPtrs(MachineInstr *MI) { 356 const MachineOperand *FunUse = &MI->getOperand(2); 357 if (const MachineOperand *FunDef = GR->getFunctionDefinitionByUse(FunUse)) { 358 const MachineInstr *FunDefMI = FunDef->getParent(); 359 assert(FunDefMI->getOpcode() == SPIRV::OpFunction && 360 "Constant function pointer must refer to function definition"); 361 Register FunDefReg = FunDef->getReg(); 362 Register GlobalFunDefReg = 363 MAI.getRegisterAlias(FunDefMI->getMF(), FunDefReg); 364 assert(GlobalFunDefReg.isValid() && 365 "Function definition must refer to a global register"); 366 Register FunPtrReg = FunUse->getReg(); 367 MAI.setRegisterAlias(MI->getMF(), FunPtrReg, GlobalFunDefReg); 368 } 369 } 370 371 using InstrSignature = SmallVector<size_t>; 372 using InstrTraces = std::set<InstrSignature>; 373 374 // Returns a representation of an instruction as a vector of MachineOperand 375 // hash values, see llvm::hash_value(const MachineOperand &MO) for details. 376 // This creates a signature of the instruction with the same content 377 // that MachineOperand::isIdenticalTo uses for comparison. 378 static InstrSignature instrToSignature(MachineInstr &MI, 379 SPIRV::ModuleAnalysisInfo &MAI) { 380 InstrSignature Signature; 381 for (unsigned i = 0; i < MI.getNumOperands(); ++i) { 382 const MachineOperand &MO = MI.getOperand(i); 383 size_t h; 384 if (MO.isReg()) { 385 Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg()); 386 // mimic llvm::hash_value(const MachineOperand &MO) 387 h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(), 388 MO.isDef()); 389 } else { 390 h = hash_value(MO); 391 } 392 Signature.push_back(h); 393 } 394 return Signature; 395 } 396 397 // Collect the given instruction in the specified MS. We assume global register 398 // numbering has already occurred by this point. We can directly compare reg 399 // arguments when detecting duplicates. 400 static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI, 401 SPIRV::ModuleSectionType MSType, InstrTraces &IS, 402 bool Append = true) { 403 MAI.setSkipEmission(&MI); 404 InstrSignature MISign = instrToSignature(MI, MAI); 405 auto FoundMI = IS.insert(MISign); 406 if (!FoundMI.second) 407 return; // insert failed, so we found a duplicate; don't add it to MAI.MS 408 // No duplicates, so add it. 409 if (Append) 410 MAI.MS[MSType].push_back(&MI); 411 else 412 MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI); 413 } 414 415 // Some global instructions make reference to function-local ID regs, so cannot 416 // be correctly collected until these registers are globally numbered. 417 void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) { 418 InstrTraces IS; 419 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 420 if ((*F).isDeclaration()) 421 continue; 422 MachineFunction *MF = MMI->getMachineFunction(*F); 423 assert(MF); 424 for (MachineBasicBlock &MBB : *MF) 425 for (MachineInstr &MI : MBB) { 426 if (MAI.getSkipEmission(&MI)) 427 continue; 428 const unsigned OpCode = MI.getOpcode(); 429 if (OpCode == SPIRV::OpString) { 430 collectOtherInstr(MI, MAI, SPIRV::MB_DebugStrings, IS); 431 } else if (OpCode == SPIRV::OpExtInst && MI.getOperand(2).isImm() && 432 MI.getOperand(2).getImm() == 433 SPIRV::InstructionSet:: 434 NonSemantic_Shader_DebugInfo_100) { 435 MachineOperand Ins = MI.getOperand(3); 436 namespace NS = SPIRV::NonSemanticExtInst; 437 static constexpr int64_t GlobalNonSemanticDITy[] = { 438 NS::DebugSource, NS::DebugCompilationUnit, NS::DebugInfoNone, 439 NS::DebugTypeBasic, NS::DebugTypePointer}; 440 bool IsGlobalDI = false; 441 for (unsigned Idx = 0; Idx < std::size(GlobalNonSemanticDITy); ++Idx) 442 IsGlobalDI |= Ins.getImm() == GlobalNonSemanticDITy[Idx]; 443 if (IsGlobalDI) 444 collectOtherInstr(MI, MAI, SPIRV::MB_NonSemanticGlobalDI, IS); 445 } else if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) { 446 collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames, IS); 447 } else if (OpCode == SPIRV::OpEntryPoint) { 448 collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints, IS); 449 } else if (TII->isDecorationInstr(MI)) { 450 collectOtherInstr(MI, MAI, SPIRV::MB_Annotations, IS); 451 collectFuncNames(MI, &*F); 452 } else if (TII->isConstantInstr(MI)) { 453 // Now OpSpecConstant*s are not in DT, 454 // but they need to be collected anyway. 455 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS); 456 } else if (OpCode == SPIRV::OpFunction) { 457 collectFuncNames(MI, &*F); 458 } else if (OpCode == SPIRV::OpTypeForwardPointer) { 459 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS, false); 460 } 461 } 462 } 463 } 464 465 // Number registers in all functions globally from 0 onwards and store 466 // the result in global register alias table. Some registers are already 467 // numbered in collectGlobalEntities. 468 void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) { 469 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 470 if ((*F).isDeclaration()) 471 continue; 472 MachineFunction *MF = MMI->getMachineFunction(*F); 473 assert(MF); 474 for (MachineBasicBlock &MBB : *MF) { 475 for (MachineInstr &MI : MBB) { 476 for (MachineOperand &Op : MI.operands()) { 477 if (!Op.isReg()) 478 continue; 479 Register Reg = Op.getReg(); 480 if (MAI.hasRegisterAlias(MF, Reg)) 481 continue; 482 Register NewReg = Register::index2VirtReg(MAI.getNextID()); 483 MAI.setRegisterAlias(MF, Reg, NewReg); 484 } 485 if (MI.getOpcode() != SPIRV::OpExtInst) 486 continue; 487 auto Set = MI.getOperand(2).getImm(); 488 if (!MAI.ExtInstSetMap.contains(Set)) 489 MAI.ExtInstSetMap[Set] = Register::index2VirtReg(MAI.getNextID()); 490 } 491 } 492 } 493 } 494 495 // RequirementHandler implementations. 496 void SPIRV::RequirementHandler::getAndAddRequirements( 497 SPIRV::OperandCategory::OperandCategory Category, uint32_t i, 498 const SPIRVSubtarget &ST) { 499 addRequirements(getSymbolicOperandRequirements(Category, i, ST, *this)); 500 } 501 502 void SPIRV::RequirementHandler::recursiveAddCapabilities( 503 const CapabilityList &ToPrune) { 504 for (const auto &Cap : ToPrune) { 505 AllCaps.insert(Cap); 506 CapabilityList ImplicitDecls = 507 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap); 508 recursiveAddCapabilities(ImplicitDecls); 509 } 510 } 511 512 void SPIRV::RequirementHandler::addCapabilities(const CapabilityList &ToAdd) { 513 for (const auto &Cap : ToAdd) { 514 bool IsNewlyInserted = AllCaps.insert(Cap).second; 515 if (!IsNewlyInserted) // Don't re-add if it's already been declared. 516 continue; 517 CapabilityList ImplicitDecls = 518 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap); 519 recursiveAddCapabilities(ImplicitDecls); 520 MinimalCaps.push_back(Cap); 521 } 522 } 523 524 void SPIRV::RequirementHandler::addRequirements( 525 const SPIRV::Requirements &Req) { 526 if (!Req.IsSatisfiable) 527 report_fatal_error("Adding SPIR-V requirements this target can't satisfy."); 528 529 if (Req.Cap.has_value()) 530 addCapabilities({Req.Cap.value()}); 531 532 addExtensions(Req.Exts); 533 534 if (!Req.MinVer.empty()) { 535 if (!MaxVersion.empty() && Req.MinVer > MaxVersion) { 536 LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer 537 << " and <= " << MaxVersion << "\n"); 538 report_fatal_error("Adding SPIR-V requirements that can't be satisfied."); 539 } 540 541 if (MinVersion.empty() || Req.MinVer > MinVersion) 542 MinVersion = Req.MinVer; 543 } 544 545 if (!Req.MaxVer.empty()) { 546 if (!MinVersion.empty() && Req.MaxVer < MinVersion) { 547 LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer 548 << " and >= " << MinVersion << "\n"); 549 report_fatal_error("Adding SPIR-V requirements that can't be satisfied."); 550 } 551 552 if (MaxVersion.empty() || Req.MaxVer < MaxVersion) 553 MaxVersion = Req.MaxVer; 554 } 555 } 556 557 void SPIRV::RequirementHandler::checkSatisfiable( 558 const SPIRVSubtarget &ST) const { 559 // Report as many errors as possible before aborting the compilation. 560 bool IsSatisfiable = true; 561 auto TargetVer = ST.getSPIRVVersion(); 562 563 if (!MaxVersion.empty() && !TargetVer.empty() && MaxVersion < TargetVer) { 564 LLVM_DEBUG( 565 dbgs() << "Target SPIR-V version too high for required features\n" 566 << "Required max version: " << MaxVersion << " target version " 567 << TargetVer << "\n"); 568 IsSatisfiable = false; 569 } 570 571 if (!MinVersion.empty() && !TargetVer.empty() && MinVersion > TargetVer) { 572 LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n" 573 << "Required min version: " << MinVersion 574 << " target version " << TargetVer << "\n"); 575 IsSatisfiable = false; 576 } 577 578 if (!MinVersion.empty() && !MaxVersion.empty() && MinVersion > MaxVersion) { 579 LLVM_DEBUG( 580 dbgs() 581 << "Version is too low for some features and too high for others.\n" 582 << "Required SPIR-V min version: " << MinVersion 583 << " required SPIR-V max version " << MaxVersion << "\n"); 584 IsSatisfiable = false; 585 } 586 587 for (auto Cap : MinimalCaps) { 588 if (AvailableCaps.contains(Cap)) 589 continue; 590 LLVM_DEBUG(dbgs() << "Capability not supported: " 591 << getSymbolicOperandMnemonic( 592 OperandCategory::CapabilityOperand, Cap) 593 << "\n"); 594 IsSatisfiable = false; 595 } 596 597 for (auto Ext : AllExtensions) { 598 if (ST.canUseExtension(Ext)) 599 continue; 600 LLVM_DEBUG(dbgs() << "Extension not supported: " 601 << getSymbolicOperandMnemonic( 602 OperandCategory::ExtensionOperand, Ext) 603 << "\n"); 604 IsSatisfiable = false; 605 } 606 607 if (!IsSatisfiable) 608 report_fatal_error("Unable to meet SPIR-V requirements for this target."); 609 } 610 611 // Add the given capabilities and all their implicitly defined capabilities too. 612 void SPIRV::RequirementHandler::addAvailableCaps(const CapabilityList &ToAdd) { 613 for (const auto Cap : ToAdd) 614 if (AvailableCaps.insert(Cap).second) 615 addAvailableCaps(getSymbolicOperandCapabilities( 616 SPIRV::OperandCategory::CapabilityOperand, Cap)); 617 } 618 619 void SPIRV::RequirementHandler::removeCapabilityIf( 620 const Capability::Capability ToRemove, 621 const Capability::Capability IfPresent) { 622 if (AllCaps.contains(IfPresent)) 623 AllCaps.erase(ToRemove); 624 } 625 626 namespace llvm { 627 namespace SPIRV { 628 void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) { 629 // Provided by both all supported Vulkan versions and OpenCl. 630 addAvailableCaps({Capability::Shader, Capability::Linkage, Capability::Int8, 631 Capability::Int16}); 632 633 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 3))) 634 addAvailableCaps({Capability::GroupNonUniform, 635 Capability::GroupNonUniformVote, 636 Capability::GroupNonUniformArithmetic, 637 Capability::GroupNonUniformBallot, 638 Capability::GroupNonUniformClustered, 639 Capability::GroupNonUniformShuffle, 640 Capability::GroupNonUniformShuffleRelative}); 641 642 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6))) 643 addAvailableCaps({Capability::DotProduct, Capability::DotProductInputAll, 644 Capability::DotProductInput4x8Bit, 645 Capability::DotProductInput4x8BitPacked, 646 Capability::DemoteToHelperInvocation}); 647 648 // Add capabilities enabled by extensions. 649 for (auto Extension : ST.getAllAvailableExtensions()) { 650 CapabilityList EnabledCapabilities = 651 getCapabilitiesEnabledByExtension(Extension); 652 addAvailableCaps(EnabledCapabilities); 653 } 654 655 if (ST.isOpenCLEnv()) { 656 initAvailableCapabilitiesForOpenCL(ST); 657 return; 658 } 659 660 if (ST.isVulkanEnv()) { 661 initAvailableCapabilitiesForVulkan(ST); 662 return; 663 } 664 665 report_fatal_error("Unimplemented environment for SPIR-V generation."); 666 } 667 668 void RequirementHandler::initAvailableCapabilitiesForOpenCL( 669 const SPIRVSubtarget &ST) { 670 // Add the min requirements for different OpenCL and SPIR-V versions. 671 addAvailableCaps({Capability::Addresses, Capability::Float16Buffer, 672 Capability::Kernel, Capability::Vector16, 673 Capability::Groups, Capability::GenericPointer, 674 Capability::StorageImageWriteWithoutFormat, 675 Capability::StorageImageReadWithoutFormat}); 676 if (ST.hasOpenCLFullProfile()) 677 addAvailableCaps({Capability::Int64, Capability::Int64Atomics}); 678 if (ST.hasOpenCLImageSupport()) { 679 addAvailableCaps({Capability::ImageBasic, Capability::LiteralSampler, 680 Capability::Image1D, Capability::SampledBuffer, 681 Capability::ImageBuffer}); 682 if (ST.isAtLeastOpenCLVer(VersionTuple(2, 0))) 683 addAvailableCaps({Capability::ImageReadWrite}); 684 } 685 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 1)) && 686 ST.isAtLeastOpenCLVer(VersionTuple(2, 2))) 687 addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage}); 688 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 4))) 689 addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero, 690 Capability::SignedZeroInfNanPreserve, 691 Capability::RoundingModeRTE, 692 Capability::RoundingModeRTZ}); 693 // TODO: verify if this needs some checks. 694 addAvailableCaps({Capability::Float16, Capability::Float64}); 695 696 // TODO: add OpenCL extensions. 697 } 698 699 void RequirementHandler::initAvailableCapabilitiesForVulkan( 700 const SPIRVSubtarget &ST) { 701 702 // Core in Vulkan 1.1 and earlier. 703 addAvailableCaps({Capability::Int64, Capability::Float16, Capability::Float64, 704 Capability::GroupNonUniform, Capability::Image1D, 705 Capability::SampledBuffer, Capability::ImageBuffer, 706 Capability::UniformBufferArrayDynamicIndexing, 707 Capability::SampledImageArrayDynamicIndexing, 708 Capability::StorageBufferArrayDynamicIndexing, 709 Capability::StorageImageArrayDynamicIndexing}); 710 711 // Became core in Vulkan 1.2 712 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 5))) { 713 addAvailableCaps( 714 {Capability::ShaderNonUniformEXT, Capability::RuntimeDescriptorArrayEXT, 715 Capability::InputAttachmentArrayDynamicIndexingEXT, 716 Capability::UniformTexelBufferArrayDynamicIndexingEXT, 717 Capability::StorageTexelBufferArrayDynamicIndexingEXT, 718 Capability::UniformBufferArrayNonUniformIndexingEXT, 719 Capability::SampledImageArrayNonUniformIndexingEXT, 720 Capability::StorageBufferArrayNonUniformIndexingEXT, 721 Capability::StorageImageArrayNonUniformIndexingEXT, 722 Capability::InputAttachmentArrayNonUniformIndexingEXT, 723 Capability::UniformTexelBufferArrayNonUniformIndexingEXT, 724 Capability::StorageTexelBufferArrayNonUniformIndexingEXT}); 725 } 726 727 // Became core in Vulkan 1.3 728 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6))) 729 addAvailableCaps({Capability::StorageImageWriteWithoutFormat, 730 Capability::StorageImageReadWithoutFormat}); 731 } 732 733 } // namespace SPIRV 734 } // namespace llvm 735 736 // Add the required capabilities from a decoration instruction (including 737 // BuiltIns). 738 static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex, 739 SPIRV::RequirementHandler &Reqs, 740 const SPIRVSubtarget &ST) { 741 int64_t DecOp = MI.getOperand(DecIndex).getImm(); 742 auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp); 743 Reqs.addRequirements(getSymbolicOperandRequirements( 744 SPIRV::OperandCategory::DecorationOperand, Dec, ST, Reqs)); 745 746 if (Dec == SPIRV::Decoration::BuiltIn) { 747 int64_t BuiltInOp = MI.getOperand(DecIndex + 1).getImm(); 748 auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp); 749 Reqs.addRequirements(getSymbolicOperandRequirements( 750 SPIRV::OperandCategory::BuiltInOperand, BuiltIn, ST, Reqs)); 751 } else if (Dec == SPIRV::Decoration::LinkageAttributes) { 752 int64_t LinkageOp = MI.getOperand(MI.getNumOperands() - 1).getImm(); 753 SPIRV::LinkageType::LinkageType LnkType = 754 static_cast<SPIRV::LinkageType::LinkageType>(LinkageOp); 755 if (LnkType == SPIRV::LinkageType::LinkOnceODR) 756 Reqs.addExtension(SPIRV::Extension::SPV_KHR_linkonce_odr); 757 } else if (Dec == SPIRV::Decoration::CacheControlLoadINTEL || 758 Dec == SPIRV::Decoration::CacheControlStoreINTEL) { 759 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_cache_controls); 760 } else if (Dec == SPIRV::Decoration::HostAccessINTEL) { 761 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_global_variable_host_access); 762 } else if (Dec == SPIRV::Decoration::InitModeINTEL || 763 Dec == SPIRV::Decoration::ImplementInRegisterMapINTEL) { 764 Reqs.addExtension( 765 SPIRV::Extension::SPV_INTEL_global_variable_fpga_decorations); 766 } else if (Dec == SPIRV::Decoration::NonUniformEXT) { 767 Reqs.addRequirements(SPIRV::Capability::ShaderNonUniformEXT); 768 } 769 } 770 771 // Add requirements for image handling. 772 static void addOpTypeImageReqs(const MachineInstr &MI, 773 SPIRV::RequirementHandler &Reqs, 774 const SPIRVSubtarget &ST) { 775 assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage"); 776 // The operand indices used here are based on the OpTypeImage layout, which 777 // the MachineInstr follows as well. 778 int64_t ImgFormatOp = MI.getOperand(7).getImm(); 779 auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp); 780 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageFormatOperand, 781 ImgFormat, ST); 782 783 bool IsArrayed = MI.getOperand(4).getImm() == 1; 784 bool IsMultisampled = MI.getOperand(5).getImm() == 1; 785 bool NoSampler = MI.getOperand(6).getImm() == 2; 786 // Add dimension requirements. 787 assert(MI.getOperand(2).isImm()); 788 switch (MI.getOperand(2).getImm()) { 789 case SPIRV::Dim::DIM_1D: 790 Reqs.addRequirements(NoSampler ? SPIRV::Capability::Image1D 791 : SPIRV::Capability::Sampled1D); 792 break; 793 case SPIRV::Dim::DIM_2D: 794 if (IsMultisampled && NoSampler) 795 Reqs.addRequirements(SPIRV::Capability::ImageMSArray); 796 break; 797 case SPIRV::Dim::DIM_Cube: 798 Reqs.addRequirements(SPIRV::Capability::Shader); 799 if (IsArrayed) 800 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageCubeArray 801 : SPIRV::Capability::SampledCubeArray); 802 break; 803 case SPIRV::Dim::DIM_Rect: 804 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageRect 805 : SPIRV::Capability::SampledRect); 806 break; 807 case SPIRV::Dim::DIM_Buffer: 808 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageBuffer 809 : SPIRV::Capability::SampledBuffer); 810 break; 811 case SPIRV::Dim::DIM_SubpassData: 812 Reqs.addRequirements(SPIRV::Capability::InputAttachment); 813 break; 814 } 815 816 // Has optional access qualifier. 817 if (ST.isOpenCLEnv()) { 818 if (MI.getNumOperands() > 8 && 819 MI.getOperand(8).getImm() == SPIRV::AccessQualifier::ReadWrite) 820 Reqs.addRequirements(SPIRV::Capability::ImageReadWrite); 821 else 822 Reqs.addRequirements(SPIRV::Capability::ImageBasic); 823 } 824 } 825 826 // Add requirements for handling atomic float instructions 827 #define ATOM_FLT_REQ_EXT_MSG(ExtName) \ 828 "The atomic float instruction requires the following SPIR-V " \ 829 "extension: SPV_EXT_shader_atomic_float" ExtName 830 static void AddAtomicFloatRequirements(const MachineInstr &MI, 831 SPIRV::RequirementHandler &Reqs, 832 const SPIRVSubtarget &ST) { 833 assert(MI.getOperand(1).isReg() && 834 "Expect register operand in atomic float instruction"); 835 Register TypeReg = MI.getOperand(1).getReg(); 836 SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg); 837 if (TypeDef->getOpcode() != SPIRV::OpTypeFloat) 838 report_fatal_error("Result type of an atomic float instruction must be a " 839 "floating-point type scalar"); 840 841 unsigned BitWidth = TypeDef->getOperand(1).getImm(); 842 unsigned Op = MI.getOpcode(); 843 if (Op == SPIRV::OpAtomicFAddEXT) { 844 if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add)) 845 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_add"), false); 846 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add); 847 switch (BitWidth) { 848 case 16: 849 if (!ST.canUseExtension( 850 SPIRV::Extension::SPV_EXT_shader_atomic_float16_add)) 851 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false); 852 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add); 853 Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT); 854 break; 855 case 32: 856 Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT); 857 break; 858 case 64: 859 Reqs.addCapability(SPIRV::Capability::AtomicFloat64AddEXT); 860 break; 861 default: 862 report_fatal_error( 863 "Unexpected floating-point type width in atomic float instruction"); 864 } 865 } else { 866 if (!ST.canUseExtension( 867 SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max)) 868 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), false); 869 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max); 870 switch (BitWidth) { 871 case 16: 872 Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT); 873 break; 874 case 32: 875 Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT); 876 break; 877 case 64: 878 Reqs.addCapability(SPIRV::Capability::AtomicFloat64MinMaxEXT); 879 break; 880 default: 881 report_fatal_error( 882 "Unexpected floating-point type width in atomic float instruction"); 883 } 884 } 885 } 886 887 bool isUniformTexelBuffer(MachineInstr *ImageInst) { 888 if (ImageInst->getOpcode() != SPIRV::OpTypeImage) 889 return false; 890 uint32_t Dim = ImageInst->getOperand(2).getImm(); 891 uint32_t Sampled = ImageInst->getOperand(6).getImm(); 892 return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 1; 893 } 894 895 bool isStorageTexelBuffer(MachineInstr *ImageInst) { 896 if (ImageInst->getOpcode() != SPIRV::OpTypeImage) 897 return false; 898 uint32_t Dim = ImageInst->getOperand(2).getImm(); 899 uint32_t Sampled = ImageInst->getOperand(6).getImm(); 900 return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 2; 901 } 902 903 bool isSampledImage(MachineInstr *ImageInst) { 904 if (ImageInst->getOpcode() != SPIRV::OpTypeImage) 905 return false; 906 uint32_t Dim = ImageInst->getOperand(2).getImm(); 907 uint32_t Sampled = ImageInst->getOperand(6).getImm(); 908 return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 1; 909 } 910 911 bool isInputAttachment(MachineInstr *ImageInst) { 912 if (ImageInst->getOpcode() != SPIRV::OpTypeImage) 913 return false; 914 uint32_t Dim = ImageInst->getOperand(2).getImm(); 915 uint32_t Sampled = ImageInst->getOperand(6).getImm(); 916 return Dim == SPIRV::Dim::DIM_SubpassData && Sampled == 2; 917 } 918 919 bool isStorageImage(MachineInstr *ImageInst) { 920 if (ImageInst->getOpcode() != SPIRV::OpTypeImage) 921 return false; 922 uint32_t Dim = ImageInst->getOperand(2).getImm(); 923 uint32_t Sampled = ImageInst->getOperand(6).getImm(); 924 return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 2; 925 } 926 927 bool isCombinedImageSampler(MachineInstr *SampledImageInst) { 928 if (SampledImageInst->getOpcode() != SPIRV::OpTypeSampledImage) 929 return false; 930 931 const MachineRegisterInfo &MRI = SampledImageInst->getMF()->getRegInfo(); 932 Register ImageReg = SampledImageInst->getOperand(1).getReg(); 933 auto *ImageInst = MRI.getUniqueVRegDef(ImageReg); 934 return isSampledImage(ImageInst); 935 } 936 937 bool hasNonUniformDecoration(Register Reg, const MachineRegisterInfo &MRI) { 938 for (const auto &MI : MRI.reg_instructions(Reg)) { 939 if (MI.getOpcode() != SPIRV::OpDecorate) 940 continue; 941 942 uint32_t Dec = MI.getOperand(1).getImm(); 943 if (Dec == SPIRV::Decoration::NonUniformEXT) 944 return true; 945 } 946 return false; 947 } 948 949 void addOpAccessChainReqs(const MachineInstr &Instr, 950 SPIRV::RequirementHandler &Handler, 951 const SPIRVSubtarget &Subtarget) { 952 const MachineRegisterInfo &MRI = Instr.getMF()->getRegInfo(); 953 // Get the result type. If it is an image type, then the shader uses 954 // descriptor indexing. The appropriate capabilities will be added based 955 // on the specifics of the image. 956 Register ResTypeReg = Instr.getOperand(1).getReg(); 957 MachineInstr *ResTypeInst = MRI.getUniqueVRegDef(ResTypeReg); 958 959 assert(ResTypeInst->getOpcode() == SPIRV::OpTypePointer); 960 uint32_t StorageClass = ResTypeInst->getOperand(1).getImm(); 961 if (StorageClass != SPIRV::StorageClass::StorageClass::UniformConstant && 962 StorageClass != SPIRV::StorageClass::StorageClass::Uniform && 963 StorageClass != SPIRV::StorageClass::StorageClass::StorageBuffer) { 964 return; 965 } 966 967 Register PointeeTypeReg = ResTypeInst->getOperand(2).getReg(); 968 MachineInstr *PointeeType = MRI.getUniqueVRegDef(PointeeTypeReg); 969 if (PointeeType->getOpcode() != SPIRV::OpTypeImage && 970 PointeeType->getOpcode() != SPIRV::OpTypeSampledImage && 971 PointeeType->getOpcode() != SPIRV::OpTypeSampler) { 972 return; 973 } 974 975 bool IsNonUniform = 976 hasNonUniformDecoration(Instr.getOperand(0).getReg(), MRI); 977 if (isUniformTexelBuffer(PointeeType)) { 978 if (IsNonUniform) 979 Handler.addRequirements( 980 SPIRV::Capability::UniformTexelBufferArrayNonUniformIndexingEXT); 981 else 982 Handler.addRequirements( 983 SPIRV::Capability::UniformTexelBufferArrayDynamicIndexingEXT); 984 } else if (isInputAttachment(PointeeType)) { 985 if (IsNonUniform) 986 Handler.addRequirements( 987 SPIRV::Capability::InputAttachmentArrayNonUniformIndexingEXT); 988 else 989 Handler.addRequirements( 990 SPIRV::Capability::InputAttachmentArrayDynamicIndexingEXT); 991 } else if (isStorageTexelBuffer(PointeeType)) { 992 if (IsNonUniform) 993 Handler.addRequirements( 994 SPIRV::Capability::StorageTexelBufferArrayNonUniformIndexingEXT); 995 else 996 Handler.addRequirements( 997 SPIRV::Capability::StorageTexelBufferArrayDynamicIndexingEXT); 998 } else if (isSampledImage(PointeeType) || 999 isCombinedImageSampler(PointeeType) || 1000 PointeeType->getOpcode() == SPIRV::OpTypeSampler) { 1001 if (IsNonUniform) 1002 Handler.addRequirements( 1003 SPIRV::Capability::SampledImageArrayNonUniformIndexingEXT); 1004 else 1005 Handler.addRequirements( 1006 SPIRV::Capability::SampledImageArrayDynamicIndexing); 1007 } else if (isStorageImage(PointeeType)) { 1008 if (IsNonUniform) 1009 Handler.addRequirements( 1010 SPIRV::Capability::StorageImageArrayNonUniformIndexingEXT); 1011 else 1012 Handler.addRequirements( 1013 SPIRV::Capability::StorageImageArrayDynamicIndexing); 1014 } 1015 } 1016 1017 static bool isImageTypeWithUnknownFormat(SPIRVType *TypeInst) { 1018 if (TypeInst->getOpcode() != SPIRV::OpTypeImage) 1019 return false; 1020 assert(TypeInst->getOperand(7).isImm() && "The image format must be an imm."); 1021 return TypeInst->getOperand(7).getImm() == 0; 1022 } 1023 1024 static void AddDotProductRequirements(const MachineInstr &MI, 1025 SPIRV::RequirementHandler &Reqs, 1026 const SPIRVSubtarget &ST) { 1027 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product)) 1028 Reqs.addExtension(SPIRV::Extension::SPV_KHR_integer_dot_product); 1029 Reqs.addCapability(SPIRV::Capability::DotProduct); 1030 1031 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); 1032 assert(MI.getOperand(2).isReg() && "Unexpected operand in dot"); 1033 // We do not consider what the previous instruction is. This is just used 1034 // to get the input register and to check the type. 1035 const MachineInstr *Input = MRI.getVRegDef(MI.getOperand(2).getReg()); 1036 assert(Input->getOperand(1).isReg() && "Unexpected operand in dot input"); 1037 Register InputReg = Input->getOperand(1).getReg(); 1038 1039 SPIRVType *TypeDef = MRI.getVRegDef(InputReg); 1040 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) { 1041 assert(TypeDef->getOperand(1).getImm() == 32); 1042 Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPacked); 1043 } else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) { 1044 SPIRVType *ScalarTypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg()); 1045 assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt); 1046 if (ScalarTypeDef->getOperand(1).getImm() == 8) { 1047 assert(TypeDef->getOperand(2).getImm() == 4 && 1048 "Dot operand of 8-bit integer type requires 4 components"); 1049 Reqs.addCapability(SPIRV::Capability::DotProductInput4x8Bit); 1050 } else { 1051 Reqs.addCapability(SPIRV::Capability::DotProductInputAll); 1052 } 1053 } 1054 } 1055 1056 void addInstrRequirements(const MachineInstr &MI, 1057 SPIRV::RequirementHandler &Reqs, 1058 const SPIRVSubtarget &ST) { 1059 switch (MI.getOpcode()) { 1060 case SPIRV::OpMemoryModel: { 1061 int64_t Addr = MI.getOperand(0).getImm(); 1062 Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand, 1063 Addr, ST); 1064 int64_t Mem = MI.getOperand(1).getImm(); 1065 Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, Mem, 1066 ST); 1067 break; 1068 } 1069 case SPIRV::OpEntryPoint: { 1070 int64_t Exe = MI.getOperand(0).getImm(); 1071 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModelOperand, 1072 Exe, ST); 1073 break; 1074 } 1075 case SPIRV::OpExecutionMode: 1076 case SPIRV::OpExecutionModeId: { 1077 int64_t Exe = MI.getOperand(1).getImm(); 1078 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModeOperand, 1079 Exe, ST); 1080 break; 1081 } 1082 case SPIRV::OpTypeMatrix: 1083 Reqs.addCapability(SPIRV::Capability::Matrix); 1084 break; 1085 case SPIRV::OpTypeInt: { 1086 unsigned BitWidth = MI.getOperand(1).getImm(); 1087 if (BitWidth == 64) 1088 Reqs.addCapability(SPIRV::Capability::Int64); 1089 else if (BitWidth == 16) 1090 Reqs.addCapability(SPIRV::Capability::Int16); 1091 else if (BitWidth == 8) 1092 Reqs.addCapability(SPIRV::Capability::Int8); 1093 break; 1094 } 1095 case SPIRV::OpTypeFloat: { 1096 unsigned BitWidth = MI.getOperand(1).getImm(); 1097 if (BitWidth == 64) 1098 Reqs.addCapability(SPIRV::Capability::Float64); 1099 else if (BitWidth == 16) 1100 Reqs.addCapability(SPIRV::Capability::Float16); 1101 break; 1102 } 1103 case SPIRV::OpTypeVector: { 1104 unsigned NumComponents = MI.getOperand(2).getImm(); 1105 if (NumComponents == 8 || NumComponents == 16) 1106 Reqs.addCapability(SPIRV::Capability::Vector16); 1107 break; 1108 } 1109 case SPIRV::OpTypePointer: { 1110 auto SC = MI.getOperand(1).getImm(); 1111 Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC, 1112 ST); 1113 // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer 1114 // capability. 1115 if (!ST.isOpenCLEnv()) 1116 break; 1117 assert(MI.getOperand(2).isReg()); 1118 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); 1119 SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg()); 1120 if (TypeDef->getOpcode() == SPIRV::OpTypeFloat && 1121 TypeDef->getOperand(1).getImm() == 16) 1122 Reqs.addCapability(SPIRV::Capability::Float16Buffer); 1123 break; 1124 } 1125 case SPIRV::OpExtInst: { 1126 if (MI.getOperand(2).getImm() == 1127 static_cast<int64_t>( 1128 SPIRV::InstructionSet::NonSemantic_Shader_DebugInfo_100)) { 1129 Reqs.addExtension(SPIRV::Extension::SPV_KHR_non_semantic_info); 1130 } 1131 break; 1132 } 1133 case SPIRV::OpBitReverse: 1134 case SPIRV::OpBitFieldInsert: 1135 case SPIRV::OpBitFieldSExtract: 1136 case SPIRV::OpBitFieldUExtract: 1137 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions)) { 1138 Reqs.addCapability(SPIRV::Capability::Shader); 1139 break; 1140 } 1141 Reqs.addExtension(SPIRV::Extension::SPV_KHR_bit_instructions); 1142 Reqs.addCapability(SPIRV::Capability::BitInstructions); 1143 break; 1144 case SPIRV::OpTypeRuntimeArray: 1145 Reqs.addCapability(SPIRV::Capability::Shader); 1146 break; 1147 case SPIRV::OpTypeOpaque: 1148 case SPIRV::OpTypeEvent: 1149 Reqs.addCapability(SPIRV::Capability::Kernel); 1150 break; 1151 case SPIRV::OpTypePipe: 1152 case SPIRV::OpTypeReserveId: 1153 Reqs.addCapability(SPIRV::Capability::Pipes); 1154 break; 1155 case SPIRV::OpTypeDeviceEvent: 1156 case SPIRV::OpTypeQueue: 1157 case SPIRV::OpBuildNDRange: 1158 Reqs.addCapability(SPIRV::Capability::DeviceEnqueue); 1159 break; 1160 case SPIRV::OpDecorate: 1161 case SPIRV::OpDecorateId: 1162 case SPIRV::OpDecorateString: 1163 addOpDecorateReqs(MI, 1, Reqs, ST); 1164 break; 1165 case SPIRV::OpMemberDecorate: 1166 case SPIRV::OpMemberDecorateString: 1167 addOpDecorateReqs(MI, 2, Reqs, ST); 1168 break; 1169 case SPIRV::OpInBoundsPtrAccessChain: 1170 Reqs.addCapability(SPIRV::Capability::Addresses); 1171 break; 1172 case SPIRV::OpConstantSampler: 1173 Reqs.addCapability(SPIRV::Capability::LiteralSampler); 1174 break; 1175 case SPIRV::OpInBoundsAccessChain: 1176 case SPIRV::OpAccessChain: 1177 addOpAccessChainReqs(MI, Reqs, ST); 1178 break; 1179 case SPIRV::OpTypeImage: 1180 addOpTypeImageReqs(MI, Reqs, ST); 1181 break; 1182 case SPIRV::OpTypeSampler: 1183 if (!ST.isVulkanEnv()) { 1184 Reqs.addCapability(SPIRV::Capability::ImageBasic); 1185 } 1186 break; 1187 case SPIRV::OpTypeForwardPointer: 1188 // TODO: check if it's OpenCL's kernel. 1189 Reqs.addCapability(SPIRV::Capability::Addresses); 1190 break; 1191 case SPIRV::OpAtomicFlagTestAndSet: 1192 case SPIRV::OpAtomicLoad: 1193 case SPIRV::OpAtomicStore: 1194 case SPIRV::OpAtomicExchange: 1195 case SPIRV::OpAtomicCompareExchange: 1196 case SPIRV::OpAtomicIIncrement: 1197 case SPIRV::OpAtomicIDecrement: 1198 case SPIRV::OpAtomicIAdd: 1199 case SPIRV::OpAtomicISub: 1200 case SPIRV::OpAtomicUMin: 1201 case SPIRV::OpAtomicUMax: 1202 case SPIRV::OpAtomicSMin: 1203 case SPIRV::OpAtomicSMax: 1204 case SPIRV::OpAtomicAnd: 1205 case SPIRV::OpAtomicOr: 1206 case SPIRV::OpAtomicXor: { 1207 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); 1208 const MachineInstr *InstrPtr = &MI; 1209 if (MI.getOpcode() == SPIRV::OpAtomicStore) { 1210 assert(MI.getOperand(3).isReg()); 1211 InstrPtr = MRI.getVRegDef(MI.getOperand(3).getReg()); 1212 assert(InstrPtr && "Unexpected type instruction for OpAtomicStore"); 1213 } 1214 assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic"); 1215 Register TypeReg = InstrPtr->getOperand(1).getReg(); 1216 SPIRVType *TypeDef = MRI.getVRegDef(TypeReg); 1217 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) { 1218 unsigned BitWidth = TypeDef->getOperand(1).getImm(); 1219 if (BitWidth == 64) 1220 Reqs.addCapability(SPIRV::Capability::Int64Atomics); 1221 } 1222 break; 1223 } 1224 case SPIRV::OpGroupNonUniformIAdd: 1225 case SPIRV::OpGroupNonUniformFAdd: 1226 case SPIRV::OpGroupNonUniformIMul: 1227 case SPIRV::OpGroupNonUniformFMul: 1228 case SPIRV::OpGroupNonUniformSMin: 1229 case SPIRV::OpGroupNonUniformUMin: 1230 case SPIRV::OpGroupNonUniformFMin: 1231 case SPIRV::OpGroupNonUniformSMax: 1232 case SPIRV::OpGroupNonUniformUMax: 1233 case SPIRV::OpGroupNonUniformFMax: 1234 case SPIRV::OpGroupNonUniformBitwiseAnd: 1235 case SPIRV::OpGroupNonUniformBitwiseOr: 1236 case SPIRV::OpGroupNonUniformBitwiseXor: 1237 case SPIRV::OpGroupNonUniformLogicalAnd: 1238 case SPIRV::OpGroupNonUniformLogicalOr: 1239 case SPIRV::OpGroupNonUniformLogicalXor: { 1240 assert(MI.getOperand(3).isImm()); 1241 int64_t GroupOp = MI.getOperand(3).getImm(); 1242 switch (GroupOp) { 1243 case SPIRV::GroupOperation::Reduce: 1244 case SPIRV::GroupOperation::InclusiveScan: 1245 case SPIRV::GroupOperation::ExclusiveScan: 1246 Reqs.addCapability(SPIRV::Capability::Kernel); 1247 Reqs.addCapability(SPIRV::Capability::GroupNonUniformArithmetic); 1248 Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot); 1249 break; 1250 case SPIRV::GroupOperation::ClusteredReduce: 1251 Reqs.addCapability(SPIRV::Capability::GroupNonUniformClustered); 1252 break; 1253 case SPIRV::GroupOperation::PartitionedReduceNV: 1254 case SPIRV::GroupOperation::PartitionedInclusiveScanNV: 1255 case SPIRV::GroupOperation::PartitionedExclusiveScanNV: 1256 Reqs.addCapability(SPIRV::Capability::GroupNonUniformPartitionedNV); 1257 break; 1258 } 1259 break; 1260 } 1261 case SPIRV::OpGroupNonUniformShuffle: 1262 case SPIRV::OpGroupNonUniformShuffleXor: 1263 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffle); 1264 break; 1265 case SPIRV::OpGroupNonUniformShuffleUp: 1266 case SPIRV::OpGroupNonUniformShuffleDown: 1267 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffleRelative); 1268 break; 1269 case SPIRV::OpGroupAll: 1270 case SPIRV::OpGroupAny: 1271 case SPIRV::OpGroupBroadcast: 1272 case SPIRV::OpGroupIAdd: 1273 case SPIRV::OpGroupFAdd: 1274 case SPIRV::OpGroupFMin: 1275 case SPIRV::OpGroupUMin: 1276 case SPIRV::OpGroupSMin: 1277 case SPIRV::OpGroupFMax: 1278 case SPIRV::OpGroupUMax: 1279 case SPIRV::OpGroupSMax: 1280 Reqs.addCapability(SPIRV::Capability::Groups); 1281 break; 1282 case SPIRV::OpGroupNonUniformElect: 1283 Reqs.addCapability(SPIRV::Capability::GroupNonUniform); 1284 break; 1285 case SPIRV::OpGroupNonUniformAll: 1286 case SPIRV::OpGroupNonUniformAny: 1287 case SPIRV::OpGroupNonUniformAllEqual: 1288 Reqs.addCapability(SPIRV::Capability::GroupNonUniformVote); 1289 break; 1290 case SPIRV::OpGroupNonUniformBroadcast: 1291 case SPIRV::OpGroupNonUniformBroadcastFirst: 1292 case SPIRV::OpGroupNonUniformBallot: 1293 case SPIRV::OpGroupNonUniformInverseBallot: 1294 case SPIRV::OpGroupNonUniformBallotBitExtract: 1295 case SPIRV::OpGroupNonUniformBallotBitCount: 1296 case SPIRV::OpGroupNonUniformBallotFindLSB: 1297 case SPIRV::OpGroupNonUniformBallotFindMSB: 1298 Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot); 1299 break; 1300 case SPIRV::OpSubgroupShuffleINTEL: 1301 case SPIRV::OpSubgroupShuffleDownINTEL: 1302 case SPIRV::OpSubgroupShuffleUpINTEL: 1303 case SPIRV::OpSubgroupShuffleXorINTEL: 1304 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) { 1305 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups); 1306 Reqs.addCapability(SPIRV::Capability::SubgroupShuffleINTEL); 1307 } 1308 break; 1309 case SPIRV::OpSubgroupBlockReadINTEL: 1310 case SPIRV::OpSubgroupBlockWriteINTEL: 1311 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) { 1312 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups); 1313 Reqs.addCapability(SPIRV::Capability::SubgroupBufferBlockIOINTEL); 1314 } 1315 break; 1316 case SPIRV::OpSubgroupImageBlockReadINTEL: 1317 case SPIRV::OpSubgroupImageBlockWriteINTEL: 1318 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) { 1319 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups); 1320 Reqs.addCapability(SPIRV::Capability::SubgroupImageBlockIOINTEL); 1321 } 1322 break; 1323 case SPIRV::OpSubgroupImageMediaBlockReadINTEL: 1324 case SPIRV::OpSubgroupImageMediaBlockWriteINTEL: 1325 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_media_block_io)) { 1326 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_media_block_io); 1327 Reqs.addCapability(SPIRV::Capability::SubgroupImageMediaBlockIOINTEL); 1328 } 1329 break; 1330 case SPIRV::OpAssumeTrueKHR: 1331 case SPIRV::OpExpectKHR: 1332 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) { 1333 Reqs.addExtension(SPIRV::Extension::SPV_KHR_expect_assume); 1334 Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR); 1335 } 1336 break; 1337 case SPIRV::OpPtrCastToCrossWorkgroupINTEL: 1338 case SPIRV::OpCrossWorkgroupCastToPtrINTEL: 1339 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) { 1340 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes); 1341 Reqs.addCapability(SPIRV::Capability::USMStorageClassesINTEL); 1342 } 1343 break; 1344 case SPIRV::OpConstantFunctionPointerINTEL: 1345 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) { 1346 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers); 1347 Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL); 1348 } 1349 break; 1350 case SPIRV::OpGroupNonUniformRotateKHR: 1351 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate)) 1352 report_fatal_error("OpGroupNonUniformRotateKHR instruction requires the " 1353 "following SPIR-V extension: SPV_KHR_subgroup_rotate", 1354 false); 1355 Reqs.addExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate); 1356 Reqs.addCapability(SPIRV::Capability::GroupNonUniformRotateKHR); 1357 Reqs.addCapability(SPIRV::Capability::GroupNonUniform); 1358 break; 1359 case SPIRV::OpGroupIMulKHR: 1360 case SPIRV::OpGroupFMulKHR: 1361 case SPIRV::OpGroupBitwiseAndKHR: 1362 case SPIRV::OpGroupBitwiseOrKHR: 1363 case SPIRV::OpGroupBitwiseXorKHR: 1364 case SPIRV::OpGroupLogicalAndKHR: 1365 case SPIRV::OpGroupLogicalOrKHR: 1366 case SPIRV::OpGroupLogicalXorKHR: 1367 if (ST.canUseExtension( 1368 SPIRV::Extension::SPV_KHR_uniform_group_instructions)) { 1369 Reqs.addExtension(SPIRV::Extension::SPV_KHR_uniform_group_instructions); 1370 Reqs.addCapability(SPIRV::Capability::GroupUniformArithmeticKHR); 1371 } 1372 break; 1373 case SPIRV::OpReadClockKHR: 1374 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_shader_clock)) 1375 report_fatal_error("OpReadClockKHR instruction requires the " 1376 "following SPIR-V extension: SPV_KHR_shader_clock", 1377 false); 1378 Reqs.addExtension(SPIRV::Extension::SPV_KHR_shader_clock); 1379 Reqs.addCapability(SPIRV::Capability::ShaderClockKHR); 1380 break; 1381 case SPIRV::OpFunctionPointerCallINTEL: 1382 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) { 1383 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers); 1384 Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL); 1385 } 1386 break; 1387 case SPIRV::OpAtomicFAddEXT: 1388 case SPIRV::OpAtomicFMinEXT: 1389 case SPIRV::OpAtomicFMaxEXT: 1390 AddAtomicFloatRequirements(MI, Reqs, ST); 1391 break; 1392 case SPIRV::OpConvertBF16ToFINTEL: 1393 case SPIRV::OpConvertFToBF16INTEL: 1394 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) { 1395 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion); 1396 Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL); 1397 } 1398 break; 1399 case SPIRV::OpVariableLengthArrayINTEL: 1400 case SPIRV::OpSaveMemoryINTEL: 1401 case SPIRV::OpRestoreMemoryINTEL: 1402 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_variable_length_array)) { 1403 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_variable_length_array); 1404 Reqs.addCapability(SPIRV::Capability::VariableLengthArrayINTEL); 1405 } 1406 break; 1407 case SPIRV::OpAsmTargetINTEL: 1408 case SPIRV::OpAsmINTEL: 1409 case SPIRV::OpAsmCallINTEL: 1410 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_inline_assembly)) { 1411 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_inline_assembly); 1412 Reqs.addCapability(SPIRV::Capability::AsmINTEL); 1413 } 1414 break; 1415 case SPIRV::OpTypeCooperativeMatrixKHR: 1416 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix)) 1417 report_fatal_error( 1418 "OpTypeCooperativeMatrixKHR type requires the " 1419 "following SPIR-V extension: SPV_KHR_cooperative_matrix", 1420 false); 1421 Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix); 1422 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR); 1423 break; 1424 case SPIRV::OpArithmeticFenceEXT: 1425 if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence)) 1426 report_fatal_error("OpArithmeticFenceEXT requires the " 1427 "following SPIR-V extension: SPV_EXT_arithmetic_fence", 1428 false); 1429 Reqs.addExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence); 1430 Reqs.addCapability(SPIRV::Capability::ArithmeticFenceEXT); 1431 break; 1432 case SPIRV::OpControlBarrierArriveINTEL: 1433 case SPIRV::OpControlBarrierWaitINTEL: 1434 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_split_barrier)) { 1435 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_split_barrier); 1436 Reqs.addCapability(SPIRV::Capability::SplitBarrierINTEL); 1437 } 1438 break; 1439 case SPIRV::OpKill: { 1440 Reqs.addCapability(SPIRV::Capability::Shader); 1441 } break; 1442 case SPIRV::OpDemoteToHelperInvocation: 1443 Reqs.addCapability(SPIRV::Capability::DemoteToHelperInvocation); 1444 1445 if (ST.canUseExtension( 1446 SPIRV::Extension::SPV_EXT_demote_to_helper_invocation)) { 1447 if (!ST.isAtLeastSPIRVVer(llvm::VersionTuple(1, 6))) 1448 Reqs.addExtension( 1449 SPIRV::Extension::SPV_EXT_demote_to_helper_invocation); 1450 } 1451 break; 1452 case SPIRV::OpSDot: 1453 case SPIRV::OpUDot: 1454 AddDotProductRequirements(MI, Reqs, ST); 1455 break; 1456 case SPIRV::OpImageRead: { 1457 Register ImageReg = MI.getOperand(2).getReg(); 1458 SPIRVType *TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(ImageReg); 1459 if (isImageTypeWithUnknownFormat(TypeDef)) 1460 Reqs.addCapability(SPIRV::Capability::StorageImageReadWithoutFormat); 1461 break; 1462 } 1463 case SPIRV::OpImageWrite: { 1464 Register ImageReg = MI.getOperand(0).getReg(); 1465 SPIRVType *TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(ImageReg); 1466 if (isImageTypeWithUnknownFormat(TypeDef)) 1467 Reqs.addCapability(SPIRV::Capability::StorageImageWriteWithoutFormat); 1468 break; 1469 } 1470 1471 default: 1472 break; 1473 } 1474 1475 // If we require capability Shader, then we can remove the requirement for 1476 // the BitInstructions capability, since Shader is a superset capability 1477 // of BitInstructions. 1478 Reqs.removeCapabilityIf(SPIRV::Capability::BitInstructions, 1479 SPIRV::Capability::Shader); 1480 } 1481 1482 static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, 1483 MachineModuleInfo *MMI, const SPIRVSubtarget &ST) { 1484 // Collect requirements for existing instructions. 1485 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 1486 MachineFunction *MF = MMI->getMachineFunction(*F); 1487 if (!MF) 1488 continue; 1489 for (const MachineBasicBlock &MBB : *MF) 1490 for (const MachineInstr &MI : MBB) 1491 addInstrRequirements(MI, MAI.Reqs, ST); 1492 } 1493 // Collect requirements for OpExecutionMode instructions. 1494 auto Node = M.getNamedMetadata("spirv.ExecutionMode"); 1495 if (Node) { 1496 // SPV_KHR_float_controls is not available until v1.4 1497 bool RequireFloatControls = false, 1498 VerLower14 = !ST.isAtLeastSPIRVVer(VersionTuple(1, 4)); 1499 for (unsigned i = 0; i < Node->getNumOperands(); i++) { 1500 MDNode *MDN = cast<MDNode>(Node->getOperand(i)); 1501 const MDOperand &MDOp = MDN->getOperand(1); 1502 if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) { 1503 Constant *C = CMeta->getValue(); 1504 if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) { 1505 auto EM = Const->getZExtValue(); 1506 MAI.Reqs.getAndAddRequirements( 1507 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST); 1508 // add SPV_KHR_float_controls if the version is too low 1509 switch (EM) { 1510 case SPIRV::ExecutionMode::DenormPreserve: 1511 case SPIRV::ExecutionMode::DenormFlushToZero: 1512 case SPIRV::ExecutionMode::SignedZeroInfNanPreserve: 1513 case SPIRV::ExecutionMode::RoundingModeRTE: 1514 case SPIRV::ExecutionMode::RoundingModeRTZ: 1515 RequireFloatControls = VerLower14; 1516 break; 1517 } 1518 } 1519 } 1520 } 1521 if (RequireFloatControls && 1522 ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls)) 1523 MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls); 1524 } 1525 for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) { 1526 const Function &F = *FI; 1527 if (F.isDeclaration()) 1528 continue; 1529 if (F.getMetadata("reqd_work_group_size")) 1530 MAI.Reqs.getAndAddRequirements( 1531 SPIRV::OperandCategory::ExecutionModeOperand, 1532 SPIRV::ExecutionMode::LocalSize, ST); 1533 if (F.getFnAttribute("hlsl.numthreads").isValid()) { 1534 MAI.Reqs.getAndAddRequirements( 1535 SPIRV::OperandCategory::ExecutionModeOperand, 1536 SPIRV::ExecutionMode::LocalSize, ST); 1537 } 1538 if (F.getMetadata("work_group_size_hint")) 1539 MAI.Reqs.getAndAddRequirements( 1540 SPIRV::OperandCategory::ExecutionModeOperand, 1541 SPIRV::ExecutionMode::LocalSizeHint, ST); 1542 if (F.getMetadata("intel_reqd_sub_group_size")) 1543 MAI.Reqs.getAndAddRequirements( 1544 SPIRV::OperandCategory::ExecutionModeOperand, 1545 SPIRV::ExecutionMode::SubgroupSize, ST); 1546 if (F.getMetadata("vec_type_hint")) 1547 MAI.Reqs.getAndAddRequirements( 1548 SPIRV::OperandCategory::ExecutionModeOperand, 1549 SPIRV::ExecutionMode::VecTypeHint, ST); 1550 1551 if (F.hasOptNone() && 1552 ST.canUseExtension(SPIRV::Extension::SPV_INTEL_optnone)) { 1553 // Output OpCapability OptNoneINTEL. 1554 MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_optnone); 1555 MAI.Reqs.addCapability(SPIRV::Capability::OptNoneINTEL); 1556 } 1557 } 1558 } 1559 1560 static unsigned getFastMathFlags(const MachineInstr &I) { 1561 unsigned Flags = SPIRV::FPFastMathMode::None; 1562 if (I.getFlag(MachineInstr::MIFlag::FmNoNans)) 1563 Flags |= SPIRV::FPFastMathMode::NotNaN; 1564 if (I.getFlag(MachineInstr::MIFlag::FmNoInfs)) 1565 Flags |= SPIRV::FPFastMathMode::NotInf; 1566 if (I.getFlag(MachineInstr::MIFlag::FmNsz)) 1567 Flags |= SPIRV::FPFastMathMode::NSZ; 1568 if (I.getFlag(MachineInstr::MIFlag::FmArcp)) 1569 Flags |= SPIRV::FPFastMathMode::AllowRecip; 1570 if (I.getFlag(MachineInstr::MIFlag::FmReassoc)) 1571 Flags |= SPIRV::FPFastMathMode::Fast; 1572 return Flags; 1573 } 1574 1575 static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST, 1576 const SPIRVInstrInfo &TII, 1577 SPIRV::RequirementHandler &Reqs) { 1578 if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) && 1579 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand, 1580 SPIRV::Decoration::NoSignedWrap, ST, Reqs) 1581 .IsSatisfiable) { 1582 buildOpDecorate(I.getOperand(0).getReg(), I, TII, 1583 SPIRV::Decoration::NoSignedWrap, {}); 1584 } 1585 if (I.getFlag(MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(I) && 1586 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand, 1587 SPIRV::Decoration::NoUnsignedWrap, ST, 1588 Reqs) 1589 .IsSatisfiable) { 1590 buildOpDecorate(I.getOperand(0).getReg(), I, TII, 1591 SPIRV::Decoration::NoUnsignedWrap, {}); 1592 } 1593 if (!TII.canUseFastMathFlags(I)) 1594 return; 1595 unsigned FMFlags = getFastMathFlags(I); 1596 if (FMFlags == SPIRV::FPFastMathMode::None) 1597 return; 1598 Register DstReg = I.getOperand(0).getReg(); 1599 buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode, {FMFlags}); 1600 } 1601 1602 // Walk all functions and add decorations related to MI flags. 1603 static void addDecorations(const Module &M, const SPIRVInstrInfo &TII, 1604 MachineModuleInfo *MMI, const SPIRVSubtarget &ST, 1605 SPIRV::ModuleAnalysisInfo &MAI) { 1606 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 1607 MachineFunction *MF = MMI->getMachineFunction(*F); 1608 if (!MF) 1609 continue; 1610 for (auto &MBB : *MF) 1611 for (auto &MI : MBB) 1612 handleMIFlagDecoration(MI, ST, TII, MAI.Reqs); 1613 } 1614 } 1615 1616 struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI; 1617 1618 void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { 1619 AU.addRequired<TargetPassConfig>(); 1620 AU.addRequired<MachineModuleInfoWrapperPass>(); 1621 } 1622 1623 bool SPIRVModuleAnalysis::runOnModule(Module &M) { 1624 SPIRVTargetMachine &TM = 1625 getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>(); 1626 ST = TM.getSubtargetImpl(); 1627 GR = ST->getSPIRVGlobalRegistry(); 1628 TII = ST->getInstrInfo(); 1629 1630 MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI(); 1631 1632 setBaseInfo(M); 1633 1634 addDecorations(M, *TII, MMI, *ST, MAI); 1635 1636 collectReqs(M, MAI, MMI, *ST); 1637 1638 // Process type/const/global var/func decl instructions, number their 1639 // destination registers from 0 to N, collect Extensions and Capabilities. 1640 processDefInstrs(M); 1641 1642 // Number rest of registers from N+1 onwards. 1643 numberRegistersGlobally(M); 1644 1645 // Update references to OpFunction instructions to use Global Registers 1646 if (GR->hasConstFunPtr()) 1647 collectFuncPtrs(); 1648 1649 // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions. 1650 processOtherInstrs(M); 1651 1652 // If there are no entry points, we need the Linkage capability. 1653 if (MAI.MS[SPIRV::MB_EntryPoints].empty()) 1654 MAI.Reqs.addCapability(SPIRV::Capability::Linkage); 1655 1656 // Set maximum ID used. 1657 GR->setBound(MAI.MaxID); 1658 1659 return false; 1660 } 1661