xref: /llvm-project/llvm/lib/Target/SPIRV/SPIRVUtils.cpp (revision 874b4fb6adf742344fabd7df898be31360a563b9)
1 //===--- SPIRVUtils.cpp ---- SPIR-V Utility Functions -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains miscellaneous utility functions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "SPIRVUtils.h"
14 #include "MCTargetDesc/SPIRVBaseInfo.h"
15 #include "SPIRV.h"
16 #include "SPIRVGlobalRegistry.h"
17 #include "SPIRVInstrInfo.h"
18 #include "SPIRVSubtarget.h"
19 #include "llvm/ADT/StringRef.h"
20 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
21 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
22 #include "llvm/CodeGen/MachineInstr.h"
23 #include "llvm/CodeGen/MachineInstrBuilder.h"
24 #include "llvm/Demangle/Demangle.h"
25 #include "llvm/IR/IntrinsicInst.h"
26 #include "llvm/IR/IntrinsicsSPIRV.h"
27 #include <queue>
28 #include <vector>
29 
30 namespace llvm {
31 
32 // The following functions are used to add these string literals as a series of
33 // 32-bit integer operands with the correct format, and unpack them if necessary
34 // when making string comparisons in compiler passes.
35 // SPIR-V requires null-terminated UTF-8 strings padded to 32-bit alignment.
36 static uint32_t convertCharsToWord(const StringRef &Str, unsigned i) {
37   uint32_t Word = 0u; // Build up this 32-bit word from 4 8-bit chars.
38   for (unsigned WordIndex = 0; WordIndex < 4; ++WordIndex) {
39     unsigned StrIndex = i + WordIndex;
40     uint8_t CharToAdd = 0;       // Initilize char as padding/null.
41     if (StrIndex < Str.size()) { // If it's within the string, get a real char.
42       CharToAdd = Str[StrIndex];
43     }
44     Word |= (CharToAdd << (WordIndex * 8));
45   }
46   return Word;
47 }
48 
49 // Get length including padding and null terminator.
50 static size_t getPaddedLen(const StringRef &Str) {
51   return (Str.size() + 4) & ~3;
52 }
53 
54 void addStringImm(const StringRef &Str, MCInst &Inst) {
55   const size_t PaddedLen = getPaddedLen(Str);
56   for (unsigned i = 0; i < PaddedLen; i += 4) {
57     // Add an operand for the 32-bits of chars or padding.
58     Inst.addOperand(MCOperand::createImm(convertCharsToWord(Str, i)));
59   }
60 }
61 
62 void addStringImm(const StringRef &Str, MachineInstrBuilder &MIB) {
63   const size_t PaddedLen = getPaddedLen(Str);
64   for (unsigned i = 0; i < PaddedLen; i += 4) {
65     // Add an operand for the 32-bits of chars or padding.
66     MIB.addImm(convertCharsToWord(Str, i));
67   }
68 }
69 
70 void addStringImm(const StringRef &Str, IRBuilder<> &B,
71                   std::vector<Value *> &Args) {
72   const size_t PaddedLen = getPaddedLen(Str);
73   for (unsigned i = 0; i < PaddedLen; i += 4) {
74     // Add a vector element for the 32-bits of chars or padding.
75     Args.push_back(B.getInt32(convertCharsToWord(Str, i)));
76   }
77 }
78 
79 std::string getStringImm(const MachineInstr &MI, unsigned StartIndex) {
80   return getSPIRVStringOperand(MI, StartIndex);
81 }
82 
83 void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) {
84   const auto Bitwidth = Imm.getBitWidth();
85   if (Bitwidth == 1)
86     return; // Already handled
87   else if (Bitwidth <= 32) {
88     MIB.addImm(Imm.getZExtValue());
89     // Asm Printer needs this info to print floating-type correctly
90     if (Bitwidth == 16)
91       MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH16);
92     return;
93   } else if (Bitwidth <= 64) {
94     uint64_t FullImm = Imm.getZExtValue();
95     uint32_t LowBits = FullImm & 0xffffffff;
96     uint32_t HighBits = (FullImm >> 32) & 0xffffffff;
97     MIB.addImm(LowBits).addImm(HighBits);
98     return;
99   }
100   report_fatal_error("Unsupported constant bitwidth");
101 }
102 
103 void buildOpName(Register Target, const StringRef &Name,
104                  MachineIRBuilder &MIRBuilder) {
105   if (!Name.empty()) {
106     auto MIB = MIRBuilder.buildInstr(SPIRV::OpName).addUse(Target);
107     addStringImm(Name, MIB);
108   }
109 }
110 
111 void buildOpName(Register Target, const StringRef &Name, MachineInstr &I,
112                  const SPIRVInstrInfo &TII) {
113   if (!Name.empty()) {
114     auto MIB =
115         BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpName))
116             .addUse(Target);
117     addStringImm(Name, MIB);
118   }
119 }
120 
121 static void finishBuildOpDecorate(MachineInstrBuilder &MIB,
122                                   const std::vector<uint32_t> &DecArgs,
123                                   StringRef StrImm) {
124   if (!StrImm.empty())
125     addStringImm(StrImm, MIB);
126   for (const auto &DecArg : DecArgs)
127     MIB.addImm(DecArg);
128 }
129 
130 void buildOpDecorate(Register Reg, MachineIRBuilder &MIRBuilder,
131                      SPIRV::Decoration::Decoration Dec,
132                      const std::vector<uint32_t> &DecArgs, StringRef StrImm) {
133   auto MIB = MIRBuilder.buildInstr(SPIRV::OpDecorate)
134                  .addUse(Reg)
135                  .addImm(static_cast<uint32_t>(Dec));
136   finishBuildOpDecorate(MIB, DecArgs, StrImm);
137 }
138 
139 void buildOpDecorate(Register Reg, MachineInstr &I, const SPIRVInstrInfo &TII,
140                      SPIRV::Decoration::Decoration Dec,
141                      const std::vector<uint32_t> &DecArgs, StringRef StrImm) {
142   MachineBasicBlock &MBB = *I.getParent();
143   auto MIB = BuildMI(MBB, I, I.getDebugLoc(), TII.get(SPIRV::OpDecorate))
144                  .addUse(Reg)
145                  .addImm(static_cast<uint32_t>(Dec));
146   finishBuildOpDecorate(MIB, DecArgs, StrImm);
147 }
148 
149 void buildOpSpirvDecorations(Register Reg, MachineIRBuilder &MIRBuilder,
150                              const MDNode *GVarMD) {
151   for (unsigned I = 0, E = GVarMD->getNumOperands(); I != E; ++I) {
152     auto *OpMD = dyn_cast<MDNode>(GVarMD->getOperand(I));
153     if (!OpMD)
154       report_fatal_error("Invalid decoration");
155     if (OpMD->getNumOperands() == 0)
156       report_fatal_error("Expect operand(s) of the decoration");
157     ConstantInt *DecorationId =
158         mdconst::dyn_extract<ConstantInt>(OpMD->getOperand(0));
159     if (!DecorationId)
160       report_fatal_error("Expect SPIR-V <Decoration> operand to be the first "
161                          "element of the decoration");
162     auto MIB = MIRBuilder.buildInstr(SPIRV::OpDecorate)
163                    .addUse(Reg)
164                    .addImm(static_cast<uint32_t>(DecorationId->getZExtValue()));
165     for (unsigned OpI = 1, OpE = OpMD->getNumOperands(); OpI != OpE; ++OpI) {
166       if (ConstantInt *OpV =
167               mdconst::dyn_extract<ConstantInt>(OpMD->getOperand(OpI)))
168         MIB.addImm(static_cast<uint32_t>(OpV->getZExtValue()));
169       else if (MDString *OpV = dyn_cast<MDString>(OpMD->getOperand(OpI)))
170         addStringImm(OpV->getString(), MIB);
171       else
172         report_fatal_error("Unexpected operand of the decoration");
173     }
174   }
175 }
176 
177 MachineBasicBlock::iterator getOpVariableMBBIt(MachineInstr &I) {
178   MachineFunction *MF = I.getParent()->getParent();
179   MachineBasicBlock *MBB = &MF->front();
180   MachineBasicBlock::iterator It = MBB->SkipPHIsAndLabels(MBB->begin()),
181                               E = MBB->end();
182   bool IsHeader = false;
183   unsigned Opcode;
184   for (; It != E && It != I; ++It) {
185     Opcode = It->getOpcode();
186     if (Opcode == SPIRV::OpFunction || Opcode == SPIRV::OpFunctionParameter) {
187       IsHeader = true;
188     } else if (IsHeader &&
189                !(Opcode == SPIRV::ASSIGN_TYPE || Opcode == SPIRV::OpLabel)) {
190       ++It;
191       break;
192     }
193   }
194   return It;
195 }
196 
197 SPIRV::StorageClass::StorageClass
198 addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI) {
199   switch (AddrSpace) {
200   case 0:
201     return SPIRV::StorageClass::Function;
202   case 1:
203     return SPIRV::StorageClass::CrossWorkgroup;
204   case 2:
205     return SPIRV::StorageClass::UniformConstant;
206   case 3:
207     return SPIRV::StorageClass::Workgroup;
208   case 4:
209     return SPIRV::StorageClass::Generic;
210   case 5:
211     return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)
212                ? SPIRV::StorageClass::DeviceOnlyINTEL
213                : SPIRV::StorageClass::CrossWorkgroup;
214   case 6:
215     return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)
216                ? SPIRV::StorageClass::HostOnlyINTEL
217                : SPIRV::StorageClass::CrossWorkgroup;
218   case 7:
219     return SPIRV::StorageClass::Input;
220   case 8:
221     return SPIRV::StorageClass::Output;
222   case 9:
223     return SPIRV::StorageClass::CodeSectionINTEL;
224   case 10:
225     return SPIRV::StorageClass::Private;
226   default:
227     report_fatal_error("Unknown address space");
228   }
229 }
230 
231 SPIRV::MemorySemantics::MemorySemantics
232 getMemSemanticsForStorageClass(SPIRV::StorageClass::StorageClass SC) {
233   switch (SC) {
234   case SPIRV::StorageClass::StorageBuffer:
235   case SPIRV::StorageClass::Uniform:
236     return SPIRV::MemorySemantics::UniformMemory;
237   case SPIRV::StorageClass::Workgroup:
238     return SPIRV::MemorySemantics::WorkgroupMemory;
239   case SPIRV::StorageClass::CrossWorkgroup:
240     return SPIRV::MemorySemantics::CrossWorkgroupMemory;
241   case SPIRV::StorageClass::AtomicCounter:
242     return SPIRV::MemorySemantics::AtomicCounterMemory;
243   case SPIRV::StorageClass::Image:
244     return SPIRV::MemorySemantics::ImageMemory;
245   default:
246     return SPIRV::MemorySemantics::None;
247   }
248 }
249 
250 SPIRV::MemorySemantics::MemorySemantics getMemSemantics(AtomicOrdering Ord) {
251   switch (Ord) {
252   case AtomicOrdering::Acquire:
253     return SPIRV::MemorySemantics::Acquire;
254   case AtomicOrdering::Release:
255     return SPIRV::MemorySemantics::Release;
256   case AtomicOrdering::AcquireRelease:
257     return SPIRV::MemorySemantics::AcquireRelease;
258   case AtomicOrdering::SequentiallyConsistent:
259     return SPIRV::MemorySemantics::SequentiallyConsistent;
260   case AtomicOrdering::Unordered:
261   case AtomicOrdering::Monotonic:
262   case AtomicOrdering::NotAtomic:
263     return SPIRV::MemorySemantics::None;
264   }
265   llvm_unreachable(nullptr);
266 }
267 
268 SPIRV::Scope::Scope getMemScope(LLVMContext &Ctx, SyncScope::ID Id) {
269   // Named by
270   // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_scope_id.
271   // We don't need aliases for Invocation and CrossDevice, as we already have
272   // them covered by "singlethread" and "" strings respectively (see
273   // implementation of LLVMContext::LLVMContext()).
274   static const llvm::SyncScope::ID SubGroup =
275       Ctx.getOrInsertSyncScopeID("subgroup");
276   static const llvm::SyncScope::ID WorkGroup =
277       Ctx.getOrInsertSyncScopeID("workgroup");
278   static const llvm::SyncScope::ID Device =
279       Ctx.getOrInsertSyncScopeID("device");
280 
281   if (Id == llvm::SyncScope::SingleThread)
282     return SPIRV::Scope::Invocation;
283   else if (Id == llvm::SyncScope::System)
284     return SPIRV::Scope::CrossDevice;
285   else if (Id == SubGroup)
286     return SPIRV::Scope::Subgroup;
287   else if (Id == WorkGroup)
288     return SPIRV::Scope::Workgroup;
289   else if (Id == Device)
290     return SPIRV::Scope::Device;
291   return SPIRV::Scope::CrossDevice;
292 }
293 
294 MachineInstr *getDefInstrMaybeConstant(Register &ConstReg,
295                                        const MachineRegisterInfo *MRI) {
296   MachineInstr *MI = MRI->getVRegDef(ConstReg);
297   MachineInstr *ConstInstr =
298       MI->getOpcode() == SPIRV::G_TRUNC || MI->getOpcode() == SPIRV::G_ZEXT
299           ? MRI->getVRegDef(MI->getOperand(1).getReg())
300           : MI;
301   if (auto *GI = dyn_cast<GIntrinsic>(ConstInstr)) {
302     if (GI->is(Intrinsic::spv_track_constant)) {
303       ConstReg = ConstInstr->getOperand(2).getReg();
304       return MRI->getVRegDef(ConstReg);
305     }
306   } else if (ConstInstr->getOpcode() == SPIRV::ASSIGN_TYPE) {
307     ConstReg = ConstInstr->getOperand(1).getReg();
308     return MRI->getVRegDef(ConstReg);
309   }
310   return MRI->getVRegDef(ConstReg);
311 }
312 
313 uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI) {
314   const MachineInstr *MI = getDefInstrMaybeConstant(ConstReg, MRI);
315   assert(MI && MI->getOpcode() == TargetOpcode::G_CONSTANT);
316   return MI->getOperand(1).getCImm()->getValue().getZExtValue();
317 }
318 
319 bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID) {
320   if (const auto *GI = dyn_cast<GIntrinsic>(&MI))
321     return GI->is(IntrinsicID);
322   return false;
323 }
324 
325 Type *getMDOperandAsType(const MDNode *N, unsigned I) {
326   Type *ElementTy = cast<ValueAsMetadata>(N->getOperand(I))->getType();
327   return toTypedPointer(ElementTy);
328 }
329 
330 // The set of names is borrowed from the SPIR-V translator.
331 // TODO: may be implemented in SPIRVBuiltins.td.
332 static bool isPipeOrAddressSpaceCastBI(const StringRef MangledName) {
333   return MangledName == "write_pipe_2" || MangledName == "read_pipe_2" ||
334          MangledName == "write_pipe_2_bl" || MangledName == "read_pipe_2_bl" ||
335          MangledName == "write_pipe_4" || MangledName == "read_pipe_4" ||
336          MangledName == "reserve_write_pipe" ||
337          MangledName == "reserve_read_pipe" ||
338          MangledName == "commit_write_pipe" ||
339          MangledName == "commit_read_pipe" ||
340          MangledName == "work_group_reserve_write_pipe" ||
341          MangledName == "work_group_reserve_read_pipe" ||
342          MangledName == "work_group_commit_write_pipe" ||
343          MangledName == "work_group_commit_read_pipe" ||
344          MangledName == "get_pipe_num_packets_ro" ||
345          MangledName == "get_pipe_max_packets_ro" ||
346          MangledName == "get_pipe_num_packets_wo" ||
347          MangledName == "get_pipe_max_packets_wo" ||
348          MangledName == "sub_group_reserve_write_pipe" ||
349          MangledName == "sub_group_reserve_read_pipe" ||
350          MangledName == "sub_group_commit_write_pipe" ||
351          MangledName == "sub_group_commit_read_pipe" ||
352          MangledName == "to_global" || MangledName == "to_local" ||
353          MangledName == "to_private";
354 }
355 
356 static bool isEnqueueKernelBI(const StringRef MangledName) {
357   return MangledName == "__enqueue_kernel_basic" ||
358          MangledName == "__enqueue_kernel_basic_events" ||
359          MangledName == "__enqueue_kernel_varargs" ||
360          MangledName == "__enqueue_kernel_events_varargs";
361 }
362 
363 static bool isKernelQueryBI(const StringRef MangledName) {
364   return MangledName == "__get_kernel_work_group_size_impl" ||
365          MangledName == "__get_kernel_sub_group_count_for_ndrange_impl" ||
366          MangledName == "__get_kernel_max_sub_group_size_for_ndrange_impl" ||
367          MangledName == "__get_kernel_preferred_work_group_size_multiple_impl";
368 }
369 
370 static bool isNonMangledOCLBuiltin(StringRef Name) {
371   if (!Name.starts_with("__"))
372     return false;
373 
374   return isEnqueueKernelBI(Name) || isKernelQueryBI(Name) ||
375          isPipeOrAddressSpaceCastBI(Name.drop_front(2)) ||
376          Name == "__translate_sampler_initializer";
377 }
378 
379 std::string getOclOrSpirvBuiltinDemangledName(StringRef Name) {
380   bool IsNonMangledOCL = isNonMangledOCLBuiltin(Name);
381   bool IsNonMangledSPIRV = Name.starts_with("__spirv_");
382   bool IsNonMangledHLSL = Name.starts_with("__hlsl_");
383   bool IsMangled = Name.starts_with("_Z");
384 
385   // Otherwise use simple demangling to return the function name.
386   if (IsNonMangledOCL || IsNonMangledSPIRV || IsNonMangledHLSL || !IsMangled)
387     return Name.str();
388 
389   // Try to use the itanium demangler.
390   if (char *DemangledName = itaniumDemangle(Name.data())) {
391     std::string Result = DemangledName;
392     free(DemangledName);
393     return Result;
394   }
395 
396   // Autocheck C++, maybe need to do explicit check of the source language.
397   // OpenCL C++ built-ins are declared in cl namespace.
398   // TODO: consider using 'St' abbriviation for cl namespace mangling.
399   // Similar to ::std:: in C++.
400   size_t Start, Len = 0;
401   size_t DemangledNameLenStart = 2;
402   if (Name.starts_with("_ZN")) {
403     // Skip CV and ref qualifiers.
404     size_t NameSpaceStart = Name.find_first_not_of("rVKRO", 3);
405     // All built-ins are in the ::cl:: namespace.
406     if (Name.substr(NameSpaceStart, 11) != "2cl7__spirv")
407       return std::string();
408     DemangledNameLenStart = NameSpaceStart + 11;
409   }
410   Start = Name.find_first_not_of("0123456789", DemangledNameLenStart);
411   Name.substr(DemangledNameLenStart, Start - DemangledNameLenStart)
412       .getAsInteger(10, Len);
413   return Name.substr(Start, Len).str();
414 }
415 
416 bool hasBuiltinTypePrefix(StringRef Name) {
417   if (Name.starts_with("opencl.") || Name.starts_with("ocl_") ||
418       Name.starts_with("spirv."))
419     return true;
420   return false;
421 }
422 
423 bool isSpecialOpaqueType(const Type *Ty) {
424   if (const TargetExtType *ExtTy = dyn_cast<TargetExtType>(Ty))
425     return isTypedPointerWrapper(ExtTy)
426                ? false
427                : hasBuiltinTypePrefix(ExtTy->getName());
428 
429   return false;
430 }
431 
432 bool isEntryPoint(const Function &F) {
433   // OpenCL handling: any function with the SPIR_KERNEL
434   // calling convention will be a potential entry point.
435   if (F.getCallingConv() == CallingConv::SPIR_KERNEL)
436     return true;
437 
438   // HLSL handling: special attribute are emitted from the
439   // front-end.
440   if (F.getFnAttribute("hlsl.shader").isValid())
441     return true;
442 
443   return false;
444 }
445 
446 Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx) {
447   TypeName.consume_front("atomic_");
448   if (TypeName.consume_front("void"))
449     return Type::getVoidTy(Ctx);
450   else if (TypeName.consume_front("bool"))
451     return Type::getIntNTy(Ctx, 1);
452   else if (TypeName.consume_front("char") ||
453            TypeName.consume_front("unsigned char") ||
454            TypeName.consume_front("uchar"))
455     return Type::getInt8Ty(Ctx);
456   else if (TypeName.consume_front("short") ||
457            TypeName.consume_front("unsigned short") ||
458            TypeName.consume_front("ushort"))
459     return Type::getInt16Ty(Ctx);
460   else if (TypeName.consume_front("int") ||
461            TypeName.consume_front("unsigned int") ||
462            TypeName.consume_front("uint"))
463     return Type::getInt32Ty(Ctx);
464   else if (TypeName.consume_front("long") ||
465            TypeName.consume_front("unsigned long") ||
466            TypeName.consume_front("ulong"))
467     return Type::getInt64Ty(Ctx);
468   else if (TypeName.consume_front("half"))
469     return Type::getHalfTy(Ctx);
470   else if (TypeName.consume_front("float"))
471     return Type::getFloatTy(Ctx);
472   else if (TypeName.consume_front("double"))
473     return Type::getDoubleTy(Ctx);
474 
475   // Unable to recognize SPIRV type name
476   return nullptr;
477 }
478 
479 std::unordered_set<BasicBlock *>
480 PartialOrderingVisitor::getReachableFrom(BasicBlock *Start) {
481   std::queue<BasicBlock *> ToVisit;
482   ToVisit.push(Start);
483 
484   std::unordered_set<BasicBlock *> Output;
485   while (ToVisit.size() != 0) {
486     BasicBlock *BB = ToVisit.front();
487     ToVisit.pop();
488 
489     if (Output.count(BB) != 0)
490       continue;
491     Output.insert(BB);
492 
493     for (BasicBlock *Successor : successors(BB)) {
494       if (DT.dominates(Successor, BB))
495         continue;
496       ToVisit.push(Successor);
497     }
498   }
499 
500   return Output;
501 }
502 
503 bool PartialOrderingVisitor::CanBeVisited(BasicBlock *BB) const {
504   for (BasicBlock *P : predecessors(BB)) {
505     // Ignore back-edges.
506     if (DT.dominates(BB, P))
507       continue;
508 
509     // One of the predecessor hasn't been visited. Not ready yet.
510     if (BlockToOrder.count(P) == 0)
511       return false;
512 
513     // If the block is a loop exit, the loop must be finished before
514     // we can continue.
515     Loop *L = LI.getLoopFor(P);
516     if (L == nullptr || L->contains(BB))
517       continue;
518 
519     // SPIR-V requires a single back-edge. And the backend first
520     // step transforms loops into the simplified format. If we have
521     // more than 1 back-edge, something is wrong.
522     assert(L->getNumBackEdges() <= 1);
523 
524     // If the loop has no latch, loop's rank won't matter, so we can
525     // proceed.
526     BasicBlock *Latch = L->getLoopLatch();
527     assert(Latch);
528     if (Latch == nullptr)
529       continue;
530 
531     // The latch is not ready yet, let's wait.
532     if (BlockToOrder.count(Latch) == 0)
533       return false;
534   }
535 
536   return true;
537 }
538 
539 size_t PartialOrderingVisitor::GetNodeRank(BasicBlock *BB) const {
540   auto It = BlockToOrder.find(BB);
541   if (It != BlockToOrder.end())
542     return It->second.Rank;
543 
544   size_t result = 0;
545   for (BasicBlock *P : predecessors(BB)) {
546     // Ignore back-edges.
547     if (DT.dominates(BB, P))
548       continue;
549 
550     auto Iterator = BlockToOrder.end();
551     Loop *L = LI.getLoopFor(P);
552     BasicBlock *Latch = L ? L->getLoopLatch() : nullptr;
553 
554     // If the predecessor is either outside a loop, or part of
555     // the same loop, simply take its rank + 1.
556     if (L == nullptr || L->contains(BB) || Latch == nullptr) {
557       Iterator = BlockToOrder.find(P);
558     } else {
559       // Otherwise, take the loop's rank (highest rank in the loop) as base.
560       // Since loops have a single latch, highest rank is easy to find.
561       // If the loop has no latch, then it doesn't matter.
562       Iterator = BlockToOrder.find(Latch);
563     }
564 
565     assert(Iterator != BlockToOrder.end());
566     result = std::max(result, Iterator->second.Rank + 1);
567   }
568 
569   return result;
570 }
571 
572 size_t PartialOrderingVisitor::visit(BasicBlock *BB, size_t Unused) {
573   ToVisit.push(BB);
574   Queued.insert(BB);
575 
576   size_t QueueIndex = 0;
577   while (ToVisit.size() != 0) {
578     BasicBlock *BB = ToVisit.front();
579     ToVisit.pop();
580 
581     if (!CanBeVisited(BB)) {
582       ToVisit.push(BB);
583       assert(QueueIndex < ToVisit.size() &&
584              "No valid candidate in the queue. Is the graph reducible?");
585       QueueIndex++;
586       continue;
587     }
588 
589     QueueIndex = 0;
590     size_t Rank = GetNodeRank(BB);
591     OrderInfo Info = {Rank, BlockToOrder.size()};
592     BlockToOrder.emplace(BB, Info);
593 
594     for (BasicBlock *S : successors(BB)) {
595       if (Queued.count(S) != 0)
596         continue;
597       ToVisit.push(S);
598       Queued.insert(S);
599     }
600   }
601 
602   return 0;
603 }
604 
605 PartialOrderingVisitor::PartialOrderingVisitor(Function &F) {
606   DT.recalculate(F);
607   LI = LoopInfo(DT);
608 
609   visit(&*F.begin(), 0);
610 
611   Order.reserve(F.size());
612   for (auto &[BB, Info] : BlockToOrder)
613     Order.emplace_back(BB);
614 
615   std::sort(Order.begin(), Order.end(), [&](const auto &LHS, const auto &RHS) {
616     return compare(LHS, RHS);
617   });
618 }
619 
620 bool PartialOrderingVisitor::compare(const BasicBlock *LHS,
621                                      const BasicBlock *RHS) const {
622   const OrderInfo &InfoLHS = BlockToOrder.at(const_cast<BasicBlock *>(LHS));
623   const OrderInfo &InfoRHS = BlockToOrder.at(const_cast<BasicBlock *>(RHS));
624   if (InfoLHS.Rank != InfoRHS.Rank)
625     return InfoLHS.Rank < InfoRHS.Rank;
626   return InfoLHS.TraversalIndex < InfoRHS.TraversalIndex;
627 }
628 
629 void PartialOrderingVisitor::partialOrderVisit(
630     BasicBlock &Start, std::function<bool(BasicBlock *)> Op) {
631   std::unordered_set<BasicBlock *> Reachable = getReachableFrom(&Start);
632   assert(BlockToOrder.count(&Start) != 0);
633 
634   // Skipping blocks with a rank inferior to |Start|'s rank.
635   auto It = Order.begin();
636   while (It != Order.end() && *It != &Start)
637     ++It;
638 
639   // This is unexpected. Worst case |Start| is the last block,
640   // so It should point to the last block, not past-end.
641   assert(It != Order.end());
642 
643   // By default, there is no rank limit. Setting it to the maximum value.
644   std::optional<size_t> EndRank = std::nullopt;
645   for (; It != Order.end(); ++It) {
646     if (EndRank.has_value() && BlockToOrder[*It].Rank > *EndRank)
647       break;
648 
649     if (Reachable.count(*It) == 0) {
650       continue;
651     }
652 
653     if (!Op(*It)) {
654       EndRank = BlockToOrder[*It].Rank;
655     }
656   }
657 }
658 
659 bool sortBlocks(Function &F) {
660   if (F.size() == 0)
661     return false;
662 
663   bool Modified = false;
664   std::vector<BasicBlock *> Order;
665   Order.reserve(F.size());
666 
667   ReversePostOrderTraversal<Function *> RPOT(&F);
668   for (BasicBlock *BB : RPOT)
669     Order.push_back(BB);
670 
671   assert(&*F.begin() == Order[0]);
672   BasicBlock *LastBlock = &*F.begin();
673   for (BasicBlock *BB : Order) {
674     if (BB != LastBlock && &*LastBlock->getNextNode() != BB) {
675       Modified = true;
676       BB->moveAfter(LastBlock);
677     }
678     LastBlock = BB;
679   }
680 
681   return Modified;
682 }
683 
684 MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg) {
685   MachineInstr *MaybeDef = MRI.getVRegDef(Reg);
686   if (MaybeDef && MaybeDef->getOpcode() == SPIRV::ASSIGN_TYPE)
687     MaybeDef = MRI.getVRegDef(MaybeDef->getOperand(1).getReg());
688   return MaybeDef;
689 }
690 
691 bool getVacantFunctionName(Module &M, std::string &Name) {
692   // It's a bit of paranoia, but still we don't want to have even a chance that
693   // the loop will work for too long.
694   constexpr unsigned MaxIters = 1024;
695   for (unsigned I = 0; I < MaxIters; ++I) {
696     std::string OrdName = Name + Twine(I).str();
697     if (!M.getFunction(OrdName)) {
698       Name = OrdName;
699       return true;
700     }
701   }
702   return false;
703 }
704 
705 // Assign SPIR-V type to the register. If the register has no valid assigned
706 // class, set register LLT type and class according to the SPIR-V type.
707 void setRegClassType(Register Reg, SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
708                      MachineRegisterInfo *MRI, const MachineFunction &MF,
709                      bool Force) {
710   GR->assignSPIRVTypeToVReg(SpvType, Reg, MF);
711   if (!MRI->getRegClassOrNull(Reg) || Force) {
712     MRI->setRegClass(Reg, GR->getRegClass(SpvType));
713     MRI->setType(Reg, GR->getRegType(SpvType));
714   }
715 }
716 
717 // Create a SPIR-V type, assign SPIR-V type to the register. If the register has
718 // no valid assigned class, set register LLT type and class according to the
719 // SPIR-V type.
720 void setRegClassType(Register Reg, const Type *Ty, SPIRVGlobalRegistry *GR,
721                      MachineIRBuilder &MIRBuilder, bool Force) {
722   setRegClassType(Reg, GR->getOrCreateSPIRVType(Ty, MIRBuilder), GR,
723                   MIRBuilder.getMRI(), MIRBuilder.getMF(), Force);
724 }
725 
726 // Create a virtual register and assign SPIR-V type to the register. Set
727 // register LLT type and class according to the SPIR-V type.
728 Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
729                                MachineRegisterInfo *MRI,
730                                const MachineFunction &MF) {
731   Register Reg = MRI->createVirtualRegister(GR->getRegClass(SpvType));
732   MRI->setType(Reg, GR->getRegType(SpvType));
733   GR->assignSPIRVTypeToVReg(SpvType, Reg, MF);
734   return Reg;
735 }
736 
737 // Create a virtual register and assign SPIR-V type to the register. Set
738 // register LLT type and class according to the SPIR-V type.
739 Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
740                                MachineIRBuilder &MIRBuilder) {
741   return createVirtualRegister(SpvType, GR, MIRBuilder.getMRI(),
742                                MIRBuilder.getMF());
743 }
744 
745 // Create a SPIR-V type, virtual register and assign SPIR-V type to the
746 // register. Set register LLT type and class according to the SPIR-V type.
747 Register createVirtualRegister(const Type *Ty, SPIRVGlobalRegistry *GR,
748                                MachineIRBuilder &MIRBuilder) {
749   return createVirtualRegister(GR->getOrCreateSPIRVType(Ty, MIRBuilder), GR,
750                                MIRBuilder);
751 }
752 
753 // Return true if there is an opaque pointer type nested in the argument.
754 bool isNestedPointer(const Type *Ty) {
755   if (Ty->isPtrOrPtrVectorTy())
756     return true;
757   if (const FunctionType *RefTy = dyn_cast<FunctionType>(Ty)) {
758     if (isNestedPointer(RefTy->getReturnType()))
759       return true;
760     for (const Type *ArgTy : RefTy->params())
761       if (isNestedPointer(ArgTy))
762         return true;
763     return false;
764   }
765   if (const ArrayType *RefTy = dyn_cast<ArrayType>(Ty))
766     return isNestedPointer(RefTy->getElementType());
767   return false;
768 }
769 
770 bool isSpvIntrinsic(const Value *Arg) {
771   if (const auto *II = dyn_cast<IntrinsicInst>(Arg))
772     if (Function *F = II->getCalledFunction())
773       if (F->getName().starts_with("llvm.spv."))
774         return true;
775   return false;
776 }
777 
778 } // namespace llvm
779