xref: /llvm-project/llvm/lib/Target/SPIRV/SPIRVUtils.cpp (revision 3ed2a81358e11a582eb5cc3edf711447767036e6)
1 //===--- SPIRVUtils.cpp ---- SPIR-V Utility Functions -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains miscellaneous utility functions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "SPIRVUtils.h"
14 #include "MCTargetDesc/SPIRVBaseInfo.h"
15 #include "SPIRV.h"
16 #include "SPIRVGlobalRegistry.h"
17 #include "SPIRVInstrInfo.h"
18 #include "SPIRVSubtarget.h"
19 #include "llvm/ADT/StringRef.h"
20 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
21 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
22 #include "llvm/CodeGen/MachineInstr.h"
23 #include "llvm/CodeGen/MachineInstrBuilder.h"
24 #include "llvm/Demangle/Demangle.h"
25 #include "llvm/IR/IntrinsicInst.h"
26 #include "llvm/IR/IntrinsicsSPIRV.h"
27 #include <queue>
28 #include <vector>
29 
30 namespace llvm {
31 
32 // The following functions are used to add these string literals as a series of
33 // 32-bit integer operands with the correct format, and unpack them if necessary
34 // when making string comparisons in compiler passes.
35 // SPIR-V requires null-terminated UTF-8 strings padded to 32-bit alignment.
36 static uint32_t convertCharsToWord(const StringRef &Str, unsigned i) {
37   uint32_t Word = 0u; // Build up this 32-bit word from 4 8-bit chars.
38   for (unsigned WordIndex = 0; WordIndex < 4; ++WordIndex) {
39     unsigned StrIndex = i + WordIndex;
40     uint8_t CharToAdd = 0;       // Initilize char as padding/null.
41     if (StrIndex < Str.size()) { // If it's within the string, get a real char.
42       CharToAdd = Str[StrIndex];
43     }
44     Word |= (CharToAdd << (WordIndex * 8));
45   }
46   return Word;
47 }
48 
49 // Get length including padding and null terminator.
50 static size_t getPaddedLen(const StringRef &Str) {
51   return (Str.size() + 4) & ~3;
52 }
53 
54 void addStringImm(const StringRef &Str, MCInst &Inst) {
55   const size_t PaddedLen = getPaddedLen(Str);
56   for (unsigned i = 0; i < PaddedLen; i += 4) {
57     // Add an operand for the 32-bits of chars or padding.
58     Inst.addOperand(MCOperand::createImm(convertCharsToWord(Str, i)));
59   }
60 }
61 
62 void addStringImm(const StringRef &Str, MachineInstrBuilder &MIB) {
63   const size_t PaddedLen = getPaddedLen(Str);
64   for (unsigned i = 0; i < PaddedLen; i += 4) {
65     // Add an operand for the 32-bits of chars or padding.
66     MIB.addImm(convertCharsToWord(Str, i));
67   }
68 }
69 
70 void addStringImm(const StringRef &Str, IRBuilder<> &B,
71                   std::vector<Value *> &Args) {
72   const size_t PaddedLen = getPaddedLen(Str);
73   for (unsigned i = 0; i < PaddedLen; i += 4) {
74     // Add a vector element for the 32-bits of chars or padding.
75     Args.push_back(B.getInt32(convertCharsToWord(Str, i)));
76   }
77 }
78 
79 std::string getStringImm(const MachineInstr &MI, unsigned StartIndex) {
80   return getSPIRVStringOperand(MI, StartIndex);
81 }
82 
83 void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) {
84   const auto Bitwidth = Imm.getBitWidth();
85   if (Bitwidth == 1)
86     return; // Already handled
87   else if (Bitwidth <= 32) {
88     MIB.addImm(Imm.getZExtValue());
89     // Asm Printer needs this info to print floating-type correctly
90     if (Bitwidth == 16)
91       MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH16);
92     return;
93   } else if (Bitwidth <= 64) {
94     uint64_t FullImm = Imm.getZExtValue();
95     uint32_t LowBits = FullImm & 0xffffffff;
96     uint32_t HighBits = (FullImm >> 32) & 0xffffffff;
97     MIB.addImm(LowBits).addImm(HighBits);
98     return;
99   }
100   report_fatal_error("Unsupported constant bitwidth");
101 }
102 
103 void buildOpName(Register Target, const StringRef &Name,
104                  MachineIRBuilder &MIRBuilder) {
105   if (!Name.empty()) {
106     auto MIB = MIRBuilder.buildInstr(SPIRV::OpName).addUse(Target);
107     addStringImm(Name, MIB);
108   }
109 }
110 
111 void buildOpName(Register Target, const StringRef &Name, MachineInstr &I,
112                  const SPIRVInstrInfo &TII) {
113   if (!Name.empty()) {
114     auto MIB =
115         BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpName))
116             .addUse(Target);
117     addStringImm(Name, MIB);
118   }
119 }
120 
121 static void finishBuildOpDecorate(MachineInstrBuilder &MIB,
122                                   const std::vector<uint32_t> &DecArgs,
123                                   StringRef StrImm) {
124   if (!StrImm.empty())
125     addStringImm(StrImm, MIB);
126   for (const auto &DecArg : DecArgs)
127     MIB.addImm(DecArg);
128 }
129 
130 void buildOpDecorate(Register Reg, MachineIRBuilder &MIRBuilder,
131                      SPIRV::Decoration::Decoration Dec,
132                      const std::vector<uint32_t> &DecArgs, StringRef StrImm) {
133   auto MIB = MIRBuilder.buildInstr(SPIRV::OpDecorate)
134                  .addUse(Reg)
135                  .addImm(static_cast<uint32_t>(Dec));
136   finishBuildOpDecorate(MIB, DecArgs, StrImm);
137 }
138 
139 void buildOpDecorate(Register Reg, MachineInstr &I, const SPIRVInstrInfo &TII,
140                      SPIRV::Decoration::Decoration Dec,
141                      const std::vector<uint32_t> &DecArgs, StringRef StrImm) {
142   MachineBasicBlock &MBB = *I.getParent();
143   auto MIB = BuildMI(MBB, I, I.getDebugLoc(), TII.get(SPIRV::OpDecorate))
144                  .addUse(Reg)
145                  .addImm(static_cast<uint32_t>(Dec));
146   finishBuildOpDecorate(MIB, DecArgs, StrImm);
147 }
148 
149 void buildOpSpirvDecorations(Register Reg, MachineIRBuilder &MIRBuilder,
150                              const MDNode *GVarMD) {
151   for (unsigned I = 0, E = GVarMD->getNumOperands(); I != E; ++I) {
152     auto *OpMD = dyn_cast<MDNode>(GVarMD->getOperand(I));
153     if (!OpMD)
154       report_fatal_error("Invalid decoration");
155     if (OpMD->getNumOperands() == 0)
156       report_fatal_error("Expect operand(s) of the decoration");
157     ConstantInt *DecorationId =
158         mdconst::dyn_extract<ConstantInt>(OpMD->getOperand(0));
159     if (!DecorationId)
160       report_fatal_error("Expect SPIR-V <Decoration> operand to be the first "
161                          "element of the decoration");
162     auto MIB = MIRBuilder.buildInstr(SPIRV::OpDecorate)
163                    .addUse(Reg)
164                    .addImm(static_cast<uint32_t>(DecorationId->getZExtValue()));
165     for (unsigned OpI = 1, OpE = OpMD->getNumOperands(); OpI != OpE; ++OpI) {
166       if (ConstantInt *OpV =
167               mdconst::dyn_extract<ConstantInt>(OpMD->getOperand(OpI)))
168         MIB.addImm(static_cast<uint32_t>(OpV->getZExtValue()));
169       else if (MDString *OpV = dyn_cast<MDString>(OpMD->getOperand(OpI)))
170         addStringImm(OpV->getString(), MIB);
171       else
172         report_fatal_error("Unexpected operand of the decoration");
173     }
174   }
175 }
176 
177 MachineBasicBlock::iterator getOpVariableMBBIt(MachineInstr &I) {
178   MachineFunction *MF = I.getParent()->getParent();
179   MachineBasicBlock *MBB = &MF->front();
180   MachineBasicBlock::iterator It = MBB->SkipPHIsAndLabels(MBB->begin()),
181                               E = MBB->end();
182   bool IsHeader = false;
183   unsigned Opcode;
184   for (; It != E && It != I; ++It) {
185     Opcode = It->getOpcode();
186     if (Opcode == SPIRV::OpFunction || Opcode == SPIRV::OpFunctionParameter) {
187       IsHeader = true;
188     } else if (IsHeader &&
189                !(Opcode == SPIRV::ASSIGN_TYPE || Opcode == SPIRV::OpLabel)) {
190       ++It;
191       break;
192     }
193   }
194   return It;
195 }
196 
197 MachineBasicBlock::iterator getInsertPtValidEnd(MachineBasicBlock *MBB) {
198   MachineBasicBlock::iterator I = MBB->end();
199   if (I == MBB->begin())
200     return I;
201   --I;
202   while (I->isTerminator() || I->isDebugValue()) {
203     if (I == MBB->begin())
204       break;
205     --I;
206   }
207   return I;
208 }
209 
210 SPIRV::StorageClass::StorageClass
211 addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI) {
212   switch (AddrSpace) {
213   case 0:
214     return SPIRV::StorageClass::Function;
215   case 1:
216     return SPIRV::StorageClass::CrossWorkgroup;
217   case 2:
218     return SPIRV::StorageClass::UniformConstant;
219   case 3:
220     return SPIRV::StorageClass::Workgroup;
221   case 4:
222     return SPIRV::StorageClass::Generic;
223   case 5:
224     return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)
225                ? SPIRV::StorageClass::DeviceOnlyINTEL
226                : SPIRV::StorageClass::CrossWorkgroup;
227   case 6:
228     return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)
229                ? SPIRV::StorageClass::HostOnlyINTEL
230                : SPIRV::StorageClass::CrossWorkgroup;
231   case 7:
232     return SPIRV::StorageClass::Input;
233   case 8:
234     return SPIRV::StorageClass::Output;
235   case 9:
236     return SPIRV::StorageClass::CodeSectionINTEL;
237   case 10:
238     return SPIRV::StorageClass::Private;
239   default:
240     report_fatal_error("Unknown address space");
241   }
242 }
243 
244 SPIRV::MemorySemantics::MemorySemantics
245 getMemSemanticsForStorageClass(SPIRV::StorageClass::StorageClass SC) {
246   switch (SC) {
247   case SPIRV::StorageClass::StorageBuffer:
248   case SPIRV::StorageClass::Uniform:
249     return SPIRV::MemorySemantics::UniformMemory;
250   case SPIRV::StorageClass::Workgroup:
251     return SPIRV::MemorySemantics::WorkgroupMemory;
252   case SPIRV::StorageClass::CrossWorkgroup:
253     return SPIRV::MemorySemantics::CrossWorkgroupMemory;
254   case SPIRV::StorageClass::AtomicCounter:
255     return SPIRV::MemorySemantics::AtomicCounterMemory;
256   case SPIRV::StorageClass::Image:
257     return SPIRV::MemorySemantics::ImageMemory;
258   default:
259     return SPIRV::MemorySemantics::None;
260   }
261 }
262 
263 SPIRV::MemorySemantics::MemorySemantics getMemSemantics(AtomicOrdering Ord) {
264   switch (Ord) {
265   case AtomicOrdering::Acquire:
266     return SPIRV::MemorySemantics::Acquire;
267   case AtomicOrdering::Release:
268     return SPIRV::MemorySemantics::Release;
269   case AtomicOrdering::AcquireRelease:
270     return SPIRV::MemorySemantics::AcquireRelease;
271   case AtomicOrdering::SequentiallyConsistent:
272     return SPIRV::MemorySemantics::SequentiallyConsistent;
273   case AtomicOrdering::Unordered:
274   case AtomicOrdering::Monotonic:
275   case AtomicOrdering::NotAtomic:
276     return SPIRV::MemorySemantics::None;
277   }
278   llvm_unreachable(nullptr);
279 }
280 
281 SPIRV::Scope::Scope getMemScope(LLVMContext &Ctx, SyncScope::ID Id) {
282   // Named by
283   // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_scope_id.
284   // We don't need aliases for Invocation and CrossDevice, as we already have
285   // them covered by "singlethread" and "" strings respectively (see
286   // implementation of LLVMContext::LLVMContext()).
287   static const llvm::SyncScope::ID SubGroup =
288       Ctx.getOrInsertSyncScopeID("subgroup");
289   static const llvm::SyncScope::ID WorkGroup =
290       Ctx.getOrInsertSyncScopeID("workgroup");
291   static const llvm::SyncScope::ID Device =
292       Ctx.getOrInsertSyncScopeID("device");
293 
294   if (Id == llvm::SyncScope::SingleThread)
295     return SPIRV::Scope::Invocation;
296   else if (Id == llvm::SyncScope::System)
297     return SPIRV::Scope::CrossDevice;
298   else if (Id == SubGroup)
299     return SPIRV::Scope::Subgroup;
300   else if (Id == WorkGroup)
301     return SPIRV::Scope::Workgroup;
302   else if (Id == Device)
303     return SPIRV::Scope::Device;
304   return SPIRV::Scope::CrossDevice;
305 }
306 
307 MachineInstr *getDefInstrMaybeConstant(Register &ConstReg,
308                                        const MachineRegisterInfo *MRI) {
309   MachineInstr *MI = MRI->getVRegDef(ConstReg);
310   MachineInstr *ConstInstr =
311       MI->getOpcode() == SPIRV::G_TRUNC || MI->getOpcode() == SPIRV::G_ZEXT
312           ? MRI->getVRegDef(MI->getOperand(1).getReg())
313           : MI;
314   if (auto *GI = dyn_cast<GIntrinsic>(ConstInstr)) {
315     if (GI->is(Intrinsic::spv_track_constant)) {
316       ConstReg = ConstInstr->getOperand(2).getReg();
317       return MRI->getVRegDef(ConstReg);
318     }
319   } else if (ConstInstr->getOpcode() == SPIRV::ASSIGN_TYPE) {
320     ConstReg = ConstInstr->getOperand(1).getReg();
321     return MRI->getVRegDef(ConstReg);
322   }
323   return MRI->getVRegDef(ConstReg);
324 }
325 
326 uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI) {
327   const MachineInstr *MI = getDefInstrMaybeConstant(ConstReg, MRI);
328   assert(MI && MI->getOpcode() == TargetOpcode::G_CONSTANT);
329   return MI->getOperand(1).getCImm()->getValue().getZExtValue();
330 }
331 
332 bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID) {
333   if (const auto *GI = dyn_cast<GIntrinsic>(&MI))
334     return GI->is(IntrinsicID);
335   return false;
336 }
337 
338 Type *getMDOperandAsType(const MDNode *N, unsigned I) {
339   Type *ElementTy = cast<ValueAsMetadata>(N->getOperand(I))->getType();
340   return toTypedPointer(ElementTy);
341 }
342 
343 // The set of names is borrowed from the SPIR-V translator.
344 // TODO: may be implemented in SPIRVBuiltins.td.
345 static bool isPipeOrAddressSpaceCastBI(const StringRef MangledName) {
346   return MangledName == "write_pipe_2" || MangledName == "read_pipe_2" ||
347          MangledName == "write_pipe_2_bl" || MangledName == "read_pipe_2_bl" ||
348          MangledName == "write_pipe_4" || MangledName == "read_pipe_4" ||
349          MangledName == "reserve_write_pipe" ||
350          MangledName == "reserve_read_pipe" ||
351          MangledName == "commit_write_pipe" ||
352          MangledName == "commit_read_pipe" ||
353          MangledName == "work_group_reserve_write_pipe" ||
354          MangledName == "work_group_reserve_read_pipe" ||
355          MangledName == "work_group_commit_write_pipe" ||
356          MangledName == "work_group_commit_read_pipe" ||
357          MangledName == "get_pipe_num_packets_ro" ||
358          MangledName == "get_pipe_max_packets_ro" ||
359          MangledName == "get_pipe_num_packets_wo" ||
360          MangledName == "get_pipe_max_packets_wo" ||
361          MangledName == "sub_group_reserve_write_pipe" ||
362          MangledName == "sub_group_reserve_read_pipe" ||
363          MangledName == "sub_group_commit_write_pipe" ||
364          MangledName == "sub_group_commit_read_pipe" ||
365          MangledName == "to_global" || MangledName == "to_local" ||
366          MangledName == "to_private";
367 }
368 
369 static bool isEnqueueKernelBI(const StringRef MangledName) {
370   return MangledName == "__enqueue_kernel_basic" ||
371          MangledName == "__enqueue_kernel_basic_events" ||
372          MangledName == "__enqueue_kernel_varargs" ||
373          MangledName == "__enqueue_kernel_events_varargs";
374 }
375 
376 static bool isKernelQueryBI(const StringRef MangledName) {
377   return MangledName == "__get_kernel_work_group_size_impl" ||
378          MangledName == "__get_kernel_sub_group_count_for_ndrange_impl" ||
379          MangledName == "__get_kernel_max_sub_group_size_for_ndrange_impl" ||
380          MangledName == "__get_kernel_preferred_work_group_size_multiple_impl";
381 }
382 
383 static bool isNonMangledOCLBuiltin(StringRef Name) {
384   if (!Name.starts_with("__"))
385     return false;
386 
387   return isEnqueueKernelBI(Name) || isKernelQueryBI(Name) ||
388          isPipeOrAddressSpaceCastBI(Name.drop_front(2)) ||
389          Name == "__translate_sampler_initializer";
390 }
391 
392 std::string getOclOrSpirvBuiltinDemangledName(StringRef Name) {
393   bool IsNonMangledOCL = isNonMangledOCLBuiltin(Name);
394   bool IsNonMangledSPIRV = Name.starts_with("__spirv_");
395   bool IsNonMangledHLSL = Name.starts_with("__hlsl_");
396   bool IsMangled = Name.starts_with("_Z");
397 
398   // Otherwise use simple demangling to return the function name.
399   if (IsNonMangledOCL || IsNonMangledSPIRV || IsNonMangledHLSL || !IsMangled)
400     return Name.str();
401 
402   // Try to use the itanium demangler.
403   if (char *DemangledName = itaniumDemangle(Name.data())) {
404     std::string Result = DemangledName;
405     free(DemangledName);
406     return Result;
407   }
408 
409   // Autocheck C++, maybe need to do explicit check of the source language.
410   // OpenCL C++ built-ins are declared in cl namespace.
411   // TODO: consider using 'St' abbriviation for cl namespace mangling.
412   // Similar to ::std:: in C++.
413   size_t Start, Len = 0;
414   size_t DemangledNameLenStart = 2;
415   if (Name.starts_with("_ZN")) {
416     // Skip CV and ref qualifiers.
417     size_t NameSpaceStart = Name.find_first_not_of("rVKRO", 3);
418     // All built-ins are in the ::cl:: namespace.
419     if (Name.substr(NameSpaceStart, 11) != "2cl7__spirv")
420       return std::string();
421     DemangledNameLenStart = NameSpaceStart + 11;
422   }
423   Start = Name.find_first_not_of("0123456789", DemangledNameLenStart);
424   Name.substr(DemangledNameLenStart, Start - DemangledNameLenStart)
425       .getAsInteger(10, Len);
426   return Name.substr(Start, Len).str();
427 }
428 
429 bool hasBuiltinTypePrefix(StringRef Name) {
430   if (Name.starts_with("opencl.") || Name.starts_with("ocl_") ||
431       Name.starts_with("spirv."))
432     return true;
433   return false;
434 }
435 
436 bool isSpecialOpaqueType(const Type *Ty) {
437   if (const TargetExtType *ExtTy = dyn_cast<TargetExtType>(Ty))
438     return isTypedPointerWrapper(ExtTy)
439                ? false
440                : hasBuiltinTypePrefix(ExtTy->getName());
441 
442   return false;
443 }
444 
445 bool isEntryPoint(const Function &F) {
446   // OpenCL handling: any function with the SPIR_KERNEL
447   // calling convention will be a potential entry point.
448   if (F.getCallingConv() == CallingConv::SPIR_KERNEL)
449     return true;
450 
451   // HLSL handling: special attribute are emitted from the
452   // front-end.
453   if (F.getFnAttribute("hlsl.shader").isValid())
454     return true;
455 
456   return false;
457 }
458 
459 Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx) {
460   TypeName.consume_front("atomic_");
461   if (TypeName.consume_front("void"))
462     return Type::getVoidTy(Ctx);
463   else if (TypeName.consume_front("bool") || TypeName.consume_front("_Bool"))
464     return Type::getIntNTy(Ctx, 1);
465   else if (TypeName.consume_front("char") ||
466            TypeName.consume_front("signed char") ||
467            TypeName.consume_front("unsigned char") ||
468            TypeName.consume_front("uchar"))
469     return Type::getInt8Ty(Ctx);
470   else if (TypeName.consume_front("short") ||
471            TypeName.consume_front("signed short") ||
472            TypeName.consume_front("unsigned short") ||
473            TypeName.consume_front("ushort"))
474     return Type::getInt16Ty(Ctx);
475   else if (TypeName.consume_front("int") ||
476            TypeName.consume_front("signed int") ||
477            TypeName.consume_front("unsigned int") ||
478            TypeName.consume_front("uint"))
479     return Type::getInt32Ty(Ctx);
480   else if (TypeName.consume_front("long") ||
481            TypeName.consume_front("signed long") ||
482            TypeName.consume_front("unsigned long") ||
483            TypeName.consume_front("ulong"))
484     return Type::getInt64Ty(Ctx);
485   else if (TypeName.consume_front("half") ||
486            TypeName.consume_front("_Float16") ||
487            TypeName.consume_front("__fp16"))
488     return Type::getHalfTy(Ctx);
489   else if (TypeName.consume_front("float"))
490     return Type::getFloatTy(Ctx);
491   else if (TypeName.consume_front("double"))
492     return Type::getDoubleTy(Ctx);
493 
494   // Unable to recognize SPIRV type name
495   return nullptr;
496 }
497 
498 std::unordered_set<BasicBlock *>
499 PartialOrderingVisitor::getReachableFrom(BasicBlock *Start) {
500   std::queue<BasicBlock *> ToVisit;
501   ToVisit.push(Start);
502 
503   std::unordered_set<BasicBlock *> Output;
504   while (ToVisit.size() != 0) {
505     BasicBlock *BB = ToVisit.front();
506     ToVisit.pop();
507 
508     if (Output.count(BB) != 0)
509       continue;
510     Output.insert(BB);
511 
512     for (BasicBlock *Successor : successors(BB)) {
513       if (DT.dominates(Successor, BB))
514         continue;
515       ToVisit.push(Successor);
516     }
517   }
518 
519   return Output;
520 }
521 
522 bool PartialOrderingVisitor::CanBeVisited(BasicBlock *BB) const {
523   for (BasicBlock *P : predecessors(BB)) {
524     // Ignore back-edges.
525     if (DT.dominates(BB, P))
526       continue;
527 
528     // One of the predecessor hasn't been visited. Not ready yet.
529     if (BlockToOrder.count(P) == 0)
530       return false;
531 
532     // If the block is a loop exit, the loop must be finished before
533     // we can continue.
534     Loop *L = LI.getLoopFor(P);
535     if (L == nullptr || L->contains(BB))
536       continue;
537 
538     // SPIR-V requires a single back-edge. And the backend first
539     // step transforms loops into the simplified format. If we have
540     // more than 1 back-edge, something is wrong.
541     assert(L->getNumBackEdges() <= 1);
542 
543     // If the loop has no latch, loop's rank won't matter, so we can
544     // proceed.
545     BasicBlock *Latch = L->getLoopLatch();
546     assert(Latch);
547     if (Latch == nullptr)
548       continue;
549 
550     // The latch is not ready yet, let's wait.
551     if (BlockToOrder.count(Latch) == 0)
552       return false;
553   }
554 
555   return true;
556 }
557 
558 size_t PartialOrderingVisitor::GetNodeRank(BasicBlock *BB) const {
559   auto It = BlockToOrder.find(BB);
560   if (It != BlockToOrder.end())
561     return It->second.Rank;
562 
563   size_t result = 0;
564   for (BasicBlock *P : predecessors(BB)) {
565     // Ignore back-edges.
566     if (DT.dominates(BB, P))
567       continue;
568 
569     auto Iterator = BlockToOrder.end();
570     Loop *L = LI.getLoopFor(P);
571     BasicBlock *Latch = L ? L->getLoopLatch() : nullptr;
572 
573     // If the predecessor is either outside a loop, or part of
574     // the same loop, simply take its rank + 1.
575     if (L == nullptr || L->contains(BB) || Latch == nullptr) {
576       Iterator = BlockToOrder.find(P);
577     } else {
578       // Otherwise, take the loop's rank (highest rank in the loop) as base.
579       // Since loops have a single latch, highest rank is easy to find.
580       // If the loop has no latch, then it doesn't matter.
581       Iterator = BlockToOrder.find(Latch);
582     }
583 
584     assert(Iterator != BlockToOrder.end());
585     result = std::max(result, Iterator->second.Rank + 1);
586   }
587 
588   return result;
589 }
590 
591 size_t PartialOrderingVisitor::visit(BasicBlock *BB, size_t Unused) {
592   ToVisit.push(BB);
593   Queued.insert(BB);
594 
595   size_t QueueIndex = 0;
596   while (ToVisit.size() != 0) {
597     BasicBlock *BB = ToVisit.front();
598     ToVisit.pop();
599 
600     if (!CanBeVisited(BB)) {
601       ToVisit.push(BB);
602       if (QueueIndex >= ToVisit.size())
603         llvm::report_fatal_error(
604             "No valid candidate in the queue. Is the graph reducible?");
605       QueueIndex++;
606       continue;
607     }
608 
609     QueueIndex = 0;
610     size_t Rank = GetNodeRank(BB);
611     OrderInfo Info = {Rank, BlockToOrder.size()};
612     BlockToOrder.emplace(BB, Info);
613 
614     for (BasicBlock *S : successors(BB)) {
615       if (Queued.count(S) != 0)
616         continue;
617       ToVisit.push(S);
618       Queued.insert(S);
619     }
620   }
621 
622   return 0;
623 }
624 
625 PartialOrderingVisitor::PartialOrderingVisitor(Function &F) {
626   DT.recalculate(F);
627   LI = LoopInfo(DT);
628 
629   visit(&*F.begin(), 0);
630 
631   Order.reserve(F.size());
632   for (auto &[BB, Info] : BlockToOrder)
633     Order.emplace_back(BB);
634 
635   std::sort(Order.begin(), Order.end(), [&](const auto &LHS, const auto &RHS) {
636     return compare(LHS, RHS);
637   });
638 }
639 
640 bool PartialOrderingVisitor::compare(const BasicBlock *LHS,
641                                      const BasicBlock *RHS) const {
642   const OrderInfo &InfoLHS = BlockToOrder.at(const_cast<BasicBlock *>(LHS));
643   const OrderInfo &InfoRHS = BlockToOrder.at(const_cast<BasicBlock *>(RHS));
644   if (InfoLHS.Rank != InfoRHS.Rank)
645     return InfoLHS.Rank < InfoRHS.Rank;
646   return InfoLHS.TraversalIndex < InfoRHS.TraversalIndex;
647 }
648 
649 void PartialOrderingVisitor::partialOrderVisit(
650     BasicBlock &Start, std::function<bool(BasicBlock *)> Op) {
651   std::unordered_set<BasicBlock *> Reachable = getReachableFrom(&Start);
652   assert(BlockToOrder.count(&Start) != 0);
653 
654   // Skipping blocks with a rank inferior to |Start|'s rank.
655   auto It = Order.begin();
656   while (It != Order.end() && *It != &Start)
657     ++It;
658 
659   // This is unexpected. Worst case |Start| is the last block,
660   // so It should point to the last block, not past-end.
661   assert(It != Order.end());
662 
663   // By default, there is no rank limit. Setting it to the maximum value.
664   std::optional<size_t> EndRank = std::nullopt;
665   for (; It != Order.end(); ++It) {
666     if (EndRank.has_value() && BlockToOrder[*It].Rank > *EndRank)
667       break;
668 
669     if (Reachable.count(*It) == 0) {
670       continue;
671     }
672 
673     if (!Op(*It)) {
674       EndRank = BlockToOrder[*It].Rank;
675     }
676   }
677 }
678 
679 bool sortBlocks(Function &F) {
680   if (F.size() == 0)
681     return false;
682 
683   bool Modified = false;
684   std::vector<BasicBlock *> Order;
685   Order.reserve(F.size());
686 
687   ReversePostOrderTraversal<Function *> RPOT(&F);
688   for (BasicBlock *BB : RPOT)
689     Order.push_back(BB);
690 
691   assert(&*F.begin() == Order[0]);
692   BasicBlock *LastBlock = &*F.begin();
693   for (BasicBlock *BB : Order) {
694     if (BB != LastBlock && &*LastBlock->getNextNode() != BB) {
695       Modified = true;
696       BB->moveAfter(LastBlock);
697     }
698     LastBlock = BB;
699   }
700 
701   return Modified;
702 }
703 
704 MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg) {
705   MachineInstr *MaybeDef = MRI.getVRegDef(Reg);
706   if (MaybeDef && MaybeDef->getOpcode() == SPIRV::ASSIGN_TYPE)
707     MaybeDef = MRI.getVRegDef(MaybeDef->getOperand(1).getReg());
708   return MaybeDef;
709 }
710 
711 bool getVacantFunctionName(Module &M, std::string &Name) {
712   // It's a bit of paranoia, but still we don't want to have even a chance that
713   // the loop will work for too long.
714   constexpr unsigned MaxIters = 1024;
715   for (unsigned I = 0; I < MaxIters; ++I) {
716     std::string OrdName = Name + Twine(I).str();
717     if (!M.getFunction(OrdName)) {
718       Name = OrdName;
719       return true;
720     }
721   }
722   return false;
723 }
724 
725 // Assign SPIR-V type to the register. If the register has no valid assigned
726 // class, set register LLT type and class according to the SPIR-V type.
727 void setRegClassType(Register Reg, SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
728                      MachineRegisterInfo *MRI, const MachineFunction &MF,
729                      bool Force) {
730   GR->assignSPIRVTypeToVReg(SpvType, Reg, MF);
731   if (!MRI->getRegClassOrNull(Reg) || Force) {
732     MRI->setRegClass(Reg, GR->getRegClass(SpvType));
733     MRI->setType(Reg, GR->getRegType(SpvType));
734   }
735 }
736 
737 // Create a SPIR-V type, assign SPIR-V type to the register. If the register has
738 // no valid assigned class, set register LLT type and class according to the
739 // SPIR-V type.
740 void setRegClassType(Register Reg, const Type *Ty, SPIRVGlobalRegistry *GR,
741                      MachineIRBuilder &MIRBuilder, bool Force) {
742   setRegClassType(Reg, GR->getOrCreateSPIRVType(Ty, MIRBuilder), GR,
743                   MIRBuilder.getMRI(), MIRBuilder.getMF(), Force);
744 }
745 
746 // Create a virtual register and assign SPIR-V type to the register. Set
747 // register LLT type and class according to the SPIR-V type.
748 Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
749                                MachineRegisterInfo *MRI,
750                                const MachineFunction &MF) {
751   Register Reg = MRI->createVirtualRegister(GR->getRegClass(SpvType));
752   MRI->setType(Reg, GR->getRegType(SpvType));
753   GR->assignSPIRVTypeToVReg(SpvType, Reg, MF);
754   return Reg;
755 }
756 
757 // Create a virtual register and assign SPIR-V type to the register. Set
758 // register LLT type and class according to the SPIR-V type.
759 Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
760                                MachineIRBuilder &MIRBuilder) {
761   return createVirtualRegister(SpvType, GR, MIRBuilder.getMRI(),
762                                MIRBuilder.getMF());
763 }
764 
765 // Create a SPIR-V type, virtual register and assign SPIR-V type to the
766 // register. Set register LLT type and class according to the SPIR-V type.
767 Register createVirtualRegister(const Type *Ty, SPIRVGlobalRegistry *GR,
768                                MachineIRBuilder &MIRBuilder) {
769   return createVirtualRegister(GR->getOrCreateSPIRVType(Ty, MIRBuilder), GR,
770                                MIRBuilder);
771 }
772 
773 // Return true if there is an opaque pointer type nested in the argument.
774 bool isNestedPointer(const Type *Ty) {
775   if (Ty->isPtrOrPtrVectorTy())
776     return true;
777   if (const FunctionType *RefTy = dyn_cast<FunctionType>(Ty)) {
778     if (isNestedPointer(RefTy->getReturnType()))
779       return true;
780     for (const Type *ArgTy : RefTy->params())
781       if (isNestedPointer(ArgTy))
782         return true;
783     return false;
784   }
785   if (const ArrayType *RefTy = dyn_cast<ArrayType>(Ty))
786     return isNestedPointer(RefTy->getElementType());
787   return false;
788 }
789 
790 bool isSpvIntrinsic(const Value *Arg) {
791   if (const auto *II = dyn_cast<IntrinsicInst>(Arg))
792     if (Function *F = II->getCalledFunction())
793       if (F->getName().starts_with("llvm.spv."))
794         return true;
795   return false;
796 }
797 
798 } // namespace llvm
799