1 //===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The analysis collects instructions that should be output at the module level 10 // and performs the global register numbering. 11 // 12 // The results of this analysis are used in AsmPrinter to rename registers 13 // globally and to output required instructions at the module level. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "SPIRVModuleAnalysis.h" 18 #include "SPIRV.h" 19 #include "SPIRVSubtarget.h" 20 #include "SPIRVTargetMachine.h" 21 #include "SPIRVUtils.h" 22 #include "TargetInfo/SPIRVTargetInfo.h" 23 #include "llvm/ADT/STLExtras.h" 24 #include "llvm/CodeGen/MachineModuleInfo.h" 25 #include "llvm/CodeGen/TargetPassConfig.h" 26 27 using namespace llvm; 28 29 #define DEBUG_TYPE "spirv-module-analysis" 30 31 static cl::opt<bool> 32 SPVDumpDeps("spv-dump-deps", 33 cl::desc("Dump MIR with SPIR-V dependencies info"), 34 cl::Optional, cl::init(false)); 35 36 char llvm::SPIRVModuleAnalysis::ID = 0; 37 38 namespace llvm { 39 void initializeSPIRVModuleAnalysisPass(PassRegistry &); 40 } // namespace llvm 41 42 INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true, 43 true) 44 45 // Retrieve an unsigned from an MDNode with a list of them as operands. 46 static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex, 47 unsigned DefaultVal = 0) { 48 if (MdNode && OpIndex < MdNode->getNumOperands()) { 49 const auto &Op = MdNode->getOperand(OpIndex); 50 return mdconst::extract<ConstantInt>(Op)->getZExtValue(); 51 } 52 return DefaultVal; 53 } 54 55 static SPIRV::Requirements 56 getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category, 57 unsigned i, const SPIRVSubtarget &ST, 58 SPIRV::RequirementHandler &Reqs) { 59 unsigned ReqMinVer = getSymbolicOperandMinVersion(Category, i); 60 unsigned ReqMaxVer = getSymbolicOperandMaxVersion(Category, i); 61 unsigned TargetVer = ST.getSPIRVVersion(); 62 bool MinVerOK = !ReqMinVer || !TargetVer || TargetVer >= ReqMinVer; 63 bool MaxVerOK = !ReqMaxVer || !TargetVer || TargetVer <= ReqMaxVer; 64 CapabilityList ReqCaps = getSymbolicOperandCapabilities(Category, i); 65 ExtensionList ReqExts = getSymbolicOperandExtensions(Category, i); 66 if (ReqCaps.empty()) { 67 if (ReqExts.empty()) { 68 if (MinVerOK && MaxVerOK) 69 return {true, {}, {}, ReqMinVer, ReqMaxVer}; 70 return {false, {}, {}, 0, 0}; 71 } 72 } else if (MinVerOK && MaxVerOK) { 73 for (auto Cap : ReqCaps) { // Only need 1 of the capabilities to work. 74 if (Reqs.isCapabilityAvailable(Cap)) 75 return {true, {Cap}, {}, ReqMinVer, ReqMaxVer}; 76 } 77 } 78 // If there are no capabilities, or we can't satisfy the version or 79 // capability requirements, use the list of extensions (if the subtarget 80 // can handle them all). 81 if (llvm::all_of(ReqExts, [&ST](const SPIRV::Extension::Extension &Ext) { 82 return ST.canUseExtension(Ext); 83 })) { 84 return {true, {}, ReqExts, 0, 0}; // TODO: add versions to extensions. 85 } 86 return {false, {}, {}, 0, 0}; 87 } 88 89 void SPIRVModuleAnalysis::setBaseInfo(const Module &M) { 90 MAI.MaxID = 0; 91 for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++) 92 MAI.MS[i].clear(); 93 MAI.RegisterAliasTable.clear(); 94 MAI.InstrsToDelete.clear(); 95 MAI.FuncMap.clear(); 96 MAI.GlobalVarList.clear(); 97 MAI.ExtInstSetMap.clear(); 98 MAI.Reqs.clear(); 99 MAI.Reqs.initAvailableCapabilities(*ST); 100 101 // TODO: determine memory model and source language from the configuratoin. 102 if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) { 103 auto MemMD = MemModel->getOperand(0); 104 MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>( 105 getMetadataUInt(MemMD, 0)); 106 MAI.Mem = 107 static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1)); 108 } else { 109 MAI.Mem = SPIRV::MemoryModel::OpenCL; 110 unsigned PtrSize = ST->getPointerSize(); 111 MAI.Addr = PtrSize == 32 ? SPIRV::AddressingModel::Physical32 112 : PtrSize == 64 ? SPIRV::AddressingModel::Physical64 113 : SPIRV::AddressingModel::Logical; 114 } 115 // Get the OpenCL version number from metadata. 116 // TODO: support other source languages. 117 if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) { 118 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C; 119 // Construct version literal in accordance with SPIRV-LLVM-Translator. 120 // TODO: support multiple OCL version metadata. 121 assert(VerNode->getNumOperands() > 0 && "Invalid SPIR"); 122 auto VersionMD = VerNode->getOperand(0); 123 unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2); 124 unsigned MinorNum = getMetadataUInt(VersionMD, 1); 125 unsigned RevNum = getMetadataUInt(VersionMD, 2); 126 MAI.SrcLangVersion = (MajorNum * 100 + MinorNum) * 1000 + RevNum; 127 } else { 128 MAI.SrcLang = SPIRV::SourceLanguage::Unknown; 129 MAI.SrcLangVersion = 0; 130 } 131 132 if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) { 133 for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) { 134 MDNode *MD = ExtNode->getOperand(I); 135 if (!MD || MD->getNumOperands() == 0) 136 continue; 137 for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J) 138 MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString()); 139 } 140 } 141 142 // Update required capabilities for this memory model, addressing model and 143 // source language. 144 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, 145 MAI.Mem, *ST); 146 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::SourceLanguageOperand, 147 MAI.SrcLang, *ST); 148 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand, 149 MAI.Addr, *ST); 150 151 // TODO: check if it's required by default. 152 MAI.ExtInstSetMap[static_cast<unsigned>(SPIRV::InstructionSet::OpenCL_std)] = 153 Register::index2VirtReg(MAI.getNextID()); 154 } 155 156 // Collect MI which defines the register in the given machine function. 157 static void collectDefInstr(Register Reg, const MachineFunction *MF, 158 SPIRV::ModuleAnalysisInfo *MAI, 159 SPIRV::ModuleSectionType MSType, 160 bool DoInsert = true) { 161 assert(MAI->hasRegisterAlias(MF, Reg) && "Cannot find register alias"); 162 MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(Reg); 163 assert(MI && "There should be an instruction that defines the register"); 164 MAI->setSkipEmission(MI); 165 if (DoInsert) 166 MAI->MS[MSType].push_back(MI); 167 } 168 169 void SPIRVModuleAnalysis::collectGlobalEntities( 170 const std::vector<SPIRV::DTSortableEntry *> &DepsGraph, 171 SPIRV::ModuleSectionType MSType, 172 std::function<bool(const SPIRV::DTSortableEntry *)> Pred, 173 bool UsePreOrder = false) { 174 DenseSet<const SPIRV::DTSortableEntry *> Visited; 175 for (const auto *E : DepsGraph) { 176 std::function<void(const SPIRV::DTSortableEntry *)> RecHoistUtil; 177 // NOTE: here we prefer recursive approach over iterative because 178 // we don't expect depchains long enough to cause SO. 179 RecHoistUtil = [MSType, UsePreOrder, &Visited, &Pred, 180 &RecHoistUtil](const SPIRV::DTSortableEntry *E) { 181 if (Visited.count(E) || !Pred(E)) 182 return; 183 Visited.insert(E); 184 185 // Traversing deps graph in post-order allows us to get rid of 186 // register aliases preprocessing. 187 // But pre-order is required for correct processing of function 188 // declaration and arguments processing. 189 if (!UsePreOrder) 190 for (auto *S : E->getDeps()) 191 RecHoistUtil(S); 192 193 Register GlobalReg = Register::index2VirtReg(MAI.getNextID()); 194 bool IsFirst = true; 195 for (auto &U : *E) { 196 const MachineFunction *MF = U.first; 197 Register Reg = U.second; 198 MAI.setRegisterAlias(MF, Reg, GlobalReg); 199 if (!MF->getRegInfo().getUniqueVRegDef(Reg)) 200 continue; 201 collectDefInstr(Reg, MF, &MAI, MSType, IsFirst); 202 IsFirst = false; 203 if (E->getIsGV()) 204 MAI.GlobalVarList.push_back(MF->getRegInfo().getUniqueVRegDef(Reg)); 205 } 206 207 if (UsePreOrder) 208 for (auto *S : E->getDeps()) 209 RecHoistUtil(S); 210 }; 211 RecHoistUtil(E); 212 } 213 } 214 215 // The function initializes global register alias table for types, consts, 216 // global vars and func decls and collects these instruction for output 217 // at module level. Also it collects explicit OpExtension/OpCapability 218 // instructions. 219 void SPIRVModuleAnalysis::processDefInstrs(const Module &M) { 220 std::vector<SPIRV::DTSortableEntry *> DepsGraph; 221 222 GR->buildDepsGraph(DepsGraph, SPVDumpDeps ? MMI : nullptr); 223 224 collectGlobalEntities( 225 DepsGraph, SPIRV::MB_TypeConstVars, 226 [](const SPIRV::DTSortableEntry *E) { return !E->getIsFunc(); }); 227 228 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 229 MachineFunction *MF = MMI->getMachineFunction(*F); 230 if (!MF) 231 continue; 232 // Iterate through and collect OpExtension/OpCapability instructions. 233 for (MachineBasicBlock &MBB : *MF) { 234 for (MachineInstr &MI : MBB) { 235 if (MI.getOpcode() == SPIRV::OpExtension) { 236 // Here, OpExtension just has a single enum operand, not a string. 237 auto Ext = SPIRV::Extension::Extension(MI.getOperand(0).getImm()); 238 MAI.Reqs.addExtension(Ext); 239 MAI.setSkipEmission(&MI); 240 } else if (MI.getOpcode() == SPIRV::OpCapability) { 241 auto Cap = SPIRV::Capability::Capability(MI.getOperand(0).getImm()); 242 MAI.Reqs.addCapability(Cap); 243 MAI.setSkipEmission(&MI); 244 } 245 } 246 } 247 } 248 249 collectGlobalEntities( 250 DepsGraph, SPIRV::MB_ExtFuncDecls, 251 [](const SPIRV::DTSortableEntry *E) { return E->getIsFunc(); }, true); 252 } 253 254 // True if there is an instruction in the MS list with all the same operands as 255 // the given instruction has (after the given starting index). 256 // TODO: maybe it needs to check Opcodes too. 257 static bool findSameInstrInMS(const MachineInstr &A, 258 SPIRV::ModuleSectionType MSType, 259 SPIRV::ModuleAnalysisInfo &MAI, 260 unsigned StartOpIndex = 0) { 261 for (const auto *B : MAI.MS[MSType]) { 262 const unsigned NumAOps = A.getNumOperands(); 263 if (NumAOps != B->getNumOperands() || A.getNumDefs() != B->getNumDefs()) 264 continue; 265 bool AllOpsMatch = true; 266 for (unsigned i = StartOpIndex; i < NumAOps && AllOpsMatch; ++i) { 267 if (A.getOperand(i).isReg() && B->getOperand(i).isReg()) { 268 Register RegA = A.getOperand(i).getReg(); 269 Register RegB = B->getOperand(i).getReg(); 270 AllOpsMatch = MAI.getRegisterAlias(A.getMF(), RegA) == 271 MAI.getRegisterAlias(B->getMF(), RegB); 272 } else { 273 AllOpsMatch = A.getOperand(i).isIdenticalTo(B->getOperand(i)); 274 } 275 } 276 if (AllOpsMatch) 277 return true; 278 } 279 return false; 280 } 281 282 // Look for IDs declared with Import linkage, and map the corresponding function 283 // to the register defining that variable (which will usually be the result of 284 // an OpFunction). This lets us call externally imported functions using 285 // the correct ID registers. 286 void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI, 287 const Function *F) { 288 if (MI.getOpcode() == SPIRV::OpDecorate) { 289 // If it's got Import linkage. 290 auto Dec = MI.getOperand(1).getImm(); 291 if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) { 292 auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm(); 293 if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) { 294 // Map imported function name to function ID register. 295 const Function *ImportedFunc = 296 F->getParent()->getFunction(getStringImm(MI, 2)); 297 Register Target = MI.getOperand(0).getReg(); 298 MAI.FuncMap[ImportedFunc] = MAI.getRegisterAlias(MI.getMF(), Target); 299 } 300 } 301 } else if (MI.getOpcode() == SPIRV::OpFunction) { 302 // Record all internal OpFunction declarations. 303 Register Reg = MI.defs().begin()->getReg(); 304 Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg); 305 assert(GlobalReg.isValid()); 306 MAI.FuncMap[F] = GlobalReg; 307 } 308 } 309 310 // Collect the given instruction in the specified MS. We assume global register 311 // numbering has already occurred by this point. We can directly compare reg 312 // arguments when detecting duplicates. 313 static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI, 314 SPIRV::ModuleSectionType MSType, 315 bool Append = true) { 316 MAI.setSkipEmission(&MI); 317 if (findSameInstrInMS(MI, MSType, MAI)) 318 return; // Found a duplicate, so don't add it. 319 // No duplicates, so add it. 320 if (Append) 321 MAI.MS[MSType].push_back(&MI); 322 else 323 MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI); 324 } 325 326 // Some global instructions make reference to function-local ID regs, so cannot 327 // be correctly collected until these registers are globally numbered. 328 void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) { 329 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 330 if ((*F).isDeclaration()) 331 continue; 332 MachineFunction *MF = MMI->getMachineFunction(*F); 333 assert(MF); 334 for (MachineBasicBlock &MBB : *MF) 335 for (MachineInstr &MI : MBB) { 336 if (MAI.getSkipEmission(&MI)) 337 continue; 338 const unsigned OpCode = MI.getOpcode(); 339 if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) { 340 collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames); 341 } else if (OpCode == SPIRV::OpEntryPoint) { 342 collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints); 343 } else if (TII->isDecorationInstr(MI)) { 344 collectOtherInstr(MI, MAI, SPIRV::MB_Annotations); 345 collectFuncNames(MI, &*F); 346 } else if (TII->isConstantInstr(MI)) { 347 // Now OpSpecConstant*s are not in DT, 348 // but they need to be collected anyway. 349 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars); 350 } else if (OpCode == SPIRV::OpFunction) { 351 collectFuncNames(MI, &*F); 352 } else if (OpCode == SPIRV::OpTypeForwardPointer) { 353 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, false); 354 } 355 } 356 } 357 } 358 359 // Number registers in all functions globally from 0 onwards and store 360 // the result in global register alias table. Some registers are already 361 // numbered in collectGlobalEntities. 362 void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) { 363 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 364 if ((*F).isDeclaration()) 365 continue; 366 MachineFunction *MF = MMI->getMachineFunction(*F); 367 assert(MF); 368 for (MachineBasicBlock &MBB : *MF) { 369 for (MachineInstr &MI : MBB) { 370 for (MachineOperand &Op : MI.operands()) { 371 if (!Op.isReg()) 372 continue; 373 Register Reg = Op.getReg(); 374 if (MAI.hasRegisterAlias(MF, Reg)) 375 continue; 376 Register NewReg = Register::index2VirtReg(MAI.getNextID()); 377 MAI.setRegisterAlias(MF, Reg, NewReg); 378 } 379 if (MI.getOpcode() != SPIRV::OpExtInst) 380 continue; 381 auto Set = MI.getOperand(2).getImm(); 382 if (MAI.ExtInstSetMap.find(Set) == MAI.ExtInstSetMap.end()) 383 MAI.ExtInstSetMap[Set] = Register::index2VirtReg(MAI.getNextID()); 384 } 385 } 386 } 387 } 388 389 // RequirementHandler implementations. 390 void SPIRV::RequirementHandler::getAndAddRequirements( 391 SPIRV::OperandCategory::OperandCategory Category, uint32_t i, 392 const SPIRVSubtarget &ST) { 393 addRequirements(getSymbolicOperandRequirements(Category, i, ST, *this)); 394 } 395 396 void SPIRV::RequirementHandler::pruneCapabilities( 397 const CapabilityList &ToPrune) { 398 for (const auto &Cap : ToPrune) { 399 AllCaps.insert(Cap); 400 auto FoundIndex = std::find(MinimalCaps.begin(), MinimalCaps.end(), Cap); 401 if (FoundIndex != MinimalCaps.end()) 402 MinimalCaps.erase(FoundIndex); 403 CapabilityList ImplicitDecls = 404 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap); 405 pruneCapabilities(ImplicitDecls); 406 } 407 } 408 409 void SPIRV::RequirementHandler::addCapabilities(const CapabilityList &ToAdd) { 410 for (const auto &Cap : ToAdd) { 411 bool IsNewlyInserted = AllCaps.insert(Cap).second; 412 if (!IsNewlyInserted) // Don't re-add if it's already been declared. 413 continue; 414 CapabilityList ImplicitDecls = 415 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap); 416 pruneCapabilities(ImplicitDecls); 417 MinimalCaps.push_back(Cap); 418 } 419 } 420 421 void SPIRV::RequirementHandler::addRequirements( 422 const SPIRV::Requirements &Req) { 423 if (!Req.IsSatisfiable) 424 report_fatal_error("Adding SPIR-V requirements this target can't satisfy."); 425 426 if (Req.Cap.has_value()) 427 addCapabilities({Req.Cap.value()}); 428 429 addExtensions(Req.Exts); 430 431 if (Req.MinVer) { 432 if (MaxVersion && Req.MinVer > MaxVersion) { 433 LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer 434 << " and <= " << MaxVersion << "\n"); 435 report_fatal_error("Adding SPIR-V requirements that can't be satisfied."); 436 } 437 438 if (MinVersion == 0 || Req.MinVer > MinVersion) 439 MinVersion = Req.MinVer; 440 } 441 442 if (Req.MaxVer) { 443 if (MinVersion && Req.MaxVer < MinVersion) { 444 LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer 445 << " and >= " << MinVersion << "\n"); 446 report_fatal_error("Adding SPIR-V requirements that can't be satisfied."); 447 } 448 449 if (MaxVersion == 0 || Req.MaxVer < MaxVersion) 450 MaxVersion = Req.MaxVer; 451 } 452 } 453 454 void SPIRV::RequirementHandler::checkSatisfiable( 455 const SPIRVSubtarget &ST) const { 456 // Report as many errors as possible before aborting the compilation. 457 bool IsSatisfiable = true; 458 auto TargetVer = ST.getSPIRVVersion(); 459 460 if (MaxVersion && TargetVer && MaxVersion < TargetVer) { 461 LLVM_DEBUG( 462 dbgs() << "Target SPIR-V version too high for required features\n" 463 << "Required max version: " << MaxVersion << " target version " 464 << TargetVer << "\n"); 465 IsSatisfiable = false; 466 } 467 468 if (MinVersion && TargetVer && MinVersion > TargetVer) { 469 LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n" 470 << "Required min version: " << MinVersion 471 << " target version " << TargetVer << "\n"); 472 IsSatisfiable = false; 473 } 474 475 if (MinVersion && MaxVersion && MinVersion > MaxVersion) { 476 LLVM_DEBUG( 477 dbgs() 478 << "Version is too low for some features and too high for others.\n" 479 << "Required SPIR-V min version: " << MinVersion 480 << " required SPIR-V max version " << MaxVersion << "\n"); 481 IsSatisfiable = false; 482 } 483 484 for (auto Cap : MinimalCaps) { 485 if (AvailableCaps.contains(Cap)) 486 continue; 487 LLVM_DEBUG(dbgs() << "Capability not supported: " 488 << getSymbolicOperandMnemonic( 489 OperandCategory::CapabilityOperand, Cap) 490 << "\n"); 491 IsSatisfiable = false; 492 } 493 494 for (auto Ext : AllExtensions) { 495 if (ST.canUseExtension(Ext)) 496 continue; 497 LLVM_DEBUG(dbgs() << "Extension not suported: " 498 << getSymbolicOperandMnemonic( 499 OperandCategory::ExtensionOperand, Ext) 500 << "\n"); 501 IsSatisfiable = false; 502 } 503 504 if (!IsSatisfiable) 505 report_fatal_error("Unable to meet SPIR-V requirements for this target."); 506 } 507 508 // Add the given capabilities and all their implicitly defined capabilities too. 509 void SPIRV::RequirementHandler::addAvailableCaps(const CapabilityList &ToAdd) { 510 for (const auto Cap : ToAdd) 511 if (AvailableCaps.insert(Cap).second) 512 addAvailableCaps(getSymbolicOperandCapabilities( 513 SPIRV::OperandCategory::CapabilityOperand, Cap)); 514 } 515 516 namespace llvm { 517 namespace SPIRV { 518 void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) { 519 // TODO: Implemented for other targets other then OpenCL. 520 if (!ST.isOpenCLEnv()) 521 return; 522 // Add the min requirements for different OpenCL and SPIR-V versions. 523 addAvailableCaps({Capability::Addresses, Capability::Float16Buffer, 524 Capability::Int16, Capability::Int8, Capability::Kernel, 525 Capability::Linkage, Capability::Vector16, 526 Capability::Groups, Capability::GenericPointer, 527 Capability::Shader}); 528 if (ST.hasOpenCLFullProfile()) 529 addAvailableCaps({Capability::Int64, Capability::Int64Atomics}); 530 if (ST.hasOpenCLImageSupport()) { 531 addAvailableCaps({Capability::ImageBasic, Capability::LiteralSampler, 532 Capability::Image1D, Capability::SampledBuffer, 533 Capability::ImageBuffer}); 534 if (ST.isAtLeastOpenCLVer(20)) 535 addAvailableCaps({Capability::ImageReadWrite}); 536 } 537 if (ST.isAtLeastSPIRVVer(11) && ST.isAtLeastOpenCLVer(22)) 538 addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage}); 539 if (ST.isAtLeastSPIRVVer(13)) 540 addAvailableCaps({Capability::GroupNonUniform, 541 Capability::GroupNonUniformVote, 542 Capability::GroupNonUniformArithmetic, 543 Capability::GroupNonUniformBallot, 544 Capability::GroupNonUniformClustered, 545 Capability::GroupNonUniformShuffle, 546 Capability::GroupNonUniformShuffleRelative}); 547 if (ST.isAtLeastSPIRVVer(14)) 548 addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero, 549 Capability::SignedZeroInfNanPreserve, 550 Capability::RoundingModeRTE, 551 Capability::RoundingModeRTZ}); 552 // TODO: verify if this needs some checks. 553 addAvailableCaps({Capability::Float16, Capability::Float64}); 554 555 // Add cap for SPV_INTEL_optnone. 556 // FIXME: this should be added only if the target has the extension. 557 addAvailableCaps({Capability::OptNoneINTEL}); 558 559 // TODO: add OpenCL extensions. 560 } 561 } // namespace SPIRV 562 } // namespace llvm 563 564 // Add the required capabilities from a decoration instruction (including 565 // BuiltIns). 566 static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex, 567 SPIRV::RequirementHandler &Reqs, 568 const SPIRVSubtarget &ST) { 569 int64_t DecOp = MI.getOperand(DecIndex).getImm(); 570 auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp); 571 Reqs.addRequirements(getSymbolicOperandRequirements( 572 SPIRV::OperandCategory::DecorationOperand, Dec, ST, Reqs)); 573 574 if (Dec == SPIRV::Decoration::BuiltIn) { 575 int64_t BuiltInOp = MI.getOperand(DecIndex + 1).getImm(); 576 auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp); 577 Reqs.addRequirements(getSymbolicOperandRequirements( 578 SPIRV::OperandCategory::BuiltInOperand, BuiltIn, ST, Reqs)); 579 } 580 } 581 582 // Add requirements for image handling. 583 static void addOpTypeImageReqs(const MachineInstr &MI, 584 SPIRV::RequirementHandler &Reqs, 585 const SPIRVSubtarget &ST) { 586 assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage"); 587 // The operand indices used here are based on the OpTypeImage layout, which 588 // the MachineInstr follows as well. 589 int64_t ImgFormatOp = MI.getOperand(7).getImm(); 590 auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp); 591 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageFormatOperand, 592 ImgFormat, ST); 593 594 bool IsArrayed = MI.getOperand(4).getImm() == 1; 595 bool IsMultisampled = MI.getOperand(5).getImm() == 1; 596 bool NoSampler = MI.getOperand(6).getImm() == 2; 597 // Add dimension requirements. 598 assert(MI.getOperand(2).isImm()); 599 switch (MI.getOperand(2).getImm()) { 600 case SPIRV::Dim::DIM_1D: 601 Reqs.addRequirements(NoSampler ? SPIRV::Capability::Image1D 602 : SPIRV::Capability::Sampled1D); 603 break; 604 case SPIRV::Dim::DIM_2D: 605 if (IsMultisampled && NoSampler) 606 Reqs.addRequirements(SPIRV::Capability::ImageMSArray); 607 break; 608 case SPIRV::Dim::DIM_Cube: 609 Reqs.addRequirements(SPIRV::Capability::Shader); 610 if (IsArrayed) 611 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageCubeArray 612 : SPIRV::Capability::SampledCubeArray); 613 break; 614 case SPIRV::Dim::DIM_Rect: 615 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageRect 616 : SPIRV::Capability::SampledRect); 617 break; 618 case SPIRV::Dim::DIM_Buffer: 619 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageBuffer 620 : SPIRV::Capability::SampledBuffer); 621 break; 622 case SPIRV::Dim::DIM_SubpassData: 623 Reqs.addRequirements(SPIRV::Capability::InputAttachment); 624 break; 625 } 626 627 // Has optional access qualifier. 628 // TODO: check if it's OpenCL's kernel. 629 if (MI.getNumOperands() > 8 && 630 MI.getOperand(8).getImm() == SPIRV::AccessQualifier::ReadWrite) 631 Reqs.addRequirements(SPIRV::Capability::ImageReadWrite); 632 else 633 Reqs.addRequirements(SPIRV::Capability::ImageBasic); 634 } 635 636 void addInstrRequirements(const MachineInstr &MI, 637 SPIRV::RequirementHandler &Reqs, 638 const SPIRVSubtarget &ST) { 639 switch (MI.getOpcode()) { 640 case SPIRV::OpMemoryModel: { 641 int64_t Addr = MI.getOperand(0).getImm(); 642 Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand, 643 Addr, ST); 644 int64_t Mem = MI.getOperand(1).getImm(); 645 Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, Mem, 646 ST); 647 break; 648 } 649 case SPIRV::OpEntryPoint: { 650 int64_t Exe = MI.getOperand(0).getImm(); 651 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModelOperand, 652 Exe, ST); 653 break; 654 } 655 case SPIRV::OpExecutionMode: 656 case SPIRV::OpExecutionModeId: { 657 int64_t Exe = MI.getOperand(1).getImm(); 658 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModeOperand, 659 Exe, ST); 660 break; 661 } 662 case SPIRV::OpTypeMatrix: 663 Reqs.addCapability(SPIRV::Capability::Matrix); 664 break; 665 case SPIRV::OpTypeInt: { 666 unsigned BitWidth = MI.getOperand(1).getImm(); 667 if (BitWidth == 64) 668 Reqs.addCapability(SPIRV::Capability::Int64); 669 else if (BitWidth == 16) 670 Reqs.addCapability(SPIRV::Capability::Int16); 671 else if (BitWidth == 8) 672 Reqs.addCapability(SPIRV::Capability::Int8); 673 break; 674 } 675 case SPIRV::OpTypeFloat: { 676 unsigned BitWidth = MI.getOperand(1).getImm(); 677 if (BitWidth == 64) 678 Reqs.addCapability(SPIRV::Capability::Float64); 679 else if (BitWidth == 16) 680 Reqs.addCapability(SPIRV::Capability::Float16); 681 break; 682 } 683 case SPIRV::OpTypeVector: { 684 unsigned NumComponents = MI.getOperand(2).getImm(); 685 if (NumComponents == 8 || NumComponents == 16) 686 Reqs.addCapability(SPIRV::Capability::Vector16); 687 break; 688 } 689 case SPIRV::OpTypePointer: { 690 auto SC = MI.getOperand(1).getImm(); 691 Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC, 692 ST); 693 // If it's a type of pointer to float16, add Float16Buffer capability. 694 assert(MI.getOperand(2).isReg()); 695 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); 696 SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg()); 697 if (TypeDef->getOpcode() == SPIRV::OpTypeFloat && 698 TypeDef->getOperand(1).getImm() == 16) 699 Reqs.addCapability(SPIRV::Capability::Float16Buffer); 700 break; 701 } 702 case SPIRV::OpBitReverse: 703 case SPIRV::OpTypeRuntimeArray: 704 Reqs.addCapability(SPIRV::Capability::Shader); 705 break; 706 case SPIRV::OpTypeOpaque: 707 case SPIRV::OpTypeEvent: 708 Reqs.addCapability(SPIRV::Capability::Kernel); 709 break; 710 case SPIRV::OpTypePipe: 711 case SPIRV::OpTypeReserveId: 712 Reqs.addCapability(SPIRV::Capability::Pipes); 713 break; 714 case SPIRV::OpTypeDeviceEvent: 715 case SPIRV::OpTypeQueue: 716 case SPIRV::OpBuildNDRange: 717 Reqs.addCapability(SPIRV::Capability::DeviceEnqueue); 718 break; 719 case SPIRV::OpDecorate: 720 case SPIRV::OpDecorateId: 721 case SPIRV::OpDecorateString: 722 addOpDecorateReqs(MI, 1, Reqs, ST); 723 break; 724 case SPIRV::OpMemberDecorate: 725 case SPIRV::OpMemberDecorateString: 726 addOpDecorateReqs(MI, 2, Reqs, ST); 727 break; 728 case SPIRV::OpInBoundsPtrAccessChain: 729 Reqs.addCapability(SPIRV::Capability::Addresses); 730 break; 731 case SPIRV::OpConstantSampler: 732 Reqs.addCapability(SPIRV::Capability::LiteralSampler); 733 break; 734 case SPIRV::OpTypeImage: 735 addOpTypeImageReqs(MI, Reqs, ST); 736 break; 737 case SPIRV::OpTypeSampler: 738 Reqs.addCapability(SPIRV::Capability::ImageBasic); 739 break; 740 case SPIRV::OpTypeForwardPointer: 741 // TODO: check if it's OpenCL's kernel. 742 Reqs.addCapability(SPIRV::Capability::Addresses); 743 break; 744 case SPIRV::OpAtomicFlagTestAndSet: 745 case SPIRV::OpAtomicLoad: 746 case SPIRV::OpAtomicStore: 747 case SPIRV::OpAtomicExchange: 748 case SPIRV::OpAtomicCompareExchange: 749 case SPIRV::OpAtomicIIncrement: 750 case SPIRV::OpAtomicIDecrement: 751 case SPIRV::OpAtomicIAdd: 752 case SPIRV::OpAtomicISub: 753 case SPIRV::OpAtomicUMin: 754 case SPIRV::OpAtomicUMax: 755 case SPIRV::OpAtomicSMin: 756 case SPIRV::OpAtomicSMax: 757 case SPIRV::OpAtomicAnd: 758 case SPIRV::OpAtomicOr: 759 case SPIRV::OpAtomicXor: { 760 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); 761 const MachineInstr *InstrPtr = &MI; 762 if (MI.getOpcode() == SPIRV::OpAtomicStore) { 763 assert(MI.getOperand(3).isReg()); 764 InstrPtr = MRI.getVRegDef(MI.getOperand(3).getReg()); 765 assert(InstrPtr && "Unexpected type instruction for OpAtomicStore"); 766 } 767 assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic"); 768 Register TypeReg = InstrPtr->getOperand(1).getReg(); 769 SPIRVType *TypeDef = MRI.getVRegDef(TypeReg); 770 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) { 771 unsigned BitWidth = TypeDef->getOperand(1).getImm(); 772 if (BitWidth == 64) 773 Reqs.addCapability(SPIRV::Capability::Int64Atomics); 774 } 775 break; 776 } 777 case SPIRV::OpGroupNonUniformIAdd: 778 case SPIRV::OpGroupNonUniformFAdd: 779 case SPIRV::OpGroupNonUniformIMul: 780 case SPIRV::OpGroupNonUniformFMul: 781 case SPIRV::OpGroupNonUniformSMin: 782 case SPIRV::OpGroupNonUniformUMin: 783 case SPIRV::OpGroupNonUniformFMin: 784 case SPIRV::OpGroupNonUniformSMax: 785 case SPIRV::OpGroupNonUniformUMax: 786 case SPIRV::OpGroupNonUniformFMax: 787 case SPIRV::OpGroupNonUniformBitwiseAnd: 788 case SPIRV::OpGroupNonUniformBitwiseOr: 789 case SPIRV::OpGroupNonUniformBitwiseXor: 790 case SPIRV::OpGroupNonUniformLogicalAnd: 791 case SPIRV::OpGroupNonUniformLogicalOr: 792 case SPIRV::OpGroupNonUniformLogicalXor: { 793 assert(MI.getOperand(3).isImm()); 794 int64_t GroupOp = MI.getOperand(3).getImm(); 795 switch (GroupOp) { 796 case SPIRV::GroupOperation::Reduce: 797 case SPIRV::GroupOperation::InclusiveScan: 798 case SPIRV::GroupOperation::ExclusiveScan: 799 Reqs.addCapability(SPIRV::Capability::Kernel); 800 Reqs.addCapability(SPIRV::Capability::GroupNonUniformArithmetic); 801 Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot); 802 break; 803 case SPIRV::GroupOperation::ClusteredReduce: 804 Reqs.addCapability(SPIRV::Capability::GroupNonUniformClustered); 805 break; 806 case SPIRV::GroupOperation::PartitionedReduceNV: 807 case SPIRV::GroupOperation::PartitionedInclusiveScanNV: 808 case SPIRV::GroupOperation::PartitionedExclusiveScanNV: 809 Reqs.addCapability(SPIRV::Capability::GroupNonUniformPartitionedNV); 810 break; 811 } 812 break; 813 } 814 case SPIRV::OpGroupNonUniformShuffle: 815 case SPIRV::OpGroupNonUniformShuffleXor: 816 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffle); 817 break; 818 case SPIRV::OpGroupNonUniformShuffleUp: 819 case SPIRV::OpGroupNonUniformShuffleDown: 820 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffleRelative); 821 break; 822 case SPIRV::OpGroupAll: 823 case SPIRV::OpGroupAny: 824 case SPIRV::OpGroupBroadcast: 825 case SPIRV::OpGroupIAdd: 826 case SPIRV::OpGroupFAdd: 827 case SPIRV::OpGroupFMin: 828 case SPIRV::OpGroupUMin: 829 case SPIRV::OpGroupSMin: 830 case SPIRV::OpGroupFMax: 831 case SPIRV::OpGroupUMax: 832 case SPIRV::OpGroupSMax: 833 Reqs.addCapability(SPIRV::Capability::Groups); 834 break; 835 case SPIRV::OpGroupNonUniformElect: 836 Reqs.addCapability(SPIRV::Capability::GroupNonUniform); 837 break; 838 case SPIRV::OpGroupNonUniformAll: 839 case SPIRV::OpGroupNonUniformAny: 840 case SPIRV::OpGroupNonUniformAllEqual: 841 Reqs.addCapability(SPIRV::Capability::GroupNonUniformVote); 842 break; 843 case SPIRV::OpGroupNonUniformBroadcast: 844 case SPIRV::OpGroupNonUniformBroadcastFirst: 845 case SPIRV::OpGroupNonUniformBallot: 846 case SPIRV::OpGroupNonUniformInverseBallot: 847 case SPIRV::OpGroupNonUniformBallotBitExtract: 848 case SPIRV::OpGroupNonUniformBallotBitCount: 849 case SPIRV::OpGroupNonUniformBallotFindLSB: 850 case SPIRV::OpGroupNonUniformBallotFindMSB: 851 Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot); 852 break; 853 default: 854 break; 855 } 856 } 857 858 static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, 859 MachineModuleInfo *MMI, const SPIRVSubtarget &ST) { 860 // Collect requirements for existing instructions. 861 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 862 MachineFunction *MF = MMI->getMachineFunction(*F); 863 if (!MF) 864 continue; 865 for (const MachineBasicBlock &MBB : *MF) 866 for (const MachineInstr &MI : MBB) 867 addInstrRequirements(MI, MAI.Reqs, ST); 868 } 869 // Collect requirements for OpExecutionMode instructions. 870 auto Node = M.getNamedMetadata("spirv.ExecutionMode"); 871 if (Node) { 872 for (unsigned i = 0; i < Node->getNumOperands(); i++) { 873 MDNode *MDN = cast<MDNode>(Node->getOperand(i)); 874 const MDOperand &MDOp = MDN->getOperand(1); 875 if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) { 876 Constant *C = CMeta->getValue(); 877 if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) { 878 auto EM = Const->getZExtValue(); 879 MAI.Reqs.getAndAddRequirements( 880 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST); 881 } 882 } 883 } 884 } 885 for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) { 886 const Function &F = *FI; 887 if (F.isDeclaration()) 888 continue; 889 if (F.getMetadata("reqd_work_group_size")) 890 MAI.Reqs.getAndAddRequirements( 891 SPIRV::OperandCategory::ExecutionModeOperand, 892 SPIRV::ExecutionMode::LocalSize, ST); 893 if (F.getMetadata("work_group_size_hint")) 894 MAI.Reqs.getAndAddRequirements( 895 SPIRV::OperandCategory::ExecutionModeOperand, 896 SPIRV::ExecutionMode::LocalSizeHint, ST); 897 if (F.getMetadata("intel_reqd_sub_group_size")) 898 MAI.Reqs.getAndAddRequirements( 899 SPIRV::OperandCategory::ExecutionModeOperand, 900 SPIRV::ExecutionMode::SubgroupSize, ST); 901 if (F.getMetadata("vec_type_hint")) 902 MAI.Reqs.getAndAddRequirements( 903 SPIRV::OperandCategory::ExecutionModeOperand, 904 SPIRV::ExecutionMode::VecTypeHint, ST); 905 906 if (F.hasOptNone() && 907 ST.canUseExtension(SPIRV::Extension::SPV_INTEL_optnone)) { 908 // Output OpCapability OptNoneINTEL. 909 MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_optnone); 910 MAI.Reqs.addCapability(SPIRV::Capability::OptNoneINTEL); 911 } 912 } 913 } 914 915 static unsigned getFastMathFlags(const MachineInstr &I) { 916 unsigned Flags = SPIRV::FPFastMathMode::None; 917 if (I.getFlag(MachineInstr::MIFlag::FmNoNans)) 918 Flags |= SPIRV::FPFastMathMode::NotNaN; 919 if (I.getFlag(MachineInstr::MIFlag::FmNoInfs)) 920 Flags |= SPIRV::FPFastMathMode::NotInf; 921 if (I.getFlag(MachineInstr::MIFlag::FmNsz)) 922 Flags |= SPIRV::FPFastMathMode::NSZ; 923 if (I.getFlag(MachineInstr::MIFlag::FmArcp)) 924 Flags |= SPIRV::FPFastMathMode::AllowRecip; 925 if (I.getFlag(MachineInstr::MIFlag::FmReassoc)) 926 Flags |= SPIRV::FPFastMathMode::Fast; 927 return Flags; 928 } 929 930 static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST, 931 const SPIRVInstrInfo &TII, 932 SPIRV::RequirementHandler &Reqs) { 933 if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) && 934 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand, 935 SPIRV::Decoration::NoSignedWrap, ST, Reqs) 936 .IsSatisfiable) { 937 buildOpDecorate(I.getOperand(0).getReg(), I, TII, 938 SPIRV::Decoration::NoSignedWrap, {}); 939 } 940 if (I.getFlag(MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(I) && 941 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand, 942 SPIRV::Decoration::NoUnsignedWrap, ST, 943 Reqs) 944 .IsSatisfiable) { 945 buildOpDecorate(I.getOperand(0).getReg(), I, TII, 946 SPIRV::Decoration::NoUnsignedWrap, {}); 947 } 948 if (!TII.canUseFastMathFlags(I)) 949 return; 950 unsigned FMFlags = getFastMathFlags(I); 951 if (FMFlags == SPIRV::FPFastMathMode::None) 952 return; 953 Register DstReg = I.getOperand(0).getReg(); 954 buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode, {FMFlags}); 955 } 956 957 // Walk all functions and add decorations related to MI flags. 958 static void addDecorations(const Module &M, const SPIRVInstrInfo &TII, 959 MachineModuleInfo *MMI, const SPIRVSubtarget &ST, 960 SPIRV::ModuleAnalysisInfo &MAI) { 961 for (auto F = M.begin(), E = M.end(); F != E; ++F) { 962 MachineFunction *MF = MMI->getMachineFunction(*F); 963 if (!MF) 964 continue; 965 for (auto &MBB : *MF) 966 for (auto &MI : MBB) 967 handleMIFlagDecoration(MI, ST, TII, MAI.Reqs); 968 } 969 } 970 971 struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI; 972 973 void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { 974 AU.addRequired<TargetPassConfig>(); 975 AU.addRequired<MachineModuleInfoWrapperPass>(); 976 } 977 978 bool SPIRVModuleAnalysis::runOnModule(Module &M) { 979 SPIRVTargetMachine &TM = 980 getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>(); 981 ST = TM.getSubtargetImpl(); 982 GR = ST->getSPIRVGlobalRegistry(); 983 TII = ST->getInstrInfo(); 984 985 MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI(); 986 987 setBaseInfo(M); 988 989 addDecorations(M, *TII, MMI, *ST, MAI); 990 991 collectReqs(M, MAI, MMI, *ST); 992 993 // Process type/const/global var/func decl instructions, number their 994 // destination registers from 0 to N, collect Extensions and Capabilities. 995 processDefInstrs(M); 996 997 // Number rest of registers from N+1 onwards. 998 numberRegistersGlobally(M); 999 1000 // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions. 1001 processOtherInstrs(M); 1002 1003 // If there are no entry points, we need the Linkage capability. 1004 if (MAI.MS[SPIRV::MB_EntryPoints].empty()) 1005 MAI.Reqs.addCapability(SPIRV::Capability::Linkage); 1006 1007 return false; 1008 } 1009