xref: /llvm-project/llvm/lib/Target/SPIRV/SPIRVUtils.cpp (revision f8b4182f076f8fe55f9d5f617b5a25008a77b22f)
1 //===--- SPIRVUtils.cpp ---- SPIR-V Utility Functions -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains miscellaneous utility functions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "SPIRVUtils.h"
14 #include "MCTargetDesc/SPIRVBaseInfo.h"
15 #include "SPIRV.h"
16 #include "SPIRVGlobalRegistry.h"
17 #include "SPIRVInstrInfo.h"
18 #include "SPIRVSubtarget.h"
19 #include "llvm/ADT/StringRef.h"
20 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
21 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
22 #include "llvm/CodeGen/MachineInstr.h"
23 #include "llvm/CodeGen/MachineInstrBuilder.h"
24 #include "llvm/Demangle/Demangle.h"
25 #include "llvm/IR/IntrinsicInst.h"
26 #include "llvm/IR/IntrinsicsSPIRV.h"
27 #include <queue>
28 #include <vector>
29 
30 namespace llvm {
31 
32 // The following functions are used to add these string literals as a series of
33 // 32-bit integer operands with the correct format, and unpack them if necessary
34 // when making string comparisons in compiler passes.
35 // SPIR-V requires null-terminated UTF-8 strings padded to 32-bit alignment.
36 static uint32_t convertCharsToWord(const StringRef &Str, unsigned i) {
37   uint32_t Word = 0u; // Build up this 32-bit word from 4 8-bit chars.
38   for (unsigned WordIndex = 0; WordIndex < 4; ++WordIndex) {
39     unsigned StrIndex = i + WordIndex;
40     uint8_t CharToAdd = 0;       // Initilize char as padding/null.
41     if (StrIndex < Str.size()) { // If it's within the string, get a real char.
42       CharToAdd = Str[StrIndex];
43     }
44     Word |= (CharToAdd << (WordIndex * 8));
45   }
46   return Word;
47 }
48 
49 // Get length including padding and null terminator.
50 static size_t getPaddedLen(const StringRef &Str) {
51   return (Str.size() + 4) & ~3;
52 }
53 
54 void addStringImm(const StringRef &Str, MCInst &Inst) {
55   const size_t PaddedLen = getPaddedLen(Str);
56   for (unsigned i = 0; i < PaddedLen; i += 4) {
57     // Add an operand for the 32-bits of chars or padding.
58     Inst.addOperand(MCOperand::createImm(convertCharsToWord(Str, i)));
59   }
60 }
61 
62 void addStringImm(const StringRef &Str, MachineInstrBuilder &MIB) {
63   const size_t PaddedLen = getPaddedLen(Str);
64   for (unsigned i = 0; i < PaddedLen; i += 4) {
65     // Add an operand for the 32-bits of chars or padding.
66     MIB.addImm(convertCharsToWord(Str, i));
67   }
68 }
69 
70 void addStringImm(const StringRef &Str, IRBuilder<> &B,
71                   std::vector<Value *> &Args) {
72   const size_t PaddedLen = getPaddedLen(Str);
73   for (unsigned i = 0; i < PaddedLen; i += 4) {
74     // Add a vector element for the 32-bits of chars or padding.
75     Args.push_back(B.getInt32(convertCharsToWord(Str, i)));
76   }
77 }
78 
79 std::string getStringImm(const MachineInstr &MI, unsigned StartIndex) {
80   return getSPIRVStringOperand(MI, StartIndex);
81 }
82 
83 void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) {
84   const auto Bitwidth = Imm.getBitWidth();
85   if (Bitwidth == 1)
86     return; // Already handled
87   else if (Bitwidth <= 32) {
88     MIB.addImm(Imm.getZExtValue());
89     // Asm Printer needs this info to print floating-type correctly
90     if (Bitwidth == 16)
91       MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH16);
92     return;
93   } else if (Bitwidth <= 64) {
94     uint64_t FullImm = Imm.getZExtValue();
95     uint32_t LowBits = FullImm & 0xffffffff;
96     uint32_t HighBits = (FullImm >> 32) & 0xffffffff;
97     MIB.addImm(LowBits).addImm(HighBits);
98     return;
99   }
100   report_fatal_error("Unsupported constant bitwidth");
101 }
102 
103 void buildOpName(Register Target, const StringRef &Name,
104                  MachineIRBuilder &MIRBuilder) {
105   if (!Name.empty()) {
106     auto MIB = MIRBuilder.buildInstr(SPIRV::OpName).addUse(Target);
107     addStringImm(Name, MIB);
108   }
109 }
110 
111 static void finishBuildOpDecorate(MachineInstrBuilder &MIB,
112                                   const std::vector<uint32_t> &DecArgs,
113                                   StringRef StrImm) {
114   if (!StrImm.empty())
115     addStringImm(StrImm, MIB);
116   for (const auto &DecArg : DecArgs)
117     MIB.addImm(DecArg);
118 }
119 
120 void buildOpDecorate(Register Reg, MachineIRBuilder &MIRBuilder,
121                      SPIRV::Decoration::Decoration Dec,
122                      const std::vector<uint32_t> &DecArgs, StringRef StrImm) {
123   auto MIB = MIRBuilder.buildInstr(SPIRV::OpDecorate)
124                  .addUse(Reg)
125                  .addImm(static_cast<uint32_t>(Dec));
126   finishBuildOpDecorate(MIB, DecArgs, StrImm);
127 }
128 
129 void buildOpDecorate(Register Reg, MachineInstr &I, const SPIRVInstrInfo &TII,
130                      SPIRV::Decoration::Decoration Dec,
131                      const std::vector<uint32_t> &DecArgs, StringRef StrImm) {
132   MachineBasicBlock &MBB = *I.getParent();
133   auto MIB = BuildMI(MBB, I, I.getDebugLoc(), TII.get(SPIRV::OpDecorate))
134                  .addUse(Reg)
135                  .addImm(static_cast<uint32_t>(Dec));
136   finishBuildOpDecorate(MIB, DecArgs, StrImm);
137 }
138 
139 void buildOpSpirvDecorations(Register Reg, MachineIRBuilder &MIRBuilder,
140                              const MDNode *GVarMD) {
141   for (unsigned I = 0, E = GVarMD->getNumOperands(); I != E; ++I) {
142     auto *OpMD = dyn_cast<MDNode>(GVarMD->getOperand(I));
143     if (!OpMD)
144       report_fatal_error("Invalid decoration");
145     if (OpMD->getNumOperands() == 0)
146       report_fatal_error("Expect operand(s) of the decoration");
147     ConstantInt *DecorationId =
148         mdconst::dyn_extract<ConstantInt>(OpMD->getOperand(0));
149     if (!DecorationId)
150       report_fatal_error("Expect SPIR-V <Decoration> operand to be the first "
151                          "element of the decoration");
152     auto MIB = MIRBuilder.buildInstr(SPIRV::OpDecorate)
153                    .addUse(Reg)
154                    .addImm(static_cast<uint32_t>(DecorationId->getZExtValue()));
155     for (unsigned OpI = 1, OpE = OpMD->getNumOperands(); OpI != OpE; ++OpI) {
156       if (ConstantInt *OpV =
157               mdconst::dyn_extract<ConstantInt>(OpMD->getOperand(OpI)))
158         MIB.addImm(static_cast<uint32_t>(OpV->getZExtValue()));
159       else if (MDString *OpV = dyn_cast<MDString>(OpMD->getOperand(OpI)))
160         addStringImm(OpV->getString(), MIB);
161       else
162         report_fatal_error("Unexpected operand of the decoration");
163     }
164   }
165 }
166 
167 MachineBasicBlock::iterator getOpVariableMBBIt(MachineInstr &I) {
168   MachineFunction *MF = I.getParent()->getParent();
169   MachineBasicBlock *MBB = &MF->front();
170   MachineBasicBlock::iterator It = MBB->SkipPHIsAndLabels(MBB->begin()),
171                               E = MBB->end();
172   bool IsHeader = false;
173   unsigned Opcode;
174   for (; It != E && It != I; ++It) {
175     Opcode = It->getOpcode();
176     if (Opcode == SPIRV::OpFunction || Opcode == SPIRV::OpFunctionParameter) {
177       IsHeader = true;
178     } else if (IsHeader &&
179                !(Opcode == SPIRV::ASSIGN_TYPE || Opcode == SPIRV::OpLabel)) {
180       ++It;
181       break;
182     }
183   }
184   return It;
185 }
186 
187 SPIRV::StorageClass::StorageClass
188 addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI) {
189   switch (AddrSpace) {
190   case 0:
191     return SPIRV::StorageClass::Function;
192   case 1:
193     return SPIRV::StorageClass::CrossWorkgroup;
194   case 2:
195     return SPIRV::StorageClass::UniformConstant;
196   case 3:
197     return SPIRV::StorageClass::Workgroup;
198   case 4:
199     return SPIRV::StorageClass::Generic;
200   case 5:
201     return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)
202                ? SPIRV::StorageClass::DeviceOnlyINTEL
203                : SPIRV::StorageClass::CrossWorkgroup;
204   case 6:
205     return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)
206                ? SPIRV::StorageClass::HostOnlyINTEL
207                : SPIRV::StorageClass::CrossWorkgroup;
208   case 7:
209     return SPIRV::StorageClass::Input;
210   case 9:
211     return SPIRV::StorageClass::CodeSectionINTEL;
212   default:
213     report_fatal_error("Unknown address space");
214   }
215 }
216 
217 SPIRV::MemorySemantics::MemorySemantics
218 getMemSemanticsForStorageClass(SPIRV::StorageClass::StorageClass SC) {
219   switch (SC) {
220   case SPIRV::StorageClass::StorageBuffer:
221   case SPIRV::StorageClass::Uniform:
222     return SPIRV::MemorySemantics::UniformMemory;
223   case SPIRV::StorageClass::Workgroup:
224     return SPIRV::MemorySemantics::WorkgroupMemory;
225   case SPIRV::StorageClass::CrossWorkgroup:
226     return SPIRV::MemorySemantics::CrossWorkgroupMemory;
227   case SPIRV::StorageClass::AtomicCounter:
228     return SPIRV::MemorySemantics::AtomicCounterMemory;
229   case SPIRV::StorageClass::Image:
230     return SPIRV::MemorySemantics::ImageMemory;
231   default:
232     return SPIRV::MemorySemantics::None;
233   }
234 }
235 
236 SPIRV::MemorySemantics::MemorySemantics getMemSemantics(AtomicOrdering Ord) {
237   switch (Ord) {
238   case AtomicOrdering::Acquire:
239     return SPIRV::MemorySemantics::Acquire;
240   case AtomicOrdering::Release:
241     return SPIRV::MemorySemantics::Release;
242   case AtomicOrdering::AcquireRelease:
243     return SPIRV::MemorySemantics::AcquireRelease;
244   case AtomicOrdering::SequentiallyConsistent:
245     return SPIRV::MemorySemantics::SequentiallyConsistent;
246   case AtomicOrdering::Unordered:
247   case AtomicOrdering::Monotonic:
248   case AtomicOrdering::NotAtomic:
249     return SPIRV::MemorySemantics::None;
250   }
251   llvm_unreachable(nullptr);
252 }
253 
254 SPIRV::Scope::Scope getMemScope(LLVMContext &Ctx, SyncScope::ID Id) {
255   // Named by
256   // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_scope_id.
257   // We don't need aliases for Invocation and CrossDevice, as we already have
258   // them covered by "singlethread" and "" strings respectively (see
259   // implementation of LLVMContext::LLVMContext()).
260   static const llvm::SyncScope::ID SubGroup =
261       Ctx.getOrInsertSyncScopeID("subgroup");
262   static const llvm::SyncScope::ID WorkGroup =
263       Ctx.getOrInsertSyncScopeID("workgroup");
264   static const llvm::SyncScope::ID Device =
265       Ctx.getOrInsertSyncScopeID("device");
266 
267   if (Id == llvm::SyncScope::SingleThread)
268     return SPIRV::Scope::Invocation;
269   else if (Id == llvm::SyncScope::System)
270     return SPIRV::Scope::CrossDevice;
271   else if (Id == SubGroup)
272     return SPIRV::Scope::Subgroup;
273   else if (Id == WorkGroup)
274     return SPIRV::Scope::Workgroup;
275   else if (Id == Device)
276     return SPIRV::Scope::Device;
277   return SPIRV::Scope::CrossDevice;
278 }
279 
280 MachineInstr *getDefInstrMaybeConstant(Register &ConstReg,
281                                        const MachineRegisterInfo *MRI) {
282   MachineInstr *MI = MRI->getVRegDef(ConstReg);
283   MachineInstr *ConstInstr =
284       MI->getOpcode() == SPIRV::G_TRUNC || MI->getOpcode() == SPIRV::G_ZEXT
285           ? MRI->getVRegDef(MI->getOperand(1).getReg())
286           : MI;
287   if (auto *GI = dyn_cast<GIntrinsic>(ConstInstr)) {
288     if (GI->is(Intrinsic::spv_track_constant)) {
289       ConstReg = ConstInstr->getOperand(2).getReg();
290       return MRI->getVRegDef(ConstReg);
291     }
292   } else if (ConstInstr->getOpcode() == SPIRV::ASSIGN_TYPE) {
293     ConstReg = ConstInstr->getOperand(1).getReg();
294     return MRI->getVRegDef(ConstReg);
295   }
296   return MRI->getVRegDef(ConstReg);
297 }
298 
299 uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI) {
300   const MachineInstr *MI = getDefInstrMaybeConstant(ConstReg, MRI);
301   assert(MI && MI->getOpcode() == TargetOpcode::G_CONSTANT);
302   return MI->getOperand(1).getCImm()->getValue().getZExtValue();
303 }
304 
305 bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID) {
306   if (const auto *GI = dyn_cast<GIntrinsic>(&MI))
307     return GI->is(IntrinsicID);
308   return false;
309 }
310 
311 Type *getMDOperandAsType(const MDNode *N, unsigned I) {
312   Type *ElementTy = cast<ValueAsMetadata>(N->getOperand(I))->getType();
313   return toTypedPointer(ElementTy);
314 }
315 
316 // The set of names is borrowed from the SPIR-V translator.
317 // TODO: may be implemented in SPIRVBuiltins.td.
318 static bool isPipeOrAddressSpaceCastBI(const StringRef MangledName) {
319   return MangledName == "write_pipe_2" || MangledName == "read_pipe_2" ||
320          MangledName == "write_pipe_2_bl" || MangledName == "read_pipe_2_bl" ||
321          MangledName == "write_pipe_4" || MangledName == "read_pipe_4" ||
322          MangledName == "reserve_write_pipe" ||
323          MangledName == "reserve_read_pipe" ||
324          MangledName == "commit_write_pipe" ||
325          MangledName == "commit_read_pipe" ||
326          MangledName == "work_group_reserve_write_pipe" ||
327          MangledName == "work_group_reserve_read_pipe" ||
328          MangledName == "work_group_commit_write_pipe" ||
329          MangledName == "work_group_commit_read_pipe" ||
330          MangledName == "get_pipe_num_packets_ro" ||
331          MangledName == "get_pipe_max_packets_ro" ||
332          MangledName == "get_pipe_num_packets_wo" ||
333          MangledName == "get_pipe_max_packets_wo" ||
334          MangledName == "sub_group_reserve_write_pipe" ||
335          MangledName == "sub_group_reserve_read_pipe" ||
336          MangledName == "sub_group_commit_write_pipe" ||
337          MangledName == "sub_group_commit_read_pipe" ||
338          MangledName == "to_global" || MangledName == "to_local" ||
339          MangledName == "to_private";
340 }
341 
342 static bool isEnqueueKernelBI(const StringRef MangledName) {
343   return MangledName == "__enqueue_kernel_basic" ||
344          MangledName == "__enqueue_kernel_basic_events" ||
345          MangledName == "__enqueue_kernel_varargs" ||
346          MangledName == "__enqueue_kernel_events_varargs";
347 }
348 
349 static bool isKernelQueryBI(const StringRef MangledName) {
350   return MangledName == "__get_kernel_work_group_size_impl" ||
351          MangledName == "__get_kernel_sub_group_count_for_ndrange_impl" ||
352          MangledName == "__get_kernel_max_sub_group_size_for_ndrange_impl" ||
353          MangledName == "__get_kernel_preferred_work_group_size_multiple_impl";
354 }
355 
356 static bool isNonMangledOCLBuiltin(StringRef Name) {
357   if (!Name.starts_with("__"))
358     return false;
359 
360   return isEnqueueKernelBI(Name) || isKernelQueryBI(Name) ||
361          isPipeOrAddressSpaceCastBI(Name.drop_front(2)) ||
362          Name == "__translate_sampler_initializer";
363 }
364 
365 std::string getOclOrSpirvBuiltinDemangledName(StringRef Name) {
366   bool IsNonMangledOCL = isNonMangledOCLBuiltin(Name);
367   bool IsNonMangledSPIRV = Name.starts_with("__spirv_");
368   bool IsNonMangledHLSL = Name.starts_with("__hlsl_");
369   bool IsMangled = Name.starts_with("_Z");
370 
371   // Otherwise use simple demangling to return the function name.
372   if (IsNonMangledOCL || IsNonMangledSPIRV || IsNonMangledHLSL || !IsMangled)
373     return Name.str();
374 
375   // Try to use the itanium demangler.
376   if (char *DemangledName = itaniumDemangle(Name.data())) {
377     std::string Result = DemangledName;
378     free(DemangledName);
379     return Result;
380   }
381 
382   // Autocheck C++, maybe need to do explicit check of the source language.
383   // OpenCL C++ built-ins are declared in cl namespace.
384   // TODO: consider using 'St' abbriviation for cl namespace mangling.
385   // Similar to ::std:: in C++.
386   size_t Start, Len = 0;
387   size_t DemangledNameLenStart = 2;
388   if (Name.starts_with("_ZN")) {
389     // Skip CV and ref qualifiers.
390     size_t NameSpaceStart = Name.find_first_not_of("rVKRO", 3);
391     // All built-ins are in the ::cl:: namespace.
392     if (Name.substr(NameSpaceStart, 11) != "2cl7__spirv")
393       return std::string();
394     DemangledNameLenStart = NameSpaceStart + 11;
395   }
396   Start = Name.find_first_not_of("0123456789", DemangledNameLenStart);
397   Name.substr(DemangledNameLenStart, Start - DemangledNameLenStart)
398       .getAsInteger(10, Len);
399   return Name.substr(Start, Len).str();
400 }
401 
402 bool hasBuiltinTypePrefix(StringRef Name) {
403   if (Name.starts_with("opencl.") || Name.starts_with("ocl_") ||
404       Name.starts_with("spirv."))
405     return true;
406   return false;
407 }
408 
409 bool isSpecialOpaqueType(const Type *Ty) {
410   if (const TargetExtType *ExtTy = dyn_cast<TargetExtType>(Ty))
411     return isTypedPointerWrapper(ExtTy)
412                ? false
413                : hasBuiltinTypePrefix(ExtTy->getName());
414 
415   return false;
416 }
417 
418 bool isEntryPoint(const Function &F) {
419   // OpenCL handling: any function with the SPIR_KERNEL
420   // calling convention will be a potential entry point.
421   if (F.getCallingConv() == CallingConv::SPIR_KERNEL)
422     return true;
423 
424   // HLSL handling: special attribute are emitted from the
425   // front-end.
426   if (F.getFnAttribute("hlsl.shader").isValid())
427     return true;
428 
429   return false;
430 }
431 
432 Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx) {
433   TypeName.consume_front("atomic_");
434   if (TypeName.consume_front("void"))
435     return Type::getVoidTy(Ctx);
436   else if (TypeName.consume_front("bool"))
437     return Type::getIntNTy(Ctx, 1);
438   else if (TypeName.consume_front("char") ||
439            TypeName.consume_front("unsigned char") ||
440            TypeName.consume_front("uchar"))
441     return Type::getInt8Ty(Ctx);
442   else if (TypeName.consume_front("short") ||
443            TypeName.consume_front("unsigned short") ||
444            TypeName.consume_front("ushort"))
445     return Type::getInt16Ty(Ctx);
446   else if (TypeName.consume_front("int") ||
447            TypeName.consume_front("unsigned int") ||
448            TypeName.consume_front("uint"))
449     return Type::getInt32Ty(Ctx);
450   else if (TypeName.consume_front("long") ||
451            TypeName.consume_front("unsigned long") ||
452            TypeName.consume_front("ulong"))
453     return Type::getInt64Ty(Ctx);
454   else if (TypeName.consume_front("half"))
455     return Type::getHalfTy(Ctx);
456   else if (TypeName.consume_front("float"))
457     return Type::getFloatTy(Ctx);
458   else if (TypeName.consume_front("double"))
459     return Type::getDoubleTy(Ctx);
460 
461   // Unable to recognize SPIRV type name
462   return nullptr;
463 }
464 
465 std::unordered_set<BasicBlock *>
466 PartialOrderingVisitor::getReachableFrom(BasicBlock *Start) {
467   std::queue<BasicBlock *> ToVisit;
468   ToVisit.push(Start);
469 
470   std::unordered_set<BasicBlock *> Output;
471   while (ToVisit.size() != 0) {
472     BasicBlock *BB = ToVisit.front();
473     ToVisit.pop();
474 
475     if (Output.count(BB) != 0)
476       continue;
477     Output.insert(BB);
478 
479     for (BasicBlock *Successor : successors(BB)) {
480       if (DT.dominates(Successor, BB))
481         continue;
482       ToVisit.push(Successor);
483     }
484   }
485 
486   return Output;
487 }
488 
489 bool PartialOrderingVisitor::CanBeVisited(BasicBlock *BB) const {
490   for (BasicBlock *P : predecessors(BB)) {
491     // Ignore back-edges.
492     if (DT.dominates(BB, P))
493       continue;
494 
495     // One of the predecessor hasn't been visited. Not ready yet.
496     if (BlockToOrder.count(P) == 0)
497       return false;
498 
499     // If the block is a loop exit, the loop must be finished before
500     // we can continue.
501     Loop *L = LI.getLoopFor(P);
502     if (L == nullptr || L->contains(BB))
503       continue;
504 
505     // SPIR-V requires a single back-edge. And the backend first
506     // step transforms loops into the simplified format. If we have
507     // more than 1 back-edge, something is wrong.
508     assert(L->getNumBackEdges() <= 1);
509 
510     // If the loop has no latch, loop's rank won't matter, so we can
511     // proceed.
512     BasicBlock *Latch = L->getLoopLatch();
513     assert(Latch);
514     if (Latch == nullptr)
515       continue;
516 
517     // The latch is not ready yet, let's wait.
518     if (BlockToOrder.count(Latch) == 0)
519       return false;
520   }
521 
522   return true;
523 }
524 
525 size_t PartialOrderingVisitor::GetNodeRank(BasicBlock *BB) const {
526   auto It = BlockToOrder.find(BB);
527   if (It != BlockToOrder.end())
528     return It->second.Rank;
529 
530   size_t result = 0;
531   for (BasicBlock *P : predecessors(BB)) {
532     // Ignore back-edges.
533     if (DT.dominates(BB, P))
534       continue;
535 
536     auto Iterator = BlockToOrder.end();
537     Loop *L = LI.getLoopFor(P);
538     BasicBlock *Latch = L ? L->getLoopLatch() : nullptr;
539 
540     // If the predecessor is either outside a loop, or part of
541     // the same loop, simply take its rank + 1.
542     if (L == nullptr || L->contains(BB) || Latch == nullptr) {
543       Iterator = BlockToOrder.find(P);
544     } else {
545       // Otherwise, take the loop's rank (highest rank in the loop) as base.
546       // Since loops have a single latch, highest rank is easy to find.
547       // If the loop has no latch, then it doesn't matter.
548       Iterator = BlockToOrder.find(Latch);
549     }
550 
551     assert(Iterator != BlockToOrder.end());
552     result = std::max(result, Iterator->second.Rank + 1);
553   }
554 
555   return result;
556 }
557 
558 size_t PartialOrderingVisitor::visit(BasicBlock *BB, size_t Unused) {
559   ToVisit.push(BB);
560   Queued.insert(BB);
561 
562   size_t QueueIndex = 0;
563   while (ToVisit.size() != 0) {
564     BasicBlock *BB = ToVisit.front();
565     ToVisit.pop();
566 
567     if (!CanBeVisited(BB)) {
568       ToVisit.push(BB);
569       assert(QueueIndex < ToVisit.size() &&
570              "No valid candidate in the queue. Is the graph reducible?");
571       QueueIndex++;
572       continue;
573     }
574 
575     QueueIndex = 0;
576     size_t Rank = GetNodeRank(BB);
577     OrderInfo Info = {Rank, BlockToOrder.size()};
578     BlockToOrder.emplace(BB, Info);
579 
580     for (BasicBlock *S : successors(BB)) {
581       if (Queued.count(S) != 0)
582         continue;
583       ToVisit.push(S);
584       Queued.insert(S);
585     }
586   }
587 
588   return 0;
589 }
590 
591 PartialOrderingVisitor::PartialOrderingVisitor(Function &F) {
592   DT.recalculate(F);
593   LI = LoopInfo(DT);
594 
595   visit(&*F.begin(), 0);
596 
597   Order.reserve(F.size());
598   for (auto &[BB, Info] : BlockToOrder)
599     Order.emplace_back(BB);
600 
601   std::sort(Order.begin(), Order.end(), [&](const auto &LHS, const auto &RHS) {
602     return compare(LHS, RHS);
603   });
604 }
605 
606 bool PartialOrderingVisitor::compare(const BasicBlock *LHS,
607                                      const BasicBlock *RHS) const {
608   const OrderInfo &InfoLHS = BlockToOrder.at(const_cast<BasicBlock *>(LHS));
609   const OrderInfo &InfoRHS = BlockToOrder.at(const_cast<BasicBlock *>(RHS));
610   if (InfoLHS.Rank != InfoRHS.Rank)
611     return InfoLHS.Rank < InfoRHS.Rank;
612   return InfoLHS.TraversalIndex < InfoRHS.TraversalIndex;
613 }
614 
615 void PartialOrderingVisitor::partialOrderVisit(
616     BasicBlock &Start, std::function<bool(BasicBlock *)> Op) {
617   std::unordered_set<BasicBlock *> Reachable = getReachableFrom(&Start);
618   assert(BlockToOrder.count(&Start) != 0);
619 
620   // Skipping blocks with a rank inferior to |Start|'s rank.
621   auto It = Order.begin();
622   while (It != Order.end() && *It != &Start)
623     ++It;
624 
625   // This is unexpected. Worst case |Start| is the last block,
626   // so It should point to the last block, not past-end.
627   assert(It != Order.end());
628 
629   // By default, there is no rank limit. Setting it to the maximum value.
630   std::optional<size_t> EndRank = std::nullopt;
631   for (; It != Order.end(); ++It) {
632     if (EndRank.has_value() && BlockToOrder[*It].Rank > *EndRank)
633       break;
634 
635     if (Reachable.count(*It) == 0) {
636       continue;
637     }
638 
639     if (!Op(*It)) {
640       EndRank = BlockToOrder[*It].Rank;
641     }
642   }
643 }
644 
645 bool sortBlocks(Function &F) {
646   if (F.size() == 0)
647     return false;
648 
649   bool Modified = false;
650   std::vector<BasicBlock *> Order;
651   Order.reserve(F.size());
652 
653   ReversePostOrderTraversal<Function *> RPOT(&F);
654   for (BasicBlock *BB : RPOT)
655     Order.push_back(BB);
656 
657   assert(&*F.begin() == Order[0]);
658   BasicBlock *LastBlock = &*F.begin();
659   for (BasicBlock *BB : Order) {
660     if (BB != LastBlock && &*LastBlock->getNextNode() != BB) {
661       Modified = true;
662       BB->moveAfter(LastBlock);
663     }
664     LastBlock = BB;
665   }
666 
667   return Modified;
668 }
669 
670 MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg) {
671   MachineInstr *MaybeDef = MRI.getVRegDef(Reg);
672   if (MaybeDef && MaybeDef->getOpcode() == SPIRV::ASSIGN_TYPE)
673     MaybeDef = MRI.getVRegDef(MaybeDef->getOperand(1).getReg());
674   return MaybeDef;
675 }
676 
677 bool getVacantFunctionName(Module &M, std::string &Name) {
678   // It's a bit of paranoia, but still we don't want to have even a chance that
679   // the loop will work for too long.
680   constexpr unsigned MaxIters = 1024;
681   for (unsigned I = 0; I < MaxIters; ++I) {
682     std::string OrdName = Name + Twine(I).str();
683     if (!M.getFunction(OrdName)) {
684       Name = OrdName;
685       return true;
686     }
687   }
688   return false;
689 }
690 
691 // Assign SPIR-V type to the register. If the register has no valid assigned
692 // class, set register LLT type and class according to the SPIR-V type.
693 void setRegClassType(Register Reg, SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
694                      MachineRegisterInfo *MRI, const MachineFunction &MF,
695                      bool Force) {
696   GR->assignSPIRVTypeToVReg(SpvType, Reg, MF);
697   if (!MRI->getRegClassOrNull(Reg) || Force) {
698     MRI->setRegClass(Reg, GR->getRegClass(SpvType));
699     MRI->setType(Reg, GR->getRegType(SpvType));
700   }
701 }
702 
703 // Create a SPIR-V type, assign SPIR-V type to the register. If the register has
704 // no valid assigned class, set register LLT type and class according to the
705 // SPIR-V type.
706 void setRegClassType(Register Reg, const Type *Ty, SPIRVGlobalRegistry *GR,
707                      MachineIRBuilder &MIRBuilder, bool Force) {
708   setRegClassType(Reg, GR->getOrCreateSPIRVType(Ty, MIRBuilder), GR,
709                   MIRBuilder.getMRI(), MIRBuilder.getMF(), Force);
710 }
711 
712 // Create a virtual register and assign SPIR-V type to the register. Set
713 // register LLT type and class according to the SPIR-V type.
714 Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
715                                MachineRegisterInfo *MRI,
716                                const MachineFunction &MF) {
717   Register Reg = MRI->createVirtualRegister(GR->getRegClass(SpvType));
718   MRI->setType(Reg, GR->getRegType(SpvType));
719   GR->assignSPIRVTypeToVReg(SpvType, Reg, MF);
720   return Reg;
721 }
722 
723 // Create a virtual register and assign SPIR-V type to the register. Set
724 // register LLT type and class according to the SPIR-V type.
725 Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
726                                MachineIRBuilder &MIRBuilder) {
727   return createVirtualRegister(SpvType, GR, MIRBuilder.getMRI(),
728                                MIRBuilder.getMF());
729 }
730 
731 // Create a SPIR-V type, virtual register and assign SPIR-V type to the
732 // register. Set register LLT type and class according to the SPIR-V type.
733 Register createVirtualRegister(const Type *Ty, SPIRVGlobalRegistry *GR,
734                                MachineIRBuilder &MIRBuilder) {
735   return createVirtualRegister(GR->getOrCreateSPIRVType(Ty, MIRBuilder), GR,
736                                MIRBuilder);
737 }
738 
739 // Return true if there is an opaque pointer type nested in the argument.
740 bool isNestedPointer(const Type *Ty) {
741   if (Ty->isPtrOrPtrVectorTy())
742     return true;
743   if (const FunctionType *RefTy = dyn_cast<FunctionType>(Ty)) {
744     if (isNestedPointer(RefTy->getReturnType()))
745       return true;
746     for (const Type *ArgTy : RefTy->params())
747       if (isNestedPointer(ArgTy))
748         return true;
749     return false;
750   }
751   if (const ArrayType *RefTy = dyn_cast<ArrayType>(Ty))
752     return isNestedPointer(RefTy->getElementType());
753   return false;
754 }
755 
756 bool isSpvIntrinsic(const Value *Arg) {
757   if (const auto *II = dyn_cast<IntrinsicInst>(Arg))
758     if (Function *F = II->getCalledFunction())
759       if (F->getName().starts_with("llvm.spv."))
760         return true;
761   return false;
762 }
763 
764 } // namespace llvm
765