1 //===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The analysis collects instructions that should be output at the module level 10 // and performs the global register numbering. 11 // 12 // The results of this analysis are used in AsmPrinter to rename registers 13 // globally and to output required instructions at the module level. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "SPIRVModuleAnalysis.h" 18 #include "MCTargetDesc/SPIRVBaseInfo.h" 19 #include "MCTargetDesc/SPIRVMCTargetDesc.h" 20 #include "SPIRV.h" 21 #include "SPIRVSubtarget.h" 22 #include "SPIRVTargetMachine.h" 23 #include "SPIRVUtils.h" 24 #include "TargetInfo/SPIRVTargetInfo.h" 25 #include "llvm/ADT/STLExtras.h" 26 #include "llvm/CodeGen/MachineModuleInfo.h" 27 #include "llvm/CodeGen/TargetPassConfig.h" 28 29 using namespace llvm; 30 31 #define DEBUG_TYPE "spirv-module-analysis" 32 33 static cl::opt<bool> 34 SPVDumpDeps("spv-dump-deps", 35 cl::desc("Dump MIR with SPIR-V dependencies info"), 36 cl::Optional, cl::init(false)); 37 38 static cl::list<SPIRV::Capability::Capability> 39 AvoidCapabilities("avoid-spirv-capabilities", 40 cl::desc("SPIR-V capabilities to avoid if there are " 41 "other options enabling a feature"), 42 cl::ZeroOrMore, cl::Hidden, 43 cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader", 44 "SPIR-V Shader capability"))); 45 // Use sets instead of cl::list to check "if contains" condition 46 struct AvoidCapabilitiesSet { 47 SmallSet<SPIRV::Capability::Capability, 4> S; 48 AvoidCapabilitiesSet() { 49 for (auto Cap : AvoidCapabilities) 50 S.insert(Cap); 51 } 52 }; 53 54 char llvm::SPIRVModuleAnalysis::ID = 0; 55 56 namespace llvm { 57 void initializeSPIRVModuleAnalysisPass(PassRegistry &); 58 } // namespace llvm 59 60 INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true, 61 true) 62 63 // Retrieve an unsigned from an MDNode with a list of them as operands. 64 static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex, 65 unsigned DefaultVal = 0) { 66 if (MdNode && OpIndex < MdNode->getNumOperands()) { 67 const auto &Op = MdNode->getOperand(OpIndex); 68 return mdconst::extract<ConstantInt>(Op)->getZExtValue(); 69 } 70 return DefaultVal; 71 } 72 73 static SPIRV::Requirements 74 getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category, 75 unsigned i, const SPIRVSubtarget &ST, 76 SPIRV::RequirementHandler &Reqs) { 77 static AvoidCapabilitiesSet 78 AvoidCaps; // contains capabilities to avoid if there is another option 79 unsigned ReqMinVer = getSymbolicOperandMinVersion(Category, i); 80 unsigned ReqMaxVer = getSymbolicOperandMaxVersion(Category, i); 81 unsigned TargetVer = ST.getSPIRVVersion(); 82 bool MinVerOK = !ReqMinVer || !TargetVer || TargetVer >= ReqMinVer; 83 bool MaxVerOK = !ReqMaxVer || !TargetVer || TargetVer <= ReqMaxVer; 84 CapabilityList ReqCaps = getSymbolicOperandCapabilities(Category, i); 85 ExtensionList ReqExts = getSymbolicOperandExtensions(Category, i); 86 if (ReqCaps.empty()) { 87 if (ReqExts.empty()) { 88 if (MinVerOK && MaxVerOK) 89 return {true, {}, {}, ReqMinVer, ReqMaxVer}; 90 return {false, {}, {}, 0, 0}; 91 } 92 } else if (MinVerOK && MaxVerOK) { 93 if (ReqCaps.size() == 1) { 94 auto Cap = ReqCaps[0]; 95 if (Reqs.isCapabilityAvailable(Cap)) 96 return {true, {Cap}, {}, ReqMinVer, ReqMaxVer}; 97 } else { 98 // By SPIR-V specification: "If an instruction, enumerant, or other 99 // feature specifies multiple enabling capabilities, only one such 100 // capability needs to be declared to use the feature." However, one 101 // capability may be preferred over another. We use command line 102 // argument(s) and AvoidCapabilities to avoid selection of certain 103 // capabilities if there are other options. 104 CapabilityList UseCaps; 105 for (auto Cap : ReqCaps) 106 if (Reqs.isCapabilityAvailable(Cap)) 107 UseCaps.push_back(Cap); 108 for (size_t i = 0, Sz = UseCaps.size(); i < Sz; ++i) { 109 auto Cap = UseCaps[i]; 110 if (i == Sz - 1 || !AvoidCaps.S.contains(Cap)) 111 return {true, {Cap}, {}, ReqMinVer, ReqMaxVer}; 112 } 113 } 114 } 115 // If there are no capabilities, or we can't satisfy the version or 116 // capability requirements, use the list of extensions (if the subtarget 117 // can handle them all). 118 if (llvm::all_of(ReqExts, [&ST](const SPIRV::Extension::Extension &Ext) { 119 return ST.canUseExtension(Ext); 120 })) { 121 return {true, {}, ReqExts, 0, 0}; // TODO: add versions to extensions. 122 } 123 return {false, {}, {}, 0, 0}; 124 } 125 126 void SPIRVModuleAnalysis::setBaseInfo(const Module &M) { 127 MAI.MaxID = 0; 128 for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++) 129 MAI.MS[i].clear(); 130 MAI.RegisterAliasTable.clear(); 131 MAI.InstrsToDelete.clear(); 132 MAI.FuncMap.clear(); 133 MAI.GlobalVarList.clear(); 134 MAI.ExtInstSetMap.clear(); 135 MAI.Reqs.clear(); 136 MAI.Reqs.initAvailableCapabilities(*ST); 137 138 // TODO: determine memory model and source language from the configuratoin. 139 if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) { 140 auto MemMD = MemModel->getOperand(0); 141 MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>( 142 getMetadataUInt(MemMD, 0)); 143 MAI.Mem = 144 static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1)); 145 } else { 146 // TODO: Add support for VulkanMemoryModel. 147 MAI.Mem = ST->isOpenCLEnv() ? SPIRV::MemoryModel::OpenCL 148 : SPIRV::MemoryModel::GLSL450; 149 if (MAI.Mem == SPIRV::MemoryModel::OpenCL) { 150 unsigned PtrSize = ST->getPointerSize(); 151 MAI.Addr = PtrSize == 32 ? SPIRV::AddressingModel::Physical32 152 : PtrSize == 64 ? SPIRV::AddressingModel::Physical64 153 : SPIRV::AddressingModel::Logical; 154 } else { 155 // TODO: Add support for PhysicalStorageBufferAddress. 156 MAI.Addr = SPIRV::AddressingModel::Logical; 157 } 158 } 159 // Get the OpenCL version number from metadata. 160 // TODO: support other source languages. 161 if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) { 162 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C; 163 // Construct version literal in accordance with SPIRV-LLVM-Translator. 164 // TODO: support multiple OCL version metadata. 165 assert(VerNode->getNumOperands() > 0 && "Invalid SPIR"); 166 auto VersionMD = VerNode->getOperand(0); 167 unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2); 168 unsigned MinorNum = getMetadataUInt(VersionMD, 1); 169 unsigned RevNum = getMetadataUInt(VersionMD, 2); 170 // Prevent Major part of OpenCL version to be 0 171 MAI.SrcLangVersion = 172 (std::max(1U, MajorNum) * 100 + MinorNum) * 1000 + RevNum; 173 } else { 174 // If there is no information about OpenCL version we are forced to generate 175 // OpenCL 1.0 by default for the OpenCL environment to avoid puzzling 176 // run-times with Unknown/0.0 version output. For a reference, LLVM-SPIRV 177 // Translator avoids potential issues with run-times in a similar manner. 178 if (ST->isOpenCLEnv()) { 179 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_CPP; 180 MAI.SrcLangVersion = 100000; 181 } else { 182 MAI.SrcLang = SPIRV::SourceLanguage::Unknown; 183 MAI.SrcLangVersion = 0; 184 } 185 } 186 187 if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) { 188 for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) { 189 MDNode *MD = ExtNode->getOperand(I); 190 if (!MD || MD->getNumOperands() == 0) 191 continue; 192 for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J) 193 MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString()); 194 } 195 } 196 197 // Update required capabilities for this memory model, addressing model and 198 // source language. 199 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, 200 MAI.Mem, *ST); 201 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::SourceLanguageOperand, 202 MAI.SrcLang, *ST); 203 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand, 204 MAI.Addr, *ST); 205 206 if (ST->isOpenCLEnv()) { 207 // TODO: check if it's required by default. 208 MAI.ExtInstSetMap[static_cast<unsigned>( 209 SPIRV::InstructionSet::OpenCL_std)] = 210 Register::index2VirtReg(MAI.getNextID()); 211 } 212 } 213 214 // Collect MI which defines the register in the given machine function. 215 static void collectDefInstr(Register Reg, const MachineFunction *MF, 216 SPIRV::ModuleAnalysisInfo *MAI, 217 SPIRV::ModuleSectionType MSType, 218 bool DoInsert = true) { 219 assert(MAI->hasRegisterAlias(MF, Reg) && "Cannot find register alias"); 220 MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(Reg); 221 assert(MI && "There should be an instruction that defines the register"); 222 MAI->setSkipEmission(MI); 223 if (DoInsert) 224 MAI->MS[MSType].push_back(MI); 225 } 226 227 void SPIRVModuleAnalysis::collectGlobalEntities( 228 const std::vector<SPIRV::DTSortableEntry *> &DepsGraph, 229 SPIRV::ModuleSectionType MSType, 230 std::function<bool(const SPIRV::DTSortableEntry *)> Pred, 231 bool UsePreOrder = false) { 232 DenseSet<const SPIRV::DTSortableEntry *> Visited; 233 for (const auto *E : DepsGraph) { 234 std::function<void(const SPIRV::DTSortableEntry *)> RecHoistUtil; 235 // NOTE: here we prefer recursive approach over iterative because 236 // we don't expect depchains long enough to cause SO. 237 RecHoistUtil = [MSType, UsePreOrder, &Visited, &Pred, 238 &RecHoistUtil](const SPIRV::DTSortableEntry *E) { 239 if (Visited.count(E) || !Pred(E)) 240 return; 241 Visited.insert(E); 242 243 // Traversing deps graph in post-order allows us to get rid of 244 // register aliases preprocessing. 245 // But pre-order is required for correct processing of function 246 // declaration and arguments processing. 247 if (!UsePreOrder) 248 for (auto *S : E->getDeps()) 249 RecHoistUtil(S); 250 251 Register GlobalReg = Register::index2VirtReg(MAI.getNextID()); 252 bool IsFirst = true; 253 for (auto &U : *E) { 254 const MachineFunction *MF = U.first; 255 Register Reg = U.second; 256 MAI.setRegisterAlias(MF, Reg, GlobalReg); 257 if (!MF->getRegInfo().getUniqueVRegDef(Reg)) 258 continue; 259 collectDefInstr(Reg, MF, &MAI, MSType, IsFirst); 260 IsFirst = false; 261 if (E->getIsGV()) 262 MAI.GlobalVarList.push_back(MF->getRegInfo().getUniqueVRegDef(Reg)); 263 } 264 265 if (UsePreOrder) 266 for (auto *S : E->getDeps()) 267 RecHoistUtil(S); 268 }; 269 RecHoistUtil(E); 270 } 271 } 272 273 // The function initializes global register alias table for types, consts, 274 // global vars and func decls and collects these instruction for output 275 // at module level. Also it collects explicit OpExtension/OpCapability 276 // instructions. 277 void SPIRVModuleAnalysis::processDefInstrs(const Module &M) { 278 std::vector<SPIRV::DTSortableEntry *> DepsGraph; 279 280 GR->buildDepsGraph(DepsGraph, SPVDumpDeps ? MMI : nullptr); 281 282 collectGlobalEntities( 283 DepsGraph, SPIRV::MB_TypeConstVars, 284 [](const SPIRV::DTSortableEntry *E) { return !E->getIsFunc(); }); 285 286 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 287 MachineFunction *MF = MMI->getMachineFunction(*F); 288 if (!MF) 289 continue; 290 // Iterate through and collect OpExtension/OpCapability instructions. 291 for (MachineBasicBlock &MBB : *MF) { 292 for (MachineInstr &MI : MBB) { 293 if (MI.getOpcode() == SPIRV::OpExtension) { 294 // Here, OpExtension just has a single enum operand, not a string. 295 auto Ext = SPIRV::Extension::Extension(MI.getOperand(0).getImm()); 296 MAI.Reqs.addExtension(Ext); 297 MAI.setSkipEmission(&MI); 298 } else if (MI.getOpcode() == SPIRV::OpCapability) { 299 auto Cap = SPIRV::Capability::Capability(MI.getOperand(0).getImm()); 300 MAI.Reqs.addCapability(Cap); 301 MAI.setSkipEmission(&MI); 302 } 303 } 304 } 305 } 306 307 collectGlobalEntities( 308 DepsGraph, SPIRV::MB_ExtFuncDecls, 309 [](const SPIRV::DTSortableEntry *E) { return E->getIsFunc(); }, true); 310 } 311 312 // Look for IDs declared with Import linkage, and map the corresponding function 313 // to the register defining that variable (which will usually be the result of 314 // an OpFunction). This lets us call externally imported functions using 315 // the correct ID registers. 316 void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI, 317 const Function *F) { 318 if (MI.getOpcode() == SPIRV::OpDecorate) { 319 // If it's got Import linkage. 320 auto Dec = MI.getOperand(1).getImm(); 321 if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) { 322 auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm(); 323 if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) { 324 // Map imported function name to function ID register. 325 const Function *ImportedFunc = 326 F->getParent()->getFunction(getStringImm(MI, 2)); 327 Register Target = MI.getOperand(0).getReg(); 328 MAI.FuncMap[ImportedFunc] = MAI.getRegisterAlias(MI.getMF(), Target); 329 } 330 } 331 } else if (MI.getOpcode() == SPIRV::OpFunction) { 332 // Record all internal OpFunction declarations. 333 Register Reg = MI.defs().begin()->getReg(); 334 Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg); 335 assert(GlobalReg.isValid()); 336 MAI.FuncMap[F] = GlobalReg; 337 } 338 } 339 340 // References to a function via function pointers generate virtual 341 // registers without a definition. We are able to resolve this 342 // reference using Globar Register info into an OpFunction instruction 343 // and replace dummy operands by the corresponding global register references. 344 void SPIRVModuleAnalysis::collectFuncPtrs() { 345 for (auto &MI : MAI.MS[SPIRV::MB_TypeConstVars]) 346 if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL) 347 collectFuncPtrs(MI); 348 } 349 350 void SPIRVModuleAnalysis::collectFuncPtrs(MachineInstr *MI) { 351 const MachineOperand *FunUse = &MI->getOperand(2); 352 if (const MachineOperand *FunDef = GR->getFunctionDefinitionByUse(FunUse)) { 353 const MachineInstr *FunDefMI = FunDef->getParent(); 354 assert(FunDefMI->getOpcode() == SPIRV::OpFunction && 355 "Constant function pointer must refer to function definition"); 356 Register FunDefReg = FunDef->getReg(); 357 Register GlobalFunDefReg = 358 MAI.getRegisterAlias(FunDefMI->getMF(), FunDefReg); 359 assert(GlobalFunDefReg.isValid() && 360 "Function definition must refer to a global register"); 361 Register FunPtrReg = FunUse->getReg(); 362 MAI.setRegisterAlias(MI->getMF(), FunPtrReg, GlobalFunDefReg); 363 } 364 } 365 366 using InstrSignature = SmallVector<size_t>; 367 using InstrTraces = std::set<InstrSignature>; 368 369 // Returns a representation of an instruction as a vector of MachineOperand 370 // hash values, see llvm::hash_value(const MachineOperand &MO) for details. 371 // This creates a signature of the instruction with the same content 372 // that MachineOperand::isIdenticalTo uses for comparison. 373 static InstrSignature instrToSignature(MachineInstr &MI, 374 SPIRV::ModuleAnalysisInfo &MAI) { 375 InstrSignature Signature; 376 for (unsigned i = 0; i < MI.getNumOperands(); ++i) { 377 const MachineOperand &MO = MI.getOperand(i); 378 size_t h; 379 if (MO.isReg()) { 380 Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg()); 381 // mimic llvm::hash_value(const MachineOperand &MO) 382 h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(), 383 MO.isDef()); 384 } else { 385 h = hash_value(MO); 386 } 387 Signature.push_back(h); 388 } 389 return Signature; 390 } 391 392 // Collect the given instruction in the specified MS. We assume global register 393 // numbering has already occurred by this point. We can directly compare reg 394 // arguments when detecting duplicates. 395 static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI, 396 SPIRV::ModuleSectionType MSType, InstrTraces &IS, 397 bool Append = true) { 398 MAI.setSkipEmission(&MI); 399 InstrSignature MISign = instrToSignature(MI, MAI); 400 auto FoundMI = IS.insert(MISign); 401 if (!FoundMI.second) 402 return; // insert failed, so we found a duplicate; don't add it to MAI.MS 403 // No duplicates, so add it. 404 if (Append) 405 MAI.MS[MSType].push_back(&MI); 406 else 407 MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI); 408 } 409 410 // Some global instructions make reference to function-local ID regs, so cannot 411 // be correctly collected until these registers are globally numbered. 412 void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) { 413 InstrTraces IS; 414 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 415 if ((*F).isDeclaration()) 416 continue; 417 MachineFunction *MF = MMI->getMachineFunction(*F); 418 assert(MF); 419 for (MachineBasicBlock &MBB : *MF) 420 for (MachineInstr &MI : MBB) { 421 if (MAI.getSkipEmission(&MI)) 422 continue; 423 const unsigned OpCode = MI.getOpcode(); 424 if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) { 425 collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames, IS); 426 } else if (OpCode == SPIRV::OpEntryPoint) { 427 collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints, IS); 428 } else if (TII->isDecorationInstr(MI)) { 429 collectOtherInstr(MI, MAI, SPIRV::MB_Annotations, IS); 430 collectFuncNames(MI, &*F); 431 } else if (TII->isConstantInstr(MI)) { 432 // Now OpSpecConstant*s are not in DT, 433 // but they need to be collected anyway. 434 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS); 435 } else if (OpCode == SPIRV::OpFunction) { 436 collectFuncNames(MI, &*F); 437 } else if (OpCode == SPIRV::OpTypeForwardPointer) { 438 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS, false); 439 } 440 } 441 } 442 } 443 444 // Number registers in all functions globally from 0 onwards and store 445 // the result in global register alias table. Some registers are already 446 // numbered in collectGlobalEntities. 447 void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) { 448 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 449 if ((*F).isDeclaration()) 450 continue; 451 MachineFunction *MF = MMI->getMachineFunction(*F); 452 assert(MF); 453 for (MachineBasicBlock &MBB : *MF) { 454 for (MachineInstr &MI : MBB) { 455 for (MachineOperand &Op : MI.operands()) { 456 if (!Op.isReg()) 457 continue; 458 Register Reg = Op.getReg(); 459 if (MAI.hasRegisterAlias(MF, Reg)) 460 continue; 461 Register NewReg = Register::index2VirtReg(MAI.getNextID()); 462 MAI.setRegisterAlias(MF, Reg, NewReg); 463 } 464 if (MI.getOpcode() != SPIRV::OpExtInst) 465 continue; 466 auto Set = MI.getOperand(2).getImm(); 467 if (!MAI.ExtInstSetMap.contains(Set)) 468 MAI.ExtInstSetMap[Set] = Register::index2VirtReg(MAI.getNextID()); 469 } 470 } 471 } 472 } 473 474 // RequirementHandler implementations. 475 void SPIRV::RequirementHandler::getAndAddRequirements( 476 SPIRV::OperandCategory::OperandCategory Category, uint32_t i, 477 const SPIRVSubtarget &ST) { 478 addRequirements(getSymbolicOperandRequirements(Category, i, ST, *this)); 479 } 480 481 void SPIRV::RequirementHandler::recursiveAddCapabilities( 482 const CapabilityList &ToPrune) { 483 for (const auto &Cap : ToPrune) { 484 AllCaps.insert(Cap); 485 CapabilityList ImplicitDecls = 486 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap); 487 recursiveAddCapabilities(ImplicitDecls); 488 } 489 } 490 491 void SPIRV::RequirementHandler::addCapabilities(const CapabilityList &ToAdd) { 492 for (const auto &Cap : ToAdd) { 493 bool IsNewlyInserted = AllCaps.insert(Cap).second; 494 if (!IsNewlyInserted) // Don't re-add if it's already been declared. 495 continue; 496 CapabilityList ImplicitDecls = 497 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap); 498 recursiveAddCapabilities(ImplicitDecls); 499 MinimalCaps.push_back(Cap); 500 } 501 } 502 503 void SPIRV::RequirementHandler::addRequirements( 504 const SPIRV::Requirements &Req) { 505 if (!Req.IsSatisfiable) 506 report_fatal_error("Adding SPIR-V requirements this target can't satisfy."); 507 508 if (Req.Cap.has_value()) 509 addCapabilities({Req.Cap.value()}); 510 511 addExtensions(Req.Exts); 512 513 if (Req.MinVer) { 514 if (MaxVersion && Req.MinVer > MaxVersion) { 515 LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer 516 << " and <= " << MaxVersion << "\n"); 517 report_fatal_error("Adding SPIR-V requirements that can't be satisfied."); 518 } 519 520 if (MinVersion == 0 || Req.MinVer > MinVersion) 521 MinVersion = Req.MinVer; 522 } 523 524 if (Req.MaxVer) { 525 if (MinVersion && Req.MaxVer < MinVersion) { 526 LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer 527 << " and >= " << MinVersion << "\n"); 528 report_fatal_error("Adding SPIR-V requirements that can't be satisfied."); 529 } 530 531 if (MaxVersion == 0 || Req.MaxVer < MaxVersion) 532 MaxVersion = Req.MaxVer; 533 } 534 } 535 536 void SPIRV::RequirementHandler::checkSatisfiable( 537 const SPIRVSubtarget &ST) const { 538 // Report as many errors as possible before aborting the compilation. 539 bool IsSatisfiable = true; 540 auto TargetVer = ST.getSPIRVVersion(); 541 542 if (MaxVersion && TargetVer && MaxVersion < TargetVer) { 543 LLVM_DEBUG( 544 dbgs() << "Target SPIR-V version too high for required features\n" 545 << "Required max version: " << MaxVersion << " target version " 546 << TargetVer << "\n"); 547 IsSatisfiable = false; 548 } 549 550 if (MinVersion && TargetVer && MinVersion > TargetVer) { 551 LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n" 552 << "Required min version: " << MinVersion 553 << " target version " << TargetVer << "\n"); 554 IsSatisfiable = false; 555 } 556 557 if (MinVersion && MaxVersion && MinVersion > MaxVersion) { 558 LLVM_DEBUG( 559 dbgs() 560 << "Version is too low for some features and too high for others.\n" 561 << "Required SPIR-V min version: " << MinVersion 562 << " required SPIR-V max version " << MaxVersion << "\n"); 563 IsSatisfiable = false; 564 } 565 566 for (auto Cap : MinimalCaps) { 567 if (AvailableCaps.contains(Cap)) 568 continue; 569 LLVM_DEBUG(dbgs() << "Capability not supported: " 570 << getSymbolicOperandMnemonic( 571 OperandCategory::CapabilityOperand, Cap) 572 << "\n"); 573 IsSatisfiable = false; 574 } 575 576 for (auto Ext : AllExtensions) { 577 if (ST.canUseExtension(Ext)) 578 continue; 579 LLVM_DEBUG(dbgs() << "Extension not supported: " 580 << getSymbolicOperandMnemonic( 581 OperandCategory::ExtensionOperand, Ext) 582 << "\n"); 583 IsSatisfiable = false; 584 } 585 586 if (!IsSatisfiable) 587 report_fatal_error("Unable to meet SPIR-V requirements for this target."); 588 } 589 590 // Add the given capabilities and all their implicitly defined capabilities too. 591 void SPIRV::RequirementHandler::addAvailableCaps(const CapabilityList &ToAdd) { 592 for (const auto Cap : ToAdd) 593 if (AvailableCaps.insert(Cap).second) 594 addAvailableCaps(getSymbolicOperandCapabilities( 595 SPIRV::OperandCategory::CapabilityOperand, Cap)); 596 } 597 598 void SPIRV::RequirementHandler::removeCapabilityIf( 599 const Capability::Capability ToRemove, 600 const Capability::Capability IfPresent) { 601 if (AllCaps.contains(IfPresent)) 602 AllCaps.erase(ToRemove); 603 } 604 605 namespace llvm { 606 namespace SPIRV { 607 void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) { 608 if (ST.isOpenCLEnv()) { 609 initAvailableCapabilitiesForOpenCL(ST); 610 return; 611 } 612 613 if (ST.isVulkanEnv()) { 614 initAvailableCapabilitiesForVulkan(ST); 615 return; 616 } 617 618 report_fatal_error("Unimplemented environment for SPIR-V generation."); 619 } 620 621 void RequirementHandler::initAvailableCapabilitiesForOpenCL( 622 const SPIRVSubtarget &ST) { 623 // Add the min requirements for different OpenCL and SPIR-V versions. 624 addAvailableCaps({Capability::Addresses, Capability::Float16Buffer, 625 Capability::Int16, Capability::Int8, Capability::Kernel, 626 Capability::Linkage, Capability::Vector16, 627 Capability::Groups, Capability::GenericPointer, 628 Capability::Shader}); 629 if (ST.hasOpenCLFullProfile()) 630 addAvailableCaps({Capability::Int64, Capability::Int64Atomics}); 631 if (ST.hasOpenCLImageSupport()) { 632 addAvailableCaps({Capability::ImageBasic, Capability::LiteralSampler, 633 Capability::Image1D, Capability::SampledBuffer, 634 Capability::ImageBuffer}); 635 if (ST.isAtLeastOpenCLVer(20)) 636 addAvailableCaps({Capability::ImageReadWrite}); 637 } 638 if (ST.isAtLeastSPIRVVer(11) && ST.isAtLeastOpenCLVer(22)) 639 addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage}); 640 if (ST.isAtLeastSPIRVVer(13)) 641 addAvailableCaps({Capability::GroupNonUniform, 642 Capability::GroupNonUniformVote, 643 Capability::GroupNonUniformArithmetic, 644 Capability::GroupNonUniformBallot, 645 Capability::GroupNonUniformClustered, 646 Capability::GroupNonUniformShuffle, 647 Capability::GroupNonUniformShuffleRelative}); 648 if (ST.isAtLeastSPIRVVer(14)) 649 addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero, 650 Capability::SignedZeroInfNanPreserve, 651 Capability::RoundingModeRTE, 652 Capability::RoundingModeRTZ}); 653 // TODO: verify if this needs some checks. 654 addAvailableCaps({Capability::Float16, Capability::Float64}); 655 656 // Add capabilities enabled by extensions. 657 for (auto Extension : ST.getAllAvailableExtensions()) { 658 CapabilityList EnabledCapabilities = 659 getCapabilitiesEnabledByExtension(Extension); 660 addAvailableCaps(EnabledCapabilities); 661 } 662 663 // TODO: add OpenCL extensions. 664 } 665 666 void RequirementHandler::initAvailableCapabilitiesForVulkan( 667 const SPIRVSubtarget &ST) { 668 addAvailableCaps({Capability::Shader, Capability::Linkage}); 669 670 // Provided by all supported Vulkan versions. 671 addAvailableCaps({Capability::Int16, Capability::Int64, Capability::Float16, 672 Capability::Float64, Capability::GroupNonUniform}); 673 } 674 675 } // namespace SPIRV 676 } // namespace llvm 677 678 // Add the required capabilities from a decoration instruction (including 679 // BuiltIns). 680 static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex, 681 SPIRV::RequirementHandler &Reqs, 682 const SPIRVSubtarget &ST) { 683 int64_t DecOp = MI.getOperand(DecIndex).getImm(); 684 auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp); 685 Reqs.addRequirements(getSymbolicOperandRequirements( 686 SPIRV::OperandCategory::DecorationOperand, Dec, ST, Reqs)); 687 688 if (Dec == SPIRV::Decoration::BuiltIn) { 689 int64_t BuiltInOp = MI.getOperand(DecIndex + 1).getImm(); 690 auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp); 691 Reqs.addRequirements(getSymbolicOperandRequirements( 692 SPIRV::OperandCategory::BuiltInOperand, BuiltIn, ST, Reqs)); 693 } else if (Dec == SPIRV::Decoration::LinkageAttributes) { 694 int64_t LinkageOp = MI.getOperand(MI.getNumOperands() - 1).getImm(); 695 SPIRV::LinkageType::LinkageType LnkType = 696 static_cast<SPIRV::LinkageType::LinkageType>(LinkageOp); 697 if (LnkType == SPIRV::LinkageType::LinkOnceODR) 698 Reqs.addExtension(SPIRV::Extension::SPV_KHR_linkonce_odr); 699 } 700 } 701 702 // Add requirements for image handling. 703 static void addOpTypeImageReqs(const MachineInstr &MI, 704 SPIRV::RequirementHandler &Reqs, 705 const SPIRVSubtarget &ST) { 706 assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage"); 707 // The operand indices used here are based on the OpTypeImage layout, which 708 // the MachineInstr follows as well. 709 int64_t ImgFormatOp = MI.getOperand(7).getImm(); 710 auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp); 711 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageFormatOperand, 712 ImgFormat, ST); 713 714 bool IsArrayed = MI.getOperand(4).getImm() == 1; 715 bool IsMultisampled = MI.getOperand(5).getImm() == 1; 716 bool NoSampler = MI.getOperand(6).getImm() == 2; 717 // Add dimension requirements. 718 assert(MI.getOperand(2).isImm()); 719 switch (MI.getOperand(2).getImm()) { 720 case SPIRV::Dim::DIM_1D: 721 Reqs.addRequirements(NoSampler ? SPIRV::Capability::Image1D 722 : SPIRV::Capability::Sampled1D); 723 break; 724 case SPIRV::Dim::DIM_2D: 725 if (IsMultisampled && NoSampler) 726 Reqs.addRequirements(SPIRV::Capability::ImageMSArray); 727 break; 728 case SPIRV::Dim::DIM_Cube: 729 Reqs.addRequirements(SPIRV::Capability::Shader); 730 if (IsArrayed) 731 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageCubeArray 732 : SPIRV::Capability::SampledCubeArray); 733 break; 734 case SPIRV::Dim::DIM_Rect: 735 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageRect 736 : SPIRV::Capability::SampledRect); 737 break; 738 case SPIRV::Dim::DIM_Buffer: 739 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageBuffer 740 : SPIRV::Capability::SampledBuffer); 741 break; 742 case SPIRV::Dim::DIM_SubpassData: 743 Reqs.addRequirements(SPIRV::Capability::InputAttachment); 744 break; 745 } 746 747 // Has optional access qualifier. 748 // TODO: check if it's OpenCL's kernel. 749 if (MI.getNumOperands() > 8 && 750 MI.getOperand(8).getImm() == SPIRV::AccessQualifier::ReadWrite) 751 Reqs.addRequirements(SPIRV::Capability::ImageReadWrite); 752 else 753 Reqs.addRequirements(SPIRV::Capability::ImageBasic); 754 } 755 756 // Add requirements for handling atomic float instructions 757 #define ATOM_FLT_REQ_EXT_MSG(ExtName) \ 758 "The atomic float instruction requires the following SPIR-V " \ 759 "extension: SPV_EXT_shader_atomic_float" ExtName 760 static void AddAtomicFloatRequirements(const MachineInstr &MI, 761 SPIRV::RequirementHandler &Reqs, 762 const SPIRVSubtarget &ST) { 763 assert(MI.getOperand(1).isReg() && 764 "Expect register operand in atomic float instruction"); 765 Register TypeReg = MI.getOperand(1).getReg(); 766 SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg); 767 if (TypeDef->getOpcode() != SPIRV::OpTypeFloat) 768 report_fatal_error("Result type of an atomic float instruction must be a " 769 "floating-point type scalar"); 770 771 unsigned BitWidth = TypeDef->getOperand(1).getImm(); 772 unsigned Op = MI.getOpcode(); 773 if (Op == SPIRV::OpAtomicFAddEXT) { 774 if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add)) 775 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_add"), false); 776 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add); 777 switch (BitWidth) { 778 case 16: 779 if (!ST.canUseExtension( 780 SPIRV::Extension::SPV_EXT_shader_atomic_float16_add)) 781 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false); 782 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add); 783 Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT); 784 break; 785 case 32: 786 Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT); 787 break; 788 case 64: 789 Reqs.addCapability(SPIRV::Capability::AtomicFloat64AddEXT); 790 break; 791 default: 792 report_fatal_error( 793 "Unexpected floating-point type width in atomic float instruction"); 794 } 795 } else { 796 if (!ST.canUseExtension( 797 SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max)) 798 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), false); 799 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max); 800 switch (BitWidth) { 801 case 16: 802 Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT); 803 break; 804 case 32: 805 Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT); 806 break; 807 case 64: 808 Reqs.addCapability(SPIRV::Capability::AtomicFloat64MinMaxEXT); 809 break; 810 default: 811 report_fatal_error( 812 "Unexpected floating-point type width in atomic float instruction"); 813 } 814 } 815 } 816 817 void addInstrRequirements(const MachineInstr &MI, 818 SPIRV::RequirementHandler &Reqs, 819 const SPIRVSubtarget &ST) { 820 switch (MI.getOpcode()) { 821 case SPIRV::OpMemoryModel: { 822 int64_t Addr = MI.getOperand(0).getImm(); 823 Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand, 824 Addr, ST); 825 int64_t Mem = MI.getOperand(1).getImm(); 826 Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, Mem, 827 ST); 828 break; 829 } 830 case SPIRV::OpEntryPoint: { 831 int64_t Exe = MI.getOperand(0).getImm(); 832 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModelOperand, 833 Exe, ST); 834 break; 835 } 836 case SPIRV::OpExecutionMode: 837 case SPIRV::OpExecutionModeId: { 838 int64_t Exe = MI.getOperand(1).getImm(); 839 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModeOperand, 840 Exe, ST); 841 break; 842 } 843 case SPIRV::OpTypeMatrix: 844 Reqs.addCapability(SPIRV::Capability::Matrix); 845 break; 846 case SPIRV::OpTypeInt: { 847 unsigned BitWidth = MI.getOperand(1).getImm(); 848 if (BitWidth == 64) 849 Reqs.addCapability(SPIRV::Capability::Int64); 850 else if (BitWidth == 16) 851 Reqs.addCapability(SPIRV::Capability::Int16); 852 else if (BitWidth == 8) 853 Reqs.addCapability(SPIRV::Capability::Int8); 854 break; 855 } 856 case SPIRV::OpTypeFloat: { 857 unsigned BitWidth = MI.getOperand(1).getImm(); 858 if (BitWidth == 64) 859 Reqs.addCapability(SPIRV::Capability::Float64); 860 else if (BitWidth == 16) 861 Reqs.addCapability(SPIRV::Capability::Float16); 862 break; 863 } 864 case SPIRV::OpTypeVector: { 865 unsigned NumComponents = MI.getOperand(2).getImm(); 866 if (NumComponents == 8 || NumComponents == 16) 867 Reqs.addCapability(SPIRV::Capability::Vector16); 868 break; 869 } 870 case SPIRV::OpTypePointer: { 871 auto SC = MI.getOperand(1).getImm(); 872 Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC, 873 ST); 874 // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer 875 // capability. 876 if (!ST.isOpenCLEnv()) 877 break; 878 assert(MI.getOperand(2).isReg()); 879 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); 880 SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg()); 881 if (TypeDef->getOpcode() == SPIRV::OpTypeFloat && 882 TypeDef->getOperand(1).getImm() == 16) 883 Reqs.addCapability(SPIRV::Capability::Float16Buffer); 884 break; 885 } 886 case SPIRV::OpBitReverse: 887 case SPIRV::OpBitFieldInsert: 888 case SPIRV::OpBitFieldSExtract: 889 case SPIRV::OpBitFieldUExtract: 890 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions)) { 891 Reqs.addCapability(SPIRV::Capability::Shader); 892 break; 893 } 894 Reqs.addExtension(SPIRV::Extension::SPV_KHR_bit_instructions); 895 Reqs.addCapability(SPIRV::Capability::BitInstructions); 896 break; 897 case SPIRV::OpTypeRuntimeArray: 898 Reqs.addCapability(SPIRV::Capability::Shader); 899 break; 900 case SPIRV::OpTypeOpaque: 901 case SPIRV::OpTypeEvent: 902 Reqs.addCapability(SPIRV::Capability::Kernel); 903 break; 904 case SPIRV::OpTypePipe: 905 case SPIRV::OpTypeReserveId: 906 Reqs.addCapability(SPIRV::Capability::Pipes); 907 break; 908 case SPIRV::OpTypeDeviceEvent: 909 case SPIRV::OpTypeQueue: 910 case SPIRV::OpBuildNDRange: 911 Reqs.addCapability(SPIRV::Capability::DeviceEnqueue); 912 break; 913 case SPIRV::OpDecorate: 914 case SPIRV::OpDecorateId: 915 case SPIRV::OpDecorateString: 916 addOpDecorateReqs(MI, 1, Reqs, ST); 917 break; 918 case SPIRV::OpMemberDecorate: 919 case SPIRV::OpMemberDecorateString: 920 addOpDecorateReqs(MI, 2, Reqs, ST); 921 break; 922 case SPIRV::OpInBoundsPtrAccessChain: 923 Reqs.addCapability(SPIRV::Capability::Addresses); 924 break; 925 case SPIRV::OpConstantSampler: 926 Reqs.addCapability(SPIRV::Capability::LiteralSampler); 927 break; 928 case SPIRV::OpTypeImage: 929 addOpTypeImageReqs(MI, Reqs, ST); 930 break; 931 case SPIRV::OpTypeSampler: 932 Reqs.addCapability(SPIRV::Capability::ImageBasic); 933 break; 934 case SPIRV::OpTypeForwardPointer: 935 // TODO: check if it's OpenCL's kernel. 936 Reqs.addCapability(SPIRV::Capability::Addresses); 937 break; 938 case SPIRV::OpAtomicFlagTestAndSet: 939 case SPIRV::OpAtomicLoad: 940 case SPIRV::OpAtomicStore: 941 case SPIRV::OpAtomicExchange: 942 case SPIRV::OpAtomicCompareExchange: 943 case SPIRV::OpAtomicIIncrement: 944 case SPIRV::OpAtomicIDecrement: 945 case SPIRV::OpAtomicIAdd: 946 case SPIRV::OpAtomicISub: 947 case SPIRV::OpAtomicUMin: 948 case SPIRV::OpAtomicUMax: 949 case SPIRV::OpAtomicSMin: 950 case SPIRV::OpAtomicSMax: 951 case SPIRV::OpAtomicAnd: 952 case SPIRV::OpAtomicOr: 953 case SPIRV::OpAtomicXor: { 954 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); 955 const MachineInstr *InstrPtr = &MI; 956 if (MI.getOpcode() == SPIRV::OpAtomicStore) { 957 assert(MI.getOperand(3).isReg()); 958 InstrPtr = MRI.getVRegDef(MI.getOperand(3).getReg()); 959 assert(InstrPtr && "Unexpected type instruction for OpAtomicStore"); 960 } 961 assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic"); 962 Register TypeReg = InstrPtr->getOperand(1).getReg(); 963 SPIRVType *TypeDef = MRI.getVRegDef(TypeReg); 964 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) { 965 unsigned BitWidth = TypeDef->getOperand(1).getImm(); 966 if (BitWidth == 64) 967 Reqs.addCapability(SPIRV::Capability::Int64Atomics); 968 } 969 break; 970 } 971 case SPIRV::OpGroupNonUniformIAdd: 972 case SPIRV::OpGroupNonUniformFAdd: 973 case SPIRV::OpGroupNonUniformIMul: 974 case SPIRV::OpGroupNonUniformFMul: 975 case SPIRV::OpGroupNonUniformSMin: 976 case SPIRV::OpGroupNonUniformUMin: 977 case SPIRV::OpGroupNonUniformFMin: 978 case SPIRV::OpGroupNonUniformSMax: 979 case SPIRV::OpGroupNonUniformUMax: 980 case SPIRV::OpGroupNonUniformFMax: 981 case SPIRV::OpGroupNonUniformBitwiseAnd: 982 case SPIRV::OpGroupNonUniformBitwiseOr: 983 case SPIRV::OpGroupNonUniformBitwiseXor: 984 case SPIRV::OpGroupNonUniformLogicalAnd: 985 case SPIRV::OpGroupNonUniformLogicalOr: 986 case SPIRV::OpGroupNonUniformLogicalXor: { 987 assert(MI.getOperand(3).isImm()); 988 int64_t GroupOp = MI.getOperand(3).getImm(); 989 switch (GroupOp) { 990 case SPIRV::GroupOperation::Reduce: 991 case SPIRV::GroupOperation::InclusiveScan: 992 case SPIRV::GroupOperation::ExclusiveScan: 993 Reqs.addCapability(SPIRV::Capability::Kernel); 994 Reqs.addCapability(SPIRV::Capability::GroupNonUniformArithmetic); 995 Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot); 996 break; 997 case SPIRV::GroupOperation::ClusteredReduce: 998 Reqs.addCapability(SPIRV::Capability::GroupNonUniformClustered); 999 break; 1000 case SPIRV::GroupOperation::PartitionedReduceNV: 1001 case SPIRV::GroupOperation::PartitionedInclusiveScanNV: 1002 case SPIRV::GroupOperation::PartitionedExclusiveScanNV: 1003 Reqs.addCapability(SPIRV::Capability::GroupNonUniformPartitionedNV); 1004 break; 1005 } 1006 break; 1007 } 1008 case SPIRV::OpGroupNonUniformShuffle: 1009 case SPIRV::OpGroupNonUniformShuffleXor: 1010 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffle); 1011 break; 1012 case SPIRV::OpGroupNonUniformShuffleUp: 1013 case SPIRV::OpGroupNonUniformShuffleDown: 1014 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffleRelative); 1015 break; 1016 case SPIRV::OpGroupAll: 1017 case SPIRV::OpGroupAny: 1018 case SPIRV::OpGroupBroadcast: 1019 case SPIRV::OpGroupIAdd: 1020 case SPIRV::OpGroupFAdd: 1021 case SPIRV::OpGroupFMin: 1022 case SPIRV::OpGroupUMin: 1023 case SPIRV::OpGroupSMin: 1024 case SPIRV::OpGroupFMax: 1025 case SPIRV::OpGroupUMax: 1026 case SPIRV::OpGroupSMax: 1027 Reqs.addCapability(SPIRV::Capability::Groups); 1028 break; 1029 case SPIRV::OpGroupNonUniformElect: 1030 Reqs.addCapability(SPIRV::Capability::GroupNonUniform); 1031 break; 1032 case SPIRV::OpGroupNonUniformAll: 1033 case SPIRV::OpGroupNonUniformAny: 1034 case SPIRV::OpGroupNonUniformAllEqual: 1035 Reqs.addCapability(SPIRV::Capability::GroupNonUniformVote); 1036 break; 1037 case SPIRV::OpGroupNonUniformBroadcast: 1038 case SPIRV::OpGroupNonUniformBroadcastFirst: 1039 case SPIRV::OpGroupNonUniformBallot: 1040 case SPIRV::OpGroupNonUniformInverseBallot: 1041 case SPIRV::OpGroupNonUniformBallotBitExtract: 1042 case SPIRV::OpGroupNonUniformBallotBitCount: 1043 case SPIRV::OpGroupNonUniformBallotFindLSB: 1044 case SPIRV::OpGroupNonUniformBallotFindMSB: 1045 Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot); 1046 break; 1047 case SPIRV::OpSubgroupShuffleINTEL: 1048 case SPIRV::OpSubgroupShuffleDownINTEL: 1049 case SPIRV::OpSubgroupShuffleUpINTEL: 1050 case SPIRV::OpSubgroupShuffleXorINTEL: 1051 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) { 1052 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups); 1053 Reqs.addCapability(SPIRV::Capability::SubgroupShuffleINTEL); 1054 } 1055 break; 1056 case SPIRV::OpSubgroupBlockReadINTEL: 1057 case SPIRV::OpSubgroupBlockWriteINTEL: 1058 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) { 1059 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups); 1060 Reqs.addCapability(SPIRV::Capability::SubgroupBufferBlockIOINTEL); 1061 } 1062 break; 1063 case SPIRV::OpSubgroupImageBlockReadINTEL: 1064 case SPIRV::OpSubgroupImageBlockWriteINTEL: 1065 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) { 1066 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups); 1067 Reqs.addCapability(SPIRV::Capability::SubgroupImageBlockIOINTEL); 1068 } 1069 break; 1070 case SPIRV::OpAssumeTrueKHR: 1071 case SPIRV::OpExpectKHR: 1072 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) { 1073 Reqs.addExtension(SPIRV::Extension::SPV_KHR_expect_assume); 1074 Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR); 1075 } 1076 break; 1077 case SPIRV::OpPtrCastToCrossWorkgroupINTEL: 1078 case SPIRV::OpCrossWorkgroupCastToPtrINTEL: 1079 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) { 1080 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes); 1081 Reqs.addCapability(SPIRV::Capability::USMStorageClassesINTEL); 1082 } 1083 break; 1084 case SPIRV::OpConstantFunctionPointerINTEL: 1085 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) { 1086 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers); 1087 Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL); 1088 } 1089 break; 1090 case SPIRV::OpGroupNonUniformRotateKHR: 1091 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate)) 1092 report_fatal_error("OpGroupNonUniformRotateKHR instruction requires the " 1093 "following SPIR-V extension: SPV_KHR_subgroup_rotate", 1094 false); 1095 Reqs.addExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate); 1096 Reqs.addCapability(SPIRV::Capability::GroupNonUniformRotateKHR); 1097 Reqs.addCapability(SPIRV::Capability::GroupNonUniform); 1098 break; 1099 case SPIRV::OpGroupIMulKHR: 1100 case SPIRV::OpGroupFMulKHR: 1101 case SPIRV::OpGroupBitwiseAndKHR: 1102 case SPIRV::OpGroupBitwiseOrKHR: 1103 case SPIRV::OpGroupBitwiseXorKHR: 1104 case SPIRV::OpGroupLogicalAndKHR: 1105 case SPIRV::OpGroupLogicalOrKHR: 1106 case SPIRV::OpGroupLogicalXorKHR: 1107 if (ST.canUseExtension( 1108 SPIRV::Extension::SPV_KHR_uniform_group_instructions)) { 1109 Reqs.addExtension(SPIRV::Extension::SPV_KHR_uniform_group_instructions); 1110 Reqs.addCapability(SPIRV::Capability::GroupUniformArithmeticKHR); 1111 } 1112 break; 1113 case SPIRV::OpFunctionPointerCallINTEL: 1114 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) { 1115 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers); 1116 Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL); 1117 } 1118 break; 1119 case SPIRV::OpAtomicFAddEXT: 1120 case SPIRV::OpAtomicFMinEXT: 1121 case SPIRV::OpAtomicFMaxEXT: 1122 AddAtomicFloatRequirements(MI, Reqs, ST); 1123 break; 1124 case SPIRV::OpConvertBF16ToFINTEL: 1125 case SPIRV::OpConvertFToBF16INTEL: 1126 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) { 1127 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion); 1128 Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL); 1129 } 1130 break; 1131 case SPIRV::OpVariableLengthArrayINTEL: 1132 case SPIRV::OpSaveMemoryINTEL: 1133 case SPIRV::OpRestoreMemoryINTEL: 1134 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_variable_length_array)) { 1135 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_variable_length_array); 1136 Reqs.addCapability(SPIRV::Capability::VariableLengthArrayINTEL); 1137 } 1138 break; 1139 default: 1140 break; 1141 } 1142 1143 // If we require capability Shader, then we can remove the requirement for 1144 // the BitInstructions capability, since Shader is a superset capability 1145 // of BitInstructions. 1146 Reqs.removeCapabilityIf(SPIRV::Capability::BitInstructions, 1147 SPIRV::Capability::Shader); 1148 } 1149 1150 static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, 1151 MachineModuleInfo *MMI, const SPIRVSubtarget &ST) { 1152 // Collect requirements for existing instructions. 1153 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 1154 MachineFunction *MF = MMI->getMachineFunction(*F); 1155 if (!MF) 1156 continue; 1157 for (const MachineBasicBlock &MBB : *MF) 1158 for (const MachineInstr &MI : MBB) 1159 addInstrRequirements(MI, MAI.Reqs, ST); 1160 } 1161 // Collect requirements for OpExecutionMode instructions. 1162 auto Node = M.getNamedMetadata("spirv.ExecutionMode"); 1163 if (Node) { 1164 // SPV_KHR_float_controls is not available until v1.4 1165 bool RequireFloatControls = false, VerLower14 = !ST.isAtLeastSPIRVVer(14); 1166 for (unsigned i = 0; i < Node->getNumOperands(); i++) { 1167 MDNode *MDN = cast<MDNode>(Node->getOperand(i)); 1168 const MDOperand &MDOp = MDN->getOperand(1); 1169 if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) { 1170 Constant *C = CMeta->getValue(); 1171 if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) { 1172 auto EM = Const->getZExtValue(); 1173 MAI.Reqs.getAndAddRequirements( 1174 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST); 1175 // add SPV_KHR_float_controls if the version is too low 1176 switch (EM) { 1177 case SPIRV::ExecutionMode::DenormPreserve: 1178 case SPIRV::ExecutionMode::DenormFlushToZero: 1179 case SPIRV::ExecutionMode::SignedZeroInfNanPreserve: 1180 case SPIRV::ExecutionMode::RoundingModeRTE: 1181 case SPIRV::ExecutionMode::RoundingModeRTZ: 1182 RequireFloatControls = VerLower14; 1183 break; 1184 } 1185 } 1186 } 1187 } 1188 if (RequireFloatControls && 1189 ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls)) 1190 MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls); 1191 } 1192 for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) { 1193 const Function &F = *FI; 1194 if (F.isDeclaration()) 1195 continue; 1196 if (F.getMetadata("reqd_work_group_size")) 1197 MAI.Reqs.getAndAddRequirements( 1198 SPIRV::OperandCategory::ExecutionModeOperand, 1199 SPIRV::ExecutionMode::LocalSize, ST); 1200 if (F.getFnAttribute("hlsl.numthreads").isValid()) { 1201 MAI.Reqs.getAndAddRequirements( 1202 SPIRV::OperandCategory::ExecutionModeOperand, 1203 SPIRV::ExecutionMode::LocalSize, ST); 1204 } 1205 if (F.getMetadata("work_group_size_hint")) 1206 MAI.Reqs.getAndAddRequirements( 1207 SPIRV::OperandCategory::ExecutionModeOperand, 1208 SPIRV::ExecutionMode::LocalSizeHint, ST); 1209 if (F.getMetadata("intel_reqd_sub_group_size")) 1210 MAI.Reqs.getAndAddRequirements( 1211 SPIRV::OperandCategory::ExecutionModeOperand, 1212 SPIRV::ExecutionMode::SubgroupSize, ST); 1213 if (F.getMetadata("vec_type_hint")) 1214 MAI.Reqs.getAndAddRequirements( 1215 SPIRV::OperandCategory::ExecutionModeOperand, 1216 SPIRV::ExecutionMode::VecTypeHint, ST); 1217 1218 if (F.hasOptNone() && 1219 ST.canUseExtension(SPIRV::Extension::SPV_INTEL_optnone)) { 1220 // Output OpCapability OptNoneINTEL. 1221 MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_optnone); 1222 MAI.Reqs.addCapability(SPIRV::Capability::OptNoneINTEL); 1223 } 1224 } 1225 } 1226 1227 static unsigned getFastMathFlags(const MachineInstr &I) { 1228 unsigned Flags = SPIRV::FPFastMathMode::None; 1229 if (I.getFlag(MachineInstr::MIFlag::FmNoNans)) 1230 Flags |= SPIRV::FPFastMathMode::NotNaN; 1231 if (I.getFlag(MachineInstr::MIFlag::FmNoInfs)) 1232 Flags |= SPIRV::FPFastMathMode::NotInf; 1233 if (I.getFlag(MachineInstr::MIFlag::FmNsz)) 1234 Flags |= SPIRV::FPFastMathMode::NSZ; 1235 if (I.getFlag(MachineInstr::MIFlag::FmArcp)) 1236 Flags |= SPIRV::FPFastMathMode::AllowRecip; 1237 if (I.getFlag(MachineInstr::MIFlag::FmReassoc)) 1238 Flags |= SPIRV::FPFastMathMode::Fast; 1239 return Flags; 1240 } 1241 1242 static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST, 1243 const SPIRVInstrInfo &TII, 1244 SPIRV::RequirementHandler &Reqs) { 1245 if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) && 1246 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand, 1247 SPIRV::Decoration::NoSignedWrap, ST, Reqs) 1248 .IsSatisfiable) { 1249 buildOpDecorate(I.getOperand(0).getReg(), I, TII, 1250 SPIRV::Decoration::NoSignedWrap, {}); 1251 } 1252 if (I.getFlag(MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(I) && 1253 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand, 1254 SPIRV::Decoration::NoUnsignedWrap, ST, 1255 Reqs) 1256 .IsSatisfiable) { 1257 buildOpDecorate(I.getOperand(0).getReg(), I, TII, 1258 SPIRV::Decoration::NoUnsignedWrap, {}); 1259 } 1260 if (!TII.canUseFastMathFlags(I)) 1261 return; 1262 unsigned FMFlags = getFastMathFlags(I); 1263 if (FMFlags == SPIRV::FPFastMathMode::None) 1264 return; 1265 Register DstReg = I.getOperand(0).getReg(); 1266 buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode, {FMFlags}); 1267 } 1268 1269 // Walk all functions and add decorations related to MI flags. 1270 static void addDecorations(const Module &M, const SPIRVInstrInfo &TII, 1271 MachineModuleInfo *MMI, const SPIRVSubtarget &ST, 1272 SPIRV::ModuleAnalysisInfo &MAI) { 1273 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 1274 MachineFunction *MF = MMI->getMachineFunction(*F); 1275 if (!MF) 1276 continue; 1277 for (auto &MBB : *MF) 1278 for (auto &MI : MBB) 1279 handleMIFlagDecoration(MI, ST, TII, MAI.Reqs); 1280 } 1281 } 1282 1283 struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI; 1284 1285 void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { 1286 AU.addRequired<TargetPassConfig>(); 1287 AU.addRequired<MachineModuleInfoWrapperPass>(); 1288 } 1289 1290 bool SPIRVModuleAnalysis::runOnModule(Module &M) { 1291 SPIRVTargetMachine &TM = 1292 getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>(); 1293 ST = TM.getSubtargetImpl(); 1294 GR = ST->getSPIRVGlobalRegistry(); 1295 TII = ST->getInstrInfo(); 1296 1297 MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI(); 1298 1299 setBaseInfo(M); 1300 1301 addDecorations(M, *TII, MMI, *ST, MAI); 1302 1303 collectReqs(M, MAI, MMI, *ST); 1304 1305 // Process type/const/global var/func decl instructions, number their 1306 // destination registers from 0 to N, collect Extensions and Capabilities. 1307 processDefInstrs(M); 1308 1309 // Number rest of registers from N+1 onwards. 1310 numberRegistersGlobally(M); 1311 1312 // Update references to OpFunction instructions to use Global Registers 1313 if (GR->hasConstFunPtr()) 1314 collectFuncPtrs(); 1315 1316 // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions. 1317 processOtherInstrs(M); 1318 1319 // If there are no entry points, we need the Linkage capability. 1320 if (MAI.MS[SPIRV::MB_EntryPoints].empty()) 1321 MAI.Reqs.addCapability(SPIRV::Capability::Linkage); 1322 1323 // Set maximum ID used. 1324 GR->setBound(MAI.MaxID); 1325 1326 return false; 1327 } 1328