1 //===--- SPIRVUtils.cpp ---- SPIR-V Utility Functions -----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains miscellaneous utility functions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "SPIRVUtils.h" 14 #include "MCTargetDesc/SPIRVBaseInfo.h" 15 #include "SPIRV.h" 16 #include "SPIRVGlobalRegistry.h" 17 #include "SPIRVInstrInfo.h" 18 #include "SPIRVSubtarget.h" 19 #include "llvm/ADT/StringRef.h" 20 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" 21 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 22 #include "llvm/CodeGen/MachineInstr.h" 23 #include "llvm/CodeGen/MachineInstrBuilder.h" 24 #include "llvm/Demangle/Demangle.h" 25 #include "llvm/IR/IntrinsicInst.h" 26 #include "llvm/IR/IntrinsicsSPIRV.h" 27 #include <queue> 28 #include <vector> 29 30 namespace llvm { 31 32 // The following functions are used to add these string literals as a series of 33 // 32-bit integer operands with the correct format, and unpack them if necessary 34 // when making string comparisons in compiler passes. 35 // SPIR-V requires null-terminated UTF-8 strings padded to 32-bit alignment. 36 static uint32_t convertCharsToWord(const StringRef &Str, unsigned i) { 37 uint32_t Word = 0u; // Build up this 32-bit word from 4 8-bit chars. 38 for (unsigned WordIndex = 0; WordIndex < 4; ++WordIndex) { 39 unsigned StrIndex = i + WordIndex; 40 uint8_t CharToAdd = 0; // Initilize char as padding/null. 41 if (StrIndex < Str.size()) { // If it's within the string, get a real char. 42 CharToAdd = Str[StrIndex]; 43 } 44 Word |= (CharToAdd << (WordIndex * 8)); 45 } 46 return Word; 47 } 48 49 // Get length including padding and null terminator. 50 static size_t getPaddedLen(const StringRef &Str) { 51 return (Str.size() + 4) & ~3; 52 } 53 54 void addStringImm(const StringRef &Str, MCInst &Inst) { 55 const size_t PaddedLen = getPaddedLen(Str); 56 for (unsigned i = 0; i < PaddedLen; i += 4) { 57 // Add an operand for the 32-bits of chars or padding. 58 Inst.addOperand(MCOperand::createImm(convertCharsToWord(Str, i))); 59 } 60 } 61 62 void addStringImm(const StringRef &Str, MachineInstrBuilder &MIB) { 63 const size_t PaddedLen = getPaddedLen(Str); 64 for (unsigned i = 0; i < PaddedLen; i += 4) { 65 // Add an operand for the 32-bits of chars or padding. 66 MIB.addImm(convertCharsToWord(Str, i)); 67 } 68 } 69 70 void addStringImm(const StringRef &Str, IRBuilder<> &B, 71 std::vector<Value *> &Args) { 72 const size_t PaddedLen = getPaddedLen(Str); 73 for (unsigned i = 0; i < PaddedLen; i += 4) { 74 // Add a vector element for the 32-bits of chars or padding. 75 Args.push_back(B.getInt32(convertCharsToWord(Str, i))); 76 } 77 } 78 79 std::string getStringImm(const MachineInstr &MI, unsigned StartIndex) { 80 return getSPIRVStringOperand(MI, StartIndex); 81 } 82 83 void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) { 84 const auto Bitwidth = Imm.getBitWidth(); 85 if (Bitwidth == 1) 86 return; // Already handled 87 else if (Bitwidth <= 32) { 88 MIB.addImm(Imm.getZExtValue()); 89 // Asm Printer needs this info to print floating-type correctly 90 if (Bitwidth == 16) 91 MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH16); 92 return; 93 } else if (Bitwidth <= 64) { 94 uint64_t FullImm = Imm.getZExtValue(); 95 uint32_t LowBits = FullImm & 0xffffffff; 96 uint32_t HighBits = (FullImm >> 32) & 0xffffffff; 97 MIB.addImm(LowBits).addImm(HighBits); 98 return; 99 } 100 report_fatal_error("Unsupported constant bitwidth"); 101 } 102 103 void buildOpName(Register Target, const StringRef &Name, 104 MachineIRBuilder &MIRBuilder) { 105 if (!Name.empty()) { 106 auto MIB = MIRBuilder.buildInstr(SPIRV::OpName).addUse(Target); 107 addStringImm(Name, MIB); 108 } 109 } 110 111 static void finishBuildOpDecorate(MachineInstrBuilder &MIB, 112 const std::vector<uint32_t> &DecArgs, 113 StringRef StrImm) { 114 if (!StrImm.empty()) 115 addStringImm(StrImm, MIB); 116 for (const auto &DecArg : DecArgs) 117 MIB.addImm(DecArg); 118 } 119 120 void buildOpDecorate(Register Reg, MachineIRBuilder &MIRBuilder, 121 SPIRV::Decoration::Decoration Dec, 122 const std::vector<uint32_t> &DecArgs, StringRef StrImm) { 123 auto MIB = MIRBuilder.buildInstr(SPIRV::OpDecorate) 124 .addUse(Reg) 125 .addImm(static_cast<uint32_t>(Dec)); 126 finishBuildOpDecorate(MIB, DecArgs, StrImm); 127 } 128 129 void buildOpDecorate(Register Reg, MachineInstr &I, const SPIRVInstrInfo &TII, 130 SPIRV::Decoration::Decoration Dec, 131 const std::vector<uint32_t> &DecArgs, StringRef StrImm) { 132 MachineBasicBlock &MBB = *I.getParent(); 133 auto MIB = BuildMI(MBB, I, I.getDebugLoc(), TII.get(SPIRV::OpDecorate)) 134 .addUse(Reg) 135 .addImm(static_cast<uint32_t>(Dec)); 136 finishBuildOpDecorate(MIB, DecArgs, StrImm); 137 } 138 139 void buildOpSpirvDecorations(Register Reg, MachineIRBuilder &MIRBuilder, 140 const MDNode *GVarMD) { 141 for (unsigned I = 0, E = GVarMD->getNumOperands(); I != E; ++I) { 142 auto *OpMD = dyn_cast<MDNode>(GVarMD->getOperand(I)); 143 if (!OpMD) 144 report_fatal_error("Invalid decoration"); 145 if (OpMD->getNumOperands() == 0) 146 report_fatal_error("Expect operand(s) of the decoration"); 147 ConstantInt *DecorationId = 148 mdconst::dyn_extract<ConstantInt>(OpMD->getOperand(0)); 149 if (!DecorationId) 150 report_fatal_error("Expect SPIR-V <Decoration> operand to be the first " 151 "element of the decoration"); 152 auto MIB = MIRBuilder.buildInstr(SPIRV::OpDecorate) 153 .addUse(Reg) 154 .addImm(static_cast<uint32_t>(DecorationId->getZExtValue())); 155 for (unsigned OpI = 1, OpE = OpMD->getNumOperands(); OpI != OpE; ++OpI) { 156 if (ConstantInt *OpV = 157 mdconst::dyn_extract<ConstantInt>(OpMD->getOperand(OpI))) 158 MIB.addImm(static_cast<uint32_t>(OpV->getZExtValue())); 159 else if (MDString *OpV = dyn_cast<MDString>(OpMD->getOperand(OpI))) 160 addStringImm(OpV->getString(), MIB); 161 else 162 report_fatal_error("Unexpected operand of the decoration"); 163 } 164 } 165 } 166 167 MachineBasicBlock::iterator getOpVariableMBBIt(MachineInstr &I) { 168 MachineFunction *MF = I.getParent()->getParent(); 169 MachineBasicBlock *MBB = &MF->front(); 170 MachineBasicBlock::iterator It = MBB->SkipPHIsAndLabels(MBB->begin()), 171 E = MBB->end(); 172 bool IsHeader = false; 173 unsigned Opcode; 174 for (; It != E && It != I; ++It) { 175 Opcode = It->getOpcode(); 176 if (Opcode == SPIRV::OpFunction || Opcode == SPIRV::OpFunctionParameter) { 177 IsHeader = true; 178 } else if (IsHeader && 179 !(Opcode == SPIRV::ASSIGN_TYPE || Opcode == SPIRV::OpLabel)) { 180 ++It; 181 break; 182 } 183 } 184 return It; 185 } 186 187 SPIRV::StorageClass::StorageClass 188 addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI) { 189 switch (AddrSpace) { 190 case 0: 191 return SPIRV::StorageClass::Function; 192 case 1: 193 return SPIRV::StorageClass::CrossWorkgroup; 194 case 2: 195 return SPIRV::StorageClass::UniformConstant; 196 case 3: 197 return SPIRV::StorageClass::Workgroup; 198 case 4: 199 return SPIRV::StorageClass::Generic; 200 case 5: 201 return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes) 202 ? SPIRV::StorageClass::DeviceOnlyINTEL 203 : SPIRV::StorageClass::CrossWorkgroup; 204 case 6: 205 return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes) 206 ? SPIRV::StorageClass::HostOnlyINTEL 207 : SPIRV::StorageClass::CrossWorkgroup; 208 case 7: 209 return SPIRV::StorageClass::Input; 210 case 9: 211 return SPIRV::StorageClass::CodeSectionINTEL; 212 default: 213 report_fatal_error("Unknown address space"); 214 } 215 } 216 217 SPIRV::MemorySemantics::MemorySemantics 218 getMemSemanticsForStorageClass(SPIRV::StorageClass::StorageClass SC) { 219 switch (SC) { 220 case SPIRV::StorageClass::StorageBuffer: 221 case SPIRV::StorageClass::Uniform: 222 return SPIRV::MemorySemantics::UniformMemory; 223 case SPIRV::StorageClass::Workgroup: 224 return SPIRV::MemorySemantics::WorkgroupMemory; 225 case SPIRV::StorageClass::CrossWorkgroup: 226 return SPIRV::MemorySemantics::CrossWorkgroupMemory; 227 case SPIRV::StorageClass::AtomicCounter: 228 return SPIRV::MemorySemantics::AtomicCounterMemory; 229 case SPIRV::StorageClass::Image: 230 return SPIRV::MemorySemantics::ImageMemory; 231 default: 232 return SPIRV::MemorySemantics::None; 233 } 234 } 235 236 SPIRV::MemorySemantics::MemorySemantics getMemSemantics(AtomicOrdering Ord) { 237 switch (Ord) { 238 case AtomicOrdering::Acquire: 239 return SPIRV::MemorySemantics::Acquire; 240 case AtomicOrdering::Release: 241 return SPIRV::MemorySemantics::Release; 242 case AtomicOrdering::AcquireRelease: 243 return SPIRV::MemorySemantics::AcquireRelease; 244 case AtomicOrdering::SequentiallyConsistent: 245 return SPIRV::MemorySemantics::SequentiallyConsistent; 246 case AtomicOrdering::Unordered: 247 case AtomicOrdering::Monotonic: 248 case AtomicOrdering::NotAtomic: 249 return SPIRV::MemorySemantics::None; 250 } 251 llvm_unreachable(nullptr); 252 } 253 254 SPIRV::Scope::Scope getMemScope(LLVMContext &Ctx, SyncScope::ID Id) { 255 // Named by 256 // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_scope_id. 257 // We don't need aliases for Invocation and CrossDevice, as we already have 258 // them covered by "singlethread" and "" strings respectively (see 259 // implementation of LLVMContext::LLVMContext()). 260 static const llvm::SyncScope::ID SubGroup = 261 Ctx.getOrInsertSyncScopeID("subgroup"); 262 static const llvm::SyncScope::ID WorkGroup = 263 Ctx.getOrInsertSyncScopeID("workgroup"); 264 static const llvm::SyncScope::ID Device = 265 Ctx.getOrInsertSyncScopeID("device"); 266 267 if (Id == llvm::SyncScope::SingleThread) 268 return SPIRV::Scope::Invocation; 269 else if (Id == llvm::SyncScope::System) 270 return SPIRV::Scope::CrossDevice; 271 else if (Id == SubGroup) 272 return SPIRV::Scope::Subgroup; 273 else if (Id == WorkGroup) 274 return SPIRV::Scope::Workgroup; 275 else if (Id == Device) 276 return SPIRV::Scope::Device; 277 return SPIRV::Scope::CrossDevice; 278 } 279 280 MachineInstr *getDefInstrMaybeConstant(Register &ConstReg, 281 const MachineRegisterInfo *MRI) { 282 MachineInstr *MI = MRI->getVRegDef(ConstReg); 283 MachineInstr *ConstInstr = 284 MI->getOpcode() == SPIRV::G_TRUNC || MI->getOpcode() == SPIRV::G_ZEXT 285 ? MRI->getVRegDef(MI->getOperand(1).getReg()) 286 : MI; 287 if (auto *GI = dyn_cast<GIntrinsic>(ConstInstr)) { 288 if (GI->is(Intrinsic::spv_track_constant)) { 289 ConstReg = ConstInstr->getOperand(2).getReg(); 290 return MRI->getVRegDef(ConstReg); 291 } 292 } else if (ConstInstr->getOpcode() == SPIRV::ASSIGN_TYPE) { 293 ConstReg = ConstInstr->getOperand(1).getReg(); 294 return MRI->getVRegDef(ConstReg); 295 } 296 return MRI->getVRegDef(ConstReg); 297 } 298 299 uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI) { 300 const MachineInstr *MI = getDefInstrMaybeConstant(ConstReg, MRI); 301 assert(MI && MI->getOpcode() == TargetOpcode::G_CONSTANT); 302 return MI->getOperand(1).getCImm()->getValue().getZExtValue(); 303 } 304 305 bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID) { 306 if (const auto *GI = dyn_cast<GIntrinsic>(&MI)) 307 return GI->is(IntrinsicID); 308 return false; 309 } 310 311 Type *getMDOperandAsType(const MDNode *N, unsigned I) { 312 Type *ElementTy = cast<ValueAsMetadata>(N->getOperand(I))->getType(); 313 return toTypedPointer(ElementTy); 314 } 315 316 // The set of names is borrowed from the SPIR-V translator. 317 // TODO: may be implemented in SPIRVBuiltins.td. 318 static bool isPipeOrAddressSpaceCastBI(const StringRef MangledName) { 319 return MangledName == "write_pipe_2" || MangledName == "read_pipe_2" || 320 MangledName == "write_pipe_2_bl" || MangledName == "read_pipe_2_bl" || 321 MangledName == "write_pipe_4" || MangledName == "read_pipe_4" || 322 MangledName == "reserve_write_pipe" || 323 MangledName == "reserve_read_pipe" || 324 MangledName == "commit_write_pipe" || 325 MangledName == "commit_read_pipe" || 326 MangledName == "work_group_reserve_write_pipe" || 327 MangledName == "work_group_reserve_read_pipe" || 328 MangledName == "work_group_commit_write_pipe" || 329 MangledName == "work_group_commit_read_pipe" || 330 MangledName == "get_pipe_num_packets_ro" || 331 MangledName == "get_pipe_max_packets_ro" || 332 MangledName == "get_pipe_num_packets_wo" || 333 MangledName == "get_pipe_max_packets_wo" || 334 MangledName == "sub_group_reserve_write_pipe" || 335 MangledName == "sub_group_reserve_read_pipe" || 336 MangledName == "sub_group_commit_write_pipe" || 337 MangledName == "sub_group_commit_read_pipe" || 338 MangledName == "to_global" || MangledName == "to_local" || 339 MangledName == "to_private"; 340 } 341 342 static bool isEnqueueKernelBI(const StringRef MangledName) { 343 return MangledName == "__enqueue_kernel_basic" || 344 MangledName == "__enqueue_kernel_basic_events" || 345 MangledName == "__enqueue_kernel_varargs" || 346 MangledName == "__enqueue_kernel_events_varargs"; 347 } 348 349 static bool isKernelQueryBI(const StringRef MangledName) { 350 return MangledName == "__get_kernel_work_group_size_impl" || 351 MangledName == "__get_kernel_sub_group_count_for_ndrange_impl" || 352 MangledName == "__get_kernel_max_sub_group_size_for_ndrange_impl" || 353 MangledName == "__get_kernel_preferred_work_group_size_multiple_impl"; 354 } 355 356 static bool isNonMangledOCLBuiltin(StringRef Name) { 357 if (!Name.starts_with("__")) 358 return false; 359 360 return isEnqueueKernelBI(Name) || isKernelQueryBI(Name) || 361 isPipeOrAddressSpaceCastBI(Name.drop_front(2)) || 362 Name == "__translate_sampler_initializer"; 363 } 364 365 std::string getOclOrSpirvBuiltinDemangledName(StringRef Name) { 366 bool IsNonMangledOCL = isNonMangledOCLBuiltin(Name); 367 bool IsNonMangledSPIRV = Name.starts_with("__spirv_"); 368 bool IsNonMangledHLSL = Name.starts_with("__hlsl_"); 369 bool IsMangled = Name.starts_with("_Z"); 370 371 // Otherwise use simple demangling to return the function name. 372 if (IsNonMangledOCL || IsNonMangledSPIRV || IsNonMangledHLSL || !IsMangled) 373 return Name.str(); 374 375 // Try to use the itanium demangler. 376 if (char *DemangledName = itaniumDemangle(Name.data())) { 377 std::string Result = DemangledName; 378 free(DemangledName); 379 return Result; 380 } 381 382 // Autocheck C++, maybe need to do explicit check of the source language. 383 // OpenCL C++ built-ins are declared in cl namespace. 384 // TODO: consider using 'St' abbriviation for cl namespace mangling. 385 // Similar to ::std:: in C++. 386 size_t Start, Len = 0; 387 size_t DemangledNameLenStart = 2; 388 if (Name.starts_with("_ZN")) { 389 // Skip CV and ref qualifiers. 390 size_t NameSpaceStart = Name.find_first_not_of("rVKRO", 3); 391 // All built-ins are in the ::cl:: namespace. 392 if (Name.substr(NameSpaceStart, 11) != "2cl7__spirv") 393 return std::string(); 394 DemangledNameLenStart = NameSpaceStart + 11; 395 } 396 Start = Name.find_first_not_of("0123456789", DemangledNameLenStart); 397 Name.substr(DemangledNameLenStart, Start - DemangledNameLenStart) 398 .getAsInteger(10, Len); 399 return Name.substr(Start, Len).str(); 400 } 401 402 bool hasBuiltinTypePrefix(StringRef Name) { 403 if (Name.starts_with("opencl.") || Name.starts_with("ocl_") || 404 Name.starts_with("spirv.")) 405 return true; 406 return false; 407 } 408 409 bool isSpecialOpaqueType(const Type *Ty) { 410 if (const TargetExtType *ExtTy = dyn_cast<TargetExtType>(Ty)) 411 return isTypedPointerWrapper(ExtTy) 412 ? false 413 : hasBuiltinTypePrefix(ExtTy->getName()); 414 415 return false; 416 } 417 418 bool isEntryPoint(const Function &F) { 419 // OpenCL handling: any function with the SPIR_KERNEL 420 // calling convention will be a potential entry point. 421 if (F.getCallingConv() == CallingConv::SPIR_KERNEL) 422 return true; 423 424 // HLSL handling: special attribute are emitted from the 425 // front-end. 426 if (F.getFnAttribute("hlsl.shader").isValid()) 427 return true; 428 429 return false; 430 } 431 432 Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx) { 433 TypeName.consume_front("atomic_"); 434 if (TypeName.consume_front("void")) 435 return Type::getVoidTy(Ctx); 436 else if (TypeName.consume_front("bool")) 437 return Type::getIntNTy(Ctx, 1); 438 else if (TypeName.consume_front("char") || 439 TypeName.consume_front("unsigned char") || 440 TypeName.consume_front("uchar")) 441 return Type::getInt8Ty(Ctx); 442 else if (TypeName.consume_front("short") || 443 TypeName.consume_front("unsigned short") || 444 TypeName.consume_front("ushort")) 445 return Type::getInt16Ty(Ctx); 446 else if (TypeName.consume_front("int") || 447 TypeName.consume_front("unsigned int") || 448 TypeName.consume_front("uint")) 449 return Type::getInt32Ty(Ctx); 450 else if (TypeName.consume_front("long") || 451 TypeName.consume_front("unsigned long") || 452 TypeName.consume_front("ulong")) 453 return Type::getInt64Ty(Ctx); 454 else if (TypeName.consume_front("half")) 455 return Type::getHalfTy(Ctx); 456 else if (TypeName.consume_front("float")) 457 return Type::getFloatTy(Ctx); 458 else if (TypeName.consume_front("double")) 459 return Type::getDoubleTy(Ctx); 460 461 // Unable to recognize SPIRV type name 462 return nullptr; 463 } 464 465 std::unordered_set<BasicBlock *> 466 PartialOrderingVisitor::getReachableFrom(BasicBlock *Start) { 467 std::queue<BasicBlock *> ToVisit; 468 ToVisit.push(Start); 469 470 std::unordered_set<BasicBlock *> Output; 471 while (ToVisit.size() != 0) { 472 BasicBlock *BB = ToVisit.front(); 473 ToVisit.pop(); 474 475 if (Output.count(BB) != 0) 476 continue; 477 Output.insert(BB); 478 479 for (BasicBlock *Successor : successors(BB)) { 480 if (DT.dominates(Successor, BB)) 481 continue; 482 ToVisit.push(Successor); 483 } 484 } 485 486 return Output; 487 } 488 489 bool PartialOrderingVisitor::CanBeVisited(BasicBlock *BB) const { 490 for (BasicBlock *P : predecessors(BB)) { 491 // Ignore back-edges. 492 if (DT.dominates(BB, P)) 493 continue; 494 495 // One of the predecessor hasn't been visited. Not ready yet. 496 if (BlockToOrder.count(P) == 0) 497 return false; 498 499 // If the block is a loop exit, the loop must be finished before 500 // we can continue. 501 Loop *L = LI.getLoopFor(P); 502 if (L == nullptr || L->contains(BB)) 503 continue; 504 505 // SPIR-V requires a single back-edge. And the backend first 506 // step transforms loops into the simplified format. If we have 507 // more than 1 back-edge, something is wrong. 508 assert(L->getNumBackEdges() <= 1); 509 510 // If the loop has no latch, loop's rank won't matter, so we can 511 // proceed. 512 BasicBlock *Latch = L->getLoopLatch(); 513 assert(Latch); 514 if (Latch == nullptr) 515 continue; 516 517 // The latch is not ready yet, let's wait. 518 if (BlockToOrder.count(Latch) == 0) 519 return false; 520 } 521 522 return true; 523 } 524 525 size_t PartialOrderingVisitor::GetNodeRank(BasicBlock *BB) const { 526 auto It = BlockToOrder.find(BB); 527 if (It != BlockToOrder.end()) 528 return It->second.Rank; 529 530 size_t result = 0; 531 for (BasicBlock *P : predecessors(BB)) { 532 // Ignore back-edges. 533 if (DT.dominates(BB, P)) 534 continue; 535 536 auto Iterator = BlockToOrder.end(); 537 Loop *L = LI.getLoopFor(P); 538 BasicBlock *Latch = L ? L->getLoopLatch() : nullptr; 539 540 // If the predecessor is either outside a loop, or part of 541 // the same loop, simply take its rank + 1. 542 if (L == nullptr || L->contains(BB) || Latch == nullptr) { 543 Iterator = BlockToOrder.find(P); 544 } else { 545 // Otherwise, take the loop's rank (highest rank in the loop) as base. 546 // Since loops have a single latch, highest rank is easy to find. 547 // If the loop has no latch, then it doesn't matter. 548 Iterator = BlockToOrder.find(Latch); 549 } 550 551 assert(Iterator != BlockToOrder.end()); 552 result = std::max(result, Iterator->second.Rank + 1); 553 } 554 555 return result; 556 } 557 558 size_t PartialOrderingVisitor::visit(BasicBlock *BB, size_t Unused) { 559 ToVisit.push(BB); 560 Queued.insert(BB); 561 562 size_t QueueIndex = 0; 563 while (ToVisit.size() != 0) { 564 BasicBlock *BB = ToVisit.front(); 565 ToVisit.pop(); 566 567 if (!CanBeVisited(BB)) { 568 ToVisit.push(BB); 569 assert(QueueIndex < ToVisit.size() && 570 "No valid candidate in the queue. Is the graph reducible?"); 571 QueueIndex++; 572 continue; 573 } 574 575 QueueIndex = 0; 576 size_t Rank = GetNodeRank(BB); 577 OrderInfo Info = {Rank, BlockToOrder.size()}; 578 BlockToOrder.emplace(BB, Info); 579 580 for (BasicBlock *S : successors(BB)) { 581 if (Queued.count(S) != 0) 582 continue; 583 ToVisit.push(S); 584 Queued.insert(S); 585 } 586 } 587 588 return 0; 589 } 590 591 PartialOrderingVisitor::PartialOrderingVisitor(Function &F) { 592 DT.recalculate(F); 593 LI = LoopInfo(DT); 594 595 visit(&*F.begin(), 0); 596 597 Order.reserve(F.size()); 598 for (auto &[BB, Info] : BlockToOrder) 599 Order.emplace_back(BB); 600 601 std::sort(Order.begin(), Order.end(), [&](const auto &LHS, const auto &RHS) { 602 return compare(LHS, RHS); 603 }); 604 } 605 606 bool PartialOrderingVisitor::compare(const BasicBlock *LHS, 607 const BasicBlock *RHS) const { 608 const OrderInfo &InfoLHS = BlockToOrder.at(const_cast<BasicBlock *>(LHS)); 609 const OrderInfo &InfoRHS = BlockToOrder.at(const_cast<BasicBlock *>(RHS)); 610 if (InfoLHS.Rank != InfoRHS.Rank) 611 return InfoLHS.Rank < InfoRHS.Rank; 612 return InfoLHS.TraversalIndex < InfoRHS.TraversalIndex; 613 } 614 615 void PartialOrderingVisitor::partialOrderVisit( 616 BasicBlock &Start, std::function<bool(BasicBlock *)> Op) { 617 std::unordered_set<BasicBlock *> Reachable = getReachableFrom(&Start); 618 assert(BlockToOrder.count(&Start) != 0); 619 620 // Skipping blocks with a rank inferior to |Start|'s rank. 621 auto It = Order.begin(); 622 while (It != Order.end() && *It != &Start) 623 ++It; 624 625 // This is unexpected. Worst case |Start| is the last block, 626 // so It should point to the last block, not past-end. 627 assert(It != Order.end()); 628 629 // By default, there is no rank limit. Setting it to the maximum value. 630 std::optional<size_t> EndRank = std::nullopt; 631 for (; It != Order.end(); ++It) { 632 if (EndRank.has_value() && BlockToOrder[*It].Rank > *EndRank) 633 break; 634 635 if (Reachable.count(*It) == 0) { 636 continue; 637 } 638 639 if (!Op(*It)) { 640 EndRank = BlockToOrder[*It].Rank; 641 } 642 } 643 } 644 645 bool sortBlocks(Function &F) { 646 if (F.size() == 0) 647 return false; 648 649 bool Modified = false; 650 std::vector<BasicBlock *> Order; 651 Order.reserve(F.size()); 652 653 ReversePostOrderTraversal<Function *> RPOT(&F); 654 for (BasicBlock *BB : RPOT) 655 Order.push_back(BB); 656 657 assert(&*F.begin() == Order[0]); 658 BasicBlock *LastBlock = &*F.begin(); 659 for (BasicBlock *BB : Order) { 660 if (BB != LastBlock && &*LastBlock->getNextNode() != BB) { 661 Modified = true; 662 BB->moveAfter(LastBlock); 663 } 664 LastBlock = BB; 665 } 666 667 return Modified; 668 } 669 670 MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg) { 671 MachineInstr *MaybeDef = MRI.getVRegDef(Reg); 672 if (MaybeDef && MaybeDef->getOpcode() == SPIRV::ASSIGN_TYPE) 673 MaybeDef = MRI.getVRegDef(MaybeDef->getOperand(1).getReg()); 674 return MaybeDef; 675 } 676 677 bool getVacantFunctionName(Module &M, std::string &Name) { 678 // It's a bit of paranoia, but still we don't want to have even a chance that 679 // the loop will work for too long. 680 constexpr unsigned MaxIters = 1024; 681 for (unsigned I = 0; I < MaxIters; ++I) { 682 std::string OrdName = Name + Twine(I).str(); 683 if (!M.getFunction(OrdName)) { 684 Name = OrdName; 685 return true; 686 } 687 } 688 return false; 689 } 690 691 // Assign SPIR-V type to the register. If the register has no valid assigned 692 // class, set register LLT type and class according to the SPIR-V type. 693 void setRegClassType(Register Reg, SPIRVType *SpvType, SPIRVGlobalRegistry *GR, 694 MachineRegisterInfo *MRI, const MachineFunction &MF, 695 bool Force) { 696 GR->assignSPIRVTypeToVReg(SpvType, Reg, MF); 697 if (!MRI->getRegClassOrNull(Reg) || Force) { 698 MRI->setRegClass(Reg, GR->getRegClass(SpvType)); 699 MRI->setType(Reg, GR->getRegType(SpvType)); 700 } 701 } 702 703 // Create a SPIR-V type, assign SPIR-V type to the register. If the register has 704 // no valid assigned class, set register LLT type and class according to the 705 // SPIR-V type. 706 void setRegClassType(Register Reg, const Type *Ty, SPIRVGlobalRegistry *GR, 707 MachineIRBuilder &MIRBuilder, bool Force) { 708 setRegClassType(Reg, GR->getOrCreateSPIRVType(Ty, MIRBuilder), GR, 709 MIRBuilder.getMRI(), MIRBuilder.getMF(), Force); 710 } 711 712 // Create a virtual register and assign SPIR-V type to the register. Set 713 // register LLT type and class according to the SPIR-V type. 714 Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR, 715 MachineRegisterInfo *MRI, 716 const MachineFunction &MF) { 717 Register Reg = MRI->createVirtualRegister(GR->getRegClass(SpvType)); 718 MRI->setType(Reg, GR->getRegType(SpvType)); 719 GR->assignSPIRVTypeToVReg(SpvType, Reg, MF); 720 return Reg; 721 } 722 723 // Create a virtual register and assign SPIR-V type to the register. Set 724 // register LLT type and class according to the SPIR-V type. 725 Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR, 726 MachineIRBuilder &MIRBuilder) { 727 return createVirtualRegister(SpvType, GR, MIRBuilder.getMRI(), 728 MIRBuilder.getMF()); 729 } 730 731 // Create a SPIR-V type, virtual register and assign SPIR-V type to the 732 // register. Set register LLT type and class according to the SPIR-V type. 733 Register createVirtualRegister(const Type *Ty, SPIRVGlobalRegistry *GR, 734 MachineIRBuilder &MIRBuilder) { 735 return createVirtualRegister(GR->getOrCreateSPIRVType(Ty, MIRBuilder), GR, 736 MIRBuilder); 737 } 738 739 // Return true if there is an opaque pointer type nested in the argument. 740 bool isNestedPointer(const Type *Ty) { 741 if (Ty->isPtrOrPtrVectorTy()) 742 return true; 743 if (const FunctionType *RefTy = dyn_cast<FunctionType>(Ty)) { 744 if (isNestedPointer(RefTy->getReturnType())) 745 return true; 746 for (const Type *ArgTy : RefTy->params()) 747 if (isNestedPointer(ArgTy)) 748 return true; 749 return false; 750 } 751 if (const ArrayType *RefTy = dyn_cast<ArrayType>(Ty)) 752 return isNestedPointer(RefTy->getElementType()); 753 return false; 754 } 755 756 bool isSpvIntrinsic(const Value *Arg) { 757 if (const auto *II = dyn_cast<IntrinsicInst>(Arg)) 758 if (Function *F = II->getCalledFunction()) 759 if (F->getName().starts_with("llvm.spv.")) 760 return true; 761 return false; 762 } 763 764 } // namespace llvm 765