1 //===--- SPIRVUtils.cpp ---- SPIR-V Utility Functions -----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains miscellaneous utility functions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "SPIRVUtils.h" 14 #include "MCTargetDesc/SPIRVBaseInfo.h" 15 #include "SPIRV.h" 16 #include "SPIRVGlobalRegistry.h" 17 #include "SPIRVInstrInfo.h" 18 #include "SPIRVSubtarget.h" 19 #include "llvm/ADT/StringRef.h" 20 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" 21 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 22 #include "llvm/CodeGen/MachineInstr.h" 23 #include "llvm/CodeGen/MachineInstrBuilder.h" 24 #include "llvm/Demangle/Demangle.h" 25 #include "llvm/IR/IntrinsicInst.h" 26 #include "llvm/IR/IntrinsicsSPIRV.h" 27 #include <queue> 28 #include <vector> 29 30 namespace llvm { 31 32 // The following functions are used to add these string literals as a series of 33 // 32-bit integer operands with the correct format, and unpack them if necessary 34 // when making string comparisons in compiler passes. 35 // SPIR-V requires null-terminated UTF-8 strings padded to 32-bit alignment. 36 static uint32_t convertCharsToWord(const StringRef &Str, unsigned i) { 37 uint32_t Word = 0u; // Build up this 32-bit word from 4 8-bit chars. 38 for (unsigned WordIndex = 0; WordIndex < 4; ++WordIndex) { 39 unsigned StrIndex = i + WordIndex; 40 uint8_t CharToAdd = 0; // Initilize char as padding/null. 41 if (StrIndex < Str.size()) { // If it's within the string, get a real char. 42 CharToAdd = Str[StrIndex]; 43 } 44 Word |= (CharToAdd << (WordIndex * 8)); 45 } 46 return Word; 47 } 48 49 // Get length including padding and null terminator. 50 static size_t getPaddedLen(const StringRef &Str) { 51 return (Str.size() + 4) & ~3; 52 } 53 54 void addStringImm(const StringRef &Str, MCInst &Inst) { 55 const size_t PaddedLen = getPaddedLen(Str); 56 for (unsigned i = 0; i < PaddedLen; i += 4) { 57 // Add an operand for the 32-bits of chars or padding. 58 Inst.addOperand(MCOperand::createImm(convertCharsToWord(Str, i))); 59 } 60 } 61 62 void addStringImm(const StringRef &Str, MachineInstrBuilder &MIB) { 63 const size_t PaddedLen = getPaddedLen(Str); 64 for (unsigned i = 0; i < PaddedLen; i += 4) { 65 // Add an operand for the 32-bits of chars or padding. 66 MIB.addImm(convertCharsToWord(Str, i)); 67 } 68 } 69 70 void addStringImm(const StringRef &Str, IRBuilder<> &B, 71 std::vector<Value *> &Args) { 72 const size_t PaddedLen = getPaddedLen(Str); 73 for (unsigned i = 0; i < PaddedLen; i += 4) { 74 // Add a vector element for the 32-bits of chars or padding. 75 Args.push_back(B.getInt32(convertCharsToWord(Str, i))); 76 } 77 } 78 79 std::string getStringImm(const MachineInstr &MI, unsigned StartIndex) { 80 return getSPIRVStringOperand(MI, StartIndex); 81 } 82 83 void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) { 84 const auto Bitwidth = Imm.getBitWidth(); 85 if (Bitwidth == 1) 86 return; // Already handled 87 else if (Bitwidth <= 32) { 88 MIB.addImm(Imm.getZExtValue()); 89 // Asm Printer needs this info to print floating-type correctly 90 if (Bitwidth == 16) 91 MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH16); 92 return; 93 } else if (Bitwidth <= 64) { 94 uint64_t FullImm = Imm.getZExtValue(); 95 uint32_t LowBits = FullImm & 0xffffffff; 96 uint32_t HighBits = (FullImm >> 32) & 0xffffffff; 97 MIB.addImm(LowBits).addImm(HighBits); 98 return; 99 } 100 report_fatal_error("Unsupported constant bitwidth"); 101 } 102 103 void buildOpName(Register Target, const StringRef &Name, 104 MachineIRBuilder &MIRBuilder) { 105 if (!Name.empty()) { 106 auto MIB = MIRBuilder.buildInstr(SPIRV::OpName).addUse(Target); 107 addStringImm(Name, MIB); 108 } 109 } 110 111 void buildOpName(Register Target, const StringRef &Name, MachineInstr &I, 112 const SPIRVInstrInfo &TII) { 113 if (!Name.empty()) { 114 auto MIB = 115 BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpName)) 116 .addUse(Target); 117 addStringImm(Name, MIB); 118 } 119 } 120 121 static void finishBuildOpDecorate(MachineInstrBuilder &MIB, 122 const std::vector<uint32_t> &DecArgs, 123 StringRef StrImm) { 124 if (!StrImm.empty()) 125 addStringImm(StrImm, MIB); 126 for (const auto &DecArg : DecArgs) 127 MIB.addImm(DecArg); 128 } 129 130 void buildOpDecorate(Register Reg, MachineIRBuilder &MIRBuilder, 131 SPIRV::Decoration::Decoration Dec, 132 const std::vector<uint32_t> &DecArgs, StringRef StrImm) { 133 auto MIB = MIRBuilder.buildInstr(SPIRV::OpDecorate) 134 .addUse(Reg) 135 .addImm(static_cast<uint32_t>(Dec)); 136 finishBuildOpDecorate(MIB, DecArgs, StrImm); 137 } 138 139 void buildOpDecorate(Register Reg, MachineInstr &I, const SPIRVInstrInfo &TII, 140 SPIRV::Decoration::Decoration Dec, 141 const std::vector<uint32_t> &DecArgs, StringRef StrImm) { 142 MachineBasicBlock &MBB = *I.getParent(); 143 auto MIB = BuildMI(MBB, I, I.getDebugLoc(), TII.get(SPIRV::OpDecorate)) 144 .addUse(Reg) 145 .addImm(static_cast<uint32_t>(Dec)); 146 finishBuildOpDecorate(MIB, DecArgs, StrImm); 147 } 148 149 void buildOpSpirvDecorations(Register Reg, MachineIRBuilder &MIRBuilder, 150 const MDNode *GVarMD) { 151 for (unsigned I = 0, E = GVarMD->getNumOperands(); I != E; ++I) { 152 auto *OpMD = dyn_cast<MDNode>(GVarMD->getOperand(I)); 153 if (!OpMD) 154 report_fatal_error("Invalid decoration"); 155 if (OpMD->getNumOperands() == 0) 156 report_fatal_error("Expect operand(s) of the decoration"); 157 ConstantInt *DecorationId = 158 mdconst::dyn_extract<ConstantInt>(OpMD->getOperand(0)); 159 if (!DecorationId) 160 report_fatal_error("Expect SPIR-V <Decoration> operand to be the first " 161 "element of the decoration"); 162 auto MIB = MIRBuilder.buildInstr(SPIRV::OpDecorate) 163 .addUse(Reg) 164 .addImm(static_cast<uint32_t>(DecorationId->getZExtValue())); 165 for (unsigned OpI = 1, OpE = OpMD->getNumOperands(); OpI != OpE; ++OpI) { 166 if (ConstantInt *OpV = 167 mdconst::dyn_extract<ConstantInt>(OpMD->getOperand(OpI))) 168 MIB.addImm(static_cast<uint32_t>(OpV->getZExtValue())); 169 else if (MDString *OpV = dyn_cast<MDString>(OpMD->getOperand(OpI))) 170 addStringImm(OpV->getString(), MIB); 171 else 172 report_fatal_error("Unexpected operand of the decoration"); 173 } 174 } 175 } 176 177 MachineBasicBlock::iterator getOpVariableMBBIt(MachineInstr &I) { 178 MachineFunction *MF = I.getParent()->getParent(); 179 MachineBasicBlock *MBB = &MF->front(); 180 MachineBasicBlock::iterator It = MBB->SkipPHIsAndLabels(MBB->begin()), 181 E = MBB->end(); 182 bool IsHeader = false; 183 unsigned Opcode; 184 for (; It != E && It != I; ++It) { 185 Opcode = It->getOpcode(); 186 if (Opcode == SPIRV::OpFunction || Opcode == SPIRV::OpFunctionParameter) { 187 IsHeader = true; 188 } else if (IsHeader && 189 !(Opcode == SPIRV::ASSIGN_TYPE || Opcode == SPIRV::OpLabel)) { 190 ++It; 191 break; 192 } 193 } 194 return It; 195 } 196 197 MachineBasicBlock::iterator getInsertPtValidEnd(MachineBasicBlock *MBB) { 198 MachineBasicBlock::iterator I = MBB->end(); 199 if (I == MBB->begin()) 200 return I; 201 --I; 202 while (I->isTerminator() || I->isDebugValue()) { 203 if (I == MBB->begin()) 204 break; 205 --I; 206 } 207 return I; 208 } 209 210 SPIRV::StorageClass::StorageClass 211 addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI) { 212 switch (AddrSpace) { 213 case 0: 214 return SPIRV::StorageClass::Function; 215 case 1: 216 return SPIRV::StorageClass::CrossWorkgroup; 217 case 2: 218 return SPIRV::StorageClass::UniformConstant; 219 case 3: 220 return SPIRV::StorageClass::Workgroup; 221 case 4: 222 return SPIRV::StorageClass::Generic; 223 case 5: 224 return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes) 225 ? SPIRV::StorageClass::DeviceOnlyINTEL 226 : SPIRV::StorageClass::CrossWorkgroup; 227 case 6: 228 return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes) 229 ? SPIRV::StorageClass::HostOnlyINTEL 230 : SPIRV::StorageClass::CrossWorkgroup; 231 case 7: 232 return SPIRV::StorageClass::Input; 233 case 8: 234 return SPIRV::StorageClass::Output; 235 case 9: 236 return SPIRV::StorageClass::CodeSectionINTEL; 237 case 10: 238 return SPIRV::StorageClass::Private; 239 default: 240 report_fatal_error("Unknown address space"); 241 } 242 } 243 244 SPIRV::MemorySemantics::MemorySemantics 245 getMemSemanticsForStorageClass(SPIRV::StorageClass::StorageClass SC) { 246 switch (SC) { 247 case SPIRV::StorageClass::StorageBuffer: 248 case SPIRV::StorageClass::Uniform: 249 return SPIRV::MemorySemantics::UniformMemory; 250 case SPIRV::StorageClass::Workgroup: 251 return SPIRV::MemorySemantics::WorkgroupMemory; 252 case SPIRV::StorageClass::CrossWorkgroup: 253 return SPIRV::MemorySemantics::CrossWorkgroupMemory; 254 case SPIRV::StorageClass::AtomicCounter: 255 return SPIRV::MemorySemantics::AtomicCounterMemory; 256 case SPIRV::StorageClass::Image: 257 return SPIRV::MemorySemantics::ImageMemory; 258 default: 259 return SPIRV::MemorySemantics::None; 260 } 261 } 262 263 SPIRV::MemorySemantics::MemorySemantics getMemSemantics(AtomicOrdering Ord) { 264 switch (Ord) { 265 case AtomicOrdering::Acquire: 266 return SPIRV::MemorySemantics::Acquire; 267 case AtomicOrdering::Release: 268 return SPIRV::MemorySemantics::Release; 269 case AtomicOrdering::AcquireRelease: 270 return SPIRV::MemorySemantics::AcquireRelease; 271 case AtomicOrdering::SequentiallyConsistent: 272 return SPIRV::MemorySemantics::SequentiallyConsistent; 273 case AtomicOrdering::Unordered: 274 case AtomicOrdering::Monotonic: 275 case AtomicOrdering::NotAtomic: 276 return SPIRV::MemorySemantics::None; 277 } 278 llvm_unreachable(nullptr); 279 } 280 281 SPIRV::Scope::Scope getMemScope(LLVMContext &Ctx, SyncScope::ID Id) { 282 // Named by 283 // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_scope_id. 284 // We don't need aliases for Invocation and CrossDevice, as we already have 285 // them covered by "singlethread" and "" strings respectively (see 286 // implementation of LLVMContext::LLVMContext()). 287 static const llvm::SyncScope::ID SubGroup = 288 Ctx.getOrInsertSyncScopeID("subgroup"); 289 static const llvm::SyncScope::ID WorkGroup = 290 Ctx.getOrInsertSyncScopeID("workgroup"); 291 static const llvm::SyncScope::ID Device = 292 Ctx.getOrInsertSyncScopeID("device"); 293 294 if (Id == llvm::SyncScope::SingleThread) 295 return SPIRV::Scope::Invocation; 296 else if (Id == llvm::SyncScope::System) 297 return SPIRV::Scope::CrossDevice; 298 else if (Id == SubGroup) 299 return SPIRV::Scope::Subgroup; 300 else if (Id == WorkGroup) 301 return SPIRV::Scope::Workgroup; 302 else if (Id == Device) 303 return SPIRV::Scope::Device; 304 return SPIRV::Scope::CrossDevice; 305 } 306 307 MachineInstr *getDefInstrMaybeConstant(Register &ConstReg, 308 const MachineRegisterInfo *MRI) { 309 MachineInstr *MI = MRI->getVRegDef(ConstReg); 310 MachineInstr *ConstInstr = 311 MI->getOpcode() == SPIRV::G_TRUNC || MI->getOpcode() == SPIRV::G_ZEXT 312 ? MRI->getVRegDef(MI->getOperand(1).getReg()) 313 : MI; 314 if (auto *GI = dyn_cast<GIntrinsic>(ConstInstr)) { 315 if (GI->is(Intrinsic::spv_track_constant)) { 316 ConstReg = ConstInstr->getOperand(2).getReg(); 317 return MRI->getVRegDef(ConstReg); 318 } 319 } else if (ConstInstr->getOpcode() == SPIRV::ASSIGN_TYPE) { 320 ConstReg = ConstInstr->getOperand(1).getReg(); 321 return MRI->getVRegDef(ConstReg); 322 } 323 return MRI->getVRegDef(ConstReg); 324 } 325 326 uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI) { 327 const MachineInstr *MI = getDefInstrMaybeConstant(ConstReg, MRI); 328 assert(MI && MI->getOpcode() == TargetOpcode::G_CONSTANT); 329 return MI->getOperand(1).getCImm()->getValue().getZExtValue(); 330 } 331 332 bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID) { 333 if (const auto *GI = dyn_cast<GIntrinsic>(&MI)) 334 return GI->is(IntrinsicID); 335 return false; 336 } 337 338 Type *getMDOperandAsType(const MDNode *N, unsigned I) { 339 Type *ElementTy = cast<ValueAsMetadata>(N->getOperand(I))->getType(); 340 return toTypedPointer(ElementTy); 341 } 342 343 // The set of names is borrowed from the SPIR-V translator. 344 // TODO: may be implemented in SPIRVBuiltins.td. 345 static bool isPipeOrAddressSpaceCastBI(const StringRef MangledName) { 346 return MangledName == "write_pipe_2" || MangledName == "read_pipe_2" || 347 MangledName == "write_pipe_2_bl" || MangledName == "read_pipe_2_bl" || 348 MangledName == "write_pipe_4" || MangledName == "read_pipe_4" || 349 MangledName == "reserve_write_pipe" || 350 MangledName == "reserve_read_pipe" || 351 MangledName == "commit_write_pipe" || 352 MangledName == "commit_read_pipe" || 353 MangledName == "work_group_reserve_write_pipe" || 354 MangledName == "work_group_reserve_read_pipe" || 355 MangledName == "work_group_commit_write_pipe" || 356 MangledName == "work_group_commit_read_pipe" || 357 MangledName == "get_pipe_num_packets_ro" || 358 MangledName == "get_pipe_max_packets_ro" || 359 MangledName == "get_pipe_num_packets_wo" || 360 MangledName == "get_pipe_max_packets_wo" || 361 MangledName == "sub_group_reserve_write_pipe" || 362 MangledName == "sub_group_reserve_read_pipe" || 363 MangledName == "sub_group_commit_write_pipe" || 364 MangledName == "sub_group_commit_read_pipe" || 365 MangledName == "to_global" || MangledName == "to_local" || 366 MangledName == "to_private"; 367 } 368 369 static bool isEnqueueKernelBI(const StringRef MangledName) { 370 return MangledName == "__enqueue_kernel_basic" || 371 MangledName == "__enqueue_kernel_basic_events" || 372 MangledName == "__enqueue_kernel_varargs" || 373 MangledName == "__enqueue_kernel_events_varargs"; 374 } 375 376 static bool isKernelQueryBI(const StringRef MangledName) { 377 return MangledName == "__get_kernel_work_group_size_impl" || 378 MangledName == "__get_kernel_sub_group_count_for_ndrange_impl" || 379 MangledName == "__get_kernel_max_sub_group_size_for_ndrange_impl" || 380 MangledName == "__get_kernel_preferred_work_group_size_multiple_impl"; 381 } 382 383 static bool isNonMangledOCLBuiltin(StringRef Name) { 384 if (!Name.starts_with("__")) 385 return false; 386 387 return isEnqueueKernelBI(Name) || isKernelQueryBI(Name) || 388 isPipeOrAddressSpaceCastBI(Name.drop_front(2)) || 389 Name == "__translate_sampler_initializer"; 390 } 391 392 std::string getOclOrSpirvBuiltinDemangledName(StringRef Name) { 393 bool IsNonMangledOCL = isNonMangledOCLBuiltin(Name); 394 bool IsNonMangledSPIRV = Name.starts_with("__spirv_"); 395 bool IsNonMangledHLSL = Name.starts_with("__hlsl_"); 396 bool IsMangled = Name.starts_with("_Z"); 397 398 // Otherwise use simple demangling to return the function name. 399 if (IsNonMangledOCL || IsNonMangledSPIRV || IsNonMangledHLSL || !IsMangled) 400 return Name.str(); 401 402 // Try to use the itanium demangler. 403 if (char *DemangledName = itaniumDemangle(Name.data())) { 404 std::string Result = DemangledName; 405 free(DemangledName); 406 return Result; 407 } 408 409 // Autocheck C++, maybe need to do explicit check of the source language. 410 // OpenCL C++ built-ins are declared in cl namespace. 411 // TODO: consider using 'St' abbriviation for cl namespace mangling. 412 // Similar to ::std:: in C++. 413 size_t Start, Len = 0; 414 size_t DemangledNameLenStart = 2; 415 if (Name.starts_with("_ZN")) { 416 // Skip CV and ref qualifiers. 417 size_t NameSpaceStart = Name.find_first_not_of("rVKRO", 3); 418 // All built-ins are in the ::cl:: namespace. 419 if (Name.substr(NameSpaceStart, 11) != "2cl7__spirv") 420 return std::string(); 421 DemangledNameLenStart = NameSpaceStart + 11; 422 } 423 Start = Name.find_first_not_of("0123456789", DemangledNameLenStart); 424 Name.substr(DemangledNameLenStart, Start - DemangledNameLenStart) 425 .getAsInteger(10, Len); 426 return Name.substr(Start, Len).str(); 427 } 428 429 bool hasBuiltinTypePrefix(StringRef Name) { 430 if (Name.starts_with("opencl.") || Name.starts_with("ocl_") || 431 Name.starts_with("spirv.")) 432 return true; 433 return false; 434 } 435 436 bool isSpecialOpaqueType(const Type *Ty) { 437 if (const TargetExtType *ExtTy = dyn_cast<TargetExtType>(Ty)) 438 return isTypedPointerWrapper(ExtTy) 439 ? false 440 : hasBuiltinTypePrefix(ExtTy->getName()); 441 442 return false; 443 } 444 445 bool isEntryPoint(const Function &F) { 446 // OpenCL handling: any function with the SPIR_KERNEL 447 // calling convention will be a potential entry point. 448 if (F.getCallingConv() == CallingConv::SPIR_KERNEL) 449 return true; 450 451 // HLSL handling: special attribute are emitted from the 452 // front-end. 453 if (F.getFnAttribute("hlsl.shader").isValid()) 454 return true; 455 456 return false; 457 } 458 459 Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx) { 460 TypeName.consume_front("atomic_"); 461 if (TypeName.consume_front("void")) 462 return Type::getVoidTy(Ctx); 463 else if (TypeName.consume_front("bool") || TypeName.consume_front("_Bool")) 464 return Type::getIntNTy(Ctx, 1); 465 else if (TypeName.consume_front("char") || 466 TypeName.consume_front("signed char") || 467 TypeName.consume_front("unsigned char") || 468 TypeName.consume_front("uchar")) 469 return Type::getInt8Ty(Ctx); 470 else if (TypeName.consume_front("short") || 471 TypeName.consume_front("signed short") || 472 TypeName.consume_front("unsigned short") || 473 TypeName.consume_front("ushort")) 474 return Type::getInt16Ty(Ctx); 475 else if (TypeName.consume_front("int") || 476 TypeName.consume_front("signed int") || 477 TypeName.consume_front("unsigned int") || 478 TypeName.consume_front("uint")) 479 return Type::getInt32Ty(Ctx); 480 else if (TypeName.consume_front("long") || 481 TypeName.consume_front("signed long") || 482 TypeName.consume_front("unsigned long") || 483 TypeName.consume_front("ulong")) 484 return Type::getInt64Ty(Ctx); 485 else if (TypeName.consume_front("half") || 486 TypeName.consume_front("_Float16") || 487 TypeName.consume_front("__fp16")) 488 return Type::getHalfTy(Ctx); 489 else if (TypeName.consume_front("float")) 490 return Type::getFloatTy(Ctx); 491 else if (TypeName.consume_front("double")) 492 return Type::getDoubleTy(Ctx); 493 494 // Unable to recognize SPIRV type name 495 return nullptr; 496 } 497 498 std::unordered_set<BasicBlock *> 499 PartialOrderingVisitor::getReachableFrom(BasicBlock *Start) { 500 std::queue<BasicBlock *> ToVisit; 501 ToVisit.push(Start); 502 503 std::unordered_set<BasicBlock *> Output; 504 while (ToVisit.size() != 0) { 505 BasicBlock *BB = ToVisit.front(); 506 ToVisit.pop(); 507 508 if (Output.count(BB) != 0) 509 continue; 510 Output.insert(BB); 511 512 for (BasicBlock *Successor : successors(BB)) { 513 if (DT.dominates(Successor, BB)) 514 continue; 515 ToVisit.push(Successor); 516 } 517 } 518 519 return Output; 520 } 521 522 bool PartialOrderingVisitor::CanBeVisited(BasicBlock *BB) const { 523 for (BasicBlock *P : predecessors(BB)) { 524 // Ignore back-edges. 525 if (DT.dominates(BB, P)) 526 continue; 527 528 // One of the predecessor hasn't been visited. Not ready yet. 529 if (BlockToOrder.count(P) == 0) 530 return false; 531 532 // If the block is a loop exit, the loop must be finished before 533 // we can continue. 534 Loop *L = LI.getLoopFor(P); 535 if (L == nullptr || L->contains(BB)) 536 continue; 537 538 // SPIR-V requires a single back-edge. And the backend first 539 // step transforms loops into the simplified format. If we have 540 // more than 1 back-edge, something is wrong. 541 assert(L->getNumBackEdges() <= 1); 542 543 // If the loop has no latch, loop's rank won't matter, so we can 544 // proceed. 545 BasicBlock *Latch = L->getLoopLatch(); 546 assert(Latch); 547 if (Latch == nullptr) 548 continue; 549 550 // The latch is not ready yet, let's wait. 551 if (BlockToOrder.count(Latch) == 0) 552 return false; 553 } 554 555 return true; 556 } 557 558 size_t PartialOrderingVisitor::GetNodeRank(BasicBlock *BB) const { 559 auto It = BlockToOrder.find(BB); 560 if (It != BlockToOrder.end()) 561 return It->second.Rank; 562 563 size_t result = 0; 564 for (BasicBlock *P : predecessors(BB)) { 565 // Ignore back-edges. 566 if (DT.dominates(BB, P)) 567 continue; 568 569 auto Iterator = BlockToOrder.end(); 570 Loop *L = LI.getLoopFor(P); 571 BasicBlock *Latch = L ? L->getLoopLatch() : nullptr; 572 573 // If the predecessor is either outside a loop, or part of 574 // the same loop, simply take its rank + 1. 575 if (L == nullptr || L->contains(BB) || Latch == nullptr) { 576 Iterator = BlockToOrder.find(P); 577 } else { 578 // Otherwise, take the loop's rank (highest rank in the loop) as base. 579 // Since loops have a single latch, highest rank is easy to find. 580 // If the loop has no latch, then it doesn't matter. 581 Iterator = BlockToOrder.find(Latch); 582 } 583 584 assert(Iterator != BlockToOrder.end()); 585 result = std::max(result, Iterator->second.Rank + 1); 586 } 587 588 return result; 589 } 590 591 size_t PartialOrderingVisitor::visit(BasicBlock *BB, size_t Unused) { 592 ToVisit.push(BB); 593 Queued.insert(BB); 594 595 size_t QueueIndex = 0; 596 while (ToVisit.size() != 0) { 597 BasicBlock *BB = ToVisit.front(); 598 ToVisit.pop(); 599 600 if (!CanBeVisited(BB)) { 601 ToVisit.push(BB); 602 if (QueueIndex >= ToVisit.size()) 603 llvm::report_fatal_error( 604 "No valid candidate in the queue. Is the graph reducible?"); 605 QueueIndex++; 606 continue; 607 } 608 609 QueueIndex = 0; 610 size_t Rank = GetNodeRank(BB); 611 OrderInfo Info = {Rank, BlockToOrder.size()}; 612 BlockToOrder.emplace(BB, Info); 613 614 for (BasicBlock *S : successors(BB)) { 615 if (Queued.count(S) != 0) 616 continue; 617 ToVisit.push(S); 618 Queued.insert(S); 619 } 620 } 621 622 return 0; 623 } 624 625 PartialOrderingVisitor::PartialOrderingVisitor(Function &F) { 626 DT.recalculate(F); 627 LI = LoopInfo(DT); 628 629 visit(&*F.begin(), 0); 630 631 Order.reserve(F.size()); 632 for (auto &[BB, Info] : BlockToOrder) 633 Order.emplace_back(BB); 634 635 std::sort(Order.begin(), Order.end(), [&](const auto &LHS, const auto &RHS) { 636 return compare(LHS, RHS); 637 }); 638 } 639 640 bool PartialOrderingVisitor::compare(const BasicBlock *LHS, 641 const BasicBlock *RHS) const { 642 const OrderInfo &InfoLHS = BlockToOrder.at(const_cast<BasicBlock *>(LHS)); 643 const OrderInfo &InfoRHS = BlockToOrder.at(const_cast<BasicBlock *>(RHS)); 644 if (InfoLHS.Rank != InfoRHS.Rank) 645 return InfoLHS.Rank < InfoRHS.Rank; 646 return InfoLHS.TraversalIndex < InfoRHS.TraversalIndex; 647 } 648 649 void PartialOrderingVisitor::partialOrderVisit( 650 BasicBlock &Start, std::function<bool(BasicBlock *)> Op) { 651 std::unordered_set<BasicBlock *> Reachable = getReachableFrom(&Start); 652 assert(BlockToOrder.count(&Start) != 0); 653 654 // Skipping blocks with a rank inferior to |Start|'s rank. 655 auto It = Order.begin(); 656 while (It != Order.end() && *It != &Start) 657 ++It; 658 659 // This is unexpected. Worst case |Start| is the last block, 660 // so It should point to the last block, not past-end. 661 assert(It != Order.end()); 662 663 // By default, there is no rank limit. Setting it to the maximum value. 664 std::optional<size_t> EndRank = std::nullopt; 665 for (; It != Order.end(); ++It) { 666 if (EndRank.has_value() && BlockToOrder[*It].Rank > *EndRank) 667 break; 668 669 if (Reachable.count(*It) == 0) { 670 continue; 671 } 672 673 if (!Op(*It)) { 674 EndRank = BlockToOrder[*It].Rank; 675 } 676 } 677 } 678 679 bool sortBlocks(Function &F) { 680 if (F.size() == 0) 681 return false; 682 683 bool Modified = false; 684 std::vector<BasicBlock *> Order; 685 Order.reserve(F.size()); 686 687 ReversePostOrderTraversal<Function *> RPOT(&F); 688 for (BasicBlock *BB : RPOT) 689 Order.push_back(BB); 690 691 assert(&*F.begin() == Order[0]); 692 BasicBlock *LastBlock = &*F.begin(); 693 for (BasicBlock *BB : Order) { 694 if (BB != LastBlock && &*LastBlock->getNextNode() != BB) { 695 Modified = true; 696 BB->moveAfter(LastBlock); 697 } 698 LastBlock = BB; 699 } 700 701 return Modified; 702 } 703 704 MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg) { 705 MachineInstr *MaybeDef = MRI.getVRegDef(Reg); 706 if (MaybeDef && MaybeDef->getOpcode() == SPIRV::ASSIGN_TYPE) 707 MaybeDef = MRI.getVRegDef(MaybeDef->getOperand(1).getReg()); 708 return MaybeDef; 709 } 710 711 bool getVacantFunctionName(Module &M, std::string &Name) { 712 // It's a bit of paranoia, but still we don't want to have even a chance that 713 // the loop will work for too long. 714 constexpr unsigned MaxIters = 1024; 715 for (unsigned I = 0; I < MaxIters; ++I) { 716 std::string OrdName = Name + Twine(I).str(); 717 if (!M.getFunction(OrdName)) { 718 Name = OrdName; 719 return true; 720 } 721 } 722 return false; 723 } 724 725 // Assign SPIR-V type to the register. If the register has no valid assigned 726 // class, set register LLT type and class according to the SPIR-V type. 727 void setRegClassType(Register Reg, SPIRVType *SpvType, SPIRVGlobalRegistry *GR, 728 MachineRegisterInfo *MRI, const MachineFunction &MF, 729 bool Force) { 730 GR->assignSPIRVTypeToVReg(SpvType, Reg, MF); 731 if (!MRI->getRegClassOrNull(Reg) || Force) { 732 MRI->setRegClass(Reg, GR->getRegClass(SpvType)); 733 MRI->setType(Reg, GR->getRegType(SpvType)); 734 } 735 } 736 737 // Create a SPIR-V type, assign SPIR-V type to the register. If the register has 738 // no valid assigned class, set register LLT type and class according to the 739 // SPIR-V type. 740 void setRegClassType(Register Reg, const Type *Ty, SPIRVGlobalRegistry *GR, 741 MachineIRBuilder &MIRBuilder, bool Force) { 742 setRegClassType(Reg, GR->getOrCreateSPIRVType(Ty, MIRBuilder), GR, 743 MIRBuilder.getMRI(), MIRBuilder.getMF(), Force); 744 } 745 746 // Create a virtual register and assign SPIR-V type to the register. Set 747 // register LLT type and class according to the SPIR-V type. 748 Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR, 749 MachineRegisterInfo *MRI, 750 const MachineFunction &MF) { 751 Register Reg = MRI->createVirtualRegister(GR->getRegClass(SpvType)); 752 MRI->setType(Reg, GR->getRegType(SpvType)); 753 GR->assignSPIRVTypeToVReg(SpvType, Reg, MF); 754 return Reg; 755 } 756 757 // Create a virtual register and assign SPIR-V type to the register. Set 758 // register LLT type and class according to the SPIR-V type. 759 Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR, 760 MachineIRBuilder &MIRBuilder) { 761 return createVirtualRegister(SpvType, GR, MIRBuilder.getMRI(), 762 MIRBuilder.getMF()); 763 } 764 765 // Create a SPIR-V type, virtual register and assign SPIR-V type to the 766 // register. Set register LLT type and class according to the SPIR-V type. 767 Register createVirtualRegister(const Type *Ty, SPIRVGlobalRegistry *GR, 768 MachineIRBuilder &MIRBuilder) { 769 return createVirtualRegister(GR->getOrCreateSPIRVType(Ty, MIRBuilder), GR, 770 MIRBuilder); 771 } 772 773 // Return true if there is an opaque pointer type nested in the argument. 774 bool isNestedPointer(const Type *Ty) { 775 if (Ty->isPtrOrPtrVectorTy()) 776 return true; 777 if (const FunctionType *RefTy = dyn_cast<FunctionType>(Ty)) { 778 if (isNestedPointer(RefTy->getReturnType())) 779 return true; 780 for (const Type *ArgTy : RefTy->params()) 781 if (isNestedPointer(ArgTy)) 782 return true; 783 return false; 784 } 785 if (const ArrayType *RefTy = dyn_cast<ArrayType>(Ty)) 786 return isNestedPointer(RefTy->getElementType()); 787 return false; 788 } 789 790 bool isSpvIntrinsic(const Value *Arg) { 791 if (const auto *II = dyn_cast<IntrinsicInst>(Arg)) 792 if (Function *F = II->getCalledFunction()) 793 if (F->getName().starts_with("llvm.spv.")) 794 return true; 795 return false; 796 } 797 798 } // namespace llvm 799