1 //===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The analysis collects instructions that should be output at the module level 10 // and performs the global register numbering. 11 // 12 // The results of this analysis are used in AsmPrinter to rename registers 13 // globally and to output required instructions at the module level. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "SPIRVModuleAnalysis.h" 18 #include "MCTargetDesc/SPIRVBaseInfo.h" 19 #include "MCTargetDesc/SPIRVMCTargetDesc.h" 20 #include "SPIRV.h" 21 #include "SPIRVSubtarget.h" 22 #include "SPIRVTargetMachine.h" 23 #include "SPIRVUtils.h" 24 #include "llvm/ADT/STLExtras.h" 25 #include "llvm/CodeGen/MachineModuleInfo.h" 26 #include "llvm/CodeGen/TargetPassConfig.h" 27 28 using namespace llvm; 29 30 #define DEBUG_TYPE "spirv-module-analysis" 31 32 static cl::opt<bool> 33 SPVDumpDeps("spv-dump-deps", 34 cl::desc("Dump MIR with SPIR-V dependencies info"), 35 cl::Optional, cl::init(false)); 36 37 static cl::list<SPIRV::Capability::Capability> 38 AvoidCapabilities("avoid-spirv-capabilities", 39 cl::desc("SPIR-V capabilities to avoid if there are " 40 "other options enabling a feature"), 41 cl::ZeroOrMore, cl::Hidden, 42 cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader", 43 "SPIR-V Shader capability"))); 44 // Use sets instead of cl::list to check "if contains" condition 45 struct AvoidCapabilitiesSet { 46 SmallSet<SPIRV::Capability::Capability, 4> S; 47 AvoidCapabilitiesSet() { 48 for (auto Cap : AvoidCapabilities) 49 S.insert(Cap); 50 } 51 }; 52 53 char llvm::SPIRVModuleAnalysis::ID = 0; 54 55 namespace llvm { 56 void initializeSPIRVModuleAnalysisPass(PassRegistry &); 57 } // namespace llvm 58 59 INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true, 60 true) 61 62 // Retrieve an unsigned from an MDNode with a list of them as operands. 63 static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex, 64 unsigned DefaultVal = 0) { 65 if (MdNode && OpIndex < MdNode->getNumOperands()) { 66 const auto &Op = MdNode->getOperand(OpIndex); 67 return mdconst::extract<ConstantInt>(Op)->getZExtValue(); 68 } 69 return DefaultVal; 70 } 71 72 static SPIRV::Requirements 73 getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category, 74 unsigned i, const SPIRVSubtarget &ST, 75 SPIRV::RequirementHandler &Reqs) { 76 static AvoidCapabilitiesSet 77 AvoidCaps; // contains capabilities to avoid if there is another option 78 79 VersionTuple ReqMinVer = getSymbolicOperandMinVersion(Category, i); 80 VersionTuple ReqMaxVer = getSymbolicOperandMaxVersion(Category, i); 81 VersionTuple SPIRVVersion = ST.getSPIRVVersion(); 82 bool MinVerOK = SPIRVVersion.empty() || SPIRVVersion >= ReqMinVer; 83 bool MaxVerOK = 84 ReqMaxVer.empty() || SPIRVVersion.empty() || SPIRVVersion <= ReqMaxVer; 85 CapabilityList ReqCaps = getSymbolicOperandCapabilities(Category, i); 86 ExtensionList ReqExts = getSymbolicOperandExtensions(Category, i); 87 if (ReqCaps.empty()) { 88 if (ReqExts.empty()) { 89 if (MinVerOK && MaxVerOK) 90 return {true, {}, {}, ReqMinVer, ReqMaxVer}; 91 return {false, {}, {}, VersionTuple(), VersionTuple()}; 92 } 93 } else if (MinVerOK && MaxVerOK) { 94 if (ReqCaps.size() == 1) { 95 auto Cap = ReqCaps[0]; 96 if (Reqs.isCapabilityAvailable(Cap)) 97 return {true, {Cap}, ReqExts, ReqMinVer, ReqMaxVer}; 98 } else { 99 // By SPIR-V specification: "If an instruction, enumerant, or other 100 // feature specifies multiple enabling capabilities, only one such 101 // capability needs to be declared to use the feature." However, one 102 // capability may be preferred over another. We use command line 103 // argument(s) and AvoidCapabilities to avoid selection of certain 104 // capabilities if there are other options. 105 CapabilityList UseCaps; 106 for (auto Cap : ReqCaps) 107 if (Reqs.isCapabilityAvailable(Cap)) 108 UseCaps.push_back(Cap); 109 for (size_t i = 0, Sz = UseCaps.size(); i < Sz; ++i) { 110 auto Cap = UseCaps[i]; 111 if (i == Sz - 1 || !AvoidCaps.S.contains(Cap)) 112 return {true, {Cap}, ReqExts, ReqMinVer, ReqMaxVer}; 113 } 114 } 115 } 116 // If there are no capabilities, or we can't satisfy the version or 117 // capability requirements, use the list of extensions (if the subtarget 118 // can handle them all). 119 if (llvm::all_of(ReqExts, [&ST](const SPIRV::Extension::Extension &Ext) { 120 return ST.canUseExtension(Ext); 121 })) { 122 return {true, 123 {}, 124 ReqExts, 125 VersionTuple(), 126 VersionTuple()}; // TODO: add versions to extensions. 127 } 128 return {false, {}, {}, VersionTuple(), VersionTuple()}; 129 } 130 131 void SPIRVModuleAnalysis::setBaseInfo(const Module &M) { 132 MAI.MaxID = 0; 133 for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++) 134 MAI.MS[i].clear(); 135 MAI.RegisterAliasTable.clear(); 136 MAI.InstrsToDelete.clear(); 137 MAI.FuncMap.clear(); 138 MAI.GlobalVarList.clear(); 139 MAI.ExtInstSetMap.clear(); 140 MAI.Reqs.clear(); 141 MAI.Reqs.initAvailableCapabilities(*ST); 142 143 // TODO: determine memory model and source language from the configuratoin. 144 if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) { 145 auto MemMD = MemModel->getOperand(0); 146 MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>( 147 getMetadataUInt(MemMD, 0)); 148 MAI.Mem = 149 static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1)); 150 } else { 151 // TODO: Add support for VulkanMemoryModel. 152 MAI.Mem = ST->isOpenCLEnv() ? SPIRV::MemoryModel::OpenCL 153 : SPIRV::MemoryModel::GLSL450; 154 if (MAI.Mem == SPIRV::MemoryModel::OpenCL) { 155 unsigned PtrSize = ST->getPointerSize(); 156 MAI.Addr = PtrSize == 32 ? SPIRV::AddressingModel::Physical32 157 : PtrSize == 64 ? SPIRV::AddressingModel::Physical64 158 : SPIRV::AddressingModel::Logical; 159 } else { 160 // TODO: Add support for PhysicalStorageBufferAddress. 161 MAI.Addr = SPIRV::AddressingModel::Logical; 162 } 163 } 164 // Get the OpenCL version number from metadata. 165 // TODO: support other source languages. 166 if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) { 167 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C; 168 // Construct version literal in accordance with SPIRV-LLVM-Translator. 169 // TODO: support multiple OCL version metadata. 170 assert(VerNode->getNumOperands() > 0 && "Invalid SPIR"); 171 auto VersionMD = VerNode->getOperand(0); 172 unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2); 173 unsigned MinorNum = getMetadataUInt(VersionMD, 1); 174 unsigned RevNum = getMetadataUInt(VersionMD, 2); 175 // Prevent Major part of OpenCL version to be 0 176 MAI.SrcLangVersion = 177 (std::max(1U, MajorNum) * 100 + MinorNum) * 1000 + RevNum; 178 } else { 179 // If there is no information about OpenCL version we are forced to generate 180 // OpenCL 1.0 by default for the OpenCL environment to avoid puzzling 181 // run-times with Unknown/0.0 version output. For a reference, LLVM-SPIRV 182 // Translator avoids potential issues with run-times in a similar manner. 183 if (ST->isOpenCLEnv()) { 184 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_CPP; 185 MAI.SrcLangVersion = 100000; 186 } else { 187 MAI.SrcLang = SPIRV::SourceLanguage::Unknown; 188 MAI.SrcLangVersion = 0; 189 } 190 } 191 192 if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) { 193 for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) { 194 MDNode *MD = ExtNode->getOperand(I); 195 if (!MD || MD->getNumOperands() == 0) 196 continue; 197 for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J) 198 MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString()); 199 } 200 } 201 202 // Update required capabilities for this memory model, addressing model and 203 // source language. 204 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, 205 MAI.Mem, *ST); 206 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::SourceLanguageOperand, 207 MAI.SrcLang, *ST); 208 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand, 209 MAI.Addr, *ST); 210 211 if (ST->isOpenCLEnv()) { 212 // TODO: check if it's required by default. 213 MAI.ExtInstSetMap[static_cast<unsigned>( 214 SPIRV::InstructionSet::OpenCL_std)] = 215 Register::index2VirtReg(MAI.getNextID()); 216 } 217 } 218 219 // Collect MI which defines the register in the given machine function. 220 static void collectDefInstr(Register Reg, const MachineFunction *MF, 221 SPIRV::ModuleAnalysisInfo *MAI, 222 SPIRV::ModuleSectionType MSType, 223 bool DoInsert = true) { 224 assert(MAI->hasRegisterAlias(MF, Reg) && "Cannot find register alias"); 225 MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(Reg); 226 assert(MI && "There should be an instruction that defines the register"); 227 MAI->setSkipEmission(MI); 228 if (DoInsert) 229 MAI->MS[MSType].push_back(MI); 230 } 231 232 void SPIRVModuleAnalysis::collectGlobalEntities( 233 const std::vector<SPIRV::DTSortableEntry *> &DepsGraph, 234 SPIRV::ModuleSectionType MSType, 235 std::function<bool(const SPIRV::DTSortableEntry *)> Pred, 236 bool UsePreOrder = false) { 237 DenseSet<const SPIRV::DTSortableEntry *> Visited; 238 for (const auto *E : DepsGraph) { 239 std::function<void(const SPIRV::DTSortableEntry *)> RecHoistUtil; 240 // NOTE: here we prefer recursive approach over iterative because 241 // we don't expect depchains long enough to cause SO. 242 RecHoistUtil = [MSType, UsePreOrder, &Visited, &Pred, 243 &RecHoistUtil](const SPIRV::DTSortableEntry *E) { 244 if (Visited.count(E) || !Pred(E)) 245 return; 246 Visited.insert(E); 247 248 // Traversing deps graph in post-order allows us to get rid of 249 // register aliases preprocessing. 250 // But pre-order is required for correct processing of function 251 // declaration and arguments processing. 252 if (!UsePreOrder) 253 for (auto *S : E->getDeps()) 254 RecHoistUtil(S); 255 256 Register GlobalReg = Register::index2VirtReg(MAI.getNextID()); 257 bool IsFirst = true; 258 for (auto &U : *E) { 259 const MachineFunction *MF = U.first; 260 Register Reg = U.second; 261 MAI.setRegisterAlias(MF, Reg, GlobalReg); 262 if (!MF->getRegInfo().getUniqueVRegDef(Reg)) 263 continue; 264 collectDefInstr(Reg, MF, &MAI, MSType, IsFirst); 265 IsFirst = false; 266 if (E->getIsGV()) 267 MAI.GlobalVarList.push_back(MF->getRegInfo().getUniqueVRegDef(Reg)); 268 } 269 270 if (UsePreOrder) 271 for (auto *S : E->getDeps()) 272 RecHoistUtil(S); 273 }; 274 RecHoistUtil(E); 275 } 276 } 277 278 // The function initializes global register alias table for types, consts, 279 // global vars and func decls and collects these instruction for output 280 // at module level. Also it collects explicit OpExtension/OpCapability 281 // instructions. 282 void SPIRVModuleAnalysis::processDefInstrs(const Module &M) { 283 std::vector<SPIRV::DTSortableEntry *> DepsGraph; 284 285 GR->buildDepsGraph(DepsGraph, TII, SPVDumpDeps ? MMI : nullptr); 286 287 collectGlobalEntities( 288 DepsGraph, SPIRV::MB_TypeConstVars, 289 [](const SPIRV::DTSortableEntry *E) { return !E->getIsFunc(); }); 290 291 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 292 MachineFunction *MF = MMI->getMachineFunction(*F); 293 if (!MF) 294 continue; 295 // Iterate through and collect OpExtension/OpCapability instructions. 296 for (MachineBasicBlock &MBB : *MF) { 297 for (MachineInstr &MI : MBB) { 298 if (MI.getOpcode() == SPIRV::OpExtension) { 299 // Here, OpExtension just has a single enum operand, not a string. 300 auto Ext = SPIRV::Extension::Extension(MI.getOperand(0).getImm()); 301 MAI.Reqs.addExtension(Ext); 302 MAI.setSkipEmission(&MI); 303 } else if (MI.getOpcode() == SPIRV::OpCapability) { 304 auto Cap = SPIRV::Capability::Capability(MI.getOperand(0).getImm()); 305 MAI.Reqs.addCapability(Cap); 306 MAI.setSkipEmission(&MI); 307 } 308 } 309 } 310 } 311 312 collectGlobalEntities( 313 DepsGraph, SPIRV::MB_ExtFuncDecls, 314 [](const SPIRV::DTSortableEntry *E) { return E->getIsFunc(); }, true); 315 } 316 317 // Look for IDs declared with Import linkage, and map the corresponding function 318 // to the register defining that variable (which will usually be the result of 319 // an OpFunction). This lets us call externally imported functions using 320 // the correct ID registers. 321 void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI, 322 const Function *F) { 323 if (MI.getOpcode() == SPIRV::OpDecorate) { 324 // If it's got Import linkage. 325 auto Dec = MI.getOperand(1).getImm(); 326 if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) { 327 auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm(); 328 if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) { 329 // Map imported function name to function ID register. 330 const Function *ImportedFunc = 331 F->getParent()->getFunction(getStringImm(MI, 2)); 332 Register Target = MI.getOperand(0).getReg(); 333 MAI.FuncMap[ImportedFunc] = MAI.getRegisterAlias(MI.getMF(), Target); 334 } 335 } 336 } else if (MI.getOpcode() == SPIRV::OpFunction) { 337 // Record all internal OpFunction declarations. 338 Register Reg = MI.defs().begin()->getReg(); 339 Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg); 340 assert(GlobalReg.isValid()); 341 MAI.FuncMap[F] = GlobalReg; 342 } 343 } 344 345 // References to a function via function pointers generate virtual 346 // registers without a definition. We are able to resolve this 347 // reference using Globar Register info into an OpFunction instruction 348 // and replace dummy operands by the corresponding global register references. 349 void SPIRVModuleAnalysis::collectFuncPtrs() { 350 for (auto &MI : MAI.MS[SPIRV::MB_TypeConstVars]) 351 if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL) 352 collectFuncPtrs(MI); 353 } 354 355 void SPIRVModuleAnalysis::collectFuncPtrs(MachineInstr *MI) { 356 const MachineOperand *FunUse = &MI->getOperand(2); 357 if (const MachineOperand *FunDef = GR->getFunctionDefinitionByUse(FunUse)) { 358 const MachineInstr *FunDefMI = FunDef->getParent(); 359 assert(FunDefMI->getOpcode() == SPIRV::OpFunction && 360 "Constant function pointer must refer to function definition"); 361 Register FunDefReg = FunDef->getReg(); 362 Register GlobalFunDefReg = 363 MAI.getRegisterAlias(FunDefMI->getMF(), FunDefReg); 364 assert(GlobalFunDefReg.isValid() && 365 "Function definition must refer to a global register"); 366 Register FunPtrReg = FunUse->getReg(); 367 MAI.setRegisterAlias(MI->getMF(), FunPtrReg, GlobalFunDefReg); 368 } 369 } 370 371 using InstrSignature = SmallVector<size_t>; 372 using InstrTraces = std::set<InstrSignature>; 373 374 // Returns a representation of an instruction as a vector of MachineOperand 375 // hash values, see llvm::hash_value(const MachineOperand &MO) for details. 376 // This creates a signature of the instruction with the same content 377 // that MachineOperand::isIdenticalTo uses for comparison. 378 static InstrSignature instrToSignature(MachineInstr &MI, 379 SPIRV::ModuleAnalysisInfo &MAI) { 380 InstrSignature Signature; 381 for (unsigned i = 0; i < MI.getNumOperands(); ++i) { 382 const MachineOperand &MO = MI.getOperand(i); 383 size_t h; 384 if (MO.isReg()) { 385 Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg()); 386 // mimic llvm::hash_value(const MachineOperand &MO) 387 h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(), 388 MO.isDef()); 389 } else { 390 h = hash_value(MO); 391 } 392 Signature.push_back(h); 393 } 394 return Signature; 395 } 396 397 // Collect the given instruction in the specified MS. We assume global register 398 // numbering has already occurred by this point. We can directly compare reg 399 // arguments when detecting duplicates. 400 static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI, 401 SPIRV::ModuleSectionType MSType, InstrTraces &IS, 402 bool Append = true) { 403 MAI.setSkipEmission(&MI); 404 InstrSignature MISign = instrToSignature(MI, MAI); 405 auto FoundMI = IS.insert(MISign); 406 if (!FoundMI.second) 407 return; // insert failed, so we found a duplicate; don't add it to MAI.MS 408 // No duplicates, so add it. 409 if (Append) 410 MAI.MS[MSType].push_back(&MI); 411 else 412 MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI); 413 } 414 415 // Some global instructions make reference to function-local ID regs, so cannot 416 // be correctly collected until these registers are globally numbered. 417 void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) { 418 InstrTraces IS; 419 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 420 if ((*F).isDeclaration()) 421 continue; 422 MachineFunction *MF = MMI->getMachineFunction(*F); 423 assert(MF); 424 425 for (MachineBasicBlock &MBB : *MF) 426 for (MachineInstr &MI : MBB) { 427 if (MAI.getSkipEmission(&MI)) 428 continue; 429 const unsigned OpCode = MI.getOpcode(); 430 if (OpCode == SPIRV::OpString) { 431 collectOtherInstr(MI, MAI, SPIRV::MB_DebugStrings, IS); 432 } else if (OpCode == SPIRV::OpExtInst && MI.getOperand(2).isImm() && 433 MI.getOperand(2).getImm() == 434 SPIRV::InstructionSet:: 435 NonSemantic_Shader_DebugInfo_100) { 436 MachineOperand Ins = MI.getOperand(3); 437 namespace NS = SPIRV::NonSemanticExtInst; 438 static constexpr int64_t GlobalNonSemanticDITy[] = { 439 NS::DebugSource, NS::DebugCompilationUnit, NS::DebugInfoNone, 440 NS::DebugTypeBasic, NS::DebugTypePointer}; 441 bool IsGlobalDI = false; 442 for (unsigned Idx = 0; Idx < std::size(GlobalNonSemanticDITy); ++Idx) 443 IsGlobalDI |= Ins.getImm() == GlobalNonSemanticDITy[Idx]; 444 if (IsGlobalDI) 445 collectOtherInstr(MI, MAI, SPIRV::MB_NonSemanticGlobalDI, IS); 446 } else if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) { 447 collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames, IS); 448 } else if (OpCode == SPIRV::OpEntryPoint) { 449 collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints, IS); 450 } else if (TII->isDecorationInstr(MI)) { 451 collectOtherInstr(MI, MAI, SPIRV::MB_Annotations, IS); 452 collectFuncNames(MI, &*F); 453 } else if (TII->isConstantInstr(MI)) { 454 // Now OpSpecConstant*s are not in DT, 455 // but they need to be collected anyway. 456 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS); 457 } else if (OpCode == SPIRV::OpFunction) { 458 collectFuncNames(MI, &*F); 459 } else if (OpCode == SPIRV::OpTypeForwardPointer) { 460 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS, false); 461 } 462 } 463 } 464 } 465 466 // Number registers in all functions globally from 0 onwards and store 467 // the result in global register alias table. Some registers are already 468 // numbered in collectGlobalEntities. 469 void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) { 470 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 471 if ((*F).isDeclaration()) 472 continue; 473 MachineFunction *MF = MMI->getMachineFunction(*F); 474 assert(MF); 475 for (MachineBasicBlock &MBB : *MF) { 476 for (MachineInstr &MI : MBB) { 477 for (MachineOperand &Op : MI.operands()) { 478 if (!Op.isReg()) 479 continue; 480 Register Reg = Op.getReg(); 481 if (MAI.hasRegisterAlias(MF, Reg)) 482 continue; 483 Register NewReg = Register::index2VirtReg(MAI.getNextID()); 484 MAI.setRegisterAlias(MF, Reg, NewReg); 485 } 486 if (MI.getOpcode() != SPIRV::OpExtInst) 487 continue; 488 auto Set = MI.getOperand(2).getImm(); 489 if (!MAI.ExtInstSetMap.contains(Set)) 490 MAI.ExtInstSetMap[Set] = Register::index2VirtReg(MAI.getNextID()); 491 } 492 } 493 } 494 } 495 496 // RequirementHandler implementations. 497 void SPIRV::RequirementHandler::getAndAddRequirements( 498 SPIRV::OperandCategory::OperandCategory Category, uint32_t i, 499 const SPIRVSubtarget &ST) { 500 addRequirements(getSymbolicOperandRequirements(Category, i, ST, *this)); 501 } 502 503 void SPIRV::RequirementHandler::recursiveAddCapabilities( 504 const CapabilityList &ToPrune) { 505 for (const auto &Cap : ToPrune) { 506 AllCaps.insert(Cap); 507 CapabilityList ImplicitDecls = 508 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap); 509 recursiveAddCapabilities(ImplicitDecls); 510 } 511 } 512 513 void SPIRV::RequirementHandler::addCapabilities(const CapabilityList &ToAdd) { 514 for (const auto &Cap : ToAdd) { 515 bool IsNewlyInserted = AllCaps.insert(Cap).second; 516 if (!IsNewlyInserted) // Don't re-add if it's already been declared. 517 continue; 518 CapabilityList ImplicitDecls = 519 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap); 520 recursiveAddCapabilities(ImplicitDecls); 521 MinimalCaps.push_back(Cap); 522 } 523 } 524 525 void SPIRV::RequirementHandler::addRequirements( 526 const SPIRV::Requirements &Req) { 527 if (!Req.IsSatisfiable) 528 report_fatal_error("Adding SPIR-V requirements this target can't satisfy."); 529 530 if (Req.Cap.has_value()) 531 addCapabilities({Req.Cap.value()}); 532 533 addExtensions(Req.Exts); 534 535 if (!Req.MinVer.empty()) { 536 if (!MaxVersion.empty() && Req.MinVer > MaxVersion) { 537 LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer 538 << " and <= " << MaxVersion << "\n"); 539 report_fatal_error("Adding SPIR-V requirements that can't be satisfied."); 540 } 541 542 if (MinVersion.empty() || Req.MinVer > MinVersion) 543 MinVersion = Req.MinVer; 544 } 545 546 if (!Req.MaxVer.empty()) { 547 if (!MinVersion.empty() && Req.MaxVer < MinVersion) { 548 LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer 549 << " and >= " << MinVersion << "\n"); 550 report_fatal_error("Adding SPIR-V requirements that can't be satisfied."); 551 } 552 553 if (MaxVersion.empty() || Req.MaxVer < MaxVersion) 554 MaxVersion = Req.MaxVer; 555 } 556 } 557 558 void SPIRV::RequirementHandler::checkSatisfiable( 559 const SPIRVSubtarget &ST) const { 560 // Report as many errors as possible before aborting the compilation. 561 bool IsSatisfiable = true; 562 auto TargetVer = ST.getSPIRVVersion(); 563 564 if (!MaxVersion.empty() && !TargetVer.empty() && MaxVersion < TargetVer) { 565 LLVM_DEBUG( 566 dbgs() << "Target SPIR-V version too high for required features\n" 567 << "Required max version: " << MaxVersion << " target version " 568 << TargetVer << "\n"); 569 IsSatisfiable = false; 570 } 571 572 if (!MinVersion.empty() && !TargetVer.empty() && MinVersion > TargetVer) { 573 LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n" 574 << "Required min version: " << MinVersion 575 << " target version " << TargetVer << "\n"); 576 IsSatisfiable = false; 577 } 578 579 if (!MinVersion.empty() && !MaxVersion.empty() && MinVersion > MaxVersion) { 580 LLVM_DEBUG( 581 dbgs() 582 << "Version is too low for some features and too high for others.\n" 583 << "Required SPIR-V min version: " << MinVersion 584 << " required SPIR-V max version " << MaxVersion << "\n"); 585 IsSatisfiable = false; 586 } 587 588 for (auto Cap : MinimalCaps) { 589 if (AvailableCaps.contains(Cap)) 590 continue; 591 LLVM_DEBUG(dbgs() << "Capability not supported: " 592 << getSymbolicOperandMnemonic( 593 OperandCategory::CapabilityOperand, Cap) 594 << "\n"); 595 IsSatisfiable = false; 596 } 597 598 for (auto Ext : AllExtensions) { 599 if (ST.canUseExtension(Ext)) 600 continue; 601 LLVM_DEBUG(dbgs() << "Extension not supported: " 602 << getSymbolicOperandMnemonic( 603 OperandCategory::ExtensionOperand, Ext) 604 << "\n"); 605 IsSatisfiable = false; 606 } 607 608 if (!IsSatisfiable) 609 report_fatal_error("Unable to meet SPIR-V requirements for this target."); 610 } 611 612 // Add the given capabilities and all their implicitly defined capabilities too. 613 void SPIRV::RequirementHandler::addAvailableCaps(const CapabilityList &ToAdd) { 614 for (const auto Cap : ToAdd) 615 if (AvailableCaps.insert(Cap).second) 616 addAvailableCaps(getSymbolicOperandCapabilities( 617 SPIRV::OperandCategory::CapabilityOperand, Cap)); 618 } 619 620 void SPIRV::RequirementHandler::removeCapabilityIf( 621 const Capability::Capability ToRemove, 622 const Capability::Capability IfPresent) { 623 if (AllCaps.contains(IfPresent)) 624 AllCaps.erase(ToRemove); 625 } 626 627 namespace llvm { 628 namespace SPIRV { 629 void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) { 630 // Provided by both all supported Vulkan versions and OpenCl. 631 addAvailableCaps({Capability::Shader, Capability::Linkage, Capability::Int8, 632 Capability::Int16}); 633 634 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 3))) 635 addAvailableCaps({Capability::GroupNonUniform, 636 Capability::GroupNonUniformVote, 637 Capability::GroupNonUniformArithmetic, 638 Capability::GroupNonUniformBallot, 639 Capability::GroupNonUniformClustered, 640 Capability::GroupNonUniformShuffle, 641 Capability::GroupNonUniformShuffleRelative}); 642 643 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6))) 644 addAvailableCaps({Capability::DotProduct, Capability::DotProductInputAll, 645 Capability::DotProductInput4x8Bit, 646 Capability::DotProductInput4x8BitPacked, 647 Capability::DemoteToHelperInvocation}); 648 649 // Add capabilities enabled by extensions. 650 for (auto Extension : ST.getAllAvailableExtensions()) { 651 CapabilityList EnabledCapabilities = 652 getCapabilitiesEnabledByExtension(Extension); 653 addAvailableCaps(EnabledCapabilities); 654 } 655 656 if (ST.isOpenCLEnv()) { 657 initAvailableCapabilitiesForOpenCL(ST); 658 return; 659 } 660 661 if (ST.isVulkanEnv()) { 662 initAvailableCapabilitiesForVulkan(ST); 663 return; 664 } 665 666 report_fatal_error("Unimplemented environment for SPIR-V generation."); 667 } 668 669 void RequirementHandler::initAvailableCapabilitiesForOpenCL( 670 const SPIRVSubtarget &ST) { 671 // Add the min requirements for different OpenCL and SPIR-V versions. 672 addAvailableCaps({Capability::Addresses, Capability::Float16Buffer, 673 Capability::Kernel, Capability::Vector16, 674 Capability::Groups, Capability::GenericPointer, 675 Capability::StorageImageWriteWithoutFormat, 676 Capability::StorageImageReadWithoutFormat}); 677 if (ST.hasOpenCLFullProfile()) 678 addAvailableCaps({Capability::Int64, Capability::Int64Atomics}); 679 if (ST.hasOpenCLImageSupport()) { 680 addAvailableCaps({Capability::ImageBasic, Capability::LiteralSampler, 681 Capability::Image1D, Capability::SampledBuffer, 682 Capability::ImageBuffer}); 683 if (ST.isAtLeastOpenCLVer(VersionTuple(2, 0))) 684 addAvailableCaps({Capability::ImageReadWrite}); 685 } 686 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 1)) && 687 ST.isAtLeastOpenCLVer(VersionTuple(2, 2))) 688 addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage}); 689 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 4))) 690 addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero, 691 Capability::SignedZeroInfNanPreserve, 692 Capability::RoundingModeRTE, 693 Capability::RoundingModeRTZ}); 694 // TODO: verify if this needs some checks. 695 addAvailableCaps({Capability::Float16, Capability::Float64}); 696 697 // TODO: add OpenCL extensions. 698 } 699 700 void RequirementHandler::initAvailableCapabilitiesForVulkan( 701 const SPIRVSubtarget &ST) { 702 703 // Core in Vulkan 1.1 and earlier. 704 addAvailableCaps({Capability::Int64, Capability::Float16, Capability::Float64, 705 Capability::GroupNonUniform, Capability::Image1D, 706 Capability::SampledBuffer, Capability::ImageBuffer, 707 Capability::UniformBufferArrayDynamicIndexing, 708 Capability::SampledImageArrayDynamicIndexing, 709 Capability::StorageBufferArrayDynamicIndexing, 710 Capability::StorageImageArrayDynamicIndexing}); 711 712 // Became core in Vulkan 1.2 713 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 5))) { 714 addAvailableCaps( 715 {Capability::ShaderNonUniformEXT, Capability::RuntimeDescriptorArrayEXT, 716 Capability::InputAttachmentArrayDynamicIndexingEXT, 717 Capability::UniformTexelBufferArrayDynamicIndexingEXT, 718 Capability::StorageTexelBufferArrayDynamicIndexingEXT, 719 Capability::UniformBufferArrayNonUniformIndexingEXT, 720 Capability::SampledImageArrayNonUniformIndexingEXT, 721 Capability::StorageBufferArrayNonUniformIndexingEXT, 722 Capability::StorageImageArrayNonUniformIndexingEXT, 723 Capability::InputAttachmentArrayNonUniformIndexingEXT, 724 Capability::UniformTexelBufferArrayNonUniformIndexingEXT, 725 Capability::StorageTexelBufferArrayNonUniformIndexingEXT}); 726 } 727 728 // Became core in Vulkan 1.3 729 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6))) 730 addAvailableCaps({Capability::StorageImageWriteWithoutFormat, 731 Capability::StorageImageReadWithoutFormat}); 732 } 733 734 } // namespace SPIRV 735 } // namespace llvm 736 737 // Add the required capabilities from a decoration instruction (including 738 // BuiltIns). 739 static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex, 740 SPIRV::RequirementHandler &Reqs, 741 const SPIRVSubtarget &ST) { 742 int64_t DecOp = MI.getOperand(DecIndex).getImm(); 743 auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp); 744 Reqs.addRequirements(getSymbolicOperandRequirements( 745 SPIRV::OperandCategory::DecorationOperand, Dec, ST, Reqs)); 746 747 if (Dec == SPIRV::Decoration::BuiltIn) { 748 int64_t BuiltInOp = MI.getOperand(DecIndex + 1).getImm(); 749 auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp); 750 Reqs.addRequirements(getSymbolicOperandRequirements( 751 SPIRV::OperandCategory::BuiltInOperand, BuiltIn, ST, Reqs)); 752 } else if (Dec == SPIRV::Decoration::LinkageAttributes) { 753 int64_t LinkageOp = MI.getOperand(MI.getNumOperands() - 1).getImm(); 754 SPIRV::LinkageType::LinkageType LnkType = 755 static_cast<SPIRV::LinkageType::LinkageType>(LinkageOp); 756 if (LnkType == SPIRV::LinkageType::LinkOnceODR) 757 Reqs.addExtension(SPIRV::Extension::SPV_KHR_linkonce_odr); 758 } else if (Dec == SPIRV::Decoration::CacheControlLoadINTEL || 759 Dec == SPIRV::Decoration::CacheControlStoreINTEL) { 760 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_cache_controls); 761 } else if (Dec == SPIRV::Decoration::HostAccessINTEL) { 762 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_global_variable_host_access); 763 } else if (Dec == SPIRV::Decoration::InitModeINTEL || 764 Dec == SPIRV::Decoration::ImplementInRegisterMapINTEL) { 765 Reqs.addExtension( 766 SPIRV::Extension::SPV_INTEL_global_variable_fpga_decorations); 767 } else if (Dec == SPIRV::Decoration::NonUniformEXT) { 768 Reqs.addRequirements(SPIRV::Capability::ShaderNonUniformEXT); 769 } 770 } 771 772 // Add requirements for image handling. 773 static void addOpTypeImageReqs(const MachineInstr &MI, 774 SPIRV::RequirementHandler &Reqs, 775 const SPIRVSubtarget &ST) { 776 assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage"); 777 // The operand indices used here are based on the OpTypeImage layout, which 778 // the MachineInstr follows as well. 779 int64_t ImgFormatOp = MI.getOperand(7).getImm(); 780 auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp); 781 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageFormatOperand, 782 ImgFormat, ST); 783 784 bool IsArrayed = MI.getOperand(4).getImm() == 1; 785 bool IsMultisampled = MI.getOperand(5).getImm() == 1; 786 bool NoSampler = MI.getOperand(6).getImm() == 2; 787 // Add dimension requirements. 788 assert(MI.getOperand(2).isImm()); 789 switch (MI.getOperand(2).getImm()) { 790 case SPIRV::Dim::DIM_1D: 791 Reqs.addRequirements(NoSampler ? SPIRV::Capability::Image1D 792 : SPIRV::Capability::Sampled1D); 793 break; 794 case SPIRV::Dim::DIM_2D: 795 if (IsMultisampled && NoSampler) 796 Reqs.addRequirements(SPIRV::Capability::ImageMSArray); 797 break; 798 case SPIRV::Dim::DIM_Cube: 799 Reqs.addRequirements(SPIRV::Capability::Shader); 800 if (IsArrayed) 801 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageCubeArray 802 : SPIRV::Capability::SampledCubeArray); 803 break; 804 case SPIRV::Dim::DIM_Rect: 805 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageRect 806 : SPIRV::Capability::SampledRect); 807 break; 808 case SPIRV::Dim::DIM_Buffer: 809 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageBuffer 810 : SPIRV::Capability::SampledBuffer); 811 break; 812 case SPIRV::Dim::DIM_SubpassData: 813 Reqs.addRequirements(SPIRV::Capability::InputAttachment); 814 break; 815 } 816 817 // Has optional access qualifier. 818 if (ST.isOpenCLEnv()) { 819 if (MI.getNumOperands() > 8 && 820 MI.getOperand(8).getImm() == SPIRV::AccessQualifier::ReadWrite) 821 Reqs.addRequirements(SPIRV::Capability::ImageReadWrite); 822 else 823 Reqs.addRequirements(SPIRV::Capability::ImageBasic); 824 } 825 } 826 827 // Add requirements for handling atomic float instructions 828 #define ATOM_FLT_REQ_EXT_MSG(ExtName) \ 829 "The atomic float instruction requires the following SPIR-V " \ 830 "extension: SPV_EXT_shader_atomic_float" ExtName 831 static void AddAtomicFloatRequirements(const MachineInstr &MI, 832 SPIRV::RequirementHandler &Reqs, 833 const SPIRVSubtarget &ST) { 834 assert(MI.getOperand(1).isReg() && 835 "Expect register operand in atomic float instruction"); 836 Register TypeReg = MI.getOperand(1).getReg(); 837 SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg); 838 if (TypeDef->getOpcode() != SPIRV::OpTypeFloat) 839 report_fatal_error("Result type of an atomic float instruction must be a " 840 "floating-point type scalar"); 841 842 unsigned BitWidth = TypeDef->getOperand(1).getImm(); 843 unsigned Op = MI.getOpcode(); 844 if (Op == SPIRV::OpAtomicFAddEXT) { 845 if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add)) 846 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_add"), false); 847 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add); 848 switch (BitWidth) { 849 case 16: 850 if (!ST.canUseExtension( 851 SPIRV::Extension::SPV_EXT_shader_atomic_float16_add)) 852 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false); 853 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add); 854 Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT); 855 break; 856 case 32: 857 Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT); 858 break; 859 case 64: 860 Reqs.addCapability(SPIRV::Capability::AtomicFloat64AddEXT); 861 break; 862 default: 863 report_fatal_error( 864 "Unexpected floating-point type width in atomic float instruction"); 865 } 866 } else { 867 if (!ST.canUseExtension( 868 SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max)) 869 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), false); 870 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max); 871 switch (BitWidth) { 872 case 16: 873 Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT); 874 break; 875 case 32: 876 Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT); 877 break; 878 case 64: 879 Reqs.addCapability(SPIRV::Capability::AtomicFloat64MinMaxEXT); 880 break; 881 default: 882 report_fatal_error( 883 "Unexpected floating-point type width in atomic float instruction"); 884 } 885 } 886 } 887 888 bool isUniformTexelBuffer(MachineInstr *ImageInst) { 889 if (ImageInst->getOpcode() != SPIRV::OpTypeImage) 890 return false; 891 uint32_t Dim = ImageInst->getOperand(2).getImm(); 892 uint32_t Sampled = ImageInst->getOperand(6).getImm(); 893 return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 1; 894 } 895 896 bool isStorageTexelBuffer(MachineInstr *ImageInst) { 897 if (ImageInst->getOpcode() != SPIRV::OpTypeImage) 898 return false; 899 uint32_t Dim = ImageInst->getOperand(2).getImm(); 900 uint32_t Sampled = ImageInst->getOperand(6).getImm(); 901 return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 2; 902 } 903 904 bool isSampledImage(MachineInstr *ImageInst) { 905 if (ImageInst->getOpcode() != SPIRV::OpTypeImage) 906 return false; 907 uint32_t Dim = ImageInst->getOperand(2).getImm(); 908 uint32_t Sampled = ImageInst->getOperand(6).getImm(); 909 return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 1; 910 } 911 912 bool isInputAttachment(MachineInstr *ImageInst) { 913 if (ImageInst->getOpcode() != SPIRV::OpTypeImage) 914 return false; 915 uint32_t Dim = ImageInst->getOperand(2).getImm(); 916 uint32_t Sampled = ImageInst->getOperand(6).getImm(); 917 return Dim == SPIRV::Dim::DIM_SubpassData && Sampled == 2; 918 } 919 920 bool isStorageImage(MachineInstr *ImageInst) { 921 if (ImageInst->getOpcode() != SPIRV::OpTypeImage) 922 return false; 923 uint32_t Dim = ImageInst->getOperand(2).getImm(); 924 uint32_t Sampled = ImageInst->getOperand(6).getImm(); 925 return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 2; 926 } 927 928 bool isCombinedImageSampler(MachineInstr *SampledImageInst) { 929 if (SampledImageInst->getOpcode() != SPIRV::OpTypeSampledImage) 930 return false; 931 932 const MachineRegisterInfo &MRI = SampledImageInst->getMF()->getRegInfo(); 933 Register ImageReg = SampledImageInst->getOperand(1).getReg(); 934 auto *ImageInst = MRI.getUniqueVRegDef(ImageReg); 935 return isSampledImage(ImageInst); 936 } 937 938 bool hasNonUniformDecoration(Register Reg, const MachineRegisterInfo &MRI) { 939 for (const auto &MI : MRI.reg_instructions(Reg)) { 940 if (MI.getOpcode() != SPIRV::OpDecorate) 941 continue; 942 943 uint32_t Dec = MI.getOperand(1).getImm(); 944 if (Dec == SPIRV::Decoration::NonUniformEXT) 945 return true; 946 } 947 return false; 948 } 949 950 void addOpAccessChainReqs(const MachineInstr &Instr, 951 SPIRV::RequirementHandler &Handler, 952 const SPIRVSubtarget &Subtarget) { 953 const MachineRegisterInfo &MRI = Instr.getMF()->getRegInfo(); 954 // Get the result type. If it is an image type, then the shader uses 955 // descriptor indexing. The appropriate capabilities will be added based 956 // on the specifics of the image. 957 Register ResTypeReg = Instr.getOperand(1).getReg(); 958 MachineInstr *ResTypeInst = MRI.getUniqueVRegDef(ResTypeReg); 959 960 assert(ResTypeInst->getOpcode() == SPIRV::OpTypePointer); 961 uint32_t StorageClass = ResTypeInst->getOperand(1).getImm(); 962 if (StorageClass != SPIRV::StorageClass::StorageClass::UniformConstant && 963 StorageClass != SPIRV::StorageClass::StorageClass::Uniform && 964 StorageClass != SPIRV::StorageClass::StorageClass::StorageBuffer) { 965 return; 966 } 967 968 Register PointeeTypeReg = ResTypeInst->getOperand(2).getReg(); 969 MachineInstr *PointeeType = MRI.getUniqueVRegDef(PointeeTypeReg); 970 if (PointeeType->getOpcode() != SPIRV::OpTypeImage && 971 PointeeType->getOpcode() != SPIRV::OpTypeSampledImage && 972 PointeeType->getOpcode() != SPIRV::OpTypeSampler) { 973 return; 974 } 975 976 bool IsNonUniform = 977 hasNonUniformDecoration(Instr.getOperand(0).getReg(), MRI); 978 if (isUniformTexelBuffer(PointeeType)) { 979 if (IsNonUniform) 980 Handler.addRequirements( 981 SPIRV::Capability::UniformTexelBufferArrayNonUniformIndexingEXT); 982 else 983 Handler.addRequirements( 984 SPIRV::Capability::UniformTexelBufferArrayDynamicIndexingEXT); 985 } else if (isInputAttachment(PointeeType)) { 986 if (IsNonUniform) 987 Handler.addRequirements( 988 SPIRV::Capability::InputAttachmentArrayNonUniformIndexingEXT); 989 else 990 Handler.addRequirements( 991 SPIRV::Capability::InputAttachmentArrayDynamicIndexingEXT); 992 } else if (isStorageTexelBuffer(PointeeType)) { 993 if (IsNonUniform) 994 Handler.addRequirements( 995 SPIRV::Capability::StorageTexelBufferArrayNonUniformIndexingEXT); 996 else 997 Handler.addRequirements( 998 SPIRV::Capability::StorageTexelBufferArrayDynamicIndexingEXT); 999 } else if (isSampledImage(PointeeType) || 1000 isCombinedImageSampler(PointeeType) || 1001 PointeeType->getOpcode() == SPIRV::OpTypeSampler) { 1002 if (IsNonUniform) 1003 Handler.addRequirements( 1004 SPIRV::Capability::SampledImageArrayNonUniformIndexingEXT); 1005 else 1006 Handler.addRequirements( 1007 SPIRV::Capability::SampledImageArrayDynamicIndexing); 1008 } else if (isStorageImage(PointeeType)) { 1009 if (IsNonUniform) 1010 Handler.addRequirements( 1011 SPIRV::Capability::StorageImageArrayNonUniformIndexingEXT); 1012 else 1013 Handler.addRequirements( 1014 SPIRV::Capability::StorageImageArrayDynamicIndexing); 1015 } 1016 } 1017 1018 static bool isImageTypeWithUnknownFormat(SPIRVType *TypeInst) { 1019 if (TypeInst->getOpcode() != SPIRV::OpTypeImage) 1020 return false; 1021 assert(TypeInst->getOperand(7).isImm() && "The image format must be an imm."); 1022 return TypeInst->getOperand(7).getImm() == 0; 1023 } 1024 1025 static void AddDotProductRequirements(const MachineInstr &MI, 1026 SPIRV::RequirementHandler &Reqs, 1027 const SPIRVSubtarget &ST) { 1028 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product)) 1029 Reqs.addExtension(SPIRV::Extension::SPV_KHR_integer_dot_product); 1030 Reqs.addCapability(SPIRV::Capability::DotProduct); 1031 1032 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); 1033 assert(MI.getOperand(2).isReg() && "Unexpected operand in dot"); 1034 // We do not consider what the previous instruction is. This is just used 1035 // to get the input register and to check the type. 1036 const MachineInstr *Input = MRI.getVRegDef(MI.getOperand(2).getReg()); 1037 assert(Input->getOperand(1).isReg() && "Unexpected operand in dot input"); 1038 Register InputReg = Input->getOperand(1).getReg(); 1039 1040 SPIRVType *TypeDef = MRI.getVRegDef(InputReg); 1041 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) { 1042 assert(TypeDef->getOperand(1).getImm() == 32); 1043 Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPacked); 1044 } else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) { 1045 SPIRVType *ScalarTypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg()); 1046 assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt); 1047 if (ScalarTypeDef->getOperand(1).getImm() == 8) { 1048 assert(TypeDef->getOperand(2).getImm() == 4 && 1049 "Dot operand of 8-bit integer type requires 4 components"); 1050 Reqs.addCapability(SPIRV::Capability::DotProductInput4x8Bit); 1051 } else { 1052 Reqs.addCapability(SPIRV::Capability::DotProductInputAll); 1053 } 1054 } 1055 } 1056 1057 void addInstrRequirements(const MachineInstr &MI, 1058 SPIRV::RequirementHandler &Reqs, 1059 const SPIRVSubtarget &ST) { 1060 switch (MI.getOpcode()) { 1061 case SPIRV::OpMemoryModel: { 1062 int64_t Addr = MI.getOperand(0).getImm(); 1063 Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand, 1064 Addr, ST); 1065 int64_t Mem = MI.getOperand(1).getImm(); 1066 Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, Mem, 1067 ST); 1068 break; 1069 } 1070 case SPIRV::OpEntryPoint: { 1071 int64_t Exe = MI.getOperand(0).getImm(); 1072 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModelOperand, 1073 Exe, ST); 1074 break; 1075 } 1076 case SPIRV::OpExecutionMode: 1077 case SPIRV::OpExecutionModeId: { 1078 int64_t Exe = MI.getOperand(1).getImm(); 1079 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModeOperand, 1080 Exe, ST); 1081 break; 1082 } 1083 case SPIRV::OpTypeMatrix: 1084 Reqs.addCapability(SPIRV::Capability::Matrix); 1085 break; 1086 case SPIRV::OpTypeInt: { 1087 unsigned BitWidth = MI.getOperand(1).getImm(); 1088 if (BitWidth == 64) 1089 Reqs.addCapability(SPIRV::Capability::Int64); 1090 else if (BitWidth == 16) 1091 Reqs.addCapability(SPIRV::Capability::Int16); 1092 else if (BitWidth == 8) 1093 Reqs.addCapability(SPIRV::Capability::Int8); 1094 break; 1095 } 1096 case SPIRV::OpTypeFloat: { 1097 unsigned BitWidth = MI.getOperand(1).getImm(); 1098 if (BitWidth == 64) 1099 Reqs.addCapability(SPIRV::Capability::Float64); 1100 else if (BitWidth == 16) 1101 Reqs.addCapability(SPIRV::Capability::Float16); 1102 break; 1103 } 1104 case SPIRV::OpTypeVector: { 1105 unsigned NumComponents = MI.getOperand(2).getImm(); 1106 if (NumComponents == 8 || NumComponents == 16) 1107 Reqs.addCapability(SPIRV::Capability::Vector16); 1108 break; 1109 } 1110 case SPIRV::OpTypePointer: { 1111 auto SC = MI.getOperand(1).getImm(); 1112 Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC, 1113 ST); 1114 // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer 1115 // capability. 1116 if (!ST.isOpenCLEnv()) 1117 break; 1118 assert(MI.getOperand(2).isReg()); 1119 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); 1120 SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg()); 1121 if (TypeDef->getOpcode() == SPIRV::OpTypeFloat && 1122 TypeDef->getOperand(1).getImm() == 16) 1123 Reqs.addCapability(SPIRV::Capability::Float16Buffer); 1124 break; 1125 } 1126 case SPIRV::OpExtInst: { 1127 if (MI.getOperand(2).getImm() == 1128 static_cast<int64_t>( 1129 SPIRV::InstructionSet::NonSemantic_Shader_DebugInfo_100)) { 1130 Reqs.addExtension(SPIRV::Extension::SPV_KHR_non_semantic_info); 1131 } 1132 break; 1133 } 1134 case SPIRV::OpBitReverse: 1135 case SPIRV::OpBitFieldInsert: 1136 case SPIRV::OpBitFieldSExtract: 1137 case SPIRV::OpBitFieldUExtract: 1138 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions)) { 1139 Reqs.addCapability(SPIRV::Capability::Shader); 1140 break; 1141 } 1142 Reqs.addExtension(SPIRV::Extension::SPV_KHR_bit_instructions); 1143 Reqs.addCapability(SPIRV::Capability::BitInstructions); 1144 break; 1145 case SPIRV::OpTypeRuntimeArray: 1146 Reqs.addCapability(SPIRV::Capability::Shader); 1147 break; 1148 case SPIRV::OpTypeOpaque: 1149 case SPIRV::OpTypeEvent: 1150 Reqs.addCapability(SPIRV::Capability::Kernel); 1151 break; 1152 case SPIRV::OpTypePipe: 1153 case SPIRV::OpTypeReserveId: 1154 Reqs.addCapability(SPIRV::Capability::Pipes); 1155 break; 1156 case SPIRV::OpTypeDeviceEvent: 1157 case SPIRV::OpTypeQueue: 1158 case SPIRV::OpBuildNDRange: 1159 Reqs.addCapability(SPIRV::Capability::DeviceEnqueue); 1160 break; 1161 case SPIRV::OpDecorate: 1162 case SPIRV::OpDecorateId: 1163 case SPIRV::OpDecorateString: 1164 addOpDecorateReqs(MI, 1, Reqs, ST); 1165 break; 1166 case SPIRV::OpMemberDecorate: 1167 case SPIRV::OpMemberDecorateString: 1168 addOpDecorateReqs(MI, 2, Reqs, ST); 1169 break; 1170 case SPIRV::OpInBoundsPtrAccessChain: 1171 Reqs.addCapability(SPIRV::Capability::Addresses); 1172 break; 1173 case SPIRV::OpConstantSampler: 1174 Reqs.addCapability(SPIRV::Capability::LiteralSampler); 1175 break; 1176 case SPIRV::OpInBoundsAccessChain: 1177 case SPIRV::OpAccessChain: 1178 addOpAccessChainReqs(MI, Reqs, ST); 1179 break; 1180 case SPIRV::OpTypeImage: 1181 addOpTypeImageReqs(MI, Reqs, ST); 1182 break; 1183 case SPIRV::OpTypeSampler: 1184 if (!ST.isVulkanEnv()) { 1185 Reqs.addCapability(SPIRV::Capability::ImageBasic); 1186 } 1187 break; 1188 case SPIRV::OpTypeForwardPointer: 1189 // TODO: check if it's OpenCL's kernel. 1190 Reqs.addCapability(SPIRV::Capability::Addresses); 1191 break; 1192 case SPIRV::OpAtomicFlagTestAndSet: 1193 case SPIRV::OpAtomicLoad: 1194 case SPIRV::OpAtomicStore: 1195 case SPIRV::OpAtomicExchange: 1196 case SPIRV::OpAtomicCompareExchange: 1197 case SPIRV::OpAtomicIIncrement: 1198 case SPIRV::OpAtomicIDecrement: 1199 case SPIRV::OpAtomicIAdd: 1200 case SPIRV::OpAtomicISub: 1201 case SPIRV::OpAtomicUMin: 1202 case SPIRV::OpAtomicUMax: 1203 case SPIRV::OpAtomicSMin: 1204 case SPIRV::OpAtomicSMax: 1205 case SPIRV::OpAtomicAnd: 1206 case SPIRV::OpAtomicOr: 1207 case SPIRV::OpAtomicXor: { 1208 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); 1209 const MachineInstr *InstrPtr = &MI; 1210 if (MI.getOpcode() == SPIRV::OpAtomicStore) { 1211 assert(MI.getOperand(3).isReg()); 1212 InstrPtr = MRI.getVRegDef(MI.getOperand(3).getReg()); 1213 assert(InstrPtr && "Unexpected type instruction for OpAtomicStore"); 1214 } 1215 assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic"); 1216 Register TypeReg = InstrPtr->getOperand(1).getReg(); 1217 SPIRVType *TypeDef = MRI.getVRegDef(TypeReg); 1218 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) { 1219 unsigned BitWidth = TypeDef->getOperand(1).getImm(); 1220 if (BitWidth == 64) 1221 Reqs.addCapability(SPIRV::Capability::Int64Atomics); 1222 } 1223 break; 1224 } 1225 case SPIRV::OpGroupNonUniformIAdd: 1226 case SPIRV::OpGroupNonUniformFAdd: 1227 case SPIRV::OpGroupNonUniformIMul: 1228 case SPIRV::OpGroupNonUniformFMul: 1229 case SPIRV::OpGroupNonUniformSMin: 1230 case SPIRV::OpGroupNonUniformUMin: 1231 case SPIRV::OpGroupNonUniformFMin: 1232 case SPIRV::OpGroupNonUniformSMax: 1233 case SPIRV::OpGroupNonUniformUMax: 1234 case SPIRV::OpGroupNonUniformFMax: 1235 case SPIRV::OpGroupNonUniformBitwiseAnd: 1236 case SPIRV::OpGroupNonUniformBitwiseOr: 1237 case SPIRV::OpGroupNonUniformBitwiseXor: 1238 case SPIRV::OpGroupNonUniformLogicalAnd: 1239 case SPIRV::OpGroupNonUniformLogicalOr: 1240 case SPIRV::OpGroupNonUniformLogicalXor: { 1241 assert(MI.getOperand(3).isImm()); 1242 int64_t GroupOp = MI.getOperand(3).getImm(); 1243 switch (GroupOp) { 1244 case SPIRV::GroupOperation::Reduce: 1245 case SPIRV::GroupOperation::InclusiveScan: 1246 case SPIRV::GroupOperation::ExclusiveScan: 1247 Reqs.addCapability(SPIRV::Capability::Kernel); 1248 Reqs.addCapability(SPIRV::Capability::GroupNonUniformArithmetic); 1249 Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot); 1250 break; 1251 case SPIRV::GroupOperation::ClusteredReduce: 1252 Reqs.addCapability(SPIRV::Capability::GroupNonUniformClustered); 1253 break; 1254 case SPIRV::GroupOperation::PartitionedReduceNV: 1255 case SPIRV::GroupOperation::PartitionedInclusiveScanNV: 1256 case SPIRV::GroupOperation::PartitionedExclusiveScanNV: 1257 Reqs.addCapability(SPIRV::Capability::GroupNonUniformPartitionedNV); 1258 break; 1259 } 1260 break; 1261 } 1262 case SPIRV::OpGroupNonUniformShuffle: 1263 case SPIRV::OpGroupNonUniformShuffleXor: 1264 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffle); 1265 break; 1266 case SPIRV::OpGroupNonUniformShuffleUp: 1267 case SPIRV::OpGroupNonUniformShuffleDown: 1268 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffleRelative); 1269 break; 1270 case SPIRV::OpGroupAll: 1271 case SPIRV::OpGroupAny: 1272 case SPIRV::OpGroupBroadcast: 1273 case SPIRV::OpGroupIAdd: 1274 case SPIRV::OpGroupFAdd: 1275 case SPIRV::OpGroupFMin: 1276 case SPIRV::OpGroupUMin: 1277 case SPIRV::OpGroupSMin: 1278 case SPIRV::OpGroupFMax: 1279 case SPIRV::OpGroupUMax: 1280 case SPIRV::OpGroupSMax: 1281 Reqs.addCapability(SPIRV::Capability::Groups); 1282 break; 1283 case SPIRV::OpGroupNonUniformElect: 1284 Reqs.addCapability(SPIRV::Capability::GroupNonUniform); 1285 break; 1286 case SPIRV::OpGroupNonUniformAll: 1287 case SPIRV::OpGroupNonUniformAny: 1288 case SPIRV::OpGroupNonUniformAllEqual: 1289 Reqs.addCapability(SPIRV::Capability::GroupNonUniformVote); 1290 break; 1291 case SPIRV::OpGroupNonUniformBroadcast: 1292 case SPIRV::OpGroupNonUniformBroadcastFirst: 1293 case SPIRV::OpGroupNonUniformBallot: 1294 case SPIRV::OpGroupNonUniformInverseBallot: 1295 case SPIRV::OpGroupNonUniformBallotBitExtract: 1296 case SPIRV::OpGroupNonUniformBallotBitCount: 1297 case SPIRV::OpGroupNonUniformBallotFindLSB: 1298 case SPIRV::OpGroupNonUniformBallotFindMSB: 1299 Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot); 1300 break; 1301 case SPIRV::OpSubgroupShuffleINTEL: 1302 case SPIRV::OpSubgroupShuffleDownINTEL: 1303 case SPIRV::OpSubgroupShuffleUpINTEL: 1304 case SPIRV::OpSubgroupShuffleXorINTEL: 1305 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) { 1306 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups); 1307 Reqs.addCapability(SPIRV::Capability::SubgroupShuffleINTEL); 1308 } 1309 break; 1310 case SPIRV::OpSubgroupBlockReadINTEL: 1311 case SPIRV::OpSubgroupBlockWriteINTEL: 1312 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) { 1313 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups); 1314 Reqs.addCapability(SPIRV::Capability::SubgroupBufferBlockIOINTEL); 1315 } 1316 break; 1317 case SPIRV::OpSubgroupImageBlockReadINTEL: 1318 case SPIRV::OpSubgroupImageBlockWriteINTEL: 1319 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) { 1320 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups); 1321 Reqs.addCapability(SPIRV::Capability::SubgroupImageBlockIOINTEL); 1322 } 1323 break; 1324 case SPIRV::OpSubgroupImageMediaBlockReadINTEL: 1325 case SPIRV::OpSubgroupImageMediaBlockWriteINTEL: 1326 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_media_block_io)) { 1327 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_media_block_io); 1328 Reqs.addCapability(SPIRV::Capability::SubgroupImageMediaBlockIOINTEL); 1329 } 1330 break; 1331 case SPIRV::OpAssumeTrueKHR: 1332 case SPIRV::OpExpectKHR: 1333 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) { 1334 Reqs.addExtension(SPIRV::Extension::SPV_KHR_expect_assume); 1335 Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR); 1336 } 1337 break; 1338 case SPIRV::OpPtrCastToCrossWorkgroupINTEL: 1339 case SPIRV::OpCrossWorkgroupCastToPtrINTEL: 1340 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) { 1341 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes); 1342 Reqs.addCapability(SPIRV::Capability::USMStorageClassesINTEL); 1343 } 1344 break; 1345 case SPIRV::OpConstantFunctionPointerINTEL: 1346 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) { 1347 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers); 1348 Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL); 1349 } 1350 break; 1351 case SPIRV::OpGroupNonUniformRotateKHR: 1352 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate)) 1353 report_fatal_error("OpGroupNonUniformRotateKHR instruction requires the " 1354 "following SPIR-V extension: SPV_KHR_subgroup_rotate", 1355 false); 1356 Reqs.addExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate); 1357 Reqs.addCapability(SPIRV::Capability::GroupNonUniformRotateKHR); 1358 Reqs.addCapability(SPIRV::Capability::GroupNonUniform); 1359 break; 1360 case SPIRV::OpGroupIMulKHR: 1361 case SPIRV::OpGroupFMulKHR: 1362 case SPIRV::OpGroupBitwiseAndKHR: 1363 case SPIRV::OpGroupBitwiseOrKHR: 1364 case SPIRV::OpGroupBitwiseXorKHR: 1365 case SPIRV::OpGroupLogicalAndKHR: 1366 case SPIRV::OpGroupLogicalOrKHR: 1367 case SPIRV::OpGroupLogicalXorKHR: 1368 if (ST.canUseExtension( 1369 SPIRV::Extension::SPV_KHR_uniform_group_instructions)) { 1370 Reqs.addExtension(SPIRV::Extension::SPV_KHR_uniform_group_instructions); 1371 Reqs.addCapability(SPIRV::Capability::GroupUniformArithmeticKHR); 1372 } 1373 break; 1374 case SPIRV::OpReadClockKHR: 1375 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_shader_clock)) 1376 report_fatal_error("OpReadClockKHR instruction requires the " 1377 "following SPIR-V extension: SPV_KHR_shader_clock", 1378 false); 1379 Reqs.addExtension(SPIRV::Extension::SPV_KHR_shader_clock); 1380 Reqs.addCapability(SPIRV::Capability::ShaderClockKHR); 1381 break; 1382 case SPIRV::OpFunctionPointerCallINTEL: 1383 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) { 1384 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers); 1385 Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL); 1386 } 1387 break; 1388 case SPIRV::OpAtomicFAddEXT: 1389 case SPIRV::OpAtomicFMinEXT: 1390 case SPIRV::OpAtomicFMaxEXT: 1391 AddAtomicFloatRequirements(MI, Reqs, ST); 1392 break; 1393 case SPIRV::OpConvertBF16ToFINTEL: 1394 case SPIRV::OpConvertFToBF16INTEL: 1395 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) { 1396 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion); 1397 Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL); 1398 } 1399 break; 1400 case SPIRV::OpVariableLengthArrayINTEL: 1401 case SPIRV::OpSaveMemoryINTEL: 1402 case SPIRV::OpRestoreMemoryINTEL: 1403 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_variable_length_array)) { 1404 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_variable_length_array); 1405 Reqs.addCapability(SPIRV::Capability::VariableLengthArrayINTEL); 1406 } 1407 break; 1408 case SPIRV::OpAsmTargetINTEL: 1409 case SPIRV::OpAsmINTEL: 1410 case SPIRV::OpAsmCallINTEL: 1411 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_inline_assembly)) { 1412 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_inline_assembly); 1413 Reqs.addCapability(SPIRV::Capability::AsmINTEL); 1414 } 1415 break; 1416 case SPIRV::OpTypeCooperativeMatrixKHR: 1417 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix)) 1418 report_fatal_error( 1419 "OpTypeCooperativeMatrixKHR type requires the " 1420 "following SPIR-V extension: SPV_KHR_cooperative_matrix", 1421 false); 1422 Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix); 1423 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR); 1424 break; 1425 case SPIRV::OpArithmeticFenceEXT: 1426 if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence)) 1427 report_fatal_error("OpArithmeticFenceEXT requires the " 1428 "following SPIR-V extension: SPV_EXT_arithmetic_fence", 1429 false); 1430 Reqs.addExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence); 1431 Reqs.addCapability(SPIRV::Capability::ArithmeticFenceEXT); 1432 break; 1433 case SPIRV::OpControlBarrierArriveINTEL: 1434 case SPIRV::OpControlBarrierWaitINTEL: 1435 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_split_barrier)) { 1436 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_split_barrier); 1437 Reqs.addCapability(SPIRV::Capability::SplitBarrierINTEL); 1438 } 1439 break; 1440 case SPIRV::OpCooperativeMatrixMulAddKHR: { 1441 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix)) 1442 report_fatal_error("Cooperative matrix instructions require the " 1443 "following SPIR-V extension: " 1444 "SPV_KHR_cooperative_matrix", 1445 false); 1446 Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix); 1447 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR); 1448 constexpr unsigned MulAddMaxSize = 6; 1449 if (MI.getNumOperands() != MulAddMaxSize) 1450 break; 1451 const int64_t CoopOperands = MI.getOperand(MulAddMaxSize - 1).getImm(); 1452 if (CoopOperands & 1453 SPIRV::CooperativeMatrixOperands::MatrixAAndBTF32ComponentsINTEL) { 1454 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix)) 1455 report_fatal_error("MatrixAAndBTF32ComponentsINTEL type interpretation " 1456 "require the following SPIR-V extension: " 1457 "SPV_INTEL_joint_matrix", 1458 false); 1459 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix); 1460 Reqs.addCapability( 1461 SPIRV::Capability::CooperativeMatrixTF32ComponentTypeINTEL); 1462 } 1463 if (CoopOperands & SPIRV::CooperativeMatrixOperands:: 1464 MatrixAAndBBFloat16ComponentsINTEL || 1465 CoopOperands & 1466 SPIRV::CooperativeMatrixOperands::MatrixCBFloat16ComponentsINTEL || 1467 CoopOperands & SPIRV::CooperativeMatrixOperands:: 1468 MatrixResultBFloat16ComponentsINTEL) { 1469 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix)) 1470 report_fatal_error("***BF16ComponentsINTEL type interpretations " 1471 "require the following SPIR-V extension: " 1472 "SPV_INTEL_joint_matrix", 1473 false); 1474 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix); 1475 Reqs.addCapability( 1476 SPIRV::Capability::CooperativeMatrixBFloat16ComponentTypeINTEL); 1477 } 1478 break; 1479 } 1480 case SPIRV::OpCooperativeMatrixLoadKHR: 1481 case SPIRV::OpCooperativeMatrixStoreKHR: 1482 case SPIRV::OpCooperativeMatrixLoadCheckedINTEL: 1483 case SPIRV::OpCooperativeMatrixStoreCheckedINTEL: 1484 case SPIRV::OpCooperativeMatrixPrefetchINTEL: { 1485 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix)) 1486 report_fatal_error("Cooperative matrix instructions require the " 1487 "following SPIR-V extension: " 1488 "SPV_KHR_cooperative_matrix", 1489 false); 1490 Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix); 1491 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR); 1492 1493 // Check Layout operand in case if it's not a standard one and add the 1494 // appropriate capability. 1495 std::unordered_map<unsigned, unsigned> LayoutToInstMap = { 1496 {SPIRV::OpCooperativeMatrixLoadKHR, 3}, 1497 {SPIRV::OpCooperativeMatrixStoreKHR, 2}, 1498 {SPIRV::OpCooperativeMatrixLoadCheckedINTEL, 5}, 1499 {SPIRV::OpCooperativeMatrixStoreCheckedINTEL, 4}, 1500 {SPIRV::OpCooperativeMatrixPrefetchINTEL, 4}}; 1501 1502 const auto OpCode = MI.getOpcode(); 1503 const unsigned LayoutNum = LayoutToInstMap[OpCode]; 1504 Register RegLayout = MI.getOperand(LayoutNum).getReg(); 1505 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); 1506 MachineInstr *MILayout = MRI.getUniqueVRegDef(RegLayout); 1507 if (MILayout->getOpcode() == SPIRV::OpConstantI) { 1508 const unsigned LayoutVal = MILayout->getOperand(2).getImm(); 1509 if (LayoutVal == 1510 static_cast<unsigned>(SPIRV::CooperativeMatrixLayout::PackedINTEL)) { 1511 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix)) 1512 report_fatal_error("PackedINTEL layout require the following SPIR-V " 1513 "extension: SPV_INTEL_joint_matrix", 1514 false); 1515 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix); 1516 Reqs.addCapability(SPIRV::Capability::PackedCooperativeMatrixINTEL); 1517 } 1518 } 1519 1520 // Nothing to do. 1521 if (OpCode == SPIRV::OpCooperativeMatrixLoadKHR || 1522 OpCode == SPIRV::OpCooperativeMatrixStoreKHR) 1523 break; 1524 1525 std::string InstName; 1526 switch (OpCode) { 1527 case SPIRV::OpCooperativeMatrixPrefetchINTEL: 1528 InstName = "OpCooperativeMatrixPrefetchINTEL"; 1529 break; 1530 case SPIRV::OpCooperativeMatrixLoadCheckedINTEL: 1531 InstName = "OpCooperativeMatrixLoadCheckedINTEL"; 1532 break; 1533 case SPIRV::OpCooperativeMatrixStoreCheckedINTEL: 1534 InstName = "OpCooperativeMatrixStoreCheckedINTEL"; 1535 break; 1536 } 1537 1538 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix)) { 1539 const std::string ErrorMsg = 1540 InstName + " instruction requires the " 1541 "following SPIR-V extension: SPV_INTEL_joint_matrix"; 1542 report_fatal_error(ErrorMsg.c_str(), false); 1543 } 1544 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix); 1545 if (OpCode == SPIRV::OpCooperativeMatrixPrefetchINTEL) { 1546 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixPrefetchINTEL); 1547 break; 1548 } 1549 Reqs.addCapability( 1550 SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL); 1551 break; 1552 } 1553 case SPIRV::OpCooperativeMatrixConstructCheckedINTEL: 1554 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix)) 1555 report_fatal_error("OpCooperativeMatrixConstructCheckedINTEL " 1556 "instructions require the following SPIR-V extension: " 1557 "SPV_INTEL_joint_matrix", 1558 false); 1559 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix); 1560 Reqs.addCapability( 1561 SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL); 1562 break; 1563 case SPIRV::OpCooperativeMatrixGetElementCoordINTEL: 1564 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix)) 1565 report_fatal_error("OpCooperativeMatrixGetElementCoordINTEL requires the " 1566 "following SPIR-V extension: SPV_INTEL_joint_matrix", 1567 false); 1568 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix); 1569 Reqs.addCapability( 1570 SPIRV::Capability::CooperativeMatrixInvocationInstructionsINTEL); 1571 break; 1572 case SPIRV::OpKill: { 1573 Reqs.addCapability(SPIRV::Capability::Shader); 1574 } break; 1575 case SPIRV::OpDemoteToHelperInvocation: 1576 Reqs.addCapability(SPIRV::Capability::DemoteToHelperInvocation); 1577 1578 if (ST.canUseExtension( 1579 SPIRV::Extension::SPV_EXT_demote_to_helper_invocation)) { 1580 if (!ST.isAtLeastSPIRVVer(llvm::VersionTuple(1, 6))) 1581 Reqs.addExtension( 1582 SPIRV::Extension::SPV_EXT_demote_to_helper_invocation); 1583 } 1584 break; 1585 case SPIRV::OpSDot: 1586 case SPIRV::OpUDot: 1587 AddDotProductRequirements(MI, Reqs, ST); 1588 break; 1589 case SPIRV::OpImageRead: { 1590 Register ImageReg = MI.getOperand(2).getReg(); 1591 SPIRVType *TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(ImageReg); 1592 if (isImageTypeWithUnknownFormat(TypeDef)) 1593 Reqs.addCapability(SPIRV::Capability::StorageImageReadWithoutFormat); 1594 break; 1595 } 1596 case SPIRV::OpImageWrite: { 1597 Register ImageReg = MI.getOperand(0).getReg(); 1598 SPIRVType *TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(ImageReg); 1599 if (isImageTypeWithUnknownFormat(TypeDef)) 1600 Reqs.addCapability(SPIRV::Capability::StorageImageWriteWithoutFormat); 1601 break; 1602 } 1603 1604 default: 1605 break; 1606 } 1607 1608 // If we require capability Shader, then we can remove the requirement for 1609 // the BitInstructions capability, since Shader is a superset capability 1610 // of BitInstructions. 1611 Reqs.removeCapabilityIf(SPIRV::Capability::BitInstructions, 1612 SPIRV::Capability::Shader); 1613 } 1614 1615 static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, 1616 MachineModuleInfo *MMI, const SPIRVSubtarget &ST) { 1617 // Collect requirements for existing instructions. 1618 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 1619 MachineFunction *MF = MMI->getMachineFunction(*F); 1620 if (!MF) 1621 continue; 1622 for (const MachineBasicBlock &MBB : *MF) 1623 for (const MachineInstr &MI : MBB) 1624 addInstrRequirements(MI, MAI.Reqs, ST); 1625 } 1626 // Collect requirements for OpExecutionMode instructions. 1627 auto Node = M.getNamedMetadata("spirv.ExecutionMode"); 1628 if (Node) { 1629 // SPV_KHR_float_controls is not available until v1.4 1630 bool RequireFloatControls = false, 1631 VerLower14 = !ST.isAtLeastSPIRVVer(VersionTuple(1, 4)); 1632 for (unsigned i = 0; i < Node->getNumOperands(); i++) { 1633 MDNode *MDN = cast<MDNode>(Node->getOperand(i)); 1634 const MDOperand &MDOp = MDN->getOperand(1); 1635 if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) { 1636 Constant *C = CMeta->getValue(); 1637 if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) { 1638 auto EM = Const->getZExtValue(); 1639 MAI.Reqs.getAndAddRequirements( 1640 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST); 1641 // add SPV_KHR_float_controls if the version is too low 1642 switch (EM) { 1643 case SPIRV::ExecutionMode::DenormPreserve: 1644 case SPIRV::ExecutionMode::DenormFlushToZero: 1645 case SPIRV::ExecutionMode::SignedZeroInfNanPreserve: 1646 case SPIRV::ExecutionMode::RoundingModeRTE: 1647 case SPIRV::ExecutionMode::RoundingModeRTZ: 1648 RequireFloatControls = VerLower14; 1649 break; 1650 } 1651 } 1652 } 1653 } 1654 if (RequireFloatControls && 1655 ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls)) 1656 MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls); 1657 } 1658 for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) { 1659 const Function &F = *FI; 1660 if (F.isDeclaration()) 1661 continue; 1662 if (F.getMetadata("reqd_work_group_size")) 1663 MAI.Reqs.getAndAddRequirements( 1664 SPIRV::OperandCategory::ExecutionModeOperand, 1665 SPIRV::ExecutionMode::LocalSize, ST); 1666 if (F.getFnAttribute("hlsl.numthreads").isValid()) { 1667 MAI.Reqs.getAndAddRequirements( 1668 SPIRV::OperandCategory::ExecutionModeOperand, 1669 SPIRV::ExecutionMode::LocalSize, ST); 1670 } 1671 if (F.getMetadata("work_group_size_hint")) 1672 MAI.Reqs.getAndAddRequirements( 1673 SPIRV::OperandCategory::ExecutionModeOperand, 1674 SPIRV::ExecutionMode::LocalSizeHint, ST); 1675 if (F.getMetadata("intel_reqd_sub_group_size")) 1676 MAI.Reqs.getAndAddRequirements( 1677 SPIRV::OperandCategory::ExecutionModeOperand, 1678 SPIRV::ExecutionMode::SubgroupSize, ST); 1679 if (F.getMetadata("vec_type_hint")) 1680 MAI.Reqs.getAndAddRequirements( 1681 SPIRV::OperandCategory::ExecutionModeOperand, 1682 SPIRV::ExecutionMode::VecTypeHint, ST); 1683 1684 if (F.hasOptNone()) { 1685 if (ST.canUseExtension(SPIRV::Extension::SPV_EXT_optnone)) { 1686 MAI.Reqs.addExtension(SPIRV::Extension::SPV_EXT_optnone); 1687 MAI.Reqs.addCapability(SPIRV::Capability::OptNoneEXT); 1688 } else if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_optnone)) { 1689 MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_optnone); 1690 MAI.Reqs.addCapability(SPIRV::Capability::OptNoneINTEL); 1691 } 1692 } 1693 } 1694 } 1695 1696 static unsigned getFastMathFlags(const MachineInstr &I) { 1697 unsigned Flags = SPIRV::FPFastMathMode::None; 1698 if (I.getFlag(MachineInstr::MIFlag::FmNoNans)) 1699 Flags |= SPIRV::FPFastMathMode::NotNaN; 1700 if (I.getFlag(MachineInstr::MIFlag::FmNoInfs)) 1701 Flags |= SPIRV::FPFastMathMode::NotInf; 1702 if (I.getFlag(MachineInstr::MIFlag::FmNsz)) 1703 Flags |= SPIRV::FPFastMathMode::NSZ; 1704 if (I.getFlag(MachineInstr::MIFlag::FmArcp)) 1705 Flags |= SPIRV::FPFastMathMode::AllowRecip; 1706 if (I.getFlag(MachineInstr::MIFlag::FmReassoc)) 1707 Flags |= SPIRV::FPFastMathMode::Fast; 1708 return Flags; 1709 } 1710 1711 static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST, 1712 const SPIRVInstrInfo &TII, 1713 SPIRV::RequirementHandler &Reqs) { 1714 if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) && 1715 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand, 1716 SPIRV::Decoration::NoSignedWrap, ST, Reqs) 1717 .IsSatisfiable) { 1718 buildOpDecorate(I.getOperand(0).getReg(), I, TII, 1719 SPIRV::Decoration::NoSignedWrap, {}); 1720 } 1721 if (I.getFlag(MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(I) && 1722 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand, 1723 SPIRV::Decoration::NoUnsignedWrap, ST, 1724 Reqs) 1725 .IsSatisfiable) { 1726 buildOpDecorate(I.getOperand(0).getReg(), I, TII, 1727 SPIRV::Decoration::NoUnsignedWrap, {}); 1728 } 1729 if (!TII.canUseFastMathFlags(I)) 1730 return; 1731 unsigned FMFlags = getFastMathFlags(I); 1732 if (FMFlags == SPIRV::FPFastMathMode::None) 1733 return; 1734 Register DstReg = I.getOperand(0).getReg(); 1735 buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode, {FMFlags}); 1736 } 1737 1738 // Walk all functions and add decorations related to MI flags. 1739 static void addDecorations(const Module &M, const SPIRVInstrInfo &TII, 1740 MachineModuleInfo *MMI, const SPIRVSubtarget &ST, 1741 SPIRV::ModuleAnalysisInfo &MAI) { 1742 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 1743 MachineFunction *MF = MMI->getMachineFunction(*F); 1744 if (!MF) 1745 continue; 1746 for (auto &MBB : *MF) 1747 for (auto &MI : MBB) 1748 handleMIFlagDecoration(MI, ST, TII, MAI.Reqs); 1749 } 1750 } 1751 1752 static void addMBBNames(const Module &M, const SPIRVInstrInfo &TII, 1753 MachineModuleInfo *MMI, const SPIRVSubtarget &ST, 1754 SPIRV::ModuleAnalysisInfo &MAI) { 1755 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 1756 MachineFunction *MF = MMI->getMachineFunction(*F); 1757 if (!MF) 1758 continue; 1759 MachineRegisterInfo &MRI = MF->getRegInfo(); 1760 for (auto &MBB : *MF) { 1761 if (!MBB.hasName() || MBB.empty()) 1762 continue; 1763 // Emit basic block names. 1764 Register Reg = MRI.createGenericVirtualRegister(LLT::scalar(64)); 1765 MRI.setRegClass(Reg, &SPIRV::IDRegClass); 1766 buildOpName(Reg, MBB.getName(), *std::prev(MBB.end()), TII); 1767 Register GlobalReg = MAI.getOrCreateMBBRegister(MBB); 1768 MAI.setRegisterAlias(MF, Reg, GlobalReg); 1769 } 1770 } 1771 } 1772 1773 // patching Instruction::PHI to SPIRV::OpPhi 1774 static void patchPhis(const Module &M, SPIRVGlobalRegistry *GR, 1775 const SPIRVInstrInfo &TII, MachineModuleInfo *MMI) { 1776 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 1777 MachineFunction *MF = MMI->getMachineFunction(*F); 1778 if (!MF) 1779 continue; 1780 for (auto &MBB : *MF) { 1781 for (MachineInstr &MI : MBB) { 1782 if (MI.getOpcode() != TargetOpcode::PHI) 1783 continue; 1784 MI.setDesc(TII.get(SPIRV::OpPhi)); 1785 Register ResTypeReg = GR->getSPIRVTypeID( 1786 GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg(), MF)); 1787 MI.insert(MI.operands_begin() + 1, 1788 {MachineOperand::CreateReg(ResTypeReg, false)}); 1789 } 1790 } 1791 } 1792 } 1793 1794 struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI; 1795 1796 void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { 1797 AU.addRequired<TargetPassConfig>(); 1798 AU.addRequired<MachineModuleInfoWrapperPass>(); 1799 } 1800 1801 bool SPIRVModuleAnalysis::runOnModule(Module &M) { 1802 SPIRVTargetMachine &TM = 1803 getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>(); 1804 ST = TM.getSubtargetImpl(); 1805 GR = ST->getSPIRVGlobalRegistry(); 1806 TII = ST->getInstrInfo(); 1807 1808 MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI(); 1809 1810 setBaseInfo(M); 1811 1812 patchPhis(M, GR, *TII, MMI); 1813 1814 addMBBNames(M, *TII, MMI, *ST, MAI); 1815 addDecorations(M, *TII, MMI, *ST, MAI); 1816 1817 collectReqs(M, MAI, MMI, *ST); 1818 1819 // Process type/const/global var/func decl instructions, number their 1820 // destination registers from 0 to N, collect Extensions and Capabilities. 1821 processDefInstrs(M); 1822 1823 // Number rest of registers from N+1 onwards. 1824 numberRegistersGlobally(M); 1825 1826 // Update references to OpFunction instructions to use Global Registers 1827 if (GR->hasConstFunPtr()) 1828 collectFuncPtrs(); 1829 1830 // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions. 1831 processOtherInstrs(M); 1832 1833 // If there are no entry points, we need the Linkage capability. 1834 if (MAI.MS[SPIRV::MB_EntryPoints].empty()) 1835 MAI.Reqs.addCapability(SPIRV::Capability::Linkage); 1836 1837 // Set maximum ID used. 1838 GR->setBound(MAI.MaxID); 1839 1840 return false; 1841 } 1842