1 //===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The analysis collects instructions that should be output at the module level 10 // and performs the global register numbering. 11 // 12 // The results of this analysis are used in AsmPrinter to rename registers 13 // globally and to output required instructions at the module level. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "SPIRVModuleAnalysis.h" 18 #include "MCTargetDesc/SPIRVBaseInfo.h" 19 #include "MCTargetDesc/SPIRVMCTargetDesc.h" 20 #include "SPIRV.h" 21 #include "SPIRVSubtarget.h" 22 #include "SPIRVTargetMachine.h" 23 #include "SPIRVUtils.h" 24 #include "TargetInfo/SPIRVTargetInfo.h" 25 #include "llvm/ADT/STLExtras.h" 26 #include "llvm/CodeGen/MachineModuleInfo.h" 27 #include "llvm/CodeGen/TargetPassConfig.h" 28 29 using namespace llvm; 30 31 #define DEBUG_TYPE "spirv-module-analysis" 32 33 static cl::opt<bool> 34 SPVDumpDeps("spv-dump-deps", 35 cl::desc("Dump MIR with SPIR-V dependencies info"), 36 cl::Optional, cl::init(false)); 37 38 static cl::list<SPIRV::Capability::Capability> 39 AvoidCapabilities("avoid-spirv-capabilities", 40 cl::desc("SPIR-V capabilities to avoid if there are " 41 "other options enabling a feature"), 42 cl::ZeroOrMore, cl::Hidden, 43 cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader", 44 "SPIR-V Shader capability"))); 45 // Use sets instead of cl::list to check "if contains" condition 46 struct AvoidCapabilitiesSet { 47 SmallSet<SPIRV::Capability::Capability, 4> S; 48 AvoidCapabilitiesSet() { 49 for (auto Cap : AvoidCapabilities) 50 S.insert(Cap); 51 } 52 }; 53 54 char llvm::SPIRVModuleAnalysis::ID = 0; 55 56 namespace llvm { 57 void initializeSPIRVModuleAnalysisPass(PassRegistry &); 58 } // namespace llvm 59 60 INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true, 61 true) 62 63 // Retrieve an unsigned from an MDNode with a list of them as operands. 64 static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex, 65 unsigned DefaultVal = 0) { 66 if (MdNode && OpIndex < MdNode->getNumOperands()) { 67 const auto &Op = MdNode->getOperand(OpIndex); 68 return mdconst::extract<ConstantInt>(Op)->getZExtValue(); 69 } 70 return DefaultVal; 71 } 72 73 static SPIRV::Requirements 74 getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category, 75 unsigned i, const SPIRVSubtarget &ST, 76 SPIRV::RequirementHandler &Reqs) { 77 static AvoidCapabilitiesSet 78 AvoidCaps; // contains capabilities to avoid if there is another option 79 unsigned ReqMinVer = getSymbolicOperandMinVersion(Category, i); 80 unsigned ReqMaxVer = getSymbolicOperandMaxVersion(Category, i); 81 unsigned TargetVer = ST.getSPIRVVersion(); 82 bool MinVerOK = !ReqMinVer || !TargetVer || TargetVer >= ReqMinVer; 83 bool MaxVerOK = !ReqMaxVer || !TargetVer || TargetVer <= ReqMaxVer; 84 CapabilityList ReqCaps = getSymbolicOperandCapabilities(Category, i); 85 ExtensionList ReqExts = getSymbolicOperandExtensions(Category, i); 86 if (ReqCaps.empty()) { 87 if (ReqExts.empty()) { 88 if (MinVerOK && MaxVerOK) 89 return {true, {}, {}, ReqMinVer, ReqMaxVer}; 90 return {false, {}, {}, 0, 0}; 91 } 92 } else if (MinVerOK && MaxVerOK) { 93 if (ReqCaps.size() == 1) { 94 auto Cap = ReqCaps[0]; 95 if (Reqs.isCapabilityAvailable(Cap)) 96 return {true, {Cap}, {}, ReqMinVer, ReqMaxVer}; 97 } else { 98 // By SPIR-V specification: "If an instruction, enumerant, or other 99 // feature specifies multiple enabling capabilities, only one such 100 // capability needs to be declared to use the feature." However, one 101 // capability may be preferred over another. We use command line 102 // argument(s) and AvoidCapabilities to avoid selection of certain 103 // capabilities if there are other options. 104 CapabilityList UseCaps; 105 for (auto Cap : ReqCaps) 106 if (Reqs.isCapabilityAvailable(Cap)) 107 UseCaps.push_back(Cap); 108 for (size_t i = 0, Sz = UseCaps.size(); i < Sz; ++i) { 109 auto Cap = UseCaps[i]; 110 if (i == Sz - 1 || !AvoidCaps.S.contains(Cap)) 111 return {true, {Cap}, {}, ReqMinVer, ReqMaxVer}; 112 } 113 } 114 } 115 // If there are no capabilities, or we can't satisfy the version or 116 // capability requirements, use the list of extensions (if the subtarget 117 // can handle them all). 118 if (llvm::all_of(ReqExts, [&ST](const SPIRV::Extension::Extension &Ext) { 119 return ST.canUseExtension(Ext); 120 })) { 121 return {true, {}, ReqExts, 0, 0}; // TODO: add versions to extensions. 122 } 123 return {false, {}, {}, 0, 0}; 124 } 125 126 void SPIRVModuleAnalysis::setBaseInfo(const Module &M) { 127 MAI.MaxID = 0; 128 for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++) 129 MAI.MS[i].clear(); 130 MAI.RegisterAliasTable.clear(); 131 MAI.InstrsToDelete.clear(); 132 MAI.FuncMap.clear(); 133 MAI.GlobalVarList.clear(); 134 MAI.ExtInstSetMap.clear(); 135 MAI.Reqs.clear(); 136 MAI.Reqs.initAvailableCapabilities(*ST); 137 138 // TODO: determine memory model and source language from the configuratoin. 139 if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) { 140 auto MemMD = MemModel->getOperand(0); 141 MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>( 142 getMetadataUInt(MemMD, 0)); 143 MAI.Mem = 144 static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1)); 145 } else { 146 // TODO: Add support for VulkanMemoryModel. 147 MAI.Mem = ST->isOpenCLEnv() ? SPIRV::MemoryModel::OpenCL 148 : SPIRV::MemoryModel::GLSL450; 149 if (MAI.Mem == SPIRV::MemoryModel::OpenCL) { 150 unsigned PtrSize = ST->getPointerSize(); 151 MAI.Addr = PtrSize == 32 ? SPIRV::AddressingModel::Physical32 152 : PtrSize == 64 ? SPIRV::AddressingModel::Physical64 153 : SPIRV::AddressingModel::Logical; 154 } else { 155 // TODO: Add support for PhysicalStorageBufferAddress. 156 MAI.Addr = SPIRV::AddressingModel::Logical; 157 } 158 } 159 // Get the OpenCL version number from metadata. 160 // TODO: support other source languages. 161 if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) { 162 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C; 163 // Construct version literal in accordance with SPIRV-LLVM-Translator. 164 // TODO: support multiple OCL version metadata. 165 assert(VerNode->getNumOperands() > 0 && "Invalid SPIR"); 166 auto VersionMD = VerNode->getOperand(0); 167 unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2); 168 unsigned MinorNum = getMetadataUInt(VersionMD, 1); 169 unsigned RevNum = getMetadataUInt(VersionMD, 2); 170 MAI.SrcLangVersion = (MajorNum * 100 + MinorNum) * 1000 + RevNum; 171 } else { 172 MAI.SrcLang = SPIRV::SourceLanguage::Unknown; 173 MAI.SrcLangVersion = 0; 174 } 175 176 if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) { 177 for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) { 178 MDNode *MD = ExtNode->getOperand(I); 179 if (!MD || MD->getNumOperands() == 0) 180 continue; 181 for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J) 182 MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString()); 183 } 184 } 185 186 // Update required capabilities for this memory model, addressing model and 187 // source language. 188 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, 189 MAI.Mem, *ST); 190 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::SourceLanguageOperand, 191 MAI.SrcLang, *ST); 192 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand, 193 MAI.Addr, *ST); 194 195 if (ST->isOpenCLEnv()) { 196 // TODO: check if it's required by default. 197 MAI.ExtInstSetMap[static_cast<unsigned>( 198 SPIRV::InstructionSet::OpenCL_std)] = 199 Register::index2VirtReg(MAI.getNextID()); 200 } 201 } 202 203 // Collect MI which defines the register in the given machine function. 204 static void collectDefInstr(Register Reg, const MachineFunction *MF, 205 SPIRV::ModuleAnalysisInfo *MAI, 206 SPIRV::ModuleSectionType MSType, 207 bool DoInsert = true) { 208 assert(MAI->hasRegisterAlias(MF, Reg) && "Cannot find register alias"); 209 MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(Reg); 210 assert(MI && "There should be an instruction that defines the register"); 211 MAI->setSkipEmission(MI); 212 if (DoInsert) 213 MAI->MS[MSType].push_back(MI); 214 } 215 216 void SPIRVModuleAnalysis::collectGlobalEntities( 217 const std::vector<SPIRV::DTSortableEntry *> &DepsGraph, 218 SPIRV::ModuleSectionType MSType, 219 std::function<bool(const SPIRV::DTSortableEntry *)> Pred, 220 bool UsePreOrder = false) { 221 DenseSet<const SPIRV::DTSortableEntry *> Visited; 222 for (const auto *E : DepsGraph) { 223 std::function<void(const SPIRV::DTSortableEntry *)> RecHoistUtil; 224 // NOTE: here we prefer recursive approach over iterative because 225 // we don't expect depchains long enough to cause SO. 226 RecHoistUtil = [MSType, UsePreOrder, &Visited, &Pred, 227 &RecHoistUtil](const SPIRV::DTSortableEntry *E) { 228 if (Visited.count(E) || !Pred(E)) 229 return; 230 Visited.insert(E); 231 232 // Traversing deps graph in post-order allows us to get rid of 233 // register aliases preprocessing. 234 // But pre-order is required for correct processing of function 235 // declaration and arguments processing. 236 if (!UsePreOrder) 237 for (auto *S : E->getDeps()) 238 RecHoistUtil(S); 239 240 Register GlobalReg = Register::index2VirtReg(MAI.getNextID()); 241 bool IsFirst = true; 242 for (auto &U : *E) { 243 const MachineFunction *MF = U.first; 244 Register Reg = U.second; 245 MAI.setRegisterAlias(MF, Reg, GlobalReg); 246 if (!MF->getRegInfo().getUniqueVRegDef(Reg)) 247 continue; 248 collectDefInstr(Reg, MF, &MAI, MSType, IsFirst); 249 IsFirst = false; 250 if (E->getIsGV()) 251 MAI.GlobalVarList.push_back(MF->getRegInfo().getUniqueVRegDef(Reg)); 252 } 253 254 if (UsePreOrder) 255 for (auto *S : E->getDeps()) 256 RecHoistUtil(S); 257 }; 258 RecHoistUtil(E); 259 } 260 } 261 262 // The function initializes global register alias table for types, consts, 263 // global vars and func decls and collects these instruction for output 264 // at module level. Also it collects explicit OpExtension/OpCapability 265 // instructions. 266 void SPIRVModuleAnalysis::processDefInstrs(const Module &M) { 267 std::vector<SPIRV::DTSortableEntry *> DepsGraph; 268 269 GR->buildDepsGraph(DepsGraph, SPVDumpDeps ? MMI : nullptr); 270 271 collectGlobalEntities( 272 DepsGraph, SPIRV::MB_TypeConstVars, 273 [](const SPIRV::DTSortableEntry *E) { return !E->getIsFunc(); }); 274 275 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 276 MachineFunction *MF = MMI->getMachineFunction(*F); 277 if (!MF) 278 continue; 279 // Iterate through and collect OpExtension/OpCapability instructions. 280 for (MachineBasicBlock &MBB : *MF) { 281 for (MachineInstr &MI : MBB) { 282 if (MI.getOpcode() == SPIRV::OpExtension) { 283 // Here, OpExtension just has a single enum operand, not a string. 284 auto Ext = SPIRV::Extension::Extension(MI.getOperand(0).getImm()); 285 MAI.Reqs.addExtension(Ext); 286 MAI.setSkipEmission(&MI); 287 } else if (MI.getOpcode() == SPIRV::OpCapability) { 288 auto Cap = SPIRV::Capability::Capability(MI.getOperand(0).getImm()); 289 MAI.Reqs.addCapability(Cap); 290 MAI.setSkipEmission(&MI); 291 } 292 } 293 } 294 } 295 296 collectGlobalEntities( 297 DepsGraph, SPIRV::MB_ExtFuncDecls, 298 [](const SPIRV::DTSortableEntry *E) { return E->getIsFunc(); }, true); 299 } 300 301 // Look for IDs declared with Import linkage, and map the corresponding function 302 // to the register defining that variable (which will usually be the result of 303 // an OpFunction). This lets us call externally imported functions using 304 // the correct ID registers. 305 void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI, 306 const Function *F) { 307 if (MI.getOpcode() == SPIRV::OpDecorate) { 308 // If it's got Import linkage. 309 auto Dec = MI.getOperand(1).getImm(); 310 if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) { 311 auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm(); 312 if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) { 313 // Map imported function name to function ID register. 314 const Function *ImportedFunc = 315 F->getParent()->getFunction(getStringImm(MI, 2)); 316 Register Target = MI.getOperand(0).getReg(); 317 MAI.FuncMap[ImportedFunc] = MAI.getRegisterAlias(MI.getMF(), Target); 318 } 319 } 320 } else if (MI.getOpcode() == SPIRV::OpFunction) { 321 // Record all internal OpFunction declarations. 322 Register Reg = MI.defs().begin()->getReg(); 323 Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg); 324 assert(GlobalReg.isValid()); 325 MAI.FuncMap[F] = GlobalReg; 326 } 327 } 328 329 // References to a function via function pointers generate virtual 330 // registers without a definition. We are able to resolve this 331 // reference using Globar Register info into an OpFunction instruction 332 // and replace dummy operands by the corresponding global register references. 333 void SPIRVModuleAnalysis::collectFuncPtrs() { 334 for (auto &MI : MAI.MS[SPIRV::MB_TypeConstVars]) 335 if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL) 336 collectFuncPtrs(MI); 337 } 338 339 void SPIRVModuleAnalysis::collectFuncPtrs(MachineInstr *MI) { 340 const MachineOperand *FunUse = &MI->getOperand(2); 341 if (const MachineOperand *FunDef = GR->getFunctionDefinitionByUse(FunUse)) { 342 const MachineInstr *FunDefMI = FunDef->getParent(); 343 assert(FunDefMI->getOpcode() == SPIRV::OpFunction && 344 "Constant function pointer must refer to function definition"); 345 Register FunDefReg = FunDef->getReg(); 346 Register GlobalFunDefReg = 347 MAI.getRegisterAlias(FunDefMI->getMF(), FunDefReg); 348 assert(GlobalFunDefReg.isValid() && 349 "Function definition must refer to a global register"); 350 Register FunPtrReg = FunUse->getReg(); 351 MAI.setRegisterAlias(MI->getMF(), FunPtrReg, GlobalFunDefReg); 352 } 353 } 354 355 using InstrSignature = SmallVector<size_t>; 356 using InstrTraces = std::set<InstrSignature>; 357 358 // Returns a representation of an instruction as a vector of MachineOperand 359 // hash values, see llvm::hash_value(const MachineOperand &MO) for details. 360 // This creates a signature of the instruction with the same content 361 // that MachineOperand::isIdenticalTo uses for comparison. 362 static InstrSignature instrToSignature(MachineInstr &MI, 363 SPIRV::ModuleAnalysisInfo &MAI) { 364 InstrSignature Signature; 365 for (unsigned i = 0; i < MI.getNumOperands(); ++i) { 366 const MachineOperand &MO = MI.getOperand(i); 367 size_t h; 368 if (MO.isReg()) { 369 Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg()); 370 // mimic llvm::hash_value(const MachineOperand &MO) 371 h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(), 372 MO.isDef()); 373 } else { 374 h = hash_value(MO); 375 } 376 Signature.push_back(h); 377 } 378 return Signature; 379 } 380 381 // Collect the given instruction in the specified MS. We assume global register 382 // numbering has already occurred by this point. We can directly compare reg 383 // arguments when detecting duplicates. 384 static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI, 385 SPIRV::ModuleSectionType MSType, InstrTraces &IS, 386 bool Append = true) { 387 MAI.setSkipEmission(&MI); 388 InstrSignature MISign = instrToSignature(MI, MAI); 389 auto FoundMI = IS.insert(MISign); 390 if (!FoundMI.second) 391 return; // insert failed, so we found a duplicate; don't add it to MAI.MS 392 // No duplicates, so add it. 393 if (Append) 394 MAI.MS[MSType].push_back(&MI); 395 else 396 MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI); 397 } 398 399 // Some global instructions make reference to function-local ID regs, so cannot 400 // be correctly collected until these registers are globally numbered. 401 void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) { 402 InstrTraces IS; 403 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 404 if ((*F).isDeclaration()) 405 continue; 406 MachineFunction *MF = MMI->getMachineFunction(*F); 407 assert(MF); 408 for (MachineBasicBlock &MBB : *MF) 409 for (MachineInstr &MI : MBB) { 410 if (MAI.getSkipEmission(&MI)) 411 continue; 412 const unsigned OpCode = MI.getOpcode(); 413 if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) { 414 collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames, IS); 415 } else if (OpCode == SPIRV::OpEntryPoint) { 416 collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints, IS); 417 } else if (TII->isDecorationInstr(MI)) { 418 collectOtherInstr(MI, MAI, SPIRV::MB_Annotations, IS); 419 collectFuncNames(MI, &*F); 420 } else if (TII->isConstantInstr(MI)) { 421 // Now OpSpecConstant*s are not in DT, 422 // but they need to be collected anyway. 423 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS); 424 } else if (OpCode == SPIRV::OpFunction) { 425 collectFuncNames(MI, &*F); 426 } else if (OpCode == SPIRV::OpTypeForwardPointer) { 427 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS, false); 428 } 429 } 430 } 431 } 432 433 // Number registers in all functions globally from 0 onwards and store 434 // the result in global register alias table. Some registers are already 435 // numbered in collectGlobalEntities. 436 void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) { 437 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 438 if ((*F).isDeclaration()) 439 continue; 440 MachineFunction *MF = MMI->getMachineFunction(*F); 441 assert(MF); 442 for (MachineBasicBlock &MBB : *MF) { 443 for (MachineInstr &MI : MBB) { 444 for (MachineOperand &Op : MI.operands()) { 445 if (!Op.isReg()) 446 continue; 447 Register Reg = Op.getReg(); 448 if (MAI.hasRegisterAlias(MF, Reg)) 449 continue; 450 Register NewReg = Register::index2VirtReg(MAI.getNextID()); 451 MAI.setRegisterAlias(MF, Reg, NewReg); 452 } 453 if (MI.getOpcode() != SPIRV::OpExtInst) 454 continue; 455 auto Set = MI.getOperand(2).getImm(); 456 if (!MAI.ExtInstSetMap.contains(Set)) 457 MAI.ExtInstSetMap[Set] = Register::index2VirtReg(MAI.getNextID()); 458 } 459 } 460 } 461 } 462 463 // RequirementHandler implementations. 464 void SPIRV::RequirementHandler::getAndAddRequirements( 465 SPIRV::OperandCategory::OperandCategory Category, uint32_t i, 466 const SPIRVSubtarget &ST) { 467 addRequirements(getSymbolicOperandRequirements(Category, i, ST, *this)); 468 } 469 470 void SPIRV::RequirementHandler::recursiveAddCapabilities( 471 const CapabilityList &ToPrune) { 472 for (const auto &Cap : ToPrune) { 473 AllCaps.insert(Cap); 474 CapabilityList ImplicitDecls = 475 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap); 476 recursiveAddCapabilities(ImplicitDecls); 477 } 478 } 479 480 void SPIRV::RequirementHandler::addCapabilities(const CapabilityList &ToAdd) { 481 for (const auto &Cap : ToAdd) { 482 bool IsNewlyInserted = AllCaps.insert(Cap).second; 483 if (!IsNewlyInserted) // Don't re-add if it's already been declared. 484 continue; 485 CapabilityList ImplicitDecls = 486 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap); 487 recursiveAddCapabilities(ImplicitDecls); 488 MinimalCaps.push_back(Cap); 489 } 490 } 491 492 void SPIRV::RequirementHandler::addRequirements( 493 const SPIRV::Requirements &Req) { 494 if (!Req.IsSatisfiable) 495 report_fatal_error("Adding SPIR-V requirements this target can't satisfy."); 496 497 if (Req.Cap.has_value()) 498 addCapabilities({Req.Cap.value()}); 499 500 addExtensions(Req.Exts); 501 502 if (Req.MinVer) { 503 if (MaxVersion && Req.MinVer > MaxVersion) { 504 LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer 505 << " and <= " << MaxVersion << "\n"); 506 report_fatal_error("Adding SPIR-V requirements that can't be satisfied."); 507 } 508 509 if (MinVersion == 0 || Req.MinVer > MinVersion) 510 MinVersion = Req.MinVer; 511 } 512 513 if (Req.MaxVer) { 514 if (MinVersion && Req.MaxVer < MinVersion) { 515 LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer 516 << " and >= " << MinVersion << "\n"); 517 report_fatal_error("Adding SPIR-V requirements that can't be satisfied."); 518 } 519 520 if (MaxVersion == 0 || Req.MaxVer < MaxVersion) 521 MaxVersion = Req.MaxVer; 522 } 523 } 524 525 void SPIRV::RequirementHandler::checkSatisfiable( 526 const SPIRVSubtarget &ST) const { 527 // Report as many errors as possible before aborting the compilation. 528 bool IsSatisfiable = true; 529 auto TargetVer = ST.getSPIRVVersion(); 530 531 if (MaxVersion && TargetVer && MaxVersion < TargetVer) { 532 LLVM_DEBUG( 533 dbgs() << "Target SPIR-V version too high for required features\n" 534 << "Required max version: " << MaxVersion << " target version " 535 << TargetVer << "\n"); 536 IsSatisfiable = false; 537 } 538 539 if (MinVersion && TargetVer && MinVersion > TargetVer) { 540 LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n" 541 << "Required min version: " << MinVersion 542 << " target version " << TargetVer << "\n"); 543 IsSatisfiable = false; 544 } 545 546 if (MinVersion && MaxVersion && MinVersion > MaxVersion) { 547 LLVM_DEBUG( 548 dbgs() 549 << "Version is too low for some features and too high for others.\n" 550 << "Required SPIR-V min version: " << MinVersion 551 << " required SPIR-V max version " << MaxVersion << "\n"); 552 IsSatisfiable = false; 553 } 554 555 for (auto Cap : MinimalCaps) { 556 if (AvailableCaps.contains(Cap)) 557 continue; 558 LLVM_DEBUG(dbgs() << "Capability not supported: " 559 << getSymbolicOperandMnemonic( 560 OperandCategory::CapabilityOperand, Cap) 561 << "\n"); 562 IsSatisfiable = false; 563 } 564 565 for (auto Ext : AllExtensions) { 566 if (ST.canUseExtension(Ext)) 567 continue; 568 LLVM_DEBUG(dbgs() << "Extension not supported: " 569 << getSymbolicOperandMnemonic( 570 OperandCategory::ExtensionOperand, Ext) 571 << "\n"); 572 IsSatisfiable = false; 573 } 574 575 if (!IsSatisfiable) 576 report_fatal_error("Unable to meet SPIR-V requirements for this target."); 577 } 578 579 // Add the given capabilities and all their implicitly defined capabilities too. 580 void SPIRV::RequirementHandler::addAvailableCaps(const CapabilityList &ToAdd) { 581 for (const auto Cap : ToAdd) 582 if (AvailableCaps.insert(Cap).second) 583 addAvailableCaps(getSymbolicOperandCapabilities( 584 SPIRV::OperandCategory::CapabilityOperand, Cap)); 585 } 586 587 void SPIRV::RequirementHandler::removeCapabilityIf( 588 const Capability::Capability ToRemove, 589 const Capability::Capability IfPresent) { 590 if (AllCaps.contains(IfPresent)) 591 AllCaps.erase(ToRemove); 592 } 593 594 namespace llvm { 595 namespace SPIRV { 596 void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) { 597 if (ST.isOpenCLEnv()) { 598 initAvailableCapabilitiesForOpenCL(ST); 599 return; 600 } 601 602 if (ST.isVulkanEnv()) { 603 initAvailableCapabilitiesForVulkan(ST); 604 return; 605 } 606 607 report_fatal_error("Unimplemented environment for SPIR-V generation."); 608 } 609 610 void RequirementHandler::initAvailableCapabilitiesForOpenCL( 611 const SPIRVSubtarget &ST) { 612 // Add the min requirements for different OpenCL and SPIR-V versions. 613 addAvailableCaps({Capability::Addresses, Capability::Float16Buffer, 614 Capability::Int16, Capability::Int8, Capability::Kernel, 615 Capability::Linkage, Capability::Vector16, 616 Capability::Groups, Capability::GenericPointer, 617 Capability::Shader}); 618 if (ST.hasOpenCLFullProfile()) 619 addAvailableCaps({Capability::Int64, Capability::Int64Atomics}); 620 if (ST.hasOpenCLImageSupport()) { 621 addAvailableCaps({Capability::ImageBasic, Capability::LiteralSampler, 622 Capability::Image1D, Capability::SampledBuffer, 623 Capability::ImageBuffer}); 624 if (ST.isAtLeastOpenCLVer(20)) 625 addAvailableCaps({Capability::ImageReadWrite}); 626 } 627 if (ST.isAtLeastSPIRVVer(11) && ST.isAtLeastOpenCLVer(22)) 628 addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage}); 629 if (ST.isAtLeastSPIRVVer(13)) 630 addAvailableCaps({Capability::GroupNonUniform, 631 Capability::GroupNonUniformVote, 632 Capability::GroupNonUniformArithmetic, 633 Capability::GroupNonUniformBallot, 634 Capability::GroupNonUniformClustered, 635 Capability::GroupNonUniformShuffle, 636 Capability::GroupNonUniformShuffleRelative}); 637 if (ST.isAtLeastSPIRVVer(14)) 638 addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero, 639 Capability::SignedZeroInfNanPreserve, 640 Capability::RoundingModeRTE, 641 Capability::RoundingModeRTZ}); 642 // TODO: verify if this needs some checks. 643 addAvailableCaps({Capability::Float16, Capability::Float64}); 644 645 // Add capabilities enabled by extensions. 646 for (auto Extension : ST.getAllAvailableExtensions()) { 647 CapabilityList EnabledCapabilities = 648 getCapabilitiesEnabledByExtension(Extension); 649 addAvailableCaps(EnabledCapabilities); 650 } 651 652 // TODO: add OpenCL extensions. 653 } 654 655 void RequirementHandler::initAvailableCapabilitiesForVulkan( 656 const SPIRVSubtarget &ST) { 657 addAvailableCaps({Capability::Shader, Capability::Linkage}); 658 659 // Provided by all supported Vulkan versions. 660 addAvailableCaps({Capability::Int16, Capability::Int64, Capability::Float16, 661 Capability::Float64}); 662 } 663 664 } // namespace SPIRV 665 } // namespace llvm 666 667 // Add the required capabilities from a decoration instruction (including 668 // BuiltIns). 669 static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex, 670 SPIRV::RequirementHandler &Reqs, 671 const SPIRVSubtarget &ST) { 672 int64_t DecOp = MI.getOperand(DecIndex).getImm(); 673 auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp); 674 Reqs.addRequirements(getSymbolicOperandRequirements( 675 SPIRV::OperandCategory::DecorationOperand, Dec, ST, Reqs)); 676 677 if (Dec == SPIRV::Decoration::BuiltIn) { 678 int64_t BuiltInOp = MI.getOperand(DecIndex + 1).getImm(); 679 auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp); 680 Reqs.addRequirements(getSymbolicOperandRequirements( 681 SPIRV::OperandCategory::BuiltInOperand, BuiltIn, ST, Reqs)); 682 } else if (Dec == SPIRV::Decoration::LinkageAttributes) { 683 int64_t LinkageOp = MI.getOperand(MI.getNumOperands() - 1).getImm(); 684 SPIRV::LinkageType::LinkageType LnkType = 685 static_cast<SPIRV::LinkageType::LinkageType>(LinkageOp); 686 if (LnkType == SPIRV::LinkageType::LinkOnceODR) 687 Reqs.addExtension(SPIRV::Extension::SPV_KHR_linkonce_odr); 688 } 689 } 690 691 // Add requirements for image handling. 692 static void addOpTypeImageReqs(const MachineInstr &MI, 693 SPIRV::RequirementHandler &Reqs, 694 const SPIRVSubtarget &ST) { 695 assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage"); 696 // The operand indices used here are based on the OpTypeImage layout, which 697 // the MachineInstr follows as well. 698 int64_t ImgFormatOp = MI.getOperand(7).getImm(); 699 auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp); 700 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageFormatOperand, 701 ImgFormat, ST); 702 703 bool IsArrayed = MI.getOperand(4).getImm() == 1; 704 bool IsMultisampled = MI.getOperand(5).getImm() == 1; 705 bool NoSampler = MI.getOperand(6).getImm() == 2; 706 // Add dimension requirements. 707 assert(MI.getOperand(2).isImm()); 708 switch (MI.getOperand(2).getImm()) { 709 case SPIRV::Dim::DIM_1D: 710 Reqs.addRequirements(NoSampler ? SPIRV::Capability::Image1D 711 : SPIRV::Capability::Sampled1D); 712 break; 713 case SPIRV::Dim::DIM_2D: 714 if (IsMultisampled && NoSampler) 715 Reqs.addRequirements(SPIRV::Capability::ImageMSArray); 716 break; 717 case SPIRV::Dim::DIM_Cube: 718 Reqs.addRequirements(SPIRV::Capability::Shader); 719 if (IsArrayed) 720 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageCubeArray 721 : SPIRV::Capability::SampledCubeArray); 722 break; 723 case SPIRV::Dim::DIM_Rect: 724 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageRect 725 : SPIRV::Capability::SampledRect); 726 break; 727 case SPIRV::Dim::DIM_Buffer: 728 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageBuffer 729 : SPIRV::Capability::SampledBuffer); 730 break; 731 case SPIRV::Dim::DIM_SubpassData: 732 Reqs.addRequirements(SPIRV::Capability::InputAttachment); 733 break; 734 } 735 736 // Has optional access qualifier. 737 // TODO: check if it's OpenCL's kernel. 738 if (MI.getNumOperands() > 8 && 739 MI.getOperand(8).getImm() == SPIRV::AccessQualifier::ReadWrite) 740 Reqs.addRequirements(SPIRV::Capability::ImageReadWrite); 741 else 742 Reqs.addRequirements(SPIRV::Capability::ImageBasic); 743 } 744 745 // Add requirements for handling atomic float instructions 746 #define ATOM_FLT_REQ_EXT_MSG(ExtName) \ 747 "The atomic float instruction requires the following SPIR-V " \ 748 "extension: SPV_EXT_shader_atomic_float" ExtName 749 static void AddAtomicFloatRequirements(const MachineInstr &MI, 750 SPIRV::RequirementHandler &Reqs, 751 const SPIRVSubtarget &ST) { 752 assert(MI.getOperand(1).isReg() && 753 "Expect register operand in atomic float instruction"); 754 Register TypeReg = MI.getOperand(1).getReg(); 755 SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg); 756 if (TypeDef->getOpcode() != SPIRV::OpTypeFloat) 757 report_fatal_error("Result type of an atomic float instruction must be a " 758 "floating-point type scalar"); 759 760 unsigned BitWidth = TypeDef->getOperand(1).getImm(); 761 unsigned Op = MI.getOpcode(); 762 if (Op == SPIRV::OpAtomicFAddEXT) { 763 if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add)) 764 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_add"), false); 765 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add); 766 switch (BitWidth) { 767 case 16: 768 if (!ST.canUseExtension( 769 SPIRV::Extension::SPV_EXT_shader_atomic_float16_add)) 770 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false); 771 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add); 772 Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT); 773 break; 774 case 32: 775 Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT); 776 break; 777 case 64: 778 Reqs.addCapability(SPIRV::Capability::AtomicFloat64AddEXT); 779 break; 780 default: 781 report_fatal_error( 782 "Unexpected floating-point type width in atomic float instruction"); 783 } 784 } else { 785 if (!ST.canUseExtension( 786 SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max)) 787 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), false); 788 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max); 789 switch (BitWidth) { 790 case 16: 791 Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT); 792 break; 793 case 32: 794 Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT); 795 break; 796 case 64: 797 Reqs.addCapability(SPIRV::Capability::AtomicFloat64MinMaxEXT); 798 break; 799 default: 800 report_fatal_error( 801 "Unexpected floating-point type width in atomic float instruction"); 802 } 803 } 804 } 805 806 void addInstrRequirements(const MachineInstr &MI, 807 SPIRV::RequirementHandler &Reqs, 808 const SPIRVSubtarget &ST) { 809 switch (MI.getOpcode()) { 810 case SPIRV::OpMemoryModel: { 811 int64_t Addr = MI.getOperand(0).getImm(); 812 Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand, 813 Addr, ST); 814 int64_t Mem = MI.getOperand(1).getImm(); 815 Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, Mem, 816 ST); 817 break; 818 } 819 case SPIRV::OpEntryPoint: { 820 int64_t Exe = MI.getOperand(0).getImm(); 821 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModelOperand, 822 Exe, ST); 823 break; 824 } 825 case SPIRV::OpExecutionMode: 826 case SPIRV::OpExecutionModeId: { 827 int64_t Exe = MI.getOperand(1).getImm(); 828 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModeOperand, 829 Exe, ST); 830 break; 831 } 832 case SPIRV::OpTypeMatrix: 833 Reqs.addCapability(SPIRV::Capability::Matrix); 834 break; 835 case SPIRV::OpTypeInt: { 836 unsigned BitWidth = MI.getOperand(1).getImm(); 837 if (BitWidth == 64) 838 Reqs.addCapability(SPIRV::Capability::Int64); 839 else if (BitWidth == 16) 840 Reqs.addCapability(SPIRV::Capability::Int16); 841 else if (BitWidth == 8) 842 Reqs.addCapability(SPIRV::Capability::Int8); 843 break; 844 } 845 case SPIRV::OpTypeFloat: { 846 unsigned BitWidth = MI.getOperand(1).getImm(); 847 if (BitWidth == 64) 848 Reqs.addCapability(SPIRV::Capability::Float64); 849 else if (BitWidth == 16) 850 Reqs.addCapability(SPIRV::Capability::Float16); 851 break; 852 } 853 case SPIRV::OpTypeVector: { 854 unsigned NumComponents = MI.getOperand(2).getImm(); 855 if (NumComponents == 8 || NumComponents == 16) 856 Reqs.addCapability(SPIRV::Capability::Vector16); 857 break; 858 } 859 case SPIRV::OpTypePointer: { 860 auto SC = MI.getOperand(1).getImm(); 861 Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC, 862 ST); 863 // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer 864 // capability. 865 if (!ST.isOpenCLEnv()) 866 break; 867 assert(MI.getOperand(2).isReg()); 868 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); 869 SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg()); 870 if (TypeDef->getOpcode() == SPIRV::OpTypeFloat && 871 TypeDef->getOperand(1).getImm() == 16) 872 Reqs.addCapability(SPIRV::Capability::Float16Buffer); 873 break; 874 } 875 case SPIRV::OpBitReverse: 876 case SPIRV::OpBitFieldInsert: 877 case SPIRV::OpBitFieldSExtract: 878 case SPIRV::OpBitFieldUExtract: 879 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions)) { 880 Reqs.addCapability(SPIRV::Capability::Shader); 881 break; 882 } 883 Reqs.addExtension(SPIRV::Extension::SPV_KHR_bit_instructions); 884 Reqs.addCapability(SPIRV::Capability::BitInstructions); 885 break; 886 case SPIRV::OpTypeRuntimeArray: 887 Reqs.addCapability(SPIRV::Capability::Shader); 888 break; 889 case SPIRV::OpTypeOpaque: 890 case SPIRV::OpTypeEvent: 891 Reqs.addCapability(SPIRV::Capability::Kernel); 892 break; 893 case SPIRV::OpTypePipe: 894 case SPIRV::OpTypeReserveId: 895 Reqs.addCapability(SPIRV::Capability::Pipes); 896 break; 897 case SPIRV::OpTypeDeviceEvent: 898 case SPIRV::OpTypeQueue: 899 case SPIRV::OpBuildNDRange: 900 Reqs.addCapability(SPIRV::Capability::DeviceEnqueue); 901 break; 902 case SPIRV::OpDecorate: 903 case SPIRV::OpDecorateId: 904 case SPIRV::OpDecorateString: 905 addOpDecorateReqs(MI, 1, Reqs, ST); 906 break; 907 case SPIRV::OpMemberDecorate: 908 case SPIRV::OpMemberDecorateString: 909 addOpDecorateReqs(MI, 2, Reqs, ST); 910 break; 911 case SPIRV::OpInBoundsPtrAccessChain: 912 Reqs.addCapability(SPIRV::Capability::Addresses); 913 break; 914 case SPIRV::OpConstantSampler: 915 Reqs.addCapability(SPIRV::Capability::LiteralSampler); 916 break; 917 case SPIRV::OpTypeImage: 918 addOpTypeImageReqs(MI, Reqs, ST); 919 break; 920 case SPIRV::OpTypeSampler: 921 Reqs.addCapability(SPIRV::Capability::ImageBasic); 922 break; 923 case SPIRV::OpTypeForwardPointer: 924 // TODO: check if it's OpenCL's kernel. 925 Reqs.addCapability(SPIRV::Capability::Addresses); 926 break; 927 case SPIRV::OpAtomicFlagTestAndSet: 928 case SPIRV::OpAtomicLoad: 929 case SPIRV::OpAtomicStore: 930 case SPIRV::OpAtomicExchange: 931 case SPIRV::OpAtomicCompareExchange: 932 case SPIRV::OpAtomicIIncrement: 933 case SPIRV::OpAtomicIDecrement: 934 case SPIRV::OpAtomicIAdd: 935 case SPIRV::OpAtomicISub: 936 case SPIRV::OpAtomicUMin: 937 case SPIRV::OpAtomicUMax: 938 case SPIRV::OpAtomicSMin: 939 case SPIRV::OpAtomicSMax: 940 case SPIRV::OpAtomicAnd: 941 case SPIRV::OpAtomicOr: 942 case SPIRV::OpAtomicXor: { 943 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); 944 const MachineInstr *InstrPtr = &MI; 945 if (MI.getOpcode() == SPIRV::OpAtomicStore) { 946 assert(MI.getOperand(3).isReg()); 947 InstrPtr = MRI.getVRegDef(MI.getOperand(3).getReg()); 948 assert(InstrPtr && "Unexpected type instruction for OpAtomicStore"); 949 } 950 assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic"); 951 Register TypeReg = InstrPtr->getOperand(1).getReg(); 952 SPIRVType *TypeDef = MRI.getVRegDef(TypeReg); 953 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) { 954 unsigned BitWidth = TypeDef->getOperand(1).getImm(); 955 if (BitWidth == 64) 956 Reqs.addCapability(SPIRV::Capability::Int64Atomics); 957 } 958 break; 959 } 960 case SPIRV::OpGroupNonUniformIAdd: 961 case SPIRV::OpGroupNonUniformFAdd: 962 case SPIRV::OpGroupNonUniformIMul: 963 case SPIRV::OpGroupNonUniformFMul: 964 case SPIRV::OpGroupNonUniformSMin: 965 case SPIRV::OpGroupNonUniformUMin: 966 case SPIRV::OpGroupNonUniformFMin: 967 case SPIRV::OpGroupNonUniformSMax: 968 case SPIRV::OpGroupNonUniformUMax: 969 case SPIRV::OpGroupNonUniformFMax: 970 case SPIRV::OpGroupNonUniformBitwiseAnd: 971 case SPIRV::OpGroupNonUniformBitwiseOr: 972 case SPIRV::OpGroupNonUniformBitwiseXor: 973 case SPIRV::OpGroupNonUniformLogicalAnd: 974 case SPIRV::OpGroupNonUniformLogicalOr: 975 case SPIRV::OpGroupNonUniformLogicalXor: { 976 assert(MI.getOperand(3).isImm()); 977 int64_t GroupOp = MI.getOperand(3).getImm(); 978 switch (GroupOp) { 979 case SPIRV::GroupOperation::Reduce: 980 case SPIRV::GroupOperation::InclusiveScan: 981 case SPIRV::GroupOperation::ExclusiveScan: 982 Reqs.addCapability(SPIRV::Capability::Kernel); 983 Reqs.addCapability(SPIRV::Capability::GroupNonUniformArithmetic); 984 Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot); 985 break; 986 case SPIRV::GroupOperation::ClusteredReduce: 987 Reqs.addCapability(SPIRV::Capability::GroupNonUniformClustered); 988 break; 989 case SPIRV::GroupOperation::PartitionedReduceNV: 990 case SPIRV::GroupOperation::PartitionedInclusiveScanNV: 991 case SPIRV::GroupOperation::PartitionedExclusiveScanNV: 992 Reqs.addCapability(SPIRV::Capability::GroupNonUniformPartitionedNV); 993 break; 994 } 995 break; 996 } 997 case SPIRV::OpGroupNonUniformShuffle: 998 case SPIRV::OpGroupNonUniformShuffleXor: 999 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffle); 1000 break; 1001 case SPIRV::OpGroupNonUniformShuffleUp: 1002 case SPIRV::OpGroupNonUniformShuffleDown: 1003 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffleRelative); 1004 break; 1005 case SPIRV::OpGroupAll: 1006 case SPIRV::OpGroupAny: 1007 case SPIRV::OpGroupBroadcast: 1008 case SPIRV::OpGroupIAdd: 1009 case SPIRV::OpGroupFAdd: 1010 case SPIRV::OpGroupFMin: 1011 case SPIRV::OpGroupUMin: 1012 case SPIRV::OpGroupSMin: 1013 case SPIRV::OpGroupFMax: 1014 case SPIRV::OpGroupUMax: 1015 case SPIRV::OpGroupSMax: 1016 Reqs.addCapability(SPIRV::Capability::Groups); 1017 break; 1018 case SPIRV::OpGroupNonUniformElect: 1019 Reqs.addCapability(SPIRV::Capability::GroupNonUniform); 1020 break; 1021 case SPIRV::OpGroupNonUniformAll: 1022 case SPIRV::OpGroupNonUniformAny: 1023 case SPIRV::OpGroupNonUniformAllEqual: 1024 Reqs.addCapability(SPIRV::Capability::GroupNonUniformVote); 1025 break; 1026 case SPIRV::OpGroupNonUniformBroadcast: 1027 case SPIRV::OpGroupNonUniformBroadcastFirst: 1028 case SPIRV::OpGroupNonUniformBallot: 1029 case SPIRV::OpGroupNonUniformInverseBallot: 1030 case SPIRV::OpGroupNonUniformBallotBitExtract: 1031 case SPIRV::OpGroupNonUniformBallotBitCount: 1032 case SPIRV::OpGroupNonUniformBallotFindLSB: 1033 case SPIRV::OpGroupNonUniformBallotFindMSB: 1034 Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot); 1035 break; 1036 case SPIRV::OpSubgroupShuffleINTEL: 1037 case SPIRV::OpSubgroupShuffleDownINTEL: 1038 case SPIRV::OpSubgroupShuffleUpINTEL: 1039 case SPIRV::OpSubgroupShuffleXorINTEL: 1040 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) { 1041 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups); 1042 Reqs.addCapability(SPIRV::Capability::SubgroupShuffleINTEL); 1043 } 1044 break; 1045 case SPIRV::OpSubgroupBlockReadINTEL: 1046 case SPIRV::OpSubgroupBlockWriteINTEL: 1047 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) { 1048 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups); 1049 Reqs.addCapability(SPIRV::Capability::SubgroupBufferBlockIOINTEL); 1050 } 1051 break; 1052 case SPIRV::OpSubgroupImageBlockReadINTEL: 1053 case SPIRV::OpSubgroupImageBlockWriteINTEL: 1054 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) { 1055 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups); 1056 Reqs.addCapability(SPIRV::Capability::SubgroupImageBlockIOINTEL); 1057 } 1058 break; 1059 case SPIRV::OpAssumeTrueKHR: 1060 case SPIRV::OpExpectKHR: 1061 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) { 1062 Reqs.addExtension(SPIRV::Extension::SPV_KHR_expect_assume); 1063 Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR); 1064 } 1065 break; 1066 case SPIRV::OpConstantFunctionPointerINTEL: 1067 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) { 1068 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers); 1069 Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL); 1070 } 1071 break; 1072 case SPIRV::OpGroupNonUniformRotateKHR: 1073 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate)) 1074 report_fatal_error("OpGroupNonUniformRotateKHR instruction requires the " 1075 "following SPIR-V extension: SPV_KHR_subgroup_rotate", 1076 false); 1077 Reqs.addExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate); 1078 Reqs.addCapability(SPIRV::Capability::GroupNonUniformRotateKHR); 1079 Reqs.addCapability(SPIRV::Capability::GroupNonUniform); 1080 break; 1081 case SPIRV::OpGroupIMulKHR: 1082 case SPIRV::OpGroupFMulKHR: 1083 case SPIRV::OpGroupBitwiseAndKHR: 1084 case SPIRV::OpGroupBitwiseOrKHR: 1085 case SPIRV::OpGroupBitwiseXorKHR: 1086 case SPIRV::OpGroupLogicalAndKHR: 1087 case SPIRV::OpGroupLogicalOrKHR: 1088 case SPIRV::OpGroupLogicalXorKHR: 1089 if (ST.canUseExtension( 1090 SPIRV::Extension::SPV_KHR_uniform_group_instructions)) { 1091 Reqs.addExtension(SPIRV::Extension::SPV_KHR_uniform_group_instructions); 1092 Reqs.addCapability(SPIRV::Capability::GroupUniformArithmeticKHR); 1093 } 1094 break; 1095 case SPIRV::OpFunctionPointerCallINTEL: 1096 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) { 1097 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers); 1098 Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL); 1099 } 1100 break; 1101 case SPIRV::OpAtomicFAddEXT: 1102 case SPIRV::OpAtomicFMinEXT: 1103 case SPIRV::OpAtomicFMaxEXT: 1104 AddAtomicFloatRequirements(MI, Reqs, ST); 1105 break; 1106 default: 1107 break; 1108 } 1109 1110 // If we require capability Shader, then we can remove the requirement for 1111 // the BitInstructions capability, since Shader is a superset capability 1112 // of BitInstructions. 1113 Reqs.removeCapabilityIf(SPIRV::Capability::BitInstructions, 1114 SPIRV::Capability::Shader); 1115 } 1116 1117 static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, 1118 MachineModuleInfo *MMI, const SPIRVSubtarget &ST) { 1119 // Collect requirements for existing instructions. 1120 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 1121 MachineFunction *MF = MMI->getMachineFunction(*F); 1122 if (!MF) 1123 continue; 1124 for (const MachineBasicBlock &MBB : *MF) 1125 for (const MachineInstr &MI : MBB) 1126 addInstrRequirements(MI, MAI.Reqs, ST); 1127 } 1128 // Collect requirements for OpExecutionMode instructions. 1129 auto Node = M.getNamedMetadata("spirv.ExecutionMode"); 1130 if (Node) { 1131 for (unsigned i = 0; i < Node->getNumOperands(); i++) { 1132 MDNode *MDN = cast<MDNode>(Node->getOperand(i)); 1133 const MDOperand &MDOp = MDN->getOperand(1); 1134 if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) { 1135 Constant *C = CMeta->getValue(); 1136 if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) { 1137 auto EM = Const->getZExtValue(); 1138 MAI.Reqs.getAndAddRequirements( 1139 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST); 1140 } 1141 } 1142 } 1143 } 1144 for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) { 1145 const Function &F = *FI; 1146 if (F.isDeclaration()) 1147 continue; 1148 if (F.getMetadata("reqd_work_group_size")) 1149 MAI.Reqs.getAndAddRequirements( 1150 SPIRV::OperandCategory::ExecutionModeOperand, 1151 SPIRV::ExecutionMode::LocalSize, ST); 1152 if (F.getFnAttribute("hlsl.numthreads").isValid()) { 1153 MAI.Reqs.getAndAddRequirements( 1154 SPIRV::OperandCategory::ExecutionModeOperand, 1155 SPIRV::ExecutionMode::LocalSize, ST); 1156 } 1157 if (F.getMetadata("work_group_size_hint")) 1158 MAI.Reqs.getAndAddRequirements( 1159 SPIRV::OperandCategory::ExecutionModeOperand, 1160 SPIRV::ExecutionMode::LocalSizeHint, ST); 1161 if (F.getMetadata("intel_reqd_sub_group_size")) 1162 MAI.Reqs.getAndAddRequirements( 1163 SPIRV::OperandCategory::ExecutionModeOperand, 1164 SPIRV::ExecutionMode::SubgroupSize, ST); 1165 if (F.getMetadata("vec_type_hint")) 1166 MAI.Reqs.getAndAddRequirements( 1167 SPIRV::OperandCategory::ExecutionModeOperand, 1168 SPIRV::ExecutionMode::VecTypeHint, ST); 1169 1170 if (F.hasOptNone() && 1171 ST.canUseExtension(SPIRV::Extension::SPV_INTEL_optnone)) { 1172 // Output OpCapability OptNoneINTEL. 1173 MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_optnone); 1174 MAI.Reqs.addCapability(SPIRV::Capability::OptNoneINTEL); 1175 } 1176 } 1177 } 1178 1179 static unsigned getFastMathFlags(const MachineInstr &I) { 1180 unsigned Flags = SPIRV::FPFastMathMode::None; 1181 if (I.getFlag(MachineInstr::MIFlag::FmNoNans)) 1182 Flags |= SPIRV::FPFastMathMode::NotNaN; 1183 if (I.getFlag(MachineInstr::MIFlag::FmNoInfs)) 1184 Flags |= SPIRV::FPFastMathMode::NotInf; 1185 if (I.getFlag(MachineInstr::MIFlag::FmNsz)) 1186 Flags |= SPIRV::FPFastMathMode::NSZ; 1187 if (I.getFlag(MachineInstr::MIFlag::FmArcp)) 1188 Flags |= SPIRV::FPFastMathMode::AllowRecip; 1189 if (I.getFlag(MachineInstr::MIFlag::FmReassoc)) 1190 Flags |= SPIRV::FPFastMathMode::Fast; 1191 return Flags; 1192 } 1193 1194 static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST, 1195 const SPIRVInstrInfo &TII, 1196 SPIRV::RequirementHandler &Reqs) { 1197 if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) && 1198 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand, 1199 SPIRV::Decoration::NoSignedWrap, ST, Reqs) 1200 .IsSatisfiable) { 1201 buildOpDecorate(I.getOperand(0).getReg(), I, TII, 1202 SPIRV::Decoration::NoSignedWrap, {}); 1203 } 1204 if (I.getFlag(MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(I) && 1205 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand, 1206 SPIRV::Decoration::NoUnsignedWrap, ST, 1207 Reqs) 1208 .IsSatisfiable) { 1209 buildOpDecorate(I.getOperand(0).getReg(), I, TII, 1210 SPIRV::Decoration::NoUnsignedWrap, {}); 1211 } 1212 if (!TII.canUseFastMathFlags(I)) 1213 return; 1214 unsigned FMFlags = getFastMathFlags(I); 1215 if (FMFlags == SPIRV::FPFastMathMode::None) 1216 return; 1217 Register DstReg = I.getOperand(0).getReg(); 1218 buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode, {FMFlags}); 1219 } 1220 1221 // Walk all functions and add decorations related to MI flags. 1222 static void addDecorations(const Module &M, const SPIRVInstrInfo &TII, 1223 MachineModuleInfo *MMI, const SPIRVSubtarget &ST, 1224 SPIRV::ModuleAnalysisInfo &MAI) { 1225 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 1226 MachineFunction *MF = MMI->getMachineFunction(*F); 1227 if (!MF) 1228 continue; 1229 for (auto &MBB : *MF) 1230 for (auto &MI : MBB) 1231 handleMIFlagDecoration(MI, ST, TII, MAI.Reqs); 1232 } 1233 } 1234 1235 struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI; 1236 1237 void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { 1238 AU.addRequired<TargetPassConfig>(); 1239 AU.addRequired<MachineModuleInfoWrapperPass>(); 1240 } 1241 1242 bool SPIRVModuleAnalysis::runOnModule(Module &M) { 1243 SPIRVTargetMachine &TM = 1244 getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>(); 1245 ST = TM.getSubtargetImpl(); 1246 GR = ST->getSPIRVGlobalRegistry(); 1247 TII = ST->getInstrInfo(); 1248 1249 MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI(); 1250 1251 setBaseInfo(M); 1252 1253 addDecorations(M, *TII, MMI, *ST, MAI); 1254 1255 collectReqs(M, MAI, MMI, *ST); 1256 1257 // Process type/const/global var/func decl instructions, number their 1258 // destination registers from 0 to N, collect Extensions and Capabilities. 1259 processDefInstrs(M); 1260 1261 // Number rest of registers from N+1 onwards. 1262 numberRegistersGlobally(M); 1263 1264 // Update references to OpFunction instructions to use Global Registers 1265 if (GR->hasConstFunPtr()) 1266 collectFuncPtrs(); 1267 1268 // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions. 1269 processOtherInstrs(M); 1270 1271 // If there are no entry points, we need the Linkage capability. 1272 if (MAI.MS[SPIRV::MB_EntryPoints].empty()) 1273 MAI.Reqs.addCapability(SPIRV::Capability::Linkage); 1274 1275 return false; 1276 } 1277