xref: /llvm-project/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp (revision f9c8c01d38f8fbea81db99ab90b7d0f2bdcc8b4d)
1 //===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a printer that converts from our internal representation
10 // of machine-dependent LLVM code to NVPTX assembly language.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "NVPTXAsmPrinter.h"
15 #include "MCTargetDesc/NVPTXBaseInfo.h"
16 #include "MCTargetDesc/NVPTXInstPrinter.h"
17 #include "MCTargetDesc/NVPTXMCAsmInfo.h"
18 #include "MCTargetDesc/NVPTXTargetStreamer.h"
19 #include "NVPTX.h"
20 #include "NVPTXMCExpr.h"
21 #include "NVPTXMachineFunctionInfo.h"
22 #include "NVPTXRegisterInfo.h"
23 #include "NVPTXSubtarget.h"
24 #include "NVPTXTargetMachine.h"
25 #include "NVPTXUtilities.h"
26 #include "TargetInfo/NVPTXTargetInfo.h"
27 #include "cl_common_defines.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/DenseSet.h"
32 #include "llvm/ADT/SmallString.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/Twine.h"
37 #include "llvm/Analysis/ConstantFolding.h"
38 #include "llvm/CodeGen/Analysis.h"
39 #include "llvm/CodeGen/MachineBasicBlock.h"
40 #include "llvm/CodeGen/MachineFrameInfo.h"
41 #include "llvm/CodeGen/MachineFunction.h"
42 #include "llvm/CodeGen/MachineInstr.h"
43 #include "llvm/CodeGen/MachineLoopInfo.h"
44 #include "llvm/CodeGen/MachineModuleInfo.h"
45 #include "llvm/CodeGen/MachineOperand.h"
46 #include "llvm/CodeGen/MachineRegisterInfo.h"
47 #include "llvm/CodeGen/TargetRegisterInfo.h"
48 #include "llvm/CodeGen/ValueTypes.h"
49 #include "llvm/CodeGenTypes/MachineValueType.h"
50 #include "llvm/IR/Attributes.h"
51 #include "llvm/IR/BasicBlock.h"
52 #include "llvm/IR/Constant.h"
53 #include "llvm/IR/Constants.h"
54 #include "llvm/IR/DataLayout.h"
55 #include "llvm/IR/DebugInfo.h"
56 #include "llvm/IR/DebugInfoMetadata.h"
57 #include "llvm/IR/DebugLoc.h"
58 #include "llvm/IR/DerivedTypes.h"
59 #include "llvm/IR/Function.h"
60 #include "llvm/IR/GlobalAlias.h"
61 #include "llvm/IR/GlobalValue.h"
62 #include "llvm/IR/GlobalVariable.h"
63 #include "llvm/IR/Instruction.h"
64 #include "llvm/IR/LLVMContext.h"
65 #include "llvm/IR/Module.h"
66 #include "llvm/IR/Operator.h"
67 #include "llvm/IR/Type.h"
68 #include "llvm/IR/User.h"
69 #include "llvm/MC/MCExpr.h"
70 #include "llvm/MC/MCInst.h"
71 #include "llvm/MC/MCInstrDesc.h"
72 #include "llvm/MC/MCStreamer.h"
73 #include "llvm/MC/MCSymbol.h"
74 #include "llvm/MC/TargetRegistry.h"
75 #include "llvm/Support/Alignment.h"
76 #include "llvm/Support/Casting.h"
77 #include "llvm/Support/CommandLine.h"
78 #include "llvm/Support/Endian.h"
79 #include "llvm/Support/ErrorHandling.h"
80 #include "llvm/Support/NativeFormatting.h"
81 #include "llvm/Support/raw_ostream.h"
82 #include "llvm/Target/TargetLoweringObjectFile.h"
83 #include "llvm/Target/TargetMachine.h"
84 #include "llvm/Transforms/Utils/UnrollLoop.h"
85 #include <cassert>
86 #include <cstdint>
87 #include <cstring>
88 #include <string>
89 #include <utility>
90 #include <vector>
91 
92 using namespace llvm;
93 
94 static cl::opt<bool>
95     LowerCtorDtor("nvptx-lower-global-ctor-dtor",
96                   cl::desc("Lower GPU ctor / dtors to globals on the device."),
97                   cl::init(false), cl::Hidden);
98 
99 #define DEPOTNAME "__local_depot"
100 
101 /// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
102 /// depends.
103 static void
104 DiscoverDependentGlobals(const Value *V,
105                          DenseSet<const GlobalVariable *> &Globals) {
106   if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
107     Globals.insert(GV);
108   else {
109     if (const User *U = dyn_cast<User>(V)) {
110       for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) {
111         DiscoverDependentGlobals(U->getOperand(i), Globals);
112       }
113     }
114   }
115 }
116 
117 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
118 /// instances to be emitted, but only after any dependents have been added
119 /// first.s
120 static void
121 VisitGlobalVariableForEmission(const GlobalVariable *GV,
122                                SmallVectorImpl<const GlobalVariable *> &Order,
123                                DenseSet<const GlobalVariable *> &Visited,
124                                DenseSet<const GlobalVariable *> &Visiting) {
125   // Have we already visited this one?
126   if (Visited.count(GV))
127     return;
128 
129   // Do we have a circular dependency?
130   if (!Visiting.insert(GV).second)
131     report_fatal_error("Circular dependency found in global variable set");
132 
133   // Make sure we visit all dependents first
134   DenseSet<const GlobalVariable *> Others;
135   for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i)
136     DiscoverDependentGlobals(GV->getOperand(i), Others);
137 
138   for (const GlobalVariable *GV : Others)
139     VisitGlobalVariableForEmission(GV, Order, Visited, Visiting);
140 
141   // Now we can visit ourself
142   Order.push_back(GV);
143   Visited.insert(GV);
144   Visiting.erase(GV);
145 }
146 
147 void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
148   NVPTX_MC::verifyInstructionPredicates(MI->getOpcode(),
149                                         getSubtargetInfo().getFeatureBits());
150 
151   MCInst Inst;
152   lowerToMCInst(MI, Inst);
153   EmitToStreamer(*OutStreamer, Inst);
154 }
155 
156 // Handle symbol backtracking for targets that do not support image handles
157 bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI,
158                                            unsigned OpNo, MCOperand &MCOp) {
159   const MachineOperand &MO = MI->getOperand(OpNo);
160   const MCInstrDesc &MCID = MI->getDesc();
161 
162   if (MCID.TSFlags & NVPTXII::IsTexFlag) {
163     // This is a texture fetch, so operand 4 is a texref and operand 5 is
164     // a samplerref
165     if (OpNo == 4 && MO.isImm()) {
166       lowerImageHandleSymbol(MO.getImm(), MCOp);
167       return true;
168     }
169     if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) {
170       lowerImageHandleSymbol(MO.getImm(), MCOp);
171       return true;
172     }
173 
174     return false;
175   } else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
176     unsigned VecSize =
177       1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
178 
179     // For a surface load of vector size N, the Nth operand will be the surfref
180     if (OpNo == VecSize && MO.isImm()) {
181       lowerImageHandleSymbol(MO.getImm(), MCOp);
182       return true;
183     }
184 
185     return false;
186   } else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
187     // This is a surface store, so operand 0 is a surfref
188     if (OpNo == 0 && MO.isImm()) {
189       lowerImageHandleSymbol(MO.getImm(), MCOp);
190       return true;
191     }
192 
193     return false;
194   } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
195     // This is a query, so operand 1 is a surfref/texref
196     if (OpNo == 1 && MO.isImm()) {
197       lowerImageHandleSymbol(MO.getImm(), MCOp);
198       return true;
199     }
200 
201     return false;
202   }
203 
204   return false;
205 }
206 
207 void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) {
208   // Ewwww
209   TargetMachine &TM = const_cast<TargetMachine &>(MF->getTarget());
210   NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine &>(TM);
211   const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>();
212   StringRef Sym = MFI->getImageHandleSymbol(Index);
213   StringRef SymName = nvTM.getStrPool().save(Sym);
214   MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(SymName));
215 }
216 
217 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
218   OutMI.setOpcode(MI->getOpcode());
219   // Special: Do not mangle symbol operand of CALL_PROTOTYPE
220   if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) {
221     const MachineOperand &MO = MI->getOperand(0);
222     OutMI.addOperand(GetSymbolRef(
223       OutContext.getOrCreateSymbol(Twine(MO.getSymbolName()))));
224     return;
225   }
226 
227   for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
228     const MachineOperand &MO = MI->getOperand(i);
229 
230     MCOperand MCOp;
231     if (lowerImageHandleOperand(MI, i, MCOp)) {
232       OutMI.addOperand(MCOp);
233       continue;
234     }
235 
236     if (lowerOperand(MO, MCOp))
237       OutMI.addOperand(MCOp);
238   }
239 }
240 
241 bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
242                                    MCOperand &MCOp) {
243   switch (MO.getType()) {
244   default: llvm_unreachable("unknown operand type");
245   case MachineOperand::MO_Register:
246     MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
247     break;
248   case MachineOperand::MO_Immediate:
249     MCOp = MCOperand::createImm(MO.getImm());
250     break;
251   case MachineOperand::MO_MachineBasicBlock:
252     MCOp = MCOperand::createExpr(MCSymbolRefExpr::create(
253         MO.getMBB()->getSymbol(), OutContext));
254     break;
255   case MachineOperand::MO_ExternalSymbol:
256     MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
257     break;
258   case MachineOperand::MO_GlobalAddress:
259     MCOp = GetSymbolRef(getSymbol(MO.getGlobal()));
260     break;
261   case MachineOperand::MO_FPImmediate: {
262     const ConstantFP *Cnt = MO.getFPImm();
263     const APFloat &Val = Cnt->getValueAPF();
264 
265     switch (Cnt->getType()->getTypeID()) {
266     default: report_fatal_error("Unsupported FP type"); break;
267     case Type::HalfTyID:
268       MCOp = MCOperand::createExpr(
269         NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
270       break;
271     case Type::BFloatTyID:
272       MCOp = MCOperand::createExpr(
273           NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext));
274       break;
275     case Type::FloatTyID:
276       MCOp = MCOperand::createExpr(
277         NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
278       break;
279     case Type::DoubleTyID:
280       MCOp = MCOperand::createExpr(
281         NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
282       break;
283     }
284     break;
285   }
286   }
287   return true;
288 }
289 
290 unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
291   if (Register::isVirtualRegister(Reg)) {
292     const TargetRegisterClass *RC = MRI->getRegClass(Reg);
293 
294     DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC];
295     unsigned RegNum = RegMap[Reg];
296 
297     // Encode the register class in the upper 4 bits
298     // Must be kept in sync with NVPTXInstPrinter::printRegName
299     unsigned Ret = 0;
300     if (RC == &NVPTX::Int1RegsRegClass) {
301       Ret = (1 << 28);
302     } else if (RC == &NVPTX::Int16RegsRegClass) {
303       Ret = (2 << 28);
304     } else if (RC == &NVPTX::Int32RegsRegClass) {
305       Ret = (3 << 28);
306     } else if (RC == &NVPTX::Int64RegsRegClass) {
307       Ret = (4 << 28);
308     } else if (RC == &NVPTX::Float32RegsRegClass) {
309       Ret = (5 << 28);
310     } else if (RC == &NVPTX::Float64RegsRegClass) {
311       Ret = (6 << 28);
312     } else if (RC == &NVPTX::Int128RegsRegClass) {
313       Ret = (7 << 28);
314     } else {
315       report_fatal_error("Bad register class");
316     }
317 
318     // Insert the vreg number
319     Ret |= (RegNum & 0x0FFFFFFF);
320     return Ret;
321   } else {
322     // Some special-use registers are actually physical registers.
323     // Encode this as the register class ID of 0 and the real register ID.
324     return Reg & 0x0FFFFFFF;
325   }
326 }
327 
328 MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
329   const MCExpr *Expr;
330   Expr = MCSymbolRefExpr::create(Symbol, MCSymbolRefExpr::VK_None,
331                                  OutContext);
332   return MCOperand::createExpr(Expr);
333 }
334 
335 static bool ShouldPassAsArray(Type *Ty) {
336   return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
337          Ty->isHalfTy() || Ty->isBFloatTy();
338 }
339 
340 void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
341   const DataLayout &DL = getDataLayout();
342   const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
343   const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
344 
345   Type *Ty = F->getReturnType();
346 
347   bool isABI = (STI.getSmVersion() >= 20);
348 
349   if (Ty->getTypeID() == Type::VoidTyID)
350     return;
351   O << " (";
352 
353   if (isABI) {
354     if ((Ty->isFloatingPointTy() || Ty->isIntegerTy()) &&
355         !ShouldPassAsArray(Ty)) {
356       unsigned size = 0;
357       if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
358         size = ITy->getBitWidth();
359       } else {
360         assert(Ty->isFloatingPointTy() && "Floating point type expected here");
361         size = Ty->getPrimitiveSizeInBits();
362       }
363       size = promoteScalarArgumentSize(size);
364       O << ".param .b" << size << " func_retval0";
365     } else if (isa<PointerType>(Ty)) {
366       O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
367         << " func_retval0";
368     } else if (ShouldPassAsArray(Ty)) {
369       unsigned totalsz = DL.getTypeAllocSize(Ty);
370       Align RetAlignment = TLI->getFunctionArgumentAlignment(
371           F, Ty, AttributeList::ReturnIndex, DL);
372       O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["
373         << totalsz << "]";
374     } else
375       llvm_unreachable("Unknown return type");
376   } else {
377     SmallVector<EVT, 16> vtparts;
378     ComputeValueVTs(*TLI, DL, Ty, vtparts);
379     unsigned idx = 0;
380     for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
381       unsigned elems = 1;
382       EVT elemtype = vtparts[i];
383       if (vtparts[i].isVector()) {
384         elems = vtparts[i].getVectorNumElements();
385         elemtype = vtparts[i].getVectorElementType();
386       }
387 
388       for (unsigned j = 0, je = elems; j != je; ++j) {
389         unsigned sz = elemtype.getSizeInBits();
390         if (elemtype.isInteger())
391           sz = promoteScalarArgumentSize(sz);
392         O << ".reg .b" << sz << " func_retval" << idx;
393         if (j < je - 1)
394           O << ", ";
395         ++idx;
396       }
397       if (i < e - 1)
398         O << ", ";
399     }
400   }
401   O << ") ";
402 }
403 
404 void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF,
405                                         raw_ostream &O) {
406   const Function &F = MF.getFunction();
407   printReturnValStr(&F, O);
408 }
409 
410 // Return true if MBB is the header of a loop marked with
411 // llvm.loop.unroll.disable or llvm.loop.unroll.count=1.
412 bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll(
413     const MachineBasicBlock &MBB) const {
414   MachineLoopInfo &LI = getAnalysis<MachineLoopInfoWrapperPass>().getLI();
415   // We insert .pragma "nounroll" only to the loop header.
416   if (!LI.isLoopHeader(&MBB))
417     return false;
418 
419   // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore,
420   // we iterate through each back edge of the loop with header MBB, and check
421   // whether its metadata contains llvm.loop.unroll.disable.
422   for (const MachineBasicBlock *PMBB : MBB.predecessors()) {
423     if (LI.getLoopFor(PMBB) != LI.getLoopFor(&MBB)) {
424       // Edges from other loops to MBB are not back edges.
425       continue;
426     }
427     if (const BasicBlock *PBB = PMBB->getBasicBlock()) {
428       if (MDNode *LoopID =
429               PBB->getTerminator()->getMetadata(LLVMContext::MD_loop)) {
430         if (GetUnrollMetadata(LoopID, "llvm.loop.unroll.disable"))
431           return true;
432         if (MDNode *UnrollCountMD =
433                 GetUnrollMetadata(LoopID, "llvm.loop.unroll.count")) {
434           if (mdconst::extract<ConstantInt>(UnrollCountMD->getOperand(1))
435                   ->isOne())
436             return true;
437         }
438       }
439     }
440   }
441   return false;
442 }
443 
444 void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
445   AsmPrinter::emitBasicBlockStart(MBB);
446   if (isLoopHeaderOfNoUnroll(MBB))
447     OutStreamer->emitRawText(StringRef("\t.pragma \"nounroll\";\n"));
448 }
449 
450 void NVPTXAsmPrinter::emitFunctionEntryLabel() {
451   SmallString<128> Str;
452   raw_svector_ostream O(Str);
453 
454   if (!GlobalsEmitted) {
455     emitGlobals(*MF->getFunction().getParent());
456     GlobalsEmitted = true;
457   }
458 
459   // Set up
460   MRI = &MF->getRegInfo();
461   F = &MF->getFunction();
462   emitLinkageDirective(F, O);
463   if (isKernelFunction(*F))
464     O << ".entry ";
465   else {
466     O << ".func ";
467     printReturnValStr(*MF, O);
468   }
469 
470   CurrentFnSym->print(O, MAI);
471 
472   emitFunctionParamList(F, O);
473   O << "\n";
474 
475   if (isKernelFunction(*F))
476     emitKernelFunctionDirectives(*F, O);
477 
478   if (shouldEmitPTXNoReturn(F, TM))
479     O << ".noreturn";
480 
481   OutStreamer->emitRawText(O.str());
482 
483   VRegMapping.clear();
484   // Emit open brace for function body.
485   OutStreamer->emitRawText(StringRef("{\n"));
486   setAndEmitFunctionVirtualRegisters(*MF);
487   encodeDebugInfoRegisterNumbers(*MF);
488   // Emit initial .loc debug directive for correct relocation symbol data.
489   if (const DISubprogram *SP = MF->getFunction().getSubprogram()) {
490     assert(SP->getUnit());
491     if (!SP->getUnit()->isDebugDirectivesOnly())
492       emitInitialRawDwarfLocDirective(*MF);
493   }
494 }
495 
496 bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) {
497   bool Result = AsmPrinter::runOnMachineFunction(F);
498   // Emit closing brace for the body of function F.
499   // The closing brace must be emitted here because we need to emit additional
500   // debug labels/data after the last basic block.
501   // We need to emit the closing brace here because we don't have function that
502   // finished emission of the function body.
503   OutStreamer->emitRawText(StringRef("}\n"));
504   return Result;
505 }
506 
507 void NVPTXAsmPrinter::emitFunctionBodyStart() {
508   SmallString<128> Str;
509   raw_svector_ostream O(Str);
510   emitDemotedVars(&MF->getFunction(), O);
511   OutStreamer->emitRawText(O.str());
512 }
513 
514 void NVPTXAsmPrinter::emitFunctionBodyEnd() {
515   VRegMapping.clear();
516 }
517 
518 const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const {
519     SmallString<128> Str;
520     raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber();
521     return OutContext.getOrCreateSymbol(Str);
522 }
523 
524 void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const {
525   Register RegNo = MI->getOperand(0).getReg();
526   if (RegNo.isVirtual()) {
527     OutStreamer->AddComment(Twine("implicit-def: ") +
528                             getVirtualRegisterName(RegNo));
529   } else {
530     const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
531     OutStreamer->AddComment(Twine("implicit-def: ") +
532                             STI.getRegisterInfo()->getName(RegNo));
533   }
534   OutStreamer->addBlankLine();
535 }
536 
537 void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
538                                                    raw_ostream &O) const {
539   // If the NVVM IR has some of reqntid* specified, then output
540   // the reqntid directive, and set the unspecified ones to 1.
541   // If none of Reqntid* is specified, don't output reqntid directive.
542   std::optional<unsigned> Reqntidx = getReqNTIDx(F);
543   std::optional<unsigned> Reqntidy = getReqNTIDy(F);
544   std::optional<unsigned> Reqntidz = getReqNTIDz(F);
545 
546   if (Reqntidx || Reqntidy || Reqntidz)
547     O << ".reqntid " << Reqntidx.value_or(1) << ", " << Reqntidy.value_or(1)
548       << ", " << Reqntidz.value_or(1) << "\n";
549 
550   // If the NVVM IR has some of maxntid* specified, then output
551   // the maxntid directive, and set the unspecified ones to 1.
552   // If none of maxntid* is specified, don't output maxntid directive.
553   std::optional<unsigned> Maxntidx = getMaxNTIDx(F);
554   std::optional<unsigned> Maxntidy = getMaxNTIDy(F);
555   std::optional<unsigned> Maxntidz = getMaxNTIDz(F);
556 
557   if (Maxntidx || Maxntidy || Maxntidz)
558     O << ".maxntid " << Maxntidx.value_or(1) << ", " << Maxntidy.value_or(1)
559       << ", " << Maxntidz.value_or(1) << "\n";
560 
561   if (const auto Mincta = getMinCTASm(F))
562     O << ".minnctapersm " << *Mincta << "\n";
563 
564   if (const auto Maxnreg = getMaxNReg(F))
565     O << ".maxnreg " << *Maxnreg << "\n";
566 
567   // .maxclusterrank directive requires SM_90 or higher, make sure that we
568   // filter it out for lower SM versions, as it causes a hard ptxas crash.
569   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
570   const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
571 
572   if (STI->getSmVersion() >= 90) {
573     std::optional<unsigned> ClusterX = getClusterDimx(F);
574     std::optional<unsigned> ClusterY = getClusterDimy(F);
575     std::optional<unsigned> ClusterZ = getClusterDimz(F);
576 
577     if (ClusterX || ClusterY || ClusterZ) {
578       O << ".explicitcluster\n";
579       if (ClusterX.value_or(1) != 0) {
580         assert(ClusterY.value_or(1) && ClusterZ.value_or(1) &&
581                "cluster_dim_x != 0 implies cluster_dim_y and cluster_dim_z "
582                "should be non-zero as well");
583 
584         O << ".reqnctapercluster " << ClusterX.value_or(1) << ", "
585           << ClusterY.value_or(1) << ", " << ClusterZ.value_or(1) << "\n";
586       } else {
587         assert(!ClusterY.value_or(1) && !ClusterZ.value_or(1) &&
588                "cluster_dim_x == 0 implies cluster_dim_y and cluster_dim_z "
589                "should be 0 as well");
590       }
591     }
592     if (const auto Maxclusterrank = getMaxClusterRank(F))
593       O << ".maxclusterrank " << *Maxclusterrank << "\n";
594   }
595 }
596 
597 std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
598   const TargetRegisterClass *RC = MRI->getRegClass(Reg);
599 
600   std::string Name;
601   raw_string_ostream NameStr(Name);
602 
603   VRegRCMap::const_iterator I = VRegMapping.find(RC);
604   assert(I != VRegMapping.end() && "Bad register class");
605   const DenseMap<unsigned, unsigned> &RegMap = I->second;
606 
607   VRegMap::const_iterator VI = RegMap.find(Reg);
608   assert(VI != RegMap.end() && "Bad virtual register");
609   unsigned MappedVR = VI->second;
610 
611   NameStr << getNVPTXRegClassStr(RC) << MappedVR;
612 
613   return Name;
614 }
615 
616 void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr,
617                                           raw_ostream &O) {
618   O << getVirtualRegisterName(vr);
619 }
620 
621 void NVPTXAsmPrinter::emitAliasDeclaration(const GlobalAlias *GA,
622                                            raw_ostream &O) {
623   const Function *F = dyn_cast_or_null<Function>(GA->getAliaseeObject());
624   if (!F || isKernelFunction(*F) || F->isDeclaration())
625     report_fatal_error(
626         "NVPTX aliasee must be a non-kernel function definition");
627 
628   if (GA->hasLinkOnceLinkage() || GA->hasWeakLinkage() ||
629       GA->hasAvailableExternallyLinkage() || GA->hasCommonLinkage())
630     report_fatal_error("NVPTX aliasee must not be '.weak'");
631 
632   emitDeclarationWithName(F, getSymbol(GA), O);
633 }
634 
635 void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) {
636   emitDeclarationWithName(F, getSymbol(F), O);
637 }
638 
639 void NVPTXAsmPrinter::emitDeclarationWithName(const Function *F, MCSymbol *S,
640                                               raw_ostream &O) {
641   emitLinkageDirective(F, O);
642   if (isKernelFunction(*F))
643     O << ".entry ";
644   else
645     O << ".func ";
646   printReturnValStr(F, O);
647   S->print(O, MAI);
648   O << "\n";
649   emitFunctionParamList(F, O);
650   O << "\n";
651   if (shouldEmitPTXNoReturn(F, TM))
652     O << ".noreturn";
653   O << ";\n";
654 }
655 
656 static bool usedInGlobalVarDef(const Constant *C) {
657   if (!C)
658     return false;
659 
660   if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) {
661     return GV->getName() != "llvm.used";
662   }
663 
664   for (const User *U : C->users())
665     if (const Constant *C = dyn_cast<Constant>(U))
666       if (usedInGlobalVarDef(C))
667         return true;
668 
669   return false;
670 }
671 
672 static bool usedInOneFunc(const User *U, Function const *&oneFunc) {
673   if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) {
674     if (othergv->getName() == "llvm.used")
675       return true;
676   }
677 
678   if (const Instruction *instr = dyn_cast<Instruction>(U)) {
679     if (instr->getParent() && instr->getParent()->getParent()) {
680       const Function *curFunc = instr->getParent()->getParent();
681       if (oneFunc && (curFunc != oneFunc))
682         return false;
683       oneFunc = curFunc;
684       return true;
685     } else
686       return false;
687   }
688 
689   for (const User *UU : U->users())
690     if (!usedInOneFunc(UU, oneFunc))
691       return false;
692 
693   return true;
694 }
695 
696 /* Find out if a global variable can be demoted to local scope.
697  * Currently, this is valid for CUDA shared variables, which have local
698  * scope and global lifetime. So the conditions to check are :
699  * 1. Is the global variable in shared address space?
700  * 2. Does it have local linkage?
701  * 3. Is the global variable referenced only in one function?
702  */
703 static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) {
704   if (!gv->hasLocalLinkage())
705     return false;
706   PointerType *Pty = gv->getType();
707   if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED)
708     return false;
709 
710   const Function *oneFunc = nullptr;
711 
712   bool flag = usedInOneFunc(gv, oneFunc);
713   if (!flag)
714     return false;
715   if (!oneFunc)
716     return false;
717   f = oneFunc;
718   return true;
719 }
720 
721 static bool useFuncSeen(const Constant *C,
722                         DenseMap<const Function *, bool> &seenMap) {
723   for (const User *U : C->users()) {
724     if (const Constant *cu = dyn_cast<Constant>(U)) {
725       if (useFuncSeen(cu, seenMap))
726         return true;
727     } else if (const Instruction *I = dyn_cast<Instruction>(U)) {
728       const BasicBlock *bb = I->getParent();
729       if (!bb)
730         continue;
731       const Function *caller = bb->getParent();
732       if (!caller)
733         continue;
734       if (seenMap.contains(caller))
735         return true;
736     }
737   }
738   return false;
739 }
740 
741 void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
742   DenseMap<const Function *, bool> seenMap;
743   for (const Function &F : M) {
744     if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) {
745       emitDeclaration(&F, O);
746       continue;
747     }
748 
749     if (F.isDeclaration()) {
750       if (F.use_empty())
751         continue;
752       if (F.getIntrinsicID())
753         continue;
754       emitDeclaration(&F, O);
755       continue;
756     }
757     for (const User *U : F.users()) {
758       if (const Constant *C = dyn_cast<Constant>(U)) {
759         if (usedInGlobalVarDef(C)) {
760           // The use is in the initialization of a global variable
761           // that is a function pointer, so print a declaration
762           // for the original function
763           emitDeclaration(&F, O);
764           break;
765         }
766         // Emit a declaration of this function if the function that
767         // uses this constant expr has already been seen.
768         if (useFuncSeen(C, seenMap)) {
769           emitDeclaration(&F, O);
770           break;
771         }
772       }
773 
774       if (!isa<Instruction>(U))
775         continue;
776       const Instruction *instr = cast<Instruction>(U);
777       const BasicBlock *bb = instr->getParent();
778       if (!bb)
779         continue;
780       const Function *caller = bb->getParent();
781       if (!caller)
782         continue;
783 
784       // If a caller has already been seen, then the caller is
785       // appearing in the module before the callee. so print out
786       // a declaration for the callee.
787       if (seenMap.contains(caller)) {
788         emitDeclaration(&F, O);
789         break;
790       }
791     }
792     seenMap[&F] = true;
793   }
794   for (const GlobalAlias &GA : M.aliases())
795     emitAliasDeclaration(&GA, O);
796 }
797 
798 static bool isEmptyXXStructor(GlobalVariable *GV) {
799   if (!GV) return true;
800   const ConstantArray *InitList = dyn_cast<ConstantArray>(GV->getInitializer());
801   if (!InitList) return true;  // Not an array; we don't know how to parse.
802   return InitList->getNumOperands() == 0;
803 }
804 
805 void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) {
806   // Construct a default subtarget off of the TargetMachine defaults. The
807   // rest of NVPTX isn't friendly to change subtargets per function and
808   // so the default TargetMachine will have all of the options.
809   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
810   const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl());
811   SmallString<128> Str1;
812   raw_svector_ostream OS1(Str1);
813 
814   // Emit header before any dwarf directives are emitted below.
815   emitHeader(M, OS1, *STI);
816   OutStreamer->emitRawText(OS1.str());
817 }
818 
819 bool NVPTXAsmPrinter::doInitialization(Module &M) {
820   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
821   const NVPTXSubtarget &STI =
822       *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
823   if (M.alias_size() && (STI.getPTXVersion() < 63 || STI.getSmVersion() < 30))
824     report_fatal_error(".alias requires PTX version >= 6.3 and sm_30");
825 
826   // OpenMP supports NVPTX global constructors and destructors.
827   bool IsOpenMP = M.getModuleFlag("openmp") != nullptr;
828 
829   if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors")) &&
830       !LowerCtorDtor && !IsOpenMP) {
831     report_fatal_error(
832         "Module has a nontrivial global ctor, which NVPTX does not support.");
833     return true;  // error
834   }
835   if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors")) &&
836       !LowerCtorDtor && !IsOpenMP) {
837     report_fatal_error(
838         "Module has a nontrivial global dtor, which NVPTX does not support.");
839     return true;  // error
840   }
841 
842   // We need to call the parent's one explicitly.
843   bool Result = AsmPrinter::doInitialization(M);
844 
845   GlobalsEmitted = false;
846 
847   return Result;
848 }
849 
850 void NVPTXAsmPrinter::emitGlobals(const Module &M) {
851   SmallString<128> Str2;
852   raw_svector_ostream OS2(Str2);
853 
854   emitDeclarations(M, OS2);
855 
856   // As ptxas does not support forward references of globals, we need to first
857   // sort the list of module-level globals in def-use order. We visit each
858   // global variable in order, and ensure that we emit it *after* its dependent
859   // globals. We use a little extra memory maintaining both a set and a list to
860   // have fast searches while maintaining a strict ordering.
861   SmallVector<const GlobalVariable *, 8> Globals;
862   DenseSet<const GlobalVariable *> GVVisited;
863   DenseSet<const GlobalVariable *> GVVisiting;
864 
865   // Visit each global variable, in order
866   for (const GlobalVariable &I : M.globals())
867     VisitGlobalVariableForEmission(&I, Globals, GVVisited, GVVisiting);
868 
869   assert(GVVisited.size() == M.global_size() && "Missed a global variable");
870   assert(GVVisiting.size() == 0 && "Did not fully process a global variable");
871 
872   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
873   const NVPTXSubtarget &STI =
874       *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
875 
876   // Print out module-level global variables in proper order
877   for (const GlobalVariable *GV : Globals)
878     printModuleLevelGV(GV, OS2, /*processDemoted=*/false, STI);
879 
880   OS2 << '\n';
881 
882   OutStreamer->emitRawText(OS2.str());
883 }
884 
885 void NVPTXAsmPrinter::emitGlobalAlias(const Module &M, const GlobalAlias &GA) {
886   SmallString<128> Str;
887   raw_svector_ostream OS(Str);
888 
889   MCSymbol *Name = getSymbol(&GA);
890 
891   OS << ".alias " << Name->getName() << ", " << GA.getAliaseeObject()->getName()
892      << ";\n";
893 
894   OutStreamer->emitRawText(OS.str());
895 }
896 
897 void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
898                                  const NVPTXSubtarget &STI) {
899   O << "//\n";
900   O << "// Generated by LLVM NVPTX Back-End\n";
901   O << "//\n";
902   O << "\n";
903 
904   unsigned PTXVersion = STI.getPTXVersion();
905   O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n";
906 
907   O << ".target ";
908   O << STI.getTargetName();
909 
910   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
911   if (NTM.getDrvInterface() == NVPTX::NVCL)
912     O << ", texmode_independent";
913 
914   bool HasFullDebugInfo = false;
915   for (DICompileUnit *CU : M.debug_compile_units()) {
916     switch(CU->getEmissionKind()) {
917     case DICompileUnit::NoDebug:
918     case DICompileUnit::DebugDirectivesOnly:
919       break;
920     case DICompileUnit::LineTablesOnly:
921     case DICompileUnit::FullDebug:
922       HasFullDebugInfo = true;
923       break;
924     }
925     if (HasFullDebugInfo)
926       break;
927   }
928   if (HasFullDebugInfo)
929     O << ", debug";
930 
931   O << "\n";
932 
933   O << ".address_size ";
934   if (NTM.is64Bit())
935     O << "64";
936   else
937     O << "32";
938   O << "\n";
939 
940   O << "\n";
941 }
942 
943 bool NVPTXAsmPrinter::doFinalization(Module &M) {
944   // If we did not emit any functions, then the global declarations have not
945   // yet been emitted.
946   if (!GlobalsEmitted) {
947     emitGlobals(M);
948     GlobalsEmitted = true;
949   }
950 
951   // call doFinalization
952   bool ret = AsmPrinter::doFinalization(M);
953 
954   clearAnnotationCache(&M);
955 
956   auto *TS =
957       static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer());
958   // Close the last emitted section
959   if (hasDebugInfo()) {
960     TS->closeLastSection();
961     // Emit empty .debug_macinfo section for better support of the empty files.
962     OutStreamer->emitRawText("\t.section\t.debug_macinfo\t{\t}");
963   }
964 
965   // Output last DWARF .file directives, if any.
966   TS->outputDwarfFileDirectives();
967 
968   return ret;
969 }
970 
971 // This function emits appropriate linkage directives for
972 // functions and global variables.
973 //
974 // extern function declaration            -> .extern
975 // extern function definition             -> .visible
976 // external global variable with init     -> .visible
977 // external without init                  -> .extern
978 // appending                              -> not allowed, assert.
979 // for any linkage other than
980 // internal, private, linker_private,
981 // linker_private_weak, linker_private_weak_def_auto,
982 // we emit                                -> .weak.
983 
984 void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
985                                            raw_ostream &O) {
986   if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
987     if (V->hasExternalLinkage()) {
988       if (isa<GlobalVariable>(V)) {
989         const GlobalVariable *GVar = cast<GlobalVariable>(V);
990         if (GVar) {
991           if (GVar->hasInitializer())
992             O << ".visible ";
993           else
994             O << ".extern ";
995         }
996       } else if (V->isDeclaration())
997         O << ".extern ";
998       else
999         O << ".visible ";
1000     } else if (V->hasAppendingLinkage()) {
1001       std::string msg;
1002       msg.append("Error: ");
1003       msg.append("Symbol ");
1004       if (V->hasName())
1005         msg.append(std::string(V->getName()));
1006       msg.append("has unsupported appending linkage type");
1007       llvm_unreachable(msg.c_str());
1008     } else if (!V->hasInternalLinkage() &&
1009                !V->hasPrivateLinkage()) {
1010       O << ".weak ";
1011     }
1012   }
1013 }
1014 
1015 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
1016                                          raw_ostream &O, bool processDemoted,
1017                                          const NVPTXSubtarget &STI) {
1018   // Skip meta data
1019   if (GVar->hasSection()) {
1020     if (GVar->getSection() == "llvm.metadata")
1021       return;
1022   }
1023 
1024   // Skip LLVM intrinsic global variables
1025   if (GVar->getName().starts_with("llvm.") ||
1026       GVar->getName().starts_with("nvvm."))
1027     return;
1028 
1029   const DataLayout &DL = getDataLayout();
1030 
1031   // GlobalVariables are always constant pointers themselves.
1032   Type *ETy = GVar->getValueType();
1033 
1034   if (GVar->hasExternalLinkage()) {
1035     if (GVar->hasInitializer())
1036       O << ".visible ";
1037     else
1038       O << ".extern ";
1039   } else if (STI.getPTXVersion() >= 50 && GVar->hasCommonLinkage() &&
1040              GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) {
1041     O << ".common ";
1042   } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() ||
1043              GVar->hasAvailableExternallyLinkage() ||
1044              GVar->hasCommonLinkage()) {
1045     O << ".weak ";
1046   }
1047 
1048   if (isTexture(*GVar)) {
1049     O << ".global .texref " << getTextureName(*GVar) << ";\n";
1050     return;
1051   }
1052 
1053   if (isSurface(*GVar)) {
1054     O << ".global .surfref " << getSurfaceName(*GVar) << ";\n";
1055     return;
1056   }
1057 
1058   if (GVar->isDeclaration()) {
1059     // (extern) declarations, no definition or initializer
1060     // Currently the only known declaration is for an automatic __local
1061     // (.shared) promoted to global.
1062     emitPTXGlobalVariable(GVar, O, STI);
1063     O << ";\n";
1064     return;
1065   }
1066 
1067   if (isSampler(*GVar)) {
1068     O << ".global .samplerref " << getSamplerName(*GVar);
1069 
1070     const Constant *Initializer = nullptr;
1071     if (GVar->hasInitializer())
1072       Initializer = GVar->getInitializer();
1073     const ConstantInt *CI = nullptr;
1074     if (Initializer)
1075       CI = dyn_cast<ConstantInt>(Initializer);
1076     if (CI) {
1077       unsigned sample = CI->getZExtValue();
1078 
1079       O << " = { ";
1080 
1081       for (int i = 0,
1082                addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE);
1083            i < 3; i++) {
1084         O << "addr_mode_" << i << " = ";
1085         switch (addr) {
1086         case 0:
1087           O << "wrap";
1088           break;
1089         case 1:
1090           O << "clamp_to_border";
1091           break;
1092         case 2:
1093           O << "clamp_to_edge";
1094           break;
1095         case 3:
1096           O << "wrap";
1097           break;
1098         case 4:
1099           O << "mirror";
1100           break;
1101         }
1102         O << ", ";
1103       }
1104       O << "filter_mode = ";
1105       switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) {
1106       case 0:
1107         O << "nearest";
1108         break;
1109       case 1:
1110         O << "linear";
1111         break;
1112       case 2:
1113         llvm_unreachable("Anisotropic filtering is not supported");
1114       default:
1115         O << "nearest";
1116         break;
1117       }
1118       if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) {
1119         O << ", force_unnormalized_coords = 1";
1120       }
1121       O << " }";
1122     }
1123 
1124     O << ";\n";
1125     return;
1126   }
1127 
1128   if (GVar->hasPrivateLinkage()) {
1129     if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0)
1130       return;
1131 
1132     // FIXME - need better way (e.g. Metadata) to avoid generating this global
1133     if (strncmp(GVar->getName().data(), "filename", 8) == 0)
1134       return;
1135     if (GVar->use_empty())
1136       return;
1137   }
1138 
1139   const Function *demotedFunc = nullptr;
1140   if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) {
1141     O << "// " << GVar->getName() << " has been demoted\n";
1142     localDecls[demotedFunc].push_back(GVar);
1143     return;
1144   }
1145 
1146   O << ".";
1147   emitPTXAddressSpace(GVar->getAddressSpace(), O);
1148 
1149   if (isManaged(*GVar)) {
1150     if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1151       report_fatal_error(
1152           ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1153     }
1154     O << " .attribute(.managed)";
1155   }
1156 
1157   if (MaybeAlign A = GVar->getAlign())
1158     O << " .align " << A->value();
1159   else
1160     O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
1161 
1162   if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
1163       (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
1164     O << " .";
1165     // Special case: ABI requires that we use .u8 for predicates
1166     if (ETy->isIntegerTy(1))
1167       O << "u8";
1168     else
1169       O << getPTXFundamentalTypeStr(ETy, false);
1170     O << " ";
1171     getSymbol(GVar)->print(O, MAI);
1172 
1173     // Ptx allows variable initilization only for constant and global state
1174     // spaces.
1175     if (GVar->hasInitializer()) {
1176       if ((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1177           (GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) {
1178         const Constant *Initializer = GVar->getInitializer();
1179         // 'undef' is treated as there is no value specified.
1180         if (!Initializer->isNullValue() && !isa<UndefValue>(Initializer)) {
1181           O << " = ";
1182           printScalarConstant(Initializer, O);
1183         }
1184       } else {
1185         // The frontend adds zero-initializer to device and constant variables
1186         // that don't have an initial value, and UndefValue to shared
1187         // variables, so skip warning for this case.
1188         if (!GVar->getInitializer()->isNullValue() &&
1189             !isa<UndefValue>(GVar->getInitializer())) {
1190           report_fatal_error("initial value of '" + GVar->getName() +
1191                              "' is not allowed in addrspace(" +
1192                              Twine(GVar->getAddressSpace()) + ")");
1193         }
1194       }
1195     }
1196   } else {
1197     uint64_t ElementSize = 0;
1198 
1199     // Although PTX has direct support for struct type and array type and
1200     // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
1201     // targets that support these high level field accesses. Structs, arrays
1202     // and vectors are lowered into arrays of bytes.
1203     switch (ETy->getTypeID()) {
1204     case Type::IntegerTyID: // Integers larger than 64 bits
1205     case Type::StructTyID:
1206     case Type::ArrayTyID:
1207     case Type::FixedVectorTyID:
1208       ElementSize = DL.getTypeStoreSize(ETy);
1209       // Ptx allows variable initilization only for constant and
1210       // global state spaces.
1211       if (((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1212            (GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) &&
1213           GVar->hasInitializer()) {
1214         const Constant *Initializer = GVar->getInitializer();
1215         if (!isa<UndefValue>(Initializer) && !Initializer->isNullValue()) {
1216           AggBuffer aggBuffer(ElementSize, *this);
1217           bufferAggregateConstant(Initializer, &aggBuffer);
1218           if (aggBuffer.numSymbols()) {
1219             unsigned int ptrSize = MAI->getCodePointerSize();
1220             if (ElementSize % ptrSize ||
1221                 !aggBuffer.allSymbolsAligned(ptrSize)) {
1222               // Print in bytes and use the mask() operator for pointers.
1223               if (!STI.hasMaskOperator())
1224                 report_fatal_error(
1225                     "initialized packed aggregate with pointers '" +
1226                     GVar->getName() +
1227                     "' requires at least PTX ISA version 7.1");
1228               O << " .u8 ";
1229               getSymbol(GVar)->print(O, MAI);
1230               O << "[" << ElementSize << "] = {";
1231               aggBuffer.printBytes(O);
1232               O << "}";
1233             } else {
1234               O << " .u" << ptrSize * 8 << " ";
1235               getSymbol(GVar)->print(O, MAI);
1236               O << "[" << ElementSize / ptrSize << "] = {";
1237               aggBuffer.printWords(O);
1238               O << "}";
1239             }
1240           } else {
1241             O << " .b8 ";
1242             getSymbol(GVar)->print(O, MAI);
1243             O << "[" << ElementSize << "] = {";
1244             aggBuffer.printBytes(O);
1245             O << "}";
1246           }
1247         } else {
1248           O << " .b8 ";
1249           getSymbol(GVar)->print(O, MAI);
1250           if (ElementSize) {
1251             O << "[";
1252             O << ElementSize;
1253             O << "]";
1254           }
1255         }
1256       } else {
1257         O << " .b8 ";
1258         getSymbol(GVar)->print(O, MAI);
1259         if (ElementSize) {
1260           O << "[";
1261           O << ElementSize;
1262           O << "]";
1263         }
1264       }
1265       break;
1266     default:
1267       llvm_unreachable("type not supported yet");
1268     }
1269   }
1270   O << ";\n";
1271 }
1272 
1273 void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) {
1274   const Value *v = Symbols[nSym];
1275   const Value *v0 = SymbolsBeforeStripping[nSym];
1276   if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) {
1277     MCSymbol *Name = AP.getSymbol(GVar);
1278     PointerType *PTy = dyn_cast<PointerType>(v0->getType());
1279     // Is v0 a generic pointer?
1280     bool isGenericPointer = PTy && PTy->getAddressSpace() == 0;
1281     if (EmitGeneric && isGenericPointer && !isa<Function>(v)) {
1282       os << "generic(";
1283       Name->print(os, AP.MAI);
1284       os << ")";
1285     } else {
1286       Name->print(os, AP.MAI);
1287     }
1288   } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) {
1289     const MCExpr *Expr = AP.lowerConstantForGV(cast<Constant>(CExpr), false);
1290     AP.printMCExpr(*Expr, os);
1291   } else
1292     llvm_unreachable("symbol type unknown");
1293 }
1294 
1295 void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) {
1296   unsigned int ptrSize = AP.MAI->getCodePointerSize();
1297   // Do not emit trailing zero initializers. They will be zero-initialized by
1298   // ptxas. This saves on both space requirements for the generated PTX and on
1299   // memory use by ptxas. (See:
1300   // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#global-state-space)
1301   unsigned int InitializerCount = size;
1302   // TODO: symbols make this harder, but it would still be good to trim trailing
1303   // 0s for aggs with symbols as well.
1304   if (numSymbols() == 0)
1305     while (InitializerCount >= 1 && !buffer[InitializerCount - 1])
1306       InitializerCount--;
1307 
1308   symbolPosInBuffer.push_back(InitializerCount);
1309   unsigned int nSym = 0;
1310   unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1311   for (unsigned int pos = 0; pos < InitializerCount;) {
1312     if (pos)
1313       os << ", ";
1314     if (pos != nextSymbolPos) {
1315       os << (unsigned int)buffer[pos];
1316       ++pos;
1317       continue;
1318     }
1319     // Generate a per-byte mask() operator for the symbol, which looks like:
1320     //   .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...};
1321     // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers
1322     std::string symText;
1323     llvm::raw_string_ostream oss(symText);
1324     printSymbol(nSym, oss);
1325     for (unsigned i = 0; i < ptrSize; ++i) {
1326       if (i)
1327         os << ", ";
1328       llvm::write_hex(os, 0xFFULL << i * 8, HexPrintStyle::PrefixUpper);
1329       os << "(" << symText << ")";
1330     }
1331     pos += ptrSize;
1332     nextSymbolPos = symbolPosInBuffer[++nSym];
1333     assert(nextSymbolPos >= pos);
1334   }
1335 }
1336 
1337 void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) {
1338   unsigned int ptrSize = AP.MAI->getCodePointerSize();
1339   symbolPosInBuffer.push_back(size);
1340   unsigned int nSym = 0;
1341   unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1342   assert(nextSymbolPos % ptrSize == 0);
1343   for (unsigned int pos = 0; pos < size; pos += ptrSize) {
1344     if (pos)
1345       os << ", ";
1346     if (pos == nextSymbolPos) {
1347       printSymbol(nSym, os);
1348       nextSymbolPos = symbolPosInBuffer[++nSym];
1349       assert(nextSymbolPos % ptrSize == 0);
1350       assert(nextSymbolPos >= pos + ptrSize);
1351     } else if (ptrSize == 4)
1352       os << support::endian::read32le(&buffer[pos]);
1353     else
1354       os << support::endian::read64le(&buffer[pos]);
1355   }
1356 }
1357 
1358 void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) {
1359   auto It = localDecls.find(f);
1360   if (It == localDecls.end())
1361     return;
1362 
1363   std::vector<const GlobalVariable *> &gvars = It->second;
1364 
1365   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
1366   const NVPTXSubtarget &STI =
1367       *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
1368 
1369   for (const GlobalVariable *GV : gvars) {
1370     O << "\t// demoted variable\n\t";
1371     printModuleLevelGV(GV, O, /*processDemoted=*/true, STI);
1372   }
1373 }
1374 
1375 void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace,
1376                                           raw_ostream &O) const {
1377   switch (AddressSpace) {
1378   case ADDRESS_SPACE_LOCAL:
1379     O << "local";
1380     break;
1381   case ADDRESS_SPACE_GLOBAL:
1382     O << "global";
1383     break;
1384   case ADDRESS_SPACE_CONST:
1385     O << "const";
1386     break;
1387   case ADDRESS_SPACE_SHARED:
1388     O << "shared";
1389     break;
1390   default:
1391     report_fatal_error("Bad address space found while emitting PTX: " +
1392                        llvm::Twine(AddressSpace));
1393     break;
1394   }
1395 }
1396 
1397 std::string
1398 NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
1399   switch (Ty->getTypeID()) {
1400   case Type::IntegerTyID: {
1401     unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth();
1402     if (NumBits == 1)
1403       return "pred";
1404     else if (NumBits <= 64) {
1405       std::string name = "u";
1406       return name + utostr(NumBits);
1407     } else {
1408       llvm_unreachable("Integer too large");
1409       break;
1410     }
1411     break;
1412   }
1413   case Type::BFloatTyID:
1414   case Type::HalfTyID:
1415     // fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53
1416     // PTX assembly.
1417     return "b16";
1418   case Type::FloatTyID:
1419     return "f32";
1420   case Type::DoubleTyID:
1421     return "f64";
1422   case Type::PointerTyID: {
1423     unsigned PtrSize = TM.getPointerSizeInBits(Ty->getPointerAddressSpace());
1424     assert((PtrSize == 64 || PtrSize == 32) && "Unexpected pointer size");
1425 
1426     if (PtrSize == 64)
1427       if (useB4PTR)
1428         return "b64";
1429       else
1430         return "u64";
1431     else if (useB4PTR)
1432       return "b32";
1433     else
1434       return "u32";
1435   }
1436   default:
1437     break;
1438   }
1439   llvm_unreachable("unexpected type");
1440 }
1441 
1442 void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
1443                                             raw_ostream &O,
1444                                             const NVPTXSubtarget &STI) {
1445   const DataLayout &DL = getDataLayout();
1446 
1447   // GlobalVariables are always constant pointers themselves.
1448   Type *ETy = GVar->getValueType();
1449 
1450   O << ".";
1451   emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O);
1452   if (isManaged(*GVar)) {
1453     if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1454       report_fatal_error(
1455           ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1456     }
1457     O << " .attribute(.managed)";
1458   }
1459   if (MaybeAlign A = GVar->getAlign())
1460     O << " .align " << A->value();
1461   else
1462     O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
1463 
1464   // Special case for i128
1465   if (ETy->isIntegerTy(128)) {
1466     O << " .b8 ";
1467     getSymbol(GVar)->print(O, MAI);
1468     O << "[16]";
1469     return;
1470   }
1471 
1472   if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {
1473     O << " .";
1474     O << getPTXFundamentalTypeStr(ETy);
1475     O << " ";
1476     getSymbol(GVar)->print(O, MAI);
1477     return;
1478   }
1479 
1480   int64_t ElementSize = 0;
1481 
1482   // Although PTX has direct support for struct type and array type and LLVM IR
1483   // is very similar to PTX, the LLVM CodeGen does not support for targets that
1484   // support these high level field accesses. Structs and arrays are lowered
1485   // into arrays of bytes.
1486   switch (ETy->getTypeID()) {
1487   case Type::StructTyID:
1488   case Type::ArrayTyID:
1489   case Type::FixedVectorTyID:
1490     ElementSize = DL.getTypeStoreSize(ETy);
1491     O << " .b8 ";
1492     getSymbol(GVar)->print(O, MAI);
1493     O << "[";
1494     if (ElementSize) {
1495       O << ElementSize;
1496     }
1497     O << "]";
1498     break;
1499   default:
1500     llvm_unreachable("type not supported yet");
1501   }
1502 }
1503 
1504 void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
1505   const DataLayout &DL = getDataLayout();
1506   const AttributeList &PAL = F->getAttributes();
1507   const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
1508   const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
1509   const NVPTXMachineFunctionInfo *MFI =
1510       MF ? MF->getInfo<NVPTXMachineFunctionInfo>() : nullptr;
1511 
1512   Function::const_arg_iterator I, E;
1513   unsigned paramIndex = 0;
1514   bool first = true;
1515   bool isKernelFunc = isKernelFunction(*F);
1516   bool isABI = (STI.getSmVersion() >= 20);
1517 
1518   if (F->arg_empty() && !F->isVarArg()) {
1519     O << "()";
1520     return;
1521   }
1522 
1523   O << "(\n";
1524 
1525   for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
1526     Type *Ty = I->getType();
1527 
1528     if (!first)
1529       O << ",\n";
1530 
1531     first = false;
1532 
1533     // Handle image/sampler parameters
1534     if (isKernelFunc) {
1535       if (isSampler(*I) || isImage(*I)) {
1536         std::string ParamSym;
1537         raw_string_ostream ParamStr(ParamSym);
1538         ParamStr << F->getName() << "_param_" << paramIndex;
1539         ParamStr.flush();
1540         bool EmitImagePtr = !MFI || !MFI->checkImageHandleSymbol(ParamSym);
1541         if (isImage(*I)) {
1542           if (isImageWriteOnly(*I) || isImageReadWrite(*I)) {
1543             if (EmitImagePtr)
1544               O << "\t.param .u64 .ptr .surfref ";
1545             else
1546               O << "\t.param .surfref ";
1547             O << TLI->getParamName(F, paramIndex);
1548           }
1549           else { // Default image is read_only
1550             if (EmitImagePtr)
1551               O << "\t.param .u64 .ptr .texref ";
1552             else
1553               O << "\t.param .texref ";
1554             O << TLI->getParamName(F, paramIndex);
1555           }
1556         } else {
1557           if (EmitImagePtr)
1558             O << "\t.param .u64 .ptr .samplerref ";
1559           else
1560             O << "\t.param .samplerref ";
1561           O << TLI->getParamName(F, paramIndex);
1562         }
1563         continue;
1564       }
1565     }
1566 
1567     auto getOptimalAlignForParam = [TLI, &DL, &PAL, F,
1568                                     paramIndex](Type *Ty) -> Align {
1569       if (MaybeAlign StackAlign =
1570               getAlign(*F, paramIndex + AttributeList::FirstArgIndex))
1571         return StackAlign.value();
1572 
1573       Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL);
1574       MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex);
1575       return std::max(TypeAlign, ParamAlign.valueOrOne());
1576     };
1577 
1578     if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) {
1579       if (ShouldPassAsArray(Ty)) {
1580         // Just print .param .align <a> .b8 .param[size];
1581         // <a>  = optimal alignment for the element type; always multiple of
1582         //        PAL.getParamAlignment
1583         // size = typeallocsize of element type
1584         Align OptimalAlign = getOptimalAlignForParam(Ty);
1585 
1586         O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1587         O << TLI->getParamName(F, paramIndex);
1588         O << "[" << DL.getTypeAllocSize(Ty) << "]";
1589 
1590         continue;
1591       }
1592       // Just a scalar
1593       auto *PTy = dyn_cast<PointerType>(Ty);
1594       unsigned PTySizeInBits = 0;
1595       if (PTy) {
1596         PTySizeInBits =
1597             TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits();
1598         assert(PTySizeInBits && "Invalid pointer size");
1599       }
1600 
1601       if (isKernelFunc) {
1602         if (PTy) {
1603           O << "\t.param .u" << PTySizeInBits << " .ptr";
1604 
1605           switch (PTy->getAddressSpace()) {
1606           default:
1607             break;
1608           case ADDRESS_SPACE_GLOBAL:
1609             O << " .global";
1610             break;
1611           case ADDRESS_SPACE_SHARED:
1612             O << " .shared";
1613             break;
1614           case ADDRESS_SPACE_CONST:
1615             O << " .const";
1616             break;
1617           case ADDRESS_SPACE_LOCAL:
1618             O << " .local";
1619             break;
1620           }
1621 
1622           O << " .align " << I->getParamAlign().valueOrOne().value();
1623           O << " " << TLI->getParamName(F, paramIndex);
1624           continue;
1625         }
1626 
1627         // non-pointer scalar to kernel func
1628         O << "\t.param .";
1629         // Special case: predicate operands become .u8 types
1630         if (Ty->isIntegerTy(1))
1631           O << "u8";
1632         else
1633           O << getPTXFundamentalTypeStr(Ty);
1634         O << " ";
1635         O << TLI->getParamName(F, paramIndex);
1636         continue;
1637       }
1638       // Non-kernel function, just print .param .b<size> for ABI
1639       // and .reg .b<size> for non-ABI
1640       unsigned sz = 0;
1641       if (isa<IntegerType>(Ty)) {
1642         sz = cast<IntegerType>(Ty)->getBitWidth();
1643         sz = promoteScalarArgumentSize(sz);
1644       } else if (PTy) {
1645         assert(PTySizeInBits && "Invalid pointer size");
1646         sz = PTySizeInBits;
1647       } else
1648         sz = Ty->getPrimitiveSizeInBits();
1649       if (isABI)
1650         O << "\t.param .b" << sz << " ";
1651       else
1652         O << "\t.reg .b" << sz << " ";
1653       O << TLI->getParamName(F, paramIndex);
1654       continue;
1655     }
1656 
1657     // param has byVal attribute.
1658     Type *ETy = PAL.getParamByValType(paramIndex);
1659     assert(ETy && "Param should have byval type");
1660 
1661     if (isABI || isKernelFunc) {
1662       // Just print .param .align <a> .b8 .param[size];
1663       // <a>  = optimal alignment for the element type; always multiple of
1664       //        PAL.getParamAlignment
1665       // size = typeallocsize of element type
1666       Align OptimalAlign =
1667           isKernelFunc
1668               ? getOptimalAlignForParam(ETy)
1669               : TLI->getFunctionByValParamAlign(
1670                     F, ETy, PAL.getParamAlignment(paramIndex).valueOrOne(), DL);
1671 
1672       unsigned sz = DL.getTypeAllocSize(ETy);
1673       O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1674       O << TLI->getParamName(F, paramIndex);
1675       O << "[" << sz << "]";
1676       continue;
1677     } else {
1678       // Split the ETy into constituent parts and
1679       // print .param .b<size> <name> for each part.
1680       // Further, if a part is vector, print the above for
1681       // each vector element.
1682       SmallVector<EVT, 16> vtparts;
1683       ComputeValueVTs(*TLI, DL, ETy, vtparts);
1684       for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
1685         unsigned elems = 1;
1686         EVT elemtype = vtparts[i];
1687         if (vtparts[i].isVector()) {
1688           elems = vtparts[i].getVectorNumElements();
1689           elemtype = vtparts[i].getVectorElementType();
1690         }
1691 
1692         for (unsigned j = 0, je = elems; j != je; ++j) {
1693           unsigned sz = elemtype.getSizeInBits();
1694           if (elemtype.isInteger())
1695             sz = promoteScalarArgumentSize(sz);
1696           O << "\t.reg .b" << sz << " ";
1697           O << TLI->getParamName(F, paramIndex);
1698           if (j < je - 1)
1699             O << ",\n";
1700           ++paramIndex;
1701         }
1702         if (i < e - 1)
1703           O << ",\n";
1704       }
1705       --paramIndex;
1706       continue;
1707     }
1708   }
1709 
1710   if (F->isVarArg()) {
1711     if (!first)
1712       O << ",\n";
1713     O << "\t.param .align " << STI.getMaxRequiredAlignment();
1714     O << " .b8 ";
1715     O << TLI->getParamName(F, /* vararg */ -1) << "[]";
1716   }
1717 
1718   O << "\n)";
1719 }
1720 
1721 void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
1722     const MachineFunction &MF) {
1723   SmallString<128> Str;
1724   raw_svector_ostream O(Str);
1725 
1726   // Map the global virtual register number to a register class specific
1727   // virtual register number starting from 1 with that class.
1728   const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
1729   //unsigned numRegClasses = TRI->getNumRegClasses();
1730 
1731   // Emit the Fake Stack Object
1732   const MachineFrameInfo &MFI = MF.getFrameInfo();
1733   int64_t NumBytes = MFI.getStackSize();
1734   if (NumBytes) {
1735     O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t"
1736       << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n";
1737     if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) {
1738       O << "\t.reg .b64 \t%SP;\n";
1739       O << "\t.reg .b64 \t%SPL;\n";
1740     } else {
1741       O << "\t.reg .b32 \t%SP;\n";
1742       O << "\t.reg .b32 \t%SPL;\n";
1743     }
1744   }
1745 
1746   // Go through all virtual registers to establish the mapping between the
1747   // global virtual
1748   // register number and the per class virtual register number.
1749   // We use the per class virtual register number in the ptx output.
1750   unsigned int numVRs = MRI->getNumVirtRegs();
1751   for (unsigned i = 0; i < numVRs; i++) {
1752     Register vr = Register::index2VirtReg(i);
1753     const TargetRegisterClass *RC = MRI->getRegClass(vr);
1754     DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1755     int n = regmap.size();
1756     regmap.insert(std::make_pair(vr, n + 1));
1757   }
1758 
1759   // Emit register declarations
1760   // @TODO: Extract out the real register usage
1761   // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n";
1762   // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n";
1763   // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n";
1764   // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n";
1765   // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n";
1766   // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n";
1767   // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n";
1768 
1769   // Emit declaration of the virtual registers or 'physical' registers for
1770   // each register class
1771   for (unsigned i=0; i< TRI->getNumRegClasses(); i++) {
1772     const TargetRegisterClass *RC = TRI->getRegClass(i);
1773     DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1774     std::string rcname = getNVPTXRegClassName(RC);
1775     std::string rcStr = getNVPTXRegClassStr(RC);
1776     int n = regmap.size();
1777 
1778     // Only declare those registers that may be used.
1779     if (n) {
1780        O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1)
1781          << ">;\n";
1782     }
1783   }
1784 
1785   OutStreamer->emitRawText(O.str());
1786 }
1787 
1788 /// Translate virtual register numbers in DebugInfo locations to their printed
1789 /// encodings, as used by CUDA-GDB.
1790 void NVPTXAsmPrinter::encodeDebugInfoRegisterNumbers(
1791     const MachineFunction &MF) {
1792   const NVPTXSubtarget &STI = MF.getSubtarget<NVPTXSubtarget>();
1793   const NVPTXRegisterInfo *registerInfo = STI.getRegisterInfo();
1794 
1795   // Clear the old mapping, and add the new one.  This mapping is used after the
1796   // printing of the current function is complete, but before the next function
1797   // is printed.
1798   registerInfo->clearDebugRegisterMap();
1799 
1800   for (auto &classMap : VRegMapping) {
1801     for (auto &registerMapping : classMap.getSecond()) {
1802       auto reg = registerMapping.getFirst();
1803       registerInfo->addToDebugRegisterMap(reg, getVirtualRegisterName(reg));
1804     }
1805   }
1806 }
1807 
1808 void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) {
1809   APFloat APF = APFloat(Fp->getValueAPF()); // make a copy
1810   bool ignored;
1811   unsigned int numHex;
1812   const char *lead;
1813 
1814   if (Fp->getType()->getTypeID() == Type::FloatTyID) {
1815     numHex = 8;
1816     lead = "0f";
1817     APF.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);
1818   } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) {
1819     numHex = 16;
1820     lead = "0d";
1821     APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &ignored);
1822   } else
1823     llvm_unreachable("unsupported fp type");
1824 
1825   APInt API = APF.bitcastToAPInt();
1826   O << lead << format_hex_no_prefix(API.getZExtValue(), numHex, /*Upper=*/true);
1827 }
1828 
1829 void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {
1830   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1831     O << CI->getValue();
1832     return;
1833   }
1834   if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
1835     printFPConstant(CFP, O);
1836     return;
1837   }
1838   if (isa<ConstantPointerNull>(CPV)) {
1839     O << "0";
1840     return;
1841   }
1842   if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1843     bool IsNonGenericPointer = false;
1844     if (GVar->getType()->getAddressSpace() != 0) {
1845       IsNonGenericPointer = true;
1846     }
1847     if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) {
1848       O << "generic(";
1849       getSymbol(GVar)->print(O, MAI);
1850       O << ")";
1851     } else {
1852       getSymbol(GVar)->print(O, MAI);
1853     }
1854     return;
1855   }
1856   if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1857     const MCExpr *E = lowerConstantForGV(cast<Constant>(Cexpr), false);
1858     printMCExpr(*E, O);
1859     return;
1860   }
1861   llvm_unreachable("Not scalar type found in printScalarConstant()");
1862 }
1863 
1864 void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
1865                                    AggBuffer *AggBuffer) {
1866   const DataLayout &DL = getDataLayout();
1867   int AllocSize = DL.getTypeAllocSize(CPV->getType());
1868   if (isa<UndefValue>(CPV) || CPV->isNullValue()) {
1869     // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise,
1870     // only the space allocated by CPV.
1871     AggBuffer->addZeros(Bytes ? Bytes : AllocSize);
1872     return;
1873   }
1874 
1875   // Helper for filling AggBuffer with APInts.
1876   auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) {
1877     size_t NumBytes = (Val.getBitWidth() + 7) / 8;
1878     SmallVector<unsigned char, 16> Buf(NumBytes);
1879     // `extractBitsAsZExtValue` does not allow the extraction of bits beyond the
1880     // input's bit width, and i1 arrays may not have a length that is a multuple
1881     // of 8. We handle the last byte separately, so we never request out of
1882     // bounds bits.
1883     for (unsigned I = 0; I < NumBytes - 1; ++I) {
1884       Buf[I] = Val.extractBitsAsZExtValue(8, I * 8);
1885     }
1886     size_t LastBytePosition = (NumBytes - 1) * 8;
1887     size_t LastByteBits = Val.getBitWidth() - LastBytePosition;
1888     Buf[NumBytes - 1] =
1889         Val.extractBitsAsZExtValue(LastByteBits, LastBytePosition);
1890     AggBuffer->addBytes(Buf.data(), NumBytes, Bytes);
1891   };
1892 
1893   switch (CPV->getType()->getTypeID()) {
1894   case Type::IntegerTyID:
1895     if (const auto CI = dyn_cast<ConstantInt>(CPV)) {
1896       AddIntToBuffer(CI->getValue());
1897       break;
1898     }
1899     if (const auto *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1900       if (const auto *CI =
1901               dyn_cast<ConstantInt>(ConstantFoldConstant(Cexpr, DL))) {
1902         AddIntToBuffer(CI->getValue());
1903         break;
1904       }
1905       if (Cexpr->getOpcode() == Instruction::PtrToInt) {
1906         Value *V = Cexpr->getOperand(0)->stripPointerCasts();
1907         AggBuffer->addSymbol(V, Cexpr->getOperand(0));
1908         AggBuffer->addZeros(AllocSize);
1909         break;
1910       }
1911     }
1912     llvm_unreachable("unsupported integer const type");
1913     break;
1914 
1915   case Type::HalfTyID:
1916   case Type::BFloatTyID:
1917   case Type::FloatTyID:
1918   case Type::DoubleTyID:
1919     AddIntToBuffer(cast<ConstantFP>(CPV)->getValueAPF().bitcastToAPInt());
1920     break;
1921 
1922   case Type::PointerTyID: {
1923     if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1924       AggBuffer->addSymbol(GVar, GVar);
1925     } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1926       const Value *v = Cexpr->stripPointerCasts();
1927       AggBuffer->addSymbol(v, Cexpr);
1928     }
1929     AggBuffer->addZeros(AllocSize);
1930     break;
1931   }
1932 
1933   case Type::ArrayTyID:
1934   case Type::FixedVectorTyID:
1935   case Type::StructTyID: {
1936     if (isa<ConstantAggregate>(CPV) || isa<ConstantDataSequential>(CPV)) {
1937       bufferAggregateConstant(CPV, AggBuffer);
1938       if (Bytes > AllocSize)
1939         AggBuffer->addZeros(Bytes - AllocSize);
1940     } else if (isa<ConstantAggregateZero>(CPV))
1941       AggBuffer->addZeros(Bytes);
1942     else
1943       llvm_unreachable("Unexpected Constant type");
1944     break;
1945   }
1946 
1947   default:
1948     llvm_unreachable("unsupported type");
1949   }
1950 }
1951 
1952 void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
1953                                               AggBuffer *aggBuffer) {
1954   const DataLayout &DL = getDataLayout();
1955   int Bytes;
1956 
1957   // Integers of arbitrary width
1958   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1959     APInt Val = CI->getValue();
1960     for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
1961       uint8_t Byte = Val.getLoBits(8).getZExtValue();
1962       aggBuffer->addBytes(&Byte, 1, 1);
1963       Val.lshrInPlace(8);
1964     }
1965     return;
1966   }
1967 
1968   // Old constants
1969   if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
1970     if (CPV->getNumOperands())
1971       for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)
1972         bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer);
1973     return;
1974   }
1975 
1976   if (const ConstantDataSequential *CDS =
1977           dyn_cast<ConstantDataSequential>(CPV)) {
1978     if (CDS->getNumElements())
1979       for (unsigned i = 0; i < CDS->getNumElements(); ++i)
1980         bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
1981                      aggBuffer);
1982     return;
1983   }
1984 
1985   if (isa<ConstantStruct>(CPV)) {
1986     if (CPV->getNumOperands()) {
1987       StructType *ST = cast<StructType>(CPV->getType());
1988       for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
1989         if (i == (e - 1))
1990           Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
1991                   DL.getTypeAllocSize(ST) -
1992                   DL.getStructLayout(ST)->getElementOffset(i);
1993         else
1994           Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) -
1995                   DL.getStructLayout(ST)->getElementOffset(i);
1996         bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer);
1997       }
1998     }
1999     return;
2000   }
2001   llvm_unreachable("unsupported constant type in printAggregateConstant()");
2002 }
2003 
2004 /// lowerConstantForGV - Return an MCExpr for the given Constant.  This is mostly
2005 /// a copy from AsmPrinter::lowerConstant, except customized to only handle
2006 /// expressions that are representable in PTX and create
2007 /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
2008 const MCExpr *
2009 NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) {
2010   MCContext &Ctx = OutContext;
2011 
2012   if (CV->isNullValue() || isa<UndefValue>(CV))
2013     return MCConstantExpr::create(0, Ctx);
2014 
2015   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV))
2016     return MCConstantExpr::create(CI->getZExtValue(), Ctx);
2017 
2018   if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) {
2019     const MCSymbolRefExpr *Expr =
2020       MCSymbolRefExpr::create(getSymbol(GV), Ctx);
2021     if (ProcessingGeneric) {
2022       return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx);
2023     } else {
2024       return Expr;
2025     }
2026   }
2027 
2028   const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV);
2029   if (!CE) {
2030     llvm_unreachable("Unknown constant value to lower!");
2031   }
2032 
2033   switch (CE->getOpcode()) {
2034   default:
2035     break; // Error
2036 
2037   case Instruction::AddrSpaceCast: {
2038     // Strip the addrspacecast and pass along the operand
2039     PointerType *DstTy = cast<PointerType>(CE->getType());
2040     if (DstTy->getAddressSpace() == 0)
2041       return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true);
2042 
2043     break; // Error
2044   }
2045 
2046   case Instruction::GetElementPtr: {
2047     const DataLayout &DL = getDataLayout();
2048 
2049     // Generate a symbolic expression for the byte address
2050     APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0);
2051     cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI);
2052 
2053     const MCExpr *Base = lowerConstantForGV(CE->getOperand(0),
2054                                             ProcessingGeneric);
2055     if (!OffsetAI)
2056       return Base;
2057 
2058     int64_t Offset = OffsetAI.getSExtValue();
2059     return MCBinaryExpr::createAdd(Base, MCConstantExpr::create(Offset, Ctx),
2060                                    Ctx);
2061   }
2062 
2063   case Instruction::Trunc:
2064     // We emit the value and depend on the assembler to truncate the generated
2065     // expression properly.  This is important for differences between
2066     // blockaddress labels.  Since the two labels are in the same function, it
2067     // is reasonable to treat their delta as a 32-bit value.
2068     [[fallthrough]];
2069   case Instruction::BitCast:
2070     return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2071 
2072   case Instruction::IntToPtr: {
2073     const DataLayout &DL = getDataLayout();
2074 
2075     // Handle casts to pointers by changing them into casts to the appropriate
2076     // integer type.  This promotes constant folding and simplifies this code.
2077     Constant *Op = CE->getOperand(0);
2078     Op = ConstantFoldIntegerCast(Op, DL.getIntPtrType(CV->getType()),
2079                                  /*IsSigned*/ false, DL);
2080     if (Op)
2081       return lowerConstantForGV(Op, ProcessingGeneric);
2082 
2083     break; // Error
2084   }
2085 
2086   case Instruction::PtrToInt: {
2087     const DataLayout &DL = getDataLayout();
2088 
2089     // Support only foldable casts to/from pointers that can be eliminated by
2090     // changing the pointer to the appropriately sized integer type.
2091     Constant *Op = CE->getOperand(0);
2092     Type *Ty = CE->getType();
2093 
2094     const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric);
2095 
2096     // We can emit the pointer value into this slot if the slot is an
2097     // integer slot equal to the size of the pointer.
2098     if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType()))
2099       return OpExpr;
2100 
2101     // Otherwise the pointer is smaller than the resultant integer, mask off
2102     // the high bits so we are sure to get a proper truncation if the input is
2103     // a constant expr.
2104     unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType());
2105     const MCExpr *MaskExpr = MCConstantExpr::create(~0ULL >> (64-InBits), Ctx);
2106     return MCBinaryExpr::createAnd(OpExpr, MaskExpr, Ctx);
2107   }
2108 
2109   // The MC library also has a right-shift operator, but it isn't consistently
2110   // signed or unsigned between different targets.
2111   case Instruction::Add: {
2112     const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2113     const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric);
2114     switch (CE->getOpcode()) {
2115     default: llvm_unreachable("Unknown binary operator constant cast expr");
2116     case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx);
2117     }
2118   }
2119   }
2120 
2121   // If the code isn't optimized, there may be outstanding folding
2122   // opportunities. Attempt to fold the expression using DataLayout as a
2123   // last resort before giving up.
2124   Constant *C = ConstantFoldConstant(CE, getDataLayout());
2125   if (C != CE)
2126     return lowerConstantForGV(C, ProcessingGeneric);
2127 
2128   // Otherwise report the problem to the user.
2129   std::string S;
2130   raw_string_ostream OS(S);
2131   OS << "Unsupported expression in static initializer: ";
2132   CE->printAsOperand(OS, /*PrintType=*/false,
2133                  !MF ? nullptr : MF->getFunction().getParent());
2134   report_fatal_error(Twine(OS.str()));
2135 }
2136 
2137 // Copy of MCExpr::print customized for NVPTX
2138 void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) {
2139   switch (Expr.getKind()) {
2140   case MCExpr::Target:
2141     return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI);
2142   case MCExpr::Constant:
2143     OS << cast<MCConstantExpr>(Expr).getValue();
2144     return;
2145 
2146   case MCExpr::SymbolRef: {
2147     const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr);
2148     const MCSymbol &Sym = SRE.getSymbol();
2149     Sym.print(OS, MAI);
2150     return;
2151   }
2152 
2153   case MCExpr::Unary: {
2154     const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr);
2155     switch (UE.getOpcode()) {
2156     case MCUnaryExpr::LNot:  OS << '!'; break;
2157     case MCUnaryExpr::Minus: OS << '-'; break;
2158     case MCUnaryExpr::Not:   OS << '~'; break;
2159     case MCUnaryExpr::Plus:  OS << '+'; break;
2160     }
2161     printMCExpr(*UE.getSubExpr(), OS);
2162     return;
2163   }
2164 
2165   case MCExpr::Binary: {
2166     const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr);
2167 
2168     // Only print parens around the LHS if it is non-trivial.
2169     if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) ||
2170         isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) {
2171       printMCExpr(*BE.getLHS(), OS);
2172     } else {
2173       OS << '(';
2174       printMCExpr(*BE.getLHS(), OS);
2175       OS<< ')';
2176     }
2177 
2178     switch (BE.getOpcode()) {
2179     case MCBinaryExpr::Add:
2180       // Print "X-42" instead of "X+-42".
2181       if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) {
2182         if (RHSC->getValue() < 0) {
2183           OS << RHSC->getValue();
2184           return;
2185         }
2186       }
2187 
2188       OS <<  '+';
2189       break;
2190     default: llvm_unreachable("Unhandled binary operator");
2191     }
2192 
2193     // Only print parens around the LHS if it is non-trivial.
2194     if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) {
2195       printMCExpr(*BE.getRHS(), OS);
2196     } else {
2197       OS << '(';
2198       printMCExpr(*BE.getRHS(), OS);
2199       OS << ')';
2200     }
2201     return;
2202   }
2203   }
2204 
2205   llvm_unreachable("Invalid expression kind!");
2206 }
2207 
2208 /// PrintAsmOperand - Print out an operand for an inline asm expression.
2209 ///
2210 bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
2211                                       const char *ExtraCode, raw_ostream &O) {
2212   if (ExtraCode && ExtraCode[0]) {
2213     if (ExtraCode[1] != 0)
2214       return true; // Unknown modifier.
2215 
2216     switch (ExtraCode[0]) {
2217     default:
2218       // See if this is a generic print operand
2219       return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, O);
2220     case 'r':
2221       break;
2222     }
2223   }
2224 
2225   printOperand(MI, OpNo, O);
2226 
2227   return false;
2228 }
2229 
2230 bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI,
2231                                             unsigned OpNo,
2232                                             const char *ExtraCode,
2233                                             raw_ostream &O) {
2234   if (ExtraCode && ExtraCode[0])
2235     return true; // Unknown modifier
2236 
2237   O << '[';
2238   printMemOperand(MI, OpNo, O);
2239   O << ']';
2240 
2241   return false;
2242 }
2243 
2244 void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, unsigned OpNum,
2245                                    raw_ostream &O) {
2246   const MachineOperand &MO = MI->getOperand(OpNum);
2247   switch (MO.getType()) {
2248   case MachineOperand::MO_Register:
2249     if (MO.getReg().isPhysical()) {
2250       if (MO.getReg() == NVPTX::VRDepot)
2251         O << DEPOTNAME << getFunctionNumber();
2252       else
2253         O << NVPTXInstPrinter::getRegisterName(MO.getReg());
2254     } else {
2255       emitVirtualRegister(MO.getReg(), O);
2256     }
2257     break;
2258 
2259   case MachineOperand::MO_Immediate:
2260     O << MO.getImm();
2261     break;
2262 
2263   case MachineOperand::MO_FPImmediate:
2264     printFPConstant(MO.getFPImm(), O);
2265     break;
2266 
2267   case MachineOperand::MO_GlobalAddress:
2268     PrintSymbolOperand(MO, O);
2269     break;
2270 
2271   case MachineOperand::MO_MachineBasicBlock:
2272     MO.getMBB()->getSymbol()->print(O, MAI);
2273     break;
2274 
2275   default:
2276     llvm_unreachable("Operand type not supported.");
2277   }
2278 }
2279 
2280 void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, unsigned OpNum,
2281                                       raw_ostream &O, const char *Modifier) {
2282   printOperand(MI, OpNum, O);
2283 
2284   if (Modifier && strcmp(Modifier, "add") == 0) {
2285     O << ", ";
2286     printOperand(MI, OpNum + 1, O);
2287   } else {
2288     if (MI->getOperand(OpNum + 1).isImm() &&
2289         MI->getOperand(OpNum + 1).getImm() == 0)
2290       return; // don't print ',0' or '+0'
2291     O << "+";
2292     printOperand(MI, OpNum + 1, O);
2293   }
2294 }
2295 
2296 // Force static initialization.
2297 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXAsmPrinter() {
2298   RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32());
2299   RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64());
2300 }
2301