1 //===--- SPIRVUtils.cpp ---- SPIR-V Utility Functions -----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains miscellaneous utility functions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "SPIRVUtils.h" 14 #include "MCTargetDesc/SPIRVBaseInfo.h" 15 #include "SPIRV.h" 16 #include "SPIRVGlobalRegistry.h" 17 #include "SPIRVInstrInfo.h" 18 #include "SPIRVSubtarget.h" 19 #include "llvm/ADT/StringRef.h" 20 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" 21 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 22 #include "llvm/CodeGen/MachineInstr.h" 23 #include "llvm/CodeGen/MachineInstrBuilder.h" 24 #include "llvm/Demangle/Demangle.h" 25 #include "llvm/IR/IntrinsicInst.h" 26 #include "llvm/IR/IntrinsicsSPIRV.h" 27 #include <queue> 28 #include <vector> 29 30 namespace llvm { 31 32 // The following functions are used to add these string literals as a series of 33 // 32-bit integer operands with the correct format, and unpack them if necessary 34 // when making string comparisons in compiler passes. 35 // SPIR-V requires null-terminated UTF-8 strings padded to 32-bit alignment. 36 static uint32_t convertCharsToWord(const StringRef &Str, unsigned i) { 37 uint32_t Word = 0u; // Build up this 32-bit word from 4 8-bit chars. 38 for (unsigned WordIndex = 0; WordIndex < 4; ++WordIndex) { 39 unsigned StrIndex = i + WordIndex; 40 uint8_t CharToAdd = 0; // Initilize char as padding/null. 41 if (StrIndex < Str.size()) { // If it's within the string, get a real char. 42 CharToAdd = Str[StrIndex]; 43 } 44 Word |= (CharToAdd << (WordIndex * 8)); 45 } 46 return Word; 47 } 48 49 // Get length including padding and null terminator. 50 static size_t getPaddedLen(const StringRef &Str) { 51 return (Str.size() + 4) & ~3; 52 } 53 54 void addStringImm(const StringRef &Str, MCInst &Inst) { 55 const size_t PaddedLen = getPaddedLen(Str); 56 for (unsigned i = 0; i < PaddedLen; i += 4) { 57 // Add an operand for the 32-bits of chars or padding. 58 Inst.addOperand(MCOperand::createImm(convertCharsToWord(Str, i))); 59 } 60 } 61 62 void addStringImm(const StringRef &Str, MachineInstrBuilder &MIB) { 63 const size_t PaddedLen = getPaddedLen(Str); 64 for (unsigned i = 0; i < PaddedLen; i += 4) { 65 // Add an operand for the 32-bits of chars or padding. 66 MIB.addImm(convertCharsToWord(Str, i)); 67 } 68 } 69 70 void addStringImm(const StringRef &Str, IRBuilder<> &B, 71 std::vector<Value *> &Args) { 72 const size_t PaddedLen = getPaddedLen(Str); 73 for (unsigned i = 0; i < PaddedLen; i += 4) { 74 // Add a vector element for the 32-bits of chars or padding. 75 Args.push_back(B.getInt32(convertCharsToWord(Str, i))); 76 } 77 } 78 79 std::string getStringImm(const MachineInstr &MI, unsigned StartIndex) { 80 return getSPIRVStringOperand(MI, StartIndex); 81 } 82 83 void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) { 84 const auto Bitwidth = Imm.getBitWidth(); 85 if (Bitwidth == 1) 86 return; // Already handled 87 else if (Bitwidth <= 32) { 88 MIB.addImm(Imm.getZExtValue()); 89 // Asm Printer needs this info to print floating-type correctly 90 if (Bitwidth == 16) 91 MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH16); 92 return; 93 } else if (Bitwidth <= 64) { 94 uint64_t FullImm = Imm.getZExtValue(); 95 uint32_t LowBits = FullImm & 0xffffffff; 96 uint32_t HighBits = (FullImm >> 32) & 0xffffffff; 97 MIB.addImm(LowBits).addImm(HighBits); 98 return; 99 } 100 report_fatal_error("Unsupported constant bitwidth"); 101 } 102 103 void buildOpName(Register Target, const StringRef &Name, 104 MachineIRBuilder &MIRBuilder) { 105 if (!Name.empty()) { 106 auto MIB = MIRBuilder.buildInstr(SPIRV::OpName).addUse(Target); 107 addStringImm(Name, MIB); 108 } 109 } 110 111 void buildOpName(Register Target, const StringRef &Name, MachineInstr &I, 112 const SPIRVInstrInfo &TII) { 113 if (!Name.empty()) { 114 auto MIB = 115 BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpName)) 116 .addUse(Target); 117 addStringImm(Name, MIB); 118 } 119 } 120 121 static void finishBuildOpDecorate(MachineInstrBuilder &MIB, 122 const std::vector<uint32_t> &DecArgs, 123 StringRef StrImm) { 124 if (!StrImm.empty()) 125 addStringImm(StrImm, MIB); 126 for (const auto &DecArg : DecArgs) 127 MIB.addImm(DecArg); 128 } 129 130 void buildOpDecorate(Register Reg, MachineIRBuilder &MIRBuilder, 131 SPIRV::Decoration::Decoration Dec, 132 const std::vector<uint32_t> &DecArgs, StringRef StrImm) { 133 auto MIB = MIRBuilder.buildInstr(SPIRV::OpDecorate) 134 .addUse(Reg) 135 .addImm(static_cast<uint32_t>(Dec)); 136 finishBuildOpDecorate(MIB, DecArgs, StrImm); 137 } 138 139 void buildOpDecorate(Register Reg, MachineInstr &I, const SPIRVInstrInfo &TII, 140 SPIRV::Decoration::Decoration Dec, 141 const std::vector<uint32_t> &DecArgs, StringRef StrImm) { 142 MachineBasicBlock &MBB = *I.getParent(); 143 auto MIB = BuildMI(MBB, I, I.getDebugLoc(), TII.get(SPIRV::OpDecorate)) 144 .addUse(Reg) 145 .addImm(static_cast<uint32_t>(Dec)); 146 finishBuildOpDecorate(MIB, DecArgs, StrImm); 147 } 148 149 void buildOpSpirvDecorations(Register Reg, MachineIRBuilder &MIRBuilder, 150 const MDNode *GVarMD) { 151 for (unsigned I = 0, E = GVarMD->getNumOperands(); I != E; ++I) { 152 auto *OpMD = dyn_cast<MDNode>(GVarMD->getOperand(I)); 153 if (!OpMD) 154 report_fatal_error("Invalid decoration"); 155 if (OpMD->getNumOperands() == 0) 156 report_fatal_error("Expect operand(s) of the decoration"); 157 ConstantInt *DecorationId = 158 mdconst::dyn_extract<ConstantInt>(OpMD->getOperand(0)); 159 if (!DecorationId) 160 report_fatal_error("Expect SPIR-V <Decoration> operand to be the first " 161 "element of the decoration"); 162 auto MIB = MIRBuilder.buildInstr(SPIRV::OpDecorate) 163 .addUse(Reg) 164 .addImm(static_cast<uint32_t>(DecorationId->getZExtValue())); 165 for (unsigned OpI = 1, OpE = OpMD->getNumOperands(); OpI != OpE; ++OpI) { 166 if (ConstantInt *OpV = 167 mdconst::dyn_extract<ConstantInt>(OpMD->getOperand(OpI))) 168 MIB.addImm(static_cast<uint32_t>(OpV->getZExtValue())); 169 else if (MDString *OpV = dyn_cast<MDString>(OpMD->getOperand(OpI))) 170 addStringImm(OpV->getString(), MIB); 171 else 172 report_fatal_error("Unexpected operand of the decoration"); 173 } 174 } 175 } 176 177 MachineBasicBlock::iterator getOpVariableMBBIt(MachineInstr &I) { 178 MachineFunction *MF = I.getParent()->getParent(); 179 MachineBasicBlock *MBB = &MF->front(); 180 MachineBasicBlock::iterator It = MBB->SkipPHIsAndLabels(MBB->begin()), 181 E = MBB->end(); 182 bool IsHeader = false; 183 unsigned Opcode; 184 for (; It != E && It != I; ++It) { 185 Opcode = It->getOpcode(); 186 if (Opcode == SPIRV::OpFunction || Opcode == SPIRV::OpFunctionParameter) { 187 IsHeader = true; 188 } else if (IsHeader && 189 !(Opcode == SPIRV::ASSIGN_TYPE || Opcode == SPIRV::OpLabel)) { 190 ++It; 191 break; 192 } 193 } 194 return It; 195 } 196 197 SPIRV::StorageClass::StorageClass 198 addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI) { 199 switch (AddrSpace) { 200 case 0: 201 return SPIRV::StorageClass::Function; 202 case 1: 203 return SPIRV::StorageClass::CrossWorkgroup; 204 case 2: 205 return SPIRV::StorageClass::UniformConstant; 206 case 3: 207 return SPIRV::StorageClass::Workgroup; 208 case 4: 209 return SPIRV::StorageClass::Generic; 210 case 5: 211 return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes) 212 ? SPIRV::StorageClass::DeviceOnlyINTEL 213 : SPIRV::StorageClass::CrossWorkgroup; 214 case 6: 215 return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes) 216 ? SPIRV::StorageClass::HostOnlyINTEL 217 : SPIRV::StorageClass::CrossWorkgroup; 218 case 7: 219 return SPIRV::StorageClass::Input; 220 case 8: 221 return SPIRV::StorageClass::Output; 222 case 9: 223 return SPIRV::StorageClass::CodeSectionINTEL; 224 case 10: 225 return SPIRV::StorageClass::Private; 226 default: 227 report_fatal_error("Unknown address space"); 228 } 229 } 230 231 SPIRV::MemorySemantics::MemorySemantics 232 getMemSemanticsForStorageClass(SPIRV::StorageClass::StorageClass SC) { 233 switch (SC) { 234 case SPIRV::StorageClass::StorageBuffer: 235 case SPIRV::StorageClass::Uniform: 236 return SPIRV::MemorySemantics::UniformMemory; 237 case SPIRV::StorageClass::Workgroup: 238 return SPIRV::MemorySemantics::WorkgroupMemory; 239 case SPIRV::StorageClass::CrossWorkgroup: 240 return SPIRV::MemorySemantics::CrossWorkgroupMemory; 241 case SPIRV::StorageClass::AtomicCounter: 242 return SPIRV::MemorySemantics::AtomicCounterMemory; 243 case SPIRV::StorageClass::Image: 244 return SPIRV::MemorySemantics::ImageMemory; 245 default: 246 return SPIRV::MemorySemantics::None; 247 } 248 } 249 250 SPIRV::MemorySemantics::MemorySemantics getMemSemantics(AtomicOrdering Ord) { 251 switch (Ord) { 252 case AtomicOrdering::Acquire: 253 return SPIRV::MemorySemantics::Acquire; 254 case AtomicOrdering::Release: 255 return SPIRV::MemorySemantics::Release; 256 case AtomicOrdering::AcquireRelease: 257 return SPIRV::MemorySemantics::AcquireRelease; 258 case AtomicOrdering::SequentiallyConsistent: 259 return SPIRV::MemorySemantics::SequentiallyConsistent; 260 case AtomicOrdering::Unordered: 261 case AtomicOrdering::Monotonic: 262 case AtomicOrdering::NotAtomic: 263 return SPIRV::MemorySemantics::None; 264 } 265 llvm_unreachable(nullptr); 266 } 267 268 SPIRV::Scope::Scope getMemScope(LLVMContext &Ctx, SyncScope::ID Id) { 269 // Named by 270 // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_scope_id. 271 // We don't need aliases for Invocation and CrossDevice, as we already have 272 // them covered by "singlethread" and "" strings respectively (see 273 // implementation of LLVMContext::LLVMContext()). 274 static const llvm::SyncScope::ID SubGroup = 275 Ctx.getOrInsertSyncScopeID("subgroup"); 276 static const llvm::SyncScope::ID WorkGroup = 277 Ctx.getOrInsertSyncScopeID("workgroup"); 278 static const llvm::SyncScope::ID Device = 279 Ctx.getOrInsertSyncScopeID("device"); 280 281 if (Id == llvm::SyncScope::SingleThread) 282 return SPIRV::Scope::Invocation; 283 else if (Id == llvm::SyncScope::System) 284 return SPIRV::Scope::CrossDevice; 285 else if (Id == SubGroup) 286 return SPIRV::Scope::Subgroup; 287 else if (Id == WorkGroup) 288 return SPIRV::Scope::Workgroup; 289 else if (Id == Device) 290 return SPIRV::Scope::Device; 291 return SPIRV::Scope::CrossDevice; 292 } 293 294 MachineInstr *getDefInstrMaybeConstant(Register &ConstReg, 295 const MachineRegisterInfo *MRI) { 296 MachineInstr *MI = MRI->getVRegDef(ConstReg); 297 MachineInstr *ConstInstr = 298 MI->getOpcode() == SPIRV::G_TRUNC || MI->getOpcode() == SPIRV::G_ZEXT 299 ? MRI->getVRegDef(MI->getOperand(1).getReg()) 300 : MI; 301 if (auto *GI = dyn_cast<GIntrinsic>(ConstInstr)) { 302 if (GI->is(Intrinsic::spv_track_constant)) { 303 ConstReg = ConstInstr->getOperand(2).getReg(); 304 return MRI->getVRegDef(ConstReg); 305 } 306 } else if (ConstInstr->getOpcode() == SPIRV::ASSIGN_TYPE) { 307 ConstReg = ConstInstr->getOperand(1).getReg(); 308 return MRI->getVRegDef(ConstReg); 309 } 310 return MRI->getVRegDef(ConstReg); 311 } 312 313 uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI) { 314 const MachineInstr *MI = getDefInstrMaybeConstant(ConstReg, MRI); 315 assert(MI && MI->getOpcode() == TargetOpcode::G_CONSTANT); 316 return MI->getOperand(1).getCImm()->getValue().getZExtValue(); 317 } 318 319 bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID) { 320 if (const auto *GI = dyn_cast<GIntrinsic>(&MI)) 321 return GI->is(IntrinsicID); 322 return false; 323 } 324 325 Type *getMDOperandAsType(const MDNode *N, unsigned I) { 326 Type *ElementTy = cast<ValueAsMetadata>(N->getOperand(I))->getType(); 327 return toTypedPointer(ElementTy); 328 } 329 330 // The set of names is borrowed from the SPIR-V translator. 331 // TODO: may be implemented in SPIRVBuiltins.td. 332 static bool isPipeOrAddressSpaceCastBI(const StringRef MangledName) { 333 return MangledName == "write_pipe_2" || MangledName == "read_pipe_2" || 334 MangledName == "write_pipe_2_bl" || MangledName == "read_pipe_2_bl" || 335 MangledName == "write_pipe_4" || MangledName == "read_pipe_4" || 336 MangledName == "reserve_write_pipe" || 337 MangledName == "reserve_read_pipe" || 338 MangledName == "commit_write_pipe" || 339 MangledName == "commit_read_pipe" || 340 MangledName == "work_group_reserve_write_pipe" || 341 MangledName == "work_group_reserve_read_pipe" || 342 MangledName == "work_group_commit_write_pipe" || 343 MangledName == "work_group_commit_read_pipe" || 344 MangledName == "get_pipe_num_packets_ro" || 345 MangledName == "get_pipe_max_packets_ro" || 346 MangledName == "get_pipe_num_packets_wo" || 347 MangledName == "get_pipe_max_packets_wo" || 348 MangledName == "sub_group_reserve_write_pipe" || 349 MangledName == "sub_group_reserve_read_pipe" || 350 MangledName == "sub_group_commit_write_pipe" || 351 MangledName == "sub_group_commit_read_pipe" || 352 MangledName == "to_global" || MangledName == "to_local" || 353 MangledName == "to_private"; 354 } 355 356 static bool isEnqueueKernelBI(const StringRef MangledName) { 357 return MangledName == "__enqueue_kernel_basic" || 358 MangledName == "__enqueue_kernel_basic_events" || 359 MangledName == "__enqueue_kernel_varargs" || 360 MangledName == "__enqueue_kernel_events_varargs"; 361 } 362 363 static bool isKernelQueryBI(const StringRef MangledName) { 364 return MangledName == "__get_kernel_work_group_size_impl" || 365 MangledName == "__get_kernel_sub_group_count_for_ndrange_impl" || 366 MangledName == "__get_kernel_max_sub_group_size_for_ndrange_impl" || 367 MangledName == "__get_kernel_preferred_work_group_size_multiple_impl"; 368 } 369 370 static bool isNonMangledOCLBuiltin(StringRef Name) { 371 if (!Name.starts_with("__")) 372 return false; 373 374 return isEnqueueKernelBI(Name) || isKernelQueryBI(Name) || 375 isPipeOrAddressSpaceCastBI(Name.drop_front(2)) || 376 Name == "__translate_sampler_initializer"; 377 } 378 379 std::string getOclOrSpirvBuiltinDemangledName(StringRef Name) { 380 bool IsNonMangledOCL = isNonMangledOCLBuiltin(Name); 381 bool IsNonMangledSPIRV = Name.starts_with("__spirv_"); 382 bool IsNonMangledHLSL = Name.starts_with("__hlsl_"); 383 bool IsMangled = Name.starts_with("_Z"); 384 385 // Otherwise use simple demangling to return the function name. 386 if (IsNonMangledOCL || IsNonMangledSPIRV || IsNonMangledHLSL || !IsMangled) 387 return Name.str(); 388 389 // Try to use the itanium demangler. 390 if (char *DemangledName = itaniumDemangle(Name.data())) { 391 std::string Result = DemangledName; 392 free(DemangledName); 393 return Result; 394 } 395 396 // Autocheck C++, maybe need to do explicit check of the source language. 397 // OpenCL C++ built-ins are declared in cl namespace. 398 // TODO: consider using 'St' abbriviation for cl namespace mangling. 399 // Similar to ::std:: in C++. 400 size_t Start, Len = 0; 401 size_t DemangledNameLenStart = 2; 402 if (Name.starts_with("_ZN")) { 403 // Skip CV and ref qualifiers. 404 size_t NameSpaceStart = Name.find_first_not_of("rVKRO", 3); 405 // All built-ins are in the ::cl:: namespace. 406 if (Name.substr(NameSpaceStart, 11) != "2cl7__spirv") 407 return std::string(); 408 DemangledNameLenStart = NameSpaceStart + 11; 409 } 410 Start = Name.find_first_not_of("0123456789", DemangledNameLenStart); 411 Name.substr(DemangledNameLenStart, Start - DemangledNameLenStart) 412 .getAsInteger(10, Len); 413 return Name.substr(Start, Len).str(); 414 } 415 416 bool hasBuiltinTypePrefix(StringRef Name) { 417 if (Name.starts_with("opencl.") || Name.starts_with("ocl_") || 418 Name.starts_with("spirv.")) 419 return true; 420 return false; 421 } 422 423 bool isSpecialOpaqueType(const Type *Ty) { 424 if (const TargetExtType *ExtTy = dyn_cast<TargetExtType>(Ty)) 425 return isTypedPointerWrapper(ExtTy) 426 ? false 427 : hasBuiltinTypePrefix(ExtTy->getName()); 428 429 return false; 430 } 431 432 bool isEntryPoint(const Function &F) { 433 // OpenCL handling: any function with the SPIR_KERNEL 434 // calling convention will be a potential entry point. 435 if (F.getCallingConv() == CallingConv::SPIR_KERNEL) 436 return true; 437 438 // HLSL handling: special attribute are emitted from the 439 // front-end. 440 if (F.getFnAttribute("hlsl.shader").isValid()) 441 return true; 442 443 return false; 444 } 445 446 Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx) { 447 TypeName.consume_front("atomic_"); 448 if (TypeName.consume_front("void")) 449 return Type::getVoidTy(Ctx); 450 else if (TypeName.consume_front("bool")) 451 return Type::getIntNTy(Ctx, 1); 452 else if (TypeName.consume_front("char") || 453 TypeName.consume_front("unsigned char") || 454 TypeName.consume_front("uchar")) 455 return Type::getInt8Ty(Ctx); 456 else if (TypeName.consume_front("short") || 457 TypeName.consume_front("unsigned short") || 458 TypeName.consume_front("ushort")) 459 return Type::getInt16Ty(Ctx); 460 else if (TypeName.consume_front("int") || 461 TypeName.consume_front("unsigned int") || 462 TypeName.consume_front("uint")) 463 return Type::getInt32Ty(Ctx); 464 else if (TypeName.consume_front("long") || 465 TypeName.consume_front("unsigned long") || 466 TypeName.consume_front("ulong")) 467 return Type::getInt64Ty(Ctx); 468 else if (TypeName.consume_front("half")) 469 return Type::getHalfTy(Ctx); 470 else if (TypeName.consume_front("float")) 471 return Type::getFloatTy(Ctx); 472 else if (TypeName.consume_front("double")) 473 return Type::getDoubleTy(Ctx); 474 475 // Unable to recognize SPIRV type name 476 return nullptr; 477 } 478 479 std::unordered_set<BasicBlock *> 480 PartialOrderingVisitor::getReachableFrom(BasicBlock *Start) { 481 std::queue<BasicBlock *> ToVisit; 482 ToVisit.push(Start); 483 484 std::unordered_set<BasicBlock *> Output; 485 while (ToVisit.size() != 0) { 486 BasicBlock *BB = ToVisit.front(); 487 ToVisit.pop(); 488 489 if (Output.count(BB) != 0) 490 continue; 491 Output.insert(BB); 492 493 for (BasicBlock *Successor : successors(BB)) { 494 if (DT.dominates(Successor, BB)) 495 continue; 496 ToVisit.push(Successor); 497 } 498 } 499 500 return Output; 501 } 502 503 bool PartialOrderingVisitor::CanBeVisited(BasicBlock *BB) const { 504 for (BasicBlock *P : predecessors(BB)) { 505 // Ignore back-edges. 506 if (DT.dominates(BB, P)) 507 continue; 508 509 // One of the predecessor hasn't been visited. Not ready yet. 510 if (BlockToOrder.count(P) == 0) 511 return false; 512 513 // If the block is a loop exit, the loop must be finished before 514 // we can continue. 515 Loop *L = LI.getLoopFor(P); 516 if (L == nullptr || L->contains(BB)) 517 continue; 518 519 // SPIR-V requires a single back-edge. And the backend first 520 // step transforms loops into the simplified format. If we have 521 // more than 1 back-edge, something is wrong. 522 assert(L->getNumBackEdges() <= 1); 523 524 // If the loop has no latch, loop's rank won't matter, so we can 525 // proceed. 526 BasicBlock *Latch = L->getLoopLatch(); 527 assert(Latch); 528 if (Latch == nullptr) 529 continue; 530 531 // The latch is not ready yet, let's wait. 532 if (BlockToOrder.count(Latch) == 0) 533 return false; 534 } 535 536 return true; 537 } 538 539 size_t PartialOrderingVisitor::GetNodeRank(BasicBlock *BB) const { 540 auto It = BlockToOrder.find(BB); 541 if (It != BlockToOrder.end()) 542 return It->second.Rank; 543 544 size_t result = 0; 545 for (BasicBlock *P : predecessors(BB)) { 546 // Ignore back-edges. 547 if (DT.dominates(BB, P)) 548 continue; 549 550 auto Iterator = BlockToOrder.end(); 551 Loop *L = LI.getLoopFor(P); 552 BasicBlock *Latch = L ? L->getLoopLatch() : nullptr; 553 554 // If the predecessor is either outside a loop, or part of 555 // the same loop, simply take its rank + 1. 556 if (L == nullptr || L->contains(BB) || Latch == nullptr) { 557 Iterator = BlockToOrder.find(P); 558 } else { 559 // Otherwise, take the loop's rank (highest rank in the loop) as base. 560 // Since loops have a single latch, highest rank is easy to find. 561 // If the loop has no latch, then it doesn't matter. 562 Iterator = BlockToOrder.find(Latch); 563 } 564 565 assert(Iterator != BlockToOrder.end()); 566 result = std::max(result, Iterator->second.Rank + 1); 567 } 568 569 return result; 570 } 571 572 size_t PartialOrderingVisitor::visit(BasicBlock *BB, size_t Unused) { 573 ToVisit.push(BB); 574 Queued.insert(BB); 575 576 size_t QueueIndex = 0; 577 while (ToVisit.size() != 0) { 578 BasicBlock *BB = ToVisit.front(); 579 ToVisit.pop(); 580 581 if (!CanBeVisited(BB)) { 582 ToVisit.push(BB); 583 assert(QueueIndex < ToVisit.size() && 584 "No valid candidate in the queue. Is the graph reducible?"); 585 QueueIndex++; 586 continue; 587 } 588 589 QueueIndex = 0; 590 size_t Rank = GetNodeRank(BB); 591 OrderInfo Info = {Rank, BlockToOrder.size()}; 592 BlockToOrder.emplace(BB, Info); 593 594 for (BasicBlock *S : successors(BB)) { 595 if (Queued.count(S) != 0) 596 continue; 597 ToVisit.push(S); 598 Queued.insert(S); 599 } 600 } 601 602 return 0; 603 } 604 605 PartialOrderingVisitor::PartialOrderingVisitor(Function &F) { 606 DT.recalculate(F); 607 LI = LoopInfo(DT); 608 609 visit(&*F.begin(), 0); 610 611 Order.reserve(F.size()); 612 for (auto &[BB, Info] : BlockToOrder) 613 Order.emplace_back(BB); 614 615 std::sort(Order.begin(), Order.end(), [&](const auto &LHS, const auto &RHS) { 616 return compare(LHS, RHS); 617 }); 618 } 619 620 bool PartialOrderingVisitor::compare(const BasicBlock *LHS, 621 const BasicBlock *RHS) const { 622 const OrderInfo &InfoLHS = BlockToOrder.at(const_cast<BasicBlock *>(LHS)); 623 const OrderInfo &InfoRHS = BlockToOrder.at(const_cast<BasicBlock *>(RHS)); 624 if (InfoLHS.Rank != InfoRHS.Rank) 625 return InfoLHS.Rank < InfoRHS.Rank; 626 return InfoLHS.TraversalIndex < InfoRHS.TraversalIndex; 627 } 628 629 void PartialOrderingVisitor::partialOrderVisit( 630 BasicBlock &Start, std::function<bool(BasicBlock *)> Op) { 631 std::unordered_set<BasicBlock *> Reachable = getReachableFrom(&Start); 632 assert(BlockToOrder.count(&Start) != 0); 633 634 // Skipping blocks with a rank inferior to |Start|'s rank. 635 auto It = Order.begin(); 636 while (It != Order.end() && *It != &Start) 637 ++It; 638 639 // This is unexpected. Worst case |Start| is the last block, 640 // so It should point to the last block, not past-end. 641 assert(It != Order.end()); 642 643 // By default, there is no rank limit. Setting it to the maximum value. 644 std::optional<size_t> EndRank = std::nullopt; 645 for (; It != Order.end(); ++It) { 646 if (EndRank.has_value() && BlockToOrder[*It].Rank > *EndRank) 647 break; 648 649 if (Reachable.count(*It) == 0) { 650 continue; 651 } 652 653 if (!Op(*It)) { 654 EndRank = BlockToOrder[*It].Rank; 655 } 656 } 657 } 658 659 bool sortBlocks(Function &F) { 660 if (F.size() == 0) 661 return false; 662 663 bool Modified = false; 664 std::vector<BasicBlock *> Order; 665 Order.reserve(F.size()); 666 667 ReversePostOrderTraversal<Function *> RPOT(&F); 668 for (BasicBlock *BB : RPOT) 669 Order.push_back(BB); 670 671 assert(&*F.begin() == Order[0]); 672 BasicBlock *LastBlock = &*F.begin(); 673 for (BasicBlock *BB : Order) { 674 if (BB != LastBlock && &*LastBlock->getNextNode() != BB) { 675 Modified = true; 676 BB->moveAfter(LastBlock); 677 } 678 LastBlock = BB; 679 } 680 681 return Modified; 682 } 683 684 MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg) { 685 MachineInstr *MaybeDef = MRI.getVRegDef(Reg); 686 if (MaybeDef && MaybeDef->getOpcode() == SPIRV::ASSIGN_TYPE) 687 MaybeDef = MRI.getVRegDef(MaybeDef->getOperand(1).getReg()); 688 return MaybeDef; 689 } 690 691 bool getVacantFunctionName(Module &M, std::string &Name) { 692 // It's a bit of paranoia, but still we don't want to have even a chance that 693 // the loop will work for too long. 694 constexpr unsigned MaxIters = 1024; 695 for (unsigned I = 0; I < MaxIters; ++I) { 696 std::string OrdName = Name + Twine(I).str(); 697 if (!M.getFunction(OrdName)) { 698 Name = OrdName; 699 return true; 700 } 701 } 702 return false; 703 } 704 705 // Assign SPIR-V type to the register. If the register has no valid assigned 706 // class, set register LLT type and class according to the SPIR-V type. 707 void setRegClassType(Register Reg, SPIRVType *SpvType, SPIRVGlobalRegistry *GR, 708 MachineRegisterInfo *MRI, const MachineFunction &MF, 709 bool Force) { 710 GR->assignSPIRVTypeToVReg(SpvType, Reg, MF); 711 if (!MRI->getRegClassOrNull(Reg) || Force) { 712 MRI->setRegClass(Reg, GR->getRegClass(SpvType)); 713 MRI->setType(Reg, GR->getRegType(SpvType)); 714 } 715 } 716 717 // Create a SPIR-V type, assign SPIR-V type to the register. If the register has 718 // no valid assigned class, set register LLT type and class according to the 719 // SPIR-V type. 720 void setRegClassType(Register Reg, const Type *Ty, SPIRVGlobalRegistry *GR, 721 MachineIRBuilder &MIRBuilder, bool Force) { 722 setRegClassType(Reg, GR->getOrCreateSPIRVType(Ty, MIRBuilder), GR, 723 MIRBuilder.getMRI(), MIRBuilder.getMF(), Force); 724 } 725 726 // Create a virtual register and assign SPIR-V type to the register. Set 727 // register LLT type and class according to the SPIR-V type. 728 Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR, 729 MachineRegisterInfo *MRI, 730 const MachineFunction &MF) { 731 Register Reg = MRI->createVirtualRegister(GR->getRegClass(SpvType)); 732 MRI->setType(Reg, GR->getRegType(SpvType)); 733 GR->assignSPIRVTypeToVReg(SpvType, Reg, MF); 734 return Reg; 735 } 736 737 // Create a virtual register and assign SPIR-V type to the register. Set 738 // register LLT type and class according to the SPIR-V type. 739 Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR, 740 MachineIRBuilder &MIRBuilder) { 741 return createVirtualRegister(SpvType, GR, MIRBuilder.getMRI(), 742 MIRBuilder.getMF()); 743 } 744 745 // Create a SPIR-V type, virtual register and assign SPIR-V type to the 746 // register. Set register LLT type and class according to the SPIR-V type. 747 Register createVirtualRegister(const Type *Ty, SPIRVGlobalRegistry *GR, 748 MachineIRBuilder &MIRBuilder) { 749 return createVirtualRegister(GR->getOrCreateSPIRVType(Ty, MIRBuilder), GR, 750 MIRBuilder); 751 } 752 753 // Return true if there is an opaque pointer type nested in the argument. 754 bool isNestedPointer(const Type *Ty) { 755 if (Ty->isPtrOrPtrVectorTy()) 756 return true; 757 if (const FunctionType *RefTy = dyn_cast<FunctionType>(Ty)) { 758 if (isNestedPointer(RefTy->getReturnType())) 759 return true; 760 for (const Type *ArgTy : RefTy->params()) 761 if (isNestedPointer(ArgTy)) 762 return true; 763 return false; 764 } 765 if (const ArrayType *RefTy = dyn_cast<ArrayType>(Ty)) 766 return isNestedPointer(RefTy->getElementType()); 767 return false; 768 } 769 770 bool isSpvIntrinsic(const Value *Arg) { 771 if (const auto *II = dyn_cast<IntrinsicInst>(Arg)) 772 if (Function *F = II->getCalledFunction()) 773 if (F->getName().starts_with("llvm.spv.")) 774 return true; 775 return false; 776 } 777 778 } // namespace llvm 779